diff --git a/.gitignore b/.gitignore index fd33cb142a28..6029cda7a93d 100644 --- a/.gitignore +++ b/.gitignore @@ -73,3 +73,7 @@ nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativebla # Ignore meld temp files *.orig + +# ignore some specific folders related to CLion +libnd4j/tests_cpu/libnd4j_tests/cmake*/ +libnd4j/tests_cpu/libnd4j_tests/cmake-build-debug-kraken/ \ No newline at end of file diff --git a/libnd4j/blas/CMakeLists.txt b/libnd4j/blas/CMakeLists.txt index fb1dc066e606..e67d29ae1e37 100755 --- a/libnd4j/blas/CMakeLists.txt +++ b/libnd4j/blas/CMakeLists.txt @@ -108,10 +108,10 @@ ENDIF() if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang" AND SD_X86_BUILD) # apple clang but not ios-arm - SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE}") + SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE} -Wno-defaulted-function-deleted") elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") # using Clang - SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE}") + SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE} -Wno-defaulted-function-deleted") elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Intel") # using Intel C++ SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE} -O3 -fp-model fast") @@ -230,9 +230,9 @@ if(SD_CUDA) file(GLOB_RECURSE COMPILATION_UNITS false ../include/loops/cuda/compilation_units/*.cu.in ../include/ops/impl/compilation_units/*.cpp.in) - foreach(FL_ITEM ${COMPILATION_UNITS}) + foreach(FL_ITEM ${COMPILATION_UNITS}) genCompilation(FL_ITEM) - endforeach() + endforeach() if (HAVE_CUDNN) message("cuDNN included") @@ -305,11 +305,11 @@ elseif(SD_CPU) file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h) - file(GLOB_RECURSE COMPILATION_UNITS false ../include/ops/declarable/helpers/cpu/compilation_units/*.cpp.in + file(GLOB_RECURSE COMPILATION_UNITS false ../include/ops/declarable/helpers/cpu/compilation_units/*.cpp.in ../include/loops/cpu/compilation_units/*.cpp.in ../include/helpers/cpu/loops/*.cpp.in ../include/ops/impl/compilation_units/*.cpp.in) - foreach(FL_ITEM ${COMPILATION_UNITS}) + foreach(FL_ITEM ${COMPILATION_UNITS}) genCompilation(FL_ITEM) endforeach() diff --git a/libnd4j/include/array/ArrayOptions.h b/libnd4j/include/array/ArrayOptions.h index 1f0c2570503a..2b9b98d4cec3 100644 --- a/libnd4j/include/array/ArrayOptions.h +++ b/libnd4j/include/array/ArrayOptions.h @@ -21,22 +21,21 @@ #ifndef ND4J_ARRAY_OPTIONS_H #define ND4J_ARRAY_OPTIONS_H -#include -#include -#include -#include #include +#include #include #include -#include +#include +#include +#include +#include #define ARRAY_SPARSE 2 #define ARRAY_COMPRESSED 4 #define ARRAY_EMPTY 8 #define ARRAY_RAGGED 16 - #define ARRAY_CSR 32 #define ARRAY_CSC 64 #define ARRAY_COO 128 @@ -79,295 +78,304 @@ #define ARRAY_UTF16 4194304 #define ARRAY_UTF32 16777216 -// flag for extras +// flag for extras #define ARRAY_EXTRAS 2097152 - // flag for signed/unsigned integers #define ARRAY_UNSIGNED 8388608 - namespace sd { - class ND4J_EXPORT ArrayOptions { - - private: - static FORCEINLINE _CUDA_HD Nd4jLong& extra(Nd4jLong* shape); +class SD_EXPORT ArrayOptions { + private: + static FORCEINLINE _CUDA_HD Nd4jLong &extra(Nd4jLong *shape); - public: - static FORCEINLINE _CUDA_HD bool isNewFormat(const Nd4jLong *shapeInfo); - static FORCEINLINE _CUDA_HD bool hasPropertyBitSet(const Nd4jLong *shapeInfo, int property); - static FORCEINLINE _CUDA_HD bool togglePropertyBit(Nd4jLong *shapeInfo, int property); - static FORCEINLINE _CUDA_HD void unsetPropertyBit(Nd4jLong *shapeInfo, int property); - static FORCEINLINE _CUDA_HD void setPropertyBit(Nd4jLong *shapeInfo, int property); - static FORCEINLINE _CUDA_HD void setPropertyBits(Nd4jLong *shapeInfo, std::initializer_list properties); + public: + static FORCEINLINE _CUDA_HD bool isNewFormat(const Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD bool hasPropertyBitSet(const Nd4jLong *shapeInfo, + int property); + static FORCEINLINE _CUDA_HD bool togglePropertyBit(Nd4jLong *shapeInfo, + int property); + static FORCEINLINE _CUDA_HD void unsetPropertyBit(Nd4jLong *shapeInfo, + int property); + static FORCEINLINE _CUDA_HD void setPropertyBit(Nd4jLong *shapeInfo, + int property); + static FORCEINLINE _CUDA_HD void setPropertyBits( + Nd4jLong *shapeInfo, std::initializer_list properties); - static FORCEINLINE _CUDA_HD bool isSparseArray(Nd4jLong *shapeInfo); - static FORCEINLINE _CUDA_HD bool isUnsigned(Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD bool isSparseArray(Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD bool isUnsigned(Nd4jLong *shapeInfo); - static FORCEINLINE _CUDA_HD sd::DataType dataType(const Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD sd::DataType dataType(const Nd4jLong *shapeInfo); - static FORCEINLINE _CUDA_HD SpaceType spaceType(Nd4jLong *shapeInfo); - static FORCEINLINE _CUDA_HD SpaceType spaceType(const Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD SpaceType spaceType(Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD SpaceType spaceType(const Nd4jLong *shapeInfo); - static FORCEINLINE _CUDA_HD ArrayType arrayType(Nd4jLong *shapeInfo); - static FORCEINLINE _CUDA_HD ArrayType arrayType(const Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD ArrayType arrayType(Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD ArrayType arrayType(const Nd4jLong *shapeInfo); - static FORCEINLINE _CUDA_HD SparseType sparseType(Nd4jLong *shapeInfo); - static FORCEINLINE _CUDA_HD SparseType sparseType(const Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD SparseType sparseType(Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD SparseType sparseType(const Nd4jLong *shapeInfo); - static FORCEINLINE _CUDA_HD bool hasExtraProperties(Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD bool hasExtraProperties(Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD void resetDataType(Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD void setDataType(Nd4jLong *shapeInfo, + const sd::DataType dataType); - static FORCEINLINE _CUDA_HD void resetDataType(Nd4jLong *shapeInfo); - static FORCEINLINE _CUDA_HD void setDataType(Nd4jLong *shapeInfo, const sd::DataType dataType); + static FORCEINLINE _CUDA_HD void copyDataType(Nd4jLong *to, + const Nd4jLong *from); +}; - static FORCEINLINE _CUDA_HD void copyDataType(Nd4jLong* to, const Nd4jLong* from); - }; - - FORCEINLINE _CUDA_HD Nd4jLong& ArrayOptions::extra(Nd4jLong* shape) { - return shape[shape[0] + shape[0] + 1]; - } - - FORCEINLINE _CUDA_HD bool ArrayOptions::isNewFormat(const Nd4jLong *shapeInfo) { - return (extra(const_cast(shapeInfo)) != 0); - } +FORCEINLINE _CUDA_HD Nd4jLong &ArrayOptions::extra(Nd4jLong *shape) { + return shape[shape[0] + shape[0] + 1]; +} +FORCEINLINE _CUDA_HD bool ArrayOptions::isNewFormat(const Nd4jLong *shapeInfo) { + return (extra(const_cast(shapeInfo)) != 0); +} - FORCEINLINE _CUDA_HD bool ArrayOptions::isSparseArray(Nd4jLong *shapeInfo) { - return hasPropertyBitSet(shapeInfo, ARRAY_SPARSE); - } +FORCEINLINE _CUDA_HD bool ArrayOptions::isSparseArray(Nd4jLong *shapeInfo) { + return hasPropertyBitSet(shapeInfo, ARRAY_SPARSE); +} - FORCEINLINE _CUDA_HD bool ArrayOptions::hasExtraProperties(Nd4jLong *shapeInfo) { - return hasPropertyBitSet(shapeInfo, ARRAY_EXTRAS); - } +FORCEINLINE _CUDA_HD bool ArrayOptions::hasExtraProperties( + Nd4jLong *shapeInfo) { + return hasPropertyBitSet(shapeInfo, ARRAY_EXTRAS); +} - FORCEINLINE _CUDA_HD bool ArrayOptions::hasPropertyBitSet(const Nd4jLong *shapeInfo, int property) { - if (!isNewFormat(shapeInfo)) - return false; +FORCEINLINE _CUDA_HD bool ArrayOptions::hasPropertyBitSet( + const Nd4jLong *shapeInfo, int property) { + if (!isNewFormat(shapeInfo)) return false; - return ((extra(const_cast(shapeInfo)) & property) == property); - } + return ((extra(const_cast(shapeInfo)) & property) == property); +} - FORCEINLINE _CUDA_HD bool ArrayOptions::isUnsigned(Nd4jLong *shapeInfo) { - if (!isNewFormat(shapeInfo)) - return false; +FORCEINLINE _CUDA_HD bool ArrayOptions::isUnsigned(Nd4jLong *shapeInfo) { + if (!isNewFormat(shapeInfo)) return false; - return hasPropertyBitSet(shapeInfo, ARRAY_UNSIGNED); - } + return hasPropertyBitSet(shapeInfo, ARRAY_UNSIGNED); +} - FORCEINLINE _CUDA_HD sd::DataType ArrayOptions::dataType(const Nd4jLong *shapeInfo) { - /*if (hasPropertyBitSet(shapeInfo, ARRAY_QUANTIZED)) - return sd::DataType::QINT8; - else */if (hasPropertyBitSet(shapeInfo, ARRAY_FLOAT)) - return sd::DataType::FLOAT32; - else if (hasPropertyBitSet(shapeInfo, ARRAY_DOUBLE)) - return sd::DataType::DOUBLE; - else if (hasPropertyBitSet(shapeInfo, ARRAY_HALF)) - return sd::DataType::HALF; - else if (hasPropertyBitSet(shapeInfo, ARRAY_BHALF)) - return sd::DataType::BFLOAT16; - else if (hasPropertyBitSet(shapeInfo, ARRAY_BOOL)) - return sd::DataType ::BOOL; - else if (hasPropertyBitSet(shapeInfo, ARRAY_UNSIGNED)) { - if (hasPropertyBitSet(shapeInfo, ARRAY_CHAR)) - return sd::DataType ::UINT8; - else if (hasPropertyBitSet(shapeInfo, ARRAY_SHORT)) - return sd::DataType ::UINT16; - else if (hasPropertyBitSet(shapeInfo, ARRAY_INT)) - return sd::DataType ::UINT32; - else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG)) - return sd::DataType ::UINT64; - else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8)) - return sd::DataType ::UTF8; - else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16)) - return sd::DataType ::UTF16; - else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32)) - return sd::DataType ::UTF32; - else { - //shape::printShapeInfoLinear("Bad unsigned datatype (not)stored in shape", const_cast(shapeInfo)); +FORCEINLINE _CUDA_HD sd::DataType ArrayOptions::dataType( + const Nd4jLong *shapeInfo) { + /*if (hasPropertyBitSet(shapeInfo, ARRAY_QUANTIZED)) + return sd::DataType::QINT8; + else */ + if (hasPropertyBitSet(shapeInfo, ARRAY_FLOAT)) + return sd::DataType::FLOAT32; + else if (hasPropertyBitSet(shapeInfo, ARRAY_DOUBLE)) + return sd::DataType::DOUBLE; + else if (hasPropertyBitSet(shapeInfo, ARRAY_HALF)) + return sd::DataType::HALF; + else if (hasPropertyBitSet(shapeInfo, ARRAY_BHALF)) + return sd::DataType::BFLOAT16; + else if (hasPropertyBitSet(shapeInfo, ARRAY_BOOL)) + return sd::DataType ::BOOL; + else if (hasPropertyBitSet(shapeInfo, ARRAY_UNSIGNED)) { + if (hasPropertyBitSet(shapeInfo, ARRAY_CHAR)) + return sd::DataType ::UINT8; + else if (hasPropertyBitSet(shapeInfo, ARRAY_SHORT)) + return sd::DataType ::UINT16; + else if (hasPropertyBitSet(shapeInfo, ARRAY_INT)) + return sd::DataType ::UINT32; + else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG)) + return sd::DataType ::UINT64; + else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8)) + return sd::DataType ::UTF8; + else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16)) + return sd::DataType ::UTF16; + else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32)) + return sd::DataType ::UTF32; + else { + // shape::printShapeInfoLinear("Bad unsigned datatype (not)stored in + // shape", const_cast(shapeInfo)); #ifndef __CUDA_ARCH__ - throw std::runtime_error("Bad datatype A"); + throw std::runtime_error("Bad datatype A"); #endif - } - } - else if (hasPropertyBitSet(shapeInfo, ARRAY_CHAR)) - return sd::DataType::INT8; - else if (hasPropertyBitSet(shapeInfo, ARRAY_SHORT)) - return sd::DataType::INT16; - else if (hasPropertyBitSet(shapeInfo, ARRAY_INT)) - return sd::DataType::INT32; - else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG)) - return sd::DataType::INT64; - else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8)) - return sd::DataType::UTF8; - else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16)) - return sd::DataType::UTF16; - else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32)) - return sd::DataType::UTF32; - else { - //shape::printShapeInfoLinear("Bad signed datatype (not)stored in shape", const_cast(shapeInfo)); + } + } else if (hasPropertyBitSet(shapeInfo, ARRAY_CHAR)) + return sd::DataType::INT8; + else if (hasPropertyBitSet(shapeInfo, ARRAY_SHORT)) + return sd::DataType::INT16; + else if (hasPropertyBitSet(shapeInfo, ARRAY_INT)) + return sd::DataType::INT32; + else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG)) + return sd::DataType::INT64; + else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8)) + return sd::DataType::UTF8; + else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16)) + return sd::DataType::UTF16; + else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32)) + return sd::DataType::UTF32; + else { + // shape::printShapeInfoLinear("Bad signed datatype (not)stored in shape", + // const_cast(shapeInfo)); #ifndef __CUDA_ARCH__ - throw std::runtime_error("Bad datatype B"); + throw std::runtime_error("Bad datatype B"); #endif - } - } + } +} - FORCEINLINE _CUDA_HD SpaceType ArrayOptions::spaceType(const Nd4jLong *shapeInfo) { - return spaceType(const_cast(shapeInfo)); - } +FORCEINLINE _CUDA_HD SpaceType +ArrayOptions::spaceType(const Nd4jLong *shapeInfo) { + return spaceType(const_cast(shapeInfo)); +} - FORCEINLINE _CUDA_HD SpaceType ArrayOptions::spaceType(Nd4jLong *shapeInfo) { - if (hasPropertyBitSet(shapeInfo, ARRAY_QUANTIZED)) - return SpaceType::QUANTIZED; - if (hasPropertyBitSet(shapeInfo, ARRAY_COMPLEX)) - return SpaceType::COMPLEX; - else // by default we return continuous type here - return SpaceType::CONTINUOUS; - } +FORCEINLINE _CUDA_HD SpaceType ArrayOptions::spaceType(Nd4jLong *shapeInfo) { + if (hasPropertyBitSet(shapeInfo, ARRAY_QUANTIZED)) + return SpaceType::QUANTIZED; + if (hasPropertyBitSet(shapeInfo, ARRAY_COMPLEX)) + return SpaceType::COMPLEX; + else // by default we return continuous type here + return SpaceType::CONTINUOUS; +} - FORCEINLINE _CUDA_HD ArrayType ArrayOptions::arrayType(const Nd4jLong *shapeInfo) { - return arrayType(const_cast(shapeInfo)); - } +FORCEINLINE _CUDA_HD ArrayType +ArrayOptions::arrayType(const Nd4jLong *shapeInfo) { + return arrayType(const_cast(shapeInfo)); +} - FORCEINLINE _CUDA_HD ArrayType ArrayOptions::arrayType(Nd4jLong *shapeInfo) { - if (hasPropertyBitSet(shapeInfo, ARRAY_SPARSE)) - return ArrayType::SPARSE; - else if (hasPropertyBitSet(shapeInfo, ARRAY_COMPRESSED)) - return ArrayType::COMPRESSED; - else if (hasPropertyBitSet(shapeInfo, ARRAY_EMPTY)) - return ArrayType::EMPTY; - else if (hasPropertyBitSet(shapeInfo, ARRAY_RAGGED)) - return ArrayType::RAGGED; - else // by default we return DENSE type here - return ArrayType::DENSE; - } +FORCEINLINE _CUDA_HD ArrayType ArrayOptions::arrayType(Nd4jLong *shapeInfo) { + if (hasPropertyBitSet(shapeInfo, ARRAY_SPARSE)) + return ArrayType::SPARSE; + else if (hasPropertyBitSet(shapeInfo, ARRAY_COMPRESSED)) + return ArrayType::COMPRESSED; + else if (hasPropertyBitSet(shapeInfo, ARRAY_EMPTY)) + return ArrayType::EMPTY; + else if (hasPropertyBitSet(shapeInfo, ARRAY_RAGGED)) + return ArrayType::RAGGED; + else // by default we return DENSE type here + return ArrayType::DENSE; +} - FORCEINLINE _CUDA_HD bool ArrayOptions::togglePropertyBit(Nd4jLong *shapeInfo, int property) { - extra(shapeInfo) ^= property; +FORCEINLINE _CUDA_HD bool ArrayOptions::togglePropertyBit(Nd4jLong *shapeInfo, + int property) { + extra(shapeInfo) ^= property; - return hasPropertyBitSet(shapeInfo, property); - } + return hasPropertyBitSet(shapeInfo, property); +} - FORCEINLINE _CUDA_HD void ArrayOptions::setPropertyBit(Nd4jLong *shapeInfo, int property) { - extra(shapeInfo) |= property; - } +FORCEINLINE _CUDA_HD void ArrayOptions::setPropertyBit(Nd4jLong *shapeInfo, + int property) { + extra(shapeInfo) |= property; +} - FORCEINLINE _CUDA_HD void ArrayOptions::unsetPropertyBit(Nd4jLong *shapeInfo, int property) { - extra(shapeInfo) &= property; - } +FORCEINLINE _CUDA_HD void ArrayOptions::unsetPropertyBit(Nd4jLong *shapeInfo, + int property) { + extra(shapeInfo) &= property; +} - FORCEINLINE _CUDA_HD SparseType ArrayOptions::sparseType(const Nd4jLong *shapeInfo) { - return sparseType(const_cast(shapeInfo)); - } +FORCEINLINE _CUDA_HD SparseType +ArrayOptions::sparseType(const Nd4jLong *shapeInfo) { + return sparseType(const_cast(shapeInfo)); +} - FORCEINLINE _CUDA_HD SparseType ArrayOptions::sparseType(Nd4jLong *shapeInfo) { +FORCEINLINE _CUDA_HD SparseType ArrayOptions::sparseType(Nd4jLong *shapeInfo) { #ifndef __CUDA_ARCH__ - if (!isSparseArray(shapeInfo)) - throw std::runtime_error("Not a sparse array"); + if (!isSparseArray(shapeInfo)) throw std::runtime_error("Not a sparse array"); #endif - if (hasPropertyBitSet(shapeInfo, ARRAY_CSC)) - return SparseType::CSC; - else if (hasPropertyBitSet(shapeInfo, ARRAY_CSR)) - return SparseType::CSR; - else if (hasPropertyBitSet(shapeInfo, ARRAY_COO)) - return SparseType::COO; - else - return SparseType::LIL; - } + if (hasPropertyBitSet(shapeInfo, ARRAY_CSC)) + return SparseType::CSC; + else if (hasPropertyBitSet(shapeInfo, ARRAY_CSR)) + return SparseType::CSR; + else if (hasPropertyBitSet(shapeInfo, ARRAY_COO)) + return SparseType::COO; + else + return SparseType::LIL; +} - FORCEINLINE _CUDA_HD void ArrayOptions::setPropertyBits(Nd4jLong *shapeInfo, std::initializer_list properties) { - for (auto v: properties) { - if (!hasPropertyBitSet(shapeInfo, v)) - setPropertyBit(shapeInfo, v); - } - } +FORCEINLINE _CUDA_HD void ArrayOptions::setPropertyBits( + Nd4jLong *shapeInfo, std::initializer_list properties) { + for (auto v : properties) { + if (!hasPropertyBitSet(shapeInfo, v)) setPropertyBit(shapeInfo, v); + } +} - FORCEINLINE _CUDA_HD void ArrayOptions::resetDataType(Nd4jLong *shapeInfo) { - unsetPropertyBit(shapeInfo, ARRAY_BOOL); - unsetPropertyBit(shapeInfo, ARRAY_HALF); - unsetPropertyBit(shapeInfo, ARRAY_BHALF); - unsetPropertyBit(shapeInfo, ARRAY_FLOAT); - unsetPropertyBit(shapeInfo, ARRAY_DOUBLE); - unsetPropertyBit(shapeInfo, ARRAY_INT); - unsetPropertyBit(shapeInfo, ARRAY_LONG); - unsetPropertyBit(shapeInfo, ARRAY_CHAR); - unsetPropertyBit(shapeInfo, ARRAY_SHORT); - unsetPropertyBit(shapeInfo, ARRAY_UNSIGNED); - } +FORCEINLINE _CUDA_HD void ArrayOptions::resetDataType(Nd4jLong *shapeInfo) { + unsetPropertyBit(shapeInfo, ARRAY_BOOL); + unsetPropertyBit(shapeInfo, ARRAY_HALF); + unsetPropertyBit(shapeInfo, ARRAY_BHALF); + unsetPropertyBit(shapeInfo, ARRAY_FLOAT); + unsetPropertyBit(shapeInfo, ARRAY_DOUBLE); + unsetPropertyBit(shapeInfo, ARRAY_INT); + unsetPropertyBit(shapeInfo, ARRAY_LONG); + unsetPropertyBit(shapeInfo, ARRAY_CHAR); + unsetPropertyBit(shapeInfo, ARRAY_SHORT); + unsetPropertyBit(shapeInfo, ARRAY_UNSIGNED); +} - FORCEINLINE _CUDA_HD void ArrayOptions::setDataType(Nd4jLong *shapeInfo, const sd::DataType dataType) { - resetDataType(shapeInfo); - if (dataType == sd::DataType::UINT8 || - dataType == sd::DataType::UINT16 || - dataType == sd::DataType::UINT32 || - dataType == sd::DataType::UINT64) { - - setPropertyBit(shapeInfo, ARRAY_UNSIGNED); - } - - switch (dataType) { - case sd::DataType::BOOL: - setPropertyBit(shapeInfo, ARRAY_BOOL); - break; - case sd::DataType::HALF: - setPropertyBit(shapeInfo, ARRAY_HALF); - break; - case sd::DataType::BFLOAT16: - setPropertyBit(shapeInfo, ARRAY_BHALF); - break; - case sd::DataType::FLOAT32: - setPropertyBit(shapeInfo, ARRAY_FLOAT); - break; - case sd::DataType::DOUBLE: - setPropertyBit(shapeInfo, ARRAY_DOUBLE); - break; - case sd::DataType::INT8: - setPropertyBit(shapeInfo, ARRAY_CHAR); - break; - case sd::DataType::INT16: - setPropertyBit(shapeInfo, ARRAY_SHORT); - break; - case sd::DataType::INT32: - setPropertyBit(shapeInfo, ARRAY_INT); - break; - case sd::DataType::INT64: - setPropertyBit(shapeInfo, ARRAY_LONG); - break; - case sd::DataType::UINT8: - setPropertyBit(shapeInfo, ARRAY_CHAR); - break; - case sd::DataType::UINT16: - setPropertyBit(shapeInfo, ARRAY_SHORT); - break; - case sd::DataType::UINT32: - setPropertyBit(shapeInfo, ARRAY_INT); - break; - case sd::DataType::UINT64: - setPropertyBit(shapeInfo, ARRAY_LONG); - break; - case sd::DataType::UTF8: - setPropertyBit(shapeInfo, ARRAY_UTF8); - break; - case sd::DataType::UTF16: - setPropertyBit(shapeInfo, ARRAY_UTF16); - break; - case sd::DataType::UTF32: - setPropertyBit(shapeInfo, ARRAY_UTF32); - break; - default: +FORCEINLINE _CUDA_HD void ArrayOptions::setDataType( + Nd4jLong *shapeInfo, const sd::DataType dataType) { + resetDataType(shapeInfo); + if (dataType == sd::DataType::UINT8 || dataType == sd::DataType::UINT16 || + dataType == sd::DataType::UINT32 || dataType == sd::DataType::UINT64) { + setPropertyBit(shapeInfo, ARRAY_UNSIGNED); + } + + switch (dataType) { + case sd::DataType::BOOL: + setPropertyBit(shapeInfo, ARRAY_BOOL); + break; + case sd::DataType::HALF: + setPropertyBit(shapeInfo, ARRAY_HALF); + break; + case sd::DataType::BFLOAT16: + setPropertyBit(shapeInfo, ARRAY_BHALF); + break; + case sd::DataType::FLOAT32: + setPropertyBit(shapeInfo, ARRAY_FLOAT); + break; + case sd::DataType::DOUBLE: + setPropertyBit(shapeInfo, ARRAY_DOUBLE); + break; + case sd::DataType::INT8: + setPropertyBit(shapeInfo, ARRAY_CHAR); + break; + case sd::DataType::INT16: + setPropertyBit(shapeInfo, ARRAY_SHORT); + break; + case sd::DataType::INT32: + setPropertyBit(shapeInfo, ARRAY_INT); + break; + case sd::DataType::INT64: + setPropertyBit(shapeInfo, ARRAY_LONG); + break; + case sd::DataType::UINT8: + setPropertyBit(shapeInfo, ARRAY_CHAR); + break; + case sd::DataType::UINT16: + setPropertyBit(shapeInfo, ARRAY_SHORT); + break; + case sd::DataType::UINT32: + setPropertyBit(shapeInfo, ARRAY_INT); + break; + case sd::DataType::UINT64: + setPropertyBit(shapeInfo, ARRAY_LONG); + break; + case sd::DataType::UTF8: + setPropertyBit(shapeInfo, ARRAY_UTF8); + break; + case sd::DataType::UTF16: + setPropertyBit(shapeInfo, ARRAY_UTF16); + break; + case sd::DataType::UTF32: + setPropertyBit(shapeInfo, ARRAY_UTF32); + break; + default: #ifndef __CUDA_ARCH__ - throw std::runtime_error("Can't set unknown data type"); + throw std::runtime_error("Can't set unknown data type"); #else - printf("Can't set unknown data type"); + printf("Can't set unknown data type"); #endif - } - } + } +} //////////////////////////////////////////////////////////////////////////////// - FORCEINLINE _CUDA_HD void ArrayOptions::copyDataType(Nd4jLong* to, const Nd4jLong* from) { - setDataType(to, dataType(from)); - } +FORCEINLINE _CUDA_HD void ArrayOptions::copyDataType(Nd4jLong *to, + const Nd4jLong *from) { + setDataType(to, dataType(from)); } +} // namespace sd -#endif // ND4J_ARRAY_OPTIONS_H :) \ No newline at end of file +#endif // ND4J_ARRAY_OPTIONS_H :) \ No newline at end of file diff --git a/libnd4j/include/array/ArrayType.h b/libnd4j/include/array/ArrayType.h index 83e80bc0f068..c749c278f2ac 100644 --- a/libnd4j/include/array/ArrayType.h +++ b/libnd4j/include/array/ArrayType.h @@ -22,13 +22,13 @@ #define ND4J_ARRAY_TYPE_H namespace sd { - enum ArrayType { - DENSE = 1, - SPARSE = 2, - COMPRESSED = 3, - EMPTY = 4, - RAGGED = 5, - }; +enum ArrayType { + DENSE = 1, + SPARSE = 2, + COMPRESSED = 3, + EMPTY = 4, + RAGGED = 5, +}; } #endif \ No newline at end of file diff --git a/libnd4j/include/array/ByteOrder.h b/libnd4j/include/array/ByteOrder.h index 121be9e9dcb9..88c20b693ad4 100644 --- a/libnd4j/include/array/ByteOrder.h +++ b/libnd4j/include/array/ByteOrder.h @@ -22,10 +22,10 @@ #define LIBND4J_BYTEORDER_H namespace sd { - enum ByteOrder { - LE = 0, - BE = 1, - }; +enum ByteOrder { + LE = 0, + BE = 1, +}; } -#endif //LIBND4J_BYTEORDER_H +#endif // LIBND4J_BYTEORDER_H diff --git a/libnd4j/include/array/ByteOrderUtils.h b/libnd4j/include/array/ByteOrderUtils.h index 0f335ea65f93..f34b9c1698b4 100644 --- a/libnd4j/include/array/ByteOrderUtils.h +++ b/libnd4j/include/array/ByteOrderUtils.h @@ -22,15 +22,15 @@ #define LIBND4J_BYTEORDERUTILS_H #include -#include "ByteOrder.h" #include -namespace sd { - class ND4J_EXPORT ByteOrderUtils { - public: - static ByteOrder fromFlatByteOrder(sd::graph::ByteOrder order); - }; -} +#include "ByteOrder.h" +namespace sd { +class SD_EXPORT ByteOrderUtils { + public: + static ByteOrder fromFlatByteOrder(sd::graph::ByteOrder order); +}; +} // namespace sd -#endif //LIBND4J_BYTEORDERUTILS_H +#endif // LIBND4J_BYTEORDERUTILS_H diff --git a/libnd4j/include/array/ConstantDataBuffer.h b/libnd4j/include/array/ConstantDataBuffer.h index 197b93307600..42a0c1cf00ca 100644 --- a/libnd4j/include/array/ConstantDataBuffer.h +++ b/libnd4j/include/array/ConstantDataBuffer.h @@ -26,37 +26,37 @@ #include #include - namespace sd { - class ND4J_EXPORT ConstantDataBuffer { - private: - std::shared_ptr _primaryBuffer; - std::shared_ptr _specialBuffer = nullptr; - uint64_t _length = 0; - uint8_t _sizeOf = 0; - - public: - ConstantDataBuffer(const std::shared_ptr& primary, uint64_t numEelements, DataType dype); - ConstantDataBuffer(const std::shared_ptr& primary, const std::shared_ptr& special, uint64_t numEelements, DataType dype); - ConstantDataBuffer(const ConstantDataBuffer &other); - ConstantDataBuffer() = default; - ~ConstantDataBuffer() = default; - - uint8_t sizeOf() const; - uint64_t length() const; - - void* primary() const; - void* special() const; - - ConstantDataBuffer& operator=(const ConstantDataBuffer& other) = default; - ConstantDataBuffer& operator=(ConstantDataBuffer&& other) noexcept = default; - - template - T* primaryAsT() const; - - template - T* specialAsT() const; - }; -} - -#endif //DEV_TESTS_CONSTANTDATABUFFER_H +class SD_EXPORT ConstantDataBuffer { + private: + std::shared_ptr _primaryBuffer; + std::shared_ptr _specialBuffer = nullptr; + uint64_t _length = 0; + uint8_t _sizeOf = 0; + + public: + ConstantDataBuffer(const std::shared_ptr& primary, uint64_t numEelements, DataType dype); + ConstantDataBuffer(const std::shared_ptr& primary, const std::shared_ptr& special, + uint64_t numEelements, DataType dype); + ConstantDataBuffer(const ConstantDataBuffer& other); + ConstantDataBuffer() = default; + ~ConstantDataBuffer() = default; + + uint8_t sizeOf() const; + uint64_t length() const; + + void* primary() const; + void* special() const; + + ConstantDataBuffer& operator=(const ConstantDataBuffer& other) = default; + ConstantDataBuffer& operator=(ConstantDataBuffer&& other) noexcept = default; + + template + T* primaryAsT()const; + + template + T* specialAsT() const; +}; +} // namespace sd + +#endif // SD_CONSTANTDATABUFFER_H diff --git a/libnd4j/include/array/ConstantDescriptor.h b/libnd4j/include/array/ConstantDescriptor.h index 89e36c2a99e7..a377a001fbbd 100644 --- a/libnd4j/include/array/ConstantDescriptor.h +++ b/libnd4j/include/array/ConstantDescriptor.h @@ -18,58 +18,59 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_CONSTANTDESCRIPTOR_H -#define DEV_TESTS_CONSTANTDESCRIPTOR_H +#ifndef SD_CONSTANTDESCRIPTOR_H +#define SD_CONSTANTDESCRIPTOR_H +#include #include +#include +#include + #include #include -#include -#include -#include namespace sd { - class ND4J_EXPORT ConstantDescriptor { - private: - std::vector _integerValues; - std::vector _floatValues; - public: - ConstantDescriptor(double* values, int length); - ConstantDescriptor(Nd4jLong const* values, int length); - ConstantDescriptor(std::initializer_list values); +class SD_EXPORT ConstantDescriptor { + private: + std::vector _integerValues; + std::vector _floatValues; - explicit ConstantDescriptor(std::vector &values); - explicit ConstantDescriptor(std::vector &values); + public: + ConstantDescriptor(double *values, int length); + ConstantDescriptor(Nd4jLong const *values, int length); + ConstantDescriptor(std::initializer_list values); - ~ConstantDescriptor() = default; + explicit ConstantDescriptor(std::vector &values); + explicit ConstantDescriptor(std::vector &values); - // equal to operator - bool operator==(const ConstantDescriptor &other) const; + ~ConstantDescriptor() = default; - // less than operator - bool operator<(const ConstantDescriptor &other) const; + // equal to operator + bool operator==(const ConstantDescriptor &other) const; - bool isInteger() const; - bool isFloat() const; + // less than operator + bool operator<(const ConstantDescriptor &other) const; - Nd4jLong length() const; + bool isInteger() const; + bool isFloat() const; - const std::vector& integerValues() const; - const std::vector& floatValues() const; - }; -} + Nd4jLong length() const; + + const std::vector &integerValues() const; + const std::vector &floatValues() const; +}; +} // namespace sd #ifndef __JAVACPP_HACK__ namespace std { - template<> - class ND4J_EXPORT hash { - public: - size_t operator()(const sd::ConstantDescriptor &k) const; - }; -} +template <> +class SD_EXPORT hash { + public: + size_t operator()(const sd::ConstantDescriptor &k) const; +}; +} // namespace std #endif - -#endif //DEV_TESTS_CONSTANTDESCRIPTOR_H +#endif // SD_CONSTANTDESCRIPTOR_H diff --git a/libnd4j/include/array/ConstantHolder.h b/libnd4j/include/array/ConstantHolder.h index a404e580843f..1006692b6e57 100644 --- a/libnd4j/include/array/ConstantHolder.h +++ b/libnd4j/include/array/ConstantHolder.h @@ -22,44 +22,45 @@ #ifndef LIBND4J_CONSTANTHOLDER_H #define LIBND4J_CONSTANTHOLDER_H -#include -#include #include +#include + +#include #include namespace sd { - class ConstantHolder { - private: - int _deviceId = 0; - std::mutex _mutex; +class ConstantHolder { + private: + int _deviceId = 0; + std::mutex _mutex; - std::map _buffers; - public: - ConstantHolder(const ConstantHolder& other); - ConstantHolder() = default; - ~ConstantHolder() = default; + std::map _buffers; - ConstantHolder& operator=(const ConstantHolder& other) = default; - ConstantHolder& operator=(ConstantHolder&& other) = default; + public: + ConstantHolder(const ConstantHolder& other); + ConstantHolder() = default; + ~ConstantHolder() = default; - bool hasBuffer(sd::DataType dataType); + ConstantHolder& operator=(const ConstantHolder& other); + ConstantHolder& operator=(ConstantHolder&& other); - template - bool hasBuffer(); + bool hasBuffer(sd::DataType dataType); - void addBuffer(ConstantDataBuffer &pointer, sd::DataType dataType); + template + bool hasBuffer(); - template - void addBuffer(ConstantDataBuffer &pointer); + void addBuffer(ConstantDataBuffer& pointer, sd::DataType dataType); - ConstantDataBuffer* getConstantDataBuffer(sd::DataType dataType); + template + void addBuffer(ConstantDataBuffer& pointer); - template - ConstantDataBuffer* getConstantDataBuffer(); + ConstantDataBuffer* getConstantDataBuffer(sd::DataType dataType); - std::mutex* mutex(); - }; -} + template + ConstantDataBuffer* getConstantDataBuffer(); + std::mutex* mutex(); +}; +} // namespace sd -#endif //DEV_TESTS_CONSTANTHOLDER_H +#endif // SD_CONSTANTHOLDER_H diff --git a/libnd4j/include/array/ConstantOffsetsBuffer.h b/libnd4j/include/array/ConstantOffsetsBuffer.h index 61c1e381f3aa..d99d91711585 100644 --- a/libnd4j/include/array/ConstantOffsetsBuffer.h +++ b/libnd4j/include/array/ConstantOffsetsBuffer.h @@ -28,7 +28,7 @@ namespace sd { -class ND4J_EXPORT ConstantOffsetsBuffer { +class SD_EXPORT ConstantOffsetsBuffer { private: std::shared_ptr _primaryOffsets; std::shared_ptr _specialOffsets; diff --git a/libnd4j/include/array/ConstantShapeBuffer.h b/libnd4j/include/array/ConstantShapeBuffer.h index 2996532710de..bc2118dab7ca 100644 --- a/libnd4j/include/array/ConstantShapeBuffer.h +++ b/libnd4j/include/array/ConstantShapeBuffer.h @@ -28,7 +28,7 @@ namespace sd { -class ND4J_EXPORT ConstantShapeBuffer { +class SD_EXPORT ConstantShapeBuffer { private: std::shared_ptr _primaryShapeInfo; std::shared_ptr _specialShapeInfo; diff --git a/libnd4j/include/array/CudaPointerDeallocator.h b/libnd4j/include/array/CudaPointerDeallocator.h index c5c817aebca2..8fbda6976955 100644 --- a/libnd4j/include/array/CudaPointerDeallocator.h +++ b/libnd4j/include/array/CudaPointerDeallocator.h @@ -26,7 +26,7 @@ #include namespace sd { -class ND4J_EXPORT CudaPointerDeallocator : public PointerDeallocator { +class SD_EXPORT CudaPointerDeallocator : public PointerDeallocator { public: CudaPointerDeallocator() = default; ~CudaPointerDeallocator() = default; diff --git a/libnd4j/include/array/DataBuffer.h b/libnd4j/include/array/DataBuffer.h index 59ffe3045e08..312cf9ccb60a 100644 --- a/libnd4j/include/array/DataBuffer.h +++ b/libnd4j/include/array/DataBuffer.h @@ -19,135 +19,150 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#ifndef DEV_TESTS_DATABUFFER_H -#define DEV_TESTS_DATABUFFER_H +#ifndef SD_DATABUFFER_H +#define SD_DATABUFFER_H -#include -#include -#include -#include #include -#include #include +#include +#include +#include +#include -namespace sd { - -class ND4J_EXPORT DataBuffer { - - private: - - void* _primaryBuffer = nullptr; - void* _specialBuffer = nullptr; - size_t _lenInBytes = 0; - DataType _dataType; - memory::Workspace* _workspace = nullptr; - bool _isOwnerPrimary; - bool _isOwnerSpecial; - std::atomic _deviceId; - - #ifdef __CUDABLAS__ - mutable std::atomic _counter; - mutable std::atomic _writePrimary; - mutable std::atomic _writeSpecial; - mutable std::atomic _readPrimary; - mutable std::atomic _readSpecial; - #endif - - void setCountersToZero(); - void copyCounters(const DataBuffer& other); - void deleteSpecial(); - void deletePrimary(); - void deleteBuffers(); - void setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial = false); - void allocateBuffers(const bool allocBoth = false); - void setSpecial(void* special, const bool isOwnerSpecial); - void copyBufferFromHost(const void* hostBuffer, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetHostBuffer = 0); - - - public: - - DataBuffer(void* primary, void* special, - const size_t lenInBytes, const DataType dataType, - const bool isOwnerPrimary = false, const bool isOwnerSpecial = false, - memory::Workspace* workspace = nullptr); - - DataBuffer(void* primary, - const size_t lenInBytes, const DataType dataType, - const bool isOwnerPrimary = false, - memory::Workspace* workspace = nullptr); - - DataBuffer(const void* hostBuffer, // copies data from hostBuffer to own memory buffer - const DataType dataType, const size_t lenInBytes, - memory::Workspace* workspace = nullptr); - - DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace = nullptr, const bool allocBoth = false); - - DataBuffer(const DataBuffer& other); - DataBuffer(DataBuffer&& other); - explicit DataBuffer(); - ~DataBuffer(); - - DataBuffer& operator=(const DataBuffer& other); - DataBuffer& operator=(DataBuffer&& other) noexcept; - - DataType getDataType(); - void setDataType(DataType dataType); - size_t getLenInBytes() const; - - void* primary(); - void* special(); - - void allocatePrimary(); - void allocateSpecial(); - - void writePrimary() const; - void writeSpecial() const; - void readPrimary() const; - void readSpecial() const; - bool isPrimaryActual() const; - bool isSpecialActual() const; - - void expand(const uint64_t size); - - int deviceId() const; - void setDeviceId(int deviceId); - void migrate(); - - template FORCEINLINE T* primaryAsT(); - template FORCEINLINE T* specialAsT(); - - void syncToPrimary(const LaunchContext* context, const bool forceSync = false); - void syncToSpecial(const bool forceSync = false); - - void setToZeroBuffers(const bool both = false); - - void copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetOther = 0); - - static void memcpy(const DataBuffer &dst, const DataBuffer &src); +#include - void setPrimaryBuffer(void *buffer, size_t length); - void setSpecialBuffer(void *buffer, size_t length); +namespace sd { - /** - * This method deletes buffers, if we're owners - */ - void close(); +class SD_EXPORT DataBuffer { + protected: + void* _primaryBuffer = nullptr; + void* _specialBuffer = nullptr; + size_t _lenInBytes = 0; + DataType _dataType; + memory::Workspace* _workspace = nullptr; + bool _isOwnerPrimary; + bool _isOwnerSpecial; + std::atomic _deviceId; + +#ifdef __CUDABLAS__ + mutable std::atomic _counter; + mutable std::atomic _writePrimary; + mutable std::atomic _writeSpecial; + mutable std::atomic _readPrimary; + mutable std::atomic _readSpecial; +#endif + + void setCountersToZero(); + void copyCounters(const DataBuffer& other); + virtual void deleteSpecial(); + virtual void deletePrimary(); + virtual void deleteBuffers(); + void setAllocFlags(const bool isOwnerPrimary, + const bool isOwnerSpecial = false); + void allocateBuffers(const bool allocBoth = false); + void setSpecial(void* special, const bool isOwnerSpecial); + void copyBufferFromHost(const void* hostBuffer, size_t sizeToCopyinBytes = 0, + const Nd4jLong offsetThis = 0, + const Nd4jLong offsetHostBuffer = 0); + + public: + DataBuffer(void* primary, void* special, const size_t lenInBytes, + const DataType dataType, const bool isOwnerPrimary = false, + const bool isOwnerSpecial = false, + memory::Workspace* workspace = nullptr); + + DataBuffer(void* primary, const size_t lenInBytes, const DataType dataType, + const bool isOwnerPrimary = false, + memory::Workspace* workspace = nullptr); + + DataBuffer(const void* hostBuffer, // copies data from hostBuffer to own + // memory buffer + const DataType dataType, const size_t lenInBytes, + memory::Workspace* workspace = nullptr); + + DataBuffer(const size_t lenInBytes, const DataType dataType, + memory::Workspace* workspace = nullptr, + const bool allocBoth = false); + + DataBuffer(const DataBuffer& other); + DataBuffer(DataBuffer&& other); + explicit DataBuffer(); + ~DataBuffer(); + + virtual DataBuffer& operator=(const DataBuffer& other); + virtual DataBuffer& operator=(DataBuffer&& other) noexcept; + + DataType getDataType(); + void setDataType(DataType dataType); + size_t getLenInBytes() const; + + virtual void* primary(); + virtual void* special(); + virtual void* platform(); + + virtual void allocatePrimary(); + virtual void allocateSpecial(); + + void writePrimary() const; + void writeSpecial() const; + void readPrimary() const; + void readSpecial() const; + bool isPrimaryActual() const; + bool isSpecialActual() const; + + virtual void expand(const uint64_t size); + + int deviceId() const; + void setDeviceId(int deviceId); + + virtual void migrate(); + + template + FORCEINLINE T* primaryAsT(); + template + FORCEINLINE T* specialAsT(); + template + FORCEINLINE T* platformAsT(); + + void syncToPrimary(const LaunchContext* context, + const bool forceSync = false); + void syncToSpecial(const bool forceSync = false); + + virtual void setToZeroBuffers(const bool both = false); + + void copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes = 0, + const Nd4jLong offsetThis = 0, + const Nd4jLong offsetOther = 0); + + static void memcpy(const DataBuffer& dst, const DataBuffer& src); + + virtual void setPrimaryBuffer(void* buffer, size_t length); + virtual void setSpecialBuffer(void* buffer, size_t length); + + /** + * This method deletes buffers, if we're owners + */ + virtual void close(); }; ///// IMLEMENTATION OF INLINE METHODS ///// //////////////////////////////////////////////////////////////////////// - template - T* DataBuffer::primaryAsT() { - return reinterpret_cast(_primaryBuffer); - } +template +T* DataBuffer::primaryAsT() { + return reinterpret_cast(primary()); +} //////////////////////////////////////////////////////////////////////// - template - T* DataBuffer::specialAsT() { - return reinterpret_cast(_specialBuffer); - } - +template +T* DataBuffer::specialAsT() { + return reinterpret_cast(special()); } +//////////////////////////////////////////////////////////////////////// +template +T* DataBuffer::platformAsT() { + return reinterpret_cast(platform()); +} +} // namespace sd -#endif //DEV_TESTS_DATABUFFER_H +#endif // SD_DATABUFFER_H diff --git a/libnd4j/include/array/DataType.h b/libnd4j/include/array/DataType.h index cf8baf7d0324..8aa16ffff4ac 100644 --- a/libnd4j/include/array/DataType.h +++ b/libnd4j/include/array/DataType.h @@ -22,31 +22,31 @@ #define ND4J_DATATYPE_H namespace sd { - enum DataType { - INHERIT = 0, - BOOL = 1, - FLOAT8 = 2, - HALF = 3, - HALF2 = 4, - FLOAT32 = 5, - DOUBLE = 6, - INT8 = 7, - INT16 = 8, - INT32 = 9, - INT64 = 10, - UINT8 = 11, - UINT16 = 12, - UINT32 = 13, - UINT64 = 14, - QINT8 = 15, - QINT16 = 16, - BFLOAT16 = 17, - UTF8 = 50, - UTF16 = 51, - UTF32 = 52, - ANY = 100, - AUTO = 200, - }; +enum DataType { + INHERIT = 0, + BOOL = 1, + FLOAT8 = 2, + HALF = 3, + HALF2 = 4, + FLOAT32 = 5, + DOUBLE = 6, + INT8 = 7, + INT16 = 8, + INT32 = 9, + INT64 = 10, + UINT8 = 11, + UINT16 = 12, + UINT32 = 13, + UINT64 = 14, + QINT8 = 15, + QINT16 = 16, + BFLOAT16 = 17, + UTF8 = 50, + UTF16 = 51, + UTF32 = 52, + ANY = 100, + AUTO = 200, +}; } #endif \ No newline at end of file diff --git a/libnd4j/include/array/DataTypeConversions.h b/libnd4j/include/array/DataTypeConversions.h index 44f55553373b..92043ae15abc 100644 --- a/libnd4j/include/array/DataTypeConversions.h +++ b/libnd4j/include/array/DataTypeConversions.h @@ -21,168 +21,172 @@ #ifndef LIBND4J_DATATYPECONVERSIONS_H #define LIBND4J_DATATYPECONVERSIONS_H -#include -#include -#include #include -#include +#include #include +#include #include #include -#include +#include +#include +#include namespace sd { - template - class ND4J_EXPORT DataTypeConversions { - private: - template - static FORCEINLINE void rconv(bool isBe, bool canKeep, T *buffer, Nd4jLong length, void *src) { - if (std::is_same::value && canKeep) { - memcpy(buffer, src, length * sizeof(T)); - } else { - auto tmp = new T2[length]; - memcpy(tmp, src, length * sizeof(T2)); - +template +class SD_EXPORT DataTypeConversions { + private: + template + static FORCEINLINE void rconv(bool isBe, bool canKeep, T *buffer, + Nd4jLong length, void *src) { + if (std::is_same::value && canKeep) { + memcpy(buffer, src, length * sizeof(T)); + } else { + auto tmp = new T2[length]; + memcpy(tmp, src, length * sizeof(T2)); #if __GNUC__ <= 4 - if (!canKeep) - for (Nd4jLong e = 0; e < length; e++) - buffer[e] = BitwiseUtils::swap_bytes(static_cast(tmp[e])); - else - TypeCast::convertGeneric(nullptr, tmp, length, buffer); + if (!canKeep) + for (Nd4jLong e = 0; e < length; e++) + buffer[e] = BitwiseUtils::swap_bytes(static_cast(tmp[e])); + else + TypeCast::convertGeneric(nullptr, tmp, length, buffer); #else - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - buffer[e] = canKeep ? static_cast(tmp[e]) : BitwiseUtils::swap_bytes(static_cast(tmp[e])); - }; - - samediff::Threads::parallel_for(func, 0, length); + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) + buffer[e] = canKeep + ? static_cast(tmp[e]) + : BitwiseUtils::swap_bytes(static_cast(tmp[e])); + }; + + samediff::Threads::parallel_for(func, 0, length); #endif - delete[] tmp; - } - } - - public: - static FORCEINLINE void convertType(void* vbuffer, void* src, DataType dataType, ByteOrder order, Nd4jLong length) { - auto buffer = reinterpret_cast(vbuffer); - bool isBe = BitwiseUtils::isBE(); - bool canKeep = (isBe && order == ByteOrder::BE) || (!isBe && order == ByteOrder::LE); - - switch (dataType) { - case BOOL: { - DataTypeConversions::template rconv(isBe, canKeep, buffer, length, src); - } - break; - case UINT8: { - DataTypeConversions::template rconv(isBe, canKeep, buffer, length, src); - } - break; - case INT8: { - DataTypeConversions::template rconv(isBe, canKeep, buffer, length, src); - } - break; - case INT16: { - DataTypeConversions::template rconv(isBe, canKeep, buffer, length, src); - } - break; - case INT32: { - DataTypeConversions::template rconv(isBe, canKeep, buffer, length, src); - } - break; - case INT64: { - DataTypeConversions::template rconv(isBe, canKeep, buffer, length, src); - } - break; - case FLOAT32: { - if (std::is_same::value && canKeep) { - memcpy(buffer, src, length * sizeof(T)); - } else { - auto tmp = new float[length]; - memcpy(tmp, src, length * sizeof(float)); - + delete[] tmp; + } + } + + public: + static FORCEINLINE void convertType(void *vbuffer, void *src, + DataType dataType, ByteOrder order, + Nd4jLong length) { + auto buffer = reinterpret_cast(vbuffer); + bool isBe = BitwiseUtils::isBE(); + bool canKeep = + (isBe && order == ByteOrder::BE) || (!isBe && order == ByteOrder::LE); + + switch (dataType) { + case BOOL: { + DataTypeConversions::template rconv(isBe, canKeep, buffer, + length, src); + } break; + case UINT8: { + DataTypeConversions::template rconv(isBe, canKeep, buffer, + length, src); + } break; + case INT8: { + DataTypeConversions::template rconv(isBe, canKeep, buffer, + length, src); + } break; + case INT16: { + DataTypeConversions::template rconv(isBe, canKeep, buffer, + length, src); + } break; + case INT32: { + DataTypeConversions::template rconv(isBe, canKeep, buffer, + length, src); + } break; + case INT64: { + DataTypeConversions::template rconv(isBe, canKeep, buffer, + length, src); + } break; + case FLOAT32: { + if (std::is_same::value && canKeep) { + memcpy(buffer, src, length * sizeof(T)); + } else { + auto tmp = new float[length]; + memcpy(tmp, src, length * sizeof(float)); #if __GNUC__ <= 4 - if (!canKeep) - for (Nd4jLong e = 0; e < length; e++) - buffer[e] = BitwiseUtils::swap_bytes(static_cast(tmp[e])); - else - TypeCast::convertGeneric(nullptr, tmp, length, buffer); + if (!canKeep) + for (Nd4jLong e = 0; e < length; e++) + buffer[e] = BitwiseUtils::swap_bytes(static_cast(tmp[e])); + else + TypeCast::convertGeneric(nullptr, tmp, length, buffer); #else - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - buffer[e] = canKeep ? static_cast(tmp[e]) : BitwiseUtils::swap_bytes(static_cast(tmp[e])); - }; - - samediff::Threads::parallel_for(func, 0, length); + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) + buffer[e] = + canKeep ? static_cast(tmp[e]) + : BitwiseUtils::swap_bytes(static_cast(tmp[e])); + }; + + samediff::Threads::parallel_for(func, 0, length); #endif - delete[] tmp; - } - } - break; - case DOUBLE: { - if (std::is_same::value && canKeep) { - memcpy(buffer, src, length * sizeof(T)); - } else { - auto tmp = new double[length]; - memcpy(tmp, src, length * sizeof(double)); + delete[] tmp; + } + } break; + case DOUBLE: { + if (std::is_same::value && canKeep) { + memcpy(buffer, src, length * sizeof(T)); + } else { + auto tmp = new double[length]; + memcpy(tmp, src, length * sizeof(double)); #if __GNUC__ <= 4 - if (!canKeep) - for (Nd4jLong e = 0; e < length; e++) - buffer[e] = BitwiseUtils::swap_bytes(static_cast(tmp[e])); - else - TypeCast::convertGeneric(nullptr, tmp, length, buffer); - + if (!canKeep) + for (Nd4jLong e = 0; e < length; e++) + buffer[e] = BitwiseUtils::swap_bytes(static_cast(tmp[e])); + else + TypeCast::convertGeneric(nullptr, tmp, length, buffer); #else - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - buffer[e] = canKeep ? static_cast(tmp[e]) : BitwiseUtils::swap_bytes(static_cast(tmp[e])); - }; - - samediff::Threads::parallel_for(func, 0, length); + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) + buffer[e] = + canKeep ? static_cast(tmp[e]) + : BitwiseUtils::swap_bytes(static_cast(tmp[e])); + }; + + samediff::Threads::parallel_for(func, 0, length); #endif - delete[] tmp; - } - } - break; - case HALF: { - - if (std::is_same::value && canKeep) { - memcpy(buffer, src, length * sizeof(T)); - } else { - auto tmp = new float16[length]; - memcpy(tmp, src, length * sizeof(float16)); + delete[] tmp; + } + } break; + case HALF: { + if (std::is_same::value && canKeep) { + memcpy(buffer, src, length * sizeof(T)); + } else { + auto tmp = new float16[length]; + memcpy(tmp, src, length * sizeof(float16)); #if __GNUC__ <= 4 - if (!canKeep) - for (Nd4jLong e = 0; e < length; e++) - buffer[e] = BitwiseUtils::swap_bytes(static_cast(tmp[e])); - else - TypeCast::convertGeneric(nullptr, tmp, length, buffer); + if (!canKeep) + for (Nd4jLong e = 0; e < length; e++) + buffer[e] = BitwiseUtils::swap_bytes(static_cast(tmp[e])); + else + TypeCast::convertGeneric(nullptr, tmp, length, buffer); #else - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - buffer[e] = canKeep ? static_cast(tmp[e]) : BitwiseUtils::swap_bytes(static_cast(tmp[e])); - }; - - samediff::Threads::parallel_for(func, 0, length); + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) + buffer[e] = + canKeep ? static_cast(tmp[e]) + : BitwiseUtils::swap_bytes(static_cast(tmp[e])); + }; + + samediff::Threads::parallel_for(func, 0, length); #endif - delete[] tmp; - } - } - break; - default: { - nd4j_printf("Unsupported DataType requested: [%i]\n", static_cast(dataType)); - throw std::runtime_error("Unsupported DataType"); - } - } + delete[] tmp; } - }; -} - - - -#endif //LIBND4J_DATATYPECONVERSIONS_H + } break; + default: { + nd4j_printf("Unsupported DataType requested: [%i]\n", + static_cast(dataType)); + throw std::runtime_error("Unsupported DataType"); + } + } + } +}; +} // namespace sd + +#endif // LIBND4J_DATATYPECONVERSIONS_H diff --git a/libnd4j/include/array/DataTypeUtils.h b/libnd4j/include/array/DataTypeUtils.h index 686b5bc97c27..ae8c2ee5f387 100644 --- a/libnd4j/include/array/DataTypeUtils.h +++ b/libnd4j/include/array/DataTypeUtils.h @@ -22,477 +22,505 @@ #ifndef DATATYPEUTILS_H #define DATATYPEUTILS_H -#include -#include +#include #include #include -#include -#include #include -#include +#include +#include +#include +#include //#include //#include #include namespace sd { - class ND4J_EXPORT DataTypeUtils { - public: - static int asInt(DataType type); - static DataType fromInt(int dtype); - static DataType fromFlatDataType(sd::graph::DType dtype); - FORCEINLINE static std::string asString(DataType dataType); - - template - static FORCEINLINE _CUDA_HD DataType fromT(); - static FORCEINLINE _CUDA_HD size_t sizeOfElement(DataType type); - - // returns the smallest finite value of the given type - template - FORCEINLINE static _CUDA_HD T min(); - - // returns the largest finite value of the given type - template - FORCEINLINE static _CUDA_HD T max(); - - /** - * returns inf for float/double and max for everything else - */ - template - FORCEINLINE static _CUDA_HD T infOrMax(); - - template - FORCEINLINE static _CUDA_HD T nanOrZero(); - - // returns the difference between 1.0 and the next representable value of the given floating-point type - template - FORCEINLINE static T eps(); +class SD_EXPORT DataTypeUtils { + public: + static int asInt(DataType type); + static DataType fromInt(int dtype); + static DataType fromFlatDataType(sd::graph::DType dtype); + FORCEINLINE static std::string asString(DataType dataType); + + template + static FORCEINLINE _CUDA_HD DataType fromT(); + static FORCEINLINE _CUDA_HD size_t sizeOfElement(DataType type); + + // returns the smallest finite value of the given type + template + FORCEINLINE static _CUDA_HD T min(); + + // returns the largest finite value of the given type + template + FORCEINLINE static _CUDA_HD T max(); + + /** + * returns inf for float/double and max for everything else + */ + template + FORCEINLINE static _CUDA_HD T infOrMax(); + + template + FORCEINLINE static _CUDA_HD T nanOrZero(); + + // returns the difference between 1.0 and the next representable value of the + // given floating-point type + template + FORCEINLINE static T eps(); + + FORCEINLINE static _CUDA_HD size_t sizeOf(DataType type); + FORCEINLINE static _CUDA_HD size_t sizeOf(const Nd4jLong* shapeInfo); + + FORCEINLINE static _CUDA_HD bool isR(sd::DataType dataType); + + FORCEINLINE static _CUDA_HD bool isZ(sd::DataType dataType); + + FORCEINLINE static _CUDA_HD bool isB(sd::DataType dataType); + + FORCEINLINE static _CUDA_HD bool isU(sd::DataType dataType); + + FORCEINLINE static _CUDA_HD bool isS(sd::DataType dataType); + + FORCEINLINE static sd::DataType pickPairwiseResultType(sd::DataType typeX, + sd::DataType typeY); + + FORCEINLINE static sd::DataType pickPairwiseResultType( + const Nd4jLong* shapeInfo1, const Nd4jLong* shapeInfo2); + + FORCEINLINE static sd::DataType pickFloatingType(sd::DataType typeX); + + template + FORCEINLINE static std::vector convertVector( + const std::vector& vector); + + template + FORCEINLINE static bool castShapeInfo(const Nd4jLong* originalShapeInfo, + T* newShapeInfo); + + template + // struct scalarTypesForNDarray { static bool const value = + // std::is_same::value || std::is_same::value || + // std::is_same::value || std::is_same::value || + // std::is_same::value || std::is_same::value; }; + struct scalarTypesForNDarray { + static bool const value = + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value; + }; + + template + struct scalarTypesForExecution { + static bool const value = + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value; + }; +}; - FORCEINLINE static _CUDA_HD size_t sizeOf(DataType type); - FORCEINLINE static _CUDA_HD size_t sizeOf(const Nd4jLong* shapeInfo); - - FORCEINLINE static _CUDA_HD bool isR(sd::DataType dataType); - - FORCEINLINE static _CUDA_HD bool isZ(sd::DataType dataType); - - FORCEINLINE static _CUDA_HD bool isB(sd::DataType dataType); - - FORCEINLINE static _CUDA_HD bool isU(sd::DataType dataType); - - FORCEINLINE static _CUDA_HD bool isS(sd::DataType dataType); - - FORCEINLINE static sd::DataType pickPairwiseResultType(sd::DataType typeX, sd::DataType typeY); - - FORCEINLINE static sd::DataType pickPairwiseResultType(const Nd4jLong* shapeInfo1, const Nd4jLong* shapeInfo2); - - FORCEINLINE static sd::DataType pickFloatingType(sd::DataType typeX); - - template - FORCEINLINE static std::vector convertVector(const std::vector &vector); - - template - FORCEINLINE static bool castShapeInfo(const Nd4jLong *originalShapeInfo, T *newShapeInfo); +////////////////////////////////////////////////////////////////////////// +///// IMLEMENTATION OF INLINE METHODS ///// +////////////////////////////////////////////////////////////////////////// - template - // struct scalarTypesForNDarray { static bool const value = std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value; }; - struct scalarTypesForNDarray { static bool const value = std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value; }; +FORCEINLINE sd::DataType DataTypeUtils::pickFloatingType(sd::DataType typeX) { + // if proposed dataType is already floating point - return it + if (isR(typeX)) return typeX; + return Environment::getInstance().defaultFloatDataType(); +} - template - struct scalarTypesForExecution { static bool const value = std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value; }; +FORCEINLINE bool DataTypeUtils::isR(sd::DataType dataType) { + return dataType == sd::DataType::FLOAT32 || + dataType == sd::DataType::BFLOAT16 || dataType == sd::DataType::HALF || + dataType == sd::DataType::DOUBLE; +} - }; +FORCEINLINE bool DataTypeUtils::isB(sd::DataType dataType) { + return dataType == sd::DataType::BOOL; +} +FORCEINLINE bool DataTypeUtils::isS(sd::DataType dataType) { + return dataType == sd::DataType::UTF8 || dataType == sd::DataType::UTF16 || + dataType == sd::DataType::UTF32; +} -////////////////////////////////////////////////////////////////////////// -///// IMLEMENTATION OF INLINE METHODS ///// -////////////////////////////////////////////////////////////////////////// +FORCEINLINE bool DataTypeUtils::isZ(sd::DataType dataType) { + return !isR(dataType) && !isB(dataType) && !isS(dataType); +} - FORCEINLINE sd::DataType DataTypeUtils::pickFloatingType(sd::DataType typeX) { - // if proposed dataType is already floating point - return it - if (isR(typeX)) - return typeX; - return Environment::getInstance().defaultFloatDataType(); - } +FORCEINLINE bool DataTypeUtils::isU(sd::DataType dataType) { + return dataType == sd::DataType::UINT8 || dataType == sd::DataType::UINT16 || + dataType == sd::DataType::UINT32 || dataType == sd::DataType::UINT64; +} - FORCEINLINE bool DataTypeUtils::isR(sd::DataType dataType) { - return dataType == sd::DataType::FLOAT32 || dataType == sd::DataType::BFLOAT16 || dataType == sd::DataType::HALF || dataType == sd::DataType::DOUBLE; - } +FORCEINLINE sd::DataType DataTypeUtils::pickPairwiseResultType( + sd::DataType typeX, sd::DataType typeY) { + // if both dtypes are the same - just return it + if (typeX == typeY) return typeX; + auto nd4j_max = [](sd::DataType typeX, sd::DataType typeY) { + return typeX > typeY ? typeX : typeY; + }; + auto rX = isR(typeX); + auto rY = isR(typeY); - FORCEINLINE bool DataTypeUtils::isB(sd::DataType dataType) { - return dataType == sd::DataType::BOOL; - } + // if X is float - use it + if (rX && !rY) return typeX; - FORCEINLINE bool DataTypeUtils::isS(sd::DataType dataType) { - return dataType == sd::DataType::UTF8 || dataType == sd::DataType::UTF16 || dataType == sd::DataType::UTF32; - } + // if Y is float - use it + if (!rX && rY) return typeY; - FORCEINLINE bool DataTypeUtils::isZ(sd::DataType dataType) { - return !isR(dataType) && !isB(dataType) && !isS(dataType); + // if both data types are float - return biggest one + if (rX && rY) { + // if we allow precision boost, then we pick bigger data type + if (sd::Environment::getInstance().precisionBoostAllowed()) { + return nd4j_max(typeX, typeY); + } else { + // and we return first operand otherwise + return typeX; } - - FORCEINLINE bool DataTypeUtils::isU(sd::DataType dataType) { - return dataType == sd::DataType::UINT8 || dataType == sd::DataType::UINT16 || dataType == sd::DataType::UINT32 || dataType == sd::DataType::UINT64; + } + + // if that's not real type, we apply same rules + if (!rX && !rY) { + if (sd::Environment::getInstance().precisionBoostAllowed()) { + return nd4j_max(typeX, typeY); + } else { + // and we return first operand otherwise + return typeX; } + } - FORCEINLINE sd::DataType DataTypeUtils::pickPairwiseResultType(sd::DataType typeX, sd::DataType typeY) { - // if both dtypes are the same - just return it - if (typeX == typeY) - return typeX; - auto nd4j_max = [](sd::DataType typeX, sd::DataType typeY) { - return typeX > typeY?typeX:typeY; - }; - auto rX = isR(typeX); - auto rY = isR(typeY); - - // if X is float - use it - if (rX && !rY) - return typeX; - - // if Y is float - use it - if (!rX && rY) - return typeY; - - // if both data types are float - return biggest one - if (rX && rY) { - // if we allow precision boost, then we pick bigger data type - if (sd::Environment::getInstance().precisionBoostAllowed()) { - return nd4j_max(typeX, typeY); - } else { - // and we return first operand otherwise - return typeX; - } - - } - - // if that's not real type, we apply same rules - if (!rX && !rY) { - if (sd::Environment::getInstance().precisionBoostAllowed()) { - return nd4j_max(typeX, typeY); - } else { - // and we return first operand otherwise - return typeX; - } - } - - return typeX; - } + return typeX; +} /////////////////////////////////////////////////////////////////// -FORCEINLINE sd::DataType DataTypeUtils::pickPairwiseResultType(const Nd4jLong* shapeInfo1, const Nd4jLong* shapeInfo2) { - - return pickPairwiseResultType(ArrayOptions::dataType(shapeInfo1), ArrayOptions::dataType(shapeInfo2)); +FORCEINLINE sd::DataType DataTypeUtils::pickPairwiseResultType( + const Nd4jLong* shapeInfo1, const Nd4jLong* shapeInfo2) { + return pickPairwiseResultType(ArrayOptions::dataType(shapeInfo1), + ArrayOptions::dataType(shapeInfo2)); } /////////////////////////////////////////////////////////////////// FORCEINLINE size_t DataTypeUtils::sizeOf(DataType type) { - return sizeOfElement(type); + return sizeOfElement(type); } /////////////////////////////////////////////////////////////////// FORCEINLINE size_t DataTypeUtils::sizeOf(const Nd4jLong* shapeInfo) { - return sizeOfElement(ArrayOptions::dataType(shapeInfo)); + return sizeOfElement(ArrayOptions::dataType(shapeInfo)); } // returns the smallest finite value of the given type -template<> +template <> FORCEINLINE _CUDA_HD int DataTypeUtils::min() { - return 1; + return 1; } -template<> +template <> FORCEINLINE _CUDA_HD char DataTypeUtils::min() { - return 1; + return 1; } template <> FORCEINLINE _CUDA_HD bool DataTypeUtils::min() { - return false; + return false; } -template<> +template <> FORCEINLINE _CUDA_HD Nd4jLong DataTypeUtils::min() { - return 1L; + return 1L; } -template<> +template <> FORCEINLINE _CUDA_HD uint64_t DataTypeUtils::min() { - return 1L; + return 1L; } -template<> +template <> FORCEINLINE _CUDA_HD uint32_t DataTypeUtils::min() { - return 1; + return 1; } -template<> +template <> FORCEINLINE _CUDA_HD float DataTypeUtils::min() { - return 1.175494e-38; + return 1.175494e-38; } -template<> +template <> FORCEINLINE _CUDA_HD float16 DataTypeUtils::min() { - return (float16) 6.1035e-05; + return (float16)6.1035e-05; } -template<> +template <> FORCEINLINE _CUDA_HD bfloat16 DataTypeUtils::min() { - return bfloat16::min(); + return bfloat16::min(); } -template<> +template <> FORCEINLINE _CUDA_HD double DataTypeUtils::min() { - return 2.2250738585072014e-308; + return 2.2250738585072014e-308; } /////////////////////////////////////////////////////////////////// // returns the largest finite value of the given type template <> FORCEINLINE _CUDA_HD int DataTypeUtils::max() { - return 2147483647; + return 2147483647; } template <> FORCEINLINE _CUDA_HD bool DataTypeUtils::max() { - return true; + return true; } template <> FORCEINLINE _CUDA_HD int8_t DataTypeUtils::max() { - return 127; + return 127; } template <> FORCEINLINE _CUDA_HD uint8_t DataTypeUtils::max() { - return 255; + return 255; } template <> FORCEINLINE _CUDA_HD int16_t DataTypeUtils::max() { - return 32767; + return 32767; } template <> FORCEINLINE _CUDA_HD uint16_t DataTypeUtils::max() { - return 65535; + return 65535; } template <> FORCEINLINE _CUDA_HD Nd4jLong DataTypeUtils::max() { - return 9223372036854775807LL; + return 9223372036854775807LL; } template <> FORCEINLINE _CUDA_HD uint32_t DataTypeUtils::max() { - return 4294967295; + return 4294967295; } template <> FORCEINLINE _CUDA_HD Nd4jULong DataTypeUtils::max() { - return 18446744073709551615LLU; + return 18446744073709551615LLU; } template <> FORCEINLINE _CUDA_HD float DataTypeUtils::max() { - return 3.402823e+38; + return 3.402823e+38; } template <> FORCEINLINE _CUDA_HD double DataTypeUtils::max() { - return 1.7976931348623157E308; + return 1.7976931348623157E308; } template <> FORCEINLINE _CUDA_HD float16 DataTypeUtils::max() { - return static_cast(65504.f); + return static_cast(65504.f); } template <> FORCEINLINE _CUDA_HD bfloat16 DataTypeUtils::max() { - return bfloat16::max(); + return bfloat16::max(); } template <> FORCEINLINE _CUDA_HD float DataTypeUtils::infOrMax() { - return std::numeric_limits::infinity(); + return std::numeric_limits::infinity(); } template <> FORCEINLINE _CUDA_HD double DataTypeUtils::infOrMax() { - return std::numeric_limits::infinity(); + return std::numeric_limits::infinity(); } template FORCEINLINE _CUDA_HD T DataTypeUtils::infOrMax() { - return DataTypeUtils::max(); + return DataTypeUtils::max(); } template <> FORCEINLINE _CUDA_HD float DataTypeUtils::nanOrZero() { - return std::numeric_limits::quiet_NaN(); + return std::numeric_limits::quiet_NaN(); } template <> FORCEINLINE _CUDA_HD double DataTypeUtils::nanOrZero() { - return std::numeric_limits::quiet_NaN(); + return std::numeric_limits::quiet_NaN(); } template FORCEINLINE _CUDA_HD T DataTypeUtils::nanOrZero() { - return static_cast(0); + return static_cast(0); } FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) { - switch(dataType) { - case INT8: - return std::string("INT8"); - case INT16: - return std::string("INT16"); - case INT32: - return std::string("INT32"); - case INT64: - return std::string("INT64"); - case BFLOAT16: - return std::string("BFLOAT16"); - case FLOAT32: - return std::string("FLOAT"); - case DOUBLE: - return std::string("DOUBLE"); - case HALF: - return std::string("HALF"); - case BOOL: - return std::string("BOOL"); - case UINT8: - return std::string("UINT8"); - case UINT16: - return std::string("UINT16"); - case UINT32: - return std::string("UINT32"); - case UINT64: - return std::string("UINT64"); - case UTF8: - return std::string("UTF8"); - case UTF16: - return std::string("UTF16"); - case UTF32: - return std::string("UTF32"); - default: - throw std::runtime_error("Unknown data type used"); - } + switch (dataType) { + case INT8: + return std::string("INT8"); + case INT16: + return std::string("INT16"); + case INT32: + return std::string("INT32"); + case INT64: + return std::string("INT64"); + case BFLOAT16: + return std::string("BFLOAT16"); + case FLOAT32: + return std::string("FLOAT"); + case DOUBLE: + return std::string("DOUBLE"); + case HALF: + return std::string("HALF"); + case BOOL: + return std::string("BOOL"); + case UINT8: + return std::string("UINT8"); + case UINT16: + return std::string("UINT16"); + case UINT32: + return std::string("UINT32"); + case UINT64: + return std::string("UINT64"); + case UTF8: + return std::string("UTF8"); + case UTF16: + return std::string("UTF16"); + case UTF32: + return std::string("UTF32"); + default: + throw std::runtime_error("Unknown data type used"); + } } - template -FORCEINLINE bool DataTypeUtils::castShapeInfo(const Nd4jLong *originalShapeInfo, T *newShapeInfo) { - auto shapeInfoLength = *originalShapeInfo * 2 + 4; - for (auto e = 0; e < shapeInfoLength; e++) { - if (originalShapeInfo[e] < static_cast(DataTypeUtils::max())) { - newShapeInfo[e] = static_cast(originalShapeInfo[e]); - } else - return false; - } +FORCEINLINE bool DataTypeUtils::castShapeInfo(const Nd4jLong* originalShapeInfo, + T* newShapeInfo) { + auto shapeInfoLength = *originalShapeInfo * 2 + 4; + for (auto e = 0; e < shapeInfoLength; e++) { + if (originalShapeInfo[e] < static_cast(DataTypeUtils::max())) { + newShapeInfo[e] = static_cast(originalShapeInfo[e]); + } else + return false; + } - return true; + return true; } /////////////////////////////////////////////////////////////////// -// returns the difference between 1.0 and the next representable value of the given floating-point type +// returns the difference between 1.0 and the next representable value of the +// given floating-point type template FORCEINLINE _CUDA_HD T DataTypeUtils::eps() { - if (std::is_same::value) - return std::numeric_limits::epsilon(); - else if (std::is_same::value) - return std::numeric_limits::epsilon(); - else if (std::is_same::value) - return 0.00097656; - else if (std::is_same::value) - return bfloat16::eps(); - else - return 0; -} - - - template - FORCEINLINE std::vector DataTypeUtils::convertVector(const std::vector &vector) { - std::vector result(vector.size()); - Nd4jLong vecSize = vector.size(); - for (Nd4jLong e = 0; e < vecSize; e++) - result[e] = static_cast(vector[e]); - - return result; - } - - FORCEINLINE _CUDA_HD size_t DataTypeUtils::sizeOfElement(sd::DataType type) { - switch (type) { - case sd::DataType::UINT8: - case sd::DataType::INT8: - case sd::DataType::FLOAT8: - case sd::DataType::QINT8: - case sd::DataType::BOOL: return (size_t) 1; - - case sd::DataType::BFLOAT16: - case sd::DataType::HALF: - case sd::DataType::INT16: - case sd::DataType::QINT16: - case sd::DataType::UINT16: return (size_t) 2; - - case sd::DataType::UTF8: - case sd::DataType::UTF16: - case sd::DataType::UTF32: - case sd::DataType::INT32: - case sd::DataType::UINT32: - case sd::DataType::HALF2: - case sd::DataType::FLOAT32: return (size_t) 4; - - case sd::DataType::UINT64: - case sd::DataType::INT64: - case sd::DataType::DOUBLE: return (size_t) 8; - - default: { - nd4j_printf("Unknown DataType used: [%i]\n", asInt(type)); + if (std::is_same::value) + return std::numeric_limits::epsilon(); + else if (std::is_same::value) + return std::numeric_limits::epsilon(); + else if (std::is_same::value) + return 0.00097656; + else if (std::is_same::value) + return bfloat16::eps(); + else + return 0; +} + +template +FORCEINLINE std::vector DataTypeUtils::convertVector( + const std::vector& vector) { + std::vector result(vector.size()); + Nd4jLong vecSize = vector.size(); + for (Nd4jLong e = 0; e < vecSize; e++) result[e] = static_cast(vector[e]); + + return result; +} + +FORCEINLINE _CUDA_HD size_t DataTypeUtils::sizeOfElement(sd::DataType type) { + switch (type) { + case sd::DataType::UINT8: + case sd::DataType::INT8: + case sd::DataType::FLOAT8: + case sd::DataType::QINT8: + case sd::DataType::BOOL: + return (size_t)1; + + case sd::DataType::BFLOAT16: + case sd::DataType::HALF: + case sd::DataType::INT16: + case sd::DataType::QINT16: + case sd::DataType::UINT16: + return (size_t)2; + + case sd::DataType::UTF8: + case sd::DataType::UTF16: + case sd::DataType::UTF32: + case sd::DataType::INT32: + case sd::DataType::UINT32: + case sd::DataType::HALF2: + case sd::DataType::FLOAT32: + return (size_t)4; + + case sd::DataType::UINT64: + case sd::DataType::INT64: + case sd::DataType::DOUBLE: + return (size_t)8; + + default: { + nd4j_printf("Unknown DataType used: [%i]\n", asInt(type)); #ifndef __CUDA_ARCH__ - throw std::runtime_error("Unknown DataType requested"); + throw std::runtime_error("Unknown DataType requested"); #endif - } - } - } - - template - FORCEINLINE _CUDA_HD sd::DataType sd::DataTypeUtils::fromT() { - if (std::is_same::value) { - return sd::DataType::BOOL; - } else if (std::is_same::value) { - return sd::DataType::UTF8; - } else if (std::is_same::value) { - return sd::DataType::UTF16; - } else if (std::is_same::value) { - return sd::DataType::UTF32; - } else if (std::is_same::value) { - return sd::DataType::FLOAT32; - } else if (std::is_same::value) { - return sd::DataType::HALF; - } else if (std::is_same::value) { - return sd::DataType::BFLOAT16; - } else if (std::is_same::value) { - return sd::DataType::DOUBLE; - } else if (std::is_same::value) { - return sd::DataType::INT8; - } else if (std::is_same::value) { - return sd::DataType::INT16; - } else if (std::is_same::value) { - return sd::DataType::INT32; - } else if (std::is_same::value) { - return sd::DataType::INT64; - } else if (std::is_same::value) { - return sd::DataType::UINT8; - } else if (std::is_same::value) { - return sd::DataType::UINT16; - } else if (std::is_same::value) { - return sd::DataType::UINT32; - } else if (std::is_same::value) { - return sd::DataType::UINT64; - } else { - return sd::DataType::INHERIT; - } } + } } -#endif //DATATYPEUTILS_H \ No newline at end of file +template +FORCEINLINE _CUDA_HD sd::DataType sd::DataTypeUtils::fromT() { + if (std::is_same::value) { + return sd::DataType::BOOL; + } else if (std::is_same::value) { + return sd::DataType::UTF8; + } else if (std::is_same::value) { + return sd::DataType::UTF16; + } else if (std::is_same::value) { + return sd::DataType::UTF32; + } else if (std::is_same::value) { + return sd::DataType::FLOAT32; + } else if (std::is_same::value) { + return sd::DataType::HALF; + } else if (std::is_same::value) { + return sd::DataType::BFLOAT16; + } else if (std::is_same::value) { + return sd::DataType::DOUBLE; + } else if (std::is_same::value) { + return sd::DataType::INT8; + } else if (std::is_same::value) { + return sd::DataType::INT16; + } else if (std::is_same::value) { + return sd::DataType::INT32; + } else if (std::is_same::value) { + return sd::DataType::INT64; + } else if (std::is_same::value) { + return sd::DataType::UINT8; + } else if (std::is_same::value) { + return sd::DataType::UINT16; + } else if (std::is_same::value) { + return sd::DataType::UINT32; + } else if (std::is_same::value) { + return sd::DataType::UINT64; + } else { + return sd::DataType::INHERIT; + } +} +} // namespace sd + +#endif // DATATYPEUTILS_H \ No newline at end of file diff --git a/libnd4j/include/array/ExtraArguments.h b/libnd4j/include/array/ExtraArguments.h index 131e8cd92924..4fce4bb73159 100644 --- a/libnd4j/include/array/ExtraArguments.h +++ b/libnd4j/include/array/ExtraArguments.h @@ -18,48 +18,48 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_EXTRAARGUMENTS_H -#define DEV_TESTS_EXTRAARGUMENTS_H +#ifndef SD_EXTRAARGUMENTS_H +#define SD_EXTRAARGUMENTS_H +#include +#include #include +#include + #include #include -#include -#include -#include namespace sd { - class ND4J_EXPORT ExtraArguments { - private: - std::vector _fpArgs; - std::vector _intArgs; - - std::vector _pointers; +class SD_EXPORT ExtraArguments { + private: + std::vector _fpArgs; + std::vector _intArgs; - template - void convertAndCopy(Nd4jPointer pointer, Nd4jLong offset); + std::vector _pointers; - void* allocate(size_t length, size_t elementSize); - public: - explicit ExtraArguments(std::initializer_list arguments); - explicit ExtraArguments(std::initializer_list arguments); + template + void convertAndCopy(Nd4jPointer pointer, Nd4jLong offset); - explicit ExtraArguments(const std::vector &arguments); - explicit ExtraArguments(const std::vector &arguments); - explicit ExtraArguments(const std::vector &arguments); + void* allocate(size_t length, size_t elementSize); - explicit ExtraArguments(); - ~ExtraArguments(); + public: + explicit ExtraArguments(std::initializer_list arguments); + explicit ExtraArguments(std::initializer_list arguments); - template - void* argumentsAsT(Nd4jLong offset = 0); + explicit ExtraArguments(const std::vector& arguments); + explicit ExtraArguments(const std::vector& arguments); + explicit ExtraArguments(const std::vector& arguments); - void* argumentsAsT(sd::DataType dataType, Nd4jLong offset = 0); + explicit ExtraArguments(); + ~ExtraArguments(); - size_t length(); - }; -} + template + void* argumentsAsT(Nd4jLong offset = 0); + void* argumentsAsT(sd::DataType dataType, Nd4jLong offset = 0); + size_t length(); +}; +} // namespace sd -#endif //DEV_TESTS_EXTRAARGUMENTS_H +#endif // SD_EXTRAARGUMENTS_H diff --git a/libnd4j/include/array/InteropDataBuffer.h b/libnd4j/include/array/InteropDataBuffer.h index 27b17aabb568..4cafb31d0cf0 100644 --- a/libnd4j/include/array/InteropDataBuffer.h +++ b/libnd4j/include/array/InteropDataBuffer.h @@ -18,54 +18,67 @@ // @author raver119@gmail.com // -#include #include #include +#include + #include #ifndef LIBND4J_INTEROPDATABUFFER_H #define LIBND4J_INTEROPDATABUFFER_H namespace sd { - /** - * This class is a wrapper for DataBuffer, suitable for sharing DataBuffer between front-end and back-end languages - */ - class ND4J_EXPORT InteropDataBuffer { - private: - std::shared_ptr _dataBuffer; - uint64_t _offset = 0; - public: - InteropDataBuffer(InteropDataBuffer &dataBuffer, uint64_t length, uint64_t offset); - InteropDataBuffer(std::shared_ptr databuffer); - InteropDataBuffer(size_t elements, sd::DataType dtype, bool allocateBoth); - ~InteropDataBuffer() = default; +/** + * This class is a wrapper for DataBuffer, suitable for sharing DataBuffer + * between front-end and back-end languages + */ +class SD_EXPORT InteropDataBuffer { + private: + std::shared_ptr _dataBuffer; + uint64_t _offset = 0; + + public: + InteropDataBuffer(InteropDataBuffer& dataBuffer, uint64_t length, + uint64_t offset); + InteropDataBuffer(std::shared_ptr databuffer); + InteropDataBuffer(size_t elements, sd::DataType dtype, bool allocateBoth); + ~InteropDataBuffer() = default; #ifndef __JAVACPP_HACK__ - std::shared_ptr getDataBuffer() const; - std::shared_ptr dataBuffer(); + std::shared_ptr getDataBuffer() const; + std::shared_ptr dataBuffer(); #endif - void* primary() const; - void* special() const; - - uint64_t offset() const ; - void setOffset(uint64_t offset); + void* primary() const; + void* special() const; - void setPrimary(void* ptr, size_t length); - void setSpecial(void* ptr, size_t length); + uint64_t offset() const; + void setOffset(uint64_t offset); - void expand(size_t newlength); + void setPrimary(void* ptr, size_t length); + void setSpecial(void* ptr, size_t length); - int deviceId() const; - void setDeviceId(int deviceId); + void expand(size_t newlength); - static void registerSpecialUse(const std::vector& writeList, const std::vector& readList); - static void prepareSpecialUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables = false); + int deviceId() const; + void setDeviceId(int deviceId); - static void registerPrimaryUse(const std::vector& writeList, const std::vector& readList); - static void preparePrimaryUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables = false); - }; -} + static void registerSpecialUse( + const std::vector& writeList, + const std::vector& readList); + static void prepareSpecialUse( + const std::vector& writeList, + const std::vector& readList, + bool synchronizeWritables = false); + static void registerPrimaryUse( + const std::vector& writeList, + const std::vector& readList); + static void preparePrimaryUse( + const std::vector& writeList, + const std::vector& readList, + bool synchronizeWritables = false); +}; +} // namespace sd -#endif //LIBND4J_INTEROPDATABUFFER_H +#endif // LIBND4J_INTEROPDATABUFFER_H diff --git a/libnd4j/include/array/ManagedDataBuffer.h b/libnd4j/include/array/ManagedDataBuffer.h new file mode 100644 index 000000000000..05d34d62f202 --- /dev/null +++ b/libnd4j/include/array/ManagedDataBuffer.h @@ -0,0 +1,49 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_MANAGEDDATABUFFER_H +#define SD_MANAGEDDATABUFFER_H + +#include +#include + +namespace sd { +/** + * This class provides special DataBuffer implementation for use within Graphs + */ +class SD_EXPORT ManagedDataBuffer : public DataBuffer { + private: + graph::GraphMemoryManager& _manager; + + protected: + memory::MemoryZone _zone; + MemoryDescriptor _descriptor; + + public: + ManagedDataBuffer(graph::GraphMemoryManager& manager, uint64_t numberOfBytes, + DataType dtype, memory::MemoryZone zone); + ~ManagedDataBuffer(); + + void* primary() override; + void* special() override; +}; +} // namespace sd + +#endif // SD_MANAGEDDATABUFFER_H diff --git a/libnd4j/include/array/NDArray.h b/libnd4j/include/array/NDArray.h index 7b32b7d490fa..fd5aa62c0d70 100644 --- a/libnd4j/include/array/NDArray.h +++ b/libnd4j/include/array/NDArray.h @@ -17,1718 +17,1968 @@ #ifndef NDARRAY_H #define NDARRAY_H -#include -#include -#include -#include -#include "legacy/NativeOpExecutioner.h" -#include -#include -#include -#include -#include -#include #include #include +#include +#include +#include +#include +#include #include +#include +#include +#include +#include +#include #include -#include -#include +#include +#include +#include +#include #include #include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include + +#include +#include #include #include #include #include +#include "legacy/NativeOpExecutioner.h" namespace sd { - template ::value>::type> - ND4J_EXPORT NDArray operator+(const NDArray& arr, const T& scalar); - template ::value>::type> - ND4J_EXPORT NDArray operator+(NDArray&& arr, const T& scalar); - template ::value>::type> - ND4J_EXPORT NDArray operator+(const T& scalar, const NDArray& arr); - template ::value>::type> - ND4J_EXPORT NDArray operator+(const T& scalar, NDArray&& arr); - - template ::value>::type> - ND4J_EXPORT NDArray operator-(const NDArray& arr, const T& scalar); - template ::value>::type> - ND4J_EXPORT NDArray operator-(NDArray&& arr, const T& scalar); - template ::value>::type> - ND4J_EXPORT NDArray operator-(const T& scalar, const NDArray& arr); - template ::value>::type> - ND4J_EXPORT NDArray operator-(const T& scalar, NDArray&& arr); - - template ::value>::type> - ND4J_EXPORT NDArray operator*(const NDArray& arr, const T& scalar); - template ::value>::type> - ND4J_EXPORT NDArray operator*(NDArray&& arr, const T& scalar); - template ::value>::type> - ND4J_EXPORT NDArray operator*(const T& scalar, const NDArray& arr); - template ::value>::type> - ND4J_EXPORT NDArray operator*(const T& scalar, NDArray&& arr); - - template ::value>::type> - ND4J_EXPORT NDArray operator/(const NDArray& arr, const T& scalar); - template ::value>::type> - ND4J_EXPORT NDArray operator/(NDArray&& arr, const T& scalar); - template ::value>::type> - ND4J_EXPORT NDArray operator/(const T& scalar, const NDArray& arr); - template ::value>::type> - ND4J_EXPORT NDArray operator/(const T& scalar, NDArray&& arr); - - template ::type>::value && std::is_same::type>::value>::type> - ND4J_EXPORT NDArray operator+(T1&& arr1, T2&& arr2); - template ::type>::value && std::is_same::type>::value>::type> - ND4J_EXPORT NDArray operator-(T1&& arr1, T2&& arr2); - template ::type>::value && std::is_same::type>::value>::type> - ND4J_EXPORT NDArray operator*(T1&& arr1, T2&& arr2); - template ::type>::value && std::is_same::type>::value>::type> - ND4J_EXPORT NDArray operator/(T1&& arr1, T2&& arr2); - - - - - ND4J_EXPORT NDArray mmul(const NDArray&, const NDArray&); - - class ND4J_EXPORT NDArray { - private: - /** - * This method applies given value to the buffer, wrt templates - * @tparam T - * @tparam Y - * @param buffer - * @param indices - * @param value - */ - template - void templatedSet(void *buffer, const Nd4jLong *indices, const void *value); - - template - void templatedSet(void *buffer, const Nd4jLong xOffset, const void *value); - - template - void templatedSet(void *buffer, const Nd4jLong xOfsset, sd::DataType dtype, const void *value); - - template - void templatedAssign(void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const; - - template - void templatedDoubleAssign(void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const; - - template - FORCEINLINE R templatedGet(void const* buffer, const Nd4jLong index) const; -/* - template - R templatedGetIndex(void *buffer, Nd4jLong *indices) const; -*/ - template - void* templatedPointerShift(const Nd4jLong offset) const; - - FORCEINLINE void copyBufferStatus(const NDArray& other) const; - - protected: - - /** - * if true then array doesn't own buffer and simply points to another's buffer - */ - bool _isView = false; - - /** - * pointer on DataBuffer buffers in cpu/device memory - */ - std::shared_ptr _buffer = std::make_shared(); - - /** - * buffers offset, it is the same both for cpu and device buffers - */ - Nd4jLong _offset = 0L; - - /** - * contains shape info: matrix rank, numbers of elements per each dimension, dimensions strides, element-wise-stride, c-like or fortan-like order - */ - const Nd4jLong *_shapeInfo = nullptr; - const Nd4jLong *_shapeInfoD = nullptr; - - /** - * pointer on device launch context (with all data needed there). - */ - sd::LaunchContext * _context = sd::LaunchContext::defaultContext(); - - // indicates if array's buffer is within workspace - bool _isAttached = false; - - /** - * Field to store cached length - */ - Nd4jLong _length = -1L; - - /** - * type of array elements - */ - sd::DataType _dataType = FLOAT32; - - /** - * deviceID where this NDArray belongs to - */ - int _deviceId = AffinityManager::currentDeviceId(); - - template - std::string toStringValue(T value); - - public: - NDArray() = default; - - /** - * do not allocate memory, memory for array is passed from outside - */ +template ::value>::type> +SD_EXPORT NDArray operator+(const NDArray& arr, const T& scalar); +template ::value>::type> +SD_EXPORT NDArray operator+(NDArray&& arr, const T& scalar); +template ::value>::type> +SD_EXPORT NDArray operator+(const T& scalar, const NDArray& arr); +template ::value>::type> +SD_EXPORT NDArray operator+(const T& scalar, NDArray&& arr); + +template ::value>::type> +SD_EXPORT NDArray operator-(const NDArray& arr, const T& scalar); +template ::value>::type> +SD_EXPORT NDArray operator-(NDArray&& arr, const T& scalar); +template ::value>::type> +SD_EXPORT NDArray operator-(const T& scalar, const NDArray& arr); +template ::value>::type> +SD_EXPORT NDArray operator-(const T& scalar, NDArray&& arr); + +template ::value>::type> +SD_EXPORT NDArray operator*(const NDArray& arr, const T& scalar); +template ::value>::type> +SD_EXPORT NDArray operator*(NDArray&& arr, const T& scalar); +template ::value>::type> +SD_EXPORT NDArray operator*(const T& scalar, const NDArray& arr); +template ::value>::type> +SD_EXPORT NDArray operator*(const T& scalar, NDArray&& arr); + +template ::value>::type> +SD_EXPORT NDArray operator/(const NDArray& arr, const T& scalar); +template ::value>::type> +SD_EXPORT NDArray operator/(NDArray&& arr, const T& scalar); +template ::value>::type> +SD_EXPORT NDArray operator/(const T& scalar, const NDArray& arr); +template ::value>::type> +SD_EXPORT NDArray operator/(const T& scalar, NDArray&& arr); + +template < + typename T1, typename T2, + typename = typename std::enable_if< + std::is_same::type>::value && + std::is_same::type>::value>::type> +SD_EXPORT NDArray operator+(T1&& arr1, T2&& arr2); +template < + typename T1, typename T2, + typename = typename std::enable_if< + std::is_same::type>::value && + std::is_same::type>::value>::type> +SD_EXPORT NDArray operator-(T1&& arr1, T2&& arr2); +template < + typename T1, typename T2, + typename = typename std::enable_if< + std::is_same::type>::value && + std::is_same::type>::value>::type> +SD_EXPORT NDArray operator*(T1&& arr1, T2&& arr2); +template < + typename T1, typename T2, + typename = typename std::enable_if< + std::is_same::type>::value && + std::is_same::type>::value>::type> +SD_EXPORT NDArray operator/(T1&& arr1, T2&& arr2); + +SD_EXPORT NDArray mmul(const NDArray&, const NDArray&); + +class SD_EXPORT NDArray { + private: + /** + * This method applies given value to the buffer, wrt templates + * @tparam T + * @tparam Y + * @param buffer + * @param indices + * @param value + */ + template + void templatedSet(void* buffer, const Nd4jLong* indices, const void* value); + + template + void templatedSet(void* buffer, const Nd4jLong xOffset, const void* value); + + template + void templatedSet(void* buffer, const Nd4jLong xOfsset, sd::DataType dtype, + const void* value); + + template + void templatedAssign(void* xBuffer, const Nd4jLong xOffset, + const void* yBuffer, const Nd4jLong yOffset) const; + + template + void templatedDoubleAssign(void* xBuffer, const Nd4jLong xOffset, + const void* yBuffer, const Nd4jLong yOffset) const; + + template + FORCEINLINE R templatedGet(void const* buffer, const Nd4jLong index) const; + /* + template + R templatedGetIndex(void *buffer, Nd4jLong *indices) const; + */ + template + void* templatedPointerShift(const Nd4jLong offset) const; + + FORCEINLINE void copyBufferStatus(const NDArray& other) const; + + protected: + /** + * if true then array doesn't own buffer and simply points to another's + * buffer + */ + bool _isView = false; + + /** + * pointer on DataBuffer buffers in cpu/device memory + */ + std::shared_ptr _buffer = std::make_shared(); + + /** + * buffers offset, it is the same both for cpu and device buffers + */ + Nd4jLong _offset = 0L; + + /** + * contains shape info: matrix rank, numbers of elements per each dimension, + * dimensions strides, element-wise-stride, c-like or fortan-like order + */ + const Nd4jLong* _shapeInfo = nullptr; + const Nd4jLong* _shapeInfoD = nullptr; + + /** + * pointer on device launch context (with all data needed there). + */ + sd::LaunchContext* _context = sd::LaunchContext::defaultContext(); + + // indicates if array's buffer is within workspace + bool _isAttached = false; + + /** + * Field to store cached length + */ + Nd4jLong _length = -1L; + + /** + * type of array elements + */ + sd::DataType _dataType = FLOAT32; + + /** + * deviceID where this NDArray belongs to + */ + int _deviceId = AffinityManager::currentDeviceId(); + + template + std::string toStringValue(T value) const; + + public: + NDArray() = default; + + /** + * do not allocate memory, memory for array is passed from outside + */ #ifndef __JAVACPP_HACK__ - NDArray(std::shared_ptr buffer, const ShapeDescriptor& descriptor, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const Nd4jLong offset = 0); - - NDArray(std::shared_ptr buffer, char order, const std::vector &shape, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - /** - * This contructors create scalar array containing string utf8 - * - */ - NDArray(const char* str, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext* context = sd::LaunchContext::defaultContext()) - : NDArray(std::string(str), dtype, context) { - } - NDArray(const std::string& string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - /** - * This contructors create scalar array containing string utf16 - * - */ - NDArray(const char16_t* u16string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()) - : NDArray(std::u16string(u16string), dtype, context) { - } - - NDArray(const std::u16string& u16string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - /** - * This contructors create scalar array containing string utf32 - * - */ - NDArray(const char32_t* u32string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()) - : NDArray(std::u32string(u32string), dtype, context) { - } - - NDArray(const std::u32string& u32string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - /** - * This contructors create array from vector of utf8 strings - * - */ - NDArray(const std::vector& shape, const std::vector& strings, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - NDArray(const std::vector& shape, const std::vector& string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - /** - * This contructors create array from vector of utf16 strings - * - */ - NDArray(const std::vector& shape, const std::vector& strings, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - NDArray(const std::vector& shape, const std::vector& string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - /** - * This contructors create array from vector of utf32 strings - * - */ - NDArray(const std::vector& shape, const std::vector& strings, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - NDArray(const std::vector& shape, const std::vector& string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + NDArray(std::shared_ptr buffer, const ShapeDescriptor& descriptor, + sd::LaunchContext* context = sd::LaunchContext::defaultContext(), + const Nd4jLong offset = 0); + + NDArray(std::shared_ptr buffer, char order, + const std::vector& shape, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + /** + * This contructors create scalar array containing string utf8 + * + */ + NDArray(const char* str, sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()) + : NDArray(std::string(str), dtype, context) {} + NDArray(const std::string& string, sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + /** + * This contructors create scalar array containing string utf16 + * + */ + NDArray(const char16_t* u16string, sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()) + : NDArray(std::u16string(u16string), dtype, context) {} + + NDArray(const std::u16string& u16string, + sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + /** + * This contructors create scalar array containing string utf32 + * + */ + NDArray(const char32_t* u32string, sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()) + : NDArray(std::u32string(u32string), dtype, context) {} + + NDArray(const std::u32string& u32string, + sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + /** + * This contructors create array from vector of utf8 strings + * + */ + NDArray(const std::vector& shape, + const std::vector& strings, + sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + NDArray(const std::vector& shape, + const std::vector& string, + sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + /** + * This contructors create array from vector of utf16 strings + * + */ + NDArray(const std::vector& shape, + const std::vector& strings, + sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + NDArray(const std::vector& shape, + const std::vector& string, + sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + /** + * This contructors create array from vector of utf32 strings + * + */ + NDArray(const std::vector& shape, + const std::vector& strings, + sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + NDArray(const std::vector& shape, + const std::vector& string, + sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); #endif - /** - * do not allocate memory, memory for array is passed from outside - */ - NDArray(void *buffer, Nd4jLong* shapeInfo, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), bool isBuffAlloc = false); - NDArray(void *buffer, const Nd4jLong* shapeInfo, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), bool isBuffAlloc = false); - - /** - * do not allocate memory, memory for array is passed from outside - * we suppose the content of both (device and host) buffers is identical - */ - NDArray(void *buffer, void *bufferD, const Nd4jLong* shapeInfo, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), bool isBuffAlloc = false, bool isBuffDAlloc = false); - - /** - * copy constructor - */ - NDArray(const NDArray& other); - - /** - * move constructor - */ - NDArray(NDArray&& other) noexcept; - - /** - * constructor, create array stored at given workspace - */ - NDArray(sd::LaunchContext * context); - - - /** - * constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently - */ - NDArray(const Nd4jLong* shapeInfo, bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), bool nullify = true); - - /** - * constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently - * set dtype as array type - */ - NDArray(const Nd4jLong* shapeInfo, sd::DataType dtype, bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), bool nullify = true); - - /** - * this constructor creates new array using shape information contained in vector argument - */ - NDArray(char order, const std::vector &shape, sd::DataType dtype = DOUBLE, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - /** - * This constructor creates new array with elements copied from data and using shape information stored in shape, elements from data will be casted to dtype - */ - NDArray(char order, const std::vector &shape, const std::vector& data, sd::DataType dtype = DOUBLE, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - /** - * this constructor creates new array using given buffer (without memory allocation) and shape information stored in shape - */ - NDArray(void *buffer, char order, const std::vector &shape, sd::DataType dtype, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool isBuffAlloc = false); - - /** - * This method returns new array with the same shape & data type - * @return - */ - NDArray like(); - - /** - * This method returns new uninitialized array with the same shape & data type - * @return - */ - NDArray ulike() const; - - - /** - * this constructor creates new NDArray with shape matching "other" array, - * doesn't copy "other" elements into new array !!! - */ - explicit NDArray(const NDArray* other, bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); - - /** - * this constructor creates scalar(and set its value = 0) or empty array depending on bool argument isScalar - */ - NDArray(sd::DataType dtype, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), bool isScalar = true); - - /** - * This method blocks until asynchronous operation finishes - */ - void synchronize(const char* msg) const; - - /** - * This method allows to set _isAttached flag - * @param reallyAttached - */ - void setAttached(bool reallyAttached); - - void tickWriteHost() const; - void tickWriteDevice() const; - void tickReadHost() const; - void tickReadDevice() const; - void tickBothActual() const; - bool isActualOnHostSide() const; - bool isActualOnDeviceSide() const; - void makeBothBuffersActual() const; - - void syncToHost() const; - void syncToDevice() const; - void syncShape() const; - - /** - * This method can be used on architectures that use special buffers - * @param writeList - * @param readList - */ - static void registerSpecialUse(const std::vector& writeList, const std::vector& readList = {}); - static void prepareSpecialUse(const std::vector& writeList, const std::vector& readList = {}, bool synchronizeWritables = false); - - static void registerPrimaryUse(const std::vector& writeList, const std::vector& readList = {}); - static void preparePrimaryUse(const std::vector& writeList, const std::vector& readList = {}, bool synchronizeWritables = false); - - /** - * This method returns buffer pointer offset by given number of elements, wrt own data type - * @param offset - * @return - */ - void const* bufferWithOffset(Nd4jLong offset) const; - void* bufferWithOffset(Nd4jLong offset); - - void const* specialBufferWithOffset(Nd4jLong offset) const; - void* specialBufferWithOffset(Nd4jLong offset); - /** - * copy assignment operator - * in particular, when _dataType != other._dataType and both shapes are the same, there will be allocation of new _buffer and _dataType acquires other._dataType - */ - NDArray& operator=(const NDArray& other); - - /** - * move assignment operator - */ - NDArray& operator=(NDArray&& other) noexcept; - - /** - * assignment operator, assigns the same scalar to all array elements - */ - template - NDArray& operator=(const T scalar); - - - /** - * operators for memory allocation and deletion - */ - void* operator new(size_t i); - void operator delete(void* p); - - - void setContext(sd::LaunchContext * context); - - /** - * create a new array by replicating current array by repeats times along given dimension - * axis - axis along which to repeat elements - * repeats - number of repetitions - */ - NDArray repeat(const int axis, const std::vector& repeats) const; - - /** - * This method fills this array with zeros - */ - void nullify(); - - /** - * This method returns quantized copy of given array - * - * @param array - * @return - */ - static NDArray quantize(const NDArray &array); - - /** - * fill target array by repeating current array - * axis - axis along which to repeat elements - * repeats - vector containing numbers of repetition for elements at given axis - */ - void repeat(const int axis, const std::vector& repeats, NDArray& target) const; - - /** - * creates array which points on certain sub-range of this array, sub-range is defined by given indices - */ - NDArray subarray(IndicesList& indices) const; - NDArray subarray(const std::initializer_list& idx) const; - NDArray subarray(const Intervals& idx) const; - - /** - * cast array elements to given dtype - */ - NDArray cast(DataType dtype) const; - - void cast(NDArray& target, DataType dtype); - - /** - * returns _context - */ - sd::LaunchContext * getContext() const { - return _context; - } + /** + * do not allocate memory, memory for array is passed from outside + */ + NDArray(void* buffer, Nd4jLong* shapeInfo, + sd::LaunchContext* context = sd::LaunchContext::defaultContext(), + bool isBuffAlloc = false); + NDArray(void* buffer, const Nd4jLong* shapeInfo, + sd::LaunchContext* context = sd::LaunchContext::defaultContext(), + bool isBuffAlloc = false); + + /** + * do not allocate memory, memory for array is passed from outside + * we suppose the content of both (device and host) buffers is identical + */ + NDArray(void* buffer, void* bufferD, const Nd4jLong* shapeInfo, + sd::LaunchContext* context = sd::LaunchContext::defaultContext(), + bool isBuffAlloc = false, bool isBuffDAlloc = false); + + /** + * copy constructor + */ + NDArray(const NDArray& other); + + /** + * move constructor + */ + NDArray(NDArray&& other) noexcept; + + /** + * constructor, create array stored at given workspace + */ + NDArray(sd::LaunchContext* context); + + /** + * constructor creates new NDArray using shape information from "shapeInfo", + * set all elements in new array to zeros, if copyStrides is true then use + * stride values from "shapeInfo", else calculate strides independently + */ + NDArray(const Nd4jLong* shapeInfo, bool copyStrides = false, + sd::LaunchContext* context = sd::LaunchContext::defaultContext(), + bool nullify = true); + + /** + * constructor creates new NDArray using shape information from "shapeInfo", + * set all elements in new array to be zeros, if copyStrides is true then use + * stride values from "shapeInfo", else calculate strides independently set + * dtype as array type + */ + NDArray(const Nd4jLong* shapeInfo, sd::DataType dtype, + bool copyStrides = false, + sd::LaunchContext* context = sd::LaunchContext::defaultContext(), + bool nullify = true); + + /** + * this constructor creates new array using shape information contained in + * vector argument + */ + NDArray(char order, const std::vector& shape, + sd::DataType dtype = DOUBLE, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + /** + * This constructor creates new array with elements copied from data and using + * shape information stored in shape, elements from data will be casted to + * dtype + */ + NDArray(char order, const std::vector& shape, + const std::vector& data, sd::DataType dtype = DOUBLE, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + /** + * this constructor creates new array using given buffer (without memory + * allocation) and shape information stored in shape + */ + NDArray(void* buffer, char order, const std::vector& shape, + sd::DataType dtype, + sd::LaunchContext* context = sd::LaunchContext::defaultContext(), + const bool isBuffAlloc = false); + + /** + * This method returns new array with the same shape & data type + * @return + */ + NDArray like(); + + /** + * This method returns new uninitialized array with the same shape & data type + * @return + */ + NDArray ulike() const; + + /** + * this constructor creates new NDArray with shape matching "other" array, + * doesn't copy "other" elements into new array !!! + */ + explicit NDArray( + const NDArray* other, bool copyStrides = false, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + /** + * this constructor creates scalar(and set its value = 0) or empty array + * depending on bool argument isScalar + */ + NDArray(sd::DataType dtype, + sd::LaunchContext* context = sd::LaunchContext::defaultContext(), + bool isScalar = true); + + /** + * This method blocks until asynchronous operation finishes + */ + void synchronize(const char* msg) const; + + /** + * This method allows to set _isAttached flag + * @param reallyAttached + */ + void setAttached(bool reallyAttached); + + void tickWriteHost() const; + void tickWriteDevice() const; + void tickReadHost() const; + void tickReadDevice() const; + void tickBothActual() const; + bool isActualOnHostSide() const; + bool isActualOnDeviceSide() const; + void makeBothBuffersActual() const; + + void syncToHost() const; + void syncToDevice() const; + void syncShape() const; + + /** + * This method can be used on architectures that use special buffers + * @param writeList + * @param readList + */ + static void registerSpecialUse(const std::vector& writeList, + const std::vector& readList = {}); + static void prepareSpecialUse(const std::vector& writeList, + const std::vector& readList = {}, + bool synchronizeWritables = false); + + static void registerPrimaryUse(const std::vector& writeList, + const std::vector& readList = {}); + static void preparePrimaryUse(const std::vector& writeList, + const std::vector& readList = {}, + bool synchronizeWritables = false); + + /** + * This method returns buffer pointer offset by given number of elements, wrt + * own data type + * @param offset + * @return + */ + void const* bufferWithOffset(Nd4jLong offset) const; + void* bufferWithOffset(Nd4jLong offset); + + void const* specialBufferWithOffset(Nd4jLong offset) const; + void* specialBufferWithOffset(Nd4jLong offset); + /** + * copy assignment operator + * in particular, when _dataType != other._dataType and both shapes are the + * same, there will be allocation of new _buffer and _dataType acquires + * other._dataType + */ + NDArray& operator=(const NDArray& other); + + /** + * move assignment operator + */ + NDArray& operator=(NDArray&& other) noexcept; + + /** + * assignment operator, assigns the same scalar to all array elements + */ + template + NDArray& operator=(const T scalar); + + /** + * operators for memory allocation and deletion + */ + void* operator new(size_t i); + void operator delete(void* p); + + void setContext(sd::LaunchContext* context); + + /** + * create a new array by replicating current array by repeats times along + * given dimension axis - axis along which to repeat elements repeats - number + * of repetitions + */ + NDArray repeat(const int axis, const std::vector& repeats) const; + + /** + * This method fills this array with zeros + */ + void nullify(); + + /** + * This method returns quantized copy of given array + * + * @param array + * @return + */ + static NDArray quantize(const NDArray& array); + + /** + * fill target array by repeating current array + * axis - axis along which to repeat elements + * repeats - vector containing numbers of repetition for elements at given + * axis + */ + void repeat(const int axis, const std::vector& repeats, + NDArray& target) const; + + /** + * creates array which points on certain sub-range of this array, sub-range + * is defined by given indices + */ + NDArray subarray(IndicesList& indices) const; + NDArray subarray(const std::initializer_list& idx) const; + NDArray subarray(const Intervals& idx) const; + + /** + * cast array elements to given dtype + */ + NDArray cast(DataType dtype) const; + + void cast(NDArray& target, DataType dtype); + + /** + * returns _context + */ + sd::LaunchContext* getContext() const { return _context; } #ifndef __JAVACPP_HACK__ - FORCEINLINE std::shared_ptr getDataBuffer() const; - FORCEINLINE std::shared_ptr dataBuffer(); + FORCEINLINE std::shared_ptr getDataBuffer() const; + FORCEINLINE std::shared_ptr dataBuffer(); #endif - /** - * returns host buffer - */ - FORCEINLINE void* buffer(); - FORCEINLINE const void* buffer() const; - - - /** - * returns buffer offset (offset is the same for host and device buffers) - */ - FORCEINLINE Nd4jLong bufferOffset() const; - - /** - * if _bufferD==nullptr return _buffer, else return _bufferD - */ - void* specialBuffer(); - const void* specialBuffer() const; - - /** - * returns device buffer if compilation is for cuda case, otherwise returns host buffer - */ - void* platformBuffer(); - const void* platformBuffer() const; - - - - template - T* bufferAsT(); - - template - const T* bufferAsT() const; - - /** - * returns _shapeInfo - */ - FORCEINLINE const Nd4jLong* shapeInfo() const; - - - /** - * Returns True if it's legally empty NDArray, or false otherwise - * @return - */ - FORCEINLINE bool isEmpty() const; - - /** - * if _shapeInfoD==nullptr return _shapeInfo, else return _shapeInfoD - */ - FORCEINLINE const Nd4jLong* specialShapeInfo() const; - - const Nd4jLong* platformShapeInfo() const; - - /** - * permutes (in-place) the dimensions in array according to "dimensions" array - */ - bool permutei(const std::initializer_list& dimensions); - bool permutei(const std::vector& dimensions); - bool permutei(const int* dimensions, const int rank); - - bool permutei(const std::initializer_list& dimensions); - bool permutei(const std::vector& dimensions); - bool permutei(const Nd4jLong* dimensions, const int rank); - - bool isFinite(); - bool hasNaNs(); - bool hasInfs(); - - void copyBuffersContinuouslyFrom(const NDArray& other, size_t sizeToCopyInBytes = 0, Nd4jLong offsetThis = 0, Nd4jLong offsetOther = 0); - - /** - * permutes the dimensions in array according to "dimensions" array, new array points on _buffer of this array - */ - NDArray permute(const std::initializer_list& dimensions) const &; - NDArray permute(const std::vector& dimensions) const &; - NDArray permute(const int* dimensions, const int rank) const &; - NDArray permute(const std::initializer_list& dimensions) &&; - NDArray permute(const std::vector& dimensions) &&; - NDArray permute(const int* dimensions, const int rank) &&; - - void permute(const int* dimensions, const int rank, NDArray& target) const; - void permute(const std::vector& dimensions, NDArray& target) const; - - NDArray permute(const std::initializer_list& dimensions) const &; - NDArray permute(const std::vector& dimensions) const &; - NDArray permute(const Nd4jLong* dimensions, const int rank) const &; - NDArray permute(const std::initializer_list& dimensions) &&; - NDArray permute(const std::vector& dimensions) &&; - NDArray permute(const Nd4jLong* dimensions, const int rank) &&; - - void permute(const Nd4jLong* dimensions, const int rank, NDArray& target) const; - void permute(const std::vector& dimensions, NDArray& target) const; - - /** - * This method streamlines given view or permuted array, and reallocates buffer - */ - void streamline(char order = 'a'); - - /** - * prints information about array shape - * msg - message to print out - */ - void printShapeInfo(const char * msg = nullptr) const; - - /** - * prints buffer elements - * msg - message to print out - * limit - number of array elements to print out - * sync - if true check whether host buffer is actual, if it is not then make it so - */ - void printBuffer(const char* msg = nullptr, Nd4jLong limit = -1, const bool sync = true) const; - - /** - * print element by element consequently in a way they (elements) are stored in physical memory - */ - void printLinearBuffer() const; - - /** - * prints _buffer (if host = true) or _bufferD (if host = false) as it is, that is in current state without checking buffer status - */ - template - void printCurrentBuffer(const bool host = true, const char* msg = nullptr, const int precision = 1) const; - - /** - * prints buffer elements, takes into account offset between elements (element-wise-stride) - * msg - message to print out - * limit - number of array elements to print out - */ - void printIndexedBuffer(const char* msg = nullptr, Nd4jLong limit = -1) const; - - std::string asIndexedString(Nd4jLong limit = -1); - std::string asString(Nd4jLong limit = -1); - - /** - * this method assigns values of given array to this one - */ - void assign(const NDArray* other, bool allowParallelism = true); - - /** - * this method assigns values of given array to this one - */ - void assign(const NDArray& other, bool allowParallelism = true); - - /** - * this method assigns given value to all elements in array - */ - template ::value>::type> - void assign(const T& value, bool allowParallelism = true); - - /** - * returns new copy of this array, optionally in different order - */ - NDArray dup(const char newOrder = 'a') const; - - /** - * returns sum of all elements of array - */ - NDArray sumNumber() const; - - /** - * returns mean number of array - */ - NDArray meanNumber() const; + /** + * returns host buffer + */ + FORCEINLINE void* buffer(); + FORCEINLINE const void* buffer() const; + + /** + * returns buffer offset (offset is the same for host and device buffers) + */ + FORCEINLINE Nd4jLong bufferOffset() const; + + /** + * if _bufferD==nullptr return _buffer, else return _bufferD + */ + void* specialBuffer(); + const void* specialBuffer() const; + + /** + * returns device buffer if compilation is for cuda case, otherwise returns + * host buffer + */ + void* platformBuffer(); + const void* platformBuffer() const; + + template + T* bufferAsT(); + + template + const T* bufferAsT() const; + + /** + * returns _shapeInfo + */ + FORCEINLINE const Nd4jLong* shapeInfo() const; + + /** + * Returns True if it's legally empty NDArray, or false otherwise + * @return + */ + FORCEINLINE bool isEmpty() const; + + /** + * if _shapeInfoD==nullptr return _shapeInfo, else return _shapeInfoD + */ + FORCEINLINE const Nd4jLong* specialShapeInfo() const; + + const Nd4jLong* platformShapeInfo() const; + + /** + * permutes (in-place) the dimensions in array according to "dimensions" + * array + */ + bool permutei(const std::initializer_list& dimensions); + bool permutei(const std::vector& dimensions); + bool permutei(const int* dimensions, const int rank); + + bool permutei(const std::initializer_list& dimensions); + bool permutei(const std::vector& dimensions); + bool permutei(const Nd4jLong* dimensions, const int rank); + + bool isFinite(); + bool hasNaNs(); + bool hasInfs(); + + void copyBuffersContinuouslyFrom(const NDArray& other, + size_t sizeToCopyInBytes = 0, + Nd4jLong offsetThis = 0, + Nd4jLong offsetOther = 0); + + /** + * permutes the dimensions in array according to "dimensions" array, new + * array points on _buffer of this array + */ + NDArray permute(const std::initializer_list& dimensions) const&; + NDArray permute(const std::vector& dimensions) const&; + NDArray permute(const int* dimensions, const int rank) const&; + NDArray permute(const std::initializer_list& dimensions) &&; + NDArray permute(const std::vector& dimensions) &&; + NDArray permute(const int* dimensions, const int rank) &&; + + void permute(const int* dimensions, const int rank, NDArray& target) const; + void permute(const std::vector& dimensions, NDArray& target) const; + + NDArray permute(const std::initializer_list& dimensions) const&; + NDArray permute(const std::vector& dimensions) const&; + NDArray permute(const Nd4jLong* dimensions, const int rank) const&; + NDArray permute(const std::initializer_list& dimensions) &&; + NDArray permute(const std::vector& dimensions) &&; + NDArray permute(const Nd4jLong* dimensions, const int rank) &&; + + void permute(const Nd4jLong* dimensions, const int rank, + NDArray& target) const; + void permute(const std::vector& dimensions, NDArray& target) const; + + /** + * This method streamlines given view or permuted array, and reallocates + * buffer + */ + void streamline(char order = 'a'); + + /** + * prints information about array shape + * msg - message to print out + */ + void printShapeInfo(const char* msg = nullptr) const; + + /** + * prints buffer elements + * msg - message to print out + * limit - number of array elements to print out + * sync - if true check whether host buffer is actual, if it is not then make + * it so + */ + void printBuffer(const char* msg = nullptr, Nd4jLong limit = -1, + const bool sync = true) const; + + /** + * make strings for current ndarray - linear or structured by dimensions + * */ + std::string linearString(Nd4jLong limit = -1) const; + std::string indexedBufferString(Nd4jLong limit = -1) const; + + /** + * print element by element consequently in a way they (elements) are stored + * in physical memory + */ + void printLinearBuffer() const; + + /** + * prints _buffer (if host = true) or _bufferD (if host = false) as it is, + * that is in current state without checking buffer status + */ + template + void printCurrentBuffer(const bool host = true, const char* msg = nullptr, + const int precision = 1) const; + + /** + * prints buffer elements, takes into account offset between elements + * (element-wise-stride) msg - message to print out limit - number of array + * elements to print out + */ + void printIndexedBuffer(const char* msg = nullptr, Nd4jLong limit = -1) const; + + std::string asIndexedString(Nd4jLong limit = -1) const; + std::string asString(Nd4jLong limit = -1) const; + + /** + * this method assigns values of given array to this one + */ + void assign(const NDArray* other, bool allowParallelism = true); + + /** + * this method assigns values of given array to this one + */ + void assign(const NDArray& other, bool allowParallelism = true); + + /** + * this method assigns given value to all elements in array + */ + template ::value>::type> + void assign(const T& value, bool allowParallelism = true); + + /** + * returns new copy of this array, optionally in different order + */ + NDArray dup(const char newOrder = 'a') const; + + /** + * returns sum of all elements of array + */ + NDArray sumNumber() const; + + /** + * returns mean number of array + */ + NDArray meanNumber() const; #ifndef __JAVACPP_HACK__ - /** - * This method explicitly enforces new shape for this NDArray, old shape/stride information is lost - */ - void enforce(const std::initializer_list &dimensions, char order = 'a'); - void enforce(std::vector &dimensions, char order = 'a'); - - - /** - * method reduces array by excluding its shapes along dimensions present in given dimensions vector, result is stored in new array to be returned - * dimensions - array of dimensions to reduce along - * keepDims - if true then put unities in place of reduced dimensions - */ - - NDArray reduceAlongDimension(sd::reduce::FloatOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray reduceAlongDimension(sd::reduce::FloatOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - - NDArray reduceAlongDimension(sd::reduce::SameOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray reduceAlongDimension(sd::reduce::SameOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - - NDArray reduceAlongDimension(sd::reduce::BoolOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray reduceAlongDimension(sd::reduce::BoolOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - - NDArray reduceAlongDimension(sd::reduce::LongOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray reduceAlongDimension(sd::reduce::LongOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - - /** - * method reduces array by excluding its shapes along dimensions present in given dimensions vector - * target - where to save result of reducing - * dimensions - array of dimensions to reduce along - * keepDims - if true then put unities in place of reduced dimensions - * extras - extra parameters - */ - void reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const; - void reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const; - void reduceAlongDimension(sd::reduce::BoolOps op, NDArray& target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const; - void reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const; - - /** - * return variance of array elements set - * biasCorrected - if true bias correction will be applied - */ - NDArray varianceNumber(sd::variance::Ops op, bool biasCorrected = true); - - /** - * apply scalar operation to array - * extraParams - extra parameters for operation - * returns scalar array - */ - NDArray reduceNumber(sd::reduce::FloatOps ops, void *extraParams = nullptr) const; - NDArray reduceNumber(sd::reduce::SameOps ops, void *extraParams = nullptr) const; - NDArray reduceNumber(sd::reduce::BoolOps ops, void *extraParams = nullptr) const; - NDArray reduceNumber(sd::reduce::LongOps ops, void *extraParams = nullptr) const; - - void reduceNumber(sd::reduce::FloatOps ops, NDArray& target, void *extraParams = nullptr) const; - void reduceNumber(sd::reduce::SameOps ops, NDArray& target, void *extraParams = nullptr) const; - void reduceNumber(sd::reduce::BoolOps ops, NDArray& target, void *extraParams = nullptr) const; - void reduceNumber(sd::reduce::LongOps ops, NDArray& target, void *extraParams = nullptr) const; - - /** - * returns element index which corresponds to some condition imposed by operation - * extraParams - extra parameters for operation - */ - NDArray indexReduceNumber(sd::indexreduce::Ops op, ExtraArguments *extraParams = nullptr); - - /** - * returns index of max element in a given array (optionally: along given dimension(s)) - * dimensions - optional vector with dimensions - */ - Nd4jLong argMax(std::initializer_list dimensions = {}); - - // FIXME: remove this method eventually - void makeBothActual() const { syncToDevice(); syncToHost(); } - - - void applyTransform(sd::transform::FloatOps op, NDArray& target, ExtraArguments *extraParams = nullptr); - void applyTransform(sd::transform::SameOps op, NDArray& target, ExtraArguments *extraParams = nullptr); - void applyTransform(sd::transform::AnyOps op, NDArray& target, ExtraArguments *extraParams = nullptr); - void applyTransform(sd::transform::BoolOps op, NDArray& target, ExtraArguments *extraParams = nullptr); - void applyTransform(sd::transform::StrictOps op, NDArray& target, ExtraArguments *extraParams = nullptr); - - /** - * apply OpName transformation to this array and store result in new array to be returned - * extraParams - extra parameters for operation - */ - NDArray transform(sd::transform::FloatOps op, void *extraParams = nullptr) const &; - NDArray transform(sd::transform::SameOps op, void *extraParams = nullptr) const &; - NDArray transform(sd::transform::BoolOps op, void *extraParams = nullptr) const &; - NDArray transform(sd::transform::StrictOps op, void *extraParams = nullptr) const &; - NDArray transform(sd::transform::FloatOps op, void *extraParams = nullptr) &&; - NDArray transform(sd::transform::SameOps op, void *extraParams = nullptr) &&; - NDArray transform(sd::transform::BoolOps op, void *extraParams = nullptr) &&; - NDArray transform(sd::transform::StrictOps op, void *extraParams = nullptr) &&; - - /** - * apply pairwise OpName transformation based on "this" and "other" arras elements, store result in this array - * other - second array necessary for pairwise operation - * extraParams - extra parameters for operation - */ - void applyPairwiseTransform(sd::pairwise::Ops op, const NDArray& other, ExtraArguments *extraParams = nullptr); - - /** - * apply pairwise OpName transformation based on "this" and "other" arras elements, store result in target array - * other - second array necessary for pairwise operation - * target - where to store result - * extraParams - extra parameters for operation - */ - void applyPairwiseTransform(sd::pairwise::Ops op, const NDArray& other, NDArray& target, ExtraArguments *extraParams = nullptr) const; - - void applyPairwiseTransform(sd::pairwise::BoolOps op, const NDArray& other, NDArray& target, ExtraArguments *extraParams = nullptr) const; - - void applyPairwiseTransform(sd::pairwise::IntOps op, const NDArray& other, NDArray&target, ExtraArguments *extraParams = nullptr) const; - - /** - * apply operation which requires broadcasting, broadcast a smaller array (tad) along bigger one (this) - * tad - array to broadcast - * dimensions - dimensions array to broadcast along - * target - where to store result - * extraParams - extra parameters for operation - */ - void applyBroadcast(sd::broadcast::Ops op, const std::initializer_list dimensions, const NDArray& tad, NDArray& target, ExtraArguments* extraArgs = nullptr); - - void applyBroadcast(sd::broadcast::Ops op, const std::vector &dimensions, const NDArray &tad, NDArray &target, ExtraArguments *extraArgs = nullptr); - - void applyBroadcast(sd::broadcast::BoolOps op, const std::vector &dimensions, const NDArray &tad, NDArray &target, ExtraArguments *extraArgs = nullptr); - - void applyBroadcast(sd::broadcast::IntOps op, const std::vector &dimensions, const NDArray& tad, NDArray &target, ExtraArguments *extraArgs = nullptr); - - /** - * apply operation which requires broadcasting, broadcast one tensor along another, also this method checks the possibility of broadcasting - * other - input array - * extraParams - extra parameters for operation - */ - NDArray applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs = nullptr) const &; - NDArray applyTrueBroadcast(sd::BroadcastOpsTuple op, NDArray&& other, ExtraArguments *extraArgs = nullptr) const &; - NDArray applyTrueBroadcast(sd::BroadcastOpsTuple op, NDArray&& other, ExtraArguments *extraArgs = nullptr) &&; - NDArray applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs = nullptr) &&; - - /** - * apply operation which requires broadcasting, broadcast one tensor along another, also this method checks the possibility of broadcasting - * other - input array - * target - where to store result - * checkTargetShape - if true check whether target shape is suitable for broadcasting - * extraParams - extra parameters for operation - */ - void applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const; - - void applyTrueBroadcast(sd::BroadcastBoolOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const; - - void applyTrueBroadcast(sd::BroadcastIntOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const; - - - /** - * apply a scalar operation to an array - * scalar - input scalar - * target - where to store result - * extraParams - extra parameters for operation - */ - template - void applyScalar(sd::scalar::Ops op, const T scalar, NDArray& target, ExtraArguments *extraParams = nullptr); - - template - void applyScalar(sd::scalar::BoolOps op, const T scalar, NDArray& target, ExtraArguments *extraParams = nullptr) const; - - template - void applyScalar(sd::scalar::IntOps op, const T scalar, NDArray& target, ExtraArguments *extraParams = nullptr) const; - - /** - * apply a scalar operation to an array - * scalar - input array which is simple scalar - * target - where to store result - * extraParams - extra parameters for operation - */ - void applyScalarArr(sd::scalar::Ops op, const NDArray& scalar, NDArray& target, ExtraArguments *extraParams = nullptr); - - void applyScalarArr(sd::scalar::BoolOps op, const NDArray& scalar, NDArray& target, ExtraArguments *extraParams = nullptr) const; - - void applyScalarArr(sd::scalar::IntOps op, const NDArray& scalar, NDArray& target, ExtraArguments *extraParams = nullptr) const; - -#if defined(__CUDABLAS__) //&& defined(BUILD_TESTS) - template - FORCEINLINE void applyLambda(Lambda func, NDArray& target); - - template - FORCEINLINE void applyPairwiseLambda(const NDArray& other, Lambda func, NDArray& target); - - template - FORCEINLINE void applyIndexedLambda(Lambda func, NDArray& target); - - template - FORCEINLINE void applyIndexedPairwiseLambda(NDArray& other, Lambda func, NDArray& target); - - template - FORCEINLINE void applyTriplewiseLambda(NDArray& second, NDArray& third, Lambda func, NDArray& target); + /** + * This method explicitly enforces new shape for this NDArray, old + * shape/stride information is lost + */ + void enforce(const std::initializer_list& dimensions, + char order = 'a'); + void enforce(std::vector& dimensions, char order = 'a'); + + /** + * method reduces array by excluding its shapes along dimensions present in + * given dimensions vector, result is stored in new array to be returned + * dimensions - array of dimensions to reduce along + * keepDims - if true then put unities in place of reduced dimensions + */ + + NDArray reduceAlongDimension(sd::reduce::FloatOps op, + const std::vector& dimensions, + const bool keepDims = false, + const bool supportOldShapes = false) const; + NDArray reduceAlongDimension(sd::reduce::FloatOps op, + const std::initializer_list& dimensions, + const bool keepDims = false, + const bool supportOldShapes = false) const; + + NDArray reduceAlongDimension(sd::reduce::SameOps op, + const std::vector& dimensions, + const bool keepDims = false, + const bool supportOldShapes = false) const; + NDArray reduceAlongDimension(sd::reduce::SameOps op, + const std::initializer_list& dimensions, + const bool keepDims = false, + const bool supportOldShapes = false) const; + + NDArray reduceAlongDimension(sd::reduce::BoolOps op, + const std::vector& dimensions, + const bool keepDims = false, + const bool supportOldShapes = false) const; + NDArray reduceAlongDimension(sd::reduce::BoolOps op, + const std::initializer_list& dimensions, + const bool keepDims = false, + const bool supportOldShapes = false) const; + + NDArray reduceAlongDimension(sd::reduce::LongOps op, + const std::vector& dimensions, + const bool keepDims = false, + const bool supportOldShapes = false) const; + NDArray reduceAlongDimension(sd::reduce::LongOps op, + const std::initializer_list& dimensions, + const bool keepDims = false, + const bool supportOldShapes = false) const; + + /** + * method reduces array by excluding its shapes along dimensions present in + * given dimensions vector target - where to save result of reducing + * dimensions - array of dimensions to reduce along + * keepDims - if true then put unities in place of reduced dimensions + * extras - extra parameters + */ + void reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, + const std::vector& dimensions, + const bool keepDims = false, + const bool supportOldShapes = false, + const bool checkTargetShape = true) const; + void reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, + const std::vector& dimensions, + const bool keepDims = false, + const bool supportOldShapes = false, + const bool checkTargetShape = true) const; + void reduceAlongDimension(sd::reduce::BoolOps op, NDArray& target, + const std::vector& dimensions, + const bool keepDims = false, + const bool supportOldShapes = false, + const bool checkTargetShape = true) const; + void reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, + const std::vector& dimensions, + const bool keepDims = false, + const bool supportOldShapes = false, + const bool checkTargetShape = true) const; + + /** + * return variance of array elements set + * biasCorrected - if true bias correction will be applied + */ + NDArray varianceNumber(sd::variance::Ops op, bool biasCorrected = true); + + /** + * apply scalar operation to array + * extraParams - extra parameters for operation + * returns scalar array + */ + NDArray reduceNumber(sd::reduce::FloatOps ops, + void* extraParams = nullptr) const; + NDArray reduceNumber(sd::reduce::SameOps ops, + void* extraParams = nullptr) const; + NDArray reduceNumber(sd::reduce::BoolOps ops, + void* extraParams = nullptr) const; + NDArray reduceNumber(sd::reduce::LongOps ops, + void* extraParams = nullptr) const; + + void reduceNumber(sd::reduce::FloatOps ops, NDArray& target, + void* extraParams = nullptr) const; + void reduceNumber(sd::reduce::SameOps ops, NDArray& target, + void* extraParams = nullptr) const; + void reduceNumber(sd::reduce::BoolOps ops, NDArray& target, + void* extraParams = nullptr) const; + void reduceNumber(sd::reduce::LongOps ops, NDArray& target, + void* extraParams = nullptr) const; + + /** + * returns element index which corresponds to some condition imposed by + * operation extraParams - extra parameters for operation + */ + NDArray indexReduceNumber(sd::indexreduce::Ops op, + ExtraArguments* extraParams = nullptr); + + /** + * returns index of max element in a given array (optionally: along given + * dimension(s)) dimensions - optional vector with dimensions + */ + Nd4jLong argMax(std::initializer_list dimensions = {}); + + void applyTransform(sd::transform::FloatOps op, NDArray& target, + ExtraArguments* extraParams = nullptr); + void applyTransform(sd::transform::SameOps op, NDArray& target, + ExtraArguments* extraParams = nullptr); + void applyTransform(sd::transform::AnyOps op, NDArray& target, + ExtraArguments* extraParams = nullptr); + void applyTransform(sd::transform::BoolOps op, NDArray& target, + ExtraArguments* extraParams = nullptr); + void applyTransform(sd::transform::StrictOps op, NDArray& target, + ExtraArguments* extraParams = nullptr); + + /** + * apply OpName transformation to this array and store result in new array to + * be returned extraParams - extra parameters for operation + */ + NDArray transform(sd::transform::FloatOps op, + void* extraParams = nullptr) const&; + NDArray transform(sd::transform::SameOps op, + void* extraParams = nullptr) const&; + NDArray transform(sd::transform::BoolOps op, + void* extraParams = nullptr) const&; + NDArray transform(sd::transform::StrictOps op, + void* extraParams = nullptr) const&; + NDArray transform(sd::transform::FloatOps op, void* extraParams = nullptr) &&; + NDArray transform(sd::transform::SameOps op, void* extraParams = nullptr) &&; + NDArray transform(sd::transform::BoolOps op, void* extraParams = nullptr) &&; + NDArray transform(sd::transform::StrictOps op, + void* extraParams = nullptr) &&; + + /** + * apply pairwise OpName transformation based on "this" and "other" arras + * elements, store result in this array other - second array necessary for + * pairwise operation extraParams - extra parameters for operation + */ + void applyPairwiseTransform(sd::pairwise::Ops op, const NDArray& other, + ExtraArguments* extraParams = nullptr); + + /** + * apply pairwise OpName transformation based on "this" and "other" arras + * elements, store result in target array other - second array necessary for + * pairwise operation target - where to store result extraParams - extra + * parameters for operation + */ + void applyPairwiseTransform(sd::pairwise::Ops op, const NDArray& other, + NDArray& target, + ExtraArguments* extraParams = nullptr) const; + + void applyPairwiseTransform(sd::pairwise::BoolOps op, const NDArray& other, + NDArray& target, + ExtraArguments* extraParams = nullptr) const; + + void applyPairwiseTransform(sd::pairwise::IntOps op, const NDArray& other, + NDArray& target, + ExtraArguments* extraParams = nullptr) const; + + /** + * apply operation which requires broadcasting, broadcast a smaller array + * (tad) along bigger one (this) tad - array to broadcast dimensions - + * dimensions array to broadcast along target - where to store result + * extraParams - extra parameters for operation + */ + void applyBroadcast(sd::broadcast::Ops op, + const std::initializer_list dimensions, + const NDArray& tad, NDArray& target, + ExtraArguments* extraArgs = nullptr); + + void applyBroadcast(sd::broadcast::Ops op, const std::vector& dimensions, + const NDArray& tad, NDArray& target, + ExtraArguments* extraArgs = nullptr); + + void applyBroadcast(sd::broadcast::BoolOps op, + const std::vector& dimensions, const NDArray& tad, + NDArray& target, ExtraArguments* extraArgs = nullptr); + + void applyBroadcast(sd::broadcast::IntOps op, + const std::vector& dimensions, const NDArray& tad, + NDArray& target, ExtraArguments* extraArgs = nullptr); + + /** + * apply operation which requires broadcasting, broadcast one tensor along + * another, also this method checks the possibility of broadcasting other - + * input array extraParams - extra parameters for operation + */ + NDArray applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, + ExtraArguments* extraArgs = nullptr) const&; + NDArray applyTrueBroadcast(sd::BroadcastOpsTuple op, NDArray&& other, + ExtraArguments* extraArgs = nullptr) const&; + NDArray applyTrueBroadcast(sd::BroadcastOpsTuple op, NDArray&& other, + ExtraArguments* extraArgs = nullptr) &&; + NDArray applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, + ExtraArguments* extraArgs = nullptr) &&; + + /** + * apply operation which requires broadcasting, broadcast one tensor along + * another, also this method checks the possibility of broadcasting other - + * input array target - where to store result checkTargetShape - if true check + * whether target shape is suitable for broadcasting extraParams - extra + * parameters for operation + */ + void applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, + NDArray& target, const bool checkTargetShape = true, + ExtraArguments* extraArgs = nullptr) const; + + void applyTrueBroadcast(sd::BroadcastBoolOpsTuple op, const NDArray& other, + NDArray& target, const bool checkTargetShape = true, + ExtraArguments* extraArgs = nullptr) const; + + void applyTrueBroadcast(sd::BroadcastIntOpsTuple op, const NDArray& other, + NDArray& target, const bool checkTargetShape = true, + ExtraArguments* extraArgs = nullptr) const; + + /** + * apply a scalar operation to an array + * scalar - input scalar + * target - where to store result + * extraParams - extra parameters for operation + */ + template + void applyScalar(sd::scalar::Ops op, const T scalar, NDArray& target, + ExtraArguments* extraParams = nullptr); + + template + void applyScalar(sd::scalar::BoolOps op, const T scalar, NDArray& target, + ExtraArguments* extraParams = nullptr) const; + + template + void applyScalar(sd::scalar::IntOps op, const T scalar, NDArray& target, + ExtraArguments* extraParams = nullptr) const; + + /** + * apply a scalar operation to an array + * scalar - input array which is simple scalar + * target - where to store result + * extraParams - extra parameters for operation + */ + void applyScalarArr(sd::scalar::Ops op, const NDArray& scalar, + NDArray& target, ExtraArguments* extraParams = nullptr); + + void applyScalarArr(sd::scalar::BoolOps op, const NDArray& scalar, + NDArray& target, + ExtraArguments* extraParams = nullptr) const; + + void applyScalarArr(sd::scalar::IntOps op, const NDArray& scalar, + NDArray& target, + ExtraArguments* extraParams = nullptr) const; + +#if defined(__CUDABLAS__) //&& defined(BUILD_TESTS) + template + FORCEINLINE void applyLambda(Lambda func, NDArray& target); + + template + FORCEINLINE void applyPairwiseLambda(const NDArray& other, Lambda func, + NDArray& target); + + template + FORCEINLINE void applyIndexedLambda(Lambda func, NDArray& target); + + template + FORCEINLINE void applyIndexedPairwiseLambda(NDArray& other, Lambda func, + NDArray& target); + + template + FORCEINLINE void applyTriplewiseLambda(NDArray& second, NDArray& third, + Lambda func, NDArray& target); #else - /** - * apply operation "func" to an array - * func - what operation to apply - * target - where to store result - */ - template - void applyLambda(const std::function& func, NDArray& target); - - /** - * apply pairwise operation "func" to an array - * other - input array - * func - what pairwise operation to apply - * target - where to store result - */ - template - void applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); - - template - void applyIndexedLambda(const std::function& func, NDArray& target); - - template - void applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); - - template - void applyTriplewiseLambda(NDArray& second, NDArray& third, const std::function& func, NDArray& target); + /** + * apply operation "func" to an array + * func - what operation to apply + * target - where to store result + */ + template + void applyLambda(const std::function& func, NDArray& target); + + /** + * apply pairwise operation "func" to an array + * other - input array + * func - what pairwise operation to apply + * target - where to store result + */ + template + void applyPairwiseLambda(const NDArray& other, + const std::function& func, NDArray& target); + + template + void applyIndexedLambda(const std::function& func, + NDArray& target); + + template + void applyIndexedPairwiseLambda(NDArray& other, + const std::function& func, + NDArray& target); + + template + void applyTriplewiseLambda(NDArray& second, NDArray& third, + const std::function& func, + NDArray& target); #endif - /** - * reduces dimensions in this array relying on index operation OpName - * dimensions - vector of dimensions to reduce along - * extraArgs - extra parameters for operation - */ - NDArray applyIndexReduce(sd::indexreduce::Ops op, const std::vector& dimensions, const ExtraArguments *extraParams = nullptr) const; - - /** - * reduces dimensions in array relying on index operation OpName - * target - where to store result - * dimensions - vector of dimensions to reduce along - * extraArgs - extra parameters for operation - */ - void applyIndexReduce(sd::indexreduce::Ops op, NDArray& target, const std::vector& dimensions, const ExtraArguments *extraParams = nullptr) const; - - /** - * apply reduce3 operation OpName to this and other array, return result in new output array - * other - input array - * extraArgs - extra parameters for operation - */ - NDArray applyReduce3(sd::reduce3::Ops op, const NDArray& other, const ExtraArguments* extraParams = nullptr) const; - - /** - * apply reduce3 operation OpName to this and other array, return result in new output array - * other - input array - * dimensions - vector of dimensions to reduce along (tads not axis) - * extraArgs - extra parameters for operation - */ - NDArray applyAllReduce3(sd::reduce3::Ops op, const NDArray& other, const std::vector& dimensions, const ExtraArguments* extraParams = nullptr) const; - - /** - * apply reduce3 (exec) operation OpName to this and other array, return result in new output array - * other - input array - * dimensions - vector of dimensions to reduce along (same as reduceAlongDimension) - * extraArgs - extra parameters for operation - */ - NDArray applyReduce3(sd::reduce3::Ops op, const NDArray& other, const std::vector& dimensions, const ExtraArguments* extraParams = nullptr) const; - - /** - * returns variance along given dimensions - * biasCorrected - if true bias correction will be applied - * dimensions - vector of dimensions to calculate variance along - */ - NDArray varianceAlongDimension(sd::variance::Ops op, const bool biasCorrected, const std::vector& dimensions) const; - NDArray varianceAlongDimension(sd::variance::Ops op, const bool biasCorrected, const std::initializer_list& dimensions) const; - - void varianceAlongDimension(sd::variance::Ops op, NDArray& target, const bool biasCorrected, const std::vector& dimensions) const; - void varianceAlongDimension(sd::variance::Ops op, NDArray& target, const bool biasCorrected, const std::initializer_list& dimensions) const; + /** + * reduces dimensions in this array relying on index operation OpName + * dimensions - vector of dimensions to reduce along + * extraArgs - extra parameters for operation + */ + NDArray applyIndexReduce(sd::indexreduce::Ops op, + const std::vector& dimensions, + const ExtraArguments* extraParams = nullptr) const; + + /** + * reduces dimensions in array relying on index operation OpName + * target - where to store result + * dimensions - vector of dimensions to reduce along + * extraArgs - extra parameters for operation + */ + void applyIndexReduce(sd::indexreduce::Ops op, NDArray& target, + const std::vector& dimensions, + const ExtraArguments* extraParams = nullptr) const; + + /** + * apply reduce3 operation OpName to this and other array, return result in + * new output array other - input array extraArgs - extra parameters for + * operation + */ + NDArray applyReduce3(sd::reduce3::Ops op, const NDArray& other, + const ExtraArguments* extraParams = nullptr) const; + + /** + * apply reduce3 operation OpName to this and other array, return result in + * new output array other - input array dimensions - vector of dimensions to + * reduce along (tads not axis) extraArgs - extra parameters for operation + */ + NDArray applyAllReduce3(sd::reduce3::Ops op, const NDArray& other, + const std::vector& dimensions, + const ExtraArguments* extraParams = nullptr) const; + + /** + * apply reduce3 (exec) operation OpName to this and other array, return + * result in new output array other - input array dimensions - vector of + * dimensions to reduce along (same as reduceAlongDimension) extraArgs - extra + * parameters for operation + */ + NDArray applyReduce3(sd::reduce3::Ops op, const NDArray& other, + const std::vector& dimensions, + const ExtraArguments* extraParams = nullptr) const; + + /** + * returns variance along given dimensions + * biasCorrected - if true bias correction will be applied + * dimensions - vector of dimensions to calculate variance along + */ + NDArray varianceAlongDimension(sd::variance::Ops op, const bool biasCorrected, + const std::vector& dimensions) const; + NDArray varianceAlongDimension( + sd::variance::Ops op, const bool biasCorrected, + const std::initializer_list& dimensions) const; + + void varianceAlongDimension(sd::variance::Ops op, NDArray& target, + const bool biasCorrected, + const std::vector& dimensions) const; + void varianceAlongDimension( + sd::variance::Ops op, NDArray& target, const bool biasCorrected, + const std::initializer_list& dimensions) const; #endif - /** - * apply transpose operation to the copy of this array, that is this array remains unaffected - */ - NDArray transpose() const &; - NDArray transpose() &&; - - /** - * perform transpose operation and store result in target, this array remains unaffected - * target - where to store result - */ - void transpose(NDArray& target) const; - - /** - * apply in-place transpose operation to this array, so this array becomes transposed - */ - void transposei(); - - /** - * returns the number of arrays pointing on specified dimension(s) - * dimensions - array of dimensions to point on - */ - Nd4jLong tensorsAlongDimension(const std::initializer_list dimensions) const ; - Nd4jLong tensorsAlongDimension(const std::vector& dimensions) const ; - - /** - * returns true if elements of two arrays are equal to within given epsilon value - * other - input array to compare - * eps - epsilon, this value defines the precision of elements comparison - */ - bool equalsTo(const NDArray *other, double eps = 1e-5) const; - bool equalsTo(const NDArray &other, double eps = 1e-5) const; - - /** - * add given row vector to all rows of this array - * row - row vector to add - */ - void addiRowVector(const NDArray& row); - - /** - * add given row vector to all rows of this array, store result in target - * row - row vector to add - * target - where to store result - */ - void addRowVector(const NDArray& row, NDArray& target) const; - - /** - * subtract given row vector from all rows of this array, store result in target - * row - row vector to subtract - * target - where to store result - */ - void subRowVector(const NDArray& row, NDArray& target) const; - - /** - * multiply all rows of this array on given row vector, store result in target - * row - row vector to multiply on - * target - where to store result - */ - void mulRowVector(const NDArray &row, NDArray& target) const; - - /** - * divide all rows of this array on given row vector, store result in target - * row - row vector to divide on - * target - where to store result - */ - void divRowVector(const NDArray &row, NDArray& target) const; - - /** - * add given column vector to all columns of this array, store result in target - * column - column vector to add - * target - where to store result - */ - void addColumnVector(const NDArray &column, NDArray& target) const; - - /** - * add given column vector to all columns of this array, this array becomes affected (in-place operation) - * column - column vector to add - */ - void addiColumnVector(const NDArray &column); - - /** - * multiply all columns of this array on given column vector, this array becomes affected (in-place operation) - * column - column vector to multiply on - */ - void muliColumnVector(const NDArray &column); - - /** - * returns number of bytes used by _buffer & _shapeInfo - */ - FORCEINLINE Nd4jLong memoryFootprint(); - - /** - * these methods suited for FlatBuffers use - */ - template - std::vector getBufferAsVector() const; - std::vector getShapeAsVector() const; - std::vector getShapeAsVectorInt() const; - std::vector getShapeInfoAsVector() const; - std::vector getShapeInfoAsFlatVector() const; - std::vector getShapeAsFlatVector() const; - - /** - * set new order and shape in case of suitable array length (in-place operation) - * order - order to set - * shape - shape to set - * copyToNewBuff - if true then old buffer will be copied to new buffer if last one will be allocated after reshaping - * if there was permute applied before or there are weird strides, then new buffer is allocated for array - */ - bool reshapei(const char order, const std::initializer_list& shape, const bool copyToNewBuff = true); - bool reshapei(const char order, const std::vector& shape, const bool copyToNewBuff = true); - - bool reshapei(const std::initializer_list& shape, const bool copyToNewBuff = true); - bool reshapei(const std::vector& shape, const bool copyToNewBuff = true); - - /** - * creates new array with corresponding order and shape, new array will point on _buffer of this array - * order - order to set - * shape - shape to set - * - * if permute have been applied before or there are weird strides, then new buffer is allocated for new array - */ - NDArray reshape(const char order, const std::vector& shape, const bool copyToNewBuff = true) const &; - NDArray reshape(const char order, const std::vector& shape, const bool copyToNewBuff = true) &&; - - /** - * calculate strides and set given order - * order - order to set - */ - void updateStrides(const char order); - - /** - * change an array by repeating it the number of times given by reps (in-place operation) - * repeats - contains numbers of repetitions - */ - void tilei(const std::vector& repeats); - - /** - * returns new array which is created by repeating of this array the number of times given by reps - * repeats - contains numbers of repetitions - */ - NDArray tile(const std::vector& repeats) const; - - /** - * change an array by repeating it the number of times given by reps (in-place operation) - * repeats - contains numbers of repetitions - * target - where to store result - */ - void tile(const std::vector& repeats, NDArray& target) const; - - /** - * change an array by repeating it the number of times to acquire the new shape which is the same as target shape - * target - where to store result - */ - void tile(NDArray& target) const; - - /** - * check whether array is identity matrix - */ - bool isIdentityMatrix(); - - /** - * check whether array is unitary matrix - */ - bool isUnitary(); - - /** - * operator returns subarray with buffer pointing at this->_buffer with offset defined by given intervals - * idx - intervals of indexes which define the subarrays to point on, idx has form {dim0Start,dim0End, dim1Start,dim1End, ....} and length (2 * this->rankOf()) - * when (dimStart == dimEnd) then whole range will be used for current dimension - * keepUnitiesInShape - if false then eliminate unities from resulting array shape, for example {1,a,1,b} -> {a,b} - * isStrided - if true then idx has length (3 * this->rankOf()) and contains additional stride numbers which correspond to stride between dimStart and dimEnd, - * so structure of idx is like {dim0Start,dim0End,dim0Stride, dim1Start,dim1End,dim1Stride, ....} - */ - NDArray operator()(const std::vector& idx, const bool keepUnitiesInShape = false, const bool isStrided = false) const; - - /** - * evaluates subarray with buffer pointing at this->_buffer and offset defined by given sequential index subArrIdx and dimensions in dimsToExclude - * subArrIdx - index of current sub-array - * dimsToExclude - MUST BE SORTED, dimensions to evaluate sub-array along, i.e. when shape is [2,3,4,5] and dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5], and subArrIdx must be in range [0,7] - * if dimsToExclude is empty then idxRanges containing all zeros (means whole array) will be returned. - * keepUnitiesInShape - if false then eliminate unities from resulting array shape, for example {1,a,1,b} -> {a,b} - */ - NDArray operator()(const Nd4jLong subArrIdx, const std::vector& dimsToExclude, bool keepUnitiesInShape = false) const; - - /** - * processes whole set of sub-arrays - * evaluates shapeInfo of sub-arrays (all sub-arrays have the same shapeInfo) and their buffer offsets (each sub-array has its own unique offset from original this-buffer) - * dimsToExclude - MUST BE SORTED, dimensions to evaluate sub-array along, i.e. when shape is [2,3,4,5] and dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5] - * if dimsToExclude.size() = array rank it means sub-array is whole array and copy of original_shapeInfo will be returned and one zero offset - * subArrShapeInfo - output argument, contains shapeInfo common for all sub-arrays - * subArrOffsets - output argument, contains successive sub-arrays offsets from original this-buffer - * keepUnitiesInShape - if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b} - */ - void getSubArrShapeAndOffsets(const std::vector& dimsToExclude, Nd4jLong* &subArrShapeInfo, Nd4jLong* &subArrOffsets, bool keepUnitiesInShape = false) const; - - /** - * addition unary operator array += other - * other - input array to add - */ - void operator+=(const NDArray& other); - - /** - * subtraction unary operator array -= other - * other - input array to add - */ - void operator-=(const NDArray& other); - - template - void operator+=(const T other); - - template - void operator-=(const T other); - - /** - * negative operator, it changes sign of all array elements on opposite - */ - NDArray operator-() const &; - NDArray operator-() &&; - - /** - * pairwise multiplication unary operator array *= other - * other - input array to multiply on - */ - void operator*=(const NDArray& other); - - /** - * multiplication unary operator array *= scalar - * scalar - input scalar to multiply on - */ - template - void operator*=(const T scalar); - - /** - * pairwise division unary operator: array /= other - * other - input array to divide on - */ - void operator/=(const NDArray& other); - - /** - * division unary operator: array /= scalar - * scalar - input scalar to divide on - */ - template - void operator/=(const T scalar); - - /** - * friend function which implements mathematical multiplication of two arrays - * left - input array - * right - input array - */ - friend NDArray mmul(const NDArray& left, const NDArray& right); - - /** - * return vector containing _buffer as flat binary array - */ - std::vector asByteVector(); - - /** - * makes array to be identity matrix (not necessarily square), that is set all diagonal elements = 1, rest = 0 - */ - void setIdentity(); - - /** - * swaps the contents of tow arrays, - * PLEASE NOTE: method doesn't take into account the shapes of arrays, shapes may be different except one condition: arrays lengths must be the same - */ - void swapUnsafe(NDArray& other); - - /** - * return vector with buffer which points on corresponding diagonal elements of array - * type - means of vector to be returned: column ('c') or row ('r') - */ - NDArray diagonal(const char type ) const; - - /** - * fill target matrix with given value in one or two directions from main diagonal: - * - down from main diagonal starting at subdiagonal number "lower" if direction = 'l' (down) or 'b' (both) - * - up from main diagonal starting at superdiagonal number "upper"if direction = 'u' (up) or 'b' (both) - * direction - in what direction to fill matrix. There are 3 possible directions: - * 'u' - fill up, mathematically this corresponds to lower triangular matrix, subdiagonal "lower" unaffected - * 'l' - fill down, mathematically this corresponds to upper triangular matrix, superdiagonal "upper" remains unaffected - * 'b' - fill in both directions, both "lower" and "upper" are taken into account - * rest of target elements are equal to this array elements - * target and this array should have same shapes, except when this_rank = 1 (in that case should be target_rank = 2) - */ - template - void fillAsTriangular(const float value, int lower, int upper, NDArray& target, const char direction = 'b'); - - /** - * change an array by repeating it the number of times in order to acquire new shape equal to the input shape - * - * shape - contains new shape to broadcast array to - * target - optional argument, if target != nullptr the resulting array will be placed in target, in opposite case tile operation is done in place - */ - NDArray tileToShape(const Nd4jLong* shapeInfo); - void tileToShape(const std::vector& shape, NDArray& target); + /** + * apply transpose operation to the copy of this array, that is this array + * remains unaffected + */ + NDArray transpose() const&; + NDArray transpose() &&; + + /** + * perform transpose operation and store result in target, this array remains + * unaffected target - where to store result + */ + void transpose(NDArray& target) const; + + /** + * apply in-place transpose operation to this array, so this array becomes + * transposed + */ + void transposei(); + + /** + * returns the number of arrays pointing on specified dimension(s) + * dimensions - array of dimensions to point on + */ + Nd4jLong tensorsAlongDimension( + const std::initializer_list dimensions) const; + Nd4jLong tensorsAlongDimension(const std::vector& dimensions) const; + + /** + * returns true if elements of two arrays are equal to within given epsilon + * value other - input array to compare eps - epsilon, this value defines the + * precision of elements comparison + */ + bool equalsTo(const NDArray* other, double eps = 1e-5) const; + bool equalsTo(const NDArray& other, double eps = 1e-5) const; + + /** + * add given row vector to all rows of this array + * row - row vector to add + */ + void addiRowVector(const NDArray& row); + + /** + * add given row vector to all rows of this array, store result in target + * row - row vector to add + * target - where to store result + */ + void addRowVector(const NDArray& row, NDArray& target) const; + + /** + * subtract given row vector from all rows of this array, store result in + * target row - row vector to subtract target - where to store result + */ + void subRowVector(const NDArray& row, NDArray& target) const; + + /** + * multiply all rows of this array on given row vector, store result in + * target row - row vector to multiply on target - where to store result + */ + void mulRowVector(const NDArray& row, NDArray& target) const; + + /** + * divide all rows of this array on given row vector, store result in target + * row - row vector to divide on + * target - where to store result + */ + void divRowVector(const NDArray& row, NDArray& target) const; + + /** + * add given column vector to all columns of this array, store result in + * target column - column vector to add target - where to store result + */ + void addColumnVector(const NDArray& column, NDArray& target) const; + + /** + * add given column vector to all columns of this array, this array becomes + * affected (in-place operation) column - column vector to add + */ + void addiColumnVector(const NDArray& column); + + /** + * multiply all columns of this array on given column vector, this array + * becomes affected (in-place operation) column - column vector to multiply on + */ + void muliColumnVector(const NDArray& column); + + /** + * returns number of bytes used by _buffer & _shapeInfo + */ + FORCEINLINE Nd4jLong memoryFootprint(); + + /** + * these methods suited for FlatBuffers use + */ + template + std::vector getBufferAsVector() const; + std::vector getShapeAsVector() const; + std::vector getShapeAsVectorInt() const; + std::vector getShapeInfoAsVector() const; + std::vector getShapeInfoAsFlatVector() const; + std::vector getShapeAsFlatVector() const; + + /** + * set new order and shape in case of suitable array length (in-place + * operation) order - order to set shape - shape to set copyToNewBuff - if + * true then old buffer will be copied to new buffer if last one will be + * allocated after reshaping if there was permute applied before or there are + * weird strides, then new buffer is allocated for array + */ + bool reshapei(const char order, const std::initializer_list& shape, + const bool copyToNewBuff = true); + bool reshapei(const char order, const std::vector& shape, + const bool copyToNewBuff = true); + + bool reshapei(const std::initializer_list& shape, + const bool copyToNewBuff = true); + bool reshapei(const std::vector& shape, + const bool copyToNewBuff = true); + + /** + * creates new array with corresponding order and shape, new array will point + * on _buffer of this array order - order to set shape - shape to set + * + * if permute have been applied before or there are weird strides, then new + * buffer is allocated for new array + */ + NDArray reshape(const char order, const std::vector& shape, + const bool copyToNewBuff = true) const&; + NDArray reshape(const char order, const std::vector& shape, + const bool copyToNewBuff = true) &&; + + /** + * calculate strides and set given order + * order - order to set + */ + void updateStrides(const char order); + + /** + * change an array by repeating it the number of times given by reps + * (in-place operation) repeats - contains numbers of repetitions + */ + void tilei(const std::vector& repeats); + + /** + * returns new array which is created by repeating of this array the number + * of times given by reps repeats - contains numbers of repetitions + */ + NDArray tile(const std::vector& repeats) const; + + /** + * change an array by repeating it the number of times given by reps + * (in-place operation) repeats - contains numbers of repetitions target - + * where to store result + */ + void tile(const std::vector& repeats, NDArray& target) const; + + /** + * change an array by repeating it the number of times to acquire the new + * shape which is the same as target shape target - where to store result + */ + void tile(NDArray& target) const; + + /** + * check whether array is identity matrix + */ + bool isIdentityMatrix(); + + /** + * check whether array is unitary matrix + */ + bool isUnitary(); + + /** + * operator returns subarray with buffer pointing at this->_buffer with + * offset defined by given intervals idx - intervals of indexes which define + * the subarrays to point on, idx has form {dim0Start,dim0End, + * dim1Start,dim1End, ....} and length (2 * this->rankOf()) when (dimStart == + * dimEnd) then whole range will be used for current dimension + * keepUnitiesInShape - if false then eliminate unities from resulting array + * shape, for example {1,a,1,b} -> {a,b} isStrided - if true then idx has + * length (3 * this->rankOf()) and contains additional stride numbers which + * correspond to stride between dimStart and dimEnd, so structure of idx is + * like {dim0Start,dim0End,dim0Stride, dim1Start,dim1End,dim1Stride, ....} + */ + NDArray operator()(const std::vector& idx, + const bool keepUnitiesInShape = false, + const bool isStrided = false) const; + + /** + * evaluates subarray with buffer pointing at this->_buffer and offset + * defined by given sequential index subArrIdx and dimensions in dimsToExclude + * subArrIdx - index of current sub-array + * dimsToExclude - MUST BE SORTED, dimensions to evaluate sub-array along, + * i.e. when shape is [2,3,4,5] and dimsToExclude={0,2}, then there will be 8 + * sub-arrays with shape [3,5], and subArrIdx must be in range [0,7] if + * dimsToExclude is empty then idxRanges containing all zeros (means whole + * array) will be returned. keepUnitiesInShape - if false then eliminate + * unities from resulting array shape, for example {1,a,1,b} -> {a,b} + */ + NDArray operator()(const Nd4jLong subArrIdx, + const std::vector& dimsToExclude, + bool keepUnitiesInShape = false) const; + + /** + * processes whole set of sub-arrays + * evaluates shapeInfo of sub-arrays (all sub-arrays have the same shapeInfo) + * and their buffer offsets (each sub-array has its own unique offset from + * original this-buffer) dimsToExclude - MUST BE SORTED, dimensions to + * evaluate sub-array along, i.e. when shape is [2,3,4,5] and + * dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5] if + * dimsToExclude.size() = array rank it means sub-array is whole array and + * copy of original_shapeInfo will be returned and one zero offset + * subArrShapeInfo - output argument, contains shapeInfo common for all + * sub-arrays subArrOffsets - output argument, contains successive + * sub-arrays offsets from original this-buffer keepUnitiesInShape - if false + * then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> + * {a,b} + */ + void getSubArrShapeAndOffsets(const std::vector& dimsToExclude, + Nd4jLong*& subArrShapeInfo, + Nd4jLong*& subArrOffsets, + bool keepUnitiesInShape = false) const; + + /** + * addition unary operator array += other + * other - input array to add + */ + void operator+=(const NDArray& other); + + /** + * subtraction unary operator array -= other + * other - input array to add + */ + void operator-=(const NDArray& other); + + template + void operator+=(const T other); + + template + void operator-=(const T other); + + /** + * negative operator, it changes sign of all array elements on opposite + */ + NDArray operator-() const&; + NDArray operator-() &&; + + /** + * pairwise multiplication unary operator array *= other + * other - input array to multiply on + */ + void operator*=(const NDArray& other); + + /** + * multiplication unary operator array *= scalar + * scalar - input scalar to multiply on + */ + template + void operator*=(const T scalar); + + /** + * pairwise division unary operator: array /= other + * other - input array to divide on + */ + void operator/=(const NDArray& other); + + /** + * division unary operator: array /= scalar + * scalar - input scalar to divide on + */ + template + void operator/=(const T scalar); + + /** + * friend function which implements mathematical multiplication of two arrays + * left - input array + * right - input array + */ + friend NDArray mmul(const NDArray& left, const NDArray& right); + + /** + * return vector containing _buffer as flat binary array + */ + std::vector asByteVector(); + + /** + * makes array to be identity matrix (not necessarily square), that is set + * all diagonal elements = 1, rest = 0 + */ + void setIdentity(); + + /** + * swaps the contents of tow arrays, + * PLEASE NOTE: method doesn't take into account the shapes of arrays, shapes + * may be different except one condition: arrays lengths must be the same + */ + void swapUnsafe(NDArray& other); + + /** + * return vector with buffer which points on corresponding diagonal elements + * of array type - means of vector to be returned: column ('c') or row ('r') + */ + NDArray diagonal(const char type) const; + + /** + * fill target matrix with given value in one or two directions from main + * diagonal: + * - down from main diagonal starting at subdiagonal number "lower" if + * direction = 'l' (down) or 'b' (both) + * - up from main diagonal starting at superdiagonal number "upper"if + * direction = 'u' (up) or 'b' (both) direction - in what direction to fill + * matrix. There are 3 possible directions: 'u' - fill up, mathematically this + * corresponds to lower triangular matrix, subdiagonal "lower" unaffected 'l' + * - fill down, mathematically this corresponds to upper triangular matrix, + * superdiagonal "upper" remains unaffected 'b' - fill in both directions, + * both "lower" and "upper" are taken into account rest of target elements are + * equal to this array elements target and this array should have same shapes, + * except when this_rank = 1 (in that case should be target_rank = 2) + */ + template + void fillAsTriangular(const float value, int lower, int upper, + NDArray& target, const char direction = 'b'); + + /** + * change an array by repeating it the number of times in order to acquire + * new shape equal to the input shape + * + * shape - contains new shape to broadcast array to + * target - optional argument, if target != nullptr the resulting array will + * be placed in target, in opposite case tile operation is done in place + */ + NDArray tileToShape(const Nd4jLong* shapeInfo); + void tileToShape(const std::vector& shape, NDArray& target); #ifndef __JAVACPP_HACK__ - void tileToShape(const std::initializer_list& shape, NDArray& target); + void tileToShape(const std::initializer_list& shape, + NDArray& target); #endif - template - NDArray asT() const; - - template - NDArray asS() const; - - NDArray asT(DataType dtype) const; - - - void linspace(const double start); - - void linspace(const double start, const double step); - - /** - * calculates the trace of an array, that is sum of elements on main diagonal = sum array[i, i, i, ...] - */ - double getTrace() const; - - ResultSet multipleTensorsAlongDimension(const std::vector& indices, const std::vector& dimensions) const; - - ResultSet allTensorsAlongDimension(const std::initializer_list& dimensions) const; - - ResultSet allTensorsAlongDimension(const std::vector& dimensions) const; - - ResultSet allExamples()const ; - - /** - * set _shapeInfo - */ - void setShapeInfo(const Nd4jLong *shapeInfo); - void setShapeInfo(const Nd4jLong *shapeInfo, const sd::DataType dtype); - void setShapeInfo(const ShapeDescriptor& descriptor); - void setShapeInfo(const ConstantShapeBuffer& shapeBuffer); - - /** - * returns absolute offset which corresponds to given sequential index - */ - Nd4jLong getOffset(const Nd4jLong i) const; - - /** - * returns reference on array element with given index - */ - template - FORCEINLINE T& r(const Nd4jLong index); - template - FORCEINLINE T& r(const Nd4jLong i, const Nd4jLong j); - template - FORCEINLINE T& r(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k); - template - FORCEINLINE T& r(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w); - - - /** - * returns array element with given index - * i - element index in array - */ - template - FORCEINLINE T t(const Nd4jLong i) const; - template - FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j) const; - template - FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const; - template - FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) const; - - - /** - * default destructor - */ - ~NDArray() noexcept = default; - - /** - * set _shapeInfo - */ - FORCEINLINE void setShapeInfo(Nd4jLong *shapeInfo); - FORCEINLINE void setShapeInfo(Nd4jLong *shapeInfo, const sd::DataType dtype); - - /** - * returns the value of "dim" dimension - */ - Nd4jLong sizeAt(const int dim) const; - - /** - * returns stride of "dim" dimension - */ - Nd4jLong strideAt(const int dim) const; - - /** - * returns order of array - */ - FORCEINLINE char ordering() const; - - /** - * return _isView - */ - FORCEINLINE bool isView() const; - - /** - * returns shape portion of shapeInfo - */ - FORCEINLINE Nd4jLong* shapeOf() const; - - /** - * returns strides portion of shapeInfo - */ - FORCEINLINE Nd4jLong* stridesOf() const; - - /** - * returns rank of array - */ - FORCEINLINE int rankOf() const; - - /** - * returns length of array - */ - FORCEINLINE Nd4jLong lengthOf() const; - - /** - * returns number of rows in array - */ - FORCEINLINE Nd4jLong rows() const; - - /** - * returns number of columns in array - */ - FORCEINLINE Nd4jLong columns() const; - - /** - * returns size of array elements type - */ - FORCEINLINE size_t sizeOfT() const; - - /** - * returns element-wise-stride - */ - FORCEINLINE Nd4jLong ews() const; - - // returns true if arrays have same shape - FORCEINLINE bool isSameShape(const NDArray *other) const; - FORCEINLINE bool isSameShape(const NDArray &other) const; - FORCEINLINE bool isSameShape(const std::initializer_list& shape) const; - FORCEINLINE bool isSameShape(const std::vector& shape) const; - FORCEINLINE bool areSameShapeAndType(const NDArray& other) const; - - /** - * returns true if these two NDArrays have same rank, dimensions, strides, ews and order - */ - FORCEINLINE bool isSameShapeStrict(const NDArray& other) const; - - /** - * returns true if buffer && shapeInfo were defined (non nullptr) - */ - FORCEINLINE bool nonNull() const; - - template - T r(const Nd4jLong i) const; - - /** - * returns array element with given index from linear buffer - * i - element index in array - */ - template - T e(const Nd4jLong i) const; - - /** - * returns element with given indexes from 2D array - * i - number of row - * j - number of column - */ - template - T e(const Nd4jLong i, const Nd4jLong j) const; - - /** - * returns element with given indexes from 3D array - * i - height - * j - width - * k - depth - */ - template - T e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const; - - /** - * returns element with given indexes from DD array - */ - template - T e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l) const; - - /** - * returns array-scalar containing element of this array with given index - * i - element index in array - */ - NDArray e(const Nd4jLong i) const; - - /** - * assigns given scalar to array element by given index, regards array buffer as linear - * i - element index in array - * value - scalar value to assign - */ - template - void p(const Nd4jLong i, const T value); - - void p(const Nd4jLong i, const NDArray& value); - - /** - * assigns given scalar to 2D array element by given indexes - * i - number of row - * j - number of row - * value - scalar value to assign - */ - template - void p(const Nd4jLong i, const Nd4jLong j, const T value); - - /** - * assigns given scalar to 3D array element by given indexes - * i - height - * j - width - * k - depth - * value - scalar value to assign - */ - template - void p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const T value); - - template - void p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const T value); - void p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, NDArray const& value); - - - template - void pIdx(const Nd4jLong* indices, const T value); - - /** - * returns true if array is 2D - */ - FORCEINLINE bool isMatrix() const; - - /** - * returns true if array is vector - */ - FORCEINLINE bool isVector() const; - - /** - * returns true if array is column vector - */ - FORCEINLINE bool isColumnVector() const; - - /** - * returns true if array is row vector - */ - FORCEINLINE bool isRowVector() const; - - /** - * returns true if all dimensions of array except one are unities, for example: [1,1,n,1], [n,1,1], [n], ... - * posOfNonUnityDim - one dimension with value > 1 - */ - FORCEINLINE bool isCommonVector(int& posOfNonUnityDim) const; - - - /** - * returns true if array is scalar - */ - FORCEINLINE bool isScalar() const; - - /** - * Returns data type of this array - * @return - */ - FORCEINLINE DataType dataType() const; - - /** - * This method returns true if value is from Integer space - * @return - */ - bool isZ() const; - - /** - * This method returns true if array is from Real space - * @return - */ - bool isR() const; - - /** - * This method returns true if array is from Boolean space - * @return - */ - bool isB() const; - - /** - * This method returns true if array contains Complex numbers - * @return - */ - bool isC() const; - - /** - * This method returns true if array contains String - * @return - */ - bool isS() const; - - template - std::vector asVectorT(); - - FORCEINLINE bool isAttached(); - - NDArray* detach(); - - FORCEINLINE bool operator==(const NDArray &other) const; - - FORCEINLINE bool operator!=(const NDArray &other) const; - }; - - - + template + NDArray asT() const; + + template + NDArray asS() const; + + NDArray asT(DataType dtype) const; + + void linspace(const double start); + + void linspace(const double start, const double step); + + /** + * calculates the trace of an array, that is sum of elements on main diagonal + * = sum array[i, i, i, ...] + */ + double getTrace() const; + + ResultSet multipleTensorsAlongDimension( + const std::vector& indices, + const std::vector& dimensions) const; + + ResultSet allTensorsAlongDimension( + const std::initializer_list& dimensions) const; + + ResultSet allTensorsAlongDimension(const std::vector& dimensions) const; + + ResultSet allExamples() const; + + /** + * set _shapeInfo + */ + void setShapeInfo(const Nd4jLong* shapeInfo); + void setShapeInfo(const Nd4jLong* shapeInfo, const sd::DataType dtype); + void setShapeInfo(const ShapeDescriptor& descriptor); + void setShapeInfo(const ConstantShapeBuffer& shapeBuffer); + + /** + * returns absolute offset which corresponds to given sequential index + */ + Nd4jLong getOffset(const Nd4jLong i) const; + + /** + * returns reference on array element with given index + */ + template + FORCEINLINE T& r(const Nd4jLong index); + + template + FORCEINLINE T& r(const Nd4jLong i, const Nd4jLong j); + template + FORCEINLINE T& r(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k); + template + FORCEINLINE T& r(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, + const Nd4jLong w); + + /** + * returns array element with given index + * i - element index in array + */ + template + FORCEINLINE T t(const Nd4jLong i) const; + + template + FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j) const; + template + FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const; + template + FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, + const Nd4jLong w) const; + + /** + * default destructor + */ + ~NDArray() noexcept = default; + + /** + * set _shapeInfo + */ + FORCEINLINE void setShapeInfo(Nd4jLong* shapeInfo); + FORCEINLINE void setShapeInfo(Nd4jLong* shapeInfo, const sd::DataType dtype); + + /** + * returns the value of "dim" dimension + */ + Nd4jLong sizeAt(const int dim) const; + + /** + * returns stride of "dim" dimension + */ + Nd4jLong strideAt(const int dim) const; + + /** + * returns order of array + */ + FORCEINLINE char ordering() const; + + /** + * return _isView + */ + FORCEINLINE bool isView() const; + + /** + * returns shape portion of shapeInfo + */ + FORCEINLINE Nd4jLong* shapeOf() const; + + /** + * returns strides portion of shapeInfo + */ + FORCEINLINE Nd4jLong* stridesOf() const; + + /** + * returns rank of array + */ + FORCEINLINE int rankOf() const; + + /** + * returns length of array + */ + FORCEINLINE Nd4jLong lengthOf() const; + + /** + * returns number of rows in array + */ + FORCEINLINE Nd4jLong rows() const; + + /** + * returns number of columns in array + */ + FORCEINLINE Nd4jLong columns() const; + + /** + * returns size of array elements type + */ + FORCEINLINE size_t sizeOfT() const; + + /** + * returns element-wise-stride + */ + FORCEINLINE Nd4jLong ews() const; + + // returns true if arrays have same shape + FORCEINLINE bool isSameShape(const NDArray* other) const; + FORCEINLINE bool isSameShape(const NDArray& other) const; + FORCEINLINE bool isSameShape( + const std::initializer_list& shape) const; + FORCEINLINE bool isSameShape(const std::vector& shape) const; + FORCEINLINE bool areSameShapeAndType(const NDArray& other) const; + + /** + * returns true if these two NDArrays have same rank, dimensions, strides, + * ews and order + */ + FORCEINLINE bool isSameShapeStrict(const NDArray& other) const; + + /** + * returns true if buffer && shapeInfo were defined (non nullptr) + */ + FORCEINLINE bool nonNull() const; + + template + T r(const Nd4jLong i) const; + + /** + * returns array element with given index from linear buffer + * i - element index in array + */ + template + T e(const Nd4jLong i) const; + + /** + * returns element with given indexes from 2D array + * i - number of row + * j - number of column + */ + template + T e(const Nd4jLong i, const Nd4jLong j) const; + + /** + * returns element with given indexes from 3D array + * i - height + * j - width + * k - depth + */ + template + T e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const; + + /** + * returns element with given indexes from DD array + */ + template + T e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, + const Nd4jLong l) const; + + /** + * returns array-scalar containing element of this array with given index + * i - element index in array + */ + NDArray e(const Nd4jLong i) const; + + /** + * assigns given scalar to array element by given index, regards array buffer + * as linear i - element index in array value - scalar value to assign + */ + template + void p(const Nd4jLong i, const T value); + + void p(const Nd4jLong i, const NDArray& value); + + /** + * assigns given scalar to 2D array element by given indexes + * i - number of row + * j - number of row + * value - scalar value to assign + */ + template + void p(const Nd4jLong i, const Nd4jLong j, const T value); + + /** + * assigns given scalar to 3D array element by given indexes + * i - height + * j - width + * k - depth + * value - scalar value to assign + */ + template + void p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const T value); + + template + void p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, + const T value); + void p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, + NDArray const& value); + + template + void pIdx(const Nd4jLong* indices, const T value); + + /** + * returns true if array is 2D + */ + FORCEINLINE bool isMatrix() const; + + /** + * returns true if array is vector + */ + FORCEINLINE bool isVector() const; + + /** + * returns true if array is column vector + */ + FORCEINLINE bool isColumnVector() const; + + /** + * returns true if array is row vector + */ + FORCEINLINE bool isRowVector() const; + + /** + * returns true if all dimensions of array except one are unities, for + * example: [1,1,n,1], [n,1,1], [n], ... posOfNonUnityDim - one dimension with + * value > 1 + */ + FORCEINLINE bool isCommonVector(int& posOfNonUnityDim) const; + + /** + * returns true if array is scalar + */ + FORCEINLINE bool isScalar() const; + + /** + * Returns data type of this array + * @return + */ + FORCEINLINE DataType dataType() const; + + /** + * This method returns true if value is from Integer space + * @return + */ + bool isZ() const; + + /** + * This method returns true if array is from Real space + * @return + */ + bool isR() const; + + /** + * This method returns true if array is from Boolean space + * @return + */ + bool isB() const; + + /** + * This method returns true if array contains Complex numbers + * @return + */ + bool isC() const; + + /** + * This method returns true if array contains String + * @return + */ + bool isS() const; + + template + std::vector asVectorT(); + + FORCEINLINE bool isAttached(); + + NDArray detach(); + + /** + * This method returns true if array is valid array with some shape etc + * @return + */ + bool defined() const; + bool undefined() const; + + FORCEINLINE bool operator==(const NDArray& other) const; + + FORCEINLINE bool operator!=(const NDArray& other) const; +}; // class NDArray + +std::ostream &operator<<(std::ostream &os, const NDArray &m); ////////////////////////////////////////////////////////////////////////// ///// IMLEMENTATION OF INLINE METHODS ///// ////////////////////////////////////////////////////////////////////////// -bool NDArray::isAttached() { - return this->_context->getWorkspace() != nullptr; -} +bool NDArray::isAttached() { return this->_context->getWorkspace() != nullptr; } template FORCEINLINE R NDArray::templatedGet(void const* buffer, Nd4jLong index) const { - auto b = reinterpret_cast(buffer); - auto v = static_cast(b[index]); - return v; + auto b = reinterpret_cast(buffer); + auto v = static_cast(b[index]); + return v; } ////////////////////////////////////////////////////////////////////////// -void NDArray::setShapeInfo(Nd4jLong *shapeInfo) { - auto buffer = ConstantShapeHelper::getInstance().bufferForShapeInfo(shapeInfo); - _shapeInfo = buffer.primary(); - _shapeInfoD = buffer.special(); - - if (shapeInfo != nullptr) { - _dataType = ArrayOptions::dataType(_shapeInfo); - if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) - _length = 0; - else - _length = shape::length(_shapeInfo); - } - else { - _dataType = sd::DataType::INHERIT; - _length = 0; - } +void NDArray::setShapeInfo(Nd4jLong* shapeInfo) { + auto buffer = + ConstantShapeHelper::getInstance().bufferForShapeInfo(shapeInfo); + _shapeInfo = buffer.primary(); + _shapeInfoD = buffer.special(); + + if (shapeInfo != nullptr) { + _dataType = ArrayOptions::dataType(_shapeInfo); + if (ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) + _length = 0; + else + _length = shape::length(_shapeInfo); + } else { + _dataType = sd::DataType::INHERIT; + _length = 0; + } } ////////////////////////////////////////////////////////////////////////// -void NDArray::setShapeInfo(Nd4jLong *shapeInfo, const sd::DataType dtype) { - auto buffer = ConstantShapeHelper::getInstance().bufferForShapeInfo(shapeInfo); - _shapeInfo = buffer.primary(); - _shapeInfoD = buffer.special(); - - if (shapeInfo != nullptr) { - _dataType = dtype; - if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) - _length = 0; - else - _length = shape::length(_shapeInfo); - } - else { - _dataType = sd::DataType::INHERIT; - _length = 0; - } +void NDArray::setShapeInfo(Nd4jLong* shapeInfo, const sd::DataType dtype) { + auto buffer = + ConstantShapeHelper::getInstance().bufferForShapeInfo(shapeInfo); + _shapeInfo = buffer.primary(); + _shapeInfoD = buffer.special(); + + if (shapeInfo != nullptr) { + _dataType = dtype; + if (ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) + _length = 0; + else + _length = shape::length(_shapeInfo); + } else { + _dataType = sd::DataType::INHERIT; + _length = 0; + } } ////////////////////////////////////////////////////////////////////////// -char NDArray::ordering() const { - return shape::order(_shapeInfo); -} +char NDArray::ordering() const { return shape::order(_shapeInfo); } ////////////////////////////////////////////////////////////////////////// -bool NDArray::isView() const { - return _isView; -} +bool NDArray::isView() const { return _isView; } ////////////////////////////////////////////////////////////////////////// -Nd4jLong* NDArray::shapeOf() const { - return shape::shapeOf(_shapeInfo); -} +Nd4jLong* NDArray::shapeOf() const { return shape::shapeOf(_shapeInfo); } ////////////////////////////////////////////////////////////////////////// -Nd4jLong* NDArray::stridesOf() const { - return shape::stride(_shapeInfo); -} +Nd4jLong* NDArray::stridesOf() const { return shape::stride(_shapeInfo); } ////////////////////////////////////////////////////////////////////////// -int NDArray::rankOf() const { - return shape::rank(_shapeInfo); -} +int NDArray::rankOf() const { return shape::rank(_shapeInfo); } ////////////////////////////////////////////////////////////////////////// -Nd4jLong NDArray::lengthOf() const { - return _length; -} +Nd4jLong NDArray::lengthOf() const { return _length; } ////////////////////////////////////////////////////////////////////////// Nd4jLong NDArray::rows() const { - if (this->rankOf() == 1) - return 1; + if (this->rankOf() == 1) return 1; - if (this->rankOf() > 2) - throw std::runtime_error("Array with rank > 2 can't have rows"); + if (this->rankOf() > 2) + throw std::runtime_error("Array with rank > 2 can't have rows"); - return shapeOf()[0]; + return shapeOf()[0]; } ////////////////////////////////////////////////////////////////////////// Nd4jLong NDArray::columns() const { - if (this->rankOf() == 1) - return this->lengthOf(); + if (this->rankOf() == 1) return this->lengthOf(); - if (this->rankOf() > 2) - throw std::runtime_error("Array with rank > 2 can't have columns"); + if (this->rankOf() > 2) + throw std::runtime_error("Array with rank > 2 can't have columns"); - return shapeOf()[1]; + return shapeOf()[1]; } ////////////////////////////////////////////////////////////////////////// size_t NDArray::sizeOfT() const { - return DataTypeUtils::sizeOfElement(_dataType); + return DataTypeUtils::sizeOfElement(_dataType); } ////////////////////////////////////////////////////////////////////////// Nd4jLong NDArray::ews() const { - if (this->isEmpty() || this->rankOf() == 0) - return 1; + if (this->isEmpty() || this->rankOf() == 0) return 1; - return shape::elementWiseStride(_shapeInfo); + return shape::elementWiseStride(_shapeInfo); } ////////////////////////////////////////////////////////////////////////// bool NDArray::nonNull() const { - if (isEmpty()) - return true; + if (isEmpty()) return true; - if(!Environment::getInstance().isCPU()) - return getDataBuffer()->special() != nullptr && specialShapeInfo() != nullptr; + if (!Environment::getInstance().isCPU()) + return getDataBuffer()->special() != nullptr && + specialShapeInfo() != nullptr; - return getDataBuffer()->primary() != nullptr && shapeInfo() != nullptr; + return getDataBuffer()->primary() != nullptr && shapeInfo() != nullptr; } ////////////////////////////////////////////////////////////////////////// bool NDArray::isMatrix() const { - if (isEmpty()) - return false; + if (isEmpty()) return false; - return 0 != shape::isMatrix(this->_shapeInfo); + return 0 != shape::isMatrix(this->_shapeInfo); } ////////////////////////////////////////////////////////////////////////// bool NDArray::isVector() const { - if (isEmpty()) - return false; - if (rankOf() == 1) - return true; - return !isScalar() && shape::isVector(this->_shapeInfo); + if (isEmpty()) return false; + if (rankOf() == 1) return true; + return !isScalar() && shape::isVector(this->_shapeInfo); } ////////////////////////////////////////////////////////////////////////// bool NDArray::isColumnVector() const { - if (isEmpty()) - return false; + if (isEmpty()) return false; - return !isScalar() && shape::isColumnVector(this->_shapeInfo); + return !isScalar() && shape::isColumnVector(this->_shapeInfo); } ////////////////////////////////////////////////////////////////////////// bool NDArray::isRowVector() const { - if (isEmpty()) - return false; + if (isEmpty()) return false; - // 1D edge case - if (shape::rank(this->_shapeInfo) == 1) - return true; + // 1D edge case + if (shape::rank(this->_shapeInfo) == 1) return true; - return !isScalar() && shape::isRowVector(this->_shapeInfo); + return !isScalar() && shape::isRowVector(this->_shapeInfo); } ////////////////////////////////////////////////////////////////////////// bool NDArray::isCommonVector(int& posOfNonUnityDim) const { - - return shape::isCommonVector(_shapeInfo, posOfNonUnityDim); + return shape::isCommonVector(_shapeInfo, posOfNonUnityDim); } ////////////////////////////////////////////////////////////////////////// bool NDArray::isScalar() const { - return 0 != shape::isScalar(this->_shapeInfo); + return 0 != shape::isScalar(this->_shapeInfo); } - ////////////////////////////////////////////////////////////////////////// Nd4jLong FORCEINLINE NDArray::memoryFootprint() { - Nd4jLong size = this->lengthOf() * this->sizeOfT(); - size += shape::shapeInfoByteLength(this->rankOf()); - return size; + Nd4jLong size = this->lengthOf() * this->sizeOfT(); + size += shape::shapeInfoByteLength(this->rankOf()); + return size; } ////////////////////////////////////////////////////////////////////////// // still the definition of inline function must be in header file -bool NDArray::isSameShape(const std::vector& shape) const{ - if (this->isScalar() && shape.size() == 1 && shape[0] == 0) - return true; - if (this->rankOf() != (int) shape.size()) - return false; - for (int e = 0; e < this->rankOf(); e++) { - if (this->shapeOf()[e] != shape[e] && shape[e] != -1) - return false; - } - return true; +bool NDArray::isSameShape(const std::vector& shape) const { + if (this->isScalar() && shape.size() == 1 && shape[0] == 0) return true; + if (this->rankOf() != (int)shape.size()) return false; + for (int e = 0; e < this->rankOf(); e++) { + if (this->shapeOf()[e] != shape[e] && shape[e] != -1) return false; + } + return true; } ////////////////////////////////////////////////////////////////////////// -bool NDArray::isSameShape(const NDArray *other) const { - if (this->isEmpty() != other->isEmpty()) - return false; +bool NDArray::isSameShape(const NDArray* other) const { + if (this->isEmpty() != other->isEmpty()) return false; - return isSameShape(std::vector(other->_shapeInfo+1, other->_shapeInfo+1+other->_shapeInfo[0])); + return isSameShape(std::vector( + other->_shapeInfo + 1, other->_shapeInfo + 1 + other->_shapeInfo[0])); } ////////////////////////////////////////////////////////////////////////// -bool NDArray::isSameShape(const NDArray &other) const { - return isSameShape(&other); +bool NDArray::isSameShape(const NDArray& other) const { + return isSameShape(&other); } ////////////////////////////////////////////////////////////////////////// bool NDArray::isSameShape(const std::initializer_list& other) const { - return isSameShape(std::vector(other)); + return isSameShape(std::vector(other)); } ////////////////////////////////////////////////////////////////////////// bool NDArray::areSameShapeAndType(const NDArray& other) const { + if (rankOf() != other.rankOf() || _dataType != other._dataType) return false; - if(rankOf() != other.rankOf() || _dataType != other._dataType) - return false; + for (int i = 0; i < rankOf(); ++i) + if (sizeAt(i) != other.sizeAt(i)) return false; - for(int i = 0; i < rankOf(); ++i) - if(sizeAt(i) != other.sizeAt(i)) - return false; - - return true; + return true; } ////////////////////////////////////////////////////////////////////////// @@ -1736,205 +1986,218 @@ bool NDArray::areSameShapeAndType(const NDArray& other) const { // still the definition of inline function must be in header file bool NDArray::isSameShapeStrict(const NDArray& other) const { - return shape::equalsStrict(_shapeInfo, other._shapeInfo); + return shape::equalsStrict(_shapeInfo, other._shapeInfo); } ////////////////////////////////////////////////////////////////////////// bool NDArray::isEmpty() const { - if (this->_shapeInfo == nullptr) - return false; + if (this->_shapeInfo == nullptr) return false; - return ArrayOptions::arrayType(this->shapeInfo()) == ArrayType::EMPTY; + return ArrayOptions::arrayType(this->shapeInfo()) == ArrayType::EMPTY; } ////////////////////////////////////////////////////////////////////////// -bool NDArray::operator==(const NDArray &other) const { - // if (this->dataType() != other.dataType()) // this comparison is already present in equalsTo - // return false; +bool NDArray::operator==(const NDArray& other) const { + // if (this->dataType() != other.dataType()) // this comparison is already + // present in equalsTo + // return false; - if (!this->isSameShape(&other)) - return false; + if (!this->isSameShape(&other)) return false; - return this->equalsTo(&other); + return this->equalsTo(&other); } ////////////////////////////////////////////////////////////////////////// -bool NDArray::operator!=(const NDArray &other) const { - if (this->dataType() != other.dataType()) - return true; +bool NDArray::operator!=(const NDArray& other) const { + if (this->dataType() != other.dataType()) return true; - if (!this->isSameShape(&other)) - return true; + if (!this->isSameShape(&other)) return true; - return !this->equalsTo(&other); + return !this->equalsTo(&other); } ////////////////////////////////////////////////////////////////////////// DataType NDArray::dataType() const { - return _dataType; - // return ArrayOptions::dataType(_shapeInfo); + return _dataType; + // return ArrayOptions::dataType(_shapeInfo); } //////////////////////////////////////////////////////////////////////// template T& NDArray::r(const Nd4jLong i) { + // if (i >= _length) + // throw std::invalid_argument("NDArray::t(i): input index is out of array + // length !"); + if (DataTypeUtils::fromT() != _dataType) + throw std::invalid_argument( + "NDArray::t(i): type of array is not equal to template type T!"); - // if (i >= _length) - // throw std::invalid_argument("NDArray::t(i): input index is out of array length !"); - if (DataTypeUtils::fromT() != _dataType) - throw std::invalid_argument("NDArray::t(i): type of array is not equal to template type T!"); - - syncToHost(); - tickWriteHost(); + syncToHost(); +tickWriteHost(); - return *(reinterpret_cast(bufferWithOffset(getOffset(i)))); + return *(reinterpret_cast(bufferWithOffset(getOffset(i)))); } //////////////////////////////////////////////////////////////////////// template T& NDArray::r(const Nd4jLong i, const Nd4jLong j) { - - if (rankOf() != 2 || i >= sizeAt(0) || j >= sizeAt(1)) - throw std::invalid_argument("NDArray::t(i,j): one of input indexes is out of array length or rank!=2 !"); - if (DataTypeUtils::fromT() != _dataType) - throw std::invalid_argument("NDArray::t(i,j): type of array is not equal to template type T!"); - - syncToHost(); - tickWriteHost(); - - return *(reinterpret_cast(bufferWithOffset(i * strideAt(0) + j * strideAt(1)))); + if (rankOf() != 2 || i >= sizeAt(0) || j >= sizeAt(1)) + throw std::invalid_argument( + "NDArray::t(i,j): one of input indexes is out of array length or " + "rank!=2 !"); + if (DataTypeUtils::fromT() != _dataType) + throw std::invalid_argument( + "NDArray::t(i,j): type of array is not equal to template type T!"); + + syncToHost(); +tickWriteHost(); + + return *(reinterpret_cast(bufferWithOffset(i * strideAt(0) + j * strideAt(1)))); } template T& NDArray::r(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) { + if (rankOf() != 3 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2)) + throw std::invalid_argument( + "NDArray::t(i,j,k): one of input indexes is out of array length or " + "rank!=3!"); + if (DataTypeUtils::fromT() != _dataType) + throw std::invalid_argument( + "NDArray::t(i,j,k): type of array is not equal to template type T!"); - if (rankOf() != 3 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2)) - throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=3!"); - if (DataTypeUtils::fromT() != _dataType) - throw std::invalid_argument("NDArray::t(i,j,k): type of array is not equal to template type T!"); + syncToHost(); - syncToHost(); - tickWriteHost(); + tickWriteHost(); - return *(reinterpret_cast(bufferWithOffset(i * strideAt(0) + j * strideAt(1) + k * strideAt(2)))); + return *(reinterpret_cast(bufferWithOffset(i * strideAt(0) + j * strideAt(1) + k * strideAt(2)))); } template -T& NDArray::r(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) { - - if (rankOf() != 4 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2) || w >= sizeAt(3)) - throw std::invalid_argument("NDArray::t(i,j,k,w): one of input indexes is out of array length or rank!=4 !"); - if (DataTypeUtils::fromT() != _dataType) - throw std::invalid_argument("NDArray::t(i,j,k,w): type of array is not equal to template type T!"); - - syncToHost(); - tickWriteHost(); - - return *(reinterpret_cast(bufferWithOffset(i * strideAt(0) + j * strideAt(1) + k * strideAt(2) + w * strideAt(3)))); +T& NDArray::r(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, + const Nd4jLong w) { + if (rankOf() != 4 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2) || + w >= sizeAt(3)) + throw std::invalid_argument( + "NDArray::t(i,j,k,w): one of input indexes is out of array length or " + "rank!=4 !"); + if (DataTypeUtils::fromT() != _dataType) + throw std::invalid_argument( + "NDArray::t(i,j,k,w): type of array is not equal to template type T!"); + + syncToHost(); +tickWriteHost(); + + return *(reinterpret_cast(bufferWithOffset(i * strideAt(0) + j * strideAt(1) + k * strideAt(2) + w * strideAt(3)))); } //////////////////////////////////////////////////////////////////////// template T NDArray::t(const Nd4jLong i) const { + // if (i >= _length) + // throw std::invalid_argument("NDArray::t(i): input index is out of array + // length !"); + if (DataTypeUtils::fromT() != _dataType) + throw std::invalid_argument( + "NDArray::t(i): type of array is not equal to template type T!"); - // if (i >= _length) - // throw std::invalid_argument("NDArray::t(i): input index is out of array length !"); - if (DataTypeUtils::fromT() != _dataType) - throw std::invalid_argument("NDArray::t(i): type of array is not equal to template type T!"); + syncToHost(); - syncToHost(); - return *(reinterpret_cast(bufferWithOffset(getOffset(i)))); + return *(reinterpret_cast(bufferWithOffset(getOffset(i)))); } //////////////////////////////////////////////////////////////////////// template T NDArray::t(const Nd4jLong i, const Nd4jLong j) const { + if (rankOf() != 2 || i >= sizeAt(0) || j >= sizeAt(1)) + throw std::invalid_argument( + "NDArray::t(i,j): one of input indexes is out of array length or " + "rank!=2 !"); + if (DataTypeUtils::fromT() != _dataType) + throw std::invalid_argument( + "NDArray::t(i,j): type of array is not equal to template type T!"); - if (rankOf() != 2 || i >= sizeAt(0) || j >= sizeAt(1)) - throw std::invalid_argument("NDArray::t(i,j): one of input indexes is out of array length or rank!=2 !"); - if (DataTypeUtils::fromT() != _dataType) - throw std::invalid_argument("NDArray::t(i,j): type of array is not equal to template type T!"); + syncToHost(); - syncToHost(); - return *(reinterpret_cast(bufferWithOffset(i * strideAt(0) + j * strideAt(1)))); + return *(reinterpret_cast(bufferWithOffset(i * strideAt(0) + j * strideAt(1)))); } //////////////////////////////////////////////////////////////////////// template T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const { + if (rankOf() != 3 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2)) + throw std::invalid_argument( + "NDArray::t(i,j,k): one of input indexes is out of array length or " + "rank!=3!"); + if (DataTypeUtils::fromT() != _dataType) + throw std::invalid_argument( + "NDArray::t(i,j,k): type of array is not equal to template type T!"); - if (rankOf() != 3 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2)) - throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=3!"); - if (DataTypeUtils::fromT() != _dataType) - throw std::invalid_argument("NDArray::t(i,j,k): type of array is not equal to template type T!"); + syncToHost(); - syncToHost(); - return *(reinterpret_cast(bufferWithOffset(i * strideAt(0) + j * strideAt(1) + k * strideAt(2)))); + return *(reinterpret_cast(bufferWithOffset(i * strideAt(0) + j * strideAt(1) + k * strideAt(2)))); } //////////////////////////////////////////////////////////////////////// template -T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) const { +T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, + const Nd4jLong w) const { + if (rankOf() != 4 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2) || + w >= sizeAt(3)) + throw std::invalid_argument( + "NDArray::t(i,j,k,w): one of input indexes is out of array length or " + "rank!=4!"); + if (DataTypeUtils::fromT() != _dataType) + throw std::invalid_argument( + "NDArray::t(i,j,k,w): type of array is not equal to template type T!"); - if (rankOf() != 4 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2) || w >= sizeAt(3)) - throw std::invalid_argument("NDArray::t(i,j,k,w): one of input indexes is out of array length or rank!=4!"); - if (DataTypeUtils::fromT() != _dataType) - throw std::invalid_argument("NDArray::t(i,j,k,w): type of array is not equal to template type T!"); + syncToHost(); - syncToHost(); - return *(reinterpret_cast(bufferWithOffset(i * strideAt(0) + j * strideAt(1) + k * strideAt(2) + w * strideAt(3)))); + return *(reinterpret_cast(bufferWithOffset(i * strideAt(0) + j * strideAt(1) + k * strideAt(2) + w * strideAt(3)))); } #ifndef __JAVACPP_HACK__ //////////////////////////////////////////////////////////////////////// -std::shared_ptr NDArray::getDataBuffer() const { - return _buffer; -} +std::shared_ptr NDArray::getDataBuffer() const { return _buffer; } //////////////////////////////////////////////////////////////////////// -std::shared_ptr NDArray::dataBuffer() { - return _buffer; -} +std::shared_ptr NDArray::dataBuffer() { return _buffer; } #endif //////////////////////////////////////////////////////////////////////// const void* NDArray::buffer() const { - return _buffer->primary() != nullptr ? static_cast(_buffer->primary()) + (_offset * sizeOfT()) : nullptr; + return _buffer->primary() != nullptr + ? static_cast(_buffer->primary()) + (_offset * sizeOfT()) + : nullptr; } ////////////////////////////////////////////////////////////////////////// void* NDArray::buffer() { - return _buffer->primary() != nullptr ? static_cast(_buffer->primary()) + (_offset * sizeOfT()) : nullptr; + return _buffer->primary() != nullptr + ? static_cast(_buffer->primary()) + (_offset * sizeOfT()) + : nullptr; } ////////////////////////////////////////////////////////////////////////// -const Nd4jLong* NDArray::shapeInfo() const { - return _shapeInfo; -} +const Nd4jLong* NDArray::shapeInfo() const { return _shapeInfo; } //////////////////////////////////////////////////////////////////////// const Nd4jLong* NDArray::specialShapeInfo() const { - if (_shapeInfoD == nullptr) - return _shapeInfo; - // FIXME: this should be fixed once CUDA backend added - return _shapeInfoD; + if (_shapeInfoD == nullptr) return _shapeInfo; + return _shapeInfoD; } //////////////////////////////////////////////////////////////////////// -Nd4jLong NDArray::bufferOffset() const { - return _offset; -} +Nd4jLong NDArray::bufferOffset() const { return _offset; } - -#if defined(__CUDACC__) //&& defined(BUILD_TESTS) +#if defined(__CUDACC__) //&& defined(BUILD_TESTS) // for CUDA we need stil stuff inline #include #endif -} +} // namespace sd #endif diff --git a/libnd4j/include/array/NDArray.hXX b/libnd4j/include/array/NDArray.hXX deleted file mode 100644 index eefe169cf556..000000000000 --- a/libnd4j/include/array/NDArray.hXX +++ /dev/null @@ -1,5625 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * Copyright (c) 2019 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// $NDArray.hpp - architech-independent implementations (both cuda and cpu). -// -#ifndef __NDARRAY__HPP__ -#define __NDARRAY__HPP__ - -#include -#include -#include -#include -#include -#include - -namespace sd { - -template <> -ND4J_EXPORT utf8string NDArray::e(const Nd4jLong i) const; -template <> -ND4J_EXPORT std::string NDArray::e(const Nd4jLong i) const; -template <> -ND4J_EXPORT std::u16string NDArray::e(const Nd4jLong i) const; -template <> -ND4J_EXPORT std::u32string NDArray::e(const Nd4jLong i) const; - -//////////////////////////////////////////////////////////////////////// -// copy constructor -NDArray::NDArray(const NDArray& other) { - - _context = other._context; - _offset = 0; - - setShapeInfo(ShapeDescriptor(other.dataType(), other.ordering(), other.shapeOf(), other.rankOf())); - - if(!isEmpty()) { - _buffer = std::make_shared(other.lengthOf() * other.sizeOfT(), other.dataType(), other.getContext()->getWorkspace()); - this->assign(&other); - } - else - _buffer = std::make_shared(); -} - -//////////////////////////////////////////////////////////////////////// -NDArray::NDArray(const char order, const std::vector &shape, sd::DataType dtype, sd::LaunchContext * context) { - - if ((int) shape.size() > MAX_RANK) - throw std::invalid_argument("Rank of NDArray can't exceed 32"); - - _context = context; - _isAttached = _context->getWorkspace() != nullptr; - _offset = 0; - - if (shape.empty()) - setShapeInfo(ShapeDescriptor::emptyDescriptor(dtype)); - else - setShapeInfo(ShapeDescriptor(dtype, order, shape)); - - _buffer = std::make_shared(lengthOf() * DataTypeUtils::sizeOf(dtype), dtype, getContext()->getWorkspace()); - _buffer->setToZeroBuffers(); -} - -//////////////////////////////////////////////////////////////////////// -NDArray::NDArray(const char order, const std::vector &shape, const std::vector& data, sd::DataType dtype, sd::LaunchContext * context) { - - if ((int) shape.size() > MAX_RANK) - throw std::invalid_argument("Rank of NDArray can't exceed 32"); - - _context = context; - _offset = 0; - - if (shape.size() == 0) { - if (data.size() == 0) - setShapeInfo(ShapeDescriptor::emptyDescriptor(dtype)); - else - setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); - } else { - setShapeInfo(ShapeDescriptor(dtype, order, shape)); - } - - if (lengthOf() != data.size()) { - nd4j_printf("NDArray constructor: data size [%i] doesn't match shape length [%i]\n", data.size(), lengthOf()); - throw std::runtime_error("Data size doesn't match shape"); - } - - _buffer = std::make_shared(lengthOf() * DataTypeUtils::sizeOf(dtype), dtype, getContext()->getWorkspace(), true); - - for(Nd4jLong i=0; i < lengthOf(); ++i) { - BUILD_SINGLE_PARTIAL_SELECTOR(dtype, templatedDoubleAssign<, double>(buffer(), i, reinterpret_cast(data.data()), i), LIBND4J_TYPES); - } - tickWriteHost(); - syncToDevice(); -} - - -//////////////////////////////////////////////////////////////////////// -NDArray::NDArray(const NDArray *other, const bool copyStrides, sd::LaunchContext* context) { - - _context = context; - _offset = 0; - _isAttached = getContext()->getWorkspace() != nullptr; - - if (copyStrides) - setShapeInfo(ShapeDescriptor(other->_shapeInfo)); - else - setShapeInfo(ShapeDescriptor(other->dataType(), other->ordering(), other->shapeOf(), other->rankOf())); - - if (!isEmpty()) - _buffer = std::make_shared(lengthOf() * sizeOfT(), dataType(), getContext()->getWorkspace()); -} - -//////////////////////////////////////////////////////////////////////// -NDArray::NDArray(void* buffer, const char order, const std::vector &shape, sd::DataType dtype, sd::LaunchContext * context, const bool isBuffAlloc) { - - if (shape.empty()) - throw std::runtime_error("NDArray constructor: input shape is empty !"); - - if ((int) shape.size() > MAX_RANK) - throw std::invalid_argument("Rank of NDArray can't exceed 32"); - - _context = context; - _offset = 0; - _isAttached = getContext()->getWorkspace() != nullptr; - - setShapeInfo(ShapeDescriptor(dtype, order, shape)); - - _buffer = std::make_shared(buffer, lengthOf() * sizeOfT(), dataType(), isBuffAlloc, getContext()->getWorkspace()); -} - -//////////////////////////////////////////////////////////////////////// -// creates new NDArray using shape information from "shapeInfo" array, set all elements in new array to be zeros -NDArray::NDArray(const Nd4jLong* shapeInfo, const sd::DataType dtype, const bool copyStrides, sd::LaunchContext * context, const bool nullify) { - - if (shapeInfo == nullptr) - throw std::runtime_error("NDArray constructor: can't be initalized without shapeinfo"); - - if ((int) shapeInfo[0] > MAX_RANK) - throw std::invalid_argument("Rank of NDArray can't exceed 32"); - - _context = context; - _offset = 0; - - if (copyStrides) - setShapeInfo(ShapeDescriptor(shapeInfo, dtype)); - else - setShapeInfo(ShapeDescriptor(dtype, shape::order(shapeInfo), shape::shapeOf(shapeInfo), shape::rank(shapeInfo))); - - if (!isEmpty()) { - _buffer = std::make_shared(lengthOf() * sizeOfT(), dtype, getContext()->getWorkspace()); - - if (nullify) - _buffer->setToZeroBuffers(); - } -} - -//////////////////////////////////////////////////////////////////////// -// scalar constructor -NDArray::NDArray(sd::DataType dtype, sd::LaunchContext* context, const bool isScalar) { - - _context = context; - _offset = 0; - _isAttached = getContext()->getWorkspace() != nullptr; - - if (isScalar) { - setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); - _buffer = std::make_shared(sizeOfT(), dtype, getContext()->getWorkspace()); - _buffer->setToZeroBuffers(); - } - else - setShapeInfo(ConstantShapeHelper::getInstance().emptyShapeInfo(dtype)); -} - -////////////////////////////////////////////////////////////////////////// -// move constructor -NDArray::NDArray(NDArray&& other) noexcept { - - _isView = other._isView; - _buffer = other._buffer; - _shapeInfo = other._shapeInfo; - _shapeInfoD = other._shapeInfoD; - _context = other._context; - _dataType = other._dataType; - _length = other._length; - _offset = other._offset; - - other._buffer = std::make_shared(); - other._shapeInfo = other._shapeInfoD = nullptr; - other._length = 0; -} - -//////////////////////////////////////////////////////////////////////// -//constructor, create empty array at given workspace -NDArray::NDArray(sd::LaunchContext * context) { - _buffer = std::make_shared(); - _shapeInfo = nullptr; - _shapeInfoD = nullptr; - _offset = 0; - _context = context; - _length = 0; -} - - //////////////////////////////////////////////////////////////////////// - // creates new NDArray using shape information from "shapeInfo" array, set all elements in new array to be zeros, set dtype as array type - NDArray::NDArray(const Nd4jLong* shapeInfo, const bool copyStrides, sd::LaunchContext * context, const bool nullify): - NDArray(shapeInfo, ArrayOptions::dataType(shapeInfo), copyStrides, context) { - } - - //////////////////////////////////////////////////////////////////////// - NDArray::NDArray(std::shared_ptr buffer, const ShapeDescriptor& descriptor, sd::LaunchContext* context, const Nd4jLong offset) { - - _context = context; - _offset = offset; - - setShapeInfo(descriptor); - - _buffer = buffer; - - _isView = offset > 0 || _length * DataTypeUtils::sizeOf(_dataType) < buffer->getLenInBytes(); - } - - NDArray::NDArray(void *buffer, Nd4jLong *shapeInfo, sd::LaunchContext * context, const bool isBuffAlloc) : NDArray::NDArray(buffer, const_cast(shapeInfo), context, isBuffAlloc) { - // - } - - //////////////////////////////////////////////////////////////////////// - // do not allocate memory, memory for array is passed from outside - NDArray::NDArray(void *buffer, const Nd4jLong *shapeInfo, sd::LaunchContext * context, const bool isBuffAlloc) { - - if (buffer == nullptr && ArrayOptions::arrayType(shapeInfo) != ArrayType::EMPTY) - throw std::runtime_error("NDArray constructor: can't be initalized with nullptr buffer !"); - - if (shapeInfo == nullptr) - throw std::runtime_error("NDArray constructor: can't be initalized without shapeinfo !"); - - if ((int) shapeInfo[0] > MAX_RANK) - throw std::invalid_argument("NDArray constructor: rank of NDArray can't exceed 32 !"); - - _context = context; - _isAttached = getContext()->getWorkspace() != nullptr; - _offset = 0; - - setShapeInfo(ShapeDescriptor(shapeInfo)); - - if (this->isEmpty()) { - tickReadDevice(); - tickReadHost(); - } - else { - _buffer = std::make_shared(buffer, lengthOf() * sizeOfT(), dataType(), isBuffAlloc, getContext()->getWorkspace()); - } - } - - //////////////////////////////////////////////////////////////////////// - // do not allocate memory, memory for array is passed from outside - // we suppose the content of both (device and host) buffers is identical - NDArray::NDArray(void *buffer, void* bufferD, const Nd4jLong *shapeInfo, sd::LaunchContext * context, const bool isBuffAlloc, const bool isBuffDAlloc) { - - if (shapeInfo == nullptr) - throw std::runtime_error("NDArray constructor cuda: can't be initalized without shapeinfo"); - - if ((int) shapeInfo[0] > MAX_RANK) - throw std::invalid_argument("NDArray constructor cuda: rank of NDArray can't exceed 32"); - - _context = context; - _offset = 0; - - setShapeInfo(ShapeDescriptor(shapeInfo)); - - if (!isEmpty()) - _buffer = std::make_shared(buffer, bufferD, lengthOf() * sizeOfT(), dataType(), isBuffAlloc, isBuffDAlloc, getContext()->getWorkspace()); - } - -////////////////////////////////////////////////////////////////////////// -NDArray::NDArray(std::shared_ptr buffer, const char order, const std::vector &shape, sd::LaunchContext* context) { - - if (shape.empty()) - throw std::runtime_error("NDArray constructor: input shape is empty !"); - - if ((int) shape.size() > MAX_RANK) - throw std::invalid_argument("NDArray constructor: rank of NDArray can't exceed 32"); - - _context = context; - _offset = 0; - - setShapeInfo(ShapeDescriptor(buffer->getDataType(), order, shape)); - - _buffer = buffer; - - _isView = _length * DataTypeUtils::sizeOf(_dataType) < buffer->getLenInBytes(); -} -///////////////////////////////////////////////////////////////////////// -// u16 string constructors -NDArray::NDArray(const std::u16string& u16string, sd::DataType dtype, sd::LaunchContext* context) { - - if (!DataTypeUtils::isS(dtype)) { - throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); - } - - if (!unicode::isStringValidU16(u16string.data(), u16string.data() + u16string.size())) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); - } - - // one word that is why used 1 - Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(1); - - Nd4jLong dataLength = [&] { - if (dtype == DataType::UTF16) { - return static_cast(u16string.size() * sizeof(uint16_t)); - } - if (dtype == DataType::UTF32) { - return unicode::offsetUtf16StringInUtf32(u16string.data(), u16string.size()); - } - return unicode::offsetUtf16StringInUtf8(u16string.data(), u16string.size()); - }(); - - Nd4jLong offsets[2] = { 0 , dataLength }; - - _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); - - _context = context; - _isAttached = getContext()->getWorkspace() != nullptr; - _offset = 0; - - setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); - - memcpy(bufferAsT(), &offsets[0], 2 * sizeof(Nd4jLong)); - - auto data = reinterpret_cast(bufferAsT() + headerLength); - if (dtype == DataType::UTF8) { - unicode::utf16to8(u16string.data(), data, u16string.size()); - } - else if (dtype == DataType::UTF16) { - memcpy(data, u16string.data(), dataLength); - } - else { - unicode::utf16to32(u16string.data(), data, u16string.size()); - } - - tickWriteHost(); - syncToDevice(); -} - -///////////////////////////////////////////////////////////////////////// -// u32 string constructors -NDArray::NDArray(const std::u32string& u32string, sd::DataType dtype, sd::LaunchContext* context) { - - if (!DataTypeUtils::isS(dtype)) { - throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); - } - - if (!unicode::isStringValidU32(u32string.data(), u32string.data() + u32string.size())) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); - } - // one word that is why used 1 - Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(1); - - Nd4jLong dataLength = [&] { - if (dtype == DataType::UTF16) { - return unicode::offsetUtf32StringInUtf16(u32string.data(), u32string.size()); - } - if (dtype == DataType::UTF32) { - return static_cast(sizeof(uint32_t) * u32string.size()); - } - return unicode::offsetUtf32StringInUtf8(u32string.data(), u32string.size()); - }(); - - Nd4jLong offsets[2] = { 0 , dataLength }; - - _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); - - _context = context; - _isAttached = getContext()->getWorkspace() != nullptr; - _offset = 0; - - setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); - - memcpy(bufferAsT(), &offsets[0], 2 * sizeof(Nd4jLong)); - - auto data = reinterpret_cast(bufferAsT() + headerLength); - if (dtype == DataType::UTF8) { - unicode::utf32to8(u32string.data(), data, u32string.size()); - } - else if (dtype == DataType::UTF16) { - unicode::utf32to16(u32string.data(), data, u32string.size()); - } - else { - memcpy(data, u32string.data(), u32string.size() * sizeof(uint32_t)); - } - - tickWriteHost(); - syncToDevice(); -} - -///////////////////////////////////////////////////////////////////////// -// u8 string constructors -NDArray::NDArray(const std::string& str, sd::DataType dtype, sd::LaunchContext* context) { - - if (!DataTypeUtils::isS(dtype)) { - throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); - } - - if (!unicode::isStringValidU8(str.data(), str.data() + str.size())) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); - } - - // one word that is why used 1 - auto headerLength = ShapeUtils::stringBufferHeaderRequirements(1); - - Nd4jLong dataLength = [&] { - if (dtype == DataType::UTF16) { - return unicode::offsetUtf8StringInUtf16(str.data(), str.size()); - } - if (dtype == DataType::UTF32) { - return unicode::offsetUtf8StringInUtf32(str.data(), str.size()); - } - return static_cast(str.size()); - }(); - - Nd4jLong offsets[2] = { 0 , dataLength }; - - _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); - - _context = context; - _isAttached = getContext()->getWorkspace() != nullptr; - _offset = 0; - - setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); - - memcpy(bufferAsT(), &offsets[0], 2 * sizeof(Nd4jLong)); - - auto data = reinterpret_cast(bufferAsT() + headerLength); - - if (dtype == DataType::UTF8) { - memcpy(data, str.data(), str.size()); - } - else if (dtype == DataType::UTF16) { - unicode::utf8to16(str.data(), data, str.size()); - } - else { - unicode::utf8to32(str.data(), data, str.size()); - } - - tickWriteHost(); - syncToDevice(); -} -///////////////////////////////////////////////////////////////////////// -// constructors for vector of strings -NDArray::NDArray(const std::vector& shape, const std::vector& string, const sd::DataType dataType, sd::LaunchContext* context) { - - if (!DataTypeUtils::isS(dataType)) - throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); - - if (shape::prodLong(shape.data(), shape.size()) != string.size()) - throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); - - for (const auto& str : string) { - if (!unicode::isStringValidU8(str, str + std::char_traits::length(str)) ) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); - } - } - - Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); - - std::vector offsets(string.size() + 1); - Nd4jLong dataLength = 0; - for (int e = 0; e < string.size(); e++) { - offsets[e] = dataLength; - dataLength += [&] { - if (dataType == DataType::UTF16) - return unicode::offsetUtf8StringInUtf16(string[e], std::char_traits::length(string[e])); - if (dataType == DataType::UTF32) - return unicode::offsetUtf8StringInUtf32(string[e], std::char_traits::length(string[e])); - return static_cast(std::char_traits::length(string[e])); - }(); - } - offsets[string.size()] = dataLength; - - _buffer = std::make_shared(headerLength + dataLength, dataType, context->getWorkspace(), true); - - _context = context; - _offset = 0; - - setShapeInfo(ShapeDescriptor(dataType, 'c', shape)); - - _isView = false; - - setAttached(context->getWorkspace() != nullptr); - - memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); - - auto data = reinterpret_cast(bufferAsT() + headerLength); - - auto func = PRAGMA_THREADS_FOR{ - for (auto e = start; e < stop; e++) { - auto cdata = data + offsets[e]; - if (dataType == DataType::UTF16) { - unicode::utf8to16(string[e], cdata, std::char_traits::length(string[e])); - } - else if (dataType == DataType::UTF32) { - unicode::utf8to32(string[e], cdata, std::char_traits::length(string[e])); - } - else { - memcpy(cdata, string[e], std::char_traits::length(string[e])); - } - } - }; - - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - - tickWriteHost(); - syncToDevice(); -} -///////////////////////////////////////////////////////////////////////// -NDArray::NDArray(const std::vector& shape, const std::vector& string, const sd::DataType dataType, sd::LaunchContext* context) { - - if (!DataTypeUtils::isS(dataType)) - throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); - - if (shape::prodLong(shape.data(), shape.size()) != string.size()) - throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); - - for (const auto& str : string) { - if (!unicode::isStringValidU8(str.data(), str.data() + str.size())) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); - } - } - - Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); - - std::vector offsets(string.size() + 1); - Nd4jLong dataLength = 0; - for (int e = 0; e < string.size(); e++) { - offsets[e] = dataLength; - dataLength += [&] { - if (dataType == DataType::UTF16) - return unicode::offsetUtf8StringInUtf16(string[e].data(), string[e].size()); - if (dataType == DataType::UTF32) - return unicode::offsetUtf8StringInUtf32(string[e].data(), string[e].size()); - return static_cast(string[e].size()); - }(); - } - - offsets[string.size()] = dataLength; - - _buffer = std::make_shared(headerLength + dataLength, dataType, context->getWorkspace(), true); - - _context = context; - _offset = 0; - - setShapeInfo(ShapeDescriptor(dataType, 'c', shape)); - - _isView = false; - - setAttached(context->getWorkspace() != nullptr); - - memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); - - auto data = reinterpret_cast(bufferAsT() + headerLength); - - auto func = PRAGMA_THREADS_FOR{ - for (auto e = start; e < stop; e++) { - auto cdata = data + offsets[e]; - if (dataType == DataType::UTF16) { - unicode::utf8to16(string[e].data(), cdata, string[e].size()); - } - else if (dataType == DataType::UTF32) { - unicode::utf8to32(string[e].data(), cdata, string[e].size()); - } - else { - memcpy(cdata, string[e].data(), string[e].size()); - } - } - }; - - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - - - tickWriteHost(); - syncToDevice(); -} -///////////////////////////////////////////////////////////////////////// -NDArray::NDArray(const std::vector& shape, const std::vector& string, sd::DataType dtype, sd::LaunchContext* context) { - - if (!DataTypeUtils::isS(dtype)) - throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); - - if (shape::prodLong(shape.data(), shape.size()) != string.size()) - throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); - - for (const auto& str : string) { - if (!unicode::isStringValidU16(str.data(), str.data() + str.size())) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); - } - } - - Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); - - std::vector offsets(string.size() + 1); - Nd4jLong dataLength = 0; - for (int e = 0; e < string.size(); e++) { - offsets[e] = dataLength; - dataLength += [&] { - if (dtype == DataType::UTF16) - return static_cast(sizeof(uint16_t) * string[e].size()); - if (dtype == DataType::UTF32) - return unicode::offsetUtf16StringInUtf32(string[e].data(), string[e].size()); - return unicode::offsetUtf16StringInUtf8(string[e].data(), string[e].size()); - }(); - } - offsets[string.size()] = dataLength; - - _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); - - _context = context; - _offset = 0; - - setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); - - _isView = false; - - setAttached(context->getWorkspace() != nullptr); - - memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); - - auto data = reinterpret_cast(bufferAsT() + headerLength); - - auto func = PRAGMA_THREADS_FOR{ - for (auto e = start; e < stop; e++) { - auto cdata = data + offsets[e]; - if (dtype == DataType::UTF16) { - memcpy(cdata, string[e].data(), string[e].size() * sizeof(uint16_t)); - } - else if (dtype == DataType::UTF32) { - unicode::utf16to32(string[e].data(), cdata, string[e].size()); - } - else { - unicode::utf16to8(string[e].data(), cdata, string[e].size()); - } - } - }; - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - - tickWriteHost(); - syncToDevice(); -} -///////////////////////////////////////////////////////////////////////// -NDArray::NDArray(const std::vector& shape, const std::vector& string, sd::DataType dtype, sd::LaunchContext* context) { - - if (!DataTypeUtils::isS(dtype)) - throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); - - if (shape::prodLong(shape.data(), shape.size()) != string.size()) - throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); - - for (const auto& str : string) { - if (!unicode::isStringValidU16(str, str + std::char_traits::length(str))) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); - } - } - - Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); - - std::vector offsets(string.size() + 1); - Nd4jLong dataLength = 0; - for (int e = 0; e < string.size(); e++) { - offsets[e] = dataLength; - dataLength += [&] { - if (dtype == DataType::UTF16) - return static_cast(sizeof(uint16_t) * std::char_traits::length(string[e])); - if (dtype == DataType::UTF32) - return unicode::offsetUtf16StringInUtf32(string[e], std::char_traits::length(string[e])); - return unicode::offsetUtf16StringInUtf8(string[e], std::char_traits::length(string[e])); - }(); - } - offsets[string.size()] = dataLength; - - _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); - - _context = context; - _offset = 0; - - setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); - - _isView = false; - - setAttached(context->getWorkspace() != nullptr); - - memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); - - auto data = reinterpret_cast(bufferAsT() + headerLength); - - - auto func = PRAGMA_THREADS_FOR{ - for (auto e = start; e < stop; e++) { - auto cdata = data + offsets[e]; - if (dtype == DataType::UTF16) { - memcpy(cdata, string[e], std::char_traits::length(string[e]) * sizeof(uint16_t)); - } - else if (dtype == DataType::UTF32) { - unicode::utf16to32(string[e], cdata, std::char_traits::length(string[e])); - } - else { - unicode::utf16to8(string[e], cdata, std::char_traits::length(string[e])); - } - } - }; - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - - tickWriteHost(); - syncToDevice(); -} -///////////////////////////////////////////////////////////////////////// -NDArray::NDArray(const std::vector& shape, const std::vector& string, sd::DataType dtype, sd::LaunchContext* context) { - - if (!DataTypeUtils::isS(dtype)) - throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); - - if (shape::prodLong(shape.data(), shape.size()) != string.size()) - throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); - - for (auto str : string) { - if (!unicode::isStringValidU32(str.data(), str.data() + str.size())) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); - } - } - - Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); - - std::vector offsets(string.size() + 1); - - Nd4jLong dataLength = 0; - for (int e = 0; e < string.size(); e++) { - offsets[e] = dataLength; - dataLength += [&] { - if (dtype == DataType::UTF16) - return unicode::offsetUtf32StringInUtf16(string[e].data(), string[e].size()); - if (dtype == DataType::UTF32) - return static_cast(sizeof(uint32_t) * string[e].size()); - return unicode::offsetUtf32StringInUtf16(string[e].data(), string[e].size()); - }(); - } - offsets[string.size()] = dataLength; - - _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); - - _context = context; - _offset = 0; - - setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); - - _isView = false; - - setAttached(context->getWorkspace() != nullptr); - - memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); - - auto data = reinterpret_cast(bufferAsT() + headerLength); - - auto func = PRAGMA_THREADS_FOR{ - for (auto e = start; e < stop; e++) { - auto cdata = data + offsets[e]; - if (dtype == DataType::UTF16) { - unicode::utf32to16(string[e].data(), cdata, string[e].size()); - } - else if (dtype == DataType::UTF32) { - memcpy(cdata, string[e].data(), string[e].size() * sizeof(uint32_t)); - } - else { - unicode::utf32to8(string[e].data(), cdata, string[e].size()); - } - } - }; - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - - tickWriteHost(); - syncToDevice(); -} -///////////////////////////////////////////////////////////////////////// -NDArray::NDArray(const std::vector& shape, const std::vector& string, sd::DataType dtype, sd::LaunchContext* context) { - - if (!DataTypeUtils::isS(dtype)) - throw std::invalid_argument("NDArray::NDArray: invalid DataType used"); - - if (shape::prodLong(shape.data(), shape.size()) != string.size()) - throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); - - for (const auto& str : string) { - if (!unicode::isStringValidU32(str, str + std::char_traits::length(str))) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); - } - } - - Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); - - std::vector offsets(string.size() + 1); - - Nd4jLong dataLength = 0; - for (int e = 0; e < string.size(); e++) { - offsets[e] = dataLength; - dataLength += [&] { - if (dtype == DataType::UTF16) - return unicode::offsetUtf32StringInUtf16(string[e], std::char_traits::length(string[e])); - if (dtype == DataType::UTF32) - return static_cast(sizeof(uint32_t) * std::char_traits::length(string[e])); - return unicode::offsetUtf32StringInUtf16(string[e], std::char_traits::length(string[e])); - }(); - } - offsets[string.size()] = dataLength; - - _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); - - _context = context; - _offset = 0; - - setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); - - _isView = _length * DataTypeUtils::sizeOf(_dataType) < _buffer->getLenInBytes(); - - setAttached(context->getWorkspace() != nullptr); - - memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); - - auto data = reinterpret_cast(bufferAsT() + headerLength); - - auto func = PRAGMA_THREADS_FOR{ - for (auto e = start; e < stop; e++) { - auto cdata = data + offsets[e]; - if (dtype == DataType::UTF16) { - unicode::utf32to16(string[e], cdata, std::char_traits::length(string[e])); - } - else if (dtype == DataType::UTF32) { - memcpy(cdata, string[e], std::char_traits::length(string[e]) * sizeof(uint32_t)); - } - else { - unicode::utf32to8(string[e], cdata, std::char_traits::length(string[e])); - } - } - }; - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - - tickWriteHost(); - syncToDevice(); -} - -//////////////////////////////////////////////////////////////////////// -// assignment operator - NDArray& NDArray::operator=(const NDArray& other) { - - if (this == &other || (_shapeInfo == other._shapeInfo && _shapeInfo == nullptr)) - return *this; - - if (_shapeInfo != nullptr && shape::equalsTypesAndShapesSoft(_shapeInfo, other._shapeInfo)) { - if(!other.isEmpty()) - this->assign(&other); - } - else { - _context = other._context; - _offset = 0; - setShapeInfo(ShapeDescriptor(other.dataType(), other.ordering(), other.shapeOf(), other.rankOf())); - - if(!other.isEmpty()) { - _buffer = std::make_shared(other.lengthOf() * other.sizeOfT(), other.dataType(), other.getContext()->getWorkspace()); - this->assign(&other); - } - else - _buffer = std::make_shared(); - } - return *this; -} - - -////////////////////////////////////////////////////////////////////////// -bool NDArray::isC() const { - // TODO: this method must be implemented once we add support for complex numbers - return false; -} - -////////////////////////////////////////////////////////////////////////// -bool NDArray::isS() const { - return (dataType() == DataType::UTF8 || - dataType() == DataType::UTF16 || - dataType() == DataType::UTF32); -} - -////////////////////////////////////////////////////////////////////////// -bool NDArray::isR() const { - auto xType = ArrayOptions::dataType(this->_shapeInfo); - return xType == FLOAT32 || xType == HALF || xType == DOUBLE || xType == FLOAT8 || xType == BFLOAT16; -} - -////////////////////////////////////////////////////////////////////////// -bool NDArray::isZ() const { - // TODO: decide if we really want to exclude Bool here - return !isC() && !isR() && !isB() && !isS(); -} - -////////////////////////////////////////////////////////////////////////// -bool NDArray::isB() const { - return ArrayOptions::dataType(this->_shapeInfo) == BOOL; -} - -////////////////////////////////////////////////////////////////////////// -template -std::string NDArray::toStringValue(T value) { - std::ostringstream os ; - //throw the value into the string stream - os << value ; - //convert the string stream into a string and return - return os.str() ; -} - -////////////////////////////////////////////////////////////////////////// -template<> -std::string NDArray::toStringValue(float16 value) { - std::ostringstream os ; - //throw the value into the string stream - os << (float) value ; - //convert the string stream into a string and return - return os.str() ; -} - -////////////////////////////////////////////////////////////////////////// -template<> -std::string NDArray::toStringValue(bfloat16 value) { - std::ostringstream os ; - //throw the value into the string stream - os << (float) value ; - //convert the string stream into a string and return - return os.str() ; -} - -////////////////////////////////////////////////////////////////////////// -std::string NDArray::asIndexedString(Nd4jLong limit) { - std::ostringstream os; - os << "["; - if (limit < 1 || limit > this->lengthOf()) - limit = this->lengthOf(); - for (Nd4jLong e = 0; e < limit; e++) { - os << toStringValue(this->e(e)); - if (e < limit - 1) - os << ", "; - } - os << "]"; - return os.str(); -} - -////////////////////////////////////////////////////////////////////////// -std::string NDArray::asString(Nd4jLong limit) { - std::ostringstream os; - os << "["; - if (limit < 1 || limit > this->lengthOf()) - limit = this->lengthOf(); - for (Nd4jLong e = 0; e < limit; e++) { - if (this->isR()) - os << toStringValue(this->e(e)); - else if (this->isZ()) - os << toStringValue(this->e(e)); - else if (this->isB()) - os << toStringValue(this->e(e)); - else if (this->isS()) // todo add utf16 and utf32 - os << this->e(e); - if (e < limit - 1) - os << ", "; - } - os << "]"; - return os.str(); -} - -//////////////////////////////////////////////////////////////////////// -template -std::vector NDArray::getBufferAsVector() const { - std::vector vector(lengthOf()); - for (Nd4jLong e = 0; e < lengthOf(); e++) - vector[e] = this->e(e); - return vector; -} -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::getBufferAsVector() const, LIBND4J_TYPES); - -//////////////////////////////////////////////////////////////////////// -std::vector NDArray::getShapeAsFlatVector() const { - std::vector vector(this->rankOf()); - for (int e = 0; e < this->rankOf(); e++) - vector[e] = static_cast(this->sizeAt(e)); - return vector; -} - -//////////////////////////////////////////////////////////////////////// -std::vector NDArray::getShapeAsVector() const { - - std::vector vector(this->rankOf()); - for (int e = 0; e < this->rankOf(); e++) - vector[e] = this->sizeAt(e); - - return vector; -} - -//////////////////////////////////////////////////////////////////////// -std::vector NDArray::getShapeAsVectorInt() const { - - std::vector vector(this->rankOf()); - for (int e = 0; e < this->rankOf(); e++) - vector[e] = static_cast(this->sizeAt(e)); - - return vector; -} - -//////////////////////////////////////////////////////////////////////// -std::vector NDArray::getShapeInfoAsFlatVector() const { - int magicNumber = shape::shapeInfoLength(this->rankOf()); - std::vector vector(magicNumber); - - for (int e = 0; e < magicNumber; e++) - vector[e] = static_cast(_shapeInfo[e]); - - return vector; -} - -//////////////////////////////////////////////////////////////////////// -std::vector NDArray::getShapeInfoAsVector() const { - int magicNumber = shape::shapeInfoLength(this->rankOf()); - std::vector vector(magicNumber); - for (int e = 0; e < magicNumber; e++) - vector[e] = this->_shapeInfo[e]; - return vector; -} - -//////////////////////////////////////////////////////////////////////// -std::vector NDArray::asByteVector() { - - if (isS()) { - // string data type requires special treatment - syncToHost(); - auto numWords = this->lengthOf(); - auto offsetsBuffer = this->bufferAsT(); - auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numWords); - auto dataLength = offsetsBuffer[numWords]; - std::vector result(headerLength + dataLength); - - memcpy(result.data(), buffer(), headerLength + dataLength); - - return result; - } else { - // all other types are linear - std::vector result((unsigned long long) this->lengthOf() * sizeOfT()); - - if (this->isView()) { - auto tmp = this->dup(this->ordering()); - syncToHost(); - memcpy(result.data(), tmp.buffer(), (unsigned long long) lengthOf() * sizeOfT()); - } else { - syncToHost(); - memcpy(result.data(), buffer(), (unsigned long long) lengthOf() * sizeOfT()); - } - return result; - } -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::linspace(const double start) { - linspace(start, 1); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::linspace(const double start, const double step) { - if (isS()) - throw std::runtime_error("NDArray::linspace: you can't use this method on String array!"); - Nd4jLong numElements = this->lengthOf(); - for (Nd4jLong e = 0; e < numElements; e++) - this->p(e, start + (step * e)); -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::streamline(char o) { - char order = o == 'a' ? this->ordering() : o; - syncToDevice(); - std::shared_ptr newBuffer = std::make_shared(this->lengthOf() * sizeOfT(), dataType(), getContext()->getWorkspace()); - auto shapeBuffer = ConstantShapeHelper::getInstance().bufferForShapeInfo(dataType(), order, rankOf(), shapeOf()); - NativeOpExecutioner::execTransformSame(getContext(), transform::Copy, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), newBuffer->primary(), - shapeBuffer.primary(), newBuffer->special(), - shapeBuffer.special(), nullptr, nullptr, nullptr); - setShapeInfo(shapeBuffer); - _buffer = newBuffer; - _offset = 0; - tickWriteDevice(); -} - -//////////////////////////////////////////////////////////////////////// -// move assignment operator -NDArray& NDArray::operator=(NDArray&& other) noexcept { - if (this == &other) - return *this; - - _isView = other._isView; - _buffer = other._buffer; - _shapeInfo = other._shapeInfo; - _shapeInfoD = other._shapeInfoD; - _context = other._context; - _dataType = other._dataType; - _length = other._length; - _offset = other._offset; - - other._buffer = std::make_shared(); - other._shapeInfo = other._shapeInfoD = nullptr; - other._length = 0; - - return *this; -} - -//////////////////////////////////////////////////////////////////////// -template -NDArray& NDArray::operator=(const T scalar) { - this->assign(scalar); - return *this; -} -template ND4J_EXPORT NDArray& NDArray::operator=(const double scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const float scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const float16 scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const bfloat16 scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const Nd4jLong scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const int scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const int8_t scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const uint8_t scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const uint16_t scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const uint32_t scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const uint64_t scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const int16_t scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const bool scalar); - -////////////////////////////////////////////////////////////////////////// -void NDArray::copyBuffersContinuouslyFrom(const NDArray& other, size_t sizeToCopyInBytes, Nd4jLong offsetThis, Nd4jLong offsetOther) { - - if(offsetThis == 0) - offsetThis = bufferOffset(); - if(offsetOther == 0) - offsetOther = other.bufferOffset(); - - dataBuffer()->copyBufferFrom(*other.getDataBuffer(), sizeToCopyInBytes, offsetThis, offsetOther); -} - -//////////////////////////////////////////////////////////////////// -// This method assigns values of given NDArray to this one -void NDArray::assign(const NDArray& other, bool allowParallelism) { - - if (this == &other) - return; - - if (other.isEmpty()) { - if (!isEmpty()) { - throw std::runtime_error("Cannot assign empty array to non-empty array"); - } - return; - } - - if(isEmpty()) { - *this = other; - return; - } - - if (other.lengthOf() == 1) { - - if(lengthOf() == 1) { - NDArray::preparePrimaryUse({this}, {&other}); - BUILD_DOUBLE_SELECTOR(dataType(), other.dataType(), templatedDoubleAssign, (buffer(), 0, other.buffer(), 0), LIBND4J_TYPES, LIBND4J_TYPES); - NDArray::registerPrimaryUse({this}, {&other}); - this->syncToDevice(); - } - else { - if (dataType() != other.dataType()) { - auto tmp = other.cast(dataType()); - NDArray::prepareSpecialUse({this}, {&tmp}); - NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {}); - } - else { - NDArray::prepareSpecialUse({this}, {&other}); - NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&other}); - } - } - } - else { - if (other.lengthOf() != lengthOf()) { - auto shapeThis = ShapeUtils::shapeAsString(this); - auto shapeThat = ShapeUtils::shapeAsString(&other); - nd4j_printf("Can't assign array: this shape %s; other shape: %s\n", shapeThis.c_str(), shapeThat.c_str()); - throw std::runtime_error("NDArray::assign: lengths of arrays are mismatched"); - } - - NDArray::prepareSpecialUse({this}, {&other}); - NativeOpExecutioner::execTransformAny(getContext(), transform::Assign, other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&other}); - } -} - -////////////////////////////////////////////////////////////////////////// -// This method assigns values of given NDArray to this one, wrt order -void NDArray::assign(const NDArray *other, bool allowParallelism) { - assign(*other, allowParallelism); -} - -////////////////////////////////////////////////////////////////////////// -template -void NDArray::assign(const T& value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(dataType(), value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.specialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&temp}); -} -template ND4J_EXPORT void NDArray::assign(const double& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const float& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const float16& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const bfloat16& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const Nd4jLong& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const int& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const int8_t& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const int16_t& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const uint8_t& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const uint16_t& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const uint32_t& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const uint64_t& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const bool& value, bool allowParallelism); - -////////////////////////////////////////////////////////////////////////// -NDArray* NDArray::detach() { - - if (!isAttached()) - return this; - - std::shared_ptr newBuffer = std::make_shared(lengthOf() * sizeOfT(), dataType()); - - auto result = new NDArray(newBuffer, ShapeDescriptor(dataType(), ordering(), shapeOf(), rankOf())); - - result->assign(*this); - - return result; -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::varianceNumber(sd::variance::Ops op, bool biasCorrected) { - - NDArray res(DataTypeUtils::pickFloatingType(dataType()), getContext()); - - NDArray::prepareSpecialUse({&res}, {this}); - NativeOpExecutioner::execSummaryStatsScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo(), biasCorrected); - NDArray::registerSpecialUse({&res}, {this}); - - return res; -} - -////////////////////////////////////////////////////////////////////////// -// This method returns sum of all elements of this NDArray -NDArray NDArray::sumNumber() const { - if (isS()) - throw std::runtime_error("NDArray::sumNumber: you can't use this method on String array!"); - NDArray res(dataType(), getContext()); - - NDArray::prepareSpecialUse({&res}, {this}); - NativeOpExecutioner::execReduceSameScalar(getContext(), sd::reduce::SameOps::Sum, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo()); - NDArray::registerSpecialUse({&res}, {this}); - - return res; -} - -////////////////////////////////////////////////////////////////////////// -// This method returns mean number of this NDArray -NDArray NDArray::meanNumber() const { - - if (isS()) - throw std::runtime_error("NDArray::meanNumber: you can't use this method on String array!"); - NDArray res(DataTypeUtils::pickFloatingType(dataType()), getContext()); - - NDArray::prepareSpecialUse({&res}, {this}); - NativeOpExecutioner::execReduceFloatScalar(getContext(), sd::reduce::FloatOps::Mean, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo()); - NDArray::registerSpecialUse({&res}, {this}); - return res; -} - -////////////////////////////////////////////////////////////////////////// -bool NDArray::hasNaNs() { - if (isS()) - throw std::runtime_error("NDArray::hasNaNs: you can't use this method on String array!"); - return this->reduceNumber(sd::reduce::IsNan, nullptr).e(0) > 0; -} - -////////////////////////////////////////////////////////////////////////// -bool NDArray::hasInfs() { - if (isS()) - throw std::runtime_error("NDArray::hasInfs: you can't use this method on String array!"); - return this->reduceNumber(sd::reduce::IsInf, nullptr).e(0) > 0; -} - -////////////////////////////////////////////////////////////////////////// -bool NDArray::isFinite() { - if (isS()) - throw std::runtime_error("NDArray::isFinite: you can't use this method on String array!"); - return this->reduceNumber(sd::reduce::IsInfOrNan, nullptr).e(0) == 0; -} - -////////////////////////////////////////////////////////////////////////// -template -void NDArray::templatedSet(void *buffer, const Nd4jLong *indices, const void *value) { - auto t = reinterpret_cast(buffer); - const auto y = *(reinterpret_cast(value)); - - auto xOffset = shape::getOffset(shapeInfo(), indices); - t[xOffset] = static_cast(y); -} -BUILD_DOUBLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong *indices, const void *value), LIBND4J_TYPES, LIBND4J_TYPES); - -////////////////////////////////////////////////////////////////////////// -template -void NDArray::templatedSet(void *buffer, const Nd4jLong offset, const void *value) { - auto t = reinterpret_cast(buffer); - const auto y = *(reinterpret_cast(value)); - - t[offset] = static_cast(y); -} -BUILD_DOUBLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong offset, const void *value), LIBND4J_TYPES, LIBND4J_TYPES); - -////////////////////////////////////////////////////////////////////////// -void NDArray::setContext(sd::LaunchContext *context) { - - _context = context; - if (getContext() == nullptr) - _context = sd::LaunchContext ::defaultContext(); // empty context for default cases -} - -////////////////////////////////////////////////////////////////////////// -void const* NDArray::bufferWithOffset(Nd4jLong offset) const { - return const_cast(buffer() != nullptr ? static_cast(buffer()) + (offset * sizeOfT()) : nullptr); -} - -////////////////////////////////////////////////////////////////////////// -void* NDArray::bufferWithOffset(Nd4jLong offset) { - return const_cast(buffer() != nullptr ? static_cast(buffer()) + (offset * sizeOfT()) : nullptr); -} - -////////////////////////////////////////////////////////////////////////// -// eventually method reduces array by excluding its shapes along axes present in dimensions vector -NDArray NDArray::reduceAlongDimension(sd::reduce::FloatOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { - - std::vector copy(dimensions); - - auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, isR() ? dataType() : Environment::getInstance().defaultFloatDataType(), keepDims, supportOldShapes, getContext()->getWorkspace()); - - NDArray result(newShape, true, getContext()); - - this->reduceAlongDimension(op, result, copy, keepDims, supportOldShapes, false); - - return result; -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceAlongDimension(sd::reduce::SameOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { - - std::vector copy(dimensions); - - auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); - - NDArray result(newShape, true, getContext()); - - reduceAlongDimension(op, result, copy, keepDims, supportOldShapes, false); - - return result; -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceAlongDimension(sd::reduce::BoolOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { - - std::vector copy(dimensions); - - auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataType::BOOL, keepDims, supportOldShapes, getContext()->getWorkspace()); - - NDArray result(newShape, true, getContext()); - - reduceAlongDimension(op, result, copy, keepDims, supportOldShapes, false); - - return result; -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceAlongDimension(sd::reduce::LongOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { - - std::vector copy(dimensions); - - auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataType::INT64, keepDims, supportOldShapes, getContext()->getWorkspace()); - - NDArray result(newShape, true, getContext()); - - reduceAlongDimension(op, result, copy, keepDims, supportOldShapes, false); - - return result; -} - -////////////////////////////////////////////////////////////////////////// -// method reduces array by excluding its shapes along axes present in dimensions vector -NDArray NDArray::reduceAlongDimension(sd::reduce::FloatOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const { - return reduceAlongDimension(op, std::vector(dimensions), keepDims, supportOldShapes); -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceAlongDimension(sd::reduce::SameOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const { - return reduceAlongDimension(op, std::vector(dimensions), keepDims, supportOldShapes); -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceAlongDimension(sd::reduce::BoolOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const { - return reduceAlongDimension(op, std::vector(dimensions), keepDims, supportOldShapes); -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceAlongDimension(sd::reduce::LongOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const { - return reduceAlongDimension(op, std::vector(dimensions), keepDims, supportOldShapes); -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceNumber(sd::reduce::FloatOps op, void *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::reduceNumber FloatOps: you can't use this method on String array!"); - - auto shape = ConstantShapeHelper::getInstance().scalarShapeInfo(DataTypeUtils::pickFloatingType(dataType())); - NDArray result(shape, true, this->getContext()); - - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execReduceFloatScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); - NDArray::registerSpecialUse({&result}, {this}); - - return result; -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceNumber(sd::reduce::SameOps op, void *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::reduceNumber SameOps: you can't use this method on String array!"); - - NDArray result(dataType(), getContext()); - - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execReduceSameScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); - NDArray::registerSpecialUse({&result}, {this}); - - return result; -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceNumber(sd::reduce::BoolOps op, void *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::reduceNumber BoolOps: you can't use this method on String array!"); - - auto shape = ConstantShapeHelper::getInstance().scalarShapeInfo(DataType::BOOL); - NDArray result(shape, true, this->getContext()); - - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execReduceBoolScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); - NDArray::registerSpecialUse({&result}, {this}); - - return result; -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceNumber(sd::reduce::LongOps op, void *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::reduceNumber LongOps: you can't use this method on String array!"); - - auto shape = ConstantShapeHelper::getInstance().scalarShapeInfo(DataType::INT64); - NDArray result(shape, true, this->getContext()); - - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execReduceLongScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); - NDArray::registerSpecialUse({&result}, {this}); - - return result; -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::reduceNumber(sd::reduce::FloatOps op, NDArray& target, void *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::reduceNumber FloatOps: you can't use this method on String array!"); - if(target.lengthOf() != 1 || target.dataType() != DataTypeUtils::pickFloatingType(dataType())) - throw std::invalid_argument("NDArray::reduceNumber FloatOps: target array should be scalar and have corresponding float type!"); - - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execReduceFloatScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - NDArray::registerSpecialUse({&target}, {this}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::reduceNumber(sd::reduce::SameOps op, NDArray& target, void *extraParams) const { - - if (isS()) - throw std::runtime_error("NDArray::reduceNumber SameOps: you can't use this method on String array!"); - if(target.lengthOf() != 1 || target.dataType() != dataType()) - throw std::invalid_argument("NDArray::reduceNumber SameOps: target array should be scalar and have same type as this array!"); - - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execReduceSameScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - NDArray::registerSpecialUse({&target}, {this}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::reduceNumber(sd::reduce::BoolOps op, NDArray& target, void *extraParams) const { - - if (isS()) - throw std::runtime_error("NDArray::reduceNumber BoolOps: you can't use this method on String array!"); - if(target.lengthOf() != 1 || target.dataType() != DataType::BOOL) - throw std::invalid_argument("NDArray::reduceNumber BoolOps: target array should be scalar and have bool type!"); - - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execReduceBoolScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - NDArray::registerSpecialUse({&target}, {this}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::reduceNumber(sd::reduce::LongOps op, NDArray& target, void *extraParams) const { - - if (isS()) - throw std::runtime_error("NDArray::reduceNumber LongOps: you can't use this method on String array!"); - if(target.lengthOf() != 1 || target.dataType() != DataType::INT64) - throw std::invalid_argument("NDArray::reduceNumber LongOps: target array should be scalar and have long type!"); - - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execReduceLongScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - NDArray::registerSpecialUse({&target}, {this}); -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::indexReduceNumber(sd::indexreduce::Ops op, ExtraArguments *extraParams) { - if (isS()) - throw std::runtime_error("NDArray::indexReduceNumber: you can't use this method on String array!"); - - auto res = NDArrayFactory::create(0); - - NDArray::NDArray::prepareSpecialUse({&res}, {this}); - NativeOpExecutioner::execIndexReduceScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams == nullptr ? nullptr : extraParams->argumentsAsT(this->dataType()), res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo()); - NDArray::NDArray::registerSpecialUse({&res}, {this}); - - return res; -} - -////////////////////////////////////////////////////////////////////////// -Nd4jLong NDArray::tensorsAlongDimension(std::initializer_list dimensions) const { - return tensorsAlongDimension(std::vector(dimensions)); -} - -////////////////////////////////////////////////////////////////////////// -Nd4jLong NDArray::tensorsAlongDimension(const std::vector& dimensions) const { - std::vector copy(dimensions); - shape::checkDimensions(rankOf(), copy); - - Nd4jLong tadLength = shape::tadLength(this->_shapeInfo, copy.data(), copy.size()); - Nd4jLong numTads = this->lengthOf() / tadLength; - - return numTads; -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::printShapeInfo(const char * msg) const { - - int rank = shape::rank(_shapeInfo); - int lim = shape::shapeInfoLength(rank); - - if(msg != nullptr) - printf("shapeInfo %s: [", msg); - else - printf("shapeInfo: ["); - - printf("%i, ", rank); - for (int i = 1; i < shape::shapeInfoLength(rank) - 3; i++){ - if(i == rank + 1) - printf(" "); - printf("%lld,", _shapeInfo[i]); - } - printf(" %lld,", shape::type(_shapeInfo)); - printf("%lld,", shape::elementWiseStride(_shapeInfo)); - printf("%lld]\n", (Nd4jLong)shape::order(_shapeInfo)); - - fflush(stdout); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::printBuffer(const char* msg, Nd4jLong limit, const bool sync) const{ - if (sync) - syncToHost(); - - if (limit == -1) - limit = (int) this->lengthOf(); - - if (msg != nullptr) - printf("%s: [", msg); - else - printf("["); - if (this->isR()) { - for (Nd4jLong e = 0; e < limit; e++) { - if (e) - printf(", "); - printf("%f", this->e(e)); - } - } - else if (this->isZ()) { - for (Nd4jLong e = 0; e < limit; e++) { - if (this->dataType() != sd::DataType::INT64 && this->dataType() != sd::DataType::UINT64) - printf("%d", this->e(e)); - else - printf("%llu", this->e(e)); - if (e < limit - 1) - printf(", "); - } - } - else if (this->isB()) { - for (Nd4jLong e = 0; e < limit; e++) { - if (this->e(e)) - printf("true"); - else - printf("false"); - if (e < limit - 1) - printf(", "); - } - } - else if (this->isS()) { - // todo do we need this print offsets - /* - for (Nd4jLong e = 0; e < limit; e++) { - printf("\"%lld\"", this->getOffset(e)); - if (e < limit - 1) - printf(", "); - } - printf("]\n["); - */ - for (Nd4jLong e = 0; e < limit; e++) { - printf("\"%s\"", this->e(e).c_str()); - if (e < limit - 1) - printf(", "); - } - } - printf("]\n"); - fflush(stdout); -} - -////////////////////////////////////////////////////////////////////////// -// print element by element consequently in a way they (elements) are stored in physical memory -void NDArray::printLinearBuffer() const { - - syncToHost(); - - const auto ews = this->ews() > 0 ? this->ews() : 1; - const auto len = this->lengthOf(); - - printf("["); - - if (this->dataType() == sd::DataType::INT32) { - for(Nd4jLong e = 0; e < len; e++) - printf("%d, ", this->bufferAsT()[e * ews]); - } - else if(this->dataType() == sd::DataType::INT64) { - for(Nd4jLong e = 0; e < len; e++) - printf("%lld, ", this->bufferAsT()[e * ews]); - } - else if(this->dataType() == sd::DataType::FLOAT32) { - for(Nd4jLong e = 0; e < len; e++) - printf("%.8f, ", this->bufferAsT()[e * ews]); - } - else if(this->dataType() == sd::DataType::DOUBLE) { - for(Nd4jLong e = 0; e < len; e++) - printf("%.8f, ", this->bufferAsT()[e * ews]); - } - else - throw std::invalid_argument("NDArray::printLinearBuffer: not implemented yet for this data type !"); - - printf("]\n"); - fflush(stdout); -} -////////////////////////////////////////////////////////////////////////// -static void printFormatted(NDArray const* arr, int depth, int limit) { - - if (arr->rankOf() == 1) { - printf("[ "); - for (Nd4jLong i = 0; i < arr->lengthOf(); ++i) { - if (arr->isR()) - printf("%f, ", arr->e(i)); - else if (arr->isZ()) - printf("%lld, ", arr->e(i)); - else if (arr->isB()) - printf("%s, ", arr->e(i)?"true":"false"); - else if (arr->isS()) { - printf("\"%s\", ", arr->e(i).c_str()); - } - } - printf("]\n"); - } - else if (arr->rankOf() == 2) { - Nd4jLong rows = arr->rows(); - Nd4jLong cols = arr->columns(); - char* padding = new char[depth + 1]; - memset(padding, ' ', depth); - padding[depth] = 0; - printf("["); - for (Nd4jLong row = 0; row < rows; ++row) { - if (row && depth > 0) - printf("%s", padding); - printf("["); - Nd4jLong colLimit = cols > limit?cols:limit; - for (Nd4jLong col = 0; col < colLimit; ++col) { - if (col) - printf(", "); - if (arr->isR()) - printf("%f", arr->e(row, col)); - else if (arr->isZ()) - printf("%lld", arr->e(row, col)); - else if (arr->isB()) - printf("%s", arr->e(row, col)?"true":"false"); - else if (arr->isS()) { - printf("\"%s\"", arr->e(row * cols + col).c_str()); - } - } - if (row < rows - 1) - printf("]\n"); - else - printf("]"); - } - printf("]"); - if (padding) - delete [] padding; - } - else { - //std::unique_ptr arrs(arr->allTensorsAlongDimension({0})); - size_t restCount = 2; - printf("["); - restCount = ShapeUtils::getNumOfSubArrs(arr->shapeInfo(), {0}); - for (size_t arrIndex = 0; arrIndex < restCount; ++arrIndex) { - NDArray subArr = (*arr)(arrIndex, {0}); - printFormatted(&subArr, depth + 1, limit); - if (arrIndex < restCount - 1) { - for (Nd4jLong i = 1; i < arr->rankOf(); ++i) - printf("\n"); - for (Nd4jLong i = 0; i < depth - 2; ++i) - printf(" "); - } - } - printf("]"); - } -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::printIndexedBuffer(const char* msg, Nd4jLong limit) const { - - syncToHost(); - - Nd4jLong rank = this->rankOf(); - - bool rowFlag = (rank < 2) || (rank == 2 && this->sizeAt(0) == 1); - - if (msg) - printf("%s: ", msg); - - if (this->isEmpty()) { - printf("Empty\n"); - } - else if (this->rankOf() == 0) { - if (this->isZ()) - printf("%lld\n", this->e(0)); - else if (this->isR()) - printf("%.8f\n", this->e(0)); - else if (this->isB()) { - printf("%s\n", this->e(0)?"true":"false"); - } - else if (this->isS()) { - // todo do we need this - // printf("\"%lld\"\n", this->getOffset(e)); - printf("\"%s\"\n", this->e(0).c_str()); - } - } - else if (rowFlag && ews()==1) - printBuffer(nullptr, limit); - else { - if (msg) - printf("\n"); - printFormatted(this, 1, limit); - printf("\n"); - } - fflush(stdout); -} - -////////////////////////////////////////////////////////////////////////// -template -void* NDArray::templatedPointerShift(const Nd4jLong offset) const { - return const_cast(reinterpret_cast(buffer()) + offset); -} -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void* NDArray::templatedPointerShift, (const Nd4jLong offset) const, LIBND4J_TYPES); - -////////////////////////////////////////////////////////////////////////// -// method makes copy of this array and applies to the copy transpose operation, this array remains unaffected -NDArray NDArray::transpose() const &{ - NDArray newArr(getDataBuffer(), ShapeDescriptor(shapeInfo()), getContext(), bufferOffset()); - newArr.transposei(); - - return newArr; -} - -////////////////////////////////////////////////////////////////////////// -// method makes copy of this array and applies to the copy transpose operation, this array remains unaffected -NDArray NDArray::transpose() && { - - this->transposei(); - return std::move(*this); -} - -//////////////////////////////////////////////////////////////////////// -// method performs transpose operation based on this array and store result in target, this array remains unaffected -void NDArray::transpose(NDArray& target) const { - - auto correctShape = ShapeUtils::evalTranspShapeInfo(*this, getContext()->getWorkspace()); - if(!shape::equalsStrict(correctShape, target.shapeInfo())) - throw std::runtime_error("NDArray::transpose method: the shapeInfo of target array is wrong !"); - - target._buffer = _buffer; - target._offset = _offset; - target._isView = true; -} - -//////////////////////////////////////////////////////////////////////// -// This method applies in-place transpose to this array, so this array becomes transposed -void NDArray::transposei() { - std::vector perm; - for (int e = this->rankOf() - 1; e >= 0; e--) - perm.emplace_back(e); - - this->permutei(perm); -} - -//////////////////////////////////////////////////////////////////////// -bool NDArray::equalsTo(const NDArray &other, double eps) const { - return equalsTo(&other, eps); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::setAttached(bool reallyAttached) { - _isAttached = reallyAttached; -}; - -////////////////////////////////////////////////////////////////////////// -// calculate strides -void NDArray::updateStrides(const char order) { - throw std::runtime_error("Forbidden method"); -} - -////////////////////////////////////////////////////////////////////////// -// set new order and shape in case of suitable array length -bool NDArray::reshapei(const char order, const std::initializer_list& shape, const bool copyToNewBuff) { - std::vector vShape(shape); - return reshapei(order, vShape, copyToNewBuff); -} - -////////////////////////////////////////////////////////////////////////// -bool NDArray::reshapei(const std::initializer_list& shape, const bool copyToNewBuff) { - return reshapei(ordering(), shape, copyToNewBuff); -} - -////////////////////////////////////////////////////////////////////////// -bool NDArray::reshapei(const std::vector& shape, const bool copyToNewBuff) { - return reshapei(ordering(), shape, copyToNewBuff); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::enforce(const std::initializer_list &dimensions, char order) { - std::vector dims(dimensions); - enforce(dims, order); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::enforce(std::vector &dimensions, char o) { - - Nd4jLong prod = 1; - for (int e = 0; e < dimensions.size(); e++) - prod *= dimensions[e]; - - if (prod != this->lengthOf()) { - std::string current = ShapeUtils::shapeAsString(this); - std::string enforced = ShapeUtils::shapeAsString(dimensions); - nd4j_printf("Can't enforce new shape, lengths mismatch. Original shape: %s; Requested shape: %s\n", current.c_str(), enforced.c_str()); - throw std::runtime_error("Incompatible shape"); - } - - char order = o == 'a' ? this->ordering() : o; - setShapeInfo(ShapeDescriptor(dataType(), order, dimensions)); -} - -////////////////////////////////////////////////////////////////////////// -Nd4jLong NDArray::argMax(std::initializer_list dimensions) { - - if (isS()) - throw std::runtime_error("NDArray::argMax: you can't use this method on String array!"); - - if (dimensions.size() == 0) { - Nd4jLong max = 0; - auto mv = -DataTypeUtils::max(); - for (Nd4jLong e = 0; e < this->lengthOf(); e++) { - auto val = this->e(e); - if (mv < val) { - mv = val; - max = e; - } - } - return max; - } - else - throw std::runtime_error("Not implemented yet"); -} - -////////////////////////////////////////////////////////////////////////// -// create new array with corresponding order and shape, new array will point to the same _buffer as this array -NDArray NDArray::reshape(const char order, const std::vector& shape, const bool copyToNewBuff) const & { - - NDArray newArr(getDataBuffer(), ShapeDescriptor(shapeInfo()), getContext(), bufferOffset()); - newArr.reshapei(order, shape, copyToNewBuff); - - return newArr; -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reshape(const char order, const std::vector& shape, const bool copyToNewBuff) && { - - this->reshapei(order, shape, copyToNewBuff); - return std::move(*this); -} - -////////////////////////////////////////////////////////////////////////// -// change an array by repeating it the number of times given by reps. -void NDArray::tilei(const std::vector& reps) { - *this = this->tile(reps); -} - -////////////////////////////////////////////////////////////////////////// -Nd4jLong NDArray::sizeAt(const int dim) const { - - if (dim >= this->rankOf() || dim < -this->rankOf()) - throw std::runtime_error("NDArray::sizeAt: bad size index requested"); - - if (dim >= 0) - return shape::shapeOf(_shapeInfo)[dim]; - else - return shape::shapeOf(_shapeInfo)[this->rankOf() + dim]; -} - -////////////////////////////////////////////////////////////////////////// -Nd4jLong NDArray::strideAt(const int dim) const { - - if (dim >= this->rankOf() || dim < -this->rankOf()) - throw std::runtime_error("NDArray::strideAt: Bad size index requested"); - - if (dim >= 0) - return shape::stride(_shapeInfo)[dim]; - else - return shape::stride(_shapeInfo)[this->rankOf() + dim]; -} - -////////////////////////////////////////////////////////////////////////// -bool NDArray::permutei(const std::initializer_list& dimensions) { - std::vector vec(dimensions); - return permutei(vec); -} - -////////////////////////////////////////////////////////////////////////// -bool NDArray::permutei(const std::vector& dimensions) { - return permutei(dimensions.data(), rankOf()); -} - -////////////////////////////////////////////////////////////////////////// -bool NDArray::permutei(const std::initializer_list& dimensions) { - std::vector vec(dimensions); - std::vector ivec(dimensions.size()); - - for (int e = 0; e < vec.size(); e++) - ivec[e] = static_cast(vec[e]); - - return permutei(ivec); -} - -////////////////////////////////////////////////////////////////////////// -bool NDArray::permutei(const std::vector& dimensions) { - - std::vector ivec(dimensions.size()); - - for (int e = 0; e < dimensions.size(); e++) - ivec[e] = dimensions[e]; - - return permutei(ivec.data(), rankOf()); -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const int* dimensions, const int rank) const & { - - // evaluate shapeInfo for output (permuted) array ret - auto shapeInfoPermuted = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, getContext()->getWorkspace()); - NDArray ret(getDataBuffer(), ShapeDescriptor(shapeInfoPermuted), getContext(), bufferOffset()); - ret._isView = true; - return ret; -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const int* dimensions, const int rank) && { - - this->permutei(dimensions, rank); - return std::move(*this); -} - -///////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) const &{ - int tempDims[MAX_RANK]; - shape::convertT(const_cast(dimensions), tempDims, rank); - return permute(tempDims, rank); -} - -///////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) && { - - this->permutei(dimensions, rank); - return std::move(*this); -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::vector& dimensions) const &{ - - return permute(dimensions.data(), rankOf()); -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::vector& dimensions) && { - - this->permutei(dimensions); - return std::move(*this); -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::vector& dimensions) const & { - - return permute(dimensions.data(), rankOf()); -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::vector& dimensions) && { - - this->permutei(dimensions); - return std::move(*this); -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::initializer_list& dimensions) const &{ - - std::vector vec(dimensions); - return permute(vec); -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::initializer_list& dimensions) && { - - this->permutei(dimensions); - return std::move(*this); -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::initializer_list& dimensions) const & { - std::vector vec(dimensions); - return permute(vec); -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::initializer_list& dimensions) && { - - this->permutei(dimensions); - return std::move(*this); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::permute(const int* dimensions, const int rank, NDArray& target) const { - if (!nonNull() || !target.nonNull() || rank != rankOf() || rank != target.rankOf() ) - throw std::runtime_error("NDArray::permute method: either arrays are nullptr or ranks are not suitable!"); - - auto shapeInfoNew = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, target.getContext()->getWorkspace()); - - target.setShapeInfo(shapeInfoNew); - target._buffer = _buffer; - target._offset = _offset; -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::permute(const Nd4jLong *dimensions, const int rank, NDArray& target) const { - if (!nonNull() || !target.nonNull() || rank != rankOf() || rank != target.rankOf() ) - throw std::runtime_error("NDArray::permute method: either arrays are nullptr or ranks are not suitable!"); - - auto shapeInfoNew = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, target.getContext()->getWorkspace()); - - target.setShapeInfo(shapeInfoNew); - target._buffer = _buffer; - target._offset = _offset; -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::permute(const std::vector& dimensions, NDArray& target) const { - permute(dimensions.data(), rankOf(), target); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::permute(const std::vector& dimensions, NDArray& target) const { - permute(dimensions.data(), rankOf(), target); -} - -////////////////////////////////////////////////////////////////////////// -// check whether array is identity matrix -bool NDArray::isIdentityMatrix() { - if (isS()) - throw std::runtime_error("NDArray::isIdentityMatrix: you can't use this method on String array!"); - if(rankOf() !=2 || rows() != columns()) - throw std::runtime_error("isIdentityMatrix method: matrix must be square and have rank = 2 !"); - - const double eps = 1e-5f; - for(Nd4jLong i=0; i(i,i) - 1.f) > eps) - return false; - - for(Nd4jLong i=0; i(i,j)) > eps) - return false; - } - } - return true; -} - -////////////////////////////////////////////////////////////////////////// -// check whether array is unitary matrix -bool NDArray::isUnitary() { - if (isS()) - throw std::runtime_error("NDArray::isUnitary: you can't use this method on String array!"); - if(rankOf() != 2 || rows() != columns()) - throw std::runtime_error("isUnitary method: matrix must be square and have rank = 2 !"); - - auto tr = this->transpose(); - auto trMul = MmulHelper::mmul(this, &tr, nullptr, 1.f, 0.f); - - bool result = trMul->isIdentityMatrix(); - delete trMul; - - return result; -} - -////////////////////////////////////////////////////////////////////////// -template <> -const std::string* ND4J_EXPORT NDArray::bufferAsT() const { - throw std::runtime_error("This method is NOT supposed to be used"); -} - -////////////////////////////////////////////////////////////////////////// -template -const T* NDArray::bufferAsT() const { - // FIXME: do we REALLY want sync here? - // syncToHost(); - - return reinterpret_cast(buffer()); -} -BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT const, * NDArray::bufferAsT() const, LIBND4J_TYPES); - -template -T* NDArray::bufferAsT() { - syncToHost(); - return reinterpret_cast(buffer()); -} -BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT, * NDArray::bufferAsT(), LIBND4J_TYPES); - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::subarray(IndicesList& idx) const { - - const int idxSize = idx.size(); - if (idxSize != this->rankOf()) - throw std::runtime_error("NDArray::subarray: number of indices should match"); - - std::vector indexes(3 * idxSize); - - // convert IndicesList to vector - for (int d = 0; d < idxSize; ++d) { - - if (idx.at(d)->isAll()) { - indexes[3 * d] = 0; // first - indexes[3 * d + 1] = 0; // last - indexes[3 * d + 2] = 1; // stride - } - else if (idx.at(d)->isPoint()) { - indexes[3 * d] = idx.at(d)->getIndices().at(0); // first - indexes[3 * d + 1] = indexes[3 * d] + 1; // last - indexes[3 * d + 2] = 1; // stride - } - else if (idx.at(d)->isInterval()) { - indexes[3 * d] = idx.at(d)->getIndices().at(0); // first - indexes[3 * d + 1] = idx.at(d)->getIndices().size();// last - indexes[3 * d + 2] = idx.at(d)->stride(); // stride - } - else { - indexes[3 * d] = idx.at(d)->getIndices().at(0); // first - indexes[3 * d + 1] = idx.at(d)->getIndices().at(1); // last - indexes[3 * d + 2] = idx.at(d)->getIndices().at(2); // stride - } - } - return NDArray((*this)(indexes, true, true)); -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::subarray(const std::initializer_list& idx) const { - - const int idxSize = idx.size(); - if (idxSize != this->rankOf()) - throw std::runtime_error("NDArray::subarray: number of indices should match the array rank"); - - std::vector indexes(3 * idxSize); - - // convert NDIndex to vector - int d = 0; - for (const auto& item : idx) { - - if (item->isAll()) { - indexes[3 * d] = 0; // first - indexes[3 * d + 1] = 0; // last - indexes[3 * d + 2] = 1; // stride - } - else if (item->isPoint()) { - indexes[3 * d] = item->getIndices().at(0); // first - indexes[3 * d + 1] = indexes[3 * d] + 1; // last - indexes[3 * d + 2] = 1; // stride - } - else if (item->isInterval()) { - indexes[3 * d] = item->getIndices().at(0); // first - indexes[3 * d + 1] = item->getIndices().size(); // last - indexes[3 * d + 2] = item->stride(); // stride - } - else { - indexes[3 * d] = item->getIndices().at(0); // first - indexes[3 * d + 1] = item->getIndices().at(1); // last - indexes[3 * d + 2] = item->getIndices().at(2); // stride - } - ++d; - } - - // release NDIndices - for (auto i: idx) - delete i; - - return NDArray((*this)(indexes, true, true)); -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::subarray(const Intervals& idx) const { - - const int idxSize = idx.size(); - if (idxSize != this->rankOf()) - throw std::runtime_error("NDArray::subarray: number of indices should match the rank of array!"); - - std::vector indexes(2 * idxSize); - - // convert Intervals to vector - for (int d = 0; d < idxSize; ++d) { - - if (idx[d].empty()) { - indexes[2 * d] = 0; // first - indexes[2 * d + 1] = 0; // last - } - else { - indexes[2 * d] = idx[d][0]; // first - indexes[2 * d + 1] = idx[d][1]; // last - } - } - - return NDArray((*this)(indexes, true)); -} - -////////////////////////////////////////////////////////////////////////// -template -NDArray NDArray::asT() const{ - - auto result = isScalar() ? NDArray('c', {}, std::vector{0.}, DataTypeUtils::fromT(), this->getContext()) : NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT(), this->getContext()); - - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execTransformAny(getContext(), transform::AnyOps::Assign, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({&result}, {this}); - - return result; -} -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArray::asT, () const, LIBND4J_TYPES); - -////////////////////////////////////////////////////////////////////////// -template -NDArray NDArray::asS() const { - - if (!isS()) - throw std::runtime_error("NDArray::asS: you can use this method only for String array!"); - - auto dtype = DataTypeUtils::fromT(); - - if (!(DataTypeUtils::isS(dtype))) - throw std::invalid_argument("NDArray::asS: invalid DataType used"); - - if (dtype == dataType()) { - - Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); - const auto nInputoffsets = bufferAsT(); - std::shared_ptr pBuffer = std::make_shared(offsetsLength + nInputoffsets[lengthOf()], dtype, getContext()->getWorkspace(), true); - - NDArray res(pBuffer, ShapeDescriptor(dtype, ordering(), getShapeAsVector()), getContext()); - res.setAttached(getContext()->getWorkspace() != nullptr); - - preparePrimaryUse({ &res }, { this }); - memcpy(res.bufferAsT(), nInputoffsets, offsetsLength); - auto data = res.bufferAsT() + offsetsLength; - const auto inData = bufferAsT() + offsetsLength; - memcpy(data, inData, nInputoffsets[lengthOf()]); - - registerPrimaryUse({ &res }, { this }); - return res; - } - - Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); - - std::vector offsets(lengthOf() + 1); - - const auto nInputoffsets = bufferAsT(); - - Nd4jLong start = 0, stop = 0; - Nd4jLong dataLength = 0; - - auto data = bufferAsT() + offsetsLength; - for (Nd4jLong e = 0; e < lengthOf(); e++) { - offsets[e] = dataLength; - start = nInputoffsets[e]; - stop = nInputoffsets[e + 1]; - if (dataType() == DataType::UTF8) { - dataLength += (dtype == DataType::UTF16) ? unicode::offsetUtf8StringInUtf16(data + start, stop) - : unicode::offsetUtf8StringInUtf32(data + start, stop); - } - else if (dataType() == DataType::UTF16) { - dataLength += (dtype == DataType::UTF32) ? unicode::offsetUtf16StringInUtf32(data + start, (stop / sizeof(char16_t)) ) - : unicode::offsetUtf16StringInUtf8(data + start, (stop / sizeof(char16_t))); - } - else { - dataLength += (dtype == DataType::UTF16) ? unicode::offsetUtf32StringInUtf16(data + start, (stop / sizeof(char32_t))) - : unicode::offsetUtf32StringInUtf8(data + start, (stop / sizeof(char32_t))); - } - } - offsets[lengthOf()] = dataLength; - - std::shared_ptr pBuffer = std::make_shared(offsetsLength + dataLength, dtype, getContext()->getWorkspace(), true); - - NDArray res(pBuffer, ShapeDescriptor(dtype, ordering(), getShapeAsVector()), getContext()); - res.setAttached(getContext()->getWorkspace() != nullptr); - - preparePrimaryUse({ &res }, { this }); - - memcpy(res.bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); - - auto outData = res.bufferAsT() + offsetsLength; - const auto inData = bufferAsT() + offsetsLength; - - auto func = PRAGMA_THREADS_FOR{ - for (int e = start; e < stop; e++) { - auto cdata = outData + offsets[e]; - auto end = nInputoffsets[e + 1]; - auto idata = inData + nInputoffsets[e]; - if (dtype == DataType::UTF16) { - if (dataType() == DataType::UTF8) { - unicode::utf8to16(idata, outData, end); - } - else { - unicode::utf32to16(idata, outData, (end / sizeof(char32_t))); - } - } - else if (dtype == DataType::UTF32) { - if (dataType() == DataType::UTF8) { - unicode::utf8to32(idata, cdata, end); - } - else { - unicode::utf16to32(idata, outData, (end / sizeof(char16_t))); - } - } - else { - if (dataType() == DataType::UTF16) { - unicode::utf16to8(idata, outData, (end / sizeof(char16_t))); - } - else { - unicode::utf32to8(idata, outData, (end / sizeof(char32_t))); - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - - registerPrimaryUse({ &res }, { this }); - - return res; -} -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArray::asS, () const, LIBND4J_STRINGTYPES); - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::asT(DataType dtype) const { - - if (isS() && !DataTypeUtils::isS(dtype)) - throw std::runtime_error("NDArray::asT: you can't use this method on String array with not string DataType!"); - - if (!isS() && DataTypeUtils::isS(dtype)) - throw std::runtime_error("NDArray::asT: you can't use this method on not String array with string DataType!"); - - if (isS()){ - BUILD_SINGLE_SELECTOR(dtype, return asS, (), LIBND4J_STRINGTYPES); - } else { - BUILD_SINGLE_SELECTOR(dtype, return asT, (), LIBND4J_TYPES); - } - - return NDArray(); -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::cast(DataType dtype) const { - - if (isS() && !DataTypeUtils::isS(dtype)) - throw std::runtime_error("NDArray::cast: you can't use this method on String array with not string DataType!"); - - if (!isS() && DataTypeUtils::isS(dtype)) - throw std::runtime_error("NDArray::cast: you can't use this method on not String array with string DataType!"); - - return this->asT(dtype); -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::cast(NDArray& target, DataType dtype) { - if (isS()) - throw std::runtime_error("NDArray::cast: you can't use this method on String array!"); - // TODO: to be implemented properly - target.assign(this); -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::operator+=(const NDArray& other) { - - if (isS()) - throw std::runtime_error("NDArray::operator+=: you can't use this method on String array!"); - if (!Environment::getInstance().isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) - throw sd::datatype_exception::build("NDArray operator+=: Cannot add different types", this->dataType(), other.dataType()); - - if (this->lengthOf() != 1 && other.lengthOf() == 1) { - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); - } - else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), sd::pairwise::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); - } - else{ - const Nd4jLong *bShape = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, bShape, getContext()->getWorkspace())) - throw std::invalid_argument("NDArray::operator+=: the shapes of this and other arrays are not suitable for broadcast operation !"); - - if(shape::equalsTypesAndShapesSoft(shapeInfo(), bShape)) { - this->applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), other, *this, false); - } - else { - NDArray result(bShape, true, getContext()); - this->applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), other, result, false); - *this = std::move(result); // move assignment operator, zero cost copy - } - } -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::operator-=(const NDArray& other) { - if (isS()) - throw std::runtime_error("NDArray::operator-=: you can't use this method on String array!"); - - if (!Environment::getInstance().isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) - throw sd::datatype_exception::build("NDArray operator-=: Cannot subtract different types", this->dataType(), other.dataType()); - - if (lengthOf() != 1 && other.lengthOf() == 1) { - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::Subtract, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); - } - else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), sd::pairwise::Subtract, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); - } - else{ - const Nd4jLong *bShape = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, bShape, getContext()->getWorkspace())) - throw std::invalid_argument("NDArray::operator-=: the shapes of this and other arrays are not suitable for broadcast operation !"); - - if(shape::equalsTypesAndShapesSoft(shapeInfo(), bShape)) { - this->applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), other, *this, false); - } - else { - NDArray result(bShape, true, getContext()); - this->applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), other, result, false); - *this = std::move(result); // move assignment operator, zero cost copy - } - } -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::operator*=(const NDArray& other) { - if (isS()) - throw std::runtime_error("NDArray::operator*=: you can't use this method on String array!"); - if (!Environment::getInstance().isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) - throw sd::datatype_exception::build("NDArray operator*=: Cannot multiply different types", this->dataType(), other.dataType()); - - if (lengthOf() != 1 && other.lengthOf() == 1) { - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::Multiply, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); - } - else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), sd::pairwise::Multiply, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); - } - else{ - const Nd4jLong *bShape = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, bShape, getContext()->getWorkspace())) - throw std::invalid_argument("NDArray::operator*=: the shapes of this and other arrays are not suitable for broadcast operation !"); - - if(shape::equalsTypesAndShapesSoft(_shapeInfo, bShape)) { - this->applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), other, *this, false); - } - else { - NDArray result(bShape, true, getContext()); - this->applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), other, result, false); - *this = std::move(result); // move assignment operator, zero cost copy - } - } -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::operator/=(const NDArray& other) { - if (isS() || other.isS()) - throw std::runtime_error("NDArray::operator/=: you can't use this method on String array!"); - if (other.isB()) - throw std::runtime_error("NDArray::operator/=: you can't divide by bool array!"); - - if (!Environment::getInstance().isExperimentalBuild() && this->dataType() != other.dataType()) { - throw sd::datatype_exception::build("NDArray operator/=: Cannot divide different types", this->dataType(), other.dataType()); - } - - if (lengthOf() != 1 && other.lengthOf() == 1) { - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::Divide, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); - } - else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), sd::pairwise::Divide, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); - } - else{ - const Nd4jLong *bShape = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, bShape, getContext()->getWorkspace())) - throw std::invalid_argument("NDArray::operator/=: the shapes of this and other arrays are not suitable for broadcast operation !"); - - if(shape::equalsTypesAndShapesSoft(_shapeInfo, bShape)) { - this->applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), other, *this, false); - } - else { - NDArray result(bShape, true, getContext()); - this->applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), other, result, false); - *this = std::move(result); // move assignment operator, zero cost copy - } - } -} - -//////////////////////////////////////////////////////////////////////// -template -void NDArray::operator+=(const T value) { - if (isS()) - throw std::runtime_error("NDArray::operator+=: you can't use this method on String array!"); - - auto other = NDArrayFactory::create(this->dataType(), value, getContext()); - - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); -} -template ND4J_EXPORT void NDArray::operator+=(const double value); -template ND4J_EXPORT void NDArray::operator+=(const float value); -template ND4J_EXPORT void NDArray::operator+=(const float16 value); -template ND4J_EXPORT void NDArray::operator+=(const bfloat16 value); -template ND4J_EXPORT void NDArray::operator+=(const Nd4jLong value); -template ND4J_EXPORT void NDArray::operator+=(const int value); -template ND4J_EXPORT void NDArray::operator+=(const bool value); - -//////////////////////////////////////////////////////////////////////// -template -void NDArray::operator-=(const T value) { - if (isS()) - throw std::runtime_error("NDArray::operator-=: you can't use this method on String array!"); - - auto other = NDArrayFactory::create(dataType(), value, getContext()); - - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::Subtract, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); -} -template ND4J_EXPORT void NDArray::operator-=(const double value); -template ND4J_EXPORT void NDArray::operator-=(const float value); -template ND4J_EXPORT void NDArray::operator-=(const float16 value); -template ND4J_EXPORT void NDArray::operator-=(const bfloat16 value); -template ND4J_EXPORT void NDArray::operator-=(const Nd4jLong value); -template ND4J_EXPORT void NDArray::operator-=(const int value); -template ND4J_EXPORT void NDArray::operator-=(const bool value); - -//////////////////////////////////////////////////////////////////////// -template -void NDArray::operator*=(const T scalar) { - if (isS()) - throw std::runtime_error("NDArray::operator*=: you can't use this method on String array!"); - - auto other = NDArrayFactory::create(this->dataType(), scalar, getContext()); - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::Multiply, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); -} -template ND4J_EXPORT void NDArray::operator*=(const double scalar); -template ND4J_EXPORT void NDArray::operator*=(const float scalar); -template ND4J_EXPORT void NDArray::operator*=(const float16 scalar); -template ND4J_EXPORT void NDArray::operator*=(const bfloat16 scalar); -template ND4J_EXPORT void NDArray::operator*=(const Nd4jLong scalar); -template ND4J_EXPORT void NDArray::operator*=(const int scalar); -template ND4J_EXPORT void NDArray::operator*=(const int16_t scalar); -template ND4J_EXPORT void NDArray::operator*=(const int8_t scalar); -template ND4J_EXPORT void NDArray::operator*=(const uint8_t scalar); -template ND4J_EXPORT void NDArray::operator*=(const bool scalar); - -//////////////////////////////////////////////////////////////////////// -template -void NDArray::operator/=(const T scalar) { - if (isS()) - throw std::runtime_error("NDArray::operator/=: you can't use this method on String array!"); - - auto other = NDArrayFactory::create(this->dataType(), scalar, getContext()); - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::Divide, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); -} -template ND4J_EXPORT void NDArray::operator/=(const double scalar); -template ND4J_EXPORT void NDArray::operator/=(const float scalar); -template ND4J_EXPORT void NDArray::operator/=(const float16 scalar); -template ND4J_EXPORT void NDArray::operator/=(const bfloat16 scalar); -template ND4J_EXPORT void NDArray::operator/=(const Nd4jLong scalar); -template ND4J_EXPORT void NDArray::operator/=(const int scalar); -template ND4J_EXPORT void NDArray::operator/=(const int16_t scalar); -template ND4J_EXPORT void NDArray::operator/=(const int8_t scalar); -template ND4J_EXPORT void NDArray::operator/=(const uint8_t scalar); -template ND4J_EXPORT void NDArray::operator/=(const bool scalar); - -//////////////////////////////////////////////////////////////////////// -// negative operator, it makes all array elements = -elements -NDArray NDArray::operator-() const & { - if (isS()) - throw std::runtime_error("NDArray::negative-: you can't use this method on String array!"); - - NDArray result(shapeInfo(), false, getContext()); - - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execTransformSame(getContext(), sd::transform::Neg, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({&result}, {this}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::operator-() && { - if (isS()) - throw std::runtime_error("NDArray::negative-: you can't use this method on String array!"); - - NDArray::prepareSpecialUse({this}, {this}); - NativeOpExecutioner::execTransformSame(getContext(), sd::transform::Neg, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({this}, {this}); - - return std::move(*this); -} - -//////////////////////////////////////////////////////////////////////// -// mathematical multiplication of two arrays -NDArray mmul(const NDArray& left, const NDArray& right) { - if (left.isS() || right.isS()) - throw std::runtime_error("mmul friend function: you can't use this function on String array!"); - auto ptr = MmulHelper::mmul(const_cast(&left), const_cast(&right), nullptr, 1., 0.); - NDArray result(std::move(*ptr)); - delete ptr; - return result; -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::tileToShape(const std::vector& shape, NDArray& target) { - if(&target != this) { - this->tile(target); - return; - } - - std::vector thisShape(rankOf()); - for(int i = 0; i < rankOf(); ++i) - thisShape[i] = sizeAt(i); - - if(!ShapeUtils::areShapesBroadcastable(shape, thisShape)) - throw std::runtime_error("NDArray::tileToShape method: the shape of this array and input shape are not suitable for broadcast operation !"); - - const int newRank = shape.size(); - std::vector repeats(newRank); - - for(int i = 1; i <= newRank; ++i) { - if(i > rankOf()) - repeats[newRank-i] = shape[newRank - i]; - else - repeats[newRank-i] = shape[newRank - i] / thisShape[rankOf() - i]; - } - - tilei(repeats); -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::tileToShape(const std::initializer_list& shape, NDArray& target) { - tileToShape(std::vector(shape), target); -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::tileToShape(const Nd4jLong* shapeInfo) { - - NDArray result(const_cast(shapeInfo), false, getContext()); - tile(result); - return result; -} - -//////////////////////////////////////////////////////////////////////// -double NDArray::getTrace() const { - if (isS()) - throw std::runtime_error("NDArray::getTrace: you can't use this method on String array!"); - - int rank = rankOf(); - auto shape = shapeOf(); - int minDim = 100000000; - - Nd4jLong indices[MAX_RANK]; - for(int j = 0; j < rank; ++j) - indices[j] = 1; - - auto offset = shape::getOffset(shapeInfo(), indices); - - for(int i = 0; i < rank; ++i) - if(minDim > shape[i]) - minDim = shape[i]; - - double sum = 0.; - - for(int i = 0; i < minDim; ++i) - sum += e(i * offset); - - return sum; -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::quantize(const NDArray& array) { - - if(!array.isR()) - throw std::invalid_argument("NDArray::quantize: type of array should be from real space!"); - - auto ws = array.getContext()->getWorkspace(); - - Nd4jLong* shapeInfo = ShapeBuilders::copyShapeInfo(array.shapeInfo(), true, ws); - ArrayOptions::setPropertyBit(shapeInfo, ARRAY_QUANTIZED); - - std::shared_ptr buffer = std::make_shared(TypeCast::estimateQuantizedSize(array.lengthOf()), ArrayOptions::dataType(shapeInfo), ws); - - NDArray result(buffer, ShapeDescriptor(shapeInfo), array.getContext()); - - return result; -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape, ExtraArguments *extraArgs) const { - - if (isS()) - throw std::runtime_error("NDArray::applyTrueBroadcast: you can't use this method on String array!"); - - if(((op.s == scalar::Divide || op.s == scalar::FloorDiv || op.s == scalar::FloorMod) && other.isB()) || (op.s == scalar::ReverseDivide && this->isB())) - throw std::runtime_error("NDArray::applyTrueBroadcast method: you can't divide by bool array !"); - - if (isEmpty() || other.isEmpty()) - return; - - // if (lengthOf() == 1) { - // target.assign(this); - // target.applyPairwiseTransform(op.p, other, extraArgs); - // return; - // } - // if (other.lengthOf() == 1) { - // const_cast(this)->applyScalarArr(op.s, other, target, extraArgs); - // return; - // } - - if(checkTargetShape) { - const Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() - throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); - if(!shape::equalsTypesAndShapesSoft(target.shapeInfo(), newShapeInfo)) - throw std::runtime_error("NDArray::applyTrueBroadcast method: the shape or type of target array is wrong !"); - } - - Nd4jLong const* xShapeInfoH = shapeInfo(); - Nd4jLong const* yShapeInfoH = other.shapeInfo(); - Nd4jLong const* xShapeInfoD = specialShapeInfo(); - Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); - - if(!isSameShape(target)) { - auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace()); - xShapeInfoH = xPack.primary(); - xShapeInfoD = xPack.special(); - } - if(!other.isSameShape(target)) { - auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace()); - yShapeInfoH = yPack.primary(); - yShapeInfoD = yPack.special(); - } - - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execBroadcast(getContext(), op.b, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - registerSpecialUse({&target}, {this, &other}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyTrueBroadcast(sd::BroadcastBoolOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape, ExtraArguments *extraArgs) const { - - if (isS()) - throw std::runtime_error("NDArray::applyTrueBroadcast bool: you can't use this method on String array!"); - - if (isEmpty() || other.isEmpty()) - return; - - // if (lengthOf() == 1) { - // NDArray temp(target._shapeInfo, dataType(), false, getContext()); - // temp.assign(this); - // temp.applyPairwiseTransform(op.p, other, target, extraArgs); - // return; - // } - // if (other.lengthOf() == 1) { - // this->applyScalarArr(op.s, other, target, extraArgs); - // return; - // } - - if(checkTargetShape) { - const Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() - throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); - if(!shape::equalsSoft(target._shapeInfo, newShapeInfo) || target.dataType() != DataType::BOOL) - throw std::runtime_error("NDArray::applyTrueBroadcast bool method: the shape or type of target array is wrong !"); - if(dataType() != other.dataType()) - throw std::invalid_argument("NDArray::applyTrueBroadcast bool method: this and other arrays must have the same type !"); - } - - Nd4jLong const* xShapeInfoH = shapeInfo(); - Nd4jLong const* yShapeInfoH = other.shapeInfo(); - Nd4jLong const* xShapeInfoD = specialShapeInfo(); - Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); - - if(!isSameShape(target)) { - auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace()); - xShapeInfoH = xPack.primary(); - xShapeInfoD = xPack.special(); - } - if(!other.isSameShape(target)) { - auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace()); - yShapeInfoH = yPack.primary(); - yShapeInfoD = yPack.special(); - } - - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execBroadcastBool(getContext(), op.b, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr); - registerSpecialUse({&target}, {this, &other}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyTrueBroadcast(sd::BroadcastIntOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape, ExtraArguments *extraArgs) const { - - if (isS()) - throw std::runtime_error("NDArray::applyTrueBroadcast bool: you can't use this method on String array!"); - - if (isEmpty() || other.isEmpty()) - return; - - // if (lengthOf() == 1) { - // NDArray temp(target._shapeInfo, dataType(), false, getContext()); - // temp.assign(this); - // temp.applyPairwiseTransform(op.p, other, target, extraArgs); - // return; - // } - // if (other.lengthOf() == 1) { - // this->applyScalarArr(op.s, other, target, extraArgs); - // return; - // } - - if(checkTargetShape) { - const Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, false, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() - throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); - if(!shape::equalsSoft(target._shapeInfo, newShapeInfo) || target.dataType() != this->dataType()) - throw std::runtime_error("NDArray::applyTrueBroadcast int method: the shape or type of target array is wrong !"); - if(dataType() != other.dataType()) - throw std::invalid_argument("NDArray::applyTrueBroadcast int method: this and other arrays must have the same type !"); - } - - Nd4jLong const* xShapeInfoH = shapeInfo(); - Nd4jLong const* yShapeInfoH = other.shapeInfo(); - Nd4jLong const* xShapeInfoD = specialShapeInfo(); - Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); - - if(!isSameShape(target)) { - auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace()); - xShapeInfoH = reinterpret_cast(xPack.primary()); - xShapeInfoD = reinterpret_cast(xPack.special()); - } - if(!other.isSameShape(target)) { - auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace()); - yShapeInfoH = reinterpret_cast(yPack.primary()); - yShapeInfoD = reinterpret_cast(yPack.special()); - } - - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execBroadcastInt(getContext(), op.b, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - registerSpecialUse({&target}, {this, &other}); -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) const & { - if (isEmpty() || other.isEmpty()) { - if (isEmpty()) - return NDArray(*this); - else - return NDArray(other); - } - - const Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)() - throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); - NDArray result(newShapeInfo, true, getContext()); - - this->applyTrueBroadcast(op, other, result, false, extraArgs); - - return result; -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, NDArray&& other, ExtraArguments *extraArgs) const & { - if (isEmpty() || other.isEmpty()) { - if (isEmpty()) - return NDArray(*this); - else - return NDArray(other); - } - - const Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)() - throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); - - if(!shape::shapeEquals(newShapeInfo, other.shapeInfo())) { - - NDArray result(newShapeInfo, true, getContext()); - this->applyTrueBroadcast(op, other, result, false, extraArgs); - return std::move(result); - } - - this->applyTrueBroadcast(op, other, other, false, extraArgs); - return std::move(other); -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) && { - if (isEmpty() || other.isEmpty()) { - if (isEmpty()) - return NDArray(*this); - else - return NDArray(other); - } - - const Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)() - throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); - - if(!shape::shapeEquals(newShapeInfo, shapeInfo())) { - - NDArray result(newShapeInfo, true, getContext()); - this->applyTrueBroadcast(op, other, result, false, extraArgs); - return std::move(result); - } - - this->applyTrueBroadcast(op, other, *this, false, extraArgs); - return std::move(*this); -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, NDArray&& other, ExtraArguments *extraArgs) && { - if (isEmpty() || other.isEmpty()) { - if (isEmpty()) - return NDArray(*this); - else - return NDArray(other); - } - - const Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)() - throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); - - const bool thisMove = shape::shapeEquals(newShapeInfo, shapeInfo()); - const bool otherMove = shape::shapeEquals(newShapeInfo, other.shapeInfo()); - - if(!thisMove && !otherMove) { - - NDArray result(newShapeInfo, true, getContext()); - this->applyTrueBroadcast(op, other, result, false, extraArgs); - return std::move(result); - } - - if(thisMove) { - this->applyTrueBroadcast(op, other, *this, false, extraArgs); - return std::move(*this); - } - - // otherMove - this->applyTrueBroadcast(op, other, other, false, extraArgs); - return std::move(other); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyBroadcast(sd::broadcast::Ops op, const std::vector& dimensions, const NDArray& other, NDArray& target, ExtraArguments* extraArgs) { - - if (dimensions.size() == 0) - return; - - if (isS()) - throw std::runtime_error("NDArray::applyBroadcast: you can't use this method on String array!"); - if(((op == broadcast::Divide || op == broadcast::FloorDiv || op == broadcast::FloorMod) && other.isB()) || (op == broadcast::ReverseDivide && this->isB())) - throw std::runtime_error("NDArray::applyBroadcast: you can't divide by array!"); - if(isEmpty() || other.isEmpty()) { - if(!target.isEmpty()) - throw std::runtime_error("NDArray::applyBroadcast method: when some of input arrays (or both) is empty, target array must be empty as well !"); - return; - } - - // if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - // NDArray::prepareSpecialUse({&target}, {this, &other}); - // NativeOpExecutioner::execPairwiseTransform(getContext(), fromBroadcastToPairwise(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.special(), nullptr); - // NDArray::registerSpecialUse({&target}, {this, &other}); - // return; - // } - - if(target.dataType() != DataTypeUtils::pickPairwiseResultType(shapeInfo(), other.shapeInfo())) - throw std::invalid_argument("NDArray::applyBroadcast method: wrong type of target array !"); - if(!target.isSameShape(this) && !target.isSameShape(other)) - throw std::invalid_argument("NDArray::applyBroadcast method: one of of two input arrays (this or other) should has the same shape as target array!"); - - std::vector copy(dimensions); - - if (dimensions.size() > 1) - std::sort(copy.begin(), copy.end()); - - Nd4jLong const* xShapeInfoH = shapeInfo(); - Nd4jLong const* yShapeInfoH = other.shapeInfo(); - Nd4jLong const* xShapeInfoD = specialShapeInfo(); - Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); - - if(!isSameShape(target)) { - auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace(), copy); - xShapeInfoH = reinterpret_cast(xPack.primary()); - xShapeInfoD = reinterpret_cast(xPack.special()); - } - if(!other.isSameShape(target)) { - auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace(), copy); - yShapeInfoH = reinterpret_cast(yPack.primary()); - yShapeInfoD = reinterpret_cast(yPack.special()); - } - - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execBroadcast(getContext(), op, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - registerSpecialUse({&target}, {this, &other}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyBroadcast(sd::broadcast::BoolOps op, const std::vector& dimensions, const NDArray& other, NDArray& target, ExtraArguments* extraArgs) { - - if (dimensions.size() == 0) - return; - - if (isS()) - throw std::runtime_error("NDArray::applyBroadcast BoolOps: you can't use this method on String array!"); - if(isEmpty() || other.isEmpty()) { - if(!target.isEmpty()) - throw std::runtime_error("NDArray::applyBroadcast BoolOps: when some of input arrays (or both) is empty, target array must be empty as well !"); - return; - } - - // if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - // NDArray::prepareSpecialUse({&target}, {this, &other}); - // NativeOpExecutioner::execPairwiseBoolTransform(getContext(), fromBroadcastToPairwiseBool(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.special(), nullptr); - // NDArray::registerSpecialUse({&target}, {this, &other}); - // return; - // } - - if(target.dataType() != DataType::BOOL) - throw std::invalid_argument("NDArray::applyBroadcast bool method: type of target array must be BOOL!"); - if(!target.isSameShape(this) && !target.isSameShape(other)) - throw std::invalid_argument("NDArray::applyBroadcast bool method: one of of two input arrays (this or other) should has the same shape as target array!"); - if(_dataType != other._dataType) - throw std::invalid_argument("NDArray::applyBroadcast bool method: this and other arrays must have the same type !"); - - std::vector copy(dimensions); - - if (dimensions.size() > 1) - std::sort(copy.begin(), copy.end()); - - Nd4jLong const* xShapeInfoH = shapeInfo(); - Nd4jLong const* yShapeInfoH = other.shapeInfo(); - Nd4jLong const* xShapeInfoD = specialShapeInfo(); - Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); - - if(!isSameShape(target)) { - auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace(), copy); - xShapeInfoH = reinterpret_cast(xPack.primary()); - xShapeInfoD = reinterpret_cast(xPack.special()); - } - if(!other.isSameShape(target)) { - auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace(), copy); - yShapeInfoH = reinterpret_cast(yPack.primary()); - yShapeInfoD = reinterpret_cast(yPack.special()); - } - - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execBroadcastBool(getContext(), op, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr); - registerSpecialUse({&target}, {this, &other}); -} - - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyBroadcast(sd::broadcast::IntOps op, const std::vector& dimensions, const NDArray& other, NDArray& target, ExtraArguments* extraArgs) { - - if (dimensions.empty()) - return; - - if (!isZ()) - throw std::runtime_error("NDArray::applyBroadcast IntOps: you can't use this method on non-Integer array!"); - if(isEmpty() || other.isEmpty()) { - if(!target.isEmpty()) - throw std::runtime_error("NDArray::applyBroadcast IntOps: when some of input arrays (or both) is empty, target array must be empty as well !"); - return; - } - - // if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - // NDArray::prepareSpecialUse({&target}, {this, &other}); - // NativeOpExecutioner::execPairwiseIntTransform(getContext(), fromBroadcastToPairwiseInt(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.special(), nullptr); - // NDArray::registerSpecialUse({&target}, {this, &other}); - // return; - // } - - if(target.dataType() != dataType()) - throw std::invalid_argument("NDArray::applyBroadcast int method: type of target array must be the same as input!"); - if(!target.isSameShape(this) && !target.isSameShape(other)) - throw std::invalid_argument("NDArray::applyBroadcast int method: one of of two input arrays (this or other) should has the same shape as target array!"); - if(_dataType != other._dataType) - throw std::invalid_argument("NDArray::applyBroadcast int method: this and other arrays must have the same type !"); - - std::vector copy(dimensions); - - if (dimensions.size() > 1) - std::sort(copy.begin(), copy.end()); - - Nd4jLong const* xShapeInfoH = shapeInfo(); - Nd4jLong const* yShapeInfoH = other.shapeInfo(); - Nd4jLong const* xShapeInfoD = specialShapeInfo(); - Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); - - if(!isSameShape(target)) { - auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace(), copy); - xShapeInfoH = reinterpret_cast(xPack.primary()); - xShapeInfoD = reinterpret_cast(xPack.special()); - } - if(!other.isSameShape(target)) { - auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace(), copy); - yShapeInfoH = reinterpret_cast(yPack.primary()); - yShapeInfoD = reinterpret_cast(yPack.special()); - } - - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execBroadcastInt(getContext(), op, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - registerSpecialUse({&target}, {this, &other}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyBroadcast(sd::broadcast::Ops op, const std::initializer_list dimensions, const NDArray& tadArray, NDArray& target, ExtraArguments* extraArgs) { - std::vector vec(dimensions); - applyBroadcast(op, vec, tadArray, target, extraArgs); -} - -//////////////////////////////////////////////////////////////////////// -void* NDArray::operator new(size_t i) { - if (sd::memory::MemoryRegistrator::getInstance().hasWorkspaceAttached()) { - sd::memory::Workspace* ws = sd::memory::MemoryRegistrator::getInstance().getWorkspace(); - return ws->allocateBytes((Nd4jLong) i); - } - else { - auto p = malloc(i); - CHECK_ALLOC(p, "Failed to allocate new NDArray", i); - return p; - } -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::operator delete(void* p) { - if (!sd::memory::MemoryRegistrator::getInstance().hasWorkspaceAttached()) - free(p); -} - -//////////////////////////////////////////////////////////////////////// -template -std::vector NDArray::asVectorT() { - - std::vector result(this->lengthOf()); - - PRAGMA_OMP_SIMD - for (int e = 0; e < this->lengthOf(); e++) - result[e] = this->e(e); - - return result; -} -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::asVectorT(), LIBND4J_TYPES); - -////////////////////////////////////////////////////////////////////////// -// set new order and shape in case of suitable array length -bool NDArray::reshapei(const char order, const std::vector& cshape, const bool copyToNewBuff) { - - // check firstly whether cshape is identical to shape of array, if yes then reshape is unnecessary - if(order == ordering() && shape::shapeEquals(rankOf(), shapeOf(), cshape.size(), cshape.data())) - return true; - - const bool isOutShapeEmpty = std::find(cshape.begin(), cshape.end(), 0) != cshape.end(); - - if(isEmpty() && !isOutShapeEmpty) - throw std::invalid_argument("NDArray::reshapei: can't reshape empty array to non-empty !"); - if(!isEmpty() && isOutShapeEmpty) - throw std::invalid_argument("NDArray::reshapei: can't reshape non-empty array to empty !"); - if(isEmpty() && isOutShapeEmpty) { - Nd4jLong* shapeInfoNew = ShapeBuilders::emptyShapeInfo(dataType(), order, cshape, getContext()->getWorkspace()); - setShapeInfo(shapeInfoNew); - RELEASE(shapeInfoNew, getContext()->getWorkspace()); - return true; - } - - std::vector shape(cshape); - int rank = shape.size(); - - // looking for negative in shape - - int numberNegativesOnes = 0; - - Nd4jLong* shape_ = shape.data(); - for (int i = 0; i < (int) shape.size(); i++) { - if (shape[i] < 0) { - if (numberNegativesOnes >= 1) - throw std::runtime_error("NDArray::reshapei: only one dimension can be negative at once"); - - numberNegativesOnes++; - - int shapeLength = 1; - for (int j = 0; j < (int) shape.size(); j++) - if (i != j) - shapeLength *= shape_[j]; - - Nd4jLong realShape = sd::math::nd4j_abs(lengthOf() / shapeLength); - auto thisNewShape = new Nd4jLong[shape.size()]; - - for (int j = 0; j < (int) shape.size(); j++) - if (i != j) - thisNewShape[j] = shape_[j]; - else - thisNewShape[j] = realShape; - - shape_ = thisNewShape; - } - } - - for (int e = 0; e < (int) shape.size(); e++) - shape[e] = shape_[e]; - - if (numberNegativesOnes > 0) - delete[] shape_; - - Nd4jLong arrLength = 1; - for(const auto& item : shape) - arrLength *= item; - - if(platformBuffer() == nullptr || arrLength != this->lengthOf()) { - this->printShapeInfo("Mismatched shape"); - sd::Logger::printv("Shape requested: ", shape); - nd4j_debug("Requested length in reshape: %i; Existing length: %i;\n", arrLength, this->lengthOf()); - throw std::runtime_error("NDArray::reshapei: bad input shape!"); - } - - Nd4jLong *shapeInfoNew; - ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); - - bool canReshape = shape::reshapeC(shapeInfo(), order, shape.size(), shape.data(), shapeInfoNew); - - if (canReshape) { - setShapeInfo(shapeInfoNew); - } - else { - NDArray temp(order, shape, dataType(), getContext()); - if(copyToNewBuff) - this->applyTransform(transform::Assign, temp, nullptr); - *this = std::move(temp); - } - - RELEASE(shapeInfoNew, getContext()->getWorkspace()); - - return canReshape; -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::nullify() { - if (isEmpty()) - return; - - if (isView() || ews() != 1) - assign(0); - else - _buffer->setToZeroBuffers(); - -} - -//////////////////////////////////////////////////////////////////////// -template -void NDArray::templatedSet(void *buffer, const Nd4jLong xOfsset, sd::DataType dtype, const void *value) { - BUILD_SINGLE_PARTIAL_SELECTOR(dtype, templatedSet< , T>(buffer, xOfsset, value), LIBND4J_TYPES); -} -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong xOfsset, sd::DataType dtype, const void *value), LIBND4J_TYPES); - -//////////////////////////////////////////////////////////////////////// -void NDArray::applyPairwiseTransform(sd::pairwise::Ops op, const NDArray& other, NDArray& target, ExtraArguments *extraParams) const{ - if (isS()) - throw std::runtime_error("NDArray::applyPairwiseTransform: you can't use this method on String array!"); - if (other.lengthOf() != target.lengthOf()) - throw std::invalid_argument("NDArray::applyPairwiseTransform method - lengths of arrays are mismatched"); - if (target.dataType() != this->dataType() && target.dataType() != other.dataType()) - throw std::invalid_argument("NDArray::applyPairwiseTransform method - type of target array must be the same as type of this or other array !"); - - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr); - NDArray::registerSpecialUse({&target}, {this, &other}); - - if (extraParams != nullptr) - synchronize("NDArray::applyPairwiseTransform"); -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::applyPairwiseTransform(sd::pairwise::BoolOps op, const NDArray& other, NDArray& target, ExtraArguments *extraParams) const{ - if (isS()) - throw std::runtime_error("NDArray::applyPairwiseTransform BoolOps: you can't use this method on String array!"); - if (other.lengthOf() != target.lengthOf()) - throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - lengths of arrays are mismatched"); - if (!target.isB()) - throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - result must have bool type"); - if (dataType() != other.dataType()) - throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - this and other arrays must have the same type !"); - - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execPairwiseBoolTransform(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr); - NDArray::registerSpecialUse({&target}, {this, &other}); -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::applyPairwiseTransform(sd::pairwise::IntOps op, const NDArray& other, NDArray& target, ExtraArguments *extraParams) const{ - if (isS()) - throw std::runtime_error("NDArray::applyPairwiseTransform IntOps: you can't use this method on String array!"); - if (other.lengthOf() != target.lengthOf()) - throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - lengths of arrays are mismatched"); - if (!target.isZ()) - throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - result must have bool type"); - if (dataType() != other.dataType()) - throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - this and other arrays must have the same type !"); - - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execPairwiseIntTransform(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr); - NDArray::registerSpecialUse({&target}, {this, &other}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyPairwiseTransform(sd::pairwise::Ops op, const NDArray& other, ExtraArguments *extraParams) { - applyPairwiseTransform(op, other, *this, extraParams); -} - -//////////////////////////////////////////////////////////////////////// -template -void NDArray::templatedDoubleAssign(void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const { - auto x = reinterpret_cast(xBuffer); - const auto y = reinterpret_cast(yBuffer); - x[xOffset] = static_cast(y[yOffset]); -} -BUILD_DOUBLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedDoubleAssign, (void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const, LIBND4J_TYPES, LIBND4J_TYPES); - -//////////////////////////////////////////////////////////////////////// -void NDArray::varianceAlongDimension(sd::variance::Ops op, NDArray& target, const bool biasCorrected, const std::vector& dimensions) const { - - if (isS()) - throw std::runtime_error("NDArray::varianceAlongDimension: you can't use this method on String array!"); - - if (!target.isR()) - throw std::runtime_error("NDArray::varianceAlongDimension: target array must have FLOAT type"); - - NDArray::prepareSpecialUse({&target}, {this}); - - if(rankOf() == dimensions.size() || dimensions.empty()) - NativeOpExecutioner::execSummaryStatsScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), biasCorrected); - else { - std::vector copy(dimensions); - auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimensions); - NativeOpExecutioner::execSummaryStats(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, dimensions.size(), packX.platformShapeInfo(), packX.platformOffsets(), biasCorrected); - synchronize("NDArray::varianceAlongDimension"); - } - - NDArray::registerSpecialUse({&target}, {this}); -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::varianceAlongDimension(sd::variance::Ops op, const bool biasCorrected, const std::vector& dimensions) const { - if (isS()) - throw std::runtime_error("NDArray::varianceAlongDimension: you can't use this method on String array!"); - - std::vector copy(dimensions); - if (copy.size() > 1) - std::sort(copy.begin(), copy.end()); - - auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataTypeUtils::pickFloatingType(dataType()), false, false, getContext()->getWorkspace()); - NDArray result(newShape, true, getContext()); - - this->varianceAlongDimension(op, result, biasCorrected, dimensions); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::varianceAlongDimension(sd::variance::Ops op, const bool biasCorrected, const std::initializer_list& dimensions) const { - return varianceAlongDimension(op, biasCorrected, std::vector(dimensions)); -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::varianceAlongDimension(sd::variance::Ops op, NDArray &target, const bool biasCorrected, const std::initializer_list& dimensions) const { - varianceAlongDimension(op, target, biasCorrected, std::vector(dimensions)); -} - -//////////////////////////////////////////////////////////////////////// -// This method returns new copy of this NDArray, optionally in different order -NDArray NDArray::dup(const char newOrder) const { - - if (isEmpty()) - return NDArrayFactory::empty(dataType(), getContext()); - - char order = newOrder == 'a' ? ordering() : newOrder; - - // for now string arrays require special treatment - if (isS()) { - if (dataType() == DataType::UTF8) { - std::vector strings(lengthOf()); - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - strings[i] = std::move(this->e(i)); - } - }; - - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - - return NDArray(getShapeAsVector(), strings, dataType(), getContext()); - } - if (dataType() == DataType::UTF16) { - std::vector strings(lengthOf()); - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - strings[i] = std::move(this->e(i)); - } - }; - - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - - return NDArray(getShapeAsVector(), strings, dataType(), getContext()); - } - - std::vector strings(lengthOf()); - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - strings[i] = std::move(this->e(i)); - } - }; - - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - - return NDArray(getShapeAsVector(), strings, dataType(), getContext()); - } - - NDArray result(order, isScalar() ? std::vector({0}) : getShapeAsVector(), dataType(), getContext()); - result.assign(*this); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -// This method returns true if two arrays are equal, with custom or default Eps value of 1e-5, false otherwise -bool NDArray::equalsTo(const NDArray *other, double eps) const { - - if (dataType() != other->dataType() || lengthOf() != other->lengthOf()) - return false; - - // we need to be able to compare [1, len] to [len] - if ((rankOf() == 1 && other->rankOf() == 2) || (rankOf() == 2 && other->rankOf() == 1)) { - // FIXME: do something here? - } else if (!shape::equalsSoft(shapeInfo(), other->shapeInfo())) - return false; - - if (isS()) { - // string is special case, we'll compare them one by one, considering both arrays are guaranteed to have the same length - - if (dataType() == DataType::UTF8) { - for (Nd4jLong e = 0; e < this->lengthOf(); e++) { - auto s1 = this->e(e); - auto s2 = other->e(e); - - if (s1 != s2) - return false; - } - } - else if (dataType() == DataType::UTF16) { - for (Nd4jLong e = 0; e < this->lengthOf(); e++) { - auto s1 = this->e(e); - auto s2 = other->e(e); - - if (s1 != s2) - return false; - } - } - else { - for (Nd4jLong e = 0; e < this->lengthOf(); e++) { - auto s1 = this->e(e); - auto s2 = other->e(e); - - if (s1 != s2) - return false; - } - } - - return true; - } else { - // regular numeric types - NDArray tmp(sd::DataType::FLOAT32, getContext()); // scalar = 0 - - ExtraArguments extras({0.0, 0.0, eps}); - - NDArray::prepareSpecialUse({&tmp}, {this, other}); - NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, buffer(), shapeInfo(), - specialBuffer(), specialShapeInfo(), - extras.argumentsAsT(DataType::FLOAT32), other->buffer(), - other->shapeInfo(), other->specialBuffer(), - other->specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), - tmp.specialBuffer(), tmp.specialShapeInfo()); - NDArray::registerSpecialUse({&tmp}, {this, other}); - - synchronize("NDArray::equalsTo"); - - if (tmp.e(0) != 0) - return false; - - return true; - } -} - -////////////////////////////////////////////////////////////////////////// -template <> -std::string NDArray::e(const Nd4jLong i) const { - - if (!isS()) - throw std::runtime_error("Can't get std::string out of non-string array"); - - if (i == lengthOf()) - throw std::runtime_error("Can't get std::string for index out of range"); - - - if (this->dataType() == DataType::UTF16) { - auto u16 = this->e(i); - std::string s; - StringUtils::u16StringToU8String(u16, s); - return s; - } - - if (this->dataType() == DataType::UTF32) { - auto u32 = this->e(i); - std::string s; - StringUtils::u32StringToU8String(u32, s); - return s; - } - - NDArray::preparePrimaryUse({}, {this}); - - auto offsets = bufferAsT(); - auto offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); - auto start = offsets[i]; - auto end = offsets[i + 1]; - auto data = bufferAsT() + offsetsLength + start; - - std::string r(reinterpret_cast(data), (end - start)); - - registerPrimaryUse({}, {this}); - - return r; -} - -template <> -std::u16string NDArray::e(const Nd4jLong i) const { - - if (!isS()) - throw std::runtime_error("Can't get std::u16string out of non-string array"); - - if(i == lengthOf()) - throw std::runtime_error("Can't get std::u16string for index out of range"); - - if (this->dataType() == DataType::UTF8) { - auto u = this->e(i); - std::u16string s; - StringUtils::u8StringToU16String(u, s); - return s; - } - - if (this->dataType() == DataType::UTF32) { - auto u32 = this->e(i); - std::u16string s; - StringUtils::u32StringToU16String(u32, s); - return s; - } - - NDArray::preparePrimaryUse({}, { this }); - - auto offsets = bufferAsT(); - Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); - Nd4jLong start = offsets[i]; - Nd4jLong end = offsets[i + 1]; - auto data = bufferAsT() + offsetsLength + start; - - std::u16string r(reinterpret_cast(data), (end - start) / sizeof(char16_t)); - - registerPrimaryUse({}, { this }); - - return r; -} - -template <> -std::u32string NDArray::e(const Nd4jLong i) const { - - if (!isS()) - throw std::runtime_error("Can't get std::u32string out of non-string array"); - - if (i == lengthOf()) - throw std::runtime_error("Can't get std::u32string for index out of range"); - - if (this->dataType() == DataType::UTF8) { - auto u = this->e(i); - std::u32string s; - StringUtils::u8StringToU32String(u, s); - return s; - } - - if (this->dataType() == DataType::UTF16) { - auto u16 = this->e(i); - std::u32string s; - StringUtils::u16StringToU32String(u16, s); - return s; - } - - NDArray::preparePrimaryUse({}, { this }); - - auto offsets = bufferAsT(); - Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); - Nd4jLong start = offsets[i]; - Nd4jLong end = offsets[i + 1]; - - auto data = bufferAsT() + offsetsLength + start; - - std::u32string r(reinterpret_cast(data), (end - start) / sizeof(char32_t)); - - registerPrimaryUse({}, { this }); - - return r; -} - -////////////////////////////////////////////////////////////////////////// -template <> -utf8string NDArray::e(const Nd4jLong i) const { - - if (!isS()) - throw std::runtime_error("This method is available for String arrays only"); - - auto rp = getOffset(i); - - syncToHost(); - tickReadHost(); - - return *(reinterpret_cast(buffer())[rp]); -} - -///////////////////////////////////////////////////////////////////////// -template -T NDArray::e(const Nd4jLong i) const { - - const auto rp = getOffset(i); - - NDArray::preparePrimaryUse({}, {this}); - NDArray::registerPrimaryUse({}, {this}); - BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(buffer(), rp), LIBND4J_TYPES); - -} -BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , NDArray::e(const Nd4jLong) const, LIBND4J_TYPES); - -////////////////////////////////////////////////////////////////////////// -// Returns value from 2D matrix by coordinates/indexes -template -T NDArray::e(const Nd4jLong i, const Nd4jLong j) const { - - if (rankOf() != 2 || i >= shapeOf()[0] || j >= shapeOf()[1]) - throw std::invalid_argument("NDArray::e(i,j): one of input indexes is out of array length or rank!=2 !"); - - const auto xOffset = i * strideAt(0) + j * strideAt(1); - - NDArray::preparePrimaryUse({}, {this}); - NDArray::registerPrimaryUse({}, {this}); - - BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(buffer(), xOffset), LIBND4J_TYPES); - - return static_cast(119); -} -BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , NDArray::e(const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES); - -////////////////////////////////////////////////////////////////////////// -// returns value from 3D tensor by coordinates -template -T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const { - - if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2]) - throw std::invalid_argument("NDArray::e(i,j,k): one of input indexes is out of array length or rank!=3 !"); - - const auto xOffset = i * strideAt(0) + j * strideAt(1) + k * strideAt(2); - - NDArray::preparePrimaryUse({}, {this}); - NDArray::registerPrimaryUse({}, {this}); - - BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(buffer(), xOffset), LIBND4J_TYPES); - - return static_cast(119); -} -BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , NDArray::e(const Nd4jLong, const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES); - -////////////////////////////////////////////////////////////////////////// -// returns value from 3D tensor by coordinates -template -T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l) const { - - if (rankOf() != 4 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2] || l >= shapeOf()[3]) - throw std::invalid_argument("NDArray::e(i,j,k,l): one of input indexes is out of array length or rank!=4 !"); - - const auto xOffset = i * strideAt(0) + j * strideAt(1) + k * strideAt(2) + l * strideAt(3); - - NDArray::preparePrimaryUse({}, {this}); - NDArray::registerPrimaryUse({}, {this}); - - BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(buffer(), xOffset), LIBND4J_TYPES); - - return static_cast(119); -} -BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , NDArray::e(const Nd4jLong, const Nd4jLong, const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES); - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::e(const Nd4jLong i) const { - - const auto offset = getOffset(i); - - NDArray scalar(dataType(), getContext()); - - scalar.copyBuffersContinuouslyFrom(*this, sizeOfT(), 0, bufferOffset() + offset); - - return scalar; -} - -////////////////////////////////////////////////////////////////////////// -// perform array transformation -void NDArray::applyTransform(sd::transform::FloatOps op, NDArray& target, ExtraArguments *extraParams) { - - if (isS()) - throw std::runtime_error("NDArray::applyTransform FloatOps: you can't use this method on String array!"); - - if (!target.isR()) - throw std::runtime_error("NDArray::applyTransform FloatOps: target array must have one of FLOAT types"); - - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execTransformFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this}); -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::applyTransform(sd::transform::AnyOps op, NDArray& target, ExtraArguments *extraParams) { - - if (isS()) - throw std::runtime_error("NDArray::applyTransform AnyOps: you can't use this method on String array!"); - - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execTransformAny(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this}); -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::applyTransform(sd::transform::SameOps op, NDArray& target, ExtraArguments *extraParams) { - - if (isS()) - throw std::runtime_error("NDArray::applyTransform SameOps: you can't use this method on String array!"); - - if (target.dataType() != dataType()) - throw std::runtime_error("NDArray::applyTransform SameOps: target array must have the same data type as original array"); - - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execTransformSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this}); -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::applyTransform(sd::transform::StrictOps op, NDArray& target, ExtraArguments *extraParams) { - if (isS()) - throw std::runtime_error("NDArray::applyTransform StrictOps: you can't use this method on String array!"); - - if (!this->isR() || !target.isR() || (this->dataType() != target.dataType())) - throw std::runtime_error("NDArray::applyTransform StrictOps: both Source and Target array must have same FLOAT type !"); - - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execTransformStrict(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this}); -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::applyTransform(sd::transform::BoolOps op, NDArray& target, ExtraArguments *extraParams) { - if (isS()) - throw std::runtime_error("NDArray::applyTransform BoolOps: you can't use this method on String array!"); - - if (!target.isB()) - throw std::runtime_error("NDArray::applyTransform BoolOps: target array must have one of BOOL types"); - - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execTransformBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this}); -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(sd::transform::FloatOps op, void *extraParams) const & { - if (isS()) - throw std::runtime_error("NDArray::transform FloatOps: you can't use this method on String array!"); - - NDArray result(ordering(), getShapeAsVector(), DataTypeUtils::pickFloatingType(dataType()), getContext()); - - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execTransformFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr); - NDArray::registerSpecialUse({&result}, {this}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(sd::transform::FloatOps op, void *extraParams) && { - if (isS()) - throw std::runtime_error("NDArray::transform SameOps: you can't use this method on String array!"); - - NDArray::prepareSpecialUse({this}, {this}); - NativeOpExecutioner::execTransformFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, nullptr, nullptr); - NDArray::registerSpecialUse({this}, {this}); - - return std::move(*this); -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(sd::transform::SameOps op, void *extraParams) const & { - if (isS()) - throw std::runtime_error("NDArray::transform SameOps: you can't use this method on String array!"); - - NDArray result(shapeInfo(), false, getContext()); - - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execTransformSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr); - NDArray::registerSpecialUse({&result}, {this}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(sd::transform::SameOps op, void *extraParams) && { - if (isS()) - throw std::runtime_error("NDArray::transform SameOps: you can't use this method on String array!"); - - NDArray::prepareSpecialUse({this}, {this}); - NativeOpExecutioner::execTransformSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, nullptr, nullptr); - NDArray::registerSpecialUse({this}, {this}); - - return std::move(*this); -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(sd::transform::StrictOps op, void *extraParams) const & { - if (!this->isR()) - throw std::runtime_error("Source array must have one of FLOAT types"); - - NDArray result(shapeInfo(), false, getContext()); - - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execTransformStrict(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr); - NDArray::registerSpecialUse({&result}, {this}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(sd::transform::StrictOps op, void *extraParams) && { - if (!this->isR()) - throw std::runtime_error("Source array must have one of FLOAT types"); - - NDArray::prepareSpecialUse({this}, {this}); - NativeOpExecutioner::execTransformStrict(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, nullptr, nullptr); - NDArray::registerSpecialUse({this}, {this}); - - return std::move(*this); -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(sd::transform::BoolOps op, void *extraParams) const & { - if (isS()) - throw std::runtime_error("NDArray::transform BoolOps: you can't use this method on String array!"); - - NDArray result(ordering(), getShapeAsVector(), sd::DataType::BOOL, getContext()); - - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execTransformBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr); - NDArray::registerSpecialUse({&result}, {this}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(sd::transform::BoolOps op, void *extraParams) && { - if (isS()) - throw std::runtime_error("NDArray::transform BoolOps: you can't use this method on String array!"); - - NDArray::prepareSpecialUse({this}, {this}); - NativeOpExecutioner::execTransformBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, nullptr, nullptr); - NDArray::registerSpecialUse({this}, {this}); - - return std::move(*this); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyScalarArr(sd::scalar::Ops op, const NDArray& scalar, NDArray& target, ExtraArguments *extraParams) { - if (isS()) - throw std::runtime_error("NDArray::applyScalarArr: you can't use this method on String array!"); - if (scalar.lengthOf() != 1) - throw std::invalid_argument("NDArray::applyScalarArr method: operand is not a scalar!"); - - if(target.dataType() != DataTypeUtils::pickPairwiseResultType(shapeInfo(), scalar.shapeInfo()) && !(target.dataType() == dataType() || target.dataType() == scalar.dataType())) - throw std::invalid_argument("NDArray::applyScalarArr method: wrong type of target array!"); - - NDArray::prepareSpecialUse({&target}, {this, &scalar}); - NativeOpExecutioner::execScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), scalar.buffer(), scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()): nullptr); - NDArray::registerSpecialUse({&target}, {this, &scalar}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyScalarArr(sd::scalar::BoolOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::applyScalarArr BoolOps: you can't use this method on String array!"); - if (!target.isB()) - throw std::invalid_argument("NDArray::applyScalarArr bool method: target has not bool type!"); - if (dataType() != scalar.dataType()) { - nd4j_printf("NDArray::applyScalarArr BoolOps: this dtype: [%i]; scalar dtype: [%i]\n", this->dataType(), scalar.dataType()); - throw std::invalid_argument("NDArray::applyScalarArr bool method: this and scalar arrays must have the same type!"); - } - - NDArray::prepareSpecialUse({&target}, {this, &scalar}); - NativeOpExecutioner::execScalarBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), scalar.buffer(), scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()): nullptr); - NDArray::registerSpecialUse({&target}, {this, &scalar}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyScalarArr(sd::scalar::IntOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::applyScalarArr IntOps: you can't use this method on String array!"); - - if (target.dataType() != this->dataType()) - throw std::invalid_argument("NDArray::applyScalarArr int method: target has not bool type!"); - if (dataType() != scalar.dataType()) { - nd4j_printf("NDArray::applyScalarArr IntOps: this dtype: [%i]; scalar dtype: [%i]\n", this->dataType(), scalar.dataType()); - throw std::invalid_argument("NDArray::applyScalarArr int method: this and scalar arrays must have the same type!"); - } - - NDArray::prepareSpecialUse({&target}, {this, &scalar}); - NativeOpExecutioner::execScalarInt(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), scalar.buffer(), scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()): nullptr); - NDArray::registerSpecialUse({&target}, {this, &scalar}); -} - -//////////////////////////////////////////////////////////////////////// -template -void NDArray::applyScalar(sd::scalar::IntOps op, const T scalar, NDArray& target, ExtraArguments *extraParams) const { - - NDArray scalarArr = NDArrayFactory::create(this->dataType(), scalar, getContext()); - applyScalarArr(op, scalarArr, target, extraParams); -} - -template <> ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const double scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const float scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const float16 scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const bfloat16 scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const Nd4jLong scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const int scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const int16_t scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const int8_t scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const uint8_t scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const bool scalar, NDArray &target, ExtraArguments *extraParams) const; - -//////////////////////////////////////////////////////////////////////// -template -void NDArray::applyScalar(sd::scalar::Ops op, const T scalar, NDArray& target, ExtraArguments *extraParams) { - - auto scalarArr = NDArrayFactory::create(dataType(), scalar, this->getContext()); - applyScalarArr(op, scalarArr, target, extraParams); -} -template <> ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const double scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const float scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const float16 scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const bfloat16 scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const Nd4jLong scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const int scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const int16_t scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const int8_t scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const uint8_t scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const bool scalar, NDArray &target, ExtraArguments *extraParams); - -//////////////////////////////////////////////////////////////////////// -template -void NDArray::applyScalar(sd::scalar::BoolOps op, const T scalar, NDArray &target, ExtraArguments *extraParams) const { - - NDArray scalarArr = NDArrayFactory::create(scalar, getContext()); - applyScalarArr(op, scalarArr, target, extraParams); -} - -template <> ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const double scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const float scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const float16 scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const bfloat16 scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const Nd4jLong scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const int scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const int16_t scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const int8_t scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const uint8_t scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const bool scalar, NDArray &target, ExtraArguments *extraParams) const; - -//////////////////////////////////////////////////////////////////////// -void NDArray::applyIndexReduce(sd::indexreduce::Ops op, NDArray& target, const std::vector& dimensions, const ExtraArguments *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::applyIndexReduce: you can't use this method on String array!"); - - if (target.dataType() != sd::DataType::INT64 && target.dataType() != sd::DataType::INT32) - throw std::runtime_error("NDArray::applyIndexReduce operations return INT32/INT64"); - - void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(this->dataType()) : nullptr; - - NDArray::prepareSpecialUse({&target}, {this}); - - if (target.lengthOf() == 1) { - NativeOpExecutioner::execIndexReduceScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - } - else { - std::vector copy = dimensions; - shape::checkDimensions(rankOf(), copy); - auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(shapeInfo(), copy); - NativeOpExecutioner::execIndexReduce(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); - synchronize("NDArray::applyIndexReduce"); - } - - registerSpecialUse({&target}, {this}); -} - -//////////////////////////////////////////////////////////////////////// -// reduce dimensions in this array relying on index operations -NDArray NDArray::applyIndexReduce(sd::indexreduce::Ops op, const std::vector& dimensions, const ExtraArguments* extraParams ) const { - - std::vector copy = dimensions; - auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataType::INT64, false, false, getContext()->getWorkspace()); - NDArray result(newShape, true, getContext()); - - applyIndexReduce(op, result, copy, extraParams); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -// apply reduce3 operations to this and other array, return result in new output array -NDArray NDArray::applyReduce3(sd::reduce3::Ops op, const NDArray& other, const ExtraArguments* extraParams) const { - - if (isS()) - throw std::runtime_error("NDArray::applyReduce3 method: you can't use this method on String array!"); - if(dataType() != other.dataType()) - throw std::runtime_error("NDArray::applyReduce3 method: the types of this and other arrays must be the same !"); - // check shapes consistency - if(!isSameShape(other)) - throw std::runtime_error("NDArray::applyReduce3 method: the shapes of this and other arrays must be the same !"); - // create shapeInfo for scalar - auto newShape = ShapeBuilders::createScalarShapeInfo(DataTypeUtils::pickFloatingType(dataType()), getContext()->getWorkspace()); - // create output array (scalar) - NDArray result(newShape, true, getContext()); - RELEASE(newShape, getContext()->getWorkspace()); - // create dynamic array of extra parameters if array extraParams is empty (==nullptr) - void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(dataType()) : nullptr; - - NDArray::prepareSpecialUse({&result}, {this, &other}); - NativeOpExecutioner::execReduce3Scalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); - NDArray::registerSpecialUse({&result}, {this, &other}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -// apply reduce3 (exec) operations to this and other array, return result in new output array -NDArray NDArray::applyReduce3(sd::reduce3::Ops op, const NDArray& other, const std::vector& dimensions, const ExtraArguments* extraParams) const { - - if (isS()) - throw std::runtime_error("NDArray::applyReduce3: you can't use this method on String array!"); - if(dataType() != other.dataType()) - throw std::runtime_error("NDArray::applyReduce3 method: the types of this and other arrays must be the same !"); - - std::vector copy(dimensions); - shape::checkDimensions(rankOf(), copy); - shape::checkDimensions(other.rankOf(), copy); - - auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataTypeUtils::pickFloatingType(dataType()), false, false, getContext()->getWorkspace()); - NDArray result(newShape, true, getContext()); - // create temporary dynamic array of extra parameters if array extraParams is empty (==nullptr) - void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(dataType()) : nullptr; - - NDArray::prepareSpecialUse({&result}, {this, &other}); - - // perform calculations - if(rankOf() == copy.size() && other.rankOf() == copy.size()) { - NativeOpExecutioner::execReduce3Scalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); - } - else { - - auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(shapeInfo(), copy); - auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions(other.shapeInfo(), copy); - - if(!shape::equalsSoft(packX.primaryShapeInfo(), packY.primaryShapeInfo()) || (packX.numberOfTads() != packY.numberOfTads() && packX.numberOfTads() != 1 && packY.numberOfTads() != 1)) - throw std::runtime_error("NDArray::applyReduce3 cuda method: arrays tads are inconsistent !"); - - NativeOpExecutioner::execReduce3(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); - } - - registerSpecialUse({&result}, {this, &other}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -// apply reduce3 (execAll) operations to this and other array, return result in new output array -NDArray NDArray::applyAllReduce3(sd::reduce3::Ops op, const NDArray& other, const std::vector& dimensions, const ExtraArguments* extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::applyAllReduce3: you can't use this method on String array!"); - if(dataType() != other.dataType()) - throw std::runtime_error("NDArray::applyAllReduce3 method: the types of this and other arrays must be the same !"); - - // be careful, copy array may undergo changes (sort, transformation of negative dimensions to positive, duplicates removing ) - std::vector copy(dimensions); - shape::checkDimensions(rankOf(), copy); - shape::checkDimensions(other.rankOf(), copy); - - auto packX = ConstantTadHelper::getInstance().tadForDimensions(shapeInfo(), copy); - auto packY = ConstantTadHelper::getInstance().tadForDimensions(other.shapeInfo(), copy); - - // check tads shapes - if(!shape::equalsSoft(packX.primaryShapeInfo(), packY.primaryShapeInfo())) - throw std::runtime_error("NDArray::applyAllReduce3 method: the shapes of array tads are different !"); - - // set newShape for output array - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(DataTypeUtils::pickFloatingType(dataType()), 'c', {packX.numberOfTads(), packY.numberOfTads()}); - - // create output array - NDArray result(newShape, true, getContext()); - - // create dynamic array of extra parameters if array extraParams is empty (==nullptr) - void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(dataType()) : nullptr; - - auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; - - NDArray::prepareSpecialUse({&result}, {this, &other}); - NativeOpExecutioner::execReduce3All(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); - NDArray::registerSpecialUse({&result}, {this, &other}); - - return result; -} - -////////////////////////////////////////////////////////////////////////// -// method reduces array by excluding its shapes along axes present in dimensions vector -void NDArray::reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const { - - if (isS()) - throw std::runtime_error("NDArray::reduceAlongDimension FloatOps: you can't use this method on String array!"); - if (!target.isR()) - throw std::invalid_argument("NDArray::reduceAlongDimension FloatOps: requires target array to be present and have type form real space!"); - - std::vector copy(dimensions); - - if(checkTargetShape) { - auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); - if(!shape::shapeEquals(newShape, target.shapeInfo())) - throw std::runtime_error("NDArray::reduceAlongDimension FloatOps: wrong target shape!"); - } - - NDArray::prepareSpecialUse({&target}, {this}); - - if(rankOf() == copy.size() || copy.empty()) { - NativeOpExecutioner::execReduceFloatScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(),nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - } - else { - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(shapeInfo(), copy); - NativeOpExecutioner::execReduceFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), copy.data(), copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); - } - synchronize("NDArray::reduceAlongDimension FloatOps"); - - NDArray::registerSpecialUse({&target}, {this}); -} - -////////////////////////////////////////////////////////////////////////// -// method reduces array by excluding its shapes along axes present in dimensions vector -void NDArray::reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const { - - if (isS()) - throw std::runtime_error("NDArray::reduceAlongDimension SameOps: you can't use this method on String array!"); - if (target.dataType() != dataType()) - throw std::runtime_error("NDArray::reduceAlongDimension SameOps: requires target array to be present and have same dtype as input"); - - std::vector copy(dimensions); - - if(checkTargetShape) { - auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); - if(!shape::shapeEquals(newShape, target.shapeInfo())) - throw std::runtime_error("NDArray::reduceAlongDimension SameOps: wrong target shape!"); - } - - NDArray::prepareSpecialUse({&target}, {this}); - - if(rankOf() == copy.size() || copy.empty()) { - NativeOpExecutioner::execReduceSameScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - } - else { //if (!isEmpty()) { - auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), copy); - NativeOpExecutioner::execReduceSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); - } - synchronize("NDArray::reduceAlongDimension SameOps"); - - NDArray::registerSpecialUse({&target}, {this}); -} - -////////////////////////////////////////////////////////////////////////// -// method reduces array by excluding its shapes along axes present in dimensions vector -void NDArray::reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const { - - if (isS()) - throw std::runtime_error("NDArray::reduceAlongDimension LongOps: you can't use this method on String array!"); - if (target.dataType() != DataType::INT64) - throw std::runtime_error("NDArray::reduceAlongDimension LongOps: requires target array to be present and have type of INT64"); - - std::vector copy(dimensions); - - if(checkTargetShape) { - auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); - if(!shape::shapeEquals(newShape, target.shapeInfo())) - throw std::runtime_error("NDArray::reduceAlongDimension LongOps: wrong target shape!"); - } - - NDArray::prepareSpecialUse({&target}, {this}); - - if(rankOf() == copy.size() || copy.empty()) { - NativeOpExecutioner::execReduceLongScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - } - else { - auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), copy); - NativeOpExecutioner::execReduceLong(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); - } - synchronize("NDArray::reduceAlongDimension LongOps"); - - NDArray::registerSpecialUse({&target}, {this}); -} - -////////////////////////////////////////////////////////////////////////// -// method reduces array by excluding its shapes along axes present in dimensions vector -void NDArray::reduceAlongDimension(sd::reduce::BoolOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const { - - if (isS()) - throw std::runtime_error("NDArray::reduceAlongDimension BoolOps cuda: you can't use this method on String array!"); - if (!target.isB()) - throw std::invalid_argument("NDArray::reduceAlongDimension BoolOps cuda: requires target array to be present and have BOOL type!"); - - std::vector copy(dimensions); - - if(checkTargetShape) { - auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); - if(!shape::shapeEquals(newShape, target.shapeInfo())) - throw std::runtime_error("NDArray::reduceAlongDimension BoolOps cuda: wrong target shape!"); - } - - NDArray::prepareSpecialUse({&target}, {this}); - - if(rankOf() == copy.size() || copy.empty()) { - NativeOpExecutioner::execReduceBoolScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - } - else { - auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), copy); - NativeOpExecutioner::execReduceBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); - } - synchronize("NDArray::reduceAlongDimension LongOps"); - - NDArray::registerSpecialUse({&target}, {this}); -} - -////////////////////////////////////////////////////////////////////////// -// This method sets value in linear buffer to position i -template -void NDArray::p(const Nd4jLong i, const T value) { - - if (i >= lengthOf()) - throw std::invalid_argument("NDArray::p(i, value): input index is out of array length !"); - - auto rp = getOffset(i); - const void *pV = reinterpret_cast(const_cast(&value)); - - NDArray::preparePrimaryUse({this}, {}, true); - BUILD_SINGLE_PARTIAL_SELECTOR(this->dataType(), templatedSet<, T>(this->buffer(), rp, pV), LIBND4J_TYPES); - NDArray::registerPrimaryUse({this}, {}); -} - -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const double value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const float value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const float16 value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const bfloat16 value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const int value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const int8_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const uint8_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const uint16_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const uint32_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const uint64_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const int16_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const bool value); - -////////////////////////////////////////////////////////////////////////// -// This method sets value in 2D matrix to position i, j -template -void NDArray::p(const Nd4jLong i, const Nd4jLong j, const T value) { - - if (rankOf() != 2 || i >= shapeOf()[0] || j >= shapeOf()[1]) - throw std::invalid_argument("NDArray:pe(i,j, value): one of input indexes is out of array length or rank!=2 !"); - - void *p = reinterpret_cast(const_cast(&value)); - auto xOffset = i * strideAt(0) + j * strideAt(1); - - NDArray::preparePrimaryUse({this}, {}, true); - BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->buffer(), xOffset, p), LIBND4J_TYPES); - NDArray::registerPrimaryUse({this}, {}); -} -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const double value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float16 value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const bfloat16 value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int8_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint8_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint16_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint32_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint64_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int16_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const bool value); - -////////////////////////////////////////////////////////////////////////// -// This method sets value in 3D matrix to position i,j,k -template -void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const T value) { - //(*this)(i,j,k) = value; - if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2]) - throw std::invalid_argument("NDArray:pe(i,j,k, value): one of input indexes is out of array length or rank!=3 !"); - - void *p = reinterpret_cast(const_cast(&value)); - auto xOffset = i * strideAt(0) + j * strideAt(1) + k * strideAt(2); - - NDArray::preparePrimaryUse({this}, {}, true); - BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->buffer(), xOffset, p), LIBND4J_TYPES); - NDArray::registerPrimaryUse({this}, {}); -} -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const double value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float16 value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const bfloat16 value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int8_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint8_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint16_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint32_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint64_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int16_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const bool value); - -////////////////////////////////////////////////////////////////////////// -template -void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const T value) { - //(*this)(i,j,k) = value; - if (rankOf() != 4 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2] || l >= shapeOf()[3]) - throw std::invalid_argument("NDArray::p(i,j,k,l, value): one of input indexes is out of array length or rank!=4 !"); - - void *p = reinterpret_cast(const_cast(&value)); - auto xOffset = i * strideAt(0) + j * strideAt(1) + k * strideAt(2) + l * strideAt(3); - - NDArray::preparePrimaryUse({this}, {}, true); - BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->buffer(), xOffset, p), LIBND4J_TYPES); - NDArray::registerPrimaryUse({this}, {}); -} -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const double value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float16 value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const bfloat16 value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const Nd4jLong value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int8_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint8_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint16_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint32_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint64_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int16_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const bool value); - -//////////////////////////////////////////////////////////////////////// -void NDArray::p(const Nd4jLong i, const NDArray& scalar) { - - if(scalar.lengthOf() != 1) - throw std::invalid_argument("NDArray::p method: input array must be scalar!"); - if (i >= _length) - throw std::invalid_argument("NDArray::p(i, NDArray_scalar): input index is out of array length !"); - - NDArray::preparePrimaryUse({this}, {&scalar}, true); - auto rp = getOffset(i); - BUILD_SINGLE_SELECTOR(scalar.dataType(), templatedSet, (buffer(), rp, scalar.dataType(), scalar.buffer()), LIBND4J_TYPES); - NDArray::registerPrimaryUse({this}, {&scalar}); -} - -//////////////////////////////////////////////////////////////////////// - void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const NDArray& scalar) { - - if(scalar.lengthOf() != 1) - throw std::invalid_argument("NDArray::p method: input array must be scalar!"); - if (i >= _length) - throw std::invalid_argument("NDArray::p(i, NDArray_scalar): input index is out of array length !"); - -// void *p = reinterpret_cast(scalar.buffer()); - Nd4jLong coords[4] = {i, j, k, l}; - auto xOffset = shape::getOffset(shapeInfo(), coords); - - NDArray::preparePrimaryUse({this}, {&scalar}, true); -// BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->buffer(), xOffset, p), LIBND4J_TYPES); - BUILD_SINGLE_SELECTOR(scalar.dataType(), templatedSet, (this->buffer(), xOffset, scalar.dataType(), scalar.buffer()), LIBND4J_TYPES); - NDArray::registerPrimaryUse({this}, {&scalar}); - } - -////////////////////////////////////////////////////////////////////////// -void NDArray::addRowVector(const NDArray& row, NDArray& target) const { - - if (isS()) - throw std::runtime_error("NDArray::addRowVector: you can't use this method on String array!"); - if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !row.isRowVector() || columns() != row.lengthOf()) - throw std::invalid_argument("NDArray::addRowVector: wrong arguments !"); - if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType()) && !(isR() && row.isR() && target.isR())) - throw std::invalid_argument("NDArray::addRowVector: wrong type of target array !"); - - int dimension = 1; - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension); - - NDArray::prepareSpecialUse({&target}, {this, &row}); - NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), row.specialBuffer(), row.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this, &row}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::subRowVector(const NDArray& row, NDArray& target) const { - - if (isS()) - throw std::runtime_error("NDArray::addRowVector: you can't use this method on String array!"); - if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !row.isRowVector() || columns() != row.lengthOf()) - throw std::invalid_argument("NDArray::addRowVector: wrong arguments !"); - if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType()) && !(isR() && row.isR() && target.isR())) - throw std::invalid_argument("NDArray::addRowVector: wrong type of target array !"); - - int dimension = 1; - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension); - - NDArray::prepareSpecialUse({&target}, {this, &row}); - NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Subtract, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), row.specialBuffer(), row.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), &dimension, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this, &row}); -} - - -////////////////////////////////////////////////////////////////////////// -void NDArray::mulRowVector(const NDArray &row, NDArray &target) const { - - if (isS()) - throw std::runtime_error("NDArray::mulRowVector: you can't use this method on String array!"); - if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !row.isRowVector() || columns() != row.columns()) - throw std::invalid_argument("NDArray::divRowVector: wrong arguments !"); - if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType())) - throw std::invalid_argument("NDArray::mulRowVector: wrong type of target array !"); - - int dimension = 1; - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension); - - NDArray::prepareSpecialUse({&target}, {this, &row}); - NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Multiply, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), row.specialBuffer(), row.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this, &row}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::divRowVector(const NDArray &row, NDArray &target) const { - - if (isS()) - throw std::runtime_error("NDArray::divRowVector: you can't use this method on String array!"); - if (row.isB()) - throw std::runtime_error("NDArray::divRowVector: you can't divide by bool row!"); - if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !row.isRowVector() || columns() != row.columns()) - throw std::invalid_argument("NDArray::divRowVector: wrong arguments !"); - if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType())) - throw std::invalid_argument("NDArray::divRowVector: wrong type of target array !"); - - int dimension = 1; - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension); - - NDArray::prepareSpecialUse({&target}, {this, &row}); - NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Divide, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), row.specialBuffer(), row.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this, &row}); -} - -////////////////////////////////////////////////////////////////////////// -// This method adds given row to all rows in this NDArray, this array becomes affected -void NDArray::addiRowVector(const NDArray& row) { - - if (isS()) - throw std::runtime_error("NDArray::addiRowVector: you can't use this method on String array!"); - if (rankOf() != 2 || !row.isRowVector() || columns() != row.lengthOf()) - throw std::invalid_argument("NDArray::addiRowVector: wrong arguments !"); - - int dimension = 1; - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension); - - NDArray::prepareSpecialUse({this}, {&row}); - NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), row.specialBuffer(), row.specialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({this}, {&row}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::addColumnVector(const NDArray &column, NDArray &target) const { - if (isS()) - throw std::runtime_error("NDArray::addColumnVector: you can't use this method on String array!"); - if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !column.isColumnVector() || rows() != column.lengthOf()) - throw std::invalid_argument("NDArray::addColumnVector: wrong arguments !"); - if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), column.dataType())) - throw std::invalid_argument("NDArray::addColumnVector: wrong type of target array !"); - - int dimension = 0; - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension); - - NDArray::prepareSpecialUse({&target}, {this, &column}); - NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), column.buffer(), column.shapeInfo(), column.specialBuffer(), column.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this, &column}); -} - -////////////////////////////////////////////////////////////////////////// -// This method adds given column to all columns in this NDArray, this array becomes affected -void NDArray::addiColumnVector(const NDArray &column) { - if (isS()) - throw std::runtime_error("NDArray::addiColumnVector: you can't use this method on String array!"); - if (rankOf() != 2 || !column.isColumnVector() || rows() != column.lengthOf()) - throw std::invalid_argument("NDArray::addiColumnVector: wrong arguments !"); - - int dimension = 0; - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension); - - NDArray::prepareSpecialUse({this}, {&column}); - NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), column.buffer(), column.shapeInfo(), column.specialBuffer(), column.specialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({this}, {&column}); -} - -////////////////////////////////////////////////////////////////////////// -// This method multiplies each column of this array by given argument-column, this array becomes affected -void NDArray::muliColumnVector(const NDArray& column) { - if (isS()) - throw std::runtime_error("NDArray::muliColumnVector: you can't use this method on String array!"); - if (rankOf() != 2 || !column.isColumnVector() || rows() != column.lengthOf()) - throw std::invalid_argument("NDArray::muliColumnVector: wrong arguments !"); - - int dimension = 0; - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension); - - NDArray::prepareSpecialUse({this}, {&column}); - NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Multiply, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), column.buffer(), column.shapeInfo(), column.specialBuffer(), column.specialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({this}, {&column}); -} - -////////////////////////////////////////////////////////////////////////// -template -void NDArray::templatedAssign(void *xBuffer, Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const { - if (xBuffer != nullptr && yBuffer != nullptr) - *(reinterpret_cast(xBuffer) + xOffset) = *(reinterpret_cast(yBuffer) + yOffset); -} -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedAssign, (void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const, LIBND4J_TYPES); - - -////////////////////////////////////////////////////////////////////////// -bool NDArray::permutei(const int* dimensions, const int rank) { - - auto shapeInfo = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, getContext()->getWorkspace()); - setShapeInfo(shapeInfo); - - return true; -} - -////////////////////////////////////////////////////////////////////////// -bool NDArray::permutei(const Nd4jLong* dimensions, const int rank) { - - auto shapeInfo = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, getContext()->getWorkspace()); - setShapeInfo(shapeInfo); - - return true; -} - -//////////////////////////////////////////////////////////////////////// -ResultSet NDArray::multipleTensorsAlongDimension(const std::vector &indices, const std::vector &dimensions) const { - ResultSet result; - - if (indices.size() == 0) - return result; - - auto pack = ConstantTadHelper::getInstance().tadForDimensions(shapeInfo(), const_cast(dimensions.data()), dimensions.size()); - - auto tadLength = shape::length(pack.primaryShapeInfo()); - auto numTads = lengthOf() / tadLength; - - for (auto idx: indices) { - if (idx >= numTads) { - nd4j_printf("NDArray::multipleTensorsAlongDimension: index %i is higher then number of TADs: %i\n", idx, numTads); - throw std::runtime_error("Bad index"); - } - - auto array = new NDArray(getDataBuffer(), ShapeDescriptor(pack.primaryShapeInfo()), getContext(), pack.primaryOffsets()[idx] + bufferOffset()); - result.push_back(array); - } - - return result; -} - -//////////////////////////////////////////////////////////////////////// -ResultSet NDArray::allTensorsAlongDimension(const std::initializer_list& dimensions) const { - return allTensorsAlongDimension(std::vector(dimensions)); -} - -//////////////////////////////////////////////////////////////////////// -ResultSet NDArray::allExamples() const { - std::vector dimensions(rankOf() - 1); - for (int e = 1; e < rankOf(); e++) - dimensions[e-1] = e; - - return allTensorsAlongDimension(dimensions); -} - -//////////////////////////////////////////////////////////////////////// -Nd4jLong NDArray::getOffset(const Nd4jLong i) const { - - if (i >= lengthOf()) - throw std::invalid_argument("NDArray::getOffset: input index is out of array length !"); - - return shape::getIndexOffset(i, _shapeInfo); -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::like() { - - return NDArray(shapeInfo(), this->dataType(), false, getContext()); -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::ulike() const{ - - return NDArray(this, false, getContext()); -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::diagonal(const char type) const { - - if (isS()) - throw std::runtime_error("NDArray::diagonal: you can't use this method on String array!"); - - const char order = ordering(); - const int rank = rankOf(); - Nd4jLong *outShapeInfo; - ALLOCATE(outShapeInfo, getContext()->getWorkspace(), 8, Nd4jLong); - outShapeInfo[0] = 2; - outShapeInfo[5] = 0; - - if(isVector() || isScalar()) { - - outShapeInfo[1] = outShapeInfo[2] = outShapeInfo[3] = outShapeInfo[4] = 1; - outShapeInfo[6] = 1; - outShapeInfo[7] = (int)order; - } - else { - - int diagSize = 100000000; - Nd4jLong indices[MAX_RANK]; - - for(int i = 0; i < rank; ++i) { - if(diagSize > shapeOf()[i]) - diagSize = shapeOf()[i]; - indices[i] = 1; - } - - auto step = shape::getOffset(shapeInfo(), indices); - - if(type == 'c') { - outShapeInfo[1] = diagSize; - outShapeInfo[2] = 1; - } - else { - outShapeInfo[1] = 1; - outShapeInfo[2] = diagSize; - } - shape::updateStrides(outShapeInfo, order); - - outShapeInfo[3] *= step; - outShapeInfo[4] *= step; - outShapeInfo[6] = 0; - } - - ArrayOptions::setDataType(outShapeInfo, this->dataType()); - - NDArray result(_buffer, ShapeDescriptor(outShapeInfo), getContext(), bufferOffset()); - - RELEASE(outShapeInfo, getContext()->getWorkspace()); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -ResultSet NDArray::allTensorsAlongDimension(const std::vector &dimensions) const { - - ResultSet result; - - if(dimensions.size() == 0) - return result; - - if(dimensions.back() >= rankOf()) - throw std::runtime_error("NDArray::allTensorsAlongDimension static function: all input dimensions must be smaller than rank of input array !"); - - - auto pack = ConstantTadHelper::getInstance().tadForDimensions(_shapeInfo, const_cast(dimensions.data()), dimensions.size()); - auto numTads = pack.numberOfTads(); - - for (Nd4jLong idx = 0; idx < numTads; idx++ ) { - auto array = new NDArray(_buffer, ShapeDescriptor(pack.primaryShapeInfo()), getContext(), pack.primaryOffsets()[idx] + bufferOffset()); - array->_isView = true; - result.push_back(array); - } - - return result; -} - -//////////////////////////////////////////////////////////////////////// -// operator returns sub-array with buffer pointing at this->_buffer + certain offset -NDArray NDArray::operator()(const std::vector& idx, const bool keepUnitiesInShape, const bool isStrided) const { - - if(isEmpty()) - throw std::invalid_argument("NDArray::operator(sub-arrays): array is empty !"); - - // Nd4jLong *outShapeInfo = nullptr; - // ALLOCATE(outShapeInfo, workspace, shape::shapeInfoLength(inShapeInfo), Nd4jLong); - - int numOfUntiesInSubArrShape = 0; - - Nd4jLong* subArrShapeInfo = nullptr; - - if(!keepUnitiesInShape) { - - int n(isStrided ? 3 : 2), first, last; - - // calculate the number of unities in shape - for (uint d = 0; d < rankOf(); ++d) { - - if (idx[n * d] != idx[n * d + 1]) { - - first = idx[n * d] >= 0 ? idx[n * d] : idx[n * d] + sizeAt(d) + 1; - last = idx[n * d + 1] >= 0 ? idx[n * d + 1] : idx[n * d + 1] + sizeAt(d) + 1; - if(last - first == 1) - ++numOfUntiesInSubArrShape; - } - } - } - - ALLOCATE(subArrShapeInfo, getContext()->getWorkspace(), shape::shapeInfoLength(rankOf() - numOfUntiesInSubArrShape), Nd4jLong); - - Nd4jLong offset; - - shape::calcSubArrShapeInfoAndOffset(idx.data(), shapeInfo(), subArrShapeInfo, offset, keepUnitiesInShape, isStrided, numOfUntiesInSubArrShape); - - NDArray result(_buffer, ShapeDescriptor(subArrShapeInfo), getContext(), offset + bufferOffset()); - result._isView = true; - - RELEASE(subArrShapeInfo, getContext()->getWorkspace()); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::operator()(const Nd4jLong subArrIdx, const std::vector& dimsToExclude, bool keepUnitiesInShape) const { - - std::vector idxRanges(2 * rankOf()); - - const auto rank = rankOf(); - const auto subArrRank = static_cast(dimsToExclude.size()); - - if(subArrRank > rank) - throw std::invalid_argument("NDArray::operator(const Nd4jLong subArrIdx, const std::vector& dimsToExclude, bool keepUnitiesInShape): static method: dimsToExclude is empty or has size > rank of array !"); - - memset(idxRanges.data(), 0, 2 * rank * sizeof(Nd4jLong)); - - // subArrRank == 0 means whole array, idxRanges should contain zeros only - - if(subArrRank != 0) { - - std::vector shapeOfSubArr(subArrRank), indexes(subArrRank); - for(int i = 0; i < subArrRank; ++i) - shapeOfSubArr[i] = sizeAt(dimsToExclude[i]); - - shape::index2coords(subArrIdx, subArrRank, shapeOfSubArr.data(), indexes.data()); - - for(int i = 0; i < subArrRank; ++i) { - int currIdx = 2 * dimsToExclude[i]; - idxRanges[currIdx] = indexes[i]; - idxRanges[currIdx + 1] = indexes[i] + 1; - } - } - - return (*this)(idxRanges, keepUnitiesInShape); -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::getSubArrShapeAndOffsets(const std::vector& dimsToExclude, Nd4jLong* &subArrShapeInfo, Nd4jLong* &subArrOffsets, bool keepUnitiesInShape) const { - - if(isEmpty()) - throw std::invalid_argument("NDArray::getSubArrShapeAndOffsets: array is empty !"); - - const int rank = rankOf(); - const int subArrRank = (rank == dimsToExclude.size() || keepUnitiesInShape) ? rank : rank - dimsToExclude.size(); - const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(_shapeInfo, dimsToExclude); - - // allocate memory - ALLOCATE(subArrShapeInfo, getContext()->getWorkspace(), shape::shapeInfoLength(subArrRank), Nd4jLong); - ALLOCATE(subArrOffsets, getContext()->getWorkspace(), numOfSubArrs, Nd4jLong); - - shape::calcSubArrsShapeInfoAndOffsets(_shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo, subArrOffsets, keepUnitiesInShape); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::setShapeInfo(const Nd4jLong *shapeInfo) { - - if (shapeInfo != nullptr) { - - ShapeDescriptor descriptor(shapeInfo); - auto shapeBuffer = ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor); - - _shapeInfo = shapeBuffer.primary(); - #ifdef __CUDABLAS__ - _shapeInfoD = shapeBuffer.special(); - #endif - - if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) - _length = 0; - else - _length = shape::length(_shapeInfo); - - _dataType = ArrayOptions::dataType(_shapeInfo); - } - else { - _dataType = sd::DataType::INHERIT; - _shapeInfoD = _shapeInfo = nullptr; - } -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::setShapeInfo(const Nd4jLong *shapeInfo, const sd::DataType dtype) { - - if (shapeInfo != nullptr) { - - Nd4jLong* shapeInfoTemp = ShapeBuilders::copyShapeInfoAndType(shapeInfo, dtype, true, getContext()->getWorkspace()); - ShapeDescriptor descriptor(shapeInfoTemp); - auto shapeBuffer = ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor); - - _shapeInfo = shapeBuffer.primary(); - #ifdef __CUDABLAS__ - _shapeInfoD = shapeBuffer.special(); - #endif - - if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) - _length = 0; - else - _length = shape::length(_shapeInfo); - - _dataType = dtype; - } - else { - _dataType = sd::DataType::INHERIT; - _shapeInfoD = _shapeInfo = nullptr; - } -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::setShapeInfo(const ShapeDescriptor& descriptor) { - - auto shapeBuffer = ConstantShapeHelper::getInstance().bufferForShapeInfo(const_cast(descriptor)); - - _shapeInfo = shapeBuffer.primary(); - #ifdef __CUDABLAS__ - _shapeInfoD = shapeBuffer.special(); - #endif - - if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) - _length = 0; - else - _length = shape::length(_shapeInfo); - - _dataType = ArrayOptions::dataType(_shapeInfo); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::setShapeInfo(const ConstantShapeBuffer& shapeBuffer) { - - _shapeInfo = shapeBuffer.primary(); - #ifdef __CUDABLAS__ - _shapeInfoD = shapeBuffer.special(); - #endif - - if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) - _length = 0; - else - _length = shape::length(_shapeInfo); - - _dataType = ArrayOptions::dataType(_shapeInfo); -} - -/////////////////////////////////////////////////////////////////////// -// addition operator array + scalar -template -NDArray operator+(NDArray&& arr, const T& scalar) { - - if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays - return std::move(arr + scalar); // arr is lvalue inside function body - - if (arr.isS()) - throw std::runtime_error("operator+(NDArray&& arr, const T& scalar): you can't use this method on String array!"); - if (arr.dataType() != DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT())) - throw std::runtime_error("operator+(NDArray&& arr, const T& scalar): you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - - NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Add, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); - - return std::move(arr); -} -template ND4J_EXPORT NDArray operator+(NDArray&& arr, const double& scalar); -template ND4J_EXPORT NDArray operator+(NDArray&& arr, const float& scalar); -template ND4J_EXPORT NDArray operator+(NDArray&& arr, const float16& scalar); -template ND4J_EXPORT NDArray operator+(NDArray&& arr, const bfloat16& scalar); -template ND4J_EXPORT NDArray operator+(NDArray&& arr, const int& scalar); - -//////////////////////////////////////////////////////////////////////// -template -NDArray operator+(const NDArray& arr, const T& scalar) { - - if (arr.isS()) - throw std::runtime_error("operator+(const NDArray& arr, const T& scalar): you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Add, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} -template ND4J_EXPORT NDArray operator+(const NDArray& arr, const double& scalar); -template ND4J_EXPORT NDArray operator+(const NDArray& arr, const float& scalar); -template ND4J_EXPORT NDArray operator+(const NDArray& arr, const float16& scalar); -template ND4J_EXPORT NDArray operator+(const NDArray& arr, const bfloat16& scalar); -template ND4J_EXPORT NDArray operator+(const NDArray& arr, const int& scalar); - -//////////////////////////////////////////////////////////////////////// -template -NDArray operator+(const T& scalar, NDArray&& arr) { - return std::move(arr) + scalar; -} -template ND4J_EXPORT NDArray operator+(const double& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator+(const float& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator+(const float16& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator+(const bfloat16& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator+(const int& scalar, NDArray&& arr); - - -//////////////////////////////////////////////////////////////////////// -template -NDArray operator+(const T& scalar, const NDArray& arr) { - return arr + scalar; -} -template ND4J_EXPORT NDArray operator+(const double& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator+(const float& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator+(const int& scalar, const NDArray& arr); - -/////////////////////////////////////////////////////////////////////// -// addition operator array - scalar -template -NDArray operator-(NDArray&& arr, const T& scalar) { - - if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays - return std::move(arr - scalar); // arr is lvalue inside function body - - if (arr.isS()) - throw std::runtime_error("operator-(NDArray&& arr, const T& scalar): you can't use this method on String array!"); - if (arr.dataType() != DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT())) - throw std::runtime_error("operator-(NDArray&& arr, const T& scalar): you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - - NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Subtract, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); - - return std::move(arr); -} -template ND4J_EXPORT NDArray operator-(NDArray&& arr, const double& scalar); -template ND4J_EXPORT NDArray operator-(NDArray&& arr, const float& scalar); - -//////////////////////////////////////////////////////////////////////// -template -NDArray operator-(const NDArray& arr, const T& scalar) { - - if (arr.isS()) - throw std::runtime_error("operator-(const NDArray& arr, const T& scalar): you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Subtract, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} -template ND4J_EXPORT NDArray operator-(const NDArray& arr, const double& scalar); -template ND4J_EXPORT NDArray operator-(const NDArray& arr, const float& scalar); -template ND4J_EXPORT NDArray operator-(const NDArray& arr, const float16& scalar); -template ND4J_EXPORT NDArray operator-(const NDArray& arr, const bfloat16& scalar); -template ND4J_EXPORT NDArray operator-(const NDArray& arr, const int& scalar); - -//////////////////////////////////////////////////////////////////////// -template -NDArray operator-(const T& scalar, NDArray&& arr) { - - if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays - return std::move(scalar - arr); // arr is lvalue inside function body - - if (arr.isS()) - throw std::runtime_error("operator-(const T& scalar, NDArray&& arr): you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - - NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::ReverseSubtract, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); - - return std::move(arr); - -} -template ND4J_EXPORT NDArray operator-(const double& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator-(const float& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator-(const float16& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator-(const bfloat16& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator-(const int& scalar, NDArray&& arr); - -//////////////////////////////////////////////////////////////////////// -template -NDArray operator-(const T& scalar, const NDArray& arr) { - - if (arr.isS()) - throw std::runtime_error("operator-(const T& scalar, const NDArray& arr): you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::ReverseSubtract, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} -template ND4J_EXPORT NDArray operator-(const double& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator-(const float& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator-(const int& scalar, const NDArray& arr); - -/////////////////////////////////////////////////////////////////////// -// addition operator array + scalar -template -NDArray operator*(NDArray&& arr, const T& scalar) { - - if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays - return std::move(arr * scalar); // arr is lvalue inside function body - - if (arr.isS()) - throw std::runtime_error("operator*(NDArray&& arr, const T& scalar): you can't use this method on String array!"); - if (arr.dataType() != DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT())) - throw std::runtime_error("operator*(NDArray&& arr, const T& scalar): you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - - NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Multiply, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); - - return std::move(arr); -} -template ND4J_EXPORT NDArray operator*(NDArray&& arr, const double& scalar); -template ND4J_EXPORT NDArray operator*(NDArray&& arr, const float& scalar); -template ND4J_EXPORT NDArray operator*(NDArray&& arr, const float16& scalar); -template ND4J_EXPORT NDArray operator*(NDArray&& arr, const bfloat16& scalar); -template ND4J_EXPORT NDArray operator*(NDArray&& arr, const int& scalar); -template ND4J_EXPORT NDArray operator*(NDArray&& arr, const long long& scalar); - -//////////////////////////////////////////////////////////////////////// -template -NDArray operator*(const NDArray& arr, const T& scalar) { - - if (arr.isS()) - throw std::runtime_error("operator*(const NDArray& arr, const T& scalar): you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Multiply, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} - -template ND4J_EXPORT NDArray operator*(const NDArray& arr, const double& scalar); -template ND4J_EXPORT NDArray operator*(const NDArray& arr, const float& scalar); -template ND4J_EXPORT NDArray operator*(const NDArray& arr, const float16& scalar); -template ND4J_EXPORT NDArray operator*(const NDArray& arr, const bfloat16& scalar); -template ND4J_EXPORT NDArray operator*(const NDArray& arr, const int& scalar); -template ND4J_EXPORT NDArray operator*(const NDArray& arr, const long long& scalar); - -//////////////////////////////////////////////////////////////////////// -template -NDArray operator*(const T& scalar, NDArray&& arr) { - return std::move(arr) * scalar; -} -template ND4J_EXPORT NDArray operator*(const double& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator*(const float& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator*(const float16& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator*(const bfloat16& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator*(const int& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator*(const long long& scalar, NDArray&& arr); - - -//////////////////////////////////////////////////////////////////////// -template -NDArray operator*(const T& scalar, const NDArray& arr) { - return arr * scalar; -} -template ND4J_EXPORT NDArray operator*(const double& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator*(const float& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator*(const float16& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator*(const bfloat16& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator*(const int& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator*(const long long& scalar, const NDArray& arr); - -/////////////////////////////////////////////////////////////////////// -template -NDArray operator/(NDArray&& arr, const T& scalar) { - - if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays - return std::move(arr / scalar); // arr is lvalue inside function body - - if (arr.isS()) - throw std::runtime_error("operator/(NDArray&& arr, const T& scalar): you can't use this method on String array!"); - if (arr.dataType() != DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT())) - throw std::runtime_error("operator/(NDArray&& arr, const T& scalar): you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - - NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Divide, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); - - return std::move(arr); -} -template ND4J_EXPORT NDArray operator/(NDArray&& arr, const double& scalar); -template ND4J_EXPORT NDArray operator/(NDArray&& arr, const float& scalar); -template ND4J_EXPORT NDArray operator/(NDArray&& arr, const float16& scalar); -template ND4J_EXPORT NDArray operator/(NDArray&& arr, const bfloat16& scalar); -template ND4J_EXPORT NDArray operator/(NDArray&& arr, const long long& scalar); - -//////////////////////////////////////////////////////////////////////// -template -NDArray operator/(const NDArray& arr, const T& scalar) { - - if (arr.isS()) - throw std::runtime_error("operator/(const NDArray& arr, const T& scalar): you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Divide, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} -template ND4J_EXPORT NDArray operator/(const NDArray& arr, const double& scalar); -template ND4J_EXPORT NDArray operator/(const NDArray& arr, const float& scalar); -template ND4J_EXPORT NDArray operator/(const NDArray& arr, const float16& scalar); -template ND4J_EXPORT NDArray operator/(const NDArray& arr, const bfloat16& scalar); -template ND4J_EXPORT NDArray operator/(const NDArray& arr, const int& scalar); -template ND4J_EXPORT NDArray operator/(const NDArray& arr, const long long& scalar); - -//////////////////////////////////////////////////////////////////////// -template -NDArray operator/(const T& scalar, NDArray&& arr) { - - if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays - return std::move(scalar / arr); // arr is lvalue inside function body - - if (arr.isS()) - throw std::runtime_error("operator/(const T& scalar, NDArray&& arr): you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - - NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::ReverseDivide, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); - - return std::move(arr); - -} -template ND4J_EXPORT NDArray operator/(const double& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator/(const float& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator/(const float16& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator/(const bfloat16& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator/(const int& scalar, NDArray&& arr); - - -//////////////////////////////////////////////////////////////////////// -template -NDArray operator/(const T& scalar, const NDArray& arr) { - - if (arr.isS()) - throw std::runtime_error("operator/(const T& scalar, const NDArray& arr): you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::ReverseDivide, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} -template ND4J_EXPORT NDArray operator/(const double& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator/(const float& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator/(const int& scalar, const NDArray& arr); - -//////////////////////////////////////////////////////////////////////// -// addition operator array + array -template -NDArray operator+(T1&& arr1, T2&& arr2) { - - if (arr1.isS() || arr2.isS()) - throw std::runtime_error("operator+(T&& arr1, T&& arr2): you can't use this method on String arrays!"); - if (!Environment::getInstance().isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) - throw sd::datatype_exception::build("operator+(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType()); - - PointersManager pointersManager(arr1.getContext(), "operator+(T&& arr1, T&& arr2)"); - - if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { - - const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); - const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); - - NDArray* result = nullptr; - if(isArr1Rvalue) - result = const_cast(&arr1); - else if(isArr2Rvalue) - result = const_cast(&arr2); - else - result = new NDArray(arr1.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr1.shapeInfo(), arr2.shapeInfo()), false, arr1.getContext()); - - NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); - NativeOpExecutioner::execPairwiseTransform(arr1.getContext(), sd::pairwise::Add, arr1.buffer(), arr1.shapeInfo(), arr1.specialBuffer(), arr1.specialShapeInfo(), arr2.buffer(), arr2.shapeInfo(), arr2.specialBuffer(), arr2.specialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({result}, {&arr1, &arr2}); - - if(!isArr1Rvalue && !isArr2Rvalue) { - NDArray res = std::move(*result); - delete result; - return std::move(res); - } - - return std::move(*result); - } - - return std::forward(arr1).applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), std::forward(arr2)); -} -template ND4J_EXPORT NDArray operator+(NDArray& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator+(NDArray& arr1, NDArray&& arr2); -template ND4J_EXPORT NDArray operator+(NDArray&& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator+(NDArray& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator+(const NDArray& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator+(const NDArray& arr1, NDArray&& arr2); -template ND4J_EXPORT NDArray operator+(const NDArray& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator+(NDArray&& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator+(NDArray&& arr1, NDArray&& arr2); - -//////////////////////////////////////////////////////////////////////// -// addition operator array - array -template -NDArray operator-(T1&& arr1, T2&& arr2) { - - if (arr1.isS() || arr2.isS()) - throw std::runtime_error("operator-(T&& arr1, T&& arr2): you can't use this method on String arrays!"); - if (!Environment::getInstance().isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) - throw sd::datatype_exception::build("operator-(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType()); - - PointersManager pointersManager(arr1.getContext(), "operator-(T&& arr1, T&& arr2)"); - - if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { - - const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); - const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); - - NDArray* result = nullptr; - if(isArr1Rvalue) - result = const_cast(&arr1); - else if(isArr2Rvalue) - result = const_cast(&arr2); - else - result = new NDArray(arr1.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr1.shapeInfo(), arr2.shapeInfo()), false, arr1.getContext()); - - NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); - NativeOpExecutioner::execPairwiseTransform(arr1.getContext(), sd::pairwise::Subtract, arr1.buffer(), arr1.shapeInfo(), arr1.specialBuffer(), arr1.specialShapeInfo(), arr2.buffer(), arr2.shapeInfo(), arr2.specialBuffer(), arr2.specialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({result}, {&arr1, &arr2}); - - if(!isArr1Rvalue && !isArr2Rvalue) { - NDArray res = std::move(*result); - delete result; - return std::move(res); - } - - return std::move(*result); - } - - return std::forward(arr1).applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), std::forward(arr2)); -} -template ND4J_EXPORT NDArray operator-(NDArray& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator-(NDArray& arr1, NDArray&& arr2); -template ND4J_EXPORT NDArray operator-(NDArray&& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator-(NDArray& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator-(const NDArray& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator-(const NDArray& arr1, NDArray&& arr2); -template ND4J_EXPORT NDArray operator-(const NDArray& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator-(NDArray&& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator-(NDArray&& arr1, NDArray&& arr2); - -//////////////////////////////////////////////////////////////////////// -// multiplication operator array*array -template -NDArray operator*(T1&& arr1, T2&& arr2) { - - if (arr1.isS() || arr2.isS()) - throw std::runtime_error("operator*(T&& arr1, T&& arr2): you can't use this method on String arrays!"); - if (!Environment::getInstance().isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) - throw sd::datatype_exception::build("operator*(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType()); - - PointersManager pointersManager(arr1.getContext(), "operator*(T&& arr1, T&& arr2)"); - - if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { - - const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); - const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); - - NDArray* result = nullptr; - if(isArr1Rvalue) - result = const_cast(&arr1); - else if(isArr2Rvalue) - result = const_cast(&arr2); - else - result = new NDArray(arr1.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr1.shapeInfo(), arr2.shapeInfo()), false, arr1.getContext()); - - NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); - NativeOpExecutioner::execPairwiseTransform(arr1.getContext(), sd::pairwise::Multiply, arr1.buffer(), arr1.shapeInfo(), arr1.specialBuffer(), arr1.specialShapeInfo(), arr2.buffer(), arr2.shapeInfo(), arr2.specialBuffer(), arr2.specialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({result}, {&arr1, &arr2}); - - if(!isArr1Rvalue && !isArr2Rvalue) { - NDArray res = std::move(*result); - delete result; - return std::move(res); - } - - return std::move(*result); - } - - return std::forward(arr1).applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), std::forward(arr2)); -} -template ND4J_EXPORT NDArray operator*(NDArray& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator*(NDArray& arr1, NDArray&& arr2); -template ND4J_EXPORT NDArray operator*(NDArray&& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator*(NDArray& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator*(const NDArray& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator*(const NDArray& arr1, NDArray&& arr2); -template ND4J_EXPORT NDArray operator*(const NDArray& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator*(NDArray&& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator*(NDArray&& arr1, NDArray&& arr2); - -//////////////////////////////////////////////////////////////////////// -// multiplication operator array*array -template -NDArray operator/(T1&& arr1, T2&& arr2) { - - if (arr1.isS() || arr2.isS()) - throw std::runtime_error("operator/(T&& arr1, T&& arr2): you can't use this method on String arrays!"); - if (!Environment::getInstance().isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) - throw sd::datatype_exception::build("operator/(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType()); - - PointersManager pointersManager(arr1.getContext(), "operator/(T&& arr1, T&& arr2)"); - - if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { - - const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); - const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); - - NDArray* result = nullptr; - if(isArr1Rvalue) - result = const_cast(&arr1); - else if(isArr2Rvalue) - result = const_cast(&arr2); - else - result = new NDArray(arr1.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr1.shapeInfo(), arr2.shapeInfo()), false, arr1.getContext()); - - NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); - NativeOpExecutioner::execPairwiseTransform(arr1.getContext(), sd::pairwise::Divide, arr1.buffer(), arr1.shapeInfo(), arr1.specialBuffer(), arr1.specialShapeInfo(), arr2.buffer(), arr2.shapeInfo(), arr2.specialBuffer(), arr2.specialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({result}, {&arr1, &arr2}); - - if(!isArr1Rvalue && !isArr2Rvalue) { - NDArray res = std::move(*result); - delete result; - return std::move(res); - } - - return std::move(*result); - } - - return std::forward(arr1).applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), std::forward(arr2)); -} -template ND4J_EXPORT NDArray operator/(NDArray& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator/(NDArray& arr1, NDArray&& arr2); -template ND4J_EXPORT NDArray operator/(NDArray&& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator/(NDArray& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator/(const NDArray& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator/(const NDArray& arr1, NDArray&& arr2); -template ND4J_EXPORT NDArray operator/(const NDArray& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator/(NDArray&& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator/(NDArray&& arr1, NDArray&& arr2); - - -/* -#ifndef __CLION_IDE__ -#include "NDArray.macro" -#endif - */ -} - -#endif - - - - - - - - - - - - -////////////////////////////////////////////////////////////////////////// -// check whether array's rows (arg=0) or columns (arg=1) create orthogonal basis -// bool NDArray::hasOrthonormalBasis(const int arg) { -// if (isS()) -// throw std::runtime_error("NDArray::hasOrthonormalBasis: you can't use this method on String array!"); -// if(rankOf() !=2 ) -// throw std::runtime_error("NDArray::hasOrthBasis method: rank of ndarray is not equal 2 !"); - -// if(arg!=0 && arg!=1) -// throw std::runtime_error("NDArray::hasOrthBasis method: input argument is not equal to 0 or 1 !"); - -// const double eps = 1e-5; -// double dot = 0.f; - -// if(arg) { // check whether columns create orthogonal basis -// for(int j=0; j(i,j)*e(i,k); - -// if(sd::math::nd4j_abs(dot) > eps ) -// return false; - -// dot = 0.f; -// } - -// for(int j=0; j(i,j)*e(i,j); -// if(dot != 0.f && sd::math::nd4j_abs(sd::math::nd4j_sqrt(dot) - 1.f) > eps) -// return false; - -// dot = 0.f; -// } -// } -// else { // check whether rows create orthogonal basis -// for(int i=0; i(i,j)*e(k,j); - -// if(sd::math::nd4j_abs(dot) > eps ) -// return false; - -// dot = 0.; -// } - -// for(int i=0; i(i,j)*e(i,j); - -// if(dot!= 0. && sd::math::nd4j_abs(sd::math::nd4j_sqrt(dot) - 1.) > eps) -// return false; -// dot = 0.; -// } -// } -// return true; -// } diff --git a/libnd4j/include/array/NDArrayFactory.h b/libnd4j/include/array/NDArrayFactory.h index f25c68fb4f32..46688ebff078 100644 --- a/libnd4j/include/array/NDArrayFactory.h +++ b/libnd4j/include/array/NDArrayFactory.h @@ -20,172 +20,344 @@ // @author Oleg Semeniv // -#ifndef DEV_TESTS_NDARRAYFACTORY_H -#define DEV_TESTS_NDARRAYFACTORY_H +#ifndef SD_NDARRAYFACTORY_H +#define SD_NDARRAYFACTORY_H -#include -#include #include + +#include +#include //#include #include -#include +#include namespace sd { - class ND4J_EXPORT NDArrayFactory { - private: - template - static void memcpyFromVector(void *ptr, const std::vector &vector); - public: - template - static NDArray* empty_(sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - static NDArray* empty_(sd::DataType dataType, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - template - static NDArray empty(sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - static NDArray empty(sd::DataType dataType, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - template - static NDArray* valueOf(const std::initializer_list& shape, T value, char order = 'c', sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - template - static NDArray* valueOf(const std::vector& shape, T value, char order = 'c', sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - static NDArray* valueOf(const std::vector& shape, const NDArray& value, char order = 'c', sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - template - static NDArray* linspace(T from, T to, Nd4jLong numElements); - - - template - static NDArray* create_(const T value, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray* create_(sd::DataType dtype, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - template - static NDArray create(const T value, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray create(sd::DataType dtype, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - template - static NDArray create(DataType type, const T scalar, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - - template - static NDArray* vector(Nd4jLong length, T startingValue = (T) 0, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - template - static NDArray* create_(char order, const std::vector &shape, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - static NDArray* create_( char order, const std::vector &shape, sd::DataType dataType, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - template - static NDArray* create_(char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - template - static NDArray create(char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - template - static NDArray create(char order, const std::vector &shape, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray create(char order, const std::vector &shape, sd::DataType dtype, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - template - static NDArray create(const std::vector &values, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); +class SD_EXPORT NDArrayFactory { + private: + template + static void memcpyFromVector(void* ptr, const std::vector& vector); + + public: + static NDArray undefined(); + + template + static NDArray* empty_( + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + static NDArray* empty_( + sd::DataType dataType, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray empty( + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + static NDArray empty( + sd::DataType dataType, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray* valueOf( + const std::initializer_list& shape, T value, char order = 'c', + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray* valueOf( + const std::vector& shape, T value, char order = 'c', + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + static NDArray* valueOf( + const std::vector& shape, const NDArray& value, + char order = 'c', + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray* linspace(T from, T to, Nd4jLong numElements); + + template + static NDArray* create_( + const T value, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray* create_( + sd::DataType dtype, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray create( + const T value, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray create( + sd::DataType dtype, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + template + static NDArray create( + DataType type, const T scalar, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray* vector( + Nd4jLong length, T startingValue = (T)0, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray* create_( + char order, const std::vector& shape, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + static NDArray* create_( + char order, const std::vector& shape, sd::DataType dataType, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray* create_( + char order, const std::vector& shape, + const std::vector& data, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray create( + char order, const std::vector& shape, + const std::vector& data, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray create( + char order, const std::vector& shape, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray create( + char order, const std::vector& shape, sd::DataType dtype, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray create( + const std::vector& values, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); #ifndef __JAVACPP_HACK__ - // this method only available out of javacpp - /** - * This constructor creates vector of T - * - * @param values - */ - - template - static NDArray create(char order, const std::initializer_list& shape, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - template - static NDArray create(T* buffer, char order, const std::initializer_list& shape, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - template - static NDArray create(char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - /** - * This method creates NDArray from .npy file - * @param fileName - * @return - */ - static NDArray fromNpyFile(const char *fileName); - - /** - * This factory create array from utf8 string - * @return NDArray default dataType UTF8 - */ - static NDArray string(const char *string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray* string_(const char *string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray* string_(const std::string &string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray string(const std::string& string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - /** - * This factory create array from utf16 string - * @return NDArray default dataType UTF16 - */ - static NDArray string(const char16_t* u16string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray* string_(const char16_t* u16string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray* string_(const std::u16string& u16string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray string(const std::u16string& u16string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - /** - * This factory create array from utf32 string - * @return NDArray default dataType UTF32 - */ - static NDArray string(const char32_t* u32string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray* string_(const char32_t* u32string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray* string_(const std::u32string& u32string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray string(const std::u32string& u32string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - /** - * This factory create array from vector of utf8 strings - * @return NDArray default dataType UTF8 - */ - static NDArray string( const std::vector &shape, const std::initializer_list &strings, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray string( const std::vector &shape, const std::initializer_list &string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray string( const std::vector &shape, const std::vector &strings, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray string( const std::vector &shape, const std::vector &string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray* string_( const std::vector &shape, const std::initializer_list &strings, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray* string_( const std::vector &shape, const std::initializer_list &string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray* string_( const std::vector &shape, const std::vector &strings, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray* string_( const std::vector &shape, const std::vector &string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - /** - * This factory create array from vector of utf16 strings - * @return NDArray default dataType UTF16 - */ - static NDArray string( const std::vector& shape, const std::initializer_list& strings, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray string( const std::vector& shape, const std::initializer_list& string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray string( const std::vector& shape, const std::vector& strings, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray string( const std::vector& shape, const std::vector& string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray* string_( const std::vector& shape, const std::initializer_list& strings, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray* string_( const std::vector& shape, const std::initializer_list& string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray* string_( const std::vector& shape, const std::vector& strings, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray* string_( const std::vector& shape, const std::vector& string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - /** - * This factory create array from vector of utf32 strings - * @return NDArray default dataType UTF32 - */ - static NDArray string( const std::vector& shape, const std::initializer_list& strings, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray string( const std::vector& shape, const std::initializer_list& string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray string( const std::vector& shape, const std::vector& strings, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray string( const std::vector& shape, const std::vector& string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray* string_( const std::vector& shape, const std::initializer_list& strings, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray* string_( const std::vector& shape, const std::initializer_list& string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray* string_( const std::vector& shape, const std::vector& strings, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray* string_( const std::vector& shape, const std::vector& string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - - static ResultSet createSetOfArrs(const Nd4jLong numOfArrs, const void* buffer, const Nd4jLong* shapeInfo, const Nd4jLong* offsets, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); + // this method only available out of javacpp + /** + * This constructor creates vector of T + * + * @param values + */ + + template + static NDArray create( + char order, const std::initializer_list& shape, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray create( + T* buffer, char order, const std::initializer_list& shape, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray create( + char order, const std::vector& shape, + const std::initializer_list& data, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + /** + * This method creates NDArray from .npy file + * @param fileName + * @return + */ + static NDArray fromNpyFile(const char* fileName); + + /** + * This factory create array from utf8 string + * @return NDArray default dataType UTF8 + */ + static NDArray string( + const char* string, sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray* string_( + const char* string, sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray* string_( + const std::string& string, sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray string( + const std::string& string, sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + /** + * This factory create array from utf16 string + * @return NDArray default dataType UTF16 + */ + static NDArray string( + const char16_t* u16string, sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray* string_( + const char16_t* u16string, sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray* string_( + const std::u16string& u16string, sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray string( + const std::u16string& u16string, sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + /** + * This factory create array from utf32 string + * @return NDArray default dataType UTF32 + */ + static NDArray string( + const char32_t* u32string, sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray* string_( + const char32_t* u32string, sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray* string_( + const std::u32string& u32string, sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray string( + const std::u32string& u32string, sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + /** + * This factory create array from vector of utf8 strings + * @return NDArray default dataType UTF8 + */ + static NDArray string( + const std::vector& shape, + const std::initializer_list& strings, + sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray string( + const std::vector& shape, + const std::initializer_list& string, + sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray string( + const std::vector& shape, + const std::vector& strings, + sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray string( + const std::vector& shape, + const std::vector& string, + sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray* string_( + const std::vector& shape, + const std::initializer_list& strings, + sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray* string_( + const std::vector& shape, + const std::initializer_list& string, + sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray* string_( + const std::vector& shape, + const std::vector& strings, + sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray* string_( + const std::vector& shape, + const std::vector& string, + sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + /** + * This factory create array from vector of utf16 strings + * @return NDArray default dataType UTF16 + */ + static NDArray string( + const std::vector& shape, + const std::initializer_list& strings, + sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray string( + const std::vector& shape, + const std::initializer_list& string, + sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray string( + const std::vector& shape, + const std::vector& strings, + sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray string( + const std::vector& shape, + const std::vector& string, + sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray* string_( + const std::vector& shape, + const std::initializer_list& strings, + sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray* string_( + const std::vector& shape, + const std::initializer_list& string, + sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray* string_( + const std::vector& shape, + const std::vector& strings, + sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray* string_( + const std::vector& shape, + const std::vector& string, + sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + /** + * This factory create array from vector of utf32 strings + * @return NDArray default dataType UTF32 + */ + static NDArray string( + const std::vector& shape, + const std::initializer_list& strings, + sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray string( + const std::vector& shape, + const std::initializer_list& string, + sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray string( + const std::vector& shape, + const std::vector& strings, + sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray string( + const std::vector& shape, + const std::vector& string, + sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray* string_( + const std::vector& shape, + const std::initializer_list& strings, + sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray* string_( + const std::vector& shape, + const std::initializer_list& string, + sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray* string_( + const std::vector& shape, + const std::vector& strings, + sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray* string_( + const std::vector& shape, + const std::vector& string, + sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + static ResultSet createSetOfArrs( + const Nd4jLong numOfArrs, const void* buffer, const Nd4jLong* shapeInfo, + const Nd4jLong* offsets, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); #endif - }; -} +}; +} // namespace sd -#endif //DEV_TESTS_NDARRAYFACTORY_H +#endif // SD_NDARRAYFACTORY_H diff --git a/libnd4j/include/array/NDArrayLambda.hXX b/libnd4j/include/array/NDArrayLambda.hXX index f213b6aa6a96..6e1289a45683 100644 --- a/libnd4j/include/array/NDArrayLambda.hXX +++ b/libnd4j/include/array/NDArrayLambda.hXX @@ -17,306 +17,401 @@ #ifndef CUDA_LAMBDA_HELPER #define CUDA_LAMBDA_HELPER -#include -#include -#include #include #include +#include +#include +#include -static Nd4jLong __device__ __noinline__ getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo) { - return shape::getIndexOffset(index, shapeInfo); +static Nd4jLong __device__ __noinline__ +getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo) { + return shape::getIndexOffset(index, shapeInfo); } static Nd4jLong __device__ __noinline__ length(const Nd4jLong *shapeInfo) { - return shape::length(shapeInfo); + return shape::length(shapeInfo); } -template static _CUDA_G void lambdaKernel(const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda); -template static _CUDA_G void lambdaIndexedKernel(const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda); -template static _CUDA_G void lambdaIndexedPairwiseKernel(const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda); -template static _CUDA_G void lambdaPairwiseKernel(const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda); -template static _CUDA_G void lambdaTriplewiseKernel(const void* vw, const Nd4jLong *wShapeInfo, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda); +template +static _CUDA_G void lambdaKernel(const void *vx, const Nd4jLong *xShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, + Lambda lambda); +template +static _CUDA_G void lambdaIndexedKernel(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + Lambda lambda); +template +static _CUDA_G void lambdaIndexedPairwiseKernel( + const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, + Lambda lambda); +template +static _CUDA_G void lambdaPairwiseKernel(const void *vx, + const Nd4jLong *xShapeInfo, + const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + Lambda lambda); +template +static _CUDA_G void lambdaTriplewiseKernel( + const void *vw, const Nd4jLong *wShapeInfo, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, Lambda lambda); template class LambdaHelper { -public: - - template - FORCEINLINE static void lambdaLauncher(cudaStream_t *stream, const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { - lambdaKernel<<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vz, zShapeInfo, lambda); - auto err = cudaStreamSynchronize(*stream); - if (err != 0) - throw std::runtime_error("NDArray::applyLambda execution failed"); - } - - template - FORCEINLINE static void lambdaIndexedLauncher(cudaStream_t *stream, const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { - lambdaIndexedKernel<<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vz, zShapeInfo, lambda); - auto err = cudaStreamSynchronize(*stream); - if (err != 0) - throw std::runtime_error("NDArray::applyIndexedLambda execution failed"); - } - - template - FORCEINLINE static void lambdaPairwiseLauncher(cudaStream_t *stream, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { - lambdaPairwiseKernel<<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, lambda); - auto err = cudaStreamSynchronize(*stream); - if (err != 0) - throw std::runtime_error("NDArray::applyPairwiseLambda execution failed"); - } - - template - FORCEINLINE static void lambdaIndexedPairwiseLauncher(cudaStream_t *stream, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { - lambdaIndexedPairwiseKernel<<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, lambda); - auto err = cudaStreamSynchronize(*stream); - if (err != 0) - throw std::runtime_error("NDArray::applyIndexedPairwiseLambda execution failed"); - } - - template - FORCEINLINE static void lambdaTriplewiseLauncher(cudaStream_t *stream,const void* vw, const Nd4jLong *wShapeInfo, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { - lambdaTriplewiseKernel<<<256, 512, 1024, *stream>>>(vw, wShapeInfo, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, lambda); - auto err = cudaStreamSynchronize(*stream); - if (err != 0) - throw std::runtime_error("NDArray::applyTriplewiseLambda execution failed"); - } + public: + template + FORCEINLINE static void lambdaLauncher(cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + Lambda lambda) { + lambdaKernel + <<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vz, zShapeInfo, lambda); + auto err = cudaStreamSynchronize(*stream); + if (err != 0) + throw std::runtime_error("NDArray::applyLambda execution failed"); + } + + template + FORCEINLINE static void lambdaIndexedLauncher( + cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { + lambdaIndexedKernel + <<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vz, zShapeInfo, lambda); + auto err = cudaStreamSynchronize(*stream); + if (err != 0) + throw std::runtime_error("NDArray::applyIndexedLambda execution failed"); + } + + template + FORCEINLINE static void lambdaPairwiseLauncher( + cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, + const void *vy, const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, Lambda lambda) { + lambdaPairwiseKernel<<<256, 512, 1024, *stream>>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, lambda); + auto err = cudaStreamSynchronize(*stream); + if (err != 0) + throw std::runtime_error("NDArray::applyPairwiseLambda execution failed"); + } + + template + FORCEINLINE static void lambdaIndexedPairwiseLauncher( + cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, + const void *vy, const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, Lambda lambda) { + lambdaIndexedPairwiseKernel<<<256, 512, 1024, *stream>>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, lambda); + auto err = cudaStreamSynchronize(*stream); + if (err != 0) + throw std::runtime_error( + "NDArray::applyIndexedPairwiseLambda execution failed"); + } + + template + FORCEINLINE static void lambdaTriplewiseLauncher( + cudaStream_t *stream, const void *vw, const Nd4jLong *wShapeInfo, + const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, + Lambda lambda) { + lambdaTriplewiseKernel<<<256, 512, 1024, *stream>>>( + vw, wShapeInfo, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, lambda); + auto err = cudaStreamSynchronize(*stream); + if (err != 0) + throw std::runtime_error( + "NDArray::applyTriplewiseLambda execution failed"); + } }; //////////////////////////////////////////////////////////////////////// template -static _CUDA_G void lambdaKernel(const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); +static _CUDA_G void lambdaKernel(const void *vx, const Nd4jLong *xShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, + Lambda lambda) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); - auto xEws = shape::elementWiseStride(xShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); - auto xOrder = shape::order(xShapeInfo); - auto zOrder = shape::order(zShapeInfo); + auto xOrder = shape::order(xShapeInfo); + auto zOrder = shape::order(zShapeInfo); - auto zLength = length(zShapeInfo); + auto zLength = length(zShapeInfo); - auto tid = threadIdx.x + blockIdx.x * blockDim.x; + auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (xEws >= 1 && zEws >= 1 && xOrder == zOrder) { - for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) - z[e * zEws] = lambda(x[e * xEws]); - } else { - for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { - auto xOffset = getIndexOffset(e, xShapeInfo); - auto zOffset = getIndexOffset(e, zShapeInfo); + if (xEws >= 1 && zEws >= 1 && xOrder == zOrder) { + for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) + z[e * zEws] = lambda(x[e * xEws]); + } else { + for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { + auto xOffset = getIndexOffset(e, xShapeInfo); + auto zOffset = getIndexOffset(e, zShapeInfo); - z[zOffset] = lambda(x[xOffset]); - } + z[zOffset] = lambda(x[xOffset]); } + } } //////////////////////////////////////////////////////////////////////// template -static _CUDA_G void lambdaIndexedKernel(const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); +static _CUDA_G void lambdaIndexedKernel(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + Lambda lambda) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); - auto xEws = shape::elementWiseStride(xShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); - auto xOrder = shape::order(xShapeInfo); - auto zOrder = shape::order(zShapeInfo); + auto xOrder = shape::order(xShapeInfo); + auto zOrder = shape::order(zShapeInfo); - auto zLength = length(zShapeInfo); + auto zLength = length(zShapeInfo); - auto tid = threadIdx.x + blockIdx.x * blockDim.x; + auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (xEws >= 1 && zEws >= 1 && xOrder == zOrder) { - for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) - z[e * zEws] = lambda(e, x[e * xEws]); - } else { - for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { - auto xOffset = getIndexOffset(e, xShapeInfo); - auto zOffset = getIndexOffset(e, zShapeInfo); + if (xEws >= 1 && zEws >= 1 && xOrder == zOrder) { + for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) + z[e * zEws] = lambda(e, x[e * xEws]); + } else { + for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { + auto xOffset = getIndexOffset(e, xShapeInfo); + auto zOffset = getIndexOffset(e, zShapeInfo); - z[zOffset] = lambda(e, x[xOffset]); - } + z[zOffset] = lambda(e, x[xOffset]); } + } } //////////////////////////////////////////////////////////////////////// template -static _CUDA_G void lambdaIndexedPairwiseKernel(const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - auto xEws = shape::elementWiseStride(xShapeInfo); - auto yEws = shape::elementWiseStride(yShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); - - auto xOrder = shape::order(xShapeInfo); - auto yOrder = shape::order(yShapeInfo); - auto zOrder = shape::order(zShapeInfo); - - auto zLength = length(zShapeInfo); - - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - - if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == zOrder && yOrder == xOrder) { - for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) - z[e * zEws] = lambda(e, x[e * xEws], y[e * yEws]); - } else { - for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { - auto xOffset = getIndexOffset(e, xShapeInfo); - auto yOffset = getIndexOffset(e, yShapeInfo); - auto zOffset = getIndexOffset(e, zShapeInfo); - - z[zOffset] = lambda(e, x[xOffset], y[yOffset]); - } +static _CUDA_G void lambdaIndexedPairwiseKernel( + const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, + Lambda lambda) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + auto xEws = shape::elementWiseStride(xShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + + auto xOrder = shape::order(xShapeInfo); + auto yOrder = shape::order(yShapeInfo); + auto zOrder = shape::order(zShapeInfo); + + auto zLength = length(zShapeInfo); + + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + + if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == zOrder && + yOrder == xOrder) { + for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) + z[e * zEws] = lambda(e, x[e * xEws], y[e * yEws]); + } else { + for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { + auto xOffset = getIndexOffset(e, xShapeInfo); + auto yOffset = getIndexOffset(e, yShapeInfo); + auto zOffset = getIndexOffset(e, zShapeInfo); + + z[zOffset] = lambda(e, x[xOffset], y[yOffset]); } + } } //////////////////////////////////////////////////////////////////////// template -static _CUDA_G void lambdaPairwiseKernel(const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - auto xEws = shape::elementWiseStride(xShapeInfo); - auto yEws = shape::elementWiseStride(yShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); - - auto xOrder = shape::order(xShapeInfo); - auto yOrder = shape::order(yShapeInfo); - auto zOrder = shape::order(zShapeInfo); - - auto zLength = length(zShapeInfo); - - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - - if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == zOrder && yOrder == xOrder) { - for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) - z[e * zEws] = lambda(x[e * xEws], y[e * yEws]); - } else { - for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { - auto xOffset = getIndexOffset(e, xShapeInfo); - auto yOffset = getIndexOffset(e, yShapeInfo); - auto zOffset = getIndexOffset(e, zShapeInfo); - - z[zOffset] = lambda(x[xOffset], y[yOffset]); - } +static _CUDA_G void lambdaPairwiseKernel(const void *vx, + const Nd4jLong *xShapeInfo, + const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + Lambda lambda) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + auto xEws = shape::elementWiseStride(xShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + + auto xOrder = shape::order(xShapeInfo); + auto yOrder = shape::order(yShapeInfo); + auto zOrder = shape::order(zShapeInfo); + + auto zLength = length(zShapeInfo); + + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + + if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == zOrder && + yOrder == xOrder) { + for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) + z[e * zEws] = lambda(x[e * xEws], y[e * yEws]); + } else { + for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { + auto xOffset = getIndexOffset(e, xShapeInfo); + auto yOffset = getIndexOffset(e, yShapeInfo); + auto zOffset = getIndexOffset(e, zShapeInfo); + + z[zOffset] = lambda(x[xOffset], y[yOffset]); } + } } //////////////////////////////////////////////////////////////////////// template -static _CUDA_G void lambdaTriplewiseKernel(const void* vw, const Nd4jLong *wShapeInfo, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { - auto w = reinterpret_cast(vw); - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - auto wEws = shape::elementWiseStride(wShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - auto yEws = shape::elementWiseStride(yShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); - - auto wOrder = shape::order(wShapeInfo); - auto xOrder = shape::order(xShapeInfo); - auto yOrder = shape::order(yShapeInfo); - auto zOrder = shape::order(zShapeInfo); - - auto zLength = length(zShapeInfo); - - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - - if (wEws > 1 && xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == zOrder && yOrder == xOrder && wOrder == xOrder) { - for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) - z[e * zEws] = lambda(w[e * wEws], x[e * xEws], y[e * yEws]); - } else { - for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { - auto wOffset = getIndexOffset(e, wShapeInfo); - auto xOffset = getIndexOffset(e, xShapeInfo); - auto yOffset = getIndexOffset(e, yShapeInfo); - auto zOffset = getIndexOffset(e, zShapeInfo); - - z[zOffset] = lambda(w[wOffset], x[xOffset], y[yOffset]); - } +static _CUDA_G void lambdaTriplewiseKernel( + const void *vw, const Nd4jLong *wShapeInfo, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { + auto w = reinterpret_cast(vw); + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + auto wEws = shape::elementWiseStride(wShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + + auto wOrder = shape::order(wShapeInfo); + auto xOrder = shape::order(xShapeInfo); + auto yOrder = shape::order(yShapeInfo); + auto zOrder = shape::order(zShapeInfo); + + auto zLength = length(zShapeInfo); + + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + + if (wEws > 1 && xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == zOrder && + yOrder == xOrder && wOrder == xOrder) { + for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) + z[e * zEws] = lambda(w[e * wEws], x[e * xEws], y[e * yEws]); + } else { + for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { + auto wOffset = getIndexOffset(e, wShapeInfo); + auto xOffset = getIndexOffset(e, xShapeInfo); + auto yOffset = getIndexOffset(e, yShapeInfo); + auto zOffset = getIndexOffset(e, zShapeInfo); + + z[zOffset] = lambda(w[wOffset], x[xOffset], y[yOffset]); } + } } #endif ////////////////////////////////////////////////////////////////////////// -template -void NDArray::applyLambda(Lambda func, NDArray& target) { - - auto dtype = this->dataType(); - - if (dtype != target.dataType()) - throw std::runtime_error("NDArray::applyLambda X/Z data types must be the same"); - //throw datatype_exception::build("NDArray::applyLambda X/Z data types must be the same", dtype, target.dataType()); - prepareSpecialUse({&target}, {this}); - BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); - registerSpecialUse({&target}, {this}); - +template +void NDArray::applyLambda(Lambda func, NDArray &target) { + auto dtype = this->dataType(); + + if (dtype != target.dataType()) + throw std::runtime_error( + "NDArray::applyLambda X/Z data types must be the same"); + // throw datatype_exception::build("NDArray::applyLambda X/Z data types must + // be the same", dtype, target.dataType()); + prepareSpecialUse({&target}, {this}); + BUILD_SINGLE_SELECTOR( + dtype, LambdaHelper, + ::lambdaLauncher(this->_context->getCudaStream(), this->specialBuffer(), + this->specialShapeInfo(), target.specialBuffer(), + target.specialShapeInfo(), func), + LIBND4J_TYPES); + registerSpecialUse({&target}, {this}); } ////////////////////////////////////////////////////////////////////////// -template -void NDArray::applyPairwiseLambda(const NDArray& other, Lambda func, NDArray& target) { - - auto dtype = this->dataType(); - - if (dtype != target.dataType() || dtype != other.dataType()) - throw std::runtime_error("NDArray::applyPairwiseLambda X/Y/Z data types must be the same"); - //throw datatype_exception::build("NDArray::applyLambda X/Z data types must be the same", dtype, target.dataType()); - - prepareSpecialUse({&target}, {this, &other}); - BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); - registerSpecialUse({&target}, {this, &other}); - +template +void NDArray::applyPairwiseLambda(const NDArray &other, Lambda func, + NDArray &target) { + auto dtype = this->dataType(); + + if (dtype != target.dataType() || dtype != other.dataType()) + throw std::runtime_error( + "NDArray::applyPairwiseLambda X/Y/Z data types must be the same"); + // throw datatype_exception::build("NDArray::applyLambda X/Z data types must + // be the same", dtype, target.dataType()); + + prepareSpecialUse({&target}, {this, &other}); + BUILD_SINGLE_SELECTOR( + dtype, LambdaHelper, + ::lambdaPairwiseLauncher(this->_context->getCudaStream(), + this->specialBuffer(), this->specialShapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), + target.specialBuffer(), + target.specialShapeInfo(), func), + LIBND4J_TYPES); + registerSpecialUse({&target}, {this, &other}); } ////////////////////////////////////////////////////////////////////////// template -void NDArray::applyIndexedLambda(Lambda func, NDArray& target) { - - auto dtype = this->dataType(); - if (dtype != target.dataType()) - throw std::runtime_error("NDArray::applyIndexedLambda X/Z data types must be the same"); - - prepareSpecialUse({&target}, {this}); - BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); - registerSpecialUse({&target}, {this}); +void NDArray::applyIndexedLambda(Lambda func, NDArray &target) { + auto dtype = this->dataType(); + if (dtype != target.dataType()) + throw std::runtime_error( + "NDArray::applyIndexedLambda X/Z data types must be the same"); + + prepareSpecialUse({&target}, {this}); + BUILD_SINGLE_SELECTOR( + dtype, LambdaHelper, + ::lambdaIndexedLauncher(this->_context->getCudaStream(), + this->specialBuffer(), this->specialShapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), + func), + LIBND4J_TYPES); + registerSpecialUse({&target}, {this}); } ////////////////////////////////////////////////////////////////////////// template -void NDArray::applyIndexedPairwiseLambda(NDArray& other, Lambda func, NDArray& target) { - - auto dtype = this->dataType(); - if (dtype != target.dataType() || dtype != other.dataType()) - throw std::runtime_error("NDArray::applyIndexedPairwiseLambda X/Y/Z data types must be the same"); - - prepareSpecialUse({&target}, {this, &other}); - BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); - registerSpecialUse({&target}, {this, &other}); +void NDArray::applyIndexedPairwiseLambda(NDArray &other, Lambda func, + NDArray &target) { + auto dtype = this->dataType(); + if (dtype != target.dataType() || dtype != other.dataType()) + throw std::runtime_error( + "NDArray::applyIndexedPairwiseLambda X/Y/Z data types must be the " + "same"); + + prepareSpecialUse({&target}, {this, &other}); + BUILD_SINGLE_SELECTOR( + dtype, LambdaHelper, + ::lambdaIndexedPairwiseLauncher( + this->_context->getCudaStream(), this->specialBuffer(), + this->specialShapeInfo(), other.specialBuffer(), + other.specialShapeInfo(), target.specialBuffer(), + target.specialShapeInfo(), func), + LIBND4J_TYPES); + registerSpecialUse({&target}, {this, &other}); } ////////////////////////////////////////////////////////////////////////// template -void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, Lambda func, NDArray& target) { - - auto dtype = this->dataType(); - - if (dtype != target.dataType() || dtype != second.dataType() || dtype != third.dataType()) - throw std::runtime_error("NDArray::applyTriplewiseLambda X/Y/Z data types must be the same"); - - prepareSpecialUse({&target}, {this, &second, &third}); - BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaTriplewiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), second.specialBuffer(), second.specialShapeInfo(), third.specialBuffer(), third.specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); - registerSpecialUse({&target}, {this, &second, &third}); +void NDArray::applyTriplewiseLambda(NDArray &second, NDArray &third, + Lambda func, NDArray &target) { + auto dtype = this->dataType(); + + if (dtype != target.dataType() || dtype != second.dataType() || + dtype != third.dataType()) + throw std::runtime_error( + "NDArray::applyTriplewiseLambda X/Y/Z data types must be the same"); + + prepareSpecialUse({&target}, {this, &second, &third}); + BUILD_SINGLE_SELECTOR( + dtype, LambdaHelper, + ::lambdaTriplewiseLauncher( + this->_context->getCudaStream(), this->specialBuffer(), + this->specialShapeInfo(), second.specialBuffer(), + second.specialShapeInfo(), third.specialBuffer(), + third.specialShapeInfo(), target.specialBuffer(), + target.specialShapeInfo(), func), + LIBND4J_TYPES); + registerSpecialUse({&target}, {this, &second, &third}); } - - - - - diff --git a/libnd4j/include/array/NDArrayList.h b/libnd4j/include/array/NDArrayList.h index e446213f2ec6..da63086474a7 100644 --- a/libnd4j/include/array/NDArrayList.h +++ b/libnd4j/include/array/NDArrayList.h @@ -23,76 +23,90 @@ #ifndef NDARRAY_LIST_H #define NDARRAY_LIST_H -#include -#include -#include #include #include #include +#include +#include +#include + namespace sd { - class ND4J_EXPORT NDArrayList { - private: - // workspace where chunks belong to - //sd::memory::Workspace* _workspace = nullptr; - sd::LaunchContext * _context = sd::LaunchContext ::defaultContext(); +class SD_EXPORT NDArrayList { + protected: + class InternalArrayList { + public: + // numeric and symbolic ids of this list + std::pair _id; + std::string _name; + + sd::DataType _dtype; + + // stored chunks + MAP_IMPL _chunks; + + // just a counter, for stored elements + std::atomic _elements; + mutable std::atomic _counter; + + // reference shape + std::vector _shape; + + // unstack axis + int _axis = 0; + + // + bool _expandable = false; - // numeric and symbolic ids of this list - std::pair _id; - std::string _name; + // maximum number of elements + int _height = 0; - sd::DataType _dtype; + ////////// + InternalArrayList(int height = 0, bool expandable = false); + ~InternalArrayList() = default; + }; - // stored chunks - MAP_IMPL _chunks; + std::shared_ptr _state; - // just a counter, for stored elements - std::atomic _elements; - std::atomic _counter; + public: + NDArrayList(int height = 0, bool expandable = false); + ~NDArrayList(); - // reference shape - std::vector _shape; + NDArrayList(const sd::NDArrayList& other); + NDArrayList(sd::NDArrayList&& other); - // unstack axis - int _axis = 0; + NDArrayList& operator=(const NDArrayList& other) noexcept; - // - bool _expandable = false; + // move assignment operator + NDArrayList& operator=(NDArrayList&& other) noexcept; - // maximum number of elements - int _height = 0; - public: - NDArrayList(int height, bool expandable = false); - ~NDArrayList(); + sd::DataType dataType() const; - sd::DataType dataType(); + NDArray read(int idx); + NDArray readRaw(int idx); + Nd4jStatus write(int idx, const NDArray& array); - NDArray* read(int idx); - NDArray* readRaw(int idx); - Nd4jStatus write(int idx, NDArray* array); + NDArray pick(const std::vector& indices); + bool isWritten(int index) const; - NDArray* pick(std::initializer_list indices); - NDArray* pick(std::vector& indices); - bool isWritten(int index); + const std::vector& shape() const; + void setShape(const std::vector& shape); - std::vector& shape(); + NDArray stack() const; + void unstack(const NDArray& array, int axis); - NDArray* stack(); - void unstack(NDArray* array, int axis); + const std::pair& id() const; + const std::string& name() const; - std::pair& id(); - std::string& name(); - //sd::memory::Workspace* workspace(); - sd::LaunchContext * context(); - NDArrayList* clone(); + NDArrayList clone(); - bool equals(NDArrayList& other); + bool equals(NDArrayList& other); - int elements(); - int height(); + int elements() const; + int height() const; - int counter(); - }; -} + int counter() const; +}; +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/array/PointerDeallocator.h b/libnd4j/include/array/PointerDeallocator.h index 5bf820421713..ec3e095528d0 100644 --- a/libnd4j/include/array/PointerDeallocator.h +++ b/libnd4j/include/array/PointerDeallocator.h @@ -26,7 +26,7 @@ namespace sd { -class ND4J_EXPORT PointerDeallocator { +class SD_EXPORT PointerDeallocator { public: PointerDeallocator() = default; ~PointerDeallocator() = default; diff --git a/libnd4j/include/array/PointerWrapper.h b/libnd4j/include/array/PointerWrapper.h index 9e15aaaa3398..dad090593d19 100644 --- a/libnd4j/include/array/PointerWrapper.h +++ b/libnd4j/include/array/PointerWrapper.h @@ -27,7 +27,7 @@ #include namespace sd { -class ND4J_EXPORT PointerWrapper { +class SD_EXPORT PointerWrapper { private: void* _pointer = nullptr; std::shared_ptr _deallocator; diff --git a/libnd4j/include/array/PrimaryPointerDeallocator.h b/libnd4j/include/array/PrimaryPointerDeallocator.h index b4fe34764560..6c8e0f553574 100644 --- a/libnd4j/include/array/PrimaryPointerDeallocator.h +++ b/libnd4j/include/array/PrimaryPointerDeallocator.h @@ -26,7 +26,7 @@ #include namespace sd { -class ND4J_EXPORT PrimaryPointerDeallocator : public PointerDeallocator { +class SD_EXPORT PrimaryPointerDeallocator : public PointerDeallocator { public: PrimaryPointerDeallocator() = default; ~PrimaryPointerDeallocator() = default; diff --git a/libnd4j/include/array/ResultSet.h b/libnd4j/include/array/ResultSet.h index 6c80e7b1816a..6b6dc9c0e496 100644 --- a/libnd4j/include/array/ResultSet.h +++ b/libnd4j/include/array/ResultSet.h @@ -25,52 +25,53 @@ #ifndef LIBND4J_RESULTSET_H #define LIBND4J_RESULTSET_H -#include #include -#include #include +#include + +#include namespace sd { - class NDArray; // forward declaration of template class NDArray +class NDArray; // forward declaration of template class NDArray - class ND4J_EXPORT ResultSet { - private: - std::vector _content; - Nd4jStatus _status = ND4J_STATUS_OK; - bool _removable = true; +class SD_EXPORT ResultSet { + private: + std::vector _content; + Nd4jStatus _status = ND4J_STATUS_OK; + bool _removable = true; - void delContent(); + void delContent(); - public: - explicit ResultSet(); + public: + explicit ResultSet(); #ifndef __JAVACPP_HACK__ - ResultSet(const sd::graph::FlatResult* result); + ResultSet(const sd::graph::FlatResult* result); #endif - ResultSet(const ResultSet& other) noexcept; + ResultSet(const ResultSet& other) noexcept; - ResultSet& operator=(const ResultSet& other) noexcept; + ResultSet& operator=(const ResultSet& other) noexcept; - // move constructor - ResultSet(ResultSet&& other) noexcept; + // move constructor + ResultSet(ResultSet&& other) noexcept; - // move assignment operator - ResultSet& operator=(ResultSet&& other) noexcept; + // move assignment operator + ResultSet& operator=(ResultSet&& other) noexcept; - ~ResultSet(); + ~ResultSet(); - int size(); - sd::NDArray* at(const unsigned long idx) const; - sd::NDArray* operator[](const unsigned long idx) const; - void push_back(sd::NDArray* array); + int size(); + sd::NDArray& at(const unsigned long idx) const; + sd::NDArray& operator[](const unsigned long idx) const; + void push_back(const sd::NDArray& array); - Nd4jStatus status(); - void setStatus(Nd4jStatus status); - void purge(); - void setNonRemovable(); - }; -} + Nd4jStatus status(); + void setStatus(Nd4jStatus status); + void purge(); + void setNonRemovable(); +}; +} // namespace sd -#endif //LIBND4J_RESULTSET_H +#endif // LIBND4J_RESULTSET_H diff --git a/libnd4j/include/array/ShapeDescriptor.h b/libnd4j/include/array/ShapeDescriptor.h index 6e2299ba08cb..83b9c875bac2 100644 --- a/libnd4j/include/array/ShapeDescriptor.h +++ b/libnd4j/include/array/ShapeDescriptor.h @@ -18,86 +18,99 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_SHAPEDESCRIPTOR_H -#define DEV_TESTS_SHAPEDESCRIPTOR_H +#ifndef SD_SHAPEDESCRIPTOR_H +#define SD_SHAPEDESCRIPTOR_H -#include -#include +#include #include #include -#include + #include +#include +#include namespace sd { -class ND4J_EXPORT ShapeDescriptor { - - private: - int _rank = 0; - std::vector _shape; - std::vector _strides; - Nd4jLong _ews = 1; - char _order = 'c'; - DataType _dataType; - bool _empty = false; - - public: - ShapeDescriptor(const ShapeDescriptor &other); - ShapeDescriptor(const Nd4jLong *shapeInfo, bool inheritDtype = true); - explicit ShapeDescriptor(const Nd4jLong *shapeInfo, const sd::DataType dtypeOverride); - explicit ShapeDescriptor(const Nd4jLong *shapeInfo, const Nd4jLong *dtypeOverride); - explicit ShapeDescriptor(const Nd4jLong *shapeInfo, const Nd4jLong *dtypeOverride, const Nd4jLong *orderOverride); - explicit ShapeDescriptor(const DataType type, const Nd4jLong length); - explicit ShapeDescriptor(const DataType type, const char order, const Nd4jLong *shape, const int rank); - explicit ShapeDescriptor(const DataType type, const char order, const Nd4jLong *shape, const Nd4jLong *strides, const int rank, Nd4jLong ews, const bool empty); - explicit ShapeDescriptor(const DataType type, const char order, const std::initializer_list &shape); - explicit ShapeDescriptor(const DataType type, const char order, const std::vector &shape); - explicit ShapeDescriptor(const DataType type, const char order, const std::vector &shape, const std::vector &strides); - explicit ShapeDescriptor(const DataType type, const char order, const std::vector &shape, const std::vector &strides, const Nd4jLong ews); - ShapeDescriptor() = default; - ~ShapeDescriptor() = default; - - int rank() const; - Nd4jLong ews() const; - Nd4jLong arrLength() const; - char order() const; - DataType dataType() const; - bool isEmpty() const; - std::vector& shape(); - std::vector& strides(); - - // we use default copy assignment operator - ShapeDescriptor& operator=(const ShapeDescriptor& other) = default; - - // we use default move assignment operator - ShapeDescriptor& operator=(ShapeDescriptor&& other) noexcept = default; - - // equal to operator - bool operator==(const ShapeDescriptor &other) const; - - // less than operator - bool operator<(const ShapeDescriptor &other) const; - - Nd4jLong* toShapeInfo() const; - - - static ShapeDescriptor emptyDescriptor(const DataType type); - static ShapeDescriptor scalarDescriptor(const DataType type); - static ShapeDescriptor vectorDescriptor(const Nd4jLong length, const DataType type); - }; -} +class SD_EXPORT ShapeDescriptor { + private: + int _rank = 0; + std::vector _shape; + std::vector _strides; + Nd4jLong _ews = 1; + char _order = 'c'; + DataType _dataType; + bool _empty = false; + + public: + ShapeDescriptor(const ShapeDescriptor &other); + ShapeDescriptor(const Nd4jLong *shapeInfo, bool inheritDtype = true); + explicit ShapeDescriptor(const Nd4jLong *shapeInfo, + const sd::DataType dtypeOverride); + explicit ShapeDescriptor(const Nd4jLong *shapeInfo, + const Nd4jLong *dtypeOverride); + explicit ShapeDescriptor(const Nd4jLong *shapeInfo, + const Nd4jLong *dtypeOverride, + const Nd4jLong *orderOverride); + explicit ShapeDescriptor(const DataType type, const Nd4jLong length); + explicit ShapeDescriptor(const DataType type, const char order, + const Nd4jLong *shape, const int rank); + explicit ShapeDescriptor(const DataType type, const char order, + const Nd4jLong *shape, const Nd4jLong *strides, + const int rank, Nd4jLong ews, const bool empty); + explicit ShapeDescriptor(const DataType type, const char order, + const std::initializer_list &shape); + explicit ShapeDescriptor(const DataType type, const char order, + const std::vector &shape); + explicit ShapeDescriptor(const DataType type, const char order, + const std::vector &shape, + const std::vector &strides); + explicit ShapeDescriptor(const DataType type, const char order, + const std::vector &shape, + const std::vector &strides, + const Nd4jLong ews); + ShapeDescriptor() = default; + ~ShapeDescriptor() = default; + + int rank() const; + Nd4jLong ews() const; + Nd4jLong arrLength() const; + char order() const; + DataType dataType() const; + bool isEmpty() const; + std::vector &shape(); + std::vector &strides(); + + // we use default copy assignment operator + ShapeDescriptor &operator=(const ShapeDescriptor &other) = default; + + // we use default move assignment operator + ShapeDescriptor &operator=(ShapeDescriptor &&other) noexcept = default; + + // equal to operator + bool operator==(const ShapeDescriptor &other) const; + + // less than operator + bool operator<(const ShapeDescriptor &other) const; + + Nd4jLong *toShapeInfo() const; + + static ShapeDescriptor emptyDescriptor(const DataType type); + static ShapeDescriptor scalarDescriptor(const DataType type); + static ShapeDescriptor vectorDescriptor(const Nd4jLong length, + const DataType type); +}; +} // namespace sd #ifndef __JAVACPP_HACK__ namespace std { - template<> - class ND4J_EXPORT hash { - public: - size_t operator()(const sd::ShapeDescriptor &k) const; - }; -} +template <> +class SD_EXPORT hash { + public: + size_t operator()(const sd::ShapeDescriptor &k) const; +}; +} // namespace std #endif - -#endif //DEV_TESTS_SHAPEDESCRIPTOR_H +#endif // SD_SHAPEDESCRIPTOR_H diff --git a/libnd4j/include/array/ShapeList.h b/libnd4j/include/array/ShapeList.h index f0034ac81487..1e5e697b9ca0 100644 --- a/libnd4j/include/array/ShapeList.h +++ b/libnd4j/include/array/ShapeList.h @@ -21,38 +21,40 @@ #ifndef LIBND4J_SHAPELIST_H #define LIBND4J_SHAPELIST_H -#include #include #include +#include + namespace sd { - class ND4J_EXPORT ShapeList { - protected: - std::vector _shapes; - - bool _destroyed = false; - bool _autoremovable = false; - bool _workspace = false; - public: - ShapeList(const Nd4jLong* shape = nullptr); - ShapeList(const std::vector &shapes, bool isWorkspace); - ShapeList(const std::vector& shapes); - //ShapeList(bool autoRemovable); - - ~ShapeList(); - - std::vector* asVector(); - void destroy(); - int size() const; - const Nd4jLong* at(int idx); - void push_back(const Nd4jLong *shape); - - /** - * PLEASE NOTE: This method should be called ONLY if shapes were generated at workspaces. Otherwise you'll get memory leak - */ - void detach(); - }; -} - - -#endif //LIBND4J_SHAPELIST_H +class SD_EXPORT ShapeList { + protected: + std::vector _shapes; + + bool _destroyed = false; + bool _autoremovable = false; + bool _workspace = false; + + public: + ShapeList(const Nd4jLong* shape = nullptr); + ShapeList(const std::vector& shapes, bool isWorkspace); + ShapeList(const std::vector& shapes); + // ShapeList(bool autoRemovable); + + ~ShapeList(); + + std::vector* asVector(); + void destroy(); + int size() const; + const Nd4jLong* at(int idx); + void push_back(const Nd4jLong* shape); + + /** + * PLEASE NOTE: This method should be called ONLY if shapes were generated at + * workspaces. Otherwise you'll get memory leak + */ + void detach(); +}; +} // namespace sd + +#endif // LIBND4J_SHAPELIST_H diff --git a/libnd4j/include/array/SpaceType.h b/libnd4j/include/array/SpaceType.h index b6c6dfbbcfbc..8c65f69850b3 100644 --- a/libnd4j/include/array/SpaceType.h +++ b/libnd4j/include/array/SpaceType.h @@ -22,11 +22,11 @@ #define ND4J_SPACE_TYPE_H namespace sd { - enum SpaceType { - CONTINUOUS = 1, - COMPLEX = 2, - QUANTIZED = 3, - }; +enum SpaceType { + CONTINUOUS = 1, + COMPLEX = 2, + QUANTIZED = 3, +}; } #endif \ No newline at end of file diff --git a/libnd4j/include/array/SparseType.h b/libnd4j/include/array/SparseType.h index 3b77a1626424..c9084d8292f3 100644 --- a/libnd4j/include/array/SparseType.h +++ b/libnd4j/include/array/SparseType.h @@ -22,12 +22,12 @@ #define LIBND4J_SPARSETYPE_H namespace sd { - enum SparseType { - CSR = 1, - CSC = 2, - COO = 3, - LIL = 4, - }; +enum SparseType { + CSR = 1, + CSC = 2, + COO = 3, + LIL = 4, +}; } -#endif //LIBND4J_SPARSETYPE_H +#endif // LIBND4J_SPARSETYPE_H diff --git a/libnd4j/include/array/TadDescriptor.h b/libnd4j/include/array/TadDescriptor.h index 01ea1caa1270..4417c2bd17be 100644 --- a/libnd4j/include/array/TadDescriptor.h +++ b/libnd4j/include/array/TadDescriptor.h @@ -18,57 +18,61 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_TADDESCRIPTOR_H -#define DEV_TESTS_TADDESCRIPTOR_H +#ifndef SD_TADDESCRIPTOR_H +#define SD_TADDESCRIPTOR_H -#include "ShapeDescriptor.h" #include +#include "ShapeDescriptor.h" + namespace sd { - class ND4J_EXPORT TadDescriptor { - private: - ShapeDescriptor _originalShape; +class SD_EXPORT TadDescriptor { + private: + ShapeDescriptor _originalShape; - std::vector _axis; + std::vector _axis; - bool _unitiesInShape; + bool _unitiesInShape; - public: - explicit TadDescriptor(const Nd4jLong *originalShape, const int *dimensions, const int length, const bool keepUnitiesInShape = false); - explicit TadDescriptor(const ShapeDescriptor &descriptor, const std::vector &dimensions, const bool keepUnitiesInShape = false); - explicit TadDescriptor(const TadDescriptor &other); - ~TadDescriptor() = default; + public: + explicit TadDescriptor(const Nd4jLong *originalShape, const int *dimensions, + const int length, + const bool keepUnitiesInShape = false); + explicit TadDescriptor(const ShapeDescriptor &descriptor, + const std::vector &dimensions, + const bool keepUnitiesInShape = false); + explicit TadDescriptor(const TadDescriptor &other); + ~TadDescriptor() = default; - // we use default copy assignment operator - TadDescriptor& operator=(const TadDescriptor& other) = default; + // we use default copy assignment operator + TadDescriptor &operator=(const TadDescriptor &other) = default; - // we use default move assignment operator - TadDescriptor& operator=(TadDescriptor&& other) noexcept = default; + // we use default move assignment operator + TadDescriptor &operator=(TadDescriptor &&other) noexcept = default; - // equal to operator - bool operator==(const TadDescriptor &other) const; + // equal to operator + bool operator==(const TadDescriptor &other) const; - // less than operator - bool operator<(const TadDescriptor &other) const; + // less than operator + bool operator<(const TadDescriptor &other) const; - std::vector& axis(); - ShapeDescriptor& originalShape(); - ShapeDescriptor const& originalShapeConst() const; - bool areUnitiesinShape() const; - }; -} + std::vector &axis(); + ShapeDescriptor &originalShape(); + ShapeDescriptor const &originalShapeConst() const; + bool areUnitiesinShape() const; +}; +} // namespace sd #ifndef __JAVACPP_HACK__ namespace std { - template<> - class ND4J_EXPORT hash { - public: - size_t operator()(const sd::TadDescriptor &k) const; - }; -} +template <> +class SD_EXPORT hash { + public: + size_t operator()(const sd::TadDescriptor &k) const; +}; +} // namespace std #endif - -#endif //DEV_TESTS_TADDESCRIPTOR_H +#endif // SD_TADDESCRIPTOR_H diff --git a/libnd4j/include/array/TadPack.h b/libnd4j/include/array/TadPack.h index f7ca15fd98a6..da1524aa5a99 100644 --- a/libnd4j/include/array/TadPack.h +++ b/libnd4j/include/array/TadPack.h @@ -18,41 +18,43 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_TADPACK_H -#define DEV_TESTS_TADPACK_H +#ifndef SD_TADPACK_H +#define SD_TADPACK_H #include #include namespace sd { - class ND4J_EXPORT TadPack { - private: - ConstantShapeBuffer _tadShape; - ConstantOffsetsBuffer _tadOffsets; - Nd4jLong _numTads = 0 ; - int _shapeInfoLength = 0; - public: - explicit TadPack(const ConstantShapeBuffer &shapes, const ConstantOffsetsBuffer &offets, Nd4jLong numTads); - TadPack() = default; - ~TadPack() = default; - - const Nd4jLong* primaryShapeInfo() const; - const Nd4jLong* primaryOffsets() const; - - const Nd4jLong* specialShapeInfo() const; - const Nd4jLong* specialOffsets() const; - - Nd4jLong numberOfTads() const; - int shapeInfoLength() const; - - /** - * These methods return either primary or special pointers depending on platform binaries were compiled for - * @return - */ - const Nd4jLong *platformShapeInfo() const; - const Nd4jLong *platformOffsets() const; - }; -} - - -#endif //DEV_TESTS_TADPACK_H +class SD_EXPORT TadPack { + private: + ConstantShapeBuffer _tadShape; + ConstantOffsetsBuffer _tadOffsets; + Nd4jLong _numTads = 0; + int _shapeInfoLength = 0; + + public: + explicit TadPack(const ConstantShapeBuffer& shapes, const ConstantOffsetsBuffer& offets, + Nd4jLong numTads); + TadPack() = default; + ~TadPack() = default; + + const Nd4jLong* primaryShapeInfo() const; + const Nd4jLong* primaryOffsets() const; + + const Nd4jLong* specialShapeInfo() const; + const Nd4jLong* specialOffsets() const; + + Nd4jLong numberOfTads() const; + int shapeInfoLength() const; + + /** + * These methods return either primary or special pointers depending on + * platform binaries were compiled for + * @return + */ + const Nd4jLong* platformShapeInfo() const; + const Nd4jLong* platformOffsets() const; +}; +} // namespace sd + +#endif // SD_TADPACK_H diff --git a/libnd4j/include/array/cpu/DataBuffer.cpp b/libnd4j/include/array/cpu/DataBuffer.cpp index 2575e2ba41bc..7de64cf29ee3 100644 --- a/libnd4j/include/array/cpu/DataBuffer.cpp +++ b/libnd4j/include/array/cpu/DataBuffer.cpp @@ -23,118 +23,111 @@ #include namespace sd { - void DataBuffer::expand(const uint64_t size) { - if (size > _lenInBytes) { - // allocate new buffer - int8_t *newBuffer = nullptr; - ALLOCATE(newBuffer, _workspace, size, int8_t); - - // copy data from existing buffer - std::memcpy(newBuffer, _primaryBuffer, _lenInBytes); - - if (_isOwnerPrimary) { - RELEASE(reinterpret_cast(_primaryBuffer), _workspace); - } - - _primaryBuffer = newBuffer; - _lenInBytes = size; - _isOwnerPrimary = true; - } - } +void DataBuffer::expand(const uint64_t size) { + if (size > _lenInBytes) { + // allocate new buffer + int8_t* newBuffer = nullptr; + ALLOCATE(newBuffer, _workspace, size, int8_t); -//////////////////////////////////////////////////////////////////////// -void DataBuffer::setCountersToZero() { + // copy data from existing buffer + std::memcpy(newBuffer, _primaryBuffer, _lenInBytes); + + if (_isOwnerPrimary) { + RELEASE(reinterpret_cast(_primaryBuffer), _workspace); + } + _primaryBuffer = newBuffer; + _lenInBytes = size; + _isOwnerPrimary = true; + } } //////////////////////////////////////////////////////////////////////// -void DataBuffer::copyCounters(const DataBuffer& other) { +void DataBuffer::setCountersToZero() {} -} //////////////////////////////////////////////////////////////////////// -void DataBuffer::allocateBuffers(const bool allocBoth) { // always allocate primary buffer only (cpu case) - - allocatePrimary(); -} - +void DataBuffer::copyCounters(const DataBuffer& other) {} //////////////////////////////////////////////////////////////////////// -void DataBuffer::copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes, const Nd4jLong offsetThis, const Nd4jLong offsetOther) { - - if(sizeToCopyinBytes == 0) - sizeToCopyinBytes = other.getLenInBytes(); - if(sizeToCopyinBytes == 0) - return; +void DataBuffer::allocateBuffers( + const bool allocBoth) { // always allocate primary buffer only (cpu case) - if(other._primaryBuffer != nullptr) - std::memcpy(static_cast(_primaryBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast(other._primaryBuffer) + offsetOther * DataTypeUtils::sizeOfElement(other._dataType), sizeToCopyinBytes); + allocatePrimary(); } //////////////////////////////////////////////////////////////////////// -void DataBuffer::copyBufferFromHost(const void* hostBuffer, size_t sizeToCopyinBytes, const Nd4jLong offsetThis, const Nd4jLong offsetHostBuffer) { - - if(sizeToCopyinBytes == 0) - sizeToCopyinBytes = getLenInBytes(); - if(sizeToCopyinBytes == 0) - return; - - if(hostBuffer != nullptr) - std::memcpy(static_cast(_primaryBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast(hostBuffer) + offsetHostBuffer * DataTypeUtils::sizeOfElement(_dataType), sizeToCopyinBytes); +void DataBuffer::copyBufferFrom(const DataBuffer& other, + size_t sizeToCopyinBytes, + const Nd4jLong offsetThis, + const Nd4jLong offsetOther) { + if (sizeToCopyinBytes == 0) sizeToCopyinBytes = other.getLenInBytes(); + if (sizeToCopyinBytes == 0) return; + + if (other._primaryBuffer != nullptr) + std::memcpy(static_cast(_primaryBuffer) + + offsetThis * DataTypeUtils::sizeOfElement(_dataType), + static_cast(other._primaryBuffer) + + offsetOther * DataTypeUtils::sizeOfElement(other._dataType), + sizeToCopyinBytes); } - //////////////////////////////////////////////////////////////////////// -void DataBuffer::deleteSpecial() { - +void DataBuffer::copyBufferFromHost(const void* hostBuffer, + size_t sizeToCopyinBytes, + const Nd4jLong offsetThis, + const Nd4jLong offsetHostBuffer) { + if (sizeToCopyinBytes == 0) sizeToCopyinBytes = getLenInBytes(); + if (sizeToCopyinBytes == 0) return; + + if (hostBuffer != nullptr) + std::memcpy(static_cast(_primaryBuffer) + + offsetThis * DataTypeUtils::sizeOfElement(_dataType), + static_cast(hostBuffer) + + offsetHostBuffer * DataTypeUtils::sizeOfElement(_dataType), + sizeToCopyinBytes); } //////////////////////////////////////////////////////////////////////// -void DataBuffer::setSpecial(void* special, const bool isOwnerSpecail) { +void DataBuffer::deleteSpecial() {} -} +//////////////////////////////////////////////////////////////////////// +void DataBuffer::setSpecial(void* special, const bool isOwnerSpecail) {} //////////////////////////////////////////////////////////////////////// void DataBuffer::setToZeroBuffers(const bool both) { - - memset(primary(), 0, getLenInBytes()); + memset(primary(), 0, getLenInBytes()); } //////////////////////////////////////////////////////////////////////// -void DataBuffer::syncToPrimary(const LaunchContext* context, const bool forceSync) { - -} +void DataBuffer::syncToPrimary(const LaunchContext* context, + const bool forceSync) {} //////////////////////////////////////////////////////////////////////// -void DataBuffer::syncToSpecial(const bool forceSync) { - -} +void DataBuffer::syncToSpecial(const bool forceSync) {} //////////////////////////////////////////////////////////////////////// -void DataBuffer::allocateSpecial() { - -} +void DataBuffer::allocateSpecial() {} //////////////////////////////////////////////////////////////////////// -void DataBuffer::migrate() { - -} +void DataBuffer::migrate() {} ///////////////////////// -void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) { - if (src._lenInBytes > dst._lenInBytes) - throw std::runtime_error("DataBuffer::memcpy: Source data buffer is larger than destination"); +void DataBuffer::memcpy(const DataBuffer& dst, const DataBuffer& src) { + if (src._lenInBytes > dst._lenInBytes) + throw std::runtime_error( + "DataBuffer::memcpy: Source data buffer is larger than destination"); - std::memcpy(dst._primaryBuffer, src._primaryBuffer, src._lenInBytes); - dst.readPrimary(); + std::memcpy(dst._primaryBuffer, src._primaryBuffer, src._lenInBytes); + dst.readPrimary(); } +void* DataBuffer::platform() { return primary(); } //////////////////////////////////////////////////////////////////////// -void DataBuffer::writePrimary() const { } -void DataBuffer::writeSpecial() const { } -void DataBuffer::readPrimary() const { } -void DataBuffer::readSpecial() const { } -bool DataBuffer::isPrimaryActual() const { return true;} -bool DataBuffer::isSpecialActual() const { return false;} - - -} +void DataBuffer::writePrimary() const {} +void DataBuffer::writeSpecial() const {} +void DataBuffer::readPrimary() const {} +void DataBuffer::readSpecial() const {} +bool DataBuffer::isPrimaryActual() const { return true; } +bool DataBuffer::isSpecialActual() const { return false; } + +} // namespace sd diff --git a/libnd4j/include/graph/execution/impl/LogicExpose.cpp b/libnd4j/include/array/cpu/ManagedDataBuffer.cpp similarity index 70% rename from libnd4j/include/graph/execution/impl/LogicExpose.cpp rename to libnd4j/include/array/cpu/ManagedDataBuffer.cpp index b19e1df55311..610b2faf9a3b 100644 --- a/libnd4j/include/graph/execution/impl/LogicExpose.cpp +++ b/libnd4j/include/array/cpu/ManagedDataBuffer.cpp @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -15,16 +15,13 @@ ******************************************************************************/ // -// Created by raver119 on 12.11.2017. +// @author raver119@gmail.com // -#include +#include namespace sd { - namespace graph { - Nd4jStatus LogicExpose::processNode(Graph *graph, Node *node) { - // do we really want this? - return ND4J_STATUS_OK; - } - } -} \ No newline at end of file +void *ManagedDataBuffer::primary() { return _descriptor.address(); } + +void *ManagedDataBuffer::special() { return nullptr; } +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/array/cpu/NDArray.cpp b/libnd4j/include/array/cpu/NDArray.cpp index 398ebe5e8cb5..d57849458edb 100644 --- a/libnd4j/include/array/cpu/NDArray.cpp +++ b/libnd4j/include/array/cpu/NDArray.cpp @@ -19,158 +19,161 @@ #include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include #include -#include +#include +#include +#include +#include #include -#include +#include #include +#include #include -#include + #include -#include -#include -#include -#include -#include -#include -#include -#include #include -#include -#include -#include -#include -#include -#include - -#include - +#include namespace sd { //////////////////////////////////////////////////////////////////////// -void* NDArray::platformBuffer() { return buffer(); } -void const* NDArray::platformBuffer() const { return buffer(); } +void* NDArray::platformBuffer() { return buffer(); } +void const* NDArray::platformBuffer() const { return buffer(); } Nd4jLong const* NDArray::platformShapeInfo() const { return shapeInfo(); } -void NDArray::syncToDevice() const { } -void NDArray::syncToHost() const { } -void NDArray::tickWriteHost() const { } -void NDArray::tickWriteDevice() const { } -void NDArray::tickReadHost() const { } -void NDArray::tickReadDevice() const { } -void NDArray::tickBothActual() const { } -bool NDArray::isActualOnHostSide() const { return true; } -bool NDArray::isActualOnDeviceSide() const { return true; } -void NDArray::makeBothBuffersActual() const { } - +void NDArray::syncToDevice() const {} +void NDArray::syncToHost() const {} +void NDArray::tickWriteHost() const {} +void NDArray::tickWriteDevice() const {} +void NDArray::tickReadHost() const {} +void NDArray::tickReadDevice() const {} +void NDArray::tickBothActual() const {} +bool NDArray::isActualOnHostSide() const { return true; } +bool NDArray::isActualOnDeviceSide() const { return true; } +void NDArray::makeBothBuffersActual() const {} //////////////////////////////////////////////////////////////////////// template -void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& target, const char direction) { - - if (isS()) - throw std::runtime_error("NDArray::fillArrayAsTriangular: you can't use this method on String array!"); - - if(!isSameShape(target) && !(rankOf() == 1 && target.rankOf() == 2 && sizeAt(0) == target.sizeAt(0) && sizeAt(0) == target.sizeAt(1))) - throw std::string("NDArray::fillArrayAsTriangular method: wrong shape of target array !"); - - if (direction == 'u') - lower = -target.sizeAt(-2); - else if (direction == 'l') - upper = target.sizeAt(-1); - - const T value = static_cast(val); - const auto x = reinterpret_cast(buffer()); - auto z = reinterpret_cast(target.buffer()); - - const int xRank = rankOf(); - const int zRank = target.rankOf(); - - const auto zLen = target.lengthOf(); - - const bool areSameOffsets = shape::haveSameShapeAndStrides(shapeInfo(), target.shapeInfo()); - - auto func = PRAGMA_THREADS_FOR { - - int coords[MAX_RANK], temp; - - for (auto i = start; i < stop; i++) { - - shape::index2coordsCPU(start, i, target.shapeInfo(), coords); - const auto zOffset = shape::getOffset(target.shapeInfo(), coords); - - // if( (row + upper < col) || (row + lower > col) ) - if ((coords[zRank - 2] + upper < coords[zRank - 1]) || (coords[zRank - 2] + lower > coords[zRank - 1])) - z[zOffset] = value; - else if (this != &target) { // when this and target are different arrays - if (xRank != zRank) { - temp = coords[0]; - coords[0] = coords[1]; - } +void NDArray::fillAsTriangular(const float val, int lower, int upper, + NDArray& target, const char direction) { + if (isS()) + throw std::runtime_error( + "NDArray::fillArrayAsTriangular: you can't use this method on String " + "array!"); + + if (!isSameShape(target) && + !(rankOf() == 1 && target.rankOf() == 2 && + sizeAt(0) == target.sizeAt(0) && sizeAt(0) == target.sizeAt(1))) + throw std::string( + "NDArray::fillArrayAsTriangular method: wrong shape of target array !"); + + if (direction == 'u') + lower = -target.sizeAt(-2); + else if (direction == 'l') + upper = target.sizeAt(-1); + + const T value = static_cast(val); + const auto x = reinterpret_cast(buffer()); + auto z = reinterpret_cast(target.buffer()); + + const int xRank = rankOf(); + const int zRank = target.rankOf(); + + const auto zLen = target.lengthOf(); + + const bool areSameOffsets = + shape::haveSameShapeAndStrides(shapeInfo(), target.shapeInfo()); + + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK], temp; + + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, target.shapeInfo(), coords); + const auto zOffset = shape::getOffset(target.shapeInfo(), coords); + + // if( (row + upper < col) || (row + lower > col) ) + if ((coords[zRank - 2] + upper < coords[zRank - 1]) || + (coords[zRank - 2] + lower > coords[zRank - 1])) + z[zOffset] = value; + else if (this != &target) { // when this and target are different arrays + if (xRank != zRank) { + temp = coords[0]; + coords[0] = coords[1]; + } - const auto xOffset = areSameOffsets ? zOffset : shape::getOffset(shapeInfo(), coords); - z[zOffset] = x[xOffset]; + const auto xOffset = + areSameOffsets ? zOffset : shape::getOffset(shapeInfo(), coords); + z[zOffset] = x[xOffset]; - if (xRank != zRank) // restore first coordinate - coords[0] = temp; - } - } - }; + if (xRank != zRank) // restore first coordinate + coords[0] = temp; + } + } + }; - samediff::Threads::parallel_for(func, 0, zLen); + samediff::Threads::parallel_for(func, 0, zLen); } -BUILD_SINGLE_TEMPLATE(template void NDArray::fillAsTriangular, (const float val, int lower, int upper, NDArray& target, const char direction), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void NDArray::fillAsTriangular, + (const float val, int lower, int upper, NDArray& target, + const char direction), + LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// void NDArray::setIdentity() { - if (isS()) - throw std::runtime_error("NDArray::setIdentity: you can't use this method on String array!"); + if (isS()) + throw std::runtime_error( + "NDArray::setIdentity: you can't use this method on String array!"); - this->nullify(); + this->nullify(); - int rank = rankOf(); - auto shape = shapeOf(); - int minDim = MAX_INT; - Nd4jLong indices[MAX_RANK]; - for(int j = 0; j < rank; ++j) - indices[j] = 1; + int rank = rankOf(); + auto shape = shapeOf(); + int minDim = MAX_INT; + Nd4jLong indices[MAX_RANK]; + for (int j = 0; j < rank; ++j) indices[j] = 1; - Nd4jLong offset = shape::getOffset(shapeInfo(), indices); + Nd4jLong offset = shape::getOffset(shapeInfo(), indices); - for(int i = 0; i < rank; ++i) - if(minDim > shape[i]) - minDim = shape[i]; + for (int i = 0; i < rank; ++i) + if (minDim > shape[i]) minDim = shape[i]; - float v = 1.0f; + float v = 1.0f; - for(int i = 0; i < minDim; ++i) - templatedSet(buffer(), i*offset, this->dataType(), &v); + for (int i = 0; i < minDim; ++i) + templatedSet(buffer(), i * offset, this->dataType(), &v); } //////////////////////////////////////////////////////////////////////// template -static void templatedSwap(void *xBuffer, void *yBuffer, const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, Nd4jLong length) { - auto x = reinterpret_cast(xBuffer); - auto y = reinterpret_cast(yBuffer); +static void templatedSwap(void* xBuffer, void* yBuffer, const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, Nd4jLong length) { + auto x = reinterpret_cast(xBuffer); + auto y = reinterpret_cast(yBuffer); - const bool isSameOrders = shape::order(xShapeInfo) == shape::order(xShapeInfo); + const bool isSameOrders = shape::order(xShapeInfo) == shape::order(xShapeInfo); const auto xEws = shape::elementWiseStride(xShapeInfo); - const auto yEws = shape::elementWiseStride(yShapeInfo); - - auto func = PRAGMA_THREADS_FOR { - if(isSameOrders && xEws > 0 && yEws > 0) { + const auto yEws = shape::elementWiseStride(yShapeInfo);auto func = PRAGMA_THREADS_FOR { + if(isSameOrders && xEws > 0 && yEws > 0) { for(auto i = start; i < stop; i++) - sd::math::nd4j_swap(x[i*xEws], y[i*yEws]); + sd::math::nd4j_swap(x[i*xEws], y[i*yEws]); } - else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { + else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { for(auto i = start; i < stop; i++) { const auto ind = shape::getIndexOffset(i, xShapeInfo); - sd::math::nd4j_swap(x[ind], y[ind]); + sd::math::nd4j_swap(x[ind], y[ind]); } } else { @@ -179,280 +182,318 @@ static void templatedSwap(void *xBuffer, void *yBuffer, const Nd4jLong* xShapeIn const auto yInd = shape::getIndexOffset(i, yShapeInfo); sd::math::nd4j_swap(x[xInd], y[yInd]); } - } - }; + } + }; - samediff::Threads::parallel_for(func, 0, length); + samediff::Threads::parallel_for(func, 0, length); } -BUILD_SINGLE_TEMPLATE(template void templatedSwap, (void *xBuffer, void *yBuffer, const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, Nd4jLong length), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void templatedSwap, + (void* xBuffer, void* yBuffer, const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, Nd4jLong length), + LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// void NDArray::swapUnsafe(NDArray& other) { - auto xType = this->dataType(); + auto xType = this->dataType(); - if (xType != other.dataType()) - throw std::runtime_error("NDArray::swapUnsage method: both arrays must have the same data type"); + if (xType != other.dataType()) + throw std::runtime_error( + "NDArray::swapUnsage method: both arrays must have the same data type"); - if(buffer() == nullptr || other.buffer() == nullptr) - throw std::runtime_error("NDArray::swapUnsafe method: input array should not be empty!"); + if (buffer() == nullptr || other.buffer() == nullptr) + throw std::runtime_error( + "NDArray::swapUnsafe method: input array should not be empty!"); - if(lengthOf() != other.lengthOf()) - throw std::runtime_error("NDArray::swapUnsafe method: input arrays should have the same length!"); + if (lengthOf() != other.lengthOf()) + throw std::runtime_error( + "NDArray::swapUnsafe method: input arrays should have the same " + "length!"); - BUILD_SINGLE_SELECTOR(xType, templatedSwap, (buffer(), other.buffer(), shapeInfo(), other.shapeInfo(), this->lengthOf()), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(xType, templatedSwap, + (buffer(), other.buffer(), shapeInfo(), other.shapeInfo(), this->lengthOf()), + LIBND4J_TYPES); } //////////////////////////////////////////////////////////////////////// void NDArray::synchronize(const char* msg) const { - // no-op + // no-op } -void NDArray::prepareSpecialUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { - // no-op +void NDArray::prepareSpecialUse(const std::vector& writeList, + const std::vector& readList, + bool synchronizeWritables) { + // no-op } -void NDArray::registerSpecialUse(const std::vector& writeList, const std::vector& readList) { - // no-op +void NDArray::registerSpecialUse(const std::vector& writeList, + const std::vector& readList) { + // no-op } -void NDArray::preparePrimaryUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { - // no-op +void NDArray::preparePrimaryUse(const std::vector& writeList, + const std::vector& readList, + bool synchronizeWritables) { + // no-op } -void NDArray::registerPrimaryUse(const std::vector& writeList, const std::vector& readList) { - // no-op +void NDArray::registerPrimaryUse(const std::vector& writeList, + const std::vector& readList) { + // no-op } void NDArray::syncShape() const { - // no-op + // no-op } ////////////////////////////////////////////////////////////////////////// -template -void NDArray::printCurrentBuffer(const bool host, const char* msg, const int precision) const { - -} +template +void NDArray::printCurrentBuffer(const bool host, const char* msg, + const int precision) const {} - //////////////////////////////////////////////////////////////////////// - void* NDArray::specialBufferWithOffset(Nd4jLong offset) { - return nullptr; - } +//////////////////////////////////////////////////////////////////////// +void* NDArray::specialBufferWithOffset(Nd4jLong offset) { return nullptr; } //////////////////////////////////////////////////////////////////////// const void* NDArray::specialBufferWithOffset(Nd4jLong offset) const { - return nullptr; + return nullptr; } //////////////////////////////////////////////////////////////////////// void* NDArray::specialBuffer() { - if (_buffer->special() == nullptr) - return buffer(); - // FIXME: this should be fixed once CUDA backend added - return static_cast(_buffer->special()) + (_offset * sizeOfT()); + if (_buffer->special() == nullptr) return buffer(); + return static_cast(_buffer->special()) + (_offset * sizeOfT()); } //////////////////////////////////////////////////////////////////////// void const* NDArray::specialBuffer() const { - if (_buffer->special() == nullptr) - return buffer(); - // FIXME: this should be fixed once CUDA backend added - return static_cast(_buffer->special()) + (_offset * sizeOfT()); + if (_buffer->special() == nullptr) return buffer(); + return static_cast(_buffer->special()) + (_offset * sizeOfT()); } - ////////////////////////////////////////////////////////////////////////// // change an array by repeating it the number of times given by reps. NDArray NDArray::tile(const std::vector& reps) const { - const int repsSize = reps.size(); - - Nd4jLong product = 1; - for(const auto& item : reps) - product *= item; - if(product == 0) - throw std::runtime_error("NDArray::tile method: one of the elements in reps array is zero !"); - - int rankOld = rankOf(); - int diff = rankOld - repsSize; - if(product==1) { // in this case 2 possibilities are present: just reshape or nothing to do - NDArray result(*this); - if(diff < 0) { // reshape to higher dimension - std::vector shapeNew = reps; // there is requirement to have unities at first "diff" positions of new shape - memcpy(&shapeNew[-diff], result.shapeInfo()+1, rankOld * sizeof(Nd4jLong)); // put old shape numbers at rest of positions - result.reshapei(ordering(), shapeNew); - } - return result; // nothing to do, if diff >= 0 -> identity tile + const int repsSize = reps.size(); + + Nd4jLong product = 1; + for (const auto& item : reps) product *= item; + if (product == 0) + throw std::runtime_error( + "NDArray::tile method: one of the elements in reps array is zero !"); + + int rankOld = rankOf(); + int diff = rankOld - repsSize; + if (product == 1) { // in this case 2 possibilities are present: just reshape + // or nothing to do + NDArray result(*this); + if (diff < 0) { // reshape to higher dimension + std::vector shapeNew = + reps; // there is requirement to have unities at first "diff" + // positions of new shape + memcpy( + &shapeNew[-diff], result.shapeInfo() + 1, + rankOld * + sizeof(Nd4jLong)); // put old shape numbers at rest of positions + result.reshapei(ordering(), shapeNew); } + return result; // nothing to do, if diff >= 0 -> identity tile + } + + // evaluate shapeInfo for resulting array + auto newShapeInfo = + ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace()); + // create new buffer, in any case the memory amount new buffer points to is + // bigger then those for old _buffer + std::shared_ptr newBuff = + std::make_shared(shape::length(newShapeInfo) * sizeOfT(), + dataType(), getContext()->getWorkspace()); + // assign new shape and new buffer to resulting array + NDArray result(newBuff, ShapeDescriptor(newShapeInfo), getContext()); + + // fill newBuff, loop through all elements of newBuff + // looping through _buffer goes automatically by means of getSubArrayIndex + // applying + const auto resultLen = result.lengthOf(); + auto xType = this->dataType(); + if (result.ordering() == 'c') { // ews == 1 always here - // evaluate shapeInfo for resulting array - auto newShapeInfo = ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace()); - // create new buffer, in any case the memory amount new buffer points to is bigger then those for old _buffer - std::shared_ptr newBuff = std::make_shared(shape::length(newShapeInfo) * sizeOfT(), dataType(), getContext()->getWorkspace()); - // assign new shape and new buffer to resulting array - NDArray result(newBuff, ShapeDescriptor(newShapeInfo), getContext()); - - // fill newBuff, loop through all elements of newBuff - // looping through _buffer goes automatically by means of getSubArrayIndex applying - const auto resultLen = result.lengthOf(); - auto xType = this->dataType(); - if(result.ordering() == 'c') { // ews == 1 always here - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto yOffset = shape::subArrayOffset(i, newShapeInfo, shapeInfo()); - BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.buffer(), i, this->buffer(), yOffset), LIBND4J_TYPES); - } - }; - - samediff::Threads::parallel_for(func, 0, resultLen); - } - else { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto yOffset = shape::subArrayOffset(i, newShapeInfo, shapeInfo()); + BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign, + (result.buffer(), i, this->buffer(), yOffset), + LIBND4J_TYPES); + } + }; - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto xOffset = result.getOffset(i); - auto yOffset = shape::subArrayOffset(i, newShapeInfo, shapeInfo()); - BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.buffer(), xOffset, this->buffer(), yOffset), LIBND4J_TYPES); - } - }; + samediff::Threads::parallel_for(func, 0, resultLen); + } else { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto xOffset = result.getOffset(i); + auto yOffset = shape::subArrayOffset(i, newShapeInfo, shapeInfo()); + BUILD_SINGLE_SELECTOR( + xType, this->template templatedAssign, + (result.buffer(), xOffset, this->buffer(), yOffset), LIBND4J_TYPES); + } + }; - samediff::Threads::parallel_for(func, 0, resultLen); - } - result.tickWriteHost(); - return result; + samediff::Threads::parallel_for(func, 0, resultLen); + } + result.tickWriteHost(); + return result; } ////////////////////////////////////////////////////////////////////////// // change an array by repeating it the number of times given by reps. void NDArray::tile(const std::vector& reps, NDArray& target) const { - - auto repProd = shape::prodLong(reps.data(), reps.size()); - if (repProd < 1) - throw std::runtime_error("NDArray::tile: reps can't contain 0s"); - - // evaluate true tile shapeInfo for comparison with target shapeInfo - auto newShapeInfo = ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace()); - if(!shape::equalsSoft(newShapeInfo, target.shapeInfo())) { - delete []newShapeInfo; - throw std::runtime_error("NDArray::tile method - shapeInfo of target array is not suitable for tile operation !"); + auto repProd = shape::prodLong(reps.data(), reps.size()); + if (repProd < 1) + throw std::runtime_error("NDArray::tile: reps can't contain 0s"); + + // evaluate true tile shapeInfo for comparison with target shapeInfo + auto newShapeInfo = + ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace()); + if (!shape::equalsSoft(newShapeInfo, target.shapeInfo())) { + delete[] newShapeInfo; + throw std::runtime_error( + "NDArray::tile method - shapeInfo of target array is not suitable for " + "tile operation !"); + } + + // fill newBuff, loop through all elements of newBuff + // looping through _buffer goes automatically by means of getSubArrayIndex + // applying + const int ews = target.ews(); + const auto targetLen = target.lengthOf(); + if (target.ordering() == 'c' && ews == 1) { // ews == 1 always here + //#pragma omp parallel for simd if(targetLen > + // Environment::getInstance().elementwiseThreshold()) schedule(guided) + for (Nd4jLong i = 0; i < targetLen; ++i) { + auto yOffset = shape::subArrayOffset(i, target.shapeInfo(), shapeInfo()); + BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), + templatedDoubleAssign, + (target.buffer(), i, buffer(), yOffset), + LIBND4J_TYPES, LIBND4J_TYPES); } - - // fill newBuff, loop through all elements of newBuff - // looping through _buffer goes automatically by means of getSubArrayIndex applying - const int ews = target.ews(); - const auto targetLen = target.lengthOf(); - if(target.ordering() == 'c' && ews == 1) { // ews == 1 always here -//#pragma omp parallel for simd if(targetLen > Environment::getInstance().elementwiseThreshold()) schedule(guided) - for(Nd4jLong i=0; i 1) { + for (Nd4jLong i = 0; i < targetLen; ++i) { + auto yOffset = shape::subArrayOffset(i, target.shapeInfo(), shapeInfo()); + BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), + templatedDoubleAssign, + (target.buffer(), i * ews, buffer(), yOffset), + LIBND4J_TYPES, LIBND4J_TYPES); } - else if(target.ordering() == 'c' && ews > 1) { - for(Nd4jLong i=0; i target.rankOf()) - throw std::runtime_error("NDArray::tile method - rank of target array must be bigger or equal to the rank of this array !"); - - if(!ShapeUtils::areShapesBroadcastable(*this, target)) - throw std::runtime_error("NDArray::tile method - shapeInfo of target array is not suitable for tile operation !"); - - // fill newBuff, loop through all elements of newBuff - // looping through _buffer goes automatically by means of getSubArrayIndex applying - const auto ews = target.ews(); - const auto targetLen = target.lengthOf(); - if(target.ordering() == 'c' && ews >= 1) { - - for(Nd4jLong i=0; i target.rankOf()) + throw std::runtime_error( + "NDArray::tile method - rank of target array must be bigger or equal " + "to the rank of this array !"); + + if (!ShapeUtils::areShapesBroadcastable(*this, target)) + throw std::runtime_error( + "NDArray::tile method - shapeInfo of target array is not suitable for " + "tile operation !"); + + // fill newBuff, loop through all elements of newBuff + // looping through _buffer goes automatically by means of getSubArrayIndex + // applying + const auto ews = target.ews(); + const auto targetLen = target.lengthOf(); + if (target.ordering() == 'c' && ews >= 1) { + for (Nd4jLong i = 0; i < targetLen; ++i) { + auto yOffset = shape::subArrayOffset(i, target.shapeInfo(), shapeInfo()); + BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), + templatedDoubleAssign, + (target.buffer(), i * ews, buffer(), yOffset), + LIBND4J_TYPES, LIBND4J_TYPES); } - else { - - for(Nd4jLong i=0; i -static void repeat_(const NDArray& input, NDArray& output, const std::vector& repeats, const int axis) { - - const X* x = input.bufferAsT(); - Z* z = output.bufferAsT(); - - const int rank = input.rankOf(); // xRank = zRank - const int zLen = output.lengthOf(); // xLen <= zLen - const uint repSize = repeats.size(); - - // loop through input array - auto func = PRAGMA_THREADS_FOR { - - int coords[MAX_RANK], temp; - - for (auto i = start; i < stop; i++) { - - shape::index2coordsCPU(start, i, output.shapeInfo(), coords); - const auto zOffset = shape::getOffset(output.shapeInfo(), coords); - - temp = coords[axis]; - - if (repSize > 1) { - for (uint j = 0; j < repSize; ++j) { - coords[axis] -= repeats[j]; - if (coords[axis] < 0) { - coords[axis] = j; - break; - } - } - } else - coords[axis] /= repeats[0]; +template +static void repeat_(const NDArray& input, NDArray& output, + const std::vector& repeats, const int axis) { + const X* x = input.bufferAsT(); + Z* z = output.bufferAsT(); + + const int rank = input.rankOf(); // xRank = zRank + const int zLen = output.lengthOf(); // xLen <= zLen + const uint repSize = repeats.size(); + + // loop through input array + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK], temp; + + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, output.shapeInfo(), coords); + const auto zOffset = shape::getOffset(output.shapeInfo(), coords); + + temp = coords[axis]; + + if (repSize > 1) { + for (uint j = 0; j < repSize; ++j) { + coords[axis] -= repeats[j]; + if (coords[axis] < 0) { + coords[axis] = j; + break; + } + } + } else + coords[axis] /= repeats[0]; - z[zOffset] = x[shape::getOffset(input.shapeInfo(), coords)]; + z[zOffset] = x[shape::getOffset(input.shapeInfo(), coords)]; - coords[axis] = temp; - } - }; + coords[axis] = temp; + } + }; - samediff::Threads::parallel_for(func, 0, zLen); + samediff::Threads::parallel_for(func, 0, zLen); } ////////////////////////////////////////////////////////////////////////// // create new array by repeating it the number of times given by repeats NDArray NDArray::repeat(const int axis, const std::vector& repeats) const { + NDArray output('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), + dataType(), getContext()); - NDArray output('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), dataType(), getContext()); + BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeat_, + (*this, output, repeats, axis), LIBND4J_TYPES); - BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeat_, (*this, output, repeats, axis), LIBND4J_TYPES); - - return output; + return output; } ////////////////////////////////////////////////////////////////////////// // fill array by repeating it the number of times given by reps -void NDArray::repeat(const int axis, const std::vector& repeats, NDArray& target) const { - - if(!target.isSameShape(ShapeUtils::evalRepeatShape(axis, repeats, *this))) - throw std::invalid_argument("NDArray::repeat(const int axis, const std::vector& repeats, NDArray& target) method: wrong shape of target array!"); - - BUILD_DOUBLE_SELECTOR(dataType(), target.dataType(), repeat_, (*this, target, repeats, axis), LIBND4J_TYPES, LIBND4J_TYPES); +void NDArray::repeat(const int axis, const std::vector& repeats, + NDArray& target) const { + if (!target.isSameShape(ShapeUtils::evalRepeatShape(axis, repeats, *this))) + throw std::invalid_argument( + "NDArray::repeat(const int axis, const std::vector& repeats, " + "NDArray& target) method: wrong shape of target array!"); + + BUILD_DOUBLE_SELECTOR(dataType(), target.dataType(), repeat_, + (*this, target, repeats, axis), LIBND4J_TYPES, + LIBND4J_TYPES); } ////////////////////////////////////////////////////////////////////////// @@ -467,9 +508,6 @@ void NDArray::repeat(const int axis, const std::vector& repeats, NDArray& t #include "NDArray.macro" #endif */ -} - - +} // namespace sd #endif - diff --git a/libnd4j/include/array/cpu/NDArrayLambda.hpp b/libnd4j/include/array/cpu/NDArrayLambda.hpp index bd8742288c80..709cd55924c7 100644 --- a/libnd4j/include/array/cpu/NDArrayLambda.hpp +++ b/libnd4j/include/array/cpu/NDArrayLambda.hpp @@ -1,332 +1,473 @@ +template +void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, + const std::function& func, + NDArray& target) { + if (dataType() != DataTypeUtils::fromT()) + throw std::runtime_error( + "NDArray::applyTriplewiseLambda method: wrong template parameter T, " + "its type should be the same as type of this array!"); + if (dataType() != second.dataType() || dataType() != third.dataType() || + dataType() != target.dataType()) + throw std::runtime_error( + "NDArray::applyTriplewiseLambda method: bother four arrays (this, " + "second, third, target) should have the same type !"); + + if (this->lengthOf() != second.lengthOf() || + this->lengthOf() != third.lengthOf() || !this->isSameShape(second) || + !this->isSameShape(third)) { + nd4j_printf( + "applyTriplewiseLambda requires all operands to have the same shape\n", + ""); + throw std::runtime_error("Shapes mismach"); + } + + auto f = this->bufferAsT(); + auto s = second.bufferAsT(); + auto t = third.bufferAsT(); + auto z = target.bufferAsT(); + + if (this->ordering() == second.ordering() && + this->ordering() == third.ordering() && + this->ordering() == target.ordering() && + (this->ews() == 1 && target.ews() == 1) && this->ews() == second.ews() && + this->ews() == third.ews()) { + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) z[e] = func(f[e], s[e], t[e]); + }; + + samediff::Threads::parallel_for(loop, 0, _length); + } else { + if (f == z) { + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto tOffset = this->getOffset(e); + auto uOffset = second.getOffset(e); + auto vOffset = third.getOffset(e); + + f[tOffset] = func(f[tOffset], s[uOffset], t[vOffset]); + } + }; -template -void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::function& func, NDArray& target) { - - if(dataType() != DataTypeUtils::fromT()) - throw std::runtime_error("NDArray::applyTriplewiseLambda method: wrong template parameter T, its type should be the same as type of this array!"); - if(dataType() != second.dataType() || dataType() != third.dataType() || dataType() != target.dataType()) - throw std::runtime_error("NDArray::applyTriplewiseLambda method: bother four arrays (this, second, third, target) should have the same type !"); - - if (this->lengthOf() != second.lengthOf() || this->lengthOf() != third.lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) { - nd4j_printf("applyTriplewiseLambda requires all operands to have the same shape\n",""); - throw std::runtime_error("Shapes mismach"); - } - - auto f = this->bufferAsT(); - auto s = second.bufferAsT(); - auto t = third.bufferAsT(); - auto z = target.bufferAsT(); - - if (this->ordering() == second.ordering() && this->ordering() == third.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == second.ews() && this->ews() == third.ews()) { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - z[e] = func(f[e], s[e], t[e]); - }; - - samediff::Threads::parallel_for(loop, 0, _length); + samediff::Threads::parallel_for(loop, 0, _length); } else { - if (f == z) { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto tOffset = this->getOffset(e); - auto uOffset = second.getOffset(e); - auto vOffset = third.getOffset(e); - - f[tOffset] = func(f[tOffset], s[uOffset], t[vOffset]); - } - }; - - samediff::Threads::parallel_for(loop, 0, _length); - } else { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto tOffset = this->getOffset(e); - auto uOffset = second.getOffset(e); - auto vOffset = third.getOffset(e); - auto zOffset = target.getOffset(e); - - z[zOffset] = func(f[tOffset], s[uOffset], t[vOffset]); - } - }; - - samediff::Threads::parallel_for(loop, 0, _length); + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto tOffset = this->getOffset(e); + auto uOffset = second.getOffset(e); + auto vOffset = third.getOffset(e); + auto zOffset = target.getOffset(e); + + z[zOffset] = func(f[tOffset], s[uOffset], t[vOffset]); } + }; + + samediff::Threads::parallel_for(loop, 0, _length); } + } } -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, + NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, + NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, + NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, + NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, + NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, + NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, + NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, + NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, NDArray& target); ////////////////////////////////////////////////////////////////////////// -template -void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target) { - - if(dataType() != DataTypeUtils::fromT()) - throw std::runtime_error("NDArray::applyPairwiseLambda method: wrong template parameter T, its type should be the same as type of this array!"); - if(dataType() != other.dataType() || dataType() != target.dataType()) - throw std::runtime_error("NDArray::applyPairwiseLambda method: all three arrays (this, other, target) must have the same type !"); - - if (this->lengthOf() != other.lengthOf()) { - nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n",""); - throw std::runtime_error("Shapes mismach"); - } - - auto f = this->bufferAsT(); - auto s = other.bufferAsT(); - auto z = target.bufferAsT(); - - if (this->ordering() == other.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - z[e] = func(f[e], s[e]); - }; +template +void NDArray::applyPairwiseLambda(const NDArray& other, + const std::function& func, + NDArray& target) { + if (dataType() != DataTypeUtils::fromT()) + throw std::runtime_error( + "NDArray::applyPairwiseLambda method: wrong template parameter T, " + "its type should be the same as type of this array!"); + if (dataType() != other.dataType() || dataType() != target.dataType()) + throw std::runtime_error( + "NDArray::applyPairwiseLambda method: all three arrays (this, " + "other, target) must have the same type !"); + + if (this->lengthOf() != other.lengthOf()) { + nd4j_printf( + "applyPairwiseLambda requires both operands to have the same shape\n", + ""); + throw std::runtime_error("Shapes mismach"); + } + + auto f = this->bufferAsT(); + auto s = other.bufferAsT(); + auto z = target.bufferAsT(); + + if (this->ordering() == other.ordering() && + this->ordering() == target.ordering() && + (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) { + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) z[e] = func(f[e], s[e]); + }; + + samediff::Threads::parallel_for(loop, 0, _length); + } else { + if (f == z) { + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto xOffset = this->getOffset(e); + auto yOffset = other.getOffset(e); + + f[xOffset] = func(f[xOffset], s[yOffset]); + } + }; - samediff::Threads::parallel_for(loop, 0, _length); + samediff::Threads::parallel_for(loop, 0, _length); } else { - if (f == z) { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto xOffset = this->getOffset(e); - auto yOffset = other.getOffset(e); - - f[xOffset] = func(f[xOffset], s[yOffset]); - } - }; + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto xOffset = this->getOffset(e); + auto yOffset = other.getOffset(e); + auto zOffset = target.getOffset(e); - samediff::Threads::parallel_for(loop, 0, _length); - } else { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto xOffset = this->getOffset(e); - auto yOffset = other.getOffset(e); - auto zOffset = target.getOffset(e); - - z[zOffset] = func(f[xOffset], s[yOffset]); - } - }; - - samediff::Threads::parallel_for(loop, 0, _length); + z[zOffset] = func(f[xOffset], s[yOffset]); } + }; + + samediff::Threads::parallel_for(loop, 0, _length); } + } } -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, const std::function& func, + NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, const std::function& func, + NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, const std::function& func, + NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, + const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, + const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, const std::function& func, + NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, const std::function& func, + NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, const std::function& func, + NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, + const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, + const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, + const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, const std::function& func, + NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, const std::function& func, + NDArray& target); ////////////////////////////////////////////////////////////////////////// -template +template void NDArray::applyLambda(const std::function& func, NDArray& target) { + if (dataType() != DataTypeUtils::fromT()) + throw std::runtime_error( + "NDArray::applyLambda method: wrong template parameter T, its type " + "should be the same as type of this array!"); + if (dataType() != target.dataType()) + throw std::runtime_error( + "NDArray::applyLambda method: types of this and target array should " + "match !"); + + auto f = this->bufferAsT(); + auto z = target.bufferAsT(); + + if (this->ordering() == target.ordering() && + (this->ews() == 1 && target.ews() == 1)) { + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) z[e] = func(f[e]); + }; + + samediff::Threads::parallel_for(loop, 0, _length); + } else { + if (f == z) { + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto xOffset = this->getOffset(e); + + f[xOffset] = func(f[xOffset]); + } + }; - if(dataType() != DataTypeUtils::fromT()) - throw std::runtime_error("NDArray::applyLambda method: wrong template parameter T, its type should be the same as type of this array!"); - if(dataType() != target.dataType()) - throw std::runtime_error("NDArray::applyLambda method: types of this and target array should match !"); - - auto f = this->bufferAsT(); - auto z = target.bufferAsT(); - - if (this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1)) { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - z[e] = func(f[e]); - }; - - samediff::Threads::parallel_for(loop, 0, _length); + samediff::Threads::parallel_for(loop, 0, _length); } else { - if (f == z) { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto xOffset = this->getOffset(e); - - f[xOffset] = func(f[xOffset]); - } - }; + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto xOffset = this->getOffset(e); + auto zOffset = target.getOffset(e); - samediff::Threads::parallel_for(loop, 0, _length); - } else { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto xOffset = this->getOffset(e); - auto zOffset = target.getOffset(e); - - z[zOffset] = func(f[xOffset]); - } - }; - - samediff::Threads::parallel_for(loop, 0, _length); + z[zOffset] = func(f[xOffset]); } + }; + + samediff::Threads::parallel_for(loop, 0, _length); } + } } -template void NDArray::applyLambda(const std::function& func, NDArray& target); -template void NDArray::applyLambda(const std::function& func, NDArray& target); -template void NDArray::applyLambda(const std::function& func, NDArray& target); -template void NDArray::applyLambda(const std::function& func, NDArray& target); -template void NDArray::applyLambda(const std::function& func, NDArray& target); -template void NDArray::applyLambda(const std::function& func, NDArray& target); -template void NDArray::applyLambda(const std::function& func, NDArray& target); -template void NDArray::applyLambda(const std::function& func, NDArray& target); -template void NDArray::applyLambda(const std::function& func, NDArray& target); -template void NDArray::applyLambda(const std::function& func, NDArray& target); -template void NDArray::applyLambda(const std::function& func, NDArray& target); -template void NDArray::applyLambda(const std::function& func, NDArray& target); -template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, + NDArray& target); +template void NDArray::applyLambda(const std::function& func, + NDArray& target); +template void NDArray::applyLambda(const std::function& func, + NDArray& target); +template void NDArray::applyLambda( + const std::function& func, NDArray& target); +template void NDArray::applyLambda( + const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, + NDArray& target); +template void NDArray::applyLambda(const std::function& func, + NDArray& target); +template void NDArray::applyLambda(const std::function& func, + NDArray& target); +template void NDArray::applyLambda( + const std::function& func, NDArray& target); +template void NDArray::applyLambda( + const std::function& func, NDArray& target); +template void NDArray::applyLambda( + const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, + NDArray& target); +template void NDArray::applyLambda(const std::function& func, + NDArray& target); ////////////////////////////////////////////////////////////////////////// -template -void NDArray::applyIndexedLambda(const std::function& func, NDArray& target) { - - if(dataType() != DataTypeUtils::fromT()) - throw std::runtime_error("NDArray::applyIndexedLambda method: wrong template parameter T, its type should be the same as type of this array!"); - if(dataType() != target.dataType()) - throw std::runtime_error("NDArray::applyIndexedLambda method: types of this and target array should match !"); - - auto f = this->bufferAsT(); - auto z = target.bufferAsT(); - - if (this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1)) { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - z[e] = func(e, f[e]); - }; +template +void NDArray::applyIndexedLambda(const std::function& func, + NDArray& target) { + if (dataType() != DataTypeUtils::fromT()) + throw std::runtime_error( + "NDArray::applyIndexedLambda method: wrong template parameter T, " + "its type should be the same as type of this array!"); + if (dataType() != target.dataType()) + throw std::runtime_error( + "NDArray::applyIndexedLambda method: types of this and target array " + "should match !"); + + auto f = this->bufferAsT(); + auto z = target.bufferAsT(); + + if (this->ordering() == target.ordering() && + (this->ews() == 1 && target.ews() == 1)) { + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) z[e] = func(e, f[e]); + }; + + samediff::Threads::parallel_for(loop, 0, _length); + } else { + if (f == z) { + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto xOffset = this->getOffset(e); + + f[xOffset] = func(e, f[xOffset]); + } + }; - samediff::Threads::parallel_for(loop, 0, _length); + samediff::Threads::parallel_for(loop, 0, _length); } else { - if (f == z) { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto xOffset = this->getOffset(e); - - f[xOffset] = func(e, f[xOffset]); - } - }; + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto xOffset = this->getOffset(e); + auto zOffset = target.getOffset(e); - samediff::Threads::parallel_for(loop, 0, _length); - } else { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto xOffset = this->getOffset(e); - auto zOffset = target.getOffset(e); - - z[zOffset] = func(e, f[xOffset]); - } - }; - - samediff::Threads::parallel_for(loop, 0, _length); + z[zOffset] = func(e, f[xOffset]); } + }; + + samediff::Threads::parallel_for(loop, 0, _length); } + } } -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); ////////////////////////////////////////////////////////////////////////// -template -void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target) { - - if(dataType() != DataTypeUtils::fromT()) - throw std::runtime_error("NDArray::applyIndexedPairwiseLambda method: wrong template parameter T, its type should be the same as type of this array!"); - if(dataType() != target.dataType()) - throw std::runtime_error("NDArray::applyIndexedPairwiseLambda method: types of this and target array should match !"); - if (this->lengthOf() != other.lengthOf()) { - nd4j_printf("applyIndexedPairwiseLambda requires both operands to have the same shape\n",""); - throw std::runtime_error("Shapes mismach"); - } - - auto f = this->bufferAsT(); - auto s = other.bufferAsT(); - auto z = target.bufferAsT(); - - if (this->ordering() == other.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - z[e] = func((Nd4jLong) e, f[e], s[e]); - }; +template +void NDArray::applyIndexedPairwiseLambda( + NDArray& other, const std::function& func, + NDArray& target) { + if (dataType() != DataTypeUtils::fromT()) + throw std::runtime_error( + "NDArray::applyIndexedPairwiseLambda method: wrong template " + "parameter T, its type should be the same as type of this array!"); + if (dataType() != target.dataType()) + throw std::runtime_error( + "NDArray::applyIndexedPairwiseLambda method: types of this and " + "target array should match !"); + if (this->lengthOf() != other.lengthOf()) { + nd4j_printf( + "applyIndexedPairwiseLambda requires both operands to have the same " + "shape\n", + ""); + throw std::runtime_error("Shapes mismach"); + } + + auto f = this->bufferAsT(); + auto s = other.bufferAsT(); + auto z = target.bufferAsT(); + + if (this->ordering() == other.ordering() && + this->ordering() == target.ordering() && + (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) { + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) z[e] = func((Nd4jLong)e, f[e], s[e]); + }; + + samediff::Threads::parallel_for(loop, 0, _length); + } else { + if (f == z) { + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto xOffset = this->getOffset(e); + auto yOffset = other.getOffset(e); + + f[xOffset] = func((Nd4jLong)e, f[xOffset], s[yOffset]); + } + }; - samediff::Threads::parallel_for(loop, 0, _length); + samediff::Threads::parallel_for(loop, 0, _length); } else { - if (f == z) { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto xOffset = this->getOffset(e); - auto yOffset = other.getOffset(e); - - f[xOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]); - } - }; - - samediff::Threads::parallel_for(loop, 0, _length); - } else { + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto xOffset = this->getOffset(e); + auto yOffset = other.getOffset(e); + auto zOffset = target.getOffset(e); - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto xOffset = this->getOffset(e); - auto yOffset = other.getOffset(e); - auto zOffset = target.getOffset(e); - - z[zOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]); - } - }; - - samediff::Threads::parallel_for(loop, 0, _length); + z[zOffset] = func((Nd4jLong)e, f[xOffset], s[yOffset]); } + }; + + samediff::Threads::parallel_for(loop, 0, _length); } + } } -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); \ No newline at end of file +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, const std::function& func, + NDArray& target); +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, const std::function& func, + NDArray& target); +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, + const std::function& func, + NDArray& target); +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, + const std::function& func, + NDArray& target); +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, + const std::function& func, + NDArray& target); +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, const std::function& func, + NDArray& target); +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, + const std::function& func, + NDArray& target); +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, + const std::function& func, + NDArray& target); +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, + const std::function& func, + NDArray& target); +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, + const std::function& func, + NDArray& target); +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, + const std::function& func, + NDArray& target); +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, const std::function& func, + NDArray& target); +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, const std::function& func, + NDArray& target); \ No newline at end of file diff --git a/libnd4j/include/array/cuda/DataBuffer.cu b/libnd4j/include/array/cuda/DataBuffer.cu index 7e88e06ba791..fb454fa320e4 100644 --- a/libnd4j/include/array/cuda/DataBuffer.cu +++ b/libnd4j/include/array/cuda/DataBuffer.cu @@ -19,274 +19,320 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include "../DataBuffer.h" #include -#include +#include #include #include #include -#include +#include +#include namespace sd { - void DataBuffer::expand(const uint64_t size) { - if (size > _lenInBytes) { - // allocate new buffer - int8_t *newBuffer = nullptr; - int8_t *newSpecialBuffer = nullptr; - ALLOCATE_SPECIAL(newSpecialBuffer, _workspace, size, int8_t); - - // copy data from existing buffer - if (_primaryBuffer != nullptr) { - // there's non-zero chance that primary buffer doesn't exist yet - ALLOCATE(newBuffer, _workspace, size, int8_t); - std::memcpy(newBuffer, _primaryBuffer, _lenInBytes); - - if (_isOwnerPrimary) { - auto ipb = reinterpret_cast(_primaryBuffer); - RELEASE(ipb, _workspace); - } - - _primaryBuffer = newBuffer; - _isOwnerPrimary = true; - } - - cudaMemcpy(newSpecialBuffer, _specialBuffer, _lenInBytes, cudaMemcpyDeviceToDevice); - - if (_isOwnerSpecial) { - auto isb = reinterpret_cast(_specialBuffer); - RELEASE_SPECIAL(isb, _workspace); - } - - _specialBuffer = newSpecialBuffer; - _lenInBytes = size; - _isOwnerSpecial = true; - } +void DataBuffer::expand(const uint64_t size) { + if (size > _lenInBytes) { + // allocate new buffer + int8_t* newBuffer = nullptr; + int8_t* newSpecialBuffer = nullptr; + ALLOCATE_SPECIAL(newSpecialBuffer, _workspace, size, int8_t); + + // copy data from existing buffer + if (_primaryBuffer != nullptr) { + // there's non-zero chance that primary buffer doesn't exist yet + ALLOCATE(newBuffer, _workspace, size, int8_t); + std::memcpy(newBuffer, _primaryBuffer, _lenInBytes); + + if (_isOwnerPrimary) { + auto ipb = reinterpret_cast(_primaryBuffer); + RELEASE(ipb, _workspace); + } + + _primaryBuffer = newBuffer; + _isOwnerPrimary = true; } -//////////////////////////////////////////////////////////////////////// -void DataBuffer::allocateSpecial() { - - if (_specialBuffer == nullptr && getLenInBytes() > 0) { - auto deviceId = sd::AffinityManager::currentDeviceId(); - - if (_workspace == nullptr) - if (!sd::memory::MemoryCounter::getInstance().validate(getLenInBytes())) - throw sd::allocation_exception::build("Requested amount exceeds device limits", sd::memory::MemoryCounter::getInstance().deviceLimit(deviceId), getLenInBytes()); - + cudaMemcpy(newSpecialBuffer, _specialBuffer, _lenInBytes, + cudaMemcpyDeviceToDevice); - ALLOCATE_SPECIAL(_specialBuffer, _workspace, getLenInBytes(), int8_t); - _isOwnerSpecial = true; - - if (_workspace == nullptr) { - sd::memory::MemoryCounter::getInstance().countIn(deviceId, getLenInBytes()); - sd::memory::MemoryCounter::getInstance().countIn(sd::memory::MemoryType::DEVICE, getLenInBytes()); - } + if (_isOwnerSpecial) { + auto isb = reinterpret_cast(_specialBuffer); + RELEASE_SPECIAL(isb, _workspace); } + + _specialBuffer = newSpecialBuffer; + _lenInBytes = size; + _isOwnerSpecial = true; + } } //////////////////////////////////////////////////////////////////////// -void DataBuffer::syncToPrimary(const LaunchContext* context, const bool forceSync) { - if(isPrimaryActual() && !forceSync) { - return; - } - - allocatePrimary(); +void DataBuffer::allocateSpecial() { + if (_specialBuffer == nullptr && getLenInBytes() > 0) { + auto deviceId = sd::AffinityManager::currentDeviceId(); - auto res = cudaStreamSynchronize(*context->getCudaStream()); - if (res != 0) - throw cuda_exception::build("DataBuffer::syncToPrimary failed to to some previous kernel failre", res); + if (_workspace == nullptr) + if (!sd::memory::MemoryCounter::getInstance().validate(getLenInBytes())) + throw sd::allocation_exception::build( + "Requested amount exceeds device limits", + sd::memory::MemoryCounter::getInstance().deviceLimit(deviceId), + getLenInBytes()); - res = cudaMemcpy(_primaryBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToHost); - if (res != 0) - throw cuda_exception::build("DataBuffer::syncToPrimary cudaMemcpy failed", res); + ALLOCATE_SPECIAL(_specialBuffer, _workspace, getLenInBytes(), int8_t); + _isOwnerSpecial = true; - readPrimary(); + if (_workspace == nullptr) { + sd::memory::MemoryCounter::getInstance().countIn(deviceId, + getLenInBytes()); + sd::memory::MemoryCounter::getInstance().countIn( + sd::memory::MemoryType::DEVICE, getLenInBytes()); + } + } } +//////////////////////////////////////////////////////////////////////// +void DataBuffer::syncToPrimary(const LaunchContext* context, + const bool forceSync) { + if (isPrimaryActual() && !forceSync) { + return; + } + + allocatePrimary(); + + auto res = cudaStreamSynchronize(*context->getCudaStream()); + if (res != 0) + throw cuda_exception::build( + "DataBuffer::syncToPrimary failed to to some previous kernel failre", + res); + + res = cudaMemcpy(_primaryBuffer, _specialBuffer, getLenInBytes(), + cudaMemcpyDeviceToHost); + if (res != 0) + throw cuda_exception::build("DataBuffer::syncToPrimary cudaMemcpy failed", + res); + + readPrimary(); +} //////////////////////////////////////////////////////////////////////// void DataBuffer::syncToSpecial(const bool forceSync) { - // in this case there's nothing to do here - if (_primaryBuffer == nullptr) - return; + // in this case there's nothing to do here + if (_primaryBuffer == nullptr) return; - if(isSpecialActual() && !forceSync) { - return; - } + if (isSpecialActual() && !forceSync) { + return; + } - allocateSpecial(); + allocateSpecial(); - auto res = cudaMemcpy(_specialBuffer, _primaryBuffer, getLenInBytes(), cudaMemcpyHostToDevice); - if (res != 0) - throw cuda_exception::build("DataBuffer::syncToSpecial cudaMemcpy failed", res); + auto res = cudaMemcpy(_specialBuffer, _primaryBuffer, getLenInBytes(), + cudaMemcpyHostToDevice); + if (res != 0) + throw cuda_exception::build("DataBuffer::syncToSpecial cudaMemcpy failed", + res); - readSpecial(); + readSpecial(); } - //////////////////////////////////////////////////////////////////////// void DataBuffer::deleteSpecial() { - - if(_isOwnerSpecial && _specialBuffer != nullptr && getLenInBytes() != 0) { - auto p = reinterpret_cast(_specialBuffer); - RELEASE_SPECIAL(p, _workspace); - _specialBuffer = nullptr; - _isOwnerSpecial = false; - - // count out towards DataBuffer device, only if we're not in workspace - if (_workspace == nullptr) { - sd::memory::MemoryCounter::getInstance().countOut(_deviceId, getLenInBytes()); - sd::memory::MemoryCounter::getInstance().countOut(sd::memory::MemoryType::DEVICE, getLenInBytes()); - } + if (_isOwnerSpecial && _specialBuffer != nullptr && getLenInBytes() != 0) { + auto p = reinterpret_cast(_specialBuffer); + RELEASE_SPECIAL(p, _workspace); + _specialBuffer = nullptr; + _isOwnerSpecial = false; + + // count out towards DataBuffer device, only if we're not in workspace + if (_workspace == nullptr) { + sd::memory::MemoryCounter::getInstance().countOut(_deviceId, + getLenInBytes()); + sd::memory::MemoryCounter::getInstance().countOut( + sd::memory::MemoryType::DEVICE, getLenInBytes()); } + } } - //////////////////////////////////////////////////////////////////////// void DataBuffer::setCountersToZero() { - - _counter.store(0L); - _writePrimary.store(0L); - _writeSpecial.store(0L); - _readPrimary.store(0L); - _readSpecial.store(0L); + _counter.store(0L); + _writePrimary.store(0L); + _writeSpecial.store(0L); + _readPrimary.store(0L); + _readSpecial.store(0L); } //////////////////////////////////////////////////////////////////////// void DataBuffer::copyCounters(const DataBuffer& other) { - - _counter.store(other._counter); - _writePrimary.store(other._readSpecial); - _writeSpecial.store(other._readPrimary); - _readPrimary.store(other._writeSpecial); - _readSpecial.store(other._writePrimary); + _counter.store(other._counter); + _writePrimary.store(other._readSpecial); + _writeSpecial.store(other._readPrimary); + _readPrimary.store(other._writeSpecial); + _readSpecial.store(other._writePrimary); } //////////////////////////////////////////////////////////////////////// -void DataBuffer::copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes, const Nd4jLong offsetThis, const Nd4jLong offsetOther) { // copies only to special buffer - - if(other._primaryBuffer == nullptr && other._specialBuffer == nullptr) - return; - - if(sizeToCopyinBytes == 0) - sizeToCopyinBytes = other.getLenInBytes(); - if(sizeToCopyinBytes == 0) - return; - - if(other.isPrimaryActual()) { - auto res = cudaMemcpy(static_cast(_specialBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast(other._primaryBuffer) + offsetOther * DataTypeUtils::sizeOfElement(other._dataType), sizeToCopyinBytes, cudaMemcpyHostToDevice); - if (res != 0) - throw cuda_exception::build("DataBuffer::copyBufferFrom: cudaMemcpy_cudaMemcpyHostToDevice failed!", res); - other.readPrimary(); - } - else { - auto res = cudaMemcpy(static_cast(_specialBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast(other._specialBuffer) + offsetOther * DataTypeUtils::sizeOfElement(other._dataType), sizeToCopyinBytes, cudaMemcpyDeviceToDevice); - if (res != 0) - throw cuda_exception::build("DataBuffer::copyBufferFrom: cudaMemcpy_cudaMemcpyDeviceToDevice failed!", res); - other.readSpecial(); - } - - writeSpecial(); +void DataBuffer::copyBufferFrom( + const DataBuffer& other, size_t sizeToCopyinBytes, + const Nd4jLong offsetThis, + const Nd4jLong offsetOther) { // copies only to special buffer + + if (other._primaryBuffer == nullptr && other._specialBuffer == nullptr) + return; + + if (sizeToCopyinBytes == 0) sizeToCopyinBytes = other.getLenInBytes(); + if (sizeToCopyinBytes == 0) return; + + if (other.isPrimaryActual()) { + auto res = cudaMemcpy( + static_cast(_specialBuffer) + + offsetThis * DataTypeUtils::sizeOfElement(_dataType), + static_cast(other._primaryBuffer) + + offsetOther * DataTypeUtils::sizeOfElement(other._dataType), + sizeToCopyinBytes, cudaMemcpyHostToDevice); + if (res != 0) + throw cuda_exception::build( + "DataBuffer::copyBufferFrom: cudaMemcpy_cudaMemcpyHostToDevice " + "failed!", + res); + other.readPrimary(); + } else { + auto res = cudaMemcpy( + static_cast(_specialBuffer) + + offsetThis * DataTypeUtils::sizeOfElement(_dataType), + static_cast(other._specialBuffer) + + offsetOther * DataTypeUtils::sizeOfElement(other._dataType), + sizeToCopyinBytes, cudaMemcpyDeviceToDevice); + if (res != 0) + throw cuda_exception::build( + "DataBuffer::copyBufferFrom: cudaMemcpy_cudaMemcpyDeviceToDevice " + "failed!", + res); + other.readSpecial(); + } + + writeSpecial(); } //////////////////////////////////////////////////////////////////////// -void DataBuffer::copyBufferFromHost(const void* hostBuffer, size_t sizeToCopyinBytes, const Nd4jLong offsetThis, const Nd4jLong offsetHostBuffer) { // copies only to special buffer - - if(hostBuffer == nullptr) - return; - - if(sizeToCopyinBytes == 0) - sizeToCopyinBytes = getLenInBytes(); - if(sizeToCopyinBytes == 0) - return; - - auto res = cudaMemcpy(static_cast(_specialBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast(hostBuffer) + offsetHostBuffer * DataTypeUtils::sizeOfElement(_dataType), sizeToCopyinBytes, cudaMemcpyHostToDevice); - if (res != 0) - throw cuda_exception::build("DataBuffer::copyBufferFromHost: cudaMemcpy_cudaMemcpyHostToDevice failed!", res); - - writeSpecial(); +void DataBuffer::copyBufferFromHost( + const void* hostBuffer, size_t sizeToCopyinBytes, const Nd4jLong offsetThis, + const Nd4jLong offsetHostBuffer) { // copies only to special buffer + + if (hostBuffer == nullptr) return; + + if (sizeToCopyinBytes == 0) sizeToCopyinBytes = getLenInBytes(); + if (sizeToCopyinBytes == 0) return; + + auto res = + cudaMemcpy(static_cast(_specialBuffer) + + offsetThis * DataTypeUtils::sizeOfElement(_dataType), + static_cast(hostBuffer) + + offsetHostBuffer * DataTypeUtils::sizeOfElement(_dataType), + sizeToCopyinBytes, cudaMemcpyHostToDevice); + if (res != 0) + throw cuda_exception::build( + "DataBuffer::copyBufferFromHost: cudaMemcpy_cudaMemcpyHostToDevice " + "failed!", + res); + + writeSpecial(); } //////////////////////////////////////////////////////////////////////// void DataBuffer::setSpecial(void* special, const bool isOwnerSpecial) { - - deleteSpecial(); - _specialBuffer = special; - _isOwnerSpecial = isOwnerSpecial; + deleteSpecial(); + _specialBuffer = special; + _isOwnerSpecial = isOwnerSpecial; } - //////////////////////////////////////////////////////////////////////// -void DataBuffer::allocateBuffers(const bool allocBoth) { // always allocate special buffer only (cuda case) +void DataBuffer::allocateBuffers( + const bool allocBoth) { // always allocate special buffer only (cuda case) - allocateSpecial(); + allocateSpecial(); - if(allocBoth) - allocatePrimary(); + if (allocBoth) allocatePrimary(); } //////////////////////////////////////////////////////////////////////// void DataBuffer::setToZeroBuffers(const bool both) { - cudaMemsetAsync(special(), 0, getLenInBytes(), *LaunchContext::defaultContext()->getCudaStream()); - auto res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream()); - if (res != 0) - throw cuda_exception::build("DataBuffer::setToZeroBuffers: streamSync failed!", res); - - writeSpecial(); - - if(both) { - memset(primary(), 0, getLenInBytes()); - readPrimary(); - } + cudaMemsetAsync(special(), 0, getLenInBytes(), + *LaunchContext::defaultContext()->getCudaStream()); + auto res = + cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream()); + if (res != 0) + throw cuda_exception::build( + "DataBuffer::setToZeroBuffers: streamSync failed!", res); + + writeSpecial(); + + if (both) { + memset(primary(), 0, getLenInBytes()); + readPrimary(); + } } ///////////////////////// -void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) { - if (src._lenInBytes > dst._lenInBytes) - throw std::runtime_error("DataBuffer::memcpy: Source data buffer is larger than destination"); - - - int res = 0; - if (src.isSpecialActual()) { - res = cudaMemcpyAsync(dst._specialBuffer, src._specialBuffer, src.getLenInBytes(), cudaMemcpyDeviceToDevice, *LaunchContext::defaultContext()->getCudaStream()); - } else if (src.isPrimaryActual()) { - res = cudaMemcpyAsync(dst._specialBuffer, src._primaryBuffer, src.getLenInBytes(), cudaMemcpyHostToDevice, *LaunchContext::defaultContext()->getCudaStream()); - } - - if (res != 0) - throw cuda_exception::build("DataBuffer::memcpy: cudaMemcpyAsync failed!", res); - - res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream()); - if (res != 0) - throw cuda_exception::build("DataBuffer::memcpy: streamSync failed!", res); - - dst.writeSpecial(); +void DataBuffer::memcpy(const DataBuffer& dst, const DataBuffer& src) { + if (src._lenInBytes > dst._lenInBytes) + throw std::runtime_error( + "DataBuffer::memcpy: Source data buffer is larger than destination"); + + int res = 0; + if (src.isSpecialActual()) { + res = cudaMemcpyAsync(dst._specialBuffer, src._specialBuffer, + src.getLenInBytes(), cudaMemcpyDeviceToDevice, + *LaunchContext::defaultContext()->getCudaStream()); + } else if (src.isPrimaryActual()) { + res = cudaMemcpyAsync(dst._specialBuffer, src._primaryBuffer, + src.getLenInBytes(), cudaMemcpyHostToDevice, + *LaunchContext::defaultContext()->getCudaStream()); + } + + if (res != 0) + throw cuda_exception::build("DataBuffer::memcpy: cudaMemcpyAsync failed!", + res); + + res = + cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream()); + if (res != 0) + throw cuda_exception::build("DataBuffer::memcpy: streamSync failed!", res); + + dst.writeSpecial(); } //////////////////////////////////////////////////////////////////////// void DataBuffer::migrate() { - memory::Workspace* newWorkspace = nullptr; - void* newBuffer; - ALLOCATE_SPECIAL(newBuffer, newWorkspace, getLenInBytes(), int8_t); - auto res = cudaMemcpy(newBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToDevice); - if (res != 0) - throw cuda_exception::build("DataBuffer::migrate: cudaMemcpyAsync failed!", res); - - if (_isOwnerSpecial) { - // now we're releasing original buffer - RELEASE_SPECIAL(_specialBuffer, _workspace); - } - - _isOwnerSpecial = true; - _specialBuffer = newBuffer; + memory::Workspace* newWorkspace = nullptr; + void* newBuffer; + ALLOCATE_SPECIAL(newBuffer, newWorkspace, getLenInBytes(), int8_t); + auto res = cudaMemcpy(newBuffer, _specialBuffer, getLenInBytes(), + cudaMemcpyDeviceToDevice); + if (res != 0) + throw cuda_exception::build("DataBuffer::migrate: cudaMemcpyAsync failed!", + res); + + if (_isOwnerSpecial) { + // now we're releasing original buffer + RELEASE_SPECIAL(_specialBuffer, _workspace); + } + + _isOwnerSpecial = true; + _specialBuffer = newBuffer; } -//////////////////////////////////////////////////////////////////////// -void DataBuffer::writePrimary() const {_writePrimary = ++_counter; } -void DataBuffer::writeSpecial() const { _writeSpecial = ++_counter; } -void DataBuffer::readPrimary() const { _readPrimary = ++_counter; } -void DataBuffer::readSpecial() const { _readSpecial = ++_counter; } -bool DataBuffer::isPrimaryActual() const { return (_writePrimary.load() > _writeSpecial.load() || _readPrimary.load() > _writeSpecial.load()); } -bool DataBuffer::isSpecialActual() const { return (_writeSpecial.load() > _writePrimary.load() || _readSpecial.load() > _writePrimary.load()); } +void* DataBuffer::platform() { return special(); } +//////////////////////////////////////////////////////////////////////// +void DataBuffer::writePrimary() const { _writePrimary = ++_counter; } +void DataBuffer::writeSpecial() const { _writeSpecial = ++_counter; } +void DataBuffer::readPrimary() const { _readPrimary = ++_counter; } +void DataBuffer::readSpecial() const { _readSpecial = ++_counter; } +bool DataBuffer::isPrimaryActual() const { + return (_writePrimary.load() > _writeSpecial.load() || + _readPrimary.load() > _writeSpecial.load()); } +bool DataBuffer::isSpecialActual() const { + return (_writeSpecial.load() > _writePrimary.load() || + _readSpecial.load() > _writePrimary.load()); +} + +} // namespace sd diff --git a/libnd4j/include/graph/execution/impl/LogicScope.cpp b/libnd4j/include/array/cuda/ManagedDataBuffer.cu similarity index 66% rename from libnd4j/include/graph/execution/impl/LogicScope.cpp rename to libnd4j/include/array/cuda/ManagedDataBuffer.cu index 1773aa6ea766..610b2faf9a3b 100644 --- a/libnd4j/include/graph/execution/impl/LogicScope.cpp +++ b/libnd4j/include/array/cuda/ManagedDataBuffer.cu @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -15,19 +15,13 @@ ******************************************************************************/ // -// Created by raver119 on 20.10.2017. +// @author raver119@gmail.com // -#include -#include - +#include namespace sd { - namespace graph { - Nd4jStatus LogicScope::processNode(Graph *graph, Node *node) { - // this op is basically no-op - // we just know it exists - return sd::Status::OK(); - } - } -} \ No newline at end of file +void *ManagedDataBuffer::primary() { return _descriptor.address(); } + +void *ManagedDataBuffer::special() { return nullptr; } +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/array/cuda/NDArray.cu b/libnd4j/include/array/cuda/NDArray.cu index f28e2ba22316..fd95c73637aa 100644 --- a/libnd4j/include/array/cuda/NDArray.cu +++ b/libnd4j/include/array/cuda/NDArray.cu @@ -19,560 +19,671 @@ #include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include -#include +#include +#include +#include +#include +#include #include -#include +#include #include +#include +#include #include -#include + #include -#include -#include -#include -#include -#include -#include -#include -#include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include namespace sd { -void* NDArray::platformBuffer() { return specialBuffer(); } -void const* NDArray::platformBuffer() const { return specialBuffer(); } +void* NDArray::platformBuffer() { return specialBuffer(); } +void const* NDArray::platformBuffer() const { return specialBuffer(); } -Nd4jLong const* NDArray::platformShapeInfo() const { return specialShapeInfo(); } -//Nd4jLong const* NDArray::platform() { return special(); } +Nd4jLong const* NDArray::platformShapeInfo() const { + return specialShapeInfo(); +} +// Nd4jLong const* NDArray::platformShapeInfo() { return +// specialShapeInfo(); } -void NDArray::syncToDevice() const { - auto currentDeviceId = AffinityManager::currentDeviceId(); - if (currentDeviceId != _deviceId) { - // first of all we update shapeInfo - const_cast(this)->setShapeInfo(this->shapeInfo()); +void NDArray::syncToDevice() const { + auto currentDeviceId = AffinityManager::currentDeviceId(); + if (currentDeviceId != _deviceId) { + // first of all we update shapeInfo + const_cast(this)->setShapeInfo(this->shapeInfo()); - // now we actually migrate data buffer - _buffer->migrate(); - } + // now we actually migrate data buffer + _buffer->migrate(); + } - _buffer->syncToSpecial(); + _buffer->syncToSpecial(); } -void NDArray::syncToHost() const { _buffer->syncToPrimary(getContext()); } -void NDArray::tickWriteHost() const { _buffer->writePrimary(); } -void NDArray::tickWriteDevice() const { _buffer->writeSpecial(); } -void NDArray::tickReadHost() const { _buffer->readPrimary(); } -void NDArray::tickReadDevice() const { _buffer->readSpecial(); } -void NDArray::tickBothActual() const { _buffer->writePrimary(); _buffer->readSpecial(); } -bool NDArray::isActualOnHostSide() const { return _buffer->isPrimaryActual(); } -bool NDArray::isActualOnDeviceSide() const { return _buffer->isSpecialActual(); } -void NDArray::makeBothBuffersActual() const { if(!isActualOnHostSide()) syncToHost(); if(!isActualOnDeviceSide()) syncToDevice(); } +void NDArray::syncToHost() const { _buffer->syncToPrimary(getContext()); } +void NDArray::tickWriteHost() const { _buffer->writePrimary(); } +void NDArray::tickWriteDevice() const { _buffer->writeSpecial(); } +void NDArray::tickReadHost() const { _buffer->readPrimary(); } +void NDArray::tickReadDevice() const { _buffer->readSpecial(); } +void NDArray::tickBothActual() const { + _buffer->writePrimary(); + _buffer->readSpecial(); +} +bool NDArray::isActualOnHostSide() const { return _buffer->isPrimaryActual(); } +bool NDArray::isActualOnDeviceSide() const { + return _buffer->isSpecialActual(); +} +void NDArray::makeBothBuffersActual() const { + if (!isActualOnHostSide()) syncToHost(); + if (!isActualOnDeviceSide()) syncToDevice(); +} /////////////////////////////////////////////////////////////////// -template -__global__ static void fillAsTriangularCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const T val, const int lower, const int upper) { - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ int zRank, xRank, areSameOffsets, *sharedMem; // xRank == zRank always, except when xRank = 1, in this case zRank = 2 - __shared__ Nd4jLong zLen, totalThreads; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - areSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - xRank = shape::rank(xShapeInfo); - zRank = shape::rank(zShapeInfo); - zLen = shape::length(zShapeInfo); - totalThreads = gridDim.x * blockDim.x; - } - __syncthreads(); - - auto coords = sharedMem + threadIdx.x * zRank; - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < zLen; i += totalThreads) { - - shape::index2coords(i, zShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); - - // if( (row + upper < col) || (row + lower > col) ) - if((coords[zRank - 2] + upper < coords[zRank - 1]) || (coords[zRank - 2] + lower > coords[zRank - 1])) - z[zOffset] = val; - else if(vx != vz) { // when x and z are different arrays - if(xRank != zRank) - coords[0] = coords[1]; - const auto xOffset = areSameOffsets ? zOffset : shape::getOffset(xShapeInfo, coords); - z[zOffset] = x[xOffset]; - } +template +__global__ static void fillAsTriangularCuda( + const void* vx, const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const T val, const int lower, const int upper) { + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ int zRank, xRank, areSameOffsets, + *sharedMem; // xRank == zRank always, except when xRank = 1, in this case + // zRank = 2 + __shared__ Nd4jLong zLen, totalThreads; // xLen == zLen, except when xRank = + // 1, in this case zLen = 2*xLen + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + areSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + xRank = shape::rank(xShapeInfo); + zRank = shape::rank(zShapeInfo); + zLen = shape::length(zShapeInfo); + totalThreads = gridDim.x * blockDim.x; + } + __syncthreads(); + + auto coords = sharedMem + threadIdx.x * zRank; + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < zLen; i += totalThreads) { + shape::index2coords(i, zShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); + + // if( (row + upper < col) || (row + lower > col) ) + if ((coords[zRank - 2] + upper < coords[zRank - 1]) || + (coords[zRank - 2] + lower > coords[zRank - 1])) + z[zOffset] = val; + else if (vx != vz) { // when x and z are different arrays + if (xRank != zRank) coords[0] = coords[1]; + const auto xOffset = + areSameOffsets ? zOffset : shape::getOffset(xShapeInfo, coords); + z[zOffset] = x[xOffset]; } + } } /////////////////////////////////////////////////////////////////// -template -void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& target, const char direction) { - - if (isS()) - throw std::runtime_error("NDArray::fillAsTriangular: you can't use this method on String array!"); - - if(!isSameShape(target) && !(rankOf() == 1 && target.rankOf() == 2 && sizeAt(0) == target.sizeAt(0) && sizeAt(0) == target.sizeAt(1))) - throw std::string("NDArray::fillAsTriangular method: wrong shape of target array !"); - - if (direction == 'u') - lower = -target.sizeAt(-2); - else if (direction == 'l') - upper = target.sizeAt(-1); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (target.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * target.rankOf() + 128; - - PointersManager manager(getContext(), "NDArray::fillAsTriangular"); - - NDArray::prepareSpecialUse({&target}, {this}); - fillAsTriangularCuda<<getCudaStream()>>>(platformBuffer(), platformShapeInfo(), target.platformBuffer(), target.platformShapeInfo(), static_cast(val), lower, upper); - NDArray::registerSpecialUse({&target}, {this}); - - manager.synchronize(); +template +void NDArray::fillAsTriangular(const float val, int lower, int upper, + NDArray& target, const char direction) { + if (isS()) + throw std::runtime_error( + "NDArray::fillAsTriangular: you can't use this method on String " + "array!"); + + if (!isSameShape(target) && + !(rankOf() == 1 && target.rankOf() == 2 && + sizeAt(0) == target.sizeAt(0) && sizeAt(0) == target.sizeAt(1))) + throw std::string( + "NDArray::fillAsTriangular method: wrong shape of target array !"); + + if (direction == 'u') + lower = -target.sizeAt(-2); + else if (direction == 'l') + upper = target.sizeAt(-1); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (target.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(int) * target.rankOf() + 128; + + PointersManager manager(getContext(), "NDArray::fillAsTriangular"); + + NDArray::prepareSpecialUse({&target}, {this}); + fillAsTriangularCuda<<getCudaStream()>>>( + platformBuffer(), platformShapeInfo(), target.platformBuffer(), + target.platformShapeInfo(), static_cast(val), lower, upper); + NDArray::registerSpecialUse({&target}, {this}); + + manager.synchronize(); } -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void NDArray::fillAsTriangular, (const float val, int lower, int upper, NDArray& target, const char direction), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template SD_EXPORT void NDArray::fillAsTriangular, + (const float val, int lower, int upper, NDArray& target, + const char direction), + LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// -template -__global__ static void identityMatrixCuda(void* vx, const Nd4jLong* xShapeInfo, const T val) { - - auto x = reinterpret_cast(vx); - - __shared__ int rank, *sharedMem; - __shared__ Nd4jLong len, totalThreads; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - rank = shape::rank(xShapeInfo); - len = shape::length(xShapeInfo); - totalThreads = gridDim.x * blockDim.x; - } - __syncthreads(); - - auto coords = sharedMem + threadIdx.x * rank; - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < len; i += totalThreads) { - - shape::index2coords(i, xShapeInfo, coords); - const auto offset = shape::getOffset(xShapeInfo, coords); - - if(coords[rank - 2] == coords[rank - 1]) // row == col -> on diagonal - x[offset] = val; - else - x[offset] = static_cast(0); - } +template +__global__ static void identityMatrixCuda(void* vx, const Nd4jLong* xShapeInfo, + const T val) { + auto x = reinterpret_cast(vx); + + __shared__ int rank, *sharedMem; + __shared__ Nd4jLong len, totalThreads; // xLen == zLen, except when xRank = + // 1, in this case zLen = 2*xLen + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + rank = shape::rank(xShapeInfo); + len = shape::length(xShapeInfo); + totalThreads = gridDim.x * blockDim.x; + } + __syncthreads(); + + auto coords = sharedMem + threadIdx.x * rank; + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < len; i += totalThreads) { + shape::index2coords(i, xShapeInfo, coords); + const auto offset = shape::getOffset(xShapeInfo, coords); + + if (coords[rank - 2] == coords[rank - 1]) // row == col -> on diagonal + x[offset] = val; + else + x[offset] = static_cast(0); + } } /////////////////////////////////////////////////////////////////// -template -static void identityMatrixCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, void* vx, const Nd4jLong *xShapeInfo, const float val) { - - identityMatrixCuda<<>>(vx, xShapeInfo, static_cast(val)); +template +static void identityMatrixCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, + const int sharedMem, + const cudaStream_t* stream, void* vx, + const Nd4jLong* xShapeInfo, + const float val) { + identityMatrixCuda<<>>( + vx, xShapeInfo, static_cast(val)); } -BUILD_SINGLE_TEMPLATE(template void identityMatrixCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, void* vx, const Nd4jLong *xShapeInfo, const float val), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void identityMatrixCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const int sharedMem, const cudaStream_t* stream, + void* vx, const Nd4jLong* xShapeInfo, const float val), + LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// void NDArray::setIdentity() { - if (isS()) - throw std::runtime_error("NDArray::setIdentity: you can't use this method on String array!"); - - // if (rankOf() != 2) - // throw std::runtime_error("NDArray::setIdentity: method should work only for 2D tensors. But " + toStringValue(rankOf()) + " was given."); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * rankOf() + 128; - - PointersManager manager(getContext(), "NDArray::setIdentity"); - - syncToDevice(); - BUILD_SINGLE_SELECTOR(dataType(), identityMatrixCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), platformBuffer(), platformShapeInfo(), 1.f), LIBND4J_TYPES); - tickWriteDevice(); - - manager.synchronize(); + if (isS()) + throw std::runtime_error( + "NDArray::setIdentity: you can't use this method on String array!"); + + // if (rankOf() != 2) + // throw std::runtime_error("NDArray::setIdentity: method should work only + // for 2D tensors. But " + toStringValue(rankOf()) + " was given."); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(int) * rankOf() + 128; + + PointersManager manager(getContext(), "NDArray::setIdentity"); + + syncToDevice(); + BUILD_SINGLE_SELECTOR( + dataType(), identityMatrixCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), + platformBuffer(), platformShapeInfo(), 1.f), + LIBND4J_TYPES); + tickWriteDevice(); + + manager.synchronize(); } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// void NDArray::swapUnsafe(NDArray& other) { - auto xType = this->dataType(); + auto xType = this->dataType(); - if (xType != other.dataType()) - throw std::runtime_error("NDArray::swapUnsage method: both arrays must have the same data type"); + if (xType != other.dataType()) + throw std::runtime_error( + "NDArray::swapUnsage method: both arrays must have the same data type"); - if(specialBuffer() == nullptr || other.specialBuffer() == nullptr) - throw std::runtime_error("NDArray::swapUnsafe method: input array should not be empty!"); + if (specialBuffer() == nullptr || other.specialBuffer() == nullptr) + throw std::runtime_error( + "NDArray::swapUnsafe method: input array should not be empty!"); - if(lengthOf() != other.lengthOf()) - throw std::runtime_error("NDArray::swapUnsafe method: input arrays should have the same length!"); + if (lengthOf() != other.lengthOf()) + throw std::runtime_error( + "NDArray::swapUnsafe method: input arrays should have the same " + "length!"); - PointersManager manager(getContext(), "NDArray::swapUnsafe"); + PointersManager manager(getContext(), "NDArray::swapUnsafe"); - prepareSpecialUse({&other, this}, {&other, this}); - BUILD_SINGLE_SELECTOR(xType, templatedSwapUnsafe, (specialBuffer(), specialShapeInfo(), other.specialBuffer(), other.specialShapeInfo(), getContext()->getCudaStream()), LIBND4J_TYPES); - registerSpecialUse({&other, this}, {&other, this}); + prepareSpecialUse({&other, this}, {&other, this});BUILD_SINGLE_SELECTOR( + xType, templatedSwapUnsafe, + (specialBuffer(), specialShapeInfo(), other.specialBuffer(), + other.specialShapeInfo(), getContext()->getCudaStream()), + LIBND4J_TYPES);registerSpecialUse({&other, this}, {&other, this}); manager.synchronize(); } //////////////////////////////////////////////////////////////////////// void NDArray::synchronize(const char* msg) const { - auto res = cudaStreamSynchronize(*(getContext()->getCudaStream())); - if (res != 0) - throw std::runtime_error(msg + std::string(": synchronization failed !")); + auto res = cudaStreamSynchronize(*(getContext()->getCudaStream())); + if (res != 0) + throw std::runtime_error(msg + std::string(": synchronization failed !")); } //////////////////////////////////////////////////////////////////////// -void NDArray::prepareSpecialUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { - - for (const auto& a : readList) - if(a != nullptr) - a->syncToDevice(); - - for (const auto& a : writeList) { - if (a != nullptr) { - a->getDataBuffer()->allocateSpecial(); - if (synchronizeWritables) - a->syncToDevice(); - } +void NDArray::prepareSpecialUse(const std::vector& writeList, + const std::vector& readList, + bool synchronizeWritables) { + for (const auto& a : readList) + if (a != nullptr) a->syncToDevice(); + + for (const auto& a : writeList) { + if (a != nullptr) { + a->getDataBuffer()->allocateSpecial(); + if (synchronizeWritables) a->syncToDevice(); } + } } //////////////////////////////////////////////////////////////////////// -void NDArray::registerSpecialUse(const std::vector& writeList, const std::vector& readList) { +void NDArray::registerSpecialUse(const std::vector& writeList, + const std::vector& readList) { + for (const auto& p : readList) + if (p != nullptr) p->tickReadDevice(); - for (const auto& p : readList) - if(p != nullptr) - p->tickReadDevice(); - - for (const auto& p : writeList) - if (p != nullptr) - p->tickWriteDevice(); + for (const auto& p : writeList) + if (p != nullptr) p->tickWriteDevice(); } //////////////////////////////////////////////////////////////////////// -void NDArray::preparePrimaryUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { - - for (const auto& a : readList) - if(a != nullptr) - a->syncToHost(); - - for (const auto& a : writeList) { - if (a != nullptr) { - a->getDataBuffer()->allocatePrimary(); - if (synchronizeWritables) - a->syncToHost(); - } +void NDArray::preparePrimaryUse(const std::vector& writeList, + const std::vector& readList, + bool synchronizeWritables) { + for (const auto& a : readList) + if (a != nullptr) a->syncToHost(); + + for (const auto& a : writeList) { + if (a != nullptr) { + a->getDataBuffer()->allocatePrimary(); + if (synchronizeWritables) a->syncToHost(); } + } } //////////////////////////////////////////////////////////////////////// -void NDArray::registerPrimaryUse(const std::vector& writeList, const std::vector& readList) { +void NDArray::registerPrimaryUse(const std::vector& writeList, + const std::vector& readList) { + for (const auto& p : readList) + if (p != nullptr) p->tickReadHost(); - for (const auto& p : readList) - if(p != nullptr) - p->tickReadHost(); - - for (const auto& p : writeList) - if (p != nullptr) - p->tickWriteHost(); + for (const auto& p : writeList) + if (p != nullptr) p->tickWriteHost(); } ////////////////////////////////////////////////////////////////////////// void NDArray::syncShape() const { - cudaMemcpy(const_cast(specialShapeInfo()), shapeInfo(), shape::shapeInfoByteLength(shapeInfo()), cudaMemcpyHostToDevice); + cudaMemcpy(const_cast(specialShapeInfo()), shapeInfo(), + shape::shapeInfoByteLength(shapeInfo()), cudaMemcpyHostToDevice); } ////////////////////////////////////////////////////////////////////////// void const* NDArray::specialBufferWithOffset(Nd4jLong offset) const { - return specialBuffer() != nullptr ? static_cast(specialBuffer()) + (offset * sizeOfT()) : nullptr; + return specialBuffer() != nullptr + ? static_cast(specialBuffer()) + + (offset * sizeOfT()) + : nullptr; } -void* NDArray::specialBufferWithOffset(Nd4jLong offset){ - return specialBuffer() != nullptr ? static_cast(specialBuffer()) + (offset * sizeOfT()) : nullptr; +void* NDArray::specialBufferWithOffset(Nd4jLong offset) { + return specialBuffer() != nullptr + ? static_cast(specialBuffer()) + (offset * sizeOfT()) + : nullptr; } ////////////////////////////////////////////////////////////////////////// // change an array by repeating it the number of times given by reps. NDArray NDArray::tile(const std::vector& reps) const { - int dim = reps.size(); - Nd4jLong product = 1; - for(const auto& item : reps) - product *= item; - - if(product < 1) - throw std::runtime_error("NDArray::tile method: one of the elements in reps array is zero !"); - - int rankOld = rankOf(); - int diff = rankOld - dim; - if(product==1) { // in this case 2 possibilities are present: just reshape or nothing to do - NDArray result(*this); - if(diff < 0) { // reshape to higher dimension - std::vector shapeNew = reps; // need to have unities at first "diff" positions of new shape - memcpy(&shapeNew[-diff], result.shapeInfo()+1, rankOld * sizeof(Nd4jLong)); // put old shape numbers at rest of positions - result.reshapei(ordering(), shapeNew); - } - return result; // nothing to do, if diff >= 0 -> identity tile + int dim = reps.size(); + Nd4jLong product = 1; + for (const auto& item : reps) product *= item; + + if (product < 1) + throw std::runtime_error( + "NDArray::tile method: one of the elements in reps array is zero !"); + + int rankOld = rankOf(); + int diff = rankOld - dim; + if (product == 1) { // in this case 2 possibilities are present: just reshape + // or nothing to do + NDArray result(*this); + if (diff < 0) { // reshape to higher dimension + std::vector shapeNew = + reps; // need to have unities at first "diff" positions of new shape + memcpy( + &shapeNew[-diff], result.shapeInfo() + 1, + rankOld * + sizeof(Nd4jLong)); // put old shape numbers at rest of positions + result.reshapei(ordering(), shapeNew); } - - // evaluate shapeInfo for resulting array - auto newShapeInfo = ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace()); - // create new buffer, in any case the memory amount new buffer points to is bigger then those for old _buffer - std::shared_ptr newBuff = std::make_shared(shape::length(newShapeInfo) * sizeOfT(), dataType(), getContext()->getWorkspace(), true); - // assign new shape and new buffer to resulting array - NDArray result(newBuff, ShapeDescriptor(newShapeInfo), getContext()); - - // fill newBuff, loop through all elements of newBuff - // looping through buffer() goes automatically by means of getSubArrayIndex applying - const auto resultLen = result.lengthOf(); - auto xType = this->dataType(); - auto stream = getContext()->getCudaStream(); - - prepareSpecialUse({&result}, {this}); - BUILD_SINGLE_SELECTOR(xType, tileKernelH, (this->specialBuffer(), this->specialShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), resultLen, stream), LIBND4J_TYPES); - registerSpecialUse({&result}, {this}); - - return result; + return result; // nothing to do, if diff >= 0 -> identity tile + } + + // evaluate shapeInfo for resulting array + auto newShapeInfo = + ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace()); + // create new buffer, in any case the memory amount new buffer points to is + // bigger then those for old _buffer + std::shared_ptr newBuff = std::make_shared( + shape::length(newShapeInfo) * sizeOfT(), dataType(), + getContext()->getWorkspace(), true); + // assign new shape and new buffer to resulting array + NDArray result(newBuff, ShapeDescriptor(newShapeInfo), getContext()); + + // fill newBuff, loop through all elements of newBuff + // looping through buffer() goes automatically by means of getSubArrayIndex + // applying + const auto resultLen = result.lengthOf(); + auto xType = this->dataType(); + auto stream = getContext()->getCudaStream(); + + prepareSpecialUse({&result}, {this}); + BUILD_SINGLE_SELECTOR( + xType, tileKernelH, + (this->specialBuffer(), this->specialShapeInfo(), result.specialBuffer(), + result.specialShapeInfo(), resultLen, stream), + LIBND4J_TYPES); + registerSpecialUse({&result}, {this}); + + return result; } ////////////////////////////////////////////////////////////////////////// // change an array by repeating it the number of times given by reps. void NDArray::tile(const std::vector& reps, NDArray& target) const { - - auto repProd = shape::prodLong(reps.data(), reps.size()); - if (repProd < 1) - throw std::runtime_error("NDArray::tile: reps can't contain 0s"); - - // evaluate true tile shapeInfo for comparison with target shapeInfo - auto newShapeInfo = ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace()); - if(!shape::equalsSoft(newShapeInfo, target.shapeInfo())) { - throw std::runtime_error("NDArray::tile method - shapeInfo of target array is not suitable for tile operation !"); - } - - // fill newBuff, loop through all elements of newBuff - // looping through buffer() goes automatically by means of getSubArrayIndex applying - const int ews = target.ews(); - const int targetLen = target.lengthOf(); - auto stream = getContext()->getCudaStream(); - - prepareSpecialUse({&target}, {this}); - BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), tileKernelHH, (specialBuffer(), specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES); - registerSpecialUse({&target}, {this}); + auto repProd = shape::prodLong(reps.data(), reps.size()); + if (repProd < 1) + throw std::runtime_error("NDArray::tile: reps can't contain 0s"); + + // evaluate true tile shapeInfo for comparison with target shapeInfo + auto newShapeInfo = + ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace()); + if (!shape::equalsSoft(newShapeInfo, target.shapeInfo())) { + throw std::runtime_error( + "NDArray::tile method - shapeInfo of target array is not suitable for " + "tile operation !"); + } + + // fill newBuff, loop through all elements of newBuff + // looping through buffer() goes automatically by means of getSubArrayIndex + // applying + const int ews = target.ews(); + const int targetLen = target.lengthOf(); + auto stream = getContext()->getCudaStream(); + + prepareSpecialUse({&target}, {this}); + BUILD_SINGLE_SELECTOR_TWICE( + target.dataType(), tileKernelHH, + (specialBuffer(), specialShapeInfo(), target.specialBuffer(), + target.specialShapeInfo(), targetLen, ews, stream), + LIBND4J_TYPES); + registerSpecialUse({&target}, {this}); } ////////////////////////////////////////////////////////////////////////// void NDArray::tile(NDArray& target) const { - if(rankOf() > target.rankOf()) - throw std::runtime_error("NDArray::tile method - rank of target array must be bigger or equal to the rank of this array !"); - - if(!ShapeUtils::areShapesBroadcastable(*this, target)) - throw std::runtime_error("NDArray::tile method - shapeInfo of target array is not suitable for tile operation !"); - - // fill newBuff, loop through all elements of newBuff - // looping through getBuffer() goes automatically by means of getSubArrayIndex applying - const auto ews = target.ews(); - const auto targetLen = target.lengthOf(); - auto stream = getContext()->getCudaStream(); - - prepareSpecialUse({&target}, {this}); - BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), tileKernelHH, (specialBuffer(), specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES); - registerSpecialUse({&target}, {this}); + if (rankOf() > target.rankOf()) + throw std::runtime_error( + "NDArray::tile method - rank of target array must be bigger or equal " + "to the rank of this array !"); + + if (!ShapeUtils::areShapesBroadcastable(*this, target)) + throw std::runtime_error( + "NDArray::tile method - shapeInfo of target array is not suitable for " + "tile operation !"); + + // fill newBuff, loop through all elements of newBuff + // looping through getBuffer() goes automatically by means of getSubArrayIndex + // applying + const auto ews = target.ews(); + const auto targetLen = target.lengthOf(); + auto stream = getContext()->getCudaStream(); + + prepareSpecialUse({&target}, {this}); + BUILD_SINGLE_SELECTOR_TWICE( + target.dataType(), tileKernelHH, + (specialBuffer(), specialShapeInfo(), target.specialBuffer(), + target.specialShapeInfo(), targetLen, ews, stream), + LIBND4J_TYPES); + registerSpecialUse({&target}, {this}); } //////////////////////////////////////////////////////////////////////// -template +template __global__ static void repeatCuda(const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, const int* repeats, const int repSize, const int axis) { + const X* x = reinterpret_cast(vx); + Z* z = reinterpret_cast(vz); - const X* x = reinterpret_cast(vx); - Z* z = reinterpret_cast(vz); + __shared__ int rank, *sharedMem; + __shared__ Nd4jLong zLen, totalThreads; // xLen = zLen - __shared__ int rank, *sharedMem; - __shared__ Nd4jLong zLen, totalThreads; // xLen = zLen + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - if (threadIdx.x == 0) { - - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - rank = shape::rank(zShapeInfo); // xRank = zRank - zLen = shape::length(zShapeInfo); // xLen <= zLen - - totalThreads = gridDim.x * blockDim.x; - } + rank = shape::rank(zShapeInfo); // xRank = zRank + zLen = shape::length(zShapeInfo); // xLen <= zLen - __syncthreads(); + totalThreads = gridDim.x * blockDim.x; + } - auto coords = sharedMem + threadIdx.x * rank; + __syncthreads(); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto coords = sharedMem + threadIdx.x * rank; - for (Nd4jLong i = tid; i < zLen; i += totalThreads) { + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - shape::index2coords(i, zShapeInfo, coords); + for (Nd4jLong i = tid; i < zLen; i += totalThreads) { + shape::index2coords(i, zShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); - if(repSize > 1) { - for (uint j = 0; j < repSize; ++j) { - coords[axis] -= repeats[j]; - if (coords[axis] < 0) { - coords[axis] = j; - break; - } - } + if (repSize > 1) { + for (uint j = 0; j < repSize; ++j) { + coords[axis] -= repeats[j]; + if (coords[axis] < 0) { + coords[axis] = j; + break; } - else - coords[axis] /= repeats[0]; + } + } else + coords[axis] /= repeats[0]; - z[zOffset] = x[shape::getOffset(xShapeInfo, coords)]; - } + z[zOffset] = x[shape::getOffset(xShapeInfo, coords)]; + } } ////////////////////////////////////////////////////////////////////////// -template -static void repeatCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const int* repeats, const int repSize, - const int axis) { - - repeatCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, repeats, repSize, axis); +template +static void repeatCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, + const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const int* repeats, + const int repSize, const int axis) { + repeatCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, repeats, repSize, axis); } -BUILD_DOUBLE_TEMPLATE(template void repeatCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const int* repeats, const int repSize, const int axis), LIBND4J_TYPES, LIBND4J_TYPES); +BUILD_DOUBLE_TEMPLATE(template void repeatCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const int sharedMem, const cudaStream_t* stream, + const void* vx, const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const int* repeats, + const int repSize, const int axis), + LIBND4J_TYPES, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// // create new array by repeating it the number of times given by repeats NDArray NDArray::repeat(const int axis, const std::vector& repeats) const { + NDArray output('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), + dataType(), getContext()); - NDArray output('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), dataType(), getContext()); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = output.rankOf() * sizeof(int) * threadsPerBlock + 128; + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = output.rankOf() * sizeof(int) * threadsPerBlock + 128; - PointersManager manager(getContext(), "NDArray::repeat(const int axis, const std::vector& repeats)"); + PointersManager manager( + getContext(), + "NDArray::repeat(const int axis, const std::vector& repeats)"); - const int* reps = reinterpret_cast(manager.replicatePointer(repeats.data(), repeats.size() * sizeof(int))); + const int* reps = reinterpret_cast( + manager.replicatePointer(repeats.data(), repeats.size() * sizeof(int))); - prepareSpecialUse({&output}, {this}); - BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), specialBuffer(), specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), reps, repeats.size(), axis), LIBND4J_TYPES); - prepareSpecialUse({&output}, {this}); + prepareSpecialUse({&output}, {this}); + BUILD_SINGLE_SELECTOR_TWICE( + dataType(), repeatCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), + specialBuffer(), specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), reps, repeats.size(), axis), + LIBND4J_TYPES); + prepareSpecialUse({&output}, {this}); - manager.synchronize(); + manager.synchronize(); - return output; + return output; } ////////////////////////////////////////////////////////////////////////// // fill array by repeating it the number of times given by repeats -void NDArray::repeat(const int axis, const std::vector& repeats, NDArray& target) const { - - if(!target.isSameShape(ShapeUtils::evalRepeatShape(axis, repeats, *this))) - throw std::invalid_argument("NDArray::repeat(const int axis, const std::vector& repeats, NDArray& target) method: wrong shape of target array!"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (target.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = target.rankOf() * sizeof(int) * threadsPerBlock + 128; - - PointersManager manager(getContext(), "NDArray::repeat(const int axis, const std::vector& repeats)"); - - const int* reps = reinterpret_cast(manager.replicatePointer(repeats.data(), repeats.size() * sizeof(int))); - - prepareSpecialUse({&target}, {this}); - BUILD_DOUBLE_SELECTOR(dataType(), target.dataType(), repeatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), specialBuffer(), specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), reps, repeats.size(), axis), LIBND4J_TYPES, LIBND4J_TYPES); - prepareSpecialUse({&target}, {this}); - - manager.synchronize(); +void NDArray::repeat(const int axis, const std::vector& repeats, + NDArray& target) const { + if (!target.isSameShape(ShapeUtils::evalRepeatShape(axis, repeats, *this))) + throw std::invalid_argument( + "NDArray::repeat(const int axis, const std::vector& repeats, " + "NDArray& target) method: wrong shape of target array!"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (target.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = target.rankOf() * sizeof(int) * threadsPerBlock + 128; + + PointersManager manager( + getContext(), + "NDArray::repeat(const int axis, const std::vector& repeats)"); + + const int* reps = reinterpret_cast( + manager.replicatePointer(repeats.data(), repeats.size() * sizeof(int))); + + prepareSpecialUse({&target}, {this}); + BUILD_DOUBLE_SELECTOR( + dataType(), target.dataType(), repeatCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), + specialBuffer(), specialShapeInfo(), target.specialBuffer(), + target.specialShapeInfo(), reps, repeats.size(), axis), + LIBND4J_TYPES, LIBND4J_TYPES); + prepareSpecialUse({&target}, {this}); + + manager.synchronize(); } - //////////////////////////////////////////////////////////////////////// void* NDArray::specialBuffer() { + if (_buffer->special() == nullptr) { + syncToDevice(); + tickReadHost(); + } - if (_buffer->special() == nullptr) { - syncToDevice(); - tickReadHost(); - } - // FIXME: this should be fixed once CUDA backend added - return static_cast(_buffer->special()) + (_offset * sizeOfT()); + return static_cast(_buffer->special()) + (_offset * sizeOfT()); } //////////////////////////////////////////////////////////////////////// void const* NDArray::specialBuffer() const { - if (_buffer->special() == nullptr) { - syncToDevice(); - tickReadHost(); - } - // FIXME: this should be fixed once CUDA backend added - return static_cast(_buffer->special()) + (_offset * sizeOfT()); + if (_buffer->special() == nullptr) { + syncToDevice(); + tickReadHost(); + } + + return static_cast(_buffer->special()) + (_offset * sizeOfT()); } ////////////////////////////////////////////////////////////////////////// -template -void NDArray::printCurrentBuffer(const bool host, const char* msg, const int precision) const { - - if(_length == 0) - { printf("NDArray::printActualBuffer: array length is zero !\n"); return; } - - if(msg) - printf("%s", msg); - - if(host) { - if(buffer() == nullptr || _length == 0) - { printf("NDArray::printActualBuffer: host buffer is nullptr !\n"); return; } +template +void NDArray::printCurrentBuffer(const bool host, const char* msg, + const int precision) const { + if (_length == 0) { + printf("NDArray::printActualBuffer: array length is zero !\n"); + return; + } + + if (msg) printf("%s", msg); + + if (host) { + if (buffer() == nullptr || _length == 0) { + printf("NDArray::printActualBuffer: host buffer is nullptr !\n"); + return; + } - const T* buff = bufferAsT(); - for (uint i = 0; i < _length; i++) - printf("%.*f, ", precision, (double)buff[getOffset(i)]); - printf("\n"); + const T* buff = bufferAsT(); + for (uint i = 0; i < _length; i++) + printf("%.*f, ", precision, (double)buff[getOffset(i)]); + printf("\n"); + } else { + if (specialBuffer() == nullptr || _length == 0) { + printf("NDArray::printSpecialBuffer: special buffer is nullptr !\n"); + return; } - else { - if(specialBuffer() == nullptr || _length == 0) - { printf("NDArray::printSpecialBuffer: special buffer is nullptr !\n"); return; } - const auto sizeOfBuffer = sizeOfT() * (getOffset(_length - 1) + 1); + const auto sizeOfBuffer = sizeOfT() * (getOffset(_length - 1) + 1); - void* pHost = operator new(sizeOfBuffer); + void* pHost = operator new(sizeOfBuffer); - cudaMemcpyAsync(pHost, specialBuffer(), sizeOfBuffer, cudaMemcpyDeviceToHost, *getContext()->getCudaStream()); + cudaMemcpyAsync(pHost, specialBuffer(), sizeOfBuffer, + cudaMemcpyDeviceToHost, *getContext()->getCudaStream()); - cudaError_t cudaResult = cudaStreamSynchronize(*getContext()->getCudaStream()); - if(cudaResult != 0) - throw std::runtime_error("NDArray::printSpecialBuffer: cudaStreamSynchronize failed!"); + cudaError_t cudaResult = + cudaStreamSynchronize(*getContext()->getCudaStream()); + if (cudaResult != 0) + throw std::runtime_error( + "NDArray::printSpecialBuffer: cudaStreamSynchronize failed!"); - for (uint i = 0; i < _length; i++) - printf("%.*f, ", precision, (double)reinterpret_cast(pHost)[getOffset(i)]); - printf("\n"); + for (uint i = 0; i < _length; i++) + printf("%.*f, ", precision, (double)reinterpret_cast(pHost)[getOffset(i)]); + printf("\n"); - operator delete(pHost); - } + operator delete(pHost); + } } -template void NDArray::printCurrentBuffer(const bool host,const char* msg, const int precision) const; -template void NDArray::printCurrentBuffer(const bool host, const char* msg, const int precision) const; -template void NDArray::printCurrentBuffer(const bool host, const char* msg, const int precision) const; - +template void NDArray::printCurrentBuffer(const bool host, const char* msg, + const int precision) const; +template void NDArray::printCurrentBuffer(const bool host, + const char* msg, + const int precision) const; +template void NDArray::printCurrentBuffer(const bool host, + const char* msg, + const int precision) const; #if defined(__CUDACC__) && !defined(BUILD_TESTS) @@ -580,6 +691,5 @@ template void NDArray::printCurrentBuffer(const bool host, const char* m #endif -} // end namespace sd +} // end namespace sd #endif - diff --git a/libnd4j/include/array/impl/ByteOrderUtils.cpp b/libnd4j/include/array/impl/ByteOrderUtils.cpp index 0220ccac807b..5ce7f0f44988 100644 --- a/libnd4j/include/array/impl/ByteOrderUtils.cpp +++ b/libnd4j/include/array/impl/ByteOrderUtils.cpp @@ -20,9 +20,8 @@ #include - namespace sd { - ByteOrder ByteOrderUtils::fromFlatByteOrder(sd::graph::ByteOrder order) { - return (ByteOrder) order; - } -} \ No newline at end of file +ByteOrder ByteOrderUtils::fromFlatByteOrder(sd::graph::ByteOrder order) { + return (ByteOrder)order; +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/array/impl/ConstantDataBuffer.cpp b/libnd4j/include/array/impl/ConstantDataBuffer.cpp index 2aeda3b6d63d..50ed3ba60516 100644 --- a/libnd4j/include/array/impl/ConstantDataBuffer.cpp +++ b/libnd4j/include/array/impl/ConstantDataBuffer.cpp @@ -22,8 +22,7 @@ #include namespace sd { -ConstantDataBuffer::ConstantDataBuffer( - const std::shared_ptr& primary, +ConstantDataBuffer::ConstantDataBuffer(const std::shared_ptr& primary, uint64_t numEelements, DataType dtype) : ConstantDataBuffer(primary, {}, numEelements, dtype) { // @@ -32,50 +31,44 @@ ConstantDataBuffer::ConstantDataBuffer( ConstantDataBuffer::ConstantDataBuffer( const std::shared_ptr& primary, const std::shared_ptr& special, - uint64_t numEelements, - DataType dtype) : _primaryBuffer(primary), _specialBuffer(special), _length(numEelements) { - _sizeOf = DataTypeUtils::sizeOf(dtype); - } - - void* ConstantDataBuffer::primary() const { - return _primaryBuffer->pointer(); - } + uint64_t numEelements, + DataType dtype) : _primaryBuffer(primary), + _specialBuffer (special), + _length (numEelements) { + _sizeOf = DataTypeUtils::sizeOf(dtype); +} - void* ConstantDataBuffer::special() const { - return _specialBuffer ? _specialBuffer->pointer() : nullptr; - } +void* ConstantDataBuffer::primary() const { return _primaryBuffer->pointer(); } - uint8_t ConstantDataBuffer::sizeOf() const { - return _sizeOf; - } +void* ConstantDataBuffer::special() const { return _specialBuffer? _specialBuffer->pointer() : nullptr; } - uint64_t ConstantDataBuffer::length() const { - return _length; - } +uint8_t ConstantDataBuffer::sizeOf() const { return _sizeOf; } - ConstantDataBuffer::ConstantDataBuffer(const ConstantDataBuffer &other) { - _primaryBuffer = other._primaryBuffer; - _specialBuffer = other._specialBuffer; - _length = other._length; - _sizeOf = other._sizeOf; - } +uint64_t ConstantDataBuffer::length() const { return _length; } - template - T* ConstantDataBuffer::primaryAsT() const { - return reinterpret_cast(_primaryBuffer->pointer()); - } - template ND4J_EXPORT float* ConstantDataBuffer::primaryAsT() const; - template ND4J_EXPORT double* ConstantDataBuffer::primaryAsT() const; - template ND4J_EXPORT int* ConstantDataBuffer::primaryAsT() const; - template ND4J_EXPORT Nd4jLong* ConstantDataBuffer::primaryAsT() const; +ConstantDataBuffer::ConstantDataBuffer(const ConstantDataBuffer& other) { + _primaryBuffer = other._primaryBuffer; + _specialBuffer = other._specialBuffer; + _length = other._length; + _sizeOf = other._sizeOf; +} - template - T* ConstantDataBuffer::specialAsT() const { - return reinterpret_cast(special()); - } - template ND4J_EXPORT float* ConstantDataBuffer::specialAsT() const; - template ND4J_EXPORT double* ConstantDataBuffer::specialAsT() const; - template ND4J_EXPORT int* ConstantDataBuffer::specialAsT() const; - template ND4J_EXPORT Nd4jLong* ConstantDataBuffer::specialAsT() const; +template +T* ConstantDataBuffer::primaryAsT() const { + return reinterpret_cast(_primaryBuffer->pointer()); +} +template SD_EXPORT float* ConstantDataBuffer::primaryAsT() const; +template SD_EXPORT double* ConstantDataBuffer::primaryAsT() const; +template SD_EXPORT int* ConstantDataBuffer::primaryAsT() const; +template SD_EXPORT Nd4jLong* ConstantDataBuffer::primaryAsT() const; +template +T* ConstantDataBuffer::specialAsT() const { + return reinterpret_cast(special()); } +template SD_EXPORT float* ConstantDataBuffer::specialAsT() const; +template SD_EXPORT double* ConstantDataBuffer::specialAsT() const; +template SD_EXPORT int* ConstantDataBuffer::specialAsT() const; +template SD_EXPORT Nd4jLong* ConstantDataBuffer::specialAsT() const; + +} // namespace sd diff --git a/libnd4j/include/array/impl/ConstantDescriptor.cpp b/libnd4j/include/array/impl/ConstantDescriptor.cpp index 829ac5b343f4..1e2b22f5c22d 100644 --- a/libnd4j/include/array/impl/ConstantDescriptor.cpp +++ b/libnd4j/include/array/impl/ConstantDescriptor.cpp @@ -20,80 +20,80 @@ #include #include + #include namespace sd { - ConstantDescriptor::ConstantDescriptor(double* values, int length) { - for (int e = 0; e < length; e++) - _floatValues.emplace_back(values[e]); - } +ConstantDescriptor::ConstantDescriptor(double *values, int length) { + for (int e = 0; e < length; e++) _floatValues.emplace_back(values[e]); +} - ConstantDescriptor::ConstantDescriptor(Nd4jLong const* values, int length) { - for (int e = 0; e < length; e++) - _integerValues.emplace_back(values[e]); - } +ConstantDescriptor::ConstantDescriptor(Nd4jLong const *values, int length) { + for (int e = 0; e < length; e++) _integerValues.emplace_back(values[e]); +} - ConstantDescriptor::ConstantDescriptor(std::initializer_list values) { - _floatValues = values; - } +ConstantDescriptor::ConstantDescriptor(std::initializer_list values) { + _floatValues = values; +} - ConstantDescriptor::ConstantDescriptor(std::vector &values) { - _integerValues = values; - } +ConstantDescriptor::ConstantDescriptor(std::vector &values) { + _integerValues = values; +} - ConstantDescriptor::ConstantDescriptor(std::vector &values) { - _floatValues = values; - } +ConstantDescriptor::ConstantDescriptor(std::vector &values) { + _floatValues = values; +} - // equal to operator - bool ConstantDescriptor::operator==(const ConstantDescriptor &other) const { - return std::tie(_floatValues, _integerValues) == std::tie(other._floatValues, other._integerValues); - } +// equal to operator +bool ConstantDescriptor::operator==(const ConstantDescriptor &other) const { + return std::tie(_floatValues, _integerValues) == + std::tie(other._floatValues, other._integerValues); +} - // less than operator - bool ConstantDescriptor::operator<(const ConstantDescriptor &other) const { - return std::tie(_floatValues, _integerValues) < std::tie(other._floatValues, other._integerValues); - } +// less than operator +bool ConstantDescriptor::operator<(const ConstantDescriptor &other) const { + return std::tie(_floatValues, _integerValues) < + std::tie(other._floatValues, other._integerValues); +} - bool ConstantDescriptor::isInteger() const { - return !_integerValues.empty(); - } +bool ConstantDescriptor::isInteger() const { return !_integerValues.empty(); } - bool ConstantDescriptor::isFloat() const { - return !_floatValues.empty(); - } +bool ConstantDescriptor::isFloat() const { return !_floatValues.empty(); } - const std::vector& ConstantDescriptor::integerValues() const { - return _integerValues; - } +const std::vector &ConstantDescriptor::integerValues() const { + return _integerValues; +} - const std::vector& ConstantDescriptor::floatValues() const { - return _floatValues; - } +const std::vector &ConstantDescriptor::floatValues() const { + return _floatValues; +} - Nd4jLong ConstantDescriptor::length() const { - return isInteger() ? _integerValues.size() : isFloat() ? _floatValues.size() : 0L; - } +Nd4jLong ConstantDescriptor::length() const { + return isInteger() ? _integerValues.size() + : isFloat() ? _floatValues.size() : 0L; } +} // namespace sd namespace std { - size_t hash::operator()(const sd::ConstantDescriptor &k) const { - using std::hash; - // Compute individual hash values for first, - // second and third and combine them using XOR - // and bit shifting: - size_t hashVal = 0; - size_t i = 0; - if (k.isInteger()) { - for (auto v: k.integerValues()) { - hashVal ^= std::hash()(v) + 0x9e3779b9 + (hashVal << 6) + (hashVal >> 2); - } - } - else { - for (auto v: k.floatValues()) { - hashVal ^= std::hash()(v) + 0x9e3779b9 + (hashVal << 6) + (hashVal >> 2); - } - } - return hashVal; +size_t hash::operator()( + const sd::ConstantDescriptor &k) const { + using std::hash; + // Compute individual hash values for first, + // second and third and combine them using XOR + // and bit shifting: + size_t hashVal = 0; + size_t i = 0; + if (k.isInteger()) { + for (auto v : k.integerValues()) { + hashVal ^= std::hash()(v) + 0x9e3779b9 + (hashVal << 6) + + (hashVal >> 2); + } + } else { + for (auto v : k.floatValues()) { + hashVal ^= + std::hash()(v) + 0x9e3779b9 + (hashVal << 6) + (hashVal >> 2); } + } + return hashVal; } +} // namespace std diff --git a/libnd4j/include/array/impl/ConstantHolder.cpp b/libnd4j/include/array/impl/ConstantHolder.cpp index 08637862c5e3..35c45335baf5 100644 --- a/libnd4j/include/array/impl/ConstantHolder.cpp +++ b/libnd4j/include/array/impl/ConstantHolder.cpp @@ -1,5 +1,5 @@ /** -* Copyright (c) 2019 Konduit K.K. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,50 +18,72 @@ // Created by raver on 5/17/2019. // -#include #include +#include #include namespace sd { - ConstantHolder::ConstantHolder(const ConstantHolder& other) { - _buffers = other._buffers; - _deviceId = other._deviceId; - } - - bool ConstantHolder::hasBuffer(sd::DataType dataType) { - return _buffers.count(dataType) > 0; - } - - std::mutex* ConstantHolder::mutex() { - return &_mutex; - } - - template - bool ConstantHolder::hasBuffer() { - return hasBuffer(DataTypeUtils::fromT()); - } - BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT bool ConstantHolder::hasBuffer, (void), LIBND4J_TYPES); - - void ConstantHolder::addBuffer(ConstantDataBuffer &pointer, sd::DataType dataType) { - _buffers[dataType] = pointer; - } - - template - void ConstantHolder::addBuffer(ConstantDataBuffer &pointer) { - addBuffer(pointer, DataTypeUtils::fromT()); - } - BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void ConstantHolder::addBuffer, (ConstantDataBuffer& cb), LIBND4J_TYPES); - - ConstantDataBuffer* ConstantHolder::getConstantDataBuffer(sd::DataType dataType) { - if (!hasBuffer(dataType)) - throw std::runtime_error("Requested dataType is absent in storage"); - - return &_buffers[dataType]; - } - - template - ConstantDataBuffer* ConstantHolder::getConstantDataBuffer() { - return getConstantDataBuffer(DataTypeUtils::fromT()); - } - BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT ConstantDataBuffer* ConstantHolder::getConstantDataBuffer, (), LIBND4J_TYPES); -} \ No newline at end of file +ConstantHolder::ConstantHolder(const ConstantHolder& other) { + _buffers = other._buffers; + _deviceId = other._deviceId; +} + +ConstantHolder &ConstantHolder::operator=(const ConstantHolder &other) { + if (this == &other) return *this; + + _buffers = other._buffers; + _deviceId = other._deviceId; + + return *this; +} + +ConstantHolder &ConstantHolder::operator=(ConstantHolder &&other) { + if (this == &other) return *this; + + _buffers = std::move(other._buffers); + _deviceId = other._deviceId; + + return *this; +} + +bool ConstantHolder::hasBuffer(sd::DataType dataType) { + return _buffers.count(dataType) > 0; +} + +std::mutex* ConstantHolder::mutex() { return &_mutex; } + +template +bool ConstantHolder::hasBuffer() { + return hasBuffer(DataTypeUtils::fromT()); +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT bool ConstantHolder::hasBuffer, (void), + LIBND4J_TYPES); + +void ConstantHolder::addBuffer(ConstantDataBuffer& pointer, + sd::DataType dataType) { + _buffers[dataType] = pointer; +} + +template +void ConstantHolder::addBuffer(ConstantDataBuffer& pointer) { + addBuffer(pointer, DataTypeUtils::fromT()); +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT void ConstantHolder::addBuffer, + (ConstantDataBuffer & cb), LIBND4J_TYPES); + +ConstantDataBuffer* ConstantHolder::getConstantDataBuffer( + sd::DataType dataType) { + if (!hasBuffer(dataType)) + throw std::runtime_error("Requested dataType is absent in storage"); + + return &_buffers[dataType]; +} + +template +ConstantDataBuffer* ConstantHolder::getConstantDataBuffer() { + return getConstantDataBuffer(DataTypeUtils::fromT()); +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT ConstantDataBuffer* + ConstantHolder::getConstantDataBuffer, + (), LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/array/impl/DataBuffer.cpp b/libnd4j/include/array/impl/DataBuffer.cpp index 89c386c3d8fa..4ef686104c82 100644 --- a/libnd4j/include/array/impl/DataBuffer.cpp +++ b/libnd4j/include/array/impl/DataBuffer.cpp @@ -20,319 +20,308 @@ // #include -#include #include +#include #include +#include #include -#include namespace sd { - ///// IMLEMENTATION OF COMMON METHODS ///// - +///// IMLEMENTATION OF COMMON METHODS ///// //////////////////////////////////////////////////////////////////////// // default constructor - DataBuffer::DataBuffer() { - - _primaryBuffer = nullptr; - _specialBuffer = nullptr; - _lenInBytes = 0; - _dataType = INT8; - _workspace = nullptr; - _isOwnerPrimary = false; - _isOwnerSpecial = false; - _deviceId = sd::AffinityManager::currentDeviceId(); - - setCountersToZero(); - } +DataBuffer::DataBuffer() { + _primaryBuffer = nullptr; + _specialBuffer = nullptr; + _lenInBytes = 0; + _dataType = INT8; + _workspace = nullptr; + _isOwnerPrimary = false; + _isOwnerSpecial = false; + _deviceId = sd::AffinityManager::currentDeviceId(); + + setCountersToZero(); +} //////////////////////////////////////////////////////////////////////// // copy constructor - DataBuffer::DataBuffer(const DataBuffer &other) { +DataBuffer::DataBuffer(const DataBuffer& other) { + throw std::runtime_error( + "DataBuffer copy constructor: we don't expect using of this " + "constructor!"); - throw std::runtime_error("DataBuffer copy constructor: we don't expect using of this constructor!"); + _lenInBytes = other._lenInBytes; + _dataType = other._dataType; + _workspace = other._workspace; - _lenInBytes = other._lenInBytes; - _dataType = other._dataType; - _workspace = other._workspace; + _primaryBuffer = nullptr; + _specialBuffer = nullptr; - _primaryBuffer = nullptr; - _specialBuffer = nullptr; + _deviceId.store(other._deviceId.load()); - _deviceId.store(other._deviceId.load()); + setCountersToZero(); - setCountersToZero(); - - allocateBuffers(); - copyBufferFrom(other); - } + allocateBuffers(); + copyBufferFrom(other); +} //////////////////////////////////////////////////////////////////////// - DataBuffer::DataBuffer(void* primary, void* special, - const size_t lenInBytes, const DataType dataType, - const bool isOwnerPrimary, const bool isOwnerSpecial, - memory::Workspace* workspace) { - - if (primary == nullptr && special == nullptr) - throw std::runtime_error("DataBuffer constructor: can't be initialized with both nullptr buffers !"); - - _primaryBuffer = primary; - _specialBuffer = special; - _lenInBytes = lenInBytes; - _dataType = dataType; - _workspace = workspace; - _isOwnerPrimary = isOwnerPrimary; - _isOwnerSpecial = isOwnerSpecial; - _deviceId = sd::AffinityManager::currentDeviceId(); - - setCountersToZero(); - - if(primary != nullptr) - readPrimary(); - if(special != nullptr) - readSpecial(); - } +DataBuffer::DataBuffer(void* primary, void* special, const size_t lenInBytes, + const DataType dataType, const bool isOwnerPrimary, + const bool isOwnerSpecial, + memory::Workspace* workspace) { + if (primary == nullptr && special == nullptr) + throw std::runtime_error( + "DataBuffer constructor: can't be initialized with both nullptr " + "buffers !"); + + _primaryBuffer = primary; + _specialBuffer = special; + _lenInBytes = lenInBytes; + _dataType = dataType; + _workspace = workspace; + _isOwnerPrimary = isOwnerPrimary; + _isOwnerSpecial = isOwnerSpecial; + _deviceId = sd::AffinityManager::currentDeviceId(); + + setCountersToZero(); + + if (primary != nullptr) readPrimary(); + if (special != nullptr) readSpecial(); +} //////////////////////////////////////////////////////////////////////// - DataBuffer::DataBuffer(void* primary, const size_t lenInBytes, const DataType dataType, const bool isOwnerPrimary, memory::Workspace* workspace): - DataBuffer(primary, nullptr, lenInBytes, dataType, isOwnerPrimary, false, workspace) { - - syncToSpecial(true); - } +DataBuffer::DataBuffer(void* primary, const size_t lenInBytes, + const DataType dataType, const bool isOwnerPrimary, + memory::Workspace* workspace) + : DataBuffer(primary, nullptr, lenInBytes, dataType, isOwnerPrimary, false, + workspace) { + syncToSpecial(true); +} //////////////////////////////////////////////////////////////////////// // copies data from hostBuffer to own memory buffer - DataBuffer::DataBuffer(const void* hostBuffer, const DataType dataType, const size_t lenInBytes, memory::Workspace* workspace) { - - if (hostBuffer == nullptr) - throw std::runtime_error("DataBuffer constructor: can't be initialized with nullptr host buffer !"); - if (lenInBytes == 0) - throw std::runtime_error("DataBuffer constructor: can't be initialized with zero length !"); +DataBuffer::DataBuffer(const void* hostBuffer, const DataType dataType, + const size_t lenInBytes, memory::Workspace* workspace) { + if (hostBuffer == nullptr) + throw std::runtime_error( + "DataBuffer constructor: can't be initialized with nullptr host buffer " + "!"); + if (lenInBytes == 0) + throw std::runtime_error( + "DataBuffer constructor: can't be initialized with zero length !"); - _primaryBuffer = nullptr; - _specialBuffer = nullptr; - _lenInBytes = lenInBytes; - _dataType = dataType; - _workspace = workspace; + _primaryBuffer = nullptr; + _specialBuffer = nullptr; + _lenInBytes = lenInBytes; + _dataType = dataType; + _workspace = workspace; - _deviceId = sd::AffinityManager::currentDeviceId(); + _deviceId = sd::AffinityManager::currentDeviceId(); - setCountersToZero(); + setCountersToZero(); - allocateBuffers(); + allocateBuffers(); - copyBufferFromHost(hostBuffer, lenInBytes); - } + copyBufferFromHost(hostBuffer, lenInBytes); +} //////////////////////////////////////////////////////////////////////// - DataBuffer::DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace, const bool allocBoth) { +DataBuffer::DataBuffer(const size_t lenInBytes, const DataType dataType, + memory::Workspace* workspace, const bool allocBoth) { + _dataType = dataType; + _workspace = workspace; + _lenInBytes = lenInBytes; - _dataType = dataType; - _workspace = workspace; - _lenInBytes = lenInBytes; + _primaryBuffer = nullptr; + _specialBuffer = nullptr; - _primaryBuffer = nullptr; - _specialBuffer = nullptr; + _deviceId = sd::AffinityManager::currentDeviceId(); - _deviceId = sd::AffinityManager::currentDeviceId(); + setCountersToZero(); - setCountersToZero(); - - if(lenInBytes != 0) { - allocateBuffers(allocBoth); - writeSpecial(); - } - } + if (lenInBytes != 0) { + allocateBuffers(allocBoth); + writeSpecial(); + } +} //////////////////////////////////////////////////////////////////////// // move constructor - DataBuffer::DataBuffer(DataBuffer&& other) { - - _primaryBuffer = other._primaryBuffer; - _specialBuffer = other._specialBuffer; - _lenInBytes = other._lenInBytes; - _dataType = other._dataType; - _workspace = other._workspace; - _isOwnerPrimary = other._isOwnerPrimary; - _isOwnerSpecial = other._isOwnerSpecial; - _deviceId.store(other._deviceId); - - copyCounters(other); - - other._primaryBuffer = other._specialBuffer = nullptr; - other.setAllocFlags(false, false); - other._lenInBytes = 0; - } +DataBuffer::DataBuffer(DataBuffer&& other) { + _primaryBuffer = other._primaryBuffer; + _specialBuffer = other._specialBuffer; + _lenInBytes = other._lenInBytes; + _dataType = other._dataType; + _workspace = other._workspace; + _isOwnerPrimary = other._isOwnerPrimary; + _isOwnerSpecial = other._isOwnerSpecial; + _deviceId.store(other._deviceId); + + copyCounters(other); + + other._primaryBuffer = other._specialBuffer = nullptr; + other.setAllocFlags(false, false); + other._lenInBytes = 0; +} //////////////////////////////////////////////////////////////////////// // assignment operator - DataBuffer& DataBuffer::operator=(const DataBuffer& other) { +DataBuffer& DataBuffer::operator=(const DataBuffer& other) { + if (this == &other) return *this; - if (this == &other) - return *this; + deleteBuffers(); - deleteBuffers(); + _lenInBytes = other._lenInBytes; + _dataType = other._dataType; + _workspace = other._workspace; - _lenInBytes = other._lenInBytes; - _dataType = other._dataType; - _workspace = other._workspace; + allocateBuffers(); + copyBufferFrom(other); - allocateBuffers(); - copyBufferFrom(other); - - return *this; - } + return *this; +} //////////////////////////////////////////////////////////////////////// // move assignment operator - DataBuffer& DataBuffer::operator=(DataBuffer&& other) noexcept { - - if (this == &other) - return *this; - - deleteBuffers(); +DataBuffer& DataBuffer::operator=(DataBuffer&& other) noexcept { + if (this == &other) return *this; - _primaryBuffer = other._primaryBuffer; - _specialBuffer = other._specialBuffer; - _lenInBytes = other._lenInBytes; - _dataType = other._dataType; - _workspace = other._workspace; - _isOwnerPrimary = other._isOwnerPrimary; - _isOwnerSpecial = other._isOwnerSpecial; + deleteBuffers(); - copyCounters(other); + _primaryBuffer = other._primaryBuffer; + _specialBuffer = other._specialBuffer; + _lenInBytes = other._lenInBytes; + _dataType = other._dataType; + _workspace = other._workspace; + _isOwnerPrimary = other._isOwnerPrimary; + _isOwnerSpecial = other._isOwnerSpecial; - other._primaryBuffer = other._specialBuffer = nullptr; - other.setAllocFlags(false, false); - other._lenInBytes = 0; + copyCounters(other); - return *this; - } + other._primaryBuffer = other._specialBuffer = nullptr; + other.setAllocFlags(false, false); + other._lenInBytes = 0; -//////////////////////////////////////////////////////////////////////// - void* DataBuffer::primary() { - return _primaryBuffer; - } + return *this; +} //////////////////////////////////////////////////////////////////////// - void* DataBuffer::special() { - return _specialBuffer; - } +void* DataBuffer::primary() { return _primaryBuffer; } //////////////////////////////////////////////////////////////////////// - DataType DataBuffer::getDataType() { - return _dataType; - } +void* DataBuffer::special() { return _specialBuffer; } //////////////////////////////////////////////////////////////////////// - size_t DataBuffer::getLenInBytes() const { - return _lenInBytes; - } - +DataType DataBuffer::getDataType() { return _dataType; } //////////////////////////////////////////////////////////////////////// - void DataBuffer::allocatePrimary() { - - if (_primaryBuffer == nullptr && getLenInBytes() > 0) { - auto deviceId = sd::AffinityManager::currentDeviceId(); - // check if this allocation won't bring us above limit - if (_workspace == nullptr) { - if (Environment::getInstance().isCPU()) { - // on cpu backend we validate against device 0 for now - if (!sd::memory::MemoryCounter::getInstance().validate(getLenInBytes())) - throw sd::allocation_exception::build("Requested amount exceeds HOST device limits", sd::memory::MemoryCounter::getInstance().deviceLimit(deviceId), getLenInBytes()); - } else { - // in heterogenous mode we valdate against device group - if (!sd::memory::MemoryCounter::getInstance().validateGroup(sd::memory::MemoryType::HOST, getLenInBytes())) - throw sd::allocation_exception::build("Requested amount exceeds HOST group limits", sd::memory::MemoryCounter::getInstance().groupLimit(sd::memory::MemoryType::HOST), getLenInBytes()); - } - } - - ALLOCATE(_primaryBuffer, _workspace, getLenInBytes(), int8_t); - _isOwnerPrimary = true; - - // count in towards current deviceId if we're not in workspace mode - if (_workspace == nullptr) { - if (Environment::getInstance().isCPU()) // we don't want this counter to be added to CUDA device - sd::memory::MemoryCounter::getInstance().countIn(deviceId, getLenInBytes()); - - sd::memory::MemoryCounter::getInstance().countIn(sd::memory::MemoryType::HOST, getLenInBytes()); - } - } - } +size_t DataBuffer::getLenInBytes() const { return _lenInBytes; } //////////////////////////////////////////////////////////////////////// - void DataBuffer::setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial) { - _isOwnerPrimary = isOwnerPrimary; - _isOwnerSpecial = isOwnerSpecial; +void DataBuffer::allocatePrimary() { + if (_primaryBuffer == nullptr && getLenInBytes() > 0) { + auto deviceId = sd::AffinityManager::currentDeviceId(); + // check if this allocation won't bring us above limit + if (_workspace == nullptr) { + if (Environment::getInstance().isCPU()) { + // on cpu backend we validate against device 0 for now + if (!sd::memory::MemoryCounter::getInstance().validate( + getLenInBytes())) + throw sd::allocation_exception::build( + "Requested amount exceeds HOST device limits", + sd::memory::MemoryCounter::getInstance().deviceLimit(deviceId), + getLenInBytes()); + } else { + // in heterogenous mode we valdate against device group + if (!sd::memory::MemoryCounter::getInstance().validateGroup( + sd::memory::MemoryType::HOST, getLenInBytes())) + throw sd::allocation_exception::build( + "Requested amount exceeds HOST group limits", + sd::memory::MemoryCounter::getInstance().groupLimit( + sd::memory::MemoryType::HOST), + getLenInBytes()); + } } -//////////////////////////////////////////////////////////////////////// - void DataBuffer::deletePrimary() { - - if(_isOwnerPrimary && _primaryBuffer != nullptr && getLenInBytes() != 0) { - auto p = reinterpret_cast(_primaryBuffer); - RELEASE(p, _workspace); - _primaryBuffer = nullptr; - _isOwnerPrimary = false; + ALLOCATE(_primaryBuffer, _workspace, getLenInBytes(), int8_t); + _isOwnerPrimary = true; + // count in towards current deviceId if we're not in workspace mode + if (_workspace == nullptr) { + if (Environment::getInstance().isCPU()) // we don't want this counter to + // be added to CUDA device + sd::memory::MemoryCounter::getInstance().countIn(deviceId, + getLenInBytes()); - // count out towards DataBuffer device, only if we're not in workspace - if (_workspace == nullptr) { - if (Environment::getInstance().isCPU()) - sd::memory::MemoryCounter::getInstance().countOut(_deviceId, getLenInBytes()); - - sd::memory::MemoryCounter::getInstance().countOut(sd::memory::MemoryType::HOST, getLenInBytes()); - } - } + sd::memory::MemoryCounter::getInstance().countIn( + sd::memory::MemoryType::HOST, getLenInBytes()); } + } +} //////////////////////////////////////////////////////////////////////// - void DataBuffer::deleteBuffers() { +void DataBuffer::setAllocFlags(const bool isOwnerPrimary, + const bool isOwnerSpecial) { + _isOwnerPrimary = isOwnerPrimary; + _isOwnerSpecial = isOwnerSpecial; +} - deletePrimary(); - deleteSpecial(); - _lenInBytes = 0; +//////////////////////////////////////////////////////////////////////// +void DataBuffer::deletePrimary() { + if (_isOwnerPrimary && _primaryBuffer != nullptr && getLenInBytes() != 0) { + auto p = reinterpret_cast(_primaryBuffer); + RELEASE(p, _workspace); + _primaryBuffer = nullptr; + _isOwnerPrimary = false; + + // count out towards DataBuffer device, only if we're not in workspace + if (_workspace == nullptr) { + if (Environment::getInstance().isCPU()) + sd::memory::MemoryCounter::getInstance().countOut(_deviceId, + getLenInBytes()); + + sd::memory::MemoryCounter::getInstance().countOut( + sd::memory::MemoryType::HOST, getLenInBytes()); } + } +} //////////////////////////////////////////////////////////////////////// - DataBuffer::~DataBuffer() { +void DataBuffer::deleteBuffers() { + deletePrimary(); + deleteSpecial(); + _lenInBytes = 0; +} - deleteBuffers(); - } +//////////////////////////////////////////////////////////////////////// +DataBuffer::~DataBuffer() { deleteBuffers(); } - void DataBuffer::setPrimaryBuffer(void *buffer, size_t length) { - if (_primaryBuffer != nullptr && _isOwnerPrimary) { - deletePrimary(); - } +void DataBuffer::setPrimaryBuffer(void* buffer, size_t length) { + if (_primaryBuffer != nullptr && _isOwnerPrimary) { + deletePrimary(); + } - _primaryBuffer = buffer; - _isOwnerPrimary = false; - _lenInBytes = length * DataTypeUtils::sizeOf(_dataType); - } + _primaryBuffer = buffer; + _isOwnerPrimary = false; + _lenInBytes = length * DataTypeUtils::sizeOf(_dataType); +} - void DataBuffer::setSpecialBuffer(void *buffer, size_t length) { - if (_specialBuffer != nullptr && _isOwnerSpecial) { - deleteSpecial(); - } +void DataBuffer::setSpecialBuffer(void* buffer, size_t length) { + if (_specialBuffer != nullptr && _isOwnerSpecial) { + deleteSpecial(); + } - this->setSpecial(buffer, false); - _lenInBytes = length * DataTypeUtils::sizeOf(_dataType); - } + this->setSpecial(buffer, false); + _lenInBytes = length * DataTypeUtils::sizeOf(_dataType); +} - void DataBuffer::setDataType(DataType dataType) { - _dataType = dataType; - } +void DataBuffer::setDataType(DataType dataType) { _dataType = dataType; } - int DataBuffer::deviceId() const { - return _deviceId.load(); - } +int DataBuffer::deviceId() const { return _deviceId.load(); } - void DataBuffer::close() { - this->deleteBuffers(); - } +void DataBuffer::close() { this->deleteBuffers(); } - void DataBuffer::setDeviceId(int deviceId) { - _deviceId = deviceId; - } -} +void DataBuffer::setDeviceId(int deviceId) { _deviceId = deviceId; } +} // namespace sd diff --git a/libnd4j/include/array/impl/DataTypeUtils.cpp b/libnd4j/include/array/impl/DataTypeUtils.cpp index 481fa4149716..9bfe90eff901 100644 --- a/libnd4j/include/array/impl/DataTypeUtils.cpp +++ b/libnd4j/include/array/impl/DataTypeUtils.cpp @@ -23,15 +23,11 @@ #include namespace sd { - DataType DataTypeUtils::fromInt(int val) { - return (DataType) val; - } +DataType DataTypeUtils::fromInt(int val) { return (DataType)val; } - DataType DataTypeUtils::fromFlatDataType(sd::graph::DType dtype) { - return (DataType) dtype; - } +DataType DataTypeUtils::fromFlatDataType(sd::graph::DType dtype) { + return (DataType)dtype; +} - int DataTypeUtils::asInt(DataType type) { - return (int) type; - } -} \ No newline at end of file +int DataTypeUtils::asInt(DataType type) { return (int)type; } +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/array/impl/ExtraArguments.cpp b/libnd4j/include/array/impl/ExtraArguments.cpp index 084f327cc290..41ac67d8beb0 100644 --- a/libnd4j/include/array/impl/ExtraArguments.cpp +++ b/libnd4j/include/array/impl/ExtraArguments.cpp @@ -18,123 +18,124 @@ // @author raver119@gmail.com // -#include #include #include -#include +#include #include +#include + #ifdef __CUDABLAS__ #include #include #endif namespace sd { - ExtraArguments::ExtraArguments(std::initializer_list arguments) { - _fpArgs = arguments; - } +ExtraArguments::ExtraArguments(std::initializer_list arguments) { + _fpArgs = arguments; +} - ExtraArguments::ExtraArguments(std::initializer_list arguments) { - _intArgs = arguments; - } +ExtraArguments::ExtraArguments(std::initializer_list arguments) { + _intArgs = arguments; +} - ExtraArguments::ExtraArguments(const std::vector &arguments) { - _fpArgs = arguments; - } +ExtraArguments::ExtraArguments(const std::vector &arguments) { + _fpArgs = arguments; +} - ExtraArguments::ExtraArguments(const std::vector &arguments) { - _intArgs = arguments; - } +ExtraArguments::ExtraArguments(const std::vector &arguments) { + _intArgs = arguments; +} - ExtraArguments::ExtraArguments(const std::vector &arguments) { - for (const auto &v:arguments) - _intArgs.emplace_back(static_cast(v)); - } +ExtraArguments::ExtraArguments(const std::vector &arguments) { + for (const auto &v : arguments) + _intArgs.emplace_back(static_cast(v)); +} - ExtraArguments::ExtraArguments() { - // no-op - } +ExtraArguments::ExtraArguments() { + // no-op +} - ExtraArguments::~ExtraArguments() { - for (auto p:_pointers) { +ExtraArguments::~ExtraArguments() { + for (auto p : _pointers) { #ifdef __CUDABLAS__ - cudaFree(p); -#else // CPU branch - delete[] reinterpret_cast(p); + cudaFree(p); +#else // CPU branch + delete[] reinterpret_cast(p); #endif - } - } + } +} - template - void ExtraArguments::convertAndCopy(Nd4jPointer pointer, Nd4jLong offset) { - auto length = this->length(); - auto target = reinterpret_cast(pointer); +template +void ExtraArguments::convertAndCopy(Nd4jPointer pointer, Nd4jLong offset) { + auto length = this->length(); + auto target = reinterpret_cast(pointer); #ifdef __CUDABLAS__ - target = new T[length]; + target = new T[length]; #endif - if (!_fpArgs.empty()) { - for (int e = offset; e < _fpArgs.size(); e++) { - target[e] = static_cast(_fpArgs[e]); - } - } else if (_intArgs.empty()) { - for (int e = offset; e < _intArgs.size(); e++) { - target[e] = static_cast(_intArgs[e]); - } - } + if (!_fpArgs.empty()) { + for (int e = offset; e < _fpArgs.size(); e++) { + target[e] = static_cast(_fpArgs[e]); + } + } else if (_intArgs.empty()) { + for (int e = offset; e < _intArgs.size(); e++) { + target[e] = static_cast(_intArgs[e]); + } + } #ifdef __CUDABLAS__ - // TODO: maybe make it asynchronous eventually? - cudaMemcpy(pointer, target, length * DataTypeUtils::sizeOf(DataTypeUtils::fromT()), cudaMemcpyHostToDevice); - delete[] target; + // TODO: maybe make it asynchronous eventually? + cudaMemcpy(pointer, target, + length * DataTypeUtils::sizeOf(DataTypeUtils::fromT()), + cudaMemcpyHostToDevice); + delete[] target; #endif - } - BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void ExtraArguments::convertAndCopy, (Nd4jPointer pointer, Nd4jLong offset), LIBND4J_TYPES); +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT void ExtraArguments::convertAndCopy, + (Nd4jPointer pointer, Nd4jLong offset), LIBND4J_TYPES); - void* ExtraArguments::allocate(size_t length, size_t elementSize) { +void *ExtraArguments::allocate(size_t length, size_t elementSize) { #ifdef __CUDABLAS__ - Nd4jPointer ptr; - auto res = cudaMalloc(reinterpret_cast(&ptr), length * elementSize); - if (res != 0) - throw std::runtime_error("Can't allocate CUDA memory"); -#else // CPU branch - auto ptr = new int8_t[length * elementSize]; - if (!ptr) - throw std::runtime_error("Can't allocate memory"); + Nd4jPointer ptr; + auto res = cudaMalloc(reinterpret_cast(&ptr), length * elementSize); + if (res != 0) throw std::runtime_error("Can't allocate CUDA memory"); +#else // CPU branch + auto ptr = new int8_t[length * elementSize]; + if (!ptr) throw std::runtime_error("Can't allocate memory"); #endif - return ptr; - } - - size_t ExtraArguments::length() { - if (!_fpArgs.empty()) - return _fpArgs.size(); - else if (!_intArgs.empty()) - return _intArgs.size(); - else - return 0; - } + return ptr; +} - template - void* ExtraArguments::argumentsAsT(Nd4jLong offset) { - return argumentsAsT(DataTypeUtils::fromT(), offset); - } - BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void *ExtraArguments::argumentsAsT, (Nd4jLong offset), LIBND4J_TYPES); +size_t ExtraArguments::length() { + if (!_fpArgs.empty()) + return _fpArgs.size(); + else if (!_intArgs.empty()) + return _intArgs.size(); + else + return 0; +} +template +void *ExtraArguments::argumentsAsT(Nd4jLong offset) { + return argumentsAsT(DataTypeUtils::fromT(), offset); +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT void *ExtraArguments::argumentsAsT, + (Nd4jLong offset), LIBND4J_TYPES); - void* ExtraArguments::argumentsAsT(sd::DataType dataType, Nd4jLong offset) { - if (_fpArgs.empty() && _intArgs.empty()) - return nullptr; +void *ExtraArguments::argumentsAsT(sd::DataType dataType, Nd4jLong offset) { + if (_fpArgs.empty() && _intArgs.empty()) return nullptr; - // we allocate pointer - auto ptr = allocate(length() - offset, DataTypeUtils::sizeOf(dataType)); + // we allocate pointer + auto ptr = allocate(length() - offset, DataTypeUtils::sizeOf(dataType)); - // fill it with data - BUILD_SINGLE_SELECTOR(dataType, convertAndCopy, (ptr, offset), LIBND4J_TYPES); + // fill it with data + BUILD_SINGLE_SELECTOR(dataType, convertAndCopy, (ptr, offset), LIBND4J_TYPES); - // store it internally for future release - _pointers.emplace_back(ptr); + // store it internally for future release + _pointers.emplace_back(ptr); - return ptr; - } + return ptr; } +} // namespace sd diff --git a/libnd4j/include/array/impl/InteropDataBuffer.cpp b/libnd4j/include/array/impl/InteropDataBuffer.cpp index d0a38161285f..02d1cfc21dad 100644 --- a/libnd4j/include/array/impl/InteropDataBuffer.cpp +++ b/libnd4j/include/array/impl/InteropDataBuffer.cpp @@ -18,129 +18,132 @@ // @author raver119@gmail.com // -#include #include +#include #include #include namespace sd { - InteropDataBuffer::InteropDataBuffer(InteropDataBuffer &dataBuffer, uint64_t length, uint64_t offset) { - _dataBuffer = dataBuffer.getDataBuffer(); - - // offset is always absolute to the original buffer - _offset = offset; - - if (_offset + length > _dataBuffer->getLenInBytes()) { - throw std::runtime_error("offset + length is higher than original length"); - } - } - - InteropDataBuffer::InteropDataBuffer(std::shared_ptr databuffer) { - _dataBuffer = databuffer; - } - - InteropDataBuffer::InteropDataBuffer(size_t elements, sd::DataType dtype, bool allocateBoth) { - if (elements == 0) { - _dataBuffer = std::make_shared(); - _dataBuffer->setDataType(dtype); - } else { - _dataBuffer = std::make_shared(elements, dtype, nullptr, allocateBoth); - } - } - - std::shared_ptr InteropDataBuffer::getDataBuffer() const { - return _dataBuffer; - } - - std::shared_ptr InteropDataBuffer::dataBuffer() { - return _dataBuffer; - } - - void* InteropDataBuffer::primary() const { - return reinterpret_cast(_dataBuffer->primary()) + _offset; - } - - void* InteropDataBuffer::special() const { - return reinterpret_cast(_dataBuffer->special()) + _offset; - } - - void InteropDataBuffer::setPrimary(void* ptr, size_t length) { - _dataBuffer->setPrimaryBuffer(ptr, length); - } - - void InteropDataBuffer::setSpecial(void* ptr, size_t length) { - _dataBuffer->setSpecialBuffer(ptr, length); - } - - uint64_t InteropDataBuffer::offset() const { - return _offset; - } - - void InteropDataBuffer::setOffset(uint64_t offset) { - _offset = offset; - } - - int InteropDataBuffer::deviceId() const { - return _dataBuffer->deviceId(); - } - - - void InteropDataBuffer::registerSpecialUse(const std::vector& writeList, const std::vector& readList) { - for (const auto &v:writeList) { - if (v == nullptr) - continue; - - v->getDataBuffer()->writeSpecial(); - } - } - - void InteropDataBuffer::prepareSpecialUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { - auto currentDeviceId = sd::AffinityManager::currentDeviceId(); - for (const auto &v:readList) { - if (v == nullptr) - continue; - - if (v->getDataBuffer()->deviceId() != currentDeviceId) - v->getDataBuffer()->migrate(); - - v->getDataBuffer()->syncToSpecial(); - } - - // we don't tick write list, only ensure the same device affinity - for (const auto &v:writeList) { - if (v == nullptr) - continue; - - // special case for legacy ops - views can be updated on host side, thus original array can be not updated - if (!v->getDataBuffer()->isSpecialActual()) - v->getDataBuffer()->syncToSpecial(); - - if (v->getDataBuffer()->deviceId() != currentDeviceId) - v->getDataBuffer()->migrate(); - } - } - - void InteropDataBuffer::registerPrimaryUse(const std::vector& writeList, const std::vector& readList) { - for (const auto &v:writeList) { - if (v == nullptr) - continue; - } - } - - void InteropDataBuffer::preparePrimaryUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { - for (const auto &v:readList) { - if (v == nullptr) - continue; - - v->getDataBuffer()->syncToPrimary(LaunchContext::defaultContext()); - } - } - - void InteropDataBuffer::expand(size_t newlength) { - _dataBuffer->expand(newlength * DataTypeUtils::sizeOf(_dataBuffer->getDataType())); - } - - void InteropDataBuffer::setDeviceId(int deviceId) { - _dataBuffer->setDeviceId(deviceId); - } +InteropDataBuffer::InteropDataBuffer(InteropDataBuffer& dataBuffer, + uint64_t length, uint64_t offset) { + _dataBuffer = dataBuffer.getDataBuffer(); + + // offset is always absolute to the original buffer + _offset = offset; + + if (_offset + length > _dataBuffer->getLenInBytes()) { + throw std::runtime_error("offset + length is higher than original length"); + } +} + +InteropDataBuffer::InteropDataBuffer(std::shared_ptr databuffer) { + _dataBuffer = databuffer; +} + +InteropDataBuffer::InteropDataBuffer(size_t elements, sd::DataType dtype, + bool allocateBoth) { + if (elements == 0) { + _dataBuffer = std::make_shared(); + _dataBuffer->setDataType(dtype); + } else { + _dataBuffer = + std::make_shared(elements, dtype, nullptr, allocateBoth); + } +} + +std::shared_ptr InteropDataBuffer::getDataBuffer() const { + return _dataBuffer; +} + +std::shared_ptr InteropDataBuffer::dataBuffer() { + return _dataBuffer; +} + +void* InteropDataBuffer::primary() const { + return reinterpret_cast(_dataBuffer->primary()) + _offset; +} + +void* InteropDataBuffer::special() const { + return reinterpret_cast(_dataBuffer->special()) + _offset; +} + +void InteropDataBuffer::setPrimary(void* ptr, size_t length) { + _dataBuffer->setPrimaryBuffer(ptr, length); +} + +void InteropDataBuffer::setSpecial(void* ptr, size_t length) { + _dataBuffer->setSpecialBuffer(ptr, length); +} + +uint64_t InteropDataBuffer::offset() const { return _offset; } + +void InteropDataBuffer::setOffset(uint64_t offset) { _offset = offset; } + +int InteropDataBuffer::deviceId() const { return _dataBuffer->deviceId(); } + +void InteropDataBuffer::registerSpecialUse( + const std::vector& writeList, + const std::vector& readList) { + for (const auto& v : writeList) { + if (v == nullptr) continue; + + v->getDataBuffer()->writeSpecial(); + } +} + +void InteropDataBuffer::prepareSpecialUse( + const std::vector& writeList, + const std::vector& readList, + bool synchronizeWritables) { + auto currentDeviceId = sd::AffinityManager::currentDeviceId(); + for (const auto& v : readList) { + if (v == nullptr) continue; + + if (v->getDataBuffer()->deviceId() != currentDeviceId) + v->getDataBuffer()->migrate(); + + v->getDataBuffer()->syncToSpecial(); + } + + // we don't tick write list, only ensure the same device affinity + for (const auto& v : writeList) { + if (v == nullptr) continue; + + // special case for legacy ops - views can be updated on host side, thus + // original array can be not updated + if (!v->getDataBuffer()->isSpecialActual()) + v->getDataBuffer()->syncToSpecial(); + + if (v->getDataBuffer()->deviceId() != currentDeviceId) + v->getDataBuffer()->migrate(); + } +} + +void InteropDataBuffer::registerPrimaryUse( + const std::vector& writeList, + const std::vector& readList) { + for (const auto& v : writeList) { + if (v == nullptr) continue; + } +} + +void InteropDataBuffer::preparePrimaryUse( + const std::vector& writeList, + const std::vector& readList, + bool synchronizeWritables) { + for (const auto& v : readList) { + if (v == nullptr) continue; + + v->getDataBuffer()->syncToPrimary(LaunchContext::defaultContext()); + } +} + +void InteropDataBuffer::expand(size_t newlength) { + _dataBuffer->expand(newlength * + DataTypeUtils::sizeOf(_dataBuffer->getDataType())); +} + +void InteropDataBuffer::setDeviceId(int deviceId) { + _dataBuffer->setDeviceId(deviceId); } +} // namespace sd diff --git a/libnd4j/include/array/impl/ManagedDataBuffer.cpp b/libnd4j/include/array/impl/ManagedDataBuffer.cpp new file mode 100644 index 000000000000..8f486f27faca --- /dev/null +++ b/libnd4j/include/array/impl/ManagedDataBuffer.cpp @@ -0,0 +1,39 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { +ManagedDataBuffer::ManagedDataBuffer(graph::GraphMemoryManager &manager, + uint64_t numberOfBytes, DataType dtype, + memory::MemoryZone zone) + : _manager(manager), + _zone(zone), + _descriptor(manager.allocate(numberOfBytes, zone)) { + _lenInBytes = numberOfBytes; + _dataType = dtype; +} + +ManagedDataBuffer::~ManagedDataBuffer() { + // if we know that MDB can be released - it means that all NDArrays were + // released, so it's really safe to release + _manager.release(_descriptor); +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/array/impl/NDArray.cpp b/libnd4j/include/array/impl/NDArray.cpp new file mode 100644 index 000000000000..8abb259e478a --- /dev/null +++ b/libnd4j/include/array/impl/NDArray.cpp @@ -0,0 +1,7164 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// $NDArray.hpp - architech-independent implementations (both cuda and cpu). +// +#ifndef __NDARRAY__HPP__ +#define __NDARRAY__HPP__ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace sd { + +template <> +SD_EXPORT utf8string NDArray::e(const Nd4jLong i) const; +template <> +SD_EXPORT std::string NDArray::e(const Nd4jLong i) const; +template <> +SD_EXPORT std::u16string NDArray::e(const Nd4jLong i) const; +template <> +SD_EXPORT std::u32string NDArray::e(const Nd4jLong i) const; + +//////////////////////////////////////////////////////////////////////// +// copy constructor +NDArray::NDArray(const NDArray& other) { + // setShapeInfo(ShapeDescriptor(other.dataType(), other.ordering(), + // other.shapeOf(), other.rankOf())); + /* + if(!isEmpty()) { + _buffer = std::make_shared(other.lengthOf() * + other.sizeOfT(), other.dataType(), other.getContext()->getWorkspace()); + this->assign(&other); + } + else + _buffer = std::make_shared(); + */ + _buffer = other._buffer; + _shapeInfo = other._shapeInfo; + _shapeInfoD = other._shapeInfoD; + _length = other._length; + _isAttached = other._isAttached; + _isView = other._isView; + _context = other._context; + _dataType = other._dataType; + _deviceId = other._deviceId; + _offset = other._offset; +} + +//////////////////////////////////////////////////////////////////////// +NDArray::NDArray(const char order, const std::vector& shape, + sd::DataType dtype, sd::LaunchContext* context) { + if ((int)shape.size() > MAX_RANK) + throw std::invalid_argument("Rank of NDArray can't exceed 32"); + + _context = context; + _isAttached = _context->getWorkspace() != nullptr; + _offset = 0; + + if (shape.empty()) + setShapeInfo(ShapeDescriptor::emptyDescriptor(dtype)); + else + setShapeInfo(ShapeDescriptor(dtype, order, shape)); + + _buffer = + std::make_shared(lengthOf() * DataTypeUtils::sizeOf(dtype), + dtype, getContext()->getWorkspace()); + _buffer->setToZeroBuffers(); +} + +bool NDArray::defined() const { return _shapeInfo != nullptr; } + +bool NDArray::undefined() const { return _shapeInfo == nullptr; } + +//////////////////////////////////////////////////////////////////////// +NDArray::NDArray(const char order, const std::vector& shape, + const std::vector& data, sd::DataType dtype, + sd::LaunchContext* context) { + if ((int)shape.size() > MAX_RANK) + throw std::invalid_argument("Rank of NDArray can't exceed 32"); + + _context = context; + _offset = 0; + + if (shape.size() == 0) { + if (data.size() == 0) + setShapeInfo(ShapeDescriptor::emptyDescriptor(dtype)); + else + setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); + } else { + setShapeInfo(ShapeDescriptor(dtype, order, shape)); + } + + if (lengthOf() != data.size()) { + nd4j_printf( + "NDArray constructor: data size [%i] doesn't match shape length [%i]\n", + data.size(), lengthOf()); + throw std::runtime_error("Data size doesn't match shape"); + } + + _buffer = + std::make_shared(lengthOf() * DataTypeUtils::sizeOf(dtype), + dtype, getContext()->getWorkspace(), true); + + for (Nd4jLong i = 0; i < lengthOf(); ++i) { + BUILD_SINGLE_PARTIAL_SELECTOR( + dtype, + templatedDoubleAssign<, double>( + buffer(), i, reinterpret_cast(data.data()), i), + LIBND4J_TYPES); + } + tickWriteHost(); + syncToDevice(); +} + +//////////////////////////////////////////////////////////////////////// +NDArray::NDArray(const NDArray* other, const bool copyStrides, + sd::LaunchContext* context) { + _context = context; + _offset = 0; + _isAttached = getContext()->getWorkspace() != nullptr; + + if (copyStrides) + setShapeInfo(ShapeDescriptor(other->_shapeInfo)); + else + setShapeInfo(ShapeDescriptor(other->dataType(), other->ordering(), + other->shapeOf(), other->rankOf())); + + if (!isEmpty()) + _buffer = std::make_shared(lengthOf() * sizeOfT(), dataType(), + getContext()->getWorkspace()); +} + +//////////////////////////////////////////////////////////////////////// +NDArray::NDArray(void* buffer, const char order, + const std::vector& shape, sd::DataType dtype, + sd::LaunchContext* context, const bool isBuffAlloc) { + if (shape.empty()) + throw std::runtime_error("NDArray constructor: input shape is empty !"); + + if ((int)shape.size() > MAX_RANK) + throw std::invalid_argument("Rank of NDArray can't exceed 32"); + + _context = context; + _offset = 0; + _isAttached = getContext()->getWorkspace() != nullptr; + + setShapeInfo(ShapeDescriptor(dtype, order, shape)); + + _buffer = + std::make_shared(buffer, lengthOf() * sizeOfT(), dataType(), + isBuffAlloc, getContext()->getWorkspace()); +} + +//////////////////////////////////////////////////////////////////////// +// creates new NDArray using shape information from "shapeInfo" array, set all +// elements in new array to be zeros +NDArray::NDArray(const Nd4jLong* shapeInfo, const sd::DataType dtype, + const bool copyStrides, sd::LaunchContext* context, + const bool nullify) { + if (shapeInfo == nullptr) + throw std::runtime_error( + "NDArray constructor: can't be initalized without shapeinfo"); + + if ((int)shapeInfo[0] > MAX_RANK) + throw std::invalid_argument("Rank of NDArray can't exceed 32"); + + _context = context; + _offset = 0; + + if (copyStrides) + setShapeInfo(ShapeDescriptor(shapeInfo, dtype)); + else + setShapeInfo(ShapeDescriptor(dtype, shape::order(shapeInfo), + shape::shapeOf(shapeInfo), + shape::rank(shapeInfo))); + + if (!isEmpty()) { + _buffer = std::make_shared(lengthOf() * sizeOfT(), dtype, + getContext()->getWorkspace()); + + if (nullify) _buffer->setToZeroBuffers(); + } +} + +//////////////////////////////////////////////////////////////////////// +// scalar constructor +NDArray::NDArray(sd::DataType dtype, sd::LaunchContext* context, + const bool isScalar) { + _context = context; + _offset = 0; + _isAttached = getContext()->getWorkspace() != nullptr; + + if (isScalar) { + setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); + _buffer = std::make_shared(sizeOfT(), dtype, + getContext()->getWorkspace()); + _buffer->setToZeroBuffers(); + } else + setShapeInfo(ConstantShapeHelper::getInstance().emptyShapeInfo(dtype)); +} + +////////////////////////////////////////////////////////////////////////// +// move constructor +NDArray::NDArray(NDArray&& other) noexcept { + _isView = other._isView; + _buffer = other._buffer; + _shapeInfo = other._shapeInfo; + _shapeInfoD = other._shapeInfoD; + _context = other._context; + _dataType = other._dataType; + _length = other._length; + _offset = other._offset; + + other._buffer = std::make_shared(); + other._shapeInfo = other._shapeInfoD = nullptr; + other._length = 0; +} + +//////////////////////////////////////////////////////////////////////// +// constructor, create empty array at given workspace +NDArray::NDArray(sd::LaunchContext* context) { + _buffer = std::make_shared(); + _shapeInfo = nullptr; + _shapeInfoD = nullptr; + _offset = 0; + _context = context; + _length = 0; +} + +//////////////////////////////////////////////////////////////////////// +// creates new NDArray using shape information from "shapeInfo" array, set all +// elements in new array to be zeros, set dtype as array type +NDArray::NDArray(const Nd4jLong* shapeInfo, const bool copyStrides, + sd::LaunchContext* context, const bool nullify) + : NDArray(shapeInfo, ArrayOptions::dataType(shapeInfo), copyStrides, + context) {} + +//////////////////////////////////////////////////////////////////////// +NDArray::NDArray(std::shared_ptr buffer, + const ShapeDescriptor& descriptor, sd::LaunchContext* context, + const Nd4jLong offset) { + _context = context; + _offset = offset; + + setShapeInfo(descriptor); + + _buffer = buffer; + + _isView = offset > 0 || _length * DataTypeUtils::sizeOf(_dataType) < + buffer->getLenInBytes(); +} + +NDArray::NDArray(void* buffer, Nd4jLong* shapeInfo, sd::LaunchContext* context, + const bool isBuffAlloc) + : NDArray::NDArray(buffer, const_cast(shapeInfo), context, + isBuffAlloc) { + // +} + +//////////////////////////////////////////////////////////////////////// +// do not allocate memory, memory for array is passed from outside +NDArray::NDArray(void* buffer, const Nd4jLong* shapeInfo, + sd::LaunchContext* context, const bool isBuffAlloc) { + if (buffer == nullptr && + ArrayOptions::arrayType(shapeInfo) != ArrayType::EMPTY) + throw std::runtime_error( + "NDArray constructor: can't be initalized with nullptr buffer !"); + + if (shapeInfo == nullptr) + throw std::runtime_error( + "NDArray constructor: can't be initalized without shapeinfo !"); + + if ((int)shapeInfo[0] > MAX_RANK) + throw std::invalid_argument( + "NDArray constructor: rank of NDArray can't exceed 32 !"); + + _context = context; + _isAttached = getContext()->getWorkspace() != nullptr; + _offset = 0; + + setShapeInfo(ShapeDescriptor(shapeInfo)); + + if (this->isEmpty()) { + tickReadDevice(); + tickReadHost(); + } else { + _buffer = + std::make_shared(buffer, lengthOf() * sizeOfT(), dataType(), + isBuffAlloc, getContext()->getWorkspace()); + } +} + +//////////////////////////////////////////////////////////////////////// +// do not allocate memory, memory for array is passed from outside +// we suppose the content of both (device and host) buffers is identical +NDArray::NDArray(void* buffer, void* bufferD, const Nd4jLong* shapeInfo, + sd::LaunchContext* context, const bool isBuffAlloc, + const bool isBuffDAlloc) { + if (shapeInfo == nullptr) + throw std::runtime_error( + "NDArray constructor cuda: can't be initalized without shapeinfo"); + + if ((int)shapeInfo[0] > MAX_RANK) + throw std::invalid_argument( + "NDArray constructor cuda: rank of NDArray can't exceed 32"); + + _context = context; + _offset = 0; + + setShapeInfo(ShapeDescriptor(shapeInfo)); + + if (!isEmpty()) + _buffer = std::make_shared( + buffer, bufferD, lengthOf() * sizeOfT(), dataType(), isBuffAlloc, + isBuffDAlloc, getContext()->getWorkspace()); +} + +////////////////////////////////////////////////////////////////////////// +NDArray::NDArray(std::shared_ptr buffer, const char order, + const std::vector& shape, + sd::LaunchContext* context) { + if (shape.empty()) + throw std::runtime_error("NDArray constructor: input shape is empty !"); + + if ((int)shape.size() > MAX_RANK) + throw std::invalid_argument( + "NDArray constructor: rank of NDArray can't exceed 32"); + + _context = context; + _offset = 0; + + setShapeInfo(ShapeDescriptor(buffer->getDataType(), order, shape)); + + _buffer = buffer; + + _isView = + _length * DataTypeUtils::sizeOf(_dataType) < buffer->getLenInBytes(); +} +///////////////////////////////////////////////////////////////////////// +// u16 string constructors +NDArray::NDArray(const std::u16string& u16string, sd::DataType dtype, + sd::LaunchContext* context) { + if (!DataTypeUtils::isS(dtype)) { + throw std::invalid_argument( + "NDArray::NDArray: invalid DataType, only string dataTypes have to be " + "used"); + } + + if (!unicode::isStringValidU16(u16string.data(), + u16string.data() + u16string.size())) { + throw std::invalid_argument( + "NDArray::NDArray: invalid character in input string"); + } + + // one word that is why used 1 + Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(1); + + Nd4jLong dataLength = [&] { + if (dtype == DataType::UTF16) { + return static_cast(u16string.size() * sizeof(uint16_t)); + } + if (dtype == DataType::UTF32) { + return unicode::offsetUtf16StringInUtf32(u16string.data(), + u16string.size()); + } + return unicode::offsetUtf16StringInUtf8(u16string.data(), u16string.size()); + }(); + + Nd4jLong offsets[2] = {0, dataLength}; + + _buffer = std::make_shared(headerLength + dataLength, dtype, + context->getWorkspace(), true); + + _context = context; + _isAttached = getContext()->getWorkspace() != nullptr; + _offset = 0; + + setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); + + memcpy(bufferAsT(), &offsets[0], 2 * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + if (dtype == DataType::UTF8) { + unicode::utf16to8(u16string.data(), data, u16string.size()); + } else if (dtype == DataType::UTF16) { + memcpy(data, u16string.data(), dataLength); + } else { + unicode::utf16to32(u16string.data(), data, u16string.size()); + } + + tickWriteHost(); + syncToDevice(); +} + +///////////////////////////////////////////////////////////////////////// +// u32 string constructors +NDArray::NDArray(const std::u32string& u32string, sd::DataType dtype, + sd::LaunchContext* context) { + if (!DataTypeUtils::isS(dtype)) { + throw std::invalid_argument( + "NDArray::NDArray: invalid DataType, only string dataTypes have to be " + "used"); + } + + if (!unicode::isStringValidU32(u32string.data(), + u32string.data() + u32string.size())) { + throw std::invalid_argument( + "NDArray::NDArray: invalid character in input string"); + } + // one word that is why used 1 + Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(1); + + Nd4jLong dataLength = [&] { + if (dtype == DataType::UTF16) { + return unicode::offsetUtf32StringInUtf16(u32string.data(), + u32string.size()); + } + if (dtype == DataType::UTF32) { + return static_cast(sizeof(uint32_t) * u32string.size()); + } + return unicode::offsetUtf32StringInUtf8(u32string.data(), u32string.size()); + }(); + + Nd4jLong offsets[2] = {0, dataLength}; + + _buffer = std::make_shared(headerLength + dataLength, dtype, + context->getWorkspace(), true); + + _context = context; + _isAttached = getContext()->getWorkspace() != nullptr; + _offset = 0; + + setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); + + memcpy(bufferAsT(), &offsets[0], 2 * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + if (dtype == DataType::UTF8) { + unicode::utf32to8(u32string.data(), data, u32string.size()); + } else if (dtype == DataType::UTF16) { + unicode::utf32to16(u32string.data(), data, u32string.size()); + } else { + memcpy(data, u32string.data(), u32string.size() * sizeof(uint32_t)); + } + + tickWriteHost(); + syncToDevice(); +} + +///////////////////////////////////////////////////////////////////////// +// u8 string constructors +NDArray::NDArray(const std::string& str, sd::DataType dtype, + sd::LaunchContext* context) { + if (!DataTypeUtils::isS(dtype)) { + throw std::invalid_argument( + "NDArray::NDArray: invalid DataType, only string dataTypes have to be " + "used"); + } + + if (!unicode::isStringValidU8(str.data(), str.data() + str.size())) { + throw std::invalid_argument( + "NDArray::NDArray: invalid character in input string"); + } + + // one word that is why used 1 + auto headerLength = ShapeUtils::stringBufferHeaderRequirements(1); + + Nd4jLong dataLength = [&] { + if (dtype == DataType::UTF16) { + return unicode::offsetUtf8StringInUtf16(str.data(), str.size()); + } + if (dtype == DataType::UTF32) { + return unicode::offsetUtf8StringInUtf32(str.data(), str.size()); + } + return static_cast(str.size()); + }(); + + Nd4jLong offsets[2] = {0, dataLength}; + + _buffer = std::make_shared(headerLength + dataLength, dtype, + context->getWorkspace(), true); + + _context = context; + _isAttached = getContext()->getWorkspace() != nullptr; + _offset = 0; + + setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); + + memcpy(bufferAsT(), &offsets[0], 2 * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + + if (dtype == DataType::UTF8) { + memcpy(data, str.data(), str.size()); + } else if (dtype == DataType::UTF16) { + unicode::utf8to16(str.data(), data, str.size()); + } else { + unicode::utf8to32(str.data(), data, str.size()); + } + + tickWriteHost(); + syncToDevice(); +} +///////////////////////////////////////////////////////////////////////// +// constructors for vector of strings +NDArray::NDArray(const std::vector& shape, + const std::vector& string, + const sd::DataType dataType, sd::LaunchContext* context) { + if (!DataTypeUtils::isS(dataType)) + throw std::invalid_argument( + "NDArray::NDArray: invalid DataType, only string dataTypes have to be " + "used"); + + if (shape::prodLong(shape.data(), shape.size()) != string.size()) + throw std::invalid_argument( + "NDArray::NDArray: Number of strings should match length of array"); + + for (const auto& str : string) { + if (!unicode::isStringValidU8(str, + str + std::char_traits::length(str))) { + throw std::invalid_argument( + "NDArray::NDArray: invalid character in input string"); + } + } + + Nd4jLong headerLength = + ShapeUtils::stringBufferHeaderRequirements(string.size()); + + std::vector offsets(string.size() + 1); + Nd4jLong dataLength = 0; + for (int e = 0; e < string.size(); e++) { + offsets[e] = dataLength; + dataLength += [&] { + if (dataType == DataType::UTF16) + return unicode::offsetUtf8StringInUtf16( + string[e], std::char_traits::length(string[e])); + if (dataType == DataType::UTF32) + return unicode::offsetUtf8StringInUtf32( + string[e], std::char_traits::length(string[e])); + return static_cast(std::char_traits::length(string[e])); + }(); + } + offsets[string.size()] = dataLength; + + _buffer = std::make_shared(headerLength + dataLength, dataType, + context->getWorkspace(), true); + + _context = context; + _offset = 0; + + setShapeInfo(ShapeDescriptor(dataType, 'c', shape)); + + _isView = false; + + setAttached(context->getWorkspace() != nullptr); + + memcpy(bufferAsT(), offsets.data(), + offsets.size() * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto cdata = data + offsets[e]; + if (dataType == DataType::UTF16) { + unicode::utf8to16(string[e], cdata, + std::char_traits::length(string[e])); + } else if (dataType == DataType::UTF32) { + unicode::utf8to32(string[e], cdata, + std::char_traits::length(string[e])); + } else { + memcpy(cdata, string[e], std::char_traits::length(string[e])); + } + } + }; + + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + tickWriteHost(); + syncToDevice(); +} +///////////////////////////////////////////////////////////////////////// +NDArray::NDArray(const std::vector& shape, + const std::vector& string, + const sd::DataType dataType, sd::LaunchContext* context) { + if (!DataTypeUtils::isS(dataType)) + throw std::invalid_argument( + "NDArray::NDArray: invalid DataType, only string dataTypes have to be " + "used"); + + if (shape::prodLong(shape.data(), shape.size()) != string.size()) + throw std::invalid_argument( + "NDArray::NDArray: Number of strings should match length of array"); + + for (const auto& str : string) { + if (!unicode::isStringValidU8(str.data(), str.data() + str.size())) { + throw std::invalid_argument( + "NDArray::NDArray: invalid character in input string"); + } + } + + Nd4jLong headerLength = + ShapeUtils::stringBufferHeaderRequirements(string.size()); + + std::vector offsets(string.size() + 1); + Nd4jLong dataLength = 0; + for (int e = 0; e < string.size(); e++) { + offsets[e] = dataLength; + dataLength += [&] { + if (dataType == DataType::UTF16) + return unicode::offsetUtf8StringInUtf16(string[e].data(), + string[e].size()); + if (dataType == DataType::UTF32) + return unicode::offsetUtf8StringInUtf32(string[e].data(), + string[e].size()); + return static_cast(string[e].size()); + }(); + } + + offsets[string.size()] = dataLength; + + _buffer = std::make_shared(headerLength + dataLength, dataType, + context->getWorkspace(), true); + + _context = context; + _offset = 0; + + setShapeInfo(ShapeDescriptor(dataType, 'c', shape)); + + _isView = false; + + setAttached(context->getWorkspace() != nullptr); + + memcpy(bufferAsT(), offsets.data(), + offsets.size() * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto cdata = data + offsets[e]; + if (dataType == DataType::UTF16) { + unicode::utf8to16(string[e].data(), cdata, string[e].size()); + } else if (dataType == DataType::UTF32) { + unicode::utf8to32(string[e].data(), cdata, string[e].size()); + } else { + memcpy(cdata, string[e].data(), string[e].size()); + } + } + }; + + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + tickWriteHost(); + syncToDevice(); +} +///////////////////////////////////////////////////////////////////////// +NDArray::NDArray(const std::vector& shape, + const std::vector& string, sd::DataType dtype, + sd::LaunchContext* context) { + if (!DataTypeUtils::isS(dtype)) + throw std::invalid_argument( + "NDArray::NDArray: invalid DataType, only string dataTypes have to be " + "used"); + + if (shape::prodLong(shape.data(), shape.size()) != string.size()) + throw std::invalid_argument( + "NDArray::NDArray: Number of strings should match length of array"); + + for (const auto& str : string) { + if (!unicode::isStringValidU16(str.data(), str.data() + str.size())) { + throw std::invalid_argument( + "NDArray::NDArray: invalid character in input string"); + } + } + + Nd4jLong headerLength = + ShapeUtils::stringBufferHeaderRequirements(string.size()); + + std::vector offsets(string.size() + 1); + Nd4jLong dataLength = 0; + for (int e = 0; e < string.size(); e++) { + offsets[e] = dataLength; + dataLength += [&] { + if (dtype == DataType::UTF16) + return static_cast(sizeof(uint16_t) * string[e].size()); + if (dtype == DataType::UTF32) + return unicode::offsetUtf16StringInUtf32(string[e].data(), + string[e].size()); + return unicode::offsetUtf16StringInUtf8(string[e].data(), + string[e].size()); + }(); + } + offsets[string.size()] = dataLength; + + _buffer = std::make_shared(headerLength + dataLength, dtype, + context->getWorkspace(), true); + + _context = context; + _offset = 0; + + setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); + + _isView = false; + + setAttached(context->getWorkspace() != nullptr); + + memcpy(bufferAsT(), offsets.data(), + offsets.size() * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto cdata = data + offsets[e]; + if (dtype == DataType::UTF16) { + memcpy(cdata, string[e].data(), string[e].size() * sizeof(uint16_t)); + } else if (dtype == DataType::UTF32) { + unicode::utf16to32(string[e].data(), cdata, string[e].size()); + } else { + unicode::utf16to8(string[e].data(), cdata, string[e].size()); + } + } + }; + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + tickWriteHost(); + syncToDevice(); +} +///////////////////////////////////////////////////////////////////////// +NDArray::NDArray(const std::vector& shape, + const std::vector& string, sd::DataType dtype, + sd::LaunchContext* context) { + if (!DataTypeUtils::isS(dtype)) + throw std::invalid_argument( + "NDArray::NDArray: invalid DataType, only string dataTypes have to be " + "used"); + + if (shape::prodLong(shape.data(), shape.size()) != string.size()) + throw std::invalid_argument( + "NDArray::NDArray: Number of strings should match length of array"); + + for (const auto& str : string) { + if (!unicode::isStringValidU16( + str, str + std::char_traits::length(str))) { + throw std::invalid_argument( + "NDArray::NDArray: invalid character in input string"); + } + } + + Nd4jLong headerLength = + ShapeUtils::stringBufferHeaderRequirements(string.size()); + + std::vector offsets(string.size() + 1); + Nd4jLong dataLength = 0; + for (int e = 0; e < string.size(); e++) { + offsets[e] = dataLength; + dataLength += [&] { + if (dtype == DataType::UTF16) + return static_cast( + sizeof(uint16_t) * std::char_traits::length(string[e])); + if (dtype == DataType::UTF32) + return unicode::offsetUtf16StringInUtf32( + string[e], std::char_traits::length(string[e])); + return unicode::offsetUtf16StringInUtf8( + string[e], std::char_traits::length(string[e])); + }(); + } + offsets[string.size()] = dataLength; + + _buffer = std::make_shared(headerLength + dataLength, dtype, + context->getWorkspace(), true); + + _context = context; + _offset = 0; + + setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); + + _isView = false; + + setAttached(context->getWorkspace() != nullptr); + + memcpy(bufferAsT(), offsets.data(), + offsets.size() * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto cdata = data + offsets[e]; + if (dtype == DataType::UTF16) { + memcpy( + cdata, string[e], + std::char_traits::length(string[e]) * sizeof(uint16_t)); + } else if (dtype == DataType::UTF32) { + unicode::utf16to32(string[e], cdata, + std::char_traits::length(string[e])); + } else { + unicode::utf16to8(string[e], cdata, + std::char_traits::length(string[e])); + } + } + }; + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + tickWriteHost(); + syncToDevice(); +} +///////////////////////////////////////////////////////////////////////// +NDArray::NDArray(const std::vector& shape, + const std::vector& string, sd::DataType dtype, + sd::LaunchContext* context) { + if (!DataTypeUtils::isS(dtype)) + throw std::invalid_argument( + "NDArray::NDArray: invalid DataType, only string dataTypes have to be " + "used"); + + if (shape::prodLong(shape.data(), shape.size()) != string.size()) + throw std::invalid_argument( + "NDArray::NDArray: Number of strings should match length of array"); + + for (auto str : string) { + if (!unicode::isStringValidU32(str.data(), str.data() + str.size())) { + throw std::invalid_argument( + "NDArray::NDArray: invalid character in input string"); + } + } + + Nd4jLong headerLength = + ShapeUtils::stringBufferHeaderRequirements(string.size()); + + std::vector offsets(string.size() + 1); + + Nd4jLong dataLength = 0; + for (int e = 0; e < string.size(); e++) { + offsets[e] = dataLength; + dataLength += [&] { + if (dtype == DataType::UTF16) + return unicode::offsetUtf32StringInUtf16(string[e].data(), + string[e].size()); + if (dtype == DataType::UTF32) + return static_cast(sizeof(uint32_t) * string[e].size()); + return unicode::offsetUtf32StringInUtf16(string[e].data(), + string[e].size()); + }(); + } + offsets[string.size()] = dataLength; + + _buffer = std::make_shared(headerLength + dataLength, dtype, + context->getWorkspace(), true); + + _context = context; + _offset = 0; + + setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); + + _isView = false; + + setAttached(context->getWorkspace() != nullptr); + + memcpy(bufferAsT(), offsets.data(), + offsets.size() * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto cdata = data + offsets[e]; + if (dtype == DataType::UTF16) { + unicode::utf32to16(string[e].data(), cdata, string[e].size()); + } else if (dtype == DataType::UTF32) { + memcpy(cdata, string[e].data(), string[e].size() * sizeof(uint32_t)); + } else { + unicode::utf32to8(string[e].data(), cdata, string[e].size()); + } + } + }; + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + tickWriteHost(); + syncToDevice(); +} + + +std::ostream& operator<<(std::ostream &os, const NDArray &m) { + os << m.indexedBufferString(); + return os; +} + +///////////////////////////////////////////////////////////////////////// +NDArray::NDArray(const std::vector& shape, + const std::vector& string, sd::DataType dtype, + sd::LaunchContext* context) { + if (!DataTypeUtils::isS(dtype)) + throw std::invalid_argument("NDArray::NDArray: invalid DataType used"); + + if (shape::prodLong(shape.data(), shape.size()) != string.size()) + throw std::invalid_argument( + "NDArray::NDArray: Number of strings should match length of array"); + + for (const auto& str : string) { + if (!unicode::isStringValidU32( + str, str + std::char_traits::length(str))) { + throw std::invalid_argument( + "NDArray::NDArray: invalid character in input string"); + } + } + + Nd4jLong headerLength = + ShapeUtils::stringBufferHeaderRequirements(string.size()); + + std::vector offsets(string.size() + 1); + + Nd4jLong dataLength = 0; + for (int e = 0; e < string.size(); e++) { + offsets[e] = dataLength; + dataLength += [&] { + if (dtype == DataType::UTF16) + return unicode::offsetUtf32StringInUtf16( + string[e], std::char_traits::length(string[e])); + if (dtype == DataType::UTF32) + return static_cast( + sizeof(uint32_t) * std::char_traits::length(string[e])); + return unicode::offsetUtf32StringInUtf16( + string[e], std::char_traits::length(string[e])); + }(); + } + offsets[string.size()] = dataLength; + + _buffer = std::make_shared(headerLength + dataLength, dtype, + context->getWorkspace(), true); + + _context = context; + _offset = 0; + + setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); + + _isView = + _length * DataTypeUtils::sizeOf(_dataType) < _buffer->getLenInBytes(); + + setAttached(context->getWorkspace() != nullptr); + + memcpy(bufferAsT(), offsets.data(), + offsets.size() * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto cdata = data + offsets[e]; + if (dtype == DataType::UTF16) { + unicode::utf32to16(string[e], cdata, + std::char_traits::length(string[e])); + } else if (dtype == DataType::UTF32) { + memcpy( + cdata, string[e], + std::char_traits::length(string[e]) * sizeof(uint32_t)); + } else { + unicode::utf32to8(string[e], cdata, + std::char_traits::length(string[e])); + } + } + }; + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + tickWriteHost(); + syncToDevice(); +} + +//////////////////////////////////////////////////////////////////////// +// assignment operator +NDArray& NDArray::operator=(const NDArray& other) { + if (this == &other || + (_shapeInfo == other._shapeInfo && _shapeInfo == nullptr)) + return *this; + + _buffer = other._buffer; + _shapeInfo = other._shapeInfo; + _shapeInfoD = other._shapeInfoD; + _length = other._length; + _isAttached = other._isAttached; + _isView = other._isView; + _context = other._context; + _dataType = other._dataType; + _deviceId = other._deviceId; + _offset = other._offset; + + /* + if (_shapeInfo != nullptr && shape::equalsTypesAndShapesSoft(_shapeInfo, + other._shapeInfo)) { if(!other.isEmpty()) this->assign(&other); + } + else { + _context = other._context; + _offset = 0; + setShapeInfo(ShapeDescriptor(other.dataType(), other.ordering(), + other.shapeOf(), other.rankOf())); + + if(!other.isEmpty()) { + _buffer = std::make_shared(other.lengthOf() * + other.sizeOfT(), other.dataType(), other.getContext()->getWorkspace()); + this->assign(&other); + } + else + _buffer = std::make_shared(); + } + */ + + return *this; +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::isC() const { + // TODO: this method must be implemented once we add support for complex + // numbers + return false; +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::isS() const { + return (dataType() == DataType::UTF8 || dataType() == DataType::UTF16 || + dataType() == DataType::UTF32); +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::isR() const { + auto xType = ArrayOptions::dataType(this->_shapeInfo); + return xType == FLOAT32 || xType == HALF || xType == DOUBLE || + xType == FLOAT8 || xType == BFLOAT16; +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::isZ() const { + return !isC() && !isR() && !isB() && !isS(); +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::isB() const { + return ArrayOptions::dataType(this->_shapeInfo) == BOOL; +} + +////////////////////////////////////////////////////////////////////////// +template +std::string NDArray::toStringValue(T value) const { + std::ostringstream os; + // throw the value into the string stream + os << value; + // convert the string stream into a string and return + return os.str(); +} + +////////////////////////////////////////////////////////////////////////// +template <> +std::string NDArray::toStringValue(float16 value) const { + std::ostringstream os; + // throw the value into the string stream + os << (float)value; + // convert the string stream into a string and return + return os.str(); +} + +////////////////////////////////////////////////////////////////////////// +template <> +std::string NDArray::toStringValue(bfloat16 value) const { + std::ostringstream os; + // throw the value into the string stream + os << (float)value; + // convert the string stream into a string and return + return os.str(); +} + +////////////////////////////////////////////////////////////////////////// +std::string NDArray::asIndexedString(Nd4jLong limit) const { + std::ostringstream os; + os << "["; + if (limit < 1 || limit > this->lengthOf()) limit = this->lengthOf(); + for (Nd4jLong e = 0; e < limit; e++) { + os << toStringValue(this->e(e)); + if (e < limit - 1) os << ", "; + } + os << "]"; + return os.str(); +} + +////////////////////////////////////////////////////////////////////////// +std::string NDArray::asString(Nd4jLong limit) const { + std::ostringstream os; + os << "["; + if (limit < 1 || limit > this->lengthOf()) limit = this->lengthOf(); + for (Nd4jLong e = 0; e < limit; e++) { + if (this->isR()) + os << toStringValue(this->e(e)); + else if (this->isZ()) + os << toStringValue(this->e(e)); + else if (this->isB()) + os << toStringValue(this->e(e)); + else if (this->isS()) // todo add utf16 and utf32 + os << this->e(e); + if (e < limit - 1) os << ", "; + } + os << "]"; + return os.str(); +} + +//////////////////////////////////////////////////////////////////////// +template +std::vector NDArray::getBufferAsVector() const { + std::vector vector(lengthOf()); + for (Nd4jLong e = 0; e < lengthOf(); e++) vector[e] = this->e(e); + return vector; +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT std::vector, + NDArray::getBufferAsVector() const, LIBND4J_TYPES); + +//////////////////////////////////////////////////////////////////////// +std::vector NDArray::getShapeAsFlatVector() const { + std::vector vector(this->rankOf()); + for (int e = 0; e < this->rankOf(); e++) + vector[e] = static_cast(this->sizeAt(e)); + return vector; +} + +//////////////////////////////////////////////////////////////////////// +std::vector NDArray::getShapeAsVector() const { + std::vector vector(this->rankOf()); + for (int e = 0; e < this->rankOf(); e++) vector[e] = this->sizeAt(e); + + return vector; +} + +//////////////////////////////////////////////////////////////////////// +std::vector NDArray::getShapeAsVectorInt() const { + std::vector vector(this->rankOf()); + for (int e = 0; e < this->rankOf(); e++) + vector[e] = static_cast(this->sizeAt(e)); + + return vector; +} + +//////////////////////////////////////////////////////////////////////// +std::vector NDArray::getShapeInfoAsFlatVector() const { + int magicNumber = shape::shapeInfoLength(this->rankOf()); + std::vector vector(magicNumber); + + for (int e = 0; e < magicNumber; e++) + vector[e] = static_cast(_shapeInfo[e]); + + return vector; +} + +//////////////////////////////////////////////////////////////////////// +std::vector NDArray::getShapeInfoAsVector() const { + int magicNumber = shape::shapeInfoLength(this->rankOf()); + std::vector vector(magicNumber); + for (int e = 0; e < magicNumber; e++) vector[e] = this->_shapeInfo[e]; + return vector; +} + +//////////////////////////////////////////////////////////////////////// +std::vector NDArray::asByteVector() { + if (isS()) { + // string data type requires special treatment + syncToHost(); + auto numWords = this->lengthOf(); + auto offsetsBuffer = this->bufferAsT(); + auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numWords); + auto dataLength = offsetsBuffer[numWords]; + std::vector result(headerLength + dataLength); + + memcpy(result.data(), buffer(), headerLength + dataLength); + + return result; + } else { + // all other types are linear + std::vector result((unsigned long long)this->lengthOf() * + sizeOfT()); + + if (this->isView()) { + auto tmp = this->dup(this->ordering()); + syncToHost(); + memcpy(result.data(), tmp.buffer(), + (unsigned long long)lengthOf() * sizeOfT()); + } else { + syncToHost(); + memcpy(result.data(), buffer(), + (unsigned long long)lengthOf() * sizeOfT()); + } + return result; + } +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::linspace(const double start) { linspace(start, 1); } + +////////////////////////////////////////////////////////////////////////// +void NDArray::linspace(const double start, const double step) { + if (isS()) + throw std::runtime_error( + "NDArray::linspace: you can't use this method on String array!"); + Nd4jLong numElements = this->lengthOf(); + for (Nd4jLong e = 0; e < numElements; e++) this->p(e, start + (step * e)); +} + +//////////////////////////////////////////////////////////////////////// +void NDArray::streamline(char o) { + char order = o == 'a' ? this->ordering() : o; + syncToDevice(); + std::shared_ptr newBuffer = std::make_shared( + this->lengthOf() * sizeOfT(), dataType(), getContext()->getWorkspace()); + auto shapeBuffer = ConstantShapeHelper::getInstance().bufferForShapeInfo( + dataType(), order, rankOf(), shapeOf()); + NativeOpExecutioner::execTransformSame( + getContext(), transform::Copy, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), newBuffer->primary(), + shapeBuffer.primary(), newBuffer->special(), + shapeBuffer.special(), nullptr, nullptr, nullptr); + + setShapeInfo(shapeBuffer); + _buffer = newBuffer; + _offset = 0; + tickWriteDevice(); +} + +//////////////////////////////////////////////////////////////////////// +// move assignment operator +NDArray& NDArray::operator=(NDArray&& other) noexcept { + if (this == &other) return *this; + + _isView = other._isView; + _buffer = other._buffer; + _shapeInfo = other._shapeInfo; + _shapeInfoD = other._shapeInfoD; + _context = other._context; + _dataType = other._dataType; + _length = other._length; + _offset = other._offset; + + other._buffer = std::make_shared(); + other._shapeInfo = other._shapeInfoD = nullptr; + other._length = 0; + + return *this; +} + +//////////////////////////////////////////////////////////////////////// +template +NDArray& NDArray::operator=(const T scalar) { + this->assign(scalar); + return *this; +} +template SD_EXPORT NDArray& NDArray::operator=(const double scalar); +template SD_EXPORT NDArray& NDArray::operator=(const float scalar); +template SD_EXPORT NDArray& NDArray::operator=(const float16 scalar); +template SD_EXPORT NDArray& NDArray::operator=(const bfloat16 scalar); +template SD_EXPORT NDArray& NDArray::operator=(const Nd4jLong scalar); +template SD_EXPORT NDArray& NDArray::operator=(const int scalar); +template SD_EXPORT NDArray& NDArray::operator=(const int8_t scalar); +template SD_EXPORT NDArray& NDArray::operator=(const uint8_t scalar); +template SD_EXPORT NDArray& NDArray::operator=(const uint16_t scalar); +template SD_EXPORT NDArray& NDArray::operator=(const uint32_t scalar); +template SD_EXPORT NDArray& NDArray::operator=(const uint64_t scalar); +template SD_EXPORT NDArray& NDArray::operator=(const int16_t scalar); +template SD_EXPORT NDArray& NDArray::operator=(const bool scalar); + +////////////////////////////////////////////////////////////////////////// +void NDArray::copyBuffersContinuouslyFrom(const NDArray& other, + size_t sizeToCopyInBytes, + Nd4jLong offsetThis, + Nd4jLong offsetOther) { + if (offsetThis == 0) offsetThis = bufferOffset(); + if (offsetOther == 0) offsetOther = other.bufferOffset(); + + dataBuffer()->copyBufferFrom(*other.getDataBuffer(), sizeToCopyInBytes, + offsetThis, offsetOther); +} + +//////////////////////////////////////////////////////////////////// +// This method assigns values of given NDArray to this one +void NDArray::assign(const NDArray& other, bool allowParallelism) { + if (this == &other) return; + + if (other.isEmpty()) { + if (!isEmpty()) { + throw std::runtime_error("Cannot assign empty array to non-empty array"); + } + return; + } + + if (isEmpty()) { + *this = other; + return; + } + + if (other.lengthOf() == 1) { + if (lengthOf() == 1) { + NDArray::preparePrimaryUse({this}, {&other}); + BUILD_DOUBLE_SELECTOR(dataType(), other.dataType(), templatedDoubleAssign, + (buffer(), 0, other.buffer(), 0), LIBND4J_TYPES, + LIBND4J_TYPES); + NDArray::registerPrimaryUse({this}, {&other}); + this->syncToDevice(); + } else { + if (dataType() != other.dataType()) { + auto tmp = other.cast(dataType()); + NDArray::prepareSpecialUse({this}, {&tmp}); + NativeOpExecutioner::execScalar( + getContext(), scalar::CopyPws, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), + tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr, + allowParallelism); + NDArray::registerSpecialUse({this}, {}); + } else { + NDArray::prepareSpecialUse({this}, {&other}); + NativeOpExecutioner::execScalar( + getContext(), scalar::CopyPws, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), other.buffer(), + other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), + nullptr, allowParallelism); + NDArray::registerSpecialUse({this}, {&other}); + } + } + } else { + if (other.lengthOf() != lengthOf()) { + auto shapeThis = ShapeUtils::shapeAsString(this); + auto shapeThat = ShapeUtils::shapeAsString(&other); + nd4j_printf("Can't assign array: this shape %s; other shape: %s\n", + shapeThis.c_str(), shapeThat.c_str()); + throw std::runtime_error( + "NDArray::assign: lengths of arrays are mismatched"); + } + + NDArray::prepareSpecialUse({this}, {&other}); + NativeOpExecutioner::execTransformAny( + getContext(), transform::Assign, other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr, + allowParallelism); + NDArray::registerSpecialUse({this}, {&other}); + } +} + +////////////////////////////////////////////////////////////////////////// +// This method assigns values of given NDArray to this one, wrt order +void NDArray::assign(const NDArray* other, bool allowParallelism) { + assign(*other, allowParallelism); +} + +////////////////////////////////////////////////////////////////////////// +template +void NDArray::assign(const T& value, bool allowParallelism) { + // just fire scalar + auto temp = NDArrayFactory::create(dataType(), value, this->getContext()); + + NDArray::prepareSpecialUse({this}, {&temp}); + NativeOpExecutioner::execScalar( + getContext(), sd::scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), + temp.specialShapeInfo(), nullptr, allowParallelism); + NDArray::registerSpecialUse({this}, {&temp}); +} +template SD_EXPORT void NDArray::assign(const double& value, + bool allowParallelism); +template SD_EXPORT void NDArray::assign(const float& value, + bool allowParallelism); +template SD_EXPORT void NDArray::assign(const float16& value, + bool allowParallelism); +template SD_EXPORT void NDArray::assign(const bfloat16& value, + bool allowParallelism); +template SD_EXPORT void NDArray::assign(const Nd4jLong& value, + bool allowParallelism); +template SD_EXPORT void NDArray::assign(const int& value, + bool allowParallelism); +template SD_EXPORT void NDArray::assign(const int8_t& value, + bool allowParallelism); +template SD_EXPORT void NDArray::assign(const int16_t& value, + bool allowParallelism); +template SD_EXPORT void NDArray::assign(const uint8_t& value, + bool allowParallelism); +template SD_EXPORT void NDArray::assign(const uint16_t& value, + bool allowParallelism); +template SD_EXPORT void NDArray::assign(const uint32_t& value, + bool allowParallelism); +template SD_EXPORT void NDArray::assign(const uint64_t& value, + bool allowParallelism); +template SD_EXPORT void NDArray::assign(const bool& value, + bool allowParallelism); + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::detach() { + if (!isAttached()) return *this; + + std::shared_ptr newBuffer = + std::make_shared(lengthOf() * sizeOfT(), dataType()); + + NDArray result(newBuffer, + ShapeDescriptor(dataType(), ordering(), shapeOf(), rankOf())); + + result.assign(*this); + + return result; +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::varianceNumber(sd::variance::Ops op, bool biasCorrected) { + NDArray res(DataTypeUtils::pickFloatingType(dataType()), getContext()); + + NDArray::prepareSpecialUse({&res}, {this}); + NativeOpExecutioner::execSummaryStatsScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), nullptr, res.buffer(), res.shapeInfo(), + res.specialBuffer(), res.specialShapeInfo(), biasCorrected); + NDArray::registerSpecialUse({&res}, {this}); + + return res; +} + +////////////////////////////////////////////////////////////////////////// +// This method returns sum of all elements of this NDArray +NDArray NDArray::sumNumber() const { + if (isS()) + throw std::runtime_error( + "NDArray::sumNumber: you can't use this method on String array!"); + NDArray res(dataType(), getContext()); + + NDArray::prepareSpecialUse({&res}, {this}); + NativeOpExecutioner::execReduceSameScalar( + getContext(), sd::reduce::SameOps::Sum, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), nullptr, res.buffer(), + res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo()); + NDArray::registerSpecialUse({&res}, {this}); + + return res; +} + +////////////////////////////////////////////////////////////////////////// +// This method returns mean number of this NDArray +NDArray NDArray::meanNumber() const { + if (isS()) + throw std::runtime_error( + "NDArray::meanNumber: you can't use this method on String array!"); + NDArray res(DataTypeUtils::pickFloatingType(dataType()), getContext()); + + NDArray::prepareSpecialUse({&res}, {this}); + NativeOpExecutioner::execReduceFloatScalar( + getContext(), sd::reduce::FloatOps::Mean, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), nullptr, res.buffer(), + res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo()); + NDArray::registerSpecialUse({&res}, {this}); + return res; +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::hasNaNs() { + if (isS()) + throw std::runtime_error( + "NDArray::hasNaNs: you can't use this method on String array!"); + return this->reduceNumber(sd::reduce::IsNan, nullptr).e(0) > 0; +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::hasInfs() { + if (isS()) + throw std::runtime_error( + "NDArray::hasInfs: you can't use this method on String array!"); + return this->reduceNumber(sd::reduce::IsInf, nullptr).e(0) > 0; +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::isFinite() { + if (isS()) + throw std::runtime_error( + "NDArray::isFinite: you can't use this method on String array!"); + return this->reduceNumber(sd::reduce::IsInfOrNan, nullptr).e(0) == 0; +} + +////////////////////////////////////////////////////////////////////////// +template +void NDArray::templatedSet(void* buffer, const Nd4jLong* indices, + const void* value) { + auto t = reinterpret_cast(buffer); + const auto y = *(reinterpret_cast(value)); + + auto xOffset = shape::getOffset(shapeInfo(), indices); + t[xOffset] = static_cast(y); +} +BUILD_DOUBLE_TEMPLATE(template SD_EXPORT void NDArray::templatedSet, + (void* buffer, const Nd4jLong* indices, + const void* value), + LIBND4J_TYPES, LIBND4J_TYPES); + +////////////////////////////////////////////////////////////////////////// +template +void NDArray::templatedSet(void* buffer, const Nd4jLong offset, + const void* value) { + auto t = reinterpret_cast(buffer); + const auto y = *(reinterpret_cast(value)); + + t[offset] = static_cast(y); +} +BUILD_DOUBLE_TEMPLATE(template SD_EXPORT void NDArray::templatedSet, + (void* buffer, const Nd4jLong offset, const void* value), + LIBND4J_TYPES, LIBND4J_TYPES); + +////////////////////////////////////////////////////////////////////////// +void NDArray::setContext(sd::LaunchContext* context) { + _context = context; + if (getContext() == nullptr) + _context = sd::LaunchContext ::defaultContext(); // empty context for + // default cases +} + +////////////////////////////////////////////////////////////////////////// +void const* NDArray::bufferWithOffset(Nd4jLong offset) const { + return const_cast(buffer() != nullptr + ? static_cast(buffer()) + + (offset * sizeOfT()) + : nullptr); +} + +////////////////////////////////////////////////////////////////////////// +void* NDArray::bufferWithOffset(Nd4jLong offset) { + return const_cast(buffer() != nullptr + ? static_cast(buffer()) + + (offset * sizeOfT()) + : nullptr); +} + +////////////////////////////////////////////////////////////////////////// +// eventually method reduces array by excluding its shapes along axes present in +// dimensions vector +NDArray NDArray::reduceAlongDimension(sd::reduce::FloatOps op, + const std::vector& dimensions, + const bool keepDims, + const bool supportOldShapes) const { + std::vector copy(dimensions); + + auto newShape = ShapeUtils::evalReduceShapeInfo( + 'c', copy, *this, + isR() ? dataType() : Environment::getInstance().defaultFloatDataType(), + keepDims, supportOldShapes, getContext()->getWorkspace()); + + NDArray result(newShape, true, getContext()); + + this->reduceAlongDimension(op, result, copy, keepDims, supportOldShapes, + false); + + return result; +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::reduceAlongDimension(sd::reduce::SameOps op, + const std::vector& dimensions, + const bool keepDims, + const bool supportOldShapes) const { + std::vector copy(dimensions); + + auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, keepDims, + supportOldShapes, + getContext()->getWorkspace()); + + NDArray result(newShape, true, getContext()); + + reduceAlongDimension(op, result, copy, keepDims, supportOldShapes, false); + + return result; +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::reduceAlongDimension(sd::reduce::BoolOps op, + const std::vector& dimensions, + const bool keepDims, + const bool supportOldShapes) const { + std::vector copy(dimensions); + + auto newShape = ShapeUtils::evalReduceShapeInfo( + 'c', copy, *this, DataType::BOOL, keepDims, supportOldShapes, + getContext()->getWorkspace()); + + NDArray result(newShape, true, getContext()); + + reduceAlongDimension(op, result, copy, keepDims, supportOldShapes, false); + + return result; +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::reduceAlongDimension(sd::reduce::LongOps op, + const std::vector& dimensions, + const bool keepDims, + const bool supportOldShapes) const { + std::vector copy(dimensions); + + auto newShape = ShapeUtils::evalReduceShapeInfo( + 'c', copy, *this, DataType::INT64, keepDims, supportOldShapes, + getContext()->getWorkspace()); + + NDArray result(newShape, true, getContext()); + + reduceAlongDimension(op, result, copy, keepDims, supportOldShapes, false); + + return result; +} + +////////////////////////////////////////////////////////////////////////// +// method reduces array by excluding its shapes along axes present in dimensions +// vector +NDArray NDArray::reduceAlongDimension( + sd::reduce::FloatOps op, const std::initializer_list& dimensions, + const bool keepDims, const bool supportOldShapes) const { + return reduceAlongDimension(op, std::vector(dimensions), keepDims, + supportOldShapes); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::reduceAlongDimension( + sd::reduce::SameOps op, const std::initializer_list& dimensions, + const bool keepDims, const bool supportOldShapes) const { + return reduceAlongDimension(op, std::vector(dimensions), keepDims, + supportOldShapes); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::reduceAlongDimension( + sd::reduce::BoolOps op, const std::initializer_list& dimensions, + const bool keepDims, const bool supportOldShapes) const { + return reduceAlongDimension(op, std::vector(dimensions), keepDims, + supportOldShapes); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::reduceAlongDimension( + sd::reduce::LongOps op, const std::initializer_list& dimensions, + const bool keepDims, const bool supportOldShapes) const { + return reduceAlongDimension(op, std::vector(dimensions), keepDims, + supportOldShapes); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::reduceNumber(sd::reduce::FloatOps op, + void* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::reduceNumber FloatOps: you can't use this method on String " + "array!"); + + auto shape = ConstantShapeHelper::getInstance().scalarShapeInfo( + DataTypeUtils::pickFloatingType(dataType())); + NDArray result(shape, true, this->getContext()); + + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execReduceFloatScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), + result.specialBuffer(), result.specialShapeInfo()); + NDArray::registerSpecialUse({&result}, {this}); + + return result; +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::reduceNumber(sd::reduce::SameOps op, void* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::reduceNumber SameOps: you can't use this method on String " + "array!"); + + NDArray result(dataType(), getContext()); + + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execReduceSameScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), + result.specialBuffer(), result.specialShapeInfo()); + NDArray::registerSpecialUse({&result}, {this}); + + return result; +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::reduceNumber(sd::reduce::BoolOps op, void* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::reduceNumber BoolOps: you can't use this method on String " + "array!"); + + auto shape = + ConstantShapeHelper::getInstance().scalarShapeInfo(DataType::BOOL); + NDArray result(shape, true, this->getContext()); + + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execReduceBoolScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), + result.specialBuffer(), result.specialShapeInfo()); + NDArray::registerSpecialUse({&result}, {this}); + + return result; +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::reduceNumber(sd::reduce::LongOps op, void* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::reduceNumber LongOps: you can't use this method on String " + "array!"); + + auto shape = + ConstantShapeHelper::getInstance().scalarShapeInfo(DataType::INT64); + NDArray result(shape, true, this->getContext()); + + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execReduceLongScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), + result.specialBuffer(), result.specialShapeInfo()); + NDArray::registerSpecialUse({&result}, {this}); + + return result; +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::reduceNumber(sd::reduce::FloatOps op, NDArray& target, + void* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::reduceNumber FloatOps: you can't use this method on String " + "array!"); + if (target.lengthOf() != 1 || + target.dataType() != DataTypeUtils::pickFloatingType(dataType())) + throw std::invalid_argument( + "NDArray::reduceNumber FloatOps: target array should be scalar and " + "have corresponding float type!"); + + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execReduceFloatScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo()); + NDArray::registerSpecialUse({&target}, {this}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::reduceNumber(sd::reduce::SameOps op, NDArray& target, + void* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::reduceNumber SameOps: you can't use this method on String " + "array!"); + if (target.lengthOf() != 1 || target.dataType() != dataType()) + throw std::invalid_argument( + "NDArray::reduceNumber SameOps: target array should be scalar and have " + "same type as this array!"); + + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execReduceSameScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo()); + NDArray::registerSpecialUse({&target}, {this}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::reduceNumber(sd::reduce::BoolOps op, NDArray& target, + void* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::reduceNumber BoolOps: you can't use this method on String " + "array!"); + if (target.lengthOf() != 1 || target.dataType() != DataType::BOOL) + throw std::invalid_argument( + "NDArray::reduceNumber BoolOps: target array should be scalar and have " + "bool type!"); + + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execReduceBoolScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo()); + NDArray::registerSpecialUse({&target}, {this}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::reduceNumber(sd::reduce::LongOps op, NDArray& target, + void* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::reduceNumber LongOps: you can't use this method on String " + "array!"); + if (target.lengthOf() != 1 || target.dataType() != DataType::INT64) + throw std::invalid_argument( + "NDArray::reduceNumber LongOps: target array should be scalar and have " + "long type!"); + + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execReduceLongScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo()); + NDArray::registerSpecialUse({&target}, {this}); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::indexReduceNumber(sd::indexreduce::Ops op, + ExtraArguments* extraParams) { + if (isS()) + throw std::runtime_error( + "NDArray::indexReduceNumber: you can't use this method on String " + "array!"); + + auto res = NDArrayFactory::create(0); + + NDArray::NDArray::prepareSpecialUse({&res}, {this}); + NativeOpExecutioner::execIndexReduceScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), + extraParams == nullptr ? nullptr + : extraParams->argumentsAsT(this->dataType()), + res.buffer(), res.shapeInfo(), res.specialBuffer(), + res.specialShapeInfo()); + NDArray::NDArray::registerSpecialUse({&res}, {this}); + + return res; +} + +////////////////////////////////////////////////////////////////////////// +Nd4jLong NDArray::tensorsAlongDimension( + std::initializer_list dimensions) const { + return tensorsAlongDimension(std::vector(dimensions)); +} + +////////////////////////////////////////////////////////////////////////// +Nd4jLong NDArray::tensorsAlongDimension( + const std::vector& dimensions) const { + std::vector copy(dimensions); + shape::checkDimensions(rankOf(), copy); + + Nd4jLong tadLength = + shape::tadLength(this->_shapeInfo, copy.data(), copy.size()); + Nd4jLong numTads = this->lengthOf() / tadLength; + + return numTads; +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::printShapeInfo(const char* msg) const { + int rank = shape::rank(_shapeInfo); + int lim = shape::shapeInfoLength(rank); + + if (msg != nullptr) + printf("shapeInfo %s: [", msg); + else + printf("shapeInfo: ["); + + printf("%i, ", rank); + for (int i = 1; i < shape::shapeInfoLength(rank) - 3; i++) { + if (i == rank + 1) printf(" "); + printf("%lld,", _shapeInfo[i]); + } + printf(" %lld,", shape::type(_shapeInfo)); + printf("%lld,", shape::elementWiseStride(_shapeInfo)); + printf("%lld]\n", (Nd4jLong)shape::order(_shapeInfo)); + + fflush(stdout); +} + +static std::string formattedString(NDArray const* arr, int depth, int limit, std::stringstream& ss); + +////////////////////////////////////////////////////////////////////////// +void NDArray::printBuffer(const char* msg, Nd4jLong limit, + const bool sync) const { + if (sync) syncToHost(); + + if (limit == -1) limit = (int)this->lengthOf(); + + if (msg != nullptr) + printf("%s: [", msg); + else + printf("["); + if (this->isR()) { + for (Nd4jLong e = 0; e < limit; e++) { + if (e) printf(", "); + printf("%f", this->e(e)); + } + } else if (this->isZ()) { + for (Nd4jLong e = 0; e < limit; e++) { + if (this->dataType() != sd::DataType::INT64 && + this->dataType() != sd::DataType::UINT64) + printf("%d", this->e(e)); + else + printf("%llu", this->e(e)); + if (e < limit - 1) printf(", "); + } + } else if (this->isB()) { + for (Nd4jLong e = 0; e < limit; e++) { + if (this->e(e)) + printf("true"); + else + printf("false"); + if (e < limit - 1) printf(", "); + } + } else if (this->isS()) { + // todo do we need this print offsets + /* + for (Nd4jLong e = 0; e < limit; e++) { + printf("\"%lld\"", this->getOffset(e)); + if (e < limit - 1) + printf(", "); + } + printf("]\n["); + */ + for (Nd4jLong e = 0; e < limit; e++) { + printf("\"%s\"", this->e(e).c_str()); + if (e < limit - 1) printf(", "); + } + } + printf("]\n"); + fflush(stdout); +} + +////////////////////////////////////////////////////////////////////////// +// print element by element consequently in a way they (elements) are stored in +// physical memory +void NDArray::printLinearBuffer() const { + syncToHost(); + + const auto ews = this->ews() > 0 ? this->ews() : 1; + const auto len = this->lengthOf(); + + printf("["); + + if (this->dataType() == sd::DataType::INT32) { + for (Nd4jLong e = 0; e < len; e++) + printf("%d, ", this->bufferAsT()[e * ews]); + } else if (this->dataType() == sd::DataType::INT64) { + for (Nd4jLong e = 0; e < len; e++) + printf("%lld, ", this->bufferAsT()[e * ews]); + } else if (this->dataType() == sd::DataType::FLOAT32) { + for (Nd4jLong e = 0; e < len; e++) + printf("%.3f, ", this->bufferAsT()[e * ews]); + } else if (this->dataType() == sd::DataType::DOUBLE) { + for (Nd4jLong e = 0; e < len; e++) + printf("%.3f, ", this->bufferAsT()[e * ews]); + } else + throw std::invalid_argument( + "NDArray::printLinearBuffer: not implemented yet for this data type !"); + + printf("]\n"); + fflush(stdout); +} + + std::string NDArray::linearString(Nd4jLong limit) const { + syncToHost(); + + const auto ews = this->ews() > 0 ? this->ews() : 1; + const auto len = this->lengthOf(); + std::stringstream ss; + ss << "["; + + for (Nd4jLong e = 0; e < len; e++) { + if (e) + ss << ", "; + switch (this->dataType()) { + case sd::DataType::INT32: + ss << this->bufferAsT()[e * ews]; + break; + case sd::DataType::INT64: + ss << this->bufferAsT()[e * ews]; + break; + case sd::DataType::FLOAT32: + ss << std::setprecision(6) << this->bufferAsT()[e * ews]; + break; + case sd::DataType::DOUBLE: + ss << std::setprecision(6) << this->bufferAsT()[e * ews]; + break; + //case sd::DataType::UTF8: ss << this->bufferAsT()[e * ews]; break; + default: + throw std::invalid_argument("NDArray::linearString: not implemented yet for this data type !"); + } + + } + ss << "]"; + return ss.str(); + } + +////////////////////////////////////////////////////////////////////////// +static void printFormatted(NDArray const* arr, int depth, int limit) { + if (arr->rankOf() == 1) { + printf("[ "); + for (Nd4jLong i = 0; i < arr->lengthOf(); ++i) { + if (arr->isR()) + printf("%f, ", arr->e(i)); + else if (arr->isZ()) + printf("%lld, ", arr->e(i)); + else if (arr->isB()) + printf("%s, ", arr->e(i) ? "true" : "false"); + else if (arr->isS()) { + printf("\"%s\", ", arr->e(i).c_str()); + } + } + printf("]\n"); + } else if (arr->rankOf() == 2) { + Nd4jLong rows = arr->rows(); + Nd4jLong cols = arr->columns(); + char* padding = new char[depth + 1]; + memset(padding, ' ', depth); + padding[depth] = 0; + printf("["); + for (Nd4jLong row = 0; row < rows; ++row) { + if (row && depth > 0) printf("%s", padding); + printf("["); + Nd4jLong colLimit = cols > limit ? cols : limit; + for (Nd4jLong col = 0; col < colLimit; ++col) { + if (col) printf(", "); + if (arr->isR()) + printf("%f", arr->e(row, col)); + else if (arr->isZ()) + printf("%lld", arr->e(row, col)); + else if (arr->isB()) + printf("%s", arr->e(row, col) ? "true" : "false"); + else if (arr->isS()) { + printf("\"%s\"", arr->e(row * cols + col).c_str()); + } + } + if (row < rows - 1) + printf("]\n"); + else + printf("]"); + } + printf("]"); + if (padding) delete[] padding; + } else { + // std::unique_ptr arrs(arr->allTensorsAlongDimension({0})); + size_t restCount = 2; + printf("["); + restCount = ShapeUtils::getNumOfSubArrs(arr->shapeInfo(), {0}); + for (size_t arrIndex = 0; arrIndex < restCount; ++arrIndex) { + NDArray subArr = (*arr)(arrIndex, {0}); + printFormatted(&subArr, depth + 1, limit); + if (arrIndex < restCount - 1) { + for (Nd4jLong i = 1; i < arr->rankOf(); ++i) printf("\n"); + for (Nd4jLong i = 0; i < depth - 2; ++i) printf(" "); + } + } + printf("]"); + } +} + + static std::string formattedString(NDArray const* arr, int depth, int limit, std::stringstream& ss) { + + if (arr->rankOf() == 1) { + ss << "[ "; + for (Nd4jLong i = 0; i < arr->lengthOf(); ++i) { + if (arr->isR()) + ss << arr->e(i); + else if (arr->isZ()) + ss << arr->e(i); + else if (arr->isB()) + ss << (arr->e(i) ? "true" : "false"); + else if (arr->isS()) { + ss << "\"" << arr->e(i).c_str() << "\""; + } + } + ss << "]"; + } else if (arr->rankOf() == 2) { + Nd4jLong rows = arr->rows(); + Nd4jLong cols = arr->columns(); + //memset(padding, ' ', depth); + ss << "["; + for (Nd4jLong row = 0; row < rows; ++row) { + if (row && depth > 0) + ss << std::setfill(' ') << std::setw(depth) << ' '; + ss << "["; + Nd4jLong colLimit = cols > limit ? cols : limit; + for (Nd4jLong col = 0; col < colLimit; ++col) { + if (col) ss << (", "); + if (arr->isR()) + ss << std::setw(12) << std::setprecision(6) << arr->e(row, col); + else if (arr->isZ()) + ss << arr->e(row, col); + else if (arr->isB()) + ss << (arr->e(row, col) ? "true" : "false"); + else if (arr->isS()) { + ss << "\"" << arr->e(row * cols + col).c_str() <<"\""; + } + } + if (row < rows - 1) + ss << "]" << std::endl; + else + ss << "]"; + } + ss << "]"; + } else { + // std::unique_ptr arrs(arr->allTensorsAlongDimension({0})); + size_t restCount = 2; + ss << "["; + restCount = ShapeUtils::getNumOfSubArrs(arr->shapeInfo(), {0}); + for (size_t arrIndex = 0; arrIndex < restCount; ++arrIndex) { + NDArray subArr = (*arr)(arrIndex, {0}); + formattedString(&subArr, depth + 1, limit, ss); + if (arrIndex < restCount - 1) { + for (Nd4jLong i = 1; i < arr->rankOf(); ++i) printf("\n"); + for (Nd4jLong i = 0; i < depth - 2; ++i) printf(" "); + } + } + ss << "]"; + } + return ss.str(); + } + +////////////////////////////////////////////////////////////////////////// + void NDArray::printIndexedBuffer(const char* msg, Nd4jLong limit) const { + auto indexedString = indexedBufferString(limit); + if (msg) + printf("%s:\n%s\n", msg, indexedString.c_str()); + else + printf("%s\n", indexedString.c_str()); + fflush(stdout); + } +////////////////////////////////////////////////////////////////////////// +std::string NDArray::indexedBufferString(Nd4jLong limit) const { + syncToHost(); + std::string output; + Nd4jLong rank = this->rankOf(); + + bool rowFlag = (rank < 2) || (rank == 2 && this->sizeAt(0) == 1); + + if (this->isEmpty()) { + return std::string("Empty"); + } else if (this->rankOf() == 0) { + std::stringstream ss; + if (this->isZ()) + ss << this->e(0); + else if (this->isR()) + ss << this->e(0); + else if (this->isB()) { + ss << (this->e(0) ? "true" : "false"); + } else if (this->isS()) { + // todo do we need this + // printf("\"%lld\"\n", this->getOffset(e)); + ss << "\"" << this->e(0) << "\n"; + } + return ss.str(); + } else if (rowFlag && ews() == 1) + return linearString(limit); + else { + std::stringstream ss; + return formattedString(this, 1, limit, ss); + } +} + +////////////////////////////////////////////////////////////////////////// +template +void* NDArray::templatedPointerShift(const Nd4jLong offset) const { + return const_cast(reinterpret_cast(buffer()) + offset); +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT void* NDArray::templatedPointerShift, + (const Nd4jLong offset) const, LIBND4J_TYPES); + +////////////////////////////////////////////////////////////////////////// +// method makes copy of this array and applies to the copy transpose operation, +// this array remains unaffected +NDArray NDArray::transpose() const& { + NDArray newArr(getDataBuffer(), ShapeDescriptor(shapeInfo()), getContext(), + bufferOffset()); + newArr.transposei(); + + return newArr; +} + +////////////////////////////////////////////////////////////////////////// +// method makes copy of this array and applies to the copy transpose operation, +// this array remains unaffected +NDArray NDArray::transpose() && { + this->transposei(); + return std::move(*this); +} + +//////////////////////////////////////////////////////////////////////// +// method performs transpose operation based on this array and store result in +// target, this array remains unaffected +void NDArray::transpose(NDArray& target) const { + auto correctShape = + ShapeUtils::evalTranspShapeInfo(*this, getContext()->getWorkspace()); + if (!shape::equalsStrict(correctShape, target.shapeInfo())) + throw std::runtime_error( + "NDArray::transpose method: the shapeInfo of target array is wrong !"); + + target._buffer = _buffer; + target._offset = _offset; + target._isView = true; +} + +//////////////////////////////////////////////////////////////////////// +// This method applies in-place transpose to this array, so this array becomes +// transposed +void NDArray::transposei() { + std::vector perm; + for (int e = this->rankOf() - 1; e >= 0; e--) perm.emplace_back(e); + + this->permutei(perm); +} + +//////////////////////////////////////////////////////////////////////// +bool NDArray::equalsTo(const NDArray& other, double eps) const { + return equalsTo(&other, eps); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::setAttached(bool reallyAttached) { + _isAttached = reallyAttached; +}; + +////////////////////////////////////////////////////////////////////////// +// calculate strides +void NDArray::updateStrides(const char order) { + throw std::runtime_error("Very bad method was invoked"); +} + +////////////////////////////////////////////////////////////////////////// +// set new order and shape in case of suitable array length +bool NDArray::reshapei(const char order, + const std::initializer_list& shape, + const bool copyToNewBuff) { + std::vector vShape(shape); + return reshapei(order, vShape, copyToNewBuff); +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::reshapei(const std::initializer_list& shape, + const bool copyToNewBuff) { + return reshapei(ordering(), shape, copyToNewBuff); +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::reshapei(const std::vector& shape, + const bool copyToNewBuff) { + return reshapei(ordering(), shape, copyToNewBuff); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::enforce(const std::initializer_list& dimensions, + char order) { + std::vector dims(dimensions); + enforce(dims, order); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::enforce(std::vector& dimensions, char o) { + Nd4jLong prod = 1; + for (int e = 0; e < dimensions.size(); e++) prod *= dimensions[e]; + + if (prod != this->lengthOf()) { + std::string current = ShapeUtils::shapeAsString(this); + std::string enforced = ShapeUtils::shapeAsString(dimensions); + nd4j_printf( + "Can't enforce new shape, lengths mismatch. Original shape: %s; " + "Requested shape: %s\n", + current.c_str(), enforced.c_str()); + throw std::runtime_error("Incompatible shape"); + } + + char order = o == 'a' ? this->ordering() : o; + setShapeInfo(ShapeDescriptor(dataType(), order, dimensions)); +} + +////////////////////////////////////////////////////////////////////////// +Nd4jLong NDArray::argMax(std::initializer_list dimensions) { + if (isS()) + throw std::runtime_error( + "NDArray::argMax: you can't use this method on String array!"); + + if (dimensions.size() == 0) { + Nd4jLong max = 0; + auto mv = -DataTypeUtils::max(); + for (Nd4jLong e = 0; e < this->lengthOf(); e++) { + auto val = this->e(e); + if (mv < val) { + mv = val; + max = e; + } + } + return max; + } else + throw std::runtime_error("NDArray::argMax() - Not implemented yet"); +} + +////////////////////////////////////////////////////////////////////////// +// create new array with corresponding order and shape, new array will point to +// the same _buffer as this array +NDArray NDArray::reshape(const char order, const std::vector& shape, + const bool copyToNewBuff) const& { + NDArray newArr(getDataBuffer(), ShapeDescriptor(shapeInfo()), getContext(), + bufferOffset()); + newArr.reshapei(order, shape, copyToNewBuff); + + return newArr; +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::reshape(const char order, const std::vector& shape, + const bool copyToNewBuff) && { + this->reshapei(order, shape, copyToNewBuff); + return std::move(*this); +} + +////////////////////////////////////////////////////////////////////////// +// change an array by repeating it the number of times given by reps. +void NDArray::tilei(const std::vector& reps) { + *this = this->tile(reps); +} + +////////////////////////////////////////////////////////////////////////// +Nd4jLong NDArray::sizeAt(const int dim) const { + if (dim >= this->rankOf() || dim < -this->rankOf()) + throw std::runtime_error("NDArray::sizeAt: bad size index requested"); + + if (dim >= 0) + return shape::shapeOf(_shapeInfo)[dim]; + else + return shape::shapeOf(_shapeInfo)[this->rankOf() + dim]; +} + +////////////////////////////////////////////////////////////////////////// +Nd4jLong NDArray::strideAt(const int dim) const { + if (dim >= this->rankOf() || dim < -this->rankOf()) + throw std::runtime_error("NDArray::strideAt: Bad size index requested"); + + if (dim >= 0) + return shape::stride(_shapeInfo)[dim]; + else + return shape::stride(_shapeInfo)[this->rankOf() + dim]; +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::permutei(const std::initializer_list& dimensions) { + std::vector vec(dimensions); + return permutei(vec); +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::permutei(const std::vector& dimensions) { + return permutei(dimensions.data(), rankOf()); +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::permutei(const std::initializer_list& dimensions) { + std::vector vec(dimensions); + std::vector ivec(dimensions.size()); + + for (int e = 0; e < vec.size(); e++) ivec[e] = static_cast(vec[e]); + + return permutei(ivec); +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::permutei(const std::vector& dimensions) { + std::vector ivec(dimensions.size()); + + for (int e = 0; e < dimensions.size(); e++) ivec[e] = dimensions[e]; + + return permutei(ivec.data(), rankOf()); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::permute(const int* dimensions, const int rank) const& { + // evaluate shapeInfo for output (permuted) array ret + auto shapeInfoPermuted = ShapeUtils::evalPermShapeInfo( + dimensions, rank, *this, getContext()->getWorkspace()); + NDArray ret(getDataBuffer(), ShapeDescriptor(shapeInfoPermuted), getContext(), + bufferOffset()); + ret._isView = true; + return ret; +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::permute(const int* dimensions, const int rank) && { + this->permutei(dimensions, rank); + return std::move(*this); +} + +///////////////////////////////////////////////////////////////////////// +NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) const& { + int tempDims[MAX_RANK]; + shape::convertT(const_cast(dimensions), tempDims, + rank); + return permute(tempDims, rank); +} + +///////////////////////////////////////////////////////////////////////// +NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) && { + this->permutei(dimensions, rank); + return std::move(*this); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::permute(const std::vector& dimensions) const& { + return permute(dimensions.data(), rankOf()); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::permute(const std::vector& dimensions) && { + this->permutei(dimensions); + return std::move(*this); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::permute(const std::vector& dimensions) const& { + return permute(dimensions.data(), rankOf()); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::permute(const std::vector& dimensions) && { + this->permutei(dimensions); + return std::move(*this); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::permute(const std::initializer_list& dimensions) const& { + std::vector vec(dimensions); + return permute(vec); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::permute(const std::initializer_list& dimensions) && { + this->permutei(dimensions); + return std::move(*this); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::permute( + const std::initializer_list& dimensions) const& { + std::vector vec(dimensions); + return permute(vec); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::permute(const std::initializer_list& dimensions) && { + this->permutei(dimensions); + return std::move(*this); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::permute(const int* dimensions, const int rank, + NDArray& target) const { + if (!nonNull() || !target.nonNull() || rank != rankOf() || + rank != target.rankOf()) + throw std::runtime_error( + "NDArray::permute method: either arrays are nullptr or ranks are " + "not suitable!"); + + auto shapeInfoNew = ShapeUtils::evalPermShapeInfo( + dimensions, rank, *this, target.getContext()->getWorkspace()); + + target.setShapeInfo(shapeInfoNew); + target._buffer = _buffer; + target._offset = _offset; +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::permute(const Nd4jLong* dimensions, const int rank, + NDArray& target) const { + if (!nonNull() || !target.nonNull() || rank != rankOf() || + rank != target.rankOf()) + throw std::runtime_error( + "NDArray::permute method: either arrays are nullptr or ranks are " + "not suitable!"); + + auto shapeInfoNew = ShapeUtils::evalPermShapeInfo( + dimensions, rank, *this, target.getContext()->getWorkspace()); + + target.setShapeInfo(shapeInfoNew); + target._buffer = _buffer; + target._offset = _offset; +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::permute(const std::vector& dimensions, + NDArray& target) const { + permute(dimensions.data(), rankOf(), target); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::permute(const std::vector& dimensions, + NDArray& target) const { + permute(dimensions.data(), rankOf(), target); +} + +////////////////////////////////////////////////////////////////////////// +// check whether array is identity matrix +bool NDArray::isIdentityMatrix() { + if (isS()) + throw std::runtime_error( + "NDArray::isIdentityMatrix: you can't use this method on String " + "array!"); + if (rankOf() != 2 || rows() != columns()) + throw std::runtime_error( + "isIdentityMatrix method: matrix must be square and have rank = 2 !"); + + const double eps = 1e-5f; + for (Nd4jLong i = 0; i < rows(); ++i) + if (sd::math::nd4j_abs(e(i, i) - 1.f) > eps) return false; + + for (Nd4jLong i = 0; i < rows(); ++i) { + for (Nd4jLong j = 0; j < columns(); ++j) { + if (i == j) continue; + if (sd::math::nd4j_abs(e(i, j)) > eps) return false; + } + } + return true; +} + +////////////////////////////////////////////////////////////////////////// +// check whether array is unitary matrix +bool NDArray::isUnitary() { + if (isS()) + throw std::runtime_error( + "NDArray::isUnitary: you can't use this method on String array!"); + if (rankOf() != 2 || rows() != columns()) + throw std::runtime_error( + "isUnitary method: matrix must be square and have rank = 2 !"); + + auto tr = this->transpose(); + auto trMul = MmulHelper::mmul(this, &tr, nullptr, 1.f, 0.f); + + bool result = trMul->isIdentityMatrix(); + delete trMul; + + return result; +} + +////////////////////////////////////////////////////////////////////////// +template <> +const std::string* SD_EXPORT NDArray::bufferAsT() const { + throw std::runtime_error("This method is NOT supposed to be used"); +} + +////////////////////////////////////////////////////////////////////////// +template +const T* NDArray::bufferAsT() const { + // FIXME: do we REALLY want sync here? + syncToHost(); + + return reinterpret_cast(buffer()); +} +BUILD_SINGLE_UNCHAINED_TEMPLATE(template SD_EXPORT const, + *NDArray::bufferAsT() const, LIBND4J_TYPES); + +template +T* NDArray::bufferAsT() { + syncToHost(); + return reinterpret_cast(buffer()); +} +BUILD_SINGLE_UNCHAINED_TEMPLATE(template SD_EXPORT, *NDArray::bufferAsT(), + LIBND4J_TYPES); + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::subarray(IndicesList& idx) const { + const int idxSize = idx.size(); + if (idxSize != this->rankOf()) + throw std::runtime_error( + "NDArray::subarray: number of indices should match"); + + std::vector indexes(3 * idxSize); + + // convert IndicesList to vector + for (int d = 0; d < idxSize; ++d) { + if (idx.at(d)->isAll()) { + indexes[3 * d] = 0; // first + indexes[3 * d + 1] = 0; // last + indexes[3 * d + 2] = 1; // stride + } else if (idx.at(d)->isPoint()) { + indexes[3 * d] = idx.at(d)->getIndices().at(0); // first + indexes[3 * d + 1] = indexes[3 * d] + 1; // last + indexes[3 * d + 2] = 1; // stride + } else if (idx.at(d)->isInterval()) { + indexes[3 * d] = idx.at(d)->getIndices().at(0); // first + indexes[3 * d + 1] = idx.at(d)->getIndices().size(); // last + indexes[3 * d + 2] = idx.at(d)->stride(); // stride + } else { + indexes[3 * d] = idx.at(d)->getIndices().at(0); // first + indexes[3 * d + 1] = idx.at(d)->getIndices().at(1); // last + indexes[3 * d + 2] = idx.at(d)->getIndices().at(2); // stride + } + } + return NDArray((*this)(indexes, true, true)); +} + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::subarray(const std::initializer_list& idx) const { + const int idxSize = idx.size(); + if (idxSize != this->rankOf()) + throw std::runtime_error( + "NDArray::subarray: number of indices should match the array rank"); + + std::vector indexes(3 * idxSize); + + // convert NDIndex to vector + int d = 0; + for (const auto& item : idx) { + if (item->isAll()) { + indexes[3 * d] = 0; // first + indexes[3 * d + 1] = 0; // last + indexes[3 * d + 2] = 1; // stride + } else if (item->isPoint()) { + indexes[3 * d] = item->getIndices().at(0); // first + indexes[3 * d + 1] = indexes[3 * d] + 1; // last + indexes[3 * d + 2] = 1; // stride + } else if (item->isInterval()) { + indexes[3 * d] = item->getIndices().at(0); // first + indexes[3 * d + 1] = item->getIndices().size(); // last + indexes[3 * d + 2] = item->stride(); // stride + } else { + indexes[3 * d] = item->getIndices().at(0); // first + indexes[3 * d + 1] = item->getIndices().at(1); // last + indexes[3 * d + 2] = item->getIndices().at(2); // stride + } + ++d; + } + + // release NDIndices + for (auto i : idx) delete i; + + return NDArray((*this)(indexes, true, true)); +} + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::subarray(const Intervals& idx) const { + const int idxSize = idx.size(); + if (idxSize != this->rankOf()) + throw std::runtime_error( + "NDArray::subarray: number of indices should match the rank of array!"); + + std::vector indexes(2 * idxSize); + + // convert Intervals to vector + for (int d = 0; d < idxSize; ++d) { + if (idx[d].empty()) { + indexes[2 * d] = 0; // first + indexes[2 * d + 1] = 0; // last + } else { + indexes[2 * d] = idx[d][0]; // first + indexes[2 * d + 1] = idx[d][1]; // last + } + } + + return NDArray((*this)(indexes, true)); +} + +////////////////////////////////////////////////////////////////////////// +template +NDArray NDArray::asT() const { + auto result = isScalar() + ? NDArray('c', {}, std::vector{0.}, + DataTypeUtils::fromT(), this->getContext()) + : NDArray(ordering(), getShapeAsVector(), + DataTypeUtils::fromT(), this->getContext()); + + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execTransformAny( + getContext(), transform::AnyOps::Assign, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), + result.specialBuffer(), result.specialShapeInfo(), nullptr, nullptr, + nullptr); + NDArray::registerSpecialUse({&result}, {this}); + + return result; +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray NDArray::asT, () const, + LIBND4J_TYPES); + +////////////////////////////////////////////////////////////////////////// +template +NDArray NDArray::asS() const { + if (!isS()) + throw std::runtime_error( + "NDArray::asS: you can use this method only for String array!"); + + auto dtype = DataTypeUtils::fromT(); + + if (!(DataTypeUtils::isS(dtype))) + throw std::invalid_argument("NDArray::asS: invalid DataType used"); + + if (dtype == dataType()) { + Nd4jLong offsetsLength = + ShapeUtils::stringBufferHeaderRequirements(lengthOf()); + const auto nInputoffsets = bufferAsT(); + std::shared_ptr pBuffer = + std::make_shared(offsetsLength + nInputoffsets[lengthOf()], + dtype, getContext()->getWorkspace(), true); + + NDArray res(pBuffer, ShapeDescriptor(dtype, ordering(), getShapeAsVector()), + getContext()); + res.setAttached(getContext()->getWorkspace() != nullptr); + + preparePrimaryUse({&res}, {this}); + memcpy(res.bufferAsT(), nInputoffsets, offsetsLength); + auto data = res.bufferAsT() + offsetsLength; + const auto inData = bufferAsT() + offsetsLength; + memcpy(data, inData, nInputoffsets[lengthOf()]); + + registerPrimaryUse({&res}, {this}); + return res; + } + + Nd4jLong offsetsLength = + ShapeUtils::stringBufferHeaderRequirements(lengthOf()); + + std::vector offsets(lengthOf() + 1); + + const auto nInputoffsets = bufferAsT(); + + Nd4jLong start = 0, stop = 0; + Nd4jLong dataLength = 0; + + auto data = bufferAsT() + offsetsLength; + for (Nd4jLong e = 0; e < lengthOf(); e++) { + offsets[e] = dataLength; + start = nInputoffsets[e]; + stop = nInputoffsets[e + 1]; + if (dataType() == DataType::UTF8) { + dataLength += (dtype == DataType::UTF16) + ? unicode::offsetUtf8StringInUtf16(data + start, stop) + : unicode::offsetUtf8StringInUtf32(data + start, stop); + } else if (dataType() == DataType::UTF16) { + dataLength += (dtype == DataType::UTF32) + ? unicode::offsetUtf16StringInUtf32( + data + start, (stop / sizeof(char16_t))) + : unicode::offsetUtf16StringInUtf8( + data + start, (stop / sizeof(char16_t))); + } else { + dataLength += (dtype == DataType::UTF16) + ? unicode::offsetUtf32StringInUtf16( + data + start, (stop / sizeof(char32_t))) + : unicode::offsetUtf32StringInUtf8( + data + start, (stop / sizeof(char32_t))); + } + } + offsets[lengthOf()] = dataLength; + + std::shared_ptr pBuffer = std::make_shared( + offsetsLength + dataLength, dtype, getContext()->getWorkspace(), true); + + NDArray res(pBuffer, ShapeDescriptor(dtype, ordering(), getShapeAsVector()), + getContext()); + res.setAttached(getContext()->getWorkspace() != nullptr); + + preparePrimaryUse({&res}, {this}); + + memcpy(res.bufferAsT(), offsets.data(), + offsets.size() * sizeof(Nd4jLong)); + + auto outData = res.bufferAsT() + offsetsLength; + const auto inData = bufferAsT() + offsetsLength; + + auto func = PRAGMA_THREADS_FOR { + for (int e = start; e < stop; e++) { + auto cdata = outData + offsets[e]; + auto end = nInputoffsets[e + 1]; + auto idata = inData + nInputoffsets[e]; + if (dtype == DataType::UTF16) { + if (dataType() == DataType::UTF8) { + unicode::utf8to16(idata, outData, end); + } else { + unicode::utf32to16(idata, outData, (end / sizeof(char32_t))); + } + } else if (dtype == DataType::UTF32) { + if (dataType() == DataType::UTF8) { + unicode::utf8to32(idata, cdata, end); + } else { + unicode::utf16to32(idata, outData, (end / sizeof(char16_t))); + } + } else { + if (dataType() == DataType::UTF16) { + unicode::utf16to8(idata, outData, (end / sizeof(char16_t))); + } else { + unicode::utf32to8(idata, outData, (end / sizeof(char32_t))); + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + registerPrimaryUse({&res}, {this}); + + return res; +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray NDArray::asS, () const, + LIBND4J_STRINGTYPES); + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::asT(DataType dtype) const { + if (isS() && !DataTypeUtils::isS(dtype)) + throw std::runtime_error( + "NDArray::asT: you can't use this method on String array with not " + "string DataType!"); + + if (!isS() && DataTypeUtils::isS(dtype)) + throw std::runtime_error( + "NDArray::asT: you can't use this method on not String array with " + "string DataType!"); + + if (isS()) { + BUILD_SINGLE_SELECTOR(dtype, return asS, (), LIBND4J_STRINGTYPES); + } else { + BUILD_SINGLE_SELECTOR(dtype, return asT, (), LIBND4J_TYPES); + } + + return NDArray(); +} + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::cast(DataType dtype) const { + if (isS() && !DataTypeUtils::isS(dtype)) + throw std::runtime_error( + "NDArray::cast: you can't use this method on String array with not " + "string DataType!"); + + if (!isS() && DataTypeUtils::isS(dtype)) + throw std::runtime_error( + "NDArray::cast: you can't use this method on not String array with " + "string DataType!"); + + return this->asT(dtype); +} + +//////////////////////////////////////////////////////////////////////// +void NDArray::cast(NDArray& target, DataType dtype) { + if (isS()) + throw std::runtime_error("NDArray::cast: you can't use this method on String array!"); + + target.assign(this); +} + +//////////////////////////////////////////////////////////////////////// +void NDArray::operator+=(const NDArray& other) { + if (isS()) + throw std::runtime_error( + "NDArray::operator+=: you can't use this method on String array!"); + if (!Environment::getInstance().isExperimentalBuild() && + this->dataType() != other.dataType() && + (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) + throw sd::datatype_exception::build( + "NDArray operator+=: Cannot add different types", this->dataType(), + other.dataType()); + + if (this->lengthOf() != 1 && other.lengthOf() == 1) { + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execScalar( + getContext(), sd::scalar::Add, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } else if (other.lengthOf() == lengthOf() && + this->rankOf() == other.rankOf()) { + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execPairwiseTransform( + getContext(), sd::pairwise::Add, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } else { + const Nd4jLong* bShape = nullptr; + if (!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, bShape, + getContext()->getWorkspace())) + throw std::invalid_argument( + "NDArray::operator+=: the shapes of this and other arrays are not " + "suitable for broadcast operation !"); + + if (shape::equalsTypesAndShapesSoft(shapeInfo(), bShape)) { + this->applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), other, *this, + false); + } else { + NDArray result(bShape, true, getContext()); + this->applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), other, result, + false); + *this = std::move(result); // move assignment operator, zero cost copy + } + } +} + +//////////////////////////////////////////////////////////////////////// +void NDArray::operator-=(const NDArray& other) { + if (isS()) + throw std::runtime_error( + "NDArray::operator-=: you can't use this method on String array!"); + + if (!Environment::getInstance().isExperimentalBuild() && + this->dataType() != other.dataType() && + (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) + throw sd::datatype_exception::build( + "NDArray operator-=: Cannot subtract different types", this->dataType(), + other.dataType()); + + if (lengthOf() != 1 && other.lengthOf() == 1) { + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execScalar( + getContext(), sd::scalar::Subtract, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } else if (other.lengthOf() == lengthOf() && + this->rankOf() == other.rankOf()) { + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execPairwiseTransform( + getContext(), sd::pairwise::Subtract, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } else { + const Nd4jLong* bShape = nullptr; + if (!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, bShape, + getContext()->getWorkspace())) + throw std::invalid_argument( + "NDArray::operator-=: the shapes of this and other arrays are not " + "suitable for broadcast operation !"); + + if (shape::equalsTypesAndShapesSoft(shapeInfo(), bShape)) { + this->applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), other, *this, + false); + } else { + NDArray result(bShape, true, getContext()); + this->applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), other, result, + false); + *this = std::move(result); // move assignment operator, zero cost copy + } + } +} + +//////////////////////////////////////////////////////////////////////// +void NDArray::operator*=(const NDArray& other) { + if (isS()) + throw std::runtime_error( + "NDArray::operator*=: you can't use this method on String array!"); + if (!Environment::getInstance().isExperimentalBuild() && + this->dataType() != other.dataType() && + (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) + throw sd::datatype_exception::build( + "NDArray operator*=: Cannot multiply different types", this->dataType(), + other.dataType()); + + if (lengthOf() != 1 && other.lengthOf() == 1) { + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execScalar( + getContext(), sd::scalar::Multiply, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } else if (other.lengthOf() == lengthOf() && + this->rankOf() == other.rankOf()) { + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execPairwiseTransform( + getContext(), sd::pairwise::Multiply, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } else { + const Nd4jLong* bShape = nullptr; + if (!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, bShape, + getContext()->getWorkspace())) + throw std::invalid_argument( + "NDArray::operator*=: the shapes of this and other arrays are not " + "suitable for broadcast operation !"); + + if (shape::equalsTypesAndShapesSoft(_shapeInfo, bShape)) { + this->applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), other, *this, + false); + } else { + NDArray result(bShape, true, getContext()); + this->applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), other, result, + false); + *this = std::move(result); // move assignment operator, zero cost copy + } + } +} + +//////////////////////////////////////////////////////////////////////// +void NDArray::operator/=(const NDArray& other) { + if (isS() || other.isS()) + throw std::runtime_error( + "NDArray::operator/=: you can't use this method on String array!"); + if (other.isB()) + throw std::runtime_error( + "NDArray::operator/=: you can't divide by bool array!"); + + if (!Environment::getInstance().isExperimentalBuild() && + this->dataType() != other.dataType()) { + throw sd::datatype_exception::build( + "NDArray operator/=: Cannot divide different types", this->dataType(), + other.dataType()); + } + + if (lengthOf() != 1 && other.lengthOf() == 1) { + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execScalar( + getContext(), sd::scalar::Divide, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } else if (other.lengthOf() == lengthOf() && + this->rankOf() == other.rankOf()) { + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execPairwiseTransform( + getContext(), sd::pairwise::Divide, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } else { + const Nd4jLong* bShape = nullptr; + if (!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, bShape, + getContext()->getWorkspace())) + throw std::invalid_argument( + "NDArray::operator/=: the shapes of this and other arrays are not " + "suitable for broadcast operation !"); + + if (shape::equalsTypesAndShapesSoft(_shapeInfo, bShape)) { + this->applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), other, *this, + false); + } else { + NDArray result(bShape, true, getContext()); + this->applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), other, result, + false); + *this = std::move(result); // move assignment operator, zero cost copy + } + } +} + +//////////////////////////////////////////////////////////////////////// +template +void NDArray::operator+=(const T value) { + if (isS()) + throw std::runtime_error( + "NDArray::operator+=: you can't use this method on String array!"); + + auto other = NDArrayFactory::create(this->dataType(), value, getContext()); + + NDArray::prepareSpecialUse({this}, {&other}); + + NativeOpExecutioner::execScalar( + getContext(), sd::scalar::Add, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), nullptr); + + NDArray::registerSpecialUse({this}, {}); +} +template SD_EXPORT void NDArray::operator+=(const double value); +template SD_EXPORT void NDArray::operator+=(const float value); +template SD_EXPORT void NDArray::operator+=(const float16 value); +template SD_EXPORT void NDArray::operator+=(const bfloat16 value); +template SD_EXPORT void NDArray::operator+=(const Nd4jLong value); +template SD_EXPORT void NDArray::operator+=(const int value); +template SD_EXPORT void NDArray::operator+=(const bool value); + +//////////////////////////////////////////////////////////////////////// +template +void NDArray::operator-=(const T value) { + if (isS()) + throw std::runtime_error( + "NDArray::operator-=: you can't use this method on String array!"); + + auto other = NDArrayFactory::create(dataType(), value, getContext()); + + NDArray::prepareSpecialUse({this}, {&other}); + + NativeOpExecutioner::execScalar( + getContext(), sd::scalar::Subtract, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), nullptr); + + NDArray::registerSpecialUse({this}, {}); +} +template SD_EXPORT void NDArray::operator-=(const double value); +template SD_EXPORT void NDArray::operator-=(const float value); +template SD_EXPORT void NDArray::operator-=(const float16 value); +template SD_EXPORT void NDArray::operator-=(const bfloat16 value); +template SD_EXPORT void NDArray::operator-=(const Nd4jLong value); +template SD_EXPORT void NDArray::operator-=(const int value); +template SD_EXPORT void NDArray::operator-=(const bool value); + +//////////////////////////////////////////////////////////////////////// +template +void NDArray::operator*=(const T scalar) { + if (isS()) + throw std::runtime_error( + "NDArray::operator*=: you can't use this method on String array!"); + + auto other = NDArrayFactory::create(this->dataType(), scalar, getContext()); + NDArray::prepareSpecialUse({this}, {&other}); + NativeOpExecutioner::execScalar( + getContext(), sd::scalar::Multiply, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), nullptr); + + NDArray::registerSpecialUse({this}, {}); +} +template SD_EXPORT void NDArray::operator*=(const double scalar); +template SD_EXPORT void NDArray::operator*=(const float scalar); +template SD_EXPORT void NDArray::operator*=(const float16 scalar); +template SD_EXPORT void NDArray::operator*=(const bfloat16 scalar); +template SD_EXPORT void NDArray::operator*=(const Nd4jLong scalar); +template SD_EXPORT void NDArray::operator*=(const int scalar); +template SD_EXPORT void NDArray::operator*=(const int16_t scalar); +template SD_EXPORT void NDArray::operator*=(const int8_t scalar); +template SD_EXPORT void NDArray::operator*=(const uint8_t scalar); +template SD_EXPORT void NDArray::operator*=(const bool scalar); + +//////////////////////////////////////////////////////////////////////// +template +void NDArray::operator/=(const T scalar) { + if (isS()) + throw std::runtime_error( + "NDArray::operator/=: you can't use this method on String array!"); + + auto other = NDArrayFactory::create(this->dataType(), scalar, getContext()); + NDArray::prepareSpecialUse({this}, {&other}); + NativeOpExecutioner::execScalar( + getContext(), sd::scalar::Divide, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {}); +} +template SD_EXPORT void NDArray::operator/=(const double scalar); +template SD_EXPORT void NDArray::operator/=(const float scalar); +template SD_EXPORT void NDArray::operator/=(const float16 scalar); +template SD_EXPORT void NDArray::operator/=(const bfloat16 scalar); +template SD_EXPORT void NDArray::operator/=(const Nd4jLong scalar); +template SD_EXPORT void NDArray::operator/=(const int scalar); +template SD_EXPORT void NDArray::operator/=(const int16_t scalar); +template SD_EXPORT void NDArray::operator/=(const int8_t scalar); +template SD_EXPORT void NDArray::operator/=(const uint8_t scalar); +template SD_EXPORT void NDArray::operator/=(const bool scalar); + +//////////////////////////////////////////////////////////////////////// +// negative operator, it makes all array elements = -elements +NDArray NDArray::operator-() const& { + if (isS()) + throw std::runtime_error( + "NDArray::negative-: you can't use this method on String array!"); + + NDArray result(shapeInfo(), false, getContext()); + + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execTransformSame( + getContext(), sd::transform::Neg, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), result.buffer(), result.shapeInfo(), + result.specialBuffer(), result.specialShapeInfo(), nullptr, nullptr, + nullptr); + NDArray::registerSpecialUse({&result}, {this}); + + return result; +} + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::operator-() && { + if (isS()) + throw std::runtime_error( + "NDArray::negative-: you can't use this method on String array!"); + + NDArray::prepareSpecialUse({this}, {this}); + NativeOpExecutioner::execTransformSame( + getContext(), sd::transform::Neg, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), nullptr, nullptr, nullptr); + NDArray::registerSpecialUse({this}, {this}); + + return std::move(*this); +} + +//////////////////////////////////////////////////////////////////////// +// mathematical multiplication of two arrays +NDArray mmul(const NDArray& left, const NDArray& right) { + if (left.isS() || right.isS()) + throw std::runtime_error( + "mmul friend function: you can't use this function on String array!"); + auto ptr = MmulHelper::mmul(const_cast(&left), + const_cast(&right), nullptr, 1., 0.); + NDArray result(std::move(*ptr)); + delete ptr; + return result; +} + +//////////////////////////////////////////////////////////////////////// +void NDArray::tileToShape(const std::vector& shape, NDArray& target) { + if (&target != this) { + this->tile(target); + return; + } + + std::vector thisShape(rankOf()); + for (int i = 0; i < rankOf(); ++i) thisShape[i] = sizeAt(i); + + if (!ShapeUtils::areShapesBroadcastable(shape, thisShape)) + throw std::runtime_error( + "NDArray::tileToShape method: the shape of this array and input shape " + "are not suitable for broadcast operation !"); + + const int newRank = shape.size(); + std::vector repeats(newRank); + + for (int i = 1; i <= newRank; ++i) { + if (i > rankOf()) + repeats[newRank - i] = shape[newRank - i]; + else + repeats[newRank - i] = shape[newRank - i] / thisShape[rankOf() - i]; + } + + tilei(repeats); +} + +//////////////////////////////////////////////////////////////////////// +void NDArray::tileToShape(const std::initializer_list& shape, + NDArray& target) { + tileToShape(std::vector(shape), target); +} + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::tileToShape(const Nd4jLong* shapeInfo) { + NDArray result(const_cast(shapeInfo), false, getContext()); + tile(result); + return result; +} + +//////////////////////////////////////////////////////////////////////// +double NDArray::getTrace() const { + if (isS()) + throw std::runtime_error( + "NDArray::getTrace: you can't use this method on String array!"); + + int rank = rankOf(); + auto shape = shapeOf(); + int minDim = 100000000; + + Nd4jLong indices[MAX_RANK]; + for (int j = 0; j < rank; ++j) indices[j] = 1; + + auto offset = shape::getOffset(shapeInfo(), indices); + + for (int i = 0; i < rank; ++i) + if (minDim > shape[i]) minDim = shape[i]; + + double sum = 0.; + + for (int i = 0; i < minDim; ++i) sum += e(i * offset); + + return sum; +} + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::quantize(const NDArray& array) { + if (!array.isR()) + throw std::invalid_argument( + "NDArray::quantize: type of array should be from real space!"); + + auto ws = array.getContext()->getWorkspace(); + + Nd4jLong* shapeInfo = + ShapeBuilders::copyShapeInfo(array.shapeInfo(), true, ws); + ArrayOptions::setPropertyBit(shapeInfo, ARRAY_QUANTIZED); + + std::shared_ptr buffer = std::make_shared( + TypeCast::estimateQuantizedSize(array.lengthOf()), + ArrayOptions::dataType(shapeInfo), ws); + + NDArray result(buffer, ShapeDescriptor(shapeInfo), array.getContext()); + + return result; +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, + NDArray& target, const bool checkTargetShape, + ExtraArguments* extraArgs) const { + if (isS()) + throw std::runtime_error( + "NDArray::applyTrueBroadcast: you can't use this method on String " + "array!"); + + if (((op.s == scalar::Divide || op.s == scalar::FloorDiv || + op.s == scalar::FloorMod) && + other.isB()) || + (op.s == scalar::ReverseDivide && this->isB())) + throw std::runtime_error( + "NDArray::applyTrueBroadcast method: you can't divide by bool array !"); + + if (isEmpty() || other.isEmpty()) return; + + // if (lengthOf() == 1) { + // target.assign(this); + // target.applyPairwiseTransform(op.p, other, extraArgs); + // return; + // } + // if (other.lengthOf() == 1) { + // const_cast(this)->applyScalarArr(op.s, other, target, + // extraArgs); return; + // } + + if (checkTargetShape) { + const Nd4jLong* newShapeInfo = nullptr; + if (!ShapeUtils::evalBroadcastShapeInfo( + *this, other, true, newShapeInfo, + getContext()->getWorkspace())) // the rank of target array must be + // equal to max->rankOf)() + throw std::runtime_error( + "NDArray::applyTrueBroadcast method: the shapes of this and other " + "arrays are not suitable for broadcast operation !"); + if (!shape::equalsTypesAndShapesSoft(target.shapeInfo(), newShapeInfo)) + throw std::runtime_error( + "NDArray::applyTrueBroadcast method: the shape or type of target " + "array is wrong !"); + } + + Nd4jLong const* xShapeInfoH = shapeInfo(); + Nd4jLong const* yShapeInfoH = other.shapeInfo(); + Nd4jLong const* xShapeInfoD = specialShapeInfo(); + Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); + + if (!isSameShape(target)) { + auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast( + target.shapeInfo(), shapeInfo(), getContext()->getWorkspace()); + xShapeInfoH = reinterpret_cast(xPack.primary()); + xShapeInfoD = reinterpret_cast(xPack.special()); + } + if (!other.isSameShape(target)) { + auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast( + target.shapeInfo(), other.shapeInfo(), + other.getContext()->getWorkspace()); + yShapeInfoH = reinterpret_cast(yPack.primary()); + yShapeInfoD = reinterpret_cast(yPack.special()); + } + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execBroadcast( + getContext(), op.b, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, + other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, + target.buffer(), target.shapeInfo(), target.specialBuffer(), + target.specialShapeInfo()); + registerSpecialUse({&target}, {this, &other}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::applyTrueBroadcast(sd::BroadcastBoolOpsTuple op, + const NDArray& other, NDArray& target, + const bool checkTargetShape, + ExtraArguments* extraArgs) const { + if (isS()) + throw std::runtime_error( + "NDArray::applyTrueBroadcast bool: you can't use this method on String " + "array!"); + + if (isEmpty() || other.isEmpty()) return; + + // if (lengthOf() == 1) { + // NDArray temp(target._shapeInfo, dataType(), false, getContext()); + // temp.assign(this); + // temp.applyPairwiseTransform(op.p, other, target, extraArgs); + // return; + // } + // if (other.lengthOf() == 1) { + // this->applyScalarArr(op.s, other, target, extraArgs); + // return; + // } + + if (checkTargetShape) { + const Nd4jLong* newShapeInfo = nullptr; + if (!ShapeUtils::evalBroadcastShapeInfo( + *this, other, true, newShapeInfo, + getContext()->getWorkspace())) // the rank of target array must be + // equal to max->rankOf)() + throw std::runtime_error( + "NDArray::applyTrueBroadcast method: the shapes of this and other " + "arrays are not suitable for broadcast operation !"); + if (!shape::equalsSoft(target._shapeInfo, newShapeInfo) || + target.dataType() != DataType::BOOL) + throw std::runtime_error( + "NDArray::applyTrueBroadcast bool method: the shape or type of " + "target array is wrong !"); + if (dataType() != other.dataType()) + throw std::invalid_argument( + "NDArray::applyTrueBroadcast bool method: this and other arrays must " + "have the same type !"); + } + + Nd4jLong const* xShapeInfoH = shapeInfo(); + Nd4jLong const* yShapeInfoH = other.shapeInfo(); + Nd4jLong const* xShapeInfoD = specialShapeInfo(); + Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); + + if (!isSameShape(target)) { + auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast( + target.shapeInfo(), shapeInfo(), getContext()->getWorkspace()); + xShapeInfoH = reinterpret_cast(xPack.primary()); + xShapeInfoD = reinterpret_cast(xPack.special()); + } + if (!other.isSameShape(target)) { + auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast( + target.shapeInfo(), other.shapeInfo(), + other.getContext()->getWorkspace()); + yShapeInfoH = reinterpret_cast(yPack.primary()); + yShapeInfoD = reinterpret_cast(yPack.special()); + } + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execBroadcastBool( + getContext(), op.b, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, + other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, + target.buffer(), target.shapeInfo(), target.specialBuffer(), + target.specialShapeInfo(), nullptr); + registerSpecialUse({&target}, {this, &other}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::applyTrueBroadcast(sd::BroadcastIntOpsTuple op, + const NDArray& other, NDArray& target, + const bool checkTargetShape, + ExtraArguments* extraArgs) const { + if (isS()) + throw std::runtime_error( + "NDArray::applyTrueBroadcast bool: you can't use this method on String " + "array!"); + + if (isEmpty() || other.isEmpty()) return; + + // if (lengthOf() == 1) { + // NDArray temp(target._shapeInfo, dataType(), false, getContext()); + // temp.assign(this); + // temp.applyPairwiseTransform(op.p, other, target, extraArgs); + // return; + // } + // if (other.lengthOf() == 1) { + // this->applyScalarArr(op.s, other, target, extraArgs); + // return; + // } + + if (checkTargetShape) { + const Nd4jLong* newShapeInfo = nullptr; + if (!ShapeUtils::evalBroadcastShapeInfo( + *this, other, false, newShapeInfo, + getContext()->getWorkspace())) // the rank of target array must be + // equal to max->rankOf)() + throw std::runtime_error( + "NDArray::applyTrueBroadcast method: the shapes of this and other " + "arrays are not suitable for broadcast operation !"); + if (!shape::equalsSoft(target._shapeInfo, newShapeInfo) || + target.dataType() != this->dataType()) + throw std::runtime_error( + "NDArray::applyTrueBroadcast int method: the shape or type of target " + "array is wrong !"); + if (dataType() != other.dataType()) + throw std::invalid_argument( + "NDArray::applyTrueBroadcast int method: this and other arrays must " + "have the same type !"); + } + + Nd4jLong const* xShapeInfoH = shapeInfo(); + Nd4jLong const* yShapeInfoH = other.shapeInfo(); + Nd4jLong const* xShapeInfoD = specialShapeInfo(); + Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); + + if (!isSameShape(target)) { + auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast( + target.shapeInfo(), shapeInfo(), getContext()->getWorkspace()); + + xShapeInfoH = reinterpret_cast(xPack.primary()); + xShapeInfoD = reinterpret_cast(xPack.special()); + } + if (!other.isSameShape(target)) { + auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast( + target.shapeInfo(), other.shapeInfo(), + other.getContext()->getWorkspace()); + yShapeInfoH = reinterpret_cast(yPack.primary()); + yShapeInfoD = reinterpret_cast(yPack.special()); + } + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execBroadcastInt( + getContext(), op.b, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, + other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, + target.buffer(), target.shapeInfo(), target.specialBuffer(), + target.specialShapeInfo()); + registerSpecialUse({&target}, {this, &other}); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, + const NDArray& other, + ExtraArguments* extraArgs) const& { + if (isEmpty() || other.isEmpty()) { + if (isEmpty()) + return NDArray(*this); + else + return NDArray(other); + } + + const Nd4jLong* newShapeInfo = nullptr; + if (!ShapeUtils::evalBroadcastShapeInfo( + *this, other, true, newShapeInfo, + getContext() + ->getWorkspace())) // the rank of new array = max->rankOf)() + throw std::runtime_error( + "NDArray::applyTrueBroadcast method: the shapes of this and other " + "arrays are not suitable for broadcast operation !"); + NDArray result(newShapeInfo, true, getContext()); + + this->applyTrueBroadcast(op, other, result, false, extraArgs); + + return result; +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, NDArray&& other, + ExtraArguments* extraArgs) const& { + if (isEmpty() || other.isEmpty()) { + if (isEmpty()) + return NDArray(*this); + else + return NDArray(other); + } + + const Nd4jLong* newShapeInfo = nullptr; + if (!ShapeUtils::evalBroadcastShapeInfo( + *this, other, true, newShapeInfo, + getContext() + ->getWorkspace())) // the rank of new array = max->rankOf)() + throw std::runtime_error( + "NDArray::applyTrueBroadcast method: the shapes of this and other " + "arrays are not suitable for broadcast operation !"); + + if (!shape::shapeEquals(newShapeInfo, other.shapeInfo())) { + NDArray result(newShapeInfo, true, getContext()); + this->applyTrueBroadcast(op, other, result, false, extraArgs); + return std::move(result); + } + + this->applyTrueBroadcast(op, other, other, false, extraArgs); + return std::move(other); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, + const NDArray& other, + ExtraArguments* extraArgs) && { + if (isEmpty() || other.isEmpty()) { + if (isEmpty()) + return NDArray(*this); + else + return NDArray(other); + } + + const Nd4jLong* newShapeInfo = nullptr; + if (!ShapeUtils::evalBroadcastShapeInfo( + *this, other, true, newShapeInfo, + getContext() + ->getWorkspace())) // the rank of new array = max->rankOf)() + throw std::runtime_error( + "NDArray::applyTrueBroadcast method: the shapes of this and other " + "arrays are not suitable for broadcast operation !"); + + if (!shape::shapeEquals(newShapeInfo, shapeInfo())) { + NDArray result(newShapeInfo, true, getContext()); + this->applyTrueBroadcast(op, other, result, false, extraArgs); + return std::move(result); + } + + this->applyTrueBroadcast(op, other, *this, false, extraArgs); + return std::move(*this); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, NDArray&& other, + ExtraArguments* extraArgs) && { + if (isEmpty() || other.isEmpty()) { + if (isEmpty()) + return NDArray(*this); + else + return NDArray(other); + } + + const Nd4jLong* newShapeInfo = nullptr; + if (!ShapeUtils::evalBroadcastShapeInfo( + *this, other, true, newShapeInfo, + getContext() + ->getWorkspace())) // the rank of new array = max->rankOf)() + throw std::runtime_error( + "NDArray::applyTrueBroadcast method: the shapes of this and other " + "arrays are not suitable for broadcast operation !"); + + const bool thisMove = shape::shapeEquals(newShapeInfo, shapeInfo()); + const bool otherMove = shape::shapeEquals(newShapeInfo, other.shapeInfo()); + + if (!thisMove && !otherMove) { + NDArray result(newShapeInfo, true, getContext()); + this->applyTrueBroadcast(op, other, result, false, extraArgs); + return std::move(result); + } + + if (thisMove) { + this->applyTrueBroadcast(op, other, *this, false, extraArgs); + return std::move(*this); + } + + // otherMove + this->applyTrueBroadcast(op, other, other, false, extraArgs); + return std::move(other); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::applyBroadcast(sd::broadcast::Ops op, + const std::vector& dimensions, + const NDArray& other, NDArray& target, + ExtraArguments* extraArgs) { + if (dimensions.size() == 0) return; + + if (isS()) + throw std::runtime_error( + "NDArray::applyBroadcast: you can't use this method on String array!"); + if (((op == broadcast::Divide || op == broadcast::FloorDiv || + op == broadcast::FloorMod) && + other.isB()) || + (op == broadcast::ReverseDivide && this->isB())) + throw std::runtime_error( + "NDArray::applyBroadcast: you can't divide by array!"); + if (isEmpty() || other.isEmpty()) { + if (!target.isEmpty()) + throw std::runtime_error( + "NDArray::applyBroadcast method: when some of input arrays (or both) " + "is empty, target array must be empty as well !"); + return; + } + + // if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { + // NDArray::prepareSpecialUse({&target}, {this, &other}); + // NativeOpExecutioner::execPairwiseTransform(getContext(), + // fromBroadcastToPairwise(op), buffer(), shapeInfo(), specialBuffer(), + // specialShapeInfo(), other.buffer(), other.shapeInfo(), + // other.specialBuffer(), other.specialShapeInfo(), target.buffer(), + // target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), + // nullptr); NDArray::registerSpecialUse({&target}, {this, &other}); + // return; + // } + + if (target.dataType() != + DataTypeUtils::pickPairwiseResultType(shapeInfo(), other.shapeInfo())) + throw std::invalid_argument( + "NDArray::applyBroadcast method: wrong type of target array !"); + if (!target.isSameShape(this) && !target.isSameShape(other)) + throw std::invalid_argument( + "NDArray::applyBroadcast method: one of of two input arrays (this or " + "other) should has the same shape as target array!"); + + std::vector copy(dimensions); + + if (dimensions.size() > 1) std::sort(copy.begin(), copy.end()); + + Nd4jLong const* xShapeInfoH = shapeInfo(); + Nd4jLong const* yShapeInfoH = other.shapeInfo(); + Nd4jLong const* xShapeInfoD = specialShapeInfo(); + Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); + + if (!isSameShape(target)) { + auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast( + target.shapeInfo(), shapeInfo(), + getContext()->getWorkspace(), copy); + xShapeInfoH = reinterpret_cast(xPack.primary()); + xShapeInfoD = reinterpret_cast(xPack.special()); + } + if (!other.isSameShape(target)) { + auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast( + target.shapeInfo(), other.shapeInfo(), + other.getContext()->getWorkspace(), copy); + yShapeInfoH = reinterpret_cast(yPack.primary()); + yShapeInfoD = reinterpret_cast(yPack.special()); + } + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execBroadcast( + getContext(), op, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, + other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, + target.buffer(), target.shapeInfo(), target.specialBuffer(), + target.specialShapeInfo()); + registerSpecialUse({&target}, {this, &other}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::applyBroadcast(sd::broadcast::BoolOps op, + const std::vector& dimensions, + const NDArray& other, NDArray& target, + ExtraArguments* extraArgs) { + if (dimensions.size() == 0) return; + + if (isS()) + throw std::runtime_error( + "NDArray::applyBroadcast BoolOps: you can't use this method on String " + "array!"); + if (isEmpty() || other.isEmpty()) { + if (!target.isEmpty()) + throw std::runtime_error( + "NDArray::applyBroadcast BoolOps: when some of input arrays (or " + "both) is empty, target array must be empty as well !"); + return; + } + + // if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { + // NDArray::prepareSpecialUse({&target}, {this, &other}); + // NativeOpExecutioner::execPairwiseBoolTransform(getContext(), + // fromBroadcastToPairwiseBool(op), buffer(), shapeInfo(), + // specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), + // other.specialBuffer(), other.specialShapeInfo(), target.buffer(), + // target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), + // nullptr); NDArray::registerSpecialUse({&target}, {this, &other}); + // return; + // } + + if (target.dataType() != DataType::BOOL) + throw std::invalid_argument( + "NDArray::applyBroadcast bool method: type of target array must be " + "BOOL!"); + if (!target.isSameShape(this) && !target.isSameShape(other)) + throw std::invalid_argument( + "NDArray::applyBroadcast bool method: one of of two input arrays (this " + "or other) should has the same shape as target array!"); + if (_dataType != other._dataType) + throw std::invalid_argument( + "NDArray::applyBroadcast bool method: this and other arrays must have " + "the same type !"); + + std::vector copy(dimensions); + + if (dimensions.size() > 1) std::sort(copy.begin(), copy.end()); + + Nd4jLong const* xShapeInfoH = shapeInfo(); + Nd4jLong const* yShapeInfoH = other.shapeInfo(); + Nd4jLong const* xShapeInfoD = specialShapeInfo(); + Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); + + if (!isSameShape(target)) { + auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast( + target.shapeInfo(), shapeInfo(), + getContext()->getWorkspace(), copy); + xShapeInfoH = reinterpret_cast(xPack.primary()); + xShapeInfoD = reinterpret_cast(xPack.special()); + } + if (!other.isSameShape(target)) { + auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast( + target.shapeInfo(), other.shapeInfo(), + other.getContext()->getWorkspace(), copy); + yShapeInfoH = reinterpret_cast(yPack.primary()); + yShapeInfoD = reinterpret_cast(yPack.special()); + } + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execBroadcastBool( + getContext(), op, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, + other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, + target.buffer(), target.shapeInfo(), target.specialBuffer(), + target.specialShapeInfo(), nullptr); + registerSpecialUse({&target}, {this, &other}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::applyBroadcast(sd::broadcast::IntOps op, + const std::vector& dimensions, + const NDArray& other, NDArray& target, + ExtraArguments* extraArgs) { + if (dimensions.empty()) return; + + if (!isZ()) + throw std::runtime_error( + "NDArray::applyBroadcast IntOps: you can't use this method on " + "non-Integer array!"); + if (isEmpty() || other.isEmpty()) { + if (!target.isEmpty()) + throw std::runtime_error( + "NDArray::applyBroadcast IntOps: when some of input arrays (or both) " + "is empty, target array must be empty as well !"); + return; + } + + // if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { + // NDArray::prepareSpecialUse({&target}, {this, &other}); + // NativeOpExecutioner::execPairwiseIntTransform(getContext(), + // fromBroadcastToPairwiseInt(op), buffer(), shapeInfo(), specialBuffer(), + // specialShapeInfo(), other.buffer(), other.shapeInfo(), + // other.specialBuffer(), other.specialShapeInfo(), target.buffer(), + // target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), + // nullptr); NDArray::registerSpecialUse({&target}, {this, &other}); + // return; + // } + + if (target.dataType() != dataType()) + throw std::invalid_argument( + "NDArray::applyBroadcast int method: type of target array must be the " + "same as input!"); + if (!target.isSameShape(this) && !target.isSameShape(other)) + throw std::invalid_argument( + "NDArray::applyBroadcast int method: one of of two input arrays (this " + "or other) should has the same shape as target array!"); + if (_dataType != other._dataType) + throw std::invalid_argument( + "NDArray::applyBroadcast int method: this and other arrays must have " + "the same type !"); + + std::vector copy(dimensions); + + if (dimensions.size() > 1) std::sort(copy.begin(), copy.end()); + + Nd4jLong const* xShapeInfoH = shapeInfo(); + Nd4jLong const* yShapeInfoH = other.shapeInfo(); + Nd4jLong const* xShapeInfoD = specialShapeInfo(); + Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); + + if (!isSameShape(target)) { + auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast( + target.shapeInfo(), shapeInfo(), + getContext()->getWorkspace(), copy); + xShapeInfoH = reinterpret_cast(xPack.primary()); + xShapeInfoD = reinterpret_cast(xPack.special()); + } + if (!other.isSameShape(target)) { + auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast( + target.shapeInfo(), other.shapeInfo(), + other.getContext()->getWorkspace(), copy); + yShapeInfoH = reinterpret_cast(yPack.primary()); + yShapeInfoD = reinterpret_cast(yPack.special()); + } + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execBroadcastInt( + getContext(), op, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, + other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, + target.buffer(), target.shapeInfo(), target.specialBuffer(), + target.specialShapeInfo()); + registerSpecialUse({&target}, {this, &other}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::applyBroadcast(sd::broadcast::Ops op, + const std::initializer_list dimensions, + const NDArray& tadArray, NDArray& target, + ExtraArguments* extraArgs) { + std::vector vec(dimensions); + applyBroadcast(op, vec, tadArray, target, extraArgs); +} + +//////////////////////////////////////////////////////////////////////// +void* NDArray::operator new(size_t i) { + if (sd::memory::MemoryRegistrator::getInstance().hasWorkspaceAttached()) { + sd::memory::Workspace* ws = + sd::memory::MemoryRegistrator::getInstance().getWorkspace(); + return ws->allocateBytes((Nd4jLong)i); + } else { + auto p = malloc(i); + CHECK_ALLOC(p, "Failed to allocate new NDArray", i); + return p; + } +} + +//////////////////////////////////////////////////////////////////////// +void NDArray::operator delete(void* p) { + if (!sd::memory::MemoryRegistrator::getInstance().hasWorkspaceAttached()) + free(p); +} + +//////////////////////////////////////////////////////////////////////// +template +std::vector NDArray::asVectorT() { + std::vector result(this->lengthOf()); + + PRAGMA_OMP_SIMD + for (int e = 0; e < this->lengthOf(); e++) result[e] = this->e(e); + + return result; +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT std::vector, NDArray::asVectorT(), + LIBND4J_TYPES); + +////////////////////////////////////////////////////////////////////////// +// set new order and shape in case of suitable array length +bool NDArray::reshapei(const char order, const std::vector& cshape, + const bool copyToNewBuff) { + // check firstly whether cshape is identical to shape of array, if yes then + // reshape is unnecessary + if (order == ordering() && + shape::shapeEquals(rankOf(), shapeOf(), cshape.size(), cshape.data())) + return true; + + const bool isOutShapeEmpty = + std::find(cshape.begin(), cshape.end(), 0) != cshape.end(); + + if (isEmpty() && !isOutShapeEmpty) + throw std::invalid_argument( + "NDArray::reshapei: can't reshape empty array to non-empty !"); + if (!isEmpty() && isOutShapeEmpty) + throw std::invalid_argument( + "NDArray::reshapei: can't reshape non-empty array to empty !"); + if (isEmpty() && isOutShapeEmpty) { + Nd4jLong* shapeInfoNew = ShapeBuilders::emptyShapeInfo( + dataType(), order, cshape, getContext()->getWorkspace()); + setShapeInfo(shapeInfoNew); + RELEASE(shapeInfoNew, getContext()->getWorkspace()); + return true; + } + + std::vector shape(cshape); + int rank = shape.size(); + + // looking for negative in shape + + int numberNegativesOnes = 0; + + Nd4jLong* shape_ = shape.data(); + for (int i = 0; i < (int)shape.size(); i++) { + if (shape[i] < 0) { + if (numberNegativesOnes >= 1) + throw std::runtime_error( + "NDArray::reshapei: only one dimension can be negative at once"); + + numberNegativesOnes++; + + int shapeLength = 1; + for (int j = 0; j < (int)shape.size(); j++) + if (i != j) shapeLength *= shape_[j]; + + Nd4jLong realShape = sd::math::nd4j_abs(lengthOf() / shapeLength); + auto thisNewShape = new Nd4jLong[shape.size()]; + + for (int j = 0; j < (int)shape.size(); j++) + if (i != j) + thisNewShape[j] = shape_[j]; + else + thisNewShape[j] = realShape; + + shape_ = thisNewShape; + } + } + + for (int e = 0; e < (int)shape.size(); e++) shape[e] = shape_[e]; + + if (numberNegativesOnes > 0) delete[] shape_; + + Nd4jLong arrLength = 1; + for (const auto& item : shape) arrLength *= item; + + if (platformBuffer() == nullptr || arrLength != this->lengthOf()) { + this->printShapeInfo("Mismatched shape"); + sd::Logger::printv("Shape requested: ", shape); + nd4j_debug("Requested length in reshape: %i; Existing length: %i;\n", + arrLength, this->lengthOf()); + throw std::runtime_error("NDArray::reshapei: bad input shape!"); + } + + Nd4jLong* shapeInfoNew; + ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), + shape::shapeInfoLength(rank), Nd4jLong); + + bool canReshape = shape::reshapeC(shapeInfo(), order, shape.size(), + shape.data(), shapeInfoNew); + + if (canReshape) { + setShapeInfo(shapeInfoNew); + } else { + NDArray temp(order, shape, dataType(), getContext()); + if (copyToNewBuff) this->applyTransform(transform::Assign, temp, nullptr); + *this = std::move(temp); + } + + RELEASE(shapeInfoNew, getContext()->getWorkspace()); + + return canReshape; +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::nullify() { + if (isEmpty()) return; + + if (isView() || ews() != 1) + assign(0); + else + _buffer->setToZeroBuffers(); +} + +//////////////////////////////////////////////////////////////////////// +template +void NDArray::templatedSet(void* buffer, const Nd4jLong xOfsset, + sd::DataType dtype, const void* value) { + BUILD_SINGLE_PARTIAL_SELECTOR( + dtype, templatedSet<, T>(buffer, xOfsset, value), LIBND4J_TYPES); +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT void NDArray::templatedSet, + (void* buffer, const Nd4jLong xOfsset, sd::DataType dtype, + const void* value), + LIBND4J_TYPES); + +//////////////////////////////////////////////////////////////////////// +void NDArray::applyPairwiseTransform(sd::pairwise::Ops op, const NDArray& other, + NDArray& target, + ExtraArguments* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::applyPairwiseTransform: you can't use this method on String " + "array!"); + if (other.lengthOf() != target.lengthOf()) + throw std::invalid_argument( + "NDArray::applyPairwiseTransform method - lengths of arrays are " + "mismatched"); + if (target.dataType() != this->dataType() && + target.dataType() != other.dataType()) + throw std::invalid_argument( + "NDArray::applyPairwiseTransform method - type of target array must be " + "the same as type of this or other array !"); + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execPairwiseTransform( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), target.buffer(), + target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), + extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) + : nullptr); + NDArray::registerSpecialUse({&target}, {this, &other}); + + if (extraParams != nullptr) synchronize("NDArray::applyPairwiseTransform"); +} + +//////////////////////////////////////////////////////////////////////// +void NDArray::applyPairwiseTransform(sd::pairwise::BoolOps op, + const NDArray& other, NDArray& target, + ExtraArguments* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::applyPairwiseTransform BoolOps: you can't use this method on " + "String array!"); + if (other.lengthOf() != target.lengthOf()) + throw std::invalid_argument( + "NDArray::applyPairwiseTransform BoolOps method - lengths of arrays " + "are mismatched"); + if (!target.isB()) + throw std::invalid_argument( + "NDArray::applyPairwiseTransform BoolOps method - result must have " + "bool type"); + if (dataType() != other.dataType()) + throw std::invalid_argument( + "NDArray::applyPairwiseTransform BoolOps method - this and other " + "arrays must have the same type !"); + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execPairwiseBoolTransform( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), target.buffer(), + target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), + extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) + : nullptr); + NDArray::registerSpecialUse({&target}, {this, &other}); +} + +//////////////////////////////////////////////////////////////////////// +void NDArray::applyPairwiseTransform(sd::pairwise::IntOps op, + const NDArray& other, NDArray& target, + ExtraArguments* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::applyPairwiseTransform IntOps: you can't use this method on " + "String array!"); + if (other.lengthOf() != target.lengthOf()) + throw std::invalid_argument( + "NDArray::applyPairwiseTransform IntOps method - lengths of arrays are " + "mismatched"); + if (!target.isZ()) + throw std::invalid_argument( + "NDArray::applyPairwiseTransform IntOps method - result must have bool " + "type"); + if (dataType() != other.dataType()) + throw std::invalid_argument( + "NDArray::applyPairwiseTransform IntOps method - this and other arrays " + "must have the same type !"); + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execPairwiseIntTransform( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), target.buffer(), + target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), + extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) + : nullptr); + NDArray::registerSpecialUse({&target}, {this, &other}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::applyPairwiseTransform(sd::pairwise::Ops op, const NDArray& other, + ExtraArguments* extraParams) { + applyPairwiseTransform(op, other, *this, extraParams); +} + +//////////////////////////////////////////////////////////////////////// +template +void NDArray::templatedDoubleAssign(void* xBuffer, const Nd4jLong xOffset, + const void* yBuffer, + const Nd4jLong yOffset) const { + auto x = reinterpret_cast(xBuffer); + const auto y = reinterpret_cast(yBuffer); + x[xOffset] = static_cast(y[yOffset]); +} +BUILD_DOUBLE_TEMPLATE(template SD_EXPORT void NDArray::templatedDoubleAssign, + (void* xBuffer, const Nd4jLong xOffset, + const void* yBuffer, const Nd4jLong yOffset) const, + LIBND4J_TYPES, LIBND4J_TYPES); + +//////////////////////////////////////////////////////////////////////// +void NDArray::varianceAlongDimension(sd::variance::Ops op, NDArray& target, + const bool biasCorrected, + const std::vector& dimensions) const { + if (isS()) + throw std::runtime_error( + "NDArray::varianceAlongDimension: you can't use this method on String " + "array!"); + + if (!target.isR()) + throw std::runtime_error( + "NDArray::varianceAlongDimension: target array must have FLOAT type"); + + NDArray::prepareSpecialUse({&target}, {this}); + + if (rankOf() == dimensions.size() || dimensions.empty()) + NativeOpExecutioner::execSummaryStatsScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), biasCorrected); + else { + std::vector copy(dimensions); + auto pDims = + sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + this->shapeInfo(), dimensions); + NativeOpExecutioner::execSummaryStats( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), pDims, + dimensions.size(), packX.platformShapeInfo(), packX.platformOffsets(), + biasCorrected); + synchronize("NDArray::varianceAlongDimension"); + } + + NDArray::registerSpecialUse({&target}, {this}); +} + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::varianceAlongDimension( + sd::variance::Ops op, const bool biasCorrected, + const std::vector& dimensions) const { + if (isS()) + throw std::runtime_error( + "NDArray::varianceAlongDimension: you can't use this method on String " + "array!"); + + std::vector copy(dimensions); + if (copy.size() > 1) std::sort(copy.begin(), copy.end()); + + auto newShape = ShapeUtils::evalReduceShapeInfo( + 'c', copy, *this, DataTypeUtils::pickFloatingType(dataType()), false, + false, getContext()->getWorkspace()); + NDArray result(newShape, true, getContext()); + + this->varianceAlongDimension(op, result, biasCorrected, dimensions); + + return result; +} + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::varianceAlongDimension( + sd::variance::Ops op, const bool biasCorrected, + const std::initializer_list& dimensions) const { + return varianceAlongDimension(op, biasCorrected, + std::vector(dimensions)); +} + +//////////////////////////////////////////////////////////////////////// +void NDArray::varianceAlongDimension( + sd::variance::Ops op, NDArray& target, const bool biasCorrected, + const std::initializer_list& dimensions) const { + varianceAlongDimension(op, target, biasCorrected, + std::vector(dimensions)); +} + +//////////////////////////////////////////////////////////////////////// +// This method returns new copy of this NDArray, optionally in different order +NDArray NDArray::dup(const char newOrder) const { + if (isEmpty()) return NDArrayFactory::empty(dataType(), getContext()); + + char order = newOrder == 'a' ? ordering() : newOrder; + + // for now string arrays require special treatment + if (isS()) { + if (dataType() == DataType::UTF8) { + std::vector strings(lengthOf()); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + strings[i] = std::move(this->e(i)); + } + }; + + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + return NDArray(getShapeAsVector(), strings, dataType(), getContext()); + } + if (dataType() == DataType::UTF16) { + std::vector strings(lengthOf()); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + strings[i] = std::move(this->e(i)); + } + }; + + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + return NDArray(getShapeAsVector(), strings, dataType(), getContext()); + } + + std::vector strings(lengthOf()); + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + strings[i] = std::move(this->e(i)); + } + }; + + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + return NDArray(getShapeAsVector(), strings, dataType(), getContext()); + } + + NDArray result(order, + isScalar() ? std::vector({0}) : getShapeAsVector(), + dataType(), getContext()); + result.assign(*this); + + return result; +} + +//////////////////////////////////////////////////////////////////////// +// This method returns true if two arrays are equal, with custom or default Eps +// value of 1e-5, false otherwise +bool NDArray::equalsTo(const NDArray* other, double eps) const { + if (dataType() != other->dataType() || lengthOf() != other->lengthOf()) + return false; + + // we need to be able to compare [1, len] to [len] + if ((rankOf() == 1 && other->rankOf() == 2) || + (rankOf() == 2 && other->rankOf() == 1)) { + // FIXME: do something here? + } else if (!shape::equalsSoft(shapeInfo(), other->shapeInfo())) + return false; + + if (isS()) { + // string is special case, we'll compare them one by one, considering both + // arrays are guaranteed to have the same length + + if (dataType() == DataType::UTF8) { + for (Nd4jLong e = 0; e < this->lengthOf(); e++) { + auto s1 = this->e(e); + auto s2 = other->e(e); + + if (s1 != s2) return false; + } + } else if (dataType() == DataType::UTF16) { + for (Nd4jLong e = 0; e < this->lengthOf(); e++) { + auto s1 = this->e(e); + auto s2 = other->e(e); + + if (s1 != s2) return false; + } + } else { + for (Nd4jLong e = 0; e < this->lengthOf(); e++) { + auto s1 = this->e(e); + auto s2 = other->e(e); + + if (s1 != s2) return false; + } + } + + return true; + } else { + // regular numeric types + NDArray tmp(sd::DataType::FLOAT32, getContext()); // scalar = 0 + + ExtraArguments extras({0.0, 0.0, eps}); + + NDArray::prepareSpecialUse({&tmp}, {this, other}); + NativeOpExecutioner::execReduce3Scalar( + getContext(), reduce3::EqualsWithEps, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), + extras.argumentsAsT(DataType::FLOAT32), other->buffer(), + other->shapeInfo(), other->specialBuffer(), other->specialShapeInfo(), + tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), + tmp.specialShapeInfo()); + NDArray::registerSpecialUse({&tmp}, {this, other}); + + synchronize("NDArray::equalsTo"); + + if (tmp.e(0) != 0) return false; + + return true; + } +} + +////////////////////////////////////////////////////////////////////////// +template <> +std::string NDArray::e(const Nd4jLong i) const { + if (!isS()) + throw std::runtime_error("Can't get std::string out of non-string array"); + + if (i == lengthOf()) + throw std::runtime_error("Can't get std::string for index out of range"); + + if (this->dataType() == DataType::UTF16) { + auto u16 = this->e(i); + std::string s; + StringUtils::u16StringToU8String(u16, s); + return s; + } + + if (this->dataType() == DataType::UTF32) { + auto u32 = this->e(i); + std::string s; + StringUtils::u32StringToU8String(u32, s); + return s; + } + + NDArray::preparePrimaryUse({}, {this}); + + auto offsets = bufferAsT(); + auto offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); + auto start = offsets[i]; + auto end = offsets[i + 1]; + auto data = bufferAsT() + offsetsLength + start; + + std::string r(reinterpret_cast(data), (end - start)); + + registerPrimaryUse({}, {this}); + + return r; +} + +template <> +std::u16string NDArray::e(const Nd4jLong i) const { + if (!isS()) + throw std::runtime_error( + "Can't get std::u16string out of non-string array"); + + if (i == lengthOf()) + throw std::runtime_error("Can't get std::u16string for index out of range"); + + if (this->dataType() == DataType::UTF8) { + auto u = this->e(i); + std::u16string s; + StringUtils::u8StringToU16String(u, s); + return s; + } + + if (this->dataType() == DataType::UTF32) { + auto u32 = this->e(i); + std::u16string s; + StringUtils::u32StringToU16String(u32, s); + return s; + } + + NDArray::preparePrimaryUse({}, {this}); + + auto offsets = bufferAsT(); + Nd4jLong offsetsLength = + ShapeUtils::stringBufferHeaderRequirements(lengthOf()); + Nd4jLong start = offsets[i]; + Nd4jLong end = offsets[i + 1]; + auto data = bufferAsT() + offsetsLength + start; + + std::u16string r(reinterpret_cast(data), + (end - start) / sizeof(char16_t)); + + registerPrimaryUse({}, {this}); + + return r; +} + +template <> +std::u32string NDArray::e(const Nd4jLong i) const { + if (!isS()) + throw std::runtime_error( + "Can't get std::u32string out of non-string array"); + + if (i == lengthOf()) + throw std::runtime_error("Can't get std::u32string for index out of range"); + + if (this->dataType() == DataType::UTF8) { + auto u = this->e(i); + std::u32string s; + StringUtils::u8StringToU32String(u, s); + return s; + } + + if (this->dataType() == DataType::UTF16) { + auto u16 = this->e(i); + std::u32string s; + StringUtils::u16StringToU32String(u16, s); + return s; + } + + NDArray::preparePrimaryUse({}, {this}); + + auto offsets = bufferAsT(); + Nd4jLong offsetsLength = + ShapeUtils::stringBufferHeaderRequirements(lengthOf()); + Nd4jLong start = offsets[i]; + Nd4jLong end = offsets[i + 1]; + + auto data = bufferAsT() + offsetsLength + start; + + std::u32string r(reinterpret_cast(data), + (end - start) / sizeof(char32_t)); + + registerPrimaryUse({}, {this}); + + return r; +} + +////////////////////////////////////////////////////////////////////////// +template <> +utf8string NDArray::e(const Nd4jLong i) const { + if (!isS()) + throw std::runtime_error("This method is available for String arrays only"); + + auto rp = getOffset(i); + + syncToHost(); + tickReadHost(); + + return *(reinterpret_cast(buffer())[rp]); +} + +///////////////////////////////////////////////////////////////////////// +template +T NDArray::e(const Nd4jLong i) const { + const auto rp = getOffset(i); + + NDArray::preparePrimaryUse({}, {this}); + NDArray::registerPrimaryUse({}, {this}); + BUILD_SINGLE_PARTIAL_SELECTOR( + dataType(), return templatedGet<, T>(buffer(), rp), LIBND4J_TYPES); +} +BUILD_SINGLE_UNCHAINED_TEMPLATE(template SD_EXPORT, + NDArray::e(const Nd4jLong) const, + LIBND4J_TYPES); + +////////////////////////////////////////////////////////////////////////// +// Returns value from 2D matrix by coordinates/indexes +template +T NDArray::e(const Nd4jLong i, const Nd4jLong j) const { + if (rankOf() != 2 || i >= shapeOf()[0] || j >= shapeOf()[1]) + throw std::invalid_argument( + "NDArray::e(i,j): one of input indexes is out of array length or " + "rank!=2 !"); + + const Nd4jLong coords[2] = {i, j}; + const auto xOffset = shape::getOffset(shapeInfo(), coords); + + NDArray::preparePrimaryUse({}, {this}); + NDArray::registerPrimaryUse({}, {this}); + + BUILD_SINGLE_PARTIAL_SELECTOR( + dataType(), return templatedGet<, T>(buffer(), xOffset), LIBND4J_TYPES); + + return static_cast(119); +} +BUILD_SINGLE_UNCHAINED_TEMPLATE(template SD_EXPORT, + NDArray::e(const Nd4jLong, const Nd4jLong) + const, + LIBND4J_TYPES); + +////////////////////////////////////////////////////////////////////////// +// returns value from 3D tensor by coordinates +template +T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const { + if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || + k >= shapeOf()[2]) + throw std::invalid_argument( + "NDArray::e(i,j,k): one of input indexes is out of array length or " + "rank!=3 !"); + + const Nd4jLong coords[3] = {i, j, k}; + const auto xOffset = shape::getOffset(shapeInfo(), coords); + + NDArray::preparePrimaryUse({}, {this}); + NDArray::registerPrimaryUse({}, {this}); + + BUILD_SINGLE_PARTIAL_SELECTOR( + dataType(), return templatedGet<, T>(buffer(), xOffset), LIBND4J_TYPES); + + return static_cast(119); +} +BUILD_SINGLE_UNCHAINED_TEMPLATE(template SD_EXPORT, + NDArray::e(const Nd4jLong, const Nd4jLong, + const Nd4jLong) const, + LIBND4J_TYPES); + +////////////////////////////////////////////////////////////////////////// +// returns value from 3D tensor by coordinates +template +T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, + const Nd4jLong l) const { + if (rankOf() != 4 || i >= shapeOf()[0] || j >= shapeOf()[1] || + k >= shapeOf()[2] || l >= shapeOf()[3]) + throw std::invalid_argument( + "NDArray::e(i,j,k,l): one of input indexes is out of array length or " + "rank!=4 !"); + + const Nd4jLong coords[4] = {i, j, k, l}; + const auto xOffset = shape::getOffset(shapeInfo(), coords); + + NDArray::preparePrimaryUse({}, {this}); + NDArray::registerPrimaryUse({}, {this}); + + BUILD_SINGLE_PARTIAL_SELECTOR( + dataType(), return templatedGet<, T>(buffer(), xOffset), LIBND4J_TYPES); + + return static_cast(119); +} +BUILD_SINGLE_UNCHAINED_TEMPLATE(template SD_EXPORT, + NDArray::e(const Nd4jLong, const Nd4jLong, + const Nd4jLong, const Nd4jLong) + const, + LIBND4J_TYPES); + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::e(const Nd4jLong i) const { + const auto offset = getOffset(i); + + NDArray scalar(dataType(), getContext()); + + scalar.copyBuffersContinuouslyFrom(*this, sizeOfT(), 0, + bufferOffset() + offset); + + return scalar; +} + +////////////////////////////////////////////////////////////////////////// +// perform array transformation +void NDArray::applyTransform(sd::transform::FloatOps op, NDArray& target, + ExtraArguments* extraParams) { + if (isS()) + throw std::runtime_error( + "NDArray::applyTransform FloatOps: you can't use this method on String " + "array!"); + + if (!target.isR()) + throw std::runtime_error( + "NDArray::applyTransform FloatOps: target array must have one of FLOAT " + "types"); + + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execTransformFloat( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), + extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) + : nullptr, + nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this}); +} + +//////////////////////////////////////////////////////////////////////// +void NDArray::applyTransform(sd::transform::AnyOps op, NDArray& target, + ExtraArguments* extraParams) { + if (isS()) + throw std::runtime_error( + "NDArray::applyTransform AnyOps: you can't use this method on String " + "array!"); + + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execTransformAny( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), + extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) + : nullptr, + nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this}); +} + +//////////////////////////////////////////////////////////////////////// +void NDArray::applyTransform(sd::transform::SameOps op, NDArray& target, + ExtraArguments* extraParams) { + if (isS()) + throw std::runtime_error( + "NDArray::applyTransform SameOps: you can't use this method on String " + "array!"); + + if (target.dataType() != dataType()) + throw std::runtime_error( + "NDArray::applyTransform SameOps: target array must have the same data " + "type as original array"); + + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execTransformSame( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), + extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) + : nullptr, + nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this}); +} + +//////////////////////////////////////////////////////////////////////// +void NDArray::applyTransform(sd::transform::StrictOps op, NDArray& target, + ExtraArguments* extraParams) { + if (isS()) + throw std::runtime_error( + "NDArray::applyTransform StrictOps: you can't use this method on " + "String array!"); + + if (!this->isR() || !target.isR() || (this->dataType() != target.dataType())) + throw std::runtime_error( + "NDArray::applyTransform StrictOps: both Source and Target array must " + "have same FLOAT type !"); + + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execTransformStrict( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), + extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) + : nullptr, + nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this}); +} + +//////////////////////////////////////////////////////////////////////// +void NDArray::applyTransform(sd::transform::BoolOps op, NDArray& target, + ExtraArguments* extraParams) { + if (isS()) + throw std::runtime_error( + "NDArray::applyTransform BoolOps: you can't use this method on String " + "array!"); + + if (!target.isB()) + throw std::runtime_error( + "NDArray::applyTransform BoolOps: target array must have one of BOOL " + "types"); + + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execTransformBool( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), + extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) + : nullptr, + nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this}); +} + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::transform(sd::transform::FloatOps op, + void* extraParams) const& { + if (isS()) + throw std::runtime_error( + "NDArray::transform FloatOps: you can't use this method on String " + "array!"); + + NDArray result(ordering(), getShapeAsVector(), + DataTypeUtils::pickFloatingType(dataType()), getContext()); + + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execTransformFloat( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), result.buffer(), result.shapeInfo(), + result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, + nullptr); + NDArray::registerSpecialUse({&result}, {this}); + + return result; +} + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::transform(sd::transform::FloatOps op, void* extraParams) && { + if (isS()) + throw std::runtime_error( + "NDArray::transform SameOps: you can't use this method on String " + "array!"); + + NDArray::prepareSpecialUse({this}, {this}); + NativeOpExecutioner::execTransformFloat( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), extraParams, nullptr, nullptr); + NDArray::registerSpecialUse({this}, {this}); + + return std::move(*this); +} + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::transform(sd::transform::SameOps op, + void* extraParams) const& { + if (isS()) + throw std::runtime_error( + "NDArray::transform SameOps: you can't use this method on String " + "array!"); + + NDArray result(shapeInfo(), false, getContext()); + + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execTransformSame( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), result.buffer(), result.shapeInfo(), + result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, + nullptr); + NDArray::registerSpecialUse({&result}, {this}); + + return result; +} + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::transform(sd::transform::SameOps op, void* extraParams) && { + if (isS()) + throw std::runtime_error( + "NDArray::transform SameOps: you can't use this method on String " + "array!"); + + NDArray::prepareSpecialUse({this}, {this}); + NativeOpExecutioner::execTransformSame( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), extraParams, nullptr, nullptr); + NDArray::registerSpecialUse({this}, {this}); + + return std::move(*this); +} + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::transform(sd::transform::StrictOps op, + void* extraParams) const& { + if (!this->isR()) + throw std::runtime_error("Source array must have one of FLOAT types"); + + NDArray result(shapeInfo(), false, getContext()); + + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execTransformStrict( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), result.buffer(), result.shapeInfo(), + result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, + nullptr); + NDArray::registerSpecialUse({&result}, {this}); + + return result; +} + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::transform(sd::transform::StrictOps op, void* extraParams) && { + if (!this->isR()) + throw std::runtime_error("Source array must have one of FLOAT types"); + + NDArray::prepareSpecialUse({this}, {this}); + NativeOpExecutioner::execTransformStrict( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), extraParams, nullptr, nullptr); + NDArray::registerSpecialUse({this}, {this}); + + return std::move(*this); +} + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::transform(sd::transform::BoolOps op, + void* extraParams) const& { + if (isS()) + throw std::runtime_error( + "NDArray::transform BoolOps: you can't use this method on String " + "array!"); + + NDArray result(ordering(), getShapeAsVector(), sd::DataType::BOOL, + getContext()); + + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execTransformBool( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), result.buffer(), result.shapeInfo(), + result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, + nullptr); + NDArray::registerSpecialUse({&result}, {this}); + + return result; +} + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::transform(sd::transform::BoolOps op, void* extraParams) && { + if (isS()) + throw std::runtime_error( + "NDArray::transform BoolOps: you can't use this method on String " + "array!"); + + NDArray::prepareSpecialUse({this}, {this}); + NativeOpExecutioner::execTransformBool( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), extraParams, nullptr, nullptr); + NDArray::registerSpecialUse({this}, {this}); + + return std::move(*this); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::applyScalarArr(sd::scalar::Ops op, const NDArray& scalar, + NDArray& target, ExtraArguments* extraParams) { + if (isS()) + throw std::runtime_error( + "NDArray::applyScalarArr: you can't use this method on String array!"); + if (scalar.lengthOf() != 1) + throw std::invalid_argument( + "NDArray::applyScalarArr method: operand is not a scalar!"); + + if (target.dataType() != DataTypeUtils::pickPairwiseResultType( + shapeInfo(), scalar.shapeInfo()) && + !(target.dataType() == dataType() || + target.dataType() == scalar.dataType())) + throw std::invalid_argument( + "NDArray::applyScalarArr method: wrong type of target array!"); + + NDArray::prepareSpecialUse({&target}, {this, &scalar}); + NativeOpExecutioner::execScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), scalar.buffer(), + scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), + extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) + : nullptr); + NDArray::registerSpecialUse({&target}, {this, &scalar}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::applyScalarArr(sd::scalar::BoolOps op, const NDArray& scalar, + NDArray& target, + ExtraArguments* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::applyScalarArr BoolOps: you can't use this method on String " + "array!"); + if (!target.isB()) + throw std::invalid_argument( + "NDArray::applyScalarArr bool method: target has not bool type!"); + if (dataType() != scalar.dataType()) { + nd4j_printf( + "NDArray::applyScalarArr BoolOps: this dtype: [%i]; scalar dtype: " + "[%i]\n", + this->dataType(), scalar.dataType()); + throw std::invalid_argument( + "NDArray::applyScalarArr bool method: this and scalar arrays must have " + "the same type!"); + } + + NDArray::prepareSpecialUse({&target}, {this, &scalar}); + NativeOpExecutioner::execScalarBool( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), scalar.buffer(), + scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), + extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) + : nullptr); + NDArray::registerSpecialUse({&target}, {this, &scalar}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::applyScalarArr(sd::scalar::IntOps op, const NDArray& scalar, + NDArray& target, + ExtraArguments* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::applyScalarArr IntOps: you can't use this method on String " + "array!"); + + if (target.dataType() != this->dataType()) + throw std::invalid_argument( + "NDArray::applyScalarArr int method: target has not bool type!"); + if (dataType() != scalar.dataType()) { + nd4j_printf( + "NDArray::applyScalarArr IntOps: this dtype: [%i]; scalar dtype: " + "[%i]\n", + this->dataType(), scalar.dataType()); + throw std::invalid_argument( + "NDArray::applyScalarArr int method: this and scalar arrays must have " + "the same type!"); + } + + NDArray::prepareSpecialUse({&target}, {this, &scalar}); + NativeOpExecutioner::execScalarInt( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), scalar.buffer(), + scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), + extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) + : nullptr); + NDArray::registerSpecialUse({&target}, {this, &scalar}); +} + +//////////////////////////////////////////////////////////////////////// +template +void NDArray::applyScalar(sd::scalar::IntOps op, const T scalar, + NDArray& target, ExtraArguments* extraParams) const { + NDArray scalarArr = + NDArrayFactory::create(this->dataType(), scalar, getContext()); + applyScalarArr(op, scalarArr, target, extraParams); +} + +template <> +SD_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, + const NDArray& scalar, NDArray& target, + ExtraArguments* extraParams) const { + throw std::runtime_error( + "NDArray::applyScalar method: do not use me!"); +} +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::IntOps op, const double scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::IntOps op, const float scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::IntOps op, const float16 scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::IntOps op, const bfloat16 scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::IntOps op, const Nd4jLong scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::IntOps op, const int scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::IntOps op, const int16_t scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::IntOps op, const int8_t scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::IntOps op, const uint8_t scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::IntOps op, const bool scalar, NDArray& target, + ExtraArguments* extraParams) const; + +//////////////////////////////////////////////////////////////////////// +template +void NDArray::applyScalar(sd::scalar::Ops op, const T scalar, NDArray& target, + ExtraArguments* extraParams) { + auto scalarArr = + NDArrayFactory::create(dataType(), scalar, this->getContext()); + applyScalarArr(op, scalarArr, target, extraParams); +} +template <> +SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const NDArray& scalar, + NDArray& target, + ExtraArguments* extraParams) { + throw std::runtime_error( + "NDArray::applyScalar method: do not use me!"); +} +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, + const double scalar, + NDArray& target, + ExtraArguments* extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, + const float scalar, + NDArray& target, + ExtraArguments* extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, + const float16 scalar, + NDArray& target, + ExtraArguments* extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, + const bfloat16 scalar, + NDArray& target, + ExtraArguments* extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, + const Nd4jLong scalar, + NDArray& target, + ExtraArguments* extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, + const int scalar, NDArray& target, + ExtraArguments* extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, + const int16_t scalar, + NDArray& target, + ExtraArguments* extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, + const int8_t scalar, + NDArray& target, + ExtraArguments* extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, + const uint8_t scalar, + NDArray& target, + ExtraArguments* extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, + const bool scalar, NDArray& target, + ExtraArguments* extraParams); + +//////////////////////////////////////////////////////////////////////// +template +void NDArray::applyScalar(sd::scalar::BoolOps op, const T scalar, + NDArray& target, ExtraArguments* extraParams) const { + NDArray scalarArr = NDArrayFactory::create(scalar, getContext()); + applyScalarArr(op, scalarArr, target, extraParams); +} + +template <> +SD_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, + const NDArray& scalar, NDArray& target, + ExtraArguments* extraParams) const { + throw std::runtime_error( + "NDArray::applyScalar method: do not use me!"); +} +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::BoolOps op, const double scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::BoolOps op, const float scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::BoolOps op, const float16 scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::BoolOps op, const bfloat16 scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::BoolOps op, const Nd4jLong scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::BoolOps op, const int scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::BoolOps op, const int16_t scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::BoolOps op, const int8_t scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::BoolOps op, const uint8_t scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::BoolOps op, const bool scalar, NDArray& target, + ExtraArguments* extraParams) const; + +//////////////////////////////////////////////////////////////////////// +void NDArray::applyIndexReduce(sd::indexreduce::Ops op, NDArray& target, + const std::vector& dimensions, + const ExtraArguments* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::applyIndexReduce: you can't use this method on String " + "array!"); + + if (target.dataType() != sd::DataType::INT64 && + target.dataType() != sd::DataType::INT32) + throw std::runtime_error( + "NDArray::applyIndexReduce operations return INT32/INT64"); + + void* params = extraParams != nullptr + ? const_cast(extraParams) + ->argumentsAsT(this->dataType()) + : nullptr; + + NDArray::prepareSpecialUse({&target}, {this}); + + if (target.lengthOf() == 1) { + NativeOpExecutioner::execIndexReduceScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), params, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo()); + } else { + std::vector copy = dimensions; + shape::checkDimensions(rankOf(), copy); + auto pDims = + sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + shapeInfo(), copy); + NativeOpExecutioner::execIndexReduce( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), params, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), + packX.platformShapeInfo(), packX.platformOffsets()); + synchronize("NDArray::applyIndexReduce"); + } + + registerSpecialUse({&target}, {this}); +} + +//////////////////////////////////////////////////////////////////////// +// reduce dimensions in this array relying on index operations +NDArray NDArray::applyIndexReduce(sd::indexreduce::Ops op, + const std::vector& dimensions, + const ExtraArguments* extraParams) const { + std::vector copy = dimensions; + auto newShape = + ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataType::INT64, false, + false, getContext()->getWorkspace()); + NDArray result(newShape, true, getContext()); + + applyIndexReduce(op, result, copy, extraParams); + + return result; +} + +//////////////////////////////////////////////////////////////////////// +// apply reduce3 operations to this and other array, return result in new output +// array +NDArray NDArray::applyReduce3(sd::reduce3::Ops op, const NDArray& other, + const ExtraArguments* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::applyReduce3 method: you can't use this method on String " + "array!"); + if (dataType() != other.dataType()) + throw std::runtime_error( + "NDArray::applyReduce3 method: the types of this and other arrays must " + "be the same !"); + // check shapes consistency + if (!isSameShape(other)) + throw std::runtime_error( + "NDArray::applyReduce3 method: the shapes of this and other arrays " + "must be the same !"); + // create shapeInfo for scalar + auto newShape = ShapeBuilders::createScalarShapeInfo( + DataTypeUtils::pickFloatingType(dataType()), + getContext()->getWorkspace()); + // create output array (scalar) + NDArray result(newShape, true, getContext()); + RELEASE(newShape, getContext()->getWorkspace()); + // create dynamic array of extra parameters if array extraParams is empty + // (==nullptr) + void* params = + extraParams != nullptr + ? const_cast(extraParams)->argumentsAsT(dataType()) + : nullptr; + + NDArray::prepareSpecialUse({&result}, {this, &other}); + NativeOpExecutioner::execReduce3Scalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), params, other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), result.buffer(), + result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); + NDArray::registerSpecialUse({&result}, {this, &other}); + + return result; +} + +//////////////////////////////////////////////////////////////////////// +// apply reduce3 (exec) operations to this and other array, return result in new +// output array +NDArray NDArray::applyReduce3(sd::reduce3::Ops op, const NDArray& other, + const std::vector& dimensions, + const ExtraArguments* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::applyReduce3: you can't use this method on String array!"); + if (dataType() != other.dataType()) + throw std::runtime_error( + "NDArray::applyReduce3 method: the types of this and other arrays must " + "be the same !"); + + std::vector copy(dimensions); + shape::checkDimensions(rankOf(), copy); + shape::checkDimensions(other.rankOf(), copy); + + auto newShape = ShapeUtils::evalReduceShapeInfo( + 'c', copy, *this, DataTypeUtils::pickFloatingType(dataType()), false, + false, getContext()->getWorkspace()); + NDArray result(newShape, true, getContext()); + // create temporary dynamic array of extra parameters if array extraParams is + // empty (==nullptr) + void* params = + extraParams != nullptr + ? const_cast(extraParams)->argumentsAsT(dataType()) + : nullptr; + + NDArray::prepareSpecialUse({&result}, {this, &other}); + + // perform calculations + if (rankOf() == copy.size() && other.rankOf() == copy.size()) { + NativeOpExecutioner::execReduce3Scalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), params, other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), result.buffer(), + result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); + } else { + auto pDims = + sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + shapeInfo(), copy); + auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions( + other.shapeInfo(), copy); + + if (!shape::equalsSoft(packX.primaryShapeInfo(), + packY.primaryShapeInfo()) || + (packX.numberOfTads() != packY.numberOfTads() && + packX.numberOfTads() != 1 && packY.numberOfTads() != 1)) + throw std::runtime_error( + "NDArray::applyReduce3 cuda method: arrays tads are inconsistent !"); + + NativeOpExecutioner::execReduce3( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), params, other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), result.buffer(), + result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), + pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), + packY.platformShapeInfo(), packY.platformOffsets()); + } + + registerSpecialUse({&result}, {this, &other}); + + return result; +} + +//////////////////////////////////////////////////////////////////////// +// apply reduce3 (execAll) operations to this and other array, return result in +// new output array +NDArray NDArray::applyAllReduce3(sd::reduce3::Ops op, const NDArray& other, + const std::vector& dimensions, + const ExtraArguments* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::applyAllReduce3: you can't use this method on String array!"); + if (dataType() != other.dataType()) + throw std::runtime_error( + "NDArray::applyAllReduce3 method: the types of this and other arrays " + "must be the same !"); + + // be careful, copy array may undergo changes (sort, transformation of + // negative dimensions to positive, duplicates removing ) + std::vector copy(dimensions); + shape::checkDimensions(rankOf(), copy); + shape::checkDimensions(other.rankOf(), copy); + + auto packX = + ConstantTadHelper::getInstance().tadForDimensions(shapeInfo(), copy); + auto packY = ConstantTadHelper::getInstance().tadForDimensions( + other.shapeInfo(), copy); + + // check tads shapes + if (!shape::equalsSoft(packX.primaryShapeInfo(), packY.primaryShapeInfo())) + throw std::runtime_error( + "NDArray::applyAllReduce3 method: the shapes of array tads are " + "different !"); + + // set newShape for output array + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + DataTypeUtils::pickFloatingType(dataType()), 'c', + {packX.numberOfTads(), packY.numberOfTads()}); + + // create output array + NDArray result(newShape, true, getContext()); + + // create dynamic array of extra parameters if array extraParams is empty + // (==nullptr) + void* params = + extraParams != nullptr + ? const_cast(extraParams)->argumentsAsT(dataType()) + : nullptr; + + auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; + + NDArray::prepareSpecialUse({&result}, {this, &other}); + NativeOpExecutioner::execReduce3All( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), params, other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), result.buffer(), + result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), + pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), + packY.platformShapeInfo(), packY.platformOffsets()); + NDArray::registerSpecialUse({&result}, {this, &other}); + + return result; +} + +////////////////////////////////////////////////////////////////////////// +// method reduces array by excluding its shapes along axes present in dimensions +// vector +void NDArray::reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, + const std::vector& dimensions, + const bool keepDims, + const bool supportOldShapes, + const bool checkTargetShape) const { + if (isS()) + throw std::runtime_error( + "NDArray::reduceAlongDimension FloatOps: you can't use this method on " + "String array!"); + if (!target.isR()) + throw std::invalid_argument( + "NDArray::reduceAlongDimension FloatOps: requires target array to be " + "present and have type form real space!"); + + std::vector copy(dimensions); + + if (checkTargetShape) { + auto newShape = ShapeUtils::evalReduceShapeInfo( + target.ordering(), copy, *this, keepDims, supportOldShapes, + getContext()->getWorkspace()); + if (!shape::shapeEquals(newShape, target.shapeInfo())) + throw std::runtime_error( + "NDArray::reduceAlongDimension FloatOps: wrong target shape!"); + } + + NDArray::prepareSpecialUse({&target}, {this}); + + if (rankOf() == copy.size() || copy.empty()) { + NativeOpExecutioner::execReduceFloatScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo()); + } else { + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + shapeInfo(), copy); + NativeOpExecutioner::execReduceFloat( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), copy.data(), + copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); + } + synchronize("NDArray::reduceAlongDimension FloatOps"); + + NDArray::registerSpecialUse({&target}, {this}); +} + +////////////////////////////////////////////////////////////////////////// +// method reduces array by excluding its shapes along axes present in dimensions +// vector +void NDArray::reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, + const std::vector& dimensions, + const bool keepDims, + const bool supportOldShapes, + const bool checkTargetShape) const { + if (isS()) + throw std::runtime_error( + "NDArray::reduceAlongDimension SameOps: you can't use this method on " + "String array!"); + if (target.dataType() != dataType()) + throw std::runtime_error( + "NDArray::reduceAlongDimension SameOps: requires target array to be " + "present and have same dtype as input"); + + std::vector copy(dimensions); + + if (checkTargetShape) { + auto newShape = ShapeUtils::evalReduceShapeInfo( + target.ordering(), copy, *this, keepDims, supportOldShapes, + getContext()->getWorkspace()); + if (!shape::shapeEquals(newShape, target.shapeInfo())) + throw std::runtime_error( + "NDArray::reduceAlongDimension SameOps: wrong target shape!"); + } + + NDArray::prepareSpecialUse({&target}, {this}); + + if (rankOf() == copy.size() || copy.empty()) { + NativeOpExecutioner::execReduceSameScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo()); + } else { // if (!isEmpty()) { + auto pDims = + sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + this->shapeInfo(), copy); + NativeOpExecutioner::execReduceSame( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), + packX.platformShapeInfo(), packX.platformOffsets()); + } + synchronize("NDArray::reduceAlongDimension SameOps"); + + NDArray::registerSpecialUse({&target}, {this}); +} + +////////////////////////////////////////////////////////////////////////// +// method reduces array by excluding its shapes along axes present in dimensions +// vector +void NDArray::reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, + const std::vector& dimensions, + const bool keepDims, + const bool supportOldShapes, + const bool checkTargetShape) const { + if (isS()) + throw std::runtime_error( + "NDArray::reduceAlongDimension LongOps: you can't use this method on " + "String array!"); + if (target.dataType() != DataType::INT64) + throw std::runtime_error( + "NDArray::reduceAlongDimension LongOps: requires target array to be " + "present and have type of INT64"); + + std::vector copy(dimensions); + + if (checkTargetShape) { + auto newShape = ShapeUtils::evalReduceShapeInfo( + target.ordering(), copy, *this, keepDims, supportOldShapes, + getContext()->getWorkspace()); + if (!shape::shapeEquals(newShape, target.shapeInfo())) + throw std::runtime_error( + "NDArray::reduceAlongDimension LongOps: wrong target shape!"); + } + + NDArray::prepareSpecialUse({&target}, {this}); + + if (rankOf() == copy.size() || copy.empty()) { + NativeOpExecutioner::execReduceLongScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo()); + } else { + auto pDims = + sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + this->shapeInfo(), copy); + NativeOpExecutioner::execReduceLong( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), + packX.platformShapeInfo(), packX.platformOffsets()); + } + synchronize("NDArray::reduceAlongDimension LongOps"); + + NDArray::registerSpecialUse({&target}, {this}); +} + +////////////////////////////////////////////////////////////////////////// +// method reduces array by excluding its shapes along axes present in dimensions +// vector +void NDArray::reduceAlongDimension(sd::reduce::BoolOps op, NDArray& target, + const std::vector& dimensions, + const bool keepDims, + const bool supportOldShapes, + const bool checkTargetShape) const { + if (isS()) + throw std::runtime_error( + "NDArray::reduceAlongDimension BoolOps cuda: you can't use this method " + "on String array!"); + if (!target.isB()) + throw std::invalid_argument( + "NDArray::reduceAlongDimension BoolOps cuda: requires target array to " + "be present and have BOOL type!"); + + std::vector copy(dimensions); + + if (checkTargetShape) { + auto newShape = ShapeUtils::evalReduceShapeInfo( + target.ordering(), copy, *this, keepDims, supportOldShapes, + getContext()->getWorkspace()); + if (!shape::shapeEquals(newShape, target.shapeInfo())) + throw std::runtime_error( + "NDArray::reduceAlongDimension BoolOps cuda: wrong target shape!"); + } + + NDArray::prepareSpecialUse({&target}, {this}); + + if (rankOf() == copy.size() || copy.empty()) { + NativeOpExecutioner::execReduceBoolScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo()); + } else { + auto pDims = + sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + this->shapeInfo(), copy); + NativeOpExecutioner::execReduceBool( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), + packX.platformShapeInfo(), packX.platformOffsets()); + } + synchronize("NDArray::reduceAlongDimension LongOps"); + + NDArray::registerSpecialUse({&target}, {this}); +} + +////////////////////////////////////////////////////////////////////////// +// This method sets value in linear buffer to position i +template +void NDArray::p(const Nd4jLong i, const T value) { + if (i >= lengthOf()) + throw std::invalid_argument( + "NDArray::p(i, value): input index is out of array length !"); + + auto rp = getOffset(i); + const void* pV = reinterpret_cast(const_cast(&value)); + + NDArray::preparePrimaryUse({this}, {}, true); + BUILD_SINGLE_PARTIAL_SELECTOR(this->dataType(), + templatedSet<, T>(this->buffer(), rp, pV), + LIBND4J_TYPES); + NDArray::registerPrimaryUse({this}, {}); +} + +template SD_EXPORT void NDArray::p(const Nd4jLong i, const double value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const float value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const float16 value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const bfloat16 value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const int value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const int8_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const uint8_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const uint16_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const uint32_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const uint64_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const int16_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const bool value); + +////////////////////////////////////////////////////////////////////////// +// This method sets value in 2D matrix to position i, j +template +void NDArray::p(const Nd4jLong i, const Nd4jLong j, const T value) { + if (rankOf() != 2 || i >= shapeOf()[0] || j >= shapeOf()[1]) + throw std::invalid_argument( + "NDArray:pe(i,j, value): one of input indexes is out of array length " + "or rank!=2 !"); + + void* p = reinterpret_cast(const_cast(&value)); + Nd4jLong coords[2] = {i, j}; + auto xOffset = shape::getOffset(shapeInfo(), coords); + + NDArray::preparePrimaryUse({this}, {}, true); + BUILD_SINGLE_PARTIAL_SELECTOR( + dataType(), templatedSet<, T>(this->buffer(), xOffset, p), LIBND4J_TYPES); + NDArray::registerPrimaryUse({this}, {}); +} +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const double value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const float value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const float16 value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const bfloat16 value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const int value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const int8_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const uint8_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const uint16_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const uint32_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const uint64_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const int16_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const bool value); + +////////////////////////////////////////////////////////////////////////// +// This method sets value in 3D matrix to position i,j,k +template +void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, + const T value) { + //(*this)(i,j,k) = value; + if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || + k >= shapeOf()[2]) + throw std::invalid_argument( + "NDArray:pe(i,j,k, value): one of input indexes is out of array length " + "or rank!=3 !"); + + NDArray::preparePrimaryUse({this}, {}, true); + + void* p = reinterpret_cast(const_cast(&value)); + Nd4jLong coords[3] = {i, j, k}; + auto xOffset = shape::getOffset(shapeInfo(), coords); + BUILD_SINGLE_PARTIAL_SELECTOR( + dataType(), templatedSet<, T>(this->buffer(), xOffset, p), LIBND4J_TYPES); + NDArray::registerPrimaryUse({this}, {}); +} +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const double value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const float value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const float16 value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const bfloat16 value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const int value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const int8_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const uint8_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const uint16_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const uint32_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const uint64_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const int16_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const bool value); + +////////////////////////////////////////////////////////////////////////// +template +void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, + const Nd4jLong l, const T value) { + //(*this)(i,j,k) = value; + if (rankOf() != 4 || i >= shapeOf()[0] || j >= shapeOf()[1] || + k >= shapeOf()[2] || l >= shapeOf()[3]) + throw std::invalid_argument( + "NDArray::p(i,j,k,l, value): one of input indexes is out of array " + "length or rank!=4 !"); + + void* p = reinterpret_cast(const_cast(&value)); + Nd4jLong coords[4] = {i, j, k, l}; + auto xOffset = shape::getOffset(shapeInfo(), coords); + + NDArray::preparePrimaryUse({this}, {}, true); + BUILD_SINGLE_PARTIAL_SELECTOR( + dataType(), templatedSet<, T>(this->buffer(), xOffset, p), LIBND4J_TYPES); + NDArray::registerPrimaryUse({this}, {}); +} +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const double value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const float value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const float16 value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const bfloat16 value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const Nd4jLong value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const int value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const int8_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const uint8_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const uint16_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const uint32_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const uint64_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const int16_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const bool value); + +//////////////////////////////////////////////////////////////////////// +void NDArray::p(const Nd4jLong i, const NDArray& scalar) { + if (scalar.lengthOf() != 1) + throw std::invalid_argument( + "NDArray::p method: input array must be scalar!"); + if (i >= _length) + throw std::invalid_argument( + "NDArray::p(i, NDArray_scalar): input index is out of array length !"); + + NDArray::preparePrimaryUse({this}, {&scalar}, true); + auto rp = getOffset(i); + BUILD_SINGLE_SELECTOR(scalar.dataType(), templatedSet, + (buffer(), rp, scalar.dataType(), scalar.buffer()), + LIBND4J_TYPES); + NDArray::registerPrimaryUse({this}, {&scalar}); +} + +//////////////////////////////////////////////////////////////////////// +void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, + const Nd4jLong l, const NDArray& scalar) { + if (scalar.lengthOf() != 1) + throw std::invalid_argument( + "NDArray::p method: input array must be scalar!"); + if (i >= _length) + throw std::invalid_argument( + "NDArray::p(i, NDArray_scalar): input index is out of array length !"); + + // void *p = reinterpret_cast(scalar.buffer()); + Nd4jLong coords[4] = {i, j, k, l}; + auto xOffset = shape::getOffset(shapeInfo(), coords); + + NDArray::preparePrimaryUse({this}, {&scalar}, true); + // BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, + // T>(this->buffer(), xOffset, p), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR( + scalar.dataType(), templatedSet, + (this->buffer(), xOffset, scalar.dataType(), scalar.buffer()), + LIBND4J_TYPES); + NDArray::registerPrimaryUse({this}, {&scalar}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::addRowVector(const NDArray& row, NDArray& target) const { + if (isS()) + throw std::runtime_error( + "NDArray::addRowVector: you can't use this method on String array!"); + if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || + columns() != target.columns() || !row.isRowVector() || + columns() != row.lengthOf()) + throw std::invalid_argument("NDArray::addRowVector: wrong arguments !"); + if (target.dataType() != + DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType()) && + !(isR() && row.isR() && target.isR())) + throw std::invalid_argument( + "NDArray::addRowVector: wrong type of target array !"); + + int dimension = 1; + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + this->shapeInfo(), dimension); + + NDArray::prepareSpecialUse({&target}, {this, &row}); + NativeOpExecutioner::execBroadcast( + getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), + row.specialBuffer(), row.specialShapeInfo(), target.buffer(), + target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), + nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, + nullptr); + NDArray::registerSpecialUse({&target}, {this, &row}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::subRowVector(const NDArray& row, NDArray& target) const { + if (isS()) + throw std::runtime_error( + "NDArray::addRowVector: you can't use this method on String array!"); + if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || + columns() != target.columns() || !row.isRowVector() || + columns() != row.lengthOf()) + throw std::invalid_argument("NDArray::addRowVector: wrong arguments !"); + if (target.dataType() != + DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType()) && + !(isR() && row.isR() && target.isR())) + throw std::invalid_argument( + "NDArray::addRowVector: wrong type of target array !"); + + int dimension = 1; + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + this->shapeInfo(), dimension); + + NDArray::prepareSpecialUse({&target}, {this, &row}); + NativeOpExecutioner::execBroadcast( + getContext(), sd::broadcast::Ops::Subtract, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), + row.specialBuffer(), row.specialShapeInfo(), target.buffer(), + target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), + &dimension, 1, packX.platformShapeInfo(), packX.platformOffsets(), + nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this, &row}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::mulRowVector(const NDArray& row, NDArray& target) const { + if (isS()) + throw std::runtime_error( + "NDArray::mulRowVector: you can't use this method on String array!"); + if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || + columns() != target.columns() || !row.isRowVector() || + columns() != row.columns()) + throw std::invalid_argument("NDArray::divRowVector: wrong arguments !"); + if (target.dataType() != + DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType())) + throw std::invalid_argument( + "NDArray::mulRowVector: wrong type of target array !"); + + int dimension = 1; + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + this->shapeInfo(), dimension); + + NDArray::prepareSpecialUse({&target}, {this, &row}); + NativeOpExecutioner::execBroadcast( + getContext(), sd::broadcast::Ops::Multiply, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), + row.specialBuffer(), row.specialShapeInfo(), target.buffer(), + target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), + nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, + nullptr); + NDArray::registerSpecialUse({&target}, {this, &row}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::divRowVector(const NDArray& row, NDArray& target) const { + if (isS()) + throw std::runtime_error( + "NDArray::divRowVector: you can't use this method on String array!"); + if (row.isB()) + throw std::runtime_error( + "NDArray::divRowVector: you can't divide by bool row!"); + if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || + columns() != target.columns() || !row.isRowVector() || + columns() != row.columns()) + throw std::invalid_argument("NDArray::divRowVector: wrong arguments !"); + if (target.dataType() != + DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType())) + throw std::invalid_argument( + "NDArray::divRowVector: wrong type of target array !"); + + int dimension = 1; + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + this->shapeInfo(), dimension); + + NDArray::prepareSpecialUse({&target}, {this, &row}); + NativeOpExecutioner::execBroadcast( + getContext(), sd::broadcast::Divide, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), + row.specialBuffer(), row.specialShapeInfo(), target.buffer(), + target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), + nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, + nullptr); + NDArray::registerSpecialUse({&target}, {this, &row}); +} + +////////////////////////////////////////////////////////////////////////// +// This method adds given row to all rows in this NDArray, this array becomes +// affected +void NDArray::addiRowVector(const NDArray& row) { + if (isS()) + throw std::runtime_error( + "NDArray::addiRowVector: you can't use this method on String array!"); + if (rankOf() != 2 || !row.isRowVector() || columns() != row.lengthOf()) + throw std::invalid_argument("NDArray::addiRowVector: wrong arguments !"); + + int dimension = 1; + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + this->shapeInfo(), dimension); + + NDArray::prepareSpecialUse({this}, {&row}); + NativeOpExecutioner::execBroadcast( + getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), + row.specialBuffer(), row.specialShapeInfo(), this->buffer(), + this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), + nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, + nullptr); + NDArray::registerSpecialUse({this}, {&row}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::addColumnVector(const NDArray& column, NDArray& target) const { + if (isS()) + throw std::runtime_error( + "NDArray::addColumnVector: you can't use this method on String array!"); + if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || + columns() != target.columns() || !column.isColumnVector() || + rows() != column.lengthOf()) + throw std::invalid_argument("NDArray::addColumnVector: wrong arguments !"); + if (target.dataType() != + DataTypeUtils::pickPairwiseResultType(dataType(), column.dataType())) + throw std::invalid_argument( + "NDArray::addColumnVector: wrong type of target array !"); + + int dimension = 0; + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + this->shapeInfo(), dimension); + + NDArray::prepareSpecialUse({&target}, {this, &column}); + NativeOpExecutioner::execBroadcast( + getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), column.buffer(), column.shapeInfo(), + column.specialBuffer(), column.specialShapeInfo(), target.buffer(), + target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), + nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, + nullptr); + NDArray::registerSpecialUse({&target}, {this, &column}); +} + +////////////////////////////////////////////////////////////////////////// +// This method adds given column to all columns in this NDArray, this array +// becomes affected +void NDArray::addiColumnVector(const NDArray& column) { + if (isS()) + throw std::runtime_error( + "NDArray::addiColumnVector: you can't use this method on String " + "array!"); + if (rankOf() != 2 || !column.isColumnVector() || rows() != column.lengthOf()) + throw std::invalid_argument("NDArray::addiColumnVector: wrong arguments !"); + + int dimension = 0; + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + this->shapeInfo(), dimension); + + NDArray::prepareSpecialUse({this}, {&column}); + NativeOpExecutioner::execBroadcast( + getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), column.buffer(), column.shapeInfo(), + column.specialBuffer(), column.specialShapeInfo(), this->buffer(), + this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), + nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, + nullptr); + NDArray::registerSpecialUse({this}, {&column}); +} + +////////////////////////////////////////////////////////////////////////// +// This method multiplies each column of this array by given argument-column, +// this array becomes affected +void NDArray::muliColumnVector(const NDArray& column) { + if (isS()) + throw std::runtime_error( + "NDArray::muliColumnVector: you can't use this method on String " + "array!"); + if (rankOf() != 2 || !column.isColumnVector() || rows() != column.lengthOf()) + throw std::invalid_argument("NDArray::muliColumnVector: wrong arguments !"); + + int dimension = 0; + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + this->shapeInfo(), dimension); + + NDArray::prepareSpecialUse({this}, {&column}); + NativeOpExecutioner::execBroadcast( + getContext(), sd::broadcast::Ops::Multiply, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), column.buffer(), column.shapeInfo(), + column.specialBuffer(), column.specialShapeInfo(), this->buffer(), + this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), + nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, + nullptr); + NDArray::registerSpecialUse({this}, {&column}); +} + +////////////////////////////////////////////////////////////////////////// +template +void NDArray::templatedAssign(void* xBuffer, Nd4jLong xOffset, + const void* yBuffer, + const Nd4jLong yOffset) const { + if (xBuffer != nullptr && yBuffer != nullptr) + *(reinterpret_cast(xBuffer) + xOffset) = + *(reinterpret_cast(yBuffer) + yOffset); +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT void NDArray::templatedAssign, + (void* xBuffer, const Nd4jLong xOffset, + const void* yBuffer, const Nd4jLong yOffset) const, + LIBND4J_TYPES); + +////////////////////////////////////////////////////////////////////////// +bool NDArray::permutei(const int* dimensions, const int rank) { + auto shapeInfo = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, + getContext()->getWorkspace()); + setShapeInfo(shapeInfo); + + return true; +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::permutei(const Nd4jLong* dimensions, const int rank) { + auto shapeInfo = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, + getContext()->getWorkspace()); + setShapeInfo(shapeInfo); + + return true; +} + +//////////////////////////////////////////////////////////////////////// +ResultSet NDArray::multipleTensorsAlongDimension( + const std::vector& indices, const std::vector& dimensions) const { + ResultSet result; + + if (indices.size() == 0) return result; + + auto pack = ConstantTadHelper::getInstance().tadForDimensions( + shapeInfo(), const_cast(dimensions.data()), dimensions.size()); + + auto tadLength = shape::length(pack.primaryShapeInfo()); + auto numTads = lengthOf() / tadLength; + + for (auto idx : indices) { + if (idx >= numTads) { + nd4j_printf( + "NDArray::multipleTensorsAlongDimension: index %i is higher then " + "number of TADs: %i\n", + idx, numTads); + throw std::runtime_error("Bad index"); + } + + NDArray array(getDataBuffer(), ShapeDescriptor(pack.primaryShapeInfo()), + getContext(), pack.primaryOffsets()[idx] + bufferOffset()); + result.push_back(array); + } + + return result; +} + +//////////////////////////////////////////////////////////////////////// +ResultSet NDArray::allTensorsAlongDimension( + const std::initializer_list& dimensions) const { + return allTensorsAlongDimension(std::vector(dimensions)); +} + +//////////////////////////////////////////////////////////////////////// +ResultSet NDArray::allExamples() const { + std::vector dimensions(rankOf() - 1); + for (int e = 1; e < rankOf(); e++) dimensions[e - 1] = e; + + return allTensorsAlongDimension(dimensions); +} + +//////////////////////////////////////////////////////////////////////// +Nd4jLong NDArray::getOffset(const Nd4jLong i) const { + if (i >= lengthOf()) + throw std::invalid_argument( + "NDArray::getOffset: input index is out of array length !"); + + return shape::getIndexOffset(i, _shapeInfo); +} + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::like() { + return NDArray(shapeInfo(), this->dataType(), false, getContext()); +} + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::ulike() const { return NDArray(this, false, getContext()); } + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::diagonal(const char type) const { + if (isS()) + throw std::runtime_error( + "NDArray::diagonal: you can't use this method on String array!"); + + const char order = ordering(); + const int rank = rankOf(); + Nd4jLong* outShapeInfo; + ALLOCATE(outShapeInfo, getContext()->getWorkspace(), 8, Nd4jLong); + outShapeInfo[0] = 2; + outShapeInfo[5] = 0; + + if (isVector() || isScalar()) { + outShapeInfo[1] = outShapeInfo[2] = outShapeInfo[3] = outShapeInfo[4] = 1; + outShapeInfo[6] = 1; + outShapeInfo[7] = (int)order; + } else { + int diagSize = 100000000; + Nd4jLong indices[MAX_RANK]; + + for (int i = 0; i < rank; ++i) { + if (diagSize > shapeOf()[i]) diagSize = shapeOf()[i]; + indices[i] = 1; + } + + auto step = shape::getOffset(shapeInfo(), indices); + + if (type == 'c') { + outShapeInfo[1] = diagSize; + outShapeInfo[2] = 1; + } else { + outShapeInfo[1] = 1; + outShapeInfo[2] = diagSize; + } + shape::updateStrides(outShapeInfo, order); + + outShapeInfo[3] *= step; + outShapeInfo[4] *= step; + outShapeInfo[6] = 0; + } + + ArrayOptions::setDataType(outShapeInfo, this->dataType()); + + NDArray result(_buffer, ShapeDescriptor(outShapeInfo), getContext(), + bufferOffset()); + + RELEASE(outShapeInfo, getContext()->getWorkspace()); + + return result; +} + +//////////////////////////////////////////////////////////////////////// +ResultSet NDArray::allTensorsAlongDimension( + const std::vector& dimensions) const { + ResultSet result; + + if (dimensions.size() == 0) return result; + + if (dimensions.back() >= rankOf()) + throw std::runtime_error( + "NDArray::allTensorsAlongDimension static function: all input " + "dimensions must be smaller than rank of input array !"); + + auto pack = ConstantTadHelper::getInstance().tadForDimensions( + _shapeInfo, const_cast(dimensions.data()), dimensions.size()); + auto numTads = pack.numberOfTads(); + + for (Nd4jLong idx = 0; idx < numTads; idx++) { + NDArray array(_buffer, ShapeDescriptor(pack.primaryShapeInfo()), + getContext(), pack.primaryOffsets()[idx] + bufferOffset()); + array._isView = true; + result.push_back(array); + } + + return result; +} + +//////////////////////////////////////////////////////////////////////// +// operator returns sub-array with buffer pointing at this->_buffer + certain +// offset +NDArray NDArray::operator()(const std::vector& idx, + const bool keepUnitiesInShape, + const bool isStrided) const { + if (isEmpty()) + throw std::invalid_argument( + "NDArray::operator(sub-arrays): array is empty !"); + + // Nd4jLong *outShapeInfo = nullptr; + // ALLOCATE(outShapeInfo, workspace, shape::shapeInfoLength(inShapeInfo), + // Nd4jLong); + + int numOfUntiesInSubArrShape = 0; + + Nd4jLong* subArrShapeInfo = nullptr; + + if (!keepUnitiesInShape) { + int n(isStrided ? 3 : 2), first, last; + + // calculate the number of unities in shape + for (uint d = 0; d < rankOf(); ++d) { + if (idx[n * d] != idx[n * d + 1]) { + first = idx[n * d] >= 0 ? idx[n * d] : idx[n * d] + sizeAt(d) + 1; + last = idx[n * d + 1] >= 0 ? idx[n * d + 1] + : idx[n * d + 1] + sizeAt(d) + 1; + if (last - first == 1) ++numOfUntiesInSubArrShape; + } + } + } + + ALLOCATE(subArrShapeInfo, getContext()->getWorkspace(), + shape::shapeInfoLength(rankOf() - numOfUntiesInSubArrShape), + Nd4jLong); + + Nd4jLong offset; + + shape::calcSubArrShapeInfoAndOffset(idx.data(), shapeInfo(), subArrShapeInfo, + offset, keepUnitiesInShape, isStrided, + numOfUntiesInSubArrShape); + + NDArray result(_buffer, ShapeDescriptor(subArrShapeInfo), getContext(), + offset + bufferOffset()); + result._isView = true; + + RELEASE(subArrShapeInfo, getContext()->getWorkspace()); + + return result; +} + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::operator()(const Nd4jLong subArrIdx, + const std::vector& dimsToExclude, + bool keepUnitiesInShape) const { + std::vector idxRanges(2 * rankOf()); + + const auto rank = rankOf(); + const auto subArrRank = static_cast(dimsToExclude.size()); + + if (subArrRank > rank) + throw std::invalid_argument( + "NDArray::operator(const Nd4jLong subArrIdx, const std::vector& " + "dimsToExclude, bool keepUnitiesInShape): static method: dimsToExclude " + "is empty or has size > rank of array !"); + + memset(idxRanges.data(), 0, 2 * rank * sizeof(Nd4jLong)); + + // subArrRank == 0 means whole array, idxRanges should contain zeros only + + if (subArrRank != 0) { + std::vector shapeOfSubArr(subArrRank), indexes(subArrRank); + for (int i = 0; i < subArrRank; ++i) + shapeOfSubArr[i] = sizeAt(dimsToExclude[i]); + + shape::index2coords(subArrIdx, subArrRank, shapeOfSubArr.data(), + indexes.data()); + + for (int i = 0; i < subArrRank; ++i) { + int currIdx = 2 * dimsToExclude[i]; + idxRanges[currIdx] = indexes[i]; + idxRanges[currIdx + 1] = indexes[i] + 1; + } + } + + return (*this)(idxRanges, keepUnitiesInShape); +} + +//////////////////////////////////////////////////////////////////////// +void NDArray::getSubArrShapeAndOffsets(const std::vector& dimsToExclude, + Nd4jLong*& subArrShapeInfo, + Nd4jLong*& subArrOffsets, + bool keepUnitiesInShape) const { + if (isEmpty()) + throw std::invalid_argument( + "NDArray::getSubArrShapeAndOffsets: array is empty !"); + + const int rank = rankOf(); + const int subArrRank = (rank == dimsToExclude.size() || keepUnitiesInShape) + ? rank + : rank - dimsToExclude.size(); + const Nd4jLong numOfSubArrs = + ShapeUtils::getNumOfSubArrs(_shapeInfo, dimsToExclude); + + // allocate memory + ALLOCATE(subArrShapeInfo, getContext()->getWorkspace(), + shape::shapeInfoLength(subArrRank), Nd4jLong); + ALLOCATE(subArrOffsets, getContext()->getWorkspace(), numOfSubArrs, Nd4jLong); + + shape::calcSubArrsShapeInfoAndOffsets( + _shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), + subArrShapeInfo, subArrOffsets, keepUnitiesInShape); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::setShapeInfo(const Nd4jLong* shapeInfo) { + if (shapeInfo != nullptr) { + ShapeDescriptor descriptor(shapeInfo); + auto shapeBuffer = + ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor); + + _shapeInfo = shapeBuffer.primary(); + _shapeInfoD = shapeBuffer.special(); + + if (ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) + _length = 0; + else + _length = shape::length(_shapeInfo); + + _dataType = ArrayOptions::dataType(_shapeInfo); + } else { + _dataType = sd::DataType::INHERIT; + _shapeInfoD = _shapeInfo = nullptr; + } +} + +//////////////////////////////////////////////////////////////////////// +void NDArray::setShapeInfo(const Nd4jLong* shapeInfo, + const sd::DataType dtype) { + if (shapeInfo != nullptr) { + Nd4jLong* shapeInfoTemp = ShapeBuilders::copyShapeInfoAndType( + shapeInfo, dtype, true, getContext()->getWorkspace()); + ShapeDescriptor descriptor(shapeInfoTemp); + auto shapeBuffer = + ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor); + + _shapeInfo = shapeBuffer.primary(); + _shapeInfoD = shapeBuffer.special(); + + if (ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) + _length = 0; + else + _length = shape::length(_shapeInfo); + + _dataType = dtype; + } else { + _dataType = sd::DataType::INHERIT; + _shapeInfoD = _shapeInfo = nullptr; + } +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::setShapeInfo(const ShapeDescriptor& descriptor) { + auto shapeBuffer = ConstantShapeHelper::getInstance().bufferForShapeInfo( + const_cast(descriptor)); + + _shapeInfo = shapeBuffer.primary(); + _shapeInfoD = shapeBuffer.special(); + + + if (ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) + _length = 0; + else + _length = shape::length(_shapeInfo); + + _dataType = ArrayOptions::dataType(_shapeInfo); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::setShapeInfo(const ConstantShapeBuffer& shapeBuffer) { + _shapeInfo = shapeBuffer.primary(); + _shapeInfoD = shapeBuffer.special(); + + + if (ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) + _length = 0; + else + _length = shape::length(_shapeInfo); + + _dataType = ArrayOptions::dataType(_shapeInfo); +} + +/////////////////////////////////////////////////////////////////////// +// addition operator array + scalar +template +NDArray operator+(NDArray&& arr, const T& scalar) { + if (arr.isView()) // do not use resources of arrays which use buffers of + // other original arrays + return std::move(arr + scalar); // arr is lvalue inside function body + + if (arr.isS()) + throw std::runtime_error( + "operator+(NDArray&& arr, const T& scalar): you can't use this method " + "on String array!"); + if (arr.dataType() != DataTypeUtils::pickPairwiseResultType( + arr.dataType(), DataTypeUtils::fromT())) + throw std::runtime_error( + "operator+(NDArray&& arr, const T& scalar): you can't use this method " + "on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar( + arr.getContext(), sd::scalar::Add, arr.buffer(), arr.shapeInfo(), + arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), + arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), + tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), + tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + + return std::move(arr); +} +template SD_EXPORT NDArray operator+(NDArray&& arr, const double& scalar); +template SD_EXPORT NDArray operator+(NDArray&& arr, const float& scalar); +template SD_EXPORT NDArray operator+(NDArray&& arr, const float16& scalar); +template SD_EXPORT NDArray operator+(NDArray&& arr, const bfloat16& scalar); +template SD_EXPORT NDArray operator+(NDArray&& arr, const int& scalar); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator+(const NDArray& arr, const T& scalar) { + if (arr.isS()) + throw std::runtime_error( + "operator+(const NDArray& arr, const T& scalar): you can't use this " + "method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.shapeInfo(), + DataTypeUtils::pickPairwiseResultType( + arr.dataType(), DataTypeUtils::fromT()), + false, arr.getContext()); + + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar( + arr.getContext(), sd::scalar::Add, arr.buffer(), arr.shapeInfo(), + arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), + result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), + tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), + tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + + return result; +} +template SD_EXPORT NDArray operator+(const NDArray& arr, const double& scalar); +template SD_EXPORT NDArray operator+(const NDArray& arr, const float& scalar); +template SD_EXPORT NDArray operator+(const NDArray& arr, const float16& scalar); +template SD_EXPORT NDArray operator+(const NDArray& arr, + const bfloat16& scalar); +template SD_EXPORT NDArray operator+(const NDArray& arr, const int& scalar); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator+(const T& scalar, NDArray&& arr) { + return std::move(arr) + scalar; +} +template SD_EXPORT NDArray operator+(const double& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator+(const float& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator+(const float16& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator+(const bfloat16& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator+(const int& scalar, NDArray&& arr); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator+(const T& scalar, const NDArray& arr) { + return arr + scalar; +} +template SD_EXPORT NDArray operator+(const double& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator+(const float& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator+(const int& scalar, const NDArray& arr); + +/////////////////////////////////////////////////////////////////////// +// addition operator array - scalar +template +NDArray operator-(NDArray&& arr, const T& scalar) { + if (arr.isView()) // do not use resources of arrays which use buffers of + // other original arrays + return std::move(arr - scalar); // arr is lvalue inside function body + + if (arr.isS()) + throw std::runtime_error( + "operator-(NDArray&& arr, const T& scalar): you can't use this method " + "on String array!"); + if (arr.dataType() != DataTypeUtils::pickPairwiseResultType( + arr.dataType(), DataTypeUtils::fromT())) + throw std::runtime_error( + "operator-(NDArray&& arr, const T& scalar): you can't use this method " + "on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar( + arr.getContext(), sd::scalar::Subtract, arr.buffer(), arr.shapeInfo(), + arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), + arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), + tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), + tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + + return std::move(arr); +} +template SD_EXPORT NDArray operator-(NDArray&& arr, const double& scalar); +template SD_EXPORT NDArray operator-(NDArray&& arr, const float& scalar); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator-(const NDArray& arr, const T& scalar) { + if (arr.isS()) + throw std::runtime_error( + "operator-(const NDArray& arr, const T& scalar): you can't use this " + "method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.shapeInfo(), + DataTypeUtils::pickPairwiseResultType( + arr.dataType(), DataTypeUtils::fromT()), + false, arr.getContext()); + + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar( + arr.getContext(), sd::scalar::Subtract, arr.buffer(), arr.shapeInfo(), + arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), + result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), + tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), + tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + + return result; +} +template SD_EXPORT NDArray operator-(const NDArray& arr, const double& scalar); +template SD_EXPORT NDArray operator-(const NDArray& arr, const float& scalar); +template SD_EXPORT NDArray operator-(const NDArray& arr, const float16& scalar); +template SD_EXPORT NDArray operator-(const NDArray& arr, + const bfloat16& scalar); +template SD_EXPORT NDArray operator-(const NDArray& arr, const int& scalar); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator-(const T& scalar, NDArray&& arr) { + if (arr.isView()) // do not use resources of arrays which use buffers of + // other original arrays + return std::move(scalar - arr); // arr is lvalue inside function body + + if (arr.isS()) + throw std::runtime_error( + "operator-(const T& scalar, NDArray&& arr): you can't use this method " + "on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar( + arr.getContext(), sd::scalar::ReverseSubtract, arr.buffer(), + arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), + arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), + arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), + tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + + return std::move(arr); +} +template SD_EXPORT NDArray operator-(const double& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator-(const float& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator-(const float16& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator-(const bfloat16& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator-(const int& scalar, NDArray&& arr); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator-(const T& scalar, const NDArray& arr) { + if (arr.isS()) + throw std::runtime_error( + "operator-(const T& scalar, const NDArray& arr): you can't use this " + "method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.shapeInfo(), + DataTypeUtils::pickPairwiseResultType( + arr.dataType(), DataTypeUtils::fromT()), + false, arr.getContext()); + + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar( + arr.getContext(), sd::scalar::ReverseSubtract, arr.buffer(), + arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), + result.buffer(), result.shapeInfo(), result.specialBuffer(), + result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), + tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + + return result; +} +template SD_EXPORT NDArray operator-(const double& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator-(const float& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator-(const int& scalar, const NDArray& arr); + +/////////////////////////////////////////////////////////////////////// +// addition operator array + scalar +template +NDArray operator*(NDArray&& arr, const T& scalar) { + if (arr.isView()) // do not use resources of arrays which use buffers of + // other original arrays + return std::move(arr * scalar); // arr is lvalue inside function body + + if (arr.isS()) + throw std::runtime_error( + "operator*(NDArray&& arr, const T& scalar): you can't use this method " + "on String array!"); + if (arr.dataType() != DataTypeUtils::pickPairwiseResultType( + arr.dataType(), DataTypeUtils::fromT())) + throw std::runtime_error( + "operator*(NDArray&& arr, const T& scalar): you can't use this method " + "on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar( + arr.getContext(), sd::scalar::Multiply, arr.buffer(), arr.shapeInfo(), + arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), + arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), + tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), + tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + + return std::move(arr); +} +template SD_EXPORT NDArray operator*(NDArray&& arr, const double& scalar); +template SD_EXPORT NDArray operator*(NDArray&& arr, const float& scalar); +template SD_EXPORT NDArray operator*(NDArray&& arr, const float16& scalar); +template SD_EXPORT NDArray operator*(NDArray&& arr, const bfloat16& scalar); +template SD_EXPORT NDArray operator*(NDArray&& arr, const int& scalar); +template SD_EXPORT NDArray operator*(NDArray&& arr, const long long& scalar); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator*(const NDArray& arr, const T& scalar) { + if (arr.isS()) + throw std::runtime_error( + "operator*(const NDArray& arr, const T& scalar): you can't use this " + "method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.shapeInfo(), + DataTypeUtils::pickPairwiseResultType( + arr.dataType(), DataTypeUtils::fromT()), + false, arr.getContext()); + + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar( + arr.getContext(), sd::scalar::Multiply, arr.buffer(), arr.shapeInfo(), + arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), + result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), + tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), + tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + + return result; +} + +template SD_EXPORT NDArray operator*(const NDArray& arr, const double& scalar); +template SD_EXPORT NDArray operator*(const NDArray& arr, const float& scalar); +template SD_EXPORT NDArray operator*(const NDArray& arr, const float16& scalar); +template SD_EXPORT NDArray operator*(const NDArray& arr, + const bfloat16& scalar); +template SD_EXPORT NDArray operator*(const NDArray& arr, const int& scalar); +template SD_EXPORT NDArray operator*(const NDArray& arr, + const long long& scalar); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator*(const T& scalar, NDArray&& arr) { + return std::move(arr) * scalar; +} +template SD_EXPORT NDArray operator*(const double& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator*(const float& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator*(const float16& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator*(const bfloat16& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator*(const int& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator*(const long long& scalar, NDArray&& arr); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator*(const T& scalar, const NDArray& arr) { + return arr * scalar; +} +template SD_EXPORT NDArray operator*(const double& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator*(const float& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator*(const float16& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator*(const bfloat16& scalar, + const NDArray& arr); +template SD_EXPORT NDArray operator*(const int& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator*(const long long& scalar, + const NDArray& arr); + +/////////////////////////////////////////////////////////////////////// +template +NDArray operator/(NDArray&& arr, const T& scalar) { + if (arr.isView()) // do not use resources of arrays which use buffers of + // other original arrays + return std::move(arr / scalar); // arr is lvalue inside function body + + if (arr.isS()) + throw std::runtime_error( + "operator/(NDArray&& arr, const T& scalar): you can't use this method " + "on String array!"); + if (arr.dataType() != DataTypeUtils::pickPairwiseResultType( + arr.dataType(), DataTypeUtils::fromT())) + throw std::runtime_error( + "operator/(NDArray&& arr, const T& scalar): you can't use this method " + "on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar( + arr.getContext(), sd::scalar::Divide, arr.buffer(), arr.shapeInfo(), + arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), + arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), + tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), + tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + + return std::move(arr); +} +template SD_EXPORT NDArray operator/(NDArray&& arr, const double& scalar); +template SD_EXPORT NDArray operator/(NDArray&& arr, const float& scalar); +template SD_EXPORT NDArray operator/(NDArray&& arr, const float16& scalar); +template SD_EXPORT NDArray operator/(NDArray&& arr, const bfloat16& scalar); +template SD_EXPORT NDArray operator/(NDArray&& arr, const long long& scalar); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator/(const NDArray& arr, const T& scalar) { + if (arr.isS()) + throw std::runtime_error( + "operator/(const NDArray& arr, const T& scalar): you can't use this " + "method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.shapeInfo(), + DataTypeUtils::pickPairwiseResultType( + arr.dataType(), DataTypeUtils::fromT()), + false, arr.getContext()); + + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar( + arr.getContext(), sd::scalar::Divide, arr.buffer(), arr.shapeInfo(), + arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), + result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), + tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), + tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + + return result; +} +template SD_EXPORT NDArray operator/(const NDArray& arr, const double& scalar); +template SD_EXPORT NDArray operator/(const NDArray& arr, const float& scalar); +template SD_EXPORT NDArray operator/(const NDArray& arr, const float16& scalar); +template SD_EXPORT NDArray operator/(const NDArray& arr, + const bfloat16& scalar); +template SD_EXPORT NDArray operator/(const NDArray& arr, const int& scalar); +template SD_EXPORT NDArray operator/(const NDArray& arr, + const long long& scalar); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator/(const T& scalar, NDArray&& arr) { + if (arr.isView()) // do not use resources of arrays which use buffers of + // other original arrays + return std::move(scalar / arr); // arr is lvalue inside function body + + if (arr.isS()) + throw std::runtime_error( + "operator/(const T& scalar, NDArray&& arr): you can't use this method " + "on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar( + arr.getContext(), sd::scalar::ReverseDivide, arr.buffer(), + arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), + arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), + arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), + tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + + return std::move(arr); +} +template SD_EXPORT NDArray operator/(const double& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator/(const float& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator/(const float16& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator/(const bfloat16& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator/(const int& scalar, NDArray&& arr); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator/(const T& scalar, const NDArray& arr) { + if (arr.isS()) + throw std::runtime_error( + "operator/(const T& scalar, const NDArray& arr): you can't use this " + "method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.shapeInfo(), + DataTypeUtils::pickPairwiseResultType( + arr.dataType(), DataTypeUtils::fromT()), + false, arr.getContext()); + + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar( + arr.getContext(), sd::scalar::ReverseDivide, arr.buffer(), + arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), + result.buffer(), result.shapeInfo(), result.specialBuffer(), + result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), + tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + + return result; +} +template SD_EXPORT NDArray operator/(const double& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator/(const float& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator/(const int& scalar, const NDArray& arr); + +//////////////////////////////////////////////////////////////////////// +// addition operator array + array +template +NDArray operator+(T1&& arr1, T2&& arr2) { + if (arr1.isS() || arr2.isS()) + throw std::runtime_error( + "operator+(T&& arr1, T&& arr2): you can't use this method on String " + "arrays!"); + if (!Environment::getInstance().isExperimentalBuild() && + arr1.dataType() != arr2.dataType() && + (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) + throw sd::datatype_exception::build( + "operator+(T&& arr1, T&& arr2): Cannot multiply different types", + arr1.dataType(), arr2.dataType()); + + PointersManager pointersManager(arr1.getContext(), + "operator+(T&& arr1, T&& arr2)"); + + if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { + const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); + const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); + + NDArray* result = nullptr; + if (isArr1Rvalue) + result = const_cast(&arr1); + else if (isArr2Rvalue) + result = const_cast(&arr2); + else + result = new NDArray(arr1.shapeInfo(), + DataTypeUtils::pickPairwiseResultType( + arr1.shapeInfo(), arr2.shapeInfo()), + false, arr1.getContext()); + + NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); + NativeOpExecutioner::execPairwiseTransform( + arr1.getContext(), sd::pairwise::Add, arr1.buffer(), arr1.shapeInfo(), + arr1.specialBuffer(), arr1.specialShapeInfo(), arr2.buffer(), + arr2.shapeInfo(), arr2.specialBuffer(), arr2.specialShapeInfo(), + result->buffer(), result->shapeInfo(), result->specialBuffer(), + result->specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({result}, {&arr1, &arr2}); + + if (!isArr1Rvalue && !isArr2Rvalue) { + NDArray res = std::move(*result); + delete result; + return std::move(res); + } + + return std::move(*result); + } + + return std::forward(arr1).applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), + std::forward(arr2)); +} +template SD_EXPORT NDArray operator+ + (NDArray& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator+ + (NDArray& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator+ + (NDArray&& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator+ + (NDArray& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator+ + (const NDArray& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator+ + (const NDArray& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator+ + (const NDArray& arr1, + const NDArray& arr2); +template SD_EXPORT NDArray operator+ + (NDArray&& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator+ + (NDArray&& arr1, NDArray&& arr2); + +//////////////////////////////////////////////////////////////////////// +// addition operator array - array +template +NDArray operator-(T1&& arr1, T2&& arr2) { + if (arr1.isS() || arr2.isS()) + throw std::runtime_error( + "operator-(T&& arr1, T&& arr2): you can't use this method on String " + "arrays!"); + if (!Environment::getInstance().isExperimentalBuild() && + arr1.dataType() != arr2.dataType() && + (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) + throw sd::datatype_exception::build( + "operator-(T&& arr1, T&& arr2): Cannot multiply different types", + arr1.dataType(), arr2.dataType()); + + PointersManager pointersManager(arr1.getContext(), + "operator-(T&& arr1, T&& arr2)"); + + if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { + const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); + const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); + + NDArray* result = nullptr; + if (isArr1Rvalue) + result = const_cast(&arr1); + else if (isArr2Rvalue) + result = const_cast(&arr2); + else + result = new NDArray(arr1.shapeInfo(), + DataTypeUtils::pickPairwiseResultType( + arr1.shapeInfo(), arr2.shapeInfo()), + false, arr1.getContext()); + + NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); + NativeOpExecutioner::execPairwiseTransform( + arr1.getContext(), sd::pairwise::Subtract, arr1.buffer(), + arr1.shapeInfo(), arr1.specialBuffer(), arr1.specialShapeInfo(), + arr2.buffer(), arr2.shapeInfo(), arr2.specialBuffer(), + arr2.specialShapeInfo(), result->buffer(), result->shapeInfo(), + result->specialBuffer(), result->specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({result}, {&arr1, &arr2}); + + if (!isArr1Rvalue && !isArr2Rvalue) { + NDArray res = std::move(*result); + delete result; + return std::move(res); + } + + return std::move(*result); + } + + return std::forward(arr1).applyTrueBroadcast( + sd::BroadcastOpsTuple::Subtract(), std::forward(arr2)); +} +template SD_EXPORT NDArray operator- + (NDArray& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator- + (NDArray& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator- + (NDArray&& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator- + (NDArray& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator- + (const NDArray& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator- + (const NDArray& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator- + (const NDArray& arr1, + const NDArray& arr2); +template SD_EXPORT NDArray operator- + (NDArray&& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator- + (NDArray&& arr1, NDArray&& arr2); + +//////////////////////////////////////////////////////////////////////// +// multiplication operator array*array +template +NDArray operator*(T1&& arr1, T2&& arr2) { + if (arr1.isS() || arr2.isS()) + throw std::runtime_error( + "operator*(T&& arr1, T&& arr2): you can't use this method on String " + "arrays!"); + if (!Environment::getInstance().isExperimentalBuild() && + arr1.dataType() != arr2.dataType() && + (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) + throw sd::datatype_exception::build( + "operator*(T&& arr1, T&& arr2): Cannot multiply different types", + arr1.dataType(), arr2.dataType()); + + PointersManager pointersManager(arr1.getContext(), + "operator*(T&& arr1, T&& arr2)"); + + if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { + const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); + const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); + + NDArray* result = nullptr; + if (isArr1Rvalue) + result = const_cast(&arr1); + else if (isArr2Rvalue) + result = const_cast(&arr2); + else + result = new NDArray(arr1.shapeInfo(), + DataTypeUtils::pickPairwiseResultType( + arr1.shapeInfo(), arr2.shapeInfo()), + false, arr1.getContext()); + + NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); + NativeOpExecutioner::execPairwiseTransform( + arr1.getContext(), sd::pairwise::Multiply, arr1.buffer(), + arr1.shapeInfo(), arr1.specialBuffer(), arr1.specialShapeInfo(), + arr2.buffer(), arr2.shapeInfo(), arr2.specialBuffer(), + arr2.specialShapeInfo(), result->buffer(), result->shapeInfo(), + result->specialBuffer(), result->specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({result}, {&arr1, &arr2}); + + if (!isArr1Rvalue && !isArr2Rvalue) { + NDArray res = std::move(*result); + delete result; + return std::move(res); + } + + return std::move(*result); + } + + return std::forward(arr1).applyTrueBroadcast( + sd::BroadcastOpsTuple::Multiply(), std::forward(arr2)); +} +template SD_EXPORT NDArray operator* + (NDArray& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator* + (NDArray& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator* + (NDArray&& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator* + (NDArray& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator* + (const NDArray& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator* + (const NDArray& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator* + (const NDArray& arr1, + const NDArray& arr2); +template SD_EXPORT NDArray operator* + (NDArray&& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator* + (NDArray&& arr1, NDArray&& arr2); + +//////////////////////////////////////////////////////////////////////// +// multiplication operator array*array +template +NDArray operator/(T1&& arr1, T2&& arr2) { + if (arr1.isS() || arr2.isS()) + throw std::runtime_error( + "operator/(T&& arr1, T&& arr2): you can't use this method on String " + "arrays!"); + if (!Environment::getInstance().isExperimentalBuild() && + arr1.dataType() != arr2.dataType() && + (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) + throw sd::datatype_exception::build( + "operator/(T&& arr1, T&& arr2): Cannot multiply different types", + arr1.dataType(), arr2.dataType()); + + PointersManager pointersManager(arr1.getContext(), + "operator/(T&& arr1, T&& arr2)"); + + if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { + const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); + const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); + + NDArray* result = nullptr; + if (isArr1Rvalue) + result = const_cast(&arr1); + else if (isArr2Rvalue) + result = const_cast(&arr2); + else + result = new NDArray(arr1.shapeInfo(), + DataTypeUtils::pickPairwiseResultType( + arr1.shapeInfo(), arr2.shapeInfo()), + false, arr1.getContext()); + + NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); + NativeOpExecutioner::execPairwiseTransform( + arr1.getContext(), sd::pairwise::Divide, arr1.buffer(), + arr1.shapeInfo(), arr1.specialBuffer(), arr1.specialShapeInfo(), + arr2.buffer(), arr2.shapeInfo(), arr2.specialBuffer(), + arr2.specialShapeInfo(), result->buffer(), result->shapeInfo(), + result->specialBuffer(), result->specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({result}, {&arr1, &arr2}); + + if (!isArr1Rvalue && !isArr2Rvalue) { + NDArray res = std::move(*result); + delete result; + return std::move(res); + } + + return std::move(*result); + } + + return std::forward(arr1).applyTrueBroadcast( + sd::BroadcastOpsTuple::Divide(), std::forward(arr2)); +} +template SD_EXPORT NDArray operator/ + (NDArray& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator/ + (NDArray& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator/ + (NDArray&& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator/ + (NDArray& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator/ + (const NDArray& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator/ + (const NDArray& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator/ + (const NDArray& arr1, + const NDArray& arr2); +template SD_EXPORT NDArray operator/ + (NDArray&& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator/ + (NDArray&& arr1, NDArray&& arr2); + +/* +#ifndef __CLION_IDE__ +#include "NDArray.macro" +#endif + */ +} // namespace sd + +#endif + +////////////////////////////////////////////////////////////////////////// +// check whether array's rows (arg=0) or columns (arg=1) create orthogonal basis +// bool NDArray::hasOrthonormalBasis(const int arg) { +// if (isS()) +// throw std::runtime_error("NDArray::hasOrthonormalBasis: you can't use +// this method on String array!"); +// if(rankOf() !=2 ) +// throw std::runtime_error("NDArray::hasOrthBasis method: rank of +// ndarray is not equal 2 !"); + +// if(arg!=0 && arg!=1) +// throw std::runtime_error("NDArray::hasOrthBasis method: input +// argument is not equal to 0 or 1 !"); + +// const double eps = 1e-5; +// double dot = 0.f; + +// if(arg) { // check whether columns create orthogonal +// basis +// for(int j=0; j(i,j)*e(i,k); + +// if(sd::math::nd4j_abs(dot) > eps ) +// return false; + +// dot = 0.f; +// } + +// for(int j=0; j(i,j)*e(i,j); +// if(dot != 0.f && sd::math::nd4j_abs(sd::math::nd4j_sqrt(dot) - 1.f) > eps) +// return false; + +// dot = 0.f; +// } +// } +// else { // check whether rows create orthogonal basis +// for(int i=0; i(i,j)*e(k,j); + +// if(sd::math::nd4j_abs(dot) > eps ) +// return false; + +// dot = 0.; +// } + +// for(int i=0; i(i,j)*e(i,j); + +// if(dot!= 0. && sd::math::nd4j_abs(sd::math::nd4j_sqrt(dot) - 1.) > eps) +// return false; +// dot = 0.; +// } +// } +// return true; +// } diff --git a/libnd4j/include/array/impl/NDArrayFactory.cpp b/libnd4j/include/array/impl/NDArrayFactory.cpp index f14aa9dbb653..a84e639ec516 100644 --- a/libnd4j/include/array/impl/NDArrayFactory.cpp +++ b/libnd4j/include/array/impl/NDArrayFactory.cpp @@ -24,695 +24,1101 @@ #include #include #include -#include +#include #include -#include - - - - #include #include -namespace sd { +#include - //////////////////////////////////////////////////////////////////////// - template <> - ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context) { +namespace sd { - if ((int) shape.size() > MAX_RANK) - throw std::invalid_argument("NDArrayFactory::create: rank of NDArray can't exceed 32 !"); +SD_EXPORT NDArray NDArrayFactory::undefined() { return NDArray(); } - ShapeDescriptor descriptor(sd::DataType::BOOL, order, shape); +//////////////////////////////////////////////////////////////////////// +template <> +SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context) { + if ((int)shape.size() > MAX_RANK) + throw std::invalid_argument( + "NDArrayFactory::create: rank of NDArray can't exceed 32 !"); + + ShapeDescriptor descriptor(sd::DataType::BOOL, order, shape); + + if (descriptor.arrLength() != data.size()) { + nd4j_printf( + "NDArrayFactory::create: data size [%i] doesn't match shape length " + "[%lld]\n", + data.size(), descriptor.arrLength()); + throw std::runtime_error( + "NDArrayFactory::create: data size doesn't match shape"); + } + + bool* hostBuffer = nullptr; + ALLOCATE(hostBuffer, context->getWorkspace(), data.size(), bool); + std::copy(data.begin(), data.end(), hostBuffer); + + std::shared_ptr buffer = std::make_shared( + hostBuffer, data.size() * sizeof(bool), sd::DataType::BOOL, true, + context->getWorkspace()); + + NDArray result(buffer, descriptor, context); + + return result; +} - if (descriptor.arrLength() != data.size()) { - nd4j_printf("NDArrayFactory::create: data size [%i] doesn't match shape length [%lld]\n", data.size(), descriptor.arrLength()); - throw std::runtime_error("NDArrayFactory::create: data size doesn't match shape"); - } +//////////////////////////////////////////////////////////////////////// +template +NDArray NDArrayFactory::create(const char order, + const std::vector& shape, + const std::vector& data, + sd::LaunchContext* context) { + if ((int)shape.size() > MAX_RANK) + throw std::invalid_argument( + "NDArrayFactory::create: rank of NDArray can't exceed 32 !"); + + ShapeDescriptor descriptor(DataTypeUtils::fromT(), order, shape); + + if (descriptor.arrLength() != data.size()) { + nd4j_printf( + "NDArrayFactory::create: data size [%i] doesn't match shape length " + "[%lld]\n", + data.size(), descriptor.arrLength()); + throw std::runtime_error( + "NDArrayFactory::create: data size doesn't match shape"); + } + + std::shared_ptr buffer = std::make_shared( + data.data(), DataTypeUtils::fromT(), + descriptor.arrLength() * sizeof(T), context->getWorkspace()); + + NDArray result(buffer, descriptor, context); + + return result; +} +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); - bool* hostBuffer = nullptr; - ALLOCATE(hostBuffer, context->getWorkspace(), data.size(), bool); - std::copy(data.begin(), data.end(), hostBuffer); +//////////////////////////////////////////////////////////////////////// +template +NDArray* NDArrayFactory::create_(const char order, + const std::vector& shape, + sd::LaunchContext* context) { + return create_(order, shape, DataTypeUtils::fromT(), context); +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray* NDArrayFactory::create_, + (const char order, const std::vector& shape, + sd::LaunchContext* context), + LIBND4J_TYPES); - std::shared_ptr buffer = std::make_shared(hostBuffer, data.size() * sizeof(bool), sd::DataType::BOOL, true, context->getWorkspace()); +//////////////////////////////////////////////////////////////////////// +template +void NDArrayFactory::memcpyFromVector(void* ptr, const std::vector& vector) { + memcpy(ptr, vector.data(), vector.size() * sizeof(T)); +} - NDArray result(buffer, descriptor, context); +template <> +void SD_EXPORT +NDArrayFactory::memcpyFromVector(void* ptr, const std::vector& vector) { + auto p = reinterpret_cast(ptr); + for (Nd4jLong e = 0; e < vector.size(); e++) p[e] = vector[e]; +} - return result; - } +template SD_EXPORT void NDArrayFactory::memcpyFromVector( + void* ptr, const std::vector& vector); +template SD_EXPORT void NDArrayFactory::memcpyFromVector( + void* ptr, const std::vector& vector); +template SD_EXPORT void NDArrayFactory::memcpyFromVector( + void* ptr, const std::vector& vector); +template SD_EXPORT void NDArrayFactory::memcpyFromVector( + void* ptr, const std::vector& vector); +template SD_EXPORT void NDArrayFactory::memcpyFromVector( + void* ptr, const std::vector& vector); +template SD_EXPORT void NDArrayFactory::memcpyFromVector( + void* ptr, const std::vector& vector); +template SD_EXPORT void NDArrayFactory::memcpyFromVector( + void* ptr, const std::vector& vector); +template SD_EXPORT void NDArrayFactory::memcpyFromVector( + void* ptr, const std::vector& vector); - //////////////////////////////////////////////////////////////////////// - template - NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context) { +#ifndef __JAVACPP_HACK__ +//////////////////////////////////////////////////////////////////////// +template +NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, + const T value, const char order, + sd::LaunchContext* context) { + return valueOf(std::vector(shape), value, order); +} +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::initializer_list& shape, const double value, + const char order, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::initializer_list& shape, const float value, + const char order, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::initializer_list& shape, const float16 value, + const char order, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::initializer_list& shape, const bfloat16 value, + const char order, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::initializer_list& shape, const Nd4jLong value, + const char order, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::initializer_list& shape, const int value, + const char order, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::initializer_list& shape, const uint8_t value, + const char order, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::initializer_list& shape, const int8_t value, + const char order, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::initializer_list& shape, const int16_t value, + const char order, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::initializer_list& shape, const bool value, + const char order, sd::LaunchContext* context); - if ((int) shape.size() > MAX_RANK) - throw std::invalid_argument("NDArrayFactory::create: rank of NDArray can't exceed 32 !"); +//////////////////////////////////////////////////////////////////////// +template +NDArray NDArrayFactory::create(const char order, + const std::vector& shape, + const std::initializer_list& data, + sd::LaunchContext* context) { + std::vector vec(data); + return create(order, shape, vec, context); +} +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::initializer_list& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::initializer_list& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::initializer_list& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::initializer_list& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::initializer_list& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::initializer_list& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::initializer_list& data, sd::LaunchContext* context); +template SD_EXPORT NDArray +NDArrayFactory::create(const char order, const std::vector& shape, + const std::initializer_list& data, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::initializer_list& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::initializer_list& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::initializer_list& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::initializer_list& data, sd::LaunchContext* context); - ShapeDescriptor descriptor(DataTypeUtils::fromT(), order, shape); +#endif - if (descriptor.arrLength() != data.size()) { - nd4j_printf("NDArrayFactory::create: data size [%i] doesn't match shape length [%lld]\n", data.size(), descriptor.arrLength()); - throw std::runtime_error("NDArrayFactory::create: data size doesn't match shape"); - } +//////////////////////////////////////////////////////////////////////// +template +NDArray* NDArrayFactory::create_(const T scalar, sd::LaunchContext* context) { + std::shared_ptr buffer = std::make_shared( + 1 * sizeof(T), DataTypeUtils::fromT(), context->getWorkspace(), true); - std::shared_ptr buffer = std::make_shared(data.data(), DataTypeUtils::fromT(), descriptor.arrLength() * sizeof(T), context->getWorkspace()); + NDArray* res = new NDArray( + buffer, ShapeDescriptor::scalarDescriptor(DataTypeUtils::fromT()), + context); - NDArray result(buffer, descriptor, context); + res->bufferAsT()[0] = scalar; - return result; - } - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); + res->tickWriteHost(); + res->syncToDevice(); -//////////////////////////////////////////////////////////////////////// -template -NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, sd::LaunchContext * context) { - return create_(order, shape, DataTypeUtils::fromT(), context); + return res; } -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray* NDArrayFactory::create_, (const char order, const std::vector &shape, sd::LaunchContext * context), LIBND4J_TYPES); +template SD_EXPORT NDArray* NDArrayFactory::create_(const double scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const float scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const float16 scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const bfloat16 scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const Nd4jLong scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const int scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const bool scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const int8_t scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const uint8_t scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const uint16_t scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const uint32_t scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const uint64_t scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const int16_t scalar, + sd::LaunchContext* context); -//////////////////////////////////////////////////////////////////////// template -void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector) { +NDArray NDArrayFactory::create(sd::DataType type, const T scalar, + sd::LaunchContext* context) { + if (type == DataTypeUtils::fromT()) + return NDArrayFactory::create(scalar, context); - memcpy(ptr, vector.data(), vector.size() * sizeof(T)); -} + NDArray res(type, context); + res.p(0, scalar); + res.syncToDevice(); -template <> -void ND4J_EXPORT NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector) { - auto p = reinterpret_cast(ptr); - for (Nd4jLong e = 0; e < vector.size(); e++) - p[e] = vector[e]; + return res; } +// BUILD_DOUBLE_TEMPLATE(template SD_EXPORT NDArray NDArrayFactory::create, +// (DataType type, const T scalar, sd::LaunchContext * context), +// LIBND4J_TYPES); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const double scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const float scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const float16 scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const bfloat16 scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const Nd4jLong scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const int scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const int8_t scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const uint8_t scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const uint16_t scalar, + sd::LaunchContext* workspace); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const uint32_t scalar, + sd::LaunchContext* workspace); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const uint64_t scalar, + sd::LaunchContext* workspace); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const int16_t scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const bool scalar, + sd::LaunchContext* context); -template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); +template +NDArray NDArrayFactory::create(const T scalar, sd::LaunchContext* context) { + std::shared_ptr buffer = std::make_shared( + 1 * sizeof(T), DataTypeUtils::fromT(), context->getWorkspace(), true); + NDArray res(buffer, + ShapeDescriptor::scalarDescriptor(DataTypeUtils::fromT()), + context); -#ifndef __JAVACPP_HACK__ - //////////////////////////////////////////////////////////////////////// - template - NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const T value, const char order, sd::LaunchContext * context) { - return valueOf(std::vector(shape), value, order); - } - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const double value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const float value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const float16 value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const bfloat16 value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const Nd4jLong value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const int value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const uint8_t value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const int8_t value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const int16_t value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const bool value, const char order, sd::LaunchContext * context); + res.bufferAsT()[0] = scalar; -//////////////////////////////////////////////////////////////////////// - template - NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context) { - std::vector vec(data); - return create(order, shape, vec, context); - } - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); + res.tickWriteHost(); + res.syncToDevice(); -#endif + return res; +} +template SD_EXPORT NDArray NDArrayFactory::create(const double scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(const float scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(const float16 scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(const bfloat16 scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(const Nd4jLong scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(const int scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(const int8_t scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(const uint8_t scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(const int16_t scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(const uint16_t scalar, + sd::LaunchContext* workspace); +template SD_EXPORT NDArray NDArrayFactory::create(const uint32_t scalar, + sd::LaunchContext* workspace); +template SD_EXPORT NDArray NDArrayFactory::create(const uint64_t scalar, + sd::LaunchContext* workspace); +template SD_EXPORT NDArray NDArrayFactory::create(const bool scalar, + sd::LaunchContext* context); //////////////////////////////////////////////////////////////////////// - template - NDArray* NDArrayFactory::create_(const T scalar, sd::LaunchContext * context) { - - std::shared_ptr buffer = std::make_shared(1 * sizeof(T), DataTypeUtils::fromT(), context->getWorkspace(), true); - - NDArray* res = new NDArray(buffer, ShapeDescriptor::scalarDescriptor(DataTypeUtils::fromT()), context); - - res->bufferAsT()[0] = scalar; - - res->tickWriteHost(); - res->syncToDevice(); - - return res; - } - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const double scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const float scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const float16 scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const bfloat16 scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const Nd4jLong scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const int scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const bool scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const int8_t scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const uint8_t scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const uint16_t scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const uint32_t scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const uint64_t scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const int16_t scalar, sd::LaunchContext * context); - - template - NDArray NDArrayFactory::create(sd::DataType type, const T scalar, sd::LaunchContext * context) { - - if (type == DataTypeUtils::fromT()) - return NDArrayFactory::create(scalar, context); - - NDArray res(type, context); - res.p(0, scalar); - res.syncToDevice(); - - return res; - } -// BUILD_DOUBLE_TEMPLATE(template ND4J_EXPORT NDArray NDArrayFactory::create, (DataType type, const T scalar, sd::LaunchContext * context), LIBND4J_TYPES); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const double scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const float scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const float16 scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const bfloat16 scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const Nd4jLong scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const int scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const int8_t scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const uint8_t scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const uint16_t scalar, sd::LaunchContext* workspace); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const uint32_t scalar, sd::LaunchContext* workspace); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const uint64_t scalar, sd::LaunchContext* workspace); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const int16_t scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const bool scalar, sd::LaunchContext * context); - - template - NDArray NDArrayFactory::create(const T scalar, sd::LaunchContext * context) { - - std::shared_ptr buffer = std::make_shared(1 * sizeof(T), DataTypeUtils::fromT(), context->getWorkspace(), true); - - NDArray res(buffer, ShapeDescriptor::scalarDescriptor(DataTypeUtils::fromT()), context); - - res.bufferAsT()[0] = scalar; - - res.tickWriteHost(); - res.syncToDevice(); - - return res; - } - template ND4J_EXPORT NDArray NDArrayFactory::create(const double scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const float scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const float16 scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const bfloat16 scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const Nd4jLong scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const int scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const int8_t scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const uint8_t scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const int16_t scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const uint16_t scalar, sd::LaunchContext* workspace); - template ND4J_EXPORT NDArray NDArrayFactory::create(const uint32_t scalar, sd::LaunchContext* workspace); - template ND4J_EXPORT NDArray NDArrayFactory::create(const uint64_t scalar, sd::LaunchContext* workspace); - template ND4J_EXPORT NDArray NDArrayFactory::create(const bool scalar, sd::LaunchContext * context); - +template +NDArray* NDArrayFactory::create_(const char order, + const std::vector& shape, + const std::vector& data, + sd::LaunchContext* context) { + return new NDArray(NDArrayFactory::create(order, shape, data, context)); +} +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); //////////////////////////////////////////////////////////////////////// -template -NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context) { - - return new NDArray(NDArrayFactory::create(order, shape, data, context)); -} -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); - - - //////////////////////////////////////////////////////////////////////// - template <> - ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, NDArray* value, const char order, sd::LaunchContext * context) { - auto result = create_(order, shape, value->dataType(), context); - result->assign(*value); - return result; - } - - template <> - ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, NDArray& value, const char order, sd::LaunchContext * context) { - auto result = create_(order, shape, value.dataType(), context); - result->assign(value); - return result; - } - - template - NDArray* NDArrayFactory::valueOf(const std::vector& shape, const T value, const char order, sd::LaunchContext * context) { - auto result = create_(order, shape, DataTypeUtils::fromT()); - result->assign(value); - return result; - } - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const double value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const float value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const float16 value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const bfloat16 value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const Nd4jLong value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const int value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const int16_t value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const int8_t value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const uint8_t value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const bool value, const char order, sd::LaunchContext * context); - - - //////////////////////////////////////////////////////////////////////// - template - NDArray* NDArrayFactory::linspace(const T from, const T to, const Nd4jLong numElements) { - NDArray* result = NDArrayFactory::vector(numElements); - //TO DO: linspace should be executed on DEVICE, but only CPU version implemnted! - for (Nd4jLong e = 0; e < numElements; e++) { - T step = (T) e / ((T) numElements - (T) 1); - result->p(e, (from * ((T) 1 - step) + step * to)); - } - result->syncToDevice(); - - return result; - } - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const double from, const double to, const Nd4jLong numElements); - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const float from, const float to, const Nd4jLong numElements); - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const float16 from, const float16 to, const Nd4jLong numElements); - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const bfloat16 from, const bfloat16 to, const Nd4jLong numElements); - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const Nd4jLong from, const Nd4jLong to, const Nd4jLong numElements); - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const int from, const int to, const Nd4jLong numElements); - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const int16_t from, const int16_t to, const Nd4jLong numElements); - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const uint8_t from, const uint8_t to, const Nd4jLong numElements); - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const uint16_t from, const uint16_t to, const Nd4jLong numElements); - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const uint32_t from, const uint32_t to, const Nd4jLong numElements); - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const uint64_t from, const uint64_t to, const Nd4jLong numElements); - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const int8_t from, const int8_t to, const Nd4jLong numElements); - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const bool from, const bool to, const Nd4jLong numElements); +template <> +SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, + NDArray* value, const char order, + sd::LaunchContext* context) { + auto result = create_(order, shape, value->dataType(), context); + result->assign(*value); + return result; +} + +template <> +SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, + NDArray& value, const char order, + sd::LaunchContext* context) { + auto result = create_(order, shape, value.dataType(), context); + result->assign(value); + return result; +} + +template +NDArray* NDArrayFactory::valueOf(const std::vector& shape, + const T value, const char order, + sd::LaunchContext* context) { + auto result = create_(order, shape, DataTypeUtils::fromT()); + result->assign(value); + return result; +} +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::vector& shape, const double value, const char order, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::vector& shape, const float value, const char order, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::vector& shape, const float16 value, const char order, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::vector& shape, const bfloat16 value, const char order, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::vector& shape, const Nd4jLong value, const char order, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::vector& shape, const int value, const char order, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::vector& shape, const int16_t value, const char order, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::vector& shape, const int8_t value, const char order, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::vector& shape, const uint8_t value, const char order, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::vector& shape, const bool value, const char order, + sd::LaunchContext* context); //////////////////////////////////////////////////////////////////////// - template - NDArray* NDArrayFactory::vector(Nd4jLong length, const T value, sd::LaunchContext * context) { - - std::shared_ptr buffer = std::make_shared(length * sizeof(T), DataTypeUtils::fromT(), context->getWorkspace(), true); - - auto res = new NDArray(buffer, ShapeDescriptor::vectorDescriptor(length, DataTypeUtils::fromT()), context); - - if (value == (T)0.0f) - res->nullify(); - else - res->assign(value); - - return res; - } - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const double startingValue, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const float startingValue, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const float16 startingValue, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const bfloat16 startingValue, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const Nd4jLong startingValue, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const int startingValue, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const uint8_t startingValue, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const uint16_t startingValue, sd::LaunchContext *workspace); - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const uint32_t startingValue, sd::LaunchContext *workspace); - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const uint64_t startingValue, sd::LaunchContext *workspace); - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const int8_t startingValue, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const int16_t startingValue, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const bool startingValue, sd::LaunchContext * context); +template +NDArray* NDArrayFactory::linspace(const T from, const T to, + const Nd4jLong numElements) { + NDArray* result = NDArrayFactory::vector(numElements); + // TO DO: linspace should be executed on DEVICE, but only CPU version + // implemnted! + for (Nd4jLong e = 0; e < numElements; e++) { + T step = (T)e / ((T)numElements - (T)1); + result->p(e, (from * ((T)1 - step) + step * to)); + } + result->syncToDevice(); + + return result; +} +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const double from, const double to, const Nd4jLong numElements); +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const float from, const float to, const Nd4jLong numElements); +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const float16 from, const float16 to, const Nd4jLong numElements); +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const bfloat16 from, const bfloat16 to, const Nd4jLong numElements); +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const Nd4jLong from, const Nd4jLong to, const Nd4jLong numElements); +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const int from, const int to, const Nd4jLong numElements); +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const int16_t from, const int16_t to, const Nd4jLong numElements); +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const uint8_t from, const uint8_t to, const Nd4jLong numElements); +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const uint16_t from, const uint16_t to, const Nd4jLong numElements); +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const uint32_t from, const uint32_t to, const Nd4jLong numElements); +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const uint64_t from, const uint64_t to, const Nd4jLong numElements); +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const int8_t from, const int8_t to, const Nd4jLong numElements); +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const bool from, const bool to, const Nd4jLong numElements); //////////////////////////////////////////////////////////////////////// - template - NDArray NDArrayFactory::create(const char order, const std::initializer_list& shape, sd::LaunchContext * context) { - std::vector vec(shape); - return create(order, vec, context); - } - BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArrayFactory::create, (const char, const std::initializer_list&, sd::LaunchContext * context), LIBND4J_TYPES); +template +NDArray* NDArrayFactory::vector(Nd4jLong length, const T value, + sd::LaunchContext* context) { + std::shared_ptr buffer = std::make_shared( + length * sizeof(T), DataTypeUtils::fromT(), context->getWorkspace(), + true); + + auto res = new NDArray( + buffer, + ShapeDescriptor::vectorDescriptor(length, DataTypeUtils::fromT()), + context); + + if (value == (T)0.0f) + res->nullify(); + else + res->assign(value); + + return res; +} +template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, + const double startingValue, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, + const float startingValue, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, + const float16 startingValue, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, + const bfloat16 startingValue, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, + const Nd4jLong startingValue, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, + const int startingValue, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, + const uint8_t startingValue, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::vector( + Nd4jLong length, const uint16_t startingValue, + sd::LaunchContext* workspace); +template SD_EXPORT NDArray* NDArrayFactory::vector( + Nd4jLong length, const uint32_t startingValue, + sd::LaunchContext* workspace); +template SD_EXPORT NDArray* NDArrayFactory::vector( + Nd4jLong length, const uint64_t startingValue, + sd::LaunchContext* workspace); +template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, + const int8_t startingValue, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, + const int16_t startingValue, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, + const bool startingValue, + sd::LaunchContext* context); //////////////////////////////////////////////////////////////////////// - template - NDArray NDArrayFactory::create(const char order, const std::vector &shape, sd::LaunchContext * context) { - return create(order, shape, DataTypeUtils::fromT(), context); - } - BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArrayFactory::create, (const char order, const std::vector &shape, sd::LaunchContext * context), LIBND4J_TYPES); +template +NDArray NDArrayFactory::create(const char order, + const std::initializer_list& shape, + sd::LaunchContext* context) { + std::vector vec(shape); + return create(order, vec, context); +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray NDArrayFactory::create, + (const char, const std::initializer_list&, + sd::LaunchContext* context), + LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// -NDArray NDArrayFactory::create(const char order, const std::vector &shape, sd::DataType dtype, sd::LaunchContext* context) { +template +NDArray NDArrayFactory::create(const char order, + const std::vector& shape, + sd::LaunchContext* context) { + return create(order, shape, DataTypeUtils::fromT(), context); +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray NDArrayFactory::create, + (const char order, const std::vector& shape, + sd::LaunchContext* context), + LIBND4J_TYPES); - if ((int) shape.size() > MAX_RANK) - throw std::invalid_argument("NDArrayFactory::create: rank of NDArray can't exceed 32"); +//////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::create(const char order, + const std::vector& shape, + sd::DataType dtype, sd::LaunchContext* context) { + if ((int)shape.size() > MAX_RANK) + throw std::invalid_argument( + "NDArrayFactory::create: rank of NDArray can't exceed 32"); - ShapeDescriptor descriptor(dtype, order, shape); + ShapeDescriptor descriptor(dtype, order, shape); - std::shared_ptr buffer = std::make_shared(descriptor.arrLength() * DataTypeUtils::sizeOfElement(dtype), dtype, context->getWorkspace()); + std::shared_ptr buffer = std::make_shared( + descriptor.arrLength() * DataTypeUtils::sizeOfElement(dtype), dtype, + context->getWorkspace()); - NDArray result(buffer, descriptor, context); + NDArray result(buffer, descriptor, context); - result.nullify(); + result.nullify(); - return result; + return result; } - //////////////////////////////////////////////////////////////////////// -NDArray NDArrayFactory::create(sd::DataType dtype, sd::LaunchContext * context) { +NDArray NDArrayFactory::create(sd::DataType dtype, sd::LaunchContext* context) { + std::shared_ptr buffer = + std::make_shared(DataTypeUtils::sizeOfElement(dtype), dtype, + context->getWorkspace(), true); - std::shared_ptr buffer = std::make_shared(DataTypeUtils::sizeOfElement(dtype), dtype, context->getWorkspace(), true); + NDArray res(buffer, ShapeDescriptor::scalarDescriptor(dtype), context); - NDArray res(buffer, ShapeDescriptor::scalarDescriptor(dtype), context); + res.nullify(); - res.nullify(); - - return res; + return res; } -NDArray* NDArrayFactory::create_(sd::DataType dtype, sd::LaunchContext * context) { - auto result = new NDArray(); - *result = NDArrayFactory::create(dtype, context); - return result; +NDArray* NDArrayFactory::create_(sd::DataType dtype, + sd::LaunchContext* context) { + auto result = new NDArray(); + *result = NDArrayFactory::create(dtype, context); + return result; } //////////////////////////////////////////////////////////////////////// template -NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context) { - - std::shared_ptr buffer = std::make_shared(values.size() * sizeof(T), DataTypeUtils::fromT(), context->getWorkspace(), true); +NDArray NDArrayFactory::create(const std::vector& values, + sd::LaunchContext* context) { + std::shared_ptr buffer = std::make_shared( + values.size() * sizeof(T), DataTypeUtils::fromT(), + context->getWorkspace(), true); - NDArray res(buffer, ShapeDescriptor::vectorDescriptor(values.size(), DataTypeUtils::fromT()), context); + NDArray res(buffer, + ShapeDescriptor::vectorDescriptor(values.size(), + DataTypeUtils::fromT()), + context); - memcpyFromVector(res.buffer(), values); + memcpyFromVector(res.buffer(), values); - res.tickWriteHost(); - res.syncToDevice(); + res.tickWriteHost(); + res.syncToDevice(); - return res; + return res; } -template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); +template SD_EXPORT NDArray NDArrayFactory::create( + const std::vector& values, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const std::vector& values, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const std::vector& values, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const std::vector& values, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const std::vector& values, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const std::vector& values, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const std::vector& values, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const std::vector& values, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const std::vector& values, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const std::vector& values, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const std::vector& values, sd::LaunchContext* context); //////////////////////////////////////////////////////////////////////// - template - NDArray* NDArrayFactory::empty_(sd::LaunchContext * context) { - auto shapeInfo = ShapeBuilders::createScalarShapeInfo(DataTypeUtils::fromT(), context->getWorkspace()); - ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); - auto result = new NDArray(nullptr, shapeInfo, context, false); +template +NDArray* NDArrayFactory::empty_(sd::LaunchContext* context) { + auto shapeInfo = ShapeBuilders::createScalarShapeInfo( + DataTypeUtils::fromT(), context->getWorkspace()); + ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); + auto result = new NDArray(nullptr, shapeInfo, context, false); - RELEASE(shapeInfo, context->getWorkspace()); + RELEASE(shapeInfo, context->getWorkspace()); - return result; - } - BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray* NDArrayFactory::empty_, (sd::LaunchContext * context), LIBND4J_TYPES); + return result; +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray* NDArrayFactory::empty_, + (sd::LaunchContext * context), LIBND4J_TYPES); - NDArray* NDArrayFactory::empty_(sd::DataType dataType, sd::LaunchContext * context) { - if (context == nullptr) - context = sd::LaunchContext ::defaultContext(); +NDArray* NDArrayFactory::empty_(sd::DataType dataType, + sd::LaunchContext* context) { + if (context == nullptr) context = sd::LaunchContext ::defaultContext(); - auto shapeInfo = ShapeBuilders::createScalarShapeInfo(dataType, context->getWorkspace()); - ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); - auto result = new NDArray(nullptr, shapeInfo, context, false); + auto shapeInfo = + ShapeBuilders::createScalarShapeInfo(dataType, context->getWorkspace()); + ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); + auto result = new NDArray(nullptr, shapeInfo, context, false); - RELEASE(shapeInfo, context->getWorkspace()); + RELEASE(shapeInfo, context->getWorkspace()); - return result; - } + return result; +} - //////////////////////////////////////////////////////////////////////// - template - NDArray NDArrayFactory::empty(sd::LaunchContext * context) { - return empty(DataTypeUtils::fromT(), context); - } - BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArrayFactory::empty, (sd::LaunchContext * context), LIBND4J_TYPES); +//////////////////////////////////////////////////////////////////////// +template +NDArray NDArrayFactory::empty(sd::LaunchContext* context) { + return empty(DataTypeUtils::fromT(), context); +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray NDArrayFactory::empty, + (sd::LaunchContext * context), LIBND4J_TYPES); - //////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::empty(sd::DataType dataType, sd::LaunchContext * context) { - auto shapeInfo = ShapeBuilders::createScalarShapeInfo(dataType, context->getWorkspace()); - ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); - NDArray result(nullptr, shapeInfo, context, false); +//////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::empty(sd::DataType dataType, + sd::LaunchContext* context) { + auto shapeInfo = + ShapeBuilders::createScalarShapeInfo(dataType, context->getWorkspace()); + ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); + NDArray result(nullptr, shapeInfo, context, false); - RELEASE(shapeInfo, context->getWorkspace()); + RELEASE(shapeInfo, context->getWorkspace()); - return result; - } + return result; +} //////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::valueOf(const std::vector& shape, const NDArray& value, const char order, sd::LaunchContext * context) { - auto res = NDArrayFactory::create_(order, shape, value.dataType(), context); - res->assign(const_cast(value)); - return res; - } +NDArray* NDArrayFactory::valueOf(const std::vector& shape, + const NDArray& value, const char order, + sd::LaunchContext* context) { + auto res = NDArrayFactory::create_(order, shape, value.dataType(), context); + res->assign(const_cast(value)); + return res; +} //////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::create_( const char order, const std::vector &shape, sd::DataType dataType, sd::LaunchContext * context) { - - return new NDArray(order, shape, dataType, context); - } +NDArray* NDArrayFactory::create_(const char order, + const std::vector& shape, + sd::DataType dataType, + sd::LaunchContext* context) { + return new NDArray(order, shape, dataType, context); +} //////////////////////////////////////////////////////////////////////// template -NDArray NDArrayFactory::create(T* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context) { - - if ((int) shape.size() > MAX_RANK) - throw std::invalid_argument("NDArrayFactory::create: Rank of NDArray can't exceed 32"); - - std::vector shp(shape); - ShapeDescriptor descriptor(DataTypeUtils::fromT(), order, shp); - - std::shared_ptr pBuffer = std::make_shared(buffer, descriptor.arrLength() * sizeof(T), descriptor.dataType(), false, context->getWorkspace()); - - NDArray result(pBuffer, descriptor, context); - - return result; -} - -template ND4J_EXPORT NDArray NDArrayFactory::create(double* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(float* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(float16* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(bfloat16* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(Nd4jLong * buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(int* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(bool* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(uint8_t * buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(int8_t* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(int16_t* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); - - ///////////////////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string(const char16_t* u16string, sd::DataType dtype, sd::LaunchContext* context) { - return NDArray(u16string, dtype, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_(const char16_t* u16string, sd::DataType dtype, sd::LaunchContext* context) { - return string_(std::u16string(u16string), dtype, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_(const std::u16string& u16string, sd::DataType dtype, sd::LaunchContext* context) { - auto res = new NDArray(); - *res = NDArray(u16string, dtype, context); - return res; - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string(const std::u16string& u16string, sd::DataType dtype, sd::LaunchContext* context) { - return NDArray(u16string, dtype, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string(const char32_t* u32string, sd::DataType dtype, sd::LaunchContext* context) { - return NDArray(u32string, dtype, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_(const char32_t* u32string, sd::DataType dtype, sd::LaunchContext* context) { - return string_(std::u32string(u32string), dtype, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_(const std::u32string& u32string, sd::DataType dtype, sd::LaunchContext* context) { - auto res = new NDArray(); - *res = NDArray(u32string, dtype, context); - return res; - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string(const std::u32string& u32string, sd::DataType dtype, sd::LaunchContext* context) { - return NDArray(u32string, dtype, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string(const char* str, sd::DataType dtype, sd::LaunchContext* context) { - return NDArray(str, dtype, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_(const char* str, sd::DataType dtype, sd::LaunchContext* context) { - return string_(std::string(str), dtype, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_(const std::string& str, sd::DataType dtype, sd::LaunchContext* context) { - auto res = new NDArray(); - *res = NDArray(str, dtype, context); - return res; - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string(const std::string& str, sd::DataType dtype, sd::LaunchContext* context) { - return NDArray(str, dtype, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string(const std::vector &shape, const std::initializer_list &strings, sd::DataType dataType, sd::LaunchContext * context) { - return NDArray(shape, std::vector(strings), dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string( const std::vector &shape, const std::vector &strings, sd::DataType dataType, sd::LaunchContext * context) { - return NDArray( shape, strings, dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string( const std::vector &shape, const std::initializer_list &string, sd::DataType dataType, sd::LaunchContext * context) { - return NDArray( shape, std::vector(string), dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_( const std::vector &shape, const std::initializer_list &strings, sd::DataType dataType, sd::LaunchContext * context) { - return NDArrayFactory::string_( shape, std::vector(strings), dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_( const std::vector &shape, const std::vector &strings, sd::DataType dataType, sd::LaunchContext * context) { - std::vector vec(strings.size()); - int cnt = 0; - for (auto s:strings) - vec[cnt++] = std::string(s); - - return NDArrayFactory::string_( shape, vec, dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_( const std::vector &shape, const std::initializer_list &string, sd::DataType dataType, sd::LaunchContext * context) { - return NDArrayFactory::string_( shape, std::vector(string), dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string( const std::vector &shape, const std::vector &string, sd::DataType dataType, sd::LaunchContext * context) { - return NDArray(shape, string, dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_(const std::vector &shape, const std::vector &string, sd::DataType dataType, sd::LaunchContext * context) { - auto res = new NDArray(); - *res = NDArray( shape, string, dataType, context); - return res; - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string(const std::vector& shape, const std::initializer_list& strings, sd::DataType dataType, sd::LaunchContext* context) { - return NDArray( shape, std::vector(strings), dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string( const std::vector& shape, const std::vector& strings, sd::DataType dataType, sd::LaunchContext* context) { - return NDArray( shape, strings, dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string( const std::vector& shape, const std::initializer_list& string, sd::DataType dataType, sd::LaunchContext* context) { - return NDArray( shape, std::vector(string), dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_( const std::vector& shape, const std::initializer_list& strings, sd::DataType dataType, sd::LaunchContext* context) { - return NDArrayFactory::string_( shape, std::vector(strings), dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_( const std::vector& shape, const std::vector& strings, sd::DataType dataType, sd::LaunchContext* context) { - std::vector vec(strings.size()); - int cnt = 0; - for (auto s : strings) - vec[cnt++] = std::u16string(s); - - return NDArrayFactory::string_( shape, vec, dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_( const std::vector& shape, const std::initializer_list& string, sd::DataType dataType, sd::LaunchContext* context) { - return NDArrayFactory::string_( shape, std::vector(string), dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_( const std::vector& shape, const std::vector& string, sd::DataType dataType, sd::LaunchContext* context) { - auto res = new NDArray(); - *res = NDArray( shape, string, dataType, context); - return res; - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string( const std::vector& shape, const std::vector& string, sd::DataType dtype, sd::LaunchContext* context) { - return NDArray( shape, string, dtype, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string( const std::vector& shape, const std::initializer_list& strings, sd::DataType dataType, sd::LaunchContext* context) { - return NDArray( shape, std::vector(strings), dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string( const std::vector& shape, const std::vector& strings, sd::DataType dataType, sd::LaunchContext* context) { - return NDArray( shape, strings, dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string( const std::vector& shape, const std::initializer_list& string, sd::DataType dataType, sd::LaunchContext* context) { - return NDArray(shape, std::vector(string), dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_( const std::vector& shape, const std::initializer_list& strings, sd::DataType dataType, sd::LaunchContext* context) { - return NDArrayFactory::string_( shape, std::vector(strings), dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_( const std::vector& shape, const std::vector& strings, sd::DataType dataType, sd::LaunchContext* context) { - std::vector vec(strings.size()); - int cnt = 0; - for (auto s : strings) - vec[cnt++] = std::u32string(s); - return NDArrayFactory::string_( shape, vec, dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_( const std::vector& shape, const std::initializer_list& string, sd::DataType dataType, sd::LaunchContext* context) { - return NDArrayFactory::string_( shape, std::vector(string), dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_( const std::vector& shape, const std::vector& string, sd::DataType dataType, sd::LaunchContext* context) { - auto res = new NDArray(); - *res = NDArray( shape, string, dataType, context); - return res; - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string(const std::vector& shape, const std::vector& string, sd::DataType dtype, sd::LaunchContext* context) { - return NDArray( shape, string, dtype, context); - } - - - NDArray NDArrayFactory::fromNpyFile(const char *fileName) { - auto size = sd::graph::getFileSize(fileName); - if (size < 0) - throw std::runtime_error("File doesn't exit"); - - auto pNPY = reinterpret_cast(::numpyFromFile(std::string(fileName))); - - auto nBuffer = reinterpret_cast(::dataPointForNumpy(pNPY)); - auto shape = reinterpret_cast(::shapeBufferForNumpy(pNPY)); - - auto length = shape::length(shape); - int8_t *buffer = nullptr; - sd::memory::Workspace *workspace = nullptr; - auto byteLen = length * DataTypeUtils::sizeOfElement(ArrayOptions::dataType(shape)); - - ALLOCATE(buffer, workspace, byteLen, int8_t); - memcpy(buffer, nBuffer, byteLen); - - free(pNPY); - - return NDArray(buffer, shape, LaunchContext::defaultContext(), true); - } +NDArray NDArrayFactory::create(T* buffer, const char order, + const std::initializer_list& shape, + sd::LaunchContext* context) { + if ((int)shape.size() > MAX_RANK) + throw std::invalid_argument( + "NDArrayFactory::create: Rank of NDArray can't exceed 32"); + + std::vector shp(shape); + ShapeDescriptor descriptor(DataTypeUtils::fromT(), order, shp); + + std::shared_ptr pBuffer = std::make_shared( + buffer, descriptor.arrLength() * sizeof(T), descriptor.dataType(), false, + context->getWorkspace()); + + NDArray result(pBuffer, descriptor, context); + + return result; +} + +template SD_EXPORT NDArray NDArrayFactory::create( + double* buffer, const char order, + const std::initializer_list& shape, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + float* buffer, const char order, + const std::initializer_list& shape, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + float16* buffer, const char order, + const std::initializer_list& shape, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + bfloat16* buffer, const char order, + const std::initializer_list& shape, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + Nd4jLong* buffer, const char order, + const std::initializer_list& shape, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + int* buffer, const char order, const std::initializer_list& shape, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + bool* buffer, const char order, + const std::initializer_list& shape, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + uint8_t* buffer, const char order, + const std::initializer_list& shape, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + int8_t* buffer, const char order, + const std::initializer_list& shape, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + int16_t* buffer, const char order, + const std::initializer_list& shape, sd::LaunchContext* context); + +///////////////////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const char16_t* u16string, sd::DataType dtype, + sd::LaunchContext* context) { + return NDArray(u16string, dtype, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_(const char16_t* u16string, sd::DataType dtype, + sd::LaunchContext* context) { + return string_(std::u16string(u16string), dtype, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_(const std::u16string& u16string, + sd::DataType dtype, + sd::LaunchContext* context) { + auto res = new NDArray(); + *res = NDArray(u16string, dtype, context); + return res; +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const std::u16string& u16string, + sd::DataType dtype, sd::LaunchContext* context) { + return NDArray(u16string, dtype, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const char32_t* u32string, sd::DataType dtype, + sd::LaunchContext* context) { + return NDArray(u32string, dtype, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_(const char32_t* u32string, sd::DataType dtype, + sd::LaunchContext* context) { + return string_(std::u32string(u32string), dtype, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_(const std::u32string& u32string, + sd::DataType dtype, + sd::LaunchContext* context) { + auto res = new NDArray(); + *res = NDArray(u32string, dtype, context); + return res; +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const std::u32string& u32string, + sd::DataType dtype, sd::LaunchContext* context) { + return NDArray(u32string, dtype, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const char* str, sd::DataType dtype, + sd::LaunchContext* context) { + return NDArray(str, dtype, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_(const char* str, sd::DataType dtype, + sd::LaunchContext* context) { + return string_(std::string(str), dtype, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_(const std::string& str, sd::DataType dtype, + sd::LaunchContext* context) { + auto res = new NDArray(); + *res = NDArray(str, dtype, context); + return res; +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const std::string& str, sd::DataType dtype, + sd::LaunchContext* context) { + return NDArray(str, dtype, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string( + const std::vector& shape, + const std::initializer_list& strings, sd::DataType dataType, + sd::LaunchContext* context) { + return NDArray(shape, std::vector(strings), dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const std::vector& shape, + const std::vector& strings, + sd::DataType dataType, + sd::LaunchContext* context) { + return NDArray(shape, strings, dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const std::vector& shape, + const std::initializer_list& string, + sd::DataType dataType, + sd::LaunchContext* context) { + return NDArray(shape, std::vector(string), dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_( + const std::vector& shape, + const std::initializer_list& strings, sd::DataType dataType, + sd::LaunchContext* context) { + return NDArrayFactory::string_(shape, std::vector(strings), + dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_(const std::vector& shape, + const std::vector& strings, + sd::DataType dataType, + sd::LaunchContext* context) { + std::vector vec(strings.size()); + int cnt = 0; + for (auto s : strings) vec[cnt++] = std::string(s); + + return NDArrayFactory::string_(shape, vec, dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_( + const std::vector& shape, + const std::initializer_list& string, sd::DataType dataType, + sd::LaunchContext* context) { + return NDArrayFactory::string_(shape, std::vector(string), + dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const std::vector& shape, + const std::vector& string, + sd::DataType dataType, + sd::LaunchContext* context) { + return NDArray(shape, string, dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_(const std::vector& shape, + const std::vector& string, + sd::DataType dataType, + sd::LaunchContext* context) { + auto res = new NDArray(); + *res = NDArray(shape, string, dataType, context); + return res; +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string( + const std::vector& shape, + const std::initializer_list& strings, + sd::DataType dataType, sd::LaunchContext* context) { + return NDArray(shape, std::vector(strings), dataType, + context); +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const std::vector& shape, + const std::vector& strings, + sd::DataType dataType, + sd::LaunchContext* context) { + return NDArray(shape, strings, dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string( + const std::vector& shape, + const std::initializer_list& string, sd::DataType dataType, + sd::LaunchContext* context) { + return NDArray(shape, std::vector(string), dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_( + const std::vector& shape, + const std::initializer_list& strings, + sd::DataType dataType, sd::LaunchContext* context) { + return NDArrayFactory::string_(shape, std::vector(strings), + dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_(const std::vector& shape, + const std::vector& strings, + sd::DataType dataType, + sd::LaunchContext* context) { + std::vector vec(strings.size()); + int cnt = 0; + for (auto s : strings) vec[cnt++] = std::u16string(s); + + return NDArrayFactory::string_(shape, vec, dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_( + const std::vector& shape, + const std::initializer_list& string, sd::DataType dataType, + sd::LaunchContext* context) { + return NDArrayFactory::string_(shape, std::vector(string), + dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_(const std::vector& shape, + const std::vector& string, + sd::DataType dataType, + sd::LaunchContext* context) { + auto res = new NDArray(); + *res = NDArray(shape, string, dataType, context); + return res; +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const std::vector& shape, + const std::vector& string, + sd::DataType dtype, sd::LaunchContext* context) { + return NDArray(shape, string, dtype, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string( + const std::vector& shape, + const std::initializer_list& strings, + sd::DataType dataType, sd::LaunchContext* context) { + return NDArray(shape, std::vector(strings), dataType, + context); +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const std::vector& shape, + const std::vector& strings, + sd::DataType dataType, + sd::LaunchContext* context) { + return NDArray(shape, strings, dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string( + const std::vector& shape, + const std::initializer_list& string, sd::DataType dataType, + sd::LaunchContext* context) { + return NDArray(shape, std::vector(string), dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_( + const std::vector& shape, + const std::initializer_list& strings, + sd::DataType dataType, sd::LaunchContext* context) { + return NDArrayFactory::string_(shape, std::vector(strings), + dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_(const std::vector& shape, + const std::vector& strings, + sd::DataType dataType, + sd::LaunchContext* context) { + std::vector vec(strings.size()); + int cnt = 0; + for (auto s : strings) vec[cnt++] = std::u32string(s); + return NDArrayFactory::string_(shape, vec, dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_( + const std::vector& shape, + const std::initializer_list& string, sd::DataType dataType, + sd::LaunchContext* context) { + return NDArrayFactory::string_(shape, std::vector(string), + dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_(const std::vector& shape, + const std::vector& string, + sd::DataType dataType, + sd::LaunchContext* context) { + auto res = new NDArray(); + *res = NDArray(shape, string, dataType, context); + return res; +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const std::vector& shape, + const std::vector& string, + sd::DataType dtype, sd::LaunchContext* context) { + return NDArray(shape, string, dtype, context); +} + +NDArray NDArrayFactory::fromNpyFile(const char* fileName) { + if (!FileUtils::fileExists(fileName)) + throw std::runtime_error("File doesn't exit"); + + auto pNPY = reinterpret_cast(::numpyFromFile(std::string(fileName))); + + auto nBuffer = reinterpret_cast(::dataPointForNumpy(pNPY)); + auto shape = reinterpret_cast(::shapeBufferForNumpy(pNPY)); + + auto length = shape::length(shape); + int8_t* buffer = nullptr; + sd::memory::Workspace* workspace = nullptr; + auto byteLen = + length * DataTypeUtils::sizeOfElement(ArrayOptions::dataType(shape)); + + ALLOCATE(buffer, workspace, byteLen, int8_t); + memcpy(buffer, nBuffer, byteLen); + + free(pNPY); + + return NDArray(buffer, shape, LaunchContext::defaultContext(), true); } +} // namespace sd diff --git a/libnd4j/include/array/impl/NDArrayList.cpp b/libnd4j/include/array/impl/NDArrayList.cpp index 1aa9d2d4b2b3..4f17f4479cb4 100644 --- a/libnd4j/include/array/impl/NDArrayList.cpp +++ b/libnd4j/include/array/impl/NDArrayList.cpp @@ -18,255 +18,262 @@ // @author raver119@gmail.com // - -#include #include #include #include -#include - -namespace sd { - NDArrayList::NDArrayList(int height, bool expandable) { - _expandable = expandable; - _elements.store(0); - _counter.store(0); - _id.first = 0; - _id.second = 0; - _height = height; - //nd4j_printf("\nCreating NDArrayList\n",""); - } - - NDArrayList::~NDArrayList() { - //nd4j_printf("\nDeleting NDArrayList: [%i]\n", _chunks.size()); - for (auto const& v : _chunks) - delete v.second; +#include - _chunks.clear(); - } - - NDArray* NDArrayList::read(int idx) { - return new NDArray(readRaw(idx)->dup()); - } - - sd::DataType NDArrayList::dataType() { - return _dtype; - } +#include - NDArray* NDArrayList::readRaw(int idx) { - if (_chunks.count(idx) < 1) { - nd4j_printf("Non-existent chunk requested: [%i]\n", idx); - throw std::invalid_argument("Bad index"); +namespace sd { +NDArrayList::InternalArrayList::InternalArrayList(int height, bool expandable) { + _expandable = expandable; + _elements.store(0); + _counter.store(0); + _id.first = 0; + _id.second = 0; + _height = height; +} + +NDArrayList::NDArrayList(const NDArrayList &other) { _state = other._state; } + +NDArrayList::NDArrayList(NDArrayList &&other) { + _state = std::move(other._state); +} + +NDArrayList &NDArrayList::operator=(const NDArrayList &other) noexcept { + if (this == &other) return *this; + + _state = other._state; + + return *this; +} + +NDArrayList &NDArrayList::operator=(NDArrayList &&other) noexcept { + if (this == &other) return *this; + + _state = std::move(other._state); + + return *this; +} + +NDArrayList::NDArrayList(int height, bool expandable) { + _state = std::make_shared(height, expandable); +} + +NDArrayList::~NDArrayList() {} + +NDArray NDArrayList::read(int idx) { return readRaw(idx); } + +sd::DataType NDArrayList::dataType() const { return _state->_dtype; } + +NDArray NDArrayList::readRaw(int idx) { + if (_state->_chunks.count(idx) < 1) { + nd4j_printf("Non-existent chunk requested: [%i]\n", idx); + throw std::invalid_argument("Bad index"); + } + + return _state->_chunks.at(idx); +} + +Nd4jStatus NDArrayList::write(int idx, const NDArray &array) { + if (_state->_chunks.count(idx) == 0) + _state->_elements++; + else { + _state->_chunks.erase(idx); + } + + // we store reference shape on first write + if (_state->_chunks.empty()) { + _state->_dtype = array.dataType(); + + if (_state->_shape.empty()) { + // adding leading 1 to shape + _state->_shape.emplace_back(1); + for (int e = 0; e < array.rankOf(); e++) + _state->_shape.emplace_back(array.sizeAt(e)); + } else { + // if shape is inferred (say, from split_list) + if (array.rankOf() == _state->_shape.size()) { + // skipping first dim + for (int e = 1; e < _state->_shape.size(); e++) { + if (_state->_shape[e] != array.sizeAt(e)) + return Status::CODE(ND4J_STATUS_BAD_INPUT, + "NDArrayList: all arrays must have same size " + "along inner dimensions"); } - - return _chunks[idx]; + } else if (array.rankOf() == _state->_shape.size() - 1) { + // case like 2d _shape, and 1D rows + for (int e = 1; e < _state->_shape.size(); e++) + if (_state->_shape[e] != array.sizeAt(e - 1)) + return Status::CODE(ND4J_STATUS_BAD_INPUT, + "NDArrayList: all arrays must have same size " + "along inner dimensions"); + } else + return Status::CODE(ND4J_STATUS_BAD_INPUT, + "NDArrayList: all arrays must have same size along " + "inner dimensions"); } - - Nd4jStatus NDArrayList::write(int idx, NDArray* array) { - if (_chunks.count(idx) == 0) - _elements++; - else { - delete _chunks[idx]; - } - - - // we store reference shape on first write - if (_chunks.empty()) { - _dtype = array->dataType(); - - if (_shape.empty()) { - //adding leading 1 to shape - _shape.emplace_back(1); - for (int e = 0; e < array->rankOf(); e++) - _shape.emplace_back(array->sizeAt(e)); - } else { - // if shape is inferred (say, from split_list) - if (array->rankOf() == _shape.size()) { - // skipping first dim - for (int e = 1; e < _shape.size(); e++) { - if (_shape[e] != array->sizeAt(e)) - return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); - } - } else if (array->rankOf() == _shape.size() - 1) { - // case like 2d _shape, and 1D rows - for (int e = 1; e < _shape.size(); e++) - if (_shape[e] != array->sizeAt(e - 1)) - return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); - } else - return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); - - } + } else { + if (array.dataType() != _state->_dtype) + return Status::CODE(ND4J_STATUS_BAD_INPUT, + "NDArrayList: all arrays must have same data type"); + + // if shape is inferred (say, from split_list) + if (array.rankOf() == _state->_shape.size()) { + // skipping first dim + for (int e = 1; e < _state->_shape.size(); e++) { + if (_state->_shape[e] != array.sizeAt(e)) + return Status::CODE(ND4J_STATUS_BAD_INPUT, + "NDArrayList: all arrays must have same size " + "along inner dimensions"); + } + } else if (array.rankOf() == _state->_shape.size() - 1) { + // case like 2d _shape, and 1D rows + for (int e = 1; e < _state->_shape.size(); e++) + if (_state->_shape[e] != array.sizeAt(e - 1)) + return Status::CODE(ND4J_STATUS_BAD_INPUT, + "NDArrayList: all arrays must have same size " + "along inner dimensions"); + } else + return Status::CODE( + ND4J_STATUS_BAD_INPUT, + "NDArrayList: all arrays must have same size along inner dimensions"); + } + + // storing reference + _state->_chunks.insert({idx, array}); + + return Status::OK(); +} + +const std::vector &NDArrayList::shape() const { + return _state->_shape; +} + +int NDArrayList::counter() const { return _state->_counter++; } + +void NDArrayList::unstack(const NDArray &array, int axis) { + _state->_axis = axis; + std::vector args({axis}); + auto newAxis = ShapeUtils::evalDimsToExclude(array.rankOf(), args); + auto result = array.allTensorsAlongDimension(newAxis); + for (int e = 0; e < result.size(); e++) { + auto chunk = result.at(e); + write(e, chunk.dup(array.ordering())); + } +} + +void NDArrayList::setShape(const std::vector &shape) { + _state->_shape = shape; +} + +NDArray NDArrayList::stack() const { + // FIXME: this is bad for perf, but ok as poc + + int numElements = _state->_elements.load(); + std::vector inputs(numElements); + for (int e = 0; e < numElements; e++) { + _state->_chunks.at(e).syncToDevice(); + inputs[e] = &_state->_chunks.at(e); + } + + auto inShapeInfo = inputs[0]->shapeInfo(); + int rank = shape::rank(inShapeInfo); + NDArray array; + + if (shape::isEmpty(inShapeInfo)) { + switch (rank) { + case 0: { + if (numElements == 1) { + array = NDArray(inputs[0]->ordering(), {0}, + ArrayOptions::dataType(inShapeInfo), + inputs[0]->getContext()); } else { - if (array->dataType() != _dtype) - return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same data type"); - - - // if shape is inferred (say, from split_list) - if (array->rankOf() == _shape.size()) { - // skipping first dim - for (int e = 1; e < _shape.size(); e++) { - if (_shape[e] != array->sizeAt(e)) - return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); - } - } else if (array->rankOf() == _shape.size() - 1) { - // case like 2d _shape, and 1D rows - for (int e = 1; e < _shape.size(); e++) - if (_shape[e] != array->sizeAt(e - 1)) - return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); - } else - return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); - } - - //_elements++; - - // storing reference - _chunks[idx] = array; - - return Status::OK(); - } - - std::vector& NDArrayList::shape() { - return _shape; - } - - int NDArrayList::counter() { - return _counter++; - } - - void NDArrayList::unstack(NDArray* array, int axis) { - _axis = axis; - std::vector args({axis}); - auto newAxis = ShapeUtils::evalDimsToExclude(array->rankOf(), args); - auto result = array->allTensorsAlongDimension(newAxis); - for (int e = 0; e < result.size(); e++) { - auto chunk = result.at(e);//->dup(array->ordering()); - write(e, new NDArray(chunk->dup(array->ordering()))); + array = NDArray('c', {(Nd4jLong)numElements, 0}, + ArrayOptions::dataType(inShapeInfo), + inputs[0]->getContext()); } + } } + } else { + std::vector outShape(inShapeInfo + 1, inShapeInfo + 1 + rank); + outShape.insert(outShape.begin(), (Nd4jLong)numElements); + array = + NDArray(shape::order(inShapeInfo), outShape, + ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext()); + } - NDArray* NDArrayList::stack() { - // FIXME: this is bad for perf, but ok as poc - - int numElements = _elements.load(); - std::vector inputs(numElements); - for (int e = 0; e < numElements; e++) { - _chunks[e]->syncToDevice(); - inputs[e] = _chunks[e]; - } - - auto inShapeInfo = inputs[0]->shapeInfo(); - int rank = shape::rank(inShapeInfo); - NDArray* array = nullptr; - - if (shape::isEmpty(inShapeInfo)) { - switch (rank) { - case 0: { - if (numElements == 1) { - array = new NDArray(inputs[0]->ordering(), {0}, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext()); - } else { - array = new NDArray('c', {(Nd4jLong) numElements, 0}, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext() ) ; - } - } - } - } - else{ - std::vector outShape(inShapeInfo + 1, inShapeInfo + 1 + rank); - outShape.insert(outShape.begin(), (Nd4jLong) numElements); - array = new NDArray( shape::order(inShapeInfo), outShape, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext()); - } - - ops::helpers::stack(inputs[0]->getContext(), inputs, *array, 0); - - return array; - } - - std::pair& NDArrayList::id() { - return _id; - } + ops::helpers::stack(inputs[0]->getContext(), inputs, array, 0); - std::string& NDArrayList::name() { - return _name; - } + return array; +} - sd::LaunchContext * NDArrayList::context() { - return _context; - } +const std::pair &NDArrayList::id() const { return _state->_id; } - int NDArrayList::elements() { - return _elements.load(); - } +const std::string &NDArrayList::name() const { return _state->_name; } - int NDArrayList::height() { - //if (_height != 0) - // return _height; - //else - return (int) _chunks.size(); - } +int NDArrayList::elements() const { return (int)_state->_chunks.size(); } - bool NDArrayList::isWritten(int index) { - if (_chunks.count(index) > 0) - return true; - else - return false; - } +int NDArrayList::height() const { return (int)_state->_chunks.size(); } - NDArray* NDArrayList::pick(std::initializer_list indices) { - std::vector idcs(indices); - return pick(idcs); - } +bool NDArrayList::isWritten(int index) const { + if (_state->_chunks.count(index) > 0) + return true; + else + return false; +} - NDArray* NDArrayList::pick(std::vector &indices) { - std::vector shape(_shape); +NDArray NDArrayList::pick(const std::vector &indices) { + std::vector shape(_state->_shape); - //shape.insert(shape.begin() + _axis, indices.size()); - shape[_axis] = indices.size(); - // do we have to enforce C order here? - auto array = new NDArray('c', shape, _chunks[0]->dataType(), _context); - std::vector axis = ShapeUtils::evalDimsToExclude(shape.size(), {_axis}); - auto tads = array->allTensorsAlongDimension(axis); - int indicesSize = indices.size(); + // shape.insert(shape.begin() + _axis, indices.size()); + shape[_state->_axis] = indices.size(); + // do we have to enforce C order here? + NDArray array('c', shape, _state->_chunks.at(0).dataType()); + std::vector axis = + ShapeUtils::evalDimsToExclude(shape.size(), {_state->_axis}); + auto tads = array.allTensorsAlongDimension(axis); + int indicesSize = indices.size(); - if (tads.size() != indicesSize) - throw std::runtime_error("Number of TADs should match number of indices"); + if (tads.size() != indicesSize) + throw std::runtime_error("Number of TADs should match number of indices"); - for (int e = 0; e < indicesSize; e++) - tads.at(e)->assign(_chunks[indices[e]]); + for (int e = 0; e < indicesSize; e++) + tads.at(e).assign(_state->_chunks.at(indices[e])); - return array; - } + return array; +} - NDArrayList* NDArrayList::clone() { - auto list = new NDArrayList(_height, _expandable); - list->_axis = _axis; - list->_id.first = _id.first; - list->_id.second = _id.second; - list->_name = _name; - list->_elements.store(_elements.load()); +NDArrayList NDArrayList::clone() { + NDArrayList list(_state->_height, _state->_expandable); + list._state->_axis = _state->_axis; + list._state->_id.first = _state->_id.first; + list._state->_id.second = _state->_id.second; + list._state->_name = _state->_name; + list._state->_elements.store(_state->_elements.load()); - for (auto const& v : _chunks) { - list->_chunks[v.first] = new NDArray(v.second->dup()); - } + for (auto const &v : _state->_chunks) { + list._state->_chunks.insert({v.first, v.second.dup()}); + } - return list; - } + return list; +} - bool NDArrayList::equals(NDArrayList& other) { - if (_axis != other._axis) - return false; +bool NDArrayList::equals(NDArrayList &other) { + if (_state->_axis != other._state->_axis) return false; - if (_chunks.size() != other._chunks.size()) - return false; + if (_state->_chunks.size() != other._state->_chunks.size()) return false; - for (auto const& v : _chunks) { - if (other._chunks.count(v.first) == 0) - return false; + for (auto const &v : _state->_chunks) { + if (other._state->_chunks.count(v.first) == 0) return false; - auto arrThis = _chunks[v.first]; - auto arrThat = other._chunks[v.first]; + auto arrThis = _state->_chunks.at(v.first); + auto arrThat = other._state->_chunks.at(v.first); - if (!arrThis->equalsTo(arrThat)) - return false; - } + if (!arrThis.equalsTo(arrThat)) return false; + } - return true; - } -} \ No newline at end of file + return true; +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/array/impl/ResultSet.cpp b/libnd4j/include/array/impl/ResultSet.cpp index d9d824d4689e..41c218954d04 100644 --- a/libnd4j/include/array/impl/ResultSet.cpp +++ b/libnd4j/include/array/impl/ResultSet.cpp @@ -22,135 +22,115 @@ #include namespace sd { - ResultSet::ResultSet() { - // - } - - ResultSet::ResultSet(const sd::graph::FlatResult* result) { - for (int e = 0; e < result->variables()->size(); e++) { - auto var = result->variables()->Get(e); +ResultSet::ResultSet() { + // +} - NDArray* array; +ResultSet::ResultSet(const sd::graph::FlatResult* result) { + for (int e = 0; e < result->variables()->size(); e++) { + auto var = result->variables()->Get(e); - if (var->ndarray() != nullptr) { - array = sd::graph::FlatUtils::fromFlatArray(var->ndarray()); - } else if (var->shape() != nullptr) { - std::vector shapeInfo; - for (int i = 0; i < var->shape()->size(); i++) { - shapeInfo.emplace_back(var->shape()->Get(i)); - } + NDArray array; - // we just create empty array here - int s0 = shapeInfo.at(0); + if (var->ndarray() != nullptr) { + array = sd::graph::FlatUtils::fromFlatArray(var->ndarray()); + } else if (var->shape() != nullptr) { + std::vector shapeInfo; + for (int i = 0; i < var->shape()->size(); i++) { + shapeInfo.emplace_back(var->shape()->Get(i)); + } - std::vector shape; - for (int i = 0; i < s0; i++) { - shape.emplace_back(shapeInfo.at(i + 1)); - } + // we just create empty array here + int s0 = shapeInfo.at(0); - array = new NDArray((char) shapeInfo.at(shapeInfo.size() - 1), shape, DataTypeUtils::fromFlatDataType(var->dtype())); - } else { - nd4j_printf("Either shape or NDArray should be defined in FlatResult variable\n",""); - throw std::runtime_error("Empty variable"); - } + std::vector shape; + for (int i = 0; i < s0; i++) { + shape.emplace_back(shapeInfo.at(i + 1)); + } - _content.push_back(array); - } + array = NDArray((char)shapeInfo.at(shapeInfo.size() - 1), shape, + DataTypeUtils::fromFlatDataType(var->dtype())); + } else { + nd4j_printf( + "Either shape or NDArray should be defined in FlatResult variable\n", + ""); + throw std::runtime_error("Empty variable"); } - ResultSet::ResultSet(const ResultSet& other) noexcept{ - for (const auto v:other._content) - _content.emplace_back(v); + _content.push_back(array); + } +} - _status = other._status; - _removable = false; - } +ResultSet::ResultSet(const ResultSet& other) noexcept { + for (const auto v : other._content) _content.emplace_back(v); - //////////////////////////////////////////////////////////////////////// - // move constructor - ResultSet::ResultSet(ResultSet&& other) noexcept { + _status = other._status; + _removable = false; +} - _content = std::move(other._content); - _status = other._status; - _removable = other._removable; - other._removable = false; - } +//////////////////////////////////////////////////////////////////////// +// move constructor +ResultSet::ResultSet(ResultSet&& other) noexcept { + _content = std::move(other._content); + _status = other._status; + _removable = other._removable; + other._removable = false; +} - //////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////// // move assignment operator - ResultSet& ResultSet::operator=(ResultSet&& other) noexcept { +ResultSet& ResultSet::operator=(ResultSet&& other) noexcept { + if (this == &other) return *this; - if (this == &other) - return *this; + delContent(); - delContent(); + _content = std::move(other._content); - _content = std::move(other._content); + _status = other._status; + _removable = other._removable; + other._removable = false; - _status = other._status; - _removable = other._removable; - other._removable = false; - - return *this; - } - - ResultSet& ResultSet::operator=(const ResultSet& other) noexcept { - - if (this == &other) - return *this; - - delContent(); + return *this; +} - for (const auto v : other._content) - _content.push_back(v); +ResultSet& ResultSet::operator=(const ResultSet& other) noexcept { + if (this == &other) return *this; - _status = other._status; - _removable = false; + delContent(); - return *this; - } + for (const auto v : other._content) _content.push_back(v); - void ResultSet::delContent() { - if (_removable) - for (auto v : _content) - delete v; - } + _status = other._status; + _removable = false; - ResultSet::~ResultSet() { + return *this; +} - delContent(); - } +void ResultSet::delContent() { + // +} - void ResultSet::setNonRemovable() { - _removable = false; - } +ResultSet::~ResultSet() { delContent(); } - int ResultSet::size() { - return (int) _content.size(); - } +void ResultSet::setNonRemovable() { _removable = false; } - sd::NDArray* ResultSet::at(const unsigned long idx) const { - return _content.at(idx); - } +int ResultSet::size() { return (int)_content.size(); } - sd::NDArray* ResultSet::operator[](const unsigned long idx) const { - return _content[idx]; - } +sd::NDArray& ResultSet::at(const unsigned long idx) const { + return const_cast(_content[idx]); +} - void ResultSet::push_back(sd::NDArray *array) { - _content.emplace_back(array); - } +sd::NDArray& ResultSet::operator[](const unsigned long idx) const { + return const_cast(_content[idx]); +} - Nd4jStatus ResultSet::status() { - return _status; - } +void ResultSet::push_back(const sd::NDArray& array) { + _content.emplace_back(array); +} - void ResultSet::setStatus(Nd4jStatus status) { - _status = status; - } +Nd4jStatus ResultSet::status() { return _status; } - void ResultSet::purge() { - _content.clear(); - } -} +void ResultSet::setStatus(Nd4jStatus status) { _status = status; } +void ResultSet::purge() { _content.clear(); } +} // namespace sd diff --git a/libnd4j/include/array/impl/ShapeDescriptor.cpp b/libnd4j/include/array/impl/ShapeDescriptor.cpp index 3ef096312d35..a097ce276a08 100644 --- a/libnd4j/include/array/impl/ShapeDescriptor.cpp +++ b/libnd4j/include/array/impl/ShapeDescriptor.cpp @@ -19,362 +19,337 @@ // #include -#include #include +#include namespace sd { ////////////////////////////////////////////////////////////////////////// // equal to operator - bool ShapeDescriptor::operator==(const ShapeDescriptor &other) const { - - if (_empty != other._empty) - return false; - if (_rank != other._rank) - return false; - if (_order != other._order) - return false; - if (_dataType != other._dataType) - return false; - if (_ews != other._ews) - return false; - - if (_shape != other._shape) - return false; - - if (_strides != other._strides) - return false; - - return true; - } +bool ShapeDescriptor::operator==(const ShapeDescriptor &other) const { + if (_empty != other._empty) return false; + if (_rank != other._rank) return false; + if (_order != other._order) return false; + if (_dataType != other._dataType) return false; + if (_ews != other._ews) return false; + + if (_shape != other._shape) return false; + + if (_strides != other._strides) return false; + + return true; +} ////////////////////////////////////////////////////////////////////////// // less than operator - bool ShapeDescriptor::operator<(const ShapeDescriptor &other) const { - return std::tie(_empty, _rank, _dataType, _ews, _order, _shape, _strides) < - std::tie(other._empty, other._rank, other._dataType, other._ews, other._order, other._shape, - other._strides); +bool ShapeDescriptor::operator<(const ShapeDescriptor &other) const { + return std::tie(_empty, _rank, _dataType, _ews, _order, _shape, _strides) < + std::tie(other._empty, other._rank, other._dataType, other._ews, + other._order, other._shape, other._strides); +} + +Nd4jLong *ShapeDescriptor::toShapeInfo() const { + if (_empty) { + if (_rank == 0) + return ShapeBuilders::emptyShapeInfo(_dataType); + else { + return ShapeBuilders::emptyShapeInfo(_dataType, _order, _shape); } + } - Nd4jLong *ShapeDescriptor::toShapeInfo() const { - if (_empty) { - if (_rank == 0) - return ShapeBuilders::emptyShapeInfo(_dataType); - else { - return ShapeBuilders::emptyShapeInfo(_dataType, _order, _shape); - } - } - - - switch (_rank) { - case 0: { - auto shapeInfo = ShapeBuilders::createScalarShapeInfo(_dataType); - shapeInfo[2] = _ews; - return shapeInfo; - } - case 1: { - auto shapeInfo = ShapeBuilders::createVectorShapeInfo(_dataType, _shape[0]); - shapeInfo[2 + _rank * 2] = _ews; - shapeInfo[2] = _strides[0]; - shapeInfo[2 + _rank * 2 + 1] = _order; - return shapeInfo; - } - default: { - auto shapeInfo = ShapeBuilders::createShapeInfo(_dataType, _order, _shape); - - for (int e = 0; e < _rank; e++) - shapeInfo[e + 1 + _rank] = _strides[e]; - - shapeInfo[2 + _rank * 2] = _ews; - - return shapeInfo; - } - } + switch (_rank) { + case 0: { + auto shapeInfo = ShapeBuilders::createScalarShapeInfo(_dataType); + shapeInfo[2] = _ews; + return shapeInfo; } + case 1: { + auto shapeInfo = + ShapeBuilders::createVectorShapeInfo(_dataType, _shape[0]); + shapeInfo[2 + _rank * 2] = _ews; + shapeInfo[2] = _strides[0]; + shapeInfo[2 + _rank * 2 + 1] = _order; + return shapeInfo; + } + default: { + auto shapeInfo = + ShapeBuilders::createShapeInfo(_dataType, _order, _shape); - ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const Nd4jLong *shape, const int rank) - : _dataType(type), _order(order), _rank(rank), _ews(1) { - _shape.resize(rank); - _strides.resize(rank); + for (int e = 0; e < _rank; e++) shapeInfo[e + 1 + _rank] = _strides[e]; - for (int e = 0; e < rank; e++) - _shape[e] = shape[e]; + shapeInfo[2 + _rank * 2] = _ews; - if (order == 'c') - shape::calcStrides(_shape.data(), _shape.size(), _strides.data()); - else - shape::calcStridesFortran(_shape.data(), _shape.size(), _strides.data()); + return shapeInfo; + } + } +} +ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, + const Nd4jLong *shape, const int rank) + : _dataType(type), _order(order), _rank(rank), _ews(1) { + _shape.resize(rank); + _strides.resize(rank); - for (auto v:_shape) { - if (v == 0) { - _empty = true; - break; - } - } - } + for (int e = 0; e < rank; e++) _shape[e] = shape[e]; - ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const Nd4jLong *shape, - const Nd4jLong *strides, const int rank, Nd4jLong ews, const bool empty) { - _shape.resize(rank); - _strides.resize(rank); + if (order == 'c') + shape::calcStrides(_shape.data(), _shape.size(), _strides.data()); + else + shape::calcStridesFortran(_shape.data(), _shape.size(), _strides.data()); - _dataType = type; - _order = order; - _rank = rank; - _empty = empty; - _ews = ews; + for (auto v : _shape) { + if (v == 0) { + _empty = true; + break; + } + } +} + +ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, + const Nd4jLong *shape, const Nd4jLong *strides, + const int rank, Nd4jLong ews, + const bool empty) { + _shape.resize(rank); + _strides.resize(rank); - for (int e = 0; e < rank; e++) - _shape[e] = shape[e]; + _dataType = type; + _order = order; + _rank = rank; + _empty = empty; + _ews = ews; - for (int e = 0; e < rank; e++) - _strides[e] = strides[e]; + for (int e = 0; e < rank; e++) _shape[e] = shape[e]; + for (int e = 0; e < rank; e++) _strides[e] = strides[e]; - for (auto v:_shape) { - if (v == 0) { - _empty = true; - break; - } - } + for (auto v : _shape) { + if (v == 0) { + _empty = true; + break; } + } +} ////////////////////////////////////////////////////////////////////////// - ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector &shape) - : _dataType(type), _order(order), _shape(shape) { - _rank = shape.size(); - _ews = 1; - - if (_rank > 0) { - _strides.resize(_rank); - - for (auto v:_shape) { - if (v == 0) { - _empty = true; - break; - } - } - - // no point calculating strides for empty arrays - if (!_empty) { - if (order == 'c') - shape::calcStrides(_shape.data(), shape.size(), _strides.data()); - else - shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data()); - } else { - // all strides set to 0 - memset(_strides.data(), 0, sizeof(Nd4jLong) * shape.size()); - } - } +ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, + const std::vector &shape) + : _dataType(type), _order(order), _shape(shape) { + _rank = shape.size(); + _ews = 1; + + if (_rank > 0) { + _strides.resize(_rank); + + for (auto v : _shape) { + if (v == 0) { + _empty = true; + break; + } } -////////////////////////////////////////////////////////////////////////// - ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, - const std::initializer_list &shape) : _dataType(type), _order(order), - _shape(shape) { - _rank = shape.size(); - _ews = 1; - - _strides.resize(shape.size()); - if (order == 'c') - shape::calcStrides(_shape.data(), shape.size(), _strides.data()); - else - shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data()); - - for (auto v:_shape) { - if (v == 0) { - _empty = true; - break; - } - } + // no point calculating strides for empty arrays + if (!_empty) { + if (order == 'c') + shape::calcStrides(_shape.data(), shape.size(), _strides.data()); + else + shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data()); + } else { + // all strides set to 0 + memset(_strides.data(), 0, sizeof(Nd4jLong) * shape.size()); } + } +} ////////////////////////////////////////////////////////////////////////// - ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector &shape, - const std::vector &strides, const Nd4jLong ews) : ShapeDescriptor(type, - order, - shape, - strides) { - _ews = ews; +ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, + const std::initializer_list &shape) + : _dataType(type), _order(order), _shape(shape) { + _rank = shape.size(); + _ews = 1; + + _strides.resize(shape.size()); + if (order == 'c') + shape::calcStrides(_shape.data(), shape.size(), _strides.data()); + else + shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data()); + + for (auto v : _shape) { + if (v == 0) { + _empty = true; + break; } + } +} - ShapeDescriptor::ShapeDescriptor(const DataType type, const Nd4jLong length) : _dataType(type), _ews(1), - _order('c'), _rank(1), - _empty(false) { - _shape = {length}; - _strides = {1}; - } +////////////////////////////////////////////////////////////////////////// +ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, + const std::vector &shape, + const std::vector &strides, + const Nd4jLong ews) + : ShapeDescriptor(type, order, shape, strides) { + _ews = ews; +} - ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, bool inheritDtype) { - _order = shape::order(shapeInfo); - _ews = shape::elementWiseStride(shapeInfo); - _rank = shape::rank(shapeInfo); +ShapeDescriptor::ShapeDescriptor(const DataType type, const Nd4jLong length) + : _dataType(type), _ews(1), _order('c'), _rank(1), _empty(false) { + _shape = {length}; + _strides = {1}; +} - if (inheritDtype) - _dataType = ArrayOptions::dataType(shapeInfo); +ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, bool inheritDtype) { + _order = shape::order(shapeInfo); + _ews = shape::elementWiseStride(shapeInfo); + _rank = shape::rank(shapeInfo); - _empty = shape::isEmpty(shapeInfo); + if (inheritDtype) _dataType = ArrayOptions::dataType(shapeInfo); - for (int e = 0; e < _rank; e++) { - _shape.emplace_back(shapeInfo[e + 1]); - if (shapeInfo[e + 1] == 0) - _empty = true; - } + _empty = shape::isEmpty(shapeInfo); - for (int e = 0; e < _rank; e++) - _strides.emplace_back(shapeInfo[e + 1 + _rank]); - } + for (int e = 0; e < _rank; e++) { + _shape.emplace_back(shapeInfo[e + 1]); + if (shapeInfo[e + 1] == 0) _empty = true; + } - ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, const sd::DataType dtypeOverride) - : ShapeDescriptor::ShapeDescriptor(shapeInfo, false) { - _dataType = dtypeOverride; - } + for (int e = 0; e < _rank; e++) + _strides.emplace_back(shapeInfo[e + 1 + _rank]); +} - ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, const Nd4jLong *dtypeOverride) - : ShapeDescriptor::ShapeDescriptor(shapeInfo, ArrayOptions::dataType(dtypeOverride)) { - // - } +ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, + const sd::DataType dtypeOverride) + : ShapeDescriptor::ShapeDescriptor(shapeInfo, false) { + _dataType = dtypeOverride; +} - ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, const Nd4jLong *dtypeOverride, - const Nd4jLong *orderOverride) : ShapeDescriptor::ShapeDescriptor(shapeInfo, - ArrayOptions::dataType( - dtypeOverride)) { - _order = shape::order(orderOverride); - } +ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, + const Nd4jLong *dtypeOverride) + : ShapeDescriptor::ShapeDescriptor(shapeInfo, + ArrayOptions::dataType(dtypeOverride)) { + // +} - int ShapeDescriptor::rank() const { - return _rank; - } +ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, + const Nd4jLong *dtypeOverride, + const Nd4jLong *orderOverride) + : ShapeDescriptor::ShapeDescriptor(shapeInfo, + ArrayOptions::dataType(dtypeOverride)) { + _order = shape::order(orderOverride); +} - Nd4jLong ShapeDescriptor::ews() const { - return _ews; - } +int ShapeDescriptor::rank() const { return _rank; } - Nd4jLong ShapeDescriptor::arrLength() const { +Nd4jLong ShapeDescriptor::ews() const { return _ews; } - Nd4jLong len = 1; - for (const auto &dim : const_cast(this)->shape()) - len *= dim; - return len; - } +Nd4jLong ShapeDescriptor::arrLength() const { + Nd4jLong len = 1; + for (const auto &dim : const_cast(this)->shape()) + len *= dim; + return len; +} - char ShapeDescriptor::order() const { - return _order; - } +char ShapeDescriptor::order() const { return _order; } - DataType ShapeDescriptor::dataType() const { - return _dataType; - } +DataType ShapeDescriptor::dataType() const { return _dataType; } - bool ShapeDescriptor::isEmpty() const { - return _empty; - } +bool ShapeDescriptor::isEmpty() const { return _empty; } - std::vector &ShapeDescriptor::shape() { - return _shape; - } +std::vector &ShapeDescriptor::shape() { return _shape; } - std::vector &ShapeDescriptor::strides() { - return _strides; - } +std::vector &ShapeDescriptor::strides() { return _strides; } - ShapeDescriptor::ShapeDescriptor(const ShapeDescriptor &other) { - _rank = other._rank; - _ews = other._ews; - _empty = other._empty; - _dataType = other._dataType; - _order = other._order; - _shape = other._shape; - _strides = other._strides; - } +ShapeDescriptor::ShapeDescriptor(const ShapeDescriptor &other) { + _rank = other._rank; + _ews = other._ews; + _empty = other._empty; + _dataType = other._dataType; + _order = other._order; + _shape = other._shape; + _strides = other._strides; +} ////////////////////////////////////////////////////////////////////////// - ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector &shape, - const std::vector &strides) : _dataType(type), _order(order), - _shape(shape) { - - if (strides.empty() && !shape.empty()) { - _strides.resize(shape.size()); - if (order == 'c') - shape::calcStrides(_shape.data(), shape.size(), _strides.data()); - else - shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data()); - } else { - _strides = strides; - } - - - for (auto v:_shape) { - if (v == 0) { - _empty = true; - break; - } - } +ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, + const std::vector &shape, + const std::vector &strides) + : _dataType(type), _order(order), _shape(shape) { + if (strides.empty() && !shape.empty()) { + _strides.resize(shape.size()); + if (order == 'c') + shape::calcStrides(_shape.data(), shape.size(), _strides.data()); + else + shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data()); + } else { + _strides = strides; + } + + for (auto v : _shape) { + if (v == 0) { + _empty = true; + break; } + } +} - ShapeDescriptor ShapeDescriptor::emptyDescriptor(const DataType type) { - ShapeDescriptor descriptor; - descriptor._dataType = type; - descriptor._empty = true; - descriptor._rank = 0; - descriptor._order = 'c'; - descriptor._ews = 1; +ShapeDescriptor ShapeDescriptor::emptyDescriptor(const DataType type) { + ShapeDescriptor descriptor; + descriptor._dataType = type; + descriptor._empty = true; + descriptor._rank = 0; + descriptor._order = 'c'; + descriptor._ews = 1; - return descriptor; - } + return descriptor; +} - ShapeDescriptor ShapeDescriptor::scalarDescriptor(const DataType type) { - ShapeDescriptor descriptor; - descriptor._dataType = type; - descriptor._empty = false; - descriptor._rank = 0; - descriptor._order = 'c'; - descriptor._ews = 1; +ShapeDescriptor ShapeDescriptor::scalarDescriptor(const DataType type) { + ShapeDescriptor descriptor; + descriptor._dataType = type; + descriptor._empty = false; + descriptor._rank = 0; + descriptor._order = 'c'; + descriptor._ews = 1; - return descriptor; - } + return descriptor; +} - ShapeDescriptor ShapeDescriptor::vectorDescriptor(const Nd4jLong length, const DataType type) { - ShapeDescriptor descriptor; - descriptor._dataType = type; - descriptor._shape.emplace_back(length); +ShapeDescriptor ShapeDescriptor::vectorDescriptor(const Nd4jLong length, + const DataType type) { + ShapeDescriptor descriptor; + descriptor._dataType = type; + descriptor._shape.emplace_back(length); - if (length > 0) - descriptor._strides.emplace_back(1); - else { - descriptor._strides.emplace_back(0); - descriptor._empty = true; - } + if (length > 0) + descriptor._strides.emplace_back(1); + else { + descriptor._strides.emplace_back(0); + descriptor._empty = true; + } - descriptor._order = 'c'; - descriptor._ews = 1; - descriptor._rank = 1; + descriptor._order = 'c'; + descriptor._ews = 1; + descriptor._rank = 1; - return descriptor; - } + return descriptor; } +} // namespace sd namespace std { - size_t hash::operator()(const sd::ShapeDescriptor &k) const { - auto res = std::hash()(k.arrLength()); - res ^= std::hash()(k.order()) + 0x9e3779b9 + (res << 6) + (res >> 2); - res ^= k.dataType() + 0x9e3779b9 + (res << 6) + (res >> 2); - res ^= std::hash()(k.rank()) + 0x9e3779b9 + (res << 6) + (res >> 2); - res ^= std::hash()(k.ews()) + 0x9e3779b9 + (res << 6) + (res >> 2); - auto shapes = const_cast(k).shape(); - auto strides = const_cast(k).strides(); - for (auto s: shapes) { - res ^= std::hash()(s) + 0x9e3779b9 + (res << 6) + (res >> 2); - } - - for (auto s: strides) { - res ^= std::hash()(s) + 0x9e3779b9 + (res << 6) + (res >> 2); - } - - return res; - } +size_t hash::operator()( + const sd::ShapeDescriptor &k) const { + auto res = std::hash()(k.arrLength()); + res ^= std::hash()(k.order()) + 0x9e3779b9 + (res << 6) + (res >> 2); + res ^= k.dataType() + 0x9e3779b9 + (res << 6) + (res >> 2); + res ^= std::hash()(k.rank()) + 0x9e3779b9 + (res << 6) + (res >> 2); + res ^= std::hash()(k.ews()) + 0x9e3779b9 + (res << 6) + (res >> 2); + auto shapes = const_cast(k).shape(); + auto strides = const_cast(k).strides(); + for (auto s : shapes) { + res ^= std::hash()(s) + 0x9e3779b9 + (res << 6) + (res >> 2); + } + + for (auto s : strides) { + res ^= std::hash()(s) + 0x9e3779b9 + (res << 6) + (res >> 2); + } + + return res; } - - - +} // namespace std diff --git a/libnd4j/include/array/impl/ShapeList.cpp b/libnd4j/include/array/impl/ShapeList.cpp index d26132516227..ce610e23805e 100644 --- a/libnd4j/include/array/impl/ShapeList.cpp +++ b/libnd4j/include/array/impl/ShapeList.cpp @@ -18,69 +18,61 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { - //ShapeList::ShapeList(bool autoRemovable) { +// ShapeList::ShapeList(bool autoRemovable) { // _autoremovable = autoRemovable; // } - ShapeList::ShapeList(const Nd4jLong* shape) { - if (shape != nullptr) - _shapes.push_back(shape); - } +ShapeList::ShapeList(const Nd4jLong* shape) { + if (shape != nullptr) _shapes.push_back(shape); +} - ShapeList::~ShapeList() { - if (_autoremovable) - destroy(); - } +ShapeList::~ShapeList() { + if (_autoremovable) destroy(); +} - ShapeList::ShapeList(const std::vector &shapes, bool isWorkspace) : ShapeList(shapes){ - _workspace = isWorkspace; - } +ShapeList::ShapeList(const std::vector& shapes, + bool isWorkspace) + : ShapeList(shapes) { + _workspace = isWorkspace; +} - ShapeList::ShapeList(const std::vector& shapes) { - _shapes = shapes; - } +ShapeList::ShapeList(const std::vector& shapes) { + _shapes = shapes; +} - std::vector* ShapeList::asVector() { - return &_shapes; - } +std::vector* ShapeList::asVector() { return &_shapes; } - void ShapeList::destroy() { - if (_destroyed) - return; +void ShapeList::destroy() { + if (_destroyed) return; - if (!_workspace) - for (auto v:_shapes) - if(v != nullptr) - delete[] v; + if (!_workspace) + for (auto v : _shapes) + if (v != nullptr) delete[] v; - _destroyed = true; - } + _destroyed = true; +} - int ShapeList::size() const { - return (int) _shapes.size(); - } +int ShapeList::size() const { return (int)_shapes.size(); } - const Nd4jLong* ShapeList::at(int idx) { - if (_shapes.size() <= idx) - throw std::runtime_error("Can't find requested variable by index"); +const Nd4jLong* ShapeList::at(int idx) { + if (_shapes.size() <= idx) + throw std::runtime_error("Can't find requested variable by index"); - return _shapes.at(idx); - } + return _shapes.at(idx); +} - void ShapeList::push_back(const Nd4jLong *shape) { - _shapes.push_back(shape); - } +void ShapeList::push_back(const Nd4jLong* shape) { _shapes.push_back(shape); } - void ShapeList::detach() { - for (int e = 0; e < _shapes.size(); e++) { - _shapes[e] = shape::detachShape(_shapes[e]); - } +void ShapeList::detach() { + for (int e = 0; e < _shapes.size(); e++) { + _shapes[e] = shape::detachShape(_shapes[e]); + } - _autoremovable = true; - _workspace = false; - } -} \ No newline at end of file + _autoremovable = true; + _workspace = false; +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/array/impl/TadDescriptor.cpp b/libnd4j/include/array/impl/TadDescriptor.cpp index e2ec7480ee8f..8ceb6e2d4bb3 100644 --- a/libnd4j/include/array/impl/TadDescriptor.cpp +++ b/libnd4j/include/array/impl/TadDescriptor.cpp @@ -18,77 +18,74 @@ // @author raver119@gmail.com // +#include "../TadDescriptor.h" #include -#include "../TadDescriptor.h" namespace sd { - TadDescriptor::TadDescriptor(const TadDescriptor &other) { - _originalShape = other._originalShape; - _axis = other._axis; - _unitiesInShape = other._unitiesInShape; - } - - TadDescriptor::TadDescriptor(const Nd4jLong *originalShape, const int *dimensions, const int length, const bool keepUnitiesInShape) { - ShapeDescriptor descriptor(originalShape); - - _axis.resize(length); - for (int e = 0; e < length; e++) - _axis[e] = dimensions[e]; +TadDescriptor::TadDescriptor(const TadDescriptor &other) { + _originalShape = other._originalShape; + _axis = other._axis; + _unitiesInShape = other._unitiesInShape; +} - if (length > 1) - std::sort(_axis.begin(), _axis.end()); +TadDescriptor::TadDescriptor(const Nd4jLong *originalShape, + const int *dimensions, const int length, + const bool keepUnitiesInShape) { + ShapeDescriptor descriptor(originalShape); - _originalShape = descriptor; - _unitiesInShape = keepUnitiesInShape; - } + _axis.resize(length); + for (int e = 0; e < length; e++) _axis[e] = dimensions[e]; - TadDescriptor::TadDescriptor(const ShapeDescriptor &descriptor, const std::vector &dimensions, const bool keepUnitiesInShape) { - _originalShape = descriptor; - _axis = dimensions; - _unitiesInShape = keepUnitiesInShape; + if (length > 1) std::sort(_axis.begin(), _axis.end()); - if (_axis.size() > 1) - std::sort(_axis.begin(), _axis.end()); - } + _originalShape = descriptor; + _unitiesInShape = keepUnitiesInShape; +} - bool TadDescriptor::operator==(const TadDescriptor &other) const { - return std::tie(_originalShape, _axis, _unitiesInShape) == std::tie(other._originalShape, other._axis, other._unitiesInShape); - } +TadDescriptor::TadDescriptor(const ShapeDescriptor &descriptor, + const std::vector &dimensions, + const bool keepUnitiesInShape) { + _originalShape = descriptor; + _axis = dimensions; + _unitiesInShape = keepUnitiesInShape; + if (_axis.size() > 1) std::sort(_axis.begin(), _axis.end()); +} - bool TadDescriptor::operator<(const TadDescriptor &other) const { - return std::tie(_originalShape, _axis, _unitiesInShape) < std::tie(other._originalShape, other._axis, other._unitiesInShape); - } +bool TadDescriptor::operator==(const TadDescriptor &other) const { + return std::tie(_originalShape, _axis, _unitiesInShape) == + std::tie(other._originalShape, other._axis, other._unitiesInShape); +} - std::vector& TadDescriptor::axis() { - return _axis; - } +bool TadDescriptor::operator<(const TadDescriptor &other) const { + return std::tie(_originalShape, _axis, _unitiesInShape) < + std::tie(other._originalShape, other._axis, other._unitiesInShape); +} - ShapeDescriptor& TadDescriptor::originalShape(){ - return _originalShape; - } +std::vector &TadDescriptor::axis() { return _axis; } - ShapeDescriptor const& TadDescriptor::originalShapeConst() const{ - return _originalShape; - } +ShapeDescriptor &TadDescriptor::originalShape() { return _originalShape; } - bool TadDescriptor::areUnitiesinShape() const { - return _unitiesInShape; - } +ShapeDescriptor const &TadDescriptor::originalShapeConst() const { + return _originalShape; } +bool TadDescriptor::areUnitiesinShape() const { return _unitiesInShape; } +} // namespace sd + namespace std { - size_t hash::operator()(const sd::TadDescriptor &k) const { - // Compute individual hash values for first, - // second and third and combine them using XOR - // and bit shifting: - auto res = std::hash()((int)k.areUnitiesinShape()); - res ^= std::hash()(k.originalShapeConst()) + 0x9e3779b9 + (res << 6) + (res >> 2); - auto axes = const_cast(k).axis(); - for (auto a: axes) { - res ^= std::hash()(a) + 0x9e3779b9 + (res << 6) + (res >> 2); - } - return res; - } -} \ No newline at end of file +size_t hash::operator()(const sd::TadDescriptor &k) const { + // Compute individual hash values for first, + // second and third and combine them using XOR + // and bit shifting: + auto res = std::hash()((int)k.areUnitiesinShape()); + res ^= std::hash()(k.originalShapeConst()) + 0x9e3779b9 + + (res << 6) + (res >> 2); + auto axes = const_cast(k).axis(); + for (auto a : axes) { + res ^= std::hash()(a) + 0x9e3779b9 + (res << 6) + (res >> 2); + } + return res; +} +} // namespace std \ No newline at end of file diff --git a/libnd4j/include/array/impl/TadPack.cpp b/libnd4j/include/array/impl/TadPack.cpp index e489d0e83cb1..4daab4ddf67b 100644 --- a/libnd4j/include/array/impl/TadPack.cpp +++ b/libnd4j/include/array/impl/TadPack.cpp @@ -18,44 +18,47 @@ // @author raver119@gmail.com // -#include "../TadPack.h" -#include +#include #include +#include namespace sd { - TadPack::TadPack(const ConstantShapeBuffer &shapes, const ConstantOffsetsBuffer &offets, Nd4jLong numTads) : _tadShape(shapes), _tadOffsets(offets) { - _numTads = numTads; - } - - const Nd4jLong* TadPack::primaryShapeInfo() const { - return _tadShape.primary(); - } - - const Nd4jLong* TadPack::primaryOffsets() const { - return _tadOffsets.primary(); - } - - const Nd4jLong* TadPack::specialShapeInfo() const { - return _tadShape.special(); - } - - const Nd4jLong* TadPack::specialOffsets() const { - return _tadOffsets.special(); - } - - Nd4jLong TadPack::numberOfTads() const { - return _numTads; - } - - const Nd4jLong* TadPack::platformShapeInfo() const { - return sd::Environment::getInstance().isCPU() ? primaryShapeInfo() : specialShapeInfo(); - } - - const Nd4jLong* TadPack::platformOffsets() const { - return sd::Environment::getInstance().isCPU() ? primaryOffsets() : specialOffsets(); - } - - int TadPack::shapeInfoLength() const { - return (int) shape::shapeInfoLength(primaryShapeInfo()); - } -} \ No newline at end of file +TadPack::TadPack(const ConstantShapeBuffer& shapes, const ConstantOffsetsBuffer& offets, + Nd4jLong numTads) : + _tadShape (shapes), + _tadOffsets (offets) { + _numTads = numTads; +} + +const Nd4jLong* TadPack::primaryShapeInfo() const { + return _tadShape.primary(); +} + +const Nd4jLong* TadPack::primaryOffsets() const { + return _tadOffsets.primary(); +} + +const Nd4jLong* TadPack::specialShapeInfo() const { + return _tadShape.special(); +} + +const Nd4jLong* TadPack::specialOffsets() const { + return _tadOffsets.special(); +} + +Nd4jLong TadPack::numberOfTads() const { return _numTads; } + +const Nd4jLong* TadPack::platformShapeInfo() const { + return sd::Environment::getInstance().isCPU() ? primaryShapeInfo() + : specialShapeInfo(); +} + +const Nd4jLong* TadPack::platformOffsets() const { + return sd::Environment::getInstance().isCPU() ? primaryOffsets() + : specialOffsets(); +} + +int TadPack::shapeInfoLength() const { + return (int)shape::shapeInfoLength(primaryShapeInfo()); +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/cblas.h b/libnd4j/include/cblas.h old mode 100755 new mode 100644 index 18970a9b074f..27c7c605bbc4 --- a/libnd4j/include/cblas.h +++ b/libnd4j/include/cblas.h @@ -55,12 +55,16 @@ extern "C" { #endif #ifndef CBLAS_ENUM_DEFINED_H #define CBLAS_ENUM_DEFINED_H -enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102 }; -enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113, - AtlasConj=114}; -enum CBLAS_UPLO {CblasUpper=121, CblasLower=122}; -enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132}; -enum CBLAS_SIDE {CblasLeft=141, CblasRight=142}; +enum CBLAS_ORDER { CblasRowMajor = 101, CblasColMajor = 102 }; +enum CBLAS_TRANSPOSE { + CblasNoTrans = 111, + CblasTrans = 112, + CblasConjTrans = 113, + AtlasConj = 114 +}; +enum CBLAS_UPLO { CblasUpper = 121, CblasLower = 122 }; +enum CBLAS_DIAG { CblasNonUnit = 131, CblasUnit = 132 }; +enum CBLAS_SIDE { CblasLeft = 141, CblasRight = 142 }; #endif #ifndef CBLAS_ENUM_ONLY @@ -68,7 +72,7 @@ enum CBLAS_SIDE {CblasLeft=141, CblasRight=142}; #define CBLAS_INDEX int int cblas_errprn(int ierr, int info, char *form, ...); -void cblas_xerbla(int p, char *rout, char *form, ...); +void cblas_xerbla(int p, char *rout, char *form, ...); #ifdef __MKL void MKL_Set_Num_Threads(int num); @@ -80,57 +84,46 @@ void openblas_set_num_threads(int num); // do nothing #endif - /* * =========================================================================== * Prototypes for level 1 BLAS functions (complex are recast as routines) * =========================================================================== */ -float cblas_sdsdot(int N, float alpha, float *X, - int incX, float *Y, int incY); -double cblas_dsdot(int N, float *X, int incX, float *Y, - int incY); -float cblas_sdot(int N, float *X, int incX, - float *Y, int incY); -double cblas_ddot(int N, double *X, int incX, - double *Y, int incY); +float cblas_sdsdot(int N, float alpha, float *X, int incX, float *Y, int incY); +double cblas_dsdot(int N, float *X, int incX, float *Y, int incY); +float cblas_sdot(int N, float *X, int incX, float *Y, int incY); +double cblas_ddot(int N, double *X, int incX, double *Y, int incY); /* * Functions having prefixes Z and C only */ -void cblas_cdotu_sub(int N, void *X, int incX, - void *Y, int incY, void *dotu); -void cblas_cdotc_sub(int N, void *X, int incX, - void *Y, int incY, void *dotc); - -void cblas_zdotu_sub(int N, void *X, int incX, - void *Y, int incY, void *dotu); -void cblas_zdotc_sub(int N, void *X, int incX, - void *Y, int incY, void *dotc); +void cblas_cdotu_sub(int N, void *X, int incX, void *Y, int incY, void *dotu); +void cblas_cdotc_sub(int N, void *X, int incX, void *Y, int incY, void *dotc); +void cblas_zdotu_sub(int N, void *X, int incX, void *Y, int incY, void *dotu); +void cblas_zdotc_sub(int N, void *X, int incX, void *Y, int incY, void *dotc); /* * Functions having prefixes S D SC DZ */ -float cblas_snrm2(int N, float *X, int incX); -float cblas_sasum(int N, float *X, int incX); - -double cblas_dnrm2(int N, double *X, int incX); -double cblas_dasum(int N, double *X, int incX); +float cblas_snrm2(int N, float *X, int incX); +float cblas_sasum(int N, float *X, int incX); -float cblas_scnrm2(int N, void *X, int incX); -float cblas_scasum(int N, void *X, int incX); +double cblas_dnrm2(int N, double *X, int incX); +double cblas_dasum(int N, double *X, int incX); -double cblas_dznrm2(int N, void *X, int incX); -double cblas_dzasum(int N, void *X, int incX); +float cblas_scnrm2(int N, void *X, int incX); +float cblas_scasum(int N, void *X, int incX); +double cblas_dznrm2(int N, void *X, int incX); +double cblas_dzasum(int N, void *X, int incX); /* * Functions having standard 4 prefixes (S D C Z) */ -CBLAS_INDEX cblas_isamax(int N, float *X, int incX); -CBLAS_INDEX cblas_idamax(int N, double *X, int incX); -CBLAS_INDEX cblas_icamax(int N, void *X, int incX); -CBLAS_INDEX cblas_izamax(int N, void *X, int incX); +CBLAS_INDEX cblas_isamax(int N, float *X, int incX); +CBLAS_INDEX cblas_idamax(int N, double *X, int incX); +CBLAS_INDEX cblas_icamax(int N, void *X, int incX); +CBLAS_INDEX cblas_izamax(int N, void *X, int incX); /* * =========================================================================== @@ -141,88 +134,67 @@ CBLAS_INDEX cblas_izamax(int N, void *X, int incX); /* * Routines with standard 4 prefixes (s, d, c, z) */ -void cblas_sswap(int N, float *X, int incX, - float *Y, int incY); -void cblas_scopy(int N, float *X, int incX, - float *Y, int incY); -void cblas_saxpy(int N, float alpha, float *X, - int incX, float *Y, int incY); -void catlas_saxpby(int N, float alpha, float *X, - int incX, float beta, float *Y, int incY); -void catlas_sset - (int N, float alpha, float *X, int incX); - -void cblas_dswap(int N, double *X, int incX, - double *Y, int incY); -void cblas_dcopy(int N, double *X, int incX, - double *Y, int incY); -void cblas_daxpy(int N, double alpha, double *X, - int incX, double *Y, int incY); -void catlas_daxpby(int N, double alpha, double *X, - int incX, double beta, double *Y, int incY); -void catlas_dset - (int N, double alpha, double *X, int incX); - -void cblas_cswap(int N, void *X, int incX, - void *Y, int incY); -void cblas_ccopy(int N, void *X, int incX, - void *Y, int incY); -void cblas_caxpy(int N, void *alpha, void *X, - int incX, void *Y, int incY); -void catlas_caxpby(int N, void *alpha, void *X, - int incX, void *beta, void *Y, int incY); -void catlas_cset - (int N, void *alpha, void *X, int incX); - -void cblas_zswap(int N, void *X, int incX, - void *Y, int incY); -void cblas_zcopy(int N, void *X, int incX, - void *Y, int incY); -void cblas_zaxpy(int N, void *alpha, void *X, - int incX, void *Y, int incY); -void catlas_zaxpby(int N, void *alpha, void *X, - int incX, void *beta, void *Y, int incY); -void catlas_zset - (int N, void *alpha, void *X, int incX); +void cblas_sswap(int N, float *X, int incX, float *Y, int incY); +void cblas_scopy(int N, float *X, int incX, float *Y, int incY); +void cblas_saxpy(int N, float alpha, float *X, int incX, float *Y, int incY); +void catlas_saxpby(int N, float alpha, float *X, int incX, float beta, float *Y, + int incY); +void catlas_sset(int N, float alpha, float *X, int incX); + +void cblas_dswap(int N, double *X, int incX, double *Y, int incY); +void cblas_dcopy(int N, double *X, int incX, double *Y, int incY); +void cblas_daxpy(int N, double alpha, double *X, int incX, double *Y, int incY); +void catlas_daxpby(int N, double alpha, double *X, int incX, double beta, + double *Y, int incY); +void catlas_dset(int N, double alpha, double *X, int incX); +void cblas_cswap(int N, void *X, int incX, void *Y, int incY); +void cblas_ccopy(int N, void *X, int incX, void *Y, int incY); +void cblas_caxpy(int N, void *alpha, void *X, int incX, void *Y, int incY); +void catlas_caxpby(int N, void *alpha, void *X, int incX, void *beta, void *Y, + int incY); +void catlas_cset(int N, void *alpha, void *X, int incX); + +void cblas_zswap(int N, void *X, int incX, void *Y, int incY); +void cblas_zcopy(int N, void *X, int incX, void *Y, int incY); +void cblas_zaxpy(int N, void *alpha, void *X, int incX, void *Y, int incY); +void catlas_zaxpby(int N, void *alpha, void *X, int incX, void *beta, void *Y, + int incY); +void catlas_zset(int N, void *alpha, void *X, int incX); /* * Routines with S and D prefix only */ void cblas_srotg(float *a, float *b, float *c, float *s); -void cblas_srotmg(float *d1, float *d2, float *b1, float b2, float *P); -void cblas_srot(int N, float *X, int incX, - float *Y, int incY, float c, float s); -void cblas_srotm(int N, float *X, int incX, - float *Y, int incY, float *P); +void cblas_srotmg(float *d1, float *d2, float *b1, float b2, float *P); +void cblas_srot(int N, float *X, int incX, float *Y, int incY, float c, + float s); +void cblas_srotm(int N, float *X, int incX, float *Y, int incY, float *P); void cblas_drotg(double *a, double *b, double *c, double *s); -void cblas_drotmg(double *d1, double *d2, double *b1, double b2, double *P); -void cblas_drot(int N, double *X, int incX, - double *Y, int incY, double c, double s); -void cblas_drotm(int N, double *X, int incX, - double *Y, int incY, double *P); - +void cblas_drotmg(double *d1, double *d2, double *b1, double b2, double *P); +void cblas_drot(int N, double *X, int incX, double *Y, int incY, double c, + double s); +void cblas_drotm(int N, double *X, int incX, double *Y, int incY, double *P); /* * Routines with S D C Z CS and ZD prefixes */ -void cblas_sscal(int N, float alpha, float *X, int incX); -void cblas_dscal(int N, double alpha, double *X, int incX); -void cblas_cscal(int N, void *alpha, void *X, int incX); -void cblas_zscal(int N, void *alpha, void *X, int incX); -void cblas_csscal(int N, float alpha, void *X, int incX); -void cblas_zdscal(int N, double alpha, void *X, int incX); +void cblas_sscal(int N, float alpha, float *X, int incX); +void cblas_dscal(int N, double alpha, double *X, int incX); +void cblas_cscal(int N, void *alpha, void *X, int incX); +void cblas_zscal(int N, void *alpha, void *X, int incX); +void cblas_csscal(int N, float alpha, void *X, int incX); +void cblas_zdscal(int N, double alpha, void *X, int incX); /* * Extra reference routines provided by ATLAS, but not mandated by the standard */ void cblas_crotg(void *a, void *b, void *c, void *s); void cblas_zrotg(void *a, void *b, void *c, void *s); -void cblas_csrot(int N, void *X, int incX, void *Y, int incY, - float c, float s); -void cblas_zdrot(int N, void *X, int incX, void *Y, int incY, - double c, double s); +void cblas_csrot(int N, void *X, int incX, void *Y, int incY, float c, float s); +void cblas_zdrot(int N, void *X, int incX, void *Y, int incY, double c, + double s); /* * =========================================================================== @@ -233,265 +205,200 @@ void cblas_zdrot(int N, void *X, int incX, void *Y, int incY, /* * Routines with standard 4 prefixes (S, D, C, Z) */ -void cblas_sgemv( enum CBLAS_ORDER Order, - enum CBLAS_TRANSPOSE TransA, int M, int N, - float alpha, float *A, int lda, - float *X, int incX, float beta, - float *Y, int incY); -void cblas_sgbmv( enum CBLAS_ORDER Order, - enum CBLAS_TRANSPOSE TransA, int M, int N, - int KL, int KU, float alpha, - float *A, int lda, float *X, - int incX, float beta, float *Y, int incY); -void cblas_strmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, float *A, int lda, - float *X, int incX); -void cblas_stbmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, int K, float *A, int lda, - float *X, int incX); -void cblas_stpmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, float *Ap, float *X, int incX); -void cblas_strsv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, float *A, int lda, float *X, - int incX); -void cblas_stbsv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, int K, float *A, int lda, - float *X, int incX); -void cblas_stpsv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, float *Ap, float *X, int incX); - -void cblas_dgemv( enum CBLAS_ORDER Order, - enum CBLAS_TRANSPOSE TransA, int M, int N, - double alpha, double *A, int lda, - double *X, int incX, double beta, - double *Y, int incY); -void cblas_dgbmv( enum CBLAS_ORDER Order, - enum CBLAS_TRANSPOSE TransA, int M, int N, - int KL, int KU, double alpha, - double *A, int lda, double *X, - int incX, double beta, double *Y, int incY); -void cblas_dtrmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, double *A, int lda, - double *X, int incX); -void cblas_dtbmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, int K, double *A, int lda, - double *X, int incX); -void cblas_dtpmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, double *Ap, double *X, int incX); -void cblas_dtrsv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, double *A, int lda, double *X, - int incX); -void cblas_dtbsv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, int K, double *A, int lda, - double *X, int incX); -void cblas_dtpsv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, double *Ap, double *X, int incX); - -void cblas_cgemv( enum CBLAS_ORDER Order, - enum CBLAS_TRANSPOSE TransA, int M, int N, - void *alpha, void *A, int lda, - void *X, int incX, void *beta, - void *Y, int incY); -void cblas_cgbmv( enum CBLAS_ORDER Order, - enum CBLAS_TRANSPOSE TransA, int M, int N, - int KL, int KU, void *alpha, - void *A, int lda, void *X, - int incX, void *beta, void *Y, int incY); -void cblas_ctrmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, void *A, int lda, - void *X, int incX); -void cblas_ctbmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, int K, void *A, int lda, - void *X, int incX); -void cblas_ctpmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, void *Ap, void *X, int incX); -void cblas_ctrsv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, void *A, int lda, void *X, - int incX); -void cblas_ctbsv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, int K, void *A, int lda, - void *X, int incX); -void cblas_ctpsv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, void *Ap, void *X, int incX); - -void cblas_zgemv( enum CBLAS_ORDER Order, - enum CBLAS_TRANSPOSE TransA, int M, int N, - void *alpha, void *A, int lda, - void *X, int incX, void *beta, - void *Y, int incY); -void cblas_zgbmv( enum CBLAS_ORDER Order, - enum CBLAS_TRANSPOSE TransA, int M, int N, - int KL, int KU, void *alpha, - void *A, int lda, void *X, - int incX, void *beta, void *Y, int incY); -void cblas_ztrmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, void *A, int lda, - void *X, int incX); -void cblas_ztbmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, int K, void *A, int lda, - void *X, int incX); -void cblas_ztpmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, void *Ap, void *X, int incX); -void cblas_ztrsv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, void *A, int lda, void *X, - int incX); -void cblas_ztbsv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, int K, void *A, int lda, - void *X, int incX); -void cblas_ztpsv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, void *Ap, void *X, int incX); - +void cblas_sgemv(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, int M, + int N, float alpha, float *A, int lda, float *X, int incX, + float beta, float *Y, int incY); +void cblas_sgbmv(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, int M, + int N, int KL, int KU, float alpha, float *A, int lda, + float *X, int incX, float beta, float *Y, int incY); +void cblas_strmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + float *A, int lda, float *X, int incX); +void cblas_stbmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + int K, float *A, int lda, float *X, int incX); +void cblas_stpmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + float *Ap, float *X, int incX); +void cblas_strsv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + float *A, int lda, float *X, int incX); +void cblas_stbsv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + int K, float *A, int lda, float *X, int incX); +void cblas_stpsv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + float *Ap, float *X, int incX); + +void cblas_dgemv(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, int M, + int N, double alpha, double *A, int lda, double *X, int incX, + double beta, double *Y, int incY); +void cblas_dgbmv(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, int M, + int N, int KL, int KU, double alpha, double *A, int lda, + double *X, int incX, double beta, double *Y, int incY); +void cblas_dtrmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + double *A, int lda, double *X, int incX); +void cblas_dtbmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + int K, double *A, int lda, double *X, int incX); +void cblas_dtpmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + double *Ap, double *X, int incX); +void cblas_dtrsv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + double *A, int lda, double *X, int incX); +void cblas_dtbsv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + int K, double *A, int lda, double *X, int incX); +void cblas_dtpsv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + double *Ap, double *X, int incX); + +void cblas_cgemv(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, int M, + int N, void *alpha, void *A, int lda, void *X, int incX, + void *beta, void *Y, int incY); +void cblas_cgbmv(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, int M, + int N, int KL, int KU, void *alpha, void *A, int lda, void *X, + int incX, void *beta, void *Y, int incY); +void cblas_ctrmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + void *A, int lda, void *X, int incX); +void cblas_ctbmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + int K, void *A, int lda, void *X, int incX); +void cblas_ctpmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + void *Ap, void *X, int incX); +void cblas_ctrsv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + void *A, int lda, void *X, int incX); +void cblas_ctbsv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + int K, void *A, int lda, void *X, int incX); +void cblas_ctpsv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + void *Ap, void *X, int incX); + +void cblas_zgemv(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, int M, + int N, void *alpha, void *A, int lda, void *X, int incX, + void *beta, void *Y, int incY); +void cblas_zgbmv(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, int M, + int N, int KL, int KU, void *alpha, void *A, int lda, void *X, + int incX, void *beta, void *Y, int incY); +void cblas_ztrmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + void *A, int lda, void *X, int incX); +void cblas_ztbmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + int K, void *A, int lda, void *X, int incX); +void cblas_ztpmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + void *Ap, void *X, int incX); +void cblas_ztrsv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + void *A, int lda, void *X, int incX); +void cblas_ztbsv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + int K, void *A, int lda, void *X, int incX); +void cblas_ztpsv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + void *Ap, void *X, int incX); /* * Routines with S and D prefixes only */ -void cblas_ssymv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, float alpha, float *A, - int lda, float *X, int incX, - float beta, float *Y, int incY); -void cblas_ssbmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, int K, float alpha, float *A, - int lda, float *X, int incX, - float beta, float *Y, int incY); -void cblas_sspmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, float alpha, float *Ap, - float *X, int incX, - float beta, float *Y, int incY); -void cblas_sger( enum CBLAS_ORDER Order, int M, int N, - float alpha, float *X, int incX, - float *Y, int incY, float *A, int lda); -void cblas_ssyr( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, float alpha, float *X, - int incX, float *A, int lda); -void cblas_sspr( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, float alpha, float *X, - int incX, float *Ap); -void cblas_ssyr2( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, float alpha, float *X, - int incX, float *Y, int incY, float *A, - int lda); -void cblas_sspr2( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, float alpha, float *X, - int incX, float *Y, int incY, float *A); - -void cblas_dsymv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, double alpha, double *A, - int lda, double *X, int incX, - double beta, double *Y, int incY); -void cblas_dsbmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, int K, double alpha, double *A, - int lda, double *X, int incX, - double beta, double *Y, int incY); -void cblas_dspmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, double alpha, double *Ap, - double *X, int incX, - double beta, double *Y, int incY); -void cblas_dger( enum CBLAS_ORDER Order, int M, int N, - double alpha, double *X, int incX, - double *Y, int incY, double *A, int lda); -void cblas_dsyr( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, double alpha, double *X, - int incX, double *A, int lda); -void cblas_dspr( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, double alpha, double *X, - int incX, double *Ap); -void cblas_dsyr2( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, double alpha, double *X, - int incX, double *Y, int incY, double *A, +void cblas_ssymv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + float alpha, float *A, int lda, float *X, int incX, float beta, + float *Y, int incY); +void cblas_ssbmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, int K, + float alpha, float *A, int lda, float *X, int incX, float beta, + float *Y, int incY); +void cblas_sspmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + float alpha, float *Ap, float *X, int incX, float beta, + float *Y, int incY); +void cblas_sger(enum CBLAS_ORDER Order, int M, int N, float alpha, float *X, + int incX, float *Y, int incY, float *A, int lda); +void cblas_ssyr(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + float alpha, float *X, int incX, float *A, int lda); +void cblas_sspr(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + float alpha, float *X, int incX, float *Ap); +void cblas_ssyr2(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + float alpha, float *X, int incX, float *Y, int incY, float *A, int lda); -void cblas_dspr2( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, double alpha, double *X, - int incX, double *Y, int incY, double *A); - +void cblas_sspr2(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + float alpha, float *X, int incX, float *Y, int incY, float *A); + +void cblas_dsymv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + double alpha, double *A, int lda, double *X, int incX, + double beta, double *Y, int incY); +void cblas_dsbmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, int K, + double alpha, double *A, int lda, double *X, int incX, + double beta, double *Y, int incY); +void cblas_dspmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + double alpha, double *Ap, double *X, int incX, double beta, + double *Y, int incY); +void cblas_dger(enum CBLAS_ORDER Order, int M, int N, double alpha, double *X, + int incX, double *Y, int incY, double *A, int lda); +void cblas_dsyr(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + double alpha, double *X, int incX, double *A, int lda); +void cblas_dspr(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + double alpha, double *X, int incX, double *Ap); +void cblas_dsyr2(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + double alpha, double *X, int incX, double *Y, int incY, + double *A, int lda); +void cblas_dspr2(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + double alpha, double *X, int incX, double *Y, int incY, + double *A); /* * Routines with C and Z prefixes only */ -void cblas_chemv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, void *alpha, void *A, - int lda, void *X, int incX, - void *beta, void *Y, int incY); -void cblas_chbmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, int K, void *alpha, void *A, - int lda, void *X, int incX, - void *beta, void *Y, int incY); -void cblas_chpmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, void *alpha, void *Ap, - void *X, int incX, - void *beta, void *Y, int incY); -void cblas_cgeru( enum CBLAS_ORDER Order, int M, int N, - void *alpha, void *X, int incX, - void *Y, int incY, void *A, int lda); -void cblas_cgerc( enum CBLAS_ORDER Order, int M, int N, - void *alpha, void *X, int incX, - void *Y, int incY, void *A, int lda); -void cblas_cher( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, float alpha, void *X, int incX, - void *A, int lda); -void cblas_chpr( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, float alpha, void *X, - int incX, void *A); -void cblas_cher2( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, - void *alpha, void *X, int incX, - void *Y, int incY, void *A, int lda); -void cblas_chpr2( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, - void *alpha, void *X, int incX, - void *Y, int incY, void *Ap); - -void cblas_zhemv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, void *alpha, void *A, - int lda, void *X, int incX, - void *beta, void *Y, int incY); -void cblas_zhbmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, int K, void *alpha, void *A, - int lda, void *X, int incX, - void *beta, void *Y, int incY); -void cblas_zhpmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, void *alpha, void *Ap, - void *X, int incX, - void *beta, void *Y, int incY); -void cblas_zgeru( enum CBLAS_ORDER Order, int M, int N, - void *alpha, void *X, int incX, - void *Y, int incY, void *A, int lda); -void cblas_zgerc( enum CBLAS_ORDER Order, int M, int N, - void *alpha, void *X, int incX, - void *Y, int incY, void *A, int lda); -void cblas_zher( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, double alpha, void *X, int incX, - void *A, int lda); -void cblas_zhpr( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, double alpha, void *X, - int incX, void *A); -void cblas_zher2( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, - void *alpha, void *X, int incX, - void *Y, int incY, void *A, int lda); -void cblas_zhpr2( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, - void *alpha, void *X, int incX, - void *Y, int incY, void *Ap); +void cblas_chemv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + void *alpha, void *A, int lda, void *X, int incX, void *beta, + void *Y, int incY); +void cblas_chbmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, int K, + void *alpha, void *A, int lda, void *X, int incX, void *beta, + void *Y, int incY); +void cblas_chpmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + void *alpha, void *Ap, void *X, int incX, void *beta, void *Y, + int incY); +void cblas_cgeru(enum CBLAS_ORDER Order, int M, int N, void *alpha, void *X, + int incX, void *Y, int incY, void *A, int lda); +void cblas_cgerc(enum CBLAS_ORDER Order, int M, int N, void *alpha, void *X, + int incX, void *Y, int incY, void *A, int lda); +void cblas_cher(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + float alpha, void *X, int incX, void *A, int lda); +void cblas_chpr(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + float alpha, void *X, int incX, void *A); +void cblas_cher2(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + void *alpha, void *X, int incX, void *Y, int incY, void *A, + int lda); +void cblas_chpr2(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + void *alpha, void *X, int incX, void *Y, int incY, void *Ap); + +void cblas_zhemv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + void *alpha, void *A, int lda, void *X, int incX, void *beta, + void *Y, int incY); +void cblas_zhbmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, int K, + void *alpha, void *A, int lda, void *X, int incX, void *beta, + void *Y, int incY); +void cblas_zhpmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + void *alpha, void *Ap, void *X, int incX, void *beta, void *Y, + int incY); +void cblas_zgeru(enum CBLAS_ORDER Order, int M, int N, void *alpha, void *X, + int incX, void *Y, int incY, void *A, int lda); +void cblas_zgerc(enum CBLAS_ORDER Order, int M, int N, void *alpha, void *X, + int incX, void *Y, int incY, void *A, int lda); +void cblas_zher(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + double alpha, void *X, int incX, void *A, int lda); +void cblas_zhpr(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + double alpha, void *X, int incX, void *A); +void cblas_zher2(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + void *alpha, void *X, int incX, void *Y, int incY, void *A, + int lda); +void cblas_zhpr2(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + void *alpha, void *X, int incX, void *Y, int incY, void *Ap); /* * =========================================================================== @@ -502,163 +409,126 @@ void cblas_zhpr2( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, /* * Routines with standard 4 prefixes (S, D, C, Z) */ -void cblas_sgemm( enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, - enum CBLAS_TRANSPOSE TransB, int M, int N, - int K, float alpha, float *A, - int lda, float *B, int ldb, - float beta, float *C, int ldc); -void cblas_ssymm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, int M, int N, - float alpha, float *A, int lda, - float *B, int ldb, float beta, - float *C, int ldc); -void cblas_ssyrk( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE Trans, int N, int K, - float alpha, float *A, int lda, - float beta, float *C, int ldc); -void cblas_ssyr2k( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE Trans, int N, int K, - float alpha, float *A, int lda, - float *B, int ldb, float beta, - float *C, int ldc); -void cblas_strmm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, - enum CBLAS_DIAG Diag, int M, int N, - float alpha, float *A, int lda, - float *B, int ldb); -void cblas_strsm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, - enum CBLAS_DIAG Diag, int M, int N, - float alpha, float *A, int lda, - float *B, int ldb); - -void cblas_dgemm( enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, - enum CBLAS_TRANSPOSE TransB, int M, int N, - int K, double alpha, double *A, - int lda, double *B, int ldb, - double beta, double *C, int ldc); -void cblas_dsymm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, int M, int N, - double alpha, double *A, int lda, - double *B, int ldb, double beta, - double *C, int ldc); -void cblas_dsyrk( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE Trans, int N, int K, - double alpha, double *A, int lda, - double beta, double *C, int ldc); -void cblas_dsyr2k( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE Trans, int N, int K, - double alpha, double *A, int lda, - double *B, int ldb, double beta, +void cblas_sgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_TRANSPOSE TransB, int M, int N, int K, float alpha, + float *A, int lda, float *B, int ldb, float beta, float *C, + int ldc); +void cblas_ssymm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, int M, int N, float alpha, float *A, + int lda, float *B, int ldb, float beta, float *C, int ldc); +void cblas_ssyrk(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE Trans, int N, int K, float alpha, + float *A, int lda, float beta, float *C, int ldc); +void cblas_ssyr2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE Trans, int N, int K, float alpha, + float *A, int lda, float *B, int ldb, float beta, float *C, + int ldc); +void cblas_strmm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_DIAG Diag, int M, int N, float alpha, float *A, + int lda, float *B, int ldb); +void cblas_strsm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_DIAG Diag, int M, int N, float alpha, float *A, + int lda, float *B, int ldb); + +void cblas_dgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_TRANSPOSE TransB, int M, int N, int K, double alpha, + double *A, int lda, double *B, int ldb, double beta, double *C, + int ldc); +void cblas_dsymm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, int M, int N, double alpha, double *A, + int lda, double *B, int ldb, double beta, double *C, int ldc); +void cblas_dsyrk(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE Trans, int N, int K, double alpha, + double *A, int lda, double beta, double *C, int ldc); +void cblas_dsyr2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE Trans, int N, int K, double alpha, + double *A, int lda, double *B, int ldb, double beta, double *C, int ldc); -void cblas_dtrmm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, - enum CBLAS_DIAG Diag, int M, int N, - double alpha, double *A, int lda, - double *B, int ldb); -void cblas_dtrsm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, - enum CBLAS_DIAG Diag, int M, int N, - double alpha, double *A, int lda, - double *B, int ldb); - -void cblas_cgemm( enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, - enum CBLAS_TRANSPOSE TransB, int M, int N, - int K, void *alpha, void *A, - int lda, void *B, int ldb, - void *beta, void *C, int ldc); -void cblas_csymm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, int M, int N, - void *alpha, void *A, int lda, - void *B, int ldb, void *beta, - void *C, int ldc); -void cblas_csyrk( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE Trans, int N, int K, - void *alpha, void *A, int lda, - void *beta, void *C, int ldc); -void cblas_csyr2k( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE Trans, int N, int K, - void *alpha, void *A, int lda, - void *B, int ldb, void *beta, - void *C, int ldc); -void cblas_ctrmm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, - enum CBLAS_DIAG Diag, int M, int N, - void *alpha, void *A, int lda, - void *B, int ldb); -void cblas_ctrsm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, - enum CBLAS_DIAG Diag, int M, int N, - void *alpha, void *A, int lda, - void *B, int ldb); - -void cblas_zgemm( enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, - enum CBLAS_TRANSPOSE TransB, int M, int N, - int K, void *alpha, void *A, - int lda, void *B, int ldb, - void *beta, void *C, int ldc); -void cblas_zsymm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, int M, int N, - void *alpha, void *A, int lda, - void *B, int ldb, void *beta, - void *C, int ldc); -void cblas_zsyrk( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE Trans, int N, int K, - void *alpha, void *A, int lda, - void *beta, void *C, int ldc); -void cblas_zsyr2k( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE Trans, int N, int K, - void *alpha, void *A, int lda, - void *B, int ldb, void *beta, - void *C, int ldc); -void cblas_ztrmm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, - enum CBLAS_DIAG Diag, int M, int N, - void *alpha, void *A, int lda, - void *B, int ldb); -void cblas_ztrsm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, - enum CBLAS_DIAG Diag, int M, int N, - void *alpha, void *A, int lda, - void *B, int ldb); - +void cblas_dtrmm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_DIAG Diag, int M, int N, double alpha, double *A, + int lda, double *B, int ldb); +void cblas_dtrsm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_DIAG Diag, int M, int N, double alpha, double *A, + int lda, double *B, int ldb); + +void cblas_cgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_TRANSPOSE TransB, int M, int N, int K, void *alpha, + void *A, int lda, void *B, int ldb, void *beta, void *C, + int ldc); +void cblas_csymm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, int M, int N, void *alpha, void *A, + int lda, void *B, int ldb, void *beta, void *C, int ldc); +void cblas_csyrk(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE Trans, int N, int K, void *alpha, void *A, + int lda, void *beta, void *C, int ldc); +void cblas_csyr2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE Trans, int N, int K, void *alpha, + void *A, int lda, void *B, int ldb, void *beta, void *C, + int ldc); +void cblas_ctrmm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_DIAG Diag, int M, int N, void *alpha, void *A, + int lda, void *B, int ldb); +void cblas_ctrsm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_DIAG Diag, int M, int N, void *alpha, void *A, + int lda, void *B, int ldb); + +void cblas_zgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_TRANSPOSE TransB, int M, int N, int K, void *alpha, + void *A, int lda, void *B, int ldb, void *beta, void *C, + int ldc); +void cblas_zsymm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, int M, int N, void *alpha, void *A, + int lda, void *B, int ldb, void *beta, void *C, int ldc); +void cblas_zsyrk(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE Trans, int N, int K, void *alpha, void *A, + int lda, void *beta, void *C, int ldc); +void cblas_zsyr2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE Trans, int N, int K, void *alpha, + void *A, int lda, void *B, int ldb, void *beta, void *C, + int ldc); +void cblas_ztrmm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_DIAG Diag, int M, int N, void *alpha, void *A, + int lda, void *B, int ldb); +void cblas_ztrsm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_DIAG Diag, int M, int N, void *alpha, void *A, + int lda, void *B, int ldb); /* * Routines with prefixes C and Z only */ -void cblas_chemm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, int M, int N, - void *alpha, void *A, int lda, - void *B, int ldb, void *beta, - void *C, int ldc); -void cblas_cherk( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE Trans, int N, int K, - float alpha, void *A, int lda, - float beta, void *C, int ldc); -void cblas_cher2k( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE Trans, int N, int K, - void *alpha, void *A, int lda, - void *B, int ldb, float beta, - void *C, int ldc); -void cblas_zhemm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, int M, int N, - void *alpha, void *A, int lda, - void *B, int ldb, void *beta, - void *C, int ldc); -void cblas_zherk( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE Trans, int N, int K, - double alpha, void *A, int lda, - double beta, void *C, int ldc); -void cblas_zher2k( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE Trans, int N, int K, - void *alpha, void *A, int lda, - void *B, int ldb, double beta, - void *C, int ldc); +void cblas_chemm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, int M, int N, void *alpha, void *A, + int lda, void *B, int ldb, void *beta, void *C, int ldc); +void cblas_cherk(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE Trans, int N, int K, float alpha, void *A, + int lda, float beta, void *C, int ldc); +void cblas_cher2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE Trans, int N, int K, void *alpha, + void *A, int lda, void *B, int ldb, float beta, void *C, + int ldc); +void cblas_zhemm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, int M, int N, void *alpha, void *A, + int lda, void *B, int ldb, void *beta, void *C, int ldc); +void cblas_zherk(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE Trans, int N, int K, double alpha, + void *A, int lda, double beta, void *C, int ldc); +void cblas_zher2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE Trans, int N, int K, void *alpha, + void *A, int lda, void *B, int ldb, double beta, void *C, + int ldc); int cblas_errprn(int ierr, int info, char *form, ...); #ifdef __cplusplus } #endif -#endif /* end #ifdef CBLAS_ENUM_ONLY */ +#endif /* end #ifdef CBLAS_ENUM_ONLY */ #endif -#endif //NATIVEOPERATIONS_CBLAS_H +#endif // NATIVEOPERATIONS_CBLAS_H diff --git a/libnd4j/include/cblas_enum_conversion.h b/libnd4j/include/cblas_enum_conversion.h old mode 100755 new mode 100644 index 6ff6fe557f96..987b5e80557a --- a/libnd4j/include/cblas_enum_conversion.h +++ b/libnd4j/include/cblas_enum_conversion.h @@ -78,5 +78,4 @@ CBLAS_SIDE convertSide(int from); } #endif - -#endif //NATIVEOPERATIONS_CBLAS_ENUM_CONVERSION_H_H +#endif // NATIVEOPERATIONS_CBLAS_ENUM_CONVERSION_H_H diff --git a/libnd4j/include/cnpy/cnpy.h b/libnd4j/include/cnpy/cnpy.h index c84623599791..5ac2985027d6 100644 --- a/libnd4j/include/cnpy/cnpy.h +++ b/libnd4j/include/cnpy/cnpy.h @@ -22,270 +22,251 @@ * THE SOFTWARE. ******************************************************************************/ -//Copyright (C) 2011 Carl Rogers -//Released under MIT License -//license available in LICENSE file, or at http://www.opensource.org/licenses/mit-license.php +// Copyright (C) 2011 Carl Rogers +// Released under MIT License +// license available in LICENSE file, or at +// http://www.opensource.org/licenses/mit-license.php #ifndef LIBCNPY_H_ #define LIBCNPY_H_ - /** * */ -#include -#include -#include -#include -#include -#include +#include #include -#include -#include -#include -#include -#include +#include +#include -#include +#include +#include +#include +#include #include +#include +#include #include -#include -#include -#include - +#include +#include +#include +#include namespace cnpy { - /** - * The numpy array - */ - struct ND4J_EXPORT NpyArray { - char* data; - std::vector shape; - unsigned int wordSize; - bool fortranOrder; - void destruct() { - delete[] data; - } - }; - - struct ND4J_EXPORT npz_t : public std::map { - void destruct() { - npz_t::iterator it = this->begin(); - for(; it != this->end(); ++it) (*it).second.destruct(); - } - }; - - /** - * - * @param path - * @return - */ - ND4J_EXPORT char* loadFile(const char *path); - - - - /** - * - * @return - */ - char BigEndianTest(); - +/** + * The numpy array + */ +struct SD_EXPORT NpyArray { + char *data; + std::vector shape; + unsigned int wordSize; + bool fortranOrder; + void destruct() { delete[] data; } +}; + +struct SD_EXPORT npz_t : public std::map { + void destruct() { + npz_t::iterator it = this->begin(); + for (; it != this->end(); ++it) (*it).second.destruct(); + } +}; - /** - * - * @param t - * @return - */ - ND4J_EXPORT char mapType(const std::type_info &t); +/** + * + * @param path + * @return + */ +SD_EXPORT char *loadFile(const char *path); - template - ND4J_EXPORT char mapType(); +/** + * + * @return + */ +char BigEndianTest(); - /** - * - * @param T the type of the ndarray - * @param data the data for the ndarray - * @param shape the shape of the ndarray - * @param ndims the rank of the ndarray - * @return - */ - template - ND4J_EXPORT std::vector createNpyHeader(const void *data, - const unsigned int *shape, - const unsigned int ndims, - unsigned int wordSize = 4); - /** - * Parse the numpy header from - * the given file - * based on the pointers passed in - * @param fp the file to parse from - * @param wordSize the size of - * the individual elements - * @param shape - * @param ndims - * @param fortranOrder - */ - ND4J_EXPORT void parseNpyHeader(FILE *fp, - unsigned int &wordSize, - unsigned int *&shape, - unsigned int &ndims, - bool &fortranOrder); +/** + * + * @param t + * @return + */ +SD_EXPORT char mapType(const std::type_info &t); - /** - * Parse the numpy header from - * the given file - * based on the pointers passed in - * @param header the file to parse from - * @param word_size the size of - * the individual elements - * @param shape - * @param ndims - * @param fortran_order - */ - ND4J_EXPORT void parseNpyHeaderPointer( - const char *header, - unsigned int& word_size, - unsigned int*& shape, - unsigned int& ndims, - bool& fortran_order); - /** - * - * @param fp - * @param nrecs - * @param global_header_size - * @param global_header_offset - */ - ND4J_EXPORT void parseZipFooter(FILE *fp, - unsigned short &nrecs, - unsigned int &global_header_size, - unsigned int &global_header_offset); +template +SD_EXPORT char mapType(); - /** - * - * @param fname - * @param varname - * @return - */ - ND4J_EXPORT NpyArray npzLoad(std::string fname, std::string varname); +/** + * + * @param T the type of the ndarray + * @param data the data for the ndarray + * @param shape the shape of the ndarray + * @param ndims the rank of the ndarray + * @return + */ +template +SD_EXPORT std::vector createNpyHeader(const void *data, + const unsigned int *shape, + const unsigned int ndims, + unsigned int wordSize = 4); +/** + * Parse the numpy header from + * the given file + * based on the pointers passed in + * @param fp the file to parse from + * @param wordSize the size of + * the individual elements + * @param shape + * @param ndims + * @param fortranOrder + */ +SD_EXPORT void parseNpyHeader(FILE *fp, unsigned int &wordSize, + unsigned int *&shape, unsigned int &ndims, + bool &fortranOrder); - /** - * - * @param fname - * @return - */ - ND4J_EXPORT NpyArray npyLoad(std::string fname); +/** + * Parse the numpy header from + * the given file + * based on the pointers passed in + * @param header the file to parse from + * @param word_size the size of + * the individual elements + * @param shape + * @param ndims + * @param fortran_order + */ +SD_EXPORT void parseNpyHeaderPointer(const char *header, + unsigned int &word_size, + unsigned int *&shape, unsigned int &ndims, + bool &fortran_order); +/** + * + * @param fp + * @param nrecs + * @param global_header_size + * @param global_header_offset + */ +SD_EXPORT void parseZipFooter(FILE *fp, unsigned short &nrecs, + unsigned int &global_header_size, + unsigned int &global_header_offset); - /** - * Parse the numpy header from - * the given file - * based on the pointers passed in - * @param fp the file to parse from - * @param wordSize the size of - * the individual elements - * @param shape - * @param ndims - * @param fortranOrder - */ - ND4J_EXPORT void parseNpyHeaderStr(std::string header, - unsigned int &wordSize, - unsigned int *&shape, - unsigned int &ndims, - bool &fortranOrder); +/** + * + * @param fname + * @param varname + * @return + */ +SD_EXPORT NpyArray npzLoad(std::string fname, std::string varname); +/** + * + * @param fname + * @return + */ +SD_EXPORT NpyArray npyLoad(std::string fname); - /** - * - * @param fp - * @return - */ - ND4J_EXPORT int* shapeFromFile(FILE *fp); +/** + * Parse the numpy header from + * the given file + * based on the pointers passed in + * @param fp the file to parse from + * @param wordSize the size of + * the individual elements + * @param shape + * @param ndims + * @param fortranOrder + */ +SD_EXPORT void parseNpyHeaderStr(std::string header, unsigned int &wordSize, + unsigned int *&shape, unsigned int &ndims, + bool &fortranOrder); - /** - * - * @param data - * @return - */ - ND4J_EXPORT int* shapeFromPointer(char *data); +/** + * + * @param fp + * @return + */ +SD_EXPORT int *shapeFromFile(FILE *fp); - /** - * Load the numpy array from the given file. - * @param fp the file to load - * @return the loaded array - */ - ND4J_EXPORT NpyArray loadNpyFromFile(FILE *fp); +/** + * + * @param data + * @return + */ +SD_EXPORT int *shapeFromPointer(char *data); - /** - * Load the numpy array archive from the given file. - * @param fp the file to load - * @return the loaded archive - */ - ND4J_EXPORT npz_t npzLoad(FILE* fp); - /** - * - * @param data - * @return - */ - ND4J_EXPORT NpyArray loadNpyFromPointer(char *data); +/** + * Load the numpy array from the given file. + * @param fp the file to load + * @return the loaded array + */ +SD_EXPORT NpyArray loadNpyFromFile(FILE *fp); - /** - * - * @param data - * @return - */ - ND4J_EXPORT NpyArray loadNpyFromHeader(char *data); +/** + * Load the numpy array archive from the given file. + * @param fp the file to load + * @return the loaded archive + */ +SD_EXPORT npz_t npzLoad(FILE *fp); +/** + * + * @param data + * @return + */ +SD_EXPORT NpyArray loadNpyFromPointer(char *data); +/** + * + * @param data + * @return + */ +SD_EXPORT NpyArray loadNpyFromHeader(char *data); - ND4J_EXPORT npz_t npzLoad(std::string fname); +SD_EXPORT npz_t npzLoad(std::string fname); - ND4J_EXPORT sd::DataType dataTypeFromHeader(char *data); +SD_EXPORT sd::DataType dataTypeFromHeader(char *data); /** -* Parse the numpy header from -* the given file -* based on the pointers passed in -* @param fp the file to parse from -* @param word_size the size of -* the individual elements -* @param shape -* @param ndims -* @param fortran_order -*/ - ND4J_EXPORT void parseNpyHeader(std::string header, - unsigned int &word_size, - unsigned int *&shape, - unsigned int &ndims, - bool &fortran_order); - - /** - * - * @tparam T - * @param i - * @param pad - * @param padval - * @return - */ - template - FORCEINLINE std::string tostring(T i, int pad = 0, char padval = ' ') { - std::stringstream s; - s << i; - return s.str(); - } + * Parse the numpy header from + * the given file + * based on the pointers passed in + * @param fp the file to parse from + * @param word_size the size of + * the individual elements + * @param shape + * @param ndims + * @param fortran_order + */ +SD_EXPORT void parseNpyHeader(std::string header, unsigned int &word_size, + unsigned int *&shape, unsigned int &ndims, + bool &fortran_order); +/** + * + * @tparam T + * @param i + * @param pad + * @param padval + * @return + */ +template +FORCEINLINE std::string tostring(T i, int pad = 0, char padval = ' ') { + std::stringstream s; + s << i; + return s.str(); +} - template - ND4J_EXPORT void npy_save(std::string fname, const T* data, const unsigned int* shape, const unsigned int ndims, std::string mode = "w"); +template +SD_EXPORT void npy_save(std::string fname, const T *data, + const unsigned int *shape, const unsigned int ndims, + std::string mode = "w"); -} +} // namespace cnpy /** - * - * @tparam T - * @param lhs - * @param rhs - * @return - */ - template - ND4J_EXPORT std::vector& operator+=(std::vector& lhs, const T rhs); - + * + * @tparam T + * @param lhs + * @param rhs + * @return + */ +template +SD_EXPORT std::vector &operator+=(std::vector &lhs, const T rhs); #endif diff --git a/libnd4j/include/exceptions/allocation_exception.h b/libnd4j/include/exceptions/allocation_exception.h index 1e9b6653b550..beec84e40d3d 100644 --- a/libnd4j/include/exceptions/allocation_exception.h +++ b/libnd4j/include/exceptions/allocation_exception.h @@ -18,31 +18,33 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_ALLOCATION_EXCEPTION_H -#define DEV_TESTS_ALLOCATION_EXCEPTION_H +#ifndef SD_ALLOCATION_EXCEPTION_H +#define SD_ALLOCATION_EXCEPTION_H -#include -#include -#include #include +#include + +#include +#include #if defined(_MSC_VER) -// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library -#pragma warning( disable : 4275 ) +// we're ignoring warning about non-exportable parent class, since +// std::runtime_error is a part of Standard C++ Library +#pragma warning(disable : 4275) #endif namespace sd { - class ND4J_EXPORT allocation_exception : public std::runtime_error { - public: - allocation_exception(std::string message); - ~allocation_exception() = default; - - static allocation_exception build(std::string message, Nd4jLong bytes); - static allocation_exception build(std::string message, Nd4jLong limit, Nd4jLong bytes); - }; -} - - -#endif //DEV_TESTS_ALLOCATION_EXCEPTION_H +class SD_EXPORT allocation_exception : public std::runtime_error { + public: + allocation_exception(std::string message); + ~allocation_exception() = default; + + static allocation_exception build(std::string message, Nd4jLong bytes); + static allocation_exception build(std::string message, Nd4jLong limit, + Nd4jLong bytes); +}; +} // namespace sd + +#endif // SD_ALLOCATION_EXCEPTION_H diff --git a/libnd4j/include/exceptions/cuda_exception.h b/libnd4j/include/exceptions/cuda_exception.h index 2dc98eec3dff..6c8f288d3ee6 100644 --- a/libnd4j/include/exceptions/cuda_exception.h +++ b/libnd4j/include/exceptions/cuda_exception.h @@ -18,30 +18,30 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_CUDA_EXCEPTION_H -#define DEV_TESTS_CUDA_EXCEPTION_H +#ifndef SD_CUDA_EXCEPTION_H +#define SD_CUDA_EXCEPTION_H -#include -#include #include +#include +#include + #if defined(_MSC_VER) -// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library -#pragma warning( disable : 4275 ) +// we're ignoring warning about non-exportable parent class, since +// std::runtime_error is a part of Standard C++ Library +#pragma warning(disable : 4275) #endif namespace sd { - class ND4J_EXPORT cuda_exception : public std::runtime_error { - public: - cuda_exception(std::string message); - ~cuda_exception() = default; - - static cuda_exception build(std::string message, int errorCode); - }; -} - +class SD_EXPORT cuda_exception : public std::runtime_error { + public: + cuda_exception(std::string message); + ~cuda_exception() = default; + static cuda_exception build(std::string message, int errorCode); +}; +} // namespace sd -#endif //DEV_TESTS_CUDA_EXCEPTION_H +#endif // SD_CUDA_EXCEPTION_H diff --git a/libnd4j/include/exceptions/datatype_exception.h b/libnd4j/include/exceptions/datatype_exception.h index 74829d54c647..b3effeaf4e31 100644 --- a/libnd4j/include/exceptions/datatype_exception.h +++ b/libnd4j/include/exceptions/datatype_exception.h @@ -18,33 +18,37 @@ // Created by raver on 11/26/2018. // -#ifndef DEV_TESTS_DATATYPE_EXCEPTION_H -#define DEV_TESTS_DATATYPE_EXCEPTION_H +#ifndef SD_DATATYPE_EXCEPTION_H +#define SD_DATATYPE_EXCEPTION_H -#include -#include #include #include +#include +#include + #if defined(_MSC_VER) -// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library -#pragma warning( disable : 4275 ) +// we're ignoring warning about non-exportable parent class, since +// std::runtime_error is a part of Standard C++ Library +#pragma warning(disable : 4275) #endif namespace sd { - class ND4J_EXPORT datatype_exception : public std::runtime_error { - public: - datatype_exception(std::string message); - ~datatype_exception() = default; - - - static datatype_exception build(std::string message, sd::DataType actual); - static datatype_exception build(std::string message, sd::DataType expected, sd::DataType actual); - static datatype_exception build(std::string message, sd::DataType expected, sd::DataType actualX, sd::DataType actualY); - }; -} - - -#endif //DEV_TESTS_DATATYPE_EXCEPTION_H +class SD_EXPORT datatype_exception : public std::runtime_error { + public: + datatype_exception(const std::string &message); + ~datatype_exception() = default; + + static datatype_exception build(const std::string &message, + sd::DataType actual); + static datatype_exception build(const std::string &message, + sd::DataType expected, sd::DataType actual); + static datatype_exception build(const std::string &message, + sd::DataType expected, sd::DataType actualX, + sd::DataType actualY); +}; +} // namespace sd + +#endif // SD_DATATYPE_EXCEPTION_H diff --git a/libnd4j/include/exceptions/graph_exception.h b/libnd4j/include/exceptions/graph_exception.h index 7c9345a4deed..66bb5aad4792 100644 --- a/libnd4j/include/exceptions/graph_exception.h +++ b/libnd4j/include/exceptions/graph_exception.h @@ -21,37 +21,40 @@ #ifndef LIBND4J_GRAPH_EXCEPTION_H #define LIBND4J_GRAPH_EXCEPTION_H -#include -#include -#include #include +#include + +#include +#include #if defined(_MSC_VER) -// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library -#pragma warning( disable : 4275 ) +// we're ignoring warning about non-exportable parent class, since +// std::runtime_error is a part of Standard C++ Library +#pragma warning(disable : 4275) #endif namespace sd { - class ND4J_EXPORT graph_exception : public std::runtime_error { - protected: - Nd4jLong _graphId; - std::string _message; - std::string _description; - public: - graph_exception(std::string message, Nd4jLong graphId); - graph_exception(std::string message, std::string description, Nd4jLong graphId); - graph_exception(std::string message, const char *description, Nd4jLong graphId); - ~graph_exception() = default; - - Nd4jLong graphId(); - - const char * message(); - const char * description(); - }; -} - - - -#endif //DEV_TESTS_GRAPH_EXCEPTION_H +class SD_EXPORT graph_exception : public std::runtime_error { + protected: + Nd4jLong _graphId; + std::string _message; + std::string _description; + + public: + graph_exception(std::string message, Nd4jLong graphId); + graph_exception(std::string message, std::string description, + Nd4jLong graphId); + graph_exception(std::string message, const char *description, + Nd4jLong graphId); + ~graph_exception() = default; + + Nd4jLong graphId(); + + const char *message(); + const char *description(); +}; +} // namespace sd + +#endif // SD_GRAPH_EXCEPTION_H diff --git a/libnd4j/include/exceptions/graph_execution_exception.h b/libnd4j/include/exceptions/graph_execution_exception.h index 37f8e636e878..2b9ce263af12 100644 --- a/libnd4j/include/exceptions/graph_execution_exception.h +++ b/libnd4j/include/exceptions/graph_execution_exception.h @@ -18,27 +18,31 @@ // Created by raver on 8/31/2018. // -#ifndef DEV_TESTS_GRAPH_EXECUTION_EXCEPTION_H -#define DEV_TESTS_GRAPH_EXECUTION_EXCEPTION_H +#ifndef SD_GRAPH_EXECUTION_EXCEPTION_H +#define SD_GRAPH_EXECUTION_EXCEPTION_H +#include +#include #include #include + #include -#include -#include #if defined(_MSC_VER) -// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library -#pragma warning( disable : 4275 ) +// we're ignoring warning about non-exportable parent class, since +// std::runtime_error is a part of Standard C++ Library +#pragma warning(disable : 4275) #endif namespace sd { - class ND4J_EXPORT graph_execution_exception: public graph_exception { - public: - explicit graph_execution_exception(Nd4jLong graphId); - }; -} - -#endif //DEV_TESTS_UNKNOWN_GRAPH_EXCEPTION_H +class SD_EXPORT graph_execution_exception : public graph_exception { + public: + explicit graph_execution_exception(Nd4jLong graphId); + explicit graph_execution_exception(const std::string &message, + Nd4jStatus status); +}; +} // namespace sd + +#endif // SD_UNKNOWN_GRAPH_EXCEPTION_H diff --git a/libnd4j/include/exceptions/graph_exists_exception.h b/libnd4j/include/exceptions/graph_exists_exception.h index 63554c31b386..0023a49fa0a3 100644 --- a/libnd4j/include/exceptions/graph_exists_exception.h +++ b/libnd4j/include/exceptions/graph_exists_exception.h @@ -18,27 +18,29 @@ // Created by raver on 8/31/2018. // -#ifndef DEV_TESTS_GRAPH_EXISTS_EXCEPTION_H -#define DEV_TESTS_GRAPH_EXISTS_EXCEPTION_H +#ifndef SD_GRAPH_EXISTS_EXCEPTION_H +#define SD_GRAPH_EXISTS_EXCEPTION_H +#include +#include #include #include + #include -#include -#include #if defined(_MSC_VER) -// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library -#pragma warning( disable : 4275 ) +// we're ignoring warning about non-exportable parent class, since +// std::runtime_error is a part of Standard C++ Library +#pragma warning(disable : 4275) #endif namespace sd { - class ND4J_EXPORT graph_exists_exception: public graph_exception { - public: - explicit graph_exists_exception(Nd4jLong graphId); - }; -} +class SD_EXPORT graph_exists_exception : public graph_exception { + public: + explicit graph_exists_exception(Nd4jLong graphId); +}; +} // namespace sd -#endif //DEV_TESTS_UNKNOWN_GRAPH_EXCEPTION_H +#endif // SD_UNKNOWN_GRAPH_EXCEPTION_H diff --git a/libnd4j/include/exceptions/impl/allocation_exception.cpp b/libnd4j/include/exceptions/impl/allocation_exception.cpp index 46f2ef5c86ae..f29f4d0d07f7 100644 --- a/libnd4j/include/exceptions/impl/allocation_exception.cpp +++ b/libnd4j/include/exceptions/impl/allocation_exception.cpp @@ -22,20 +22,24 @@ #include namespace sd { - allocation_exception::allocation_exception(std::string message) : std::runtime_error(message){ - // - } +allocation_exception::allocation_exception(std::string message) + : std::runtime_error(message) { + // +} - allocation_exception allocation_exception::build(std::string message, Nd4jLong numBytes) { - auto bytes = StringUtils::valueToString(numBytes); - message += "; Requested bytes: [" + bytes + "]"; - return allocation_exception(message); - } +allocation_exception allocation_exception::build(std::string message, + Nd4jLong numBytes) { + auto bytes = StringUtils::valueToString(numBytes); + message += "; Requested bytes: [" + bytes + "]"; + return allocation_exception(message); +} - allocation_exception allocation_exception::build(std::string message, Nd4jLong limit, Nd4jLong numBytes) { - auto bytes = StringUtils::valueToString(numBytes); - auto lim = StringUtils::valueToString(limit); - message += "; Limit bytes: [" + lim + "]; Requested bytes: [" + bytes + "]"; - return allocation_exception(message); - } -} \ No newline at end of file +allocation_exception allocation_exception::build(std::string message, + Nd4jLong limit, + Nd4jLong numBytes) { + auto bytes = StringUtils::valueToString(numBytes); + auto lim = StringUtils::valueToString(limit); + message += "; Limit bytes: [" + lim + "]; Requested bytes: [" + bytes + "]"; + return allocation_exception(message); +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/exceptions/impl/cuda_exception.cpp b/libnd4j/include/exceptions/impl/cuda_exception.cpp index 91de6c2516a5..93854168a338 100644 --- a/libnd4j/include/exceptions/impl/cuda_exception.cpp +++ b/libnd4j/include/exceptions/impl/cuda_exception.cpp @@ -22,13 +22,14 @@ #include namespace sd { - cuda_exception::cuda_exception(std::string message) : std::runtime_error(message){ - // - } +cuda_exception::cuda_exception(std::string message) + : std::runtime_error(message) { + // +} - cuda_exception cuda_exception::build(std::string message, int errorCode) { - auto ec = StringUtils::valueToString(errorCode); - message += "; Error code: [" + ec + "]"; - return cuda_exception(message); - } -} \ No newline at end of file +cuda_exception cuda_exception::build(std::string message, int errorCode) { + auto ec = StringUtils::valueToString(errorCode); + message += "; Error code: [" + ec + "]"; + return cuda_exception(message); +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/exceptions/impl/datatype_exception.cpp b/libnd4j/include/exceptions/impl/datatype_exception.cpp index 9aab37951032..4b7e36010291 100644 --- a/libnd4j/include/exceptions/impl/datatype_exception.cpp +++ b/libnd4j/include/exceptions/impl/datatype_exception.cpp @@ -22,28 +22,36 @@ #include namespace sd { - datatype_exception::datatype_exception(std::string message) : std::runtime_error(message){ - // - } +datatype_exception::datatype_exception(const std::string &message) + : std::runtime_error(message) { + // +} - datatype_exception datatype_exception::build(std::string message, sd::DataType expected, sd::DataType actual) { - auto exp = DataTypeUtils::asString(expected); - auto act = DataTypeUtils::asString(actual); - message += "; Expected: [" + exp + "]; Actual: [" + act + "]"; - return datatype_exception(message); - } +datatype_exception datatype_exception::build(const std::string &message, + sd::DataType expected, + sd::DataType actual) { + auto exp = DataTypeUtils::asString(expected); + auto act = DataTypeUtils::asString(actual); + auto fmessage = message + "; Expected: [" + exp + "]; Actual: [" + act + "]"; + return datatype_exception(fmessage); +} - datatype_exception datatype_exception::build(std::string message, sd::DataType expected, sd::DataType actualX, sd::DataType actualY) { - auto exp = DataTypeUtils::asString(expected); - auto actX = DataTypeUtils::asString(actualX); - auto actY = DataTypeUtils::asString(actualY); - message += "; Expected: [" + exp + "]; Actual: [" + actX + ", " + actY + "]"; - return datatype_exception(message); - } +datatype_exception datatype_exception::build(const std::string &message, + sd::DataType expected, + sd::DataType actualX, + sd::DataType actualY) { + auto exp = DataTypeUtils::asString(expected); + auto actX = DataTypeUtils::asString(actualX); + auto actY = DataTypeUtils::asString(actualY); + auto fmessage = message + "; Expected: [" + exp + "]; Actual: [" + actX + + ", " + actY + "]"; + return datatype_exception(fmessage); +} - datatype_exception datatype_exception::build(std::string message, sd::DataType actual) { - auto act = DataTypeUtils::asString(actual); - message += "; Actual: [" + act + "]"; - return datatype_exception(message); - } -} \ No newline at end of file +datatype_exception datatype_exception::build(const std::string &message, + sd::DataType actual) { + auto act = DataTypeUtils::asString(actual); + auto fmessage = message + "; Actual: [" + act + "]"; + return datatype_exception(fmessage); +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/exceptions/impl/graph_exception.cpp b/libnd4j/include/exceptions/impl/graph_exception.cpp index fa2210a1d729..bb289ad54de4 100644 --- a/libnd4j/include/exceptions/impl/graph_exception.cpp +++ b/libnd4j/include/exceptions/impl/graph_exception.cpp @@ -22,33 +22,31 @@ #include namespace sd { - graph_exception::graph_exception(std::string message, Nd4jLong graphId) : std::runtime_error(message) { - this->_message = message; - this->_graphId = graphId; - } - - graph_exception::graph_exception(std::string message, std::string description, Nd4jLong graphId) : std::runtime_error(message) { - this->_message = message; - this->_description = description; - this->_graphId = graphId; - } - - graph_exception::graph_exception(std::string message, const char *description, Nd4jLong graphId) : std::runtime_error(message) { - this->_message = message; - this->_description = description; - this->_graphId = graphId; - } - - - Nd4jLong graph_exception::graphId() { - return _graphId; - } - - const char* graph_exception::message() { - return _message.c_str(); - } - - const char* graph_exception::description() { - return _description.c_str(); - } +graph_exception::graph_exception(std::string message, Nd4jLong graphId) + : std::runtime_error(message) { + this->_message = message; + this->_graphId = graphId; } + +graph_exception::graph_exception(std::string message, std::string description, + Nd4jLong graphId) + : std::runtime_error(message) { + this->_message = message; + this->_description = description; + this->_graphId = graphId; +} + +graph_exception::graph_exception(std::string message, const char* description, + Nd4jLong graphId) + : std::runtime_error(message) { + this->_message = message; + this->_description = description; + this->_graphId = graphId; +} + +Nd4jLong graph_exception::graphId() { return _graphId; } + +const char* graph_exception::message() { return _message.c_str(); } + +const char* graph_exception::description() { return _description.c_str(); } +} // namespace sd diff --git a/libnd4j/include/exceptions/impl/graph_execution_exception.cpp b/libnd4j/include/exceptions/impl/graph_execution_exception.cpp index 086796517f7c..340340f34398 100644 --- a/libnd4j/include/exceptions/impl/graph_execution_exception.cpp +++ b/libnd4j/include/exceptions/impl/graph_execution_exception.cpp @@ -18,11 +18,20 @@ // Created by raver on 8/31/2018. // -#include #include +#include namespace sd { - graph_execution_exception::graph_execution_exception(Nd4jLong graphId) : graph_exception(StringUtils::buildGraphErrorMessage("Caught exception during graph execution", graphId), graphId) { - _graphId = graphId; - } +graph_execution_exception::graph_execution_exception(Nd4jLong graphId) + : graph_exception(StringUtils::buildGraphErrorMessage( + "Caught exception during graph execution", graphId), + graphId) { + _graphId = graphId; +} + +graph_execution_exception::graph_execution_exception(const std::string &message, + Nd4jStatus status) + : graph_exception(message, status) { + // } +} // namespace sd diff --git a/libnd4j/include/exceptions/impl/graph_exists_exception.cpp b/libnd4j/include/exceptions/impl/graph_exists_exception.cpp index 535a74a6a0fe..8889aed2a628 100644 --- a/libnd4j/include/exceptions/impl/graph_exists_exception.cpp +++ b/libnd4j/include/exceptions/impl/graph_exists_exception.cpp @@ -18,11 +18,14 @@ // Created by raver on 8/31/2018. // -#include #include +#include namespace sd { - graph_exists_exception::graph_exists_exception(Nd4jLong graphId) : graph_exception(StringUtils::buildGraphErrorMessage("Graph with given ID already exists", graphId), graphId) { - _graphId = graphId; - } +graph_exists_exception::graph_exists_exception(Nd4jLong graphId) + : graph_exception(StringUtils::buildGraphErrorMessage( + "Graph with given ID already exists", graphId), + graphId) { + _graphId = graphId; } +} // namespace sd diff --git a/libnd4j/include/exceptions/impl/no_results_exception.cpp b/libnd4j/include/exceptions/impl/no_results_exception.cpp index ce3122ffbd8a..a20cd77c97fa 100644 --- a/libnd4j/include/exceptions/impl/no_results_exception.cpp +++ b/libnd4j/include/exceptions/impl/no_results_exception.cpp @@ -18,11 +18,14 @@ // Created by raver on 8/31/2018. // -#include #include +#include namespace sd { - no_results_exception::no_results_exception(Nd4jLong graphId) : graph_exception(StringUtils::buildGraphErrorMessage("Got no results after graph execution", graphId), graphId) { - _graphId = graphId; - } +no_results_exception::no_results_exception(Nd4jLong graphId) + : graph_exception(StringUtils::buildGraphErrorMessage( + "Got no results after graph execution", graphId), + graphId) { + _graphId = graphId; } +} // namespace sd diff --git a/libnd4j/include/exceptions/impl/shape_mismatch_exception.cpp b/libnd4j/include/exceptions/impl/shape_mismatch_exception.cpp new file mode 100644 index 000000000000..da8ee1241eda --- /dev/null +++ b/libnd4j/include/exceptions/impl/shape_mismatch_exception.cpp @@ -0,0 +1,39 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include + +namespace sd { +shape_mismatch_exception::shape_mismatch_exception(const std::string &message) + : std::runtime_error(message) { + // +} + +shape_mismatch_exception shape_mismatch_exception::build( + const std::string &message, const std::vector &expected, + const std::vector &actual) { + auto exp = ShapeUtils::shapeAsString(expected); + auto act = ShapeUtils::shapeAsString(actual); + auto fmessage = + message + "; Expected shape: " + exp + "; Actual shape: " + act + ";"; + return shape_mismatch_exception(fmessage); +} +} // namespace sd diff --git a/libnd4j/include/exceptions/impl/unknown_graph_exception.cpp b/libnd4j/include/exceptions/impl/unknown_graph_exception.cpp index ad73f3d33353..5cb64defe169 100644 --- a/libnd4j/include/exceptions/impl/unknown_graph_exception.cpp +++ b/libnd4j/include/exceptions/impl/unknown_graph_exception.cpp @@ -18,11 +18,14 @@ // Created by raver on 8/31/2018. // -#include #include +#include namespace sd { - unknown_graph_exception::unknown_graph_exception(Nd4jLong graphId) : graph_exception(StringUtils::buildGraphErrorMessage("Unknown graph", graphId), graphId) { - _graphId = graphId; - } +unknown_graph_exception::unknown_graph_exception(Nd4jLong graphId) + : graph_exception( + StringUtils::buildGraphErrorMessage("Unknown graph", graphId), + graphId) { + _graphId = graphId; } +} // namespace sd diff --git a/libnd4j/include/exceptions/no_results_exception.h b/libnd4j/include/exceptions/no_results_exception.h index b2687854b25d..415a76018d91 100644 --- a/libnd4j/include/exceptions/no_results_exception.h +++ b/libnd4j/include/exceptions/no_results_exception.h @@ -18,27 +18,29 @@ // Created by raver on 8/31/2018. // -#ifndef DEV_TESTS_NO_RESULTS_EXCEPTION_H -#define DEV_TESTS_NO_RESULTS_EXCEPTION_H +#ifndef SD_NO_RESULTS_EXCEPTION_H +#define SD_NO_RESULTS_EXCEPTION_H +#include +#include #include #include + #include -#include -#include #if defined(_MSC_VER) -// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library -#pragma warning( disable : 4275 ) +// we're ignoring warning about non-exportable parent class, since +// std::runtime_error is a part of Standard C++ Library +#pragma warning(disable : 4275) #endif namespace sd { - class ND4J_EXPORT no_results_exception: public graph_exception { - public: - explicit no_results_exception(Nd4jLong graphId); - }; -} +class SD_EXPORT no_results_exception : public graph_exception { + public: + explicit no_results_exception(Nd4jLong graphId); +}; +} // namespace sd -#endif //DEV_TESTS_UNKNOWN_GRAPH_EXCEPTION_H +#endif // SD_UNKNOWN_GRAPH_EXCEPTION_H diff --git a/libnd4j/include/exceptions/shape_mismatch_exception.h b/libnd4j/include/exceptions/shape_mismatch_exception.h new file mode 100644 index 000000000000..48868f892697 --- /dev/null +++ b/libnd4j/include/exceptions/shape_mismatch_exception.h @@ -0,0 +1,52 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_SHAPE_MISMATCH_EXCEPTION_H +#define SD_SHAPE_MISMATCH_EXCEPTION_H + +#include +#include +#include +#include + +#include +#include + +#if defined(_MSC_VER) + +// we're ignoring warning about non-exportable parent class, since +// std::runtime_error is a part of Standard C++ Library +#pragma warning(disable : 4275) + +#endif + +namespace sd { +class SD_EXPORT shape_mismatch_exception : public std::runtime_error { + public: + shape_mismatch_exception(const std::string &message); + ~shape_mismatch_exception() = default; + + static shape_mismatch_exception build(const std::string &message, + const std::vector &expected, + const std::vector &actual); +}; +} // namespace sd + +#endif // SD_SHAPE_MISMATCH_EXCEPTION_H diff --git a/libnd4j/include/exceptions/unknown_graph_exception.h b/libnd4j/include/exceptions/unknown_graph_exception.h index 917aeb757954..2ca79c6d272d 100644 --- a/libnd4j/include/exceptions/unknown_graph_exception.h +++ b/libnd4j/include/exceptions/unknown_graph_exception.h @@ -18,27 +18,29 @@ // Created by raver on 8/31/2018. // -#ifndef DEV_TESTS_UNKNOWN_GRAPH_EXCEPTION_H -#define DEV_TESTS_UNKNOWN_GRAPH_EXCEPTION_H +#ifndef SD_UNKNOWN_GRAPH_EXCEPTION_H +#define SD_UNKNOWN_GRAPH_EXCEPTION_H +#include +#include #include #include + #include -#include -#include #if defined(_MSC_VER) -// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library -#pragma warning( disable : 4275 ) +// we're ignoring warning about non-exportable parent class, since +// std::runtime_error is a part of Standard C++ Library +#pragma warning(disable : 4275) #endif namespace sd { - class ND4J_EXPORT unknown_graph_exception: public graph_exception { - public: - explicit unknown_graph_exception(Nd4jLong graphId); - }; -} +class SD_EXPORT unknown_graph_exception : public graph_exception { + public: + explicit unknown_graph_exception(Nd4jLong graphId); +}; +} // namespace sd -#endif //DEV_TESTS_UNKNOWN_GRAPH_EXCEPTION_H +#endif // SD_UNKNOWN_GRAPH_EXCEPTION_H diff --git a/libnd4j/include/execution/AffinityManager.h b/libnd4j/include/execution/AffinityManager.h index 757f637cee05..ae525b9c955e 100644 --- a/libnd4j/include/execution/AffinityManager.h +++ b/libnd4j/include/execution/AffinityManager.h @@ -23,24 +23,25 @@ #include #include + #include #include namespace sd { - class ND4J_EXPORT AffinityManager { - private: - static std::atomic _lastDevice; - static int _numberOfDevices; - static std::mutex _currentMutex; - static std::mutex _numberMutex; +class SD_EXPORT AffinityManager { + private: + static std::atomic _lastDevice; + static int _numberOfDevices; + static std::mutex _currentMutex; + static std::mutex _numberMutex; - public: - static int currentNativeDeviceId(); - static int currentDeviceId(); - static int numberOfDevices(); - static void setCurrentDevice(int deviceId); - static void setCurrentNativeDevice(int deviceId); - }; -} + public: + static int currentNativeDeviceId(); + static int currentDeviceId(); + static int numberOfDevices(); + static void setCurrentDevice(int deviceId); + static void setCurrentNativeDevice(int deviceId); +}; +} // namespace sd -#endif //DEV_TESTS_AFFINITYMANAGER_H +#endif // SD_AFFINITYMANAGER_H diff --git a/libnd4j/include/execution/BlockingQueue.h b/libnd4j/include/execution/BlockingQueue.h index a78196dfc745..de6957f13063 100644 --- a/libnd4j/include/execution/BlockingQueue.h +++ b/libnd4j/include/execution/BlockingQueue.h @@ -21,32 +21,33 @@ #ifndef SAMEDIFF_BLOCKINGQUEUE_H #define SAMEDIFF_BLOCKINGQUEUE_H -#include -#include -#include #include #include +#include +#include +#include namespace samediff { - template - class BlockingQueue { - private: - std::queue _queue; - std::mutex _lock; - std::atomic _size; - std::atomic _available; - - std::condition_variable _condition; - public: - BlockingQueue(int queueSize); - ~BlockingQueue() = default; - T poll(); - void put(const T &t); - - bool available(); - void markAvailable(); - void markUnavailable(); - }; -} - -#endif //DEV_TESTS_BLOCKINGQUEUE_H +template +class BlockingQueue { + private: + std::queue _queue; + std::mutex _lock; + std::atomic _size; + std::atomic _available; + + std::condition_variable _condition; + + public: + BlockingQueue(int queueSize); + ~BlockingQueue() = default; + T poll(); + void put(const T &t); + + bool available(); + void markAvailable(); + void markUnavailable(); +}; +} // namespace samediff + +#endif // SD_BLOCKINGQUEUE_H diff --git a/libnd4j/include/execution/CallableInterface.h b/libnd4j/include/execution/CallableInterface.h index aad83b379222..06a4a7ca7243 100644 --- a/libnd4j/include/execution/CallableInterface.h +++ b/libnd4j/include/execution/CallableInterface.h @@ -22,73 +22,82 @@ #define SAMEDIFF_CALLABLEINTERFACE_H #include + +#include +#include +#include #include #include -#include -#include #include -#include namespace samediff { - /** - * This class is suited for passing functions to execution threads without queues - */ - class CallableInterface { - private: - // parallel_for functions - FUNC_1D _function_1d; - FUNC_2D _function_2d; - FUNC_3D _function_3d; - - // parallel function - FUNC_DO _function_do; - - // reduction functions - FUNC_RL _function_rl; - FUNC_RD _function_rd; - - std::array _arguments; - - volatile int _branch = 0; - volatile uint32_t _thread_id = 0; - volatile uint32_t _num_threads = 0; - - std::atomic _finished; - std::atomic _filled; - std::atomic _available; - - std::condition_variable _starter; - std::condition_variable _finisher; - - int64_t* _lptr = nullptr; - double* _dptr = nullptr; - - std::mutex _ms; - std::mutex _mf; - public: - CallableInterface(); - ~CallableInterface() = default; - - void waitForTask(); - void waitForCompletion(); - - void fill(int thread_id, int num_threads, int64_t *lpt, FUNC_RL func, int64_t start_x, int64_t stop_x, int64_t inc_x); - void fill(int thread_id, int num_threads, double *dpt, FUNC_RD func, int64_t start_x, int64_t stop_x, int64_t inc_x); - - void fill(int thread_id, int num_threads, FUNC_DO func); - void fill(int thread_id, int num_threads, FUNC_1D func, int64_t start_x, int64_t stop_x, int64_t inc_x); - void fill(int thread_id, int num_threads, FUNC_2D func, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y); - void fill(int thread_id, int num_threads, FUNC_3D func, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z); - - bool available(); - void markAvailable(); - void markUnavailable(); - - void finish(); - - void execute(); - }; -} - - -#endif //DEV_TESTS_CALLABLEINTERFACE_H +/** + * This class is suited for passing functions to execution threads without + * queues + */ +class CallableInterface { + private: + // parallel_for functions + FUNC_1D _function_1d; + FUNC_2D _function_2d; + FUNC_3D _function_3d; + + // parallel function + FUNC_DO _function_do; + + // reduction functions + FUNC_RL _function_rl; + FUNC_RD _function_rd; + + std::array _arguments; + + volatile int _branch = 0; + volatile uint32_t _thread_id = 0; + volatile uint32_t _num_threads = 0; + + std::atomic _finished; + std::atomic _filled; + std::atomic _available; + + std::condition_variable _starter; + std::condition_variable _finisher; + + int64_t* _lptr = nullptr; + double* _dptr = nullptr; + + std::mutex _ms; + std::mutex _mf; + + public: + CallableInterface(); + ~CallableInterface() = default; + + void waitForTask(); + void waitForCompletion(); + + void fill(int thread_id, int num_threads, int64_t* lpt, FUNC_RL func, + int64_t start_x, int64_t stop_x, int64_t inc_x); + void fill(int thread_id, int num_threads, double* dpt, FUNC_RD func, + int64_t start_x, int64_t stop_x, int64_t inc_x); + + void fill(int thread_id, int num_threads, FUNC_DO func); + void fill(int thread_id, int num_threads, FUNC_1D func, int64_t start_x, + int64_t stop_x, int64_t inc_x); + void fill(int thread_id, int num_threads, FUNC_2D func, int64_t start_x, + int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, + int64_t inc_y); + void fill(int thread_id, int num_threads, FUNC_3D func, int64_t start_x, + int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, + int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z); + + bool available(); + void markAvailable(); + void markUnavailable(); + + void finish(); + + void execute(); +}; +} // namespace samediff + +#endif // SD_CALLABLEINTERFACE_H diff --git a/libnd4j/include/execution/CallableWithArguments.h b/libnd4j/include/execution/CallableWithArguments.h index 28ef8433e3ad..f926492df491 100644 --- a/libnd4j/include/execution/CallableWithArguments.h +++ b/libnd4j/include/execution/CallableWithArguments.h @@ -18,75 +18,80 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_CALLABLEWITHARGUMENTS_H -#define DEV_TESTS_CALLABLEWITHARGUMENTS_H +#ifndef SD_CALLABLEWITHARGUMENTS_H +#define SD_CALLABLEWITHARGUMENTS_H + +#include -#include -#include #include #include -#include +#include +#include namespace samediff { - class CallableWithArguments { - FUNC_DO _function_do; - FUNC_1D _function_1d; - FUNC_2D _function_2d; - FUNC_3D _function_3d; - - std::vector _arguments; - - std::atomic _finished; - - std::condition_variable _condition; - - std::mutex _lock; - - int _dimensions = 0; - - uint64_t _threadId; - uint64_t _numThreads; - public: - CallableWithArguments(FUNC_DO func, uint64_t thread_id, uint64_t numThreads); - CallableWithArguments(FUNC_1D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t increment_x); - CallableWithArguments(FUNC_2D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t increment_x, int64_t start_y, int64_t stop_y, int64_t increment_y); - CallableWithArguments(FUNC_3D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t increment_x, int64_t start_y, int64_t stop_y, int64_t increment_y, int64_t start_z, int64_t stop_z, int64_t increment_z); - - - /** - * This method returns number of dimensions - * @return - */ - int dimensions(); - - /** - * This method checks if this callable is finished - * @return - */ - bool finished(); - - /** - * this method marks this Callable as finished - */ - void finish(); - - /** - * This method blocks until callable is finished - */ - void waitUntilFinished(); - - std::vector& arguments(); - FUNC_DO function_do(); - FUNC_1D function_1d(); - FUNC_2D function_2d(); - FUNC_3D function_3d(); - - - uint64_t threadId(); - - uint64_t numThreads(); - }; -} - - -#endif //DEV_TESTS_CALLABLEWITHARGUMENTS_H +class CallableWithArguments { + FUNC_DO _function_do; + FUNC_1D _function_1d; + FUNC_2D _function_2d; + FUNC_3D _function_3d; + + std::vector _arguments; + + std::atomic _finished; + + std::condition_variable _condition; + + std::mutex _lock; + + int _dimensions = 0; + + uint64_t _threadId; + uint64_t _numThreads; + + public: + CallableWithArguments(FUNC_DO func, uint64_t thread_id, uint64_t numThreads); + CallableWithArguments(FUNC_1D func, uint64_t thread_id, int64_t start_x, + int64_t stop_x, int64_t increment_x); + CallableWithArguments(FUNC_2D func, uint64_t thread_id, int64_t start_x, + int64_t stop_x, int64_t increment_x, int64_t start_y, + int64_t stop_y, int64_t increment_y); + CallableWithArguments(FUNC_3D func, uint64_t thread_id, int64_t start_x, + int64_t stop_x, int64_t increment_x, int64_t start_y, + int64_t stop_y, int64_t increment_y, int64_t start_z, + int64_t stop_z, int64_t increment_z); + + /** + * This method returns number of dimensions + * @return + */ + int dimensions(); + + /** + * This method checks if this callable is finished + * @return + */ + bool finished(); + + /** + * this method marks this Callable as finished + */ + void finish(); + + /** + * This method blocks until callable is finished + */ + void waitUntilFinished(); + + std::vector& arguments(); + FUNC_DO function_do(); + FUNC_1D function_1d(); + FUNC_2D function_2d(); + FUNC_3D function_3d(); + + uint64_t threadId(); + + uint64_t numThreads(); +}; +} // namespace samediff + +#endif // SD_CALLABLEWITHARGUMENTS_H diff --git a/libnd4j/include/execution/ContextBuffers.h b/libnd4j/include/execution/ContextBuffers.h index c14671e426f9..8cb2891febdc 100644 --- a/libnd4j/include/execution/ContextBuffers.h +++ b/libnd4j/include/execution/ContextBuffers.h @@ -21,56 +21,57 @@ #ifndef LIBND4J_CONTEXTBUFFERS_H #define LIBND4J_CONTEXTBUFFERS_H +#include #include #include -#include namespace sd { - class ND4J_EXPORT ContextBuffers { - private: - void* _reductionPointer = nullptr; - void* _scalarPointer = nullptr; - void* _allocationPointer = nullptr; - void* _execStream = nullptr; - void* _specialStream = nullptr; - sd::ErrorReference _errorReference; - bool _allocated = false; - bool _initialized = false; +class SD_EXPORT ContextBuffers { + private: + void* _reductionPointer = nullptr; + void* _scalarPointer = nullptr; + void* _allocationPointer = nullptr; + void* _execStream = nullptr; + void* _specialStream = nullptr; + sd::ErrorReference _errorReference; + bool _allocated = false; + bool _initialized = false; - int _deviceId = -1; + int _deviceId = -1; - void initialize(); - public: - ContextBuffers(); - ContextBuffers(const ContextBuffers &other); - ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner = false); - ~ContextBuffers(); + void initialize(); - ContextBuffers& operator=(const ContextBuffers& other); - ContextBuffers& operator=(ContextBuffers&& other); + public: + ContextBuffers(); + ContextBuffers(const ContextBuffers& other); + ContextBuffers(void* rPointer, void* sPointer, void* aPointer, + bool isOwner = false); + ~ContextBuffers(); - void release(); + ContextBuffers& operator=(const ContextBuffers& other); + ContextBuffers& operator=(ContextBuffers&& other); - void* reductionBuffer(); - void* scalarBuffer(); - void* allocationBuffer(); + void release(); - void* execStream(); - void* specialStream(); + void* reductionBuffer(); + void* scalarBuffer(); + void* allocationBuffer(); - void setReductionBuffer(void* pointer); - void setScalarBuffer(void* pointer); - void setAllocationBuffer(void* pointer); + void* execStream(); + void* specialStream(); - sd::ErrorReference* errorReference(); + void setReductionBuffer(void* pointer); + void setScalarBuffer(void* pointer); + void setAllocationBuffer(void* pointer); - void triggerOwnership(bool isOwner); + sd::ErrorReference* errorReference(); - int deviceId(); + void triggerOwnership(bool isOwner); - bool isInitialized(); - }; -} + int deviceId(); + bool isInitialized(); +}; +} // namespace sd -#endif //DEV_TESTS_CONTEXTBUFFERS_H +#endif // SD_CONTEXTBUFFERS_H diff --git a/libnd4j/include/execution/Engine.h b/libnd4j/include/execution/Engine.h index cd30867a9bb2..69b16570491a 100644 --- a/libnd4j/include/execution/Engine.h +++ b/libnd4j/include/execution/Engine.h @@ -22,10 +22,10 @@ #define SD_ENGINE_H namespace samediff { - enum Engine { - ENGINE_CPU = 0, - ENGINE_CUDA = 1, - }; +enum Engine { + ENGINE_CPU = 0, + ENGINE_CUDA = 1, +}; } -#endif //SD_ENGINE_H +#endif // SD_ENGINE_H diff --git a/libnd4j/include/execution/ErrorReference.h b/libnd4j/include/execution/ErrorReference.h index b71090248994..2cbbeef4c66b 100644 --- a/libnd4j/include/execution/ErrorReference.h +++ b/libnd4j/include/execution/ErrorReference.h @@ -18,29 +18,30 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_ERRORREFERENCE_H -#define DEV_TESTS_ERRORREFERENCE_H +#ifndef SD_ERRORREFERENCE_H +#define SD_ERRORREFERENCE_H -#include #include +#include + namespace sd { - class ND4J_EXPORT ErrorReference { - private: - int _errorCode = 0; - std::string _errorMessage; - public: - ErrorReference() = default; - ~ErrorReference() = default; +class SD_EXPORT ErrorReference { + private: + int _errorCode = 0; + std::string _errorMessage; - int errorCode(); - const char* errorMessage(); + public: + ErrorReference() = default; + ~ErrorReference() = default; - void setErrorCode(int errorCode); - void setErrorMessage(std::string message); - void setErrorMessage(const char* message); - }; -} + int errorCode(); + const char* errorMessage(); + void setErrorCode(int errorCode); + void setErrorMessage(std::string message); + void setErrorMessage(const char* message); +}; +} // namespace sd -#endif //DEV_TESTS_ERRORREFERENCE_H +#endif // SD_ERRORREFERENCE_H diff --git a/libnd4j/include/execution/ExecutionMode.h b/libnd4j/include/execution/ExecutionMode.h index ea97e3fc9bdf..0c37cdfdbb83 100644 --- a/libnd4j/include/execution/ExecutionMode.h +++ b/libnd4j/include/execution/ExecutionMode.h @@ -22,11 +22,11 @@ #define SD_EXECUTIONMODE_H namespace samediff { - enum ExecutionMode { - MODE_UNDEFINED = 0, - MODE_TRAINING = 1, - MODE_INFERENCE = 2, - }; +enum ExecutionMode { + MODE_UNDEFINED = 0, + MODE_TRAINING = 1, + MODE_INFERENCE = 2, +}; } -#endif //SD_EXECUTIONMODE_H +#endif // SD_EXECUTIONMODE_H diff --git a/libnd4j/include/execution/Executor.h b/libnd4j/include/execution/Executor.h index a9eaa6ad36c0..02646cf73f46 100644 --- a/libnd4j/include/execution/Executor.h +++ b/libnd4j/include/execution/Executor.h @@ -22,12 +22,12 @@ #define SD_EXECUTOR_H namespace sd { - class Executor { - public: - static void execute() { - // - } - }; -} +class Executor { + public: + static void execute() { + // + } +}; +} // namespace sd -#endif //SD_EXECUTOR_H +#endif // SD_EXECUTOR_H diff --git a/libnd4j/include/execution/LaunchContext.h b/libnd4j/include/execution/LaunchContext.h index 4eaf2ca0f1bc..49910af3c093 100644 --- a/libnd4j/include/execution/LaunchContext.h +++ b/libnd4j/include/execution/LaunchContext.h @@ -21,12 +21,12 @@ #ifndef LIBND4J_CUDACONTEXT_H #define LIBND4J_CUDACONTEXT_H - #ifdef __CUDABLAS__ #include -#include -#include #include +#include +#include + #include "config.h" #endif @@ -35,101 +35,100 @@ #include "config.h" #endif -#include -#include -#include -#include -#include -#include #include #include +#include +#include +#include +#include +#include +#include +namespace sd { -namespace sd { - -class ND4J_EXPORT LaunchContext { - - private: - static std::vector> _contexts; - static std::mutex _mutex; +class SD_EXPORT LaunchContext { + private: + static std::vector> _contexts; + static std::mutex _mutex; - static MAP_IMPL _deviceMutexes; + static MAP_IMPL _deviceMutexes; - // used for MKLDNN - void *_engine = nullptr; + // used for MKLDNN + void* _engine = nullptr; #ifdef __CUDABLAS__ #ifndef __JAVACPP_HACK__ - void* _cublasHandle = nullptr; - void* _cusolverHandle = nullptr; + void* _cublasHandle = nullptr; + void* _cusolverHandle = nullptr; -#endif // JCPP +#endif // JCPP - bool _isAllocated = false; -#endif // CUDA - sd::memory::Workspace* _workspace = nullptr; - int _deviceID = 0; + bool _isAllocated = false; +#endif // CUDA + sd::memory::Workspace* _workspace = nullptr; + int _deviceID = 0; - public: + public: #ifdef __CUDABLAS__ #ifndef __JAVACPP_HACK__ - LaunchContext(cudaStream_t* cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer = nullptr, void* scalarPointer = nullptr, int* allocationPointer = nullptr); - - void* getReductionPointer () const; - void* getScalarPointer() const; - int* getAllocationPointer() const; - void* getCublasHandle() const; - void* getCusolverHandle() const; - void* getCuDnnHandle() const; - cudaStream_t* getCudaStream() const; - cudaStream_t* getCudaSpecialStream() const; - - void setReductionPointer (void* reductionPointer); - void setScalarPointer(void* scalarPointer); - void setAllocationPointer(int* allocationPointer); - void setCudaStream(cudaStream_t* cudaStream); - void setCudaSpecialStream(cudaStream_t* cudaStream); - void setCublasHandle(void *handle); - -#endif // JCPP - -#endif // CUDA - LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer = nullptr, Nd4jPointer scalarPointer = nullptr, Nd4jPointer allocationPointer = nullptr); - LaunchContext(); - ~LaunchContext(); - sd::memory::Workspace* getWorkspace() const { return _workspace; } - void setWorkspace(sd::memory::Workspace* theWorkspace) { - _workspace = theWorkspace; - } - - void* engine(); - - int getDeviceID() const {return _deviceID;} - void setDeviceID(int deviceID) { _deviceID = deviceID; } - sd::ErrorReference* errorReference(); + LaunchContext(cudaStream_t* cudaStream, cudaStream_t& specialCudaStream, + void* reductionPointer = nullptr, void* scalarPointer = nullptr, + int* allocationPointer = nullptr); + + void* getReductionPointer() const; + void* getScalarPointer() const; + int* getAllocationPointer() const; + void* getCublasHandle() const; + void* getCusolverHandle() const; + void* getCuDnnHandle() const; + cudaStream_t* getCudaStream() const; + cudaStream_t* getCudaSpecialStream() const; + + void setReductionPointer(void* reductionPointer); + void setScalarPointer(void* scalarPointer); + void setAllocationPointer(int* allocationPointer); + void setCudaStream(cudaStream_t* cudaStream); + void setCudaSpecialStream(cudaStream_t* cudaStream); + void setCublasHandle(void* handle); + +#endif // JCPP + +#endif // CUDA + LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer = nullptr, + Nd4jPointer scalarPointer = nullptr, + Nd4jPointer allocationPointer = nullptr); + LaunchContext(); + ~LaunchContext(); + sd::memory::Workspace* getWorkspace() const { return _workspace; } + void setWorkspace(sd::memory::Workspace* theWorkspace) { + _workspace = theWorkspace; + } + + void* engine(); + + int getDeviceID() const { return _deviceID; } + void setDeviceID(int deviceID) { _deviceID = deviceID; } + sd::ErrorReference* errorReference(); #ifndef __JAVACPP_HACK__ - // this method returns mutex shared between all threads that use the same device - static std::mutex* deviceMutex(); + // this method returns mutex shared between all threads that use the same + // device + static std::mutex* deviceMutex(); #endif - static bool isInitialized(); - static void releaseBuffers(); - + static bool isInitialized(); + static void releaseBuffers(); - static LaunchContext* defaultContext(); - - - static void swapContextBuffers(ContextBuffers &buffers); + static LaunchContext* defaultContext(); + static void swapContextBuffers(ContextBuffers& buffers); }; -} - +} // namespace sd -#endif //LIBND4J_CUDACONTEXT_H +#endif // LIBND4J_CUDACONTEXT_H diff --git a/libnd4j/include/execution/ThreadPool.h b/libnd4j/include/execution/ThreadPool.h index ce44d5ae281b..8040eb951ea6 100644 --- a/libnd4j/include/execution/ThreadPool.h +++ b/libnd4j/include/execution/ThreadPool.h @@ -21,49 +21,55 @@ #ifndef SAMEDIFF_THREADPOOL_H #define SAMEDIFF_THREADPOOL_H -#include -#include -#include -#include -#include #include -#include #include +#include #include + +#include +#include +#include #include +#include +#include namespace samediff { - class ND4J_EXPORT ThreadPool { - private: - std::vector _threads; - std::vector*> _queues; - std::vector _interfaces; +class SD_EXPORT ThreadPool { + private: + + + std::vector _threads; + std::vector*> _queues; + std::vector _interfaces; + + std::mutex _lock; + std::atomic _available; + std::queue _tickets; - std::mutex _lock; - std::atomic _available; - std::queue _tickets; - protected: - ThreadPool(); - ~ThreadPool(); - public: - static ThreadPool& getInstance(); + protected: + ThreadPool(); + ~ThreadPool(); - /** - * This method returns list of pointers to threads ONLY if num_threads of threads were available upon request, returning empty list otherwise - * @param num_threads - * @return - */ - Ticket* tryAcquire(int num_threads); + public: + static ThreadPool& getInstance(); - /** - * This method marks specified number of threads as released, and available for use - * @param num_threads - */ - void release(int num_threads = 1); + /** + * This method returns list of pointers to threads ONLY if num_threads of + * threads were available upon request, returning empty list otherwise + * @param num_threads + * @return + */ + Ticket* tryAcquire(int num_threads); - void release(Ticket *ticket); - }; -} + /** + * This method marks specified number of threads as released, and available + * for use + * @param num_threads + */ + void release(int num_threads = 1); + void release(Ticket* ticket); +}; +} // namespace samediff -#endif //DEV_TESTS_THREADPOOL_H +#endif // SD_THREADPOOL_H diff --git a/libnd4j/include/execution/Threads.h b/libnd4j/include/execution/Threads.h index bf35de089ff5..1b48dcb1ade9 100644 --- a/libnd4j/include/execution/Threads.h +++ b/libnd4j/include/execution/Threads.h @@ -14,167 +14,209 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author raver119@gmail.com - // +// +// @author raver119@gmail.com +// #ifndef SAMEDIFF_THREADS_H #define SAMEDIFF_THREADS_H -#include -#include -#include #include +#include #include +#include + +#include namespace samediff { - class ND4J_EXPORT ThreadsHelper { - public: - static int numberOfThreads(int maxThreads, uint64_t numberOfElements); - static int numberOfThreads2d(int maxThreads, uint64_t iters_x, uint64_t iters_y); - static int numberOfThreads3d(int maxThreads, uint64_t iters_x, uint64_t iters_y, uint64_t iters_z); - static int pickLoop2d(int numThreads, uint64_t iters_x, uint64_t iters_y); - static int pickLoop3d(int numThreads, uint64_t iters_x, uint64_t iters_y, uint64_t iters_z); - }; - - class ND4J_EXPORT Span { - private: - int64_t _startX, _stopX, _incX; - public: - Span(int64_t start_x, int64_t stop_x, int64_t inc_x); - ~Span() = default; - - int64_t startX() const; - int64_t stopX() const; - int64_t incX() const; - - static Span build(uint64_t thread_id, uint64_t num_threads, int64_t start_x, int64_t stop_x, int64_t inc_x); - }; - - class ND4J_EXPORT Span2 { - private: - int64_t _startX, _stopX, _incX; - int64_t _startY, _stopY, _incY; - public: - Span2(int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y); - ~Span2() = default; - - int64_t startX() const; - int64_t startY() const; - - int64_t stopX() const; - int64_t stopY() const; - - int64_t incX() const; - int64_t incY() const; - - static Span2 build(int loop, uint64_t thread_id, uint64_t num_threads, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y); - }; - - class ND4J_EXPORT Span3 { - private: - int64_t _startX, _stopX, _incX; - int64_t _startY, _stopY, _incY; - int64_t _startZ, _stopZ, _incZ; - public: - Span3(int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z); - ~Span3() = default; - - int64_t startX() const; - int64_t startY() const; - int64_t startZ() const; - - int64_t stopX() const; - int64_t stopY() const; - int64_t stopZ() const; - - int64_t incX() const; - int64_t incY() const; - int64_t incZ() const; - - static Span3 build(int loop, uint64_t thread_id, uint64_t num_threads, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z); - }; - - class ND4J_EXPORT Threads { - public: - /** - * This function executes 1 dimensional loop for a given number of threads - * PLEASE NOTE: this function can use smaller number of threads than requested. - * - * @param function - * @param numThreads - * @param start - * @param stop - * @param increment - * @return - */ - static int parallel_for(FUNC_1D function, int64_t start, int64_t stop, int64_t increment = 1, uint32_t numThreads = sd::Environment::getInstance().maxMasterThreads()); - - /** - * This function executes 1 dimensional loop for a given number of threads - * - * @param function - * @param start - * @param stop - * @param increment - * @param numThreads - * @return - */ - static int parallel_tad(FUNC_1D function, int64_t start, int64_t stop, int64_t increment = 1, uint32_t numThreads = sd::Environment::getInstance().maxMasterThreads()); - - /** - * This method will execute function splitting 2 nested loops space with multiple threads - * - * @param function - * @param numThreads - * @param start_x - * @param stop_x - * @param inc_x - * @param start_y - * @param stop_y - * @param inc_y - * @return - */ - static int parallel_for(FUNC_2D function, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, uint64_t numThreads = sd::Environment::getInstance().maxMasterThreads(), bool debug = false); - - /** - * This method will execute function splitting 3 nested loops space with multiple threads - * - * @param function - * @param numThreads - * @param start_x - * @param stop_x - * @param inc_x - * @param start_y - * @param stop_y - * @param inc_y - * @param start_z - * @param stop_z - * @param inc_z - * @return - */ - static int parallel_for(FUNC_3D function, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z, uint64_t numThreads = sd::Environment::getInstance().maxMasterThreads()); - - /** - * - * @param function - * @param numThreads - * @return - */ - static int parallel_do(FUNC_DO function, uint64_t numThreads = sd::Environment::getInstance().maxMasterThreads()); - - static int64_t parallel_long(FUNC_RL function, FUNC_AL aggregator, int64_t start, int64_t stop, int64_t increment = 1, uint64_t numThreads = sd::Environment::getInstance().maxMasterThreads()); - - static double parallel_double(FUNC_RD function, FUNC_AD aggregator, int64_t start, int64_t stop, int64_t increment = 1, uint64_t numThreads = sd::Environment::getInstance().maxMasterThreads()); - - /** - * This method will execute function in parallel preserving the parts to be aligned increment size - * PLEASE NOTE: this function can use smaller number of threads than requested. - * - */ - static int parallel_aligned_increment(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, size_t type_size = sizeof(float), uint32_t req_numThreads = sd::Environment::getInstance().maxMasterThreads()); - - }; -} - - -#endif //SAMEDIFF_THREADS_H +class SD_EXPORT ThreadsHelper { + public: + static int numberOfThreads(int maxThreads, uint64_t numberOfElements); + static int numberOfThreads2d(int maxThreads, uint64_t iters_x, + uint64_t iters_y); + static int numberOfThreads3d(int maxThreads, uint64_t iters_x, + uint64_t iters_y, uint64_t iters_z); + static int pickLoop2d(int numThreads, uint64_t iters_x, uint64_t iters_y); + static int pickLoop3d(int numThreads, uint64_t iters_x, uint64_t iters_y, + uint64_t iters_z); +}; + +class SD_EXPORT Span { + private: + int64_t _startX, _stopX, _incX; + + public: + Span(int64_t start_x, int64_t stop_x, int64_t inc_x); + ~Span() = default; + + int64_t startX() const; + int64_t stopX() const; + int64_t incX() const; + + static Span build(uint64_t thread_id, uint64_t num_threads, int64_t start_x, + int64_t stop_x, int64_t inc_x); +}; + +class SD_EXPORT Span2 { + private: + int64_t _startX, _stopX, _incX; + int64_t _startY, _stopY, _incY; + + public: + Span2(int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, + int64_t stop_y, int64_t inc_y); + ~Span2() = default; + + int64_t startX() const; + int64_t startY() const; + + int64_t stopX() const; + int64_t stopY() const; + + int64_t incX() const; + int64_t incY() const; + + static Span2 build(int loop, uint64_t thread_id, uint64_t num_threads, + int64_t start_x, int64_t stop_x, int64_t inc_x, + int64_t start_y, int64_t stop_y, int64_t inc_y); +}; + +class SD_EXPORT Span3 { + private: + int64_t _startX, _stopX, _incX; + int64_t _startY, _stopY, _incY; + int64_t _startZ, _stopZ, _incZ; + + public: + Span3(int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, + int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, + int64_t inc_z); + ~Span3() = default; + + int64_t startX() const; + int64_t startY() const; + int64_t startZ() const; + + int64_t stopX() const; + int64_t stopY() const; + int64_t stopZ() const; + + int64_t incX() const; + int64_t incY() const; + int64_t incZ() const; + + static Span3 build(int loop, uint64_t thread_id, uint64_t num_threads, + int64_t start_x, int64_t stop_x, int64_t inc_x, + int64_t start_y, int64_t stop_y, int64_t inc_y, + int64_t start_z, int64_t stop_z, int64_t inc_z); +}; + +class SD_EXPORT Threads { + public: + /** + * This function executes 1 dimensional loop for a given number of threads + * PLEASE NOTE: this function can use smaller number of threads than + * requested. + * + * @param function + * @param numThreads + * @param start + * @param stop + * @param increment + * @return + */ + static int parallel_for( + FUNC_1D function, int64_t start, int64_t stop, int64_t increment = 1, + uint32_t numThreads = sd::Environment::getInstance().maxMasterThreads()); + + /** + * This function executes 1 dimensional loop for a given number of threads + * + * @param function + * @param start + * @param stop + * @param increment + * @param numThreads + * @return + */ + static int parallel_tad( + FUNC_1D function, int64_t start, int64_t stop, int64_t increment = 1, + uint32_t numThreads = sd::Environment::getInstance().maxMasterThreads()); + + /** + * This method will execute function splitting 2 nested loops space with + * multiple threads + * + * @param function + * @param numThreads + * @param start_x + * @param stop_x + * @param inc_x + * @param start_y + * @param stop_y + * @param inc_y + * @return + */ + static int parallel_for( + FUNC_2D function, int64_t start_x, int64_t stop_x, int64_t inc_x, + int64_t start_y, int64_t stop_y, int64_t inc_y, + uint64_t numThreads = sd::Environment::getInstance().maxMasterThreads(), + bool debug = false); + + /** + * This method will execute function splitting 3 nested loops space with + * multiple threads + * + * @param function + * @param numThreads + * @param start_x + * @param stop_x + * @param inc_x + * @param start_y + * @param stop_y + * @param inc_y + * @param start_z + * @param stop_z + * @param inc_z + * @return + */ + static int parallel_for( + FUNC_3D function, int64_t start_x, int64_t stop_x, int64_t inc_x, + int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, + int64_t stop_z, int64_t inc_z, + uint64_t numThreads = sd::Environment::getInstance().maxMasterThreads()); + + /** + * + * @param function + * @param numThreads + * @return + */ + static int parallel_do( + FUNC_DO function, + uint64_t numThreads = sd::Environment::getInstance().maxMasterThreads()); + + static int64_t parallel_long( + FUNC_RL function, FUNC_AL aggregator, int64_t start, int64_t stop, + int64_t increment = 1, + uint64_t numThreads = sd::Environment::getInstance().maxMasterThreads()); + + static double parallel_double( + FUNC_RD function, FUNC_AD aggregator, int64_t start, int64_t stop, + int64_t increment = 1, + uint64_t numThreads = sd::Environment::getInstance().maxMasterThreads()); + + /** + * This method will execute function in parallel preserving the parts to be + * aligned increment size PLEASE NOTE: this function can use smaller number of + * threads than requested. + * + */ + static int parallel_aligned_increment( + FUNC_1D function, int64_t start, int64_t stop, int64_t increment, + size_t type_size = sizeof(float), + uint32_t req_numThreads = + sd::Environment::getInstance().maxMasterThreads()); +}; +} // namespace samediff + +#endif // SAMEDIFF_THREADS_H diff --git a/libnd4j/include/execution/Ticket.h b/libnd4j/include/execution/Ticket.h index 80bf54145661..d66039c6620e 100644 --- a/libnd4j/include/execution/Ticket.h +++ b/libnd4j/include/execution/Ticket.h @@ -21,47 +21,57 @@ #ifndef SAMEDIFF_TICKET_H #define SAMEDIFF_TICKET_H -#include #include -#include #include +#include + #include #include +#include namespace samediff { - class ND4J_EXPORT Ticket { - private: - bool _acquired = false; - std::vector*> _queues; - std::vector _callables; - std::vector _interfaces; +class SD_EXPORT Ticket { + private: + bool _acquired = false; + std::vector *> _queues; + std::vector _callables; + std::vector _interfaces; - uint32_t _acquiredThreads = 0; - public: - explicit Ticket(const std::vector*> &queues); - Ticket(); - ~Ticket() = default; + uint32_t _acquiredThreads = 0; - bool acquired(); + public: + explicit Ticket( + const std::vector *> &queues); + Ticket(); + ~Ticket() = default; - void acquiredThreads(uint32_t threads); + bool acquired(); - void attach(uint32_t thread_id, CallableInterface *interface); + void acquiredThreads(uint32_t threads); - // deprecated one - void enqueue(int thread_id, CallableWithArguments* callable); + void attach(uint32_t thread_id, CallableInterface *interface); - void enqueue(uint32_t thread_id, uint32_t num_threads, int64_t *lpt, FUNC_RL func, int64_t start_x, int64_t stop_x, int64_t inc_x); - void enqueue(uint32_t thread_id, uint32_t num_threads, double *lpt, FUNC_RD func, int64_t start_x, int64_t stop_x, int64_t inc_x); + // deprecated one + void enqueue(int thread_id, CallableWithArguments *callable); - void enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_DO func); - void enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_1D func, int64_t start_x, int64_t stop_x, int64_t inc_x); - void enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_2D func, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y); - void enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_3D func, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_, int64_t stop_z, int64_t inc_z); + void enqueue(uint32_t thread_id, uint32_t num_threads, int64_t *lpt, + FUNC_RL func, int64_t start_x, int64_t stop_x, int64_t inc_x); + void enqueue(uint32_t thread_id, uint32_t num_threads, double *lpt, + FUNC_RD func, int64_t start_x, int64_t stop_x, int64_t inc_x); - void waitAndRelease(); - }; -} + void enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_DO func); + void enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_1D func, + int64_t start_x, int64_t stop_x, int64_t inc_x); + void enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_2D func, + int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, + int64_t stop_y, int64_t inc_y); + void enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_3D func, + int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, + int64_t stop_y, int64_t inc_y, int64_t start_, int64_t stop_z, + int64_t inc_z); + void waitAndRelease(); +}; +} // namespace samediff -#endif //DEV_TESTS_TICKET_H +#endif // SD_TICKET_H diff --git a/libnd4j/include/execution/cpu/AffinityManager.cpp b/libnd4j/include/execution/cpu/AffinityManager.cpp index 32df63d0d1e8..f00c1e35eb85 100644 --- a/libnd4j/include/execution/cpu/AffinityManager.cpp +++ b/libnd4j/include/execution/cpu/AffinityManager.cpp @@ -21,23 +21,17 @@ #include namespace sd { - int AffinityManager::currentDeviceId() { - return 0; - } +int AffinityManager::currentDeviceId() { return 0; } - int AffinityManager::currentNativeDeviceId() { - return 0; - } +int AffinityManager::currentNativeDeviceId() { return 0; } - int AffinityManager::numberOfDevices() { - return 1; - } +int AffinityManager::numberOfDevices() { return 1; } - void AffinityManager::setCurrentDevice(int deviceId) { - // no-op - } +void AffinityManager::setCurrentDevice(int deviceId) { + // no-op +} - void AffinityManager::setCurrentNativeDevice(int deviceId) { - // no-op - } -} \ No newline at end of file +void AffinityManager::setCurrentNativeDevice(int deviceId) { + // no-op +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/execution/cpu/ContextBuffers.cpp b/libnd4j/include/execution/cpu/ContextBuffers.cpp index 3b1c566a837d..0e1d6a8d2242 100644 --- a/libnd4j/include/execution/cpu/ContextBuffers.cpp +++ b/libnd4j/include/execution/cpu/ContextBuffers.cpp @@ -17,90 +17,75 @@ // // @author raver119@gmail.com // -#include #include +#include namespace sd { - ContextBuffers::ContextBuffers() { - _deviceId = AffinityManager::currentDeviceId(); - } - - ContextBuffers::~ContextBuffers() { - // no-op - } - - ContextBuffers::ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner) { - _reductionPointer = rPointer; - _scalarPointer = sPointer; - _allocationPointer = aPointer; - _allocated = isOwner; - } - - ContextBuffers::ContextBuffers(const ContextBuffers &other) { - // - } - - void ContextBuffers::initialize() { - // no-op - } - - void* ContextBuffers::reductionBuffer() { - return _reductionPointer; - } - - void* ContextBuffers::scalarBuffer() { - return _scalarPointer; - } - - void* ContextBuffers::allocationBuffer() { - return _allocationPointer; - } - - void ContextBuffers::setReductionBuffer(void* pointer) { - _reductionPointer = pointer; - } - - void ContextBuffers::setScalarBuffer(void* pointer) { - _scalarPointer = pointer; - } - - void ContextBuffers::setAllocationBuffer(void* pointer) { - _allocationPointer = pointer; - } - - void ContextBuffers::triggerOwnership(bool isOwner) { - _allocated = isOwner; - } - - int ContextBuffers::deviceId() { - return _deviceId; - } - - void* ContextBuffers::execStream() { - return _execStream; - } - - void* ContextBuffers::specialStream() { - return _specialStream; - } - - bool ContextBuffers::isInitialized() { - return true; - } - - void ContextBuffers::release() { - // - } - - ContextBuffers& ContextBuffers::operator=(const ContextBuffers& other) { - return *this; - } - - ContextBuffers& ContextBuffers::operator=(ContextBuffers&& other) { - return *this; - } - - sd::ErrorReference* ContextBuffers::errorReference() { - return &_errorReference; - } -} \ No newline at end of file +ContextBuffers::ContextBuffers() { + _deviceId = AffinityManager::currentDeviceId(); +} + +ContextBuffers::~ContextBuffers() { + // no-op +} + +ContextBuffers::ContextBuffers(void* rPointer, void* sPointer, void* aPointer, + bool isOwner) { + _reductionPointer = rPointer; + _scalarPointer = sPointer; + _allocationPointer = aPointer; + _allocated = isOwner; +} + +ContextBuffers::ContextBuffers(const ContextBuffers& other) { + // +} + +void ContextBuffers::initialize() { + // no-op +} + +void* ContextBuffers::reductionBuffer() { return _reductionPointer; } + +void* ContextBuffers::scalarBuffer() { return _scalarPointer; } + +void* ContextBuffers::allocationBuffer() { return _allocationPointer; } + +void ContextBuffers::setReductionBuffer(void* pointer) { + _reductionPointer = pointer; +} + +void ContextBuffers::setScalarBuffer(void* pointer) { + _scalarPointer = pointer; +} + +void ContextBuffers::setAllocationBuffer(void* pointer) { + _allocationPointer = pointer; +} + +void ContextBuffers::triggerOwnership(bool isOwner) { _allocated = isOwner; } + +int ContextBuffers::deviceId() { return _deviceId; } + +void* ContextBuffers::execStream() { return _execStream; } + +void* ContextBuffers::specialStream() { return _specialStream; } + +bool ContextBuffers::isInitialized() { return true; } + +void ContextBuffers::release() { + // +} + +ContextBuffers& ContextBuffers::operator=(const ContextBuffers& other) { + return *this; +} + +ContextBuffers& ContextBuffers::operator=(ContextBuffers&& other) { + return *this; +} + +sd::ErrorReference* ContextBuffers::errorReference() { + return &_errorReference; +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/execution/cpu/LaunchContext.cpp b/libnd4j/include/execution/cpu/LaunchContext.cpp index 31cb6889d850..8a0e9ad1045b 100644 --- a/libnd4j/include/execution/cpu/LaunchContext.cpp +++ b/libnd4j/include/execution/cpu/LaunchContext.cpp @@ -18,13 +18,15 @@ // Created by raver119 on 30.11.17. // -#include +#include #include +#include #include -#include + #include -#if defined(SD_IOS_BUILD) || defined(SD_APPLE_BUILD) || defined(SD_ANDROID_BUILD) +#if defined(SD_IOS_BUILD) || defined(SD_APPLE_BUILD) || \ + defined(SD_ANDROID_BUILD) sd::ContextBuffers contextBuffers = sd::ContextBuffers(); #else thread_local sd::ContextBuffers contextBuffers = sd::ContextBuffers(); @@ -36,67 +38,60 @@ thread_local sd::ContextBuffers contextBuffers = sd::ContextBuffers(); namespace sd { - LaunchContext::~LaunchContext() { +LaunchContext::~LaunchContext() { #ifdef HAVE_MKLDNN - delete reinterpret_cast(_engine); + delete reinterpret_cast(_engine); #endif - } +} - std::vector> LaunchContext::_contexts = std::vector>(); - MAP_IMPL LaunchContext::_deviceMutexes; - std::mutex LaunchContext::_mutex; +std::vector> LaunchContext::_contexts = + std::vector>(); +MAP_IMPL LaunchContext::_deviceMutexes; +std::mutex LaunchContext::_mutex; //////////////////////////////////////////////////////////////////////// - LaunchContext::LaunchContext() { - // default constructor, just to make clang/ranlib happy - _workspace = nullptr; - _deviceID = 0; +LaunchContext::LaunchContext() { + // default constructor, just to make clang/ranlib happy + _workspace = nullptr; + _deviceID = 0; #ifdef HAVE_MKLDNN - _engine = new dnnl::engine(dnnl::engine::kind::cpu, 0); + _engine = new dnnl::engine(dnnl::engine::kind::cpu, 0); #endif - } +} - LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) { +LaunchContext::LaunchContext(Nd4jPointer cudaStream, + Nd4jPointer reductionPointer, + Nd4jPointer scalarPointer, + Nd4jPointer allocationPointer) {} - } +static std::mutex _lock; - static std::mutex _lock; +LaunchContext* LaunchContext::defaultContext() { + { + // synchronous block goes here + std::lock_guard lock(_lock); + // TODO: we need it to be device-aware, but only once we add NUMA support for cpu + if (LaunchContext::_contexts.empty()) + LaunchContext::_contexts.emplace_back(std::make_shared()); + } - LaunchContext* LaunchContext::defaultContext() { - { - // synchronous block goes here - std::lock_guard lock(_lock); - // TODO: we need it to be device-aware, but only once we add NUMA support for cpu - if (LaunchContext::_contexts.empty()) - LaunchContext::_contexts.emplace_back(std::make_shared()); - } + // return context for current device + return LaunchContext::_contexts[0].get(); +} - // return context for current device - return LaunchContext::_contexts[0].get(); - } +std::mutex* LaunchContext::deviceMutex() { return &_mutex; } - std::mutex* LaunchContext::deviceMutex() { - return &_mutex; - } +void LaunchContext::swapContextBuffers(ContextBuffers& buffers) { } - void LaunchContext::swapContextBuffers(ContextBuffers &buffers) { - // - } +bool LaunchContext::isInitialized() { return true; } - bool LaunchContext::isInitialized() { - return true; - } +void LaunchContext::releaseBuffers() { } - void LaunchContext::releaseBuffers() { - // - } +sd::ErrorReference* LaunchContext::errorReference() { + return contextBuffers.errorReference(); +} - sd::ErrorReference* LaunchContext::errorReference() { - return contextBuffers.errorReference(); - } +void* LaunchContext::engine() { return _engine; } - void* LaunchContext::engine() { - return _engine; - } -} \ No newline at end of file +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/execution/cuda/AffinityManager.cu b/libnd4j/include/execution/cuda/AffinityManager.cu index cdfe7c1079d8..d0c9f95b39b9 100644 --- a/libnd4j/include/execution/cuda/AffinityManager.cu +++ b/libnd4j/include/execution/cuda/AffinityManager.cu @@ -18,111 +18,109 @@ // @author raver119@gmail.com // -#include -#include #include +#include #include +#include thread_local int globalThreadToDevice = -1; namespace sd { - std::mutex AffinityManager::_currentMutex; - std::mutex AffinityManager::_numberMutex; - int AffinityManager::_numberOfDevices = -1; +std::mutex AffinityManager::_currentMutex; +std::mutex AffinityManager::_numberMutex; +int AffinityManager::_numberOfDevices = -1; + +int AffinityManager::currentDeviceId() { + // if there's no affinity set - set it now + if (globalThreadToDevice < 0) { + // this block must be thread-local + _currentMutex.lock(); + + globalThreadToDevice = _lastDevice++; + + // we need to check if we've got deviceId >= number of actual devices, and + // reset to zero otherwise + if (globalThreadToDevice >= numberOfDevices()) { + globalThreadToDevice = 0; + _lastDevice = numberOfDevices() > 1 ? 1 : 0; + } - int AffinityManager::currentDeviceId() { - // if there's no affinity set - set it now - if (globalThreadToDevice < 0) { + _currentMutex.unlock(); - // this block must be thread-local - _currentMutex.lock(); + setCurrentNativeDevice(globalThreadToDevice); + } - globalThreadToDevice = _lastDevice++; + // if we already know affinity - just return it + if (globalThreadToDevice >= 0) return globalThreadToDevice; - // we need to check if we've got deviceId >= number of actual devices, and reset to zero otherwise - if (globalThreadToDevice >= numberOfDevices()) { - globalThreadToDevice = 0; - _lastDevice = numberOfDevices() > 1 ? 1 : 0; - } + int dev = 0; + auto res = cudaGetDevice(&dev); - _currentMutex.unlock(); + if (res != 0) throw cuda_exception::build("cudaGetDevice failed", res); - setCurrentNativeDevice(globalThreadToDevice); - } + return dev; +} - // if we already know affinity - just return it - if (globalThreadToDevice >= 0) - return globalThreadToDevice; +int AffinityManager::currentNativeDeviceId() { + int dev = 0; + auto res = cudaGetDevice(&dev); - int dev = 0; - auto res = cudaGetDevice(&dev); + if (res != 0) throw cuda_exception::build("cudaGetDevice failed", res); - if (res != 0) - throw cuda_exception::build("cudaGetDevice failed", res); + return dev; +} - return dev; - } +int AffinityManager::numberOfDevices() { + _numberMutex.lock(); + // we want to cache number of devices + if (_numberOfDevices <= 0) { + int dev = 0; + auto res = cudaGetDeviceCount(&dev); - int AffinityManager::currentNativeDeviceId() { - int dev = 0; - auto res = cudaGetDevice(&dev); + if (res != 0) throw cuda_exception::build("cudaGetDeviceCount failed", res); - if (res != 0) - throw cuda_exception::build("cudaGetDevice failed", res); + _numberOfDevices = dev; + } + _numberMutex.unlock(); - return dev; - } + return _numberOfDevices; +} - int AffinityManager::numberOfDevices() { - _numberMutex.lock(); - // we want to cache number of devices - if (_numberOfDevices <= 0) { - int dev = 0; - auto res = cudaGetDeviceCount(&dev); +void AffinityManager::setCurrentNativeDevice(int deviceId) { + auto res = cudaSetDevice(deviceId); + if (res != 0) throw cuda_exception::build("setCurrentDevice failed", res); +} - if (res != 0) - throw cuda_exception::build("cudaGetDeviceCount failed", res); +void AffinityManager::setCurrentDevice(int deviceId) { + auto previousDeviceId = globalThreadToDevice; + if (previousDeviceId >= 0 && LaunchContext::isInitialized()) { + auto res = cudaStreamSynchronize( + *LaunchContext::defaultContext()->getCudaStream()); + if (res != 0) + throw cuda_exception::build("setCurrentDevice -> sync failed", res); - _numberOfDevices = dev; - } - _numberMutex.unlock(); + res = cudaStreamSynchronize( + *LaunchContext::defaultContext()->getCudaSpecialStream()); + if (res != 0) + throw cuda_exception::build("setCurrentDevice -> specialSync failed", + res); - return _numberOfDevices; + if (deviceId != previousDeviceId) { + // discard existing stuff + // nd4j_printf("AffinityManager::setCurrentDevice() was invoked, releasing + // buffers\n", ""); + LaunchContext::releaseBuffers(); } + } - void AffinityManager::setCurrentNativeDevice(int deviceId) { - auto res = cudaSetDevice(deviceId); - if (res != 0) - throw cuda_exception::build("setCurrentDevice failed", res); - } + if (deviceId != previousDeviceId) { + auto res = cudaSetDevice(deviceId); + if (res != 0) throw cuda_exception::build("cudaSetDevice failed", res); + } - void AffinityManager::setCurrentDevice(int deviceId) { - auto previousDeviceId = globalThreadToDevice; - if (previousDeviceId >= 0 && LaunchContext::isInitialized()) { - auto res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream()); - if (res != 0) - throw cuda_exception::build("setCurrentDevice -> sync failed", res); - - res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaSpecialStream()); - if (res != 0) - throw cuda_exception::build("setCurrentDevice -> specialSync failed", res); - - if (deviceId != previousDeviceId) { - // discard existing stuff - //nd4j_printf("AffinityManager::setCurrentDevice() was invoked, releasing buffers\n", ""); - LaunchContext::releaseBuffers(); - } - } - - if (deviceId != previousDeviceId) { - auto res = cudaSetDevice(deviceId); - if (res != 0) - throw cuda_exception::build("cudaSetDevice failed", res); - } - - // update thread-device affinity - globalThreadToDevice = deviceId; - } + // update thread-device affinity + globalThreadToDevice = deviceId; +} - std::atomic AffinityManager::_lastDevice;// = std::atomic(initialV); -} \ No newline at end of file +std::atomic AffinityManager::_lastDevice; // = std::atomic(initialV); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/execution/cuda/ContextBuffers.cu b/libnd4j/include/execution/cuda/ContextBuffers.cu index 0c17ba614939..d1195562100c 100644 --- a/libnd4j/include/execution/cuda/ContextBuffers.cu +++ b/libnd4j/include/execution/cuda/ContextBuffers.cu @@ -18,220 +18,211 @@ // @author raver119@gmail.com // -#include -#include -#include -#include - #include -#include -#include #include +#include +#include +#include +#include +#include +#include namespace sd { - ContextBuffers::ContextBuffers() { - //nd4j_printf("Creating ContextBuffers for device [%i]\n", AffinityManager::currentDeviceId()); - _deviceId = AffinityManager::currentDeviceId(); - } +ContextBuffers::ContextBuffers() { + // nd4j_printf("Creating ContextBuffers for device [%i]\n", + // AffinityManager::currentDeviceId()); + _deviceId = AffinityManager::currentDeviceId(); +} + +ContextBuffers::ContextBuffers(const ContextBuffers& other) { + release(); - ContextBuffers::ContextBuffers(const ContextBuffers &other) { - release(); - - this->_initialized = other._initialized; - this->_allocated = other._allocated; - this->_deviceId = other._deviceId; - - this->_specialStream = other._specialStream; - this->_execStream = other._execStream; - this->_allocationPointer = other._allocationPointer; - this->_reductionPointer = other._reductionPointer; - this->_scalarPointer = other._scalarPointer; - } + this->_initialized = other._initialized; + this->_allocated = other._allocated; + this->_deviceId = other._deviceId; - ContextBuffers& ContextBuffers::operator=(const ContextBuffers& other) { - release(); + this->_specialStream = other._specialStream; + this->_execStream = other._execStream; + this->_allocationPointer = other._allocationPointer; + this->_reductionPointer = other._reductionPointer; + this->_scalarPointer = other._scalarPointer; +} - this->_initialized = other._initialized; - this->_allocated = other._allocated; - this->_deviceId = other._deviceId; +ContextBuffers& ContextBuffers::operator=(const ContextBuffers& other) { + release(); - this->_specialStream = other._specialStream; - this->_execStream = other._execStream; - this->_allocationPointer = other._allocationPointer; - this->_reductionPointer = other._reductionPointer; - this->_scalarPointer = other._scalarPointer; - - return *this; - } + this->_initialized = other._initialized; + this->_allocated = other._allocated; + this->_deviceId = other._deviceId; - ContextBuffers& ContextBuffers::operator=(ContextBuffers&& other) { - release(); + this->_specialStream = other._specialStream; + this->_execStream = other._execStream; + this->_allocationPointer = other._allocationPointer; + this->_reductionPointer = other._reductionPointer; + this->_scalarPointer = other._scalarPointer; - this->_initialized = other._initialized; - this->_allocated = other._allocated; - this->_deviceId = other._deviceId; + return *this; +} - this->_specialStream = other._specialStream; - this->_execStream = other._execStream; - this->_allocationPointer = other._allocationPointer; - this->_reductionPointer = other._reductionPointer; - this->_scalarPointer = other._scalarPointer; +ContextBuffers& ContextBuffers::operator=(ContextBuffers&& other) { + release(); - return *this; - } + this->_initialized = other._initialized; + this->_allocated = other._allocated; + this->_deviceId = other._deviceId; - void ContextBuffers::release() { - if (_allocated) { - //nd4j_printf("Releasing ContextBuffers on device [%i]\n", _deviceId); + this->_specialStream = other._specialStream; + this->_execStream = other._execStream; + this->_allocationPointer = other._allocationPointer; + this->_reductionPointer = other._reductionPointer; + this->_scalarPointer = other._scalarPointer; - if (_allocationPointer != nullptr) - cudaFree(_allocationPointer); + return *this; +} - if (_scalarPointer != nullptr) - cudaFreeHost(_scalarPointer); +void ContextBuffers::release() { + if (_allocated) { + // nd4j_printf("Releasing ContextBuffers on device [%i]\n", _deviceId); - if (_allocationPointer != nullptr) - cudaFree(_reductionPointer); + if (_allocationPointer != nullptr) cudaFree(_allocationPointer); - auto _cudaStream = reinterpret_cast(_execStream); - auto _cudaSpecialStream = reinterpret_cast(_specialStream); + if (_scalarPointer != nullptr) cudaFreeHost(_scalarPointer); - cudaStreamSynchronize(*_cudaStream); - cudaStreamSynchronize(*_cudaSpecialStream); + if (_allocationPointer != nullptr) cudaFree(_reductionPointer); - cudaStreamDestroy(*_cudaStream); - cudaStreamDestroy(*_cudaSpecialStream); + auto _cudaStream = reinterpret_cast(_execStream); + auto _cudaSpecialStream = reinterpret_cast(_specialStream); - delete _cudaStream; - delete _cudaSpecialStream; + cudaStreamSynchronize(*_cudaStream); + cudaStreamSynchronize(*_cudaSpecialStream); - ////// - _allocated = false; - _deviceId = -1; - - this->_specialStream = nullptr; - this->_execStream = nullptr; - this->_allocationPointer = nullptr; - this->_reductionPointer = nullptr; - this->_scalarPointer = nullptr; - } + cudaStreamDestroy(*_cudaStream); + cudaStreamDestroy(*_cudaSpecialStream); - _initialized = false; - } - - ContextBuffers::~ContextBuffers() { - release(); - } - - ContextBuffers::ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner) { - _reductionPointer = rPointer; - _scalarPointer = sPointer; - _allocationPointer = aPointer; - _allocated = isOwner; - } - - void ContextBuffers::initialize() { - _deviceId = AffinityManager::currentNativeDeviceId(); - //nd4j_printf("Initializing buffers on deviceId [%i]\n", _deviceId); + delete _cudaStream; + delete _cudaSpecialStream; - auto res = cudaMalloc(reinterpret_cast(&_reductionPointer), 1024 * 1024 * 8); - if (res != 0) - throw cuda_exception::build("_reductionPointer allocation failed", res); + ////// + _allocated = false; + _deviceId = -1; - res = cudaHostAlloc(reinterpret_cast(&_scalarPointer), 16, cudaHostAllocDefault); - if (res != 0) - throw cuda_exception::build("_scalarPointer allocation failed", res); + this->_specialStream = nullptr; + this->_execStream = nullptr; + this->_allocationPointer = nullptr; + this->_reductionPointer = nullptr; + this->_scalarPointer = nullptr; + } - res = cudaMalloc(reinterpret_cast(&_allocationPointer), 1024 * 1024 * 8); - if (res != 0) - throw cuda_exception::build("_allocationPointer allocation failed", res); + _initialized = false; +} - _execStream = new cudaStream_t(); - _specialStream = new cudaStream_t(); - if (nullptr == _execStream || nullptr == _specialStream) - throw std::runtime_error("Failed to allocate memory for new CUDA stream"); +ContextBuffers::~ContextBuffers() { release(); } - res = cudaStreamCreate(reinterpret_cast(_execStream)); - if (res != 0) - throw cuda_exception::build("Failed to create default CUDA stream with launch context", res); +ContextBuffers::ContextBuffers(void* rPointer, void* sPointer, void* aPointer, + bool isOwner) { + _reductionPointer = rPointer; + _scalarPointer = sPointer; + _allocationPointer = aPointer; + _allocated = isOwner; +} - res = cudaStreamCreate(reinterpret_cast(_specialStream)); - if (res != 0) - throw cuda_exception::build("Failed to create special CUDA stream with launch context", res); +void ContextBuffers::initialize() { + _deviceId = AffinityManager::currentNativeDeviceId(); + // nd4j_printf("Initializing buffers on deviceId [%i]\n", _deviceId); + + auto res = + cudaMalloc(reinterpret_cast(&_reductionPointer), 1024 * 1024 * 8); + if (res != 0) + throw cuda_exception::build("_reductionPointer allocation failed", res); + + res = cudaHostAlloc(reinterpret_cast(&_scalarPointer), 16, + cudaHostAllocDefault); + if (res != 0) + throw cuda_exception::build("_scalarPointer allocation failed", res); + + res = cudaMalloc(reinterpret_cast(&_allocationPointer), + 1024 * 1024 * 8); + if (res != 0) + throw cuda_exception::build("_allocationPointer allocation failed", res); + + _execStream = new cudaStream_t(); + _specialStream = new cudaStream_t(); + if (nullptr == _execStream || nullptr == _specialStream) + throw std::runtime_error("Failed to allocate memory for new CUDA stream"); + + res = cudaStreamCreate(reinterpret_cast(_execStream)); + if (res != 0) + throw cuda_exception::build( + "Failed to create default CUDA stream with launch context", res); + + res = cudaStreamCreate(reinterpret_cast(_specialStream)); + if (res != 0) + throw cuda_exception::build( + "Failed to create special CUDA stream with launch context", res); + + _allocated = true; + _initialized = true; +} - _allocated = true; - _initialized = true; - } +void* ContextBuffers::reductionBuffer() { + if (!_initialized) initialize(); - void* ContextBuffers::reductionBuffer() { - if (!_initialized) - initialize(); + return _reductionPointer; +} - return _reductionPointer; - } +void* ContextBuffers::scalarBuffer() { + if (!_initialized) initialize(); - void* ContextBuffers::scalarBuffer() { - if (!_initialized) - initialize(); + return _scalarPointer; +} - return _scalarPointer; - } +void* ContextBuffers::allocationBuffer() { + if (!_initialized) initialize(); - void* ContextBuffers::allocationBuffer() { - if (!_initialized) - initialize(); + return _allocationPointer; +} - return _allocationPointer; - } +void ContextBuffers::setReductionBuffer(void* pointer) { + _reductionPointer = pointer; +} - void ContextBuffers::setReductionBuffer(void* pointer) { - _reductionPointer = pointer; - } +void ContextBuffers::setScalarBuffer(void* pointer) { + _scalarPointer = pointer; +} - void ContextBuffers::setScalarBuffer(void* pointer) { - _scalarPointer = pointer; - } +void ContextBuffers::setAllocationBuffer(void* pointer) { + _allocationPointer = pointer; +} - void ContextBuffers::setAllocationBuffer(void* pointer) { - _allocationPointer = pointer; - } +void ContextBuffers::triggerOwnership(bool isOwner) { _allocated = isOwner; } - void ContextBuffers::triggerOwnership(bool isOwner) { - _allocated = isOwner; - } - - int ContextBuffers::deviceId() { - return _deviceId; - } - - void* ContextBuffers::execStream() { - if (!_initialized) { - //nd4j_printf("execStream not initialized\n", ""); - initialize(); - } else { - //nd4j_printf("execStream is initialized\n", ""); - } - - return _execStream; - } +int ContextBuffers::deviceId() { return _deviceId; } - void* ContextBuffers::specialStream() { - if (!_initialized) { - //nd4j_printf("specialStream not initialized\n", ""); - initialize(); - } else { - //nd4j_printf("specialStream is initialized\n", ""); - } +void* ContextBuffers::execStream() { + if (!_initialized) { + // nd4j_printf("execStream not initialized\n", ""); + initialize(); + } else { + // nd4j_printf("execStream is initialized\n", ""); + } - return _specialStream; - } + return _execStream; +} - bool ContextBuffers::isInitialized() { - return _initialized; - } +void* ContextBuffers::specialStream() { + if (!_initialized) { + // nd4j_printf("specialStream not initialized\n", ""); + initialize(); + } else { + // nd4j_printf("specialStream is initialized\n", ""); + } - sd::ErrorReference* ContextBuffers::errorReference() { - return &_errorReference; - } + return _specialStream; } +bool ContextBuffers::isInitialized() { return _initialized; } + +sd::ErrorReference* ContextBuffers::errorReference() { + return &_errorReference; +} +} // namespace sd diff --git a/libnd4j/include/execution/cuda/LaunchContext.cu b/libnd4j/include/execution/cuda/LaunchContext.cu index bd51c350445e..3d824f84d8c4 100644 --- a/libnd4j/include/execution/cuda/LaunchContext.cu +++ b/libnd4j/include/execution/cuda/LaunchContext.cu @@ -19,172 +19,170 @@ // @author raver119@gmail.com // -#include -#include #include +#include +#include #include +#include + #include -#include thread_local sd::ContextBuffers contextBuffers = sd::ContextBuffers(); namespace sd { - std::vector> LaunchContext::_contexts = std::vector>(); - std::mutex LaunchContext::_mutex; - MAP_IMPL LaunchContext::_deviceMutexes; +std::vector> LaunchContext::_contexts = + std::vector>(); +std::mutex LaunchContext::_mutex; +MAP_IMPL LaunchContext::_deviceMutexes; //////////////////////////////////////////////////////////////////////// -LaunchContext::LaunchContext(cudaStream_t *cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer, void* scalarPointer, int* allocationPointer) { - - //_cudaStream = cudaStream; - //_cudaSpecialStream = &specialCudaStream; // ideal is = new cudaStream_t; *_cudaSpecialStream = specialCudaStream; - //_reductionPointer = reductionPointer; - //_scalarPointer = scalarPointer; - //_allocationPointer = allocationPointer; - _workspace = nullptr; - _isAllocated = false; +LaunchContext::LaunchContext(cudaStream_t* cudaStream, + cudaStream_t& specialCudaStream, + void* reductionPointer, void* scalarPointer, + int* allocationPointer) { + //_cudaStream = cudaStream; + //_cudaSpecialStream = &specialCudaStream; // ideal is = new cudaStream_t; + //*_cudaSpecialStream = specialCudaStream; _reductionPointer = + // reductionPointer; _scalarPointer = scalarPointer; _allocationPointer = + // allocationPointer; + _workspace = nullptr; + _isAllocated = false; } - std::mutex* LaunchContext::deviceMutex() { - auto deviceId = AffinityManager::currentDeviceId(); - return _deviceMutexes[deviceId]; - } +std::mutex* LaunchContext::deviceMutex() { + auto deviceId = AffinityManager::currentDeviceId(); + return _deviceMutexes[deviceId]; +} LaunchContext::~LaunchContext() { - if (_isAllocated) { - - } + if (_isAllocated) { + } } //////////////////////////////////////////////////////////////////////// LaunchContext::LaunchContext() { - // default constructor, just to make clang/ranlib happy - _workspace = nullptr; - _deviceID = 0; + // default constructor, just to make clang/ranlib happy + _workspace = nullptr; + _deviceID = 0; - _isAllocated = true; + _isAllocated = true; } - LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) { - _isAllocated = false; - //_cudaStream = reinterpret_cast(cudaStream); - // _cudaSpecialStream = reinterpret_cast(cudaStream); - //_reductionPointer = reductionPointer; - //_scalarPointer = scalarPointer; - //_allocationPointer = reinterpret_cast(allocationPointer); - } +LaunchContext::LaunchContext(Nd4jPointer cudaStream, + Nd4jPointer reductionPointer, + Nd4jPointer scalarPointer, + Nd4jPointer allocationPointer) { + _isAllocated = false; + //_cudaStream = reinterpret_cast(cudaStream); + // _cudaSpecialStream = reinterpret_cast(cudaStream); + //_reductionPointer = reductionPointer; + //_scalarPointer = scalarPointer; + //_allocationPointer = reinterpret_cast(allocationPointer); +} - LaunchContext* LaunchContext::defaultContext() { - /** - * This method returns LaunchContext, that has multiple entities within: - * 1) temporary buffers. they must be per-thread - * 2) CUDA stream. it must be either per-thread or per-device - * 3) cuBLAS handle. it must be per-device - */ - auto deviceId = AffinityManager::currentDeviceId(); - - { - // we need this block synchronous, to avoid double initialization etc - std::lock_guard lock(_mutex); - if (LaunchContext::_contexts.empty()) { - // create one context per device - auto numDevices = AffinityManager::numberOfDevices(); - - _contexts.resize(numDevices); - for (int e = 0; e < numDevices; e++) { - _deviceMutexes[e] = new std::mutex(); - - AffinityManager::setCurrentNativeDevice(e); - - LaunchContext::_contexts[e] = std::make_shared(); - } - - // don't forget to restore device back again - AffinityManager::setCurrentNativeDevice(deviceId); - } - } - - // return context for current device - return LaunchContext::_contexts[deviceId].get(); +LaunchContext* LaunchContext::defaultContext() { + /** + * This method returns LaunchContext, that has multiple entities within: + * 1) temporary buffers. they must be per-thread + * 2) CUDA stream. it must be either per-thread or per-device + * 3) cuBLAS handle. it must be per-device + */ + auto deviceId = AffinityManager::currentDeviceId(); + + {// we need this block synchronous, to avoid double initialization etc + std::lock_guard lock(_mutex); + if (LaunchContext::_contexts.empty()) { + // create one context per device + auto numDevices = AffinityManager::numberOfDevices(); + + _contexts.resize(numDevices); + for (int e = 0; e < numDevices; e++) { + _deviceMutexes[e] = new std::mutex(); + + AffinityManager::setCurrentNativeDevice(e); + + LaunchContext::_contexts[e] = std::make_shared(); } + // don't forget to restore device back again + AffinityManager::setCurrentNativeDevice(deviceId); + } + } - void* LaunchContext::getReductionPointer () const { - return contextBuffers.reductionBuffer(); - }; + // return context for current device + return LaunchContext::_contexts[deviceId].get(); +} - void* LaunchContext::getScalarPointer() const { - return contextBuffers.scalarBuffer(); - }; +void* LaunchContext::getReductionPointer() const { + return contextBuffers.reductionBuffer(); +}; - int* LaunchContext::getAllocationPointer() const { - return reinterpret_cast(contextBuffers.allocationBuffer()); - }; +void* LaunchContext::getScalarPointer() const { + return contextBuffers.scalarBuffer(); +}; - void* LaunchContext::getCublasHandle() const { - return CublasHelper::getInstance().handle(); - }; +int* LaunchContext::getAllocationPointer() const { + return reinterpret_cast(contextBuffers.allocationBuffer()); +}; - void* LaunchContext::getCusolverHandle() const { - return CublasHelper::getInstance().solver(); - }; +void* LaunchContext::getCublasHandle() const { + return CublasHelper::getInstance().handle(); +}; - cudaStream_t* LaunchContext::getCudaStream() const { - return reinterpret_cast(contextBuffers.execStream()); - }; +void* LaunchContext::getCusolverHandle() const { + return CublasHelper::getInstance().solver(); +}; - cudaStream_t* LaunchContext::getCudaSpecialStream() const { - return reinterpret_cast(contextBuffers.specialStream());; - }; +cudaStream_t* LaunchContext::getCudaStream() const { + return reinterpret_cast(contextBuffers.execStream()); +}; +cudaStream_t* LaunchContext::getCudaSpecialStream() const { + return reinterpret_cast(contextBuffers.specialStream()); + ; +}; - void LaunchContext::setReductionPointer (void* reductionPointer) { - contextBuffers.setReductionBuffer(reductionPointer); - }; +void LaunchContext::setReductionPointer(void* reductionPointer) { + contextBuffers.setReductionBuffer(reductionPointer); +}; - void LaunchContext::setScalarPointer(void* scalarPointer) { - contextBuffers.setScalarBuffer(scalarPointer); - }; +void LaunchContext::setScalarPointer(void* scalarPointer) { + contextBuffers.setScalarBuffer(scalarPointer); +}; - void LaunchContext::setAllocationPointer(int* allocationPointer) { - contextBuffers.setAllocationBuffer(allocationPointer); - }; +void LaunchContext::setAllocationPointer(int* allocationPointer) { + contextBuffers.setAllocationBuffer(allocationPointer); +}; - void LaunchContext::setCudaStream(cudaStream_t* cudaStream) { - //_cudaStream = cudaStream; - }; +void LaunchContext::setCudaStream(cudaStream_t* cudaStream){ + //_cudaStream = cudaStream; +}; - void LaunchContext::setCudaSpecialStream(cudaStream_t* cudaStream) { - //_cudaSpecialStream = cudaStream; - }; +void LaunchContext::setCudaSpecialStream(cudaStream_t* cudaStream){ + //_cudaSpecialStream = cudaStream; +}; - void LaunchContext::setCublasHandle(void *handle) { - _cublasHandle = handle; - }; +void LaunchContext::setCublasHandle(void* handle) { _cublasHandle = handle; }; - void LaunchContext::swapContextBuffers(ContextBuffers &buffers) { - contextBuffers = buffers; - }; +void LaunchContext::swapContextBuffers(ContextBuffers& buffers) { + contextBuffers = buffers; +}; - void LaunchContext::releaseBuffers() { - //nd4j_printf("LaunchContext::releaseBuffers() was invoked\n", ""); - contextBuffers.release(); - } +void LaunchContext::releaseBuffers() { + // nd4j_printf("LaunchContext::releaseBuffers() was invoked\n", ""); + contextBuffers.release(); +} - bool LaunchContext::isInitialized() { - return contextBuffers.isInitialized(); - } +bool LaunchContext::isInitialized() { return contextBuffers.isInitialized(); } - void* LaunchContext::getCuDnnHandle() const { - return CublasHelper::getInstance().cudnn(); - } +void* LaunchContext::getCuDnnHandle() const { + return CublasHelper::getInstance().cudnn(); +} - sd::ErrorReference* LaunchContext::errorReference() { - return contextBuffers.errorReference(); - } +sd::ErrorReference* LaunchContext::errorReference() { + return contextBuffers.errorReference(); +} - void* LaunchContext::engine() { - return _engine; - } -} \ No newline at end of file +void* LaunchContext::engine() { return _engine; } +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/execution/impl/BlockingQueue.cpp b/libnd4j/include/execution/impl/BlockingQueue.cpp index 21c3b4c6a5f4..80f5f9a58280 100644 --- a/libnd4j/include/execution/impl/BlockingQueue.cpp +++ b/libnd4j/include/execution/impl/BlockingQueue.cpp @@ -20,54 +20,55 @@ #include #include + #include namespace samediff { - template - BlockingQueue::BlockingQueue(int queueSize) { - _size = 0; - _available = true; - } - - template - T BlockingQueue::poll() { - // locking untill there's something within queue - std::unique_lock lock(_lock); - _condition.wait(lock, [&]{ return this->_size.load() != 0; }); +template +BlockingQueue::BlockingQueue(int queueSize) { + _size = 0; + _available = true; +} - T t(std::move(_queue.front())); - _queue.pop(); - _size--; - return t; - } +template +T BlockingQueue::poll() { + // locking untill there's something within queue + std::unique_lock lock(_lock); + _condition.wait(lock, [&] { return this->_size.load() != 0; }); - template - void BlockingQueue::put(const T &t) { - { - // locking before push, unlocking after - std::unique_lock lock(_lock); - _queue.push(t); - _size++; - } + T t(std::move(_queue.front())); + _queue.pop(); + _size--; + return t; +} - // notifying condition - _condition.notify_one(); - } +template +void BlockingQueue::put(const T &t) { + { + // locking before push, unlocking after + std::unique_lock lock(_lock); + _queue.push(t); + _size++; + } - template - bool BlockingQueue::available() { - return _available.load(); - } + // notifying condition + _condition.notify_one(); +} - template - void BlockingQueue::markAvailable() { - _available = true; - } +template +bool BlockingQueue::available() { + return _available.load(); +} - template - void BlockingQueue::markUnavailable() { - _available = false; - } +template +void BlockingQueue::markAvailable() { + _available = true; +} - template class BlockingQueue; +template +void BlockingQueue::markUnavailable() { + _available = false; } + +template class BlockingQueue; +} // namespace samediff diff --git a/libnd4j/include/execution/impl/CallableInterface.cpp b/libnd4j/include/execution/impl/CallableInterface.cpp index a719af848576..8e8c121ac7ab 100644 --- a/libnd4j/include/execution/impl/CallableInterface.cpp +++ b/libnd4j/include/execution/impl/CallableInterface.cpp @@ -22,192 +22,201 @@ #include namespace samediff { - CallableInterface::CallableInterface() { - // initial state is available - _available = true; - _filled = false; - _finished = false; - } - - bool CallableInterface::available() { - return _available.load(); - } - - void CallableInterface::markUnavailable() { - _available = false; - } - - void CallableInterface::markAvailable() { - _available = true; - } - - void CallableInterface::fill(int threadID, int numThreads, FUNC_DO func) { - _function_do = std::move(func); - - _branch = 0; - _num_threads = numThreads; - _thread_id = threadID; - _finished = false; - { - std::unique_lock l(_ms); - _filled = true; - } - _starter.notify_one(); - } - - void CallableInterface::fill(int threadID, int numThreads, FUNC_1D func, int64_t startX, int64_t stopX, int64_t incX) { - _function_1d = std::move(func); - _arguments[0] = startX; - _arguments[1] = stopX; - _arguments[2] = incX; - - _branch = 1; - _num_threads = numThreads; - _thread_id = threadID; - _finished = false; - - { - std::unique_lock l(_ms); - _filled = true; - } - _starter.notify_one(); - } - - void CallableInterface::fill(int threadID, int numThreads, FUNC_2D func, int64_t startX, int64_t stopX, int64_t incX, int64_t start_y, int64_t stop_y, int64_t inc_y) { - _function_2d = std::move(func); - _arguments[0] = startX; - _arguments[1] = stopX; - _arguments[2] = incX; - _arguments[3] = start_y; - _arguments[4] = stop_y; - _arguments[5] = inc_y; - - _branch = 2; - _num_threads = numThreads; - _thread_id = threadID; - _finished = false; - - { - std::unique_lock l(_ms); - _filled = true; - } - _starter.notify_one(); - } - - void CallableInterface::fill(int threadID, int numThreads, FUNC_3D func, int64_t startX, int64_t stopX, int64_t incX, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z) { - _function_3d = std::move(func); - _arguments[0] = startX; - _arguments[1] = stopX; - _arguments[2] = incX; - _arguments[3] = start_y; - _arguments[4] = stop_y; - _arguments[5] = inc_y; - _arguments[6] = start_z; - _arguments[7] = stop_z; - _arguments[8] = inc_z; - - _branch = 3; - _num_threads = numThreads; - _thread_id = threadID; - _finished = false; - - { - std::unique_lock l(_ms); - _filled = true; - } - _starter.notify_one(); - } - - void CallableInterface::fill(int threadID, int numThreads, int64_t *lptr, FUNC_RL func, int64_t startX, int64_t stopX, int64_t incX) { - _function_rl = std::move(func); - _arguments[0] = startX; - _arguments[1] = stopX; - _arguments[2] = incX; - - _lptr = lptr; - - _branch = 4; - _num_threads = numThreads; - _thread_id = threadID; - _finished = false; - - { - std::unique_lock l(_ms); - _filled = true; - } - _starter.notify_one(); - } - - void CallableInterface::fill(int threadID, int numThreads, double *dptr, FUNC_RD func, int64_t startX, int64_t stopX, int64_t incX) { - _function_rd = std::move(func); - _arguments[0] = startX; - _arguments[1] = stopX; - _arguments[2] = incX; - - _dptr = dptr; - - _branch = 5; - _num_threads = numThreads; - _thread_id = threadID; - _finished = false; - - { - std::unique_lock l(_ms); - _filled = true; - } - _starter.notify_one(); - } - - void CallableInterface::waitForTask() { - // block until task is available - std::unique_lock lock(_ms); - _starter.wait(lock, [&]{ return _filled.load(); }); - } - - void CallableInterface::waitForCompletion() { - //while (!_finished.load()); - - // block until finished - std::unique_lock lock(_mf); - _finisher.wait(lock, [&] { return _finished.load(); }); - } - - void CallableInterface::finish() { - // mark as finished - { - std::unique_lock l(_mf); - _finished.store(true); - } - _finisher.notify_one(); - } - - void CallableInterface::execute() { - // mark it as consumed - _filled = false; - - // actually executing op - switch (_branch) { - case 0: - _function_do(_thread_id, _num_threads); - break; - case 1: - _function_1d(_thread_id, _arguments[0], _arguments[1], _arguments[2]); - break; - case 2: - _function_2d(_thread_id, _arguments[0], _arguments[1], _arguments[2], _arguments[3], _arguments[4], _arguments[5]); - break; - case 3: - _function_3d(_thread_id, _arguments[0], _arguments[1], _arguments[2], _arguments[3], _arguments[4], _arguments[5], _arguments[6], _arguments[7], _arguments[8]); - break; - case 4: - _lptr[0] = _function_rl(_thread_id, _arguments[0], _arguments[1], _arguments[2]); - break; - case 5: - _dptr[0] = _function_rd(_thread_id, _arguments[0], _arguments[1], _arguments[2]); - break; - } - - // notify that thread finished the job - this->finish(); - } -} \ No newline at end of file +CallableInterface::CallableInterface() { + // initial state is available + _available = true; + _filled = false; + _finished = false; +} + +bool CallableInterface::available() { return _available.load(); } + +void CallableInterface::markUnavailable() { _available = false; } + +void CallableInterface::markAvailable() { _available = true; } + +void CallableInterface::fill(int threadID, int numThreads, FUNC_DO func) { + _function_do = std::move(func); + + _branch = 0; + _num_threads = numThreads; + _thread_id = threadID; + _finished = false; + { + std::unique_lock l(_ms); + _filled = true; + } + _starter.notify_one(); +} + +void CallableInterface::fill(int threadID, int numThreads, FUNC_1D func, + int64_t startX, int64_t stopX, int64_t incX) { + _function_1d = std::move(func); + _arguments[0] = startX; + _arguments[1] = stopX; + _arguments[2] = incX; + + _branch = 1; + _num_threads = numThreads; + _thread_id = threadID; + _finished = false; + + { + std::unique_lock l(_ms); + _filled = true; + } + _starter.notify_one(); +} + +void CallableInterface::fill(int threadID, int numThreads, FUNC_2D func, + int64_t startX, int64_t stopX, int64_t incX, + int64_t start_y, int64_t stop_y, int64_t inc_y) { + _function_2d = std::move(func); + _arguments[0] = startX; + _arguments[1] = stopX; + _arguments[2] = incX; + _arguments[3] = start_y; + _arguments[4] = stop_y; + _arguments[5] = inc_y; + + _branch = 2; + _num_threads = numThreads; + _thread_id = threadID; + _finished = false; + + { + std::unique_lock l(_ms); + _filled = true; + } + _starter.notify_one(); +} + +void CallableInterface::fill(int threadID, int numThreads, FUNC_3D func, + int64_t startX, int64_t stopX, int64_t incX, + int64_t start_y, int64_t stop_y, int64_t inc_y, + int64_t start_z, int64_t stop_z, int64_t inc_z) { + _function_3d = std::move(func); + _arguments[0] = startX; + _arguments[1] = stopX; + _arguments[2] = incX; + _arguments[3] = start_y; + _arguments[4] = stop_y; + _arguments[5] = inc_y; + _arguments[6] = start_z; + _arguments[7] = stop_z; + _arguments[8] = inc_z; + + _branch = 3; + _num_threads = numThreads; + _thread_id = threadID; + _finished = false; + + { + std::unique_lock l(_ms); + _filled = true; + } + _starter.notify_one(); +} + +void CallableInterface::fill(int threadID, int numThreads, int64_t *lptr, + FUNC_RL func, int64_t startX, int64_t stopX, + int64_t incX) { + _function_rl = std::move(func); + _arguments[0] = startX; + _arguments[1] = stopX; + _arguments[2] = incX; + + _lptr = lptr; + + _branch = 4; + _num_threads = numThreads; + _thread_id = threadID; + _finished = false; + + { + std::unique_lock l(_ms); + _filled = true; + } + _starter.notify_one(); +} + +void CallableInterface::fill(int threadID, int numThreads, double *dptr, + FUNC_RD func, int64_t startX, int64_t stopX, + int64_t incX) { + _function_rd = std::move(func); + _arguments[0] = startX; + _arguments[1] = stopX; + _arguments[2] = incX; + + _dptr = dptr; + + _branch = 5; + _num_threads = numThreads; + _thread_id = threadID; + _finished = false; + + { + std::unique_lock l(_ms); + _filled = true; + } + _starter.notify_one(); +} + +void CallableInterface::waitForTask() { + // block until task is available + std::unique_lock lock(_ms); + _starter.wait(lock, [&] { return _filled.load(); }); +} + +void CallableInterface::waitForCompletion() { + // while (!_finished.load()); + + // block until finished + std::unique_lock lock(_mf); + _finisher.wait(lock, [&] { return _finished.load(); }); +} + +void CallableInterface::finish() { + // mark as finished + { + std::unique_lock l(_mf); + _finished.store(true); + } + _finisher.notify_one(); +} + +void CallableInterface::execute() { + // mark it as consumed + _filled = false; + + // actually executing op + switch (_branch) { + case 0: + _function_do(_thread_id, _num_threads); + break; + case 1: + _function_1d(_thread_id, _arguments[0], _arguments[1], _arguments[2]); + break; + case 2: + _function_2d(_thread_id, _arguments[0], _arguments[1], _arguments[2], + _arguments[3], _arguments[4], _arguments[5]); + break; + case 3: + _function_3d(_thread_id, _arguments[0], _arguments[1], _arguments[2], + _arguments[3], _arguments[4], _arguments[5], _arguments[6], + _arguments[7], _arguments[8]); + break; + case 4: + _lptr[0] = + _function_rl(_thread_id, _arguments[0], _arguments[1], _arguments[2]); + break; + case 5: + _dptr[0] = + _function_rd(_thread_id, _arguments[0], _arguments[1], _arguments[2]); + break; + } + + // notify that thread finished the job + this->finish(); +} +} // namespace samediff \ No newline at end of file diff --git a/libnd4j/include/execution/impl/CallableWithArguments.cpp b/libnd4j/include/execution/impl/CallableWithArguments.cpp index 8f17622b733b..a05fcf1f8eba 100644 --- a/libnd4j/include/execution/impl/CallableWithArguments.cpp +++ b/libnd4j/include/execution/impl/CallableWithArguments.cpp @@ -21,83 +21,75 @@ #include namespace samediff { - CallableWithArguments::CallableWithArguments(FUNC_DO func, uint64_t thread_id, uint64_t numThreads) { - _function_do = func; - _finished = false; - _threadId = thread_id; - _numThreads = numThreads; - _dimensions = 0; - } - - CallableWithArguments::CallableWithArguments(FUNC_3D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t increment_x, int64_t start_y, int64_t stop_y, int64_t increment_y, int64_t start_z, int64_t stop_z, int64_t increment_z) { - _function_3d = func; - _arguments = {start_x, stop_x, increment_x, start_y, stop_y, increment_y, start_z, stop_z, increment_z}; - _finished = false; - _threadId = thread_id; - _dimensions = 3; - } - - CallableWithArguments::CallableWithArguments(FUNC_1D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t increment_x) { - _function_1d = func; - _arguments = {start_x, stop_x, increment_x}; - _finished = false; - _threadId = thread_id; - _dimensions = 1; - } - - CallableWithArguments::CallableWithArguments(FUNC_2D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t increment_x, int64_t start_y, int64_t stop_y, int64_t increment_y) { - _function_2d = func; - _arguments = {start_x, stop_x, increment_x, start_y, stop_y, increment_y}; - _finished = false; - _threadId = thread_id; - _dimensions = 2; - } - - int CallableWithArguments::dimensions() { - return _dimensions; - } - - std::vector& CallableWithArguments::arguments() { - return _arguments; - } - - bool CallableWithArguments::finished() { - return _finished.load(); - } - - void CallableWithArguments::finish() { - std::lock_guard lock(_lock); - _finished = true; - _condition.notify_one(); - } - - void CallableWithArguments::waitUntilFinished() { - std::unique_lock lock(_lock); - _condition.wait(lock, [&]{ return _finished.load(); }); - } - - - FUNC_1D CallableWithArguments::function_1d() { - return _function_1d; - } - - FUNC_2D CallableWithArguments::function_2d() { - return _function_2d; - } - - FUNC_DO CallableWithArguments::function_do() { - return _function_do; - } - - FUNC_3D CallableWithArguments::function_3d() { - return _function_3d; - } - - uint64_t CallableWithArguments::threadId() { - return _threadId; - } - - uint64_t CallableWithArguments::numThreads() { - return _numThreads; - } -} \ No newline at end of file +CallableWithArguments::CallableWithArguments(FUNC_DO func, uint64_t thread_id, + uint64_t numThreads) { + _function_do = func; + _finished = false; + _threadId = thread_id; + _numThreads = numThreads; + _dimensions = 0; +} + +CallableWithArguments::CallableWithArguments( + FUNC_3D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, + int64_t increment_x, int64_t start_y, int64_t stop_y, int64_t increment_y, + int64_t start_z, int64_t stop_z, int64_t increment_z) { + _function_3d = func; + _arguments = {start_x, stop_x, increment_x, start_y, stop_y, + increment_y, start_z, stop_z, increment_z}; + _finished = false; + _threadId = thread_id; + _dimensions = 3; +} + +CallableWithArguments::CallableWithArguments(FUNC_1D func, uint64_t thread_id, + int64_t start_x, int64_t stop_x, + int64_t increment_x) { + _function_1d = func; + _arguments = {start_x, stop_x, increment_x}; + _finished = false; + _threadId = thread_id; + _dimensions = 1; +} + +CallableWithArguments::CallableWithArguments(FUNC_2D func, uint64_t thread_id, + int64_t start_x, int64_t stop_x, + int64_t increment_x, + int64_t start_y, int64_t stop_y, + int64_t increment_y) { + _function_2d = func; + _arguments = {start_x, stop_x, increment_x, start_y, stop_y, increment_y}; + _finished = false; + _threadId = thread_id; + _dimensions = 2; +} + +int CallableWithArguments::dimensions() { return _dimensions; } + +std::vector& CallableWithArguments::arguments() { return _arguments; } + +bool CallableWithArguments::finished() { return _finished.load(); } + +void CallableWithArguments::finish() { + std::lock_guard lock(_lock); + _finished = true; + _condition.notify_one(); +} + +void CallableWithArguments::waitUntilFinished() { + std::unique_lock lock(_lock); + _condition.wait(lock, [&] { return _finished.load(); }); +} + +FUNC_1D CallableWithArguments::function_1d() { return _function_1d; } + +FUNC_2D CallableWithArguments::function_2d() { return _function_2d; } + +FUNC_DO CallableWithArguments::function_do() { return _function_do; } + +FUNC_3D CallableWithArguments::function_3d() { return _function_3d; } + +uint64_t CallableWithArguments::threadId() { return _threadId; } + +uint64_t CallableWithArguments::numThreads() { return _numThreads; } +} // namespace samediff \ No newline at end of file diff --git a/libnd4j/include/execution/impl/ErrorReference.cpp b/libnd4j/include/execution/impl/ErrorReference.cpp index 7b3409aa1339..307ed874c602 100644 --- a/libnd4j/include/execution/impl/ErrorReference.cpp +++ b/libnd4j/include/execution/impl/ErrorReference.cpp @@ -18,29 +18,25 @@ // @author raver119@gmail.com // - #include namespace sd { - int ErrorReference::errorCode() { - return _errorCode; - } +int ErrorReference::errorCode() { return _errorCode; } - const char* ErrorReference::errorMessage() { - // since we're fetching error message - error code will be assumed consumed & nullified - _errorCode = 0; - return _errorMessage.c_str(); - } +const char* ErrorReference::errorMessage() { + // since we're fetching error message - error code will be assumed consumed & + // nullified + _errorCode = 0; + return _errorMessage.c_str(); +} - void ErrorReference::setErrorCode(int errorCode) { - _errorCode = errorCode; - } +void ErrorReference::setErrorCode(int errorCode) { _errorCode = errorCode; } - void ErrorReference::setErrorMessage(std::string message) { - _errorMessage = message; - } +void ErrorReference::setErrorMessage(std::string message) { + _errorMessage = message; +} - void ErrorReference::setErrorMessage(const char* message) { - _errorMessage = std::string(message); - } +void ErrorReference::setErrorMessage(const char* message) { + _errorMessage = std::string(message); } +} // namespace sd diff --git a/libnd4j/include/execution/impl/ThreadPool.cpp b/libnd4j/include/execution/impl/ThreadPool.cpp index f6c3fdaca9f9..4557b496c804 100644 --- a/libnd4j/include/execution/impl/ThreadPool.cpp +++ b/libnd4j/include/execution/impl/ThreadPool.cpp @@ -19,176 +19,180 @@ // #include -#include #include +#include + #if defined(_WIN32) || defined(_WIN64) //#include #endif namespace samediff { - // this function executed once per thread, it polls functions from queue, and executes them via wrapper - static void executionLoop_(int thread_id, BlockingQueue *queue) { - while (true) { - // this method blocks until there's something within queue - auto c = queue->poll(); - //nd4j_printf("ThreadPool: starting thread %i\n", c->threadId()); - switch (c->dimensions()) { - case 0: { - c->function_do()(c->threadId(), c->numThreads()); - c->finish(); - } - break; - case 1: { - auto args = c->arguments(); - c->function_1d()(c->threadId(), args[0], args[1], args[2]); - c->finish(); - } - break; - case 2: { - auto args = c->arguments(); - c->function_2d()(c->threadId(), args[0], args[1], args[2], args[3], args[4], args[5]); - c->finish(); - //nd4j_printf("ThreadPool: finished thread %i\n", c->threadId()); - } - break; - case 3: { - auto args = c->arguments(); - c->function_3d()(c->threadId(), args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8]); - c->finish(); - } - break; - default: - throw std::runtime_error("Don't know what to do with provided Callable"); - } - } +// this function executed once per thread, it polls functions from queue, and +// executes them via wrapper +static void executionLoop_(int thread_id, + BlockingQueue *queue) { + while (true) { + // this method blocks until there's something within queue + auto c = queue->poll(); + // nd4j_printf("ThreadPool: starting thread %i\n", c->threadId()); + switch (c->dimensions()) { + case 0: { + c->function_do()(c->threadId(), c->numThreads()); + c->finish(); + } break; + case 1: { + auto args = c->arguments(); + c->function_1d()(c->threadId(), args[0], args[1], args[2]); + c->finish(); + } break; + case 2: { + auto args = c->arguments(); + c->function_2d()(c->threadId(), args[0], args[1], args[2], args[3], + args[4], args[5]); + c->finish(); + // nd4j_printf("ThreadPool: finished thread %i\n", c->threadId()); + } break; + case 3: { + auto args = c->arguments(); + c->function_3d()(c->threadId(), args[0], args[1], args[2], args[3], + args[4], args[5], args[6], args[7], args[8]); + c->finish(); + } break; + default: + throw std::runtime_error( + "Don't know what to do with provided Callable"); } + } +} - static void executionLoopWithInterface_(int thread_id, CallableInterface *c) { - while (true) { - // blocking here until there's something to do - c->waitForTask(); +static void executionLoopWithInterface_(int thread_id, CallableInterface *c) { + while (true) { + // blocking here until there's something to do + c->waitForTask(); - // execute whatever we have - c->execute(); - } - } + // execute whatever we have + c->execute(); + } +} - ThreadPool::ThreadPool() { - // TODO: number of threads must reflect number of cores for UMA system. In case of NUMA it should be per-device pool - // FIXME: on mobile phones this feature must NOT be used - _available = sd::Environment::getInstance().maxThreads(); - - _queues.resize(_available.load()); - _threads.resize(_available.load()); - _interfaces.resize(_available.load()); - - // creating threads here - for (int e = 0; e < _available.load(); e++) { - _queues[e] = new BlockingQueue(2); - _interfaces[e] = new CallableInterface(); - _threads[e] = std::thread(executionLoopWithInterface_, e, _interfaces[e]); - _tickets.push(new Ticket()); - // _threads[e] = new std::thread(executionLoop_, e, _queues[e]); - - // TODO: add other platforms here as well - // now we must set affinity, and it's going to be platform-specific thing +ThreadPool::ThreadPool() { + // TODO: number of threads must reflect number of cores for UMA system. In + // case of NUMA it should be per-device pool + // FIXME: on mobile phones this feature must NOT be used + _available = sd::Environment::getInstance().maxThreads(); + + _queues.resize(_available.load()); + _threads.resize(_available.load()); + _interfaces.resize(_available.load()); + + // creating threads here + for (int e = 0; e < _available.load(); e++) { + _queues[e] = new BlockingQueue(2); + _interfaces[e] = new CallableInterface(); + _threads[e] = + std::thread(executionLoopWithInterface_, e, _interfaces[e]); + _tickets.push(new Ticket()); + // _threads[e] = new std::thread(executionLoop_, e, _queues[e]); + + // TODO: add other platforms here as well + // now we must set affinity, and it's going to be platform-specific thing #ifdef LINUX_BUILD - cpu_set_t cpuset; - CPU_ZERO(&cpuset); - CPU_SET(e, &cpuset); - int rc = pthread_setaffinity_np(_threads[e]->native_handle(), sizeof(cpu_set_t), &cpuset); - if (rc != 0) - throw std::runtime_error("Failed to set pthread affinity"); + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + CPU_SET(e, &cpuset); + int rc = pthread_setaffinity_np(_threads[e]->native_handle(), + sizeof(cpu_set_t), &cpuset); + if (rc != 0) throw std::runtime_error("Failed to set pthread affinity"); #endif - /* + /* #if defined(_WIN32) || defined(_WIN64) - // we can't set affinity to more than 64 cores - if (e <= 64) { - auto mask = (static_cast(1) << e); - auto result = SetThreadAffinityMask(_threads[e]->native_handle(), mask); - if (!result) - throw std::runtime_error("Failed to set pthread affinity"); - } - - // that's fine. no need for time_critical here - SetThreadPriority(_threads[e]->native_handle(), THREAD_PRIORITY_HIGHEST); -#endif - */ - } + // we can't set affinity to more than 64 cores + if (e <= 64) { + auto mask = (static_cast(1) << e); + auto result = SetThreadAffinityMask(_threads[e]->native_handle(), mask); + if (!result) + throw std::runtime_error("Failed to set pthread affinity"); } - ThreadPool::~ThreadPool() { - // TODO: implement this one properly - for (int e = 0; e < _queues.size(); e++) { - // stop each and every thread + // that's fine. no need for time_critical here + SetThreadPriority(_threads[e]->native_handle(), THREAD_PRIORITY_HIGHEST); +#endif + */ + } +} - // release queue and thread - delete _queues[e]; - _threads[e].detach(); - //delete _interfaces[e]; - } +ThreadPool::~ThreadPool() { + // TODO: implement this one properly + for (int e = 0; e < _queues.size(); e++) { + // stop each and every thread - while (!_tickets.empty()) { + // release queue and thread + delete _queues[e]; + _threads[e].detach();// delete _interfaces[e]; + } +while (!_tickets.empty()) { auto t = _tickets.front(); _tickets.pop(); delete t; } - } - - ThreadPool& ThreadPool::getInstance() { - static ThreadPool instance; - return instance; - } +} - void ThreadPool::release(int numThreads) { - _available += numThreads; - } +ThreadPool &ThreadPool::getInstance() { + static ThreadPool instance; - Ticket* ThreadPool::tryAcquire(int numThreads) { - //std::vector*> queues; - - Ticket *t = nullptr; - // we check for threads availability first - bool threaded = false; - { - // we lock before checking availability - std::unique_lock lock(_lock); - if (_available >= numThreads) { - threaded = true; - _available -= numThreads; - - // getting a ticket from the queue - t = _tickets.front(); - _tickets.pop(); - - // ticket must contain information about number of threads for the current session - t->acquiredThreads(numThreads); - - // filling ticket with executable interfaces - for (int e = 0, i = 0; e < _queues.size() && i < numThreads; e++) { - if (_interfaces[e]->available()) { - t->attach(i++, _interfaces[e]); - _interfaces[e]->markUnavailable(); - } - } - } - } + return instance; +} - // we either dispatch tasks to threads, or run single-threaded - if (threaded) { - return t; - } else { - // if there's no threads available - return nullptr - return nullptr; +void ThreadPool::release(int numThreads) { _available += numThreads; } + +Ticket *ThreadPool::tryAcquire(int numThreads) { + // std::vector*> queues; + + Ticket *t = nullptr; + // we check for threads availability first + bool threaded = false; + { + // we lock before checking availability + std::unique_lock lock(_lock); + if (_available >= numThreads) { + threaded = true; + _available -= numThreads; + + // getting a ticket from the queue + t = _tickets.front(); + _tickets.pop(); + + // ticket must contain information about number of threads for the current + // session + t->acquiredThreads(numThreads); + + // filling ticket with executable interfaces + for (int e = 0, i = 0; e < _queues.size() && i < numThreads; e++) { + if (_interfaces[e]->available()) { + t->attach(i++, _interfaces[e]); + _interfaces[e]->markUnavailable(); } + } } + } + + // we either dispatch tasks to threads, or run single-threaded + if (threaded) { + return t; + } else { + // if there's no threads available - return nullptr + return nullptr; + } +} - void ThreadPool::release(samediff::Ticket *ticket) { - // returning ticket back to the queue - std::unique_lock lock(_lock); - _tickets.push(ticket); - } +void ThreadPool::release(samediff::Ticket *ticket) { + // returning ticket back to the queue + std::unique_lock lock(_lock); + _tickets.push(ticket); } + + +} // namespace samediff diff --git a/libnd4j/include/execution/impl/Threads.cpp b/libnd4j/include/execution/impl/Threads.cpp index 90dd519b11b9..06aa2e1dc937 100644 --- a/libnd4j/include/execution/impl/Threads.cpp +++ b/libnd4j/include/execution/impl/Threads.cpp @@ -17,707 +17,685 @@ // // @author raver119@gmail.com // -#include #include -#include -#include +#include #include -#include #include +#include +#include +#include namespace samediff { - int ThreadsHelper::numberOfThreads(int maxThreads, uint64_t numberOfElements) { - // let's see how many threads we actually need first - auto optimalThreads = sd::math::nd4j_max(1, numberOfElements / 1024); - - // now return the smallest value - return sd::math::nd4j_min(optimalThreads, maxThreads); - } - - Span3::Span3(int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY, int64_t startZ, int64_t stopZ, int64_t incZ) { - _startX = startX; - _startY = startY; - _startZ = startZ; - _stopX = stopX; - _stopY = stopY; - _stopZ = stopZ; - _incX = incX; - _incY = incY; - _incZ = incZ; - } - - Span3 Span3::build(int loop, uint64_t threadID, uint64_t numThreads, int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY, int64_t startZ, int64_t stopZ, int64_t incZ) { - switch (loop) { - case 1: { - auto span = (stopX - startX) / numThreads; - auto s = span * threadID; - auto e = s + span; - if (threadID == numThreads - 1) - e = stopX; - - return Span3(s, e, incX, startY, stopY, incY, startZ, stopZ, incZ); - } - break; - case 2: { - auto span = (stopY - startY) / numThreads; - auto s = span * threadID; - auto e = s + span; - if (threadID == numThreads - 1) - e = stopY; - - return Span3(startX, stopX, incX, s, e, incY, startZ, stopZ, incZ); - } - break; - case 3: { - auto span = (stopZ - startZ) / numThreads; - auto s = span * threadID; - auto e = s + span; - if (threadID == numThreads - 1) - e = stopZ; - - return Span3(startX, stopX, incX, startY, stopY, incY, s, e, incZ); - } - break; - default: - throw std::runtime_error(""); - } - return Span3(startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ); - } - - Span::Span(int64_t startX, int64_t stopX, int64_t incX) { - _startX = startX; - _stopX = stopX; - _incX = incX; - } - - Span Span::build(uint64_t threadID, uint64_t numThreads, int64_t startX, int64_t stopX, int64_t incX) { - auto span = (stopX - startX) / numThreads; - auto s = span * threadID; - auto e = s + span; - if (threadID == numThreads - 1) - e = stopX; - - return Span(s, e, incX); - } - - Span2::Span2(int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY) { - _startX = startX; - _startY = startY; - _stopX = stopX; - _stopY = stopY; - _incX = incX; - _incY = incY; - } - - - Span2 Span2::build(int loop, uint64_t threadID, uint64_t numThreads, int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY) { - - switch (loop) { - case 1: { - auto span = (stopX - startX) / numThreads; - auto s = span * threadID; - auto e = s + span; - if (threadID == numThreads - 1) - e = stopX; - - return Span2(s, e, incX, startY, stopY, incY); - } - break; - case 2: { - auto span = (stopY - startY) / numThreads; - auto s = span * threadID; - auto e = s + span; - if (threadID == numThreads - 1) - e = stopY; - - return Span2(startX, stopX, incX, s, e, incY); - } - break; - default: - throw std::runtime_error(""); - } - } - - int64_t Span::startX() const { - return _startX; - } - - int64_t Span::stopX() const { - return _stopX; - } - - int64_t Span::incX() const { - return _incX; - } - - int64_t Span2::startX() const { - return _startX; - } - - int64_t Span2::startY() const { - return _startY; - } - - int64_t Span2::stopX() const { - return _stopX; - } - - int64_t Span2::stopY() const { - return _stopY; - } - - int64_t Span2::incX() const { - return _incX; - } - - int64_t Span2::incY() const { - return _incY; - } - - int64_t Span3::startX() const { - return _startX; - } - - int64_t Span3::startY() const { - return _startY; - } - - int64_t Span3::startZ() const { - return _startZ; - } - - int64_t Span3::stopX() const { - return _stopX; - } - - int64_t Span3::stopY() const { - return _stopY; - } - - int64_t Span3::stopZ() const { - return _stopZ; - } - - int64_t Span3::incX() const { - return _incX; - } - - int64_t Span3::incY() const { - return _incY; - } - - int64_t Span3::incZ() const { - return _incZ; - } - - int ThreadsHelper::pickLoop2d(int numThreads, uint64_t itersX, uint64_t itersY) { - // if one of dimensions is definitely too small - we just pick the other one - if (itersX < numThreads && itersY >= numThreads) - return 2; - if (itersY < numThreads && itersX >= numThreads) - return 1; - - // next step - we pick the most balanced dimension - auto remX = itersX % numThreads; - auto remY = itersY % numThreads; - auto splitY = itersY / numThreads; - - // if there's no remainder left in some dimension - we're picking that dimension, because it'll be the most balanced work distribution - if (remX == 0) - return 1; - if (remY == 0) - return 2; - - // if there's no loop without a remainder - we're picking one with smaller remainder - if (remX < remY) - return 1; - if (remY < remX && splitY >= 64) // we don't want too small splits over last dimension, or vectorization will fail - return 2; - // if loops are equally sized - give the preference to the first thread - return 1; - } - - - static int threads_(int maxThreads, uint64_t elements) { - - if (elements == maxThreads) { - return maxThreads; - } - else if (elements > maxThreads) { - // if we have full load across thread, or at least half of threads can be utilized - auto rem = elements % maxThreads; - if (rem == 0 || rem >= maxThreads / 3) - return maxThreads; - else - return threads_(maxThreads - 1, elements); - - } - else if (elements < maxThreads) { - return elements; - } - - return 1; - } - - int ThreadsHelper::numberOfThreads2d(int maxThreads, uint64_t iters_x, uint64_t iters_y) { - // in some cases there's nothing to think about, part 1 - if (iters_x < maxThreads && iters_y < maxThreads) - return sd::math::nd4j_max(iters_x, iters_y); - - auto remX = iters_x % maxThreads; - auto remY = iters_y % maxThreads; - - // in some cases there's nothing to think about, part 2 - if ((iters_x >= maxThreads && remX == 0 )|| (iters_y >= maxThreads && remY == 0)) - return maxThreads; - - // at this point we suppose that there's no loop perfectly matches number of our threads - // so let's pick something as equal as possible - if (iters_x > maxThreads || iters_y > maxThreads) - return maxThreads; - else - return numberOfThreads2d(maxThreads - 1, iters_x, iters_y); - } - - int ThreadsHelper::numberOfThreads3d(int maxThreads, uint64_t itersX, uint64_t itersY, uint64_t itersZ) { - // we don't want to run underloaded threads - if (itersX * itersY * itersZ <= 32) - return 1; - - auto remX = itersX % maxThreads; - auto remY = itersY % maxThreads; - auto remZ = itersZ % maxThreads; - - // if we have perfect balance across one of dimensions - just go for it - if ((itersX >= maxThreads && remX == 0) || (itersY >= maxThreads && remY == 0) || (itersZ >= maxThreads && remZ == 0)) - return maxThreads; - - int threadsX = 0, threadsY = 0, threadsZ = 0; - - // now we look into possible number of - threadsX = threads_(maxThreads, itersX); - threadsY = threads_(maxThreads, itersY); - threadsZ = threads_(maxThreads, itersZ); - - // we want to split as close to outer loop as possible, so checking it out first - if (threadsX >= threadsY && threadsX >= threadsZ) - return threadsX; - else if (threadsY >= threadsX && threadsY >= threadsZ) - return threadsY; - else if (threadsZ >= threadsX && threadsZ >= threadsY) - return threadsZ; - - return 1; - } - - int ThreadsHelper::pickLoop3d(int numThreads, uint64_t itersX, uint64_t itersY, uint64_t itersZ) { - auto remX = itersX % numThreads; - auto remY = itersY % numThreads; - auto remZ = itersZ % numThreads; - - auto splitX = itersX / numThreads; - auto splitY = itersY / numThreads; - auto splitZ = itersZ / numThreads; - - // if there's no remainder left in some dimension - we're picking that dimension, because it'll be the most balanced work distribution - if (remX == 0) - return 1; - else if (remY == 0) - return 2; - else if (remZ == 0) // TODO: we don't want too smal splits over last dimension? or we do? - return 3; - - if (itersX > numThreads) - return 1; - else if (itersY > numThreads) - return 2; - else if (itersZ > numThreads) - return 3; - - return 1; - } - - int Threads::parallel_tad(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, uint32_t numThreads) { - if (start > stop) - throw std::runtime_error("Threads::parallel_for got start > stop"); - - auto delta = (stop - start); - - if (numThreads > delta) - numThreads = delta; - - if (numThreads == 0) - return 0; - - // shortcut - if (numThreads == 1) { - function(0, start, stop, increment); - return 1; - } - - auto ticket = ThreadPool::getInstance().tryAcquire(numThreads); - if (ticket != nullptr) { - // if we got our threads - we'll run our jobs here - auto span = delta / numThreads; - - for (uint32_t e = 0; e < numThreads; e++) { - auto start_ = span * e + start; - auto stop_ = start_ + span; - - // last thread will process tail - if (e == numThreads - 1) - stop_ = stop; - - // putting the task into the queue for a given thread - ticket->enqueue(e, numThreads, function, start_, stop_, increment); - } - - // block and wait till all threads finished the job - ticket->waitAndRelease(); - - // we tell that parallelism request succeeded - return numThreads; - } else { - // if there were no threads available - we'll execute function right within current thread - function(0, start, stop, increment); - - // we tell that parallelism request declined - return 1; - } - } - - int Threads::parallel_for(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, uint32_t numThreads) { - if (start > stop) - throw std::runtime_error("Threads::parallel_for got start > stop"); - - auto delta = (stop - start); - - // in some cases we just fire func as is - if (delta == 0 || numThreads == 1) { - function(0, start, stop, increment); - return 1; - } - - auto numElements = delta / increment; - - // we decide what's optimal number of threads we need here, and execute it in parallel_tad. - numThreads = ThreadsHelper::numberOfThreads(numThreads, numElements); - return parallel_tad(function, start, stop, increment, numThreads); - } - - int Threads::parallel_for(FUNC_2D function, int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY, uint64_t numThreads, bool debug) { - if (startX > stopX) - throw std::runtime_error("Threads::parallel_for got startX > stopX"); - - if (startY > stopY) - throw std::runtime_error("Threads::parallel_for got startY > stopY"); - - // number of elements per loop - auto delta_x = (stopX - startX); - auto delta_y = (stopY - startY); - - // number of iterations per loop - auto itersX = delta_x / incX; - auto itersY = delta_y / incY; - - // total number of iterations - auto iters_t = itersX * itersY; - - // we are checking the case of number of requested threads was smaller - numThreads = ThreadsHelper::numberOfThreads2d(numThreads, itersX, itersY); - - // basic shortcut for no-threading cases - if (numThreads == 1) { - function(0, startX, stopX, incX, startY, stopY, incY); - return 1; - } - - // We have couple of scenarios: - // either we split workload along 1st loop, or 2nd - auto splitLoop = ThreadsHelper::pickLoop2d(numThreads, itersX, itersY); - - // for debug mode we execute things inplace, without any threads - if (debug) { - for (int e = 0; e < numThreads; e++) { - auto span = Span2::build(splitLoop, e, numThreads, startX, stopX, incX, startY, stopY, incY); - - function(e, span.startX(), span.stopX(), span.incX(), span.startY(), span.stopY(), span.incY()); - } - - // but we still mimic multithreaded execution - return numThreads; - } else { - auto ticket = ThreadPool::getInstance().tryAcquire(numThreads); - if (ticket != nullptr) { - - for (int e = 0; e < numThreads; e++) { - auto threadId = numThreads - e - 1; - auto span = Span2::build(splitLoop, threadId, numThreads, startX, stopX, incX, startY, stopY, incY); - - ticket->enqueue(e, numThreads, function, span.startX(), span.stopX(), span.incX(), span.startY(), span.stopY(), span.incY()); - } - - // block until all threads finish their job - ticket->waitAndRelease(); - - return numThreads; - } else { - // if there were no threads available - we'll execute function right within current thread - function(0, startX, stopX, incX, startY, stopY, incY); - - // we tell that parallelism request declined - return 1; - } - }; - } - - - int Threads::parallel_for(FUNC_3D function, int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY, int64_t startZ, int64_t stopZ, int64_t incZ, uint64_t numThreads) { - if (startX > stopX) - throw std::runtime_error("Threads::parallel_for got startX > stopX"); - - if (startY > stopY) - throw std::runtime_error("Threads::parallel_for got startY > stopY"); - - if (startZ > stopZ) - throw std::runtime_error("Threads::parallel_for got startZ > stopZ"); - - auto delta_x = stopX - startX; - auto delta_y = stopY - startY; - auto delta_z = stopZ - startZ; - - auto itersX = delta_x / incX; - auto itersY = delta_y / incY; - auto itersZ = delta_z / incZ; - - numThreads = ThreadsHelper::numberOfThreads3d(numThreads, itersX, itersY, itersZ); - if (numThreads == 1) { - // loop is too small - executing function as is - function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ); - return 1; - } - - auto ticket = ThreadPool::getInstance().tryAcquire(numThreads); - if (ticket != nullptr) { - auto splitLoop = ThreadsHelper::pickLoop3d(numThreads, itersX, itersY, itersZ); - - for (int e = 0; e < numThreads; e++) { - auto thread_id = numThreads - e - 1; - auto span = Span3::build(splitLoop, thread_id, numThreads, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ); - - ticket->enqueue(e, numThreads, function, span.startX(), span.stopX(), span.incX(), span.startY(), span.stopY(), span.incY(), span.startZ(), span.stopZ(), span.incZ()); - } - - // block until we're done - ticket->waitAndRelease(); - - // we tell that parallelism request succeeded - return numThreads; - } else { - // if there were no threads available - we'll execute function right within current thread - function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ); - - // we tell that parallelism request declined - return 1; - } - - } - - int Threads::parallel_do(FUNC_DO function, uint64_t numThreads) { - auto ticket = ThreadPool::getInstance().tryAcquire(numThreads - 1); - if (ticket != nullptr) { - - // submit tasks one by one - for (uint64_t e = 0; e < numThreads - 1; e++) - ticket->enqueue(e, numThreads, function); - - function(numThreads - 1, numThreads); - - ticket->waitAndRelease(); - - return numThreads; - } else { - // if there's no threads available - we'll execute function sequentially one by one - for (uint64_t e = 0; e < numThreads; e++) - function(e, numThreads); - - return numThreads; - } - - - return numThreads; - } - - int64_t Threads::parallel_long(FUNC_RL function, FUNC_AL aggregator, int64_t start, int64_t stop, int64_t increment, uint64_t numThreads) { - if (start > stop) - throw std::runtime_error("Threads::parallel_long got start > stop"); - - auto delta = (stop - start); - if (delta == 0 || numThreads == 1) - return function(0, start, stop, increment); - - auto numElements = delta / increment; - - // we decide what's optimal number of threads we need here, and execute it - numThreads = ThreadsHelper::numberOfThreads(numThreads, numElements); - if (numThreads == 1) - return function(0, start, stop, increment); - - auto ticket = ThreadPool::getInstance().tryAcquire(numThreads - 1); - if (ticket == nullptr) - return function(0, start, stop, increment); - - // create temporary array - int64_t intermediatery[256]; - auto span = (numElements / numThreads) - (numElements % numThreads); - - // execute threads in parallel - for (uint32_t e = 0; e < numThreads; e++) { - auto start_ = span * e + start; - auto stop_ = span * (e + 1) + start; - - if (e == numThreads - 1) - intermediatery[e] = function(e, start_, stop, increment); - else - ticket->enqueue(e, numThreads, &intermediatery[e], function, start_, stop_, increment); - } - - ticket->waitAndRelease(); - - // aggregate results in single thread - for (uint64_t e = 1; e < numThreads; e++) - intermediatery[0] = aggregator(intermediatery[0], intermediatery[e]); - - // return accumulated result - return intermediatery[0]; - } - - double Threads::parallel_double(FUNC_RD function, FUNC_AD aggregator, int64_t start, int64_t stop, int64_t increment, uint64_t numThreads) { - if (start > stop) - throw std::runtime_error("Threads::parallel_long got start > stop"); - - auto delta = (stop - start); - if (delta == 0 || numThreads == 1) - return function(0, start, stop, increment); - - auto numElements = delta / increment; - - // we decide what's optimal number of threads we need here, and execute it - numThreads = ThreadsHelper::numberOfThreads(numThreads, numElements); - if (numThreads == 1) - return function(0, start, stop, increment); - - auto ticket = ThreadPool::getInstance().tryAcquire(numThreads - 1); - if (ticket == nullptr) - return function(0, start, stop, increment); - - // create temporary array - double intermediatery[256]; - auto span = (numElements / numThreads) - (numElements % numThreads); - - // execute threads in parallel - for (uint32_t e = 0; e < numThreads; e++) { - auto start_ = span * e + start; - auto stop_ = span * (e + 1) + start; - - if (e == numThreads - 1) - intermediatery[e] = function(e, start_, stop, increment); - else - ticket->enqueue(e, numThreads, &intermediatery[e], function, start_, stop_, increment); - } - - ticket->waitAndRelease(); - - // aggregate results in single thread - for (uint64_t e = 1; e < numThreads; e++) - intermediatery[0] = aggregator(intermediatery[0], intermediatery[e]); - - // return accumulated result - return intermediatery[0]; - } - - - int Threads::parallel_aligned_increment(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, size_t type_size , uint32_t req_numThreads) { - if (start > stop) - throw std::runtime_error("Threads::parallel_for got start > stop"); - auto num_elements = (stop - start); - //this way we preserve increment starts offset - //so we will parition considering delta but not total elements - auto delta = (stop - start) / increment; - - // in some cases we just fire func as is - if (delta == 0 || req_numThreads == 1) { - function(0, start, stop, increment); - return 1; - } - int numThreads = 0; - - int adjusted_numThreads = samediff::ThreadsHelper::numberOfThreads(req_numThreads, (num_elements * sizeof(double)) / (200 * type_size)); - - if (adjusted_numThreads > delta) - adjusted_numThreads = delta; - // shortcut - if (adjusted_numThreads <= 1) { - function(0, start, stop, increment); - return 1; - } - //take span as ceil - auto spand = std::ceil((double)delta / (double)adjusted_numThreads); - numThreads = static_cast(std::ceil((double)delta / spand)); - auto span = static_cast(spand); - - auto ticket = samediff::ThreadPool::getInstance().tryAcquire(numThreads); - if (ticket != nullptr) { - //tail_add is additional value of the last part - //it could be negative or positive - //we will spread that value across - auto tail_add = delta - numThreads * span; - Nd4jLong begin = 0; - Nd4jLong end = 0; - - //we will try enqueu bigger parts first - decltype(span) span1, span2; - int last = 0; - if (tail_add >= 0) { - //for span == 1 , tail_add is 0 - last = tail_add; - span1 = span + 1; - span2 = span; - } - else { - last = numThreads + tail_add;// -std::abs(tail_add); - span1 = span; - span2 = span - 1; - } - for (int i = 0; i < last; i++) { - end = begin + span1 * increment; - // putting the task into the queue for a given thread - ticket->enqueue(i, numThreads, function, begin, end, increment); - begin = end; - } - for (int i = last; i < numThreads - 1; i++) { - end = begin + span2 * increment; - // putting the task into the queue for a given thread - ticket->enqueue(i, numThreads, function, begin, end, increment); - begin = end; - } - //for last one enqueue last offset as stop - //we need it in case our ((stop-start) % increment ) > 0 - ticket->enqueue(numThreads - 1, numThreads, function, begin, stop, increment); - // block and wait till all threads finished the job - ticket->waitAndRelease(); - // we tell that parallelism request succeeded - return numThreads; - } - else { - // if there were no threads available - we'll execute function right within current thread - function(0, start, stop, increment); - // we tell that parallelism request declined - return 1; - } - } - +int ThreadsHelper::numberOfThreads(int maxThreads, uint64_t numberOfElements) { + // let's see how many threads we actually need first + auto optimalThreads = + sd::math::nd4j_max(1, numberOfElements / 1024); + + // now return the smallest value + return sd::math::nd4j_min(optimalThreads, maxThreads); +} + +Span3::Span3(int64_t startX, int64_t stopX, int64_t incX, int64_t startY, + int64_t stopY, int64_t incY, int64_t startZ, int64_t stopZ, + int64_t incZ) { + _startX = startX; + _startY = startY; + _startZ = startZ; + _stopX = stopX; + _stopY = stopY; + _stopZ = stopZ; + _incX = incX; + _incY = incY; + _incZ = incZ; +} + +Span3 Span3::build(int loop, uint64_t threadID, uint64_t numThreads, + int64_t startX, int64_t stopX, int64_t incX, int64_t startY, + int64_t stopY, int64_t incY, int64_t startZ, int64_t stopZ, + int64_t incZ) { + switch (loop) { + case 1: { + auto span = (stopX - startX) / numThreads; + auto s = span * threadID; + auto e = s + span; + if (threadID == numThreads - 1) e = stopX; + + return Span3(s, e, incX, startY, stopY, incY, startZ, stopZ, incZ); + } break; + case 2: { + auto span = (stopY - startY) / numThreads; + auto s = span * threadID; + auto e = s + span; + if (threadID == numThreads - 1) e = stopY; + + return Span3(startX, stopX, incX, s, e, incY, startZ, stopZ, incZ); + } break; + case 3: { + auto span = (stopZ - startZ) / numThreads; + auto s = span * threadID; + auto e = s + span; + if (threadID == numThreads - 1) e = stopZ; + + return Span3(startX, stopX, incX, startY, stopY, incY, s, e, incZ); + } break; + default: + throw std::runtime_error(""); + } + return Span3(startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ); +} + +Span::Span(int64_t startX, int64_t stopX, int64_t incX) { + _startX = startX; + _stopX = stopX; + _incX = incX; +} + +Span Span::build(uint64_t threadID, uint64_t numThreads, int64_t startX, + int64_t stopX, int64_t incX) { + auto span = (stopX - startX) / numThreads; + auto s = span * threadID; + auto e = s + span; + if (threadID == numThreads - 1) e = stopX; + + return Span(s, e, incX); +} + +Span2::Span2(int64_t startX, int64_t stopX, int64_t incX, int64_t startY, + int64_t stopY, int64_t incY) { + _startX = startX; + _startY = startY; + _stopX = stopX; + _stopY = stopY; + _incX = incX; + _incY = incY; +} + +Span2 Span2::build(int loop, uint64_t threadID, uint64_t numThreads, + int64_t startX, int64_t stopX, int64_t incX, int64_t startY, + int64_t stopY, int64_t incY) { + switch (loop) { + case 1: { + auto span = (stopX - startX) / numThreads; + auto s = span * threadID; + auto e = s + span; + if (threadID == numThreads - 1) e = stopX; + + return Span2(s, e, incX, startY, stopY, incY); + } break; + case 2: { + auto span = (stopY - startY) / numThreads; + auto s = span * threadID; + auto e = s + span; + if (threadID == numThreads - 1) e = stopY; + + return Span2(startX, stopX, incX, s, e, incY); + } break; + default: + throw std::runtime_error(""); + } +} + +int64_t Span::startX() const { return _startX; } + +int64_t Span::stopX() const { return _stopX; } + +int64_t Span::incX() const { return _incX; } + +int64_t Span2::startX() const { return _startX; } + +int64_t Span2::startY() const { return _startY; } + +int64_t Span2::stopX() const { return _stopX; } + +int64_t Span2::stopY() const { return _stopY; } + +int64_t Span2::incX() const { return _incX; } + +int64_t Span2::incY() const { return _incY; } + +int64_t Span3::startX() const { return _startX; } + +int64_t Span3::startY() const { return _startY; } + +int64_t Span3::startZ() const { return _startZ; } + +int64_t Span3::stopX() const { return _stopX; } + +int64_t Span3::stopY() const { return _stopY; } + +int64_t Span3::stopZ() const { return _stopZ; } + +int64_t Span3::incX() const { return _incX; } + +int64_t Span3::incY() const { return _incY; } + +int64_t Span3::incZ() const { return _incZ; } + +int ThreadsHelper::pickLoop2d(int numThreads, uint64_t itersX, + uint64_t itersY) { + // if one of dimensions is definitely too small - we just pick the other one + if (itersX < numThreads && itersY >= numThreads) return 2; + if (itersY < numThreads && itersX >= numThreads) return 1; + + // next step - we pick the most balanced dimension + auto remX = itersX % numThreads; + auto remY = itersY % numThreads; + auto splitY = itersY / numThreads; + + // if there's no remainder left in some dimension - we're picking that + // dimension, because it'll be the most balanced work distribution + if (remX == 0) return 1; + if (remY == 0) return 2; + + // if there's no loop without a remainder - we're picking one with smaller + // remainder + if (remX < remY) return 1; + if (remY < remX && splitY >= 64) // we don't want too small splits over last + // dimension, or vectorization will fail + return 2; + // if loops are equally sized - give the preference to the first thread + return 1; +} + +static int threads_(int maxThreads, uint64_t elements) { + if (elements == maxThreads) { + return maxThreads; + } else if (elements > maxThreads) { + // if we have full load across thread, or at least half of threads can be + // utilized + auto rem = elements % maxThreads; + if (rem == 0 || rem >= maxThreads / 3) + return maxThreads; + else + return threads_(maxThreads - 1, elements); + + } else if (elements < maxThreads) { + return elements; + } + + return 1; +} + +int ThreadsHelper::numberOfThreads2d(int maxThreads, uint64_t iters_x, + uint64_t iters_y) { + // in some cases there's nothing to think about, part 1 + if (iters_x < maxThreads && iters_y < maxThreads) + return sd::math::nd4j_max(iters_x, iters_y); + + auto remX = iters_x % maxThreads; + auto remY = iters_y % maxThreads; + + // in some cases there's nothing to think about, part 2 + if ((iters_x >= maxThreads && remX == 0) || + (iters_y >= maxThreads && remY == 0)) + return maxThreads; + + // at this point we suppose that there's no loop perfectly matches number of + // our threads so let's pick something as equal as possible + if (iters_x > maxThreads || iters_y > maxThreads) + return maxThreads; + else + return numberOfThreads2d(maxThreads - 1, iters_x, iters_y); +} -} \ No newline at end of file +int ThreadsHelper::numberOfThreads3d(int maxThreads, uint64_t itersX, + uint64_t itersY, uint64_t itersZ) { + // we don't want to run underloaded threads + if (itersX * itersY * itersZ <= 32) return 1; + + auto remX = itersX % maxThreads; + auto remY = itersY % maxThreads; + auto remZ = itersZ % maxThreads; + + // if we have perfect balance across one of dimensions - just go for it + if ((itersX >= maxThreads && remX == 0) || + (itersY >= maxThreads && remY == 0) || + (itersZ >= maxThreads && remZ == 0)) + return maxThreads; + + int threadsX = 0, threadsY = 0, threadsZ = 0; + + // now we look into possible number of + threadsX = threads_(maxThreads, itersX); + threadsY = threads_(maxThreads, itersY); + threadsZ = threads_(maxThreads, itersZ); + + // we want to split as close to outer loop as possible, so checking it out + // first + if (threadsX >= threadsY && threadsX >= threadsZ) + return threadsX; + else if (threadsY >= threadsX && threadsY >= threadsZ) + return threadsY; + else if (threadsZ >= threadsX && threadsZ >= threadsY) + return threadsZ; + + return 1; +} + +int ThreadsHelper::pickLoop3d(int numThreads, uint64_t itersX, uint64_t itersY, + uint64_t itersZ) { + auto remX = itersX % numThreads; + auto remY = itersY % numThreads; + auto remZ = itersZ % numThreads; + + auto splitX = itersX / numThreads; + auto splitY = itersY / numThreads; + auto splitZ = itersZ / numThreads; + + // if there's no remainder left in some dimension - we're picking that + // dimension, because it'll be the most balanced work distribution + if (remX == 0) + return 1; + else if (remY == 0) + return 2; + else if (remZ == 0) // TODO: we don't want too smal splits over last + // dimension? or we do? + return 3; + + if (itersX > numThreads) + return 1; + else if (itersY > numThreads) + return 2; + else if (itersZ > numThreads) + return 3; + + return 1; +} + +int Threads::parallel_tad(FUNC_1D function, int64_t start, int64_t stop, + int64_t increment, uint32_t numThreads) { + if (start > stop) + throw std::runtime_error("Threads::parallel_for got start > stop"); + + auto delta = (stop - start); + + if (numThreads > delta) numThreads = delta; + + if (numThreads == 0) return 0; + + // shortcut + if (numThreads == 1) { + function(0, start, stop, increment); + return 1; + } + + auto ticket = ThreadPool::getInstance().tryAcquire(numThreads); + if (ticket != nullptr) { + // if we got our threads - we'll run our jobs here + auto span = delta / numThreads; + + for (uint32_t e = 0; e < numThreads; e++) { + auto start_ = span * e + start; + auto stop_ = start_ + span; + + // last thread will process tail + if (e == numThreads - 1) stop_ = stop; + + // putting the task into the queue for a given thread + ticket->enqueue(e, numThreads, function, start_, stop_, increment); + } + + // block and wait till all threads finished the job + ticket->waitAndRelease(); + + // we tell that parallelism request succeeded + return numThreads; + } else { + // if there were no threads available - we'll execute function right within + // current thread + function(0, start, stop, increment); + + // we tell that parallelism request declined + return 1; + } +} + +int Threads::parallel_for(FUNC_1D function, int64_t start, int64_t stop, + int64_t increment, uint32_t numThreads) { + if (start > stop) + throw std::runtime_error("Threads::parallel_for got start > stop"); + + auto delta = (stop - start); + + // in some cases we just fire func as is + if (delta == 0 || numThreads == 1) { + function(0, start, stop, increment); + return 1; + } + + auto numElements = delta / increment; + + // we decide what's optimal number of threads we need here, and execute it in + // parallel_tad. + numThreads = ThreadsHelper::numberOfThreads(numThreads, numElements); + return parallel_tad(function, start, stop, increment, numThreads); +} + +int Threads::parallel_for(FUNC_2D function, int64_t startX, int64_t stopX, + int64_t incX, int64_t startY, int64_t stopY, + int64_t incY, uint64_t numThreads, bool debug) { + if (startX > stopX) + throw std::runtime_error("Threads::parallel_for got startX > stopX"); + + if (startY > stopY) + throw std::runtime_error("Threads::parallel_for got startY > stopY"); + + // number of elements per loop + auto delta_x = (stopX - startX); + auto delta_y = (stopY - startY); + + // number of iterations per loop + auto itersX = delta_x / incX; + auto itersY = delta_y / incY; + + // total number of iterations + auto iters_t = itersX * itersY; + + // we are checking the case of number of requested threads was smaller + numThreads = ThreadsHelper::numberOfThreads2d(numThreads, itersX, itersY); + + // basic shortcut for no-threading cases + if (numThreads == 1) { + function(0, startX, stopX, incX, startY, stopY, incY); + return 1; + } + + // We have couple of scenarios: + // either we split workload along 1st loop, or 2nd + auto splitLoop = ThreadsHelper::pickLoop2d(numThreads, itersX, itersY); + + // for debug mode we execute things inplace, without any threads + if (debug) { + for (int e = 0; e < numThreads; e++) { + auto span = Span2::build(splitLoop, e, numThreads, startX, stopX, incX, + startY, stopY, incY); + + function(e, span.startX(), span.stopX(), span.incX(), span.startY(), + span.stopY(), span.incY()); + } + + // but we still mimic multithreaded execution + return numThreads; + } else { + auto ticket = ThreadPool::getInstance().tryAcquire(numThreads); + if (ticket != nullptr) { + for (int e = 0; e < numThreads; e++) { + auto threadId = numThreads - e - 1; + auto span = Span2::build(splitLoop, threadId, numThreads, startX, stopX, + incX, startY, stopY, incY); + + ticket->enqueue(e, numThreads, function, span.startX(), span.stopX(), + span.incX(), span.startY(), span.stopY(), span.incY()); + } + + // block until all threads finish their job + ticket->waitAndRelease(); + + return numThreads; + } else { + // if there were no threads available - we'll execute function right + // within current thread + function(0, startX, stopX, incX, startY, stopY, incY); + + // we tell that parallelism request declined + return 1; + } + }; +} + +int Threads::parallel_for(FUNC_3D function, int64_t startX, int64_t stopX, + int64_t incX, int64_t startY, int64_t stopY, + int64_t incY, int64_t startZ, int64_t stopZ, + int64_t incZ, uint64_t numThreads) { + if (startX > stopX) + throw std::runtime_error("Threads::parallel_for got startX > stopX"); + + if (startY > stopY) + throw std::runtime_error("Threads::parallel_for got startY > stopY"); + + if (startZ > stopZ) + throw std::runtime_error("Threads::parallel_for got startZ > stopZ"); + + auto delta_x = stopX - startX; + auto delta_y = stopY - startY; + auto delta_z = stopZ - startZ; + + auto itersX = delta_x / incX; + auto itersY = delta_y / incY; + auto itersZ = delta_z / incZ; + + numThreads = + ThreadsHelper::numberOfThreads3d(numThreads, itersX, itersY, itersZ); + if (numThreads == 1) { + // loop is too small - executing function as is + function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ); + return 1; + } + + auto ticket = ThreadPool::getInstance().tryAcquire(numThreads); + if (ticket != nullptr) { + auto splitLoop = + ThreadsHelper::pickLoop3d(numThreads, itersX, itersY, itersZ); + + for (int e = 0; e < numThreads; e++) { + auto thread_id = numThreads - e - 1; + auto span = Span3::build(splitLoop, thread_id, numThreads, startX, stopX, + incX, startY, stopY, incY, startZ, stopZ, incZ); + + ticket->enqueue(e, numThreads, function, span.startX(), span.stopX(), + span.incX(), span.startY(), span.stopY(), span.incY(), + span.startZ(), span.stopZ(), span.incZ()); + } + + // block until we're done + ticket->waitAndRelease(); + + // we tell that parallelism request succeeded + return numThreads; + } else { + // if there were no threads available - we'll execute function right within + // current thread + function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ); + + // we tell that parallelism request declined + return 1; + } +} + +int Threads::parallel_do(FUNC_DO function, uint64_t numThreads) { + auto ticket = ThreadPool::getInstance().tryAcquire(numThreads - 1); + if (ticket != nullptr) { + // submit tasks one by one + for (uint64_t e = 0; e < numThreads - 1; e++) + ticket->enqueue(e, numThreads, function); + + function(numThreads - 1, numThreads); + + ticket->waitAndRelease(); + + return numThreads; + } else { + // if there's no threads available - we'll execute function sequentially one + // by one + for (uint64_t e = 0; e < numThreads; e++) function(e, numThreads); + + return numThreads; + } + + return numThreads; +} + +int64_t Threads::parallel_long(FUNC_RL function, FUNC_AL aggregator, + int64_t start, int64_t stop, int64_t increment, + uint64_t numThreads) { + if (start > stop) + throw std::runtime_error("Threads::parallel_long got start > stop"); + + auto delta = (stop - start); + if (delta == 0 || numThreads == 1) return function(0, start, stop, increment); + + auto numElements = delta / increment; + + // we decide what's optimal number of threads we need here, and execute it + numThreads = ThreadsHelper::numberOfThreads(numThreads, numElements); + if (numThreads == 1) return function(0, start, stop, increment); + + auto ticket = ThreadPool::getInstance().tryAcquire(numThreads - 1); + if (ticket == nullptr) return function(0, start, stop, increment); + + // create temporary array + int64_t intermediatery[256]; + auto span = (numElements / numThreads) - (numElements % numThreads); + + // execute threads in parallel + for (uint32_t e = 0; e < numThreads; e++) { + auto start_ = span * e + start; + auto stop_ = span * (e + 1) + start; + + if (e == numThreads - 1) + intermediatery[e] = function(e, start_, stop, increment); + else + ticket->enqueue(e, numThreads, &intermediatery[e], function, start_, + stop_, increment); + } + + ticket->waitAndRelease(); + + // aggregate results in single thread + for (uint64_t e = 1; e < numThreads; e++) + intermediatery[0] = aggregator(intermediatery[0], intermediatery[e]); + + // return accumulated result + return intermediatery[0]; +} + +double Threads::parallel_double(FUNC_RD function, FUNC_AD aggregator, + int64_t start, int64_t stop, int64_t increment, + uint64_t numThreads) { + if (start > stop) + throw std::runtime_error("Threads::parallel_long got start > stop"); + + auto delta = (stop - start); + if (delta == 0 || numThreads == 1) return function(0, start, stop, increment); + + auto numElements = delta / increment; + + // we decide what's optimal number of threads we need here, and execute it + numThreads = ThreadsHelper::numberOfThreads(numThreads, numElements); + if (numThreads == 1) return function(0, start, stop, increment); + + auto ticket = ThreadPool::getInstance().tryAcquire(numThreads - 1); + if (ticket == nullptr) return function(0, start, stop, increment); + + // create temporary array + double intermediatery[256]; + auto span = (numElements / numThreads) - (numElements % numThreads); + + // execute threads in parallel + for (uint32_t e = 0; e < numThreads; e++) { + auto start_ = span * e + start; + auto stop_ = span * (e + 1) + start; + + if (e == numThreads - 1) + intermediatery[e] = function(e, start_, stop, increment); + else + ticket->enqueue(e, numThreads, &intermediatery[e], function, start_, + stop_, increment); + } + + ticket->waitAndRelease(); + + // aggregate results in single thread + for (uint64_t e = 1; e < numThreads; e++) + intermediatery[0] = aggregator(intermediatery[0], intermediatery[e]); + + // return accumulated result + return intermediatery[0]; +} + +int Threads::parallel_aligned_increment(FUNC_1D function, int64_t start, + int64_t stop, int64_t increment, + size_t type_size, + uint32_t req_numThreads) { + if (start > stop) + throw std::runtime_error("Threads::parallel_for got start > stop"); + auto num_elements = (stop - start); + // this way we preserve increment starts offset + // so we will parition considering delta but not total elements + auto delta = (stop - start) / increment; + + // in some cases we just fire func as is + if (delta == 0 || req_numThreads == 1) { + function(0, start, stop, increment); + return 1; + } + int numThreads = 0; + + int adjusted_numThreads = samediff::ThreadsHelper::numberOfThreads( + req_numThreads, (num_elements * sizeof(double)) / (200 * type_size)); + + if (adjusted_numThreads > delta) adjusted_numThreads = delta; + // shortcut + if (adjusted_numThreads <= 1) { + function(0, start, stop, increment); + return 1; + } + // take span as ceil + auto spand = std::ceil((double)delta / (double)adjusted_numThreads); + numThreads = static_cast(std::ceil((double)delta / spand)); + auto span = static_cast(spand); + + auto ticket = samediff::ThreadPool::getInstance().tryAcquire(numThreads); + if (ticket != nullptr) { + // tail_add is additional value of the last part + // it could be negative or positive + // we will spread that value across + auto tail_add = delta - numThreads * span; + Nd4jLong begin = 0; + Nd4jLong end = 0; + + // we will try enqueu bigger parts first + decltype(span) span1, span2; + int last = 0; + if (tail_add >= 0) { + // for span == 1 , tail_add is 0 + last = tail_add; + span1 = span + 1; + span2 = span; + } else { + last = numThreads + tail_add; // -std::abs(tail_add); + span1 = span; + span2 = span - 1; + } + for (int i = 0; i < last; i++) { + end = begin + span1 * increment; + // putting the task into the queue for a given thread + ticket->enqueue(i, numThreads, function, begin, end, increment); + begin = end; + } + for (int i = last; i < numThreads - 1; i++) { + end = begin + span2 * increment; + // putting the task into the queue for a given thread + ticket->enqueue(i, numThreads, function, begin, end, increment); + begin = end; + } + // for last one enqueue last offset as stop + // we need it in case our ((stop-start) % increment ) > 0 + ticket->enqueue(numThreads - 1, numThreads, function, begin, stop, + increment); + // block and wait till all threads finished the job + ticket->waitAndRelease(); + // we tell that parallelism request succeeded + return numThreads; + } else { + // if there were no threads available - we'll execute function right within + // current thread + function(0, start, stop, increment); + // we tell that parallelism request declined + return 1; + } +} + +} // namespace samediff \ No newline at end of file diff --git a/libnd4j/include/execution/impl/Ticket.cpp b/libnd4j/include/execution/impl/Ticket.cpp index b50b8f7712fc..1290360e16a5 100644 --- a/libnd4j/include/execution/impl/Ticket.cpp +++ b/libnd4j/include/execution/impl/Ticket.cpp @@ -18,77 +18,91 @@ // @author raver119@gmail.com // -#include #include +#include #include + #include namespace samediff { - Ticket::Ticket(const std::vector*> &queues) { - _acquired = true; - _queues = queues; - } - - Ticket::Ticket() { - _acquired = true; - _interfaces.resize(sd::Environment::getInstance().maxThreads()); - } - - bool Ticket::acquired() { - return _acquired; - } - - void Ticket::enqueue(int thread_id, samediff::CallableWithArguments *callable) { - _queues[thread_id]->put(callable); - _callables.emplace_back(callable); - } - - void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_DO func) { - _interfaces[thread_id]->fill(thread_id, num_threads, func); - } - - void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_1D func, int64_t start_x, int64_t stop_x, int64_t inc_x) { - _interfaces[thread_id]->fill(thread_id, num_threads, func, start_x, stop_x, inc_x); - } - - void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, int64_t *lpt, FUNC_RL func, int64_t start_x, int64_t stop_x, int64_t inc_x) { - _interfaces[thread_id]->fill(thread_id, num_threads, lpt, func, start_x, stop_x, inc_x); - } - - void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, double *dpt, FUNC_RD func, int64_t start_x, int64_t stop_x, int64_t inc_x) { - _interfaces[thread_id]->fill(thread_id, num_threads, dpt, func, start_x, stop_x, inc_x); - } - - void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_2D func, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y) { - _interfaces[thread_id]->fill(thread_id, num_threads, std::move(func), start_x, stop_x, inc_x, start_y, stop_y, inc_y); - } - - void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_3D func, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z) { - _interfaces[thread_id]->fill(thread_id, num_threads, func, start_x, stop_x, inc_x, start_y, stop_y, inc_y, start_z, stop_z, inc_z); - } - - void Ticket::acquiredThreads(uint32_t threads) { - _acquiredThreads = threads; - } - - void Ticket::waitAndRelease() { - for (uint32_t e = 0; e < this->_acquiredThreads; e++) { - // block until finished - _interfaces[e]->waitForCompletion(); - - // mark available - _interfaces[e]->markAvailable(); - - // increment availability counter - ThreadPool::getInstance().release(); - } - - // return this ticket back to the pool - ThreadPool::getInstance().release(this); - } - - - void Ticket::attach(uint32_t thread_id, samediff::CallableInterface *interface) { - _interfaces[thread_id] = interface; - } -} \ No newline at end of file +Ticket::Ticket( + const std::vector *> &queues) { + _acquired = true; + _queues = queues; +} + +Ticket::Ticket() { + _acquired = true; + _interfaces.resize(sd::Environment::getInstance().maxThreads()); +} + +bool Ticket::acquired() { return _acquired; } + +void Ticket::enqueue(int thread_id, samediff::CallableWithArguments *callable) { + _queues[thread_id]->put(callable); + _callables.emplace_back(callable); +} + +void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_DO func) { + _interfaces[thread_id]->fill(thread_id, num_threads, func); +} + +void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_1D func, + int64_t start_x, int64_t stop_x, int64_t inc_x) { + _interfaces[thread_id]->fill(thread_id, num_threads, func, start_x, stop_x, + inc_x); +} + +void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, int64_t *lpt, + FUNC_RL func, int64_t start_x, int64_t stop_x, + int64_t inc_x) { + _interfaces[thread_id]->fill(thread_id, num_threads, lpt, func, start_x, + stop_x, inc_x); +} + +void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, double *dpt, + FUNC_RD func, int64_t start_x, int64_t stop_x, + int64_t inc_x) { + _interfaces[thread_id]->fill(thread_id, num_threads, dpt, func, start_x, + stop_x, inc_x); +} + +void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_2D func, + int64_t start_x, int64_t stop_x, int64_t inc_x, + int64_t start_y, int64_t stop_y, int64_t inc_y) { + _interfaces[thread_id]->fill(thread_id, num_threads, std::move(func), start_x, + stop_x, inc_x, start_y, stop_y, inc_y); +} + +void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_3D func, + int64_t start_x, int64_t stop_x, int64_t inc_x, + int64_t start_y, int64_t stop_y, int64_t inc_y, + int64_t start_z, int64_t stop_z, int64_t inc_z) { + _interfaces[thread_id]->fill(thread_id, num_threads, func, start_x, stop_x, + inc_x, start_y, stop_y, inc_y, start_z, stop_z, + inc_z); +} + +void Ticket::acquiredThreads(uint32_t threads) { _acquiredThreads = threads; } + +void Ticket::waitAndRelease() { + for (uint32_t e = 0; e < this->_acquiredThreads; e++) { + // block until finished + _interfaces[e]->waitForCompletion(); + + // mark available + _interfaces[e]->markAvailable(); + + // increment availability counter + ThreadPool::getInstance().release(); + } + + // return this ticket back to the pool + ThreadPool::getInstance().release(this); +} + +void Ticket::attach(uint32_t thread_id, + samediff::CallableInterface *interface) { + _interfaces[thread_id] = interface; +} +} // namespace samediff \ No newline at end of file diff --git a/libnd4j/include/graph/ArgumentsList.h b/libnd4j/include/graph/ArgumentsList.h index 75bdf857a2b3..79fd49d4c0ec 100644 --- a/libnd4j/include/graph/ArgumentsList.h +++ b/libnd4j/include/graph/ArgumentsList.h @@ -21,40 +21,42 @@ #ifndef LIBND4J_INPUTLIST_H #define LIBND4J_INPUTLIST_H +#include #include #include -#include -#include #include +#include + namespace sd { namespace graph { - class ND4J_EXPORT ArgumentsList { - protected: - std::vector _arguments; - public: - explicit ArgumentsList() = default; - ArgumentsList(std::initializer_list arguments); - ArgumentsList(std::initializer_list arguments); - - ~ArgumentsList() = default; - - /** - * This method returns number of argument pairs available - * - * @return - */ - int size(); - - /** - * This method returns Pair at specified index - * - * @param index - * @return - */ - Pair &at(int index); - }; -} -} - -#endif //LIBND4J_INPUTLIST_H +class SD_EXPORT ArgumentsList { + protected: + std::vector _arguments; + + public: + explicit ArgumentsList() = default; + ArgumentsList(std::initializer_list arguments); + ArgumentsList(std::initializer_list arguments); + + ~ArgumentsList() = default; + + /** + * This method returns number of argument pairs available + * + * @return + */ + int size(); + + /** + * This method returns Pair at specified index + * + * @param index + * @return + */ + Pair &at(int index); +}; +} // namespace graph +} // namespace sd + +#endif // LIBND4J_INPUTLIST_H diff --git a/libnd4j/include/graph/Context.h b/libnd4j/include/graph/Context.h index de6608b464c6..f6cffd6a7b7f 100644 --- a/libnd4j/include/graph/Context.h +++ b/libnd4j/include/graph/Context.h @@ -22,222 +22,219 @@ #ifndef LIBND4J_CONTEXT_H #define LIBND4J_CONTEXT_H -#include #include +#include +#include #include #include -#include +#include #include -#include + +#include // CUDA-specific includes #ifdef __CUDACC__ #include -#include -#include #include +#include +#include #endif namespace sd { - namespace graph { - /** - * This class defines input desired for any given node/operation within graph - */ - class ND4J_EXPORT Context : public sd::graph::ContextPrototype { - protected: - sd::memory::Workspace* _workspace = nullptr; - sd::graph::VariableSpace* _variableSpace = nullptr; - std::pair _executionTime; - sd::random::RandomBuffer* _rng = nullptr; - - sd::DataType _dataType = sd::DataType::FLOAT32; - // branch for divergent_op - int _branch = 0; - - // temporary context for standalone ops execution - LaunchContext* _context = nullptr; - - std::vector _dataTypes; - - // fields for fast execution (out-of-graph ops use) - std::vector _fastpath_in; - std::vector _fastpath_out; - std::vector _handles; - - bool _helpersAllowed = true; - - // in some cases we might be able to skip shape function for validation purposes - bool _shapeFunctionOverride = false; - - // special flag used during conversion from Graph exec to FastPath exec - bool _forbidFastPath = false; - public: - Context(ContextPrototype* prototype, VariableSpace* variableSpace); - - explicit Context(int nodeId, VariableSpace *variableSpace = nullptr); - Context(int nodeId, VariableSpace *variableSpace, bool isInplace); - - // default destructor - ~Context(); - - // these methods are for execution timing - void setOuterTime(Nd4jLong time); - void setInnerTime(Nd4jLong time); - Nd4jLong getOuterTime(); - Nd4jLong getInnerTime(); - - sd::DataType dataType() override; - - sd::DataType dataType(int index) override; - void setDataType(int index, sd::DataType type) override; - // these methods are related to Workspace abstraction - bool hasWorkspaceProvided(); - void attachWorkspace(sd::memory::Workspace* workspace); - void forgetWorkspace(); - - // these methods return full-time workspace - sd::memory::Workspace* getWorkspace(); - sd::memory::Workspace* workspace(); - sd::memory::Workspace* fWorkspace(); - - // this method returns workspace for temporary allocations - sd::memory::Workspace* tWorkspace(); - - // this method returns workspace for object allocations - sd::memory::Workspace* oWorkspace(); - - void setVariableSpace(VariableSpace* variableSpace); - - sd::random::RandomBuffer* getRNG(); - void setRNG(sd::random::RandomBuffer* rng); - - void setTargetEngine(samediff::Engine engine); - - VariableSpace *getVariableSpace(); - - LaunchContext* launchContext(); - - // these fields define, if we can execute specific node in-place, without generating new array - - - // these variables are only for Divergent Nodes - int getBranch(); - void setBranch(int branch); - - /** - * - * @return - */ - Stash* getStash(); - - /** - * - */ - void trackList(NDArrayList* list); - - - /** - * This method returns variable for a given input index for this block - * @param idx - * @return - */ - Variable* getVariable(int idx); - Variable* variable(int idx); - - /** - * This method is shortcut to getVariable(int idx); - * - * + it check fastpath for array availability (preferred) - * @return - */ - NDArray* getNDArray(int idx); - NDArray* array(int idx); - - - /** - * This method fetches variable from VariableSpace DIRECTLY - * @param p - * @return - */ - Variable* variable(int node, int index); - Variable* variable(std::pair& p); - Variable* variable(std::initializer_list p); - - - void pushNDArrayToVariableSpace(int nodeId, int index, NDArray* array, bool removable = true); - void pushNDArrayToVariableSpace(std::pair& pair, NDArray* array, bool removable = true); - - void pushNDArrayListToVariableSpace(int nodeId, int index, NDArrayList* list, bool track = true); - void pushNDArrayListToVariableSpace(std::pair& pair, NDArrayList* list, bool track = true); - - bool isValueAvailable(int idx = 0); - - Variable* ensureVariable(int idx = 0); - - unsigned long width() override; - - // methods used in java interop - /** - * This method checks if Context uses fastpath variable access - * @return - */ - bool isFastPath(); - - /** - * Method allows to forbid FastPath execution - * @param reallyForbid - */ - void forbidFastPath(bool reallyForbid); +namespace graph { +/** + * This class defines input desired for any given node/operation within graph + */ +class SD_EXPORT Context : public sd::graph::ContextPrototype { + protected: + sd::graph::GraphMemoryManager *_memoryManager = nullptr; + sd::memory::Workspace *_workspace = nullptr; + + sd::graph::VariableSpace *_variableSpace = nullptr; + std::pair _executionTime; + + sd::DataType _dataType = sd::DataType::FLOAT32; + // branch for divergent_op + int _branch = 0; + + // temporary context for standalone ops execution + LaunchContext *_context = nullptr; + + std::vector _dataTypes; + + // fields for fast execution (out-of-graph ops use) + std::vector> _fastpath_in; + std::vector> _fastpath_out; + + bool _helpersAllowed = true; + + // in some cases we might be able to skip shape function for validation + // purposes + bool _shapeFunctionOverride = false; + + // special flag used during conversion from Graph exec to FastPath exec + bool _forbidFastPath = false; + + public: + Context(const ContextPrototype &prototype, VariableSpace *variableSpace, + GraphMemoryManager *memoryManager = nullptr); + + explicit Context(int nodeId, VariableSpace *variableSpace = nullptr); + Context(int nodeId, VariableSpace *variableSpace, bool isInplace); + + // default destructor + ~Context(); + + // these methods are for execution timing + void setOuterTime(Nd4jLong time); + void setInnerTime(Nd4jLong time); + Nd4jLong outerTime() const; + Nd4jLong innerTime() const; + + // these methods are related to Workspace abstraction + bool hasWorkspaceProvided() const; + void attachWorkspace(sd::memory::Workspace *workspace); + + sd::memory::Workspace *workspace() const; + + void setVariableSpace(VariableSpace *variableSpace); + + void setTargetEngine(samediff::Engine engine); + + VariableSpace *getVariableSpace(); + + LaunchContext *launchContext(); + + /** + * + * @return + */ + Stash *stash() const; + + /** + * This method returns variable for a given input index for this block + * @param idx + * @return + */ + std::shared_ptr getVariable(int idx) const; + std::shared_ptr variable(int idx) const; + + /** + * This method is shortcut to getVariable(int idx); + * + * + it check fastpath for array availability (preferred) + * @return + */ + std::shared_ptr getNDArray(int idx) const; + std::shared_ptr array(int idx) const; + + /** + * This is special method, used only within Graph + * @param idx + * @return + */ + NDArray *arrayForOp(int idx) const; + + /** + * This method fetches variable from VariableSpace DIRECTLY + * @param p + * @return + */ + std::shared_ptr variable(int node, int index) const; + std::shared_ptr variable(const std::pair &p) const; + std::shared_ptr variable(std::initializer_list p) const; + + void pushNDArrayToVariableSpace(int nodeId, int index, const NDArray &array); + void pushNDArrayToVariableSpace(const std::pair &pair, + const NDArray &array); + + void pushNDArrayListToVariableSpace(int nodeId, int index, std::shared_ptr list); + void pushNDArrayListToVariableSpace(int nodeId, int index, + const NDArrayList &list, + bool track = true); + void pushNDArrayListToVariableSpace(const std::pair &pair, + const NDArrayList &list, + bool track = true); + + bool isValueAvailable(const std::string &name, int id, int idx = 0) const; + + std::shared_ptr ensureVariable(const std::string &name, int id, + int idx = 0); + + unsigned long width() const override; + + // methods used in java interop + /** + * This method checks if Context uses fastpath variable access + * @return + */ + bool isFastPath() const; + + /** + * Method allows to forbid FastPath execution + * @param reallyForbid + */ + void forbidFastPath(bool reallyForbid); #ifndef __JAVACPP_HACK__ - std::vector& fastpath_in(); - std::vector& fastpath_out(); + const std::vector> &fastpath_in() const; + const std::vector> &fastpath_out() const; #endif - void setInputArray(int index, NDArray *array, bool removable = false); - void setInputArray(int index, void *buffer, void const* shapeInfo, void *specialBuffer, void const* specialShapeInfo); - void setInputArray(int index, void *buffer, void * shapeInfo, void *specialBuffer, void * specialShapeInfo); - void setInputArray(int index, void *databuffer, void const* shapeInfo, void const* specialShapeInfo); - - void setOutputArray(int index, NDArray *array, bool removable = false); - void setOutputArray(int index, void *buffer, const void * shapeInfo, void *specialBuffer, const void * specialShapeInfo); - void setOutputArray(int index, void *buffer, void * shapeInfo, void *specialBuffer, void * specialShapeInfo); - void setOutputArray(int index, void *databuffer, void const* shapeInfo, void const* specialShapeInfo); - - void setTArguments(double *arguments, int numberOfArguments); - void setIArguments(Nd4jLong *arguments, int numberOfArguments); - void setBArguments(bool *arguments, int numberOfArguments); - void setDArguments(sd::DataType *arguments, int numberOfArguments); - - void setTArguments(const std::vector &tArgs); - void setIArguments(const std::vector &tArgs); - void setBArguments(const std::vector &tArgs); - void setDArguments(const std::vector &dArgs); - - /** - * This method purges fastpath in/out contents and releases all the handles. - * - * PLEASE NOTE: I/T/B/D args will stay intact - */ - void clearFastPath(); - - void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer); - - void allowHelpers(bool reallyAllow); - bool helpersAllowed(); - - void setShapeFunctionOverride(bool reallyOverride); - bool shapeFunctionOverride(); - - samediff::ExecutionMode executionMode(); - void setExecutionMode(samediff::ExecutionMode executionMode); - - bool isTraining(); - bool isInference(); - }; - } -} - - -#endif //LIBND4J_BLOCK_H + void setInputArray(int index, const std::shared_ptr &array); + void setInputArray(int index, const NDArray &array); + void setInputArray(int index, void *buffer, void const *shapeInfo, + void *specialBuffer, void const *specialShapeInfo); + void setInputArray(int index, void *buffer, void *shapeInfo, + void *specialBuffer, void *specialShapeInfo); + void setInputArray(int index, void *databuffer, void const *shapeInfo, + void const *specialShapeInfo); + + void setOutputArray(int index, const std::shared_ptr &array); + void setOutputArray(int index, const NDArray &array); + void setOutputArray(int index, void *buffer, const void *shapeInfo, + void *specialBuffer, const void *specialShapeInfo); + void setOutputArray(int index, void *buffer, void *shapeInfo, + void *specialBuffer, void *specialShapeInfo); + void setOutputArray(int index, void *databuffer, void const *shapeInfo, + void const *specialShapeInfo); + + void setTArguments(double *arguments, int numberOfArguments); + void setIArguments(Nd4jLong *arguments, int numberOfArguments); + void setBArguments(bool *arguments, int numberOfArguments); + void setDArguments(sd::DataType *arguments, int numberOfArguments); + + void setTArguments(const std::vector &tArgs); + void setIArguments(const std::vector &tArgs); + void setBArguments(const std::vector &tArgs); + void setDArguments(const std::vector &dArgs); + + /** + * This method purges fastpath in/out contents and releases all the handles. + * + * PLEASE NOTE: I/T/B/D args will stay intact + */ + void clearFastPath(); + + void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, + Nd4jPointer allocationPointer); + + void allowHelpers(bool reallyAllow); + bool helpersAllowed() const; + + void setShapeFunctionOverride(bool reallyOverride); + bool shapeFunctionOverride() const; + + samediff::ExecutionMode executionMode() const; + void setExecutionMode(samediff::ExecutionMode executionMode); + + bool isTraining() const; + bool isInference() const; + + const GraphMemoryManager &memoryManager() const; +}; +} // namespace graph +} // namespace sd + +#endif // LIBND4J_BLOCK_H diff --git a/libnd4j/include/graph/ContextPrototype.h b/libnd4j/include/graph/ContextPrototype.h index e61831fa7072..e780c4aa93a0 100644 --- a/libnd4j/include/graph/ContextPrototype.h +++ b/libnd4j/include/graph/ContextPrototype.h @@ -22,118 +22,143 @@ #ifndef ND4J_CONTEXT_PROTOTYPE_H #define ND4J_CONTEXT_PROTOTYPE_H -#include -#include #include -#include -#include -#include #include #include +#include +#include +#include +#include + +#include #ifndef __STANDALONE_BUILD__ #include #endif namespace sd { - namespace graph { - - class ND4J_EXPORT ContextPrototype { - protected: - // int ids of the input nodes - std::vector> _inputs; - int _nodeId; - std::vector _tArgs; - std::vector _iArgs; - std::vector _bArgs; - std::vector _axis; - std::vector _dArgs; - - // TODO: remove this field - sd::DataType _dataType = sd::DataType::FLOAT32; - bool _isInplace; - - // opNum for legacy XYZ ops - int _opNum = -1; - uint64_t _rootSeed; - RandomGenerator _randomGenerator; - - std::vector _dataTypes; - - sd::ops::OpDescriptor* _opDescriptor; - bool _useMKLDNN = sd::Environment::getInstance().isUseMKLDNN(); - - // target engine for execution - samediff::Engine _engine = DEFAULT_ENGINE; - - samediff::ExecutionMode _execMode = samediff::ExecutionMode::MODE_UNDEFINED; - public: - explicit ContextPrototype(sd::ops::OpDescriptor* opDescriptor = nullptr, int nodeId = 1, bool inPlace = false); - ~ContextPrototype() = default; - - int getNodeId(); - int nodeId(); - - // this method returns true, if inputs are defined - bool hasVariablesFilled(); - - void setOpDescriptor(sd::ops::OpDescriptor* opDescriptor); - - virtual sd::DataType dataType(); - virtual sd::DataType dataType(int index); - virtual void setDataType(int index, sd::DataType type); - - bool isInplace(); - void markInplace(bool reallyInplace); - - void pickInput(int input); - void pickInput(int input, int index); - void pickInput(std::pair& p); - void fillInputs(std::initializer_list inputs); - void fillInputs(std::vector& inputs); - std::vector>* inputs(); - - std::vector* getTArguments(); - std::vector* getIArguments(); - std::vector* getBArguments(); - std::vector* getDArguments(); - std::vector* getAxis(); - - samediff::Engine engine(); - - size_t numT(); - size_t numI(); - size_t numB(); - size_t numD(); - - std::pair* input(int idx); - - int opNum(); - void setOpNum(int opNum); - - bool isUseMKLDNN() { return _useMKLDNN; } - void setUseMKLDNN(bool useMKLDNN) { _useMKLDNN = useMKLDNN; } - - /** - * This method returns number of inputs available in this block - * @return - */ - virtual unsigned long width(); - - // just a clone - ContextPrototype* clone(); - - template - ContextPrototype* asT(); - - RandomGenerator& randomGenerator() {return _randomGenerator;} - RandomGenerator const& getRng()const { return _randomGenerator; } - void setRng(RandomGenerator const& anotherRng) { _randomGenerator = anotherRng; } - void setRandomGenerator(RandomGenerator const& anotherRng) { _randomGenerator = anotherRng; } - uint64_t randomSeed() const { return _rootSeed; } - void setRandomSeed(uint64_t seed) { _rootSeed = seed; } - }; - } -} - -#endif //ND4J_CONTEXT_PROTOTYPE_H +namespace graph { + +class SD_EXPORT ContextPrototype { + protected: + // int ids of the input nodes + std::vector> _inputs; + int _nodeId; + std::string _name; + std::vector _tArgs; + std::vector _iArgs; + std::vector _bArgs; + std::vector _axis; + std::vector _dArgs; + + bool _isInplace; + + // opNum for legacy XYZ ops + int _opNum = -1; + uint64_t _rootSeed; + RandomGenerator _randomGenerator; + + sd::ops::OpDescriptor* _opDescriptor; + bool _useMKLDNN = sd::Environment::getInstance().isUseMKLDNN(); + + // target engine for execution + samediff::Engine _engine = DEFAULT_ENGINE; + + samediff::ExecutionMode _execMode = samediff::ExecutionMode::MODE_UNDEFINED; + + public: + explicit ContextPrototype(sd::ops::OpDescriptor* opDescriptor = nullptr, + int nodeId = 1, bool inPlace = false); + ~ContextPrototype() = default; + + ContextPrototype(const ContextPrototype& other) noexcept; + + ContextPrototype& operator=(const ContextPrototype& other) noexcept; + + // move constructor + ContextPrototype(ContextPrototype&& other) noexcept; + + // move assignment operator + ContextPrototype& operator=(ContextPrototype&& other) noexcept; + + int getNodeId() const; + int nodeId() const; + void setNodeId(int id); + + // this method returns true, if inputs are defined + bool hasVariablesFilled() const; + + void setOpDescriptor(sd::ops::OpDescriptor* opDescriptor); + + bool isInplace() const; + void markInplace(bool reallyInplace); + + void pickInput(int input); + void pickInput(int input, int index); + void pickInput(const std::pair& p); + void fillInputs(std::initializer_list inputs); + void fillInputs(std::vector& inputs); + const std::vector>& inputs() const; + + const std::vector& getTArguments() const; + const std::vector& getIArguments() const; + const std::vector& getBArguments() const; + const std::vector& getDArguments() const; + const std::vector& getAxis() const; + + void appendI(const std::vector& value); + void appendT(const std::vector& value); + void appendB(const std::vector& value); + void appendD(const std::vector& value); + + void appendA(Nd4jLong value); + void appendI(Nd4jLong value); + void appendT(double value); + void appendB(bool value); + void appendD(DataType value); + + samediff::Engine engine() const; + + size_t numT() const; + size_t numI() const; + size_t numB() const; + size_t numD() const; + + const std::pair& input(int idx) const; + + int opNum() const; + void setOpNum(int opNum); + + bool isUseMKLDNN() const { return _useMKLDNN; } + void setUseMKLDNN(bool useMKLDNN) { _useMKLDNN = useMKLDNN; } + + const std::string& name() const; + void setName(const std::string& name); + + /** + * This method returns number of inputs available in this block + * @return + */ + virtual unsigned long width() const; + + // just a clone + ContextPrototype* clone(); + + template + ContextPrototype* asT(); + + RandomGenerator& randomGenerator() { return _randomGenerator; } + RandomGenerator const& getRng() const { return _randomGenerator; } + void setRng(RandomGenerator const& anotherRng) { + _randomGenerator = anotherRng; + } + void setRandomGenerator(RandomGenerator const& anotherRng) { + _randomGenerator = anotherRng; + } + uint64_t randomSeed() const { return _rootSeed; } + void setRandomSeed(uint64_t seed) { _rootSeed = seed; } +}; +} // namespace graph +} // namespace sd + +#endif // ND4J_CONTEXT_PROTOTYPE_H diff --git a/libnd4j/include/graph/ExecutionResult.h b/libnd4j/include/graph/ExecutionResult.h index b1f16032c314..b67320ef546d 100644 --- a/libnd4j/include/graph/ExecutionResult.h +++ b/libnd4j/include/graph/ExecutionResult.h @@ -21,74 +21,77 @@ #ifndef LIBND4J_EXECUTION_RESULT #define LIBND4J_EXECUTION_RESULT -#include +#include +#include + #include -#include #include #include -#include -#include +#include +#include namespace sd { - namespace graph { - class ExecutionResult { - private: - std::vector _variables; - MAP_IMPL _stringIdMap; - MAP_IMPL, Variable *> _pairIdMap; - - // this flag is used to optionally release variables - bool _releasable = false; - public: - ExecutionResult(const FlatResult* flatResult); - ExecutionResult(std::initializer_list variables); - ExecutionResult() = default; - ~ExecutionResult(); - - /** - * This method adds variable pointer to result - */ - void emplace_back(Variable *variable); - - /** - * This method returns Variable by its position in output - */ - Variable* at(int position); - - /** - * This method returns Variable by its string id - */ - Variable* byId(std::string &id); - - /** - * This method returns Variable by its string id - */ - Variable* byId(const char *str); - - /** - * This method returns Variable by its numeric id:index pair - */ - Variable* byId(std::pair &id); - - /** - * This method returns Variable by its numeric id with index 0 - */ - Variable* byId(int id); - - /** - * This method returns number of elements stored in this entity - * @return - */ - Nd4jLong size(); +namespace graph { +class ExecutionResult { + private: + std::vector _variables; + MAP_IMPL _stringIdMap; + MAP_IMPL, Variable *> _pairIdMap; + + // this flag is used to optionally release variables + bool _releasable = false; + + public: + ExecutionResult(const FlatResult *flatResult); + ExecutionResult(std::initializer_list variables); + ExecutionResult() = default; + ~ExecutionResult(); + + /** + * This method adds variable pointer to result + */ + void emplace_back(Variable *variable); + + /** + * This method returns Variable by its position in output + */ + Variable *at(int position); + + /** + * This method returns Variable by its string id + */ + Variable *byId(std::string &id); + + /** + * This method returns Variable by its string id + */ + Variable *byId(const char *str); + + /** + * This method returns Variable by its numeric id:index pair + */ + Variable *byId(std::pair &id); + + /** + * This method returns Variable by its numeric id with index 0 + */ + Variable *byId(int id); + + /** + * This method returns number of elements stored in this entity + * @return + */ + Nd4jLong size(); #ifndef __JAVACPP_HACK__ - /** - * This method converts ExecutionResult entity to FlatResult - */ - flatbuffers::Offset asFlatResult(flatbuffers::FlatBufferBuilder &builder); + /** + * This method converts ExecutionResult entity to FlatResult + */ + flatbuffers::Offset asFlatResult( + flatbuffers::FlatBufferBuilder &builder); #endif - }; - } -} +}; +} // namespace graph +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/graph/ExecutorConfiguration.h b/libnd4j/include/graph/ExecutorConfiguration.h index 40f299f02513..758300595e1e 100644 --- a/libnd4j/include/graph/ExecutorConfiguration.h +++ b/libnd4j/include/graph/ExecutorConfiguration.h @@ -22,31 +22,33 @@ #define LIBND4J_EXECUTORCONFIGURATION_H #include -#include #include +#include namespace sd { - namespace graph { - class ND4J_EXPORT ExecutorConfiguration { - public: - sd::graph::ProfilingMode _profilingMode; - sd::graph::ExecutionMode _executionMode; - sd::graph::OutputMode _outputMode; - bool _timestats; - Nd4jLong _footprintForward = 0L; - Nd4jLong _footprintBackward = 0L; - Direction _direction = Direction_FORWARD_ONLY; - - explicit ExecutorConfiguration(const sd::graph::FlatConfiguration *conf = nullptr); - ~ExecutorConfiguration() = default; - - ExecutorConfiguration* clone(); +namespace graph { +class SD_EXPORT ExecutorConfiguration { + public: + sd::graph::ProfilingMode _profilingMode; + sd::graph::ExecutionMode _executionMode; + sd::graph::OutputMode _outputMode; + bool _timestats; + Nd4jLong _footprintForward = 0L; + Nd4jLong _footprintBackward = 0L; + Direction _direction = Direction_FORWARD_ONLY; + + explicit ExecutorConfiguration( + const sd::graph::FlatConfiguration *conf = nullptr); + ~ExecutorConfiguration() = default; + + ExecutorConfiguration clone() const; #ifndef __JAVACPP_HACK__ - flatbuffers::Offset asFlatConfiguration(flatbuffers::FlatBufferBuilder &builder); + flatbuffers::Offset asFlatConfiguration( + flatbuffers::FlatBufferBuilder &builder); #endif - }; - } -} +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_EXECUTORCONFIGURATION_H +#endif // LIBND4J_EXECUTORCONFIGURATION_H diff --git a/libnd4j/include/graph/FlatUtils.h b/libnd4j/include/graph/FlatUtils.h index 1b2a02dca841..180c141a3d2f 100644 --- a/libnd4j/include/graph/FlatUtils.h +++ b/libnd4j/include/graph/FlatUtils.h @@ -21,25 +21,27 @@ #ifndef LIBND4J_FLATUTILS_H #define LIBND4J_FLATUTILS_H -#include -#include +#include #include #include -#include +#include + +#include namespace sd { - namespace graph { - class ND4J_EXPORT FlatUtils { - public: - static std::pair fromIntPair(IntPair* pair); +namespace graph { +class SD_EXPORT FlatUtils { + public: + static std::pair fromIntPair(IntPair* pair); - static std::pair fromLongPair(LongPair* pair); + static std::pair fromLongPair(LongPair* pair); - static NDArray* fromFlatArray(const sd::graph::FlatArray* flatArray); + static NDArray fromFlatArray(const sd::graph::FlatArray* flatArray); - static flatbuffers::Offset toFlatArray(flatbuffers::FlatBufferBuilder &builder, NDArray &array); - }; - } -} + static flatbuffers::Offset toFlatArray( + flatbuffers::FlatBufferBuilder& builder, NDArray& array); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_FLATUTILS_H +#endif // LIBND4J_FLATUTILS_H diff --git a/libnd4j/include/graph/FlowPath.h b/libnd4j/include/graph/FlowPath.h index 59752024929e..e9e8ddc9122f 100644 --- a/libnd4j/include/graph/FlowPath.h +++ b/libnd4j/include/graph/FlowPath.h @@ -21,67 +21,68 @@ #ifndef LIBND4J_FLOWPATH_H #define LIBND4J_FLOWPATH_H -#include -#include -#include -#include -#include #include +#include #include #include +#include +#include + +#include +#include namespace sd { - namespace graph { - class ND4J_EXPORT FlowPath { - private: - MAP_IMPL _states; - MAP_IMPL _frames; +namespace graph { +class SD_EXPORT FlowPath { + private: + MAP_IMPL _states; + MAP_IMPL _frames; - void ensureNode(int nodeId); - void ensureFrame(int nodeId); + void ensureNode(int nodeId); + void ensureFrame(int nodeId); - GraphProfile _profile; - public: - FlowPath() = default; - ~FlowPath() = default; + GraphProfile _profile; - void setInnerTime(int nodeId, Nd4jLong time); - void setOuterTime(int nodeId, Nd4jLong time); + public: + FlowPath() = default; + ~FlowPath() = default; - Nd4jLong innerTime(int nodeId); - Nd4jLong outerTime(int nodeId); + void setInnerTime(int nodeId, Nd4jLong time); + void setOuterTime(int nodeId, Nd4jLong time); - bool isNodeActive(int nodeId); - void markNodeActive(int nodeId, bool isActive); + Nd4jLong innerTime(int nodeId); + Nd4jLong outerTime(int nodeId); - bool wasExecuted(int nodeId); - void markExecuted(int nodeId, bool wasExecuted); + bool isNodeActive(int nodeId); + void markNodeActive(int nodeId, bool isActive); - int branch(int nodeId); - void markBranch(int nodeId, int index); + bool wasExecuted(int nodeId); + void markExecuted(int nodeId, bool wasExecuted); - // Frame-related methods + int branch(int nodeId); + void markBranch(int nodeId, int index); - void registerFrame(Nd4jLong frameId); - void forgetFrame(Nd4jLong frameId); + // Frame-related methods - bool isFrameActive(Nd4jLong frameId); - void markFrameActive(Nd4jLong frameId, bool isActive); + void registerFrame(Nd4jLong frameId); + void forgetFrame(Nd4jLong frameId); - bool isRewindPlanned(Nd4jLong frameId); - void planRewind(Nd4jLong frameId, bool reallyRewind); + bool isFrameActive(Nd4jLong frameId); + void markFrameActive(Nd4jLong frameId, bool isActive); - int getRewindPosition(Nd4jLong frameId); - void setRewindPosition(Nd4jLong frameId, int position); - void setRewindPositionOnce(Nd4jLong frameId, int position); + bool isRewindPlanned(Nd4jLong frameId); + void planRewind(Nd4jLong frameId, bool reallyRewind); - void incrementNumberOfCycles(Nd4jLong frameId); - Nd4jLong getNumberOfCycles(Nd4jLong frameId); + int getRewindPosition(Nd4jLong frameId); + void setRewindPosition(Nd4jLong frameId, int position); + void setRewindPositionOnce(Nd4jLong frameId, int position); - GraphProfile* profile(); - }; - } -} + void incrementNumberOfCycles(Nd4jLong frameId); + Nd4jLong getNumberOfCycles(Nd4jLong frameId); + GraphProfile* profile(); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_FLOWPATH_H +#endif // LIBND4J_FLOWPATH_H diff --git a/libnd4j/include/graph/FrameState.h b/libnd4j/include/graph/FrameState.h index 1c0edbc0bbcb..564cf5750ebe 100644 --- a/libnd4j/include/graph/FrameState.h +++ b/libnd4j/include/graph/FrameState.h @@ -21,87 +21,89 @@ #ifndef LIBND4J_FRAMESTATE_H #define LIBND4J_FRAMESTATE_H -#include -#include #include +#include + +#include namespace sd { - namespace graph { - class ND4J_EXPORT FrameState { - private: - std::string _name; - Nd4jLong _id = 0; - int _numberOfCycles = 0; - bool _activated = false; - - bool _rewindPlanned = false; - int _rewindPosition = -1; - public: - FrameState(Nd4jLong id = 0); - ~FrameState() = default; - - /** - * This method returns number of cycles passed for this Frame - * - * @return - */ - int getNumberOfCycles(); - - /** - * This method increments number of cycles by 1 for this Frame - */ - void incrementNumberOfCycles(); - - /** - * This method returns TRUE is frame was activated at LoopCond - * @return - */ - bool wasActivated(); - - /** - * This method allows to toggle activated state of this Frame - * @param reallyActivated - */ - void markActivated(bool reallyActivated); - - /** - * This method returns of this Frame (if it's set) - * @return - */ - std::string& getFrameName(); - - /** - * This method returns TRUE if reset is planned for this Frame - * @return - */ - bool isRewindPlanned(); - - /** - * This method allows you to toggle flag for planned rewind - * @param reallyPlanning - */ - void planRewind(bool reallyPlanning); - - /** - * This method returns planned reset position for given Frame - * @return - */ - int getRewindPosition(); - - /** - * This method allows to set rewind position for this Frame - * @param pos - */ - void setRewindPosition(int pos); - - /** - * This method allows to set rewind position for this Frame, but only if it wasn't set earlier - * @param pos - */ - void setRewindPositionOnce(int pos); - }; - } -} - - -#endif //LIBND4J_FRAMESTATE_H +namespace graph { +class SD_EXPORT FrameState { + private: + std::string _name; + Nd4jLong _id = 0; + int _numberOfCycles = 0; + bool _activated = false; + + bool _rewindPlanned = false; + int _rewindPosition = -1; + + public: + FrameState(Nd4jLong id = 0); + ~FrameState() = default; + + /** + * This method returns number of cycles passed for this Frame + * + * @return + */ + int getNumberOfCycles(); + + /** + * This method increments number of cycles by 1 for this Frame + */ + void incrementNumberOfCycles(); + + /** + * This method returns TRUE is frame was activated at LoopCond + * @return + */ + bool wasActivated(); + + /** + * This method allows to toggle activated state of this Frame + * @param reallyActivated + */ + void markActivated(bool reallyActivated); + + /** + * This method returns of this Frame (if it's set) + * @return + */ + std::string& getFrameName(); + + /** + * This method returns TRUE if reset is planned for this Frame + * @return + */ + bool isRewindPlanned(); + + /** + * This method allows you to toggle flag for planned rewind + * @param reallyPlanning + */ + void planRewind(bool reallyPlanning); + + /** + * This method returns planned reset position for given Frame + * @return + */ + int getRewindPosition(); + + /** + * This method allows to set rewind position for this Frame + * @param pos + */ + void setRewindPosition(int pos); + + /** + * This method allows to set rewind position for this Frame, but only if it + * wasn't set earlier + * @param pos + */ + void setRewindPositionOnce(int pos); +}; +} // namespace graph +} // namespace sd + +#endif // LIBND4J_FRAMESTATE_H diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index a160872fd2db..81dfc67c31ce 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -21,262 +21,183 @@ #ifndef LIBND4J_GRAPH_H #define LIBND4J_GRAPH_H -#include #include -#include +#include #include +#include //#include +#include #include +#include #include -#include #include #include -#include -#include +#include #include -#include +#include +#include +#include #include namespace sd { - namespace graph { +namespace graph { + +class NodeInfo; +class SD_EXPORT Graph { + protected: + ExecutorConfiguration _configuration; + VariableSpace _variableSpace; + + // TODO: these 2 fields should be deleted + memory::Workspace _workspace; + Stash _stash; + + MAP_IMPL _unmapped; + + // string -> id conversion table + MAP_IMPL _symbolicLookupTable; + + std::mutex _mutexPreprocessing; + std::atomic _built; + + // we want to know last node id + int _maxId = 1; + + const GraphMemoryManager &_memoryManager; + + //////////////////////////////////////// + Nd4jStatus validateNode(Node *node); + + int idByName(const std::string &nodeName) const; + + void printOutNode(const Node &node) const; + + std::vector _placeholders; + + mutable OptimizedGraph _optimized; + + mutable std::mutex _optimizedLock; + - class ND4J_EXPORT Graph { - protected: - ExecutorConfiguration *_configuration; - VariableSpace *_variableSpace; - Stash* _stash; + std::vector _handles; - // this list holds references to Node ptrs, which should be free'd in Graph destructor - std::vector _handles; + public: + Graph(const FlatGraph *flatGraph = nullptr, + const GraphMemoryManager &memoryManager = GraphMemoryManager()); - // vector holds ID's of top nodes only - std::vector *_nodes; - MAP_IMPL *_mapped; + ~Graph(); - MAP_IMPL *> *_onion; - MAP_IMPL _unmapped; - std::vector _unmappedMap; // macOS? + Graph(const Graph &other); - std::mutex _mutexPreprocessing; - std::atomic _built; + Graph &operator=(const Graph &other) noexcept; - std::vector _output; - std::vector _autos; + // move constructor + Graph(Graph &&other); + // move assignment operator + Graph &operator=(Graph &&other) noexcept; - MAP_IMPL _mappedScopes; - std::vector _scopes; + /** + * Methods that allow Graph imports + */ + static Graph importFromTensorFlow(const char *fileName); + static Graph fromFlatBuffers( + const char *fileName, + const GraphMemoryManager &memoryManager = GraphMemoryManager()); + static Graph fromFlatPointer( + void *ptr, + const GraphMemoryManager &memoryManager = GraphMemoryManager()); -//////////////////////////////////////// - Nd4jStatus validateNode(sd::graph::Node *node); + // method that'll print out graph + Nd4jStatus validate(); - void expandOnion(int newLayer); + // this method returns total number of nodes in this graph + int size() const; - void injectNode(sd::graph::Node *node); + int numberOfPlaceholders() const; - void pushToOutputOnce(int id); + const std::vector> &placeholders() const; - void printOutNode(Node* node); + /** + * This method returns pointer to thread_local VariableSpace + * @return + */ + VariableSpace &variableSpace() const; + /** + * This method returns unmapped nodes + * @return + */ + const MAP_IMPL &unmappedNodes() const { return _unmapped; }; - void prepareOutputs(); + const GraphMemoryManager &memoryManager() const; - public: - Graph(const FlatGraph *flatGraph = nullptr, VariableSpace *variableSpace = nullptr); + /** + * These methods add given node to the graph + * @param node + */ + void addNode(Node &&node, const std::vector &inputs); - ~Graph(); + void addNode(Node &node, const std::vector &inputs); - // this method applies toposort to nodes - void toposortNodes(); - // method that'll print out graph - Nd4jStatus validate(); + void addVariable(const std::string &name, NDArray &array); + void addVariable(const std::string &name, NDArray &&array); - // this method will build structured representation of graph - Nd4jStatus buildGraph(); + /** + * This method allows to add placeholder with some pre-defined properties + */ + void addPlaceholder(const std::string &nodeName, + const DataType dataType = sd::DataType::ANY, + const std::vector &shape = {}); - // this method will return estimated memory size (in bytes) required for 1 full graph execution round - Nd4jLong estimateRequiredMemory(); + /** + * This method returns pointer to ExecutorConfiguration + * + * @return + */ + const ExecutorConfiguration &getExecutorConfiguration() const; - // this method returns number of root nodes in this graph - int rootNodes(); + /** + * This method prints out Graph op-by-op, and respective inputs + */ + void printOut(); - // this method returns total number of nodes in this graph - int totalNodes(); + /** + * This method returns clone of the graph + */ + Graph *clone() const; - int numberOfPlaceholders(); + /** + * This method returns clone of the graph, backed by VariableProxy instead of + * VariableSpace + */ + Graph cloneWithProxy() const; - std::vector* getPlaceholders(); - - /** - * This method returns pointer to thread_local VariableSpace - * @return - */ - sd::graph::VariableSpace *getVariableSpace(); - - /** - * This method adds given node to the graph - * - * @param node - */ - void addNode(sd::graph::Node *node); - - /** - * This method returns layered representation of the graph - * - * @return - */ - MAP_IMPL *> *getOnion(); - - /** - * This method returns map of all nodes of the graph - * @return - */ - MAP_IMPL* getMapped(); - - /** - * This method returns outputs of this graph - * @return - */ - std::vector *fetchOutputs(); - - /** - * This method returns pointer to ExecutorConfiguration - * - * @return - */ - ExecutorConfiguration *getExecutorConfiguration(); + /** + * This method returns hash of given Graph instance + */ + Nd4jLong hashCode() const; - /** - * This method adds specified node (by ID) to de - * @param id - */ - void addOutput(int id); - - /** - * This method returns all nodes at once (order is NOT guaranteed) - * @return - */ - std::vector *getAllNodes(); - - /** - * This method prints out Graph op-by-op, and respective inputs - */ - void printOut(); - - /** - * This method collect all ops from the graph into ops vector - */ - std::vector getOperations(); - - /** - * This method returns Scope ptr specified with id - * - * @param id - * @return - */ - Scope* scopeById(int id); - - /** - * This method returns TRUE if specified ID refers to Scope, and false otherwise - * @param id - * @return - */ - bool hasScope(int id); - - /** - * This method returns clone of the graph - */ - Graph* clone(); - - /** - * This method returns clone of the graph, backed by VariableProxy instead of VariableSpace - */ - Graph* cloneWithProxy(); - - /** - * This method removes reference to VariableSpace from this Graph - */ - void forgetVariableSpace(); - - /** - * This method returns Node with given Id - */ - Node* nodeById(int nodeId); - - /** - * This method returns True if node with given ID exists, False otherwise - * @param nodeId - * @return - */ - bool hasNode(int nodeId); - - /** - * This method returns hash of given Graph instance - */ - Nd4jLong hashCode(); - - /** - * PLEASE NOTE: This method will be moved to private section - */ - void tagInplaceNodes(); - - void replaceState(VariableSpace *state, ExecutorConfiguration *configuration); - - FORCEINLINE std::vector* nodes() { - return _nodes; - } - - FORCEINLINE std::vector* autos() { - return &_autos; - } - - FORCEINLINE std::vector* output() { - return &_output; - } - - FORCEINLINE MAP_IMPL* scopes() { - return &_mappedScopes; - } - - FORCEINLINE bool built() { - return _built.load(); - } - - FORCEINLINE void pullState(Graph *other) { - for (int e = 0; e < other->nodes()->size(); e++) - this->_nodes->emplace_back(other->nodes()->at(e)); - - for (int e = 0; e < other->output()->size(); e++) - this->_output.emplace_back(other->output()->at(e)); - - for (int e = 0; e < other->autos()->size(); e++) - this->_autos.emplace_back(other->autos()->at(e)); - - for (auto &v: *other->scopes()) { - auto scp = v.second->clone(); - this->_mappedScopes[v.first] = scp; - this->_scopes.emplace_back(scp); - } - - for (auto &v: *other->getOnion()) { - auto vec = this->_onion->count(v.first) > 0 ? this->_onion->at(v.first) : new std::vector(); - - auto ovec = (*other->getOnion())[v.first]; - for (auto x: *(ovec)) { - auto n = x->clone(); - vec->emplace_back(n); - _handles.emplace_back(n); - (*this->_mapped)[n->id()] = n; - } - - if (this->_onion->count(v.first) < 1) - (*this->_onion)[v.first] = vec; - } - - this->_built.store(other->built()); - } - }; - } -} - -#endif //LIBND4J_GRAPH_H + void replaceState(VariableSpace *state, + const ExecutorConfiguration &configuration); + + FORCEINLINE bool built(); + + const OptimizedGraph &optimizedGraph() const; + + /** + * This method executes this Graph instance and returns execution results + * @param dictionary + * @return + */ + std::map execute( + const std::map &dictionary = {}, + const std::vector &outputs = {}, + const GraphExecutor &executor = GraphExecutor()) const; +}; + +FORCEINLINE bool Graph::built() { return _built.load(); } +} // namespace graph +} // namespace sd + +#endif // LIBND4J_GRAPH_H diff --git a/libnd4j/include/graph/GraphExecutioner.h b/libnd4j/include/graph/GraphExecutioner.h deleted file mode 100644 index 148b27951a01..000000000000 --- a/libnd4j/include/graph/GraphExecutioner.h +++ /dev/null @@ -1,84 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#ifndef LIBND4J_GRAPHEXECUTIONER_H -#define LIBND4J_GRAPHEXECUTIONER_H - -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -#define TF_INPUT "Placeholder" -#define TF_CONST "Const" -#define TF_VAR "VariableV2" - -namespace sd { - namespace graph { - - class ND4J_EXPORT GraphExecutioner { - protected: - - - public: - //static Nd4jStatus executeFlatNode(sd::graph::Graph *graph, sd::graph::Node *node, sd::graph::VariableSpace *variableSpace); - - static Nd4jStatus executeFlatNode(Graph *graph, Node *node, VariableSpace *variableSpace); - - /** - * This method executes given Graph - * @return - */ - static Nd4jStatus execute(Graph *graph, VariableSpace *variableSpace = nullptr); - - - /** - * This method executes graph stored at given FlatBuffers pointer - * - * @param pointer Pointer to FlatBuffer - * @return pointer to FlatBuffer with result - */ - static sd::graph::ResultWrapper* executeFlatBuffer(Nd4jPointer pointer); - - static flatbuffers::Offset execute(Graph *graph, flatbuffers::FlatBufferBuilder &builder, const FlatInferenceRequest* request); - - static Graph *importFromTensorFlow(const char *fileName); - - - static Graph *importFromFlatBuffers(const char *filename); - - static Graph *importFromFlatPointer(Nd4jPointer ptr); - }; - - long getFileSize(const char * filename); - uint8_t* readFlatBuffers(const char * filename); - - } -} - - -#endif //LIBND4J_GRAPHEXECUTIONER_H diff --git a/libnd4j/include/graph/GraphHolder.h b/libnd4j/include/graph/GraphHolder.h index 84aebd6948c1..7c166036e48d 100644 --- a/libnd4j/include/graph/GraphHolder.h +++ b/libnd4j/include/graph/GraphHolder.h @@ -18,76 +18,45 @@ // @author raver119@gmail.com // +#include +#include +#include #include #include -#include + #include -#include -#include -#include +#include namespace sd { - namespace graph { - class ND4J_EXPORT GraphHolder { - private: - MAP_IMPL _graphF; - - MAP_IMPL _locks; - - GraphHolder() = default; - ~GraphHolder() = default; - public: - static GraphHolder& getInstance(); - - void registerGraph(Nd4jLong graphId, Graph *graph); - - Graph* cloneGraph(Nd4jLong graphId); - - Graph* pullGraph(Nd4jLong graphId); - - void forgetGraph(Nd4jLong graphId); - - void dropGraph(Nd4jLong graphId); - - void dropGraphAny(Nd4jLong graphId); - - bool hasGraph(Nd4jLong graphId); - - bool hasGraphAny(Nd4jLong graphId); +namespace graph { +class SD_EXPORT GraphHolder { + private: - flatbuffers::Offset execute(Nd4jLong graphId, flatbuffers::FlatBufferBuilder &builder, const FlatInferenceRequest* request); + MAP_IMPL _graphs; - void replaceGraph(Nd4jLong graphId, Graph *graph); + std::mutex _mutex; - ///////////////////////////// + GraphHolder() = default; + ~GraphHolder() = default; - FORCEINLINE void lockWrite(Nd4jLong graphId) { - if (_locks.count(graphId) == 0) - return; + public: + static GraphHolder &getInstance(); - _locks[graphId].lockWrite(); - } + void registerGraph(Nd4jLong graphId, const Graph &graph); - FORCEINLINE void unlockWrite(Nd4jLong graphId) { - if (_locks.count(graphId) == 0) - return; + Graph &graph(Nd4jLong graphId); - _locks[graphId].unlockWrite(); - } + void forgetGraph(Nd4jLong graphId); - FORCEINLINE void lockRead(Nd4jLong graphId) { - if (_locks.count(graphId) == 0) - return; + void dropGraph(Nd4jLong graphId); - _locks[graphId].lockRead(); - } + bool hasGraph(Nd4jLong graphId); - FORCEINLINE void unlockRead(Nd4jLong graphId) { - if (_locks.count(graphId) == 0) - return; + flatbuffers::Offset execute( + Nd4jLong graphId, flatbuffers::FlatBufferBuilder &builder, + const FlatInferenceRequest *request); - _locks[graphId].unlockRead(); - } - }; - } -} \ No newline at end of file + void replaceGraph(Nd4jLong graphId, const Graph &graph); +}; +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/GraphState.h b/libnd4j/include/graph/GraphState.h deleted file mode 100644 index 89343997fa40..000000000000 --- a/libnd4j/include/graph/GraphState.h +++ /dev/null @@ -1,142 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 23.01.18. -// - -#ifndef LIBND4J_GRAPHSTATE_H -#define LIBND4J_GRAPHSTATE_H - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace sd { -namespace graph { - - class ND4J_EXPORT GraphState { - protected: - // id of this GraphState instance - Nd4jLong _id = 0; - - // map of scopes. Scope id is used as key, since it's referred in calls later anyway - MAP_IMPL _scopes; - - // this variable space holds temp references - VariableSpace _variableSpace; - - Graph *_graph; - - public: - explicit GraphState(Nd4jLong id); - ~GraphState(); - - /** - * - * @return - */ - Nd4jLong id(); - - /** - * This method adds scope to this state tracker - * - * @param scopeId - * @return - */ - Nd4jStatus registerScope(int scopeId); - - /** - * This method cheks if scope with given ID exists - * - * @param scopeId - ID of the scope - * @return - TRUE if scope exists, FALSE otherwise - */ - bool hasScope(int scopeId); - - /** - * This method removes specified scope from this state tracker - * - * @param scopeId - * @return - */ - Nd4jStatus forgetScope(int scopeId); - -#ifndef __JAVACPP_HACK__ - /** - * This method adds given op to the end of specified scope - * PLEASE NOTE: This method is used for tests mostly - * - * @param scopeId - * @param op - * @return - */ - Nd4jStatus attachOpToScope(int scopeId, int nodeId, sd::ops::DeclarableOp *op, ArgumentsList inputs); - - /** - * This method returns pointer to the scope with given id - * - * @param scopeId - id of the scope - */ - Scope* getScope(int scopeId); - - Graph* graph(); -#endif - /** - * This method adds given op to the end of specified scope - * - * @param scopeId - * @param opNum - * @param type - * @return - */ - Nd4jStatus attachOpToScope(int scopeId, Nd4jLong opNum, int type, ArgumentsList inputs); - - /** - * This method adds return statement to specified scope - * - * PLEASE NOTE: should be used only in body scopes - * - * @param scopeId - * @param nodeId - * @param args - * @return - */ - Nd4jStatus defineReturn(int scopeId, int nodeId, ArgumentsList args); - - /** - * This method returns current variable space of this state holder - * - * @return - */ - VariableSpace* variableSpace(); - }; -} -} - - - -#endif //LIBND4J_GRAPHSTATE_H diff --git a/libnd4j/include/graph/GraphUtils.h b/libnd4j/include/graph/GraphUtils.h index 3aaf820aeb8e..21e79e051a3d 100644 --- a/libnd4j/include/graph/GraphUtils.h +++ b/libnd4j/include/graph/GraphUtils.h @@ -21,23 +21,24 @@ #ifndef __H__GRAPH_UTILS__ #define __H__GRAPH_UTILS__ -#include -#include #include +#include + +#include namespace sd { namespace graph { -class ND4J_EXPORT GraphUtils { -public: - typedef std::vector OpList; +class SD_EXPORT GraphUtils { + public: + typedef std::vector OpList; -public: - static bool filterOperations(OpList& ops); - static std::string makeCommandLine(OpList& ops); - static int runPreprocessor(char const* input, char const* output); + public: + static bool filterOperations(OpList& ops); + static std::string makeCommandLine(OpList& ops); + static int runPreprocessor(char const* input, char const* output); }; -} -} +} // namespace graph +} // namespace sd #endif diff --git a/libnd4j/include/graph/InferenceRequest.h b/libnd4j/include/graph/InferenceRequest.h index b445fa0e1daa..ea195eaf4539 100644 --- a/libnd4j/include/graph/InferenceRequest.h +++ b/libnd4j/include/graph/InferenceRequest.h @@ -17,44 +17,45 @@ // // @author raver119@gmail.com // -#ifndef DEV_TESTS_INFERENCEREQUEST_H -#define DEV_TESTS_INFERENCEREQUEST_H +#ifndef SD_INFERENCEREQUEST_H +#define SD_INFERENCEREQUEST_H +#include +#include #include #include -#include -#include + #include "ExecutorConfiguration.h" namespace sd { - namespace graph { - class ND4J_EXPORT InferenceRequest { - private: - Nd4jLong _id; - std::vector _variables; - std::vector _deletables; +namespace graph { +class SD_EXPORT InferenceRequest { + private: + Nd4jLong _id; + std::vector> _variables; - ExecutorConfiguration *_configuration = nullptr; + ExecutorConfiguration _configuration; - void insertVariable(Variable* variable); - public: + void insertVariable(std::shared_ptr variable); - InferenceRequest(Nd4jLong graphId, ExecutorConfiguration *configuration = nullptr); - ~InferenceRequest(); + public: + InferenceRequest(Nd4jLong graphId, + const ExecutorConfiguration &configuration); + ~InferenceRequest(); - void appendVariable(int id, NDArray *array); - void appendVariable(int id, int index, NDArray *array); - void appendVariable(std::string &name, NDArray *array); - void appendVariable(std::string &name, int id, int index, NDArray *array); - void appendVariable(Variable *variable); + void appendVariable(int id, const NDArray &array); + void appendVariable(int id, int index, const NDArray &array); + void appendVariable(const std::string &name, const NDArray &array); + void appendVariable(const std::string &name, int id, int index, + const NDArray &array); + void appendVariable(std::shared_ptr variable); #ifndef __JAVACPP_HACK__ - flatbuffers::Offset asFlatInferenceRequest(flatbuffers::FlatBufferBuilder &builder); + flatbuffers::Offset asFlatInferenceRequest( + flatbuffers::FlatBufferBuilder &builder); #endif - }; - } -} - - +}; +} // namespace graph +} // namespace sd -#endif //DEV_TESTS_INFERENCEREQUEST_H +#endif // SD_INFERENCEREQUEST_H diff --git a/libnd4j/include/graph/Intervals.h b/libnd4j/include/graph/Intervals.h index 3a796407608d..e0ec274eb690 100644 --- a/libnd4j/include/graph/Intervals.h +++ b/libnd4j/include/graph/Intervals.h @@ -21,36 +21,33 @@ #ifndef LIBND4J_INTERVALS_H #define LIBND4J_INTERVALS_H +#include #include -#include + #include -#include +#include namespace sd { - class ND4J_EXPORT Intervals { - - private: - std::vector> _content; - - public: +class SD_EXPORT Intervals { + private: + std::vector> _content; - // default constructor - Intervals(); - - // constructor - Intervals(const std::initializer_list>& content ); - Intervals(const std::vector>& content ); - - // accessing operator - std::vector operator[](const Nd4jLong i) const; + public: + // default constructor + Intervals(); - // returns size of _content - int size() const; + // constructor + Intervals(const std::initializer_list>& content); + Intervals(const std::vector>& content); - }; + // accessing operator + std::vector operator[](const Nd4jLong i) const; + // returns size of _content + int size() const; +}; -} +} // namespace sd -#endif //LIBND4J_INTERVALS_H +#endif // LIBND4J_INTERVALS_H diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index 5fde65f3c16d..cc0803ab93bc 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,226 +19,179 @@ // @author raver119@gmail.com // -#ifndef LIBND4J_GNODE_H -#define LIBND4J_GNODE_H +#ifndef SD_GNODE_H +#define SD_GNODE_H -#include -#include -#include #include -#include "Context.h" -#include #include +#include +#include +#include +#include -namespace sd { - namespace graph { - - - class Graph; - - class ND4J_EXPORT Node { - protected: - // TODO: this field must be removed - sd::DataType _dataType; - - OpType _opType; - ContextPrototype* _protoContext = nullptr; - Nd4jLong _opNum; - int _id; - std::vector> _input; - std::vector> _output; - std::vector _dimensions; - - std::vector _referencedBy; - - int * _dim = nullptr; - std::string _name; - - - // this variable points to onion layer within graph - int _layer = -1; - - // many ops require extra parameters to run - double *_extraParams = nullptr; - - - // optional scalar. used in scalar ops and in summary stats - // TODO: this field must be removed - NDArray _scalar; - - bool _hasExternalOutputs; - bool _hasExternalInputs; - bool _hasInternalOutputs; - bool _hasInternalInputs; - - // this field is used to check, if op should be used in-place (so it can/will modify its inputs) - bool _isInplace = false; - - // this field is used to delete attached customOp - bool _isDeductable = false; - - OpClass _opClass; - - // these fields are used to store embedded CustomOps and Graph in case of Graph-in-Graph scenario - sd::graph::Graph * _graph= nullptr; - sd::ops::DeclarableOp *_customOp = nullptr; - - // each node can be active or inactive, if used with divergents, like IF statements - bool _active = true; - - // these fields contain information about Scope these ops are related to - int _scope_id = 0; - std::string _scope_name; - - // TODO: these 3 fields should be removed - int _rewindNode = -1; - std::pair _rewindLayer = {-1, -1}; - Nd4jLong _frameId = -1; - - public: - explicit Node(sd::ops::DeclarableOp *customOp, int id = 0, std::initializer_list input = {}, std::initializer_list output = {}, std::initializer_list dimensions = {}, float scalar = 0.0f, std::initializer_list tArgs = {}, std::initializer_list iArgs = {}); - explicit Node(OpType opType = OpType_TRANSFORM_SAME, int opNum = 0, int id = 0, std::initializer_list input = {}, std::initializer_list output = {}, std::initializer_list dimensions = {}, float scalar = 0.0f, std::initializer_list tArgs = {}, std::initializer_list iArgs = {}); - explicit Node(const sd::graph::FlatNode *node); - ~Node(); - - bool equals(Node *other); - - sd::DataType dataType(); - ContextPrototype *protoContext(); - OpType opType(); - Nd4jLong opNum(); - int id(); - std::vector> *input(); - std::vector> *output(); - - Nd4jLong getFrameId(); - void setFrameId(Nd4jLong frameId); - - int getRewindNode(); - void setRewindNode(int nodeId); - - std::pair& getRewindLayer(); - void setRewindLayer(int layerId, int stepId = 0); - - void setId(int id); - - double *extraParams(); - - bool isMultiInput(); - bool isMultiOutput(); - - int getLayer(); - void setLayer(int layer); - - bool isDivergencePoint(); - void setActive(bool reallyActive); - bool isActive(); - - bool hasExternalOutputs(); - bool hasExternalInputs(); - bool hasInternalOutputs(); - bool hasInternalInputs(); - - double scalar(); - - std::vector * getDimensions(); - int * getDimensionsPtr(); - - - void pickOutputOnce(int outputId); - void pickOutput(int outputId); - void pickOutput(int nodeId, int outputId); - void pickExternalOutput(int outputId); - void pickInput(int inputId); - void pickInput(int nodeId, int outputId); - void pickInput(std::pair& id); - - bool isDeductable(); - void setDeductable(bool reallyDeductable); - - void setName(std::string *name); - void setName(const std::string& name); - std::string * getName(); - std::string * name(); - - int totalReferences(); - void addReference(int nodeId); - - void setContextPrototype(ContextPrototype *block); - ContextPrototype* getContextPrototype(); - bool hasBlockAttached(); - - void setCustomOp(sd::ops::DeclarableOp *customOp = nullptr); - sd::ops::DeclarableOp* getCustomOp(); - bool hasCustomOp(); - - void setGraph(sd::graph::Graph* graph = nullptr); - sd::graph::Graph* getGraph(); - bool hasGraphEmbedded(); - - bool isInplace(); - void markInplace(bool reallyInplace); - - - OpClass getOpClass(); - - // these methods are used for internal profiling - void setOuterTime(Nd4jLong time); - void setInnerTime(Nd4jLong time); - - // methods related to scopes - bool isScoped(); - void setScopeInfo(int id, const char* name = nullptr); - int scopeId(); - std::string* scopeName(); - - void setOpType(OpType opType); - - // clone Node - Node* clone(); - - template - Node* asT(); - - FORCEINLINE void pullValues(Node *other) { - - if (this->_protoContext != nullptr) - delete _protoContext; - - this->_dataType = other->dataType(); - this->_protoContext = other->protoContext()->clone(); - this->_scalar = other->scalar(); - this->_hasExternalInputs = other->hasExternalInputs(); - this->_hasExternalOutputs = other->hasExternalOutputs(); - this->_hasInternalInputs = other->hasInternalInputs(); - this->_hasInternalOutputs = other->hasInternalOutputs(); - - this->markInplace(other->isInplace()); - this->setActive(other->isActive()); - this->setScopeInfo(other->scopeId(), other->scopeName()->c_str()); - this->setLayer(other->getLayer()); - this->setDeductable(other->isDeductable()); - - - if (this->_customOp != nullptr && _isDeductable) - delete this->_customOp; - - for (auto v: *other->input()) - this->_input.emplace_back(v); - - for (auto v: *other->output()) - this->_output.emplace_back(v); - - for (auto v: *other->getDimensions()) - this->_dimensions.emplace_back(v); - - } +#include "Context.h" - static sd::ops::DeclarableOp* buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum, NDArray *scalar); - static void deleteOpByType(OpType opType, void *op); - }; - } -} +namespace sd { +namespace graph { + +class Graph; + +class SD_EXPORT Node { + protected: + // int and string IDs + int _id = 0; + std::string _name; + + // Node state, basically + ContextPrototype _protoContext; + + // these 2 fields are used for Logic ops only + OpType _opType = OpType_GRAPH; + OpClass _opClass = OpClass_GRAPH; + Nd4jLong _opNum = 0; + + // Inputs are stored in format + std::vector> _input; + + // Outputs are stored in format + std::vector> _output; + + // Control flow dependencies for Node + std::vector> _dependencies; + std::vector _stringDependencies; + + // TODO: these fields should be removed + // service state fields + bool _hasExternalOutputs = false; + bool _hasExternalInputs = false; + bool _hasInternalOutputs = false; + bool _hasInternalInputs = false; + + std::shared_ptr _customOp; + + // this field is for Enter nodes only + mutable int _frameId = -1; + mutable int _exitId = -1; + + public: + explicit Node(const sd::ops::DeclarableOp &op, + const std::string &nodeName = {}, + const std::vector &tArgs = {}, + const std::vector &iArgs = {}, + const std::vector &bArgs = {}, + const std::vector &dArgs = {}); + explicit Node(const std::string &opName, const std::string &nodeName = {}, + const std::vector &tArgs = {}, + const std::vector &iArgs = {}, + const std::vector &bArgs = {}, + const std::vector &dArgs = {}); + explicit Node(const FlatNode *node); + ~Node(); + + /* + * FIXME: deprecated methods, to be removed + */ + explicit Node(const std::string &opName, const std::string &nodeName, + const int id, const std::vector &inputs = {}, + const std::vector &tArgs = {}, + const std::vector &iArgs = {}); + explicit Node(const std::string &opName, const int id = 0, + const std::vector> &inputs = {}, + const std::vector &tArgs = {}, + const std::vector &iArgs = {}); + explicit Node(sd::ops::DeclarableOp *customOp, int id = 0, + std::initializer_list input = {}, + std::initializer_list output = {}, + std::initializer_list dimensions = {}, float scalar = 0.0f, + std::initializer_list tArgs = {}, + std::initializer_list iArgs = {}); + explicit Node(std::shared_ptr customOp, int id = 0, + std::initializer_list input = {}, + std::initializer_list output = {}, + std::initializer_list dimensions = {}, float scalar = 0.0f, + std::initializer_list tArgs = {}, + std::initializer_list iArgs = {}); + explicit Node(OpType opType = OpType_TRANSFORM_SAME, int opNum = 0, + int id = 0, std::initializer_list input = {}, + std::initializer_list output = {}, + std::initializer_list dimensions = {}, float scalar = 0.0f, + std::initializer_list tArgs = {}, + std::initializer_list iArgs = {}); + + Node(const Node &other) noexcept; + + Node &operator=(const Node &other) noexcept; + + // move constructor + Node(Node &&other) noexcept; + + // move assignment operator + Node &operator=(Node &&other) noexcept; + + bool equals(const Node *other) const; + bool equals(const Node &other) const; + + OpType opType() const { return _opType; }; + OpClass opClass() const { return _opClass;}; + + int id() const; + const std::string &name() const; + + void setName(const std::string &name); + + Nd4jLong opNum() const; + + const std::vector> &inputs() const; + const std::vector> &outputs() const; + const std::vector> &dependencies() const; + + void setId(int id); + + bool isMultiInput(); + bool isMultiOutput(); + + bool isDivergencePoint(); + + bool hasExternalOutputs() const; + bool hasExternalInputs() const; + bool hasInternalOutputs() const; + bool hasInternalInputs() const; + + void pickOutputOnce(int outputId); + void pickOutput(int outputId); + void pickOutput(int nodeId, int outputId); + + void pickExternalOutput(int outputId); + + void pickInput(int inputId); + void pickInput(int nodeId, int outputId); + void pickInput(const std::pair &id); + void pickInput(const std::string &id); + + const ContextPrototype &contextPrototype() const; + void setContextPrototype(const ContextPrototype &block); + + void setCustomOp(const std::shared_ptr &customOp); + std::shared_ptr customOp() const; + + bool hasCustomOp() const; + +#ifndef __JAVACPP_HACK__ + // this method converts string deps to int deps + void actualizeDependencies(const MAP_IMPL &lookupTable) const; +#endif + + int frameId() const; + void setFrameId(int frameId); + + int exitId() const; + void setExitId(int exitId) const; + + // utility method that generates legacy ops out of OpType and OpNum + static std::shared_ptr buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_GNODE_H +#endif // SD_GNODE_H diff --git a/libnd4j/include/graph/NodeState.h b/libnd4j/include/graph/NodeState.h index 5e0a7a6d2dcf..199bdcb6aa9f 100644 --- a/libnd4j/include/graph/NodeState.h +++ b/libnd4j/include/graph/NodeState.h @@ -21,49 +21,49 @@ #ifndef LIBND4J_NODESTATE_H #define LIBND4J_NODESTATE_H -#include #include +#include namespace sd { - namespace graph { - class ND4J_EXPORT NodeState { - private: - // inner time spent on specific node - Nd4jLong _inner = 0; +namespace graph { +class SD_EXPORT NodeState { + private: + // inner time spent on specific node + Nd4jLong _inner = 0; + + // outer time spent on specific node + Nd4jLong _outer = 0; - // outer time spent on specific node - Nd4jLong _outer = 0; - - // flag that shows if node is active or disabled (i.e. after Switch op) - bool _active = true; + // flag that shows if node is active or disabled (i.e. after Switch op) + bool _active = true; - bool _executed = false; + bool _executed = false; - // active divergence branch - int _branch = 0; + // active divergence branch + int _branch = 0; - int _id = 0; - public: - NodeState(int id = 0); - ~NodeState() = default; + int _id = 0; - void setInnerTime(Nd4jLong time); - void setOuterTime(Nd4jLong time); + public: + NodeState(int id = 0); + ~NodeState() = default; - Nd4jLong innerTime(); - Nd4jLong outerTime(); + void setInnerTime(Nd4jLong time); + void setOuterTime(Nd4jLong time); - void markActive(bool isActive); - bool isActive(); + Nd4jLong innerTime(); + Nd4jLong outerTime(); - int branch(); - void markBranch(int index); + void markActive(bool isActive); + bool isActive(); - bool wasExecuted(); - void markExecuted(bool wasExecuted); - }; - } -} + int branch(); + void markBranch(int index); + bool wasExecuted(); + void markExecuted(bool wasExecuted); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_NODESTATE_H +#endif // LIBND4J_NODESTATE_H diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h new file mode 100644 index 000000000000..0c553f967988 --- /dev/null +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -0,0 +1,271 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_OPTIMIZEDGRAPH_H +#define SD_OPTIMIZEDGRAPH_H + +#include +#include +#include + +#include + +namespace sd { +namespace graph { + + +class SD_EXPORT OptimizedGraph { + private: + std::vector _sortedGraph; + MAP_IMPL _nodesMap; + + void sortUsualGraph(const VariableSpace& varSpace); + void sortGraphWithFrames(const VariableSpace& varSpace); + + /** + * This method removes all empty ExecutionLayers from this graph + */ + void purgeEmptyLayers(); + public: + OptimizedGraph(const MAP_IMPL& map, const VariableSpace& varSpace); + // move constructor + OptimizedGraph(OptimizedGraph&& other) noexcept; + // default constructor + OptimizedGraph() = default; + + /** + * returns number of nodes in this graph instance + * @return + */ + size_t size() const; + + /** + * returns OpSequences stored in a given layer + * @param index + * @return + */ + const ExecutionLayer& layer(const uint64_t& index) const { return _sortedGraph.at(index); } + + + /** + * returns number of layers within OptimizedGraph + * @return + */ + uint64_t numOfLayers() const { return _sortedGraph.size(); } + + + // move assignment operator + OptimizedGraph& operator=(OptimizedGraph&& other) noexcept; + + /** + * prints out graph content + */ + void printOut() const; + + /** + * Returns number of layers within OptimizedGraph + * @return + */ + uint64_t layers() const { return _sortedGraph.size(); } + + /** + * This method adds given OpSequence to execution queue + * @param sequence + */ + void append(const OpSequence &sequence); + + /** + * returns reference on _nodesMap + * @return + */ + const MAP_IMPL& nodesMap() const { return _nodesMap; } + + int nodeLayer(int nodeId) const; + int nodeSequence(int nodeId) const; + int nodeIndex(int nodeId) const; +}; + + +// class Graph; +// class NodeInfo; +// /** +// * This class acts as a topologically sorted & optimized Graph representation, +// * ready for execution +// */ +// class SD_EXPORT OptimizedGraph { +// protected: +// // here we store independent OpSequences +// // Graph starts from layer 0, and goes deeper step by step +// // on each layer we can have 1+ OpSequences that can be executed independent +// std::map _onion; + +// GraphMemoryManager* _memoryManager = nullptr; +// Graph* _originalGraph = nullptr; + +// mutable std::mutex _mutex; + +// mutable size_t _size = 0; + +// public: +// OptimizedGraph(Graph* original); +// OptimizedGraph() = default; +// ~OptimizedGraph() = default; + +// OptimizedGraph(const OptimizedGraph& other) noexcept; + +// OptimizedGraph& operator=(const OptimizedGraph& other) noexcept; + +// // move constructor +// OptimizedGraph(OptimizedGraph&& other) noexcept; + +// // move assignment operator +// OptimizedGraph& operator=(OptimizedGraph&& other) noexcept; + + + + +// /** +// * This method allows to append layer to this OptimizedGraph instance +// */ +// // FIXME: this method should be removed or made private +// void append(const std::vector& layer); +// void append(const ExecutionLayer& layer); +// void append(OpSequence& sequence); + +// /** +// * This method returns GraphMemoryManager instance that manages this Graph +// * @return +// */ +// const GraphMemoryManager& memoryManager() const; + +// /** +// * This method returns pointer to original Graph +// * @return +// */ +// const Graph& originalGraph() const; + +// /** +// * This method returns number of nodes in this graph instance +// * @return +// */ +// size_t size() const; + + +// protected: +// /* +// * optimize original graph +// */ +// void createOptimizedGraph(); +// /* +// * Topological graph analysis +// * @param const start node for search +// * @param const reference for nodes infor container +// * @param operation gather +// * @return stop iterating +// */ +// bool topolSearch(const int startNode, +// std::unordered_map& nodesConnections, +// std::vector>& opSeq) const; + +// * Optimized graph analysis prototyping, gather nodes infor +// * @param reference to node information collector +// * @param reference to start nodes +// * @param reference to input branching nodes (input branching node - atleast 2 +// * internal inputs) +// * @return stop iterating + +// bool opGraphProto(std::unordered_map& collector, +// std::set& startNodes, +// std::set& inBranchingNodes) const; +// /* +// * Define layers and sequence positions based on nodes infor +// * @param reference to node information collector +// * @param node ID +// * @param layer ID +// * @param sequence ID +// * @param map of layers and max sequence +// * @return stop iterating +// */ +// bool layersSeqDefine(std::unordered_map& collection, int ID, +// int layer, int nStartSeq, +// std::unordered_map& layersMaxSeq) const; +// /* +// * Initialize container with operations and context +// * @param const reference to layers and sequence collection +// * @param reference to opSequence collector +// * @return stop iterating +// */ +// bool initOpSeqContainer(const std::unordered_map& layersMaxSeq, +// std::vector>& vOpSeq) const; +// }; + +// class NodeInfo { +// private: +// std::set sConnections; + +// bool bInBranching; +// bool bOutBranching; +// bool bProcessed; + +// int nLayer; +// int nSequence; + +// sd::graph::OpType opType; + +// public: +// NodeInfo() { reset(); } +// ~NodeInfo() { reset(); } + +// void setInBranching(bool bValue) { bInBranching = bValue; } +// void setOutBranching(bool bValue) { bOutBranching = bValue; } +// void setProcessed(bool bValue = true) { bProcessed = bValue; } + +// void reset() { +// sConnections.clear(); +// bProcessed = bInBranching = bOutBranching = false; +// nLayer = 0; +// nSequence = -1; +// opType = OpType_CUSTOM; +// } + +// int layer() const { return nLayer; } +// void setLayer(int layer) { nLayer = layer; } + +// int sequence() const { return nSequence; } +// void setSequence(int sequence) { nSequence = sequence; } + +// void addConnection(int id) { sConnections.emplace(id); } +// const std::set& connections() const { return sConnections; } + +// void setType(sd::graph::OpType value) { opType = value; } +// sd::graph::OpType type() const { return opType; } +// bool isLogic() { return opType == OpType_LOGIC; } + +// bool isInBranching() const { return bInBranching; } +// bool isOutBranching() const { return bOutBranching; } +// bool isProcessed() const { return bProcessed; } +// }; + + + +} // namespace graph +} // namespace sd + +#endif // SD_OPTIMIZEDGRAPH_H diff --git a/libnd4j/include/graph/RandomGenerator.h b/libnd4j/include/graph/RandomGenerator.h index 407993a09ab0..68cdee7fabcc 100644 --- a/libnd4j/include/graph/RandomGenerator.h +++ b/libnd4j/include/graph/RandomGenerator.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -29,6 +30,12 @@ #include #include #include +#include +#include +#include +#include + +#include #include #include @@ -38,143 +45,142 @@ #endif namespace sd { - namespace graph { +namespace graph { #ifdef __CUDACC__ - class ND4J_EXPORT CudaManagedRandomGenerator { - private: - - protected: - void *devHolder; - - public: - void *operator new(size_t len) { - void *ptr; - auto res = cudaHostAlloc(&ptr, len, cudaHostAllocDefault); - if (res != 0) - throw std::runtime_error("CudaManagedRandomGenerator: failed to allocate memory"); - - return ptr; - } - - void operator delete(void *ptr) { - cudaFreeHost(ptr); - } - }; - - class ND4J_EXPORT RandomGenerator : public CudaManagedRandomGenerator { +class SD_EXPORT CudaManagedRandomGenerator { + private: + protected: + void* devHolder; + + public: + void* operator new(size_t len) { + void* ptr; + auto res = cudaHostAlloc(&ptr, len, cudaHostAllocDefault); + if (res != 0) + throw std::runtime_error( + "CudaManagedRandomGenerator: failed to allocate memory"); + + return ptr; + } + + void operator delete(void* ptr) { cudaFreeHost(ptr); } +}; + +class SD_EXPORT RandomGenerator : public CudaManagedRandomGenerator { #else - class ND4J_EXPORT RandomGenerator { +class SD_EXPORT RandomGenerator { #endif - private: + private: #ifndef __CUDACC__ - void *placeHolder; + void* placeHolder; #endif - // GRAPH-LEVEL STATE - u64 _rootState; + // GRAPH-LEVEL STATE + u64 _rootState; - // NODE-LEVEL STATE - u64 _nodeState; + // NODE-LEVEL STATE + u64 _nodeState; - /** - * Utility method, returns number of milliseconds since 1970 - * Leave this static if possible to avoid problems in constructor - */ - static FORCEINLINE Nd4jLong currentMilliseconds(); + /** + * Utility method, returns number of milliseconds since 1970 + * Leave this static if possible to avoid problems in constructor + */ + static FORCEINLINE Nd4jLong currentMilliseconds(); - public: +public: FORCEINLINE _CUDA_HD uint32_t xoroshiro32(uint64_t index); - FORCEINLINE _CUDA_HD uint64_t xoroshiro64(uint64_t index); - - /** - * This method returns integer value between 0 and MAX_UINT - */ - //uint32_t relativeUInt32(Nd4jLong index); - - public: - FORCEINLINE RandomGenerator(Nd4jLong rootSeed = 0, Nd4jLong nodeSeed = 0); - - /** - * This method allows to change graph-level state in runtime. - * PLEASE NOTE: this method will change state of node as well. - */ - FORCEINLINE _CUDA_H void setStates(Nd4jLong rootSeed, Nd4jLong nodeState = 0); - - - - /** - * This method returns T value between from and to - */ - template - FORCEINLINE _CUDA_HD T relativeT(Nd4jLong index, T from, T to); - - /** - * This method returns T value between 0 and MAX_T - */ - template - FORCEINLINE _CUDA_HD T relativeT(Nd4jLong index); - - /** - * These two methods are made for JVM - * @param index - * @return - */ - FORCEINLINE _CUDA_HD int relativeInt(Nd4jLong index); - FORCEINLINE _CUDA_HD Nd4jLong relativeLong(Nd4jLong index); - - FORCEINLINE _CUDA_HD void rewindH(uint64_t steps); - - /** - * These methods set up only node states, with non-changed root ones - */ - FORCEINLINE _CUDA_H void setSeed(int seed) { - _nodeState._ulong = static_cast(seed); - } - - FORCEINLINE _CUDA_H void setSeed(uint64_t seed) { - _nodeState._ulong = seed; - } - - FORCEINLINE _CUDA_HD Nd4jLong rootState() { - return _rootState._long; - } - - FORCEINLINE _CUDA_HD Nd4jLong nodeState() { - return _nodeState._long; - } - }; - - - FORCEINLINE RandomGenerator::RandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) { - // this seed is used graph-level state - if (rootSeed == 0) - rootSeed = currentMilliseconds(); - - // graph-level state is just first seed - _rootState._long = rootSeed; - - // used to build second, node state - _nodeState._long = (nodeSeed != 0 ? nodeSeed: 1298567341LL); - } + FORCEINLINE _CUDA_HD uint64_t xoroshiro64(uint64_t index); - FORCEINLINE void RandomGenerator::setStates(Nd4jLong rootSeed, Nd4jLong nodeSeed) { - // this seed is used graph-level state - if (rootSeed == 0) - rootSeed = currentMilliseconds(); + /** + * This method returns integer value between 0 and MAX_UINT + */ + // uint32_t relativeUInt32(Nd4jLong index); - // graph-level state is just first seed - _rootState._long = rootSeed; + public: + FORCEINLINE RandomGenerator(Nd4jLong rootSeed = 0, Nd4jLong nodeSeed = 0); - // used to build second, node state - _nodeState._long = (nodeSeed != 0 ? nodeSeed: 1298567341LL); - } + RandomGenerator(const RandomGenerator& other) noexcept; - FORCEINLINE Nd4jLong RandomGenerator::currentMilliseconds() { - auto s = std::chrono::system_clock::now().time_since_epoch(); - auto v = std::chrono::duration_cast(s).count(); - return v; - } + RandomGenerator& operator=(const RandomGenerator& other) noexcept; - template <> + // move constructor + RandomGenerator(RandomGenerator&& other) noexcept; + + // move assignment operator + RandomGenerator& operator=(RandomGenerator&& other) noexcept; + + /** + * This method allows to change graph-level state in runtime. + * PLEASE NOTE: this method will change state of node as well. + */ + FORCEINLINE _CUDA_H void setStates(Nd4jLong rootSeed, Nd4jLong nodeState = 0); + + /** + * This method returns T value between from and to + */ + template + FORCEINLINE _CUDA_HD T relativeT(Nd4jLong index, T from, T to); + + /** + * This method returns T value between 0 and MAX_T + */ + template + FORCEINLINE _CUDA_HD T relativeT(Nd4jLong index); + + /** + * These two methods are made for JVM + * @param index + * @return + */ + FORCEINLINE _CUDA_HD int relativeInt(Nd4jLong index); + FORCEINLINE _CUDA_HD Nd4jLong relativeLong(Nd4jLong index); + + FORCEINLINE _CUDA_HD void rewindH(uint64_t steps); + + /** + * These methods set up only node states, with non-changed root ones + */ + FORCEINLINE _CUDA_H void setSeed(int seed) { + _nodeState._ulong = static_cast(seed); + } + + FORCEINLINE _CUDA_H void setSeed(uint64_t seed) { _nodeState._ulong = seed; } + + FORCEINLINE _CUDA_HD Nd4jLong rootState() { return _rootState._long; } + + FORCEINLINE _CUDA_HD Nd4jLong nodeState() { return _nodeState._long; } +}; + +FORCEINLINE RandomGenerator::RandomGenerator(Nd4jLong rootSeed, + Nd4jLong nodeSeed) { + // this seed is used graph-level state + if (rootSeed == 0) rootSeed = currentMilliseconds(); + + // graph-level state is just first seed + _rootState._long = rootSeed; + + // used to build second, node state + _nodeState._long = (nodeSeed != 0 ? nodeSeed : 1298567341LL); +} + +FORCEINLINE void RandomGenerator::setStates(Nd4jLong rootSeed, + Nd4jLong nodeSeed) { + // this seed is used graph-level state + if (rootSeed == 0) rootSeed = currentMilliseconds(); + + // graph-level state is just first seed + _rootState._long = rootSeed; + + // used to build second, node state + _nodeState._long = (nodeSeed != 0 ? nodeSeed : 1298567341LL); +} + +FORCEINLINE Nd4jLong RandomGenerator::currentMilliseconds() { + auto s = std::chrono::system_clock::now().time_since_epoch(); + auto v = std::chrono::duration_cast(s).count(); + return v; +} + +template <> _CUDA_HD FORCEINLINE float RandomGenerator::relativeT(Nd4jLong index) { u32 u; u._u32 = (0x3f800000 | (this->xoroshiro32(index) >> 9)); @@ -190,115 +196,123 @@ namespace sd { #else return (double) relativeT(index); #endif - } + }template <> +_CUDA_HD FORCEINLINE uint64_t +RandomGenerator::relativeT(Nd4jLong index) { + return this->xoroshiro64(index); +} - template <> - _CUDA_HD FORCEINLINE uint64_t RandomGenerator::relativeT(Nd4jLong index) { - return this->xoroshiro64(index); - } +template <> +_CUDA_HD FORCEINLINE uint32_t +RandomGenerator::relativeT(Nd4jLong index) { + return this->xoroshiro32(index); +} - template <> - _CUDA_HD FORCEINLINE uint32_t RandomGenerator::relativeT(Nd4jLong index) { - return this->xoroshiro32(index); - } +template <> +_CUDA_HD FORCEINLINE int RandomGenerator::relativeT(Nd4jLong index) { + auto r =relativeT(index); + return r <= DataTypeUtils::max() ? r : r % DataTypeUtils::max(); - template <> - _CUDA_HD FORCEINLINE int RandomGenerator::relativeT(Nd4jLong index) { - auto r = relativeT(index); - return r <= DataTypeUtils::max() ? r : r % DataTypeUtils::max(); - } +} - template <> - _CUDA_HD FORCEINLINE Nd4jLong RandomGenerator::relativeT(Nd4jLong index) { - auto r = relativeT(index); - return r <= DataTypeUtils::max() ? r : r % DataTypeUtils::max(); - } +template <> +_CUDA_HD FORCEINLINE Nd4jLong +RandomGenerator::relativeT(Nd4jLong index) { + auto r =relativeT(index); + return r <= DataTypeUtils::max() ? r : r % DataTypeUtils::max(); - template - _CUDA_HD FORCEINLINE T RandomGenerator::relativeT(Nd4jLong index, T from, T to) { - auto t = this->relativeT(index); - auto z = from + T(t * (to - from)); - return z; - } +} - template <> - _CUDA_HD FORCEINLINE Nd4jLong RandomGenerator::relativeT(Nd4jLong index, Nd4jLong from, Nd4jLong to) { - auto t = this->relativeT(index); - auto z = from + Nd4jLong(t * (to - from)); - return z; - } +template +_CUDA_HD FORCEINLINE T RandomGenerator::relativeT(Nd4jLong index, T from, + T to) { + auto t = this->relativeT(index); + auto z = from + T(t * (to - from)); + return z; +} - template <> - _CUDA_HD FORCEINLINE int RandomGenerator::relativeT(Nd4jLong index, int from, int to) { - auto t = this->relativeT(index); - auto z = from + float(t * (to - from)); - return z; - } +template <> +_CUDA_HD FORCEINLINE Nd4jLong RandomGenerator::relativeT(Nd4jLong index, + Nd4jLong from, + Nd4jLong to) { + auto t = this->relativeT(index); + auto z = from + Nd4jLong(t * (to - from)); + return z; +} - template - _CUDA_HD FORCEINLINE T RandomGenerator::relativeT(Nd4jLong index) { - // This is default implementation for floating point types - return static_cast(relativeT(index)); - } +template <> +_CUDA_HD FORCEINLINE int RandomGenerator::relativeT(Nd4jLong index, int from, + int to) { + auto t = this->relativeT(index); + auto z = from + float(t * (to - from)); + return z; +} +template +_CUDA_HD FORCEINLINE T RandomGenerator::relativeT(Nd4jLong index) { + // This is default implementation for floating point types + + return static_cast(relativeT(index)); +} - _CUDA_HD FORCEINLINE int RandomGenerator::relativeInt(Nd4jLong index) { - auto r = relativeT(index); +_CUDA_HD FORCEINLINE int RandomGenerator::relativeInt(Nd4jLong index) { + auto r = relativeT(index); return r <= DataTypeUtils::max() ? r : r % DataTypeUtils::max(); - } +} - _CUDA_HD FORCEINLINE Nd4jLong RandomGenerator::relativeLong(Nd4jLong index) { - auto r = relativeT(index); +_CUDA_HD FORCEINLINE Nd4jLong RandomGenerator::relativeLong(Nd4jLong index) { + auto r = relativeT(index); return r <= DataTypeUtils::max() ? r : r % DataTypeUtils::max(); - } +} - ////// - static FORCEINLINE _CUDA_HD uint32_t rotl(const uint32_t x, int k) { - return (x << k) | (x >> (32 - k)); - } +////// +static FORCEINLINE _CUDA_HD uint32_t rotl(const uint32_t x, int k) { + return (x << k) | (x >> (32 - k)); +} - static FORCEINLINE _CUDA_HD uint64_t rotl(const uint64_t x, int k) { - return (x << k) | (x >> (64 - k)); - } +static FORCEINLINE _CUDA_HD uint64_t rotl(const uint64_t x, int k) { + return (x << k) | (x >> (64 - k)); +} - static FORCEINLINE _CUDA_HD uint32_t next(uint32_t s0, uint32_t s1, uint32_t s2, uint32_t s3) { +static FORCEINLINE _CUDA_HD uint32_t next(uint32_t s0, uint32_t s1, uint32_t s2, uint32_t s3) { const uint32_t result = rotl(s0 + s3, 7) + s0; return result; } _CUDA_HD FORCEINLINE uint32_t RandomGenerator::xoroshiro32(uint64_t index) { - auto s0 = _rootState._ulong; - auto s1 = _nodeState._ulong; + auto s0 = _rootState._ulong; + auto s1 = _nodeState._ulong; - // xor by idx - s0 |= ((index + 2) * (s1 + 24243287)); - s1 ^= ((index + 2) * (s0 + 723829)); + // xor by idx + s0 |= ((index + 2) * (s1 + 24243287)); + s1 ^= ((index + 2) * (s0 + 723829)); - unsigned long val = 0; - val = s1 ^ s0; - int* pHalf = reinterpret_cast(&val); + unsigned long val = 0; + val = s1 ^ s0; + int* pHalf = reinterpret_cast(&val); - return rotl(*pHalf * 0x9E3779BB, 5) * 5; - } + return rotl(*pHalf * 0x9E3779BB, 5) * 5; +} - _CUDA_HD FORCEINLINE uint64_t RandomGenerator::xoroshiro64(uint64_t index) { - uint64_t upper = ((uint64_t) xoroshiro32(index)) << 32; - uint32_t lower = xoroshiro32(sd::math::nd4j_rotl(index, 32)); - return upper + lower; - } +_CUDA_HD FORCEINLINE uint64_t RandomGenerator::xoroshiro64(uint64_t index) { + uint64_t upper = ((uint64_t) xoroshiro32(index)) << 32; + uint32_t lower = xoroshiro32(sd::math::nd4j_rotl(index, 32)); + + return upper + lower; +} - _CUDA_HD FORCEINLINE void RandomGenerator::rewindH(uint64_t steps) { +_CUDA_HD FORCEINLINE void RandomGenerator::rewindH(uint64_t steps) { // we only update node state, if any - auto s0 = _nodeState._du32._v0; - auto s1 = _nodeState._du32._v1; + auto s0 = _nodeState._du32._v0; + auto s1 = _nodeState._du32._v1; - s1 ^= s0; - _nodeState._du32._v0 = rotl(s0, 26) ^ s1 ^ (s1 << 9); // a, b - _nodeState._du32._v1 = rotl(s1, 13); // c + s1 ^= s0; + _nodeState._du32._v0 = rotl(s0, 26) ^ s1 ^ (s1 << 9); // a, b + _nodeState._du32._v1 = rotl(s1, 13); // c - _nodeState._long ^= (steps ^ 0xdeadbeef); - } - } + _nodeState._long ^= (steps ^ 0xdeadbeef); } +} // namespace graph +} // namespace sd #endif diff --git a/libnd4j/include/graph/RandomGenerator.hpp b/libnd4j/include/graph/RandomGenerator.hpp index fbbc8bad1749..5ddd3e0cd19b 100644 --- a/libnd4j/include/graph/RandomGenerator.hpp +++ b/libnd4j/include/graph/RandomGenerator.hpp @@ -19,26 +19,27 @@ // // relies on xoroshiro64** and xoroshiro128 implementations +#include +#include +#include #include #include -#include + #include -#include -#include namespace sd { - namespace graph { - - +namespace graph { - template _CUDA_HD int RandomGenerator::relativeT(Nd4jLong, int, int); - template _CUDA_HD float16 RandomGenerator::relativeT(Nd4jLong, float16, float16); - template _CUDA_HD float RandomGenerator::relativeT(Nd4jLong, float, float); - template _CUDA_HD double RandomGenerator::relativeT(Nd4jLong, double, double); - template _CUDA_HD Nd4jLong RandomGenerator::relativeT(Nd4jLong, Nd4jLong, Nd4jLong); +template _CUDA_HD int RandomGenerator::relativeT(Nd4jLong, int, int); +template _CUDA_HD float16 RandomGenerator::relativeT(Nd4jLong, float16, + float16); +template _CUDA_HD float RandomGenerator::relativeT(Nd4jLong, float, float); +template _CUDA_HD double RandomGenerator::relativeT(Nd4jLong, double, double); +template _CUDA_HD Nd4jLong RandomGenerator::relativeT(Nd4jLong, Nd4jLong, + Nd4jLong); - template _CUDA_HD float16 RandomGenerator::relativeT(Nd4jLong); - template _CUDA_HD float RandomGenerator::relativeT(Nd4jLong); - template _CUDA_HD double RandomGenerator::relativeT(Nd4jLong); - } -} \ No newline at end of file +template _CUDA_HD float16 RandomGenerator::relativeT(Nd4jLong); +template _CUDA_HD float RandomGenerator::relativeT(Nd4jLong); +template _CUDA_HD double RandomGenerator::relativeT(Nd4jLong); +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/ResultWrapper.h b/libnd4j/include/graph/ResultWrapper.h index fe5193097818..9ee2bcb6405f 100644 --- a/libnd4j/include/graph/ResultWrapper.h +++ b/libnd4j/include/graph/ResultWrapper.h @@ -21,27 +21,26 @@ #ifndef LIBND4J_RESULTWRAPPER_H #define LIBND4J_RESULTWRAPPER_H +#include #include #include -#include namespace sd { - namespace graph { - class ND4J_EXPORT ResultWrapper { - private: - Nd4jLong _size = 0L; - Nd4jPointer _pointer = nullptr; - - public: - ResultWrapper(Nd4jLong size, Nd4jPointer ptr); - ~ResultWrapper(); +namespace graph { +class SD_EXPORT ResultWrapper { + private: + Nd4jLong _size = 0L; + Nd4jPointer _pointer = nullptr; - Nd4jLong size(); + public: + ResultWrapper(Nd4jLong size, Nd4jPointer ptr); + ~ResultWrapper(); - Nd4jPointer pointer(); - }; - } -} + Nd4jLong size(); + Nd4jPointer pointer(); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_RESULTWRAPPER_H +#endif // LIBND4J_RESULTWRAPPER_H diff --git a/libnd4j/include/graph/Scope.h b/libnd4j/include/graph/Scope.h deleted file mode 100644 index 42b99c18e9bc..000000000000 --- a/libnd4j/include/graph/Scope.h +++ /dev/null @@ -1,106 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 14.10.2017. -// - -#ifndef LIBND4J_SCOPE_H -#define LIBND4J_SCOPE_H - -#include -#include -#include - -namespace sd { - namespace graph { - - /** - * Scope holds sequential list of operations, and made suitable for continuous - * re-execution of multiple operations. - * - * @tparam T - */ - class ND4J_EXPORT Scope { - protected: - // Graph-unique IDs for Scope instances - int _id; - std::string _name; - - // list of nodes to run, always sequential - // Graph takes care of topo sort - std::vector _nodes; - public: - // attach GiG here, with shared namespace? - // or just rebuilt graph leaf? - // ¯\_(ツ)_/¯ - - // default consructor - explicit Scope(int id, const char* name = nullptr); - - // default destructor - ~Scope(); - - /** - * this method adds op node to the scope - * - * PLEASE NOTE: We assume that ops are being added ORDERED - */ - void push_back(Node* node); - - /** - * This method returns list of ops stored earlier, ready for execution - * - * PLEASE NOTE: If the scope is conditional - last op in list should be BooleanOp - * @return - */ - std::vector * nodes(); - - /** - * This function returns number of nodes in this scope - * - * @return - */ - int size(); - - /** - * Returns ID of this scope - * @return - */ - int id(); - - /** - * Returns name of this scope - * - * @return - */ - std::string* name(); - - /** - * This method returns clone of this Scope - */ - Scope* clone(); - - /** - * This method removes all Nodes from this scope - */ - void forgetNodes(); - }; - } -} - - -#endif //LIBND4J_SCOPE_H diff --git a/libnd4j/include/graph/SessionLocalStorage.h b/libnd4j/include/graph/SessionLocalStorage.h deleted file mode 100644 index 3cb77ec3a5f0..000000000000 --- a/libnd4j/include/graph/SessionLocalStorage.h +++ /dev/null @@ -1,65 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#ifndef LIBND4J_SESSIONLOCALSTORAGE_H -#define LIBND4J_SESSIONLOCALSTORAGE_H - -#include -#include -#include -#include "VariableSpace.h" -#include "Context.h" -#include "Stash.h" -#include - -namespace sd{ - namespace graph { - class ND4J_EXPORT SessionLocalStorage { - protected: - std::atomic _sessionCounter; - MAP_IMPL _threadSession; - MAP_IMPL _threadVariableSpace; - - VariableSpace* _variableSpace; - Stash* _stash; - - std::mutex _mutex; - - Nd4jLong getSessionId(); - Nd4jLong getThreadId(); - public: - SessionLocalStorage(VariableSpace* variableSpace = nullptr, Stash* stash = nullptr); - - ~SessionLocalStorage(); - - VariableSpace* localVariableSpace(); - VariableSpace* localVariableSpace(Nd4jLong sessionId); - - - Nd4jLong startSession(); - void endSession(Nd4jLong sessionId); - void endSession(); - - int numberOfSessions(); - }; - } -} - -#endif //LIBND4J_SESSIONLOCALSTORAGE_H diff --git a/libnd4j/include/graph/Stash.h b/libnd4j/include/graph/Stash.h index ba431d05756e..d5d375d8d351 100644 --- a/libnd4j/include/graph/Stash.h +++ b/libnd4j/include/graph/Stash.h @@ -23,72 +23,70 @@ //#include #include -#include -#include -#include +#include + #include #include -#include +#include +#include +#include namespace sd { - namespace graph { - class ND4J_EXPORT KeyPair { - int _node; - std::string _name; - public: - KeyPair(int node = 0, const char *name = nullptr); +namespace graph { +class SD_EXPORT KeyPair { + int _node; + std::string _name; + + public: + KeyPair(int node = 0, const char *name = nullptr); - bool operator<(const KeyPair &other) const; + bool operator<(const KeyPair &other) const; - bool operator==(const KeyPair &other) const { - return _node == other._node; - } + bool operator==(const KeyPair &other) const { return _node == other._node; } - int key() const { return _node; } - std::string name() const { return _name; } - }; - } -} + int key() const { return _node; } + std::string name() const { return _name; } +}; +} // namespace graph +} // namespace sd #ifndef __JAVACPP_HACK__ namespace std { - template <> - class ND4J_EXPORT hash { - public: - size_t operator()(const sd::graph::KeyPair& k) const; - }; +template <> +class SD_EXPORT hash { + public: + size_t operator()(const sd::graph::KeyPair &k) const; }; +}; // namespace std #endif namespace sd { - namespace graph { - class ND4J_EXPORT Stash { - protected: - std::map _stash; - std::vector _handles; - - public: - Stash(); - ~Stash(); +namespace graph { +class SD_EXPORT Stash { + protected: + std::map _stash; + std::vector _handles; - //void storeArray(sd::graph::Block& block, const char *name, sd::NDArray *array); - void storeArray(int nodeId, const char *name, sd::NDArray *array); + public: + Stash(); + ~Stash(); - //bool checkStash(sd::graph::Block& block, const char *name); - bool checkStash(int nodeId, const char *name); + // void storeArray(sd::graph::Block& block, const char *name, + // sd::NDArray *array); + void storeArray(int nodeId, const char *name, sd::NDArray *array); - //sd::NDArray* extractArray(sd::graph::Block& block, const char *name); - sd::NDArray* extractArray(int nodeId, const char *name); - - void clear(); - }; - } - -} + // bool checkStash(sd::graph::Block& block, const char *name); + bool checkStash(int nodeId, const char *name); + // sd::NDArray* extractArray(sd::graph::Block& block, const char *name); + sd::NDArray *extractArray(int nodeId, const char *name); + void clear(); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_STASH_H +#endif // LIBND4J_STASH_H diff --git a/libnd4j/include/graph/Status.h b/libnd4j/include/graph/Status.h index 42794488dd6b..93a957846c8c 100644 --- a/libnd4j/include/graph/Status.h +++ b/libnd4j/include/graph/Status.h @@ -21,30 +21,28 @@ #ifndef ND4J_STATUS_H #define ND4J_STATUS_H -#include -#include -#include #include +#include +#include +#include namespace sd { - class ND4J_EXPORT Status { - public: - static FORCEINLINE Nd4jStatus OK() { - return ND4J_STATUS_OK; - }; +class SD_EXPORT Status { + public: + static FORCEINLINE Nd4jStatus OK() { return ND4J_STATUS_OK; }; - static FORCEINLINE Nd4jStatus CODE(Nd4jStatus code, const char *message) { - nd4j_printf("%s\n", message); - return code; - } + static FORCEINLINE Nd4jStatus CODE(Nd4jStatus code, const char *message) { + nd4j_printf("%s\n", message); + return code; + } - static FORCEINLINE Nd4jStatus THROW(const char *message = nullptr) { - if (message != nullptr) { - nd4j_printf("%s\n", message); - } - return ND4J_STATUS_KERNEL_FAILURE; - } - }; -} + static FORCEINLINE Nd4jStatus THROW(const char *message = nullptr) { + if (message != nullptr) { + nd4j_printf("%s\n", message); + } + return ND4J_STATUS_KERNEL_FAILURE; + } +}; +} // namespace sd -#endif // STATUS_H \ No newline at end of file +#endif // STATUS_H \ No newline at end of file diff --git a/libnd4j/include/graph/TimeHolder.h b/libnd4j/include/graph/TimeHolder.h index 191a75bace13..9d332226cbad 100644 --- a/libnd4j/include/graph/TimeHolder.h +++ b/libnd4j/include/graph/TimeHolder.h @@ -21,32 +21,29 @@ #ifndef LIBND4J_TIMEHOLDER_H #define LIBND4J_TIMEHOLDER_H -#include -#include #include +#include -namespace sd { - namespace graph { - class ND4J_EXPORT TimeHolder { - private: - std::map _outer; - std::map _inner; - - - public: - - TimeHolder() = default; - ~TimeHolder() = default; - - - void setOuterTime(int nodeId, Nd4jLong time); - void setInnerTime(int nodeId, Nd4jLong time); - - - Nd4jLong outerTime(int nodeId); - Nd4jLong innerTime(int nodeId); - }; - } -} +#include -#endif //LIBND4J_TIMEHOLDER_H +namespace sd { +namespace graph { +class SD_EXPORT TimeHolder { + private: + std::map _outer; + std::map _inner; + + public: + TimeHolder() = default; + ~TimeHolder() = default; + + void setOuterTime(int nodeId, Nd4jLong time); + void setInnerTime(int nodeId, Nd4jLong time); + + Nd4jLong outerTime(int nodeId); + Nd4jLong innerTime(int nodeId); +}; +} // namespace graph +} // namespace sd + +#endif // LIBND4J_TIMEHOLDER_H diff --git a/libnd4j/include/graph/Variable.h b/libnd4j/include/graph/Variable.h index b3ac74533f34..b948dc993c2c 100644 --- a/libnd4j/include/graph/Variable.h +++ b/libnd4j/include/graph/Variable.h @@ -21,131 +21,133 @@ #ifndef LIBND4J_VARIABLE_H #define LIBND4J_VARIABLE_H -#include #include #include #include #include -#include #include +#include + +#include #ifndef __JAVACPP_HACK__ namespace std { - template <> - class ND4J_EXPORT hash> { - public: - size_t operator()(const std::pair& k) const; - }; - - template <> - class ND4J_EXPORT hash { - public: - size_t operator()(const bfloat16& k) const; - }; - - template <> - class ND4J_EXPORT hash { - public: - size_t operator()(const float16& k) const; - }; +template <> +class SD_EXPORT hash> { + public: + size_t operator()(const std::pair &k) const; +}; + +template <> +class SD_EXPORT hash { + public: + size_t operator()(const bfloat16 &k) const; }; +template <> +class SD_EXPORT hash { + public: + size_t operator()(const float16 &k) const; +}; +}; // namespace std + #endif namespace sd { - namespace graph { - class ND4J_EXPORT Variable { - protected: - int _id = 0; - int _index = 0; - sd::NDArray *_ndarray = nullptr; - std::string _name; - - std::vector _shape; - - bool _external = false; - bool _readOnly = false; - bool _placeholder = false; - bool _removable = true; - - // for now we're setting default to numeric - // in future we'll be fetching it right from the array, - //InputType _variableType = InputType_UNDEFINED; - //DataType _dataType = INHERIT; - - sd::NDArrayList *_list = nullptr; - - VariableType _variableType = VariableType::NDARRAY; - - public: - Variable(bool placeHolder); - Variable(sd::NDArray *arrayw, const char *name, int id, int idx = 0); - Variable(sd::NDArray *array = nullptr, const char *name = nullptr); +namespace graph { +class SD_EXPORT Variable { + protected: + int _id = 0; + int _index = 0; + std::string _name; + + std::vector _shape; + DataType _dtype; + + bool _external = false; + bool _readOnly = false; + bool _placeholder = false; + bool _removable = true; + + // actual content + std::shared_ptr _ndarray; + std::shared_ptr _list; + + VariableType _variableType = VariableType::NDARRAY; + + std::vector> _dependencies; + std::vector _stringDependencies; + + public: + explicit Variable(bool placeHolder, DataType dataType = DataType::ANY, const std::vector &shape = {}); + explicit Variable(const sd::NDArray &array, const std::string &name, int id, int idx = 0); + explicit Variable(std::shared_ptr array, const std::string &name, int id, int idx = 0); + explicit Variable(std::shared_ptr array, const std::string &name, int id, int idx = 0); + explicit Variable(std::shared_ptr array, const char *name = nullptr); + explicit Variable(const NDArrayList &arrayList, const std::string &name, int id, int idx = 0); + explicit Variable(); #ifndef __JAVACPP_HACK__ - Variable(const sd::graph::FlatVariable *flatVariable); + explicit Variable(const sd::graph::FlatVariable *flatVariable); #endif - ~Variable(); - - Variable* clone(); - - template - ND4J_EXPORT Variable* asT(); + ~Variable(); - bool hasNDArray(); - sd::NDArray* getNDArray(); - void setNDArray(sd::NDArray *array); + bool hasNDArray() const; + std::shared_ptr getNDArray() const; + void setNDArray(std::shared_ptr array); - bool hasNDArrayList(); - sd::NDArrayList* getNDArrayList(); - void setNDArrayList(sd::NDArrayList* list); + bool hasNDArrayList() const; + std::shared_ptr getNDArrayList() const; + void setNDArrayList(std::shared_ptr list); - bool isExternal(); - bool isReadOnly(); - bool isEmpty(); - bool isRemovable(); + bool isExternal() const; + bool isReadOnly() const; + bool isEmpty() const; + bool isRemovable() const; - bool isPlaceholder(); + bool isPlaceholder() const; - VariableType variableType(); - void setVariableType(VariableType variableType); + VariableType variableType() const; + void setVariableType(VariableType variableType); - /** - * This method returns InputType of this variable - */ - //InputType variableType() { - // return _variableType; - //} + void markExternal(bool reallyExternal); + void markReadOnly(bool reallyReadOnly); + void markRemovable(bool reallyRemovable); - void markExternal(bool reallyExternal); - void markReadOnly(bool reallyReadOnly); - void markRemovable(bool reallyRemovable); + int id() const; + int index() const; + void setIndex(int index); + void setId(int id); + void setId(int id, int idx); - int id(); - int index(); - void setIndex(int index); - void setId(int id); - void setId(int id, int idx); + const std::string &name() const; + const std::string &getName() const; + void setName(const std::string &name); - std::string *getName(); - void setName(std::string *name); + const std::vector &shape() const; + DataType dataType() const; - std::vector& shape(); + const std::vector>& dependencies() const; #ifndef __JAVACPP_HACK__ - /** - * This method returns offset to this Variable in FlatBuffer - * @param builder - * @return - */ - flatbuffers::Offset asFlatVariable(flatbuffers::FlatBufferBuilder &builder); + // this method converts string deps to int deps + void actualizeDependencies(const MAP_IMPL &lookupTable) const; #endif - }; - } -} +#ifndef __JAVACPP_HACK__ + /** + * This method returns offset to this Variable in FlatBuffer + * @param builder + * @return + */ + flatbuffers::Offset asFlatVariable( + flatbuffers::FlatBufferBuilder &builder); +#endif +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_VARIABLE_H +#endif // LIBND4J_VARIABLE_H diff --git a/libnd4j/include/graph/VariableProxy.h b/libnd4j/include/graph/VariableProxy.h index 1569b477d23b..554c5296aa2f 100644 --- a/libnd4j/include/graph/VariableProxy.h +++ b/libnd4j/include/graph/VariableProxy.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,72 +19,96 @@ // @author raver119@gmail.com // +#ifndef SD_VARIABLEPROXY_H +#define SD_VARIABLEPROXY_H + #include namespace sd { - namespace graph { - class ND4J_EXPORT VariableProxy: public VariableSpace { - protected: - VariableSpace* _backed = nullptr; - VariableSpace* _current = nullptr; - public: - explicit VariableProxy(VariableSpace* reference); - ~VariableProxy(); - - virtual VariableSpace& operator=(const VariableSpace& other); - - virtual int numberOfPlaceholders(); - virtual std::vector* getPlaceholders(); - - virtual sd::memory::Workspace *workspace(); - - virtual bool hasExternalVariable(int it); - virtual bool hasExternalVariable(std::pair& pair); - virtual bool hasExternalVariable(std::string *symbol); - - virtual bool hasVariable(int id); - virtual bool hasVariable(int id, int idx); - virtual bool hasVariable(std::pair& pair); - virtual bool hasVariable(std::string *symbol); - - virtual sd::graph::Variable *getVariable(int id); - virtual sd::graph::Variable *getVariable(int id, int idx); - virtual sd::graph::Variable *getVariable(std::pair& pair); - virtual sd::graph::Variable *getVariable(std::string *symbol); - - virtual std::vector getVariables(); - - virtual Variable* putVariable(std::pair& pair, NDArray *array); - virtual void putVariable(std::pair& pair, Variable *variable); - virtual void putVariable(int id, Variable *variable); - virtual void putVariable(int id, NDArray *array); - virtual Variable* putVariable(int id, int idx, NDArray *array); - virtual void putVariable(int id, int idx, NDArray &array); - virtual void putVariable(int id, int idx, Variable *array); - - virtual void replaceVariable(Variable *variable); - - virtual void dropVariable(std::pair &pair); - virtual void dropVariable(int id, int idx); - - virtual void putOutputVariable(Variable *variable); - - virtual void trackList(sd::NDArrayList *list); - - // memory-related statistics - virtual Nd4jLong externalMemory(); - virtual Nd4jLong internalMemory(); - virtual Nd4jLong totalMemory(); - - virtual int externalEntries(); - virtual int internalEntries(); - virtual int totalEntries(); - - virtual sd::graph::VariableSpace *clone(); - - virtual sd::graph::Stash* getStash(); - virtual void setFlowPath(FlowPath* timers); - virtual FlowPath* flowPath(); - }; - } -} \ No newline at end of file +namespace graph { +class SD_EXPORT VariableProxy : public VariableSpace { + protected: + const VariableSpace* _backed; + VariableSpace _current; + + public: + explicit VariableProxy(const VariableSpace* reference); + ~VariableProxy(); + + virtual VariableSpace& operator=(const VariableSpace& other) override; + + virtual int numberOfPlaceholders() const override; +#ifndef __JAVACPP_HACK__ + virtual const std::vector>& placeholders() const override; + virtual std::vector> variables() const override; +#endif + + virtual const MAP_IMPL, std::shared_ptr>& externalPaired() const; + + virtual bool hasExternalVariable(int it) const override; + virtual bool hasExternalVariable( + const std::pair& pair) const override; + virtual bool hasExternalVariable(const std::string& symbol) const override; + + virtual bool hasVariable(int id) const override; + virtual bool hasVariable(int id, int idx) const override; + virtual bool hasVariable(const std::pair& pair) const override; + virtual bool hasVariable(const std::string& symbol) const override; + + virtual std::shared_ptr getVariable(int id) const override; + virtual std::shared_ptr getVariable(int id, int idx) const override; + virtual std::shared_ptr getVariable( + const std::pair& pair) const override; + virtual std::shared_ptr getVariable( + const std::string& symbol) const override; + + virtual std::shared_ptr putVariable(const std::pair& pair, + const NDArray& array) override; + virtual std::shared_ptr putVariable(int id, + const NDArray& array) override; + virtual std::shared_ptr putVariable(int id, int idx, + const NDArray& array) override; + virtual std::shared_ptr putVariable(const std::string& name, int id, + int idx, + const NDArray& array) override; + virtual void putVariable(const std::string& name, int id, int idx, + const std::shared_ptr& variable) override; + virtual void putVariable(const std::pair& pair, + const std::shared_ptr& variable) override; + virtual void putVariable(int id, + const std::shared_ptr& variable) override; + + virtual void replaceVariable(std::shared_ptr variable) override; + + virtual void dropVariable(const std::pair& pair) override; + virtual void dropVariable(int id, int idx) override; + + virtual void putOutputVariable(std::shared_ptr variable) override; + + // memory-related statistics + virtual Nd4jLong externalMemory() const override; + virtual Nd4jLong internalMemory() const override; + virtual Nd4jLong totalMemory() const override; + + virtual int externalEntries() const override; + virtual int internalEntries() const override; + virtual int totalEntries() const override; + + virtual Stash* stash() const override; + + /** + * This method updates this proxy with entries from a given VariableProxy + * @param proxy + */ + void pullFrom(const VariableProxy &proxy); + + /** + * This method will update given VariableProxy with values from this proxy + * @param proxy + */ + void pushTo(VariableProxy &proxy) const; +}; +} // namespace graph +} // namespace sd + +#endif // SD_VARIABLEPROXY_H \ No newline at end of file diff --git a/libnd4j/include/graph/VariableSpace.h b/libnd4j/include/graph/VariableSpace.h index ea3c6370d16c..3a78337d7a5a 100644 --- a/libnd4j/include/graph/VariableSpace.h +++ b/libnd4j/include/graph/VariableSpace.h @@ -21,123 +21,117 @@ #ifndef LIBND4J_VARIABLESPACE_H #define LIBND4J_VARIABLESPACE_H -#include -#include -#include -#include -#include -#include -#include #include #include +#include +#include #include +#include +#include #include -#include -#include +#include +#include +#include +#include +#include namespace sd { - namespace graph { - class ND4J_EXPORT VariableSpace { - protected: - sd::memory::Workspace *_workspace; - - // stash is NOT cloned - sd::graph::Stash _stash; - - MAP_IMPL, Variable*> _paired; - MAP_IMPL _symbolic; - MAP_IMPL _variables; - std::vector _external; - std::vector _internal; - - std::vector _lists; - - std::vector _placeholders; - - void silentPutVariable(std::pair& pair, Variable *variable); - - int _auto_counter = -1; - - std::mutex _varmap; - - MAP_IMPL _temporary; - - std::vector *_handles; +namespace graph { +class SD_EXPORT VariableSpace { + friend class VariableProxy; + protected: + // stash is NOT cloned + Stash _stash; - FlowPath* _flow = nullptr; + // lookup tables: by name, by id, by id:idx + MAP_IMPL, std::shared_ptr> _paired; + MAP_IMPL> _symbolic; + MAP_IMPL> _variables; - public: - VariableSpace(); - virtual ~VariableSpace(); + // direct references to external variables and internally-generated variables + std::vector> _external; + std::vector> _internal; - virtual VariableSpace& operator=(const VariableSpace& other); + // meh + std::vector> _lists; - virtual int numberOfPlaceholders(); - virtual std::vector* getPlaceholders(); - virtual void setWorkspace(sd::memory::Workspace *workspace); + // placeholders. must be resolved before Graph execution + std::vector> _placeholders; - virtual LaunchContext* launchContext(); + void silentPutVariable(const std::pair &pair, + const std::shared_ptr &variable); - virtual bool hasExternalVariable(int it); - virtual bool hasExternalVariable(std::pair& pair); - virtual bool hasExternalVariable(std::string *symbol); + int _auto_counter = -1; - virtual bool hasVariable(int id); - virtual bool hasVariable(int id, int idx); - virtual bool hasVariable(std::pair& pair); - virtual bool hasVariable(std::string *symbol); + std::mutex _varmap; - virtual sd::graph::Variable* getVariable(int id); - virtual sd::graph::Variable* getVariable(int id, int idx); - virtual sd::graph::Variable* getVariable(std::pair& pair); - virtual sd::graph::Variable* getVariable(std::string *symbol); + public: + VariableSpace(); + virtual ~VariableSpace(); - virtual std::vector getVariables(); + VariableSpace(const sd::graph::VariableSpace &variableSpace); + VariableSpace(sd::graph::VariableSpace &&variableSpace); - virtual Variable* putVariable(std::pair& pair, NDArray *array); - virtual void putVariable(std::pair& pair, Variable *variable); - virtual void putVariable(int id, Variable *variable); - virtual void putVariable(int id, NDArray *array); - virtual Variable* putVariable(int id, int idx, NDArray *array); - virtual void putVariable(int id, int idx, NDArray &array); - virtual void putVariable(int id, int idx, Variable *array); + virtual VariableSpace &operator=(const VariableSpace &other); + virtual VariableSpace &operator=(VariableSpace &&other); - virtual void dropVariable(std::pair &pair); - virtual void dropVariable(int id, int idx); + virtual int numberOfPlaceholders() const; - virtual void trackList(sd::NDArrayList *list); +#ifndef __JAVACPP_HACK__ + virtual const std::vector>& placeholders() const; + virtual std::vector> variables() const; +#endif - virtual void putOutputVariable(Variable *variable); + virtual bool hasExternalVariable(int it) const; + virtual bool hasExternalVariable(const std::pair &pair) const; + virtual bool hasExternalVariable(const std::string &symbol) const; - virtual void replaceVariable(Variable *variable); + virtual bool hasVariable(int id) const; + virtual bool hasVariable(int id, int idx) const; + virtual bool hasVariable(const std::pair &pair) const; + virtual bool hasVariable(const std::string &symbol) const; - // memory-related statistics - virtual Nd4jLong externalMemory(); - virtual Nd4jLong internalMemory(); - virtual Nd4jLong totalMemory(); + virtual std::shared_ptr getVariable(int id) const; + virtual std::shared_ptr getVariable(int id, int idx) const; + virtual std::shared_ptr getVariable( + const std::pair &pair) const; + virtual std::shared_ptr getVariable( + const std::string &symbol) const; - virtual int externalEntries(); - virtual int internalEntries(); - virtual int totalEntries(); + virtual std::shared_ptr putVariable(const std::pair &pair, const NDArray &array); + virtual std::shared_ptr putVariable(int id, const NDArray &array); + virtual std::shared_ptr putVariable(int id, int idx, const std::shared_ptr &array); + virtual std::shared_ptr putVariable(int id, int idx, const NDArray &array); + virtual std::shared_ptr putVariable(const std::string &name, int id, int idx, const NDArray &array); - virtual sd::graph::VariableSpace* clone(); + virtual void putVariable(const std::string &name, int id, int idx, const std::shared_ptr &variable); + virtual void putVariable(const std::pair &pair, const std::shared_ptr &variable); + virtual void putVariable(int id, const std::shared_ptr &variable); - std::vector *handles(); + virtual void dropVariable(const std::string &pair); + virtual void dropVariable(const std::pair &pair); + virtual void dropVariable(int id, int idx); + virtual void putOutputVariable(std::shared_ptr variable); - sd::graph::VariableSpace* asT(); - void injectVariable(std::pair &pair, Variable* variable); + virtual void replaceVariable(std::shared_ptr variable); - virtual sd::graph::Stash* getStash(); + // memory-related statistics + virtual Nd4jLong externalMemory() const; + virtual Nd4jLong internalMemory() const; + virtual Nd4jLong totalMemory() const; - virtual std::vector * getExternalVariables(); + virtual int externalEntries() const; + virtual int internalEntries() const; + virtual int totalEntries() const; - virtual void setFlowPath(FlowPath* timers); - virtual FlowPath* flowPath(); - }; - } -} + void injectVariable(const std::pair &pair, + std::shared_ptr variable); + virtual Stash *stash() const; +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_VARIABLESPACE_H +#endif // LIBND4J_VARIABLESPACE_H diff --git a/libnd4j/include/graph/VariableType.h b/libnd4j/include/graph/VariableType.h index 28883f9b1e52..1e9c580e5c00 100644 --- a/libnd4j/include/graph/VariableType.h +++ b/libnd4j/include/graph/VariableType.h @@ -22,15 +22,15 @@ #define ND4J_VARIABLE_TYPE_H namespace sd { - namespace graph { - enum VariableType { - NDARRAY = 0, - ARRAY_LIST = 1, - FLOW = 2, - CONSTANT = 3, - PLACEHOLDER = 4, - }; - } +namespace graph { +enum VariableType { + NDARRAY = 0, + ARRAY_LIST = 1, + FLOW = 2, + CONSTANT = 3, + PLACEHOLDER = 4, +}; } +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/graph/VariablesSet.h b/libnd4j/include/graph/VariablesSet.h index 682b7fce4dab..b5e5a4d5e87f 100644 --- a/libnd4j/include/graph/VariablesSet.h +++ b/libnd4j/include/graph/VariablesSet.h @@ -21,35 +21,33 @@ #ifndef LIBND4J_VARIABLESSET_H #define LIBND4J_VARIABLESSET_H -#include -#include -#include -#include #include +#include +#include +#include +#include namespace sd { - namespace graph { - class ND4J_EXPORT VariablesSet { - protected: - std::vector _holder; - Nd4jStatus _status; - public: - VariablesSet(Nd4jStatus status = ND4J_STATUS_OK); - ~VariablesSet(); - - Nd4jStatus status(); - - int size(); +namespace graph { +class SD_EXPORT VariablesSet { + protected: + std::vector _holder; + Nd4jStatus _status; - void push_back(Variable* variable); + public: + VariablesSet(Nd4jStatus status = ND4J_STATUS_OK); + ~VariablesSet(); - Variable* at(int index); + Nd4jStatus status(); - }; - } -} + int size(); + void push_back(Variable* variable); + Variable* at(int index); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_VARIABLESSET_H +#endif // LIBND4J_VARIABLESSET_H diff --git a/libnd4j/include/graph/exceptions/impl/unresolved_input_exception.cpp b/libnd4j/include/graph/exceptions/impl/unresolved_input_exception.cpp index fe6e45875dda..78b8c058533f 100644 --- a/libnd4j/include/graph/exceptions/impl/unresolved_input_exception.cpp +++ b/libnd4j/include/graph/exceptions/impl/unresolved_input_exception.cpp @@ -22,29 +22,35 @@ #include namespace sd { - namespace graph { - unresolved_input_exception::unresolved_input_exception(std::string message) : std::runtime_error(message) { - // - } +namespace graph { +unresolved_input_exception::unresolved_input_exception( + const std::string &message) + : std::runtime_error(message) { + // +} - unresolved_input_exception unresolved_input_exception::build(std::string message, int nodeId, std::pair &varIndex) { - auto node = StringUtils::valueToString(nodeId); - auto varId = StringUtils::valueToString(varIndex.first); - auto outputIdx = StringUtils::valueToString(varIndex.second); - message += "; Node: [" + node +":0]; Variable: [" + varId + ":" + outputIdx + "]"; - return unresolved_input_exception(message); - } +unresolved_input_exception unresolved_input_exception::build( + const std::string &message, int nodeId, std::pair &varIndex) { + auto node = StringUtils::valueToString(nodeId); + auto varId = StringUtils::valueToString(varIndex.first); + auto outputIdx = StringUtils::valueToString(varIndex.second); + auto rmessage = message + "; Node: [" + node + ":0]; Variable: [" + varId + + ":" + outputIdx + "]"; + return unresolved_input_exception(rmessage); +} - unresolved_input_exception unresolved_input_exception::build(std::string message, std::pair &varIndex) { - auto nodeId = StringUtils::valueToString(varIndex.first); - auto outputIdx = StringUtils::valueToString(varIndex.second); - message += "; Variable: [" + nodeId + ":" + outputIdx + "]"; - return unresolved_input_exception(message); - } +unresolved_input_exception unresolved_input_exception::build( + const std::string &message, std::pair &varIndex) { + auto nodeId = StringUtils::valueToString(varIndex.first); + auto outputIdx = StringUtils::valueToString(varIndex.second); + auto rmessage = message + "; Variable: [" + nodeId + ":" + outputIdx + "]"; + return unresolved_input_exception(rmessage); +} - unresolved_input_exception unresolved_input_exception::build(std::string message, std::string &varName) { - message += "; Variable: [" + varName + "]"; - return unresolved_input_exception(message); - } - } -} \ No newline at end of file +unresolved_input_exception unresolved_input_exception::build( + const std::string &message, const std::string &varName) { + auto rmessage = message + "; Variable: [" + varName + "]"; + return unresolved_input_exception(rmessage); +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/exceptions/impl/unresolved_output_exception.cpp b/libnd4j/include/graph/exceptions/impl/unresolved_output_exception.cpp index df8b5eb00c53..49e08b56e472 100644 --- a/libnd4j/include/graph/exceptions/impl/unresolved_output_exception.cpp +++ b/libnd4j/include/graph/exceptions/impl/unresolved_output_exception.cpp @@ -20,30 +20,36 @@ #include #include + #include namespace sd { - namespace graph { - unresolved_output_exception::unresolved_output_exception(std::string message) : std::runtime_error(message) { - // - } +namespace graph { +unresolved_output_exception::unresolved_output_exception( + const std::string &message) + : std::runtime_error(message) { + // +} - unresolved_output_exception unresolved_output_exception::build(std::string message, std::pair &varIndex) { - auto nodeId = StringUtils::valueToString(varIndex.first); - auto outputIdx = StringUtils::valueToString(varIndex.second); - message += "; Variable: [" + nodeId + ":" + outputIdx + "]"; - return unresolved_output_exception(message); - } +unresolved_output_exception unresolved_output_exception::build( + const std::string &message, std::pair &varIndex) { + auto nodeId = StringUtils::valueToString(varIndex.first); + auto outputIdx = StringUtils::valueToString(varIndex.second); + auto rmessage = message + "; Variable: [" + nodeId + ":" + outputIdx + "]"; + return unresolved_output_exception(rmessage); +} - unresolved_output_exception unresolved_output_exception::build(std::string message, int nodeId, int outputIndex) { - std::pair p(nodeId, outputIndex); - return build(message, p); - } +unresolved_output_exception unresolved_output_exception::build( + const std::string &message, int nodeId, int outputIndex) { + std::pair p(nodeId, outputIndex); + return build(message, p); +} - unresolved_output_exception unresolved_output_exception::build(std::string message, std::string &varName, int outputIndex) { - auto outputIdx = StringUtils::valueToString(outputIndex); - message += "; Variable: [" + varName + ":" + outputIdx + "]"; - return unresolved_output_exception(message); - } - } -} \ No newline at end of file +unresolved_output_exception unresolved_output_exception::build( + const std::string &message, const std::string &varName, int outputIndex) { + auto outputIdx = StringUtils::valueToString(outputIndex); + auto rmessage = message + "; Variable: [" + varName + ":" + outputIdx + "]"; + return unresolved_output_exception(rmessage); +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/exceptions/unresolved_input_exception.h b/libnd4j/include/graph/exceptions/unresolved_input_exception.h index 5e38977a99e7..6cea61bacc23 100644 --- a/libnd4j/include/graph/exceptions/unresolved_input_exception.h +++ b/libnd4j/include/graph/exceptions/unresolved_input_exception.h @@ -18,25 +18,29 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_UNRESOLVED_INPUT_H -#define DEV_TESTS_UNRESOLVED_INPUT_H +#ifndef SD_UNRESOLVED_INPUT_H +#define SD_UNRESOLVED_INPUT_H -#include -#include #include +#include +#include namespace sd { - namespace graph { - class unresolved_input_exception : public std::runtime_error { - public: - unresolved_input_exception(std::string message); - ~unresolved_input_exception() = default; +namespace graph { +class unresolved_input_exception : public std::runtime_error { + public: + unresolved_input_exception(const std::string &message); + ~unresolved_input_exception() = default; - static unresolved_input_exception build(std::string message, int nodeId, std::pair &varIndex); - static unresolved_input_exception build(std::string message, std::pair &varIndex); - static unresolved_input_exception build(std::string message, std::string &varName); - }; - } -} + static unresolved_input_exception build(const std::string &message, + int nodeId, + std::pair &varIndex); + static unresolved_input_exception build(const std::string &message, + std::pair &varIndex); + static unresolved_input_exception build(const std::string &message, + const std::string &varName); +}; +} // namespace graph +} // namespace sd -#endif //DEV_TESTS_UNRESOLVED_INPUT_H +#endif // SD_UNRESOLVED_INPUT_H diff --git a/libnd4j/include/graph/exceptions/unresolved_output_exception.h b/libnd4j/include/graph/exceptions/unresolved_output_exception.h index 05d39c514818..688028472d12 100644 --- a/libnd4j/include/graph/exceptions/unresolved_output_exception.h +++ b/libnd4j/include/graph/exceptions/unresolved_output_exception.h @@ -18,26 +18,29 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_UNRESOLVED_OUTPUT_H -#define DEV_TESTS_UNRESOLVED_OUTPUT_H +#ifndef SD_UNRESOLVED_OUTPUT_H +#define SD_UNRESOLVED_OUTPUT_H -#include -#include #include +#include +#include namespace sd { - namespace graph { - class unresolved_output_exception : public std::runtime_error { - public: - unresolved_output_exception(std::string message); - ~unresolved_output_exception() = default; - - static unresolved_output_exception build(std::string message, int nodeId, int outputIndex); - static unresolved_output_exception build(std::string message, std::pair &varIndex); - static unresolved_output_exception build(std::string message, std::string &varName, int outputIndex); - }; - } -} +namespace graph { +class unresolved_output_exception : public std::runtime_error { + public: + unresolved_output_exception(const std::string &message); + ~unresolved_output_exception() = default; + static unresolved_output_exception build(const std::string &message, + int nodeId, int outputIndex); + static unresolved_output_exception build(const std::string &message, + std::pair &varIndex); + static unresolved_output_exception build(const std::string &message, + const std::string &varName, + int outputIndex = 0); +}; +} // namespace graph +} // namespace sd -#endif //DEV_TESTS_UNRESOLVED_INPUT_H +#endif // SD_UNRESOLVED_INPUT_H diff --git a/libnd4j/include/graph/execution/ExecutionLayer.h b/libnd4j/include/graph/execution/ExecutionLayer.h new file mode 100644 index 000000000000..16ce1bf4096f --- /dev/null +++ b/libnd4j/include/graph/execution/ExecutionLayer.h @@ -0,0 +1,94 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_EXECUTIONLAYER_H +#define SD_EXECUTIONLAYER_H + +#include +#include + +#include + +namespace sd { +namespace graph { +class SD_EXPORT ExecutionLayer { + protected: + std::vector _sequences; + + public: + ExecutionLayer(const std::vector& sequences = {}); + ~ExecutionLayer() = default; + + ExecutionLayer(const ExecutionLayer& other) noexcept; + + ExecutionLayer& operator=(const ExecutionLayer& other) noexcept; + + // move constructor + ExecutionLayer(ExecutionLayer&& other) noexcept; + + // move assignment operator + ExecutionLayer& operator=(ExecutionLayer&& other) noexcept; + + /** + * This method returns number of sequences in this layer + * @return + */ + uint64_t width() const; + + /** + * This method returns specified OpSequence from this layer + * @return + */ + const OpSequence& at(uint64_t index) const; + const OpSequence& operator[](uint64_t index) const; + + OpSequence& at(uint64_t index); + OpSequence& operator[](uint64_t index); + + /** + * This method appends OpSequence to the end of this layer + * @param sequence + */ + void append(OpSequence&& sequence); + void append(const OpSequence& sequence); + + /** + * sort OpSequences in increasing order in respect to id of fist node in sequence + * @param sequence + */ + void sortOpSequences(); + + /** + * This method checks if specified Node resides within this ExecutionLayer + * @param nodeId + * @return + */ + bool hasNode(int nodeId) const; + + /** + * This method removes all empty OpSequences from this layer + */ + void purgeEmptySequences(); +}; + +} // namespace graph +} // namespace sd + +#endif // SD_EXECUTIONLAYER_H diff --git a/libnd4j/include/graph/execution/ExecutionTask.h b/libnd4j/include/graph/execution/ExecutionTask.h new file mode 100644 index 000000000000..be15ee185b6f --- /dev/null +++ b/libnd4j/include/graph/execution/ExecutionTask.h @@ -0,0 +1,65 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_EXECUTIONTASK_H +#define SD_EXECUTIONTASK_H + +#include +#include +#include + +#include + +namespace sd { +namespace graph { +class SD_EXPORT ExecutionTask { + protected: + // TODO: smart pointers here? + const Node _node; + + // FIXME: this field can be removed. Node contains ContextPrototype. + const ContextPrototype _context; + + public: + ExecutionTask(const Node& node, + const ContextPrototype& ctx); + + ~ExecutionTask() = default; + + ExecutionTask(const ExecutionTask& other); + + ExecutionTask& operator=(const ExecutionTask& other) noexcept; + + // move constructor + ExecutionTask(ExecutionTask&& other); + + // move assignment operator + ExecutionTask& operator=(ExecutionTask&& other) noexcept; + + void printOut() const; + + const Node& node() const; + + const ContextPrototype& protoContext() const; +}; +} // namespace graph +} // namespace sd + +#endif // SD_EXECUTIONTASK_H diff --git a/libnd4j/include/graph/execution/GraphExecutor.h b/libnd4j/include/graph/execution/GraphExecutor.h new file mode 100644 index 000000000000..281fd419fddb --- /dev/null +++ b/libnd4j/include/graph/execution/GraphExecutor.h @@ -0,0 +1,102 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_GRAPHEXECUTOR_H +#define SD_GRAPHEXECUTOR_H + +#include +#include +#include +#include +#include +#include +#include + +namespace sd { +namespace graph { +class Graph; + +class SD_EXPORT GraphExecutor { + protected: + virtual Context prepareContext(const ContextPrototype &contextPrototype, + VariableProxy &variableSpace, + const GraphMemoryManager &memoryManager) const; + + /* + * preprocessor call involves: + * - ensure all inputs reside in HOT memory zone + * - shape function call + * - open workspace + */ + virtual Nd4jStatus preprocess(sd::ops::DeclarableOp *op, + Context &context) const; + + /** + * postporcessor call involves: + * - remove all inputs that are not going to be used later from HOT memory + * zone + * - close workspace + * @return + */ + virtual Nd4jStatus postprocess(sd::ops::DeclarableOp *op, + Context *context) const; + + public: + GraphExecutor() = default; + virtual ~GraphExecutor() = default; + + /** + * This method executes OptimizedGraph instance + * @param graph + * @return + */ + virtual Nd4jStatus execute(const OptimizedGraph &graph, + VariableProxy &proxy, + bool isInference = true) const; + + /** + * This method executes OpSequence + * @param seq + * @param deviceId - this argument allows to override device affinity + * specified in OpSequence, keep it < 0 to follow OpSequence + * @return + */ + virtual Nd4jStatus execute(const OpSequence &seq, + const OptimizedGraph &graph, + Stack &stack, + int deviceId) const; + + /** + * This method executes given op + * @param op + * @param contextPrototype + * @return + */ + virtual Nd4jStatus execute(const std::shared_ptr &op, + const ContextPrototype &contextPrototype, + const OpSequence &sequence, + const OptimizedGraph &graph, + VariableProxy &proxy, + const int deviceId) const; +}; +} // namespace graph +} // namespace sd + +#endif // SD_GRAPHEXECUTOR_H diff --git a/libnd4j/include/graph/execution/LogicConditional.h b/libnd4j/include/graph/execution/LogicConditional.h deleted file mode 100644 index ffaf6f098f99..000000000000 --- a/libnd4j/include/graph/execution/LogicConditional.h +++ /dev/null @@ -1,49 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 20.10.2017. -// - -#ifndef LIBND4J_LOGICCONDITIONAL_H -#define LIBND4J_LOGICCONDITIONAL_H - -#include -#include -#include - -namespace sd { - namespace graph { - /** - * This class is responsible for execution logic of Conditional logical abstraction - * - * TL/DR: Class takes 2 ops/scopes with the same number of inputs/outputs and condtion. - * Condition is evaluated, and based on its result - one of ops/scopes is executed. - * Results of this execution will be copied to Conditional node, and every other op - * in the graph will be sure that it's Conditional own result, both alternative nodes will - * stay in disguise. - * - * @tparam T - */ - class LogicConditional { - public: - static Nd4jStatus processNode(Graph* graph, Node* node); - }; - } -} - - -#endif //LIBND4J_LOGICCONDITIONAL_H diff --git a/libnd4j/include/graph/execution/LogicScope.h b/libnd4j/include/graph/execution/LogicScope.h deleted file mode 100644 index a7a8d6b7a9c6..000000000000 --- a/libnd4j/include/graph/execution/LogicScope.h +++ /dev/null @@ -1,45 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 20.10.2017. -// - -#ifndef LIBND4J_LOGICSCOPE_H -#define LIBND4J_LOGICSCOPE_H - -#include -#include -#include - -namespace sd { - namespace graph { - /** - * This class is responsible for execution logic of Scope logical abstraction - * - * It's ultra-simple. It does nothing, and can't be executed directly. - * - * @tparam T - */ - class LogicScope { - public: - static Nd4jStatus processNode(Graph* graph, Node* node); - }; - } -} - - -#endif //LIBND4J_LOGICSCOPE_H diff --git a/libnd4j/include/graph/execution/LogicWhile.h b/libnd4j/include/graph/execution/LogicWhile.h deleted file mode 100644 index 6e4b2ea3ae24..000000000000 --- a/libnd4j/include/graph/execution/LogicWhile.h +++ /dev/null @@ -1,44 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 20.10.2017. -// - -#ifndef LIBND4J_LOGICWHILE_H -#define LIBND4J_LOGICWHILE_H - -#include -#include -#include - -namespace sd { - namespace graph { - /** - * This class is responsible for execution logic of While logical abstraction - * - * Basic idea is simple: we take 2 scopes, one for condition and other one for body. and we re-execute body as long, as condition scope evaluates to TRUE - * @tparam T - */ - class LogicWhile { - public: - static Nd4jStatus processNode(Graph* graph, Node* node); - }; - } -} - - -#endif //LIBND4J_LOGICWHILE_H diff --git a/libnd4j/include/graph/execution/OpSequence.h b/libnd4j/include/graph/execution/OpSequence.h new file mode 100644 index 000000000000..9ad7b10f8d21 --- /dev/null +++ b/libnd4j/include/graph/execution/OpSequence.h @@ -0,0 +1,144 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_OPSEQUENCE_H +#define SD_OPSEQUENCE_H + +#include +#include + +#include + +namespace sd { +namespace graph { +/** + * This class represents independent and immutable sequence of operations + */ +class SD_EXPORT OpSequence + : public std::iterator { + // our internal iterator for OpSequence + class iterator; + + protected: + // main thing here. sorted list of operations and their contexts + std::vector _ops; + + int _deviceId = 0; + + // this map contains Node::id() -> OpSequence index mappings + MAP_IMPL _idToIndex; + + // this map contains OpSequence index -> Node::id() mapping + MAP_IMPL _indexToId; + + public: + explicit OpSequence(const std::vector& ops, int deviceId = 0); + OpSequence(int deviceId = 0); + ~OpSequence() = default; + + OpSequence(const OpSequence& other) noexcept; + + OpSequence& operator=(const OpSequence& other) noexcept; + + // move constructor + OpSequence(OpSequence&& other) noexcept; + + // move assignment operator + OpSequence& operator=(OpSequence&& other) noexcept; + + int deviceId() const; + + /** + * This method blocks until all operations within sequence are processed + * @return + */ + Nd4jStatus wait() const; + + /** + * This method prints out content of the sequence + */ + void printOut() const; + + /** + * This method returns number of individual operations within this sequence + * @return + */ + uint64_t length() const; + + /** + * This method returns specific Op/ContextPrototype pair for specified index + * @param index + * @return + */ + const ExecutionTask& at(uint64_t index) const; + ExecutionTask& at(uint64_t index); + + const ExecutionTask& operator[](uint64_t index) const; + ExecutionTask& operator[](uint64_t index); + + + /** + * This method allows to add DeclarableOp to the end of execution queue + * @param op - Op to be executed + * @param ctx - ContextPrototype for this operation with inputs/outputs/args + * defined + */ + void append(const Node& node, const sd::graph::ContextPrototype& ctx); + void append(const ExecutionTask& task); + void append(ExecutionTask&& task); + void append(const OpSequence &sequence); + + /** + * These two methods provide access to index/id dictionalries + * @param index + * @return + */ + int nodeId(int index) const; + int nodeIndex(int id) const; + bool hasNode(int id) const; + + /** + * Iterator functionality for OpSequence + * @return + */ + + OpSequence::iterator begin(); + OpSequence::iterator end(); + + // additional private section + private: + class iterator + : public std::iterator { + private: + uint64_t _position = 0; + OpSequence& _container; + + public: + explicit iterator(OpSequence& container, uint64_t index = 0); + const ExecutionTask& operator*() const; + iterator& operator++(); + iterator& operator++(int); + bool operator!=(const iterator&) const; + }; +}; +} // namespace graph +} // namespace sd + +#endif // SD_OPSEQUENCE_H diff --git a/libnd4j/include/graph/execution/LogicReturn.h b/libnd4j/include/graph/execution/Stack.h similarity index 53% rename from libnd4j/include/graph/execution/LogicReturn.h rename to libnd4j/include/graph/execution/Stack.h index 2cc6107c5f6b..20ffbaea0f93 100644 --- a/libnd4j/include/graph/execution/LogicReturn.h +++ b/libnd4j/include/graph/execution/Stack.h @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -15,32 +15,40 @@ ******************************************************************************/ // -// Created by raver119 on 28.10.2017. +// @author raver119@gmail.com // -#ifndef LIBND4J_LOGICRETURN_H -#define LIBND4J_LOGICRETURN_H +#ifndef SD_STACK_H_ +#define SD_STACK_H_ - -#include -#include -#include +#include +#include +#include namespace sd { - namespace graph { - /** - * This class is responsible for execution logic of Return logical abstraction - * - * Basically we're just transferring input variable(s) to output variable(s), nothing beyond that - * @tparam T - */ - class LogicReturn { - public: - static Nd4jStatus processNode(Graph* graph, Node* node); - }; - } -} - - - -#endif //LIBND4J_LOGICRETURN_H +namespace graph { + +class SD_EXPORT Stack { + private: + std::deque _frames; + + int _counter = 0; + public: + Stack(const VariableProxy &root); + ~Stack() = default; + + StackFrame& back(); + StackFrame& front(); + StackFrame& root(); + + const VariableProxy& rootVariableSpace() const; + + void openFrame(int frameId, int enterId); + void iterateFrame(int frameId, int enterId); + void closeFrame(); +}; + +} // namespace graph +} // namespace sd + +#endif //SD_STACK_H_ diff --git a/libnd4j/include/graph/execution/StackFrame.h b/libnd4j/include/graph/execution/StackFrame.h new file mode 100644 index 000000000000..a94cda1f2df2 --- /dev/null +++ b/libnd4j/include/graph/execution/StackFrame.h @@ -0,0 +1,76 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_STACKFRAME_H_ +#define SD_STACKFRAME_H_ + +#include +#include +#include + +namespace sd { +namespace graph { + +class SD_EXPORT StackFrame { + private: + int _id; + VariableProxy _proxy; + StackFrame *_parent = nullptr; + + MAP_IMPL _disabledNodes; + + // these fields are used + int _frameId = -119; + int _enterId = -119; + mutable int _rewindId = -119; + mutable int _exitId = -119; + public: + explicit StackFrame(const VariableProxy &proxy, int id, int frameId, int enterId); + explicit StackFrame(const VariableProxy &proxy, int id, int frameId, int enterId, StackFrame &parent); + ~StackFrame() = default; + + const VariableProxy& variableProxy() const { return _proxy; } + + + + void disableNode(int nodeId); + bool isDisabled(int nodeId) const; + + int frameId() const; + int enterId() const; + int exitId() const; + int rewindId() const; + + void setRewindId(int id) const; + void setExitId(int id) const; + + /** + * This method returns parent frame + * @return + */ + StackFrame& parent() const; + + int id() const { return _id; } +}; + +} // namespace graph +} // namespace sd + +#endif // SD_STACKFRAME_H_ diff --git a/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp b/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp new file mode 100644 index 000000000000..50f8a600d6df --- /dev/null +++ b/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp @@ -0,0 +1,111 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { +namespace graph { +ExecutionLayer::ExecutionLayer(const std::vector &sequences) { + _sequences = sequences; +} + +uint64_t ExecutionLayer::width() const { return _sequences.size(); } + +const OpSequence &ExecutionLayer::at(uint64_t index) const { + return _sequences[index]; +} + +OpSequence &ExecutionLayer::at(uint64_t index) { + return _sequences[index]; +} + +const OpSequence &ExecutionLayer::operator[](uint64_t index) const { + return at(index); +} + +OpSequence &ExecutionLayer::operator[](uint64_t index) { + return at(index); +} + +bool ExecutionLayer::hasNode(int nodeId) const { + for (const auto &v:_sequences) + if (v.hasNode(nodeId)) + return true; + + return false; +} + +void ExecutionLayer::append(OpSequence&& sequence) { + _sequences.emplace_back(std::move(sequence)); +} +void ExecutionLayer::append(const OpSequence& sequence) { + _sequences.emplace_back(sequence); +} + + +ExecutionLayer::ExecutionLayer(const ExecutionLayer &other) noexcept { + _sequences = other._sequences; +} + +ExecutionLayer &ExecutionLayer::operator=( + const ExecutionLayer &other) noexcept { + if (this == &other) return *this; + + _sequences = other._sequences; + + return *this; +} + +// move constructor +ExecutionLayer::ExecutionLayer(ExecutionLayer &&other) noexcept: _sequences(std::move(other._sequences)) { + +} + +ExecutionLayer &ExecutionLayer::operator=(ExecutionLayer &&other) noexcept { + if (this == &other) return *this; + + _sequences = std::move(other._sequences); + + return *this; +} + +//////////////////////////////////////////////////////////////////////// +void ExecutionLayer::sortOpSequences() { // bubble sort + + const int numOfOpSequences = this->width(); + + OpSequence temp; + + for (int i = 0; i < numOfOpSequences - 1; ++i) + for (int j = 0; j < numOfOpSequences - 1 - i; ++j) + if (_sequences[j][0].protoContext().nodeId() > _sequences[j + 1][0].protoContext().nodeId()) { + temp = _sequences[j]; + _sequences[j] = _sequences[j + 1]; + _sequences[j + 1] = temp; + + } +} + +void ExecutionLayer::purgeEmptySequences() { + _sequences.erase(std::remove_if(_sequences.begin(), _sequences.end(), [](OpSequence &seq) -> bool { return seq.length() == 0; }), _sequences.end()); +} + +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/execution/impl/ExecutionTask.cpp b/libnd4j/include/graph/execution/impl/ExecutionTask.cpp new file mode 100644 index 000000000000..5677852bc28f --- /dev/null +++ b/libnd4j/include/graph/execution/impl/ExecutionTask.cpp @@ -0,0 +1,103 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { +namespace graph { +ExecutionTask::ExecutionTask(const Node& node, + const ContextPrototype &ctx) + : _node(node), _context(ctx) { } + +const Node& ExecutionTask::node() const { return _node; } + +const ContextPrototype &ExecutionTask::protoContext() const { return _context; } + +ExecutionTask::ExecutionTask(const ExecutionTask &other) + : _node(other._node), _context(other._context) { } + +ExecutionTask &ExecutionTask::operator=(const ExecutionTask &other) noexcept { + if (this == &other) return *this; + + const_cast(_node) = other._node; + const_cast(_context) = other._context; + + return *this; +} + +ExecutionTask::ExecutionTask(ExecutionTask &&other) + : _node(other._node), _context(other._context) { } + +void ExecutionTask::printOut() const { + if (_context.name().empty()) { + if (_node.hasCustomOp()) + printf(" <%i:0>: {Op: %s}; ", _context.nodeId(), _node.customOp()->getOpName().c_str()); + else + printf(" <%i:0>: ", _context.nodeId()); + } else { + printf(" <%s> <%i>: ", _context.name().c_str(), _context.nodeId()); + } + + auto sz = _context.inputs().size(); + if (sz) { + printf(" Inputs: ["); + int cnt = 0; + for (const auto &v : _node.inputs()) { + printf("<%i:%i>", v.first, v.second); + + if (cnt < sz - 1) printf(", "); + cnt++; + } + + printf("]; "); + } else { + printf(" No inputs; "); + } + + sz = _node.outputs().size(); + if (sz) { + printf(" Outputs: ["); + int cnt = 0; + for (const auto &v : _node.outputs()) { + printf("<%i:%i>", v.first, v.second); + + if (cnt < sz - 1) printf(", "); + cnt++; + } + + printf("]; "); + } else { + printf(" No outputs; "); + } + + printf("\n"); + fflush(stdout); +} + +ExecutionTask &ExecutionTask::operator=(ExecutionTask &&other) noexcept { + if (this == &other) return *this; + + const_cast(_node) = other._node; + const_cast(_context) = std::move(other._context); + + return *this; +} +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp new file mode 100644 index 000000000000..6ffb29976e95 --- /dev/null +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -0,0 +1,163 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include + + +namespace sd { +namespace graph { +Context GraphExecutor::prepareContext( + const ContextPrototype &contextPrototype, VariableProxy &variableProxy, + const GraphMemoryManager &memoryManager) const { + // TODO: maybe we'll want to do something here? + return Context(contextPrototype, &variableProxy, + const_cast(&memoryManager)); +} + +Nd4jStatus GraphExecutor::preprocess(sd::ops::DeclarableOp *op, + Context &context) const { + // time to allocate outputs, if that's not inplace op + // inplace case is covered there + op->prepareOutputs(context); + + // once prepareOutputs method was called - we don't need shape function + // anymore + context.setShapeFunctionOverride(true); + + return Status::OK(); +} + +Nd4jStatus GraphExecutor::postprocess(sd::ops::DeclarableOp *op, + Context *context) const { + return Status::OK(); +} + +Nd4jStatus GraphExecutor::execute( + const std::shared_ptr &op, + const ContextPrototype &contextPrototype, const OpSequence &sequence, + const OptimizedGraph &graph, VariableProxy &proxy, + const int deviceId) const { + auto ctx = prepareContext(contextPrototype, proxy, GraphMemoryManager()/*graph.memoryManager()*/); + return op->execute(&ctx); + // throw std::runtime_error("GraphExecutor::execute - Not implemented yet"); +} + +Nd4jStatus GraphExecutor::execute(const OpSequence &seq, + const OptimizedGraph &graph, + Stack &stack, + const int deviceId) const { + // we either follow or override target deviceId specified in OpSequence + auto targetDevice = deviceId >= 0 + ? deviceId + : seq.deviceId(); + + /* + * this is a basic implementation that works without dispatching etc + */ + auto result = Status::OK(); + for (int e = 0; e < seq.length(); e++) { + auto &v = seq[e]; + auto &f = stack.back(); + auto &p = f.variableProxy(); + + // we skip all disabled nodes + if (f.isDisabled(v.node().id())) + continue; + + + if (v.node().opType() == OpType_LOGIC) { + nd4j_debug("Node <%i:%s> is a logic op; Current frame: [%i]\n", + v.node().id(), + v.node().name().empty() ? "" : v.node().name().c_str(), + f.id()); + + LogicExecutor::processNode(&v.node(), stack, graph); + + // next iteration is special case: we might rewind + if (v.node().opNum() == sd::logic::NextIteration) { + if (f.rewindId() == v.node().id()) + e = seq.nodeIndex(f.enterId()) - 1; + } + } else if (v.node().hasCustomOp()) { + nd4j_debug("Node <%i:%s> is a custom op; Current frame: [%i]\n", + v.node().id(), + v.node().name().empty() ? "" : v.node().name().c_str(), + f.id()); + + // only Ops can be executed this way :( + result = execute(v.node().customOp(), v.protoContext(), seq, graph, const_cast(p), targetDevice); + } else { + nd4j_debug("Node <%i:%s> has no customOp set; Current frame: [%i]\n", + v.node().id(), + v.node().name().empty() ? "" : v.node().name().c_str(), + f.id()); + } + + // if any one op fails - there will be no sense in executing other ops + if (result != Status::OK()) return result; + } + + return Status::OK(); +} + +Nd4jStatus GraphExecutor::execute(const OptimizedGraph &graph, + VariableProxy &proxy, + bool isInference) const { + /* + * this is a basic exection logic: roll through layers and sequences and + * execute them one by one sequentially + */ + + // now we create out dequeue of frames with one root StackFrame. current one. + Stack stack(proxy); + + const auto numDevices = AffinityManager::numberOfDevices(); + Nd4jStatus result = Status::OK(); // + for (uint64_t l = 0; l < graph.layers(); l++) { + const auto &layer = graph.layer(l); + + //TODO: this loop is executable in parallel, so we should do this eventually + for (uint64_t o = 0; o < layer.width(); o++) { + result = execute(layer[o], graph, stack, -1); + } + + // early termination + if (result != Status::OK()) return result; + + // optionally block until all sequences in this layer processed + if (layer.width() > 0 && numDevices > 1) + for (uint64_t o = 0; o < layer.width(); o++) { + result = layer[o].wait(); + + // early termination + if (result != Status::OK()) return result; + } + } + + // update original VariableSpace from the top-level VariableSpace + proxy.pullFrom(stack.front().variableProxy()); + + return result; +} + +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/execution/impl/LogicConditional.cpp b/libnd4j/include/graph/execution/impl/LogicConditional.cpp deleted file mode 100644 index 25627df4564d..000000000000 --- a/libnd4j/include/graph/execution/impl/LogicConditional.cpp +++ /dev/null @@ -1,136 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 20.10.2017. -// - -#include -#include -#include -#include - - -namespace sd { - namespace graph { - Nd4jStatus LogicConditional::processNode(Graph *graph, Node *node) { - auto __variableSpace = graph->getVariableSpace(); - - auto size = node->input()->size(); - - // propagating inputs (optional) - for (int e = 0; e < size - 3; e++) { - std::pair pair(node->id(), e); - if (!__variableSpace->hasVariable(pair)) { - __variableSpace->putVariable(pair, new Variable(nullptr, nullptr, node->id(), e)); - } - - auto va = node->input()->at(e); - - auto inputVar = __variableSpace->getVariable(va); - - auto innerVar = __variableSpace->getVariable(pair); - if (innerVar->hasNDArray()) { - // TODO: ??? - } else { - // FIXME: in some cases it's possible to have no NDArray - if (inputVar->hasNDArray()) - innerVar->setNDArray(new NDArray(inputVar->getNDArray()->dup())); - } - } - - - int scopeConditionIndex = node->input()->at(size - 3).first; - int scopeFalseIndex = node->input()->at(size - 2).first; - int scopeTrueIndex = node->input()->at(size - 1).first; - - auto scopeCondition = graph->scopeById(scopeConditionIndex); - int lastNode = 0; - for (auto v: *scopeCondition->nodes()) { - GraphExecutioner::executeFlatNode(graph, v, __variableSpace); - lastNode = v->id(); - } - - // now we should take result of the Scope run, and evaluate it - //nd4j_debug("", ""); - auto result = __variableSpace->getVariable(lastNode)->getNDArray(); - //result->printBuffer("Result of the last node:"); - - bool isReturn = false; - - // now we're executing one of the scopes, depending on condition evaluation - if (result->e(0) == 0) { - auto scopeFalse = graph->scopeById(scopeFalseIndex); - lastNode = 0; - int nodes = scopeFalse->nodes()->size(); - for (int e = 0; e < nodes - 1; e++) { - auto v = scopeFalse->nodes()->at(e); - GraphExecutioner::executeFlatNode(graph, v, __variableSpace); - lastNode = v->id(); - } - - // last node is either return or just last op - auto *node = scopeFalse->nodes()->at(nodes -1); - if (node->opType() == OpType_LOGIC && node->opNum() == 40) { - isReturn = true; - LogicReturn::processNode(graph, node); - } else { - GraphExecutioner::executeFlatNode(graph, node, __variableSpace); - lastNode = node->id(); - } - } else { - auto scopeTrue = graph->scopeById(scopeTrueIndex); - lastNode = 0; - int nodes = scopeTrue->nodes()->size(); - for (int e = 0; e < nodes - 1; e++) { - auto v = scopeTrue->nodes()->at(e); - GraphExecutioner::executeFlatNode(graph, v, __variableSpace); - lastNode = v->id(); - } - - // last node is either return or just last op - auto node = scopeTrue->nodes()->at(nodes -1); - if (node->opType() == OpType_LOGIC && node->opNum() == 40) { - isReturn = true; - LogicReturn::processNode(graph, node); - } else { - GraphExecutioner::executeFlatNode(graph, node, __variableSpace); - lastNode = node->id(); - } - } - - // now fetch and transfer variables to Conditional node - // but only if return wasn't called at the end of scope - if (!isReturn) { - for (int e = 0; e < DataTypeUtils::max(); e++) { - std::pair pair(lastNode, e); - std::pair pairNew(node->id(), e); - if (__variableSpace->hasVariable(pair)) { - auto array = __variableSpace->getVariable(pair)->getNDArray(); - auto newVar = new Variable(array); - newVar->setId(lastNode, e); - newVar->markRemovable(false); - - __variableSpace->putVariable(pairNew, newVar); - } else - break; - } - } - - return sd::Status::OK(); - } - } -} \ No newline at end of file diff --git a/libnd4j/include/graph/execution/impl/LogicEnter.cpp b/libnd4j/include/graph/execution/impl/LogicEnter.cpp deleted file mode 100644 index f10ff792f765..000000000000 --- a/libnd4j/include/graph/execution/impl/LogicEnter.cpp +++ /dev/null @@ -1,74 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include -#include - - -namespace sd { - namespace graph { - Nd4jStatus LogicEnter::processNode(Graph *graph, Node *node) { - // this op replicates input variable into the frame. basically happens once for single loop. - // sure, if there's inner loop within outer loop, it'll be called once for outer loop and multiple times for inner loop - - auto __variableSpace = graph->getVariableSpace(); - auto __flowPath = __variableSpace->flowPath(); - - // basically, first non-null variable is our target - for (int e = 0; e < node->input()->size(); e++) { - auto inputAddr = node->input()->at(e); - - if (__variableSpace->hasVariable(inputAddr)) { - auto var = __variableSpace->getVariable(inputAddr); - if (var->hasNDArray()) { - Variable *lvar = nullptr; - if (__variableSpace->hasVariable(node->id(), 0)) - lvar = __variableSpace->getVariable(node->id(), 0); - else - lvar = new Variable(nullptr, node->getName()->c_str(), node->id(), 0); - - auto array = var->getNDArray(); - lvar->setNDArray(array); - lvar->markReadOnly(true); - - break; - } else if (var->hasNDArrayList()) { - Variable *lvar = nullptr; - if (__variableSpace->hasVariable(node->id(), 0)) - lvar = __variableSpace->getVariable(node->id(), 0); - else - lvar = new Variable(nullptr, node->getName()->c_str(), node->id(), 0); - - auto list = var->getNDArrayList(); - lvar->setNDArrayList(list); - lvar->markReadOnly(true); - - break; - } else { - // FIXME: can we really have third case here? - continue; - } - } - } - - return sd::Status::OK(); - } - } -} \ No newline at end of file diff --git a/libnd4j/include/graph/execution/impl/LogicExecutor.cpp b/libnd4j/include/graph/execution/impl/LogicExecutor.cpp deleted file mode 100644 index fd7ce3e852e4..000000000000 --- a/libnd4j/include/graph/execution/impl/LogicExecutor.cpp +++ /dev/null @@ -1,71 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 20.10.2017. -// - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - - -namespace sd { - namespace graph { - Nd4jStatus LogicExecutor::processNode(Graph *graph, Node *node) { - switch (node->opNum()) { - case sd::logic::While: - return LogicWhile::processNode(graph, node); - case sd::logic::Scope: - return LogicScope::processNode(graph, node); - case sd::logic::Conditional: - return LogicConditional::processNode(graph, node); - case sd::logic::Switch: - return LogicSwitch::processNode(graph, node); - case sd::logic::Return: - return LogicReturn::processNode(graph, node); - case sd::logic::Expose: - return LogicExpose::processNode(graph, node); - case sd::logic::Merge: - return LogicMerge::processNode(graph, node); - case sd::logic::LoopCond: - return LogicLoopCond::processNode(graph, node); - case sd::logic::NextIteration: - return LogicNextIeration::processNode(graph, node); - case sd::logic::Exit: - return LogicExit::processNode(graph, node); - case sd::logic::Enter: - return LogicEnter::processNode(graph, node); - } - - if (node->getName() == nullptr) { - nd4j_printf("Unknown LogicOp used at node [%i]: [%i]\n", node->id(), node->opNum()); - } else { - nd4j_printf("Unknown LogicOp used at node [%i:<%s>]: [%i]\n", node->id(), node->getName()->c_str(), node->opNum()); - } - return ND4J_STATUS_BAD_INPUT; - } - } -} \ No newline at end of file diff --git a/libnd4j/include/graph/execution/impl/LogicExit.cpp b/libnd4j/include/graph/execution/impl/LogicExit.cpp deleted file mode 100644 index 9a0e217938a8..000000000000 --- a/libnd4j/include/graph/execution/impl/LogicExit.cpp +++ /dev/null @@ -1,47 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include - - -namespace sd { - namespace graph { - Nd4jStatus LogicExit::processNode(Graph *graph, Node *node) { - // this op is basically no-op - // we just know it exists - - auto __variableSpace = graph->getVariableSpace(); - auto __flowPath = __variableSpace->flowPath(); - - Context ctx(node->getContextPrototype(), __variableSpace); - auto input = ctx.variable(0)->getNDArray(); - - std::pair pair0(node->id(), 0); - - if (!__variableSpace->hasVariable(pair0)) - __variableSpace->putVariable(pair0, new Variable(nullptr, nullptr, node->id(), 0)); - - __variableSpace->getVariable(pair0)->setNDArray(input); - __variableSpace->getVariable(pair0)->markRemovable(false); - - return ND4J_STATUS_OK; - } - } -} \ No newline at end of file diff --git a/libnd4j/include/graph/execution/impl/LogicLoopCond.cpp b/libnd4j/include/graph/execution/impl/LogicLoopCond.cpp deleted file mode 100644 index 292452719770..000000000000 --- a/libnd4j/include/graph/execution/impl/LogicLoopCond.cpp +++ /dev/null @@ -1,54 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include - - -namespace sd { - namespace graph { - Nd4jStatus LogicLoopCond::processNode(Graph *graph, Node *node) { - auto __variableSpace = graph->getVariableSpace(); - auto __flowPath = __variableSpace->flowPath(); - - Context ctx(node->getContextPrototype(), __variableSpace); - auto input = ctx.variable(0)->getNDArray(); - - std::pair pair0(node->id(), 0); - - if (!__variableSpace->hasVariable(pair0)) - __variableSpace->putVariable(pair0, new Variable(nullptr, nullptr, node->id(), 0)); - - __variableSpace->getVariable(pair0)->setNDArray(input); - __variableSpace->getVariable(pair0)->markRemovable(false); - - // pass further - if (input->e(0) > 0) { - // if condition is TRUE body will be invoked some time soon - // __flowPath->markFrameActive(node->getFrameId(), true); - //__flowPath->i - } else { - // body won't be activated - // __flowPath->markFrameActive(node->getFrameId(), false); - } - - return ND4J_STATUS_OK; - } - } -} \ No newline at end of file diff --git a/libnd4j/include/graph/execution/impl/LogicMerge.cpp b/libnd4j/include/graph/execution/impl/LogicMerge.cpp deleted file mode 100644 index 9d032a93f110..000000000000 --- a/libnd4j/include/graph/execution/impl/LogicMerge.cpp +++ /dev/null @@ -1,134 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 30.01.18. -// - -#include -#include - -namespace sd { - namespace graph { - Nd4jStatus LogicMerge::processNode(Graph *graph, Node *node) { - // at merge node only one of inputs exist if that's just switch and other node isn't LogicNextItration - auto __variableSpace = graph->getVariableSpace(); - auto __flowPath = __variableSpace->flowPath(); - - // merge MUST have 2 inputs - auto inputAddr0 = node->input()->at(0); - auto inputAddr1 = node->input()->at(1); - - bool isWhile = false; - - // now we want to check if second input is NextIteration - if (graph->hasNode(inputAddr1.first)) { - auto secondNode = graph->nodeById(inputAddr1.first); - - // checking for NextIteration - if (secondNode->opType() == OpType_LOGIC && secondNode->opNum() == 80L) { - isWhile = true; - - // notifying NextIteration node for rewind index - secondNode->setRewindLayer(node->getLayer()); - secondNode->setRewindNode(node->id()); - } - - } - - // FIXME: we don't need this check. Just last input should survive, IF it exists - if (isWhile){ - - if (node->getFrameId() >= 0) - __flowPath->markFrameActive(node->getFrameId(), true); - - bool hasVar = __variableSpace->hasVariable(inputAddr1); - if ( hasVar && __flowPath->wasExecuted(inputAddr1.first)) { - nd4j_debug("Node_%i: propagating second input\n", node->id()); - auto var = __variableSpace->getVariable(inputAddr1); - - Variable *lvar = nullptr; - if (__variableSpace->hasVariable(node->id(), 0)) - lvar = __variableSpace->getVariable(node->id(), 0); - else - lvar = new Variable(nullptr, node->getName()->c_str(), node->id(), 0); - -// if (lvar->hasNDArray()) -// delete lvar->getNDArray(); - - auto array = var->getNDArray(); - - //array->printIndexedBuffer("propagated"); - - lvar->setNDArray(array); - lvar->markReadOnly(true); - - __flowPath->markExecuted(inputAddr1.first, false); - - - } else { - nd4j_debug("Node_%i: propagating first input\n", node->id()); - auto var = __variableSpace->getVariable(inputAddr0); - - Variable *lvar = nullptr; - if (__variableSpace->hasVariable(node->id(), 0)) - lvar = __variableSpace->getVariable(node->id(), 0); - else - lvar = new Variable(nullptr, node->getName()->c_str(), node->id(), 0); - -// if (lvar->hasNDArray()) -// delete lvar->getNDArray(); - - auto array = var->getNDArray(); - lvar->setNDArray(array); - lvar->markReadOnly(true); - - - } - } else { - - // basically, first non-null variable is our target - for (int e = 0; e < node->input()->size(); e++) { - auto inputAddr = node->input()->at(e); - - if (__variableSpace->hasVariable(inputAddr)) { - auto var = __variableSpace->getVariable(inputAddr); - if (!var->hasNDArray() || !__flowPath->isNodeActive(inputAddr.first)) - continue; - - Variable *lvar = nullptr; - if (__variableSpace->hasVariable(node->id(), 0)) - lvar = __variableSpace->getVariable(node->id(), 0); - else - lvar = new Variable(nullptr, node->getName()->c_str(), node->id(), 0); - - if (lvar->hasNDArray()) - delete lvar->getNDArray(); - - auto array = var->getNDArray(); - lvar->setNDArray(array); - lvar->markReadOnly(true); - //lvar->markExternal(false);h - - break; - } - } - } - - return Status::OK(); - } - } -} diff --git a/libnd4j/include/graph/execution/impl/LogicNextIteration.cpp b/libnd4j/include/graph/execution/impl/LogicNextIteration.cpp deleted file mode 100644 index fb7eaa513872..000000000000 --- a/libnd4j/include/graph/execution/impl/LogicNextIteration.cpp +++ /dev/null @@ -1,50 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include - - -namespace sd { - namespace graph { - Nd4jStatus LogicNextIeration::processNode(Graph *graph, Node *node) { - auto __variableSpace = graph->getVariableSpace(); - auto __flowPath = __variableSpace->flowPath(); - - auto inputAddr = node->input()->at(0); - - auto var = __variableSpace->getVariable(inputAddr); - - Variable *lvar = nullptr; - if (__variableSpace->hasVariable(node->id(), 0)) - lvar = __variableSpace->getVariable(node->id(), 0); - else - lvar = new Variable(nullptr, node->getName()->c_str(), node->id(), 0); - -// if (lvar->hasNDArray()) -// delete lvar->getNDArray(); - - auto array = var->getNDArray(); - lvar->setNDArray(array); - lvar->markReadOnly(true); - - return ND4J_STATUS_OK; - } - } -} \ No newline at end of file diff --git a/libnd4j/include/graph/execution/impl/LogicReturn.cpp b/libnd4j/include/graph/execution/impl/LogicReturn.cpp deleted file mode 100644 index 0ee62e9453c2..000000000000 --- a/libnd4j/include/graph/execution/impl/LogicReturn.cpp +++ /dev/null @@ -1,55 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 28.10.2017. -// - -#include "graph/execution/LogicReturn.h" -#include -#include - -namespace sd { - namespace graph { - Nd4jStatus LogicReturn::processNode(Graph *graph, Node *node) { - auto __variableSpace = graph->getVariableSpace(); - - for (int e = 0; e < node->input()->size(); e++) { - auto inputAddr = node->input()->at(e); - auto outputAddr = node->output()->at(e); - - // FIXME!! - outputAddr.second = e; - - if (Environment::getInstance().isDebugAndVerbose()) - nd4j_debug("Return input: <%i, %i>; Return output: <%i, %i>\n", inputAddr.first, inputAddr.second, outputAddr.first, outputAddr.second); - - auto varIn = __variableSpace->getVariable(inputAddr); - auto varOut = __variableSpace->getVariable(outputAddr); - - nd4j_debug("Returning varType: [%s]\n", EnumUtils::_VariableTypeToString(varIn->variableType())); - - // FIXME: this is obviously wrong, we should keep depth track for backprop here - varOut->getNDArray()->assign(varIn->getNDArray()); - - if (Environment::getInstance().isDebugAndVerbose()) - nd4j_debug("In after: [%f]; Out after: [%f]\n", varIn->getNDArray()->meanNumber().e(0), varOut->getNDArray()->meanNumber().e(0)); - } - - return sd::Status::OK(); - } - } -} diff --git a/libnd4j/include/graph/execution/impl/LogicSwitch.cpp b/libnd4j/include/graph/execution/impl/LogicSwitch.cpp deleted file mode 100644 index 1089046a3546..000000000000 --- a/libnd4j/include/graph/execution/impl/LogicSwitch.cpp +++ /dev/null @@ -1,108 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 21.10.17. -// - -#include -#include -#include -#include - -namespace sd { - namespace graph { - Nd4jStatus LogicSwitch::processNode(Graph* graph, Node* node) { - auto __variableSpace = graph->getVariableSpace(); - auto __flowPath = __variableSpace->flowPath(); - - Context ctx(node->getContextPrototype(), __variableSpace); - - // this can be either our format, or compatible format. - if (graph->hasScope(node->input()->at(0).first)) { - nd4j_debug("Node_%i: Scoped mode.\n", node->id()); - // first input is Scope, so it's ours - int scopeConditionIndex = node->input()->at(0).first; - auto input = ctx.variable(1); - - auto scopeCondition = graph->scopeById(scopeConditionIndex); - int lastNode = 0; - for (auto v: *scopeCondition->nodes()) { - GraphExecutioner::executeFlatNode(graph, v, __variableSpace); - lastNode = v->id(); - } - - // now we should take result of the Scope run, and evaluate it - auto result = __variableSpace->getVariable(lastNode)->getNDArray(); - //result->printBuffer("Result of the last node"); - - - std::pair pair0(node->id(), 0); - std::pair pair1(node->id(), 1); - - if (!__variableSpace->hasVariable(pair0)) - __variableSpace->putVariable(pair0, new Variable(nullptr, nullptr, node->id(), 0)); - - if (!__variableSpace->hasVariable(pair1)) - __variableSpace->putVariable(pair1, new Variable(nullptr, nullptr, node->id(), 1)); - - if (!result->e(0)) { - __flowPath->markBranch(node->id(), 0); - __variableSpace->getVariable(pair0)->setNDArray(input->getNDArray()); - __variableSpace->getVariable(pair0)->markRemovable(false); - } else { - __flowPath->markBranch(node->id(), 1); - __variableSpace->getVariable(pair1)->setNDArray(input->getNDArray()); - __variableSpace->getVariable(pair1)->markRemovable(false); - } - } else { - // first input is NOT a Scope, so it's compatible format - nd4j_debug("Node_%i: Compatible mode.\n", node->id()); - - auto input = ctx.variable(0)->getNDArray(); - auto boolean = ctx.variable(1)->getNDArray(); - - //input->printIndexedBuffer("0"); - //boolean->printIndexedBuffer("1"); - - std::pair pair0(node->id(), 0); - std::pair pair1(node->id(), 1); - - if (!__variableSpace->hasVariable(pair0)) - __variableSpace->putVariable(pair0, new Variable(nullptr, nullptr, node->id(), 0)); - - if (!__variableSpace->hasVariable(pair1)) - __variableSpace->putVariable(pair1, new Variable(nullptr, nullptr, node->id(), 1)); - - if (!boolean->e(0)) { - // false - nd4j_debug("Node_%i: FALSE branch active\n", node->id()); - __flowPath->markBranch(node->id(), 0); - __variableSpace->getVariable(pair0)->setNDArray(input); - __variableSpace->getVariable(pair0)->markRemovable(false); - } else { - //true - nd4j_debug("Node_%i: TRUE branch active\n", node->id()); - __flowPath->markBranch(node->id(), 1); - __variableSpace->getVariable(pair1)->setNDArray(input); - __variableSpace->getVariable(pair1)->markRemovable(false); - } - } - - return sd::Status::OK(); - }; - } -} diff --git a/libnd4j/include/graph/execution/impl/LogicWhile.cpp b/libnd4j/include/graph/execution/impl/LogicWhile.cpp deleted file mode 100644 index fec9a0d30b9c..000000000000 --- a/libnd4j/include/graph/execution/impl/LogicWhile.cpp +++ /dev/null @@ -1,144 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 20.10.2017. -// - -#include -#include -#include -#include -#include - - -namespace sd { - namespace graph { - Nd4jStatus LogicWhile::processNode(Graph *graph, Node *node) { - auto __variableSpace = graph->getVariableSpace(); - - nd4j_debug("Starting on WHILE loop: [%i]\n", node->id()); - - // total number of inputs. 2 last inputs are scopes - int inputs = node->input()->size(); - - if (inputs < 3) { - nd4j_printf("While [%i]: loop should have at least 1 external variable announced\n", node->id()); - return ND4J_STATUS_BAD_INPUT; - } - - for (int e = 0; e < inputs - 2; e++) { - std::pair pair(node->id(), e); - if (!__variableSpace->hasVariable(pair)) { - __variableSpace->putVariable(pair, new Variable(nullptr, nullptr, node->id(), e)); - } - - auto va = node->input()->at(e); - - auto inputVar = __variableSpace->getVariable(va); - - auto innerVar = __variableSpace->getVariable(pair); - if (innerVar->hasNDArray()) { - // TODO: ??? - } else { - // FIXME: in some cases it's possible to have no NDArray - if (inputVar->hasNDArray()) - innerVar->setNDArray(new NDArray(inputVar->getNDArray()->dup())); - } - } - - int scopeConditionIndex = node->input()->at(inputs - 2).first; - int scopeBodyIndex = node->input()->at(inputs - 1).first; - - nd4j_debug("While [%i]: got [%i] inputs\n", node->id(), node->input()->size()); - - // we're running condition nodes now - auto scope = graph->scopeById(scopeConditionIndex); - int breaker = 0; - while (true && breaker < 10000000) { - int lastNode = 0; - // we're running condition scope first - nd4j_debug("While [%i]: got [%i] ops in condition scope [%i]\n", node->id(), scope->nodes()->size(), scopeConditionIndex); - - for (Node* v: *scope->nodes()) { - //v->getBlock()->updateVariables(); - if (v->opType() == OpType_LOGIC) { - nd4j_debug("Falling back to logic\n",""); - LogicExecutor::processNode(graph, v); - } else { - nd4j_debug("Op [<%s>]\n", v->getName()->c_str()); - Nd4jStatus status = GraphExecutioner::executeFlatNode(graph, v, __variableSpace); - if (status != ND4J_STATUS_OK) - return status; - } - - lastNode = v->id(); - } - - if (!__variableSpace->hasVariable(lastNode)) { - nd4j_printf("While [%i]: got no results out of conditional loop\n", node->id()); - return ND4J_STATUS_KERNEL_FAILURE; - } - - // now we should take result of the Scope run, and evaluate it - auto result = __variableSpace->getVariable(lastNode)->getNDArray(); - - if (Environment::getInstance().isDebugAndVerbose()) - result->printBuffer("Result of the last node:"); - - // if result evaluates to 0.0 - condition returned FALSE - if (result->e(0) == 0) - break; - else { - auto scopeBody = graph->scopeById(scopeBodyIndex); - int lastNode = 0; - int e = 0; - nd4j_debug("While [%i] got [%i] ops in body scope [%i]\n", node->id(), scopeBody->nodes()->size(), scopeBodyIndex); - for (; e < scopeBody->nodes()->size() - 1; e++) { - Node* v = scopeBody->nodes()->at(e); - - if (v->opType() == OpType_LOGIC) { - nd4j_debug("Falling back to logic\n",""); - LogicExecutor::processNode(graph, v); - } else { - nd4j_debug("Op [<%s>]\n", v->getName()->c_str()); - //v->getBlock()->updateVariables(); - Nd4jStatus status = GraphExecutioner::executeFlatNode(graph, v, __variableSpace); - if (status != ND4J_STATUS_OK) - return status; - } - - lastNode = v->id(); - } - - // now execute return statement - Node* ret = scopeBody->nodes()->at(e); - LogicReturn::processNode(graph, ret); - } - - breaker++; - } - - // if we've hit breaker limit - we should notify about that - if (breaker >= 10000000) { - nd4j_printf("While condition seems to be never ending, aborting...\n", breaker); - return ND4J_STATUS_KERNEL_FAILURE; - } - - return sd::Status::OK(); - } - } -} diff --git a/libnd4j/include/graph/execution/impl/OpSequence.cpp b/libnd4j/include/graph/execution/impl/OpSequence.cpp new file mode 100644 index 000000000000..cc073d3dd26a --- /dev/null +++ b/libnd4j/include/graph/execution/impl/OpSequence.cpp @@ -0,0 +1,182 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include + +namespace sd { +namespace graph { +OpSequence::OpSequence(const int deviceId) : _deviceId(deviceId) { + // +} + +OpSequence::OpSequence(const std::vector &ops, + const int deviceId) { + _deviceId = deviceId; + for (const auto &v : ops) _ops.emplace_back(v); +} + +OpSequence::OpSequence(const OpSequence &other) noexcept { + _ops.clear(); + + for (const auto &v : other._ops) _ops.emplace_back(v); + + _idToIndex = other._idToIndex; + _indexToId = other._indexToId; +} + +//////////////////////////////////////////////////////////////////////// +// move constructor +OpSequence::OpSequence(OpSequence &&other) noexcept: _ops(std::move(other._ops)) { + _idToIndex = std::move(other._idToIndex); + _indexToId = std::move(other._indexToId); +} + +OpSequence &OpSequence::operator=(OpSequence &&other) noexcept { + if (this == &other) return *this; + + _ops = std::move(other._ops); + _idToIndex = std::move(other._idToIndex); + _indexToId = std::move(other._indexToId); + + return *this; +} + +OpSequence &OpSequence::operator=(const OpSequence &other) noexcept { + if (this == &other) return *this; + + _ops.clear(); + for (const auto &v : other._ops) _ops.emplace_back(v); + + _idToIndex = other._idToIndex; + _indexToId = other._indexToId; + + return *this; +} + +void OpSequence::printOut() const { + for (const auto &v : _ops) v.printOut(); +} + +int OpSequence::deviceId() const { return _deviceId; } + +const ExecutionTask &OpSequence::at(uint64_t index) const { + return _ops[index]; +} + +ExecutionTask &OpSequence::at(uint64_t index) { + return _ops[index]; +} + +const ExecutionTask &OpSequence::operator[](uint64_t index) const { + return at(index); +} + +ExecutionTask &OpSequence::operator[](uint64_t index) { + return at(index); +} + +uint64_t OpSequence::length() const { return _ops.size(); } + +void OpSequence::append(const Node &node, + const sd::graph::ContextPrototype &ctx) { + ExecutionTask task(node, ctx); + append(task); +} + +void OpSequence::append(const ExecutionTask& task) { + _ops.emplace_back(task); + + // updating dictionaries + auto index = _ops.size() - 1; + _idToIndex[task.node().id()] = index; + _indexToId[index] = task.node().id(); +} + +void OpSequence::append(ExecutionTask&& task) { + _ops.emplace_back(std::move(task)); + + // updating dictionaries + auto index = _ops.size() - 1; + _idToIndex[task.node().id()] = index; + _indexToId[index] = task.node().id(); +} + +void OpSequence::append(const OpSequence &sequence) { + for (const auto &v:sequence._ops) { + this->append(v); + } +} + +int OpSequence::nodeId(int index) const { + if (index < 0 || index >= _ops.size() || _indexToId.count(index) < 1) + throw std::runtime_error("Out-of-size index requested: " + StringUtils::valueToString(index)); + + return _indexToId.at(index); +} + +int OpSequence::nodeIndex(int id) const { + if ( _idToIndex.count(id) < 1) + throw std::runtime_error("Unknown Node ID requested: " + StringUtils::valueToString(id)); + + return _idToIndex.at(id); +} + +bool OpSequence::hasNode(int id) const { + return _idToIndex.count(id) > 0; +} + +OpSequence::iterator OpSequence::begin() { + return OpSequence::iterator(*this, 0); +} + +OpSequence::iterator OpSequence::end() { + return OpSequence::iterator(*this, length()); +} + +OpSequence::iterator::iterator(OpSequence &container, uint64_t index) + : _container(container), _position(index) { + // +} + +const ExecutionTask &OpSequence::iterator::operator*() const { + return _container._ops[_position]; +} + +OpSequence::iterator &OpSequence::iterator::operator++() { + _position++; + return *this; +} + +OpSequence::iterator &OpSequence::iterator::operator++(int inc) { + return ++(*this); +} + +bool OpSequence::iterator::operator!=(const OpSequence::iterator &other) const { + return _position != other._position; +} + +Nd4jStatus OpSequence::wait() const { + // TODO: to be implemented + return Status::OK(); +} +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/execution/impl/Stack.cpp b/libnd4j/include/graph/execution/impl/Stack.cpp new file mode 100644 index 000000000000..127a45d3cbb5 --- /dev/null +++ b/libnd4j/include/graph/execution/impl/Stack.cpp @@ -0,0 +1,76 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { +namespace graph { + +Stack::Stack(const VariableProxy &root) { + _frames.push_back(StackFrame(const_cast(root), _counter++, -1, 0)); +} + +const VariableProxy &Stack::rootVariableSpace() const { + return _frames.front().variableProxy(); +} + +StackFrame &Stack::back() { + return _frames.back(); +} + +StackFrame &Stack::front() { + return _frames.front(); +} + +StackFrame &Stack::root() { + return _frames.front(); +} + +void Stack::openFrame(int frameId, int enterId) { + _frames.emplace_back(StackFrame(_frames.back().variableProxy(), _counter++, frameId, enterId, _frames.back())); + nd4j_printf("Opening frame [%i], parent: [%i]\n", _frames.back().id(), _frames.back().parent().id()); +} + +void Stack::iterateFrame(int frameId, int enterId) { + auto ¤t = this->back(); + auto &parent = current.parent(); + _frames.emplace_back(StackFrame(_frames.back().variableProxy(), _counter++, frameId, enterId, parent)); + nd4j_printf("Iterating frame, parent: [%i]\n", parent.id()); +} + +void Stack::closeFrame() { + // we should remove all frames untl we hit parent frame + auto &parent = this->back().parent(); + + nd4j_printf("Collapsed frame [%i], parent: [%i]\n", this->back().id(), parent.id()); + + while (!_frames.empty()) { + auto ¤t = this->back(); + + // if ID's match - we'll stop + if (current.id() == parent.id()) + break; + + _frames.pop_back(); + } +} + +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/execution/impl/StackFrame.cpp b/libnd4j/include/graph/execution/impl/StackFrame.cpp new file mode 100644 index 000000000000..115aa03c5b3b --- /dev/null +++ b/libnd4j/include/graph/execution/impl/StackFrame.cpp @@ -0,0 +1,72 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + + +#include + +namespace sd { +namespace graph { + +StackFrame::StackFrame(const VariableProxy &proxy, int id, int frameId, int enterId) + : _proxy(proxy), _frameId(frameId), _enterId(enterId), _id(id) { } + +StackFrame::StackFrame(const VariableProxy &proxy, int id, int frameId, int enterId, StackFrame &parent) + : StackFrame(proxy, id, frameId, enterId) { + _parent = &parent; +} + +void StackFrame::disableNode(int nodeId) { + _disabledNodes[nodeId] = 1; +} + +bool StackFrame::isDisabled(int nodeId) const { + return _disabledNodes.count(nodeId) > 0; +} + +int StackFrame::frameId() const { + return _frameId; +} + +int StackFrame::enterId() const { + return _enterId; +} + +int StackFrame::exitId() const { + return _exitId; +} + +void StackFrame::setExitId(int id) const { + _exitId = id; +} + +int StackFrame::rewindId() const { + return _rewindId; +} + +void StackFrame::setRewindId(int id) const { + _rewindId = id; +} + +StackFrame &StackFrame::parent() const { + return *_parent; +} + +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/generated/array_generated.h b/libnd4j/include/graph/generated/array_generated.h index 5c4c0d7af947..1d1feaefab9f 100644 --- a/libnd4j/include/graph/generated/array_generated.h +++ b/libnd4j/include/graph/generated/array_generated.h @@ -1,6 +1,5 @@ // automatically generated by the FlatBuffers compiler, do not modify - #ifndef FLATBUFFERS_GENERATED_ARRAY_ND4J_GRAPH_H_ #define FLATBUFFERS_GENERATED_ARRAY_ND4J_GRAPH_H_ @@ -19,19 +18,12 @@ enum ByteOrder { }; inline const ByteOrder (&EnumValuesByteOrder())[2] { - static const ByteOrder values[] = { - ByteOrder_LE, - ByteOrder_BE - }; + static const ByteOrder values[] = {ByteOrder_LE, ByteOrder_BE}; return values; } -inline const char * const *EnumNamesByteOrder() { - static const char * const names[] = { - "LE", - "BE", - nullptr - }; +inline const char *const *EnumNamesByteOrder() { + static const char *const names[] = {"LE", "BE", nullptr}; return names; } @@ -68,88 +60,24 @@ enum DType { inline const DType (&EnumValuesDType())[21] { static const DType values[] = { - DType_INHERIT, - DType_BOOL, - DType_FLOAT8, - DType_HALF, - DType_HALF2, - DType_FLOAT, - DType_DOUBLE, - DType_INT8, - DType_INT16, - DType_INT32, - DType_INT64, - DType_UINT8, - DType_UINT16, - DType_UINT32, - DType_UINT64, - DType_QINT8, - DType_QINT16, - DType_BFLOAT16, - DType_UTF8, - DType_UTF16, - DType_UTF32 - }; + DType_INHERIT, DType_BOOL, DType_FLOAT8, DType_HALF, DType_HALF2, + DType_FLOAT, DType_DOUBLE, DType_INT8, DType_INT16, DType_INT32, + DType_INT64, DType_UINT8, DType_UINT16, DType_UINT32, DType_UINT64, + DType_QINT8, DType_QINT16, DType_BFLOAT16, DType_UTF8, DType_UTF16, + DType_UTF32}; return values; } -inline const char * const *EnumNamesDType() { - static const char * const names[] = { - "INHERIT", - "BOOL", - "FLOAT8", - "HALF", - "HALF2", - "FLOAT", - "DOUBLE", - "INT8", - "INT16", - "INT32", - "INT64", - "UINT8", - "UINT16", - "UINT32", - "UINT64", - "QINT8", - "QINT16", - "BFLOAT16", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "UTF8", - "UTF16", - "UTF32", - nullptr - }; +inline const char *const *EnumNamesDType() { + static const char *const names[] = { + "INHERIT", "BOOL", "FLOAT8", "HALF", "HALF2", "FLOAT", "DOUBLE", + "INT8", "INT16", "INT32", "INT64", "UINT8", "UINT16", "UINT32", + "UINT64", "QINT8", "QINT16", "BFLOAT16", "", "", "", + "", "", "", "", "", "", "", + "", "", "", "", "", "", "", + "", "", "", "", "", "", "", + "", "", "", "", "", "", "", + "", "UTF8", "UTF16", "UTF32", nullptr}; return names; } @@ -159,12 +87,7 @@ inline const char *EnumNameDType(DType e) { } struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_SHAPE = 4, - VT_BUFFER = 6, - VT_DTYPE = 8, - VT_BYTEORDER = 10 - }; + enum { VT_SHAPE = 4, VT_BUFFER = 6, VT_DTYPE = 8, VT_BYTEORDER = 10 }; const flatbuffers::Vector *shape() const { return GetPointer *>(VT_SHAPE); } @@ -178,14 +101,12 @@ struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return static_cast(GetField(VT_BYTEORDER, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_SHAPE) && + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_SHAPE) && verifier.VerifyVector(shape()) && VerifyOffset(verifier, VT_BUFFER) && verifier.VerifyVector(buffer()) && VerifyField(verifier, VT_DTYPE) && - VerifyField(verifier, VT_BYTEORDER) && - verifier.EndTable(); + VerifyField(verifier, VT_BYTEORDER) && verifier.EndTable(); } }; @@ -202,10 +123,10 @@ struct FlatArrayBuilder { fbb_.AddElement(FlatArray::VT_DTYPE, static_cast(dtype), 0); } void add_byteOrder(ByteOrder byteOrder) { - fbb_.AddElement(FlatArray::VT_BYTEORDER, static_cast(byteOrder), 0); + fbb_.AddElement(FlatArray::VT_BYTEORDER, + static_cast(byteOrder), 0); } - explicit FlatArrayBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit FlatArrayBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } FlatArrayBuilder &operator=(const FlatArrayBuilder &); @@ -220,8 +141,7 @@ inline flatbuffers::Offset CreateFlatArray( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset> shape = 0, flatbuffers::Offset> buffer = 0, - DType dtype = DType_INHERIT, - ByteOrder byteOrder = ByteOrder_LE) { + DType dtype = DType_INHERIT, ByteOrder byteOrder = ByteOrder_LE) { FlatArrayBuilder builder_(_fbb); builder_.add_buffer(buffer); builder_.add_shape(shape); @@ -233,15 +153,11 @@ inline flatbuffers::Offset CreateFlatArray( inline flatbuffers::Offset CreateFlatArrayDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector *shape = nullptr, - const std::vector *buffer = nullptr, - DType dtype = DType_INHERIT, + const std::vector *buffer = nullptr, DType dtype = DType_INHERIT, ByteOrder byteOrder = ByteOrder_LE) { return sd::graph::CreateFlatArray( - _fbb, - shape ? _fbb.CreateVector(*shape) : 0, - buffer ? _fbb.CreateVector(*buffer) : 0, - dtype, - byteOrder); + _fbb, shape ? _fbb.CreateVector(*shape) : 0, + buffer ? _fbb.CreateVector(*buffer) : 0, dtype, byteOrder); } inline const sd::graph::FlatArray *GetFlatArray(const void *buf) { @@ -252,13 +168,11 @@ inline const sd::graph::FlatArray *GetSizePrefixedFlatArray(const void *buf) { return flatbuffers::GetSizePrefixedRoot(buf); } -inline bool VerifyFlatArrayBuffer( - flatbuffers::Verifier &verifier) { +inline bool VerifyFlatArrayBuffer(flatbuffers::Verifier &verifier) { return verifier.VerifyBuffer(nullptr); } -inline bool VerifySizePrefixedFlatArrayBuffer( - flatbuffers::Verifier &verifier) { +inline bool VerifySizePrefixedFlatArrayBuffer(flatbuffers::Verifier &verifier) { return verifier.VerifySizePrefixedBuffer(nullptr); } diff --git a/libnd4j/include/graph/generated/config_generated.h b/libnd4j/include/graph/generated/config_generated.h index 2c12027a2b0c..505fd5767056 100644 --- a/libnd4j/include/graph/generated/config_generated.h +++ b/libnd4j/include/graph/generated/config_generated.h @@ -1,6 +1,5 @@ // automatically generated by the FlatBuffers compiler, do not modify - #ifndef FLATBUFFERS_GENERATED_CONFIG_ND4J_GRAPH_H_ #define FLATBUFFERS_GENERATED_CONFIG_ND4J_GRAPH_H_ @@ -22,22 +21,14 @@ enum ProfilingMode { inline const ProfilingMode (&EnumValuesProfilingMode())[4] { static const ProfilingMode values[] = { - ProfilingMode_NONE, - ProfilingMode_NAN_PANIC, - ProfilingMode_INF_PANIC, - ProfilingMode_ANY_PANIC - }; + ProfilingMode_NONE, ProfilingMode_NAN_PANIC, ProfilingMode_INF_PANIC, + ProfilingMode_ANY_PANIC}; return values; } -inline const char * const *EnumNamesProfilingMode() { - static const char * const names[] = { - "NONE", - "NAN_PANIC", - "INF_PANIC", - "ANY_PANIC", - nullptr - }; +inline const char *const *EnumNamesProfilingMode() { + static const char *const names[] = {"NONE", "NAN_PANIC", "INF_PANIC", + "ANY_PANIC", nullptr}; return names; } @@ -56,20 +47,12 @@ enum ExecutionMode { inline const ExecutionMode (&EnumValuesExecutionMode())[3] { static const ExecutionMode values[] = { - ExecutionMode_SEQUENTIAL, - ExecutionMode_STRICT, - ExecutionMode_AUTO - }; + ExecutionMode_SEQUENTIAL, ExecutionMode_STRICT, ExecutionMode_AUTO}; return values; } -inline const char * const *EnumNamesExecutionMode() { - static const char * const names[] = { - "SEQUENTIAL", - "STRICT", - "AUTO", - nullptr - }; +inline const char *const *EnumNamesExecutionMode() { + static const char *const names[] = {"SEQUENTIAL", "STRICT", "AUTO", nullptr}; return names; } @@ -89,25 +72,17 @@ enum OutputMode { }; inline const OutputMode (&EnumValuesOutputMode())[5] { - static const OutputMode values[] = { - OutputMode_IMPLICIT, - OutputMode_EXPLICIT, - OutputMode_EXPLICIT_AND_IMPLICIT, - OutputMode_VARIABLE_SPACE, - OutputMode_OPTIMIZED - }; + static const OutputMode values[] = {OutputMode_IMPLICIT, OutputMode_EXPLICIT, + OutputMode_EXPLICIT_AND_IMPLICIT, + OutputMode_VARIABLE_SPACE, + OutputMode_OPTIMIZED}; return values; } -inline const char * const *EnumNamesOutputMode() { - static const char * const names[] = { - "IMPLICIT", - "EXPLICIT", - "EXPLICIT_AND_IMPLICIT", - "VARIABLE_SPACE", - "OPTIMIZED", - nullptr - }; +inline const char *const *EnumNamesOutputMode() { + static const char *const names[] = { + "IMPLICIT", "EXPLICIT", "EXPLICIT_AND_IMPLICIT", + "VARIABLE_SPACE", "OPTIMIZED", nullptr}; return names; } @@ -125,21 +100,15 @@ enum Direction { }; inline const Direction (&EnumValuesDirection())[3] { - static const Direction values[] = { - Direction_FORWARD_ONLY, - Direction_FORWARD_AND_BACKWARD, - Direction_BACKWARD_ONLY - }; + static const Direction values[] = {Direction_FORWARD_ONLY, + Direction_FORWARD_AND_BACKWARD, + Direction_BACKWARD_ONLY}; return values; } -inline const char * const *EnumNamesDirection() { - static const char * const names[] = { - "FORWARD_ONLY", - "FORWARD_AND_BACKWARD", - "BACKWARD_ONLY", - nullptr - }; +inline const char *const *EnumNamesDirection() { + static const char *const names[] = {"FORWARD_ONLY", "FORWARD_AND_BACKWARD", + "BACKWARD_ONLY", nullptr}; return names; } @@ -159,9 +128,7 @@ struct FlatConfiguration FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_FOOTPRINTBACKWARD = 16, VT_DIRECTION = 18 }; - int64_t id() const { - return GetField(VT_ID, 0); - } + int64_t id() const { return GetField(VT_ID, 0); } ExecutionMode executionMode() const { return static_cast(GetField(VT_EXECUTIONMODE, 0)); } @@ -171,9 +138,7 @@ struct FlatConfiguration FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { OutputMode outputMode() const { return static_cast(GetField(VT_OUTPUTMODE, 0)); } - bool timestats() const { - return GetField(VT_TIMESTATS, 0) != 0; - } + bool timestats() const { return GetField(VT_TIMESTATS, 0) != 0; } int64_t footprintForward() const { return GetField(VT_FOOTPRINTFORWARD, 0); } @@ -192,8 +157,7 @@ struct FlatConfiguration FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyField(verifier, VT_TIMESTATS) && VerifyField(verifier, VT_FOOTPRINTFORWARD) && VerifyField(verifier, VT_FOOTPRINTBACKWARD) && - VerifyField(verifier, VT_DIRECTION) && - verifier.EndTable(); + VerifyField(verifier, VT_DIRECTION) && verifier.EndTable(); } }; @@ -204,28 +168,35 @@ struct FlatConfigurationBuilder { fbb_.AddElement(FlatConfiguration::VT_ID, id, 0); } void add_executionMode(ExecutionMode executionMode) { - fbb_.AddElement(FlatConfiguration::VT_EXECUTIONMODE, static_cast(executionMode), 0); + fbb_.AddElement(FlatConfiguration::VT_EXECUTIONMODE, + static_cast(executionMode), 0); } void add_profilingMode(ProfilingMode profilingMode) { - fbb_.AddElement(FlatConfiguration::VT_PROFILINGMODE, static_cast(profilingMode), 0); + fbb_.AddElement(FlatConfiguration::VT_PROFILINGMODE, + static_cast(profilingMode), 0); } void add_outputMode(OutputMode outputMode) { - fbb_.AddElement(FlatConfiguration::VT_OUTPUTMODE, static_cast(outputMode), 0); + fbb_.AddElement(FlatConfiguration::VT_OUTPUTMODE, + static_cast(outputMode), 0); } void add_timestats(bool timestats) { - fbb_.AddElement(FlatConfiguration::VT_TIMESTATS, static_cast(timestats), 0); + fbb_.AddElement(FlatConfiguration::VT_TIMESTATS, + static_cast(timestats), 0); } void add_footprintForward(int64_t footprintForward) { - fbb_.AddElement(FlatConfiguration::VT_FOOTPRINTFORWARD, footprintForward, 0); + fbb_.AddElement(FlatConfiguration::VT_FOOTPRINTFORWARD, + footprintForward, 0); } void add_footprintBackward(int64_t footprintBackward) { - fbb_.AddElement(FlatConfiguration::VT_FOOTPRINTBACKWARD, footprintBackward, 0); + fbb_.AddElement(FlatConfiguration::VT_FOOTPRINTBACKWARD, + footprintBackward, 0); } void add_direction(Direction direction) { - fbb_.AddElement(FlatConfiguration::VT_DIRECTION, static_cast(direction), 0); + fbb_.AddElement(FlatConfiguration::VT_DIRECTION, + static_cast(direction), 0); } explicit FlatConfigurationBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } FlatConfigurationBuilder &operator=(const FlatConfigurationBuilder &); @@ -237,14 +208,11 @@ struct FlatConfigurationBuilder { }; inline flatbuffers::Offset CreateFlatConfiguration( - flatbuffers::FlatBufferBuilder &_fbb, - int64_t id = 0, + flatbuffers::FlatBufferBuilder &_fbb, int64_t id = 0, ExecutionMode executionMode = ExecutionMode_SEQUENTIAL, ProfilingMode profilingMode = ProfilingMode_NONE, - OutputMode outputMode = OutputMode_IMPLICIT, - bool timestats = false, - int64_t footprintForward = 0, - int64_t footprintBackward = 0, + OutputMode outputMode = OutputMode_IMPLICIT, bool timestats = false, + int64_t footprintForward = 0, int64_t footprintBackward = 0, Direction direction = Direction_FORWARD_ONLY) { FlatConfigurationBuilder builder_(_fbb); builder_.add_footprintBackward(footprintBackward); @@ -258,22 +226,24 @@ inline flatbuffers::Offset CreateFlatConfiguration( return builder_.Finish(); } -inline const sd::graph::FlatConfiguration *GetFlatConfiguration(const void *buf) { +inline const sd::graph::FlatConfiguration *GetFlatConfiguration( + const void *buf) { return flatbuffers::GetRoot(buf); } -inline const sd::graph::FlatConfiguration *GetSizePrefixedFlatConfiguration(const void *buf) { +inline const sd::graph::FlatConfiguration *GetSizePrefixedFlatConfiguration( + const void *buf) { return flatbuffers::GetSizePrefixedRoot(buf); } -inline bool VerifyFlatConfigurationBuffer( - flatbuffers::Verifier &verifier) { +inline bool VerifyFlatConfigurationBuffer(flatbuffers::Verifier &verifier) { return verifier.VerifyBuffer(nullptr); } inline bool VerifySizePrefixedFlatConfigurationBuffer( flatbuffers::Verifier &verifier) { - return verifier.VerifySizePrefixedBuffer(nullptr); + return verifier.VerifySizePrefixedBuffer( + nullptr); } inline void FinishFlatConfigurationBuffer( diff --git a/libnd4j/include/graph/generated/graph.grpc.fb.h b/libnd4j/include/graph/generated/graph.grpc.fb.h index 0167f48d568c..7f9e27d42163 100644 --- a/libnd4j/include/graph/generated/graph.grpc.fb.h +++ b/libnd4j/include/graph/generated/graph.grpc.fb.h @@ -4,9 +4,6 @@ #ifndef GRPC_graph__INCLUDED #define GRPC_graph__INCLUDED -#include "graph_generated.h" -#include "flatbuffers/grpc.h" - #include #include #include @@ -17,6 +14,9 @@ #include #include +#include "flatbuffers/grpc.h" +#include "graph_generated.h" + namespace grpc { class CompletionQueue; class Channel; @@ -35,196 +35,449 @@ class GraphInferenceServer final { class StubInterface { public: virtual ~StubInterface() {} - virtual ::grpc::Status RegisterGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, flatbuffers::grpc::Message* response) = 0; - std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>> AsyncRegisterGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>>(AsyncRegisterGraphRaw(context, request, cq)); - } - std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>> PrepareAsyncRegisterGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>>(PrepareAsyncRegisterGraphRaw(context, request, cq)); - } - virtual ::grpc::Status ForgetGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, flatbuffers::grpc::Message* response) = 0; - std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>> AsyncForgetGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>>(AsyncForgetGraphRaw(context, request, cq)); - } - std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>> PrepareAsyncForgetGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>>(PrepareAsyncForgetGraphRaw(context, request, cq)); - } - virtual ::grpc::Status ReplaceGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, flatbuffers::grpc::Message* response) = 0; - std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>> AsyncReplaceGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>>(AsyncReplaceGraphRaw(context, request, cq)); - } - std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>> PrepareAsyncReplaceGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>>(PrepareAsyncReplaceGraphRaw(context, request, cq)); - } - virtual ::grpc::Status InferenceRequest(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, flatbuffers::grpc::Message* response) = 0; - std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>> AsyncInferenceRequest(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>>(AsyncInferenceRequestRaw(context, request, cq)); - } - std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>> PrepareAsyncInferenceRequest(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>>(PrepareAsyncInferenceRequestRaw(context, request, cq)); - } - private: - virtual ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>* AsyncRegisterGraphRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>* PrepareAsyncRegisterGraphRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>* AsyncForgetGraphRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>* PrepareAsyncForgetGraphRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>* AsyncReplaceGraphRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>* PrepareAsyncReplaceGraphRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>* AsyncInferenceRequestRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>* PrepareAsyncInferenceRequestRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) = 0; + virtual ::grpc::Status RegisterGraph( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + flatbuffers::grpc::Message* response) = 0; + std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>> + AsyncRegisterGraph(::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>>( + AsyncRegisterGraphRaw(context, request, cq)); + } + std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>> + PrepareAsyncRegisterGraph( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>>( + PrepareAsyncRegisterGraphRaw(context, request, cq)); + } + virtual ::grpc::Status ForgetGraph( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + flatbuffers::grpc::Message* response) = 0; + std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>> + AsyncForgetGraph(::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>>( + AsyncForgetGraphRaw(context, request, cq)); + } + std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>> + PrepareAsyncForgetGraph( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>>( + PrepareAsyncForgetGraphRaw(context, request, cq)); + } + virtual ::grpc::Status ReplaceGraph( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + flatbuffers::grpc::Message* response) = 0; + std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>> + AsyncReplaceGraph(::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>>( + AsyncReplaceGraphRaw(context, request, cq)); + } + std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>> + PrepareAsyncReplaceGraph( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>>( + PrepareAsyncReplaceGraphRaw(context, request, cq)); + } + virtual ::grpc::Status InferenceRequest( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + flatbuffers::grpc::Message* response) = 0; + std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>> + AsyncInferenceRequest( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>>( + AsyncInferenceRequestRaw(context, request, cq)); + } + std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>> + PrepareAsyncInferenceRequest( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>>( + PrepareAsyncInferenceRequestRaw(context, request, cq)); + } + + private: + virtual ::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>* + AsyncRegisterGraphRaw(::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) = 0; + virtual ::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>* + PrepareAsyncRegisterGraphRaw( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) = 0; + virtual ::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>* + AsyncForgetGraphRaw( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) = 0; + virtual ::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>* + PrepareAsyncForgetGraphRaw( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) = 0; + virtual ::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>* + AsyncReplaceGraphRaw(::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) = 0; + virtual ::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>* + PrepareAsyncReplaceGraphRaw( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) = 0; + virtual ::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>* + AsyncInferenceRequestRaw( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) = 0; + virtual ::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>* + PrepareAsyncInferenceRequestRaw( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) = 0; }; class Stub final : public StubInterface { public: - Stub(const std::shared_ptr< ::grpc::ChannelInterface>& channel); - ::grpc::Status RegisterGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, flatbuffers::grpc::Message* response) override; - std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>> AsyncRegisterGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>>(AsyncRegisterGraphRaw(context, request, cq)); - } - std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>> PrepareAsyncRegisterGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>>(PrepareAsyncRegisterGraphRaw(context, request, cq)); + Stub(const std::shared_ptr<::grpc::ChannelInterface>& channel); + ::grpc::Status RegisterGraph( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + flatbuffers::grpc::Message* response) override; + std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>> + AsyncRegisterGraph(::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>>( + AsyncRegisterGraphRaw(context, request, cq)); + } + std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>> + PrepareAsyncRegisterGraph( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>>( + PrepareAsyncRegisterGraphRaw(context, request, cq)); + } + ::grpc::Status ForgetGraph( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + flatbuffers::grpc::Message* response) override; + std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>> + AsyncForgetGraph(::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>>( + AsyncForgetGraphRaw(context, request, cq)); + } + std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>> + PrepareAsyncForgetGraph( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>>( + PrepareAsyncForgetGraphRaw(context, request, cq)); + } + ::grpc::Status ReplaceGraph( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + flatbuffers::grpc::Message* response) override; + std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>> + AsyncReplaceGraph(::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>>( + AsyncReplaceGraphRaw(context, request, cq)); + } + std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>> + PrepareAsyncReplaceGraph( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>>( + PrepareAsyncReplaceGraphRaw(context, request, cq)); + } + ::grpc::Status InferenceRequest( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + flatbuffers::grpc::Message* response) override; + std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>> + AsyncInferenceRequest( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>>( + AsyncInferenceRequestRaw(context, request, cq)); + } + std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>> + PrepareAsyncInferenceRequest( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>>( + PrepareAsyncInferenceRequestRaw(context, request, cq)); } - ::grpc::Status ForgetGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, flatbuffers::grpc::Message* response) override; - std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>> AsyncForgetGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>>(AsyncForgetGraphRaw(context, request, cq)); - } - std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>> PrepareAsyncForgetGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>>(PrepareAsyncForgetGraphRaw(context, request, cq)); - } - ::grpc::Status ReplaceGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, flatbuffers::grpc::Message* response) override; - std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>> AsyncReplaceGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>>(AsyncReplaceGraphRaw(context, request, cq)); - } - std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>> PrepareAsyncReplaceGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>>(PrepareAsyncReplaceGraphRaw(context, request, cq)); - } - ::grpc::Status InferenceRequest(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, flatbuffers::grpc::Message* response) override; - std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>> AsyncInferenceRequest(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>>(AsyncInferenceRequestRaw(context, request, cq)); - } - std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>> PrepareAsyncInferenceRequest(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>>(PrepareAsyncInferenceRequestRaw(context, request, cq)); - } - + private: - std::shared_ptr< ::grpc::ChannelInterface> channel_; - ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>* AsyncRegisterGraphRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) override; - ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>* PrepareAsyncRegisterGraphRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) override; - ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>* AsyncForgetGraphRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) override; - ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>* PrepareAsyncForgetGraphRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) override; - ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>* AsyncReplaceGraphRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) override; - ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>* PrepareAsyncReplaceGraphRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) override; - ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>* AsyncInferenceRequestRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) override; - ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>* PrepareAsyncInferenceRequestRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) override; + std::shared_ptr<::grpc::ChannelInterface> channel_; + ::grpc::ClientAsyncResponseReader>* + AsyncRegisterGraphRaw(::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) override; + ::grpc::ClientAsyncResponseReader>* + PrepareAsyncRegisterGraphRaw( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) override; + ::grpc::ClientAsyncResponseReader>* + AsyncForgetGraphRaw( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) override; + ::grpc::ClientAsyncResponseReader>* + PrepareAsyncForgetGraphRaw( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) override; + ::grpc::ClientAsyncResponseReader>* + AsyncReplaceGraphRaw(::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) override; + ::grpc::ClientAsyncResponseReader>* + PrepareAsyncReplaceGraphRaw( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) override; + ::grpc::ClientAsyncResponseReader>* + AsyncInferenceRequestRaw( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) override; + ::grpc::ClientAsyncResponseReader>* + PrepareAsyncInferenceRequestRaw( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) override; const ::grpc::internal::RpcMethod rpcmethod_RegisterGraph_; const ::grpc::internal::RpcMethod rpcmethod_ForgetGraph_; const ::grpc::internal::RpcMethod rpcmethod_ReplaceGraph_; const ::grpc::internal::RpcMethod rpcmethod_InferenceRequest_; }; - static std::unique_ptr NewStub(const std::shared_ptr< ::grpc::ChannelInterface>& channel, const ::grpc::StubOptions& options = ::grpc::StubOptions()); - + static std::unique_ptr NewStub( + const std::shared_ptr<::grpc::ChannelInterface>& channel, + const ::grpc::StubOptions& options = ::grpc::StubOptions()); + class Service : public ::grpc::Service { public: Service(); virtual ~Service(); - virtual ::grpc::Status RegisterGraph(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response); - virtual ::grpc::Status ForgetGraph(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response); - virtual ::grpc::Status ReplaceGraph(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response); - virtual ::grpc::Status InferenceRequest(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response); + virtual ::grpc::Status RegisterGraph( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response); + virtual ::grpc::Status ForgetGraph( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response); + virtual ::grpc::Status ReplaceGraph( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response); + virtual ::grpc::Status InferenceRequest( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response); }; template class WithAsyncMethod_RegisterGraph : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithAsyncMethod_RegisterGraph() { - ::grpc::Service::MarkMethodAsync(0); - } + WithAsyncMethod_RegisterGraph() { ::grpc::Service::MarkMethodAsync(0); } ~WithAsyncMethod_RegisterGraph() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status RegisterGraph(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response) final override { + ::grpc::Status RegisterGraph( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response) final override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } - void RequestRegisterGraph(::grpc::ServerContext* context, flatbuffers::grpc::Message* request, ::grpc::ServerAsyncResponseWriter< flatbuffers::grpc::Message>* response, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void *tag) { - ::grpc::Service::RequestAsyncUnary(0, context, request, response, new_call_cq, notification_cq, tag); + void RequestRegisterGraph( + ::grpc::ServerContext* context, + flatbuffers::grpc::Message* request, + ::grpc::ServerAsyncResponseWriter< + flatbuffers::grpc::Message>* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(0, context, request, response, + new_call_cq, notification_cq, tag); } }; template class WithAsyncMethod_ForgetGraph : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithAsyncMethod_ForgetGraph() { - ::grpc::Service::MarkMethodAsync(1); - } + WithAsyncMethod_ForgetGraph() { ::grpc::Service::MarkMethodAsync(1); } ~WithAsyncMethod_ForgetGraph() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status ForgetGraph(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response) final override { + ::grpc::Status ForgetGraph( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response) final override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } - void RequestForgetGraph(::grpc::ServerContext* context, flatbuffers::grpc::Message* request, ::grpc::ServerAsyncResponseWriter< flatbuffers::grpc::Message>* response, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void *tag) { - ::grpc::Service::RequestAsyncUnary(1, context, request, response, new_call_cq, notification_cq, tag); + void RequestForgetGraph( + ::grpc::ServerContext* context, + flatbuffers::grpc::Message* request, + ::grpc::ServerAsyncResponseWriter< + flatbuffers::grpc::Message>* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(1, context, request, response, + new_call_cq, notification_cq, tag); } }; template class WithAsyncMethod_ReplaceGraph : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithAsyncMethod_ReplaceGraph() { - ::grpc::Service::MarkMethodAsync(2); - } + WithAsyncMethod_ReplaceGraph() { ::grpc::Service::MarkMethodAsync(2); } ~WithAsyncMethod_ReplaceGraph() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status ReplaceGraph(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response) final override { + ::grpc::Status ReplaceGraph( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response) final override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } - void RequestReplaceGraph(::grpc::ServerContext* context, flatbuffers::grpc::Message* request, ::grpc::ServerAsyncResponseWriter< flatbuffers::grpc::Message>* response, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void *tag) { - ::grpc::Service::RequestAsyncUnary(2, context, request, response, new_call_cq, notification_cq, tag); + void RequestReplaceGraph( + ::grpc::ServerContext* context, + flatbuffers::grpc::Message* request, + ::grpc::ServerAsyncResponseWriter< + flatbuffers::grpc::Message>* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(2, context, request, response, + new_call_cq, notification_cq, tag); } }; template class WithAsyncMethod_InferenceRequest : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithAsyncMethod_InferenceRequest() { - ::grpc::Service::MarkMethodAsync(3); - } + WithAsyncMethod_InferenceRequest() { ::grpc::Service::MarkMethodAsync(3); } ~WithAsyncMethod_InferenceRequest() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status InferenceRequest(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response) final override { + ::grpc::Status InferenceRequest( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response) final override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } - void RequestInferenceRequest(::grpc::ServerContext* context, flatbuffers::grpc::Message* request, ::grpc::ServerAsyncResponseWriter< flatbuffers::grpc::Message>* response, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void *tag) { - ::grpc::Service::RequestAsyncUnary(3, context, request, response, new_call_cq, notification_cq, tag); + void RequestInferenceRequest( + ::grpc::ServerContext* context, + flatbuffers::grpc::Message* request, + ::grpc::ServerAsyncResponseWriter< + flatbuffers::grpc::Message>* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(3, context, request, response, + new_call_cq, notification_cq, tag); } }; - typedef WithAsyncMethod_RegisterGraph< WithAsyncMethod_ForgetGraph< WithAsyncMethod_ReplaceGraph< WithAsyncMethod_InferenceRequest< Service > > > > AsyncService; + typedef WithAsyncMethod_RegisterGraph>>> + AsyncService; template class WithGenericMethod_RegisterGraph : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithGenericMethod_RegisterGraph() { - ::grpc::Service::MarkMethodGeneric(0); - } + WithGenericMethod_RegisterGraph() { ::grpc::Service::MarkMethodGeneric(0); } ~WithGenericMethod_RegisterGraph() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status RegisterGraph(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response) final override { + ::grpc::Status RegisterGraph( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response) final override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } @@ -232,16 +485,18 @@ class GraphInferenceServer final { template class WithGenericMethod_ForgetGraph : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithGenericMethod_ForgetGraph() { - ::grpc::Service::MarkMethodGeneric(1); - } + WithGenericMethod_ForgetGraph() { ::grpc::Service::MarkMethodGeneric(1); } ~WithGenericMethod_ForgetGraph() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status ForgetGraph(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response) final override { + ::grpc::Status ForgetGraph( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response) final override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } @@ -249,16 +504,18 @@ class GraphInferenceServer final { template class WithGenericMethod_ReplaceGraph : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithGenericMethod_ReplaceGraph() { - ::grpc::Service::MarkMethodGeneric(2); - } + WithGenericMethod_ReplaceGraph() { ::grpc::Service::MarkMethodGeneric(2); } ~WithGenericMethod_ReplaceGraph() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status ReplaceGraph(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response) final override { + ::grpc::Status ReplaceGraph( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response) final override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } @@ -266,7 +523,8 @@ class GraphInferenceServer final { template class WithGenericMethod_InferenceRequest : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: WithGenericMethod_InferenceRequest() { ::grpc::Service::MarkMethodGeneric(3); @@ -275,7 +533,10 @@ class GraphInferenceServer final { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status InferenceRequest(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response) final override { + ::grpc::Status InferenceRequest( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response) final override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } @@ -283,90 +544,147 @@ class GraphInferenceServer final { template class WithStreamedUnaryMethod_RegisterGraph : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: WithStreamedUnaryMethod_RegisterGraph() { - ::grpc::Service::MarkMethodStreamed(0, - new ::grpc::internal::StreamedUnaryHandler< flatbuffers::grpc::Message, flatbuffers::grpc::Message>(std::bind(&WithStreamedUnaryMethod_RegisterGraph::StreamedRegisterGraph, this, std::placeholders::_1, std::placeholders::_2))); + ::grpc::Service::MarkMethodStreamed( + 0, new ::grpc::internal::StreamedUnaryHandler< + flatbuffers::grpc::Message, + flatbuffers::grpc::Message>(std::bind( + &WithStreamedUnaryMethod_RegisterGraph< + BaseClass>::StreamedRegisterGraph, + this, std::placeholders::_1, std::placeholders::_2))); } ~WithStreamedUnaryMethod_RegisterGraph() override { BaseClassMustBeDerivedFromService(this); } // disable regular version of this method - ::grpc::Status RegisterGraph(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response) final override { + ::grpc::Status RegisterGraph( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response) final override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } // replace default version of method with streamed unary - virtual ::grpc::Status StreamedRegisterGraph(::grpc::ServerContext* context, ::grpc::ServerUnaryStreamer< flatbuffers::grpc::Message,flatbuffers::grpc::Message>* server_unary_streamer) = 0; + virtual ::grpc::Status StreamedRegisterGraph( + ::grpc::ServerContext* context, + ::grpc::ServerUnaryStreamer, + flatbuffers::grpc::Message>* + server_unary_streamer) = 0; }; template class WithStreamedUnaryMethod_ForgetGraph : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: WithStreamedUnaryMethod_ForgetGraph() { - ::grpc::Service::MarkMethodStreamed(1, - new ::grpc::internal::StreamedUnaryHandler< flatbuffers::grpc::Message, flatbuffers::grpc::Message>(std::bind(&WithStreamedUnaryMethod_ForgetGraph::StreamedForgetGraph, this, std::placeholders::_1, std::placeholders::_2))); + ::grpc::Service::MarkMethodStreamed( + 1, new ::grpc::internal::StreamedUnaryHandler< + flatbuffers::grpc::Message, + flatbuffers::grpc::Message>(std::bind( + &WithStreamedUnaryMethod_ForgetGraph< + BaseClass>::StreamedForgetGraph, + this, std::placeholders::_1, std::placeholders::_2))); } ~WithStreamedUnaryMethod_ForgetGraph() override { BaseClassMustBeDerivedFromService(this); } // disable regular version of this method - ::grpc::Status ForgetGraph(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response) final override { + ::grpc::Status ForgetGraph( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response) final override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } // replace default version of method with streamed unary - virtual ::grpc::Status StreamedForgetGraph(::grpc::ServerContext* context, ::grpc::ServerUnaryStreamer< flatbuffers::grpc::Message,flatbuffers::grpc::Message>* server_unary_streamer) = 0; + virtual ::grpc::Status StreamedForgetGraph( + ::grpc::ServerContext* context, + ::grpc::ServerUnaryStreamer, + flatbuffers::grpc::Message>* + server_unary_streamer) = 0; }; template class WithStreamedUnaryMethod_ReplaceGraph : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: WithStreamedUnaryMethod_ReplaceGraph() { - ::grpc::Service::MarkMethodStreamed(2, - new ::grpc::internal::StreamedUnaryHandler< flatbuffers::grpc::Message, flatbuffers::grpc::Message>(std::bind(&WithStreamedUnaryMethod_ReplaceGraph::StreamedReplaceGraph, this, std::placeholders::_1, std::placeholders::_2))); + ::grpc::Service::MarkMethodStreamed( + 2, new ::grpc::internal::StreamedUnaryHandler< + flatbuffers::grpc::Message, + flatbuffers::grpc::Message>(std::bind( + &WithStreamedUnaryMethod_ReplaceGraph< + BaseClass>::StreamedReplaceGraph, + this, std::placeholders::_1, std::placeholders::_2))); } ~WithStreamedUnaryMethod_ReplaceGraph() override { BaseClassMustBeDerivedFromService(this); } // disable regular version of this method - ::grpc::Status ReplaceGraph(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response) final override { + ::grpc::Status ReplaceGraph( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response) final override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } // replace default version of method with streamed unary - virtual ::grpc::Status StreamedReplaceGraph(::grpc::ServerContext* context, ::grpc::ServerUnaryStreamer< flatbuffers::grpc::Message,flatbuffers::grpc::Message>* server_unary_streamer) = 0; + virtual ::grpc::Status StreamedReplaceGraph( + ::grpc::ServerContext* context, + ::grpc::ServerUnaryStreamer, + flatbuffers::grpc::Message>* + server_unary_streamer) = 0; }; template class WithStreamedUnaryMethod_InferenceRequest : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: WithStreamedUnaryMethod_InferenceRequest() { - ::grpc::Service::MarkMethodStreamed(3, - new ::grpc::internal::StreamedUnaryHandler< flatbuffers::grpc::Message, flatbuffers::grpc::Message>(std::bind(&WithStreamedUnaryMethod_InferenceRequest::StreamedInferenceRequest, this, std::placeholders::_1, std::placeholders::_2))); + ::grpc::Service::MarkMethodStreamed( + 3, new ::grpc::internal::StreamedUnaryHandler< + flatbuffers::grpc::Message, + flatbuffers::grpc::Message>(std::bind( + &WithStreamedUnaryMethod_InferenceRequest< + BaseClass>::StreamedInferenceRequest, + this, std::placeholders::_1, std::placeholders::_2))); } ~WithStreamedUnaryMethod_InferenceRequest() override { BaseClassMustBeDerivedFromService(this); } // disable regular version of this method - ::grpc::Status InferenceRequest(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response) final override { + ::grpc::Status InferenceRequest( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response) final override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } // replace default version of method with streamed unary - virtual ::grpc::Status StreamedInferenceRequest(::grpc::ServerContext* context, ::grpc::ServerUnaryStreamer< flatbuffers::grpc::Message,flatbuffers::grpc::Message>* server_unary_streamer) = 0; + virtual ::grpc::Status StreamedInferenceRequest( + ::grpc::ServerContext* context, + ::grpc::ServerUnaryStreamer< + flatbuffers::grpc::Message, + flatbuffers::grpc::Message>* server_unary_streamer) = 0; }; - typedef WithStreamedUnaryMethod_RegisterGraph< WithStreamedUnaryMethod_ForgetGraph< WithStreamedUnaryMethod_ReplaceGraph< WithStreamedUnaryMethod_InferenceRequest< Service > > > > StreamedUnaryService; - typedef Service SplitStreamedService; - typedef WithStreamedUnaryMethod_RegisterGraph< WithStreamedUnaryMethod_ForgetGraph< WithStreamedUnaryMethod_ReplaceGraph< WithStreamedUnaryMethod_InferenceRequest< Service > > > > StreamedService; + typedef WithStreamedUnaryMethod_RegisterGraph< + WithStreamedUnaryMethod_ForgetGraph>>> + StreamedUnaryService; + typedef Service SplitStreamedService; + typedef WithStreamedUnaryMethod_RegisterGraph< + WithStreamedUnaryMethod_ForgetGraph>>> + StreamedService; }; } // namespace graph } // namespace sd - #endif // GRPC_graph__INCLUDED diff --git a/libnd4j/include/graph/generated/graph_generated.h b/libnd4j/include/graph/generated/graph_generated.h index 1285e4607e91..6421ccdbee29 100644 --- a/libnd4j/include/graph/generated/graph_generated.h +++ b/libnd4j/include/graph/generated/graph_generated.h @@ -1,13 +1,11 @@ // automatically generated by the FlatBuffers compiler, do not modify - #ifndef FLATBUFFERS_GENERATED_GRAPH_ND4J_GRAPH_H_ #define FLATBUFFERS_GENERATED_GRAPH_ND4J_GRAPH_H_ -#include "flatbuffers/flatbuffers.h" - #include "array_generated.h" #include "config_generated.h" +#include "flatbuffers/flatbuffers.h" #include "node_generated.h" #include "properties_generated.h" #include "request_generated.h" @@ -27,23 +25,24 @@ struct FlatDropRequest; struct FlatResponse; struct UpdaterState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_PARAMNAME = 4, - VT_UPDATERSTATEKEYS = 6, - VT_UPDATERSTATEVALUES = 8 - }; + enum { VT_PARAMNAME = 4, VT_UPDATERSTATEKEYS = 6, VT_UPDATERSTATEVALUES = 8 }; const flatbuffers::String *paramName() const { return GetPointer(VT_PARAMNAME); } - const flatbuffers::Vector> *updaterStateKeys() const { - return GetPointer> *>(VT_UPDATERSTATEKEYS); + const flatbuffers::Vector> + *updaterStateKeys() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_UPDATERSTATEKEYS); } - const flatbuffers::Vector> *updaterStateValues() const { - return GetPointer> *>(VT_UPDATERSTATEVALUES); + const flatbuffers::Vector> + *updaterStateValues() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_UPDATERSTATEVALUES); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_PARAMNAME) && + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_PARAMNAME) && verifier.VerifyString(paramName()) && VerifyOffset(verifier, VT_UPDATERSTATEKEYS) && verifier.VerifyVector(updaterStateKeys()) && @@ -61,14 +60,19 @@ struct UpdaterStateBuilder { void add_paramName(flatbuffers::Offset paramName) { fbb_.AddOffset(UpdaterState::VT_PARAMNAME, paramName); } - void add_updaterStateKeys(flatbuffers::Offset>> updaterStateKeys) { + void add_updaterStateKeys( + flatbuffers::Offset< + flatbuffers::Vector>> + updaterStateKeys) { fbb_.AddOffset(UpdaterState::VT_UPDATERSTATEKEYS, updaterStateKeys); } - void add_updaterStateValues(flatbuffers::Offset>> updaterStateValues) { + void add_updaterStateValues( + flatbuffers::Offset>> + updaterStateValues) { fbb_.AddOffset(UpdaterState::VT_UPDATERSTATEVALUES, updaterStateValues); } explicit UpdaterStateBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } UpdaterStateBuilder &operator=(const UpdaterStateBuilder &); @@ -82,8 +86,11 @@ struct UpdaterStateBuilder { inline flatbuffers::Offset CreateUpdaterState( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset paramName = 0, - flatbuffers::Offset>> updaterStateKeys = 0, - flatbuffers::Offset>> updaterStateValues = 0) { + flatbuffers::Offset< + flatbuffers::Vector>> + updaterStateKeys = 0, + flatbuffers::Offset>> + updaterStateValues = 0) { UpdaterStateBuilder builder_(_fbb); builder_.add_updaterStateValues(updaterStateValues); builder_.add_updaterStateKeys(updaterStateKeys); @@ -92,15 +99,20 @@ inline flatbuffers::Offset CreateUpdaterState( } inline flatbuffers::Offset CreateUpdaterStateDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const char *paramName = nullptr, - const std::vector> *updaterStateKeys = nullptr, - const std::vector> *updaterStateValues = nullptr) { + flatbuffers::FlatBufferBuilder &_fbb, const char *paramName = nullptr, + const std::vector> + *updaterStateKeys = nullptr, + const std::vector> *updaterStateValues = + nullptr) { return sd::graph::CreateUpdaterState( - _fbb, - paramName ? _fbb.CreateString(paramName) : 0, - updaterStateKeys ? _fbb.CreateVector>(*updaterStateKeys) : 0, - updaterStateValues ? _fbb.CreateVector>(*updaterStateValues) : 0); + _fbb, paramName ? _fbb.CreateString(paramName) : 0, + updaterStateKeys + ? _fbb.CreateVector>( + *updaterStateKeys) + : 0, + updaterStateValues ? _fbb.CreateVector>( + *updaterStateValues) + : 0); } struct FlatGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -115,32 +127,44 @@ struct FlatGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_TRAININGCONFIG = 18, VT_UPDATERSTATE = 20 }; - int64_t id() const { - return GetField(VT_ID, 0); - } - const flatbuffers::Vector> *variables() const { - return GetPointer> *>(VT_VARIABLES); + int64_t id() const { return GetField(VT_ID, 0); } + const flatbuffers::Vector> *variables() + const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_VARIABLES); } const flatbuffers::Vector> *nodes() const { - return GetPointer> *>(VT_NODES); + return GetPointer< + const flatbuffers::Vector> *>(VT_NODES); } const flatbuffers::Vector> *outputs() const { - return GetPointer> *>(VT_OUTPUTS); + return GetPointer< + const flatbuffers::Vector> *>(VT_OUTPUTS); } const FlatConfiguration *configuration() const { return GetPointer(VT_CONFIGURATION); } - const flatbuffers::Vector> *placeholders() const { - return GetPointer> *>(VT_PLACEHOLDERS); + const flatbuffers::Vector> + *placeholders() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_PLACEHOLDERS); } - const flatbuffers::Vector> *lossVariables() const { - return GetPointer> *>(VT_LOSSVARIABLES); + const flatbuffers::Vector> + *lossVariables() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_LOSSVARIABLES); } const flatbuffers::String *trainingConfig() const { return GetPointer(VT_TRAININGCONFIG); } - const flatbuffers::Vector> *updaterState() const { - return GetPointer> *>(VT_UPDATERSTATE); + const flatbuffers::Vector> *updaterState() + const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_UPDATERSTATE); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -148,8 +172,7 @@ struct FlatGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyOffset(verifier, VT_VARIABLES) && verifier.VerifyVector(variables()) && verifier.VerifyVectorOfTables(variables()) && - VerifyOffset(verifier, VT_NODES) && - verifier.VerifyVector(nodes()) && + VerifyOffset(verifier, VT_NODES) && verifier.VerifyVector(nodes()) && verifier.VerifyVectorOfTables(nodes()) && VerifyOffset(verifier, VT_OUTPUTS) && verifier.VerifyVector(outputs()) && @@ -166,43 +189,54 @@ struct FlatGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyString(trainingConfig()) && VerifyOffset(verifier, VT_UPDATERSTATE) && verifier.VerifyVector(updaterState()) && - verifier.VerifyVectorOfTables(updaterState()) && - verifier.EndTable(); + verifier.VerifyVectorOfTables(updaterState()) && verifier.EndTable(); } }; struct FlatGraphBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_id(int64_t id) { - fbb_.AddElement(FlatGraph::VT_ID, id, 0); - } - void add_variables(flatbuffers::Offset>> variables) { + void add_id(int64_t id) { fbb_.AddElement(FlatGraph::VT_ID, id, 0); } + void add_variables(flatbuffers::Offset< + flatbuffers::Vector>> + variables) { fbb_.AddOffset(FlatGraph::VT_VARIABLES, variables); } - void add_nodes(flatbuffers::Offset>> nodes) { + void add_nodes( + flatbuffers::Offset>> + nodes) { fbb_.AddOffset(FlatGraph::VT_NODES, nodes); } - void add_outputs(flatbuffers::Offset>> outputs) { + void add_outputs( + flatbuffers::Offset>> + outputs) { fbb_.AddOffset(FlatGraph::VT_OUTPUTS, outputs); } void add_configuration(flatbuffers::Offset configuration) { fbb_.AddOffset(FlatGraph::VT_CONFIGURATION, configuration); } - void add_placeholders(flatbuffers::Offset>> placeholders) { + void add_placeholders( + flatbuffers::Offset< + flatbuffers::Vector>> + placeholders) { fbb_.AddOffset(FlatGraph::VT_PLACEHOLDERS, placeholders); } - void add_lossVariables(flatbuffers::Offset>> lossVariables) { + void add_lossVariables( + flatbuffers::Offset< + flatbuffers::Vector>> + lossVariables) { fbb_.AddOffset(FlatGraph::VT_LOSSVARIABLES, lossVariables); } - void add_trainingConfig(flatbuffers::Offset trainingConfig) { + void add_trainingConfig( + flatbuffers::Offset trainingConfig) { fbb_.AddOffset(FlatGraph::VT_TRAININGCONFIG, trainingConfig); } - void add_updaterState(flatbuffers::Offset>> updaterState) { + void add_updaterState(flatbuffers::Offset< + flatbuffers::Vector>> + updaterState) { fbb_.AddOffset(FlatGraph::VT_UPDATERSTATE, updaterState); } - explicit FlatGraphBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit FlatGraphBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } FlatGraphBuilder &operator=(const FlatGraphBuilder &); @@ -214,16 +248,23 @@ struct FlatGraphBuilder { }; inline flatbuffers::Offset CreateFlatGraph( - flatbuffers::FlatBufferBuilder &_fbb, - int64_t id = 0, - flatbuffers::Offset>> variables = 0, - flatbuffers::Offset>> nodes = 0, - flatbuffers::Offset>> outputs = 0, + flatbuffers::FlatBufferBuilder &_fbb, int64_t id = 0, + flatbuffers::Offset>> + variables = 0, + flatbuffers::Offset>> + nodes = 0, + flatbuffers::Offset>> + outputs = 0, flatbuffers::Offset configuration = 0, - flatbuffers::Offset>> placeholders = 0, - flatbuffers::Offset>> lossVariables = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + placeholders = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + lossVariables = 0, flatbuffers::Offset trainingConfig = 0, - flatbuffers::Offset>> updaterState = 0) { + flatbuffers::Offset>> + updaterState = 0) { FlatGraphBuilder builder_(_fbb); builder_.add_id(id); builder_.add_updaterState(updaterState); @@ -238,40 +279,46 @@ inline flatbuffers::Offset CreateFlatGraph( } inline flatbuffers::Offset CreateFlatGraphDirect( - flatbuffers::FlatBufferBuilder &_fbb, - int64_t id = 0, + flatbuffers::FlatBufferBuilder &_fbb, int64_t id = 0, const std::vector> *variables = nullptr, const std::vector> *nodes = nullptr, const std::vector> *outputs = nullptr, flatbuffers::Offset configuration = 0, - const std::vector> *placeholders = nullptr, - const std::vector> *lossVariables = nullptr, + const std::vector> *placeholders = + nullptr, + const std::vector> *lossVariables = + nullptr, const char *trainingConfig = nullptr, - const std::vector> *updaterState = nullptr) { + const std::vector> *updaterState = + nullptr) { return sd::graph::CreateFlatGraph( - _fbb, - id, - variables ? _fbb.CreateVector>(*variables) : 0, + _fbb, id, + variables + ? _fbb.CreateVector>(*variables) + : 0, nodes ? _fbb.CreateVector>(*nodes) : 0, outputs ? _fbb.CreateVector>(*outputs) : 0, configuration, - placeholders ? _fbb.CreateVector>(*placeholders) : 0, - lossVariables ? _fbb.CreateVector>(*lossVariables) : 0, + placeholders + ? _fbb.CreateVector>( + *placeholders) + : 0, + lossVariables + ? _fbb.CreateVector>( + *lossVariables) + : 0, trainingConfig ? _fbb.CreateString(trainingConfig) : 0, - updaterState ? _fbb.CreateVector>(*updaterState) : 0); + updaterState + ? _fbb.CreateVector>(*updaterState) + : 0); } struct FlatDropRequest FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_ID = 4 - }; - int64_t id() const { - return GetField(VT_ID, 0); - } + enum { VT_ID = 4 }; + int64_t id() const { return GetField(VT_ID, 0); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_ID) && - verifier.EndTable(); + VerifyField(verifier, VT_ID) && verifier.EndTable(); } }; @@ -282,7 +329,7 @@ struct FlatDropRequestBuilder { fbb_.AddElement(FlatDropRequest::VT_ID, id, 0); } explicit FlatDropRequestBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } FlatDropRequestBuilder &operator=(const FlatDropRequestBuilder &); @@ -294,24 +341,18 @@ struct FlatDropRequestBuilder { }; inline flatbuffers::Offset CreateFlatDropRequest( - flatbuffers::FlatBufferBuilder &_fbb, - int64_t id = 0) { + flatbuffers::FlatBufferBuilder &_fbb, int64_t id = 0) { FlatDropRequestBuilder builder_(_fbb); builder_.add_id(id); return builder_.Finish(); } struct FlatResponse FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_STATUS = 4 - }; - int32_t status() const { - return GetField(VT_STATUS, 0); - } + enum { VT_STATUS = 4 }; + int32_t status() const { return GetField(VT_STATUS, 0); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_STATUS) && - verifier.EndTable(); + VerifyField(verifier, VT_STATUS) && verifier.EndTable(); } }; @@ -322,7 +363,7 @@ struct FlatResponseBuilder { fbb_.AddElement(FlatResponse::VT_STATUS, status, 0); } explicit FlatResponseBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } FlatResponseBuilder &operator=(const FlatResponseBuilder &); @@ -334,8 +375,7 @@ struct FlatResponseBuilder { }; inline flatbuffers::Offset CreateFlatResponse( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t status = 0) { + flatbuffers::FlatBufferBuilder &_fbb, int32_t status = 0) { FlatResponseBuilder builder_(_fbb); builder_.add_status(status); return builder_.Finish(); @@ -349,13 +389,11 @@ inline const sd::graph::FlatGraph *GetSizePrefixedFlatGraph(const void *buf) { return flatbuffers::GetSizePrefixedRoot(buf); } -inline bool VerifyFlatGraphBuffer( - flatbuffers::Verifier &verifier) { +inline bool VerifyFlatGraphBuffer(flatbuffers::Verifier &verifier) { return verifier.VerifyBuffer(nullptr); } -inline bool VerifySizePrefixedFlatGraphBuffer( - flatbuffers::Verifier &verifier) { +inline bool VerifySizePrefixedFlatGraphBuffer(flatbuffers::Verifier &verifier) { return verifier.VerifySizePrefixedBuffer(nullptr); } diff --git a/libnd4j/include/graph/generated/node_generated.h b/libnd4j/include/graph/generated/node_generated.h index a39f2490c9d2..503cded18db6 100644 --- a/libnd4j/include/graph/generated/node_generated.h +++ b/libnd4j/include/graph/generated/node_generated.h @@ -1,12 +1,10 @@ // automatically generated by the FlatBuffers compiler, do not modify - #ifndef FLATBUFFERS_GENERATED_NODE_ND4J_GRAPH_H_ #define FLATBUFFERS_GENERATED_NODE_ND4J_GRAPH_H_ -#include "flatbuffers/flatbuffers.h" - #include "array_generated.h" +#include "flatbuffers/flatbuffers.h" #include "properties_generated.h" #include "utils_generated.h" @@ -41,26 +39,27 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_CONTROLDEPFOR = 46, VT_EXTRATYPES = 48 }; - int32_t id() const { - return GetField(VT_ID, 0); - } + int32_t id() const { return GetField(VT_ID, 0); } const flatbuffers::String *name() const { return GetPointer(VT_NAME); } OpType opType() const { return static_cast(GetField(VT_OPTYPE, 0)); } - int64_t opNum() const { - return GetField(VT_OPNUM, 0); - } - const flatbuffers::Vector> *properties() const { - return GetPointer> *>(VT_PROPERTIES); + int64_t opNum() const { return GetField(VT_OPNUM, 0); } + const flatbuffers::Vector> *properties() + const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_PROPERTIES); } const flatbuffers::Vector *input() const { return GetPointer *>(VT_INPUT); } const flatbuffers::Vector> *inputPaired() const { - return GetPointer> *>(VT_INPUTPAIRED); + return GetPointer< + const flatbuffers::Vector> *>( + VT_INPUTPAIRED); } const flatbuffers::Vector *output() const { return GetPointer *>(VT_OUTPUT); @@ -77,17 +76,16 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::Vector *dimensions() const { return GetPointer *>(VT_DIMENSIONS); } - int32_t device() const { - return GetField(VT_DEVICE, 0); - } - int32_t scope_id() const { - return GetField(VT_SCOPE_ID, 0); - } + int32_t device() const { return GetField(VT_DEVICE, 0); } + int32_t scope_id() const { return GetField(VT_SCOPE_ID, 0); } const flatbuffers::String *scope_name() const { return GetPointer(VT_SCOPE_NAME); } - const flatbuffers::Vector> *outputNames() const { - return GetPointer> *>(VT_OUTPUTNAMES); + const flatbuffers::Vector> + *outputNames() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_OUTPUTNAMES); } const flatbuffers::String *opName() const { return GetPointer(VT_OPNAME); @@ -98,14 +96,23 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const FlatArray *scalar() const { return GetPointer(VT_SCALAR); } - const flatbuffers::Vector> *controlDeps() const { - return GetPointer> *>(VT_CONTROLDEPS); + const flatbuffers::Vector> + *controlDeps() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_CONTROLDEPS); } - const flatbuffers::Vector> *varControlDeps() const { - return GetPointer> *>(VT_VARCONTROLDEPS); + const flatbuffers::Vector> + *varControlDeps() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_VARCONTROLDEPS); } - const flatbuffers::Vector> *controlDepFor() const { - return GetPointer> *>(VT_CONTROLDEPFOR); + const flatbuffers::Vector> + *controlDepFor() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_CONTROLDEPFOR); } const flatbuffers::Vector *extraTypes() const { return GetPointer *>(VT_EXTRATYPES); @@ -113,15 +120,13 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_ID) && - VerifyOffset(verifier, VT_NAME) && - verifier.VerifyString(name()) && + VerifyOffset(verifier, VT_NAME) && verifier.VerifyString(name()) && VerifyField(verifier, VT_OPTYPE) && VerifyField(verifier, VT_OPNUM) && VerifyOffset(verifier, VT_PROPERTIES) && verifier.VerifyVector(properties()) && verifier.VerifyVectorOfTables(properties()) && - VerifyOffset(verifier, VT_INPUT) && - verifier.VerifyVector(input()) && + VerifyOffset(verifier, VT_INPUT) && verifier.VerifyVector(input()) && VerifyOffset(verifier, VT_INPUTPAIRED) && verifier.VerifyVector(inputPaired()) && verifier.VerifyVectorOfTables(inputPaired()) && @@ -158,48 +163,54 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyVector(controlDepFor()) && verifier.VerifyVectorOfStrings(controlDepFor()) && VerifyOffset(verifier, VT_EXTRATYPES) && - verifier.VerifyVector(extraTypes()) && - verifier.EndTable(); + verifier.VerifyVector(extraTypes()) && verifier.EndTable(); } }; struct FlatNodeBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_id(int32_t id) { - fbb_.AddElement(FlatNode::VT_ID, id, 0); - } + void add_id(int32_t id) { fbb_.AddElement(FlatNode::VT_ID, id, 0); } void add_name(flatbuffers::Offset name) { fbb_.AddOffset(FlatNode::VT_NAME, name); } void add_opType(OpType opType) { - fbb_.AddElement(FlatNode::VT_OPTYPE, static_cast(opType), 0); + fbb_.AddElement(FlatNode::VT_OPTYPE, static_cast(opType), + 0); } void add_opNum(int64_t opNum) { fbb_.AddElement(FlatNode::VT_OPNUM, opNum, 0); } - void add_properties(flatbuffers::Offset>> properties) { + void add_properties(flatbuffers::Offset< + flatbuffers::Vector>> + properties) { fbb_.AddOffset(FlatNode::VT_PROPERTIES, properties); } void add_input(flatbuffers::Offset> input) { fbb_.AddOffset(FlatNode::VT_INPUT, input); } - void add_inputPaired(flatbuffers::Offset>> inputPaired) { + void add_inputPaired( + flatbuffers::Offset>> + inputPaired) { fbb_.AddOffset(FlatNode::VT_INPUTPAIRED, inputPaired); } void add_output(flatbuffers::Offset> output) { fbb_.AddOffset(FlatNode::VT_OUTPUT, output); } - void add_extraParams(flatbuffers::Offset> extraParams) { + void add_extraParams( + flatbuffers::Offset> extraParams) { fbb_.AddOffset(FlatNode::VT_EXTRAPARAMS, extraParams); } - void add_extraInteger(flatbuffers::Offset> extraInteger) { + void add_extraInteger( + flatbuffers::Offset> extraInteger) { fbb_.AddOffset(FlatNode::VT_EXTRAINTEGER, extraInteger); } - void add_extraBools(flatbuffers::Offset> extraBools) { + void add_extraBools( + flatbuffers::Offset> extraBools) { fbb_.AddOffset(FlatNode::VT_EXTRABOOLS, extraBools); } - void add_dimensions(flatbuffers::Offset> dimensions) { + void add_dimensions( + flatbuffers::Offset> dimensions) { fbb_.AddOffset(FlatNode::VT_DIMENSIONS, dimensions); } void add_device(int32_t device) { @@ -211,32 +222,45 @@ struct FlatNodeBuilder { void add_scope_name(flatbuffers::Offset scope_name) { fbb_.AddOffset(FlatNode::VT_SCOPE_NAME, scope_name); } - void add_outputNames(flatbuffers::Offset>> outputNames) { + void add_outputNames( + flatbuffers::Offset< + flatbuffers::Vector>> + outputNames) { fbb_.AddOffset(FlatNode::VT_OUTPUTNAMES, outputNames); } void add_opName(flatbuffers::Offset opName) { fbb_.AddOffset(FlatNode::VT_OPNAME, opName); } - void add_outputTypes(flatbuffers::Offset> outputTypes) { + void add_outputTypes( + flatbuffers::Offset> outputTypes) { fbb_.AddOffset(FlatNode::VT_OUTPUTTYPES, outputTypes); } void add_scalar(flatbuffers::Offset scalar) { fbb_.AddOffset(FlatNode::VT_SCALAR, scalar); } - void add_controlDeps(flatbuffers::Offset>> controlDeps) { + void add_controlDeps( + flatbuffers::Offset< + flatbuffers::Vector>> + controlDeps) { fbb_.AddOffset(FlatNode::VT_CONTROLDEPS, controlDeps); } - void add_varControlDeps(flatbuffers::Offset>> varControlDeps) { + void add_varControlDeps( + flatbuffers::Offset< + flatbuffers::Vector>> + varControlDeps) { fbb_.AddOffset(FlatNode::VT_VARCONTROLDEPS, varControlDeps); } - void add_controlDepFor(flatbuffers::Offset>> controlDepFor) { + void add_controlDepFor( + flatbuffers::Offset< + flatbuffers::Vector>> + controlDepFor) { fbb_.AddOffset(FlatNode::VT_CONTROLDEPFOR, controlDepFor); } - void add_extraTypes(flatbuffers::Offset> extraTypes) { + void add_extraTypes( + flatbuffers::Offset> extraTypes) { fbb_.AddOffset(FlatNode::VT_EXTRATYPES, extraTypes); } - explicit FlatNodeBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit FlatNodeBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } FlatNodeBuilder &operator=(const FlatNodeBuilder &); @@ -248,29 +272,37 @@ struct FlatNodeBuilder { }; inline flatbuffers::Offset CreateFlatNode( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t id = 0, + flatbuffers::FlatBufferBuilder &_fbb, int32_t id = 0, flatbuffers::Offset name = 0, - OpType opType = OpType_TRANSFORM_FLOAT, - int64_t opNum = 0, - flatbuffers::Offset>> properties = 0, + OpType opType = OpType_TRANSFORM_FLOAT, int64_t opNum = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + properties = 0, flatbuffers::Offset> input = 0, - flatbuffers::Offset>> inputPaired = 0, + flatbuffers::Offset>> + inputPaired = 0, flatbuffers::Offset> output = 0, flatbuffers::Offset> extraParams = 0, flatbuffers::Offset> extraInteger = 0, flatbuffers::Offset> extraBools = 0, flatbuffers::Offset> dimensions = 0, - int32_t device = 0, - int32_t scope_id = 0, + int32_t device = 0, int32_t scope_id = 0, flatbuffers::Offset scope_name = 0, - flatbuffers::Offset>> outputNames = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + outputNames = 0, flatbuffers::Offset opName = 0, flatbuffers::Offset> outputTypes = 0, flatbuffers::Offset scalar = 0, - flatbuffers::Offset>> controlDeps = 0, - flatbuffers::Offset>> varControlDeps = 0, - flatbuffers::Offset>> controlDepFor = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + controlDeps = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + varControlDeps = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + controlDepFor = 0, flatbuffers::Offset> extraTypes = 0) { FlatNodeBuilder builder_(_fbb); builder_.add_opNum(opNum); @@ -300,54 +332,62 @@ inline flatbuffers::Offset CreateFlatNode( } inline flatbuffers::Offset CreateFlatNodeDirect( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t id = 0, - const char *name = nullptr, - OpType opType = OpType_TRANSFORM_FLOAT, + flatbuffers::FlatBufferBuilder &_fbb, int32_t id = 0, + const char *name = nullptr, OpType opType = OpType_TRANSFORM_FLOAT, int64_t opNum = 0, - const std::vector> *properties = nullptr, + const std::vector> *properties = + nullptr, const std::vector *input = nullptr, const std::vector> *inputPaired = nullptr, const std::vector *output = nullptr, const std::vector *extraParams = nullptr, const std::vector *extraInteger = nullptr, const std::vector *extraBools = nullptr, - const std::vector *dimensions = nullptr, - int32_t device = 0, - int32_t scope_id = 0, - const char *scope_name = nullptr, - const std::vector> *outputNames = nullptr, + const std::vector *dimensions = nullptr, int32_t device = 0, + int32_t scope_id = 0, const char *scope_name = nullptr, + const std::vector> *outputNames = + nullptr, const char *opName = nullptr, const std::vector *outputTypes = nullptr, flatbuffers::Offset scalar = 0, - const std::vector> *controlDeps = nullptr, - const std::vector> *varControlDeps = nullptr, - const std::vector> *controlDepFor = nullptr, + const std::vector> *controlDeps = + nullptr, + const std::vector> + *varControlDeps = nullptr, + const std::vector> *controlDepFor = + nullptr, const std::vector *extraTypes = nullptr) { return sd::graph::CreateFlatNode( - _fbb, - id, - name ? _fbb.CreateString(name) : 0, - opType, - opNum, - properties ? _fbb.CreateVector>(*properties) : 0, + _fbb, id, name ? _fbb.CreateString(name) : 0, opType, opNum, + properties + ? _fbb.CreateVector>(*properties) + : 0, input ? _fbb.CreateVector(*input) : 0, - inputPaired ? _fbb.CreateVector>(*inputPaired) : 0, + inputPaired + ? _fbb.CreateVector>(*inputPaired) + : 0, output ? _fbb.CreateVector(*output) : 0, extraParams ? _fbb.CreateVector(*extraParams) : 0, extraInteger ? _fbb.CreateVector(*extraInteger) : 0, extraBools ? _fbb.CreateVector(*extraBools) : 0, - dimensions ? _fbb.CreateVector(*dimensions) : 0, - device, - scope_id, - scope_name ? _fbb.CreateString(scope_name) : 0, - outputNames ? _fbb.CreateVector>(*outputNames) : 0, + dimensions ? _fbb.CreateVector(*dimensions) : 0, device, + scope_id, scope_name ? _fbb.CreateString(scope_name) : 0, + outputNames ? _fbb.CreateVector>( + *outputNames) + : 0, opName ? _fbb.CreateString(opName) : 0, - outputTypes ? _fbb.CreateVector(*outputTypes) : 0, - scalar, - controlDeps ? _fbb.CreateVector>(*controlDeps) : 0, - varControlDeps ? _fbb.CreateVector>(*varControlDeps) : 0, - controlDepFor ? _fbb.CreateVector>(*controlDepFor) : 0, + outputTypes ? _fbb.CreateVector(*outputTypes) : 0, scalar, + controlDeps ? _fbb.CreateVector>( + *controlDeps) + : 0, + varControlDeps + ? _fbb.CreateVector>( + *varControlDeps) + : 0, + controlDepFor + ? _fbb.CreateVector>( + *controlDepFor) + : 0, extraTypes ? _fbb.CreateVector(*extraTypes) : 0); } @@ -359,13 +399,11 @@ inline const sd::graph::FlatNode *GetSizePrefixedFlatNode(const void *buf) { return flatbuffers::GetSizePrefixedRoot(buf); } -inline bool VerifyFlatNodeBuffer( - flatbuffers::Verifier &verifier) { +inline bool VerifyFlatNodeBuffer(flatbuffers::Verifier &verifier) { return verifier.VerifyBuffer(nullptr); } -inline bool VerifySizePrefixedFlatNodeBuffer( - flatbuffers::Verifier &verifier) { +inline bool VerifySizePrefixedFlatNodeBuffer(flatbuffers::Verifier &verifier) { return verifier.VerifySizePrefixedBuffer(nullptr); } diff --git a/libnd4j/include/graph/generated/properties_generated.h b/libnd4j/include/graph/generated/properties_generated.h index 34138fe86264..21986255d6e1 100644 --- a/libnd4j/include/graph/generated/properties_generated.h +++ b/libnd4j/include/graph/generated/properties_generated.h @@ -1,12 +1,10 @@ // automatically generated by the FlatBuffers compiler, do not modify - #ifndef FLATBUFFERS_GENERATED_PROPERTIES_ND4J_GRAPH_H_ #define FLATBUFFERS_GENERATED_PROPERTIES_ND4J_GRAPH_H_ -#include "flatbuffers/flatbuffers.h" - #include "array_generated.h" +#include "flatbuffers/flatbuffers.h" namespace sd { namespace graph { @@ -37,37 +35,32 @@ struct FlatProperties FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return GetPointer *>(VT_D); } const flatbuffers::Vector> *a() const { - return GetPointer> *>(VT_A); + return GetPointer< + const flatbuffers::Vector> *>(VT_A); } const flatbuffers::Vector *b() const { return GetPointer *>(VT_B); } - const flatbuffers::Vector> *s() const { - return GetPointer> *>(VT_S); + const flatbuffers::Vector> *s() + const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_S); } const flatbuffers::Vector *shape() const { return GetPointer *>(VT_SHAPE); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_NAME) && - verifier.VerifyString(name()) && - VerifyOffset(verifier, VT_I) && - verifier.VerifyVector(i()) && - VerifyOffset(verifier, VT_L) && - verifier.VerifyVector(l()) && - VerifyOffset(verifier, VT_D) && - verifier.VerifyVector(d()) && - VerifyOffset(verifier, VT_A) && - verifier.VerifyVector(a()) && - verifier.VerifyVectorOfTables(a()) && - VerifyOffset(verifier, VT_B) && - verifier.VerifyVector(b()) && - VerifyOffset(verifier, VT_S) && - verifier.VerifyVector(s()) && + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && VerifyOffset(verifier, VT_I) && + verifier.VerifyVector(i()) && VerifyOffset(verifier, VT_L) && + verifier.VerifyVector(l()) && VerifyOffset(verifier, VT_D) && + verifier.VerifyVector(d()) && VerifyOffset(verifier, VT_A) && + verifier.VerifyVector(a()) && verifier.VerifyVectorOfTables(a()) && + VerifyOffset(verifier, VT_B) && verifier.VerifyVector(b()) && + VerifyOffset(verifier, VT_S) && verifier.VerifyVector(s()) && verifier.VerifyVectorOfStrings(s()) && - VerifyOffset(verifier, VT_SHAPE) && - verifier.VerifyVector(shape()) && + VerifyOffset(verifier, VT_SHAPE) && verifier.VerifyVector(shape()) && verifier.EndTable(); } }; @@ -87,20 +80,24 @@ struct FlatPropertiesBuilder { void add_d(flatbuffers::Offset> d) { fbb_.AddOffset(FlatProperties::VT_D, d); } - void add_a(flatbuffers::Offset>> a) { + void add_a( + flatbuffers::Offset>> + a) { fbb_.AddOffset(FlatProperties::VT_A, a); } void add_b(flatbuffers::Offset> b) { fbb_.AddOffset(FlatProperties::VT_B, b); } - void add_s(flatbuffers::Offset>> s) { + void add_s(flatbuffers::Offset< + flatbuffers::Vector>> + s) { fbb_.AddOffset(FlatProperties::VT_S, s); } void add_shape(flatbuffers::Offset> shape) { fbb_.AddOffset(FlatProperties::VT_SHAPE, shape); } explicit FlatPropertiesBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } FlatPropertiesBuilder &operator=(const FlatPropertiesBuilder &); @@ -117,9 +114,12 @@ inline flatbuffers::Offset CreateFlatProperties( flatbuffers::Offset> i = 0, flatbuffers::Offset> l = 0, flatbuffers::Offset> d = 0, - flatbuffers::Offset>> a = 0, + flatbuffers::Offset>> a = + 0, flatbuffers::Offset> b = 0, - flatbuffers::Offset>> s = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + s = 0, flatbuffers::Offset> shape = 0) { FlatPropertiesBuilder builder_(_fbb); builder_.add_shape(shape); @@ -134,8 +134,7 @@ inline flatbuffers::Offset CreateFlatProperties( } inline flatbuffers::Offset CreateFlatPropertiesDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const char *name = nullptr, + flatbuffers::FlatBufferBuilder &_fbb, const char *name = nullptr, const std::vector *i = nullptr, const std::vector *l = nullptr, const std::vector *d = nullptr, @@ -144,8 +143,7 @@ inline flatbuffers::Offset CreateFlatPropertiesDirect( const std::vector> *s = nullptr, const std::vector *shape = nullptr) { return sd::graph::CreateFlatProperties( - _fbb, - name ? _fbb.CreateString(name) : 0, + _fbb, name ? _fbb.CreateString(name) : 0, i ? _fbb.CreateVector(*i) : 0, l ? _fbb.CreateVector(*l) : 0, d ? _fbb.CreateVector(*d) : 0, @@ -159,12 +157,12 @@ inline const sd::graph::FlatProperties *GetFlatProperties(const void *buf) { return flatbuffers::GetRoot(buf); } -inline const sd::graph::FlatProperties *GetSizePrefixedFlatProperties(const void *buf) { +inline const sd::graph::FlatProperties *GetSizePrefixedFlatProperties( + const void *buf) { return flatbuffers::GetSizePrefixedRoot(buf); } -inline bool VerifyFlatPropertiesBuffer( - flatbuffers::Verifier &verifier) { +inline bool VerifyFlatPropertiesBuffer(flatbuffers::Verifier &verifier) { return verifier.VerifyBuffer(nullptr); } diff --git a/libnd4j/include/graph/generated/request_generated.h b/libnd4j/include/graph/generated/request_generated.h index 00c782311b05..970336dd3393 100644 --- a/libnd4j/include/graph/generated/request_generated.h +++ b/libnd4j/include/graph/generated/request_generated.h @@ -1,13 +1,11 @@ // automatically generated by the FlatBuffers compiler, do not modify - #ifndef FLATBUFFERS_GENERATED_REQUEST_ND4J_GRAPH_H_ #define FLATBUFFERS_GENERATED_REQUEST_ND4J_GRAPH_H_ -#include "flatbuffers/flatbuffers.h" - #include "array_generated.h" #include "config_generated.h" +#include "flatbuffers/flatbuffers.h" #include "utils_generated.h" #include "variable_generated.h" @@ -16,17 +14,15 @@ namespace graph { struct FlatInferenceRequest; -struct FlatInferenceRequest FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_ID = 4, - VT_VARIABLES = 6, - VT_CONFIGURATION = 8 - }; - int64_t id() const { - return GetField(VT_ID, 0); - } - const flatbuffers::Vector> *variables() const { - return GetPointer> *>(VT_VARIABLES); +struct FlatInferenceRequest FLATBUFFERS_FINAL_CLASS + : private flatbuffers::Table { + enum { VT_ID = 4, VT_VARIABLES = 6, VT_CONFIGURATION = 8 }; + int64_t id() const { return GetField(VT_ID, 0); } + const flatbuffers::Vector> *variables() + const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_VARIABLES); } const FlatConfiguration *configuration() const { return GetPointer(VT_CONFIGURATION); @@ -38,8 +34,7 @@ struct FlatInferenceRequest FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table verifier.VerifyVector(variables()) && verifier.VerifyVectorOfTables(variables()) && VerifyOffset(verifier, VT_CONFIGURATION) && - verifier.VerifyTable(configuration()) && - verifier.EndTable(); + verifier.VerifyTable(configuration()) && verifier.EndTable(); } }; @@ -49,14 +44,16 @@ struct FlatInferenceRequestBuilder { void add_id(int64_t id) { fbb_.AddElement(FlatInferenceRequest::VT_ID, id, 0); } - void add_variables(flatbuffers::Offset>> variables) { + void add_variables(flatbuffers::Offset< + flatbuffers::Vector>> + variables) { fbb_.AddOffset(FlatInferenceRequest::VT_VARIABLES, variables); } void add_configuration(flatbuffers::Offset configuration) { fbb_.AddOffset(FlatInferenceRequest::VT_CONFIGURATION, configuration); } explicit FlatInferenceRequestBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } FlatInferenceRequestBuilder &operator=(const FlatInferenceRequestBuilder &); @@ -68,9 +65,9 @@ struct FlatInferenceRequestBuilder { }; inline flatbuffers::Offset CreateFlatInferenceRequest( - flatbuffers::FlatBufferBuilder &_fbb, - int64_t id = 0, - flatbuffers::Offset>> variables = 0, + flatbuffers::FlatBufferBuilder &_fbb, int64_t id = 0, + flatbuffers::Offset>> + variables = 0, flatbuffers::Offset configuration = 0) { FlatInferenceRequestBuilder builder_(_fbb); builder_.add_id(id); @@ -79,34 +76,37 @@ inline flatbuffers::Offset CreateFlatInferenceRequest( return builder_.Finish(); } -inline flatbuffers::Offset CreateFlatInferenceRequestDirect( - flatbuffers::FlatBufferBuilder &_fbb, - int64_t id = 0, +inline flatbuffers::Offset +CreateFlatInferenceRequestDirect( + flatbuffers::FlatBufferBuilder &_fbb, int64_t id = 0, const std::vector> *variables = nullptr, flatbuffers::Offset configuration = 0) { return sd::graph::CreateFlatInferenceRequest( - _fbb, - id, - variables ? _fbb.CreateVector>(*variables) : 0, + _fbb, id, + variables + ? _fbb.CreateVector>(*variables) + : 0, configuration); } -inline const sd::graph::FlatInferenceRequest *GetFlatInferenceRequest(const void *buf) { +inline const sd::graph::FlatInferenceRequest *GetFlatInferenceRequest( + const void *buf) { return flatbuffers::GetRoot(buf); } -inline const sd::graph::FlatInferenceRequest *GetSizePrefixedFlatInferenceRequest(const void *buf) { +inline const sd::graph::FlatInferenceRequest * +GetSizePrefixedFlatInferenceRequest(const void *buf) { return flatbuffers::GetSizePrefixedRoot(buf); } -inline bool VerifyFlatInferenceRequestBuffer( - flatbuffers::Verifier &verifier) { +inline bool VerifyFlatInferenceRequestBuffer(flatbuffers::Verifier &verifier) { return verifier.VerifyBuffer(nullptr); } inline bool VerifySizePrefixedFlatInferenceRequestBuffer( flatbuffers::Verifier &verifier) { - return verifier.VerifySizePrefixedBuffer(nullptr); + return verifier.VerifySizePrefixedBuffer( + nullptr); } inline void FinishFlatInferenceRequestBuffer( diff --git a/libnd4j/include/graph/generated/result_generated.h b/libnd4j/include/graph/generated/result_generated.h index 04c458a9fde1..b05a0ed44300 100644 --- a/libnd4j/include/graph/generated/result_generated.h +++ b/libnd4j/include/graph/generated/result_generated.h @@ -1,12 +1,10 @@ // automatically generated by the FlatBuffers compiler, do not modify - #ifndef FLATBUFFERS_GENERATED_RESULT_ND4J_GRAPH_H_ #define FLATBUFFERS_GENERATED_RESULT_ND4J_GRAPH_H_ -#include "flatbuffers/flatbuffers.h" - #include "array_generated.h" +#include "flatbuffers/flatbuffers.h" #include "node_generated.h" #include "properties_generated.h" #include "utils_generated.h" @@ -20,14 +18,8 @@ struct FlatTiming; struct FlatResult; struct FlatTiming FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_ID = 4, - VT_NAME = 6, - VT_TIMING = 8 - }; - int32_t id() const { - return GetField(VT_ID, 0); - } + enum { VT_ID = 4, VT_NAME = 6, VT_TIMING = 8 }; + int32_t id() const { return GetField(VT_ID, 0); } const flatbuffers::String *name() const { return GetPointer(VT_NAME); } @@ -37,11 +29,9 @@ struct FlatTiming FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_ID) && - VerifyOffset(verifier, VT_NAME) && - verifier.VerifyString(name()) && + VerifyOffset(verifier, VT_NAME) && verifier.VerifyString(name()) && VerifyOffset(verifier, VT_TIMING) && - verifier.VerifyTable(timing()) && - verifier.EndTable(); + verifier.VerifyTable(timing()) && verifier.EndTable(); } }; @@ -58,7 +48,7 @@ struct FlatTimingBuilder { fbb_.AddOffset(FlatTiming::VT_TIMING, timing); } explicit FlatTimingBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } FlatTimingBuilder &operator=(const FlatTimingBuilder &); @@ -70,8 +60,7 @@ struct FlatTimingBuilder { }; inline flatbuffers::Offset CreateFlatTiming( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t id = 0, + flatbuffers::FlatBufferBuilder &_fbb, int32_t id = 0, flatbuffers::Offset name = 0, flatbuffers::Offset timing = 0) { FlatTimingBuilder builder_(_fbb); @@ -82,15 +71,10 @@ inline flatbuffers::Offset CreateFlatTiming( } inline flatbuffers::Offset CreateFlatTimingDirect( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t id = 0, - const char *name = nullptr, - flatbuffers::Offset timing = 0) { + flatbuffers::FlatBufferBuilder &_fbb, int32_t id = 0, + const char *name = nullptr, flatbuffers::Offset timing = 0) { return sd::graph::CreateFlatTiming( - _fbb, - id, - name ? _fbb.CreateString(name) : 0, - timing); + _fbb, id, name ? _fbb.CreateString(name) : 0, timing); } struct FlatResult FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -101,14 +85,17 @@ struct FlatResult FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_FOOTPRINTFORWARD = 10, VT_FOOTPRINTBACKWARD = 12 }; - int64_t id() const { - return GetField(VT_ID, 0); - } - const flatbuffers::Vector> *variables() const { - return GetPointer> *>(VT_VARIABLES); + int64_t id() const { return GetField(VT_ID, 0); } + const flatbuffers::Vector> *variables() + const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_VARIABLES); } const flatbuffers::Vector> *timing() const { - return GetPointer> *>(VT_TIMING); + return GetPointer< + const flatbuffers::Vector> *>( + VT_TIMING); } int64_t footprintForward() const { return GetField(VT_FOOTPRINTFORWARD, 0); @@ -137,20 +124,26 @@ struct FlatResultBuilder { void add_id(int64_t id) { fbb_.AddElement(FlatResult::VT_ID, id, 0); } - void add_variables(flatbuffers::Offset>> variables) { + void add_variables(flatbuffers::Offset< + flatbuffers::Vector>> + variables) { fbb_.AddOffset(FlatResult::VT_VARIABLES, variables); } - void add_timing(flatbuffers::Offset>> timing) { + void add_timing( + flatbuffers::Offset>> + timing) { fbb_.AddOffset(FlatResult::VT_TIMING, timing); } void add_footprintForward(int64_t footprintForward) { - fbb_.AddElement(FlatResult::VT_FOOTPRINTFORWARD, footprintForward, 0); + fbb_.AddElement(FlatResult::VT_FOOTPRINTFORWARD, footprintForward, + 0); } void add_footprintBackward(int64_t footprintBackward) { - fbb_.AddElement(FlatResult::VT_FOOTPRINTBACKWARD, footprintBackward, 0); + fbb_.AddElement(FlatResult::VT_FOOTPRINTBACKWARD, + footprintBackward, 0); } explicit FlatResultBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } FlatResultBuilder &operator=(const FlatResultBuilder &); @@ -162,12 +155,12 @@ struct FlatResultBuilder { }; inline flatbuffers::Offset CreateFlatResult( - flatbuffers::FlatBufferBuilder &_fbb, - int64_t id = 0, - flatbuffers::Offset>> variables = 0, - flatbuffers::Offset>> timing = 0, - int64_t footprintForward = 0, - int64_t footprintBackward = 0) { + flatbuffers::FlatBufferBuilder &_fbb, int64_t id = 0, + flatbuffers::Offset>> + variables = 0, + flatbuffers::Offset>> + timing = 0, + int64_t footprintForward = 0, int64_t footprintBackward = 0) { FlatResultBuilder builder_(_fbb); builder_.add_footprintBackward(footprintBackward); builder_.add_footprintForward(footprintForward); @@ -178,19 +171,17 @@ inline flatbuffers::Offset CreateFlatResult( } inline flatbuffers::Offset CreateFlatResultDirect( - flatbuffers::FlatBufferBuilder &_fbb, - int64_t id = 0, + flatbuffers::FlatBufferBuilder &_fbb, int64_t id = 0, const std::vector> *variables = nullptr, const std::vector> *timing = nullptr, - int64_t footprintForward = 0, - int64_t footprintBackward = 0) { + int64_t footprintForward = 0, int64_t footprintBackward = 0) { return sd::graph::CreateFlatResult( - _fbb, - id, - variables ? _fbb.CreateVector>(*variables) : 0, + _fbb, id, + variables + ? _fbb.CreateVector>(*variables) + : 0, timing ? _fbb.CreateVector>(*timing) : 0, - footprintForward, - footprintBackward); + footprintForward, footprintBackward); } inline const sd::graph::FlatResult *GetFlatResult(const void *buf) { @@ -201,8 +192,7 @@ inline const sd::graph::FlatResult *GetSizePrefixedFlatResult(const void *buf) { return flatbuffers::GetSizePrefixedRoot(buf); } -inline bool VerifyFlatResultBuffer( - flatbuffers::Verifier &verifier) { +inline bool VerifyFlatResultBuffer(flatbuffers::Verifier &verifier) { return verifier.VerifyBuffer(nullptr); } diff --git a/libnd4j/include/graph/generated/uigraphevents_generated.h b/libnd4j/include/graph/generated/uigraphevents_generated.h index b3430a5c7ca8..5d10f9d75906 100644 --- a/libnd4j/include/graph/generated/uigraphevents_generated.h +++ b/libnd4j/include/graph/generated/uigraphevents_generated.h @@ -1,12 +1,10 @@ // automatically generated by the FlatBuffers compiler, do not modify - #ifndef FLATBUFFERS_GENERATED_UIGRAPHEVENTS_ND4J_GRAPH_H_ #define FLATBUFFERS_GENERATED_UIGRAPHEVENTS_ND4J_GRAPH_H_ -#include "flatbuffers/flatbuffers.h" - #include "array_generated.h" +#include "flatbuffers/flatbuffers.h" namespace sd { namespace graph { @@ -40,33 +38,29 @@ enum UIEventType { }; inline const UIEventType (&EnumValuesUIEventType())[9] { - static const UIEventType values[] = { - UIEventType_ADD_NAME, - UIEventType_SCALAR, - UIEventType_ARRAY, - UIEventType_ARRAY_LIST, - UIEventType_HISTOGRAM, - UIEventType_IMAGE, - UIEventType_SUMMARY_STATISTICS, - UIEventType_OP_TIMING, - UIEventType_HARDWARE_STATE - }; + static const UIEventType values[] = {UIEventType_ADD_NAME, + UIEventType_SCALAR, + UIEventType_ARRAY, + UIEventType_ARRAY_LIST, + UIEventType_HISTOGRAM, + UIEventType_IMAGE, + UIEventType_SUMMARY_STATISTICS, + UIEventType_OP_TIMING, + UIEventType_HARDWARE_STATE}; return values; } -inline const char * const *EnumNamesUIEventType() { - static const char * const names[] = { - "ADD_NAME", - "SCALAR", - "ARRAY", - "ARRAY_LIST", - "HISTOGRAM", - "IMAGE", - "SUMMARY_STATISTICS", - "OP_TIMING", - "HARDWARE_STATE", - nullptr - }; +inline const char *const *EnumNamesUIEventType() { + static const char *const names[] = {"ADD_NAME", + "SCALAR", + "ARRAY", + "ARRAY_LIST", + "HISTOGRAM", + "IMAGE", + "SUMMARY_STATISTICS", + "OP_TIMING", + "HARDWARE_STATE", + nullptr}; return names; } @@ -92,34 +86,19 @@ enum UIEventSubtype { inline const UIEventSubtype (&EnumValuesUIEventSubtype())[10] { static const UIEventSubtype values[] = { - UIEventSubtype_NONE, - UIEventSubtype_EVALUATION, - UIEventSubtype_LOSS, - UIEventSubtype_LEARNING_RATE, - UIEventSubtype_TUNING_METRIC, - UIEventSubtype_PERFORMANCE, - UIEventSubtype_PROFILING, - UIEventSubtype_FEATURE_LABEL, - UIEventSubtype_PREDICTION, - UIEventSubtype_USER_CUSTOM - }; + UIEventSubtype_NONE, UIEventSubtype_EVALUATION, + UIEventSubtype_LOSS, UIEventSubtype_LEARNING_RATE, + UIEventSubtype_TUNING_METRIC, UIEventSubtype_PERFORMANCE, + UIEventSubtype_PROFILING, UIEventSubtype_FEATURE_LABEL, + UIEventSubtype_PREDICTION, UIEventSubtype_USER_CUSTOM}; return values; } -inline const char * const *EnumNamesUIEventSubtype() { - static const char * const names[] = { - "NONE", - "EVALUATION", - "LOSS", - "LEARNING_RATE", - "TUNING_METRIC", - "PERFORMANCE", - "PROFILING", - "FEATURE_LABEL", - "PREDICTION", - "USER_CUSTOM", - nullptr - }; +inline const char *const *EnumNamesUIEventSubtype() { + static const char *const names[] = { + "NONE", "EVALUATION", "LOSS", "LEARNING_RATE", + "TUNING_METRIC", "PERFORMANCE", "PROFILING", "FEATURE_LABEL", + "PREDICTION", "USER_CUSTOM", nullptr}; return names; } @@ -137,21 +116,15 @@ enum UIHistogramType { }; inline const UIHistogramType (&EnumValuesUIHistogramType())[3] { - static const UIHistogramType values[] = { - UIHistogramType_DISCRETE, - UIHistogramType_EQUAL_SPACING, - UIHistogramType_CUSTOM - }; + static const UIHistogramType values[] = {UIHistogramType_DISCRETE, + UIHistogramType_EQUAL_SPACING, + UIHistogramType_CUSTOM}; return values; } -inline const char * const *EnumNamesUIHistogramType() { - static const char * const names[] = { - "DISCRETE", - "EQUAL_SPACING", - "CUSTOM", - nullptr - }; +inline const char *const *EnumNamesUIHistogramType() { + static const char *const names[] = {"DISCRETE", "EQUAL_SPACING", "CUSTOM", + nullptr}; return names; } @@ -178,27 +151,15 @@ struct UIEvent FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { UIEventSubtype eventSubType() const { return static_cast(GetField(VT_EVENTSUBTYPE, 0)); } - int32_t nameIdx() const { - return GetField(VT_NAMEIDX, 0); - } - int64_t timestamp() const { - return GetField(VT_TIMESTAMP, 0); - } - int32_t iteration() const { - return GetField(VT_ITERATION, 0); - } - int32_t epoch() const { - return GetField(VT_EPOCH, 0); - } - int16_t variableId() const { - return GetField(VT_VARIABLEID, 0); - } + int32_t nameIdx() const { return GetField(VT_NAMEIDX, 0); } + int64_t timestamp() const { return GetField(VT_TIMESTAMP, 0); } + int32_t iteration() const { return GetField(VT_ITERATION, 0); } + int32_t epoch() const { return GetField(VT_EPOCH, 0); } + int16_t variableId() const { return GetField(VT_VARIABLEID, 0); } const FrameIteration *frameIter() const { return GetPointer(VT_FRAMEITER); } - uint16_t plugin() const { - return GetField(VT_PLUGIN, 0); - } + uint16_t plugin() const { return GetField(VT_PLUGIN, 0); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_EVENTTYPE) && @@ -210,8 +171,7 @@ struct UIEvent FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyField(verifier, VT_VARIABLEID) && VerifyOffset(verifier, VT_FRAMEITER) && verifier.VerifyTable(frameIter()) && - VerifyField(verifier, VT_PLUGIN) && - verifier.EndTable(); + VerifyField(verifier, VT_PLUGIN) && verifier.EndTable(); } }; @@ -219,10 +179,12 @@ struct UIEventBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_eventType(UIEventType eventType) { - fbb_.AddElement(UIEvent::VT_EVENTTYPE, static_cast(eventType), 0); + fbb_.AddElement(UIEvent::VT_EVENTTYPE, + static_cast(eventType), 0); } void add_eventSubType(UIEventSubtype eventSubType) { - fbb_.AddElement(UIEvent::VT_EVENTSUBTYPE, static_cast(eventSubType), 0); + fbb_.AddElement(UIEvent::VT_EVENTSUBTYPE, + static_cast(eventSubType), 0); } void add_nameIdx(int32_t nameIdx) { fbb_.AddElement(UIEvent::VT_NAMEIDX, nameIdx, 0); @@ -245,8 +207,7 @@ struct UIEventBuilder { void add_plugin(uint16_t plugin) { fbb_.AddElement(UIEvent::VT_PLUGIN, plugin, 0); } - explicit UIEventBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit UIEventBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } UIEventBuilder &operator=(const UIEventBuilder &); @@ -260,13 +221,9 @@ struct UIEventBuilder { inline flatbuffers::Offset CreateUIEvent( flatbuffers::FlatBufferBuilder &_fbb, UIEventType eventType = UIEventType_ADD_NAME, - UIEventSubtype eventSubType = UIEventSubtype_NONE, - int32_t nameIdx = 0, - int64_t timestamp = 0, - int32_t iteration = 0, - int32_t epoch = 0, - int16_t variableId = 0, - flatbuffers::Offset frameIter = 0, + UIEventSubtype eventSubType = UIEventSubtype_NONE, int32_t nameIdx = 0, + int64_t timestamp = 0, int32_t iteration = 0, int32_t epoch = 0, + int16_t variableId = 0, flatbuffers::Offset frameIter = 0, uint16_t plugin = 0) { UIEventBuilder builder_(_fbb); builder_.add_timestamp(timestamp); @@ -282,22 +239,15 @@ inline flatbuffers::Offset CreateUIEvent( } struct FrameIteration FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_FRAME = 4, - VT_ITERATION = 6 - }; + enum { VT_FRAME = 4, VT_ITERATION = 6 }; const flatbuffers::String *frame() const { return GetPointer(VT_FRAME); } - uint16_t iteration() const { - return GetField(VT_ITERATION, 0); - } + uint16_t iteration() const { return GetField(VT_ITERATION, 0); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_FRAME) && + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_FRAME) && verifier.VerifyString(frame()) && - VerifyField(verifier, VT_ITERATION) && - verifier.EndTable(); + VerifyField(verifier, VT_ITERATION) && verifier.EndTable(); } }; @@ -311,7 +261,7 @@ struct FrameIterationBuilder { fbb_.AddElement(FrameIteration::VT_ITERATION, iteration, 0); } explicit FrameIterationBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } FrameIterationBuilder &operator=(const FrameIterationBuilder &); @@ -333,31 +283,22 @@ inline flatbuffers::Offset CreateFrameIteration( } inline flatbuffers::Offset CreateFrameIterationDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const char *frame = nullptr, + flatbuffers::FlatBufferBuilder &_fbb, const char *frame = nullptr, uint16_t iteration = 0) { return sd::graph::CreateFrameIteration( - _fbb, - frame ? _fbb.CreateString(frame) : 0, - iteration); + _fbb, frame ? _fbb.CreateString(frame) : 0, iteration); } struct UIAddName FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_NAMEIDX = 4, - VT_NAME = 6 - }; - int32_t nameIdx() const { - return GetField(VT_NAMEIDX, 0); - } + enum { VT_NAMEIDX = 4, VT_NAME = 6 }; + int32_t nameIdx() const { return GetField(VT_NAMEIDX, 0); } const flatbuffers::String *name() const { return GetPointer(VT_NAME); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_NAMEIDX) && - VerifyOffset(verifier, VT_NAME) && - verifier.VerifyString(name()) && + VerifyOffset(verifier, VT_NAME) && verifier.VerifyString(name()) && verifier.EndTable(); } }; @@ -371,8 +312,7 @@ struct UIAddNameBuilder { void add_name(flatbuffers::Offset name) { fbb_.AddOffset(UIAddName::VT_NAME, name); } - explicit UIAddNameBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit UIAddNameBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } UIAddNameBuilder &operator=(const UIAddNameBuilder &); @@ -384,8 +324,7 @@ struct UIAddNameBuilder { }; inline flatbuffers::Offset CreateUIAddName( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t nameIdx = 0, + flatbuffers::FlatBufferBuilder &_fbb, int32_t nameIdx = 0, flatbuffers::Offset name = 0) { UIAddNameBuilder builder_(_fbb); builder_.add_name(name); @@ -394,39 +333,35 @@ inline flatbuffers::Offset CreateUIAddName( } inline flatbuffers::Offset CreateUIAddNameDirect( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t nameIdx = 0, + flatbuffers::FlatBufferBuilder &_fbb, int32_t nameIdx = 0, const char *name = nullptr) { - return sd::graph::CreateUIAddName( - _fbb, - nameIdx, - name ? _fbb.CreateString(name) : 0); + return sd::graph::CreateUIAddName(_fbb, nameIdx, + name ? _fbb.CreateString(name) : 0); } struct FlatArrayList FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_LIST = 4 - }; + enum { VT_LIST = 4 }; const flatbuffers::Vector> *list() const { - return GetPointer> *>(VT_LIST); + return GetPointer< + const flatbuffers::Vector> *>(VT_LIST); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_LIST) && + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_LIST) && verifier.VerifyVector(list()) && - verifier.VerifyVectorOfTables(list()) && - verifier.EndTable(); + verifier.VerifyVectorOfTables(list()) && verifier.EndTable(); } }; struct FlatArrayListBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_list(flatbuffers::Offset>> list) { + void add_list( + flatbuffers::Offset>> + list) { fbb_.AddOffset(FlatArrayList::VT_LIST, list); } explicit FlatArrayListBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } FlatArrayListBuilder &operator=(const FlatArrayListBuilder &); @@ -439,7 +374,8 @@ struct FlatArrayListBuilder { inline flatbuffers::Offset CreateFlatArrayList( flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset>> list = 0) { + flatbuffers::Offset>> + list = 0) { FlatArrayListBuilder builder_(_fbb); builder_.add_list(list); return builder_.Finish(); @@ -464,30 +400,26 @@ struct UIHistogram FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { UIHistogramType type() const { return static_cast(GetField(VT_TYPE, 0)); } - uint32_t numbins() const { - return GetField(VT_NUMBINS, 0); - } + uint32_t numbins() const { return GetField(VT_NUMBINS, 0); } const FlatArray *binranges() const { return GetPointer(VT_BINRANGES); } - const FlatArray *y() const { - return GetPointer(VT_Y); - } - const flatbuffers::Vector> *binlabels() const { - return GetPointer> *>(VT_BINLABELS); + const FlatArray *y() const { return GetPointer(VT_Y); } + const flatbuffers::Vector> + *binlabels() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_BINLABELS); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_TYPE) && VerifyField(verifier, VT_NUMBINS) && VerifyOffset(verifier, VT_BINRANGES) && - verifier.VerifyTable(binranges()) && - VerifyOffset(verifier, VT_Y) && - verifier.VerifyTable(y()) && - VerifyOffset(verifier, VT_BINLABELS) && + verifier.VerifyTable(binranges()) && VerifyOffset(verifier, VT_Y) && + verifier.VerifyTable(y()) && VerifyOffset(verifier, VT_BINLABELS) && verifier.VerifyVector(binlabels()) && - verifier.VerifyVectorOfStrings(binlabels()) && - verifier.EndTable(); + verifier.VerifyVectorOfStrings(binlabels()) && verifier.EndTable(); } }; @@ -506,11 +438,14 @@ struct UIHistogramBuilder { void add_y(flatbuffers::Offset y) { fbb_.AddOffset(UIHistogram::VT_Y, y); } - void add_binlabels(flatbuffers::Offset>> binlabels) { + void add_binlabels( + flatbuffers::Offset< + flatbuffers::Vector>> + binlabels) { fbb_.AddOffset(UIHistogram::VT_BINLABELS, binlabels); } explicit UIHistogramBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } UIHistogramBuilder &operator=(const UIHistogramBuilder &); @@ -523,11 +458,12 @@ struct UIHistogramBuilder { inline flatbuffers::Offset CreateUIHistogram( flatbuffers::FlatBufferBuilder &_fbb, - UIHistogramType type = UIHistogramType_DISCRETE, - uint32_t numbins = 0, + UIHistogramType type = UIHistogramType_DISCRETE, uint32_t numbins = 0, flatbuffers::Offset binranges = 0, flatbuffers::Offset y = 0, - flatbuffers::Offset>> binlabels = 0) { + flatbuffers::Offset< + flatbuffers::Vector>> + binlabels = 0) { UIHistogramBuilder builder_(_fbb); builder_.add_binlabels(binlabels); builder_.add_y(y); @@ -539,21 +475,20 @@ inline flatbuffers::Offset CreateUIHistogram( inline flatbuffers::Offset CreateUIHistogramDirect( flatbuffers::FlatBufferBuilder &_fbb, - UIHistogramType type = UIHistogramType_DISCRETE, - uint32_t numbins = 0, + UIHistogramType type = UIHistogramType_DISCRETE, uint32_t numbins = 0, flatbuffers::Offset binranges = 0, flatbuffers::Offset y = 0, - const std::vector> *binlabels = nullptr) { + const std::vector> *binlabels = + nullptr) { return sd::graph::CreateUIHistogram( - _fbb, - type, - numbins, - binranges, - y, - binlabels ? _fbb.CreateVector>(*binlabels) : 0); + _fbb, type, numbins, binranges, y, + binlabels ? _fbb.CreateVector>( + *binlabels) + : 0); } -struct UISummaryStatistics FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { +struct UISummaryStatistics FLATBUFFERS_FINAL_CLASS + : private flatbuffers::Table { enum { VT_BITMASK = 4, VT_MIN = 6, @@ -566,51 +501,32 @@ struct UISummaryStatistics FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table VT_COUNTNAN = 20, VT_COUNTINF = 22 }; - uint32_t bitmask() const { - return GetField(VT_BITMASK, 0); - } - const FlatArray *min() const { - return GetPointer(VT_MIN); - } - const FlatArray *max() const { - return GetPointer(VT_MAX); - } - double mean() const { - return GetField(VT_MEAN, 0.0); - } - double stdev() const { - return GetField(VT_STDEV, 0.0); - } - int64_t countzero() const { - return GetField(VT_COUNTZERO, 0); - } + uint32_t bitmask() const { return GetField(VT_BITMASK, 0); } + const FlatArray *min() const { return GetPointer(VT_MIN); } + const FlatArray *max() const { return GetPointer(VT_MAX); } + double mean() const { return GetField(VT_MEAN, 0.0); } + double stdev() const { return GetField(VT_STDEV, 0.0); } + int64_t countzero() const { return GetField(VT_COUNTZERO, 0); } int64_t countpositive() const { return GetField(VT_COUNTPOSITIVE, 0); } int64_t countnegative() const { return GetField(VT_COUNTNEGATIVE, 0); } - int64_t countnan() const { - return GetField(VT_COUNTNAN, 0); - } - int64_t countinf() const { - return GetField(VT_COUNTINF, 0); - } + int64_t countnan() const { return GetField(VT_COUNTNAN, 0); } + int64_t countinf() const { return GetField(VT_COUNTINF, 0); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_BITMASK) && - VerifyOffset(verifier, VT_MIN) && - verifier.VerifyTable(min()) && - VerifyOffset(verifier, VT_MAX) && - verifier.VerifyTable(max()) && + VerifyOffset(verifier, VT_MIN) && verifier.VerifyTable(min()) && + VerifyOffset(verifier, VT_MAX) && verifier.VerifyTable(max()) && VerifyField(verifier, VT_MEAN) && VerifyField(verifier, VT_STDEV) && VerifyField(verifier, VT_COUNTZERO) && VerifyField(verifier, VT_COUNTPOSITIVE) && VerifyField(verifier, VT_COUNTNEGATIVE) && VerifyField(verifier, VT_COUNTNAN) && - VerifyField(verifier, VT_COUNTINF) && - verifier.EndTable(); + VerifyField(verifier, VT_COUNTINF) && verifier.EndTable(); } }; @@ -636,10 +552,12 @@ struct UISummaryStatisticsBuilder { fbb_.AddElement(UISummaryStatistics::VT_COUNTZERO, countzero, 0); } void add_countpositive(int64_t countpositive) { - fbb_.AddElement(UISummaryStatistics::VT_COUNTPOSITIVE, countpositive, 0); + fbb_.AddElement(UISummaryStatistics::VT_COUNTPOSITIVE, + countpositive, 0); } void add_countnegative(int64_t countnegative) { - fbb_.AddElement(UISummaryStatistics::VT_COUNTNEGATIVE, countnegative, 0); + fbb_.AddElement(UISummaryStatistics::VT_COUNTNEGATIVE, + countnegative, 0); } void add_countnan(int64_t countnan) { fbb_.AddElement(UISummaryStatistics::VT_COUNTNAN, countnan, 0); @@ -648,7 +566,7 @@ struct UISummaryStatisticsBuilder { fbb_.AddElement(UISummaryStatistics::VT_COUNTINF, countinf, 0); } explicit UISummaryStatisticsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } UISummaryStatisticsBuilder &operator=(const UISummaryStatisticsBuilder &); @@ -660,17 +578,11 @@ struct UISummaryStatisticsBuilder { }; inline flatbuffers::Offset CreateUISummaryStatistics( - flatbuffers::FlatBufferBuilder &_fbb, - uint32_t bitmask = 0, + flatbuffers::FlatBufferBuilder &_fbb, uint32_t bitmask = 0, flatbuffers::Offset min = 0, - flatbuffers::Offset max = 0, - double mean = 0.0, - double stdev = 0.0, - int64_t countzero = 0, - int64_t countpositive = 0, - int64_t countnegative = 0, - int64_t countnan = 0, - int64_t countinf = 0) { + flatbuffers::Offset max = 0, double mean = 0.0, + double stdev = 0.0, int64_t countzero = 0, int64_t countpositive = 0, + int64_t countnegative = 0, int64_t countnan = 0, int64_t countinf = 0) { UISummaryStatisticsBuilder builder_(_fbb); builder_.add_countinf(countinf); builder_.add_countnan(countnan); @@ -686,36 +598,30 @@ inline flatbuffers::Offset CreateUISummaryStatistics( } struct UIHardwareState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_GPUMEMORY = 4, - VT_HOSTMEMORY = 6 - }; + enum { VT_GPUMEMORY = 4, VT_HOSTMEMORY = 6 }; const flatbuffers::Vector *gpuMemory() const { return GetPointer *>(VT_GPUMEMORY); } - int64_t hostMemory() const { - return GetField(VT_HOSTMEMORY, 0); - } + int64_t hostMemory() const { return GetField(VT_HOSTMEMORY, 0); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_GPUMEMORY) && + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_GPUMEMORY) && verifier.VerifyVector(gpuMemory()) && - VerifyField(verifier, VT_HOSTMEMORY) && - verifier.EndTable(); + VerifyField(verifier, VT_HOSTMEMORY) && verifier.EndTable(); } }; struct UIHardwareStateBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_gpuMemory(flatbuffers::Offset> gpuMemory) { + void add_gpuMemory( + flatbuffers::Offset> gpuMemory) { fbb_.AddOffset(UIHardwareState::VT_GPUMEMORY, gpuMemory); } void add_hostMemory(int64_t hostMemory) { fbb_.AddElement(UIHardwareState::VT_HOSTMEMORY, hostMemory, 0); } explicit UIHardwareStateBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } UIHardwareStateBuilder &operator=(const UIHardwareStateBuilder &); @@ -738,12 +644,9 @@ inline flatbuffers::Offset CreateUIHardwareState( inline flatbuffers::Offset CreateUIHardwareStateDirect( flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *gpuMemory = nullptr, - int64_t hostMemory = 0) { + const std::vector *gpuMemory = nullptr, int64_t hostMemory = 0) { return sd::graph::CreateUIHardwareState( - _fbb, - gpuMemory ? _fbb.CreateVector(*gpuMemory) : 0, - hostMemory); + _fbb, gpuMemory ? _fbb.CreateVector(*gpuMemory) : 0, hostMemory); } } // namespace graph diff --git a/libnd4j/include/graph/generated/uigraphstatic_generated.h b/libnd4j/include/graph/generated/uigraphstatic_generated.h index b6545f53a5e6..e0e0e8a6547d 100644 --- a/libnd4j/include/graph/generated/uigraphstatic_generated.h +++ b/libnd4j/include/graph/generated/uigraphstatic_generated.h @@ -1,12 +1,10 @@ // automatically generated by the FlatBuffers compiler, do not modify - #ifndef FLATBUFFERS_GENERATED_UIGRAPHSTATIC_ND4J_GRAPH_H_ #define FLATBUFFERS_GENERATED_UIGRAPHSTATIC_ND4J_GRAPH_H_ -#include "flatbuffers/flatbuffers.h" - #include "array_generated.h" +#include "flatbuffers/flatbuffers.h" #include "utils_generated.h" #include "variable_generated.h" @@ -32,21 +30,15 @@ enum UIInfoType { }; inline const UIInfoType (&EnumValuesUIInfoType())[3] { - static const UIInfoType values[] = { - UIInfoType_GRAPH_STRUCTURE, - UIInfoType_SYTEM_INFO, - UIInfoType_START_EVENTS - }; + static const UIInfoType values[] = {UIInfoType_GRAPH_STRUCTURE, + UIInfoType_SYTEM_INFO, + UIInfoType_START_EVENTS}; return values; } -inline const char * const *EnumNamesUIInfoType() { - static const char * const names[] = { - "GRAPH_STRUCTURE", - "SYTEM_INFO", - "START_EVENTS", - nullptr - }; +inline const char *const *EnumNamesUIInfoType() { + static const char *const names[] = {"GRAPH_STRUCTURE", "SYTEM_INFO", + "START_EVENTS", nullptr}; return names; } @@ -56,16 +48,13 @@ inline const char *EnumNameUIInfoType(UIInfoType e) { } struct UIStaticInfoRecord FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_INFOTYPE = 4 - }; + enum { VT_INFOTYPE = 4 }; UIInfoType infoType() const { return static_cast(GetField(VT_INFOTYPE, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_INFOTYPE) && - verifier.EndTable(); + VerifyField(verifier, VT_INFOTYPE) && verifier.EndTable(); } }; @@ -73,10 +62,11 @@ struct UIStaticInfoRecordBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_infoType(UIInfoType infoType) { - fbb_.AddElement(UIStaticInfoRecord::VT_INFOTYPE, static_cast(infoType), 0); + fbb_.AddElement(UIStaticInfoRecord::VT_INFOTYPE, + static_cast(infoType), 0); } explicit UIStaticInfoRecordBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } UIStaticInfoRecordBuilder &operator=(const UIStaticInfoRecordBuilder &); @@ -96,9 +86,7 @@ inline flatbuffers::Offset CreateUIStaticInfoRecord( } struct UISystemInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_PHYSICALCORES = 4 - }; + enum { VT_PHYSICALCORES = 4 }; int32_t physicalCores() const { return GetField(VT_PHYSICALCORES, 0); } @@ -116,7 +104,7 @@ struct UISystemInfoBuilder { fbb_.AddElement(UISystemInfo::VT_PHYSICALCORES, physicalCores, 0); } explicit UISystemInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } UISystemInfoBuilder &operator=(const UISystemInfoBuilder &); @@ -128,8 +116,7 @@ struct UISystemInfoBuilder { }; inline flatbuffers::Offset CreateUISystemInfo( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t physicalCores = 0) { + flatbuffers::FlatBufferBuilder &_fbb, int32_t physicalCores = 0) { UISystemInfoBuilder builder_(_fbb); builder_.add_physicalCores(physicalCores); return builder_.Finish(); @@ -143,24 +130,35 @@ struct UIGraphStructure FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_VARIABLES = 10, VT_OPS = 12 }; - const flatbuffers::Vector> *inputs() const { - return GetPointer> *>(VT_INPUTS); + const flatbuffers::Vector> *inputs() + const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_INPUTS); } const flatbuffers::Vector> *inputsPair() const { - return GetPointer> *>(VT_INPUTSPAIR); - } - const flatbuffers::Vector> *outputs() const { - return GetPointer> *>(VT_OUTPUTS); - } - const flatbuffers::Vector> *variables() const { - return GetPointer> *>(VT_VARIABLES); + return GetPointer< + const flatbuffers::Vector> *>( + VT_INPUTSPAIR); + } + const flatbuffers::Vector> *outputs() + const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_OUTPUTS); + } + const flatbuffers::Vector> *variables() + const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_VARIABLES); } const flatbuffers::Vector> *ops() const { - return GetPointer> *>(VT_OPS); + return GetPointer> *>( + VT_OPS); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_INPUTS) && + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_INPUTS) && verifier.VerifyVector(inputs()) && verifier.VerifyVectorOfStrings(inputs()) && VerifyOffset(verifier, VT_INPUTSPAIR) && @@ -172,33 +170,41 @@ struct UIGraphStructure FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyOffset(verifier, VT_VARIABLES) && verifier.VerifyVector(variables()) && verifier.VerifyVectorOfTables(variables()) && - VerifyOffset(verifier, VT_OPS) && - verifier.VerifyVector(ops()) && - verifier.VerifyVectorOfTables(ops()) && - verifier.EndTable(); + VerifyOffset(verifier, VT_OPS) && verifier.VerifyVector(ops()) && + verifier.VerifyVectorOfTables(ops()) && verifier.EndTable(); } }; struct UIGraphStructureBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_inputs(flatbuffers::Offset>> inputs) { + void add_inputs(flatbuffers::Offset< + flatbuffers::Vector>> + inputs) { fbb_.AddOffset(UIGraphStructure::VT_INPUTS, inputs); } - void add_inputsPair(flatbuffers::Offset>> inputsPair) { + void add_inputsPair( + flatbuffers::Offset>> + inputsPair) { fbb_.AddOffset(UIGraphStructure::VT_INPUTSPAIR, inputsPair); } - void add_outputs(flatbuffers::Offset>> outputs) { + void add_outputs( + flatbuffers::Offset< + flatbuffers::Vector>> + outputs) { fbb_.AddOffset(UIGraphStructure::VT_OUTPUTS, outputs); } - void add_variables(flatbuffers::Offset>> variables) { + void add_variables( + flatbuffers::Offset>> + variables) { fbb_.AddOffset(UIGraphStructure::VT_VARIABLES, variables); } - void add_ops(flatbuffers::Offset>> ops) { + void add_ops( + flatbuffers::Offset>> ops) { fbb_.AddOffset(UIGraphStructure::VT_OPS, ops); } explicit UIGraphStructureBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } UIGraphStructureBuilder &operator=(const UIGraphStructureBuilder &); @@ -211,11 +217,18 @@ struct UIGraphStructureBuilder { inline flatbuffers::Offset CreateUIGraphStructure( flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset>> inputs = 0, - flatbuffers::Offset>> inputsPair = 0, - flatbuffers::Offset>> outputs = 0, - flatbuffers::Offset>> variables = 0, - flatbuffers::Offset>> ops = 0) { + flatbuffers::Offset< + flatbuffers::Vector>> + inputs = 0, + flatbuffers::Offset>> + inputsPair = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + outputs = 0, + flatbuffers::Offset>> + variables = 0, + flatbuffers::Offset>> ops = + 0) { UIGraphStructureBuilder builder_(_fbb); builder_.add_ops(ops); builder_.add_variables(variables); @@ -227,17 +240,25 @@ inline flatbuffers::Offset CreateUIGraphStructure( inline flatbuffers::Offset CreateUIGraphStructureDirect( flatbuffers::FlatBufferBuilder &_fbb, - const std::vector> *inputs = nullptr, + const std::vector> *inputs = + nullptr, const std::vector> *inputsPair = nullptr, - const std::vector> *outputs = nullptr, + const std::vector> *outputs = + nullptr, const std::vector> *variables = nullptr, const std::vector> *ops = nullptr) { return sd::graph::CreateUIGraphStructure( _fbb, - inputs ? _fbb.CreateVector>(*inputs) : 0, - inputsPair ? _fbb.CreateVector>(*inputsPair) : 0, - outputs ? _fbb.CreateVector>(*outputs) : 0, - variables ? _fbb.CreateVector>(*variables) : 0, + inputs + ? _fbb.CreateVector>(*inputs) + : 0, + inputsPair ? _fbb.CreateVector>(*inputsPair) + : 0, + outputs ? _fbb.CreateVector>( + *outputs) + : 0, + variables ? _fbb.CreateVector>(*variables) + : 0, ops ? _fbb.CreateVector>(*ops) : 0); } @@ -257,9 +278,7 @@ struct UIVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_UILABELEXTRA = 26, VT_CONSTANTVALUE = 28 }; - const IntPair *id() const { - return GetPointer(VT_ID); - } + const IntPair *id() const { return GetPointer(VT_ID); } const flatbuffers::String *name() const { return GetPointer(VT_NAME); } @@ -272,20 +291,32 @@ struct UIVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::Vector *shape() const { return GetPointer *>(VT_SHAPE); } - const flatbuffers::Vector> *controlDeps() const { - return GetPointer> *>(VT_CONTROLDEPS); + const flatbuffers::Vector> + *controlDeps() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_CONTROLDEPS); } const flatbuffers::String *outputOfOp() const { return GetPointer(VT_OUTPUTOFOP); } - const flatbuffers::Vector> *inputsForOp() const { - return GetPointer> *>(VT_INPUTSFOROP); + const flatbuffers::Vector> + *inputsForOp() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_INPUTSFOROP); } - const flatbuffers::Vector> *controlDepsForOp() const { - return GetPointer> *>(VT_CONTROLDEPSFOROP); + const flatbuffers::Vector> + *controlDepsForOp() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_CONTROLDEPSFOROP); } - const flatbuffers::Vector> *controlDepsForVar() const { - return GetPointer> *>(VT_CONTROLDEPSFORVAR); + const flatbuffers::Vector> + *controlDepsForVar() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_CONTROLDEPSFORVAR); } const flatbuffers::String *gradientVariable() const { return GetPointer(VT_GRADIENTVARIABLE); @@ -297,15 +328,12 @@ struct UIVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return GetPointer(VT_CONSTANTVALUE); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_ID) && - verifier.VerifyTable(id()) && - VerifyOffset(verifier, VT_NAME) && + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_ID) && + verifier.VerifyTable(id()) && VerifyOffset(verifier, VT_NAME) && verifier.VerifyString(name()) && VerifyField(verifier, VT_TYPE) && VerifyField(verifier, VT_DATATYPE) && - VerifyOffset(verifier, VT_SHAPE) && - verifier.VerifyVector(shape()) && + VerifyOffset(verifier, VT_SHAPE) && verifier.VerifyVector(shape()) && VerifyOffset(verifier, VT_CONTROLDEPS) && verifier.VerifyVector(controlDeps()) && verifier.VerifyVectorOfStrings(controlDeps()) && @@ -325,8 +353,7 @@ struct UIVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyOffset(verifier, VT_UILABELEXTRA) && verifier.VerifyString(uiLabelExtra()) && VerifyOffset(verifier, VT_CONSTANTVALUE) && - verifier.VerifyTable(constantValue()) && - verifier.EndTable(); + verifier.VerifyTable(constantValue()) && verifier.EndTable(); } }; @@ -343,27 +370,41 @@ struct UIVariableBuilder { fbb_.AddElement(UIVariable::VT_TYPE, static_cast(type), 0); } void add_datatype(DType datatype) { - fbb_.AddElement(UIVariable::VT_DATATYPE, static_cast(datatype), 0); + fbb_.AddElement(UIVariable::VT_DATATYPE, + static_cast(datatype), 0); } void add_shape(flatbuffers::Offset> shape) { fbb_.AddOffset(UIVariable::VT_SHAPE, shape); } - void add_controlDeps(flatbuffers::Offset>> controlDeps) { + void add_controlDeps( + flatbuffers::Offset< + flatbuffers::Vector>> + controlDeps) { fbb_.AddOffset(UIVariable::VT_CONTROLDEPS, controlDeps); } void add_outputOfOp(flatbuffers::Offset outputOfOp) { fbb_.AddOffset(UIVariable::VT_OUTPUTOFOP, outputOfOp); } - void add_inputsForOp(flatbuffers::Offset>> inputsForOp) { + void add_inputsForOp( + flatbuffers::Offset< + flatbuffers::Vector>> + inputsForOp) { fbb_.AddOffset(UIVariable::VT_INPUTSFOROP, inputsForOp); } - void add_controlDepsForOp(flatbuffers::Offset>> controlDepsForOp) { + void add_controlDepsForOp( + flatbuffers::Offset< + flatbuffers::Vector>> + controlDepsForOp) { fbb_.AddOffset(UIVariable::VT_CONTROLDEPSFOROP, controlDepsForOp); } - void add_controlDepsForVar(flatbuffers::Offset>> controlDepsForVar) { + void add_controlDepsForVar( + flatbuffers::Offset< + flatbuffers::Vector>> + controlDepsForVar) { fbb_.AddOffset(UIVariable::VT_CONTROLDEPSFORVAR, controlDepsForVar); } - void add_gradientVariable(flatbuffers::Offset gradientVariable) { + void add_gradientVariable( + flatbuffers::Offset gradientVariable) { fbb_.AddOffset(UIVariable::VT_GRADIENTVARIABLE, gradientVariable); } void add_uiLabelExtra(flatbuffers::Offset uiLabelExtra) { @@ -373,7 +414,7 @@ struct UIVariableBuilder { fbb_.AddOffset(UIVariable::VT_CONSTANTVALUE, constantValue); } explicit UIVariableBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } UIVariableBuilder &operator=(const UIVariableBuilder &); @@ -385,17 +426,23 @@ struct UIVariableBuilder { }; inline flatbuffers::Offset CreateUIVariable( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset id = 0, + flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset id = 0, flatbuffers::Offset name = 0, - VarType type = VarType_VARIABLE, - DType datatype = DType_INHERIT, + VarType type = VarType_VARIABLE, DType datatype = DType_INHERIT, flatbuffers::Offset> shape = 0, - flatbuffers::Offset>> controlDeps = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + controlDeps = 0, flatbuffers::Offset outputOfOp = 0, - flatbuffers::Offset>> inputsForOp = 0, - flatbuffers::Offset>> controlDepsForOp = 0, - flatbuffers::Offset>> controlDepsForVar = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + inputsForOp = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + controlDepsForOp = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + controlDepsForVar = 0, flatbuffers::Offset gradientVariable = 0, flatbuffers::Offset uiLabelExtra = 0, flatbuffers::Offset constantValue = 0) { @@ -417,35 +464,40 @@ inline flatbuffers::Offset CreateUIVariable( } inline flatbuffers::Offset CreateUIVariableDirect( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset id = 0, - const char *name = nullptr, - VarType type = VarType_VARIABLE, - DType datatype = DType_INHERIT, - const std::vector *shape = nullptr, - const std::vector> *controlDeps = nullptr, + flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset id = 0, + const char *name = nullptr, VarType type = VarType_VARIABLE, + DType datatype = DType_INHERIT, const std::vector *shape = nullptr, + const std::vector> *controlDeps = + nullptr, const char *outputOfOp = nullptr, - const std::vector> *inputsForOp = nullptr, - const std::vector> *controlDepsForOp = nullptr, - const std::vector> *controlDepsForVar = nullptr, - const char *gradientVariable = nullptr, - const char *uiLabelExtra = nullptr, + const std::vector> *inputsForOp = + nullptr, + const std::vector> + *controlDepsForOp = nullptr, + const std::vector> + *controlDepsForVar = nullptr, + const char *gradientVariable = nullptr, const char *uiLabelExtra = nullptr, flatbuffers::Offset constantValue = 0) { return sd::graph::CreateUIVariable( - _fbb, - id, - name ? _fbb.CreateString(name) : 0, - type, - datatype, + _fbb, id, name ? _fbb.CreateString(name) : 0, type, datatype, shape ? _fbb.CreateVector(*shape) : 0, - controlDeps ? _fbb.CreateVector>(*controlDeps) : 0, + controlDeps ? _fbb.CreateVector>( + *controlDeps) + : 0, outputOfOp ? _fbb.CreateString(outputOfOp) : 0, - inputsForOp ? _fbb.CreateVector>(*inputsForOp) : 0, - controlDepsForOp ? _fbb.CreateVector>(*controlDepsForOp) : 0, - controlDepsForVar ? _fbb.CreateVector>(*controlDepsForVar) : 0, + inputsForOp ? _fbb.CreateVector>( + *inputsForOp) + : 0, + controlDepsForOp + ? _fbb.CreateVector>( + *controlDepsForOp) + : 0, + controlDepsForVar + ? _fbb.CreateVector>( + *controlDepsForVar) + : 0, gradientVariable ? _fbb.CreateString(gradientVariable) : 0, - uiLabelExtra ? _fbb.CreateString(uiLabelExtra) : 0, - constantValue); + uiLabelExtra ? _fbb.CreateString(uiLabelExtra) : 0, constantValue); } struct UIOp FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -463,23 +515,30 @@ struct UIOp FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::String *opName() const { return GetPointer(VT_OPNAME); } - const flatbuffers::Vector> *inputs() const { - return GetPointer> *>(VT_INPUTS); + const flatbuffers::Vector> *inputs() + const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_INPUTS); } - const flatbuffers::Vector> *outputs() const { - return GetPointer> *>(VT_OUTPUTS); + const flatbuffers::Vector> *outputs() + const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_OUTPUTS); } - const flatbuffers::Vector> *controlDeps() const { - return GetPointer> *>(VT_CONTROLDEPS); + const flatbuffers::Vector> + *controlDeps() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_CONTROLDEPS); } const flatbuffers::String *uiLabelExtra() const { return GetPointer(VT_UILABELEXTRA); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_NAME) && - verifier.VerifyString(name()) && - VerifyOffset(verifier, VT_OPNAME) && + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && VerifyOffset(verifier, VT_OPNAME) && verifier.VerifyString(opName()) && VerifyOffset(verifier, VT_INPUTS) && verifier.VerifyVector(inputs()) && @@ -491,8 +550,7 @@ struct UIOp FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyVector(controlDeps()) && verifier.VerifyVectorOfStrings(controlDeps()) && VerifyOffset(verifier, VT_UILABELEXTRA) && - verifier.VerifyString(uiLabelExtra()) && - verifier.EndTable(); + verifier.VerifyString(uiLabelExtra()) && verifier.EndTable(); } }; @@ -505,20 +563,27 @@ struct UIOpBuilder { void add_opName(flatbuffers::Offset opName) { fbb_.AddOffset(UIOp::VT_OPNAME, opName); } - void add_inputs(flatbuffers::Offset>> inputs) { + void add_inputs(flatbuffers::Offset< + flatbuffers::Vector>> + inputs) { fbb_.AddOffset(UIOp::VT_INPUTS, inputs); } - void add_outputs(flatbuffers::Offset>> outputs) { + void add_outputs( + flatbuffers::Offset< + flatbuffers::Vector>> + outputs) { fbb_.AddOffset(UIOp::VT_OUTPUTS, outputs); } - void add_controlDeps(flatbuffers::Offset>> controlDeps) { + void add_controlDeps( + flatbuffers::Offset< + flatbuffers::Vector>> + controlDeps) { fbb_.AddOffset(UIOp::VT_CONTROLDEPS, controlDeps); } void add_uiLabelExtra(flatbuffers::Offset uiLabelExtra) { fbb_.AddOffset(UIOp::VT_UILABELEXTRA, uiLabelExtra); } - explicit UIOpBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit UIOpBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } UIOpBuilder &operator=(const UIOpBuilder &); @@ -533,9 +598,15 @@ inline flatbuffers::Offset CreateUIOp( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset name = 0, flatbuffers::Offset opName = 0, - flatbuffers::Offset>> inputs = 0, - flatbuffers::Offset>> outputs = 0, - flatbuffers::Offset>> controlDeps = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + inputs = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + outputs = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + controlDeps = 0, flatbuffers::Offset uiLabelExtra = 0) { UIOpBuilder builder_(_fbb); builder_.add_uiLabelExtra(uiLabelExtra); @@ -548,20 +619,27 @@ inline flatbuffers::Offset CreateUIOp( } inline flatbuffers::Offset CreateUIOpDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const char *name = nullptr, + flatbuffers::FlatBufferBuilder &_fbb, const char *name = nullptr, const char *opName = nullptr, - const std::vector> *inputs = nullptr, - const std::vector> *outputs = nullptr, - const std::vector> *controlDeps = nullptr, + const std::vector> *inputs = + nullptr, + const std::vector> *outputs = + nullptr, + const std::vector> *controlDeps = + nullptr, const char *uiLabelExtra = nullptr) { return sd::graph::CreateUIOp( - _fbb, - name ? _fbb.CreateString(name) : 0, + _fbb, name ? _fbb.CreateString(name) : 0, opName ? _fbb.CreateString(opName) : 0, - inputs ? _fbb.CreateVector>(*inputs) : 0, - outputs ? _fbb.CreateVector>(*outputs) : 0, - controlDeps ? _fbb.CreateVector>(*controlDeps) : 0, + inputs + ? _fbb.CreateVector>(*inputs) + : 0, + outputs ? _fbb.CreateVector>( + *outputs) + : 0, + controlDeps ? _fbb.CreateVector>( + *controlDeps) + : 0, uiLabelExtra ? _fbb.CreateString(uiLabelExtra) : 0); } diff --git a/libnd4j/include/graph/generated/utils_generated.h b/libnd4j/include/graph/generated/utils_generated.h index 8e7896bb4956..24853534d86a 100644 --- a/libnd4j/include/graph/generated/utils_generated.h +++ b/libnd4j/include/graph/generated/utils_generated.h @@ -1,6 +1,5 @@ // automatically generated by the FlatBuffers compiler, do not modify - #ifndef FLATBUFFERS_GENERATED_UTILS_ND4J_GRAPH_H_ #define FLATBUFFERS_GENERATED_UTILS_ND4J_GRAPH_H_ @@ -50,160 +49,144 @@ enum OpType { inline const OpType (&EnumValuesOpType())[26] { static const OpType values[] = { - OpType_TRANSFORM_FLOAT, - OpType_TRANSFORM_SAME, - OpType_TRANSFORM_BOOL, - OpType_TRANSFORM_STRICT, - OpType_TRANSFORM_ANY, - OpType_REDUCE_FLOAT, - OpType_REDUCE_SAME, - OpType_REDUCE_LONG, - OpType_REDUCE_BOOL, - OpType_INDEX_REDUCE, - OpType_SCALAR, - OpType_SCALAR_BOOL, - OpType_BROADCAST, - OpType_BROADCAST_BOOL, - OpType_PAIRWISE, - OpType_PAIRWISE_BOOL, - OpType_REDUCE_3, - OpType_SUMMARYSTATS, - OpType_SHAPE, - OpType_AGGREGATION, - OpType_RANDOM, - OpType_CUSTOM, - OpType_GRAPH, - OpType_VARIABLE, - OpType_BOOLEAN, - OpType_LOGIC - }; + OpType_TRANSFORM_FLOAT, OpType_TRANSFORM_SAME, + OpType_TRANSFORM_BOOL, OpType_TRANSFORM_STRICT, + OpType_TRANSFORM_ANY, OpType_REDUCE_FLOAT, + OpType_REDUCE_SAME, OpType_REDUCE_LONG, + OpType_REDUCE_BOOL, OpType_INDEX_REDUCE, + OpType_SCALAR, OpType_SCALAR_BOOL, + OpType_BROADCAST, OpType_BROADCAST_BOOL, + OpType_PAIRWISE, OpType_PAIRWISE_BOOL, + OpType_REDUCE_3, OpType_SUMMARYSTATS, + OpType_SHAPE, OpType_AGGREGATION, + OpType_RANDOM, OpType_CUSTOM, + OpType_GRAPH, OpType_VARIABLE, + OpType_BOOLEAN, OpType_LOGIC}; return values; } -inline const char * const *EnumNamesOpType() { - static const char * const names[] = { - "TRANSFORM_FLOAT", - "TRANSFORM_SAME", - "TRANSFORM_BOOL", - "TRANSFORM_STRICT", - "TRANSFORM_ANY", - "REDUCE_FLOAT", - "REDUCE_SAME", - "REDUCE_LONG", - "REDUCE_BOOL", - "INDEX_REDUCE", - "SCALAR", - "SCALAR_BOOL", - "BROADCAST", - "BROADCAST_BOOL", - "PAIRWISE", - "PAIRWISE_BOOL", - "REDUCE_3", - "SUMMARYSTATS", - "SHAPE", - "AGGREGATION", - "RANDOM", - "CUSTOM", - "GRAPH", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "VARIABLE", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "BOOLEAN", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "LOGIC", - nullptr - }; +inline const char *const *EnumNamesOpType() { + static const char *const names[] = {"TRANSFORM_FLOAT", + "TRANSFORM_SAME", + "TRANSFORM_BOOL", + "TRANSFORM_STRICT", + "TRANSFORM_ANY", + "REDUCE_FLOAT", + "REDUCE_SAME", + "REDUCE_LONG", + "REDUCE_BOOL", + "INDEX_REDUCE", + "SCALAR", + "SCALAR_BOOL", + "BROADCAST", + "BROADCAST_BOOL", + "PAIRWISE", + "PAIRWISE_BOOL", + "REDUCE_3", + "SUMMARYSTATS", + "SHAPE", + "AGGREGATION", + "RANDOM", + "CUSTOM", + "GRAPH", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "VARIABLE", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "BOOLEAN", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "LOGIC", + nullptr}; return names; } @@ -224,24 +207,15 @@ enum InputType { inline const InputType (&EnumValuesInputType())[5] { static const InputType values[] = { - InputType_UNDEFINED, - InputType_NUMERIC, - InputType_STRINGULAR, - InputType_NUMERIC_SET, - InputType_STRINGULAR_SET - }; + InputType_UNDEFINED, InputType_NUMERIC, InputType_STRINGULAR, + InputType_NUMERIC_SET, InputType_STRINGULAR_SET}; return values; } -inline const char * const *EnumNamesInputType() { - static const char * const names[] = { - "UNDEFINED", - "NUMERIC", - "STRINGULAR", - "NUMERIC_SET", - "STRINGULAR_SET", - nullptr - }; +inline const char *const *EnumNamesInputType() { + static const char *const names[] = {"UNDEFINED", "NUMERIC", + "STRINGULAR", "NUMERIC_SET", + "STRINGULAR_SET", nullptr}; return names; } @@ -262,27 +236,16 @@ enum OpClass { }; inline const OpClass (&EnumValuesOpClass())[6] { - static const OpClass values[] = { - OpClass_TRANSFORM, - OpClass_REDUCTION, - OpClass_MULTIPLICATOR, - OpClass_GRAPH, - OpClass_CONDITIONAL, - OpClass_LOOP - }; + static const OpClass values[] = {OpClass_TRANSFORM, OpClass_REDUCTION, + OpClass_MULTIPLICATOR, OpClass_GRAPH, + OpClass_CONDITIONAL, OpClass_LOOP}; return values; } -inline const char * const *EnumNamesOpClass() { - static const char * const names[] = { - "TRANSFORM", - "REDUCTION", - "MULTIPLICATOR", - "GRAPH", - "CONDITIONAL", - "LOOP", - nullptr - }; +inline const char *const *EnumNamesOpClass() { + static const char *const names[] = { + "TRANSFORM", "REDUCTION", "MULTIPLICATOR", "GRAPH", + "CONDITIONAL", "LOOP", nullptr}; return names; } @@ -292,21 +255,13 @@ inline const char *EnumNameOpClass(OpClass e) { } struct LongPair FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_FIRST = 4, - VT_SECOND = 6 - }; - int64_t first() const { - return GetField(VT_FIRST, 0); - } - int64_t second() const { - return GetField(VT_SECOND, 0); - } + enum { VT_FIRST = 4, VT_SECOND = 6 }; + int64_t first() const { return GetField(VT_FIRST, 0); } + int64_t second() const { return GetField(VT_SECOND, 0); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FIRST) && - VerifyField(verifier, VT_SECOND) && - verifier.EndTable(); + VerifyField(verifier, VT_SECOND) && verifier.EndTable(); } }; @@ -319,8 +274,7 @@ struct LongPairBuilder { void add_second(int64_t second) { fbb_.AddElement(LongPair::VT_SECOND, second, 0); } - explicit LongPairBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit LongPairBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } LongPairBuilder &operator=(const LongPairBuilder &); @@ -332,8 +286,7 @@ struct LongPairBuilder { }; inline flatbuffers::Offset CreateLongPair( - flatbuffers::FlatBufferBuilder &_fbb, - int64_t first = 0, + flatbuffers::FlatBufferBuilder &_fbb, int64_t first = 0, int64_t second = 0) { LongPairBuilder builder_(_fbb); builder_.add_second(second); @@ -342,26 +295,15 @@ inline flatbuffers::Offset CreateLongPair( } struct LongTriple FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_FIRST = 4, - VT_SECOND = 6, - VT_THIRD = 8 - }; - int64_t first() const { - return GetField(VT_FIRST, 0); - } - int64_t second() const { - return GetField(VT_SECOND, 0); - } - int64_t third() const { - return GetField(VT_THIRD, 0); - } + enum { VT_FIRST = 4, VT_SECOND = 6, VT_THIRD = 8 }; + int64_t first() const { return GetField(VT_FIRST, 0); } + int64_t second() const { return GetField(VT_SECOND, 0); } + int64_t third() const { return GetField(VT_THIRD, 0); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FIRST) && VerifyField(verifier, VT_SECOND) && - VerifyField(verifier, VT_THIRD) && - verifier.EndTable(); + VerifyField(verifier, VT_THIRD) && verifier.EndTable(); } }; @@ -378,7 +320,7 @@ struct LongTripleBuilder { fbb_.AddElement(LongTriple::VT_THIRD, third, 0); } explicit LongTripleBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } LongTripleBuilder &operator=(const LongTripleBuilder &); @@ -390,9 +332,7 @@ struct LongTripleBuilder { }; inline flatbuffers::Offset CreateLongTriple( - flatbuffers::FlatBufferBuilder &_fbb, - int64_t first = 0, - int64_t second = 0, + flatbuffers::FlatBufferBuilder &_fbb, int64_t first = 0, int64_t second = 0, int64_t third = 0) { LongTripleBuilder builder_(_fbb); builder_.add_third(third); @@ -402,21 +342,13 @@ inline flatbuffers::Offset CreateLongTriple( } struct IntPair FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_FIRST = 4, - VT_SECOND = 6 - }; - int32_t first() const { - return GetField(VT_FIRST, 0); - } - int32_t second() const { - return GetField(VT_SECOND, 0); - } + enum { VT_FIRST = 4, VT_SECOND = 6 }; + int32_t first() const { return GetField(VT_FIRST, 0); } + int32_t second() const { return GetField(VT_SECOND, 0); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FIRST) && - VerifyField(verifier, VT_SECOND) && - verifier.EndTable(); + VerifyField(verifier, VT_SECOND) && verifier.EndTable(); } }; @@ -429,8 +361,7 @@ struct IntPairBuilder { void add_second(int32_t second) { fbb_.AddElement(IntPair::VT_SECOND, second, 0); } - explicit IntPairBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit IntPairBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } IntPairBuilder &operator=(const IntPairBuilder &); @@ -442,8 +373,7 @@ struct IntPairBuilder { }; inline flatbuffers::Offset CreateIntPair( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t first = 0, + flatbuffers::FlatBufferBuilder &_fbb, int32_t first = 0, int32_t second = 0) { IntPairBuilder builder_(_fbb); builder_.add_second(second); @@ -452,26 +382,15 @@ inline flatbuffers::Offset CreateIntPair( } struct IntTriple FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_FIRST = 4, - VT_SECOND = 6, - VT_THIRD = 8 - }; - int32_t first() const { - return GetField(VT_FIRST, 0); - } - int32_t second() const { - return GetField(VT_SECOND, 0); - } - int32_t third() const { - return GetField(VT_THIRD, 0); - } + enum { VT_FIRST = 4, VT_SECOND = 6, VT_THIRD = 8 }; + int32_t first() const { return GetField(VT_FIRST, 0); } + int32_t second() const { return GetField(VT_SECOND, 0); } + int32_t third() const { return GetField(VT_THIRD, 0); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FIRST) && VerifyField(verifier, VT_SECOND) && - VerifyField(verifier, VT_THIRD) && - verifier.EndTable(); + VerifyField(verifier, VT_THIRD) && verifier.EndTable(); } }; @@ -487,8 +406,7 @@ struct IntTripleBuilder { void add_third(int32_t third) { fbb_.AddElement(IntTriple::VT_THIRD, third, 0); } - explicit IntTripleBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit IntTripleBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } IntTripleBuilder &operator=(const IntTripleBuilder &); @@ -500,9 +418,7 @@ struct IntTripleBuilder { }; inline flatbuffers::Offset CreateIntTriple( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t first = 0, - int32_t second = 0, + flatbuffers::FlatBufferBuilder &_fbb, int32_t first = 0, int32_t second = 0, int32_t third = 0) { IntTripleBuilder builder_(_fbb); builder_.add_third(third); diff --git a/libnd4j/include/graph/generated/variable_generated.h b/libnd4j/include/graph/generated/variable_generated.h index a0e43a5af7b3..61fe314ee156 100644 --- a/libnd4j/include/graph/generated/variable_generated.h +++ b/libnd4j/include/graph/generated/variable_generated.h @@ -1,12 +1,10 @@ // automatically generated by the FlatBuffers compiler, do not modify - #ifndef FLATBUFFERS_GENERATED_VARIABLE_ND4J_GRAPH_H_ #define FLATBUFFERS_GENERATED_VARIABLE_ND4J_GRAPH_H_ -#include "flatbuffers/flatbuffers.h" - #include "array_generated.h" +#include "flatbuffers/flatbuffers.h" #include "utils_generated.h" namespace sd { @@ -24,23 +22,14 @@ enum VarType { }; inline const VarType (&EnumValuesVarType())[4] { - static const VarType values[] = { - VarType_VARIABLE, - VarType_CONSTANT, - VarType_ARRAY, - VarType_PLACEHOLDER - }; + static const VarType values[] = {VarType_VARIABLE, VarType_CONSTANT, + VarType_ARRAY, VarType_PLACEHOLDER}; return values; } -inline const char * const *EnumNamesVarType() { - static const char * const names[] = { - "VARIABLE", - "CONSTANT", - "ARRAY", - "PLACEHOLDER", - nullptr - }; +inline const char *const *EnumNamesVarType() { + static const char *const names[] = {"VARIABLE", "CONSTANT", "ARRAY", + "PLACEHOLDER", nullptr}; return names; } @@ -62,9 +51,7 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_CONTROLDEPFOROP = 20, VT_CONTROLDEPSFORVAR = 22 }; - const IntPair *id() const { - return GetPointer(VT_ID); - } + const IntPair *id() const { return GetPointer(VT_ID); } const flatbuffers::String *name() const { return GetPointer(VT_NAME); } @@ -77,30 +64,34 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const FlatArray *ndarray() const { return GetPointer(VT_NDARRAY); } - int32_t device() const { - return GetField(VT_DEVICE, 0); - } + int32_t device() const { return GetField(VT_DEVICE, 0); } VarType variabletype() const { return static_cast(GetField(VT_VARIABLETYPE, 0)); } - const flatbuffers::Vector> *controlDeps() const { - return GetPointer> *>(VT_CONTROLDEPS); + const flatbuffers::Vector> + *controlDeps() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_CONTROLDEPS); } - const flatbuffers::Vector> *controlDepForOp() const { - return GetPointer> *>(VT_CONTROLDEPFOROP); + const flatbuffers::Vector> + *controlDepForOp() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_CONTROLDEPFOROP); } - const flatbuffers::Vector> *controlDepsForVar() const { - return GetPointer> *>(VT_CONTROLDEPSFORVAR); + const flatbuffers::Vector> + *controlDepsForVar() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_CONTROLDEPSFORVAR); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_ID) && - verifier.VerifyTable(id()) && - VerifyOffset(verifier, VT_NAME) && + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_ID) && + verifier.VerifyTable(id()) && VerifyOffset(verifier, VT_NAME) && verifier.VerifyString(name()) && VerifyField(verifier, VT_DTYPE) && - VerifyOffset(verifier, VT_SHAPE) && - verifier.VerifyVector(shape()) && + VerifyOffset(verifier, VT_SHAPE) && verifier.VerifyVector(shape()) && VerifyOffset(verifier, VT_NDARRAY) && verifier.VerifyTable(ndarray()) && VerifyField(verifier, VT_DEVICE) && @@ -128,7 +119,8 @@ struct FlatVariableBuilder { fbb_.AddOffset(FlatVariable::VT_NAME, name); } void add_dtype(DType dtype) { - fbb_.AddElement(FlatVariable::VT_DTYPE, static_cast(dtype), 0); + fbb_.AddElement(FlatVariable::VT_DTYPE, static_cast(dtype), + 0); } void add_shape(flatbuffers::Offset> shape) { fbb_.AddOffset(FlatVariable::VT_SHAPE, shape); @@ -140,19 +132,29 @@ struct FlatVariableBuilder { fbb_.AddElement(FlatVariable::VT_DEVICE, device, 0); } void add_variabletype(VarType variabletype) { - fbb_.AddElement(FlatVariable::VT_VARIABLETYPE, static_cast(variabletype), 0); + fbb_.AddElement(FlatVariable::VT_VARIABLETYPE, + static_cast(variabletype), 0); } - void add_controlDeps(flatbuffers::Offset>> controlDeps) { + void add_controlDeps( + flatbuffers::Offset< + flatbuffers::Vector>> + controlDeps) { fbb_.AddOffset(FlatVariable::VT_CONTROLDEPS, controlDeps); } - void add_controlDepForOp(flatbuffers::Offset>> controlDepForOp) { + void add_controlDepForOp( + flatbuffers::Offset< + flatbuffers::Vector>> + controlDepForOp) { fbb_.AddOffset(FlatVariable::VT_CONTROLDEPFOROP, controlDepForOp); } - void add_controlDepsForVar(flatbuffers::Offset>> controlDepsForVar) { + void add_controlDepsForVar( + flatbuffers::Offset< + flatbuffers::Vector>> + controlDepsForVar) { fbb_.AddOffset(FlatVariable::VT_CONTROLDEPSFORVAR, controlDepsForVar); } explicit FlatVariableBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } FlatVariableBuilder &operator=(const FlatVariableBuilder &); @@ -164,17 +166,21 @@ struct FlatVariableBuilder { }; inline flatbuffers::Offset CreateFlatVariable( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset id = 0, + flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset id = 0, flatbuffers::Offset name = 0, DType dtype = DType_INHERIT, flatbuffers::Offset> shape = 0, - flatbuffers::Offset ndarray = 0, - int32_t device = 0, + flatbuffers::Offset ndarray = 0, int32_t device = 0, VarType variabletype = VarType_VARIABLE, - flatbuffers::Offset>> controlDeps = 0, - flatbuffers::Offset>> controlDepForOp = 0, - flatbuffers::Offset>> controlDepsForVar = 0) { + flatbuffers::Offset< + flatbuffers::Vector>> + controlDeps = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + controlDepForOp = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + controlDepsForVar = 0) { FlatVariableBuilder builder_(_fbb); builder_.add_controlDepsForVar(controlDepsForVar); builder_.add_controlDepForOp(controlDepForOp); @@ -190,41 +196,44 @@ inline flatbuffers::Offset CreateFlatVariable( } inline flatbuffers::Offset CreateFlatVariableDirect( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset id = 0, - const char *name = nullptr, - DType dtype = DType_INHERIT, + flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset id = 0, + const char *name = nullptr, DType dtype = DType_INHERIT, const std::vector *shape = nullptr, - flatbuffers::Offset ndarray = 0, - int32_t device = 0, + flatbuffers::Offset ndarray = 0, int32_t device = 0, VarType variabletype = VarType_VARIABLE, - const std::vector> *controlDeps = nullptr, - const std::vector> *controlDepForOp = nullptr, - const std::vector> *controlDepsForVar = nullptr) { + const std::vector> *controlDeps = + nullptr, + const std::vector> + *controlDepForOp = nullptr, + const std::vector> + *controlDepsForVar = nullptr) { return sd::graph::CreateFlatVariable( - _fbb, - id, - name ? _fbb.CreateString(name) : 0, - dtype, - shape ? _fbb.CreateVector(*shape) : 0, - ndarray, - device, + _fbb, id, name ? _fbb.CreateString(name) : 0, dtype, + shape ? _fbb.CreateVector(*shape) : 0, ndarray, device, variabletype, - controlDeps ? _fbb.CreateVector>(*controlDeps) : 0, - controlDepForOp ? _fbb.CreateVector>(*controlDepForOp) : 0, - controlDepsForVar ? _fbb.CreateVector>(*controlDepsForVar) : 0); + controlDeps ? _fbb.CreateVector>( + *controlDeps) + : 0, + controlDepForOp + ? _fbb.CreateVector>( + *controlDepForOp) + : 0, + controlDepsForVar + ? _fbb.CreateVector>( + *controlDepsForVar) + : 0); } inline const sd::graph::FlatVariable *GetFlatVariable(const void *buf) { return flatbuffers::GetRoot(buf); } -inline const sd::graph::FlatVariable *GetSizePrefixedFlatVariable(const void *buf) { +inline const sd::graph::FlatVariable *GetSizePrefixedFlatVariable( + const void *buf) { return flatbuffers::GetSizePrefixedRoot(buf); } -inline bool VerifyFlatVariableBuffer( - flatbuffers::Verifier &verifier) { +inline bool VerifyFlatVariableBuffer(flatbuffers::Verifier &verifier) { return verifier.VerifyBuffer(nullptr); } diff --git a/libnd4j/include/graph/impl/ArgumentsList.cpp b/libnd4j/include/graph/impl/ArgumentsList.cpp index 71ba8c479b91..a73f9aeb52ab 100644 --- a/libnd4j/include/graph/impl/ArgumentsList.cpp +++ b/libnd4j/include/graph/impl/ArgumentsList.cpp @@ -22,25 +22,20 @@ namespace sd { namespace graph { - ArgumentsList::ArgumentsList(std::initializer_list arguments) { - _arguments = arguments; - } - - ArgumentsList::ArgumentsList(std::initializer_list arguments) { - std::vector args(arguments); - for (int e = 0; e < args.size(); e++) { - Pair pair(args[e]); - _arguments.emplace_back(pair); - } +ArgumentsList::ArgumentsList(std::initializer_list arguments) { + _arguments = arguments; +} - } +ArgumentsList::ArgumentsList(std::initializer_list arguments) { + std::vector args(arguments); + for (int e = 0; e < args.size(); e++) { + Pair pair(args[e]); + _arguments.emplace_back(pair); + } +} - int ArgumentsList::size() { - return (int) _arguments.size(); - } +int ArgumentsList::size() { return (int)_arguments.size(); } - Pair& ArgumentsList::at(int index) { - return _arguments.at(index); - } -} -} +Pair& ArgumentsList::at(int index) { return _arguments.at(index); } +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/impl/Context.cpp b/libnd4j/include/graph/impl/Context.cpp index f76f66bbea13..1b92e6f8d3de 100644 --- a/libnd4j/include/graph/impl/Context.cpp +++ b/libnd4j/include/graph/impl/Context.cpp @@ -18,581 +18,518 @@ // @author raver119@gmail.com // +#include #include #include -#include -#include - namespace sd { - namespace graph { - Context::Context(ContextPrototype* prototype, VariableSpace* variableSpace) { - _variableSpace = variableSpace; - _dataType = prototype->dataType(); - - if (prototype != nullptr) { - for (const auto &v: *(prototype->inputs())) { - this->_inputs.push_back(v); - } - - for (const auto &v: *(prototype->getTArguments())) { - this->_tArgs.push_back(v); - } - - for (const auto &v: *(prototype->getIArguments())) { - this->_iArgs.push_back(v); - } - - for (const auto &v: *(prototype->getBArguments())) { - this->_bArgs.push_back(v); - } - - for (const auto &v: *(prototype->getAxis())) { - this->_axis.push_back(v); - } - - this->_opNum = prototype->opNum(); - this->_isInplace = prototype->isInplace(); - this->_nodeId = prototype->nodeId(); - this->_useMKLDNN = prototype->isUseMKLDNN(); - } - - - if (variableSpace != nullptr && variableSpace->launchContext()->getWorkspace() != nullptr) - this->_workspace = variableSpace->launchContext()->getWorkspace(); - } - sd::DataType Context::dataType(int index) { - - return _dataType; - } - - sd::DataType Context::dataType() { - return dataType(0); - } - - void Context::setDataType(int index, sd::DataType type) { - if (this->_dataTypes.size() > (size_t)index) - _dataTypes[index] = type; - _dataType = type; - } - - Context::Context(int nodeId, VariableSpace *variableSpace) { - this->_nodeId = nodeId; - this->_variableSpace = variableSpace; - this->_isInplace = false; - this->_workspace = nullptr; - - this->_executionTime.first = 0; - this->_executionTime.second = 0; - - if (variableSpace != nullptr && variableSpace->launchContext()->getWorkspace() != nullptr) - this->_workspace = variableSpace->launchContext()->getWorkspace(); - } - - Context::Context(int nodeId, VariableSpace *variableSpace, bool isInplace) : Context(nodeId, variableSpace) { - this->_isInplace = isInplace; - } - - Context::~Context() { - this->_iArgs.clear(); - this->_tArgs.clear(); - this->_inputs.clear(); - this->_fastpath_in.clear(); - this->_fastpath_out.clear(); - - for (auto v:_handles) - delete v; - - if (_context != nullptr) - delete _context; - } - - void Context::setTargetEngine(samediff::Engine engine) { - _engine = engine; - } - - bool Context::hasWorkspaceProvided() { - return this->_workspace != nullptr; - } - - void Context::attachWorkspace(sd::memory::Workspace* workspace) { - this->_workspace = workspace; - } - - void Context::setVariableSpace(VariableSpace *variableSpace) { - this->_variableSpace = variableSpace; - } - - void Context::forgetWorkspace() { - _workspace = nullptr; - } - - std::vector& Context::fastpath_in() { - return _fastpath_in; - } - - std::vector& Context::fastpath_out() { - return _fastpath_out; - } - - bool Context::isFastPath() { - auto ie = _fastpath_in.empty(); - auto io = _fastpath_out.empty(); - // two options here. - // either both IN/OUT are filled - auto b1 = (!ie && !io) || (!ie && _isInplace); - - // or at least something is filled, and FastPath is NOT forbidden - auto b2 = (!ie || !io) && !_forbidFastPath; - return b1 || b2; - } - - void Context::forbidFastPath(bool reallyForbid) { - _forbidFastPath = reallyForbid; - } - - VariableSpace *Context::getVariableSpace() { - return _variableSpace; - } - - sd::memory::Workspace* Context::getWorkspace() { - return _workspace; - } - - sd::memory::Workspace* Context::workspace() { - return _workspace; - } - - sd::random::RandomBuffer* Context::getRNG() { - return _rng; - } - - void Context::setRNG(sd::random::RandomBuffer* rng) { - _rng = rng; - } - - /** - * This method returns variableSpace used in this block - * @return - */ - /* - VariableSpace* Context::getVariableSpace() { - return _variableSpace; - } -*/ - - Stash* Context::getStash() { - return _variableSpace->getStash(); - } - - void Context::trackList(NDArrayList* list) { - _variableSpace->trackList(list); - } - -/* - void Block::updateVariables() { - _variables.clear(); - auto x = _inputs.size(); - for (auto &v:_inputs) { - auto var = _variableSpace->getVariable(v); - _variables.emplace_back(var); - } - } -*/ - int Context::getBranch() { - return _variableSpace->flowPath()->branch(this->nodeId()); - } - - void Context::setBranch(int branch) { - //_branch = branch; - if (_variableSpace->flowPath() != nullptr) - _variableSpace->flowPath()->markBranch(this->nodeId(), branch); - } - - Nd4jLong sd::graph::Context::getOuterTime(){ - return this->_executionTime.first; - } - - Nd4jLong sd::graph::Context::getInnerTime(){ - return this->_executionTime.second; - } - - void sd::graph::Context::setOuterTime(Nd4jLong time){ - this->_executionTime.first = time; - } - - void sd::graph::Context::setInnerTime(Nd4jLong time){ - this->_executionTime.second = time; - } - - - Variable* Context::getVariable(int idx) { - if (idx >= this->_inputs.size()) { - nd4j_printf("Node %i; Variable [%i] requested, but only %i inputs available\n", this->_nodeId, idx, this->_inputs.size()); - throw std::runtime_error("Context: bad Variable index"); - } - - auto p = this->_inputs[idx]; - - auto v = variable(p); - - if (Environment::getInstance().isDebugAndVerbose() && v != nullptr && v->getNDArray() != nullptr) { - auto array = v->getNDArray(); - std::string shape_ = ShapeUtils::shapeAsString(array); - auto type = DataTypeUtils::asString(array->dataType()); - float m = std::numeric_limits::quiet_NaN(); - if (!array->isEmpty()) { - auto values = array->asIndexedString(16); - - nd4j_printf("Debug info for node_%i input[%i]; shape: %s; ews: [%i]; order: [%i]; dtype: [%s]; first values: %s\n", this->_nodeId, idx, shape_.c_str(), array->ews(), array->ordering(), type.c_str(), values.c_str()); - } else { - nd4j_printf("Debug info for node_%i input[%i]; shape: %s; ews: [%i]; order: [%i]; dtype: [%s]; mean value: [%f]\n", this->_nodeId, idx, shape_.c_str(), array->ews(), array->ordering(), type.c_str(), m); - } - } - - return v; - } - - Variable* Context::variable(int idx) { - return getVariable(idx); - } - - Variable* Context::variable(std::initializer_list p) { - if (p.size() != 2) - throw std::runtime_error("Variable address should have size of 2"); - - // FIXME: lol - std::vector vec(p); - std::pair pair(vec[0], vec[1]); - return variable(pair); - } - - Variable* Context::variable(int node, int idx) { - std::pair pair(node, idx); - return variable(pair); - } - - Variable* Context::variable(std::pair& p) { - try { - return _variableSpace->getVariable(p); - } catch (std::exception &e) { - nd4j_printf("Node %i; Non-existent variable requested: [%i:%i]\n", this->_nodeId, p.first, p.second); - throw std::runtime_error("Bad variable"); - } - } - - void Context::pushNDArrayToVariableSpace(int nodeId, int index, NDArray *array, bool removable) { - std::pair pair(nodeId, index); - pushNDArrayToVariableSpace(pair, array, removable); - } - - void Context::pushNDArrayToVariableSpace(std::pair &pair, NDArray *array, bool removable) { - if (_variableSpace != nullptr) { - if (!_variableSpace->hasVariable(pair)) { - auto var = new Variable(array, nullptr, pair.first, pair.second); - _variableSpace->putVariable(pair, var); - var->markRemovable(removable); - } else { - auto var = _variableSpace->getVariable(pair); - if (var->hasNDArray()) { - if (var->getNDArray() != array) { - if (var->isRemovable() && var->hasNDArray()) - delete var->getNDArray(); - - var->setNDArray(array); - var->markRemovable(removable); - } - } else { - var->setNDArray(array); - var->markRemovable(removable); - } - } - } - } - - void Context::pushNDArrayListToVariableSpace(int nodeId, int index, NDArrayList* list, bool track) { - std::pair pair(nodeId, index); - pushNDArrayListToVariableSpace(pair, list, track); - } - - void Context::pushNDArrayListToVariableSpace(std::pair& pair, NDArrayList* list, bool track) { - if (!_variableSpace->hasVariable(pair)) { - auto var = new Variable(nullptr, nullptr, pair.first, pair.second); - var->setNDArrayList(list); - _variableSpace->putVariable(pair, var); - } else { - auto var = _variableSpace->getVariable(pair); - var->setNDArrayList(list); - } - - if (track) - _variableSpace->trackList(list); - } - - Variable* Context::ensureVariable(int idx) { - std::pair pair(this->nodeId(), idx); - - if (_variableSpace == nullptr) - throw std::runtime_error("Context::ensureVariable VariableSpace is NULL!"); - - if (!_variableSpace->hasVariable(pair)) { - auto var = new Variable(nullptr, nullptr, this->nodeId(), idx); - _variableSpace->putVariable(pair, var); - return var; - } else { - return _variableSpace->getVariable(pair); - } - } - - bool Context::isValueAvailable(int idx) { - auto var = ensureVariable(idx); - - if (var->variableType() == VariableType::NDARRAY) { - return var->hasNDArray(); - } else if (var->variableType() == VariableType::ARRAY_LIST) { - return var->hasNDArrayList(); - } - - return false; - } - - NDArray* Context::getNDArray(int idx) { - return array(idx); - } - - NDArray* Context::array(int idx) { - // we check for fastpath first - if (!_fastpath_in.empty() && _fastpath_in.size() > idx) { - return _fastpath_in[idx]; - } - - // if no luck for fastpath - return whatever is available - return getVariable(idx)->getNDArray(); - } - - sd::memory::Workspace *Context::fWorkspace() { - return workspace(); - } - - sd::memory::Workspace *Context::tWorkspace() { - return nullptr; - } - - sd::memory::Workspace *Context::oWorkspace() { - return nullptr; - } - - LaunchContext* Context::launchContext() { - //FIXME: we need proper context to be shared here - if (_context == nullptr) { - return LaunchContext::defaultContext(); - } else { - return _context; - } - } - - unsigned long Context::width() { - if (!_fastpath_in.empty()) - return _fastpath_in.size(); - else - return _inputs.size(); - } - - void Context::setInputArray(int index, NDArray *array, bool removable) { - if (_fastpath_in.size() < index + 1) - _fastpath_in.resize(index+1); - - _fastpath_in[index] = array; - if (removable) - _handles.emplace_back(array); - } - - void Context::setInputArray(int index, void *buffer, void * shapeInfo, void *specialBuffer, void * specialShapeInfo) { - this->setInputArray(index, buffer, const_cast(shapeInfo), specialBuffer, const_cast(specialShapeInfo)); - } - - void Context::setInputArray(int index, void *buffer, void const* shapeInfo, void *specialBuffer, void const* specialShapeInfo) { - auto array = new NDArray(buffer, specialBuffer, reinterpret_cast(shapeInfo)); - - if (_fastpath_in.size() < index + 1) - _fastpath_in.resize(index+1); - - _fastpath_in[index] = array; - _handles.emplace_back(array); - - if (_context != nullptr) - array->setContext(_context); - } - - void Context::setOutputArray(int index, NDArray *array, bool removable) { - if (_fastpath_out.size() < index + 1) - _fastpath_out.resize(index+1); - - _fastpath_out[index] = array; - - if (removable) - _handles.emplace_back(array); - } - - void Context::setOutputArray(int index, void *buffer, void * shapeInfo, void *specialBuffer, void * specialShapeInfo) { - this->setOutputArray(index, buffer, const_cast(shapeInfo), specialBuffer, const_cast(specialShapeInfo)); - } - - void Context::setOutputArray(int index, void *buffer, const void * shapeInfo, void *specialBuffer, const void * specialShapeInfo) { - if (_fastpath_out.size() < index + 1) - _fastpath_out.resize(index+1); - - auto array = new NDArray(buffer, specialBuffer, reinterpret_cast(shapeInfo)); - - _fastpath_out[index] = array; - _handles.emplace_back(array); - - if (_context != nullptr) - array->setContext(_context); - } - - void Context::setInputArray(int index, void *vdatabuffer, void const* shapeInfo, void const* specialShapeInfo) { - auto dataBuffer = reinterpret_cast(vdatabuffer); - - if (_fastpath_in.size() < index + 1) - _fastpath_in.resize(index+1); - - NDArray *array; - if (dataBuffer != nullptr) - array = new NDArray(dataBuffer->dataBuffer(), reinterpret_cast(shapeInfo), sd::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast(shapeInfo)))); - else - array = new NDArray(nullptr, nullptr, reinterpret_cast(shapeInfo)); - - _fastpath_in[index] = array; - _handles.emplace_back(array); - - if (_context != nullptr) - array->setContext(_context); - } - - void Context::setOutputArray(int index, void *vdatabuffer, void const* shapeInfo, void const* specialShapeInfo) { - auto dataBuffer = reinterpret_cast(vdatabuffer); - - if (_fastpath_out.size() < index + 1) - _fastpath_out.resize(index+1); - - NDArray *array; - if (dataBuffer != nullptr) - array = new NDArray(dataBuffer->dataBuffer(), reinterpret_cast(shapeInfo), sd::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast(shapeInfo)))); - else - array = new NDArray(nullptr, nullptr, reinterpret_cast(shapeInfo)); - - _fastpath_out[index] = array; - _handles.emplace_back(array); - - if (_context != nullptr) - array->setContext(_context); - } +namespace graph { +Context::Context(const ContextPrototype &prototype, + VariableSpace *variableSpace, + GraphMemoryManager *memoryManager) { + _memoryManager = memoryManager; + _variableSpace = variableSpace; + + for (const auto &v : prototype.inputs()) { + this->_inputs.push_back(v); + } + + for (const auto &v : prototype.getTArguments()) { + this->_tArgs.push_back(v); + } + + for (const auto &v : prototype.getIArguments()) { + this->_iArgs.push_back(v); + } + + for (const auto &v : prototype.getBArguments()) { + this->_bArgs.push_back(v); + } + + for (const auto &v : prototype.getAxis()) { + this->_axis.push_back(v); + } + + this->_opNum = prototype.opNum(); + this->_isInplace = prototype.isInplace(); + this->_nodeId = prototype.nodeId(); + this->_name = prototype.name(); + this->_useMKLDNN = prototype.isUseMKLDNN(); +} + +Context::Context(int nodeId, VariableSpace *variableSpace) { + this->_nodeId = nodeId; + this->_variableSpace = variableSpace; + this->_isInplace = false; + this->_workspace = nullptr; + + this->_executionTime.first = 0; + this->_executionTime.second = 0; +} + +Context::Context(int nodeId, VariableSpace *variableSpace, bool isInplace) + : Context(nodeId, variableSpace) { + this->_isInplace = isInplace; +} - void Context::setTArguments(double *arguments, int numberOfArguments) { - _tArgs.clear(); - _tArgs.reserve(numberOfArguments); - for (int e = 0; e < numberOfArguments; e++) - _tArgs.push_back(arguments[e]); - } +Context::~Context() { + this->_iArgs.clear(); + this->_tArgs.clear(); + this->_inputs.clear(); + this->_fastpath_in.clear(); + this->_fastpath_out.clear(); - void Context::setIArguments(Nd4jLong *arguments, int numberOfArguments) { - _iArgs.clear(); - _iArgs.reserve(numberOfArguments); - for (int e = 0; e < numberOfArguments; e++) - _iArgs.push_back(arguments[e]); - } + if (_context != nullptr) delete _context; +} + +void Context::setTargetEngine(samediff::Engine engine) { _engine = engine; } + +void Context::attachWorkspace(sd::memory::Workspace *workspace) { + this->_workspace = workspace; +} + +void Context::setVariableSpace(VariableSpace *variableSpace) { + this->_variableSpace = variableSpace; +} + +const std::vector> &Context::fastpath_in() const { + return _fastpath_in; +} + +const std::vector> &Context::fastpath_out() const { + return _fastpath_out; +} - void Context::setBArguments(bool *arguments, int numberOfArguments) { - _bArgs.clear(); - _bArgs.reserve(numberOfArguments); - for (int e = 0; e < numberOfArguments; e++) - _bArgs.push_back(arguments[e]); - } +bool Context::isFastPath() const { + auto ie = _fastpath_in.empty(); + auto io = _fastpath_out.empty(); + // two options here. + // either both IN/OUT are filled + auto b1 = (!ie && !io) || (!ie && _isInplace); + + // or at least something is filled, and FastPath is NOT forbidden + auto b2 = (!ie || !io) && !_forbidFastPath; + return b1 || b2; +} + +void Context::forbidFastPath(bool reallyForbid) { + _forbidFastPath = reallyForbid; +} + +VariableSpace *Context::getVariableSpace() { return _variableSpace; } + +sd::memory::Workspace *Context::workspace() const { return _workspace; } + +Stash *Context::stash() const { return _variableSpace->stash(); } + +Nd4jLong sd::graph::Context::outerTime() const { + return this->_executionTime.first; +} + +Nd4jLong sd::graph::Context::innerTime() const { + return this->_executionTime.second; +} + +void sd::graph::Context::setOuterTime(Nd4jLong time) { + this->_executionTime.first = time; +} + +void sd::graph::Context::setInnerTime(Nd4jLong time) { + this->_executionTime.second = time; +} + +std::shared_ptr Context::getVariable(int idx) const { + if (idx >= this->_inputs.size()) { + nd4j_printf( + "Node %i; Variable [%i] requested, but only %i inputs available\n", + this->_nodeId, idx, this->_inputs.size()); + throw std::runtime_error("Context: bad Variable index"); + } + + auto p = this->_inputs[idx]; + + auto v = variable(p); + + if (Environment::getInstance().isDebugAndVerbose() && v != nullptr && + v->getNDArray() != nullptr) { + auto array = v->getNDArray(); + std::string shape_ = ShapeUtils::shapeAsString(array.get()); + auto type = DataTypeUtils::asString(array->dataType()); + float m = std::numeric_limits::quiet_NaN(); + if (!array->isEmpty()) { + auto values = array->asIndexedString(16); + + nd4j_printf( + "Debug info for node_%i input[%i]; shape: %s; ews: [%i]; order: " + "[%i]; dtype: [%s]; first values: %s\n", + this->_nodeId, idx, shape_.c_str(), array->ews(), array->ordering(), + type.c_str(), values.c_str()); + } else { + nd4j_printf( + "Debug info for node_%i input[%i]; shape: %s; ews: [%i]; order: " + "[%i]; dtype: [%s]; mean value: [%f]\n", + this->_nodeId, idx, shape_.c_str(), array->ews(), array->ordering(), + type.c_str(), m); + } + } + + return v; +} + +std::shared_ptr Context::variable(int idx) const { + return getVariable(idx); +} + +std::shared_ptr Context::variable( + std::initializer_list p) const { + if (p.size() != 2) + throw std::runtime_error("Variable address should have size of 2"); + + // FIXME: lol + std::vector vec(p); + std::pair pair(vec[0], vec[1]); + return variable(pair); +} + +std::shared_ptr Context::variable(int node, int idx) const { + std::pair pair(node, idx); + return variable(pair); +} + +std::shared_ptr Context::variable( + const std::pair &p) const { + try { + return _variableSpace->getVariable(p); + } catch (std::exception &e) { + nd4j_printf("Node %i; Non-existent variable requested: [%i:%i]\n", + this->_nodeId, p.first, p.second); + throw std::runtime_error("Bad variable"); + } +} + +void Context::pushNDArrayToVariableSpace(int nodeId, int index, + const NDArray &array) { + std::pair pair(nodeId, index); + pushNDArrayToVariableSpace(pair, array); +} + +void Context::pushNDArrayToVariableSpace(const std::pair &pair, + const NDArray &array) { + if (_variableSpace != nullptr) { + if (!_variableSpace->hasVariable(pair)) { + auto var = std::make_shared(array, "", pair.first, pair.second); + _variableSpace->putVariable(pair, var); + } else { + auto var = _variableSpace->getVariable(pair); + var->setNDArray(std::make_shared(array)); + } + } +} + +void Context::pushNDArrayListToVariableSpace(int nodeId, int index, + const NDArrayList &list, + bool track) { + std::pair pair(nodeId, index); + pushNDArrayListToVariableSpace(pair, list, track); +} + + +void Context::pushNDArrayListToVariableSpace(int nodeId, int index, + std::shared_ptr list) { + std::pair pair(nodeId, index); + if (!_variableSpace->hasVariable(pair)) { + auto var = std::make_shared(); + var->setId(pair.first, pair.second); + var->setNDArrayList(list); + _variableSpace->putVariable(pair, var); + } else { + auto var = _variableSpace->getVariable(pair); + var->setNDArrayList(list); + } +} + +void Context::pushNDArrayListToVariableSpace(const std::pair &pair, + const NDArrayList &list, + bool track) { + if (!_variableSpace->hasVariable(pair)) { + auto var = std::make_shared(); + var->setId(pair.first, pair.second); + var->setNDArrayList(std::make_shared(list)); + _variableSpace->putVariable(pair, var); + } else { + auto var = _variableSpace->getVariable(pair); + var->setNDArrayList(std::make_shared(list)); + } +} + +std::shared_ptr Context::ensureVariable(const std::string &name, + int id, int idx) { + std::pair pair(this->nodeId(), idx); + + if (_variableSpace == nullptr) + throw std::runtime_error("Context::ensureVariable VariableSpace is NULL!"); + + if (!_variableSpace->hasVariable(pair)) { + auto var = std::make_shared(); + var->setId(this->nodeId(), idx); + auto name = this->name(); + + if (!name.empty()) var->setName(name); + + _variableSpace->putVariable(pair, var); + return var; + } else { + return _variableSpace->getVariable(pair); + } +} + +bool Context::isValueAvailable(const std::string &name, int id, int idx) const { + auto var = const_cast(this)->ensureVariable(name, id, idx); + + if (var->variableType() == VariableType::NDARRAY) { + return var->hasNDArray(); + } else if (var->variableType() == VariableType::ARRAY_LIST) { + return var->hasNDArrayList(); + } + + return false; +} + +std::shared_ptr Context::getNDArray(int idx) const { + return array(idx); +} + +std::shared_ptr Context::array(int idx) const { + // we check for fastpath first + if (!_fastpath_in.empty() && _fastpath_in.size() > idx) { + return _fastpath_in[idx]; + } + + // if no luck for fastpath - return whatever is available + return getVariable(idx)->getNDArray(); +} + +LaunchContext *Context::launchContext() { + // FIXME: we need proper context to be shared here + if (_context == nullptr) { + return LaunchContext::defaultContext(); + } else { + return _context; + } +} - void Context::setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer) { +unsigned long Context::width() const { + if (!_fastpath_in.empty()) + return _fastpath_in.size(); + else + return _inputs.size(); +} + +void Context::setInputArray(int index, const NDArray &array) { + if (_fastpath_in.size() < index + 1) _fastpath_in.resize(index + 1); + + _fastpath_in[index] = std::make_shared(array); +} + +NDArray *Context::arrayForOp(int idx) const { + auto ptr = array(idx); + + if (ptr.get() != nullptr && ptr->undefined()) return nullptr; + + return ptr.get(); +} + +void Context::setInputArray(int index, void *buffer, void *shapeInfo, + void *specialBuffer, void *specialShapeInfo) { + this->setInputArray(index, buffer, const_cast(shapeInfo), + specialBuffer, + const_cast(specialShapeInfo)); +} + +void Context::setInputArray(int index, void *buffer, void const *shapeInfo, + void *specialBuffer, void const *specialShapeInfo) { + auto array = std::make_shared( + buffer, specialBuffer, reinterpret_cast(shapeInfo)); + + if (_fastpath_in.size() < index + 1) _fastpath_in.resize(index + 1); + + _fastpath_in[index] = array; + + if (_context != nullptr) array->setContext(_context); +} + +void Context::setOutputArray(int index, const NDArray &array) { + if (_fastpath_out.size() < index + 1) _fastpath_out.resize(index + 1); + + _fastpath_out[index] = std::make_shared(array); +} + +void Context::setOutputArray(int index, void *buffer, void *shapeInfo, + void *specialBuffer, void *specialShapeInfo) { + this->setOutputArray(index, buffer, const_cast(shapeInfo), + specialBuffer, + const_cast(specialShapeInfo)); +} + +void Context::setOutputArray(int index, void *buffer, const void *shapeInfo, + void *specialBuffer, + const void *specialShapeInfo) { + if (_fastpath_out.size() < index + 1) _fastpath_out.resize(index + 1); + + auto array = std::make_shared( + buffer, specialBuffer, reinterpret_cast(shapeInfo)); + + _fastpath_out[index] = array; + + if (_context != nullptr) array->setContext(_context); +} + +void Context::setInputArray(int index, void *vdatabuffer, void const *shapeInfo, + void const *specialShapeInfo) { + auto dataBuffer = reinterpret_cast(vdatabuffer); + + if (_fastpath_in.size() < index + 1) _fastpath_in.resize(index + 1); + + std::shared_ptr array; + if (dataBuffer != nullptr) + array = std::make_shared( + dataBuffer->dataBuffer(), reinterpret_cast(shapeInfo), + sd::LaunchContext::defaultContext(), + dataBuffer->offset() / + DataTypeUtils::sizeOf(ArrayOptions::dataType( + reinterpret_cast(shapeInfo)))); + else + array = std::make_shared( + nullptr, nullptr, reinterpret_cast(shapeInfo)); + + _fastpath_in[index] = array; + + if (_context != nullptr) array->setContext(_context); +} + +void Context::setOutputArray(int index, void *vdatabuffer, + void const *shapeInfo, + void const *specialShapeInfo) { + auto dataBuffer = reinterpret_cast(vdatabuffer); + + if (_fastpath_out.size() < index + 1) _fastpath_out.resize(index + 1); + + std::shared_ptr array; + if (dataBuffer != nullptr) + array = std::make_shared( + dataBuffer->dataBuffer(), reinterpret_cast(shapeInfo), + sd::LaunchContext::defaultContext(), + dataBuffer->offset() / + DataTypeUtils::sizeOf(ArrayOptions::dataType( + reinterpret_cast(shapeInfo)))); + else + array = std::make_shared( + nullptr, nullptr, reinterpret_cast(shapeInfo)); + + _fastpath_out[index] = array; + + if (_context != nullptr) array->setContext(_context); +} + +void Context::setTArguments(double *arguments, int numberOfArguments) { + _tArgs.clear(); + _tArgs.reserve(numberOfArguments); + for (int e = 0; e < numberOfArguments; e++) _tArgs.push_back(arguments[e]); +} + +void Context::setIArguments(Nd4jLong *arguments, int numberOfArguments) { + _iArgs.clear(); + _iArgs.reserve(numberOfArguments); + for (int e = 0; e < numberOfArguments; e++) _iArgs.push_back(arguments[e]); +} + +void Context::setBArguments(bool *arguments, int numberOfArguments) { + _bArgs.clear(); + _bArgs.reserve(numberOfArguments); + for (int e = 0; e < numberOfArguments; e++) _bArgs.push_back(arguments[e]); +} + +void Context::setCudaContext(Nd4jPointer cudaStream, + Nd4jPointer reductionPointer, + Nd4jPointer allocationPointer) { #ifdef __CUDABLAS__ - _context = new LaunchContext(cudaStream, reductionPointer, allocationPointer); + _context = new LaunchContext(cudaStream, reductionPointer, allocationPointer); - // FIXME: either pass handle from outside, or make sure outside we use the same handle - _context->setCublasHandle(LaunchContext::defaultContext()->getCublasHandle()); + // FIXME: either pass handle from outside, or make sure outside we use the + // same handle + _context->setCublasHandle(LaunchContext::defaultContext()->getCublasHandle()); - for (auto v: _fastpath_out) - v->setContext(_context); + for (auto v : _fastpath_out) v->setContext(_context); - for (auto v: _fastpath_in) - v->setContext(_context); + for (auto v : _fastpath_in) v->setContext(_context); #endif - } - - void Context::allowHelpers(bool reallyAllow) { - _helpersAllowed = reallyAllow; - } - - bool Context::helpersAllowed() { - return _helpersAllowed; - } - - void Context::setTArguments(const std::vector &tArgs) { - for (auto t:tArgs) - _tArgs.emplace_back(t); - } - - void Context::setIArguments(const std::vector &iArgs) { - for (auto i:iArgs) - _iArgs.emplace_back(i); - } - - void Context::setBArguments(const std::vector &bArgs) { - for (auto b:bArgs) - _bArgs.push_back(b); - } - - void Context::setShapeFunctionOverride(bool reallyOverride) { - _shapeFunctionOverride = reallyOverride; - } - - bool Context::shapeFunctionOverride() { - return _shapeFunctionOverride; - } - - samediff::ExecutionMode Context::executionMode() { - return _execMode; - } - - void Context::setExecutionMode(samediff::ExecutionMode executionMode) { - _execMode = executionMode; - } - - bool Context::isTraining() { - return _execMode == samediff::ExecutionMode::MODE_TRAINING; - } - - bool Context::isInference() { - return _execMode == samediff::ExecutionMode::MODE_INFERENCE; - } - - void Context::setDArguments(sd::DataType *arguments, int numberOfArguments) { - _dArgs.clear(); - for (int e = 0; e < numberOfArguments; e++) - _dArgs.emplace_back(arguments[e]); - } - - void Context::setDArguments(const std::vector &dArgs) { - _dArgs.clear(); - for (auto d:dArgs) - _dArgs.emplace_back(d); - } - - void Context::clearFastPath() { - _fastpath_in.clear(); - _fastpath_out.clear(); - - for (auto v:_handles) - delete v; - - _handles.clear(); - } - } } +void Context::allowHelpers(bool reallyAllow) { _helpersAllowed = reallyAllow; } + +bool Context::helpersAllowed() const { return _helpersAllowed; } + +void Context::setTArguments(const std::vector &tArgs) { + for (auto t : tArgs) _tArgs.emplace_back(t); +} + +void Context::setIArguments(const std::vector &iArgs) { + for (auto i : iArgs) _iArgs.emplace_back(i); +} + +void Context::setBArguments(const std::vector &bArgs) { + for (auto b : bArgs) _bArgs.push_back(b); +} + +void Context::setShapeFunctionOverride(bool reallyOverride) { + _shapeFunctionOverride = reallyOverride; +} + +bool Context::shapeFunctionOverride() const { return _shapeFunctionOverride; } + +samediff::ExecutionMode Context::executionMode() const { return _execMode; } + +void Context::setExecutionMode(samediff::ExecutionMode executionMode) { + _execMode = executionMode; +} + +bool Context::isTraining() const { + return _execMode == samediff::ExecutionMode::MODE_TRAINING; +} + +bool Context::isInference() const { + return _execMode == samediff::ExecutionMode::MODE_INFERENCE; +} + +void Context::setDArguments(sd::DataType *arguments, int numberOfArguments) { + _dArgs.clear(); + for (int e = 0; e < numberOfArguments; e++) _dArgs.emplace_back(arguments[e]); +} + +void Context::setDArguments(const std::vector &dArgs) { + _dArgs.clear(); + for (auto d : dArgs) _dArgs.emplace_back(d); +} + +void Context::clearFastPath() { + _fastpath_in.clear(); + _fastpath_out.clear(); +} + +const GraphMemoryManager &Context::memoryManager() const { + return *_memoryManager; +} + +void Context::setInputArray(int index, const std::shared_ptr &array) { + if (_fastpath_in.size() < index + 1) _fastpath_in.resize(index + 1); + + _fastpath_in[index] = array; +} + +void Context::setOutputArray(int index, const std::shared_ptr &array) { + if (_fastpath_out.size() < index + 1) _fastpath_out.resize(index + 1); + + _fastpath_out[index] = array; +} +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/impl/ContextPrototype.cpp b/libnd4j/include/graph/impl/ContextPrototype.cpp index 417c46b3a11b..8f333b8f32ad 100644 --- a/libnd4j/include/graph/impl/ContextPrototype.cpp +++ b/libnd4j/include/graph/impl/ContextPrototype.cpp @@ -18,168 +18,253 @@ // @author raver119@gmail.com // -#include +#include #include +#include #include -#include namespace sd { - namespace graph { - ContextPrototype::ContextPrototype(sd::ops::OpDescriptor* opDescriptor, int nodeId, bool inPlace) { - _nodeId = nodeId; - _isInplace = inPlace; - _opDescriptor = opDescriptor; - } - - void ContextPrototype::pickInput(std::pair& p) { - this->_inputs.emplace_back(p); - } - - void ContextPrototype::pickInput(int input, int index) { - std::pair pair(input, index); - pickInput(pair); - } - - int ContextPrototype::opNum() { - return this->_opNum; - } - - void ContextPrototype::setOpNum(int opNum) { - this->_opNum = opNum; - } - - std::vector>* ContextPrototype::inputs() { - return &_inputs; - } - - void ContextPrototype::fillInputs(std::vector& inputs) { - for (int e = 0; e < inputs.size(); e++) { - auto v = inputs.at(e); - pickInput(v); - } - } - - samediff::Engine ContextPrototype::engine() { - return _engine; - } - - bool ContextPrototype::hasVariablesFilled() { - return this->_inputs.size() > 0; - } - - bool ContextPrototype::isInplace() { - return this->_isInplace; - } - - std::vector* ContextPrototype::getTArguments() { - return &(this->_tArgs); - } - - std::vector* ContextPrototype::getIArguments() { - return &(this->_iArgs); - } - - std::vector* ContextPrototype::getBArguments() { - return &(this->_bArgs); - } - - std::vector* ContextPrototype::getAxis() { - return &(this->_axis); - } - - void ContextPrototype::pickInput(int input) { - std::pair pair(input, 0); - this->_inputs.emplace_back(pair); - } - - std::pair* ContextPrototype::input(int idx) { - return &(this->_inputs.at(idx)); - } - - void ContextPrototype::fillInputs(std::initializer_list inputs) { - for (auto v: inputs) { - pickInput(v); - } - } - - int ContextPrototype::nodeId() { - return getNodeId(); - } - - sd::DataType ContextPrototype::dataType() { - return dataType(0); - } - - sd::DataType ContextPrototype::dataType(int index) { - return _dataType; - } - - void ContextPrototype::setDataType(int index, sd::DataType type) { - // if (_outputs->size() == 0) - _dataType = type; - } - - size_t ContextPrototype::numT() { - return (int) _tArgs.size(); - } - - size_t ContextPrototype::numI() { - return (int) _iArgs.size(); - } - - size_t ContextPrototype::numB() { - return (int) _bArgs.size(); - } - - int ContextPrototype::getNodeId() { - return this->_nodeId; - } - - /** - * This method returns number of inputs available in this block - * @return - */ - unsigned long ContextPrototype::width() { - return this->_inputs.size(); - }; - - void ContextPrototype::markInplace(bool reallyInplace) { - this->_isInplace = reallyInplace; - } - - template - ContextPrototype* ContextPrototype::asT() { - auto clone = new ContextPrototype(_opDescriptor, _nodeId, _isInplace); - - return clone; - } - - void ContextPrototype::setOpDescriptor(sd::ops::OpDescriptor* opDescriptor) { - _opDescriptor = opDescriptor; - } - - ContextPrototype* ContextPrototype::clone() { - auto clone = new ContextPrototype(_opDescriptor, _nodeId, _isInplace); - clone->_opNum = _opNum; - - for (auto v: _inputs) - clone->_inputs.emplace_back(v); - - for (auto v: _tArgs) - clone->_tArgs.emplace_back(v); - - for (auto v: _iArgs) - clone->_iArgs.emplace_back(v); - - return clone; - } - - std::vector *ContextPrototype::getDArguments() { - return &_dArgs; - } - - size_t ContextPrototype::numD() { - return _dArgs.size(); - } - } -} \ No newline at end of file +namespace graph { +ContextPrototype::ContextPrototype(sd::ops::OpDescriptor *opDescriptor, + int nodeId, bool inPlace) { + _nodeId = nodeId; + _isInplace = inPlace; + _opDescriptor = opDescriptor; +} + +void ContextPrototype::pickInput(const std::pair &p) { + this->_inputs.emplace_back(p); +} + +void ContextPrototype::pickInput(int input, int index) { + std::pair pair(input, index); + pickInput(pair); +} + +int ContextPrototype::opNum() const { return this->_opNum; } + +void ContextPrototype::setOpNum(int opNum) { this->_opNum = opNum; } + +const std::vector> &ContextPrototype::inputs() const { + return const_cast> &>(_inputs); +} + +void ContextPrototype::fillInputs(std::vector &inputs) { + for (int e = 0; e < inputs.size(); e++) { + auto v = inputs.at(e); + pickInput(v); + } +} + +samediff::Engine ContextPrototype::engine() const { return _engine; } + +bool ContextPrototype::hasVariablesFilled() const { + return this->_inputs.size() > 0; +} + +bool ContextPrototype::isInplace() const { return this->_isInplace; } + +const std::vector &ContextPrototype::getTArguments() const { + return const_cast &>(_tArgs); +} + +const std::vector &ContextPrototype::getIArguments() const { + return const_cast &>(_iArgs); +} + +const std::vector &ContextPrototype::getBArguments() const { + return const_cast &>(_bArgs); +} + +const std::vector &ContextPrototype::getAxis() const { + return const_cast &>(_axis); +} + +void ContextPrototype::pickInput(int input) { + std::pair pair(input, 0); + this->_inputs.emplace_back(pair); +} + +const std::pair &ContextPrototype::input(int idx) const { + return this->_inputs.at(idx); +} + +void ContextPrototype::fillInputs(std::initializer_list inputs) { + for (auto v : inputs) { + pickInput(v); + } +} + +int ContextPrototype::nodeId() const { return getNodeId(); } + +size_t ContextPrototype::numT() const { return (int)_tArgs.size(); } + +size_t ContextPrototype::numI() const { return (int)_iArgs.size(); } + +size_t ContextPrototype::numB() const { return (int)_bArgs.size(); } + +int ContextPrototype::getNodeId() const { return this->_nodeId; } + +/** + * This method returns number of inputs available in this block + * @return + */ +unsigned long ContextPrototype::width() const { return this->_inputs.size(); }; + +void ContextPrototype::markInplace(bool reallyInplace) { + this->_isInplace = reallyInplace; +} + +template +ContextPrototype *ContextPrototype::asT() { + auto clone = new ContextPrototype(_opDescriptor, _nodeId, _isInplace); + + return clone; +} + +void ContextPrototype::setOpDescriptor(sd::ops::OpDescriptor *opDescriptor) { + _opDescriptor = opDescriptor; +} + +ContextPrototype *ContextPrototype::clone() { + auto clone = new ContextPrototype(_opDescriptor, _nodeId, _isInplace); + clone->_opNum = _opNum; + + for (auto v : _inputs) clone->_inputs.emplace_back(v); + + for (auto v : _tArgs) clone->_tArgs.emplace_back(v); + + for (auto v : _iArgs) clone->_iArgs.emplace_back(v); + + return clone; +} + +const std::vector &ContextPrototype::getDArguments() const { + return const_cast &>(_dArgs); +} + +size_t ContextPrototype::numD() const { return _dArgs.size(); } + +void ContextPrototype::appendI(const std::vector &value) { + for (auto v : value) _iArgs.emplace_back(v); +} + +void ContextPrototype::appendT(const std::vector &value) { + for (auto v : value) _tArgs.emplace_back(v); +} + +void ContextPrototype::appendB(const std::vector &value) { + for (auto v : value) _bArgs.emplace_back(v); +} + +void ContextPrototype::appendD(const std::vector &value) { + for (auto v : value) _dArgs.emplace_back(v); +} + +void ContextPrototype::appendA(Nd4jLong value) { _axis.emplace_back(value); } + +void ContextPrototype::appendI(Nd4jLong value) { _iArgs.emplace_back(value); } + +void ContextPrototype::appendT(double value) { _tArgs.emplace_back(value); } + +void ContextPrototype::appendB(bool value) { _bArgs.emplace_back(value); } + +void ContextPrototype::appendD(DataType value) { _dArgs.emplace_back(value); } + +ContextPrototype::ContextPrototype(const ContextPrototype &other) noexcept { + _inputs = other._inputs; + _tArgs = other._tArgs; + _iArgs = other._iArgs; + _bArgs = other._bArgs; + _dArgs = other._dArgs; + _name = other._name; + _axis = other._axis; + + _nodeId = other._nodeId; + _isInplace = other._isInplace; + _opNum = other._opNum; + _rootSeed = other._rootSeed; + _randomGenerator = other._randomGenerator; + _opDescriptor = other._opDescriptor; + _useMKLDNN = other._useMKLDNN; + _engine = other._engine; + _execMode = other._execMode; +} + +ContextPrototype &ContextPrototype::operator=( + const ContextPrototype &other) noexcept { + if (this == &other) return *this; + + _inputs = other._inputs; + _tArgs = other._tArgs; + _iArgs = other._iArgs; + _bArgs = other._bArgs; + _dArgs = other._dArgs; + _name = other._name; + _axis = other._axis; + + _nodeId = other._nodeId; + _isInplace = other._isInplace; + _opNum = other._opNum; + _rootSeed = other._rootSeed; + _randomGenerator = other._randomGenerator; + _opDescriptor = other._opDescriptor; + _useMKLDNN = other._useMKLDNN; + _engine = other._engine; + _execMode = other._execMode; + + return *this; +} + +ContextPrototype::ContextPrototype(ContextPrototype &&other) noexcept { + _inputs = std::move(other._inputs); + _tArgs = std::move(other._tArgs); + _iArgs = std::move(other._iArgs); + _bArgs = std::move(other._bArgs); + _dArgs = std::move(other._dArgs); + _name = std::move(other._name); + _axis = std::move(other._axis); + + _nodeId = other._nodeId; + _isInplace = other._isInplace; + _opNum = other._opNum; + _rootSeed = other._rootSeed; + _randomGenerator = other._randomGenerator; + _opDescriptor = other._opDescriptor; + _useMKLDNN = other._useMKLDNN; + _engine = other._engine; + _execMode = other._execMode; +} + +ContextPrototype &ContextPrototype::operator=( + ContextPrototype &&other) noexcept { + if (this == &other) return *this; + + _inputs = std::move(other._inputs); + _tArgs = std::move(other._tArgs); + _iArgs = std::move(other._iArgs); + _bArgs = std::move(other._bArgs); + _dArgs = std::move(other._dArgs); + _name = std::move(other._name); + _axis = std::move(other._axis); + + _nodeId = other._nodeId; + _isInplace = other._isInplace; + _opNum = other._opNum; + _rootSeed = other._rootSeed; + _randomGenerator = other._randomGenerator; + _opDescriptor = other._opDescriptor; + _useMKLDNN = other._useMKLDNN; + _engine = other._engine; + _execMode = other._execMode; + + return *this; +} + +void ContextPrototype::setNodeId(int id) { _nodeId = id; } + +const std::string& ContextPrototype::name() const { return _name; } + +void ContextPrototype::setName(const std::string &name) { _name = name; } +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/impl/ExecutionResult.cpp b/libnd4j/include/graph/impl/ExecutionResult.cpp index fd2bed054201..586bc151ab6a 100644 --- a/libnd4j/include/graph/impl/ExecutionResult.cpp +++ b/libnd4j/include/graph/impl/ExecutionResult.cpp @@ -18,91 +18,88 @@ // @author raver119@gmail.com // +#include +#include #include #include -#include -#include namespace sd { - namespace graph { - ExecutionResult::ExecutionResult(const FlatResult* flatResult) { - if (flatResult->variables() != nullptr) { - for (int e = 0; e < flatResult->variables()->size(); e++) { - auto fv = flatResult->variables()->Get(e); - auto v = new Variable(fv); - this->emplace_back(v); - } - - _releasable = true; - } - } - - ExecutionResult::~ExecutionResult(){ - if (_releasable) - for (auto v : _variables) - delete v; - } - - Nd4jLong ExecutionResult::size() { - return _variables.size(); - } - - ExecutionResult::ExecutionResult(std::initializer_list variables) { - for (auto v: variables) - this->emplace_back(v); - } - - void ExecutionResult::emplace_back(Variable *variable) { - _variables.emplace_back(variable); - - if (variable->getName() != nullptr) - _stringIdMap[*variable->getName()] = variable; - - std::pair p(variable->id(), variable->index()); - _pairIdMap[p] = variable; - } - - Variable* ExecutionResult::at(int position) { - if (position >= _variables.size()) - throw std::runtime_error("Position index is higher then number of variables stored"); - - return _variables.at(position); - } - - Variable* ExecutionResult::byId(std::string &id) { - if (_stringIdMap.count(id) == 0) - throw std::runtime_error("Can't find specified ID"); - - return _stringIdMap.at(id); - } - - Variable* ExecutionResult::byId(std::pair &id) { - if (_pairIdMap.count(id) == 0) - throw std::runtime_error("Can't find specified ID"); - - return _pairIdMap.at(id); - } - - Variable* ExecutionResult::byId(int id) { - std::pair p(id, 0); - return byId(p); - } - - Variable* ExecutionResult::byId(const char *str) { - std::string p(str); - return byId(p); - } - - flatbuffers::Offset ExecutionResult::asFlatResult(flatbuffers::FlatBufferBuilder &builder) { - - std::vector> vec; - for (Variable* v : _variables) { - vec.emplace_back(v->asFlatVariable(builder)); - } - - auto vecOffset = builder.CreateVector(vec); - - return CreateFlatResult(builder, 0, vecOffset); - } +namespace graph { +ExecutionResult::ExecutionResult(const FlatResult* flatResult) { + if (flatResult->variables() != nullptr) { + for (int e = 0; e < flatResult->variables()->size(); e++) { + auto fv = flatResult->variables()->Get(e); + auto v = new Variable(fv); + this->emplace_back(v); } -} \ No newline at end of file + + _releasable = true; + } +} + +ExecutionResult::~ExecutionResult() { + if (_releasable) + for (auto v : _variables) delete v; +} + +Nd4jLong ExecutionResult::size() { return _variables.size(); } + +ExecutionResult::ExecutionResult(std::initializer_list variables) { + for (auto v : variables) this->emplace_back(v); +} + +void ExecutionResult::emplace_back(Variable* variable) { + _variables.emplace_back(variable); + + if (!variable->getName().empty()) + _stringIdMap[variable->getName()] = variable; + + std::pair p(variable->id(), variable->index()); + _pairIdMap[p] = variable; +} + +Variable* ExecutionResult::at(int position) { + if (position >= _variables.size()) + throw std::runtime_error( + "Position index is higher then number of variables stored"); + + return _variables.at(position); +} + +Variable* ExecutionResult::byId(std::string& id) { + if (_stringIdMap.count(id) == 0) + throw std::runtime_error("Can't find specified ID"); + + return _stringIdMap.at(id); +} + +Variable* ExecutionResult::byId(std::pair& id) { + if (_pairIdMap.count(id) == 0) + throw std::runtime_error("Can't find specified ID"); + + return _pairIdMap.at(id); +} + +Variable* ExecutionResult::byId(int id) { + std::pair p(id, 0); + return byId(p); +} + +Variable* ExecutionResult::byId(const char* str) { + std::string p(str); + return byId(p); +} + +flatbuffers::Offset ExecutionResult::asFlatResult( + flatbuffers::FlatBufferBuilder& builder) { + std::vector> vec; + for (Variable* v : _variables) { + vec.emplace_back(v->asFlatVariable(builder)); + } + + auto vecOffset = builder.CreateVector(vec); + + return CreateFlatResult(builder, 0, vecOffset); +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/impl/ExecutorConfiguration.cpp b/libnd4j/include/graph/impl/ExecutorConfiguration.cpp index f296ef3cd4df..f189fa850d3c 100644 --- a/libnd4j/include/graph/impl/ExecutorConfiguration.cpp +++ b/libnd4j/include/graph/impl/ExecutorConfiguration.cpp @@ -21,39 +21,35 @@ #include namespace sd { - namespace graph { - ExecutorConfiguration::ExecutorConfiguration(const sd::graph::FlatConfiguration *conf) { - if (conf != nullptr) { - _profilingMode = conf->profilingMode(); - _executionMode = conf->executionMode(); - _outputMode = conf->outputMode(); - _timestats = conf->timestats(); - _footprintForward = conf->footprintForward(); - _footprintBackward = conf->footprintBackward(); - _direction = conf->direction(); - } else { - _profilingMode = ProfilingMode_NONE; - _executionMode = ExecutionMode_SEQUENTIAL; - _outputMode = OutputMode_IMPLICIT; - _timestats = false; - } - }; +namespace graph { +ExecutorConfiguration::ExecutorConfiguration( + const sd::graph::FlatConfiguration *conf) { + if (conf != nullptr) { + _profilingMode = conf->profilingMode(); + _executionMode = conf->executionMode(); + _outputMode = conf->outputMode(); + _timestats = conf->timestats(); + _footprintForward = conf->footprintForward(); + _footprintBackward = conf->footprintBackward(); + _direction = conf->direction(); + } else { + _profilingMode = ProfilingMode_NONE; + _executionMode = ExecutionMode_SEQUENTIAL; + _outputMode = OutputMode_IMPLICIT; + _timestats = false; + } +}; - ExecutorConfiguration* ExecutorConfiguration::clone() { - auto clone = new ExecutorConfiguration(); - clone->_profilingMode = _profilingMode; - clone->_executionMode = _executionMode; - clone->_outputMode = _outputMode; - clone->_timestats = _timestats; - clone->_direction = _direction; - clone->_footprintForward = _footprintForward; - clone->_footprintBackward = _footprintBackward; +ExecutorConfiguration ExecutorConfiguration::clone() const { + return ExecutorConfiguration(*this); +}; - return clone; - }; - - flatbuffers::Offset ExecutorConfiguration::asFlatConfiguration(flatbuffers::FlatBufferBuilder &builder) { - return CreateFlatConfiguration(builder, 0, _executionMode, _profilingMode, _outputMode, _timestats, _footprintBackward, _footprintBackward); - } - } -} \ No newline at end of file +flatbuffers::Offset +ExecutorConfiguration::asFlatConfiguration( + flatbuffers::FlatBufferBuilder &builder) { + return CreateFlatConfiguration(builder, 0, _executionMode, _profilingMode, + _outputMode, _timestats, _footprintBackward, + _footprintBackward); +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/impl/FlatUtils.cpp b/libnd4j/include/graph/impl/FlatUtils.cpp index e6984bb9705d..40378847cff3 100644 --- a/libnd4j/include/graph/impl/FlatUtils.cpp +++ b/libnd4j/include/graph/impl/FlatUtils.cpp @@ -18,99 +18,106 @@ // Created by raver119 on 22.11.2017. // -#include #include +#include #include #include -#include #include - +#include namespace sd { - namespace graph { - std::pair FlatUtils::fromIntPair(IntPair *pair) { - return std::pair(pair->first(), pair->second()); - } - - std::pair FlatUtils::fromLongPair(LongPair *pair) { - return std::pair(pair->first(), pair->second()); - } - - NDArray* FlatUtils::fromFlatArray(const sd::graph::FlatArray *flatArray) { - auto rank = static_cast(flatArray->shape()->Get(0)); - auto newShape = new Nd4jLong[shape::shapeInfoLength(rank)]; - memcpy(newShape, flatArray->shape()->data(), shape::shapeInfoByteLength(rank)); - - auto length = shape::length(newShape); - auto dtype = DataTypeUtils::fromFlatDataType(flatArray->dtype()); - - // empty arrays is special case, nothing to restore here - if (shape::isEmpty(newShape)) { - delete[] newShape; - return NDArrayFactory::empty_(dtype, nullptr); - } - // TODO fix UTF16 and UTF32 - if (dtype == UTF8) { - bool isBe = BitwiseUtils::isBE(); - bool canKeep = (isBe && flatArray->byteOrder() == sd::graph::ByteOrder_BE) || (!isBe && flatArray->byteOrder() == sd::graph::ByteOrder_LE); - - std::vector substrings(length); - std::vector shapeVector(rank); - for (int e = 0; e < rank; e++) - shapeVector[e] = newShape[e+1]; - - auto rawPtr = (void *)flatArray->buffer()->data(); - auto longPtr = reinterpret_cast(rawPtr); - auto charPtr = reinterpret_cast(longPtr + length + 1); - auto offsets = new Nd4jLong[length+1]; - for (Nd4jLong e = 0; e <= length; e++) { - auto o = longPtr[e]; - // FIXME: BE vs LE on partials - //auto v = canKeep ? o : BitwiseUtils::swap_bytes(o); - offsets[e] = o; - } - - for (Nd4jLong e = 0; e < length; e++) { - auto start = offsets[e]; - auto end = offsets[e+1]; - auto len = end - start; - - auto c = (char *) malloc(len+1); - CHECK_ALLOC(c, "Failed temp allocation", len + 1); - memset(c, '\0', len + 1); - memcpy(c, charPtr + start, len); - - std::string val(c); - substrings[e] = val; - free(c); - } - - delete[] offsets; - delete[] newShape; - // string order always 'c' - return NDArrayFactory::string_(shapeVector, substrings); - } - - - auto newBuffer = new int8_t[length * DataTypeUtils::sizeOf(dtype)]; - - BUILD_SINGLE_SELECTOR(dtype, DataTypeConversions, ::convertType(newBuffer, (void *)flatArray->buffer()->data(), dtype, ByteOrderUtils::fromFlatByteOrder(flatArray->byteOrder()), length), LIBND4J_TYPES); - - auto array = new NDArray(newBuffer, newShape, sd::LaunchContext::defaultContext(), true); - - delete[] newShape; - return array; - } - - flatbuffers::Offset FlatUtils::toFlatArray(flatbuffers::FlatBufferBuilder &builder, NDArray &array) { - auto byteVector = array.asByteVector(); - - auto fBuffer = builder.CreateVector(byteVector); - auto fShape = builder.CreateVector(array.getShapeInfoAsFlatVector()); - - auto bo = static_cast(BitwiseUtils::asByteOrder()); - - return CreateFlatArray(builder, fShape, fBuffer, static_cast(array.dataType()), bo); - } +namespace graph { +std::pair FlatUtils::fromIntPair(IntPair *pair) { + return std::pair(pair->first(), pair->second()); +} + +std::pair FlatUtils::fromLongPair(LongPair *pair) { + return std::pair(pair->first(), pair->second()); +} + +NDArray FlatUtils::fromFlatArray(const sd::graph::FlatArray *flatArray) { + auto rank = static_cast(flatArray->shape()->Get(0)); + auto newShape = new Nd4jLong[shape::shapeInfoLength(rank)]; + memcpy(newShape, flatArray->shape()->data(), + shape::shapeInfoByteLength(rank)); + + auto length = shape::length(newShape); + auto dtype = DataTypeUtils::fromFlatDataType(flatArray->dtype()); + + // empty arrays is special case, nothing to restore here + if (shape::isEmpty(newShape)) { + delete[] newShape; + return NDArrayFactory::empty(dtype); + } + // TODO fix UTF16 and UTF32 + if (dtype == UTF8) { + bool isBe = BitwiseUtils::isBE(); + bool canKeep = + (isBe && flatArray->byteOrder() == sd::graph::ByteOrder_BE) || + (!isBe && flatArray->byteOrder() == sd::graph::ByteOrder_LE); + + std::vector substrings(length); + std::vector shapeVector(rank); + for (int e = 0; e < rank; e++) shapeVector[e] = newShape[e + 1]; + + auto rawPtr = (void *)flatArray->buffer()->data(); + auto longPtr = reinterpret_cast(rawPtr); + auto charPtr = reinterpret_cast(longPtr + length + 1); + auto offsets = new Nd4jLong[length + 1]; + for (Nd4jLong e = 0; e <= length; e++) { + auto o = longPtr[e]; + // FIXME: BE vs LE on partials + // auto v = canKeep ? o : BitwiseUtils::swap_bytes(o); + offsets[e] = o; + } + + for (Nd4jLong e = 0; e < length; e++) { + auto start = offsets[e]; + auto end = offsets[e + 1]; + auto len = end - start; + + auto c = (char *)malloc(len + 1); + CHECK_ALLOC(c, "Failed temp allocation", len + 1); + memset(c, '\0', len + 1); + memcpy(c, charPtr + start, len); + + std::string val(c); + substrings[e] = val; + free(c); } -} \ No newline at end of file + + delete[] offsets; + delete[] newShape; + // string order always 'c' + return NDArrayFactory::string(shapeVector, substrings); + } + + auto newBuffer = new int8_t[length * DataTypeUtils::sizeOf(dtype)]; + + BUILD_SINGLE_SELECTOR( + dtype, DataTypeConversions, + ::convertType(newBuffer, (void *)flatArray->buffer()->data(), dtype, + ByteOrderUtils::fromFlatByteOrder(flatArray->byteOrder()), + length), + LIBND4J_TYPES); + + NDArray array(newBuffer, newShape, sd::LaunchContext::defaultContext(), true); + + delete[] newShape; + return array; +} + +flatbuffers::Offset FlatUtils::toFlatArray( + flatbuffers::FlatBufferBuilder &builder, NDArray &array) { + auto byteVector = array.asByteVector(); + + auto fBuffer = builder.CreateVector(byteVector); + auto fShape = builder.CreateVector(array.getShapeInfoAsFlatVector()); + + auto bo = static_cast(BitwiseUtils::asByteOrder()); + + return CreateFlatArray(builder, fShape, fBuffer, + static_cast(array.dataType()), bo); +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/impl/FlowPath.cpp b/libnd4j/include/graph/impl/FlowPath.cpp index 79fe67b30110..f4119896289b 100644 --- a/libnd4j/include/graph/impl/FlowPath.cpp +++ b/libnd4j/include/graph/impl/FlowPath.cpp @@ -21,131 +21,124 @@ #include namespace sd { - namespace graph { +namespace graph { - void FlowPath::ensureNode(int nodeId) { - if (_states.count(nodeId) == 0) { - NodeState state(nodeId); - _states[nodeId] = state; - } - } +void FlowPath::ensureNode(int nodeId) { + if (_states.count(nodeId) == 0) { + NodeState state(nodeId); + _states[nodeId] = state; + } +} - void FlowPath::ensureFrame(int frameId) { - if (_frames.count(frameId) == 0) { - FrameState state(frameId); - _frames[frameId] = state; - } - } +void FlowPath::ensureFrame(int frameId) { + if (_frames.count(frameId) == 0) { + FrameState state(frameId); + _frames[frameId] = state; + } +} - void FlowPath::setInnerTime(int nodeId, Nd4jLong time) { - ensureNode(nodeId); +void FlowPath::setInnerTime(int nodeId, Nd4jLong time) { + ensureNode(nodeId); - _states[nodeId].setInnerTime(time); - } + _states[nodeId].setInnerTime(time); +} - void FlowPath::setOuterTime(int nodeId, Nd4jLong time) { - ensureNode(nodeId); +void FlowPath::setOuterTime(int nodeId, Nd4jLong time) { + ensureNode(nodeId); - _states[nodeId].setOuterTime(time); - } + _states[nodeId].setOuterTime(time); +} - Nd4jLong FlowPath::innerTime(int nodeId) { - ensureNode(nodeId); +Nd4jLong FlowPath::innerTime(int nodeId) { + ensureNode(nodeId); - return _states[nodeId].innerTime(); - } + return _states[nodeId].innerTime(); +} - Nd4jLong FlowPath::outerTime(int nodeId) { - ensureNode(nodeId); +Nd4jLong FlowPath::outerTime(int nodeId) { + ensureNode(nodeId); - return _states[nodeId].outerTime(); - } + return _states[nodeId].outerTime(); +} - bool FlowPath::isNodeActive(int nodeId) { - ensureNode(nodeId); +bool FlowPath::isNodeActive(int nodeId) { + ensureNode(nodeId); - return _states[nodeId].isActive(); - } - - void FlowPath::markNodeActive(int nodeId, bool isActive) { - ensureNode(nodeId); + return _states[nodeId].isActive(); +} - _states[nodeId].markActive(isActive); - } +void FlowPath::markNodeActive(int nodeId, bool isActive) { + ensureNode(nodeId); - int FlowPath::branch(int nodeId){ - ensureNode(nodeId); + _states[nodeId].markActive(isActive); +} - return _states[nodeId].branch(); - } +int FlowPath::branch(int nodeId) { + ensureNode(nodeId); - void FlowPath::markBranch(int nodeId, int index) { - ensureNode(nodeId); + return _states[nodeId].branch(); +} - _states[nodeId].markBranch(index); - } +void FlowPath::markBranch(int nodeId, int index) { + ensureNode(nodeId); - bool FlowPath::isFrameActive(Nd4jLong frameId) { - ensureFrame(frameId); + _states[nodeId].markBranch(index); +} - return _frames[frameId].wasActivated(); - } +bool FlowPath::isFrameActive(Nd4jLong frameId) { + ensureFrame(frameId); - void FlowPath::markFrameActive(Nd4jLong frameId, bool isActive) { - ensureFrame(frameId); + return _frames[frameId].wasActivated(); +} - _frames[frameId].markActivated(isActive); - } +void FlowPath::markFrameActive(Nd4jLong frameId, bool isActive) { + ensureFrame(frameId); - bool FlowPath::isRewindPlanned(Nd4jLong frameId) { - return _frames[frameId].isRewindPlanned(); - } + _frames[frameId].markActivated(isActive); +} - void FlowPath::planRewind(Nd4jLong frameId, bool reallyRewind) { - _frames[frameId].planRewind(reallyRewind); - } +bool FlowPath::isRewindPlanned(Nd4jLong frameId) { + return _frames[frameId].isRewindPlanned(); +} - int FlowPath::getRewindPosition(Nd4jLong frameId) { - return _frames[frameId].getRewindPosition(); - } +void FlowPath::planRewind(Nd4jLong frameId, bool reallyRewind) { + _frames[frameId].planRewind(reallyRewind); +} - void FlowPath::setRewindPosition(Nd4jLong frameId, int position) { - _frames[frameId].setRewindPosition(position); - } +int FlowPath::getRewindPosition(Nd4jLong frameId) { + return _frames[frameId].getRewindPosition(); +} - void FlowPath::setRewindPositionOnce(Nd4jLong frameId, int position) { - _frames[frameId].setRewindPositionOnce(position); - } +void FlowPath::setRewindPosition(Nd4jLong frameId, int position) { + _frames[frameId].setRewindPosition(position); +} - void FlowPath::registerFrame(Nd4jLong frameId) { - if (_frames.count(frameId) == 0) - ensureFrame(frameId); - } +void FlowPath::setRewindPositionOnce(Nd4jLong frameId, int position) { + _frames[frameId].setRewindPositionOnce(position); +} - void FlowPath::forgetFrame(Nd4jLong frameId) { - if (_frames.count(frameId) > 0) - _frames.erase(frameId); - } +void FlowPath::registerFrame(Nd4jLong frameId) { + if (_frames.count(frameId) == 0) ensureFrame(frameId); +} - void FlowPath::incrementNumberOfCycles(Nd4jLong frameId) { - _frames[frameId].incrementNumberOfCycles(); - } +void FlowPath::forgetFrame(Nd4jLong frameId) { + if (_frames.count(frameId) > 0) _frames.erase(frameId); +} - Nd4jLong FlowPath::getNumberOfCycles(Nd4jLong frameId) { - return _frames[frameId].getNumberOfCycles(); - } +void FlowPath::incrementNumberOfCycles(Nd4jLong frameId) { + _frames[frameId].incrementNumberOfCycles(); +} +Nd4jLong FlowPath::getNumberOfCycles(Nd4jLong frameId) { + return _frames[frameId].getNumberOfCycles(); +} - bool FlowPath::wasExecuted(int nodeId) { - return _states[nodeId].wasExecuted(); - } +bool FlowPath::wasExecuted(int nodeId) { return _states[nodeId].wasExecuted(); } - void FlowPath::markExecuted(int nodeId, bool wasExecuted) { - _states[nodeId].markExecuted(wasExecuted); - } +void FlowPath::markExecuted(int nodeId, bool wasExecuted) { + _states[nodeId].markExecuted(wasExecuted); +} - GraphProfile* FlowPath::profile() { - return &_profile; - } - } -} \ No newline at end of file +GraphProfile* FlowPath::profile() { return &_profile; } +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/impl/FrameState.cpp b/libnd4j/include/graph/impl/FrameState.cpp index d312a4f39069..fc3fbb9d4b73 100644 --- a/libnd4j/include/graph/impl/FrameState.cpp +++ b/libnd4j/include/graph/impl/FrameState.cpp @@ -20,52 +20,34 @@ #include - namespace sd { - namespace graph { - FrameState::FrameState(Nd4jLong id) { - this->_id = id; - } +namespace graph { +FrameState::FrameState(Nd4jLong id) { this->_id = id; } - int FrameState::getNumberOfCycles() { - return _numberOfCycles; - } +int FrameState::getNumberOfCycles() { return _numberOfCycles; } - void FrameState::incrementNumberOfCycles() { - ++_numberOfCycles; - } +void FrameState::incrementNumberOfCycles() { ++_numberOfCycles; } - bool FrameState::wasActivated() { - return _activated; - } +bool FrameState::wasActivated() { return _activated; } - void FrameState::markActivated(bool reallyActivated) { - _activated = reallyActivated; - } +void FrameState::markActivated(bool reallyActivated) { + _activated = reallyActivated; +} - std::string &FrameState::getFrameName() { - return _name; - } +std::string &FrameState::getFrameName() { return _name; } - bool FrameState::isRewindPlanned() { - return _rewindPlanned; - } +bool FrameState::isRewindPlanned() { return _rewindPlanned; } - int FrameState::getRewindPosition() { - return _rewindPosition; - } +int FrameState::getRewindPosition() { return _rewindPosition; } - void FrameState::setRewindPosition(int pos) { - _rewindPosition = pos; - } +void FrameState::setRewindPosition(int pos) { _rewindPosition = pos; } - void FrameState::setRewindPositionOnce(int pos) { - if (_rewindPosition < 0) - _rewindPosition = pos; - } +void FrameState::setRewindPositionOnce(int pos) { + if (_rewindPosition < 0) _rewindPosition = pos; +} - void FrameState::planRewind(bool reallyPlanning) { - _rewindPlanned = reallyPlanning; - } - } -} \ No newline at end of file +void FrameState::planRewind(bool reallyPlanning) { + _rewindPlanned = reallyPlanning; +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index a50d1f4b6edc..c67979bea9b7 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -18,1441 +18,711 @@ // @author raver119@gmail.com // -#include #include -#include +#include +#include +#include +#include #include -#include -#include -#include -#include +#include #include -#include #include #include +#include +#include +#include +#include +#include -namespace sd { - namespace graph { - std::vector* Graph::getAllNodes() { - return &_handles; - } - - std::vector* Graph::getPlaceholders() { - return _variableSpace->getPlaceholders(); - } - - int Graph::numberOfPlaceholders() { - return _variableSpace->numberOfPlaceholders(); - }; - - Nd4jLong Graph::estimateRequiredMemory() { - - Nd4jLong result = 0L; - Nd4jLong lastStep = 0L; - - std::vector shapes; - MAP_IMPL, Nd4jLong const*> shapesMap; - - int cntFD = 0; +#include +#include - // we loop in similar way to execution - for (int l = 0; l < (int) _onion->size(); l++) { - int layerSize = _onion->count(l) == 1 ? _onion->at(l)->size() : 0; +namespace sd { +namespace graph { +const std::vector> &Graph::placeholders() const { + return _variableSpace.placeholders(); +} +int Graph::numberOfPlaceholders() const { + return _variableSpace.numberOfPlaceholders(); +}; - for (int n = 0; n < layerSize; n++) { - Node* node = _onion->at(l)->at(n); +const ExecutorConfiguration &Graph::getExecutorConfiguration() const { + return _configuration; +} - /* - * Limited number of options here: - * - * 1) Op is inplace, so adds nothing to total - * 2) Op is not inplace, and 1:1 transform - * 3) Op is reduction (i.e. sum) - * 4) Op is multiplicator (i.e. im2col) - */ - if (node->hasCustomOp()) { - //if (node->isInplace()) { - // continue; - //} +VariableSpace &Graph::variableSpace() const { + return const_cast(_variableSpace); +} +Graph::~Graph() { - nd4j_debug("Trying estimation [%i] on [%s]\n", node->id(), node->getCustomOp()->getOpName()->c_str()); +} - auto op = node->getCustomOp(); - auto in = node->input()->at(0); +int Graph::idByName(const std::string &nodeName) const { + if (_symbolicLookupTable.count(nodeName) == 0) + throw std::runtime_error("Can't find node [" + nodeName + "]"); - auto block = node->getContextPrototype(); - std::vector inputShapes; - int *oldShape; - for (auto v: *node->input()) { - nd4j_debug(" inputs for estimation are: %i:%i\n", v.first, v.second); - if (v.first < 0) { - inputShapes.push_back(_variableSpace->getVariable(v.first)->getNDArray()->shapeInfo()); - } else { - inputShapes.push_back(shapesMap.at(v)); - } - } + return _symbolicLookupTable.at(nodeName); +} - Context ctx(block, _variableSpace); +void Graph::addVariable(const std::string &name, NDArray &array) { + int id = _maxId++; + _symbolicLookupTable[name] = id; + _variableSpace.putVariable(id, 0, array); +} - ShapeList inSha(inputShapes); - auto outSha = op->calculateOutputShape(&inSha, ctx); +void Graph::addVariable(const std::string &name, NDArray &&array) { + auto lvalue = array; + addVariable(name, lvalue); +} - int cnt = 0; - for (auto newShape: *outSha->asVector()) { - std::pair pairAddr(node->id(), cnt++); - std::pair, Nd4jLong const*> pairShape(pairAddr, newShape); +void Graph::addNode(Node &&node, + const std::vector &inputs) { + auto lvalue = std::move(node); + addNode(lvalue, inputs); +} - shapesMap.insert(pairShape); +void Graph::addNode(Node &node, const std::vector &inputs) { + // temporary check. basically we're okay if Node has id defined + if (node.id() != 0) + throw std::runtime_error("Graph::addNode - Node has id defined"); + + if (node.name().empty()) { + // if name is empty we'll make up a name based on Op name + } else { + if (_symbolicLookupTable.count(node.name()) > 0) + throw std::runtime_error("Graph::addNode - Graph alread has Node [" + + node.name() + "] defined"); + } + + // node must have numeric id + node.setId(_maxId++); + _symbolicLookupTable[node.name()] = node.id(); + + // converting string ids to numeric ones + for (auto &v : inputs) { + // we don't allow self-references + if (v == node.name()) + throw unresolved_input_exception::build( + "Graph::addNode - Node references itself", v); + + node.pickInput(idByName(v), 0); + } + + // actually storing the node. Later, topological sort will be applied on this + // map + _unmapped[node.id()] = node; +} - if (!block->isInplace() && !node->isInplace()) - result += shape::length(newShape) * DataTypeUtils::sizeOfElement(node->dataType()); +Graph::Graph(const FlatGraph *flatGraph, const GraphMemoryManager &memoryManager): _memoryManager(memoryManager) { + bool trusted = flatGraph != nullptr; + + // if there was no exec configuration in flatgraph - create default one + if (flatGraph != nullptr && flatGraph->configuration() != nullptr) { + _configuration = ExecutorConfiguration(flatGraph->configuration()); + } else + _configuration = ExecutorConfiguration(); + + // if memory reqs were set - initialize workspace + if (_configuration._footprintForward > 0) { + _workspace.expandBy(_configuration._footprintForward); + } + + // parsing variables here + if (flatGraph != nullptr && flatGraph->variables() != nullptr && + flatGraph->variables()->size() > 0) { + for (unsigned int e = 0; e < flatGraph->variables()->size(); e++) { + auto flatVar = flatGraph->variables()->Get(e); + std::pair pair(flatVar->id()->first(), flatVar->id()->second()); + + auto var = std::make_shared(flatVar); + if (flatVar->name() != nullptr) { + var->setName(flatVar->name()->str()); + _symbolicLookupTable[var->name()] = pair.first; + } + + _variableSpace.putVariable(pair, var); + + if (var->isPlaceholder()) + _placeholders.emplace_back(var->name()); + } + } + + // at this point we expect all variables are already registered + // we're saving outputs only if explicit mode is set + if (_configuration._outputMode == OutputMode_EXPLICIT || + _configuration._outputMode == OutputMode_EXPLICIT_AND_IMPLICIT) { + if (flatGraph != nullptr && flatGraph->outputs() != nullptr) { + for (unsigned int e = 0; e < flatGraph->outputs()->size(); e++) { + auto out = flatGraph->outputs()->Get(e); + std::pair vp(out->first(), out->second()); + if (!_variableSpace.hasVariable(vp)) { + nd4j_verbose("Non-existent variable requested: %i\n", out); + throw std::runtime_error("Non-existent variable requested"); + } + } + } + } + + // rolling through nodes + if (flatGraph != nullptr && flatGraph->nodes() != nullptr && + flatGraph->nodes()->size() > 0) { + for (unsigned int e = 0; e < flatGraph->nodes()->size(); e++) { + auto node = flatGraph->nodes()->Get(e); + + if (node->output() == nullptr || node->output()->size() == 0) { + nd4j_verbose("Orphan node detected: %i; AutoOutput to be considered\n", + node->id()); + } + + nd4j_debug("Node name: [%s]\n", node->name()->c_str()); + Node nnode(node); + // just filling list of nodes + _unmapped[nnode.id()] = nnode; + + if (!nnode.name().empty()) + _symbolicLookupTable[nnode.name()] = nnode.id(); + } + } - shapes.push_back(newShape); - } + // now, once everything is deserializerd, time to roll through Variables/Nodes and update dependencies + for (const auto &v: _unmapped) + v.second.actualizeDependencies(_symbolicLookupTable); - delete outSha; - } else if (node->getOpClass() == OpClass_TRANSFORM) { - auto vec = node->input(); + for (const auto &v:_variableSpace.variables()) + v->actualizeDependencies(_symbolicLookupTable); +} - auto in = node->input()->at(0); - if (in.first < 0) { +/** + * This method returns total number of nodes in this graph + * @return + */ +int Graph::size() const { return _unmapped.size(); } + +Nd4jStatus Graph::validate() { + throw std::runtime_error("Graph::validate - method not implemented"); +}; + +void Graph::printOutNode(const Node &node) const { + nd4j_printf("%i. ", node.id()); + switch (node.opType()) { + case OpType_CUSTOM: { + printf("%s; ", node.customOp()->getOpName().c_str()); + } break; + case OpType_LOGIC: { + printf("%s; ", EnumUtils::_LogicOpToString(node.opNum())); + } break; + default: { + printf("%s:{%i}; ", EnumUtils::_OpTypeToString(node.opType()), + (int)node.opNum()); + } + } + + nd4j_printf("Inputs: [", ""); + // auto block = node->getBlock(); + for (int e = 0; e < node.inputs().size(); e++) { + auto in = node.inputs()[e]; + printf("{%i:%i}", in.first, in.second); + if (e < node.inputs().size() - 1) nd4j_printf(", ", ""); + } + + if (node.opType() == OpType_CUSTOM) { + auto ctx = node.contextPrototype(); + if (ctx.numI() > 0) { + printf("]; iArgs: ["); + + for (int e = 0; e < ctx.numI(); e++) { + printf("%i", ctx.getIArguments().at(e)); + if (e < ctx.getIArguments().size() - 1) nd4j_printf(", ", ""); + } + } + } - auto x = _variableSpace->getVariable(in); - auto z = _variableSpace->getVariable(node->id()); + nd4j_printf("]; \n", ""); - auto newShape = new Nd4jLong[shape::shapeInfoLength(x->getNDArray()->shapeInfo())]; - memcpy(newShape, x->getNDArray()->shapeInfo(), shape::shapeInfoByteLength(x->getNDArray()->shapeInfo())); + // printf("\n"); + fflush(stdout); +} - std::pair pairAddr(node->id(), 0); - std::pair, Nd4jLong const*> pairShape(pairAddr, newShape); +void Graph::printOut() { + // print placeholders + if (_placeholders.size() > 0) { + nd4j_printf("\nPrinting out Placeholders...\n", ""); + for (auto &v:_placeholders) { + auto var = _variableSpace.getVariable(v); + auto shape = ShapeUtils::shapeAsString(var->shape()); + auto dtype = DataTypeUtils::asString(var->dataType()); + nd4j_printf("<%s> <%i> dtype: %s; shape: %s; \n", v.c_str(), var->id(), dtype.c_str(), shape.c_str()); + } + } + + // print variables + if (_variableSpace.totalEntries() > 0) { + nd4j_printf("\nPrinting out Variables...\n", ""); + auto vars = _variableSpace.variables(); + + for (auto &v : vars) { + if (v->hasNDArray()) { + auto shape = ShapeUtils::shapeAsString(v->getNDArray().get()); + auto values = v->getNDArray()->asString(16); + auto dtype = DataTypeUtils::asString(v->getNDArray()->dataType()); + + if (!v->getName().empty()) { + nd4j_printf("<%s> <%i:%i> dtype: %s; shape: %s; values: %s;\n", + v->getName().c_str(), v->id(), v->index(), dtype.c_str(), + shape.c_str(), values.c_str()); + } else { + nd4j_printf("<%i:%i> dtype: %s; shape: %s; values: %s;\n", v->id(), + v->index(), dtype.c_str(), shape.c_str(), values.c_str()); + } + } else if (v->hasNDArrayList()) { + // TODO: add better NDArrayList printout + nd4j_printf("<%i:%i> holds ArrayList", v->id(), v->index()); + } + } + } - shapesMap.insert(pairShape); + fflush(stdout); - if (!node->isInplace()) - result += shape::length(newShape) * DataTypeUtils::sizeOfElement(node->dataType()); + if (size() > 0) { + nd4j_printf("\nPrinting out Nodes...\n", ""); + optimizedGraph().printOut(); + } +} - shapes.push_back(newShape); - } else { - auto prevShape = shapesMap.at(in); +Nd4jStatus Graph::validateNode(Node *node) { + // TODO: to be implemented + return ND4J_STATUS_OK; +} - auto newShape = new Nd4jLong[shape::shapeInfoLength(prevShape)]; - memcpy(newShape, prevShape, shape::shapeInfoByteLength(prevShape)); +void Graph::replaceState(VariableSpace *state, + const ExecutorConfiguration &configuration) { + _variableSpace = *state; + _configuration = configuration; +} - std::pair pairAddr(node->id(), 0); - std::pair, Nd4jLong const*> pairShape(pairAddr, newShape); +Graph Graph::cloneWithProxy() const { + Graph clone; - shapesMap.insert(pairShape); + // clone.replaceState(new VariableProxy(&this->_variableSpace), + // this->_configuration); - if (!node->isInplace()) - result += shape::length(newShape) * DataTypeUtils::sizeOfElement(node->dataType()); + // return clone; + throw std::runtime_error("Graph::cloneWithProxy - Not implemented yet"); +} - shapes.push_back(newShape); - } +Graph *Graph::clone() const { + auto clone = new Graph(); - } else if (node->getOpClass() == OpClass_REDUCTION) { - Nd4jLong const* newShape = nullptr; + // clone->replaceState(&this->_variableSpace, this->_configuration.clone()); - // if that's scalar output - we don't care about previous node - if (node->getDimensions()->size() == 0 || (node->getDimensions()->size() == 1 && node->getDimensions()->at(0) == sd::DataTypeUtils::max())) { -// auto aNewShape = new Nd4jLong[8]; -// -// aNewShape[0] = 2; -// aNewShape[1] = 1; -// aNewShape[2] = 1; -// aNewShape[3] = 1; -// aNewShape[4] = 1; -// aNewShape[5] = 8192; // set type as FLOAT32 by default -// aNewShape[6] = 1; -// aNewShape[7] = 99; - newShape = ConstantShapeHelper::getInstance().createShapeInfo(DataType::FLOAT32, 'c', {1,1}); - } else { - auto in = node->input()->at(0); - - Nd4jLong const* oldShape = nullptr; - // calculate tads here - if (in.first < 0) { - auto x = _variableSpace->getVariable(in)->getNDArray(); - - oldShape = x->shapeInfo(); - } else { - - oldShape = shapesMap.at(in); - } - - //shape::TAD tad(oldShape, node->getDimensions()->data(), node->getDimensions()->size()); - auto numTads = shape::tadLength(oldShape, node->getDimensions()->data(), node->getDimensions()->size()); - Nd4jLong shape[2] = {1, (int) numTads}; - newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(oldShape), 'c', 2, shape); - } - - std::pair pairAddr(node->id(), 0); - std::pair, Nd4jLong const*> pairShape(pairAddr, newShape); - - shapesMap.insert(pairShape); - - result += shape::length(newShape) * DataTypeUtils::sizeOfElement(node->dataType()); - - shapes.push_back(newShape); - } else if (node->getOpClass() == OpClass_MULTIPLICATOR) { - // can't be in non-special op - } - - - cntFD++; - } - } - - // this is the only place where we deallocate shapes. - //if (_variableSpace->launchContext()->getWorkspace() == nullptr) - // for (auto v: shapes) - // delete[] v; - - return result; - } + throw std::runtime_error("Graph::clone - not implemented yet"); +} - void Graph::pushToOutputOnce(int id) { - if (std::find(_output.begin(), _output.end(), id) == _output.end()) - _output.emplace_back(id); - } +Nd4jLong Graph::hashCode() const { + throw std::runtime_error("Graph::hashCode - not implemented yet"); +} - void Graph::addOutput(int id) { - if (_configuration->_outputMode == OutputMode_EXPLICIT || _configuration->_outputMode == OutputMode_EXPLICIT_AND_IMPLICIT) - pushToOutputOnce(id); - } +Graph Graph::fromFlatBuffers(const char *fileName, + const GraphMemoryManager &memoryManager) { + // check if file exists + if (!FileUtils::fileExists(fileName)) + throw std::runtime_error("Graph file doesn't exist"); + + // get file size + auto fsize = FileUtils::fileSize(fileName); + Nd4jLong *ref; + void *ptrGraph; + + // TODO: check if mmap is supported + if (true) { + // mmap this file + ref = ::mmapFile(nullptr, fileName, fsize); + ptrGraph = reinterpret_cast(ref[0]); + memoryManager.track(std::make_shared(ref, std::make_shared())); + } else { + // if mmap is not supported - load it directly + + ptrGraph = new uint8_t[fsize]; + auto data = reinterpret_cast(ptrGraph); + + FILE *in = fopen(fileName, "rb"); + int cnt = 0; + int b = 0; + while (cnt < fsize) { + b = fread(data + cnt, 1, fsize < 16384 ? fsize : 16384, in); + + cnt += b; + } + fclose(in); + } - ExecutorConfiguration * Graph::getExecutorConfiguration() { - return _configuration; - } + return fromFlatPointer(ptrGraph, memoryManager); +} - std::vector * Graph::fetchOutputs() { - auto res = new std::vector(); - - nd4j_debug("Graph output size: %i\n", _output.size()); - for (int e = 0; e < (int) _output.size(); e++) { - auto nodeId = _output.at(e); - nd4j_debug("Output node: %i\n", nodeId); - - for (int e = 0; e < DataTypeUtils::max(); e++) { - if (_variableSpace->hasVariable(nodeId, e)) { - res->push_back(_variableSpace->getVariable(nodeId, e)); - } else { - if (e == 0) { - throw unresolved_output_exception::build("Can't find output variable", nodeId, e); - } else - break; - } - } - } - - return res; - } +Graph Graph::fromFlatPointer(void *ptr, + const GraphMemoryManager &memoryManager) { + // get FlatGraph out of it + auto fg = GetFlatGraph(reinterpret_cast(ptr)); - MAP_IMPL * Graph::getMapped() { - return _mapped; - } + // return Graph from this FlatGraph + return Graph(fg, memoryManager); +} - MAP_IMPL *>* Graph::getOnion() { - return _onion; - } +Graph Graph::importFromTensorFlow(const char *fileName) { + throw std::runtime_error("Graph::importFromTensorFlow() not implemented yet"); + /* + if (fileName == nullptr) + return nullptr; - void Graph::injectNode(Node *node) { - if (node->getLayer() < 0) - throw std::runtime_error("Only nodes with non-negative layer defined can be inserted"); + int fd = open(fileName, O_RDONLY); - std::pair pair(node->id(), node); - if (_mapped->count(pair.first) > 0) - return; + if (fd < 0) { + nd4j_printf("File not found: [%s]\n", fileName); + return nullptr; + } - nd4j_debug("Node_%i mapped to layer_%i\n", node->id(), node->getLayer()); + nd4j_verbose("Trying to load TF GraphDef from file [%s]\n", fileName); + tensorflow::GraphDef graphDef; + bool res = graphDef.ParseFromFileDescriptor(fd); - _onion->at(node->getLayer())->push_back(node); - _mapped->insert(pair); - } + // trying to read graph as text + if(!res) { + close(fd); + fd = open(fileName, O_RDONLY); - void Graph::expandOnion(int newLayer) { - if (_onion->count(newLayer) > 0) - return; + google::protobuf::io::FileInputStream fileInput(fd); + fileInput.SetCloseOnDelete(true); - std::vector *rootList = new std::vector(); - std::pair*> pair(newLayer, rootList); - _onion->insert(pair); - } + if (!google::protobuf::TextFormat::Parse(&fileInput, &graphDef)) { + nd4j_printf("Failed to read file\n",""); + } else { + res = true; + } + } + + close(fd); + + if (!res) + return nullptr; + + auto graph = new Graph(); + auto variableSpace = graph->variableSpace(); + + std::map variablesMap; + + int variablesCounter = 0; + int nodesCounter = 0; + nd4j_verbose("Number of nodes in graphDef: %i\n", graphDef.node_size()); + for (int n = 0; n < graphDef.node_size(); n++) { + auto node = graphDef.node(n); + + // if that's external variable - we put it to variable space + if (strcmp(TF_VAR, node.op().c_str()) == 0 || strcmp(TF_CONST, + node.op().c_str()) == 0 || strcmp(TF_INPUT, node.op().c_str()) == 0) { + nd4j_printf("Variable found: %s\n", node.name().c_str()); + auto variable = new Variable(); + variable->setName(new std::string(node.name().c_str())); + variable->setId(--variablesCounter); + variableSpace->putVariable(variable->id(), variable); + + std::pair pair(node.name(), variable->id()); + variablesMap.insert(pair); + + // TODO: we might want to have something like that. + // it basically just gives input validation option, since settles + expectations for input if (strcmp(TF_INPUT, node.op().c_str()) == 0) continue; + + // checking shape, not applicable to input, since it can vary + if (node.attr().count("shape")) { + auto attr = node.attr().at("shape"); + int dims = attr.shape().dim_size(); + + if (dims > 0) { + std::vector __shape; + + // we don't have rank1 arrays. vector is 2d. + if (dims == 1) + __shape.push_back(1); + + // roll through dimensions + for (auto s: attr.shape().dim()) { + __shape.push_back((int) s.size()) ; + } - VariableSpace * Graph::getVariableSpace() { - return _variableSpace; - } - - Graph::~Graph() { - for (auto &v: *_mapped) - delete v.second; + variable->setNDArray(new NDArray('c', __shape)); - for (auto &v: _unmapped) - delete v.second; + nd4j_printf("Shape found: %i dims;\n", dims); + variable->getNDArray()->printShapeInfo(); + } + } - for (auto &v: *_onion) - delete v.second; + // checking tensor attached + if (node.attr().count("value")) { + auto attr = node.attr().at("value"); + // int + if (attr.tensor().dtype() == ::tensorflow::DataType::DT_INT32) { + nd4j_verbose("Int size: %i\n", attr.tensor().int_val_size()); - for (auto v: _scopes) - delete v; + Nd4jLong __length = 0; - delete _mapped; - delete _nodes; - delete _variableSpace; - delete _onion; - delete _configuration; - } + nd4j_verbose("Tensor has shape: %i\n", + attr.tensor().has_tensor_shape()); if (attr.tensor().has_tensor_shape()) { + auto shape = attr.tensor().tensor_shape(); + int dims = shape.dim_size(); - void Graph::addNode(Node *node) { - _built.store(false); - - if (node->opType() == OpType_LOGIC) { - // nd4j_debug("Adding LogicOp [%i]\n", node->opNum()); - // SCOPE - if (node->opNum() == logic::Scope) { - auto scope = new Scope(node->id(), node->getName() != nullptr ? node->getName()->c_str() : ""); - _mappedScopes[node->id()] = scope; - _scopes.push_back(scope); - } - } - - auto cname = node->getName() == nullptr ? nullptr : node->getName()->c_str(); - auto nodeState = _variableSpace->hasVariable(node->id()) ? _variableSpace->getVariable(node->id()) :new Variable(nullptr, cname, node->id()); - if (node->getName() != nullptr) - nodeState->setName(node->getName()); - - - if (node->isInplace()) - nodeState->markRemovable(false); - - _handles.push_back(node); - - - _nodes->emplace_back(node->id()); - - // storing node state now - _variableSpace->putVariable(node->id(), nodeState); - - // here we're filling our blocks with future variables - if (node->opType() == OpType_LOGIC && node->opNum() == logic::While) { - // filling while - int inputs = node->input()->size(); - for (int e = 0; e < inputs - 2; e++){ - auto deepVar = new Variable(nullptr, nullptr, node->id(), e); - - std::pair id(node->id(), e); - _variableSpace->putVariable(id, deepVar); - } - - } else if (node->hasCustomOp()) { - // custom ops require Block inside. but we'll set it inside buildGraph - - // TODO: we want to change this, to make blocks thread-local/session-local - ContextPrototype* block = nullptr; - - if (!node->hasBlockAttached()) { - block = new ContextPrototype(node->getCustomOp()->getOpDescriptor(), node->id()); - node->setContextPrototype(block); - } else - block = node->getContextPrototype(); - - - if (!block->hasVariablesFilled()) { - - for (uint32_t e = 0; e < node->input()->size(); e++) { - auto p = node->input()->at(e); - - block->pickInput(p); - } - } - - // and might have > 1 output - if (node->getCustomOp()->getOpDescriptor()->getNumberOfOutputs() > 1) { - for (int e = 1; e < node->getCustomOp()->getOpDescriptor()->getNumberOfOutputs(); e++) { - auto deepVar = new Variable(nullptr, nullptr, node->id()); - //deepVar->setId(node->id()); - deepVar->setId(node->id(), e); - if (node->isInplace()) - deepVar->markRemovable(false); - - std::pair id(node->id(), e); - _variableSpace->putVariable(id, deepVar); - } - } else { - // we need to check, if we should propagate output of this variable somewhere - for (int e = 0; e < node->output()->size(); e++) { - auto out = node->output()->at(e); - if (out.first < 0) { - nd4j_debug("Node [%i] will be propagating its output to Variable [%i]\n", node->id(), out.first); - auto extVar = _variableSpace->getVariable(out); - if (extVar->hasNDArray()) { - nodeState->setNDArray(extVar->getNDArray()); - nodeState->markRemovable(false); - } - } - } - } - } - - // we're saving only ops that have internal outpus here - if (_configuration->_outputMode == OutputMode_VARIABLE_SPACE) - if (node->hasInternalOutputs()) - pushToOutputOnce(node->id()); - - // if outputs are undefined, we have to auto-create variable - if (node->output()->size() == 0 || (node->output()->size() == 1 && node->output()->at(0).first == 0)){ - Variable* var; - if (!_variableSpace->hasVariable(node->id())) { - var = new Variable(); - } else { - var = _variableSpace->getVariable(node->id()); - } - // nd4j_logger("Adding auto output variable; Output size: %i\n", node->output()->size()); - - var->setId(node->id()); - var->setName(node->getName()); - _variableSpace->putOutputVariable(var); - //node->pickExternalOutput(var->id()); - - this->_autos.push_back(var->id()); - -// } - } else if (node->hasExternalOutputs()) { - // TODO: we might want this behavior configurable! - nd4j_logger("Adding specific output variable: Outputs: %i; HasInternal: %i;\n", node->output()->size(), node->hasInternalOutputs()); - - // we're pushing this node to output only - if ((!node->hasInternalOutputs() && (_configuration->_outputMode == OutputMode_IMPLICIT || _configuration->_outputMode == OutputMode_EXPLICIT_AND_IMPLICIT)) ) { - for (int e = 0; e < (int) node->output()->size(); e++) { - if (node->output()->at(e).first < 0) - pushToOutputOnce(node->output()->at(e).first); - } - - nd4j_logger("Loop finished: %i outputs now\n", this->_output.size()); - } - } - - // ops that are tied to specific scope are never placed into the structure. - if (node->isScoped()) { - if (_mappedScopes.count(node->scopeId()) < 1) { - nd4j_printf("Requested scope [%i/%s] wasn't created yet\n", node->scopeId(), node->scopeName()->c_str()); - throw std::invalid_argument("Unknown scope requested"); - } - - Scope* scope = _mappedScopes.at(node->scopeId()); - scope->push_back(node); - - return; - } - - std::pair pair(node->id(), node); - // nd4j_debug("Adding node_%i\n", node->id()); - // if model has only external variables as input - it goes to first layer, no matter what. - if (node->hasExternalInputs() && !node->hasInternalInputs()) { - node->setLayer(0); - - injectNode(node); - - // nd4j_logger("A Node_%i mapped to layer_%i; Output: %i;\n", node->id(), node->getLayer(), node->output()->at(0)); - - return; - } else { - // in some cases we're able to put stuff immediately - if (node->hasInternalInputs() && !node->hasExternalInputs() && node->input()->size() == 1) { - - bool automapAllowed = true; - for (int e = 0; e < node->input()->size(); e++) { - auto cInput = node->input()->at(e); - int cFirst = cInput.first; - if (_mapped->count(cFirst) == 0) { - automapAllowed = false; - break; - } - } - - // we only can put single input nodes, whose outputs were not mapped yet - //if (_mapped->count(node->input()->at(0).first) == 1 && (node->output()->size() == 0 || _mapped->count(node->output()->at(0).first) == 0)) { - if (automapAllowed) { - auto parent = _mapped->at(node->input()->at(0).first); - int nLayer = parent->getLayer() + 1; - - expandOnion(nLayer); - node->setLayer(nLayer); - injectNode(node); - - if (node->output()->size() > 0) { - nd4j_logger("Node_%i mapped to layer_%i; Output: %i;\n", node->id(), node->getLayer(), node->output()->at(0)); - } else { - nd4j_logger("Node_%i mapped to layer_%i; Output: none;\n", node->id(), node->getLayer()); - } - - return; - } - } /*else if (node->opType() == OpType_LOGIC && node->opNum() == 10) { - // Scopes are just being added. They won't be executed on their own anyway. - - int nLayer = _onion->size(); - - expandOnion(nLayer); - node->setLayer(nLayer); - injectNode(node); - - nd4j_logger("Node_%i mapped Scope to layer_%i; Output: %i;\n", node->id(), node->getLayer(), node->output()->at(0)); - - return; - } -*/ - // otherwise we're putting it to unmapped space for further sorting - _unmapped.insert(pair); - _unmappedMap.emplace_back(pair.first); - nd4j_debug("adding: %i\n", pair.first); - } - } + if (dims > 0) { + std::vector __shape; + // we don't have rank1 arrays. vector is 2d. + if (dims == 1) + __shape.push_back(1); - Nd4jStatus Graph::buildGraph() { - if (_built.load()) { - prepareOutputs(); - return ND4J_STATUS_OK; - } + // roll through dimensions + for (auto s: shape.dim()) { + __shape.push_back((int) s.size()); + } - typename MAP_IMPL::iterator fit; - int cnts = 0; - for ( fit = _unmapped.begin(); fit != _unmapped.end(); fit++ ) { - int tK = fit->first; - int tF = _unmappedMap.at(cnts++); - } - - int buildCnt = 0; - int buildLimit = _unmapped.size() * 2; - while (_unmapped.size() > 0) { - - int sz = _unmapped.size(); - int sf = _unmappedMap.size(); - - std::vector queue; - - // first pass for unmapped nodes, we try to build tale here - typename MAP_IMPL::iterator it; - int cntf = 0; - nd4j_debug("-----------\n",""); - for ( it = _unmapped.begin(); it != _unmapped.end(); it++ ) { - auto node = it->second; - int tK = it->first; - int tF = _unmappedMap.at(cntf++); - - //nd4j_printf("tK: %i; tF: %i\n", tK, tF); - //for (int f = 0; f < sz; f++) { - // auto node = _unmapped.at(_unmappedMap.at(f)); - - - // single-input node - if (node->input()->size() == 1) { - - if (node->getName() == nullptr) { - nd4j_debug("Trying SI Node_%i\n", node->id()); - } else { - nd4j_debug("Trying SI Node_%i:[%s]\n", node->id(), node->getName()->c_str()); - } - - int iNode = node->input()->at(0).first; - if (iNode < 0 || _variableSpace->hasExternalVariable(iNode)) { - // this is external variable, should we check, who's the last user of this variable? - int lastLayer = _onion->size(); - expandOnion(lastLayer); - - node->setLayer(lastLayer); - this->injectNode(node); - - if (node->hasCustomOp()) { - ContextPrototype* block = nullptr; - - if (!node->hasBlockAttached()) { - block = new ContextPrototype(node->getCustomOp()->getOpDescriptor(), node->id()); - node->setContextPrototype(block); - } else - block = node->getContextPrototype(); - - - if (!block->hasVariablesFilled()) { - - for (int e = 0; e < node->input()->size(); e++) { - auto p = node->input()->at(e); - - block->pickInput(p); - } - } - } - } else if (_mapped->count(iNode) > 0) { - int maxLayer = _mapped->at(iNode)->getLayer() + 1; - - node->setLayer(maxLayer); - if (_onion->count(maxLayer) == 0) - expandOnion(maxLayer); - - this->injectNode(node); - queue.emplace_back(node->id()); - - if (node->hasCustomOp()) { - ContextPrototype* block = nullptr; - - if (!node->hasBlockAttached()) { - block = new ContextPrototype(node->getCustomOp()->getOpDescriptor(), node->id()); - node->setContextPrototype(block); - } else - block = node->getContextPrototype(); - - - if (!block->hasVariablesFilled()) { - - for (uint32_t e = 0; e < node->input()->size(); e++) { - auto p = node->input()->at(e); - - block->pickInput(p); - } - } - } - } else - continue; - - //_unmapped.erase(node->id()); - queue.emplace_back(node->id()); - } else { - // multi-input node - if (node->getName() == nullptr) { - nd4j_debug("Trying MI Node_%i\n", node->id()); - } else { - std::string np = *(node->getName()); - nd4j_debug("Trying MI Node_%i:[%s]\n", node->id(), node->getName()->c_str()); - } - - int maxLayer = 0; - bool breaker = false; - for (unsigned int e = 0; e < node->input()->size(); e++) { - int nodeId = node->input()->at(e).first; - - // if input node wasn't mapped yet - we'll have skip it in this round - if (_mapped->count(nodeId) == 1) { - auto iNode = _mapped->at(nodeId); - - if (maxLayer < iNode->getLayer()) - maxLayer = iNode->getLayer(); - } else - if (node->opType() == OpType_LOGIC) { - // just allow it? - } else // checking if that's static variable - if (nodeId > 0 && !_variableSpace->hasExternalVariable(nodeId)) { - breaker = true; - break; - } - } - - if (breaker) - continue; - - maxLayer++; - if (_onion->count(maxLayer) == 0) - expandOnion(maxLayer); - - node->setLayer(maxLayer); - injectNode(node); - queue.emplace_back(node->id()); - - if (node->hasCustomOp()) { - ContextPrototype* block = nullptr; - - if (!node->hasBlockAttached()) { - block = new ContextPrototype(node->getCustomOp()->getOpDescriptor(), node->id()); - node->setContextPrototype(block); - } else - block = node->getContextPrototype(); - - if (!block->hasVariablesFilled()) { - - for (uint32_t e = 0; e < node->input()->size(); e++) { - auto p = node->input()->at(e); - - block->pickInput(p); - } - } - } - } - } - - for (auto &v: queue) - _unmapped.erase(v); - - // second pass is mover, we'll be moving onion layers around here - buildCnt++; - if (buildCnt > buildLimit) { - nd4j_printf("Unable to build graph, probably unmapped nodes, or something: %i nodes left\n", _unmapped.size()); - for (auto v: _unmapped) { - Node* node = v.second; - nd4j_printf("Unmapped node: [%i]\n", node->id()); - } - - throw std::runtime_error("Unable to build graph"); - } - } - - if (_unmapped.size() == 0) - _built.store(true); - - prepareOutputs(); - - return sd::Status::OK(); - } + variable->setNDArray(new NDArray('c', __shape)); + __length = variable->getNDArray()->lengthOf(); - void Graph::tagInplaceNodes() { - // just calling, in case it wasn't built before - if (!_built.load()) - this->buildGraph(); - - bool buildRef = false; - - // checking for non-refenenced nodes - for (auto v: *_nodes) { - // skipping unmapped nodes - if (_mapped->count(v) == 0) - continue; - - Node* node = _mapped->at(v); - if (node->totalReferences() == 0) { - buildRef = true; - break; - } - } - - if (buildRef) { - for (auto v: *_nodes) { - // skipping unmapped nodes - if (_mapped->count(v) == 0) - continue; - - Node* node = _mapped->at(v); - auto inputs = node->input(); - for (auto &t: *inputs) { - if (_mapped->count(t.first) == 0) - continue; - - Node* inode = _mapped->at(t.first); - inode->addReference(node->id()); - } - } - } - - - for (auto v: *_nodes) { - // skipping unmapped nodes - if (_mapped->count(v) == 0) - continue; - - Node* node = _mapped->at(v); - - /** - * Node can be inplace if 2 requirements met: - * 1) current node allows in-place modification - * 2) source node has only 1 output - */ - - // checking for first requirement first - if (node->getCustomOp() != nullptr) - if (node->getCustomOp()->getOpDescriptor()->allowsInplace()){ - bool singleInput = true; - auto inputs = node->input(); - for (auto &t: *inputs) { - if (_mapped->count(t.first) == 0) - continue; - - Node* inode = _mapped->at(t.first); - - int output_size = inode->output()->size(); - - // checking for second requirement: inputNode must not be used as input anywhere - if (inode->totalReferences() > 1) { - singleInput = false; - break; - } - } - - node->markInplace(singleInput); - } - } - } + nd4j_printf("Tensor shape found: %i dims;\n", dims); + variable->getNDArray()->printShapeInfo(); + } + } - void Graph::prepareOutputs() { - // if we're dumping everything out there - we'll add external variables as well - if (_configuration->_outputMode == OutputMode_VARIABLE_SPACE) { - auto ext = _variableSpace->getExternalVariables(); - nd4j_verbose("Number of external variables: %i\n", ext->size()) - for (unsigned int e = 0; e < ext->size(); e++) { - pushToOutputOnce(ext->at(e)->id()); - } - - for (auto v: *_nodes) { - if (_mapped->count(v) == 0) - continue; - - Node* node = _mapped->at(v); - - if (std::find(_output.begin(), _output.end(), node->id()) == _output.end()) - _output.emplace_back(node->id()); - } - - } else if (_configuration->_outputMode == OutputMode_IMPLICIT) { - // we're adding final nodes of the graph. those, not used as input anywhere - nd4j_debug("Paring nodes... \n", ""); - - if (Environment::getInstance().isDebugAndVerbose()) { - // nd4j_printv("current _output", _output); - } - //_output.clear(); - - for (auto v: *_nodes) { - // we should check for scopes, and other possible non-mapped stuff - if (_mapped->count(v) == 0) - continue; - - Node* node = _mapped->at(v); - if (node->name() != nullptr) { - nd4j_debug("Node %i; Name: [%s]\n", v, node->name()->c_str()); - } else { - nd4j_debug("Node %i\n", v); - } - - // updating outputs now - for (int e = 0; e < node->input()->size(); e++) { - auto inP = node->input()->at(e); - - // input can be variable, or node. we only care about nodes - if (_mapped->count(inP.first) > 0) { - _mapped->at(inP.first)->pickOutputOnce(v); - } - } - } - // at this point all nodes have filled inputs/outputs, so we know nodes that do not have any connected outputs - - for (auto v: *_nodes) { - // we should check for scopes, and other possible non-mapped stuff - if (_mapped->count(v) == 0) - continue; - - Node* node = _mapped->at(v); - - if (!node->hasInternalOutputs()) { - if (node->name() != nullptr) { - nd4j_debug("Output node found: [%i:<%s>]\n", v, node->name()->c_str()); - } else { - nd4j_debug("Output node found: [%i]\n", v); - } - - // FIXME: we don't really need search here. - - if (std::find(_output.begin(), _output.end(), node->id()) == _output.end()) - _output.emplace_back(node->id()); - } else if (Environment::getInstance().isDebugAndVerbose()) { - nd4j_debug("Node [%i:<%s>] has %i outputs announced:\n", v, node->name()->c_str(), node->output()->size()); - printf("{"); - for (auto s : *node->output()) { - printf("[%i:%i], ", s.first, s.second); - } - printf("}\n"); - fflush(stdout); - } - } - } - } + // it can be valueOf array + if (attr.tensor().int_val_size() == 1 && __length > 0) { + variable->getNDArray()->assign((T) + attr.tensor().int_val(0)); + } + } + } + } else { + nd4j_verbose("Node id: [%i]; name: [%s]; opName: [%s]\n", n + 1, + node.name().c_str(), node.op().c_str()); - Graph::Graph(const FlatGraph *flatGraph, VariableSpace *variableSpace) { - this->_onion = new MAP_IMPL *>(); - this->_mapped = new MAP_IMPL (); - this->_nodes = new std::vector(); - this->_variableSpace = variableSpace == nullptr ? new VariableSpace() : variableSpace; - bool trusted = flatGraph != nullptr; - - // add 0 layer - this->expandOnion(0); - - // if there was no exec configuration in flatgraph - create default one - if (flatGraph != nullptr && flatGraph->configuration() != nullptr) { - _configuration = new ExecutorConfiguration(flatGraph->configuration()); - } else - _configuration = new ExecutorConfiguration(); - - // if memory reqs were set - initialize workspace - if (_configuration->_footprintForward > 0) { - sd::memory::Workspace *workspace = this->_variableSpace->launchContext()->getWorkspace(); - workspace->expandBy(_configuration->_footprintForward); - } - - // parsing variables here - if (flatGraph != nullptr && flatGraph->variables() != nullptr && flatGraph->variables()->size() > 0) { - for (unsigned int e = 0; e < flatGraph->variables()->size(); e++) { - auto flatVar = flatGraph->variables()->Get(e); - - auto var = new Variable(flatVar); - std::pair pair(flatVar->id()->first(), flatVar->id()->second()); - _variableSpace->putVariable(pair, var); - - // if that's VariableSpace mode - we're pushing it to _output - if (_configuration->_outputMode == OutputMode_VARIABLE_SPACE) - pushToOutputOnce(var->id()); - - } - } - - // at this point we expect all variables are already registered - // we're saving outputs only if explicit mode is set - if (_configuration->_outputMode == OutputMode_EXPLICIT || _configuration->_outputMode == OutputMode_EXPLICIT_AND_IMPLICIT) { - if (flatGraph != nullptr && flatGraph->outputs() != nullptr) { - for (unsigned int e = 0; e < flatGraph->outputs()->size(); e++) { - auto out = flatGraph->outputs()->Get(e); - std::pair vp(out->first(), out->second()); - if (!_variableSpace->hasVariable(vp)) { - nd4j_verbose("Non-existent variable requested: %i\n", out); - throw std::runtime_error("Non-existent variable requested"); - } - - // TODO: fix this .first - pushToOutputOnce(vp.first); - } - } - } - - // rolling through nodes - if (flatGraph != nullptr && flatGraph->nodes() != nullptr && flatGraph->nodes()->size() > 0) { - for (unsigned int e = 0; e < flatGraph->nodes()->size(); e++) { - auto node = flatGraph->nodes()->Get(e); - - if (node->output() == nullptr || node->output()->size() == 0) { - nd4j_verbose("Orphan node detected: %i; AutoOutput to be considered\n", node->id()); - } - - nd4j_debug("Node name: [%s]\n", node->name()->c_str()); - auto nnode = new Node(node); - /* - expandOnion(e); - nnode->setLayer(e); - this->addNode(nnode); - injectNode(nnode); - _unmapped.erase(nnode->id()); - */ - // just filling list of nodes - _unmapped[nnode->id()] = nnode; - } - - - this->toposortNodes(); - - _built = true; - } - - /** - * we allow in-place execution optimizations ONLY if 2 requirements met: - * 1) this is FeedForward pass ONLY - * 2) OPTIMIZED mode is set, so no intermediate results are going to be used - */ - if (_configuration->_direction == Direction_FORWARD_ONLY && _configuration->_outputMode == OutputMode_OPTIMIZED) - this->tagInplaceNodes(); - } + sd::ops::DeclarableOp *op = + sd::ops::OpRegistrator::getInstance().getOperationFloat(node.op().c_str()); + if (op == nullptr) { + nd4j_verbose("Op wasn't found: %s\n", node.op().c_str()); + return nullptr; + } + + auto jNode = new Node(); + jNode->setName(node.name()); + jNode->setId(++nodesCounter); + jNode->setCustomOp(op); + jNode->setBlock(new Block(jNode->id(), variableSpace)); - void Graph::toposortNodes() { - int attempts = 0; - - // in worst possible case number of rolls equals to the number of nodes - int limit = _unmapped.size() + 1; - - std::vector tbd; - while (!_unmapped.empty() && attempts < limit) { - for (auto np:_unmapped) { - auto id = np.first; - tbd.emplace_back(id); - } - - // rolling through unmapped nodes - for (auto np:tbd) { - auto id = np; - auto node = _unmapped[id]; - - // this variables contains the layer of maximal dependency - int maxDependencyLayer = -1; - - // simple flag - bool canMap = true; - - // looping through inputs, to check if they were already mapped - auto inputs = node->input(); - for (auto in:*inputs) { - // only 2 options here, in either refers to the node, or to the variable - // however, node can be already mapped, or not yet. this makes it 3 options :) - - if (hasNode(in.first)) { // node was mapped - auto dependency = nodeById(in.first); - auto layer = dependency->getLayer(); - if (layer > maxDependencyLayer) - maxDependencyLayer = layer; - - } else if (_unmapped.count(in.first) > 0) { // node is unmapped yet - // can't map this node yet, due to non-resolved dependencies - canMap = false; - } else if (_variableSpace->hasVariable(in.first)){ // that's probably variable. if not - we'll throw exception later - // do nothing, maxDepLayer is -1 here, because it's a variable input - } else { - throw graph::unresolved_input_exception::build("Unknown input specified", id, in); - } - } - - if (canMap) { - auto layer = maxDependencyLayer + 1; - this->expandOnion(layer); - node->setLayer(layer); - this->addNode(node); - this->injectNode(node); - _unmapped.erase(id); - } - } - - // if something was successfully mapped - remove it from unmapped entries - if (!tbd.empty()) - tbd.clear(); - - attempts++; - } - - if (!_unmapped.empty()) - throw graph_exception("Graph wasn't toposorted", 0); - - _built = true; - } + std::pair pair(node.name(), jNode->id()); + variablesMap.insert(pair); + + // multi-output nodes require special treatment + for (int e = 0; e < op->getOpDescriptor()->getNumberOfOutputs(); e++) + { std::string deepName(node.name()); deepName += ":" + std::to_string(e); auto + deepVar = new Variable(); deepVar->setName(&deepName); + if (e > 0) + deepVar->setId(--variablesCounter); + else + deepVar->setId(jNode->id()); -/** - * This method returns number of root nodes in this graph - * @return - */ - int Graph::rootNodes() { - return this->_onion->at(0)->size(); - } + std::pair pair(deepName, deepVar->id()); + variablesMap.insert(pair); - /** - * This method returns total number of nodes in this graph - * @return - */ - int Graph::totalNodes() { - if (_built.load() != true) - buildGraph(); + variableSpace->putVariable(deepVar->id(), deepVar); - return _mapped->size(); - } + std::pair nodepair(jNode->id(), e); + variableSpace->putVariable(nodepair, deepVar); + } - Nd4jStatus Graph::validate() { - if (!_built) { - _mutexPreprocessing.lock(); - if (!_built) { - this->buildGraph(); - } - _mutexPreprocessing.unlock(); - } - - if (_built.load() != true) - return ND4J_STATUS_BAD_GRAPH; - - return ND4J_STATUS_OK; - }; - - void Graph::printOutNode(Node* node) { - nd4j_printf("%i. ", node->id()); - switch(node->opType()) { - case OpType_CUSTOM: { - printf("%s; ", node->getCustomOp()->getOpName()->c_str()); - } - break; - case OpType_LOGIC: { - printf("%s; ", EnumUtils::_LogicOpToString(node->opNum())); - } - break; - default: { - printf("%s:{%i}; ", EnumUtils::_OpTypeToString(node->opType()), (int) node->opNum()); - } - } - - nd4j_printf("Inputs: [", ""); - //auto block = node->getBlock(); - for (int e = 0; e < node->input()->size(); e++) { - - auto in = node->input()->at(e); - printf("{%i:%i}", in.first, in.second); - if (e < node->input()->size() - 1) - nd4j_printf(", ", ""); - } - - if (node->opType() == OpType_CUSTOM) { - auto ctx = node->protoContext(); - if (ctx->getIArguments()->size() > 0) { - printf("]; iArgs: ["); - - for (int e = 0; e < ctx->getIArguments()->size(); e++) { - printf("%i", ctx->getIArguments()->at(e)); - if (e < ctx->getIArguments()->size() - 1) - nd4j_printf(", ", ""); - } - } - } - - nd4j_printf("]; \n", ""); - - -// printf("\n"); - fflush(stdout); - } - void Graph::printOut() { - buildGraph(); - - // print variables first - if (_variableSpace->totalEntries() > 0) { - nd4j_printf("\nPrinting out Variables...\n", ""); - auto vars = _variableSpace->getVariables(); - - for (Variable* v: vars) { - if (v->hasNDArray()) { - auto shape = ShapeUtils::shapeAsString(v->getNDArray()); - auto values = v->getNDArray()->asString(16); - auto dtype = DataTypeUtils::asString(v->getNDArray()->dataType()); - - if (v->getName() != nullptr && !v->getName()->empty()) { - nd4j_printf("<%s> <%i:%i> dtype: %s; shape: %s; values: %s;\n", v->getName()->c_str(), v->id(), v->index(), dtype.c_str(), shape.c_str(), values.c_str()); - } else { - nd4j_printf("<%i:%i> dtype: %s; shape: %s; values: %s;\n", v->id(), v->index(), dtype.c_str(), shape.c_str(), values.c_str()); - } - } else if (v->hasNDArrayList()) { - // TODO: add better NDArrayList printout - nd4j_printf("<%i:%i> holds ArrayList", v->id(), v->index()); - } - } - } - - if (_onion->size() > 0) - nd4j_printf("\nPrinting out Graph...\n", ""); - - int opCnt = 0; - for (int l = 0; l < _onion->size(); l++) { - int layerSize = _onion->count(l) == 1 ? _onion->at(l)->size() : 0; - - for (int n = 0; n < layerSize; n++) { - Node* node = _onion->at(l)->at(n); - - // we're skipping Scopes here - if (node->opType() == OpType_LOGIC && node->opNum() == logic::Scope) - continue; - - printOutNode(node); - } - } - - - if (_scopes.size() > 0) - nd4j_printf("\nPrinting out Scopes...\n",""); - - for (int s = 0; s < _scopes.size(); s++) { - Scope* scope = _scopes.at(s); - nd4j_printf("Scope %i:<%s>:\n", scope->id(), scope->name()->c_str()); - - for (int n = 0; n < scope->nodes()->size(); n++) { - Node* node = scope->nodes()->at(n); - printOutNode(node); - } - } - - fflush(stdout); - } + printf(" Inputs: ["); + for (int i = 0; i < node.input_size(); i++) { + nd4j_printf("Trying input: %s\n", node.input(i).c_str()); - Nd4jStatus Graph::validateNode(Node *node) { - // TODO: to be implemented - return ND4J_STATUS_OK; - } + // if this fails - we're probably on partial input :) + if (!variablesMap.count(node.input(i))) + return nullptr; - std::vector Graph::getOperations() { - buildGraph(); - // nd4j_printf("\nRetrieving ops from the Graph and collect them...\n", ""); - std::vector res; - - int opCnt = 0; - for (int l = 0; l < _onion->size(); l++) { - int layerSize = _onion->count(l) == 1 ? _onion->at(l)->size() : 0; - - for (int n = 0; n < layerSize; n++) { - Node* node = _onion->at(l)->at(n); - if (node->name() == nullptr) continue; - sd::ops::OpDescriptor* pOpDescriptor = nullptr; - std::string opNameStr; //node->name(); - int numInputs = 0; - int numOutputs = 0; - - switch(node->opType()) { - case OpType_CUSTOM: { - pOpDescriptor = node->getCustomOp()->getOpDescriptor(); - } - break; - case OpType_LOGIC: { - opNameStr = std::string(EnumUtils::_LogicOpToString(node->opNum())); - } - break; - default: { - opNameStr = std::string(EnumUtils::_OpTypeToString(node->opType()))+"{" + ops::OpRegistrator::getInstance().local_to_string((int) node->opNum()) + "}"; - } - } - - if (node->input()) - numInputs = node->input()->size(); - - if (node->output()) - numOutputs = node->output()->size(); - bool inplace = node->isInplace(); - - //OpDescriptor opDescriptor(numInputs, numOutputs, opNameStr, inplace); - - // we're skipping Scopes here - if (node->opType() == OpType_LOGIC && node->opNum() == logic::Scope) - continue; - if (pOpDescriptor) - res.emplace_back(*pOpDescriptor); - else - res.emplace_back(sd::ops::OpDescriptor(numInputs, numOutputs, opNameStr, inplace)); - } - } - - - // nd4j_printf("\nCollecting out Scopes...\n",""); - for (int s = 0; s < _scopes.size(); s++) { - Scope* scope = _scopes.at(s); - // nd4j_printf("Scope %i:<%s>:\n", scope->id(), scope->name()->c_str()); - - for (int n = 0; n < scope->nodes()->size(); n++) { - Node* node = scope->nodes()->at(n); - //printOutNode(node); - if (node->name() == nullptr) continue; - std::string opNameStr; //node->name(); - sd::ops::OpDescriptor* pOpDescriptor = nullptr; - int numInputs = 0; - int numOutputs = 0; - - switch(node->opType()) { - case OpType_CUSTOM: { - pOpDescriptor = node->getCustomOp()->getOpDescriptor(); - } - break; - case OpType_LOGIC: { - opNameStr = std::string(EnumUtils::_LogicOpToString(node->opNum())); - } - break; - default: { - opNameStr = std::string(EnumUtils::_OpTypeToString(node->opType()))+"{" + ops::OpRegistrator::getInstance().local_to_string((int) node->opNum()) + "}"; - } - } - - if (node->input()) - numInputs = node->input()->size(); - - if (node->output()) - numOutputs = node->output()->size(); - bool inplace = node->isInplace(); - - if (pOpDescriptor != nullptr) - res.emplace_back(*pOpDescriptor); - else - res.emplace_back(sd::ops::OpDescriptor(numInputs, numOutputs, opNameStr, inplace)); - } - } - - return res; - } + printf("%s (%i)", node.input(i).c_str(), + variablesMap.at(node.input(i))); - Scope *Graph::scopeById(int id) { - if (_mappedScopes.count(id) == 0) { - nd4j_printf("Requested Scope [%i] doesn't exist\n", id); - throw std::runtime_error("Non-existent Scope was requested"); - } - return _mappedScopes.at(id); - } + jNode->pickInput(variablesMap.at(node.input(i))); + jNode->getBlock()->pickInput(variablesMap.at(node.input(i))); - void Graph::forgetVariableSpace() { - _variableSpace = nullptr; - } - void Graph::replaceState(VariableSpace *state, ExecutorConfiguration *configuration) { - delete _variableSpace; - delete _configuration; + if (i < node.input_size() + 1) + printf(", "); + } + printf("]\n"); - _variableSpace = state; - _configuration = configuration; - } + graph->addNode(jNode); + } + } - Graph* Graph::cloneWithProxy() { - auto clone = new Graph(); + return graph; + */ +} - clone->replaceState(new VariableProxy(this->_variableSpace), this->_configuration->clone()); +void Graph::addPlaceholder(const std::string &nodeName, DataType dataType, + const std::vector &shape) { + int id = _maxId++; - // transfer nodes - for (int e = 0; e < _nodes->size(); e++) - clone->_nodes->emplace_back(_nodes->at(e)); + _symbolicLookupTable[nodeName] = id; - // transfer outputs - for (auto v: _output) - clone->_output.emplace_back(v); + auto var = std::make_shared(true, dataType, shape); + var->setName(nodeName); + _variableSpace.putVariable(id, var); - // transfer autos - for (auto v: _autos) - clone->_autos.emplace_back(v); + _placeholders.emplace_back(nodeName); +} - // transfer scopes - for (auto &v: _mappedScopes) { - auto scp = v.second->clone(); - clone->_mappedScopes[v.first] = scp; - clone->_scopes.emplace_back(scp); - } +std::map Graph::execute( + const std::map &dictionary, + const std::vector &outputs, + const GraphExecutor &executor) const { + // creating our proxy, we'll use it for actual execution + VariableProxy proxy(&_variableSpace); + + // first of all we check existence of placeholders in dictionary + int placeholdersCount = 0; + for (const auto &v : dictionary) { + if (_symbolicLookupTable.count(v.first) == 0) + throw unresolved_input_exception::build("Dictionary entry doesn't exist", + v.first); + + // we also check if arrays provided here do match placeholder restrictions + // of shape and dtype + auto var = _variableSpace.getVariable(v.first); + if (var->dataType() != DataType::ANY && + var->dataType() != v.second.dataType()) + throw datatype_exception::build("Placeholder requires another data type", + var->dataType(), v.second.dataType()); + + auto shape = v.second.getShapeAsVector(); + if (shape != var->shape()) + throw shape_mismatch_exception::build( + "Placeholder requires specific shape", var->shape(), shape); + + // update the placeholder + proxy.putVariable(v.first, var->id(), var->index(), v.second); + + // we must also check if all placeholders were resolved + placeholdersCount++; + } + + // TODO: it would be nice if we'll print out unresolved placeholders + if (placeholdersCount != _placeholders.size()) { + std::string missing; + for (const auto &v:_placeholders) { + if (dictionary.count(v) == 0) + missing += "<" + v + ">, "; + } + throw std::runtime_error("Placeholders were not resolved: [" + missing + "]"); + } - // transfer mapped nodes - for (auto &v: *_onion) { - auto vec = clone->_onion->count(v.first) > 0 ? clone->_onion->at(v.first) : new std::vector(); + // we also must check existence of requested outputs + for (const auto &v : outputs) { + if (_symbolicLookupTable.count(v) == 0) + throw unresolved_output_exception::build("Requested output doesn't exist", + v); + } + // execute optimized version of this graph + auto status = executor.execute(optimizedGraph(), proxy); + if (status != Status::OK()) + throw graph_execution_exception("Graph execution failed, error code: ", status); - // cloning actual nodes - auto ovec = (*_onion)[v.first]; - for (auto x: *(ovec)) { - auto n = x->clone(); - vec->emplace_back(n); - _handles.emplace_back(n); - (*clone->_mapped)[n->id()] = n; - } + // fetch outputs from our VariableProxy + std::map result; + for (const auto &v : outputs) { + // resolve string -> int dep + int id = -119; + if (_symbolicLookupTable.count(v) > 0) + id = _symbolicLookupTable.at(v); - if (clone->_onion->count(v.first) < 1) - (*clone->_onion)[v.first] = vec; - } - // transfer mapped nodes - for (auto &v: _unmapped) - clone->_unmapped[v.first] = v.second->clone(); + if (!proxy.hasVariable(id)) + throw unresolved_output_exception::build( + "Requested output doesn't exist after execution", v); - clone->_built.store(_built.load()); + auto var = proxy.getVariable(id); - return clone; - } + // TODO: we want to make sure ManagedDataBuffer doesn't leak here + result[v] = *var->getNDArray(); + } - Graph* Graph::clone() { - auto clone = new Graph(); + return result; +} - clone->replaceState(this->_variableSpace->clone(), this->_configuration->clone()); +Graph::Graph(const Graph &other) : _memoryManager(other._memoryManager) { + _configuration = other._configuration; + _variableSpace = other._variableSpace; + _stash = other._stash; + _unmapped = other._unmapped; + _symbolicLookupTable = other._symbolicLookupTable; + _built = false; + _maxId = other._maxId; +} - // transfer nodes - for (int e = 0; e < _nodes->size(); e++) - clone->_nodes->emplace_back(_nodes->at(e)); +Graph &Graph::operator=(const Graph &other) noexcept { + if (this == &other) return *this; - // transfer outputs - for (auto v: _output) - clone->_output.emplace_back(v); + _configuration = other._configuration; + _variableSpace = other._variableSpace; + _stash = other._stash; + _unmapped = other._unmapped; + _symbolicLookupTable = other._symbolicLookupTable; + _built = false; + _maxId = other._maxId; - // transfer autos - for (auto v: _autos) - clone->_autos.emplace_back(v); + return *this; +} - // transfer scopes - for (auto &v: _mappedScopes) { - auto scp = v.second->clone(); - clone->_mappedScopes[v.first] = scp; - clone->_scopes.emplace_back(scp); - } +Graph::Graph(Graph &&other) : _memoryManager(other._memoryManager) { + _configuration = other._configuration; + _variableSpace = other._variableSpace; + _stash = other._stash; - // transfer mapped nodes - for (auto &v: *_onion) { - auto vec = clone->_onion->count(v.first) > 0 ? clone->_onion->at(v.first) : new std::vector(); + _unmapped = std::move(other._unmapped); + _symbolicLookupTable = std::move(other._symbolicLookupTable); + _built = false; + _maxId = other._maxId; +} - // cloning actual nodes - auto ovec = (*_onion)[v.first]; - for (auto x: *(ovec)) { - auto n = x->clone(); - vec->emplace_back(n); - _handles.emplace_back(n); - (*clone->_mapped)[n->id()] = n; - } +Graph &Graph::operator=(Graph &&other) noexcept { + if (this == &other) return *this; - if (clone->_onion->count(v.first) < 1) - (*clone->_onion)[v.first] = vec; - } + _configuration = other._configuration; + _variableSpace = other._variableSpace; + _stash = other._stash; - // transfer mapped nodes - for (auto &v: _unmapped) - clone->_unmapped[v.first] = v.second->clone(); + _unmapped = std::move(other._unmapped); + _symbolicLookupTable = std::move(other._symbolicLookupTable); - clone->_built.store(_built.load()); + _built = false; + _maxId = other._maxId; - return clone; - } + return *this; +} - bool Graph::hasNode(int id) { - return _mapped->count(id) > 0; - } +const GraphMemoryManager &Graph::memoryManager() const { return _memoryManager; } - Node* Graph::nodeById(int id) { - return _mapped->at(id); - } +const OptimizedGraph &Graph::optimizedGraph() const { + std::lock_guard lock(_optimizedLock); - bool Graph::hasScope(int id) { - return _mappedScopes.count(id) > 0; - } + // optionally rebuild optimized graph, if it's out of date + if (_optimized.size() != size()) + _optimized = OptimizedGraph(unmappedNodes(), variableSpace()); - Nd4jLong Graph::hashCode() { - if (!_built.load()) - this->buildGraph(); - - Nd4jLong hash = 0L; - std::string localStamp; - /** - * Plan is: - * 1) get shapes of existing variables - * 2) get hash codes of individual ops - * 3) optionally: get node names, if they are defined - * 4) use long hash on that - */ - int cnt = 0; - /* - if (_variableSpace != nullptr) { - // loop over existing variables - for (auto v: *(_variableSpace->handles())) { - if (v->hasNDArray()) { - NDArray *arr = v->getNDArray(); - if (arr != nullptr && arr->nonNull()) { - auto shape = arr->getShapeAsVector(); - auto string = ShapeUtils::shapeAsString(shape); - localStamp += string; - } - } - } - } - */ - - // loop over nodes in graph - for (auto &v: *_mapped) { - Node *node = v.second; - - // optional part: node names - if (!node->name()->empty()) { - localStamp += *(node->name()); - } - } - - - hash = ops::HashHelper::getInstance().getLongHash(localStamp); - - nd4j_debug("Graph hash: %lld\n", hash); - - return hash; - } - } + return _optimized; } - +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/impl/GraphExecutioner.cpp b/libnd4j/include/graph/impl/GraphExecutioner.cpp deleted file mode 100644 index abc3b2e0c2c9..000000000000 --- a/libnd4j/include/graph/impl/GraphExecutioner.cpp +++ /dev/null @@ -1,906 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include -#include -#include - -//#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -//#include -//#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace sd{ -namespace graph { - -/** - * This method executes given Node (as in Op within Node) - * - * Basically it just does DeclarableOp::execute(Block), and ops to their job. However, there are some additional functionality. - * - * @param graph - Graph instance pointer - * @param node - Node instance pointer, which will be executed - * @param variableSpace - VariableSpace instance pointer - varspace specific to current Thread/Session - * @return - */ - Nd4jStatus GraphExecutioner::executeFlatNode(Graph *graph, Node *node, VariableSpace *variableSpace) { - OpType opType = node->opType(); - int opNum = node->opNum(); -// std::string opName = *(node->getCustomOp()->getOpName()); - - if (opType == OpType_BOOLEAN) { - nd4j_debug("Executing boolean graph node_%i", node->id()); - } else if (opType == OpType_LOGIC) { - nd4j_debug("Executing logic graph node_%i", node->id()); - } else if (opType == OpType_GRAPH) { - nd4j_debug("Executing embedded graph node_%i", node->id()); - } else if (opType != OpType_CUSTOM) { - nd4j_debug("Executing node_%i{%i}\n", node->id(), opNum); - } else { - nd4j_debug("Executing node_%i{%s}\n", node->id(), node->getCustomOp()->getOpName()->c_str()); - } - - Context context(node->getContextPrototype(), variableSpace); - - if (sd::Environment::getInstance().isDebugAndVerbose()) { - //nd4j_debug("Input variables: %i\n", node->input()->size()); - printf(" Inputs: {"); - for (int e = 0; e < node->input()->size(); e++) { - printf("[%i:%i]", node->input()->at(e).first, node->input()->at(e).second); - - if (e < node->input()->size() - 1) - printf(", "); - } - printf("}\n"); - fflush(stdout); - } - - if (node->id() == 13) - nd4j_debug("",""); - - // if true - this is special case: Graph-in-Graph. - if (node->hasGraphEmbedded()) { - auto embedded = node->getGraph(); - - /** - * basically, we should do following things here: - * 1) fill embedded graph with input variables from this graph, if anything should be filled in - * 2) invoke embedded graph - * 3) announce its results as corresponding output variables in current VariableSpace - */ - - // enforcing IMPLICIT mode. or not... should we try to be smarter then user? - //embedded->getExecutorConfiguration()->_outputMode = OutputMode_IMPLICIT; - - if (node->input()->size() != embedded->numberOfPlaceholders()) { - nd4j_debug("Placeholders amount mismatch: %i expected, and %i available\n",node->input()->size(), embedded->numberOfPlaceholders()); - return ND4J_STATUS_BAD_INPUT; - } - - // we need to propagate required variables to the embedded graph - ResultSet deletables; - int cnt = 0; - for (Variable* v: *embedded->getPlaceholders()) { - if (v->getName() != nullptr && v->getName()->size() > 0) { - - // trying symbolic lookup first - if (variableSpace->hasVariable(v->getName())) { - // symbolic feeder - auto array = variableSpace->getVariable(v->getName())->getNDArray(); - auto vr = new NDArray(array->dup()); -// deletables.push_back(vr); - v->setNDArray(vr); - } else { - nd4j_debug("Can't find variable [%s] in parent graph...", v->getName()->c_str()); - return ND4J_STATUS_BAD_INPUT; - //throw "Can't find desired variable"; - } - } else { - // if we're not using symbolic lookup - we'll use sequential approach then - auto p = node->input()->at(cnt); - auto array = variableSpace->getVariable(p)->getNDArray(); - auto vr = new NDArray(array->dup()); - //deletables.push_back(vr); - v->setNDArray(vr); - } - - cnt++; - } - - // executing embedded graph as independent one - Nd4jStatus status = GraphExecutioner::execute(embedded); - if (status != ND4J_STATUS_OK) - return status; - - // now we should migrate its results to this node, as its own outputs - cnt = 0; - auto outputs = embedded->fetchOutputs(); - - for (auto v: *outputs){ - NDArray *array = v->getNDArray(); - v->setNDArray(nullptr); - - std::pair pair(node->id(), cnt++); - - auto var = variableSpace->getVariable(pair); - - //nd4j_printf("HasArray: [%i]; Removable: [%i]\n", var->hasNDArray(), var->isRemovable()); - var->setNDArray(array); - var->markRemovable(true); - } - deletables.size(); - delete outputs; - nd4j_debug("Embedded graph execution finished. %i variable(s) migrated\n", cnt); - - } else if (node->hasCustomOp()) { - // now, if we have something to execute - lets just execute it. - auto status = node->getCustomOp()->execute(&context); - if (status != ND4J_STATUS_OK) - return status; - - // propagate variables - if (node->hasExternalOutputs()) { - for (auto v: *node->output()) { - if (variableSpace->hasExternalVariable(v.first)) { - variableSpace->getVariable(v.first)->getNDArray()->assign(variableSpace->getVariable(node->id())->getNDArray()); - } - } - } - - return status; - } - return ND4J_STATUS_OK; -} - - -/** - * This method executes given Graph instance, and returns error code. - * - * @param graph - * @return one of error codes defined in pointercast.h - */ -Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace) { - auto __variableSpace = variableSpace == nullptr ? graph->getVariableSpace() : variableSpace; - - bool tempFlow = false; - if (__variableSpace->flowPath() == nullptr) { - tempFlow = true; - __variableSpace->setFlowPath(new FlowPath()); - } - auto flowPath = __variableSpace->flowPath(); - - Nd4jLong tb0 = Environment::getInstance().isProfiling() ? GraphProfile::currentTime() : 0L; - graph->buildGraph(); - - auto footprintForward = sd::memory::MemoryRegistrator::getInstance().getGraphMemoryFootprint(graph->hashCode()); - if (footprintForward > 0) { - if (__variableSpace->launchContext()->getWorkspace() != nullptr) { - // this method will work only if current workspace size is smaller then proposed value - nd4j_debug("Setting workspace to %lld bytes\n", footprintForward); - __variableSpace->launchContext()->getWorkspace()->expandTo(footprintForward); - } - } - - // optionally saving graph build time - if (Environment::getInstance().isProfiling()) - flowPath->profile()->setBuildTime(GraphProfile::relativeTime(tb0)); - - Nd4jLong timeStart = Environment::getInstance().isProfiling() ? GraphProfile::currentTime() : 0L; - - bool pe = graph->getExecutorConfiguration()->_executionMode == ExecutionMode_AUTO; - - - // basically if at some point code diverges, code branch might be _DISABLED_, and all nodes within that branch will be disabled as well - - std::deque frames; - bool inFrame = false; - bool leftFrame = false; - - auto nodeTime = GraphProfile::currentTime(); - int lastId = -10000000; - Nd4jLong exec_counter = 0; - // we loop through op layers here - for (int l = 0; l < (int) graph->getOnion()->size(); l++) { - int layerSize = graph->getOnion()->count(l) == 1 ? graph->getOnion()->at(l)->size() : 0; - - int n = 0; -// this omp block will probably never be the case - for (; n < layerSize; n++) { - if (++exec_counter > 10000) { - l = graph->getOnion()->size(); - return Status::THROW("Early termination hit"); - } - - Node* node = graph->getOnion()->at(l)->at(n); - - if (Environment::getInstance().isProfiling()) - flowPath->profile()->nodeById(node->id(), node->name()->c_str()); - - if (lastId != node->id() && Environment::getInstance().isProfiling()) { - if (lastId != -10000000) - flowPath->profile()->nodeById(lastId)->setTotalTime(GraphProfile::relativeTime(nodeTime)); - - lastId = node->id(); - nodeTime = GraphProfile::currentTime(); - } - - nd4j_debug("Step: %lld; Node: %i <%s>\n", exec_counter, node->id(), node->name()->c_str()); - - // on first non-Exit node after loop we can rewind (if planned) - if (!(node->opType() == OpType_LOGIC && node->opNum() == sd::logic::Exit)) { - // VALIDATED - - // if we're out of frame - let's remove it from queue - if (leftFrame) { - auto frame_id = frames.back(); - frames.pop_back(); - flowPath->markFrameActive(frame_id, false); - flowPath->forgetFrame(frame_id); - - leftFrame = false; - } - - - // TODO: move inactivity check right here - bool shouldSkip = false; - if (node->opType() == OpType_LOGIC && node->opNum() == sd::logic::Merge) { - // Merge node has own checkout logic - - auto inputId0 = node->input()->at(0); - auto inputId1 = node->input()->at(1); - - // Merge node can be skipped only both inputs are inactive - if (!flowPath->isNodeActive(inputId0.first) && !flowPath->isNodeActive(inputId1.first)) - shouldSkip = true; - - } else { - // let's check for input nodes, if they are disabled or contain divergents - for (int e = 0; e < node->input()->size(); e++) { - auto inputId = node->input()->at(e); - - // not a node. skipping checks - if (graph->getMapped()->count(inputId.first) == 0) - continue; - - /** - * We can skip current node, in two cases: - * 1) If previous node was disabled - * 2) If previous node was divergent node (i.e. IF op) and code went other way - */ - Node *prevNode = graph->getMapped()->at(inputId.first); - if (!flowPath->isNodeActive(inputId.first)) { - shouldSkip = true; - flowPath->markNodeActive(node->id(), false); - - nd4j_debug("Skipping Node_%i due to inactive input [%i]\n", node->id(), inputId.first); - break; - - } else if (prevNode->isDivergencePoint()) { // literally checking for switch here - if (flowPath->branch(inputId.first) != inputId.second) { - shouldSkip = true; - flowPath->markNodeActive(node->id(), false); - nd4j_debug("Skipping Node_%i due to divergent branch [%i]\n", node->id(), - inputId.first); - break; - } - } - } - } - - if (shouldSkip) - continue; - } - - // we're propagating frameId here (but only if wasn't set earlier) - if (frames.size() > 0 && node->getFrameId() < 0) - node->setFrameId(frames.back()); - - - flowPath->markNodeActive(node->id(), true); - - if (node->opType() == OpType_LOGIC && node->opNum() == sd::logic::Enter) { - // Enter operation - // VALIDATED - - // we expect this node to have frameId set - auto frame_id = node->getFrameId(); - - // new frame starts here - if (frames.size() == 0 || (frames.size() > 0 && frames.back() != frame_id)) { - flowPath->registerFrame(frame_id); - frames.emplace_back(frame_id); - inFrame = true; - } - - - auto status = LogicExecutor::processNode(graph, node); - if (status != Status::OK()) - return status; - - } else if (node->opType() == OpType_LOGIC && node->opNum() == sd::logic::NextIteration) { - /** - * NextIteration is special case: after successful execution of this op - we're changing execution position - */ - // VALIDATED - auto inputId = node->input()->at(0); - - auto status = LogicExecutor::processNode(graph, node); - if (status != Status::OK()) - return status; - - auto frame_id = frames.back(); - - flowPath->markNodeActive(node->id(), true); - flowPath->markExecuted(node->id(), true); - - if (!flowPath->isRewindPlanned(frame_id)) { - auto nextLayer = node->getRewindLayer(); - - nd4j_debug("Node_%i planned rewind to Node_%i at [%i:%i]\n", node->id(), node->getRewindNode(), nextLayer.first, nextLayer.second); - - flowPath->planRewind(frame_id, true); - flowPath->setRewindPositionOnce(frame_id, nextLayer.first - 1); - - continue; - } - - - } else if (node->opType() == OpType_LOGIC && node->opNum() == sd::logic::Exit) { - // Exit node is another special case: it can rewind executioner to specific point in graph - // VALIDATED - - auto frame_id = frames.back(); - - // if this loop frame wasn't activated - just skip it - if (!flowPath->isFrameActive(frame_id)) { - flowPath->markNodeActive(node->id(), false); - - leftFrame = true; - continue; - } - - if (flowPath->isRewindPlanned(frame_id)) { - // just break loop here - l = flowPath->getRewindPosition(frame_id); - flowPath->setRewindPosition(frame_id, -1); - flowPath->planRewind(frame_id, false); - - break; - } else { - // execute Exit node otherwise - - auto status = LogicExecutor::processNode(graph, node); - if (status != Status::OK()) - return status; - - leftFrame = true; - } - - - } else if (node->opType() == OpType_LOGIC) { - /** - * If this LOGIC op, we'll use another execution model here - */ - auto status = LogicExecutor::processNode(graph, node); - - if (status != Status::OK()) - return status; - } else { - - - auto timeStart = std::chrono::system_clock::now(); - - // actual node execution happens right here - Nd4jStatus status = executeFlatNode(graph, node, __variableSpace); - - auto timeEnd = std::chrono::system_clock::now(); - - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - - - flowPath->setOuterTime(node->id(), outerTime); - - if (status != ND4J_STATUS_OK) - return status; - - - // here we should handle divergent ops, and disable nodes accordingly - if (node->isDivergencePoint()) { - auto activeBranch = flowPath->branch(node->id()); - nd4j_debug("Active branch at node [%i]: %i\n", node->id(), activeBranch); - - // now we skip all branches except of this active one - } - - if (sd::Environment::getInstance().isDebugAndVerbose()) { - - if (__variableSpace->getVariable(node->id())->hasNDArray()) { - auto array = __variableSpace->getVariable(node->id())->getNDArray(); - auto shape = ShapeUtils::shapeAsString(array); - auto values = array->asIndexedString(16); - auto type = DataTypeUtils::asString(array->dataType()); - nd4j_debug("node_%i finished. result shape: %s; data type: %s; first values: %s\n", node->id(), shape.c_str(), type.c_str(), values.c_str()); - } else if (__variableSpace->getVariable(node->id())->hasNDArrayList()) { - auto list = __variableSpace->getVariable(node->id())->hasNDArrayList() ? __variableSpace->getVariable(node->id())->getNDArrayList() : nullptr; - nd4j_debug("node_% is ListOp, skipping evaluation", node->id()); - } else { - nd4j_debug("node_% is Unknown: has no NDArray or NDArrayList", node->id()); - } - } - } - - // if node was executed - tag it as active - flowPath->markExecuted(node->id(), true); - } - } - - // optionally saving execution time - if (Environment::getInstance().isProfiling()) { - flowPath->profile()->nodeById(lastId)->setTotalTime(GraphProfile::relativeTime(nodeTime)); - flowPath->profile()->setExecutionTime(GraphProfile::relativeTime(timeStart)); - //flowPath->profile().printOut(); - } - - // saving memory footprint for current run - if (__variableSpace->launchContext()->getWorkspace() != nullptr) { - auto m = __variableSpace->launchContext()->getWorkspace()->getAllocatedSize(); - auto h = graph->hashCode(); - sd::memory::MemoryRegistrator::getInstance().setGraphMemoryFootprintIfGreater(h, m); - } - - if (tempFlow) { - delete flowPath; - __variableSpace->setFlowPath(nullptr); - } - - return Status::OK(); -} - -/** - * This method is provided for IPC: - * 1) it accepts pointer to FlatBuffers buffer - * 2) restores Graph from it - * 3) Executes this Graph - * 4) Packs execution results into FlatBuffers (FlatResults instance) - * 5) Returns pointer to FlatBuffer results buffer - * - */ - sd::graph::ResultWrapper* GraphExecutioner::executeFlatBuffer(Nd4jPointer pointer) { - uint8_t *buffer = reinterpret_cast(pointer); - - // nd4j_debug("Trying to restore graph\n", 0); - - auto restoredGraph = GetFlatGraph(buffer); - - // nd4j_debug("Graph restored\n", 0); - - // converting FlatGraph to internal representation - auto nativeGraph = new Graph(restoredGraph); - - if (Environment::getInstance().isDebugAndVerbose()) { - nativeGraph->printOut(); - } - - FlowPath flowPath; - nativeGraph->getVariableSpace()->setFlowPath(&flowPath); - - - // nd4j_debug("Going to execute graph\n", 0); - - // executing internal representation - auto status = GraphExecutioner::execute(nativeGraph); - if (status != ND4J_STATUS_OK) { - nd4j_printf("Graph execution failed with status: [%i]\n", status) - return nullptr; - } - - // nd4j_debug("Building output...\n", 0); - - flatbuffers::FlatBufferBuilder builder(1024); - - // fetching time reports - std::vector> timings_vector; - for (int e = 0; e < (int) nativeGraph->getAllNodes()->size(); e++) { - Node *node = nativeGraph->getAllNodes()->at(e); - - if (node->getContextPrototype() == nullptr) - continue; - - auto pair = CreateLongPair(builder, flowPath.outerTime(node->id()), flowPath.innerTime(node->id())); - if (node->getName() != nullptr) { - auto name = builder.CreateString(node->getName()->c_str()); - auto fr = CreateFlatTiming(builder, node->id(), name, pair); - timings_vector.push_back(fr); - } else { - auto fr = CreateFlatTiming(builder, node->id(), 0, pair); - timings_vector.push_back(fr); - } - } - - - // now, we'll prepare output, depending on given outputmode - auto outputs = nativeGraph->fetchOutputs(); - auto size = static_cast(outputs->size()); - int arrays = 0; - std::vector> variables_vector; - for (int e = 0; e < size; e++) { - auto var = outputs->at(e); - - // FIXME: we want to export multi-output nodes as well - // FIXME: we want to export NDArrayList and skip nodes without outputs - if (!var->hasNDArray()) - continue; - - - auto array = var->getNDArray(); - - auto fArray = FlatUtils::toFlatArray(builder, *array); - - auto fName = builder.CreateString(*(var->getName())); - auto id = CreateIntPair(builder, var->id(), var->index()); - - auto fv = CreateFlatVariable(builder, id, fName, static_cast(array->dataType()), 0, fArray); - - variables_vector.push_back(fv); - arrays++; - } - - nd4j_debug("Returning %i variables back\n", arrays); - - auto varTimings = builder.CreateVector(timings_vector); - auto varVectors = builder.CreateVector(variables_vector); - auto result = CreateFlatResult(builder, restoredGraph->id(), varVectors, varTimings); - builder.Finish(result); - - // we might want to keep this graph for future - delete outputs; - delete nativeGraph; - - char* res = new char[builder.GetSize()]; - memcpy(res, builder.GetBufferPointer(), builder.GetSize()); - - nd4j_debug("Buffer size: %lld\n", static_cast(builder.GetSize())); - - return new ResultWrapper(builder.GetSize(), reinterpret_cast(res)); -} - -Graph* GraphExecutioner::importFromTensorFlow(const char *fileName) { - /* - if (fileName == nullptr) - return nullptr; - - int fd = open(fileName, O_RDONLY); - - if (fd < 0) { - nd4j_printf("File not found: [%s]\n", fileName); - return nullptr; - } - - nd4j_verbose("Trying to load TF GraphDef from file [%s]\n", fileName); - - tensorflow::GraphDef graphDef; - bool res = graphDef.ParseFromFileDescriptor(fd); - - // trying to read graph as text - if(!res) { - close(fd); - fd = open(fileName, O_RDONLY); - - google::protobuf::io::FileInputStream fileInput(fd); - fileInput.SetCloseOnDelete(true); - - if (!google::protobuf::TextFormat::Parse(&fileInput, &graphDef)) { - nd4j_printf("Failed to read file\n",""); - } else { - res = true; - } - } - - close(fd); - - if (!res) - return nullptr; - - auto graph = new Graph(); - auto variableSpace = graph->getVariableSpace(); - - std::map variablesMap; - - int variablesCounter = 0; - int nodesCounter = 0; - nd4j_verbose("Number of nodes in graphDef: %i\n", graphDef.node_size()); - for (int n = 0; n < graphDef.node_size(); n++) { - auto node = graphDef.node(n); - - // if that's external variable - we put it to variable space - if (strcmp(TF_VAR, node.op().c_str()) == 0 || strcmp(TF_CONST, node.op().c_str()) == 0 || strcmp(TF_INPUT, node.op().c_str()) == 0) { - nd4j_printf("Variable found: %s\n", node.name().c_str()); - auto variable = new Variable(); - variable->setName(new std::string(node.name().c_str())); - variable->setId(--variablesCounter); - variableSpace->putVariable(variable->id(), variable); - - std::pair pair(node.name(), variable->id()); - variablesMap.insert(pair); - - // TODO: we might want to have something like that. - // it basically just gives input validation option, since settles expectations for input - if (strcmp(TF_INPUT, node.op().c_str()) == 0) - continue; - - // checking shape, not applicable to input, since it can vary - if (node.attr().count("shape")) { - auto attr = node.attr().at("shape"); - int dims = attr.shape().dim_size(); - - if (dims > 0) { - std::vector __shape; - - // we don't have rank1 arrays. vector is 2d. - if (dims == 1) - __shape.push_back(1); - - // roll through dimensions - for (auto s: attr.shape().dim()) { - __shape.push_back((int) s.size()) ; - } - - variable->setNDArray(new NDArray('c', __shape)); - - nd4j_printf("Shape found: %i dims;\n", dims); - variable->getNDArray()->printShapeInfo(); - } - } - - // checking tensor attached - if (node.attr().count("value")) { - auto attr = node.attr().at("value"); - - // int - if (attr.tensor().dtype() == ::tensorflow::DataType::DT_INT32) { - nd4j_verbose("Int size: %i\n", attr.tensor().int_val_size()); - - Nd4jLong __length = 0; - - nd4j_verbose("Tensor has shape: %i\n", attr.tensor().has_tensor_shape()); - if (attr.tensor().has_tensor_shape()) { - auto shape = attr.tensor().tensor_shape(); - int dims = shape.dim_size(); - - if (dims > 0) { - std::vector __shape; - // we don't have rank1 arrays. vector is 2d. - if (dims == 1) - __shape.push_back(1); - - // roll through dimensions - for (auto s: shape.dim()) { - __shape.push_back((int) s.size()); - } - - variable->setNDArray(new NDArray('c', __shape)); - __length = variable->getNDArray()->lengthOf(); - - nd4j_printf("Tensor shape found: %i dims;\n", dims); - variable->getNDArray()->printShapeInfo(); - } - } - - // it can be valueOf array - if (attr.tensor().int_val_size() == 1 && __length > 0) { - variable->getNDArray()->assign((T) attr.tensor().int_val(0)); - } - } - } - } else { - nd4j_verbose("Node id: [%i]; name: [%s]; opName: [%s]\n", n + 1, node.name().c_str(), - node.op().c_str()); - - sd::ops::DeclarableOp *op = sd::ops::OpRegistrator::getInstance().getOperationFloat(node.op().c_str()); - - if (op == nullptr) { - nd4j_verbose("Op wasn't found: %s\n", node.op().c_str()); - return nullptr; - } - - auto jNode = new Node(); - jNode->setName(node.name()); - jNode->setId(++nodesCounter); - jNode->setCustomOp(op); - jNode->setBlock(new Block(jNode->id(), variableSpace)); - - std::pair pair(node.name(), jNode->id()); - variablesMap.insert(pair); - - // multi-output nodes require special treatment - for (int e = 0; e < op->getOpDescriptor()->getNumberOfOutputs(); e++) { - std::string deepName(node.name()); - deepName += ":" + std::to_string(e); - auto deepVar = new Variable(); - deepVar->setName(&deepName); - - if (e > 0) - deepVar->setId(--variablesCounter); - else - deepVar->setId(jNode->id()); - - std::pair pair(deepName, deepVar->id()); - variablesMap.insert(pair); - - variableSpace->putVariable(deepVar->id(), deepVar); - - std::pair nodepair(jNode->id(), e); - variableSpace->putVariable(nodepair, deepVar); - } - - - printf(" Inputs: ["); - for (int i = 0; i < node.input_size(); i++) { - nd4j_printf("Trying input: %s\n", node.input(i).c_str()); - - // if this fails - we're probably on partial input :) - if (!variablesMap.count(node.input(i))) - return nullptr; - - printf("%s (%i)", node.input(i).c_str(), variablesMap.at(node.input(i))); - - - jNode->pickInput(variablesMap.at(node.input(i))); - jNode->getBlock()->pickInput(variablesMap.at(node.input(i))); - - - if (i < node.input_size() + 1) - printf(", "); - } - printf("]\n"); - - graph->addNode(jNode); - } - } - - return graph; - */ - return nullptr; -} - -/** -* This function returns file size for the given file name, or -1 if something went wrong -*/ -long getFileSize(const char * filename) { - struct stat stat_buf; - int rc = stat(filename, &stat_buf); - return rc == 0 ? stat_buf.st_size : -1; -} - -/** -* Helper function, that loads given filename into uint8_t array -* -*/ -uint8_t* readFlatBuffers(const char * filename) { - long fileLen = getFileSize(filename); - if (fileLen < 0) { - nd4j_printf("File [%s] wasn't found. Please check path and permissions\n", filename); - throw std::runtime_error("File not found"); - } - - nd4j_debug("File length: %i\n", fileLen); - - uint8_t * data = new uint8_t[fileLen]; - - FILE *in = fopen(filename, "rb"); - int cnt = 0; - int b = 0; - while (cnt < fileLen) { - b = fread(data + cnt, 1, 1, in); - - cnt += b; - } - fclose(in); - - return data; -} - -flatbuffers::Offset GraphExecutioner::execute(Graph *graph, flatbuffers::FlatBufferBuilder &builder, const FlatInferenceRequest* request) { - ExecutionResult result; - auto varSpace = graph->getVariableSpace(); - - if (request != nullptr && request->variables() != nullptr) { - auto vars = request->variables(); - for (int e = 0; e < vars->size(); e++) { - auto fv = vars->Get(e); - auto v = new Variable(fv); - varSpace->replaceVariable(v); - } - } - - if (Environment::getInstance().isDebugAndVerbose()) - graph->printOut(); - - auto status = GraphExecutioner::execute(graph); - if (status != sd::Status::OK()) - throw graph_execution_exception(request->id()); - - auto outputs = graph->fetchOutputs(); - - if (outputs->size() == 0) - throw no_results_exception(request->id()); - - - for (auto v: *outputs) { - result.emplace_back(v); - } - - auto t = result.asFlatResult(builder); - - delete outputs; - - return t; -} - - - /** - * This method reads given FlatBuffers file, and returns Graph instance - * - * PLEASE NOTE: This method is mostly suited for tests and debugging/profiling - */ - Graph* GraphExecutioner::importFromFlatBuffers(const char *filename) { - auto data = readFlatBuffers(filename); - auto restoredGraph = importFromFlatPointer(reinterpret_cast(data)); - delete[] data; - return restoredGraph; - } - - Graph *GraphExecutioner::importFromFlatPointer(Nd4jPointer ptr) { - auto fg = GetFlatGraph(reinterpret_cast(ptr)); - auto restoredGraph = new Graph(fg); - - return restoredGraph; - } - } -} \ No newline at end of file diff --git a/libnd4j/include/graph/impl/GraphHolder.cpp b/libnd4j/include/graph/impl/GraphHolder.cpp index 13c4e38965ac..642df4d4e464 100644 --- a/libnd4j/include/graph/impl/GraphHolder.cpp +++ b/libnd4j/include/graph/impl/GraphHolder.cpp @@ -18,111 +18,78 @@ // @author raver119@gmail.com // -#include -#include -#include #include +#include +#include namespace sd { - namespace graph { - GraphHolder& GraphHolder::getInstance() { - static GraphHolder instance; - return instance; - }; - - void GraphHolder::registerGraph(Nd4jLong graphId, Graph* graph) { - if (hasGraphAny(graphId)) - throw graph_exists_exception(graphId); - - _graphF[graphId] = graph; - - sd::SimpleReadWriteLock lock; - _locks[graphId] = lock; - } - - Graph* GraphHolder::cloneGraph(Nd4jLong graphId) { - if (!this->hasGraph(graphId)) { - nd4j_printf("GraphHolder doesn't have graph stored for [%lld]\n", graphId); - throw std::runtime_error("Bad argument"); - } - - auto graph = _graphF[graphId]->cloneWithProxy(); - - return graph; - } - - Graph* GraphHolder::pullGraph(Nd4jLong graphId) { - if (!this->hasGraph(graphId)) { - nd4j_printf("GraphHolder doesn't have graph stored for [%lld]\n", graphId); - throw std::runtime_error("Bad argument"); - } - - auto graph = _graphF[graphId]; +namespace graph { - return graph; - } +GraphHolder& GraphHolder::getInstance() { + static GraphHolder instance; + return instance; +}; - void GraphHolder::forgetGraph(Nd4jLong graphId) { - if (this->hasGraph(graphId)) - _graphF.erase(graphId); - } +void GraphHolder::registerGraph(Nd4jLong graphId, const Graph& graph) { + if (hasGraph(graphId)) throw graph_exists_exception(graphId); - void GraphHolder::dropGraph(Nd4jLong graphId) { - if (this->hasGraph(graphId)) { - auto g = _graphF[graphId]; - forgetGraph(graphId); - delete g; - } - } - - void GraphHolder::dropGraphAny(Nd4jLong graphId) { - if (!hasGraphAny(graphId)) - return; - - this->lockWrite(graphId); - - this->dropGraph(graphId); - - this->unlockWrite(graphId); - } - - bool GraphHolder::hasGraphAny(Nd4jLong graphId) { - return this->hasGraph(graphId); - } - - bool GraphHolder::hasGraph(Nd4jLong graphId) { - return _graphF.count(graphId) > 0; - } + std::lock_guard lock(_mutex); + _graphs[graphId] = graph; +} - void GraphHolder::replaceGraph(Nd4jLong graphId, Graph* graph) { - if (!hasGraph(graphId)) { - registerGraph(graphId, graph); - return; - } +Graph& GraphHolder::graph(Nd4jLong graphId) { + if (!this->hasGraph(graphId)) { + nd4j_printf("GraphHolder doesn't have graph stored for [%lld]\n", graphId); + throw std::runtime_error("Bad argument"); + } - this->lockWrite(graphId); + std::lock_guard lock(_mutex); + return _graphs[graphId]; +} - _graphF[graphId] = graph; +void GraphHolder::forgetGraph(Nd4jLong graphId) { + if (this->hasGraph(graphId)) { + std::lock_guard lock(_mutex); + _graphs.erase(graphId); + } +} - this->unlockWrite(graphId); - } +void GraphHolder::dropGraph(Nd4jLong graphId) { forgetGraph(graphId); } +bool GraphHolder::hasGraph(Nd4jLong graphId) { + std::lock_guard lock(_mutex); + return _graphs.count(graphId) > 0; +} +void GraphHolder::replaceGraph(Nd4jLong graphId, const Graph& graph) { + if (!hasGraph(graphId)) { + registerGraph(graphId, graph); + return; + } + forgetGraph(graphId); - flatbuffers::Offset GraphHolder::execute(Nd4jLong graphId, flatbuffers::FlatBufferBuilder &builder, const FlatInferenceRequest* request) { - if (!hasGraph(graphId)) - throw unknown_graph_exception(graphId); + std::lock_guard lock(_mutex); + _graphs[graphId] = graph; +} - lockRead(graphId); +flatbuffers::Offset GraphHolder::execute( + Nd4jLong graphId, flatbuffers::FlatBufferBuilder& builder, + const FlatInferenceRequest* request) { + if (!hasGraph(graphId)) throw unknown_graph_exception(graphId); + /* + lockRead(graphId); - auto graph = cloneGraph(graphId); - auto res = GraphExecutioner::execute(graph, builder, request); - delete graph; + auto graph = cloneGraph(graphId); + auto res = GraphExecutioner::execute(graph, builder, request); + delete graph; - unlockRead(graphId); + unlockRead(graphId); - return res; - } - } + return res; + */ + throw std::runtime_error("GraphHolder::execute - not implemented yet"); } + +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/impl/GraphState.cpp b/libnd4j/include/graph/impl/GraphState.cpp deleted file mode 100644 index a8b25603a512..000000000000 --- a/libnd4j/include/graph/impl/GraphState.cpp +++ /dev/null @@ -1,167 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 23.01.18. -// - -#include -#include - - -namespace sd { -namespace graph { - GraphState::GraphState(Nd4jLong id) { - _id = id; - _graph = new Graph(nullptr, &_variableSpace); - }; - - GraphState::~GraphState() { - // stupid. we should get rid of pointers here i think. no need to bother - for (const auto &v: _scopes) { - v.second->forgetNodes(); - delete v.second; - } - - // we must remove reference to VariableSpace - _graph->forgetVariableSpace(); - - delete _graph; - }; - - Nd4jStatus GraphState::registerScope(int scopeId) { - auto scope = new Scope(scopeId); - _scopes[scopeId] = scope; - - auto scopeWrapper = new Node(OpType_LOGIC, 10, scopeId); - _graph->addNode(scopeWrapper); - - return Status::OK(); - }; - - Nd4jStatus GraphState::forgetScope(int scopeId) { - if (_scopes.count(scopeId) > 0) - _scopes.erase(scopeId); - else - return Status::THROW("Non-existent scope requested"); - - return Status::OK(); - }; - -#ifndef __JAVACPP_HACK__ - Nd4jStatus GraphState::attachOpToScope(int scopeId, int nodeId, ops::DeclarableOp *op, ArgumentsList inputs) { - if (_scopes.count(scopeId) == 0) - return Status::THROW("GraphState: can't attach op to unknown scope"); - - auto scope = _scopes[scopeId]; - - // creating new Node -// auto node = new Node(OpType_CUSTOM, 0, nodeId); - auto node = new Node(op, nodeId); -// node->setCustomOp(op); - node->setScopeInfo(scopeId); - - // mapping inputs here - for (int e = 0; e < inputs.size(); e++) { - auto p = inputs.at(e); - - // each expected input is Variable in current VariableSpace - // it should have it's numerical and symbolic ID - - if (!_variableSpace.hasVariable(p.first(), p.second())) { - auto var = new Variable(); - var->setId(p.first(), p.second()); - _variableSpace.putVariable(p.first(), p.second(), var); - } - - // nd4j_printf("Node_%i: adding input [%i:%i]\n", node->id(), p.first(), p.second()); - node->pickInput(p.first(), p.second()); - } - - scope->push_back(node); - - _graph->addNode(node); - - return Status::OK(); - }; - - Graph* GraphState::graph() { - return _graph; - } - - Scope* GraphState::getScope(int scopeId) { - if (_scopes.count(scopeId) == 0) { - nd4j_printf("GraphState: Unknown scope requested %i\n", scopeId); - return nullptr; - } - - return _scopes[scopeId]; - } -#endif - Nd4jStatus GraphState::defineReturn(int scopeId, int nodeId, ArgumentsList args) { - if (_scopes.count(scopeId) == 0) - return Status::THROW("GraphState: can't attach op to unknown scope"); - - auto scope = _scopes[scopeId]; - - // creating new Node for RETURN - auto node = new Node(OpType_LOGIC, 40, nodeId); - node->setScopeInfo(scopeId); - - // mapping inputs here - for (int e = 0; e < args.size(); e++) { - auto p = args.at(e); - - // each expected input is Variable in current VariableSpace - // it should have it's numerical and symbolic ID - - if (!_variableSpace.hasVariable(p.first(), p.second())) { - auto var = new Variable(); - var->setId(p.first(), p.second()); - _variableSpace.putVariable(p.first(), p.second(), var); - } - - // nd4j_printf("Node_%i: adding input [%i:%i]\n", node->id(), p.first(), p.second()); - node->pickInput(p.first(), p.second()); - node->pickOutput(0, e); - } - - scope->push_back(node); - - _graph->addNode(node); - - - return Status::OK(); - } - - bool GraphState::hasScope(int scopeId) { - return _scopes.count(scopeId) > 0; - } - - VariableSpace* GraphState::variableSpace() { - return &_variableSpace; - }; - - Nd4jLong GraphState::id() { - return _id; - } - - Nd4jStatus GraphState::attachOpToScope(int scopeId, Nd4jLong opNum, int type, ArgumentsList inputs) { - // we should use OpRegistrator here, to create Node and push it to specific scope - return Status::OK(); - } -} -} \ No newline at end of file diff --git a/libnd4j/include/graph/impl/GraphUtils.cpp b/libnd4j/include/graph/impl/GraphUtils.cpp index 15f674ce1c70..de8c5a06d6d6 100644 --- a/libnd4j/include/graph/impl/GraphUtils.cpp +++ b/libnd4j/include/graph/impl/GraphUtils.cpp @@ -23,14 +23,15 @@ #endif #include -#include + #include +#include -#ifdef __linux__ //_WIN32 -#include +#ifdef __linux__ //_WIN32 +#include #include #include -#include +#include //#eldef __APPLE__ //#include //#include @@ -39,81 +40,81 @@ namespace sd { namespace graph { bool GraphUtils::filterOperations(GraphUtils::OpList& ops) { - bool modified = false; - - std::vector filtered(ops); - - std::sort(filtered.begin(), filtered.end(), [](ops::OpDescriptor a, ops::OpDescriptor b) { - return a.getOpName()->compare(*(b.getOpName())) < 0; - }); - std::string name = *(filtered[0].getOpName()); - - for (int e = 1; e < filtered.size(); e++) { -// nd4j_printf(">%s<, %lu %lu\n", name.c_str(), ops.size(), filtered.size()); - if (0 == filtered[e].getOpName()->compare(name)) { - // there is a match - auto fi = std::find_if(ops.begin(), ops.end(), - [name](ops::OpDescriptor a) { - return a.getOpName()->compare(name) == 0; + bool modified = false; + + std::vector filtered(ops); + + std::sort(filtered.begin(), filtered.end(), + [](ops::OpDescriptor a, ops::OpDescriptor b) { + return a.getOpName()->compare(*(b.getOpName())) < 0; }); - if (fi != ops.end()) - ops.erase(fi); - modified = true; - } - name = *(filtered[e].getOpName()); + std::string name = *(filtered[0].getOpName()); + + for (int e = 1; e < filtered.size(); e++) { + // nd4j_printf(">%s<, %lu %lu\n", name.c_str(), ops.size(), + // filtered.size()); + if (0 == filtered[e].getOpName()->compare(name)) { + // there is a match + auto fi = + std::find_if(ops.begin(), ops.end(), [name](ops::OpDescriptor a) { + return a.getOpName()->compare(name) == 0; + }); + if (fi != ops.end()) ops.erase(fi); + modified = true; } - return modified; + name = *(filtered[e].getOpName()); + } + return modified; } std::string GraphUtils::makeCommandLine(GraphUtils::OpList& ops) { - std::string res; - - if (!ops.empty()) { - res += std::string(" -g \"-DSD_OPS_LIST='"); - //res += *(ops[0].getOpName()); - for (int i = 0; i < ops.size(); i++) { - res += std::string("-DOP_"); - res += *(ops[i].getOpName()); - res += "=true "; - } - res += "'\""; + std::string res; + + if (!ops.empty()) { + res += std::string(" -g \"-DSD_OPS_LIST='"); + // res += *(ops[0].getOpName()); + for (int i = 0; i < ops.size(); i++) { + res += std::string("-DOP_"); + res += *(ops[i].getOpName()); + res += "=true "; } + res += "'\""; + } - return res; + return res; } -int -GraphUtils::runPreprocessor(char const* input, char const* output) { - int status = 0; +int GraphUtils::runPreprocessor(char const* input, char const* output) { + int status = 0; -#ifdef __linux__ //_WIN32 - int pipefd[2]; - status = pipe(pipefd); - pid_t pid = fork(); - if (pid == 0) - { - close(pipefd[0]); // close reading end in the child - - dup2(pipefd[1], 1); // send stdout to the pipe - dup2(pipefd[1], 2); // send stderr to the pipe +#ifdef __linux__ //_WIN32 + int pipefd[2]; + status = pipe(pipefd); + pid_t pid = fork(); + if (pid == 0) { + close(pipefd[0]); // close reading end in the child - close(pipefd[1]); // this descriptor is no longer needed + dup2(pipefd[1], 1); // send stdout to the pipe + dup2(pipefd[1], 2); // send stderr to the pipe - #if __CNUC__ < 4 && __GNUC_MINOR__ < 9 - #pragma error "Compiler version should be greater then 4.9" - #endif + close(pipefd[1]); // this descriptor is no longer needed + +#if __CNUC__ < 4 && __GNUC_MINOR__ < 9 +#pragma error "Compiler version should be greater then 4.9" +#endif // just stacking everything together -// std::string cmdline = "./buildnativeoperations.sh " + -/// std::string(name_arg) + -// std::string(build_arg) + -/// std::string(arch_arg) + -// std::string(opts_arg); - - FILE *f = popen("which c++", "r"); - if(f == NULL) { - std::cerr << "Cannot find c++ compiler with 'which' command." << std::endl; - exit(1); + // std::string cmdline = "./buildnativeoperations.sh " + + /// std::string(name_arg) + + // std::string(build_arg) + + /// std::string(arch_arg) + + // std::string(opts_arg); + + FILE* f = popen("which c++", "r"); + if (f == NULL) { + std::cerr << "Cannot find c++ compiler with 'which' command." + << std::endl; + exit(1); } #if _POSIX_C_SOURCE >= 200809L char* line = nullptr; @@ -121,11 +122,11 @@ GraphUtils::runPreprocessor(char const* input, char const* output) { ssize_t len; if ((len = getdelim(&line, &size, '\n', f)) < 2) { - std::cerr << "Cannot find c++ compiler with 'which' command." << std::endl; - exit(2); + std::cerr << "Cannot find c++ compiler with 'which' command." + << std::endl; + exit(2); } - if (line[len - 1] == '\n') - line[len - 1] = '\0'; + if (line[len - 1] == '\n') line[len - 1] = '\0'; std::string cmd(line); @@ -135,38 +136,39 @@ GraphUtils::runPreprocessor(char const* input, char const* output) { #else std::string cmd; { - - char szLine[PATH_MAX]; - if (NULL == fgets(szLine, sizeof(szLine), f)) { - std::cerr << "Cannot find c++ compiler with 'which' command." << std::endl; - exit(3); - } - char* p = strchr(szLine, '\n'); - if (p) { - *p = '\0'; - } - cmd = szLine; + char szLine[PATH_MAX]; + if (NULL == fgets(szLine, sizeof(szLine), f)) { + std::cerr << "Cannot find c++ compiler with 'which' command." + << std::endl; + exit(3); + } + char* p = strchr(szLine, '\n'); + if (p) { + *p = '\0'; + } + cmd = szLine; } #endif - char const* cxx = cmd.c_str(); //;getenv("CXX"); -// if (cxx == nullptr) { -// nd4j_printf("Cannot retrieve mandatory environment variable 'CXX'. Please set up the variable and try again.", ""); -// exit(3); -// } - //char* pathEnv = getenv("PATH"); - //std::string pathStr("PATH=./;"); - //pathStr += pathEnv; - - //nd4j_printf("%s\n", pathStr.c_str()); -// char const* env[] = {// "HOME=/tmp", -// pathStr.c_str(), -// (char *)0 }; - -// to retrieve c++ version (hardcoded 6): c++ -v 2>&1 | tail -1 | awk '{v = int($3); print v;}' - - std::vector params;//(9); - std::vector args;//(9); + char const* cxx = cmd.c_str(); //;getenv("CXX"); + // if (cxx == nullptr) { + // nd4j_printf("Cannot retrieve mandatory environment variable 'CXX'. + // Please set up the variable and try again.", ""); exit(3); + // } + // char* pathEnv = getenv("PATH"); + // std::string pathStr("PATH=./;"); + // pathStr += pathEnv; + + // nd4j_printf("%s\n", pathStr.c_str()); + // char const* env[] = {// "HOME=/tmp", + // pathStr.c_str(), + // (char *)0 }; + + // to retrieve c++ version (hardcoded 6): c++ -v 2>&1 | tail -1 | awk '{v = + // int($3); print v;}' + + std::vector params; //(9); + std::vector args; //(9); args.emplace_back(cmd); args.emplace_back(std::string("-E")); args.emplace_back(std::string("-P")); @@ -182,7 +184,7 @@ GraphUtils::runPreprocessor(char const* input, char const* output) { args.emplace_back(std::string("-I../include/types")); args.emplace_back(std::string("-I../include/array")); args.emplace_back(std::string("-I../include/cnpy")); - args.emplace_back(std::string("-I../include/graph")); + args.emplace_back(std::string("-I../include/graph")); args.emplace_back(std::string("-I../include/ops/declarable")); #ifdef MKLDNN_PATH args.emplace_back(std::string("-I" MKLDNN_PATH "/include")); @@ -197,14 +199,13 @@ GraphUtils::runPreprocessor(char const* input, char const* output) { std::string preprocessorCmd(cxx); bool skip = true; - for (auto& arg: args) { - if (!skip) { - preprocessorCmd += ' '; - preprocessorCmd += arg; - } - else - skip = false; - params.emplace_back(const_cast(arg.data())); + for (auto& arg : args) { + if (!skip) { + preprocessorCmd += ' '; + preprocessorCmd += arg; + } else + skip = false; + params.emplace_back(const_cast(arg.data())); } params.emplace_back(nullptr); nd4j_printf("Run: \n\t %s\n", preprocessorCmd.c_str()); @@ -212,26 +213,24 @@ GraphUtils::runPreprocessor(char const* input, char const* output) { int err = execvp(cmd.c_str(), ¶ms[0]); if (err < 0) { - perror("\nCannot run Preprocessor properly due \n"); + perror("\nCannot run Preprocessor properly due \n"); } status = err; nd4j_printf("Header file %s was generated.\n", output); -// nd4j_printf("Running build script\n%s\n", cmdline.c_str()); - } - else - { + // nd4j_printf("Running build script\n%s\n", cmdline.c_str()); + } else { // parent - char buffer[1024]; - close(pipefd[1]); // close the write end of the pipe in the parent - memset(buffer, 0, sizeof(buffer)); - while (read(pipefd[0], buffer, sizeof(buffer)) != 0) { - printf("%s\n", buffer); - } - waitpid(pid, &status, 0); + char buffer[1024]; + close(pipefd[1]); // close the write end of the pipe in the parent + memset(buffer, 0, sizeof(buffer)); + while (read(pipefd[0], buffer, sizeof(buffer)) != 0) { + printf("%s\n", buffer); } + waitpid(pid, &status, 0); + } #endif - return status; + return status; } -} -} +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/impl/InferenceRequest.cpp b/libnd4j/include/graph/impl/InferenceRequest.cpp index 29fde1eb1d5b..079491eb39d5 100644 --- a/libnd4j/include/graph/impl/InferenceRequest.cpp +++ b/libnd4j/include/graph/impl/InferenceRequest.cpp @@ -20,60 +20,65 @@ #include - namespace sd { - namespace graph { - InferenceRequest::InferenceRequest(Nd4jLong graphId, ExecutorConfiguration *configuration) { - this->_id = graphId; - this->_configuration = configuration; - } +namespace graph { +InferenceRequest::InferenceRequest(Nd4jLong graphId, + const ExecutorConfiguration &configuration) { + this->_id = graphId; + this->_configuration = configuration; +} - InferenceRequest::~InferenceRequest() { - for (auto v : _deletables) - delete v; - } +InferenceRequest::~InferenceRequest() { + // +} - void InferenceRequest::appendVariable(int id, NDArray *array) { - appendVariable(id, 0, array); - } +void InferenceRequest::appendVariable(int id, const NDArray &array) { + appendVariable(id, 0, array); +} - void InferenceRequest::appendVariable(int id, int index, NDArray *array) { - auto v = new Variable(array, nullptr, id, index); - insertVariable(v); - } +void InferenceRequest::appendVariable(int id, int index, const NDArray &array) { + auto v = std::make_shared(std::make_shared(array), + nullptr, id, index); + insertVariable(v); +} - void InferenceRequest::appendVariable(std::string &id, NDArray *array) { - auto v = new Variable(array, id.c_str()); - insertVariable(v); - } +void InferenceRequest::appendVariable(const std::string &id, + const NDArray &array) { + auto v = std::make_shared(std::make_shared(array), + id.c_str()); + insertVariable(v); +} - void InferenceRequest::appendVariable(std::string &name, int id, int index, NDArray *array) { - auto v = new Variable(array, name.c_str(), id, index); - insertVariable(v); - } +void InferenceRequest::appendVariable(const std::string &name, int id, + int index, const NDArray &array) { + auto v = std::make_shared(std::make_shared(array), + name, id, index); + insertVariable(v); +} - void InferenceRequest::insertVariable(Variable *variable) { - variable->markRemovable(false); - variable->markReadOnly(true); - _variables.emplace_back(variable); - _deletables.emplace_back(variable); - } +void InferenceRequest::insertVariable(std::shared_ptr variable) { + variable->markRemovable(false); + variable->markReadOnly(true); + _variables.emplace_back(variable); +} - void InferenceRequest::appendVariable(Variable *variable) { - _variables.emplace_back(variable); - } +void InferenceRequest::appendVariable(std::shared_ptr variable) { + _variables.emplace_back(variable); +} - flatbuffers::Offset InferenceRequest::asFlatInferenceRequest(flatbuffers::FlatBufferBuilder &builder) { - std::vector> vec; - for (Variable* v : _variables) { - vec.emplace_back(v->asFlatVariable(builder)); - } +flatbuffers::Offset +InferenceRequest::asFlatInferenceRequest( + flatbuffers::FlatBufferBuilder &builder) { + std::vector> vec; + for (const auto &v : _variables) { + vec.emplace_back(v->asFlatVariable(builder)); + } - auto confOffset = _configuration != nullptr ? _configuration->asFlatConfiguration(builder) : 0; + auto confOffset = _configuration.asFlatConfiguration(builder); - auto vecOffset = builder.CreateVector(vec); + auto vecOffset = builder.CreateVector(vec); - return CreateFlatInferenceRequest(builder, _id, vecOffset, confOffset); - } - } -} \ No newline at end of file + return CreateFlatInferenceRequest(builder, _id, vecOffset, confOffset); +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/impl/Intervals.cpp b/libnd4j/include/graph/impl/Intervals.cpp index 1a89c797fc92..b8f0ff1a63e8 100644 --- a/libnd4j/include/graph/impl/Intervals.cpp +++ b/libnd4j/include/graph/impl/Intervals.cpp @@ -21,33 +21,30 @@ namespace sd { - // default constructor - Intervals::Intervals(): _content({{}}) {} - - // constructor - Intervals::Intervals(const std::initializer_list>& content ): _content(content) {} - Intervals::Intervals(const std::vector>& content ): _content(content) {} - - ////////////////////////////////////////////////////////////////////////// - // accessing operator - std::vector Intervals::operator[](const Nd4jLong i) const { - - return *(_content.begin() + i); - } - - ////////////////////////////////////////////////////////////////////////// - // returns size of _content - int Intervals::size() const { - - return _content.size(); - } - - ////////////////////////////////////////////////////////////////////////// - // modifying operator - // std::vector& Intervals::operator()(const int i) { - // return _content[i]; - // } +// default constructor +Intervals::Intervals() : _content({{}}) {} + +// constructor +Intervals::Intervals( + const std::initializer_list>& content) + : _content(content) {} +Intervals::Intervals(const std::vector>& content) + : _content(content) {} + +////////////////////////////////////////////////////////////////////////// +// accessing operator +std::vector Intervals::operator[](const Nd4jLong i) const { + return *(_content.begin() + i); +} +////////////////////////////////////////////////////////////////////////// +// returns size of _content +int Intervals::size() const { return _content.size(); } -} +////////////////////////////////////////////////////////////////////////// +// modifying operator +// std::vector& Intervals::operator()(const int i) { +// return _content[i]; +// } +} // namespace sd diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index a3baf1a9bea2..4a4081e7d779 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,854 +19,720 @@ // @author raver119@gmail.com // +#include +#include #include -#include -#include -#include -#include -#include -#include -#include -#include +#include #include -#include +#include +#include +#include #include #include -#include -#include +#include #include -#include +#include +#include +#include #include -#include -#include +#include +#include #include -#include -#include +#include +#include +#include +#include +#include namespace sd { - namespace graph { - void sd::graph::Node::setOuterTime(Nd4jLong time){ -// if (hasBlockAttached()) -// _block->setOuterTime(time); - } +namespace graph { +Node::Node(const ops::DeclarableOp &opName, const std::string &nodeName, + const std::vector &tArgs, const std::vector &iArgs, + const std::vector &bArgs, const std::vector &dArgs) { + auto customOp = + ops::OpRegistrator::getInstance().getOperation(opName.getOpHash()); + + this->_name = nodeName; + this->_opType = OpType_CUSTOM; + this->_opNum = customOp->getOpHash(); + this->_customOp = customOp; + + _hasExternalInputs = false; + _hasExternalOutputs = false; + _hasInternalInputs = false; + _hasInternalOutputs = false; + + ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), + false); + block.setName(nodeName); + + block.appendI(iArgs); + block.appendT(tArgs); + block.appendB(bArgs); + block.appendD(dArgs); + + this->setContextPrototype(block); +} - void sd::graph::Node::setInnerTime(Nd4jLong time){ -// if (hasBlockAttached()) -// _block->setInnerTime(time); - } +Node::Node(const std::string &opName, const std::string &nodeName, + const std::vector &tArgs, const std::vector &iArgs, + const std::vector &bArgs, const std::vector &dArgs) { + auto customOp = ops::OpRegistrator::getInstance().getOperation(opName); - void sd::graph::Node::setGraph(sd::graph::Graph* graph) { - _graph = graph; - } + this->_name = nodeName; + this->_opType = OpType_CUSTOM; + this->_opNum = customOp->getOpHash(); + this->_customOp = customOp; - sd::graph::Graph* sd::graph::Node::getGraph() { - return _graph; - } + _hasExternalInputs = false; + _hasExternalOutputs = false; + _hasInternalInputs = false; + _hasInternalOutputs = false; - bool sd::graph::Node::hasGraphEmbedded() { - return _graph != nullptr; - } + ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), + false); + block.setName(nodeName); - void sd::graph::Node::markInplace(bool reallyInplace) { - _isInplace = reallyInplace; - if (_protoContext != nullptr) { - _protoContext->markInplace(reallyInplace); - } - } + block.appendI(iArgs); + block.appendT(tArgs); + block.appendB(bArgs); + block.appendD(dArgs); - OpClass sd::graph::Node::getOpClass() { - return _opClass; - } + this->setContextPrototype(block); +} - bool sd::graph::Node::hasBlockAttached() { - return _protoContext != nullptr; - } +bool Node::isDivergencePoint() { + if (hasCustomOp()) { + return _customOp->getOpDescriptor()->isDivergent(); + } else if (opType() == OpType_LOGIC && opNum() == sd::logic::Switch) + return true; + else + return false; +} - bool sd::graph::Node::isInplace() { - return _isInplace; - } +void Node::setContextPrototype(const ContextPrototype &block) { + _protoContext = block; +} - bool sd::graph::Node::isDivergencePoint() { - if (hasCustomOp()) { - return _customOp->getOpDescriptor()->isDivergent(); - } else if (opType() == OpType_LOGIC && opNum() == 30) - return true; - else - return false; - } +void Node::setId(int id) { + _id = id; + _protoContext.setNodeId(id); +} - void sd::graph::Node::setActive(bool reallyActive) { - _active = reallyActive; - } +std::shared_ptr Node::customOp() const { + return _customOp; +} - bool sd::graph::Node::isActive() { - return _active; - } +void Node::setCustomOp(const std::shared_ptr& customOp) { + _customOp = customOp; +} - Nd4jLong Node::getFrameId() { - return _frameId; - } +bool Node::hasCustomOp() const { return _customOp != nullptr; } - void Node::setFrameId(Nd4jLong frameId) { - _frameId = frameId; - } +const std::string &Node::name() const { return _name; } - ContextPrototype * sd::graph::Node::getContextPrototype() { - if (_protoContext == nullptr) - _protoContext = new ContextPrototype(this->getCustomOp() != nullptr ? this->getCustomOp()->getOpDescriptor() : nullptr, this->id()); - if (_protoContext->inputs()->empty()) { - for (int e = 0; e < this->input()->size(); e++) { - _protoContext->inputs()->emplace_back(this->input()->at(e)); - } - } - return _protoContext; - } - void sd::graph::Node::setContextPrototype(ContextPrototype *block) { - if (_protoContext != nullptr) - throw std::runtime_error("Block already exists"); +void Node::setName(const std::string &name) { _name = name; } - _protoContext = block; - } +void Node::pickInput(const std::pair &pair) { + _input.emplace_back(pair); + _protoContext.pickInput(pair); +} - void sd::graph::Node::setId(int id) { - _id = id; - } +void Node::pickInput(const std::string &id) { + throw std::runtime_error("Node::pickInput - Not implemented yet"); +} - sd::ops::DeclarableOp* sd::graph::Node::getCustomOp() { - return _customOp; - } +void Node::pickInput(int inputId, int outputId) { + std::pair p(inputId, outputId); + pickInput(p); +} - void sd::graph::Node::setCustomOp(sd::ops::DeclarableOp *customOp) { - _customOp = customOp; +void Node::pickInput(int inputId) { + pickInput(inputId, 0); - // divergent ops (Switch etc) are always inplace, they don't allocate anything - if (_customOp != nullptr && customOp->getOpDescriptor()->isDivergent()) - _isInplace = true; - } + if (inputId < 0) + _hasExternalInputs = true; + else + _hasInternalInputs = true; +} - bool sd::graph::Node::hasCustomOp() { - return _customOp != nullptr; - } +void Node::pickExternalOutput(int outputId) { + std::pair pair(outputId, 0); + _output.emplace_back(pair); - std::string * sd::graph::Node::name() { - return this->getName(); - } + _hasExternalOutputs = true; +} - std::string * sd::graph::Node::getName() { - return &_name; - } +void Node::pickOutputOnce(int outputId) { + std::pair pair(outputId, 0); + if (std::find(_output.begin(), _output.end(), pair) == _output.end()) + pickOutput(outputId); +} - void sd::graph::Node::setName(const std::string& name) { - _name = name.c_str(); - } +void Node::pickOutput(int nodeId, int outputId) { + std::pair pair(nodeId, outputId); + _output.emplace_back(pair); +} - void sd::graph::Node::setName(std::string *name) { - _name = *name; - } +void Node::pickOutput(int outputId) { + std::pair pair(outputId, 0); + _output.emplace_back(pair); - double sd::graph::Node::scalar() { - return _scalar.e(0); - }; + if (outputId < 0) + _hasExternalOutputs = true; + else + _hasInternalOutputs = true; +} - void sd::graph::Node::pickInput(std::pair& pair) { - _input.push_back(pair); - } +bool Node::hasExternalOutputs() const { return _hasExternalOutputs; } - void sd::graph::Node::pickInput(int inputId, int outputId) { - std::pair p(inputId,outputId); - pickInput(p); - } +bool Node::hasExternalInputs() const { return _hasExternalInputs; } - void sd::graph::Node::pickInput(int inputId) { - pickInput(inputId, 0); +bool Node::hasInternalOutputs() const { return _hasInternalOutputs; } - if (inputId < 0) - _hasExternalInputs = true; - else - _hasInternalInputs = true; - } +bool Node::hasInternalInputs() const { return _hasInternalInputs; } - void sd::graph::Node::pickExternalOutput(int outputId) { - std::pair pair(outputId, 0); - _output.push_back(pair); +bool Node::isMultiInput() { return _input.size() > 1; } - _hasExternalOutputs = true; - } +bool Node::isMultiOutput() { return _output.size() > 1; } - void sd::graph::Node::pickOutputOnce(int outputId) { - std::pair pair(outputId, 0); - if (std::find(_output.begin(), _output.end(), pair) == _output.end()) - pickOutput(outputId); - } +int Node::id() const { return _id; } - void sd::graph::Node::pickOutput(int nodeId, int outputId) { - std::pair pair(nodeId, outputId); - _output.emplace_back(pair); - } +Nd4jLong Node::opNum() const { return _opNum; } - void sd::graph::Node::pickOutput(int outputId) { - std::pair pair(outputId, 0); - _output.emplace_back(pair); +const std::vector> &Node::inputs() const { return _input; } - if (outputId < 0) - _hasExternalOutputs = true; - else - _hasInternalOutputs = true; - } +const std::vector> &Node::outputs() const { return _output; } - int * sd::graph::Node::getDimensionsPtr() { - return _dim; - } +Node::Node(const std::string &opName, const std::string &nodeName, const int id, + const std::vector &inputs, + const std::vector &tArgs, + const std::vector &iArgs) { + auto customOp = ops::OpRegistrator::getInstance().getOperation(opName); - std::vector * sd::graph::Node::getDimensions() { - return &_dimensions; - } + this->_opType = OpType_CUSTOM; + this->_id = id; + this->_opNum = customOp->getOpHash(); + this->_customOp = customOp; - int sd::graph::Node::getLayer() { - return _layer; - } + for (auto i : inputs) pickInput(i); - void sd::graph::Node::setLayer(int layer) { - _layer = layer; - } + ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), + false); - bool sd::graph::Node::hasExternalOutputs() { - return _hasExternalOutputs; - } + block.appendI(iArgs); + block.appendT(tArgs); - bool sd::graph::Node::hasExternalInputs() { - return _hasExternalInputs; - } + this->setContextPrototype(block); +} - bool sd::graph::Node::hasInternalOutputs() { - return _hasInternalOutputs; - } +Node::Node(const std::string &opName, const int id, + const std::vector> &inputs, + const std::vector &tArgs, + const std::vector &iArgs) { + auto customOp = ops::OpRegistrator::getInstance().getOperation(opName); - bool sd::graph::Node::hasInternalInputs() { - return _hasInternalInputs; - } + this->_opType = OpType_CUSTOM; + this->_id = id; + this->_opNum = customOp->getOpHash(); + this->_customOp = customOp; - bool sd::graph::Node::isMultiInput() { - return _input.size() > 1; - } - bool sd::graph::Node::isMultiOutput() { - return _output.size() > 1; - } + for (auto i : inputs) pickInput(i); - double * sd::graph::Node::extraParams() { - return _extraParams; - } + ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), + false); - int Node::totalReferences() { - return _referencedBy.size(); - } + block.appendI(iArgs); + block.appendT(tArgs); - void Node::addReference(int nodeId) { - _referencedBy.emplace_back(nodeId); - } + this->setContextPrototype(block); +} - sd::graph::OpType sd::graph::Node::opType() { - return _opType; - } +Node::Node(sd::ops::DeclarableOp *customOp, int id, + std::initializer_list input, std::initializer_list output, + std::initializer_list dimensions, float scalar, + std::initializer_list tArgs, + std::initializer_list iArgs) { + this->_opType = OpType_CUSTOM; + this->_id = id; + this->_opNum = customOp->getOpHash(); - int sd::graph::Node::id() { - return _id; - } + // if custom op is a registered one - pull it from cache, otherwise - clone + // locally + if (sd::ops::OpRegistrator::getInstance().hasOperation(_opNum)) + this->_customOp = + sd::ops::OpRegistrator::getInstance().getOperation(_opNum); + else + throw std::runtime_error( + "Can't create a node with custom operation within"); - Nd4jLong sd::graph::Node::opNum() { - return _opNum; - } + for (auto i : input) pickInput(i); - std::vector> *sd::graph::Node::input() { - return &_input; - } + for (auto o : output) pickOutput(o); - std::vector> *sd::graph::Node::output() { - return &_output; - } + ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), + false); - bool Node::isScoped() { - return _scope_id != 0; - } + for (auto v : dimensions) block.appendA(v); - void Node::setScopeInfo(int id, const char* name) { - _scope_id = id; + for (auto v : iArgs) block.appendI(v); - if (name != nullptr) - _scope_name = name; - } + for (auto v : tArgs) block.appendT(v); - int Node::scopeId() { - return _scope_id; - } + this->setContextPrototype(block); +} - std::string* Node::scopeName() { - return &_scope_name; - } +const std::vector>& Node::dependencies() const { + return _dependencies; +} - template - Node* Node::asT() { - auto node = this->clone(); - node->_dataType = DataTypeUtils::fromT(); - return node; - } - BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT Node* Node::asT, (), LIBND4J_TYPES); +void Node::actualizeDependencies(const MAP_IMPL &lookupTable) const { + for (const auto &v: _stringDependencies) { + if (lookupTable.count(v) == 0) + throw std::runtime_error("Unknown Node dependency found: [" + v + "]"); - sd::graph::Node::Node(sd::ops::DeclarableOp *customOp, int id, std::initializer_list input, std::initializer_list output, std::initializer_list dimensions, float scalar, std::initializer_list tArgs, std::initializer_list iArgs) { - this->_opType = OpType_CUSTOM; - this->_id = id; - this->_opNum = customOp->getOpHash(); - this->_extraParams = nullptr; - this->_dataType = sd::DataType::FLOAT32; // float as default - this->_dim = nullptr; - this->_customOp = customOp; + const_cast(this)->_dependencies.emplace_back(std::pair{lookupTable.at(v), 0}); + } +} - _hasExternalInputs = false; - _hasExternalOutputs = false; - _hasInternalInputs = false; - _hasInternalOutputs = false; +Node::Node(OpType opType, int opNum, int id, std::initializer_list input, + std::initializer_list output, + std::initializer_list dimensions, float scalar, + std::initializer_list tArgs, + std::initializer_list iArgs) { + this->_opType = opType; + this->_id = id; + this->_opNum = opNum; + + _hasExternalInputs = false; + _hasExternalOutputs = false; + _hasInternalInputs = false; + _hasInternalOutputs = false; + + for (auto i : input) pickInput(i); + + for (auto o : output) pickOutput(o); + + // these ops allow in-place execution by design + if (opType == OpType_TRANSFORM_SAME || opType == OpType_TRANSFORM_FLOAT || + opType == OpType_TRANSFORM_STRICT || opType == OpType_TRANSFORM_BOOL || + opType == OpType_SCALAR || opType == OpType_BROADCAST) { + + _opClass = OpClass_TRANSFORM; + } else if (opType == OpType_REDUCE_SAME || opType == OpType_REDUCE_FLOAT || + opType == OpType_REDUCE_BOOL || opType == OpType_REDUCE_LONG || + opType == OpType_SUMMARYSTATS) { + _opClass = OpClass_REDUCTION; + } + + if (opType == OpType_BROADCAST || opType == OpType_BROADCAST_BOOL || + opType == OpType_INDEX_REDUCE || opType == OpType_SUMMARYSTATS || + opType == OpType_REDUCE_BOOL || opType == OpType_REDUCE_SAME || + opType == OpType_REDUCE_FLOAT || opType == OpType_REDUCE_3 || + opType == OpType_TRANSFORM_STRICT || opType == OpType_TRANSFORM_SAME || + opType == OpType_TRANSFORM_FLOAT || opType == OpType_TRANSFORM_BOOL || + opType == OpType_RANDOM || opType == OpType_PAIRWISE || + opType == OpType_PAIRWISE_BOOL || opType == OpType_SCALAR_BOOL || + opType == OpType_SCALAR) { + ContextPrototype block(nullptr, this->id(), false); + + for (auto v : dimensions) block.appendA(v); + + for (auto v : iArgs) block.appendI(v); + + for (auto v : tArgs) block.appendT(v); + + this->setContextPrototype(block); + + this->setCustomOp(Node::buildOpByType( + opType, (int)input.size(), (int)block.getIArguments().size(), + (int)block.getTArguments().size(), opNum)); + block.setOpDescriptor(this->customOp()->getOpDescriptor()); + } else if (opType == OpType_CUSTOM) { + if (this->customOp()) { + ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), + false); + + for (auto v : dimensions) block.appendA(v); + + for (auto v : iArgs) block.appendI(v); + + for (auto v : tArgs) block.appendT(v); + + this->setContextPrototype(block); + } else + throw std::runtime_error("wrong custom operation given"); + } +}; + +int Node::frameId() const { + return _frameId; +} - _scalar = NDArrayFactory::create(scalar); +void Node::setFrameId(int frameId) { + _frameId = frameId; +} - for (auto i: input) - pickInput(i); +int Node::exitId() const { + return _exitId; +} - for (auto o: output) - pickOutput(o); +void Node::setExitId(int exitId) const { + _exitId = exitId; +} - if (dimensions.size() > 0) { - _dim = new int[dimensions.size()]; - int cnt = 0; - for (auto d: dimensions) { - _dimensions.push_back(d); - _dim[cnt++] = d; - } - } +Node::Node(const FlatNode *node) { + // temporary holders _dimensions, for transferring axis into ContextPrototype + std::vector axis; - auto block = new ContextPrototype(this->getCustomOp()->getOpDescriptor(), this->id(), false); + if (node->scalar() != nullptr) + throw std::runtime_error("FlatNode has scalar defined, it's deprecated"); - for (auto v: dimensions) - block->getAxis()->emplace_back(v); + if (node != nullptr) { + this->_id = node->id(); + // this->_dataType = DataTypeUtils::fromFlatDataType(node->dataType()); + this->_opNum = node->opNum(); + this->_opType = node->opType(); - for (auto v: iArgs) - block->getIArguments()->emplace_back(v); + if (node->name() != nullptr && node->name()->c_str() != nullptr) { + this->_name = node->name()->str(); + } - for (auto v: tArgs) - block->getTArguments()->emplace_back(v); + if (node->inputPaired() != nullptr && node->inputPaired()->size() > 0) { + for (int e = 0; e < (int)node->inputPaired()->size(); e++) { + auto pair = node->inputPaired()->Get(e); + pickInput(pair->first(), pair->second()); + } + } else if (node->input() != nullptr && node->input()->size() > 0) { + for (int e = 0; e < (int)node->input()->size(); e++) + pickInput(node->input()->Get(e)); + } else { + if (this->opType() != OpType_LOGIC) { + if (this->_name.size() > 0) { + nd4j_debug("Node [%i:<%s>] has no inputs defined\n", this->_id, this->_name.c_str()); + } else { + nd4j_debug("Node [%i:] has no inputs defined\n", this->_id); + } + } + } - this->setContextPrototype(block); - } + // reading control deps, and filling _dependencies field + if (node->varControlDeps() != nullptr && node->varControlDeps()->size() > 0) + for (int e = 0; e < node->varControlDeps()->size(); e++) + _stringDependencies.emplace_back(node->varControlDeps()->Get(e)->str()); - void sd::graph::Node::setOpType(OpType opType) { - this->_opType = opType; - } + if (node->controlDepFor() != nullptr && node->controlDepFor()->size() > 0) + for (int e = 0; e < node->controlDepFor()->size(); e++) + _stringDependencies.emplace_back(node->controlDepFor()->Get(e)->str()); - sd::graph::Node::Node(OpType opType, int opNum, int id, std::initializer_list input, std::initializer_list output, std::initializer_list dimensions, float scalar, std::initializer_list tArgs, std::initializer_list iArgs) { - this->_opType = opType; - this->_id = id; - this->_opNum = opNum; - this->_extraParams = nullptr; - this->_dataType = sd::DataType::FLOAT32; // float as default - this->_dim = nullptr; - - _hasExternalInputs = false; - _hasExternalOutputs = false; - _hasInternalInputs = false; - _hasInternalOutputs = false; - - _scalar = NDArrayFactory::create(scalar); - - for (auto i: input) - pickInput(i); - - for (auto o: output) - pickOutput(o); - - if (dimensions.size() > 0) { - _dim = new int[dimensions.size()]; - int cnt = 0; - for (auto d: dimensions) { - _dimensions.push_back(d); - _dim[cnt++] = d; - } - } - - // these ops allow in-place execution by design - if (opType == OpType_TRANSFORM_SAME || opType == OpType_TRANSFORM_FLOAT || opType == OpType_TRANSFORM_STRICT || opType == OpType_TRANSFORM_BOOL || opType == OpType_SCALAR || opType == OpType_BROADCAST) { - if (_output.size() <= 1) { - _isInplace = true; - } - _opClass = OpClass_TRANSFORM; - } else if (opType == OpType_REDUCE_SAME || opType == OpType_REDUCE_FLOAT || opType == OpType_REDUCE_BOOL || opType == OpType_REDUCE_LONG || opType == OpType_SUMMARYSTATS) { - _opClass = OpClass_REDUCTION; - } - - - if (opType == OpType_BROADCAST || - opType == OpType_BROADCAST_BOOL || - opType == OpType_INDEX_REDUCE || - opType == OpType_SUMMARYSTATS || - opType == OpType_REDUCE_BOOL || - opType == OpType_REDUCE_SAME || - opType == OpType_REDUCE_FLOAT || - opType == OpType_REDUCE_3 || - opType == OpType_TRANSFORM_STRICT || - opType == OpType_TRANSFORM_SAME || - opType == OpType_TRANSFORM_FLOAT || - opType == OpType_TRANSFORM_BOOL || - opType == OpType_RANDOM || - opType == OpType_PAIRWISE || - opType == OpType_PAIRWISE_BOOL || - opType == OpType_SCALAR_BOOL || - opType == OpType_SCALAR) { - - this->_isDeductable = true; - - auto block = new ContextPrototype(nullptr, this->id(), false); - - for (auto v: dimensions) - block->getAxis()->emplace_back(v); - - for (auto v: iArgs) - block->getIArguments()->emplace_back(v); - - for (auto v: tArgs) - block->getTArguments()->emplace_back(v); - - this->setContextPrototype(block); - this->setCustomOp(Node::buildOpByType(opType, (int) input.size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), opNum, &_scalar)); - block->setOpDescriptor(this->getCustomOp()->getOpDescriptor()); - } else if (opType == OpType_CUSTOM) { - if (this->getCustomOp()) { - auto block = new ContextPrototype(this->getCustomOp()->getOpDescriptor(), this->id(), false); - - for (auto v: dimensions) - block->getAxis()->emplace_back(v); - - for (auto v: iArgs) - block->getIArguments()->emplace_back(v); - - for (auto v: tArgs) - block->getTArguments()->emplace_back(v); - - this->setContextPrototype(block); - } else throw std::runtime_error("wrong custom operation given"); - } - }; - - sd::graph::Node::Node(const sd::graph::FlatNode *node) { - _hasExternalInputs = false; - _hasExternalOutputs = false; - _hasInternalInputs = false; - _hasInternalOutputs = false; - _extraParams = nullptr; - _dim = nullptr; - _dataType = sd::DataType::FLOAT32; // float as default - if (node->scope_id() != 0) - this->_scope_id = node->scope_id(); - - if (node->scope_name() != nullptr && node->scope_name()->size() > 0) - this->_scope_name = node->scope_name()->str(); - - if (node->scalar() != nullptr) { - auto scalar = sd::graph::FlatUtils::fromFlatArray(node->scalar()); - _scalar = *scalar; - delete scalar; - } - - if (node != nullptr) { - this->_id = node->id(); - //this->_dataType = DataTypeUtils::fromFlatDataType(node->dataType()); - this->_opNum = node->opNum(); - this->_opType = node->opType(); - - if (node->name() != nullptr && node->name()->c_str() != nullptr) { - this->_name = node->name()->str(); - } - - if (node->inputPaired() != nullptr && node->inputPaired()->size() > 0) { - for (int e = 0; e < (int) node->inputPaired()->size(); e++) { - auto pair = node->inputPaired()->Get(e); - pickInput(pair->first(), pair->second()); - } - } else if (node->input() != nullptr && node->input()->size() > 0) { - for (int e = 0; e < (int) node->input()->size(); e++) - pickInput(node->input()->Get(e)); - } else { - if (this->opType() != OpType_LOGIC) { - if (this->_name.size() > 0) { - nd4j_debug("Node [%i:<%s>] has no inputs defined\n", this->_id, this->_name.c_str()); - } else { - nd4j_debug("Node [%i:] has no inputs defined\n", this->_id); - } - } - } - - /* - if (node->output() != nullptr) - for (int e = 0; e < (int) node->output()->size(); e++) { - auto oid = node->output()->Get(e); - if (oid != this->_id && oid != 0) { - nd4j_verbose("Picking output: %i\n", node->output()->Get(e)); - pickOutput(oid); - } - } - */ - - - if (node->extraParams() != nullptr && node->extraParams()->size() > 0) { - _extraParams = new double[node->extraParams()->size()]; - for (int e = 0; e < (int) node->extraParams()->size(); e++) { - _extraParams[e] = static_cast(node->extraParams()->Get(e)); - } - } - - if (node->dimensions() != nullptr && node->dimensions()->size() > 0) { - _dim = new int[node->dimensions()->size()]; - for (int e = 0; e < (int) node->dimensions()->size(); e++) { - _dimensions.emplace_back(node->dimensions()->Get(e)); - _dim[e] = node->dimensions()->Get(e); - } - } - - if (this->opType() == OpType_LOGIC && this->opNum() == 100L) { - if (node->extraInteger()->size() < 1) { - nd4j_printf("Node_%i is type of Enter, but has no FrameID defined\n", this->id()); - throw std::runtime_error("Enter node must have FrameID specified"); - } - - this->setFrameId(node->extraInteger()->Get(0)); - } - - - // these ops allow in-place execution by design - if (_opType == OpType_BROADCAST || - _opType == OpType_BROADCAST_BOOL || - _opType == OpType_INDEX_REDUCE || - _opType == OpType_SUMMARYSTATS || - _opType == OpType_REDUCE_BOOL || - _opType == OpType_REDUCE_SAME || - _opType == OpType_REDUCE_FLOAT || - _opType == OpType_REDUCE_3 || - _opType == OpType_TRANSFORM_STRICT || - _opType == OpType_TRANSFORM_SAME || - _opType == OpType_TRANSFORM_FLOAT || - _opType == OpType_TRANSFORM_BOOL || - _opType == OpType_RANDOM || - _opType == OpType_PAIRWISE || - _opType == OpType_PAIRWISE_BOOL || - _opType == OpType_SCALAR_BOOL || - _opType == OpType_SCALAR) { - - if (_output.size() <= 1) { - _isInplace = true; - } - - if (node->input() != nullptr && node->input()->size() > 0) { - this->_isDeductable = true; - - auto block = new ContextPrototype(nullptr, this->id(), false); - - - for (auto v: _dimensions) - block->getAxis()->emplace_back(v); - - if (node->extraParams() != nullptr && node->extraParams()->size() > 0) - for (int e = 0; e < (int) node->extraParams()->size(); e++) { - block->getTArguments()->emplace_back(static_cast(node->extraParams()->Get(e))); - } - - if (node->extraBools() != nullptr && node->extraBools()->size() > 0) - for (int e = 0; e < (int) node->extraBools()->size(); e++) { - block->getBArguments()->push_back(node->extraBools()->Get(e)); - } - - if (node->extraInteger() != nullptr && node->extraInteger()->size() > 0) - for (int e = 0; e < (int) node->extraInteger()->size(); e++) { - block->getIArguments()->emplace_back(node->extraInteger()->Get(e)); - } - - if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) { - for (int e = 0; e < (int) node->extraTypes()->size(); e++) { - block->getDArguments()->emplace_back((sd::DataType) node->extraTypes()->Get(e)); - } - } - - this->setContextPrototype(block); - this->setCustomOp(Node::buildOpByType(_opType, (int) node->input()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar)); - block->setOpDescriptor(this->getCustomOp()->getOpDescriptor()); - } else if (node->inputPaired() != nullptr && node->inputPaired()->size() > 0) { - this->_isDeductable = true; - - auto block = new ContextPrototype(nullptr, this->id(), false); - - for (int e = 0; e < this->input()->size(); e++) { - block->inputs()->emplace_back(this->input()->at(e)); - } - - // there's no other IArgs in legacy options, actually - for (auto v: _dimensions) - block->getAxis()->emplace_back(v); - - if (node->extraParams() != nullptr && node->extraParams()->size() > 0) - for (int e = 0; e < (int) node->extraParams()->size(); e++) { - block->getTArguments()->emplace_back(static_cast(node->extraParams()->Get(e))); - } - - if (node->extraBools() != nullptr && node->extraBools()->size() > 0) - for (int e = 0; e < (int) node->extraBools()->size(); e++) { - block->getBArguments()->push_back(node->extraBools()->Get(e)); - } - - if (node->extraInteger() != nullptr && node->extraInteger()->size() > 0) - for (int e = 0; e < (int) node->extraInteger()->size(); e++) { - block->getIArguments()->emplace_back(node->extraInteger()->Get(e)); - } - - if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) { - for (int e = 0; e < (int) node->extraTypes()->size(); e++) { - block->getDArguments()->emplace_back((sd::DataType) node->extraTypes()->Get(e)); - } - } - - this->setContextPrototype(block); - - this->setCustomOp(Node::buildOpByType(_opType, (int) node->inputPaired()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar)); - block->setOpDescriptor(this->getCustomOp()->getOpDescriptor()); - } - } else if (this->_opType == OpType_CUSTOM) { - auto op = sd::ops::OpRegistrator::getInstance().getOperation(this->opNum()); - if (op == nullptr) { - nd4j_verbose("Can't find operation: %lld\n", this->opNum()); - throw std::runtime_error("Can't find requested operation"); - } - - auto block = new ContextPrototype(nullptr, this->id()); - - for (int e = 0; e < this->input()->size(); e++) { - block->inputs()->emplace_back(this->input()->at(e)); - } - - if (node->extraInteger() != nullptr) - for (uint32_t e = 0; e < node->extraInteger()->size(); e++) { - auto v = node->extraInteger()->Get(e); - // FIXME: remove this static_cast, iArgs should be Nd4jLong - block->getIArguments()->emplace_back(static_cast(v)); - } - - if (node->extraParams() != nullptr) - for (uint32_t e = 0; e < node->extraParams()->size(); e++) - block->getTArguments()->emplace_back(static_cast(node->extraParams()->Get(e))); - - if (node->extraBools() != nullptr && node->extraBools()->size() > 0) - for (int e = 0; e < (int) node->extraBools()->size(); e++) { - block->getBArguments()->push_back(node->extraBools()->Get(e)); - } - - if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) { - for (int e = 0; e < (int) node->extraTypes()->size(); e++) { - block->getDArguments()->emplace_back((sd::DataType) node->extraTypes()->Get(e)); - } - } - - for (auto v: _dimensions) - block->getAxis()->emplace_back(v); - - this->setContextPrototype(block); - this->setCustomOp(op); - block->setOpDescriptor(this->getCustomOp()->getOpDescriptor()); - } - } else { - // empty dynamic node, tests probably - } - } + if (node->controlDeps() != nullptr && node->controlDeps()->size() > 0) + for (int e = 0; e < node->controlDeps()->size(); e++) + _stringDependencies.emplace_back(node->controlDeps()->Get(e)->str()); - sd::DataType Node::dataType() { - return _dataType; - } - ContextPrototype* Node::protoContext() { - return _protoContext; - } + // transferring dimensions. Used for legacy ops only + if (node->dimensions() != nullptr && node->dimensions()->size() > 0) { + axis.resize(node->dimensions()->size()); + + for (int e = 0; e < (int) node->dimensions()->size(); e++) + axis[e] = node->dimensions()->Get(e); + } - sd::graph::Node::~Node() { - if (_extraParams != nullptr) - delete[] _extraParams; + if (this->opType() == OpType_LOGIC && this->opNum() == 100L) { + if (node->extraInteger()->size() < 1) + throw std::runtime_error("Enter Node [" + StringUtils::valueToString(this->id()) + "] must have FrameID specified"); - if (_dim != nullptr) - delete[] _dim; + this->setFrameId(node->extraInteger()->Get(0)); + } - if (_protoContext != nullptr) - delete _protoContext; + // these ops allow in-place execution by design + if (_opType == OpType_BROADCAST || _opType == OpType_BROADCAST_BOOL || + _opType == OpType_INDEX_REDUCE || _opType == OpType_SUMMARYSTATS || + _opType == OpType_REDUCE_BOOL || _opType == OpType_REDUCE_SAME || + _opType == OpType_REDUCE_FLOAT || _opType == OpType_REDUCE_3 || + _opType == OpType_TRANSFORM_STRICT || + _opType == OpType_TRANSFORM_SAME || _opType == OpType_TRANSFORM_FLOAT || + _opType == OpType_TRANSFORM_BOOL || _opType == OpType_RANDOM || + _opType == OpType_PAIRWISE || _opType == OpType_PAIRWISE_BOOL || + _opType == OpType_SCALAR_BOOL || _opType == OpType_SCALAR) { + + if (node->input() != nullptr && node->input()->size() > 0) { + ContextPrototype block(nullptr, this->id(), false); + if (!this->name().empty()) + block.setName(this->name()); + + for (auto v : axis) block.appendA(v); + + if (node->extraParams() != nullptr && node->extraParams()->size() > 0) + for (int e = 0; e < (int) node->extraParams()->size(); e++) + block.appendT(static_cast(node->extraParams()->Get(e))); - if (_isDeductable && _customOp != nullptr) { - Node::deleteOpByType(_opType, _customOp); - } - } + if (node->extraBools() != nullptr && node->extraBools()->size() > 0) + for (int e = 0; e < (int) node->extraBools()->size(); e++) + block.appendB(node->extraBools()->Get(e)); + + if (node->extraInteger() != nullptr && node->extraInteger()->size() > 0) + for (int e = 0; e < (int) node->extraInteger()->size(); e++) + block.appendI(node->extraInteger()->Get(e)); + + if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) + for (int e = 0; e < (int) node->extraTypes()->size(); e++) + block.appendD((sd::DataType) node->extraTypes()->Get(e)); + + this->setContextPrototype(block); + this->setCustomOp(Node::buildOpByType( + _opType, (int) node->input()->size(), + (int) block.getIArguments().size(), + (int) block.getTArguments().size(), (int) _opNum)); + block.setOpDescriptor(this->customOp()->getOpDescriptor()); + } else if (node->inputPaired() != nullptr && + node->inputPaired()->size() > 0) { + ContextPrototype block(nullptr, this->id(), false); + + for (int e = 0; e < this->inputs().size(); e++) { + block.pickInput(this->inputs().at(e)); + } + + // there's no other IArgs in legacy options, actually + for (auto v : axis) block.appendA(v); + + if (node->extraParams() != nullptr && node->extraParams()->size() > 0) + for (int e = 0; e < (int) node->extraParams()->size(); e++) + block.appendT(static_cast(node->extraParams()->Get(e))); + + if (node->extraBools() != nullptr && node->extraBools()->size() > 0) + for (int e = 0; e < (int) node->extraBools()->size(); e++) + block.appendB(node->extraBools()->Get(e)); + + if (node->extraInteger() != nullptr && node->extraInteger()->size() > 0) + for (int e = 0; e < (int) node->extraInteger()->size(); e++) + block.appendI(node->extraInteger()->Get(e)); + + if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) + for (int e = 0; e < (int) node->extraTypes()->size(); e++) + block.appendD((sd::DataType) node->extraTypes()->Get(e)); + + this->setContextPrototype(block); + + this->setCustomOp(Node::buildOpByType( + _opType, (int) node->inputPaired()->size(), + (int) block.getIArguments().size(), + (int) block.getTArguments().size(), (int) _opNum)); + block.setOpDescriptor(this->customOp()->getOpDescriptor()); + } + } else if (this->_opType == OpType_LOGIC) { + ContextPrototype block(nullptr, this->id()); + if (!this->name().empty()) + block.setName(this->name()); + + for (int e = 0; e < this->inputs().size(); e++) + block.pickInput(this->inputs().at(e)); + + this->setContextPrototype(block); + } else if (this->_opType == OpType_CUSTOM) { + auto op = + sd::ops::OpRegistrator::getInstance().getOperation(this->opNum()); + if (op == nullptr) + throw std::runtime_error("Can't find requested operation [" + StringUtils::valueToString(this->opNum()) + "]"); + + ContextPrototype block(nullptr, this->id()); + if (!this->name().empty()) + block.setName(this->name()); + + for (int e = 0; e < this->inputs().size(); e++) + block.pickInput(this->inputs().at(e)); + + if (node->extraInteger() != nullptr) + for (uint32_t e = 0; e < node->extraInteger()->size(); e++) { + auto v = node->extraInteger()->Get(e); + // FIXME: remove this static_cast, iArgs should be Nd4jLong + block.appendI(static_cast(v)); + } + + if (node->extraParams() != nullptr) + for (uint32_t e = 0; e < node->extraParams()->size(); e++) + block.appendT(static_cast(node->extraParams()->Get(e))); - int sd::graph::Node::getRewindNode() { - return _rewindNode; + if (node->extraBools() != nullptr && node->extraBools()->size() > 0) + for (int e = 0; e < (int)node->extraBools()->size(); e++) { + block.appendB(node->extraBools()->Get(e)); } - void sd::graph::Node::setRewindNode(int nodeId) { - _rewindNode = nodeId; + if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) { + for (int e = 0; e < (int)node->extraTypes()->size(); e++) { + block.appendD((sd::DataType)node->extraTypes()->Get(e)); } + } - std::pair& sd::graph::Node::getRewindLayer() { - return _rewindLayer; - }; + for (auto v : axis) block.appendA(v); - void sd::graph::Node::setRewindLayer(int layerId, int stepId) { - _rewindLayer.first = layerId; - _rewindLayer.second = stepId; - } + this->setContextPrototype(block); + this->setCustomOp(op); + block.setOpDescriptor(this->customOp()->getOpDescriptor()); + } + } else { + // empty dynamic node, tests probably + } +} - bool sd::graph::Node::equals(Node *other) { - if (_opType == other->_opType && _dataType == other->_dataType && _opNum == other->_opNum) - return true; - return false; - } +const ContextPrototype &Node::contextPrototype() const { return _protoContext; } - void sd::graph::Node::deleteOpByType(OpType opType, void *op) { - switch (opType) { - case OpType_PAIRWISE: - delete reinterpret_cast(op); - break; - case OpType_PAIRWISE_BOOL: - delete reinterpret_cast(op); - break; - case OpType_TRANSFORM_STRICT: - delete reinterpret_cast(op); - break; - case OpType_TRANSFORM_SAME: - delete reinterpret_cast(op); - break; - case OpType_TRANSFORM_FLOAT: - delete reinterpret_cast(op); - break; - case OpType_TRANSFORM_BOOL: - delete reinterpret_cast(op); - break; - case OpType_SCALAR: - delete reinterpret_cast(op); - break; - case OpType_SCALAR_BOOL: - delete reinterpret_cast(op); - break; - case OpType_REDUCE_3: - delete reinterpret_cast(op); - break; - case OpType_REDUCE_SAME: - delete reinterpret_cast(op); - break; - case OpType_REDUCE_FLOAT: - delete reinterpret_cast(op); - break; - case OpType_REDUCE_LONG: - delete reinterpret_cast(op); - break; - case OpType_REDUCE_BOOL: - delete reinterpret_cast(op); - break; - case OpType_INDEX_REDUCE: - delete reinterpret_cast(op); - break; - case OpType_SUMMARYSTATS: - delete reinterpret_cast(op); - break; - case OpType_RANDOM: - delete reinterpret_cast(op); - break; - case OpType_BROADCAST: - delete reinterpret_cast(op); - break; - case OpType_BROADCAST_BOOL: - delete reinterpret_cast(op); - break; - case OpType_CUSTOM: - delete reinterpret_cast(op); - break; - default: - throw std::runtime_error("Bad opType passed in"); - } - } +Node::~Node() { } - sd::ops::DeclarableOp* sd::graph::Node::buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum, NDArray *scalar) { - switch (opType) { - case OpType_PAIRWISE: - return new sd::ops::LegacyPairwiseTransformOp(opNum); - case OpType_PAIRWISE_BOOL: - return new sd::ops::LegacyPairwiseTransformBoolOp(opNum); - case OpType_TRANSFORM_STRICT: - return new sd::ops::LegacyTransformStrictOp(opNum); - case OpType_TRANSFORM_SAME: - return new sd::ops::LegacyTransformSameOp(opNum); - case OpType_TRANSFORM_FLOAT: - return new sd::ops::LegacyTransformFloatOp(opNum); - case OpType_TRANSFORM_BOOL: - return new sd::ops::LegacyTransformBoolOp(opNum); - case OpType_SCALAR: - return scalar == nullptr ? new sd::ops::LegacyScalarOp(opNum) : new sd::ops::LegacyScalarOp(opNum, *scalar); - case OpType_SCALAR_BOOL: - return scalar == nullptr ? new sd::ops::LegacyScalarBoolOp(opNum) : new sd::ops::LegacyScalarBoolOp(opNum, *scalar); - case OpType_REDUCE_3: - return new sd::ops::LegacyReduce3Op(opNum); - case OpType_REDUCE_SAME: - return new sd::ops::LegacyReduceSameOp(opNum); - case OpType_REDUCE_FLOAT: - return new sd::ops::LegacyReduceFloatOp(opNum); - case OpType_REDUCE_LONG: - return new sd::ops::LegacyReduceLongOp(opNum); - case OpType_REDUCE_BOOL: - return new sd::ops::LegacyReduceBoolOp(opNum); - case OpType_INDEX_REDUCE: - return new sd::ops::LegacyIndexReduceOp(opNum); - case OpType_SUMMARYSTATS: - return new sd::ops::LegacyStatsOp(opNum); - case OpType_RANDOM: - return new sd::ops::LegacyRandomOp(opNum); - case OpType_BROADCAST: - return new sd::ops::LegacyBroadcastOp(opNum); - case OpType_BROADCAST_BOOL: - return new sd::ops::LegacyBroadcastBoolOp(opNum); - default: - throw std::runtime_error("Bad opType passed in"); - } - } +bool Node::equals(const Node *other) const { + if (_opType == other->_opType && _opNum == other->_opNum) + return true; - bool Node::isDeductable() { - return _isDeductable; - } + return false; +} - void Node::setDeductable(bool reallyDeductable) { - _isDeductable = reallyDeductable; - } +bool Node::equals(const Node &other) const { + return this->equals(&other); +} +Node::Node(const Node &other) noexcept { + _opType = other._opType; + _opClass = other._opClass; + _opNum = other._opNum; + _customOp = other._customOp; + _name = other._name; + _id = other._id; + _frameId = other._frameId; + _exitId = other._exitId; + + _hasExternalOutputs = other._hasExternalOutputs; + _hasExternalInputs = other._hasExternalInputs; + _hasInternalOutputs = other._hasInternalOutputs; + _hasInternalInputs = other._hasInternalInputs; + + _customOp = other._customOp; + _protoContext = other._protoContext; + + _input = other._input; + _output = other._output; +} - Node* Node::clone() { - if (this->_customOp && this->_opType == sd::graph::OpType_CUSTOM) { - auto clone = new Node(this->_customOp, _id); - clone->pullValues(this); - return clone; - } - else { - auto clone = new Node(_opType, _opNum, _id); +Node &Node::operator=(const Node &other) noexcept { + if (this == &other) return *this; - clone->pullValues(this); + _opType = other._opType; + _opClass = other._opClass; + _opNum = other._opNum; + _customOp = other._customOp; + _name = other._name; + _id = other._id; + _frameId = other._frameId; + _exitId = other._exitId; - // op time - if (!_isDeductable) - clone->_customOp = _customOp; - else { - auto c = dynamic_cast(_customOp); - clone->_customOp = c->clone(); - } + _hasExternalOutputs = other._hasExternalOutputs; + _hasExternalInputs = other._hasExternalInputs; + _hasInternalOutputs = other._hasInternalOutputs; + _hasInternalInputs = other._hasInternalInputs; - return clone; - } - } - } + _customOp = other._customOp; + _protoContext = other._protoContext; + + _input = other._input; + _output = other._output; + + return *this; +} + +Node::Node(Node &&other) noexcept { + + _opType = other._opType; + _opClass = other._opClass; + _opNum = other._opNum; + _customOp = other._customOp; + _name = std::move(other._name); + _id = other._id; + _frameId = other._frameId; + _exitId = other._exitId; + + _hasExternalOutputs = other._hasExternalOutputs; + _hasExternalInputs = other._hasExternalInputs; + _hasInternalOutputs = other._hasInternalOutputs; + _hasInternalInputs = other._hasInternalInputs; + + _protoContext = std::move(other._protoContext); + + _customOp = std::move(other._customOp); + _input = std::move(other._input); + _output = std::move(other._output); } + +Node &Node::operator=(Node &&other) noexcept { + if (this == &other) return *this; + + _opType = other._opType; + _opClass = other._opClass; + _opNum = other._opNum; + _customOp = other._customOp; + _name = std::move(other._name); + _id = other._id; + _frameId = other._frameId; + _exitId = other._exitId; + + _hasExternalOutputs = other._hasExternalOutputs; + _hasExternalInputs = other._hasExternalInputs; + _hasInternalOutputs = other._hasInternalOutputs; + _hasInternalInputs = other._hasInternalInputs; + + _protoContext = std::move(other._protoContext); + + _customOp = std::move(other._customOp); + _input = std::move(other._input); + _output = std::move(other._output); + + return *this; +} + +std::shared_ptr Node::buildOpByType( + OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum) { + switch (opType) { + case OpType_PAIRWISE: + return std::make_shared(opNum); + case OpType_PAIRWISE_BOOL: + return std::make_shared(opNum); + case OpType_TRANSFORM_STRICT: + return std::make_shared(opNum); + case OpType_TRANSFORM_SAME: + return std::make_shared(opNum); + case OpType_TRANSFORM_FLOAT: + return std::make_shared(opNum); + case OpType_TRANSFORM_BOOL: + return std::make_shared(opNum); + case OpType_SCALAR: + return std::make_shared(opNum); + case OpType_SCALAR_BOOL: + return std::make_shared(opNum); + case OpType_REDUCE_3: + return std::make_shared(opNum); + case OpType_REDUCE_SAME: + return std::make_shared(opNum); + case OpType_REDUCE_FLOAT: + return std::make_shared(opNum); + case OpType_REDUCE_LONG: + return std::make_shared(opNum); + case OpType_REDUCE_BOOL: + return std::make_shared(opNum); + case OpType_INDEX_REDUCE: + return std::make_shared(opNum); + case OpType_SUMMARYSTATS: + return std::make_shared(opNum); + case OpType_RANDOM: + return std::make_shared(opNum); + case OpType_BROADCAST: + return std::make_shared(opNum); + case OpType_BROADCAST_BOOL: + return std::make_shared(opNum); + default: + throw std::runtime_error("Bad opType passed in"); + } +} + +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/impl/NodeState.cpp b/libnd4j/include/graph/impl/NodeState.cpp index d09de9c57bbc..09c4ba461f5c 100644 --- a/libnd4j/include/graph/impl/NodeState.cpp +++ b/libnd4j/include/graph/impl/NodeState.cpp @@ -21,49 +21,27 @@ #include namespace sd { - namespace graph { - NodeState::NodeState(int id) { - _id = id; - } +namespace graph { +NodeState::NodeState(int id) { _id = id; } - void NodeState::setInnerTime(Nd4jLong time) { - _inner = time; - } +void NodeState::setInnerTime(Nd4jLong time) { _inner = time; } - void NodeState::setOuterTime(Nd4jLong time) { - _outer = time; - } +void NodeState::setOuterTime(Nd4jLong time) { _outer = time; } - Nd4jLong NodeState::innerTime() { - return _inner; - } +Nd4jLong NodeState::innerTime() { return _inner; } - Nd4jLong NodeState::outerTime() { - return _outer; - } +Nd4jLong NodeState::outerTime() { return _outer; } - void NodeState::markActive(bool isActive) { - _active = isActive; - } +void NodeState::markActive(bool isActive) { _active = isActive; } - bool NodeState::isActive() { - return _active; - } +bool NodeState::isActive() { return _active; } - int NodeState::branch() { - return _branch; - } +int NodeState::branch() { return _branch; } - void NodeState::markBranch(int index) { - _branch = index; - } +void NodeState::markBranch(int index) { _branch = index; } - bool NodeState::wasExecuted() { - return _executed; - } +bool NodeState::wasExecuted() { return _executed; } - void NodeState::markExecuted(bool wasExecuted) { - _executed = wasExecuted; - } - } -} \ No newline at end of file +void NodeState::markExecuted(bool wasExecuted) { _executed = wasExecuted; } +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp new file mode 100644 index 000000000000..f01a6d0e8c2c --- /dev/null +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -0,0 +1,798 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// @author raver119@gmail.com +// + +#include +#include +#include +#include + +namespace sd { +namespace graph { + +/////////////////////////////////////////////////////////////////// +// move constructor +OptimizedGraph::OptimizedGraph(OptimizedGraph &&other) noexcept: _sortedGraph(std::move(other._sortedGraph)), _nodesMap(std::move(other._nodesMap)) { } + +/////////////////////////////////////////////////////////////////// +// move assignment operator +OptimizedGraph& OptimizedGraph::operator=(OptimizedGraph &&other) noexcept { + + if (this == &other) + return *this; + + _sortedGraph = std::move(other._sortedGraph); + _nodesMap = std::move(other._nodesMap); + + return *this; +} + +/////////////////////////////////////////////////////////////////// +void OptimizedGraph::sortGraphWithFrames(const VariableSpace& varSpace) { + + struct NodeInfo { + std::vector _in = {}; + bool _isActive = false; + }; + + MAP_IMPL workMap; // key is node id, value is class NodeInfo containing auxiliary information (layer number this node belongs to, input/output nodes) + + // create workMap, fill vectors containing input and output nodes per each node, and find start nodes + std::unordered_map> nextItersOfFrame; // key - id of frame, value is vector containing ids of all NextIterations of this frame + for (auto& p : _nodesMap) { + + const auto& id = p.first; + const auto& inputs = p.second.inputs(); + const auto& nameOfNode = p.second.name(); + + if(nameOfNode.find("Exit") != std::string::npos) { + + const int idOfEnter = _nodesMap[_nodesMap[inputs[0].first].inputs()[0].first].inputs()[0].first; + p.second.setFrameId(_nodesMap[idOfEnter].frameId()); + _nodesMap[idOfEnter].setExitId(id); + + } else if (nameOfNode.find("NextIteration") != std::string::npos) { + + const std::string frameName = nameOfNode.substr(0, nameOfNode.find_last_of("/")); + nextItersOfFrame[frameName].push_back(id); + } + + for (int i = 0; i < inputs.size(); ++i) { + + const auto& inId = inputs[i].first; + if (_nodesMap.count(inId) != 0) { // is op + + _nodesMap[inId].pickOutput(id, inputs[i].second); + workMap[id]._in.push_back(inId); + + } else { // is variable + + const auto& depends = varSpace.getVariable(inId).get()->dependencies(); + + for (int j = 0; j < depends.size(); ++j) { + + if(std::find(workMap[id]._in.begin(), workMap[id]._in.end(), depends[j].first) == workMap[id]._in.end()) { + + _nodesMap[depends[j].first].pickOutput(id, depends[j].second); + workMap[id]._in.push_back(depends[j].first); + } + } + } + } + } + + std::vector seq; + uint numOfActive = 0; + auto it = workMap.begin(); + + // perform linear sort + while (numOfActive < workMap.size()) { + + bool makeActive = true; + + if(!it->second._isActive) { + + const auto& nameOfNode = _nodesMap[it->first].name(); + + for (const auto& inId : it->second._in) { + + if (_nodesMap[inId].name().find("NextIteration") != std::string::npos) + continue; + + if (!workMap[inId]._isActive) { + makeActive = false; + } + else if (nameOfNode.find("Exit") != std::string::npos) { + const std::string frameName = nameOfNode.substr(0, nameOfNode.find_last_of("/")); + + for (const auto& j : nextItersOfFrame[frameName]) { + if(!workMap[j]._isActive){ + makeActive = false; + break; + } + } + } + if(!makeActive) + break; + } + } + + if(makeActive && !it->second._isActive) { + ++numOfActive; + it->second._isActive = true; + seq.push_back(it->first); + } + + if(++it == workMap.end()) + it = workMap.begin(); + } + + // fill _sortedGraph + OpSequence sequence; + for (auto i : seq) + sequence.append(_nodesMap.at(i), _nodesMap.at(i).contextPrototype()); + _sortedGraph.emplace_back(ExecutionLayer({sequence})); +} + +/////////////////////////////////////////////////////////////////// +void OptimizedGraph::sortUsualGraph(const VariableSpace& varSpace) { + + struct NodeInfo { + uint _layerNum = 0; + std::vector _opSeq = {}; + std::vector _in = {}; + std::vector _out = {}; + }; + + MAP_IMPL workMap; // key is node id, value is class NodeInfo containing auxiliary information (layer number this node belongs to, input/output nodes, OpSequence that starts from this node) + + // create workMap, fill vectors containing input and output nodes per each node, and find start nodes + std::vector startNodes, idsOfExits, idsOfNextIters; + for (auto& p : _nodesMap) { + + const auto& id = p.first; + const auto& inputs = p.second.inputs(); + + if(p.second.name().find("Exit") != std::string::npos) { + idsOfExits.push_back(id); + const int idOfEnter = _nodesMap[_nodesMap[inputs[0].first].inputs()[0].first].inputs()[0].first; + p.second.setFrameId(_nodesMap[idOfEnter].frameId()); + _nodesMap[idOfEnter].setExitId(id); + } + else if(p.second.name().find("NextIteration") != std::string::npos) { + idsOfNextIters.push_back(id); + } + + for (int i = 0; i < inputs.size(); ++i) { + + const auto& inId = inputs[i].first; + if (_nodesMap.count(inId) != 0) { // is op + + _nodesMap[inId].pickOutput(id, inputs[i].second); + + if(_nodesMap[inId].name().find("NextIteration") == std::string::npos) { + workMap[inId]._out.push_back(id); + workMap[id]._in.push_back(inId); + } + } + else { // is variable + + const auto& depends = varSpace.getVariable(inId).get()->dependencies(); + + for (int j = 0; j < depends.size(); ++j) { + if(std::find(workMap[id]._in.begin(), workMap[id]._in.end(), depends[j].first) == workMap[id]._in.end()) { + _nodesMap[depends[j].first].pickOutput(id, depends[j].second); + if(_nodesMap[depends[j].first].name().find("NextIteration") == std::string::npos) { + workMap[depends[j].first]._out.push_back(id); + workMap[id]._in.push_back(depends[j].first); + } + } + } + } + } + + if(workMap[id]._in.empty()) + startNodes.push_back(id); + } + + // make all NextIteration within current frame to be inputs for all Exits within the same frame + for (const auto& i0 : idsOfExits) { + + const std::string frameName0 = _nodesMap[i0].name().substr(0, _nodesMap[i0].name().find_last_of("/")); + + for (const auto& i1 : idsOfNextIters) { + + const std::string frameName1 = _nodesMap[i1].name().substr(0, _nodesMap[i1].name().find_last_of("/")); + + if(frameName0 != frameName1) + continue; + + workMap[i0]._in.push_back(i1); + workMap[i1]._out.push_back(i0); + } + } + + // collect OpSequences (fill _opSeq) + std::vector nodesToDelete; + for (auto& p : workMap) { + + auto& out = p.second._out; + while(out.size() == 1 && workMap[out[0]]._in.size() == 1) { + nodesToDelete.push_back(out[0]); + p.second._opSeq.push_back(out[0]); + out = std::move(workMap[out[0]]._out); + } + } + + // delete nodes present in _opSeq, their ids are already stored in nodesToDelete + for (const auto& i : nodesToDelete) + workMap.erase(i); + + // lambda for topological sort + std::function visit = [&visit, &workMap, this] (const int id, const uint layerNum, uint& numOfLayers) { + + if(layerNum <= workMap[id]._layerNum) + return; + workMap[id]._layerNum = layerNum; + if(numOfLayers < layerNum) + numOfLayers = layerNum; + + // const bool isSwitch = this->_nodesMap[id].name().find("Switch") != std::string::npos; + + // if(!isSwitch) { + for (const auto& nextId : workMap[id]._out) + visit(nextId, layerNum+1, numOfLayers); + // } + // else { + // if(this->_nodesMap[id].outputs()[0].second == 1) { + // visit(_nodesMap[id].outputs()[0].first, layerNum+1, numOfLayers); // true branch + // visit(_nodesMap[id].outputs()[1].first, numOfLayers+1, numOfLayers); // false branch + // } + // else { + // visit(_nodesMap[id].outputs()[1].first, layerNum+1, numOfLayers); // true branch + // visit(_nodesMap[id].outputs()[0].first, numOfLayers+1, numOfLayers); // false branch + // } + // } + }; + + // perform topological sort + uint numOfLayers = 0; + for (const auto& id : startNodes) + for (const auto& nextId : workMap[id]._out) + visit(nextId, 1, numOfLayers); + + + // fill vectors with layers + std::vector sortedGraphTemp(numOfLayers+1); + for (const auto& p : workMap) { + + OpSequence seq; + seq.append(_nodesMap.at(p.first), _nodesMap.at(p.first).contextPrototype()); + + for (const auto& id : p.second._opSeq) + seq.append(_nodesMap.at(id), _nodesMap.at(id).contextPrototype()); + + // while(sortedGraphTemp.size() <= p.second._layerNum) + // sortedGraphTemp.emplace_back(ExecutionLayer()); + + sortedGraphTemp[p.second._layerNum].append(std::move(seq)); + } + + + // check whether there are layers with one OpSequence which in turn contains only one op + bool isLayerWithOneOp = false; + for(auto& layer : sortedGraphTemp) { + if(layer.width() == 1 && layer.at(0).length() == 1) { + isLayerWithOneOp = true; + break; + } + } + + // fill _sortedGraph + if(!isLayerWithOneOp) { + _sortedGraph = std::move(sortedGraphTemp); + } + else { + for (uint i = 0; i < sortedGraphTemp.size();) { + + OpSequence seq; + while(i < sortedGraphTemp.size() && sortedGraphTemp[i].width() == 1 && sortedGraphTemp[i].at(0).length() == 1) + seq.append(std::move(sortedGraphTemp[i++].at(0).at(0))); + + if(seq.length() != 0) + _sortedGraph.emplace_back(ExecutionLayer({seq})); + else + _sortedGraph.emplace_back(std::move(sortedGraphTemp[i++])); + } + + } + + purgeEmptyLayers(); +} + +/////////////////////////////////////////////////////////////////// +OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableSpace& varSpace): _nodesMap(inMap) { + + bool hasEnter = false; + for (const auto& p : _nodesMap) { + if(p.second.name().find("Enter") != std::string::npos) { + hasEnter = true; + break; + } + } + + if(hasEnter) + sortGraphWithFrames(varSpace); + else + sortUsualGraph(varSpace); +} + +/////////////////////////////////////////////////////////////////// +void OptimizedGraph::append(const OpSequence &sequence) { + _sortedGraph.emplace_back(ExecutionLayer({sequence})); +} + +/////////////////////////////////////////////////////////////////// +size_t OptimizedGraph::size() const { + // std::lock_guard lock(_mutex); + + size_t size = 0; + + std::for_each(_sortedGraph.begin(), _sortedGraph.end(), [&size] (const ExecutionLayer &l) { + for (int e = 0; e < l.width(); e++) { + size += l.at(0).length(); + } + ;}); + + return size; +} + +int OptimizedGraph::nodeLayer(int nodeId) const { + int cnt = 0; + for (const auto &v:_sortedGraph) { + if (v.hasNode(nodeId)) + return cnt; + + cnt++; + } + + throw std::runtime_error("Layer of Node [" + StringUtils::valueToString(nodeId) + "] wasn't found in OptimizedGraph"); +} + +int OptimizedGraph::nodeIndex(int nodeId) const { + for (const auto &v:_sortedGraph) { + if (v.hasNode(nodeId)) { + for (int e = 0; e < v.width(); e++) { + if (v[e].hasNode(nodeId)) + return v[e].nodeIndex(nodeId); + } + } + } + + throw std::runtime_error("Index of Node [" + StringUtils::valueToString(nodeId) + "] wasn't found in OptimizedGraph"); +} + +int OptimizedGraph::nodeSequence(int nodeId) const { + for (const auto &v:_sortedGraph) { + if (v.hasNode(nodeId)) { + for (int e = 0; e < v.width(); e++) { + if (v[e].hasNode(nodeId)) + return e; + } + } + } + + throw std::runtime_error("Sequence of Node [" + StringUtils::valueToString(nodeId) + "] wasn't found in OptimizedGraph"); +} + +void OptimizedGraph::purgeEmptyLayers() { + // purge empty sequences first, if any + for (auto &v:_sortedGraph) + v.purgeEmptySequences(); + + // now purge all layers without sequences + _sortedGraph.erase(std::remove_if(_sortedGraph.begin(), _sortedGraph.end(), [](ExecutionLayer &layer) -> bool { return layer.width() == 0; }), _sortedGraph.end()); +} + +void OptimizedGraph::printOut() const { + for (uint i = 0; i < _sortedGraph.size(); ++i) { + printf("Layer [%u] {\n", i); + for (uint j = 0; j < _sortedGraph[i].width(); ++j) { + printf(" Sequence [%u] {\n", j); + _sortedGraph[i][j].printOut(); + printf(" }\n"); + } + printf("}\n"); + } + +/* + printf("And simple print:\n"); + for (int i = 0; i < _sortedGraph.size(); ++i) { + printf("layer %i: ", i); + for (int j = 0; j < _sortedGraph[i].width(); ++j) { + printf("("); + for (int k = 0; k < _sortedGraph[i][j].length(); ++k) { + printf("%i, ", _sortedGraph[i][j][k].protoContext().nodeId()); + } + printf("), "); + } + printf("\n"); + } +*/ +} + + +// std::for_each(inMap.begin(), inMap.end(), [] (const std::pair &p) {printf("node id %i \n", p.first);}); + // _sortedGraph = std::vector(numOfLayers+1); + // std::vector> printGraph = std::vector>(numOfLayers+1); + // for (const auto& p : workMap) { + + // OpSequence seq; + // seq.append(inMap.at(p.first).customOp(), inMap.at(p.first).contextPrototype()); + // printGraph[p.second._layerNum].push_back(p.first); + + // for (const auto& id : p.second._opSeq) { + // seq.append(inMap.at(id).customOp(), inMap.at(id).contextPrototype()); + // printGraph[p.second._layerNum].push_back(id); + // } + + // _sortedGraph[p.second._layerNum].append(std::move(seq)); + // } + + // for (int i = 0; i < printGraph.size(); ++i) { + // printf("layer %i: ", i); + // for (int j = 0; j < printGraph[i].size(); ++j) + // printf("%i, ", printGraph[i][j]); + // printf("\n"); + // } + + // for (const auto& p : workMap) { + // printf("node %i: , layerNum %i: , opSeq: ", p.first,p.second._layerNum); + // std::for_each(p.second._opSeq.begin(), p.second._opSeq.end(), [] (const int &j) {printf("%i, ", j);}); + // printf(", ins: "); + // std::for_each(p.second._in.begin(), p.second._in.end(), [] (const int &j) {printf("%i, ", j);}); + // printf(", outs: "); + // std::for_each(p.second._out.begin(), p.second._out.end(), [] (const int &j) {printf("%i, ", j);}); + // printf("\n"); + // } + // printf("\n-----------------\n");; + + + +// OptimizedGraph::OptimizedGraph(Graph* original) { +// _originalGraph = original; +// _memoryManager = const_cast(&original->memoryManager()); +// // create optimized graph +// createOptimizedGraph(); +// } + +// OptimizedGraph::OptimizedGraph(const OptimizedGraph& other) noexcept { +// _onion = other._onion; +// _memoryManager = other._memoryManager; +// _originalGraph = other._originalGraph; +// } + +// OptimizedGraph& OptimizedGraph::operator=( +// const OptimizedGraph& other) noexcept { +// if (this == &other) return *this; + +// _onion = other._onion; +// _memoryManager = other._memoryManager; +// _originalGraph = other._originalGraph; + +// return *this; +// } + +// OptimizedGraph::OptimizedGraph(OptimizedGraph&& other) noexcept { +// _onion = std::move(other._onion); +// _memoryManager = other._memoryManager; +// _originalGraph = other._originalGraph; +// } + +// OptimizedGraph& OptimizedGraph::operator=(OptimizedGraph&& other) noexcept { +// if (this == &other) return *this; + +// _onion = std::move(other._onion); +// _memoryManager = other._memoryManager; +// _originalGraph = other._originalGraph; + +// return *this; +// } + + + + + +// void OptimizedGraph::append(const std::vector& layer) { +// std::lock_guard lock(_mutex); +// _onion[_onion.size()] = layer; +// _size = 0; +// } + +// void OptimizedGraph::append(OpSequence& sequence) { +// append(ExecutionLayer({sequence})); +// } + +// void OptimizedGraph::append(const ExecutionLayer& layer) { +// std::lock_guard lock(_mutex); +// _onion[_onion.size()] = layer; +// _size = 0; +// } + +// const GraphMemoryManager& OptimizedGraph::memoryManager() const { +// return *_memoryManager; +// } + +// const Graph& OptimizedGraph::originalGraph() const { return *_originalGraph; } + +// bool OptimizedGraph::opGraphProto(std::unordered_map& collector, +// std::set& startNodes, +// std::set& inBranchingNodes) const { +// // double check to avoid unstable behavior +// if (originalGraph().unmappedNodes().empty()) return false; + +// const auto& unmappedNodes = originalGraph().unmappedNodes(); +// // iterate via original graph nodes to gather node information +// for (const auto& it : unmappedNodes) { +// const auto& ID = it.first; +// const auto& inputs = it.second.input(); +// // if node info is not in collecter add it +// if (collector.find(ID) == collector.end()) collector[ID] = NodeInfo(); + +// NodeInfo& parentNode = collector[ID]; +// // count external and internal inputs to find out the type of the node +// // (start, in-branching, out-branching) +// int inExCounts = 0, inInternalCounts = 0; +// for (const auto& in : inputs) { +// // find input id in original graph +// if (unmappedNodes.find(in.first) == unmappedNodes.end()) { +// // count external inputs, all inputs which id is not in unmapped +// // container will be treaded as external +// inExCounts++; +// } else { +// // count iternal inputs, all inputs that are not in external variable +// // space will be treated as outputs from other nodes +// inInternalCounts++; +// // if node info is not in collector add it +// if (collector.find(in.first) == collector.end()) +// collector[in.first] = NodeInfo(); +// // input node connection with discovered +// collector[in.first].addConnection(ID); +// } +// } +// // set operation type +// parentNode.setType(it.second.opType()); + +// // if move then 1 internal input this is in-branching node +// parentNode.setInBranching(inInternalCounts > 1); +// // gather start and in-branching nodes for the loop when operations are put +// // to OpSequence (topolSearch) +// if (inExCounts == inputs.size()) { +// startNodes.emplace(ID); +// } else { +// if (parentNode.isInBranching()) inBranchingNodes.emplace(ID); +// } +// } +// return true; +// } + +// bool OptimizedGraph::topolSearch( +// const int startNode, std::unordered_map& collector, +// std::vector>& opSeq) const { +// // double check to avoid unstable behavior +// if (originalGraph().unmappedNodes().empty() || collector.empty()) +// return false; + +// // skip nodes which are not pre-collected and pre-processed +// auto itParent = collector.find(startNode); +// if (itParent != collector.end()) { +// // iterate via start (in-branching) nodes connections in depth +// for (const auto& itNodes : itParent->second.connections()) { +// auto itChild = collector.find(itNodes); +// // double check +// if (itChild != collector.end()) { +// // if the child is in-branching node it will be treated as start node or +// // it was proceed +// if (itChild->second.isInBranching() || itChild->second.isProcessed()) { +// continue; +// } +// // put operation to OpSequence container +// const auto it = originalGraph().unmappedNodes().find(itNodes); +// auto& child = itChild->second; +// // the layer and sequence are pre-defined in layersSeqDefine method +// opSeq[child.layer()][child.sequence()].append( +// it->second.customOp(), it->second.contextPrototype()); +// child.setProcessed(); +// // go to the child node connections +// topolSearch(itNodes, collector, opSeq); +// } +// } +// } +// return true; +// } + +// void OptimizedGraph::createOptimizedGraph() { +// // container to store node infor +// std::unordered_map collector; +// // containers to store start and in-branching nodes +// std::set startNodes, inBranching; +// // container to store max sequences per layer +// std::unordered_map layersMaxSeq; + +// // optimizing graph prototyping +// // select start nodes +// // create connections between nodes +// // select in-branching nodes ( more then one iternal input -> outputs from +// // other nodes) +// if (!opGraphProto(collector, startNodes, inBranching)) +// throw std::runtime_error( +// "OptimizedGraph::optimizedGraph() - not prototyped!"); + +// // next step set the node layer and it sequence in layer +// // define max layers and max sequence per layer +// int startSeq = 0; +// bool bOnlyStartNodes = collector.empty(); +// for (const auto& id : startNodes) { +// layersMaxSeq[0] = startSeq; +// // if only start nodes exists they have to be add to connections +// if (bOnlyStartNodes) { +// auto node = NodeInfo(); +// node.setLayer(0); +// node.setProcessed(true); +// node.setSequence(startSeq); +// collector[id] = node; +// } else { +// layersSeqDefine(collector, id, 0, startSeq, layersMaxSeq); +// } +// startSeq++; +// } + +// // init container to collect operations per node position (layer:sequence) +// std::vector> vOpSeq; +// if (!initOpSeqContainer(layersMaxSeq, vOpSeq)) +// throw std::runtime_error( +// "OptimizedGraph::initOpSeqContainer() - cannot initialize OpSequence, " +// "not all nodes properly prototyped!"); + +// // combine start nodes and in-branching nodes +// startNodes.insert(inBranching.begin(), inBranching.end()); +// // re-init proceed NodeInfo member to avoid append sequence several times +// for (auto& it : collector) { +// it.second.setProcessed(false); +// } + +// // iterate via start and in-branching nodes +// for (const auto& id : startNodes) { +// const auto it = originalGraph().unmappedNodes().find(id); +// auto& nodeInfo = collector[id]; +// // append start/in-branching node operation to sequence +// if (!nodeInfo.isProcessed()) { +// vOpSeq[nodeInfo.layer()][nodeInfo.sequence()].append( +// it->second.customOp(), it->second.contextPrototype()); +// nodeInfo.setProcessed(); +// } + +// // search in depth via connections of "start" node +// if (!topolSearch(id, collector, vOpSeq)) +// throw std::runtime_error( +// "OptimizedGraph::topolSearch() - cannot run topological search, " +// "inputs incorrect!"); +// } +// // put results to optimized graph +// for (auto& vSeq : vOpSeq) { +// this->append(vSeq); +// } +// } + +// bool OptimizedGraph::initOpSeqContainer( +// const std::unordered_map& layersMaxSeq, +// std::vector>& vOpSeq) const { +// // double check to avoid unstable behavior +// if (layersMaxSeq.empty()) return false; +// // pre-init op-sequence size layers/per-layer sequence +// vOpSeq.resize(layersMaxSeq.size()); +// for (const auto& it : layersMaxSeq) { +// vOpSeq[it.first].resize(it.second + 1); +// } +// return true; +// } + +// bool OptimizedGraph::layersSeqDefine( +// std::unordered_map& collection, int ID, int layer, +// int startSeq, std::unordered_map& layersMaxSeq) const { +// // double check to avoid unstable behavior +// auto parent = collection.find(ID); +// if (parent == collection.end()) return false; + +// // if node was proceed and the current layer is less of it own return +// if (parent->second.isProcessed() && parent->second.layer() >= layer) +// return true; + +// // put layer and sequence to container that collects layers and max sequence +// // per layer +// auto layerFound = layersMaxSeq.find(layer); +// if (layerFound == layersMaxSeq.end()) { +// // if layer was not treated before, create pair for it +// layersMaxSeq[layer] = 0; +// // set sequence value to 0, as this is first sequence in layer +// startSeq = 0; +// } else { +// // if node sequence position was not checked use it for max sequence +// // selection sequence have to be incremented as max + 1, without any jumps +// if (startSeq > (layerFound->second + 1)) startSeq = layerFound->second + 1; + +// layerFound->second = +// (layerFound->second < startSeq && parent->second.sequence() < 0) +// ? startSeq +// : layerFound->second; +// } + +// // double check if the layer is higher and set node layer +// if (parent->second.layer() < layer) parent->second.setLayer(layer); +// // double check if sequence was init, if not set current sequence +// if (parent->second.sequence() < 0) parent->second.setSequence(startSeq); +// // set is node out-branching +// parent->second.setOutBranching(parent->second.connections().size() > 1); +// // set that node was processed, to avoid it double processing (only for some +// // cases it can be processed several times) +// parent->second.setProcessed(); + +// // if current node is out-branching it childs will be put to next layer +// if (parent->second.isOutBranching() && !parent->second.isLogic()) layer++; + +// // childs sequence position have to start from max defined sequence position +// // in layer or if it is first node in layer from 0 +// int seq = (layersMaxSeq.find(layer) == layersMaxSeq.end()) +// ? 0 +// : layersMaxSeq[layer]; +// // if parent is out-branching node sequence have to be increment +// // on the next stage the sequence value will be double checked with max per +// // layer todo check logic part maybe here have to be check operation class +// // (something likke Switch, If, While etc) probably for each of them could be +// // other behavior +// seq = (parent->second.isOutBranching() && !parent->second.isLogic()) ? seq + 1 +// : seq; + +// // loop via childs (connected nodes) +// for (const auto& id : parent->second.connections()) { +// // double check to avoid unstable behavior +// auto child = collection.find(id); +// if (child == collection.end()) return false; + +// // in case parent was not out-branching node but child is in branching it +// // will be put to next layer todo check logic part +// if (!parent->second.isOutBranching() && child->second.isInBranching() && +// !child->second.isLogic()) +// layer++; + +// // move in depth of connections +// layersSeqDefine(collection, id, layer, seq, layersMaxSeq); +// // increment sequence as childs are on the one layer in case if child was +// // not processed earlier todo check logic part +// if (!parent->second.isLogic()) seq++; +// } + +// return true; +// } + + + +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/impl/RandomGenerator.cpp b/libnd4j/include/graph/impl/RandomGenerator.cpp new file mode 100644 index 000000000000..c6e1b7b81f59 --- /dev/null +++ b/libnd4j/include/graph/impl/RandomGenerator.cpp @@ -0,0 +1,56 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { +namespace graph { +RandomGenerator::RandomGenerator(const RandomGenerator& other) noexcept { + _rootState = other._rootState; + _nodeState = other._nodeState; +} + +RandomGenerator& RandomGenerator::operator=( + const RandomGenerator& other) noexcept { + if (this == &other) return *this; + + _rootState = other._rootState; + _nodeState = other._nodeState; + + return *this; +} + +// move constructor +RandomGenerator::RandomGenerator(RandomGenerator&& other) noexcept { + _rootState = other._rootState; + _nodeState = other._nodeState; +} + +// move assignment operator +RandomGenerator& RandomGenerator::operator=(RandomGenerator&& other) noexcept { + if (this == &other) return *this; + + _rootState = other._rootState; + _nodeState = other._nodeState; + + return *this; +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/impl/ResultWrapper.cpp b/libnd4j/include/graph/impl/ResultWrapper.cpp index 277644acf913..204445148c47 100644 --- a/libnd4j/include/graph/impl/ResultWrapper.cpp +++ b/libnd4j/include/graph/impl/ResultWrapper.cpp @@ -19,33 +19,27 @@ // #include -#include +#include namespace sd { - namespace graph { - ResultWrapper::ResultWrapper(Nd4jLong size, Nd4jPointer ptr) { - if (size <= 0) - throw std::runtime_error("FlatResult size should be > 0"); - - _size = size; - _pointer = ptr; - } - - ResultWrapper::~ResultWrapper() { - if (_pointer != nullptr && _size > 0) { - auto ptr = reinterpret_cast(_pointer); - delete[] ptr; - } - } - - - Nd4jLong ResultWrapper::size() { - return _size; - } - - Nd4jPointer ResultWrapper::pointer() { - return _pointer; - } - } -} \ No newline at end of file +namespace graph { +ResultWrapper::ResultWrapper(Nd4jLong size, Nd4jPointer ptr) { + if (size <= 0) throw std::runtime_error("FlatResult size should be > 0"); + + _size = size; + _pointer = ptr; +} + +ResultWrapper::~ResultWrapper() { + if (_pointer != nullptr && _size > 0) { + auto ptr = reinterpret_cast(_pointer); + delete[] ptr; + } +} + +Nd4jLong ResultWrapper::size() { return _size; } + +Nd4jPointer ResultWrapper::pointer() { return _pointer; } +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/impl/Scope.cpp b/libnd4j/include/graph/impl/Scope.cpp deleted file mode 100644 index 84a8f2f0d074..000000000000 --- a/libnd4j/include/graph/impl/Scope.cpp +++ /dev/null @@ -1,73 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 14.10.2017. -// - -#include - -namespace sd { - namespace graph { - Scope::Scope(int id, const char *name) { - _id = id; - - if (name != nullptr) - _name = name; - else - name = ""; - } - - Scope::~Scope() { - for (auto v: _nodes) - delete v; - } - - void Scope::push_back(Node *node) { - _nodes.emplace_back(node); - } - - std::vector* Scope::nodes() { - return &_nodes; - } - - int Scope::size() { - return (int) _nodes.size(); - } - - int Scope::id() { - return _id; - } - - std::string* Scope::name() { - return &_name; - } - - void Scope::forgetNodes() { - _nodes.clear(); - } - - Scope* Scope::clone() { - auto clone = new Scope(_id, _name.c_str()); - - for (auto v: _nodes) - clone->_nodes.emplace_back(v->clone()); - - return clone; - } - } -} - diff --git a/libnd4j/include/graph/impl/SessionLocalStorage.cpp b/libnd4j/include/graph/impl/SessionLocalStorage.cpp deleted file mode 100644 index 9c512b0b6f48..000000000000 --- a/libnd4j/include/graph/impl/SessionLocalStorage.cpp +++ /dev/null @@ -1,123 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include -#include -#include - -namespace sd { - namespace graph { - SessionLocalStorage::SessionLocalStorage(VariableSpace* variableSpace, Stash* stash) { - // we start from 1, since key 0 holds original VariableSpace - _sessionCounter.store(1); - _variableSpace = variableSpace; - _stash = stash; - } - - VariableSpace* SessionLocalStorage::localVariableSpace(Nd4jLong sessionId) { - _mutex.lock(); - auto varSpace = _threadVariableSpace.at(sessionId); - _mutex.unlock(); - - return varSpace; - } - - VariableSpace* SessionLocalStorage::localVariableSpace() { - return localVariableSpace(getSessionId()); - } - - SessionLocalStorage::~SessionLocalStorage() { - for (const auto & v: _threadVariableSpace) { - delete v.second; - } - } - - - Nd4jLong SessionLocalStorage::getThreadId() { -#ifdef __APPLE__ - // syscall? -#elif _WIN32 - // some win32api -#else - // syscall! -#endif - auto id=std::this_thread::get_id(); - uint64_t* ptr=(uint64_t*) &id; - return (*ptr); - } - - int SessionLocalStorage::numberOfSessions() { - _mutex.lock(); - int size = (int) _threadSession.size(); - _mutex.unlock(); - return size; - } - - void SessionLocalStorage::endSession(Nd4jLong sessionId) { - // we should delete specific holders here - _mutex.lock(); - auto vs = _threadVariableSpace[sessionId]; - _threadVariableSpace.erase(sessionId); - - delete vs; - _mutex.unlock(); - } - - void SessionLocalStorage::endSession() { - auto tid = getThreadId(); - - _mutex.lock(); - - auto ntid = _threadSession[tid]; - _threadSession.erase(tid); - - _mutex.unlock(); - - endSession(ntid); - } - - Nd4jLong SessionLocalStorage::getSessionId() { - auto tid = getThreadId(); - - _mutex.lock(); - auto ntid = _threadSession[tid]; - - _mutex.unlock(); - - return ntid; - } - - Nd4jLong sd::graph::SessionLocalStorage::startSession() { - auto tid = getThreadId(); - - nd4j_debug("Adding ThreadId: %i;\n", (int) tid); - Nd4jLong ntid = _sessionCounter++; - _mutex.lock(); - - _threadSession[tid] = ntid; - _threadVariableSpace[ntid] = _variableSpace->clone(); - - _mutex.unlock(); - - return ntid; - } - } -} - diff --git a/libnd4j/include/graph/impl/Stash.cpp b/libnd4j/include/graph/impl/Stash.cpp index 618e01c8745c..5215e988a4dd 100644 --- a/libnd4j/include/graph/impl/Stash.cpp +++ b/libnd4j/include/graph/impl/Stash.cpp @@ -20,40 +20,38 @@ #include - namespace std { - size_t hash::operator()(const sd::graph::KeyPair& k) const { - using std::hash; - auto res = std::hash()(k.name()); - res ^= std::hash()(k.key()) + 0x9e3779b9 + (res << 6) + (res >> 2); - return res; - } +size_t hash::operator()(const sd::graph::KeyPair &k) const { + using std::hash; + auto res = std::hash()(k.name()); + res ^= std::hash()(k.key()) + 0x9e3779b9 + (res << 6) + (res >> 2); + return res; } +} // namespace std namespace sd { - namespace graph { - sd::graph::KeyPair::KeyPair(int node, const char * name) { - _node = node; - _name = std::string(name); - } +namespace graph { +sd::graph::KeyPair::KeyPair(int node, const char *name) { + _node = node; + _name = std::string(name); +} - bool sd::graph::KeyPair::operator<(const KeyPair& other) const { - if (_node < other._node) - return true; - else if (_node > other._node) - return false; - else - return _name < other._name; - } +bool sd::graph::KeyPair::operator<(const KeyPair &other) const { + if (_node < other._node) + return true; + else if (_node > other._node) + return false; + else + return _name < other._name; +} - sd::graph::Stash::Stash() { - // - } +sd::graph::Stash::Stash() { + // +} - sd::graph::Stash::~Stash() { - if (_handles.size() > 0) - this->clear(); - } +sd::graph::Stash::~Stash() { + if (_handles.size() > 0) this->clear(); +} /* bool sd::graph::Stash::checkStash(sd::graph::Block& block, const char *name) { @@ -61,40 +59,40 @@ bool sd::graph::Stash::checkStash(sd::graph::Block& block, const char *name) { } */ - bool sd::graph::Stash::checkStash(int nodeId, const char *name) { - KeyPair kp(nodeId, name); - return _stash.count(kp) > 0; - } +bool sd::graph::Stash::checkStash(int nodeId, const char *name) { + KeyPair kp(nodeId, name); + return _stash.count(kp) > 0; +} /* -sd::NDArray* sd::graph::Stash::extractArray(sd::graph::Block& block, const char *name) { - return extractArray(block.getNodeId(), name); +sd::NDArray* sd::graph::Stash::extractArray(sd::graph::Block& block, const char +*name) { return extractArray(block.getNodeId(), name); } */ - sd::NDArray* sd::graph::Stash::extractArray(int nodeId, const char *name) { - KeyPair kp(nodeId, name); - return _stash[kp]; - } +sd::NDArray *sd::graph::Stash::extractArray(int nodeId, const char *name) { + KeyPair kp(nodeId, name); + return _stash[kp]; +} /* -void sd::graph::Stash::storeArray(sd::graph::Block& block, const char *name, sd::NDArray *array) { - storeArray(block.getNodeId(), name, array); +void sd::graph::Stash::storeArray(sd::graph::Block& block, const char *name, +sd::NDArray *array) { storeArray(block.getNodeId(), name, array); } */ - void sd::graph::Stash::storeArray(int nodeId, const char *name, sd::NDArray *array) { - KeyPair kp(nodeId, name); - _stash[kp] = array; +void sd::graph::Stash::storeArray(int nodeId, const char *name, + sd::NDArray *array) { + KeyPair kp(nodeId, name); + _stash[kp] = array; - // storing reference to delete it once it's not needed anymore - _handles.push_back(array); - } + // storing reference to delete it once it's not needed anymore + _handles.push_back(array); +} - void sd::graph::Stash::clear() { - for (auto v: _handles) - delete v; +void sd::graph::Stash::clear() { + for (auto v : _handles) delete v; - _handles.clear(); - _stash.clear(); - } - } -} \ No newline at end of file + _handles.clear(); + _stash.clear(); +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/impl/TimeHolder.cpp b/libnd4j/include/graph/impl/TimeHolder.cpp index a292a89974cf..a01ae29f7bd4 100644 --- a/libnd4j/include/graph/impl/TimeHolder.cpp +++ b/libnd4j/include/graph/impl/TimeHolder.cpp @@ -21,28 +21,26 @@ #include namespace sd { - namespace graph { +namespace graph { - void TimeHolder::setOuterTime(int nodeId, Nd4jLong time) { - _outer[nodeId] = time; - } +void TimeHolder::setOuterTime(int nodeId, Nd4jLong time) { + _outer[nodeId] = time; +} - void TimeHolder::setInnerTime(int nodeId, Nd4jLong time) { - _inner[nodeId] = time; - } +void TimeHolder::setInnerTime(int nodeId, Nd4jLong time) { + _inner[nodeId] = time; +} - Nd4jLong TimeHolder::outerTime(int nodeId) { - if (_outer.count(nodeId) == 0) - return 0; +Nd4jLong TimeHolder::outerTime(int nodeId) { + if (_outer.count(nodeId) == 0) return 0; - return _outer[nodeId]; - } + return _outer[nodeId]; +} - Nd4jLong TimeHolder::innerTime(int nodeId) { - if (_inner.count(nodeId) == 0) - return 0; +Nd4jLong TimeHolder::innerTime(int nodeId) { + if (_inner.count(nodeId) == 0) return 0; - return _inner[nodeId]; - } - } + return _inner[nodeId]; } +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/impl/Variable.cpp b/libnd4j/include/graph/impl/Variable.cpp index e87c51897ebb..dc6f74e2541f 100644 --- a/libnd4j/include/graph/impl/Variable.cpp +++ b/libnd4j/include/graph/impl/Variable.cpp @@ -18,343 +18,335 @@ // @author raver119@gmail.com // -#include -#include -#include #include #include +#include #include +#include +#include #include namespace sd { - namespace graph { - - template - Variable* Variable::asT() { - auto result = new Variable(this->isPlaceholder()); - - result->markExternal(this->_external); - result->setId(this->_id); - result->markReadOnly(this->_readOnly); - result->setName(&this->_name); - result->setIndex(this->_index); - - if (this->_ndarray != nullptr) - result->setNDArray(new NDArray(this->_ndarray->template asT())); - - // FIXME: add support for ArrayList - if (this->_list != nullptr) { - nd4j_printf("ArrayList not supported yet\n", ""); - throw std::runtime_error("ArrayList not supported yet for asT"); - } - - return result; - } - BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT Variable* Variable::asT, (), LIBND4J_TYPES); - - sd::graph::Variable* sd::graph::Variable::clone() { - auto result = new Variable(this->isPlaceholder()); - result->_external = this->_external; - result->_id = this->_id; - result->_readOnly = this->_readOnly; - result->_name = this->_name; - result->_index = this->_index; - - if (this->_ndarray != nullptr) { - result->_ndarray = new NDArray(this->_ndarray->dup(this->_ndarray->ordering())); - result->_readOnly = false; - result->_removable = true; - } - - if (this->_list != nullptr) - result->_list = this->_list->clone(); - - return result; - } - - void sd::graph::Variable::setIndex(int index) { - _index = index; - } - - bool sd::graph::Variable::hasNDArray() { - return _ndarray != nullptr; - } - - void sd::graph::Variable::setVariableType(VariableType variableType) { - _variableType = variableType; - } - - bool sd::graph::Variable::hasNDArrayList() { - return _list != nullptr; - } - - bool sd::graph::Variable::isPlaceholder() { - return _placeholder; - } - - std::string * sd::graph::Variable::getName() { - return &_name; - } - - void sd::graph::Variable::setName(std::string *name) { - _name = *name; - } - - int sd::graph::Variable::id() { - return _id; - } - - int sd::graph::Variable::index() { - return _index; - } +namespace graph { +Variable::Variable(const NDArrayList &arrayList, const std::string &name, int id, int idx) + : Variable(std::make_shared(arrayList), name, id, idx) { } + +Variable::Variable(std::shared_ptr list, const std::string &name, int id, int idx) { + _list = list; + if (!name.empty()) _name = name; + + _id = id; + _index = idx; + _variableType = VariableType::ARRAY_LIST; +} + +Variable::Variable(const NDArray &array, const std::string &name, int id, + int idx) { + _ndarray = std::make_shared(array); + + if (!name.empty()) _name = name; + + _id = id; + _index = idx; +} + +Variable::Variable() { + // +} + +void sd::graph::Variable::setIndex(int index) { _index = index; } + +bool sd::graph::Variable::hasNDArray() const { + return _ndarray.get() != nullptr; +} + +void sd::graph::Variable::setVariableType(VariableType variableType) { + _variableType = variableType; +} - void sd::graph::Variable::setId(int id) { - _id = id; - } - - bool sd::graph::Variable::isEmpty() { - if (_variableType == VariableType::NDARRAY) - return _ndarray == nullptr || !_ndarray->nonNull(); - else if (_variableType == VariableType::ARRAY_LIST) - return _list == nullptr; - - return false; - } - - bool sd::graph::Variable::isExternal() { - return _external; - } - - bool sd::graph::Variable::isReadOnly() { - return _readOnly; - } - - void sd::graph::Variable::markExternal(bool reallyExternal) { - this->_external = reallyExternal; - } - - void sd::graph::Variable::markRemovable(bool reallyRemovable) { - if (!reallyRemovable) - nd4j_debug("",""); - this->_removable = reallyRemovable; - } - - void sd::graph::Variable::markReadOnly(bool reallyReadOnly) { - this->_readOnly = reallyReadOnly; - } - - sd::NDArray * sd::graph::Variable::getNDArray() { - if (_variableType != VariableType::NDARRAY) { - nd4j_printf("Variable[%i:%i/<%s>] is has [%s] type, but NDArray was requested\n", this->_id, this->_index, this->_name.c_str(), EnumUtils::_VariableTypeToString(_variableType)); - } - - if (this->_ndarray == nullptr) { - if (_name.empty()) { - auto nodeId = StringUtils::valueToString(this->id()); - auto outputIndex = StringUtils::valueToString(this->index()); - throw std::runtime_error("Array doesn't exist for Variable <" + nodeId + ":" + outputIndex + ">"); - } else { - auto outputIndex = StringUtils::valueToString(this->index()); - throw std::runtime_error("Array doesn't exist for Variable <" + this->_name + ":" + outputIndex+ ">"); - } - } - - return this->_ndarray; - } - - sd::NDArrayList * sd::graph::Variable::getNDArrayList() { - if (_variableType != VariableType::ARRAY_LIST) { - nd4j_debug("Variable[%i:%i/<%s>] is has [%s] type, but NDArrayList was requested\n", this->_id, this->_index, this->_name.c_str(), EnumUtils::_VariableTypeToString(_variableType)); - } - return this->_list; - } - - - bool Variable::isRemovable() { - return _removable; - } - - - void sd::graph::Variable::setNDArrayList(sd::NDArrayList * list) { - this->_variableType = VariableType::ARRAY_LIST; - this->_list = list; - } - - - void sd::graph::Variable::setNDArray(sd::NDArray * array) { - this->_variableType = VariableType::NDARRAY; - this->_ndarray = array; - } - - - VariableType sd::graph::Variable::variableType() { - return _variableType; - } - - - sd::graph::Variable::Variable(const sd::graph::FlatVariable *flatVariable) { - auto vid = flatVariable->id(); - this->_id = vid->first(); - this->_index = vid->second(); - - if (flatVariable->name() != nullptr && flatVariable->name()->size() != 0) - this->_name = flatVariable->name()->str(); - - _external = true; - _readOnly = false; - - int8_t *buffer = nullptr; - - switch (flatVariable->variabletype()) { - case VarType_VARIABLE: { - - // ????? - if (flatVariable->ndarray() != nullptr) { - auto ar = flatVariable->ndarray(); - _ndarray = sd::graph::FlatUtils::fromFlatArray(ar); - } - - _variableType = VariableType::NDARRAY; - } - break; - case VarType_CONSTANT: { - if (flatVariable->ndarray() == nullptr) - throw std::runtime_error("CONSTANT variable must have NDArray bundled"); - - auto ar = flatVariable->ndarray(); - if (ar->dtype() == DType_UTF8) { - _ndarray = sd::graph::FlatUtils::fromFlatArray(ar); - } else { - _ndarray = sd::graph::FlatUtils::fromFlatArray(ar); - } +bool sd::graph::Variable::hasNDArrayList() const { return _list != nullptr; } - _variableType = VariableType::NDARRAY; - } - break; - case VarType_ARRAY: { +bool sd::graph::Variable::isPlaceholder() const { return _placeholder || _variableType == sd::graph::VariableType::PLACEHOLDER; } - // ????? - if (flatVariable->ndarray() != nullptr) { - auto ar = flatVariable->ndarray(); - _ndarray = sd::graph::FlatUtils::fromFlatArray(ar); - // _ndarray->triggerAllocationFlag(true); - } - - _variableType = VariableType::NDARRAY; - } - break; - case VarType_PLACEHOLDER: { - if (flatVariable->shape() == nullptr && flatVariable->ndarray() == nullptr) - throw std::runtime_error("PLACEHOLDER variable must have shape defined"); - - if (flatVariable->ndarray() != nullptr) { - auto ar = flatVariable->ndarray(); - _ndarray = sd::graph::FlatUtils::fromFlatArray(ar); - // _ndarray->triggerAllocationFlag(true); - - _variableType = VariableType::NDARRAY; - } +const std::string &sd::graph::Variable::name() const { return _name; } - if (flatVariable->shape() != nullptr) { - int shapeLen = flatVariable->shape()->Length(); - for (int i = 0; i < flatVariable->shape()->size(); i++) - _shape.emplace_back(flatVariable->shape()->Get(i)); +const std::string &sd::graph::Variable::getName() const { return _name; } - if (_ndarray == nullptr) - _variableType = VariableType::PLACEHOLDER; - } - } - break; - default: - throw std::runtime_error("Unknown variable type used"); - } - } +void sd::graph::Variable::setName(const std::string &name) { _name = name; } - std::vector& sd::graph::Variable::shape() { - return _shape; - } +int sd::graph::Variable::id() const { return _id; } - sd::graph::Variable::Variable(bool placeholder) { - _placeholder = placeholder; - } +int sd::graph::Variable::index() const { return _index; } +void sd::graph::Variable::setId(int id) { _id = id; } - sd::graph::Variable::Variable(NDArray *array, const char *name ) { - _ndarray = array; - - _external = false; - _readOnly = false; +bool sd::graph::Variable::isEmpty() const { + if (_variableType == VariableType::NDARRAY) + return _ndarray == nullptr || !_ndarray->nonNull(); + else if (_variableType == VariableType::ARRAY_LIST) + return _list == nullptr; + + return false; +} - if (name != nullptr) - _name = std::string(name); - - if (_ndarray != nullptr) - _variableType = VariableType::NDARRAY; - } - - - sd::graph::Variable::Variable(NDArray *array, const char *name, int id, int idx) : Variable(array, name) { - _id = id; - _index = idx; - } - - - sd::graph::Variable::~Variable() { - //nd4j_printf("Removing variable [%i:%i]\n", _id, _index); - if (_variableType == VariableType::NDARRAY) { - nd4j_debug("Removing variable <%i:%i>\n", _id, _index); - if (_ndarray != nullptr && _removable && !_readOnly) - delete _ndarray; - } - } - - - void Variable::setId(int id, int idx) { - _id = id; - _index = idx; - } - - - flatbuffers::Offset Variable::asFlatVariable(flatbuffers::FlatBufferBuilder &builder) { - if (this->hasNDArray()) { - auto array = this->getNDArray(); - auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector()); - - auto fBuffer = builder.CreateVector(array->asByteVector()); - - // packing array - auto fArray = CreateFlatArray(builder, fShape, fBuffer, (sd::graph::DType) array->dataType()); - - // packing id/index of this var - auto fVid = CreateIntPair(builder, this->_id, this->_index); - - // name is still optional - flatbuffers::Offset stringId = 0; - if (!this->_name.empty()) - stringId = builder.CreateString(this->_name); +bool sd::graph::Variable::isExternal() const { return _external; } - // returning array - return CreateFlatVariable(builder, fVid, stringId, static_cast(array->dataType()), 0, fArray); - } else { - throw std::runtime_error("Variable::asFlatVariable isn't possible for NDArrayList"); - } - } +bool sd::graph::Variable::isReadOnly() const { return _readOnly; } + +void sd::graph::Variable::markExternal(bool reallyExternal) { + this->_external = reallyExternal; +} + +void sd::graph::Variable::markRemovable(bool reallyRemovable) { + if (!reallyRemovable) nd4j_debug("", ""); + this->_removable = reallyRemovable; +} + +void sd::graph::Variable::markReadOnly(bool reallyReadOnly) { + this->_readOnly = reallyReadOnly; +} + +std::shared_ptr sd::graph::Variable::getNDArray() const { + if (_variableType != VariableType::NDARRAY) { + nd4j_printf( + "Variable[%i:%i/<%s>] is has [%s] type, but NDArray was requested\n", + this->_id, this->_index, this->_name.c_str(), + EnumUtils::_VariableTypeToString(_variableType)); + } + + if (this->_ndarray.get() == nullptr) { + if (_name.empty()) { + auto nodeId = StringUtils::valueToString(this->id()); + auto outputIndex = StringUtils::valueToString(this->index()); + throw std::runtime_error("Array doesn't exist for Variable <" + nodeId + + ":" + outputIndex + ">"); + } else { + auto outputIndex = StringUtils::valueToString(this->index()); + throw std::runtime_error("Array doesn't exist for Variable <" + + this->_name + ":" + outputIndex + ">"); } + } + + return this->_ndarray; +} + +std::shared_ptr sd::graph::Variable::getNDArrayList() const { + if (_variableType != VariableType::ARRAY_LIST) { + nd4j_debug( + "Variable[%i:%i/<%s>] is has [%s] type, but NDArrayList was " + "requested\n", + this->_id, this->_index, this->_name.c_str(), + EnumUtils::_VariableTypeToString(_variableType)); + } + return this->_list; +} + +bool Variable::isRemovable() const { return _removable; } + +void sd::graph::Variable::setNDArrayList( + std::shared_ptr list) { + this->_variableType = VariableType::ARRAY_LIST; + this->_list = list; +} + +const std::vector>& Variable::dependencies() const { + return _dependencies; +} + +void Variable::actualizeDependencies(const MAP_IMPL &lookupTable) const { + for (const auto &v: _stringDependencies) { + if (lookupTable.count(v) == 0) + throw std::runtime_error("Unknown Variable dependency found: [" + v + "]"); + + const_cast(this)->_dependencies.emplace_back(std::pair{lookupTable.at(v), 0}); + } +} + +void sd::graph::Variable::setNDArray(std::shared_ptr array) { + this->_variableType = VariableType::NDARRAY; + this->_ndarray = array; +} + +VariableType sd::graph::Variable::variableType() const { return _variableType; } + +sd::graph::Variable::Variable(const sd::graph::FlatVariable *flatVariable) { + auto vid = flatVariable->id(); + this->_id = vid->first(); + this->_index = vid->second(); + + if (flatVariable->name() != nullptr && flatVariable->name()->size() != 0) + this->_name = flatVariable->name()->str(); + + _external = true; + _readOnly = false; + + int8_t *buffer = nullptr; + + // reading control deps, and filling _dependencies field + if (flatVariable->controlDepsForVar() != nullptr && flatVariable->controlDepsForVar()->size() > 0) { + for (int e = 0; e < flatVariable->controlDepsForVar()->size(); e++) + _stringDependencies.emplace_back(flatVariable->controlDepsForVar()->Get(e)->str()); + } + + if (flatVariable->controlDepForOp() != nullptr && flatVariable->controlDepForOp()->size() > 0) { + for (int e = 0; e < flatVariable->controlDepForOp()->size(); e++) + _stringDependencies.emplace_back(flatVariable->controlDepForOp()->Get(e)->str()); + } + + if (flatVariable->controlDeps() != nullptr && flatVariable->controlDeps()->size() > 0) { + for (int e = 0; e < flatVariable->controlDeps()->size(); e++) + _stringDependencies.emplace_back(flatVariable->controlDeps()->Get(e)->str()); + } + + switch (flatVariable->variabletype()) { + case VarType_VARIABLE: { + // ????? + if (flatVariable->ndarray() != nullptr) { + auto ar = flatVariable->ndarray(); + _ndarray = std::make_shared( + sd::graph::FlatUtils::fromFlatArray(ar)); + } + + _variableType = VariableType::NDARRAY; + } break; + case VarType_CONSTANT: { + if (flatVariable->ndarray() == nullptr) + throw std::runtime_error("CONSTANT variable must have NDArray bundled"); + + auto ar = flatVariable->ndarray(); + if (ar->dtype() == DType_UTF8) { + _ndarray = std::make_shared( + sd::graph::FlatUtils::fromFlatArray(ar)); + } else { + _ndarray = std::make_shared( + sd::graph::FlatUtils::fromFlatArray(ar)); + } + + _variableType = VariableType::NDARRAY; + } break; + case VarType_ARRAY: { + // ????? + if (flatVariable->ndarray() != nullptr) { + auto ar = flatVariable->ndarray(); + _ndarray = std::make_shared( + sd::graph::FlatUtils::fromFlatArray(ar)); + // _ndarray->triggerAllocationFlag(true); + } + + _variableType = VariableType::NDARRAY; + } break; + case VarType_PLACEHOLDER: { + if (flatVariable->shape() == nullptr && flatVariable->ndarray() == nullptr) + throw std::runtime_error("PLACEHOLDER variable must have shape defined"); + + if (flatVariable->ndarray() != nullptr) { + auto ar = flatVariable->ndarray(); + _ndarray = std::make_shared(FlatUtils::fromFlatArray(ar)); + + _variableType = VariableType::NDARRAY; + } + + if (flatVariable->shape() != nullptr) { + int shapeLen = flatVariable->shape()->Length(); + for (int i = 0; i < flatVariable->shape()->size(); i++) + _shape.emplace_back(flatVariable->shape()->Get(i)); + + _dtype = (sd::DataType) flatVariable->dtype(); + + if (_ndarray == nullptr) _variableType = VariableType::PLACEHOLDER; + } + } break; + default: + throw std::runtime_error("Unknown variable type used"); + } +} + +const std::vector &sd::graph::Variable::shape() const { + return _shape; +} + +sd::graph::Variable::Variable(bool placeholder, DataType dataType, + const std::vector &shape) { + _placeholder = placeholder; + _dtype = dataType; + _shape = shape; +} + +sd::graph::Variable::Variable(std::shared_ptr array, + const char *name) { + _ndarray = array; + + _external = false; + _readOnly = false; + + if (name != nullptr) _name = std::string(name); + + if (_ndarray != nullptr) _variableType = VariableType::NDARRAY; +} + +DataType Variable::dataType() const { return _dtype; } + +sd::graph::Variable::Variable(std::shared_ptr array, + const std::string &name, int id, int idx) + : Variable(array, name.c_str()) { + _id = id; + _index = idx; +} + +sd::graph::Variable::~Variable() { + // +} + +void Variable::setId(int id, int idx) { + _id = id; + _index = idx; } +flatbuffers::Offset Variable::asFlatVariable( + flatbuffers::FlatBufferBuilder &builder) { + if (this->hasNDArray()) { + auto array = this->getNDArray(); + auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector()); + + auto fBuffer = builder.CreateVector(array->asByteVector()); + + // packing array + auto fArray = CreateFlatArray(builder, fShape, fBuffer, + (sd::graph::DType)array->dataType()); + + // packing id/index of this var + auto fVid = CreateIntPair(builder, this->_id, this->_index); + + // name is still optional + flatbuffers::Offset stringId = 0; + if (!this->_name.empty()) stringId = builder.CreateString(this->_name); + + // returning array + return CreateFlatVariable(builder, fVid, stringId, + static_cast(array->dataType()), + 0, fArray); + } else { + throw std::runtime_error( + "Variable::asFlatVariable isn't possible for NDArrayList"); + } +} +} // namespace graph +} // namespace sd + namespace std { - size_t hash>::operator()(const std::pair& k) const { - auto v = std::hash()(k.first); - v ^= std::hash()(k.second) + 0x9e3779b9 + (v << 6) + (v >> 2); - return v; - } +size_t hash>::operator()( + const std::pair &k) const { + auto v = std::hash()(k.first); + v ^= std::hash()(k.second) + 0x9e3779b9 + (v << 6) + (v >> 2); + return v; +} - size_t hash::operator()(const bfloat16& k) const { - return std::hash()((float)k); - } +size_t hash::operator()(const bfloat16 &k) const { + return std::hash()((float)k); +} - size_t hash::operator()(const float16& k) const { - return std::hash()((float)k); - } -} \ No newline at end of file +size_t hash::operator()(const float16 &k) const { + return std::hash()((float)k); +} +} // namespace std \ No newline at end of file diff --git a/libnd4j/include/graph/impl/VariableProxy.cpp b/libnd4j/include/graph/impl/VariableProxy.cpp index 2736e2a9e32b..809a0c1fa1ae 100644 --- a/libnd4j/include/graph/impl/VariableProxy.cpp +++ b/libnd4j/include/graph/impl/VariableProxy.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,269 +19,221 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { - namespace graph { - - VariableProxy::VariableProxy(VariableSpace* ref) { - if (ref == nullptr) - _backed = new VariableSpace(); - - _backed = ref; - _current = new VariableSpace(); - } - - - VariableProxy::~VariableProxy() { - delete _current; - } - - - int VariableProxy::numberOfPlaceholders() { - return _backed->numberOfPlaceholders(); - } - - - std::vector* VariableProxy::getPlaceholders() { - return _backed->getPlaceholders(); - } - - bool VariableProxy::hasExternalVariable(int it) { - return _backed->hasExternalVariable(it); - } - - - bool VariableProxy::hasExternalVariable(std::pair& pair) { - return _backed->hasExternalVariable(pair); - } - - - bool VariableProxy::hasExternalVariable(std::string *symbol) { - return _backed->hasExternalVariable(symbol); - } - - - bool VariableProxy::hasVariable(int id) { - return _current->hasVariable(id) || _backed->hasVariable(id); - } - - - bool VariableProxy::hasVariable(int id, int idx) { - return _current->hasVariable(id, idx) || _backed->hasVariable(id, idx); - } - - - bool VariableProxy::hasVariable(std::pair& pair) { - return _current->hasVariable(pair) || _backed->hasVariable(pair); - } - - - void VariableProxy::dropVariable(std::pair &pair) { - dropVariable(pair.first, pair.second); - } - - - void VariableProxy::dropVariable(int id, int idx) { - assert(_current->hasVariable(id, idx)); - - _current->dropVariable(id, idx); - } - - - std::vector VariableProxy::getVariables() { - std::vector result; - - auto b = _backed->getVariables(); - auto c = _current->getVariables(); - - for (auto v: b) - result.emplace_back(v); - - for (auto v: c) - result.emplace_back(v); - - return result; - } - - - bool VariableProxy::hasVariable(std::string *symbol) { - return _current->hasVariable(symbol) || _backed->hasVariable(symbol); - } - - - sd::graph::Variable *VariableProxy::getVariable(int id) { - if (_current->hasVariable(id)) - return _current->getVariable(id); - - if (_backed->hasVariable(id)) - return _backed->getVariable(id); - - nd4j_printf("Unable to get Variable to proxy: [%i]\n", id); - throw std::runtime_error("Bad arguments"); - } - - - sd::graph::Variable *VariableProxy::getVariable(int id, int idx) { - if (_current->hasVariable(id, idx)) - return _current->getVariable(id, idx); - - if (_backed->hasVariable(id, idx)) - return _backed->getVariable(id, idx); - - nd4j_printf("Unable to get Variable to proxy: [%i:%i]\n", id, idx); - throw std::runtime_error("Bad arguments"); - } - - - sd::graph::Variable *VariableProxy::getVariable(std::pair& pair) { - if (_current->hasVariable(pair)) - return _current->getVariable(pair); - - if (_backed->hasVariable(pair)) - return _backed->getVariable(pair); - - nd4j_printf("Unable to get Variable to proxy: [%i:%i]\n", pair.first, pair.second); - throw std::runtime_error("Bad arguments"); - } - - - sd::graph::Variable *VariableProxy::getVariable(std::string *symbol) { - if (_current->hasVariable(symbol)) - return _current->getVariable(symbol); - - if (_backed->hasVariable(symbol)) - return _backed->getVariable(symbol); - - nd4j_printf("Unable to get Variable to proxy: [%s]\n", symbol->c_str()); - throw std::runtime_error("Bad arguments"); - } - - - void VariableProxy::replaceVariable(Variable *variable) { - if (variable->getName() != nullptr && !variable->getName()->empty()) { - // if variable has name defined - we should resolve it via backing var space - if (_backed->hasVariable(variable->getName())) { - auto origVar = _backed->getVariable(variable->getName()); - variable->setId(origVar->id(), origVar->index()); - _current->replaceVariable(variable); - } else - _current->replaceVariable(variable); - } else // if proxy has variable - that's one story - _current->replaceVariable(variable); - } - - - Variable* VariableProxy::putVariable(std::pair& pair, NDArray *array) { - return _current->putVariable(pair, array); - } - - - void VariableProxy::putVariable(std::pair& pair, Variable *variable) { - _current->putVariable(pair, variable); - } - - - void VariableProxy::putVariable(int id, Variable *variable) { - _current->putVariable(id, variable); - } - - - void VariableProxy::putVariable(int id, NDArray *array) { - _current->putVariable(id, array); - } - - void sd::graph::VariableProxy::putVariable(int id, int idx, NDArray &array) { - _current->putVariable(id, idx, array); - } - - Variable* VariableProxy::putVariable(int id, int idx, NDArray *array) { - return _current->putVariable(id, idx, array); - } - - - void VariableProxy::putVariable(int id, int idx, Variable *array) { - _current->putVariable(id, idx, array); - } - - - void VariableProxy::trackList(sd::NDArrayList* list) { - _current->trackList(list); - } - - - sd::graph::Stash* VariableProxy::getStash() { - return _current->getStash(); - } - - - void VariableProxy::setFlowPath(FlowPath* timers) { - _current->setFlowPath(timers); - } - - - FlowPath* VariableProxy::flowPath() { - return _current->flowPath(); - } - - - void VariableProxy::putOutputVariable(Variable *variable) { - _current->putOutputVariable(variable); - } - - - Nd4jLong VariableProxy::externalMemory() { - return _backed->externalMemory() + _current->externalMemory(); - } - - - Nd4jLong VariableProxy::internalMemory() { - return _backed->internalMemory() + _current->internalMemory(); - } - - - Nd4jLong VariableProxy::totalMemory() { - return _backed->totalMemory() + _current->totalMemory(); - } - - - int VariableProxy::externalEntries() { - return _backed->externalEntries() + _current->externalEntries(); - } - - - int VariableProxy::internalEntries() { - return _backed->internalEntries() + _current->internalEntries(); - } - - - int VariableProxy::totalEntries() { - return _backed->totalEntries() + _current->totalEntries(); - } - - - sd::graph::VariableSpace* VariableProxy::clone() { - auto clone = new VariableProxy(_backed); - - delete clone->_current; - clone->_current = _current->clone(); - - return clone; - } - - - VariableSpace& VariableProxy::operator=(const VariableSpace& other) { - if (this == &other) return *this; - - nd4j_printf("VariableProxy = not implemented\n",""); - - return *this; - } - - - sd::memory::Workspace * sd::graph::VariableProxy::workspace() { - return _workspace; - } - } +namespace graph { + +VariableProxy::VariableProxy(const VariableSpace *ref) { + if (ref == nullptr) _backed = new VariableSpace(); + + _backed = ref; +} + +VariableProxy::~VariableProxy() { } + +int VariableProxy::numberOfPlaceholders() const { + return _backed->numberOfPlaceholders(); +} + +const std::vector> &VariableProxy::placeholders() + const { + return _backed->placeholders(); +} + +bool VariableProxy::hasExternalVariable(int it) const { + return _backed->hasExternalVariable(it); +} + +bool VariableProxy::hasExternalVariable(const std::pair &pair) const { + return _backed->hasExternalVariable(pair); +} + +bool VariableProxy::hasExternalVariable(const std::string &symbol) const { + return _backed->hasExternalVariable(symbol); +} + +bool VariableProxy::hasVariable(int id) const { + return _current.hasVariable(id) || _backed->hasVariable(id); +} + +bool VariableProxy::hasVariable(int id, int idx) const { + return _current.hasVariable(id, idx) || _backed->hasVariable(id, idx); +} + +bool VariableProxy::hasVariable(const std::pair &pair) const { + return _current.hasVariable(pair) || _backed->hasVariable(pair); +} + +void VariableProxy::dropVariable(const std::pair &pair) { + dropVariable(pair.first, pair.second); +} + +void VariableProxy::dropVariable(int id, int idx) { + assert(_current.hasVariable(id, idx)); + + _current.dropVariable(id, idx); +} + +std::vector> VariableProxy::variables() const { + std::vector> result; + + auto b = _backed->variables(); + auto c = _current.variables(); + + for (auto v : b) result.emplace_back(v); + + for (auto v : c) result.emplace_back(v); + + return result; +} + +bool VariableProxy::hasVariable(const std::string &symbol) const { + return _current.hasVariable(symbol) || _backed->hasVariable(symbol); +} + +std::shared_ptr VariableProxy::getVariable(int id) const { + if (_current.hasVariable(id)) return _current.getVariable(id); + + if (_backed->hasVariable(id)) return _backed->getVariable(id); + + nd4j_printf("Unable to get Variable from proxy: [%i]\n", id); + throw std::runtime_error("Bad arguments"); +} + +std::shared_ptr VariableProxy::getVariable(int id, int idx) const { + if (_current.hasVariable(id, idx)) return _current.getVariable(id, idx); + + if (_backed->hasVariable(id, idx)) return _backed->getVariable(id, idx); + + nd4j_printf("Unable to get Variable from proxy: [%i:%i]\n", id, idx); + throw std::runtime_error("Bad arguments"); +} + +std::shared_ptr VariableProxy::getVariable( + const std::pair &pair) const { + if (_current.hasVariable(pair)) return _current.getVariable(pair); + + if (_backed->hasVariable(pair)) return _backed->getVariable(pair); + + nd4j_printf("Unable to get Variable from proxy: [%i:%i]\n", pair.first, + pair.second); + throw std::runtime_error("Bad arguments"); } + +std::shared_ptr VariableProxy::getVariable( + const std::string &symbol) const { + if (_current.hasVariable(symbol)) return _current.getVariable(symbol); + + if (_backed->hasVariable(symbol)) return _backed->getVariable(symbol); + + nd4j_printf("Unable to get Variable from proxy: [%s]\n", symbol.c_str()); + throw std::runtime_error("Bad arguments"); +} + +void VariableProxy::replaceVariable(std::shared_ptr variable) { + if (!variable->getName().empty()) { + // if variable has name defined - we should resolve it via backing var space + if (_backed->hasVariable(variable->getName())) { + auto origVar = _backed->getVariable(variable->getName()); + variable->setId(origVar->id(), origVar->index()); + _current.replaceVariable(variable); + } else + _current.replaceVariable(variable); + } else // if proxy has variable - that's one story + _current.replaceVariable(variable); +} + +std::shared_ptr VariableProxy::putVariable(const std::string &name, + int id, int idx, + const NDArray &array) { + return _current.putVariable(name, id, idx, array); +} + +void VariableProxy::putOutputVariable(std::shared_ptr variable) { + _current.putOutputVariable(variable); +} + +std::shared_ptr VariableProxy::putVariable( + const std::pair &pair, const NDArray &array) { + return _current.putVariable(pair, array); +} + +const MAP_IMPL, std::shared_ptr> &VariableProxy::externalPaired() const { + return _backed->_paired; +} + +void VariableProxy::putVariable(const std::pair &pair, + const std::shared_ptr &variable) { + _current.putVariable(pair, variable); +} + +void VariableProxy::putVariable(int id, + const std::shared_ptr &variable) { + _current.putVariable(id, variable); +} + +std::shared_ptr VariableProxy::putVariable(int id, + const NDArray &array) { + return _current.putVariable(id, array); +} + +std::shared_ptr VariableProxy::putVariable(int id, int idx, + const NDArray &array) { + return _current.putVariable(id, idx, array); +} + +void VariableProxy::putVariable(const std::string &name, int id, int idx, + const std::shared_ptr &array) { + _current.putVariable(name, id, idx, array); +} + +Stash *VariableProxy::stash() const { return _current.stash(); } + +Nd4jLong VariableProxy::externalMemory() const { + return _backed->externalMemory() + _current.externalMemory(); +} + +Nd4jLong VariableProxy::internalMemory() const { + return _backed->internalMemory() + _current.internalMemory(); +} + +Nd4jLong VariableProxy::totalMemory() const { + return _backed->totalMemory() + _current.totalMemory(); +} + +int VariableProxy::externalEntries() const { + return _backed->externalEntries() + _current.externalEntries(); +} + +int VariableProxy::internalEntries() const { + return _backed->internalEntries() + _current.internalEntries(); +} + +int VariableProxy::totalEntries() const { + return _backed->totalEntries() + _current.totalEntries(); +} + +VariableSpace &VariableProxy::operator=(const VariableSpace &other) { + if (this == &other) return *this; + + nd4j_printf("VariableProxy = not implemented\n", ""); + + return *this; +} + +void VariableProxy::pullFrom(const VariableProxy &proxy) { + for (const auto &v:proxy._current.variables()) { + _current.replaceVariable(v); + } +} + +void VariableProxy::pushTo(VariableProxy &proxy) const { + for (const auto &v:_current.variables()) { + proxy._current.replaceVariable(v); + } +} + +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/impl/VariableSpace.cpp b/libnd4j/include/graph/impl/VariableSpace.cpp index 0e8634d07303..a0c8e64acf13 100644 --- a/libnd4j/include/graph/impl/VariableSpace.cpp +++ b/libnd4j/include/graph/impl/VariableSpace.cpp @@ -22,423 +22,372 @@ #include namespace sd { - namespace graph { - std::vector * sd::graph::VariableSpace::getExternalVariables() { - return &_external; - } +namespace graph { +Stash *VariableSpace::stash() const { return const_cast(&_stash); } - sd::graph::Stash* sd::graph::VariableSpace::getStash() { - return &_stash; - } +void VariableSpace::injectVariable(const std::pair &pair, + std::shared_ptr variable) { + if (pair.second == 0) { + this->_variables[pair.first] = variable; + } - sd::graph::VariableSpace* sd::graph::VariableSpace::clone() { - auto result = new VariableSpace(); + if (!variable->getName().empty()) + this->_symbolic[variable->getName()] = variable; - for (auto const& x : _paired) { - std::pair pair(x.first.first, x.first.second); + this->_paired[pair] = variable; +} - Variable* clonedVar = x.second->clone(); +const std::vector> &VariableSpace::placeholders() + const { + return _placeholders; +} - result->injectVariable(pair, clonedVar); - } +int VariableSpace::numberOfPlaceholders() const { return _placeholders.size(); } - return result; - } +bool VariableSpace::hasVariable(const std::string &symbol) const { + return _symbolic.count(symbol) > 0; +} - void VariableSpace::setWorkspace(sd::memory::Workspace *workspace) { - //_workspace = *workspace; - } +std::shared_ptr VariableSpace::getVariable( + const std::string &symbol) const { + return _symbolic.at(symbol); +} - - sd::graph::VariableSpace* sd::graph::VariableSpace::asT() { - auto result = new VariableSpace(); +bool VariableSpace::hasVariable(int id, int index) const { + std::pair pair(id, index); + return hasVariable(pair); +} - for (auto const& x : _paired) { - std::pair pair(x.first.first, x.first.second); +bool VariableSpace::hasExternalVariable(int id) const { + if (!hasVariable(id)) return false; - //Variable* clonedVar = x.second->template asT(); + auto var = getVariable(id); + return var->isExternal(); +} - //result->injectVariable(pair, clonedVar); - } +bool VariableSpace::hasExternalVariable(const std::pair &pair) const { + if (!hasVariable(pair)) return false; - return result; - } - - - void sd::graph::VariableSpace::injectVariable(std::pair &pair, Variable* variable) { - if (pair.second == 0) { - if (pair.first < 0) - this->_variables[pair.first] = variable; - else - this->_temporary[pair.first] = variable; - } - - if (variable->getName() != nullptr && variable->getName()->length() > 0) - this->_symbolic[*(variable->getName())] = variable; - - this->_paired[pair] = variable; - - this->_handles->push_back(variable); - } - - std::vector * sd::graph::VariableSpace::getPlaceholders() { - return &_placeholders; - } - - int sd::graph::VariableSpace ::numberOfPlaceholders() { - return _placeholders.size(); - } - - bool sd::graph::VariableSpace::hasVariable(std::string *symbol) { - return _symbolic.count(*symbol) == 1; - } - - sd::graph::Variable * sd::graph::VariableSpace::getVariable(std::string *symbol) { - return _symbolic.at(*symbol); - } - - bool sd::graph::VariableSpace::hasVariable(int id, int index) { - std::pair pair(id, index); - return hasVariable(pair); - } - - bool VariableSpace::hasExternalVariable(int id) { - if (!hasVariable(id)) - return false; - - auto var = getVariable(id); - return var->isExternal(); - } - - bool VariableSpace::hasExternalVariable(std::pair& pair) { - if (!hasVariable(pair)) - return false; - - auto var = getVariable(pair); - return var->isExternal(); - } - - bool VariableSpace::hasExternalVariable(std::string *symbol) { - if (!hasVariable(symbol)) - return false; - - auto var = getVariable(symbol); - return var->isExternal(); - } - - sd::graph::Variable * sd::graph::VariableSpace::getVariable(int id, int index) { - std::pair pair(id, index); - return getVariable(pair); - } - - sd::graph::Variable * sd::graph::VariableSpace::getVariable(std::pair& pair) { - if (pair.first < 0) - return getVariable(pair.first); - else - return _paired.at(pair); - - nd4j_printf("Unknown variable requested: [%i,%i]\n", pair.first, pair.second); - throw std::runtime_error("Unknown variable requested"); - } - - bool sd::graph::VariableSpace::hasVariable(int id) { - return _variables.count(id) == 1 || _temporary.count(id) == 1; - } - - bool sd::graph::VariableSpace::hasVariable(std::pair& id) { - return _paired.count(id) > 0; - } - - void sd::graph::VariableSpace::putOutputVariable(Variable *variable) { - //putVariable(_auto_counter--, variable); - putVariable(variable->id(), variable); - } - - int sd::graph::VariableSpace::externalEntries() { - return _external.size(); - } - - int sd::graph::VariableSpace::internalEntries() { - return _internal.size(); - } - - int sd::graph::VariableSpace::totalEntries() { - return externalEntries() + internalEntries(); - } - - Nd4jLong sd::graph::VariableSpace::externalMemory() { - Nd4jLong size = 0; - for (auto n: _external) { - size += n->getNDArray()->memoryFootprint(); - } - - return size; - } - - std::vector VariableSpace::getVariables() { - std::vector result; + auto var = getVariable(pair); + return var->isExternal(); +} + +bool VariableSpace::hasExternalVariable(const std::string &symbol) const { + if (!hasVariable(symbol)) return false; + + auto var = getVariable(symbol); + return var->isExternal(); +} + +std::shared_ptr VariableSpace::getVariable(int id, int index) const { + std::pair pair(id, index); + return getVariable(pair); +} + +std::shared_ptr VariableSpace::getVariable( + const std::pair &pair) const { + if (pair.first < 0) + return getVariable(pair.first); + else + return _paired.at(pair); + + nd4j_printf("Unknown variable requested: [%i,%i]\n", pair.first, pair.second); + throw std::runtime_error("Unknown variable requested"); +} + +bool VariableSpace::hasVariable(int id) const { + return _variables.count(id) > 0; +} + +bool VariableSpace::hasVariable(const std::pair &id) const { + return _paired.count(id) > 0; +} + +void VariableSpace::putOutputVariable(std::shared_ptr variable) { + // putVariable(_auto_counter--, variable); + putVariable(variable->id(), variable); +} + +int VariableSpace::externalEntries() const { return _external.size(); } + +int VariableSpace::internalEntries() const { return _internal.size(); } + +int VariableSpace::totalEntries() const { + return externalEntries() + internalEntries(); +} + +Nd4jLong VariableSpace::externalMemory() const { + Nd4jLong size = 0; + for (auto n : _external) { + size += n->getNDArray()->memoryFootprint(); + } + + return size; +} + +std::vector> VariableSpace::variables() const { + std::vector> result; + + for (auto v : _internal) result.emplace_back(v); + + for (auto v : _external) result.emplace_back(v); + + return result; +} + +Nd4jLong VariableSpace::internalMemory() const { + Nd4jLong size = 0; + for (auto n : _internal) { + size += n->getNDArray()->memoryFootprint(); + } + + return size; +} + +Nd4jLong VariableSpace::totalMemory() const { + return externalMemory() + internalMemory(); +} + +std::shared_ptr VariableSpace::putVariable( + int id, int idx, const std::shared_ptr &array) { + auto variable = std::make_shared(array, "", id, idx); + this->putVariable({id, idx}, variable); + return variable; +} + +std::shared_ptr VariableSpace::putVariable(const std::string &name, + int id, int idx, + const NDArray &array) { + auto variable = std::make_shared(array, name, id, idx); + this->putVariable({id, idx}, variable); + return variable; +} + +void VariableSpace::dropVariable(const std::string &pair) { + throw std::runtime_error("VariableSpace::dropVariable - not implemented yet"); +} + +std::shared_ptr VariableSpace::putVariable( + const std::pair &pair, const NDArray &array) { + auto variable = + std::make_shared(array, "", pair.first, pair.second); + this->putVariable(pair, variable); + return variable; +} + +std::shared_ptr VariableSpace::putVariable(int node, int idx, + const NDArray &array) { + std::pair pair(node, idx); + return this->putVariable(pair, array); +} + +void VariableSpace::putVariable(const std::string &name, int node, int idx, + const std::shared_ptr &variable) { + std::pair pair(node, idx); + variable->setName(name); + this->putVariable(pair, variable); +} + +void VariableSpace::silentPutVariable( + const std::pair &pair, + const std::shared_ptr &variable) { + std::lock_guard lock(_varmap); + + _paired[pair] = variable; +} + +void VariableSpace::putVariable(const std::pair &pair, + const std::shared_ptr &variable) { + silentPutVariable(pair, variable); + + if (variable->isPlaceholder()) _placeholders.emplace_back(variable); + + // copying duplicate for compatibility + if (pair.second == 0 && !this->hasVariable(pair.first)) { + this->putVariable(pair.first, variable); + } + + if (!variable->getName().empty()) { + _symbolic[variable->getName()] = variable; + } +} + +void VariableSpace::putVariable(int id, + const std::shared_ptr &variable) { + // we don't want to add variables more then once + if (_variables.count(id) > 0) { + throw std::runtime_error("VariableSpace::putVariable - duplicate found"); + } + + { + std::lock_guard lock(_varmap); + + if (_auto_counter >= id) _auto_counter = id - 1; + + variable->setId(id); + + if (!variable->getName().empty()) { + // std::pair pair(*(variable->getName()), + // variable); + _symbolic[variable->name()] = variable; + } - for (auto v: _internal) - result.emplace_back(v); + // we have special list for external variables to ensure graph completeness + if (id < 0) { + _external.emplace_back(variable); + } else { + _internal.emplace_back(variable); + } - for (auto v: _external) - result.emplace_back(v); + _variables[id] = variable; + } - return result; - } + std::pair pair(id, 0); + if (!hasVariable(pair)) { + this->silentPutVariable(pair, variable); - Nd4jLong sd::graph::VariableSpace::internalMemory() { - Nd4jLong size = 0; - for (auto n: _internal) { - size += n->getNDArray()->memoryFootprint(); - } + if (variable->isPlaceholder()) _placeholders.emplace_back(variable); + } +} - return size; - } +std::shared_ptr VariableSpace::putVariable(int id, + const NDArray &array) { + auto var = std::make_shared(array, "", id, 0); + this->putVariable(id, var); + return var; +} - Nd4jLong sd::graph::VariableSpace::totalMemory() { - return externalMemory() + internalMemory(); - } +std::shared_ptr VariableSpace::getVariable(int id) const { + return _variables.at(id); +} - Variable* sd::graph::VariableSpace::putVariable(std::pair& pair, NDArray *array) { - auto variable = new Variable(array, nullptr, pair.first, pair.second); - this->putVariable(pair, variable); - return variable; - } +VariableSpace::~VariableSpace() { } - Variable* sd::graph::VariableSpace::putVariable(int node, int idx, NDArray *array) { - std::pair pair(node, idx); - return this->putVariable(pair, array); - } +VariableSpace::VariableSpace(const VariableSpace &other) { + _stash = other._stash; - void sd::graph::VariableSpace::putVariable(int node, int idx, Variable *variable) { - std::pair pair(node, idx); - this->putVariable(pair, variable); - } + _paired = other._paired; + _symbolic = other._symbolic; + _variables = other._variables; - void sd::graph::VariableSpace::silentPutVariable(std::pair& pair, Variable *variable) { - _varmap.lock(); + _external = other._external; + _internal = other._internal; - //std::pair, sd::graph::Variable *> p(pair, variable); - _paired[pair] = variable; + _lists = other._lists; + _placeholders = other._placeholders; - _varmap.unlock(); - } + _auto_counter = other._auto_counter; +} - void sd::graph::VariableSpace::putVariable(std::pair& pair, Variable *variable) { - silentPutVariable(pair, variable); +VariableSpace::VariableSpace(VariableSpace &&other) { + _stash = std::move(other._stash); - if (variable->isPlaceholder()) - _placeholders.push_back(variable); + _paired = std::move(other._paired); + _symbolic = std::move(other._symbolic); + _variables = std::move(other._variables); - // copying duplicate for compatibility - if (pair.second == 0 && !this->hasVariable(pair.first)) { - this->putVariable(pair.first, variable); - } else { - if (variable->getName() != nullptr && variable->getName()->length() != 0) { - _symbolic[*(variable->getName())] = variable; - } + _external = std::move(other._external); + _internal = std::move(other._internal); - _varmap.lock(); + _lists = std::move(other._lists); + _placeholders = std::move(other._placeholders); - _handles->push_back(variable); + _auto_counter = other._auto_counter; +} - _varmap.unlock(); - } - } +VariableSpace &VariableSpace::operator=(VariableSpace &&other) { + if (this == &other) return *this; - void VariableSpace::trackList(sd::NDArrayList* list) { - _lists.emplace_back(list); - } + _stash = std::move(other._stash); - void sd::graph::VariableSpace::putVariable(int id, Variable *variable) { - // we don't want to add variables more then once - if (_variables.count(id) > 0 || _temporary.count(id) > 0) { - auto local = id < 0 ? _variables.at(id) : _temporary.at(id); + _paired = std::move(other._paired); + _symbolic = std::move(other._symbolic); + _variables = std::move(other._variables); - if (!local->hasNDArray() && variable->hasNDArray()) { - local->setNDArray(variable->getNDArray()); + _external = std::move(other._external); + _internal = std::move(other._internal); - // we're inheriting this from Variable - local->markReadOnly(variable->isReadOnly()); - local->markRemovable(variable->isRemovable()); - } + _lists = std::move(other._lists); + _placeholders = std::move(other._placeholders); - return; - } + _auto_counter = other._auto_counter; - _varmap.lock(); + return *this; +} - _handles->emplace_back(variable); +VariableSpace &VariableSpace::operator=(const VariableSpace &other) { + if (this == &other) return *this; - if (_auto_counter >= id) - _auto_counter = id - 1; + _stash = other._stash; - variable->setId(id); + _paired = other._paired; + _symbolic = other._symbolic; + _variables = other._variables; - if (variable->getName() != nullptr && variable->getName()->length() != 0) { - //std::pair pair(*(variable->getName()), variable); - _symbolic[*(variable->getName())] = variable; - } + _external = other._external; + _internal = other._internal; - // we have special list for external variables to ensure graph completeness + _lists = other._lists; + _placeholders = other._placeholders; - if (id < 0) { - //if (variable->isExternal()) - _external.push_back(variable); + _auto_counter = other._auto_counter; - _variables[id] = variable; - } else { - _internal.push_back(variable); + return *this; +} - _temporary[id] = variable; - } +void VariableSpace::replaceVariable(std::shared_ptr variable) { + bool replaced = false; + // trying name lookup first + if (!variable->getName().empty()) { + if (hasVariable(variable->getName())) { + auto vs = getVariable(variable->getName()); + dropVariable(vs->id(), vs->index()); - _varmap.unlock(); + putVariable({vs->id(), vs->index()}, variable); - std::pair pair(id, 0); - if (!hasVariable(pair)) { - this->silentPutVariable(pair, variable); + // if we're on zero index, we also must update index-less reference + if (vs->index() == 0) + _variables[vs->id()] = variable; - if (variable->isPlaceholder()) - _placeholders.push_back(variable); - } - } + replaced = true; + } + } else { + if (hasVariable(variable->id(), variable->index())) { + auto vs = getVariable(variable->id(), variable->index()); + dropVariable(variable->id(), variable->index()); + putVariable({vs->id(), vs->index()}, variable); - void sd::graph::VariableSpace::putVariable(int id, int idx, NDArray &array) { - auto *var = new sd::graph::Variable(&array, "", id, idx); - var->markRemovable(false); - var->markReadOnly(true); - - // let's see if this op needs - bool d = this->hasVariable(id, idx); - - this->putVariable(id, var); - - // if var for this nodeid already exists - we'll just delete variable - if (d) - delete var; - } - - void sd::graph::VariableSpace::putVariable(int id, NDArray *array) { - auto *var = new sd::graph::Variable(array); - this->putVariable(id, var); - } - - sd::graph::Variable * sd::graph::VariableSpace::getVariable(int id) { - if (id < 0) { - return _variables.at(id); - } else { - return _temporary.at(id); - } - } - - LaunchContext* sd::graph::VariableSpace::launchContext() { - return LaunchContext::defaultContext(); - } + // if we're on zero index, we also must update index-less reference + if (vs->index() == 0) + _variables[vs->id()] = variable; - std::vector* sd::graph::VariableSpace::handles() { - return _handles; - } - -/* - * FIXME: this thing have nice chances to become backend-specific! - */ - sd::graph::VariableSpace::~VariableSpace() { - // loop through variables and release them - for (auto p: *_handles) { - delete p; - } - - delete _handles; - - //_internal.clear(); - //_external.clear(); - //_temporary.clear(); - - //nd4j_printf("Number of NDArrayLists in this space: [%i]\n", _lists.size()) - for (auto p: _lists) - delete p; - - _lists.clear(); - } - - VariableSpace& VariableSpace::operator=(const VariableSpace& other) { - if (this == &other) return *this; - - for (auto const& x : other._paired) { - std::pair pair(x.first.first, x.first.second); - - Variable* clonedVar = x.second->clone(); - - if (pair.second == 0) { - if (pair.first < 0) - this->_variables[pair.first] = clonedVar; - else - this->_temporary[pair.first] = clonedVar; - } - - if (clonedVar->getName() != nullptr && clonedVar->getName()->length() > 0) - this->_symbolic[*(clonedVar->getName())] = clonedVar; - - this->_paired[pair] = clonedVar; - - this->_handles->push_back(clonedVar); - } - - return *this; - } - - void VariableSpace::replaceVariable(Variable *variable) { - bool replaced = false; - // trying name first - if (variable->getName() != nullptr && !variable->getName()->empty()) { - nd4j_printf("Trying to replace variable by name: [%s]\n", variable->getName()->c_str()); - if (hasVariable(variable->getName())) { - nd4j_printf("Replacing by name: [%s]\n", variable->getName()->c_str()); - auto vs = getVariable(variable->getName()); - dropVariable(vs->id(), vs->index()); - putVariable(vs->id(), vs->index(), variable); - //delete vs; - replaced = true; - } - } else { - nd4j_printf("Trying to replace variable by id: [%i:%i]\n", variable->id(), variable->index()); - if (hasVariable(variable->id(), variable->index())) { - nd4j_printf("Replacing by id: [%i:%i]\n", variable->id(), variable->index()); - auto vs = getVariable(variable->id(), variable->index()); - dropVariable(variable->id(), variable->index()); - putVariable(vs->id(), vs->index(), variable); - //delete vs; - replaced = true; - } - } - - if (!replaced) { - nd4j_printf("wasn't able to replace variable, putting\n", ""); - putVariable(variable->id(), variable->index(), variable); - } - } - - void VariableSpace::dropVariable(std::pair &pair) { - dropVariable(pair.first, pair.second); - } - - void VariableSpace::dropVariable(int id, int idx) { - - } - - - void VariableSpace::setFlowPath(FlowPath* flow) { - _flow = flow; - } - - FlowPath* VariableSpace::flowPath() { - return _flow; - } - - VariableSpace::VariableSpace() { - _handles = new std::vector; - } + replaced = true; } -} \ No newline at end of file + } + + if (!replaced) { + putVariable({variable->id(), variable->index()}, variable); + } +} + +void VariableSpace::dropVariable(const std::pair &pair) { + dropVariable(pair.first, pair.second); +} + +void VariableSpace::dropVariable(int id, int idx) { + if (hasVariable(id, idx)) { + // we must check if string name is defined for this Variable + auto v = getVariable(id, idx); + if (!v->name().empty()) + _symbolic.erase(v->name()); + + _paired.erase(std::pair{id, idx}); + } + + if (idx == 0 && hasVariable(id)) + _variables.erase(id); +} + +VariableSpace::VariableSpace() {} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/impl/VariablesSet.cpp b/libnd4j/include/graph/impl/VariablesSet.cpp index 80f8e3728949..2920014cdd44 100644 --- a/libnd4j/include/graph/impl/VariablesSet.cpp +++ b/libnd4j/include/graph/impl/VariablesSet.cpp @@ -21,30 +21,21 @@ #include namespace sd { - namespace graph { - Nd4jStatus VariablesSet::status() { - return _status; - } - - int VariablesSet::size() { - return _holder.size(); - } - - void VariablesSet::push_back(Variable *variable) { - _holder.push_back(variable); - } - - Variable *VariablesSet::at(int index) { - return _holder.at(index); - } - - VariablesSet::VariablesSet(Nd4jStatus status) { - _status = status; - } - - VariablesSet::~VariablesSet() { - for (auto v: _holder) - delete v; - } - } +namespace graph { +Nd4jStatus VariablesSet::status() { return _status; } + +int VariablesSet::size() { return _holder.size(); } + +void VariablesSet::push_back(Variable *variable) { + _holder.push_back(variable); +} + +Variable *VariablesSet::at(int index) { return _holder.at(index); } + +VariablesSet::VariablesSet(Nd4jStatus status) { _status = status; } + +VariablesSet::~VariablesSet() { + for (auto v : _holder) delete v; } +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/execution/LogicEnter.h b/libnd4j/include/graph/logic/LogicEnter.h similarity index 75% rename from libnd4j/include/graph/execution/LogicEnter.h rename to libnd4j/include/graph/logic/LogicEnter.h index d770ff10a443..f341595ce34f 100644 --- a/libnd4j/include/graph/execution/LogicEnter.h +++ b/libnd4j/include/graph/logic/LogicEnter.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,21 +19,19 @@ // Created by raver119 on 30.01.18. // -#ifndef LIBND4J_LOGICENTER_H -#define LIBND4J_LOGICENTER_H +#ifndef SD_LOGICENTER_H +#define SD_LOGICENTER_H -#include #include +#include namespace sd { - namespace graph { - class LogicEnter { - public: - static Nd4jStatus processNode(Graph* graph, Node* node); - }; - } -} - - +namespace graph { +class LogicEnter { + public: + static Nd4jStatus processNode(const Node* node, Stack &stack, const OptimizedGraph& graph); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_LOGICEXIT_H +#endif // SD_LOGICEXIT_H diff --git a/libnd4j/include/graph/execution/LogicExecutor.h b/libnd4j/include/graph/logic/LogicExecutor.h similarity index 64% rename from libnd4j/include/graph/execution/LogicExecutor.h rename to libnd4j/include/graph/logic/LogicExecutor.h index 541b3fc8425b..83a4e2a3e6e1 100644 --- a/libnd4j/include/graph/execution/LogicExecutor.h +++ b/libnd4j/include/graph/logic/LogicExecutor.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,26 +19,27 @@ // Created by raver119 on 20.10.2017. // -#ifndef LIBND4J_LOGICEXECUTOR_H -#define LIBND4J_LOGICEXECUTOR_H +#ifndef SD_LOGICEXECUTOR_H +#define SD_LOGICEXECUTOR_H -#include -#include #include +#include +#include +#include +#include namespace sd { - namespace graph { - /** - * This class acts as switch for picking logic execution based on opNum, unique for each logical op - * @tparam T - */ - class LogicExecutor { - public: - static Nd4jStatus processNode(Graph* graph, Node* node); - }; - } -} - - +namespace graph { +/** + * This class acts as switch for picking logic execution based on opNum, unique + * for each logical op + * @tparam T + */ +class LogicExecutor { + public: + static Nd4jStatus processNode(const Node* node, Stack &stack, const OptimizedGraph& graph); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_LOGICEXECUTOR_H +#endif // SD_LOGICEXECUTOR_H diff --git a/libnd4j/include/graph/execution/LogicExit.h b/libnd4j/include/graph/logic/LogicExit.h similarity index 74% rename from libnd4j/include/graph/execution/LogicExit.h rename to libnd4j/include/graph/logic/LogicExit.h index d182e26fbf39..82e4e134f086 100644 --- a/libnd4j/include/graph/execution/LogicExit.h +++ b/libnd4j/include/graph/logic/LogicExit.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,21 +19,19 @@ // Created by raver119 on 30.01.18. // -#ifndef LIBND4J_LOGICEXIT_H -#define LIBND4J_LOGICEXIT_H +#ifndef SD_LOGICEXIT_H +#define SD_LOGICEXIT_H -#include #include +#include namespace sd { - namespace graph { - class LogicExit { - public: - static Nd4jStatus processNode(Graph* graph, Node* node); - }; - } -} - - +namespace graph { +class LogicExit { + public: + static Nd4jStatus processNode(const Node* node, Stack &stack, const OptimizedGraph& graph); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_LOGICEXIT_H +#endif // LIBND4J_LOGICEXIT_H diff --git a/libnd4j/include/graph/execution/LogicLoopCond.h b/libnd4j/include/graph/logic/LogicLoopCond.h similarity index 74% rename from libnd4j/include/graph/execution/LogicLoopCond.h rename to libnd4j/include/graph/logic/LogicLoopCond.h index 36693232be90..5564ac6613f6 100644 --- a/libnd4j/include/graph/execution/LogicLoopCond.h +++ b/libnd4j/include/graph/logic/LogicLoopCond.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,21 +19,19 @@ // Created by raver119 on 30.01.18. // -#ifndef LIBND4J_LOGICLOOPCOND_H -#define LIBND4J_LOGICLOOPCOND_H +#ifndef SD_LOGICLOOPCOND_H +#define SD_LOGICLOOPCOND_H -#include #include +#include namespace sd { - namespace graph { - class LogicLoopCond { - public: - static Nd4jStatus processNode(Graph* graph, Node* node); - }; - } -} - - +namespace graph { +class LogicLoopCond { + public: + static Nd4jStatus processNode(const Node* node, Stack &stack, const OptimizedGraph& graph); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_LOGICLOOPCOND_H +#endif // SD_LOGICLOOPCOND_H diff --git a/libnd4j/include/graph/execution/LogicMerge.h b/libnd4j/include/graph/logic/LogicMerge.h similarity index 75% rename from libnd4j/include/graph/execution/LogicMerge.h rename to libnd4j/include/graph/logic/LogicMerge.h index fe20c9d660ed..541b7396bff2 100644 --- a/libnd4j/include/graph/execution/LogicMerge.h +++ b/libnd4j/include/graph/logic/LogicMerge.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,21 +19,21 @@ // Created by raver119 on 30.01.18. // -#ifndef LIBND4J_LOGICMERGE_H -#define LIBND4J_LOGICMERGE_H +#ifndef SD_LOGICMERGE_H +#define SD_LOGICMERGE_H -#include #include +#include namespace sd { - namespace graph { - class LogicMerge { - public: - static Nd4jStatus processNode(Graph* graph, Node* node); - }; - } -} +namespace graph { +class LogicMerge { + public: + static Nd4jStatus processNode(const Node* node, Stack &stack, const OptimizedGraph& graph); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_LOGICMERGE_H +#endif // SD_LOGICMERGE_H diff --git a/libnd4j/include/graph/execution/LogicNextIteration.h b/libnd4j/include/graph/logic/LogicNextIteration.h similarity index 73% rename from libnd4j/include/graph/execution/LogicNextIteration.h rename to libnd4j/include/graph/logic/LogicNextIteration.h index 5b9600909ea9..aeb2b16bd7e3 100644 --- a/libnd4j/include/graph/execution/LogicNextIteration.h +++ b/libnd4j/include/graph/logic/LogicNextIteration.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,21 +19,21 @@ // Created by raver119 on 30.01.18. // -#ifndef LIBND4J_LOGICNEXTITERATION_H -#define LIBND4J_LOGICNEXTITERATION_H +#ifndef SD_LOGICNEXTITERATION_H +#define SD_LOGICNEXTITERATION_H -#include #include +#include namespace sd { - namespace graph { - class LogicNextIeration { - public: - static Nd4jStatus processNode(Graph* graph, Node* node); - }; - } -} +namespace graph { +class LogicNextIeration { + public: + static Nd4jStatus processNode(const Node* node, Stack &stack, const OptimizedGraph& graph); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_LOGICNEXTITERATION_H +#endif // SD_LOGICNEXTITERATION_H diff --git a/libnd4j/include/graph/execution/LogicSwitch.h b/libnd4j/include/graph/logic/LogicSwitch.h similarity index 61% rename from libnd4j/include/graph/execution/LogicSwitch.h rename to libnd4j/include/graph/logic/LogicSwitch.h index d91959d91eff..a7d21fd053fe 100644 --- a/libnd4j/include/graph/execution/LogicSwitch.h +++ b/libnd4j/include/graph/logic/LogicSwitch.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -15,31 +16,32 @@ ******************************************************************************/ // -// Created by raver119 on 21.10.17. +// @author raver119@gmail.com // -#ifndef LIBND4J_LOGICSWITCH_H -#define LIBND4J_LOGICSWITCH_H +#ifndef SD_LOGICSWITCH_H +#define SD_LOGICSWITCH_H -#include -#include #include +#include +#include +#include namespace sd { - namespace graph { - /** - * This class is responsible for execution logic of Switch logical abstraction - * - * It's ultra-simple. It does nothing, and can't be executed directly. - * - * @tparam T - */ - class LogicSwitch { - public: - static Nd4jStatus processNode(Graph* graph, Node* node); - }; - } -} +namespace graph { +/** + * This class is responsible for execution logic of Switch logical abstraction + * + * It's ultra-simple. It does nothing, and can't be executed directly. + * + * @tparam T + */ +class LogicSwitch { + public: + static Nd4jStatus processNode(const Node* node, Stack &stack, const OptimizedGraph& graph); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_LOGICSWITCH_H +#endif // SD_LOGICSWITCH_H diff --git a/libnd4j/include/graph/logic/LogicUtilities.h b/libnd4j/include/graph/logic/LogicUtilities.h new file mode 100644 index 000000000000..93f9d4677085 --- /dev/null +++ b/libnd4j/include/graph/logic/LogicUtilities.h @@ -0,0 +1,42 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_LOGICUTILITIES_H +#define SD_LOGICUTILITIES_H + +#include +#include +#include +#include + +namespace sd { +namespace graph { + +class SD_EXPORT LogicUtilities { + public: + static void disableBranch(StackFrame &frame, VariableProxy &varSpace, const OptimizedGraph &graph, const Node* node, bool branchToDisable); + static void disableBranch(StackFrame &frame, VariableProxy &varSpace, const OptimizedGraph &graph, const Node* node); +}; + + +} +} + +#endif //SD_LOGICUTILITIES_H diff --git a/libnd4j/include/graph/logic/impl/LogicEnter.cpp b/libnd4j/include/graph/logic/impl/LogicEnter.cpp new file mode 100644 index 000000000000..76e29c18dfd7 --- /dev/null +++ b/libnd4j/include/graph/logic/impl/LogicEnter.cpp @@ -0,0 +1,89 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include + +namespace sd { +namespace graph { + +/** + * This function does 2 things: + * - Propagates input Variable + * - Opens new StackFrame (only if that's the first Enter in this Loop) + */ +Nd4jStatus LogicEnter::processNode(const Node *node, Stack &stack, const OptimizedGraph& graph) { + if (node->exitId() >= 0) { + // the only possible case for this equality is NextIteration rewind + if (node->id() == stack.back().enterId()) { + stack.iterateFrame(node->frameId(), node->id()); + } else { + // if current frameName isn't equal to node frame name - we'll open new StackFrame then + if (node->frameId() != stack.back().frameId()) { + // since this is the loop entrance, we'll rewind to this Node once iteration ends + stack.openFrame(node->frameId(), node->id()); + } + } + } else { + //nd4j_printf("Neg\n", ""); + } + + // getting current frame (it might be the new one!) + const auto &frame = stack.back(); + const auto &inputs = node->inputs(); + const auto &outputs = node->outputs(); + auto &varSpace = const_cast(frame.variableProxy()); + + // and we need to find exit point - it has to be Exit node, with max index within OpSequence + if (node->exitId() >= 0) { + auto currentExitIndex = frame.exitId() >= 0 ? graph.nodeIndex(frame.exitId()) : -1; + auto thisExitIndex = graph.nodeIndex(node->exitId()); + + // we want to exit after the last Exit node + if (thisExitIndex > currentExitIndex) + frame.setExitId(node->exitId()); + + // we need to find rewind point - it has to be NextIteration node with max index within OpSequence + const auto &merge = graph.nodesMap().at(outputs[0].first); + const auto &iter = graph.nodesMap().at(merge.inputs()[1].first); + + // we must compare index of this NextIteration Node within OpSequence to the current one, if it's set + auto currentRewindIndex = frame.rewindId() >= 0 ? graph.nodeIndex(frame.rewindId()) : -1; + auto thisRewindIndex = graph.nodeIndex(iter.id()); + + // we want to rewind after the last NextIteration node + if (thisRewindIndex > currentRewindIndex) + frame.setRewindId(iter.id()); + } + + // validate Node state + REQUIRE_TRUE(inputs.size() == 1, 0, "Enter: op must have exactly 1 inputs"); + REQUIRE_TRUE(varSpace.hasVariable(inputs[0]), 0, "Enter: input Variable doesn't exist"); + + // now we propagate input as own output + // ssince we've opened new StackFrame, this Variable will end up in current VariableProxy + varSpace.putVariable(std::pair{node->id(), 0}, *varSpace.getVariable(inputs[0])->getNDArray()); + + return sd::Status::OK(); +} + +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/logic/impl/LogicExecutor.cpp b/libnd4j/include/graph/logic/impl/LogicExecutor.cpp new file mode 100644 index 000000000000..a128f6b9ed39 --- /dev/null +++ b/libnd4j/include/graph/logic/impl/LogicExecutor.cpp @@ -0,0 +1,58 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include +#include +#include +#include +#include + +namespace sd { +namespace graph { +Nd4jStatus LogicExecutor::processNode(const Node *node, Stack &stack, const OptimizedGraph& graph) { + switch (node->opNum()) { + case sd::logic::Switch: + return LogicSwitch::processNode(node, stack, graph); + case sd::logic::Merge: + return LogicMerge::processNode(node, stack, graph); + case sd::logic::LoopCond: + return LogicLoopCond::processNode(node, stack, graph); + case sd::logic::NextIteration: + return LogicNextIeration::processNode(node, stack, graph); + case sd::logic::Exit: + return LogicExit::processNode(node, stack, graph); + case sd::logic::Enter: + return LogicEnter::processNode(node, stack, graph); + } + + if (node->name().empty()) { + nd4j_printf("Unknown LogicOp used at node [%i]: [%i]\n", node->id(), + node->opNum()); + } else { + nd4j_printf("Unknown LogicOp used at node [%i:<%s>]: [%i]\n", node->id(), + node->name().c_str(), node->opNum()); + } + return ND4J_STATUS_BAD_INPUT; +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/logic/impl/LogicExit.cpp b/libnd4j/include/graph/logic/impl/LogicExit.cpp new file mode 100644 index 000000000000..647152c544af --- /dev/null +++ b/libnd4j/include/graph/logic/impl/LogicExit.cpp @@ -0,0 +1,58 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { +namespace graph { + +/** + * This funciton does 2 things: + * - Propagates input Variable to outer StackFrame + * - closes current StackFrame (only if this is the last Exit node in this loop) + */ +Nd4jStatus LogicExit::processNode(const Node *node, Stack &stack, const OptimizedGraph& graph) { + // getting current frame (it must be the StackFrame created for a While loop) + const auto &frame = stack.back(); + + // we must propagate variable from this frame to parent one + const auto &parent = frame.parent(); + + const auto &inputs = node->inputs(); + + REQUIRE_TRUE(inputs.size() == 1, 0, "Exit: op must have exactly 1 input1"); + REQUIRE_TRUE(frame.variableProxy().hasVariable(inputs[0]), 0, "Exit: input Variable doesn't exist"); + + nd4j_printf("Propagating Variable to the Frame [%i]\n", parent.id()); + + // get Variable from current VariableProxy and put to the ParentOne + auto var = frame.variableProxy().getVariable(inputs[0]); + const_cast(parent.variableProxy()).putVariable({node->id(), 0}, *var->getNDArray()); + + // if this is the last Exit node - we close current StackFrame + if (frame.exitId() == node->id()) + stack.closeFrame(); + + return Status::OK(); +} + +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp b/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp new file mode 100644 index 000000000000..cdb595b678b2 --- /dev/null +++ b/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp @@ -0,0 +1,44 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { +namespace graph { + +Nd4jStatus LogicLoopCond::processNode(const Node *node, Stack &stack, const OptimizedGraph& graph) { + auto &frame = stack.back(); + + const auto &inputs = node->inputs(); + auto &varSpace = const_cast(frame.variableProxy()); + + REQUIRE_TRUE(inputs.size() == 1, 0, "LoopCond: op must have exactly 1 input1"); + REQUIRE_TRUE(frame.variableProxy().hasVariable(inputs[0]), 0, "LoopCond: input Variable doesn't exist"); + + // Propagate Variable + auto var = varSpace.getVariable(inputs[0]); + varSpace.putVariable({node->id(), 0}, *var->getNDArray()); + + return Status::OK(); +} + +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/logic/impl/LogicMerge.cpp b/libnd4j/include/graph/logic/impl/LogicMerge.cpp new file mode 100644 index 000000000000..815b322bf658 --- /dev/null +++ b/libnd4j/include/graph/logic/impl/LogicMerge.cpp @@ -0,0 +1,99 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include +#include + +namespace sd { +namespace graph { + +static bool isNextIterationCase(const OptimizedGraph& graph, int firstId, int secondId) { + const auto firstNode = graph.nodesMap().count(firstId) > 0 ? &graph.nodesMap().at(firstId) : nullptr; + const auto secondNode = graph.nodesMap().count(secondId) > 0 ? &graph.nodesMap().at(secondId) : nullptr; + + return (firstNode != nullptr && firstNode->opType() == OpType_LOGIC && firstNode->opNum() == sd::logic::NextIteration) || (secondNode != nullptr && secondNode->opType() == OpType_LOGIC && secondNode->opNum() == sd::logic::NextIteration); +} + +static bool checkViability(Stack &stack, const std::pair &first, const std::pair &second) { + auto &frame = stack.back(); + auto &varSpace = const_cast(frame.variableProxy()); + + return (!frame.isDisabled(first.first) && varSpace.hasVariable(first) && varSpace.getVariable(first)->hasNDArray()) || (!frame.isDisabled(second.first) && varSpace.hasVariable(second) && varSpace.getVariable(second)->hasNDArray()); +} + +Nd4jStatus LogicMerge::processNode(const Node *node, Stack &stack, const OptimizedGraph& graph) { + // getting current frame first + auto &frame = stack.back(); + + const auto &inputs = node->inputs(); + auto &varSpace = const_cast(frame.variableProxy()); + + REQUIRE_TRUE(inputs.size() == 2, 0, "Merge: op expects exactly 2 inputs, but only %i defined", (int) inputs.size()); + + // if both inputs are unavailable - this node is disabled and must be disabled + if (!checkViability(stack, inputs[0], inputs[1])) { + nd4j_printf("Both inputs absent, skipping\n", ""); + LogicUtilities::disableBranch(frame, varSpace, graph, node); + return Status::OK(); + } + + if (frame.isDisabled(inputs[0].first) && frame.isDisabled(inputs[1].first)) { + REQUIRE_TRUE(false, 0, "Merge: only 1 input should be disabled, but got both of them down"); + } + + + + // if we're on NextIteration merge, we'll propagate its output regardless of first arg existence + if (isNextIterationCase(graph, inputs[0].first, inputs[1].first)) { + const auto firstNode = graph.nodesMap().count(inputs[0].first) > 0 ? &graph.nodesMap().at(inputs[0].first) : nullptr; + const auto secondNode = graph.nodesMap().count(inputs[1].first) > 0 ? &graph.nodesMap().at(inputs[1].first) : nullptr; + + if (firstNode != nullptr && firstNode->opType() == OpType_LOGIC && firstNode->opNum() == sd::logic::NextIteration) { + // we must check, if NextIteration Node already was executed. Or, pick initial value first + auto id = varSpace.hasVariable(inputs[0]) && varSpace.getVariable(inputs[0])->hasNDArray() ? inputs[0] : inputs[1]; + if (!varSpace.hasVariable(id) || !varSpace.getVariable(id)->hasNDArray()) + throw std::runtime_error("Non-existent NDArray requested: [" + StringUtils::valueToString(id) +"]"); + + varSpace.putVariable({node->id(), 0}, *varSpace.getVariable(id)->getNDArray()); + } else { + // we must check, if NextIteration Node already was executed. Or, pick initial value first + auto id = varSpace.hasVariable(inputs[1]) && varSpace.getVariable(inputs[1])->hasNDArray() ? inputs[1] : inputs[0]; + if (!varSpace.hasVariable(id) || !varSpace.getVariable(id)->hasNDArray()) + throw std::runtime_error("Non-existent NDArray requested: [" + StringUtils::valueToString(id) +"]"); + + varSpace.putVariable({node->id(), 0}, *varSpace.getVariable(id)->getNDArray()); + } + } else { + // we're getting first non-disabled input and propagate it + const auto &p = frame.isDisabled(inputs[0].first) ? inputs[1] : inputs[0]; + + REQUIRE_TRUE(frame.variableProxy().hasVariable(p), 0, "Merge: Variable [%i:%i] doesn't exist", p.first, p.second); + + varSpace.putVariable({node->id(), 0}, *varSpace.getVariable(p)->getNDArray()); + } + + return Status::OK(); +} + +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp b/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp new file mode 100644 index 000000000000..86ecab4e54c1 --- /dev/null +++ b/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp @@ -0,0 +1,44 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { +namespace graph { + +Nd4jStatus LogicNextIeration::processNode(const Node *node, Stack &stack, const OptimizedGraph& graph) { + auto &frame = stack.back(); + + const auto &inputs = node->inputs(); + auto &varSpace = const_cast(frame.variableProxy()); + + REQUIRE_TRUE(inputs.size() == 1, 0, "LoopCond: op must have exactly 1 input1"); + REQUIRE_TRUE(frame.variableProxy().hasVariable(inputs[0]), 0, "LoopCond: input Variable doesn't exist"); + + // Propagate Variable + auto var = varSpace.getVariable(inputs[0]); + varSpace.putVariable({node->id(), 0}, *var->getNDArray()); + + return Status::OK(); +} + +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/logic/impl/LogicSwitch.cpp b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp new file mode 100644 index 000000000000..53310a79a997 --- /dev/null +++ b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp @@ -0,0 +1,65 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include +#include + +namespace sd { +namespace graph { + + +Nd4jStatus LogicSwitch::processNode(const Node* node, Stack &stack, const OptimizedGraph& graph) { + // getting current frame first + auto &frame = stack.back(); + + const auto &inputs = node->inputs(); + const auto &outputs = node->outputs(); + + auto &varSpace = const_cast(frame.variableProxy()); + + REQUIRE_TRUE(inputs.size() == 2, 0, "Switch: op must have exactly 2 inputs"); + REQUIRE_TRUE(varSpace.hasVariable(inputs[0]), 0, "Switch: input Variable doesn't exist"); + REQUIRE_TRUE(varSpace.hasVariable(inputs[1]), 0, "Switch: condition Variable doesn't exist"); + + auto input = varSpace.getVariable(inputs[0]); + auto boolean = varSpace.getVariable(inputs[1]); + + REQUIRE_TRUE(boolean->hasNDArray(), 0, "Switch: boolean Variable must have NDArray defined"); + + nd4j_printf("Switch [%i] evaluated as [%s]\n", node->id(), boolean->getNDArray()->e(0) ? "true" : "false"); + + if (boolean->getNDArray()->e(0)) { + // true branch + varSpace.putVariable(std::pair{node->id(), 1}, *input->getNDArray()); + LogicUtilities::disableBranch(frame, varSpace, graph, node, false); + } else { + // false branch + varSpace.putVariable(std::pair{node->id(), 0}, *input->getNDArray()); + LogicUtilities::disableBranch(frame, varSpace, graph, node, true); + } + + return Status::OK(); +}; + +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/logic/impl/LogicUtilities.cpp b/libnd4j/include/graph/logic/impl/LogicUtilities.cpp new file mode 100644 index 000000000000..1adca748a95a --- /dev/null +++ b/libnd4j/include/graph/logic/impl/LogicUtilities.cpp @@ -0,0 +1,77 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { +namespace graph { + +void LogicUtilities::disableBranch(StackFrame &frame, VariableProxy &varSpace, const OptimizedGraph &graph, const Node* node) { + const auto &outputs = node->outputs(); + + // we're going to disable certain external variables, if they depend on a current disabled node + // FIXME: it can be done in a better way rather than O(n^2) + for (const auto &var: varSpace.externalPaired()) { + for (const auto &d: var.second->dependencies()) { + if (d.first == node->id()) + frame.disableNode(var.second->id()); + } + } + + // we're going to roll through all consumers + for (const auto &o:outputs) { + if (graph.nodesMap().count(o.first) == 0) + throw std::runtime_error("pew-pew"); + + // now fetch disabled node + const auto &n = graph.nodesMap().at(o.first); + + // edge case here: don't disable Merge node + if (n.opType() == OpType_LOGIC && n.opNum() == sd::logic::Merge) + continue; + + // disable each consumer + frame.disableNode(o.first); + + // do recursive magic + disableBranch(frame, varSpace, graph, &n); + } +} + +void LogicUtilities::disableBranch(StackFrame &frame, VariableProxy &varSpace, const OptimizedGraph &graph, const Node* node, bool branchToDisable) { + const auto &outputs = node->outputs(); + int second = branchToDisable ? 1 : 0; + + for (const auto &o:outputs) { + if (o.second == second) { + frame.disableNode(o.first); + + if (graph.nodesMap().count(o.first) == 0) + throw std::runtime_error("pew-pew"); + + const auto &n = graph.nodesMap().at(o.first); + + disableBranch(frame, varSpace, graph, &n); + } + } +} + +} +} diff --git a/libnd4j/include/graph/execution/LogicExpose.h b/libnd4j/include/graph/optimization/GraphOptimizer.h similarity index 64% rename from libnd4j/include/graph/execution/LogicExpose.h rename to libnd4j/include/graph/optimization/GraphOptimizer.h index 046f3e64e9a6..5432b6f39d47 100644 --- a/libnd4j/include/graph/execution/LogicExpose.h +++ b/libnd4j/include/graph/optimization/GraphOptimizer.h @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -15,25 +15,26 @@ ******************************************************************************/ // -// Created by raver119 on 12.11.2017. +// @author raver119@gmail.com // -#ifndef LIBND4J_LOGICEXPOSE_H -#define LIBND4J_LOGICEXPOSE_H +#ifndef SD_GRAPHOPTIMIZER_H +#define SD_GRAPHOPTIMIZER_H -#include -#include #include namespace sd { - namespace graph { - class LogicExpose { - public: - static Nd4jStatus processNode(Graph* graph, Node* node); - }; - } -} +namespace graph { +class SD_EXPORT GraphOptimizer { + public: + /** + * This method optimizes given Graph and returns independent cloned Graph + * @param graph + * @return + */ + static Graph* optimize(const Graph& graph); +}; +} // namespace graph +} // namespace sd - - -#endif //LIBND4J_LOGICEXPOSE_H +#endif // SD_GRAPHOPTIMIZER_H diff --git a/libnd4j/include/graph/optimization/NodeOptimizer.h b/libnd4j/include/graph/optimization/NodeOptimizer.h new file mode 100644 index 000000000000..8fb39c012898 --- /dev/null +++ b/libnd4j/include/graph/optimization/NodeOptimizer.h @@ -0,0 +1,59 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_NODEOPTIMIZER_H +#define SD_NODEOPTIMIZER_H + +#include +#include + +#include + +namespace sd { +namespace graph { +/** + * This abstract class defines basic methods needed for Inputs/Outputs + * optimizations. I.e. weight format changes or data types changes for a + * specific backend + */ +class SD_EXPORT NodeOptimizer { + protected: + std::string _target = {}; + + public: + NodeOptimizer() = default; + virtual ~NodeOptimizer() = default; + + /** + * This method applu + * @param node + */ + virtual void optimize(Node& node) = 0; + + /** + * This method returns target Op name for this optimizer + * @return + */ + const std::string& targetOp() const; +}; +} // namespace graph +} // namespace sd + +#endif // DEV_TESTS_NODEOPTIMIZER_H diff --git a/libnd4j/include/graph/optimization/cudnn/README.md b/libnd4j/include/graph/optimization/cudnn/README.md new file mode 100644 index 000000000000..7c872fe4132c --- /dev/null +++ b/libnd4j/include/graph/optimization/cudnn/README.md @@ -0,0 +1 @@ +This folder contains NodeOptimizer implementations for cuDNN \ No newline at end of file diff --git a/libnd4j/include/graph/optimization/generic/README.md b/libnd4j/include/graph/optimization/generic/README.md new file mode 100644 index 000000000000..535e905f3fb4 --- /dev/null +++ b/libnd4j/include/graph/optimization/generic/README.md @@ -0,0 +1 @@ +This folder contains generic NodeOptimizer implementations suitable for all platforms \ No newline at end of file diff --git a/libnd4j/include/graph/optimization/impl/GraphOptimizer.cpp b/libnd4j/include/graph/optimization/impl/GraphOptimizer.cpp new file mode 100644 index 000000000000..1170b0b6191c --- /dev/null +++ b/libnd4j/include/graph/optimization/impl/GraphOptimizer.cpp @@ -0,0 +1,33 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { +namespace graph { +Graph *GraphOptimizer::optimize(const Graph &graph) { + auto clone = graph.clone(); + + // TODO: implement this method + + return clone; +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/optimization/impl/NodeOptimizer.cpp b/libnd4j/include/graph/optimization/impl/NodeOptimizer.cpp new file mode 100644 index 000000000000..b8ed74ef84ae --- /dev/null +++ b/libnd4j/include/graph/optimization/impl/NodeOptimizer.cpp @@ -0,0 +1,27 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { +namespace graph { +const std::string &NodeOptimizer::targetOp() const { return _target; } +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/optimization/mkldnn/README.md b/libnd4j/include/graph/optimization/mkldnn/README.md new file mode 100644 index 000000000000..1cedbb91e007 --- /dev/null +++ b/libnd4j/include/graph/optimization/mkldnn/README.md @@ -0,0 +1 @@ +This folder contains NodeOptimizer implementations for MKL-DNN \ No newline at end of file diff --git a/libnd4j/include/graph/profiling/GraphProfile.h b/libnd4j/include/graph/profiling/GraphProfile.h index f0ada4f90f19..5c89ce2ff13e 100644 --- a/libnd4j/include/graph/profiling/GraphProfile.h +++ b/libnd4j/include/graph/profiling/GraphProfile.h @@ -21,100 +21,105 @@ #ifndef ND4J_GRAPH_PROFILE_H #define ND4J_GRAPH_PROFILE_H -#include "NodeProfile.h" -#include #include -#include -#include -#include +#include + #include +#include +#include +#include + +#include "NodeProfile.h" namespace sd { - namespace graph { - class ND4J_EXPORT GraphProfile { - private: - // this variable - Nd4jLong _merges = 1L; - - /** - * This is global memory values - */ - Nd4jLong _memoryTotal = 0L; - Nd4jLong _memoryActivations = 0L; - Nd4jLong _memoryTemporary = 0L; - Nd4jLong _memoryObjects = 0L; - - // time spent for graph construction - Nd4jLong _buildTime = 0L; - - // time spent for graph execution - Nd4jLong _executionTime = 0L; - - // collection of pointers to profile results - std::vector _profiles; - std::map _profilesById; - - // collection of various timing reports - std::map _timings; - std::chrono::time_point _last; - - std::map> _timers; - - void updateLast(); - public: - GraphProfile(); - ~GraphProfile(); - - /** - * These methods just adding amount of bytes to various counters - */ - void addToTotal(Nd4jLong bytes); - void addToActivations(Nd4jLong bytes); - void addToTemporary(Nd4jLong bytes); - void addToObjects(Nd4jLong bytes); - - /** - * This method allows to set graph construction (i.e. deserialization) time in nanoseconds - */ - void setBuildTime(Nd4jLong nanos); - - /** - * This method sets graph execution time in nanoseconds. - */ - void setExecutionTime(Nd4jLong nanos); - - void startEvent(const char *name); - void recordEvent(const char *name); - void deleteEvent(const char *name); - - /** - * This method saves time as delta from last saved time - */ - void spotEvent(const char *name); - - /** - * This method returns pointer to NodeProfile by ID - * PLEASE NOTE: this method will create new NodeProfile if there's none - */ - NodeProfile* nodeById(int id, const char *name = nullptr); - bool nodeExists(int id); - - /** - * This method merges values from other profile report - * @param other - */ - void merge(GraphProfile *other); - void assign(GraphProfile *other); - - /** - * These methods are just utility methods for time - */ - static Nd4jLong currentTime(); - static Nd4jLong relativeTime(Nd4jLong time); - - void printOut(); - }; - } -} +namespace graph { +class SD_EXPORT GraphProfile { + private: + // this variable + Nd4jLong _merges = 1L; + + /** + * This is global memory values + */ + Nd4jLong _memoryTotal = 0L; + Nd4jLong _memoryActivations = 0L; + Nd4jLong _memoryTemporary = 0L; + Nd4jLong _memoryObjects = 0L; + + // time spent for graph construction + Nd4jLong _buildTime = 0L; + + // time spent for graph execution + Nd4jLong _executionTime = 0L; + + // collection of pointers to profile results + std::vector _profiles; + std::map _profilesById; + + // collection of various timing reports + std::map _timings; + std::chrono::time_point _last; + + std::map> + _timers; + + void updateLast(); + + public: + GraphProfile(); + ~GraphProfile(); + + /** + * These methods just adding amount of bytes to various counters + */ + void addToTotal(Nd4jLong bytes); + void addToActivations(Nd4jLong bytes); + void addToTemporary(Nd4jLong bytes); + void addToObjects(Nd4jLong bytes); + + /** + * This method allows to set graph construction (i.e. deserialization) time in + * nanoseconds + */ + void setBuildTime(Nd4jLong nanos); + + /** + * This method sets graph execution time in nanoseconds. + */ + void setExecutionTime(Nd4jLong nanos); + + void startEvent(const char *name); + void recordEvent(const char *name); + void deleteEvent(const char *name); + + /** + * This method saves time as delta from last saved time + */ + void spotEvent(const char *name); + + /** + * This method returns pointer to NodeProfile by ID + * PLEASE NOTE: this method will create new NodeProfile if there's none + */ + NodeProfile *nodeById(int id, const char *name = nullptr); + bool nodeExists(int id); + + /** + * This method merges values from other profile report + * @param other + */ + void merge(GraphProfile *other); + void assign(GraphProfile *other); + + /** + * These methods are just utility methods for time + */ + static Nd4jLong currentTime(); + static Nd4jLong relativeTime(Nd4jLong time); + + void printOut(); +}; +} // namespace graph +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/graph/profiling/GraphProfilingHelper.h b/libnd4j/include/graph/profiling/GraphProfilingHelper.h index d32d99374660..e87f6a509663 100644 --- a/libnd4j/include/graph/profiling/GraphProfilingHelper.h +++ b/libnd4j/include/graph/profiling/GraphProfilingHelper.h @@ -21,17 +21,17 @@ #ifndef LIBND4J_GRAPHPROFILINGHELPER_H #define LIBND4J_GRAPHPROFILINGHELPER_H - #include + #include "GraphProfile.h" namespace sd { - namespace graph { - class GraphProfilingHelper { - public: - static GraphProfile* profile(Graph *graph, int iterations); - }; - } -} +namespace graph { +class GraphProfilingHelper { + public: + static GraphProfile* profile(Graph* graph, int iterations); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_GRAPHPROFILINGHELPER_H +#endif // LIBND4J_GRAPHPROFILINGHELPER_H diff --git a/libnd4j/include/graph/profiling/NodeProfile.h b/libnd4j/include/graph/profiling/NodeProfile.h index 83f0b88fca0b..6c1f639ce726 100644 --- a/libnd4j/include/graph/profiling/NodeProfile.h +++ b/libnd4j/include/graph/profiling/NodeProfile.h @@ -21,91 +21,93 @@ #ifndef LIBND4J_NODE_PROFILE_H #define LIBND4J_NODE_PROFILE_H -#include #include +#include + #include #include namespace sd { - namespace graph { - class ND4J_EXPORT NodeProfile { - private: - int _id; - std::string _name; +namespace graph { +class SD_EXPORT NodeProfile { + private: + int _id; + std::string _name; + + Nd4jLong _merges = 1L; + + // time spent during deserialization + Nd4jLong _buildTime = 0L; - Nd4jLong _merges = 1L; + // time spent before op execution + Nd4jLong _preparationTime = 0L; - // time spent during deserialization - Nd4jLong _buildTime = 0L; - - // time spent before op execution - Nd4jLong _preparationTime = 0L; + // time spent for op execution + Nd4jLong _executionTime = 0L; - // time spent for op execution - Nd4jLong _executionTime = 0L; + // total time spent during node execution + Nd4jLong _totalTime = 0L; - // total time spent during node execution - Nd4jLong _totalTime = 0L; + // time spent for output shape creation + Nd4jLong _shapeTime = 0L; - // time spent for output shape creation - Nd4jLong _shapeTime = 0L; + // time spent for output arrays creation + Nd4jLong _arrayTime = 0L; - // time spent for output arrays creation - Nd4jLong _arrayTime = 0L; + Nd4jLong _inputTime = 0L; - Nd4jLong _inputTime = 0L; + // amount of memory used for outputs + Nd4jLong _memoryActivations = 0L; - // amount of memory used for outputs - Nd4jLong _memoryActivations = 0L; + // amount of memory used internally for temporary arrays + Nd4jLong _memoryTemporary = 0L; - // amount of memory used internally for temporary arrays - Nd4jLong _memoryTemporary = 0L; + // amount of memory used internally for objects + Nd4jLong _memoryObjects = 0L; - // amount of memory used internally for objects - Nd4jLong _memoryObjects = 0L; + // total amount of memory used during execution + Nd4jLong _memoryTotal = 0L; - // total amount of memory used during execution - Nd4jLong _memoryTotal = 0L; + std::vector _inputShapes; + std::vector _outputShapes; - std::vector _inputShapes; - std::vector _outputShapes; - public: - NodeProfile() = default; - ~NodeProfile() = default; + public: + NodeProfile() = default; + ~NodeProfile() = default; - explicit NodeProfile(int id, const char *name); + explicit NodeProfile(int id, const char* name); - void setBuildTime(Nd4jLong time); - void setPreparationTime(Nd4jLong time); - void setExecutionTime(Nd4jLong time); - void setTotalTime(Nd4jLong time); - void setShapeFunctionTime(Nd4jLong time); - void setArrayTime(Nd4jLong time); - void setInputTime(Nd4jLong time); + void setBuildTime(Nd4jLong time); + void setPreparationTime(Nd4jLong time); + void setExecutionTime(Nd4jLong time); + void setTotalTime(Nd4jLong time); + void setShapeFunctionTime(Nd4jLong time); + void setArrayTime(Nd4jLong time); + void setInputTime(Nd4jLong time); - void setActivationsSize(Nd4jLong bytes); - void setTemporarySize(Nd4jLong bytes); - void setObjectsSize(Nd4jLong bytes); - void setTotalSize(Nd4jLong bytes); + void setActivationsSize(Nd4jLong bytes); + void setTemporarySize(Nd4jLong bytes); + void setObjectsSize(Nd4jLong bytes); + void setTotalSize(Nd4jLong bytes); - void addInputShape(Nd4jLong const* shapeInfo); - void addOutputShape(Nd4jLong const* shapeInfo); + void addInputShape(Nd4jLong const* shapeInfo); + void addOutputShape(Nd4jLong const* shapeInfo); - Nd4jLong getActivationsSize() const; - Nd4jLong getTemporarySize() const; - Nd4jLong getObjectsSize() const; - Nd4jLong getTotalSize() const; + Nd4jLong getActivationsSize() const; + Nd4jLong getTemporarySize() const; + Nd4jLong getObjectsSize() const; + Nd4jLong getTotalSize() const; - Nd4jLong getExecutionTime() const; + Nd4jLong getExecutionTime() const; - std::string& name(); + std::string& name(); - void merge(NodeProfile *other); - void assign(NodeProfile *other); + void merge(NodeProfile* other); + void assign(NodeProfile* other); - void printOut(); - }; - } -} + void printOut(); +}; +} // namespace graph +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/graph/profiling/impl/GraphProfile.cpp b/libnd4j/include/graph/profiling/impl/GraphProfile.cpp index 14ce54a0f645..bb37c0764c78 100644 --- a/libnd4j/include/graph/profiling/impl/GraphProfile.cpp +++ b/libnd4j/include/graph/profiling/impl/GraphProfile.cpp @@ -20,198 +20,183 @@ #include #include -#include #include + #include +#include namespace sd { - namespace graph { - GraphProfile::GraphProfile() { - updateLast(); - } - - GraphProfile::~GraphProfile() { - // releasing NodeProfile pointers - for (auto v: _profiles) - delete v; - - _timings.clear(); - } - - void GraphProfile::addToTotal(Nd4jLong bytes) { - _memoryTotal += bytes; - } - - void GraphProfile::addToActivations(Nd4jLong bytes) { - _memoryActivations += bytes; - } - - void GraphProfile::addToTemporary(Nd4jLong bytes) { - _memoryTemporary += bytes; - } - - void GraphProfile::addToObjects(Nd4jLong bytes) { - _memoryObjects += bytes; - } - - void GraphProfile::setBuildTime(Nd4jLong nanos) { - _buildTime = nanos; - } - - void GraphProfile::setExecutionTime(Nd4jLong nanos) { - _executionTime = nanos; - } - - - Nd4jLong GraphProfile::currentTime() { - auto t = std::chrono::system_clock::now(); - auto v = std::chrono::time_point_cast (t); - auto epoch = v.time_since_epoch(); - return (Nd4jLong) std::chrono::duration_cast(epoch).count(); - } - - Nd4jLong GraphProfile::relativeTime(Nd4jLong time) { - auto t1 = currentTime(); - return t1 - time; - } - - void GraphProfile::updateLast() { - _last = std::chrono::system_clock::now(); - } - - void GraphProfile::startEvent(const char *name) { - std::string k = name; - _timers[k] = std::chrono::system_clock::now(); - } - - void GraphProfile::recordEvent(const char *name) { - std::string k = name; - if (_timers.count(k) == 0) { - nd4j_printf("Can't find timer key: [%s]", name); - throw std::runtime_error("Missing timer key"); - } - auto t0 = _timers[k]; - auto t1 = std::chrono::system_clock::now(); - auto v = (Nd4jLong) std::chrono::duration_cast(t1 - t0).count(); - - _timings[k] = v; - _timers.erase(k); - } - - void GraphProfile::deleteEvent(const char *name) { - std::string k = name; - _timers.erase(k); - } - - void GraphProfile::spotEvent(const char *name) { - auto t = std::chrono::system_clock::now(); - auto d = (Nd4jLong) std::chrono::duration_cast(t - _last).count(); - std::string k = name; - _timings[k] = d; - updateLast(); - } - - NodeProfile* GraphProfile::nodeById(int id, const char *name) { - if (_profilesById.count(id) == 0) { - auto node = new NodeProfile(id, name); - _profiles.emplace_back(node); - _profilesById[id] = node; - return node; - } - - return _profilesById[id]; - } - - void GraphProfile::merge(GraphProfile *other) { - _merges += other->_merges; - _memoryActivations += other->_memoryActivations; - _memoryTemporary += other->_memoryTemporary; - _memoryTotal += other->_memoryTotal; - _memoryObjects += other->_memoryObjects; - - _executionTime += other->_executionTime; - _buildTime += other->_buildTime; - - - for (auto v:_profilesById) { - if (!other->nodeExists(v.first)) - continue; - - v.second->merge(other->nodeById(v.first)); - } - } - - void GraphProfile::assign(GraphProfile *other) { - _merges = other->_merges; - _memoryActivations = other->_memoryActivations; - _memoryTemporary = other->_memoryTemporary; - _memoryTotal = other->_memoryTotal; - _memoryObjects = other->_memoryObjects; - - _executionTime = other->_executionTime; - _buildTime = other->_buildTime; - - - for (auto v: other->_profilesById) { - nodeById(v.first, v.second->name().c_str())->assign(v.second); - } - } - - bool GraphProfile::nodeExists(int id) { - return _profilesById.count(id) > 0; - } - - void GraphProfile::printOut() { - nd4j_printf("Graph profile: %i executions\n", _merges); - nd4j_printf("\nMemory:\n", ""); - - Nd4jLong tmp = 0L; - Nd4jLong obj = 0L; - Nd4jLong act = 0L; - Nd4jLong ttl = 0L; - for (auto v: _profiles) { - tmp += v->getTemporarySize(); - obj += v->getObjectsSize(); - act += v->getActivationsSize(); - ttl += v->getTotalSize(); - } - - nd4j_printf("ACT: %lld; TMP: %lld; OBJ: %lld; TTL: %lld;\n", act / _merges, tmp / _merges, obj / _merges, ttl / _merges); - - nd4j_printf("\nTime:\n", ""); - nd4j_printf("Construction time: %lld ns;\n", _buildTime / _merges); - nd4j_printf("Execution time: %lld ns;\n", _executionTime / _merges); - - nd4j_printf("\nPer-node reports:\n", ""); - if (_profiles.empty()) - nd4j_printf("No nodes in graph\n",""); - - // printint out stuff - std::vector sorted; - for (auto v: _profiles) { - v->printOut(); - sorted.emplace_back(v); - } - - if (_profiles.size() > 1) { - // building hot spots - std::sort(sorted.begin(), sorted.end(), [](const NodeProfile *a, const NodeProfile *b) -> bool { - return a->getExecutionTime() > b->getExecutionTime(); - }); - - nd4j_printf("\nTop 50 reports by EXEC:\n", ""); - auto limit = sd::math::nd4j_min(50, sorted.size()); - for (int e = 0; e < limit; e++) { - sorted[e]->printOut(); - } - } - - nd4j_printf("\nSpecial timers:\n", ""); - if (_timings.empty()) - nd4j_printf("No special timers were set\n",""); - - for (auto v: _timings) - nd4j_printf("%s: %lld ns;\n", v.first.c_str(), v.second); - } +namespace graph { +GraphProfile::GraphProfile() { updateLast(); } + +GraphProfile::~GraphProfile() { + // releasing NodeProfile pointers + for (auto v : _profiles) delete v; + + _timings.clear(); +} + +void GraphProfile::addToTotal(Nd4jLong bytes) { _memoryTotal += bytes; } + +void GraphProfile::addToActivations(Nd4jLong bytes) { + _memoryActivations += bytes; +} + +void GraphProfile::addToTemporary(Nd4jLong bytes) { _memoryTemporary += bytes; } + +void GraphProfile::addToObjects(Nd4jLong bytes) { _memoryObjects += bytes; } + +void GraphProfile::setBuildTime(Nd4jLong nanos) { _buildTime = nanos; } + +void GraphProfile::setExecutionTime(Nd4jLong nanos) { _executionTime = nanos; } + +Nd4jLong GraphProfile::currentTime() { + auto t = std::chrono::system_clock::now(); + auto v = std::chrono::time_point_cast(t); + auto epoch = v.time_since_epoch(); + return (Nd4jLong)std::chrono::duration_cast(epoch) + .count(); +} + +Nd4jLong GraphProfile::relativeTime(Nd4jLong time) { + auto t1 = currentTime(); + return t1 - time; +} + +void GraphProfile::updateLast() { _last = std::chrono::system_clock::now(); } + +void GraphProfile::startEvent(const char *name) { + std::string k = name; + _timers[k] = std::chrono::system_clock::now(); +} + +void GraphProfile::recordEvent(const char *name) { + std::string k = name; + if (_timers.count(k) == 0) { + nd4j_printf("Can't find timer key: [%s]", name); + throw std::runtime_error("Missing timer key"); + } + auto t0 = _timers[k]; + auto t1 = std::chrono::system_clock::now(); + auto v = + (Nd4jLong)std::chrono::duration_cast(t1 - t0) + .count(); + + _timings[k] = v; + _timers.erase(k); +} + +void GraphProfile::deleteEvent(const char *name) { + std::string k = name; + _timers.erase(k); +} + +void GraphProfile::spotEvent(const char *name) { + auto t = std::chrono::system_clock::now(); + auto d = + (Nd4jLong)std::chrono::duration_cast(t - _last) + .count(); + std::string k = name; + _timings[k] = d; + updateLast(); +} + +NodeProfile *GraphProfile::nodeById(int id, const char *name) { + if (_profilesById.count(id) == 0) { + auto node = new NodeProfile(id, name); + _profiles.emplace_back(node); + _profilesById[id] = node; + return node; + } + + return _profilesById[id]; +} + +void GraphProfile::merge(GraphProfile *other) { + _merges += other->_merges; + _memoryActivations += other->_memoryActivations; + _memoryTemporary += other->_memoryTemporary; + _memoryTotal += other->_memoryTotal; + _memoryObjects += other->_memoryObjects; + + _executionTime += other->_executionTime; + _buildTime += other->_buildTime; + + for (auto v : _profilesById) { + if (!other->nodeExists(v.first)) continue; + + v.second->merge(other->nodeById(v.first)); + } +} + +void GraphProfile::assign(GraphProfile *other) { + _merges = other->_merges; + _memoryActivations = other->_memoryActivations; + _memoryTemporary = other->_memoryTemporary; + _memoryTotal = other->_memoryTotal; + _memoryObjects = other->_memoryObjects; + + _executionTime = other->_executionTime; + _buildTime = other->_buildTime; + + for (auto v : other->_profilesById) { + nodeById(v.first, v.second->name().c_str())->assign(v.second); + } +} + +bool GraphProfile::nodeExists(int id) { return _profilesById.count(id) > 0; } + +void GraphProfile::printOut() { + nd4j_printf("Graph profile: %i executions\n", _merges); + nd4j_printf("\nMemory:\n", ""); + + Nd4jLong tmp = 0L; + Nd4jLong obj = 0L; + Nd4jLong act = 0L; + Nd4jLong ttl = 0L; + for (auto v : _profiles) { + tmp += v->getTemporarySize(); + obj += v->getObjectsSize(); + act += v->getActivationsSize(); + ttl += v->getTotalSize(); + } + + nd4j_printf("ACT: %lld; TMP: %lld; OBJ: %lld; TTL: %lld;\n", act / _merges, + tmp / _merges, obj / _merges, ttl / _merges); + + nd4j_printf("\nTime:\n", ""); + nd4j_printf("Construction time: %lld ns;\n", _buildTime / _merges); + nd4j_printf("Execution time: %lld ns;\n", _executionTime / _merges); + + nd4j_printf("\nPer-node reports:\n", ""); + if (_profiles.empty()) nd4j_printf("No nodes in graph\n", ""); + + // printint out stuff + std::vector sorted; + for (auto v : _profiles) { + v->printOut(); + sorted.emplace_back(v); + } + + if (_profiles.size() > 1) { + // building hot spots + std::sort(sorted.begin(), sorted.end(), + [](const NodeProfile *a, const NodeProfile *b) -> bool { + return a->getExecutionTime() > b->getExecutionTime(); + }); + + nd4j_printf("\nTop 50 reports by EXEC:\n", ""); + auto limit = sd::math::nd4j_min(50, sorted.size()); + for (int e = 0; e < limit; e++) { + sorted[e]->printOut(); } -} \ No newline at end of file + } + + nd4j_printf("\nSpecial timers:\n", ""); + if (_timings.empty()) nd4j_printf("No special timers were set\n", ""); + + for (auto v : _timings) + nd4j_printf("%s: %lld ns;\n", v.first.c_str(), v.second); +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/profiling/impl/GraphProfilingHelper.cpp b/libnd4j/include/graph/profiling/impl/GraphProfilingHelper.cpp index 03c2411e28d4..bf72392e21a4 100644 --- a/libnd4j/include/graph/profiling/impl/GraphProfilingHelper.cpp +++ b/libnd4j/include/graph/profiling/impl/GraphProfilingHelper.cpp @@ -19,53 +19,53 @@ // #include -#include namespace sd { - namespace graph { - GraphProfile *GraphProfilingHelper::profile(Graph *graph, int iterations) { +namespace graph { +GraphProfile *GraphProfilingHelper::profile(Graph *graph, int iterations) { + if (1 > 0) + throw std::runtime_error( + "GraphProfilingHelper::profile - Not implemented yet"); - // saving original workspace - auto varSpace = graph->getVariableSpace()->clone(); + // saving original workspace + // auto varSpace = graph->variableSpace(); - // printing out graph structure - // graph->printOut(); + // printing out graph structure + // graph->printOut(); - // warm up - for (int e = 0; e < iterations; e++) { - FlowPath fp; + // warm up + for (int e = 0; e < iterations; e++) { + FlowPath fp; - auto _vs = varSpace->clone(); - //_vs->workspace()->expandTo(100000); - _vs->setFlowPath(&fp); - GraphExecutioner::execute(graph, _vs); + // auto _vs = varSpace->clone(); + //_vs->workspace()->expandTo(100000); + //_vs->setFlowPath(&fp); + // GraphExecutioner::execute(graph, _vs); - delete _vs; - } + // delete _vs; + } + auto profile = new GraphProfile(); + for (int e = 0; e < iterations; e++) { + FlowPath fp; + /* + // we're always starting from "fresh" varspace here + auto _vs = varSpace->clone(); + //_vs->workspace()->expandTo(100000); + _vs->setFlowPath(&fp); + //GraphExecutioner::execute(graph, _vs); - auto profile = new GraphProfile(); - for (int e = 0; e < iterations; e++) { - FlowPath fp; + auto p = fp.profile(); + if (e == 0) + profile->assign(p); + else + profile->merge(p); - // we're always starting from "fresh" varspace here - auto _vs = varSpace->clone(); - //_vs->workspace()->expandTo(100000); - _vs->setFlowPath(&fp); - GraphExecutioner::execute(graph, _vs); + delete _vs; + */ + } - auto p = fp.profile(); - if (e == 0) - profile->assign(p); - else - profile->merge(p); - - delete _vs; - } - - delete varSpace; - - return profile; - } - } + return profile; } +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/profiling/impl/NodeProfile.cpp b/libnd4j/include/graph/profiling/impl/NodeProfile.cpp index 8db4472e6bc9..b78d8535ee93 100644 --- a/libnd4j/include/graph/profiling/impl/NodeProfile.cpp +++ b/libnd4j/include/graph/profiling/impl/NodeProfile.cpp @@ -18,150 +18,119 @@ // @author raver119@gmail.com // -#include #include #include +#include namespace sd { - namespace graph { - NodeProfile::NodeProfile(int id, const char *name) { - _id = id; - - if (name != nullptr) - _name = name; - }; - - void NodeProfile::printOut() { - nd4j_printf("Node: <%i:%s>\n", _id, _name.c_str()); - nd4j_printf(" Memory: ACT: %lld; TMP: %lld; OBJ: %lld; TTL: %lld;\n", _memoryActivations / _merges, _memoryTemporary / _merges, _memoryObjects / _merges, _memoryTotal / _merges); - nd4j_printf(" Time: PREP: %lld ns; EXEC: %lld ns; TTL: %lld ns;\n", _preparationTime / _merges, _executionTime / _merges, _totalTime / _merges); - nd4j_printf(" PREP: INPUT: %lld ns; SHAPE: %lld ns; ARRAY: %lld ns;\n", _inputTime / _merges, _shapeTime / _merges, _arrayTime / _merges); - - std::string inputs; - std::string outputs; - - int cnt = 0; - for (const auto &v: _inputShapes) - inputs += v + " "; - - for (const auto &v: _outputShapes) - outputs += v + " "; - - - nd4j_printf(" Inputs: %s\n", inputs.c_str()); - nd4j_printf(" Outputs: %s\n", outputs.c_str()); - }; - - Nd4jLong NodeProfile::getActivationsSize() const { - return _memoryActivations; - } - - void NodeProfile::setShapeFunctionTime(Nd4jLong time) { - _shapeTime = time; - } - - void NodeProfile::setArrayTime(Nd4jLong time) { - _arrayTime = time; - } - - void NodeProfile::setInputTime(Nd4jLong time) { - _inputTime = time; - } - - Nd4jLong NodeProfile::getTemporarySize() const{ - return _memoryTemporary; - } - - Nd4jLong NodeProfile::getObjectsSize() const{ - return _memoryObjects; - } - - Nd4jLong NodeProfile::getTotalSize() const{ - return _memoryTotal; - } - - void NodeProfile::setBuildTime(Nd4jLong time) { - _buildTime = time; - } - - void NodeProfile::setPreparationTime(Nd4jLong time) { - _preparationTime = time; - } - - void NodeProfile::setExecutionTime(Nd4jLong time) { - _executionTime = time; - } - - void NodeProfile::setTotalTime(Nd4jLong time) { - _totalTime = time; - } - - void NodeProfile::setActivationsSize(Nd4jLong bytes) { - _memoryActivations = bytes; - } - - void NodeProfile::setTemporarySize(Nd4jLong bytes) { - _memoryTemporary = bytes; - } - - void NodeProfile::setObjectsSize(Nd4jLong bytes) { - _memoryObjects = bytes; - } - - void NodeProfile::setTotalSize(Nd4jLong bytes) { - _memoryTotal = bytes; - } - - Nd4jLong NodeProfile::getExecutionTime() const { - return _executionTime; - } - - void NodeProfile::addInputShape(Nd4jLong const* shapeInfo) { - _inputShapes.emplace_back(ShapeUtils::shapeInfoAsString(shapeInfo)); - } - - void NodeProfile::addOutputShape(Nd4jLong const*shapeInfo) { - _outputShapes.emplace_back(ShapeUtils::shapeInfoAsString(shapeInfo)); - } - - void NodeProfile::merge(NodeProfile *other) { - _merges += other->_merges; - _memoryObjects += other->_memoryObjects; - _memoryActivations += other->_memoryActivations; - _memoryTemporary += other->_memoryTemporary; - _memoryTotal += other->_memoryTotal; - - _preparationTime += other->_preparationTime; - _executionTime += other->_executionTime; - _totalTime += other->_totalTime; - _shapeTime += other->_shapeTime; - _arrayTime += other->_arrayTime; - _inputTime += other->_inputTime; - - _inputShapes = other->_inputShapes; - _outputShapes = other->_outputShapes; - } - - std::string& NodeProfile::name() { - return _name; - } - - void NodeProfile::assign(NodeProfile *other) { - _merges = other->_merges; - _memoryObjects = other->_memoryObjects; - _memoryActivations = other->_memoryActivations; - _memoryTemporary = other->_memoryTemporary; - _memoryTotal = other->_memoryTotal; - - _preparationTime = other->_preparationTime; - _executionTime = other->_executionTime; - _totalTime = other->_totalTime; - _shapeTime = other->_shapeTime; - _arrayTime = other->_arrayTime; - _inputTime = other->_inputTime; - - _inputShapes = other->_inputShapes; - _outputShapes = other->_outputShapes; - } - } -} \ No newline at end of file +namespace graph { +NodeProfile::NodeProfile(int id, const char *name) { + _id = id; + + if (name != nullptr) _name = name; +}; + +void NodeProfile::printOut() { + nd4j_printf("Node: <%i:%s>\n", _id, _name.c_str()); + nd4j_printf(" Memory: ACT: %lld; TMP: %lld; OBJ: %lld; TTL: %lld;\n", + _memoryActivations / _merges, _memoryTemporary / _merges, + _memoryObjects / _merges, _memoryTotal / _merges); + nd4j_printf(" Time: PREP: %lld ns; EXEC: %lld ns; TTL: %lld ns;\n", + _preparationTime / _merges, _executionTime / _merges, + _totalTime / _merges); + nd4j_printf(" PREP: INPUT: %lld ns; SHAPE: %lld ns; ARRAY: %lld ns;\n", + _inputTime / _merges, _shapeTime / _merges, _arrayTime / _merges); + + std::string inputs; + std::string outputs; + + int cnt = 0; + for (const auto &v : _inputShapes) inputs += v + " "; + + for (const auto &v : _outputShapes) outputs += v + " "; + + nd4j_printf(" Inputs: %s\n", inputs.c_str()); + nd4j_printf(" Outputs: %s\n", outputs.c_str()); +}; + +Nd4jLong NodeProfile::getActivationsSize() const { return _memoryActivations; } + +void NodeProfile::setShapeFunctionTime(Nd4jLong time) { _shapeTime = time; } + +void NodeProfile::setArrayTime(Nd4jLong time) { _arrayTime = time; } + +void NodeProfile::setInputTime(Nd4jLong time) { _inputTime = time; } + +Nd4jLong NodeProfile::getTemporarySize() const { return _memoryTemporary; } + +Nd4jLong NodeProfile::getObjectsSize() const { return _memoryObjects; } + +Nd4jLong NodeProfile::getTotalSize() const { return _memoryTotal; } + +void NodeProfile::setBuildTime(Nd4jLong time) { _buildTime = time; } + +void NodeProfile::setPreparationTime(Nd4jLong time) { _preparationTime = time; } + +void NodeProfile::setExecutionTime(Nd4jLong time) { _executionTime = time; } + +void NodeProfile::setTotalTime(Nd4jLong time) { _totalTime = time; } + +void NodeProfile::setActivationsSize(Nd4jLong bytes) { + _memoryActivations = bytes; +} + +void NodeProfile::setTemporarySize(Nd4jLong bytes) { _memoryTemporary = bytes; } + +void NodeProfile::setObjectsSize(Nd4jLong bytes) { _memoryObjects = bytes; } + +void NodeProfile::setTotalSize(Nd4jLong bytes) { _memoryTotal = bytes; } + +Nd4jLong NodeProfile::getExecutionTime() const { return _executionTime; } + +void NodeProfile::addInputShape(Nd4jLong const *shapeInfo) { + _inputShapes.emplace_back(ShapeUtils::shapeInfoAsString(shapeInfo)); +} + +void NodeProfile::addOutputShape(Nd4jLong const *shapeInfo) { + _outputShapes.emplace_back(ShapeUtils::shapeInfoAsString(shapeInfo)); +} + +void NodeProfile::merge(NodeProfile *other) { + _merges += other->_merges; + _memoryObjects += other->_memoryObjects; + _memoryActivations += other->_memoryActivations; + _memoryTemporary += other->_memoryTemporary; + _memoryTotal += other->_memoryTotal; + + _preparationTime += other->_preparationTime; + _executionTime += other->_executionTime; + _totalTime += other->_totalTime; + _shapeTime += other->_shapeTime; + _arrayTime += other->_arrayTime; + _inputTime += other->_inputTime; + + _inputShapes = other->_inputShapes; + _outputShapes = other->_outputShapes; +} + +std::string &NodeProfile::name() { return _name; } + +void NodeProfile::assign(NodeProfile *other) { + _merges = other->_merges; + _memoryObjects = other->_memoryObjects; + _memoryActivations = other->_memoryActivations; + _memoryTemporary = other->_memoryTemporary; + _memoryTotal = other->_memoryTotal; + + _preparationTime = other->_preparationTime; + _executionTime = other->_executionTime; + _totalTime = other->_totalTime; + _shapeTime = other->_shapeTime; + _arrayTime = other->_arrayTime; + _inputTime = other->_inputTime; + + _inputShapes = other->_inputShapes; + _outputShapes = other->_outputShapes; +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/ArrayUtils.h b/libnd4j/include/helpers/ArrayUtils.h index 2ecebeb4aa98..6392cce00d32 100644 --- a/libnd4j/include/helpers/ArrayUtils.h +++ b/libnd4j/include/helpers/ArrayUtils.h @@ -21,23 +21,23 @@ #ifndef LIBND4J_ARRAYUTILS_H #define LIBND4J_ARRAYUTILS_H +#include + +#include #include #include -#include -#include namespace sd { - namespace ArrayUtils { - void toIntPtr(std::initializer_list list, int* target); - void toIntPtr(std::vector& list, int* target); - - void toLongPtr(std::initializer_list list, Nd4jLong* target); - void toLongPtr(std::vector& list, Nd4jLong* target); +namespace ArrayUtils { +void toIntPtr(std::initializer_list list, int* target); +void toIntPtr(std::vector& list, int* target); +void toLongPtr(std::initializer_list list, Nd4jLong* target); +void toLongPtr(std::vector& list, Nd4jLong* target); - std::vector toLongVector(std::vector vec); - std::vector toLongVector(std::vector vec); - } -} +std::vector toLongVector(std::vector vec); +std::vector toLongVector(std::vector vec); +} // namespace ArrayUtils +} // namespace sd -#endif //LIBND4J_ARRAYUTILS_H +#endif // LIBND4J_ARRAYUTILS_H diff --git a/libnd4j/include/helpers/AttentionHelper.h b/libnd4j/include/helpers/AttentionHelper.h index 02d9da995541..d6794b45db53 100644 --- a/libnd4j/include/helpers/AttentionHelper.h +++ b/libnd4j/include/helpers/AttentionHelper.h @@ -24,13 +24,17 @@ #include "array/NDArray.h" namespace sd { - class ND4J_EXPORT AttentionHelper { - - public: - static sd::NDArray multiHeadProject(const sd::NDArray* input, const sd::NDArray* projectionMatrix, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static void multiHeadProjectBp(const sd::NDArray* input, const sd::NDArray* projectionMatrix, const sd::NDArray* eps, sd::NDArray* dLdInput, sd::NDArray* dLdProjectionMatrix, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - }; -} - +class SD_EXPORT AttentionHelper { + public: + static sd::NDArray multiHeadProject( + const sd::NDArray* input, const sd::NDArray* projectionMatrix, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static void multiHeadProjectBp( + const sd::NDArray* input, const sd::NDArray* projectionMatrix, + const sd::NDArray* eps, sd::NDArray* dLdInput, + sd::NDArray* dLdProjectionMatrix, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); +}; +} // namespace sd #endif diff --git a/libnd4j/include/helpers/BenchmarkHelper.h b/libnd4j/include/helpers/BenchmarkHelper.h index f76f787d8b69..31411d8a2ff2 100644 --- a/libnd4j/include/helpers/BenchmarkHelper.h +++ b/libnd4j/include/helpers/BenchmarkHelper.h @@ -21,69 +21,117 @@ #ifndef LIBND4J_BENCHMARKHELPER_H #define LIBND4J_BENCHMARKHELPER_H - +#include +#include +#include #include -#include -#include -#include -#include +#include +#include #include +#include +#include #include -#include -#include -#include -#include +#include #include -#include #include #include -#include -#include -#include -#include +#include +#include +#include +#include +#include namespace sd { - class ND4J_EXPORT BenchmarkHelper { - private: - unsigned int _wIterations; - unsigned int _rIterations; - - protected: - std::string benchmarkOperation(OpBenchmark &benchmark); - - void benchmarkScalarOperation(scalar::Ops op, std::string testName, double value, NDArray &x, NDArray &z); - - void benchmarkDeclarableOp(sd::ops::DeclarableOp &op, std::string testName, Context &context); - - void benchmarkGEMM(char orderA, std::initializer_list shapeA, char orderB, std::initializer_list shapeB, char orderC, std::initializer_list shapeC); - - std::string printHeader(); - public: - BenchmarkHelper(unsigned int warmUpIterations = 10, unsigned int runIterations = 100); - - std::string runOperationSuit(std::initializer_list benchmarks, const char *msg = nullptr); - std::string runOperationSuit(std::vector &benchmarks, bool postHeaders, const char *msg = nullptr); - std::string runOperationSuit(OpBenchmark* benchmark); - - std::string runOperationSuit(ScalarBenchmark *op, const std::function& func, const char *message = nullptr); - std::string runOperationSuit(TransformBenchmark *op, const std::function& func, const char *message = nullptr); - std::string runOperationSuit(ReductionBenchmark *op, const std::function& func, const char *message = nullptr); - std::string runOperationSuit(ReductionBenchmark *op, const std::function& func, const char *message = nullptr); - std::string runOperationSuit(PairwiseBenchmark *op, const std::function& func, const char *message = nullptr); - - - std::string runOperationSuit(TransformBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message = nullptr); - std::string runOperationSuit(ScalarBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message = nullptr); - std::string runOperationSuit(ReductionBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message = nullptr); - std::string runOperationSuit(ReductionBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message = nullptr); - std::string runOperationSuit(BroadcastBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message = nullptr); - std::string runOperationSuit(PairwiseBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message = nullptr); - std::string runOperationSuit(MatrixBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message = nullptr); - - std::string runOperationSuit(DeclarableBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message = nullptr); - }; -} - - -#endif //DEV_TESTS_BENCHMARKHELPER_H +class SD_EXPORT BenchmarkHelper { + private: + unsigned int _wIterations; + unsigned int _rIterations; + + protected: + std::string benchmarkOperation(OpBenchmark &benchmark); + + void benchmarkScalarOperation(scalar::Ops op, std::string testName, + double value, NDArray &x, NDArray &z); + + void benchmarkDeclarableOp(sd::ops::DeclarableOp &op, std::string testName, + Context &context); + + void benchmarkGEMM(char orderA, std::initializer_list shapeA, + char orderB, std::initializer_list shapeB, + char orderC, std::initializer_list shapeC); + + std::string printHeader(); + + public: + BenchmarkHelper(unsigned int warmUpIterations = 10, + unsigned int runIterations = 100); + + std::string runOperationSuit(std::initializer_list benchmarks, + const char *msg = nullptr); + std::string runOperationSuit(std::vector &benchmarks, + bool postHeaders, const char *msg = nullptr); + std::string runOperationSuit(OpBenchmark *benchmark); + + std::string runOperationSuit( + ScalarBenchmark *op, + const std::function &func, + const char *message = nullptr); + std::string runOperationSuit( + TransformBenchmark *op, + const std::function &func, + const char *message = nullptr); + std::string runOperationSuit( + ReductionBenchmark *op, + const std::function &func, + const char *message = nullptr); + std::string runOperationSuit( + ReductionBenchmark *op, + const std::function &func, + const char *message = nullptr); + std::string runOperationSuit( + PairwiseBenchmark *op, + const std::function &func, + const char *message = nullptr); + + std::string runOperationSuit( + TransformBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message = nullptr); + std::string runOperationSuit( + ScalarBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message = nullptr); + std::string runOperationSuit( + ReductionBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message = nullptr); + std::string runOperationSuit( + ReductionBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message = nullptr); + std::string runOperationSuit( + BroadcastBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message = nullptr); + std::string runOperationSuit( + PairwiseBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message = nullptr); + std::string runOperationSuit( + MatrixBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message = nullptr); + + std::string runOperationSuit( + DeclarableBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message = nullptr); +}; +} // namespace sd + +#endif // SD_BENCHMARKHELPER_H diff --git a/libnd4j/include/helpers/BitwiseUtils.h b/libnd4j/include/helpers/BitwiseUtils.h index 6b7e5c231fcf..931a47d51f51 100644 --- a/libnd4j/include/helpers/BitwiseUtils.h +++ b/libnd4j/include/helpers/BitwiseUtils.h @@ -21,109 +21,90 @@ #ifndef LIBND4J_BITWISEUTILS_H #define LIBND4J_BITWISEUTILS_H -#include #include #include #include + #include +#include namespace sd { - class ND4J_EXPORT BitwiseUtils { - public: - - - /** - * This method returns first non-zero bit index - * @param holder - * @return - */ - static int valueBit(int holder); - - /** - * This method returns vector representation of bits. - * - * PLEASE NOTE: Result is ALWAYS left-to-right - */ - static std::vector valueBits(int holder); - - /** - * This method returns TRUE if it's called on Big-Endian system, and false otherwise - */ - static bool isBE(); - - /** - * This method returns enum - * @return - */ - static sd::ByteOrder asByteOrder(); - - /** - * This method swaps bytes: LE vs BE - * @tparam T - * @param v - * @return - */ - template - static FORCEINLINE T swap_bytes(T v) { - static_assert (CHAR_BIT == 8, "CHAR_BIT != 8"); - - union S { - T v; - unsigned char u8[sizeof(T)]; - S(T val) { - v = val; - } - }; - - S source(v); - S dest(v); - - for (size_t k = 0; k < sizeof(T); k++) - dest.u8[k] = source.u8[sizeof(T) - k - 1]; - - return dest.v; - } - - /** - * This method flips bits in given value - * - * @tparam T - * @param v - * @return - */ - static int FORCEINLINE flip_bits(int v) { - return ~v; - } - - static int8_t FORCEINLINE flip_bits(int8_t v) { - return ~v; - } - - static int16_t FORCEINLINE flip_bits(int16_t v) { - return ~v; - } - - static uint8_t FORCEINLINE flip_bits(uint8_t v) { - return ~v; - } - - static uint16_t FORCEINLINE flip_bits(uint16_t v) { - return ~v; - } - - static uint32_t FORCEINLINE flip_bits(uint32_t v) { - return ~v; - } - - static uint64_t FORCEINLINE flip_bits(uint64_t v) { - return ~v; - } - - static Nd4jLong FORCEINLINE flip_bits(Nd4jLong v) { - return ~v; - } +class SD_EXPORT BitwiseUtils { + public: + /** + * This method returns first non-zero bit index + * @param holder + * @return + */ + static int valueBit(int holder); + + /** + * This method returns vector representation of bits. + * + * PLEASE NOTE: Result is ALWAYS left-to-right + */ + static std::vector valueBits(int holder); + + /** + * This method returns TRUE if it's called on Big-Endian system, and false + * otherwise + */ + static bool isBE(); + + /** + * This method returns enum + * @return + */ + static sd::ByteOrder asByteOrder(); + + /** + * This method swaps bytes: LE vs BE + * @tparam T + * @param v + * @return + */ + template + static FORCEINLINE T swap_bytes(T v) { + static_assert(CHAR_BIT == 8, "CHAR_BIT != 8"); + + union S { + T v; + unsigned char u8[sizeof(T)]; + S(T val) { v = val; } }; -} + S source(v); + S dest(v); + + for (size_t k = 0; k < sizeof(T); k++) + dest.u8[k] = source.u8[sizeof(T) - k - 1]; + + return dest.v; + } + + /** + * This method flips bits in given value + * + * @tparam T + * @param v + * @return + */ + static int FORCEINLINE flip_bits(int v) { return ~v; } + + static int8_t FORCEINLINE flip_bits(int8_t v) { return ~v; } + + static int16_t FORCEINLINE flip_bits(int16_t v) { return ~v; } + + static uint8_t FORCEINLINE flip_bits(uint8_t v) { return ~v; } + + static uint16_t FORCEINLINE flip_bits(uint16_t v) { return ~v; } + + static uint32_t FORCEINLINE flip_bits(uint32_t v) { return ~v; } + + static uint64_t FORCEINLINE flip_bits(uint64_t v) { return ~v; } + + static Nd4jLong FORCEINLINE flip_bits(Nd4jLong v) { return ~v; } +}; +} // namespace sd -#endif //LIBND4J_BITWISEUTILS_H +#endif // LIBND4J_BITWISEUTILS_H diff --git a/libnd4j/include/helpers/BlasHelper.h b/libnd4j/include/helpers/BlasHelper.h index 038df67b5c6e..265c262be4ce 100644 --- a/libnd4j/include/helpers/BlasHelper.h +++ b/libnd4j/include/helpers/BlasHelper.h @@ -21,98 +21,90 @@ #ifndef LIBND4J_BLAS_HELPER_H #define LIBND4J_BLAS_HELPER_H -#include -#include #include #include +#include +#include #ifdef _WIN32 #define CUBLASWINAPI __stdcall #define CUSOLVERAPI __stdcall #else -#define CUBLASWINAPI -#define CUSOLVERAPI +#define CUBLASWINAPI +#define CUSOLVERAPI #endif namespace sd { - typedef enum{ - CUBLAS_STATUS_SUCCESS =0, - CUBLAS_STATUS_NOT_INITIALIZED =1, - CUBLAS_STATUS_ALLOC_FAILED =3, - CUBLAS_STATUS_INVALID_VALUE =7, - CUBLAS_STATUS_ARCH_MISMATCH =8, - CUBLAS_STATUS_MAPPING_ERROR =11, - CUBLAS_STATUS_EXECUTION_FAILED=13, - CUBLAS_STATUS_INTERNAL_ERROR =14, - CUBLAS_STATUS_NOT_SUPPORTED =15, - CUBLAS_STATUS_LICENSE_ERROR =16 - } cublasStatus_t; - - typedef enum { - CUBLAS_OP_N=0, - CUBLAS_OP_T=1, - CUBLAS_OP_C=2 - } cublasOperation_t; - - struct cublasContext; - typedef struct cublasContext *cublasHandle_t; - - typedef enum - { - CUDA_R_16F= 2, /* real as a half */ - CUDA_C_16F= 6, /* complex as a pair of half numbers */ - CUDA_R_32F= 0, /* real as a float */ - CUDA_C_32F= 4, /* complex as a pair of float numbers */ - CUDA_R_64F= 1, /* real as a double */ - CUDA_C_64F= 5, /* complex as a pair of double numbers */ - CUDA_R_8I = 3, /* real as a signed char */ - CUDA_C_8I = 7, /* complex as a pair of signed char numbers */ - CUDA_R_8U = 8, /* real as a unsigned char */ - CUDA_C_8U = 9, /* complex as a pair of unsigned char numbers */ - CUDA_R_32I= 10, /* real as a signed int */ - CUDA_C_32I= 11, /* complex as a pair of signed int numbers */ - CUDA_R_32U= 12, /* real as a unsigned int */ - CUDA_C_32U= 13 /* complex as a pair of unsigned int numbers */ - } cublasDataType_t; - - typedef void (*CblasSgemv)(CBLAS_ORDER Layout, - CBLAS_TRANSPOSE TransA, int M, int N, - float alpha, float *A, int lda, - float *X, int incX, float beta, - float *Y, int incY); - - typedef void (*CblasDgemv)(CBLAS_ORDER Layout, - CBLAS_TRANSPOSE TransA, int M, int N, - double alpha, double *A, int lda, - double *X, int incX, double beta, - double *Y, int incY); - - - typedef void (*CblasSgemm)(CBLAS_ORDER Layout, CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, int M, int N, - int K, float alpha, float *A, - int lda, float *B, int ldb, - float beta, float *C, int ldc); - - typedef void (*CblasDgemm)(CBLAS_ORDER Layout, CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, int M, int N, - int K, double alpha, double *A, - int lda, double *B, int ldb, - double beta, double *C, int ldc); - - typedef void (*CblasSgemmBatch)(CBLAS_ORDER Layout, CBLAS_TRANSPOSE *TransA_Array, - CBLAS_TRANSPOSE *TransB_Array, int *M_Array, int *N_Array, - int *K_Array, float *alpha_Array, float **A_Array, - int *lda_Array, float **B_Array, int *ldb_Array, - float *beta_Array, float **C_Array, int *ldc_Array, - int group_count, int *group_size); - - typedef void (*CblasDgemmBatch)(CBLAS_ORDER Layout, CBLAS_TRANSPOSE *TransA_Array, - CBLAS_TRANSPOSE *TransB_Array, int *M_Array, int *N_Array, - int *K_Array, double *alpha_Array, double **A_Array, - int *lda_Array, double **B_Array, int* ldb_Array, - double *beta_Array, double **C_Array, int *ldc_Array, - int group_count, int *group_size); +typedef enum { + CUBLAS_STATUS_SUCCESS = 0, + CUBLAS_STATUS_NOT_INITIALIZED = 1, + CUBLAS_STATUS_ALLOC_FAILED = 3, + CUBLAS_STATUS_INVALID_VALUE = 7, + CUBLAS_STATUS_ARCH_MISMATCH = 8, + CUBLAS_STATUS_MAPPING_ERROR = 11, + CUBLAS_STATUS_EXECUTION_FAILED = 13, + CUBLAS_STATUS_INTERNAL_ERROR = 14, + CUBLAS_STATUS_NOT_SUPPORTED = 15, + CUBLAS_STATUS_LICENSE_ERROR = 16 +} cublasStatus_t; + +typedef enum { + CUBLAS_OP_N = 0, + CUBLAS_OP_T = 1, + CUBLAS_OP_C = 2 +} cublasOperation_t; + +struct cublasContext; +typedef struct cublasContext *cublasHandle_t; + +typedef enum { + CUDA_R_16F = 2, /* real as a half */ + CUDA_C_16F = 6, /* complex as a pair of half numbers */ + CUDA_R_32F = 0, /* real as a float */ + CUDA_C_32F = 4, /* complex as a pair of float numbers */ + CUDA_R_64F = 1, /* real as a double */ + CUDA_C_64F = 5, /* complex as a pair of double numbers */ + CUDA_R_8I = 3, /* real as a signed char */ + CUDA_C_8I = 7, /* complex as a pair of signed char numbers */ + CUDA_R_8U = 8, /* real as a unsigned char */ + CUDA_C_8U = 9, /* complex as a pair of unsigned char numbers */ + CUDA_R_32I = 10, /* real as a signed int */ + CUDA_C_32I = 11, /* complex as a pair of signed int numbers */ + CUDA_R_32U = 12, /* real as a unsigned int */ + CUDA_C_32U = 13 /* complex as a pair of unsigned int numbers */ +} cublasDataType_t; + +typedef void (*CblasSgemv)(CBLAS_ORDER Layout, CBLAS_TRANSPOSE TransA, int M, + int N, float alpha, float *A, int lda, float *X, + int incX, float beta, float *Y, int incY); + +typedef void (*CblasDgemv)(CBLAS_ORDER Layout, CBLAS_TRANSPOSE TransA, int M, + int N, double alpha, double *A, int lda, double *X, + int incX, double beta, double *Y, int incY); + +typedef void (*CblasSgemm)(CBLAS_ORDER Layout, CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, int M, int N, int K, + float alpha, float *A, int lda, float *B, int ldb, + float beta, float *C, int ldc); + +typedef void (*CblasDgemm)(CBLAS_ORDER Layout, CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, int M, int N, int K, + double alpha, double *A, int lda, double *B, int ldb, + double beta, double *C, int ldc); + +typedef void (*CblasSgemmBatch)( + CBLAS_ORDER Layout, CBLAS_TRANSPOSE *TransA_Array, + CBLAS_TRANSPOSE *TransB_Array, int *M_Array, int *N_Array, int *K_Array, + float *alpha_Array, float **A_Array, int *lda_Array, float **B_Array, + int *ldb_Array, float *beta_Array, float **C_Array, int *ldc_Array, + int group_count, int *group_size); + +typedef void (*CblasDgemmBatch)( + CBLAS_ORDER Layout, CBLAS_TRANSPOSE *TransA_Array, + CBLAS_TRANSPOSE *TransB_Array, int *M_Array, int *N_Array, int *K_Array, + double *alpha_Array, double **A_Array, int *lda_Array, double **B_Array, + int *ldb_Array, double *beta_Array, double **C_Array, int *ldc_Array, + int group_count, int *group_size); #ifdef LAPACK_ROW_MAJOR #undef LAPACK_ROW_MAJOR @@ -121,322 +113,215 @@ namespace sd { #ifdef LAPACK_COL_MAJOR #undef LAPACK_COL_MAJOR #endif - enum LAPACK_LAYOUT { LAPACK_ROW_MAJOR=101, LAPACK_COL_MAJOR=102 }; - - typedef int (*LapackeSgesvd)(LAPACK_LAYOUT matrix_layout, char jobu, char jobvt, - int m, int n, float* a, int lda, - float* s, float* u, int ldu, float* vt, - int ldvt, float* superb); - - typedef int (*LapackeDgesvd)(LAPACK_LAYOUT matrix_layout, char jobu, char jobvt, - int m, int n, double* a, - int lda, double* s, double* u, int ldu, - double* vt, int ldvt, double* superb); - - typedef int (*LapackeSgesdd)(LAPACK_LAYOUT matrix_layout, char jobz, int m, - int n, float* a, int lda, float* s, - float* u, int ldu, float* vt, - int ldvt); - typedef int (*LapackeDgesdd)(LAPACK_LAYOUT matrix_layout, char jobz, int m, - int n, double* a, int lda, double* s, - double* u, int ldu, double* vt, - int ldvt); - - typedef cublasStatus_t (CUBLASWINAPI *CublasSgemv)(cublasHandle_t handle, - cublasOperation_t trans, - int m, - int n, - float *alpha, /* host or device pointer */ - float *A, - int lda, - float *x, - int incx, - float *beta, /* host or device pointer */ - float *y, - int incy); - - typedef cublasStatus_t (CUBLASWINAPI *CublasDgemv)(cublasHandle_t handle, - cublasOperation_t trans, - int m, - int n, - double *alpha, /* host or device pointer */ - double *A, - int lda, - double *x, - int incx, - double *beta, /* host or device pointer */ - double *y, - int incy); - - typedef cublasStatus_t (CUBLASWINAPI *CublasHgemm)(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - __half *alpha, /* host or device pointer */ - __half *A, - int lda, - __half *B, - int ldb, - __half *beta, /* host or device pointer */ - __half *C, - int ldc); - - typedef cublasStatus_t (CUBLASWINAPI *CublasSgemm)(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - float *alpha, /* host or device pointer */ - float *A, - int lda, - float *B, - int ldb, - float *beta, /* host or device pointer */ - float *C, - int ldc); - - typedef cublasStatus_t (CUBLASWINAPI *CublasDgemm)(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - double *alpha, /* host or device pointer */ - double *A, - int lda, - double *B, - int ldb, - double *beta, /* host or device pointer */ - double *C, - int ldc); - - typedef cublasStatus_t (CUBLASWINAPI *CublasSgemmEx)(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - float *alpha, /* host or device pointer */ - void *A, - cublasDataType_t Atype, - int lda, - void *B, - cublasDataType_t Btype, - int ldb, - float *beta, /* host or device pointer */ - void *C, - cublasDataType_t Ctype, - int ldc); - - typedef cublasStatus_t (CUBLASWINAPI *CublasHgemmBatched)(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - __half *alpha, /* host or device pointer */ - __half *Aarray[], - int lda, - __half *Barray[], - int ldb, - __half *beta, /* host or device pointer */ - __half *Carray[], - int ldc, - int batchCount); - - typedef cublasStatus_t (CUBLASWINAPI *CublasSgemmBatched)(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - float *alpha, /* host or device pointer */ - float *Aarray[], - int lda, - float *Barray[], - int ldb, - float *beta, /* host or device pointer */ - float *Carray[], - int ldc, - int batchCount); - - typedef cublasStatus_t (CUBLASWINAPI *CublasDgemmBatched)(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - double *alpha, /* host or device pointer */ - double *Aarray[], - int lda, - double *Barray[], - int ldb, - double *beta, /* host or device pointer */ - double *Carray[], - int ldc, - int batchCount); - - typedef enum{ - CUSOLVER_STATUS_SUCCESS=0, - CUSOLVER_STATUS_NOT_INITIALIZED=1, - CUSOLVER_STATUS_ALLOC_FAILED=2, - CUSOLVER_STATUS_INVALID_VALUE=3, - CUSOLVER_STATUS_ARCH_MISMATCH=4, - CUSOLVER_STATUS_MAPPING_ERROR=5, - CUSOLVER_STATUS_EXECUTION_FAILED=6, - CUSOLVER_STATUS_INTERNAL_ERROR=7, - CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED=8, - CUSOLVER_STATUS_NOT_SUPPORTED = 9, - CUSOLVER_STATUS_ZERO_PIVOT=10, - CUSOLVER_STATUS_INVALID_LICENSE=11 - } cusolverStatus_t; - - typedef enum { - CUSOLVER_EIG_TYPE_1=1, - CUSOLVER_EIG_TYPE_2=2, - CUSOLVER_EIG_TYPE_3=3 - } cusolverEigType_t ; - - typedef enum { - CUSOLVER_EIG_MODE_NOVECTOR=0, - CUSOLVER_EIG_MODE_VECTOR=1 - } cusolverEigMode_t ; - - struct cusolverDnContext; - typedef struct cusolverDnContext *cusolverDnHandle_t; - - typedef cusolverStatus_t (CUSOLVERAPI *CusolverDnSgesvdBufferSize)( - cusolverDnHandle_t handle, - int m, - int n, - int *lwork); - - typedef cusolverStatus_t (CUSOLVERAPI *CusolverDnDgesvdBufferSize)( - cusolverDnHandle_t handle, - int m, - int n, - int *lwork); - - typedef cusolverStatus_t (CUSOLVERAPI *CusolverDnSgesvd)( - cusolverDnHandle_t handle, - signed char jobu, - signed char jobvt, - int m, - int n, - float *A, - int lda, - float *S, - float *U, - int ldu, - float *VT, - int ldvt, - float *work, - int lwork, - float *rwork, - int *info); - - typedef cusolverStatus_t (CUSOLVERAPI *CusolverDnDgesvd)( - cusolverDnHandle_t handle, - signed char jobu, - signed char jobvt, - int m, - int n, - double *A, - int lda, - double *S, - double *U, - int ldu, - double *VT, - int ldvt, - double *work, - int lwork, - double *rwork, - int *info); - - - enum BlasFunctions { - GEMV = 0, - GEMM = 1, - }; - - class BlasHelper { - private: - bool _hasHgemv = false; - bool _hasHgemm = false; - bool _hasHgemmBatch = false; - - bool _hasSgemv = false; - bool _hasSgemm = false; - bool _hasSgemmBatch = false; - - bool _hasDgemv = false; - bool _hasDgemm = false; - bool _hasDgemmBatch = false; - - CblasSgemv cblasSgemv; - CblasDgemv cblasDgemv; - CblasSgemm cblasSgemm; - CblasDgemm cblasDgemm; - CblasSgemmBatch cblasSgemmBatch; - CblasDgemmBatch cblasDgemmBatch; - LapackeSgesvd lapackeSgesvd; - LapackeDgesvd lapackeDgesvd; - LapackeSgesdd lapackeSgesdd; - LapackeDgesdd lapackeDgesdd; - - CublasSgemv cublasSgemv; - CublasDgemv cublasDgemv; - CublasHgemm cublasHgemm; - CublasSgemm cublasSgemm; - CublasDgemm cublasDgemm; - CublasSgemmEx cublasSgemmEx; - CublasHgemmBatched cublasHgemmBatched; - CublasSgemmBatched cublasSgemmBatched; - CublasDgemmBatched cublasDgemmBatched; - CusolverDnSgesvdBufferSize cusolverDnSgesvdBufferSize; - CusolverDnDgesvdBufferSize cusolverDnDgesvdBufferSize; - CusolverDnSgesvd cusolverDnSgesvd; - CusolverDnDgesvd cusolverDnDgesvd; - - public: - static BlasHelper& getInstance(); - - void initializeFunctions(Nd4jPointer *functions); - void initializeDeviceFunctions(Nd4jPointer *functions); - - template - bool hasGEMV(); - - template - bool hasGEMM(); - - bool hasGEMM(const sd::DataType dtype); - bool hasGEMV(const sd::DataType dtype); - - template - bool hasBatchedGEMM(); - - CblasSgemv sgemv(); - CblasDgemv dgemv(); - - CblasSgemm sgemm(); - CblasDgemm dgemm(); - - CblasSgemmBatch sgemmBatched(); - CblasDgemmBatch dgemmBatched(); - - LapackeSgesvd sgesvd(); - LapackeDgesvd dgesvd(); - - LapackeSgesdd sgesdd(); - LapackeDgesdd dgesdd(); - - // destructor - ~BlasHelper() noexcept; - }; -} +enum LAPACK_LAYOUT { LAPACK_ROW_MAJOR = 101, LAPACK_COL_MAJOR = 102 }; + +typedef int (*LapackeSgesvd)(LAPACK_LAYOUT matrix_layout, char jobu, char jobvt, + int m, int n, float *a, int lda, float *s, + float *u, int ldu, float *vt, int ldvt, + float *superb); + +typedef int (*LapackeDgesvd)(LAPACK_LAYOUT matrix_layout, char jobu, char jobvt, + int m, int n, double *a, int lda, double *s, + double *u, int ldu, double *vt, int ldvt, + double *superb); + +typedef int (*LapackeSgesdd)(LAPACK_LAYOUT matrix_layout, char jobz, int m, + int n, float *a, int lda, float *s, float *u, + int ldu, float *vt, int ldvt); +typedef int (*LapackeDgesdd)(LAPACK_LAYOUT matrix_layout, char jobz, int m, + int n, double *a, int lda, double *s, double *u, + int ldu, double *vt, int ldvt); + +typedef cublasStatus_t(CUBLASWINAPI *CublasSgemv)( + cublasHandle_t handle, cublasOperation_t trans, int m, int n, + float *alpha, /* host or device pointer */ + float *A, int lda, float *x, int incx, + float *beta, /* host or device pointer */ + float *y, int incy); + +typedef cublasStatus_t(CUBLASWINAPI *CublasDgemv)( + cublasHandle_t handle, cublasOperation_t trans, int m, int n, + double *alpha, /* host or device pointer */ + double *A, int lda, double *x, int incx, + double *beta, /* host or device pointer */ + double *y, int incy); + +typedef cublasStatus_t(CUBLASWINAPI *CublasHgemm)( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, __half *alpha, /* host or device pointer */ + __half *A, int lda, __half *B, int ldb, + __half *beta, /* host or device pointer */ + __half *C, int ldc); + +typedef cublasStatus_t(CUBLASWINAPI *CublasSgemm)( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, float *alpha, /* host or device pointer */ + float *A, int lda, float *B, int ldb, + float *beta, /* host or device pointer */ + float *C, int ldc); + +typedef cublasStatus_t(CUBLASWINAPI *CublasDgemm)( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, double *alpha, /* host or device pointer */ + double *A, int lda, double *B, int ldb, + double *beta, /* host or device pointer */ + double *C, int ldc); + +typedef cublasStatus_t(CUBLASWINAPI *CublasSgemmEx)( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, float *alpha, /* host or device pointer */ + void *A, cublasDataType_t Atype, int lda, void *B, cublasDataType_t Btype, + int ldb, float *beta, /* host or device pointer */ + void *C, cublasDataType_t Ctype, int ldc); + +typedef cublasStatus_t(CUBLASWINAPI *CublasHgemmBatched)( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, __half *alpha, /* host or device pointer */ + __half *Aarray[], int lda, __half *Barray[], int ldb, + __half *beta, /* host or device pointer */ + __half *Carray[], int ldc, int batchCount); + +typedef cublasStatus_t(CUBLASWINAPI *CublasSgemmBatched)( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, float *alpha, /* host or device pointer */ + float *Aarray[], int lda, float *Barray[], int ldb, + float *beta, /* host or device pointer */ + float *Carray[], int ldc, int batchCount); + +typedef cublasStatus_t(CUBLASWINAPI *CublasDgemmBatched)( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, double *alpha, /* host or device pointer */ + double *Aarray[], int lda, double *Barray[], int ldb, + double *beta, /* host or device pointer */ + double *Carray[], int ldc, int batchCount); + +typedef enum { + CUSOLVER_STATUS_SUCCESS = 0, + CUSOLVER_STATUS_NOT_INITIALIZED = 1, + CUSOLVER_STATUS_ALLOC_FAILED = 2, + CUSOLVER_STATUS_INVALID_VALUE = 3, + CUSOLVER_STATUS_ARCH_MISMATCH = 4, + CUSOLVER_STATUS_MAPPING_ERROR = 5, + CUSOLVER_STATUS_EXECUTION_FAILED = 6, + CUSOLVER_STATUS_INTERNAL_ERROR = 7, + CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED = 8, + CUSOLVER_STATUS_NOT_SUPPORTED = 9, + CUSOLVER_STATUS_ZERO_PIVOT = 10, + CUSOLVER_STATUS_INVALID_LICENSE = 11 +} cusolverStatus_t; + +typedef enum { + CUSOLVER_EIG_TYPE_1 = 1, + CUSOLVER_EIG_TYPE_2 = 2, + CUSOLVER_EIG_TYPE_3 = 3 +} cusolverEigType_t; + +typedef enum { + CUSOLVER_EIG_MODE_NOVECTOR = 0, + CUSOLVER_EIG_MODE_VECTOR = 1 +} cusolverEigMode_t; + +struct cusolverDnContext; +typedef struct cusolverDnContext *cusolverDnHandle_t; + +typedef cusolverStatus_t(CUSOLVERAPI *CusolverDnSgesvdBufferSize)( + cusolverDnHandle_t handle, int m, int n, int *lwork); + +typedef cusolverStatus_t(CUSOLVERAPI *CusolverDnDgesvdBufferSize)( + cusolverDnHandle_t handle, int m, int n, int *lwork); + +typedef cusolverStatus_t(CUSOLVERAPI *CusolverDnSgesvd)( + cusolverDnHandle_t handle, signed char jobu, signed char jobvt, int m, + int n, float *A, int lda, float *S, float *U, int ldu, float *VT, int ldvt, + float *work, int lwork, float *rwork, int *info); + +typedef cusolverStatus_t(CUSOLVERAPI *CusolverDnDgesvd)( + cusolverDnHandle_t handle, signed char jobu, signed char jobvt, int m, + int n, double *A, int lda, double *S, double *U, int ldu, double *VT, + int ldvt, double *work, int lwork, double *rwork, int *info); + +enum BlasFunctions { + GEMV = 0, + GEMM = 1, +}; + +class BlasHelper { + private: + + + bool _hasHgemv = false; + bool _hasHgemm = false; + bool _hasHgemmBatch = false; + + bool _hasSgemv = false; + bool _hasSgemm = false; + bool _hasSgemmBatch = false; + + bool _hasDgemv = false; + bool _hasDgemm = false; + bool _hasDgemmBatch = false; + + CblasSgemv cblasSgemv; + CblasDgemv cblasDgemv; + CblasSgemm cblasSgemm; + CblasDgemm cblasDgemm; + CblasSgemmBatch cblasSgemmBatch; + CblasDgemmBatch cblasDgemmBatch; + LapackeSgesvd lapackeSgesvd; + LapackeDgesvd lapackeDgesvd; + LapackeSgesdd lapackeSgesdd; + LapackeDgesdd lapackeDgesdd; + + CublasSgemv cublasSgemv; + CublasDgemv cublasDgemv; + CublasHgemm cublasHgemm; + CublasSgemm cublasSgemm; + CublasDgemm cublasDgemm; + CublasSgemmEx cublasSgemmEx; + CublasHgemmBatched cublasHgemmBatched; + CublasSgemmBatched cublasSgemmBatched; + CublasDgemmBatched cublasDgemmBatched; + CusolverDnSgesvdBufferSize cusolverDnSgesvdBufferSize; + CusolverDnDgesvdBufferSize cusolverDnDgesvdBufferSize; + CusolverDnSgesvd cusolverDnSgesvd; + CusolverDnDgesvd cusolverDnDgesvd; + + public: + static BlasHelper &getInstance(); + + void initializeFunctions(Nd4jPointer *functions); + void initializeDeviceFunctions(Nd4jPointer *functions); + + template + bool hasGEMV(); + + template + bool hasGEMM(); + + bool hasGEMM(const sd::DataType dtype); + bool hasGEMV(const sd::DataType dtype); + + template + bool hasBatchedGEMM(); + + CblasSgemv sgemv(); + CblasDgemv dgemv(); + + CblasSgemm sgemm(); + CblasDgemm dgemm(); + + CblasSgemmBatch sgemmBatched(); + CblasDgemmBatch dgemmBatched(); + + LapackeSgesvd sgesvd(); + LapackeDgesvd dgesvd(); + + LapackeSgesdd sgesdd(); + LapackeDgesdd dgesdd(); + + // destructor + ~BlasHelper() noexcept; +}; +} // namespace sd #endif diff --git a/libnd4j/include/helpers/ConstantHelper.h b/libnd4j/include/helpers/ConstantHelper.h index 7d4446d34eeb..187b5a49d6d2 100644 --- a/libnd4j/include/helpers/ConstantHelper.h +++ b/libnd4j/include/helpers/ConstantHelper.h @@ -18,46 +18,51 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_CONSTANTHELPER_H -#define DEV_TESTS_CONSTANTHELPER_H +#ifndef SD_CONSTANTHELPER_H +#define SD_CONSTANTHELPER_H -#include +#include +#include +#include +#include #include +#include #include -#include -#include + #include #include -#include -#include -#include +#include namespace sd { - class ND4J_EXPORT ConstantHelper { - private: - ConstantHelper(); +class SD_EXPORT ConstantHelper { + private: + + ConstantHelper(); + + std::vector> _cache; - std::vector> _cache; + // tracking of per-device constant memory buffers (CUDA only atm) + std::vector _devicePointers; + std::vector _deviceOffsets; + std::mutex _mutex; + std::mutex _mutexHolder; - // tracking of per-device constant memory buffers (CUDA only atm) - std::vector _devicePointers; - std::vector _deviceOffsets; - std::mutex _mutex; - std::mutex _mutexHolder; + std::vector _counters; - std::vector _counters; - public: - ~ConstantHelper(); + public: + ~ConstantHelper(); - static ConstantHelper& getInstance(); - static int getCurrentDevice(); - static int getNumberOfDevices(); - void* replicatePointer(void *src, size_t numBytes, memory::Workspace *workspace = nullptr); + static ConstantHelper& getInstance(); + static int getCurrentDevice(); + static int getNumberOfDevices(); + void* replicatePointer(void* src, size_t numBytes, + memory::Workspace* workspace = nullptr); - ConstantDataBuffer* constantBuffer(const ConstantDescriptor &descriptor, sd::DataType dataType); + ConstantDataBuffer* constantBuffer(const ConstantDescriptor& descriptor, + sd::DataType dataType); - Nd4jLong getCachedAmount(int deviceId); - }; -} + Nd4jLong getCachedAmount(int deviceId); +}; +} // namespace sd -#endif //DEV_TESTS_CONSTANTHELPER_H +#endif // SD_CONSTANTHELPER_H diff --git a/libnd4j/include/helpers/ConstantShapeHelper.h b/libnd4j/include/helpers/ConstantShapeHelper.h index 25440e05c58c..bc413af7a884 100644 --- a/libnd4j/include/helpers/ConstantShapeHelper.h +++ b/libnd4j/include/helpers/ConstantShapeHelper.h @@ -18,11 +18,16 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_CONSTANTSHAPEHELPER_H -#define DEV_TESTS_CONSTANTSHAPEHELPER_H +#ifndef SD_CONSTANTSHAPEHELPER_H +#define SD_CONSTANTSHAPEHELPER_H +#include +#include +#include #include +#include #include + #include #include #include @@ -33,64 +38,73 @@ namespace sd { - class ND4J_EXPORT ConstantShapeHelper { - private: - std::mutex _mutex; - std::vector> _cache; - - - ConstantShapeHelper(); - public: - ~ConstantShapeHelper() = default; - - static ConstantShapeHelper & getInstance(); - - - ConstantShapeBuffer& bufferForShapeInfo(sd::DataType dataType, char order, const std::vector &shape); - ConstantShapeBuffer& bufferForShapeInfo(const ShapeDescriptor &descriptor); - ConstantShapeBuffer& bufferForShapeInfo(const Nd4jLong *shapeInfo); - ConstantShapeBuffer& bufferForShapeInfo(sd::DataType dataType, char order, int rank, const Nd4jLong* shape); - ConstantShapeBuffer& createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace = nullptr, const std::vector &dimensions = {}); - - - const Nd4jLong* emptyShapeInfo(sd::DataType dataType); - const Nd4jLong* scalarShapeInfo(sd::DataType dataType); - const Nd4jLong* vectorShapeInfo(Nd4jLong length, sd::DataType dataType); - const Nd4jLong* createShapeInfo(const ShapeDescriptor &descriptor); - const Nd4jLong* createShapeInfo(sd::DataType dataType, char order, const std::vector &shape); - const Nd4jLong* createShapeInfo(sd::DataType dataType, char order, int rank, const Nd4jLong* shape); - const Nd4jLong* createShapeInfo(sd::DataType dataType, const Nd4jLong* shapeInfo); - - const Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, sd::memory::Workspace *workspace); - const Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal = true); - - bool checkBufferExistenceForShapeInfo(ShapeDescriptor &descriptor); - - - /** - * This method returns number of cached TAD shapes/offsets on specific device - * @return - */ - FORCEINLINE int cachedEntriesForDevice(int deviceId) { - if (deviceId > _cache.size()) - throw std::runtime_error("deviceId > number of actual devices"); - - return _cache[deviceId].size(); - } - - /** - * This method returns total number of cached TAD shapes/offsets on all devices - * @return - */ - FORCEINLINE int totalCachedEntries() { - int total = 0; - - for (int e = 0; e < _cache.size(); e++) - total += _cache[e].size(); - - return total; - } - }; -} - -#endif //DEV_TESTS_CONSTANTSHAPEHELPER_H +class SD_EXPORT ConstantShapeHelper { + private: + + + std::mutex _mutex; + std::vector> _cache; + + ConstantShapeHelper(); + + public: + ~ConstantShapeHelper() = default; + + static ConstantShapeHelper & getInstance(); + + ConstantShapeBuffer& bufferForShapeInfo(sd::DataType dataType, char order, + const std::vector& shape); + ConstantShapeBuffer& bufferForShapeInfo(const ShapeDescriptor& descriptor); + ConstantShapeBuffer& bufferForShapeInfo(const Nd4jLong* shapeInfo); + ConstantShapeBuffer& bufferForShapeInfo(sd::DataType dataType, char order, + int rank, const Nd4jLong* shape); + ConstantShapeBuffer& createShapeInfoWithUnitiesForBroadcast( + const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, + sd::memory::Workspace* workspace = nullptr, + const std::vector& dimensions = {}); + + const Nd4jLong* emptyShapeInfo(sd::DataType dataType); + const Nd4jLong* scalarShapeInfo(sd::DataType dataType); + const Nd4jLong* vectorShapeInfo(Nd4jLong length, sd::DataType dataType); + const Nd4jLong* createShapeInfo(const ShapeDescriptor& descriptor); + const Nd4jLong* createShapeInfo(sd::DataType dataType, char order, + const std::vector& shape); + const Nd4jLong* createShapeInfo(sd::DataType dataType, char order, int rank, + const Nd4jLong* shape); + const Nd4jLong* createShapeInfo(sd::DataType dataType, + const Nd4jLong* shapeInfo); + + const Nd4jLong* createFromExisting(Nd4jLong* shapeInfo, + sd::memory::Workspace* workspace); + const Nd4jLong* createFromExisting(Nd4jLong* shapeInfo, + bool destroyOriginal = true); + + bool checkBufferExistenceForShapeInfo(ShapeDescriptor& descriptor); + + /** + * This method returns number of cached TAD shapes/offsets on specific device + * @return + */ + FORCEINLINE int cachedEntriesForDevice(int deviceId) { + if (deviceId > _cache.size()) + throw std::runtime_error("deviceId > number of actual devices"); + + return _cache[deviceId].size(); + } + + /** + * This method returns total number of cached TAD shapes/offsets on all + * devices + * @return + */ + FORCEINLINE int totalCachedEntries() { + int total = 0; + + for (int e = 0; e < _cache.size(); e++) total += _cache[e].size(); + + return total; + } +}; +} // namespace sd + +#endif // SD_CONSTANTSHAPEHELPER_H diff --git a/libnd4j/include/helpers/ConstantTadHelper.h b/libnd4j/include/helpers/ConstantTadHelper.h index 10bdd108d7c4..0070b5b73734 100644 --- a/libnd4j/include/helpers/ConstantTadHelper.h +++ b/libnd4j/include/helpers/ConstantTadHelper.h @@ -18,70 +18,80 @@ // @author raver119@gmail.com // +#ifndef SD_CONSTANTTADHELPER_H +#define SD_CONSTANTTADHELPER_H -#ifndef DEV_TESTS_CONSTANTTADHELPER_H -#define DEV_TESTS_CONSTANTTADHELPER_H - +#include +#include +#include #include #include #include + #include -#include #include -#include -#include -#include +#include namespace sd { - class ND4J_EXPORT ConstantTadHelper { - private: - std::mutex _mutex; - std::vector> _cache; - - ConstantTadHelper(); - public: - ~ConstantTadHelper() = default; - - static ConstantTadHelper & getInstance(); - - /** - * These methods calculate Tensor-Along-Dimension(s) shape and offsets - * - * @param originalShape - * @param dimensions - * @param keepUnitiesInShape - * @return - */ - TadPack tadForDimensions(const Nd4jLong *originalShape, const std::vector &dimensions, const bool keepUnitiesInShape = false); - TadPack tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false); - TadPack tadForDimensions(const Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false); - TadPack tadForDimensions(ShapeDescriptor &descriptor, std::vector &dimensions, const bool keepUnitiesInShape = false); - TadPack tadForDimensions(TadDescriptor &descriptor); - - /** - * This method returns number of cached TAD shapes/offsets on specific device - * @return - */ - FORCEINLINE int cachedEntriesForDevice(int deviceId) { - if (deviceId > _cache.size()) - throw std::runtime_error("deviceId > number of actual devices"); - - return _cache[deviceId].size(); - } - - /** - * This method returns total number of cached TAD shapes/offsets on all devices - * @return - */ - FORCEINLINE int totalCachedEntries() { - int total = 0; - - for (int e = 0; e < _cache.size(); e++) - total += _cache[e].size(); - - return total; - } - }; -} - -#endif //DEV_TESTS_CONSTANTTADHELPER_H +class SD_EXPORT ConstantTadHelper { + private: + + + std::mutex _mutex; + std::vector> _cache; + + ConstantTadHelper(); + + public: + ~ConstantTadHelper() = default; + + static ConstantTadHelper &getInstance(); + + /** + * These methods calculate Tensor-Along-Dimension(s) shape and offsets + * + * @param originalShape + * @param dimensions + * @param keepUnitiesInShape + * @return + */ + TadPack tadForDimensions(const Nd4jLong *originalShape, + const std::vector &dimensions, + const bool keepUnitiesInShape = false); + TadPack tadForDimensions(const Nd4jLong *originalShape, int *dimensions, + int dimLength, + const bool keepUnitiesInShape = false); + TadPack tadForDimensions(const Nd4jLong *originalShape, int dimensions, + const bool keepUnitiesInShape = false); + TadPack tadForDimensions(ShapeDescriptor &descriptor, + std::vector &dimensions, + const bool keepUnitiesInShape = false); + TadPack tadForDimensions(TadDescriptor &descriptor); + + /** + * This method returns number of cached TAD shapes/offsets on specific device + * @return + */ + FORCEINLINE int cachedEntriesForDevice(int deviceId) { + if (deviceId > _cache.size()) + throw std::runtime_error("deviceId > number of actual devices"); + + return _cache[deviceId].size(); + } + + /** + * This method returns total number of cached TAD shapes/offsets on all + * devices + * @return + */ + FORCEINLINE int totalCachedEntries() { + int total = 0; + + for (int e = 0; e < _cache.size(); e++) total += _cache[e].size(); + + return total; + } +}; +} // namespace sd + +#endif // SD_CONSTANTTADHELPER_H diff --git a/libnd4j/include/helpers/CudaLaunchHelper.h b/libnd4j/include/helpers/CudaLaunchHelper.h index 6bf44317fd86..8d5cc889f397 100644 --- a/libnd4j/include/helpers/CudaLaunchHelper.h +++ b/libnd4j/include/helpers/CudaLaunchHelper.h @@ -21,19 +21,18 @@ #ifndef LIBND4J_CUDALAUNCHHELPER_H #define LIBND4J_CUDALAUNCHHELPER_H - -#include #include #include +#include #include namespace sd { - class ND4J_EXPORT CudaLaunchHelper { - public: - static Triple getFlatLaunchParams(Nd4jLong length, int SM, int CORES, int SHARED_MEMORY); - static int getReductionBlocks(Nd4jLong xLength, int blockSize = 512); - }; -} - +class SD_EXPORT CudaLaunchHelper { + public: + static Triple getFlatLaunchParams(Nd4jLong length, int SM, int CORES, + int SHARED_MEMORY); + static int getReductionBlocks(Nd4jLong xLength, int blockSize = 512); +}; +} // namespace sd -#endif //LIBND4J_CUDALAUNCHHELPER_H +#endif // LIBND4J_CUDALAUNCHHELPER_H diff --git a/libnd4j/include/helpers/DebugHelper.h b/libnd4j/include/helpers/DebugHelper.h index 10bb1dc90647..2cbca70496fc 100644 --- a/libnd4j/include/helpers/DebugHelper.h +++ b/libnd4j/include/helpers/DebugHelper.h @@ -21,60 +21,64 @@ #ifndef LIBND4J_DEBUGHELPER_H #define LIBND4J_DEBUGHELPER_H -#include -#include -#include #include -#include +#include +#include +#include +#include #ifdef __CUDACC__ #include -#include #include +#include #endif #include namespace sd { - class NDArray; - class ND4J_EXPORT DebugHelper { - public: - - // cuda-specific debug functions +class NDArray; +class SD_EXPORT DebugHelper { + public: + // cuda-specific debug functions #ifdef __CUDACC__ - static FORCEINLINE void checkErrorCode(cudaStream_t *stream, int opNum = 0) { - if (Environment::getInstance().isDebug()) { - cudaError_t res = cudaStreamSynchronize(*stream); + static FORCEINLINE void checkErrorCode(cudaStream_t* stream, int opNum = 0) { + if (Environment::getInstance().isDebug()) { + cudaError_t res = cudaStreamSynchronize(*stream); - if (res != 0) { - //PRINT_FIRST("Kernel OpNum failed: [%i]\n", opNum); - std::string op = "Kernel OpNum failed: ["; - op += StringUtils::valueToString(opNum); - op += "]"; + if (res != 0) { + // PRINT_FIRST("Kernel OpNum failed: [%i]\n", opNum); + std::string op = "Kernel OpNum failed: ["; + op += StringUtils::valueToString(opNum); + op += "]"; - throw std::runtime_error(op); - } - } - } + throw std::runtime_error(op); + } + } + } - static FORCEINLINE void checkErrorCode(cudaStream_t *stream, const char *failMessage = nullptr) { - cudaError_t res = cudaStreamSynchronize(*stream); - if (res != 0) { - if (failMessage == nullptr) { - std::string op = "CUDA call ended with error code [" + StringUtils::valueToString(res) + std::string("]"); - throw std::runtime_error(op); - } else { - std::string op = std::string(failMessage) + std::string("Error code [") + StringUtils::valueToString(res) + std::string("]"); - throw std::runtime_error(op); - } - } - } + static FORCEINLINE void checkErrorCode(cudaStream_t* stream, + const char* failMessage = nullptr) { + cudaError_t res = cudaStreamSynchronize(*stream); + if (res != 0) { + if (failMessage == nullptr) { + std::string op = "CUDA call ended with error code [" + + StringUtils::valueToString(res) + + std::string("]"); + throw std::runtime_error(op); + } else { + std::string op = + std::string(failMessage) + std::string("Error code [") + + StringUtils::valueToString(res) + std::string("]"); + throw std::runtime_error(op); + } + } + } #endif - static DebugInfo debugStatistics(NDArray const* input); - static void retrieveDebugStatistics(DebugInfo* statistics, NDArray const* input); - }; -} - + static DebugInfo debugStatistics(NDArray const* input); + static void retrieveDebugStatistics(DebugInfo* statistics, + NDArray const* input); +}; +} // namespace sd -#endif //LIBND4J_DEBUGHELPER_H +#endif // LIBND4J_DEBUGHELPER_H diff --git a/libnd4j/include/helpers/DebugInfo.h b/libnd4j/include/helpers/DebugInfo.h index c2efb00fe576..3d82928d2bcb 100644 --- a/libnd4j/include/helpers/DebugInfo.h +++ b/libnd4j/include/helpers/DebugInfo.h @@ -21,49 +21,49 @@ #ifndef LIBND4J__DEBUG_INFO_HELPER__H #define LIBND4J__DEBUG_INFO_HELPER__H -#include -#include -#include #include -#include -#include #include +#include +#include +#include +#include + +#include #ifdef __CUDACC__ #include -#include #include +#include #endif namespace sd { - struct ND4J_EXPORT DebugInfo { - double _minValue; - double _maxValue; - double _meanValue; - double _stdDevValue; - Nd4jLong _zeroCount; - Nd4jLong _positiveCount; - Nd4jLong _negativeCount; - Nd4jLong _infCount; - Nd4jLong _nanCount; - }; - - FORCEINLINE bool operator==(DebugInfo const& first, DebugInfo const& second) { - return sd::math::nd4j_abs(first._minValue - second._minValue) < 0.000001 && - sd::math::nd4j_abs(first._maxValue - second._maxValue) < 0.000001 && - sd::math::nd4j_abs(first._meanValue - second._meanValue) < 0.000001 && - sd::math::nd4j_abs(first._stdDevValue - second._stdDevValue) < 0.000001 && - first._zeroCount == second._zeroCount && - first._positiveCount == second._positiveCount && - first._negativeCount == second._negativeCount && - first._infCount == second._infCount && - first._nanCount == second._nanCount; - - } +struct SD_EXPORT DebugInfo { + double _minValue; + double _maxValue; + double _meanValue; + double _stdDevValue; + Nd4jLong _zeroCount; + Nd4jLong _positiveCount; + Nd4jLong _negativeCount; + Nd4jLong _infCount; + Nd4jLong _nanCount; +}; +FORCEINLINE bool operator==(DebugInfo const& first, DebugInfo const& second) { + return sd::math::nd4j_abs(first._minValue - second._minValue) < 0.000001 && + sd::math::nd4j_abs(first._maxValue - second._maxValue) < 0.000001 && + sd::math::nd4j_abs(first._meanValue - second._meanValue) < 0.000001 && + sd::math::nd4j_abs(first._stdDevValue - second._stdDevValue) < + 0.000001 && + first._zeroCount == second._zeroCount && + first._positiveCount == second._positiveCount && + first._negativeCount == second._negativeCount && + first._infCount == second._infCount && + first._nanCount == second._nanCount; } +} // namespace sd -#endif //LIBND4J_DEBUGHELPER_H +#endif // LIBND4J_DEBUGHELPER_H diff --git a/libnd4j/include/helpers/EigenValsAndVecs.h b/libnd4j/include/helpers/EigenValsAndVecs.h index 222b9c36ed36..f4d52f97f868 100644 --- a/libnd4j/include/helpers/EigenValsAndVecs.h +++ b/libnd4j/include/helpers/EigenValsAndVecs.h @@ -30,51 +30,52 @@ namespace helpers { // this class calculates eigenvalues and eigenvectors of given input matrix template class EigenValsAndVecs { + public: + // suppose we got input square NxN matrix - public: - // suppose we got input square NxN matrix + // {N,2} matrix of eigenvalues, 2 means real and imaginary part + NDArray _Vals; + // {N,N,2} matrix, whose columns are the eigenvectors (complex), 2 means real and imaginary part + NDArray _Vecs; - NDArray _Vals; // {N,2} matrix of eigenvalues, 2 means real and imaginary part - NDArray _Vecs; // {N,N,2} matrix, whose columns are the eigenvectors (complex), 2 means real and imaginary part + explicit EigenValsAndVecs(const NDArray& matrix); - explicit EigenValsAndVecs(const NDArray& matrix); + ////////////////////////////////////////////////////////////////////////// + FORCEINLINE static void divideComplexNums(const T& a1, const T& b1, const T& a2, const T& b2, T& a3, T& b3) { + T norm2 = a2*a2 + b2*b2; + a3 = (a1*a2 + b1*b2) / norm2; + b3 = (a2*b1 - a1*b2) / norm2; + } - ////////////////////////////////////////////////////////////////////////// - FORCEINLINE static void divideComplexNums(const T& a1, const T& b1, const T& a2, const T& b2, T& a3, T& b3) { + ////////////////////////////////////////////////////////////////////////// + FORCEINLINE static void multiplyComplexNums(const T& a1, const T& b1, const T& a2, const T& b2, T& a3, T& b3) { + a3 = (a1*a2 - b1*b2); + b3 = (a1*b2 + b1*a2); + } - T norm2 = a2*a2 + b2*b2; + ////////////////////////////////////////////////////////////////////////// + FORCEINLINE static void sqrtComplexNum(T& a, T& b) { + T norm = math::nd4j_sqrt(a*a + b*b); - a3 = (a1*a2 + b1*b2) / norm2; - b3 = (a2*b1 - a1*b2) / norm2; - } + if(b < (T)0) + b = -math::nd4j_sqrt((T)0.5 * (norm - a)); + else + b = math::nd4j_sqrt((T)0.5 * (norm - a)); - ////////////////////////////////////////////////////////////////////////// - FORCEINLINE static void multiplyComplexNums(const T& a1, const T& b1, const T& a2, const T& b2, T& a3, T& b3) { + a = math::nd4j_sqrt((T)0.5 * (norm + a)); + } - a3 = (a1*a2 - b1*b2); - b3 = (a1*b2 + b1*a2); - } - ////////////////////////////////////////////////////////////////////////// - FORCEINLINE static void sqrtComplexNum(T& a, T& b) { + private: + // calculates _Vals + void calcEigenVals(const NDArray& schurMatrixT); - T norm = math::nd4j_sqrt(a*a + b*b); - - if(b < (T)0) - b = -math::nd4j_sqrt((T)0.5 * (norm - a)); - else - b = math::nd4j_sqrt((T)0.5 * (norm - a)); - a = math::nd4j_sqrt((T)0.5 * (norm + a)); - } - - - private: - - void calcEigenVals(const NDArray& schurMatrixT); // calculates _Vals - void calcPseudoEigenVecs(NDArray& schurMatrixT, NDArray& schurMatrixU); // makes changes both in schurMatrixT(NxN) and schurMatrixU(NxN), also calculates and stores pseudo-eigenvectors (real) in schurMatrixU columns - void calcEigenVecs(const NDArray& schurMatrixU); // calculates _Vecs + // makes changes both in schurMatrixT(NxN) and schurMatrixU(NxN), also calculates and stores pseudo-eigenvectors (real) in schurMatrixU columns + void calcPseudoEigenVecs(NDArray& schurMatrixT, NDArray& schurMatrixU); + // calculates _Vecs + void calcEigenVecs(const NDArray& schurMatrixU); }; diff --git a/libnd4j/include/helpers/EnumUtils.h b/libnd4j/include/helpers/EnumUtils.h index 6138117c7e60..6d6fe1facbc8 100644 --- a/libnd4j/include/helpers/EnumUtils.h +++ b/libnd4j/include/helpers/EnumUtils.h @@ -25,12 +25,13 @@ #include namespace sd { - class EnumUtils { - public: - static const char * _VariableTypeToString(sd::graph::VariableType variableType); - static const char * _OpTypeToString(sd::graph::OpType opType); - static const char * _LogicOpToString(int opNum); - }; -} +class EnumUtils { + public: + static const char* _VariableTypeToString( + sd::graph::VariableType variableType); + static const char* _OpTypeToString(sd::graph::OpType opType); + static const char* _LogicOpToString(int opNum); +}; +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/helpers/FileUtils.h b/libnd4j/include/helpers/FileUtils.h new file mode 100644 index 000000000000..cbcda199a4da --- /dev/null +++ b/libnd4j/include/helpers/FileUtils.h @@ -0,0 +1,37 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_FILEUTILS_H +#define SD_FILEUTILS_H + +#include + +#include + +namespace sd { +class SD_EXPORT FileUtils { + public: + static bool fileExists(const char *filename); + + static int64_t fileSize(const char *filename); +}; +} // namespace sd + +#endif // SD_FILEUTILS_H diff --git a/libnd4j/include/helpers/GradCheck.h b/libnd4j/include/helpers/GradCheck.h index 9ca18a82b907..f7abcaee50cc 100644 --- a/libnd4j/include/helpers/GradCheck.h +++ b/libnd4j/include/helpers/GradCheck.h @@ -21,50 +21,54 @@ #ifndef LIBND4J_GRADCHECK_H #define LIBND4J_GRADCHECK_H - #include #include namespace sd { -class ND4J_EXPORT GradCheck { - - public: - enum LossFunc {MEAN = 0, SUM = 1}; - private: - static constexpr double EPSILON = 1e-5; - static constexpr double MAXRELERR = 1e-5; - static constexpr double MINABSERR = 1e-6; - static void fillGradArrays(const LossFunc loss, const std::vector& gradArrs); - - - public: - - /** - * performs numerical check of gradients in back prop - * - * opFF - feed forward operation - * opBP - back propagation operation - * argsHolderFF - argument holder for feed forward operation - * argsHolderBP - argument holder for back propagation operation - * whatArrsToCheck - specifies what output gradient arrays to check, for example {0, 1, 0} means that only second output gradient array will be checked, default value is empty std::vector which means to check all arrays - * IdxRange - specifies indexes range over which array elements will be checked, for example {0.2, 0.7} means range [0.2*array_length, 0.7*array_length), default value is {0., 1.} - * loss - type of scalar loss function, it specifies what elements values will be filled into input gradient arrays automatically, default value is SUM - */ - static bool checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP, - const std::vector& whatArrsToCheck = std::vector(), const std::vector& IdxRange = {0., 1.}, const LossFunc loss = SUM); +class SD_EXPORT GradCheck { + public: + enum LossFunc { MEAN = 0, SUM = 1 }; + + private: + static constexpr double EPSILON = 1e-5; + static constexpr double MAXRELERR = 1e-5; + static constexpr double MINABSERR = 1e-6; + static void fillGradArrays(const LossFunc loss, + const std::vector& gradArrs); + + public: + /** + * performs numerical check of gradients in back prop + * + * opFF - feed forward operation + * opBP - back propagation operation + * argsHolderFF - argument holder for feed forward operation + * argsHolderBP - argument holder for back propagation operation + * whatArrsToCheck - specifies what output gradient arrays to check, for + * example {0, 1, 0} means that only second output gradient array will be + * checked, default value is empty std::vector which means to check all arrays + * IdxRange - specifies indexes range over which array elements will be + * checked, for example {0.2, 0.7} means range [0.2*array_length, + * 0.7*array_length), default value is {0., 1.} loss - type of scalar loss + * function, it specifies what elements values will be filled into input + * gradient arrays automatically, default value is SUM + */ + static bool checkGrad( + ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, + const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP, + const std::vector& whatArrsToCheck = std::vector(), + const std::vector& IdxRange = {0., 1.}, + const LossFunc loss = SUM); }; - - - - // ////////////////////////////////////////////////////////////////////////// // ///// IMLEMENTATION OF INLINE METHODS ///// // ////////////////////////////////////////////////////////////////////////// // template -// FORCEINLINE bool ShapeUtils::isPermutNecessary(const std::vector& permut) { +// FORCEINLINE bool ShapeUtils::isPermutNecessary(const std::vector& +// permut) { // for(int i=0; i #include // #include @@ -33,243 +32,262 @@ namespace sd { - -class ND4J_EXPORT LoopKind { - - public: - enum Kind { SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON, BROADCAST_SCALAR_X, BROADCAST_SCALAR_Y, BROADCAST_3D, BROADCAST_4D, BROADCAST_5D }; - - static FORCEINLINE Kind deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo); - static FORCEINLINE Kind deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo); - static FORCEINLINE Kind deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo); - static FORCEINLINE Kind deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo); - static FORCEINLINE Kind deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo); - +class SD_EXPORT LoopKind { + public: + enum Kind { + SMALLARR2DX, + EWS1, + EWSNONZERO, + RANK1, + RANK2, + RANK3, + RANK4, + RANK5, + X_EWSNONZERO, + Y_EWSNONZERO, + Z_EWSNONZERO, + COMMON, + BROADCAST_SCALAR_X, + BROADCAST_SCALAR_Y, + BROADCAST_3D, + BROADCAST_4D, + BROADCAST_5D + }; + + static FORCEINLINE Kind deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, + const Nd4jLong* zShapeInfo); + static FORCEINLINE Kind deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, + const Nd4jLong* yShapeInfo, + const Nd4jLong* zShapeInfo); + static FORCEINLINE Kind deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo); + static FORCEINLINE Kind deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, + const Nd4jLong* yTadShapeInfo, + const Nd4jLong* zShapeInfo); + static FORCEINLINE Kind deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, + const Nd4jLong* yShapeInfo, + const Nd4jLong* zShapeInfo); }; ////////////////////////////////////////////////////////////////////////////// -LoopKind::Kind LoopKind::deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo) { - - const int xRank = shape::rank(xShapeInfo); - const Nd4jLong xEws = shape::elementWiseStride(xShapeInfo); - const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo); - - const char xOrder = shape::order(xShapeInfo); - const char zOrder = shape::order(zShapeInfo); - - int temp; - const bool xVectorOrC = shape::isCommonVector(xShapeInfo, temp) || xOrder == 'c'; - const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c'; - const bool shapesSame = shape::shapeEquals(xShapeInfo, zShapeInfo); - - if (xEws == 1 && zEws == 1 && xOrder == zOrder && (shapesSame || xOrder == 'c')) - return EWS1; - if(xEws > 0 && zEws > 0 && ((xOrder == zOrder && (shapesSame || xOrder == 'c')) || (xVectorOrC && zVectorOrC))) - return EWSNONZERO; - if(xRank == 1 && shapesSame) - return RANK1; - if(xRank == 2 && shapesSame) - return RANK2; - if(xRank == 3 && shapesSame) - return RANK3; - if(xRank == 4 && shapesSame) - return RANK4; - if(xRank == 5 && shapesSame) - return RANK5; - if(xEws > 0 && xVectorOrC) - return X_EWSNONZERO; - if(zEws > 0 && zVectorOrC) - return Z_EWSNONZERO; - return COMMON; +LoopKind::Kind LoopKind::deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, + const Nd4jLong* zShapeInfo) { + const int xRank = shape::rank(xShapeInfo); + const Nd4jLong xEws = shape::elementWiseStride(xShapeInfo); + const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo); + + const char xOrder = shape::order(xShapeInfo); + const char zOrder = shape::order(zShapeInfo); + + int temp; + const bool xVectorOrC = + shape::isCommonVector(xShapeInfo, temp) || xOrder == 'c'; + const bool zVectorOrC = + shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c'; + const bool shapesSame = shape::shapeEquals(xShapeInfo, zShapeInfo); + + if (xEws == 1 && zEws == 1 && xOrder == zOrder && + (shapesSame || xOrder == 'c')) + return EWS1; + if (xEws > 0 && zEws > 0 && + ((xOrder == zOrder && (shapesSame || xOrder == 'c')) || + (xVectorOrC && zVectorOrC))) + return EWSNONZERO; + if (xRank == 1 && shapesSame) return RANK1; + if (xRank == 2 && shapesSame) return RANK2; + if (xRank == 3 && shapesSame) return RANK3; + if (xRank == 4 && shapesSame) return RANK4; + if (xRank == 5 && shapesSame) return RANK5; + if (xEws > 0 && xVectorOrC) return X_EWSNONZERO; + if (zEws > 0 && zVectorOrC) return Z_EWSNONZERO; + return COMMON; } -LoopKind::Kind LoopKind::deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo) { - auto xRank = shape::rank(xShapeInfo); - auto yRank = shape::rank(yShapeInfo); - auto zRank = shape::rank(zShapeInfo); - - auto xOrder = shape::order(xShapeInfo); - auto yOrder = shape::order(yShapeInfo); - auto zOrder = shape::order(zShapeInfo); - - auto xEws = shape::elementWiseStride(xShapeInfo); - auto yEws = shape::elementWiseStride(yShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); - - bool bNDLoopsRanks = (xRank == zRank && yRank <= xRank && yRank >= 2); - - int countUnityDimsInY = 0, countUnityDimsInX = 0; - for (int i = 0; i < xRank; i++) { - if (i < yRank) - countUnityDimsInY += (1 == shape::sizeAt(yShapeInfo, i)) ? 1 : 0; - countUnityDimsInX += (1 == shape::sizeAt(xShapeInfo, i)) ? 1 : 0; +LoopKind::Kind LoopKind::deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, + const Nd4jLong* yShapeInfo, + const Nd4jLong* zShapeInfo) { + auto xRank = shape::rank(xShapeInfo); + auto yRank = shape::rank(yShapeInfo); + auto zRank = shape::rank(zShapeInfo); + + auto xOrder = shape::order(xShapeInfo); + auto yOrder = shape::order(yShapeInfo); + auto zOrder = shape::order(zShapeInfo); + + auto xEws = shape::elementWiseStride(xShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + + bool bNDLoopsRanks = (xRank == zRank && yRank <= xRank && yRank >= 2); + + int countUnityDimsInY = 0, countUnityDimsInX = 0; + for (int i = 0; i < xRank; i++) { + if (i < yRank) + countUnityDimsInY += (1 == shape::sizeAt(yShapeInfo, i)) ? 1 : 0; + countUnityDimsInX += (1 == shape::sizeAt(xShapeInfo, i)) ? 1 : 0; + } + + bool bNotCommonVectorCase = + (countUnityDimsInY != yRank - 1) && (countUnityDimsInX != xRank - 1); + + if (bNDLoopsRanks && bNotCommonVectorCase) { + // case x[3,4,5] * y[1,4,5] = z[3,4,5] or reverse x[1,4,5] + y[3,4,5] = + // z[3,4,5] + if (sd::LoopKind::EWS1 == + deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo) && + (1 == shape::sizeAt(yShapeInfo, 0) || + 1 == shape::sizeAt(xShapeInfo, 0))) { + return EWS1; } - bool bNotCommonVectorCase = (countUnityDimsInY != yRank - 1) && (countUnityDimsInX != xRank - 1); - - - if (bNDLoopsRanks && bNotCommonVectorCase) { - // case x[3,4,5] * y[1,4,5] = z[3,4,5] or reverse x[1,4,5] + y[3,4,5] = z[3,4,5] - if (sd::LoopKind::EWS1 == deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo) - && (1 == shape::sizeAt(yShapeInfo, 0) || 1 == shape::sizeAt(xShapeInfo, 0))) { - return EWS1; - } - - if (3 == xRank) - return sd::LoopKind::BROADCAST_3D; - if (4 == xRank) - return sd::LoopKind::BROADCAST_4D; - if (5 == xRank) - return sd::LoopKind::BROADCAST_5D; + if (3 == xRank) return sd::LoopKind::BROADCAST_3D; + if (4 == xRank) return sd::LoopKind::BROADCAST_4D; + if (5 == xRank) return sd::LoopKind::BROADCAST_5D; + } + if (xRank == yRank && xRank == zRank && xOrder == 'c' && yOrder == 'c' && + zOrder == 'c' && xEws == 1 && yEws == 1 && zEws == 1 && xRank >= 2) { + // we validate that shapes are equal till the last dim + for (int e = 0; e < xRank - 1; e++) { + if (xShapeInfo[e + 1] != yShapeInfo[e + 1]) return COMMON; } + // now, if one of the shapes has 1 as last dim + auto detect = + xShapeInfo[xRank] == 1 ? -1 : (yShapeInfo[xRank] == 1) ? 1 : 0; - if (xRank == yRank && xRank == zRank && xOrder == 'c' && yOrder == 'c' && zOrder == 'c' && xEws == 1 && yEws == 1 && zEws == 1 && xRank >= 2) { - // we validate that shapes are equal till the last dim - for (int e = 0; e < xRank - 1; e++) { - if (xShapeInfo[e+1] != yShapeInfo[e+1]) - return COMMON; - } - - // now, if one of the shapes has 1 as last dim - auto detect = xShapeInfo[xRank] == 1 ? -1 : (yShapeInfo[xRank] == 1) ? 1 : 0; + if (detect == 1) + return sd::LoopKind::BROADCAST_SCALAR_Y; + else if (detect == -1) + return sd::LoopKind::BROADCAST_SCALAR_X; + } - if (detect == 1) - return sd::LoopKind::BROADCAST_SCALAR_Y; - else if (detect == -1) - return sd::LoopKind::BROADCAST_SCALAR_X; - } - - return sd::LoopKind::COMMON; + return sd::LoopKind::COMMON; } ////////////////////////////////////////////////////////////////////////////// -LoopKind::Kind LoopKind::deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo) { - - const int xRank = shape::rank(xShapeInfo); - const Nd4jLong xEws = shape::elementWiseStride(xShapeInfo); - const Nd4jLong yEws = shape::elementWiseStride(yShapeInfo); - const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo); - - const char xOrder = shape::order(xShapeInfo); - const char yOrder = shape::order(yShapeInfo); - const char zOrder = shape::order(zShapeInfo); - - int temp; - const bool xVectorOrC = shape::isCommonVector(xShapeInfo, temp) || xOrder == 'c'; - const bool yVectorOrC = shape::isCommonVector(yShapeInfo, temp) || yOrder == 'c'; - const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c'; - const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo, zShapeInfo); - - if (xEws == 1 && yEws == 1 && zEws == 1 && xOrder == yOrder && xOrder == zOrder && (shapesSame || xOrder == 'c')) - return EWS1; - if(xEws > 0 && yEws > 0 && zEws > 0 && ((xOrder == yOrder && xOrder == zOrder && (shapesSame || xOrder == 'c')) || (xVectorOrC && yVectorOrC && zVectorOrC))) - return EWSNONZERO; - if(xRank == 1 && shapesSame) - return RANK1; - if(xRank == 2 && shapesSame) - return RANK2; - if(xRank == 3 && shapesSame) - return RANK3; - if(xRank == 4 && shapesSame) - return RANK4; - if(xRank == 5 && shapesSame) - return RANK5; - if(xEws > 0 && xVectorOrC) - return X_EWSNONZERO; - if(yEws > 0 && yVectorOrC) - return Y_EWSNONZERO; - if(zEws > 0 && zVectorOrC) - return Z_EWSNONZERO; - return COMMON; +LoopKind::Kind LoopKind::deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, + const Nd4jLong* yShapeInfo, + const Nd4jLong* zShapeInfo) { + const int xRank = shape::rank(xShapeInfo); + const Nd4jLong xEws = shape::elementWiseStride(xShapeInfo); + const Nd4jLong yEws = shape::elementWiseStride(yShapeInfo); + const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo); + + const char xOrder = shape::order(xShapeInfo); + const char yOrder = shape::order(yShapeInfo); + const char zOrder = shape::order(zShapeInfo); + + int temp; + const bool xVectorOrC = + shape::isCommonVector(xShapeInfo, temp) || xOrder == 'c'; + const bool yVectorOrC = + shape::isCommonVector(yShapeInfo, temp) || yOrder == 'c'; + const bool zVectorOrC = + shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c'; + const bool shapesSame = + shape::shapeEquals(xShapeInfo, yShapeInfo, zShapeInfo); + + if (xEws == 1 && yEws == 1 && zEws == 1 && xOrder == yOrder && + xOrder == zOrder && (shapesSame || xOrder == 'c')) + return EWS1; + if (xEws > 0 && yEws > 0 && zEws > 0 && + ((xOrder == yOrder && xOrder == zOrder && + (shapesSame || xOrder == 'c')) || + (xVectorOrC && yVectorOrC && zVectorOrC))) + return EWSNONZERO; + if (xRank == 1 && shapesSame) return RANK1; + if (xRank == 2 && shapesSame) return RANK2; + if (xRank == 3 && shapesSame) return RANK3; + if (xRank == 4 && shapesSame) return RANK4; + if (xRank == 5 && shapesSame) return RANK5; + if (xEws > 0 && xVectorOrC) return X_EWSNONZERO; + if (yEws > 0 && yVectorOrC) return Y_EWSNONZERO; + if (zEws > 0 && zVectorOrC) return Z_EWSNONZERO; + return COMMON; } - ////////////////////////////////////////////////////////////////////////////// -LoopKind::Kind LoopKind::deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo) { - - const int xRank = shape::rank(xShapeInfo); - const int tRank = shape::rank(tadShapeInfo); - - const Nd4jLong xEws = shape::elementWiseStride(xShapeInfo); - const Nd4jLong tEws = shape::elementWiseStride(tadShapeInfo); - const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo); - - const char xOrder = shape::order(xShapeInfo); - const char tOrder = shape::order(tadShapeInfo); - const char zOrder = shape::order(zShapeInfo); - - const bool allC = (tOrder == zOrder && zOrder == 'c'); - - int temp; - const bool tVectorOrC = shape::isCommonVector(tadShapeInfo, temp) || tOrder == 'c'; - const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';; - - if(shape::length(tadShapeInfo) * shape::length(zShapeInfo) <= Environment::getInstance().elementwiseThreshold() && xEws == 1 && xOrder == 'c' && xRank == 2 && - tEws > 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC))) - return SMALLARR2DX; - if(tEws == 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC))) - return EWS1; - if(tEws > 0 && zEws > 0 && (allC || (tVectorOrC && zVectorOrC))) - return EWSNONZERO; - if(tRank == 1 && zEws == 1 && zVectorOrC) - return RANK1; - if(tRank == 2 && zEws == 1 && zVectorOrC) - return RANK2; - if(tRank == 3 && zEws == 1 && zVectorOrC) - return RANK3; - if(tRank == 4 && zEws == 1 && zVectorOrC) - return RANK4; - if(tRank == 5 && zEws == 1 && zVectorOrC) - return RANK5; - if(tEws > 0 && tVectorOrC && zEws == 0) - return X_EWSNONZERO; - if(zEws > 0 && zVectorOrC && tEws == 0) - return Z_EWSNONZERO; - return COMMON; +LoopKind::Kind LoopKind::deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo) { + const int xRank = shape::rank(xShapeInfo); + const int tRank = shape::rank(tadShapeInfo); + + const Nd4jLong xEws = shape::elementWiseStride(xShapeInfo); + const Nd4jLong tEws = shape::elementWiseStride(tadShapeInfo); + const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo); + + const char xOrder = shape::order(xShapeInfo); + const char tOrder = shape::order(tadShapeInfo); + const char zOrder = shape::order(zShapeInfo); + + const bool allC = (tOrder == zOrder && zOrder == 'c'); + + int temp; + const bool tVectorOrC = + shape::isCommonVector(tadShapeInfo, temp) || tOrder == 'c'; + const bool zVectorOrC = + shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c'; + ; + + if (shape::length(tadShapeInfo) * shape::length(zShapeInfo) <= + Environment::getInstance().elementwiseThreshold() && + xEws == 1 && xOrder == 'c' && xRank == 2 && tEws > 1 && zEws == 1 && + (allC || (tVectorOrC && zVectorOrC))) + return SMALLARR2DX; + if (tEws == 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC))) + return EWS1; + if (tEws > 0 && zEws > 0 && (allC || (tVectorOrC && zVectorOrC))) + return EWSNONZERO; + if (tRank == 1 && zEws == 1 && zVectorOrC) return RANK1; + if (tRank == 2 && zEws == 1 && zVectorOrC) return RANK2; + if (tRank == 3 && zEws == 1 && zVectorOrC) return RANK3; + if (tRank == 4 && zEws == 1 && zVectorOrC) return RANK4; + if (tRank == 5 && zEws == 1 && zVectorOrC) return RANK5; + if (tEws > 0 && tVectorOrC && zEws == 0) return X_EWSNONZERO; + if (zEws > 0 && zVectorOrC && tEws == 0) return Z_EWSNONZERO; + return COMMON; } ////////////////////////////////////////////////////////////////////////////// -LoopKind::Kind LoopKind::deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo) { - - // both tad shapes are the same, but strides and ews may be different - - const int tadRank = shape::rank(xTadShapeInfo); - - const Nd4jLong xTadEws = shape::elementWiseStride(xTadShapeInfo); - const Nd4jLong yTadEws = shape::elementWiseStride(yTadShapeInfo); - const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo); - - const char xTadOrder = shape::order(xTadShapeInfo); - const char yTadOrder = shape::order(xTadShapeInfo); - const char zOrder = shape::order(zShapeInfo); - - int position; - const bool xTadVectorOrC = shape::isCommonVector(xTadShapeInfo, position) || xTadOrder == 'c'; - const bool yTadVectorOrC = shape::isCommonVector(yTadShapeInfo, position) || yTadOrder == 'c'; - const bool zVectorOrC = shape::isCommonVector(zShapeInfo, position) || zOrder == 'c'; - const bool allC = (xTadOrder == yTadOrder && xTadOrder == zOrder && zOrder == 'c'); - - if(xTadEws == 1 && yTadEws == 1 && zEws == 1 && allC) - return EWS1; - if(xTadEws > 0 && yTadEws > 0 && zEws > 0 && (allC || (xTadVectorOrC && yTadVectorOrC && zVectorOrC))) - return EWSNONZERO; - if(tadRank == 1 && zEws > 0 && zVectorOrC) - return RANK1; - if(tadRank == 2 && zEws > 0 && zVectorOrC) - return RANK2; - if(tadRank == 3 && zEws > 0 && zVectorOrC) - return RANK3; - if(tadRank == 4 && zEws > 0 && zVectorOrC) - return RANK4; - if(tadRank == 5 && zEws > 0 && zVectorOrC) - return RANK5; - return COMMON; +LoopKind::Kind LoopKind::deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, + const Nd4jLong* yTadShapeInfo, + const Nd4jLong* zShapeInfo) { + // both tad shapes are the same, but strides and ews may be different + + const int tadRank = shape::rank(xTadShapeInfo); + + const Nd4jLong xTadEws = shape::elementWiseStride(xTadShapeInfo); + const Nd4jLong yTadEws = shape::elementWiseStride(yTadShapeInfo); + const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo); + + const char xTadOrder = shape::order(xTadShapeInfo); + const char yTadOrder = shape::order(xTadShapeInfo); + const char zOrder = shape::order(zShapeInfo); + + int position; + const bool xTadVectorOrC = + shape::isCommonVector(xTadShapeInfo, position) || xTadOrder == 'c'; + const bool yTadVectorOrC = + shape::isCommonVector(yTadShapeInfo, position) || yTadOrder == 'c'; + const bool zVectorOrC = + shape::isCommonVector(zShapeInfo, position) || zOrder == 'c'; + const bool allC = + (xTadOrder == yTadOrder && xTadOrder == zOrder && zOrder == 'c'); + + if (xTadEws == 1 && yTadEws == 1 && zEws == 1 && allC) return EWS1; + if (xTadEws > 0 && yTadEws > 0 && zEws > 0 && + (allC || (xTadVectorOrC && yTadVectorOrC && zVectorOrC))) + return EWSNONZERO; + if (tadRank == 1 && zEws > 0 && zVectorOrC) return RANK1; + if (tadRank == 2 && zEws > 0 && zVectorOrC) return RANK2; + if (tadRank == 3 && zEws > 0 && zVectorOrC) return RANK3; + if (tadRank == 4 && zEws > 0 && zVectorOrC) return RANK4; + if (tadRank == 5 && zEws > 0 && zVectorOrC) return RANK5; + return COMMON; } - - - -} -#endif //LIBND4J_LOOPKIND_H +} // namespace sd +#endif // LIBND4J_LOOPKIND_H diff --git a/libnd4j/include/helpers/Loops.h b/libnd4j/include/helpers/Loops.h index 9bf3daede64b..ce9d273bb773 100644 --- a/libnd4j/include/helpers/Loops.h +++ b/libnd4j/include/helpers/Loops.h @@ -14,1239 +14,1392 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Yurii Shyrma (iuriish@yahoo.com), created on 14.03.2019 - // +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 14.03.2019 +// #ifndef LIBND4J_LOOPS_H #define LIBND4J_LOOPS_H -#include -#include -#include +#include +#include +#include #include #include -#include -#include +#include #include -#include +#include #include -#include - -namespace sd { - - template - class ND4J_EXPORT ReductionLoops { - protected: - public: - - template - static FORCEINLINE void loopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, E* extraParams, int64_t start, int64_t stop); - }; - - template - class ReductionFloatLoops : public ReductionLoops { - public: - static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop); - - template - static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop); - }; - - template - class ND4J_EXPORT ReductionBoolLoops : public ReductionLoops { - public: - static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop); - - template - static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop); - }; - - template - class ND4J_EXPORT ReductionLongLoops : public ReductionLoops { - public: - static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop); - - template - static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop); - }; - - template - class ND4J_EXPORT ReductionSameLoops : public ReductionLoops { - public: - static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, X* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop); - - template - static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, X* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop); - }; - - - template - class ND4J_EXPORT IndexReductionLoops { - private: - public: - static void wrapIndexReduce(int opNum, const void* x, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* extraParams); - - template - static void loopIndexReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams); - }; - - - template - class ND4J_EXPORT TransformLoops { - - public: - - template - static FORCEINLINE void loopTransform(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, E* extraParams, uint64_t threadId, uint64_t numThreads); - }; - - template - class ND4J_EXPORT Reduction3Loops { - public: - - template - static FORCEINLINE void loopReduce3(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop); - - template - static FORCEINLINE void loopReduce3All(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop); - - static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop); - - static void wrapperAll(int opNum, const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop); - - template - static void innerloopReduce3(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop); - - template - static void innerloopReduce3All(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop); - }; - - - - - /* - ////////////////////////////////////////////////////////////////////////////// - template - void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, - const Y* y, const Nd4jLong* yShapeInfo, - Z* z, const Nd4jLong* zShapeInfo, - Z* extraParams, - std::function op) { - - const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo); - - const Nd4jLong* xShape = shape::shapeOf(xShapeInfo); - const Nd4jLong* xStride = shape::stride(xShapeInfo); - const Nd4jLong* yStride = shape::stride(yShapeInfo); - const Nd4jLong* zStride = shape::stride(zShapeInfo); - - const Nd4jLong len = shape::length(xShapeInfo); - - OmpLaunchHelper threadsInfo(len); - - switch (kindOfLoop) { - - case LoopKind::EWS1: { - PRAGMA_OMP_PARALLEL_THREADS(threadsInfo._numThreads) - { - const auto threadNum = omp_get_thread_num(); - const auto threadOffset = threadsInfo.getThreadOffset(threadNum); - const auto lenPerThread = static_cast(threadsInfo.getItersPerThread(threadNum)); - - const auto xi = x + threadOffset; - const auto yi = y + threadOffset; - auto zi = z + threadOffset; - - PRAGMA_OMP_SIMD - for (uint i = 0; i < lenPerThread; i++) - zi[i] = op(xi[i], yi[i], extraParams); - } - } - break; - - case LoopKind::EWSNONZERO: { - const uint xEws = shape::elementWiseStride(xShapeInfo); - const uint yEws = shape::elementWiseStride(yShapeInfo); - const uint zEws = shape::elementWiseStride(zShapeInfo); - - PRAGMA_OMP_PARALLEL_THREADS(threadsInfo._numThreads) - { - const auto threadNum = omp_get_thread_num(); - const auto threadOffset = threadsInfo.getThreadOffset(threadNum); - const auto lenPerThread = static_cast(threadsInfo.getItersPerThread(threadNum)); - const auto xi = x + threadOffset * xEws; - const auto yi = y + threadOffset * yEws; - auto zi = z + threadOffset * zEws; - - PRAGMA_OMP_SIMD - for (uint i = 0; i < lenPerThread; i++) - zi[i*zEws] = op(xi[i*xEws], yi[i*yEws], extraParams); - } - } - break; - - case LoopKind::RANK1: { - PRAGMA_OMP_PARALLEL_FOR - for (uint i0 = 0; i0 < len; ++i0) - z[i0 * zStride[0]] = op(x[i0 * xStride[0]], y[i0 * yStride[0]], extraParams); - } - break; - - case LoopKind::RANK2: { - PRAGMA_OMP_PARALLEL_FOR_SIMD - for (uint i0 = 0; i0 < xShape[0]; ++i0) - for (uint i1 = 0; i1 < xShape[1]; ++i1) - z[i0 * zStride[0] + i1 * zStride[1]] = op(x[i0 * xStride[0] + i1 * xStride[1]], y[i0 * yStride[0] + i1 * yStride[1]], extraParams); - } - break; - - case LoopKind::RANK3: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(2) - for (uint i0 = 0; i0 < xShape[0]; ++i0) - for (uint i1 = 0; i1 < xShape[1]; ++i1) - for (uint i2 = 0; i2 < xShape[2]; ++i2) - z[i0*zStride[0]+i1*zStride[1]+i2*zStride[2]] = op(x[i0*xStride[0]+i1*xStride[1]+i2*xStride[2]], y[i0*yStride[0]+i1*yStride[1]+i2*yStride[2]], extraParams); - } - break; - - case LoopKind::RANK4: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(3) - for (uint i0 = 0; i0 < xShape[0]; ++i0) - for (uint i1 = 0; i1 < xShape[1]; ++i1) - for (uint i2 = 0; i2 < xShape[2]; ++i2) - for (uint i3 = 0; i3 < xShape[3]; ++i3) - z[i0*zStride[0]+i1*zStride[1]+i2*zStride[2]+i3*zStride[3]] = op(x[i0*xStride[0]+i1*xStride[1]+i2*xStride[2]+i3*xStride[3]], y[i0*yStride[0]+i1*yStride[1]+i2*yStride[2]+i3*yStride[3]], extraParams); - } - break; - - case LoopKind::RANK5: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(4) - for (uint i0 = 0; i0 < xShape[0]; ++i0) - for (uint i1 = 0; i1 < xShape[1]; ++i1) - for (uint i2 = 0; i2 < xShape[2]; ++i2) - for (uint i3 = 0; i3 < xShape[3]; ++i3) - for (uint i4 = 0; i4 < xShape[4]; ++i4) - z[i0*zStride[0]+i1*zStride[1]+i2*zStride[2]+i3*zStride[3]+i4*zStride[4]] = op(x[i0*xStride[0]+i1*xStride[1]+i2*xStride[2]+i3*xStride[3]+i4*xStride[4]], y[i0*yStride[0]+i1*yStride[1]+i2*yStride[2]+i3*yStride[3]+i4*yStride[4]], extraParams); - } - break; - - default: { - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - uint zShapeInfoCast[MAX_RANK]; - - bool canCastX = DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - bool canCastZ = DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_PARALLEL_THREADS(threadsInfo._numThreads) - { - auto threadNum = omp_get_thread_num(); - auto threadOffset = threadsInfo.getThreadOffset(threadNum); - auto lenPerThread = static_cast(threadsInfo.getItersPerThread(threadNum)); - PRAGMA_OMP_SIMD - for (uint i = 0; i < lenPerThread; i++) { - auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = op(x[xOffset], y[yOffset], extraParams); - } - } - } - } - } - */ - - - - ////////////////////////////////////////////////////////////////////////////// - template - template - void sd::ReductionLoops::loopReduce(const X* x, const Nd4jLong* xShapeInfo, - Z* z, const Nd4jLong* zShapeInfo, - const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, - E* extraParams, - int64_t start, int64_t stop) { - - const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopTadXZ(xShapeInfo, zShapeInfo, tadShapeInfo); - - const Nd4jLong zLen = shape::length(zShapeInfo); - const Nd4jLong tadLen = shape::length(tadShapeInfo); - - const uint tadEws = shape::elementWiseStride(tadShapeInfo); - const uint zEws = shape::elementWiseStride(zShapeInfo); - - const Nd4jLong* tadShape = shape::shapeOf(tadShapeInfo); - const Nd4jLong* tadStride = shape::stride(tadShapeInfo); - - int numThreads = OmpLaunchHelper::tadThreads(tadLen, zLen); - - switch (kindOfLoop) { - - //*********************************************// - // case LoopKind::SMALLARR2DX: { - // shape::printShapeInfoLinear(xShapeInfo); - // shape::printShapeInfoLinear(zShapeInfo); - // const auto xLen = zLen * tadLen; - // for (uint i = 0; i < xLen; ++i) { - // const auto zOffset = shape::subArrayOffset(i, xShapeInfo, zShapeInfo, dimsToExclude, dimsLen); - // const uint tadInd = (i / tadEws) % tadLen; - // auto startVal = tadInd ? z[zOffset] : static_cast(OpType::startingValue(x)); - // z[zOffset] = OpType::update(startVal, OpType::op(x[i], extraParams), extraParams); - // if(tadInd == tadLen - 1) - // z[zOffset] = OpType::postProcess(z[zOffset], tadLen, extraParams); - // printf("%u - %lld\n", i, zOffset); - // } - // } - case LoopKind::SMALLARR2DX: { - const auto uTadLen = static_cast(tadLen); - const auto uZLenMinusOne = static_cast(zLen - 1); - const auto xLen = static_cast(zLen * uTadLen); - const auto sv = static_cast(OpType::startingValue(x)); - - for (uint i = 0; i <= uZLenMinusOne; i++) - z[i] = OpType::startingValue(x); - - uint zOffset = 0; - for (uint i = 0; i < xLen; ++i) { - z[zOffset] = OpType::update(z[zOffset], OpType::op(x[i], extraParams), extraParams); - zOffset = zOffset == uZLenMinusOne ? 0 : zOffset + 1; - } - - for (uint i = 0; i <= uZLenMinusOne; i++) - z[i] = OpType::postProcess(z[i], tadLen, extraParams); - } - break; - - //*********************************************// - case LoopKind::EWS1: { - for (auto i = start; i < stop; i++) { - auto tad = x + tadOffsets[i]; - auto s = OpType::startingValue(tad); - - for (Nd4jLong j = 0; j < tadLen; j++) - s = OpType::update(s, OpType::op(tad[j], extraParams), extraParams); - - z[i] = OpType::postProcess(s, tadLen, extraParams); - }; - } - break; - - //*********************************************// - case LoopKind::EWSNONZERO: { - for (auto i = start; i < stop; i++) { - auto tad = x + tadOffsets[i]; - auto s = OpType::startingValue(tad); - - for (Nd4jLong j = 0; j < tadLen; j++) - s = OpType::update(s, OpType::op(tad[j * tadEws], extraParams), extraParams); - - z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); - }; - } - break; - - //*********************************************// - case LoopKind::RANK1: { - for (auto i = start; i < stop; i++) { - auto tad = x + tadOffsets[i]; - auto s = OpType::startingValue(tad); - - for (Nd4jLong i0 = 0; i0 < tadLen; ++i0) - s = OpType::update(s, OpType::op(tad[i0 * tadStride[0]], extraParams), extraParams); - - z[i] = OpType::postProcess(s, tadLen, extraParams); - }; - } - break; - - //*********************************************// - case LoopKind::RANK2: { - for (auto i = start; i < stop; i++) { - auto tad = x + tadOffsets[i]; - auto s = OpType::startingValue(tad); - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) - s = OpType::update(s, OpType::op(tad[i0 * tadStride[0] + i1 * tadStride[1]], extraParams), extraParams); - - z[i] = OpType::postProcess(s, tadLen, extraParams); - }; - } - break; - - //*********************************************// - case LoopKind::RANK3: { - for (auto i = start; i < stop; i++) { - auto tad = x + tadOffsets[i]; - auto s = OpType::startingValue(tad); - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) - for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) - s = OpType::update(s, OpType::op(tad[i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2]], extraParams), extraParams); - - z[i] = OpType::postProcess(s, tadLen, extraParams); - }; - } - break; - - //*********************************************// - case LoopKind::RANK4: { - for (auto i = start; i < stop; i++) { - auto tad = x + tadOffsets[i]; - auto s = OpType::startingValue(tad); - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) - for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) - for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) - s = OpType::update(s, OpType::op(tad[i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2] + i3 * tadStride[3]], extraParams), extraParams); - - z[i] = OpType::postProcess(s, tadLen, extraParams); - }; - } - break; - - //*********************************************// - case LoopKind::RANK5: { - for (auto i = start; i < stop; i++) { - auto tad = x + tadOffsets[i]; - auto s = OpType::startingValue(tad); - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) - for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) - for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) - for (Nd4jLong i4 = 0; i4 < tadShape[4]; ++i4) - s = OpType::update(s, OpType::op(tad[i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2] + i3 * tadStride[3] + i4 * tadStride[4]], extraParams), extraParams); - - z[i] = OpType::postProcess(s, tadLen, extraParams); - }; - } - break; - - //*********************************************// - case LoopKind::X_EWSNONZERO: { - uint castZShapeInfo[MAX_RANK]; - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, castZShapeInfo); - - for (auto i = start; i < stop; i++) { - auto tad = x + tadOffsets[i]; - auto s = OpType::startingValue(tad); - - for (Nd4jLong j = 0; j < tadLen; j++) - s = OpType::update(s, OpType::op(tad[j * tadEws], extraParams), extraParams); - - auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ); - z[zOffset] = OpType::postProcess(s, tadLen, extraParams); - }; - } - break; - - //*********************************************// - case LoopKind::Z_EWSNONZERO: { - uint castTadShapeInfo[MAX_RANK]; - const bool canCastTad = sd::DataTypeUtils::castShapeInfo(tadShapeInfo, castTadShapeInfo); - - for (auto i = start; i < stop; i++) { - auto tad = x + tadOffsets[i]; - auto s = OpType::startingValue(tad); - - for (Nd4jLong j = 0; j < tadLen; j++) { - auto tadOffset = shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, canCastTad); - s = OpType::update(s, OpType::op(tad[tadOffset], extraParams), extraParams); - } - - z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); - }; - } - break; - - //*********************************************// - default: { - auto innertadOffsets = new Nd4jLong[tadLen]; - shape::calcOffsets(tadShapeInfo, innertadOffsets); - - uint castZShapeInfo[MAX_RANK]; - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, castZShapeInfo); - - for (auto i = start; i < stop; i++) { - auto tad = x + tadOffsets[i]; - auto s = OpType::startingValue(tad); - - for (Nd4jLong j = 0; j < tadLen; j++) - s = OpType::update(s, OpType::op(tad[innertadOffsets[j]], extraParams), extraParams); - - auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ); - z[zOffset] = OpType::postProcess(s, tadLen, extraParams); - }; - - delete[] innertadOffsets; - } - } - } - - - - ////////////////////////////////////////////////////////////////////////////// - template - template - void sd::TransformLoops::loopTransform(const X* x, const Nd4jLong* xShapeInfo, - Z* z, const Nd4jLong* zShapeInfo, - E* extraParams, - uint64_t threadId, uint64_t numThreads) { - - const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo); - - const Nd4jLong* xShape = shape::shapeOf(const_cast(xShapeInfo)); - const Nd4jLong* xStride = shape::stride(const_cast(xShapeInfo)); - const Nd4jLong* zStride = shape::stride(const_cast(zShapeInfo)); +#include - const Nd4jLong len = shape::length(xShapeInfo); +#include - if (len == 0) - return; +namespace sd { - switch (kindOfLoop) { +template +class SD_EXPORT ReductionLoops { + protected: + public: + template + static FORCEINLINE void loopReduce(const X* x, const Nd4jLong* xShapeInfo, + Z* z, const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, E* extraParams, + int64_t start, int64_t stop); +}; + +template +class ReductionFloatLoops : public ReductionLoops { + public: + static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, Z* z, + const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, + int64_t stop); + + template + static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, Z* extraParams, + int64_t start, int64_t stop); +}; + +template +class SD_EXPORT ReductionBoolLoops : public ReductionLoops { + public: + static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, Z* z, + const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, X* extraParams, int64_t start, + int64_t stop); + + template + static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, X* extraParams, + int64_t start, int64_t stop); +}; + +template +class SD_EXPORT ReductionLongLoops : public ReductionLoops { + public: + static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, Z* z, + const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, X* extraParams, int64_t start, + int64_t stop); + + template + static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, X* extraParams, + int64_t start, int64_t stop); +}; + +template +class SD_EXPORT ReductionSameLoops : public ReductionLoops { + public: + static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, X* z, + const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, X* extraParams, int64_t start, + int64_t stop); + + template + static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, X* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, X* extraParams, + int64_t start, int64_t stop); +}; + +template +class SD_EXPORT IndexReductionLoops { + private: + public: + static void wrapIndexReduce(int opNum, const void* x, + const Nd4jLong* xShapeInfo, void* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, void* extraParams); + + template + static void loopIndexReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, X* extraParams); +}; + +template +class SD_EXPORT TransformLoops { + public: + template + static FORCEINLINE void loopTransform(const X* x, const Nd4jLong* xShapeInfo, + Z* z, const Nd4jLong* zShapeInfo, + E* extraParams, uint64_t threadId, + uint64_t numThreads); +}; + +template +class SD_EXPORT Reduction3Loops { + public: + template + static FORCEINLINE void loopReduce3(const X* x, const Nd4jLong* xShapeInfo, + const X* y, const Nd4jLong* yShapeInfo, + Z* z, const Nd4jLong* zShapeInfo, + int* dims, int dimsLen, Z* extraParams, + int64_t start, int64_t stop); + + template + static FORCEINLINE void loopReduce3All( + const X* x, const Nd4jLong* xShapeInfo, const X* y, + const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, + const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, + const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, + Z* extraParams, int64_t start, int64_t stop); + + static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, + const X* y, const Nd4jLong* yShapeInfo, Z* z, + const Nd4jLong* zShapeInfo, int* dims, int dimsLen, + Z* extraParams, int64_t start, int64_t stop); + + static void wrapperAll(int opNum, const X* x, const Nd4jLong* xShapeInfo, + const X* y, const Nd4jLong* yShapeInfo, Z* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* xTadShapeInfo, + const Nd4jLong* xTadOffsets, + const Nd4jLong* yTadShapeInfo, + const Nd4jLong* yTadOffsets, Z* extraParams, + int64_t start, int64_t stop); + + template + static void innerloopReduce3(const X* x, const Nd4jLong* xShapeInfo, + const X* y, const Nd4jLong* yShapeInfo, Z* z, + const Nd4jLong* zShapeInfo, int* dims, + int dimsLen, Z* extraParams, int64_t start, + int64_t stop); + + template + static void innerloopReduce3All(const X* x, const Nd4jLong* xShapeInfo, + const X* y, const Nd4jLong* yShapeInfo, Z* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* xTadShapeInfo, + const Nd4jLong* xTadOffsets, + const Nd4jLong* yTadShapeInfo, + const Nd4jLong* yTadOffsets, Z* extraParams, + int64_t start, int64_t stop); +}; + +/* +////////////////////////////////////////////////////////////////////////////// +template +void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, + const Y* y, const Nd4jLong* yShapeInfo, + Z* z, const Nd4jLong* zShapeInfo, + Z* extraParams, + std::function op) { + + const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopXYZ(xShapeInfo, +yShapeInfo, zShapeInfo); + + const Nd4jLong* xShape = shape::shapeOf(xShapeInfo); + const Nd4jLong* xStride = shape::stride(xShapeInfo); + const Nd4jLong* yStride = shape::stride(yShapeInfo); + const Nd4jLong* zStride = shape::stride(zShapeInfo); + + const Nd4jLong len = shape::length(xShapeInfo); + + OmpLaunchHelper threadsInfo(len); + + switch (kindOfLoop) { - //*********************************************// case LoopKind::EWS1: { - auto span = samediff::Span::build(threadId, numThreads, 0, len, 1); - int64_t start = span.startX(), stop = span.stopX(); - - for (auto i = start; i < stop; i++) - z[i] = OpType::op(x[i], extraParams); + PRAGMA_OMP_PARALLEL_THREADS(threadsInfo._numThreads) + { + const auto threadNum = omp_get_thread_num(); + const auto threadOffset = +threadsInfo.getThreadOffset(threadNum); const auto lenPerThread = +static_cast(threadsInfo.getItersPerThread(threadNum)); + + const auto xi = x + threadOffset; + const auto yi = y + threadOffset; + auto zi = z + threadOffset; + + PRAGMA_OMP_SIMD + for (uint i = 0; i < lenPerThread; i++) + zi[i] = op(xi[i], yi[i], extraParams); + } } - break; + break; - //*********************************************// case LoopKind::EWSNONZERO: { const uint xEws = shape::elementWiseStride(xShapeInfo); + const uint yEws = shape::elementWiseStride(yShapeInfo); const uint zEws = shape::elementWiseStride(zShapeInfo); - auto span = samediff::Span::build(threadId, numThreads, 0, len, 1); - int64_t start = span.startX(), stop = span.stopX(); - - for (auto i = start; i < stop; i++) - z[i * zEws] = OpType::op(x[i * xEws], extraParams); - } - break; - - //*********************************************// - case LoopKind::Z_EWSNONZERO: { - const uint zEws = shape::elementWiseStride(zShapeInfo); - uint castXShapeInfo[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, castXShapeInfo); - - auto span = samediff::Span::build(threadId, numThreads, 0, len, 1); - int64_t start = span.startX(), stop = span.stopX(); - - if (zEws > 1) { - for (auto i = start; i < stop; i++) { - const auto xOffset = shape::indexOffset(i, xShapeInfo, castXShapeInfo, canCastX); - z[i * zEws] = OpType::op(x[xOffset], extraParams); - } - } - else { - for (auto i = start; i < stop; i++) { - const auto xOffset = shape::indexOffset(i, xShapeInfo, castXShapeInfo, canCastX); - z[i] = OpType::op(x[xOffset], extraParams); - } + PRAGMA_OMP_PARALLEL_THREADS(threadsInfo._numThreads) + { + const auto threadNum = omp_get_thread_num(); + const auto threadOffset = +threadsInfo.getThreadOffset(threadNum); const auto lenPerThread = +static_cast(threadsInfo.getItersPerThread(threadNum)); const auto xi = x + +threadOffset * xEws; const auto yi = y + threadOffset * yEws; auto zi = z + +threadOffset * zEws; + + PRAGMA_OMP_SIMD + for (uint i = 0; i < lenPerThread; i++) + zi[i*zEws] = op(xi[i*xEws], yi[i*yEws], extraParams); } } - break; + break; - //*********************************************// case LoopKind::RANK1: { - auto span = samediff::Span::build(threadId, numThreads, 0, len, 1); - - for (auto i0 = span.startX(); i0 < span.stopX(); i0++) - z[i0 * zStride[0]] = OpType::op(x[i0 * xStride[0]], extraParams); + PRAGMA_OMP_PARALLEL_FOR + for (uint i0 = 0; i0 < len; ++i0) + z[i0 * zStride[0]] = op(x[i0 * xStride[0]], y[i0 * yStride[0]], +extraParams); } - break; + break; - //*********************************************// case LoopKind::RANK2: { - auto uXShape0 = static_cast(xShape[0]); - auto uXShape1 = static_cast(xShape[1]); - - auto loop = samediff::ThreadsHelper::pickLoop2d(numThreads, uXShape0, uXShape1); - auto span = samediff::Span2::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1); - - for (auto i0 = span.startX(); i0 < span.stopX(); i0++) { - auto z0 = i0 * zStride[0]; - auto x0 = i0 * xStride[0]; - - for (auto i1 = span.startY(); i1 < span.stopY(); ++i1) - z[z0 + i1 * zStride[1]] = OpType::op(x[x0 + i1 * xStride[1]], extraParams); - } + PRAGMA_OMP_PARALLEL_FOR_SIMD + for (uint i0 = 0; i0 < xShape[0]; ++i0) + for (uint i1 = 0; i1 < xShape[1]; ++i1) + z[i0 * zStride[0] + i1 * zStride[1]] = op(x[i0 * xStride[0] ++ i1 * xStride[1]], y[i0 * yStride[0] + i1 * yStride[1]], extraParams); } - break; + break; - //*********************************************// case LoopKind::RANK3: { - auto uXShape0 = xShape[0]; - auto uXShape1 = xShape[1]; - auto uXShape2 = xShape[2]; - - auto loop = samediff::ThreadsHelper::pickLoop2d(numThreads, uXShape0, uXShape1); - auto span = samediff::Span2::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1); - - - for (auto i0 = span.startX(); i0 < span.stopX(); i0++) - for (auto i1 = span.startY(); i1 < span.stopY(); i1++) { - auto z0 = i0 * zStride[0] + i1 * zStride[1]; - auto x0 = i0 * xStride[0] + i1 * xStride[1]; - - for (Nd4jLong i2 = 0; i2 < uXShape2; ++i2) - z[z0 + i2 * zStride[2]] = OpType::op(x[x0 + i2 * xStride[2]], extraParams); - } + PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(2) + for (uint i0 = 0; i0 < xShape[0]; ++i0) + for (uint i1 = 0; i1 < xShape[1]; ++i1) + for (uint i2 = 0; i2 < xShape[2]; ++i2) + z[i0*zStride[0]+i1*zStride[1]+i2*zStride[2]] = +op(x[i0*xStride[0]+i1*xStride[1]+i2*xStride[2]], +y[i0*yStride[0]+i1*yStride[1]+i2*yStride[2]], extraParams); } - break; + break; - //*********************************************// case LoopKind::RANK4: { - auto uXShape0 = xShape[0]; - auto uXShape1 = xShape[1]; - auto uXShape2 = xShape[2]; - auto uXShape3 = xShape[3]; - - auto loop = samediff::ThreadsHelper::pickLoop3d(numThreads, uXShape0, uXShape1, uXShape2); - auto span = samediff::Span3::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1, 0, uXShape2, 1); - - for (auto i0 = span.startX(); i0 < span.stopX(); i0++) - for (auto i1 = span.startY(); i1 < span.stopY(); i1++) - for (auto i2 = span.startZ(); i2 < span.stopZ(); i2++) { - auto x0 = i0 * xStride[0] + i1 * xStride[1] + i2 * xStride[2]; - auto z0 = i0 * zStride[0] + i1 * zStride[1] + i2 * zStride[2]; - - for (Nd4jLong i3 = 0; i3 < uXShape3; ++i3) - z[z0 + i3 * zStride[3]] = OpType::op(x[x0 + i3 * xStride[3]], extraParams); - } + PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(3) + for (uint i0 = 0; i0 < xShape[0]; ++i0) + for (uint i1 = 0; i1 < xShape[1]; ++i1) + for (uint i2 = 0; i2 < xShape[2]; ++i2) + for (uint i3 = 0; i3 < xShape[3]; ++i3) + z[i0*zStride[0]+i1*zStride[1]+i2*zStride[2]+i3*zStride[3]] += op(x[i0*xStride[0]+i1*xStride[1]+i2*xStride[2]+i3*xStride[3]], +y[i0*yStride[0]+i1*yStride[1]+i2*yStride[2]+i3*yStride[3]], extraParams); } - break; + break; - //*********************************************// case LoopKind::RANK5: { - auto uXShape0 = xShape[0]; - auto uXShape1 = xShape[1]; - auto uXShape2 = xShape[2]; - auto uXShape3 = xShape[3]; - auto uXShape4 = xShape[4]; - - auto loop = samediff::ThreadsHelper::pickLoop3d(numThreads, uXShape0, uXShape1, uXShape2); - auto span = samediff::Span3::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1, 0, uXShape2, 1); - - - for (auto i0 = span.startX(); i0 < span.stopX(); i0++) - for (auto i1 = span.startY(); i1 < span.stopY(); i1++) - for (auto i2 = span.startZ(); i2 < span.stopZ(); i2++) { - auto z0 = i0 * zStride[0] + i1 * zStride[1] + i2 * zStride[2]; - auto x0 = i0 * xStride[0] + i1 * xStride[1] + i2 * xStride[2]; - - for (Nd4jLong i3 = 0; i3 < uXShape3; ++i3) { - - auto z1 = z0 + i3 * zStride[3]; - auto x1 = x0 + i3 * xStride[3]; - - for (Nd4jLong i4 = 0; i4 < uXShape4; ++i4) - z[z1 + i4 * zStride[4]] = OpType::op(x[x1 + i4 * xStride[4]], extraParams); - - } - } - + PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(4) + for (uint i0 = 0; i0 < xShape[0]; ++i0) + for (uint i1 = 0; i1 < xShape[1]; ++i1) + for (uint i2 = 0; i2 < xShape[2]; ++i2) + for (uint i3 = 0; i3 < xShape[3]; ++i3) + for (uint i4 = 0; i4 < xShape[4]; ++i4) + z[i0*zStride[0]+i1*zStride[1]+i2*zStride[2]+i3*zStride[3]+i4*zStride[4]] += op(x[i0*xStride[0]+i1*xStride[1]+i2*xStride[2]+i3*xStride[3]+i4*xStride[4]], +y[i0*yStride[0]+i1*yStride[1]+i2*yStride[2]+i3*yStride[3]+i4*yStride[4]], +extraParams); } - break; + break; - //*********************************************// default: { uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK]; - bool canCastX = DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastZ = DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - auto span = samediff::Span::build(threadId, numThreads, 0, len, 1); - - for (auto i = span.startX(); i < span.stopX(); i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[xOffset], extraParams); + bool canCastX = DataTypeUtils::castShapeInfo(xShapeInfo, +xShapeInfoCast); bool canCastY = DataTypeUtils::castShapeInfo(yShapeInfo, +yShapeInfoCast); bool canCastZ = DataTypeUtils::castShapeInfo(zShapeInfo, +zShapeInfoCast); + + PRAGMA_OMP_PARALLEL_THREADS(threadsInfo._numThreads) + { + auto threadNum = omp_get_thread_num(); + auto threadOffset = threadsInfo.getThreadOffset(threadNum); + auto lenPerThread = +static_cast(threadsInfo.getItersPerThread(threadNum)); PRAGMA_OMP_SIMD for +(uint i = 0; i < lenPerThread; i++) { auto xOffset = shape::indexOffset(i + +threadOffset, xShapeInfo, xShapeInfoCast, canCastX); auto yOffset = +shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY); auto +zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, +canCastZ); z[zOffset] = op(x[xOffset], y[yOffset], extraParams); + } } } - - } } - - - ////////////////////////////////////////////////////////////////////////////// - template - template - void sd::Reduction3Loops::loopReduce3(const X* x, const Nd4jLong* xShapeInfo, - const X* y, const Nd4jLong* yShapeInfo, - Z* z, const Nd4jLong* zShapeInfo, - int* dims, int dimsLen, - Z* extraParameters, int64_t start, int64_t stop) { - - // both tads have same shape, however strides and ews may differ - - Z param0(OpType::startingValue(x)), param1(OpType::startingValue(x)), param2(extraParameters ? extraParameters[0] : OpType::startingValue(x)); - - const Nd4jLong xLen = shape::length(xShapeInfo); - const Nd4jLong yLen = shape::length(yShapeInfo); - - const Nd4jLong* xTadShapeInfo = nullptr, * yTadShapeInfo = nullptr, * xTadOffsets = nullptr, * yTadOffsets = nullptr; - TadPack tadPackX, tadPackY; - std::vector zeroOffsets; - - if (xLen == yLen) { - tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dims, dimsLen); - tadPackY = sd::ConstantTadHelper::getInstance().tadForDimensions(yShapeInfo, dims, dimsLen); - xTadShapeInfo = tadPackX.primaryShapeInfo(); - yTadShapeInfo = tadPackY.primaryShapeInfo(); - xTadOffsets = tadPackX.primaryOffsets(); - yTadOffsets = tadPackY.primaryOffsets(); - } - else if (yLen > xLen) { - tadPackY = sd::ConstantTadHelper::getInstance().tadForDimensions(yShapeInfo, dims, dimsLen); - xTadShapeInfo = xShapeInfo; - yTadShapeInfo = tadPackY.primaryShapeInfo(); - yTadOffsets = tadPackY.primaryOffsets(); - } - else { - tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dims, dimsLen); - yTadShapeInfo = yShapeInfo; - xTadShapeInfo = tadPackX.primaryShapeInfo(); - xTadOffsets = tadPackX.primaryOffsets(); +} +*/ + +////////////////////////////////////////////////////////////////////////////// +template +template +void sd::ReductionLoops::loopReduce( + const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, E* extraParams, + int64_t start, int64_t stop) { + const LoopKind::Kind kindOfLoop = + LoopKind::deduceKindOfLoopTadXZ(xShapeInfo, zShapeInfo, tadShapeInfo); + + const Nd4jLong zLen = shape::length(zShapeInfo); + const Nd4jLong tadLen = shape::length(tadShapeInfo); + + const uint tadEws = shape::elementWiseStride(tadShapeInfo); + const uint zEws = shape::elementWiseStride(zShapeInfo); + + const Nd4jLong* tadShape = shape::shapeOf(tadShapeInfo); + const Nd4jLong* tadStride = shape::stride(tadShapeInfo); + + int numThreads = OmpLaunchHelper::tadThreads(tadLen, zLen); + + switch (kindOfLoop) { + //*********************************************// + // case LoopKind::SMALLARR2DX: { + // shape::printShapeInfoLinear(xShapeInfo); + // shape::printShapeInfoLinear(zShapeInfo); + // const auto xLen = zLen * tadLen; + // for (uint i = 0; i < xLen; ++i) { + // const auto zOffset = shape::subArrayOffset(i, xShapeInfo, + // zShapeInfo, dimsToExclude, dimsLen); const uint tadInd = (i / + // tadEws) % tadLen; auto startVal = tadInd ? z[zOffset] : + // static_cast(OpType::startingValue(x)); z[zOffset] = + // OpType::update(startVal, OpType::op(x[i], extraParams), + // extraParams); if(tadInd == tadLen - 1) + // z[zOffset] = OpType::postProcess(z[zOffset], tadLen, + // extraParams); + // printf("%u - %lld\n", i, zOffset); + // } + // } + case LoopKind::SMALLARR2DX: { + const auto uTadLen = static_cast(tadLen); + const auto uZLenMinusOne = static_cast(zLen - 1); + const auto xLen = static_cast(zLen * uTadLen); + const auto sv = static_cast(OpType::startingValue(x)); + + for (uint i = 0; i <= uZLenMinusOne; i++) z[i] = OpType::startingValue(x); + + uint zOffset = 0; + for (uint i = 0; i < xLen; ++i) { + z[zOffset] = OpType::update(z[zOffset], OpType::op(x[i], extraParams), + extraParams); + zOffset = zOffset == uZLenMinusOne ? 0 : zOffset + 1; + } + + for (uint i = 0; i <= uZLenMinusOne; i++) + z[i] = OpType::postProcess(z[i], tadLen, extraParams); + } break; + + //*********************************************// + case LoopKind::EWS1: { + for (auto i = start; i < stop; i++) { + auto tad = x + tadOffsets[i]; + auto s = OpType::startingValue(tad); + + for (Nd4jLong j = 0; j < tadLen; j++) + s = OpType::update(s, OpType::op(tad[j], extraParams), extraParams); + + z[i] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::EWSNONZERO: { + for (auto i = start; i < stop; i++) { + auto tad = x + tadOffsets[i]; + auto s = OpType::startingValue(tad); + + for (Nd4jLong j = 0; j < tadLen; j++) + s = OpType::update(s, OpType::op(tad[j * tadEws], extraParams), + extraParams); + + z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::RANK1: { + for (auto i = start; i < stop; i++) { + auto tad = x + tadOffsets[i]; + auto s = OpType::startingValue(tad); + + for (Nd4jLong i0 = 0; i0 < tadLen; ++i0) + s = OpType::update(s, OpType::op(tad[i0 * tadStride[0]], extraParams), + extraParams); + + z[i] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::RANK2: { + for (auto i = start; i < stop; i++) { + auto tad = x + tadOffsets[i]; + auto s = OpType::startingValue(tad); + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) + s = OpType::update( + s, + OpType::op(tad[i0 * tadStride[0] + i1 * tadStride[1]], + extraParams), + extraParams); + + z[i] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::RANK3: { + for (auto i = start; i < stop; i++) { + auto tad = x + tadOffsets[i]; + auto s = OpType::startingValue(tad); + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) + for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) + s = OpType::update( + s, + OpType::op(tad[i0 * tadStride[0] + i1 * tadStride[1] + + i2 * tadStride[2]], + extraParams), + extraParams); + + z[i] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::RANK4: { + for (auto i = start; i < stop; i++) { + auto tad = x + tadOffsets[i]; + auto s = OpType::startingValue(tad); + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) + for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) + for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) + s = OpType::update( + s, + OpType::op(tad[i0 * tadStride[0] + i1 * tadStride[1] + + i2 * tadStride[2] + i3 * tadStride[3]], + extraParams), + extraParams); + + z[i] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::RANK5: { + for (auto i = start; i < stop; i++) { + auto tad = x + tadOffsets[i]; + auto s = OpType::startingValue(tad); + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) + for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) + for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) + for (Nd4jLong i4 = 0; i4 < tadShape[4]; ++i4) + s = OpType::update( + s, + OpType::op(tad[i0 * tadStride[0] + i1 * tadStride[1] + + i2 * tadStride[2] + i3 * tadStride[3] + + i4 * tadStride[4]], + extraParams), + extraParams); + + z[i] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::X_EWSNONZERO: { + uint castZShapeInfo[MAX_RANK]; + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, castZShapeInfo); + + for (auto i = start; i < stop; i++) { + auto tad = x + tadOffsets[i]; + auto s = OpType::startingValue(tad); + + for (Nd4jLong j = 0; j < tadLen; j++) + s = OpType::update(s, OpType::op(tad[j * tadEws], extraParams), + extraParams); + + auto zOffset = + shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ); + z[zOffset] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::Z_EWSNONZERO: { + uint castTadShapeInfo[MAX_RANK]; + const bool canCastTad = sd::DataTypeUtils::castShapeInfo( + tadShapeInfo, castTadShapeInfo); + + for (auto i = start; i < stop; i++) { + auto tad = x + tadOffsets[i]; + auto s = OpType::startingValue(tad); + + for (Nd4jLong j = 0; j < tadLen; j++) { + auto tadOffset = + shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, canCastTad); + s = OpType::update(s, OpType::op(tad[tadOffset], extraParams), + extraParams); } + z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; - const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopTadXYZ(xTadShapeInfo, yTadShapeInfo, zShapeInfo); - - const auto xTadEws = shape::elementWiseStride(xTadShapeInfo); - const auto yTadEws = shape::elementWiseStride(yTadShapeInfo); - const auto zEws = shape::elementWiseStride(zShapeInfo); + //*********************************************// + default: { + auto innertadOffsets = new Nd4jLong[tadLen]; + shape::calcOffsets(tadShapeInfo, innertadOffsets); - const auto zLen = shape::length(zShapeInfo); - const auto tadLen = shape::length(xTadShapeInfo); + uint castZShapeInfo[MAX_RANK]; + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, castZShapeInfo); - const auto tadShape = shape::shapeOf(xTadShapeInfo); - const auto xTadStride = shape::stride(xTadShapeInfo); - const auto yTadStride = shape::stride(xTadShapeInfo); + for (auto i = start; i < stop; i++) { + auto tad = x + tadOffsets[i]; + auto s = OpType::startingValue(tad); - int numThreads = OmpLaunchHelper::tadThreads(tadLen, zLen); + for (Nd4jLong j = 0; j < tadLen; j++) + s = OpType::update( + s, OpType::op(tad[innertadOffsets[j]], extraParams), extraParams); - switch (kindOfLoop) { + auto zOffset = + shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ); + z[zOffset] = OpType::postProcess(s, tadLen, extraParams); + }; - //*********************************************// - case LoopKind::EWS1: { - Z extraParams[3]; - for (auto i = start; i < stop; i++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; - const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; - auto s = OpType::startingValue(xTad); - - for (Nd4jLong j = 0; j < tadLen; ++j) - s = OpType::update(s, OpType::op(xTad[j], yTad[j], extraParams), extraParams); + delete[] innertadOffsets; + } + } +} - z[i] = OpType::postProcess(s, tadLen, extraParams); - }; +////////////////////////////////////////////////////////////////////////////// +template +template +void sd::TransformLoops::loopTransform( + const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, + E* extraParams, uint64_t threadId, uint64_t numThreads) { + const LoopKind::Kind kindOfLoop = + LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo); + + const Nd4jLong* xShape = shape::shapeOf(const_cast(xShapeInfo)); + const Nd4jLong* xStride = shape::stride(const_cast(xShapeInfo)); + const Nd4jLong* zStride = shape::stride(const_cast(zShapeInfo)); + + const Nd4jLong len = shape::length(xShapeInfo); + + if (len == 0) return; + + switch (kindOfLoop) { + //*********************************************// + case LoopKind::EWS1: { + auto span = samediff::Span::build(threadId, numThreads, 0, len, 1); + int64_t start = span.startX(), stop = span.stopX(); + + for (auto i = start; i < stop; i++) z[i] = OpType::op(x[i], extraParams); + } break; + + //*********************************************// + case LoopKind::EWSNONZERO: { + const uint xEws = shape::elementWiseStride(xShapeInfo); + const uint zEws = shape::elementWiseStride(zShapeInfo); + + auto span = samediff::Span::build(threadId, numThreads, 0, len, 1); + int64_t start = span.startX(), stop = span.stopX(); + + for (auto i = start; i < stop; i++) + z[i * zEws] = OpType::op(x[i * xEws], extraParams); + } break; + + //*********************************************// + case LoopKind::Z_EWSNONZERO: { + const uint zEws = shape::elementWiseStride(zShapeInfo); + uint castXShapeInfo[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, castXShapeInfo); + + auto span = samediff::Span::build(threadId, numThreads, 0, len, 1); + int64_t start = span.startX(), stop = span.stopX(); + + if (zEws > 1) { + for (auto i = start; i < stop; i++) { + const auto xOffset = + shape::indexOffset(i, xShapeInfo, castXShapeInfo, canCastX); + z[i * zEws] = OpType::op(x[xOffset], extraParams); + } + } else { + for (auto i = start; i < stop; i++) { + const auto xOffset = + shape::indexOffset(i, xShapeInfo, castXShapeInfo, canCastX); + z[i] = OpType::op(x[xOffset], extraParams); } - break; + } + } break; + + //*********************************************// + case LoopKind::RANK1: { + auto span = samediff::Span::build(threadId, numThreads, 0, len, 1); + + for (auto i0 = span.startX(); i0 < span.stopX(); i0++) + z[i0 * zStride[0]] = OpType::op(x[i0 * xStride[0]], extraParams); + } break; + + //*********************************************// + case LoopKind::RANK2: { + auto uXShape0 = static_cast(xShape[0]); + auto uXShape1 = static_cast(xShape[1]); + + auto loop = + samediff::ThreadsHelper::pickLoop2d(numThreads, uXShape0, uXShape1); + auto span = samediff::Span2::build(loop, threadId, numThreads, 0, + uXShape0, 1, 0, uXShape1, 1); + + for (auto i0 = span.startX(); i0 < span.stopX(); i0++) { + auto z0 = i0 * zStride[0]; + auto x0 = i0 * xStride[0]; + + for (auto i1 = span.startY(); i1 < span.stopY(); ++i1) + z[z0 + i1 * zStride[1]] = + OpType::op(x[x0 + i1 * xStride[1]], extraParams); + } + } break; + + //*********************************************// + case LoopKind::RANK3: { + auto uXShape0 = xShape[0]; + auto uXShape1 = xShape[1]; + auto uXShape2 = xShape[2]; + + auto loop = + samediff::ThreadsHelper::pickLoop2d(numThreads, uXShape0, uXShape1); + auto span = samediff::Span2::build(loop, threadId, numThreads, 0, + uXShape0, 1, 0, uXShape1, 1); + + for (auto i0 = span.startX(); i0 < span.stopX(); i0++) + for (auto i1 = span.startY(); i1 < span.stopY(); i1++) { + auto z0 = i0 * zStride[0] + i1 * zStride[1]; + auto x0 = i0 * xStride[0] + i1 * xStride[1]; + + for (Nd4jLong i2 = 0; i2 < uXShape2; ++i2) + z[z0 + i2 * zStride[2]] = + OpType::op(x[x0 + i2 * xStride[2]], extraParams); + } + } break; + + //*********************************************// + case LoopKind::RANK4: { + auto uXShape0 = xShape[0]; + auto uXShape1 = xShape[1]; + auto uXShape2 = xShape[2]; + auto uXShape3 = xShape[3]; + + auto loop = samediff::ThreadsHelper::pickLoop3d(numThreads, uXShape0, + uXShape1, uXShape2); + auto span = + samediff::Span3::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, + uXShape1, 1, 0, uXShape2, 1); + + for (auto i0 = span.startX(); i0 < span.stopX(); i0++) + for (auto i1 = span.startY(); i1 < span.stopY(); i1++) + for (auto i2 = span.startZ(); i2 < span.stopZ(); i2++) { + auto x0 = i0 * xStride[0] + i1 * xStride[1] + i2 * xStride[2]; + auto z0 = i0 * zStride[0] + i1 * zStride[1] + i2 * zStride[2]; + + for (Nd4jLong i3 = 0; i3 < uXShape3; ++i3) + z[z0 + i3 * zStride[3]] = + OpType::op(x[x0 + i3 * xStride[3]], extraParams); + } + } break; + + //*********************************************// + case LoopKind::RANK5: { + auto uXShape0 = xShape[0]; + auto uXShape1 = xShape[1]; + auto uXShape2 = xShape[2]; + auto uXShape3 = xShape[3]; + auto uXShape4 = xShape[4]; + + auto loop = samediff::ThreadsHelper::pickLoop3d(numThreads, uXShape0, + uXShape1, uXShape2); + auto span = + samediff::Span3::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, + uXShape1, 1, 0, uXShape2, 1); + + for (auto i0 = span.startX(); i0 < span.stopX(); i0++) + for (auto i1 = span.startY(); i1 < span.stopY(); i1++) + for (auto i2 = span.startZ(); i2 < span.stopZ(); i2++) { + auto z0 = i0 * zStride[0] + i1 * zStride[1] + i2 * zStride[2]; + auto x0 = i0 * xStride[0] + i1 * xStride[1] + i2 * xStride[2]; + + for (Nd4jLong i3 = 0; i3 < uXShape3; ++i3) { + auto z1 = z0 + i3 * zStride[3]; + auto x1 = x0 + i3 * xStride[3]; + + for (Nd4jLong i4 = 0; i4 < uXShape4; ++i4) + z[z1 + i4 * zStride[4]] = + OpType::op(x[x1 + i4 * xStride[4]], extraParams); + } + } - //*********************************************// - case LoopKind::EWSNONZERO: { - Z extraParams[3]; - for (auto i = start; i < stop; i++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; + } break; - const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; - const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; - auto s = OpType::startingValue(xTad); + //*********************************************// + default: { + uint xShapeInfoCast[MAX_RANK]; + uint zShapeInfoCast[MAX_RANK]; - for (Nd4jLong j = 0; j < tadLen; ++j) - s = OpType::update(s, OpType::op(xTad[j * xTadEws], yTad[j * yTadEws], extraParams), extraParams); + bool canCastX = DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastZ = DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); - }; - } - break; + auto span = samediff::Span::build(threadId, numThreads, 0, len, 1); - //*********************************************// - case LoopKind::RANK1: { - Z extraParams[3]; - for (auto i = start; i < stop; i++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; - const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; - auto s = OpType::startingValue(xTad); - - for (Nd4jLong i0 = 0; i0 < tadLen; ++i0) { - const auto xTadOffset = i0 * xTadStride[0]; - const auto yTadOffset = i0 * yTadStride[0]; - s = OpType::update(s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); - } + for (auto i = span.startX(); i < span.stopX(); i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[xOffset], extraParams); + } + } + } +} - z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); - }; +////////////////////////////////////////////////////////////////////////////// +template +template +void sd::Reduction3Loops::loopReduce3( + const X* x, const Nd4jLong* xShapeInfo, const X* y, + const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, + int dimsLen, Z* extraParameters, int64_t start, int64_t stop) { + // both tads have same shape, however strides and ews may differ + + Z param0(OpType::startingValue(x)), param1(OpType::startingValue(x)), + param2(extraParameters ? extraParameters[0] : OpType::startingValue(x)); + + const Nd4jLong xLen = shape::length(xShapeInfo); + const Nd4jLong yLen = shape::length(yShapeInfo); + + const Nd4jLong *xTadShapeInfo = nullptr, *yTadShapeInfo = nullptr, + *xTadOffsets = nullptr, *yTadOffsets = nullptr; + TadPack tadPackX, tadPackY; + std::vector zeroOffsets; + + if (xLen == yLen) { + tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions( + xShapeInfo, dims, dimsLen); + tadPackY = sd::ConstantTadHelper::getInstance().tadForDimensions( + yShapeInfo, dims, dimsLen); + xTadShapeInfo = tadPackX.primaryShapeInfo(); + yTadShapeInfo = tadPackY.primaryShapeInfo(); + xTadOffsets = tadPackX.primaryOffsets(); + yTadOffsets = tadPackY.primaryOffsets(); + } else if (yLen > xLen) { + tadPackY = sd::ConstantTadHelper::getInstance().tadForDimensions( + yShapeInfo, dims, dimsLen); + xTadShapeInfo = xShapeInfo; + yTadShapeInfo = tadPackY.primaryShapeInfo(); + yTadOffsets = tadPackY.primaryOffsets(); + } else { + tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions( + xShapeInfo, dims, dimsLen); + yTadShapeInfo = yShapeInfo; + xTadShapeInfo = tadPackX.primaryShapeInfo(); + xTadOffsets = tadPackX.primaryOffsets(); + } + + const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopTadXYZ( + xTadShapeInfo, yTadShapeInfo, zShapeInfo); + + const auto xTadEws = shape::elementWiseStride(xTadShapeInfo); + const auto yTadEws = shape::elementWiseStride(yTadShapeInfo); + const auto zEws = shape::elementWiseStride(zShapeInfo); + + const auto zLen = shape::length(zShapeInfo); + const auto tadLen = shape::length(xTadShapeInfo); + + const auto tadShape = shape::shapeOf(xTadShapeInfo); + const auto xTadStride = shape::stride(xTadShapeInfo); + const auto yTadStride = shape::stride(xTadShapeInfo); + + int numThreads = OmpLaunchHelper::tadThreads(tadLen, zLen); + + switch (kindOfLoop) { + //*********************************************// + case LoopKind::EWS1: { + Z extraParams[3]; + for (auto i = start; i < stop; i++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; + const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; + auto s = OpType::startingValue(xTad); + + for (Nd4jLong j = 0; j < tadLen; ++j) + s = OpType::update(s, OpType::op(xTad[j], yTad[j], extraParams), + extraParams); + + z[i] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::EWSNONZERO: { + Z extraParams[3]; + for (auto i = start; i < stop; i++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; + const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; + auto s = OpType::startingValue(xTad); + + for (Nd4jLong j = 0; j < tadLen; ++j) + s = OpType::update( + s, OpType::op(xTad[j * xTadEws], yTad[j * yTadEws], extraParams), + extraParams); + + z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::RANK1: { + Z extraParams[3]; + for (auto i = start; i < stop; i++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; + const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; + auto s = OpType::startingValue(xTad); + + for (Nd4jLong i0 = 0; i0 < tadLen; ++i0) { + const auto xTadOffset = i0 * xTadStride[0]; + const auto yTadOffset = i0 * yTadStride[0]; + s = OpType::update( + s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), + extraParams); } - break; - //*********************************************// - case LoopKind::RANK2: { - Z extraParams[3]; - for (auto i = start; i < stop; i++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; - const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; - auto s = OpType::startingValue(xTad); - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { - const auto xTadOffset = i0 * xTadStride[0] + i1 * xTadStride[1]; - const auto yTadOffset = i0 * yTadStride[0] + i1 * yTadStride[1]; - s = OpType::update(s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); - } - } - z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); - }; + z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::RANK2: { + Z extraParams[3]; + for (auto i = start; i < stop; i++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; + const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; + auto s = OpType::startingValue(xTad); + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { + const auto xTadOffset = i0 * xTadStride[0] + i1 * xTadStride[1]; + const auto yTadOffset = i0 * yTadStride[0] + i1 * yTadStride[1]; + s = OpType::update( + s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), + extraParams); + } } - break; - - //*********************************************// - case LoopKind::RANK3: { - Z extraParams[3]; - for (auto i = start; i < stop; i++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; - const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; - auto s = OpType::startingValue(xTad); - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { - for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { - const auto xTadOffset = i0 * xTadStride[0] + i1 * xTadStride[1] + i2 * xTadStride[2]; - const auto yTadOffset = i0 * yTadStride[0] + i1 * yTadStride[1] + i2 * yTadStride[2]; - s = OpType::update(s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); - } - } - } - z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); - }; + z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::RANK3: { + Z extraParams[3]; + for (auto i = start; i < stop; i++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; + const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; + auto s = OpType::startingValue(xTad); + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { + for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { + const auto xTadOffset = + i0 * xTadStride[0] + i1 * xTadStride[1] + i2 * xTadStride[2]; + const auto yTadOffset = + i0 * yTadStride[0] + i1 * yTadStride[1] + i2 * yTadStride[2]; + s = OpType::update( + s, + OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), + extraParams); + } + } } - break; - - //*********************************************// - case LoopKind::RANK4: { - Z extraParams[3]; - for (auto i = start; i < stop; i++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; - const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; - auto s = OpType::startingValue(xTad); - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { - for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { - for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) { - const auto xTadOffset = i0 * xTadStride[0] + i1 * xTadStride[1] + i2 * xTadStride[2] + i3 * xTadStride[3]; - const auto yTadOffset = i0 * yTadStride[0] + i1 * yTadStride[1] + i2 * yTadStride[2] + i3 * yTadStride[3]; - s = OpType::update(s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); - } - } - } - } - z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); - }; + z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::RANK4: { + Z extraParams[3]; + for (auto i = start; i < stop; i++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; + const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; + auto s = OpType::startingValue(xTad); + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { + for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { + for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) { + const auto xTadOffset = i0 * xTadStride[0] + + i1 * xTadStride[1] + + i2 * xTadStride[2] + i3 * xTadStride[3]; + const auto yTadOffset = i0 * yTadStride[0] + + i1 * yTadStride[1] + + i2 * yTadStride[2] + i3 * yTadStride[3]; + s = OpType::update( + s, + OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), + extraParams); + } + } + } } - break; - - //*********************************************// - case LoopKind::RANK5: { - Z extraParams[3]; - for (auto i = start; i < stop; i++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; - const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; - auto s = OpType::startingValue(xTad); - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { - for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { - for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) { - for (Nd4jLong i4 = 0; i4 < tadShape[4]; ++i4) { - const auto xTadOffset = i0 * xTadStride[0] + i1 * xTadStride[1] + i2 * xTadStride[2] + i3 * xTadStride[3] + i4 * xTadStride[4]; - const auto yTadOffset = i0 * yTadStride[0] + i1 * yTadStride[1] + i2 * yTadStride[2] + i3 * yTadStride[3] + i4 * yTadStride[4]; - s = OpType::update(s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); - } - } - } - } + z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::RANK5: { + Z extraParams[3]; + for (auto i = start; i < stop; i++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; + const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; + auto s = OpType::startingValue(xTad); + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { + for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { + for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) { + for (Nd4jLong i4 = 0; i4 < tadShape[4]; ++i4) { + const auto xTadOffset = + i0 * xTadStride[0] + i1 * xTadStride[1] + + i2 * xTadStride[2] + i3 * xTadStride[3] + + i4 * xTadStride[4]; + const auto yTadOffset = + i0 * yTadStride[0] + i1 * yTadStride[1] + + i2 * yTadStride[2] + i3 * yTadStride[3] + + i4 * yTadStride[4]; + s = OpType::update(s, + OpType::op(xTad[xTadOffset], + yTad[yTadOffset], extraParams), + extraParams); } - z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); - }; - } - break; - - //*********************************************// - default: { - uint castXTadShapeInfo[MAX_RANK]; - const bool canCastXTad = sd::DataTypeUtils::castShapeInfo(xTadShapeInfo, castXTadShapeInfo); - - if (shape::haveSameShapeAndStrides(xTadShapeInfo, yTadShapeInfo)) { - Z extraParams[3]; - for (auto i = start; i < stop; i++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; - const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; - auto s = OpType::startingValue(xTad); - - for (Nd4jLong j = 0; j < tadLen; ++j) { - const auto tadOffset = shape::indexOffset(j, xTadShapeInfo, castXTadShapeInfo, canCastXTad); - s = OpType::update(s, OpType::op(xTad[tadOffset], yTad[tadOffset], extraParams), extraParams); - } - - z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); - }; + } } - else { - uint castYTadShapeInfo[MAX_RANK]; - const bool canCastYTad = sd::DataTypeUtils::castShapeInfo(yTadShapeInfo, castYTadShapeInfo); - - Z extraParams[3]; - for (auto i = start; i < stop; i++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; - const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; - auto s = OpType::startingValue(xTad); - - for (Nd4jLong j = 0; j < tadLen; ++j) { - const auto xTadOffset = shape::indexOffset(j, xTadShapeInfo, castXTadShapeInfo, canCastXTad); - const auto yTadOffset = shape::indexOffset(j, yTadShapeInfo, castYTadShapeInfo, canCastYTad); - s = OpType::update(s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); - } - z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); - }; - } - } + } } + z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + default: { + uint castXTadShapeInfo[MAX_RANK]; + const bool canCastXTad = sd::DataTypeUtils::castShapeInfo( + xTadShapeInfo, castXTadShapeInfo); + + if (shape::haveSameShapeAndStrides(xTadShapeInfo, yTadShapeInfo)) { + Z extraParams[3]; + for (auto i = start; i < stop; i++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; + const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; + auto s = OpType::startingValue(xTad); + + for (Nd4jLong j = 0; j < tadLen; ++j) { + const auto tadOffset = shape::indexOffset( + j, xTadShapeInfo, castXTadShapeInfo, canCastXTad); + s = OpType::update( + s, OpType::op(xTad[tadOffset], yTad[tadOffset], extraParams), + extraParams); + } + + z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); + }; + } else { + uint castYTadShapeInfo[MAX_RANK]; + const bool canCastYTad = sd::DataTypeUtils::castShapeInfo( + yTadShapeInfo, castYTadShapeInfo); + + Z extraParams[3]; + for (auto i = start; i < stop; i++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; + const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; + auto s = OpType::startingValue(xTad); + + for (Nd4jLong j = 0; j < tadLen; ++j) { + const auto xTadOffset = shape::indexOffset( + j, xTadShapeInfo, castXTadShapeInfo, canCastXTad); + const auto yTadOffset = shape::indexOffset( + j, yTadShapeInfo, castYTadShapeInfo, canCastYTad); + s = OpType::update( + s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), + extraParams); + } + z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); + }; + } } + } +} - ////////////////////////////////////////////////////////////////////////////// - template - template - void sd::Reduction3Loops::loopReduce3All(const X* x, const Nd4jLong* xShapeInfo, - const X* y, const Nd4jLong* yShapeInfo, - Z* z, const Nd4jLong* zShapeInfo, - const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, - const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, - Z* extraParameters, - int64_t start, int64_t stop) { - - // both tads have same shape, however strides and ews may differ - - Z param0(OpType::startingValue(x)), param1(OpType::startingValue(x)), param2(extraParameters ? extraParameters[0] : OpType::startingValue(x)); - - const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopTadXYZ(xTadShapeInfo, yTadShapeInfo, zShapeInfo); - - const auto xTadEws = shape::elementWiseStride(xTadShapeInfo); - const auto yTadEws = shape::elementWiseStride(yTadShapeInfo); - const auto zEws = shape::elementWiseStride(zShapeInfo); - - const auto zLen = shape::length(zShapeInfo); - const auto tadLen = shape::length(xTadShapeInfo); - - const auto numXTads = shape::length(xShapeInfo) / tadLen; - const auto numYTads = shape::length(yShapeInfo) / tadLen; - - const auto tadShape = shape::shapeOf(xTadShapeInfo); - const auto xTadStride = shape::stride(xTadShapeInfo); - const auto yTadStride = shape::stride(yTadShapeInfo); - - const auto startVal = OpType::startingValue(x); - - int numThreads = OmpLaunchHelper::tadThreads(tadLen, numXTads * numYTads); - - switch (kindOfLoop) { - //*********************************************// - case LoopKind::EWS1: { - Z extraParams[3]; - for (Nd4jLong ix = 0; ix < numXTads; ix++) { - for (Nd4jLong iy = 0; iy < numYTads; iy++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = x + xTadOffsets[ix]; - const auto yTad = y + yTadOffsets[iy]; - const auto zInd = ix * numYTads + iy; - auto s = startVal; - - for (Nd4jLong j = 0; j < tadLen; ++j) - s = OpType::update(s, OpType::op(xTad[j], yTad[j], extraParams), extraParams); - - z[zInd] = OpType::postProcess(s, tadLen, extraParams); - } - }; +////////////////////////////////////////////////////////////////////////////// +template +template +void sd::Reduction3Loops::loopReduce3All( + const X* x, const Nd4jLong* xShapeInfo, const X* y, + const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, + const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, + const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, + Z* extraParameters, int64_t start, int64_t stop) { + // both tads have same shape, however strides and ews may differ + + Z param0(OpType::startingValue(x)), param1(OpType::startingValue(x)), + param2(extraParameters ? extraParameters[0] : OpType::startingValue(x)); + + const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopTadXYZ( + xTadShapeInfo, yTadShapeInfo, zShapeInfo); + + const auto xTadEws = shape::elementWiseStride(xTadShapeInfo); + const auto yTadEws = shape::elementWiseStride(yTadShapeInfo); + const auto zEws = shape::elementWiseStride(zShapeInfo); + + const auto zLen = shape::length(zShapeInfo); + const auto tadLen = shape::length(xTadShapeInfo); + + const auto numXTads = shape::length(xShapeInfo) / tadLen; + const auto numYTads = shape::length(yShapeInfo) / tadLen; + + const auto tadShape = shape::shapeOf(xTadShapeInfo); + const auto xTadStride = shape::stride(xTadShapeInfo); + const auto yTadStride = shape::stride(yTadShapeInfo); + + const auto startVal = OpType::startingValue(x); + + int numThreads = OmpLaunchHelper::tadThreads(tadLen, numXTads * numYTads); + + switch (kindOfLoop) { + //*********************************************// + case LoopKind::EWS1: { + Z extraParams[3]; + for (Nd4jLong ix = 0; ix < numXTads; ix++) { + for (Nd4jLong iy = 0; iy < numYTads; iy++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = x + xTadOffsets[ix]; + const auto yTad = y + yTadOffsets[iy]; + const auto zInd = ix * numYTads + iy; + auto s = startVal; + + for (Nd4jLong j = 0; j < tadLen; ++j) + s = OpType::update(s, OpType::op(xTad[j], yTad[j], extraParams), + extraParams); + + z[zInd] = OpType::postProcess(s, tadLen, extraParams); } - break; - - //*********************************************// - case LoopKind::EWSNONZERO: { - Z extraParams[3]; - for (Nd4jLong ix = 0; ix < numXTads; ix++) { - for (Nd4jLong iy = 0; iy < numYTads; iy++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = x + xTadOffsets[ix]; - const auto yTad = y + yTadOffsets[iy]; - const auto zInd = ix * numYTads + iy; - auto s = startVal; - - for (Nd4jLong j = 0; j < tadLen; ++j) - s = OpType::update(s, OpType::op(xTad[j * xTadEws], yTad[j * yTadEws], extraParams), extraParams); - - z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); - } - }; + }; + } break; + + //*********************************************// + case LoopKind::EWSNONZERO: { + Z extraParams[3]; + for (Nd4jLong ix = 0; ix < numXTads; ix++) { + for (Nd4jLong iy = 0; iy < numYTads; iy++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = x + xTadOffsets[ix]; + const auto yTad = y + yTadOffsets[iy]; + const auto zInd = ix * numYTads + iy; + auto s = startVal; + + for (Nd4jLong j = 0; j < tadLen; ++j) + s = OpType::update( + s, + OpType::op(xTad[j * xTadEws], yTad[j * yTadEws], extraParams), + extraParams); + + z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); } - break; - - //*********************************************// - case LoopKind::RANK1: { - Z extraParams[3]; - for (Nd4jLong ix = 0; ix < numXTads; ix++) { - for (Nd4jLong iy = 0; iy < numYTads; iy++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = x + xTadOffsets[ix]; - const auto yTad = y + yTadOffsets[iy]; - const auto zInd = ix * numYTads + iy; - auto s = startVal; - - for (Nd4jLong i0 = 0; i0 < tadLen; ++i0) { - const auto xTadOffset = i0 * xTadStride[0]; - const auto yTadOffset = i0 * yTadStride[0]; - s = OpType::update(s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); - } - z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); - } - }; + }; + } break; + + //*********************************************// + case LoopKind::RANK1: { + Z extraParams[3]; + for (Nd4jLong ix = 0; ix < numXTads; ix++) { + for (Nd4jLong iy = 0; iy < numYTads; iy++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = x + xTadOffsets[ix]; + const auto yTad = y + yTadOffsets[iy]; + const auto zInd = ix * numYTads + iy; + auto s = startVal; + + for (Nd4jLong i0 = 0; i0 < tadLen; ++i0) { + const auto xTadOffset = i0 * xTadStride[0]; + const auto yTadOffset = i0 * yTadStride[0]; + s = OpType::update( + s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), + extraParams); + } + z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); } - break; - - //*********************************************// - case LoopKind::RANK2: { - Z extraParams[3]; - for (Nd4jLong ix = 0; ix < numXTads; ix++) { - for (Nd4jLong iy = 0; iy < numYTads; iy++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = x + xTadOffsets[ix]; - const auto yTad = y + yTadOffsets[iy]; - const auto zInd = ix * numYTads + iy; - auto s = startVal; - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { - const auto xTadOffset = i0 * xTadStride[0] + i1 * xTadStride[1]; - const auto yTadOffset = i0 * yTadStride[0] + i1 * yTadStride[1]; - s = OpType::update(s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); - } - } - z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); - } - }; + }; + } break; + + //*********************************************// + case LoopKind::RANK2: { + Z extraParams[3]; + for (Nd4jLong ix = 0; ix < numXTads; ix++) { + for (Nd4jLong iy = 0; iy < numYTads; iy++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = x + xTadOffsets[ix]; + const auto yTad = y + yTadOffsets[iy]; + const auto zInd = ix * numYTads + iy; + auto s = startVal; + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { + const auto xTadOffset = i0 * xTadStride[0] + i1 * xTadStride[1]; + const auto yTadOffset = i0 * yTadStride[0] + i1 * yTadStride[1]; + s = OpType::update( + s, + OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), + extraParams); + } + } + z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); } - break; - - //*********************************************// - case LoopKind::RANK3: { - Z extraParams[3]; - for (Nd4jLong ix = 0; ix < numXTads; ix++) { - for (Nd4jLong iy = 0; iy < numYTads; iy++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = x + xTadOffsets[ix]; - const auto yTad = y + yTadOffsets[iy]; - const auto zInd = ix * numYTads + iy; - auto s = startVal; - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { - for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { - const auto xTadOffset = i0 * xTadStride[0] + i1 * xTadStride[1] + i2 * xTadStride[2]; - const auto yTadOffset = i0 * yTadStride[0] + i1 * yTadStride[1] + i2 * yTadStride[2]; - s = OpType::update(s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); - } - } - } - z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); - } - }; + }; + } break; + + //*********************************************// + case LoopKind::RANK3: { + Z extraParams[3]; + for (Nd4jLong ix = 0; ix < numXTads; ix++) { + for (Nd4jLong iy = 0; iy < numYTads; iy++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = x + xTadOffsets[ix]; + const auto yTad = y + yTadOffsets[iy]; + const auto zInd = ix * numYTads + iy; + auto s = startVal; + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { + for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { + const auto xTadOffset = i0 * xTadStride[0] + + i1 * xTadStride[1] + i2 * xTadStride[2]; + const auto yTadOffset = i0 * yTadStride[0] + + i1 * yTadStride[1] + i2 * yTadStride[2]; + s = OpType::update( + s, + OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), + extraParams); + } + } + } + z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); } - break; - - //*********************************************// - case LoopKind::RANK4: { - Z extraParams[3]; - for (Nd4jLong ix = 0; ix < numXTads; ix++) { - for (Nd4jLong iy = 0; iy < numYTads; iy++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = x + xTadOffsets[ix]; - const auto yTad = y + yTadOffsets[iy]; - const auto zInd = ix * numYTads + iy; - auto s = startVal; - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { - for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { - for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) { - const auto xTadOffset = i0 * xTadStride[0] + i1 * xTadStride[1] + i2 * xTadStride[2] + i3 * xTadStride[3]; - const auto yTadOffset = i0 * yTadStride[0] + i1 * yTadStride[1] + i2 * yTadStride[2] + i3 * yTadStride[3]; - s = OpType::update(s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); - } - } - } - } - z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::RANK4: { + Z extraParams[3]; + for (Nd4jLong ix = 0; ix < numXTads; ix++) { + for (Nd4jLong iy = 0; iy < numYTads; iy++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = x + xTadOffsets[ix]; + const auto yTad = y + yTadOffsets[iy]; + const auto zInd = ix * numYTads + iy; + auto s = startVal; + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { + for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { + for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) { + const auto xTadOffset = + i0 * xTadStride[0] + i1 * xTadStride[1] + + i2 * xTadStride[2] + i3 * xTadStride[3]; + const auto yTadOffset = + i0 * yTadStride[0] + i1 * yTadStride[1] + + i2 * yTadStride[2] + i3 * yTadStride[3]; + s = OpType::update(s, + OpType::op(xTad[xTadOffset], + yTad[yTadOffset], extraParams), + extraParams); } - }; + } + } + } + z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); } - break; - - //*********************************************// - case LoopKind::RANK5: { - Z extraParams[3]; - for (Nd4jLong ix = 0; ix < numXTads; ix++) { - for (Nd4jLong iy = 0; iy < numYTads; iy++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = x + xTadOffsets[ix]; - const auto yTad = y + yTadOffsets[iy]; - const auto zInd = ix * numYTads + iy; - auto s = startVal; - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { - for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { - for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) { - for (Nd4jLong i4 = 0; i4 < tadShape[4]; ++i4) { - const auto xTadOffset = i0 * xTadStride[0] + i1 * xTadStride[1] + i2 * xTadStride[2] + i3 * xTadStride[3] + i4 * xTadStride[4]; - const auto yTadOffset = i0 * yTadStride[0] + i1 * yTadStride[1] + i2 * yTadStride[2] + i3 * yTadStride[3] + i4 * yTadStride[4]; - s = OpType::update(s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); - } - } - } - } - } - z[zInd * zEws] = OpType::postProcess(start, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::RANK5: { + Z extraParams[3]; + for (Nd4jLong ix = 0; ix < numXTads; ix++) { + for (Nd4jLong iy = 0; iy < numYTads; iy++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = x + xTadOffsets[ix]; + const auto yTad = y + yTadOffsets[iy]; + const auto zInd = ix * numYTads + iy; + auto s = startVal; + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { + for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { + for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) { + for (Nd4jLong i4 = 0; i4 < tadShape[4]; ++i4) { + const auto xTadOffset = + i0 * xTadStride[0] + i1 * xTadStride[1] + + i2 * xTadStride[2] + i3 * xTadStride[3] + + i4 * xTadStride[4]; + const auto yTadOffset = + i0 * yTadStride[0] + i1 * yTadStride[1] + + i2 * yTadStride[2] + i3 * yTadStride[3] + + i4 * yTadStride[4]; + s = OpType::update( + s, + OpType::op(xTad[xTadOffset], yTad[yTadOffset], + extraParams), + extraParams); + } } - }; + } + } + } + z[zInd * zEws] = OpType::postProcess(start, tadLen, extraParams); } - break; - - //*********************************************// - default: { - uint castXTadShapeInfo[MAX_RANK]; - const bool canCastXTad = sd::DataTypeUtils::castShapeInfo(xTadShapeInfo, castXTadShapeInfo); - - if (shape::haveSameShapeAndStrides(xTadShapeInfo, yTadShapeInfo)) { - Z extraParams[3]; - for (Nd4jLong ix = 0; ix < numXTads; ix++) { - for (Nd4jLong iy = 0; iy < numYTads; iy++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = x + xTadOffsets[ix]; - const auto yTad = y + yTadOffsets[iy]; - const auto zInd = ix * numYTads + iy; - auto s = startVal; - - for (Nd4jLong j = 0; j < tadLen; ++j) { - const auto tadOffset = shape::indexOffset(j, xTadShapeInfo, castXTadShapeInfo, canCastXTad); - s = OpType::update(s, OpType::op(xTad[tadOffset], yTad[tadOffset], extraParams), extraParams); - } - z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); - } - }; + }; + } break; + + //*********************************************// + default: { + uint castXTadShapeInfo[MAX_RANK]; + const bool canCastXTad = sd::DataTypeUtils::castShapeInfo( + xTadShapeInfo, castXTadShapeInfo); + + if (shape::haveSameShapeAndStrides(xTadShapeInfo, yTadShapeInfo)) { + Z extraParams[3]; + for (Nd4jLong ix = 0; ix < numXTads; ix++) { + for (Nd4jLong iy = 0; iy < numYTads; iy++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = x + xTadOffsets[ix]; + const auto yTad = y + yTadOffsets[iy]; + const auto zInd = ix * numYTads + iy; + auto s = startVal; + + for (Nd4jLong j = 0; j < tadLen; ++j) { + const auto tadOffset = shape::indexOffset( + j, xTadShapeInfo, castXTadShapeInfo, canCastXTad); + s = OpType::update( + s, OpType::op(xTad[tadOffset], yTad[tadOffset], extraParams), + extraParams); } - else { - uint castYTadShapeInfo[MAX_RANK]; - const bool canCastYTad = sd::DataTypeUtils::castShapeInfo(yTadShapeInfo, castYTadShapeInfo); - - Z extraParams[3]; - for (Nd4jLong ix = 0; ix < numXTads; ix++) { - for (Nd4jLong iy = 0; iy < numYTads; iy++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = x + xTadOffsets[ix]; - const auto yTad = y + yTadOffsets[iy]; - const auto zInd = ix * numYTads + iy; - auto s = startVal; - - for (Nd4jLong j = 0; j < tadLen; ++j) { - const auto xTadOffset = shape::indexOffset(j, xTadShapeInfo, castXTadShapeInfo, canCastXTad); - const auto yTadOffset = shape::indexOffset(j, yTadShapeInfo, castYTadShapeInfo, canCastYTad); - s = OpType::update(s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); - } - - z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); - } - }; + z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); + } + }; + } else { + uint castYTadShapeInfo[MAX_RANK]; + const bool canCastYTad = sd::DataTypeUtils::castShapeInfo( + yTadShapeInfo, castYTadShapeInfo); + + Z extraParams[3]; + for (Nd4jLong ix = 0; ix < numXTads; ix++) { + for (Nd4jLong iy = 0; iy < numYTads; iy++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = x + xTadOffsets[ix]; + const auto yTad = y + yTadOffsets[iy]; + const auto zInd = ix * numYTads + iy; + auto s = startVal; + + for (Nd4jLong j = 0; j < tadLen; ++j) { + const auto xTadOffset = shape::indexOffset( + j, xTadShapeInfo, castXTadShapeInfo, canCastXTad); + const auto yTadOffset = shape::indexOffset( + j, yTadShapeInfo, castYTadShapeInfo, canCastYTad); + s = OpType::update( + s, + OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), + extraParams); } - } - } - } - - + z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); + } + }; + } + } + } } +} // namespace sd -#endif //LIBND4J_LOOPS_H +#endif // LIBND4J_LOOPS_H diff --git a/libnd4j/include/helpers/Loops.hpp b/libnd4j/include/helpers/Loops.hpp index 852ef4808c3d..ee849b654a44 100644 --- a/libnd4j/include/helpers/Loops.hpp +++ b/libnd4j/include/helpers/Loops.hpp @@ -22,17 +22,22 @@ //#define LIBND4J_LOOPS_CPP #include -#include -#include #include +#include +#include - -namespace sd { - - -} -//template void Loops::loopReduce(const double* x, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, double* z, const Nd4jLong* zShapeInfo, double* extraParams, std::function startVal, std::function update, std::function op, std::function postPr); -//template void Loops::loopReduce(const float* x, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, float* z, const Nd4jLong* zShapeInfo, float* extraParams, std::function startVal, std::function update, std::function op, std::function postPr); +namespace sd {} +// template void Loops::loopReduce(const double* x, const +// Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, double* z, const +// Nd4jLong* zShapeInfo, double* extraParams, std::function startVal, std::function update, +// std::function op, +// std::function postPr); template void +// Loops::loopReduce(const float* x, const Nd4jLong* tadShapeInfo, +// const Nd4jLong* tadOffsets, float* z, const Nd4jLong* zShapeInfo, float* +// extraParams, std::function startVal, +// std::function update, +// std::function op, +// std::function postPr); //#endif // LIBND4J_LOOPS_CPP - diff --git a/libnd4j/include/helpers/LoopsCoordsHelper.h b/libnd4j/include/helpers/LoopsCoordsHelper.h index 8a1160aea78a..ac1fdcdfb4f6 100644 --- a/libnd4j/include/helpers/LoopsCoordsHelper.h +++ b/libnd4j/include/helpers/LoopsCoordsHelper.h @@ -14,9 +14,9 @@ * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author AbdelRauf - // +// +// @author AbdelRauf +// #ifndef LIBND4J_LOOPCOORDSHELPER_H #define LIBND4J_LOOPCOORDSHELPER_H #include @@ -29,412 +29,399 @@ namespace sd { #if defined(__GNUC__) -#define likely(x) __builtin_expect( (x), 1) -#define unlikely(x) __builtin_expect( (x), 0) +#define likely(x) __builtin_expect((x), 1) +#define unlikely(x) __builtin_expect((x), 0) #else -#define likely(x) (x) -#define unlikely(x) (x) +#define likely(x) (x) +#define unlikely(x) (x) #endif - using zip_size_t = std::pair; - - template - struct CoordsState :CoordsState { - Nd4jLong coord; - Nd4jLong last_num; - Nd4jLong stride; - Nd4jLong adjust; - CoordsState() :CoordsState() {} - }; - - template<> - struct CoordsState<0> { - Nd4jLong coord; - Nd4jLong last_num; - Nd4jLong stride; - Nd4jLong adjust; - CoordsState() {} - }; - - - template - struct ZipCoordsState :ZipCoordsState { - Nd4jLong coord; - Nd4jLong last_num; - Nd4jLong stride1; - Nd4jLong stride2; - Nd4jLong adjust1; - Nd4jLong adjust2; - ZipCoordsState() : ZipCoordsState() {} - }; - - template<> - struct ZipCoordsState<0> { - Nd4jLong coord; - Nd4jLong last_num; - Nd4jLong stride1; - Nd4jLong stride2; - Nd4jLong adjust1; - Nd4jLong adjust2; - ZipCoordsState() {} - }; - -#define COORDS(x,index) ((x).::sd::CoordsState<(index)>::coord) -#define STRIDE(x,index) ((x).::sd::CoordsState<(index)>::stride) -#define LAST_NUM(x,index) ((x).::sd::CoordsState<(index)>::last_num) -#define OF_ADJUST(x,index) ((x).::sd::CoordsState<(index)>::adjust) -#define ZIP_LAST_NUM(x,index) ((x).::sd::ZipCoordsState<(index)>::last_num) -#define ZIP_COORDS(x,index) ((x).::sd::ZipCoordsState<(index)>::coord) -#define ZIP_STRIDE1(x,index) ((x).::sd::ZipCoordsState<(index)>::stride1) -#define ZIP_STRIDE2(x,index) ((x).::sd::ZipCoordsState<(index)>::stride2) -#define ZIP_OF_ADJUST1(x,index) ((x).::sd::ZipCoordsState<(index)>::adjust1) -#define ZIP_OF_ADJUST2(x,index) ((x).::sd::ZipCoordsState<(index)>::adjust2) - - - FORCEINLINE void index2coords_C(Nd4jLong index, const Nd4jLong rank, const Nd4jLong* bases, Nd4jLong* coords) { - for (size_t i = rank - 1; i > 0; --i) { - coords[i] = index % bases[i]; - index /= bases[i]; - } - coords[0] = index; // last iteration - } - - FORCEINLINE void index2coords_F(Nd4jLong index, const Nd4jLong rank, const Nd4jLong* bases, Nd4jLong* coords) { - - for (size_t i = 0; i < rank - 1; i++) { - coords[i] = index % bases[i]; - index /= bases[i]; - } - coords[rank - 1] = index; // last iteration - } - - FORCEINLINE size_t offset_from_coords(const Nd4jLong* strides, const Nd4jLong* coords, const Nd4jLong& rank) { - - size_t offset = 0; - size_t rank_4 = rank & -4; - for (int i = 0; i < rank_4; i += 4) { - offset = offset - + coords[i] * strides[i] - + coords[i + 1] * strides[i + 1] - + coords[i + 2] * strides[i + 2] - + coords[i + 3] * strides[i + 3]; - } - for (int i = rank_4; i < rank; i++) { - offset += coords[i] * strides[i]; - } - return offset; - } - - - FORCEINLINE zip_size_t offset_from_coords(const Nd4jLong* x_strides, const Nd4jLong* z_strides, const Nd4jLong* coords, const Nd4jLong& rank) { - - zip_size_t offset = { 0,0 }; - size_t rank_4 = rank & -4; - for (int i = 0; i < rank_4; i += 4) { - offset.first = offset.first - + coords[i] * x_strides[i] - + coords[i + 1] * x_strides[i + 1] - + coords[i + 2] * x_strides[i + 2] - + coords[i + 3] * x_strides[i + 3]; - offset.second = offset.second - + coords[i] * z_strides[i] - + coords[i + 1] * z_strides[i + 1] - + coords[i + 2] * z_strides[i + 2] - + coords[i + 3] * z_strides[i + 3]; - } - for (int i = rank_4; i < rank; i++) { - offset.first += coords[i] * x_strides[i]; - offset.second += coords[i] * z_strides[i]; - } - return offset; - } - - template - constexpr size_t StridesOrderInd() { - return Last_Index_Faster ? Rank - Index - 1 : Index; - } - - template - FORCEINLINE - typename std::enable_if<(Rank - 1 == Index), size_t>::type - coord_inc_n(CoordsState& cbs, size_t last_offset) { - - constexpr size_t Ind = StridesOrderInd(); - - if (likely(COORDS(cbs, Ind) < LAST_NUM(cbs, Ind))) { - last_offset += cbs.CoordsState::stride; - COORDS(cbs, Ind) = COORDS(cbs, Ind) + 1; - return last_offset; - } - //overflow case should not happen - COORDS(cbs, Ind) = 0; - //last_offset = 0;// last_offset + strides[Ind] - adjust_stride; - return 0; - } - - template - FORCEINLINE - typename std::enable_if<(Rank - 1 != Index), size_t >::type - coord_inc_n(CoordsState& cbs, size_t last_offset) { - - constexpr size_t Ind = StridesOrderInd(); - - if (likely(COORDS(cbs, Ind) < LAST_NUM(cbs, Ind))) { - last_offset = last_offset + cbs.CoordsState::stride; - COORDS(cbs, Ind) = COORDS(cbs, Ind) + 1; - } - else { - //lets adjust offset - last_offset -= OF_ADJUST(cbs, Ind); - COORDS(cbs, Ind) = 0; - last_offset = coord_inc_n(cbs, last_offset); - } - - return last_offset; - - } - - template - FORCEINLINE size_t inc_coords(CoordsState& cbs, size_t last_offset) { - - return coord_inc_n(cbs,/* 1,*/ last_offset/*, 0*/); - } - - template - FORCEINLINE size_t inc_coords_ews(CoordsState& cbs, size_t last_offset, size_t ews) { - if (ews == 1) { - constexpr size_t Ind = StridesOrderInd(); - return last_offset + STRIDE(cbs, Ind); - } - return coord_inc_n(cbs,/* 1,*/ last_offset/*, 0*/); - } - - template - FORCEINLINE - typename std::enable_if<(Rank - 1 == rankIndex), zip_size_t>::type - coord_inc_n(ZipCoordsState& cbs, zip_size_t last_offset) { - - constexpr size_t Ind = StridesOrderInd(); - - if (likely(ZIP_COORDS(cbs, Ind) < ZIP_LAST_NUM(cbs, Ind))) { - last_offset.first += ZIP_STRIDE1(cbs, Ind); - last_offset.second += ZIP_STRIDE2(cbs, Ind); - ZIP_COORDS(cbs, Ind) = ZIP_COORDS(cbs, Ind) + 1; - return last_offset; - } - //overflow case should not happen - ZIP_COORDS(cbs, Ind) = 0; - //last_offset = 0;// last_offset + strides[Ind] - adjust_stride; - return { 0,0 }; - } - - template - FORCEINLINE - typename std::enable_if<(Rank - 1 != rankIndex), zip_size_t >::type - coord_inc_n(ZipCoordsState& cbs, zip_size_t last_offset) { - - constexpr size_t Ind = StridesOrderInd(); - - if (likely(ZIP_COORDS(cbs, Ind) < ZIP_LAST_NUM(cbs, Ind))) { - last_offset.first += ZIP_STRIDE1(cbs, Ind); - last_offset.second += ZIP_STRIDE2(cbs, Ind); - ZIP_COORDS(cbs, Ind) = ZIP_COORDS(cbs, Ind) + 1; - } - else { - - //lets adjust offset - last_offset.first -= ZIP_OF_ADJUST1(cbs, Ind); - last_offset.second -= ZIP_OF_ADJUST2(cbs, Ind); - ZIP_COORDS(cbs, Ind) = 0; - last_offset = coord_inc_n(cbs, last_offset); - } - - return last_offset; - - } - - template - FORCEINLINE zip_size_t inc_coords(ZipCoordsState& cbs, zip_size_t last_offset) { - - return coord_inc_n(cbs, last_offset); - } - - - template - FORCEINLINE - typename std::enable_if<(Rank - 1 == rankIndex), size_t>::type - init_coords(CoordsState& cbs, const Nd4jLong index, const Nd4jLong* bases, const Nd4jLong* strides, size_t offset = 0) { - constexpr size_t Ind = StridesOrderInd(); - COORDS(cbs, Ind) = index % bases[Ind]; - LAST_NUM(cbs, Ind) = bases[Ind] - 1; - STRIDE(cbs, Ind) = strides[Ind]; - OF_ADJUST(cbs, Ind) = bases[Ind] * strides[Ind] - strides[Ind]; - offset += COORDS(cbs, Ind) * strides[Ind]; - return offset; - } - +using zip_size_t = std::pair; + +template +struct CoordsState : CoordsState { + Nd4jLong coord; + Nd4jLong last_num; + Nd4jLong stride; + Nd4jLong adjust; + CoordsState() : CoordsState() {} +}; + +template <> +struct CoordsState<0> { + Nd4jLong coord; + Nd4jLong last_num; + Nd4jLong stride; + Nd4jLong adjust; + CoordsState() {} +}; + +template +struct ZipCoordsState : ZipCoordsState { + Nd4jLong coord; + Nd4jLong last_num; + Nd4jLong stride1; + Nd4jLong stride2; + Nd4jLong adjust1; + Nd4jLong adjust2; + ZipCoordsState() : ZipCoordsState() {} +}; + +template <> +struct ZipCoordsState<0> { + Nd4jLong coord; + Nd4jLong last_num; + Nd4jLong stride1; + Nd4jLong stride2; + Nd4jLong adjust1; + Nd4jLong adjust2; + ZipCoordsState() {} +}; + +#define COORDS(x, index) ((x).::sd::CoordsState<(index)>::coord) +#define STRIDE(x, index) ((x).::sd::CoordsState<(index)>::stride) +#define LAST_NUM(x, index) ((x).::sd::CoordsState<(index)>::last_num) +#define OF_ADJUST(x, index) ((x).::sd::CoordsState<(index)>::adjust) +#define ZIP_LAST_NUM(x, index) ((x).::sd::ZipCoordsState<(index)>::last_num) +#define ZIP_COORDS(x, index) ((x).::sd::ZipCoordsState<(index)>::coord) +#define ZIP_STRIDE1(x, index) ((x).::sd::ZipCoordsState<(index)>::stride1) +#define ZIP_STRIDE2(x, index) ((x).::sd::ZipCoordsState<(index)>::stride2) +#define ZIP_OF_ADJUST1(x, index) ((x).::sd::ZipCoordsState<(index)>::adjust1) +#define ZIP_OF_ADJUST2(x, index) ((x).::sd::ZipCoordsState<(index)>::adjust2) + +FORCEINLINE void index2coords_C(Nd4jLong index, const Nd4jLong rank, + const Nd4jLong* bases, Nd4jLong* coords) { + for (size_t i = rank - 1; i > 0; --i) { + coords[i] = index % bases[i]; + index /= bases[i]; + } + coords[0] = index; // last iteration +} +FORCEINLINE void index2coords_F(Nd4jLong index, const Nd4jLong rank, + const Nd4jLong* bases, Nd4jLong* coords) { + for (size_t i = 0; i < rank - 1; i++) { + coords[i] = index % bases[i]; + index /= bases[i]; + } + coords[rank - 1] = index; // last iteration +} - template - FORCEINLINE - typename std::enable_if<(Rank - 1 != rankIndex), size_t>::type - init_coords(CoordsState& cbs, const Nd4jLong index, const Nd4jLong* bases, const Nd4jLong* strides, size_t offset = 0) { - constexpr size_t Ind = StridesOrderInd(); - COORDS(cbs, Ind) = index % bases[Ind]; - LAST_NUM(cbs, Ind) = bases[Ind] - 1; - STRIDE(cbs, Ind) = strides[Ind]; - OF_ADJUST(cbs, Ind) = bases[Ind] * strides[Ind] - strides[Ind]; - offset += COORDS(cbs, Ind) * strides[Ind]; - return init_coords(cbs, index / bases[Ind], bases, strides, offset); - } +FORCEINLINE size_t offset_from_coords(const Nd4jLong* strides, + const Nd4jLong* coords, + const Nd4jLong& rank) { + size_t offset = 0; + size_t rank_4 = rank & -4; + for (int i = 0; i < rank_4; i += 4) { + offset = offset + coords[i] * strides[i] + coords[i + 1] * strides[i + 1] + + coords[i + 2] * strides[i + 2] + coords[i + 3] * strides[i + 3]; + } + for (int i = rank_4; i < rank; i++) { + offset += coords[i] * strides[i]; + } + return offset; +} +FORCEINLINE zip_size_t offset_from_coords(const Nd4jLong* x_strides, + const Nd4jLong* z_strides, + const Nd4jLong* coords, + const Nd4jLong& rank) { + zip_size_t offset = {0, 0}; + size_t rank_4 = rank & -4; + for (int i = 0; i < rank_4; i += 4) { + offset.first = offset.first + coords[i] * x_strides[i] + + coords[i + 1] * x_strides[i + 1] + + coords[i + 2] * x_strides[i + 2] + + coords[i + 3] * x_strides[i + 3]; + offset.second = offset.second + coords[i] * z_strides[i] + + coords[i + 1] * z_strides[i + 1] + + coords[i + 2] * z_strides[i + 2] + + coords[i + 3] * z_strides[i + 3]; + } + for (int i = rank_4; i < rank; i++) { + offset.first += coords[i] * x_strides[i]; + offset.second += coords[i] * z_strides[i]; + } + return offset; +} +template +constexpr size_t StridesOrderInd() { + return Last_Index_Faster ? Rank - Index - 1 : Index; +} +template +FORCEINLINE typename std::enable_if<(Rank - 1 == Index), size_t>::type +coord_inc_n(CoordsState& cbs, size_t last_offset) { + constexpr size_t Ind = StridesOrderInd(); + + if (likely(COORDS(cbs, Ind) < LAST_NUM(cbs, Ind))) { + last_offset += cbs.CoordsState::stride; + COORDS(cbs, Ind) = COORDS(cbs, Ind) + 1; + return last_offset; + } + // overflow case should not happen + COORDS(cbs, Ind) = 0; + // last_offset = 0;// last_offset + strides[Ind] - adjust_stride; + return 0; +} - template - FORCEINLINE - typename std::enable_if<(Rank - 1 == rankIndex), bool>::type - eq_coords(CoordsState& cbs, const Nd4jLong* coords) { - return COORDS(cbs, rankIndex) == coords[rankIndex]; - } +template +FORCEINLINE typename std::enable_if<(Rank - 1 != Index), size_t>::type +coord_inc_n(CoordsState& cbs, size_t last_offset) { + constexpr size_t Ind = StridesOrderInd(); + + if (likely(COORDS(cbs, Ind) < LAST_NUM(cbs, Ind))) { + last_offset = last_offset + cbs.CoordsState::stride; + COORDS(cbs, Ind) = COORDS(cbs, Ind) + 1; + } else { + // lets adjust offset + last_offset -= OF_ADJUST(cbs, Ind); + COORDS(cbs, Ind) = 0; + last_offset = + coord_inc_n(cbs, last_offset); + } + + return last_offset; +} - template - FORCEINLINE - typename std::enable_if<(Rank - 1 != rankIndex), bool>::type - eq_coords(CoordsState& cbs, const Nd4jLong* coords) { - return COORDS(cbs, rankIndex) == coords[rankIndex] && eq_coords(cbs, coords); - } +template +FORCEINLINE size_t inc_coords(CoordsState& cbs, size_t last_offset) { + return coord_inc_n( + cbs, /* 1,*/ last_offset /*, 0*/); +} +template +FORCEINLINE size_t inc_coords_ews(CoordsState& cbs, + size_t last_offset, size_t ews) { + if (ews == 1) { + constexpr size_t Ind = + StridesOrderInd(); + return last_offset + STRIDE(cbs, Ind); + } + return coord_inc_n( + cbs, /* 1,*/ last_offset /*, 0*/); +} - template - FORCEINLINE - typename std::enable_if<(Rank - 1 == rankIndex), bool>::type - eq_zip_coords(ZipCoordsState& cbs, const Nd4jLong* coords) { - return ZIP_COORDS(cbs, rankIndex) == coords[rankIndex]; - } +template +FORCEINLINE typename std::enable_if<(Rank - 1 == rankIndex), zip_size_t>::type +coord_inc_n(ZipCoordsState& cbs, zip_size_t last_offset) { + constexpr size_t Ind = StridesOrderInd(); + + if (likely(ZIP_COORDS(cbs, Ind) < ZIP_LAST_NUM(cbs, Ind))) { + last_offset.first += ZIP_STRIDE1(cbs, Ind); + last_offset.second += ZIP_STRIDE2(cbs, Ind); + ZIP_COORDS(cbs, Ind) = ZIP_COORDS(cbs, Ind) + 1; + return last_offset; + } + // overflow case should not happen + ZIP_COORDS(cbs, Ind) = 0; + // last_offset = 0;// last_offset + strides[Ind] - adjust_stride; + return {0, 0}; +} - template - FORCEINLINE - typename std::enable_if<(Rank - 1 != rankIndex), bool>::type - eq_zip_coords(ZipCoordsState& cbs, const Nd4jLong* coords) { - return ZIP_COORDS(cbs, rankIndex) == coords[rankIndex] && eq_zip_coords(cbs, coords); - } +template +FORCEINLINE typename std::enable_if<(Rank - 1 != rankIndex), zip_size_t>::type +coord_inc_n(ZipCoordsState& cbs, zip_size_t last_offset) { + constexpr size_t Ind = StridesOrderInd(); + + if (likely(ZIP_COORDS(cbs, Ind) < ZIP_LAST_NUM(cbs, Ind))) { + last_offset.first += ZIP_STRIDE1(cbs, Ind); + last_offset.second += ZIP_STRIDE2(cbs, Ind); + ZIP_COORDS(cbs, Ind) = ZIP_COORDS(cbs, Ind) + 1; + } else { + // lets adjust offset + last_offset.first -= ZIP_OF_ADJUST1(cbs, Ind); + last_offset.second -= ZIP_OF_ADJUST2(cbs, Ind); + ZIP_COORDS(cbs, Ind) = 0; + last_offset = + coord_inc_n(cbs, last_offset); + } + + return last_offset; +} - template - FORCEINLINE - typename std::enable_if<(Rank - 1 == rankIndex), zip_size_t>::type - init_coords(ZipCoordsState& cbs, const Nd4jLong index, const Nd4jLong* bases, const Nd4jLong* x_strides, const Nd4jLong* z_strides, zip_size_t offset = {}) { - constexpr size_t Ind = StridesOrderInd(); - ZIP_COORDS(cbs, Ind) = index % bases[Ind]; - ZIP_LAST_NUM(cbs, Ind) = bases[Ind] - 1; - ZIP_STRIDE1(cbs, Ind) = x_strides[Ind]; - ZIP_STRIDE2(cbs, Ind) = z_strides[Ind]; - ZIP_OF_ADJUST1(cbs, Ind) = ZIP_LAST_NUM(cbs, Ind) * ZIP_STRIDE1(cbs, Ind); - ZIP_OF_ADJUST2(cbs, Ind) = ZIP_LAST_NUM(cbs, Ind) * ZIP_STRIDE2(cbs, Ind); - offset.first += ZIP_COORDS(cbs, Ind) * ZIP_STRIDE1(cbs, Ind); - offset.second += ZIP_COORDS(cbs, Ind) * ZIP_STRIDE2(cbs, Ind); - return offset; - } +template +FORCEINLINE zip_size_t inc_coords(ZipCoordsState& cbs, + zip_size_t last_offset) { + return coord_inc_n(cbs, last_offset); +} - template - FORCEINLINE - typename std::enable_if<(Rank - 1 != rankIndex), zip_size_t>::type - init_coords(ZipCoordsState& cbs, const Nd4jLong index, const Nd4jLong* bases, const Nd4jLong* x_strides, const Nd4jLong* z_strides, zip_size_t offset = {}) { - constexpr size_t Ind = StridesOrderInd(); - ZIP_COORDS(cbs, Ind) = index % bases[Ind]; - ZIP_LAST_NUM(cbs, Ind) = bases[Ind] - 1; - ZIP_STRIDE1(cbs, Ind) = x_strides[Ind]; - ZIP_STRIDE2(cbs, Ind) = z_strides[Ind]; - ZIP_OF_ADJUST1(cbs, Ind) = ZIP_LAST_NUM(cbs, Ind) * ZIP_STRIDE1(cbs, Ind); - ZIP_OF_ADJUST2(cbs, Ind) = ZIP_LAST_NUM(cbs, Ind) * ZIP_STRIDE2(cbs, Ind); - offset.first += ZIP_COORDS(cbs, Ind) * ZIP_STRIDE1(cbs, Ind); - offset.second += ZIP_COORDS(cbs, Ind) * ZIP_STRIDE2(cbs, Ind); - return init_coords(cbs, index / bases[Ind], bases, x_strides, z_strides, offset); - } +template +FORCEINLINE typename std::enable_if<(Rank - 1 == rankIndex), size_t>::type +init_coords(CoordsState& cbs, const Nd4jLong index, + const Nd4jLong* bases, const Nd4jLong* strides, size_t offset = 0) { + constexpr size_t Ind = StridesOrderInd(); + COORDS(cbs, Ind) = index % bases[Ind]; + LAST_NUM(cbs, Ind) = bases[Ind] - 1; + STRIDE(cbs, Ind) = strides[Ind]; + OF_ADJUST(cbs, Ind) = bases[Ind] * strides[Ind] - strides[Ind]; + offset += COORDS(cbs, Ind) * strides[Ind]; + return offset; +} +template +FORCEINLINE typename std::enable_if<(Rank - 1 != rankIndex), size_t>::type +init_coords(CoordsState& cbs, const Nd4jLong index, + const Nd4jLong* bases, const Nd4jLong* strides, size_t offset = 0) { + constexpr size_t Ind = StridesOrderInd(); + COORDS(cbs, Ind) = index % bases[Ind]; + LAST_NUM(cbs, Ind) = bases[Ind] - 1; + STRIDE(cbs, Ind) = strides[Ind]; + OF_ADJUST(cbs, Ind) = bases[Ind] * strides[Ind] - strides[Ind]; + offset += COORDS(cbs, Ind) * strides[Ind]; + return init_coords( + cbs, index / bases[Ind], bases, strides, offset); +} - //inc coords for non constant Ranks - template - FORCEINLINE size_t inc_coords(const Nd4jLong* bases, const Nd4jLong* strides, Nd4jLong* coords, size_t last_offset, const size_t rank, const size_t skip = 0) { +template +FORCEINLINE typename std::enable_if<(Rank - 1 == rankIndex), bool>::type +eq_coords(CoordsState& cbs, const Nd4jLong* coords) { + return COORDS(cbs, rankIndex) == coords[rankIndex]; +} - Nd4jLong val; - for (int i = rank - skip - 1; i >= 0; i--) { - val = coords[i] + 1; - if (likely(val < bases[i])) { - coords[i] = val; - last_offset += strides[i]; - break; - } - else { - last_offset -= coords[i] * strides[i]; - coords[i] = 0; - } - } - return last_offset; - } +template +FORCEINLINE typename std::enable_if<(Rank - 1 != rankIndex), bool>::type +eq_coords(CoordsState& cbs, const Nd4jLong* coords) { + return COORDS(cbs, rankIndex) == coords[rankIndex] && + eq_coords(cbs, coords); +} - template<> - FORCEINLINE size_t inc_coords(const Nd4jLong* bases, const Nd4jLong* strides, Nd4jLong* coords, size_t last_offset, const size_t rank, const size_t skip) { +template +FORCEINLINE typename std::enable_if<(Rank - 1 == rankIndex), bool>::type +eq_zip_coords(ZipCoordsState& cbs, const Nd4jLong* coords) { + return ZIP_COORDS(cbs, rankIndex) == coords[rankIndex]; +} - Nd4jLong val; - for (int i = skip; i < rank; i++) { - val = coords[i] + 1; - if (likely(val < bases[i])) { - coords[i] = val; - last_offset += strides[i]; - break; - } - else { - last_offset -= coords[i] * strides[i]; - coords[i] = 0; - } - } - return last_offset; - } +template +FORCEINLINE typename std::enable_if<(Rank - 1 != rankIndex), bool>::type +eq_zip_coords(ZipCoordsState& cbs, const Nd4jLong* coords) { + return ZIP_COORDS(cbs, rankIndex) == coords[rankIndex] && + eq_zip_coords(cbs, coords); +} +template +FORCEINLINE typename std::enable_if<(Rank - 1 == rankIndex), zip_size_t>::type +init_coords(ZipCoordsState& cbs, const Nd4jLong index, + const Nd4jLong* bases, const Nd4jLong* x_strides, + const Nd4jLong* z_strides, zip_size_t offset = {}) { + constexpr size_t Ind = StridesOrderInd(); + ZIP_COORDS(cbs, Ind) = index % bases[Ind]; + ZIP_LAST_NUM(cbs, Ind) = bases[Ind] - 1; + ZIP_STRIDE1(cbs, Ind) = x_strides[Ind]; + ZIP_STRIDE2(cbs, Ind) = z_strides[Ind]; + ZIP_OF_ADJUST1(cbs, Ind) = ZIP_LAST_NUM(cbs, Ind) * ZIP_STRIDE1(cbs, Ind); + ZIP_OF_ADJUST2(cbs, Ind) = ZIP_LAST_NUM(cbs, Ind) * ZIP_STRIDE2(cbs, Ind); + offset.first += ZIP_COORDS(cbs, Ind) * ZIP_STRIDE1(cbs, Ind); + offset.second += ZIP_COORDS(cbs, Ind) * ZIP_STRIDE2(cbs, Ind); + return offset; +} - template - FORCEINLINE zip_size_t inc_coords(const Nd4jLong* bases, const Nd4jLong* x_strides, const Nd4jLong* z_strides, Nd4jLong* coords, zip_size_t last_offset, const size_t rank, const size_t skip = 0) { +template +FORCEINLINE typename std::enable_if<(Rank - 1 != rankIndex), zip_size_t>::type +init_coords(ZipCoordsState& cbs, const Nd4jLong index, + const Nd4jLong* bases, const Nd4jLong* x_strides, + const Nd4jLong* z_strides, zip_size_t offset = {}) { + constexpr size_t Ind = StridesOrderInd(); + ZIP_COORDS(cbs, Ind) = index % bases[Ind]; + ZIP_LAST_NUM(cbs, Ind) = bases[Ind] - 1; + ZIP_STRIDE1(cbs, Ind) = x_strides[Ind]; + ZIP_STRIDE2(cbs, Ind) = z_strides[Ind]; + ZIP_OF_ADJUST1(cbs, Ind) = ZIP_LAST_NUM(cbs, Ind) * ZIP_STRIDE1(cbs, Ind); + ZIP_OF_ADJUST2(cbs, Ind) = ZIP_LAST_NUM(cbs, Ind) * ZIP_STRIDE2(cbs, Ind); + offset.first += ZIP_COORDS(cbs, Ind) * ZIP_STRIDE1(cbs, Ind); + offset.second += ZIP_COORDS(cbs, Ind) * ZIP_STRIDE2(cbs, Ind); + return init_coords( + cbs, index / bases[Ind], bases, x_strides, z_strides, offset); +} - Nd4jLong val = 0; - for (int i = rank - skip - 1; i >= 0; i--) { - val = coords[i] + 1; - if (likely(val < bases[i])) { - coords[i] = val; - last_offset.first += x_strides[i]; - last_offset.second += z_strides[i]; - break; - } - else { - last_offset.first -= coords[i] * x_strides[i]; - last_offset.second -= coords[i] * z_strides[i]; - coords[i] = 0; - } - } - return last_offset; - } +// inc coords for non constant Ranks +template +FORCEINLINE size_t inc_coords(const Nd4jLong* bases, const Nd4jLong* strides, + Nd4jLong* coords, size_t last_offset, + const size_t rank, const size_t skip = 0) { + Nd4jLong val; + for (int i = rank - skip - 1; i >= 0; i--) { + val = coords[i] + 1; + if (likely(val < bases[i])) { + coords[i] = val; + last_offset += strides[i]; + break; + } else { + last_offset -= coords[i] * strides[i]; + coords[i] = 0; + } + } + return last_offset; +} - template<> - FORCEINLINE zip_size_t inc_coords(const Nd4jLong* bases, const Nd4jLong* x_strides, const Nd4jLong* z_strides, Nd4jLong* coords, zip_size_t last_offset, const size_t rank, const size_t skip) { +template <> +FORCEINLINE size_t inc_coords(const Nd4jLong* bases, + const Nd4jLong* strides, Nd4jLong* coords, + size_t last_offset, const size_t rank, + const size_t skip) { + Nd4jLong val; + for (int i = skip; i < rank; i++) { + val = coords[i] + 1; + if (likely(val < bases[i])) { + coords[i] = val; + last_offset += strides[i]; + break; + } else { + last_offset -= coords[i] * strides[i]; + coords[i] = 0; + } + } + return last_offset; +} - Nd4jLong val = 0; - for (int i = skip; i < rank; i++) { - val = coords[i] + 1; - if (likely(val < bases[i])) { - coords[i] = val; +template +FORCEINLINE zip_size_t inc_coords(const Nd4jLong* bases, + const Nd4jLong* x_strides, + const Nd4jLong* z_strides, Nd4jLong* coords, + zip_size_t last_offset, const size_t rank, + const size_t skip = 0) { + Nd4jLong val = 0; + for (int i = rank - skip - 1; i >= 0; i--) { + val = coords[i] + 1; + if (likely(val < bases[i])) { + coords[i] = val; + last_offset.first += x_strides[i]; + last_offset.second += z_strides[i]; + break; + } else { + last_offset.first -= coords[i] * x_strides[i]; + last_offset.second -= coords[i] * z_strides[i]; + coords[i] = 0; + } + } + return last_offset; +} - last_offset.first += x_strides[i]; - last_offset.second += z_strides[i]; - break; - } - else { - last_offset.first -= coords[i] * x_strides[i]; - last_offset.second -= coords[i] * z_strides[i]; - coords[i] = 0; - } - } - return last_offset; - } +template <> +FORCEINLINE zip_size_t inc_coords(const Nd4jLong* bases, + const Nd4jLong* x_strides, + const Nd4jLong* z_strides, + Nd4jLong* coords, + zip_size_t last_offset, + const size_t rank, const size_t skip) { + Nd4jLong val = 0; + for (int i = skip; i < rank; i++) { + val = coords[i] + 1; + if (likely(val < bases[i])) { + coords[i] = val; + + last_offset.first += x_strides[i]; + last_offset.second += z_strides[i]; + break; + } else { + last_offset.first -= coords[i] * x_strides[i]; + last_offset.second -= coords[i] * z_strides[i]; + coords[i] = 0; + } + } + return last_offset; +} struct triple_size_t { @@ -641,7 +628,7 @@ namespace sd { first_end = ind; first_begin = 0; //nd4j_printf("rffrr ss & %d ind-- %d %d\n", first_rank, first_begin, first_end); - //squash output rank + //squash output rank if (first_squash && first_rank > 1) { if (order == 'c') { @@ -661,7 +648,7 @@ namespace sd { first_begin = first_end - first_rank; } else { - //squash fortran + //squash fortran int uniq_ind = 0; for (int i = 1; i < first_end; i++) { if (new_strides[i] == new_bases[uniq_ind] * new_strides[uniq_ind]) { @@ -677,7 +664,7 @@ namespace sd { } first_end = first_begin + first_rank; - } + } ind = first_end; } @@ -696,7 +683,7 @@ namespace sd { second_begin = first_end; } - + if (second_squash && second_rank > 1) { @@ -733,7 +720,7 @@ namespace sd { second_end = second_begin + second_rank; } - + } return; diff --git a/libnd4j/include/helpers/MKLDNNStream.h b/libnd4j/include/helpers/MKLDNNStream.h index f575c48d9f92..2db2cdac5077 100644 --- a/libnd4j/include/helpers/MKLDNNStream.h +++ b/libnd4j/include/helpers/MKLDNNStream.h @@ -27,48 +27,56 @@ #if defined(HAVE_MKLDNN) +#include + +#include +#include + namespace sd { - class MKLDNNStream { - protected: - std::string _opName; +class MKLDNNStream { + protected: + std::string _opName; - std::vector _inputs; - std::vector _outputs; - std::vector _floatArguments; - std::vector _intArguments; + std::vector _inputs; + std::vector _outputs; + std::vector _floatArguments; + std::vector _intArguments; - public: - template - static bool isSupported() { - // FIXME: strict float support doesn't work anymore - return typeid(X) == typeid(float) && typeid(Y) == typeid(float); - } + public: + template + static bool isSupported() { + // FIXME: strict float support doesn't work anymore + return typeid(X) == typeid(float) && typeid(Y) == typeid(float); + } - static bool isSupported(const std::vector &arrays) { - // FIXME: strict float support doesn't work anymore - for (auto v:arrays) { - if (v != nullptr && v->dataType() != sd::DataType::FLOAT32) { - return false; - } - } - return true; - } + static bool isSupported(const std::vector &arrays) { + // FIXME: strict float support doesn't work anymore + for (auto v : arrays) { + if (v != nullptr && v->dataType() != sd::DataType::FLOAT32) { + return false; + } + } + return true; + } - explicit MKLDNNStream(const std::string &opName) : _opName(opName) { } + explicit MKLDNNStream(const std::string &opName) : _opName(opName) {} - bool checkAndReset(const std::vector &inputs, const std::vector &outputs, - const std::vector &floatArguments, const std::vector &intArguments) { - if (inputs != _inputs || outputs != _outputs || floatArguments != _floatArguments || intArguments != _intArguments) { - _inputs = inputs; - _outputs = outputs; - _floatArguments = floatArguments; - _intArguments = intArguments; - return true; - } - return false; - } - }; -} + bool checkAndReset(const std::vector &inputs, + const std::vector &outputs, + const std::vector &floatArguments, + const std::vector &intArguments) { + if (inputs != _inputs || outputs != _outputs || + floatArguments != _floatArguments || intArguments != _intArguments) { + _inputs = inputs; + _outputs = outputs; + _floatArguments = floatArguments; + _intArguments = intArguments; + return true; + } + return false; + } +}; +} // namespace sd #endif -#endif //LIBND4J_MKLDNNSTREAM_H +#endif // LIBND4J_MKLDNNSTREAM_H diff --git a/libnd4j/include/helpers/MmulHelper.h b/libnd4j/include/helpers/MmulHelper.h index 517ca98881de..549d20ea7c68 100644 --- a/libnd4j/include/helpers/MmulHelper.h +++ b/libnd4j/include/helpers/MmulHelper.h @@ -25,44 +25,68 @@ #include "array/NDArray.h" namespace sd { - class ND4J_EXPORT MmulHelper { - - private: - - // multiptication N-dimensions tensor on other N-dimensions one - static sd::NDArray* mmulNxN(const sd::NDArray* A, const sd::NDArray* B, sd::NDArray* C, const double alpha = 1.0, const double beta = 0.0, const char outOrder = 'f'); - - // dot product of vectors (X * Y) = Z[0] - static sd::NDArray* dot(const sd::NDArray* X, const sd::NDArray* Y, sd::NDArray* Z, const double alpha = 1.0, const double beta = 0.0); - - // multiptication Matrix to Matrix - static sd::NDArray* mmulMxM(const sd::NDArray* A, const sd::NDArray* B, sd::NDArray* C, double alpha = 1.0, double beta = 0.0, const char outOrder = 'f'); - - // multiptication Matrix to vector - static sd::NDArray* mmulMxV(const sd::NDArray* A, const sd::NDArray* B, sd::NDArray* C, double alpha = 1.0, double beta = 0.0, const char outOrder = 'f'); - - public: - - static sd::NDArray* mmul(const sd::NDArray* A, const sd::NDArray* B, sd::NDArray* C = nullptr, const double alpha = 1.0, const double beta = 0.0, const char outOrder = 'f'); - - static sd::NDArray* tensorDot(const sd::NDArray* A, const sd::NDArray* B, const std::initializer_list& axesA, const std::initializer_list& axesB = {}); - - static sd::NDArray* tensorDot(const sd::NDArray* A, const sd::NDArray* B, const std::vector& axesA, const std::vector& axesB); - - static void tensorDot(const sd::NDArray* a, const sd::NDArray* b, sd::NDArray* c, const std::vector& axes_a, const std::vector& axes_b, const std::vector& permutForC = {}); - +class SD_EXPORT MmulHelper { + private: + // multiptication N-dimensions tensor on other N-dimensions one + static sd::NDArray* mmulNxN(const sd::NDArray* A, const sd::NDArray* B, + sd::NDArray* C, const double alpha = 1.0, + const double beta = 0.0, + const char outOrder = 'f'); + + // dot product of vectors (X * Y) = Z[0] + static sd::NDArray* dot(const sd::NDArray* X, const sd::NDArray* Y, + sd::NDArray* Z, const double alpha = 1.0, + const double beta = 0.0); + + // multiptication Matrix to Matrix + static sd::NDArray* mmulMxM(const sd::NDArray* A, const sd::NDArray* B, + sd::NDArray* C, double alpha = 1.0, + double beta = 0.0, const char outOrder = 'f'); + + // multiptication Matrix to vector + static sd::NDArray* mmulMxV(const sd::NDArray* A, const sd::NDArray* B, + sd::NDArray* C, double alpha = 1.0, + double beta = 0.0, const char outOrder = 'f'); + + public: + static sd::NDArray* mmul(const sd::NDArray* A, const sd::NDArray* B, + sd::NDArray* C = nullptr, const double alpha = 1.0, + const double beta = 0.0, const char outOrder = 'f'); + + static sd::NDArray* tensorDot(const sd::NDArray* A, const sd::NDArray* B, + const std::initializer_list& axesA, + const std::initializer_list& axesB = {}); + + static sd::NDArray* tensorDot(const sd::NDArray* A, const sd::NDArray* B, + const std::vector& axesA, + const std::vector& axesB); + + static void tensorDot(const sd::NDArray* a, const sd::NDArray* b, + sd::NDArray* c, const std::vector& axes_a, + const std::vector& axes_b, + const std::vector& permutForC = {}); #ifndef __JAVACPP_HACK__ - /** - * modif - (can be empty) vector containing a subsequence of permutation/reshaping arrays (in any order), user must take care of correctness of such arrays by himself - */ - static void tensorDot(const sd::NDArray* a, const sd::NDArray* b, sd::NDArray* c, const std::vector>& modifA, const std::vector>& modifB, const std::vector>& modifC); - static sd::NDArray* tensorDot(const sd::NDArray* a, const sd::NDArray* b, const std::vector>& modifA, const std::vector>& modifB); + /** + * modif - (can be empty) vector containing a subsequence of + * permutation/reshaping arrays (in any order), user must take care of + * correctness of such arrays by himself + */ + static void tensorDot(const sd::NDArray* a, const sd::NDArray* b, + sd::NDArray* c, + const std::vector>& modifA, + const std::vector>& modifB, + const std::vector>& modifC); + static sd::NDArray* tensorDot( + const sd::NDArray* a, const sd::NDArray* b, + const std::vector>& modifA, + const std::vector>& modifB); #endif - static void matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY, double alpha = 1.0, double beta = 0.0); - }; -} - + static void matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, + const bool transX, const bool transY, double alpha = 1.0, + double beta = 0.0); +}; +} // namespace sd -#endif //LIBND4J_MMULHELPER_H \ No newline at end of file +#endif // LIBND4J_MMULHELPER_H \ No newline at end of file diff --git a/libnd4j/include/helpers/OmpLaunchHelper.h b/libnd4j/include/helpers/OmpLaunchHelper.h index 3e0e50391472..289f91253b63 100644 --- a/libnd4j/include/helpers/OmpLaunchHelper.h +++ b/libnd4j/include/helpers/OmpLaunchHelper.h @@ -22,49 +22,48 @@ #ifndef LIBND4J_OMPLAUNCHHELPER_H #define LIBND4J_OMPLAUNCHHELPER_H -#include -#include #include +#include + +#include namespace sd { -class ND4J_EXPORT OmpLaunchHelper { - - public: - - OmpLaunchHelper() = delete; - - OmpLaunchHelper(const Nd4jLong N, float desiredNumThreads = -1); +class SD_EXPORT OmpLaunchHelper { + public: + OmpLaunchHelper() = delete; - FORCEINLINE Nd4jLong getThreadOffset(const int threadNum); - FORCEINLINE Nd4jLong getItersPerThread(const int threadNum); + OmpLaunchHelper(const Nd4jLong N, float desiredNumThreads = -1); - static Nd4jLong betterSpan(Nd4jLong N); - static Nd4jLong betterSpan(Nd4jLong N, Nd4jLong numThreads); - - static int betterThreads(Nd4jLong N); - static int betterThreads(Nd4jLong N, int maxThreads); + FORCEINLINE Nd4jLong getThreadOffset(const int threadNum); + FORCEINLINE Nd4jLong getItersPerThread(const int threadNum); - static int tadThreads(Nd4jLong tadLength, Nd4jLong numTads); + static Nd4jLong betterSpan(Nd4jLong N); + static Nd4jLong betterSpan(Nd4jLong N, Nd4jLong numThreads); - int _numThreads; - unsigned int _itersPerThread; - unsigned int _remainder; + static int betterThreads(Nd4jLong N); + static int betterThreads(Nd4jLong N, int maxThreads); + + static int tadThreads(Nd4jLong tadLength, Nd4jLong numTads); + + int _numThreads; + unsigned int _itersPerThread; + unsigned int _remainder; }; //////////////////////////////////////////////////////////////////////////////// FORCEINLINE Nd4jLong OmpLaunchHelper::getThreadOffset(const int threadNum) { - - return threadNum * _itersPerThread; + return threadNum * _itersPerThread; } //////////////////////////////////////////////////////////////////////////////// FORCEINLINE Nd4jLong OmpLaunchHelper::getItersPerThread(const int threadNum) { - - return (threadNum == _numThreads - 1) ? _itersPerThread + _remainder : _itersPerThread; // last thread may contain bigger number of iterations -} - + return (threadNum == _numThreads - 1) + ? _itersPerThread + _remainder + : _itersPerThread; // last thread may contain bigger number of + // iterations } +} // namespace sd -#endif //LIBND4J_OMPLAUNCHHELPER_H +#endif // LIBND4J_OMPLAUNCHHELPER_H diff --git a/libnd4j/include/helpers/OpArgsHolder.h b/libnd4j/include/helpers/OpArgsHolder.h index a9432f134aa9..f19569a017f7 100644 --- a/libnd4j/include/helpers/OpArgsHolder.h +++ b/libnd4j/include/helpers/OpArgsHolder.h @@ -21,85 +21,71 @@ #ifndef LIBND4J_OPARGSHOLDER_H #define LIBND4J_OPARGSHOLDER_H - #include #include namespace sd { -class ND4J_EXPORT OpArgsHolder { - -private: - - std::vector _inArrs = std::vector(); - std::vector _tArgs = std::vector(); - std::vector _iArgs = std::vector(); - std::vector _bArgs = std::vector(); - - std::vector _isArrAlloc = std::vector(); - - int _numInArrs = _inArrs.size(); - int _numTArgs = _tArgs.size(); - int _numIArgs = _iArgs.size(); - int _numBArgs = _bArgs.size(); +class SD_EXPORT OpArgsHolder { + private: + std::vector _inArrs = std::vector(); + std::vector _tArgs = std::vector(); + std::vector _iArgs = std::vector(); + std::vector _bArgs = std::vector(); -public: + std::vector _isArrAlloc = std::vector(); - // default constructor - OpArgsHolder(); + int _numInArrs = _inArrs.size(); + int _numTArgs = _tArgs.size(); + int _numIArgs = _iArgs.size(); + int _numBArgs = _bArgs.size(); - // copy constructor - OpArgsHolder(const OpArgsHolder& other); + public: + // default constructor + OpArgsHolder(); - // constructor - OpArgsHolder(const std::vector& inArrs, const std::vector& tArgs = std::vector(), const std::vector& iArgs = std::vector(), const std::vector& bArgs = std::vector()); + // copy constructor + OpArgsHolder(const OpArgsHolder& other); - // move constructor - OpArgsHolder(OpArgsHolder&& other) noexcept; + // constructor + OpArgsHolder(const std::vector& inArrs, + const std::vector& tArgs = std::vector(), + const std::vector& iArgs = std::vector(), + const std::vector& bArgs = std::vector()); - // assignment operator - OpArgsHolder& operator=(const OpArgsHolder& other); + // move constructor + OpArgsHolder(OpArgsHolder&& other) noexcept; - // move assignment operator - OpArgsHolder& operator=(OpArgsHolder&& other) noexcept; + // assignment operator + OpArgsHolder& operator=(const OpArgsHolder& other); - const std::vector& getInArrs() const - {return _inArrs; } + // move assignment operator + OpArgsHolder& operator=(OpArgsHolder&& other) noexcept; - const std::vector& getTArgs() const - {return _tArgs; } + const std::vector& getInArrs() const { return _inArrs; } - const std::vector& getIArgs() const - {return _iArgs; } + const std::vector& getTArgs() const { return _tArgs; } - const std::vector& getBArgs() const - {return _bArgs; } + const std::vector& getIArgs() const { return _iArgs; } - const std::vector& getAllocInfo() const - {return _isArrAlloc; } + const std::vector& getBArgs() const { return _bArgs; } - int getNumInArrs() const - {return _numInArrs; } + const std::vector& getAllocInfo() const { return _isArrAlloc; } - int getNumTArgs() const - {return _numTArgs; } + int getNumInArrs() const { return _numInArrs; } - int getNumIArgs() const - {return _numIArgs; } + int getNumTArgs() const { return _numTArgs; } - int getNumBArgs() const - {return _numBArgs; } + int getNumIArgs() const { return _numIArgs; } - OpArgsHolder createArgsHolderForBP(const std::vector& inGradArrs, const bool isInPlace = false) const; + int getNumBArgs() const { return _numBArgs; } - ~OpArgsHolder() noexcept; + OpArgsHolder createArgsHolderForBP(const std::vector& inGradArrs, + const bool isInPlace = false) const; + ~OpArgsHolder() noexcept; }; +} // namespace sd - - - -} - -#endif //LIBND4J_OPARGSHOLDER_H +#endif // LIBND4J_OPARGSHOLDER_H diff --git a/libnd4j/include/helpers/OpBenchmark.h b/libnd4j/include/helpers/OpBenchmark.h index 328b20dce3ab..9ffa67de93f8 100644 --- a/libnd4j/include/helpers/OpBenchmark.h +++ b/libnd4j/include/helpers/OpBenchmark.h @@ -18,59 +18,60 @@ // Created by raver on 2/28/2019. // -#ifndef DEV_TESTS_OPEXECUTIONER_H -#define DEV_TESTS_OPEXECUTIONER_H +#ifndef SD_OPEXECUTIONER_H +#define SD_OPEXECUTIONER_H -#include #include -#include -#include #include +#include +#include +#include namespace sd { - class ND4J_EXPORT OpBenchmark { - protected: - int _opNum = 0; - std::string _testName; - NDArray *_x = nullptr; - NDArray *_y = nullptr; - NDArray *_z = nullptr; - std::vector _axis; - public: - OpBenchmark() = default; - OpBenchmark(std::string name, NDArray *x, NDArray *y, NDArray *z); - OpBenchmark(std::string name, NDArray *x, NDArray *z); - OpBenchmark(std::string name, NDArray *x, NDArray *z, std::initializer_list axis); - OpBenchmark(std::string name, NDArray *x, NDArray *z, std::vector axis); - OpBenchmark(std::string name, NDArray *x, NDArray *y, NDArray *z, std::initializer_list axis); - OpBenchmark(std::string name, NDArray *x, NDArray *y, NDArray *z, std::vector axis); +class SD_EXPORT OpBenchmark { + protected: + int _opNum = 0; + std::string _testName; + NDArray _x; + NDArray _y; + NDArray _z; + std::vector _axis; - void setOpNum(int opNum); - void setTestName(std::string testName); - void setX(NDArray *array); - void setY(NDArray *array); - void setZ(NDArray *array); - void setAxis(std::vector axis); - void setAxis(std::initializer_list axis); + public: + OpBenchmark() = default; + OpBenchmark(const std::string &name, const NDArray &x, const NDArray &y, + const NDArray &z); + OpBenchmark(const std::string &name, const NDArray &x, const NDArray &z); + OpBenchmark(const std::string &name, const NDArray &x, const NDArray &z, + const std::vector &axis); + OpBenchmark(const std::string &name, const NDArray &x, const NDArray &y, + const NDArray &z, const std::vector &axis); - NDArray& x(); - int opNum(); - std::string testName(); - std::vector getAxis(); + void setOpNum(int opNum); + void setTestName(const std::string &testName); + void setX(const NDArray &array); + void setY(const NDArray &array); + void setZ(const NDArray &array); + void setAxis(std::vector axis); + void setAxis(std::initializer_list axis); - virtual std::string extra(); - virtual std::string dataType(); - virtual std::string axis() = 0; - virtual std::string orders() = 0; - virtual std::string strides() = 0; - virtual std::string shape(); - virtual std::string inplace() = 0; + NDArray &x(); + int opNum() const; + const std::string &testName() const; + std::vector getAxis(); - virtual void executeOnce() = 0; + virtual std::string extra(); + virtual std::string dataType(); + virtual std::string axis() = 0; + virtual std::string orders() = 0; + virtual std::string strides() = 0; + virtual std::string shape(); + virtual std::string inplace() = 0; - virtual OpBenchmark* clone() = 0; - }; -} + virtual void executeOnce() = 0; + virtual OpBenchmark *clone() = 0; +}; +} // namespace sd -#endif //DEV_TESTS_OPEXECUTIONER_H +#endif // SD_OPEXECUTIONER_H diff --git a/libnd4j/include/helpers/OpTracker.h b/libnd4j/include/helpers/OpTracker.h index dfccf5e5ddbd..7094fbc56980 100644 --- a/libnd4j/include/helpers/OpTracker.h +++ b/libnd4j/include/helpers/OpTracker.h @@ -18,41 +18,47 @@ // @author raver119@gmail.com // -#ifndef LIBND4J_OP_TRACKER_H -#define LIBND4J_OP_TRACKER_H +#ifndef SD_OP_TRACKER_H +#define SD_OP_TRACKER_H -#include -#include -#include -#include #include #include #include +#include + +#include +#include +#include namespace sd { - class ND4J_EXPORT OpTracker { - private: - std::string _export; +class SD_EXPORT OpTracker { + private: + + + std::string _export; + + int _operations = 0; + std::map> _map; - int _operations = 0; - std::map> _map; + OpTracker() = default; + ~OpTracker() = default; - OpTracker() = default; - ~OpTracker() = default; + template + std::string local_to_string(T value); - template - std::string local_to_string(T value); - public: - static OpTracker& getInstance(); + public: + static OpTracker& getInstance(); - int totalGroups(); - int totalOperations(); + int totalGroups(); + int totalOperations(); - void storeOperation(sd::graph::OpType opType, const sd::ops::OpDescriptor& descriptor); - void storeOperation(sd::graph::OpType opType, const char* opName, const Nd4jLong opNum); + void storeOperation(sd::graph::OpType opType, + const sd::ops::OpDescriptor& descriptor); + void storeOperation(sd::graph::OpType opType, const char* opName, + const Nd4jLong opNum); - const char* exportOperations(); - }; -} + const char* exportOperations(); +}; +} // namespace sd #endif diff --git a/libnd4j/include/helpers/PointersManager.h b/libnd4j/include/helpers/PointersManager.h index 4f7af94098eb..6572d7a2ee0c 100644 --- a/libnd4j/include/helpers/PointersManager.h +++ b/libnd4j/include/helpers/PointersManager.h @@ -22,60 +22,56 @@ #ifndef CUDAMANAGER_H #define CUDAMANAGER_H -#include -#include #include - #include -namespace sd { - -class ND4J_EXPORT PointersManager { - - private: - - sd::LaunchContext *_context; - std::vector _pOnGlobMem; - std::string _funcName; +#include +#include - public: +namespace sd { - PointersManager(const sd::LaunchContext* context, const std::string& funcName = ""); +class SD_EXPORT PointersManager { + private: + sd::LaunchContext* _context; + std::vector _pOnGlobMem; + std::string _funcName; - ~PointersManager(); + public: + PointersManager(const sd::LaunchContext* context, + const std::string& funcName = ""); - void* replicatePointer(const void* src, const size_t size); + ~PointersManager(); - void synchronize() const; + void* replicatePointer(const void* src, const size_t size); - template - void printDevContentOnHost(const void* pDev, const Nd4jLong len) const; + void synchronize() const; + template + void printDevContentOnHost(const void* pDev, const Nd4jLong len) const; #ifdef __CUDABLAS__ - template - static void printDevContentOnDevFromHost(const void* pDev, const Nd4jLong len, const int tid = 0); + template + static void printDevContentOnDevFromHost(const void* pDev, const Nd4jLong len, + const int tid = 0); #endif #ifdef __CUDACC__ - template - static FORCEINLINE __device__ void printDevContentOnDev(const void* pDev, const Nd4jLong len, const int tid = 0) { - if(blockIdx.x * blockDim.x + threadIdx.x != tid) - return; + template + static FORCEINLINE __device__ void printDevContentOnDev(const void* pDev, + const Nd4jLong len, + const int tid = 0) { + if (blockIdx.x * blockDim.x + threadIdx.x != tid) return; - printf("device print out: \n"); - for(Nd4jLong i = 0; i < len; ++i) - printf("%f, ", (double)reinterpret_cast(pDev)[i]); + printf("device print out: \n"); + for (Nd4jLong i = 0; i < len; ++i) + printf("%f, ", (double)reinterpret_cast(pDev)[i]); - printf("\n"); - } + printf("\n"); + } #endif - }; -} - - +} // namespace sd -#endif // CUDAMANAGER_H +#endif // CUDAMANAGER_H diff --git a/libnd4j/include/helpers/RandomLauncher.h b/libnd4j/include/helpers/RandomLauncher.h index 49e961062d96..0e22eda3b079 100644 --- a/libnd4j/include/helpers/RandomLauncher.h +++ b/libnd4j/include/helpers/RandomLauncher.h @@ -19,29 +19,51 @@ // #include -#include -#include #include +#include +#include namespace sd { - class ND4J_EXPORT RandomLauncher { - public: - static void applyDropOut(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z = nullptr); - static void applyInvertedDropOut(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z = nullptr); - static void applyAlphaDropOut(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray *array, double retainProb, double alpha, double beta, double alphaPrime, NDArray* z = nullptr); +class SD_EXPORT RandomLauncher { + public: + static void applyDropOut(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, NDArray* array, + double retainProb, NDArray* z = nullptr); + static void applyInvertedDropOut(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, + NDArray* array, double retainProb, + NDArray* z = nullptr); + static void applyAlphaDropOut(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, NDArray* array, + double retainProb, double alpha, double beta, + double alphaPrime, NDArray* z = nullptr); - static void fillUniform(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double from, double to); + static void fillUniform(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, NDArray* array, + double from, double to); - static void fillGaussian(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev); + static void fillGaussian(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, NDArray* array, + double mean, double stdev); - static void fillExponential(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double lambda); + static void fillExponential(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, NDArray* array, + double lambda); - static void fillLogNormal(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev); + static void fillLogNormal(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, NDArray* array, + double mean, double stdev); - static void fillTruncatedNormal(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev); + static void fillTruncatedNormal(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, + NDArray* array, double mean, double stdev); - static void fillBinomial(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, int trials, double prob); + static void fillBinomial(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, NDArray* array, + int trials, double prob); - static void fillBernoulli(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double prob); - }; -} \ No newline at end of file + static void fillBernoulli(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, NDArray* array, + double prob); +}; +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/ShapeBuilders.h b/libnd4j/include/helpers/ShapeBuilders.h index e2c29a280405..455a4fdb0600 100644 --- a/libnd4j/include/helpers/ShapeBuilders.h +++ b/libnd4j/include/helpers/ShapeBuilders.h @@ -18,52 +18,79 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_SHAPEBUILDERS_H -#define DEV_TESTS_SHAPEBUILDERS_H +#ifndef SD_SHAPEBUILDERS_H +#define SD_SHAPEBUILDERS_H -#include +#include +#include #include -#include #include -#include -#include - -namespace sd { - class ND4J_EXPORT ShapeBuilders { - public: - static Nd4jLong* createScalarShapeInfo(sd::DataType dataType, sd::memory::Workspace* workspace = nullptr); +#include - static Nd4jLong* createVectorShapeInfo(const sd::DataType dataType, const Nd4jLong length, sd::memory::Workspace* workspace = nullptr); +#include - /** - * create shapeInfo for given order basing on shape stored in shapeOnly vector - * memory allocation for shapeInfo is on given workspace - */ - static Nd4jLong* createShapeInfo(const sd::DataType dataType, const char order, int rank, const Nd4jLong* shapeOnly, memory::Workspace* workspace = nullptr); - static Nd4jLong* createShapeInfo(const sd::DataType dataType, const char order, const std::vector& shapeOnly, memory::Workspace* workspace = nullptr); - static Nd4jLong* createShapeInfo(const sd::DataType dataType, const char order, const std::initializer_list& shapeOnly, memory::Workspace* workspace = nullptr); +namespace sd { +class SD_EXPORT ShapeBuilders { + public: + static Nd4jLong* createScalarShapeInfo( + sd::DataType dataType, sd::memory::Workspace* workspace = nullptr); - /** - * allocates memory for new shapeInfo and copy all information from inShapeInfo to new shapeInfo - * if copyStrides is false then strides for new shapeInfo are recalculated - */ - static Nd4jLong* copyShapeInfo(const Nd4jLong* inShapeInfo, const bool copyStrides, memory::Workspace* workspace = nullptr); - static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const DataType dtype, const bool copyStrides, memory::Workspace* workspace = nullptr); - static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const Nd4jLong* shapeInfoToGetTypeFrom, const bool copyStrides, memory::Workspace* workspace = nullptr); + static Nd4jLong* createVectorShapeInfo( + const sd::DataType dataType, const Nd4jLong length, + sd::memory::Workspace* workspace = nullptr); - /** - * allocates memory for new shapeInfo and copy all information from inShapeInfo to new shapeInfo except dimensions in dimsToExclude (unit dimensions) and corresponding strides - * for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, dimsToExclude = {2,3}, dimsSize = 2 - * then outShapeInfo will contain {3, 2,3,4, 12,4,1, 16384,1,99} - */ - static Nd4jLong* copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace = nullptr); + /** + * create shapeInfo for given order basing on shape stored in shapeOnly + * vector memory allocation for shapeInfo is on given workspace + */ + static Nd4jLong* createShapeInfo(const sd::DataType dataType, + const char order, int rank, + const Nd4jLong* shapeOnly, + memory::Workspace* workspace = nullptr); + static Nd4jLong* createShapeInfo(const sd::DataType dataType, + const char order, + const std::vector& shapeOnly, + memory::Workspace* workspace = nullptr); + static Nd4jLong* createShapeInfo( + const sd::DataType dataType, const char order, + const std::initializer_list& shapeOnly, + memory::Workspace* workspace = nullptr); - static Nd4jLong* emptyShapeInfo(const sd::DataType dataType, memory::Workspace* workspace = nullptr); + /** + * allocates memory for new shapeInfo and copy all information from + * inShapeInfo to new shapeInfo if copyStrides is false then strides for new + * shapeInfo are recalculated + */ + static Nd4jLong* copyShapeInfo(const Nd4jLong* inShapeInfo, + const bool copyStrides, + memory::Workspace* workspace = nullptr); + static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, + const DataType dtype, + const bool copyStrides, + memory::Workspace* workspace = nullptr); + static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, + const Nd4jLong* shapeInfoToGetTypeFrom, + const bool copyStrides, + memory::Workspace* workspace = nullptr); - static Nd4jLong* emptyShapeInfo(const sd::DataType dataType, const char order, const std::vector &shape, memory::Workspace* workspace = nullptr); + /** + * allocates memory for new shapeInfo and copy all information from + * inShapeInfo to new shapeInfo except dimensions in dimsToExclude (unit + * dimensions) and corresponding strides for example inShapeInfo is {3, + * 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, dimsToExclude = {2,3}, dimsSize = 2 + * then outShapeInfo will contain {3, 2,3,4, 12,4,1, 16384,1,99} + */ + static Nd4jLong* copyShapeInfoWithoutUnites( + const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, + memory::Workspace* workspace = nullptr); - }; -} + static Nd4jLong* emptyShapeInfo(const sd::DataType dataType, + memory::Workspace* workspace = nullptr); + static Nd4jLong* emptyShapeInfo(const sd::DataType dataType, const char order, + const std::vector& shape, + memory::Workspace* workspace = nullptr); +}; +} // namespace sd -#endif //DEV_TESTS_SHAPEBUILDERS_H +#endif // SD_SHAPEBUILDERS_H diff --git a/libnd4j/include/helpers/ShapeUtils.h b/libnd4j/include/helpers/ShapeUtils.h index cb2faa43da76..a4573268e5f0 100644 --- a/libnd4j/include/helpers/ShapeUtils.h +++ b/libnd4j/include/helpers/ShapeUtils.h @@ -21,202 +21,296 @@ #ifndef LIBND4J_SHAPEUTILS_H #define LIBND4J_SHAPEUTILS_H -#include #include -namespace sd { - - class ND4J_EXPORT ShapeUtils { - - public: - - // evaluate shape for array resulting from tensorDot operation, also evaluate shapes and permutation dimensions for transposition of two input arrays - static std::vector evalShapeForTensorDot(const Nd4jLong* aShapeInfo, const Nd4jLong* bShapeInfo, std::vector axesA, std::vector axesB, std::vector& permutAt, std::vector& permutBt, std::vector& shapeAt, std::vector& shapeBt); - static std::vector evalShapeForTensorDot(const NDArray* a, const NDArray* b, const std::vector& axesA, const std::vector& axesB, std::vector& permutAt, std::vector& permutBt, std::vector& shapeAt, std::vector& shapeBt); - - // evaluate resulting shape after reduce operation - static const Nd4jLong* evalReduceShapeInfo(const char order, std::vector& dimensions, const NDArray& arr, const sd::DataType dataType, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr); - static const Nd4jLong* evalReduceShapeInfo(const char order, std::vector& dimensions, const Nd4jLong* shapeInfo, const sd::DataType dataType, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr); - static const Nd4jLong* evalReduceShapeInfo(const char order, std::vector& dimensions, const NDArray& arr, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr); - static const Nd4jLong* evalReduceShapeInfo(const char order, std::vector& dimensions, const Nd4jLong* shapeInfo, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr); - - /** - * evaluate output shape for reduce operation when input shape is empty - * behavior is analogous to tf - */ - static const Nd4jLong* evalReduceShapeInfoEmpty(const char order, std::vector& dimensions, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, sd::memory::Workspace* workspace); - - // evaluate shape for array which is result of repeat operation applied to arr - static std::vector evalRepeatShape(int axis, const std::vector& repeats, const NDArray& arr); - - // evaluate shapeInfo of permuted array - // if setContigStrides = true, then set contiguous strides in output shapeInfo in accordance with arr order - static const Nd4jLong* evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides = false); - static const Nd4jLong* evalPermShapeInfo(const Nd4jLong* dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace); - - // evaluate shapeInfo of transposed array - // if setContigStrides = true, then set contiguous strides in output shapeInfo in accordance with arr order - static const Nd4jLong* evalTranspShapeInfo(const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides = false); - - static bool copyVectorPart(std::vector& target, std::vector& source, int rank, int offset); - - // return new (shorter) sorted dimensions array without dimensions that are present in input vector - static std::vector evalDimsToExclude(const int rank, const int dimsLen, const int* dimensions); - static std::vector evalDimsToExclude(const int rank, const std::vector& dimensions); - - // check whether 2 arrays have mutually broadcastable shapes - // shape comparison starts from the end - static bool areShapesBroadcastable(const NDArray &arr1, const NDArray &arr2); - static bool areShapesBroadcastable(const Nd4jLong* shapeX, const Nd4jLong* shapeY); - static bool areShapesBroadcastable(const std::vector& shape1, const std::vector& shape2); - - // check the possibility of broadcast operation, if true then return shapeInfo of resulting array - // if evalMinMax == false then array with larger rank has to be passed as first argument - static bool evalBroadcastShapeInfo(const NDArray& max, const NDArray& min, const bool evalMinMax, const Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace); - static bool evalBroadcastShapeInfo(const Nd4jLong *max, const Nd4jLong *min, const bool evalMinMax, const Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace); - - // evaluate sorted vector of max axes to create tads along in case of simple broadcast operation - // if simple broadcast is not possible then empty vector is returned - // PLEASE NOTE: condition (rank_max >= rank_min) should be satisfied ! - static std::vector tadAxesForSimpleBroadcast(const NDArray& max, const NDArray& min); - - // check the possibility of broadcast operation for set of arrays, if true then return resulting broadcasted shapeInfo - static bool evalCommonBroadcastShapeInfo(const std::vector& arrays, Nd4jLong*& resultShapeInfo, memory::Workspace* workspace = nullptr); - - // return sorted vector of dimensions common (same) for two arrays, dimensions values corresponds to array with bigger rank - // for example if arr1{2,7}, arr2{2,5,4,7} then vector = {0,3} - static std::vector getDimsWithSameShape(const NDArray& max, const NDArray& min); - - // evaluate shapeInfo for resulting array of tile operation - static const Nd4jLong* evalTileShapeInfo(const NDArray& arr, const std::vector& reps, sd::memory::Workspace* workspace); - - // returns shape part of shapeInfo as std::vector - static std::vector pullShapeFromShapeInfo(const Nd4jLong *shapeInfo); - - static std::string shapeAsString(const NDArray* array); - static std::string shapeAsString(const std::vector& shape); - static std::string shapeAsString(const Nd4jLong* shapeInfo); - static std::string shapeAsString(const int rank, const Nd4jLong* shapeInfo); - static std::string strideAsString(const NDArray* array); - - static std::string shapeInfoAsString(const Nd4jLong* shapeInfo); - - static std::vector shapeAsVector(const Nd4jLong* shapeInfo); - - // evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal - static const Nd4jLong* evalDiagShapeInfo(const Nd4jLong* shapeInfo, sd::memory::Workspace* workspace); - - static std::vector evalBroadcastBackwardAxis(const Nd4jLong *operand, const Nd4jLong *result); - - // utility to calculate matrix product shape with give source shapes and additional params - // returns ShapeList pointer with result shape - static const Nd4jLong* matrixProductShape(const Nd4jLong* theFirstShape, const Nd4jLong* theSecondShape, bool shouldTranspondFirst, bool shouldTranspondSecond, sd::DataType dtype, sd::memory::Workspace* workspace); - - /** - * This method evaluates permutation vector necessary for reducing of shapeFrom to shapeTo - * if shapeFrom is identical to shapeTo (permutation is unnecessary) then empty vector is returned - * in case of permutation is impossible an exception is thrown - */ - static std::vector evalPermutFromTo(const std::vector& shapeFrom, const std::vector& shapeTo); - - /** - * This method composes shape (shape only, not whole shapeInfo!) using dimensions values and corresponding indexes, - * please note: the size of input vector dimsAndIdx must always be even, since the numbers of dimensions and indexes are the same, - * for example if dimsAndIdx = {dimC,dimB,dimA, 2,1,0} then output vector = {dimA,dimB,dimC} - */ - static std::vector composeShapeUsingDimsAndIdx(const std::vector& dimsAndIdx); - - /** - * x * y = c, evaluate shape for array resulting from mmul operation - * possible cases: dot product (xRank=yRank=1), matrix-vector product (xRank=2, yRank=1), vector-matrix product (xRank=1, yRank=2), matrix-matrix product (xRank=yRank and rank >=2) - */ - static std::vector evalShapeForMatmul(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const bool transX, const bool transY); - - /** - * evaluate number of sub-arrays along dimensions stored in dimsToExclude - * i.e. if shape is [2,3,4,5] and dimsToExclude={0,2}, then number of sub-arrays = 8 - */ - static Nd4jLong getNumOfSubArrs(const Nd4jLong* shapeInfo, const std::vector& dimsToExclude); - - /** - * return shape without unities, for example if shape is [1,2,1,3] then [2,3] will be returned - * if unities are not present in given shapeInfo then exactly identical shape will be returned, for example [2,3] -> [2,3] - * edge case: if given shape is [1,1,1,...,1] (all dims are unities) then output will be empty and means scalar - */ - static std::vector evalDimsWithoutUnities(const Nd4jLong* shapeInfo); - - /** - * method returns false if permut == {0,1,2,...permut.size()-1} - in that case permutation is unnecessary - */ - FORCEINLINE static bool isPermutNecessary(const std::vector& permut); - - /** - * calculates strides using "dest" shape and given "order", also copies data type from "source" to "dest" - */ - static void updateStridesAndType(Nd4jLong* dest, const Nd4jLong* source, const char order); - - /** - * calculates strides using "dest" shape and "order", also set "dtype" into "dest" - */ - static void updateStridesAndType(Nd4jLong* dest, const DataType dtype, const char order); - - /** - * This method retuns number of bytes required for string tensor - * @param numStrings - * @return - */ - static FORCEINLINE Nd4jLong stringBufferHeaderRequirements(Nd4jLong numStrings) { - // we store +1 offset - return (numStrings + 1) * sizeof(Nd4jLong); - } - - /** - * This method selects strides based on dimentions required for broadcasting - * @param const pointer to input (Y) shape info for strides selection - * @param rank of input (X) to broadcasting - * @param dimentions size - * @param const pointer to dimentions for broadcasting - * @param pointer to output strides have to be pre allocated by 0 - * @return - */ - static void copyCertainStridesFromShapeInfo(const Nd4jLong* inShapeInfo, const int nRank, const int dimsSize, const int* dims, Nd4jLong* outStrides); - - /* - * check whether arr1/arr2 is sub-array of arr2/arr1, - * this method do not evaluate what array is sub-array, it returns true if arr1 is sub-array of arr2 or arr2 is sub-array of arr1 - * sameDims is filled (and sorted) with dimensions values that match both in arr1 and arr2 shapes (unities are ignored) - * for example: - * if arr1{2,3} and arr2{2,4,3,7} then return true and sameDims contains {0,2} - * if arr1{1,1,3,1,3,1,1} and arr2{1,2,3,1,3} then return true and sameDims contains {2,4} - * if arr1{2,1,4,1,7,5} and arr2{1,1,4,5} then return true and sameDims contains {2,5} - - static bool isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector& sameDims); - */ - - /* - * comparing of shapes, not strides - */ - static bool areShapesEqual(const Nd4jLong* shapeInfo, const std::vector& shapeOnly); - }; - - +#include +namespace sd { +class SD_EXPORT ShapeUtils { + public: + // evaluate shape for array resulting from tensorDot operation, also evaluate + // shapes and permutation dimensions for transposition of two input arrays + static std::vector evalShapeForTensorDot( + const Nd4jLong* aShapeInfo, const Nd4jLong* bShapeInfo, + std::vector axesA, std::vector axesB, + std::vector& permutAt, std::vector& permutBt, + std::vector& shapeAt, std::vector& shapeBt); + static std::vector evalShapeForTensorDot( + const NDArray* a, const NDArray* b, const std::vector& axesA, + const std::vector& axesB, std::vector& permutAt, + std::vector& permutBt, std::vector& shapeAt, + std::vector& shapeBt); + + // evaluate resulting shape after reduce operation + static const Nd4jLong* evalReduceShapeInfo( + const char order, const std::vector& dimensions, const NDArray& arr, + const sd::DataType dataType, const bool keepDims = false, + const bool supportOldShapes = false, + sd::memory::Workspace* workspace = nullptr); + static const Nd4jLong* evalReduceShapeInfo( + const char order, const std::vector& dimensions, + const Nd4jLong* shapeInfo, const sd::DataType dataType, + const bool keepDims = false, const bool supportOldShapes = false, + sd::memory::Workspace* workspace = nullptr); + static const Nd4jLong* evalReduceShapeInfo( + const char order, const std::vector& dimensions, const NDArray& arr, + const bool keepDims = false, const bool supportOldShapes = false, + sd::memory::Workspace* workspace = nullptr); + static const Nd4jLong* evalReduceShapeInfo( + const char order, const std::vector& dimensions, + const Nd4jLong* shapeInfo, const bool keepDims = false, + const bool supportOldShapes = false, + sd::memory::Workspace* workspace = nullptr); + + /** + * evaluate output shape for reduce operation when input shape is empty + * behavior is analogous to tf + */ + static const Nd4jLong* evalReduceShapeInfoEmpty( + const char order, const std::vector& dimensions, + const Nd4jLong* shapeInfo, const sd::DataType dataType, + const bool keepDims, sd::memory::Workspace* workspace); + + // evaluate shape for array which is result of repeat operation applied to arr + static std::vector evalRepeatShape(int axis, + const std::vector& repeats, + const NDArray& arr); + + // evaluate shapeInfo of permuted array + // if setContigStrides = true, then set contiguous strides in output shapeInfo + // in accordance with arr order + static const Nd4jLong* evalPermShapeInfo(const int* dimensions, + const int rank, const NDArray& arr, + sd::memory::Workspace* workspace, + const bool setContigStrides = false); + static const Nd4jLong* evalPermShapeInfo(const Nd4jLong* dimensions, + const int rank, const NDArray& arr, + sd::memory::Workspace* workspace); + + // evaluate shapeInfo of transposed array + // if setContigStrides = true, then set contiguous strides in output shapeInfo + // in accordance with arr order + static const Nd4jLong* evalTranspShapeInfo( + const NDArray& arr, sd::memory::Workspace* workspace, + const bool setContigStrides = false); + + static bool copyVectorPart(std::vector& target, std::vector& source, + int rank, int offset); + + // return new (shorter) sorted dimensions array without dimensions that are + // present in input vector + static std::vector evalDimsToExclude(const int rank, const int dimsLen, + const int* dimensions); + static std::vector evalDimsToExclude(const int rank, + const std::vector& dimensions); + + // check whether 2 arrays have mutually broadcastable shapes + // shape comparison starts from the end + static bool areShapesBroadcastable(const NDArray& arr1, const NDArray& arr2); + static bool areShapesBroadcastable(const Nd4jLong* shapeX, + const Nd4jLong* shapeY); + static bool areShapesBroadcastable(const std::vector& shape1, + const std::vector& shape2); + + // check the possibility of broadcast operation, if true then return shapeInfo + // of resulting array if evalMinMax == false then array with larger rank has + // to be passed as first argument + static bool evalBroadcastShapeInfo(const NDArray& max, const NDArray& min, + const bool evalMinMax, + const Nd4jLong*& resultShapeInfo, + sd::memory::Workspace* workspace); + static bool evalBroadcastShapeInfo(const Nd4jLong* max, const Nd4jLong* min, + const bool evalMinMax, + const Nd4jLong*& resultShapeInfo, + sd::memory::Workspace* workspace); + + // evaluate sorted vector of max axes to create tads along in case of simple + // broadcast operation if simple broadcast is not possible then empty vector + // is returned PLEASE NOTE: condition (rank_max >= rank_min) should be + // satisfied ! + static std::vector tadAxesForSimpleBroadcast(const NDArray& max, + const NDArray& min); + + // check the possibility of broadcast operation for set of arrays, if true + // then return resulting broadcasted shapeInfo + static bool evalCommonBroadcastShapeInfo( + const std::vector& arrays, Nd4jLong*& resultShapeInfo, + memory::Workspace* workspace = nullptr); + + // return sorted vector of dimensions common (same) for two arrays, dimensions + // values corresponds to array with bigger rank for example if arr1{2,7}, + // arr2{2,5,4,7} then vector = {0,3} + static std::vector getDimsWithSameShape(const NDArray& max, + const NDArray& min); + + // evaluate shapeInfo for resulting array of tile operation + static const Nd4jLong* evalTileShapeInfo(const NDArray& arr, + const std::vector& reps, + sd::memory::Workspace* workspace); + + // returns shape part of shapeInfo as std::vector + static std::vector pullShapeFromShapeInfo( + const Nd4jLong* shapeInfo); + + static std::string shapeAsString(const NDArray& array); + static std::string shapeAsString(const NDArray* array); + static std::string shapeAsString(const std::vector& shape); + static std::string shapeAsString(const Nd4jLong* shapeInfo); + static std::string shapeAsString(const int rank, const Nd4jLong* shapeInfo); + static std::string strideAsString(const NDArray& array); + static std::string strideAsString(const NDArray* array); + + static std::string shapeInfoAsString(const Nd4jLong* shapeInfo); + + static std::vector shapeAsVector(const Nd4jLong* shapeInfo); + + // evaluate shapeInfo for diagonal array which is made using input arr + // elements as diagonal + static const Nd4jLong* evalDiagShapeInfo(const Nd4jLong* shapeInfo, + sd::memory::Workspace* workspace); + + static std::vector evalBroadcastBackwardAxis(const Nd4jLong* operand, + const Nd4jLong* result); + + // utility to calculate matrix product shape with give source shapes and + // additional params returns ShapeList pointer with result shape + static const Nd4jLong* matrixProductShape(const Nd4jLong* theFirstShape, + const Nd4jLong* theSecondShape, + bool shouldTranspondFirst, + bool shouldTranspondSecond, + sd::DataType dtype, + sd::memory::Workspace* workspace); + + /** + * This method evaluates permutation vector necessary for reducing of + * shapeFrom to shapeTo if shapeFrom is identical to shapeTo (permutation is + * unnecessary) then empty vector is returned in case of permutation is + * impossible an exception is thrown + */ + static std::vector evalPermutFromTo( + const std::vector& shapeFrom, + const std::vector& shapeTo); + + /** + * This method composes shape (shape only, not whole shapeInfo!) using + * dimensions values and corresponding indexes, please note: the size of input + * vector dimsAndIdx must always be even, since the numbers of dimensions and + * indexes are the same, for example if dimsAndIdx = {dimC,dimB,dimA, 2,1,0} + * then output vector = {dimA,dimB,dimC} + */ + static std::vector composeShapeUsingDimsAndIdx( + const std::vector& dimsAndIdx); + + /** + * x * y = c, evaluate shape for array resulting from mmul operation + * possible cases: dot product (xRank=yRank=1), matrix-vector product + * (xRank=2, yRank=1), vector-matrix product (xRank=1, yRank=2), matrix-matrix + * product (xRank=yRank and rank >=2) + */ + static std::vector evalShapeForMatmul(const Nd4jLong* xShapeInfo, + const Nd4jLong* yShapeInfo, + const bool transX, + const bool transY); + + /** + * evaluate number of sub-arrays along dimensions stored in dimsToExclude + * i.e. if shape is [2,3,4,5] and dimsToExclude={0,2}, then number of + * sub-arrays = 8 + */ + static Nd4jLong getNumOfSubArrs(const Nd4jLong* shapeInfo, + const std::vector& dimsToExclude); + + /** + * return shape without unities, for example if shape is [1,2,1,3] then + * [2,3] will be returned if unities are not present in given shapeInfo then + * exactly identical shape will be returned, for example [2,3] -> [2,3] edge + * case: if given shape is [1,1,1,...,1] (all dims are unities) then output + * will be empty and means scalar + */ + static std::vector evalDimsWithoutUnities( + const Nd4jLong* shapeInfo); + + /** + * method returns false if permut == {0,1,2,...permut.size()-1} - in that + * case permutation is unnecessary + */ + FORCEINLINE static bool isPermutNecessary(const std::vector& permut); + + /** + * calculates strides using "dest" shape and given "order", also copies data + * type from "source" to "dest" + */ + static void updateStridesAndType(Nd4jLong* dest, const Nd4jLong* source, + const char order); + + /** + * calculates strides using "dest" shape and "order", also set "dtype" into + * "dest" + */ + static void updateStridesAndType(Nd4jLong* dest, const DataType dtype, + const char order); + + /** + * This method retuns number of bytes required for string tensor + * @param numStrings + * @return + */ + static FORCEINLINE Nd4jLong + stringBufferHeaderRequirements(Nd4jLong numStrings) { + // we store +1 offset + return (numStrings + 1) * sizeof(Nd4jLong); + } + + /** + * This method selects strides based on dimentions required for broadcasting + * @param const pointer to input (Y) shape info for strides selection + * @param rank of input (X) to broadcasting + * @param dimentions size + * @param const pointer to dimentions for broadcasting + * @param pointer to output strides have to be pre allocated by 0 + * @return + */ + static void copyCertainStridesFromShapeInfo(const Nd4jLong* inShapeInfo, + const int nRank, + const int dimsSize, + const int* dims, + Nd4jLong* outStrides); + + /* + * check whether arr1/arr2 is sub-array of arr2/arr1, + * this method do not evaluate what array is sub-array, it returns true if arr1 + is sub-array of arr2 or arr2 is sub-array of arr1 + * sameDims is filled (and sorted) with dimensions values that match both in + arr1 and arr2 shapes (unities are ignored) + * for example: + * if arr1{2,3} and arr2{2,4,3,7} then return true and sameDims contains {0,2} + * if arr1{1,1,3,1,3,1,1} and arr2{1,2,3,1,3} then return true and sameDims + contains {2,4} + * if arr1{2,1,4,1,7,5} and arr2{1,1,4,5} then return true and sameDims + contains {2,5} + + static bool isSubArrayCase(const NDArray& arr1, const NDArray& arr2, + std::vector& sameDims); + */ + + /* + * comparing of shapes, not strides + */ + static bool areShapesEqual(const Nd4jLong* shapeInfo, + const std::vector& shapeOnly); +}; ////////////////////////////////////////////////////////////////////////// ///// IMLEMENTATION OF INLINE METHODS ///// ////////////////////////////////////////////////////////////////////////// FORCEINLINE bool ShapeUtils::isPermutNecessary(const std::vector& permut) { + for (int i = 0; i < permut.size(); ++i) + if (permut[i] != i) return true; - for(int i=0; i + #include #include -#include /** - * This class provides PRIMITIVE read-write lock, and should NOT be used outside of GraphServer due to its inefficiency. - * However, since GraphServer isn't supposed to have Reads/Writes ration even close to 1.0, it'll work just fine. + * This class provides PRIMITIVE read-write lock, and should NOT be used outside + * of GraphServer due to its inefficiency. However, since GraphServer isn't + * supposed to have Reads/Writes ration even close to 1.0, it'll work just fine. * * Basic idea: write lock won't be obtained before all read requests served */ namespace sd { - class ND4J_EXPORT SimpleReadWriteLock { - private: - std::atomic _read_locks; - std::atomic _write_locks; - std::mutex _mutex; - - public: - - explicit SimpleReadWriteLock(); - SimpleReadWriteLock(const SimpleReadWriteLock& other); - ~SimpleReadWriteLock() = default; - - // read lock - void lockRead(); - void unlockRead(); - - // write lock - void lockWrite(); - void unlockWrite(); - - SimpleReadWriteLock& operator= ( const SimpleReadWriteLock &other); - }; -} - - -#endif //DEV_TESTS_READWRITELOCK_H +class SD_EXPORT SimpleReadWriteLock { + private: + std::atomic _read_locks; + std::atomic _write_locks; + std::mutex _mutex; + + public: + explicit SimpleReadWriteLock(); + SimpleReadWriteLock(const SimpleReadWriteLock& other); + ~SimpleReadWriteLock() = default; + + // read lock + void lockRead(); + void unlockRead(); + + // write lock + void lockWrite(); + void unlockWrite(); + + SimpleReadWriteLock& operator=(const SimpleReadWriteLock& other); +}; +} // namespace sd + +#endif // SD_READWRITELOCK_H diff --git a/libnd4j/include/helpers/StringUtils.h b/libnd4j/include/helpers/StringUtils.h index e5f9f299091d..2eaba7e7e916 100644 --- a/libnd4j/include/helpers/StringUtils.h +++ b/libnd4j/include/helpers/StringUtils.h @@ -16,140 +16,156 @@ ******************************************************************************/ // -// Created by raver119 on 20/04/18. +// @author raver110@gmail.com // @author Oleg Semeniv // #ifndef LIBND4J_STRINGUTILS_H #define LIBND4J_STRINGUTILS_H -#include +#include +#include #include -#include +#include + #include +#include #include -#include -#include namespace sd { - class ND4J_EXPORT StringUtils { - public: - template - static FORCEINLINE std::string valueToString(T value) { - std::ostringstream os; - - os << value ; - - //convert the string stream into a string and return - return os.str(); - } - - /** - * These methods convert integer values to string with 0s and 1s - * @param value - * @return - */ - template - static std::string bitsToString(T value); - - /** - * This method just concatenates error message with a given graphId - * @param message - * @param graphId - * @return - */ - static FORCEINLINE std::string buildGraphErrorMessage(const char *message, Nd4jLong graphId) { - std::string result(message); - result += " ["; - result += valueToString(graphId); - result += "]"; - - return result; - } - - /** - * This method returns number of needle matches within haystack - * PLEASE NOTE: this method operates on 8-bit arrays interpreted as uint8 - * - * @param haystack - * @param haystackLength - * @param needle - * @param needleLength - * @return - */ - static uint64_t countSubarrays(const void *haystack, uint64_t haystackLength, const void *needle, uint64_t needleLength); - - /** - * This method returns number of bytes used for string NDArrays content - * PLEASE NOTE: this doesn't include header - * - * @param array - * @return - */ - static uint64_t byteLength(const NDArray &array); - - /** - * This method splits a string into substring by delimiter - * - * @param haystack - * @param delimiter - * @return - */ - static std::vector split(const std::string &haystack, const std::string &delimiter); - - - /** - * This method convert u8 string to u16 - * @param const reference to input string - * @param reference to output u16string - * @return boolean status - */ - static bool u8StringToU16String(const std::string& u8, std::u16string& u16); - - /** - * This method convert u8 string to u32 - * @param const reference to input string - * @param reference to output u32string - * @return boolean status - */ - static bool u8StringToU32String(const std::string& u8, std::u32string& u32); - - /** - * This method convert u16 string to u32 - * @param const reference to input u16string - * @param reference to output u32string - * @return boolean status - */ - static bool u16StringToU32String(const std::u16string& u16, std::u32string& u32); - - /** - * This method convert u16 string to u8 string - * @param const reference to input u16string - * @param reference to output string - * @return boolean status - */ - static bool u16StringToU8String(const std::u16string& u16, std::string& u8); - - /** - * This method convert u32 string to u16 string - * @param const reference to input u32string - * @param reference to output u16string - * @return boolean status - */ - static bool u32StringToU16String(const std::u32string& u32, std::u16string& u16); - - /** - * This method convert u32 string to u8 string - * @param const reference to input u32string - * @param reference to output string - * @return boolean status - */ - static bool u32StringToU8String(const std::u32string& u32, std::string& u8); - - template - static std::string vectorToString(const std::vector &vec); - }; +class SD_EXPORT StringUtils { + public: + template + static FORCEINLINE std::string valueToString(T value); + + /** + * These methods convert integer values to string with 0s and 1s + * @param value + * @return + */ + template + static std::string bitsToString(T value); + + /** + * This method just concatenates error message with a given graphId + * @param message + * @param graphId + * @return + */ + static FORCEINLINE std::string buildGraphErrorMessage(const char* message, Nd4jLong graphId) { + std::string result(message); + result += " ["; + result += valueToString(graphId); + result += "]"; + + return result; + } + + /** + * This method returns number of needle matches within haystack + * PLEASE NOTE: this method operates on 8-bit arrays interpreted as uint8 + * + * @param haystack + * @param haystackLength + * @param needle + * @param needleLength + * @return + */ + static uint64_t countSubarrays(const void* haystack, uint64_t haystackLength, + const void* needle, uint64_t needleLength); + + /** + * This method returns number of bytes used for string NDArrays content + * PLEASE NOTE: this doesn't include header + * + * @param array + * @return + */ + static uint64_t byteLength(const NDArray& array); + + /** + * This method splits a string into substring by delimiter + * + * @param haystack + * @param delimiter + * @return + */ + static std::vector split(const std::string& haystack, const std::string& delimiter); + + /** + * This method convert u8 string to u16 + * @param const reference to input string + * @param reference to output u16string + * @return boolean status + */ + static bool u8StringToU16String(const std::string& u8, std::u16string& u16); + + /** + * This method convert u8 string to u32 + * @param const reference to input string + * @param reference to output u32string + * @return boolean status + */ + static bool u8StringToU32String(const std::string& u8, std::u32string& u32); + + /** + * This method convert u16 string to u32 + * @param const reference to input u16string + * @param reference to output u32string + * @return boolean status + */ + static bool u16StringToU32String(const std::u16string& u16, std::u32string& u32); + + /** + * This method convert u16 string to u8 string + * @param const reference to input u16string + * @param reference to output string + * @return boolean status + */ + static bool u16StringToU8String(const std::u16string& u16, std::string& u8); + + /** + * This method convert u32 string to u16 string + * @param const reference to input u32string + * @param reference to output u16string + * @return boolean status + */ + static bool u32StringToU16String(const std::u32string& u32, std::u16string& u16); + + /** + * This method convert u32 string to u8 string + * @param const reference to input u32string + * @param reference to output string + * @return boolean status + */ + static bool u32StringToU8String(const std::u32string& u32, std::string& u8); + + template + static std::string vectorToString(const std::vector &vec); +}; + +template <> +FORCEINLINE std::string StringUtils::valueToString(std::pair value) { + std::ostringstream os; + + os << value.first; + os << ":"; + os << value.second; + + // convert the string stream into a string and return + return os.str(); +} + +template +FORCEINLINE std::string StringUtils::valueToString(T value) { + std::ostringstream os; + + os << value; + + // convert the string stream into a string and return + return os.str(); } +} // namespace sd -#endif //LIBND4J_STRINGUTILS_H +#endif // LIBND4J_STRINGUTILS_H diff --git a/libnd4j/include/helpers/TAD.h b/libnd4j/include/helpers/TAD.h index cd58e421e5e8..2f763b0c8c4b 100644 --- a/libnd4j/include/helpers/TAD.h +++ b/libnd4j/include/helpers/TAD.h @@ -21,239 +21,242 @@ #ifndef LIBND4J_TAD_H #define LIBND4J_TAD_H - #include #include - namespace shape { - /** - * Dimension collapse is an algorithm - * for collapsing singular dimensions. - * This algorithm will adjust the dimensions - * wrt the original. - * - * The algorithm has 3 components: - * trailing ones - * middle ones - * beginning ones - * - * dimensions that are specified to reduce along - * that are singular should be truncated - * - * dimensions that are specified that are singular - * at the beginning should be removed with middle dimensions - * decremented. - * - * For any time there is a no op, a collapse will - * set the first dimension to be -1. - * - * - */ - class TAD { - public: - Nd4jLong tadIndex = 0; - int dimensionLength; - int* dimension = nullptr; - Nd4jLong const* shapeInfo = nullptr; - Nd4jLong* tadOnlyShapeInfo = nullptr; - Nd4jLong numTads = 0; - int tadRank = 0; - Nd4jLong* tadShape = nullptr; - Nd4jLong* tadStride = nullptr; - Nd4jLong* tadOffsets = nullptr; - Nd4jLong tadOffsetForBlock = 0; - int rank = 0; - int numOnes = 0; - //pointers to original - int originalDimensionLength; - int const* originalDimension = nullptr; - Nd4jLong const* originalShapeInfo = nullptr; - bool squeezed = false; - bool newSqueezeDimensions = false; - int numOnesInMiddle = 0; - bool wholeThing = false; - //need to track whether we create a new dimension array or not, we could have just moved the pointer forward - //due to leading ones - bool createdNewDimension = false; - - // special case for CUDA, we're passing in __shared__ memory pointers to be used instead of new/malloc - void *ptrManager = nullptr; - int *ptrOutput = nullptr; - - INLINEDEF bool dimensionsDescending(int rank, int const* dimensions, int length); +/** + * Dimension collapse is an algorithm + * for collapsing singular dimensions. + * This algorithm will adjust the dimensions + * wrt the original. + * + * The algorithm has 3 components: + * trailing ones + * middle ones + * beginning ones + * + * dimensions that are specified to reduce along + * that are singular should be truncated + * + * dimensions that are specified that are singular + * at the beginning should be removed with middle dimensions + * decremented. + * + * For any time there is a no op, a collapse will + * set the first dimension to be -1. + * + * + */ +class TAD { + public: + Nd4jLong tadIndex = 0; + int dimensionLength; + int *dimension = nullptr; + Nd4jLong const *shapeInfo = nullptr; + Nd4jLong *tadOnlyShapeInfo = nullptr; + Nd4jLong numTads = 0; + int tadRank = 0; + Nd4jLong *tadShape = nullptr; + Nd4jLong *tadStride = nullptr; + Nd4jLong *tadOffsets = nullptr; + Nd4jLong tadOffsetForBlock = 0; + int rank = 0; + int numOnes = 0; + // pointers to original + int originalDimensionLength; + int const *originalDimension = nullptr; + Nd4jLong const *originalShapeInfo = nullptr; + bool squeezed = false; + bool newSqueezeDimensions = false; + int numOnesInMiddle = 0; + bool wholeThing = false; + // need to track whether we create a new dimension array or not, we could have + // just moved the pointer forward due to leading ones + bool createdNewDimension = false; + + // special case for CUDA, we're passing in __shared__ memory pointers to be + // used instead of new/malloc + void *ptrManager = nullptr; + int *ptrOutput = nullptr; + + INLINEDEF bool dimensionsDescending(int rank, int const *dimensions, + int length); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF TAD() {} - + INLINEDEF + TAD() { + } #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF void setExternalBuffers(void *ptrManager); - - + INLINEDEF void + setExternalBuffers(void *ptrManager); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF void setOutputBuffer(int *ptrOutput); + INLINEDEF void + setOutputBuffer(int *ptrOutput); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - /** - * This method is for GPU mostly, it allows to initialize TAD instance with precalculated tadOnlyShapeInfo - */ - INLINEDEF void initWithExternalTAD(Nd4jLong *existingTAD, Nd4jLong *originalShape, int *dimension, int dimensionLength); - - + /** + * This method is for GPU mostly, it allows to initialize TAD instance + * with precalculated tadOnlyShapeInfo + */ + INLINEDEF void + initWithExternalTAD(Nd4jLong *existingTAD, Nd4jLong *originalShape, + int *dimension, int dimensionLength); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF void init(Nd4jLong const* shapeInfo,int const* dimension,int dimensionLength); + INLINEDEF void + init(Nd4jLong const *shapeInfo, int const *dimension, + int dimensionLength); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF void init(int index, Nd4jLong const* shapeInfo,int const* dimension,int dimensionLength); - - + INLINEDEF void + init(int index, Nd4jLong const *shapeInfo, int const *dimension, + int dimensionLength); - template + template #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF void printTADsND(T *x); - - + INLINEDEF void + printTADsND(T *x); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF void permuteShapeBufferInPlace(Nd4jLong const* shapeBuffer, int const* rearrange, Nd4jLong *out); + INLINEDEF void + permuteShapeBufferInPlace(Nd4jLong const *shapeBuffer, + int const *rearrange, Nd4jLong *out); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF Nd4jLong* permuteShapeBuffer(Nd4jLong const* shapeBuffer, int *rearrange); - - - + INLINEDEF Nd4jLong * + permuteShapeBuffer(Nd4jLong const *shapeBuffer, int *rearrange); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF void createTadOnlyShapeInfo(); - + INLINEDEF void + createTadOnlyShapeInfo(); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF Nd4jLong lengthPerSlice(Nd4jLong const* shapeBuffer); - + INLINEDEF Nd4jLong + lengthPerSlice(Nd4jLong const *shapeBuffer); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF Nd4jLong* tad2Sub(Nd4jLong index); - - + INLINEDEF Nd4jLong * + tad2Sub(Nd4jLong index); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF ~TAD(); - + INLINEDEF ~TAD(); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF int* permuteDims(); - - - /** - * Compute the tad offset given a dimension. - * - * The general pattern for computing a tad offset is as follows: - * Every $STRIDE that was removed (the first dimension) - * do a jump by the major stride of the parent array - * (stride[0] of the parent array) - * - * For example given a c ordered 2,2,3,2 with stride 12,6,2,1 - * A tad of dimension 1 will jump 12 every 6 tads. - * - * You then end up with offsets of: - * 0 - * 1 - * 2 - * 3 - * 4 - * 5 - * 12 - * 13 - * 14 - * 15 - * 16 - * 17 - * - * notice there are 12 tads here. This same incremental jump will happen - * every time. - * Note here that by default the - * stride of element wise stride is used for the hops. - * - * Sometimes a jump doesn't happen. If there are less tads - * than the stride of the dimension you removed, the - * element wise stride will always be used. - * - * For example in a dimension of 0,1, you end up with offsets of: - * 0,1,2,3,4,5 - * - * Given that the inner most stride of the dimensions that was removed (1) - * had a stride of 6, we never need to do a major stride jump. - * - */ + INLINEDEF int * + permuteDims(); + + /** + * Compute the tad offset given a dimension. + * + * The general pattern for computing a tad offset is as follows: + * Every $STRIDE that was removed (the first dimension) + * do a jump by the major stride of the parent array + * (stride[0] of the parent array) + * + * For example given a c ordered 2,2,3,2 with stride 12,6,2,1 + * A tad of dimension 1 will jump 12 every 6 tads. + * + * You then end up with offsets of: + * 0 + * 1 + * 2 + * 3 + * 4 + * 5 + * 12 + * 13 + * 14 + * 15 + * 16 + * 17 + * + * notice there are 12 tads here. This same incremental jump will happen + * every time. + * Note here that by default the + * stride of element wise stride is used for the hops. + * + * Sometimes a jump doesn't happen. If there are less tads + * than the stride of the dimension you removed, the + * element wise stride will always be used. + * + * For example in a dimension of 0,1, you end up with offsets of: + * 0,1,2,3,4,5 + * + * Given that the inner most stride of the dimensions that was removed (1) + * had a stride of 6, we never need to do a major stride jump. + * + */ #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF Nd4jLong tadOffset(Nd4jLong index); - + INLINEDEF Nd4jLong + tadOffset(Nd4jLong index); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF Nd4jLong* tensorShape(); + INLINEDEF Nd4jLong * + tensorShape(); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF Nd4jLong* tad2Sub(Nd4jLong index, void *ptrManager); - + INLINEDEF Nd4jLong * + tad2Sub(Nd4jLong index, void *ptrManager); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF void createOffsets(); - + INLINEDEF void + createOffsets(); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF Nd4jLong* shapeInfoOnlyShapeAndStride(); - + INLINEDEF Nd4jLong * + shapeInfoOnlyShapeAndStride(); - - /** - * Length of a tad given - * the shape information - */ + /** + * Length of a tad given + * the shape information + */ #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF Nd4jLong tadLength(Nd4jLong const* shapeInfo, int const* dimension, int dimensionLength); + INLINEDEF Nd4jLong + tadLength(Nd4jLong const *shapeInfo, int const *dimension, + int dimensionLength); /** * Computes the number @@ -261,42 +264,33 @@ namespace shape { * a given dimension */ #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF Nd4jLong tensorsAlongDimension(Nd4jLong const* shapeInfo, int const* dimension, int dimensionLength); - + INLINEDEF Nd4jLong + tensorsAlongDimension(Nd4jLong const *shapeInfo, int const *dimension, + int dimensionLength); #ifdef __CUDACC__ - __host__ __device__ - INLINEDEF void createOffsetForBlock(int blockIdx) { - this->tadOffsetForBlock = this->tadOffset(blockIdx); - } + __host__ __device__ INLINEDEF void createOffsetForBlock(int blockIdx) { + this->tadOffsetForBlock = this->tadOffset(blockIdx); + } #endif - #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF void collapse(); - }; - + INLINEDEF void + collapse(); +}; - - - - - - - - - //// +//// /* #ifdef __CUDACC__ __host__ __device__ #endif - INLINEDEF TAD::TAD(int tadIndex,Nd4jLong *shapeInfo,int *dimension,int dimensionLength) { - this->tadIndex = tadIndex; - this->init(shapeInfo, dimension, dimensionLength); + INLINEDEF TAD::TAD(int tadIndex,Nd4jLong *shapeInfo,int *dimension,int +dimensionLength) { this->tadIndex = tadIndex; this->init(shapeInfo, dimension, +dimensionLength); } @@ -308,785 +302,798 @@ namespace shape { } */ - INLINEDEF void TAD::setExternalBuffers(void *ptrManager) { - this->ptrManager = ptrManager; - } - - INLINEDEF void TAD::setOutputBuffer(int *ptrOutput) { - this->ptrOutput = ptrOutput; - } - - INLINEDEF void TAD::initWithExternalTAD(Nd4jLong *existingTAD, Nd4jLong *originalShape, int *dimension, int dimensionLength) { - this->tadOnlyShapeInfo = existingTAD; - this->rank = shape::rank(originalShape); - - this->originalShapeInfo = originalShape; - this->originalDimension = dimension; - this->originalDimensionLength = dimensionLength; - - this->shapeInfo = originalShape; - this->dimension = dimension; - this->dimensionLength = dimensionLength; +INLINEDEF void TAD::setExternalBuffers(void *ptrManager) { + this->ptrManager = ptrManager; +} - this->tadShape = shape::shapeOf(existingTAD); - this->tadStride = shape::stride(existingTAD); +INLINEDEF void TAD::setOutputBuffer(int *ptrOutput) { + this->ptrOutput = ptrOutput; +} - Nd4jLong ews = shape::elementWiseStride(originalShape); +INLINEDEF void TAD::initWithExternalTAD(Nd4jLong *existingTAD, + Nd4jLong *originalShape, int *dimension, + int dimensionLength) { + this->tadOnlyShapeInfo = existingTAD; + this->rank = shape::rank(originalShape); + + this->originalShapeInfo = originalShape; + this->originalDimension = dimension; + this->originalDimensionLength = dimensionLength; + + this->shapeInfo = originalShape; + this->dimension = dimension; + this->dimensionLength = dimensionLength; + + this->tadShape = shape::shapeOf(existingTAD); + this->tadStride = shape::stride(existingTAD); + + Nd4jLong ews = shape::elementWiseStride(originalShape); + + this->numTads = + shape::length(originalShape) / + shape::length( + existingTAD); // this->tensorsAlongDimension(this->shapeInfo, + // this->dimension, + // this->dimensionLength);//shape::length(originalShape) + // / shape::length(existingTAD); + this->wholeThing = this->numTads == 1 || + ((this->dimensionLength == this->rank || + this->numTads == shape::length(this->shapeInfo)) && + ews == 1); +} - this->numTads = shape::length(originalShape) / shape::length(existingTAD); // this->tensorsAlongDimension(this->shapeInfo, this->dimension, this->dimensionLength);//shape::length(originalShape) / shape::length(existingTAD); - this->wholeThing = this->numTads == 1 || ((this->dimensionLength == this->rank || this->numTads == shape::length(this->shapeInfo)) && ews == 1); - } +INLINEDEF void TAD::init(int tadIndex, Nd4jLong const *shapeInfo, + int const *dimension, int dimensionLength) { + this->tadIndex = tadIndex; + this->init(shapeInfo, dimension, dimensionLength); +} - INLINEDEF void TAD::init(int tadIndex, Nd4jLong const* shapeInfo,int const* dimension,int dimensionLength) { - this->tadIndex = tadIndex; - this->init(shapeInfo, dimension, dimensionLength); +INLINEDEF void TAD::init(Nd4jLong const *shapeInfo, int const *dimension, + int dimensionLength) { + this->originalShapeInfo = shapeInfo; + this->originalDimension = dimension; + this->originalDimensionLength = dimensionLength; + // start off as original references + this->shapeInfo = shapeInfo; + this->dimensionLength = dimensionLength; + this->dimension = const_cast(dimension); + this->rank = shape::rank(shapeInfo); + this->numTads = dimensionLength == 0 ? 1 + : this->tensorsAlongDimension( + this->shapeInfo, this->dimension, + this->dimensionLength); + + Nd4jLong ews = shape::elementWiseStride(shapeInfo); + + if (dimensionLength == 0) { + wholeThing = true; + } else if (!shape::isVector(shapeInfo)) { + wholeThing = + this->numTads == + 1 // if number of TADs is 1, we just have input shape == TAD shape + || + ((this->dimensionLength == + this->rank // if number of dimensions is the same as input rank, + // that'll be wholeTad too, but only if EWS==1 (aka - + // not a View) + || (this->numTads == shape::length(shapeInfo) && + shape::order(shapeInfo) == + 'c')) // OR number of tads equals to shapeInfo length AND + // input is in C order. if order is F - we'll have to + // calculate offsets + && + ews == + 1); // as mentioned above - last 2 rules apply only to non-views + } else if (shape::isScalar(shapeInfo)) { + wholeThing = true; + // vector case + } else { + // if(dimensionLength == 1 && shape::shapeOf(shapeInfo)[dimension[0]] == 1) + // { + // if(dimension == 0 && ) { + if (dimensionLength != 0 && dimension != nullptr && + shape::shapeOf(shapeInfo)[dimension[0]] == 1) { + wholeThing = true; } + } +} - INLINEDEF void TAD::init(Nd4jLong const* shapeInfo, int const* dimension,int dimensionLength) { - this->originalShapeInfo = shapeInfo; - this->originalDimension = dimension; - this->originalDimensionLength = dimensionLength; - //start off as original references - this->shapeInfo = shapeInfo; - this->dimensionLength = dimensionLength; - this->dimension = const_cast(dimension); - this->rank = shape::rank(shapeInfo); - this->numTads = dimensionLength == 0 ? 1 : this->tensorsAlongDimension(this->shapeInfo, this->dimension, this->dimensionLength); - - Nd4jLong ews = shape::elementWiseStride(shapeInfo); - - if (dimensionLength == 0) { - wholeThing = true; - } else if(!shape::isVector(shapeInfo)) { - wholeThing = this->numTads == 1 // if number of TADs is 1, we just have input shape == TAD shape - || ((this->dimensionLength == this->rank // if number of dimensions is the same as input rank, that'll be wholeTad too, but only if EWS==1 (aka - not a View) - || (this->numTads == shape::length(shapeInfo) && shape::order(shapeInfo) == 'c')) // OR number of tads equals to shapeInfo length AND input is in C order. if order is F - we'll have to calculate offsets - && ews == 1); // as mentioned above - last 2 rules apply only to non-views - } else if(shape::isScalar(shapeInfo)) { - wholeThing = true; - //vector case - } else { - // if(dimensionLength == 1 && shape::shapeOf(shapeInfo)[dimension[0]] == 1) { - //if(dimension == 0 && ) { - if(dimensionLength != 0 && dimension != nullptr && shape::shapeOf(shapeInfo)[dimension[0]] == 1) { - wholeThing = true; - } - } +template +INLINEDEF void TAD::printTADsND(T *x) { + if (wholeThing) { + for (int i = 0; i < shape::length(tadOnlyShapeInfo); i++) { + printf(" %f ", x[i]); } - - template - INLINEDEF void TAD::printTADsND(T *x) { - if(wholeThing) { - for(int i = 0; i < shape::length(tadOnlyShapeInfo); i++) { - printf(" %f ",x[i]); - } - printf("\n"); - } - else { - for (int i = 0; i < numTads; i++) { - auto offset = tadOffsets[i]; - Nd4jLong shapeIter[MAX_RANK]; - Nd4jLong coord[MAX_RANK]; - int dim; - int rankIter = shape::rank(tadOnlyShapeInfo); - Nd4jLong xStridesIter[MAX_RANK]; - T *xPointer = x + offset; - if (PrepareOneRawArrayIter(rankIter, - shape::shapeOf(tadOnlyShapeInfo), - xPointer, - shape::stride(tadOnlyShapeInfo), - &rankIter, - shapeIter, - &xPointer, - xStridesIter) >= 0) { - ND4J_RAW_ITER_START(dim, shape::rank(tadOnlyShapeInfo), coord, shapeIter); { - /* Process the innermost dimension */ - printf(" %f ",xPointer[0]); - } - ND4J_RAW_ITER_ONE_NEXT(dim, - rankIter, - coord, - shapeIter, - xPointer, - xStridesIter); - printf("\n"); - - } - else { - printf("Unable to prepare array\n"); - } - } + printf("\n"); + } else { + for (int i = 0; i < numTads; i++) { + auto offset = tadOffsets[i]; + Nd4jLong shapeIter[MAX_RANK]; + Nd4jLong coord[MAX_RANK]; + int dim; + int rankIter = shape::rank(tadOnlyShapeInfo); + Nd4jLong xStridesIter[MAX_RANK]; + T *xPointer = x + offset; + if (PrepareOneRawArrayIter(rankIter, shape::shapeOf(tadOnlyShapeInfo), + xPointer, shape::stride(tadOnlyShapeInfo), + &rankIter, shapeIter, &xPointer, + xStridesIter) >= 0) { + ND4J_RAW_ITER_START(dim, shape::rank(tadOnlyShapeInfo), coord, + shapeIter); + { + /* Process the innermost dimension */ + printf(" %f ", xPointer[0]); } - } + ND4J_RAW_ITER_ONE_NEXT(dim, rankIter, coord, shapeIter, xPointer, + xStridesIter); + printf("\n"); - - INLINEDEF void TAD::permuteShapeBufferInPlace(Nd4jLong const* shapeBuffer, int const* rearrange, Nd4jLong* out) { - memcpy(out, shapeBuffer, sizeof(Nd4jLong) * shape::shapeInfoLength(this->rank)); - doPermuteShapeInfo(out, rearrange); - } - - INLINEDEF Nd4jLong* TAD::permuteShapeBuffer(Nd4jLong const* shapeBuffer, int *rearrange) { - int len = shape::shapeInfoLength(this->rank); - Nd4jLong *copy = shape::copyOf(len,shapeBuffer); - doPermuteShapeInfo(copy,rearrange); - return copy; + } else { + printf("Unable to prepare array\n"); + } } + } +} - INLINEDEF bool TAD::dimensionsDescending(int rank, int const* dimensions, int length) { - int desired = rank - 1; - for (int e = length - 1; e >= 0; e--) { - if (dimensions[e] != desired--) - return false; - } - return true; - } +INLINEDEF void TAD::permuteShapeBufferInPlace(Nd4jLong const *shapeBuffer, + int const *rearrange, + Nd4jLong *out) { + memcpy(out, shapeBuffer, + sizeof(Nd4jLong) * shape::shapeInfoLength(this->rank)); + doPermuteShapeInfo(out, rearrange); +} - INLINEDEF void TAD::createTadOnlyShapeInfo() { - this->tadOnlyShapeInfo = this->shapeInfoOnlyShapeAndStride(); - sd::ArrayOptions::setDataType(this->tadOnlyShapeInfo, sd::ArrayOptions::dataType(this->originalShapeInfo)); +INLINEDEF Nd4jLong *TAD::permuteShapeBuffer(Nd4jLong const *shapeBuffer, + int *rearrange) { + int len = shape::shapeInfoLength(this->rank); + Nd4jLong *copy = shape::copyOf(len, shapeBuffer); + doPermuteShapeInfo(copy, rearrange); + return copy; +} - // possible optimization goes here - if (shape::order(this->originalShapeInfo) == 'c' - && shape::strideDescendingCAscendingF(this->originalShapeInfo) - && dimensionsDescending(shape::rank(this->originalShapeInfo), this->originalDimension, this->originalDimensionLength)) { - // for C order, if outer dimensions are used, continuous layout is preserved - this->tadOnlyShapeInfo[shape::shapeInfoLength(this->tadOnlyShapeInfo) - 2] = this->originalShapeInfo[shape::shapeInfoLength(this->originalShapeInfo) - 2]; - } +INLINEDEF bool TAD::dimensionsDescending(int rank, int const *dimensions, + int length) { + int desired = rank - 1; + for (int e = length - 1; e >= 0; e--) { + if (dimensions[e] != desired--) return false; + } + return true; +} - // do not swap order if positive elementwise stride preserved - if (shape::elementWiseStride(this->tadOnlyShapeInfo) >= 1) { - this->tadOnlyShapeInfo[shape::shapeInfoLength(this->tadOnlyShapeInfo) - 1] = shape::order(this->originalShapeInfo); - } +INLINEDEF void TAD::createTadOnlyShapeInfo() { + this->tadOnlyShapeInfo = this->shapeInfoOnlyShapeAndStride(); + sd::ArrayOptions::setDataType( + this->tadOnlyShapeInfo, + sd::ArrayOptions::dataType(this->originalShapeInfo)); + + // possible optimization goes here + if (shape::order(this->originalShapeInfo) == 'c' && + shape::strideDescendingCAscendingF(this->originalShapeInfo) && + dimensionsDescending(shape::rank(this->originalShapeInfo), + this->originalDimension, + this->originalDimensionLength)) { + // for C order, if outer dimensions are used, continuous layout is preserved + this->tadOnlyShapeInfo[shape::shapeInfoLength(this->tadOnlyShapeInfo) - 2] = + this->originalShapeInfo[shape::shapeInfoLength( + this->originalShapeInfo) - + 2]; + } + + // do not swap order if positive elementwise stride preserved + if (shape::elementWiseStride(this->tadOnlyShapeInfo) >= 1) { + this->tadOnlyShapeInfo[shape::shapeInfoLength(this->tadOnlyShapeInfo) - 1] = + shape::order(this->originalShapeInfo); + } + + if (this->tadShape != nullptr) delete[] this->tadShape; + + this->tadShape = shape::shapeOf(this->tadOnlyShapeInfo); + this->tadStride = shape::stride(this->tadOnlyShapeInfo); +} - if (this->tadShape != nullptr) - delete[] this->tadShape; +INLINEDEF Nd4jLong TAD::lengthPerSlice(Nd4jLong const *shapeBuffer) { + int dimension = 0; + Nd4jLong *remove = shape::removeIndex(shape::shapeOf(shapeBuffer), &dimension, + shape::rank(shapeBuffer), 1); + Nd4jLong prod = shape::prodLong(remove, shape::rank(shapeBuffer) - 1); + delete[] remove; + return prod; +} - this->tadShape = shape::shapeOf(this->tadOnlyShapeInfo); - this->tadStride = shape::stride(this->tadOnlyShapeInfo); +INLINEDEF Nd4jLong *TAD::tad2Sub(Nd4jLong index) { + Nd4jLong *shape = shape::shapeOf(shapeInfo); + int rank = shape::rank(shapeInfo); + int leftOverIndexLen = rank - originalDimensionLength; + + Nd4jLong *ret = new Nd4jLong[rank]; + // shape of the tad + Nd4jLong *tadShape = new Nd4jLong[leftOverIndexLen]; + Nd4jLong *leftOverIndexes = new Nd4jLong[leftOverIndexLen]; + Nd4jLong *sub = new Nd4jLong[rank]; + + // indexes not specified in the tad indexes + + // every coordinate starts as zero + memset(ret, 0, shape::shapeInfoByteLength(rank)); + + // find the length of the elements we + // are iterating over + Nd4jLong len = 1; + // left over index cursor for initializing elements + int leftOverIndex = 0; + for (int i = 0; i < rank; i++) { + // look for dimensions NOT found in dimension length (basically compute + // shape - dimension (set difference) + bool found = false; + for (int j = 0; j < originalDimensionLength; j++) { + // skip over specified dimensions when computing left over length + if (i == originalDimension[j]) { + found = true; + break; + } } - INLINEDEF Nd4jLong TAD::lengthPerSlice(Nd4jLong const* shapeBuffer) { - int dimension = 0; - Nd4jLong *remove = shape::removeIndex(shape::shapeOf(shapeBuffer),&dimension,shape::rank(shapeBuffer),1); - Nd4jLong prod = shape::prodLong(remove, shape::rank(shapeBuffer) - 1); - delete[] remove; - return prod; + // add to the indexes that aren't specified as part of the tad dimension + // indexes + if (!found) { + // accumulate the list of indexes left over used for initializing the + // return value + leftOverIndexes[leftOverIndex] = i; + // accumulate the tad shape + tadShape[leftOverIndex] = shape[i]; + // accumulate the length (product) of the indexes that will be iterated + // over + len *= shape[i]; + leftOverIndex++; } + } + // sub for indices + /* int *sub = new int[leftOverIndexLen]; + shape::ind2subOrder(tadShape,index,len,sub); + */ + shape::index2coords(index, leftOverIndexLen, tadShape, sub); - INLINEDEF Nd4jLong* TAD::tad2Sub(Nd4jLong index) { - Nd4jLong *shape = shape::shapeOf(shapeInfo); - int rank = shape::rank(shapeInfo); - int leftOverIndexLen = rank - originalDimensionLength; - - Nd4jLong *ret = new Nd4jLong[rank]; - //shape of the tad - Nd4jLong *tadShape = new Nd4jLong[leftOverIndexLen]; - Nd4jLong *leftOverIndexes = new Nd4jLong[leftOverIndexLen]; - Nd4jLong *sub = new Nd4jLong[rank]; - - //indexes not specified in the tad indexes - - //every coordinate starts as zero - memset(ret,0, shape::shapeInfoByteLength(rank)); - - //find the length of the elements we - //are iterating over - Nd4jLong len = 1; - //left over index cursor for initializing elements - int leftOverIndex = 0; - for(int i = 0; i < rank; i++) { - //look for dimensions NOT found in dimension length (basically compute shape - dimension (set difference) - bool found = false; - for(int j = 0; j < originalDimensionLength; j++) { - //skip over specified dimensions when computing left over length - if(i == originalDimension[j]) { - found = true; - break; - } - - } - - //add to the indexes that aren't specified as part of the tad dimension - //indexes - if(!found) { - //accumulate the list of indexes left over used for initializing the return value - leftOverIndexes[leftOverIndex] = i; - //accumulate the tad shape - tadShape[leftOverIndex] = shape[i]; - //accumulate the length (product) of the indexes that will be iterated over - len *= shape[i]; - leftOverIndex++; - - } - } + for (int i = 0; i < leftOverIndexLen; i++) { + ret[leftOverIndexes[i]] = sub[i]; + } + if (ptrManager == nullptr) { + delete[] tadShape; + delete[] leftOverIndexes; + delete[] sub; + } - //sub for indices - /* int *sub = new int[leftOverIndexLen]; - shape::ind2subOrder(tadShape,index,len,sub); - */ - shape::index2coords(index, leftOverIndexLen,tadShape, sub); + return ret; +} +INLINEDEF TAD::~TAD() { + // we may have just moved the pointer forward, we may not need to delete the + // pointer here + if (originalDimension != this->dimension && createdNewDimension) { + delete[] this->dimension; + } + if (this->originalShapeInfo != this->shapeInfo) { + delete[] this->shapeInfo; + } + if (this->tadOffsets != nullptr) { + delete[] this->tadOffsets; + } + + if (this->tadOnlyShapeInfo != nullptr && + this->tadOnlyShapeInfo != shapeInfo) { + delete[] this->tadOnlyShapeInfo; + } +} - for(int i = 0; i < leftOverIndexLen; i++) { - ret[leftOverIndexes[i]] = sub[i]; - } +INLINEDEF int *TAD::permuteDims() { + // permute dimensions for tad + int dimIdx = 0; + // loop backwards assuming dimension is sorted - if (ptrManager == nullptr) { - delete[] tadShape; - delete[] leftOverIndexes; - delete[] sub; - } + int *permuteDims = new int[shape::rank(shapeInfo)]; - return ret; + for (int i = 0; i < shape::rank(shapeInfo); i++) { + bool found = false; + for (int j = 0; j < originalDimensionLength; j++) { + if (i == originalDimension[j]) { + found = true; + break; + } } + // not found, append it to the end for permute + if (!found) permuteDims[dimIdx++] = i; + } - INLINEDEF TAD::~TAD() { - //we may have just moved the pointer forward, we may not need to delete the pointer here - if(originalDimension != this->dimension && createdNewDimension) { - delete[] this->dimension; - } - if(this->originalShapeInfo != this->shapeInfo) { - delete[] this->shapeInfo; - } - if(this->tadOffsets != nullptr) { - delete[] this->tadOffsets; - } + for (int i = originalDimensionLength - 1; i >= 0; i--) { + permuteDims[dimIdx++] = originalDimension[i]; + } - if(this->tadOnlyShapeInfo != nullptr && this->tadOnlyShapeInfo != shapeInfo) { - delete[] this->tadOnlyShapeInfo; - } - } + /* + for (int i = 0; i < originalDimensionLength; i++) { + permuteDims[i] = originalDimension[i]; + } + */ - INLINEDEF int* TAD::permuteDims() { - //permute dimensions for tad - int dimIdx = 0; - //loop backwards assuming dimension is sorted - - int *permuteDims = new int[shape::rank(shapeInfo)]; - - for(int i = 0; i < shape::rank(shapeInfo); i++) { - bool found = false; - for(int j = 0; j < originalDimensionLength; j++) { - if(i == originalDimension[j]) { - found = true; - break; - } - } - - //not found, append it to the end for permute - if(!found) - permuteDims[dimIdx++] = i; - } + // permute dimensions for tad + return permuteDims; +} +INLINEDEF Nd4jLong TAD::tadOffset(Nd4jLong index) { + if (tadOnlyShapeInfo == nullptr) { + this->createTadOnlyShapeInfo(); + } + if (wholeThing) return index; - for(int i = originalDimensionLength - 1; i >= 0; i--) { - permuteDims[dimIdx++] = originalDimension[i]; - } + if (dimensionLength > 1) { + Nd4jLong *tad2Sub = this->tad2Sub(index, ptrManager); -/* - for (int i = 0; i < originalDimensionLength; i++) { - permuteDims[i] = originalDimension[i]; - } -*/ + Nd4jLong ret = shape::getOffset(shapeInfo, tad2Sub); - //permute dimensions for tad - return permuteDims; + if (ret < 0) { + if (ptrManager == nullptr) delete[] tad2Sub; + return -1; } + if (ptrManager == nullptr) delete[] tad2Sub; + return ret; - INLINEDEF Nd4jLong TAD::tadOffset(Nd4jLong index) { - if(tadOnlyShapeInfo == nullptr) { - this->createTadOnlyShapeInfo(); - } - - if(wholeThing) - return index; - - if(dimensionLength > 1) { - Nd4jLong *tad2Sub = this->tad2Sub(index, ptrManager); + } else { + Nd4jLong *tad2Sub = this->tad2Sub(index, ptrManager); - Nd4jLong ret = shape::getOffset(shapeInfo, tad2Sub); + Nd4jLong ret = shape::getOffset(shapeInfo, tad2Sub); - if(ret < 0) { - if (ptrManager == nullptr) - delete[] tad2Sub; - return -1; - } - if (ptrManager == nullptr) - delete[] tad2Sub; + if (ptrManager == nullptr) delete[] tad2Sub; - return ret; - - } - else { - Nd4jLong *tad2Sub = this->tad2Sub(index, ptrManager); + return ret; + } +} - Nd4jLong ret = shape::getOffset(shapeInfo, tad2Sub); +INLINEDEF Nd4jLong *TAD::tensorShape() { + if (this->tadShape != nullptr) return this->tadShape; - if (ptrManager == nullptr) - delete[] tad2Sub; + Nd4jLong *theShape = shape::shapeOf(shapeInfo); + Nd4jLong *tensorShape = shape::keep(theShape, this->dimension, + dimensionLength, shape::rank(shapeInfo)); + this->tadShape = tensorShape; + this->tadRank = dimensionLength; + return tensorShape; +} - return ret; - } +INLINEDEF Nd4jLong *TAD::tad2Sub(Nd4jLong index, void *ptrManager) { + auto shape = shape::shapeOf(shapeInfo); + int rank = shape::rank(shapeInfo); + int leftOverIndexLen = rank - originalDimensionLength; + Nd4jLong *tadShape; + Nd4jLong *leftOverIndexes; + Nd4jLong *sub; + Nd4jLong *ret; + + ret = new Nd4jLong[rank]; + // shape of the tad + leftOverIndexes = new Nd4jLong[leftOverIndexLen]; + sub = new Nd4jLong[rank]; + tadShape = new Nd4jLong[leftOverIndexLen]; + + // indexes not specified in the tad indexes + + // every coordinate starts as zero + memset(ret, 0, sizeof(Nd4jLong) * rank); + + // find the length of the elements we + // are iterating over + Nd4jLong len = 1; + // left over index cursor for initializing elements + int leftOverIndex = 0; + for (int i = 0; i < rank; i++) { + // look for dimensions NOT found in dimension length (basically compute + // shape - dimension (set difference) + bool found = false; + for (int j = 0; j < originalDimensionLength; j++) { + // skip over specified dimensions when computing left over length + if (i == originalDimension[j]) { + found = true; + break; + } } - - INLINEDEF Nd4jLong* TAD::tensorShape(){ - if(this->tadShape != nullptr) - return this->tadShape; - - Nd4jLong *theShape = shape::shapeOf(shapeInfo); - Nd4jLong *tensorShape = shape::keep(theShape, this->dimension, dimensionLength,shape::rank(shapeInfo)); - this->tadShape = tensorShape; - this->tadRank = dimensionLength; - return tensorShape; + // add to the indexes that aren't specified as part of the tad dimension + // indexes + if (!found) { + // accumulate the list of indexes left over used for initializing the + // return value + leftOverIndexes[leftOverIndex] = i; + // accumulate the tad shape + tadShape[leftOverIndex] = shape[i]; + // accumulate the length (product) of the indexes that will be iterated + // over + leftOverIndex++; + len *= shape[i]; } + } - INLINEDEF Nd4jLong* TAD::tad2Sub(Nd4jLong index, void *ptrManager) { - auto shape = shape::shapeOf(shapeInfo); - int rank = shape::rank(shapeInfo); - int leftOverIndexLen = rank - originalDimensionLength; - Nd4jLong *tadShape; - Nd4jLong *leftOverIndexes; - Nd4jLong *sub; - Nd4jLong *ret; - - ret = new Nd4jLong[rank]; - //shape of the tad - leftOverIndexes = new Nd4jLong[leftOverIndexLen]; - sub = new Nd4jLong[rank]; - tadShape = new Nd4jLong[leftOverIndexLen]; - - //indexes not specified in the tad indexes - - //every coordinate starts as zero - memset(ret,0,sizeof(Nd4jLong) * rank); - - - //find the length of the elements we - //are iterating over - Nd4jLong len = 1; - //left over index cursor for initializing elements - int leftOverIndex = 0; - for(int i = 0; i < rank; i++) { - //look for dimensions NOT found in dimension length (basically compute shape - dimension (set difference) - bool found = false; - for(int j = 0; j < originalDimensionLength; j++) { - //skip over specified dimensions when computing left over length - if(i == originalDimension[j]) { - found = true; - break; - } - - } - - //add to the indexes that aren't specified as part of the tad dimension - //indexes - if(!found) { - //accumulate the list of indexes left over used for initializing the return value - leftOverIndexes[leftOverIndex] = i; - //accumulate the tad shape - tadShape[leftOverIndex] = shape[i]; - //accumulate the length (product) of the indexes that will be iterated over - leftOverIndex++; - len *= shape[i]; - - } - } + // sub for indices + /* int *sub = new int[leftOverIndexLen]; + shape::ind2subOrder(tadShape,index,len,sub); + */ + shape::index2coords(index, leftOverIndexLen, tadShape, sub); + for (int i = 0; i < leftOverIndexLen; i++) { + ret[leftOverIndexes[i]] = sub[i]; + } - //sub for indices - /* int *sub = new int[leftOverIndexLen]; - shape::ind2subOrder(tadShape,index,len,sub); - */ - shape::index2coords(index, leftOverIndexLen,tadShape, sub); + if (ptrManager == nullptr) { + delete[] leftOverIndexes; + delete[] tadShape; + delete[] sub; + } - for(int i = 0; i < leftOverIndexLen; i++) { - ret[leftOverIndexes[i]] = sub[i]; - } + return ret; +} - if (ptrManager == nullptr) { - delete[] leftOverIndexes; - delete[] tadShape; - delete[] sub; - } +INLINEDEF void TAD::createOffsets() { + this->tadOffsets = new Nd4jLong[this->numTads]; + uint nT = this->numTads; - return ret; - } - - INLINEDEF void TAD::createOffsets() { - this->tadOffsets = new Nd4jLong[this->numTads]; - uint nT = this->numTads; + for (uint i = 0; i < nT; i++) this->tadOffsets[i] = this->tadOffset(i); +} - for(uint i = 0; i < nT; i++) - this->tadOffsets[i] = this->tadOffset(i); +INLINEDEF Nd4jLong *TAD::shapeInfoOnlyShapeAndStride() { + // if(wholeThing || (dimensionLength == 1 && dimension[0] == MAX_DIMENSION) || + // shape::isScalar(shapeInfo)) + // return shape::createScalarShapeInfo(); + + // ensure tad shapes get setup right for vectors + if (dimensionLength > 1 && shape::isVector(shapeInfo)) + return shape::copyOf(shape::shapeInfoLength(shape::rank(shapeInfo)), + shapeInfo); + + // case when tad coincides with whole array + if (this->numTads == 1 && + ((shape::rank(originalShapeInfo) == originalDimensionLength) || + originalDimensionLength == 0)) { + // we might have special case here: skipped dimensions might be just full of + // ones + Nd4jLong *ret = shape::copyOf( + shape::shapeInfoLength(shape::rank(shapeInfo)), shapeInfo); + if (shape::isDimPermuted( + dimension, + (Nd4jLong)dimensionLength)) // check whether we need permutation + doPermuteShapeInfo(ret, dimension); + + return ret; + } + + Nd4jLong *theShape = shape::shapeOf(shapeInfo); + int rank = shape::rank(shapeInfo); + + if (dimensionLength == 1) { + if (dimension[0] == 0 && shape::isVector(shapeInfo) && theShape[1] == 1) { + int permuted[2] = {1, 0}; + Nd4jLong *permutedRet2 = shape::permuteShapeBuffer(shapeInfo, permuted); + return permutedRet2; + } else if (dimension[0] == 1 && shape::isVector(shapeInfo) && + theShape[0] == 1) { + Nd4jLong *ret = shape::copyOf( + shape::shapeInfoLength(shape::rank(shapeInfo)), shapeInfo); + return ret; + } else if (shape::shapeOf(shapeInfo)[dimension[0]] == 1) { + Nd4jLong *scalarInfo = shape::createScalarShapeInfo(); + scalarInfo[shape::shapeInfoLength(shape::rank(scalarInfo)) - 3] = + this->tadIndex; + return scalarInfo; } - - - INLINEDEF Nd4jLong* TAD::shapeInfoOnlyShapeAndStride() { - //if(wholeThing || (dimensionLength == 1 && dimension[0] == MAX_DIMENSION) || shape::isScalar(shapeInfo)) - // return shape::createScalarShapeInfo(); - - //ensure tad shapes get setup right for vectors - if(dimensionLength > 1 && shape::isVector(shapeInfo)) - return shape::copyOf(shape::shapeInfoLength(shape::rank(shapeInfo)),shapeInfo); - - // case when tad coincides with whole array - if( this->numTads == 1 && ((shape::rank(originalShapeInfo) == originalDimensionLength) || originalDimensionLength == 0)) { - // we might have special case here: skipped dimensions might be just full of ones - Nd4jLong *ret = shape::copyOf(shape::shapeInfoLength(shape::rank(shapeInfo)), shapeInfo); - if (shape::isDimPermuted(dimension, (Nd4jLong) dimensionLength)) // check whether we need permutation - doPermuteShapeInfo(ret, dimension); - - return ret; - } - - Nd4jLong *theShape = shape::shapeOf(shapeInfo); - int rank = shape::rank(shapeInfo); - - if(dimensionLength == 1) { - if(dimension[0] == 0 && shape::isVector(shapeInfo) && theShape[1] == 1) { - int permuted[2] = {1,0}; - Nd4jLong *permutedRet2 = shape::permuteShapeBuffer(shapeInfo, permuted); - return permutedRet2; - } else if(dimension[0] == 1 && shape::isVector(shapeInfo) && theShape[0] == 1) { - Nd4jLong *ret = shape::copyOf(shape::shapeInfoLength(shape::rank(shapeInfo)),shapeInfo); - return ret; - } - else if(shape::shapeOf(shapeInfo)[dimension[0]] == 1) { - Nd4jLong *scalarInfo = shape::createScalarShapeInfo(); - scalarInfo[shape::shapeInfoLength(shape::rank(scalarInfo)) - 3] = this->tadIndex; - return scalarInfo; - } - } - - Nd4jLong *tensorShape = this->tensorShape(); - int *reverseDimensions = shape::reverseCopy(dimension, dimensionLength); - int *rankRange = shape::range(0, rank); - int *remove = shape::removeIndex(rankRange, dimension, (Nd4jLong) rank, (Nd4jLong) dimensionLength); - //concat is wrong here with the length - int *newPermuteDims = shape::concat(remove,rank - dimensionLength,reverseDimensions,dimensionLength); - - Nd4jLong* permuted = shape::permuteShapeBuffer(shapeInfo,newPermuteDims); - - - Nd4jLong sliceIndex = shape::sliceOffsetForTensor(shape::rank(permuted), - this->tadIndex, - shape::shapeOf(shapeInfo), - tensorShape, - dimensionLength, - dimension, - dimensionLength); - - - - Nd4jLong *ret2 = shape::sliceOfShapeBuffer(sliceIndex, permuted); - Nd4jLong tensorLength = shape::prodLong(tensorShape,tadRank); - - Nd4jLong compLength = shape::isVector(ret2) ? shape::length(ret2) : shape::prodLong(tensorShape,tadRank); - // int temp; - // const bool isLikeVector = shape::isLikeVector(ret2, temp); - - // if(dimensionLength == tadRank && compLength == shape::length(ret2) && !isLikeVector) { - if(dimensionLength == tadRank && compLength == shape::length(ret2)) { - if(dimensionLength == 1 && shape::isVector(ret2) && shape::shapeOf(ret2)[0] == 1) { - //go to the bottom and return ret2 after proper freeing of pointers - //basic idea; we *don't* permute row vectors - } - else if(dimensionLength > 1) { - //permute *then* return ret2 - int *finalPermuteDims = new int[shape::rank(ret2)]; - int forward = 0; - for(int i = shape::rank(ret2) - 1; i >= 0; i--) { - finalPermuteDims[forward++] = i; - } - shape::permuteShapeBufferInPlace(ret2,finalPermuteDims,ret2); - delete[] finalPermuteDims; - - } - - } - else { - Nd4jLong length = tensorLength; - Nd4jLong lengthPerSlice = this->lengthPerSlice(ret2); - if(lengthPerSlice < 1) { - return ret2; - } - - Nd4jLong offset = tadIndex * tensorLength /lengthPerSlice; - if(sliceIndex == 0 && length == lengthPerSlice) { - Nd4jLong *newRet2 = shape::sliceOfShapeBuffer(offset, ret2); - delete[] ret2; - ret2 = newRet2; - int *finalPermuteDims = new int[shape::rank(ret2)]; - int forward = 0; - for(int i = shape::rank(ret2) - 1; i >= 0; i--) { - finalPermuteDims[forward++] = i; - } - // bool isRowVector2 = shape::isRowVector(ret2) && !isLikeVector; - bool isRowVector2 = shape::isRowVector(ret2); - if(isRowVector2 == false) { - shape::permuteShapeBufferInPlace(ret2, finalPermuteDims, ret2); - } - - delete[] finalPermuteDims; - - } - else if(length == lengthPerSlice) { - offset -= shape::slices(ret2) * (offset / shape::slices(ret2)); - Nd4jLong *newRet2 = shape::sliceOfShapeBuffer(offset,ret2); - delete[] ret2; - ret2 = newRet2; - if(dimensionLength == 1 && shape::isVector(ret2) && shape::shapeOf(ret2)[0] == 1) { - //go to the bottom and return ret2 after proper freeing of pointers - //basic idea; we *don't* permute row vectors - } - else { - int *finalPermuteDims = new int[shape::rank(ret2)]; - int forward = 0; - for(int i = shape::rank(ret2) - 1; i >= 0; i--) { - finalPermuteDims[forward++] = i; - } - Nd4jLong *newRet = shape::permuteShapeBuffer(ret2, finalPermuteDims); - delete[] ret2; - delete[] finalPermuteDims; - ret2 = newRet; - - } - - } - else { - //execute final part, note that this is mainly so delete[] gets called - //at the bottom of the method - while(shape::length(ret2) > length) { - auto lengthPerSlice2 = this->lengthPerSlice(ret2); - sliceIndex = sliceOffsetForTensor(sliceIndex,shape::length(ret2),lengthPerSlice2); - sliceIndex -= shape::slices(ret2) * (sliceIndex / shape::slices(ret2)); - auto newRet2 = shape::sliceOfShapeBuffer(sliceIndex,ret2); - delete[] ret2; - ret2 = newRet2; - } - - //don't permute on a row vector - if(dimensionLength == 1 && shape::isVector(ret2) && shape::shapeOf(ret2)[0] == 1) { - //go to the bottom and return ret2 after proper freeing of pointers - //basic idea; we *don't* permute row vectors - } - else if(dimensionLength > 1){ - //permute *then* return ret - int *finalPermuteDims = new int[shape::rank(ret2)]; - int forward = 0; - for(int i = shape::rank(ret2) - 1; i >= 0; i--) { - finalPermuteDims[forward++] = i; - } - auto newPermute = shape::permuteShapeBuffer(ret2,finalPermuteDims); - delete[] ret2; - delete[] finalPermuteDims; - ret2 = newPermute; - } - - } - } - - - delete[] permuted; - delete[] newPermuteDims; - delete[] rankRange; - delete[] remove; - delete[] reverseDimensions; - return ret2; + } + + Nd4jLong *tensorShape = this->tensorShape(); + int *reverseDimensions = shape::reverseCopy(dimension, dimensionLength); + int *rankRange = shape::range(0, rank); + int *remove = shape::removeIndex(rankRange, dimension, (Nd4jLong)rank, + (Nd4jLong)dimensionLength); + // concat is wrong here with the length + int *newPermuteDims = shape::concat(remove, rank - dimensionLength, + reverseDimensions, dimensionLength); + + Nd4jLong *permuted = shape::permuteShapeBuffer(shapeInfo, newPermuteDims); + + Nd4jLong sliceIndex = shape::sliceOffsetForTensor( + shape::rank(permuted), this->tadIndex, shape::shapeOf(shapeInfo), + tensorShape, dimensionLength, dimension, dimensionLength); + + Nd4jLong *ret2 = shape::sliceOfShapeBuffer(sliceIndex, permuted); + Nd4jLong tensorLength = shape::prodLong(tensorShape, tadRank); + + Nd4jLong compLength = shape::isVector(ret2) + ? shape::length(ret2) + : shape::prodLong(tensorShape, tadRank); + // int temp; + // const bool isLikeVector = shape::isLikeVector(ret2, temp); + + // if(dimensionLength == tadRank && compLength == shape::length(ret2) && + // !isLikeVector) { + if (dimensionLength == tadRank && compLength == shape::length(ret2)) { + if (dimensionLength == 1 && shape::isVector(ret2) && + shape::shapeOf(ret2)[0] == 1) { + // go to the bottom and return ret2 after proper freeing of pointers + // basic idea; we *don't* permute row vectors + } else if (dimensionLength > 1) { + // permute *then* return ret2 + int *finalPermuteDims = new int[shape::rank(ret2)]; + int forward = 0; + for (int i = shape::rank(ret2) - 1; i >= 0; i--) { + finalPermuteDims[forward++] = i; + } + shape::permuteShapeBufferInPlace(ret2, finalPermuteDims, ret2); + delete[] finalPermuteDims; } + } else { + Nd4jLong length = tensorLength; + Nd4jLong lengthPerSlice = this->lengthPerSlice(ret2); + if (lengthPerSlice < 1) { + return ret2; + } - INLINEDEF Nd4jLong TAD::tadLength(Nd4jLong const* shapeInfo, int const* dimension, int dimensionLength) { - if(dimensionLength == 1) { - return shape::shapeOf(shapeInfo)[dimension[0]]; + Nd4jLong offset = tadIndex * tensorLength / lengthPerSlice; + if (sliceIndex == 0 && length == lengthPerSlice) { + Nd4jLong *newRet2 = shape::sliceOfShapeBuffer(offset, ret2); + delete[] ret2; + ret2 = newRet2; + int *finalPermuteDims = new int[shape::rank(ret2)]; + int forward = 0; + for (int i = shape::rank(ret2) - 1; i >= 0; i--) { + finalPermuteDims[forward++] = i; + } + // bool isRowVector2 = shape::isRowVector(ret2) && !isLikeVector; + bool isRowVector2 = shape::isRowVector(ret2); + if (isRowVector2 == false) { + shape::permuteShapeBufferInPlace(ret2, finalPermuteDims, ret2); + } + + delete[] finalPermuteDims; + + } else if (length == lengthPerSlice) { + offset -= shape::slices(ret2) * (offset / shape::slices(ret2)); + Nd4jLong *newRet2 = shape::sliceOfShapeBuffer(offset, ret2); + delete[] ret2; + ret2 = newRet2; + if (dimensionLength == 1 && shape::isVector(ret2) && + shape::shapeOf(ret2)[0] == 1) { + // go to the bottom and return ret2 after proper freeing of pointers + // basic idea; we *don't* permute row vectors + } else { + int *finalPermuteDims = new int[shape::rank(ret2)]; + int forward = 0; + for (int i = shape::rank(ret2) - 1; i >= 0; i--) { + finalPermuteDims[forward++] = i; } - else { - Nd4jLong ret = 1; - for(int i = 0; i < shape::rank(shapeInfo); i++) { - for(int j = 0; j < dimensionLength; j++) { - if(i == dimension[j]) - ret *= shape::shapeOf(shapeInfo)[dimension[j]]; - } - } - return ret; + Nd4jLong *newRet = shape::permuteShapeBuffer(ret2, finalPermuteDims); + delete[] ret2; + delete[] finalPermuteDims; + ret2 = newRet; + } + + } else { + // execute final part, note that this is mainly so delete[] gets called + // at the bottom of the method + while (shape::length(ret2) > length) { + auto lengthPerSlice2 = this->lengthPerSlice(ret2); + sliceIndex = sliceOffsetForTensor(sliceIndex, shape::length(ret2), + lengthPerSlice2); + sliceIndex -= shape::slices(ret2) * (sliceIndex / shape::slices(ret2)); + auto newRet2 = shape::sliceOfShapeBuffer(sliceIndex, ret2); + delete[] ret2; + ret2 = newRet2; + } + + // don't permute on a row vector + if (dimensionLength == 1 && shape::isVector(ret2) && + shape::shapeOf(ret2)[0] == 1) { + // go to the bottom and return ret2 after proper freeing of pointers + // basic idea; we *don't* permute row vectors + } else if (dimensionLength > 1) { + // permute *then* return ret + int *finalPermuteDims = new int[shape::rank(ret2)]; + int forward = 0; + for (int i = shape::rank(ret2) - 1; i >= 0; i--) { + finalPermuteDims[forward++] = i; } + auto newPermute = shape::permuteShapeBuffer(ret2, finalPermuteDims); + delete[] ret2; + delete[] finalPermuteDims; + ret2 = newPermute; + } } + } + + delete[] permuted; + delete[] newPermuteDims; + delete[] rankRange; + delete[] remove; + delete[] reverseDimensions; + return ret2; +} - - INLINEDEF Nd4jLong TAD::tensorsAlongDimension(Nd4jLong const* shapeInfo, int const* dimension, int dimensionLength) { - return shape::length(shapeInfo) / this->tadLength(shapeInfo,dimension,dimensionLength); +INLINEDEF Nd4jLong TAD::tadLength(Nd4jLong const *shapeInfo, + int const *dimension, int dimensionLength) { + if (dimensionLength == 1) { + return shape::shapeOf(shapeInfo)[dimension[0]]; + } else { + Nd4jLong ret = 1; + for (int i = 0; i < shape::rank(shapeInfo); i++) { + for (int j = 0; j < dimensionLength; j++) { + if (i == dimension[j]) ret *= shape::shapeOf(shapeInfo)[dimension[j]]; + } } + return ret; + } +} +INLINEDEF Nd4jLong TAD::tensorsAlongDimension(Nd4jLong const *shapeInfo, + int const *dimension, + int dimensionLength) { + return shape::length(shapeInfo) / + this->tadLength(shapeInfo, dimension, dimensionLength); +} - INLINEDEF void TAD::collapse() { - auto shape = shape::shapeOf(shapeInfo); - //handle negative dimensions/backwards indexing - for(int i = 0; i < dimensionLength; i++) { - if((dimension)[i] < 0) - (dimension)[i] += shape::rank(this->shapeInfo); - } - - this->dimension = new int[dimensionLength]; - memcpy(this->dimension,this->originalDimension, sizeof(int) * dimensionLength); - - //we can drop trailing dimensions where it's all singular for example: - // shape: 4,3,1,2 - //dimension: 0,2 - // the problem for 0,2 is equivalent to: 0 - //the rest of the algorithm handles cases suchas - //shape: 4,1,1,2 - //dimension: 0,1 - //when this happens there are other dimensions (eg: at the end) that matter - int trailingOneDimensions = 0; - //trailing ones - for(int i = dimensionLength - 1; i >= 0; i--) { - if(shape[dimension[i]] != 1) { - break; - } - else if(shape[dimension[i]] == 1) - trailingOneDimensions++; - } +INLINEDEF void TAD::collapse() { + auto shape = shape::shapeOf(shapeInfo); + // handle negative dimensions/backwards indexing + for (int i = 0; i < dimensionLength; i++) { + if ((dimension)[i] < 0) (dimension)[i] += shape::rank(this->shapeInfo); + } + + this->dimension = new int[dimensionLength]; + memcpy(this->dimension, this->originalDimension, + sizeof(int) * dimensionLength); + + // we can drop trailing dimensions where it's all singular for example: + // shape: 4,3,1,2 + // dimension: 0,2 + // the problem for 0,2 is equivalent to: 0 + // the rest of the algorithm handles cases suchas + // shape: 4,1,1,2 + // dimension: 0,1 + // when this happens there are other dimensions (eg: at the end) that matter + int trailingOneDimensions = 0; + // trailing ones + for (int i = dimensionLength - 1; i >= 0; i--) { + if (shape[dimension[i]] != 1) { + break; + } else if (shape[dimension[i]] == 1) + trailingOneDimensions++; + } + + dimensionLength -= trailingOneDimensions; + + int leadingOneDimensions = 0; + // trailing ones + for (int i = 0; i < dimensionLength; i++) { + if (shape[dimension[i]] != 1) { + break; + } else if (shape[dimension[i]] == 1) + leadingOneDimensions++; + } + + // bump the dimension pointer forward for however many leadingones there are + dimension += leadingOneDimensions; + // decrease the dimension length by the amount of leading ones + dimensionLength -= leadingOneDimensions; + + bool preConverged = true; + for (int i = 0; i < dimensionLength; i++) { + if (shape[dimension[i]] == 1) { + preConverged = false; + break; + } + } + + // we took away all the singular dimensions, we can just return + if (preConverged) return; + + // no more singular dimensions specified + bool done = false; + int onesDecrement = 0; + bool changed = false; + while (!done) { + // terminate early: only singular dimensions specified for reduce + if ((dimensionLength) < 1) { + done = true; + // signal as a no op + dimension[0] = -1; + break; + } + // captures intermediary result from the for loop + traceNew(3); - dimensionLength -= trailingOneDimensions; + int intermediaryResult[MAX_RANK]; + for (int i = 0; i < dimensionLength; i++) { + intermediaryResult[i] = (dimension)[i]; + } - int leadingOneDimensions = 0; - //trailing ones - for(int i = 0; i < dimensionLength; i++) { - if(shape[dimension[i]] != 1) { - break; - } - else if(shape[dimension[i]] == 1) - leadingOneDimensions++; + bool oneEncountered = false; + bool nonOneEncountered = false; + bool hitBeginning = false; + // assume intermediate collapsing of dimensions + bool collapseMiddleDimensions = true; + // note here that dimension length MAY end up being zero + for (int i = (dimensionLength)-1; i >= 0; i--) { + if (shape[(dimension)[i]] == 1) { + oneEncountered = true; + // trailing ones + if (!nonOneEncountered) { + // just drop trailing ones + dimensionLength--; + nonOneEncountered = false; + collapseMiddleDimensions = false; + // intermediary result just needs to have the results copied from + // dimension since we're just removing the tail + memcpy(intermediaryResult, dimension, sizeof(int) * dimensionLength); + changed = true; + // break the for loop and force it to go back around starting from the + // new index + break; + } else { + // already decremented all dimensions + // this was a result of hitting beginning ones + // we will only need to loop once + if (i == 0) { + hitBeginning = true; + } + // will need to shift dimensions that aren't trailing ones + // back by onesDecrement + // mark the intermediary result as -1 for non inclusion + intermediaryResult[i] = -1; + onesDecrement++; } + } else { + intermediaryResult[i] = (dimension)[i]; + nonOneEncountered = true; + } + } - //bump the dimension pointer forward for however many leadingones there are - dimension += leadingOneDimensions; - //decrease the dimension length by the amount of leading ones - dimensionLength -= leadingOneDimensions; - - - bool preConverged = true; - for(int i = 0; i < dimensionLength; i++) { - if(shape[dimension[i]] == 1) { - preConverged = false; - break; - } + if (collapseMiddleDimensions && oneEncountered) { + // collapse dimensions + int newIntermediary[MAX_RANK]; + int idx = 0; + for (int i = 0; i < dimensionLength; i++) { + // of note: dimension will decrease by the number of ones encountered + if (intermediaryResult[i] >= 0) { + // dimension 0 doesn't need to be decremented + if (intermediaryResult[i] > 0) + newIntermediary[idx++] = intermediaryResult[i] - onesDecrement; + else + newIntermediary[idx++] = intermediaryResult[i]; } + } - //we took away all the singular dimensions, we can just return - if(preConverged) - return; - - //no more singular dimensions specified - bool done = false; - int onesDecrement = 0; - bool changed = false; - while(!done) { - //terminate early: only singular dimensions specified for reduce - if((dimensionLength) < 1) { - done = true; - //signal as a no op - dimension[0] = -1; - break; - } - //captures intermediary result from the for loop - traceNew(3); - - int intermediaryResult[MAX_RANK]; - for(int i = 0; i < dimensionLength; i++) { - intermediaryResult[i] = (dimension)[i]; - } - - bool oneEncountered = false; - bool nonOneEncountered = false; - bool hitBeginning = false; - //assume intermediate collapsing of dimensions - bool collapseMiddleDimensions = true; - //note here that dimension length MAY end up being zero - for(int i = (dimensionLength) - 1; i >= 0; i--) { - if(shape[(dimension)[i]] == 1) { - oneEncountered = true; - //trailing ones - if(!nonOneEncountered) { - //just drop trailing ones - dimensionLength--; - nonOneEncountered = false; - collapseMiddleDimensions = false; - //intermediary result just needs to have the results copied from dimension since we're just removing the tail - memcpy(intermediaryResult,dimension, sizeof(int) * dimensionLength); - changed = true; - //break the for loop and force it to go back around starting from the new index - break; - } - else { - //already decremented all dimensions - //this was a result of hitting beginning ones - //we will only need to loop once - if(i == 0) { - hitBeginning = true; - } - //will need to shift dimensions that aren't trailing ones - //back by onesDecrement - //mark the intermediary result as -1 for non inclusion - intermediaryResult[i] = -1; - onesDecrement++; - } - } - else { - intermediaryResult[i] = (dimension)[i]; - nonOneEncountered = true; - } - } - - if(collapseMiddleDimensions && oneEncountered) { - //collapse dimensions - int newIntermediary[MAX_RANK]; - int idx = 0; - for(int i = 0; i < dimensionLength; i++) { - //of note: dimension will decrease by the number of ones encountered - if(intermediaryResult[i] >= 0) { - //dimension 0 doesn't need to be decremented - if(intermediaryResult[i] > 0) - newIntermediary[idx++] = intermediaryResult[i] - onesDecrement; - else - newIntermediary[idx++] = intermediaryResult[i]; - } - } - - - //decrement by the number of dimensions where ones appeared - (dimensionLength) -= onesDecrement; - //update to current result - memcpy(dimension,newIntermediary, sizeof(int) * (dimensionLength)); - changed = true; - - } - //converged: no need to change result - else { - //update to current result - memcpy(dimension,intermediaryResult, sizeof(int) * dimensionLength); - } - - //converge when there are no singular dimensions specified in the reduce - done = (!oneEncountered && nonOneEncountered) || hitBeginning; - //delete[] intermediaryResult; - } + // decrement by the number of dimensions where ones appeared + (dimensionLength) -= onesDecrement; + // update to current result + memcpy(dimension, newIntermediary, sizeof(int) * (dimensionLength)); + changed = true; - //nothing changed but need to collapse dimension - if(!changed && this->numOnes > 0) { - for(int i = 0; i < dimensionLength ;i++) { - dimension[i] -= numOnes; - } - } + } + // converged: no need to change result + else { + // update to current result + memcpy(dimension, intermediaryResult, sizeof(int) * dimensionLength); + } + // converge when there are no singular dimensions specified in the reduce + done = (!oneEncountered && nonOneEncountered) || hitBeginning; + // delete[] intermediaryResult; + } + // nothing changed but need to collapse dimension + if (!changed && this->numOnes > 0) { + for (int i = 0; i < dimensionLength; i++) { + dimension[i] -= numOnes; } + } } +} // namespace shape -#endif //LIBND4J_TAD_H +#endif // LIBND4J_TAD_H diff --git a/libnd4j/include/helpers/benchmark/BasicSuit.h b/libnd4j/include/helpers/benchmark/BasicSuit.h index 1e4c156fbe72..dda0593d09d7 100644 --- a/libnd4j/include/helpers/benchmark/BasicSuit.h +++ b/libnd4j/include/helpers/benchmark/BasicSuit.h @@ -18,16 +18,14 @@ * @author raver119@gmail.com */ -#ifndef DEV_TESTS_BASICSUIT_H -#define DEV_TESTS_BASICSUIT_H +#ifndef SD_BASICSUIT_H +#define SD_BASICSUIT_H namespace sd { - class BasicSuit { - protected: +class BasicSuit { + protected: + public: +}; +} // namespace sd - public: - - }; -} - -#endif //DEV_TESTS_BASICSUIT_H \ No newline at end of file +#endif // SD_BASICSUIT_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/BoolParameters.h b/libnd4j/include/helpers/benchmark/BoolParameters.h index bac8a0c5cb69..ce98a1bbac57 100644 --- a/libnd4j/include/helpers/benchmark/BoolParameters.h +++ b/libnd4j/include/helpers/benchmark/BoolParameters.h @@ -18,31 +18,29 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_BOOLPARAMETERS_H -#define DEV_TESTS_BOOLPARAMETERS_H +#ifndef SD_BOOLPARAMETERS_H +#define SD_BOOLPARAMETERS_H #include -#include #include +#include + #include "Parameters.h" #include "ParametersSpace.h" namespace sd { - class BoolParameters : public ParametersSpace { - protected: - - public: - BoolParameters(std::string name) : ParametersSpace() { - _name = name; - } +class BoolParameters : public ParametersSpace { + protected: + public: + BoolParameters(std::string name) : ParametersSpace() { _name = name; } - std::vector evaluate() override { - std::vector result; - result.emplace_back(0); - result.emplace_back(1); - return result; - } - }; -} + std::vector evaluate() override { + std::vector result; + result.emplace_back(0); + result.emplace_back(1); + return result; + } +}; +} // namespace sd -#endif //DEV_TESTS_PARAMETERSPACE_H \ No newline at end of file +#endif // SD_PARAMETERSPACE_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/BroadcastBenchmark.h b/libnd4j/include/helpers/benchmark/BroadcastBenchmark.h index 8c61bda23e80..5c100444f9b2 100644 --- a/libnd4j/include/helpers/benchmark/BroadcastBenchmark.h +++ b/libnd4j/include/helpers/benchmark/BroadcastBenchmark.h @@ -18,116 +18,118 @@ // @author Alex Black // +#include + #include "../OpBenchmark.h" -#ifndef DEV_TESTS_BROADCASTBENCHMARK_H -#define DEV_TESTS_BROADCASTBENCHMARK_H +#ifndef SD_BROADCASTBENCHMARK_H +#define SD_BROADCASTBENCHMARK_H namespace sd { - class ND4J_EXPORT BroadcastBenchmark : public OpBenchmark { - public: - BroadcastBenchmark() : OpBenchmark() { - // - } - - BroadcastBenchmark(broadcast::Ops op, std::string testName, NDArray *x, NDArray *y, NDArray *z, std::vector axis) : OpBenchmark(testName, x, y, z, axis) { - _opNum = (int) op; - } - - BroadcastBenchmark(broadcast::Ops op, std::string testName, NDArray *x, NDArray *y, NDArray *z, std::initializer_list axis) : OpBenchmark(testName, x, y, z, axis) { - _opNum = (int) op; - } - - BroadcastBenchmark(broadcast::Ops op, std::string name, std::vector axis) : OpBenchmark() { - _opNum = (int) op; - _testName = name; - _axis = axis; - } - - BroadcastBenchmark(broadcast::Ops op, std::string name, std::initializer_list axis) : OpBenchmark() { - _opNum = (int) op; - _testName = name; - _axis = axis; - } - - ~BroadcastBenchmark(){ - if (_x != _y && _x != _z && _y != _z) { - delete _x; - delete _y; - delete _z; - } else if (_x == _y && _x == _z) { - delete _x; - } else if (_x == _z) { - delete _x; - delete _y; - } else if (_y == _z) { - delete _x; - delete _y; - } - } - -void executeOnce() override { +class SD_EXPORT BroadcastBenchmark : public OpBenchmark { + public: + BroadcastBenchmark() : OpBenchmark() { + // + } + + BroadcastBenchmark(broadcast::Ops op, const std::string &testName, + const NDArray &x, const NDArray &y, const NDArray &z, + const std::vector &axis) + : OpBenchmark(testName, x, y, z, axis) { + _opNum = (int)op; + } + + BroadcastBenchmark(broadcast::Ops op, const std::string &name, + const std::vector &axis) + : OpBenchmark() { + _opNum = (int)op; + _testName = name; + _axis = axis; + } + + ~BroadcastBenchmark() { + // + } + + void executeOnce() override { PointersManager manager(LaunchContext::defaultContext(), "BroadcastBM"); - auto packX = ConstantTadHelper::getInstance().tadForDimensions(_x->shapeInfo(), _axis); - auto packZ = ConstantTadHelper::getInstance().tadForDimensions(_z->shapeInfo(), _axis); - - auto tadOnlyShapeInfo = Environment::getInstance().isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); - auto tadOffsets = Environment::getInstance().isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); - - auto tadOnlyShapeInfoZ = Environment::getInstance().isCPU() ? packZ.primaryShapeInfo() : packZ.specialShapeInfo(); - auto tadOffsetsZ = Environment::getInstance().isCPU() ? packZ.primaryOffsets() : packZ.specialOffsets(); - - NativeOpExecutioner::execBroadcast(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), _y->buffer(), _y->shapeInfo(), _y->specialBuffer(), _y->specialShapeInfo(), _z->buffer(), _z->shapeInfo(), _z->specialBuffer(), _z->specialShapeInfo(), nullptr, _axis.size(), - /*Nd4jLong **/ tadOnlyShapeInfo, /*Nd4jLong */ tadOffsets, /*Nd4jLong */ tadOnlyShapeInfoZ, /*Nd4jLong */ tadOffsetsZ); + auto packX = ConstantTadHelper::getInstance().tadForDimensions( + _x.shapeInfo(), _axis); + auto packZ = ConstantTadHelper::getInstance().tadForDimensions( + _z.shapeInfo(), _axis); + + auto tadOnlyShapeInfo = Environment::getInstance().isCPU() + ? packX.primaryShapeInfo() + : packX.specialShapeInfo(); + auto tadOffsets = Environment::getInstance().isCPU() + ? packX.primaryOffsets() + : packX.specialOffsets(); + + auto tadOnlyShapeInfoZ = Environment::getInstance().isCPU() + ? packZ.primaryShapeInfo() + : packZ.specialShapeInfo(); + auto tadOffsetsZ = Environment::getInstance().isCPU() + ? packZ.primaryOffsets() + : packZ.specialOffsets(); + + NativeOpExecutioner::execBroadcast( + LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), + _x.specialBuffer(), _x.specialShapeInfo(), _y.buffer(), _y.shapeInfo(), + _y.specialBuffer(), _y.specialShapeInfo(), _z.buffer(), _z.shapeInfo(), + _z.specialBuffer(), _z.specialShapeInfo(), nullptr, _axis.size(), + /*Nd4jLong **/ tadOnlyShapeInfo, /*Nd4jLong */ tadOffsets, + /*Nd4jLong */ tadOnlyShapeInfoZ, /*Nd4jLong */ tadOffsetsZ); manager.synchronize(); - } - - std::string axis() override { - if (_axis.empty()) - return ""; - else { - std::string result; - for (auto v:_axis) { - auto s = StringUtils::valueToString(v); - result += s; - result += ","; - } - return result; - } - } - - std::string inplace() override { - std::string result; - result += (_x == _z ? "true" : "false"); - return result; - } - - std::string orders() override { - std::string result; - result += _x->ordering(); - result += "/"; - result += _y->ordering(); - result += "/"; - result += _z == nullptr ? _x->ordering() : _z->ordering(); - return result; - } - - std::string strides() override { - std::string result; - result += ShapeUtils::strideAsString(_x); - result += "/"; - result += ShapeUtils::strideAsString(_y); - result += "/"; - result += _z == nullptr ? ShapeUtils::strideAsString(_x) : ShapeUtils::strideAsString(_z); - return result; - } - - OpBenchmark* clone() override { - return new BroadcastBenchmark((broadcast::Ops) _opNum, _testName, _x, _y, _z, _axis); - } - }; -} - -#endif //DEV_TESTS_BROADCASTBENCHMARK_H \ No newline at end of file + } + + std::string axis() override { + if (_axis.empty()) + return ""; + else { + std::string result; + for (auto v : _axis) { + auto s = StringUtils::valueToString(v); + result += s; + result += ","; + } + return result; + } + } + + std::string inplace() override { + std::string result; + result += (_x == _z ? "true" : "false"); + return result; + } + + std::string orders() override { + std::string result; + result += _x.ordering(); + result += "/"; + result += _y.ordering(); + result += "/"; + result += _z.shapeInfo() == nullptr ? _x.ordering() : _z.ordering(); + return result; + } + + std::string strides() override { + std::string result; + result += ShapeUtils::strideAsString(_x); + result += "/"; + result += ShapeUtils::strideAsString(_y); + result += "/"; + result += _z.shapeInfo() == nullptr ? ShapeUtils::strideAsString(_x) + : ShapeUtils::strideAsString(_z); + return result; + } + + OpBenchmark *clone() override { + return new BroadcastBenchmark((broadcast::Ops)_opNum, _testName, _x, _y, _z, + _axis); + } +}; +} // namespace sd + +#endif // SD_BROADCASTBENCHMARK_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/DeclarableBenchmark.h b/libnd4j/include/helpers/benchmark/DeclarableBenchmark.h index 58c018a5b602..9091257494a0 100644 --- a/libnd4j/include/helpers/benchmark/DeclarableBenchmark.h +++ b/libnd4j/include/helpers/benchmark/DeclarableBenchmark.h @@ -14,164 +14,154 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // Created by raver on 3/2/2019. // -#ifndef DEV_TESTS_DECLARABLEBENCHMARK_H -#define DEV_TESTS_DECLARABLEBENCHMARK_H +#ifndef SD_DECLARABLEBENCHMARK_H +#define SD_DECLARABLEBENCHMARK_H #include #include #include +#include #include #include -#include namespace sd { - class ND4J_EXPORT DeclarableBenchmark : public OpBenchmark { - protected: - sd::ops::DeclarableOp *_op = nullptr; - sd::graph::Context *_context = nullptr; - public: - DeclarableBenchmark(sd::ops::DeclarableOp &op, std::string name = 0) : OpBenchmark() { - _op = &op; //ops::OpRegistrator::getInstance().getOperation(op.getOpHash()); - _testName = name; - } - - void setContext(sd::graph::Context *ctx) { - _context = ctx; - } - - std::string axis() override { - return "N/A"; - } - - std::string orders() override { - if(_context != nullptr && _context->isFastPath()){ - std::vector& ins = _context->fastpath_in(); - std::string s; - for( int i=0; i 0){ - s += "/"; - } - s += ShapeUtils::strideAsString(_context->getNDArray(i)); - } - return s; - } - return "N/A"; - } - - std::string strides() override { - if (_context != nullptr && _context->isFastPath()) { - std::vector& ins = _context->fastpath_in(); - std::string s(""); - for( int i=0; i 0){ - s += "/"; - } - s += ShapeUtils::strideAsString(_context->getNDArray(i)); - } - return s; - } else - return "N/A"; +class SD_EXPORT DeclarableBenchmark : public OpBenchmark { + protected: + sd::ops::DeclarableOp *_op = nullptr; + sd::graph::Context *_context = nullptr; + + public: + DeclarableBenchmark(sd::ops::DeclarableOp &op, std::string name = 0) + : OpBenchmark() { + _op = + &op; // ops::OpRegistrator::getInstance().getOperation(op.getOpHash()); + _testName = name; + } + + void setContext(sd::graph::Context *ctx) { _context = ctx; } + + std::string axis() override { return "N/A"; } + + std::string orders() override { + if (_context != nullptr && _context->isFastPath()) { + auto &ins = _context->fastpath_in(); + std::string s; + for (int i = 0; i < ins.size(); i++) { + if (i > 0) { + s += "/"; } - - std::string inplace() override { - return "N/A"; + s += ShapeUtils::strideAsString(_context->getNDArray(i).get()); + } + return s; + } + return "N/A"; + } + + std::string strides() override { + if (_context != nullptr && _context->isFastPath()) { + auto &ins = _context->fastpath_in(); + std::string s(""); + for (int i = 0; i < ins.size(); i++) { + if (i > 0) { + s += "/"; } - - void executeOnce() override { - PointersManager pm(LaunchContext::defaultContext(), "DeclarableBenchmark"); - _op->execute(_context); - pm.synchronize(); - } - - OpBenchmark *clone() override { - return new DeclarableBenchmark(*_op, _testName); + s += ShapeUtils::strideAsString(_context->getNDArray(i).get()); + } + return s; + } else + return "N/A"; + } + + std::string inplace() override { return "N/A"; } + + void executeOnce() override { + PointersManager pm(LaunchContext::defaultContext(), "DeclarableBenchmark"); + _op->execute(_context); + pm.synchronize(); + } + + OpBenchmark *clone() override { + return new DeclarableBenchmark(*_op, _testName); + } + + std::string shape() override { + if (_context != nullptr && _context->isFastPath()) { + auto &ins = _context->fastpath_in(); + std::string s; + for (int i = 0; i < ins.size(); i++) { + if (i > 0) { + s += "/"; } - - std::string shape() override { - if (_context != nullptr && _context->isFastPath()) { - std::vector& ins = _context->fastpath_in(); - std::string s; - for( int i=0; i 0){ - s += "/"; - } - s += ShapeUtils::shapeAsString(_context->getNDArray(i)); - } - return s; - } else - return "N/A"; + s += ShapeUtils::shapeAsString(_context->getNDArray(i).get()); + } + return s; + } else + return "N/A"; + } + + std::string dataType() override { + if (_context != nullptr && _context->isFastPath()) { + auto &ins = _context->fastpath_in(); + std::string s; + for (int i = 0; i < ins.size(); i++) { + if (i > 0) { + s += "/"; } - - std::string dataType() override { - if (_context != nullptr && _context->isFastPath()){ - std::vector& ins = _context->fastpath_in(); - std::string s; - for( int i=0; i 0){ - s += "/"; - } - s += DataTypeUtils::asString(_context->getNDArray(i)->dataType()); - } - return s; - } else - return "N/A"; + s += DataTypeUtils::asString(_context->getNDArray(i)->dataType()); + } + return s; + } else + return "N/A"; + } + + std::string extra() override { + if (_context != nullptr) { + auto iargs = _context->getIArguments(); + auto targs = _context->getTArguments(); + auto bargs = _context->getBArguments(); + std::string e; + bool any = false; + if (!iargs.empty()) { + e += "iargs=["; + for (int i = 0; i < iargs.size(); i++) { + if (i > 0) e += ","; + e += std::to_string(iargs.at(i)); } - - std::string extra() override { - if(_context != nullptr){ - std::vector* iargs = _context->getIArguments(); - std::vector* targs = _context->getTArguments(); - std::vector* bargs = _context->getBArguments(); - std::string e; - bool any = false; - if(iargs != nullptr){ - e += "iargs=["; - for( int i=0; isize(); i++ ){ - if(i > 0) - e += ","; - e += std::to_string(iargs->at(i)); - } - e += "]"; - any = true; - } - if(targs != nullptr){ - if(any) - e += ","; - e += "targs=["; - for( int i=0; isize(); i++ ){ - if(i > 0) - e += ","; - e += std::to_string(targs->at(i)); - } - e += "]"; - any = true; - } - if(bargs != nullptr){ - if(any) - e += ","; - e += "bargs=["; - for( int i=0; isize(); i++ ){ - if(i > 0) - e += ","; - e += std::to_string(bargs->at(i)); - } - e += "]"; - } - return e; - } - return "N/A"; + e += "]"; + any = true; + } + if (!targs.empty()) { + if (any) e += ","; + e += "targs=["; + for (int i = 0; i < targs.size(); i++) { + if (i > 0) e += ","; + e += std::to_string(targs.at(i)); } - - ~DeclarableBenchmark() { - if (_context != nullptr) - delete _context; + e += "]"; + any = true; + } + if (!bargs.empty()) { + if (any) e += ","; + e += "bargs=["; + for (int i = 0; i < bargs.size(); i++) { + if (i > 0) e += ","; + e += std::to_string(bargs.at(i)); } - }; -} - -#endif //DEV_TESTS_DECLARABLEBENCHMARKS_H \ No newline at end of file + e += "]"; + } + return e; + } + return "N/A"; + } + + ~DeclarableBenchmark() { + if (_context != nullptr) delete _context; + } +}; +} // namespace sd + +#endif // SD_DECLARABLEBENCHMARKS_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/IntParameters.h b/libnd4j/include/helpers/benchmark/IntParameters.h index 10a1763e4979..d630becbe0ae 100644 --- a/libnd4j/include/helpers/benchmark/IntParameters.h +++ b/libnd4j/include/helpers/benchmark/IntParameters.h @@ -18,38 +18,40 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_INTPARAMETERS_H -#define DEV_TESTS_INTPARAMETERS_H +#ifndef SD_INTPARAMETERS_H +#define SD_INTPARAMETERS_H #include -#include #include +#include + #include "Parameters.h" #include "ParametersSpace.h" namespace sd { - class IntParameters : public ParametersSpace { - protected: - int _start; - int _stop; - int _step; - - public: - IntParameters(std::string name, int start, int stop, int step = 1) : ParametersSpace() { - _start = start; - _stop = stop; - _step = step; - _name = name; - } - - std::vector evaluate() override { - std::vector result; - for (int e = _start; e <= _stop; e += _step) { - result.emplace_back(e); - } - return result; - } - }; -} - -#endif //DEV_TESTS_INTPARAMETERS_H \ No newline at end of file +class IntParameters : public ParametersSpace { + protected: + int _start; + int _stop; + int _step; + + public: + IntParameters(std::string name, int start, int stop, int step = 1) + : ParametersSpace() { + _start = start; + _stop = stop; + _step = step; + _name = name; + } + + std::vector evaluate() override { + std::vector result; + for (int e = _start; e <= _stop; e += _step) { + result.emplace_back(e); + } + return result; + } +}; +} // namespace sd + +#endif // SD_INTPARAMETERS_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/IntPowerParameters.h b/libnd4j/include/helpers/benchmark/IntPowerParameters.h index 82c58bb2317e..d9bac1cd24df 100644 --- a/libnd4j/include/helpers/benchmark/IntPowerParameters.h +++ b/libnd4j/include/helpers/benchmark/IntPowerParameters.h @@ -18,40 +18,43 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_INTPOWERPARAMETERS_H -#define DEV_TESTS_INTPOWERPARAMETERS_H +#ifndef SD_INTPOWERPARAMETERS_H +#define SD_INTPOWERPARAMETERS_H #include -#include #include +#include + #include "Parameters.h" #include "ParametersSpace.h" namespace sd { - class IntPowerParameters : public ParametersSpace { - protected: - int _base; - int _start; - int _stop; - int _step; - - public: - IntPowerParameters(std::string name, int base, int start, int stop, int step = 1) : ParametersSpace() { - _base = base; - _start = start; - _stop = stop; - _step = step; - _name = name; - } - - std::vector evaluate() override { - std::vector result; - for (int e = _start; e <= _stop; e += _step) { - result.emplace_back(sd::math::nd4j_pow(_base, e)); - } - return result; - } - }; -} - -#endif //DEV_TESTS_INTPOWERPARAMETERS_H \ No newline at end of file +class IntPowerParameters : public ParametersSpace { + protected: + int _base; + int _start; + int _stop; + int _step; + + public: + IntPowerParameters(std::string name, int base, int start, int stop, + int step = 1) + : ParametersSpace() { + _base = base; + _start = start; + _stop = stop; + _step = step; + _name = name; + } + + std::vector evaluate() override { + std::vector result; + for (int e = _start; e <= _stop; e += _step) { + result.emplace_back(sd::math::nd4j_pow(_base, e)); + } + return result; + } +}; +} // namespace sd + +#endif // SD_INTPOWERPARAMETERS_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/MatrixBenchmark.h b/libnd4j/include/helpers/benchmark/MatrixBenchmark.h index eb8fd2619993..c1d81df89df1 100644 --- a/libnd4j/include/helpers/benchmark/MatrixBenchmark.h +++ b/libnd4j/include/helpers/benchmark/MatrixBenchmark.h @@ -18,107 +18,98 @@ // @author raver119@gmail.com // -#include #include +#include -#ifndef DEV_TESTS_MATRIXBENCHMARK_H -#define DEV_TESTS_MATRIXBENCHMARK_H +#ifndef SD_MATRIXBENCHMARK_H +#define SD_MATRIXBENCHMARK_H namespace sd { - class ND4J_EXPORT MatrixBenchmark : public OpBenchmark { - private: - float _alpha = 1.0f; - float _beta = 0.0f; - bool _tA; - bool _tB; - public: - MatrixBenchmark() : OpBenchmark() { - // - } - - MatrixBenchmark(float alpha, float beta, std::string testName, NDArray *x, NDArray *y, NDArray *z) : OpBenchmark(testName, x, y, z) { - _alpha = alpha; - _beta = beta; - _tA = false; - _tB = false; - } - - MatrixBenchmark(float alpha, float beta, bool tA, bool tB, std::string name) : OpBenchmark() { - _testName = name; - _alpha = alpha; - _beta = beta; - _tA = tA; - _tB = tB; - } - - ~MatrixBenchmark(){ - if (_x != _y && _x != _z && _y != _z) { - delete _x; - delete _y; - delete _z; - } else if (_x == _y && _x == _z) { - delete _x; - } else if (_x == _z) { - delete _x; - delete _y; - } else if (_y == _z) { - delete _x; - delete _y; - } - } - - void executeOnce() override { - auto xT = (_tA ? _x->transpose() : *_x); - auto yT = (_tB ? _y->transpose() : *_y); - - MmulHelper::mmul(&xT, &yT, _z, _alpha, _beta); - } - - std::string axis() override { - return "N/A"; - } - - std::string inplace() override { - return "N/A"; - } - - std::string orders() override { - std::string result; - result += _x->ordering(); - result += "/"; - result += _y->ordering(); - result += "/"; - result += _z == nullptr ? _x->ordering() : _z->ordering(); - return result; - } - - std::string strides() override { - std::string result; - result += ShapeUtils::strideAsString(_x); - result += "/"; - result += ShapeUtils::strideAsString(_y); - result += "/"; - result += _z == nullptr ? ShapeUtils::strideAsString(_x) : ShapeUtils::strideAsString(_z); - return result; - } - - std::string shape() override { - std::string result; - result += ShapeUtils::shapeAsString(_x); - result += "x"; - result += ShapeUtils::shapeAsString(_y); - result += "="; - result += _z == nullptr ? "" : ShapeUtils::shapeAsString(_z); - return result; - } - - OpBenchmark* clone() override { - MatrixBenchmark* mb = new MatrixBenchmark(_alpha, _beta, _testName, _x, _y, _z); - mb->_tA = _tA; - mb->_tB = _tB; - return mb; - } - }; -} - -#endif //DEV_TESTS_SCALARBENCHMARK_H \ No newline at end of file +class SD_EXPORT MatrixBenchmark : public OpBenchmark { + private: + float _alpha = 1.0f; + float _beta = 0.0f; + bool _tA; + bool _tB; + + public: + MatrixBenchmark() : OpBenchmark() { + // + } + + MatrixBenchmark(float alpha, float beta, const std::string &testName, + const NDArray &x, const NDArray &y, const NDArray &z) + : OpBenchmark(testName, x, y, z) { + _alpha = alpha; + _beta = beta; + _tA = false; + _tB = false; + } + + MatrixBenchmark(float alpha, float beta, bool tA, bool tB, + const std::string &name) + : OpBenchmark() { + _testName = name; + _alpha = alpha; + _beta = beta; + _tA = tA; + _tB = tB; + } + + ~MatrixBenchmark() { + // + } + + void executeOnce() override { + auto xT = (_tA ? _x.transpose() : _x); + auto yT = (_tB ? _y.transpose() : _y); + + MmulHelper::mmul(&xT, &yT, &_z, _alpha, _beta); + } + + std::string axis() override { return "N/A"; } + + std::string inplace() override { return "N/A"; } + + std::string orders() override { + std::string result; + result += _x.ordering(); + result += "/"; + result += _y.ordering(); + result += "/"; + result += _z.shapeInfo() == nullptr ? _x.ordering() : _z.ordering(); + return result; + } + + std::string strides() override { + std::string result; + result += ShapeUtils::strideAsString(_x); + result += "/"; + result += ShapeUtils::strideAsString(_y); + result += "/"; + result += _z.shapeInfo() == nullptr ? ShapeUtils::strideAsString(_x) + : ShapeUtils::strideAsString(_z); + return result; + } + + std::string shape() override { + std::string result; + result += ShapeUtils::shapeAsString(_x); + result += "x"; + result += ShapeUtils::shapeAsString(_y); + result += "="; + result += _z.shapeInfo() == nullptr ? "" : ShapeUtils::shapeAsString(_z); + return result; + } + + OpBenchmark *clone() override { + MatrixBenchmark *mb = + new MatrixBenchmark(_alpha, _beta, _testName, _x, _y, _z); + mb->_tA = _tA; + mb->_tB = _tB; + return mb; + } +}; +} // namespace sd + +#endif // SD_SCALARBENCHMARK_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/PairwiseBenchmark.h b/libnd4j/include/helpers/benchmark/PairwiseBenchmark.h index ca92e96b3219..bc1b8c053be5 100644 --- a/libnd4j/include/helpers/benchmark/PairwiseBenchmark.h +++ b/libnd4j/include/helpers/benchmark/PairwiseBenchmark.h @@ -20,89 +20,82 @@ #include "../OpBenchmark.h" -#ifndef DEV_TESTS_PAIRWISEBENCHMARK_H -#define DEV_TESTS_PAIRWISEBENCHMARK_H +#ifndef SD_PAIRWISEBENCHMARK_H +#define SD_PAIRWISEBENCHMARK_H using namespace sd::graph; namespace sd { - class ND4J_EXPORT PairwiseBenchmark : public OpBenchmark { - public: - PairwiseBenchmark() : OpBenchmark() { - // - } - - PairwiseBenchmark(pairwise::Ops op, std::string testName, NDArray *x, NDArray *y, NDArray *z) : OpBenchmark(testName, x, y, z) { - _opNum = (int) op; - } - - PairwiseBenchmark(pairwise::Ops op, std::string name) : OpBenchmark() { - _opNum = (int) op; - _testName = name; - } - - ~PairwiseBenchmark(){ - if (_x != _y && _x != _z && _y != _z) { - delete _x; - delete _y; - delete _z; - } else if (_x == _y && _x == _z) { - delete _x; - } else if (_x == _z) { - delete _x; - delete _y; - } else if (_y == _z) { - delete _x; - delete _y; - } - } - - void executeOnce() override { - PointersManager manager(LaunchContext::defaultContext(), "PairwiseBM"); - - NativeOpExecutioner::execPairwiseTransform(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), _y->buffer(), _y->shapeInfo(), _y->specialBuffer(), _y->specialShapeInfo(), _z->buffer(), _z->shapeInfo(), _z->specialBuffer(), _z->specialShapeInfo(), nullptr); - - manager.synchronize(); - } - - std::string axis() override { - return "N/A"; - } - - std::string inplace() override { - std::string result; - result += (_x == _y ? "x==y" : "x!=y"); - result += "/"; - result += (_x == _z ? "x==z" : "x!=z"); - result += "/"; - result += (_y == _z ? "y==z" : "y!=z"); - return result; - } - - std::string orders() override { - std::string result; - result += _x->ordering(); - result += "/"; - result += _y->ordering(); - result += "/"; - result += _z == nullptr ? _x->ordering() : _z->ordering(); - return result; - } - - std::string strides() override { - std::string result; - result += ShapeUtils::strideAsString(_x); - result += "/"; - result += ShapeUtils::strideAsString(_y); - result += "/"; - result += _z == nullptr ? ShapeUtils::strideAsString(_x) : ShapeUtils::strideAsString(_z); - return result; - } - - OpBenchmark* clone() override { - return new PairwiseBenchmark((pairwise::Ops) _opNum, _testName, _x, _y, _z); - } - }; -} - -#endif //DEV_TESTS_SCALARBENCHMARK_H \ No newline at end of file +class SD_EXPORT PairwiseBenchmark : public OpBenchmark { + public: + PairwiseBenchmark() : OpBenchmark() { + // + } + + PairwiseBenchmark(pairwise::Ops op, const std::string &testName, + const NDArray &x, const NDArray &y, const NDArray &z) + : OpBenchmark(testName, x, y, z) { + _opNum = (int)op; + } + + PairwiseBenchmark(pairwise::Ops op, std::string name) : OpBenchmark() { + _opNum = (int)op; + _testName = name; + } + + ~PairwiseBenchmark() { + // + } + + void executeOnce() override { + PointersManager manager(LaunchContext::defaultContext(), "PairwiseBM"); + + NativeOpExecutioner::execPairwiseTransform( + LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), + _x.specialBuffer(), _x.specialShapeInfo(), _y.buffer(), _y.shapeInfo(), + _y.specialBuffer(), _y.specialShapeInfo(), _z.buffer(), _z.shapeInfo(), + _z.specialBuffer(), _z.specialShapeInfo(), nullptr); + + manager.synchronize(); + } + + std::string axis() override { return "N/A"; } + + std::string inplace() override { + std::string result; + result += (_x.platformBuffer() == _y.platformBuffer() ? "x==y" : "x!=y"); + result += "/"; + result += (_x.platformBuffer() == _z.platformBuffer() ? "x==z" : "x!=z"); + result += "/"; + result += (_y.platformBuffer() == _z.platformBuffer() ? "y==z" : "y!=z"); + return result; + } + + std::string orders() override { + std::string result; + result += _x.ordering(); + result += "/"; + result += _y.ordering(); + result += "/"; + result += _z.shapeInfo() == nullptr ? _x.ordering() : _z.ordering(); + return result; + } + + std::string strides() override { + std::string result; + result += ShapeUtils::strideAsString(_x); + result += "/"; + result += ShapeUtils::strideAsString(_y); + result += "/"; + result += _z.shapeInfo() == nullptr ? ShapeUtils::strideAsString(_x) + : ShapeUtils::strideAsString(_z); + return result; + } + + OpBenchmark *clone() override { + return new PairwiseBenchmark((pairwise::Ops)_opNum, _testName, _x, _y, _z); + } +}; +} // namespace sd + +#endif // SD_SCALARBENCHMARK_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/Parameters.h b/libnd4j/include/helpers/benchmark/Parameters.h index eee443574b81..54efd9be2e0d 100644 --- a/libnd4j/include/helpers/benchmark/Parameters.h +++ b/libnd4j/include/helpers/benchmark/Parameters.h @@ -18,35 +18,41 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_PARAMETERS_H -#define DEV_TESTS_PARAMETERS_H +#ifndef SD_PARAMETERS_H +#define SD_PARAMETERS_H #include #include #include namespace sd { - class Parameters { - private: - std::map _intParams; - std::map _boolParams; - std::map> _arrayParams; - public: - Parameters() = default; - - Parameters* addIntParam(std::string string, int param); - Parameters* addIntParam(std::initializer_list strings, std::initializer_list params); - - Parameters* addBoolParam(std::string string, bool param); - Parameters* addBoolParam(std::initializer_list strings, std::initializer_list params); - - Parameters* addArrayParam(std::string string, std::initializer_list param); - Parameters* addArrayParam(std::initializer_list strings, std::initializer_list> params); - - int getIntParam(std::string string) const ; - bool getBoolParam(std::string string) const; - std::vector getArrayParam(std::string string) const; - }; -} - -#endif //DEV_TESTS_PARAMETERS_H \ No newline at end of file +class Parameters { + private: + std::map _intParams; + std::map _boolParams; + std::map> _arrayParams; + + public: + Parameters() = default; + + Parameters* addIntParam(std::string string, int param); + Parameters* addIntParam(std::initializer_list strings, + std::initializer_list params); + + Parameters* addBoolParam(std::string string, bool param); + Parameters* addBoolParam(std::initializer_list strings, + std::initializer_list params); + + Parameters* addArrayParam(std::string string, + std::initializer_list param); + Parameters* addArrayParam( + std::initializer_list strings, + std::initializer_list> params); + + int getIntParam(std::string string) const; + bool getBoolParam(std::string string) const; + std::vector getArrayParam(std::string string) const; +}; +} // namespace sd + +#endif // SD_PARAMETERS_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/ParametersBatch.h b/libnd4j/include/helpers/benchmark/ParametersBatch.h index 68c4dfb9f726..235cd5f6e8b3 100644 --- a/libnd4j/include/helpers/benchmark/ParametersBatch.h +++ b/libnd4j/include/helpers/benchmark/ParametersBatch.h @@ -18,67 +18,64 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_PARAMETERSBATCH_H -#define DEV_TESTS_PARAMETERSBATCH_H +#ifndef SD_PARAMETERSBATCH_H +#define SD_PARAMETERSBATCH_H #include -#include #include +#include + namespace sd { - class ParametersBatch { - protected: - std::vector _spaces; - public: - ParametersBatch() = default; - ParametersBatch(std::initializer_list spaces) { - _spaces = spaces; - } - - ParametersBatch(std::vector spaces) { - _spaces = spaces; - } - - - std::vector parameters() { - std::vector result; - std::vector> vectors; - int totalIterations = 1; - - // hehe - int xCoords[MAX_RANK]; - Nd4jLong xShape[MAX_RANK]; - int xRank = _spaces.size(); - - for (int e = 0; e < _spaces.size(); e++) { - auto space = _spaces[e]; - auto values = space->evaluate(); - vectors.emplace_back(values); - - totalIterations *= values.size(); - xShape[e] = values.size(); - } - - //nd4j_printf("Total Iterations: %i\n", totalIterations); - - for (int i = 0; i < totalIterations; i++) { - if (xRank > 0) - shape::index2coords(i, xRank, xShape, xCoords); - - Parameters params; - for (int j = 0; j < xRank; j++) { - int value = vectors[j][xCoords[j]]; - std::string name = _spaces[j]->name(); - params.addIntParam(name, value); - } - - result.emplace_back(params); - } - - - return result; - } - }; -} - -#endif //DEV_TESTS_PARAMETERSBATCH_H \ No newline at end of file +class ParametersBatch { + protected: + std::vector _spaces; + + public: + ParametersBatch() = default; + ParametersBatch(std::initializer_list spaces) { + _spaces = spaces; + } + + ParametersBatch(std::vector spaces) { _spaces = spaces; } + + std::vector parameters() { + std::vector result; + std::vector> vectors; + int totalIterations = 1; + + // hehe + int xCoords[MAX_RANK]; + Nd4jLong xShape[MAX_RANK]; + int xRank = _spaces.size(); + + for (int e = 0; e < _spaces.size(); e++) { + auto space = _spaces[e]; + auto values = space->evaluate(); + vectors.emplace_back(values); + + totalIterations *= values.size(); + xShape[e] = values.size(); + } + + // nd4j_printf("Total Iterations: %i\n", totalIterations); + + for (int i = 0; i < totalIterations; i++) { + if (xRank > 0) shape::index2coords(i, xRank, xShape, xCoords); + + Parameters params; + for (int j = 0; j < xRank; j++) { + int value = vectors[j][xCoords[j]]; + std::string name = _spaces[j]->name(); + params.addIntParam(name, value); + } + + result.emplace_back(params); + } + + return result; + } +}; +} // namespace sd + +#endif // SD_PARAMETERSBATCH_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/ParametersSpace.h b/libnd4j/include/helpers/benchmark/ParametersSpace.h index a7c59f9a6d45..42438bfc97c2 100644 --- a/libnd4j/include/helpers/benchmark/ParametersSpace.h +++ b/libnd4j/include/helpers/benchmark/ParametersSpace.h @@ -18,25 +18,24 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_PARAMETERSPACE_H -#define DEV_TESTS_PARAMETERSPACE_H +#ifndef SD_PARAMETERSPACE_H +#define SD_PARAMETERSPACE_H #include namespace sd { - class ParametersSpace { - protected: - std::string _name; - public: - ParametersSpace() = default; - ~ParametersSpace() = default; - - std::string name() { - return _name; - } - - virtual std::vector evaluate() = 0; - }; -} - -#endif //DEV_TESTS_PARAMETERSPACE_H \ No newline at end of file +class ParametersSpace { + protected: + std::string _name; + + public: + ParametersSpace() = default; + ~ParametersSpace() = default; + + std::string name() { return _name; } + + virtual std::vector evaluate() = 0; +}; +} // namespace sd + +#endif // SD_PARAMETERSPACE_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/PredefinedParameters.h b/libnd4j/include/helpers/benchmark/PredefinedParameters.h index f2a7fc347655..ef4e55dc7868 100644 --- a/libnd4j/include/helpers/benchmark/PredefinedParameters.h +++ b/libnd4j/include/helpers/benchmark/PredefinedParameters.h @@ -18,29 +18,30 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_PREDEFINEDPARAMETERS_H -#define DEV_TESTS_PREDEFINEDPARAMETERS_H +#ifndef SD_PREDEFINEDPARAMETERS_H +#define SD_PREDEFINEDPARAMETERS_H #include "ParametersSpace.h" namespace sd { - class PredefinedParameters : public ParametersSpace{ - std::vector _params; - public: - PredefinedParameters(std::string name, std::initializer_list parameters) : ParametersSpace() { - _name = name; - _params = parameters; - } - - PredefinedParameters(std::string name, std::vector parameters) : ParametersSpace() { - _name = name; - _params = parameters; - } - - std::vector evaluate() override { - return _params; - } - }; -} - -#endif //DEV_TESTS_PREDEFINEDPARAMETERS_H \ No newline at end of file +class PredefinedParameters : public ParametersSpace { + std::vector _params; + + public: + PredefinedParameters(std::string name, std::initializer_list parameters) + : ParametersSpace() { + _name = name; + _params = parameters; + } + + PredefinedParameters(std::string name, std::vector parameters) + : ParametersSpace() { + _name = name; + _params = parameters; + } + + std::vector evaluate() override { return _params; } +}; +} // namespace sd + +#endif // SD_PREDEFINEDPARAMETERS_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/ReductionBenchmark.h b/libnd4j/include/helpers/benchmark/ReductionBenchmark.h index d87c20d3c556..68d91143bf08 100644 --- a/libnd4j/include/helpers/benchmark/ReductionBenchmark.h +++ b/libnd4j/include/helpers/benchmark/ReductionBenchmark.h @@ -20,133 +20,169 @@ #include #include -#include "../OpBenchmark.h" -#ifndef DEV_TESTS_REDUCEBENCHMARK_H -#define DEV_TESTS_REDUCEBENCHMARK_H +#include + +#ifndef SD_REDUCEBENCHMARK_H +#define SD_REDUCEBENCHMARK_H using namespace sd::graph; namespace sd { - class ND4J_EXPORT ReductionBenchmark : public OpBenchmark { - protected: - int _opType; //0=Float, 1=Same - public: - ReductionBenchmark() : OpBenchmark() { - // - } - - ReductionBenchmark(reduce::FloatOps op, std::string testName, NDArray *x, NDArray *z, std::initializer_list axis) : OpBenchmark(testName, x, z, axis) { - _opNum = (int) op; - _opType = 0; - } - - ReductionBenchmark(reduce::SameOps op, std::string testName, NDArray *x, NDArray *z, std::initializer_list axis) : OpBenchmark(testName, x, z, axis) { - _opNum = (int) op; - _opType = 1; - } - - - ReductionBenchmark(reduce::FloatOps op) : OpBenchmark() { - _opNum = (int) op; - _opType = 0; - } - - ReductionBenchmark(reduce::FloatOps op, std::string testName) : OpBenchmark() { - _opNum = (int) op; - _opType = 0; - _testName = testName; - } - - ReductionBenchmark(reduce::SameOps op) : OpBenchmark() { - _opNum = (int) op; - _opType = 1; - } - - ReductionBenchmark(reduce::SameOps op, std::string testName) : OpBenchmark() { - _opNum = (int) op; - _opType = 1; - _testName = testName; - } - - ReductionBenchmark(reduce::FloatOps op, std::string testName, NDArray *x, NDArray *z, std::vector axis) : OpBenchmark(testName ,x, z, axis) { - _opNum = (int) op; - _opType = 0; - } - - ReductionBenchmark(reduce::SameOps op, std::string testName, NDArray *x, NDArray *z, std::vector axis) : OpBenchmark(testName ,x, z, axis) { - _opNum = (int) op; - _opType = 1; - } - - void executeOnce() override { - PointersManager manager(LaunchContext::defaultContext(), "reductionBM"); - - if (_z->isScalar() || _y == nullptr) - if (_opType == 0) - NativeOpExecutioner::execReduceFloatScalar(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), nullptr, _z->buffer(), _z->shapeInfo(), _z->specialBuffer(), _z->specialShapeInfo()); - else - NativeOpExecutioner::execReduceSameScalar(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), nullptr, _z->buffer(), _z->shapeInfo(), _z->specialBuffer(), _z->specialShapeInfo()); - else { - auto pack = ConstantTadHelper::getInstance().tadForDimensions(_x->shapeInfo(), _axis); - - auto tadOnlyShapeInfo = Environment::getInstance().isCPU() ? pack.primaryShapeInfo() : pack.specialShapeInfo(); - auto tadOffsets = Environment::getInstance().isCPU() ? pack.primaryOffsets() : pack.specialOffsets(); - - if (_opType == 0) - NativeOpExecutioner::execReduceFloat(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), nullptr, _z->buffer(), _z->shapeInfo(), _z->specialBuffer(), _z->specialShapeInfo(), nullptr, _axis.size(), tadOnlyShapeInfo, tadOffsets); - else - NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), nullptr, _z->buffer(), _z->shapeInfo(), _z->specialBuffer(), _z->specialShapeInfo(), nullptr, _axis.size(), tadOnlyShapeInfo, tadOffsets); - } - - manager.synchronize(); - } - - std::string orders() override { - std::string result; - result += _x->ordering(); - result += "/"; - result += _z == nullptr ? _x->ordering() : _z->ordering(); - return result; - } - - std::string strides() override { - std::string result; - result += ShapeUtils::strideAsString(_x); - return result; - } - - std::string inplace() override { - return "n/a"; - } - - ~ReductionBenchmark(){ - delete _x; - delete _z; - } - - std::string axis() override { - if (_axis.empty()) - return "ALL"; - else { - std::string result; - for (auto v:_axis) { - auto s = StringUtils::valueToString(v); - result += s; - result += ","; - } - - return result; - } - } - - OpBenchmark* clone() override { - if (_opType == 0) - return new ReductionBenchmark((reduce::FloatOps) _opNum, _testName, _x, _z, _axis); - else - return new ReductionBenchmark((reduce::SameOps) _opNum, _testName, _x, _z, _axis); - } - }; -} - -#endif //DEV_TESTS_SCALARBENCHMARK_H \ No newline at end of file +class SD_EXPORT ReductionBenchmark : public OpBenchmark { + protected: + int _opType; // 0=Float, 1=Same + public: + ReductionBenchmark() : OpBenchmark() { + // + } + + ReductionBenchmark(reduce::FloatOps op, const std::string &testName, + const NDArray &x, const NDArray &z, + std::initializer_list axis) + : OpBenchmark(testName, x, z, axis) { + _opNum = (int)op; + _opType = 0; + } + + ReductionBenchmark(reduce::SameOps op, const std::string &testName, + const NDArray &x, const NDArray &z, + std::initializer_list axis) + : OpBenchmark(testName, x, z, axis) { + _opNum = (int)op; + _opType = 1; + } + + ReductionBenchmark(reduce::FloatOps op) : OpBenchmark() { + _opNum = (int)op; + _opType = 0; + } + + ReductionBenchmark(reduce::FloatOps op, const std::string &testName) + : OpBenchmark() { + _opNum = (int)op; + _opType = 0; + _testName = testName; + } + + ReductionBenchmark(reduce::SameOps op) : OpBenchmark() { + _opNum = (int)op; + _opType = 1; + } + + ReductionBenchmark(reduce::SameOps op, const std::string &testName) + : OpBenchmark() { + _opNum = (int)op; + _opType = 1; + _testName = testName; + } + + ReductionBenchmark(reduce::FloatOps op, const std::string &testName, + const NDArray &x, const NDArray &z, + const std::vector &axis) + : OpBenchmark(testName, x, z, axis) { + _opNum = (int)op; + _opType = 0; + } + + ReductionBenchmark(reduce::SameOps op, const std::string &testName, + const NDArray &x, const NDArray &z, + const std::vector &axis) + : OpBenchmark(testName, x, z, axis) { + _opNum = (int)op; + _opType = 1; + } + + void executeOnce() override { + PointersManager manager(LaunchContext::defaultContext(), "reductionBM"); + + if (_z.isScalar() || _y.shapeInfo() == nullptr) + if (_opType == 0) + NativeOpExecutioner::execReduceFloatScalar( + LaunchContext::defaultContext(), _opNum, _x.buffer(), + _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), nullptr, + _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), + _z.specialShapeInfo()); + else + NativeOpExecutioner::execReduceSameScalar( + LaunchContext::defaultContext(), _opNum, _x.buffer(), + _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), nullptr, + _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), + _z.specialShapeInfo()); + else { + auto pack = ConstantTadHelper::getInstance().tadForDimensions( + _x.shapeInfo(), _axis); + + auto tadOnlyShapeInfo = Environment::getInstance().isCPU() + ? pack.primaryShapeInfo() + : pack.specialShapeInfo(); + auto tadOffsets = Environment::getInstance().isCPU() + ? pack.primaryOffsets() + : pack.specialOffsets(); + + if (_opType == 0) + NativeOpExecutioner::execReduceFloat( + LaunchContext::defaultContext(), _opNum, _x.buffer(), + _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), nullptr, + _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), + _z.specialShapeInfo(), nullptr, _axis.size(), tadOnlyShapeInfo, + tadOffsets); + else + NativeOpExecutioner::execReduceSame( + LaunchContext::defaultContext(), _opNum, _x.buffer(), + _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), nullptr, + _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), + _z.specialShapeInfo(), nullptr, _axis.size(), tadOnlyShapeInfo, + tadOffsets); + } + + manager.synchronize(); + } + + std::string orders() override { + std::string result; + result += _x.ordering(); + result += "/"; + result += _z.shapeInfo() == nullptr ? _x.ordering() : _z.ordering(); + return result; + } + + std::string strides() override { + std::string result; + result += ShapeUtils::strideAsString(_x); + return result; + } + + std::string inplace() override { return "n/a"; } + + ~ReductionBenchmark() { + // + } + + std::string axis() override { + if (_axis.empty()) + return "ALL"; + else { + std::string result; + for (auto v : _axis) { + auto s = StringUtils::valueToString(v); + result += s; + result += ","; + } + + return result; + } + } + + OpBenchmark *clone() override { + if (_opType == 0) + return new ReductionBenchmark((reduce::FloatOps)_opNum, _testName, _x, _z, + _axis); + else + return new ReductionBenchmark((reduce::SameOps)_opNum, _testName, _x, _z, + _axis); + } +}; +} // namespace sd + +#endif // SD_SCALARBENCHMARK_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/ScalarBenchmark.h b/libnd4j/include/helpers/benchmark/ScalarBenchmark.h index 3b0cdecf5bcc..97292124f606 100644 --- a/libnd4j/include/helpers/benchmark/ScalarBenchmark.h +++ b/libnd4j/include/helpers/benchmark/ScalarBenchmark.h @@ -19,83 +19,83 @@ // #include "../OpBenchmark.h" -#ifndef DEV_TESTS_SCALARBENCHMARK_H -#define DEV_TESTS_SCALARBENCHMARK_H +#ifndef SD_SCALARBENCHMARK_H +#define SD_SCALARBENCHMARK_H using namespace sd::graph; namespace sd { - class ND4J_EXPORT ScalarBenchmark : public OpBenchmark { - public: - ScalarBenchmark() : OpBenchmark() { - // - } - - ~ScalarBenchmark(){ - if (_x != _y && _x != _z && _y != _z) { - delete _x; - delete _y; - delete _z; - } else if (_x == _y && _x == _z) { - delete _x; - } else if (_x == _z) { - delete _x; - delete _y; - } - } - - ScalarBenchmark(scalar::Ops op) : OpBenchmark() { - _opNum = (int) op; - } - - ScalarBenchmark(scalar::Ops op, std::string testName) : OpBenchmark() { - _opNum = (int) op; - _testName = testName; - } - - ScalarBenchmark(scalar::Ops op, std::string testName, NDArray *x, NDArray *y, NDArray *z) : OpBenchmark(testName, x, y, z) { - _opNum = (int) op; - } - - void executeOnce() override { - PointersManager manager(LaunchContext::defaultContext(), "ScalarBM"); - - if (_z == nullptr) - NativeOpExecutioner::execScalar(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), _y->buffer(), _y->shapeInfo(), _y->specialBuffer(), _y->specialShapeInfo(), nullptr); - else - NativeOpExecutioner::execScalar(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), _z->buffer(), _z->shapeInfo(), _z->specialBuffer(), _z->specialShapeInfo(), _y->buffer(), _y->shapeInfo(), _y->specialBuffer(), _y->specialShapeInfo(), nullptr); - - manager.synchronize(); - } - - std::string orders() override { - std::string result; - result += _x->ordering(); - result += "/"; - result += _z == nullptr ? _x->ordering() : _z->ordering(); - return result; - } - - std::string strides() override { - std::string result; - result += ShapeUtils::strideAsString(_x); - result += "/"; - result += _z == nullptr ? ShapeUtils::strideAsString(_x) : ShapeUtils::strideAsString(_z); - return result; - } - - std::string axis() override { - return "N/A"; - } - - std::string inplace() override { - return _x == _z ? "true" : "false"; - } - - OpBenchmark* clone() override { - return new ScalarBenchmark((scalar::Ops) _opNum, _testName, _x == nullptr ? _x : new NDArray(_x->dup()) , _y == nullptr ? _y : new NDArray(_y->dup()), _z == nullptr ? _z : new NDArray(_z->dup())); - } - }; -} - -#endif //DEV_TESTS_SCALARBENCHMARK_H \ No newline at end of file +class SD_EXPORT ScalarBenchmark : public OpBenchmark { + public: + ScalarBenchmark() : OpBenchmark() { + // + } + + ~ScalarBenchmark() {} + + ScalarBenchmark(scalar::Ops op) : OpBenchmark() { _opNum = (int)op; } + + ScalarBenchmark(scalar::Ops op, const std::string &testName) : OpBenchmark() { + _opNum = (int)op; + _testName = testName; + } + + ScalarBenchmark(scalar::Ops op, const std::string &testName, const NDArray &x, + const NDArray &y, const NDArray &z) + : OpBenchmark(testName, x, y, z) { + _opNum = (int)op; + } + + void executeOnce() override { + PointersManager manager(LaunchContext::defaultContext(), "ScalarBM"); + + if (_z.shapeInfo() == nullptr) + NativeOpExecutioner::execScalar( + LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), + _x.specialBuffer(), _x.specialShapeInfo(), _x.buffer(), + _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), + _y.buffer(), _y.shapeInfo(), _y.specialBuffer(), + _y.specialShapeInfo(), nullptr); + else + NativeOpExecutioner::execScalar( + LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), + _x.specialBuffer(), _x.specialShapeInfo(), _z.buffer(), + _z.shapeInfo(), _z.specialBuffer(), _z.specialShapeInfo(), + _y.buffer(), _y.shapeInfo(), _y.specialBuffer(), + _y.specialShapeInfo(), nullptr); + + manager.synchronize(); + } + + std::string orders() override { + std::string result; + result += _x.ordering(); + result += "/"; + result += _z.shapeInfo() == nullptr ? _x.ordering() : _z.ordering(); + return result; + } + + std::string strides() override { + std::string result; + result += ShapeUtils::strideAsString(_x); + result += "/"; + result += _z.shapeInfo() == nullptr ? ShapeUtils::strideAsString(_x) + : ShapeUtils::strideAsString(_z); + return result; + } + + std::string axis() override { return "N/A"; } + + std::string inplace() override { return _x == _z ? "true" : "false"; } + + OpBenchmark *clone() override { + return new ScalarBenchmark( + (scalar::Ops)_opNum, _testName, + _x.shapeInfo() == nullptr ? _x : NDArray(_x.dup()), + _y.shapeInfo() == nullptr ? _y : NDArray(_y.dup()), + _z.shapeInfo() == nullptr ? _z : NDArray(_z.dup())); + } +}; +} // namespace sd + +#endif // SD_SCALARBENCHMARK_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/TransformBenchmark.h b/libnd4j/include/helpers/benchmark/TransformBenchmark.h index 024857633490..b5f1874737fc 100644 --- a/libnd4j/include/helpers/benchmark/TransformBenchmark.h +++ b/libnd4j/include/helpers/benchmark/TransformBenchmark.h @@ -19,117 +19,129 @@ // #include "../OpBenchmark.h" -#ifndef DEV_TESTS_TRANSFORMBENCHMARK_H -#define DEV_TESTS_TRANSFORMBENCHMARK_H +#ifndef SD_TRANSFORMBENCHMARK_H +#define SD_TRANSFORMBENCHMARK_H using namespace sd::graph; namespace sd { - class ND4J_EXPORT TransformBenchmark : public OpBenchmark { - - protected: - int _opType; // 0=StrictOps, 1=Same, 2=Any, 3=Float - - public: - TransformBenchmark() : OpBenchmark() { - // - } - - TransformBenchmark(int opNum, int opType, std::string testName, NDArray *x, NDArray *z) : OpBenchmark(testName, x, z) { - _opNum = opNum; - _opType = opType; - } - - TransformBenchmark(transform::StrictOps op, std::string testName, NDArray *x, NDArray *z) : OpBenchmark(testName, x, z) { - _opNum = (int) op; - _opType = 0; - } - - TransformBenchmark(transform::StrictOps op, std::string name) : OpBenchmark() { - _opNum = (int) op; - _opType = 0; - _testName = name; - } - - TransformBenchmark(transform::SameOps op, std::string name) : OpBenchmark() { - _opNum = (int) op; - _opType = 1; - _testName = name; - } - - TransformBenchmark(transform::AnyOps op, std::string name) : OpBenchmark() { - _opNum = (int) op; - _opType = 2; - _testName = name; - } - - TransformBenchmark(transform::FloatOps op, std::string name) : OpBenchmark() { - _opNum = (int) op; - _opType = 3; - _testName = name; - } - - ~TransformBenchmark(){ - - if (_x == _z) { - delete _x; - } else { - delete _x; - delete _z; - } - } - - void executeOnce() override { - PointersManager manager(LaunchContext::defaultContext(), "TransformBM"); - - auto z = _z == nullptr ? _x : _z; - - switch (_opType) { - case 0: - NativeOpExecutioner::execTransformStrict(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), nullptr, nullptr, nullptr); - break; - case 1: - NativeOpExecutioner::execTransformSame(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), nullptr, nullptr, nullptr); - break; - case 2: - NativeOpExecutioner::execTransformAny(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), nullptr, nullptr, nullptr); - break; - case 3: - NativeOpExecutioner::execTransformFloat(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), nullptr, nullptr, nullptr); - break; - } - - manager.synchronize(); - } - - std::string axis() override { - return "N/A"; - } - - std::string orders() override { - std::string result; - result += _x->ordering(); - result += "/"; - result += _z == nullptr ? _x->ordering() : _z->ordering(); - return result; - } - - std::string strides() override { - std::string result; - result += ShapeUtils::strideAsString(_x); - result += "/"; - result += _z == nullptr ? ShapeUtils::strideAsString(_x) : ShapeUtils::strideAsString(_z); - return result; - } - - std::string inplace() override { - return _x == _z ? "true" : "false"; - } - - OpBenchmark* clone() override { - return new TransformBenchmark(_opNum, _opType, _testName, _x, _z); - } - }; -} - -#endif //DEV_TESTS_SCALARBENCHMARK_H \ No newline at end of file +class SD_EXPORT TransformBenchmark : public OpBenchmark { + protected: + int _opType; // 0=StrictOps, 1=Same, 2=Any, 3=Float + + public: + TransformBenchmark() : OpBenchmark() { + // + } + + TransformBenchmark(int opNum, int opType, const std::string &testName, + const NDArray &x, const NDArray &z) + : OpBenchmark(testName, x, z) { + _opNum = opNum; + _opType = opType; + } + + TransformBenchmark(transform::StrictOps op, const std::string &testName, + const NDArray &x, const NDArray &z) + : OpBenchmark(testName, x, z) { + _opNum = (int)op; + _opType = 0; + } + + TransformBenchmark(transform::StrictOps op, const std::string &name) + : OpBenchmark() { + _opNum = (int)op; + _opType = 0; + _testName = name; + } + + TransformBenchmark(transform::SameOps op, const std::string &name) + : OpBenchmark() { + _opNum = (int)op; + _opType = 1; + _testName = name; + } + + TransformBenchmark(transform::AnyOps op, const std::string &name) + : OpBenchmark() { + _opNum = (int)op; + _opType = 2; + _testName = name; + } + + TransformBenchmark(transform::FloatOps op, const std::string &name) + : OpBenchmark() { + _opNum = (int)op; + _opType = 3; + _testName = name; + } + + ~TransformBenchmark() {} + + void executeOnce() override { + PointersManager manager(LaunchContext::defaultContext(), "TransformBM"); + + auto z = _z.shapeInfo() == nullptr ? _x : _z; + + switch (_opType) { + case 0: + NativeOpExecutioner::execTransformStrict( + LaunchContext::defaultContext(), _opNum, _x.buffer(), + _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), + _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), + _z.specialShapeInfo(), nullptr, nullptr, nullptr); + break; + case 1: + NativeOpExecutioner::execTransformSame( + LaunchContext::defaultContext(), _opNum, _x.buffer(), + _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), + _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), + _z.specialShapeInfo(), nullptr, nullptr, nullptr); + break; + case 2: + NativeOpExecutioner::execTransformAny( + LaunchContext::defaultContext(), _opNum, _x.buffer(), + _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), + _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), + _z.specialShapeInfo(), nullptr, nullptr, nullptr); + break; + case 3: + NativeOpExecutioner::execTransformFloat( + LaunchContext::defaultContext(), _opNum, _x.buffer(), + _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), + _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), + _z.specialShapeInfo(), nullptr, nullptr, nullptr); + break; + } + + manager.synchronize(); + } + + std::string axis() override { return "N/A"; } + + std::string orders() override { + std::string result; + result += _x.ordering(); + result += "/"; + result += _z.shapeInfo() == nullptr ? _x.ordering() : _z.ordering(); + return result; + } + + std::string strides() override { + std::string result; + result += ShapeUtils::strideAsString(_x); + result += "/"; + result += _z.shapeInfo() == nullptr ? ShapeUtils::strideAsString(_x) + : ShapeUtils::strideAsString(_z); + return result; + } + + std::string inplace() override { return _x == _z ? "true" : "false"; } + + OpBenchmark *clone() override { + return new TransformBenchmark(_opNum, _opType, _testName, _x, _z); + } +}; +} // namespace sd + +#endif // SD_SCALARBENCHMARK_H \ No newline at end of file diff --git a/libnd4j/include/helpers/biDiagonalUp.h b/libnd4j/include/helpers/biDiagonalUp.h index dc44057a9942..2a4c839b69d5 100644 --- a/libnd4j/include/helpers/biDiagonalUp.h +++ b/libnd4j/include/helpers/biDiagonalUp.h @@ -15,66 +15,62 @@ ******************************************************************************/ // -// Created by Yurii Shyrma on 18.12.2017. +// @author Yurii Shyrma (iuriish@yahoo.com) // #ifndef LIBND4J_BIDIAGONALUP_H #define LIBND4J_BIDIAGONALUP_H -#include #include +#include namespace sd { namespace ops { namespace helpers { - class BiDiagonalUp { - - public: - - NDArray _HHmatrix; // 2D Householder matrix - NDArray _HHbidiag; // vector which contains Householder coefficients - NDArray _hhCoeffs; // vector of Householder coefficients - - /** - * constructor - * - * matrix - input matrix expected to be bi-diagonalized, remains unaffected - */ - BiDiagonalUp(const NDArray& matrix); - - /** - * this method evaluates data (coeff, normX, tail) used in Householder transformation - * formula for Householder matrix: P = identity_matrix - coeff * w * w^T - * P * x = [normX, 0, 0 , 0, ...] - * coeff - scalar - * w = [1, w1, w2, w3, ...], "tail" is w except first unity element, that is "tail" = [w1, w2, w3, ...] - * tail and coeff are stored in _HHmatrix - * normX are stored in _HHbidiag - */ - template - void _evalData(); - - void evalData(); - - /** - * this method evaluates product of Householder sequence matrices (transformations) acting on columns - * - * type - type of sequence, type = 'u' (acting on columns) or type = 'v' (acting on rows) - */ - template - HHsequence makeHHsequence_(const char type); - - HHsequence makeHHsequence(const char type); - + public: + NDArray _HHmatrix; // 2D Householder matrix + NDArray _HHbidiag; // vector which contains Householder coefficients + NDArray _hhCoeffs; // vector of Householder coefficients + + + /** + * constructor + * + * matrix - input matrix expected to be bi-diagonalized, remains unaffected + */ + BiDiagonalUp(const NDArray& matrix); + + /** + * this method evaluates data (coeff, normX, tail) used in Householder transformation + * formula for Householder matrix: P = identity_matrix - coeff * w * w^T + * P * x = [normX, 0, 0 , 0, ...] + * coeff - scalar + * w = [1, w1, w2, w3, ...], "tail" is w except first unity element, that is "tail" = [w1, w2, w3, ...] + * tail and coeff are stored in _HHmatrix + * normX are stored in _HHbidiag + */ + template + void _evalData(); + + void evalData(); + + /** + * this method evaluates product of Householder sequence matrices + * (transformations) acting on columns + * + * type - type of sequence, type = 'u' (acting on columns) or type = 'v' + * (acting on rows) + */ + template + HHsequence makeHHsequence_(const char type); + + HHsequence makeHHsequence(const char type); }; +} // namespace helpers +} // namespace ops +} // namespace sd - -} -} -} - - -#endif //LIBND4J_BIDIAGONALUP_H +#endif // LIBND4J_BIDIAGONALUP_H diff --git a/libnd4j/include/helpers/cpu/ConstantHelper.cpp b/libnd4j/include/helpers/cpu/ConstantHelper.cpp index be6eff65c403..a15d64f0cd0c 100644 --- a/libnd4j/include/helpers/cpu/ConstantHelper.cpp +++ b/libnd4j/include/helpers/cpu/ConstantHelper.cpp @@ -21,114 +21,127 @@ #ifndef __CUDABLAS__ -#include #include -#include +#include #include #include +#include + #include #include namespace sd { - ConstantHelper::ConstantHelper() { - int numDevices = getNumberOfDevices(); - _cache.resize(numDevices); - _counters.resize(numDevices); - for (int e = 0; e < numDevices; e++) { - MAP_IMPL map; +ConstantHelper::ConstantHelper() { + int numDevices = getNumberOfDevices(); + _cache.resize(numDevices); + _counters.resize(numDevices); + for (int e = 0; e < numDevices; e++) { + MAP_IMPL map; - _cache[e] = map; - _counters[e] = 0L; - } - } + _cache[e] = map; + _counters[e] = 0L; + } +} -ConstantHelper::~ConstantHelper() { +ConstantHelper ::~ConstantHelper() { for (const auto &v:_cache) { for (const auto &c:v) { delete c.second; - } + } } } ConstantHelper& ConstantHelper::getInstance() { static ConstantHelper instance; return instance; - } - - void* ConstantHelper::replicatePointer(void *src, size_t numBytes, memory::Workspace *workspace) { - if (workspace == nullptr) { - auto deviceId = getCurrentDevice(); - _counters[deviceId] += numBytes; - } - - int8_t *ptr = nullptr; - ALLOCATE(ptr, workspace, numBytes, int8_t); - - std::memcpy(ptr, src, numBytes); - return ptr; - } - - int ConstantHelper::getCurrentDevice() { - return AffinityManager::currentDeviceId(); - } +} - int ConstantHelper::getNumberOfDevices() { - return AffinityManager::numberOfDevices(); - } +void *ConstantHelper::replicatePointer(void *src, size_t numBytes, + memory::Workspace *workspace) { + if (workspace == nullptr) { + auto deviceId = getCurrentDevice(); + _counters[deviceId] += numBytes; + } - ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, sd::DataType dataType) { - const auto deviceId = getCurrentDevice(); + int8_t *ptr = nullptr; + ALLOCATE(ptr, workspace, numBytes, int8_t); - // we're locking away cache modification - _mutexHolder.lock(); + std::memcpy(ptr, src, numBytes); + return ptr; +} - if (_cache[deviceId].count(descriptor) == 0) { - _cache[deviceId][descriptor] = new ConstantHolder(); - } +int ConstantHelper::getCurrentDevice() { + return AffinityManager::currentDeviceId(); +} - auto holder = _cache[deviceId][descriptor]; +int ConstantHelper::getNumberOfDevices() { + return AffinityManager::numberOfDevices(); +} - // releasing cache lock - _mutexHolder.unlock(); +ConstantDataBuffer *ConstantHelper::constantBuffer( + const ConstantDescriptor &descriptor, sd::DataType dataType) { + const auto deviceId = getCurrentDevice(); + // we're locking away cache modification + _mutexHolder.lock(); - ConstantDataBuffer* result; + if (_cache[deviceId].count(descriptor) == 0) { + _cache[deviceId][descriptor] = new ConstantHolder(); + } - // access to this holder instance is synchronous - holder->mutex()->lock(); + auto holder = _cache[deviceId][descriptor]; + + // releasing cache lock + _mutexHolder.unlock(); + + ConstantDataBuffer *result; + + // access to this holder instance is synchronous + holder->mutex()->lock(); + + if (holder->hasBuffer(dataType)) + result = holder->getConstantDataBuffer(dataType); + else { + auto size = descriptor.length() * DataTypeUtils::sizeOf(dataType); + auto cbuff = std::make_shared(new int8_t[size], std::make_shared()); + _counters[deviceId] += size; + + // create buffer with this dtype + if (descriptor.isFloat()) { + BUILD_DOUBLE_SELECTOR( + sd::DataType::DOUBLE, dataType, sd::TypeCast::convertGeneric, + (nullptr, const_cast(descriptor.floatValues().data()), + descriptor.length(), cbuff->pointer()), + (sd::DataType::DOUBLE, double), LIBND4J_TYPES); + } else if (descriptor.isInteger()) { + BUILD_DOUBLE_SELECTOR( + sd::DataType::INT64, dataType, sd::TypeCast::convertGeneric, + (nullptr, const_cast(descriptor.integerValues().data()), + descriptor.length(), cbuff->pointer()), + (sd::DataType::INT64, Nd4jLong), LIBND4J_TYPES); + } - if (holder->hasBuffer(dataType)) - result = holder->getConstantDataBuffer(dataType); - else { - auto size = descriptor.length() * DataTypeUtils::sizeOf(dataType); - auto cbuff = std::make_shared(new int8_t[size], std::make_shared()); - _counters[deviceId] += size; + ConstantDataBuffer dataBuffer(cbuff, descriptor.length(), + dataType); + holder->addBuffer(dataBuffer, dataType); - // create buffer with this dtype - if (descriptor.isFloat()) { - BUILD_DOUBLE_SELECTOR(sd::DataType::DOUBLE, dataType, sd::TypeCast::convertGeneric, (nullptr, const_cast(descriptor.floatValues().data()), descriptor.length(), cbuff->pointer()), (sd::DataType::DOUBLE, double), LIBND4J_TYPES); - } else if (descriptor.isInteger()) { - BUILD_DOUBLE_SELECTOR(sd::DataType::INT64, dataType, sd::TypeCast::convertGeneric, (nullptr, const_cast(descriptor.integerValues().data()), descriptor.length(), cbuff->pointer()), (sd::DataType::INT64, Nd4jLong), LIBND4J_TYPES); - } + result = holder->getConstantDataBuffer(dataType); + } + holder->mutex()->unlock(); - ConstantDataBuffer dataBuffer(cbuff, descriptor.length(), dataType); - holder->addBuffer(dataBuffer, dataType); + return result; +} - result = holder->getConstantDataBuffer(dataType); - } - holder->mutex()->unlock(); +Nd4jLong ConstantHelper::getCachedAmount(int deviceId) { + int numDevices = getNumberOfDevices(); + if (deviceId > numDevices || deviceId < 0) + return 0L; + else + return _counters[deviceId]; +} - return result; - } - Nd4jLong ConstantHelper::getCachedAmount(int deviceId) { - int numDevices = getNumberOfDevices(); - if (deviceId > numDevices || deviceId < 0) - return 0L; - else - return _counters[deviceId]; - } -} +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp b/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp index 528527f363cb..538208e90e1d 100644 --- a/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp +++ b/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp @@ -21,37 +21,41 @@ #ifndef __CUDABLAS__ #include -#include #include #include +#include #include namespace sd { - ConstantShapeHelper::ConstantShapeHelper() { - _cache.resize(32); - for (int e = 0; e < 32; e++) { - MAP_IMPL cache; - _cache[e] = cache; - } - } +ConstantShapeHelper::ConstantShapeHelper() { + _cache.resize(32); + for (int e = 0; e < 32; e++) { + MAP_IMPL cache; + _cache[e] = cache; + } +} - ConstantShapeHelper& ConstantShapeHelper::getInstance() { - static ConstantShapeHelper instance; - return instance; - } +ConstantShapeHelper& ConstantShapeHelper::getInstance() { + static ConstantShapeHelper instance; -ConstantShapeBuffer& ConstantShapeHelper::bufferForShapeInfo(sd::DataType dataType, char order, const std::vector &shape) { - ShapeDescriptor descriptor(dataType, order, shape); - return bufferForShapeInfo(descriptor); - } + return instance; +} -ConstantShapeBuffer& ConstantShapeHelper::bufferForShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape) { - ShapeDescriptor descriptor(dataType, order, shape, rank); - return bufferForShapeInfo(descriptor); - } +ConstantShapeBuffer& ConstantShapeHelper::bufferForShapeInfo( + sd::DataType dataType, char order, const std::vector& shape) { + ShapeDescriptor descriptor(dataType, order, shape); + return bufferForShapeInfo(descriptor); +} +ConstantShapeBuffer& ConstantShapeHelper::bufferForShapeInfo( + const sd::DataType dataType, const char order, const int rank, + const Nd4jLong* shape) { + ShapeDescriptor descriptor(dataType, order, shape, rank); + return bufferForShapeInfo(descriptor); +} -ConstantShapeBuffer& ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) { +ConstantShapeBuffer& ConstantShapeHelper::bufferForShapeInfo( + const ShapeDescriptor& descriptor) { int deviceId = 0; std::lock_guard lock(_mutex); @@ -67,122 +71,136 @@ ConstantShapeBuffer& ConstantShapeHelper::bufferForShapeInfo(const ShapeDescript } } -ConstantShapeBuffer& ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) { - ShapeDescriptor descriptor(shapeInfo); - return bufferForShapeInfo(descriptor); - } - - bool ConstantShapeHelper::checkBufferExistenceForShapeInfo(ShapeDescriptor &descriptor) { - bool result; - int deviceId = 0; - std::lock_guard lock(_mutex); +ConstantShapeBuffer& ConstantShapeHelper::bufferForShapeInfo( + const Nd4jLong* shapeInfo) { + ShapeDescriptor descriptor(shapeInfo); + return bufferForShapeInfo(descriptor); +} - return _cache[deviceId].count(descriptor) != 0; - } +bool ConstantShapeHelper::checkBufferExistenceForShapeInfo( + ShapeDescriptor& descriptor) { + bool result; + int deviceId = 0; + std::lock_guard lock(_mutex); - const Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape) { - ShapeDescriptor descriptor(dataType, order, shape, rank); - return bufferForShapeInfo(descriptor).primary(); - } + return _cache[deviceId].count(descriptor) != 0; +} - const Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const Nd4jLong* shapeInfo) { - return ConstantShapeHelper::createShapeInfo(dataType, shape::order(shapeInfo), shape::rank(shapeInfo), shape::shapeOf(const_cast(shapeInfo))); - } +const Nd4jLong* ConstantShapeHelper::createShapeInfo( + const sd::DataType dataType, const char order, const int rank, + const Nd4jLong* shape) { + ShapeDescriptor descriptor(dataType, order, shape, rank); + return bufferForShapeInfo(descriptor).primary(); +} - const Nd4jLong* ConstantShapeHelper::emptyShapeInfo(const sd::DataType dataType) { - auto descriptor = ShapeDescriptor::emptyDescriptor(dataType); - return bufferForShapeInfo(descriptor).primary(); - } +const Nd4jLong* ConstantShapeHelper::createShapeInfo( + const sd::DataType dataType, const Nd4jLong* shapeInfo) { + return ConstantShapeHelper::createShapeInfo( + dataType, shape::order(shapeInfo), shape::rank(shapeInfo), + shape::shapeOf(const_cast(shapeInfo))); +} - const Nd4jLong* ConstantShapeHelper::scalarShapeInfo(const sd::DataType dataType) { - auto descriptor = ShapeDescriptor::scalarDescriptor(dataType); - return bufferForShapeInfo(descriptor).primary(); - } +const Nd4jLong* ConstantShapeHelper::emptyShapeInfo( + const sd::DataType dataType) { + auto descriptor = ShapeDescriptor::emptyDescriptor(dataType); + return bufferForShapeInfo(descriptor).primary(); +} - const Nd4jLong* ConstantShapeHelper::vectorShapeInfo(const Nd4jLong length, const sd::DataType dataType) { - auto descriptor = ShapeDescriptor::vectorDescriptor(length, dataType); - return bufferForShapeInfo(descriptor).primary(); - } +const Nd4jLong* ConstantShapeHelper::scalarShapeInfo( + const sd::DataType dataType) { + auto descriptor = ShapeDescriptor::scalarDescriptor(dataType); + return bufferForShapeInfo(descriptor).primary(); +} - const Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const std::vector &shape) { - ShapeDescriptor descriptor(dataType, order, shape); - return bufferForShapeInfo(descriptor).primary(); - } +const Nd4jLong* ConstantShapeHelper::vectorShapeInfo( + const Nd4jLong length, const sd::DataType dataType) { + auto descriptor = ShapeDescriptor::vectorDescriptor(length, dataType); + return bufferForShapeInfo(descriptor).primary(); +} - const Nd4jLong* ConstantShapeHelper::createShapeInfo(const ShapeDescriptor &descriptor) { - return bufferForShapeInfo(descriptor).primary(); - } +const Nd4jLong* ConstantShapeHelper::createShapeInfo( + const sd::DataType dataType, const char order, + const std::vector& shape) { + ShapeDescriptor descriptor(dataType, order, shape); + return bufferForShapeInfo(descriptor).primary(); +} - const Nd4jLong* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal) { - ShapeDescriptor descriptor(shapeInfo); - auto result = createShapeInfo(descriptor); +const Nd4jLong* ConstantShapeHelper::createShapeInfo( + const ShapeDescriptor& descriptor) { + return bufferForShapeInfo(descriptor).primary(); +} - if (destroyOriginal) - RELEASE(shapeInfo, nullptr) +const Nd4jLong* ConstantShapeHelper::createFromExisting(Nd4jLong* shapeInfo, + bool destroyOriginal) { + ShapeDescriptor descriptor(shapeInfo); + auto result = createShapeInfo(descriptor); - return result; - } + if (destroyOriginal) RELEASE(shapeInfo, nullptr) - const Nd4jLong* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, sd::memory::Workspace *workspace) { - ShapeDescriptor descriptor(shapeInfo); - auto result = createShapeInfo(descriptor); + return result; +} - RELEASE(shapeInfo, workspace); +const Nd4jLong* ConstantShapeHelper::createFromExisting( + Nd4jLong* shapeInfo, sd::memory::Workspace* workspace) { + ShapeDescriptor descriptor(shapeInfo); + auto result = createShapeInfo(descriptor); - return result; - } + RELEASE(shapeInfo, workspace); + return result; +} //////////////////////////////////////////////////////////////////////// -ConstantShapeBuffer& ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace, const std::vector &dimensions) { - - Nd4jLong* newShapeInfo = nullptr; - ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(shape::rank(maxShapeInfo)), Nd4jLong); - - newShapeInfo[0] = shape::rank(maxShapeInfo); - - sd::ArrayOptions::copyDataType(newShapeInfo, minShapeInfo); // type - newShapeInfo[2 * newShapeInfo[0] + 2] = shape::elementWiseStride(minShapeInfo); // ews - newShapeInfo[2 * newShapeInfo[0] + 3] = shape::order(minShapeInfo); // order - - if(!dimensions.empty()) { - - for(uint k = 0, j = 0, i = 0; i < shape::rank(maxShapeInfo); ++i) { - - if(j < dimensions.size() && dimensions[j] == i) { - shape::shapeOf(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[k]; - shape::stride(newShapeInfo)[i] = shape::stride(minShapeInfo)[k++]; - ++j; - } - else{ - shape::shapeOf(newShapeInfo)[i] = 1; - shape::stride(newShapeInfo)[i] = 0; - if(shape::sizeAt(minShapeInfo, k) == 1 && dimensions.size() != shape::rank(minShapeInfo)) - ++k; - } - } +ConstantShapeBuffer& ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast( + const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, + sd::memory::Workspace* workspace, const std::vector& dimensions) { + Nd4jLong* newShapeInfo = nullptr; + ALLOCATE(newShapeInfo, workspace, + shape::shapeInfoLength(shape::rank(maxShapeInfo)), Nd4jLong); + + newShapeInfo[0] = shape::rank(maxShapeInfo); + + sd::ArrayOptions::copyDataType(newShapeInfo, minShapeInfo); // type + newShapeInfo[2 * newShapeInfo[0] + 2] = + shape::elementWiseStride(minShapeInfo); // ews + newShapeInfo[2 * newShapeInfo[0] + 3] = shape::order(minShapeInfo); // order + + if (!dimensions.empty()) { + for (uint k = 0, j = 0, i = 0; i < shape::rank(maxShapeInfo); ++i) { + if (j < dimensions.size() && dimensions[j] == i) { + shape::shapeOf(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[k]; + shape::stride(newShapeInfo)[i] = shape::stride(minShapeInfo)[k++]; + ++j; + } else { + shape::shapeOf(newShapeInfo)[i] = 1; + shape::stride(newShapeInfo)[i] = 0; + if (shape::sizeAt(minShapeInfo, k) == 1 && + dimensions.size() != shape::rank(minShapeInfo)) + ++k; + } } - else{ - - for(int j = shape::rank(minShapeInfo) - 1, i = shape::rank(maxShapeInfo) - 1; i >=0 ; --i) { - - if(j >= 0) { - shape::shapeOf(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[j]; - shape::stride(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[j] == 1 ? 0 : shape::stride(minShapeInfo)[j]; - --j; - } - else { - shape::shapeOf(newShapeInfo)[i] = 1; - shape::stride(newShapeInfo)[i] = 0; - } - } + } else { + for (int j = shape::rank(minShapeInfo) - 1, + i = shape::rank(maxShapeInfo) - 1; + i >= 0; --i) { + if (j >= 0) { + shape::shapeOf(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[j]; + shape::stride(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[j] == 1 + ? 0 + : shape::stride(minShapeInfo)[j]; + --j; + } else { + shape::shapeOf(newShapeInfo)[i] = 1; + shape::stride(newShapeInfo)[i] = 0; + } } + } - ShapeDescriptor descriptor(newShapeInfo); + ShapeDescriptor descriptor(newShapeInfo); - RELEASE(newShapeInfo, workspace); + RELEASE(newShapeInfo, workspace); - return bufferForShapeInfo(descriptor); + return bufferForShapeInfo(descriptor); } } // namespace sd diff --git a/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp b/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp index 9f859ee3e7d0..767aa7d35fcc 100644 --- a/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp +++ b/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp @@ -18,73 +18,90 @@ // @author raver119@gmail.com // -#include "../ConstantTadHelper.h" -#include +#include #include #include #include +#include #ifndef __CUDABLAS__ - namespace sd { - ConstantTadHelper::ConstantTadHelper() { - MAP_IMPL pack; - _cache.emplace_back(pack); - } - - ConstantTadHelper& ConstantTadHelper::getInstance() { - static ConstantTadHelper instance; - return instance; - } - - TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) { - return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape); - } - - TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector &dimensions, const bool keepUnitiesInShape) { - return tadForDimensions(originalShape, const_cast(dimensions.data()), dimensions.size(), keepUnitiesInShape); - } - - TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) { - TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape); - return tadForDimensions(tadDescriptor); - } - - TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector &dimensions, const bool keepUnitiesInShape) { - TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape); - return tadForDimensions(tadDescriptor); - } - - TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) { - const int deviceId = 0; - - std::lock_guard lock(_mutex); - if (_cache[deviceId].count(descriptor) == 0) { - // if there's no TadPack matching this descriptor - create one - const auto shapeInfo = descriptor.originalShape().toShapeInfo(); - const int rank = shape::rank(shapeInfo); - const std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(rank, descriptor.axis()); - const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(shapeInfo, dimsToExclude); - const int subArrRank = (rank == dimsToExclude.size() || descriptor.areUnitiesinShape()) ? rank : rank - dimsToExclude.size(); - - auto sPtr = std::make_shared(new Nd4jLong[shape::shapeInfoLength(subArrRank)], std::make_shared()); // shape of sub-arrays (same for all for them) - auto oPtr = std::make_shared(new Nd4jLong[numOfSubArrs], std::make_shared()); - - if (numOfSubArrs > 0) - shape::calcSubArrsShapeInfoAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr->pointerAsT(), oPtr->pointerAsT(), descriptor.areUnitiesinShape()); - - ConstantShapeBuffer shapeBuffer(sPtr); - ConstantOffsetsBuffer offsetsBuffer(oPtr); - TadPack t(shapeBuffer, offsetsBuffer, numOfSubArrs); - _cache[deviceId][descriptor] = t; - - delete[] shapeInfo; - } - - return _cache[deviceId][descriptor]; - } +ConstantTadHelper::ConstantTadHelper() { + MAP_IMPL pack; + _cache.emplace_back(pack); } +ConstantTadHelper &ConstantTadHelper::getInstance() { + static ConstantTadHelper instance; + + return instance; +} + +TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, + int dimension, + const bool keepUnitiesInShape) { + return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape); +} + +TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, + const std::vector &dimensions, + const bool keepUnitiesInShape) { + return tadForDimensions(originalShape, const_cast(dimensions.data()), + dimensions.size(), keepUnitiesInShape); +} + +TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, + int *dimensions, int dimLength, + const bool keepUnitiesInShape) { + TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, + keepUnitiesInShape); + return tadForDimensions(tadDescriptor); +} + +TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, + std::vector &dimensions, + const bool keepUnitiesInShape) { + TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape); + return tadForDimensions(tadDescriptor); +} + +TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) { + const int deviceId = 0; + + std::lock_guard lock(_mutex); + if (_cache[deviceId].count(descriptor) == 0) { +// if there's no TadPack matching this descriptor - create one + const auto shapeInfo = descriptor.originalShape().toShapeInfo(); + const int rank = shape::rank(shapeInfo); + const std::vector dimsToExclude = + ShapeUtils::evalDimsToExclude(rank, descriptor.axis()); + const Nd4jLong numOfSubArrs = + ShapeUtils::getNumOfSubArrs(shapeInfo, dimsToExclude); + const int subArrRank = + (rank == dimsToExclude.size() || descriptor.areUnitiesinShape()) + ? rank + : rank - dimsToExclude.size(); + + auto sPtr = std::make_shared(new Nd4jLong[shape::shapeInfoLength( + subArrRank)], std::make_shared()); // shape of sub-arrays (same for all for them) + auto oPtr = std::make_shared(new Nd4jLong[numOfSubArrs], std::make_shared()); + + if (numOfSubArrs > 0) + shape::calcSubArrsShapeInfoAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr->pointerAsT(), oPtr->pointerAsT(), descriptor.areUnitiesinShape()); + + ConstantShapeBuffer shapeBuffer(sPtr); + ConstantOffsetsBuffer offsetsBuffer(oPtr); + TadPack t(shapeBuffer, offsetsBuffer, numOfSubArrs); + _cache[deviceId][descriptor] = t; + + delete[] shapeInfo; + } + + return _cache[deviceId][descriptor]; +} + +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/MmulHelper.cpp b/libnd4j/include/helpers/cpu/MmulHelper.cpp index 437eebe1d606..17a672c275b4 100644 --- a/libnd4j/include/helpers/cpu/MmulHelper.cpp +++ b/libnd4j/include/helpers/cpu/MmulHelper.cpp @@ -19,384 +19,429 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // #include "../MmulHelper.h" + #include -#include -#include #include #include - +#include +#include namespace sd { ////////////////////////////////////////////////////////////////////////////// // MXK x KxN = MxN -> actual sequence of axes doesn't matter template -static void usualGemm(const NDArray* vA, const NDArray* vB, NDArray* vC, - const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, - const double alpha, const double beta) { - - const T1* A = vA->bufferAsT(); - const T2* B = vB->bufferAsT(); - T3* C = vC->bufferAsT(); - - const T3 alphaZ = alpha; - const T3 betaZ = beta; - - const bool betaPersent = beta; +static void usualGemm(const NDArray* vA, const NDArray* vB, NDArray* vC, + const int aMaxis, const int aKaxis, const int bKaxis, + const int bNaxis, const int cMaxis, const int cNaxis, + const double alpha, const double beta) { + const T1* A = vA->bufferAsT(); + const T2* B = vB->bufferAsT(); + T3* C = vC->bufferAsT(); - const Nd4jLong* aShapeInfo = vA->shapeInfo(); - const Nd4jLong* bShapeInfo = vB->shapeInfo(); - const Nd4jLong* cShapeInfo = vC->shapeInfo(); + const T3 alphaZ = alpha; + const T3 betaZ = beta; - const int aRank = vA->rankOf(); - const int bRank = vB->rankOf(); - const int cRank = vC->rankOf(); + const bool betaPersent = beta; - const Nd4jLong cLen = vC->lengthOf(); + const Nd4jLong* aShapeInfo = vA->shapeInfo(); + const Nd4jLong* bShapeInfo = vB->shapeInfo(); + const Nd4jLong* cShapeInfo = vC->shapeInfo(); - const int K = vA->sizeAt(aKaxis); + const int aRank = vA->rankOf(); + const int bRank = vB->rankOf(); + const int cRank = vC->rankOf(); - auto func = PRAGMA_THREADS_FOR { + const Nd4jLong cLen = vC->lengthOf(); - std::vector aCoords(2), bCoords(2), cCoords(2); + const int K = vA->sizeAt(aKaxis); - for (auto i = start; i < stop; ++i) { + auto func = PRAGMA_THREADS_FOR { + std::vector aCoords(2), bCoords(2), cCoords(2); - // evaluate C coordinates - shape::index2coordsCPU(start, i, cShapeInfo, cCoords.data()); + for (auto i = start; i < stop; ++i) { + // evaluate C coordinates + shape::index2coordsCPU(start, i, cShapeInfo, cCoords.data()); - // evaluate A coordinates - aCoords[aMaxis] = cCoords[cMaxis]; - aCoords[aKaxis] = 0; + // evaluate A coordinates + aCoords[aMaxis] = cCoords[cMaxis]; + aCoords[aKaxis] = 0; - // evaluate B coordinates - bCoords[bKaxis] = 0; - bCoords[bNaxis] = cCoords[cNaxis]; + // evaluate B coordinates + bCoords[bKaxis] = 0; + bCoords[bNaxis] = cCoords[cNaxis]; - auto aOffset = shape::getOffset(aShapeInfo, aCoords.data()); - auto bOffset = shape::getOffset(bShapeInfo, bCoords.data()); + auto aOffset = shape::getOffset(aShapeInfo, aCoords.data()); + auto bOffset = shape::getOffset(bShapeInfo, bCoords.data()); - T3 val = A[aOffset] * B[bOffset]; // first iteration + T3 val = A[aOffset] * B[bOffset]; // first iteration - for (int j = 1; j < K; ++j) { // rest iterations - aOffset += shape::stride(aShapeInfo)[aKaxis]; - bOffset += shape::stride(bShapeInfo)[bKaxis]; - val = val + A[aOffset] * B[bOffset]; - } + for (int j = 1; j < K; ++j) { // rest iterations + aOffset += shape::stride(aShapeInfo)[aKaxis]; + bOffset += shape::stride(bShapeInfo)[bKaxis]; + val = val + A[aOffset] * B[bOffset]; + } - auto cOffset = shape::getOffset(cShapeInfo, cCoords.data()); + auto cOffset = shape::getOffset(cShapeInfo, cCoords.data()); - if(betaPersent) - C[cOffset] = alphaZ * val + betaZ * C[cOffset]; - else - C[cOffset] = alphaZ * val; - } - }; + if (betaPersent) + C[cOffset] = alphaZ * val + betaZ * C[cOffset]; + else + C[cOffset] = alphaZ * val; + } + }; - samediff::Threads::parallel_tad(func, 0, cLen); + samediff::Threads::parallel_tad(func, 0, cLen); } - ////////////////////////////////////////////////////////////////////////////// // MXN x N = M -> actual sequence of {M,N} axes doesn't matter template -static void usualGemv(const NDArray* vA, const NDArray* vX, NDArray* vY, const int incx, const int incy, const int aMaxis, const double alpha, const double beta) { +static void usualGemv(const NDArray* vA, const NDArray* vX, NDArray* vY, + const int incx, const int incy, const int aMaxis, + const double alpha, const double beta) { + const T1* A = vA->bufferAsT(); + const T2* X = vX->bufferAsT(); + T3* Y = vY->bufferAsT(); - const T1* A = vA->bufferAsT(); - const T2* X = vX->bufferAsT(); - T3* Y = vY->bufferAsT(); + const T3 alphaZ = alpha; + const T3 betaZ = beta; - const T3 alphaZ = alpha; - const T3 betaZ = beta; + const bool betaPersent = beta; - const bool betaPersent = beta; + const Nd4jLong* aShapeInfo = vA->shapeInfo(); + const Nd4jLong* xShapeInfo = vX->shapeInfo(); + const Nd4jLong* yShapeInfo = vY->shapeInfo(); - const Nd4jLong* aShapeInfo = vA->shapeInfo(); - const Nd4jLong* xShapeInfo = vX->shapeInfo(); - const Nd4jLong* yShapeInfo = vY->shapeInfo(); + const int N = vX->lengthOf(); + const int M = vY->lengthOf(); - const int N = vX->lengthOf(); - const int M = vY->lengthOf(); + const auto aMstride = vA->strideAt(aMaxis); + const auto aNstride = vA->strideAt(aMaxis == 0 ? 1 : 0); - const auto aMstride = vA->strideAt(aMaxis); - const auto aNstride = vA->strideAt(aMaxis == 0 ? 1 : 0); - - auto func = PRAGMA_THREADS_FOR { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; ++i) { + // evaluate offsets + auto aOffset = i * aMstride; + auto xOffset = 0; - for (auto i = start; i < stop; ++i) { + T3 val = A[aOffset] * X[xOffset]; // first iteration - // evaluate offsets - auto aOffset = i * aMstride; - auto xOffset = 0; + for (int j = 1; j < N; ++j) { // rest iterations + aOffset += aNstride; + xOffset += incx; + val = val + A[aOffset] * X[xOffset]; + } - T3 val = A[aOffset] * X[xOffset]; // first iteration + auto yOffset = i * incy; - for (int j = 1; j < N; ++j) { // rest iterations - aOffset += aNstride; - xOffset += incx; - val = val + A[aOffset] * X[xOffset]; - } - - auto yOffset = i * incy; - - if(betaPersent) - Y[yOffset] = alphaZ * val + betaZ * Y[yOffset]; - else - Y[yOffset] = alphaZ * val; - } - }; + if (betaPersent) + Y[yOffset] = alphaZ * val + betaZ * Y[yOffset]; + else + Y[yOffset] = alphaZ * val; + } + }; - samediff::Threads::parallel_tad(func, 0, M); + samediff::Threads::parallel_tad(func, 0, M); } ////////////////////////////////////////////////////////////////////////////// // (X*Y) = Z[0] template -static void usualDot(const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ) { - - T1* X = reinterpret_cast(const_cast(vX)); - T2* Y = reinterpret_cast(const_cast(vY)); - T3* Z = reinterpret_cast(vZ); - T3 alphaZ(alpha), betaZ(beta); - - const bool betaPersent = beta; - - T3 sum = 0; - PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(length > Environment::getInstance().elementwiseThreshold()) schedule(guided) reduction(OMP_SUMT:sum)) - for(Nd4jLong i = 0; i < length; ++i) - sum += X[i * incx] * Y[i * incy]; - - if(betaPersent) - *Z = alphaZ * sum + betaZ * *Z; - else - *Z = alphaZ * sum; +static void usualDot(const Nd4jLong length, const double alpha, const void* vX, + const Nd4jLong incx, const void* vY, const Nd4jLong incy, + const double beta, void* vZ) { + T1* X = reinterpret_cast(const_cast(vX)); + T2* Y = reinterpret_cast(const_cast(vY)); + T3* Z = reinterpret_cast(vZ); + T3 alphaZ(alpha), betaZ(beta); + + const bool betaPersent = beta; + + T3 sum = 0; + PRAGMA_OMP_PARALLEL_FOR_ARGS( + OMP_IF(length > Environment::getInstance().elementwiseThreshold()) + schedule(guided) reduction(OMP_SUMT + : sum)) + for (Nd4jLong i = 0; i < length; ++i) sum += X[i * incx] * Y[i * incy]; + + if (betaPersent) + *Z = alphaZ * sum + betaZ * *Z; + else + *Z = alphaZ * sum; } ////////////////////////////////////////////////////////////////////////////// // MXK x KxN = MxN -NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { - if (A->dataType() != B->dataType()) - throw datatype_exception::build("mmulMxM expects all data types to be the same", A->dataType(), B->dataType()); - - if (C != nullptr && A->dataType() != C->dataType()) - throw datatype_exception::build("mmulMxM expects all data types to be the same", A->dataType(), C->dataType()); - - if(A->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxM: rank of A array is not equal 2 !"); - if(B->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxM: rank of B array is not equal 2 !"); - - const auto M = A->sizeAt(0); - const auto K = A->sizeAt(1); - const auto N = B->sizeAt(1); - - if(C != nullptr && C->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxM: rank of C array is not equal 2 !"); - if(B->sizeAt(0) != K) - throw std::runtime_error("MmulHelper::mmulMxM: B array has wrong number of rows !"); - if(C != nullptr && C->sizeAt(0) != M) - throw std::runtime_error("MmulHelper::mmulMxM: C array has wrong number of rows !"); - if(C != nullptr && C->sizeAt(1) != N) - throw std::runtime_error("MmulHelper::mmulMxM: C array has wrong number of columns !"); - - if(C == nullptr) - C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext()); - - if (C->isEmpty()) - return C; - - const auto aType = A->dataType(); - const auto bType = B->dataType(); - const auto cType = C->dataType(); - - const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC); - const bool hasGemm = BlasHelper::getInstance().hasGEMM(aType); - - const bool typeDouble = hasGemm && ABC && aType == DataType::DOUBLE; - const bool typeFloat = hasGemm && ABC && aType == DataType::FLOAT32; - - if(!typeFloat && !typeDouble) { - BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (A, B, C, 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES); - // BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (A, B, C, 0, 1, 0, 1, 0, 1, alpha, beta), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); +NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, + const double alpha, const double beta, + const char outOrder) { + if (A->dataType() != B->dataType()) + throw datatype_exception::build( + "mmulMxM expects all data types to be the same", A->dataType(), + B->dataType()); + + if (C != nullptr && A->dataType() != C->dataType()) + throw datatype_exception::build( + "mmulMxM expects all data types to be the same", A->dataType(), + C->dataType()); + + if (A->rankOf() != 2) + throw std::runtime_error( + "MmulHelper::mmulMxM: rank of A array is not equal 2 !"); + if (B->rankOf() != 2) + throw std::runtime_error( + "MmulHelper::mmulMxM: rank of B array is not equal 2 !"); + + const auto M = A->sizeAt(0); + const auto K = A->sizeAt(1); + const auto N = B->sizeAt(1); + + if (C != nullptr && C->rankOf() != 2) + throw std::runtime_error( + "MmulHelper::mmulMxM: rank of C array is not equal 2 !"); + if (B->sizeAt(0) != K) + throw std::runtime_error( + "MmulHelper::mmulMxM: B array has wrong number of rows !"); + if (C != nullptr && C->sizeAt(0) != M) + throw std::runtime_error( + "MmulHelper::mmulMxM: C array has wrong number of rows !"); + if (C != nullptr && C->sizeAt(1) != N) + throw std::runtime_error( + "MmulHelper::mmulMxM: C array has wrong number of columns !"); + + if (C == nullptr) + C = new NDArray( + outOrder, {M, N}, + DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), + A->getContext()); + + if (C->isEmpty()) return C; + + const auto aType = A->dataType(); + const auto bType = B->dataType(); + const auto cType = C->dataType(); + + const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC); + const bool hasGemm = BlasHelper::getInstance().hasGEMM(aType); + + const bool typeDouble = hasGemm && ABC && aType == DataType::DOUBLE; + const bool typeFloat = hasGemm && ABC && aType == DataType::FLOAT32; + + if (!typeFloat && !typeDouble) { + BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, + (A, B, C, 0, 1, 0, 1, 0, 1, alpha, beta), + NUMERIC_TYPES); + // BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (A, B, C, 0, 1, 0, + // 1, 0, 1, alpha, beta), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); + } else { + std::vector toDelete; + + NDArray *pA(const_cast(A)), *pB(const_cast(B)), + *pC(const_cast(C)); + + bool aMcont = M == 1 || A->strideAt(0) == 1; + bool aKcont = K == 1 || A->strideAt(1) == 1; + bool bKcont = K == 1 || B->strideAt(0) == 1; + bool bNcont = N == 1 || B->strideAt(1) == 1; + bool cMcont = M == 1 || C->strideAt(0) == 1; + bool cNcont = N == 1 || C->strideAt(1) == 1; + + if (!aMcont && !aKcont) { + pA = new NDArray(A->dup('f')); + toDelete.push_back(pA); + aMcont = true; + } + if (!bKcont && !bNcont) { + pB = new NDArray(B->dup('f')); + toDelete.push_back(pB); + bKcont = true; + } + if (!cMcont && !cNcont) { + pC = new NDArray(C->dup('f')); + toDelete.push_back(pC); + cMcont = true; } - else { - - std::vector toDelete; - - NDArray *pA(const_cast(A)), *pB(const_cast(B)), *pC(const_cast(C)); - - bool aMcont = M == 1 || A->strideAt(0) == 1; - bool aKcont = K == 1 || A->strideAt(1) == 1; - bool bKcont = K == 1 || B->strideAt(0) == 1; - bool bNcont = N == 1 || B->strideAt(1) == 1; - bool cMcont = M == 1 || C->strideAt(0) == 1; - bool cNcont = N == 1 || C->strideAt(1) == 1; - - if(!aMcont && !aKcont) { - pA = new NDArray(A->dup('f')); - toDelete.push_back(pA); - aMcont = true; - } - if(!bKcont && !bNcont) { - pB = new NDArray(B->dup('f')); - toDelete.push_back(pB); - bKcont = true; - } - if(!cMcont && !cNcont) { - pC = new NDArray(C->dup('f')); - toDelete.push_back(pC); - cMcont = true; - } - - const CBLAS_ORDER blasOrder = cMcont ? CblasColMajor : CblasRowMajor; - - const bool transA = (!aMcont && cMcont) || (aMcont && !cMcont); - const bool transB = (!bKcont && cMcont) || (bKcont && !cMcont); - - const CBLAS_TRANSPOSE transAblas = transA ? CblasTrans : CblasNoTrans; - const CBLAS_TRANSPOSE transBblas = transB ? CblasTrans : CblasNoTrans; - - const int lda = (aMcont && aKcont) ? M : !aMcont ? pA->strideAt(0) : pA->strideAt(1); - const int ldb = (bKcont && bNcont) ? K : !bKcont ? pB->strideAt(0) : pB->strideAt(1); - const int ldc = (cMcont && cNcont) ? M : !cMcont ? pC->strideAt(0) : pC->strideAt(1); - if(typeFloat) { - BlasHelper::getInstance().sgemm()(blasOrder, transAblas, transBblas, M, N, K, (float) alpha, pA->bufferAsT(), lda, pB->bufferAsT(), ldb, (float) beta, pC->bufferAsT(), ldc); - } - else if(typeDouble) { - BlasHelper::getInstance().dgemm()(blasOrder, transAblas, transBblas, M, N, K, (double) alpha, pA->bufferAsT(), lda, pB->bufferAsT(), ldb, (double) beta, pC->bufferAsT(), ldc); - } + const CBLAS_ORDER blasOrder = cMcont ? CblasColMajor : CblasRowMajor; + + const bool transA = (!aMcont && cMcont) || (aMcont && !cMcont); + const bool transB = (!bKcont && cMcont) || (bKcont && !cMcont); + + const CBLAS_TRANSPOSE transAblas = transA ? CblasTrans : CblasNoTrans; + const CBLAS_TRANSPOSE transBblas = transB ? CblasTrans : CblasNoTrans; + + const int lda = + (aMcont && aKcont) ? M : !aMcont ? pA->strideAt(0) : pA->strideAt(1); + const int ldb = + (bKcont && bNcont) ? K : !bKcont ? pB->strideAt(0) : pB->strideAt(1); + const int ldc = + (cMcont && cNcont) ? M : !cMcont ? pC->strideAt(0) : pC->strideAt(1); + + if (typeFloat) { + BlasHelper::getInstance().sgemm()( + blasOrder, transAblas, transBblas, M, N, K, (float)alpha, + pA->bufferAsT(), lda, pB->bufferAsT(), ldb, (float)beta, + pC->bufferAsT(), ldc); + } else if (typeDouble) { + BlasHelper::getInstance().dgemm()( + blasOrder, transAblas, transBblas, M, N, K, (double)alpha, + pA->bufferAsT(), lda, pB->bufferAsT(), ldb, + (double)beta, pC->bufferAsT(), ldc); + } - if(pC != C) { - C->assign(pC); - delete pC; - } - if(pA != A) - delete pA; - if(pB != B) - delete pB; + if (pC != C) { + C->assign(pC); + delete pC; } + if (pA != A) delete pA; + if (pB != B) delete pB; + } - return C; + return C; } //////////////////////////////////////////////////////////////////////////// // MXN x N = M -NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y, const double alpha, const double beta, const char outOrder) { - - if (X->dataType() != A->dataType()) - throw datatype_exception::build("mmulMxV expects all data types to be the same", A->dataType(), X->dataType()); - - if (Y != nullptr && X->dataType() != Y->dataType()) - throw datatype_exception::build("mmulMxV expects all data types to be the same", A->dataType(), Y->dataType()); - - int xLenDim, yLenDim(0); - - if(A->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxV: rank of A array is not equal 2 !"); - if(!shape::isCommonVector(X->shapeInfo(), xLenDim)) - throw std::runtime_error("MmulHelper::mmulMxV: X array must be vector !"); - - const auto M = A->sizeAt(0); - const auto N = A->sizeAt(1); - - if(Y != nullptr && !shape::isCommonVector(Y->shapeInfo(), yLenDim)) - throw std::runtime_error("MmulHelper::mmulMxV: Y array must be vector !"); - if(X->lengthOf() != N) - throw std::runtime_error("MmulHelper::mmulMxV: X vector has wrong length !"); - if(Y != nullptr && Y->lengthOf() != M) - throw std::runtime_error("MmulHelper::mmulMxV: Y array has wrong length !"); - - if(Y == nullptr) - Y = new NDArray(outOrder, {M}, DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), A->getContext()); - - if (Y->isEmpty()) - return Y; - - const int incx = X->stridesOf()[xLenDim]; - const int incy = Y->stridesOf()[yLenDim]; - - const auto aType = A->dataType(); - const auto xType = X->dataType(); - const auto yType = Y->dataType(); - - const bool AX(aType == xType), AY(aType == yType), AXY(AX && AY); - const bool hasGemv = BlasHelper::getInstance().hasGEMV(aType); - - const bool typeDouble = hasGemv && AXY && aType == DataType::DOUBLE; - const bool typeFloat = hasGemv && AXY && aType == DataType::FLOAT32; - - if(!typeDouble && !typeFloat) { - BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemv, (A, X, Y, incx, incy, 0, alpha, beta), NUMERIC_TYPES); - // BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (A, X, Y, incx, incy, 0, alpha, beta), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); +NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y, + const double alpha, const double beta, + const char outOrder) { + if (X->dataType() != A->dataType()) + throw datatype_exception::build( + "mmulMxV expects all data types to be the same", A->dataType(), + X->dataType()); + + if (Y != nullptr && X->dataType() != Y->dataType()) + throw datatype_exception::build( + "mmulMxV expects all data types to be the same", A->dataType(), + Y->dataType()); + + int xLenDim, yLenDim(0); + + if (A->rankOf() != 2) + throw std::runtime_error( + "MmulHelper::mmulMxV: rank of A array is not equal 2 !"); + if (!shape::isCommonVector(X->shapeInfo(), xLenDim)) + throw std::runtime_error("MmulHelper::mmulMxV: X array must be vector !"); + + const auto M = A->sizeAt(0); + const auto N = A->sizeAt(1); + + if (Y != nullptr && !shape::isCommonVector(Y->shapeInfo(), yLenDim)) + throw std::runtime_error("MmulHelper::mmulMxV: Y array must be vector !"); + if (X->lengthOf() != N) + throw std::runtime_error( + "MmulHelper::mmulMxV: X vector has wrong length !"); + if (Y != nullptr && Y->lengthOf() != M) + throw std::runtime_error("MmulHelper::mmulMxV: Y array has wrong length !"); + + if (Y == nullptr) + Y = new NDArray( + outOrder, {M}, + DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), + A->getContext()); + + if (Y->isEmpty()) return Y; + + const int incx = X->stridesOf()[xLenDim]; + const int incy = Y->stridesOf()[yLenDim]; + + const auto aType = A->dataType(); + const auto xType = X->dataType(); + const auto yType = Y->dataType(); + + const bool AX(aType == xType), AY(aType == yType), AXY(AX && AY); + const bool hasGemv = BlasHelper::getInstance().hasGEMV(aType); + + const bool typeDouble = hasGemv && AXY && aType == DataType::DOUBLE; + const bool typeFloat = hasGemv && AXY && aType == DataType::FLOAT32; + + if (!typeDouble && !typeFloat) { + BUILD_SINGLE_SELECTOR_THRICE( + aType, usualGemv, (A, X, Y, incx, incy, 0, alpha, beta), NUMERIC_TYPES); + // BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (A, X, Y, incx, + // incy, 0, alpha, beta), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); + } else { + NDArray* pA(const_cast(A)); + + bool aMcont = M == 1 || A->strideAt(0) == 1; + bool aNcont = N == 1 || A->strideAt(1) == 1; + + if (!aMcont && !aNcont) { + pA = new NDArray(A->dup('f')); + aMcont = true; } - else { - - NDArray *pA(const_cast(A)); - - bool aMcont = M == 1 || A->strideAt(0) == 1; - bool aNcont = N == 1 || A->strideAt(1) == 1; - - if(!aMcont && !aNcont) { - pA = new NDArray(A->dup('f')); - aMcont = true; - } - const CBLAS_ORDER blasOrder = aMcont ? CblasColMajor : CblasRowMajor; - - const int lda = (aMcont && aNcont) ? M : !aMcont ? pA->strideAt(0) : pA->strideAt(1); - - // choose appropriate cuda gemm api depending on data types - if(typeDouble) { - BlasHelper::getInstance().dgemv()(blasOrder, CblasNoTrans, M, N, alpha, (double*)pA->buffer(), lda, (double*)X->buffer(), incx, beta, (double*)Y->buffer(), incy); - } - else if(typeFloat) { - BlasHelper::getInstance().sgemv()(blasOrder, CblasNoTrans, M, N, (float)alpha, (float*)pA->buffer(), lda, (float*)X->buffer(), incx, (float)beta, (float*)Y->buffer(), incy); - } - - if(pA != A) - delete pA; + const CBLAS_ORDER blasOrder = aMcont ? CblasColMajor : CblasRowMajor; + + const int lda = + (aMcont && aNcont) ? M : !aMcont ? pA->strideAt(0) : pA->strideAt(1); + + // choose appropriate cuda gemm api depending on data types + if (typeDouble) { + BlasHelper::getInstance().dgemv()( + blasOrder, CblasNoTrans, M, N, alpha, (double*)pA->buffer(), lda, + (double*)X->buffer(), incx, beta, (double*)Y->buffer(), incy); + } else if (typeFloat) { + BlasHelper::getInstance().sgemv()( + blasOrder, CblasNoTrans, M, N, (float)alpha, (float*)pA->buffer(), + lda, (float*)X->buffer(), incx, (float)beta, (float*)Y->buffer(), + incy); } - return Y; + if (pA != A) delete pA; + } + + return Y; } //////////////////////////////////////////////////////////////////////////// // (X * Y) = Z[0] -NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, const double alpha, const double beta) { - if (X->dataType() != Y->dataType()) - throw datatype_exception::build("Dot expects all data types to be the same", X->dataType(), Y->dataType()); - - if (Z != nullptr && X->dataType() != Z->dataType()) - throw datatype_exception::build("Dot expects all data types to be the same", X->dataType(), Z->dataType()); - - int xLenDim(0), yLenDim(0); - - if(!shape::isCommonVector(X->shapeInfo(), xLenDim)) - throw std::runtime_error("MmulHelper::dot: X array must be vector !"); - if(!shape::isCommonVector(Y->shapeInfo(), yLenDim)) - throw std::runtime_error("MmulHelper::dot: Y array must be vector !"); - if(Z != nullptr && !Z->isScalar()) - throw std::runtime_error("MmulHelper::dot: Z array must be scalar !"); - - const auto length = X->lengthOf(); - - if(Y->lengthOf() != length) - throw std::runtime_error("MmulHelper::dot: lengths of input vectors are different !"); - - if(Z == nullptr) - Z = new NDArray(DataTypeUtils::pickPairwiseResultType(X->dataType(), Y->dataType()), X->getContext()); - - const Nd4jLong incx = X->stridesOf()[xLenDim]; - const Nd4jLong incy = Y->stridesOf()[yLenDim]; - - const auto xType = X->dataType(); - const auto yType = Y->dataType(); - const auto zType = Z->dataType(); - - BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (length, alpha, X->buffer(), incx, Y->buffer(), incy, beta, Z->buffer()), NUMERIC_TYPES); - //BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (length, alpha, X->buffer(), incx, Y->buffer(), incy, beta, Z->buffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); - - return Z; +NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, + const double alpha, const double beta) { + if (X->dataType() != Y->dataType()) + throw datatype_exception::build("Dot expects all data types to be the same", + X->dataType(), Y->dataType()); + + if (Z != nullptr && X->dataType() != Z->dataType()) + throw datatype_exception::build("Dot expects all data types to be the same", + X->dataType(), Z->dataType()); + + int xLenDim(0), yLenDim(0); + + if (!shape::isCommonVector(X->shapeInfo(), xLenDim)) + throw std::runtime_error("MmulHelper::dot: X array must be vector !"); + if (!shape::isCommonVector(Y->shapeInfo(), yLenDim)) + throw std::runtime_error("MmulHelper::dot: Y array must be vector !"); + if (Z != nullptr && !Z->isScalar()) + throw std::runtime_error("MmulHelper::dot: Z array must be scalar !"); + + const auto length = X->lengthOf(); + + if (Y->lengthOf() != length) + throw std::runtime_error( + "MmulHelper::dot: lengths of input vectors are different !"); + + if (Z == nullptr) + Z = new NDArray( + DataTypeUtils::pickPairwiseResultType(X->dataType(), Y->dataType()), + X->getContext()); + + const Nd4jLong incx = X->stridesOf()[xLenDim]; + const Nd4jLong incy = Y->stridesOf()[yLenDim]; + + const auto xType = X->dataType(); + const auto yType = Y->dataType(); + const auto zType = Z->dataType(); + + BUILD_SINGLE_SELECTOR_THRICE( + xType, usualDot, + (length, alpha, X->buffer(), incx, Y->buffer(), incy, beta, Z->buffer()), + NUMERIC_TYPES); + // BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (length, alpha, + // X->buffer(), incx, Y->buffer(), incy, beta, Z->buffer()), LIBND4J_TYPES, + // FLOAT_TYPES, FLOAT_TYPES); + + return Z; } ////////////////////////////////////////////////////////////////////////////// @@ -405,79 +450,81 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, con // [M,K] x [bS,K,N] = [bS,M,N] // bS could stand for several axes template -static void batchedGemm(const NDArray* vA, const NDArray* vB, NDArray* vC, - const int* aBatchDims, const int* bBatchDims, const int* cBatchDims, - const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, - const double alpha, const double beta) { - - const T1* A = vA->bufferAsT(); - const T2* B = vB->bufferAsT(); - T3* C = vC->bufferAsT(); - - const T3 alphaZ = alpha; - const T3 betaZ = beta; - - const bool betaPersent = beta; - - const Nd4jLong* aShapeInfo = vA->shapeInfo(); - const Nd4jLong* bShapeInfo = vB->shapeInfo(); - const Nd4jLong* cShapeInfo = vC->shapeInfo(); - - const int aRank = vA->rankOf(); - const int bRank = vB->rankOf(); - const int cRank = vC->rankOf(); - - const Nd4jLong cLen = vC->lengthOf(); - - const int K = vA->sizeAt(aKaxis); - - auto func = PRAGMA_THREADS_FOR { - - std::vector aCoords(aRank), bCoords(bRank), cCoords(cRank); - - for (auto i = start; i < stop; ++i) { - - // evaluate C coordinates - shape::index2coordsCPU(start, i, cShapeInfo, cCoords.data()); - - // calculate index of current batch - Nd4jLong batchInd; - if(cRank > 2) - batchInd = shape::coords2index(cShapeInfo, cCoords.data(), cRank - 2, cBatchDims); - - // evaluate A coordinates - if(aRank > 2) - shape::index2coords(batchInd, aShapeInfo, aCoords.data(), aRank - 2, aBatchDims); - aCoords[aMaxis] = cCoords[cMaxis]; - aCoords[aKaxis] = 0; - - // evaluate B coordinates - if(bRank > 2) - shape::index2coords(batchInd, bShapeInfo, bCoords.data(), bRank - 2, bBatchDims); - bCoords[bKaxis] = 0; - bCoords[bNaxis] = cCoords[cNaxis]; - - auto aOffset = shape::getOffset(aShapeInfo, aCoords.data()); - auto bOffset = shape::getOffset(bShapeInfo, bCoords.data()); - - T3 val = A[aOffset] * B[bOffset]; // first iteration - - for (int j = 1; j < K; ++j) { // rest iterations - aOffset += shape::stride(aShapeInfo)[aKaxis]; - bOffset += shape::stride(bShapeInfo)[bKaxis]; - val = val + A[aOffset] * B[bOffset]; - } - - auto cOffset = shape::getOffset(cShapeInfo, cCoords.data()); - - if(betaPersent) - C[cOffset] = alphaZ * val + betaZ * C[cOffset]; - else - C[cOffset] = alphaZ * val; - } - }; +static void batchedGemm(const NDArray* vA, const NDArray* vB, NDArray* vC, + const int* aBatchDims, const int* bBatchDims, + const int* cBatchDims, const int aMaxis, + const int aKaxis, const int bKaxis, const int bNaxis, + const int cMaxis, const int cNaxis, const double alpha, + const double beta) { + const T1* A = vA->bufferAsT(); + const T2* B = vB->bufferAsT(); + T3* C = vC->bufferAsT(); + + const T3 alphaZ = alpha; + const T3 betaZ = beta; + + const bool betaPersent = beta; + + const Nd4jLong* aShapeInfo = vA->shapeInfo(); + const Nd4jLong* bShapeInfo = vB->shapeInfo(); + const Nd4jLong* cShapeInfo = vC->shapeInfo(); + + const int aRank = vA->rankOf(); + const int bRank = vB->rankOf(); + const int cRank = vC->rankOf(); + + const Nd4jLong cLen = vC->lengthOf(); + + const int K = vA->sizeAt(aKaxis); + + auto func = PRAGMA_THREADS_FOR { + std::vector aCoords(aRank), bCoords(bRank), cCoords(cRank); + + for (auto i = start; i < stop; ++i) { + // evaluate C coordinates + shape::index2coordsCPU(start, i, cShapeInfo, cCoords.data()); + + // calculate index of current batch + Nd4jLong batchInd; + if (cRank > 2) + batchInd = shape::coords2index(cShapeInfo, cCoords.data(), cRank - 2, + cBatchDims); + + // evaluate A coordinates + if (aRank > 2) + shape::index2coords(batchInd, aShapeInfo, aCoords.data(), aRank - 2, + aBatchDims); + aCoords[aMaxis] = cCoords[cMaxis]; + aCoords[aKaxis] = 0; + + // evaluate B coordinates + if (bRank > 2) + shape::index2coords(batchInd, bShapeInfo, bCoords.data(), bRank - 2, + bBatchDims); + bCoords[bKaxis] = 0; + bCoords[bNaxis] = cCoords[cNaxis]; + + auto aOffset = shape::getOffset(aShapeInfo, aCoords.data()); + auto bOffset = shape::getOffset(bShapeInfo, bCoords.data()); + + T3 val = A[aOffset] * B[bOffset]; // first iteration + + for (int j = 1; j < K; ++j) { // rest iterations + aOffset += shape::stride(aShapeInfo)[aKaxis]; + bOffset += shape::stride(bShapeInfo)[bKaxis]; + val = val + A[aOffset] * B[bOffset]; + } + + auto cOffset = shape::getOffset(cShapeInfo, cCoords.data()); + + if (betaPersent) + C[cOffset] = alphaZ * val + betaZ * C[cOffset]; + else + C[cOffset] = alphaZ * val; + } + }; - samediff::Threads::parallel_tad(func, 0, cLen); + samediff::Threads::parallel_tad(func, 0, cLen); } ////////////////////////////////////////////////////////////////////////// @@ -485,89 +532,109 @@ static void batchedGemm(const NDArray* vA, const NDArray* vB, NDArray* vC, // [bS,M,K] x [K,N] = [bS,M,N] // [M,K] x [bS,K,N] = [bS,M,N] // bS could stand for several axes -NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { - - const int aRank = A->rankOf(); - const int bRank = B->rankOf(); - - // input ranks validation - if(aRank > bRank && bRank != 2) - throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be equal 2 !"); - else if(bRank > aRank && aRank != 2) - throw std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); - else if (aRank == bRank ) { - for(int i = 0; i < aRank - 2; ++i) - if(A->sizeAt(i) != B->sizeAt(i)) - throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); - } - - if(A->sizeAt(-1) != B->sizeAt(-2)) - throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); - - // validation of C array - std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); - cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); - cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); - - if(C != nullptr ) { - if(!C->isSameShape(cExpectedShape)) - throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !"); - } - else { - C = new NDArray(outOrder, cExpectedShape, B->dataType()); - } - - if (C->isEmpty()) - return C; - - const int cRank = C->rankOf(); - - const int aMaxis(aRank-2), aKaxis(aRank-1), bKaxis(bRank-2), bNaxis(bRank-1), cMaxis(cRank-2), cNaxis(cRank-1); - - std::vector aBatchDims, bBatchDims, cBatchDims; - - if(aRank > 2) - aBatchDims = ShapeUtils::evalDimsToExclude(aRank, {aMaxis, aKaxis}); - if(bRank > 2) - bBatchDims = ShapeUtils::evalDimsToExclude(bRank, {bKaxis, bNaxis}); - if(cRank > 2) - cBatchDims = ShapeUtils::evalDimsToExclude(cRank, {cMaxis, cNaxis}); - - // BUILD_TRIPLE_SELECTOR(A->dataType(), B->dataType(), C->dataType(), batchedGemm, (A, B, C, aBatchDims.data(), bBatchDims.data(), cBatchDims.data(), aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); - BUILD_SINGLE_SELECTOR_THRICE(A->dataType(), batchedGemm, (A, B, C, aBatchDims.data(), bBatchDims.data(), cBatchDims.data(), aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), NUMERIC_TYPES); - - return C; +NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, + const double alpha, const double beta, + const char outOrder) { + const int aRank = A->rankOf(); + const int bRank = B->rankOf(); + + // input ranks validation + if (aRank > bRank && bRank != 2) + throw std::runtime_error( + "MmulHelper::mmulNxN: rank of B array should be equal 2 !"); + else if (bRank > aRank && aRank != 2) + throw std::runtime_error( + "MmulHelper::mmulNxN: rank of A array should be equal 2 !"); + else if (aRank == bRank) { + for (int i = 0; i < aRank - 2; ++i) + if (A->sizeAt(i) != B->sizeAt(i)) + throw std::runtime_error( + "MmulHelper::mmulNxN: shapes of A and B arrays are not suitable " + "for matrix multiplication !"); + } + + if (A->sizeAt(-1) != B->sizeAt(-2)) + throw std::runtime_error( + "MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for " + "matrix multiplication !"); + + // validation of C array + std::vector cExpectedShape = + aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); + cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); + cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); + + if (C != nullptr) { + if (!C->isSameShape(cExpectedShape)) + throw std::runtime_error( + "MmulHelper::mmulNxN: shape of C array is not suitable for AxB " + "matrix multiplication !"); + } else { + C = new NDArray(outOrder, cExpectedShape, B->dataType()); + } + + if (C->isEmpty()) return C; + + const int cRank = C->rankOf(); + + const int aMaxis(aRank - 2), aKaxis(aRank - 1), bKaxis(bRank - 2), + bNaxis(bRank - 1), cMaxis(cRank - 2), cNaxis(cRank - 1); + + std::vector aBatchDims, bBatchDims, cBatchDims; + + if (aRank > 2) + aBatchDims = ShapeUtils::evalDimsToExclude(aRank, {aMaxis, aKaxis}); + if (bRank > 2) + bBatchDims = ShapeUtils::evalDimsToExclude(bRank, {bKaxis, bNaxis}); + if (cRank > 2) + cBatchDims = ShapeUtils::evalDimsToExclude(cRank, {cMaxis, cNaxis}); + + // BUILD_TRIPLE_SELECTOR(A->dataType(), B->dataType(), C->dataType(), + // batchedGemm, (A, B, C, aBatchDims.data(), bBatchDims.data(), + // cBatchDims.data(), aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, + // beta), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE( + A->dataType(), batchedGemm, + (A, B, C, aBatchDims.data(), bBatchDims.data(), cBatchDims.data(), aMaxis, + aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), + NUMERIC_TYPES); + + return C; } /* ////////////////////////////////////////////////////////////////////////// -NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { +NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, +const double alpha, const double beta, const char outOrder) { const int aRank = A->rankOf(); const int bRank = B->rankOf(); // input ranks validation if(aRank > bRank && bRank != 2) - throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be equal 2 !"); - else if(bRank > aRank && aRank != 2) - throw std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); + throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be +equal 2 !"); else if(bRank > aRank && aRank != 2) throw +std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); else if (aRank == bRank ) { for(int i = 0; i < aRank - 2; ++i) if(A->sizeAt(i) != B->sizeAt(i)) - throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B +arrays are not suitable for matrix multiplication !"); } if(A->sizeAt(-1) != B->sizeAt(-2)) - throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays +are not suitable for matrix multiplication !"); // validation of C array - std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); - cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); - cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); + std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() +: B->getShapeAsVector(); cExpectedShape[cExpectedShape.size() - 2] = +A->sizeAt(-2); cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); if(C != nullptr ) { if(!C->isSameShape(cExpectedShape)) - throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !"); + throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is +not suitable for AxB matrix multiplication !"); } else { C = new NDArray(outOrder, cExpectedShape, B->dataType()); @@ -575,15 +642,16 @@ NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, con // multiplication - const std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(C->rankOf(), {-2, -1}); - const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(C->shapeInfo(), dimsToExclude); + const std::vector dimsToExclude = +ShapeUtils::evalDimsToExclude(C->rankOf(), {-2, -1}); const Nd4jLong +numOfSubArrs = ShapeUtils::getNumOfSubArrs(C->shapeInfo(), dimsToExclude); std::vector idxRanges(2 * C->rankOf()); // #pragma omp parallel for schedule(guided) firstprivate(idxRanges) for(Nd4jLong i = 0; i < numOfSubArrs; ++i) { - ShapeUtils::evalIdxRangesForSubArr(i, C->shapeInfo(), dimsToExclude, idxRanges.data()); - NDArray cSubArr = (*C)(idxRanges); + ShapeUtils::evalIdxRangesForSubArr(i, C->shapeInfo(), dimsToExclude, +idxRanges.data()); NDArray cSubArr = (*C)(idxRanges); if(aRank > bRank) { NDArray aSubArr = (*A)(idxRanges); @@ -606,7 +674,10 @@ NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, con ////////////////////////////////////////////////////////////////////////////// // MXK x KxN = MxN template -static void usualGemm(const char cOrder, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc) { +static void usualGemm(const char cOrder, const bool transA, const bool transB, +const int M, const int N, const int K, const double alpha, const void* vA, const +int lda, const void* vB, const int ldb, const double beta, void* vC, const int +ldc) { T1* A = reinterpret_cast(const_cast(vA)); T2* B = reinterpret_cast(const_cast(vB)); @@ -617,7 +688,8 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c const bool flagA = (flagC && transA) || (!flagC && !transA); const bool flagB = (flagC && transB) || (!flagC && !transB); - // PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M*N > Environment::getInstance().elementwiseThreshold()) schedule(guided)) + // PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M*N > +Environment::getInstance().elementwiseThreshold()) schedule(guided)) // for(uint row = 0; row < M; ++row) { // T3* c = flagC ? (C + row) : (C + row * ldc); @@ -633,7 +705,8 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c // if(flagC) { // for(uint col = 0; col < N; ++col) { // if(betaZ) - // c[col * ldc] += a * b[flagB ? col : col * ldb] + betaZ * c[col * ldc]; + // c[col * ldc] += a * b[flagB ? col : col * ldb] + +betaZ * c[col * ldc]; // else // c[col * ldc] += a * b[flagB ? col : col * ldb]; // } @@ -641,7 +714,8 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c // else { // for(uint col = 0; col < N; ++col) { // if(betaZ) - // c[col] += a * b[flagB ? col : col * ldb] + betaZ * c[col]; + // c[col] += a * b[flagB ? col : col * ldb] + betaZ * +c[col]; // else // c[col] += a * b[flagB ? col : col * ldb]; // } @@ -675,7 +749,9 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c ////////////////////////////////////////////////////////////////////////////// // MXN x N = M template -static void usualGemv(const char aOrder, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vX, const int incx, const double beta, void* vY, const int incy) { +static void usualGemv(const char aOrder, const int M, const int N, const double +alpha, const void* vA, const int lda, const void* vX, const int incx, const +double beta, void* vY, const int incy) { T1* A = reinterpret_cast(const_cast(vA)); T2* X = reinterpret_cast(const_cast(vX)); @@ -707,8 +783,17 @@ static void usualGemv(const char aOrder, const int M, const int N, const double } */ -//BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const char cOrder, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* A, const int lda, const void* B, const int ldb, const double beta, void* C, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); -//BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const char aOrder, const int M, const int N, const double alpha, const void* A, const int lda, const void* B, const int incx, const double beta, void* C, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); -//BUILD_TRIPLE_TEMPLATE(template void usualDot, (const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); - -} +// BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const char cOrder, const bool +// transA, const bool transB, const int M, const int N, const int K, const +// double alpha, const void* A, const int lda, const void* B, const int ldb, +// const double beta, void* C, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, +// FLOAT_TYPES); BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const char +// aOrder, const int M, const int N, const double alpha, const void* A, const +// int lda, const void* B, const int incx, const double beta, void* C, const int +// incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); +// BUILD_TRIPLE_TEMPLATE(template void usualDot, (const Nd4jLong length, const +// double alpha, const void* vX, const Nd4jLong incx, const void* vY, const +// Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, +// FLOAT_TYPES); + +} // namespace sd diff --git a/libnd4j/include/helpers/cpu/PointersManager.cpp b/libnd4j/include/helpers/cpu/PointersManager.cpp index 61eb7b2ecccf..839474254393 100644 --- a/libnd4j/include/helpers/cpu/PointersManager.cpp +++ b/libnd4j/include/helpers/cpu/PointersManager.cpp @@ -20,35 +20,37 @@ #ifndef __CUDABLAS__ -#include #include +#include #include #include namespace sd { ////////////////////////////////////////////////////////////////////////// -PointersManager::PointersManager(const sd::LaunchContext *context, const std::string& funcName) { - _context = const_cast(context); - _funcName = funcName; +PointersManager::PointersManager(const sd::LaunchContext* context, + const std::string& funcName) { + _context = const_cast(context); + _funcName = funcName; } ////////////////////////////////////////////////////////////////////////// -void* PointersManager::replicatePointer(const void* src, const size_t numberOfBytes) { - // no-op - return const_cast(src); +void* PointersManager::replicatePointer(const void* src, + const size_t numberOfBytes) { + // no-op + return const_cast(src); } ////////////////////////////////////////////////////////////////////////// void PointersManager::synchronize() const { - // no-op + // no-op } ////////////////////////////////////////////////////////////////////////// PointersManager::~PointersManager() { - // no-op + // no-op } -} +} // namespace sd #endif diff --git a/libnd4j/include/helpers/cpu/cublasHelper.cpp b/libnd4j/include/helpers/cpu/cublasHelper.cpp index 4b17e601d27a..803cc709c088 100644 --- a/libnd4j/include/helpers/cpu/cublasHelper.cpp +++ b/libnd4j/include/helpers/cpu/cublasHelper.cpp @@ -18,39 +18,26 @@ // @author raver119@gmail.com // -#include "../cublasHelper.h" +#include namespace sd { - static void* handle_() { - return nullptr; - } +static void* handle_() { return nullptr; } - static void destroyHandle_(void* handle) { +static void destroyHandle_(void* handle) {} - } +CublasHelper::CublasHelper() {} - CublasHelper::CublasHelper() { +CublasHelper::~CublasHelper() {} - } +CublasHelper& CublasHelper::getInstance() { + static CublasHelper instance; - CublasHelper::~CublasHelper() { + return instance; +} - } +void* CublasHelper::handle() { return nullptr; } - CublasHelper& CublasHelper::getInstance() { - static CublasHelper instance; - return instance; - } +void* CublasHelper::solver() { return nullptr; } - void* CublasHelper::handle() { - return nullptr; - } - - void* CublasHelper::solver() { - return nullptr; - } - - void* CublasHelper::handle(int deviceId) { - return nullptr; - } -} \ No newline at end of file +void* CublasHelper::handle(int deviceId) { return nullptr; } +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops.hpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops.hpp index a64f0fc913ad..bfcca574dcb2 100644 --- a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops.hpp +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops.hpp @@ -22,293 +22,309 @@ using namespace simdOps; - ////////////////////////////////////////////////////////////////////////////// template template -void sd::IndexReductionLoops::loopIndexReduce(const X* x, const Nd4jLong* xShapeInfo, - Z* z, const Nd4jLong* zShapeInfo, - const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, - X* extraParams) { - - sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(xShapeInfo, zShapeInfo, tadShapeInfo); - if(kindOfLoop == sd::LoopKind::SMALLARR2DX) - kindOfLoop = sd::LoopKind::EWSNONZERO; - - const Nd4jLong zLen = shape::length(zShapeInfo); - const Nd4jLong tadLen = shape::length(tadShapeInfo); - - const uint tadEws = shape::elementWiseStride(tadShapeInfo); - const uint zEws = shape::elementWiseStride(zShapeInfo); - - const Nd4jLong* tadShape = shape::shapeOf(const_cast(tadShapeInfo)); - const Nd4jLong* tadStride = shape::stride(const_cast(tadShapeInfo)); - - switch (kindOfLoop) { - //*********************************************// - case sd::LoopKind::EWS1: { - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto tad = const_cast(x) + tadOffsets[i]; - auto indexValue = OpType::startingIndexValue(tad); - - for (Nd4jLong j = 0; j < tadLen; j++) { - functions::indexreduce::IndexValue comp(tad[j], j); - indexValue = OpType::update(indexValue, comp, extraParams); - } - - z[i] = (Z) indexValue.index; - } - }; - - samediff::Threads::parallel_tad(func, 0, zLen); +void sd::IndexReductionLoops::loopIndexReduce( + const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams) { + sd::LoopKind::Kind kindOfLoop = + sd::LoopKind::deduceKindOfLoopTadXZ(xShapeInfo, zShapeInfo, tadShapeInfo); + if (kindOfLoop == sd::LoopKind::SMALLARR2DX) + kindOfLoop = sd::LoopKind::EWSNONZERO; + + const Nd4jLong zLen = shape::length(zShapeInfo); + const Nd4jLong tadLen = shape::length(tadShapeInfo); + + const uint tadEws = shape::elementWiseStride(tadShapeInfo); + const uint zEws = shape::elementWiseStride(zShapeInfo); + + const Nd4jLong* tadShape = + shape::shapeOf(const_cast(tadShapeInfo)); + const Nd4jLong* tadStride = + shape::stride(const_cast(tadShapeInfo)); + + switch (kindOfLoop) { + //*********************************************// + case sd::LoopKind::EWS1: { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto tad = const_cast(x) + tadOffsets[i]; + auto indexValue = OpType::startingIndexValue(tad); + + for (Nd4jLong j = 0; j < tadLen; j++) { + functions::indexreduce::IndexValue comp(tad[j], j); + indexValue = OpType::update(indexValue, comp, extraParams); + } + + z[i] = (Z)indexValue.index; } - break; + }; - //*********************************************// - case sd::LoopKind::EWSNONZERO: { + samediff::Threads::parallel_tad(func, 0, zLen); + } break; - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto tad = const_cast(x) + tadOffsets[i]; - auto indexValue = OpType::startingIndexValue(tad); + //*********************************************// + case sd::LoopKind::EWSNONZERO: { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto tad = const_cast(x) + tadOffsets[i]; + auto indexValue = OpType::startingIndexValue(tad); - for (Nd4jLong j = 0; j < tadLen; j++) { - functions::indexreduce::IndexValue comp(tad[j * tadEws], j); - indexValue = OpType::update(indexValue, comp, extraParams); - } + for (Nd4jLong j = 0; j < tadLen; j++) { + functions::indexreduce::IndexValue comp(tad[j * tadEws], j); + indexValue = OpType::update(indexValue, comp, extraParams); + } - z[i * zEws] = (Z) indexValue.index; - } - }; - - samediff::Threads::parallel_tad(func, 0, zLen); + z[i * zEws] = (Z)indexValue.index; } - break; + }; - //*********************************************// - case sd::LoopKind::RANK1: { + samediff::Threads::parallel_tad(func, 0, zLen); + } break; - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto tad = const_cast(x) + tadOffsets[i]; - auto indexValue = OpType::startingIndexValue(tad); + //*********************************************// + case sd::LoopKind::RANK1: { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto tad = const_cast(x) + tadOffsets[i]; + auto indexValue = OpType::startingIndexValue(tad); - for (Nd4jLong i0 = 0; i0 < tadLen; ++i0) { - functions::indexreduce::IndexValue comp(tad[i0 * tadStride[0]], i0); - indexValue = OpType::update(indexValue, comp, extraParams); - } + for (Nd4jLong i0 = 0; i0 < tadLen; ++i0) { + functions::indexreduce::IndexValue comp(tad[i0 * tadStride[0]], + i0); + indexValue = OpType::update(indexValue, comp, extraParams); + } - z[i] = (Z) indexValue.index; - } - }; - - samediff::Threads::parallel_tad(func, 0, zLen); + z[i] = (Z)indexValue.index; } - break; - - //*********************************************// - case sd::LoopKind::RANK2: { - Nd4jLong newStride[2]; - shape::updateStrides(2, tadShape, newStride, 'c'); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto tad = const_cast(x) + tadOffsets[i]; - auto indexValue = OpType::startingIndexValue(tad); - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { - const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1]; - const auto tadIndex = i0 * newStride[0] + i1; - functions::indexreduce::IndexValue comp(tad[tadOffset], tadIndex); - indexValue = OpType::update(indexValue, comp, extraParams); - } - } - - z[i] = (Z) indexValue.index; - } - }; - - samediff::Threads::parallel_tad(func, 0, zLen); + }; + + samediff::Threads::parallel_tad(func, 0, zLen); + } break; + + //*********************************************// + case sd::LoopKind::RANK2: { + Nd4jLong newStride[2]; + shape::updateStrides(2, tadShape, newStride, 'c'); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto tad = const_cast(x) + tadOffsets[i]; + auto indexValue = OpType::startingIndexValue(tad); + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { + const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1]; + const auto tadIndex = i0 * newStride[0] + i1; + functions::indexreduce::IndexValue comp(tad[tadOffset], + tadIndex); + indexValue = OpType::update(indexValue, comp, extraParams); + } + } + + z[i] = (Z)indexValue.index; } - break; - - //*********************************************// - case sd::LoopKind::RANK3: { - Nd4jLong newStride[3]; - shape::updateStrides(3, tadShape, newStride, 'c'); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto tad = const_cast(x) + tadOffsets[i]; - auto indexValue = OpType::startingIndexValue(tad); - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { - for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { - const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2]; - const auto tadIndex = i0 * newStride[0] + i1 * newStride[1] + i2; - functions::indexreduce::IndexValue comp(tad[tadOffset], tadIndex); - indexValue = OpType::update(indexValue, comp, extraParams); - } - } - } - - z[i] = (Z) indexValue.index; - } - }; - - samediff::Threads::parallel_tad(func, 0, zLen); + }; + + samediff::Threads::parallel_tad(func, 0, zLen); + } break; + + //*********************************************// + case sd::LoopKind::RANK3: { + Nd4jLong newStride[3]; + shape::updateStrides(3, tadShape, newStride, 'c'); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto tad = const_cast(x) + tadOffsets[i]; + auto indexValue = OpType::startingIndexValue(tad); + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { + for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { + const auto tadOffset = + i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2]; + const auto tadIndex = + i0 * newStride[0] + i1 * newStride[1] + i2; + functions::indexreduce::IndexValue comp(tad[tadOffset], + tadIndex); + indexValue = OpType::update(indexValue, comp, extraParams); + } + } + } + + z[i] = (Z)indexValue.index; } - break; - - //*********************************************// - case sd::LoopKind::RANK4: { - Nd4jLong newStride[4]; - shape::updateStrides(4, tadShape, newStride, 'c'); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto tad = const_cast(x) + tadOffsets[i]; - auto indexValue = OpType::startingIndexValue(tad); - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { - for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { - for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) { - const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2] + i3 * tadStride[3]; - const auto tadIndex = i0 * newStride[0] + i1 * newStride[1] + i2 * newStride[2] + i3; - functions::indexreduce::IndexValue comp(tad[tadOffset], tadIndex); - indexValue = OpType::update(indexValue, comp, extraParams); - } - } - } - } - - z[i] = (Z) indexValue.index; + }; + + samediff::Threads::parallel_tad(func, 0, zLen); + } break; + + //*********************************************// + case sd::LoopKind::RANK4: { + Nd4jLong newStride[4]; + shape::updateStrides(4, tadShape, newStride, 'c'); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto tad = const_cast(x) + tadOffsets[i]; + auto indexValue = OpType::startingIndexValue(tad); + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { + for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { + for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) { + const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1] + + i2 * tadStride[2] + i3 * tadStride[3]; + const auto tadIndex = i0 * newStride[0] + i1 * newStride[1] + + i2 * newStride[2] + i3; + functions::indexreduce::IndexValue comp(tad[tadOffset], + tadIndex); + indexValue = OpType::update(indexValue, comp, extraParams); } - }; + } + } + } - samediff::Threads::parallel_tad(func, 0, zLen); + z[i] = (Z)indexValue.index; } - break; - - //*********************************************// - case sd::LoopKind::RANK5: { - Nd4jLong newStride[5]; - shape::updateStrides(5, tadShape, newStride, 'c'); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto tad = const_cast(x) + tadOffsets[i]; - auto indexValue = OpType::startingIndexValue(tad); - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { - for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { - for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) { - for (Nd4jLong i4 = 0; i4 < tadShape[4]; ++i4) { - const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2] + i3 * tadStride[3] + i4 * tadStride[4]; - const auto tadIndex = i0 * newStride[0] + i1 * newStride[1] + i2 * newStride[2] + i3 * newStride[3] + i4; - functions::indexreduce::IndexValue comp(tad[tadOffset], tadIndex); - indexValue = OpType::update(indexValue, comp, extraParams); - } - } - } - } - } - - z[i] = (Z) indexValue.index; + }; + + samediff::Threads::parallel_tad(func, 0, zLen); + } break; + + //*********************************************// + case sd::LoopKind::RANK5: { + Nd4jLong newStride[5]; + shape::updateStrides(5, tadShape, newStride, 'c'); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto tad = const_cast(x) + tadOffsets[i]; + auto indexValue = OpType::startingIndexValue(tad); + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { + for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { + for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) { + for (Nd4jLong i4 = 0; i4 < tadShape[4]; ++i4) { + const auto tadOffset = + i0 * tadStride[0] + i1 * tadStride[1] + + i2 * tadStride[2] + i3 * tadStride[3] + + i4 * tadStride[4]; + const auto tadIndex = + i0 * newStride[0] + i1 * newStride[1] + + i2 * newStride[2] + i3 * newStride[3] + i4; + functions::indexreduce::IndexValue comp(tad[tadOffset], + tadIndex); + indexValue = OpType::update(indexValue, comp, extraParams); + } } - }; + } + } + } - samediff::Threads::parallel_tad(func, 0, zLen); + z[i] = (Z)indexValue.index; } - break; - - //*********************************************// - case sd::LoopKind::X_EWSNONZERO: { - uint castZShapeInfo[MAX_RANK]; - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, castZShapeInfo); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto tad = const_cast(x) + tadOffsets[i]; - auto indexValue = OpType::startingIndexValue(tad); - - for (Nd4jLong j = 0; j < tadLen; j++) { - functions::indexreduce::IndexValue comp(tad[j * tadEws], j); - indexValue = OpType::update(indexValue, comp, extraParams); - } - - auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ); - z[zOffset] = (Z) indexValue.index; - } - }; - - samediff::Threads::parallel_tad(func, 0, zLen); + }; + + samediff::Threads::parallel_tad(func, 0, zLen); + } break; + + //*********************************************// + case sd::LoopKind::X_EWSNONZERO: { + uint castZShapeInfo[MAX_RANK]; + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, castZShapeInfo); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto tad = const_cast(x) + tadOffsets[i]; + auto indexValue = OpType::startingIndexValue(tad); + + for (Nd4jLong j = 0; j < tadLen; j++) { + functions::indexreduce::IndexValue comp(tad[j * tadEws], j); + indexValue = OpType::update(indexValue, comp, extraParams); + } + + auto zOffset = + shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ); + z[zOffset] = (Z)indexValue.index; } - break; - - //*********************************************// - case sd::LoopKind::Z_EWSNONZERO: { - uint castTadShapeInfo[MAX_RANK]; - const bool canCastTad = sd::DataTypeUtils::castShapeInfo(tadShapeInfo, castTadShapeInfo); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto tad = const_cast(x) + tadOffsets[i]; - auto indexValue = OpType::startingIndexValue(tad); - - for (Nd4jLong j = 0; j < tadLen; j++) { - auto tadOffset = shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, canCastTad); - functions::indexreduce::IndexValue comp(tad[tadOffset], j); - indexValue = OpType::update(indexValue, comp, extraParams); - } - - z[i * zEws] = (Z) indexValue.index; - } - }; - - samediff::Threads::parallel_tad(func, 0, zLen); + }; + + samediff::Threads::parallel_tad(func, 0, zLen); + } break; + + //*********************************************// + case sd::LoopKind::Z_EWSNONZERO: { + uint castTadShapeInfo[MAX_RANK]; + const bool canCastTad = sd::DataTypeUtils::castShapeInfo( + tadShapeInfo, castTadShapeInfo); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto tad = const_cast(x) + tadOffsets[i]; + auto indexValue = OpType::startingIndexValue(tad); + + for (Nd4jLong j = 0; j < tadLen; j++) { + auto tadOffset = shape::indexOffset(j, tadShapeInfo, + castTadShapeInfo, canCastTad); + functions::indexreduce::IndexValue comp(tad[tadOffset], j); + indexValue = OpType::update(indexValue, comp, extraParams); + } + + z[i * zEws] = (Z)indexValue.index; } - break; - - //*********************************************// - default: { - uint castTadShapeInfo[MAX_RANK]; - uint castZShapeInfo[MAX_RANK]; - const bool canCastTad = sd::DataTypeUtils::castShapeInfo(tadShapeInfo, castTadShapeInfo); - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, castZShapeInfo); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto tad = const_cast(x) + tadOffsets[i]; - auto indexValue = OpType::startingIndexValue(tad); - - for (Nd4jLong j = 0; j < tadLen; j++) { - auto tadOffset = shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, canCastTad); - functions::indexreduce::IndexValue comp(tad[tadOffset], j); - indexValue = OpType::update(indexValue, comp, extraParams); - } - - auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ); - z[zOffset] = (Z) indexValue.index; - } - }; - - samediff::Threads::parallel_tad(func, 0, zLen); + }; + + samediff::Threads::parallel_tad(func, 0, zLen); + } break; + + //*********************************************// + default: { + uint castTadShapeInfo[MAX_RANK]; + uint castZShapeInfo[MAX_RANK]; + const bool canCastTad = sd::DataTypeUtils::castShapeInfo( + tadShapeInfo, castTadShapeInfo); + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, castZShapeInfo); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto tad = const_cast(x) + tadOffsets[i]; + auto indexValue = OpType::startingIndexValue(tad); + + for (Nd4jLong j = 0; j < tadLen; j++) { + auto tadOffset = shape::indexOffset(j, tadShapeInfo, + castTadShapeInfo, canCastTad); + functions::indexreduce::IndexValue comp(tad[tadOffset], j); + indexValue = OpType::update(indexValue, comp, extraParams); + } + + auto zOffset = + shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ); + z[zOffset] = (Z)indexValue.index; } + }; + + samediff::Threads::parallel_tad(func, 0, zLen); } + } } template -void sd::IndexReductionLoops::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - DISPATCH_BY_OPNUM_TT(loopIndexReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams), INDEX_REDUCE_OPS); +void sd::IndexReductionLoops::wrapIndexReduce( + const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, void* vextraParams) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + DISPATCH_BY_OPNUM_TT(loopIndexReduce, + PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, + tadOffsets, extraParams), + INDEX_REDUCE_OPS); } \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/loops/Reduction3Loops.cpp.in b/libnd4j/include/helpers/cpu/loops/Reduction3Loops.cpp.in index 4f38b4d8f397..c762b192e7fe 100644 --- a/libnd4j/include/helpers/cpu/loops/Reduction3Loops.cpp.in +++ b/libnd4j/include/helpers/cpu/loops/Reduction3Loops.cpp.in @@ -22,6 +22,5 @@ #cmakedefine FLOAT_TYPE_GEN namespace sd { - - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduction3Loops, , LIBND4J_TYPES, FLOAT_TYPES_@FL_TYPE_INDEX@); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduction3Loops, , LIBND4J_TYPES, FLOAT_TYPES_@FL_TYPE_INDEX@); } diff --git a/libnd4j/include/helpers/cpu/loops/Reduction3Loops.hpp b/libnd4j/include/helpers/cpu/loops/Reduction3Loops.hpp index 241dc7e8cd25..d9edf2bdd2e7 100644 --- a/libnd4j/include/helpers/cpu/loops/Reduction3Loops.hpp +++ b/libnd4j/include/helpers/cpu/loops/Reduction3Loops.hpp @@ -18,9 +18,9 @@ // @author raver119@gmail.com // +#include #include #include -#include using namespace simdOps; diff --git a/libnd4j/include/helpers/cpu/loops/ReductionLoops_bool.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_bool.cpp index e122717fc983..553a5f906f0d 100644 --- a/libnd4j/include/helpers/cpu/loops/ReductionLoops_bool.cpp +++ b/libnd4j/include/helpers/cpu/loops/ReductionLoops_bool.cpp @@ -24,25 +24,32 @@ using namespace simdOps; namespace sd { - template - template - void ReductionBoolLoops::innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) { +template +template +void ReductionBoolLoops::innerloopReduce( + const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, + int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - ReductionLoops::template loopReduce(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop); + ReductionLoops::template loopReduce( + x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, + start, stop); #endif - } +} - template - void ReductionBoolLoops::wrapper(const int opNum, - const X *x, const Nd4jLong *xShapeInfo, - Y *z, const Nd4jLong *zShapeInfo, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - X *extraParams, int64_t start, int64_t stop) { +template +void ReductionBoolLoops::wrapper( + const int opNum, const X* x, const Nd4jLong* xShapeInfo, Y* z, + const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_BOOL_OPS); + DISPATCH_BY_OPNUM_TT(innerloopReduce, + PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, + tadOffsets, extraParams, start, stop), + REDUCE_BOOL_OPS); #endif - } - - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionBoolLoops, , LIBND4J_TYPES, BOOL_TYPES); } +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReductionBoolLoops, , + LIBND4J_TYPES, BOOL_TYPES); +} // namespace sd diff --git a/libnd4j/include/helpers/cpu/loops/ReductionLoops_float.cpp.in b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float.cpp.in index 5c1bb227d8d3..9443dc497230 100644 --- a/libnd4j/include/helpers/cpu/loops/ReductionLoops_float.cpp.in +++ b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float.cpp.in @@ -22,7 +22,7 @@ #cmakedefine FLOAT_TYPE_GEN namespace sd { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_@FL_TYPE_INDEX@); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_@FL_TYPE_INDEX@); } diff --git a/libnd4j/include/helpers/cpu/loops/ReductionLoops_long.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_long.cpp index be6cb28bdda4..81da1b348d8f 100644 --- a/libnd4j/include/helpers/cpu/loops/ReductionLoops_long.cpp +++ b/libnd4j/include/helpers/cpu/loops/ReductionLoops_long.cpp @@ -22,31 +22,41 @@ using namespace simdOps; - -#include "ReductionLoops.hpp" #include #include +#include "ReductionLoops.hpp" + using namespace simdOps; namespace sd { - template - template - void ReductionLongLoops::innerloopReduce(const X * x, const Nd4jLong* xShapeInfo, Z *z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) { +template +template +void ReductionLongLoops::innerloopReduce( + const X *x, const Nd4jLong *xShapeInfo, Z *z, const Nd4jLong *zShapeInfo, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, X *extraParams, + int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - ReductionLoops::template loopReduce(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop); + ReductionLoops::template loopReduce( + x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, + start, stop); #endif - } +} - template - void ReductionLongLoops::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo, Y *z, - const Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo, - const Nd4jLong *tadOffsets, X *extraParams, int64_t start, int64_t stop) { +template +void ReductionLongLoops::wrapper( + const int opNum, const X *x, const Nd4jLong *xShapeInfo, Y *z, + const Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, X *extraParams, int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_LONG_OPS); + DISPATCH_BY_OPNUM_TT(innerloopReduce, + PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, + tadOffsets, extraParams, start, stop), + REDUCE_LONG_OPS); #endif - } - - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionLongLoops, , LIBND4J_TYPES, LONG_TYPES); } + +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReductionLongLoops, , + LIBND4J_TYPES, LONG_TYPES); +} // namespace sd diff --git a/libnd4j/include/helpers/cpu/loops/ReductionLoops_same.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_same.cpp index 53725de83846..382cabd60680 100644 --- a/libnd4j/include/helpers/cpu/loops/ReductionLoops_same.cpp +++ b/libnd4j/include/helpers/cpu/loops/ReductionLoops_same.cpp @@ -24,27 +24,37 @@ using namespace simdOps; namespace sd { - template - template - void ReductionSameLoops::innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, X* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) { +template +template +void ReductionSameLoops::innerloopReduce( + const X *x, const Nd4jLong *xShapeInfo, X *z, const Nd4jLong *zShapeInfo, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, X *extraParams, + int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - ReductionLoops::template loopReduce(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop); + ReductionLoops::template loopReduce( + x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, + start, stop); #endif - } +} - template - void ReductionSameLoops::wrapper(const int opNum, const X *vx, const Nd4jLong *xShapeInfo, X *vz, - const Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo, - const Nd4jLong *tadOffsets, - X *vextraParams, int64_t start, int64_t stop) { +template +void ReductionSameLoops::wrapper(const int opNum, const X *vx, + const Nd4jLong *xShapeInfo, X *vz, + const Nd4jLong *zShapeInfo, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, X *vextraParams, + int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - DISPATCH_BY_OPNUM_T(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_SAME_OPS); + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + DISPATCH_BY_OPNUM_T(innerloopReduce, + PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, + tadOffsets, extraParams, start, stop), + REDUCE_SAME_OPS); #endif - } - - BUILD_SINGLE_TEMPLATE(template class ReductionSameLoops, , LIBND4J_TYPES); } + +BUILD_SINGLE_TEMPLATE(template class ReductionSameLoops, , LIBND4J_TYPES); +} // namespace sd diff --git a/libnd4j/include/helpers/cpu/svd.cpp b/libnd4j/include/helpers/cpu/svd.cpp index 8a320f6de4a2..dfe100dc96ca 100644 --- a/libnd4j/include/helpers/cpu/svd.cpp +++ b/libnd4j/include/helpers/cpu/svd.cpp @@ -28,278 +28,291 @@ namespace sd { namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// template -SVD::SVD(const NDArray& matrix, const int switchSize, const bool calcU, const bool calcV, const bool fullUV ) { - - if(matrix.rankOf() != 2 || matrix.isScalar()) - throw std::runtime_error("ops::helpers::SVD constructor: input array must be 2D matrix !"); - - const int rows = matrix.sizeAt(0); - const int cols = matrix.sizeAt(1); - - if(cols > rows) { - - _transp = true; - _diagSize = rows; - } - else { - - _transp = false; - _diagSize = cols; - } - - _switchSize = switchSize; - _calcU = calcU; - _calcV = calcV; - _fullUV = fullUV; - - if (_transp) - math::nd4j_swap(_calcU, _calcV); - - _s = NDArray(matrix.ordering(), {_diagSize, 1}, matrix.dataType(), matrix.getContext()); - _m = NDArray(matrix.ordering(), {_diagSize + 1, _diagSize}, matrix.dataType(), matrix.getContext()); - // _m.assign(0.); - - if (_calcU) - _u = NDArray(matrix.ordering(), {_diagSize + 1, _diagSize + 1}, matrix.dataType(), matrix.getContext()); - else - _u = NDArray(matrix.ordering(), {2, _diagSize + 1}, matrix.dataType(), matrix.getContext()); - // _u.assign(0.); - - if (_calcV) { - _v = NDArray(matrix.ordering(), {_diagSize, _diagSize}, matrix.dataType(), matrix.getContext()); - // _v.assign(0.); - } - - evalData(matrix); +SVD::SVD(const NDArray& matrix, const int switchSize, const bool calcU, + const bool calcV, const bool fullUV) { + if (matrix.rankOf() != 2 || matrix.isScalar()) + throw std::runtime_error( + "ops::helpers::SVD constructor: input array must be 2D matrix !"); + + const int rows = matrix.sizeAt(0); + const int cols = matrix.sizeAt(1); + + if (cols > rows) { + _transp = true; + _diagSize = rows; + } else { + _transp = false; + _diagSize = cols; + } + + _switchSize = switchSize; + _calcU = calcU; + _calcV = calcV; + _fullUV = fullUV; + + if (_transp) math::nd4j_swap(_calcU, _calcV); + + _s = NDArray(matrix.ordering(), {_diagSize, 1}, + matrix.dataType(), matrix.getContext()); + _m = NDArray(matrix.ordering(), {_diagSize + 1, _diagSize}, + matrix.dataType(), matrix.getContext()); + // _m.assign(0.); + + if (_calcU) + _u = NDArray( + matrix.ordering(), {_diagSize + 1, _diagSize + 1}, matrix.dataType(),matrix.getContext()); + else + _u = NDArray(matrix.ordering(), {2, _diagSize + 1}, + matrix.dataType(), matrix.getContext()); + // _u.assign(0.); + + if (_calcV) { + _v = NDArray(matrix.ordering(), {_diagSize, _diagSize}, + matrix.dataType(), matrix.getContext()); + //_v.assign(0.); + } + + evalData(matrix); } ////////////////////////////////////////////////////////////////////////// template -SVD::SVD(const NDArray& matrix, const int switchSize, const bool calcU, const bool calcV, const bool fullUV, const char t) { - - if(matrix.rankOf() != 2 || matrix.isScalar()) - throw std::runtime_error("ops::helpers::SVD constructor: input array must be 2D matrix !"); - - const int rows = matrix.sizeAt(0); - const int cols = matrix.sizeAt(1); - - if(cols > rows) { - - _transp = true; - _diagSize = rows; - } - else { - - _transp = false; - _diagSize = cols; - } - - _switchSize = switchSize; - _calcU = calcU; - _calcV = calcV; - _fullUV = fullUV; - - if (_transp) - math::nd4j_swap(_calcU, _calcV); - - _s = NDArray(matrix.ordering(), {_diagSize, 1}, matrix.dataType(), matrix.getContext()); - _m = NDArray(matrix.ordering(), {_diagSize + 1, _diagSize}, matrix.dataType(), matrix.getContext()); - // _m.assign(0.f); - - if (_calcU) - _u = NDArray(matrix.ordering(), {_diagSize + 1, _diagSize + 1}, matrix.dataType(), matrix.getContext()); - else - _u = NDArray(matrix.ordering(), {2, _diagSize + 1}, matrix.dataType(), matrix.getContext()); - // _u.assign(0.); - - if (_calcV) { - _v = NDArray(matrix.ordering(), {_diagSize, _diagSize}, matrix.dataType(), matrix.getContext()); - // _v.assign(0.); - } +SVD::SVD(const NDArray& matrix, const int switchSize, const bool calcU, + const bool calcV, const bool fullUV, const char t) { + if (matrix.rankOf() != 2 || matrix.isScalar()) + throw std::runtime_error( + "ops::helpers::SVD constructor: input array must be 2D matrix !"); + + const int rows = matrix.sizeAt(0); + const int cols = matrix.sizeAt(1); + + if (cols > rows) { + _transp = true; + _diagSize = rows; + } else { + _transp = false; + _diagSize = cols; + } + + _switchSize = switchSize; + _calcU = calcU; + _calcV = calcV; + _fullUV = fullUV; + + if (_transp) math::nd4j_swap(_calcU, _calcV); + + _s = NDArray(matrix.ordering(), {_diagSize, 1}, + matrix.dataType(), matrix.getContext()); + _m = NDArray(matrix.ordering(), {_diagSize + 1, _diagSize}, + matrix.dataType(), matrix.getContext()); + // _m.assign(0.f); + + if (_calcU) + _u = NDArray( + matrix.ordering(), {_diagSize + 1, _diagSize + 1}, matrix.dataType(),matrix.getContext()); + else + _u = NDArray(matrix.ordering(), {2, _diagSize + 1}, + matrix.dataType(), matrix.getContext()); + // _u.assign(0.); + + if (_calcV) { + _v = NDArray(matrix.ordering(), {_diagSize, _diagSize}, + matrix.dataType(), matrix.getContext()); + //_v.assign(0.); + } } - ////////////////////////////////////////////////////////////////////////// template void SVD::deflation1(int col1, int shift, int ind, int size) { - - if(ind <= 0) - throw std::runtime_error("ops::helpers::SVD::deflation1 method: input int must satisfy condition ind > 0 !"); - - int first = col1 + shift; - T cos = _m.t(first, first); - T sin = _m.t(first+ind, first); - T denom = math::nd4j_sqrt(cos*cos + sin*sin); - - if (denom == (T)0.) { - _m.r(first+ind, first+ind) = (T)0; - return; - } - - cos /= denom; - sin /= denom; - - _m.r(first,first) = denom; - _m.r(first+ind, first) = (T)0; - _m.r(first+ind, first+ind) = (T)0; - - NDArray rotation(_m.ordering(), {2, 2}, _m.dataType(), _m.getContext()); - - rotation.r(0,0) = rotation.r(1,1) = cos; - rotation.r(0,1) = -sin; - rotation.r(1,0) = sin; - - if (_calcU) { - auto temp = _u({col1,col1+size+1, 0,0}, true); - JacobiSVD::mulRotationOnRight(col1, col1+ind, temp, rotation); - } - else - JacobiSVD::mulRotationOnRight(col1, col1+ind, _u, rotation); + if (ind <= 0) + throw std::runtime_error( + "ops::helpers::SVD::deflation1 method: input int must satisfy " + "condition ind > 0 !"); + + int first = col1 + shift; + T cos = _m.t(first, first); + T sin = _m.t(first + ind, first); + T denom = math::nd4j_sqrt(cos * cos + sin * sin); + + if (denom == (T)0.) { + _m.r(first + ind, first + ind) = (T)0; + return; + } + + cos /= denom; + sin /= denom; + + _m.r(first, first) = denom; + _m.r(first + ind, first) = (T)0; + _m.r(first + ind, first + ind) = (T)0; + + NDArray rotation(_m.ordering(), {2, 2}, _m.dataType(),_m.getContext()); + rotation.r(0,0) = rotation.r(1,1) = cos; + rotation.r(0,1) = -sin; + rotation.r(1,0) = sin; + + if (_calcU) { + auto temp = _u({col1, col1 + size + 1, 0, 0}, true); + JacobiSVD::mulRotationOnRight(col1, col1 + ind, temp, rotation); + } else + JacobiSVD::mulRotationOnRight(col1, col1 + ind, _u, rotation); } ////////////////////////////////////////////////////////////////////////// template -void SVD::deflation2(int col1U , int col1M, int row1W, int col1W, int ind1, int ind2, int size) { - - if(ind1 >= ind2) - throw std::runtime_error("ops::helpers::SVD::deflation2 method: input intes must satisfy condition ind1 < ind2 !"); - - if(size <= 0) - throw std::runtime_error("ops::helpers::SVD::deflation2 method: input size must satisfy condition size > 0 !"); - - T cos = _m.t(col1M+ind1, col1M); - T sin = _m.t(col1M+ind2, col1M); - T denom = math::nd4j_sqrt(cos*cos + sin*sin); - - if (denom == (T)0.) { - _m.r(col1M+ind1, col1M+ind1) = _m.t(col1M+ind2, col1M+ind2); - return; - } - - cos /= denom; - sin /= denom; - _m.r(col1M+ind1, col1M) = denom; - _m.r(col1M+ind2, col1M+ind2) = _m.t(col1M+ind1, col1M+ind1); - _m.r(col1M+ind2, col1M) = (T)0; - - NDArray rotation(_m.ordering(), {2, 2}, _m.dataType(), _m.getContext()); - - rotation.r(0,0) = rotation.r(1,1) = cos; - rotation.r(0,1) = -sin; - rotation.r(1,0) = sin; - - if (_calcU) { - auto temp = _u({col1U,col1U+size+1, 0,0}, true); - JacobiSVD::mulRotationOnRight(col1U+ind1, col1U+ind2, temp, rotation); - } - else - JacobiSVD::mulRotationOnRight(col1U+ind1, col1U+ind2, _u, rotation); - - if (_calcV) { - auto temp = _v({row1W,row1W+size, 0,0}, true); - JacobiSVD::mulRotationOnRight(col1W+ind1, col1W+ind2, temp, rotation); - } +void SVD::deflation2(int col1U, int col1M, int row1W, int col1W, int ind1, + int ind2, int size) { + if (ind1 >= ind2) + throw std::runtime_error( + "ops::helpers::SVD::deflation2 method: input intes must satisfy " + "condition ind1 < ind2 !"); + + if (size <= 0) + throw std::runtime_error( + "ops::helpers::SVD::deflation2 method: input size must satisfy " + "condition size > 0 !"); + + T cos = _m.t(col1M + ind1, col1M); + T sin = _m.t(col1M + ind2, col1M); + T denom = math::nd4j_sqrt(cos * cos + sin * sin); + + if (denom == (T)0.) { + _m.r(col1M+ind1, col1M+ind1) = _m.t(col1M+ind2, col1M+ind2); + return; + } + + cos /= denom; + sin /= denom; + _m.r(col1M+ind1, col1M) = denom; + _m.r(col1M+ind2, col1M+ind2) = _m.t(col1M+ind1, col1M+ind1); + _m.r(col1M+ind2, col1M) = (T)0; + + NDArray rotation(_m.ordering(), {2, 2}, _m.dataType(),_m.getContext()); + + + rotation.r(0,0) = rotation.r(1,1) = cos; + rotation.r(0, 1) = -sin; + rotation.r(1, 0) = sin; + + if (_calcU) { + auto temp = _u({col1U, col1U + size + 1, 0, 0}, true); + JacobiSVD::mulRotationOnRight(col1U + ind1, col1U + ind2, temp, + rotation); + } else + JacobiSVD::mulRotationOnRight(col1U + ind1, col1U + ind2, _u, rotation); + + if (_calcV) { + auto temp = _v({row1W, row1W + size, 0, 0}, true); + JacobiSVD::mulRotationOnRight(col1W + ind1, col1W + ind2, temp, + rotation); + } } ////////////////////////////////////////////////////////////////////////// -// has effect on block from (col1+shift, col1+shift) to (col2+shift, col2+shift) inclusively +// has effect on block from (col1+shift, col1+shift) to (col2+shift, col2+shift) +// inclusively template -void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int shift) -{ - - const int len = col2 + 1 - col1; +void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, + int shift) { + const int len = col2 + 1 - col1; NDArray colVec0 = _m({col1+shift,col1+shift+len, col1+shift,col1+shift+1}, true); - NDArray diagInterval = _m({col1+shift,col1+shift+len, col1+shift,col1+shift+len}, true).diagonal('c'); + NDArray diagInterval = + _m({col1 + shift, col1 + shift + len, col1 + shift, col1 + shift + len}, + true) + .diagonal('c'); - const T almostZero = DataTypeUtils::min(); - T maxElem; - if(len == 1) - maxElem = math::nd4j_abs(diagInterval.template t(0)); - else - maxElem = diagInterval({1,-1, 0,0}, true).reduceNumber(reduce::AMax).template t(0); - T maxElem0 = colVec0.reduceNumber(reduce::AMax).template t(0); + const T almostZero = DataTypeUtils::min(); + T maxElem; + if (len == 1) + maxElem = math::nd4j_abs(diagInterval.template t(0)); + else + maxElem = diagInterval({1, -1, 0, 0}, true) + .reduceNumber(reduce::AMax) + .template t(0); + T maxElem0 = colVec0.reduceNumber(reduce::AMax).template t(0); - T eps = math::nd4j_max(almostZero, DataTypeUtils::eps() * maxElem); - T epsBig = (T)8. * DataTypeUtils::eps() * math::nd4j_max(maxElem0, maxElem); + T eps = math::nd4j_max(almostZero, DataTypeUtils::eps() * maxElem); + T epsBig = + (T)8. * DataTypeUtils::eps() * math::nd4j_max(maxElem0, maxElem); - if(diagInterval.template t(0) < epsBig) - diagInterval.r(0) = epsBig; + if (diagInterval.template t(0) < epsBig) + diagInterval.r(0) = epsBig; for(int i=1; i < len; ++i) if(math::nd4j_abs(colVec0.template t(i)) < eps) colVec0.r(i) = (T)0; - for(int i=1; i < len; i++) - if(diagInterval.template t(i) < epsBig) { - deflation1(col1, shift, i, len); - for(int i = 0; i < len; ++i) - diagInterval.r(i) = _m.t(col1+shift+i,col1+shift+i); - } - - { - - bool totDefl = true; - for(int i=1; i < len; i++) - if(colVec0.template t(i) >= almostZero) { - totDefl = false; - break; - } + for (int i = 1; i < len; i++) + if (diagInterval.template t(i) < epsBig) { + deflation1(col1, shift, i, len); + for (int i = 0; i < len; ++i) + diagInterval.r(i) = _m.t(col1 + shift + i, col1 + shift + i); + } - int* permut = nullptr; - ALLOCATE(permut, _m.getContext()->getWorkspace(), 3*_diagSize, int); - { - permut[0] = 0; - int p = 1; - - for(int i=1; i(diagInterval.template t(i)) < almostZero) - permut[p++] = i; - - int k = 1, m = ind+1; - - for( ; p < len; ++p) { - if(k > ind) - permut[p] = m++; - else if(m >= len) - permut[p] = k++; - else if(diagInterval.template t(k) < diagInterval.template t(m)) - permut[p] = m++; - else - permut[p] = k++; - } - } + { + bool totDefl = true; + for (int i = 1; i < len; i++) + if (colVec0.template t(i) >= almostZero) { + totDefl = false; + break; + } - if(totDefl) { - for(int i=1; i(diagInterval.template t(ki)) < almostZero || diagInterval.template t(0) < diagInterval.template t(ki)) - permut[i-1] = permut[i]; - else { - permut[i-1] = 0; - break; - } - } + int* permut = nullptr; + ALLOCATE(permut, _m.getContext()->getWorkspace(), 3 * _diagSize, int); + { + permut[0] = 0; + int p = 1; + + for (int i = 1; i < len; ++i) + if (math::nd4j_abs(diagInterval.template t(i)) < almostZero) + permut[p++] = i; + + int k = 1, m = ind + 1; + + for (; p < len; ++p) { + if (k > ind) + permut[p] = m++; + else if (m >= len) + permut[p] = k++; + else if (diagInterval.template t(k) < diagInterval.template t(m)) + permut[p] = m++; + else + permut[p] = k++; + } + } + + if (totDefl) { + for (int i = 1; i < len; ++i) { + int ki = permut[i]; + if (math::nd4j_abs(diagInterval.template t(ki)) < almostZero || + diagInterval.template t(0) < diagInterval.template t(ki)) + permut[i - 1] = permut[i]; + else { + permut[i - 1] = 0; + break; } + } + } - int *tInd = permut + len; - int *tCol = permut + 2*len; + int* tInd = permut + len; + int* tCol = permut + 2 * len; - for(int m = 0; m < len; m++) { - tCol[m] = m; - tInd[m] = m; - } + for (int m = 0; m < len; m++) { + tCol[m] = m; + tInd[m] = m; + } - for(int i = totDefl ? 0 : 1; i < len; i++) { + for (int i = totDefl ? 0 : 1; i < len; i++) { + const int ki = permut[len - (totDefl ? i + 1 : i)]; + const int jac = tCol[ki]; - const int ki = permut[len - (totDefl ? i+1 : i)]; - const int jac = tCol[ki]; + math::nd4j_swap(diagInterval.r(i), diagInterval.r(jac)); - math::nd4j_swap(diagInterval.r(i), diagInterval.r(jac)); if(i!=0 && jac!=0) math::nd4j_swap(colVec0.r(i), colVec0.r(jac)); @@ -321,299 +334,313 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh temp1.swapUnsafe(temp2); } - const int tI = tInd[i]; - tCol[tI] = jac; - tCol[ki] = i; - tInd[jac] = tI; - tInd[i] = ki; - } - - RELEASE(permut, _m.getContext()->getWorkspace()); + const int tI = tInd[i]; + tCol[tI] = jac; + tCol[ki] = i; + tInd[jac] = tI; + tInd[i] = ki; } - { - int i = len-1; + RELEASE(permut, _m.getContext()->getWorkspace()); + } - while(i > 0 && (math::nd4j_abs(diagInterval.template t(i)) < almostZero || math::nd4j_abs(colVec0.template t(i)) < almostZero)) - --i; + { + int i = len - 1; - for(; i > 1; --i) { - if( (diagInterval.template t(i) - diagInterval.template t(i-1)) < DataTypeUtils::eps()*maxElem ) { - if (math::nd4j_abs(diagInterval.template t(i) - diagInterval.template t(i-1)) >= epsBig) - throw std::runtime_error("ops::helpers::SVD::deflation: diagonal elements are not properly sorted !"); - deflation2(col1, col1 + shift, row1W, col1W, i-1, i, len); - } - } + while (i > 0 && + (math::nd4j_abs(diagInterval.template t(i)) < almostZero || + math::nd4j_abs(colVec0.template t(i)) < almostZero)) + --i; + + for (; i > 1; --i) { + if ((diagInterval.template t(i) - diagInterval.template t(i - 1)) < + DataTypeUtils::eps() * maxElem) { + if (math::nd4j_abs(diagInterval.template t(i) - + diagInterval.template t(i - 1)) >= epsBig) + throw std::runtime_error( + "ops::helpers::SVD::deflation: diagonal elements are not " + "properly sorted !"); + deflation2(col1, col1 + shift, row1W, col1W, i - 1, i, len); + } } + } } - ////////////////////////////////////////////////////////////////////////// template -T SVD::secularEq(const T diff, const NDArray& col0, const NDArray& diag, const NDArray& permut, const NDArray& diagShifted, const T shift) { - - auto len = permut.lengthOf(); - T res = 1.; - T item; - for(int i=0; i(i); - item = col0.t(j) / ((diagShifted.t(j) - diff) * (diag.t(j) + shift + diff)); - res += item * col0.t(j); - } - - return res; +T SVD::secularEq(const T diff, const NDArray& col0, const NDArray& diag, + const NDArray& permut, const NDArray& diagShifted, + const T shift) { + auto len = permut.lengthOf(); + T res = 1.; + T item; + for (int i = 0; i < len; ++i) { + int j = (int)permut.t(i); + item = col0.t(j) / + ((diagShifted.t(j) - diff) * (diag.t(j) + shift + diff)); + res += item * col0.t(j); + } + + return res; } ////////////////////////////////////////////////////////////////////////// template -void SVD::calcSingVals(const NDArray& col0, const NDArray& diag, const NDArray& permut, NDArray& singVals, NDArray& shifts, NDArray& mus) { - - auto len = col0.lengthOf(); - auto curLen = len; +void SVD::calcSingVals(const NDArray& col0, const NDArray& diag, + const NDArray& permut, NDArray& singVals, + NDArray& shifts, NDArray& mus) { + auto len = col0.lengthOf(); + auto curLen = len; - while(curLen > 1 && col0.t(curLen-1) == (T)0.f) - --curLen; + while (curLen > 1 && col0.t(curLen - 1) == (T)0.f) --curLen; - for (Nd4jLong k = 0; k < len; ++k) { - - if (col0.t(k) == (T)0.f || curLen==1) { - - singVals.r(k) = k==0 ? col0.t(0) : diag.t(k); - mus.r(k) = (T)0; - shifts.r(k) = k==0 ? col0.t(0) : diag.t(k); - continue; - } - - T left = diag.t(k); - T right; - - if(k==curLen-1) - right = diag.t(curLen-1) + col0.reduceNumber(reduce::Norm2).t(0); - else { - - int l = k+1; - while(col0.t(l) == (T)0.f) { - ++l; - if(l >= curLen) - throw std::runtime_error("ops::helpers::SVD::calcSingVals method: l >= curLen !"); - } - - right = diag.t(l); - } - - T mid = left + (right - left) / (T)2.; - T fMid = secularEq(mid, col0, diag, permut, diag, 0.); - T shift = (k == curLen-1 || fMid > (T)0.) ? left : right; + for (Nd4jLong k = 0; k < len; ++k) { + if (col0.t(k) == (T)0.f || curLen == 1) { + singVals.r(k) = k == 0 ? col0.t(0) : diag.t(k); + mus.r(k) = (T)0; + shifts.r(k) = k == 0 ? col0.t(0) : diag.t(k); + continue; + } - auto diagShifted = diag - shift; + T left = diag.t(k); + T right; - T muPrev, muCur; - if (shift == left) { - muPrev = (right - left) * 0.1; - if (k == curLen-1) - muCur = right - left; - else - muCur = (right - left) * 0.5; - } + if (k == curLen - 1) + right = diag.t(curLen - 1) + col0.reduceNumber(reduce::Norm2).t(0); + else { + int l = k + 1; + while (col0.t(l) == (T)0.f) { + ++l; + if (l >= curLen) + throw std::runtime_error( + "ops::helpers::SVD::calcSingVals method: l >= curLen !"); + } + + right = diag.t(l); + } + + T mid = left + (right - left) / (T)2.; + T fMid = secularEq(mid, col0, diag, permut, diag, 0.); + T shift = (k == curLen - 1 || fMid > (T)0.) ? left : right; + + auto diagShifted = diag - shift; + + T muPrev, muCur; + if (shift == left) { + muPrev = (right - left) * 0.1; + if (k == curLen - 1) + muCur = right - left; + else + muCur = (right - left) * 0.5; + } else { + muPrev = -(right - left) * 0.1; + muCur = -(right - left) * 0.5; + } + + T fPrev = secularEq(muPrev, col0, diag, permut, diagShifted, shift); + T fCur = secularEq(muCur, col0, diag, permut, diagShifted, shift); + + if (math::nd4j_abs(fPrev) < math::nd4j_abs(fCur)) { + math::nd4j_swap(fPrev, fCur); + math::nd4j_swap(muPrev, muCur); + } + + bool useBisection = fPrev * fCur > (T)0.; + while (fCur != (T).0 && + math::nd4j_abs(muCur - muPrev) > + (T)8. * DataTypeUtils::eps() * + math::nd4j_max(math::nd4j_abs(muCur), + math::nd4j_abs(muPrev)) && + math::nd4j_abs(fCur - fPrev) > DataTypeUtils::eps() && + !useBisection) { + T a = (fCur - fPrev) / ((T)1. / muCur - (T)1. / muPrev); + T jac = fCur - a / muCur; + T muZero = -a / jac; + T fZero = secularEq(muZero, col0, diag, permut, diagShifted, shift); + + muPrev = muCur; + fPrev = fCur; + muCur = muZero; + fCur = fZero; + + if (shift == left && (muCur < (T)0. || muCur > right - left)) + useBisection = true; + else if (shift == right && (muCur < -(right - left) || muCur > (T)0.)) + useBisection = true; + else if (math::nd4j_abs(fCur) > math::nd4j_abs(fPrev) && + math::nd4j_abs(fCur - fPrev) > (T)16. * DataTypeUtils::eps()) + useBisection = true; + } + + if (useBisection) { + T leftShifted, rightShifted; + if (shift == left) { + leftShifted = DataTypeUtils::min(); + rightShifted = (k == curLen - 1) ? right : ((right - left) * (T)0.6); + } else { + leftShifted = -(right - left) * (T)0.6; + rightShifted = -DataTypeUtils::min(); + } + + T fLeft = secularEq(leftShifted, col0, diag, permut, diagShifted, shift); + T fRight = + secularEq(rightShifted, col0, diag, permut, diagShifted, shift); + // if(fLeft * fRight >= (T)0.) + // throw "ops::helpers::SVD::calcSingVals method: fLeft * fRight >= (T)0. + // !"; + + while (rightShifted - leftShifted > + (T)2.f * DataTypeUtils::eps() * + math::nd4j_max(math::nd4j_abs(leftShifted), + math::nd4j_abs(rightShifted))) { + T midShifted = (leftShifted + rightShifted) / (T)2.; + fMid = secularEq(midShifted, col0, diag, permut, diagShifted, shift); + if (fLeft * fMid < (T)0.) + rightShifted = midShifted; else { - muPrev = -(right - left) * 0.1; - muCur = -(right - left) * 0.5; - } - - T fPrev = secularEq(muPrev, col0, diag, permut, diagShifted, shift); - T fCur = secularEq(muCur, col0, diag, permut, diagShifted, shift); - - if (math::nd4j_abs(fPrev) < math::nd4j_abs(fCur)) { - math::nd4j_swap(fPrev, fCur); - math::nd4j_swap(muPrev, muCur); + leftShifted = midShifted; + fLeft = fMid; } - - bool useBisection = fPrev * fCur > (T)0.; - while (fCur != (T).0 && - math::nd4j_abs(muCur - muPrev) > (T)8. * DataTypeUtils::eps() * math::nd4j_max(math::nd4j_abs(muCur), math::nd4j_abs(muPrev)) - && math::nd4j_abs(fCur - fPrev) > DataTypeUtils::eps() && !useBisection) { - - T a = (fCur - fPrev) / ((T)1./muCur - (T)1./muPrev); - T jac = fCur - a / muCur; - T muZero = -a/jac; - T fZero = secularEq(muZero, col0, diag, permut, diagShifted, shift); - - muPrev = muCur; - fPrev = fCur; - muCur = muZero; - fCur = fZero; - - if (shift == left && (muCur < (T)0. || muCur > right - left)) - useBisection = true; - else if (shift == right && (muCur < -(right - left) || muCur > (T)0.)) - useBisection = true; - else if (math::nd4j_abs(fCur) > math::nd4j_abs(fPrev) && math::nd4j_abs(fCur - fPrev) > (T)16. * DataTypeUtils::eps()) - useBisection = true; - } - - if (useBisection) { - - T leftShifted, rightShifted; - if (shift == left) { - leftShifted = DataTypeUtils::min(); - rightShifted = (k==curLen-1) ? right : ((right - left) * (T)0.6); - } - else { - leftShifted = -(right - left) * (T)0.6; - rightShifted = -DataTypeUtils::min(); - } - - T fLeft = secularEq(leftShifted, col0, diag, permut, diagShifted, shift); - T fRight = secularEq(rightShifted, col0, diag, permut, diagShifted, shift); - // if(fLeft * fRight >= (T)0.) - // throw "ops::helpers::SVD::calcSingVals method: fLeft * fRight >= (T)0. !"; - - while (rightShifted - leftShifted > (T)2.f * DataTypeUtils::eps() * math::nd4j_max(math::nd4j_abs(leftShifted), math::nd4j_abs(rightShifted))) { - - T midShifted = (leftShifted + rightShifted) / (T)2.; - fMid = secularEq(midShifted, col0, diag, permut, diagShifted, shift); - if (fLeft * fMid < (T)0.) - rightShifted = midShifted; - else { - leftShifted = midShifted; - fLeft = fMid; - } - } - muCur = (leftShifted + rightShifted) / (T)2.; - } - singVals.r(k) = shift + muCur; - shifts.r(k) = shift; - mus.r(k) = muCur; + } + muCur = (leftShifted + rightShifted) / (T)2.; } + singVals.r(k) = shift + muCur; + shifts.r(k) = shift; + mus.r(k) = muCur; + } } ////////////////////////////////////////////////////////////////////////// template -void SVD::perturb(const NDArray& col0, const NDArray& diag, const NDArray& permut, const NDArray& singVals, const NDArray& shifts, const NDArray& mus, NDArray& zhat) { - - int n = col0.lengthOf(); - int m = permut.lengthOf(); - if(m==0) { - zhat.nullify(); - return; - } - - int last = permut.t(m-1); - - for (int k = 0; k < n; ++k) { - - if (col0.t(k) == (T)0.f) - zhat.r(k) = (T)0; - else { - T dk = diag.t(k); - T prod = (singVals.t(last) + dk) * (mus.t(last) + (shifts.t(last) - dk)); - - for(int l = 0; l(l); - if(i!=k) { - int j = i(l-1); - prod *= ((singVals.t(j)+dk) / ((diag.t(i)+dk))) * ((mus.t(j)+(shifts.t(j)-dk)) / ((diag.t(i)-dk))); - } - } - T tmp = math::nd4j_sqrt(prod); - zhat.r(k) = col0.t(k) > (T)0 ? tmp : -tmp; +void SVD::perturb(const NDArray& col0, const NDArray& diag, + const NDArray& permut, const NDArray& singVals, + const NDArray& shifts, const NDArray& mus, NDArray& zhat) { + int n = col0.lengthOf(); + int m = permut.lengthOf(); + if (m == 0) { + zhat.nullify(); + return; + } + + int last = permut.t(m - 1); + + for (int k = 0; k < n; ++k) { + if (col0.t(k) == (T)0.f) + zhat.r(k) = (T)0; + else { + T dk = diag.t(k); + T prod = (singVals.t(last) + dk) * + (mus.t(last) + (shifts.t(last) - dk)); + + for (int l = 0; l < m; ++l) { + int i = (int)permut.t(l); + if (i != k) { + int j = i < k ? i : (int)permut.t(l - 1); + prod *= + ((singVals.t(j) + dk) / ((diag.t(i) + dk))) * + ((mus.t(j) + (shifts.t(j) - dk)) / ((diag.t(i) - dk))); } + } + T tmp = math::nd4j_sqrt(prod); + zhat.r(k) = col0.t(k) > (T)0 ? tmp : -tmp; } + } } - ////////////////////////////////////////////////////////////////////////// template -void SVD::calcSingVecs(const NDArray& zhat, const NDArray& diag, const NDArray& perm, const NDArray& singVals, - const NDArray& shifts, const NDArray& mus, NDArray& U, NDArray& V) { - - int n = zhat.lengthOf(); - int m = perm.lengthOf(); - - for (int k = 0; k < n; ++k) { - - NDArray colU = U({0,0, k,k+1}); - colU.nullify(); - - NDArray colV; - - if (_calcV) { - colV = V({0,0, k,k+1}); - colV.nullify(); - } - - if (zhat.t(k) == (T)0.f) { - colU.r(k) = (T)1; - - if (_calcV) - colV.r(k) = (T)1; - } - else { - - for(int l = 0; l < m; ++l) { - int i = (int)perm.t(l); - U.r(i,k) = zhat.t(i)/(((diag.t(i) - shifts.t(k)) - mus.t(k)) )/( (diag.t(i) + singVals.t(k))); - } - U.r(n,k) = (T)0; - colU /= colU.reduceNumber(reduce::Norm2); - - if (_calcV) { +void SVD::calcSingVecs(const NDArray& zhat, const NDArray& diag, + const NDArray& perm, const NDArray& singVals, + const NDArray& shifts, const NDArray& mus, NDArray& U, + NDArray& V) { + int n = zhat.lengthOf(); + int m = perm.lengthOf(); + + for (int k = 0; k < n; ++k) { + NDArray colU =U({0, 0, k, k + 1}); + colU.nullify(); + NDArray colV; - for(int l = 1; l < m; ++l){ - int i = perm.t(l); - V.r(i,k) = diag.t(i) * zhat.t(i) / (((diag.t(i) - shifts.t(k)) - mus.t(k)) )/( (diag.t(i) + singVals.t(k))); - } - V.r(0,k) = (T)-1; - colV /= colV.reduceNumber(reduce::Norm2); - } + if (_calcV) { + colV = V({0, 0, k, k + 1}); + colV.nullify(); + } + + if (zhat.t(k) == (T)0.f) { + colU.r(k) = (T)1; + + if (_calcV) colV.r(k) = (T)1; + } else { + for (int l = 0; l < m; ++l) { + int i = (int)perm.t(l); + U.r(i, k) = + zhat.t(i) / (((diag.t(i) - shifts.t(k)) - mus.t(k))) / + ((diag.t(i) + singVals.t(k))); + } + U.r(n, k) = (T)0; + colU /= colU.reduceNumber(reduce::Norm2); + + if (_calcV) { + for (int l = 1; l < m; ++l) { + int i = perm.t(l); + V.r(i, k) = + diag.t(i) * zhat.t(i) / + (((diag.t(i) - shifts.t(k)) - mus.t(k))) / + ((diag.t(i) + singVals.t(k))); } + V.r(0, k) = (T)-1; + colV /= colV.reduceNumber(reduce::Norm2); + } } - NDArray colU = U({0,0, n,n+1}); - colU.nullify(); - colU.r(n) = (T)1; -} + } + NDArray colU = U({0, 0, n, n + 1}); + colU .nullify(); + colU.r(n) = (T)1; +} ////////////////////////////////////////////////////////////////////////// template -void SVD::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDArray& V) { - - const T almostZero = DataTypeUtils::min(); - auto col0 = _m({col1, col1+size, col1, col1+1}, true); - auto diag = static_cast(_m({col1, col1+size, col1, col1+size}, true).diagonal('c')); - - diag.r(0) = (T)0; - singVals = NDArray(_m.ordering(), {size, 1}, _m.dataType(), _m.getContext()); - U = NDArray(_u.ordering(), {size+1, size+1}, _u.dataType(), _u.getContext()); - if (_calcV) - V = NDArray(_v.ordering(), {size, size}, _v.dataType(), _v.getContext()); - - int curSize = size; - while(curSize > 1 && diag.template t(curSize-1) == (T)0.f) - --curSize; - - int m = 0; - std::vector indices; - for(int k = 0; k < curSize; ++k) - if(math::nd4j_abs(col0.template t(k)) > almostZero) - indices.push_back(k); - - NDArray permut(_m.ordering(), {(int)indices.size()}, _m.dataType(), _m.getContext()); - for(int k = 0; k < indices.size(); ++k) +void SVD::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, + NDArray& V) { + const T almostZero = DataTypeUtils::min(); + auto col0 = _m({col1, col1 + size, col1, col1 + 1}, true); + auto diag = + _m({col1, col1 + size, col1, col1 + size}, true).diagonal('c').dup(); + + diag.r(0) = (T)0; + singVals = + NDArray(_m.ordering(), {size, 1}, _m.dataType(),_m.getContext()); + U = NDArray(_u.ordering(), {size + 1, size + 1}, + _u.dataType(), _u.getContext()); + if (_calcV) + V = NDArray(_v.ordering(), {size, size}, _v.dataType(), _v.getContext()); + + int curSize = size; + while (curSize > 1 && diag.template t(curSize - 1) == (T)0.f) --curSize; + + int m = 0; + std::vector indices; + for (int k = 0; k < curSize; ++k) + if (math::nd4j_abs(col0.template t(k)) > almostZero) + indices.push_back(k); + + NDArray permut( + _m.ordering(), {(int)indices.size()}, _m.dataType(), _m.getContext()); + for(int k = 0; k < indices.size(); ++k) permut.r(k) = (T)indices[k]; - NDArray shifts(_m.ordering(), {size, 1}, _m.dataType(), _m.getContext()); - NDArray mus(_m.ordering(), {size, 1}, _m.dataType(), _m.getContext()); - NDArray zhat(_m.ordering(), {size, 1}, _m.dataType(), _m.getContext()); - - calcSingVals(col0, diag, permut, singVals, shifts, mus); - perturb(col0, diag, permut, singVals, shifts, mus, zhat); - calcSingVecs(zhat, diag, permut, singVals, shifts, mus, U, V); + NDArray shifts(_m.ordering(), {size, 1}, _m.dataType(),_m.getContext()); + NDArray mus (_m.ordering(), {size, 1}, _m.dataType(),_m.getContext()); + NDArray zhat (_m.ordering(), {size, 1}, _m.dataType(), _m.getContext()); - for(int i=0; i(i) > singVals.t(i+1)) { + for (int i = 0; i < curSize - 1; ++i) { + if (singVals.t(i) > singVals.t(i+1)) { math::nd4j_swap(singVals.r(i), singVals.r(i+1)); @@ -629,9 +656,10 @@ void SVD::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDA } } - auto temp1 = singVals({0,curSize, 0,0}); - for (int e = 0; e < curSize / 2; ++e) - math::nd4j_swap(temp1.r(e), temp1.r(curSize-1-e)); + auto temp1 = singVals({0, curSize, 0, 0}); + for (int e = 0; e < curSize / 2; ++e) + math::nd4j_swap(temp1.r(e), temp1.r(curSize - 1 - e)); + auto temp2 = U({0,0, 0,curSize}, true); for(int i = 0; i < curSize/2; ++i) { @@ -650,117 +678,119 @@ void SVD::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDA } } - ////////////////////////////////////////////////////////////////////////// -template -void SVD::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shift) { - - // requires rows = cols + 1; - const int n = col2 - col1 + 1; - const int k = n/2; - const T almostZero = DataTypeUtils::min(); - T alphaK, betaK, r0, lambda, phi, c0, s0; - - NDArray l(_u.ordering(), {1, k}, _u.dataType(), _u.getContext()); - NDArray f(_u.ordering(), {1, n-k-1}, _u.dataType(), _u.getContext()); +template +void SVD::DivideAndConquer(int col1, int col2, int row1W, int col1W, + int shift) { + // requires rows = cols + 1; + const int n = col2 - col1 + 1; + const int k = n / 2; + const T almostZero = DataTypeUtils::min(); + T alphaK, betaK, r0, lambda, phi, c0, s0; + NDArray l(_u.ordering(), {1, k}, _u.dataType(),_u.getContext()); + NDArray f(_u.ordering(), {1, n - k - 1}, _u.dataType(), _u.getContext()); + + if (n < _switchSize) { + JacobiSVD jac(_m({col1, col1 + n + 1, col1, col1 + n}, true), _calcU, + _calcV, _fullUV); - if(n < _switchSize) { + if (_calcU) + _u({col1, col1 + n + 1, col1, col1 + n + 1}, true).assign(jac._u); + else { + _u({0, 1, col1, col1 + n + 1}, true).assign(jac._u({0, 1, 0, 0}, true)); + _u({1, 2, col1, col1 + n + 1}, true).assign(jac._u({n, n + 1, 0, 0}, true)); + } - JacobiSVD jac(_m({col1,col1+n+1, col1,col1+n}, true), _calcU, _calcV, _fullUV); + if (_calcV) + _v({row1W, row1W + n, col1W, col1W + n}, true).assign(jac._v); - if (_calcU) - _u({col1,col1+n+1, col1,col1+n+1}, true).assign(jac._u); - else { - _u({0,1, col1,col1+n+1}, true).assign(jac._u({0,1, 0,0}, true)); - _u({1,2, col1,col1+n+1}, true).assign(jac._u({n,n+1, 0,0}, true)); - } - if (_calcV) - _v({row1W,row1W+n, col1W,col1W+n}, true).assign(jac._v); - _m({col1+shift,col1+shift+n+1, col1+shift,col1+shift+n}, true).nullify(); - auto diag = _m.diagonal('c'); - diag({col1+shift, col1+shift+n, 0,0}, true).assign(jac._s({0,n, 0,0}, true)); + _m({col1 + shift, col1 + shift + n + 1, col1 + shift, col1 + shift + n}, + true).nullify(); + auto diag = _m.diagonal('c'); + diag({col1 + shift, col1 + shift + n, 0, 0}, true) + .assign(jac._s({0, n, 0, 0}, true)); - return; - } + return; + } - alphaK = _m.t(col1 + k, col1 + k); - betaK = _m.t(col1 + k + 1, col1 + k); + alphaK = _m.t(col1 + k, col1 + k); + betaK = _m.t(col1 + k + 1, col1 + k); - DivideAndConquer(k + 1 + col1, col2, k + 1 + row1W, k + 1 + col1W, shift); - DivideAndConquer(col1, k - 1 + col1, row1W, col1W + 1, shift + 1); + DivideAndConquer(k + 1 + col1, col2, k + 1 + row1W, k + 1 + col1W, shift); + DivideAndConquer(col1, k - 1 + col1, row1W, col1W + 1, shift + 1); - if (_calcU) { - lambda = _u.t(col1 + k, col1 + k); - phi = _u.t(col1 + k + 1, col2 + 1); - } - else { - lambda = _u.t(1, col1 + k); - phi = _u.t(0, col2 + 1); - } + if (_calcU) { + lambda = _u.t(col1 + k, col1 + k); + phi = _u.t(col1 + k + 1, col2 + 1); + } else { + lambda = _u.t(1, col1 + k); + phi = _u.t(0, col2 + 1); + } - r0 = math::nd4j_sqrt((math::nd4j_abs(alphaK * lambda) * math::nd4j_abs(alphaK * lambda)) + math::nd4j_abs(betaK * phi) * math::nd4j_abs(betaK * phi)); + r0 = math::nd4j_sqrt((math::nd4j_abs(alphaK * lambda) * + math::nd4j_abs(alphaK * lambda)) + + math::nd4j_abs(betaK * phi) * + math::nd4j_abs(betaK * phi)); - if(_calcU) { - l.assign(_u({col1+k, col1+k+1, col1,col1+k}, true)); - f.assign(_u({col1+k+1,col1+k+2, col1+k+1,col1+n}, true)); - } - else { - l.assign(_u({1,2, col1, col1+k}, true)); - f.assign(_u({0,1, col1+k+1, col1+n}, true)); - } + if (_calcU) { + l.assign(_u({col1 + k, col1 + k + 1, col1, col1 + k}, true)); + f.assign(_u({col1 + k + 1, col1 + k + 2, col1 + k + 1, col1 + n}, true)); + } else { + l.assign(_u({1, 2, col1, col1 + k}, true)); + f.assign(_u({0, 1, col1 + k + 1, col1 + n}, true)); + } - if (_calcV) - _v.r(row1W+k, col1W) = (T)1; + if (_calcV) _v.r(row1W + k, col1W) = (T)1; - if (r0 < almostZero){ - c0 = 1.; - s0 = 0.; - } - else { - c0 = alphaK * lambda / r0; - s0 = betaK * phi / r0; - } + if (r0 < almostZero) { + c0 = 1.; + s0 = 0.; + } else { + c0 = alphaK * lambda / r0; + s0 = betaK * phi / r0; + } if (_calcU) { NDArray q1 = _u({col1,col1+k+1, col1+k,col1+k+1}, true).dup(); - for (int i = col1 + k - 1; i >= col1; --i) - _u({col1,col1+k+1, i+1,i+2}, true).assign(_u({col1,col1+k+1, i,i+1}, true)); - - NDArray temp1 = _u({col1+k+1,col1+n+1, col2+1,col2+2}, true); - - _u({col1,col1+k+1, col1,col1+1}, true).assign(q1 * c0); - _u({col1,col1+k+1, col2+1,col2+2}, true).assign(q1 * (-s0)); - _u({col1+k+1,col1+n+1, col1,col1+1}, true).assign(temp1 * s0); - temp1 *= c0; - } - else { - - T q1 = _u.t(0, col1 + k); - - for (int i = col1 + k - 1; i >= col1; --i) - _u.r(0, i+1) = _u.r(0, i); - - _u.r(0, col1) = q1 * c0; - _u.r(0, col2+1) = -q1*s0; - _u.r(1, col1) = _u.t(1, col2+1) * s0; - _u.r(1, col2+1) = _u.t(1, col2+1) * c0; - _u({1,2, col1+1, col1+k+1}).nullify(); - _u({0,1, col1+k+1, col1+n}).nullify(); - } - - _m.r(col1+shift, col1+shift) = r0; - - _m({col1+shift+1,col1+shift+k+1, col1+shift,col1+shift+1}, true).assign(l*alphaK); - _m({col1+shift+k+1,col1+shift+n, col1+shift,col1+shift+1}, true).assign(f*betaK); - - deflation(col1, col2, k, row1W, col1W, shift); - - NDArray UofSVD, VofSVD, singVals; - calcBlockSVD(col1 + shift, n, UofSVD, singVals, VofSVD); + for (int i = col1 + k - 1; i >= col1; --i) + _u({col1, col1 + k + 1, i + 1, i + 2}, true).assign(_u({col1, col1 + k + 1, i, i + 1}, true)); + NDArray temp1 = _u({col1+k+1,col1+n+1, col2+1,col2+2}, true); + + _u({col1, col1 + k + 1, col1, col1 + 1}, true).assign(q1 * c0); + _u({col1, col1 + k + 1, col2 + 1, col2 + 2}, true).assign(q1 * (-s0)); + _u({col1 + k + 1, col1 + n + 1, col1, col1 + 1}, true) + .assign(temp1 * + s0); + temp1 *= c0; + } else { + T q1 = _u.t(0, col1 + k); + + for (int i = col1 + k - 1; i >= col1; --i) _u.r(0, i + 1) = _u.r(0, i); + + _u.r(0, col1) = q1 * c0; + _u.r(0, col2 + 1) = -q1 * s0; + _u.r(1, col1) = _u.t(1, col2 + 1) * s0; + _u.r(1, col2+1) = _u.t(1, col2+1) * c0; + _u({1, 2, col1 + 1, col1 + k + 1}).nullify(); + _u({0, 1, col1 + k + 1, col1 + n}).nullify(); + } + + _m.r(col1+shift, col1+shift) = r0; + _m( + {col1 + shift + 1, col1 + shift + k + 1, col1 + shift, col1 + shift + 1}, + true).assign(l * alphaK); + _m( + {col1 + shift + k + 1, col1 + shift + n, col1 + shift, col1 + shift + 1}, + true).assign(f * betaK); + + deflation(col1, col2, k, row1W, col1W, shift); + + NDArray UofSVD, VofSVD, singVals; + calcBlockSVD(col1 + shift, n, UofSVD, singVals, VofSVD); if(_calcU) { auto temp = _u({col1, col1+n+1, col1,col1+n+1}, true); @@ -776,100 +806,91 @@ void SVD::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif temp.assign(mmul(temp, VofSVD)); } - auto blockM = _m({col1+shift,col1+shift+n, col1+shift,col1+shift+n}, true); - blockM.nullify(); - blockM.diagonal('c').assign(singVals); + auto blockM = _m( + {col1 + shift, col1 + shift + n, col1 + shift, col1 + shift + n}, true); + blockM .nullify(); + blockM.diagonal('c').assign(singVals); } ////////////////////////////////////////////////////////////////////////// -template -void SVD::exchangeUV(const HHsequence& hhU, const HHsequence& hhV, const NDArray& U, const NDArray& V) { - - if (_calcU) { - - int colsU = _fullUV ? hhU.rows() : _diagSize; - NDArray temp1(_u.ordering(), {hhU.rows(), colsU}, _u.dataType(), _u.getContext()); - temp1.setIdentity(); - _u = temp1; - - _u({0,_diagSize, 0,_diagSize}, true).assign(V({0,_diagSize, 0,_diagSize}, true)); - const_cast(hhU).mulLeft(_u); - } - - if (_calcV) { - - int colsV = _fullUV ? hhV.rows() : _diagSize; - NDArray temp1(_v.ordering(), {hhV.rows(), colsV}, _v.dataType(), _v.getContext()); - temp1.setIdentity(); - _v = temp1; - - _v({0,_diagSize, 0,_diagSize}, true).assign(U({0,_diagSize, 0,_diagSize}, true)); - const_cast(hhV).mulLeft(_v); - } +template +void SVD::exchangeUV(const HHsequence& hhU, const HHsequence& hhV, + const NDArray& U, const NDArray& V) { + if (_calcU) { + int colsU = _fullUV ? hhU.rows() : _diagSize; + NDArray temp1(_u.ordering(), {hhU.rows(), colsU}, + _u.dataType(), _u.getContext()); + temp1.setIdentity(); + _u = temp1; + + _u({0, _diagSize, 0, _diagSize}, true).assign(V({0, _diagSize, 0, _diagSize}, true)); + const_cast(hhU).mulLeft(_u); + } + + if (_calcV) { + int colsV = _fullUV ? hhV.rows() : _diagSize; + NDArray temp1(_v.ordering(), {hhV.rows(), colsV}, + _v.dataType(), _v.getContext()); + temp1.setIdentity(); + _v = temp1; + + _v({0, _diagSize, 0, _diagSize}, true).assign(U({0, _diagSize, 0, _diagSize}, true)); + const_cast(hhV).mulLeft(_v); + } } ////////////////////////////////////////////////////////////////////////// template void SVD::evalData(const NDArray& matrix) { + const T almostZero = DataTypeUtils::min(); - const T almostZero = DataTypeUtils::min(); + if (matrix.sizeAt(1) < _switchSize) { + JacobiSVD jac(matrix, _calcU, _calcV, _fullUV); - if(matrix.sizeAt(1) < _switchSize) { + if (_calcU) _u = jac._u; + if (_calcV) _v = jac._v; - JacobiSVD jac(matrix, _calcU, _calcV, _fullUV); + _s.assign(jac._s); - if(_calcU) - _u = jac._u; - if(_calcV) - _v = jac._v; + return; + } - _s.assign(jac._s); + T scale = matrix.reduceNumber(reduce::AMax).t(0); - return; - } - - T scale = matrix.reduceNumber(reduce::AMax).t(0); - - if(scale == (T)0.) - scale = 1.; + if (scale == (T)0.) scale = 1.; - BiDiagonalUp biDiag(_transp ? matrix.transpose() : matrix / scale); + BiDiagonalUp biDiag (_transp? matrix.transpose() : matrix / scale); - _u.nullify(); - _v.nullify(); + _u.nullify(); + _v.nullify(); - _m({0,_diagSize, 0,0}, true).assign(biDiag._HHbidiag.transpose()); + _m({0, _diagSize, 0, 0}, true).assign(biDiag._HHbidiag.transpose()); - _m({_m.sizeAt(0)-1,_m.sizeAt(0), 0,0}).nullify(); + _m({_m.sizeAt(0) - 1, _m.sizeAt(0), 0, 0}).nullify(); - DivideAndConquer(0, _diagSize - 1, 0, 0, 0); + DivideAndConquer(0, _diagSize - 1, 0, 0, 0); - for (int i = 0; i < _diagSize; ++i) { - T a = math::nd4j_abs(_m.t(i, i)); - _s.r(i) = a * scale; - if (a < almostZero) { - _s({i+1,_diagSize, 0,0}).nullify(); - break; - } - else if (i == _diagSize-1) - break; - } + for (int i = 0; i < _diagSize; ++i) { + T a = math::nd4j_abs(_m.t(i, i)); + _s.r(i) = a * scale; + if (a < almostZero) { + _s({i + 1, _diagSize, 0, 0}).nullify(); + break; + } else if (i == _diagSize - 1) + break; + } - HHsequence hhV = biDiag.makeHHsequence('v'); + HHsequence hhV = biDiag.makeHHsequence('v'); HHsequence hhU = biDiag.makeHHsequence('u'); if(_transp) exchangeUV(hhV, hhU, _v, _u); - else - exchangeUV(hhU, hhV, _u, _v); + else + exchangeUV(hhU, hhV, _u, _v); } +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT SVD, , FLOAT_TYPES); -BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT SVD,,FLOAT_TYPES); - - - -} -} -} - +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/helpers/cublasHelper.h b/libnd4j/include/helpers/cublasHelper.h index 8ebdc66a7307..f2e8cf3986cf 100644 --- a/libnd4j/include/helpers/cublasHelper.h +++ b/libnd4j/include/helpers/cublasHelper.h @@ -18,34 +18,35 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_CUBLASHELPER_H -#define DEV_TESTS_CUBLASHELPER_H +#ifndef SD_CUBLASHELPER_H +#define SD_CUBLASHELPER_H #include #include -#include + #include +#include namespace sd { - class ND4J_EXPORT CublasHelper { - private: - static std::mutex _mutex; +class SD_EXPORT CublasHelper { + private: + static std::mutex _mutex; + + std::vector _cache; + std::vector _solvers; + std::vector _cudnn; - std::vector _cache; - std::vector _solvers; - std::vector _cudnn; + CublasHelper(); - CublasHelper(); - public: - ~CublasHelper(); - static CublasHelper& getInstance(); + public:~CublasHelper(); + static CublasHelper& getInstance(); - void* cudnn(); - void* solver(); + void* cudnn(); + void* solver(); - void* handle(); - void* handle(int deviceId); - }; -} + void* handle(); + void* handle(int deviceId); +}; +} // namespace sd -#endif //DEV_TESTS_CUBLASHELPER_H +#endif // SD_CUBLASHELPER_H diff --git a/libnd4j/include/helpers/cuda/ConstantHelper.cu b/libnd4j/include/helpers/cuda/ConstantHelper.cu index 7eb9273e5fff..32f4a5b41a14 100644 --- a/libnd4j/include/helpers/cuda/ConstantHelper.cu +++ b/libnd4j/include/helpers/cuda/ConstantHelper.cu @@ -1,6 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. - * Copyright (c) 2019 Konduit K.K. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -19,16 +19,17 @@ // @author raver119@gmail.com // +#include +#include +#include #include +#include +#include #include -#include +#include #include -#include #include #include -#include -#include -#include #include #define CONSTANT_LIMIT 49152 @@ -36,161 +37,174 @@ __constant__ char deviceConstantMemory[CONSTANT_LIMIT]; namespace sd { - static void* getConstantSpace() { - Nd4jPointer dConstAddr; - auto dZ = cudaGetSymbolAddress(reinterpret_cast(&dConstAddr), deviceConstantMemory); - - if (dZ != 0) - throw cuda_exception::build("cudaGetSymbolAddress(...) failed", dZ); +static void *getConstantSpace() { + Nd4jPointer dConstAddr; + auto dZ = cudaGetSymbolAddress(reinterpret_cast(&dConstAddr), + deviceConstantMemory); - return dConstAddr; - } + if (dZ != 0) + throw cuda_exception::build("cudaGetSymbolAddress(...) failed", dZ); - int ConstantHelper::getCurrentDevice() { - return AffinityManager::currentDeviceId(); - } + return dConstAddr; +} - int ConstantHelper::getNumberOfDevices() { - return AffinityManager::numberOfDevices(); - } +int ConstantHelper::getCurrentDevice() { + return AffinityManager::currentDeviceId(); +} +int ConstantHelper::getNumberOfDevices() { + return AffinityManager::numberOfDevices(); +} - ConstantHelper::ConstantHelper() { - auto initialDevice = getCurrentDevice(); +ConstantHelper::ConstantHelper() { + auto initialDevice = getCurrentDevice(); - auto numDevices = getNumberOfDevices(); - _devicePointers.resize(numDevices); - _deviceOffsets.resize(numDevices); - _cache.resize(numDevices); - _counters.resize(numDevices); + auto numDevices = getNumberOfDevices(); + _devicePointers.resize(numDevices); + _deviceOffsets.resize(numDevices); + _cache.resize(numDevices); + _counters.resize(numDevices); - // filling all pointers - for (int e = 0; e < numDevices; e++) { - auto res = cudaSetDevice(e); - if (res != 0) - throw cuda_exception::build("cudaSetDevice failed", res); - auto constant = getConstantSpace(); + // filling all pointers + for (int e = 0; e < numDevices; e++) { + auto res = cudaSetDevice(e); + if (res != 0) throw cuda_exception::build("cudaSetDevice failed", res); + auto constant = getConstantSpace(); - MAP_IMPL devCache; + MAP_IMPL devCache; - _devicePointers[e] = constant; - _deviceOffsets[e] = 0; - _cache[e] = devCache; - _counters[e] = 0L; - } + _devicePointers[e] = constant; + _deviceOffsets[e] = 0; + _cache[e] = devCache; + _counters[e] = 0L; + } - // - auto res = cudaSetDevice(initialDevice); - if (res != 0) - throw cuda_exception::build("Final cudaSetDevice failed", res); - } + // + auto res = cudaSetDevice(initialDevice); + if (res != 0) throw cuda_exception::build("Final cudaSetDevice failed", res); +} -ConstantHelper::~ConstantHelper() { +ConstantHelper ::~ConstantHelper() { for (const auto &v:_cache) { for (const auto &c:v) { delete c.second; - } + } } } - ConstantHelper& ConstantHelper::getInstance() { + ConstantHelper& ConstantHelper::getInstance() { static ConstantHelper instance; return instance; - } +} - void* ConstantHelper::replicatePointer(void *src, size_t numBytes, memory::Workspace *workspace) { - std::lock_guard lock(_mutex); - - auto deviceId = getCurrentDevice(); - Nd4jPointer constantPtr = nullptr; - Nd4jLong constantOffset = 0L; - if (_devicePointers[deviceId] == 0) { - auto constant = getConstantSpace(); - - // filling default ptr, which will be 0 probably - _devicePointers[deviceId] = constant; - _deviceOffsets[deviceId] = 0; - constantPtr = constant; - } else { - constantPtr = _devicePointers[deviceId]; - constantOffset = _deviceOffsets[deviceId]; - } - - if (constantOffset + numBytes >= CONSTANT_LIMIT) { - int8_t *ptr = nullptr; - ALLOCATE_SPECIAL(ptr, workspace, numBytes, int8_t); - auto res = cudaMemcpy(ptr, src, numBytes, cudaMemcpyHostToDevice); - if (res != 0) - throw cuda_exception::build("cudaMemcpy failed", res); - - return ptr; - } else { - auto originalBytes = numBytes; - auto rem = numBytes % 8; - if (rem != 0) - numBytes += 8 - rem; - - _deviceOffsets[deviceId] += numBytes; - - auto res = cudaMemcpyToSymbol(deviceConstantMemory, const_cast(src), originalBytes, constantOffset, cudaMemcpyHostToDevice); - if (res != 0) - throw cuda_exception::build("cudaMemcpyToSymbol failed", res); - - return reinterpret_cast(constantPtr) + constantOffset; - } - } +void *ConstantHelper::replicatePointer(void *src, size_t numBytes, + memory::Workspace *workspace) { + std::lock_guard lock(_mutex); + + auto deviceId = getCurrentDevice(); + Nd4jPointer constantPtr = nullptr; + Nd4jLong constantOffset = 0L; + if (_devicePointers[deviceId] == 0) { + auto constant = getConstantSpace(); + + // filling default ptr, which will be 0 probably + _devicePointers[deviceId] = constant; + _deviceOffsets[deviceId] = 0; + constantPtr = constant; + } else { + constantPtr = _devicePointers[deviceId]; + constantOffset = _deviceOffsets[deviceId]; + } - ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, sd::DataType dataType) { - const auto deviceId = getCurrentDevice(); + if (constantOffset + numBytes >= CONSTANT_LIMIT) { + int8_t *ptr = nullptr; + ALLOCATE_SPECIAL(ptr, workspace, numBytes, int8_t); + auto res = cudaMemcpy(ptr, src, numBytes, cudaMemcpyHostToDevice); + if (res != 0) throw cuda_exception::build("cudaMemcpy failed", res); - // all cache modifications are synchronous - _mutexHolder.lock(); + return ptr; + } else { + auto originalBytes = numBytes; + auto rem = numBytes % 8; + if (rem != 0) numBytes += 8 - rem; - if (_cache[deviceId].count(descriptor) == 0) { - _cache[deviceId][descriptor] = new ConstantHolder(); - } - auto holder = _cache[deviceId][descriptor]; + _deviceOffsets[deviceId] += numBytes; - // release cache lock - _mutexHolder.unlock(); + auto res = cudaMemcpyToSymbol(deviceConstantMemory, + const_cast(src), originalBytes, + constantOffset, cudaMemcpyHostToDevice); + if (res != 0) throw cuda_exception::build("cudaMemcpyToSymbol failed", res); - ConstantDataBuffer* result; + return reinterpret_cast(constantPtr) + constantOffset; + } +} - // access to this holder instance is synchronous - std::lock_guard lock(*holder->mutex()); +ConstantDataBuffer *ConstantHelper::constantBuffer( + const ConstantDescriptor &descriptor, sd::DataType dataType) { + const auto deviceId = getCurrentDevice(); - if (holder->hasBuffer(dataType)) { - result = holder->getConstantDataBuffer(dataType); - } else { - auto numBytes = descriptor.length() * DataTypeUtils::sizeOf(dataType); - auto cbuff = std::make_shared(new int8_t[numBytes], std::make_shared()); - _counters[deviceId] += numBytes; + // all cache modifications are synchronous + _mutexHolder.lock(); - // create buffer with this dtype - if (descriptor.isFloat()) { - BUILD_DOUBLE_SELECTOR(sd::DataType::DOUBLE, dataType, sd::SpecialTypeConverter::convertGeneric, (nullptr, const_cast(descriptor.floatValues().data()), descriptor.length(), cbuff->pointer()), (sd::DataType::DOUBLE, double), LIBND4J_TYPES); - } else if (descriptor.isInteger()) { - BUILD_DOUBLE_SELECTOR(sd::DataType::INT64, dataType, sd::SpecialTypeConverter::convertGeneric, (nullptr, const_cast(descriptor.integerValues().data()), descriptor.length(), cbuff->pointer()), (sd::DataType::INT64, Nd4jLong), LIBND4J_TYPES); - } + if (_cache[deviceId].count(descriptor) == 0) { + _cache[deviceId][descriptor] = new ConstantHolder(); + } + auto holder = _cache[deviceId][descriptor]; + + // release cache lock + _mutexHolder.unlock(); + + ConstantDataBuffer *result; + + // access to this holder instance is synchronous + std::lock_guard lock(*holder->mutex()); + + if (holder->hasBuffer(dataType)) { + result = holder->getConstantDataBuffer(dataType); + } else { + auto numBytes = descriptor.length() * DataTypeUtils::sizeOf(dataType); + auto cbuff = std::make_shared(new int8_t[numBytes], std::make_shared()); + _counters[deviceId] += numBytes; + + // create buffer with this dtype + if (descriptor.isFloat()) { + BUILD_DOUBLE_SELECTOR( + sd::DataType::DOUBLE, dataType, + sd::SpecialTypeConverter::convertGeneric, + (nullptr, const_cast(descriptor.floatValues().data()), + descriptor.length(), cbuff->pointer()), + (sd::DataType::DOUBLE, double), LIBND4J_TYPES); + } else if (descriptor.isInteger()) { + BUILD_DOUBLE_SELECTOR( + sd::DataType::INT64, dataType, + sd::SpecialTypeConverter::convertGeneric, + (nullptr, const_cast(descriptor.integerValues().data()), + descriptor.length(), cbuff->pointer()), + (sd::DataType::INT64, Nd4jLong), LIBND4J_TYPES); + } - // we don't have deallocator here. + // we don't have deallocator here. // TODO: we probably want to make use deallocator here, if we're not using constant memory - auto dbuff = std::make_shared(replicatePointer(cbuff->pointer(), descriptor.length() * DataTypeUtils::sizeOf(dataType))); + auto dbuff = std::make_shared(replicatePointer( + cbuff->pointer(), descriptor.length() * DataTypeUtils::sizeOf(dataType))); - ConstantDataBuffer dataBuffer(cbuff, dbuff, descriptor.length(), dataType); + ConstantDataBuffer dataBuffer(cbuff, dbuff, descriptor.length(), + dataType); - holder->addBuffer(dataBuffer, dataType); - result = holder->getConstantDataBuffer(dataType); - } + holder->addBuffer(dataBuffer, dataType); + result = holder->getConstantDataBuffer(dataType); + } - return result; - } + return result; +} - Nd4jLong ConstantHelper::getCachedAmount(int deviceId) { - int numDevices = getNumberOfDevices(); - if (deviceId > numDevices || deviceId < 0) - return 0L; - else - return _counters[deviceId]; - } -} \ No newline at end of file +Nd4jLong ConstantHelper::getCachedAmount(int deviceId) { + int numDevices = getNumberOfDevices(); + if (deviceId > numDevices || deviceId < 0) + return 0L; + else + return _counters[deviceId]; +} + + +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu b/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu index 35ba60ca97e2..4b2bad85ed03 100644 --- a/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu +++ b/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu @@ -18,173 +18,196 @@ // @author raver119@gmail.com // -#include "../ConstantShapeHelper.h" -#include #include -#include +#include #include #include #include #include +#include +#include namespace sd { - ConstantShapeHelper::ConstantShapeHelper() { - auto numDevices = AffinityManager::numberOfDevices(); - - _cache.resize(numDevices); - for (int e = 0; e < numDevices; e++) { - MAP_IMPL cache; - _cache[e] = cache; - } - } - - ConstantShapeHelper& ConstantShapeHelper::getInstance() { - static ConstantShapeHelper instance; - return instance; - } - - ConstantShapeBuffer& ConstantShapeHelper::bufferForShapeInfo(sd::DataType dataType, char order, const std::vector &shape) { - ShapeDescriptor descriptor(dataType, order, shape); - return bufferForShapeInfo(descriptor); - } - -ConstantShapeBuffer& ConstantShapeHelper::bufferForShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape) { - ShapeDescriptor descriptor(dataType, order, shape, rank); - return bufferForShapeInfo(descriptor); - } - -ConstantShapeBuffer& ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) { - int deviceId = AffinityManager::currentDeviceId(); - - std::lock_guard lock(_mutex); - - if (_cache[deviceId].count(descriptor) == 0) { - auto hPtr = std::make_shared(descriptor.toShapeInfo(), std::make_shared()); - auto dPtr = std::make_shared(ConstantHelper::getInstance().replicatePointer(hPtr->pointer(), shape::shapeInfoByteLength(hPtr->pointerAsT())), std::make_shared()); - ConstantShapeBuffer buffer(hPtr, dPtr); - ShapeDescriptor descriptor1(descriptor); - _cache[deviceId][descriptor1] = buffer; - return _cache[deviceId][descriptor1]; - } else { - return _cache[deviceId].at(descriptor); - } - } - -ConstantShapeBuffer& ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) { - ShapeDescriptor descriptor(shapeInfo); - return bufferForShapeInfo(descriptor); - } +ConstantShapeHelper::ConstantShapeHelper() { + auto numDevices = AffinityManager::numberOfDevices(); - bool ConstantShapeHelper::checkBufferExistenceForShapeInfo(ShapeDescriptor &descriptor) { - auto deviceId = AffinityManager::currentDeviceId(); - std::lock_guard lock(_mutex); + _cache.resize(numDevices); + for (int e = 0; e < numDevices; e++) { + MAP_IMPL cache; + _cache[e] = cache; + } +} - return _cache[deviceId].count(descriptor) != 0; - } +ConstantShapeHelper& ConstantShapeHelper::getInstance() { + static ConstantShapeHelper instance; - Nd4jLong const* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape) { - ShapeDescriptor descriptor(dataType, order, shape, rank); - return bufferForShapeInfo(descriptor).primary(); - } + return instance; +} - Nd4jLong const* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const Nd4jLong* shapeInfo) { - return ConstantShapeHelper::createShapeInfo(dataType, shape::order(shapeInfo), shape::rank(shapeInfo), shape::shapeOf(const_cast(shapeInfo))); - } +ConstantShapeBuffer& ConstantShapeHelper::bufferForShapeInfo( + sd::DataType dataType, char order, const std::vector& shape) { + ShapeDescriptor descriptor(dataType, order, shape); + return bufferForShapeInfo(descriptor); +} - Nd4jLong const* ConstantShapeHelper::emptyShapeInfo(const sd::DataType dataType) { - auto descriptor = ShapeDescriptor::emptyDescriptor(dataType); - return bufferForShapeInfo(descriptor).primary(); - } +ConstantShapeBuffer& ConstantShapeHelper::bufferForShapeInfo( + const sd::DataType dataType, const char order, const int rank, + const Nd4jLong* shape) { + ShapeDescriptor descriptor(dataType, order, shape, rank); + return bufferForShapeInfo(descriptor); +} - Nd4jLong const* ConstantShapeHelper::scalarShapeInfo(const sd::DataType dataType) { - auto descriptor = ShapeDescriptor::scalarDescriptor(dataType); - return bufferForShapeInfo(descriptor).primary(); - } +ConstantShapeBuffer& ConstantShapeHelper::bufferForShapeInfo( + const ShapeDescriptor& descriptor) { + int deviceId = AffinityManager::currentDeviceId(); + + std::lock_guard lock(_mutex); + + if (_cache[deviceId].count(descriptor) == 0) { + auto hPtr = std::make_shared(descriptor.toShapeInfo(), std::make_shared()); + auto dPtr = std::make_shared(ConstantHelper::getInstance().replicatePointer( + hPtr->pointer(), shape::shapeInfoByteLength(hPtr->pointerAsT())), std::make_shared()); + ConstantShapeBuffer buffer(hPtr, dPtr); + ShapeDescriptor descriptor1(descriptor); + _cache[deviceId][descriptor1] = buffer; + return _cache[deviceId][descriptor1]; + } else { + return _cache[deviceId].at(descriptor); + } +} - Nd4jLong const* ConstantShapeHelper::vectorShapeInfo(const Nd4jLong length, const sd::DataType dataType) { - auto descriptor = ShapeDescriptor::vectorDescriptor(length, dataType); - return bufferForShapeInfo(descriptor).primary(); - } +ConstantShapeBuffer& ConstantShapeHelper::bufferForShapeInfo( + const Nd4jLong* shapeInfo) { + ShapeDescriptor descriptor(shapeInfo); + return bufferForShapeInfo(descriptor); +} - Nd4jLong const* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const std::vector &shape) { - ShapeDescriptor descriptor(dataType, order, shape); - return bufferForShapeInfo(descriptor).primary(); - } +bool ConstantShapeHelper::checkBufferExistenceForShapeInfo( + ShapeDescriptor& descriptor) { + auto deviceId = AffinityManager::currentDeviceId(); + std::lock_guard lock(_mutex); - Nd4jLong const* ConstantShapeHelper::createShapeInfo(const ShapeDescriptor &descriptor) { - return bufferForShapeInfo(descriptor).primary(); - } + return _cache[deviceId].count(descriptor) != 0; +} - Nd4jLong const* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal) { - ShapeDescriptor descriptor(shapeInfo); - auto result = createShapeInfo(descriptor); +Nd4jLong const* ConstantShapeHelper::createShapeInfo( + const sd::DataType dataType, const char order, const int rank, + const Nd4jLong* shape) { + ShapeDescriptor descriptor(dataType, order, shape, rank); + return bufferForShapeInfo(descriptor).primary(); +} - if (destroyOriginal) - RELEASE(shapeInfo, nullptr); +Nd4jLong const* ConstantShapeHelper::createShapeInfo( + const sd::DataType dataType, const Nd4jLong* shapeInfo) { + return ConstantShapeHelper::createShapeInfo( + dataType, shape::order(shapeInfo), shape::rank(shapeInfo), + shape::shapeOf(const_cast(shapeInfo))); +} - return result; - } +Nd4jLong const* ConstantShapeHelper::emptyShapeInfo( + const sd::DataType dataType) { + auto descriptor = ShapeDescriptor::emptyDescriptor(dataType); + return bufferForShapeInfo(descriptor).primary(); +} - Nd4jLong const* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, sd::memory::Workspace *workspace) { - ShapeDescriptor descriptor(shapeInfo); - auto result = createShapeInfo(descriptor); +Nd4jLong const* ConstantShapeHelper::scalarShapeInfo( + const sd::DataType dataType) { + auto descriptor = ShapeDescriptor::scalarDescriptor(dataType); + return bufferForShapeInfo(descriptor).primary(); +} - RELEASE(shapeInfo, workspace); +Nd4jLong const* ConstantShapeHelper::vectorShapeInfo( + const Nd4jLong length, const sd::DataType dataType) { + auto descriptor = ShapeDescriptor::vectorDescriptor(length, dataType); + return bufferForShapeInfo(descriptor).primary(); +} - return result; - } +Nd4jLong const* ConstantShapeHelper::createShapeInfo( + const sd::DataType dataType, const char order, + const std::vector& shape) { + ShapeDescriptor descriptor(dataType, order, shape); + return bufferForShapeInfo(descriptor).primary(); +} -//////////////////////////////////////////////////////////////////////// -ConstantShapeBuffer& ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace, const std::vector& dimensions) { +Nd4jLong const* ConstantShapeHelper::createShapeInfo( + const ShapeDescriptor& descriptor) { + return bufferForShapeInfo(descriptor).primary(); +} - Nd4jLong* newShapeInfo = nullptr; - ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(shape::rank(maxShapeInfo)), Nd4jLong); +Nd4jLong const* ConstantShapeHelper::createFromExisting(Nd4jLong* shapeInfo, + bool destroyOriginal) { + ShapeDescriptor descriptor(shapeInfo); + auto result = createShapeInfo(descriptor); - newShapeInfo[0] = shape::rank(maxShapeInfo); + if (destroyOriginal) RELEASE(shapeInfo, nullptr); - sd::ArrayOptions::copyDataType(newShapeInfo, minShapeInfo); // type - newShapeInfo[2 * newShapeInfo[0] + 2] = shape::elementWiseStride(minShapeInfo); // ews - newShapeInfo[2 * newShapeInfo[0] + 3] = shape::order(minShapeInfo); // order + return result; +} - if(!dimensions.empty()) { +Nd4jLong const* ConstantShapeHelper::createFromExisting( + Nd4jLong* shapeInfo, sd::memory::Workspace* workspace) { + ShapeDescriptor descriptor(shapeInfo); + auto result = createShapeInfo(descriptor); - for(uint k = 0, j = 0, i = 0; i < shape::rank(maxShapeInfo); ++i) { + RELEASE(shapeInfo, workspace); - if(j < dimensions.size() && dimensions[j] == i) { - shape::shapeOf(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[k]; - shape::stride(newShapeInfo)[i] = shape::stride(minShapeInfo)[k++]; - ++j; - } - else{ - shape::shapeOf(newShapeInfo)[i] = 1; - shape::stride(newShapeInfo)[i] = 0; - if(shape::sizeAt(minShapeInfo, k) == 1 && dimensions.size() != shape::rank(minShapeInfo)) - ++k; - } - } - } - else{ - - for(int j = shape::rank(minShapeInfo) - 1, i = shape::rank(maxShapeInfo) - 1; i >=0 ; --i) { - - if(j >= 0) { - shape::shapeOf(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[j]; - shape::stride(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[j] == 1 ? 0 : shape::stride(minShapeInfo)[j]; - --j; - } - else { - shape::shapeOf(newShapeInfo)[i] = 1; - shape::stride(newShapeInfo)[i] = 0; - } - } - } + return result; +} - ShapeDescriptor descriptor(newShapeInfo); +//////////////////////////////////////////////////////////////////////// +ConstantShapeBuffer& ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast( + const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, + sd::memory::Workspace* workspace, const std::vector& dimensions) { + Nd4jLong* newShapeInfo = nullptr; + ALLOCATE(newShapeInfo, workspace, + shape::shapeInfoLength(shape::rank(maxShapeInfo)), Nd4jLong); + + newShapeInfo[0] = shape::rank(maxShapeInfo); + + sd::ArrayOptions::copyDataType(newShapeInfo, minShapeInfo); // type + newShapeInfo[2 * newShapeInfo[0] + 2] = + shape::elementWiseStride(minShapeInfo); // ews + newShapeInfo[2 * newShapeInfo[0] + 3] = shape::order(minShapeInfo); // order + + if (!dimensions.empty()) { + for (uint k = 0, j = 0, i = 0; i < shape::rank(maxShapeInfo); ++i) { + if (j < dimensions.size() && dimensions[j] == i) { + shape::shapeOf(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[k]; + shape::stride(newShapeInfo)[i] = shape::stride(minShapeInfo)[k++]; + ++j; + } else { + shape::shapeOf(newShapeInfo)[i] = 1; + shape::stride(newShapeInfo)[i] = 0; + if (shape::sizeAt(minShapeInfo, k) == 1 && + dimensions.size() != shape::rank(minShapeInfo)) + ++k; + } + } + } else { + for (int j = shape::rank(minShapeInfo) - 1, + i = shape::rank(maxShapeInfo) - 1; + i >= 0; --i) { + if (j >= 0) { + shape::shapeOf(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[j]; + shape::stride(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[j] == 1 + ? 0 + : shape::stride(minShapeInfo)[j]; + --j; + } else { + shape::shapeOf(newShapeInfo)[i] = 1; + shape::stride(newShapeInfo)[i] = 0; + } + } + } + + ShapeDescriptor descriptor(newShapeInfo); + + RELEASE(newShapeInfo, workspace); + + return bufferForShapeInfo(descriptor); +} - RELEASE(newShapeInfo, workspace); - return bufferForShapeInfo(descriptor); -} -} \ No newline at end of file +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/cuda/ConstantTadHelper.cu b/libnd4j/include/helpers/cuda/ConstantTadHelper.cu index 662c99e7ca5b..e33f8b8cf95c 100644 --- a/libnd4j/include/helpers/cuda/ConstantTadHelper.cu +++ b/libnd4j/include/helpers/cuda/ConstantTadHelper.cu @@ -18,96 +18,119 @@ // @author raver119@gmail.com // -#include "../ConstantTadHelper.h" -#include -#include -#include #include +#include #include +#include #include #include #include +#include +#include namespace sd { - ConstantTadHelper::ConstantTadHelper() { - auto numDevices = AffinityManager::numberOfDevices(); - - for (int e = 0; e < numDevices; e++) { - MAP_IMPL pack; - _cache.emplace_back(pack); - } - } - - ConstantTadHelper& ConstantTadHelper::getInstance() { - static ConstantTadHelper instance; - return instance; - } - - TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) { - return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape); - } - - TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector &dimensions, const bool keepUnitiesInShape) { - return tadForDimensions(originalShape, const_cast(dimensions.data()), dimensions.size(), keepUnitiesInShape); - } - - TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) { - TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape); - return tadForDimensions(tadDescriptor); - } - - TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector &dimensions, const bool keepUnitiesInShape) { - TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape); - return tadForDimensions(tadDescriptor); - } - - TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) { - const int deviceId = AffinityManager::currentDeviceId(); - - std::lock_guard lock(_mutex); - - if (_cache[deviceId].count(descriptor) == 0) { - const auto shapeInfo = descriptor.originalShape().toShapeInfo(); - const int rank = shape::rank(shapeInfo); - const std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(rank, descriptor.axis()); - const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(shapeInfo, dimsToExclude); - const int subArrRank = (rank == dimsToExclude.size() || descriptor.areUnitiesinShape()) ? rank : rank - dimsToExclude.size(); - - auto sPtr = std::make_shared(new Nd4jLong[shape::shapeInfoLength(subArrRank)], std::make_shared()); - auto oPtr = std::make_shared(new Nd4jLong[numOfSubArrs], std::make_shared()); - - if (numOfSubArrs > 0) - shape::calcSubArrsShapeInfoAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr->pointerAsT(), oPtr->pointerAsT(), descriptor.areUnitiesinShape()); - - Nd4jPointer soPtr; - auto res = cudaMalloc(reinterpret_cast(&soPtr), numOfSubArrs * sizeof(Nd4jLong)); - if (res != 0) - throw cuda_exception::build("Memory allocation for tadOffsets failed", res); - - res = cudaMemcpy(soPtr, oPtr->pointer(), numOfSubArrs * sizeof(Nd4jLong), cudaMemcpyHostToDevice); - if (res != 0) - throw cuda_exception::build("tadOffsets copy failed", res); - - // TODO: add deallocator here? - auto ssPtr = std::make_shared(ConstantHelper::getInstance().replicatePointer(sPtr->pointer(), shape::shapeInfoByteLength(subArrRank))); - - - - ConstantShapeBuffer shapesBuffer(sPtr, ssPtr); - ConstantOffsetsBuffer offsetsBuffer(oPtr, std::make_shared(soPtr, std::make_shared())); - - TadPack t(shapesBuffer, offsetsBuffer, numOfSubArrs); - _cache[deviceId][descriptor] = t; - - TadPack r = _cache[deviceId][descriptor]; - - delete[] shapeInfo; - - return r; - } else { - TadPack r = _cache[deviceId][descriptor]; - - return r; - } - } -} \ No newline at end of file +ConstantTadHelper::ConstantTadHelper() { + auto numDevices = AffinityManager::numberOfDevices(); + + for (int e = 0; e < numDevices; e++) { + MAP_IMPL pack; + _cache.emplace_back(pack); + } +} + +ConstantTadHelper &ConstantTadHelper::getInstance() { + static ConstantTadHelper instance; + + return instance; +} + +TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, + int dimension, + const bool keepUnitiesInShape) { + return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape); +} + +TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, + const std::vector &dimensions, + const bool keepUnitiesInShape) { + return tadForDimensions(originalShape, const_cast(dimensions.data()), + dimensions.size(), keepUnitiesInShape); +} + +TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, + int *dimensions, int dimLength, + const bool keepUnitiesInShape) { + TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, + keepUnitiesInShape); + return tadForDimensions(tadDescriptor); +} + +TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, + std::vector &dimensions, + const bool keepUnitiesInShape) { + TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape); + return tadForDimensions(tadDescriptor); +} + +TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) { + const int deviceId = AffinityManager::currentDeviceId(); + + std::lock_guard lock(_mutex); + + if (_cache[deviceId].count(descriptor) == 0) { + const auto shapeInfo = descriptor.originalShape().toShapeInfo(); + const int rank = shape::rank(shapeInfo); + const std::vector dimsToExclude = + ShapeUtils::evalDimsToExclude(rank, descriptor.axis()); + const Nd4jLong numOfSubArrs = + ShapeUtils::getNumOfSubArrs(shapeInfo, dimsToExclude); + const int subArrRank = + (rank == dimsToExclude.size() || descriptor.areUnitiesinShape()) + ? rank + : rank - dimsToExclude.size(); + + auto sPtr = std::make_shared(new Nd4jLong[shape::shapeInfoLength(subArrRank)], std::make_shared()); + auto oPtr = std::make_shared(new Nd4jLong[numOfSubArrs], std::make_shared()); + + if (numOfSubArrs > 0) + shape::calcSubArrsShapeInfoAndOffsets( + shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), + sPtr->pointerAsT(), oPtr->pointerAsT(), descriptor.areUnitiesinShape()); + + Nd4jPointer soPtr; + auto res = cudaMalloc(reinterpret_cast(&soPtr), + numOfSubArrs * sizeof(Nd4jLong)); + if (res != 0) + throw cuda_exception::build("Memory allocation for tadOffsets failed", + res); + + res = cudaMemcpy(soPtr, oPtr->pointer(), numOfSubArrs * sizeof(Nd4jLong), + cudaMemcpyHostToDevice); + if (res != 0) throw cuda_exception::build("tadOffsets copy failed", res); + + // TODO: add deallocator here? + auto ssPtr = std::make_shared(ConstantHelper::getInstance().replicatePointer( + sPtr->pointer(), shape::shapeInfoByteLength(subArrRank))); + + ConstantShapeBuffer shapesBuffer( + sPtr, ssPtr); + ConstantOffsetsBuffer offsetsBuffer( + oPtr, std::make_shared(soPtr, std::make_shared())); + + TadPack t(shapesBuffer, offsetsBuffer, numOfSubArrs); + _cache[deviceId][descriptor] = t; + + TadPack r = _cache[deviceId][descriptor]; + + delete[] shapeInfo; + + return r; + } else { + TadPack r = _cache[deviceId][descriptor]; + + return r; + } +} + + +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/cuda/PointersManager.cu b/libnd4j/include/helpers/cuda/PointersManager.cu index dc5fe15f50b1..990097c39cc1 100644 --- a/libnd4j/include/helpers/cuda/PointersManager.cu +++ b/libnd4j/include/helpers/cuda/PointersManager.cu @@ -19,8 +19,8 @@ // @author raver119@gmail.com // -#include #include +#include #include #include #include @@ -28,97 +28,121 @@ namespace sd { ////////////////////////////////////////////////////////////////////////// -PointersManager::PointersManager(const sd::LaunchContext* context, const std::string& funcName) { - _context = const_cast(context); - _funcName = funcName; +PointersManager::PointersManager(const sd::LaunchContext* context, + const std::string& funcName) { + _context = const_cast(context); + _funcName = funcName; } ////////////////////////////////////////////////////////////////////////// -void* PointersManager::replicatePointer(const void* src, const size_t numberOfBytes) { - - void* dst = nullptr; - if (_context->getWorkspace() == nullptr) { - cudaError_t cudaResult = cudaMalloc(reinterpret_cast(&dst), numberOfBytes); - if (cudaResult != 0) - throw cuda_exception::build(_funcName + ": cannot allocate global memory on device!", cudaResult); - } else { - dst = _context->getWorkspace()->allocateBytes(sd::memory::MemoryType::DEVICE, numberOfBytes); - } - - if (_context != nullptr) - cudaMemcpyAsync(dst, src, numberOfBytes, cudaMemcpyHostToDevice, *_context->getCudaStream()); - else - cudaMemcpy(dst, src, numberOfBytes, cudaMemcpyHostToDevice); - - _pOnGlobMem.emplace_back(dst); - - return dst; +void* PointersManager::replicatePointer(const void* src, + const size_t numberOfBytes) { + void* dst = nullptr; + if (_context->getWorkspace() == nullptr) { + cudaError_t cudaResult = + cudaMalloc(reinterpret_cast(&dst), numberOfBytes); + if (cudaResult != 0) + throw cuda_exception::build( + _funcName + ": cannot allocate global memory on device!", cudaResult); + } else { + dst = _context->getWorkspace()->allocateBytes( + sd::memory::MemoryType::DEVICE, numberOfBytes); + } + + if (_context != nullptr) + cudaMemcpyAsync(dst, src, numberOfBytes, cudaMemcpyHostToDevice, + *_context->getCudaStream()); + else + cudaMemcpy(dst, src, numberOfBytes, cudaMemcpyHostToDevice); + + _pOnGlobMem.emplace_back(dst); + + return dst; } ////////////////////////////////////////////////////////////////////////// void PointersManager::synchronize() const { - if (_context != nullptr) { - cudaError_t cudaResult = cudaStreamSynchronize(*_context->getCudaStream()); - if (cudaResult != 0) - throw cuda_exception::build(_funcName + ": cuda stream synchronization failed !", cudaResult); - } else { - nd4j_printf("<%s> syncStream isn't possible: no stream set!", _funcName.c_str()); - } + if (_context != nullptr) { + cudaError_t cudaResult = cudaStreamSynchronize(*_context->getCudaStream()); + if (cudaResult != 0) + throw cuda_exception::build( + _funcName + ": cuda stream synchronization failed !", cudaResult); + } else { + nd4j_printf("<%s> syncStream isn't possible: no stream set!", + _funcName.c_str()); + } } ////////////////////////////////////////////////////////////////////////// PointersManager::~PointersManager() { - - for (auto& p :_pOnGlobMem) - cudaFree(p); + for (auto& p : _pOnGlobMem) cudaFree(p); } - //////////////////////////////////////////////////////////////////////// template -static __global__ void printDevContentOnDev_(const void* pDev, const Nd4jLong len, const int tid) { - - PointersManager::printDevContentOnDev(pDev, len, tid); +static __global__ void printDevContentOnDev_(const void* pDev, + const Nd4jLong len, + const int tid) { + PointersManager::printDevContentOnDev(pDev, len, tid); } //////////////////////////////////////////////////////////////////////// -template -void PointersManager::printDevContentOnDevFromHost(const void* pDev, const Nd4jLong len, const int tid) { - printDevContentOnDev_<<<512, 512, 1024, *sd::LaunchContext ::defaultContext()->getCudaStream()>>>(pDev, len, tid); - auto res = cudaStreamSynchronize(*sd::LaunchContext ::defaultContext()->getCudaStream()); - if (res != 0) - throw std::runtime_error("PointersManager::printDevContentOnDevFromHost: cudaStreamSynchronize failed!"); +template +void PointersManager::printDevContentOnDevFromHost(const void* pDev, + const Nd4jLong len, + const int tid) { + printDevContentOnDev_ + <<<512, 512, 1024, + *sd::LaunchContext ::defaultContext()->getCudaStream()>>>(pDev, len, + tid); + auto res = cudaStreamSynchronize( + *sd::LaunchContext ::defaultContext()->getCudaStream()); + if (res != 0) + throw std::runtime_error( + "PointersManager::printDevContentOnDevFromHost: cudaStreamSynchronize " + "failed!"); } -template void PointersManager::printDevContentOnDevFromHost(const void* pDev, const Nd4jLong len, const int tid); -template void PointersManager::printDevContentOnDevFromHost(const void* pDev, const Nd4jLong len, const int tid); -template void PointersManager::printDevContentOnDevFromHost(const void* pDev, const Nd4jLong len, const int tid); -template void PointersManager::printDevContentOnDevFromHost(const void* pDev, const Nd4jLong len, const int tid); - -//BUILD_SINGLE_TEMPLATE(template void PointersManager::printDevContentOnDevFromHost, (void* pDev, Nd4jLong len, int tid), LIBND4J_TYPES); +template void PointersManager::printDevContentOnDevFromHost( + const void* pDev, const Nd4jLong len, const int tid); +template void PointersManager::printDevContentOnDevFromHost( + const void* pDev, const Nd4jLong len, const int tid); +template void PointersManager::printDevContentOnDevFromHost( + const void* pDev, const Nd4jLong len, const int tid); +template void PointersManager::printDevContentOnDevFromHost( + const void* pDev, const Nd4jLong len, const int tid); + +// BUILD_SINGLE_TEMPLATE(template void +// PointersManager::printDevContentOnDevFromHost, (void* pDev, Nd4jLong len, int +// tid), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// -template -void PointersManager::printDevContentOnHost(const void* pDev, const Nd4jLong len) const { - printf("host print out\n"); - void* pHost = operator new(sizeof(T) * len); - - cudaMemcpyAsync(pHost, pDev, sizeof(T) * len, cudaMemcpyDeviceToHost, *_context->getCudaStream()); - cudaError_t cudaResult = cudaStreamSynchronize(*_context->getCudaStream()); - if(cudaResult != 0) - throw std::runtime_error("PointersManager::printCudaHost: cudaStreamSynchronize failed!"); - - for(Nd4jLong i = 0; i < len; ++i) - printf("%f, ", (double)reinterpret_cast(pHost)[i]); - printf("\n"); - - operator delete(pHost); +template +void PointersManager::printDevContentOnHost(const void* pDev, + const Nd4jLong len) const { + printf("host print out\n"); + void* pHost = operator new(sizeof(T) * len); + + cudaMemcpyAsync(pHost, pDev, sizeof(T) * len, cudaMemcpyDeviceToHost, + *_context->getCudaStream()); + cudaError_t cudaResult = cudaStreamSynchronize(*_context->getCudaStream()); + if (cudaResult != 0) + throw std::runtime_error( + "PointersManager::printCudaHost: cudaStreamSynchronize failed!"); + + for (Nd4jLong i = 0; i < len; ++i) + printf("%f, ", (double)reinterpret_cast(pHost)[i]); + printf("\n"); + + operator delete(pHost); } +template void PointersManager::printDevContentOnHost( + const void* pDev, const Nd4jLong len) const; +template void PointersManager::printDevContentOnHost( + const void* pDev, const Nd4jLong len) const; +template void PointersManager::printDevContentOnHost( + const void* pDev, const Nd4jLong len) const; +template void PointersManager::printDevContentOnHost( + const void* pDev, const Nd4jLong len) const; -template void PointersManager::printDevContentOnHost(const void* pDev, const Nd4jLong len) const; -template void PointersManager::printDevContentOnHost(const void* pDev, const Nd4jLong len) const; -template void PointersManager::printDevContentOnHost(const void* pDev, const Nd4jLong len) const; -template void PointersManager::printDevContentOnHost(const void* pDev, const Nd4jLong len) const; - - -} +} // namespace sd diff --git a/libnd4j/include/helpers/cuda_off/MmulHelper.cu b/libnd4j/include/helpers/cuda_off/MmulHelper.cu index d1122d794839..381837a37cd8 100644 --- a/libnd4j/include/helpers/cuda_off/MmulHelper.cu +++ b/libnd4j/include/helpers/cuda_off/MmulHelper.cu @@ -19,502 +19,597 @@ // @author raver119@gmail.com // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include -#include "../MmulHelper.h" -#include -#include +#include #include +#include +#include + #include +#include "../MmulHelper.h" + namespace sd { ////////////////////////////////////////////////////////////////////////////// // MXK x KxN = MxN -> actual sequence of axes doesn't matter template -static __global__ void usualCudaGemm(const void* vA, const Nd4jLong* aShapeInfo, const void* vB, const Nd4jLong* bShapeInfo, void* vC, const Nd4jLong* cShapeInfo, - const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, +static __global__ void usualCudaGemm(const void* vA, const Nd4jLong* aShapeInfo, + const void* vB, const Nd4jLong* bShapeInfo, + void* vC, const Nd4jLong* cShapeInfo, + const int aMaxis, const int aKaxis, + const int bKaxis, const int bNaxis, + const int cMaxis, const int cNaxis, const double alpha, const double beta) { + const T1* A = reinterpret_cast(vA); + const T2* B = reinterpret_cast(vB); + T3* C = reinterpret_cast(vC); - const T1* A = reinterpret_cast(vA); - const T2* B = reinterpret_cast(vB); - T3* C = reinterpret_cast< T3*>(vC); - - __shared__ int K, *coords; - __shared__ bool betaPresent; - __shared__ Nd4jLong cLen, totalThreads; - __shared__ T3 alphaZ, betaZ; - - if (threadIdx.x == 0) { + __shared__ int K, *coords; + __shared__ bool betaPresent; + __shared__ Nd4jLong cLen, totalThreads; + __shared__ T3 alphaZ, betaZ; - extern __shared__ unsigned char shmem[]; - coords = reinterpret_cast(shmem); - cLen = shape::length(cShapeInfo); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + coords = reinterpret_cast(shmem); + cLen = shape::length(cShapeInfo); - K = shape::shapeOf(const_cast(aShapeInfo))[aKaxis]; + K = shape::shapeOf(const_cast(aShapeInfo))[aKaxis]; - betaPresent = beta; + betaPresent = beta; - totalThreads = gridDim.x * blockDim.x; + totalThreads = gridDim.x * blockDim.x; - alphaZ = alpha; - betaZ = beta; - } - __syncthreads(); - - auto aCoords = coords + threadIdx.x * 6; // 6 = (aRank + bRank + cRank) - auto bCoords = aCoords + 2; - auto cCoords = bCoords + 2; + alphaZ = alpha; + betaZ = beta; + } + __syncthreads(); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto aCoords = coords + threadIdx.x * 6; // 6 = (aRank + bRank + cRank) + auto bCoords = aCoords + 2; + auto cCoords = bCoords + 2; - for (Nd4jLong i = tid; i < cLen; i += totalThreads) { + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - // evaluate C coordinates - shape::index2coords(i, cShapeInfo, cCoords); + for (Nd4jLong i = tid; i < cLen; i += totalThreads) { + // evaluate C coordinates + shape::index2coords(i, cShapeInfo, cCoords); - // evaluate A coordinates - aCoords[aMaxis] = cCoords[cMaxis]; - aCoords[aKaxis] = 0; + // evaluate A coordinates + aCoords[aMaxis] = cCoords[cMaxis]; + aCoords[aKaxis] = 0; - // evaluate B coordinates - bCoords[bKaxis] = 0; - bCoords[bNaxis] = cCoords[cNaxis]; + // evaluate B coordinates + bCoords[bKaxis] = 0; + bCoords[bNaxis] = cCoords[cNaxis]; - auto aOffset = shape::getOffset(aShapeInfo, aCoords); - auto bOffset = shape::getOffset(bShapeInfo, bCoords); + auto aOffset = shape::getOffset(aShapeInfo, aCoords); + auto bOffset = shape::getOffset(bShapeInfo, bCoords); - T3 val = A[aOffset] * B[bOffset]; // first iteration + T3 val = A[aOffset] * B[bOffset]; // first iteration - for (uint j = 1; j < K; ++j) { // rest iterations - aOffset += shape::stride(aShapeInfo)[aKaxis]; - bOffset += shape::stride(bShapeInfo)[bKaxis]; - val = val + A[aOffset] * B[bOffset]; - } + for (uint j = 1; j < K; ++j) { // rest iterations + aOffset += shape::stride(aShapeInfo)[aKaxis]; + bOffset += shape::stride(bShapeInfo)[bKaxis]; + val = val + A[aOffset] * B[bOffset]; + } - auto cOffset = shape::getOffset(cShapeInfo, cCoords); + auto cOffset = shape::getOffset(cShapeInfo, cCoords); - if(betaPresent) - C[cOffset] = alphaZ * val + betaZ * C[cOffset]; - else - C[cOffset] = alphaZ * val; - } + if (betaPresent) + C[cOffset] = alphaZ * val + betaZ * C[cOffset]; + else + C[cOffset] = alphaZ * val; + } } //////////////////////////////////////////////////////////////////////// template -__host__ static void usualGemm(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, cudaStream_t *stream, const void* vA, const Nd4jLong* aShapeInfo, const void* vB, const Nd4jLong* bShapeInfo, void* vC, const Nd4jLong* cShapeInfo, const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, const double alpha, const double beta) { - - usualCudaGemm<<>>(vA, aShapeInfo, vB, bShapeInfo, vC, cShapeInfo, aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta); +__host__ static void usualGemm( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + cudaStream_t* stream, const void* vA, const Nd4jLong* aShapeInfo, + const void* vB, const Nd4jLong* bShapeInfo, void* vC, + const Nd4jLong* cShapeInfo, const int aMaxis, const int aKaxis, + const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, + const double alpha, const double beta) { + usualCudaGemm + <<>>( + vA, aShapeInfo, vB, bShapeInfo, vC, cShapeInfo, aMaxis, aKaxis, + bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta); } //////////////////////////////////////////////////////////////////////// // MXN x N = M -> actual sequence of {M,N} axes doesn't matter template -static __global__ void usualCudaGemv(const void* vA, const Nd4jLong* aShapeInfo, const void* vX, const Nd4jLong* xShapeInfo, void* vY, const Nd4jLong* yShapeInfo, - const int incx, const int incy, const int aMaxis, const double alpha, const double beta) { - - const T1* A = reinterpret_cast(vA); - const T2* X = reinterpret_cast(vX); - T3* Y = reinterpret_cast< T3*>(vY); - - __shared__ int M, N; - __shared__ bool betaPresent; - __shared__ Nd4jLong cLen, totalThreads, aNstride, aMstride; - __shared__ T3 alphaZ, betaZ; - - if (threadIdx.x == 0) { - - N = shape::length(xShapeInfo); - M = shape::length(yShapeInfo); - - aMstride = shape::stride(aShapeInfo)[aMaxis]; - aNstride = shape::stride(aShapeInfo)[aMaxis == 0 ? 1 : 0]; - - totalThreads = gridDim.x * blockDim.x; - - betaPresent = beta; - - alphaZ = alpha; - betaZ = beta; +static __global__ void usualCudaGemv(const void* vA, const Nd4jLong* aShapeInfo, + const void* vX, const Nd4jLong* xShapeInfo, + void* vY, const Nd4jLong* yShapeInfo, + const int incx, const int incy, + const int aMaxis, const double alpha, + const double beta) { + const T1* A = reinterpret_cast(vA); + const T2* X = reinterpret_cast(vX); + T3* Y = reinterpret_cast(vY); + + __shared__ int M, N; + __shared__ bool betaPresent; + __shared__ Nd4jLong cLen, totalThreads, aNstride, aMstride; + __shared__ T3 alphaZ, betaZ; + + if (threadIdx.x == 0) { + N = shape::length(xShapeInfo); + M = shape::length(yShapeInfo); + + aMstride = shape::stride(aShapeInfo)[aMaxis]; + aNstride = shape::stride(aShapeInfo)[aMaxis == 0 ? 1 : 0]; + + totalThreads = gridDim.x * blockDim.x; + + betaPresent = beta; + + alphaZ = alpha; + betaZ = beta; + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < M; i += totalThreads) { + // evaluate offsets + auto aOffset = i * aMstride; + auto xOffset = 0; + + T3 val = A[aOffset] * X[xOffset]; // first iteration + + for (uint j = 1; j < N; ++j) { // rest iterations + aOffset += aNstride; + xOffset += incx; + val = val + A[aOffset] * X[xOffset]; } - __syncthreads(); - - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < M; i += totalThreads) { - // evaluate offsets - auto aOffset = i * aMstride; - auto xOffset = 0; + auto yOffset = i * incy; - T3 val = A[aOffset] * X[xOffset]; // first iteration - - for (uint j = 1; j < N; ++j) { // rest iterations - aOffset += aNstride; - xOffset += incx; - val = val + A[aOffset] * X[xOffset]; - } - - auto yOffset = i * incy; - - if(betaPresent) - Y[yOffset] = alphaZ * val + betaZ * Y[yOffset]; - else - Y[yOffset] = alphaZ * val; - } + if (betaPresent) + Y[yOffset] = alphaZ * val + betaZ * Y[yOffset]; + else + Y[yOffset] = alphaZ * val; + } } //////////////////////////////////////////////////////////////////////// template -__host__ static void usualGemv(const int blocksPerGrid, const int threadsPerBlock, cudaStream_t *stream, const void* vA, const Nd4jLong* aShapeInfo, const void* vX, const Nd4jLong* xShapeInfo, void* vY, const Nd4jLong* yShapeInfo, const int incx, const int incy, const int aMaxis, const double alpha, const double beta) { - - usualCudaGemv<<>>(vA, aShapeInfo, vX, xShapeInfo, vY, yShapeInfo, incx, incy, aMaxis, alpha, beta); +__host__ static void usualGemv(const int blocksPerGrid, + const int threadsPerBlock, cudaStream_t* stream, + const void* vA, const Nd4jLong* aShapeInfo, + const void* vX, const Nd4jLong* xShapeInfo, + void* vY, const Nd4jLong* yShapeInfo, + const int incx, const int incy, const int aMaxis, + const double alpha, const double beta) { + usualCudaGemv<<>>( + vA, aShapeInfo, vX, xShapeInfo, vY, yShapeInfo, incx, incy, aMaxis, alpha, + beta); } - ////////////////////////////////////////////////////////////////////////////// template -static __global__ void usualCudaDot(const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ) { - - T1* X = reinterpret_cast(const_cast(vX)); - T2* Y = reinterpret_cast(const_cast(vY)); - T3* Z = reinterpret_cast(vZ); +static __global__ void usualCudaDot(const Nd4jLong length, const double alpha, + const void* vX, const Nd4jLong incx, + const void* vY, const Nd4jLong incy, + const double beta, void* vZ) { + T1* X = reinterpret_cast(const_cast(vX)); + T2* Y = reinterpret_cast(const_cast(vY)); + T3* Z = reinterpret_cast(vZ); - extern __shared__ unsigned char shmem[]; - auto pairwiseMul = reinterpret_cast(shmem); + extern __shared__ unsigned char shmem[]; + auto pairwiseMul = reinterpret_cast(shmem); - const int tid = blockIdx.x * blockDim.x + threadIdx.x; - if(tid < length) - pairwiseMul[tid] = X[tid * incx] * Y[tid * incy]; + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < length) pairwiseMul[tid] = X[tid * incx] * Y[tid * incy]; - __syncthreads(); + __syncthreads(); - if(tid == 0) { - T3 sum = 0; - for(Nd4jLong i = 0; i < length; ++i) - sum = sum + pairwiseMul[i]; + if (tid == 0) { + T3 sum = 0; + for (Nd4jLong i = 0; i < length; ++i) sum = sum + pairwiseMul[i]; - if(beta) - *Z = (T3)alpha * sum + (T3)beta * *Z; - else - *Z = (T3)alpha * sum; - } + if (beta) + *Z = (T3)alpha * sum + (T3)beta * *Z; + else + *Z = (T3)alpha * sum; + } } //////////////////////////////////////////////////////////////////////// template -__host__ static void usualDot(const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ) { - - usualCudaDot<<>>(length, alpha, vX, incx, vY, incy, beta, vZ); +__host__ static void usualDot(const dim3& blocksPerGrid, + const dim3& threadsPerBlock, cudaStream_t* stream, + const Nd4jLong length, const double alpha, + const void* vX, const Nd4jLong incx, + const void* vY, const Nd4jLong incy, + const double beta, void* vZ) { + usualCudaDot + <<>>( + length, alpha, vX, incx, vY, incy, beta, vZ); } ////////////////////////////////////////////////////////////////////////////// // MXK x KxN = MxN -NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, double alpha, double beta, const char outOrder) { - - if(A->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxM cuda: rank of A array is not equal 2 !"); - if(B->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxM cuda: rank of B array is not equal 2 !"); - - const auto M = A->sizeAt(0); - const auto K = A->sizeAt(1); - const auto N = B->sizeAt(1); - - if(C != nullptr && C->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxM cuda: rank of C array is not equal 2 !"); - if(B->sizeAt(0) != K) - throw std::runtime_error("MmulHelper::mmulMxM cuda: B array has wrong number of rows !"); - if(C != nullptr && C->sizeAt(0) != M) - throw std::runtime_error("MmulHelper::mmulMxM cuda: C array has wrong number of rows !"); - if(C != nullptr && C->sizeAt(1) != N) - throw std::runtime_error("MmulHelper::mmulMxM cuda: C array has wrong number of columns !"); - - if(C == nullptr) - C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext()); - - if (C->isEmpty()) - return C; - - const int major = Environment::getInstance().capabilities()[AffinityManager::currentDeviceId()].first(); - - const auto aType = A->dataType(); - const auto bType = B->dataType(); - const auto cType = C->dataType(); - - const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC); - - const bool typeDouble = ABC && aType == DataType::DOUBLE; - const bool typeFloat = ABC && aType == DataType::FLOAT32; - const bool typeHalf = ABC && aType == DataType::HALF && major >= 6; - const bool typeIntFloat = AB && aType == DataType::INT8 && cType == DataType::FLOAT32 && major >= 6; - const bool typeHalfFloat = AB && aType == DataType::HALF && cType == DataType::FLOAT32 && major >= 6; - - std::lock_guard lock(*LaunchContext::deviceMutex()); +NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, + double alpha, double beta, const char outOrder) { + if (A->rankOf() != 2) + throw std::runtime_error( + "MmulHelper::mmulMxM cuda: rank of A array is not equal 2 !"); + if (B->rankOf() != 2) + throw std::runtime_error( + "MmulHelper::mmulMxM cuda: rank of B array is not equal 2 !"); + + const auto M = A->sizeAt(0); + const auto K = A->sizeAt(1); + const auto N = B->sizeAt(1); + + if (C != nullptr && C->rankOf() != 2) + throw std::runtime_error( + "MmulHelper::mmulMxM cuda: rank of C array is not equal 2 !"); + if (B->sizeAt(0) != K) + throw std::runtime_error( + "MmulHelper::mmulMxM cuda: B array has wrong number of rows !"); + if (C != nullptr && C->sizeAt(0) != M) + throw std::runtime_error( + "MmulHelper::mmulMxM cuda: C array has wrong number of rows !"); + if (C != nullptr && C->sizeAt(1) != N) + throw std::runtime_error( + "MmulHelper::mmulMxM cuda: C array has wrong number of columns !"); + + if (C == nullptr) + C = new NDArray( + outOrder, {M, N}, + DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), + A->getContext()); + + if (C->isEmpty()) return C; + + const int major = Environment::getInstance() + .capabilities()[AffinityManager::currentDeviceId()] + .first(); + + const auto aType = A->dataType(); + const auto bType = B->dataType(); + const auto cType = C->dataType(); + + const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC); + + const bool typeDouble = ABC && aType == DataType::DOUBLE; + const bool typeFloat = ABC && aType == DataType::FLOAT32; + const bool typeHalf = ABC && aType == DataType::HALF && major >= 6; + const bool typeIntFloat = + AB && aType == DataType::INT8 && cType == DataType::FLOAT32 && major >= 6; + const bool typeHalfFloat = + AB && aType == DataType::HALF && cType == DataType::FLOAT32 && major >= 6; + + std::lock_guard lock(*LaunchContext::deviceMutex()); + + auto handle = + reinterpret_cast(A->getContext()->getCublasHandle()); + auto stream = A->getContext()->getCudaStream(); + + auto status = cublasSetStream_v2(*handle, *stream); + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); + + if (!typeDouble && !typeFloat && !typeHalf && !typeIntFloat && + !typeHalfFloat) { + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (C->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + threadsPerBlock * sizeof(int) * 6 + 128; // 6 = aRank + bRank + cRank - auto handle = reinterpret_cast(A->getContext()->getCublasHandle()); - auto stream = A->getContext()->getCudaStream(); - - auto status = cublasSetStream_v2(*handle, *stream); - if (status != CUBLAS_STATUS_SUCCESS) - throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); + NDArray::prepareSpecialUse({C}, {A, B}); + // BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, + // threadsPerBlock, sharedMem, stream, A->specialBuffer(), + // A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), + // C->specialBuffer(), C->special(), 0, 1, 0, 1, 0, 1, alpha, + // beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE( + aType, usualGemm, + (blocksPerGrid, threadsPerBlock, sharedMem, stream, A->specialBuffer(), + A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), + C->specialBuffer(), C->specialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, + beta), + NUMERIC_TYPES) + NDArray::registerSpecialUse({C}, {A, B}); - if(!typeDouble && !typeFloat && !typeHalf && !typeIntFloat && !typeHalfFloat) { + auto cudaResult = cudaStreamSynchronize(*stream); + if (cudaResult != 0) + throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", + cudaResult); + } else { + std::vector toDelete; - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (C->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * 6 + 128; // 6 = aRank + bRank + cRank + NDArray *pA(const_cast(A)), *pB(const_cast(B)), + *pC(const_cast(C)); - NDArray::prepareSpecialUse({C}, {A, B}); - // BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, stream, A->specialBuffer(), A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), C->specialBuffer(), C->special(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); - BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, stream, A->specialBuffer(), A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), C->specialBuffer(), C->specialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES) - NDArray::registerSpecialUse({C}, {A, B}); + bool aMcont = M == 1 || A->strideAt(0) == 1; + bool aKcont = K == 1 || A->strideAt(1) == 1; + bool bKcont = K == 1 || B->strideAt(0) == 1; + bool bNcont = N == 1 || B->strideAt(1) == 1; + bool cMcont = M == 1 || C->strideAt(0) == 1; + bool cNcont = N == 1 || C->strideAt(1) == 1; - auto cudaResult = cudaStreamSynchronize(*stream); - if (cudaResult != 0) - throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", cudaResult); + if (!aMcont && !aKcont) { + pA = new NDArray(A->dup('f')); + toDelete.push_back(pA); + aMcont = true; + } + if (!bKcont && !bNcont) { + pB = new NDArray(B->dup('f')); + toDelete.push_back(pB); + bKcont = true; + } + if (!cMcont) { + pC = new NDArray(C->dup('f')); + toDelete.push_back(pC); + cMcont = true; } - else { - - std::vector toDelete; - - NDArray *pA(const_cast(A)), *pB(const_cast(B)), *pC(const_cast(C)); - - bool aMcont = M == 1 || A->strideAt(0) == 1; - bool aKcont = K == 1 || A->strideAt(1) == 1; - bool bKcont = K == 1 || B->strideAt(0) == 1; - bool bNcont = N == 1 || B->strideAt(1) == 1; - bool cMcont = M == 1 || C->strideAt(0) == 1; - bool cNcont = N == 1 || C->strideAt(1) == 1; - if(!aMcont && !aKcont) { - pA = new NDArray(A->dup('f')); - toDelete.push_back(pA); - aMcont = true; - } - if(!bKcont && !bNcont) { - pB = new NDArray(B->dup('f')); - toDelete.push_back(pB); - bKcont = true; - } - if(!cMcont) { - pC = new NDArray(C->dup('f')); - toDelete.push_back(pC); - cMcont = true; - } + const bool transA = !aMcont; + const bool transB = !bKcont; - const bool transA = !aMcont; - const bool transB = !bKcont; + const int lda = + (aMcont && aKcont) ? M : transA ? pA->strideAt(0) : pA->strideAt(1); + const int ldb = + (bKcont && bNcont) ? K : transB ? pB->strideAt(0) : pB->strideAt(1); + const int ldc = (cMcont && cNcont) ? M : pC->strideAt(1); - const int lda = (aMcont && aKcont) ? M : transA ? pA->strideAt(0) : pA->strideAt(1); - const int ldb = (bKcont && bNcont) ? K : transB ? pB->strideAt(0) : pB->strideAt(1); - const int ldc = (cMcont && cNcont) ? M : pC->strideAt(1); + const cublasOperation_t transAblas = transA ? CUBLAS_OP_T : CUBLAS_OP_N; + const cublasOperation_t transBblas = transB ? CUBLAS_OP_T : CUBLAS_OP_N; - const cublasOperation_t transAblas = transA ? CUBLAS_OP_T : CUBLAS_OP_N; - const cublasOperation_t transBblas = transB ? CUBLAS_OP_T : CUBLAS_OP_N; - - NDArray::prepareSpecialUse({pC}, {pA, pB}); + NDArray::prepareSpecialUse({pC}, {pA, pB}); - // choose appropriate cuda gemm api depending on data types - if(typeDouble) { - status = cublasDgemm(*handle, transAblas, transBblas, M, N, K, &alpha, (double*)pA->specialBuffer(), lda, (double*)pB->specialBuffer(), ldb, &beta, (double*)pC->specialBuffer(), ldc); - } - else if(typeFloat) { - float alphaF(alpha), betaF(beta); - status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->specialBuffer(), lda, (float*)pB->specialBuffer(), ldb, &betaF, (float*)pC->specialBuffer(), ldc); - } - else if(typeHalf) { - float16 alphaH(alpha), betaH(beta); - status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->specialBuffer(), lda, (__half*)pB->specialBuffer(), ldb, &betaH.data, (__half*)pC->specialBuffer(), ldc); - } - else if(typeIntFloat) { - float alphaF(alpha), betaF(beta); - status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->specialBuffer(), CUDA_R_8I, lda, pB->specialBuffer(), CUDA_R_8I, ldb, &betaF, pC->specialBuffer(), CUDA_R_32F, ldc); - } - else if(typeHalfFloat) { - float alphaF(alpha), betaF(beta); - status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->specialBuffer(), CUDA_R_16F, lda, pB->specialBuffer(), CUDA_R_16F, ldb, &betaF, pC->specialBuffer(), CUDA_R_32F, ldc); - } + // choose appropriate cuda gemm api depending on data types + if (typeDouble) { + status = cublasDgemm(*handle, transAblas, transBblas, M, N, K, &alpha, + (double*)pA->specialBuffer(), lda, + (double*)pB->specialBuffer(), ldb, &beta, + (double*)pC->specialBuffer(), ldc); + } else if (typeFloat) { + float alphaF(alpha), betaF(beta); + status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, + (float*)pA->specialBuffer(), lda, + (float*)pB->specialBuffer(), ldb, &betaF, + (float*)pC->specialBuffer(), ldc); + } else if (typeHalf) { + float16 alphaH(alpha), betaH(beta); + status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, + &alphaH.data, (__half*)pA->specialBuffer(), lda, + (__half*)pB->specialBuffer(), ldb, &betaH.data, + (__half*)pC->specialBuffer(), ldc); + } else if (typeIntFloat) { + float alphaF(alpha), betaF(beta); + status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, + pA->specialBuffer(), CUDA_R_8I, lda, + pB->specialBuffer(), CUDA_R_8I, ldb, &betaF, + pC->specialBuffer(), CUDA_R_32F, ldc); + } else if (typeHalfFloat) { + float alphaF(alpha), betaF(beta); + status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, + pA->specialBuffer(), CUDA_R_16F, lda, + pB->specialBuffer(), CUDA_R_16F, ldb, &betaF, + pC->specialBuffer(), CUDA_R_32F, ldc); + } - if (status != CUBLAS_STATUS_SUCCESS) - throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); - NDArray::registerSpecialUse({pC}, {pA, pB}); + NDArray::registerSpecialUse({pC}, {pA, pB}); - auto cudaResult = cudaStreamSynchronize(*stream); - if (cudaResult != 0) - throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", cudaResult); + auto cudaResult = cudaStreamSynchronize(*stream); + if (cudaResult != 0) + throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", + cudaResult); - if(C != pC) - C->assign(pC); + if (C != pC) C->assign(pC); - for(int i = toDelete.size() - 1; i >= 0; --i) - delete toDelete[i]; - } + for (int i = toDelete.size() - 1; i >= 0; --i) delete toDelete[i]; + } - return C; + return C; } //////////////////////////////////////////////////////////////////////////// // MXN x N = M -NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y, const double alpha, const double beta, const char outOrder) { +NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y, + const double alpha, const double beta, + const char outOrder) { + int xLenDim, yLenDim(0); + + if (A->rankOf() != 2) + throw std::runtime_error( + "MmulHelper::mmulMxV cuda: rank of A array is not equal 2 !"); + if (!shape::isCommonVector(X->shapeInfo(), xLenDim)) + throw std::runtime_error( + "MmulHelper::mmulMxV cuda: X array must be vector !"); + + const auto M = A->sizeAt(0); + const auto N = A->sizeAt(1); + + if (Y != nullptr && !shape::isCommonVector(Y->shapeInfo(), yLenDim)) + throw std::runtime_error( + "MmulHelper::mmulMxV cuda: Y array must be vector !"); + if (X->lengthOf() != N) + throw std::runtime_error( + "MmulHelper::mmulMxV cuda: X vector has wrong length !"); + if (Y != nullptr && Y->lengthOf() != M) + throw std::runtime_error( + "MmulHelper::mmulMxV cuda: Y array has wrong length !"); + + if (Y == nullptr) + Y = new NDArray( + outOrder, {M}, + DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), + A->getContext()); + + if (Y->isEmpty()) return Y; + + const int incx = X->strideAt(xLenDim); + const int incy = Y->strideAt(yLenDim); + + const auto aType = A->dataType(); + const auto xType = X->dataType(); + const auto yType = Y->dataType(); + + const bool AX(aType == xType), AY(aType == yType), AXY(AX && AY); + + const bool typeDouble = AXY && aType == DataType::DOUBLE; + const bool typeFloat = AXY && aType == DataType::FLOAT32; + + std::lock_guard lock(*LaunchContext::deviceMutex()); + + auto handle = + reinterpret_cast(A->getContext()->getCublasHandle()); + auto stream = A->getContext()->getCudaStream(); + + auto status = cublasSetStream_v2(*handle, *stream); + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status); + + if (!typeDouble && !typeFloat) { + const int threadsPerBlock = MAX_NUM_THREADS; + const int blocksPerGrid = (M + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({Y}, {A, X}); + // BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, + // threadsPerBlock, stream, A->specialBuffer(), A->specialShapeInfo(), + // X->specialBuffer(), X->specialShapeInfo(), Y->specialBuffer(), + // Y->special(), incx, incy, 0, alpha, beta), NUMERIC_TYPES, + // NUMERIC_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE( + xType, usualGemv, + (blocksPerGrid, threadsPerBlock, stream, A->specialBuffer(), + A->specialShapeInfo(), X->specialBuffer(), X->specialShapeInfo(), + Y->specialBuffer(), Y->specialShapeInfo(), incx, incy, 0, alpha, beta), + NUMERIC_TYPES) + NDArray::registerSpecialUse({Y}, {A, X}); - int xLenDim, yLenDim(0); - - if(A->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxV cuda: rank of A array is not equal 2 !"); - if(!shape::isCommonVector(X->shapeInfo(), xLenDim)) - throw std::runtime_error("MmulHelper::mmulMxV cuda: X array must be vector !"); - - const auto M = A->sizeAt(0); - const auto N = A->sizeAt(1); - - if(Y != nullptr && !shape::isCommonVector(Y->shapeInfo(), yLenDim)) - throw std::runtime_error("MmulHelper::mmulMxV cuda: Y array must be vector !"); - if(X->lengthOf() != N) - throw std::runtime_error("MmulHelper::mmulMxV cuda: X vector has wrong length !"); - if(Y != nullptr && Y->lengthOf() != M) - throw std::runtime_error("MmulHelper::mmulMxV cuda: Y array has wrong length !"); - - if(Y == nullptr) - Y = new NDArray(outOrder, {M}, DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), A->getContext()); - - if (Y->isEmpty()) - return Y; - - const int incx = X->strideAt(xLenDim); - const int incy = Y->strideAt(yLenDim); - - const auto aType = A->dataType(); - const auto xType = X->dataType(); - const auto yType = Y->dataType(); - - const bool AX(aType == xType), AY(aType == yType), AXY(AX && AY); - - const bool typeDouble = AXY && aType == DataType::DOUBLE; - const bool typeFloat = AXY && aType == DataType::FLOAT32; - - std::lock_guard lock(*LaunchContext::deviceMutex()); - - auto handle = reinterpret_cast(A->getContext()->getCublasHandle()); - auto stream = A->getContext()->getCudaStream(); - - auto status = cublasSetStream_v2(*handle, *stream); - if (status != CUBLAS_STATUS_SUCCESS) - throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status); - - if(!typeDouble && !typeFloat) { + auto cudaResult = cudaStreamSynchronize(*stream); + if (cudaResult != 0) + throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", + cudaResult); - const int threadsPerBlock = MAX_NUM_THREADS; - const int blocksPerGrid = (M + threadsPerBlock - 1) / threadsPerBlock; + } else { + NDArray* pA(const_cast(A)); - NDArray::prepareSpecialUse({Y}, {A, X}); - // BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, A->specialBuffer(), A->specialShapeInfo(), X->specialBuffer(), X->specialShapeInfo(), Y->specialBuffer(), Y->special(), incx, incy, 0, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); - BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, A->specialBuffer(), A->specialShapeInfo(), X->specialBuffer(), X->specialShapeInfo(), Y->specialBuffer(), Y->specialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES) - NDArray::registerSpecialUse({Y}, {A, X}); - - auto cudaResult = cudaStreamSynchronize(*stream); - if (cudaResult != 0) - throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", cudaResult); + bool aMcont = M == 1 || A->strideAt(0) == 1; + bool aNcont = N == 1 || A->strideAt(1) == 1; + if (!aMcont && !aNcont) { + pA = new NDArray(A->dup('f')); + aMcont = true; } - else { - NDArray *pA(const_cast(A)); + const bool transA = !aMcont; - bool aMcont = M == 1 || A->strideAt(0) == 1; - bool aNcont = N == 1 || A->strideAt(1) == 1; + const int lda = + (aMcont && aNcont) ? M : transA ? pA->strideAt(0) : pA->strideAt(1); - if(!aMcont && !aNcont) { - pA = new NDArray(A->dup('f')); - aMcont = true; - } + const cublasOperation_t transAblas = transA ? CUBLAS_OP_T : CUBLAS_OP_N; - const bool transA = !aMcont; - - const int lda = (aMcont && aNcont) ? M : transA ? pA->strideAt(0) : pA->strideAt(1); - - const cublasOperation_t transAblas = transA ? CUBLAS_OP_T : CUBLAS_OP_N; - - NDArray::prepareSpecialUse({Y}, {pA, X}); + NDArray::prepareSpecialUse({Y}, {pA, X}); - // choose appropriate cuda gemm api depending on data types - if(typeDouble) { - status = cublasDgemv(*handle, transAblas, transA ? N : M, transA ? M : N, &alpha, (double*)pA->specialBuffer(), lda, (double*)X->specialBuffer(), incx, &beta, (double*)Y->specialBuffer(), incy); - } - else if(typeFloat) { - float alphaF(alpha), betaF(beta); - status = cublasSgemv(*handle, transAblas, transA ? N : M, transA ? M : N, &alphaF, (float*)pA->specialBuffer(), lda, (float*)X->specialBuffer(), incx, &betaF, (float*)Y->specialBuffer(), incy); - } + // choose appropriate cuda gemm api depending on data types + if (typeDouble) { + status = cublasDgemv(*handle, transAblas, transA ? N : M, transA ? M : N, + &alpha, (double*)pA->specialBuffer(), lda, + (double*)X->specialBuffer(), incx, &beta, + (double*)Y->specialBuffer(), incy); + } else if (typeFloat) { + float alphaF(alpha), betaF(beta); + status = cublasSgemv(*handle, transAblas, transA ? N : M, transA ? M : N, + &alphaF, (float*)pA->specialBuffer(), lda, + (float*)X->specialBuffer(), incx, &betaF, + (float*)Y->specialBuffer(), incy); + } - if (status != CUBLAS_STATUS_SUCCESS) - throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status); + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status); - auto cudaResult = cudaStreamSynchronize(*stream); - if (cudaResult != 0) - throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", cudaResult); + auto cudaResult = cudaStreamSynchronize(*stream); + if (cudaResult != 0) + throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", + cudaResult); - NDArray::registerSpecialUse({Y}, {pA, X}); + NDArray::registerSpecialUse({Y}, {pA, X}); - if(pA != A) - delete pA; - } + if (pA != A) delete pA; + } - return Y; + return Y; } //////////////////////////////////////////////////////////////////////////// // (X * Y) = Z[0] -NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, const double alpha, const double beta) { - - int xLenDim(0), yLenDim(0); - - if(!shape::isCommonVector(X->shapeInfo(), xLenDim)) - throw std::runtime_error("MmulHelper::dot cuda: X array must be vector !"); - if(!shape::isCommonVector(Y->shapeInfo(), yLenDim)) - throw std::runtime_error("MmulHelper::dot cuda: Y array must be vector !"); - if(Z != nullptr && !Z->isScalar()) - throw std::runtime_error("MmulHelper::dot cuda: Z array must be scalar !"); - - const auto length = X->lengthOf(); - - if(Y->lengthOf() != length) - throw std::runtime_error("MmulHelper::dot cuda: lengths of input vectors are different !"); - - if(Z == nullptr) - Z = new NDArray(DataTypeUtils::pickPairwiseResultType(X->dataType(), Y->dataType()), X->getContext()); - - const Nd4jLong incx = X->strideAt(xLenDim); - const Nd4jLong incy = Y->strideAt(yLenDim); - - const auto xType = X->dataType(); - const auto yType = Y->dataType(); - const auto zType = Z->dataType(); - - if(!X->isActualOnDeviceSide()) X->syncToDevice(); - if(!Y->isActualOnDeviceSide()) Y->syncToDevice(); - if(!Z->isActualOnDeviceSide()) Z->syncToDevice(); - - cudaStream_t* stream = X->getContext()->getCudaStream(); - - dim3 threadsPerBlock(512); - dim3 blocksPerGrid(1); - if (length > 512) - threadsPerBlock.x = math::nd4j_ceil(static_cast(length) / 512); - - NDArray::prepareSpecialUse({Z}, {X, Y}); - - //BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->specialBuffer(), incx, Y->specialBuffer(), incy, beta, Z->specialBuffer()), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); - BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->specialBuffer(), incx, Y->specialBuffer(), incy, beta, Z->specialBuffer()), NUMERIC_TYPES) - - auto cudaResult = cudaStreamSynchronize(*stream); - if (cudaResult != 0) throw cuda_exception::build("MmulHelper::dot cuda failed !", cudaResult); - - NDArray::registerSpecialUse({Z}, {X, Y}); - - return Z; +NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, + const double alpha, const double beta) { + int xLenDim(0), yLenDim(0); + + if (!shape::isCommonVector(X->shapeInfo(), xLenDim)) + throw std::runtime_error("MmulHelper::dot cuda: X array must be vector !"); + if (!shape::isCommonVector(Y->shapeInfo(), yLenDim)) + throw std::runtime_error("MmulHelper::dot cuda: Y array must be vector !"); + if (Z != nullptr && !Z->isScalar()) + throw std::runtime_error("MmulHelper::dot cuda: Z array must be scalar !"); + + const auto length = X->lengthOf(); + + if (Y->lengthOf() != length) + throw std::runtime_error( + "MmulHelper::dot cuda: lengths of input vectors are different !"); + + if (Z == nullptr) + Z = new NDArray( + DataTypeUtils::pickPairwiseResultType(X->dataType(), Y->dataType()), + X->getContext()); + + const Nd4jLong incx = X->strideAt(xLenDim); + const Nd4jLong incy = Y->strideAt(yLenDim); + + const auto xType = X->dataType(); + const auto yType = Y->dataType(); + const auto zType = Z->dataType(); + + if (!X->isActualOnDeviceSide()) X->syncToDevice(); + if (!Y->isActualOnDeviceSide()) Y->syncToDevice(); + if (!Z->isActualOnDeviceSide()) Z->syncToDevice(); + + cudaStream_t* stream = X->getContext()->getCudaStream(); + + dim3 threadsPerBlock(512); + dim3 blocksPerGrid(1); + if (length > 512) + threadsPerBlock.x = + math::nd4j_ceil(static_cast(length) / 512); + + NDArray::prepareSpecialUse({Z}, {X, Y}); + + // BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, + // threadsPerBlock, stream, length, alpha, X->specialBuffer(), incx, + // Y->specialBuffer(), incy, beta, Z->specialBuffer()), NUMERIC_TYPES, + // NUMERIC_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE( + xType, usualDot, + (blocksPerGrid, threadsPerBlock, stream, length, alpha, + X->specialBuffer(), incx, Y->specialBuffer(), incy, beta, + Z->specialBuffer()), + NUMERIC_TYPES) + + auto cudaResult = cudaStreamSynchronize(*stream); + if (cudaResult != 0) + throw cuda_exception::build("MmulHelper::dot cuda failed !", cudaResult); + + NDArray::registerSpecialUse({Z}, {X, Y}); + + return Z; } ////////////////////////////////////////////////////////////////////////////// @@ -523,165 +618,208 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, con // [M,K] x [bS,K,N] = [bS,M,N] // bS could stand for several axes template -static __global__ void batchedCudaGemm(const void* vA, const Nd4jLong* aShapeInfo, const void* vB, const Nd4jLong* bShapeInfo, void* vC, const Nd4jLong* cShapeInfo, - const int* aBatchDims, const int* bBatchDims, const int* cBatchDims, - const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, - const double alpha, const double beta) { - - const T1* A = reinterpret_cast(vA); - const T2* B = reinterpret_cast(vB); - T3* C = reinterpret_cast< T3*>(vC); - - __shared__ bool betaPresent; - __shared__ int aRank, bRank, cRank, K, *coords; - __shared__ Nd4jLong cLen, totalThreads; - __shared__ T3 alphaZ, betaZ; - - if (threadIdx.x == 0) { - - extern __shared__ unsigned char shmem[]; - coords = reinterpret_cast(shmem); - cLen = shape::length(cShapeInfo); +static __global__ void batchedCudaGemm( + const void* vA, const Nd4jLong* aShapeInfo, const void* vB, + const Nd4jLong* bShapeInfo, void* vC, const Nd4jLong* cShapeInfo, + const int* aBatchDims, const int* bBatchDims, const int* cBatchDims, + const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, + const int cMaxis, const int cNaxis, const double alpha, const double beta) { + const T1* A = reinterpret_cast(vA); + const T2* B = reinterpret_cast(vB); + T3* C = reinterpret_cast(vC); + + __shared__ bool betaPresent; + __shared__ int aRank, bRank, cRank, K, *coords; + __shared__ Nd4jLong cLen, totalThreads; + __shared__ T3 alphaZ, betaZ; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + coords = reinterpret_cast(shmem); + cLen = shape::length(cShapeInfo); - K = shape::shapeOf(const_cast(aShapeInfo))[aKaxis]; + K = shape::shapeOf(const_cast(aShapeInfo))[aKaxis]; - totalThreads = gridDim.x * blockDim.x; - aRank = shape::rank(aShapeInfo); - bRank = shape::rank(bShapeInfo); - cRank = shape::rank(cShapeInfo); + totalThreads = gridDim.x * blockDim.x; + aRank = shape::rank(aShapeInfo); + bRank = shape::rank(bShapeInfo); + cRank = shape::rank(cShapeInfo); - betaPresent = beta; + betaPresent = beta; - alphaZ = alpha; - betaZ = beta; - } - __syncthreads(); + alphaZ = alpha; + betaZ = beta; + } + __syncthreads(); - auto aCoords = coords + threadIdx.x * (aRank + bRank + cRank); - auto bCoords = aCoords + aRank; - auto cCoords = bCoords + bRank; + auto aCoords = coords + threadIdx.x * (aRank + bRank + cRank); + auto bCoords = aCoords + aRank; + auto cCoords = bCoords + bRank; - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (Nd4jLong i = tid; i < cLen; i += totalThreads) { + for (Nd4jLong i = tid; i < cLen; i += totalThreads) { + // evaluate C coordinates + shape::index2coords(i, cShapeInfo, cCoords); - // evaluate C coordinates - shape::index2coords(i, cShapeInfo, cCoords); + // calculate index of current batch + Nd4jLong batchInd; + if (cBatchDims != nullptr) + batchInd = + shape::coords2index(cShapeInfo, cCoords, cRank - 2, cBatchDims); - // calculate index of current batch - Nd4jLong batchInd; - if(cBatchDims != nullptr) - batchInd = shape::coords2index(cShapeInfo, cCoords, cRank - 2, cBatchDims); + // evaluate A coordinates + if (aBatchDims != nullptr) + shape::index2coords(batchInd, aShapeInfo, aCoords, aRank - 2, aBatchDims); + aCoords[aMaxis] = cCoords[cMaxis]; + aCoords[aKaxis] = 0; - // evaluate A coordinates - if(aBatchDims != nullptr) - shape::index2coords(batchInd, aShapeInfo, aCoords, aRank - 2, aBatchDims); - aCoords[aMaxis] = cCoords[cMaxis]; - aCoords[aKaxis] = 0; + // evaluate B coordinates + if (bBatchDims != nullptr) + shape::index2coords(batchInd, bShapeInfo, bCoords, bRank - 2, bBatchDims); + bCoords[bKaxis] = 0; + bCoords[bNaxis] = cCoords[cNaxis]; - // evaluate B coordinates - if(bBatchDims != nullptr) - shape::index2coords(batchInd, bShapeInfo, bCoords, bRank - 2, bBatchDims); - bCoords[bKaxis] = 0; - bCoords[bNaxis] = cCoords[cNaxis]; + auto aOffset = shape::getOffset(aShapeInfo, aCoords); + auto bOffset = shape::getOffset(bShapeInfo, bCoords); - auto aOffset = shape::getOffset(aShapeInfo, aCoords); - auto bOffset = shape::getOffset(bShapeInfo, bCoords); + T3 val = A[aOffset] * B[bOffset]; // first iteration - T3 val = A[aOffset] * B[bOffset]; // first iteration - - for (uint j = 1; j < K; ++j) { // rest iterations - aOffset += shape::stride(aShapeInfo)[aKaxis]; - bOffset += shape::stride(bShapeInfo)[bKaxis]; - val = val + A[aOffset] * B[bOffset]; - } + for (uint j = 1; j < K; ++j) { // rest iterations + aOffset += shape::stride(aShapeInfo)[aKaxis]; + bOffset += shape::stride(bShapeInfo)[bKaxis]; + val = val + A[aOffset] * B[bOffset]; + } - auto cOffset = shape::getOffset(cShapeInfo, cCoords); + auto cOffset = shape::getOffset(cShapeInfo, cCoords); - if(betaPresent) - C[cOffset] = alphaZ * val + betaZ * C[cOffset]; - else - C[cOffset] = alphaZ * val; - } + if (betaPresent) + C[cOffset] = alphaZ * val + betaZ * C[cOffset]; + else + C[cOffset] = alphaZ * val; + } } //////////////////////////////////////////////////////////////////////// template -__host__ static void batchedGemm(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, cudaStream_t *stream, const void* vA, const Nd4jLong* aShapeInfo, const void* vB, const Nd4jLong* bShapeInfo, void* vC, const Nd4jLong* cShapeInfo, const int* aBatchDims, const int* bBatchDims, const int* cBatchDims, const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, const double alpha, const double beta) { - - batchedCudaGemm<<>>(vA, aShapeInfo, vB, bShapeInfo, vC, cShapeInfo, aBatchDims, bBatchDims, cBatchDims, aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta); +__host__ static void batchedGemm( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + cudaStream_t* stream, const void* vA, const Nd4jLong* aShapeInfo, + const void* vB, const Nd4jLong* bShapeInfo, void* vC, + const Nd4jLong* cShapeInfo, const int* aBatchDims, const int* bBatchDims, + const int* cBatchDims, const int aMaxis, const int aKaxis, const int bKaxis, + const int bNaxis, const int cMaxis, const int cNaxis, const double alpha, + const double beta) { + batchedCudaGemm + <<>>( + vA, aShapeInfo, vB, bShapeInfo, vC, cShapeInfo, aBatchDims, + bBatchDims, cBatchDims, aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, + cNaxis, alpha, beta); } /////////////////////////////////////////////////////////////////// -NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { - - const int aRank = A->rankOf(); - const int bRank = B->rankOf(); - - // input ranks validation - if(aRank > bRank && bRank != 2) - throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be equal 2 !"); - else if(bRank > aRank && aRank != 2) - throw std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); - else if (aRank == bRank ) { - for(int i = 0; i < aRank - 2; ++i) - if(A->sizeAt(i) != B->sizeAt(i)) - throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); - } - - if(A->sizeAt(-1) != B->sizeAt(-2)) - throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); - - // validation of C array - std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); - cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); - cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); - - if(C != nullptr ) { - if(!C->isSameShape(cExpectedShape)) - throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !"); - } - else - C = new NDArray(outOrder, cExpectedShape, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext()); - - if (C->isEmpty()) - return C; - - const int cRank = C->rankOf(); - - const int aMaxis(aRank-2), aKaxis(aRank-1), bKaxis(bRank-2), bNaxis(bRank-1), cMaxis(cRank-2), cNaxis(cRank-1); - - const int threadsPerBlock = MAX_NUM_THREADS / 8; - const int blocksPerGrid = (C->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * (aRank + bRank + cRank) + 128; - - PointersManager manager(A->getContext(), "MmulHelper::mmulNxN"); - - const int *aBatchDims(nullptr), *bBatchDims(nullptr), *cBatchDims(nullptr); - - if(aRank > 2) - aBatchDims = reinterpret_cast(manager.replicatePointer(ShapeUtils::evalDimsToExclude(aRank, {aMaxis, aKaxis}).data(), (aRank - 2) * sizeof(int))); - if(bRank > 2) - bBatchDims = reinterpret_cast(manager.replicatePointer(ShapeUtils::evalDimsToExclude(bRank, {bKaxis, bNaxis}).data(), (bRank - 2) * sizeof(int))); - if(cRank > 2) - cBatchDims = reinterpret_cast(manager.replicatePointer(ShapeUtils::evalDimsToExclude(cRank, {cMaxis, cNaxis}).data(), (cRank - 2) * sizeof(int))); - - NDArray::prepareSpecialUse({C}, {A, B}); - // BUILD_TRIPLE_SELECTOR(A->dataType(), b->dataType(), C->dataType(), batchedGemm, (blocksPerGrid, threadsPerBlock, A->getContext()->getCudaStream(), A->specialBuffer(), A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), C->specialBuffer(), C->special(), aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); - BUILD_SINGLE_SELECTOR_THRICE(A->dataType(), batchedGemm, (blocksPerGrid, threadsPerBlock, sharedMem, A->getContext()->getCudaStream(), A->specialBuffer(), A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), C->specialBuffer(), C->specialShapeInfo(), aBatchDims, bBatchDims, cBatchDims, aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), NUMERIC_TYPES) - NDArray::registerSpecialUse({C}, {A, B}); - - manager.synchronize(); - - return C; +NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, + const double alpha, const double beta, + const char outOrder) { + const int aRank = A->rankOf(); + const int bRank = B->rankOf(); + + // input ranks validation + if (aRank > bRank && bRank != 2) + throw std::runtime_error( + "MmulHelper::mmulNxN: rank of B array should be equal 2 !"); + else if (bRank > aRank && aRank != 2) + throw std::runtime_error( + "MmulHelper::mmulNxN: rank of A array should be equal 2 !"); + else if (aRank == bRank) { + for (int i = 0; i < aRank - 2; ++i) + if (A->sizeAt(i) != B->sizeAt(i)) + throw std::runtime_error( + "MmulHelper::mmulNxN: shapes of A and B arrays are not suitable " + "for matrix multiplication !"); + } + + if (A->sizeAt(-1) != B->sizeAt(-2)) + throw std::runtime_error( + "MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for " + "matrix multiplication !"); + + // validation of C array + std::vector cExpectedShape = + aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); + cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); + cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); + + if (C != nullptr) { + if (!C->isSameShape(cExpectedShape)) + throw std::runtime_error( + "MmulHelper::mmulNxN: shape of C array is not suitable for AxB " + "matrix multiplication !"); + } else + C = new NDArray( + outOrder, cExpectedShape, + DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), + A->getContext()); + + if (C->isEmpty()) return C; + + const int cRank = C->rankOf(); + + const int aMaxis(aRank - 2), aKaxis(aRank - 1), bKaxis(bRank - 2), + bNaxis(bRank - 1), cMaxis(cRank - 2), cNaxis(cRank - 1); + + const int threadsPerBlock = MAX_NUM_THREADS / 8; + const int blocksPerGrid = + (C->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + threadsPerBlock * sizeof(int) * (aRank + bRank + cRank) + 128; + + PointersManager manager(A->getContext(), "MmulHelper::mmulNxN"); + + const int *aBatchDims(nullptr), *bBatchDims(nullptr), *cBatchDims(nullptr); + + if (aRank > 2) + aBatchDims = reinterpret_cast(manager.replicatePointer( + ShapeUtils::evalDimsToExclude(aRank, {aMaxis, aKaxis}).data(), + (aRank - 2) * sizeof(int))); + if (bRank > 2) + bBatchDims = reinterpret_cast(manager.replicatePointer( + ShapeUtils::evalDimsToExclude(bRank, {bKaxis, bNaxis}).data(), + (bRank - 2) * sizeof(int))); + if (cRank > 2) + cBatchDims = reinterpret_cast(manager.replicatePointer( + ShapeUtils::evalDimsToExclude(cRank, {cMaxis, cNaxis}).data(), + (cRank - 2) * sizeof(int))); + + NDArray::prepareSpecialUse({C}, {A, B}); + // BUILD_TRIPLE_SELECTOR(A->dataType(), b->dataType(), C->dataType(), + // batchedGemm, (blocksPerGrid, threadsPerBlock, + // A->getContext()->getCudaStream(), A->specialBuffer(), + // A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), + // C->specialBuffer(), C->special(), aMaxis, aKaxis, bKaxis, bNaxis, + // cMaxis, cNaxis, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE( + A->dataType(), batchedGemm, + (blocksPerGrid, threadsPerBlock, sharedMem, + A->getContext()->getCudaStream(), A->specialBuffer(), + A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), + C->specialBuffer(), C->specialShapeInfo(), aBatchDims, bBatchDims, + cBatchDims, aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), + NUMERIC_TYPES) + NDArray::registerSpecialUse({C}, {A, B}); + + manager.synchronize(); + + return C; } - /* ////////////////////////////////////////////////////////////////////////////// // MXN x N = M template -static __global__ void usualCudaGemv(const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vX, const int incx, const double beta, void* vY, const int incy) { +static __global__ void usualCudaGemv(const bool transA, const int M, const int +N, const double alpha, const void* vA, const int lda, const void* vX, const int +incx, const double beta, void* vY, const int incy) { T1* A = reinterpret_cast(const_cast(vA)); T2* X = reinterpret_cast(const_cast(vX)); @@ -697,7 +835,8 @@ static __global__ void usualCudaGemv(const bool transA, const int M, const int N alphaZ = alpha; betaZ = beta; - if(transA) { strideArow = lda; strideAcol = 1; } else { strideArow = 1; strideAcol = lda; } + if(transA) { strideArow = lda; strideAcol = 1; } else { strideArow = 1; +strideAcol = lda; } } __syncthreads(); @@ -712,9 +851,13 @@ static __global__ void usualCudaGemv(const bool transA, const int M, const int N //////////////////////////////////////////////////////////////////////// template -__host__ static void usualGemv(const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vX, const int incx, const double beta, void* vY, const int incy) { +__host__ static void usualGemv(const dim3 &blocksPerGrid, const dim3 +&threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const +int N, const double alpha, const void* vA, const int lda, const void* vX, const +int incx, const double beta, void* vY, const int incy) { - usualCudaGemv<<>>(transA, M, N, alpha, vA, lda, vX, incx, beta, vY, incy); + usualCudaGemv<<>>(transA, M, N, alpha, vA, lda, vX, incx, beta, vY, incy); } */ /* @@ -722,7 +865,10 @@ __host__ static void usualGemv(const dim3 &blocksPerGrid, const dim3 &threadsPer MXK x KxN = MxN C array must be in f order template -static __global__ void usualCudaGemm(const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc) { +static __global__ void usualCudaGemm(const bool transA, const bool transB, const +int M, const int N, const int K, const double alpha, const void* vA, const int +lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc) +{ T1* A = reinterpret_cast(const_cast(vA)); T2* B = reinterpret_cast(const_cast(vB)); @@ -739,8 +885,9 @@ static __global__ void usualCudaGemm(const bool transA, const bool transB, const alphaZ = alpha; betaZ = beta; - if(transA) { strideArow = lda; strideAcol = 1; } else { strideArow = 1; strideAcol = lda; } - if(transB) { strideBrow = ldb; strideBcol = 1; } else { strideBrow = 1; strideBcol = ldb; } + if(transA) { strideArow = lda; strideAcol = 1; } else { strideArow = 1; +strideAcol = lda; } if(transB) { strideBrow = ldb; strideBcol = 1; } else { +strideBrow = 1; strideBcol = ldb; } } __syncthreads(); @@ -748,47 +895,57 @@ static __global__ void usualCudaGemm(const bool transA, const bool transB, const T3 val = 0; if (row < M && col < N) for (int i = 0; i < K; i++) - val = val + A[row * strideArow + i * strideAcol] * B[i * strideBrow + col * strideBcol]; + val = val + A[row * strideArow + i * strideAcol] * B[i * strideBrow ++ col * strideBcol]; C[row + col * ldc] = alphaZ * val + betaZ * C[row + col * ldc]; } ////////////////////////////////////////////////////////////////////////////// template -__host__ static void usualGemm(const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc) { - - usualCudaGemm<<>>(transA, transB, M, N, K, alpha, vA, lda, vB, ldb, beta, vC, ldc); +__host__ static void usualGemm(const dim3 &blocksPerGrid, const dim3 +&threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, +const int M, const int N, const int K, const double alpha, const void* vA, const +int lda, const void* vB, const int ldb, const double beta, void* vC, const int +ldc) { + + usualCudaGemm<<>>(transA, transB, M, N, K, alpha, vA, lda, vB, ldb, beta, vC, ldc); } */ ////////////////////////////////////////////////////////////////////////// /* -NDArray* MmulHelper::mmulNxNold1(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { +NDArray* MmulHelper::mmulNxNold1(const NDArray* A, const NDArray* B, NDArray* C, +const double alpha, const double beta, const char outOrder) { const int aRank = A->rankOf(); const int bRank = B->rankOf(); // input ranks validation if(aRank > bRank && bRank != 2) - throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be equal 2 !"); - else if(bRank > aRank && aRank != 2) - throw std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); + throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be +equal 2 !"); else if(bRank > aRank && aRank != 2) throw +std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); else if (aRank == bRank ) { for(int i = 0; i < aRank - 2; ++i) if(A->sizeAt(i) != B->sizeAt(i)) - throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B +arrays are not suitable for matrix multiplication !"); } if(A->sizeAt(-1) != B->sizeAt(-2)) - throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays +are not suitable for matrix multiplication !"); // validation of C array - std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); - cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); - cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); + std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() +: B->getShapeAsVector(); cExpectedShape[cExpectedShape.size() - 2] = +A->sizeAt(-2); cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); if(C != nullptr ) { if(!C->isSameShape(cExpectedShape)) - throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !"); + throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is +not suitable for AxB matrix multiplication !"); } else { C = new NDArray(outOrder, cExpectedShape, B->dataType()); @@ -796,15 +953,16 @@ NDArray* MmulHelper::mmulNxNold1(const NDArray* A, const NDArray* B, NDArray* C, // multiplication - const std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(C->rankOf(), {-2, -1}); - const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(C->shapeInfo(), dimsToExclude); + const std::vector dimsToExclude = +ShapeUtils::evalDimsToExclude(C->rankOf(), {-2, -1}); const Nd4jLong +numOfSubArrs = ShapeUtils::getNumOfSubArrs(C->shapeInfo(), dimsToExclude); std::vector idxRanges(2 * C->rankOf()); // #pragma omp parallel for schedule(guided) firstprivate(idxRanges) for(Nd4jLong i = 0; i < numOfSubArrs; ++i) { - ShapeUtils::evalIdxRangesForSubArr(i, C->shapeInfo(), dimsToExclude, idxRanges.data()); - NDArray cSubArr = (*C)(idxRanges); + ShapeUtils::evalIdxRangesForSubArr(i, C->shapeInfo(), dimsToExclude, +idxRanges.data()); NDArray cSubArr = (*C)(idxRanges); if(aRank > bRank) { NDArray aSubArr = (*A)(idxRanges); @@ -831,33 +989,37 @@ NDArray* MmulHelper::mmulNxNold1(const NDArray* A, const NDArray* B, NDArray* C, // [M,K] x [bS,K,N] = [bS,M,N] // bS could stand for several axes /* -NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { +NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C, +const double alpha, const double beta, const char outOrder) { const int aRank = A->rankOf(); const int bRank = B->rankOf(); // input ranks validation if(aRank > bRank && bRank != 2) - throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be equal 2 !"); - else if(bRank > aRank && aRank != 2) - throw std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); + throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be +equal 2 !"); else if(bRank > aRank && aRank != 2) throw +std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); else if (aRank == bRank ) { for(int i = 0; i < aRank - 2; ++i) if(A->sizeAt(i) != B->sizeAt(i)) - throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B +arrays are not suitable for matrix multiplication !"); } if(A->sizeAt(-1) != B->sizeAt(-2)) - throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays +are not suitable for matrix multiplication !"); // validation of C array - std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); - cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); - cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); + std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() +: B->getShapeAsVector(); cExpectedShape[cExpectedShape.size() - 2] = +A->sizeAt(-2); cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); if(C != nullptr ) { if(!C->isSameShape(cExpectedShape)) - throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !"); + throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is +not suitable for AxB matrix multiplication !"); } else C = new NDArray(outOrder, cExpectedShape, B->dataType()); @@ -868,8 +1030,8 @@ NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C, const auto K = A->sizeAt(-1); const auto N = B->sizeAt(-1); - NDArray *pA(const_cast(A)), *pB(const_cast(B)), *pC(const_cast(C)); - std::vector toDelete; + NDArray *pA(const_cast(A)), *pB(const_cast(B)), +*pC(const_cast(C)); std::vector toDelete; bool aMcont = M == 1 || A->strideAt(-2) == 1; bool aKcont = K == 1 || A->strideAt(-1) == 1; @@ -892,9 +1054,9 @@ NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C, if(!cMcont) { std::iota(permut.begin(), permut.end(), 0); permut[cRank - 2] = cRank - 1; - permut[cRank - 1] = cRank - 2; // swap two last dimensions [..., M,N] -> [..., N,M] - auto Cpermut = C->permute(permut); - pC = new NDArray('c', Cpermut.getShapeAsVector(), Cpermut.dataType(), A->getContext()); + permut[cRank - 1] = cRank - 2; // swap two last dimensions [..., M,N] +-> [..., N,M] auto Cpermut = C->permute(permut); pC = new NDArray('c', +Cpermut.getShapeAsVector(), Cpermut.dataType(), A->getContext()); pC->assign(Cpermut); toDelete.push_back(pC); cMcont = true; @@ -932,43 +1094,53 @@ NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C, const int bS = pC->lengthOf() / (M*N); - const std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(cRank, {-2, -1}); + const std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(cRank, +{-2, -1}); NDArray::prepareSpecialUse({pC}, {pA, pB}); if(!badTypes) { std::vector subArrOffsets(bS); - std::vector subArrShapeInfo(shape::shapeInfoLength(2)); // all sub-arrays have rank = 2 + std::vector subArrShapeInfo(shape::shapeInfoLength(2)); // all +sub-arrays have rank = 2 std::vector aSubArrs(bS), bSubArrs(bS), cSubArrs(bS); if(aRank > 2) - shape::calcSubArrsShapeInfoAndOffsets(pA->shapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data()); - for (int i = 0; i < bS; ++i) - aSubArrs[i] = aRank == 2 ? pA->specialBuffer() : pA->specialBuffer() + subArrOffsets[i] * pA->sizeOfT(); + shape::calcSubArrsShapeInfoAndOffsets(pA->shapeInfo(), bS, +dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), +subArrOffsets.data()); for (int i = 0; i < bS; ++i) aSubArrs[i] = aRank == 2 ? +pA->specialBuffer() : pA->specialBuffer() + subArrOffsets[i] * pA->sizeOfT(); if(bRank > 2) - shape::calcSubArrsShapeInfoAndOffsets(pB->shapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data()); - for (int i = 0; i < bS; ++i) - bSubArrs[i] = bRank == 2 ? pB->specialBuffer() : pB->specialBuffer() + subArrOffsets[i] * pB->sizeOfT(); + shape::calcSubArrsShapeInfoAndOffsets(pB->shapeInfo(), bS, +dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), +subArrOffsets.data()); for (int i = 0; i < bS; ++i) bSubArrs[i] = bRank == 2 ? +pB->specialBuffer() : pB->specialBuffer() + subArrOffsets[i] * pB->sizeOfT(); - shape::calcSubArrsShapeInfoAndOffsets(pC->shapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data()); - for (int i = 0; i < bS; ++i) - cSubArrs[i] = pC->specialBuffer() + subArrOffsets[i] * pC->sizeOfT(); + shape::calcSubArrsShapeInfoAndOffsets(pC->shapeInfo(), bS, +dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), +subArrOffsets.data()); for (int i = 0; i < bS; ++i) cSubArrs[i] = +pC->specialBuffer() + subArrOffsets[i] * pC->sizeOfT(); PointersManager manager(A->getContext(), "mmulNxN"); - const void** aSubArrsCuda = reinterpret_cast(manager.replicatePointer(aSubArrs.data(), aSubArrs.size() * sizeof(void*))); - const void** bSubArrsCuda = reinterpret_cast(manager.replicatePointer(bSubArrs.data(), bSubArrs.size() * sizeof(void*))); - void** cSubArrsCuda = reinterpret_cast< void **>(manager.replicatePointer(cSubArrs.data(), cSubArrs.size() * sizeof(void*))); + const void** aSubArrsCuda = reinterpret_cast(manager.replicatePointer(aSubArrs.data(), aSubArrs.size() * +sizeof(void*))); const void** bSubArrsCuda = reinterpret_cast(manager.replicatePointer(bSubArrs.data(), bSubArrs.size() * +sizeof(void*))); void** cSubArrsCuda = reinterpret_cast< void +**>(manager.replicatePointer(cSubArrs.data(), cSubArrs.size() * +sizeof(void*))); const bool transA = !aMcont; const bool transB = !bKcont; - const int lda = (aMcont && aKcont) ? M : transA ? pA->strideAt(-2) : pA->strideAt(-1); - const int ldb = (bKcont && bNcont) ? K : transB ? pB->strideAt(-2) : pB->strideAt(-1); - const int ldc = (cMcont && cNcont) ? M : C != pC ? pC->strideAt(-2) : pC->strideAt(-1); + const int lda = (aMcont && aKcont) ? M : transA ? pA->strideAt(-2) : +pA->strideAt(-1); const int ldb = (bKcont && bNcont) ? K : transB ? +pB->strideAt(-2) : pB->strideAt(-1); const int ldc = (cMcont && cNcont) ? M : C +!= pC ? pC->strideAt(-2) : pC->strideAt(-1); const cublasOperation_t transAblas = transA ? CUBLAS_OP_T : CUBLAS_OP_N; const cublasOperation_t transBblas = transB ? CUBLAS_OP_T : CUBLAS_OP_N; @@ -989,21 +1161,27 @@ NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C, uBeta._d = beta; } - auto handle = reinterpret_cast(A->getContext()->getCublasHandle()); - auto stream = A->getContext()->getCudaStream(); + auto handle = reinterpret_cast(A->getContext()->getCublasHandle()); auto stream = +A->getContext()->getCudaStream(); auto status = cublasSetStream_v2(*handle, *stream); if (status != CUBLAS_STATUS_SUCCESS) - throw cuda_exception::build("MmulHelper::mmulNxN cuda failed !", status); + throw cuda_exception::build("MmulHelper::mmulNxN cuda failed !", +status); - status = cublasGemmBatchedEx(*handle, transAblas, transBblas, M, N, K, &uAlpha, aSubArrsCuda, cudaAType, lda, bSubArrsCuda, cudaBType, ldb, &uBeta, cSubArrsCuda, cudaCType, ldc, bS, cudaType, CUBLAS_GEMM_DEFAULT); + status = cublasGemmBatchedEx(*handle, transAblas, transBblas, M, N, K, +&uAlpha, aSubArrsCuda, cudaAType, lda, bSubArrsCuda, cudaBType, ldb, &uBeta, +cSubArrsCuda, cudaCType, ldc, bS, cudaType, CUBLAS_GEMM_DEFAULT); if (status != CUBLAS_STATUS_SUCCESS) - throw cuda_exception::build("MmulHelper::mmulNxN cuda failed !", status); + throw cuda_exception::build("MmulHelper::mmulNxN cuda failed !", +status); auto cudaResult = cudaStreamSynchronize(*stream); if (cudaResult != 0) - throw cuda_exception::build("MmulHelper::mmulNxN cuda failed !", cudaResult); + throw cuda_exception::build("MmulHelper::mmulNxN cuda failed !", +cudaResult); } else { @@ -1011,8 +1189,8 @@ NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C, for(Nd4jLong i = 0; i < bS; ++i) { - ShapeUtils::evalIdxRangesForSubArr(i, pC->shapeInfo(), dimsToExclude, idxRanges.data()); - NDArray cSubArr = (*pC)(idxRanges); + ShapeUtils::evalIdxRangesForSubArr(i, pC->shapeInfo(), +dimsToExclude, idxRanges.data()); NDArray cSubArr = (*pC)(idxRanges); if(aRank > bRank) { NDArray aSubArr = (*pA)(idxRanges); @@ -1042,8 +1220,19 @@ NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C, } */ -//BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); -//BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vB, const int incx, const double beta, void* vC, const int incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); -//BUILD_TRIPLE_TEMPLATE(template void usualDot, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); - -} \ No newline at end of file +// BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, +// const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const +// bool transB, const int M, const int N, const int K, const double alpha, const +// void* vA, const int lda, const void* vB, const int ldb, const double beta, +// void* vC, const int ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); +// BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, +// const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const +// int M, const int N, const double alpha, const void* vA, const int lda, const +// void* vB, const int incx, const double beta, void* vC, const int incy), +// NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); BUILD_TRIPLE_TEMPLATE(template +// void usualDot, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, +// cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* +// vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double +// beta, void* vZ), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); + +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/cuda_off/cublasHelper.cu b/libnd4j/include/helpers/cuda_off/cublasHelper.cu index b179b093049d..065353a9e966 100644 --- a/libnd4j/include/helpers/cuda_off/cublasHelper.cu +++ b/libnd4j/include/helpers/cuda_off/cublasHelper.cu @@ -18,13 +18,13 @@ // @author raver119@gmail.com // - #include #include -#include "../cublasHelper.h" #include -#include #include +#include + +#include #include "config.h" #ifdef HAVE_CUDNN @@ -34,103 +34,109 @@ #endif namespace sd { - std::mutex CublasHelper::_mutex; +std::mutex CublasHelper::_mutex; - static void* handle_() { - auto _handle = new cublasHandle_t(); - auto status = cublasCreate_v2(_handle); // initialize CUBLAS context - if (status != CUBLAS_STATUS_SUCCESS) - throw cuda_exception::build("cuBLAS handle creation failed !", status); +static void* handle_() { + auto _handle = new cublasHandle_t(); + auto status = cublasCreate_v2(_handle); // initialize CUBLAS context + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("cuBLAS handle creation failed !", status); - return reinterpret_cast(_handle); - } + return reinterpret_cast(_handle); +} - static void* solver_() { - auto cusolverH = new cusolverDnHandle_t(); - auto status = cusolverDnCreate(cusolverH); - if (status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("cuSolver handle creation failed !", status); +static void* solver_() { + auto cusolverH = new cusolverDnHandle_t(); + auto status = cusolverDnCreate(cusolverH); + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("cuSolver handle creation failed !", status); - return cusolverH; - } + return cusolverH; +} - static void* cudnn_() { +static void* cudnn_() { #ifdef HAVE_CUDNN - auto cudnnH = new cudnnHandle_t(); - auto status = cudnnCreate(cudnnH); - if (status != CUDNN_STATUS_SUCCESS) - throw cuda_exception::build("cuDNN handle creation failed !", status); + auto cudnnH = new cudnnHandle_t(); + auto status = cudnnCreate(cudnnH); + if (status != CUDNN_STATUS_SUCCESS) + throw cuda_exception::build("cuDNN handle creation failed !", status); - return cudnnH; + return cudnnH; #endif - return nullptr; - } - - static void destroyHandle_(void* handle) { - auto ch = reinterpret_cast(handle); - auto status = cublasDestroy_v2(*ch); - if (status != CUBLAS_STATUS_SUCCESS) - throw cuda_exception::build("cuBLAS handle destruction failed !", status); - - delete ch; - } - - CublasHelper::CublasHelper() { - //nd4j_printf("Initializing cuBLAS\n",""); - auto numDevices = AffinityManager::numberOfDevices(); - auto currentDevice = AffinityManager::currentDeviceId(); - _cache.resize(numDevices); - _solvers.resize(numDevices); - _cudnn.resize(numDevices); - for (int e = 0; e < numDevices; e++) { - AffinityManager::setCurrentNativeDevice(e); - - _cache[e] = handle_(); - _solvers[e] = solver_(); - _cudnn[e] = cudnn_(); - } - - // don't forget to restore back original device - AffinityManager::setCurrentNativeDevice(currentDevice); - } - - CublasHelper::~CublasHelper() { - auto numDevices = AffinityManager::numberOfDevices(); - - for (int e = 0; e < numDevices; e++) - destroyHandle_(_cache[e]); - } - - CublasHelper& CublasHelper::getInstance() { - static CublasHelper instance; - return instance; - } - - void* CublasHelper::cudnn() { - auto deviceId = AffinityManager::currentDeviceId(); - if (deviceId < 0 || deviceId > _cudnn.size()) - throw cuda_exception::build("requested deviceId doesn't look valid", deviceId); - - return _cudnn[deviceId]; - } - - void* CublasHelper::handle() { - auto deviceId = AffinityManager::currentDeviceId(); - return handle(deviceId); - } - - void* CublasHelper::solver() { - auto deviceId = AffinityManager::currentDeviceId(); - if (deviceId < 0 || deviceId > _solvers.size()) - throw cuda_exception::build("requested deviceId doesn't look valid", deviceId); - - return _solvers[deviceId]; - } - - void* CublasHelper::handle(int deviceId) { - if (deviceId < 0 || deviceId > _cache.size()) - throw cuda_exception::build("requested deviceId doesn't look valid", deviceId); - - return _cache[deviceId]; - } -} \ No newline at end of file + return nullptr; +} + +static void destroyHandle_(void* handle) { + auto ch = reinterpret_cast(handle); + auto status = cublasDestroy_v2(*ch); + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("cuBLAS handle destruction failed !", status); + + delete ch; +} + +CublasHelper::CublasHelper() { + // nd4j_printf("Initializing cuBLAS\n",""); + auto numDevices = AffinityManager::numberOfDevices(); + auto currentDevice = AffinityManager::currentDeviceId(); + _cache.resize(numDevices); + _solvers.resize(numDevices); + _cudnn.resize(numDevices); + for (int e = 0; e < numDevices; e++) { + AffinityManager::setCurrentNativeDevice(e); + + _cache[e] = handle_(); + _solvers[e] = solver_(); + _cudnn[e] = cudnn_(); + } + + // don't forget to restore back original device + AffinityManager::setCurrentNativeDevice(currentDevice); +} + +CublasHelper::~CublasHelper() { + + auto numDevices = AffinityManager::numberOfDevices(); + + for (int e = 0; e < numDevices; e++) destroyHandle_(_cache[e]); +} + +CublasHelper& CublasHelper::getInstance() { + static CublasHelper instance; + + return instance; +} + +void* CublasHelper::cudnn() { + auto deviceId = AffinityManager::currentDeviceId(); + if (deviceId < 0 || deviceId > _cudnn.size()) + throw cuda_exception::build("requested deviceId doesn't look valid", + deviceId); + + return _cudnn[deviceId]; +} + +void* CublasHelper::handle() { + auto deviceId = AffinityManager::currentDeviceId(); + return handle(deviceId); +} + +void* CublasHelper::solver() { + auto deviceId = AffinityManager::currentDeviceId(); + if (deviceId < 0 || deviceId > _solvers.size()) + throw cuda_exception::build("requested deviceId doesn't look valid", + deviceId); + + return _solvers[deviceId]; +} + +void* CublasHelper::handle(int deviceId) { + if (deviceId < 0 || deviceId > _cache.size()) + throw cuda_exception::build("requested deviceId doesn't look valid", + deviceId); + + return _cache[deviceId]; +} + + +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/data_gen.h b/libnd4j/include/helpers/data_gen.h index 1640396178a9..4145de70e0ad 100644 --- a/libnd4j/include/helpers/data_gen.h +++ b/libnd4j/include/helpers/data_gen.h @@ -29,15 +29,14 @@ * @return the linear spaced array */ template -T * linspace(int lower, int upper, int num) { - T *data = new T[num]; - for (int i = 0; i < num; i++) { - T t = (T) i / (num - 1); - data[i] = lower * (1 - t) + t * upper; +T *linspace(int lower, int upper, int num) { + T *data = new T[num]; + for (int i = 0; i < num; i++) { + T t = (T)i / (num - 1); + data[i] = lower * (1 - t) + t * upper; + } - } - - return data; + return data; } -#endif //LIBND4J_DATA_GEN_H_H +#endif // LIBND4J_DATA_GEN_H_H diff --git a/libnd4j/include/helpers/files.h b/libnd4j/include/helpers/files.h index c49cedbb737b..6e4de204e2e3 100644 --- a/libnd4j/include/helpers/files.h +++ b/libnd4j/include/helpers/files.h @@ -16,113 +16,110 @@ // // Methods to lookup files in $PATH -// adopted from https://stackoverflow.com/questions/2718915/check-if-file-exists-including-on-path +// adopted from +// https://stackoverflow.com/questions/2718915/check-if-file-exists-including-on-path // #ifndef LIBND4J_FILES_H #define LIBND4J_FILES_H -#include -#include -#include #include - +#include +#include void *malloc_check(const char *what, size_t n); char *strsave(const char *s, const char *lim); -char ** shellpath(void); -void freeshellpath (char *shellpath[]); +char **shellpath(void); +void freeshellpath(char *shellpath[]); unsigned maxpathlen(char *path[], const char *base); -bool file_exists(char *name); +bool file_exists(const char *name); void *malloc_check(const char *what, size_t n) { - void *p = malloc(n); - if (p == NULL) { - fprintf(stderr, "Cannot allocate %zu bytes to %s\n", n, what); - exit(2); - } - return p; + void *p = malloc(n); + if (p == NULL) { + fprintf(stderr, "Cannot allocate %zu bytes to %s\n", n, what); + exit(2); + } + return p; } char *strsave(const char *s, const char *lim) { - if (lim == NULL) - lim = s + strlen(s); - char *p = (char *) malloc_check("save string", lim - s + 1); - strncpy(p, s, lim-s); - p[lim-s] = '\0'; - return p; + if (lim == NULL) lim = s + strlen(s); + char *p = (char *)malloc_check("save string", lim - s + 1); + strncpy(p, s, lim - s); + p[lim - s] = '\0'; + return p; } -char ** shellpath(void) { - const char *path = getenv("PATH"); - if (!path) - path = "./"; +char **shellpath(void) { + const char *path = getenv("PATH"); + if (!path) path = "./"; - char **vector = // size is overkill - (char **) malloc_check("hold path elements", strlen(path) * sizeof(*vector)); - const char *p = path; - int next = 0; - while (p) { + char **vector = // size is overkill + (char **)malloc_check("hold path elements", + strlen(path) * sizeof(*vector)); + const char *p = path; + int next = 0; + while (p) { #ifdef _WIN32 - char *q = strchr(p, ';'); // windows uses ; as delimiter + char *q = + strchr(const_cast(p), ';'); // windows uses ; as delimiter #else - char *q = strchr(p, ':'); // linux and derivatives use : as delimiter + char *q = strchr(const_cast(p), + ':'); // linux and derivatives use : as delimiter #endif - vector[next++] = strsave(p, q); - p = q ? q + 1 : NULL; - } - vector[next] = NULL; - return vector; + vector[next++] = strsave(p, q); + p = q ? q + 1 : NULL; + } + vector[next] = NULL; + return vector; } -void freeshellpath (char *shellpath[]) { - for (int i = 0; shellpath[i]; i++) - free(shellpath[i]); - free(shellpath); +void freeshellpath(char *shellpath[]) { + for (int i = 0; shellpath[i]; i++) free(shellpath[i]); + free(shellpath); } unsigned maxpathlen(char *path[], const char *base) { - unsigned blen = strlen(base); - unsigned n = 0; - for (int i = 0; path[i]; i++) { - unsigned pn = strlen(path[i]); - if (pn > n) n = pn; - } - return blen+n+1; + unsigned blen = strlen(base); + unsigned n = 0; + for (int i = 0; path[i]; i++) { + unsigned pn = strlen(path[i]); + if (pn > n) n = pn; + } + return blen + n + 1; } -bool file_exists(char *name){ - printf("Trying file: [%s]\n", name); - FILE *file; - if (file = fopen(name, "r")) { - fclose(file); - return true; - } - return false; +bool file_exists(const char *name) { + FILE *file; + if ((file = fopen(name, "r"))) { + fclose(file); + return true; + } + return false; } bool checkFileInPath(const char *file) { - char *path = getenv("PATH"); - char **listed = shellpath(); - size_t maxlen = maxpathlen(listed, file)+1; - char *buf = (char *) malloc_check("hold path", maxlen); - bool found = false; - for (int i = 0; listed[i]; i++) { - if (strlen(listed[i]) > 0) { + char *path = getenv("PATH"); + char **listed = shellpath(); + size_t maxlen = maxpathlen(listed, file) + 1; + char *buf = (char *)malloc_check("hold path", maxlen); + bool found = false; + for (int i = 0; listed[i]; i++) { + if (strlen(listed[i]) > 0) { #ifdef _WIN32 - snprintf(buf, maxlen, "%s\\%s", listed[i], file); + snprintf(buf, maxlen, "%s\\%s", listed[i], file); #else - snprintf(buf, maxlen, "%s/%s", listed[i], file); + snprintf(buf, maxlen, "%s/%s", listed[i], file); #endif - if (file_exists(buf)) { - found = true; - break; - } - } + if (file_exists(buf)) { + found = true; + break; + } } - free(buf); - freeshellpath(listed); + } + free(buf); + freeshellpath(listed); - return found; + return found; } - -#endif //LIBND4J_FILES_H +#endif // LIBND4J_FILES_H diff --git a/libnd4j/include/helpers/helper_generator.h b/libnd4j/include/helpers/helper_generator.h index ecf87ae81f10..e6ce41cc968b 100644 --- a/libnd4j/include/helpers/helper_generator.h +++ b/libnd4j/include/helpers/helper_generator.h @@ -21,10 +21,10 @@ #ifndef LIBND4J_HELPER_GENERATOR_H #define LIBND4J_HELPER_GENERATOR_H -#include -#include #include #include +#include +#include #ifdef _MSC_VER // include for uint64_t on MSVC @@ -34,587 +34,541 @@ #ifndef UINT64_C #if defined(__LP64__) -#define UINT64_C(c) c ## UL +#define UINT64_C(c) c##UL #else -#define UINT64_C(c) c ## ULL -#endif //LP64 -#endif // UINT64 - -#endif // MSVC/ANDROID +#define UINT64_C(c) c##ULL +#endif // LP64 +#endif // UINT64 +#endif // MSVC/ANDROID #ifdef __GNUC__ #include #endif - namespace sd { - namespace random { +namespace random { #ifdef __CUDACC__ - class ND4J_EXPORT CudaManaged { - private: - - protected: - void *devHolder; - - public: - void *operator new(size_t len) { - void *ptr; - cudaHostAlloc(&ptr, len, cudaHostAllocDefault); - return ptr; - } - - void operator delete(void *ptr) { - cudaFreeHost(ptr); - } - }; - - class ND4J_EXPORT RandomBuffer : public CudaManaged { +class SD_EXPORT CudaManaged { + private: + protected: + void *devHolder; + + public: + void *operator new(size_t len) { + void *ptr; + cudaHostAlloc(&ptr, len, cudaHostAllocDefault); + return ptr; + } + + void operator delete(void *ptr) { cudaFreeHost(ptr); } +}; + +class SD_EXPORT RandomBuffer : public CudaManaged { #else - class ND4J_EXPORT RandomBuffer { +class SD_EXPORT RandomBuffer { #endif - private: - void *devHolder; - Nd4jLong size; - uint64_t *buffer; - uint64_t *devBuffer; - Nd4jLong offset; - Nd4jLong seed; - Nd4jLong position; - Nd4jLong generation; - Nd4jLong currentPosition; - Nd4jLong amplifier; - unsigned int synchronizer; + private: + void *devHolder; + Nd4jLong size; + uint64_t *buffer; + uint64_t *devBuffer; + Nd4jLong offset; + Nd4jLong seed; + Nd4jLong position; + Nd4jLong generation; + Nd4jLong currentPosition; + Nd4jLong amplifier; + unsigned int synchronizer; #ifdef __CUDACC__ - curandGenerator_t gen; + curandGenerator_t gen; #endif - public: - /** - * This method allocates buffer of size * sizeof(Nd4jLong) - * - * @param size - * @return - */ + public: + /** + * This method allocates buffer of size * sizeof(Nd4jLong) + * + * @param size + * @return + */ #ifdef __CUDACC__ - __host__ - RandomBuffer(Nd4jLong seed, Nd4jLong size, uint64_t *hostBuffer, uint64_t *devBuffer) { - this->buffer = hostBuffer; - this->seed = seed; - this->size = size; - this->generation = 1; - this->currentPosition = 0; - this->offset = 0; - this->amplifier = seed; - this->synchronizer = 0; - this->devBuffer = devBuffer; - - cudaMalloc(&devHolder, sizeof(sd::random::RandomBuffer)); - } - - __host__ - Nd4jPointer getDevicePointer() { - return reinterpret_cast(devHolder); - } - - __host__ - ~RandomBuffer() { - cudaFree(devHolder); - } - - __host__ - void propagateToDevice(sd::random::RandomBuffer *buffer, cudaStream_t stream) { - cudaMemcpyAsync(devHolder, buffer, sizeof(sd::random::RandomBuffer), cudaMemcpyHostToDevice, stream); - } - - __host__ __device__ + __host__ RandomBuffer(Nd4jLong seed, Nd4jLong size, uint64_t *hostBuffer, + uint64_t *devBuffer) { + this->buffer = hostBuffer; + this->seed = seed; + this->size = size; + this->generation = 1; + this->currentPosition = 0; + this->offset = 0; + this->amplifier = seed; + this->synchronizer = 0; + this->devBuffer = devBuffer; + + cudaMalloc(&devHolder, sizeof(sd::random::RandomBuffer)); + } + + __host__ Nd4jPointer getDevicePointer() { + return reinterpret_cast(devHolder); + } + + __host__ ~RandomBuffer() { cudaFree(devHolder); } + + __host__ void propagateToDevice(sd::random::RandomBuffer *buffer, + cudaStream_t stream) { + cudaMemcpyAsync(devHolder, buffer, sizeof(sd::random::RandomBuffer), + cudaMemcpyHostToDevice, stream); + } + + __host__ __device__ #endif - RandomBuffer(Nd4jLong seed, Nd4jLong size, uint64_t *buffer) { - this->buffer = buffer; - this->seed = seed; - this->size = size; - this->generation = 1; - this->currentPosition = 0; - this->offset = 0; - this->amplifier = seed; - this->synchronizer = 0; - this->devBuffer = buffer; - } - - inline _CUDA_HD uint64_t *getBuffer() { - return this->buffer; - } - - inline _CUDA_HD uint64_t *getDeviceBuffer() { - return this->devBuffer; - } + RandomBuffer(Nd4jLong seed, Nd4jLong size, uint64_t *buffer) { + this->buffer = buffer; + this->seed = seed; + this->size = size; + this->generation = 1; + this->currentPosition = 0; + this->offset = 0; + this->amplifier = seed; + this->synchronizer = 0; + this->devBuffer = buffer; + } + + inline _CUDA_HD uint64_t *getBuffer() { return this->buffer; } + + inline _CUDA_HD uint64_t *getDeviceBuffer() { return this->devBuffer; } #ifdef __CUDACC__ - _CUDA_HD curandGenerator_t *getGeneratorPointer() { - return &gen; - } - - _CUDA_HD curandGenerator_t getGenerator() { - return gen; - } + _CUDA_HD curandGenerator_t *getGeneratorPointer() { return &gen; } + _CUDA_HD curandGenerator_t getGenerator() { return gen; } - _CUDA_H void setBuffer(uint64_t *ptr) { - this->buffer = ptr; - } + _CUDA_H void setBuffer(uint64_t *ptr) { this->buffer = ptr; } #endif - inline _CUDA_HD Nd4jLong getSize() { - return this->size; - } - - inline _CUDA_HD Nd4jLong getSeed() { - return this->seed; - } - - void _CUDA_HD setSeed(Nd4jLong seed) { - this->seed = seed; - this->amplifier = seed; - this->generation = 1; - } - - Nd4jLong _CUDA_HD getAllocatedSize() { - return this->size * sizeof(double); - } - - inline _CUDA_HD Nd4jLong getOffset() { - return this->currentPosition; - } - - void _CUDA_HD setOffset(Nd4jLong offset) { - this->currentPosition = offset; - } - - void _CUDA_HD reSeed(Nd4jLong amplifier) { - this->amplifier = amplifier; - } - - inline _CUDA_D uint64_t getElement(Nd4jLong position) { - Nd4jLong actualPosition = this->getOffset() + position; - Nd4jLong tempGen = generation; - if (actualPosition >= this->size) { - tempGen += actualPosition / this->size; - actualPosition = actualPosition % this->size; - } + inline _CUDA_HD Nd4jLong getSize() { return this->size; } + + inline _CUDA_HD Nd4jLong getSeed() { return this->seed; } + + void _CUDA_HD setSeed(Nd4jLong seed) { + this->seed = seed; + this->amplifier = seed; + this->generation = 1; + } + + Nd4jLong _CUDA_HD getAllocatedSize() { return this->size * sizeof(double); } + + inline _CUDA_HD Nd4jLong getOffset() { return this->currentPosition; } + + void _CUDA_HD setOffset(Nd4jLong offset) { this->currentPosition = offset; } + + void _CUDA_HD reSeed(Nd4jLong amplifier) { this->amplifier = amplifier; } + + inline _CUDA_D uint64_t getElement(Nd4jLong position) { + Nd4jLong actualPosition = this->getOffset() + position; + Nd4jLong tempGen = generation; + if (actualPosition >= this->size) { + tempGen += actualPosition / this->size; + actualPosition = actualPosition % this->size; + } #ifdef __CUDACC__ -// __syncthreads(); + // __syncthreads(); - auto ret = static_cast(devBuffer[actualPosition]); + auto ret = static_cast(devBuffer[actualPosition]); #else - auto ret = static_cast(buffer[actualPosition]); + auto ret = static_cast(buffer[actualPosition]); #endif - if (tempGen != generation) - ret = safeShift(ret, tempGen); + if (tempGen != generation) ret = safeShift(ret, tempGen); - if(generation > 1) - ret = safeShift(ret, generation); + if (generation > 1) ret = safeShift(ret, generation); - if (amplifier != seed) - ret = safeShift(ret, amplifier); + if (amplifier != seed) ret = safeShift(ret, amplifier); #ifdef __CUDACC__ // __syncthreads(); #endif - if (amplifier != seed || generation > 1 || tempGen != generation) - ret = next64(seedConv(static_cast(ret))); - - return ret; - } - - uint64_t _CUDA_HD next64(uint64_t shiftedSeed) { - const auto s0 = static_cast(shiftedSeed); - auto s1 = static_cast(shiftedSeed) % sd::DataTypeUtils::max() + 11; - uint64_t r0, r1; - - s1 ^= s0; - r0 = rotl(s0, 55) ^ s1 ^ (s1 << 14); // a, b - r1 = rotl(s1, 36); // c - - return r0 + r1; - } - - static _CUDA_HD inline uint64_t rotl(const uint64_t x, uint64_t k) { - return (x << k) | (x >> (64 - k)); - } - - uint64_t static _CUDA_HD inline safeShift(uint64_t x, uint64_t y) { - if (y != 0 && x > sd::DataTypeUtils::max() / y) { - return x / y + 11; - } else return (x * y) + 11; - } - - uint64_t _CUDA_HD seedConv(Nd4jLong seed) { - uint64_t x = static_cast(seed); - uint64_t z = (x += UINT64_C(0x9E3779B97F4A7C15)); - z = (z ^ (z >> 30)) * UINT64_C(0xBF58476D1CE4E5B9); - z = (z ^ (z >> 27)) * UINT64_C(0x94D049BB133111EB); - return z ^ (z >> 31); - } - - void _CUDA_HD incrementGeneration() { - this->generation++; - } - - Nd4jLong _CUDA_HD getNextIndex() { - currentPosition++; - if (currentPosition >= size) { - currentPosition = 0; - generation++; - } - Nd4jLong ret = currentPosition; - - return ret; - } - - uint64_t _CUDA_HD getNextElement() { - // TODO: proper implementation needed here - return generation == 1 ? buffer[getNextIndex()] : buffer[getNextIndex()] * generation; - } - - - /** - * This method skips X elements from buffer - * - * @param numberOfElements number of elements to skip - */ + if (amplifier != seed || generation > 1 || tempGen != generation) + ret = next64(seedConv(static_cast(ret))); + + return ret; + } + + uint64_t _CUDA_HD next64(uint64_t shiftedSeed) { + const auto s0 = static_cast(shiftedSeed); + auto s1 = + static_cast(shiftedSeed) % sd::DataTypeUtils::max() + 11; + uint64_t r0, r1; + + s1 ^= s0; + r0 = rotl(s0, 55) ^ s1 ^ (s1 << 14); // a, b + r1 = rotl(s1, 36); // c + + return r0 + r1; + } + + static _CUDA_HD inline uint64_t rotl(const uint64_t x, uint64_t k) { + return (x << k) | (x >> (64 - k)); + } + + uint64_t static _CUDA_HD inline safeShift(uint64_t x, uint64_t y) { + if (y != 0 && x > sd::DataTypeUtils::max() / y) { + return x / y + 11; + } else + return (x * y) + 11; + } + + uint64_t _CUDA_HD seedConv(Nd4jLong seed) { + uint64_t x = static_cast(seed); + uint64_t z = (x += UINT64_C(0x9E3779B97F4A7C15)); + z = (z ^ (z >> 30)) * UINT64_C(0xBF58476D1CE4E5B9); + z = (z ^ (z >> 27)) * UINT64_C(0x94D049BB133111EB); + return z ^ (z >> 31); + } + + void _CUDA_HD incrementGeneration() { this->generation++; } + + Nd4jLong _CUDA_HD getNextIndex() { + currentPosition++; + if (currentPosition >= size) { + currentPosition = 0; + generation++; + } + Nd4jLong ret = currentPosition; + + return ret; + } + + uint64_t _CUDA_HD getNextElement() { + // TODO: proper implementation needed here + return generation == 1 ? buffer[getNextIndex()] + : buffer[getNextIndex()] * generation; + } + + /** + * This method skips X elements from buffer + * + * @param numberOfElements number of elements to skip + */ #ifdef __CUDACC__ - __device__ - void rewind(Nd4jLong numberOfElements) { - if (gridDim.x > 1) { - __shared__ bool amLast; - - if (threadIdx.x == 0) { - unsigned int ticket = atomicInc(&synchronizer, gridDim.x); - amLast = (ticket == gridDim.x - 1); - } - __syncthreads(); - - if (amLast) { - if (threadIdx.x == 0) { - synchronizer = 0; - - Nd4jLong newPos = this->getOffset() + numberOfElements; - if (newPos > this->getSize()) { - generation += newPos / this->size; - newPos = newPos % this->size; - } else if (newPos == this->getSize()) { - newPos = 0; - generation++; - } - - this->setOffset(newPos); - } - } - } else { - if (threadIdx.x == 0) { - Nd4jLong newPos = this->getOffset() + numberOfElements; - if (newPos > this->getSize()) { - generation += newPos / this->size; - newPos = newPos % this->size; - } else if (newPos == this->getSize()) { - generation++; - newPos = 0; - } - - this->setOffset(newPos); - } - } - } + __device__ void rewind(Nd4jLong numberOfElements) { + if (gridDim.x > 1) { + __shared__ bool amLast; + + if (threadIdx.x == 0) { + unsigned int ticket = atomicInc(&synchronizer, gridDim.x); + amLast = (ticket == gridDim.x - 1); + } + __syncthreads(); + + if (amLast) { + if (threadIdx.x == 0) { + synchronizer = 0; + + Nd4jLong newPos = this->getOffset() + numberOfElements; + if (newPos > this->getSize()) { + generation += newPos / this->size; + newPos = newPos % this->size; + } else if (newPos == this->getSize()) { + newPos = 0; + generation++; + } + + this->setOffset(newPos); + } + } + } else { + if (threadIdx.x == 0) { + Nd4jLong newPos = this->getOffset() + numberOfElements; + if (newPos > this->getSize()) { + generation += newPos / this->size; + newPos = newPos % this->size; + } else if (newPos == this->getSize()) { + generation++; + newPos = 0; + } + + this->setOffset(newPos); + } + } + } #endif - void rewindH(Nd4jLong numberOfElements) { - Nd4jLong newPos = this->getOffset() + numberOfElements; - if (newPos > this->getSize()) { - generation += newPos / this->size; - newPos = newPos % this->size; - } - else if (newPos == this->getSize()) { - generation++; - newPos = 0; - } - - this->setOffset(newPos); - } - - /** - * This method returns random int in range [0..MAX_INT] - * @return - */ - int _CUDA_D nextInt() { - auto u = nextUInt64(); - return u <= sd::DataTypeUtils::max() ? static_cast(u) : static_cast(u % sd::DataTypeUtils::max()); - }; - - uint64_t _CUDA_D nextUInt64() { - return getNextElement(); - } - - /** - * This method returns random int in range [0..to] - * @param to - * @return - */ - int _CUDA_D nextInt(int to) { - int r = nextInt(); - int m = to - 1; - if ((to & m) == 0) // i.e., bound is a power of 2 - r = ((to * (Nd4jLong) r) >> 31); - else { - for (int u = r; - u - (r = u % to) + m < 0; - u = nextInt()); - } - return r; - }; - - /** - * This method returns random int in range [from..to] - * @param from - * @param to - * @return - */ - int _CUDA_D nextInt(int from, int to) { - if (from == 0) - return nextInt(to); - - return from + nextInt(to - from); - }; - - - /** - * This method returns random T in range of [0..1] - * @return - */ - template - _CUDA_D T nextT() { - auto u = static_cast(nextUInt64()); - auto m = static_cast(sd::DataTypeUtils::max()); - return static_cast(u / m); - } - - /** - * This method returns random T in range of [0..to] - * @param to - * @return - */ - template - _CUDA_D T nextT(T to) { - if (to == static_cast(1.0f)) - return nextT(); - - return nextT(static_cast(0.0f), to); - } - - /** - * This method returns random T in range [from..to] - * @param from - * @param to - * @return - */ - template - _CUDA_D T inline nextT(T from, T to) { - return from + (nextT() * (to - from)); - } - - inline _CUDA_D uint64_t relativeUInt64(Nd4jLong index) { - return getElement(index); - } - - /** - * relative methods are made as workaround for lock-free concurrent execution - */ - inline int _CUDA_D relativeInt(Nd4jLong index) { - auto u = relativeUInt64(index); - return u <= sd::DataTypeUtils::max() ? static_cast(u) : static_cast(u % sd::DataTypeUtils::max()); - } - - /** - * This method returns random int within [0..to] - * - * @param index - * @param to - * @return - */ - inline int _CUDA_D relativeInt(Nd4jLong index, int to) { - auto rel = relativeInt(index); - return rel % to; - } - - /** - * This method returns random int within [from..to] - * - * @param index - * @param to - * @param from - * @return - */ - inline _CUDA_D int relativeInt(Nd4jLong index, int from, int to) { - if (from == 0) - return relativeInt(index, to); - - return from + relativeInt(index, to - from); - } - - /** - * This method returns random T within [0..1] - * - * @param index - * @return - */ - template - inline _CUDA_D T relativeT(Nd4jLong index) { - /** - * Basically we just get float u/m value, and convert into to - * - * FIXME: once we add support for additional datatypes this code must be tweaked - */ - auto u = static_cast(relativeUInt64(index)); - auto m = static_cast (sd::DataTypeUtils::max()); - return static_cast(u / m); - } - -/** - * This method returns random T within [0..to] - * - * @param index - * @param to - * @return - */ - - template - _CUDA_D T relativeT(Nd4jLong index, T to) { - if (to == static_cast(1.0f)) - return relativeT(index); - - return relativeT(index, static_cast(0.0f), to); - } + void rewindH(Nd4jLong numberOfElements) { + Nd4jLong newPos = this->getOffset() + numberOfElements; + if (newPos > this->getSize()) { + generation += newPos / this->size; + newPos = newPos % this->size; + } else if (newPos == this->getSize()) { + generation++; + newPos = 0; + } -/** - * This method returns random T within [from..to] - * - * @param index - * @param from - * @param to - * @return - */ - template - _CUDA_D T relativeT(Nd4jLong index, T from, T to) { - return from + (relativeT(index) * (to - from)); - } - - }; - - class ND4J_EXPORT IGenerator { - protected: - Nd4jLong limit; - Nd4jLong seed; - uint64_t *buffer; - sd::random::RandomBuffer *realBuffer; - - public: - - _CUDA_HD IGenerator(sd::random::RandomBuffer *buffer) { - this->limit = buffer->getSize(); - this->buffer = reinterpret_cast(buffer->getBuffer()); - this->realBuffer = buffer; - this->seed = buffer->getSeed(); - } - - - _CUDA_HD RandomBuffer *getBuffer() { - return realBuffer; - } - - _CUDA_HD void setOffset(Nd4jLong offset) { - this->realBuffer->setOffset(offset); - } - - _CUDA_HD Nd4jLong getElementAbsolute(Nd4jLong position) { - return buffer[position]; - } - - _CUDA_HD Nd4jLong getElementRelative(Nd4jLong position) { - return buffer[realBuffer->getOffset() + position]; - } - - virtual _CUDA_HD void refreshBuffer() = 0; - }; - - - - class ND4J_EXPORT Xoroshiro128 : public IGenerator { - protected: - uint64_t state[2]; - - static inline _CUDA_HD uint64_t rotl(const uint64_t x, int k) { - return (x << k) | (x >> (64 - k)); - } - - /** - * This method returns 64 random bits - * @return - */ - uint64_t _CUDA_HD next64() { - const uint64_t s0 = state[0]; - uint64_t s1 = state[1]; - const uint64_t result = s0 + s1; - - s1 ^= s0; - state[0] = rotl(s0, 55) ^ s1 ^ (s1 << 14); // a, b - state[1] = rotl(s1, 36); // c - - return result; - } - - uint64_t _CUDA_HD seedConv(Nd4jLong seed) { - uint64_t x = static_cast(seed); - uint64_t z = (x += UINT64_C(0x9E3779B97F4A7C15)); - z = (z ^ (z >> 30)) * UINT64_C(0xBF58476D1CE4E5B9); - z = (z ^ (z >> 27)) * UINT64_C(0x94D049BB133111EB); - return z ^ (z >> 31); - } - - void _CUDA_H jump(void) { - static const uint64_t JUMP[] = { 0xbeac0467eba5facb, 0xd86b048b86aa9922 }; - - uint64_t s0 = 0; - uint64_t s1 = 0; - for(unsigned int i = 0; i < sizeof JUMP / sizeof *JUMP; i++) - for(int b = 0; b < 64; b++) { - if (JUMP[i] & 1ULL << b) { - s0 ^= state[0]; - s1 ^= state[1]; - } - next64(); - } - - state[0] = s0; - state[1] = s1; - } - - public: - _CUDA_HD Xoroshiro128(sd::random::RandomBuffer *buffer) : IGenerator(buffer) { - // - } - - _CUDA_HD void refreshBuffer() { - state[0] = seedConv(this->seed); - state[1] = seedConv(this->seed * 119 + 3); - - int fd = 3 + 3; - - for (Nd4jLong i = 0; i < limit; i++) { - buffer[i] = next64(); - } - } - }; + this->setOffset(newPos); + } + + /** + * This method returns random int in range [0..MAX_INT] + * @return + */ + int _CUDA_D nextInt() { + auto u = nextUInt64(); + return u <= sd::DataTypeUtils::max() + ? static_cast(u) + : static_cast(u % sd::DataTypeUtils::max()); + }; + + uint64_t _CUDA_D nextUInt64() { return getNextElement(); } + + /** + * This method returns random int in range [0..to] + * @param to + * @return + */ + int _CUDA_D nextInt(int to) { + int r = nextInt(); + int m = to - 1; + if ((to & m) == 0) // i.e., bound is a power of 2 + r = ((to * (Nd4jLong)r) >> 31); + else { + for (int u = r; u - (r = u % to) + m < 0; u = nextInt()) + ; + } + return r; + }; + + /** + * This method returns random int in range [from..to] + * @param from + * @param to + * @return + */ + int _CUDA_D nextInt(int from, int to) { + if (from == 0) return nextInt(to); + + return from + nextInt(to - from); + }; + + /** + * This method returns random T in range of [0..1] + * @return + */ + template + _CUDA_D T nextT() { + auto u = static_cast(nextUInt64()); + auto m = static_cast(sd::DataTypeUtils::max()); + return static_cast(u / m); + } + + /** + * This method returns random T in range of [0..to] + * @param to + * @return + */ + template + _CUDA_D T nextT(T to) { + if (to == static_cast(1.0f)) return nextT(); + + return nextT(static_cast(0.0f), to); + } + + /** + * This method returns random T in range [from..to] + * @param from + * @param to + * @return + */ + template + _CUDA_D T inline nextT(T from, T to) { + return from + (nextT() * (to - from)); + } + + inline _CUDA_D uint64_t relativeUInt64(Nd4jLong index) { + return getElement(index); + } + + /** + * relative methods are made as workaround for lock-free concurrent execution + */ + inline int _CUDA_D relativeInt(Nd4jLong index) { + auto u = relativeUInt64(index); + return u <= sd::DataTypeUtils::max() + ? static_cast(u) + : static_cast(u % sd::DataTypeUtils::max()); + } + + /** + * This method returns random int within [0..to] + * + * @param index + * @param to + * @return + */ + inline int _CUDA_D relativeInt(Nd4jLong index, int to) { + auto rel = relativeInt(index); + return rel % to; + } + + /** + * This method returns random int within [from..to] + * + * @param index + * @param to + * @param from + * @return + */ + inline _CUDA_D int relativeInt(Nd4jLong index, int from, int to) { + if (from == 0) return relativeInt(index, to); + + return from + relativeInt(index, to - from); + } + + /** + * This method returns random T within [0..1] + * + * @param index + * @return + */ + template + inline _CUDA_D T relativeT(Nd4jLong index) { + /** + * Basically we just get float u/m value, and convert into to + * + * FIXME: once we add support for additional datatypes this code must be + * tweaked + */ + auto u = static_cast(relativeUInt64(index)); + auto m = static_cast(sd::DataTypeUtils::max()); + return static_cast(u / m); + } + + /** + * This method returns random T within [0..to] + * + * @param index + * @param to + * @return + */ + + template + _CUDA_D T relativeT(Nd4jLong index, T to) { + if (to == static_cast(1.0f)) return relativeT(index); + + return relativeT(index, static_cast(0.0f), to); + } + + /** + * This method returns random T within [from..to] + * + * @param index + * @param from + * @param to + * @return + */ + template + _CUDA_D T relativeT(Nd4jLong index, T from, T to) { + return from + (relativeT(index) * (to - from)); + } +}; + +class SD_EXPORT IGenerator { + protected: + Nd4jLong limit; + Nd4jLong seed; + uint64_t *buffer; + sd::random::RandomBuffer *realBuffer; + + public: + _CUDA_HD IGenerator(sd::random::RandomBuffer *buffer) { + this->limit = buffer->getSize(); + this->buffer = reinterpret_cast(buffer->getBuffer()); + this->realBuffer = buffer; + this->seed = buffer->getSeed(); + } + + _CUDA_HD RandomBuffer *getBuffer() { return realBuffer; } + + _CUDA_HD void setOffset(Nd4jLong offset) { + this->realBuffer->setOffset(offset); + } + + _CUDA_HD Nd4jLong getElementAbsolute(Nd4jLong position) { + return buffer[position]; + } + + _CUDA_HD Nd4jLong getElementRelative(Nd4jLong position) { + return buffer[realBuffer->getOffset() + position]; + } + + virtual _CUDA_HD void refreshBuffer() = 0; +}; + +class SD_EXPORT Xoroshiro128 : public IGenerator { + protected: + uint64_t state[2]; + + static inline _CUDA_HD uint64_t rotl(const uint64_t x, int k) { + return (x << k) | (x >> (64 - k)); + } + + /** + * This method returns 64 random bits + * @return + */ + uint64_t _CUDA_HD next64() { + const uint64_t s0 = state[0]; + uint64_t s1 = state[1]; + const uint64_t result = s0 + s1; + + s1 ^= s0; + state[0] = rotl(s0, 55) ^ s1 ^ (s1 << 14); // a, b + state[1] = rotl(s1, 36); // c + + return result; + } + + uint64_t _CUDA_HD seedConv(Nd4jLong seed) { + uint64_t x = static_cast(seed); + uint64_t z = (x += UINT64_C(0x9E3779B97F4A7C15)); + z = (z ^ (z >> 30)) * UINT64_C(0xBF58476D1CE4E5B9); + z = (z ^ (z >> 27)) * UINT64_C(0x94D049BB133111EB); + return z ^ (z >> 31); + } + + void _CUDA_H jump(void) { + static const uint64_t JUMP[] = {0xbeac0467eba5facb, 0xd86b048b86aa9922}; + + uint64_t s0 = 0; + uint64_t s1 = 0; + for (unsigned int i = 0; i < sizeof JUMP / sizeof *JUMP; i++) + for (int b = 0; b < 64; b++) { + if (JUMP[i] & 1ULL << b) { + s0 ^= state[0]; + s1 ^= state[1]; + } + next64(); + } + + state[0] = s0; + state[1] = s1; + } + + public: + _CUDA_HD Xoroshiro128(sd::random::RandomBuffer *buffer) : IGenerator(buffer) { + // + } + + _CUDA_HD void refreshBuffer() { + state[0] = seedConv(this->seed); + state[1] = seedConv(this->seed * 119 + 3); + + int fd = 3 + 3; + + for (Nd4jLong i = 0; i < limit; i++) { + buffer[i] = next64(); } -} -#endif //LIBND4J_HELPER_GENERATOR_H + } +}; +} // namespace random +} // namespace sd +#endif // LIBND4J_HELPER_GENERATOR_H diff --git a/libnd4j/include/helpers/helper_hash.h b/libnd4j/include/helpers/helper_hash.h index fa44b04b755d..456f622302a4 100644 --- a/libnd4j/include/helpers/helper_hash.h +++ b/libnd4j/include/helpers/helper_hash.h @@ -15,34 +15,39 @@ ******************************************************************************/ // -// Stronger 64-bit hash function helper, as described here: http://www.javamex.com/tutorials/collections/strong_hash_code_implementation.shtml +// Stronger 64-bit hash function helper, as described here: +// http://www.javamex.com/tutorials/collections/strong_hash_code_implementation.shtml // @author raver119@gmail.com // -#ifndef LIBND4J_HELPER_HASH_H -#define LIBND4J_HELPER_HASH_H +#ifndef SD_HELPER_HASH_H +#define SD_HELPER_HASH_H -#include #include #include + #include +#include namespace sd { - namespace ops { - class ND4J_EXPORT HashHelper { - private: - Nd4jLong _byteTable[256]; - const Nd4jLong HSTART = 0xBB40E64DA205B064L; - const Nd4jLong HMULT = 7664345821815920749L; - - bool _isInit = false; - std::mutex _locker; - - public: - static HashHelper& getInstance(); - Nd4jLong getLongHash(std::string& str); - }; - } -} - -#endif //LIBND4J_HELPER_HASH_H +namespace ops { +class SD_EXPORT HashHelper { + private: + + + Nd4jLong _byteTable[256]; + const Nd4jLong HSTART = 0xBB40E64DA205B064L; + const Nd4jLong HMULT = 7664345821815920749L; + + bool _isInit = false; + std::mutex _locker; + + public: + static HashHelper& getInstance(); + Nd4jLong getLongHash(const std::string& str); +}; + +} // namespace ops +} // namespace sd + +#endif // SD_HELPER_HASH_H diff --git a/libnd4j/include/helpers/helper_ptrmap.h b/libnd4j/include/helpers/helper_ptrmap.h index 4f2ec128c9ec..bc115af776b9 100644 --- a/libnd4j/include/helpers/helper_ptrmap.h +++ b/libnd4j/include/helpers/helper_ptrmap.h @@ -29,191 +29,199 @@ namespace sd { - /** - * This class is a simple wrapper to represent batch arguments as single surface of parameters. - * So we pass batch parameters as single surface, and then we use this helper to extract arguments for each aggregates. - * - * Surface map format is simple: - * [0] we put numbers for num*Arguments - * [1] then we put indexing arguments, since their size is constant - * [2] here we put block of JVM IntArrays by value, batchLimit * maxIntArrays * maxArraySize; - * [3] then we put real arguments - * [4] then we put arguments pointers - * [5] then we put shape pointers - * - */ - template - class PointersHelper { - private: - int aggregates; - void *ptrGeneral; - - // we enforce maximal batch size limit, to simplify +/** + * This class is a simple wrapper to represent batch arguments as single surface + * of parameters. So we pass batch parameters as single surface, and then we use + * this helper to extract arguments for each aggregates. + * + * Surface map format is simple: + * [0] we put numbers for num*Arguments + * [1] then we put indexing arguments, since their size is constant + * [2] here we put block of JVM IntArrays by value, batchLimit * maxIntArrays * + * maxArraySize; [3] then we put real arguments [4] then we put arguments + * pointers [5] then we put shape pointers + * + */ +template +class PointersHelper { + private: + int aggregates; + void *ptrGeneral; + + // we enforce maximal batch size limit, to simplify #ifdef __CUDACC__ - const int batchLimit = 8192; + const int batchLimit = 8192; #else - const int batchLimit = 512; + const int batchLimit = 512; #endif - // we have 5 diff kinds of arguments: arguments, shapeArguments, intArrayArguments, indexArguments, realArguments - const int argTypes = 5; - - int maxIntArrays; - int maxArraySize; - - // right now we hardcode maximas, but we'll probably change that later - int maxIndexArguments; - int maxRealArguments; - - // since that's pointers (which is 64-bit on 64bit systems), we limit number of maximum arguments to 1/2 of maxIndex arguments - int maxArguments; - int maxShapeArguments; - - int sizeT; - int sizePtr; - public: - - /** - * We accept single memory chunk and number of jobs stored. - * - * @param ptrToParams pointer to "surface" - * @param numAggregates actual number of aggregates being passed in - * @return - */ + // we have 5 diff kinds of arguments: arguments, shapeArguments, + // intArrayArguments, indexArguments, realArguments + const int argTypes = 5; + + int maxIntArrays; + int maxArraySize; + + // right now we hardcode maximas, but we'll probably change that later + int maxIndexArguments; + int maxRealArguments; + + // since that's pointers (which is 64-bit on 64bit systems), we limit number + // of maximum arguments to 1/2 of maxIndex arguments + int maxArguments; + int maxShapeArguments; + + int sizeT; + int sizePtr; + + public: + /** + * We accept single memory chunk and number of jobs stored. + * + * @param ptrToParams pointer to "surface" + * @param numAggregates actual number of aggregates being passed in + * @return + */ #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - PointersHelper(void *ptrToParams, int numAggregates, int maxArgs, int maxShapes, int maxIntArrays, int maxIntArraySize, int maxIdx, int maxReals) { - aggregates = numAggregates; - ptrGeneral = ptrToParams; - - // ptrSize for hypothetical 32-bit compatibility - sizePtr = sizeof(ptrToParams); - - // unfortunately we have to know sizeOf(T) - sizeT = sizeof(T); - - this->maxIntArrays = maxIntArrays; - this->maxArraySize = maxIntArraySize; - this->maxIndexArguments = maxIdx; - this->maxArguments = maxArgs; - this->maxShapeArguments = maxShapes; - this->maxRealArguments = maxReals; - } - - /** - * This method returns point - * - * @param aggregateIdx - * @return - */ - - ptr_def T **getArguments(int aggregateIdx) { - T **aPtr = (T **) getRealArguments(batchLimit); - - return aPtr + (aggregateIdx * maxArguments); - } - - /** - * This method returns number of array arguments for specified aggregate - * - * @param aggregateIdx - * @return - */ - ptr_def int getNumArguments(int aggregateIdx) { - int *tPtr = (int *) ptrGeneral; - return tPtr[aggregateIdx * argTypes]; - } - - /** - * This method returns set of pointers to shape aruments for specified aggregates - * - * @param aggregateIdx - * @return - */ - ptr_def Nd4jLong **getShapeArguments(int aggregateIdx) { - Nd4jLong **sPtr = (Nd4jLong **)getArguments(batchLimit); - - return sPtr + (aggregateIdx * maxShapeArguments); - } - - /** - * This methor returns number of shape arguments for specified aggregate - * - * @param aggregateIdx - * @return - */ - ptr_def int getNumShapeArguments(int aggregateIdx) { - int *tPtr = (int *) ptrGeneral; - return tPtr[aggregateIdx * argTypes + 1]; - } - - /** - * This method returns pointer to array of int/index arguments for specified aggregate - * - * @param aggregateIdx - * @return - */ - ptr_def int *getIndexArguments(int aggregateIdx) { - // we skip first numeric num*arguments - int *ptr = ((int *) ptrGeneral) + (batchLimit * argTypes); - - // and return index for requested aggregate - return ptr + (aggregateIdx * maxIndexArguments) ; - } - - /** - * This method returns number of int/index arguments for specified aggregate - * - * @param aggregateIdx - * @return - */ - ptr_def int getNumIndexArguments(int aggregateIdx) { - int *tPtr = (int *) ptrGeneral; - return tPtr[aggregateIdx * argTypes + 2]; - } - - /** - * This method returns pointer to array of jvm IntArray arguments - */ - ptr_def int *getIntArrayArguments(int aggregateIdx, int argumentIdx) { - int *ptr = (int * )getIndexArguments(batchLimit); - - return ptr + (aggregateIdx * maxIntArrays * maxArraySize) + (argumentIdx * maxArraySize); - } - - /** - * This method returns number of jvm IntArray arguments - */ - ptr_def int getNumIntArrayArguments(int aggregateIdx) { - int *tPtr = (int *) ptrGeneral; - return tPtr[aggregateIdx * argTypes + 4]; - } - - /** - * This method returns real arguments for specific aggregate - * - * @param aggregateIdx - * @return - */ - ptr_def T *getRealArguments(int aggregateIdx) { - // we get pointer for last batchElement + 1, so that'll be pointer for 0 realArgument - T *ptr = (T * ) getIntArrayArguments(batchLimit, 0); - - return ptr + (aggregateIdx * maxRealArguments); - } - - /** - * This methor returns number of real arguments for specified aggregate - * - * @param aggregateIdx - * @return - */ - ptr_def int getNumRealArguments(int aggregateIdx) { - int *tPtr = (int *) ptrGeneral; - return tPtr[aggregateIdx * argTypes + 3]; - } - }; -} - -#endif //LIBND4J_HELPER_PTRMAP_H + PointersHelper(void *ptrToParams, int numAggregates, int maxArgs, + int maxShapes, int maxIntArrays, int maxIntArraySize, + int maxIdx, int maxReals) { + aggregates = numAggregates; + ptrGeneral = ptrToParams; + + // ptrSize for hypothetical 32-bit compatibility + sizePtr = sizeof(ptrToParams); + + // unfortunately we have to know sizeOf(T) + sizeT = sizeof(T); + + this->maxIntArrays = maxIntArrays; + this->maxArraySize = maxIntArraySize; + this->maxIndexArguments = maxIdx; + this->maxArguments = maxArgs; + this->maxShapeArguments = maxShapes; + this->maxRealArguments = maxReals; + } + + /** + * This method returns point + * + * @param aggregateIdx + * @return + */ + + ptr_def T **getArguments(int aggregateIdx) { + T **aPtr = (T **)getRealArguments(batchLimit); + + return aPtr + (aggregateIdx * maxArguments); + } + + /** + * This method returns number of array arguments for specified aggregate + * + * @param aggregateIdx + * @return + */ + ptr_def int getNumArguments(int aggregateIdx) { + int *tPtr = (int *)ptrGeneral; + return tPtr[aggregateIdx * argTypes]; + } + + /** + * This method returns set of pointers to shape aruments for specified + * aggregates + * + * @param aggregateIdx + * @return + */ + ptr_def Nd4jLong **getShapeArguments(int aggregateIdx) { + Nd4jLong **sPtr = (Nd4jLong **)getArguments(batchLimit); + + return sPtr + (aggregateIdx * maxShapeArguments); + } + + /** + * This methor returns number of shape arguments for specified aggregate + * + * @param aggregateIdx + * @return + */ + ptr_def int getNumShapeArguments(int aggregateIdx) { + int *tPtr = (int *)ptrGeneral; + return tPtr[aggregateIdx * argTypes + 1]; + } + + /** + * This method returns pointer to array of int/index arguments for specified + * aggregate + * + * @param aggregateIdx + * @return + */ + ptr_def int *getIndexArguments(int aggregateIdx) { + // we skip first numeric num*arguments + int *ptr = ((int *)ptrGeneral) + (batchLimit * argTypes); + + // and return index for requested aggregate + return ptr + (aggregateIdx * maxIndexArguments); + } + + /** + * This method returns number of int/index arguments for specified aggregate + * + * @param aggregateIdx + * @return + */ + ptr_def int getNumIndexArguments(int aggregateIdx) { + int *tPtr = (int *)ptrGeneral; + return tPtr[aggregateIdx * argTypes + 2]; + } + + /** + * This method returns pointer to array of jvm IntArray arguments + */ + ptr_def int *getIntArrayArguments(int aggregateIdx, int argumentIdx) { + int *ptr = (int *)getIndexArguments(batchLimit); + + return ptr + (aggregateIdx * maxIntArrays * maxArraySize) + + (argumentIdx * maxArraySize); + } + + /** + * This method returns number of jvm IntArray arguments + */ + ptr_def int getNumIntArrayArguments(int aggregateIdx) { + int *tPtr = (int *)ptrGeneral; + return tPtr[aggregateIdx * argTypes + 4]; + } + + /** + * This method returns real arguments for specific aggregate + * + * @param aggregateIdx + * @return + */ + ptr_def T *getRealArguments(int aggregateIdx) { + // we get pointer for last batchElement + 1, so that'll be pointer for 0 + // realArgument + T *ptr = (T *)getIntArrayArguments(batchLimit, 0); + + return ptr + (aggregateIdx * maxRealArguments); + } + + /** + * This methor returns number of real arguments for specified aggregate + * + * @param aggregateIdx + * @return + */ + ptr_def int getNumRealArguments(int aggregateIdx) { + int *tPtr = (int *)ptrGeneral; + return tPtr[aggregateIdx * argTypes + 3]; + } +}; +} // namespace sd + +#endif // LIBND4J_HELPER_PTRMAP_H diff --git a/libnd4j/include/helpers/helper_random.h b/libnd4j/include/helpers/helper_random.h index 6f2523e05fa6..829d33209f5a 100644 --- a/libnd4j/include/helpers/helper_random.h +++ b/libnd4j/include/helpers/helper_random.h @@ -33,204 +33,194 @@ #endif - namespace sd { - namespace random { - - template - class RandomHelper { - private: - sd::random::IGenerator *generator; - sd::random::RandomBuffer *buffer; - - - public: - - _CUDA_HD RandomHelper(sd::random::IGenerator *generator) { - this->generator = generator; - this->buffer = generator->getBuffer(); - } - - _CUDA_HD RandomHelper(sd::random::RandomBuffer *buffer) { - this->buffer = buffer; - } - - - /** - * This method returns random int in range [0..MAX_INT] - * @return - */ - inline _CUDA_D int nextInt() { - int r = (int) nextUInt(); - return r < 0 ? -1 * r : r; - }; - - inline _CUDA_D uint64_t nextUInt() { - return buffer->getNextElement(); - } - - /** - * This method returns random int in range [0..to] - * @param to - * @return - */ - inline _CUDA_D int nextInt(int to) { - int r = nextInt(); - int m = to - 1; - if ((to & m) == 0) // i.e., bound is a power of 2 - r = (int) ((to * (long) r) >> 31); - else { - for (int u = r; - u - (r = u % to) + m < 0; - u = nextInt()); - } - return r; - }; - - /** - * This method returns random int in range [from..to] - * @param from - * @param to - * @return - */ - inline _CUDA_D int nextInt(int from, int to) { - if (from == 0) - return nextInt(to); - - return from + nextInt(to - from); - }; - - - /** - * This method returns random T in range of [0..MAX_FLOAT] - * @return - */ - inline _CUDA_D T nextMaxT() { - T rnd = (T) buffer->getNextElement(); - return rnd < 0 ? -1 * rnd : rnd; - }; - - - /** - * This method returns random T in range of [0..1] - * @return - */ - inline _CUDA_D T nextT() { - return (T) nextUInt() / (T) sd::DataTypeUtils::max(); - } - - /** - * This method returns random T in range of [0..to] - * @param to - * @return - */ - inline _CUDA_D T nextT(T to) { - if (to == (T) 1.0f) - return nextT(); - - return nextT((T) 0.0f, to); - }; - - /** - * This method returns random T in range [from..to] - * @param from - * @param to - * @return - */ - inline _CUDA_D T nextT(T from, T to) { - return from + (nextT() * (to - from)); - } - - inline _CUDA_D uint64_t relativeUInt(Nd4jLong index) { - return buffer->getElement(index); - } - - /** - * relative methods are made as workaround for lock-free concurrent execution - */ - inline _CUDA_D int relativeInt(Nd4jLong index) { - return (int) (relativeUInt(index) % (sd::DataTypeUtils::max() + 1)); - } - - /** - * This method returns random int within [0..to] - * - * @param index - * @param to - * @return - */ - inline _CUDA_D int relativeInt(Nd4jLong index, int to) { - int rel = relativeInt(index); - return rel % to; - } - - /** - * This method returns random int within [from..to] - * - * @param index - * @param to - * @param from - * @return - */ - inline int _CUDA_D relativeInt(Nd4jLong index, int to, int from) { - if (from == 0) - return relativeInt(index, to); - - return from + relativeInt(index, to - from); - } - - /** - * This method returns random T within [0..1] - * - * @param index - * @return - */ - - inline _CUDA_D T relativeT(Nd4jLong index) { - if (sizeof(T) < 4) { - // FIXME: this is fast hack for short types, like fp16. This should be improved. - return (T)((float) relativeUInt(index) / (float) sd::DataTypeUtils::max()); - } else return (T) relativeUInt(index) / (T) sd::DataTypeUtils::max(); - } - - /** - * This method returns random T within [0..to] - * - * @param index - * @param to - * @return - */ - inline _CUDA_D T relativeT(Nd4jLong index, T to) { - if (to == (T) 1.0f) - return relativeT(index); - - return relativeT(index, (T) 0.0f, to); - } - - /** - * This method returns random T within [from..to] - * - * @param index - * @param from - * @param to - * @return - */ - inline _CUDA_D T relativeT(Nd4jLong index, T from, T to) { - return from + (relativeT(index) * (to - from)); - } - - - /** - * This method skips X elements from buffer - * - * @param numberOfElements number of elements to skip - */ - inline _CUDA_D void rewind(Nd4jLong numberOfElements) { - buffer->rewindH(numberOfElements); - } - }; +namespace random { + +template +class RandomHelper { + private: + sd::random::IGenerator *generator; + sd::random::RandomBuffer *buffer; + + public: + _CUDA_HD RandomHelper(sd::random::IGenerator *generator) { + this->generator = generator; + this->buffer = generator->getBuffer(); + } + + _CUDA_HD RandomHelper(sd::random::RandomBuffer *buffer) { + this->buffer = buffer; + } + + /** + * This method returns random int in range [0..MAX_INT] + * @return + */ + inline _CUDA_D int nextInt() { + int r = (int)nextUInt(); + return r < 0 ? -1 * r : r; + }; + + inline _CUDA_D uint64_t nextUInt() { return buffer->getNextElement(); } + + /** + * This method returns random int in range [0..to] + * @param to + * @return + */ + inline _CUDA_D int nextInt(int to) { + int r = nextInt(); + int m = to - 1; + if ((to & m) == 0) // i.e., bound is a power of 2 + r = (int)((to * (long)r) >> 31); + else { + for (int u = r; u - (r = u % to) + m < 0; u = nextInt()) + ; } -} - -#endif //LIBND4J_HELPER_RANDOM_H + return r; + }; + + /** + * This method returns random int in range [from..to] + * @param from + * @param to + * @return + */ + inline _CUDA_D int nextInt(int from, int to) { + if (from == 0) return nextInt(to); + + return from + nextInt(to - from); + }; + + /** + * This method returns random T in range of [0..MAX_FLOAT] + * @return + */ + inline _CUDA_D T nextMaxT() { + T rnd = (T)buffer->getNextElement(); + return rnd < 0 ? -1 * rnd : rnd; + }; + + /** + * This method returns random T in range of [0..1] + * @return + */ + inline _CUDA_D T nextT() { + return (T)nextUInt() / (T)sd::DataTypeUtils::max(); + } + + /** + * This method returns random T in range of [0..to] + * @param to + * @return + */ + inline _CUDA_D T nextT(T to) { + if (to == (T)1.0f) return nextT(); + + return nextT((T)0.0f, to); + }; + + /** + * This method returns random T in range [from..to] + * @param from + * @param to + * @return + */ + inline _CUDA_D T nextT(T from, T to) { + return from + (nextT() * (to - from)); + } + + inline _CUDA_D uint64_t relativeUInt(Nd4jLong index) { + return buffer->getElement(index); + } + + /** + * relative methods are made as workaround for lock-free concurrent execution + */ + inline _CUDA_D int relativeInt(Nd4jLong index) { + return (int)(relativeUInt(index) % + (sd::DataTypeUtils::max() + 1)); + } + + /** + * This method returns random int within [0..to] + * + * @param index + * @param to + * @return + */ + inline _CUDA_D int relativeInt(Nd4jLong index, int to) { + int rel = relativeInt(index); + return rel % to; + } + + /** + * This method returns random int within [from..to] + * + * @param index + * @param to + * @param from + * @return + */ + inline int _CUDA_D relativeInt(Nd4jLong index, int to, int from) { + if (from == 0) return relativeInt(index, to); + + return from + relativeInt(index, to - from); + } + + /** + * This method returns random T within [0..1] + * + * @param index + * @return + */ + + inline _CUDA_D T relativeT(Nd4jLong index) { + if (sizeof(T) < 4) { + // FIXME: this is fast hack for short types, like fp16. This should be + // improved. + return (T)((float)relativeUInt(index) / + (float)sd::DataTypeUtils::max()); + } else + return (T)relativeUInt(index) / (T)sd::DataTypeUtils::max(); + } + + /** + * This method returns random T within [0..to] + * + * @param index + * @param to + * @return + */ + inline _CUDA_D T relativeT(Nd4jLong index, T to) { + if (to == (T)1.0f) return relativeT(index); + + return relativeT(index, (T)0.0f, to); + } + + /** + * This method returns random T within [from..to] + * + * @param index + * @param from + * @param to + * @return + */ + inline _CUDA_D T relativeT(Nd4jLong index, T from, T to) { + return from + (relativeT(index) * (to - from)); + } + + /** + * This method skips X elements from buffer + * + * @param numberOfElements number of elements to skip + */ + inline _CUDA_D void rewind(Nd4jLong numberOfElements) { + buffer->rewindH(numberOfElements); + } +}; +} // namespace random +} // namespace sd + +#endif // LIBND4J_HELPER_RANDOM_H diff --git a/libnd4j/include/helpers/hhColPivQR.h b/libnd4j/include/helpers/hhColPivQR.h index 28dd42f64a0b..7cbff757270a 100644 --- a/libnd4j/include/helpers/hhColPivQR.h +++ b/libnd4j/include/helpers/hhColPivQR.h @@ -21,36 +21,31 @@ #ifndef LIBND4J_HHCOLPICQR_H #define LIBND4J_HHCOLPICQR_H -#include #include +#include namespace sd { namespace ops { namespace helpers { class HHcolPivQR { + public: + NDArray _qr; + NDArray _coeffs; + NDArray _permut; + int _diagSize; - public: - - NDArray _qr; - NDArray _coeffs; - NDArray _permut; - int _diagSize; - - HHcolPivQR() = delete; - HHcolPivQR(const NDArray& matrix); + HHcolPivQR() = delete; + HHcolPivQR(const NDArray& matrix); - template - void _evalData(); + template + void _evalData(); - void evalData(); + void evalData(); }; +} // namespace helpers +} // namespace ops +} // namespace sd - -} -} -} - - -#endif //LIBND4J_HHCOLPICQR_H +#endif // LIBND4J_HHCOLPICQR_H diff --git a/libnd4j/include/helpers/hhSequence.h b/libnd4j/include/helpers/hhSequence.h index 1e1f8ecadaa8..07bf7281fcd8 100644 --- a/libnd4j/include/helpers/hhSequence.h +++ b/libnd4j/include/helpers/hhSequence.h @@ -27,74 +27,67 @@ namespace sd { namespace ops { namespace helpers { - class HHsequence { - - public: - - /* - * matrix containing the Householder vectors - */ - const NDArray& _vectors; - - /* - * vector containing the Householder coefficients - */ - const NDArray& _coeffs; - - /* - * shift of the Householder sequence - */ - int _shift; - - /* - * length of the Householder sequence - */ - int _diagSize; - - /* - * type of sequence, type = 'u' (acting on columns, left) or type = 'v' (acting on rows, right) - */ - char _type; - - /* - * constructor - */ - HHsequence(const NDArray& vectors, const NDArray& coeffs, const char type); - - /** - * this method mathematically multiplies input matrix on Householder sequence from the left H0*H1*...Hn * matrix - * - * matrix - input matrix to be multiplied - */ - template - void mulLeft_(NDArray& matrix); - - void mulLeft(NDArray& matrix); - - NDArray getTail(const int idx) const; - - template - void applyTo_(NDArray& dest); - - void applyTo(NDArray& dest); - - FORCEINLINE int rows() const; - + public: + /* + * matrix containing the Householder vectors + */ + const NDArray& _vectors; + + /* + * vector containing the Householder coefficients + */ + const NDArray& _coeffs; + + /* + * shift of the Householder sequence + */ + int _shift; + + /* + * length of the Householder sequence + */ + int _diagSize; + + /* + * type of sequence, type = 'u' (acting on columns, left) or type = 'v' + * (acting on rows, right) + */ + char _type; + + /* + * constructor + */ + HHsequence(const NDArray& vectors, const NDArray& coeffs, const char type); + + /** + * this method mathematically multiplies input matrix on Householder sequence + * from the left H0*H1*...Hn * matrix + * + * matrix - input matrix to be multiplied + */ + template + void mulLeft_(NDArray& matrix); + + void mulLeft(NDArray& matrix); + + NDArray getTail(const int idx) const; + + template + void applyTo_(NDArray& dest); + + void applyTo(NDArray& dest); + + FORCEINLINE int rows() const; }; - ////////////////////////////////////////////////////////////////////////// FORCEINLINE int HHsequence::rows() const { - - return _type == 'u' ? _vectors.sizeAt(0) : _vectors.sizeAt(1); -} - - - -} -} + return _type == 'u' ? _vectors.sizeAt(0) : _vectors.sizeAt(1); } +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //LIBND4J_HHSEQUENCE_H +#endif // LIBND4J_HHSEQUENCE_H diff --git a/libnd4j/include/helpers/householder.h b/libnd4j/include/helpers/householder.h index 7811fafa008e..072fef9b4de2 100644 --- a/libnd4j/include/helpers/householder.h +++ b/libnd4j/include/helpers/householder.h @@ -21,8 +21,7 @@ #ifndef LIBND4J_HOUSEHOLDER_H #define LIBND4J_HOUSEHOLDER_H - -#include "array/NDArray.h" +#include namespace sd { namespace ops { @@ -30,92 +29,89 @@ namespace helpers { template class Householder { - - public: - - /** - * this method calculates Householder matrix P = identity_matrix - coeff * w * w^T - * P * x = [normX, 0, 0 , 0, ...] - * coeff - scalar - * w = [1, w1, w2, w3, ...] - * w = u / u0 - * u = x - |x|*e0 - * u0 = x0 - |x| - * e0 = [1, 0, 0 , 0, ...] - * - * x - input vector, remains unaffected - */ - // static NDArray evalHHmatrix(const NDArray& x); - - /** - * this method evaluates data required for calculation of Householder matrix P = identity_matrix - coeff * w * w^T - * P * x = [normX, 0, 0 , 0, ...] - * coeff - scalar - * w = [1, w1, w2, w3, ...] - * w = u / u0 - * u = x - |x|*e0 - * u0 = x0 - |x| - * e0 = [1, 0, 0 , 0, ...] - * - * x - input vector, remains unaffected - * tail - the essential part of the vector w: [w1, w2, w3, ...] - * normX - this scalar is the first non-zero element in vector resulting from Householder transformation -> (P*x) - * coeff - scalar, scaling factor in Householder matrix formula - */ - static void evalHHmatrixData(const NDArray& x, NDArray& tail, T& coeff, T& normX); - - static void evalHHmatrixDataI(NDArray& x, T& coeff, T& normX); // in-place, x to be affected - - /** - * this method mathematically multiplies input matrix on Householder from the left P * matrix - * - * matrix - input matrix - * tail - the essential part of the Householder vector w: [w1, w2, w3, ...] - * coeff - scalar, scaling factor in Householder matrix formula - */ - static void mulLeft(NDArray& matrix, const NDArray& tail, const T coeff); - - /** - * this method mathematically multiplies input matrix on Householder from the right matrix * P - * - * matrix - input matrix - * tail - the essential part of the Householder vector w: [w1, w2, w3, ...] - * coeff - scalar, scaling factor in Householder matrix formula - */ - static void mulRight(NDArray& matrix, const NDArray& tail, const T coeff); - - - + public: + /** + * this method calculates Householder matrix P = identity_matrix - coeff * w + * * w^T P * x = [normX, 0, 0 , 0, ...] coeff - scalar w = [1, w1, w2, w3, + * ...] w = u / u0 u = x - |x|*e0 u0 = x0 - |x| e0 = [1, 0, 0 , 0, ...] + * + * x - input vector, remains unaffected + */ + // static NDArray evalHHmatrix(const NDArray& x); + + /** + * this method evaluates data required for calculation of Householder matrix + * P = identity_matrix - coeff * w * w^T P * x = [normX, 0, 0 , 0, ...] coeff + * - scalar w = [1, w1, w2, w3, ...] w = u / u0 u = x - |x|*e0 u0 = x0 - |x| + * e0 = [1, 0, 0 , 0, ...] + * + * x - input vector, remains unaffected + * tail - the essential part of the vector w: [w1, w2, w3, ...] + * normX - this scalar is the first non-zero element in vector resulting from + * Householder transformation -> (P*x) coeff - scalar, scaling factor in + * Householder matrix formula + */ + static void evalHHmatrixData(const NDArray& x, NDArray& tail, T& coeff, + T& normX); + + static void evalHHmatrixDataI(NDArray& x, T& coeff, T& normX); // in-place, x to be affected + + /** + * this method mathematically multiplies input matrix on Householder from the + * left P * matrix + * + * matrix - input matrix + * tail - the essential part of the Householder vector w: [w1, w2, w3, ...] + * coeff - scalar, scaling factor in Householder matrix formula + */ + static void mulLeft(NDArray& matrix, const NDArray& tail, const T coeff); + + /** + * this method mathematically multiplies input matrix on Householder from the + * right matrix * P + * + * matrix - input matrix + * tail - the essential part of the Householder vector w: [w1, w2, w3, ...] + * coeff - scalar, scaling factor in Householder matrix formula + */ + static void mulRight(NDArray& matrix, const NDArray& tail, const T coeff); }; - - // /** - // * this function reduce given matrix to upper bidiagonal form (in-place operation), matrix must satisfy following condition rows >= cols - // * - // * matrix - input 2D matrix to be reduced to upper bidiagonal from - // */ - // template - // void biDiagonalizeUp(NDArray& matrix); - - // /** - // * given a matrix [m,n], this function computes its singular value decomposition matrix = u * s * v^T - // * - // * matrix - input 2D matrix to decompose, [m, n] - // * u - unitary matrix containing left singular vectors of input matrix, [m, m] - // * s - diagonal matrix with singular values of input matrix (non-negative) on the diagonal sorted in decreasing order, - // * actually the mathematically correct dimension of s is [m, n], however for memory saving we work with s as vector [1, p], where p is smaller among m and n - // * v - unitary matrix containing right singular vectors of input matrix, [n, n] - // * calcUV - if true then u and v will be computed, in opposite case function works significantly faster - // * fullUV - if false then only p (p is smaller among m and n) first columns of u and v will be calculated and their dimensions in this case are [m, p] and [n, p] - // * - // */ - // void svd(const NDArray& matrix, NDArray& u, NDArray& s, NDArray& v, const bool calcUV = false, const bool fullUV = false) - - - -} -} -} - - -#endif //LIBND4J_HOUSEHOLDER_H +// /** +// * this function reduce given matrix to upper bidiagonal form (in-place +// operation), matrix must satisfy following condition rows >= cols +// * +// * matrix - input 2D matrix to be reduced to upper bidiagonal from +// */ +// template +// void biDiagonalizeUp(NDArray& matrix); + +// /** +// * given a matrix [m,n], this function computes its singular value +// decomposition matrix = u * s * v^T +// * +// * matrix - input 2D matrix to decompose, [m, n] +// * u - unitary matrix containing left singular vectors of input matrix, [m, +// m] +// * s - diagonal matrix with singular values of input matrix (non-negative) on +// the diagonal sorted in decreasing order, +// * actually the mathematically correct dimension of s is [m, n], however +// for memory saving we work with s as vector [1, p], where p is smaller among m +// and n +// * v - unitary matrix containing right singular vectors of input matrix, [n, +// n] +// * calcUV - if true then u and v will be computed, in opposite case function +// works significantly faster +// * fullUV - if false then only p (p is smaller among m and n) first columns +// of u and v will be calculated and their dimensions in this case are [m, p] +// and [n, p] +// * +// */ +// void svd(const NDArray& matrix, NDArray& u, NDArray& s, NDArray& v, const +// bool calcUV = false, const bool fullUV = false) + +} // namespace helpers +} // namespace ops +} // namespace sd + +#endif // LIBND4J_HOUSEHOLDER_H diff --git a/libnd4j/include/helpers/impl/ArrayUtils.cpp b/libnd4j/include/helpers/impl/ArrayUtils.cpp index 004cb15469cc..04ea00a06f10 100644 --- a/libnd4j/include/helpers/impl/ArrayUtils.cpp +++ b/libnd4j/include/helpers/impl/ArrayUtils.cpp @@ -21,37 +21,34 @@ #include namespace sd { - namespace ArrayUtils { - void toIntPtr(std::initializer_list list, int* target) { - std::vector vec(list); - toIntPtr(vec, target); - } - - void toIntPtr(std::vector& list, int* target) { - memcpy(target, list.data(), list.size() * sizeof(int)); - } - - void toLongPtr(std::initializer_list list, Nd4jLong* target) { - std::vector vec(list); - toLongPtr(vec, target); - } - - void toLongPtr(std::vector& list, Nd4jLong* target) { - memcpy(target, list.data(), list.size() * sizeof(Nd4jLong)); - } - - std::vector toLongVector(std::vector vec) { - std::vector result(vec.size()); - Nd4jLong vecSize = vec.size(); - - for (Nd4jLong e = 0; e < vecSize; e++) - result[e] = vec[e]; - - return result; - } - - std::vector toLongVector(std::vector vec) { - return vec; - } - } +namespace ArrayUtils { +void toIntPtr(std::initializer_list list, int* target) { + std::vector vec(list); + toIntPtr(vec, target); } + +void toIntPtr(std::vector& list, int* target) { + memcpy(target, list.data(), list.size() * sizeof(int)); +} + +void toLongPtr(std::initializer_list list, Nd4jLong* target) { + std::vector vec(list); + toLongPtr(vec, target); +} + +void toLongPtr(std::vector& list, Nd4jLong* target) { + memcpy(target, list.data(), list.size() * sizeof(Nd4jLong)); +} + +std::vector toLongVector(std::vector vec) { + std::vector result(vec.size()); + Nd4jLong vecSize = vec.size(); + + for (Nd4jLong e = 0; e < vecSize; e++) result[e] = vec[e]; + + return result; +} + +std::vector toLongVector(std::vector vec) { return vec; } +} // namespace ArrayUtils +} // namespace sd diff --git a/libnd4j/include/helpers/impl/AttentionHelper.cpp b/libnd4j/include/helpers/impl/AttentionHelper.cpp index bd5d006f2b17..00ec60aad546 100644 --- a/libnd4j/include/helpers/impl/AttentionHelper.cpp +++ b/libnd4j/include/helpers/impl/AttentionHelper.cpp @@ -21,61 +21,82 @@ #ifndef LIBND4J_ATTENTIONHELPER_CPP #define LIBND4J_ATTENTIONHELPER_CPP -#include - #include "../AttentionHelper.h" + +#include #include namespace sd { - sd::NDArray AttentionHelper::multiHeadProject(const sd::NDArray *input, const sd::NDArray *projectionMatrix, sd::LaunchContext * context) { - auto miniBatchSize = input->sizeAt(0); - auto seqLength = input->sizeAt(2); - auto numHeads = projectionMatrix->sizeAt(0); - auto projectedSize = projectionMatrix->sizeAt(1); - - auto inputPerm = input->permute({1, 0, 2}); //[batch, nIn, timeSteps] -> [nIn, batch, timeSteps] - auto inputPrep = inputPerm.reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)}); //[nIn, batch*timeSteps] - auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)}); //[nHeads, hS, nIn] -> [nHeads*hS, nIn] - - NDArray projected('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context); //[nHeads*hS, batch*timeSteps] - sd::ops::matmul mmul; - mmul.execute({&projectionPrep, &inputPrep}, {&projected}); - - projected.reshapei({numHeads, projectedSize, miniBatchSize, seqLength}); - projected.permutei({2, 0, 1, 3}); //[minibatch, numHeads, projectedSize, seqLength] - - return projected; - } - - void AttentionHelper::multiHeadProjectBp(const sd::NDArray *input, const sd::NDArray *projectionMatrix, - const sd::NDArray *eps, sd::NDArray *dLdInput, - sd::NDArray *dLdProjectionMatrix, sd::LaunchContext * context) { - auto miniBatchSize = input->sizeAt(0); - auto seqLength = input->sizeAt(2); - auto numHeads = projectionMatrix->sizeAt(0); - auto projectedSize = projectionMatrix->sizeAt(1); - - auto epsPerm = eps->permute({1, 2, 0, 3}); - auto epsReshaped = epsPerm.reshape('c', {numHeads * projectedSize, miniBatchSize * seqLength}); - - auto inputPerm = input->permute({1, 0, 2}); - auto inputPrep = inputPerm.reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)}); - auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)}); - - sd::ops::matmul_bp mmulBp; - NDArray dLdProjectionPrep(projectionPrep.shapeInfo(), false, context); - NDArray dLdInputPrep(inputPrep.shapeInfo(), false, context); - mmulBp.execute({&projectionPrep, &inputPrep, &epsReshaped}, std::vector{&dLdProjectionPrep, &dLdInputPrep}, {}, {}, {}); - - dLdProjectionPrep.reshapei({numHeads, projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)}); - dLdProjectionMatrix->assign(dLdProjectionPrep); - - dLdInputPrep.reshapei({input->sizeAt(1), miniBatchSize, seqLength}); - dLdInputPrep.permutei({1, 0, 2}); - dLdInput->assign(dLdInputPrep); - } +sd::NDArray AttentionHelper::multiHeadProject( + const sd::NDArray *input, const sd::NDArray *projectionMatrix, + sd::LaunchContext *context) { + auto miniBatchSize = input->sizeAt(0); + auto seqLength = input->sizeAt(2); + auto numHeads = projectionMatrix->sizeAt(0); + auto projectedSize = projectionMatrix->sizeAt(1); + + auto inputPerm = input->permute( + {1, 0, 2}); //[batch, nIn, timeSteps] -> [nIn, batch, timeSteps] + auto inputPrep = inputPerm.reshape( + 'c', {input->sizeAt(1), + (miniBatchSize * seqLength)}); //[nIn, batch*timeSteps] + auto projectionPrep = projectionMatrix->reshape( + 'c', + {numHeads * projectionMatrix->sizeAt(1), + projectionMatrix->sizeAt(2)}); //[nHeads, hS, nIn] -> [nHeads*hS, nIn] + + NDArray projected( + 'c', + {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, + input->dataType(), context); //[nHeads*hS, batch*timeSteps] + sd::ops::matmul mmul; + mmul.execute({&projectionPrep, &inputPrep}, {&projected}); + + projected.reshapei({numHeads, projectedSize, miniBatchSize, seqLength}); + projected.permutei( + {2, 0, 1, 3}); //[minibatch, numHeads, projectedSize, seqLength] + + return projected; } +void AttentionHelper::multiHeadProjectBp(const sd::NDArray *input, + const sd::NDArray *projectionMatrix, + const sd::NDArray *eps, + sd::NDArray *dLdInput, + sd::NDArray *dLdProjectionMatrix, + sd::LaunchContext *context) { + auto miniBatchSize = input->sizeAt(0); + auto seqLength = input->sizeAt(2); + auto numHeads = projectionMatrix->sizeAt(0); + auto projectedSize = projectionMatrix->sizeAt(1); + + auto epsPerm = eps->permute({1, 2, 0, 3}); + auto epsReshaped = epsPerm.reshape( + 'c', {numHeads * projectedSize, miniBatchSize * seqLength}); + + auto inputPerm = input->permute({1, 0, 2}); + auto inputPrep = + inputPerm.reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)}); + auto projectionPrep = projectionMatrix->reshape( + 'c', + {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)}); + + sd::ops::matmul_bp mmulBp; + NDArray dLdProjectionPrep(projectionPrep.shapeInfo(), false, context); + NDArray dLdInputPrep(inputPrep.shapeInfo(), false, context); + mmulBp.execute({&projectionPrep, &inputPrep, &epsReshaped}, + std::vector{&dLdProjectionPrep, &dLdInputPrep}, {}, + {}, {}); + + dLdProjectionPrep.reshapei( + {numHeads, projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)}); + dLdProjectionMatrix->assign(dLdProjectionPrep); + + dLdInputPrep.reshapei({input->sizeAt(1), miniBatchSize, seqLength}); + dLdInputPrep.permutei({1, 0, 2}); + dLdInput->assign(dLdInputPrep); +} +} // namespace sd #endif diff --git a/libnd4j/include/helpers/impl/BenchmarkHelper.cpp b/libnd4j/include/helpers/impl/BenchmarkHelper.cpp index 9e85cc5b73fd..a672a0a99979 100644 --- a/libnd4j/include/helpers/impl/BenchmarkHelper.cpp +++ b/libnd4j/include/helpers/impl/BenchmarkHelper.cpp @@ -18,685 +18,756 @@ * @author raver119@gmail.com */ - #include "../BenchmarkHelper.h" + #include -#include #include +#include + namespace sd { - BenchmarkHelper::BenchmarkHelper(unsigned int warmUpIterations, unsigned int runIterations) { - _wIterations = warmUpIterations; - _rIterations = runIterations; - } +BenchmarkHelper::BenchmarkHelper(unsigned int warmUpIterations, + unsigned int runIterations) { + _wIterations = warmUpIterations; + _rIterations = runIterations; +} + +std::string BenchmarkHelper::printHeader() { + return std::string( + "TestName\tOpNum\tWarmup\tNumIter\tDataType\tInplace\tShape\tStrides\tAxi" + "s\tOrders\tavg (us)\tmedian (us)\tmin (us)\tmax (us)\tstdev (us)\n"); +} + +std::string BenchmarkHelper::benchmarkOperation(OpBenchmark &benchmark) { + for (uint i = 0; i < _wIterations; i++) benchmark.executeOnce(); + + std::vector timings(_rIterations); + double sumT = 0.0; + + for (uint i = 0; i < _rIterations; i++) { + auto timeStart = std::chrono::system_clock::now(); + + benchmark.executeOnce(); + + auto timeEnd = std::chrono::system_clock::now(); + auto loopTime = std::chrono::duration_cast( + (timeEnd - timeStart)) + .count(); + timings[i] = loopTime; + sumT += loopTime; + } + sumT /= _rIterations; + + std::sort(timings.begin(), timings.end()); + Nd4jLong median = timings[_rIterations / 2]; + + auto n = NDArrayFactory::create(timings, LaunchContext::defaultContext()); + + auto stdev = + n.varianceNumber(sd::variance::SummaryStatsStandardDeviation, false) + .e(0); + auto min = n.reduceNumber(sd::reduce::Min).e(0); + auto max = n.reduceNumber(sd::reduce::Max).e(0); + + // opNum, DataType, Shape, average time, median time + auto t = benchmark.dataType(); + auto s = benchmark.shape(); + auto strides = benchmark.strides(); + auto o = benchmark.orders(); + auto a = benchmark.axis(); + auto inpl = benchmark.inplace(); + + std::string temp; + temp.resize(65536); + + // printing out stuff + snprintf( + const_cast(temp.data()), temp.length(), + "%s\t%i\t%i\t%i\t%s\t%s\t%s\t%s\t%s\t%s\t%lld\t%lld\t%lld\t%lld\t%.2f\n", + benchmark.testName().c_str(), benchmark.opNum(), _wIterations, + _rIterations, t.c_str(), inpl.c_str(), s.c_str(), strides.c_str(), + a.c_str(), o.c_str(), sd::math::nd4j_floor(sumT), + median, min, max, stdev); + + auto pos = temp.find('\n'); + return temp.substr(0, pos + 1); +} + +void BenchmarkHelper::benchmarkScalarOperation(scalar::Ops op, + std::string testName, + double value, NDArray &x, + NDArray &z) { + auto y = NDArrayFactory::create(x.dataType(), value); + + // for (uint i = 0; i < _wIterations; i++) + // NativeOpExecutioner::execScalar(op, x.buffer(), x.shapeInfo(), z.buffer(), + // z.shapeInfo(), y.buffer(), y.shapeInfo(), nullptr); + + std::vector timings(_rIterations); + double sumT = 0.0; + + for (uint i = 0; i < _rIterations; i++) { + auto timeStart = std::chrono::system_clock::now(); + + // NativeOpExecutioner::execScalar(op, x.buffer(), x.shapeInfo(), + // z.buffer(), z.shapeInfo(), y.buffer(), y.shapeInfo(), nullptr); + + auto timeEnd = std::chrono::system_clock::now(); + auto loopTime = std::chrono::duration_cast( + (timeEnd - timeStart)) + .count(); + timings[i] = loopTime; + sumT += loopTime; + } + sumT /= _rIterations; + + std::sort(timings.begin(), timings.end()); + Nd4jLong median = timings[_rIterations / 2]; + + NDArray n = NDArrayFactory::create(timings, nullptr); + double stdev = + n.varianceNumber(sd::variance::SummaryStatsStandardDeviation, false) + .e(0); + Nd4jLong min = n.reduceNumber(sd::reduce::Min).e(0); + Nd4jLong max = n.reduceNumber(sd::reduce::Max).e(0); + + // opNum, DataType, Shape, average time, median time + auto t = DataTypeUtils::asString(x.dataType()); + auto s = ShapeUtils::shapeAsString(&x); + auto stride = ShapeUtils::strideAsString(&x); + stride += "/"; + stride += ShapeUtils::strideAsString(&z); + std::string o; + o += x.ordering(); + o += "/"; + o += z.ordering(); + std::string inpl; + inpl += (x == z ? "true" : "false"); + + // printing out stuff + nd4j_printf( + "%s\t%i\t%i\t%i\t%s\t%s\t%s\t%s\t%s\tn/a\t%lld\t%lld\t%lld\t%lld\t%.2f\n", + testName.c_str(), op, _wIterations, _rIterations, t.c_str(), inpl.c_str(), + s.c_str(), stride.c_str(), o.c_str(), + sd::math::nd4j_floor(sumT), median, min, max, stdev); +} + +std::string BenchmarkHelper::runOperationSuit( + std::initializer_list benchmarks, const char *msg) { + std::vector ops(benchmarks); + return runOperationSuit(ops, msg); +} + +std::string BenchmarkHelper::runOperationSuit(OpBenchmark *benchmark) { + return benchmarkOperation(*benchmark); +} + +std::string BenchmarkHelper::runOperationSuit( + std::vector &benchmarks, bool postHeaders, const char *msg) { + std::string result; + + if (msg != nullptr && postHeaders) { + result += "\n"; + result += msg; + result += "\n"; + } + + if (postHeaders) result += printHeader(); + + for (auto v : benchmarks) result += benchmarkOperation(*v); + + return result; +} + +std::string BenchmarkHelper::runOperationSuit( + DeclarableBenchmark *op, const std::function &func, + ParametersBatch ¶metersBatch, const char *message) { + auto parameters = parametersBatch.parameters(); + std::string result; + + if (message != nullptr) { + result += "\n"; + result += message; + result += "\n"; + } + + result += printHeader(); + + std::vector list; + + for (auto &p : parameters) { + auto ctx = func(p); + + auto clone = reinterpret_cast(op->clone()); + clone->setContext(ctx); + + result += runOperationSuit(clone); + + delete clone; + } + + return result; +} + +std::string BenchmarkHelper::runOperationSuit( + ScalarBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message) { + auto parameters = parametersBatch.parameters(); + std::string output; + + if (message != nullptr) { + output += "\n"; + output += message; + output += "\n"; + } + + output += printHeader(); + + for (auto &p : parameters) { + ResultSet x; + x.setNonRemovable(); + ResultSet z; + z.setNonRemovable(); + func(p, x, z); + std::vector result; + + if (x.size() != z.size()) + throw std::runtime_error( + "ScalarBenchmark: number of X and Z arrays should match"); + + for (int e = 0; e < x.size(); e++) { + auto x_ = x.at(e); + auto z_ = z.at(e); + + auto clone = op->clone(); + clone->setX(x_); + clone->setZ(z_); - std::string BenchmarkHelper::printHeader() { - return std::string("TestName\tOpNum\tWarmup\tNumIter\tDataType\tInplace\tShape\tStrides\tAxis\tOrders\tavg (us)\tmedian (us)\tmin (us)\tmax (us)\tstdev (us)\n"); + result.emplace_back(clone); } - std::string BenchmarkHelper::benchmarkOperation(OpBenchmark &benchmark) { - - for (uint i = 0; i < _wIterations; i++) - benchmark.executeOnce(); - - std::vector timings(_rIterations); - double sumT = 0.0; - - for (uint i = 0; i < _rIterations; i++) { - auto timeStart = std::chrono::system_clock::now(); - - benchmark.executeOnce(); - - auto timeEnd = std::chrono::system_clock::now(); - auto loopTime = std::chrono::duration_cast ((timeEnd - timeStart)).count(); - timings[i] = loopTime; - sumT += loopTime; - } - sumT /= _rIterations; - - std::sort(timings.begin(), timings.end()); - Nd4jLong median = timings[_rIterations / 2]; - - auto n = NDArrayFactory::create(timings, LaunchContext::defaultContext()); + output += runOperationSuit(result, false); - auto stdev = n.varianceNumber(sd::variance::SummaryStatsStandardDeviation, false).e(0); - auto min = n.reduceNumber(sd::reduce::Min).e(0); - auto max = n.reduceNumber(sd::reduce::Max).e(0); - - // opNum, DataType, Shape, average time, median time - auto t = benchmark.dataType(); - auto s = benchmark.shape(); - auto strides = benchmark.strides(); - auto o = benchmark.orders(); - auto a = benchmark.axis(); - auto inpl = benchmark.inplace(); - - std::string temp; - temp.resize(65536); - - // printing out stuff - snprintf(const_cast(temp.data()), temp.length(), "%s\t%i\t%i\t%i\t%s\t%s\t%s\t%s\t%s\t%s\t%lld\t%lld\t%lld\t%lld\t%.2f\n", benchmark.testName().c_str(), benchmark.opNum(), - _wIterations, _rIterations, t.c_str(), inpl.c_str(), s.c_str(), strides.c_str(), a.c_str(), o.c_str(), - sd::math::nd4j_floor(sumT), median, min, max, stdev); - - auto pos = temp.find('\n'); - return temp.substr(0, pos + 1); + // removing everything + for (auto v : result) { + delete reinterpret_cast(v); } - - void BenchmarkHelper::benchmarkScalarOperation(scalar::Ops op, std::string testName, double value, NDArray &x, NDArray &z) { - auto y = NDArrayFactory::create(x.dataType(), value); - - //for (uint i = 0; i < _wIterations; i++) - //NativeOpExecutioner::execScalar(op, x.buffer(), x.shapeInfo(), z.buffer(), z.shapeInfo(), y.buffer(), y.shapeInfo(), nullptr); - - - std::vector timings(_rIterations); - double sumT = 0.0; - - for (uint i = 0; i < _rIterations; i++) { - auto timeStart = std::chrono::system_clock::now(); - - //NativeOpExecutioner::execScalar(op, x.buffer(), x.shapeInfo(), z.buffer(), z.shapeInfo(), y.buffer(), y.shapeInfo(), nullptr); - - auto timeEnd = std::chrono::system_clock::now(); - auto loopTime = std::chrono::duration_cast ((timeEnd - timeStart)).count(); - timings[i] = loopTime; - sumT += loopTime; - } - sumT /= _rIterations; - - std::sort(timings.begin(), timings.end()); - Nd4jLong median = timings[_rIterations / 2]; - - NDArray n = NDArrayFactory::create(timings, nullptr); - double stdev = n.varianceNumber(sd::variance::SummaryStatsStandardDeviation, false).e(0); - Nd4jLong min = n.reduceNumber(sd::reduce::Min).e(0); - Nd4jLong max = n.reduceNumber(sd::reduce::Max).e(0); - - // opNum, DataType, Shape, average time, median time - auto t = DataTypeUtils::asString(x.dataType()); - auto s = ShapeUtils::shapeAsString(&x); - auto stride = ShapeUtils::strideAsString(&x); - stride += "/"; - stride += ShapeUtils::strideAsString(&z); - std::string o; - o += x.ordering(); - o += "/"; - o += z.ordering(); - std::string inpl; - inpl += (x == z ? "true" : "false"); - - // printing out stuff - nd4j_printf("%s\t%i\t%i\t%i\t%s\t%s\t%s\t%s\t%s\tn/a\t%lld\t%lld\t%lld\t%lld\t%.2f\n", testName.c_str(), op, - _wIterations, _rIterations, t.c_str(), inpl.c_str(), s.c_str(), stride.c_str(), o.c_str(), - sd::math::nd4j_floor(sumT), median, min, max, stdev); + } + + return output; +} + +std::string BenchmarkHelper::runOperationSuit( + ScalarBenchmark *op, + const std::function &func, + const char *message) { + std::string output; + + ResultSet x; + x.setNonRemovable(); + ResultSet z; + z.setNonRemovable(); + func(x, z); + std::vector result; + + if (x.size() != z.size()) + throw std::runtime_error( + "ScalarBenchmark: number of X and Z arrays should match"); + + for (int e = 0; e < x.size(); e++) { + auto x_ = x.at(e); + auto z_ = z.at(e); + + auto clone = op->clone(); + clone->setX(x_); + clone->setZ(z_); + + result.emplace_back(clone); + } + + output += runOperationSuit(result, message); + + // removing everything + for (auto v : result) { + delete reinterpret_cast(v); + } + + return output; +} + +std::string BenchmarkHelper::runOperationSuit( + TransformBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message) { + auto parameters = parametersBatch.parameters(); + std::string output; + + if (message != nullptr) { + output += "\n"; + output += message; + output += "\n"; + } + + output += printHeader(); + + for (auto &p : parameters) { + ResultSet x; + x.setNonRemovable(); + ResultSet z; + z.setNonRemovable(); + func(p, x, z); + std::vector result; + + if (x.size() != z.size()) + throw std::runtime_error( + "TransformBenchmark: number of X and Z arrays should match"); + + for (int e = 0; e < x.size(); e++) { + auto x_ = x.at(e); + auto z_ = z.at(e); + + auto clone = op->clone(); + clone->setX(x_); + clone->setZ(z_); + + result.emplace_back(clone); } - std::string BenchmarkHelper::runOperationSuit(std::initializer_list benchmarks, const char *msg) { - std::vector ops(benchmarks); - return runOperationSuit(ops, msg); - } + output += runOperationSuit(result, false); - std::string BenchmarkHelper::runOperationSuit(OpBenchmark* benchmark) { - return benchmarkOperation(*benchmark); + // removing everything + for (auto v : result) { + delete reinterpret_cast(v); } - - std::string BenchmarkHelper::runOperationSuit(std::vector &benchmarks, bool postHeaders, const char *msg) { - std::string result; - - if (msg != nullptr && postHeaders) { - result += "\n"; - result += msg; - result += "\n"; - } - - if (postHeaders) - result += printHeader(); - - for (auto v:benchmarks) - result += benchmarkOperation(*v); - - return result; + } + + return output; +} + +std::string BenchmarkHelper::runOperationSuit( + TransformBenchmark *op, + const std::function &func, + const char *message) { + std::string output; + + ResultSet x; + x.setNonRemovable(); + ResultSet z; + z.setNonRemovable(); + func(x, z); + std::vector result; + + if (x.size() != z.size()) + throw std::runtime_error( + "TransformBenchmark: number of X and Z arrays should match"); + + for (int e = 0; e < x.size(); e++) { + auto x_ = x.at(e); + auto z_ = z.at(e); + + auto clone = op->clone(); + clone->setX(x_); + clone->setZ(z_); + + result.emplace_back(clone); + } + + output += runOperationSuit(result, message); + + // removing everything + for (auto v : result) { + delete reinterpret_cast(v); + } + + return output; +} + +std::string BenchmarkHelper::runOperationSuit( + ReductionBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message) { + std::string output; + auto parameters = parametersBatch.parameters(); + + if (message != nullptr) { + output += "\n"; + output += message; + output += "\n"; + } + + output += printHeader(); + + for (auto &p : parameters) { + ResultSet x; + x.setNonRemovable(); + ResultSet z; + z.setNonRemovable(); + func(p, x, z); + std::vector result; + + if (x.size() != z.size()) + throw std::runtime_error( + "ReductionBenchmark: number of X and Z arrays should match"); + + for (int e = 0; e < x.size(); e++) { + auto x_ = x.at(e); + auto z_ = z.at(e); + + auto clone = op->clone(); + clone->setX(x_); + clone->setZ(z_); + + result.emplace_back(clone); } - std::string BenchmarkHelper::runOperationSuit(DeclarableBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message) { - auto parameters = parametersBatch.parameters(); - std::string result; - - if (message != nullptr) { - result += "\n"; - result += message; - result += "\n"; - } + output += runOperationSuit(result, false); - result += printHeader(); - - std::vector list; - - for (auto &p : parameters) { - auto ctx = func(p); - - auto clone = reinterpret_cast(op->clone()); - clone->setContext(ctx); - - result += runOperationSuit(clone); - - delete clone; - } - - return result; - } - - std::string BenchmarkHelper::runOperationSuit(ScalarBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message) { - auto parameters = parametersBatch.parameters(); - std::string output; - - if (message != nullptr) { - output += "\n"; - output += message; - output += "\n"; - } - - output += printHeader(); - - for (auto &p: parameters) { - ResultSet x; - x.setNonRemovable(); - ResultSet z; - z.setNonRemovable(); - func(p, x, z); - std::vector result; - - if (x.size() != z.size()) - throw std::runtime_error("ScalarBenchmark: number of X and Z arrays should match"); - - for (int e = 0; e < x.size(); e++) { - auto x_ = x.at(e); - auto z_ = z.at(e); - - auto clone = op->clone(); - clone->setX(x_); - clone->setZ(z_); - - result.emplace_back(clone); - } - - output += runOperationSuit(result, false); - - // removing everything - for (auto v:result) { - delete reinterpret_cast(v); - } - } - - return output; + // removing everything + for (auto v : result) { + delete reinterpret_cast(v); } - - std::string BenchmarkHelper::runOperationSuit(ScalarBenchmark *op, const std::function& func, const char *message) { - std::string output; - - ResultSet x; - x.setNonRemovable(); - ResultSet z; - z.setNonRemovable(); - func(x, z); - std::vector result; - - if (x.size() != z.size()) - throw std::runtime_error("ScalarBenchmark: number of X and Z arrays should match"); - - for (int e = 0; e < x.size(); e++) { - auto x_ = x.at(e); - auto z_ = z.at(e); - - auto clone = op->clone(); - clone->setX(x_); - clone->setZ(z_); - - result.emplace_back(clone); - } - - output += runOperationSuit(result, message); - - // removing everything - for (auto v:result) { - delete reinterpret_cast(v); - } - - return output; + } + + return output; +} + +std::string BenchmarkHelper::runOperationSuit( + ReductionBenchmark *op, + const std::function &func, + const char *message) { + std::string output; + ResultSet x; + x.setNonRemovable(); + ResultSet z; + z.setNonRemovable(); + func(x, z); + std::vector result; + + if (x.size() != z.size()) + throw std::runtime_error( + "ReductionBenchmark: number of X and Z arrays should match"); + + for (int e = 0; e < x.size(); e++) { + auto x_ = x.at(e); + auto z_ = z.at(e); + + auto clone = op->clone(); + clone->setX(x_); + clone->setZ(z_); + + result.emplace_back(clone); + } + + output += runOperationSuit(result, message); + + // removing everything + for (auto v : result) { + delete reinterpret_cast(v); + } + + return output; +} + +std::string BenchmarkHelper::runOperationSuit( + ReductionBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message) { + auto parameters = parametersBatch.parameters(); + std::string output; + + if (message != nullptr) { + output += "\n"; + output += message; + output += "\n"; + } + + printHeader(); + + for (auto &p : parameters) { + ResultSet x; + x.setNonRemovable(); + ResultSet y; + y.setNonRemovable(); + ResultSet z; + z.setNonRemovable(); + func(p, x, y, z); + std::vector result; + + if (x.size() != z.size() || x.size() != y.size()) + throw std::runtime_error( + "ReductionBenchmark: number of X and Z arrays should match"); + + for (int e = 0; e < x.size(); e++) { + auto x_ = x.at(e); + auto y_ = y.at(e); + auto z_ = z.at(e); + + auto clone = op->clone(); + clone->setX(x_); + clone->setZ(z_); + + if (y_.shapeInfo() != nullptr) { + clone->setAxis(y_.asVectorT()); + } + + result.emplace_back(clone); } - std::string BenchmarkHelper::runOperationSuit(TransformBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message) { - auto parameters = parametersBatch.parameters(); - std::string output; - - if (message != nullptr) { - output += "\n"; - output += message; - output += "\n"; - } - - output += printHeader(); - - for (auto &p: parameters) { - ResultSet x; - x.setNonRemovable(); - ResultSet z; - z.setNonRemovable(); - func(p, x, z); - std::vector result; - - if (x.size() != z.size()) - throw std::runtime_error("TransformBenchmark: number of X and Z arrays should match"); + output += runOperationSuit(result, false); - for (int e = 0; e < x.size(); e++) { - auto x_ = x.at(e); - auto z_ = z.at(e); - - auto clone = op->clone(); - clone->setX(x_); - clone->setZ(z_); - - result.emplace_back(clone); - } - - output += runOperationSuit(result, false); - - // removing everything - for (auto v:result) { - delete reinterpret_cast(v); - } - } - - return output; + // removing everything + for (auto v : result) { + delete reinterpret_cast(v); } - - std::string BenchmarkHelper::runOperationSuit(TransformBenchmark *op, const std::function& func, const char *message) { - std::string output; - - ResultSet x; - x.setNonRemovable(); - ResultSet z; - z.setNonRemovable(); - func(x, z); - std::vector result; - - if (x.size() != z.size()) - throw std::runtime_error("TransformBenchmark: number of X and Z arrays should match"); - - for (int e = 0; e < x.size(); e++) { - auto x_ = x.at(e); - auto z_ = z.at(e); - - auto clone = op->clone(); - clone->setX(x_); - clone->setZ(z_); - - result.emplace_back(clone); - } - - output += runOperationSuit(result, message); - - // removing everything - for (auto v:result) { - delete reinterpret_cast(v); - } - - return output; + } + + return output; +} + +std::string BenchmarkHelper::runOperationSuit( + ReductionBenchmark *op, + const std::function &func, + const char *message) { + std::string output; + + ResultSet x; + x.setNonRemovable(); + ResultSet y; + y.setNonRemovable(); + ResultSet z; + z.setNonRemovable(); + func(x, y, z); + std::vector result; + + if (x.size() != z.size() || x.size() != y.size()) + throw std::runtime_error( + "ReductionBenchmark: number of X and Z arrays should match"); + + for (int e = 0; e < x.size(); e++) { + auto x_ = x.at(e); + auto y_ = y.at(e); + auto z_ = z.at(e); + + auto clone = op->clone(); + clone->setX(x_); + clone->setZ(z_); + + if (y_.shapeInfo() != nullptr) { + clone->setAxis(y_.asVectorT()); } - - std::string BenchmarkHelper::runOperationSuit(ReductionBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message) { - std::string output; - auto parameters = parametersBatch.parameters(); - - if (message != nullptr) { - output += "\n"; - output += message; - output += "\n"; - } - - output += printHeader(); - - for (auto &p: parameters) { - ResultSet x; - x.setNonRemovable(); - ResultSet z; - z.setNonRemovable(); - func(p, x, z); - std::vector result; - - if (x.size() != z.size()) - throw std::runtime_error("ReductionBenchmark: number of X and Z arrays should match"); - - for (int e = 0; e < x.size(); e++) { - auto x_ = x.at(e); - auto z_ = z.at(e); - - auto clone = op->clone(); - clone->setX(x_); - clone->setZ(z_); - - result.emplace_back(clone); - } - - output += runOperationSuit(result, false); - - // removing everything - for (auto v:result) { - delete reinterpret_cast(v); - } - } - - return output; + result.emplace_back(clone); + } + + output += runOperationSuit(result, message); + + // removing everything + for (auto v : result) { + delete reinterpret_cast(v); + } + + return output; +} + +std::string BenchmarkHelper::runOperationSuit( + BroadcastBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message) { + auto parameters = parametersBatch.parameters(); + std::string output; + + if (message != nullptr) { + output += "\n"; + output += message; + output += "\n"; + } + + output += printHeader(); + + for (auto &p : parameters) { + ResultSet x; + x.setNonRemovable(); + ResultSet y; + y.setNonRemovable(); + ResultSet z; + z.setNonRemovable(); + func(p, x, y, z); + std::vector result; + + if (x.size() != z.size()) + throw std::runtime_error( + "BroadcastBenchmark: number of X and Z arrays should match"); + + for (int e = 0; e < x.size(); e++) { + auto x_ = x.at(e); + auto y_ = y.at(e); + auto z_ = z.at(e); + + auto clone = op->clone(); + clone->setX(x_); + clone->setY(y_); + clone->setZ(z_); + + clone->setAxis(op->getAxis()); + result.emplace_back(clone); } - std::string BenchmarkHelper::runOperationSuit(ReductionBenchmark *op, const std::function& func, const char *message) { - std::string output; - ResultSet x; - x.setNonRemovable(); - ResultSet z; - z.setNonRemovable(); - func(x, z); - std::vector result; - - if (x.size() != z.size()) - throw std::runtime_error("ReductionBenchmark: number of X and Z arrays should match"); - - for (int e = 0; e < x.size(); e++) { - auto x_ = x.at(e); - auto z_ = z.at(e); - - auto clone = op->clone(); - clone->setX(x_); - clone->setZ(z_); - - result.emplace_back(clone); - } - - output += runOperationSuit(result, message); + output += runOperationSuit(result, false); - // removing everything - for (auto v:result) { - delete reinterpret_cast(v); - } - - return output; + // removing everything + for (auto v : result) { + delete reinterpret_cast(v); } - - std::string BenchmarkHelper::runOperationSuit(ReductionBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message) { - auto parameters = parametersBatch.parameters(); - std::string output; - - if (message != nullptr) { - output += "\n"; - output += message; - output += "\n"; - } - - printHeader(); - - for (auto &p: parameters) { - ResultSet x; - x.setNonRemovable(); - ResultSet y; - y.setNonRemovable(); - ResultSet z; - z.setNonRemovable(); - func(p, x, y, z); - std::vector result; - - if (x.size() != z.size() || x.size() != y.size()) - throw std::runtime_error("ReductionBenchmark: number of X and Z arrays should match"); - - for (int e = 0; e < x.size(); e++) { - auto x_ = x.at(e); - auto y_ = y.at(e); - auto z_ = z.at(e); - - auto clone = op->clone(); - clone->setX(x_); - clone->setZ(z_); - - if (y_ != nullptr) { - clone->setAxis(y_->asVectorT()); - delete y_; - } - - result.emplace_back(clone); - } - - output += runOperationSuit(result, false); - - // removing everything - for (auto v:result) { - delete reinterpret_cast(v); - } - } - - return output; + } + + return output; +} + +std::string BenchmarkHelper::runOperationSuit( + PairwiseBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message) { + auto parameters = parametersBatch.parameters(); + std::string output; + + if (message != nullptr) { + output += "\n"; + output += message; + output += "\n"; + } + + output += printHeader(); + + for (auto &p : parameters) { + ResultSet x; + x.setNonRemovable(); + ResultSet y; + y.setNonRemovable(); + ResultSet z; + z.setNonRemovable(); + func(p, x, y, z); + std::vector result; + + if (x.size() != z.size() || x.size() != y.size()) + throw std::runtime_error( + "PairwiseBenchmark: number of X and Z arrays should match"); + + for (int e = 0; e < x.size(); e++) { + auto x_ = x.at(e); + auto y_ = y.at(e); + auto z_ = z.at(e); + + auto clone = op->clone(); + clone->setX(x_); + clone->setY(y_); + clone->setZ(z_); + + result.emplace_back(clone); } - std::string BenchmarkHelper::runOperationSuit(ReductionBenchmark *op, const std::function& func, const char *message) { - std::string output; - - ResultSet x; - x.setNonRemovable(); - ResultSet y; - y.setNonRemovable(); - ResultSet z; - z.setNonRemovable(); - func(x, y, z); - std::vector result; - - if (x.size() != z.size() || x.size() != y.size()) - throw std::runtime_error("ReductionBenchmark: number of X and Z arrays should match"); - - for (int e = 0; e < x.size(); e++) { - auto x_ = x.at(e); - auto y_ = y.at(e); - auto z_ = z.at(e); - - auto clone = op->clone(); - clone->setX(x_); - clone->setZ(z_); - - if (y_ != nullptr) { - clone->setAxis(y_->asVectorT()); - delete y_; - } - result.emplace_back(clone); - } - - output += runOperationSuit(result, message); - - // removing everything - for (auto v:result) { - delete reinterpret_cast(v); - } - - return output; - } + output += runOperationSuit(result, false); - std::string BenchmarkHelper::runOperationSuit(BroadcastBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message) { - auto parameters = parametersBatch.parameters(); - std::string output; - - if (message != nullptr) { - output += "\n"; - output += message; - output += "\n"; - } - - output += printHeader(); - - for (auto &p: parameters) { - ResultSet x; - x.setNonRemovable(); - ResultSet y; - y.setNonRemovable(); - ResultSet z; - z.setNonRemovable(); - func(p, x, y, z); - std::vector result; - - if (x.size() != z.size() ) - throw std::runtime_error("BroadcastBenchmark: number of X and Z arrays should match"); - - for (int e = 0; e < x.size(); e++) { - auto x_ = x.at(e); - auto y_ = y.at(e); - auto z_ = z.at(e); - - auto clone = op->clone(); - clone->setX(x_); - clone->setY(y_); - clone->setZ(z_); - - clone->setAxis(op->getAxis()); - result.emplace_back(clone); - } - - output += runOperationSuit(result, false); - - // removing everything - for (auto v:result) { - delete reinterpret_cast(v); - } - } - - return output; + // removing everything + for (auto v : result) { + delete reinterpret_cast(v); } - - std::string BenchmarkHelper::runOperationSuit(PairwiseBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message) { - auto parameters = parametersBatch.parameters(); - std::string output; - - if (message != nullptr) { - output += "\n"; - output += message; - output += "\n"; - } - - output += printHeader(); - - for (auto &p: parameters) { - ResultSet x; - x.setNonRemovable(); - ResultSet y; - y.setNonRemovable(); - ResultSet z; - z.setNonRemovable(); - func(p, x, y, z); - std::vector result; - - if (x.size() != z.size() || x.size() != y.size()) - throw std::runtime_error("PairwiseBenchmark: number of X and Z arrays should match"); - - for (int e = 0; e < x.size(); e++) { - auto x_ = x.at(e); - auto y_ = y.at(e); - auto z_ = z.at(e); - - auto clone = op->clone(); - clone->setX(x_); - clone->setY(y_); - clone->setZ(z_); - - result.emplace_back(clone); - } - - output += runOperationSuit(result, false); - - // removing everything - for (auto v:result) { - delete reinterpret_cast(v); - } - } - - return output; + } + + return output; +} + +std::string BenchmarkHelper::runOperationSuit( + PairwiseBenchmark *op, + const std::function &func, + const char *message) { + std::string output; + + ResultSet x; + x.setNonRemovable(); + ResultSet y; + y.setNonRemovable(); + ResultSet z; + z.setNonRemovable(); + func(x, y, z); + std::vector result; + + if (x.size() != z.size() || x.size() != y.size()) + throw std::runtime_error( + "PairwiseBenchmark: number of X and Z arrays should match"); + + for (int e = 0; e < x.size(); e++) { + auto x_ = x.at(e); + auto y_ = y.at(e); + auto z_ = z.at(e); + + auto clone = op->clone(); + clone->setX(x_); + clone->setY(y_); + clone->setZ(z_); + + result.emplace_back(clone); + } + + output += runOperationSuit(result, message); + + // removing everything + for (auto v : result) { + delete reinterpret_cast(v); + } + + return output; +} + +std::string BenchmarkHelper::runOperationSuit( + MatrixBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message) { + auto parameters = parametersBatch.parameters(); + std::string output; + + if (message != nullptr) { + output += "\n"; + output += message; + output += "\n"; + } + + output += printHeader(); + + for (auto &p : parameters) { + ResultSet x; + x.setNonRemovable(); + ResultSet y; + y.setNonRemovable(); + ResultSet z; + z.setNonRemovable(); + func(p, x, y, z); + std::vector result; + + for (int e = 0; e < x.size(); e++) { + auto x_ = x.at(e); + auto y_ = y.at(e); + auto z_ = z.at(e); + + auto clone = op->clone(); + clone->setX(x_); + clone->setY(y_); + clone->setZ(z_); + + result.emplace_back(clone); } - std::string BenchmarkHelper::runOperationSuit(PairwiseBenchmark *op, const std::function& func, const char *message) { - std::string output; - - ResultSet x; - x.setNonRemovable(); - ResultSet y; - y.setNonRemovable(); - ResultSet z; - z.setNonRemovable(); - func(x, y, z); - std::vector result; - - if (x.size() != z.size() || x.size() != y.size()) - throw std::runtime_error("PairwiseBenchmark: number of X and Z arrays should match"); + output += runOperationSuit(result, false); - for (int e = 0; e < x.size(); e++) { - auto x_ = x.at(e); - auto y_ = y.at(e); - auto z_ = z.at(e); - - auto clone = op->clone(); - clone->setX(x_); - clone->setY(y_); - clone->setZ(z_); - - result.emplace_back(clone); - } - - output += runOperationSuit(result, message); - - // removing everything - for (auto v:result) { - delete reinterpret_cast(v); - } - - return output; + // removing everything + for (auto v : result) { + delete reinterpret_cast(v); } + } - std::string BenchmarkHelper::runOperationSuit(MatrixBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message) { - auto parameters = parametersBatch.parameters(); - std::string output; - - if (message != nullptr) { - output += "\n"; - output += message; - output += "\n"; - } - - output += printHeader(); - - for (auto &p: parameters) { - ResultSet x; - x.setNonRemovable(); - ResultSet y; - y.setNonRemovable(); - ResultSet z; - z.setNonRemovable(); - func(p, x, y, z); - std::vector result; - - for (int e = 0; e < x.size(); e++) { - auto x_ = x.at(e); - auto y_ = y.at(e); - auto z_ = z.at(e); - - auto clone = op->clone(); - clone->setX(x_); - clone->setY(y_); - clone->setZ(z_); - - result.emplace_back(clone); - } - - output += runOperationSuit(result, false); - - // removing everything - for (auto v:result) { - delete reinterpret_cast(v); - } - } - - return output; - } -} \ No newline at end of file + return output; +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/impl/BitwiseUtils.cpp b/libnd4j/include/helpers/impl/BitwiseUtils.cpp index 9bd3fa8cfaa4..3a560a8c7c0e 100644 --- a/libnd4j/include/helpers/impl/BitwiseUtils.cpp +++ b/libnd4j/include/helpers/impl/BitwiseUtils.cpp @@ -18,63 +18,60 @@ // Created by raver119 on 10.11.2017. // -#include #include +#include #include #include namespace sd { - bool BitwiseUtils::isBE() { - short int word = 0x0001; - char *byte = (char *) &word; - return(byte[0] ? false : true); - } +bool BitwiseUtils::isBE() { + short int word = 0x0001; + char *byte = (char *)&word; + return (byte[0] ? false : true); +} - int BitwiseUtils::valueBit(int holder) { - if (holder == 0) - return -1; +int BitwiseUtils::valueBit(int holder) { + if (holder == 0) return -1; #ifdef REVERSE_BITS - for (int e = 32; e >= 0; e--) { + for (int e = 32; e >= 0; e--) { #else - for (int e = 0; e < 32; e++) { + for (int e = 0; e < 32; e++) { #endif - bool isOne = (holder & 1 << e) != 0; + bool isOne = (holder & 1 << e) != 0; - if (isOne) - return e; - } + if (isOne) return e; + } - return -1; - } + return -1; +} - std::vector BitwiseUtils::valueBits(int holder) { - std::vector bits; - if (holder == 0) { - for (int e = 0; e < 32; e++) - bits.emplace_back(0); +std::vector BitwiseUtils::valueBits(int holder) { + std::vector bits; + if (holder == 0) { + for (int e = 0; e < 32; e++) bits.emplace_back(0); - return bits; - } + return bits; + } #ifdef REVERSE_BITS - for (int e = 32; e >= 0; e--) { + for (int e = 32; e >= 0; e--) { #else - for (int e = 0; e < 32; e++) { + for (int e = 0; e < 32; e++) { #endif - bool isOne = (holder & 1 << e) != 0; + bool isOne = (holder & 1 << e) != 0; - if (isOne) - bits.emplace_back(1); - else - bits.emplace_back(0); - } + if (isOne) + bits.emplace_back(1); + else + bits.emplace_back(0); + } - return bits; - } + return bits; +} - sd::ByteOrder BitwiseUtils::asByteOrder() { - return isBE() ? ByteOrder::BE : ByteOrder::LE; - } +sd::ByteOrder BitwiseUtils::asByteOrder() { + return isBE() ? ByteOrder::BE : ByteOrder::LE; } +} // namespace sd diff --git a/libnd4j/include/helpers/impl/BlasHelper.cpp b/libnd4j/include/helpers/impl/BlasHelper.cpp index 70839fe2dea9..564757195a9e 100644 --- a/libnd4j/include/helpers/impl/BlasHelper.cpp +++ b/libnd4j/include/helpers/impl/BlasHelper.cpp @@ -20,345 +20,322 @@ #include namespace sd { - BlasHelper& BlasHelper::getInstance() { - static BlasHelper instance; - return instance; - } - - - void BlasHelper::initializeFunctions(Nd4jPointer *functions) { - nd4j_debug("Initializing BLAS\n",""); - - _hasSgemv = functions[0] != nullptr; - _hasSgemm = functions[2] != nullptr; - - _hasDgemv = functions[1] != nullptr; - _hasDgemm = functions[3] != nullptr; - - _hasSgemmBatch = functions[4] != nullptr; - _hasDgemmBatch = functions[5] != nullptr; - - this->cblasSgemv = (CblasSgemv)functions[0]; - this->cblasDgemv = (CblasDgemv)functions[1]; - this->cblasSgemm = (CblasSgemm)functions[2]; - this->cblasDgemm = (CblasDgemm)functions[3]; - this->cblasSgemmBatch = (CblasSgemmBatch)functions[4]; - this->cblasDgemmBatch = (CblasDgemmBatch)functions[5]; - this->lapackeSgesvd = (LapackeSgesvd)functions[6]; - this->lapackeDgesvd = (LapackeDgesvd)functions[7]; - this->lapackeSgesdd = (LapackeSgesdd)functions[8]; - this->lapackeDgesdd = (LapackeDgesdd)functions[9]; - } - - void BlasHelper::initializeDeviceFunctions(Nd4jPointer *functions) { - nd4j_debug("Initializing device BLAS\n",""); - - /* - this->cublasSgemv = (CublasSgemv)functions[0]; - this->cublasDgemv = (CublasDgemv)functions[1]; - this->cublasHgemm = (CublasHgemm)functions[2]; - this->cublasSgemm = (CublasSgemm)functions[3]; - this->cublasDgemm = (CublasDgemm)functions[4]; - this->cublasSgemmEx = (CublasSgemmEx)functions[5]; - this->cublasHgemmBatched = (CublasHgemmBatched)functions[6]; - this->cublasSgemmBatched = (CublasSgemmBatched)functions[7]; - this->cublasDgemmBatched = (CublasDgemmBatched)functions[8]; - this->cusolverDnSgesvdBufferSize = (CusolverDnSgesvdBufferSize)functions[9]; - this->cusolverDnDgesvdBufferSize = (CusolverDnDgesvdBufferSize)functions[10]; - this->cusolverDnSgesvd = (CusolverDnSgesvd)functions[11]; - this->cusolverDnDgesvd = (CusolverDnDgesvd)functions[12]; - */ - } - - - template <> - bool BlasHelper::hasGEMV() { - if (sd::Environment::getInstance().blasFallback()) - return false; +BlasHelper& BlasHelper::getInstance() { + static BlasHelper instance; + return instance; +} + +void BlasHelper::initializeFunctions(Nd4jPointer* functions) { + nd4j_debug("Initializing BLAS\n", ""); + + _hasSgemv = functions[0] != nullptr; + _hasSgemm = functions[2] != nullptr; + + _hasDgemv = functions[1] != nullptr; + _hasDgemm = functions[3] != nullptr; + + _hasSgemmBatch = functions[4] != nullptr; + _hasDgemmBatch = functions[5] != nullptr; + + this->cblasSgemv = (CblasSgemv)functions[0]; + this->cblasDgemv = (CblasDgemv)functions[1]; + this->cblasSgemm = (CblasSgemm)functions[2]; + this->cblasDgemm = (CblasDgemm)functions[3]; + this->cblasSgemmBatch = (CblasSgemmBatch)functions[4]; + this->cblasDgemmBatch = (CblasDgemmBatch)functions[5]; + this->lapackeSgesvd = (LapackeSgesvd)functions[6]; + this->lapackeDgesvd = (LapackeDgesvd)functions[7]; + this->lapackeSgesdd = (LapackeSgesdd)functions[8]; + this->lapackeDgesdd = (LapackeDgesdd)functions[9]; +} + +void BlasHelper::initializeDeviceFunctions(Nd4jPointer* functions) { + nd4j_debug("Initializing device BLAS\n", ""); + + /* + this->cublasSgemv = (CublasSgemv)functions[0]; + this->cublasDgemv = (CublasDgemv)functions[1]; + this->cublasHgemm = (CublasHgemm)functions[2]; + this->cublasSgemm = (CublasSgemm)functions[3]; + this->cublasDgemm = (CublasDgemm)functions[4]; + this->cublasSgemmEx = (CublasSgemmEx)functions[5]; + this->cublasHgemmBatched = (CublasHgemmBatched)functions[6]; + this->cublasSgemmBatched = (CublasSgemmBatched)functions[7]; + this->cublasDgemmBatched = (CublasDgemmBatched)functions[8]; + this->cusolverDnSgesvdBufferSize = (CusolverDnSgesvdBufferSize)functions[9]; + this->cusolverDnDgesvdBufferSize = (CusolverDnDgesvdBufferSize)functions[10]; + this->cusolverDnSgesvd = (CusolverDnSgesvd)functions[11]; + this->cusolverDnDgesvd = (CusolverDnDgesvd)functions[12]; + */ +} + +template <> +bool BlasHelper::hasGEMV() { + if (sd::Environment::getInstance().blasFallback()) return false; #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) - return true; + return true; #else - return _hasSgemv; + return _hasSgemv; #endif - } +} + +template <> +bool BlasHelper::hasGEMV() { + if (sd::Environment::getInstance().blasFallback()) return false; + +#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) + return true; +#else + return _hasDgemv; +#endif +} + +template <> +bool BlasHelper::hasGEMV() { + return false; +} + +template <> +bool BlasHelper::hasGEMV() { + return false; +} + +template <> +bool BlasHelper::hasGEMV() { + return false; +} + +template <> +bool BlasHelper::hasGEMV() { + return false; +} - template <> - bool BlasHelper::hasGEMV() { - if (sd::Environment::getInstance().blasFallback()) - return false; +template <> +bool BlasHelper::hasGEMV() { + return false; +} + +template <> +bool BlasHelper::hasGEMV() { + return false; +} + +template <> +bool BlasHelper::hasGEMV() { + return false; +} + +template <> +bool BlasHelper::hasGEMV() { + return false; +} + +bool BlasHelper::hasGEMV(const sd::DataType dtype) { + if (dtype == DataType::FLOAT32) { + if (sd::Environment::getInstance().blasFallback()) return false; + +#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) + return true; +#else + return _hasSgemv; +#endif + } + if (dtype == DataType::DOUBLE) { + if (sd::Environment::getInstance().blasFallback()) return false; #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) - return true; + return true; #else return _hasDgemv; #endif - } - - template <> - bool BlasHelper::hasGEMV() { - return false; - } - - template <> - bool BlasHelper::hasGEMV() { - return false; - } - - template <> - bool BlasHelper::hasGEMV() { - return false; - } - - template <> - bool BlasHelper::hasGEMV() { - return false; - } - - template <> - bool BlasHelper::hasGEMV() { - return false; - } - - template <> - bool BlasHelper::hasGEMV() { - return false; - } - - template <> - bool BlasHelper::hasGEMV() { - return false; - } - - template <> - bool BlasHelper::hasGEMV() { - return false; - } - - bool BlasHelper::hasGEMV(const sd::DataType dtype) { - if(dtype == DataType::FLOAT32) { - if (sd::Environment::getInstance().blasFallback()) - return false; - - #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) - return true; - #else - return _hasSgemv; - #endif - } - if(dtype == DataType::DOUBLE) { - if (sd::Environment::getInstance().blasFallback()) - return false; - - #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) - return true; - #else - return _hasDgemv; - #endif - } - return false; - } - - template <> - bool BlasHelper::hasGEMM() { - if (sd::Environment::getInstance().blasFallback()) - return false; + } + return false; +} + +template <> +bool BlasHelper::hasGEMM() { + if (sd::Environment::getInstance().blasFallback()) return false; #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) - return true; + return true; #else - return _hasSgemm; + return _hasSgemm; +#endif +} + +template <> +bool BlasHelper::hasGEMM() { + if (sd::Environment::getInstance().blasFallback()) return false; + +#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) + return true; +#else + return _hasDgemm; #endif - } +} + +template <> +bool BlasHelper::hasGEMM() { + return false; +} + +template <> +bool BlasHelper::hasGEMM() { + return false; +} + +template <> +bool BlasHelper::hasGEMM() { + return false; +} + +template <> +bool BlasHelper::hasGEMM() { + return false; +} + +template <> +bool BlasHelper::hasGEMM() { + return false; +} + +template <> +bool BlasHelper::hasGEMM() { + return false; +} + +template <> +bool BlasHelper::hasGEMM() { + return false; +} + +template <> +bool BlasHelper::hasGEMM() { + return false; +} - template <> - bool BlasHelper::hasGEMM() { - if (sd::Environment::getInstance().blasFallback()) - return false; +bool BlasHelper::hasGEMM(const sd::DataType dtype) { + if (dtype == DataType::FLOAT32) { + if (sd::Environment::getInstance().blasFallback()) return false; #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) - return true; + return true; +#else + return _hasSgemm; +#endif + } + if (dtype == DataType::DOUBLE) { + if (sd::Environment::getInstance().blasFallback()) return false; + +#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) + return true; #else return _hasDgemm; #endif - } - - template <> - bool BlasHelper::hasGEMM() { - return false; - } - - template <> - bool BlasHelper::hasGEMM() { - return false; - } - - template <> - bool BlasHelper::hasGEMM() { - return false; - } - - template <> - bool BlasHelper::hasGEMM() { - return false; - } - - template <> - bool BlasHelper::hasGEMM() { - return false; - } - - template <> - bool BlasHelper::hasGEMM() { - return false; - } - - template <> - bool BlasHelper::hasGEMM() { - return false; - } - - template <> - bool BlasHelper::hasGEMM() { - return false; - } - - bool BlasHelper:: hasGEMM(const sd::DataType dtype) { - if(dtype == DataType::FLOAT32) { - if (sd::Environment::getInstance().blasFallback()) - return false; - - #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) - return true; - #else - return _hasSgemm; - #endif - } - if(dtype == DataType::DOUBLE) { - if (sd::Environment::getInstance().blasFallback()) - return false; - - #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) - return true; - #else - return _hasDgemm; - #endif - } - return false; - } - - - template <> - bool BlasHelper::hasBatchedGEMM() { - if (sd::Environment::getInstance().blasFallback()) - return false; - - return _hasSgemmBatch; - } - - template <> - bool BlasHelper::hasBatchedGEMM() { - if (sd::Environment::getInstance().blasFallback()) - return false; - - return _hasDgemmBatch; - } - - template <> - bool BlasHelper::hasBatchedGEMM() { - return false; - } - - template <> - bool BlasHelper::hasBatchedGEMM() { - return false; - } - - template <> - bool BlasHelper::hasBatchedGEMM() { - return false; - } - - template <> - bool BlasHelper::hasBatchedGEMM() { - return false; - } - - template <> - bool BlasHelper::hasBatchedGEMM() { - return false; - } - - template <> - bool BlasHelper::hasBatchedGEMM() { - return false; - } - - template <> - bool BlasHelper::hasBatchedGEMM() { - return false; - } - - template <> - bool BlasHelper::hasBatchedGEMM() { - return false; - } - - CblasSgemv BlasHelper::sgemv() { -#if defined(__EXTERNAL_BLAS__)|| defined(HAVE_OPENBLAS) - return (CblasSgemv)&cblas_sgemv; + } + return false; +} + +template <> +bool BlasHelper::hasBatchedGEMM() { + if (sd::Environment::getInstance().blasFallback()) return false; + + return _hasSgemmBatch; +} + +template <> +bool BlasHelper::hasBatchedGEMM() { + if (sd::Environment::getInstance().blasFallback()) return false; + + return _hasDgemmBatch; +} + +template <> +bool BlasHelper::hasBatchedGEMM() { + return false; +} + +template <> +bool BlasHelper::hasBatchedGEMM() { + return false; +} + +template <> +bool BlasHelper::hasBatchedGEMM() { + return false; +} + +template <> +bool BlasHelper::hasBatchedGEMM() { + return false; +} + +template <> +bool BlasHelper::hasBatchedGEMM() { + return false; +} + +template <> +bool BlasHelper::hasBatchedGEMM() { + return false; +} + +template <> +bool BlasHelper::hasBatchedGEMM() { + return false; +} + +template <> +bool BlasHelper::hasBatchedGEMM() { + return false; +} + +CblasSgemv BlasHelper::sgemv() { +#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) + return (CblasSgemv)&cblas_sgemv; #else - return this->cblasSgemv; + return this->cblasSgemv; #endif - } - CblasDgemv BlasHelper::dgemv() { +} +CblasDgemv BlasHelper::dgemv() { #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) - return (CblasDgemv)&cblas_dgemv; + return (CblasDgemv)&cblas_dgemv; #else - return this->cblasDgemv; + return this->cblasDgemv; #endif - } +} - CblasSgemm BlasHelper::sgemm() { +CblasSgemm BlasHelper::sgemm() { #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) - return (CblasSgemm)&cblas_sgemm; + return (CblasSgemm)&cblas_sgemm; #else - return this->cblasSgemm; + return this->cblasSgemm; #endif - } +} - CblasDgemm BlasHelper::dgemm() { +CblasDgemm BlasHelper::dgemm() { #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) - return (CblasDgemm)&cblas_dgemm; + return (CblasDgemm)&cblas_dgemm; #else - return this->cblasDgemm; + return this->cblasDgemm; #endif - } +} - CblasSgemmBatch BlasHelper::sgemmBatched() { - return this->cblasSgemmBatch; - } +CblasSgemmBatch BlasHelper::sgemmBatched() { return this->cblasSgemmBatch; } - CblasDgemmBatch BlasHelper::dgemmBatched() { - return this->cblasDgemmBatch; - } +CblasDgemmBatch BlasHelper::dgemmBatched() { return this->cblasDgemmBatch; } - LapackeSgesvd BlasHelper::sgesvd() { - return this->lapackeSgesvd; - } +LapackeSgesvd BlasHelper::sgesvd() { return this->lapackeSgesvd; } - LapackeDgesvd BlasHelper::dgesvd() { - return this->lapackeDgesvd; - } +LapackeDgesvd BlasHelper::dgesvd() { return this->lapackeDgesvd; } - LapackeSgesdd BlasHelper::sgesdd() { - return this->lapackeSgesdd; - } +LapackeSgesdd BlasHelper::sgesdd() { return this->lapackeSgesdd; } - LapackeDgesdd BlasHelper::dgesdd() { - return this->lapackeDgesdd; - } +LapackeDgesdd BlasHelper::dgesdd() { return this->lapackeDgesdd; } - // destructor - BlasHelper::~BlasHelper() noexcept { } -} +// destructor +BlasHelper::~BlasHelper() noexcept {} + + +} // namespace sd diff --git a/libnd4j/include/helpers/impl/CudaLaunchHelper.cpp b/libnd4j/include/helpers/impl/CudaLaunchHelper.cpp index d0bcce11e62b..3c4b91d7fb5d 100644 --- a/libnd4j/include/helpers/impl/CudaLaunchHelper.cpp +++ b/libnd4j/include/helpers/impl/CudaLaunchHelper.cpp @@ -22,20 +22,20 @@ #include namespace sd { - Triple CudaLaunchHelper::getFlatLaunchParams(Nd4jLong length, int SM, int CORES, int SHARED_MEMORY) { - // TODO: to be implemented - Triple triple(1, 2, 3); +Triple CudaLaunchHelper::getFlatLaunchParams(Nd4jLong length, int SM, int CORES, + int SHARED_MEMORY) { + // TODO: to be implemented + Triple triple(1, 2, 3); - return triple; - } + return triple; +} - int CudaLaunchHelper::getReductionBlocks(Nd4jLong xLength, int blockSize) { - int div = xLength / blockSize; - int can = sd::math::nd4j_max(div, 1); - if (xLength % blockSize != 0 && xLength > blockSize) - can++; +int CudaLaunchHelper::getReductionBlocks(Nd4jLong xLength, int blockSize) { + int div = xLength / blockSize; + int can = sd::math::nd4j_max(div, 1); + if (xLength % blockSize != 0 && xLength > blockSize) can++; - // not more then 512 blocks - return sd::math::nd4j_min(can, 512); - } + // not more then 512 blocks + return sd::math::nd4j_min(can, 512); } +} // namespace sd diff --git a/libnd4j/include/helpers/impl/DebugHelper.cpp b/libnd4j/include/helpers/impl/DebugHelper.cpp index d24068a65484..cd8df4261971 100644 --- a/libnd4j/include/helpers/impl/DebugHelper.cpp +++ b/libnd4j/include/helpers/impl/DebugHelper.cpp @@ -18,92 +18,97 @@ // Created by raver119 on 20/04/18. // -#include #include #include -#include -#include #include +#include +#include +#include namespace sd { - DebugInfo DebugHelper::debugStatistics(NDArray const* input) { - DebugInfo info; - DebugHelper::retrieveDebugStatistics(&info, input); - return info; - } - void - DebugHelper::retrieveDebugStatistics(DebugInfo* info, NDArray const* input) { - if (nullptr == info) - return; - - info->_minValue = 0.; - info->_maxValue = -1; - info->_meanValue = 0.; - info->_stdDevValue = 1.; - info->_zeroCount = 0; - info->_positiveCount = 0; - info->_negativeCount = 0; - info->_infCount = 0; - info->_nanCount = 0; - if (input->lengthOf() == 1) { // scalar case - info->_minValue = input->e(0); - info->_maxValue = info->_minValue; - info->_meanValue = info->_minValue; - info->_stdDevValue = info->_minValue; - info->_zeroCount = sd::math::nd4j_abs(input->e(0)) > 0.00001? 0: 1; - info->_positiveCount = input->e(0) > 0?1:0; - info->_negativeCount = input->e(0) < 0?1:0; - info->_infCount = sd::math::nd4j_isinf(input->e(0)); - info->_nanCount = sd::math::nd4j_isnan(input->e(0)); - } - else if (input->lengthOf() > 0) { - // TO DO: here processing for all elements with array - auto _minValue = input->e(0); - auto _maxValue = input->e(0); - auto _meanValue = input->e(0); - auto _stdDevValue = 0.; //info->_minValue; - auto _zeroCount = sd::math::nd4j_abs(input->e(0)) > 0.00001? 0L : 1L; - auto _positiveCount = input->e(0) > 0? 1L : 0L; - auto _negativeCount = input->e(0) < 0? 1L : 0L; - auto _infCount = sd::math::nd4j_isinf(input->e(0)) ? 1L : 0L; - auto _nanCount = sd::math::nd4j_isnan(input->e(0)) ? 1L : 0L; +DebugInfo DebugHelper::debugStatistics(NDArray const* input) { + DebugInfo info; + DebugHelper::retrieveDebugStatistics(&info, input); + return info; +} +void DebugHelper::retrieveDebugStatistics(DebugInfo* info, + NDArray const* input) { + if (nullptr == info) return; -PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided) reduction(+:_nanCount,_infCount,_meanValue,_zeroCount,_positiveCount,_negativeCount) reduction(min:_minValue) reduction(max:_maxValue)) - for (Nd4jLong e = 1; e < input->lengthOf(); e++) { - auto current = input->e(e); - auto n = e + 1.; -// auto delta = current - _meanValue; -// auto delta2 = delta * delta; - _minValue = sd::math::nd4j_min(current, _minValue); - _maxValue = sd::math::nd4j_max(current, _maxValue); + info->_minValue = 0.; + info->_maxValue = -1; + info->_meanValue = 0.; + info->_stdDevValue = 1.; + info->_zeroCount = 0; + info->_positiveCount = 0; + info->_negativeCount = 0; + info->_infCount = 0; + info->_nanCount = 0; + if (input->lengthOf() == 1) { // scalar case + info->_minValue = input->e(0); + info->_maxValue = info->_minValue; + info->_meanValue = info->_minValue; + info->_stdDevValue = info->_minValue; + info->_zeroCount = + sd::math::nd4j_abs(input->e(0)) > 0.00001 ? 0 : 1; + info->_positiveCount = input->e(0) > 0 ? 1 : 0; + info->_negativeCount = input->e(0) < 0 ? 1 : 0; + info->_infCount = sd::math::nd4j_isinf(input->e(0)); + info->_nanCount = sd::math::nd4j_isnan(input->e(0)); + } else if (input->lengthOf() > 0) { + // TO DO: here processing for all elements with array + auto _minValue = input->e(0); + auto _maxValue = input->e(0); + auto _meanValue = input->e(0); + auto _stdDevValue = 0.; // info->_minValue; + auto _zeroCount = + sd::math::nd4j_abs(input->e(0)) > 0.00001 ? 0L : 1L; + auto _positiveCount = input->e(0) > 0 ? 1L : 0L; + auto _negativeCount = input->e(0) < 0 ? 1L : 0L; + auto _infCount = sd::math::nd4j_isinf(input->e(0)) ? 1L : 0L; + auto _nanCount = sd::math::nd4j_isnan(input->e(0)) ? 1L : 0L; - _meanValue += current; - //_meanValue += delta / n; // this is a perfect formula but not working with omp in this notation - //_stdDevValue += delta2 * e / n; + PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided) reduction(+:_nanCount,_infCount,_meanValue,_zeroCount,_positiveCount,_negativeCount) reduction(min:_minValue) reduction(max:_maxValue)) + for (Nd4jLong e = 1; e < input->lengthOf(); e++) { + auto current = input->e(e); + auto n = e + 1.; + // auto delta = current - _meanValue; + // auto delta2 = delta * delta; + _minValue = sd::math::nd4j_min(current, _minValue); + _maxValue = sd::math::nd4j_max(current, _maxValue); - _zeroCount += sd::math::nd4j_abs(current) > 0.00001 ? 0 : 1; - _positiveCount += current > 0 ? 1 : 0; - _negativeCount += current < 0 ? 1 : 0; - _infCount += sd::math::nd4j_isinf(current); - _nanCount += sd::math::nd4j_isnan(current); - } - *info = {_minValue, _maxValue, _meanValue / input->lengthOf(), _stdDevValue, _zeroCount, _positiveCount, _negativeCount, _infCount, _nanCount}; - _stdDevValue = 0; //math::nd4j_sqrt(info->_stdDevValue / (input->lengthOf() - 1)); + _meanValue += current; + //_meanValue += delta / n; // this is a perfect formula but not working + // with omp in this notation _stdDevValue += delta2 * e / n; - auto func = PRAGMA_REDUCE_DOUBLE { - auto _stdDevValue = 0.0; - for (auto e = start; e < stop; e++) { - double current = input->e(e); - _stdDevValue += (info->_meanValue - current) * (info->_meanValue - current); //info->_minValue; - } + _zeroCount += sd::math::nd4j_abs(current) > 0.00001 ? 0 : 1; + _positiveCount += current > 0 ? 1 : 0; + _negativeCount += current < 0 ? 1 : 0; + _infCount += sd::math::nd4j_isinf(current); + _nanCount += sd::math::nd4j_isnan(current); + } + *info = {_minValue, _maxValue, _meanValue / input->lengthOf(), + _stdDevValue, _zeroCount, _positiveCount, + _negativeCount, _infCount, _nanCount}; + _stdDevValue = 0; // math::nd4j_sqrt(info->_stdDevValue / + // (input->lengthOf() - 1)); - return _stdDevValue; - }; - _stdDevValue = samediff::Threads::parallel_double(func, LAMBDA_AD { return _old + _new; }, 0, input->lengthOf()); + auto func = PRAGMA_REDUCE_DOUBLE { + auto _stdDevValue = 0.0; + for (auto e = start; e < stop; e++) { + double current = input->e(e); + _stdDevValue += (info->_meanValue - current) * + (info->_meanValue - current); // info->_minValue; + } - info->_stdDevValue = math::nd4j_sqrt(_stdDevValue / input->lengthOf()); + return _stdDevValue; + }; + _stdDevValue = samediff::Threads::parallel_double( + func, LAMBDA_AD { return _old + _new; }, 0, input->lengthOf()); - } -// else - no statistics for empty - } + info->_stdDevValue = + math::nd4j_sqrt(_stdDevValue / input->lengthOf()); + } + // else - no statistics for empty } +} // namespace sd diff --git a/libnd4j/include/helpers/impl/EigenValsAndVecs.cpp b/libnd4j/include/helpers/impl/EigenValsAndVecs.cpp index 6eeb0c28bff1..8f11363cb7ec 100644 --- a/libnd4j/include/helpers/impl/EigenValsAndVecs.cpp +++ b/libnd4j/include/helpers/impl/EigenValsAndVecs.cpp @@ -283,10 +283,10 @@ void EigenValsAndVecs::calcEigenVecs(const NDArray& schurMatrixU) { } -template class ND4J_EXPORT EigenValsAndVecs; -template class ND4J_EXPORT EigenValsAndVecs; -template class ND4J_EXPORT EigenValsAndVecs; -template class ND4J_EXPORT EigenValsAndVecs; +template class SD_EXPORT EigenValsAndVecs; +template class SD_EXPORT EigenValsAndVecs; +template class SD_EXPORT EigenValsAndVecs; +template class SD_EXPORT EigenValsAndVecs; } } diff --git a/libnd4j/include/helpers/impl/EnumUtils.cpp b/libnd4j/include/helpers/impl/EnumUtils.cpp index a18592d0b962..c848a7407d94 100644 --- a/libnd4j/include/helpers/impl/EnumUtils.cpp +++ b/libnd4j/include/helpers/impl/EnumUtils.cpp @@ -24,55 +24,91 @@ using namespace sd::graph; namespace sd { - const char * EnumUtils::_VariableTypeToString(sd::graph::VariableType variableType) { - switch (variableType) { - case NDARRAY: return "NDARRAY"; - case ARRAY_LIST: return "ARRAY_LIST"; - case FLOW: return "FLOW"; - default: return "UNKNOWN VariableType"; - } - } +const char* EnumUtils::_VariableTypeToString( + sd::graph::VariableType variableType) { + switch (variableType) { + case NDARRAY: + return "NDARRAY"; + case ARRAY_LIST: + return "ARRAY_LIST"; + case FLOW: + return "FLOW"; + default: + return "UNKNOWN VariableType"; + } +} - const char * EnumUtils::_OpTypeToString(sd::graph::OpType opType) { - switch(opType) { - case OpType_REDUCE_SAME: return "REDUCE_SAME"; - case OpType_REDUCE_BOOL: return "REDUCE_BOOL"; - case OpType_REDUCE_LONG: return "REDUCE_LONG"; - case OpType_REDUCE_FLOAT: return "REDUCE_FLOAT"; - case OpType_BOOLEAN: return "BOOLEAN"; - case OpType_BROADCAST: return "BROADCAST"; - case OpType_BROADCAST_BOOL: return "BROADCAST_BOOL"; - case OpType_PAIRWISE: return "PAIRWISE"; - case OpType_PAIRWISE_BOOL: return "PAIRWISE_BOOL"; - case OpType_CUSTOM: return "CUSTOM"; - case OpType_LOGIC: return "LOGIC"; - case OpType_TRANSFORM_SAME: return "TRANSFORM_SAME"; - case OpType_TRANSFORM_FLOAT: return "TRANSFORM_FLOAT"; - case OpType_TRANSFORM_BOOL: return "TRANSFORM_BOOL"; - case OpType_TRANSFORM_STRICT: return "TRANSFORM_STRICT"; - case OpType_TRANSFORM_ANY: return "TRANSFORM_ANY"; - case OpType_INDEX_REDUCE: return "INDEX_ACCUMULATION"; - case OpType_SCALAR: return "SCALAR"; - case OpType_SCALAR_BOOL: return "SCALAR_BOOL"; - case OpType_SHAPE: return "SHAPE"; - default: return "UNKNOWN OpType"; - } - } +const char* EnumUtils::_OpTypeToString(sd::graph::OpType opType) { + switch (opType) { + case OpType_REDUCE_SAME: + return "REDUCE_SAME"; + case OpType_REDUCE_BOOL: + return "REDUCE_BOOL"; + case OpType_REDUCE_LONG: + return "REDUCE_LONG"; + case OpType_REDUCE_FLOAT: + return "REDUCE_FLOAT"; + case OpType_BOOLEAN: + return "BOOLEAN"; + case OpType_BROADCAST: + return "BROADCAST"; + case OpType_BROADCAST_BOOL: + return "BROADCAST_BOOL"; + case OpType_PAIRWISE: + return "PAIRWISE"; + case OpType_PAIRWISE_BOOL: + return "PAIRWISE_BOOL"; + case OpType_CUSTOM: + return "CUSTOM"; + case OpType_LOGIC: + return "LOGIC"; + case OpType_TRANSFORM_SAME: + return "TRANSFORM_SAME"; + case OpType_TRANSFORM_FLOAT: + return "TRANSFORM_FLOAT"; + case OpType_TRANSFORM_BOOL: + return "TRANSFORM_BOOL"; + case OpType_TRANSFORM_STRICT: + return "TRANSFORM_STRICT"; + case OpType_TRANSFORM_ANY: + return "TRANSFORM_ANY"; + case OpType_INDEX_REDUCE: + return "INDEX_ACCUMULATION"; + case OpType_SCALAR: + return "SCALAR"; + case OpType_SCALAR_BOOL: + return "SCALAR_BOOL"; + case OpType_SHAPE: + return "SHAPE"; + default: + return "UNKNOWN OpType"; + } +} - - const char * EnumUtils::_LogicOpToString(int opNum) { - switch(opNum) { - case 0: return "WHILE"; - case 10: return "SCOPE"; - case 20: return "CONDITIONAL"; - case 30: return "SWITCH"; - case 40: return "RETURN"; - case 60: return "MERGE"; - case 70: return "LOOP_COND"; - case 80: return "NEXT_ITERATION"; - case 90: return "EXIT"; - case 100: return "ENTER"; - default: return "UNKNOWN OPERATION"; - } - } -} \ No newline at end of file +const char* EnumUtils::_LogicOpToString(int opNum) { + switch (opNum) { + case 0: + return "WHILE"; + case 10: + return "SCOPE"; + case 20: + return "CONDITIONAL"; + case 30: + return "SWITCH"; + case 40: + return "RETURN"; + case 60: + return "MERGE"; + case 70: + return "LOOP_COND"; + case 80: + return "NEXT_ITERATION"; + case 90: + return "EXIT"; + case 100: + return "ENTER"; + default: + return "UNKNOWN OPERATION"; + } +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/impl/FileUtils.cpp b/libnd4j/include/helpers/impl/FileUtils.cpp new file mode 100644 index 000000000000..c0e02db27225 --- /dev/null +++ b/libnd4j/include/helpers/impl/FileUtils.cpp @@ -0,0 +1,38 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include +#include + +namespace sd { +bool FileUtils::fileExists(const char *filename) { + if (filename == nullptr) return false; + + return file_exists(filename); +} + +int64_t FileUtils::fileSize(const char *filename) { + struct stat stat_buf; + int rc = stat(filename, &stat_buf); + return rc == 0 ? stat_buf.st_size : -1; +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/impl/FullPivLU.cpp b/libnd4j/include/helpers/impl/FullPivLU.cpp index efb7571ed0a4..6e7993b1bc50 100644 --- a/libnd4j/include/helpers/impl/FullPivLU.cpp +++ b/libnd4j/include/helpers/impl/FullPivLU.cpp @@ -160,10 +160,10 @@ void FullPivLU::solve(const NDArray& A, const NDArray& b, NDArray& x) { x({colsPermut[i],colsPermut[i]+1, 0,0}, true).nullify(); } -template class ND4J_EXPORT FullPivLU; -template class ND4J_EXPORT FullPivLU; -template class ND4J_EXPORT FullPivLU; -template class ND4J_EXPORT FullPivLU; +template class SD_EXPORT FullPivLU; +template class SD_EXPORT FullPivLU; +template class SD_EXPORT FullPivLU; +template class SD_EXPORT FullPivLU; } } diff --git a/libnd4j/include/helpers/impl/GradCheck.cpp b/libnd4j/include/helpers/impl/GradCheck.cpp index 12ecab75f033..e9ac5f0d5c98 100644 --- a/libnd4j/include/helpers/impl/GradCheck.cpp +++ b/libnd4j/include/helpers/impl/GradCheck.cpp @@ -18,134 +18,156 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 16.07.2018 // -#include #include - +#include namespace sd { ////////////////////////////////////////////////////////////////////////// -void GradCheck::fillGradArrays(const LossFunc loss, const std::vector& gradArrs) { - - const int numInGradArrs = gradArrs.size(); - - // fill input gradient arrays in accordance to type of loss function - switch(loss) { - - case MEAN: - for(int i = 0; i < numInGradArrs; ++i) - *gradArrs[i] = 1. / gradArrs[i]->lengthOf(); - break; - - case SUM: - for(int i = 0; i < numInGradArrs; ++i) - *gradArrs[i] = 1.; - break; - - default: - throw std::invalid_argument("GradCheck::fillGradArrays: invalid type of loss function !"); - } +void GradCheck::fillGradArrays(const LossFunc loss, + const std::vector& gradArrs) { + const int numInGradArrs = gradArrs.size(); + + // fill input gradient arrays in accordance to type of loss function + switch (loss) { + case MEAN: + for (int i = 0; i < numInGradArrs; ++i) + *gradArrs[i] = 1. / gradArrs[i]->lengthOf(); + break; + + case SUM: + for (int i = 0; i < numInGradArrs; ++i) *gradArrs[i] = 1.; + break; + + default: + throw std::invalid_argument( + "GradCheck::fillGradArrays: invalid type of loss function !"); + } } ////////////////////////////////////////////////////////////////////////// -bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP, - const std::vector& whatArrsToCheck, const std::vector& idxRange, const LossFunc loss) { - - const int numInArrsFF = argsHolderFF.getNumInArrs(); // at the same time numInArrsFF = number of output arrays in opBP - const int numInGradArrsBP = argsHolderBP.getNumInArrs() - numInArrsFF; // because argsHolderBP.getNumInArrs() = numInArrsFF + numInGradArrsBP - const std::vector& inArrsFF = argsHolderFF.getInArrs(); - const std::vector& inArrsBP = argsHolderBP.getInArrs(); - - // fill input gradient arrays in accordance to kind of loss function - fillGradArrays(loss, std::vector(&inArrsBP[numInArrsFF], &inArrsBP[numInArrsFF + numInGradArrsBP])); - - // back prop pass - ResultSet outArrsBP = opBP.execute(argsHolderBP); // number of output arrays in back prop = numInArrsFF; - - NDArray tmpScalar(sd::DataType::DOUBLE, inArrsFF[0]->getContext()); // scalar = 0 - - for(int i = 0; i < numInArrsFF; ++i) { // loop through input array - - if(!whatArrsToCheck.empty() && static_cast(whatArrsToCheck[i]) == false) - continue; - - const Nd4jLong idxStart = static_cast(idxRange[0] * inArrsFF[i]->lengthOf()); - const Nd4jLong idxEnd = static_cast(idxRange[1] * inArrsFF[i]->lengthOf()); - - for(Nd4jLong j = idxStart; j < idxEnd; ++j) { // loop through all elements for current array - - const double orig = inArrsFF[i]->e(j); - - // add epsilon, feed forward - inArrsFF[i]->p(j, orig + EPSILON); - ResultSet outArrsFF = opFF.execute(argsHolderFF); - int numOutArrs = outArrsFF.size(); - double scorePlus = 0.; - - for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays - if(loss == SUM) - outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); - else - outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); - scorePlus += tmpScalar.e(0); - } - - // subtract epsilon, feed forward - inArrsFF[i]->p(j, orig - EPSILON); - outArrsFF = opFF.execute(argsHolderFF); - double scoreMinus = 0.; - - for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays - if(loss == SUM) - outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); - else - outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); - scoreMinus += tmpScalar.e(0); - } - - // restore initial element value - inArrsFF[i]->p(j, orig); - - // calculate numerical gradient - const double numericalGrad = (scorePlus - scoreMinus) / (2 * EPSILON); - if(std::isnan(numericalGrad) || std::isinf(numericalGrad)) { - printf("GradCheck::checkGrad: got wrong value for numerical gradient for input array # %i and its element at position %lld ! \n", i, j); - throw std::runtime_error(""); - } - - // get analytical gradient - const double analyticGrad = outArrsBP.at(i)->e(j); - if(std::isnan(analyticGrad) || std::isinf(analyticGrad)) { - printf("GradCheck::checkGrad: got wrong value for analytical gradient for input array # %i and its element at position %lld ! \n", i, j); - throw std::runtime_error(""); - } - - // printf("%lld: num = %.15f, ana = %.15f\n", j, numericalGrad, analyticGrad); - - // calculate relative error - double relError; - if(numericalGrad == 0. && analyticGrad == 0.) - relError = 0.; - else - relError = math::nd4j_abs(analyticGrad - numericalGrad) / (math::nd4j_abs(analyticGrad) + math::nd4j_abs(numericalGrad)); - - // verify result - if(relError > MAXRELERR || std::isnan(relError)) { - - if(math::nd4j_abs(analyticGrad - numericalGrad) < MINABSERR) - continue; - printf("numericalGrad = %.15f, analyticGrad = %.15f \n", numericalGrad, analyticGrad); - printf("GradCheck::checkGrad: got RELERROR = %f > MAXRELERROR(%f) for input array # %i and its element at position %lld ! \n", relError, MAXRELERR, i, j); - return false; - } - } - } - - return true; -} - - - +bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, + const OpArgsHolder& argsHolderFF, + const OpArgsHolder& argsHolderBP, + const std::vector& whatArrsToCheck, + const std::vector& idxRange, + const LossFunc loss) { + const int numInArrsFF = + argsHolderFF.getNumInArrs(); // at the same time numInArrsFF = number of + // output arrays in opBP + const int numInGradArrsBP = + argsHolderBP.getNumInArrs() - + numInArrsFF; // because argsHolderBP.getNumInArrs() = numInArrsFF + + // numInGradArrsBP + const std::vector& inArrsFF = argsHolderFF.getInArrs(); + const std::vector& inArrsBP = argsHolderBP.getInArrs(); + + // fill input gradient arrays in accordance to kind of loss function + fillGradArrays( + loss, std::vector(&inArrsBP[numInArrsFF], + &inArrsBP[numInArrsFF + numInGradArrsBP])); + + // back prop pass + ResultSet outArrsBP = opBP.execute( + argsHolderBP); // number of output arrays in back prop = numInArrsFF; + + NDArray tmpScalar(sd::DataType::DOUBLE, + inArrsFF[0]->getContext()); // scalar = 0 + + for (int i = 0; i < numInArrsFF; ++i) { // loop through input array + + if (!whatArrsToCheck.empty() && + static_cast(whatArrsToCheck[i]) == false) + continue; + + const Nd4jLong idxStart = + static_cast(idxRange[0] * inArrsFF[i]->lengthOf()); + const Nd4jLong idxEnd = + static_cast(idxRange[1] * inArrsFF[i]->lengthOf()); + + for (Nd4jLong j = idxStart; j < idxEnd; + ++j) { // loop through all elements for current array + + const double orig = inArrsFF[i]->e(j); + + // add epsilon, feed forward + inArrsFF[i]->p(j, orig + EPSILON); + ResultSet outArrsFF = opFF.execute(argsHolderFF); + int numOutArrs = outArrsFF.size(); + double scorePlus = 0.; + + for (int k = 0; k < numOutArrs; ++k) { // loop through output arrays + if (loss == SUM) + outArrsFF.at(k).reduceNumber(reduce::Sum, tmpScalar); + else + outArrsFF.at(k).reduceNumber(reduce::Mean, tmpScalar); + scorePlus += tmpScalar.e(0); + } + + // subtract epsilon, feed forward + inArrsFF[i]->p(j, orig - EPSILON); + outArrsFF = opFF.execute(argsHolderFF); + double scoreMinus = 0.; + + for (int k = 0; k < numOutArrs; ++k) { // loop through output arrays + if (loss == SUM) + outArrsFF.at(k).reduceNumber(reduce::Sum, tmpScalar); + else + outArrsFF.at(k).reduceNumber(reduce::Mean, tmpScalar); + scoreMinus += tmpScalar.e(0); + } + + // restore initial element value + inArrsFF[i]->p(j, orig); + + // calculate numerical gradient + const double numericalGrad = (scorePlus - scoreMinus) / (2 * EPSILON); + if (std::isnan(numericalGrad) || std::isinf(numericalGrad)) { + printf( + "GradCheck::checkGrad: got wrong value for numerical gradient for " + "input array # %i and its element at position %lld ! \n", + i, j); + throw std::runtime_error(""); + } + + // get analytical gradient + const double analyticGrad = outArrsBP.at(i).e(j); + if (std::isnan(analyticGrad) || std::isinf(analyticGrad)) { + printf( + "GradCheck::checkGrad: got wrong value for analytical gradient for " + "input array # %i and its element at position %lld ! \n", + i, j); + throw std::runtime_error(""); + } + + // printf("%lld: num = %.15f, ana = %.15f\n", j, numericalGrad, + // analyticGrad); + + // calculate relative error + double relError; + if (numericalGrad == 0. && analyticGrad == 0.) + relError = 0.; + else + relError = math::nd4j_abs(analyticGrad - numericalGrad) / + (math::nd4j_abs(analyticGrad) + + math::nd4j_abs(numericalGrad)); + + // verify result + if (relError > MAXRELERR || std::isnan(relError)) { + if (math::nd4j_abs(analyticGrad - numericalGrad) < MINABSERR) + continue; + printf("numericalGrad = %.15f, analyticGrad = %.15f \n", numericalGrad, + analyticGrad); + printf( + "GradCheck::checkGrad: got RELERROR = %f > MAXRELERROR(%f) for " + "input array # %i and its element at position %lld ! \n", + relError, MAXRELERR, i, j); + return false; + } + } + } + + return true; } - +} // namespace sd diff --git a/libnd4j/include/helpers/impl/HessenbergAndSchur.cpp b/libnd4j/include/helpers/impl/HessenbergAndSchur.cpp index 31495cab9f10..1422cf261ceb 100644 --- a/libnd4j/include/helpers/impl/HessenbergAndSchur.cpp +++ b/libnd4j/include/helpers/impl/HessenbergAndSchur.cpp @@ -368,15 +368,15 @@ void Schur::calcFromHessenberg() { } } -template class ND4J_EXPORT Hessenberg; -template class ND4J_EXPORT Hessenberg; -template class ND4J_EXPORT Hessenberg; -template class ND4J_EXPORT Hessenberg; - -template class ND4J_EXPORT Schur; -template class ND4J_EXPORT Schur; -template class ND4J_EXPORT Schur; -template class ND4J_EXPORT Schur; +template class SD_EXPORT Hessenberg; +template class SD_EXPORT Hessenberg; +template class SD_EXPORT Hessenberg; +template class SD_EXPORT Hessenberg; + +template class SD_EXPORT Schur; +template class SD_EXPORT Schur; +template class SD_EXPORT Schur; +template class SD_EXPORT Schur; } } diff --git a/libnd4j/include/helpers/impl/MmulHelper.cpp b/libnd4j/include/helpers/impl/MmulHelper.cpp index ba86bb1b5436..700f551213ee 100644 --- a/libnd4j/include/helpers/impl/MmulHelper.cpp +++ b/libnd4j/include/helpers/impl/MmulHelper.cpp @@ -22,292 +22,377 @@ #define LIBND4J_MMULHELPER_CPP #include "../MmulHelper.h" -#include -#include + #include +#include +#include namespace sd { ////////////////////////////////////////////////////////////////////////// -sd::NDArray* sd::MmulHelper::tensorDot(const sd::NDArray* A, const sd::NDArray* B, const std::initializer_list& axesA, const std::initializer_list& axesB) { - std::vector aA(axesA); - std::vector aB(axesB); - return tensorDot(A, B, aA, aB); +sd::NDArray* sd::MmulHelper::tensorDot( + const sd::NDArray* A, const sd::NDArray* B, + const std::initializer_list& axesA, + const std::initializer_list& axesB) { + std::vector aA(axesA); + std::vector aB(axesB); + return tensorDot(A, B, aA, aB); } ////////////////////////////////////////////////////////////////////////// -sd::NDArray* sd::MmulHelper::tensorDot(const sd::NDArray* a, const sd::NDArray* b, const std::vector& axes_0, const std::vector& axes_1) { - - std::vector permutAt, permutBt; - std::vector shapeAt, shapeBt; - - auto outShape = ShapeUtils::evalShapeForTensorDot(a, b, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt); - - // check whether permutation is necessary - const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt)); - const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt)); - - // check whether reshape is necessary - const NDArray* aPR = aP->isSameShape(shapeAt) ? aP : new NDArray(aP->reshape(aP->ordering(), shapeAt)); - const NDArray* bPR = bP->isSameShape(shapeAt) ? bP : new NDArray(bP->reshape(bP->ordering(), shapeBt)); - - NDArray* c = mmul(aPR, bPR, nullptr, 1.0, 0.0); - - c->reshapei(outShape); - - if(aP != aPR) - delete aPR; - if(bP != bPR) - delete bPR; - if(a != aP) - delete aP; - if(b != bP) - delete bP; - - return c; +sd::NDArray* sd::MmulHelper::tensorDot(const sd::NDArray* a, + const sd::NDArray* b, + const std::vector& axes_0, + const std::vector& axes_1) { + std::vector permutAt, permutBt; + std::vector shapeAt, shapeBt; + + auto outShape = ShapeUtils::evalShapeForTensorDot( + a, b, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt); + + // check whether permutation is necessary + const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt)); + const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt)); + + // check whether reshape is necessary + const NDArray* aPR = aP->isSameShape(shapeAt) + ? aP + : new NDArray(aP->reshape(aP->ordering(), shapeAt)); + const NDArray* bPR = bP->isSameShape(shapeAt) + ? bP + : new NDArray(bP->reshape(bP->ordering(), shapeBt)); + + NDArray* c = mmul(aPR, bPR, nullptr, 1.0, 0.0); + + c->reshapei(outShape); + + if (aP != aPR) delete aPR; + if (bP != bPR) delete bPR; + if (a != aP) delete aP; + if (b != bP) delete bP; + + return c; } ////////////////////////////////////////////////////////////////////////// -void sd::MmulHelper::tensorDot(const sd::NDArray* a, const sd::NDArray* b, sd::NDArray* c, const std::vector& axes_a, const std::vector& axes_b, const std::vector& permutForC) { - - std::vector permutAt, permutBt; - std::vector shapeAt, shapeBt; - ShapeUtils::evalShapeForTensorDot(a, b, axes_a, axes_b, permutAt, permutBt, shapeAt, shapeBt); - - // check whether permutation is required - NDArray* cP = permutForC.empty() ? c : new NDArray(c->permute(permutForC)); - - // check whether permutation is necessary - const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt)); - const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt)); - - // check whether reshape is necessary - const NDArray* aPR = aP->isSameShape(shapeAt) ? aP : new NDArray(aP->reshape(aP->ordering(), shapeAt)); - const NDArray* bPR = bP->isSameShape(shapeAt) ? bP : new NDArray(bP->reshape(bP->ordering(), shapeBt)); - - std::vector requiredCshape = {aPR->sizeAt(0), bPR->sizeAt(1)}; - - NDArray* cPR = cP->isSameShape(requiredCshape) ? cP : new NDArray(cP->reshape(cP->ordering(), requiredCshape, false)); - - mmul(aPR, bPR, cPR, 1.0, 0.0); - - if(cPR->buffer() != cP->buffer() || cPR->specialBuffer() != cP->specialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->buffer() - cP->assign(cPR); - - if(aP != aPR) - delete aPR; - if(bP != bPR) - delete bPR; - if(a != aP) - delete aP; - if(b != bP) - delete bP; - - if(cP != cPR) - delete cPR; - if(c != cP) - delete cP; +void sd::MmulHelper::tensorDot(const sd::NDArray* a, const sd::NDArray* b, + sd::NDArray* c, const std::vector& axes_a, + const std::vector& axes_b, + const std::vector& permutForC) { + std::vector permutAt, permutBt; + std::vector shapeAt, shapeBt; + ShapeUtils::evalShapeForTensorDot(a, b, axes_a, axes_b, permutAt, permutBt, + shapeAt, shapeBt); + + // check whether permutation is required + NDArray* cP = permutForC.empty() ? c : new NDArray(c->permute(permutForC)); + + // check whether permutation is necessary + const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt)); + const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt)); + + // check whether reshape is necessary + const NDArray* aPR = aP->isSameShape(shapeAt) + ? aP + : new NDArray(aP->reshape(aP->ordering(), shapeAt)); + const NDArray* bPR = bP->isSameShape(shapeAt) + ? bP + : new NDArray(bP->reshape(bP->ordering(), shapeBt)); + + std::vector requiredCshape = {aPR->sizeAt(0), bPR->sizeAt(1)}; + + NDArray* cPR = + cP->isSameShape(requiredCshape) + ? cP + : new NDArray(cP->reshape(cP->ordering(), requiredCshape, false)); + + mmul(aPR, bPR, cPR, 1.0, 0.0); + + if (cPR->buffer() != cP->buffer() || + cPR->specialBuffer() != + cP->specialBuffer()) // this means both permute and reshape have been + // performed on c, cP always points on + // c->buffer() + cP->assign(cPR); + + if (aP != aPR) delete aPR; + if (bP != bPR) delete bPR; + if (a != aP) delete aP; + if (b != bP) delete bP; + + if (cP != cPR) delete cPR; + if (c != cP) delete cP; } - #ifndef __JAVACPP_HACK__ ////////////////////////////////////////////////////////////////////////// -void sd::MmulHelper::tensorDot(const NDArray* a, const NDArray* b, NDArray* c, const std::vector>& modifA, const std::vector>& modifB, const std::vector>& modifC) { - - NDArray *aPR(const_cast(a)), *bPR(const_cast(b)); - std::string whatToDoWithA, whatToDoWithB, whatToDoWithC; // "" - nothing; "p" - permutation; "r" - reshaping; "pr" - permutation+reshaping; "rp" - reshaping/permutation, and so on; if another string is produced - throw exception - - for(const auto& arr : modifA) - whatToDoWithA = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithA + "p" : whatToDoWithA + "r"; // when 0 is present in arr then it is permutation array, otherwise - it is reshaping array - for(const auto& arr : modifB) - whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithB + "p" : whatToDoWithB + "r"; - for(const auto& arr : modifC) - whatToDoWithC = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithC + "p" : whatToDoWithC + "r"; - - // first step for a array - if(!whatToDoWithA.empty()) - aPR = (whatToDoWithA[0] == 'p') ? new NDArray(a->permute(modifA[0])) : new NDArray(a->reshape(a->ordering(), modifA[0])); - // first step for b array - if(!whatToDoWithB.empty()) - bPR = (whatToDoWithB[0] == 'p') ? new NDArray(b->permute(modifB[0])) : new NDArray(b->reshape(b->ordering(), modifB[0])); - // rest steps for a array - for(int i = 1; i < whatToDoWithA.size(); ++i) - if(whatToDoWithA[i] == 'p') aPR->permutei(modifA[i]); else aPR->reshapei(modifA[i]); - // rest steps for b array - for(int i = 1; i < whatToDoWithB.size(); ++i) - if(whatToDoWithB[i] == 'p') bPR->permutei(modifB[i]); else bPR->reshapei(modifB[i]); - - // now work with c array - std::vector cArrs = {c}; - if(!whatToDoWithC.empty()) { - cArrs = std::vector(whatToDoWithC.size()+1, c); - for(int i = 0; i < cArrs.size()-1; ++i) - cArrs[i+1] = (whatToDoWithC[i] == 'p') ? new NDArray(cArrs[i]->permute(modifC[i])) : new NDArray(cArrs[i]->reshape(c->ordering(), modifC[i], false)); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c +void sd::MmulHelper::tensorDot( + const NDArray* a, const NDArray* b, NDArray* c, + const std::vector>& modifA, + const std::vector>& modifB, + const std::vector>& modifC) { + NDArray *aPR(const_cast(a)), *bPR(const_cast(b)); + std::string whatToDoWithA, whatToDoWithB, + whatToDoWithC; // "" - nothing; "p" - permutation; "r" - reshaping; "pr" + // - permutation+reshaping; "rp" - reshaping/permutation, + // and so on; if another string is produced - throw + // exception + + for (const auto& arr : modifA) + whatToDoWithA = + (std::find(arr.begin(), arr.end(), 0) != arr.end()) + ? whatToDoWithA + "p" + : whatToDoWithA + + "r"; // when 0 is present in arr then it is permutation + // array, otherwise - it is reshaping array + for (const auto& arr : modifB) + whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) + ? whatToDoWithB + "p" + : whatToDoWithB + "r"; + for (const auto& arr : modifC) + whatToDoWithC = (std::find(arr.begin(), arr.end(), 0) != arr.end()) + ? whatToDoWithC + "p" + : whatToDoWithC + "r"; + + // first step for a array + if (!whatToDoWithA.empty()) + aPR = (whatToDoWithA[0] == 'p') + ? new NDArray(a->permute(modifA[0])) + : new NDArray(a->reshape(a->ordering(), modifA[0])); + // first step for b array + if (!whatToDoWithB.empty()) + bPR = (whatToDoWithB[0] == 'p') + ? new NDArray(b->permute(modifB[0])) + : new NDArray(b->reshape(b->ordering(), modifB[0])); + // rest steps for a array + for (int i = 1; i < whatToDoWithA.size(); ++i) + if (whatToDoWithA[i] == 'p') + aPR->permutei(modifA[i]); + else + aPR->reshapei(modifA[i]); + // rest steps for b array + for (int i = 1; i < whatToDoWithB.size(); ++i) + if (whatToDoWithB[i] == 'p') + bPR->permutei(modifB[i]); + else + bPR->reshapei(modifB[i]); + + // now work with c array + std::vector cArrs = {c}; + if (!whatToDoWithC.empty()) { + cArrs = std::vector(whatToDoWithC.size() + 1, c); + for (int i = 0; i < cArrs.size() - 1; ++i) + cArrs[i + 1] = + (whatToDoWithC[i] == 'p') + ? new NDArray(cArrs[i]->permute(modifC[i])) + : new NDArray(cArrs[i]->reshape( + c->ordering(), modifC[i], + false)); // since we ignore first element in cArrs (that is + // cArrs[0]) then it is always equal to c + } + + mmul(aPR, bPR, cArrs[cArrs.size() - 1], 1.0, 0.0); + + // check whether new buffer allocation was happened for c array + if (!whatToDoWithC.empty()) { + for (int i = cArrs.size() - 1; i > 0; --i) { + if (cArrs[i]->buffer() != cArrs[i - 1]->buffer() || + cArrs[i]->specialBuffer() != cArrs[i - 1]->specialBuffer()) + cArrs[i - 1]->assign(cArrs[i]); + delete cArrs[i]; } + } - mmul(aPR, bPR, cArrs[cArrs.size()-1], 1.0, 0.0); - - // check whether new buffer allocation was happened for c array - if(!whatToDoWithC.empty()) { - for(int i = cArrs.size()-1; i > 0; --i) { - if(cArrs[i]->buffer() != cArrs[i-1]->buffer() || cArrs[i]->specialBuffer() != cArrs[i-1]->specialBuffer()) - cArrs[i-1]->assign(cArrs[i]); - delete cArrs[i]; - } - } - - if(aPR != a) - delete aPR; - if(bPR != b) - delete bPR; + if (aPR != a) delete aPR; + if (bPR != b) delete bPR; } ////////////////////////////////////////////////////////////////////////// -NDArray* sd::MmulHelper::tensorDot(const sd::NDArray* a, const sd::NDArray* b, const std::vector>& modifA, const std::vector>& modifB) { - - NDArray *aPR(const_cast(a)), *bPR(const_cast(b)); - std::string whatToDoWithA, whatToDoWithB; // "" - nothing; "p" - permutation only; "r" - reshaping only; "pr" - permutation+reshaping; "rp" - reshaping/permutation; another string - throw exception - - for(const auto& arr : modifA) - whatToDoWithA = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithA + "p" : whatToDoWithA + "r"; // when 0 is present in arr then it is permutation array, otherwise - it is reshaping array - for(const auto& arr : modifB) - whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithB + "p" : whatToDoWithB + "r"; - - // first step for a array - if(!whatToDoWithA.empty()) - aPR = (whatToDoWithA[0] == 'p') ? new NDArray(a->permute(modifA[0])) : new NDArray(a->reshape(a->ordering(), modifA[0])); - // first step for b array - if(!whatToDoWithB.empty()) - bPR = (whatToDoWithB[0] == 'p') ? new NDArray(b->permute(modifB[0])) : new NDArray(b->reshape(b->ordering(), modifB[0])); - // rest steps for a array - for(int i = 1; i < whatToDoWithA.size(); ++i) - if(whatToDoWithA[i] == 'p') aPR->permutei(modifA[i]); else aPR->reshapei(modifA[i]); - // rest steps for b array - for(int i = 1; i < whatToDoWithB.size(); ++i) - if(whatToDoWithB[i] == 'p') bPR->permutei(modifB[i]); else bPR->reshapei(modifB[i]); - - NDArray* result = mmul(aPR, bPR, nullptr, 1.0, 0.0); - - if(aPR != a) - delete aPR; - if(bPR != b) - delete bPR; - return result; +NDArray* sd::MmulHelper::tensorDot( + const sd::NDArray* a, const sd::NDArray* b, + const std::vector>& modifA, + const std::vector>& modifB) { + NDArray *aPR(const_cast(a)), *bPR(const_cast(b)); + std::string whatToDoWithA, + whatToDoWithB; // "" - nothing; "p" - permutation only; "r" - reshaping + // only; "pr" - permutation+reshaping; "rp" - + // reshaping/permutation; another string - throw exception + + for (const auto& arr : modifA) + whatToDoWithA = + (std::find(arr.begin(), arr.end(), 0) != arr.end()) + ? whatToDoWithA + "p" + : whatToDoWithA + + "r"; // when 0 is present in arr then it is permutation + // array, otherwise - it is reshaping array + for (const auto& arr : modifB) + whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) + ? whatToDoWithB + "p" + : whatToDoWithB + "r"; + + // first step for a array + if (!whatToDoWithA.empty()) + aPR = (whatToDoWithA[0] == 'p') + ? new NDArray(a->permute(modifA[0])) + : new NDArray(a->reshape(a->ordering(), modifA[0])); + // first step for b array + if (!whatToDoWithB.empty()) + bPR = (whatToDoWithB[0] == 'p') + ? new NDArray(b->permute(modifB[0])) + : new NDArray(b->reshape(b->ordering(), modifB[0])); + // rest steps for a array + for (int i = 1; i < whatToDoWithA.size(); ++i) + if (whatToDoWithA[i] == 'p') + aPR->permutei(modifA[i]); + else + aPR->reshapei(modifA[i]); + // rest steps for b array + for (int i = 1; i < whatToDoWithB.size(); ++i) + if (whatToDoWithB[i] == 'p') + bPR->permutei(modifB[i]); + else + bPR->reshapei(modifB[i]); + + NDArray* result = mmul(aPR, bPR, nullptr, 1.0, 0.0); + + if (aPR != a) delete aPR; + if (bPR != b) delete bPR; + return result; } #endif - ////////////////////////////////////////////////////////////////////////// -sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::NDArray* C , const double alpha, const double beta, const char outOrder) { - - int lenDim; - const int aRank = A->rankOf(); - const int bRank = B->rankOf(); - const bool isAVector = shape::isCommonVector(A->shapeInfo(), lenDim); - const bool isBVector = shape::isCommonVector(B->shapeInfo(), lenDim); - - // dot product of 2 vectors - if(A->lengthOf() == B->lengthOf() && isAVector && isBVector && (aRank != 2 || aRank == 2 && (A->isSameShape(B) || bRank == 1 && A->sizeAt(1) == 1))) // (1x1x1 * 1x1) or (1x4 * 1*4) or (4x1 * 4x1) or (4x1 * 4) - return dot(A, B, C, alpha, beta); - - // matrix x matrix - if(aRank == 2 && bRank == 2) - return mmulMxM(A, B, C, alpha, beta, outOrder); - - // matrix x vector - if(aRank == 2 && isBVector) - return mmulMxV(A, B, C, alpha, beta, outOrder); - - // vector x matrix, A{M} x B{M,N} = C{N} -> reduce to matrix x matrix A2{1,M} x B{M,N} = C2{1,N}, since there is no corresponding blas operation sgevm - if(isAVector && bRank == 2) { - NDArray* A2 = new NDArray(A->reshape(A->ordering(), {1, A->lengthOf()})); // A{M} -> A2{1,M} - NDArray* C2 = C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()}, false)) : nullptr; // C{N} -> C2{1,N} - auto result = mmulMxM(A2, B, C2, alpha, beta, outOrder); // result{1,N} - delete A2; - delete C2; - - if(!C) { - result->reshapei({result->lengthOf()}); // result{1,N} -> result{N} - return result; - } - return C; +sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, + sd::NDArray* C, const double alpha, + const double beta, const char outOrder) { + int lenDim; + const int aRank = A->rankOf(); + const int bRank = B->rankOf(); + const bool isAVector = shape::isCommonVector(A->shapeInfo(), lenDim); + const bool isBVector = shape::isCommonVector(B->shapeInfo(), lenDim); + + // dot product of 2 vectors + if (A->lengthOf() == B->lengthOf() && isAVector && isBVector && + (aRank != 2 || + aRank == 2 && + (A->isSameShape(B) || + bRank == 1 && A->sizeAt(1) == 1))) // (1x1x1 * 1x1) or (1x4 * 1*4) + // or (4x1 * 4x1) or (4x1 * 4) + return dot(A, B, C, alpha, beta); + + // matrix x matrix + if (aRank == 2 && bRank == 2) return mmulMxM(A, B, C, alpha, beta, outOrder); + + // matrix x vector + if (aRank == 2 && isBVector) return mmulMxV(A, B, C, alpha, beta, outOrder); + + // vector x matrix, A{M} x B{M,N} = C{N} -> reduce to matrix x matrix A2{1,M} + // x B{M,N} = C2{1,N}, since there is no corresponding blas operation sgevm + if (isAVector && bRank == 2) { + NDArray* A2 = new NDArray( + A->reshape(A->ordering(), {1, A->lengthOf()})); // A{M} -> A2{1,M} + NDArray* C2 = + C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()}, false)) + : nullptr; // C{N} -> C2{1,N} + auto result = mmulMxM(A2, B, C2, alpha, beta, outOrder); // result{1,N} + delete A2; + delete C2; + + if (!C) { + result->reshapei({result->lengthOf()}); // result{1,N} -> result{N} + return result; } + return C; + } - // batched matrix multiplication - return mmulNxN(A, B, C, alpha, beta, outOrder); + // batched matrix multiplication + return mmulNxN(A, B, C, alpha, beta, outOrder); } - ////////////////////////////////////////////////////////////////////////// - void MmulHelper::matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY, double alpha, double beta) { - int xRank = x->rankOf(); - int yRank = y->rankOf(); - - auto outShape = ShapeUtils::evalShapeForMatmul(x->shapeInfo(), y->shapeInfo(), transX, transY); - if(!z->isSameShape(outShape)) { - nd4j_printf("NDArrayFactory::matmul static method: input shape of output array is wrong, actual is %s and expected is %s ! \n", ShapeUtils::shapeAsString(z).c_str(), ShapeUtils::shapeAsString(outShape).c_str()); - throw std::invalid_argument(""); - } - - if (z->isEmpty()) - return; - - NDArray* xT(const_cast(x)), *yT(const_cast(y)), *zT(z); - - if((transX && xRank > 1) || (transY && yRank > 1)) { - const int rank = xRank >= yRank ? xRank : yRank; - std::vector permut(rank); - for (int i = 0; i < rank-2; ++i) - permut[i] = i; - permut[rank-2] = rank - 1; - permut[rank-1] = rank - 2; - - if(transX) - xT = new NDArray(x->permute(permut)); - - if(transY) - yT = new NDArray(y->permute(permut)); - } - - if(xRank <= 2 && yRank <= 2) { // dot (1Dx1D), vector-matrix (1Dx2D), matrix-vector (2Dx1D), matrix-matrix (2Dx2D) product cases - - if(xRank == 1 && yRank == 2) { // reduce vector-matrix to matrix-matrix case - xT = new NDArray(x->reshape(x->ordering(), {1, x->lengthOf()})); // please note x is not transposed in this case (since xRank=1) - zT = new NDArray(z->reshape(z->ordering(), {1, z->lengthOf()})); - } - - mmul(xT, yT, zT, alpha, beta); - } - else { // rest cases - batched mmul - - const int batchRank = xRank - 2; - std::vector dimsToExclude(batchRank); - for(int i = 0; i < batchRank; ++i) - dimsToExclude[i] = i; - - const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(xT->shapeInfo(), dimsToExclude); - -//PRAGMA_OMP_PARALLEL_FOR - for(Nd4jLong i = 0; i < numOfSubArrs; ++i) { - auto xSubArr = (*xT)(i, dimsToExclude); - auto ySubArr = (*yT)(i, dimsToExclude); - auto zSubArr = (*zT)(i, dimsToExclude); - mmul(&xSubArr, &ySubArr, &zSubArr, alpha, beta); - } - } - - if(xT != x) - delete xT; - if(yT != y) - delete yT; - if(zT != z) - delete zT; +void MmulHelper::matmul(const sd::NDArray* x, const sd::NDArray* y, + sd::NDArray* z, const bool transX, const bool transY, + double alpha, double beta) { + int xRank = x->rankOf(); + int yRank = y->rankOf(); + + auto outShape = ShapeUtils::evalShapeForMatmul(x->shapeInfo(), y->shapeInfo(), + transX, transY); + if (!z->isSameShape(outShape)) { + nd4j_printf( + "NDArrayFactory::matmul static method: input shape of output array is " + "wrong, actual is %s and expected is %s ! \n", + ShapeUtils::shapeAsString(z).c_str(), + ShapeUtils::shapeAsString(outShape).c_str()); + throw std::invalid_argument(""); + } + + if (z->isEmpty()) return; + + NDArray *xT(const_cast(x)), *yT(const_cast(y)), *zT(z); + + if ((transX && xRank > 1) || (transY && yRank > 1)) { + const int rank = xRank >= yRank ? xRank : yRank; + std::vector permut(rank); + for (int i = 0; i < rank - 2; ++i) permut[i] = i; + permut[rank - 2] = rank - 1; + permut[rank - 1] = rank - 2; + + if (transX) xT = new NDArray(x->permute(permut)); + + if (transY) yT = new NDArray(y->permute(permut)); + } + + if (xRank <= 2 && + yRank <= 2) { // dot (1Dx1D), vector-matrix (1Dx2D), matrix-vector + // (2Dx1D), matrix-matrix (2Dx2D) product cases + + if (xRank == 1 && + yRank == 2) { // reduce vector-matrix to matrix-matrix case + xT = new NDArray( + x->reshape(x->ordering(), + {1, x->lengthOf()})); // please note x is not transposed + // in this case (since xRank=1) + zT = new NDArray(z->reshape(z->ordering(), {1, z->lengthOf()})); } -//BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const char cOrder, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* A, const int lda, const void* B, const int ldb, const double beta, void* C, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); -//BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const char aOrder, const int M, const int N, const double alpha, const void* A, const int lda, const void* B, const int incx, const double beta, void* C, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); -//BUILD_TRIPLE_TEMPLATE(template void usualDot, (const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); + mmul(xT, yT, zT, alpha, beta); + } else { // rest cases - batched mmul + + const int batchRank = xRank - 2; + std::vector dimsToExclude(batchRank); + for (int i = 0; i < batchRank; ++i) dimsToExclude[i] = i; + const Nd4jLong numOfSubArrs = + ShapeUtils::getNumOfSubArrs(xT->shapeInfo(), dimsToExclude); + + // PRAGMA_OMP_PARALLEL_FOR + for (Nd4jLong i = 0; i < numOfSubArrs; ++i) { + auto xSubArr = (*xT)(i, dimsToExclude); + auto ySubArr = (*yT)(i, dimsToExclude); + auto zSubArr = (*zT)(i, dimsToExclude); + mmul(&xSubArr, &ySubArr, &zSubArr, alpha, beta); + } + } + + if (xT != x) delete xT; + if (yT != y) delete yT; + if (zT != z) delete zT; } +// BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const char cOrder, const bool +// transA, const bool transB, const int M, const int N, const int K, const +// double alpha, const void* A, const int lda, const void* B, const int ldb, +// const double beta, void* C, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, +// FLOAT_TYPES); BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const char +// aOrder, const int M, const int N, const double alpha, const void* A, const +// int lda, const void* B, const int incx, const double beta, void* C, const int +// incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); +// BUILD_TRIPLE_TEMPLATE(template void usualDot, (const Nd4jLong length, const +// double alpha, const void* vX, const Nd4jLong incx, const void* vY, const +// Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, +// FLOAT_TYPES); + +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/helpers/impl/OmpLaunchHelper.cpp b/libnd4j/include/helpers/impl/OmpLaunchHelper.cpp index b0ef974570e8..e8e694c9f827 100644 --- a/libnd4j/include/helpers/impl/OmpLaunchHelper.cpp +++ b/libnd4j/include/helpers/impl/OmpLaunchHelper.cpp @@ -20,92 +20,93 @@ // #include -#include #include +#include #ifdef _OPENMP #include #endif namespace sd { - //////////////////////////////////////////////////////////////////////////////// -OmpLaunchHelper::OmpLaunchHelper(const Nd4jLong N, float desiredNumThreads) { - - auto maxItersPerThread = Environment::getInstance().elementwiseThreshold(); - - if(N < maxItersPerThread) - _numThreads = 1; - else { - #ifdef _OPENMP - if(desiredNumThreads == -1) - desiredNumThreads = omp_get_max_threads(); - else if(desiredNumThreads < 1) - desiredNumThreads = 1; - else - desiredNumThreads = sd::math::nd4j_min(omp_get_max_threads(), desiredNumThreads); - #else - desiredNumThreads = sd::Environment::getInstance().maxThreads(); - #endif - _numThreads = sd::math::nd4j_min(N / maxItersPerThread, desiredNumThreads); - } - - _itersPerThread = N / _numThreads; - _remainder = N % _numThreads; // last thread may contain bigger number of iterations -} +OmpLaunchHelper::OmpLaunchHelper(const Nd4jLong N, float desiredNumThreads) { + auto maxItersPerThread = Environment::getInstance().elementwiseThreshold(); + if (N < maxItersPerThread) + _numThreads = 1; + else { +#ifdef _OPENMP + if (desiredNumThreads == -1) + desiredNumThreads = omp_get_max_threads(); + else if (desiredNumThreads < 1) + desiredNumThreads = 1; + else + desiredNumThreads = + sd::math::nd4j_min(omp_get_max_threads(), desiredNumThreads); +#else + desiredNumThreads = sd::Environment::getInstance().maxThreads(); +#endif + _numThreads = + sd::math::nd4j_min(N / maxItersPerThread, desiredNumThreads); + } + + _itersPerThread = N / _numThreads; + _remainder = + N % _numThreads; // last thread may contain bigger number of iterations +} Nd4jLong OmpLaunchHelper::betterSpan(Nd4jLong N) { - return OmpLaunchHelper::betterSpan(N, OmpLaunchHelper::betterThreads(N)); - } - - Nd4jLong OmpLaunchHelper::betterSpan(Nd4jLong N, Nd4jLong numThreads) { - auto r = N % numThreads; - auto t = N / numThreads; - - if (r == 0) - return t; - else { - // breaks alignment - return t + 1; - } - } - - int OmpLaunchHelper::betterThreads(Nd4jLong N) { - #ifdef _OPENMP - return betterThreads(N, omp_get_max_threads()); - #else - return betterThreads(N, sd::Environment::getInstance().maxThreads());; - #endif - } - - int OmpLaunchHelper::betterThreads(Nd4jLong N, int maxThreads) { - auto t = Environment::getInstance().elementwiseThreshold(); - if (N < t) - return 1; - else { - return static_cast(sd::math::nd4j_min(N / t, maxThreads)); - } - } - - int OmpLaunchHelper::tadThreads(Nd4jLong tadLength, Nd4jLong numTads) { + return OmpLaunchHelper::betterSpan(N, OmpLaunchHelper::betterThreads(N)); +} + +Nd4jLong OmpLaunchHelper::betterSpan(Nd4jLong N, Nd4jLong numThreads) { + auto r = N % numThreads; + auto t = N / numThreads; + + if (r == 0) + return t; + else { + // breaks alignment + return t + 1; + } +} + +int OmpLaunchHelper::betterThreads(Nd4jLong N) { +#ifdef _OPENMP + return betterThreads(N, omp_get_max_threads()); +#else + return betterThreads(N, sd::Environment::getInstance().maxThreads()); + ; +#endif +} + +int OmpLaunchHelper::betterThreads(Nd4jLong N, int maxThreads) { + auto t = Environment::getInstance().elementwiseThreshold(); + if (N < t) + return 1; + else { + return static_cast(sd::math::nd4j_min(N / t, maxThreads)); + } +} + +int OmpLaunchHelper::tadThreads(Nd4jLong tadLength, Nd4jLong numTads) { #ifdef _OPENMP - auto maxThreads = omp_get_max_threads(); + auto maxThreads = omp_get_max_threads(); #else - auto maxThreads = sd::Environment::getInstance().maxThreads(); + auto maxThreads = sd::Environment::getInstance().maxThreads(); #endif - // if there's only 1 thread allowed - nothing to do here - if (maxThreads <= 1) - return 1; + // if there's only 1 thread allowed - nothing to do here + if (maxThreads <= 1) return 1; - auto totalLength = tadLength * numTads; + auto totalLength = tadLength * numTads; - // if array is tiny - no need to spawn any threeds - if (totalLength < Environment::getInstance().elementwiseThreshold()) - return 1; + // if array is tiny - no need to spawn any threeds + if (totalLength < Environment::getInstance().elementwiseThreshold()) + return 1; - // by default we're spawning as many threads we can, but not more than number of TADs - return sd::math::nd4j_min(numTads, maxThreads); - } + // by default we're spawning as many threads we can, but not more than number + // of TADs + return sd::math::nd4j_min(numTads, maxThreads); } +} // namespace sd diff --git a/libnd4j/include/helpers/impl/OpArgsHolder.cpp b/libnd4j/include/helpers/impl/OpArgsHolder.cpp index 7b82a85d937b..ab1a0986a722 100644 --- a/libnd4j/include/helpers/impl/OpArgsHolder.cpp +++ b/libnd4j/include/helpers/impl/OpArgsHolder.cpp @@ -20,140 +20,129 @@ #include - namespace sd { //////////////////////////////////////////////////////////////////////// // default constructor OpArgsHolder::OpArgsHolder() { + _inArrs = std::vector(); + _tArgs = std::vector(); + _iArgs = std::vector(); + _bArgs = std::vector(); - _inArrs = std::vector(); - _tArgs = std::vector(); - _iArgs = std::vector(); - _bArgs = std::vector(); - - _isArrAlloc = std::vector(); + _isArrAlloc = std::vector(); - _numInArrs = 0; - _numTArgs = 0; - _numIArgs = 0; - _numBArgs = 0; + _numInArrs = 0; + _numTArgs = 0; + _numIArgs = 0; + _numBArgs = 0; } //////////////////////////////////////////////////////////////////////// // copy constructor OpArgsHolder::OpArgsHolder(const OpArgsHolder& other) { - - throw std::runtime_error("OpArgsHolder::OpArgsHolder copy constructor: don't use me !"); + throw std::runtime_error( + "OpArgsHolder::OpArgsHolder copy constructor: don't use me !"); } - //////////////////////////////////////////////////////////////////////// // constructor OpArgsHolder::OpArgsHolder(const std::vector& inArrs, - const std::vector& tArgs, - const std::vector& iArgs, - const std::vector& bArgs) { - _inArrs = inArrs; - _tArgs = tArgs; - _iArgs = iArgs; - _bArgs = bArgs; - - _isArrAlloc = std::vector(); - - _numInArrs = _inArrs.size(); - _numTArgs = _tArgs.size(); - _numIArgs = _iArgs.size(); - _numBArgs = _bArgs.size(); + const std::vector& tArgs, + const std::vector& iArgs, + const std::vector& bArgs) { + _inArrs = inArrs; + _tArgs = tArgs; + _iArgs = iArgs; + _bArgs = bArgs; + + _isArrAlloc = std::vector(); + + _numInArrs = _inArrs.size(); + _numTArgs = _tArgs.size(); + _numIArgs = _iArgs.size(); + _numBArgs = _bArgs.size(); } //////////////////////////////////////////////////////////////////////// // move constructor -OpArgsHolder::OpArgsHolder(OpArgsHolder&& other) noexcept: _inArrs(std::move(other._inArrs)), - _tArgs(std::move(other._tArgs)), - _iArgs(std::move(other._iArgs)), - _bArgs(std::move(other._bArgs)), - _isArrAlloc(std::move(other._isArrAlloc)) { - - other._isArrAlloc = std::vector(); - - _numInArrs = _inArrs.size(); - _numTArgs = _tArgs.size(); - _numIArgs = _iArgs.size(); - _numBArgs = _bArgs.size(); +OpArgsHolder::OpArgsHolder(OpArgsHolder&& other) noexcept + : _inArrs(std::move(other._inArrs)), + _tArgs(std::move(other._tArgs)), + _iArgs(std::move(other._iArgs)), + _bArgs(std::move(other._bArgs)), + _isArrAlloc(std::move(other._isArrAlloc)) { + other._isArrAlloc = std::vector(); + + _numInArrs = _inArrs.size(); + _numTArgs = _tArgs.size(); + _numIArgs = _iArgs.size(); + _numBArgs = _bArgs.size(); } //////////////////////////////////////////////////////////////////////// // assignment operator OpArgsHolder& OpArgsHolder::operator=(const OpArgsHolder& other) { - - throw std::runtime_error("OpArgsHolder::OpArgsHolder assignment operator: don't use me !"); + throw std::runtime_error( + "OpArgsHolder::OpArgsHolder assignment operator: don't use me !"); } - //////////////////////////////////////////////////////////////////////// // move assignment operator OpArgsHolder& OpArgsHolder::operator=(OpArgsHolder&& other) noexcept { + if (this == &other) return *this; - if (this == &other) - return *this; - - for (int i = 0; i < _isArrAlloc.size(); ++i) // delete arrays if necessary - if(_isArrAlloc[i]) - delete _inArrs[i]; + for (int i = 0; i < _isArrAlloc.size(); ++i) // delete arrays if necessary + if (_isArrAlloc[i]) delete _inArrs[i]; - _inArrs = std::move(other._inArrs); - _tArgs = std::move(other._tArgs); - _iArgs = std::move(other._iArgs); - _bArgs = std::move(other._bArgs); - _isArrAlloc = std::move(other._isArrAlloc); + _inArrs = std::move(other._inArrs); + _tArgs = std::move(other._tArgs); + _iArgs = std::move(other._iArgs); + _bArgs = std::move(other._bArgs); + _isArrAlloc = std::move(other._isArrAlloc); - other._isArrAlloc = std::vector(); + other._isArrAlloc = std::vector(); - _numInArrs = _inArrs.size(); - _numTArgs = _tArgs.size(); - _numIArgs = _iArgs.size(); - _numBArgs = _bArgs.size(); + _numInArrs = _inArrs.size(); + _numTArgs = _tArgs.size(); + _numIArgs = _iArgs.size(); + _numBArgs = _bArgs.size(); - return *this; + return *this; } //////////////////////////////////////////////////////////////////////// -OpArgsHolder OpArgsHolder::createArgsHolderForBP(const std::vector& inGradArrs, const bool isInPlace) const { - - const int numInGradArrs = inGradArrs.size(); - - OpArgsHolder result(std::vector(_numInArrs + numInGradArrs, nullptr), _tArgs, _iArgs); - - if(isInPlace) - result._isArrAlloc = std::vector(_numInArrs + numInGradArrs, false); - - for (int i = 0; i < _numInArrs; ++i) { - - if(isInPlace) { - result._inArrs[i] = new NDArray(*_inArrs[i]); // make copy - result._isArrAlloc[i] = true; - } - else - result._inArrs[i] = _inArrs[i]; - } - - // input gradients - for (int i = 0; i < numInGradArrs; ++i) - result._inArrs[_numInArrs + i] = inGradArrs[i]; - - return result; +OpArgsHolder OpArgsHolder::createArgsHolderForBP( + const std::vector& inGradArrs, const bool isInPlace) const { + const int numInGradArrs = inGradArrs.size(); + + OpArgsHolder result( + std::vector(_numInArrs + numInGradArrs, nullptr), _tArgs, + _iArgs); + + if (isInPlace) + result._isArrAlloc = std::vector(_numInArrs + numInGradArrs, false); + + for (int i = 0; i < _numInArrs; ++i) { + if (isInPlace) { + result._inArrs[i] = new NDArray(*_inArrs[i]); // make copy + result._isArrAlloc[i] = true; + } else + result._inArrs[i] = _inArrs[i]; + } + + // input gradients + for (int i = 0; i < numInGradArrs; ++i) + result._inArrs[_numInArrs + i] = inGradArrs[i]; + + return result; } //////////////////////////////////////////////////////////////////////// // default destructor OpArgsHolder::~OpArgsHolder() noexcept { - - for (int i = 0; i < _isArrAlloc.size(); ++i) - if(_isArrAlloc[i]) - delete _inArrs[i]; -} - + for (int i = 0; i < _isArrAlloc.size(); ++i) + if (_isArrAlloc[i]) delete _inArrs[i]; } - +} // namespace sd diff --git a/libnd4j/include/helpers/impl/OpBenchmark.cpp b/libnd4j/include/helpers/impl/OpBenchmark.cpp index 6cb0dc08a788..8fe4dd500f0c 100644 --- a/libnd4j/include/helpers/impl/OpBenchmark.cpp +++ b/libnd4j/include/helpers/impl/OpBenchmark.cpp @@ -21,109 +21,81 @@ #include "../OpBenchmark.h" namespace sd { - OpBenchmark::OpBenchmark(std::string name, NDArray *x, NDArray *y, NDArray *z) { - _testName = name; - _x = x; - _y = y; - _z = z; - } - - OpBenchmark::OpBenchmark(std::string name, NDArray *x, NDArray *z) { - _testName = name; - _x = x; - _z = z; - } - - OpBenchmark::OpBenchmark(std::string name, NDArray *x, NDArray *z, std::initializer_list axis) : OpBenchmark(name, x, nullptr, z, axis){ } - - OpBenchmark::OpBenchmark(std::string name, NDArray *x, NDArray *y, NDArray *z, std::initializer_list axis){ - _testName = name; - _x = x; - _y = y; - _z = z; - _axis = std::vector(axis); - - if (_axis.size() > 1) - std::sort(_axis.begin(), _axis.end()); - - } - - OpBenchmark::OpBenchmark(std::string name, NDArray *x, NDArray *z, std::vector axis) : OpBenchmark(name, x, nullptr, z, axis) { } - - OpBenchmark::OpBenchmark(std::string name, NDArray *x, NDArray *y, NDArray *z, std::vector axis) { - _testName = name; - _x = x; - _y = y; - _z = z; - _axis = axis; - - if (_axis.size() > 1) - std::sort(_axis.begin(), _axis.end()); - } - - - NDArray& OpBenchmark::x() { - return *_x; - } - - int OpBenchmark::opNum() { - return _opNum; - } - std::string OpBenchmark::testName(){ - return _testName; - } - - void OpBenchmark::setOpNum(int opNum) { - _opNum = opNum; - } - - void OpBenchmark::setTestName(std::string name){ - _testName = name; - } - - void OpBenchmark::setX(NDArray *array) { - _x = array; - } - - void OpBenchmark::setY(NDArray *array) { - _y = array; - } - - void OpBenchmark::setZ(NDArray *array) { - _z = array; - } - - void OpBenchmark::setAxis(std::vector axis) { - _axis = axis; - } - - void OpBenchmark::setAxis(std::initializer_list axis) { - _axis = axis; - } - - std::vector OpBenchmark::getAxis(){ - return _axis; - } - - std::string OpBenchmark::extra() { - return "N/A"; - } - - std::string OpBenchmark::shape() { - if (_x != nullptr) - return ShapeUtils::shapeAsString(_x); - else if (_z != nullptr) - return ShapeUtils::shapeAsString(_z); - else - return "N/A"; - } - - std::string OpBenchmark::dataType() { - if (_x != nullptr) - return DataTypeUtils::asString(_x->dataType()); - else if (_z != nullptr) - return DataTypeUtils::asString(_z->dataType()); - else - return "N/A"; - } -} \ No newline at end of file +OpBenchmark::OpBenchmark(const std::string &name, const NDArray &x, + const NDArray &y, const NDArray &z) { + _testName = name; + _x = x; + _y = y; + _z = z; +} + +OpBenchmark::OpBenchmark(const std::string &name, const NDArray &x, + const NDArray &z) { + _testName = name; + _x = x; + _z = z; +} + +OpBenchmark::OpBenchmark(const std::string &name, const NDArray &x, + const NDArray &z, const std::vector &axis) { + _testName = name; + _x = x; + _z = z; + _axis = axis; + + if (_axis.size() > 1) std::sort(_axis.begin(), _axis.end()); +} + +OpBenchmark::OpBenchmark(const std::string &name, const NDArray &x, + const NDArray &y, const NDArray &z, + const std::vector &axis) { + _testName = name; + _x = x; + _y = y; + _z = z; + _axis = axis; + + if (_axis.size() > 1) std::sort(_axis.begin(), _axis.end()); +} + +NDArray &OpBenchmark::x() { return _x; } + +int OpBenchmark::opNum() const { return _opNum; } +const std::string &OpBenchmark::testName() const { return _testName; } + +void OpBenchmark::setOpNum(int opNum) { _opNum = opNum; } + +void OpBenchmark::setTestName(const std::string &name) { _testName = name; } + +void OpBenchmark::setX(const NDArray &array) { _x = array; } + +void OpBenchmark::setY(const NDArray &array) { _y = array; } + +void OpBenchmark::setZ(const NDArray &array) { _z = array; } + +void OpBenchmark::setAxis(std::vector axis) { _axis = axis; } + +void OpBenchmark::setAxis(std::initializer_list axis) { _axis = axis; } + +std::vector OpBenchmark::getAxis() { return _axis; } + +std::string OpBenchmark::extra() { return "N/A"; } + +std::string OpBenchmark::shape() { + if (_x.shapeInfo() != nullptr) + return ShapeUtils::shapeAsString(_x); + else if (_z.shapeInfo() != nullptr) + return ShapeUtils::shapeAsString(_z); + else + return "N/A"; +} + +std::string OpBenchmark::dataType() { + if (_x.shapeInfo() != nullptr) + return DataTypeUtils::asString(_x.dataType()); + else if (_z.shapeInfo() != nullptr) + return DataTypeUtils::asString(_z.dataType()); + else + return "N/A"; +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/impl/OpTracker.cpp b/libnd4j/include/helpers/impl/OpTracker.cpp index e36d4ab5a4c6..0676abdd8883 100644 --- a/libnd4j/include/helpers/impl/OpTracker.cpp +++ b/libnd4j/include/helpers/impl/OpTracker.cpp @@ -19,101 +19,102 @@ // #include -#include #include #include +#include using namespace sd::ops; using namespace sd::graph; namespace sd { - - OpTracker& OpTracker::getInstance() { - static OpTracker instance; - return instance; - } - void OpTracker::storeOperation(sd::graph::OpType opType, const OpDescriptor& descriptor) { - // check out CPU features - if (!::isMinimalRequirementsMet()) { - - auto binaryLevel = ::binaryLevel(); - auto optimalLevel = ::optimalLevel(); - - switch (binaryLevel) { - case 3: { - nd4j_printf("libnd4j binary was built with AVX512 support, but current CPU doesn't have this instruction set. Exiting now...",""); - } - break; - case 2: { - nd4j_printf("libnd4j binary was built with AVX/AVX2 support, but current CPU doesn't have this instruction set. Exiting now...",""); - } - break; - default: { - nd4j_printf("Unknown binary validation error. Exiting now...",""); - } - break; - } - - // we're exiting now - exit(119); - } - // - if (_map.count(opType) < 1) { - std::vector vec; - _map[opType] = vec; - } - - _operations++; - - auto vec = _map[opType]; - - if (std::find(vec.begin(), vec.end(), descriptor) == vec.end()) - _map[opType].emplace_back(descriptor); - } +OpTracker& OpTracker::getInstance() { + static OpTracker instance; - void OpTracker::storeOperation(sd::graph::OpType opType, const char* opName, const Nd4jLong opNum) { - OpDescriptor descriptor(0, opName, false); - descriptor.setOpNum((int) opNum); - descriptor.setHash(-1); + return instance; +} - storeOperation(opType, descriptor); +void OpTracker::storeOperation(sd::graph::OpType opType, + const OpDescriptor& descriptor) { + // check out CPU features + if (!::isMinimalRequirementsMet()) { + auto binaryLevel = ::binaryLevel(); + auto optimalLevel = ::optimalLevel(); + + switch (binaryLevel) { + case 3: { + nd4j_printf( + "libnd4j binary was built with AVX512 support, but current CPU " + "doesn't have this instruction set. Exiting now...", + ""); + } break; + case 2: { + nd4j_printf( + "libnd4j binary was built with AVX/AVX2 support, but current CPU " + "doesn't have this instruction set. Exiting now...", + ""); + } break; + default: { + nd4j_printf("Unknown binary validation error. Exiting now...", ""); + } break; } + // we're exiting now + exit(119); + } + // + if (_map.count(opType) < 1) { + std::vector vec; + _map[opType] = vec; + } - template - std::string OpTracker::local_to_string(T value) { - std::ostringstream os ; - os << value ; - return os.str() ; - } + _operations++; + auto vec = _map[opType]; - int OpTracker::totalGroups() { - return (int) _map.size(); - } + if (std::find(vec.begin(), vec.end(), descriptor) == vec.end()) + _map[opType].emplace_back(descriptor); +} - int OpTracker::totalOperations() { - return _operations; - } +void OpTracker::storeOperation(sd::graph::OpType opType, const char* opName, + const Nd4jLong opNum) { + OpDescriptor descriptor(0, opName, false); + descriptor.setOpNum((int)opNum); + descriptor.setHash(-1); - const char* OpTracker::exportOperations() { - if (_export.length() == 0) { - for (auto &v: _map) { - std::string block = local_to_string(v.first) + " "; + storeOperation(opType, descriptor); +} + +template +std::string OpTracker::local_to_string(T value) { + std::ostringstream os; + os << value; + return os.str(); +} + +int OpTracker::totalGroups() { return (int)_map.size(); } - for (auto &i: v.second) { - block += local_to_string(i.getHash()) + ":"; - block += local_to_string(i.getOpNum()) + ":"; - block += *i.getOpName() + "<<"; - } +int OpTracker::totalOperations() { return _operations; } - block += ">>"; - _export += block; - } - } +const char* OpTracker::exportOperations() { + if (_export.length() == 0) { + for (auto& v : _map) { + std::string block = local_to_string(v.first) + " "; - return _export.c_str(); + for (auto& i : v.second) { + block += local_to_string(i.getHash()) + ":"; + block += local_to_string(i.getOpNum()) + ":"; + block += *i.getOpName() + "<<"; + } + + block += ">>"; + _export += block; } + } + + return _export.c_str(); } + + +} // namespace sd diff --git a/libnd4j/include/helpers/impl/Parameters.cpp b/libnd4j/include/helpers/impl/Parameters.cpp index 356ad5a5aa7e..a609ce4cb302 100644 --- a/libnd4j/include/helpers/impl/Parameters.cpp +++ b/libnd4j/include/helpers/impl/Parameters.cpp @@ -19,81 +19,88 @@ // #include "../benchmark/Parameters.h" + #include namespace sd { - Parameters* Parameters::addIntParam(std::string string, int param) { - _intParams[string] = param; - return this; - } +Parameters* Parameters::addIntParam(std::string string, int param) { + _intParams[string] = param; + return this; +} - int Parameters::getIntParam(std::string string) const { - if (_intParams.count(string) == 0) - throw std::runtime_error("Not available intParameter requested"); +int Parameters::getIntParam(std::string string) const { + if (_intParams.count(string) == 0) + throw std::runtime_error("Not available intParameter requested"); - return _intParams.at(string); - } + return _intParams.at(string); +} - Parameters* Parameters::addIntParam(std::initializer_list strings, std::initializer_list params) { - std::vector s(strings); - std::vector p(params); +Parameters* Parameters::addIntParam(std::initializer_list strings, + std::initializer_list params) { + std::vector s(strings); + std::vector p(params); - if (s.size() != p.size()) - throw std::runtime_error("addIntParam: number of keys and values should match"); + if (s.size() != p.size()) + throw std::runtime_error( + "addIntParam: number of keys and values should match"); - for (int e = 0; e < s.size(); e++) - _intParams[s[e]] = p[e]; + for (int e = 0; e < s.size(); e++) _intParams[s[e]] = p[e]; - return this; - } + return this; +} - Parameters* Parameters::addBoolParam(std::string string, bool param) { - _boolParams[string] = param; - return this; - } +Parameters* Parameters::addBoolParam(std::string string, bool param) { + _boolParams[string] = param; + return this; +} - Parameters* Parameters::addBoolParam(std::initializer_list strings, std::initializer_list params) { - std::vector s(strings); - std::vector p(params); +Parameters* Parameters::addBoolParam(std::initializer_list strings, + std::initializer_list params) { + std::vector s(strings); + std::vector p(params); - if (s.size() != p.size()) - throw std::runtime_error("addIntParam: number of keys and values should match"); + if (s.size() != p.size()) + throw std::runtime_error( + "addIntParam: number of keys and values should match"); - for (int e = 0; e < s.size(); e++) - _boolParams[s[e]] = p[e]; + for (int e = 0; e < s.size(); e++) _boolParams[s[e]] = p[e]; - return this; - } + return this; +} - Parameters* Parameters::addArrayParam(std::string string, std::initializer_list param) { - _arrayParams[string] = std::vector(param); - return this; - } +Parameters* Parameters::addArrayParam(std::string string, + std::initializer_list param) { + _arrayParams[string] = std::vector(param); + return this; +} - Parameters* Parameters::addArrayParam(std::initializer_list strings, std::initializer_list> params) { - std::vector s(strings); - std::vector> p(params); +Parameters* Parameters::addArrayParam( + std::initializer_list strings, + std::initializer_list> params) { + std::vector s(strings); + std::vector> p(params); - if (s.size() != p.size()) - throw std::runtime_error("addIntParam: number of keys and values should match"); + if (s.size() != p.size()) + throw std::runtime_error( + "addIntParam: number of keys and values should match"); - for (int e = 0; e < s.size(); e++) - _arrayParams[s[e]] = std::vector(p[e]); + for (int e = 0; e < s.size(); e++) + _arrayParams[s[e]] = std::vector(p[e]); - return this; - } + return this; +} - bool Parameters::getBoolParam(std::string string) const { - if (_boolParams.count(string) == 0) - throw std::runtime_error("Not available boolParameter requested"); +bool Parameters::getBoolParam(std::string string) const { + if (_boolParams.count(string) == 0) + throw std::runtime_error("Not available boolParameter requested"); - return _boolParams.at(string); - } + return _boolParams.at(string); +} - std::vector Parameters::getArrayParam(std::string string) const { - if (_arrayParams.count(string) == 0) - throw std::runtime_error("Not available arrayParameter requested"); +std::vector Parameters::getArrayParam(std::string string) const { + if (_arrayParams.count(string) == 0) + throw std::runtime_error("Not available arrayParameter requested"); - return _arrayParams.at(string); - } + return _arrayParams.at(string); } +} // namespace sd diff --git a/libnd4j/include/helpers/impl/RandomLauncher.cpp b/libnd4j/include/helpers/impl/RandomLauncher.cpp index f7cdd0f3aea9..4bb9ca952c35 100644 --- a/libnd4j/include/helpers/impl/RandomLauncher.cpp +++ b/libnd4j/include/helpers/impl/RandomLauncher.cpp @@ -18,140 +18,207 @@ // @author raver119@gmail.com // -#include -#include -#include #include +#include +#include +#include //#include #include namespace sd { - void RandomLauncher::applyDropOut(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z) { - if (z == nullptr) - z = array; - - ExtraArguments arguments({retainProb}); - PointersManager pm(context, "applyDropOut"); - - NDArray::prepareSpecialUse({z}, {array}); - - NativeOpExecutioner::execRandom(context, random::DropOut, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType())); - pm.synchronize(); +void RandomLauncher::applyDropOut(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, + NDArray* array, double retainProb, + NDArray* z) { + if (z == nullptr) z = array; - NDArray::registerSpecialUse({z}, {array}); - } + ExtraArguments arguments({retainProb}); + PointersManager pm(context, "applyDropOut"); - void RandomLauncher::applyInvertedDropOut(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z) { - if (z == nullptr) - z = array; + NDArray::prepareSpecialUse({z}, {array}); - ExtraArguments arguments({retainProb}); - PointersManager pm(context, "applyInvertedDropOut"); + NativeOpExecutioner::execRandom( + context, random::DropOut, &rng, array->buffer(), array->shapeInfo(), + array->specialBuffer(), array->specialShapeInfo(), z->buffer(), + z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), + arguments.argumentsAsT(z->dataType())); + pm.synchronize(); - NDArray::prepareSpecialUse({z}, {array}); + NDArray::registerSpecialUse({z}, {array}); +} - NativeOpExecutioner::execRandom(context, random::DropOutInverted, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType())); - pm.synchronize(); +void RandomLauncher::applyInvertedDropOut(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, + NDArray* array, double retainProb, + NDArray* z) { + if (z == nullptr) z = array; - NDArray::registerSpecialUse({z}, {array}); - } + ExtraArguments arguments({retainProb}); + PointersManager pm(context, "applyInvertedDropOut"); - void RandomLauncher::applyAlphaDropOut(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray *array, double retainProb, double alpha, double beta, double alphaPrime, NDArray* z) { - if (z == nullptr) - z = array; + NDArray::prepareSpecialUse({z}, {array}); - ExtraArguments arguments({retainProb, alpha, beta, alphaPrime}); - PointersManager pm(context, "applyAlphaDropOut"); + NativeOpExecutioner::execRandom( + context, random::DropOutInverted, &rng, array->buffer(), + array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), + z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), + arguments.argumentsAsT(z->dataType())); + pm.synchronize(); - NDArray::prepareSpecialUse({z}, {array}); + NDArray::registerSpecialUse({z}, {array}); +} - NativeOpExecutioner::execRandom(context, random::AlphaDropOut, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType())); - pm.synchronize(); +void RandomLauncher::applyAlphaDropOut(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, + NDArray* array, double retainProb, + double alpha, double beta, + double alphaPrime, NDArray* z) { + if (z == nullptr) z = array; - NDArray::registerSpecialUse({z}, {array}); - } + ExtraArguments arguments({retainProb, alpha, beta, alphaPrime}); + PointersManager pm(context, "applyAlphaDropOut"); - void RandomLauncher::fillBernoulli(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double prob) { - ExtraArguments arguments({prob}); - PointersManager pm(context, "fillBernoulli"); + NDArray::prepareSpecialUse({z}, {array}); - NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom( + context, random::AlphaDropOut, &rng, array->buffer(), array->shapeInfo(), + array->specialBuffer(), array->specialShapeInfo(), z->buffer(), + z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), + arguments.argumentsAsT(z->dataType())); + pm.synchronize(); - NativeOpExecutioner::execRandom(context, random::BernoulliDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); - pm.synchronize(); + NDArray::registerSpecialUse({z}, {array}); +} - NDArray::registerSpecialUse({array}, {}); - } +void RandomLauncher::fillBernoulli(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, + NDArray* array, double prob) { + ExtraArguments arguments({prob}); + PointersManager pm(context, "fillBernoulli"); - void RandomLauncher::fillUniform(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double from, double to) { - ExtraArguments arguments({from, to}); - PointersManager pm(context, "fillUniform"); + NDArray::prepareSpecialUse({array}, {}); - NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom( + context, random::BernoulliDistribution, &rng, array->buffer(), + array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), + arguments.argumentsAsT(array->dataType())); + pm.synchronize(); - NativeOpExecutioner::execRandom(context, random::UniformDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); - pm.synchronize(); + NDArray::registerSpecialUse({array}, {}); +} - NDArray::registerSpecialUse({array}, {}); - } +void RandomLauncher::fillUniform(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, + NDArray* array, double from, double to) { + ExtraArguments arguments({from, to}); + PointersManager pm(context, "fillUniform"); - void RandomLauncher::fillGaussian(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) { - ExtraArguments arguments({mean, stdev}); - PointersManager pm(context, "fillGaussian"); + NDArray::prepareSpecialUse({array}, {}); - NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom( + context, random::UniformDistribution, &rng, array->buffer(), + array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), + arguments.argumentsAsT(array->dataType())); + pm.synchronize(); - NativeOpExecutioner::execRandom(context, random::GaussianDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); - pm.synchronize(); + NDArray::registerSpecialUse({array}, {}); +} - NDArray::registerSpecialUse({array}, {}); - } +void RandomLauncher::fillGaussian(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, + NDArray* array, double mean, double stdev) { + ExtraArguments arguments({mean, stdev}); + PointersManager pm(context, "fillGaussian"); - void RandomLauncher::fillExponential(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double lambda) { - ExtraArguments arguments({lambda}); - PointersManager pm(context, "fillExponential"); + NDArray::prepareSpecialUse({array}, {}); - NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom( + context, random::GaussianDistribution, &rng, array->buffer(), + array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), + array->buffer(), array->shapeInfo(), array->specialBuffer(), + array->specialShapeInfo(), array->buffer(), array->shapeInfo(), + array->specialBuffer(), array->specialShapeInfo(), + arguments.argumentsAsT(array->dataType())); + pm.synchronize(); - NativeOpExecutioner::execRandom(context, random::ExponentialDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); - pm.synchronize(); + NDArray::registerSpecialUse({array}, {}); +} - NDArray::registerSpecialUse({array}, {}); - } +void RandomLauncher::fillExponential(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, + NDArray* array, double lambda) { + ExtraArguments arguments({lambda}); + PointersManager pm(context, "fillExponential"); - void RandomLauncher::fillLogNormal(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) { - ExtraArguments arguments({mean, stdev}); - PointersManager pm(context, "fillLogNormal"); + NDArray::prepareSpecialUse({array}, {}); - NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom( + context, random::ExponentialDistribution, &rng, array->buffer(), + array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), + arguments.argumentsAsT(array->dataType())); + pm.synchronize(); - NativeOpExecutioner::execRandom(context, random::GaussianDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); - pm.synchronize(); + NDArray::registerSpecialUse({array}, {}); +} - NDArray::registerSpecialUse({array}, {}); - } +void RandomLauncher::fillLogNormal(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, + NDArray* array, double mean, double stdev) { + ExtraArguments arguments({mean, stdev}); + PointersManager pm(context, "fillLogNormal"); - void RandomLauncher::fillTruncatedNormal(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) { - ExtraArguments arguments({mean, stdev}); - PointersManager pm(context, "fillTruncatedNormal"); + NDArray::prepareSpecialUse({array}, {}); - NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom( + context, random::GaussianDistribution, &rng, array->buffer(), + array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), + array->buffer(), array->shapeInfo(), array->specialBuffer(), + array->specialShapeInfo(), array->buffer(), array->shapeInfo(), + array->specialBuffer(), array->specialShapeInfo(), + arguments.argumentsAsT(array->dataType())); + pm.synchronize(); - NativeOpExecutioner::execRandom(context, random::TruncatedNormalDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); - pm.synchronize(); + NDArray::registerSpecialUse({array}, {}); +} - NDArray::registerSpecialUse({array}, {}); - } +void RandomLauncher::fillTruncatedNormal(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, + NDArray* array, double mean, + double stdev) { + ExtraArguments arguments({mean, stdev}); + PointersManager pm(context, "fillTruncatedNormal"); + + NDArray::prepareSpecialUse({array}, {}); + + NativeOpExecutioner::execRandom( + context, random::TruncatedNormalDistribution, &rng, array->buffer(), + array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), + array->buffer(), array->shapeInfo(), array->specialBuffer(), + array->specialShapeInfo(), array->buffer(), array->shapeInfo(), + array->specialBuffer(), array->specialShapeInfo(), + arguments.argumentsAsT(array->dataType())); + pm.synchronize(); + + NDArray::registerSpecialUse({array}, {}); +} - void RandomLauncher::fillBinomial(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, int trials, double prob) { - ExtraArguments arguments({(double) trials, prob}); - PointersManager pm(context, "fillBinomial"); +void RandomLauncher::fillBinomial(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, + NDArray* array, int trials, double prob) { + ExtraArguments arguments({(double)trials, prob}); + PointersManager pm(context, "fillBinomial"); - NDArray::prepareSpecialUse({array}, {}); + NDArray::prepareSpecialUse({array}, {}); - NativeOpExecutioner::execRandom(context, random::BinomialDistributionEx, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); - pm.synchronize(); + NativeOpExecutioner::execRandom( + context, random::BinomialDistributionEx, &rng, array->buffer(), + array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), + array->buffer(), array->shapeInfo(), array->specialBuffer(), + array->specialShapeInfo(), array->buffer(), array->shapeInfo(), + array->specialBuffer(), array->specialShapeInfo(), + arguments.argumentsAsT(array->dataType())); + pm.synchronize(); - NDArray::registerSpecialUse({array}, {}); - } + NDArray::registerSpecialUse({array}, {}); } +} // namespace sd diff --git a/libnd4j/include/helpers/impl/ShapeBuilders.cpp b/libnd4j/include/helpers/impl/ShapeBuilders.cpp index 7c0c7fed6577..2fde57f5f7e4 100644 --- a/libnd4j/include/helpers/impl/ShapeBuilders.cpp +++ b/libnd4j/include/helpers/impl/ShapeBuilders.cpp @@ -22,132 +22,155 @@ namespace sd { +Nd4jLong* ShapeBuilders::createScalarShapeInfo( + const sd::DataType dataType, sd::memory::Workspace* workspace) { + Nd4jLong* newShape; + ALLOCATE(newShape, workspace, shape::shapeInfoLength(0), Nd4jLong); + newShape[0] = 0; + newShape[1] = 0; + newShape[2] = 1; + newShape[3] = 99; + + sd::ArrayOptions::setDataType(newShape, dataType); + + return newShape; +} - Nd4jLong* ShapeBuilders::createScalarShapeInfo(const sd::DataType dataType, sd::memory::Workspace* workspace) { - Nd4jLong *newShape; - ALLOCATE(newShape, workspace, shape::shapeInfoLength(0), Nd4jLong); - newShape[0] = 0; - newShape[1] = 0; - newShape[2] = 1; - newShape[3] = 99; - - sd::ArrayOptions::setDataType(newShape, dataType); - - return newShape; - } +Nd4jLong* ShapeBuilders::createVectorShapeInfo( + const sd::DataType dataType, const Nd4jLong length, + sd::memory::Workspace* workspace) { + Nd4jLong* newShape; + ALLOCATE(newShape, workspace, shape::shapeInfoLength(1), Nd4jLong); - Nd4jLong* ShapeBuilders::createVectorShapeInfo(const sd::DataType dataType, const Nd4jLong length, sd::memory::Workspace* workspace) { - Nd4jLong *newShape; - ALLOCATE(newShape, workspace, shape::shapeInfoLength(1), Nd4jLong); + newShape[0] = 1; + newShape[1] = length; + newShape[2] = 1; + newShape[3] = 0; + newShape[4] = 1; + newShape[5] = 99; - newShape[0] = 1; - newShape[1] = length; - newShape[2] = 1; - newShape[3] = 0; - newShape[4] = 1; - newShape[5] = 99; + sd::ArrayOptions::setDataType(newShape, dataType); - sd::ArrayOptions::setDataType(newShape, dataType); + return newShape; +} - return newShape; +//////////////////////////////////////////////////////////////////////////////// +Nd4jLong* ShapeBuilders::createShapeInfo(const sd::DataType dataType, + const char order, int rank, + const Nd4jLong* shapeOnly, + memory::Workspace* workspace) { + Nd4jLong* shapeInfo = nullptr; + + if (rank == 0) { // scalar case + shapeInfo = ShapeBuilders::createScalarShapeInfo(dataType, workspace); + } else { + ALLOCATE(shapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong); + shapeInfo[0] = rank; + bool isEmpty = false; + for (int i = 0; i < rank; ++i) { + shapeInfo[i + 1] = shapeOnly[i]; + + if (shapeOnly[i] == 0) isEmpty = true; } - //////////////////////////////////////////////////////////////////////////////// - Nd4jLong* ShapeBuilders::createShapeInfo(const sd::DataType dataType, const char order, int rank, const Nd4jLong* shapeOnly, memory::Workspace* workspace) { - Nd4jLong* shapeInfo = nullptr; - - if(rank == 0) { // scalar case - shapeInfo = ShapeBuilders::createScalarShapeInfo(dataType, workspace); - } - else { - ALLOCATE(shapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong); - shapeInfo[0] = rank; - bool isEmpty = false; - for(int i = 0; i < rank; ++i) { - shapeInfo[i + 1] = shapeOnly[i]; - - if (shapeOnly[i] == 0) - isEmpty = true; - } - - if (!isEmpty) { - shape::updateStrides(shapeInfo, order); - } - else { - shapeInfo[shape::shapeInfoLength(rank) - 1] = order; - memset(shape::stride(shapeInfo), 0, rank * sizeof(Nd4jLong)); - ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); - } - - sd::ArrayOptions::setDataType(shapeInfo, dataType); - } - - return shapeInfo; + if (!isEmpty) { + shape::updateStrides(shapeInfo, order); + } else { + shapeInfo[shape::shapeInfoLength(rank) - 1] = order; + memset(shape::stride(shapeInfo), 0, rank * sizeof(Nd4jLong)); + ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); } - Nd4jLong* ShapeBuilders::emptyShapeInfo(const sd::DataType dataType, memory::Workspace* workspace) { - auto shapeInfo = createScalarShapeInfo(dataType, workspace); - ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); - return shapeInfo; - } + sd::ArrayOptions::setDataType(shapeInfo, dataType); + } - Nd4jLong* ShapeBuilders::emptyShapeInfo(const sd::DataType dataType, const char order, const std::vector &shape, memory::Workspace* workspace) { - auto shapeInfo = createShapeInfo(dataType, order, shape, workspace); - memset(shape::stride(shapeInfo), 0, shape.size() * sizeof(Nd4jLong)); - ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); - return shapeInfo; - } + return shapeInfo; +} -//////////////////////////////////////////////////////////////////////////////// - Nd4jLong* ShapeBuilders::createShapeInfo(const sd::DataType dataType, const char order, const std::vector& shapeOnly, memory::Workspace* workspace) { +Nd4jLong* ShapeBuilders::emptyShapeInfo(const sd::DataType dataType, + memory::Workspace* workspace) { + auto shapeInfo = createScalarShapeInfo(dataType, workspace); + ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); + return shapeInfo; +} - return ShapeBuilders::createShapeInfo(dataType, order, shapeOnly.size(), shapeOnly.data(), workspace); - } +Nd4jLong* ShapeBuilders::emptyShapeInfo(const sd::DataType dataType, + const char order, + const std::vector& shape, + memory::Workspace* workspace) { + auto shapeInfo = createShapeInfo(dataType, order, shape, workspace); + memset(shape::stride(shapeInfo), 0, shape.size() * sizeof(Nd4jLong)); + ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); + return shapeInfo; +} //////////////////////////////////////////////////////////////////////////////// - Nd4jLong* ShapeBuilders::createShapeInfo(const sd::DataType dataType, const char order, const std::initializer_list& shapeOnly, memory::Workspace* workspace) { - - return ShapeBuilders::createShapeInfo(dataType, order, std::vector(shapeOnly), workspace); - } +Nd4jLong* ShapeBuilders::createShapeInfo(const sd::DataType dataType, + const char order, + const std::vector& shapeOnly, + memory::Workspace* workspace) { + return ShapeBuilders::createShapeInfo(dataType, order, shapeOnly.size(), + shapeOnly.data(), workspace); +} //////////////////////////////////////////////////////////////////////////////// - Nd4jLong* ShapeBuilders::copyShapeInfo(const Nd4jLong* inShapeInfo, const bool copyStrides, memory::Workspace* workspace) { +Nd4jLong* ShapeBuilders::createShapeInfo( + const sd::DataType dataType, const char order, + const std::initializer_list& shapeOnly, + memory::Workspace* workspace) { + return ShapeBuilders::createShapeInfo( + dataType, order, std::vector(shapeOnly), workspace); +} - Nd4jLong *outShapeInfo = nullptr; - ALLOCATE(outShapeInfo, workspace, shape::shapeInfoLength(inShapeInfo), Nd4jLong); +//////////////////////////////////////////////////////////////////////////////// +Nd4jLong* ShapeBuilders::copyShapeInfo(const Nd4jLong* inShapeInfo, + const bool copyStrides, + memory::Workspace* workspace) { + Nd4jLong* outShapeInfo = nullptr; + ALLOCATE(outShapeInfo, workspace, shape::shapeInfoLength(inShapeInfo), + Nd4jLong); - memcpy(outShapeInfo, inShapeInfo, shape::shapeInfoByteLength(inShapeInfo)); + memcpy(outShapeInfo, inShapeInfo, shape::shapeInfoByteLength(inShapeInfo)); - if(!copyStrides) - shape::updateStrides(outShapeInfo, shape::order(outShapeInfo)); + if (!copyStrides) + shape::updateStrides(outShapeInfo, shape::order(outShapeInfo)); - return outShapeInfo; - } + return outShapeInfo; +} //////////////////////////////////////////////////////////////////////////////// - Nd4jLong* ShapeBuilders::copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const DataType dtype, const bool copyStrides, memory::Workspace* workspace) { - - Nd4jLong* outShapeInfo = ShapeBuilders::copyShapeInfo(inShapeInfo, copyStrides, workspace); - ArrayOptions::setDataType(outShapeInfo, dtype); - - return outShapeInfo; - } +Nd4jLong* ShapeBuilders::copyShapeInfoAndType(const Nd4jLong* inShapeInfo, + const DataType dtype, + const bool copyStrides, + memory::Workspace* workspace) { + Nd4jLong* outShapeInfo = + ShapeBuilders::copyShapeInfo(inShapeInfo, copyStrides, workspace); + ArrayOptions::setDataType(outShapeInfo, dtype); + + return outShapeInfo; +} //////////////////////////////////////////////////////////////////////////////// - Nd4jLong* ShapeBuilders::copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const Nd4jLong* shapeInfoToGetTypeFrom, const bool copyStrides, memory::Workspace* workspace) { - - return ShapeBuilders::copyShapeInfoAndType(inShapeInfo, ArrayOptions::dataType(shapeInfoToGetTypeFrom), copyStrides, workspace); - } +Nd4jLong* ShapeBuilders::copyShapeInfoAndType( + const Nd4jLong* inShapeInfo, const Nd4jLong* shapeInfoToGetTypeFrom, + const bool copyStrides, memory::Workspace* workspace) { + return ShapeBuilders::copyShapeInfoAndType( + inShapeInfo, ArrayOptions::dataType(shapeInfoToGetTypeFrom), copyStrides, + workspace); +} //////////////////////////////////////////////////////////////////////////////// -Nd4jLong* ShapeBuilders::copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace) { - - Nd4jLong *outShapeInfo = nullptr; - ALLOCATE(outShapeInfo, workspace, shape::shapeInfoLength(inShapeInfo[0] - dimsSize), Nd4jLong); +Nd4jLong* ShapeBuilders::copyShapeInfoWithoutUnites( + const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, + memory::Workspace* workspace) { + Nd4jLong* outShapeInfo = nullptr; + ALLOCATE(outShapeInfo, workspace, + shape::shapeInfoLength(inShapeInfo[0] - dimsSize), Nd4jLong); - shape::excludeUnitiesFromShapeInfo(inShapeInfo, dimsSize, dimsToExclude, outShapeInfo); + shape::excludeUnitiesFromShapeInfo(inShapeInfo, dimsSize, dimsToExclude, + outShapeInfo); - return outShapeInfo; + return outShapeInfo; } -} \ No newline at end of file +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/impl/ShapeUtils.cpp b/libnd4j/include/helpers/impl/ShapeUtils.cpp index 2c189cff1590..4442cdb5dd3b 100644 --- a/libnd4j/include/helpers/impl/ShapeUtils.cpp +++ b/libnd4j/include/helpers/impl/ShapeUtils.cpp @@ -18,1053 +18,1213 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include +#include #include + +#include #include #include -#include #include -#include - namespace sd { ////////////////////////////////////////////////////////////////////////// -// evaluate shape for array resulting from tensorDot operation, also evaluate shapes and dimensions permutations for transposition of two input arrays -std::vector ShapeUtils::evalShapeForTensorDot(const Nd4jLong* aShapeInfo, const Nd4jLong* bShapeInfo, std::vector axesA, std::vector axesB, std::vector& permutAt, std::vector& permutBt, std::vector& shapeAt, std::vector& shapeBt) { - - int axeAsize = (int) axesA.size(); - int axeBsize = (int) axesB.size(); - int aRank = aShapeInfo[0]; - int bRank = bShapeInfo[0]; - - if(axeAsize != axeBsize) - throw std::runtime_error("ShapeUtils::evalShapeForTensorDot method: the numbers of a axes and b axes to make dot product along must have identical values !"); - if(axeAsize > aRank || axeBsize > bRank) - throw std::runtime_error("ShapeUtils::evalShapeForTensorDot method: the length of vector of a or b axes is larger than array rank !"); - - // axes validation - for (int i = 0; i < axeBsize; i++) { - if (axesA[i] < 0) - axesA[i] += aRank; - if (axesB[i] < 0) - axesB[i] += bRank; - if (aShapeInfo[axesA[i] + 1] != bShapeInfo[axesB[i] + 1]) - throw std::runtime_error("ShapeUtils::evalShapeForTensorDot method: the dimensions at given axes for both input arrays must be the same !"); - } - - // check whether axesA and axesB contain only unique numbers - std::set uniqueElems(axesA.begin(), axesA.end()); - if((int)uniqueElems.size() != axeAsize) - throw std::runtime_error("ShapeUtils::evalShapeForTensorDot method: the vector of a axes contains duplicates !"); - uniqueElems.clear(); - uniqueElems = std::set(axesB.begin(), axesB.end()); - if((int)uniqueElems.size() != axeBsize) - throw std::runtime_error("ShapeUtils::evalShapeForTensorDot method: the vector of b axes contains duplicates !"); - - std::vector list_A, list_B; - for (int i = 0; i < aRank; i++) - if (std::find(axesA.begin(), axesA.end(), i) == axesA.end()) - list_A.emplace_back(i); - for (int i = 0; i < bRank; i++) - if (std::find(axesB.begin(), axesB.end(), i) == axesB.end()) - list_B.emplace_back(i); - - permutAt = list_A; - permutAt.insert(permutAt.end(), axesA.begin(), axesA.end()); - permutBt = axesB; - permutBt.insert(permutBt.end(), list_B.begin(), list_B.end()); - - // if permut contains something like {0,1,2,..rank-1}, then there is no need to make permutation and we return empty vector in this case - uint i1, i2; - for(i1 = 0; i1 < aRank; ++i1) - if(permutAt[i1] != i1) - break; - if(i1 == aRank) - permutAt = {}; - for(i2 = 0; i2 < bRank; ++i2) - if(permutBt[i2] != i2) - break; - if(i2 == bRank) - permutBt = {}; - - Nd4jLong n2 = 1; - for (int i = 0; i < axeAsize; i++) - n2 *= aShapeInfo[axesA[i] + 1]; - shapeAt = {shape::length(aShapeInfo) / n2, n2}; - - std::vector oldShapeA; - oldShapeA.resize(list_A.size()); - for (int i = 0; i < oldShapeA.size(); ++i) - oldShapeA[i] = aShapeInfo[list_A[i] + 1]; - - - Nd4jLong n3 = 1; - for (int i = 0; i < axeBsize; i++) - n3 *= bShapeInfo[axesB[i] + 1]; - shapeBt = {n3, shape::length(bShapeInfo) / n3}; - - std::vector oldShapeB; - oldShapeB.resize(list_B.size()); - for (int i = 0; i < oldShapeB.size(); i++) - oldShapeB[i] = bShapeInfo[list_B[i] + 1]; - - std::vector aPlusB(oldShapeA); - aPlusB.insert(aPlusB.end(), oldShapeB.begin(), oldShapeB.end()); - - return aPlusB; +// evaluate shape for array resulting from tensorDot operation, also evaluate +// shapes and dimensions permutations for transposition of two input arrays +std::vector ShapeUtils::evalShapeForTensorDot( + const Nd4jLong* aShapeInfo, const Nd4jLong* bShapeInfo, + std::vector axesA, std::vector axesB, std::vector& permutAt, + std::vector& permutBt, std::vector& shapeAt, + std::vector& shapeBt) { + int axeAsize = (int)axesA.size(); + int axeBsize = (int)axesB.size(); + int aRank = aShapeInfo[0]; + int bRank = bShapeInfo[0]; + + if (axeAsize != axeBsize) + throw std::runtime_error( + "ShapeUtils::evalShapeForTensorDot method: the numbers of a axes and b " + "axes to make dot product along must have identical values !"); + if (axeAsize > aRank || axeBsize > bRank) + throw std::runtime_error( + "ShapeUtils::evalShapeForTensorDot method: the length of vector of a " + "or b axes is larger than array rank !"); + + // axes validation + for (int i = 0; i < axeBsize; i++) { + if (axesA[i] < 0) axesA[i] += aRank; + if (axesB[i] < 0) axesB[i] += bRank; + if (aShapeInfo[axesA[i] + 1] != bShapeInfo[axesB[i] + 1]) + throw std::runtime_error( + "ShapeUtils::evalShapeForTensorDot method: the dimensions at given " + "axes for both input arrays must be the same !"); + } + + // check whether axesA and axesB contain only unique numbers + std::set uniqueElems(axesA.begin(), axesA.end()); + if ((int)uniqueElems.size() != axeAsize) + throw std::runtime_error( + "ShapeUtils::evalShapeForTensorDot method: the vector of a axes " + "contains duplicates !"); + uniqueElems.clear(); + uniqueElems = std::set(axesB.begin(), axesB.end()); + if ((int)uniqueElems.size() != axeBsize) + throw std::runtime_error( + "ShapeUtils::evalShapeForTensorDot method: the vector of b axes " + "contains duplicates !"); + + std::vector list_A, list_B; + for (int i = 0; i < aRank; i++) + if (std::find(axesA.begin(), axesA.end(), i) == axesA.end()) + list_A.emplace_back(i); + for (int i = 0; i < bRank; i++) + if (std::find(axesB.begin(), axesB.end(), i) == axesB.end()) + list_B.emplace_back(i); + + permutAt = list_A; + permutAt.insert(permutAt.end(), axesA.begin(), axesA.end()); + permutBt = axesB; + permutBt.insert(permutBt.end(), list_B.begin(), list_B.end()); + + // if permut contains something like {0,1,2,..rank-1}, then there is no need + // to make permutation and we return empty vector in this case + uint i1, i2; + for (i1 = 0; i1 < aRank; ++i1) + if (permutAt[i1] != i1) break; + if (i1 == aRank) permutAt = {}; + for (i2 = 0; i2 < bRank; ++i2) + if (permutBt[i2] != i2) break; + if (i2 == bRank) permutBt = {}; + + Nd4jLong n2 = 1; + for (int i = 0; i < axeAsize; i++) n2 *= aShapeInfo[axesA[i] + 1]; + shapeAt = {shape::length(aShapeInfo) / n2, n2}; + + std::vector oldShapeA; + oldShapeA.resize(list_A.size()); + for (int i = 0; i < oldShapeA.size(); ++i) + oldShapeA[i] = aShapeInfo[list_A[i] + 1]; + + Nd4jLong n3 = 1; + for (int i = 0; i < axeBsize; i++) n3 *= bShapeInfo[axesB[i] + 1]; + shapeBt = {n3, shape::length(bShapeInfo) / n3}; + + std::vector oldShapeB; + oldShapeB.resize(list_B.size()); + for (int i = 0; i < oldShapeB.size(); i++) + oldShapeB[i] = bShapeInfo[list_B[i] + 1]; + + std::vector aPlusB(oldShapeA); + aPlusB.insert(aPlusB.end(), oldShapeB.begin(), oldShapeB.end()); + + return aPlusB; } ////////////////////////////////////////////////////////////////////////// -std::vector ShapeUtils::evalShapeForTensorDot(const NDArray* a, const NDArray* b, const std::vector& axesA, const std::vector& axesB, std::vector& permutAt, std::vector& permutBt, std::vector& shapeAt, std::vector& shapeBt) { - - return evalShapeForTensorDot(a->shapeInfo(), b->shapeInfo(), axesA, axesB, permutAt, permutBt, shapeAt, shapeBt); +std::vector ShapeUtils::evalShapeForTensorDot( + const NDArray* a, const NDArray* b, const std::vector& axesA, + const std::vector& axesB, std::vector& permutAt, + std::vector& permutBt, std::vector& shapeAt, + std::vector& shapeBt) { + return evalShapeForTensorDot(a->shapeInfo(), b->shapeInfo(), axesA, axesB, + permutAt, permutBt, shapeAt, shapeBt); } - ////////////////////////////////////////////////////////////////////////// // evaluate output shape for reduce operation when input shape is empty - const Nd4jLong* ShapeUtils::evalReduceShapeInfoEmpty(const char order, std::vector& dimsToExclude, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, sd::memory::Workspace* workspace) { +const Nd4jLong* ShapeUtils::evalReduceShapeInfoEmpty( + const char order, const std::vector& vdimsToExclude, + const Nd4jLong* shapeInfo, const sd::DataType dataType, const bool keepDims, + sd::memory::Workspace* workspace) { + auto dimsToExclude = vdimsToExclude; + + if (dimsToExclude.size() == 0) { // return copy of input shape + Nd4jLong* outShapeInfo = ShapeBuilders::copyShapeInfoAndType( + shapeInfo, dataType, true, workspace); + ShapeDescriptor descriptor(outShapeInfo, dataType); + RELEASE(outShapeInfo, workspace); + return ConstantShapeHelper::getInstance() + .bufferForShapeInfo(descriptor) + .primary(); + } - if (dimsToExclude.size() == 0) { // return copy of input shape - Nd4jLong* outShapeInfo = ShapeBuilders::copyShapeInfoAndType(shapeInfo, dataType, true, workspace); - ShapeDescriptor descriptor(outShapeInfo, dataType); - RELEASE(outShapeInfo, workspace); - return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary(); - } + const int rank = shape::rank(shapeInfo); + Nd4jLong* outShapeInfo = nullptr; - const int rank = shape::rank(shapeInfo); - Nd4jLong* outShapeInfo = nullptr; + if (dimsToExclude.size() == + rank) { // return scalar or shape filled with unities - if (dimsToExclude.size() == rank) { // return scalar or shape filled with unities + if (!keepDims) + outShapeInfo = ShapeBuilders::createScalarShapeInfo(dataType, workspace); + else + outShapeInfo = ShapeBuilders::createShapeInfo( + dataType, order, std::vector(rank, 1), workspace); + } else { + shape::checkDimensions(rank, dimsToExclude); - if(!keepDims) - outShapeInfo = ShapeBuilders::createScalarShapeInfo(dataType, workspace); + std::vector outShape; + + if (keepDims) { + outShape.assign(shapeInfo + 1, shapeInfo + 1 + rank); + for (const auto& dim : dimsToExclude) outShape[dim] = 1; + } else { + for (uint i = 0, j = 0; i < rank; ++i) { + if (j < dimsToExclude.size() && i == dimsToExclude[j]) + ++j; else - outShapeInfo = ShapeBuilders::createShapeInfo(dataType, order, std::vector(rank, 1), workspace); + outShape.emplace_back(shapeInfo[i + 1]); + } } - else { - shape::checkDimensions(rank, dimsToExclude); + outShapeInfo = + ShapeBuilders::createShapeInfo(dataType, order, outShape, workspace); + } - std::vector outShape; - - if(keepDims) { - outShape.assign(shapeInfo + 1, shapeInfo + 1 + rank); - for(const auto& dim : dimsToExclude) - outShape[dim] = 1; - } - else { - for (uint i = 0, j = 0; i < rank; ++i) { - if(j < dimsToExclude.size() && i == dimsToExclude[j]) - ++j; - else - outShape.emplace_back(shapeInfo[i + 1]); - } - } - - outShapeInfo = ShapeBuilders::createShapeInfo(dataType, order, outShape, workspace); - } - - ShapeDescriptor descriptor(outShapeInfo, dataType); - RELEASE(outShapeInfo, workspace); - return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary(); + ShapeDescriptor descriptor(outShapeInfo, dataType); + RELEASE(outShapeInfo, workspace); + return ConstantShapeHelper::getInstance() + .bufferForShapeInfo(descriptor) + .primary(); } - const Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector& dimsToExclude, const NDArray& arr, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) { - return evalReduceShapeInfo(order, dimsToExclude, arr, arr.dataType(), keepDims, supportOldShapes, workspace); - } +const Nd4jLong* ShapeUtils::evalReduceShapeInfo( + const char order, const std::vector& dimsToExclude, const NDArray& arr, + const bool keepDims, const bool supportOldShapes, + sd::memory::Workspace* workspace) { + return evalReduceShapeInfo(order, dimsToExclude, arr, arr.dataType(), + keepDims, supportOldShapes, workspace); +} - const Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector& dimsToExclude, const Nd4jLong* shapeInfo, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) { - return evalReduceShapeInfo(order, dimsToExclude, shapeInfo, ArrayOptions::dataType(shapeInfo), keepDims, supportOldShapes, workspace); - } +const Nd4jLong* ShapeUtils::evalReduceShapeInfo( + const char order, const std::vector& dimsToExclude, + const Nd4jLong* shapeInfo, const bool keepDims, const bool supportOldShapes, + sd::memory::Workspace* workspace) { + return evalReduceShapeInfo(order, dimsToExclude, shapeInfo, + ArrayOptions::dataType(shapeInfo), keepDims, + supportOldShapes, workspace); +} ////////////////////////////////////////////////////////////////////////// - const Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector& dimsToExclude, const NDArray& arr, const sd::DataType dataType, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) { - return evalReduceShapeInfo(order, dimsToExclude, arr.shapeInfo(), dataType, keepDims, supportOldShapes, workspace); - } +const Nd4jLong* ShapeUtils::evalReduceShapeInfo( + const char order, const std::vector& dimsToExclude, const NDArray& arr, + const sd::DataType dataType, const bool keepDims, + const bool supportOldShapes, sd::memory::Workspace* workspace) { + return evalReduceShapeInfo(order, dimsToExclude, arr.shapeInfo(), dataType, + keepDims, supportOldShapes, workspace); +} ////////////////////////////////////////////////////////////////////////// // evaluate shape resulting from reduce operation - const Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector& dimsToExclude, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) { - - if(ArrayOptions::arrayType(shapeInfo) == ArrayType::EMPTY) - return ShapeUtils::evalReduceShapeInfoEmpty(order, dimsToExclude, shapeInfo, dataType, keepDims, workspace); - - Nd4jLong* newShapeInfo = nullptr; - - int rank = shape::rank(const_cast(shapeInfo)); - - if (dimsToExclude.size() == 0) { // return scalar or array with len=1 in this case - - if(keepDims && rank > 1) { - ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong); - newShapeInfo[0] = rank; - for(int i = 0; i < rank; ++i) - newShapeInfo[i+1] = 1; - ShapeUtils::updateStridesAndType(newShapeInfo, shapeInfo, order); - ArrayOptions::setDataType(newShapeInfo, dataType); - - ShapeDescriptor descriptor(newShapeInfo, dataType); - RELEASE(newShapeInfo, workspace); - return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary(); - } - else if(supportOldShapes) { - ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong); - shape::shapeOldScalar(dataType, newShapeInfo, 'c'); - ShapeDescriptor descriptor(newShapeInfo, dataType); - RELEASE(newShapeInfo, workspace); - return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary(); - } - else { - newShapeInfo = ShapeBuilders::createScalarShapeInfo(dataType, workspace); - ShapeDescriptor descriptor(newShapeInfo, dataType); - RELEASE(newShapeInfo, workspace); - return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary(); - } +const Nd4jLong* ShapeUtils::evalReduceShapeInfo( + const char order, const std::vector& vdimsToExclude, + const Nd4jLong* shapeInfo, const sd::DataType dataType, const bool keepDims, + const bool supportOldShapes, sd::memory::Workspace* workspace) { + auto dimsToExclude = vdimsToExclude; + + if (ArrayOptions::arrayType(shapeInfo) == ArrayType::EMPTY) + return ShapeUtils::evalReduceShapeInfoEmpty(order, dimsToExclude, shapeInfo, + dataType, keepDims, workspace); + + Nd4jLong* newShapeInfo = nullptr; + + int rank = shape::rank(const_cast(shapeInfo)); + + if (dimsToExclude.size() == + 0) { // return scalar or array with len=1 in this case + + if (keepDims && rank > 1) { + ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong); + newShapeInfo[0] = rank; + for (int i = 0; i < rank; ++i) newShapeInfo[i + 1] = 1; + ShapeUtils::updateStridesAndType(newShapeInfo, shapeInfo, order); + ArrayOptions::setDataType(newShapeInfo, dataType); + + ShapeDescriptor descriptor(newShapeInfo, dataType); + RELEASE(newShapeInfo, workspace); + return ConstantShapeHelper::getInstance() + .bufferForShapeInfo(descriptor) + .primary(); + } else if (supportOldShapes) { + ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong); + shape::shapeOldScalar(dataType, newShapeInfo, 'c'); + ShapeDescriptor descriptor(newShapeInfo, dataType); + RELEASE(newShapeInfo, workspace); + return ConstantShapeHelper::getInstance() + .bufferForShapeInfo(descriptor) + .primary(); + } else { + newShapeInfo = ShapeBuilders::createScalarShapeInfo(dataType, workspace); + ShapeDescriptor descriptor(newShapeInfo, dataType); + RELEASE(newShapeInfo, workspace); + return ConstantShapeHelper::getInstance() + .bufferForShapeInfo(descriptor) + .primary(); } - - shape::checkDimensions(rank, dimsToExclude); - - int dimSize = dimsToExclude.size(); - - if(keepDims) { - - ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong); - newShapeInfo[0] = rank; - for(int i = 0; i < rank; ++i) - if (std::binary_search(dimsToExclude.begin(), dimsToExclude.end(), i)) // dimsToExclude is already sorted after shape::checkDimensions() has been applied - newShapeInfo[i+1] = 1; - else - newShapeInfo[i+1] = shapeInfo[i+1]; - - ShapeUtils::updateStridesAndType(newShapeInfo, shapeInfo, order); - ShapeDescriptor descriptor(newShapeInfo, dataType); - RELEASE(newShapeInfo, workspace); - return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary(); + } + + shape::checkDimensions(rank, dimsToExclude); + + int dimSize = dimsToExclude.size(); + + if (keepDims) { + ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong); + newShapeInfo[0] = rank; + for (int i = 0; i < rank; ++i) + if (std::binary_search(dimsToExclude.begin(), dimsToExclude.end(), + i)) // dimsToExclude is already sorted after + // shape::checkDimensions() has been applied + newShapeInfo[i + 1] = 1; + else + newShapeInfo[i + 1] = shapeInfo[i + 1]; + + ShapeUtils::updateStridesAndType(newShapeInfo, shapeInfo, order); + ShapeDescriptor descriptor(newShapeInfo, dataType); + RELEASE(newShapeInfo, workspace); + return ConstantShapeHelper::getInstance() + .bufferForShapeInfo(descriptor) + .primary(); + } + + int newRank = rank - dimSize; + if (newRank == 0 || + (dimSize == 1 && + dimsToExclude[0] == INT_MAX)) { // check whether given dimension is + // meant for the whole dimension + + if (supportOldShapes) { + ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong); + shape::shapeOldScalar(ArrayOptions::dataType(shapeInfo), newShapeInfo, + 'c'); + ShapeDescriptor descriptor(newShapeInfo, dataType); + RELEASE(newShapeInfo, workspace); + return ConstantShapeHelper::getInstance() + .bufferForShapeInfo(descriptor) + .primary(); + } else { + newShapeInfo = ShapeBuilders::createScalarShapeInfo( + ArrayOptions::dataType(shapeInfo), workspace); + ShapeDescriptor descriptor(newShapeInfo, dataType); + RELEASE(newShapeInfo, workspace); + return ConstantShapeHelper::getInstance() + .bufferForShapeInfo(descriptor) + .primary(); } - - int newRank = rank - dimSize; - if (newRank==0 || (dimSize==1 && dimsToExclude[0]==INT_MAX)) { // check whether given dimension is meant for the whole dimension - - if(supportOldShapes) { - ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong); - shape::shapeOldScalar(ArrayOptions::dataType(shapeInfo), newShapeInfo, 'c'); - ShapeDescriptor descriptor(newShapeInfo, dataType); - RELEASE(newShapeInfo, workspace); - return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary(); - } - else { - newShapeInfo = ShapeBuilders::createScalarShapeInfo(ArrayOptions::dataType(shapeInfo), workspace); - ShapeDescriptor descriptor(newShapeInfo, dataType); - RELEASE(newShapeInfo, workspace); - return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary(); - } - } - - ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(newRank), Nd4jLong); - newShapeInfo[0] = newRank; // set rank - int j=1; - for(int i = 0; i < rank; ++i) - if (!std::binary_search(dimsToExclude.begin(), dimsToExclude.end(), i)) // dimsToExclude is already sorted after shape::checkDimensions() has been applied - newShapeInfo[j++] = shapeInfo[i+1]; - - //ensure whether vector has proper shape for old shape type - if (newRank == 1 && supportOldShapes) { - int oldValue = newShapeInfo[1]; - RELEASE(newShapeInfo, workspace); - ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong); // set newRank = 2 - newShapeInfo[0] = 2; - if (dimsToExclude[0] == 0) { - newShapeInfo[1] = 1; - newShapeInfo[2] = oldValue; - } - else { - newShapeInfo[1] = oldValue; - newShapeInfo[2] = 1; - } + } + + ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(newRank), Nd4jLong); + newShapeInfo[0] = newRank; // set rank + int j = 1; + for (int i = 0; i < rank; ++i) + if (!std::binary_search(dimsToExclude.begin(), dimsToExclude.end(), + i)) // dimsToExclude is already sorted after + // shape::checkDimensions() has been applied + newShapeInfo[j++] = shapeInfo[i + 1]; + + // ensure whether vector has proper shape for old shape type + if (newRank == 1 && supportOldShapes) { + int oldValue = newShapeInfo[1]; + RELEASE(newShapeInfo, workspace); + ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), + Nd4jLong); // set newRank = 2 + newShapeInfo[0] = 2; + if (dimsToExclude[0] == 0) { + newShapeInfo[1] = 1; + newShapeInfo[2] = oldValue; + } else { + newShapeInfo[1] = oldValue; + newShapeInfo[2] = 1; } + } - ShapeUtils::updateStridesAndType(newShapeInfo, shapeInfo, order); + ShapeUtils::updateStridesAndType(newShapeInfo, shapeInfo, order); - ShapeDescriptor descriptor(newShapeInfo, dataType); - RELEASE(newShapeInfo, workspace); - return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary(); + ShapeDescriptor descriptor(newShapeInfo, dataType); + RELEASE(newShapeInfo, workspace); + return ConstantShapeHelper::getInstance() + .bufferForShapeInfo(descriptor) + .primary(); } ////////////////////////////////////////////////////////////////////////// // evaluate shape for array which is result of repeat operation applied to arr -std::vector ShapeUtils::evalRepeatShape(int axis, const std::vector& repeats, const NDArray& arr) { +std::vector ShapeUtils::evalRepeatShape( + int axis, const std::vector& repeats, const NDArray& arr) { + if (axis < 0) axis += arr.rankOf(); - if (axis < 0) - axis += arr.rankOf(); + if (repeats.size() != 1 && repeats.size() != arr.sizeAt(axis)) + throw std::invalid_argument( + "ShapeUtils::evalRepeatShape: size of repeats vector must be 1 or " + "equal to dimension at given axis !"); - if(repeats.size() != 1 && repeats.size() != arr.sizeAt(axis)) - throw std::invalid_argument("ShapeUtils::evalRepeatShape: size of repeats vector must be 1 or equal to dimension at given axis !"); + std::vector outShape = arr.getShapeAsVector(); - std::vector outShape = arr.getShapeAsVector(); + if (repeats.size() == 1) + outShape[axis] *= repeats[0]; + else + outShape[axis] = std::accumulate(repeats.begin(), repeats.end(), 0); - if(repeats.size() == 1) - outShape[axis] *= repeats[0]; - else - outShape[axis] = std::accumulate(repeats.begin(), repeats.end(), 0); - - return outShape; + return outShape; } ////////////////////////////////////////////////////////////////////////// // evaluate shapeInfo of permuted array - const Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides) { - - if (!arr.nonNull()) - throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: array is nullptr!"); +const Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, + const int rank, + const NDArray& arr, + sd::memory::Workspace* workspace, + const bool setContigStrides) { + if (!arr.nonNull()) + throw std::runtime_error( + "ShapeUtils::evalPermShapeInfo static method: wrong arguments: array " + "is nullptr!"); - if (rank != arr.rankOf()) - throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: rank is not suitable!"); + if (rank != arr.rankOf()) + throw std::runtime_error( + "ShapeUtils::evalPermShapeInfo static method: wrong arguments: rank is " + "not suitable!"); - auto shapeInfoLength = shape::shapeInfoLength(rank); + auto shapeInfoLength = shape::shapeInfoLength(rank); - // allocate memory for new array - shapeInfo - Nd4jLong *shapeInfoNew = nullptr; - ALLOCATE(shapeInfoNew, workspace, shapeInfoLength, Nd4jLong); + // allocate memory for new array - shapeInfo + Nd4jLong* shapeInfoNew = nullptr; + ALLOCATE(shapeInfoNew, workspace, shapeInfoLength, Nd4jLong); - // copy arr _shapeInfo into new array - memcpy(shapeInfoNew, arr.shapeInfo(), shape::shapeInfoByteLength(rank)); + // copy arr _shapeInfo into new array + memcpy(shapeInfoNew, arr.shapeInfo(), shape::shapeInfoByteLength(rank)); - // perform buffer permutation - shape::doPermuteShapeInfo(shapeInfoNew, dimensions, arr.lengthOf()); + // perform buffer permutation + shape::doPermuteShapeInfo(shapeInfoNew, dimensions, arr.lengthOf()); - if(setContigStrides) - shape::updateStrides(shapeInfoNew, arr.ordering()); + if (setContigStrides) shape::updateStrides(shapeInfoNew, arr.ordering()); - ShapeDescriptor descriptor(shapeInfoNew); + ShapeDescriptor descriptor(shapeInfoNew); - RELEASE(shapeInfoNew, workspace); - - return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary(); - } + RELEASE(shapeInfoNew, workspace); - ////////////////////////////////////////////////////////////////////////// - // evaluate shapeInfo of permuted array - const Nd4jLong* ShapeUtils::evalPermShapeInfo(const Nd4jLong *dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace) { + return ConstantShapeHelper::getInstance() + .bufferForShapeInfo(descriptor) + .primary(); +} - std::vector dims(dimensions, dimensions + rank); - return evalPermShapeInfo(dims.data(), rank, arr, workspace); - } +////////////////////////////////////////////////////////////////////////// +// evaluate shapeInfo of permuted array +const Nd4jLong* ShapeUtils::evalPermShapeInfo( + const Nd4jLong* dimensions, const int rank, const NDArray& arr, + sd::memory::Workspace* workspace) { + std::vector dims(dimensions, dimensions + rank); + return evalPermShapeInfo(dims.data(), rank, arr, workspace); +} ////////////////////////////////////////////////////////////////////////// // evaluate shapeInfo of transposed array - const Nd4jLong* ShapeUtils::evalTranspShapeInfo(const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides) { - - int rank = arr.rankOf(); - std::vector dimensions(rank); - for (int i = 0; i < rank; ++i) - dimensions[i] = rank - 1 - i; - - return evalPermShapeInfo(dimensions.data(), dimensions.size(), arr, workspace, setContigStrides); - } +const Nd4jLong* ShapeUtils::evalTranspShapeInfo( + const NDArray& arr, sd::memory::Workspace* workspace, + const bool setContigStrides) { + int rank = arr.rankOf(); + std::vector dimensions(rank); + for (int i = 0; i < rank; ++i) dimensions[i] = rank - 1 - i; + + return evalPermShapeInfo(dimensions.data(), dimensions.size(), arr, workspace, + setContigStrides); +} ////////////////////////////////////////////////////////////////////////// - bool ShapeUtils::copyVectorPart(std::vector& target, std::vector& source, int rank, int offset) { - if (source.size() < offset + rank) - return false; +bool ShapeUtils::copyVectorPart(std::vector& target, + std::vector& source, int rank, + int offset) { + if (source.size() < offset + rank) return false; - for (int e = offset; e < offset + rank; e++) - target.push_back(source[e]); - - return true; - } + for (int e = offset; e < offset + rank; e++) target.push_back(source[e]); + return true; +} ////////////////////////////////////////////////////////////////////////// -// return new (shorter) sorted dimensions array without dimensions that are present in input vector -std::vector ShapeUtils::evalDimsToExclude(const int rank, const int dimsLen, const int* dimensions) { - - std::vector newDimensions; - if(dimsLen == 0) { // if input vector is empty then return whole shape range - newDimensions.resize(rank); - std::iota(newDimensions.begin(), newDimensions.end(), 0); // fill with 0, 1, ... rank-1 - } - else { - bool isAbsent; - for(int i=0; i= 0 ? dimensions[j] : dimensions[j] + rank; - if(i == dim) { - isAbsent = false; - break; - } - } - if(isAbsent) - newDimensions.emplace_back(i); +// return new (shorter) sorted dimensions array without dimensions that are +// present in input vector +std::vector ShapeUtils::evalDimsToExclude(const int rank, + const int dimsLen, + const int* dimensions) { + std::vector newDimensions; + if (dimsLen == 0) { // if input vector is empty then return whole shape range + newDimensions.resize(rank); + std::iota(newDimensions.begin(), newDimensions.end(), + 0); // fill with 0, 1, ... rank-1 + } else { + bool isAbsent; + for (int i = 0; i < rank; ++i) { + isAbsent = true; + for (int j = 0; j < dimsLen; ++j) { + int dim = dimensions[j] >= 0 ? dimensions[j] : dimensions[j] + rank; + if (i == dim) { + isAbsent = false; + break; } + } + if (isAbsent) newDimensions.emplace_back(i); } + } - return newDimensions; + return newDimensions; } ////////////////////////////////////////////////////////////////////////// -std::vector ShapeUtils::evalDimsToExclude(const int rank, const std::vector& dimensions) { - - return ShapeUtils::evalDimsToExclude(rank, dimensions.size(), dimensions.data()); +std::vector ShapeUtils::evalDimsToExclude( + const int rank, const std::vector& dimensions) { + return ShapeUtils::evalDimsToExclude(rank, dimensions.size(), + dimensions.data()); } ////////////////////////////////////////////////////////////////////////// // check whether 2 arrays have mutually broadcastable shapes // shape comparison starts from the end -bool ShapeUtils::areShapesBroadcastable(const NDArray &arr1, const NDArray &arr2) { - return areShapesBroadcastable(arr1.shapeInfo(), arr2.shapeInfo()); +bool ShapeUtils::areShapesBroadcastable(const NDArray& arr1, + const NDArray& arr2) { + return areShapesBroadcastable(arr1.shapeInfo(), arr2.shapeInfo()); } -bool ShapeUtils::areShapesBroadcastable(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2) { - int minRank = shape::rank(shapeInfo1) < shape::rank(shapeInfo2) ? shape::rank(shapeInfo1) : shape::rank(shapeInfo2); +bool ShapeUtils::areShapesBroadcastable(const Nd4jLong* shapeInfo1, + const Nd4jLong* shapeInfo2) { + int minRank = shape::rank(shapeInfo1) < shape::rank(shapeInfo2) + ? shape::rank(shapeInfo1) + : shape::rank(shapeInfo2); - for (int i = -1; i >= -minRank; --i) - if (shape::sizeAt(shapeInfo1, i) != shape::sizeAt(shapeInfo2, i) && shape::sizeAt(shapeInfo1, i) != 1 && shape::sizeAt(shapeInfo2, i) != 1) - return false; + for (int i = -1; i >= -minRank; --i) + if (shape::sizeAt(shapeInfo1, i) != shape::sizeAt(shapeInfo2, i) && + shape::sizeAt(shapeInfo1, i) != 1 && shape::sizeAt(shapeInfo2, i) != 1) + return false; - return true; + return true; } - bool ShapeUtils::areShapesBroadcastable(const std::vector& shape1, const std::vector& shape2) { - - const auto rank1 = shape1.size(); - const auto rank2 = shape2.size(); - const int minRank = rank1 < rank2 ? rank1 : rank2; - - for (int i = 1; i <= minRank; ++i) - if (shape1[rank1-i] != shape2[rank2-i] && shape1[rank1-i] != 1 && shape2[rank2-i] != 1) - return false; - - return true; - } - - ////////////////////////////////////////////////////////////////////////// - // check the possibility of broadcast operation, if true then return shapeInfo of resulting array - // if evalMinMax == false the array with larger rank has to be passed as first argument - bool ShapeUtils::evalBroadcastShapeInfo(const NDArray &max, const NDArray &min, const bool evalMinMax, const Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace) { - return evalBroadcastShapeInfo(max.shapeInfo(), min.shapeInfo(), evalMinMax, resultShapeInfo, workspace); - } - - bool ShapeUtils::evalBroadcastShapeInfo(const Nd4jLong *max, const Nd4jLong *min, const bool evalMinMax, const Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace) { - - // check whether broadcast operation is possible for input arrays - if(!areShapesBroadcastable(max, min)) - return false; - - auto maxShapeInfo = max; //max.shapeInfo(); - auto minShapeInfo = min; //min.shapeInfo(); - - if(evalMinMax && (shape::rank(max) < shape::rank(min))) { - maxShapeInfo = min; - minShapeInfo = max; - } - - const auto maxRank = shape::rank(maxShapeInfo); - const auto minRank = shape::rank(minShapeInfo); - - // evaluate shapeInfo for resulting array - if(resultShapeInfo != nullptr) - throw std::runtime_error("std::runtime_error(ShapeUtils::evalBroadcastShapeInfo method: the input pointer on shapeInfo must be empty (=nullptr) !"); - - Nd4jLong *tmpShapeInfo = nullptr; - ALLOCATE(tmpShapeInfo, workspace, shape::shapeInfoLength(maxRank), Nd4jLong); - - // FIXME: get rid of memcpy here - memcpy(tmpShapeInfo, maxShapeInfo, shape::shapeInfoByteLength(maxRank)); - for (int i = 0; i < minRank; ++i) - if((maxShapeInfo[maxRank-i] != 0 && maxShapeInfo[maxRank-i] < minShapeInfo[minRank-i]) || minShapeInfo[minRank-i] == 0) - tmpShapeInfo[maxRank - i] = minShapeInfo[minRank-i]; - - ShapeUtils::updateStridesAndType(tmpShapeInfo, DataTypeUtils::pickPairwiseResultType(maxShapeInfo, minShapeInfo), shape::order(maxShapeInfo)); - - if (shape::isEmpty(max) || shape::isEmpty(min)) { - ArrayOptions::setPropertyBit(tmpShapeInfo, ARRAY_EMPTY); - memset(shape::stride(tmpShapeInfo), 0, shape::rank(tmpShapeInfo) * sizeof(Nd4jLong)); - } - - ShapeDescriptor descriptor(tmpShapeInfo); - RELEASE(tmpShapeInfo, workspace); - resultShapeInfo = ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary(); - - return true; - } - - ////////////////////////////////////////////////////////////////////////// - // check the possibility of broadcast operation for set of arrays, if true then return resulting broadcasted shapeInfo - bool ShapeUtils::evalCommonBroadcastShapeInfo(const std::vector& arrays, Nd4jLong*& resultShapeInfo, memory::Workspace* workspace) { - - if(resultShapeInfo != nullptr) - throw std::runtime_error("ShapeUtils::evalCommonBroadcastShapeInfo method: the input pointer on shapeInfo must be empty (=nullptr) !"); - - int size = arrays.size(); - int maxRank = arrays[size - 1]->rankOf(); - - for(int i = 0; i < size - 1; ++i) { - if(arrays[i]->rankOf() > maxRank) - maxRank = arrays[i]->rankOf(); - for(int j = i + 1; j < size; ++j) - if(!areShapesBroadcastable(*arrays[i], *arrays[j])) - return false; - } - - Nd4jLong *tmpShapeInfo = nullptr; - ALLOCATE(tmpShapeInfo, workspace, shape::shapeInfoLength(maxRank), Nd4jLong); - memset(tmpShapeInfo, 0, shape::shapeInfoByteLength(maxRank)); - tmpShapeInfo[0] = maxRank; - - for(const auto& item : arrays ) { - for(int i = -1; i >= -item->rankOf(); --i) - if(tmpShapeInfo[i + 1 + maxRank] < item->sizeAt(i)) - tmpShapeInfo[i + 1 + maxRank] = item->sizeAt(i); - } - - shape::updateStrides(tmpShapeInfo, arrays[0]->ordering()); - ArrayOptions::setDataType(tmpShapeInfo, arrays[0]->dataType()); - - ShapeDescriptor descriptor(tmpShapeInfo); - RELEASE(tmpShapeInfo, workspace); - resultShapeInfo = const_cast(ConstantShapeHelper::getInstance().createShapeInfo(descriptor)); - - return true; - } - - - ////////////////////////////////////////////////////////////////////////// - // return sorted vector of dimensions common (same) for two arrays, dimensions values corresponds to array with bigger rank - // for example if arr1{2,7}, arr2{2,5,4,7} then vector = {0,3} - std::vector ShapeUtils::getDimsWithSameShape(const NDArray& arr1, const NDArray& arr2) { +bool ShapeUtils::areShapesBroadcastable(const std::vector& shape1, + const std::vector& shape2) { + const auto rank1 = shape1.size(); + const auto rank2 = shape2.size(); + const int minRank = rank1 < rank2 ? rank1 : rank2; - const NDArray *min, *max; + for (int i = 1; i <= minRank; ++i) + if (shape1[rank1 - i] != shape2[rank2 - i] && shape1[rank1 - i] != 1 && + shape2[rank2 - i] != 1) + return false; - if(arr1.rankOf() >= arr2.rankOf()) { - max = &arr1; - min = &arr2; - } - else { - max = &arr2; - min = &arr1; - } - - const int rankDiff = max->rankOf() - min->rankOf(); + return true; +} - std::vector dims; +////////////////////////////////////////////////////////////////////////// +// check the possibility of broadcast operation, if true then return shapeInfo +// of resulting array if evalMinMax == false the array with larger rank has to +// be passed as first argument +bool ShapeUtils::evalBroadcastShapeInfo(const NDArray& max, const NDArray& min, + const bool evalMinMax, + const Nd4jLong*& resultShapeInfo, + sd::memory::Workspace* workspace) { + return evalBroadcastShapeInfo(max.shapeInfo(), min.shapeInfo(), evalMinMax, + resultShapeInfo, workspace); +} - for (int i = 0; i < min->rankOf(); ++i) - if (min->sizeAt(i) == max->sizeAt(rankDiff + i)) - dims.emplace_back(rankDiff + i); +bool ShapeUtils::evalBroadcastShapeInfo(const Nd4jLong* max, + const Nd4jLong* min, + const bool evalMinMax, + const Nd4jLong*& resultShapeInfo, + sd::memory::Workspace* workspace) { + // check whether broadcast operation is possible for input arrays + if (!areShapesBroadcastable(max, min)) return false; + + auto maxShapeInfo = max; // max.shapeInfo(); + auto minShapeInfo = min; // min.shapeInfo(); + + if (evalMinMax && (shape::rank(max) < shape::rank(min))) { + maxShapeInfo = min; + minShapeInfo = max; + } + + const auto maxRank = shape::rank(maxShapeInfo); + const auto minRank = shape::rank(minShapeInfo); + + // evaluate shapeInfo for resulting array + if (resultShapeInfo != nullptr) + throw std::runtime_error( + "std::runtime_error(ShapeUtils::evalBroadcastShapeInfo method: the " + "input pointer on shapeInfo must be empty (=nullptr) !"); + + Nd4jLong* tmpShapeInfo = nullptr; + ALLOCATE(tmpShapeInfo, workspace, shape::shapeInfoLength(maxRank), Nd4jLong); + + // FIXME: get rid of memcpy here + memcpy(tmpShapeInfo, maxShapeInfo, shape::shapeInfoByteLength(maxRank)); + for (int i = 0; i < minRank; ++i) + if ((maxShapeInfo[maxRank - i] != 0 && + maxShapeInfo[maxRank - i] < minShapeInfo[minRank - i]) || + minShapeInfo[minRank - i] == 0) + tmpShapeInfo[maxRank - i] = minShapeInfo[minRank - i]; + + ShapeUtils::updateStridesAndType( + tmpShapeInfo, + DataTypeUtils::pickPairwiseResultType(maxShapeInfo, minShapeInfo), + shape::order(maxShapeInfo)); + + if (shape::isEmpty(max) || shape::isEmpty(min)) { + ArrayOptions::setPropertyBit(tmpShapeInfo, ARRAY_EMPTY); + memset(shape::stride(tmpShapeInfo), 0, + shape::rank(tmpShapeInfo) * sizeof(Nd4jLong)); + } + + ShapeDescriptor descriptor(tmpShapeInfo); + RELEASE(tmpShapeInfo, workspace); + resultShapeInfo = ConstantShapeHelper::getInstance() + .bufferForShapeInfo(descriptor) + .primary(); + + return true; +} - return dims; - } +////////////////////////////////////////////////////////////////////////// +// check the possibility of broadcast operation for set of arrays, if true then +// return resulting broadcasted shapeInfo +bool ShapeUtils::evalCommonBroadcastShapeInfo( + const std::vector& arrays, Nd4jLong*& resultShapeInfo, + memory::Workspace* workspace) { + if (resultShapeInfo != nullptr) + throw std::runtime_error( + "ShapeUtils::evalCommonBroadcastShapeInfo method: the input pointer on " + "shapeInfo must be empty (=nullptr) !"); + + int size = arrays.size(); + int maxRank = arrays[size - 1]->rankOf(); + + for (int i = 0; i < size - 1; ++i) { + if (arrays[i]->rankOf() > maxRank) maxRank = arrays[i]->rankOf(); + for (int j = i + 1; j < size; ++j) + if (!areShapesBroadcastable(*arrays[i], *arrays[j])) return false; + } + + Nd4jLong* tmpShapeInfo = nullptr; + ALLOCATE(tmpShapeInfo, workspace, shape::shapeInfoLength(maxRank), Nd4jLong); + memset(tmpShapeInfo, 0, shape::shapeInfoByteLength(maxRank)); + tmpShapeInfo[0] = maxRank; + + for (const auto& item : arrays) { + for (int i = -1; i >= -item->rankOf(); --i) + if (tmpShapeInfo[i + 1 + maxRank] < item->sizeAt(i)) + tmpShapeInfo[i + 1 + maxRank] = item->sizeAt(i); + } + + shape::updateStrides(tmpShapeInfo, arrays[0]->ordering()); + ArrayOptions::setDataType(tmpShapeInfo, arrays[0]->dataType()); + + ShapeDescriptor descriptor(tmpShapeInfo); + RELEASE(tmpShapeInfo, workspace); + resultShapeInfo = const_cast( + ConstantShapeHelper::getInstance().createShapeInfo(descriptor)); + + return true; +} - ////////////////////////////////////////////////////////////////////////// - // evaluate shapeInfo for resulting array from tile operation - const Nd4jLong* ShapeUtils::evalTileShapeInfo(const NDArray& arr, const std::vector& reps, sd::memory::Workspace* workspace) { - // check whether reps contains at least one zero (then throw exception) or whether all elements in reps are unities (then simply reshape or do nothing) - int repsSize = reps.size(); - Nd4jLong product = 1; - for(const auto& item : reps) - product *= item; - if(product == 0) - throw std::runtime_error("NDArray::tile method: one of the elements in reps array is zero !"); - - int rankOld = arr.rankOf(); - int diff = rankOld - repsSize; - - // evaluate new shapeInfo - Nd4jLong* newShapeInfo = nullptr; - if(diff < 0) { - ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(repsSize), Nd4jLong); - newShapeInfo[0] = repsSize; // set new rank - for(int i=1; i <= -diff; ++i) - newShapeInfo[i] = 1; // set unities to be new dimensions at left-hand side of newShapeInfo shape place - memcpy(newShapeInfo + 1 - diff, arr.shapeInfo() + 1, rankOld*sizeof(Nd4jLong)); // copy old dimensions to the right-hand side of newShapeInfo shape place - for(int i=1; i <= repsSize; ++i) - newShapeInfo[i] *= reps[i - 1]; // set new shape by multiplying old dimensions by corresponding numbers from reps - } - else { - ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rankOld), Nd4jLong); - memcpy(newShapeInfo, arr.shapeInfo(), shape::shapeInfoByteLength(rankOld)); // copy all elements of _shapeInfo to newShapeInfo - for(int i=1; i <= repsSize; ++i) - newShapeInfo[rankOld + 1 - i] *= reps[repsSize - i]; // set new shape by multiplying old dimensions by corresponding numbers from reps - } - shape::updateStrides(newShapeInfo, arr.ordering()); - ArrayOptions::setDataType(newShapeInfo, arr.dataType()); +////////////////////////////////////////////////////////////////////////// +// return sorted vector of dimensions common (same) for two arrays, dimensions +// values corresponds to array with bigger rank for example if arr1{2,7}, +// arr2{2,5,4,7} then vector = {0,3} +std::vector ShapeUtils::getDimsWithSameShape(const NDArray& arr1, + const NDArray& arr2) { + const NDArray *min, *max; + + if (arr1.rankOf() >= arr2.rankOf()) { + max = &arr1; + min = &arr2; + } else { + max = &arr2; + min = &arr1; + } + + const int rankDiff = max->rankOf() - min->rankOf(); + + std::vector dims; + + for (int i = 0; i < min->rankOf(); ++i) + if (min->sizeAt(i) == max->sizeAt(rankDiff + i)) + dims.emplace_back(rankDiff + i); + + return dims; +} - ShapeDescriptor descriptor(newShapeInfo); - RELEASE(newShapeInfo, workspace); - return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary(); - } +////////////////////////////////////////////////////////////////////////// +// evaluate shapeInfo for resulting array from tile operation +const Nd4jLong* ShapeUtils::evalTileShapeInfo( + const NDArray& arr, const std::vector& reps, + sd::memory::Workspace* workspace) { + // check whether reps contains at least one zero (then throw exception) or + // whether all elements in reps are unities (then simply reshape or do + // nothing) + int repsSize = reps.size(); + Nd4jLong product = 1; + for (const auto& item : reps) product *= item; + if (product == 0) + throw std::runtime_error( + "NDArray::tile method: one of the elements in reps array is zero !"); + + int rankOld = arr.rankOf(); + int diff = rankOld - repsSize; + + // evaluate new shapeInfo + Nd4jLong* newShapeInfo = nullptr; + if (diff < 0) { + ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(repsSize), + Nd4jLong); + newShapeInfo[0] = repsSize; // set new rank + for (int i = 1; i <= -diff; ++i) + newShapeInfo[i] = 1; // set unities to be new dimensions at left-hand + // side of newShapeInfo shape place + memcpy( + newShapeInfo + 1 - diff, arr.shapeInfo() + 1, + rankOld * sizeof(Nd4jLong)); // copy old dimensions to the right-hand + // side of newShapeInfo shape place + for (int i = 1; i <= repsSize; ++i) + newShapeInfo[i] *= + reps[i - 1]; // set new shape by multiplying old dimensions by + // corresponding numbers from reps + } else { + ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rankOld), + Nd4jLong); + memcpy(newShapeInfo, arr.shapeInfo(), + shape::shapeInfoByteLength( + rankOld)); // copy all elements of _shapeInfo to newShapeInfo + for (int i = 1; i <= repsSize; ++i) + newShapeInfo[rankOld + 1 - i] *= + reps[repsSize - i]; // set new shape by multiplying old dimensions by + // corresponding numbers from reps + } + shape::updateStrides(newShapeInfo, arr.ordering()); + ArrayOptions::setDataType(newShapeInfo, arr.dataType()); + + ShapeDescriptor descriptor(newShapeInfo); + RELEASE(newShapeInfo, workspace); + return ConstantShapeHelper::getInstance() + .bufferForShapeInfo(descriptor) + .primary(); +} - std::vector ShapeUtils::pullShapeFromShapeInfo(const Nd4jLong *shapeInfo) { - std::vector shape(shape::rank(shapeInfo)); - int shapeSize = shape.size(); +std::vector ShapeUtils::pullShapeFromShapeInfo( + const Nd4jLong* shapeInfo) { + std::vector shape(shape::rank(shapeInfo)); + int shapeSize = shape.size(); - for (int e = 0; e < shapeSize; e++) - shape[e] = shape::shapeOf(shapeInfo)[e]; + for (int e = 0; e < shapeSize; e++) shape[e] = shape::shapeOf(shapeInfo)[e]; - return shape; - } + return shape; +} - std::string ShapeUtils::shapeAsString(const NDArray* array) { - std::string result; +std::string ShapeUtils::shapeAsString(const NDArray* array) { + std::string result; - result.append("["); - for (int e = 0; e < array->rankOf(); e++) { - result += flatbuffers::NumToString(array->sizeAt(e)); - if (e < array->rankOf() - 1) - result.append(", "); - } - result.append("]"); + result.append("["); + for (int e = 0; e < array->rankOf(); e++) { + result += flatbuffers::NumToString(array->sizeAt(e)); + if (e < array->rankOf() - 1) result.append(", "); + } + result.append("]"); - return result; - } + return result; +} - std::string ShapeUtils::strideAsString(const NDArray* array) { - std::string result; - - auto shapeBuffer = array->shapeInfo(); //Nd4jLong* - int rank = (int)*shapeBuffer; - result.append("["); - for (int e = 0; e < rank; e++) { - if (e > 0) - result.append(","); - Nd4jLong stride = *(shapeBuffer + rank+1+e); - result += flatbuffers::NumToString(stride); - } - result.append("]"); +std::string ShapeUtils::shapeAsString(const NDArray& array) { + return shapeAsString(&array); +} - return result; - } +std::string ShapeUtils::strideAsString(const NDArray& array) { + return strideAsString(&array); +} - std::string ShapeUtils::shapeAsString(const std::vector& shape) { - std::string result; +std::string ShapeUtils::strideAsString(const NDArray* array) { + std::string result; - result.append("["); - for (int e = 0; e < shape.size(); e++) { - result += flatbuffers::NumToString(shape.at(e)); - if (e < shape.size() - 1) - result.append(", "); - } - result.append("]"); + auto shapeBuffer = array->shapeInfo(); // Nd4jLong* + int rank = (int)*shapeBuffer; + result.append("["); + for (int e = 0; e < rank; e++) { + if (e > 0) result.append(","); + Nd4jLong stride = *(shapeBuffer + rank + 1 + e); + result += flatbuffers::NumToString(stride); + } + result.append("]"); - return result; - } + return result; +} - std::string ShapeUtils::shapeAsString(const Nd4jLong* shapeInfo) { +std::string ShapeUtils::shapeAsString(const std::vector& shape) { + std::string result; - if(!shapeInfo) - throw std::runtime_error("ShapeUtils::shapeAsString method: input shapeInfo must not be nullptr !"); + result.append("["); + for (int e = 0; e < shape.size(); e++) { + result += flatbuffers::NumToString(shape.at(e)); + if (e < shape.size() - 1) result.append(", "); + } + result.append("]"); - std::string result; + return result; +} - result.append("["); - for (int e = 0; e < shapeInfo[0]; e++) { - result += flatbuffers::NumToString(shapeInfo[e+1]); - if (e < shapeInfo[0] - 1) - result.append(", "); - } - result.append("]"); +std::string ShapeUtils::shapeAsString(const Nd4jLong* shapeInfo) { + if (!shapeInfo) + throw std::runtime_error( + "ShapeUtils::shapeAsString method: input shapeInfo must not be nullptr " + "!"); - return result; - } + std::string result; - std::string ShapeUtils::shapeInfoAsString(const Nd4jLong* shapeInfo) { + result.append("["); + for (int e = 0; e < shapeInfo[0]; e++) { + result += flatbuffers::NumToString(shapeInfo[e + 1]); + if (e < shapeInfo[0] - 1) result.append(", "); + } + result.append("]"); - if(!shapeInfo) - throw std::runtime_error("ShapeUtils::shapeAsString method: input shapeInfo must not be nullptr !"); + return result; +} - std::string result; +std::string ShapeUtils::shapeInfoAsString(const Nd4jLong* shapeInfo) { + if (!shapeInfo) + throw std::runtime_error( + "ShapeUtils::shapeAsString method: input shapeInfo must not be nullptr " + "!"); - int len = shape::shapeInfoLength(shapeInfo[0]); + std::string result; - result.append("["); - for (int e = 0; e < len; e++) { - result += flatbuffers::NumToString(shapeInfo[e]); - if (e < len - 1) - result.append(", "); - } - result.append("]"); + int len = shape::shapeInfoLength(shapeInfo[0]); - return result; - } + result.append("["); + for (int e = 0; e < len; e++) { + result += flatbuffers::NumToString(shapeInfo[e]); + if (e < len - 1) result.append(", "); + } + result.append("]"); + return result; +} - std::string ShapeUtils::shapeAsString(const int rank, const Nd4jLong* shapeInfo) { - if(!shapeInfo) - throw std::runtime_error("ShapeUtils::shapeAsString method: input shapeInfo must not be nullptr !"); +std::string ShapeUtils::shapeAsString(const int rank, + const Nd4jLong* shapeInfo) { + if (!shapeInfo) + throw std::runtime_error( + "ShapeUtils::shapeAsString method: input shapeInfo must not be nullptr " + "!"); - std::string result; + std::string result; - result.append("["); - for (int e = 0; e < rank; e++) { - result += flatbuffers::NumToString(shapeInfo[e]); - if (e < rank - 1) - result.append(", "); - } - result.append("]"); + result.append("["); + for (int e = 0; e < rank; e++) { + result += flatbuffers::NumToString(shapeInfo[e]); + if (e < rank - 1) result.append(", "); + } + result.append("]"); - return result; - } + return result; +} ////////////////////////////////////////////////////////////////////////// std::vector ShapeUtils::shapeAsVector(const Nd4jLong* shapeInfo) { + if (!shapeInfo) + throw std::runtime_error( + "ShapeUtils::shapeAsVector method: input shapeInfo must not be nullptr " + "!"); - if(!shapeInfo) - throw std::runtime_error("ShapeUtils::shapeAsVector method: input shapeInfo must not be nullptr !"); + std::vector vector(shapeInfo[0]); - std::vector vector(shapeInfo[0]); + for (uint e = 0; e < shapeInfo[0]; e++) vector[e] = shapeInfo[e + 1]; - for (uint e = 0; e < shapeInfo[0]; e++) - vector[e] = shapeInfo[e + 1]; - - return vector; + return vector; } ////////////////////////////////////////////////////////////////////////// -// evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal - const Nd4jLong* ShapeUtils::evalDiagShapeInfo(const Nd4jLong* shapeInfoConst, sd::memory::Workspace* workspace){ - auto shapeInfo = const_cast(shapeInfoConst); - - const auto rank = shape::rank(shapeInfo); - - Nd4jLong* outputShapeInfo = nullptr; - - if(shape::isVector(shapeInfo) || shape::isScalar(shapeInfo)) { - ALLOCATE(outputShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong); - outputShapeInfo[0] = 2; - outputShapeInfo[1] = outputShapeInfo[2] = shape::length(shapeInfo); - } - else { - ALLOCATE(outputShapeInfo, workspace, shape::shapeInfoLength(2*rank), Nd4jLong); - outputShapeInfo[0] = 2*rank; - for(int i = 1; i <= rank; ++i) - outputShapeInfo[i] = outputShapeInfo[i + rank] = shapeInfo[i]; - } - - ShapeUtils::updateStridesAndType(outputShapeInfo, shapeInfo, shape::order(shapeInfo)); - - auto result = ConstantShapeHelper::getInstance().createShapeInfo(outputShapeInfo); - RELEASE(outputShapeInfo, workspace); - return result; - } +// evaluate shapeInfo for diagonal array which is made using input arr elements +// as diagonal +const Nd4jLong* ShapeUtils::evalDiagShapeInfo( + const Nd4jLong* shapeInfoConst, sd::memory::Workspace* workspace) { + auto shapeInfo = const_cast(shapeInfoConst); + + const auto rank = shape::rank(shapeInfo); + + Nd4jLong* outputShapeInfo = nullptr; + + if (shape::isVector(shapeInfo) || shape::isScalar(shapeInfo)) { + ALLOCATE(outputShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong); + outputShapeInfo[0] = 2; + outputShapeInfo[1] = outputShapeInfo[2] = shape::length(shapeInfo); + } else { + ALLOCATE(outputShapeInfo, workspace, shape::shapeInfoLength(2 * rank), + Nd4jLong); + outputShapeInfo[0] = 2 * rank; + for (int i = 1; i <= rank; ++i) + outputShapeInfo[i] = outputShapeInfo[i + rank] = shapeInfo[i]; + } + + ShapeUtils::updateStridesAndType(outputShapeInfo, shapeInfo, + shape::order(shapeInfo)); + + auto result = + ConstantShapeHelper::getInstance().createShapeInfo(outputShapeInfo); + RELEASE(outputShapeInfo, workspace); + return result; +} -std::vector ShapeUtils::evalBroadcastBackwardAxis(const Nd4jLong *operandShapeInfo, const Nd4jLong *resultShapeInfo) { - // rRank >= oRank always !! - const auto oRank = shape::rank(operandShapeInfo); - const auto rRank = shape::rank(resultShapeInfo); - const auto diff = rRank - oRank; - std::vector axis; +std::vector ShapeUtils::evalBroadcastBackwardAxis( + const Nd4jLong* operandShapeInfo, const Nd4jLong* resultShapeInfo) { + // rRank >= oRank always !! + const auto oRank = shape::rank(operandShapeInfo); + const auto rRank = shape::rank(resultShapeInfo); + const auto diff = rRank - oRank; + std::vector axis; - for(int i = 0; i < rRank; ++i) - if(i < diff || shape::sizeAt(operandShapeInfo, i - diff) != shape::sizeAt(resultShapeInfo, i)) - axis.push_back(i); + for (int i = 0; i < rRank; ++i) + if (i < diff || shape::sizeAt(operandShapeInfo, i - diff) != + shape::sizeAt(resultShapeInfo, i)) + axis.push_back(i); - return axis; + return axis; } //////////////////////////////////////////////////////////////////////////////// - const Nd4jLong* ShapeUtils::matrixProductShape(const Nd4jLong* theFirstShape, const Nd4jLong* theSecondShape, bool shouldTranspondFirst, bool shouldTranspondSecond, sd::DataType dtype, sd::memory::Workspace* workspace) { - auto inA = theFirstShape; - auto inB = theSecondShape; - Nd4jLong *shape; - ALLOCATE(shape, workspace, shape::shapeInfoLength(2), Nd4jLong); - - Nd4jLong* tmpA = ShapeBuilders::copyShapeInfo(inA, true, workspace); - Nd4jLong* tmpB = ShapeBuilders::copyShapeInfo(inB, true, workspace); - - if (shouldTranspondFirst) - shape::transposeInplace(tmpA); - - if (shouldTranspondSecond) - shape::transposeInplace(tmpB); - - - if (shape::rank(tmpA) == 1 && shape::isMatrix(tmpB)) { - // special case here - shape[0] = 1; - shape[1] = tmpB[2]; - Nd4jLong *newShape = ShapeBuilders::createShapeInfo(dtype, 'f', 2, shape, workspace); - - RELEASE(shape, workspace); - RELEASE(tmpA, workspace); - RELEASE(tmpB, workspace); - - return newShape; - } else if (shape::isScalar(tmpA) && shape::isScalar(tmpB)) { - // just scalar vs scalar - shape[0] = 1; - shape[1] = 1; - } else if (shape::isMatrix(tmpA) && shape::isVector(tmpB)) { - // gemv case - if (shape::rank(tmpB) == 2) { - shape[0] = tmpA[1]; - shape[1] = tmpB[2]; - } else { - // we have new 1D shape here - auto newShape = ShapeBuilders::createVectorShapeInfo(dtype, tmpA[1], workspace); - - RELEASE(shape, workspace); - RELEASE(tmpA, workspace); - RELEASE(tmpB, workspace); - - return newShape; - } - } else if ((shape::isMatrix(tmpA) && shape::isMatrix(tmpB)) || - (shape::isVector(tmpA) && shape::isMatrix(tmpB)) || - (shape::isColumnVector(tmpA) && shape::isVector(tmpB))) { - // gemm case - shape[0] = tmpA[1]; - shape[1] = tmpB[2]; - } else if ((shape::isVector(tmpA) && shape::isScalar(tmpB)) || - (shape::isScalar(tmpA) && shape::isVector(tmpB))) { - // element-wise - shape[0] = 1; - shape[1] = (int) sd::math::nd4j_max(shape::length(tmpA), shape::length(tmpB)); - } else if (shape::isRowVector(tmpA) && shape::isRowVector(tmpB)) { - // dot case - shape[0] = 1; - shape[1] = 1; - } else if (shape::isRowVector(tmpA) && shape::isColumnVector(tmpB)) { - // dot case - shape[0] = 1; - shape[1] = 1; - } - - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(dtype, 'f', 2, shape); - - RELEASE(shape, workspace); - - RELEASE(tmpA, workspace); - RELEASE(tmpB, workspace); - return newShape; +const Nd4jLong* ShapeUtils::matrixProductShape( + const Nd4jLong* theFirstShape, const Nd4jLong* theSecondShape, + bool shouldTranspondFirst, bool shouldTranspondSecond, sd::DataType dtype, + sd::memory::Workspace* workspace) { + auto inA = theFirstShape; + auto inB = theSecondShape; + Nd4jLong* shape; + ALLOCATE(shape, workspace, shape::shapeInfoLength(2), Nd4jLong); + + Nd4jLong* tmpA = ShapeBuilders::copyShapeInfo(inA, true, workspace); + Nd4jLong* tmpB = ShapeBuilders::copyShapeInfo(inB, true, workspace); + + if (shouldTranspondFirst) shape::transposeInplace(tmpA); + + if (shouldTranspondSecond) shape::transposeInplace(tmpB); + + if (shape::rank(tmpA) == 1 && shape::isMatrix(tmpB)) { + // special case here + shape[0] = 1; + shape[1] = tmpB[2]; + Nd4jLong* newShape = + ShapeBuilders::createShapeInfo(dtype, 'f', 2, shape, workspace); + + RELEASE(shape, workspace); + RELEASE(tmpA, workspace); + RELEASE(tmpB, workspace); + + return newShape; + } else if (shape::isScalar(tmpA) && shape::isScalar(tmpB)) { + // just scalar vs scalar + shape[0] = 1; + shape[1] = 1; + } else if (shape::isMatrix(tmpA) && shape::isVector(tmpB)) { + // gemv case + if (shape::rank(tmpB) == 2) { + shape[0] = tmpA[1]; + shape[1] = tmpB[2]; + } else { + // we have new 1D shape here + auto newShape = + ShapeBuilders::createVectorShapeInfo(dtype, tmpA[1], workspace); + + RELEASE(shape, workspace); + RELEASE(tmpA, workspace); + RELEASE(tmpB, workspace); + + return newShape; } + } else if ((shape::isMatrix(tmpA) && shape::isMatrix(tmpB)) || + (shape::isVector(tmpA) && shape::isMatrix(tmpB)) || + (shape::isColumnVector(tmpA) && shape::isVector(tmpB))) { + // gemm case + shape[0] = tmpA[1]; + shape[1] = tmpB[2]; + } else if ((shape::isVector(tmpA) && shape::isScalar(tmpB)) || + (shape::isScalar(tmpA) && shape::isVector(tmpB))) { + // element-wise + shape[0] = 1; + shape[1] = (int)sd::math::nd4j_max(shape::length(tmpA), + shape::length(tmpB)); + } else if (shape::isRowVector(tmpA) && shape::isRowVector(tmpB)) { + // dot case + shape[0] = 1; + shape[1] = 1; + } else if (shape::isRowVector(tmpA) && shape::isColumnVector(tmpB)) { + // dot case + shape[0] = 1; + shape[1] = 1; + } + + auto newShape = + ConstantShapeHelper::getInstance().createShapeInfo(dtype, 'f', 2, shape); + + RELEASE(shape, workspace); + + RELEASE(tmpA, workspace); + RELEASE(tmpB, workspace); + return newShape; +} //////////////////////////////////////////////////////////////////////////////// -std::vector ShapeUtils::evalPermutFromTo(const std::vector& shapeFrom, const std::vector& shapeTo) { - auto rank = shapeFrom.size(); - if(rank != shapeTo.size()) - throw std::runtime_error("ShapeUtils::evalPermutFromTo static method: the input shapes are not suitable for mutual permutation !"); - - if (std::equal(begin(shapeFrom), end(shapeFrom), begin(shapeTo))) // if shapes are identical (permutation is unnecessary) then return empty vector - return std::vector(); - - std::vector permutation(rank, -2); // vector to be returned - std::vector shapeTo2(shapeTo); // make copy of const vector since we will change the content of shapeTo - - for(int i=0; i ShapeUtils::evalPermutFromTo( + const std::vector& shapeFrom, + const std::vector& shapeTo) { + auto rank = shapeFrom.size(); + if (rank != shapeTo.size()) + throw std::runtime_error( + "ShapeUtils::evalPermutFromTo static method: the input shapes are not " + "suitable for mutual permutation !"); + + if (std::equal(begin(shapeFrom), end(shapeFrom), + begin(shapeTo))) // if shapes are identical (permutation is + // unnecessary) then return empty vector + return std::vector(); + + std::vector permutation(rank, -2); // vector to be returned + std::vector shapeTo2( + shapeTo); // make copy of const vector since we will change the content + // of shapeTo + + for (int i = 0; i < rank; ++i) + for (int j = 0; j < rank; ++j) + if (shapeFrom[i] == shapeTo2[j]) { + permutation[j] = i; + shapeTo2[j] = -2; // mark coincidence as -2 in order to not account + // index of shapeTo twice + break; + } + + if (std::find(begin(permutation), end(permutation), -2) != + end(permutation)) // if -2 is still present in vector then permutation is + // impossible + throw std::runtime_error( + "ShapeUtils::evalPermutFromTo static method: the input shapes are not " + "suitable for mutual permutation !"); + + return permutation; } - //////////////////////////////////////////////////////////////////////////////// -std::vector ShapeUtils::composeShapeUsingDimsAndIdx(const std::vector& dimsAndIdx) { - auto size = dimsAndIdx.size(); - if(size % 2 != 0) - throw std::runtime_error("ShapeUtils::composeShapeUsingDimsAndIdx static method: the size of input vector must be even !"); - - size /= 2; - - std::vector shape(size); - int index; - - for(int i = 0; i < size; ++i) { - index = dimsAndIdx[i + size]; - if(index > size-1) - throw std::runtime_error("ShapeUtils::composeShapeUsingDimsAndIdx static method: input index is too large !"); - shape[index] = dimsAndIdx[i]; - } - - return shape; +std::vector ShapeUtils::composeShapeUsingDimsAndIdx( + const std::vector& dimsAndIdx) { + auto size = dimsAndIdx.size(); + if (size % 2 != 0) + throw std::runtime_error( + "ShapeUtils::composeShapeUsingDimsAndIdx static method: the size of " + "input vector must be even !"); + + size /= 2; + + std::vector shape(size); + int index; + + for (int i = 0; i < size; ++i) { + index = dimsAndIdx[i + size]; + if (index > size - 1) + throw std::runtime_error( + "ShapeUtils::composeShapeUsingDimsAndIdx static method: input index " + "is too large !"); + shape[index] = dimsAndIdx[i]; + } + + return shape; } - //////////////////////////////////////////////////////////////////////////////// -std::vector ShapeUtils::evalShapeForMatmul(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const bool transX, const bool transY) { - - const auto xRank = xShapeInfo[0]; - const auto yRank = yShapeInfo[0]; - - const Nd4jLong x0Dim = transX ? xShapeInfo[xRank] : xShapeInfo[xRank-1]; - const Nd4jLong y0Dim = transY ? yShapeInfo[yRank] : yShapeInfo[yRank-1]; - const Nd4jLong x1Dim = transX ? xShapeInfo[xRank-1] : xShapeInfo[xRank]; - const Nd4jLong y1Dim = transY ? yShapeInfo[yRank-1] : yShapeInfo[yRank]; - - - if(xRank == 1 && yRank == 1) { // dot case, output is scalar - if(xShapeInfo[1] != yShapeInfo[1]) { - nd4j_printf("ShapeUtils::evalShapeForMatmul method: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !", xShapeInfo[1], yShapeInfo[1]); - throw std::invalid_argument(""); - } - return std::vector({}); +std::vector ShapeUtils::evalShapeForMatmul(const Nd4jLong* xShapeInfo, + const Nd4jLong* yShapeInfo, + const bool transX, + const bool transY) { + const auto xRank = xShapeInfo[0]; + const auto yRank = yShapeInfo[0]; + + const Nd4jLong x0Dim = transX ? xShapeInfo[xRank] : xShapeInfo[xRank - 1]; + const Nd4jLong y0Dim = transY ? yShapeInfo[yRank] : yShapeInfo[yRank - 1]; + const Nd4jLong x1Dim = transX ? xShapeInfo[xRank - 1] : xShapeInfo[xRank]; + const Nd4jLong y1Dim = transY ? yShapeInfo[yRank - 1] : yShapeInfo[yRank]; + + if (xRank == 1 && yRank == 1) { // dot case, output is scalar + if (xShapeInfo[1] != yShapeInfo[1]) { + nd4j_printf( + "ShapeUtils::evalShapeForMatmul method: since input arrays are " + "vectors they must have the same length, but got x length = %i, y " + "length = %i !", + xShapeInfo[1], yShapeInfo[1]); + throw std::invalid_argument(""); } - - - if(xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector - if(xShapeInfo[1] != y0Dim) { - nd4j_printf("ShapeUtils::evalShapeForMatmul method: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !", ShapeUtils::shapeAsString(xShapeInfo).c_str(), ShapeUtils::shapeAsString(yShapeInfo).c_str()); - throw std::invalid_argument(""); - } - return std::vector({y1Dim}); + return std::vector({}); + } + + if (xRank == 1 && + yRank == + 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector + if (xShapeInfo[1] != y0Dim) { + nd4j_printf( + "ShapeUtils::evalShapeForMatmul method: input arrays have " + "inconsistent shapes for vector-matrix product: x %s, y %s !", + ShapeUtils::shapeAsString(xShapeInfo).c_str(), + ShapeUtils::shapeAsString(yShapeInfo).c_str()); + throw std::invalid_argument(""); } - - - if(xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector - if(x1Dim != yShapeInfo[1]) { - nd4j_printf("ShapeUtils::evalShapeForMatmul method: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !", ShapeUtils::shapeAsString(xShapeInfo).c_str(), ShapeUtils::shapeAsString(yShapeInfo).c_str()); - throw std::invalid_argument(""); - } - return std::vector({x0Dim}); + return std::vector({y1Dim}); + } + + if (xRank == 2 && + yRank == + 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector + if (x1Dim != yShapeInfo[1]) { + nd4j_printf( + "ShapeUtils::evalShapeForMatmul method: input arrays have " + "inconsistent shapes for vector-matrix product: x %s, y %s !", + ShapeUtils::shapeAsString(xShapeInfo).c_str(), + ShapeUtils::shapeAsString(yShapeInfo).c_str()); + throw std::invalid_argument(""); } - - - // rest cases - usual 2Dx2D or batched mmul - if(xRank != yRank) { - nd4j_printf("ShapeUtils::evalShapeForMatmul static method: the ranks of arrays must be the same, but got xRank = %i and yRank = %i ! \n", xRank, yRank); - throw std::invalid_argument(""); + return std::vector({x0Dim}); + } + + // rest cases - usual 2Dx2D or batched mmul + if (xRank != yRank) { + nd4j_printf( + "ShapeUtils::evalShapeForMatmul static method: the ranks of arrays " + "must be the same, but got xRank = %i and yRank = %i ! \n", + xRank, yRank); + throw std::invalid_argument(""); + } + + if (x1Dim != y0Dim) { + nd4j_printf( + "ShapeUtils::evalShapeForMatmul static method: input shapes are " + "inconsistent: xDim %i != yDim %i \n", + x1Dim, y0Dim); + throw std::invalid_argument(""); + } + + for (int i = 0; i < xRank - 2; ++i) + if (xShapeInfo[i + 1] != yShapeInfo[i + 1]) { + nd4j_printf( + "ShapeUtils::evalShapeForMatmul static method: input shapes are " + "inconsistent: xShape = %s, yShape = %s ! \n", + ShapeUtils::shapeAsString(xShapeInfo).c_str(), + ShapeUtils::shapeAsString(yShapeInfo).c_str()); + throw std::invalid_argument(""); } - if(x1Dim != y0Dim) { - nd4j_printf("ShapeUtils::evalShapeForMatmul static method: input shapes are inconsistent: xDim %i != yDim %i \n", x1Dim, y0Dim); - throw std::invalid_argument(""); - } - - for(int i = 0; i < xRank - 2; ++i) - if(xShapeInfo[i+1] != yShapeInfo[i+1]) { - nd4j_printf("ShapeUtils::evalShapeForMatmul static method: input shapes are inconsistent: xShape = %s, yShape = %s ! \n", ShapeUtils::shapeAsString(xShapeInfo).c_str(), ShapeUtils::shapeAsString(yShapeInfo).c_str()); - throw std::invalid_argument(""); - } - - std::vector cShape(xRank); + std::vector cShape(xRank); - // copy batch part of shape (if present) - for(int i = 0; i < xRank - 2; ++i) - cShape[i] = xShapeInfo[i+1]; - // copy rest part of shape (two dims: multiplication part) - cShape[xRank-2] = x0Dim; - cShape[xRank-1] = y1Dim; + // copy batch part of shape (if present) + for (int i = 0; i < xRank - 2; ++i) cShape[i] = xShapeInfo[i + 1]; + // copy rest part of shape (two dims: multiplication part) + cShape[xRank - 2] = x0Dim; + cShape[xRank - 1] = y1Dim; - return cShape; + return cShape; } //////////////////////////////////////////////////////////////////////////////// -Nd4jLong ShapeUtils::getNumOfSubArrs(const Nd4jLong* shapeInfo, const std::vector& dimsToExclude) { +Nd4jLong ShapeUtils::getNumOfSubArrs(const Nd4jLong* shapeInfo, + const std::vector& dimsToExclude) { + Nd4jLong numOfSubArrs = 1; - Nd4jLong numOfSubArrs = 1; - - if(dimsToExclude.size() == shape::rank(shapeInfo) || dimsToExclude.size() == 0) // means there is only one sub-array and it coincides with whole array - return numOfSubArrs; + if (dimsToExclude.size() == shape::rank(shapeInfo) || + dimsToExclude.size() == 0) // means there is only one sub-array and it + // coincides with whole array + return numOfSubArrs; - for(const auto& dim : dimsToExclude) - numOfSubArrs *= shapeInfo[dim + 1]; + for (const auto& dim : dimsToExclude) numOfSubArrs *= shapeInfo[dim + 1]; - return numOfSubArrs; + return numOfSubArrs; } //////////////////////////////////////////////////////////////////////////////// -std::vector ShapeUtils::evalDimsWithoutUnities(const Nd4jLong* shapeInfo) { +std::vector ShapeUtils::evalDimsWithoutUnities( + const Nd4jLong* shapeInfo) { + std::vector result; + for (int i = 1; i <= shapeInfo[0]; ++i) + if (shapeInfo[i] != 1) result.push_back(shapeInfo[i]); - std::vector result; - for(int i = 1; i <= shapeInfo[0]; ++i) - if(shapeInfo[i] != 1) - result.push_back(shapeInfo[i]); - - return result; + return result; } //////////////////////////////////////////////////////////////////////////////// -void ShapeUtils::updateStridesAndType(Nd4jLong* dest, const Nd4jLong* source, const char order) { - - shape::updateStrides(dest, order); - ArrayOptions::copyDataType(dest, source); +void ShapeUtils::updateStridesAndType(Nd4jLong* dest, const Nd4jLong* source, + const char order) { + shape::updateStrides(dest, order); + ArrayOptions::copyDataType(dest, source); } //////////////////////////////////////////////////////////////////////////////// -void ShapeUtils::updateStridesAndType(Nd4jLong* dest, const DataType dtype, const char order) { - - shape::updateStrides(dest, order); - ArrayOptions::setDataType(dest, dtype); +void ShapeUtils::updateStridesAndType(Nd4jLong* dest, const DataType dtype, + const char order) { + shape::updateStrides(dest, order); + ArrayOptions::setDataType(dest, dtype); } //////////////////////////////////////////////////////////////////////////////// -std::vector ShapeUtils::tadAxesForSimpleBroadcast(const NDArray& max, const NDArray& min) { - - const int maxRank = max.rankOf(); - const int minRank = min.rankOf(); - const int diff = maxRank - minRank; - - Nd4jLong numOfMinTads(1), numOfMaxTads(1); - std::vector maxTadDims; - - for(int i = 0; i < minRank; ++i) { - if(min.sizeAt(i) == max.sizeAt(diff + i)) - maxTadDims.push_back(diff + i); - else { - numOfMinTads *= min.sizeAt(i); - numOfMaxTads *= max.sizeAt(i); - } +std::vector ShapeUtils::tadAxesForSimpleBroadcast(const NDArray& max, + const NDArray& min) { + const int maxRank = max.rankOf(); + const int minRank = min.rankOf(); + const int diff = maxRank - minRank; + + Nd4jLong numOfMinTads(1), numOfMaxTads(1); + std::vector maxTadDims; + + for (int i = 0; i < minRank; ++i) { + if (min.sizeAt(i) == max.sizeAt(diff + i)) + maxTadDims.push_back(diff + i); + else { + numOfMinTads *= min.sizeAt(i); + numOfMaxTads *= max.sizeAt(i); } + } - if(min.lengthOf() > max.lengthOf()) { // in this case tad is max array - for(int i = 0; i < diff; ++i) - numOfMaxTads *= max.sizeAt(i); + if (min.lengthOf() > max.lengthOf()) { // in this case tad is max array + for (int i = 0; i < diff; ++i) numOfMaxTads *= max.sizeAt(i); - return numOfMaxTads == 1 ? maxTadDims : std::vector(); - } + return numOfMaxTads == 1 ? maxTadDims : std::vector(); + } - return numOfMinTads == 1 ? maxTadDims : std::vector(); + return numOfMinTads == 1 ? maxTadDims : std::vector(); } -void ShapeUtils::copyCertainStridesFromShapeInfo(const Nd4jLong* inShapeInfo, const int nRank, const int dimsSize, const int* dims, Nd4jLong* outStrides) { - - int yRank = shape::rank(inShapeInfo); - auto yOrigStride = shape::stride(inShapeInfo); - - if (yRank == nRank) { - for (int i = 0; i < yRank; ++i) { - // x[2,3,4] * y[2,1,4] = z[2,3,4] - outStrides[i] = (1 == shape::sizeAt(inShapeInfo, i)) ? 0 : yOrigStride[i]; - } +void ShapeUtils::copyCertainStridesFromShapeInfo(const Nd4jLong* inShapeInfo, + const int nRank, + const int dimsSize, + const int* dims, + Nd4jLong* outStrides) { + int yRank = shape::rank(inShapeInfo); + auto yOrigStride = shape::stride(inShapeInfo); + + if (yRank == nRank) { + for (int i = 0; i < yRank; ++i) { + // x[2,3,4] * y[2,1,4] = z[2,3,4] + outStrides[i] = (1 == shape::sizeAt(inShapeInfo, i)) ? 0 : yOrigStride[i]; } - else { + } else { + auto dimEx = sd::ShapeUtils::evalDimsToExclude(nRank, dimsSize, dims); - auto dimEx = sd::ShapeUtils::evalDimsToExclude(nRank, dimsSize, dims); - - for (int i = 0, it = 0; i < nRank; ++i) { - auto nCount = std::count(dimEx.cbegin(), dimEx.cend(), i); - outStrides[i] = (0 == nCount) ? yOrigStride[it++] : 0; - if (it == yRank) - break; - } + for (int i = 0, it = 0; i < nRank; ++i) { + auto nCount = std::count(dimEx.cbegin(), dimEx.cend(), i); + outStrides[i] = (0 == nCount) ? yOrigStride[it++] : 0; + if (it == yRank) break; } + } } -bool ShapeUtils::areShapesEqual(const Nd4jLong* shapeInfo, const std::vector& shapeOnly) { - - if(shape::rank(shapeInfo) != shapeOnly.size()) - return false; +bool ShapeUtils::areShapesEqual(const Nd4jLong* shapeInfo, + const std::vector& shapeOnly) { + if (shape::rank(shapeInfo) != shapeOnly.size()) return false; - for(uint i = 0; i < shape::rank(shapeInfo); ++i) - if(shape::shapeOf(shapeInfo)[i] != shapeOnly[i]) - return false; + for (uint i = 0; i < shape::rank(shapeInfo); ++i) + if (shape::shapeOf(shapeInfo)[i] != shapeOnly[i]) return false; - return true; + return true; } //////////////////////////////////////////////////////////////////////////////// /* -bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector& sameDims) { +bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, +std::vector& sameDims) { if(!sameDims.empty()) sameDims.clear(); @@ -1079,7 +1239,8 @@ bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::v int numUnitiesInMin = 0; - for (int iMax = -1, iMin = -1; iMax >= -max->rankOf() && iMin >= -min->rankOf(); ) { + for (int iMax = -1, iMin = -1; iMax >= -max->rankOf() && iMin >= +-min->rankOf(); ) { if(max->sizeAt(iMax) == 1) { // ignore unities in shape --iMax; @@ -1104,6 +1265,4 @@ bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::v } */ -} - - +} // namespace sd diff --git a/libnd4j/include/helpers/impl/SimpleReadWriteLock.cpp b/libnd4j/include/helpers/impl/SimpleReadWriteLock.cpp index 52682b925fee..b6af27fbd813 100644 --- a/libnd4j/include/helpers/impl/SimpleReadWriteLock.cpp +++ b/libnd4j/include/helpers/impl/SimpleReadWriteLock.cpp @@ -20,51 +20,47 @@ #include - namespace sd { - SimpleReadWriteLock::SimpleReadWriteLock(const SimpleReadWriteLock& other) { - _read_locks.store(other._read_locks.load()); - _write_locks.store(other._write_locks.load()); - } +SimpleReadWriteLock::SimpleReadWriteLock(const SimpleReadWriteLock& other) { + _read_locks.store(other._read_locks.load()); + _write_locks.store(other._write_locks.load()); +} - SimpleReadWriteLock::SimpleReadWriteLock(){ - _read_locks.store(0); - _write_locks.store(0); - } +SimpleReadWriteLock::SimpleReadWriteLock() { + _read_locks.store(0); + _write_locks.store(0); +} - void SimpleReadWriteLock::lockRead() { - _mutex.lock(); - _read_locks++; - while(_write_locks.load() > 0) { - // just loop - } - _mutex.unlock(); - } +void SimpleReadWriteLock::lockRead() { + _mutex.lock(); + _read_locks++; + while (_write_locks.load() > 0) { + // just loop + } + _mutex.unlock(); +} - void SimpleReadWriteLock::unlockRead() { - _read_locks--; - } +void SimpleReadWriteLock::unlockRead() { _read_locks--; } - // write lock - void SimpleReadWriteLock::lockWrite() { - _mutex.lock(); - _write_locks++; - while (_read_locks.load() > 0) { - // just loop - } - _mutex.unlock(); - } +// write lock +void SimpleReadWriteLock::lockWrite() { + _mutex.lock(); + _write_locks++; + while (_read_locks.load() > 0) { + // just loop + } + _mutex.unlock(); +} - void SimpleReadWriteLock::unlockWrite() { - _write_locks--; - } +void SimpleReadWriteLock::unlockWrite() { _write_locks--; } - SimpleReadWriteLock& SimpleReadWriteLock::operator= ( const SimpleReadWriteLock &other) { - if (this == &other) return *this; +SimpleReadWriteLock& SimpleReadWriteLock::operator=( + const SimpleReadWriteLock& other) { + if (this == &other) return *this; - this->_write_locks.store(other._write_locks.load()); - this->_read_locks.store(other._read_locks.load()); + this->_write_locks.store(other._write_locks.load()); + this->_read_locks.store(other._read_locks.load()); - return *this; - } -} \ No newline at end of file + return *this; +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/impl/Sqrtm.cpp b/libnd4j/include/helpers/impl/Sqrtm.cpp index 5fe45656f9b8..087097a59c8e 100644 --- a/libnd4j/include/helpers/impl/Sqrtm.cpp +++ b/libnd4j/include/helpers/impl/Sqrtm.cpp @@ -265,10 +265,10 @@ void Sqrtm::calc(const NDArray& in, NDArray& out) { MmulHelper::mmul(&schur._U, &temp, &out); } -template class ND4J_EXPORT Sqrtm; -template class ND4J_EXPORT Sqrtm; -template class ND4J_EXPORT Sqrtm; -template class ND4J_EXPORT Sqrtm; +template class SD_EXPORT Sqrtm; +template class SD_EXPORT Sqrtm; +template class SD_EXPORT Sqrtm; +template class SD_EXPORT Sqrtm; } diff --git a/libnd4j/include/helpers/impl/StringUtils.cpp b/libnd4j/include/helpers/impl/StringUtils.cpp index 757def763a14..05d20c364f1d 100644 --- a/libnd4j/include/helpers/impl/StringUtils.cpp +++ b/libnd4j/include/helpers/impl/StringUtils.cpp @@ -26,15 +26,15 @@ #include namespace sd { - static FORCEINLINE bool match(const uint8_t *haystack, const uint8_t *needle, uint64_t length) { - for (int e = 0; e < length; e++) - if (haystack[e] != needle[e]) - return false; +static FORCEINLINE bool match(const uint8_t* haystack, const uint8_t* needle, + uint64_t length) { + for (int e = 0; e < length; e++) + if (haystack[e] != needle[e]) return false; - return true; - } + return true; +} - template +template std::string StringUtils::bitsToString(T value) { return std::bitset(value).to_string(); } @@ -42,131 +42,130 @@ namespace sd { template std::string StringUtils::bitsToString(int value); template std::string StringUtils::bitsToString(uint32_t value); template std::string StringUtils::bitsToString(Nd4jLong value); -template std::string StringUtils::bitsToString(uint64_t value); +template std::string StringUtils::bitsToString(uint64_t value);uint64_t StringUtils::countSubarrays(const void* vhaystack, + uint64_t haystackLength, + const void* vneedle, + uint64_t needleLength) { + auto haystack = reinterpret_cast(vhaystack); + auto needle = reinterpret_cast(vneedle); + uint64_t number = 0; - uint64_t StringUtils::countSubarrays(const void *vhaystack, uint64_t haystackLength, const void *vneedle, uint64_t needleLength) { - auto haystack = reinterpret_cast(vhaystack); - auto needle = reinterpret_cast(vneedle); + for (uint64_t e = 0; e < haystackLength - needleLength; e++) { + if (match(&haystack[e], needle, needleLength)) number++; + } - uint64_t number = 0; + return number; +} - for (uint64_t e = 0; e < haystackLength - needleLength; e++) { - if (match(&haystack[e], needle, needleLength)) - number++; - } +uint64_t StringUtils::byteLength(const NDArray& array) { + if (!array.isS()) + throw sd::datatype_exception::build( + "StringUtils::byteLength expects one of String types;", + array.dataType()); - return number; - } + auto buffer = array.bufferAsT(); + return buffer[array.lengthOf()]; +} +std::vector StringUtils::split(const std::string& haystack, + const std::string& delimiter) { + std::vector output; - uint64_t StringUtils::byteLength(const NDArray &array) { - if (!array.isS()) - throw sd::datatype_exception::build("StringUtils::byteLength expects one of String types;", array.dataType()); + std::string::size_type prev_pos = 0, pos = 0; - auto buffer = array.bufferAsT(); - return buffer[array.lengthOf()]; - } + // iterating through the haystack till the end + while ((pos = haystack.find(delimiter, pos)) != std::string::npos) { + output.emplace_back(haystack.substr(prev_pos, pos - prev_pos)); + prev_pos = ++pos; + } - std::vector StringUtils::split(const std::string &haystack, const std::string &delimiter) { - std::vector output; + output.emplace_back(haystack.substr(prev_pos, pos - prev_pos)); // Last word - std::string::size_type prev_pos = 0, pos = 0; + return output; +} - // iterating through the haystack till the end - while((pos = haystack.find(delimiter, pos)) != std::string::npos) { - output.emplace_back(haystack.substr(prev_pos, pos-prev_pos)); - prev_pos = ++pos; - } +bool StringUtils::u8StringToU16String(const std::string& u8, + std::u16string& u16) { + if (u8.empty()) return false; - output.emplace_back(haystack.substr(prev_pos, pos - prev_pos)); // Last word + u16.resize(unicode::offsetUtf8StringInUtf16(u8.data(), u8.size()) / + sizeof(char16_t)); + if (u8.size() == u16.size()) + u16.assign(u8.begin(), u8.end()); + else + return unicode::utf8to16(u8.data(), &u16[0], u8.size()); - return output; - } - - bool StringUtils::u8StringToU16String(const std::string& u8, std::u16string& u16) { - - if (u8.empty()) - return false; - - u16.resize(unicode::offsetUtf8StringInUtf16(u8.data(), u8.size()) / sizeof(char16_t)); - if (u8.size() == u16.size()) - u16.assign(u8.begin(), u8.end()); - else - return unicode::utf8to16(u8.data(), &u16[0], u8.size()); - - return true; - } + return true; +} - bool StringUtils::u8StringToU32String(const std::string& u8, std::u32string& u32) { +bool StringUtils::u8StringToU32String(const std::string& u8, + std::u32string& u32) { + if (u8.empty()) return false; - if (u8.empty()) - return false; + u32.resize(unicode::offsetUtf8StringInUtf32(u8.data(), u8.size()) / + sizeof(char32_t)); + if (u8.size() == u32.size()) + u32.assign(u8.begin(), u8.end()); + else + return unicode::utf8to32(u8.data(), &u32[0], u8.size()); - u32.resize( unicode::offsetUtf8StringInUtf32(u8.data(), u8.size()) / sizeof(char32_t) ); - if (u8.size() == u32.size()) - u32.assign(u8.begin(), u8.end()); - else - return unicode::utf8to32(u8.data(), &u32[0], u8.size()); - - return true; - } + return true; +} - bool StringUtils::u16StringToU32String(const std::u16string& u16, std::u32string& u32) { - - if (u16.empty()) - return false; - - u32.resize(unicode::offsetUtf16StringInUtf32(u16.data(), u16.size()) / sizeof(char32_t)); - if (u16.size() == u32.size()) - u32.assign(u16.begin(), u16.end()); - else - return unicode::utf16to32(u16.data(), &u32[0], u16.size()); - - return true; - } - - bool StringUtils::u16StringToU8String(const std::u16string& u16, std::string& u8) { - - if (u16.empty()) - return false; - - u8.resize(unicode::offsetUtf16StringInUtf8(u16.data(), u16.size())); - if (u16.size() == u8.size()) - u8.assign(u16.begin(), u16.end()); - else - return unicode::utf16to8(u16.data(), &u8[0], u16.size()); - - return true; - } - - bool StringUtils::u32StringToU16String(const std::u32string& u32, std::u16string& u16) { - - if (u32.empty()) - return false; - - u16.resize(unicode::offsetUtf32StringInUtf16(u32.data(), u32.size()) / sizeof(char16_t)); - if (u32.size() == u16.size()) - u16.assign(u32.begin(), u32.end()); - else - return unicode::utf32to16(u32.data(), &u16[0], u32.size()); - - return true; - } - - bool StringUtils::u32StringToU8String(const std::u32string& u32, std::string& u8) { - - if (u32.empty()) - return false; - - u8.resize(unicode::offsetUtf32StringInUtf8(u32.data(), u32.size())); - if (u32.size() == u8.size()) - u8.assign(u32.begin(), u32.end()); - else - return unicode::utf32to8(u32.data(), &u8[0], u32.size()); - - return true; - } +bool StringUtils::u16StringToU32String(const std::u16string& u16, + std::u32string& u32) { + if (u16.empty()) return false; + + u32.resize(unicode::offsetUtf16StringInUtf32(u16.data(), u16.size()) / + sizeof(char32_t)); + if (u16.size() == u32.size()) + u32.assign(u16.begin(), u16.end()); + else + return unicode::utf16to32(u16.data(), &u32[0], u16.size()); + + return true; +} + +bool StringUtils::u16StringToU8String(const std::u16string& u16, + std::string& u8) { + if (u16.empty()) return false; + + u8.resize(unicode::offsetUtf16StringInUtf8(u16.data(), u16.size())); + if (u16.size() == u8.size()) + u8.assign(u16.begin(), u16.end()); + else + return unicode::utf16to8(u16.data(), &u8[0], u16.size()); + + return true; +} + +bool StringUtils::u32StringToU16String(const std::u32string& u32, + std::u16string& u16) { + if (u32.empty()) return false; + + u16.resize(unicode::offsetUtf32StringInUtf16(u32.data(), u32.size()) / + sizeof(char16_t)); + if (u32.size() == u16.size()) + u16.assign(u32.begin(), u32.end()); + else + return unicode::utf32to16(u32.data(), &u16[0], u32.size()); + + return true; +} + +bool StringUtils::u32StringToU8String(const std::u32string& u32, + std::string& u8) { + if (u32.empty()) return false; + + u8.resize(unicode::offsetUtf32StringInUtf8(u32.data(), u32.size())); + if (u32.size() == u8.size()) + u8.assign(u32.begin(), u32.end()); + else + return unicode::utf32to8(u32.data(), &u8[0], u32.size()); + + return true; +} template std::string StringUtils::vectorToString(const std::vector &vec) { @@ -181,4 +180,4 @@ template std::string StringUtils::bitsToString(uint64_t value); template std::string StringUtils::vectorToString(const std::vector &vec); template std::string StringUtils::vectorToString(const std::vector &vec); template std::string StringUtils::vectorToString(const std::vector &vec); -} +} // namespace sd diff --git a/libnd4j/include/helpers/impl/TAD.cpp b/libnd4j/include/helpers/impl/TAD.cpp index 5d31827da27d..79261707954a 100644 --- a/libnd4j/include/helpers/impl/TAD.cpp +++ b/libnd4j/include/helpers/impl/TAD.cpp @@ -18,10 +18,7 @@ // @author Adam Gibson // - #include #include -namespace shape { - -} \ No newline at end of file +namespace shape {} \ No newline at end of file diff --git a/libnd4j/include/helpers/impl/helper_hash.cpp b/libnd4j/include/helpers/impl/helper_hash.cpp index 4fde919cddc5..f02707dac310 100644 --- a/libnd4j/include/helpers/impl/helper_hash.cpp +++ b/libnd4j/include/helpers/impl/helper_hash.cpp @@ -22,46 +22,46 @@ #include namespace sd { - namespace ops { +namespace ops { - HashHelper& HashHelper::getInstance() { - static HashHelper instance; - return instance; - } +HashHelper& HashHelper::getInstance() { + static HashHelper instance; - Nd4jLong HashHelper::getLongHash(std::string& str) { - _locker.lock(); - if (!_isInit) { - nd4j_verbose("Building HashUtil table\n",""); + return instance; +} - Nd4jLong h = 0x544B2FBACAAF1684L; - for (int i = 0; i < 256; i++) { - for (int j = 0; j < 31; j++) { - h = (((unsigned long long) h) >> 7) ^ h; - h = (h << 11) ^ h; - h = (((unsigned long long) h) >> 10) ^ h; - } - _byteTable[i] = h; - } +Nd4jLong HashHelper::getLongHash(const std::string& str) { + _locker.lock(); + if (!_isInit) { + nd4j_verbose("Building HashUtil table\n", ""); + Nd4jLong h = 0x544B2FBACAAF1684L; + for (int i = 0; i < 256; i++) { + for (int j = 0; j < 31; j++) { + h = (((unsigned long long)h) >> 7) ^ h; + h = (h << 11) ^ h; + h = (((unsigned long long)h) >> 10) ^ h; + } + _byteTable[i] = h; + } - _isInit = true; - } + _isInit = true; + } - _locker.unlock(); + _locker.unlock(); - Nd4jLong h = HSTART; - Nd4jLong hmult = HMULT; - Nd4jLong len = str.size(); - for (int i = 0; i < len; i++) { - char ch = str.at(i); - auto uch = (unsigned char) ch; - h = (h * hmult) ^ _byteTable[ch & 0xff]; - h = (h * hmult) ^ _byteTable[(uch >> 8) & 0xff]; - } + Nd4jLong h = HSTART; + Nd4jLong hmult = HMULT; + Nd4jLong len = str.size(); + for (int i = 0; i < len; i++) { + char ch = str.at(i); + auto uch = (unsigned char)ch; + h = (h * hmult) ^ _byteTable[ch & 0xff]; + h = (h * hmult) ^ _byteTable[(uch >> 8) & 0xff]; + } - return h; - } - } + return h; } +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/helpers/impl/hhSequence.cpp b/libnd4j/include/helpers/impl/hhSequence.cpp index dc038dfc8de4..6072e492324f 100644 --- a/libnd4j/include/helpers/impl/hhSequence.cpp +++ b/libnd4j/include/helpers/impl/hhSequence.cpp @@ -18,6 +18,7 @@ // Created by Yurii Shyrma on 02.01.2018 // +#include #include #include @@ -25,22 +26,20 @@ namespace sd { namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// -HHsequence::HHsequence(const NDArray& vectors, const NDArray& coeffs, const char type): _vectors(vectors), _coeffs(coeffs) { - - _diagSize = sd::math::nd4j_min(_vectors.sizeAt(0), _vectors.sizeAt(1)); - _shift = 0; - _type = type; +HHsequence::HHsequence(const NDArray& vectors, const NDArray& coeffs, const char type) + : _vectors(vectors), _coeffs(coeffs) { + _diagSize = sd::math::nd4j_min(_vectors.sizeAt(0), _vectors.sizeAt(1)); + _shift = 0; + _type = type; } ////////////////////////////////////////////////////////////////////////// template void HHsequence::mulLeft_(NDArray& matrix) { - - const int rows = _vectors.sizeAt(0); - const int cols = _vectors.sizeAt(1); - const int inRows = matrix.sizeAt(0); + const int rows = _vectors.sizeAt(0); + const int cols = _vectors.sizeAt(1); + const int inRows = matrix.sizeAt(0); for(int i = _diagSize - 1; i >= 0; --i) { @@ -57,50 +56,51 @@ void HHsequence::mulLeft_(NDArray& matrix) { } } - ////////////////////////////////////////////////////////////////////////// NDArray HHsequence::getTail(const int idx) const { + int first = idx + 1 + _shift; - - int first = idx + 1 + _shift; - - if(_type == 'u') - return _vectors({first, -1, idx, idx+1}, true); - else - return _vectors({idx, idx+1, first, -1}, true); + if (_type == 'u') + return _vectors({first, -1, idx, idx + 1}, true); + else + return _vectors({idx, idx + 1, first, -1}, true); } ////////////////////////////////////////////////////////////////////////// template void HHsequence::applyTo_(NDArray& dest) { - int size = _type == 'u' ? _vectors.sizeAt(0) : _vectors.sizeAt(1); + int size = _type == 'u' ? _vectors.sizeAt(0) : _vectors.sizeAt(1); - if(dest.rankOf() != 2 || (dest.sizeAt(0) != size && dest.sizeAt(1) != size)) - dest = NDArray(dest.ordering(), {size, size}, dest.dataType(), dest.getContext()); - dest.setIdentity(); + if (dest.rankOf() != 2 || (dest.sizeAt(0) != size && dest.sizeAt(1) != size)) + dest = NDArray(dest.ordering(), {size, size}, + dest.dataType(), dest.getContext()); + dest.setIdentity(); - for(int k = _diagSize - 1; k >= 0; --k) { + for (int k = _diagSize - 1; k >= 0; --k) { + int curNum = size - k - _shift; + if (curNum < 1 || (k + 1 + _shift) >= size) continue; + auto block = dest({dest.sizeAt(0) - curNum, dest.sizeAt(0), + dest.sizeAt(1) - curNum, dest.sizeAt(1)}, + true); - int curNum = size - k - _shift; - if(curNum < 1 || (k + 1 + _shift) >= size ) - continue; - auto block = dest({dest.sizeAt(0)-curNum,dest.sizeAt(0), dest.sizeAt(1)-curNum,dest.sizeAt(1)}, true); - Householder::mulLeft(block, getTail(k), _coeffs.t(k)); - } + Householder::mulLeft(block, getTail(k), _coeffs.t(k)); + } } ////////////////////////////////////////////////////////////////////////// void HHsequence::applyTo(NDArray& dest) { - auto xType = _coeffs.dataType(); - BUILD_SINGLE_SELECTOR(xType, applyTo_, (dest), FLOAT_TYPES); + auto xType = _coeffs.dataType(); + + BUILD_SINGLE_SELECTOR(xType, applyTo_, (dest), FLOAT_TYPES); } -////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////// void HHsequence::mulLeft(NDArray& matrix) { - auto xType = _coeffs.dataType(); - BUILD_SINGLE_SELECTOR(xType, mulLeft_, (matrix), FLOAT_TYPES); + auto xType = _coeffs.dataType(); + + BUILD_SINGLE_SELECTOR(xType, mulLeft_, (matrix), FLOAT_TYPES); } BUILD_SINGLE_TEMPLATE(template void HHsequence::applyTo_, (sd::NDArray &dest), FLOAT_TYPES); diff --git a/libnd4j/include/helpers/impl/householder.cpp b/libnd4j/include/helpers/impl/householder.cpp index e9572f9f666c..a149021fd965 100644 --- a/libnd4j/include/helpers/impl/householder.cpp +++ b/libnd4j/include/helpers/impl/householder.cpp @@ -202,10 +202,10 @@ void Householder::mulRight(NDArray& matrix, const NDArray& tail, const T coef } -template class ND4J_EXPORT Householder; -template class ND4J_EXPORT Householder; -template class ND4J_EXPORT Householder; -template class ND4J_EXPORT Householder; +template class SD_EXPORT Householder; +template class SD_EXPORT Householder; +template class SD_EXPORT Householder; +template class SD_EXPORT Householder; diff --git a/libnd4j/include/helpers/impl/jacobiSVD.cpp b/libnd4j/include/helpers/impl/jacobiSVD.cpp index 7fbf183b2b48..53e87d602257 100644 --- a/libnd4j/include/helpers/impl/jacobiSVD.cpp +++ b/libnd4j/include/helpers/impl/jacobiSVD.cpp @@ -21,190 +21,195 @@ #include #include #include +#include namespace sd { namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// template -JacobiSVD::JacobiSVD(const NDArray& matrix, const bool calcU, const bool calcV, const bool fullUV) { +JacobiSVD::JacobiSVD(const NDArray& matrix, const bool calcU, + const bool calcV, const bool fullUV) { - if(matrix.rankOf() != 2 || matrix.isScalar()) - throw std::runtime_error("ops::helpers::JacobiSVD constructor: input array must be 2D matrix !"); + if(matrix.rankOf() != 2 || matrix.isScalar()) + throw std::runtime_error("ops::helpers::JacobiSVD constructor: input array must be 2D matrix !"); - _rows = static_cast(matrix.sizeAt(0)); - _cols = static_cast(matrix.sizeAt(1)); - _diagSize = math::nd4j_min(_rows, _cols); + _rows = static_cast(matrix.sizeAt(0)); + _cols = static_cast(matrix.sizeAt(1)); + _diagSize = math::nd4j_min(_rows, _cols); - _calcU = calcU; - _calcV = calcV; - _fullUV = fullUV; + _calcU = calcU; + _calcV = calcV; + _fullUV = fullUV; - _s = NDArray(matrix.ordering(), {_diagSize, 1}, matrix.dataType(), matrix.getContext()); + _s = NDArray(matrix.ordering(), {_diagSize, 1}, matrix.dataType(), matrix.getContext()); - if(_calcU) { - if(_fullUV) - _u = NDArray(matrix.ordering(), {_rows, _rows}, matrix.dataType(), matrix.getContext()); - else - _u = NDArray(matrix.ordering(), {_rows, _diagSize}, matrix.dataType(), matrix.getContext()); - } + if(_calcU) { + if(_fullUV) + _u = NDArray(matrix.ordering(), {_rows, _rows}, matrix.dataType(), matrix.getContext()); else - _u = NDArray(matrix.ordering(), {_rows, 1}, matrix.dataType(), matrix.getContext()); - - if(_calcV) { - if(_fullUV) - _v = NDArray(matrix.ordering(), {_cols, _cols}, matrix.dataType(), matrix.getContext()); - else - _v = NDArray(matrix.ordering(), {_cols, _diagSize}, matrix.dataType(), matrix.getContext()); - } + _u = NDArray(matrix.ordering(), {_rows, _diagSize}, matrix.dataType(), matrix.getContext()); + } + else + _u = NDArray(matrix.ordering(), {_rows, 1}, matrix.dataType(), matrix.getContext()); + + if(_calcV) { + if(_fullUV) + _v = NDArray(matrix.ordering(), {_cols, _cols}, matrix.dataType(), matrix.getContext()); else - _v = NDArray(matrix.ordering(), {_cols, 1}, matrix.dataType(), matrix.getContext()); + _v = NDArray(matrix.ordering(), {_cols, _diagSize}, matrix.dataType(), matrix.getContext()); + } + else + _v = NDArray(matrix.ordering(), {_cols, 1}, matrix.dataType(), matrix.getContext()); - _m = NDArray(matrix.ordering(), {_diagSize, _diagSize}, matrix.dataType(), matrix.getContext()); + _m = NDArray(matrix.ordering(), {_diagSize, _diagSize}, matrix.dataType(), matrix.getContext()); - evalData(matrix); + evalData(matrix); } ////////////////////////////////////////////////////////////////////////// template -void JacobiSVD::mulRotationOnLeft(const int i, const int j, NDArray& block, const NDArray& rotation) { - - if(i < j) { - - if(j+1 > block.sizeAt(0)) - throw std::runtime_error("ops::helpers::JacobiSVD mulRotationOnLeft: second arguments is out of array row range !"); - - auto temp = block({i,j+1,j-i, 0,0,0}, true, true); - temp.assign(mmul(rotation, temp)); - - // auto pTemp = block({i,j+1,j-i, 0,0,0}, true, true); - // auto temp = pTemp.dup(); - // pTemp.assign(mmul(rotation, temp)); - } - else { - - if(j+1 > block.sizeAt(0) || i+1 > block.sizeAt(0)) - throw std::runtime_error("ops::helpers::JacobiSVD mulRotationOnLeft: some or both integer arguments are out of array row range !"); - - NDArray temp(block.ordering(), {2, block.sizeAt(1)}, block.dataType(), block.getContext()); - auto row1 = block({i,i+1, 0,0}, true); - auto row2 = block({j,j+1, 0,0}, true); - auto rowTemp1 = temp({0,1, 0,0}, true); - auto rowTemp2 = temp({1,2, 0,0}, true); - rowTemp1.assign(row1); - rowTemp2.assign(row2); - temp.assign(mmul(rotation, temp)); - row1.assign(rowTemp1); - row2.assign(rowTemp2); - } +void JacobiSVD::mulRotationOnLeft(const int i, const int j, NDArray& block, + const NDArray& rotation) { + if (i < j) { + if (j + 1 > block.sizeAt(0)) + throw std::runtime_error( + "ops::helpers::JacobiSVD mulRotationOnLeft: second arguments is out " + "of array row range !"); + + auto temp = block({i,j+1,j-i, 0,0,0}, true, true); + temp.assign(mmul(rotation, temp)); + + //auto pTemp = block({i, j + 1, j - i, 0, 0, 0}, true, true); + // auto temp = pTemp.dup(); + //pTemp.assign(mmul(rotation, temp)); + } else { + if (j + 1 > block.sizeAt(0) || i + 1 > block.sizeAt(0)) + throw std::runtime_error( + "ops::helpers::JacobiSVD mulRotationOnLeft: some or both integer " + "arguments are out of array row range !"); + + NDArray temp(block.ordering(), {2, block.sizeAt(1)}, + block.dataType(), block.getContext()); + auto row1 = block({i, i + 1, 0, 0}, true); + auto row2 = block({j, j + 1, 0, 0}, true); + auto rowTemp1 = temp({0, 1, 0, 0}, true); + auto rowTemp2 = temp({1, 2, 0, 0}, true); + rowTemp1.assign(row1); + rowTemp2.assign(row2); + temp.assign(mmul(rotation, temp)); + row1.assign(rowTemp1); + row2.assign(rowTemp2); + } } ////////////////////////////////////////////////////////////////////////// template -void JacobiSVD::mulRotationOnRight(const int i, const int j, NDArray& block, const NDArray& rotation) { - - if(i < j) { - - if(j+1 > block.sizeAt(1)) - throw std::runtime_error("ops::helpers::JacobiSVD mulRotationOnRight: second argument is out of array column range !"); - - auto temp = block({0,0,0, i,j+1,j-i}, true, true); +void JacobiSVD::mulRotationOnRight(const int i, const int j, NDArray& block, + const NDArray& rotation) { + if (i < j) { + if (j + 1 > block.sizeAt(1)) + throw std::runtime_error( + "ops::helpers::JacobiSVD mulRotationOnRight: second argument is out " + "of array column range !"); + + auto temp = block({0,0,0, i,j+1,j-i}, true, true); temp.assign(mmul(temp, rotation)); - // auto pTemp = block({0,0,0, i,j+1,j-i}, true, true); - // auto temp = pTemp.dup(); - // pTemp.assign(mmul(temp, rotation)); - } - else { - - if(j+1 > block.sizeAt(1) || i+1 > block.sizeAt(1)) - throw std::runtime_error("ops::helpers::JacobiSVD mulRotationOnRight: some or both integer arguments are out of array column range !"); - - NDArray temp(block.ordering(), {block.sizeAt(0), 2}, block.dataType(), block.getContext()); - auto col1 = block({0,0, i,i+1}, true); - auto col2 = block({0,0, j,j+1}, true); - auto colTemp1 = temp({0,0, 0,1}, true); - auto colTemp2 = temp({0,0, 1,2}, true); - colTemp1.assign(col1); - colTemp2.assign(col2); - temp.assign(mmul(temp, rotation)); - col1.assign(colTemp1); - col2.assign(colTemp2); - } + //auto pTemp = block({0, 0, 0, i, j + 1, j - i}, true, true); + // auto temp = pTemp.dup(); + //pTemp.assign(mmul(temp, rotation)); + } else { + if (j + 1 > block.sizeAt(1) || i + 1 > block.sizeAt(1)) + throw std::runtime_error( + "ops::helpers::JacobiSVD mulRotationOnRight: some or both integer " + "arguments are out of array column range !"); + + NDArray temp(block.ordering(), {block.sizeAt(0), 2}, + block.dataType(), block.getContext()); + auto col1 = block({0, 0, i, i + 1}, true); + auto col2 = block({0, 0, j, j + 1}, true); + auto colTemp1 = temp({0, 0, 0, 1}, true); + auto colTemp2 = temp({0, 0, 1, 2}, true); + colTemp1.assign(col1); + colTemp2.assign(col2); + temp.assign(mmul(temp, rotation)); + col1.assign(colTemp1); + col2.assign(colTemp2); + } } ////////////////////////////////////////////////////////////////////////// template bool JacobiSVD::isBlock2x2NotDiag(NDArray& block, int p, int q, T& maxElem) { + NDArray rotation(_m.ordering(), {2, 2}, _m.dataType(), + _m.getContext()); + T n = math::nd4j_sqrt(block.t(p, p) * block.t(p, p) + + block.t(q, p) * block.t(q, p)); - NDArray rotation(_m.ordering(), {2, 2}, _m.dataType(), _m.getContext()); + const T almostZero = DataTypeUtils::min(); + const T precision = DataTypeUtils::eps(); - T n = math::nd4j_sqrt(block.t(p, p) * block.t(p, p) + block.t(q, p)*block.t(q, p)); + if (n == (T)0.f) { + block.r(p, p) = (T)0; + block.r(q, p) = (T)0; + } else { + T v = block.t(p, p) / n; - const T almostZero = DataTypeUtils::min(); - const T precision = DataTypeUtils::eps(); + rotation.r(0,0) = rotation.r(1, 1) = v; - if(n == (T)0.f) { - block.r(p, p) = (T)0; - block.r(q, p) = (T)0; - } else { - T v = block.t(p, p) / n; + v = block.t(q, p) / n; + rotation.r(0,1) = v; - rotation.r(0,0) = rotation.r(1,1) = v; + rotation.r(1, 0) = -rotation.template t(0,1); + mulRotationOnLeft(p, q, block, rotation); - v = block.t(q, p) / n; - rotation.r(0,1) = v; + if (_calcU) - rotation.r(1,0) = -rotation.template t(0,1); - mulRotationOnLeft(p, q, block, rotation); + mulRotationOnRight(p, q, _u, rotation.transpose()); + } - if(_calcU) - mulRotationOnRight(p, q, _u, rotation.transpose()); - } + maxElem = math::nd4j_max( + maxElem, math::nd4j_max(math::nd4j_abs(block.t(p, p)), + math::nd4j_abs(block.t(q, q)))); + T threshold = math::nd4j_max(almostZero, precision * maxElem); - maxElem = math::nd4j_max(maxElem, math::nd4j_max(math::nd4j_abs(block.t(p, p)), math::nd4j_abs(block.t(q, q)))); - T threshold = math::nd4j_max(almostZero, precision * maxElem); - return math::nd4j_abs(block.t(p, q)) > threshold || math::nd4j_abs(block.t(q, p)) > threshold; + return math::nd4j_abs(block.t(p, q)) > threshold || math::nd4j_abs(block.t(q, p)) > threshold; } ////////////////////////////////////////////////////////////////////////// template -bool JacobiSVD::createJacobiRotation(const T& x, const T& y, const T& z, NDArray& rotation) { - - T denom = (T)(2.f)* math::nd4j_abs(y); - - if(denom < DataTypeUtils::min()) { - - rotation.r(0,0) = rotation.r(1,1) = (T)1.f; - rotation.r(0,1) = rotation.r(1,0) = (T)0.f; - - return false; - } - else { - - T tau = (x-z)/denom; - T w = math::nd4j_sqrt(tau*tau + (T)1.f); - T t; - - if(tau > (T)0.) - t = (T)1.f / (tau + w); - else - t = (T)1.f / (tau - w); +bool JacobiSVD::createJacobiRotation(const T& x, const T& y, const T& z, + NDArray& rotation) { + T denom = (T)(2.f)* math::nd4j_abs(y); + + if (denom < DataTypeUtils::min()) { + rotation.r(0, 0) = rotation.r(1, 1) = (T)1.f; + rotation.r(0, 1) = rotation.r(1, 0) = (T)0.f; + return false; + } else { + T tau = (x - z) / denom; + T w = math::nd4j_sqrt(tau * tau + (T)1.f); + T t; + + if (tau > (T)0.) + t = (T)1.f / (tau + w); + else + t = (T)1.f / (tau - w); - T sign = t > (T)0. ? (T)1.f : (T)-1.f; + T sign = t > (T)0. ? (T)1.f : (T)-1.f; - T cos = (T)1.f / math::nd4j_sqrt(t*t + (T)1.f); - T sin = -sign * (y / math::nd4j_abs(y)) * math::nd4j_abs(t) * cos; + T cos = (T)1.f / math::nd4j_sqrt(t*t + (T) 1.f); + T sin = -sign * (y / math::nd4j_abs(y)) * math::nd4j_abs(t) * cos; - rotation.r(0,1) = sin; - rotation.r(1,0) = -sin; - rotation.r(0,0) = rotation.r(1,1) = cos; + rotation.r(0,1) = sin; + rotation.r(1, 0) = -sin; + rotation.r(0, 0) = rotation.r(1,1) = cos; - return true; - } + return true; + } } @@ -250,170 +255,165 @@ void JacobiSVD::createJacobiRotationGivens(const T& p, const T& q, NDArray& r ////////////////////////////////////////////////////////////////////////// template -void JacobiSVD::svd2x2(const NDArray& block, int p, int q, NDArray& left, NDArray& right) { - - NDArray m(block.ordering(), {2, 2}, block.dataType(), block.getContext()); - m.r(0,0) = block.t(p,p); - m.r(0,1) = block.t(p,q); - m.r(1,0) = block.t(q,p); - m.r(1,1) = block.t(q,q); - - NDArray rotation(block.ordering(), {2, 2}, block.dataType(), block.getContext()); - T t = m.t(0,0) + m.t(1,1); - T d = m.t(1,0) - m.t(0,1); - - if(math::nd4j_abs(d) < DataTypeUtils::min()) { - - rotation.r(0,0) = rotation.r(1,1) = (T)1; - rotation.r(0,1) = rotation.r(1,0) = (T)0; - } - else { - - T u = t / d; - T tmp = math::nd4j_sqrt((T)1.f + u*u); - rotation.r(0,0) = rotation.r(1,1) = u / tmp; - rotation.r(0,1) = (T)1.f / tmp; - rotation.r(1,0) = -rotation.t(0,1); - } - - m.assign(mmul(rotation, m)); - - createJacobiRotation(m.t(0,0), m.t(0,1), m.t(1,1), right); - - left.assign(mmul(rotation, right.transpose())); +void JacobiSVD::svd2x2(const NDArray& block, int p, int q, NDArray& left, + NDArray& right) { + NDArray m(block.ordering(), {2, 2}, block.dataType(), + block.getContext()); + m.r(0, 0) = block.t(p, p); + m.r(0, 1) = block.t(p, q); + m.r(1, 0) = block.t(q, p); + m.r(1, 1) = block.t(q, q); + + NDArray rotation(block.ordering(), {2, 2}, + block.dataType(), block.getContext()); + T t = m.t(0, 0) + m.t(1, 1); + T d = m.t(1, 0) - m.t(0, 1); + + if (math::nd4j_abs(d) < DataTypeUtils::min()) { + rotation.r(0, 0) = + rotation.r(1, 1) = (T)1; + rotation.r(0, 1) = + rotation.r(1, 0) = (T)0; + } else { + T u = t / d; + T tmp = math::nd4j_sqrt((T)1.f + u * u); + rotation.r(0, 0) = + rotation.r(1, 1) = u / tmp; + rotation.r(0, 1) = (T)1.f / tmp; + rotation.r(1, 0) = -rotation.t(0, 1); + } + + m.assign(mmul(rotation, m)); + + createJacobiRotation(m.t(0, 0), m.t(0, 1), m.t(1, 1), right); + + + left.assign(mmul(rotation, right.transpose())); } - ////////////////////////////////////////////////////////////////////////// template void JacobiSVD::evalData(const NDArray& matrix) { + const T precision = (T)2.f * DataTypeUtils::eps(); + const T almostZero = DataTypeUtils::min(); - const T precision = (T)2.f * DataTypeUtils::eps(); - const T almostZero = DataTypeUtils::min(); - - T scale = matrix.reduceNumber(reduce::AMax).template t(0); - if(scale== (T)0.f) - scale = (T)1.f; - - if(_rows > _cols) { + T scale = matrix.reduceNumber(reduce::AMax).template t(0); + if (scale == (T)0.f) scale = (T)1.f; - HHcolPivQR qr(matrix / scale); - _m.assign(qr._qr({0,_cols, 0,_cols})); - _m.fillAsTriangular(0., 0, 0, _m, 'l'); + if (_rows > _cols) { + HHcolPivQR qr(matrix / scale); + _m.assign(qr._qr({0, _cols, 0, _cols})); + _m.fillAsTriangular(0., 0, 0, _m, 'l'); - HHsequence hhSeg(qr._qr, qr._coeffs, 'u'); + HHsequence hhSeg(qr._qr, qr._coeffs, 'u'); - if(_fullUV) - hhSeg.applyTo(_u); - else if(_calcU) { - _u.setIdentity(); - hhSeg.mulLeft(_u); - } - - if(_calcV) - _v.assign(qr._permut); + if (_fullUV) + hhSeg.applyTo(_u); + else if (_calcU) { + _u.setIdentity(); + hhSeg.mulLeft(_u); } - else if(_rows < _cols) { - HHcolPivQR qr(matrix.transpose() / scale); - _m.assign(qr._qr({0,_rows, 0,_rows})); - _m.fillAsTriangular(0., 0, 0, _m, 'l'); - _m.transposei(); + if (_calcV) _v.assign(qr._permut); + } else if (_rows < _cols) { - HHsequence hhSeg(qr._qr, qr._coeffs, 'u'); // type = 'u' is not mistake here ! + HHcolPivQR qr(matrix.transpose() / scale); + _m.assign(qr._qr({0, _rows, 0, _rows})); + _m.fillAsTriangular(0., 0, 0, _m, 'l'); + _m.transposei(); - if(_fullUV) - hhSeg.applyTo(_v); - else if(_calcV) { - _v.setIdentity(); - hhSeg.mulLeft(_v); - } + HHsequence hhSeg(qr._qr, qr._coeffs, + 'u'); // type = 'u' is not mistake here ! - if(_calcU) - _u.assign(qr._permut); + if (_fullUV) + hhSeg.applyTo(_v); + else if (_calcV) { + _v.setIdentity(); + hhSeg.mulLeft(_v); } - else { - _m.assign(matrix({0,_diagSize, 0,_diagSize}) / scale); + if (_calcU) _u.assign(qr._permut); + } else { + _m.assign( + static_cast(matrix({0, _diagSize, 0, _diagSize})) / + scale); - if(_calcU) - _u.setIdentity(); - - if(_calcV) - _v.setIdentity(); - } - - T maxDiagElem = 0.; - for(int i = 0; i < _diagSize; ++i) { - T current = math::nd4j_abs(_m.t(i,i)); - if(maxDiagElem < current ) - maxDiagElem = current; - } + _m.assign(matrix({0,_diagSize, 0,_diagSize}) / scale); - bool stop = false; + if(_calcU) _u.setIdentity(); - while(!stop) { + if(_calcV) _v.setIdentity(); + } - stop = true; + T maxDiagElem = 0.; + for (int i = 0; i < _diagSize; ++i) { + T current = math::nd4j_abs(_m.t(i, i)); + if (maxDiagElem < current) maxDiagElem = current; + } - for(int p = 1; p < _diagSize; ++p) { + bool stop = false; - for(int q = 0; q < p; ++q) { + while (!stop) { + stop = true; - T threshold = math::nd4j_max(almostZero, precision * maxDiagElem); + for (int p = 1; p < _diagSize; ++p) { + for (int q = 0; q < p; ++q) { + T threshold = math::nd4j_max(almostZero, precision * maxDiagElem); - if(math::nd4j_abs(_m.t(p,q)) > threshold || math::nd4j_abs(_m.t(q,p)) > threshold){ + if (math::nd4j_abs(_m.t(p, q)) > threshold || + math::nd4j_abs(_m.t(q, p)) > threshold) { + stop = false; - stop = false; + // if(isBlock2x2NotDiag(_m, p, q, maxDiagElem)) + { + NDArray rotLeft( + _m.ordering(), {2, 2}, _m.dataType(), _m.getContext()); + NDArray rotRight( + _m.ordering(), {2, 2}, _m.dataType(), _m.getContext()); + svd2x2(_m, p, q, rotLeft, rotRight); - // if(isBlock2x2NotDiag(_m, p, q, maxDiagElem)) - { - NDArray rotLeft(_m.ordering(), {2, 2}, _m.dataType(), _m.getContext()); - NDArray rotRight(_m.ordering(), {2, 2}, _m.dataType(), _m.getContext()); - svd2x2(_m, p, q, rotLeft, rotRight); + mulRotationOnLeft(p, q, _m, rotLeft); - mulRotationOnLeft(p, q, _m, rotLeft); + if (_calcU) - if(_calcU) - mulRotationOnRight(p, q, _u, rotLeft.transpose()); + mulRotationOnRight(p, q, _u, rotLeft.transpose()); - mulRotationOnRight(p, q, _m, rotRight); + mulRotationOnRight(p, q, _m, rotRight); - if(_calcV) - mulRotationOnRight(p, q, _v, rotRight); + if (_calcV) mulRotationOnRight(p, q, _v, rotRight); - maxDiagElem = math::nd4j_max(maxDiagElem, math::nd4j_max(math::nd4j_abs(_m.t(p,p)), math::nd4j_abs(_m.t(q,q)))); - } - } - } + maxDiagElem = math::nd4j_max( + maxDiagElem, + math::nd4j_max(math::nd4j_abs(_m.t(p, p)), + math::nd4j_abs(_m.t(q, q)))); + } } + } } + } - for(int i = 0; i < _diagSize; ++i) { - - _s.r(i) = math::nd4j_abs(_m.t(i,i)); - - if(_calcU && _m.t(i,i) < (T)0.) { - auto temp = _u({0,0, i,i+1}, true); - temp.applyTransform(transform::Neg, temp, nullptr); - } + for (int i = 0; i < _diagSize; ++i) { + _s.r(i) = math::nd4j_abs(_m.t(i, i)); + if (_calcU && _m.t(i, i) < (T)0.) { + auto temp = _u({0, 0, i, i + 1}, true); + temp.applyTransform(transform::Neg, temp, nullptr); } + } - _s *= scale; + _s *= scale; - for(int i = 0; i < _diagSize; i++) { + for (int i = 0; i < _diagSize; i++) { + int pos = + (_s({i, -1, 0, 0}).indexReduceNumber(indexreduce::IndexMax, nullptr)) + .template e(0); + T maxSingVal = _s({i, -1, 0, 0}).reduceNumber(reduce::Max).template t(0); - int pos = (_s({i,-1, 0,0}).indexReduceNumber(indexreduce::IndexMax, nullptr)).template e(0); - T maxSingVal = _s({i,-1, 0,0}).reduceNumber(reduce::Max).template t(0); + if (maxSingVal == (T)0.) break; - if(maxSingVal == (T)0.) - break; + if (pos) { + pos += i; - if(pos) { - - pos += i; - - math::nd4j_swap(_s.r(i), _s.r(pos)); + math::nd4j_swap(_s.r(i), _s.r(pos)); if(_calcU) { auto temp1 = _u({0,0, pos,pos+1}, true); @@ -430,19 +430,11 @@ void JacobiSVD::evalData(const NDArray& matrix) { } } +template class SD_EXPORT JacobiSVD; +template class SD_EXPORT JacobiSVD; +template class SD_EXPORT JacobiSVD; +template class SD_EXPORT JacobiSVD; -template class ND4J_EXPORT JacobiSVD; -template class ND4J_EXPORT JacobiSVD; -template class ND4J_EXPORT JacobiSVD; -template class ND4J_EXPORT JacobiSVD; - - - - - - - -} -} -} - +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/helpers/impl/logger.cpp b/libnd4j/include/helpers/impl/logger.cpp index 59d8f98bc1de..07110a2b5957 100644 --- a/libnd4j/include/helpers/impl/logger.cpp +++ b/libnd4j/include/helpers/impl/logger.cpp @@ -22,48 +22,48 @@ namespace sd { - #ifdef __CUDACC__ - __host__ +__host__ #endif - void Logger::info(const char *format, ...) { - va_list args; - va_start(args, format); + void + Logger::info(const char *format, ...) { + va_list args; + va_start(args, format); - vprintf(format, args); + vprintf(format, args); - va_end(args); + va_end(args); - fflush(stdout); - } + fflush(stdout); +} #ifdef __CUDACC__ - __host__ +__host__ #endif - void Logger::printv(const char *format, const std::vector& vec) { - printf("%s: {", format); - for(int e = 0; e < vec.size(); e++) { - auto v = vec[e]; - printf("%i", v); - if (e < vec.size() - 1) - printf(", "); - } - printf("}\n"); - fflush(stdout); - } + void + Logger::printv(const char *format, const std::vector &vec) { + printf("%s: {", format); + for (int e = 0; e < vec.size(); e++) { + auto v = vec[e]; + printf("%i", v); + if (e < vec.size() - 1) printf(", "); + } + printf("}\n"); + fflush(stdout); +} - #ifdef __CUDACC__ - __host__ +#ifdef __CUDACC__ +__host__ #endif - void Logger::printv(const char *format, const std::vector& vec) { - printf("%s: {", format); - for(int e = 0; e < vec.size(); e++) { - auto v = vec[e]; - printf("%lld", (long long) v); - if (e < vec.size() - 1) - printf(", "); - } - printf("}\n"); - fflush(stdout); - } -} \ No newline at end of file + void + Logger::printv(const char *format, const std::vector &vec) { + printf("%s: {", format); + for (int e = 0; e < vec.size(); e++) { + auto v = vec[e]; + printf("%lld", (long long)v); + if (e < vec.size() - 1) printf(", "); + } + printf("}\n"); + fflush(stdout); +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/impl/shape.cpp b/libnd4j/include/helpers/impl/shape.cpp index 33042e0b76f2..680d775ad683 100644 --- a/libnd4j/include/helpers/impl/shape.cpp +++ b/libnd4j/include/helpers/impl/shape.cpp @@ -20,7 +20,4 @@ #include -namespace shape { - -} - +namespace shape {} diff --git a/libnd4j/include/helpers/impl/unicode.cpp b/libnd4j/include/helpers/impl/unicode.cpp index 6ebbe7c1b478..0c8593ef8f6b 100644 --- a/libnd4j/include/helpers/impl/unicode.cpp +++ b/libnd4j/include/helpers/impl/unicode.cpp @@ -23,434 +23,420 @@ namespace sd { namespace unicode { - constexpr uint32_t ONEBYTEBOUND = 0x00000080; - constexpr uint32_t TWOBYTEBOUND = 0x00000800; - constexpr uint32_t THREEBYTEBOUND = 0x00010000; - constexpr uint16_t HIGHBYTEMIN = 0xd800u; - constexpr uint16_t HIGHBYTEMAX = 0xdbffu; - constexpr uint16_t TRAILBYTEMIN = 0xdc00u; - constexpr uint16_t TRAILBYTEMAX = 0xdfffu; - constexpr uint16_t HIGHBYTEOFFSET = HIGHBYTEMIN - (0x10000 >> 10); - constexpr uint32_t BYTEOFFSET = 0x10000u - (HIGHBYTEMIN << 10) - TRAILBYTEMIN; - // Maximum valid value for a Unicode code point - constexpr uint32_t CODEPOINTMAX = 0x0010ffffu; - - template - FORCEINLINE uint8_t castToU8(const T cp) { - return static_cast(0xff & cp); - } +constexpr uint32_t ONEBYTEBOUND = 0x00000080; +constexpr uint32_t TWOBYTEBOUND = 0x00000800; +constexpr uint32_t THREEBYTEBOUND = 0x00010000; +constexpr uint16_t HIGHBYTEMIN = 0xd800u; +constexpr uint16_t HIGHBYTEMAX = 0xdbffu; +constexpr uint16_t TRAILBYTEMIN = 0xdc00u; +constexpr uint16_t TRAILBYTEMAX = 0xdfffu; +constexpr uint16_t HIGHBYTEOFFSET = HIGHBYTEMIN - (0x10000 >> 10); +constexpr uint32_t BYTEOFFSET = 0x10000u - (HIGHBYTEMIN << 10) - TRAILBYTEMIN; +// Maximum valid value for a Unicode code point +constexpr uint32_t CODEPOINTMAX = 0x0010ffffu; + +template +FORCEINLINE uint8_t castToU8(const T cp) { + return static_cast(0xff & cp); +} - template - FORCEINLINE uint16_t castToU16(const T cp) { - return static_cast(0xffff & cp); - } +template +FORCEINLINE uint16_t castToU16(const T cp) { + return static_cast(0xffff & cp); +} - template - FORCEINLINE uint32_t castToU32(const T cp) { - return static_cast(0xffffff & cp); - } +template +FORCEINLINE uint32_t castToU32(const T cp) { + return static_cast(0xffffff & cp); +} - template - FORCEINLINE bool isTrail(const T cp) { - return ((castToU8(cp) >> 6) == 0x2); - } +template +FORCEINLINE bool isTrail(const T cp) { + return ((castToU8(cp) >> 6) == 0x2); +} - template - FORCEINLINE bool isHighSurrogate(const T cp) { - return (cp & 0xfffffc00) == 0xd800; - } +template +FORCEINLINE bool isHighSurrogate(const T cp) { + return (cp & 0xfffffc00) == 0xd800; +} - template - bool isLowSurrogate(const T cp) { - return (cp & 0xfffffc00) == 0xdc00; - } +template +bool isLowSurrogate(const T cp) { + return (cp & 0xfffffc00) == 0xdc00; +} - template - FORCEINLINE bool isLeadSurrogate(const T cp) { - return (cp >= HIGHBYTEMIN && cp <= HIGHBYTEMAX); - } +template +FORCEINLINE bool isLeadSurrogate(const T cp) { + return (cp >= HIGHBYTEMIN && cp <= HIGHBYTEMAX); +} - template - FORCEINLINE bool isTrailSurrogate(const T cp) { - return (cp >= TRAILBYTEMIN && cp <= TRAILBYTEMAX); - } +template +FORCEINLINE bool isTrailSurrogate(const T cp) { + return (cp >= TRAILBYTEMIN && cp <= TRAILBYTEMAX); +} - template - FORCEINLINE bool isSurrogateU8(const T cp) { - return (cp >= HIGHBYTEMIN && cp <= TRAILBYTEMAX); - } +template +FORCEINLINE bool isSurrogateU8(const T cp) { + return (cp >= HIGHBYTEMIN && cp <= TRAILBYTEMAX); +} - template - FORCEINLINE bool isSurrogateU16(const T cp) { - return ((cp - 0xd800u) < 2048u); - } +template +FORCEINLINE bool isSurrogateU16(const T cp) { + return ((cp - 0xd800u) < 2048u); +} - template - FORCEINLINE bool isSymbolU8Valid(const T cp) { - return (cp <= CODEPOINTMAX && !isSurrogateU8(cp)); - } +template +FORCEINLINE bool isSymbolU8Valid(const T cp) { + return (cp <= CODEPOINTMAX && !isSurrogateU8(cp)); +} - template - FORCEINLINE bool isSymbolValid(const T cp) { - return (cp <= CODEPOINTMAX); - } +template +FORCEINLINE bool isSymbolValid(const T cp) { + return (cp <= CODEPOINTMAX); +} - template - FORCEINLINE uint32_t surrogateU32(const T& high, const T& low) { - return (high << 10) + low - 0x35fdc00; - } +template +FORCEINLINE uint32_t surrogateU32(const T& high, const T& low) { + return (high << 10) + low - 0x35fdc00; +} - template - Nd4jLong symbolLength(const T* it) { - uint8_t lead = castToU8(*it); - if (lead < 0x80) - return 1; - else if ((lead >> 5) == 0x6) - return 2; - else if ((lead >> 4) == 0xe) - return 3; - else if ((lead >> 3) == 0x1e) - return 4; - else - return 0; - } +template +Nd4jLong symbolLength(const T* it) { + uint8_t lead = castToU8(*it); + if (lead < 0x80) + return 1; + else if ((lead >> 5) == 0x6) + return 2; + else if ((lead >> 4) == 0xe) + return 3; + else if ((lead >> 3) == 0x1e) + return 4; + else + return 0; +} - template - Nd4jLong symbolLength32(const T* it) { - auto lead = castToU32(*it); - if (lead < ONEBYTEBOUND) - return 1; - else if (lead < TWOBYTEBOUND) - return 2; - else if (lead < THREEBYTEBOUND) - return 3; - else if (lead <= CODEPOINTMAX) - return 4; - else - return 0; - } +template +Nd4jLong symbolLength32(const T* it) { + auto lead = castToU32(*it); + if (lead < ONEBYTEBOUND) + return 1; + else if (lead < TWOBYTEBOUND) + return 2; + else if (lead < THREEBYTEBOUND) + return 3; + else if (lead <= CODEPOINTMAX) + return 4; + else + return 0; +} - template - Nd4jLong symbolLength16(const T* it) { - - uint32_t lead = castToU16(*it); - if (!isLeadSurrogate(lead)) { - if (lead < ONEBYTEBOUND) - return 1; - else if (lead < TWOBYTEBOUND) - return 2; - else if (lead < THREEBYTEBOUND) - return 3; - else - return 0; - } - else { - return 4; - } - } +template +Nd4jLong symbolLength16(const T* it) { + uint32_t lead = castToU16(*it); + if (!isLeadSurrogate(lead)) { + if (lead < ONEBYTEBOUND) + return 1; + else if (lead < TWOBYTEBOUND) + return 2; + else if (lead < THREEBYTEBOUND) + return 3; + else + return 0; + } else { + return 4; + } +} + +Nd4jLong offsetUtf8StringInUtf32(const void* start, const void* end) { + Nd4jLong count = 0; + for (auto it = static_cast(start); it != end; it++) { + auto length = symbolLength(it); + it += (length > 0) ? (length - 1) : 0; + count += 1; + } + return static_cast(count * sizeof(char32_t)); +} + +Nd4jLong offsetUtf16StringInUtf32(const void* start, const void* end) { + Nd4jLong count = 0; + for (auto it = static_cast(start); it != end;) { + auto length = symbolLength16(it); + it += (4 == length) ? 2 : 1; + count += 1; + } + return static_cast(count * sizeof(char32_t)); +} + +Nd4jLong offsetUtf8StringInUtf16(const void* start, const void* end) { + Nd4jLong count = 0; + for (auto it = static_cast(start); it != end; it++) { + auto length = symbolLength(it); + auto step = ((length > 0) ? (length - 1) : 0); + it += step; + count += (4 == length) ? 2 : 1; + } + return static_cast(count * sizeof(char16_t)); +} + +Nd4jLong offsetUtf16StringInUtf8(const void* start, const void* end) { + Nd4jLong count = 0; + for (auto it = static_cast(start); it != end;) { + auto length = symbolLength16(it); + it += (4 == length) ? 2 : 1; + count += length; + } + return static_cast(count); +} + +Nd4jLong offsetUtf32StringInUtf16(const void* start, const void* end) { + Nd4jLong count = 0; + for (auto it = static_cast(start); it != end; it++) { + auto length = symbolLength32(it); + count += (4 == length) ? 2 : 1; + ; + } + return static_cast(count * sizeof(char16_t)); +} - Nd4jLong offsetUtf8StringInUtf32(const void* start, const void* end) { - - Nd4jLong count = 0; - for (auto it = static_cast(start); it != end; it++) { - auto length = symbolLength(it); - it += (length > 0) ? (length - 1) : 0; - count += 1; - } - return static_cast(count * sizeof(char32_t)); +Nd4jLong offsetUtf32StringInUtf8(const void* start, const void* end) { + Nd4jLong count = 0; + for (auto it = static_cast(start); it != end; it++) { + count += symbolLength32(it); + } + return count; +} + +bool isStringValidU8(const void* start, const void* stop) { + for (auto it = static_cast(start); it != stop; it++) { + if (!isSymbolU8Valid(castToU8(*it))) { + return false; } - - Nd4jLong offsetUtf16StringInUtf32(const void* start, const void* end) { - - Nd4jLong count = 0; - for (auto it = static_cast(start); it != end;) { - auto length = symbolLength16(it); - it += (4 == length) ? 2 : 1; - count += 1; - } - return static_cast(count*sizeof(char32_t)); + } + return true; +} + +bool isStringValidU16(const void* start, const void* stop) { + for (auto it = static_cast(start); it != stop; it++) { + if (!isSymbolValid(castToU32(*it))) { + return false; } - - Nd4jLong offsetUtf8StringInUtf16(const void* start, const void* end) { - - Nd4jLong count = 0; - for (auto it = static_cast(start); it != end; it++) { - auto length = symbolLength(it); - auto step = ((length > 0) ? (length - 1) : 0); - it += step; - count += (4 == length) ? 2 : 1; - } - return static_cast(count*sizeof(char16_t)); + } + return true; +} + +bool isStringValidU32(const void* start, const void* stop) { + for (auto it = static_cast(start); it != stop; it++) { + if (!isSymbolValid(castToU32(*it))) { + return false; } - - Nd4jLong offsetUtf16StringInUtf8(const void* start, const void* end) { - - Nd4jLong count = 0; - for (auto it = static_cast(start); it != end;) { - auto length = symbolLength16(it); - it += (4 == length) ? 2 : 1; - count += length; - } - return static_cast(count); + } + return true; +} + +void* utf16to8Ptr(const void* start, const void* end, void* res) { + auto result = static_cast(res); + // result have to be pre-allocated + for (auto it = static_cast(start); it != end;) { + uint32_t cp = castToU16(*it++); + if (!isLeadSurrogate(cp)) { + if (cp < 0x80) { // for one byte + *(result++) = static_cast(cp); + } else if (cp < 0x800) { // for two bytes + *(result++) = static_cast((cp >> 6) | 0xc0); + *(result++) = static_cast((cp & 0x3f) | 0x80); + } else { // for three bytes + *(result++) = static_cast((cp >> 12) | 0xe0); + *(result++) = static_cast(((cp >> 6) & 0x3f) | 0x80); + *(result++) = static_cast((cp & 0x3f) | 0x80); + } + } else { + if (it != end) { + uint32_t trail_surrogate = castToU16(*it++); + if (isTrailSurrogate(trail_surrogate)) + cp = (cp << 10) + trail_surrogate + BYTEOFFSET; + } + // for four bytes + *(result++) = static_cast((cp >> 18) | 0xf0); + *(result++) = static_cast(((cp >> 12) & 0x3f) | 0x80); + *(result++) = static_cast(((cp >> 6) & 0x3f) | 0x80); + *(result++) = static_cast((cp & 0x3f) | 0x80); } - - Nd4jLong offsetUtf32StringInUtf16(const void* start, const void* end) { - - Nd4jLong count = 0; - for (auto it = static_cast(start); it != end; it++) { - auto length = symbolLength32(it); - count += (4 == length) ? 2 : 1;; - } - return static_cast(count*sizeof(char16_t)); + } + return result; +} + +void* utf8to16Ptr(const void* start, const void* end, void* res) { + auto result = static_cast(res); + // result have to be pre-allocated + for (auto it = static_cast(start); it != end;) { + auto nLength = symbolLength(it); + uint32_t cp = castToU8(*it++); + if (4 != nLength) { + if (2 == nLength) { + cp = ((cp << 6) & 0x7ff) + ((*it++) & 0x3f); + } else if (3 == nLength) { + cp = ((cp << 12) & 0xffff) + ((castToU8(*it++) << 6) & 0xfff); + cp += (*it++) & 0x3f; + } + *(result++) = static_cast(cp); + } else { + cp = ((cp << 18) & 0x1fffff) + ((castToU8(*it++) << 12) & 0x3ffff); + cp += (castToU8(*it++) << 6) & 0xfff; + cp += (*it++) & 0x3f; + // make a surrogate pair + *(result++) = static_cast((cp >> 10) + HIGHBYTEOFFSET); + *(result++) = static_cast((cp & 0x3ff) + TRAILBYTEMIN); } - - Nd4jLong offsetUtf32StringInUtf8(const void* start, const void* end) { - - Nd4jLong count = 0; - for (auto it = static_cast(start); it != end; it++) { - count += symbolLength32(it); - } - return count; + } + return result; +} + +void* utf32to8Ptr(const void* start, const void* end, void* result) { + auto res = static_cast(result); + // result have to be pre-allocated + for (auto it = static_cast(start); it != end; it++) { + if (*it < 0x80) // for one byte + *(res++) = static_cast(*it); + else if (*it < 0x800) { // for two bytes + *(res++) = static_cast((*it >> 6) | 0xc0); + *(res++) = static_cast((*it & 0x3f) | 0x80); + } else if (*it < 0x10000) { // for three bytes + *(res++) = static_cast((*it >> 12) | 0xe0); + *(res++) = static_cast(((*it >> 6) & 0x3f) | 0x80); + *(res++) = static_cast((*it & 0x3f) | 0x80); + } else { // for four bytes + *(res++) = static_cast((*it >> 18) | 0xf0); + *(res++) = static_cast(((*it >> 12) & 0x3f) | 0x80); + *(res++) = static_cast(((*it >> 6) & 0x3f) | 0x80); + *(res++) = static_cast((*it & 0x3f) | 0x80); } - - bool isStringValidU8(const void* start, const void* stop) { - for (auto it = static_cast(start); it != stop; it++) { - if (!isSymbolU8Valid( castToU8(*it) )) { - return false; - } - } - return true; + } + return result; +} + +void* utf8to32Ptr(const void* start, const void* end, void* res) { + auto result = static_cast(res); + // result have to be pre-allocated + for (auto it = static_cast(start); it != end;) { + auto nLength = symbolLength(it); + uint32_t cp = castToU8(*it++); + if (2 == nLength) { + cp = ((cp << 6) & 0x7ff) + ((*it++) & 0x3f); + } else if (3 == nLength) { + cp = ((cp << 12) & 0xffff) + ((castToU8(*it++) << 6) & 0xfff); + cp += (*it++) & 0x3f; + } else if (4 == nLength) { + cp = ((cp << 18) & 0x1fffff) + ((castToU8(*it++) << 12) & 0x3ffff); + cp += (castToU8(*it++) << 6) & 0xfff; + cp += (*it++) & 0x3f; } - - bool isStringValidU16(const void* start, const void* stop) { - for (auto it = static_cast(start); it != stop; it++) { - if (!isSymbolValid( castToU32(*it) )) { - return false; - } - } - return true; + (*result++) = cp; + } + return result; +} + +void* utf16to32Ptr(const void* start, const void* end, void* res) { + auto result = static_cast(res); + // result have to be pre-allocated + for (auto it = static_cast(start); it != end; it++) { + uint32_t cpHigh = castToU32(*it); + if (!isSurrogateU16(cpHigh)) { + *result++ = cpHigh; + } else { + it++; + uint32_t cpLow = castToU32(*it); + if (isHighSurrogate(cpHigh) && it != end && isLowSurrogate(cpLow)) { + *result++ = surrogateU32(cpHigh, cpLow); + } } - - bool isStringValidU32(const void* start, const void* stop) { - for (auto it = static_cast(start); it != stop; it++) { - if (!isSymbolValid( castToU32(*it) )) { - return false; - } - } - return true; + } + return result; +} + +void* utf32to16Ptr(const void* start, const void* end, void* res) { + auto result = static_cast(res); + // result have to be pre-allocate + for (auto it = static_cast(start); it != end; it++) { + uint32_t cpHigh = castToU32(*it); + // todo check do we need this as we have pre-validation, if yes find out how + // to check u16 + if (cpHigh < 0 || cpHigh > 0x10FFFF || + (cpHigh >= 0xD800 && cpHigh <= 0xDFFF)) { + // Invalid code point. Replace with sentinel, per Unicode standard: + *result++ = u'\uFFFD'; + } else if (cpHigh < 0x10000UL) { // In the BMP. + *result++ = static_cast(cpHigh); + } else { + *result++ = + static_cast(((cpHigh - 0x10000UL) / 0x400U) + 0xD800U); + *result++ = + static_cast(((cpHigh - 0x10000UL) % 0x400U) + 0xDC00U); } + } + return result; +} - void* utf16to8Ptr(const void* start, const void* end, void* res) { - - auto result = static_cast(res); - // result have to be pre-allocated - for (auto it = static_cast(start); it != end;) { - uint32_t cp = castToU16(*it++); - if (!isLeadSurrogate(cp)) { - if (cp < 0x80) { // for one byte - *(result++) = static_cast(cp); - } - else if (cp < 0x800) { // for two bytes - *(result++) = static_cast((cp >> 6) | 0xc0); - *(result++) = static_cast((cp & 0x3f) | 0x80); - } - else{ // for three bytes - *(result++) = static_cast((cp >> 12) | 0xe0); - *(result++) = static_cast(((cp >> 6) & 0x3f) | 0x80); - *(result++) = static_cast((cp & 0x3f) | 0x80); - } - } - else { - if (it != end) { - uint32_t trail_surrogate = castToU16(*it++); - if (isTrailSurrogate(trail_surrogate)) - cp = (cp << 10) + trail_surrogate + BYTEOFFSET; - } - // for four bytes - *(result++) = static_cast((cp >> 18) | 0xf0); - *(result++) = static_cast(((cp >> 12) & 0x3f) | 0x80); - *(result++) = static_cast(((cp >> 6) & 0x3f) | 0x80); - *(result++) = static_cast((cp & 0x3f) | 0x80); - } - } - return result; - } - - void* utf8to16Ptr(const void* start, const void* end, void* res) { - - auto result = static_cast(res); - // result have to be pre-allocated - for (auto it = static_cast(start); it != end;) { - - auto nLength = symbolLength(it); - uint32_t cp = castToU8(*it++); - if (4 != nLength) { - if (2 == nLength) { - cp = ((cp << 6) & 0x7ff) + ((*it++) & 0x3f); - } - else if (3 == nLength) { - cp = ((cp << 12) & 0xffff) + ((castToU8(*it++) << 6) & 0xfff); - cp += (*it++) & 0x3f; - } - *(result++) = static_cast(cp); - } - else { - cp = ((cp << 18) & 0x1fffff) + ((castToU8(*it++) << 12) & 0x3ffff); - cp += (castToU8(*it++) << 6) & 0xfff; - cp += (*it++) & 0x3f; - //make a surrogate pair - *(result++) = static_cast((cp >> 10) + HIGHBYTEOFFSET); - *(result++) = static_cast((cp & 0x3ff) + TRAILBYTEMIN); - } - } - return result; - } - - void* utf32to8Ptr( const void* start, const void* end, void* result) { - - auto res = static_cast(result); - // result have to be pre-allocated - for (auto it = static_cast(start); it != end; it++) { - - if (*it < 0x80) // for one byte - *(res++) = static_cast(*it); - else if (*it < 0x800) { // for two bytes - *(res++) = static_cast((*it >> 6) | 0xc0); - *(res++) = static_cast((*it & 0x3f) | 0x80); - } - else if (*it < 0x10000) { // for three bytes - *(res++) = static_cast((*it >> 12) | 0xe0); - *(res++) = static_cast(((*it >> 6) & 0x3f) | 0x80); - *(res++) = static_cast((*it & 0x3f) | 0x80); - } - else { // for four bytes - *(res++) = static_cast((*it >> 18) | 0xf0); - *(res++) = static_cast(((*it >> 12) & 0x3f) | 0x80); - *(res++) = static_cast(((*it >> 6) & 0x3f) | 0x80); - *(res++) = static_cast((*it & 0x3f) | 0x80); - } - } - return result; - } - - void* utf8to32Ptr(const void* start, const void* end, void* res) { - - auto result = static_cast(res); - // result have to be pre-allocated - for (auto it = static_cast(start); it != end;) { - - auto nLength = symbolLength(it); - uint32_t cp = castToU8(*it++); - if (2 == nLength) { - cp = ((cp << 6) & 0x7ff) + ((*it++) & 0x3f); - } - else if (3 == nLength) { - cp = ((cp << 12) & 0xffff) + ((castToU8(*it++) << 6) & 0xfff); - cp += (*it++) & 0x3f; - } - else if (4 == nLength) { - cp = ((cp << 18) & 0x1fffff) + ((castToU8(*it++) << 12) & 0x3ffff); - cp += (castToU8(*it++) << 6) & 0xfff; - cp += (*it++) & 0x3f; - } - (*result++) = cp; - } - return result; - } - - void* utf16to32Ptr(const void* start, const void* end, void* res) { - - auto result = static_cast(res); - // result have to be pre-allocated - for (auto it = static_cast(start); it != end; it++) { - - uint32_t cpHigh = castToU32(*it); - if (!isSurrogateU16(cpHigh)) { - *result++ = cpHigh; - } - else { - it++; - uint32_t cpLow = castToU32(*it); - if (isHighSurrogate(cpHigh) && it != end && isLowSurrogate(cpLow)) { - *result++ = surrogateU32(cpHigh, cpLow); - } - } - } - return result; - } - - void* utf32to16Ptr(const void* start, const void* end, void* res) { - - auto result = static_cast(res); - // result have to be pre-allocate - for (auto it = static_cast(start); it != end; it++) { - - uint32_t cpHigh = castToU32(*it); - // todo check do we need this as we have pre-validation, if yes find out how to check u16 - if (cpHigh < 0 || cpHigh > 0x10FFFF || (cpHigh >= 0xD800 && cpHigh <= 0xDFFF)) { - // Invalid code point. Replace with sentinel, per Unicode standard: - *result++ = u'\uFFFD'; - } - else if (cpHigh < 0x10000UL) { // In the BMP. - *result++ = static_cast(cpHigh); - } - else { - *result++ = static_cast(((cpHigh - 0x10000UL) / 0x400U) + 0xD800U); - *result++ = static_cast(((cpHigh - 0x10000UL) % 0x400U) + 0xDC00U); - } - } - return result; - } - - Nd4jLong offsetUtf8StringInUtf32(const void* input, uint32_t nInputSize) { - return offsetUtf8StringInUtf32(input, static_cast(input) + nInputSize); - } - - Nd4jLong offsetUtf16StringInUtf32(const void* input, uint32_t nInputSize) { - return offsetUtf16StringInUtf32(input, static_cast(input) + nInputSize); - } - - Nd4jLong offsetUtf8StringInUtf16(const void* input, uint32_t nInputSize) { - return offsetUtf8StringInUtf16(input, static_cast(input) + nInputSize); - } - - Nd4jLong offsetUtf16StringInUtf8(const void* input, uint32_t nInputSize) { - return offsetUtf16StringInUtf8(input, static_cast(input) + nInputSize); - } - - Nd4jLong offsetUtf32StringInUtf8(const void* input, uint32_t nInputSize) { - return offsetUtf32StringInUtf8(input, static_cast(input) + nInputSize); - } - - Nd4jLong offsetUtf32StringInUtf16(const void* input, const uint32_t nInputSize) { - return offsetUtf32StringInUtf16(input, static_cast(input) + nInputSize); - } - - bool utf8to16(const void* input, void* output, uint32_t nInputSize) { - return utf8to16Ptr(input, static_cast(input) + nInputSize, output); - } - - bool utf8to32(const void* input, void* output, uint32_t nInputSize) { - return utf8to32Ptr(input, static_cast(input) + nInputSize, output); - } - - bool utf16to32(const void* input, void* output, uint32_t nInputSize) { - return utf16to32Ptr(input, static_cast(input) + nInputSize, output); - } - - bool utf16to8(const void* input, void* output, uint32_t nInputSize) { - return utf16to8Ptr(input, static_cast(input) + nInputSize, output); - } - - bool utf32to16(const void* input, void* output, uint32_t nInputSize) { - return utf32to16Ptr(input, static_cast(input) + nInputSize, output); - } - - bool utf32to8(const void* input, void* output, const Nd4jLong nInputSize) { - return utf32to8Ptr(input, static_cast(input) + nInputSize, output); - } - - } +Nd4jLong offsetUtf8StringInUtf32(const void* input, uint32_t nInputSize) { + return offsetUtf8StringInUtf32( + input, static_cast(input) + nInputSize); +} +Nd4jLong offsetUtf16StringInUtf32(const void* input, uint32_t nInputSize) { + return offsetUtf16StringInUtf32( + input, static_cast(input) + nInputSize); } +Nd4jLong offsetUtf8StringInUtf16(const void* input, uint32_t nInputSize) { + return offsetUtf8StringInUtf16( + input, static_cast(input) + nInputSize); +} + +Nd4jLong offsetUtf16StringInUtf8(const void* input, uint32_t nInputSize) { + return offsetUtf16StringInUtf8( + input, static_cast(input) + nInputSize); +} + +Nd4jLong offsetUtf32StringInUtf8(const void* input, uint32_t nInputSize) { + return offsetUtf32StringInUtf8( + input, static_cast(input) + nInputSize); +} + +Nd4jLong offsetUtf32StringInUtf16(const void* input, + const uint32_t nInputSize) { + return offsetUtf32StringInUtf16( + input, static_cast(input) + nInputSize); +} + +bool utf8to16(const void* input, void* output, uint32_t nInputSize) { + return utf8to16Ptr(input, static_cast(input) + nInputSize, + output); +} + +bool utf8to32(const void* input, void* output, uint32_t nInputSize) { + return utf8to32Ptr(input, static_cast(input) + nInputSize, + output); +} + +bool utf16to32(const void* input, void* output, uint32_t nInputSize) { + return utf16to32Ptr(input, static_cast(input) + nInputSize, + output); +} + +bool utf16to8(const void* input, void* output, uint32_t nInputSize) { + return utf16to8Ptr(input, static_cast(input) + nInputSize, + output); +} + +bool utf32to16(const void* input, void* output, uint32_t nInputSize) { + return utf32to16Ptr(input, static_cast(input) + nInputSize, + output); +} + +bool utf32to8(const void* input, void* output, const Nd4jLong nInputSize) { + return utf32to8Ptr(input, static_cast(input) + nInputSize, + output); +} + +} // namespace unicode + +} // namespace sd diff --git a/libnd4j/include/helpers/jacobiSVD.h b/libnd4j/include/helpers/jacobiSVD.h index 615811e9a42e..bcfd75225a87 100644 --- a/libnd4j/include/helpers/jacobiSVD.h +++ b/libnd4j/include/helpers/jacobiSVD.h @@ -21,8 +21,8 @@ #ifndef LIBND4J_JACOBISVD_H #define LIBND4J_JACOBISVD_H -#include #include +#include namespace sd { namespace ops { @@ -30,44 +30,43 @@ namespace helpers { template class JacobiSVD { + public: + NDArray _m; + NDArray _s; // vector with singular values + NDArray _u; + NDArray _v; - public: - - NDArray _m; - NDArray _s; // vector with singular values - NDArray _u; - NDArray _v; - - int _diagSize; - int _rows; - int _cols; + int _diagSize; + int _rows; + int _cols; - // bool _transp; - bool _calcU; - bool _calcV; - bool _fullUV; + // bool _transp; + bool _calcU; + bool _calcV; + bool _fullUV; - JacobiSVD(const NDArray& matrix, const bool calcU, const bool calcV, const bool fullUV); + JacobiSVD(const NDArray& matrix, const bool calcU, const bool calcV, + const bool fullUV); - bool isBlock2x2NotDiag(NDArray& block, int p, int q, T& maxElem); + bool isBlock2x2NotDiag(NDArray& block, int p, int q, T& maxElem); - static bool createJacobiRotation(const T& x, const T& y, const T& z, NDArray& rotation); - static void createJacobiRotationGivens(const T& p, const T& q, NDArray& rotation); + static bool createJacobiRotation(const T& x, const T& y, const T& z, + NDArray& rotation); + static void createJacobiRotationGivens(const T& p, const T& q, NDArray& rotation); + static void svd2x2(const NDArray& block, int p, int q, NDArray& left, + NDArray& right); - static void svd2x2(const NDArray& block, int p, int q, NDArray& left, NDArray& right); + static void mulRotationOnLeft(const int i, const int j, NDArray& block, + const NDArray& rotation); - static void mulRotationOnLeft(const int i, const int j, NDArray& block, const NDArray& rotation); + static void mulRotationOnRight(const int i, const int j, NDArray& block, + const NDArray& rotation); - static void mulRotationOnRight(const int i, const int j, NDArray& block, const NDArray& rotation); - - void evalData(const NDArray& matrix); + void evalData(const NDArray& matrix); }; +} // namespace helpers +} // namespace ops +} // namespace sd - -} -} -} - - -#endif //LIBND4J_JACOBISVD_H +#endif // LIBND4J_JACOBISVD_H diff --git a/libnd4j/include/helpers/logger.h b/libnd4j/include/helpers/logger.h index b7ed88c1d266..903a2d9d7be5 100644 --- a/libnd4j/include/helpers/logger.h +++ b/libnd4j/include/helpers/logger.h @@ -21,22 +21,31 @@ #ifndef LIBND4J_LOGGER_H #define LIBND4J_LOGGER_H -#include -#include -#include -#include #include +#include +#include #include #include #include +#include +#include + #ifndef __CUDA_ARCH__ -#define nd4j_debug(FORMAT, ...) if (sd::Environment::getInstance().isDebug() && sd::Environment::getInstance().isVerbose()) sd::Logger::info(FORMAT, __VA_ARGS__); -#define nd4j_logger(FORMAT, ...) if (sd::Environment::getInstance().isDebug() && sd::Environment::getInstance().isVerbose()) sd::Logger::info(FORMAT, __VA_ARGS__); -#define nd4j_verbose(FORMAT, ...) if (sd::Environment::getInstance().isVerbose()) sd::Logger::info(FORMAT, __VA_ARGS__); +#define nd4j_debug(FORMAT, ...) \ + if (sd::Environment::getInstance().isDebug() && \ + sd::Environment::getInstance().isVerbose()) \ + sd::Logger::info(FORMAT, __VA_ARGS__); +#define nd4j_logger(FORMAT, ...) \ + if (sd::Environment::getInstance().isDebug() && \ + sd::Environment::getInstance().isVerbose()) \ + sd::Logger::info(FORMAT, __VA_ARGS__); +#define nd4j_verbose(FORMAT, ...) \ + if (sd::Environment::getInstance().isVerbose()) \ + sd::Logger::info(FORMAT, __VA_ARGS__); #define nd4j_printf(FORMAT, ...) sd::Logger::info(FORMAT, __VA_ARGS__); -#define nd4j_printv(FORMAT, VECTOR) sd::Logger::printv(FORMAT, VECTOR); +#define nd4j_printv(FORMAT, VECTOR) sd::Logger::printv(FORMAT, VECTOR); #else @@ -49,17 +58,15 @@ #endif namespace sd { - class ND4J_EXPORT Logger { - - public: - - static void _CUDA_H info(const char *format, ...); - - static void _CUDA_H printv(const char *format, const std::vector& vec); - static void _CUDA_H printv(const char *format, const std::vector& vec); - }; +class SD_EXPORT Logger { + public: + static void _CUDA_H info(const char *format, ...); -} + static void _CUDA_H printv(const char *format, const std::vector &vec); + static void _CUDA_H printv(const char *format, + const std::vector &vec); +}; +} // namespace sd -#endif //LIBND4J_LOGGER_H +#endif // LIBND4J_LOGGER_H diff --git a/libnd4j/include/helpers/mman.h b/libnd4j/include/helpers/mman.h index 618ee23c3960..4a09ab1ddcb4 100644 --- a/libnd4j/include/helpers/mman.h +++ b/libnd4j/include/helpers/mman.h @@ -22,16 +22,18 @@ #ifndef _SYS_MMAN_H_ #define _SYS_MMAN_H_ -#include #include #include +#include #ifndef FILE_MAP_EXECUTE -#define FILE_MAP_EXECUTE 0x0020 +#define FILE_MAP_EXECUTE 0x0020 #endif /* FILE_MAP_EXECUTE */ -#ifndef _WIN32_WINNT // Allow use of features specific to Windows XP or later. -#define _WIN32_WINNT 0x0501 // Change this to the appropriate value to target other versions of Windows. +#ifndef _WIN32_WINNT // Allow use of features specific to Windows XP or later. +#define _WIN32_WINNT \ + 0x0501 // Change this to the appropriate value to target other versions of + // Windows. #endif /* All the headers include this file. */ @@ -53,269 +55,252 @@ typedef uint32_t OffsetType; extern "C" { #endif -#define PROT_NONE 0 -#define PROT_READ 1 -#define PROT_WRITE 2 -#define PROT_EXEC 4 +#define PROT_NONE 0 +#define PROT_READ 1 +#define PROT_WRITE 2 +#define PROT_EXEC 4 -#define MAP_FILE 0 -#define MAP_SHARED 1 -#define MAP_PRIVATE 2 -#define MAP_TYPE 0xf -#define MAP_FIXED 0x10 -#define MAP_ANONYMOUS 0x20 -#define MAP_ANON MAP_ANONYMOUS +#define MAP_FILE 0 +#define MAP_SHARED 1 +#define MAP_PRIVATE 2 +#define MAP_TYPE 0xf +#define MAP_FIXED 0x10 +#define MAP_ANONYMOUS 0x20 +#define MAP_ANON MAP_ANONYMOUS -#define MAP_FAILED ((void *)-1) +#define MAP_FAILED ((void *)-1) /* Flags for msync. */ -#define MS_ASYNC 1 -#define MS_SYNC 2 -#define MS_INVALIDATE 4 - -void _mmap(Nd4jLong* result, size_t length, const char *fileName); -void* mmap(void *addr, size_t len, int prot, int flags, int fildes, OffsetType off); -int munmap(void *addr, size_t len); -int _mprotect(void *addr, size_t len, int prot); -int msync(void *addr, size_t len, int flags); -int mlock(const void *addr, size_t len); -int munlock(const void *addr, size_t len); +#define MS_ASYNC 1 +#define MS_SYNC 2 +#define MS_INVALIDATE 4 + +void _mmap(Nd4jLong *result, size_t length, const char *fileName); +void *mmap(void *addr, size_t len, int prot, int flags, int fildes, + OffsetType off); +int munmap(void *addr, size_t len); +int _mprotect(void *addr, size_t len, int prot); +int msync(void *addr, size_t len, int flags); +int mlock(const void *addr, size_t len); +int munlock(const void *addr, size_t len); #ifdef __cplusplus } #endif -static int __map_mman_error(const DWORD err, const int deferr) -{ - if (err == 0) - return 0; - //TODO: implement - return err; +static int __map_mman_error(const DWORD err, const int deferr) { + if (err == 0) return 0; + // TODO: implement + return err; } -static DWORD __map_mmap_prot_page(const int prot) -{ - DWORD protect = 0; - - if (prot == PROT_NONE) - return protect; - - if ((prot & PROT_EXEC) != 0) - { - protect = ((prot & PROT_WRITE) != 0) ? - PAGE_EXECUTE_READWRITE : PAGE_EXECUTE_READ; - } - else - { - protect = ((prot & PROT_WRITE) != 0) ? - PAGE_READWRITE : PAGE_READONLY; - } - - return protect; +static DWORD __map_mmap_prot_page(const int prot) { + DWORD protect = 0; + + if (prot == PROT_NONE) return protect; + + if ((prot & PROT_EXEC) != 0) { + protect = + ((prot & PROT_WRITE) != 0) ? PAGE_EXECUTE_READWRITE : PAGE_EXECUTE_READ; + } else { + protect = ((prot & PROT_WRITE) != 0) ? PAGE_READWRITE : PAGE_READONLY; + } + + return protect; } -static DWORD __map_mmap_prot_file(const int prot) -{ - DWORD desiredAccess = 0; +static DWORD __map_mmap_prot_file(const int prot) { + DWORD desiredAccess = 0; - if (prot == PROT_NONE) - return desiredAccess; + if (prot == PROT_NONE) return desiredAccess; - if ((prot & PROT_READ) != 0) - desiredAccess |= FILE_MAP_READ; - if ((prot & PROT_WRITE) != 0) - desiredAccess |= FILE_MAP_WRITE; - if ((prot & PROT_EXEC) != 0) - desiredAccess |= FILE_MAP_EXECUTE; + if ((prot & PROT_READ) != 0) desiredAccess |= FILE_MAP_READ; + if ((prot & PROT_WRITE) != 0) desiredAccess |= FILE_MAP_WRITE; + if ((prot & PROT_EXEC) != 0) desiredAccess |= FILE_MAP_EXECUTE; - return desiredAccess; + return desiredAccess; } -void _mmap(Nd4jLong* result, size_t length, const char *fileName) { - HANDLE fm, h; +void _mmap(Nd4jLong *result, size_t length, const char *fileName) { + HANDLE fm, h; - void * map = MAP_FAILED; - OffsetType off = 0; - int prot = PROT_READ | PROT_WRITE; + void *map = MAP_FAILED; + OffsetType off = 0; + int prot = PROT_READ | PROT_WRITE; #ifdef _MSC_VER - #pragma warning(push) - #pragma warning(disable: 4293) +#pragma warning(push) +#pragma warning(disable : 4293) #endif - const DWORD dwFileOffsetLow = (sizeof(OffsetType) <= sizeof(DWORD)) ? - (DWORD)off : (DWORD)(off & 0xFFFFFFFFL); - const DWORD dwFileOffsetHigh = (sizeof(OffsetType) <= sizeof(DWORD)) ? - (DWORD)0 : (DWORD)((off >> 32) & 0xFFFFFFFFL); - const DWORD protect = __map_mmap_prot_page(prot); - const DWORD desiredAccess = __map_mmap_prot_file(prot); + const DWORD dwFileOffsetLow = (sizeof(OffsetType) <= sizeof(DWORD)) + ? (DWORD)off + : (DWORD)(off & 0xFFFFFFFFL); + const DWORD dwFileOffsetHigh = (sizeof(OffsetType) <= sizeof(DWORD)) + ? (DWORD)0 + : (DWORD)((off >> 32) & 0xFFFFFFFFL); + const DWORD protect = __map_mmap_prot_page(prot); + const DWORD desiredAccess = __map_mmap_prot_file(prot); - const OffsetType maxSize = off + (OffsetType) length; + const OffsetType maxSize = off + (OffsetType)length; - const DWORD dwMaxSizeLow = (sizeof(OffsetType) <= sizeof(DWORD)) ? - (DWORD)maxSize : (DWORD)(maxSize & 0xFFFFFFFFL); - const DWORD dwMaxSizeHigh = (sizeof(OffsetType) <= sizeof(DWORD)) ? - (DWORD)0 : (DWORD)((maxSize >> 32) & 0xFFFFFFFFL); + const DWORD dwMaxSizeLow = (sizeof(OffsetType) <= sizeof(DWORD)) + ? (DWORD)maxSize + : (DWORD)(maxSize & 0xFFFFFFFFL); + const DWORD dwMaxSizeHigh = (sizeof(OffsetType) <= sizeof(DWORD)) + ? (DWORD)0 + : (DWORD)((maxSize >> 32) & 0xFFFFFFFFL); #ifdef _MSC_VER - #pragma warning(pop) +#pragma warning(pop) #endif - h = CreateFileA(fileName, GENERIC_READ | GENERIC_WRITE, FILE_SHARE_WRITE | FILE_SHARE_READ, nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr); + h = CreateFileA(fileName, GENERIC_READ | GENERIC_WRITE, + FILE_SHARE_WRITE | FILE_SHARE_READ, nullptr, OPEN_EXISTING, + FILE_ATTRIBUTE_NORMAL, nullptr); - if (h == INVALID_HANDLE_VALUE) { - errno = __map_mman_error(GetLastError(), EPERM); - nd4j_printf("Error code: %i\n", (int) errno); - throw std::runtime_error("CreateFile failed"); - } + if (h == INVALID_HANDLE_VALUE) { + errno = __map_mman_error(GetLastError(), EPERM); + nd4j_printf("Error code: %i\n", (int)errno); + throw std::runtime_error("CreateFile failed"); + } - fm = CreateFileMapping(h, NULL, protect, dwMaxSizeHigh, dwMaxSizeLow, NULL); + fm = CreateFileMapping(h, NULL, protect, dwMaxSizeHigh, dwMaxSizeLow, NULL); - if (fm == NULL) - { - errno = __map_mman_error(GetLastError(), EPERM); - throw std::runtime_error("CreateFileMapping failed"); - } + if (fm == NULL) { + errno = __map_mman_error(GetLastError(), EPERM); + throw std::runtime_error("CreateFileMapping failed"); + } - map = MapViewOfFile(fm, desiredAccess, dwFileOffsetHigh, dwFileOffsetLow, length); + map = MapViewOfFile(fm, desiredAccess, dwFileOffsetHigh, dwFileOffsetLow, + length); - CloseHandle(fm); + CloseHandle(fm); - if (map == NULL) - { - errno = __map_mman_error(GetLastError(), EPERM); - throw std::runtime_error("MapViewOfFile failed"); - } + if (map == NULL) { + errno = __map_mman_error(GetLastError(), EPERM); + throw std::runtime_error("MapViewOfFile failed"); + } - result[0] = reinterpret_cast(map); - result[1] = reinterpret_cast(h); + result[0] = reinterpret_cast(map); + result[1] = reinterpret_cast(h); } -void* mmap(void *addr, size_t len, int prot, int flags, int files, OffsetType off) -{ - HANDLE fm, h; +void *mmap(void *addr, size_t len, int prot, int flags, int files, + OffsetType off) { + HANDLE fm, h; - void * map = MAP_FAILED; + void *map = MAP_FAILED; #ifdef _MSC_VER - #pragma warning(push) -#pragma warning(disable: 4293) +#pragma warning(push) +#pragma warning(disable : 4293) #endif - const DWORD dwFileOffsetLow = (sizeof(OffsetType) <= sizeof(DWORD)) ? - (DWORD)off : (DWORD)(off & 0xFFFFFFFFL); - const DWORD dwFileOffsetHigh = (sizeof(OffsetType) <= sizeof(DWORD)) ? - (DWORD)0 : (DWORD)((off >> 32) & 0xFFFFFFFFL); - const DWORD protect = __map_mmap_prot_page(prot); - const DWORD desiredAccess = __map_mmap_prot_file(prot); + const DWORD dwFileOffsetLow = (sizeof(OffsetType) <= sizeof(DWORD)) + ? (DWORD)off + : (DWORD)(off & 0xFFFFFFFFL); + const DWORD dwFileOffsetHigh = (sizeof(OffsetType) <= sizeof(DWORD)) + ? (DWORD)0 + : (DWORD)((off >> 32) & 0xFFFFFFFFL); + const DWORD protect = __map_mmap_prot_page(prot); + const DWORD desiredAccess = __map_mmap_prot_file(prot); - const OffsetType maxSize = off + (OffsetType)len; + const OffsetType maxSize = off + (OffsetType)len; - const DWORD dwMaxSizeLow = (sizeof(OffsetType) <= sizeof(DWORD)) ? - (DWORD)maxSize : (DWORD)(maxSize & 0xFFFFFFFFL); - const DWORD dwMaxSizeHigh = (sizeof(OffsetType) <= sizeof(DWORD)) ? - (DWORD)0 : (DWORD)((maxSize >> 32) & 0xFFFFFFFFL); + const DWORD dwMaxSizeLow = (sizeof(OffsetType) <= sizeof(DWORD)) + ? (DWORD)maxSize + : (DWORD)(maxSize & 0xFFFFFFFFL); + const DWORD dwMaxSizeHigh = (sizeof(OffsetType) <= sizeof(DWORD)) + ? (DWORD)0 + : (DWORD)((maxSize >> 32) & 0xFFFFFFFFL); #ifdef _MSC_VER #pragma warning(pop) #endif - errno = 0; + errno = 0; - if (len == 0 - /* Unsupported flag combinations */ - || (flags & MAP_FIXED) != 0 - /* Usupported protection combinations */ - || prot == PROT_EXEC) - { - errno = EINVAL; - return MAP_FAILED; - } + if (len == 0 + /* Unsupported flag combinations */ + || (flags & MAP_FIXED) != 0 + /* Usupported protection combinations */ + || prot == PROT_EXEC) { + errno = EINVAL; + return MAP_FAILED; + } - h = ((flags & MAP_ANONYMOUS) == 0) ? - (HANDLE)_get_osfhandle(files) : INVALID_HANDLE_VALUE; + h = ((flags & MAP_ANONYMOUS) == 0) ? (HANDLE)_get_osfhandle(files) + : INVALID_HANDLE_VALUE; + if ((flags & MAP_ANONYMOUS) == 0 && h == INVALID_HANDLE_VALUE) { + errno = EBADF; + return MAP_FAILED; + } - if ((flags & MAP_ANONYMOUS) == 0 && h == INVALID_HANDLE_VALUE) - { - errno = EBADF; - return MAP_FAILED; - } + fm = CreateFileMapping(h, NULL, protect, dwMaxSizeHigh, dwMaxSizeLow, NULL); - fm = CreateFileMapping(h, NULL, protect, dwMaxSizeHigh, dwMaxSizeLow, NULL); + if (fm == NULL) { + errno = __map_mman_error(GetLastError(), EPERM); + return MAP_FAILED; + } - if (fm == NULL) - { - errno = __map_mman_error(GetLastError(), EPERM); - return MAP_FAILED; - } + map = + MapViewOfFile(fm, desiredAccess, dwFileOffsetHigh, dwFileOffsetLow, len); - map = MapViewOfFile(fm, desiredAccess, dwFileOffsetHigh, dwFileOffsetLow, len); + CloseHandle(fm); - CloseHandle(fm); + if (map == NULL) { + errno = __map_mman_error(GetLastError(), EPERM); + return MAP_FAILED; + } - if (map == NULL) - { - errno = __map_mman_error(GetLastError(), EPERM); - return MAP_FAILED; - } - - return map; + return map; } -int munmap(void *addr, size_t len) -{ - if (UnmapViewOfFile(addr)) - return 0; +int munmap(void *addr, size_t len) { + if (UnmapViewOfFile(addr)) return 0; - errno = __map_mman_error(GetLastError(), EPERM); + errno = __map_mman_error(GetLastError(), EPERM); - return -1; + return -1; } -int _mprotect(void *addr, size_t len, int prot) -{ - DWORD newProtect = __map_mmap_prot_page(prot); - DWORD oldProtect = 0; +int _mprotect(void *addr, size_t len, int prot) { + DWORD newProtect = __map_mmap_prot_page(prot); + DWORD oldProtect = 0; - if (VirtualProtect(addr, len, newProtect, &oldProtect)) - return 0; + if (VirtualProtect(addr, len, newProtect, &oldProtect)) return 0; - errno = __map_mman_error(GetLastError(), EPERM); + errno = __map_mman_error(GetLastError(), EPERM); - return -1; + return -1; } int msync(void *addr, size_t len, int flags) { - if (FlushViewOfFile(addr, len)) - return 0; + if (FlushViewOfFile(addr, len)) return 0; - errno = __map_mman_error(GetLastError(), EPERM); + errno = __map_mman_error(GetLastError(), EPERM); - return -1; + return -1; } -int mlock(const void *addr, size_t len) -{ - if (VirtualLock((LPVOID)addr, len)) - return 0; +int mlock(const void *addr, size_t len) { + if (VirtualLock((LPVOID)addr, len)) return 0; - errno = __map_mman_error(GetLastError(), EPERM); + errno = __map_mman_error(GetLastError(), EPERM); - return -1; + return -1; } -int munlock(const void *addr, size_t len) -{ - if (VirtualUnlock((LPVOID)addr, len)) - return 0; +int munlock(const void *addr, size_t len) { + if (VirtualUnlock((LPVOID)addr, len)) return 0; - errno = __map_mman_error(GetLastError(), EPERM); + errno = __map_mman_error(GetLastError(), EPERM); - return -1; + return -1; } - -#endif //PROJECT_MMAN_H +#endif // PROJECT_MMAN_H diff --git a/libnd4j/include/helpers/shape.h b/libnd4j/include/helpers/shape.h index 719b086cb95d..9c2e9b26005b 100644 --- a/libnd4j/include/helpers/shape.h +++ b/libnd4j/include/helpers/shape.h @@ -24,20 +24,22 @@ #ifndef SHAPE_H_ #define SHAPE_H_ -#include +#include + #include +#include + +#include "../cnpy/cnpy.h" +#include "../helpers/logger.h" +#include "math/templatemath.h" #include "system/dll.h" #include "system/nd4jmalloc.h" -#include "math/templatemath.h" -#include "../helpers/logger.h" #include "system/pointercast.h" -#include "../cnpy/cnpy.h" -#include #define MAX_DIMENSION 0x7fffffff -#define MAX_NUM_THREADS 1024 +#define MAX_NUM_THREADS 1024 #define MAX_RANK 32 -#define MAX_SHAPEINFOLENGTH 2*MAX_RANK+4 +#define MAX_SHAPEINFOLENGTH 2 * MAX_RANK + 4 #define MAX_COORD 3 #define PREALLOC_SIZE 33554432 #ifdef __CUDACC__ @@ -45,16 +47,16 @@ #include #endif - #ifdef __CUDACC__ #define INLINEDEF inline #else #define INLINEDEF inline #endif -#include "system/pairwise_util.h" -#include #include +#include + +#include "system/pairwise_util.h" typedef unsigned int uint; @@ -64,103 +66,134 @@ namespace shape { * Shape information approximating * the information on an ndarray */ - struct ND4J_EXPORT ShapeInformation { - _CUDA_HD ShapeInformation(Nd4jLong* shape_ = nullptr, Nd4jLong *stride_ = nullptr, char order_ = 0, int rank_ = 0, int offset_ = 0, int elementWiseStride_ = 0) - : shape(shape_), stride(stride_), order(order_), rank(rank_), offset(offset_), elementWiseStride(elementWiseStride_) - {} - - Nd4jLong *shape; - Nd4jLong *stride; - char order; - int rank; - int offset; - int elementWiseStride; - }; +struct SD_EXPORT ShapeInformation { + _CUDA_HD ShapeInformation(Nd4jLong *shape_ = nullptr, + Nd4jLong *stride_ = nullptr, char order_ = 0, + int rank_ = 0, int offset_ = 0, + int elementWiseStride_ = 0) + : shape(shape_), + stride(stride_), + order(order_), + rank(rank_), + offset(offset_), + elementWiseStride(elementWiseStride_) {} + + Nd4jLong *shape; + Nd4jLong *stride; + char order; + int rank; + int offset; + int elementWiseStride; +}; /** * Indexing information * for bounds checking */ - struct ND4J_EXPORT CurrentIndexing { - int numElementsPerThread; - int blockStartingIndex; - int startingThreadIndex; - int endingThreadIndex; - - }; +struct SD_EXPORT CurrentIndexing { + int numElementsPerThread; + int blockStartingIndex; + int startingThreadIndex; + int endingThreadIndex; +}; +SD_EXPORT _CUDA_HD bool shapeEquals(const int shape1Rank, + const Nd4jLong *shape1, + const int shape2Rank, + const Nd4jLong *shape2); +SD_EXPORT _CUDA_HD const Nd4jLong *detachShape(const Nd4jLong *originalShape); - ND4J_EXPORT _CUDA_HD bool shapeEquals(const int shape1Rank, const Nd4jLong *shape1, const int shape2Rank, const Nd4jLong *shape2); +SD_EXPORT _CUDA_HD Nd4jLong *copyShape(Nd4jLong const *originalShape); - ND4J_EXPORT _CUDA_HD const Nd4jLong* detachShape(const Nd4jLong *originalShape); +SD_EXPORT _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, + const Nd4jLong *shapeInfo2); - ND4J_EXPORT _CUDA_HD Nd4jLong* copyShape(Nd4jLong const* originalShape); +SD_EXPORT _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, + const Nd4jLong *shapeInfo2, + const Nd4jLong *shapeInfo3); - ND4J_EXPORT _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2); +SD_EXPORT _CUDA_HD bool strideEquals(int const shape1Rank, + Nd4jLong const *shape1, + int const shape2Rank, + Nd4jLong const *shape2); - ND4J_EXPORT _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2, const Nd4jLong *shapeInfo3); +SD_EXPORT _CUDA_HD bool strideEquals(Nd4jLong const *shapeInfo1, + Nd4jLong const *shapeInfo2); - ND4J_EXPORT _CUDA_HD bool strideEquals(int const shape1Rank,Nd4jLong const* shape1,int const shape2Rank, Nd4jLong const* shape2); +SD_EXPORT _CUDA_HD bool strideEquals(Nd4jLong const *stride1, int const rank1, + Nd4jLong const *stride2, int const rank2); - ND4J_EXPORT _CUDA_HD bool strideEquals(Nd4jLong const* shapeInfo1, Nd4jLong const* shapeInfo2); +SD_EXPORT _CUDA_HD bool equalsSoft(const Nd4jLong *shapeA, + const Nd4jLong *shapeB); - ND4J_EXPORT _CUDA_HD bool strideEquals(Nd4jLong const* stride1,int const rank1, Nd4jLong const* stride2, int const rank2); +SD_EXPORT _CUDA_HD bool equalsTypesAndShapesSoft(const Nd4jLong *shapeA, + const Nd4jLong *shapeB); - ND4J_EXPORT _CUDA_HD bool equalsSoft(const Nd4jLong *shapeA, const Nd4jLong *shapeB); +SD_EXPORT _CUDA_HD bool equalsStrict(const Nd4jLong *shapeA, + const Nd4jLong *shapeB); - ND4J_EXPORT _CUDA_HD bool equalsTypesAndShapesSoft(const Nd4jLong *shapeA, const Nd4jLong *shapeB); +// returns true if ranks, shapes and strides are the same +SD_EXPORT _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, + const Nd4jLong *shapeInfo2); +SD_EXPORT _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, + const Nd4jLong *shapeInfo2, + const Nd4jLong *shapeInfo3); - ND4J_EXPORT _CUDA_HD bool equalsStrict(const Nd4jLong *shapeA, const Nd4jLong *shapeB); +SD_EXPORT _CUDA_HD int sizeAt(const Nd4jLong *shapeInfo, const int dim); +SD_EXPORT _CUDA_HD Nd4jLong strideAt(const Nd4jLong *shapeInfo, const int dim); - // returns true if ranks, shapes and strides are the same - ND4J_EXPORT _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2); - ND4J_EXPORT _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2, const Nd4jLong *shapeInfo3); - - ND4J_EXPORT _CUDA_HD int sizeAt(const Nd4jLong *shapeInfo, const int dim); - ND4J_EXPORT _CUDA_HD Nd4jLong strideAt(const Nd4jLong *shapeInfo, const int dim); - - template - ND4J_EXPORT _CUDA_HD void fill(T* buffer, T value, Nd4jLong length); - - ND4J_EXPORT _CUDA_HD void traceNew(int id); +template +SD_EXPORT _CUDA_HD void fill(T *buffer, T value, Nd4jLong length); +SD_EXPORT _CUDA_HD void traceNew(int id); - ND4J_EXPORT _CUDA_HD int tadIndexForLinear(int linearIndex, int tadLength); +SD_EXPORT _CUDA_HD int tadIndexForLinear(int linearIndex, int tadLength); - ND4J_EXPORT _CUDA_HD Nd4jLong tadLength(const Nd4jLong *shapeInfo, int *dimension, int dimensionLength); +SD_EXPORT _CUDA_HD Nd4jLong tadLength(const Nd4jLong *shapeInfo, int *dimension, + int dimensionLength); - ND4J_EXPORT _CUDA_HD bool canReshape(const int oldRank, Nd4jLong* oldShape, const int newRank, Nd4jLong* newShape, bool isFOrder); +SD_EXPORT _CUDA_HD bool canReshape(const int oldRank, Nd4jLong *oldShape, + const int newRank, Nd4jLong *newShape, + bool isFOrder); - ND4J_EXPORT _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, const char newOrder, const int newRank, const Nd4jLong* newShape, Nd4jLong* newShapeInfo); - /** - * newShapeInfo contains rank, shape and order only, no strides/ews/type - */ - ND4J_EXPORT _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, Nd4jLong* newShapeInfo); +SD_EXPORT _CUDA_HD bool reshapeC(const Nd4jLong *oldShapeInfo, + const char newOrder, const int newRank, + const Nd4jLong *newShape, + Nd4jLong *newShapeInfo); +/** + * newShapeInfo contains rank, shape and order only, no strides/ews/type + */ +SD_EXPORT _CUDA_HD bool reshapeC(const Nd4jLong *oldShapeInfo, + Nd4jLong *newShapeInfo); - /** - * Get the shape info buffer - * for the given rank and shape. - */ - ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong const* shape); +/** + * Get the shape info buffer + * for the given rank and shape. + */ +SD_EXPORT _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, + Nd4jLong const *shape); - ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong const* shape, Nd4jLong *buffer); +SD_EXPORT _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, + Nd4jLong const *shape, + Nd4jLong *buffer); - /** - * Get the shape info buffer - * for the given rank and shape. - */ - ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong const* shape); +/** + * Get the shape info buffer + * for the given rank and shape. + */ +SD_EXPORT _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, + Nd4jLong const *shape); - ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong const* shape, Nd4jLong *output); +SD_EXPORT _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, + Nd4jLong const *shape, + Nd4jLong *output); #ifdef __CUDACC__ - __device__ ND4J_EXPORT Nd4jLong *cuMalloc(Nd4jLong *buffer, long size); +__device__ SD_EXPORT Nd4jLong *cuMalloc(Nd4jLong *buffer, long size); #endif - - /** * Computes the standard packed array strides for a given shape. * @@ -168,9 +201,11 @@ namespace shape { * @param startNum the start number for the strides * @return the strides for a matrix of n dimensions */ - ND4J_EXPORT _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong const* shape, int rank); +SD_EXPORT _CUDA_HD Nd4jLong *calcStridesFortran(Nd4jLong const *shape, + int rank); - ND4J_EXPORT _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong const* shape, int rank, Nd4jLong* ret); +SD_EXPORT _CUDA_HD Nd4jLong *calcStridesFortran(Nd4jLong const *shape, int rank, + Nd4jLong *ret); /** * Computes the standard packed array strides for a given shape. @@ -180,17 +215,19 @@ namespace shape { * @return the strides for a matrix of n dimensions */ - ND4J_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong const *shape, int rank); - - ND4J_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong const *shape, int rank, Nd4jLong* ret); +SD_EXPORT _CUDA_HD Nd4jLong *calcStrides(Nd4jLong const *shape, int rank); - ND4J_EXPORT _CUDA_HD void updateStrides(Nd4jLong *shape, const char order); - ND4J_EXPORT _CUDA_HD void updateStrides(const int rank, const Nd4jLong *shapeOnly, Nd4jLong *stridesOnly, const char order); +SD_EXPORT _CUDA_HD Nd4jLong *calcStrides(Nd4jLong const *shape, int rank, + Nd4jLong *ret); +SD_EXPORT _CUDA_HD void updateStrides(Nd4jLong *shape, const char order); +SD_EXPORT _CUDA_HD void updateStrides(const int rank, const Nd4jLong *shapeOnly, + Nd4jLong *stridesOnly, const char order); -// check whether input dimensions are permuted, not permuted dimensions order have to be 0,....,rank-1 - template - ND4J_EXPORT _CUDA_HD bool isDimPermuted(const T* dimensions, const int dimSize); +// check whether input dimensions are permuted, not permuted dimensions order +// have to be 0,....,rank-1 +template +SD_EXPORT _CUDA_HD bool isDimPermuted(const T *dimensions, const int dimSize); /** * Computes the standard packed array strides for a given shape. @@ -199,9 +236,11 @@ namespace shape { * @param startNum the start number for the strides * @return the strides for a matrix of n dimensions */ - ND4J_EXPORT _CUDA_HD Nd4jLong* calcStridesFortran(Nd4jLong const *shape, int rank, int startNum); +SD_EXPORT _CUDA_HD Nd4jLong *calcStridesFortran(Nd4jLong const *shape, int rank, + int startNum); - ND4J_EXPORT _CUDA_HD Nd4jLong* calcStridesFortran(Nd4jLong const *shape, int rank, int startNum, Nd4jLong* ret); +SD_EXPORT _CUDA_HD Nd4jLong *calcStridesFortran(Nd4jLong const *shape, int rank, + int startNum, Nd4jLong *ret); /** * Computes the standard packed array strides for a given shape. @@ -210,28 +249,28 @@ namespace shape { * @param startNum the start number for the strides * @return the strides for a matrix of n dimensions */ - ND4J_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong const* shape, int rank, int startNum); +SD_EXPORT _CUDA_HD Nd4jLong *calcStrides(Nd4jLong const *shape, int rank, + int startNum); - ND4J_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong const *shape, int rank, int startNum, Nd4jLong* ret); +SD_EXPORT _CUDA_HD Nd4jLong *calcStrides(Nd4jLong const *shape, int rank, + int startNum, Nd4jLong *ret); /** * @param toCopy the shape to copy * @return a copy of the original struct */ - ND4J_EXPORT _CUDA_HD ShapeInformation *shapeCopy( ShapeInformation *toCopy); - +SD_EXPORT _CUDA_HD ShapeInformation *shapeCopy(ShapeInformation *toCopy); - ND4J_EXPORT _CUDA_HD bool strideDescendingCAscendingF(const Nd4jLong *shapeBuffer); - - ND4J_EXPORT _CUDA_HD bool isContiguous(const Nd4jLong* shapeInfo); +SD_EXPORT _CUDA_HD bool strideDescendingCAscendingF( + const Nd4jLong *shapeBuffer); +SD_EXPORT _CUDA_HD bool isContiguous(const Nd4jLong *shapeInfo); /** * copy-past from java hasDefaultStridesForShape function * check whether array is not permuted and has contiguous elements in memory */ - ND4J_EXPORT _CUDA_HD bool areStridesDefault(const Nd4jLong* shapeInfo); - +SD_EXPORT _CUDA_HD bool areStridesDefault(const Nd4jLong *shapeInfo); /** * Compute the element wise stride @@ -244,7 +283,9 @@ namespace shape { * @return 0 if there is no element wise stride the * element wise stride of reshape(1,length) otherwise */ - ND4J_EXPORT _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong const* shape, Nd4jLong const* stride, int isFOrder); +SD_EXPORT _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong const *shape, + Nd4jLong const *stride, + int isFOrder); /** * Compute the element wise stride @@ -257,11 +298,19 @@ namespace shape { * @return 0 if there is no element wise stride the * element wise stride of reshape(1,length) otherwise */ - ND4J_EXPORT _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong const* shape, Nd4jLong const* stride, int isFOrder, Nd4jLong const* dimension, int dimensionLength); - - ND4J_EXPORT _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(Nd4jLong const* shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride); - - ND4J_EXPORT _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(const Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride, Nd4jLong *buffer); +SD_EXPORT _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong const *shape, + Nd4jLong const *stride, + int isFOrder, + Nd4jLong const *dimension, + int dimensionLength); + +SD_EXPORT _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride( + Nd4jLong const *shapeInfo, Nd4jLong *dimension, int dimensionLength, + bool reverseCopyStride); + +SD_EXPORT _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride( + const Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength, + bool reverseCopyStride, Nd4jLong *buffer); /** * * @param length @@ -269,9 +318,8 @@ namespace shape { * @param rearrange * @return */ - ND4J_EXPORT _CUDA_HD Nd4jLong *doPermuteSwap(int length, Nd4jLong *shape, int* rearrange); - - +SD_EXPORT _CUDA_HD Nd4jLong *doPermuteSwap(int length, Nd4jLong *shape, + int *rearrange); /** * In place permute swap @@ -279,40 +327,48 @@ namespace shape { * @param shape * @param rearrange */ - ND4J_EXPORT _CUDA_HD void doPermuteSwap(int length, Nd4jLong **shape, int* rearrange); +SD_EXPORT _CUDA_HD void doPermuteSwap(int length, Nd4jLong **shape, + int *rearrange); - ND4J_EXPORT _CUDA_HD Nd4jLong *permuteShapeBuffer(Nd4jLong const* shapeBuffer, int* rearrange); +SD_EXPORT _CUDA_HD Nd4jLong *permuteShapeBuffer(Nd4jLong const *shapeBuffer, + int *rearrange); - ND4J_EXPORT _CUDA_HD void permuteShapeBufferInPlace(Nd4jLong *shapeBuffer, int* rearrange, Nd4jLong *out); +SD_EXPORT _CUDA_HD void permuteShapeBufferInPlace(Nd4jLong *shapeBuffer, + int *rearrange, + Nd4jLong *out); - ND4J_EXPORT _CUDA_HD void doPermuteShapeInfo(Nd4jLong *shapeBuffer, const int *rearrange, Nd4jLong len = -1); - - /** - * Rearrange the permute indexes - * according to which dimensions are specified. - * - * For example, dimension is implicitly: - * 0,1,2 - * - * If you want to do a reduce along dimensions 0 and 1, - * you need to permute the indexes to be: - * 2,0,1 - * - * which will give us the ability to ierate along an element - * wise stride. - */ +SD_EXPORT _CUDA_HD void doPermuteShapeInfo(Nd4jLong *shapeBuffer, + const int *rearrange, + Nd4jLong len = -1); - ND4J_EXPORT _CUDA_HD Nd4jLong* createPermuteIndexes(int originalRank, int *dimension,int dimensionLength); +/** + * Rearrange the permute indexes + * according to which dimensions are specified. + * + * For example, dimension is implicitly: + * 0,1,2 + * + * If you want to do a reduce along dimensions 0 and 1, + * you need to permute the indexes to be: + * 2,0,1 + * + * which will give us the ability to ierate along an element + * wise stride. + */ - ND4J_EXPORT _CUDA_HD Nd4jLong* computeResultShape(const Nd4jLong *originalShapeBuffer, int *dimension,int dimensionLength); +SD_EXPORT _CUDA_HD Nd4jLong *createPermuteIndexes(int originalRank, + int *dimension, + int dimensionLength); - /** - * This method does inplace transpose of given shapeBuffer - * - * @param shapeBuffer - */ - ND4J_EXPORT _CUDA_HD void transposeInplace(Nd4jLong *shapeBuffer); +SD_EXPORT _CUDA_HD Nd4jLong *computeResultShape( + const Nd4jLong *originalShapeBuffer, int *dimension, int dimensionLength); +/** + * This method does inplace transpose of given shapeBuffer + * + * @param shapeBuffer + */ +SD_EXPORT _CUDA_HD void transposeInplace(Nd4jLong *shapeBuffer); /** * Get the ordering for the device @@ -322,7 +378,8 @@ namespace shape { * @param elementStride * @return */ - ND4J_EXPORT _CUDA_HD char getOrder(int length, Nd4jLong *shape, Nd4jLong *stride, int elementStride); +SD_EXPORT _CUDA_HD char getOrder(int length, Nd4jLong *shape, Nd4jLong *stride, + int elementStride); /** * Ensure that every value in the re arrange @@ -333,8 +390,9 @@ namespace shape { * @param shapeLength * @return */ - template - ND4J_EXPORT _CUDA_HD int checkArrangeArray(T *arr, int arrLength, int shapeLength); +template +SD_EXPORT _CUDA_HD int checkArrangeArray(T *arr, int arrLength, + int shapeLength); /** * Permute the shape information @@ -342,7 +400,8 @@ namespace shape { * @param rearrange the order to re arrange * @param rank the rank of the rearrange array */ - ND4J_EXPORT _CUDA_HD void permute(ShapeInformation **info, int *rearrange, int rank); +SD_EXPORT _CUDA_HD void permute(ShapeInformation **info, int *rearrange, + int rank); /** * Returns whether the @@ -350,49 +409,51 @@ namespace shape { * @param shape the shape of the array * @param rank the rank of cthe shape */ - ND4J_EXPORT _CUDA_HD int isVector(Nd4jLong const* shape, int rank); +SD_EXPORT _CUDA_HD int isVector(Nd4jLong const *shape, int rank); +/** + * When 1 dimension is the whole length of the + * array + */ +SD_EXPORT _CUDA_HD int oneDimEqualToLength(Nd4jLong *shape, int rank); - /** - * When 1 dimension is the whole length of the - * array - */ - ND4J_EXPORT _CUDA_HD int oneDimEqualToLength(Nd4jLong *shape, int rank); - - ND4J_EXPORT _CUDA_HD int oneDimEqualToLength(Nd4jLong *shapeInfo); +SD_EXPORT _CUDA_HD int oneDimEqualToLength(Nd4jLong *shapeInfo); - ND4J_EXPORT _CUDA_HD int isVector(const Nd4jLong *shapeInfo); +SD_EXPORT _CUDA_HD int isVector(const Nd4jLong *shapeInfo); - ND4J_EXPORT _CUDA_HD bool isLikeVector(Nd4jLong const* shapeInfo, int& posOfNonUnityDim); +SD_EXPORT _CUDA_HD bool isLikeVector(Nd4jLong const *shapeInfo, + int &posOfNonUnityDim); - ND4J_EXPORT _CUDA_HD bool isCommonVector(const Nd4jLong *shapeInfo, int& posOfNonUnityDim); +SD_EXPORT _CUDA_HD bool isCommonVector(const Nd4jLong *shapeInfo, + int &posOfNonUnityDim); - ND4J_EXPORT _CUDA_HD bool isRowVector(const Nd4jLong *shapeInfo); +SD_EXPORT _CUDA_HD bool isRowVector(const Nd4jLong *shapeInfo); - ND4J_EXPORT _CUDA_HD bool isColumnVector(Nd4jLong const* shapeInfo); +SD_EXPORT _CUDA_HD bool isColumnVector(Nd4jLong const *shapeInfo); - /** - * shape - input inShape is shape only, not shapeInfo - * returns number of non-unity dimensions in inShape - */ - ND4J_EXPORT _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape); +/** + * shape - input inShape is shape only, not shapeInfo + * returns number of non-unity dimensions in inShape + */ +SD_EXPORT _CUDA_HD int numOfNonUnitDims(const int rank, + const Nd4jLong *inShape); - /** +/** * Returns whether the * given shape is a vector or not * @param shape the shape of the array * @param rank the rank of the shape */ - ND4J_EXPORT _CUDA_HD int isMatrix(const Nd4jLong *shape, int rank); +SD_EXPORT _CUDA_HD int isMatrix(const Nd4jLong *shape, int rank); - INLINEDEF _CUDA_HD int isMatrix(const Nd4jLong *shapeInfo); +INLINEDEF _CUDA_HD int isMatrix(const Nd4jLong *shapeInfo); /** * Returns the shape portion of an information * buffer */ - ND4J_EXPORT _CUDA_HD Nd4jLong *shapeOf(Nd4jLong *shapeInfo); - ND4J_EXPORT _CUDA_HD Nd4jLong *shapeOf(const Nd4jLong *shapeInfo); +SD_EXPORT _CUDA_HD Nd4jLong *shapeOf(Nd4jLong *shapeInfo); +SD_EXPORT _CUDA_HD Nd4jLong *shapeOf(const Nd4jLong *shapeInfo); /** * Return a copy of a buffer. @@ -400,26 +461,27 @@ namespace shape { * that must be freed elsewhere. */ - template - ND4J_EXPORT _CUDA_HD T* copyOf(Nd4jLong length, T const* toCopy); +template +SD_EXPORT _CUDA_HD T *copyOf(Nd4jLong length, T const *toCopy); - template - ND4J_EXPORT _CUDA_HD T* copyOf(Nd4jLong length, T const* toCopy, T *ret); +template +SD_EXPORT _CUDA_HD T *copyOf(Nd4jLong length, T const *toCopy, T *ret); - /** +/** * Return a copy of a buffer. * This buffer allocates memory * that must be freed elsewhere. */ - template - ND4J_EXPORT _CUDA_HD void copyTo(Nd4jLong length, T const* from, T *to); - /** -* Return a copy of a buffer. -* This buffer allocates memory -* that must be freed elsewhere. -*/ - ND4J_EXPORT _CUDA_HD void copyTo(int length, Nd4jLong const* from, Nd4jLong *to, Nd4jLong *indexes); +template +SD_EXPORT _CUDA_HD void copyTo(Nd4jLong length, T const *from, T *to); +/** + * Return a copy of a buffer. + * This buffer allocates memory + * that must be freed elsewhere. + */ +SD_EXPORT _CUDA_HD void copyTo(int length, Nd4jLong const *from, Nd4jLong *to, + Nd4jLong *indexes); /** * Permute the given strides @@ -430,18 +492,20 @@ namespace shape { * and all must be filled in) * @return the rearranged array */ - //ND4J_EXPORT _CUDA_HD Nd4jLong *permutedStrides(Nd4jLong *toPermute, int shapeRank, Nd4jLong *rearrange); +// SD_EXPORT _CUDA_HD Nd4jLong *permutedStrides(Nd4jLong *toPermute, int +// shapeRank, Nd4jLong *rearrange); /** * Return the slice (shape + 1 in pointer arithmetic) * @param shape the shape to take the slice of * @return the shape array - the first entry */ - ND4J_EXPORT _CUDA_HD Nd4jLong *slice(Nd4jLong *shape); +SD_EXPORT _CUDA_HD Nd4jLong *slice(Nd4jLong *shape); - ND4J_EXPORT _CUDA_HD int slices(Nd4jLong *shapeBuffer); +SD_EXPORT _CUDA_HD int slices(Nd4jLong *shapeBuffer); - ND4J_EXPORT _CUDA_HD Nd4jLong *sliceOfShapeBuffer(Nd4jLong sliceIdx, Nd4jLong *shapeBuffer); +SD_EXPORT _CUDA_HD Nd4jLong *sliceOfShapeBuffer(Nd4jLong sliceIdx, + Nd4jLong *shapeBuffer); /** * Returns the length of the * shape information buffer: @@ -450,30 +514,30 @@ namespace shape { * info length for * @return rank * 2 + 4 */ - ND4J_EXPORT _CUDA_HD int shapeInfoLength(int rank); +SD_EXPORT _CUDA_HD int shapeInfoLength(int rank); - ND4J_EXPORT _CUDA_HD int shapeInfoLength(Nd4jLong* shapeInfo); +SD_EXPORT _CUDA_HD int shapeInfoLength(Nd4jLong *shapeInfo); - ND4J_EXPORT _CUDA_HD int shapeInfoLength(const Nd4jLong* shapeInfo); +SD_EXPORT _CUDA_HD int shapeInfoLength(const Nd4jLong *shapeInfo); - ND4J_EXPORT _CUDA_HD size_t shapeInfoByteLength(int rank); +SD_EXPORT _CUDA_HD size_t shapeInfoByteLength(int rank); - ND4J_EXPORT _CUDA_HD size_t shapeInfoByteLength(const Nd4jLong* shapeInfo); +SD_EXPORT _CUDA_HD size_t shapeInfoByteLength(const Nd4jLong *shapeInfo); - ND4J_EXPORT _CUDA_HD size_t shapeInfoByteLength(const Nd4jLong* shapeInfo); +SD_EXPORT _CUDA_HD size_t shapeInfoByteLength(const Nd4jLong *shapeInfo); /** * Returns the rank portion of * an information buffer */ - ND4J_EXPORT _CUDA_HD int rank(const Nd4jLong *shapeInfo); - ND4J_EXPORT _CUDA_HD int rank(const int *shapeInfo); - ND4J_EXPORT _CUDA_HD int rank(const unsigned int *shapeInfo); +SD_EXPORT _CUDA_HD int rank(const Nd4jLong *shapeInfo); +SD_EXPORT _CUDA_HD int rank(const int *shapeInfo); +SD_EXPORT _CUDA_HD int rank(const unsigned int *shapeInfo); - /** - * returns pointer on elementWiseStride - */ - ND4J_EXPORT _CUDA_HD Nd4jLong* ews(Nd4jLong* shapeInfo); +/** + * returns pointer on elementWiseStride + */ +SD_EXPORT _CUDA_HD Nd4jLong *ews(Nd4jLong *shapeInfo); /** * Converts a raw int buffer of the layout: @@ -485,65 +549,65 @@ namespace shape { * * where shape and stride are both straight int pointers */ - ND4J_EXPORT _CUDA_HD ShapeInformation *infoFromBuffer(Nd4jLong *buffer); +SD_EXPORT _CUDA_HD ShapeInformation *infoFromBuffer(Nd4jLong *buffer); /** * Returns the stride portion of an information * buffer */ - ND4J_EXPORT _CUDA_HD Nd4jLong *stride(Nd4jLong *buffer); +SD_EXPORT _CUDA_HD Nd4jLong *stride(Nd4jLong *buffer); - ND4J_EXPORT _CUDA_HD Nd4jLong *stride(const Nd4jLong *buffer); +SD_EXPORT _CUDA_HD Nd4jLong *stride(const Nd4jLong *buffer); /** * Compute the length of the given shape */ - ND4J_EXPORT _CUDA_HD bool isEmpty(const Nd4jLong *shapeInfo); +SD_EXPORT _CUDA_HD bool isEmpty(const Nd4jLong *shapeInfo); - ND4J_EXPORT _CUDA_HD Nd4jLong length(const Nd4jLong *shapeInfo); +SD_EXPORT _CUDA_HD Nd4jLong length(const Nd4jLong *shapeInfo); - ND4J_EXPORT _CUDA_HD Nd4jLong length(std::initializer_list& shape); +SD_EXPORT _CUDA_HD Nd4jLong length(std::initializer_list &shape); - ND4J_EXPORT _CUDA_HD Nd4jLong length(std::initializer_list& shape); +SD_EXPORT _CUDA_HD Nd4jLong length(std::initializer_list &shape); /*** * Returns the offset portion of an information buffer */ - ND4J_EXPORT _CUDA_HD Nd4jLong offset(Nd4jLong *buffer); +SD_EXPORT _CUDA_HD Nd4jLong offset(Nd4jLong *buffer); - ND4J_EXPORT _CUDA_HD Nd4jLong& extra(Nd4jLong *buffer); +SD_EXPORT _CUDA_HD Nd4jLong &extra(Nd4jLong *buffer); /** * Returns the ordering * for this shape information buffer */ - ND4J_EXPORT _CUDA_HD char order(const Nd4jLong *buffer); +SD_EXPORT _CUDA_HD char order(const Nd4jLong *buffer); /** * Returns the type */ - ND4J_EXPORT _CUDA_HD Nd4jLong type(const Nd4jLong* shapeInfo); +SD_EXPORT _CUDA_HD Nd4jLong type(const Nd4jLong *shapeInfo); /** * Returns the element wise stride for this information * buffer */ - ND4J_EXPORT _CUDA_HD Nd4jLong elementWiseStride(const Nd4jLong *shapeInfo); +SD_EXPORT _CUDA_HD Nd4jLong elementWiseStride(const Nd4jLong *shapeInfo); - - /** +/** * Returns the element wise stride for this information * buffer - * relative to a dimension and ordering for a reduction index + * relative to a dimension and ordering for a reduction index */ - ND4J_EXPORT _CUDA_HD Nd4jLong reductionIndexElementWiseStride(Nd4jLong *buffer, int *dimension, int dimensionLength); +SD_EXPORT _CUDA_HD Nd4jLong reductionIndexElementWiseStride( + Nd4jLong *buffer, int *dimension, int dimensionLength); /** * Returns whether * the given shape info buffer * represents a scalar shape */ - ND4J_EXPORT _CUDA_HD int isScalar(const Nd4jLong *info); +SD_EXPORT _CUDA_HD int isScalar(const Nd4jLong *info); /** * Returns whether @@ -551,7 +615,7 @@ namespace shape { * represents a scalar * shape or not */ - ND4J_EXPORT _CUDA_HD int isScalar(volatile ShapeInformation *info); +SD_EXPORT _CUDA_HD int isScalar(volatile ShapeInformation *info); /** * Return a copy of this array with the @@ -565,10 +629,12 @@ namespace shape { * * item */ - template - ND4J_EXPORT _CUDA_HD void removeIndex(T1 const* data, T2 const* indexes, Nd4jLong dataLength, Nd4jLong indexesLength, T1 *out); +template +SD_EXPORT _CUDA_HD void removeIndex(T1 const *data, T2 const *indexes, + Nd4jLong dataLength, Nd4jLong indexesLength, + T1 *out); - /** +/** * Return a copy of this array with the * given index omitted * @@ -581,21 +647,24 @@ namespace shape { * item */ - template - ND4J_EXPORT _CUDA_HD T1* removeIndex(T1 const* data, T2 const* indexes, Nd4jLong dataLength, Nd4jLong indexesLength); +template +SD_EXPORT _CUDA_HD T1 *removeIndex(T1 const *data, T2 const *indexes, + Nd4jLong dataLength, Nd4jLong indexesLength); - /** - * Iterate over a given set of indexes - * the begin and end indexes are 0 based. - * 1 padding is automatically assumed for the ending. - * - * For example if you want to iterate over 0 to 4 - * it will go to 4 rather than 3. - * - * indexes should be the indexes to exclude - * indexes length should be the length of indexes - */ - ND4J_EXPORT _CUDA_HD Nd4jLong* everyIndexBut(Nd4jLong const* indexes,int indexesLength,int begin,int end); +/** + * Iterate over a given set of indexes + * the begin and end indexes are 0 based. + * 1 padding is automatically assumed for the ending. + * + * For example if you want to iterate over 0 to 4 + * it will go to 4 rather than 3. + * + * indexes should be the indexes to exclude + * indexes length should be the length of indexes + */ +SD_EXPORT _CUDA_HD Nd4jLong *everyIndexBut(Nd4jLong const *indexes, + int indexesLength, int begin, + int end); /** * Computes the offset for accessing @@ -605,7 +674,7 @@ namespace shape { //#ifdef __CUDACC__ // __device__ //#endif -// ND4J_EXPORT int tadOffset(shape::ShapeInformation *xInfo, int offset); +// SD_EXPORT int tadOffset(shape::ShapeInformation *xInfo, int offset); /** * Returns a shape @@ -615,11 +684,11 @@ namespace shape { * for the shape to be returned as * @return the new shape */ - ND4J_EXPORT _CUDA_HD Nd4jLong* ensureVectorShape(Nd4jLong *shape); +SD_EXPORT _CUDA_HD Nd4jLong *ensureVectorShape(Nd4jLong *shape); - ND4J_EXPORT _CUDA_HD Nd4jLong* createScalarShapeInfo(); +SD_EXPORT _CUDA_HD Nd4jLong *createScalarShapeInfo(); - ND4J_EXPORT _CUDA_HD Nd4jLong* createScalarShapeInfo(Nd4jLong *ret); +SD_EXPORT _CUDA_HD Nd4jLong *createScalarShapeInfo(Nd4jLong *ret); /** * Generate an int buffer @@ -627,21 +696,22 @@ namespace shape { * at the specified increment * */ - template - ND4J_EXPORT _CUDA_HD T* range(int from, int to, int increment); +template +SD_EXPORT _CUDA_HD T *range(int from, int to, int increment); /** * Range between from and two with an * increment of 1 */ - template - ND4J_EXPORT _CUDA_HD T* range(int from, int to); +template +SD_EXPORT _CUDA_HD T *range(int from, int to); /** * Keep the given indexes * in the data */ - ND4J_EXPORT _CUDA_HD Nd4jLong *keep(volatile Nd4jLong *data, int const* index, int indexLength, int dataLength); +SD_EXPORT _CUDA_HD Nd4jLong *keep(volatile Nd4jLong *data, int const *index, + int indexLength, int dataLength); /** * Generate reverse copy of the data @@ -650,17 +720,18 @@ namespace shape { * @return */ - template - ND4J_EXPORT _CUDA_HD T* reverseCopy(T const* data, Nd4jLong length); +template +SD_EXPORT _CUDA_HD T *reverseCopy(T const *data, Nd4jLong length); - template - ND4J_EXPORT _CUDA_HD void reverseCopyTo(T const* from, T *to, Nd4jLong length); +template +SD_EXPORT _CUDA_HD void reverseCopyTo(T const *from, T *to, Nd4jLong length); - template - ND4J_EXPORT _CUDA_HD void reverseCopyTo(T const* from, T *to, Nd4jLong *indexes, Nd4jLong length); +template +SD_EXPORT _CUDA_HD void reverseCopyTo(T const *from, T *to, Nd4jLong *indexes, + Nd4jLong length); - template - ND4J_EXPORT _CUDA_H void convertT(T1 *from, T2 *to, Nd4jLong length); +template +SD_EXPORT _CUDA_H void convertT(T1 *from, T2 *to, Nd4jLong length); /** * * @param arr1 @@ -669,8 +740,9 @@ namespace shape { * @param arr2Length * @return */ - template - ND4J_EXPORT _CUDA_HD T* concat(T const* arr1, Nd4jLong const arr1Length, T const* arr2, Nd4jLong const arr2Length); +template +SD_EXPORT _CUDA_HD T *concat(T const *arr1, Nd4jLong const arr1Length, + T const *arr2, Nd4jLong const arr2Length); /** * @@ -680,8 +752,9 @@ namespace shape { * @param lengths * @return */ - template - ND4J_EXPORT _CUDA_HD T* concat(int const numArrays, int const numTotalElements, Nd4jLong const**arr, Nd4jLong const* lengths); +template +SD_EXPORT _CUDA_HD T *concat(int const numArrays, int const numTotalElements, + Nd4jLong const **arr, Nd4jLong const *lengths); /** * Get the length per slice of the @@ -695,7 +768,9 @@ namespace shape { * @return the length per slice of the given shape * along the given dimension */ - ND4J_EXPORT _CUDA_HD Nd4jLong lengthPerSlice(int rank, Nd4jLong const* shape, int const* dimension, int dimensionLength); +SD_EXPORT _CUDA_HD Nd4jLong lengthPerSlice(int rank, Nd4jLong const *shape, + int const *dimension, + int dimensionLength); /** * calculates the offset for a tensor @@ -704,13 +779,9 @@ namespace shape { * @param tensorShape * @return */ - ND4J_EXPORT _CUDA_HD Nd4jLong sliceOffsetForTensor(int rank, - int index, - Nd4jLong const* shape, - Nd4jLong const* tensorShape, - int tensorShapeLength, - int const *dimension, - int dimensionLength); +SD_EXPORT _CUDA_HD Nd4jLong sliceOffsetForTensor( + int rank, int index, Nd4jLong const *shape, Nd4jLong const *tensorShape, + int tensorShapeLength, int const *dimension, int dimensionLength); /** * calculates the offset for a tensor @@ -719,41 +790,41 @@ namespace shape { * @param tensorShape * @return */ - ND4J_EXPORT _CUDA_HD Nd4jLong sliceOffsetForTensor(int index,int tensorLength,int lengthPerSlice2); +SD_EXPORT _CUDA_HD Nd4jLong sliceOffsetForTensor(int index, int tensorLength, + int lengthPerSlice2); /** * Computes the tensor along dimension * offset * @param index the index to get the offset for the tad for * @param rank the rank of the shapes and strides * @param info the shape information to use for tad - * @param dimension the dimensions to use for computing the tensor along dimensions + * @param dimension the dimensions to use for computing the tensor along + * dimensions */ -// ND4J_EXPORT _CUDA_HD int offset(int index, +// SD_EXPORT _CUDA_HD int offset(int index, // int rank, // shape::ShapeInformation *info, // Nd4jLong *dimension, // int dimensionLength); - /** * Computes the number * of tensors along * a given dimension */ - ND4J_EXPORT _CUDA_HD Nd4jLong tensorsAlongDimension(int rank, - volatile int length, - volatile Nd4jLong *shape, - int *dimension, - int dimensionLength); +SD_EXPORT _CUDA_HD Nd4jLong tensorsAlongDimension(int rank, volatile int length, + volatile Nd4jLong *shape, + int *dimension, + int dimensionLength); /** * Computes the number * of tensors along * a given dimension */ - ND4J_EXPORT _CUDA_HD Nd4jLong tensorsAlongDimension(Nd4jLong *shapeInfo, int *dimension, int dimensionLength); - - +SD_EXPORT _CUDA_HD Nd4jLong tensorsAlongDimension(Nd4jLong *shapeInfo, + int *dimension, + int dimensionLength); /** * Returns the tensor along dimension @@ -763,24 +834,26 @@ namespace shape { * @param i * @return */ - ND4J_EXPORT _CUDA_HD int tadForBlockIndex(int blockSize, int blockIdx, int i); +SD_EXPORT _CUDA_HD int tadForBlockIndex(int blockSize, int blockIdx, int i); /** * Computes the number of tads per block * */ - ND4J_EXPORT _CUDA_HD int tadsPerBlock(int blockSize, int tads); +SD_EXPORT _CUDA_HD int tadsPerBlock(int blockSize, int tads); -// ND4J_EXPORT _CUDA_HD Nd4jLong *tadShapeInfo(int index, Nd4jLong *xShapeInfo, Nd4jLong *dimension, +// SD_EXPORT _CUDA_HD Nd4jLong *tadShapeInfo(int index, Nd4jLong *xShapeInfo, +// Nd4jLong *dimension, // int dimensionLength); /** * Returns a shape buffer * for the shape information metadata. */ - ND4J_EXPORT _CUDA_HD Nd4jLong *toShapeBuffer( ShapeInformation *info); +SD_EXPORT _CUDA_HD Nd4jLong *toShapeBuffer(ShapeInformation *info); - ND4J_EXPORT _CUDA_HD Nd4jLong *toShapeBuffer( ShapeInformation *info, Nd4jLong* ret); +SD_EXPORT _CUDA_HD Nd4jLong *toShapeBuffer(ShapeInformation *info, + Nd4jLong *ret); /** * Returns the number of elements per thread @@ -831,25 +904,29 @@ namespace shape { * @param numElementsPerTad the number of elements * per tad */ - ND4J_EXPORT _CUDA_HD int tadIndex(int i, int elementWiseStride, int numElementsPerTad); +SD_EXPORT _CUDA_HD int tadIndex(int i, int elementWiseStride, + int numElementsPerTad); /** * Map a tad to a * reduction index. * @param tadIndexForOriginal the original tad index for the * split up problem (eg: split is dimension 3 mapping to a 2,3 problem) - * @param tadsForReduced the number of tads for the shrunk down problem (eg: 2,3) + * @param tadsForReduced the number of tads for the shrunk down problem (eg: + * 2,3) * @param tadsForOriginal the number of tads for the smaller problem (eg: 3) */ - ND4J_EXPORT _CUDA_HD int reductionIndexForTad(int tadIndexForOriginal, int tadsForReduced, - int tadsForOriginal); +SD_EXPORT _CUDA_HD int reductionIndexForTad(int tadIndexForOriginal, + int tadsForReduced, + int tadsForOriginal); /** * Computes the number of tads * per reduce index for the * reduction tad. */ - ND4J_EXPORT _CUDA_HD int tadsPerReduceIndex(int tadsForReduce, int tadsForOriginal); +SD_EXPORT _CUDA_HD int tadsPerReduceIndex(int tadsForReduce, + int tadsForOriginal); /** * Maps a linear index to a reduction index @@ -859,259 +936,357 @@ namespace shape { * @param tadNum the number of tads for the shrunken problem * @param originalTadNum the tad number for the reduced version of the problem */ - ND4J_EXPORT _CUDA_HD int reductionIndexForLinear(int i, int elementWiseStride, int numElementsPerTad, - int tadNum, int originalTadNum); +SD_EXPORT _CUDA_HD int reductionIndexForLinear(int i, int elementWiseStride, + int numElementsPerTad, + int tadNum, int originalTadNum); /** * Returns the prod of the data * up to the given length */ - ND4J_EXPORT _CUDA_HD Nd4jLong prodLong(const Nd4jLong *data, int length); - - /** - * Returns the rear most left over item not present in - * the dimension array. This assumes that the dimension array is sorted. - * - * For example, given a dimension array of: - * 0,2 - * - * and - * - * 12,4,2,1 in data - * - * You end up with 1 (data[3]) - * since the first item won't match - * the last item of the dimension array - */ +SD_EXPORT _CUDA_HD Nd4jLong prodLong(const Nd4jLong *data, int length); -// ND4J_EXPORT _CUDA_HD int rearMostLeftOverItem(Nd4jLong *data,int length,Nd4jLong *dimension,int dimensionLength); - - /** -* Get an offset for retrieval -* from a data buffer -* based on the given -* shape stride and given indices -* @param baseOffset the offset to start from -* @param shape the shape of the array -* @param stride the stride of the array -* @param indices the indices to iterate over -* @return the double at the specified index -*/ - - ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong *coords, Nd4jLong baseOffset = 0); - ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *coords, Nd4jLong baseOffset = 0); - ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coords, Nd4jLong baseOffset = 0); - - ND4J_EXPORT _CUDA_HD Nd4jLong* createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, int rank); - - ND4J_EXPORT _CUDA_HD Nd4jLong* createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, int rank, Nd4jLong *buffer); - - /** - * Convert a linear index to the corresponding coordinates - * for example if shape is {2, 4}, then index 5 corresponds to coordinates [1, 1] - */ - ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords); - ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords); - ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, uint *coords); - ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords); - ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, int *coords); - - ND4J_EXPORT _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLong& index, const Nd4jLong *shapeInfo, Nd4jLong *coords); - ND4J_EXPORT _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLong& index, const Nd4jLong *shapeInfo, int *coords); - - /** - * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! - */ - ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords, const int dimsSize, const int* tadDims); - - /** - * Convert coordinates to the corresponding linear index (sequence number in other words) - * for example if shape is {2, 4} and coordinates [1, 1] then index 5 is returned - */ - ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const Nd4jLong *coords); - ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int *coords); - ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const uint *coords); - ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, const int *coords); - /** - * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! - */ - ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int *coords, const int dimsSize, const int* tadDims); - - /** - * increment n-dimensional array by one iteration by changing coord appropriately - * for example we have array with shape {2, 3}: - * - if input coord = {0,1}, then output coord = {0,2} - * - if input coord = {0,2}, then output coord = {1,0} - * so the aim is to produce following subsequence of coord: {0,0}, {0,1}, {0,2}, {1,0}, {1,1}, {1,2} - */ +/** + * Returns the rear most left over item not present in + * the dimension array. This assumes that the dimension array is sorted. + * + * For example, given a dimension array of: + * 0,2 + * + * and + * + * 12,4,2,1 in data + * + * You end up with 1 (data[3]) + * since the first item won't match + * the last item of the dimension array + */ - /* calculates an array buffer offset for given "index" using following formula: offset = coord_0*stride_0 + coord_1*stride_1 + ... + coord_{rank-1}*stride_{rank-1} - */ - ND4J_EXPORT _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo); - ND4J_EXPORT _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo); - ND4J_EXPORT _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeInfo, const uint* uShapeInfo, const bool useUnsigned); - - ND4J_EXPORT _CUDA_HD void printShapeInfo(Nd4jLong *shapeInfo); - - ND4J_EXPORT _CUDA_HD void printShapeInfoLinear(const Nd4jLong *shapeInfo); - - ND4J_EXPORT _CUDA_HD void printShapeInfoLinear(const char *msg, const Nd4jLong *shapeInfo); - - ND4J_EXPORT _CUDA_HD void printShapeInfoLinear(const char *msg, int rank, const Nd4jLong *shape, const Nd4jLong *strides); - - ND4J_EXPORT _CUDA_HD void printIntArray(const Nd4jLong *arr, const int length); - ND4J_EXPORT _CUDA_HD void printIntArray(const int *arr, const int length); - - ND4J_EXPORT _CUDA_HD void printArray(float *arr,int length); - - template - ND4J_EXPORT _CUDA_HD void printArray(T *arr,int length, const char *message); - - ND4J_EXPORT _CUDA_HD Nd4jLong* shapeBufferOfNpy(int rank, unsigned int *shape,bool fortranOrder); - - ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBufferOfNpy(cnpy::NpyArray arr); - -// ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBufferOfNpyBuffer(char *buffer); - - - // this function checks the consistence of dimensions with array rank (negative dimensions, too large dimensions, too big number of dimensions) - // also sort input array of dimensions, this operation is also necessary for creating TAD object - ND4J_EXPORT _CUDA_H void checkDimensions(const int rank, std::vector& dimensions); +// SD_EXPORT _CUDA_HD int rearMostLeftOverItem(Nd4jLong *data,int +// length,Nd4jLong *dimension,int dimensionLength); + +/** + * Get an offset for retrieval + * from a data buffer + * based on the given + * shape stride and given indices + * @param baseOffset the offset to start from + * @param shape the shape of the array + * @param stride the stride of the array + * @param indices the indices to iterate over + * @return the double at the specified index + */ + +SD_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, + const Nd4jLong *coords, + Nd4jLong baseOffset = 0); +SD_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, + const int *coords, + Nd4jLong baseOffset = 0); +SD_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, + const uint *coords, + Nd4jLong baseOffset = 0); + +SD_EXPORT _CUDA_HD Nd4jLong *createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, + int rank); + +SD_EXPORT _CUDA_HD Nd4jLong *createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, + int rank, Nd4jLong *buffer); + +/** + * Convert a linear index to the corresponding coordinates + * for example if shape is {2, 4}, then index 5 corresponds to coordinates [1, + * 1] + */ +SD_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, + Nd4jLong *coords); +SD_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, + int *coords); +SD_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, + uint *coords); +SD_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, + const Nd4jLong *shape, Nd4jLong *coords); +SD_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, + const Nd4jLong *shape, int *coords); + +SD_EXPORT _CUDA_HD void index2coordsCPU(const Nd4jLong &startIndex, + const Nd4jLong &index, + const Nd4jLong *shapeInfo, + Nd4jLong *coords); +SD_EXPORT _CUDA_HD void index2coordsCPU(const Nd4jLong &startIndex, + const Nd4jLong &index, + const Nd4jLong *shapeInfo, int *coords); + +/** + * take into account only dimensions stored in tadDims, tadDims must be sorted + * in increasing order! + */ +SD_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, + int *coords, const int dimsSize, + const int *tadDims); + +/** + * Convert coordinates to the corresponding linear index (sequence number in + * other words) for example if shape is {2, 4} and coordinates [1, 1] then index + * 5 is returned + */ +SD_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, + const Nd4jLong *coords); +SD_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, + const int *coords); +SD_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, + const uint *coords); +SD_EXPORT _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, + const int *coords); +/** + * take into account only dimensions stored in tadDims, tadDims must be sorted + * in increasing order! + */ +SD_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, + const int *coords, const int dimsSize, + const int *tadDims); + +/** + * increment n-dimensional array by one iteration by changing coord + * appropriately for example we have array with shape {2, 3}: + * - if input coord = {0,1}, then output coord = {0,2} + * - if input coord = {0,2}, then output coord = {1,0} + * so the aim is to produce following subsequence of coord: {0,0}, {0,1}, {0,2}, + * {1,0}, {1,1}, {1,2} + */ - // function calculates linear index of array min, min is sub-array of max, index to be returned is min-array's index and corresponds to maxIdx of max array - // dimsToExclude - should be sorted in increasing order - ND4J_EXPORT _CUDA_HD Nd4jLong subArrayIndex(const Nd4jLong maxIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude = nullptr, const int dimsLen = -1); +/* calculates an array buffer offset for given "index" using following formula: + * offset = coord_0*stride_0 + coord_1*stride_1 + ... + + * coord_{rank-1}*stride_{rank-1} + */ +SD_EXPORT _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo); +SD_EXPORT _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, + const Nd4jLong *shapeInfo); +SD_EXPORT _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, + const Nd4jLong *lShapeInfo, + const uint *uShapeInfo, + const bool useUnsigned); + +SD_EXPORT _CUDA_HD void printShapeInfo(Nd4jLong *shapeInfo); - // function calculates absolute offset of min array, min is sub-array of max, offset to be returned corresponds to maxIdx of max array - // dimsToExclude - should be sorted in increasing order - ND4J_EXPORT _CUDA_HD Nd4jLong subArrayOffset(const Nd4jLong maxIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude = nullptr, const int dimsLen = -1); +SD_EXPORT _CUDA_HD void printShapeInfoLinear(const Nd4jLong *shapeInfo); - // max array is outer for min array, min array is sub-array of max array - // function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs) - // dimsToExclude - should be sorted in increasing order - // dimsLen - length of dimsToExclude, if not set (= -1), then it is calculated as maxRank - minRank - ND4J_EXPORT _CUDA_HD void maxIndToMinInd(int* maxIdxs, int* minIdxs, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude = nullptr, const int dimsLen = -1); +SD_EXPORT _CUDA_HD void printShapeInfoLinear(const char *msg, + const Nd4jLong *shapeInfo); - // calculate indexes of max-array, these output indexes correspond to one minIdx index of min-array which is sub-array of max-array - // dimsToExclude - should be sorted in increasing order - ND4J_EXPORT _CUDA_HD int outerArrayIndexes(int* maxIdxs, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude = nullptr); +SD_EXPORT _CUDA_HD void printShapeInfoLinear(const char *msg, int rank, + const Nd4jLong *shape, + const Nd4jLong *strides); - // calculate offsets of max-array, these offsets correspond to one minIdx index of min-array which is sub-array of max-array - // maxOffsets - will contain calculated offsets of max-array, buffer for maxOffsets should be allocated beforehand - // dimsToExclude - should be sorted in increasing order - // memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be allocated beforehand - ND4J_EXPORT _CUDA_HD int outerArrayOffsets(Nd4jLong* maxOffsets, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, int* memBuff, const int* dimsToExclude = nullptr); +SD_EXPORT _CUDA_HD void printIntArray(const Nd4jLong *arr, const int length); +SD_EXPORT _CUDA_HD void printIntArray(const int *arr, const int length); - // calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array - // rank is equal to size of shape - ND4J_EXPORT void calcOffsets(const int rank, const Nd4jLong* shape, const Nd4jLong* strides, Nd4jLong* offsets, const char order = 'c'); - ND4J_EXPORT void calcOffsets(const Nd4jLong* shapeInfo, Nd4jLong* offsets, const char order = 'c'); - // ND4J_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order = 'c'); - // ND4J_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order = 'c'); - ND4J_EXPORT _CUDA_HD void shapeOldScalar(sd::DataType dtype, Nd4jLong* const buffer, const char order); +SD_EXPORT _CUDA_HD void printArray(float *arr, int length); + +template +SD_EXPORT _CUDA_HD void printArray(T *arr, int length, const char *message); + +SD_EXPORT _CUDA_HD Nd4jLong *shapeBufferOfNpy(int rank, unsigned int *shape, + bool fortranOrder); + +SD_EXPORT _CUDA_HD Nd4jLong *shapeBufferOfNpy(cnpy::NpyArray arr); + +// SD_EXPORT _CUDA_HD Nd4jLong *shapeBufferOfNpyBuffer(char *buffer); + +// this function checks the consistence of dimensions with array rank (negative +// dimensions, too large dimensions, too big number of dimensions) also sort +// input array of dimensions, this operation is also necessary for creating TAD +// object +SD_EXPORT _CUDA_H void checkDimensions(const int rank, + std::vector &dimensions); + +// function calculates linear index of array min, min is sub-array of max, index +// to be returned is min-array's index and corresponds to maxIdx of max array +// dimsToExclude - should be sorted in increasing order +SD_EXPORT _CUDA_HD Nd4jLong subArrayIndex(const Nd4jLong maxIdx, + const Nd4jLong *maxShapeInfo, + const Nd4jLong *minShapeInfo, + const int *dimsToExclude = nullptr, + const int dimsLen = -1); + +// function calculates absolute offset of min array, min is sub-array of max, +// offset to be returned corresponds to maxIdx of max array dimsToExclude - +// should be sorted in increasing order +SD_EXPORT _CUDA_HD Nd4jLong subArrayOffset(const Nd4jLong maxIdx, + const Nd4jLong *maxShapeInfo, + const Nd4jLong *minShapeInfo, + const int *dimsToExclude = nullptr, + const int dimsLen = -1); - // deduce order and element-wise stride - // if array is scalar or unit length vector then ews = 1 and order is preserved - // if array is common vector then ews = stride of non-unity dimension and order is preserved - // if strides are normal/contiguous then ews = 1 and corresponding order is set, otherwise ews = 0 and order is preserved - ND4J_EXPORT _CUDA_HD void checkStridesEwsAndOrder(Nd4jLong* shapeInfo, const char proposedOrder, const int numOfNonUnitDims, const Nd4jLong* shapeNoUnities, const Nd4jLong* stridesNoUnities); - ND4J_EXPORT _CUDA_HD void checkStridesEwsAndOrder(Nd4jLong* shapeInfo); +// max array is outer for min array, min array is sub-array of max array +// function calculates the coordinates of min array (and saves them into +// minIdxs) given coordinates of max array (already stored in maxIdxs) +// dimsToExclude - should be sorted in increasing order +// dimsLen - length of dimsToExclude, if not set (= -1), then it is calculated +// as maxRank - minRank +SD_EXPORT _CUDA_HD void maxIndToMinInd(int *maxIdxs, int *minIdxs, + const Nd4jLong *maxShapeInfo, + const Nd4jLong *minShapeInfo, + const int *dimsToExclude = nullptr, + const int dimsLen = -1); + +// calculate indexes of max-array, these output indexes correspond to one minIdx +// index of min-array which is sub-array of max-array dimsToExclude - should be +// sorted in increasing order +SD_EXPORT _CUDA_HD int outerArrayIndexes(int *maxIdxs, const Nd4jLong minIdx, + const Nd4jLong *maxShapeInfo, + const Nd4jLong *minShapeInfo, + const int *dimsToExclude = nullptr); + +// calculate offsets of max-array, these offsets correspond to one minIdx index +// of min-array which is sub-array of max-array maxOffsets - will contain +// calculated offsets of max-array, buffer for maxOffsets should be allocated +// beforehand dimsToExclude - should be sorted in increasing order memBuff - +// auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments +// storing, should be allocated beforehand +SD_EXPORT _CUDA_HD int outerArrayOffsets(Nd4jLong *maxOffsets, + const Nd4jLong minIdx, + const Nd4jLong *maxShapeInfo, + const Nd4jLong *minShapeInfo, + int *memBuff, + const int *dimsToExclude = nullptr); + +// calculates offsets for entities (elements or sub-arrays), shape in context of +// sub-array means dimensions excluded from outer array rank is equal to size of +// shape +SD_EXPORT void calcOffsets(const int rank, const Nd4jLong *shape, + const Nd4jLong *strides, Nd4jLong *offsets, + const char order = 'c'); +SD_EXPORT void calcOffsets(const Nd4jLong *shapeInfo, Nd4jLong *offsets, + const char order = 'c'); +// SD_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, +// const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order = 'c'); +// SD_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, +// const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, +// Nd4jLong*& zOffsets, const char order = 'c'); +SD_EXPORT _CUDA_HD void shapeOldScalar(sd::DataType dtype, + Nd4jLong *const buffer, + const char order); + +// deduce order and element-wise stride +// if array is scalar or unit length vector then ews = 1 and order is preserved +// if array is common vector then ews = stride of non-unity dimension and order +// is preserved if strides are normal/contiguous then ews = 1 and corresponding +// order is set, otherwise ews = 0 and order is preserved +SD_EXPORT _CUDA_HD void checkStridesEwsAndOrder( + Nd4jLong *shapeInfo, const char proposedOrder, const int numOfNonUnitDims, + const Nd4jLong *shapeNoUnities, const Nd4jLong *stridesNoUnities); +SD_EXPORT _CUDA_HD void checkStridesEwsAndOrder(Nd4jLong *shapeInfo); - /** - * processes whole set of sub-arrays - * evaluates shapeInfo of sub-arrays (all sub-arrays have the same shapeInfo) and their buffer offsets (each sub-array has its own unique offset from original this-buffer) - * arguments: - * wholeShapeInfo - original shapeInfo of whole array - * numOfSubArrs - number of sub-arrays, size of subArrOffsets is equal to numOfSubArrs - * dimsSize - size of dimsToExclude, if dimsSize = array rank or dimsSize = 0 it means sub-array is whole array, copy of wholeShapeInfo and one zero offset will be returned - * dimsToExclude - MUST BE SORTED, dimensions to evaluate sub-array along, i.e. when shape is [2,3,4,5] and dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5] - * subArrShapeInfo - output argument, contains shapeInfo (same for all sub-arrays) - * subArrOffsets - output argument, contains successive sub-arrays offsets from original this-buffer - * keepUnitiesInShape - if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b} - */ - ND4J_EXPORT _CUDA_HD void calcSubArrsShapeInfoAndOffsets(const Nd4jLong* wholeShapeInfo, const Nd4jLong numOfSubArrs, const int dimsSize, const int* dimsToExclude, Nd4jLong* subArrShapeInfo, Nd4jLong* subArrOffsets, bool keepUnitiesInShape = false); +/** + * processes whole set of sub-arrays + * evaluates shapeInfo of sub-arrays (all sub-arrays have the same shapeInfo) + * and their buffer offsets (each sub-array has its own unique offset from + * original this-buffer) arguments: wholeShapeInfo - original shapeInfo of whole + * array numOfSubArrs - number of sub-arrays, size of subArrOffsets is equal to + * numOfSubArrs dimsSize - size of dimsToExclude, if dimsSize = array rank or + * dimsSize = 0 it means sub-array is whole array, copy of wholeShapeInfo and + * one zero offset will be returned dimsToExclude - MUST BE SORTED, dimensions + * to evaluate sub-array along, i.e. when shape is [2,3,4,5] and + * dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5] + * subArrShapeInfo - output argument, contains shapeInfo (same for all + * sub-arrays) subArrOffsets - output argument, contains successive + * sub-arrays offsets from original this-buffer keepUnitiesInShape - if false + * then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> + * {a,b} + */ +SD_EXPORT _CUDA_HD void calcSubArrsShapeInfoAndOffsets( + const Nd4jLong *wholeShapeInfo, const Nd4jLong numOfSubArrs, + const int dimsSize, const int *dimsToExclude, Nd4jLong *subArrShapeInfo, + Nd4jLong *subArrOffsets, bool keepUnitiesInShape = false); - /** - * processes only one sub-array, evaluates shapeInfo of sub-array and its buffer offset from original array - * arguments: - * idx - input argument, intervals of indexes which define the sub-array to point on, - * when isStrided = false then idx has form {dim0Start,dim0End, dim1Start,dim1End, ....} and length (2 * maxRank) - * when isStrided = true then idx has form {dim0Start,dim0End,dim0Stride, dim1Start,dim1End,dim1Stride, ....} and length (3 * maxRank) - * when (dimStart == dimEnd) then whole range will be used for current dimension - * maxShapeInfo - input argument, shapeInfo of original array - * minShapeInfo - output argument, shapeInfo of sub-array to be deduced - * minOffset - output argument, offset of sub-array buffer offsets from original buffer - * keepUnitiesInShape - input argument, if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b} - * isStrided - input argument, if true then idx has length (3 * this->rankOf()) and contains additional stride numbers which correspond to stride between dimStart and dimEnd, - * numOfUntiesInMinShape - input argument, number of occurrences in idx when (dimEnd - dimStart) = 1 - */ - ND4J_EXPORT void calcSubArrShapeInfoAndOffset(const Nd4jLong* idx, const Nd4jLong* maxShapeInfo, Nd4jLong* minShapeInfo, Nd4jLong& minOffset, const bool keepUnitiesInShape = false, const bool isStrided = false, const int numOfUntiesInMinShape = 0); +/** + * processes only one sub-array, evaluates shapeInfo of sub-array and its buffer + * offset from original array arguments: idx - input argument, intervals of + * indexes which define the sub-array to point on, when isStrided = false then + * idx has form {dim0Start,dim0End, dim1Start,dim1End, ....} and length (2 * + * maxRank) when isStrided = true then idx has form + * {dim0Start,dim0End,dim0Stride, dim1Start,dim1End,dim1Stride, ....} and + * length (3 * maxRank) when (dimStart == dimEnd) then whole range will be used + * for current dimension maxShapeInfo - input argument, shapeInfo of original + * array minShapeInfo - output argument, shapeInfo of sub-array to be deduced + * minOffset - output argument, offset of sub-array buffer offsets from original + * buffer keepUnitiesInShape - input argument, if false then eliminate unities + * from sub-array shapeInfo, for example {1,a,1,b} -> {a,b} isStrided - input + * argument, if true then idx has length (3 * this->rankOf()) and contains + * additional stride numbers which correspond to stride between dimStart and + * dimEnd, numOfUntiesInMinShape - input argument, number of occurrences in idx + * when (dimEnd - dimStart) = 1 + */ +SD_EXPORT void calcSubArrShapeInfoAndOffset( + const Nd4jLong *idx, const Nd4jLong *maxShapeInfo, Nd4jLong *minShapeInfo, + Nd4jLong &minOffset, const bool keepUnitiesInShape = false, + const bool isStrided = false, const int numOfUntiesInMinShape = 0); - /** - * for example inShapeInfo is {3, 2,1,4, 4,4,1, 16384,1,99} - * then output shapeNoUnities will contain {2,4, 4,1} - that is only shape and strides, no rank/type/ews/order - * stridesNoUnities will point on strides in shapeNoUnities that is on {4,1} - * returns number of non-unity dimensions in inShapeInfo - * if there is no unities in inShapeInfo, then no copy procedure will be performed and shapeNoUnities/stridesNoUnities will point on corresponding places in inShapeInfo - */ - ND4J_EXPORT _CUDA_HD int excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, Nd4jLong*& shapeNoUnities, Nd4jLong*& stridesNoUnities); +/** + * for example inShapeInfo is {3, 2,1,4, 4,4,1, 16384,1,99} + * then output shapeNoUnities will contain {2,4, 4,1} - that is only shape and + * strides, no rank/type/ews/order stridesNoUnities will point on strides in + * shapeNoUnities that is on {4,1} returns number of non-unity dimensions in + * inShapeInfo if there is no unities in inShapeInfo, then no copy procedure + * will be performed and shapeNoUnities/stridesNoUnities will point on + * corresponding places in inShapeInfo + */ +SD_EXPORT _CUDA_HD int excludeUnitiesFromShapeInfo(const Nd4jLong *inShapeInfo, + Nd4jLong *&shapeNoUnities, + Nd4jLong *&stridesNoUnities); - /** - * for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, dimsToExclude = {1,3}, dimsSize = 2 - * then outShapeInfo will contain {3, 2,3,4, 12,4,1, 16384,1,99} - */ - INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, Nd4jLong* outShapeInfo); +/** + * for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, + * dimsToExclude = {1,3}, dimsSize = 2 then outShapeInfo will contain {3, 2,3,4, + * 12,4,1, 16384,1,99} + */ +INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong *inShapeInfo, + const int dimsSize, + const int *dimsToExclude, + Nd4jLong *outShapeInfo); - /** - * get stride over contiguous axis (contiguous axis must have stride = 1) - * for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1) - */ - // INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo); - - - - - - -//END HEADERS - - - //BEGIN IMPLEMENTATIONS +/** + * get stride over contiguous axis (contiguous axis must have stride = 1) + * for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then + * output is 5 (that is smallest stride in inShapeInfo except those equal to 1) + */ +// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const +// Nd4jLong* inShapeInfo); +// END HEADERS +// BEGIN IMPLEMENTATIONS #ifdef __CUDACC__ - /** -* BEWARE: THIS METHOD DOES NOT CHECKS ALLOCATION BOUNDARIES -*/ +/** + * BEWARE: THIS METHOD DOES NOT CHECKS ALLOCATION BOUNDARIES + */ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) { - Nd4jLong *ret = buffer; - ret += (threadIdx.x * size); - return ret; + Nd4jLong *ret = buffer; + ret += (threadIdx.x * size); + return ret; } #endif /** -* Length of a tad given -* the shape information -*/ - INLINEDEF _CUDA_HD Nd4jLong tadLength(const Nd4jLong *shapeInfo, int *dimension, int dimensionLength) { - if(dimensionLength == 1) { - return shape::shapeOf(shapeInfo)[dimension[0]]; - } - else { - Nd4jLong ret = 1; - for(int i = 0; i < shape::rank(shapeInfo); i++) { - for(int j = 0; j < dimensionLength; j++) { - if(i == dimension[j]) - ret *= shape::shapeOf(shapeInfo)[dimension[j]]; - } - } - return ret; - } + * Length of a tad given + * the shape information + */ +INLINEDEF _CUDA_HD Nd4jLong tadLength(const Nd4jLong *shapeInfo, int *dimension, + int dimensionLength) { + if (dimensionLength == 1) { + return shape::shapeOf(shapeInfo)[dimension[0]]; + } else { + Nd4jLong ret = 1; + for (int i = 0; i < shape::rank(shapeInfo); i++) { + for (int j = 0; j < dimensionLength; j++) { + if (i == dimension[j]) ret *= shape::shapeOf(shapeInfo)[dimension[j]]; + } } - - + return ret; + } +} /** * Tad element wise stride: @@ -1139,189 +1314,199 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) { * Again: this may not preserve ordering of the tad * but maybe used for reductions. */ - INLINEDEF _CUDA_HD int tadElementWiseStride(Nd4jLong *shapeInfo, int *dimension,int dimensionLength) { - return reductionIndexElementWiseStride(shapeInfo,dimension,dimensionLength); - } - - - INLINEDEF _CUDA_HD bool shapeEquals(const int shape1Rank, const Nd4jLong *shape1, const int shape2Rank, const Nd4jLong *shape2) { - if(shape1Rank != shape2Rank) - return false; - //rank not equals - for(int i = 0; i < shape1Rank; i++) { - if(shape1[i] != shape2[i]) - return false; - } - - return true; - } - - INLINEDEF _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2) { - return shape::shapeEquals(shape::rank(shapeInfo1), shape::shapeOf(const_cast(shapeInfo1)), shape::rank(shapeInfo2), shape::shapeOf(const_cast(shapeInfo2))); - } - - INLINEDEF _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2, const Nd4jLong *shapeInfo3) { - - return shape::shapeEquals(shapeInfo1, shapeInfo2) && shape::shapeEquals(shapeInfo1, shapeInfo3); - - } - - INLINEDEF _CUDA_HD bool strideEquals(int const shape1Rank, Nd4jLong const* shape1,int const shape2Rank,Nd4jLong const* shape2) { - if(shape1Rank != shape2Rank) - return false; - //rank not equals - for(int i = 0; i < shape1Rank; i++) { - if(shape1[i] != shape2[i]) - return false; - } - - return true; - } - - INLINEDEF _CUDA_HD bool strideEquals(Nd4jLong const* shapeInfo1,Nd4jLong const* shapeInfo2) { - return shape::strideEquals(shape::rank(shapeInfo1),shape::stride(shapeInfo1),shape::rank(shapeInfo2),shape::stride(shapeInfo2)); - - } - - INLINEDEF _CUDA_HD bool strideEquals(Nd4jLong const* stride1,int const rank1 , Nd4jLong const* stride2, int const rank2) { - if(rank1 != rank2) - return false; - - for(int i = 0; i < rank1; i++) { - if(stride1[i] != stride2[i]) - return false; - } - - return true; - } - - INLINEDEF _CUDA_HD Nd4jLong *computeResultShape(Nd4jLong const* originalShapeBuffer, int * dimension,int dimensionLength) { - Nd4jLong *retShape; - int retShapeLength; - if(dimensionLength == 1 && dimension[0] == 2147483647) { - retShape = new Nd4jLong[2]; - retShape[0] = 1; - retShape[1] = 1; - retShapeLength = 2; - } - else { - retShape = shape::removeIndex(shape::shapeOf(originalShapeBuffer), dimension, shape::shapeInfoLength(shape::rank(originalShapeBuffer)), dimensionLength); - retShapeLength = shape::rank(originalShapeBuffer) - dimensionLength; - } - //ensure vector is proper shape - if (retShapeLength == 1) { - if (dimension[0] == 0) { - auto newRetShape = new Nd4jLong[2]{1, retShape[0]}; - delete[] retShape; - retShape = newRetShape; - retShapeLength = 2; - } - else { - auto newRetShape = new Nd4jLong[2]{retShape[0], 1}; - delete[] retShape; - retShape = newRetShape; - retShapeLength = 2; - } - } else if (retShapeLength == 0) { - auto newRetShape = new Nd4jLong[2]{1, 1}; - delete[] retShape; - retShape = newRetShape; - retShapeLength = 2; - } - - auto ret = shape::shapeBuffer(retShapeLength, sd::ArrayOptions::dataType(originalShapeBuffer), retShape); - delete[] retShape; - - return ret; +INLINEDEF _CUDA_HD int tadElementWiseStride(Nd4jLong *shapeInfo, int *dimension, + int dimensionLength) { + return reductionIndexElementWiseStride(shapeInfo, dimension, dimensionLength); +} - } +INLINEDEF _CUDA_HD bool shapeEquals(const int shape1Rank, + const Nd4jLong *shape1, + const int shape2Rank, + const Nd4jLong *shape2) { + if (shape1Rank != shape2Rank) return false; + // rank not equals + for (int i = 0; i < shape1Rank; i++) { + if (shape1[i] != shape2[i]) return false; + } + + return true; +} - INLINEDEF _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(const Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride, Nd4jLong *buffer) { - Nd4jLong *theShape = shape::shapeOf(shapeInfo); - Nd4jLong *theStride = shape::stride(shapeInfo); - int rank = dimensionLength == 1 ? 2 : dimensionLength; - Nd4jLong *ret = buffer; - //set the rank - ret[0] = rank; - Nd4jLong *retShape = shape::shapeOf(ret); - Nd4jLong *retStride = shape::stride(ret); - int len = rank; - - if(dimensionLength == 1) { - if(shape::isMatrix(theShape,shape::rank(shapeInfo))) { - if(dimension[0] == 0) { - Nd4jLong newStride[2] = {theStride[dimension[0]],1}; - Nd4jLong newShape[2] = {theShape[dimension[0]],1}; - retShape[0] = newShape[0]; - retShape[1] = newShape[1]; - retStride[0] = newStride[0]; - retStride[1] = newStride[1]; - } - else { - Nd4jLong newStride[2] = {theStride[dimension[0]],1}; - Nd4jLong newShape[2] = {theShape[dimension[0]],1}; - retShape[0] = newShape[0]; - retShape[1] = newShape[1]; - retStride[0] = newStride[0]; - retStride[1] = newStride[1]; - } - } - else { - Nd4jLong newStride[2] = {1,theStride[dimension[0]]}; - Nd4jLong newShape[2] = {1,theShape[dimension[0]]}; - retShape[0] = newShape[0]; - retShape[1] = newShape[1]; - retStride[0] = newStride[0]; - retStride[1] = newStride[1]; - } +INLINEDEF _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, + const Nd4jLong *shapeInfo2) { + return shape::shapeEquals(shape::rank(shapeInfo1), + shape::shapeOf(const_cast(shapeInfo1)), + shape::rank(shapeInfo2), + shape::shapeOf(const_cast(shapeInfo2))); +} +INLINEDEF _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, + const Nd4jLong *shapeInfo2, + const Nd4jLong *shapeInfo3) { + return shape::shapeEquals(shapeInfo1, shapeInfo2) && + shape::shapeEquals(shapeInfo1, shapeInfo3); +} +INLINEDEF _CUDA_HD bool strideEquals(int const shape1Rank, + Nd4jLong const *shape1, + int const shape2Rank, + Nd4jLong const *shape2) { + if (shape1Rank != shape2Rank) return false; + // rank not equals + for (int i = 0; i < shape1Rank; i++) { + if (shape1[i] != shape2[i]) return false; + } + + return true; +} - } - else { - Nd4jLong *newIndexes = dimension; - if(reverseCopyStride) - shape::reverseCopyTo(theStride, retStride, newIndexes, len); - else - shape::copyTo(len, theStride, retStride, newIndexes); - shape::copyTo(len, theShape, retShape, newIndexes); +INLINEDEF _CUDA_HD bool strideEquals(Nd4jLong const *shapeInfo1, + Nd4jLong const *shapeInfo2) { + return shape::strideEquals(shape::rank(shapeInfo1), shape::stride(shapeInfo1), + shape::rank(shapeInfo2), + shape::stride(shapeInfo2)); +} - } +INLINEDEF _CUDA_HD bool strideEquals(Nd4jLong const *stride1, int const rank1, + Nd4jLong const *stride2, int const rank2) { + if (rank1 != rank2) return false; + for (int i = 0; i < rank1; i++) { + if (stride1[i] != stride2[i]) return false; + } - ret[shape::shapeInfoLength(rank) - 1] = shape::order(shapeInfo); - return ret; - } + return true; +} - INLINEDEF _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(const Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride) { - int rank = dimensionLength == 1 ? 2 : dimensionLength; +INLINEDEF _CUDA_HD Nd4jLong *computeResultShape( + Nd4jLong const *originalShapeBuffer, int *dimension, int dimensionLength) { + Nd4jLong *retShape; + int retShapeLength; + if (dimensionLength == 1 && dimension[0] == 2147483647) { + retShape = new Nd4jLong[2]; + retShape[0] = 1; + retShape[1] = 1; + retShapeLength = 2; + } else { + retShape = shape::removeIndex( + shape::shapeOf(originalShapeBuffer), dimension, + shape::shapeInfoLength(shape::rank(originalShapeBuffer)), + dimensionLength); + retShapeLength = shape::rank(originalShapeBuffer) - dimensionLength; + } + // ensure vector is proper shape + if (retShapeLength == 1) { + if (dimension[0] == 0) { + auto newRetShape = new Nd4jLong[2]{1, retShape[0]}; + delete[] retShape; + retShape = newRetShape; + retShapeLength = 2; + } else { + auto newRetShape = new Nd4jLong[2]{retShape[0], 1}; + delete[] retShape; + retShape = newRetShape; + retShapeLength = 2; + } + } else if (retShapeLength == 0) { + auto newRetShape = new Nd4jLong[2]{1, 1}; + delete[] retShape; + retShape = newRetShape; + retShapeLength = 2; + } + + auto ret = shape::shapeBuffer(retShapeLength, + sd::ArrayOptions::dataType(originalShapeBuffer), + retShape); + delete[] retShape; + + return ret; +} - traceNew(4); +INLINEDEF _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride( + const Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength, + bool reverseCopyStride, Nd4jLong *buffer) { + Nd4jLong *theShape = shape::shapeOf(shapeInfo); + Nd4jLong *theStride = shape::stride(shapeInfo); + int rank = dimensionLength == 1 ? 2 : dimensionLength; + Nd4jLong *ret = buffer; + // set the rank + ret[0] = rank; + Nd4jLong *retShape = shape::shapeOf(ret); + Nd4jLong *retStride = shape::stride(ret); + int len = rank; + + if (dimensionLength == 1) { + if (shape::isMatrix(theShape, shape::rank(shapeInfo))) { + if (dimension[0] == 0) { + Nd4jLong newStride[2] = {theStride[dimension[0]], 1}; + Nd4jLong newShape[2] = {theShape[dimension[0]], 1}; + retShape[0] = newShape[0]; + retShape[1] = newShape[1]; + retStride[0] = newStride[0]; + retStride[1] = newStride[1]; + } else { + Nd4jLong newStride[2] = {theStride[dimension[0]], 1}; + Nd4jLong newShape[2] = {theShape[dimension[0]], 1}; + retShape[0] = newShape[0]; + retShape[1] = newShape[1]; + retStride[0] = newStride[0]; + retStride[1] = newStride[1]; + } + } else { + Nd4jLong newStride[2] = {1, theStride[dimension[0]]}; + Nd4jLong newShape[2] = {1, theShape[dimension[0]]}; + retShape[0] = newShape[0]; + retShape[1] = newShape[1]; + retStride[0] = newStride[0]; + retStride[1] = newStride[1]; + } + + } else { + Nd4jLong *newIndexes = dimension; + if (reverseCopyStride) + shape::reverseCopyTo(theStride, retStride, newIndexes, len); + else + shape::copyTo(len, theStride, retStride, newIndexes); + shape::copyTo(len, theShape, retShape, newIndexes); + } + + ret[shape::shapeInfoLength(rank) - 1] = shape::order(shapeInfo); + return ret; +} - Nd4jLong *ret = new Nd4jLong[shape::shapeInfoLength(rank)]; - return shapeInfoOnlyShapeAndStride(shapeInfo, dimension, dimensionLength, reverseCopyStride, ret); - } +INLINEDEF _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride( + const Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength, + bool reverseCopyStride) { + int rank = dimensionLength == 1 ? 2 : dimensionLength; - INLINEDEF _CUDA_HD Nd4jLong * createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, int rank) { + traceNew(4); - traceNew(5); + Nd4jLong *ret = new Nd4jLong[shape::shapeInfoLength(rank)]; + return shapeInfoOnlyShapeAndStride(shapeInfo, dimension, dimensionLength, + reverseCopyStride, ret); +} - Nd4jLong *ret = new Nd4jLong[shape::shapeInfoLength(rank)]; +INLINEDEF _CUDA_HD Nd4jLong *createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, + int rank) { + traceNew(5); - return createShapeInfo(shape, stride, rank, ret); - } + Nd4jLong *ret = new Nd4jLong[shape::shapeInfoLength(rank)]; - INLINEDEF _CUDA_HD Nd4jLong * createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, int rank, Nd4jLong *buffer) { - buffer[0] = rank; - Nd4jLong *retShape = shape::shapeOf(buffer); - Nd4jLong *retStride = shape::stride(buffer); - for(int i = 0;i < rank; i++) { - retShape[i] = shape[i]; - retStride[i] = stride[i]; - } + return createShapeInfo(shape, stride, rank, ret); +} - return buffer; - } +INLINEDEF _CUDA_HD Nd4jLong *createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, + int rank, Nd4jLong *buffer) { + buffer[0] = rank; + Nd4jLong *retShape = shape::shapeOf(buffer); + Nd4jLong *retStride = shape::stride(buffer); + for (int i = 0; i < rank; i++) { + retShape[i] = shape[i]; + retStride[i] = stride[i]; + } + + return buffer; +} /** * Computes the standard packed array strides for a given shape. @@ -1330,50 +1515,47 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) { * @param startNum the start number for the strides * @return the strides for a matrix of n dimensions */ - INLINEDEF _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong const* shape, int rank, int startNum) { - if (isVector(shape, rank)) { - - traceNew(5); - - Nd4jLong *ret = new Nd4jLong[2]; - for (int i = 0; i < 2; i++) - ret[i] = 1; - return ret; - - } +INLINEDEF _CUDA_HD Nd4jLong *calcStridesFortran(Nd4jLong const *shape, int rank, + int startNum) { + if (isVector(shape, rank)) { + traceNew(5); - int dimensions = rank; + Nd4jLong *ret = new Nd4jLong[2]; + for (int i = 0; i < 2; i++) ret[i] = 1; + return ret; + } - traceNew(6); + int dimensions = rank; - Nd4jLong *stride = new Nd4jLong[dimensions]; - Nd4jLong st = startNum; - for (int j = 0; j < rank; j++) { - stride[j] = st; - st *= shape[j]; - } + traceNew(6); - return stride; - } + Nd4jLong *stride = new Nd4jLong[dimensions]; + Nd4jLong st = startNum; + for (int j = 0; j < rank; j++) { + stride[j] = st; + st *= shape[j]; + } - INLINEDEF _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong const* shape, int rank, int startNum, Nd4jLong *ret) { - if (isVector(shape, rank)) { - for (int i = 0; i < rank; i++) - ret[i] = 1; - return ret; + return stride; +} - } +INLINEDEF _CUDA_HD Nd4jLong *calcStridesFortran(Nd4jLong const *shape, int rank, + int startNum, Nd4jLong *ret) { + if (isVector(shape, rank)) { + for (int i = 0; i < rank; i++) ret[i] = 1; + return ret; + } - //int dimensions = rank; + // int dimensions = rank; - Nd4jLong st = startNum; - for (int j = 0; j < rank; j++) { - ret[j] = st; - st *= shape[j]; - } + Nd4jLong st = startNum; + for (int j = 0; j < rank; j++) { + ret[j] = st; + st *= shape[j]; + } - return ret; - } + return ret; +} /** * Computes the standard packed array strides for a given shape. @@ -1382,55 +1564,55 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) { * @param startNum the start number for the strides * @return the strides for a matrix of n dimensions */ - INLINEDEF _CUDA_HD Nd4jLong * calcStrides(Nd4jLong const *shape, int rank, int startNum) { - - traceNew(7); - - Nd4jLong *stride = new Nd4jLong[rank]; +INLINEDEF _CUDA_HD Nd4jLong *calcStrides(Nd4jLong const *shape, int rank, + int startNum) { + traceNew(7); - if (rank == 1) { - stride[0] = 1; - return stride; - } + Nd4jLong *stride = new Nd4jLong[rank]; + if (rank == 1) { + stride[0] = 1; + return stride; + } - // if (shape::isVector(shape, rank)) { - // for (int i = 0; i < 2; i++) - // stride[i] = 1; - // return stride; + // if (shape::isVector(shape, rank)) { + // for (int i = 0; i < 2; i++) + // stride[i] = 1; + // return stride; - // } + // } - Nd4jLong st = startNum; - for (int j = rank - 1; j >= 0; j--) { - stride[j] = st; - st *= shape[j]; - } + Nd4jLong st = startNum; + for (int j = rank - 1; j >= 0; j--) { + stride[j] = st; + st *= shape[j]; + } - return stride; - } + return stride; +} - INLINEDEF _CUDA_HD Nd4jLong * calcStrides(Nd4jLong const* shape, int rank, int startNum, Nd4jLong* ret) { - if (rank == 1) { - ret[0] = 1; - return ret; - } +INLINEDEF _CUDA_HD Nd4jLong *calcStrides(Nd4jLong const *shape, int rank, + int startNum, Nd4jLong *ret) { + if (rank == 1) { + ret[0] = 1; + return ret; + } - // if (shape::isVector(shape, rank)) { - // for (int i = 0; i < 2; i++) - // ret[i] = 1; - // return ret; + // if (shape::isVector(shape, rank)) { + // for (int i = 0; i < 2; i++) + // ret[i] = 1; + // return ret; - // } + // } - Nd4jLong st = startNum; - for (int j = rank - 1; j >= 0; j--) { - ret[j] = st; - st *= shape[j]; - } + Nd4jLong st = startNum; + for (int j = rank - 1; j >= 0; j--) { + ret[j] = st; + st *= shape[j]; + } - return ret; - } + return ret; +} /** * Computes the standard packed array strides for a given shape. @@ -1439,13 +1621,15 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) { * @param startNum the start number for the strides * @return the strides for a matrix of n dimensions */ - INLINEDEF _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong const* shape, int rank) { - return calcStridesFortran(shape, rank, 1); - } +INLINEDEF _CUDA_HD Nd4jLong *calcStridesFortran(Nd4jLong const *shape, + int rank) { + return calcStridesFortran(shape, rank, 1); +} - INLINEDEF _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong const* shape, int rank, Nd4jLong* ret) { - return calcStridesFortran(shape, rank, 1, ret); - } +INLINEDEF _CUDA_HD Nd4jLong *calcStridesFortran(Nd4jLong const *shape, int rank, + Nd4jLong *ret) { + return calcStridesFortran(shape, rank, 1, ret); +} /** * Computes the standard packed array strides for a given shape. @@ -1454,423 +1638,436 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) { * @param startNum the start number for the strides * @return the strides for a matrix of n dimensions */ - INLINEDEF _CUDA_HD Nd4jLong* calcStrides(Nd4jLong const *shape, int rank) { - return calcStrides(shape, rank, 1); - } +INLINEDEF _CUDA_HD Nd4jLong *calcStrides(Nd4jLong const *shape, int rank) { + return calcStrides(shape, rank, 1); +} - INLINEDEF _CUDA_HD Nd4jLong* calcStrides(Nd4jLong const *shape, int rank, Nd4jLong* ret) { - return calcStrides(shape, rank, 1, ret); - } +INLINEDEF _CUDA_HD Nd4jLong *calcStrides(Nd4jLong const *shape, int rank, + Nd4jLong *ret) { + return calcStrides(shape, rank, 1, ret); +} ////////////////////////////////////////////////////////////////////// - INLINEDEF _CUDA_HD void updateStrides(Nd4jLong *shapeInfo, const char order) { - int rank = shapeInfo[0]; - int doubleRank = 2*rank; - - if (rank > 0) { - if (order == 'c') { - shapeInfo[doubleRank] = 1; // set unity as last stride for c order - for (int j = 1; j < rank; ++j) { - shapeInfo[doubleRank - j] = shapeInfo[doubleRank - j + 1] * shapeInfo[rank + 1 - j]; - } - } else { - shapeInfo[rank + 1] = 1; // set unity as first stride for f order - for (int j = rank + 1; j < doubleRank; ++j) { - shapeInfo[j + 1] = shapeInfo[j] * shapeInfo[j - rank]; - } - } - } - // set last 2 elements in shapeInfo - shapeInfo[doubleRank + 2] = 1; - shapeInfo[doubleRank + 3] = (int)order; - } +INLINEDEF _CUDA_HD void updateStrides(Nd4jLong *shapeInfo, const char order) { + int rank = shapeInfo[0]; + int doubleRank = 2 * rank; + + if (rank > 0) { + if (order == 'c') { + shapeInfo[doubleRank] = 1; // set unity as last stride for c order + for (int j = 1; j < rank; ++j) { + shapeInfo[doubleRank - j] = + shapeInfo[doubleRank - j + 1] * shapeInfo[rank + 1 - j]; + } + } else { + shapeInfo[rank + 1] = 1; // set unity as first stride for f order + for (int j = rank + 1; j < doubleRank; ++j) { + shapeInfo[j + 1] = shapeInfo[j] * shapeInfo[j - rank]; + } + } + } + // set last 2 elements in shapeInfo + shapeInfo[doubleRank + 2] = 1; + shapeInfo[doubleRank + 3] = (int)order; +} ////////////////////////////////////////////////////////////////////// - INLINEDEF _CUDA_HD void updateStrides(const int rank, const Nd4jLong *shapeOnly, Nd4jLong *stridesOnly, const char order) { - - if (rank > 0) { - if (order == 'c') { - stridesOnly[rank - 1] = 1; // set unity as last stride for c order - for (int j = 1; j < rank; ++j) - stridesOnly[rank - 1 - j] = stridesOnly[rank - j] * shapeOnly[rank - j]; - } - else { - stridesOnly[0] = 1; // set unity as first stride for f order - for (int j = 1; j < rank; ++j) { - stridesOnly[j] = stridesOnly[j - 1] * shapeOnly[j - 1]; - } - } - } - } - - -// check whether input dimensions are permuted, not permuted dimensions order have to be 0,....,rank-1 - template - INLINEDEF _CUDA_HD bool isDimPermuted(const T* dimensions, const Nd4jLong dimSize ) { - for(int i=0; i dimensions[i+1]) - return true; +INLINEDEF _CUDA_HD void updateStrides(const int rank, const Nd4jLong *shapeOnly, + Nd4jLong *stridesOnly, const char order) { + if (rank > 0) { + if (order == 'c') { + stridesOnly[rank - 1] = 1; // set unity as last stride for c order + for (int j = 1; j < rank; ++j) + stridesOnly[rank - 1 - j] = stridesOnly[rank - j] * shapeOnly[rank - j]; + } else { + stridesOnly[0] = 1; // set unity as first stride for f order + for (int j = 1; j < rank; ++j) { + stridesOnly[j] = stridesOnly[j - 1] * shapeOnly[j - 1]; + } + } + } +} - return false; - } +// check whether input dimensions are permuted, not permuted dimensions order +// have to be 0,....,rank-1 +template +INLINEDEF _CUDA_HD bool isDimPermuted(const T *dimensions, + const Nd4jLong dimSize) { + for (int i = 0; i < dimSize - 1; ++i) + if (dimensions[i] > dimensions[i + 1]) return true; + return false; +} /** * @param toCopy the shape to copy * @return a copy of the original struct */ - INLINEDEF _CUDA_HD ShapeInformation *shapeCopy( ShapeInformation *toCopy) { - auto copy = new ShapeInformation; - - traceNew(8); +INLINEDEF _CUDA_HD ShapeInformation *shapeCopy(ShapeInformation *toCopy) { + auto copy = new ShapeInformation; - copy->shape = new Nd4jLong[toCopy->rank]; + traceNew(8); - memcpy(copy->shape, toCopy->shape, toCopy->rank * sizeof(Nd4jLong)); + copy->shape = new Nd4jLong[toCopy->rank]; - traceNew(9); + memcpy(copy->shape, toCopy->shape, toCopy->rank * sizeof(Nd4jLong)); - copy->stride = new Nd4jLong[toCopy->rank]; - for (int i = 0; i < toCopy->rank; i++) { - copy->stride[i] = toCopy->stride[i]; - } - copy->order = toCopy->order; - copy->rank = toCopy->rank; - copy->offset = toCopy->offset; - copy->elementWiseStride = toCopy->elementWiseStride; - return copy; - } + traceNew(9); - INLINEDEF _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong const* shape, Nd4jLong const* stride, int isFOrder) { - if (rank == 0) - return 1; + copy->stride = new Nd4jLong[toCopy->rank]; + for (int i = 0; i < toCopy->rank; i++) { + copy->stride[i] = toCopy->stride[i]; + } + copy->order = toCopy->order; + copy->rank = toCopy->rank; + copy->offset = toCopy->offset; + copy->elementWiseStride = toCopy->elementWiseStride; + return copy; +} - if(shape::isVector(shape,rank)) { - return stride[rank - 1]; +INLINEDEF _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong const *shape, + Nd4jLong const *stride, + int isFOrder) { + if (rank == 0) return 1; + + if (shape::isVector(shape, rank)) { + return stride[rank - 1]; + } + + else { + int oldnd; + Nd4jLong *oldDims = shape::copyOf(rank, shape); + Nd4jLong *oldStrides = shape::copyOf(rank, stride); + Nd4jLong np, op, last_stride; + Nd4jLong oldStart, oldStop, ok, newStart, newStop, nk; + + traceNew(10); + + auto newStrides = new Nd4jLong[rank]; + oldnd = 0; + // set the shape to be 1 x length + int newShapeRank = 2; + auto newShape = new Nd4jLong[newShapeRank]; + newShape[0] = 1; + newShape[1] = shape::prodLong(shape, rank); + + /* + * Remove axes with dimension 1 from the old array. They have no effect + * but would need special cases since their strides do not matter. + */ + for (oldStart = 0; oldStart < rank; oldStart++) { + if (shape[oldStart] != 1) { + oldDims[oldnd] = shape[oldStart]; + oldStrides[oldnd] = stride[oldStart]; + oldnd++; + } + } + + np = 1; + for (newStart = 0; newStart < newShapeRank; newStart++) { + np *= newShape[newStart]; + } + op = 1; + for (oldStart = 0; oldStart < oldnd; oldStart++) { + op *= oldDims[oldStart]; + } + if (np != op) { + /* different total sizes; no hope */ + delete[] newStrides; + delete[] newShape; + delete[] oldStrides; + delete[] oldDims; + return 0; + } + + if (np == 0) { + /* the current code does not handle 0-sized arrays, so give up */ + delete[] newStrides; + delete[] newShape; + delete[] oldStrides; + delete[] oldDims; + return 0; + } + + /* oldStart to oldStop and newStart to newStop give the axis ranges + * currently worked with */ + oldStart = 0; + oldStop = 1; + newStart = 0; + newStop = 1; + while (newStart < newShapeRank && oldStart < oldnd) { + np = newShape[newStart]; + op = oldDims[oldStart]; + + while (np != op) { + if (np < op) { + /* Misses trailing 1s, these are handled later */ + np *= newShape[newStop++]; + } else { + op *= oldDims[oldStop++]; } + } - else { - int oldnd; - Nd4jLong *oldDims = shape::copyOf(rank, shape); - Nd4jLong *oldStrides = shape::copyOf(rank, stride); - Nd4jLong np, op, last_stride; - Nd4jLong oldStart, oldStop, ok, newStart, newStop, nk; - - traceNew(10); - - auto newStrides = new Nd4jLong[rank]; - oldnd = 0; - //set the shape to be 1 x length - int newShapeRank = 2; - auto newShape = new Nd4jLong[newShapeRank]; - newShape[0] = 1; - newShape[1] = shape::prodLong(shape, rank); - - /* - * Remove axes with dimension 1 from the old array. They have no effect - * but would need special cases since their strides do not matter. - */ - for (oldStart = 0; oldStart < rank; oldStart++) { - if (shape[oldStart] != 1) { - oldDims[oldnd] = shape[oldStart]; - oldStrides[oldnd] = stride[oldStart]; - oldnd++; - } - } - - np = 1; - for (newStart = 0; newStart < newShapeRank; newStart++) { - np *= newShape[newStart]; - } - op = 1; - for (oldStart = 0; oldStart < oldnd; oldStart++) { - op *= oldDims[oldStart]; - } - if (np != op) { -/* different total sizes; no hope */ - delete[] newStrides; - delete[] newShape; - delete[] oldStrides; - delete[] oldDims; - return 0; - } - - if (np == 0) { -/* the current code does not handle 0-sized arrays, so give up */ - delete[] newStrides; - delete[] newShape; - delete[] oldStrides; - delete[] oldDims; - return 0; - } - -/* oldStart to oldStop and newStart to newStop give the axis ranges currently worked with */ - oldStart = 0; - oldStop = 1; - newStart = 0; - newStop = 1; - while (newStart < newShapeRank && oldStart < oldnd) { - np = newShape[newStart]; - op = oldDims[oldStart]; - - while (np != op) { - if (np < op) { -/* Misses trailing 1s, these are handled later */ - np *= newShape[newStop++]; - } else { - op *= oldDims[oldStop++]; - } - } - -/* Check whether the original axes can be combined */ - for (ok = oldStart; ok < oldStop - 1; ok++) { - if (isFOrder) { - if (oldStrides[ok + 1] != oldDims[ok] * oldStrides[ok]) { -/* not contiguous enough */ - delete[] newStrides; - delete[] newShape; - delete[] oldStrides; - delete[] oldDims; - return 0; - } - } else { -/* C order */ - if (oldStrides[ok] != oldDims[ok + 1] * oldStrides[ok + 1]) { -/* not contiguous enough */ - delete[] newStrides; - delete[] newShape; - delete[] oldStrides; - delete[] oldDims; - return 0; - } - } - } - -/* Calculate new strides for all axes currently worked with */ - if (isFOrder) { - newStrides[newStart] = oldStrides[oldStart]; - for (nk = newStart + 1; nk < newStop; nk++) { - newStrides[nk] = newStrides[nk - 1] * newShape[nk - 1]; - } - } else { -/* C order */ - newStrides[newStop - 1] = oldStrides[oldStop - 1]; - for (nk = newStop - 1; nk > newStart; nk--) { - newStrides[nk - 1] = newStrides[nk] * newShape[nk]; - } - } - newStart = newStop++; - oldStart = oldStop++; - } - -/* - * Set strides corresponding to trailing 1s of the new shape. - */ - if (newStart >= 1) { - last_stride = newStrides[newStart - 1]; - } else { - last_stride = stride[rank - 1]; - } - if (isFOrder) { - if (newStart >= 1) - last_stride *= newShape[newStart - 1]; - } - for (nk = newStart; nk < newShapeRank; nk++) { - newStrides[nk] = last_stride; - } -//returns the last element of the new stride array - int ret = last_stride; + /* Check whether the original axes can be combined */ + for (ok = oldStart; ok < oldStop - 1; ok++) { + if (isFOrder) { + if (oldStrides[ok + 1] != oldDims[ok] * oldStrides[ok]) { + /* not contiguous enough */ + delete[] newStrides; + delete[] newShape; + delete[] oldStrides; + delete[] oldDims; + return 0; + } + } else { + /* C order */ + if (oldStrides[ok] != oldDims[ok + 1] * oldStrides[ok + 1]) { + /* not contiguous enough */ delete[] newStrides; delete[] newShape; delete[] oldStrides; delete[] oldDims; - return ret; + return 0; + } } + } - + /* Calculate new strides for all axes currently worked with */ + if (isFOrder) { + newStrides[newStart] = oldStrides[oldStart]; + for (nk = newStart + 1; nk < newStop; nk++) { + newStrides[nk] = newStrides[nk - 1] * newShape[nk - 1]; + } + } else { + /* C order */ + newStrides[newStop - 1] = oldStrides[oldStop - 1]; + for (nk = newStop - 1; nk > newStart; nk--) { + newStrides[nk - 1] = newStrides[nk] * newShape[nk]; + } + } + newStart = newStop++; + oldStart = oldStop++; } - INLINEDEF _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong const* shape, Nd4jLong const* stride, int isFOrder, - Nd4jLong const* dimension, int dimensionLength) { - if(dimensionLength == 1) { - return stride[dimension[0]]; - } - return 0; + /* + * Set strides corresponding to trailing 1s of the new shape. + */ + if (newStart >= 1) { + last_stride = newStrides[newStart - 1]; + } else { + last_stride = stride[rank - 1]; + } + if (isFOrder) { + if (newStart >= 1) last_stride *= newShape[newStart - 1]; + } + for (nk = newStart; nk < newShapeRank; nk++) { + newStrides[nk] = last_stride; + } + // returns the last element of the new stride array + int ret = last_stride; + delete[] newStrides; + delete[] newShape; + delete[] oldStrides; + delete[] oldDims; + return ret; + } +} - } +INLINEDEF _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong const *shape, + Nd4jLong const *stride, + int isFOrder, + Nd4jLong const *dimension, + int dimensionLength) { + if (dimensionLength == 1) { + return stride[dimension[0]]; + } + return 0; +} /** * Get the shape info buffer * for the given rank and shape. */ - INLINEDEF _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong const* shape) { - Nd4jLong *stride = shape::calcStrides(shape, rank); - - traceNew(11); - - auto shapeInfo = new shape::ShapeInformation(); - shapeInfo->shape = const_cast(shape); - shapeInfo->stride = stride; - shapeInfo->offset = 0; - shapeInfo->rank = rank; - int elementWiseStride = shape::computeElementWiseStride(rank, shape, stride, 0); - shapeInfo->order = 'c'; - shapeInfo->elementWiseStride = elementWiseStride; - auto shapeInfoBuffer = shape::toShapeBuffer(shapeInfo); - delete[] stride; - delete shapeInfo; - sd::ArrayOptions::setDataType(shapeInfoBuffer, dtype); - return shapeInfoBuffer; - } - - /** - * This is special method, it returns ONLY 2D shapebuffer. - * - * This method is used only for SoftMax - */ - INLINEDEF _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong const* shape, Nd4jLong *buffer) { - Nd4jLong stride[MAX_RANK]; - shape::calcStrides(shape,rank, stride); - - - shape::ShapeInformation shapeInfo; - shapeInfo.shape = const_cast(shape); - shapeInfo.stride = stride; - shapeInfo.offset = 0; - shapeInfo.rank = rank; - auto elementWiseStride = shape::computeElementWiseStride(rank, shape, stride, 0); - - shapeInfo.order = 'c'; - shapeInfo.elementWiseStride = elementWiseStride; - shape::toShapeBuffer(&shapeInfo, buffer); - sd::ArrayOptions::setDataType(buffer, dtype); - return buffer; - } +INLINEDEF _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, + Nd4jLong const *shape) { + Nd4jLong *stride = shape::calcStrides(shape, rank); + + traceNew(11); + + auto shapeInfo = new shape::ShapeInformation(); + shapeInfo->shape = const_cast(shape); + shapeInfo->stride = stride; + shapeInfo->offset = 0; + shapeInfo->rank = rank; + int elementWiseStride = + shape::computeElementWiseStride(rank, shape, stride, 0); + shapeInfo->order = 'c'; + shapeInfo->elementWiseStride = elementWiseStride; + auto shapeInfoBuffer = shape::toShapeBuffer(shapeInfo); + delete[] stride; + delete shapeInfo; + sd::ArrayOptions::setDataType(shapeInfoBuffer, dtype); + return shapeInfoBuffer; +} /** -* Get the shape info buffer -* for the given rank and shape. -*/ - INLINEDEF _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong const* shape) { - auto stride = shape::calcStridesFortran(shape,rank); - - traceNew(12); - - auto shapeInfo = new shape::ShapeInformation(); - shapeInfo->shape = const_cast(shape); - shapeInfo->stride = stride; - shapeInfo->offset = 0; - shapeInfo->rank = rank; - int elementWiseStride = shape::computeElementWiseStride(rank, shape, stride, 0); - - shapeInfo->order = 'f'; - shapeInfo->elementWiseStride = elementWiseStride; - auto shapeInfoBuffer = shape::toShapeBuffer(shapeInfo); - delete[] stride; - delete shapeInfo; - sd::ArrayOptions::setDataType(shapeInfoBuffer, dtype); - return shapeInfoBuffer; - } - - INLINEDEF _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong const *shape, Nd4jLong *output) { - Nd4jLong stride[MAX_RANK]; - shape::calcStridesFortran(shape,rank, stride); - + * This is special method, it returns ONLY 2D shapebuffer. + * + * This method is used only for SoftMax + */ +INLINEDEF _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, + Nd4jLong const *shape, + Nd4jLong *buffer) { + Nd4jLong stride[MAX_RANK]; + shape::calcStrides(shape, rank, stride); + + shape::ShapeInformation shapeInfo; + shapeInfo.shape = const_cast(shape); + shapeInfo.stride = stride; + shapeInfo.offset = 0; + shapeInfo.rank = rank; + auto elementWiseStride = + shape::computeElementWiseStride(rank, shape, stride, 0); + + shapeInfo.order = 'c'; + shapeInfo.elementWiseStride = elementWiseStride; + shape::toShapeBuffer(&shapeInfo, buffer); + sd::ArrayOptions::setDataType(buffer, dtype); + return buffer; +} - shape::ShapeInformation shapeInfo; - shapeInfo.shape = const_cast(shape); - shapeInfo.stride = stride; - shapeInfo.offset = 0; - shapeInfo.rank = rank; - auto elementWiseStride = shape::computeElementWiseStride(rank, shape, stride, 0); +/** + * Get the shape info buffer + * for the given rank and shape. + */ +INLINEDEF _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, + Nd4jLong const *shape) { + auto stride = shape::calcStridesFortran(shape, rank); + + traceNew(12); + + auto shapeInfo = new shape::ShapeInformation(); + shapeInfo->shape = const_cast(shape); + shapeInfo->stride = stride; + shapeInfo->offset = 0; + shapeInfo->rank = rank; + int elementWiseStride = + shape::computeElementWiseStride(rank, shape, stride, 0); + + shapeInfo->order = 'f'; + shapeInfo->elementWiseStride = elementWiseStride; + auto shapeInfoBuffer = shape::toShapeBuffer(shapeInfo); + delete[] stride; + delete shapeInfo; + sd::ArrayOptions::setDataType(shapeInfoBuffer, dtype); + return shapeInfoBuffer; +} - shapeInfo.order = 'f'; - shapeInfo.elementWiseStride = elementWiseStride; - shape::toShapeBuffer(&shapeInfo, output); - sd::ArrayOptions::setDataType(output, dtype); - return output; - } +INLINEDEF _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, + Nd4jLong const *shape, + Nd4jLong *output) { + Nd4jLong stride[MAX_RANK]; + shape::calcStridesFortran(shape, rank, stride); + + shape::ShapeInformation shapeInfo; + shapeInfo.shape = const_cast(shape); + shapeInfo.stride = stride; + shapeInfo.offset = 0; + shapeInfo.rank = rank; + auto elementWiseStride = + shape::computeElementWiseStride(rank, shape, stride, 0); + + shapeInfo.order = 'f'; + shapeInfo.elementWiseStride = elementWiseStride; + shape::toShapeBuffer(&shapeInfo, output); + sd::ArrayOptions::setDataType(output, dtype); + return output; +} ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const Nd4jLong *indices) { - - Nd4jLong index, shift = 1;; - - index = indices[shapeInfo[0] - 1]; - for(uint i = shapeInfo[0]; i > 1; --i) { - shift *= shapeInfo[i]; - index += shift * indices[i - 2]; - } - - return index; +INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, + const Nd4jLong *indices) { + Nd4jLong index, shift = 1; + ; + + index = indices[shapeInfo[0] - 1]; + for (uint i = shapeInfo[0]; i > 1; --i) { + shift *= shapeInfo[i]; + index += shift * indices[i - 2]; + } + + return index; } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int *coords) { - - Nd4jLong index, shift = 1;; - - index = coords[shapeInfo[0] - 1]; - for(uint i = shapeInfo[0]; i > 1; --i) { - shift *= shapeInfo[i]; - index += shift * coords[i - 2]; - } - - return index; +INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, + const int *coords) { + Nd4jLong index, shift = 1; + ; + + index = coords[shapeInfo[0] - 1]; + for (uint i = shapeInfo[0]; i > 1; --i) { + shift *= shapeInfo[i]; + index += shift * coords[i - 2]; + } + + return index; } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const uint *coords) { - - Nd4jLong index, shift = 1;; - - index = coords[shapeInfo[0] - 1]; - for(uint i = shapeInfo[0]; i > 1; --i) { - shift *= shapeInfo[i]; - index += shift * coords[i - 2]; - } - - return index; +INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, + const uint *coords) { + Nd4jLong index, shift = 1; + ; + + index = coords[shapeInfo[0] - 1]; + for (uint i = shapeInfo[0]; i > 1; --i) { + shift *= shapeInfo[i]; + index += shift * coords[i - 2]; + } + + return index; } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, const int *indices) { - - Nd4jLong index, shift = 1;; - - index = indices[rank - 1]; - for(uint i = rank - 1; i >= 1; --i) { - shift *= shape[i]; - index += shift * indices[i - 1]; - } - - return index; +INLINEDEF _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, + const int *indices) { + Nd4jLong index, shift = 1; + ; + + index = indices[rank - 1]; + for (uint i = rank - 1; i >= 1; --i) { + shift *= shape[i]; + index += shift * indices[i - 1]; + } + + return index; } -INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int *coords, const int dimsSize, const int* tadDims) { +INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, + const int *coords, const int dimsSize, + const int *tadDims) { + Nd4jLong index, shift = 1; + ; - Nd4jLong index, shift = 1;; + index = coords[tadDims[dimsSize - 1]]; + for (uint i = dimsSize - 1; i >= 1; --i) { + shift *= shapeInfo[tadDims[i]]; + index += shift * coords[i - 1]; + } - index = coords[tadDims[dimsSize - 1]]; - for(uint i = dimsSize - 1; i >= 1; --i) { - shift *= shapeInfo[tadDims[i]]; - index += shift * coords[i - 1]; - } - - return index; + return index; } template - INLINEDEF _CUDA_HD void fill(T* buffer, T value, Nd4jLong length) { - - PRAGMA_OMP_SIMD - for (int e = 0; e < length; e++) - buffer[e] = value; - } - +INLINEDEF _CUDA_HD void fill(T *buffer, T value, Nd4jLong length) { + PRAGMA_OMP_SIMD + for (int e = 0; e < length; e++) buffer[e] = value; +} // ////////////////////////////////////////////////////////////////////// -// INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong arrLen) { +// INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong +// *shapeInfo, Nd4jLong arrLen) { // const Nd4jLong ews = shapeInfo[shapeInfo[0] + shapeInfo[0] + 2]; @@ -1892,7 +2089,8 @@ template // return offset; // } -// INLINEDEF _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo, uint arrLen) { +// INLINEDEF _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo, +// uint arrLen) { // const uint rank = shapeInfo[0]; // const uint ews = shapeInfo[rank + rank + 2]; @@ -1916,61 +2114,58 @@ template // } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo) { - - if (shapeInfo[2 * shapeInfo[0] + 3] == 99) { - - const Nd4jLong ews = shapeInfo[2 * shapeInfo[0] + 2]; - if (ews == 1) - return index; - else if(ews > 1) - return ews * index; - } - - Nd4jLong offset = 0; - - for(uint i = shapeInfo[0]; i > 1; --i) { - offset += (index % shapeInfo[i]) * shapeInfo[i + shapeInfo[0]]; - index /= shapeInfo[i]; - } - - offset += index * shapeInfo[1 + shapeInfo[0]]; // last iteration - - return offset; +INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, + const Nd4jLong *shapeInfo) { + if (shapeInfo[2 * shapeInfo[0] + 3] == 99) { + const Nd4jLong ews = shapeInfo[2 * shapeInfo[0] + 2]; + if (ews == 1) + return index; + else if (ews > 1) + return ews * index; + } + + Nd4jLong offset = 0; + + for (uint i = shapeInfo[0]; i > 1; --i) { + offset += (index % shapeInfo[i]) * shapeInfo[i + shapeInfo[0]]; + index /= shapeInfo[i]; + } + + offset += index * shapeInfo[1 + shapeInfo[0]]; // last iteration + + return offset; } ////////////////////////////////////////////////////////////////////// INLINEDEF _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo) { + if (shapeInfo[2 * shapeInfo[0] + 3] == 99) { + const Nd4jLong ews = shapeInfo[2 * shapeInfo[0] + 2]; + if (ews == 1) + return index; + else if (ews > 1) + return ews * index; + } - if (shapeInfo[2 * shapeInfo[0] + 3] == 99) { - - const Nd4jLong ews = shapeInfo[2 * shapeInfo[0] + 2]; - if (ews == 1) - return index; - else if(ews > 1) - return ews * index; - } - - uint offset = 0; + uint offset = 0; - for(uint i = shapeInfo[0]; i > 1; --i) { - offset += (index % shapeInfo[i]) * shapeInfo[i + shapeInfo[0]]; - index /= shapeInfo[i]; - } + for (uint i = shapeInfo[0]; i > 1; --i) { + offset += (index % shapeInfo[i]) * shapeInfo[i + shapeInfo[0]]; + index /= shapeInfo[i]; + } - offset += index * shapeInfo[1 + shapeInfo[0]]; // last iteration + offset += index * shapeInfo[1 + shapeInfo[0]]; // last iteration - return offset; + return offset; } - ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeInfo, const uint* uShapeInfo, const bool useUnsigned) { - - if(useUnsigned) - return getIndexOffset(static_cast(index), uShapeInfo); +INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, + const Nd4jLong *lShapeInfo, + const uint *uShapeInfo, + const bool useUnsigned) { + if (useUnsigned) return getIndexOffset(static_cast(index), uShapeInfo); - return getIndexOffset(index, lShapeInfo); + return getIndexOffset(index, lShapeInfo); } /** @@ -1980,14 +2175,15 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn * @param rearrange * @return */ - INLINEDEF _CUDA_HD Nd4jLong *doPermuteSwap(int length, Nd4jLong *shape, int *rearrange) { - traceNew(16); - Nd4jLong *ret = new Nd4jLong[length]; - for (int i = 0; i < length; i++) { - ret[i] = shape[rearrange[i]]; - } - return ret; - } +INLINEDEF _CUDA_HD Nd4jLong *doPermuteSwap(int length, Nd4jLong *shape, + int *rearrange) { + traceNew(16); + Nd4jLong *ret = new Nd4jLong[length]; + for (int i = 0; i < length; i++) { + ret[i] = shape[rearrange[i]]; + } + return ret; +} /** * @@ -1996,125 +2192,128 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn * @param rearrange * @return */ - INLINEDEF _CUDA_HD void doPermuteSwap(int length, Nd4jLong **shape, int *rearrange) { - if(length == 1) { - return; - } - else { - Nd4jLong *shapeDeref = *shape; - if(shape::prodLong(shapeDeref,length) < 2) { - return; - } - } +INLINEDEF _CUDA_HD void doPermuteSwap(int length, Nd4jLong **shape, + int *rearrange) { + if (length == 1) { + return; + } else { + Nd4jLong *shapeDeref = *shape; + if (shape::prodLong(shapeDeref, length) < 2) { + return; + } + } + + bool inOrder = true; + for (int i = 0; i < length - 1; i++) { + inOrder = inOrder && rearrange[i] + 1 == rearrange[i + 1]; + } + + // all in order, nothing to do + if (inOrder) return; + + Nd4jLong *shapeDeref = *shape; + // we know they are just reversed, dimension length of 2 + if (length == 2) { + auto shapeFirst = shapeDeref[0]; + auto shapeSecond = shapeDeref[1]; + shapeDeref[0] = shapeSecond; + shapeDeref[1] = shapeFirst; + return; + } else if (length == 1) { + // no permute + return; + } + + auto temp = new Nd4jLong[length]; + memcpy(temp, shapeDeref, sizeof(Nd4jLong) * length); + for (int i = 0; i < length; i++) { + shapeDeref[i] = temp[rearrange[i]]; + } + + delete[] temp; +} - bool inOrder = true; - for(int i = 0; i < length - 1; i++) { - inOrder = inOrder && rearrange[i] + 1 == rearrange[i + 1]; +INLINEDEF _CUDA_HD void permuteShapeBufferInPlace(Nd4jLong *shapeBuffer, + int *rearrange, + Nd4jLong *out) { + if (shapeBuffer != out) + memcpy(out, shapeBuffer, + sizeof(Nd4jLong) * shape::shapeInfoLength(shapeBuffer)); - } + shape::doPermuteShapeInfo(out, rearrange); +} - //all in order, nothing to do - if(inOrder) - return; +INLINEDEF _CUDA_HD Nd4jLong *permuteShapeBuffer(Nd4jLong const *shapeBuffer, + int *rearrange) { + auto len = shape::shapeInfoLength(shape::rank(shapeBuffer)); + Nd4jLong *copy = shape::copyOf(len, shapeBuffer); + shape::doPermuteShapeInfo(copy, rearrange); + return copy; +} +INLINEDEF _CUDA_HD void doPermuteShapeInfo(Nd4jLong *shapeInfo, + const int *rearrange, Nd4jLong len) { + if (len == -1) // calculate array length if it is not given + len = shape::length(shapeInfo); - Nd4jLong *shapeDeref = *shape; - //we know they are just reversed, dimension length of 2 - if(length == 2) { - auto shapeFirst = shapeDeref[0]; - auto shapeSecond = shapeDeref[1]; - shapeDeref[0] = shapeSecond; - shapeDeref[1] = shapeFirst; - return; - } - else if(length == 1) { - //no permute - return; - } + // check whether shape is like {1} or {1,1} or {1,1,1,1,...} - in this case we + // don't need permute + if (len == 1) return; - auto temp = new Nd4jLong[length]; - memcpy(temp,shapeDeref,sizeof(Nd4jLong) * length); - for (int i = 0; i < length; i++) { - shapeDeref[i] = temp[rearrange[i]]; - } + const int rank = shape::rank(shapeInfo); - delete[] temp; + // check whether rearrange is like {0,1,2,3,...} - in this case we don't need + // permute as well + bool isPermutNecessary = false; + for (int i = 0; i < rank; ++i) + if (rearrange[i] != i) { + isPermutNecessary = true; + break; } + if (!isPermutNecessary) return; - INLINEDEF _CUDA_HD void permuteShapeBufferInPlace(Nd4jLong *shapeBuffer, int *rearrange, Nd4jLong *out) { - if(shapeBuffer != out) - memcpy(out,shapeBuffer,sizeof(Nd4jLong) * shape::shapeInfoLength(shapeBuffer)); - - shape::doPermuteShapeInfo(out, rearrange); - } - - INLINEDEF _CUDA_HD Nd4jLong *permuteShapeBuffer(Nd4jLong const* shapeBuffer, int* rearrange) { - auto len = shape::shapeInfoLength(shape::rank(shapeBuffer)); - Nd4jLong *copy = shape::copyOf(len, shapeBuffer); - shape::doPermuteShapeInfo(copy,rearrange); - return copy; + // check whether rearrange contains correct indexes + for (int i = 0; i < rank; ++i) + if (rearrange[i] >= rank || rearrange[i] < 0) { + printf( + "shape::doPermuteShapeInfo function failed: rearrange indexes are " + "incorrect !\n"); + return; } - INLINEDEF _CUDA_HD void doPermuteShapeInfo(Nd4jLong *shapeInfo, const int *rearrange, Nd4jLong len) { - - if(len == -1) // calculate array length if it is not given - len = shape::length(shapeInfo); - - //check whether shape is like {1} or {1,1} or {1,1,1,1,...} - in this case we don't need permute - if(len == 1) - return; - - const int rank = shape::rank(shapeInfo); - - // check whether rearrange is like {0,1,2,3,...} - in this case we don't need permute as well - bool isPermutNecessary = false; - for(int i = 0; i < rank; ++i) - if(rearrange[i] != i) { - isPermutNecessary = true; - break; - } - - if(!isPermutNecessary) - return; - - // check whether rearrange contains correct indexes - for(int i = 0; i < rank; ++i) - if(rearrange[i] >= rank || rearrange[i] < 0) { - printf("shape::doPermuteShapeInfo function failed: rearrange indexes are incorrect !\n"); - return; - } - - // if everything is ok then perform permute - auto temp = new Nd4jLong[shape::shapeInfoLength(rank) - 3]; - memcpy(temp, shapeInfo, sizeof(Nd4jLong) * (shape::shapeInfoLength(rank) - 3)); - for (int i = 0; i < rank; ++i) { - shapeInfo[i + 1] = temp[rearrange[i] + 1]; - shapeInfo[i + 1 + rank] = temp[rearrange[i] + 1 + rank]; - } - - shape::checkStridesEwsAndOrder(shapeInfo); + // if everything is ok then perform permute + auto temp = new Nd4jLong[shape::shapeInfoLength(rank) - 3]; + memcpy(temp, shapeInfo, + sizeof(Nd4jLong) * (shape::shapeInfoLength(rank) - 3)); + for (int i = 0; i < rank; ++i) { + shapeInfo[i + 1] = temp[rearrange[i] + 1]; + shapeInfo[i + 1 + rank] = temp[rearrange[i] + 1 + rank]; + } - delete[] temp; - } + shape::checkStridesEwsAndOrder(shapeInfo); + delete[] temp; +} - INLINEDEF _CUDA_HD Nd4jLong *createPermuteIndexes(int originalRank, int *dimension,int dimensionLength) { - int delta = originalRank - dimensionLength; +INLINEDEF _CUDA_HD Nd4jLong *createPermuteIndexes(int originalRank, + int *dimension, + int dimensionLength) { + int delta = originalRank - dimensionLength; - traceNew(17); + traceNew(17); - Nd4jLong *ret = new Nd4jLong[originalRank]; - for(int i = 0; i < delta; i++) { - ret[i] = i + dimensionLength; - } + Nd4jLong *ret = new Nd4jLong[originalRank]; + for (int i = 0; i < delta; i++) { + ret[i] = i + dimensionLength; + } - for(int i = delta; i < originalRank; i++) { - ret[i] = i - delta; - } + for (int i = delta; i < originalRank; i++) { + ret[i] = i - delta; + } - return ret; - } + return ret; +} /** * Get the ordering for the device @@ -2124,56 +2323,50 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn * @param elementStride * @return */ - INLINEDEF _CUDA_HD char getOrder(int length, Nd4jLong *shape, Nd4jLong *stride, int elementStride) { - Nd4jLong sd = 1; - int dim = -1; - int i = -1; - int cContiguous = 1; - int isFortran = 1; - - for (i = length - 1; i >= 0; --i) { - dim = shape[i]; - - if (stride[i] != sd) { - cContiguous = 0; - break; - } - /* contiguous, if it got this far */ - if (dim == 0) { - break; - } - sd *= dim; - - } - - /* check if fortran contiguous */ - sd = elementStride; - for (i = 0; i < length; ++i) { - dim = shape[i]; - if (stride[i] != sd) { - isFortran = 0; - } - if (dim == 0) { - break; - } - sd *= dim; - - } - - if (isFortran && cContiguous) - return 'a'; - else if (isFortran && !cContiguous) - return 'f'; - else if (!isFortran && !cContiguous) - return 'c'; - else - return 'c'; - - } - - - - +INLINEDEF _CUDA_HD char getOrder(int length, Nd4jLong *shape, Nd4jLong *stride, + int elementStride) { + Nd4jLong sd = 1; + int dim = -1; + int i = -1; + int cContiguous = 1; + int isFortran = 1; + + for (i = length - 1; i >= 0; --i) { + dim = shape[i]; + + if (stride[i] != sd) { + cContiguous = 0; + break; + } + /* contiguous, if it got this far */ + if (dim == 0) { + break; + } + sd *= dim; + } + + /* check if fortran contiguous */ + sd = elementStride; + for (i = 0; i < length; ++i) { + dim = shape[i]; + if (stride[i] != sd) { + isFortran = 0; + } + if (dim == 0) { + break; + } + sd *= dim; + } + + if (isFortran && cContiguous) + return 'a'; + else if (isFortran && !cContiguous) + return 'f'; + else if (!isFortran && !cContiguous) + return 'c'; + else + return 'c'; +} /** * Ensure that every value in the re arrange @@ -2185,33 +2378,30 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn * @return */ - template - INLINEDEF _CUDA_HD int checkArrangeArray(T *arr, int arrLength, int shapeLength) { - if (arrLength != shapeLength) - return -1; - for (int i = 0; i < arrLength; i++) { - if (arr[i] >= arrLength || arr[i] < 0) - return -1; - } - - for (int i = 0; i < arrLength; i++) { - for (int j = 0; j < arrLength; j++) { - if (i != j && arr[i] == arr[j]) - return -1; - } - } +template +INLINEDEF _CUDA_HD int checkArrangeArray(T *arr, int arrLength, + int shapeLength) { + if (arrLength != shapeLength) return -1; + for (int i = 0; i < arrLength; i++) { + if (arr[i] >= arrLength || arr[i] < 0) return -1; + } - return 1; + for (int i = 0; i < arrLength; i++) { + for (int j = 0; j < arrLength; j++) { + if (i != j && arr[i] == arr[j]) return -1; } + } + return 1; +} - INLINEDEF _CUDA_HD void traceNew(int id) { - //printf("new happened: [%i]\n", id); +INLINEDEF _CUDA_HD void traceNew(int id){ +// printf("new happened: [%i]\n", id); #ifndef __CUDACC__ - //fflush(stdout); +// fflush(stdout); #endif - } +} /** * Permute the shape information @@ -2219,18 +2409,16 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn * @param rearrange the order to re arrange * @param rank the rank of the rearrange array */ - INLINEDEF _CUDA_HD void permute(ShapeInformation **info, int *rearrange, int rank) { - ShapeInformation *infoDeref = *info; - checkArrangeArray(rearrange, rank, rank); - shape::doPermuteSwap(rank, &infoDeref->shape, rearrange); - shape::doPermuteSwap(rank, &infoDeref->stride, rearrange); - char order = getOrder(rank, - infoDeref->shape, - infoDeref->stride, - infoDeref->elementWiseStride); - infoDeref->order = order; - - } +INLINEDEF _CUDA_HD + void permute(ShapeInformation **info, int *rearrange, int rank) { + ShapeInformation *infoDeref = *info; + checkArrangeArray(rearrange, rank, rank); + shape::doPermuteSwap(rank, &infoDeref->shape, rearrange); + shape::doPermuteSwap(rank, &infoDeref->stride, rearrange); + char order = getOrder(rank, infoDeref->shape, infoDeref->stride, + infoDeref->elementWiseStride); + infoDeref->order = order; +} /** * Returns whether the @@ -2238,182 +2426,176 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn * @param shape the shape of the array * @param rank the rank of the shape */ - INLINEDEF _CUDA_HD int isVector(Nd4jLong const* shape, int rank) { - if (rank == 0) - return 0; - - if (rank == 1) - return 1; +INLINEDEF _CUDA_HD int isVector(Nd4jLong const *shape, int rank) { + if (rank == 0) return 0; - if (rank > 2) - return 0; - else if (rank <= 2) { - if (shape[0] == 1 || shape[1] == 1) - return 1; - } - return 0; - } + if (rank == 1) return 1; - INLINEDEF _CUDA_HD bool isLikeVector(Nd4jLong const* shapeInfo, int& posOfNonUnityDim) { - - int numOfNonUnity = 0; - for(int i = 1; i <= shapeInfo[0]; ++i) { - if(shapeInfo[i] != 1) { - ++numOfNonUnity; - posOfNonUnityDim = i-1; - } - } + if (rank > 2) + return 0; + else if (rank <= 2) { + if (shape[0] == 1 || shape[1] == 1) return 1; + } + return 0; +} - return numOfNonUnity == 1 && shapeInfo[0] > 2; +INLINEDEF _CUDA_HD bool isLikeVector(Nd4jLong const *shapeInfo, + int &posOfNonUnityDim) { + int numOfNonUnity = 0; + for (int i = 1; i <= shapeInfo[0]; ++i) { + if (shapeInfo[i] != 1) { + ++numOfNonUnity; + posOfNonUnityDim = i - 1; } + } - INLINEDEF _CUDA_HD bool isCommonVector(const Nd4jLong *shapeInfo, int& posOfNonUnityDim) { + return numOfNonUnity == 1 && shapeInfo[0] > 2; +} - if(rank(shapeInfo) > 0 && length(shapeInfo) == 1) { - posOfNonUnityDim = -1; - return true; - } +INLINEDEF _CUDA_HD bool isCommonVector(const Nd4jLong *shapeInfo, + int &posOfNonUnityDim) { + if (rank(shapeInfo) > 0 && length(shapeInfo) == 1) { + posOfNonUnityDim = -1; + return true; + } - int numOfNonUnity = 0; - for(int i = 1; i <= shapeInfo[0]; ++i) { - if(shapeInfo[i] != 1) { - ++numOfNonUnity; - posOfNonUnityDim = i-1; - } - } - return numOfNonUnity == 1; + int numOfNonUnity = 0; + for (int i = 1; i <= shapeInfo[0]; ++i) { + if (shapeInfo[i] != 1) { + ++numOfNonUnity; + posOfNonUnityDim = i - 1; } + } + return numOfNonUnity == 1; +} - INLINEDEF _CUDA_H Nd4jLong const* detachShape(Nd4jLong const* originalShape) { - Nd4jLong *newShape = new Nd4jLong[shape::shapeInfoLength(originalShape)]; - memcpy(newShape, originalShape, shape::shapeInfoByteLength(originalShape)); - - return newShape; - } +INLINEDEF _CUDA_H Nd4jLong const *detachShape(Nd4jLong const *originalShape) { + Nd4jLong *newShape = new Nd4jLong[shape::shapeInfoLength(originalShape)]; + memcpy(newShape, originalShape, shape::shapeInfoByteLength(originalShape)); + return newShape; +} - INLINEDEF _CUDA_H Nd4jLong* copyShape(Nd4jLong const* originalShape) { - Nd4jLong *newShape = new Nd4jLong[shape::shapeInfoLength(originalShape)]; - memcpy(newShape, originalShape, shape::shapeInfoByteLength(originalShape)); +INLINEDEF _CUDA_H Nd4jLong *copyShape(Nd4jLong const *originalShape) { + Nd4jLong *newShape = new Nd4jLong[shape::shapeInfoLength(originalShape)]; + memcpy(newShape, originalShape, shape::shapeInfoByteLength(originalShape)); - return newShape; - } + return newShape; +} - INLINEDEF _CUDA_HD int isVector(const Nd4jLong *shapeInfo) { - return isVector(shape::shapeOf(const_cast(shapeInfo)), shape::rank(shapeInfo)); - } +INLINEDEF _CUDA_HD int isVector(const Nd4jLong *shapeInfo) { + return isVector(shape::shapeOf(const_cast(shapeInfo)), + shape::rank(shapeInfo)); +} - INLINEDEF _CUDA_HD bool isRowVector(const Nd4jLong *shapeInfo) { - bool isVector = shape::isVector(shapeInfo) == 1; - bool shapeFirstOne = shape::shapeOf(const_cast(shapeInfo))[0] == 1; - return isVector && shapeFirstOne; - } +INLINEDEF _CUDA_HD bool isRowVector(const Nd4jLong *shapeInfo) { + bool isVector = shape::isVector(shapeInfo) == 1; + bool shapeFirstOne = + shape::shapeOf(const_cast(shapeInfo))[0] == 1; + return isVector && shapeFirstOne; +} - INLINEDEF _CUDA_HD bool isColumnVector(const Nd4jLong *shapeInfo) { - bool isVector = shape::isVector(shapeInfo) == 1; - bool shapeFirstOne = shape::shapeOf(shapeInfo)[0] == 1; - return isVector && !shapeFirstOne; - } +INLINEDEF _CUDA_HD bool isColumnVector(const Nd4jLong *shapeInfo) { + bool isVector = shape::isVector(shapeInfo) == 1; + bool shapeFirstOne = shape::shapeOf(shapeInfo)[0] == 1; + return isVector && !shapeFirstOne; +} ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape) { - - int num = 0; +INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, + const Nd4jLong *inShape) { + int num = 0; - for(uint i = 0; i < rank; ++i) - if(inShape[i] != 1) - ++num; + for (uint i = 0; i < rank; ++i) + if (inShape[i] != 1) ++num; - return num; + return num; } - INLINEDEF _CUDA_HD int oneDimEqualToLength(Nd4jLong *shape, int rank) { - for(int i = 0; i < rank; i++) { - if(shape[i] == shape::prodLong(shape,rank)) - return 1; - } +INLINEDEF _CUDA_HD int oneDimEqualToLength(Nd4jLong *shape, int rank) { + for (int i = 0; i < rank; i++) { + if (shape[i] == shape::prodLong(shape, rank)) return 1; + } - return 0; - } + return 0; +} - INLINEDEF _CUDA_HD int oneDimEqualToLength(Nd4jLong *shapeInfo) { - return oneDimEqualToLength(shape::shapeOf(shapeInfo),shape::rank(shapeInfo)); - } +INLINEDEF _CUDA_HD int oneDimEqualToLength(Nd4jLong *shapeInfo) { + return oneDimEqualToLength(shape::shapeOf(shapeInfo), shape::rank(shapeInfo)); +} /** -* Returns whether the -* given shape is a vector or not -* @param shape the shape of the array -* @param rank the rank of the shape -*/ - INLINEDEF _CUDA_HD int isMatrix(const Nd4jLong *shape, int rank) { - if (rank > 2) - return 0; - else if (rank <= 2) { - if (shape[0] == 1 || shape[1] == 1) - return 0; - } - - return 1; - } + * Returns whether the + * given shape is a vector or not + * @param shape the shape of the array + * @param rank the rank of the shape + */ +INLINEDEF _CUDA_HD int isMatrix(const Nd4jLong *shape, int rank) { + if (rank > 2) + return 0; + else if (rank <= 2) { + if (shape[0] == 1 || shape[1] == 1) return 0; + } + + return 1; +} - INLINEDEF _CUDA_HD int isMatrix(const Nd4jLong *shapeInfo) { - return isMatrix(shape::shapeOf(shapeInfo),shape::rank(shapeInfo)); - } +INLINEDEF _CUDA_HD int isMatrix(const Nd4jLong *shapeInfo) { + return isMatrix(shape::shapeOf(shapeInfo), shape::rank(shapeInfo)); +} /** * Returns the shape portion of an information * buffer */ - INLINEDEF _CUDA_HD Nd4jLong *shapeOf(Nd4jLong *shapeInfo) { - - return shapeInfo + 1; - } - - INLINEDEF _CUDA_HD Nd4jLong *shapeOf(const Nd4jLong *shapeInfo) { +INLINEDEF _CUDA_HD Nd4jLong *shapeOf(Nd4jLong *shapeInfo) { + return shapeInfo + 1; +} - return shape::shapeOf(const_cast(shapeInfo)); - } +INLINEDEF _CUDA_HD Nd4jLong *shapeOf(const Nd4jLong *shapeInfo) { + return shape::shapeOf(const_cast(shapeInfo)); +} /** * Return a copy of a buffer. * This buffer allocates memory * that must be freed elsewhere. */ - template - INLINEDEF _CUDA_HD T *copyOf(Nd4jLong length, T const* toCopy) { - traceNew(18); +template +INLINEDEF _CUDA_HD T *copyOf(Nd4jLong length, T const *toCopy) { + traceNew(18); - T *ret = new T[length]; - return copyOf(length, toCopy, ret); - } + T *ret = new T[length]; + return copyOf(length, toCopy, ret); +} - template - INLINEDEF _CUDA_HD T* copyOf(Nd4jLong length, T const* toCopy, T *ret) { - memcpy(ret, toCopy, sizeof(T)*length); - return ret; - } +template +INLINEDEF _CUDA_HD T *copyOf(Nd4jLong length, T const *toCopy, T *ret) { + memcpy(ret, toCopy, sizeof(T) * length); + return ret; +} /** -* Return a copy of a buffer. -* This buffer allocates memory -* that must be freed elsewhere. -*/ - template - INLINEDEF _CUDA_HD void copyTo(Nd4jLong length, T const* from, T *to) { - memcpy(to, from, sizeof(T)*length); - } + * Return a copy of a buffer. + * This buffer allocates memory + * that must be freed elsewhere. + */ +template +INLINEDEF _CUDA_HD void copyTo(Nd4jLong length, T const *from, T *to) { + memcpy(to, from, sizeof(T) * length); +} /** -* Return a copy of a buffer. -* This buffer allocates memory -* that must be freed elsewhere. -*/ - INLINEDEF _CUDA_HD void copyTo(int length, Nd4jLong const* from, Nd4jLong *to, Nd4jLong *indexes) { - for(int i = 0; i < length; i++) { - to[i] = from[indexes[i]]; - } - } + * Return a copy of a buffer. + * This buffer allocates memory + * that must be freed elsewhere. + */ +INLINEDEF _CUDA_HD void copyTo(int length, Nd4jLong const *from, Nd4jLong *to, + Nd4jLong *indexes) { + for (int i = 0; i < length; i++) { + to[i] = from[indexes[i]]; + } +} /** * Permute the given strides @@ -2424,93 +2606,88 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape) * and all must be filled in) * @return the rearranged array */ - /* - INLINEDEF _CUDA_HD Nd4jLong *permutedStrides(Nd4jLong *toPermute, int shapeRank, int *rearrange) { - Nd4jLong *strideCopy = copyOf(shapeRank, toPermute); - checkArrangeArray(rearrange, shapeRank, shapeRank); - Nd4jLong *newStride = doPermuteSwap(shapeRank, strideCopy, rearrange); - delete[] strideCopy; - return newStride; - } - */ +/* + INLINEDEF _CUDA_HD Nd4jLong *permutedStrides(Nd4jLong *toPermute, int + shapeRank, int *rearrange) { Nd4jLong *strideCopy = copyOf(shapeRank, + toPermute); checkArrangeArray(rearrange, shapeRank, shapeRank); Nd4jLong + *newStride = doPermuteSwap(shapeRank, strideCopy, rearrange); delete[] + strideCopy; return newStride; + } + */ /** * Return the slice (shape + 1 in pointer arithmetic) * @param shape the shape to take the slice of * @return the shape array - the first entry */ - INLINEDEF _CUDA_HD Nd4jLong *slice(Nd4jLong *shape) { - return shape + 1; - } - - INLINEDEF _CUDA_HD int slices(Nd4jLong *shapeBuffer) { - return static_cast(shape::shapeOf(shapeBuffer)[0]); - } - - - INLINEDEF _CUDA_HD Nd4jLong *sliceOfShapeBuffer(Nd4jLong sliceIdx, Nd4jLong *shapeBuffer) { - int rank = shape::rank(shapeBuffer); - int newRank = rank - 1; - if(newRank < 2) - newRank = 2; - Nd4jLong *newShapeBuffer = new Nd4jLong[shape::shapeInfoLength(newRank)]; - newShapeBuffer[0] = newRank; - Nd4jLong *currShape = shape::shapeOf(shapeBuffer); - Nd4jLong *currStride = shape::stride(shapeBuffer); - //initialize new shape and stride by taking the shape and stride + 1 - //and adding to the shape information - //a slice is always just taking the existing shape and cutting the first index off - //of the shape and stride - Nd4jLong *newShape = shape::shapeOf(newShapeBuffer); - Nd4jLong *newStride = shape::stride(newShapeBuffer); - if(shape::isVector(shapeBuffer)) { - Nd4jLong *currShape = shape::shapeOf(shapeBuffer); - //row vector: slice index 0 is a valid index, just copy the whole thing - if(currShape[0] == 1) { - if(sliceIdx == 0) { - memcpy(newShapeBuffer,shapeBuffer,shape::shapeInfoByteLength(shape::rank(shapeBuffer))); - return newShapeBuffer; - } - } - //column vector: this will be a scalar - else { - delete[] newShapeBuffer; - Nd4jLong *scalar = shape::createScalarShapeInfo(); - int offset = shape::offset(shapeBuffer); - scalar[shape::shapeInfoLength(2) - 3] = offset + sliceIdx; - return scalar; - } - } - else if(shape::isMatrix(shapeBuffer)) { - newShape[0] = 1; - newShape[1] = currShape[1]; - newStride[0] = 1; - newStride[1] = currStride[1]; - } - else { - for(int i = 0; i < newRank; i++) { - newShape[i] = currShape[i + 1]; - newStride[i] = currStride[i + 1]; - } - } - - auto indices = new Nd4jLong[rank]; - memset((void *) indices,0,rank * sizeof(Nd4jLong)); - indices[0] = sliceIdx; - Nd4jLong offset = shape::getOffset(newShapeBuffer, indices); - newShapeBuffer[shape::shapeInfoLength(newRank) - 3] = offset; - - // set current order and ews - newShapeBuffer[2 * newRank + 2] = shape::elementWiseStride(shapeBuffer); - newShapeBuffer[2 * newRank + 3] = shape::order(shapeBuffer); - - // correct order and ews if necessary - shape::checkStridesEwsAndOrder(newShapeBuffer); +INLINEDEF _CUDA_HD Nd4jLong *slice(Nd4jLong *shape) { return shape + 1; } - delete[] indices; +INLINEDEF _CUDA_HD int slices(Nd4jLong *shapeBuffer) { + return static_cast(shape::shapeOf(shapeBuffer)[0]); +} +INLINEDEF _CUDA_HD Nd4jLong *sliceOfShapeBuffer(Nd4jLong sliceIdx, + Nd4jLong *shapeBuffer) { + int rank = shape::rank(shapeBuffer); + int newRank = rank - 1; + if (newRank < 2) newRank = 2; + Nd4jLong *newShapeBuffer = new Nd4jLong[shape::shapeInfoLength(newRank)]; + newShapeBuffer[0] = newRank; + Nd4jLong *currShape = shape::shapeOf(shapeBuffer); + Nd4jLong *currStride = shape::stride(shapeBuffer); + // initialize new shape and stride by taking the shape and stride + 1 + // and adding to the shape information + // a slice is always just taking the existing shape and cutting the first + // index off of the shape and stride + Nd4jLong *newShape = shape::shapeOf(newShapeBuffer); + Nd4jLong *newStride = shape::stride(newShapeBuffer); + if (shape::isVector(shapeBuffer)) { + Nd4jLong *currShape = shape::shapeOf(shapeBuffer); + // row vector: slice index 0 is a valid index, just copy the whole thing + if (currShape[0] == 1) { + if (sliceIdx == 0) { + memcpy(newShapeBuffer, shapeBuffer, + shape::shapeInfoByteLength(shape::rank(shapeBuffer))); return newShapeBuffer; + } } + // column vector: this will be a scalar + else { + delete[] newShapeBuffer; + Nd4jLong *scalar = shape::createScalarShapeInfo(); + int offset = shape::offset(shapeBuffer); + scalar[shape::shapeInfoLength(2) - 3] = offset + sliceIdx; + return scalar; + } + } else if (shape::isMatrix(shapeBuffer)) { + newShape[0] = 1; + newShape[1] = currShape[1]; + newStride[0] = 1; + newStride[1] = currStride[1]; + } else { + for (int i = 0; i < newRank; i++) { + newShape[i] = currShape[i + 1]; + newStride[i] = currStride[i + 1]; + } + } + + auto indices = new Nd4jLong[rank]; + memset((void *)indices, 0, rank * sizeof(Nd4jLong)); + indices[0] = sliceIdx; + Nd4jLong offset = shape::getOffset(newShapeBuffer, indices); + newShapeBuffer[shape::shapeInfoLength(newRank) - 3] = offset; + + // set current order and ews + newShapeBuffer[2 * newRank + 2] = shape::elementWiseStride(shapeBuffer); + newShapeBuffer[2 * newRank + 3] = shape::order(shapeBuffer); + + // correct order and ews if necessary + shape::checkStridesEwsAndOrder(newShapeBuffer); + + delete[] indices; + + return newShapeBuffer; +} /** * Returns the length of the @@ -2520,48 +2697,46 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape) * info length for * @return rank * 2 + 4 */ - INLINEDEF _CUDA_HD int shapeInfoLength(int rank) { - //FIXME magic numbers - return rank * 2 + 4; - } +INLINEDEF _CUDA_HD int shapeInfoLength(int rank) { + // FIXME magic numbers + return rank * 2 + 4; +} - INLINEDEF _CUDA_HD int shapeInfoLength(Nd4jLong* shape) { - return shapeInfoLength(static_cast(shape[0])); - } +INLINEDEF _CUDA_HD int shapeInfoLength(Nd4jLong *shape) { + return shapeInfoLength(static_cast(shape[0])); +} - INLINEDEF _CUDA_HD int shapeInfoLength(const Nd4jLong* shape) { - return shapeInfoLength(static_cast(shape[0])); - } +INLINEDEF _CUDA_HD int shapeInfoLength(const Nd4jLong *shape) { + return shapeInfoLength(static_cast(shape[0])); +} - INLINEDEF _CUDA_HD size_t shapeInfoByteLength(int rank) { - //FIXME magic numbers - return (rank * 2 + 4) * sizeof(Nd4jLong); - } +INLINEDEF _CUDA_HD size_t shapeInfoByteLength(int rank) { + // FIXME magic numbers + return (rank * 2 + 4) * sizeof(Nd4jLong); +} - INLINEDEF _CUDA_HD size_t shapeInfoByteLength(const Nd4jLong* shapeInfo) { - //FIXME magic numbers - return shapeInfoByteLength((int) shapeInfo[0]); - } +INLINEDEF _CUDA_HD size_t shapeInfoByteLength(const Nd4jLong *shapeInfo) { + // FIXME magic numbers + return shapeInfoByteLength((int)shapeInfo[0]); +} /** * Returns the rank portion of * an information buffer */ - INLINEDEF _CUDA_HD int rank(const Nd4jLong *buffer) { - return static_cast(buffer[0]); - } +INLINEDEF _CUDA_HD int rank(const Nd4jLong *buffer) { + return static_cast(buffer[0]); +} - INLINEDEF _CUDA_HD int rank(const int *buffer) { - return buffer[0]; - } +INLINEDEF _CUDA_HD int rank(const int *buffer) { return buffer[0]; } - INLINEDEF _CUDA_HD int rank(const unsigned int *buffer) { - return static_cast(buffer[0]); - } +INLINEDEF _CUDA_HD int rank(const unsigned int *buffer) { + return static_cast(buffer[0]); +} - INLINEDEF _CUDA_HD Nd4jLong* ews(Nd4jLong* shapeInfo) { - return shapeInfo + 2 * shapeInfo[0] + 2; - } +INLINEDEF _CUDA_HD Nd4jLong *ews(Nd4jLong *shapeInfo) { + return shapeInfo + 2 * shapeInfo[0] + 2; +} /** * Converts a raw int buffer of the layout: @@ -2573,216 +2748,211 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape) * * where shape and stride are both straight int pointers */ - INLINEDEF _CUDA_HD ShapeInformation *infoFromBuffer(Nd4jLong *buffer) { - - traceNew(19); - - auto info = new ShapeInformation; - auto length = shapeInfoLength(rank(buffer)); - auto rank = buffer[0]; - - //start after rank - info->shape = buffer + 1; - info->stride = buffer + (1 + rank); - info->rank = rank; - info->offset = buffer[length - 3]; - info->elementWiseStride = buffer[length - 2]; - Nd4jLong *stride = buffer + 1 + rank; - info->stride = stride; - info->order = (char) buffer[length - 1]; - return info; - } +INLINEDEF _CUDA_HD ShapeInformation *infoFromBuffer(Nd4jLong *buffer) { + traceNew(19); + + auto info = new ShapeInformation; + auto length = shapeInfoLength(rank(buffer)); + auto rank = buffer[0]; + + // start after rank + info->shape = buffer + 1; + info->stride = buffer + (1 + rank); + info->rank = rank; + info->offset = buffer[length - 3]; + info->elementWiseStride = buffer[length - 2]; + Nd4jLong *stride = buffer + 1 + rank; + info->stride = stride; + info->order = (char)buffer[length - 1]; + return info; +} /** * Returns the stride portion of an information * buffer */ - INLINEDEF _CUDA_HD Nd4jLong *stride(Nd4jLong *buffer) { - return buffer + (1 + rank(buffer)); - } - - INLINEDEF _CUDA_HD Nd4jLong *stride(const Nd4jLong *buffer) { - return stride(const_cast(buffer)); - } +INLINEDEF _CUDA_HD Nd4jLong *stride(Nd4jLong *buffer) { + return buffer + (1 + rank(buffer)); +} - INLINEDEF _CUDA_HD bool isEmpty(const Nd4jLong *shapeInfo) { - return ((shape::extra(const_cast(shapeInfo)) & ARRAY_EMPTY) == ARRAY_EMPTY); - } +INLINEDEF _CUDA_HD Nd4jLong *stride(const Nd4jLong *buffer) { + return stride(const_cast(buffer)); +} +INLINEDEF _CUDA_HD bool isEmpty(const Nd4jLong *shapeInfo) { + return ((shape::extra(const_cast(shapeInfo)) & ARRAY_EMPTY) == + ARRAY_EMPTY); +} /** * Compute the length of the given shape */ - INLINEDEF _CUDA_HD Nd4jLong length(const Nd4jLong *shapeInfo) { - - const int rank = shape::rank(shapeInfo); - - if (rank == 0) { - if (isEmpty(shapeInfo)) - return 0L; - return 1L; - } - - if (rank == 1) - return shapeInfo[1]; - - // if(shape::elementWiseStride(shapeInfo) == 1) { // contiguous - // if(shape::order(shapeInfo) == 'c') - // return shapeInfo[1] * shapeInfo[rank + 1]; // first dim * first stride - // return shapeInfo[rank] * shapeInfo[2 * rank]; // last dim * last stride - // } - - return shape::prodLong(shape::shapeOf(const_cast(shapeInfo)), rank); - } +INLINEDEF _CUDA_HD Nd4jLong length(const Nd4jLong *shapeInfo) { + const int rank = shape::rank(shapeInfo); + + if (rank == 0) { + if (isEmpty(shapeInfo)) return 0L; + return 1L; + } + + if (rank == 1) return shapeInfo[1]; + + // if(shape::elementWiseStride(shapeInfo) == 1) { // contiguous + // if(shape::order(shapeInfo) == 'c') + // return shapeInfo[1] * shapeInfo[rank + 1]; // first dim * + // first stride + // return shapeInfo[rank] * shapeInfo[2 * rank]; // last dim * last + // stride + // } + + return shape::prodLong(shape::shapeOf(const_cast(shapeInfo)), + rank); +} - INLINEDEF _CUDA_HD Nd4jLong length(std::initializer_list& shape) { - Nd4jLong ret = 1; - for (auto v : shape) { - ret *= v; - } - return ret; - } +INLINEDEF _CUDA_HD Nd4jLong length(std::initializer_list &shape) { + Nd4jLong ret = 1; + for (auto v : shape) { + ret *= v; + } + return ret; +} - INLINEDEF _CUDA_HD Nd4jLong length(std::initializer_list& shape) { - Nd4jLong ret = 1; - for (auto v : shape) { - ret *= v; - } - return ret; - } +INLINEDEF _CUDA_HD Nd4jLong length(std::initializer_list &shape) { + Nd4jLong ret = 1; + for (auto v : shape) { + ret *= v; + } + return ret; +} /*** * Returns the offset * portion of an information buffer */ - INLINEDEF _CUDA_HD Nd4jLong offset(Nd4jLong *buffer) { - return buffer[shape::shapeInfoLength(shape::rank(buffer)) - 3]; - } - - INLINEDEF _CUDA_HD Nd4jLong& extra(Nd4jLong *buffer) { - return buffer[shape::shapeInfoLength(shape::rank(buffer)) - 3]; - } +INLINEDEF _CUDA_HD Nd4jLong offset(Nd4jLong *buffer) { + return buffer[shape::shapeInfoLength(shape::rank(buffer)) - 3]; +} +INLINEDEF _CUDA_HD Nd4jLong &extra(Nd4jLong *buffer) { + return buffer[shape::shapeInfoLength(shape::rank(buffer)) - 3]; +} /** * Returns the ordering * for this shape information buffer */ - INLINEDEF _CUDA_HD char order(const Nd4jLong *buffer) { - //FIXME magic numbers - return static_cast(buffer[buffer[0] * 2 + 3]); - } +INLINEDEF _CUDA_HD char order(const Nd4jLong *buffer) { + // FIXME magic numbers + return static_cast(buffer[buffer[0] * 2 + 3]); +} /** * Returns type */ - INLINEDEF _CUDA_HD Nd4jLong type(const Nd4jLong *shapeInfo) { - return shapeInfo[2 * shapeInfo[0] + 1]; - } +INLINEDEF _CUDA_HD Nd4jLong type(const Nd4jLong *shapeInfo) { + return shapeInfo[2 * shapeInfo[0] + 1]; +} /** * Returns the element wise stride for this information * buffer */ - INLINEDEF _CUDA_HD Nd4jLong elementWiseStride(const Nd4jLong *buffer) { - return buffer[shapeInfoLength(static_cast(buffer[0])) - 2]; - } +INLINEDEF _CUDA_HD Nd4jLong elementWiseStride(const Nd4jLong *buffer) { + return buffer[shapeInfoLength(static_cast(buffer[0])) - 2]; +} /** -* Returns the element wise stride for this information -* buffer relative to a dimension and reduction index -*/ - INLINEDEF _CUDA_HD Nd4jLong reductionIndexElementWiseStride(Nd4jLong* buffer, int* dimension, int dimensionLength) { - if(dimensionLength > 1) { - if(shape::order(buffer) == 'f') { - /** - * The element wise stride belongs to a reduction index. - * When used out of order, we can get rid of the data - * dependencies and rely on using the max dimension - * specified for stride instead. - * Say we take the sum(0,1) along arr - * we can use arr.stride(1) as a representation - * along which to iterate. - */ - if(shape::shapeOf(buffer)[dimension[dimensionLength - 1]] != 1) { - //int tadElementWiseStride = shape::stride(buffer)[dimension[dimensionLength - 1]]; - //return tadElementWiseStride; - auto tadElementWiseStride = shape::stride(buffer)[dimension[0]]; - return tadElementWiseStride; - } - - return 1; - - } - else { - /** - * The element wise stride belongs to a reduction index. - * When used out of order, we can get rid of the data - * dependencies and rely on using the max dimension - * specified for stride instead. - * Say we take the sum(0,1) along arr - * we can use arr.stride(1) as a representation - * along which to iterate. - */ - if(shape::shapeOf(buffer)[dimension[dimensionLength - 1]] != 1) { - auto tadElementWiseStride = shape::stride(buffer)[dimension[dimensionLength - 1]]; - return tadElementWiseStride; - } - - return 1; - } - } - else { - if(shape::order(buffer) == 'f') { - /** - * The element wise stride belongs to a reduction index. - * When used out of order, we can get rid of the data - * dependencies and rely on using the max dimension - * specified for stride instead. - * Say we take the sum(0,1) along arr - * we can use arr.stride(1) as a representation - * along which to iterate. - */ - auto tadElementWiseStride = shape::stride(buffer)[dimension[0]]; - return tadElementWiseStride; - } - else { - /** - * The element wise stride belongs to a reduction index. - * When used out of order, we can get rid of the data - * dependencies and rely on using the max dimension - * specified for stride instead. - * Say we take the sum(0,1) along arr - * we can use arr.stride(1) as a representation - * along which to iterate. - */ - auto tadElementWiseStride = shape::stride(buffer)[dimension[dimensionLength - 1]]; - return tadElementWiseStride; - } - } - - } + * Returns the element wise stride for this information + * buffer relative to a dimension and reduction index + */ +INLINEDEF _CUDA_HD Nd4jLong reductionIndexElementWiseStride( + Nd4jLong *buffer, int *dimension, int dimensionLength) { + if (dimensionLength > 1) { + if (shape::order(buffer) == 'f') { + /** + * The element wise stride belongs to a reduction index. + * When used out of order, we can get rid of the data + * dependencies and rely on using the max dimension + * specified for stride instead. + * Say we take the sum(0,1) along arr + * we can use arr.stride(1) as a representation + * along which to iterate. + */ + if (shape::shapeOf(buffer)[dimension[dimensionLength - 1]] != 1) { + // int tadElementWiseStride = + // shape::stride(buffer)[dimension[dimensionLength - 1]]; return + // tadElementWiseStride; + auto tadElementWiseStride = shape::stride(buffer)[dimension[0]]; + return tadElementWiseStride; + } + + return 1; + + } else { + /** + * The element wise stride belongs to a reduction index. + * When used out of order, we can get rid of the data + * dependencies and rely on using the max dimension + * specified for stride instead. + * Say we take the sum(0,1) along arr + * we can use arr.stride(1) as a representation + * along which to iterate. + */ + if (shape::shapeOf(buffer)[dimension[dimensionLength - 1]] != 1) { + auto tadElementWiseStride = + shape::stride(buffer)[dimension[dimensionLength - 1]]; + return tadElementWiseStride; + } + + return 1; + } + } else { + if (shape::order(buffer) == 'f') { + /** + * The element wise stride belongs to a reduction index. + * When used out of order, we can get rid of the data + * dependencies and rely on using the max dimension + * specified for stride instead. + * Say we take the sum(0,1) along arr + * we can use arr.stride(1) as a representation + * along which to iterate. + */ + auto tadElementWiseStride = shape::stride(buffer)[dimension[0]]; + return tadElementWiseStride; + } else { + /** + * The element wise stride belongs to a reduction index. + * When used out of order, we can get rid of the data + * dependencies and rely on using the max dimension + * specified for stride instead. + * Say we take the sum(0,1) along arr + * we can use arr.stride(1) as a representation + * along which to iterate. + */ + auto tadElementWiseStride = + shape::stride(buffer)[dimension[dimensionLength - 1]]; + return tadElementWiseStride; + } + } +} /** * Returns whether * the given shape info buffer * represents a scalar shape */ - INLINEDEF _CUDA_HD int isScalar(const Nd4jLong *info) { +INLINEDEF _CUDA_HD int isScalar(const Nd4jLong *info) { + const int rank = shape::rank(info); - const int rank = shape::rank(info); + if (rank > 2) return 0; + if (rank == 0) return 1; + if (rank == 1) return shape::shapeOf(const_cast(info))[0] == 1; + if (rank == 2) + return shape::shapeOf(const_cast(info))[0] == 1 && + shape::shapeOf(const_cast(info))[1] == 1; - if(rank > 2) - return 0; - if(rank == 0) - return 1; - if(rank == 1) - return shape::shapeOf(const_cast(info))[0] == 1; - if(rank == 2) - return shape::shapeOf(const_cast(info))[0] == 1 && shape::shapeOf(const_cast(info))[1] == 1; - - return 0; - } + return 0; +} /** * Returns whether @@ -2790,19 +2960,15 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape) * represents a scalar * shape or not */ - INLINEDEF _CUDA_HD int isScalar(volatile ShapeInformation *info) { +INLINEDEF _CUDA_HD int isScalar(volatile ShapeInformation *info) { + const int rank = info->rank; - const int rank = info->rank; + if (rank > 2) return 0; + if (rank == 1) return info->shape[0] == 1; + if (rank == 2) return info->shape[0] == 1 && info->shape[1] == 1; - if(rank > 2) - return 0; - if(rank == 1) - return info->shape[0] == 1; - if(rank == 2) - return info->shape[0] == 1 && info->shape[1] == 1; - - return 0; - } + return 0; +} /** * Return a copy of this array with the @@ -2816,28 +2982,29 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape) * * item */ - template - INLINEDEF _CUDA_HD void removeIndex(T1 const* data, T2 const* indexes, Nd4jLong dataLength, Nd4jLong indexesLength, T1 *ret) { - - int count = 0; - int absLength = dataLength - indexesLength; - for (int i = 0; i < dataLength && count < absLength; i++) { - int contains = 0; - for (int j = 0; j < indexesLength; j++) { - if (i == indexes[j]) { - contains = 1; - break; - } - } - - if (!contains) { - ret[count] = data[i]; - count++; - } - } - } +template +INLINEDEF _CUDA_HD void removeIndex(T1 const *data, T2 const *indexes, + Nd4jLong dataLength, Nd4jLong indexesLength, + T1 *ret) { + int count = 0; + int absLength = dataLength - indexesLength; + for (int i = 0; i < dataLength && count < absLength; i++) { + int contains = 0; + for (int j = 0; j < indexesLength; j++) { + if (i == indexes[j]) { + contains = 1; + break; + } + } + + if (!contains) { + ret[count] = data[i]; + count++; + } + } +} - /** +/** * Return a copy of this array with the * given index omitted * @@ -2849,46 +3016,50 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape) * * item */ - template - INLINEDEF _CUDA_HD T1* removeIndex(T1 const* data, T2 const* indexes, Nd4jLong dataLength, Nd4jLong indexesLength) { - auto lengthOfArr = dataLength - indexesLength; - if(lengthOfArr < 0) { - printf("Remove index call created a <= 0 length array. This was likely not intended."); - } - - auto ret = new T1[lengthOfArr]; - memset(ret,0,sizeof(T1) * lengthOfArr); - removeIndex(data, indexes, dataLength, indexesLength, ret); - return ret; - } - - INLINEDEF _CUDA_HD Nd4jLong* everyIndexBut(const Nd4jLong *indexes,int indexesLength,int begin,int end) { - int len = end - indexesLength; - - traceNew(20); - - auto ret = new Nd4jLong[len]; - int retIdx = 0; - //not here that we do 0 based indexing for end - this assumes things like: - //0 to 4 are specified - for(int i = begin; i < end ; i++) { - bool found = false; - for(int j = 0; j < indexesLength; j++) { - if(indexes[j] == i) { - found = true; - break; - } - } +template +INLINEDEF _CUDA_HD T1 *removeIndex(T1 const *data, T2 const *indexes, + Nd4jLong dataLength, + Nd4jLong indexesLength) { + auto lengthOfArr = dataLength - indexesLength; + if (lengthOfArr < 0) { + printf( + "Remove index call created a <= 0 length array. This was likely not " + "intended."); + } + + auto ret = new T1[lengthOfArr]; + memset(ret, 0, sizeof(T1) * lengthOfArr); + removeIndex(data, indexes, dataLength, indexesLength, ret); + return ret; +} - if(!found) { - ret[retIdx++] = i; - } +INLINEDEF _CUDA_HD Nd4jLong *everyIndexBut(const Nd4jLong *indexes, + int indexesLength, int begin, + int end) { + int len = end - indexesLength; - } + traceNew(20); - return ret; + auto ret = new Nd4jLong[len]; + int retIdx = 0; + // not here that we do 0 based indexing for end - this assumes things like: + // 0 to 4 are specified + for (int i = begin; i < end; i++) { + bool found = false; + for (int j = 0; j < indexesLength; j++) { + if (indexes[j] == i) { + found = true; + break; + } + } + if (!found) { + ret[retIdx++] = i; } + } + + return ret; +} /** * Computes the offset for accessing @@ -2896,8 +3067,8 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape) * and the offset to be read. */ #ifdef __CUDACC__ - INLINEDEF __device__ int tadOffset(ShapeInformation *xInfo, int offset) { - return offset + threadIdx.x * xInfo->elementWiseStride; +INLINEDEF __device__ int tadOffset(ShapeInformation *xInfo, int offset) { + return offset + threadIdx.x * xInfo->elementWiseStride; } #endif @@ -2909,21 +3080,21 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape) * for the shape to be returned as * @return the new shape */ - INLINEDEF _CUDA_HD Nd4jLong *ensureVectorShape(Nd4jLong *shape, int dimension) { - traceNew(21); +INLINEDEF _CUDA_HD Nd4jLong *ensureVectorShape(Nd4jLong *shape, int dimension) { + traceNew(21); - Nd4jLong *ret = new Nd4jLong[2]; + Nd4jLong *ret = new Nd4jLong[2]; - if (dimension == 0) { - ret[0] = 1; - ret[1] = shape[0]; - } else { - ret[0] = shape[0]; - ret[1] = 1; - } + if (dimension == 0) { + ret[0] = 1; + ret[1] = shape[0]; + } else { + ret[0] = shape[0]; + ret[1] = 1; + } - return ret; - } + return ret; +} /** * Returns a shape @@ -2933,99 +3104,96 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape) * for the shape to be returned as * @return the new shape */ - INLINEDEF _CUDA_HD Nd4jLong *ensureVectorShape(Nd4jLong *shape) { - return ensureVectorShape(shape, 0); - } +INLINEDEF _CUDA_HD Nd4jLong *ensureVectorShape(Nd4jLong *shape) { + return ensureVectorShape(shape, 0); +} - /** - * This method does STRICT comparison for two shape buffers - * - * @param shape - * @return - */ - INLINEDEF _CUDA_HD bool equalsStrict(const Nd4jLong *shapeA, const Nd4jLong *shapeB) { - if (shapeA[0] != shapeB[0]) - return false; +/** + * This method does STRICT comparison for two shape buffers + * + * @param shape + * @return + */ +INLINEDEF _CUDA_HD bool equalsStrict(const Nd4jLong *shapeA, + const Nd4jLong *shapeB) { + if (shapeA[0] != shapeB[0]) return false; - if (shapeA[0] == 0) - return true; + if (shapeA[0] == 0) return true; - // we do full comparison here - int length = shape::shapeInfoLength(shapeA[0]); + // we do full comparison here + int length = shape::shapeInfoLength(shapeA[0]); - for (int e = 1; e < length; e++) - if (shapeA[e] != shapeB[e]) - return false; + for (int e = 1; e < length; e++) + if (shapeA[e] != shapeB[e]) return false; - return true; - } + return true; +} ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2) { +INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, + const Nd4jLong *shapeInfo2) { + if (shapeInfo1[0] != shapeInfo2[0]) return false; - if (shapeInfo1[0] != shapeInfo2[0]) - return false; + if (shapeInfo1[0] == 0) return true; - if (shapeInfo1[0] == 0) - return true; + for (uint e = 0; e < static_cast(shape::rank(shapeInfo1)); ++e) + if (shape::shapeOf(shapeInfo1)[e] != shape::shapeOf(shapeInfo2)[e] || + shape::stride(shapeInfo1)[e] != shape::stride(shapeInfo2)[e]) + return false; - for (uint e = 0; e < static_cast(shape::rank(shapeInfo1)); ++e) - if (shape::shapeOf(shapeInfo1)[e] != shape::shapeOf(shapeInfo2)[e] || shape::stride(shapeInfo1)[e] != shape::stride(shapeInfo2)[e]) - return false; - - return true; + return true; } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2, const Nd4jLong *shapeInfo3) { - - return shape::haveSameShapeAndStrides(shapeInfo1, shapeInfo2) && shape::haveSameShapeAndStrides(shapeInfo1, shapeInfo3); +INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, + const Nd4jLong *shapeInfo2, + const Nd4jLong *shapeInfo3) { + return shape::haveSameShapeAndStrides(shapeInfo1, shapeInfo2) && + shape::haveSameShapeAndStrides(shapeInfo1, shapeInfo3); +} +INLINEDEF _CUDA_HD int sizeAt(const Nd4jLong *shapeInfo, const int dim) { + if (0 == rank(shapeInfo)) return 1; + if (dim >= 0) + return shapeInfo[1 + dim]; + else + return shapeInfo[1 + (rank(shapeInfo) + dim)]; } - INLINEDEF _CUDA_HD int sizeAt(const Nd4jLong *shapeInfo, const int dim) { - if (0 == rank(shapeInfo)) - return 1; - if (dim >= 0) - return shapeInfo[1+dim]; - else - return shapeInfo[1+(rank(shapeInfo) + dim)]; - } - - INLINEDEF _CUDA_HD Nd4jLong strideAt(const Nd4jLong *shapeInfo, const int dim) { - if (0 == rank(shapeInfo)) - return 1; - if (dim >= 0) - return shapeInfo[1 + rank(shapeInfo) + dim]; - else - return shapeInfo[1 + 2*rank(shapeInfo) + dim]; - } - /** - * This method does SOFT comparison for two shape buffers, we compare only rank & shapes - * - * @param shape - * @return - */ - INLINEDEF _CUDA_HD bool equalsSoft(const Nd4jLong *shapeA, const Nd4jLong *shapeB) { - if (shapeA[0] != shapeB[0]) - return false; +INLINEDEF _CUDA_HD Nd4jLong strideAt(const Nd4jLong *shapeInfo, const int dim) { + if (0 == rank(shapeInfo)) return 1; + if (dim >= 0) + return shapeInfo[1 + rank(shapeInfo) + dim]; + else + return shapeInfo[1 + 2 * rank(shapeInfo) + dim]; +} - if (shapeA[0] == 0) - return true; +/** + * This method does SOFT comparison for two shape buffers, we compare only rank + * & shapes + * + * @param shape + * @return + */ +INLINEDEF _CUDA_HD bool equalsSoft(const Nd4jLong *shapeA, + const Nd4jLong *shapeB) { + if (shapeA[0] != shapeB[0]) return false; - // we compare only shapes, and ignoring stride & ews - auto length = shapeA[0]; + if (shapeA[0] == 0) return true; - for (int e = 1; e <= length; e++) - if (shapeA[e] != shapeB[e]) - return false; + // we compare only shapes, and ignoring stride & ews + auto length = shapeA[0]; - return true; - } + for (int e = 1; e <= length; e++) + if (shapeA[e] != shapeB[e]) return false; - INLINEDEF _CUDA_HD bool equalsTypesAndShapesSoft(const Nd4jLong *shapeA, const Nd4jLong *shapeB) { + return true; +} - return equalsSoft(shapeA, shapeB) && shapeA[shapeInfoLength(shapeA) - 3] == shapeB[shapeInfoLength(shapeB) - 3]; - } +INLINEDEF _CUDA_HD bool equalsTypesAndShapesSoft(const Nd4jLong *shapeA, + const Nd4jLong *shapeB) { + return equalsSoft(shapeA, shapeB) && shapeA[shapeInfoLength(shapeA) - 3] == + shapeB[shapeInfoLength(shapeB) - 3]; +} /** * Generate an int buffer @@ -3033,36 +3201,34 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons * at the specified increment * */ - template - INLINEDEF _CUDA_HD T* range(int from, int to, int increment) { - int diff = sd::math::nd4j_abs(from - to); - int retLength = diff / increment; - T *ret; - - traceNew(22); - - if(diff / increment < 1) - ret = new T[1]; - else - ret = new T[diff / increment]; - if (from < to) { - int count = 0; - for (int i = from; i < to; i += increment) { - if (count >= retLength) - break; - ret[count++] = i; - } - } else if (from > to) { - int count = 0; - for (int i = from - 1; i >= to; i -= increment) { - if (count >= retLength) - break; - ret[count++] = i; - } - } - - return ret; - } +template +INLINEDEF _CUDA_HD T *range(int from, int to, int increment) { + int diff = sd::math::nd4j_abs(from - to); + int retLength = diff / increment; + T *ret; + + traceNew(22); + + if (diff / increment < 1) + ret = new T[1]; + else + ret = new T[diff / increment]; + if (from < to) { + int count = 0; + for (int i = from; i < to; i += increment) { + if (count >= retLength) break; + ret[count++] = i; + } + } else if (from > to) { + int count = 0; + for (int i = from - 1; i >= to; i -= increment) { + if (count >= retLength) break; + ret[count++] = i; + } + } + + return ret; +} /** * Generate a range @@ -3073,10 +3239,10 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons * @return the int array starting at from and ending at to */ - template - INLINEDEF _CUDA_HD T* range(int from, int to) { - return range(from, to, 1); - } +template +INLINEDEF _CUDA_HD T *range(int from, int to) { + return range(from, to, 1); +} /** * Keep the given indexes in the data @@ -3086,71 +3252,67 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons * @param dataLength * @return */ - INLINEDEF _CUDA_HD Nd4jLong *keep(volatile Nd4jLong *data, int const* index, int indexLength, int dataLength) { - - traceNew(23); - - Nd4jLong *ret = new Nd4jLong[indexLength]; - int count = 0; - for (int i = 0; i < dataLength; i++) { - int contains = 0; - for (int j = 0; j < indexLength; j++) { - if (i == index[j]) { - contains = 1; - break; - } - } - - if (contains) - ret[count++] = data[i]; - } - return ret; - } +INLINEDEF _CUDA_HD Nd4jLong *keep(volatile Nd4jLong *data, int const *index, + int indexLength, int dataLength) { + traceNew(23); + + Nd4jLong *ret = new Nd4jLong[indexLength]; + int count = 0; + for (int i = 0; i < dataLength; i++) { + int contains = 0; + for (int j = 0; j < indexLength; j++) { + if (i == index[j]) { + contains = 1; + break; + } + } + + if (contains) ret[count++] = data[i]; + } + return ret; +} /** * Generate a reverse * copy of the data */ - template - INLINEDEF _CUDA_HD T* reverseCopy(T const* data, Nd4jLong length) { - if (length < 1) - return nullptr; - - traceNew(24); - - T *copy = new T[length]; - for (Nd4jLong i = 0; i <= length / 2; i++) { - T temp = data[i]; - copy[i] = data[length - i - 1]; - copy[length - i - 1] = temp; - } - return copy; - } - - template - INLINEDEF _CUDA_HD void reverseCopyTo(T const* from, T *to, Nd4jLong length) { - if (length < 1) - return; - for (Nd4jLong i = 0; i <= length / 2; i++) { - T temp = from[i]; - to[i] = from[length - i - 1]; - to[length - i - 1] = temp; - } - } - - template - INLINEDEF _CUDA_HD void reverseCopyTo(T const* from, T *to, Nd4jLong *indexes, Nd4jLong length) { - if (length < 1) - return; +template +INLINEDEF _CUDA_HD T *reverseCopy(T const *data, Nd4jLong length) { + if (length < 1) return nullptr; + + traceNew(24); + + T *copy = new T[length]; + for (Nd4jLong i = 0; i <= length / 2; i++) { + T temp = data[i]; + copy[i] = data[length - i - 1]; + copy[length - i - 1] = temp; + } + return copy; +} - for (Nd4jLong i = 0; i <= length / 2; i++) { - T temp = from[indexes[i]]; - to[i] = from[indexes[length - i - 1]]; - to[length - i - 1] = temp; - } +template +INLINEDEF _CUDA_HD void reverseCopyTo(T const *from, T *to, Nd4jLong length) { + if (length < 1) return; + for (Nd4jLong i = 0; i <= length / 2; i++) { + T temp = from[i]; + to[i] = from[length - i - 1]; + to[length - i - 1] = temp; + } +} - } +template +INLINEDEF _CUDA_HD void reverseCopyTo(T const *from, T *to, Nd4jLong *indexes, + Nd4jLong length) { + if (length < 1) return; + + for (Nd4jLong i = 0; i <= length / 2; i++) { + T temp = from[indexes[i]]; + to[i] = from[indexes[length - i - 1]]; + to[length - i - 1] = temp; + } +} /** * @@ -3160,16 +3322,16 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons * @param arr2Length * @return */ - template - INLINEDEF _CUDA_HD T* concat(T const* arr1, Nd4jLong const arr1Length, T const* arr2, Nd4jLong const arr2Length) { - - traceNew(25); - - T *ret = new T[arr1Length + arr2Length]; - std::memcpy(ret, arr1, arr1Length * sizeof(T)); - std::memcpy(ret + arr1Length, arr2, arr2Length * sizeof(T)); - return ret; - } +template +INLINEDEF _CUDA_HD T *concat(T const *arr1, Nd4jLong const arr1Length, + T const *arr2, Nd4jLong const arr2Length) { + traceNew(25); + + T *ret = new T[arr1Length + arr2Length]; + std::memcpy(ret, arr1, arr1Length * sizeof(T)); + std::memcpy(ret + arr1Length, arr2, arr2Length * sizeof(T)); + return ret; +} /** * @@ -3179,20 +3341,21 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons * @param lengths * @return */ - template - INLINEDEF _CUDA_HD T *concat(Nd4jLong const numArrays, Nd4jLong const numTotalElements, T const **arr, Nd4jLong const *lengths) { - - T* ret = new T[numTotalElements]; - Nd4jLong count = 0; - - for (Nd4jLong i = 0; i < numArrays; i++) { - for (Nd4jLong j = 0; j < lengths[i]; j++) { - ret[count++] = arr[i][j]; - } - } +template +INLINEDEF _CUDA_HD T *concat(Nd4jLong const numArrays, + Nd4jLong const numTotalElements, T const **arr, + Nd4jLong const *lengths) { + T *ret = new T[numTotalElements]; + Nd4jLong count = 0; - return ret; + for (Nd4jLong i = 0; i < numArrays; i++) { + for (Nd4jLong j = 0; j < lengths[i]; j++) { + ret[count++] = arr[i][j]; } + } + + return ret; +} /** * Get the length per slice of the @@ -3206,22 +3369,24 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons * @return the length per slice of the given shape * along the given dimension */ - INLINEDEF _CUDA_HD Nd4jLong lengthPerSlice(int rank, Nd4jLong const* shape, int const* dimension, int dimensionLength) { - if(shape::isVector(shape,rank)) { - //return total length for row vectors - if(dimensionLength == 1 && shape[0] == 1) { - return shape::prodLong(shape,rank); - } - } - else if(rank == dimensionLength) - return shape::prodLong(shape,rank); - int absSelta = sd::math::nd4j_abs(rank - dimensionLength); - traceNew(27); - auto ret2 = shape::removeIndex(shape, dimension, rank, dimensionLength); - auto ret = prodLong(ret2, absSelta); - delete[] ret2; - return ret; - } +INLINEDEF _CUDA_HD Nd4jLong lengthPerSlice(int rank, Nd4jLong const *shape, + int const *dimension, + int dimensionLength) { + if (shape::isVector(shape, rank)) { + // return total length for row vectors + if (dimensionLength == 1 && shape[0] == 1) { + return shape::prodLong(shape, rank); + } + } else if (rank == dimensionLength) + return shape::prodLong(shape, rank); + int absSelta = sd::math::nd4j_abs(rank - dimensionLength); + traceNew(27); + auto ret2 = + shape::removeIndex(shape, dimension, rank, dimensionLength); + auto ret = prodLong(ret2, absSelta); + delete[] ret2; + return ret; +} /** * calculates the offset for a tensor @@ -3230,18 +3395,21 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons * @param tensorShape * @return */ - INLINEDEF _CUDA_HD Nd4jLong sliceOffsetForTensor(int rank, int index, Nd4jLong const* shape, Nd4jLong const* tensorShape, int tensorShapeLength, int const* dimension, int dimensionLength) { - auto tensorLength = prodLong(tensorShape, tensorShapeLength); - auto lengthPerSlice2 = lengthPerSlice(rank, shape, dimension, dimensionLength); - if (lengthPerSlice2 <= 0) { - return 0; - } - - Nd4jLong offset = index * tensorLength / lengthPerSlice2; - return offset; - } +INLINEDEF _CUDA_HD Nd4jLong sliceOffsetForTensor( + int rank, int index, Nd4jLong const *shape, Nd4jLong const *tensorShape, + int tensorShapeLength, int const *dimension, int dimensionLength) { + auto tensorLength = prodLong(tensorShape, tensorShapeLength); + auto lengthPerSlice2 = + lengthPerSlice(rank, shape, dimension, dimensionLength); + if (lengthPerSlice2 <= 0) { + return 0; + } + + Nd4jLong offset = index * tensorLength / lengthPerSlice2; + return offset; +} - /** +/** * calculates the offset for a tensor * @param index * @param arr @@ -3249,106 +3417,105 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons * @return */ - INLINEDEF _CUDA_HD Nd4jLong sliceOffsetForTensor(int index,int tensorLength,int lengthPerSlice2) { - Nd4jLong offset = index * tensorLength / lengthPerSlice2; - return offset; - } - +INLINEDEF _CUDA_HD Nd4jLong sliceOffsetForTensor(int index, int tensorLength, + int lengthPerSlice2) { + Nd4jLong offset = index * tensorLength / lengthPerSlice2; + return offset; +} #ifdef __CUDACC__ /** -* Computes the offset for accessing -* a global element given the shape information -* and the offset to be read. -*/ - INLINEDEF _CUDA_D int tadOffset(Nd4jLong *xInfo, int offset) { - return offset + threadIdx.x * elementWiseStride(xInfo); - + * Computes the offset for accessing + * a global element given the shape information + * and the offset to be read. + */ +INLINEDEF _CUDA_D int tadOffset(Nd4jLong *xInfo, int offset) { + return offset + threadIdx.x * elementWiseStride(xInfo); } #endif - - - - /** * Computes the number * of tensors along * a given dimension */ - INLINEDEF _CUDA_HD Nd4jLong tensorsAlongDimension(volatile int rank, volatile int length, - volatile Nd4jLong *shape, int *dimension, int dimensionLength) { - Nd4jLong *tensorShape = shape::keep(shape, dimension, dimensionLength, rank); - Nd4jLong ret = length / shape::prodLong(tensorShape, dimensionLength); - delete[] tensorShape; - return ret; - } +INLINEDEF _CUDA_HD Nd4jLong tensorsAlongDimension(volatile int rank, + volatile int length, + volatile Nd4jLong *shape, + int *dimension, + int dimensionLength) { + Nd4jLong *tensorShape = shape::keep(shape, dimension, dimensionLength, rank); + Nd4jLong ret = length / shape::prodLong(tensorShape, dimensionLength); + delete[] tensorShape; + return ret; +} /** * Computes the number * of tensors along * a given dimension */ - INLINEDEF _CUDA_HD Nd4jLong tensorsAlongDimension(Nd4jLong *shapeInfo, int *dimension, int dimensionLength) { - Nd4jLong *keepShape = shape::shapeOf(shapeInfo); - Nd4jLong *tensorShape = shape::keep(keepShape, dimension, dimensionLength, rank(shapeInfo)); - Nd4jLong ret = shape::length(shapeInfo) / shape::prodLong(tensorShape, dimensionLength); - delete[] tensorShape; - return ret; - } - - - +INLINEDEF _CUDA_HD Nd4jLong tensorsAlongDimension(Nd4jLong *shapeInfo, + int *dimension, + int dimensionLength) { + Nd4jLong *keepShape = shape::shapeOf(shapeInfo); + Nd4jLong *tensorShape = + shape::keep(keepShape, dimension, dimensionLength, rank(shapeInfo)); + Nd4jLong ret = + shape::length(shapeInfo) / shape::prodLong(tensorShape, dimensionLength); + delete[] tensorShape; + return ret; +} /** -* Get an offset for retrieval -* from a data buffer -* based on the given -* shape stride and given indices -* @param baseOffset the offset to start from -* @param shape the shape of the array -* @param stride the stride of the array -* @param indices the indices to iterate over -* @return the double at the specified index -*/ + * Get an offset for retrieval + * from a data buffer + * based on the given + * shape stride and given indices + * @param baseOffset the offset to start from + * @param shape the shape of the array + * @param stride the stride of the array + * @param indices the indices to iterate over + * @return the double at the specified index + */ ////////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong *indices, Nd4jLong baseOffset) { +INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, + const Nd4jLong *indices, + Nd4jLong baseOffset) { + Nd4jLong offset = baseOffset; - Nd4jLong offset = baseOffset; + for (uint i = 1; i <= shapeInfo[0]; ++i) + if (shapeInfo[i] != 1) + offset += indices[i - 1] * shapeInfo[shapeInfo[0] + i]; - for(uint i = 1; i <= shapeInfo[0]; ++i) - if(shapeInfo[i] != 1) - offset += indices[i - 1] * shapeInfo[shapeInfo[0] + i]; - - return offset; + return offset; } ////////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *coords, Nd4jLong baseOffset) { - - Nd4jLong offset = baseOffset; +INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, + const int *coords, Nd4jLong baseOffset) { + Nd4jLong offset = baseOffset; - for(uint i = 1; i <= shapeInfo[0]; ++i) - if(shapeInfo[i] != 1) - offset += coords[i - 1] * shapeInfo[shapeInfo[0] + i]; + for (uint i = 1; i <= shapeInfo[0]; ++i) + if (shapeInfo[i] != 1) + offset += coords[i - 1] * shapeInfo[shapeInfo[0] + i]; - return offset; + return offset; } ////////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coords, Nd4jLong baseOffset) { +INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, + const uint *coords, Nd4jLong baseOffset) { + Nd4jLong offset = baseOffset; - Nd4jLong offset = baseOffset; + for (uint i = 1; i <= shapeInfo[0]; ++i) + if (shapeInfo[i] != 1) + offset += coords[i - 1] * shapeInfo[shapeInfo[0] + i]; - for(uint i = 1; i <= shapeInfo[0]; ++i) - if(shapeInfo[i] != 1) - offset += coords[i - 1] * shapeInfo[shapeInfo[0] + i]; - - return offset; + return offset; } - /** * Returns the tensor along dimension * for the given block index @@ -3357,194 +3524,196 @@ INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coo * @param i * @return */ - INLINEDEF _CUDA_HD int tadForBlockIndex(int blockSize, int blockIdx, int i) { - return blockIdx + i * blockSize; - } +INLINEDEF _CUDA_HD int tadForBlockIndex(int blockSize, int blockIdx, int i) { + return blockIdx + i * blockSize; +} /** * Computes the number of tads per block * */ - INLINEDEF _CUDA_HD int tadsPerBlock(int blockSize, int tads) { - return sd::math::nd4j_ceil(tads / (double) blockSize); - } +INLINEDEF _CUDA_HD int tadsPerBlock(int blockSize, int tads) { + return sd::math::nd4j_ceil(tads / (double)blockSize); +} /** * Returns a shape buffer * for the shape information metadata. */ - INLINEDEF _CUDA_HD Nd4jLong *toShapeBuffer( ShapeInformation *info) { +INLINEDEF _CUDA_HD Nd4jLong *toShapeBuffer(ShapeInformation *info) { + traceNew(29); - traceNew(29); + auto ret = new Nd4jLong[shapeInfoLength(info->rank)]; + int count = 1; + int rank = info->rank; - auto ret = new Nd4jLong[shapeInfoLength(info->rank)]; - int count = 1; - int rank = info->rank; + ret[0] = info->rank; - ret[0] = info->rank; + for (int i = 0; i < rank; i++) { + ret[count++] = info->shape[i]; + } - for (int i = 0; i < rank; i++) { - ret[count++] = info->shape[i]; - } + for (int i = 0; i < rank; i++) { + ret[count++] = info->stride[i]; + } - for (int i = 0; i < rank; i++) { - ret[count++] = info->stride[i]; - } + ret[count++] = info->offset; + ret[count++] = info->elementWiseStride; + ret[count] = info->order; - ret[count++] = info->offset; - ret[count++] = info->elementWiseStride; - ret[count] = info->order; + return ret; +} - return ret; - } +INLINEDEF _CUDA_HD Nd4jLong *toShapeBuffer(ShapeInformation *info, + Nd4jLong *ret) { + int count = 1; + int rank = info->rank; - INLINEDEF _CUDA_HD Nd4jLong *toShapeBuffer( ShapeInformation *info, Nd4jLong* ret) { + ret[0] = info->rank; - int count = 1; - int rank = info->rank; + if (ret[0] == 0) { + ret[1] = 0; + ret[2] = 1; + ret[3] = 99; + return ret; + } - ret[0] = info->rank; + for (int i = 0; i < rank; i++) { + ret[count++] = info->shape[i]; + } - if (ret[0] == 0) { - ret[1] = 0; - ret[2] = 1; - ret[3] = 99; - return ret; - } + for (int i = 0; i < rank; i++) { + ret[count++] = info->stride[i]; + } - for (int i = 0; i < rank; i++) { - ret[count++] = info->shape[i]; - } + ret[count++] = info->offset; + ret[count++] = info->elementWiseStride; + ret[count++] = info->order; - for (int i = 0; i < rank; i++) { - ret[count++] = info->stride[i]; - } + return ret; +} - ret[count++] = info->offset; - ret[count++] = info->elementWiseStride; - ret[count++] = info->order; +INLINEDEF _CUDA_HD void printIntArray(const Nd4jLong *arr, const int length) { + for (int i = 0; i < length; i++) { + printf(" %lld ", (long long)arr[i]); + } - return ret; - } + printf("\n"); +} - INLINEDEF _CUDA_HD void printIntArray(const Nd4jLong *arr, const int length) { - for(int i = 0; i < length; i++) { - printf(" %lld ", (long long) arr[i]); - } +INLINEDEF _CUDA_HD void printIntArray(const int *arr, const int length) { + for (int i = 0; i < length; i++) { + printf(" %i ", arr[i]); + } - printf("\n"); - } + printf("\n"); +} - INLINEDEF _CUDA_HD void printIntArray(const int *arr, const int length) { - for(int i = 0; i < length; i++) { - printf(" %i ", arr[i]); - } +INLINEDEF _CUDA_HD void printShapeInfo(Nd4jLong *shapeInfo) { + int rank = shape::rank(shapeInfo); + Nd4jLong *shape = shape::shapeOf(shapeInfo); + printf("Rank %d\n", rank); + printf("Shape:\n"); + for (int i = 0; i < rank; i++) { + printf(" %lld ", (long long)shape[i]); + } - printf("\n"); - } + printf("\n"); - INLINEDEF _CUDA_HD void printShapeInfo(Nd4jLong *shapeInfo) { - int rank = shape::rank(shapeInfo); - Nd4jLong *shape = shape::shapeOf(shapeInfo); - printf("Rank %d\n",rank); - printf("Shape:\n"); - for(int i = 0; i < rank; i++) { - printf(" %lld ",(long long) shape[i]); - } + Nd4jLong *stride = shape::stride(shapeInfo); + printf("Stride:\n"); + for (int i = 0; i < rank; i++) { + printf(" %lld ", (long long)stride[i]); + } - printf("\n"); + printf("\n"); - Nd4jLong *stride = shape::stride(shapeInfo); - printf("Stride:\n"); - for(int i = 0; i < rank; i++) { - printf(" %lld ", (long long) stride[i]); - } + printf("Order %c\n", shape::order(shapeInfo)); +} - printf("\n"); +INLINEDEF _CUDA_HD void printShapeInfoLinear(const Nd4jLong *shapeInfo) { + int rank = shape::rank(shapeInfo); + int lim = shape::shapeInfoLength(rank); + printf("ShapeInfo: ["); + for (int i = 0; i < lim; i++) { + printf("%lld", (long long)shapeInfo[i]); - printf("Order %c\n",shape::order(shapeInfo)); + if (i < lim - 1) { + printf(", "); } - - INLINEDEF _CUDA_HD void printShapeInfoLinear(const Nd4jLong *shapeInfo) { - int rank = shape::rank(shapeInfo); - int lim = shape::shapeInfoLength(rank); - printf("ShapeInfo: ["); - for (int i = 0; i < lim; i++) { - printf("%lld", (long long) shapeInfo[i]); - - if (i < lim - 1) { - printf(", "); - } - } - printf("]\n"); + } + printf("]\n"); #ifndef __CUDA_ARCH__ - fflush(stdout); + fflush(stdout); #endif - } +} - INLINEDEF _CUDA_HD void printShapeInfoLinear(const char *msg, int rank, const Nd4jLong *shape, const Nd4jLong *strides) { - printf("%s : [", msg); - for (int i = 0; i < rank; i++) { - printf("%lld, ", (long long) shape[i]); - } +INLINEDEF _CUDA_HD void printShapeInfoLinear(const char *msg, int rank, + const Nd4jLong *shape, + const Nd4jLong *strides) { + printf("%s : [", msg); + for (int i = 0; i < rank; i++) { + printf("%lld, ", (long long)shape[i]); + } - for (int i = 0; i < rank; i++) { - printf("%lld", (long long) strides[i]); + for (int i = 0; i < rank; i++) { + printf("%lld", (long long)strides[i]); - if (i < rank - 1) - printf(", "); - } - printf("]\n"); + if (i < rank - 1) printf(", "); + } + printf("]\n"); #ifndef __CUDA_ARCH__ - fflush(stdout); + fflush(stdout); #endif - } +} - INLINEDEF _CUDA_HD void printShapeInfoLinear(const char *msg, const Nd4jLong *shapeInfo) { - int rank = shape::rank(shapeInfo); - int lim = shape::shapeInfoLength(rank); - printf("%s : [", msg); - for (int i = 0; i < lim; i++) { - printf("%lld", (long long) shapeInfo[i]); +INLINEDEF _CUDA_HD void printShapeInfoLinear(const char *msg, + const Nd4jLong *shapeInfo) { + int rank = shape::rank(shapeInfo); + int lim = shape::shapeInfoLength(rank); + printf("%s : [", msg); + for (int i = 0; i < lim; i++) { + printf("%lld", (long long)shapeInfo[i]); - if (i < lim - 1) { - printf(", "); - } - } - printf("]\n"); + if (i < lim - 1) { + printf(", "); + } + } + printf("]\n"); #ifndef __CUDACC__ - fflush(stdout); + fflush(stdout); #endif - } - - template - INLINEDEF _CUDA_HD void printArray(void *varr,int length, const char * message) { - auto arr = reinterpret_cast(varr); - if (message != nullptr) - printf("%s: [", message); - else - printf("Array: ["); +} - for (int i = 0; i < length; i ++) { - printf("%f", (float) arr[i]); - if (i + 1 < length) printf(", "); - } - printf("]\n"); +template +INLINEDEF _CUDA_HD void printArray(void *varr, int length, + const char *message) { + auto arr = reinterpret_cast(varr); + if (message != nullptr) + printf("%s: [", message); + else + printf("Array: ["); + + for (int i = 0; i < length; i++) { + printf("%f", (float)arr[i]); + if (i + 1 < length) printf(", "); + } + printf("]\n"); #ifndef __CUDACC__ - fflush(stdout); + fflush(stdout); #endif - } +} - INLINEDEF _CUDA_HD void printArray(float *arr,int length) { - printf("Array: ["); - for (int i = 0; i < length; i ++) { - printf("%f", arr[i]); - if (i + 1 < length) printf(", "); - } - printf("]\n"); - } +INLINEDEF _CUDA_HD void printArray(float *arr, int length) { + printf("Array: ["); + for (int i = 0; i < length; i++) { + printf("%f", arr[i]); + if (i + 1 < length) printf(", "); + } + printf("]\n"); +} /** * Given an linear index, element wise stride * and the length of each tad @@ -3554,54 +3723,55 @@ INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coo * @param numElementsPerTad the number of elements * per tad */ - INLINEDEF _CUDA_HD int tadIndex(int i, int elementWiseStride, int numElementsPerTad) { - return i / (numElementsPerTad * elementWiseStride); - } +INLINEDEF _CUDA_HD int tadIndex(int i, int elementWiseStride, + int numElementsPerTad) { + return i / (numElementsPerTad * elementWiseStride); +} /** * Map a tad to a * reduction index. * @param tadIndexForOriginal the original tad index for the * split up problem (eg: split is dimension 3 mapping to a 2,3 problem) - * @param tadsForReduced the number of tads for the shrunk down problem (eg: 2,3) + * @param tadsForReduced the number of tads for the shrunk down problem (eg: + * 2,3) * @param tadsForOriginal the number of tads for the smaller problem (eg: 3) */ - INLINEDEF _CUDA_HD int reductionIndexForTad(int tadIndexForOriginal, int tadsForReduced, - int tadsForOriginal) { - if (tadIndexForOriginal == 0) - return 0; - return tadIndexForOriginal / (tadsForOriginal / tadsForReduced); - } - - - INLINEDEF _CUDA_HD void transposeInplace(Nd4jLong *shapeBuffer) { - int rank = shape::rank(shapeBuffer); - Nd4jLong *shape = shape::shapeOf(shapeBuffer); - Nd4jLong *strides = shape::stride(shapeBuffer); - - // swap shape - for (int e = 0; e < rank / 2; e++) { - int idx1 = rank - e - 1; - int idx2 = e; - int tmp = shape[idx2]; - shape[idx2] = shape[idx1]; - shape[idx1] = tmp; - } - - // swap strides - for (int e = 0; e < rank / 2; e++) { - int idx1 = rank - e - 1; - int idx2 = e; - int tmp = strides[idx2]; - strides[idx2] = strides[idx1]; - strides[idx1] = tmp; - } +INLINEDEF _CUDA_HD int reductionIndexForTad(int tadIndexForOriginal, + int tadsForReduced, + int tadsForOriginal) { + if (tadIndexForOriginal == 0) return 0; + return tadIndexForOriginal / (tadsForOriginal / tadsForReduced); +} - if (shape::order(shapeBuffer) == 'c') - shapeBuffer[shape::shapeInfoLength(shapeBuffer) - 1] = 102; - else - shapeBuffer[shape::shapeInfoLength(shapeBuffer) - 1] = 99; - } +INLINEDEF _CUDA_HD void transposeInplace(Nd4jLong *shapeBuffer) { + int rank = shape::rank(shapeBuffer); + Nd4jLong *shape = shape::shapeOf(shapeBuffer); + Nd4jLong *strides = shape::stride(shapeBuffer); + + // swap shape + for (int e = 0; e < rank / 2; e++) { + int idx1 = rank - e - 1; + int idx2 = e; + int tmp = shape[idx2]; + shape[idx2] = shape[idx1]; + shape[idx1] = tmp; + } + + // swap strides + for (int e = 0; e < rank / 2; e++) { + int idx1 = rank - e - 1; + int idx2 = e; + int tmp = strides[idx2]; + strides[idx2] = strides[idx1]; + strides[idx1] = tmp; + } + + if (shape::order(shapeBuffer) == 'c') + shapeBuffer[shape::shapeInfoLength(shapeBuffer) - 1] = 102; + else + shapeBuffer[shape::shapeInfoLength(shapeBuffer) - 1] = 99; +} /** * Tad index for linear @@ -3609,18 +3779,19 @@ INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coo * @param tadLength * @return */ - INLINEDEF _CUDA_HD int tadIndexForLinear(int linearIndex, int tadLength) { - return linearIndex % tadLength; - } +INLINEDEF _CUDA_HD int tadIndexForLinear(int linearIndex, int tadLength) { + return linearIndex % tadLength; +} /** * Computes the number of tads * per reduce index for the * reduction tad. */ - INLINEDEF _CUDA_HD int tadsPerReduceIndex(int tadsForReduce, int tadsForOriginal) { - return tadsForOriginal / tadsForReduce; - } +INLINEDEF _CUDA_HD int tadsPerReduceIndex(int tadsForReduce, + int tadsForOriginal) { + return tadsForOriginal / tadsForReduce; +} /** * Maps a linear index to a reduction index @@ -3630,139 +3801,131 @@ INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coo * @param tadNum the number of tads for the shrunken problem * @param originalTadNum the tad number for the reduced version of the problem */ - INLINEDEF _CUDA_HD int reductionIndexForLinear(int i, int elementWiseStride, int numElementsPerTad, - int tadNum, int originalTadNum) { - int tad = tadIndex(i, elementWiseStride, numElementsPerTad); - return reductionIndexForTad(tad, tadNum, originalTadNum); - } - - INLINEDEF _CUDA_HD Nd4jLong* createScalarShapeInfo() { - - traceNew(30); - - auto shape = new Nd4jLong[1]; - shape[0] = 1; - auto stride = new Nd4jLong[1]; - stride[0] = 1; - auto shapeInformation2 = new ShapeInformation(); - shapeInformation2->rank = 1; - shapeInformation2->offset = 0; - shapeInformation2->stride = stride; - shapeInformation2->shape = shape; - shapeInformation2->elementWiseStride = 1; - shapeInformation2->order = 99; - Nd4jLong *ret = shape::toShapeBuffer(shapeInformation2); - delete shapeInformation2; - delete[] shape; - delete[] stride; - return ret; - } - - INLINEDEF _CUDA_HD Nd4jLong* createScalarShapeInfo(Nd4jLong *ret) { - ret[0] = 2; - ret[1] = 1; - ret[2] = 1; - ret[3] = 1; - ret[4] = 1; - ret[5] = 0; - ret[6] = 1; - ret[7] = 99; - - return ret; - } - - -/** - * Returns the prod of the data - * up to the given length - */ - INLINEDEF _CUDA_HD Nd4jLong prodLong(const Nd4jLong *data, int length) { - Nd4jLong prod = 1; - for (int i = 0; i < length; i++) { - prod *= data[i]; - } - - return prod; - } - - INLINEDEF _CUDA_HD int rearMostLeftOverItem(Nd4jLong *data, Nd4jLong *dimension,int dimensionLength) { - Nd4jLong *stride = shape::stride(data); - //corner case: return the final item when its greater than the max, since its guaranteed to be left over - //note here that strides are interpreted in reverse for tad - //start from the front rather than the back - - int rank = shape::rank(data); - - - if(shape::order(data) == 'f') { - int dimIdx = dimensionLength - 1; - for(int i = rank - 1; i >= 0; i--) { - /** - * Needs to find an algorithm such that: - * looping backwards will find the highest dimension left - * that isn't included in the dimension index list. - * - * This can also be thought of as the last item of the first index - * of the difference between the full list of indices and - * the dimension indices. - * - * We should avoid excessive object creation by only looping backwards. - */ - if(dimension[dimIdx--] != i) { - int ret = stride[i]; - return ret; - } - } - } +INLINEDEF _CUDA_HD int reductionIndexForLinear(int i, int elementWiseStride, + int numElementsPerTad, + int tadNum, int originalTadNum) { + int tad = tadIndex(i, elementWiseStride, numElementsPerTad); + return reductionIndexForTad(tad, tadNum, originalTadNum); +} - else { - int dimIdx = dimensionLength - 1; - - for(int i = rank - 1; i >= 0; i--) { - /** - * Needs to find an algorithm such that: - * looping backwards will find the highest dimension left - * that isn't included in the dimension index list. - * - * This can also be thought of as the last item of the first index - * of the difference between the full list of indices and - * the dimension indices. - * - * We should avoid excessive object creation by only looping backwards. - */ - if(dimension[dimIdx--] != i) { - int ret = stride[i]; - return ret; - } - } - } +INLINEDEF _CUDA_HD Nd4jLong *createScalarShapeInfo() { + traceNew(30); + + auto shape = new Nd4jLong[1]; + shape[0] = 1; + auto stride = new Nd4jLong[1]; + stride[0] = 1; + auto shapeInformation2 = new ShapeInformation(); + shapeInformation2->rank = 1; + shapeInformation2->offset = 0; + shapeInformation2->stride = stride; + shapeInformation2->shape = shape; + shapeInformation2->elementWiseStride = 1; + shapeInformation2->order = 99; + Nd4jLong *ret = shape::toShapeBuffer(shapeInformation2); + delete shapeInformation2; + delete[] shape; + delete[] stride; + return ret; +} +INLINEDEF _CUDA_HD Nd4jLong *createScalarShapeInfo(Nd4jLong *ret) { + ret[0] = 2; + ret[1] = 1; + ret[2] = 1; + ret[3] = 1; + ret[4] = 1; + ret[5] = 0; + ret[6] = 1; + ret[7] = 99; + + return ret; +} +/** + * Returns the prod of the data + * up to the given length + */ +INLINEDEF _CUDA_HD Nd4jLong prodLong(const Nd4jLong *data, int length) { + Nd4jLong prod = 1; + for (int i = 0; i < length; i++) { + prod *= data[i]; + } + return prod; +} - int ret = stride[0]; +INLINEDEF _CUDA_HD int rearMostLeftOverItem(Nd4jLong *data, Nd4jLong *dimension, + int dimensionLength) { + Nd4jLong *stride = shape::stride(data); + // corner case: return the final item when its greater than the max, since its + // guaranteed to be left over note here that strides are interpreted in + // reverse for tad start from the front rather than the back + + int rank = shape::rank(data); + + if (shape::order(data) == 'f') { + int dimIdx = dimensionLength - 1; + for (int i = rank - 1; i >= 0; i--) { + /** + * Needs to find an algorithm such that: + * looping backwards will find the highest dimension left + * that isn't included in the dimension index list. + * + * This can also be thought of as the last item of the first index + * of the difference between the full list of indices and + * the dimension indices. + * + * We should avoid excessive object creation by only looping backwards. + */ + if (dimension[dimIdx--] != i) { + int ret = stride[i]; + return ret; + } + } + } + + else { + int dimIdx = dimensionLength - 1; + + for (int i = rank - 1; i >= 0; i--) { + /** + * Needs to find an algorithm such that: + * looping backwards will find the highest dimension left + * that isn't included in the dimension index list. + * + * This can also be thought of as the last item of the first index + * of the difference between the full list of indices and + * the dimension indices. + * + * We should avoid excessive object creation by only looping backwards. + */ + if (dimension[dimIdx--] != i) { + int ret = stride[i]; return ret; + } } + } + + int ret = stride[0]; + return ret; +} #ifdef __CUDACC__ - __device__ INLINEDEF void sweepShapeInfoBuffer(Nd4jLong *shapeInfoBuffer, Nd4jLong *targetBuffer) { - // we read first element, to find out length of our shapeInfoBuffer - int rank = shapeInfoBuffer[0]; - int len = shape::shapeInfoLength(rank); - for (int i = threadIdx.x; i < len; i += blockDim.x) - targetBuffer[i] = shapeInfoBuffer[i]; +__device__ INLINEDEF void sweepShapeInfoBuffer(Nd4jLong *shapeInfoBuffer, + Nd4jLong *targetBuffer) { + // we read first element, to find out length of our shapeInfoBuffer + int rank = shapeInfoBuffer[0]; + int len = shape::shapeInfoLength(rank); + for (int i = threadIdx.x; i < len; i += blockDim.x) + targetBuffer[i] = shapeInfoBuffer[i]; } #endif - - INLINEDEF _CUDA_HD Nd4jLong *shapeBufferOfNpy(cnpy::NpyArray arr) { - return shape::shapeBufferOfNpy(arr.shape.size(),(unsigned int*) arr.shape.data(),arr.fortranOrder); - } - - - - - +INLINEDEF _CUDA_HD Nd4jLong *shapeBufferOfNpy(cnpy::NpyArray arr) { + return shape::shapeBufferOfNpy( + arr.shape.size(), (unsigned int *)arr.shape.data(), arr.fortranOrder); +} // INLINEDEF _CUDA_HD Nd4jLong *shapeBufferOfNpyBuffer(char *buffer) { // unsigned Nd4jLong *shape; @@ -3774,90 +3937,84 @@ INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coo // return ret; // } +INLINEDEF _CUDA_HD Nd4jLong *shapeBufferOfNpy(int rank, unsigned int *shape, + bool fortranOrder) { + if (fortranOrder) { + Nd4jLong *shapeBufferRet = + shape::shapeBufferFortran(rank, sd::FLOAT32, (Nd4jLong *)shape); + return shapeBufferRet; + } else { + Nd4jLong *newShape = new Nd4jLong[rank]; + for (int i = 0; i < rank; i++) { + newShape[i] = shape[i]; + } + + Nd4jLong *shapeBufferRet = shape::shapeBuffer(rank, sd::FLOAT32, newShape); + delete[] newShape; + return shapeBufferRet; + } +} - INLINEDEF _CUDA_HD Nd4jLong *shapeBufferOfNpy(int rank, unsigned int* shape,bool fortranOrder) { - if(fortranOrder) { - Nd4jLong *shapeBufferRet = shape::shapeBufferFortran(rank, sd::FLOAT32,(Nd4jLong *) shape); - return shapeBufferRet; - } - else { - Nd4jLong *newShape = new Nd4jLong[rank]; - for(int i = 0; i < rank; i++) { - newShape[i] = shape[i]; - } - - Nd4jLong *shapeBufferRet = shape::shapeBuffer(rank, sd::FLOAT32, newShape); - delete[] newShape; - return shapeBufferRet; - - } - } +INLINEDEF _CUDA_HD bool strideDescendingCAscendingF( + const Nd4jLong *shapeBuffer) { + int rank = shape::rank(shapeBuffer); + Nd4jLong *strides = shape::stride(const_cast(shapeBuffer)); + char order = shape::order(shapeBuffer); - INLINEDEF _CUDA_HD bool strideDescendingCAscendingF(const Nd4jLong *shapeBuffer) { - int rank = shape::rank(shapeBuffer); - Nd4jLong *strides = shape::stride(const_cast(shapeBuffer)); - char order = shape::order(shapeBuffer); - - if (shape::isRowVector(shapeBuffer) && strides[0] == 1 && strides[1] == 1) - return true; - - if (order == 'c') { - for (int i = 1; i < rank; i++) - if (strides[i-1] <= strides[i]) - return false; - return true; - } else if (order == 'f') { - for (int i = 1; i < rank; i++) - if (strides[i-1] >= strides[i]) - return false; - return true; - } else { - printf("Unknown order for array!\n"); - return false; - } - } + if (shape::isRowVector(shapeBuffer) && strides[0] == 1 && strides[1] == 1) + return true; - INLINEDEF _CUDA_HD bool isContiguous(const Nd4jLong* shapeInfo) { + if (order == 'c') { + for (int i = 1; i < rank; i++) + if (strides[i - 1] <= strides[i]) return false; + return true; + } else if (order == 'f') { + for (int i = 1; i < rank; i++) + if (strides[i - 1] >= strides[i]) return false; + return true; + } else { + printf("Unknown order for array!\n"); + return false; + } +} - return (order(shapeInfo) == 'c') && (elementWiseStride(shapeInfo) > 0); - } +INLINEDEF _CUDA_HD bool isContiguous(const Nd4jLong *shapeInfo) { + return (order(shapeInfo) == 'c') && (elementWiseStride(shapeInfo) > 0); +} ////////////////////////////////////////////////////////////////////////// // copy-past from java hasDefaultStridesForShape function -INLINEDEF _CUDA_HD bool areStridesDefault(const Nd4jLong* shapeInfo) { - - const int rank = shape::rank(shapeInfo); +INLINEDEF _CUDA_HD bool areStridesDefault(const Nd4jLong *shapeInfo) { + const int rank = shape::rank(shapeInfo); - if(rank == 0) - return true; - if(!strideDescendingCAscendingF(shapeInfo)) - return false; + if (rank == 0) return true; + if (!strideDescendingCAscendingF(shapeInfo)) return false; - Nd4jLong defaultShapeInfo[MAX_SHAPEINFOLENGTH]; - memcpy(defaultShapeInfo, shapeInfo, shape::shapeInfoByteLength(shapeInfo)); - shape::updateStrides(defaultShapeInfo, shape::order(shapeInfo)); + Nd4jLong defaultShapeInfo[MAX_SHAPEINFOLENGTH]; + memcpy(defaultShapeInfo, shapeInfo, shape::shapeInfoByteLength(shapeInfo)); + shape::updateStrides(defaultShapeInfo, shape::order(shapeInfo)); - bool result = true; - for(int i = rank+1; i <= 2*rank; ++i) - if(defaultShapeInfo[i] != shapeInfo[i]) { - result = false; - break; - } + bool result = true; + for (int i = rank + 1; i <= 2 * rank; ++i) + if (defaultShapeInfo[i] != shapeInfo[i]) { + result = false; + break; + } - return result; + return result; } -// INLINEDEF _CUDA_H bool reshapeC(const int oldRank, Nd4jLong* oldShape, const int newRank, Nd4jLong* newShapeOf, bool isFOrder, Nd4jLong* target) { +// INLINEDEF _CUDA_H bool reshapeC(const int oldRank, Nd4jLong* oldShape, const +// int newRank, Nd4jLong* newShapeOf, bool isFOrder, Nd4jLong* target) { // int oldnd; // Nd4jLong* olddims = shape::copyOf(oldRank, shape::shapeOf(oldShape)); -// Nd4jLong* oldstrides = shape::copyOf(oldRank, shape::stride(oldShape)); -// int np, op, last_stride; -// int oi, oj, ok, ni, nj, nk; -// Nd4jLong* newStrides = new Nd4jLong[newRank]; -// oldnd = 0; +// Nd4jLong* oldstrides = shape::copyOf(oldRank, +// shape::stride(oldShape)); int np, op, last_stride; int oi, oj, ok, +// ni, nj, nk; Nd4jLong* newStrides = new Nd4jLong[newRank]; oldnd = 0; // /* -// * Remove axes with dimension 1 from the old array. They have no effect +// * Remove axes with dimension 1 from the old array. They have no +// effect // * but would need special cases since their strides do not matter. // */ // for (oi = 0; oi < oldRank; oi++) { @@ -3894,11 +4051,8 @@ INLINEDEF _CUDA_HD bool areStridesDefault(const Nd4jLong* shapeInfo) { // return false; // } -// /* oi to oj and ni to nj give the axis ranges currently worked with */ -// oi = 0; -// oj = 1; -// ni = 0; -// nj = 1; +// /* oi to oj and ni to nj give the axis ranges currently worked with +// */ oi = 0; oj = 1; ni = 0; nj = 1; // while (ni < newRank && oi < oldnd) { // np = newShapeOf[ni]; @@ -3926,7 +4080,8 @@ INLINEDEF _CUDA_HD bool areStridesDefault(const Nd4jLong* shapeInfo) { // } // } else { // /* C order */ -// if (oldstrides[ok] != olddims[ok + 1] * oldstrides[ok + 1]) { +// if (oldstrides[ok] != olddims[ok + 1] * oldstrides[ok + +// 1]) { // /* not contiguous enough */ // delete[] olddims; // delete[] oldstrides; @@ -3977,7 +4132,8 @@ INLINEDEF _CUDA_HD bool areStridesDefault(const Nd4jLong* shapeInfo) { // target[shape::shapeInfoLength(newRank) - 3] = 0; // target[shape::shapeInfoLength(newRank) - 2] = 0; // target[shape::shapeInfoLength(newRank) - 1] = isFOrder ? 102 : 99; -// sd::ArrayOptions::setDataType(target, sd::ArrayOptions::dataType(oldShape)); +// sd::ArrayOptions::setDataType(target, +// sd::ArrayOptions::dataType(oldShape)); // delete[] olddims; // delete[] oldstrides; @@ -3987,18 +4143,26 @@ INLINEDEF _CUDA_HD bool areStridesDefault(const Nd4jLong* shapeInfo) { // } ////////////////////////////////////////////////////////////////////// -// INLINEDEF _CUDA_H bool reshapeC(const int oldRank, const Nd4jLong* oldShapeInfo, const int newRank, const Nd4jLong* newShape, Nd4jLong* newShapeInfo) { +// INLINEDEF _CUDA_H bool reshapeC(const int oldRank, const Nd4jLong* +// oldShapeInfo, const int newRank, const Nd4jLong* newShape, Nd4jLong* +// newShapeInfo) { -// // PLEASE NOTE !: reshaping not-permuted (ews=1) array in f order (except insertion/elimination of unities) will definitely cause allocation of new buffer for array elements -// // also this function takes into account identical shapes automatically, namely in that case oldShapeInfo is completely copied to newShapeInfo +// // PLEASE NOTE !: reshaping not-permuted (ews=1) array in f order +// (except insertion/elimination of unities) will definitely cause +// allocation of new buffer for array elements +// // also this function takes into account identical shapes +// automatically, namely in that case oldShapeInfo is completely copied +// to newShapeInfo // newShapeInfo[0] = newRank; // memcpy(newShapeInfo + 1, newShape, newRank * sizeof(Nd4jLong)); // Nd4jLong* newStrides = shape::stride(newShapeInfo); -// const Nd4jLong* oldShape = shape::shapeOf(const_cast(oldShapeInfo)); -// const Nd4jLong* oldStrides = shape::stride(const_cast(oldShapeInfo)); -// Nd4jLong oldStart(0), oldStop(1), newStart(0), newStop(1), newDim, oldDim; +// const Nd4jLong* oldShape = +// shape::shapeOf(const_cast(oldShapeInfo)); const Nd4jLong* +// oldStrides = shape::stride(const_cast(oldShapeInfo)); +// Nd4jLong oldStart(0), oldStop(1), newStart(0), newStop(1), newDim, +// oldDim; // while (newStart < newRank && oldStart < oldRank) { @@ -4009,13 +4173,16 @@ INLINEDEF _CUDA_HD bool areStridesDefault(const Nd4jLong* shapeInfo) { // if (newDim < oldDim) newDim *= newShape[newStop++]; // else oldDim *= oldShape[oldStop++]; -// // ------ Check whether the original axes can be combined ------ // -// for (int step = 1, i = oldStart; i < oldStop - 1; ++i) { -// if(oldShape[i] == 1) // skip unity-dimension and its stride +// // ------ Check whether the original axes can be combined ------ +// // for (int step = 1, i = oldStart; i < oldStop - 1; ++i) { +// if(oldShape[i] == 1) // skip unity-dimension +// and its stride // continue; // while((i + step) < oldRank && oldShape[i + step] == 1) -// ++step; // skip following unity-dimensions and its strides if such are present -// if((i + step) < oldRank && oldStrides[i] != oldShape[i + step] * oldStrides[i + step]) +// ++step; // skip following +// unity-dimensions and its strides if such are present +// if((i + step) < oldRank && oldStrides[i] != oldShape[i + +// step] * oldStrides[i + step]) // return false; // not contiguous enough // } @@ -4027,902 +4194,956 @@ INLINEDEF _CUDA_HD bool areStridesDefault(const Nd4jLong* shapeInfo) { // oldStart = oldStop++; // } -// // rest of strides should be unities (if there is remainder in strides space, that is newStart < newRank) -// for (int i = newStart; i < newRank; ++i) +// // rest of strides should be unities (if there is remainder in +// strides space, that is newStart < newRank) for (int i = newStart; i < +// newRank; ++i) // newStrides[i] = 1; -// newShapeInfo[2 * newRank + 3] = shape::order(oldShapeInfo); // order -// newShapeInfo[2 * newRank + 2] = shape::elementWiseStride(oldShapeInfo); // ews -// newShapeInfo[2 * newRank + 1] = shape::type(oldShapeInfo); // type +// newShapeInfo[2 * newRank + 3] = shape::order(oldShapeInfo); // order +// newShapeInfo[2 * newRank + 2] = +// shape::elementWiseStride(oldShapeInfo); // ews newShapeInfo[2 * +// newRank + 1] = shape::type(oldShapeInfo); // type // return true; // } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, const char newOrder, const int newRank, const Nd4jLong* newShape, Nd4jLong* newShapeInfo) { - - // copy shape from newShape into newShapeInfo - newShapeInfo[0] = newRank; - memcpy(newShapeInfo + 1, newShape, newRank * sizeof(Nd4jLong)); - - // copy order - newShapeInfo[2 * newRank + 3] = newOrder; - - return shape::reshapeC(oldShapeInfo, newShapeInfo); +INLINEDEF _CUDA_HD bool reshapeC(const Nd4jLong *oldShapeInfo, + const char newOrder, const int newRank, + const Nd4jLong *newShape, + Nd4jLong *newShapeInfo) { + // copy shape from newShape into newShapeInfo + newShapeInfo[0] = newRank; + memcpy(newShapeInfo + 1, newShape, newRank * sizeof(Nd4jLong)); + + // copy order + newShapeInfo[2 * newRank + 3] = newOrder; + + return shape::reshapeC(oldShapeInfo, newShapeInfo); } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, Nd4jLong* newShapeInfo) { - - // newShapeInfo contains rank, shape and order; but no strides, type and ews - - const int newRank = shape::rank(newShapeInfo); - - // if oldShapeInfo is scalar or vector with length=1 - if(shape::length(oldShapeInfo) == 1) { - for (uint i = 0; i < newRank; ++i) - shape::stride(newShapeInfo)[i] = 1; - newShapeInfo[2 * newRank + 1] = shape::type(oldShapeInfo); - *shape::ews(newShapeInfo) = 1; - return true; - } - - const auto oldOrder = shape::order(oldShapeInfo); - const auto newOrder = shape::order(newShapeInfo); - const auto oldEws = shape::elementWiseStride(const_cast(oldShapeInfo)); - - if(oldEws > 0 && oldOrder != newOrder) - return false; - - // *** FIRST STAGE - exclude unity dimensions from oldShapeInfo and newShapeInfo (if such are present of course), since they don't affect on strides evaluation, however they complicate code - - // FIXME - indeed we don't need to allocate so large memory amount (4*MAX_RANK), sufficient amount is (2*oldNumOfNonUnities + 2*newNumOfNonUnities) - Nd4jLong tempBuffer[4*MAX_RANK]; - Nd4jLong *oldShape = tempBuffer, *newShape = tempBuffer + 2*MAX_RANK, *oldStrides, *newStrides; - - // exclude unities from oldShapeInfo - const int oldNumOfNonUnities = shape::excludeUnitiesFromShapeInfo(oldShapeInfo, oldShape, oldStrides); - const int newNumOfNonUnities = shape::excludeUnitiesFromShapeInfo(newShapeInfo, newShape, newStrides); - - // *** SECOND STAGE - strides evaluation - - int oldStart(0), oldStop(1), newStart(0), newStop(1), newDim, oldDim; - - while (newStart < newNumOfNonUnities && oldStart < oldNumOfNonUnities) { - - newDim = newShape[newStart]; - oldDim = oldShape[oldStart]; - - while (newDim != oldDim && newDim > 0 && oldDim > 0) { - - if (newDim < oldDim) - newDim *= newShape[newStop++]; - else - oldDim *= oldShape[oldStop++]; - } - - // check c-contiguous of old axes range - for(uint i = oldStart; i < oldStop - 1; ++i) // do not check value of last stride, it doesn't matter - if(oldStrides[i] != oldShape[i + 1] * oldStrides[i + 1]) - return false; // not contiguous - - // fill newStrides in c manner - newStrides[newStop - 1] = oldStrides[oldStop - 1]; // copy last stride - for (int i = newStop - 2; i >= newStart; --i) - newStrides[i] = newStrides[i + 1] * newShape[i + 1]; - - newStart = newStop++; - oldStart = oldStop++; - } - - // fill new calculated strides into newShapeInfo, take into account possible unities in shape - for (int j = 0, i = 0; i < newRank; ++i) - shape::stride(newShapeInfo)[i] = (shape::shapeOf(newShapeInfo)[i] == 1) ? 1 : newStrides[j++]; +INLINEDEF _CUDA_HD bool reshapeC(const Nd4jLong *oldShapeInfo, + Nd4jLong *newShapeInfo) { + // newShapeInfo contains rank, shape and order; but no strides, type and ews - // set ews - if(oldEws == 0) - shape::checkStridesEwsAndOrder(newShapeInfo, newOrder, newNumOfNonUnities, newShape, newStrides); // set ews and order - else { - newShapeInfo[2 * newRank + 3] = oldOrder; // order - *shape::ews(newShapeInfo) = oldEws; // ews - } - - sd::ArrayOptions::copyDataType(newShapeInfo, oldShapeInfo); // type + const int newRank = shape::rank(newShapeInfo); + // if oldShapeInfo is scalar or vector with length=1 + if (shape::length(oldShapeInfo) == 1) { + for (uint i = 0; i < newRank; ++i) shape::stride(newShapeInfo)[i] = 1; + newShapeInfo[2 * newRank + 1] = shape::type(oldShapeInfo); + *shape::ews(newShapeInfo) = 1; return true; + } + + const auto oldOrder = shape::order(oldShapeInfo); + const auto newOrder = shape::order(newShapeInfo); + const auto oldEws = + shape::elementWiseStride(const_cast(oldShapeInfo)); + + if (oldEws > 0 && oldOrder != newOrder) return false; + + // *** FIRST STAGE - exclude unity dimensions from oldShapeInfo and + // newShapeInfo (if such are present of course), since they don't affect on + // strides evaluation, however they complicate code + + // FIXME - indeed we don't need to allocate so large memory amount + // (4*MAX_RANK), sufficient amount is (2*oldNumOfNonUnities + + // 2*newNumOfNonUnities) + Nd4jLong tempBuffer[4 * MAX_RANK]; + Nd4jLong *oldShape = tempBuffer, *newShape = tempBuffer + 2 * MAX_RANK, + *oldStrides, *newStrides; + + // exclude unities from oldShapeInfo + const int oldNumOfNonUnities = + shape::excludeUnitiesFromShapeInfo(oldShapeInfo, oldShape, oldStrides); + const int newNumOfNonUnities = + shape::excludeUnitiesFromShapeInfo(newShapeInfo, newShape, newStrides); + + // *** SECOND STAGE - strides evaluation + + int oldStart(0), oldStop(1), newStart(0), newStop(1), newDim, oldDim; + + while (newStart < newNumOfNonUnities && oldStart < oldNumOfNonUnities) { + newDim = newShape[newStart]; + oldDim = oldShape[oldStart]; + + while (newDim != oldDim && newDim > 0 && oldDim > 0) { + if (newDim < oldDim) + newDim *= newShape[newStop++]; + else + oldDim *= oldShape[oldStop++]; + } + + // check c-contiguous of old axes range + for (uint i = oldStart; i < oldStop - 1; + ++i) // do not check value of last stride, it doesn't matter + if (oldStrides[i] != oldShape[i + 1] * oldStrides[i + 1]) + return false; // not contiguous + + // fill newStrides in c manner + newStrides[newStop - 1] = oldStrides[oldStop - 1]; // copy last stride + for (int i = newStop - 2; i >= newStart; --i) + newStrides[i] = newStrides[i + 1] * newShape[i + 1]; + + newStart = newStop++; + oldStart = oldStop++; + } + + // fill new calculated strides into newShapeInfo, take into account possible + // unities in shape + for (int j = 0, i = 0; i < newRank; ++i) + shape::stride(newShapeInfo)[i] = + (shape::shapeOf(newShapeInfo)[i] == 1) ? 1 : newStrides[j++]; + + // set ews + if (oldEws == 0) + shape::checkStridesEwsAndOrder(newShapeInfo, newOrder, newNumOfNonUnities, + newShape, newStrides); // set ews and order + else { + newShapeInfo[2 * newRank + 3] = oldOrder; // order + *shape::ews(newShapeInfo) = oldEws; // ews + } + + sd::ArrayOptions::copyDataType(newShapeInfo, oldShapeInfo); // type + + return true; } - - INLINEDEF _CUDA_H bool canReshape(const int oldRank, Nd4jLong* oldShape, const int newRank, Nd4jLong* newShapeOf, bool isFOrder) { - Nd4jLong oldnd; - Nd4jLong* oldDims = shape::copyOf(oldRank, shape::shapeOf(oldShape)); - Nd4jLong* oldStrides = shape::copyOf(oldRank, shape::stride(oldShape)); - Nd4jLong np, op, last_stride; - Nd4jLong oldStart, oldStop, ok, newStart, newStop, nk; - auto newStrides = new Nd4jLong[newRank]; - oldnd = 0; - - /* - * Remove axes with dimension 1 from the old array. They have no effect - * but would need special cases since their strides do not matter. - */ - for (oldStart = 0; oldStart < oldRank; oldStart++) { - if (shape::shapeOf(oldShape)[oldStart] != 1) { - oldDims[oldnd] = shape::shapeOf(oldShape)[oldStart]; - oldStrides[oldnd] = shape::stride(oldShape)[oldStart]; - oldnd++; - } - } - - np = 1; - for (newStart = 0; newStart < newRank; newStart++) { - np *= newShapeOf[newStart]; - } - op = 1; - for (oldStart = 0; oldStart < oldnd; oldStart++) { - op *= oldDims[oldStart]; - } - if (np != op) { - /* different total sizes; no hope */ - delete[] oldDims; - delete[] oldStrides; - delete[] newStrides; - - return false; - } - - if (np == 0) { - /* the current code does not handle 0-sized arrays, so give up */ - delete[] oldDims; - delete[] oldStrides; - delete[] newStrides; - - return false; - } - - /* oldStart to oldStop and newStart to newStop give the axis ranges currently worked with */ - oldStart = 0; - oldStop = 1; - newStart = 0; - newStop = 1; - - while (newStart < newRank && oldStart < oldnd) { - np = newShapeOf[newStart]; - op = oldDims[oldStart]; - - while (np != op) { - if (np < op) { - /* Misses trailing 1s, these are handled later */ - np *= newShapeOf[newStop++]; - } else { - op *= oldDims[oldStop++]; - } - } - - /* Check whether the original axes can be combined */ - for (ok = oldStart; ok < oldStop - 1; ok++) { - if (isFOrder) { - if (oldStrides[ok + 1] != oldDims[ok] * oldStrides[ok]) { - /* not contiguous enough */ - delete[] oldDims; - delete[] oldStrides; - delete[] newStrides; - - return false; - } - } else { - /* C order */ - if (oldStrides[ok] != oldDims[ok + 1] * oldStrides[ok + 1]) { - /* not contiguous enough */ - delete[] oldDims; - delete[] oldStrides; - delete[] newStrides; - - return false; - } - } - } - - /* Calculate new strides for all axes currently worked with */ - if (isFOrder) { - newStrides[newStart] = oldStrides[oldStart]; - for (nk = newStart + 1; nk < newStop; nk++) { - newStrides[nk] = newStrides[nk - 1] * newShapeOf[nk - 1]; - } - } else { - /* C order */ - newStrides[newStop - 1] = oldStrides[oldStop - 1]; - for (nk = newStop - 1; nk > newStart; nk--) { - newStrides[nk - 1] = newStrides[nk] * newShapeOf[nk]; - } - } - newStart = newStop++; - oldStart = oldStop++; +INLINEDEF _CUDA_H bool canReshape(const int oldRank, Nd4jLong *oldShape, + const int newRank, Nd4jLong *newShapeOf, + bool isFOrder) { + Nd4jLong oldnd; + Nd4jLong *oldDims = shape::copyOf(oldRank, shape::shapeOf(oldShape)); + Nd4jLong *oldStrides = shape::copyOf(oldRank, shape::stride(oldShape)); + Nd4jLong np, op, last_stride; + Nd4jLong oldStart, oldStop, ok, newStart, newStop, nk; + auto newStrides = new Nd4jLong[newRank]; + oldnd = 0; + + /* + * Remove axes with dimension 1 from the old array. They have no effect + * but would need special cases since their strides do not matter. + */ + for (oldStart = 0; oldStart < oldRank; oldStart++) { + if (shape::shapeOf(oldShape)[oldStart] != 1) { + oldDims[oldnd] = shape::shapeOf(oldShape)[oldStart]; + oldStrides[oldnd] = shape::stride(oldShape)[oldStart]; + oldnd++; + } + } + + np = 1; + for (newStart = 0; newStart < newRank; newStart++) { + np *= newShapeOf[newStart]; + } + op = 1; + for (oldStart = 0; oldStart < oldnd; oldStart++) { + op *= oldDims[oldStart]; + } + if (np != op) { + /* different total sizes; no hope */ + delete[] oldDims; + delete[] oldStrides; + delete[] newStrides; + + return false; + } + + if (np == 0) { + /* the current code does not handle 0-sized arrays, so give up */ + delete[] oldDims; + delete[] oldStrides; + delete[] newStrides; + + return false; + } + + /* oldStart to oldStop and newStart to newStop give the axis ranges currently + * worked with */ + oldStart = 0; + oldStop = 1; + newStart = 0; + newStop = 1; + + while (newStart < newRank && oldStart < oldnd) { + np = newShapeOf[newStart]; + op = oldDims[oldStart]; + + while (np != op) { + if (np < op) { + /* Misses trailing 1s, these are handled later */ + np *= newShapeOf[newStop++]; + } else { + op *= oldDims[oldStop++]; + } + } + + /* Check whether the original axes can be combined */ + for (ok = oldStart; ok < oldStop - 1; ok++) { + if (isFOrder) { + if (oldStrides[ok + 1] != oldDims[ok] * oldStrides[ok]) { + /* not contiguous enough */ + delete[] oldDims; + delete[] oldStrides; + delete[] newStrides; + + return false; } - - delete[] oldDims; - delete[] oldStrides; - delete[] newStrides; - - return true; - } - - // this function checks the consistence of dimensions with array rank (negative dimensions, too large dimensions, too big number of dimensions) - // also it sorts input array of dimensions, this operation is also necessary for creating TAD object - INLINEDEF _CUDA_H void checkDimensions(const int rank, std::vector& dimensions) { - - int dimSize = dimensions.size(); - if(dimSize == 0) - throw std::runtime_error("shape::checkDimensions method: array of dimensions is empty!"); - // check presence of negative dimensions and if they are present transform them to positive ones -dim -> rank - |dim| - for(auto& dim : dimensions) - if(dim < 0) - dim += rank; - // sort input array of dimensions, this operation is also necessary for creating TAD object in external methods - if (dimSize > 1) { - std::sort(dimensions.begin(), dimensions.end()); - // remove duplicates if they are present - dimensions.erase(std::unique(dimensions.begin(), dimensions.end()), dimensions.end()); + } else { + /* C order */ + if (oldStrides[ok] != oldDims[ok + 1] * oldStrides[ok + 1]) { + /* not contiguous enough */ + delete[] oldDims; + delete[] oldStrides; + delete[] newStrides; + + return false; } - // check whether number of dimensions is to big (>rank) - dimSize = dimensions.size(); - if(dimSize > rank) - throw std::runtime_error("shape::checkDimensions method: number of input dimensions is too big ( > rank of array)!"); - // check if min dimension is still negative and whether max dimension is bigger then rank-1 - if(dimensions[0] < 0 || dimensions.back() > (rank-1)) - throw std::runtime_error("shape::checkDimensions method: the negative dimension is still present in input array after transform or the too big dimension is present ( > rank of array) !"); - } + } + } + + /* Calculate new strides for all axes currently worked with */ + if (isFOrder) { + newStrides[newStart] = oldStrides[oldStart]; + for (nk = newStart + 1; nk < newStop; nk++) { + newStrides[nk] = newStrides[nk - 1] * newShapeOf[nk - 1]; + } + } else { + /* C order */ + newStrides[newStop - 1] = oldStrides[oldStop - 1]; + for (nk = newStop - 1; nk > newStart; nk--) { + newStrides[nk - 1] = newStrides[nk] * newShapeOf[nk]; + } + } + newStart = newStop++; + oldStart = oldStop++; + } + + delete[] oldDims; + delete[] oldStrides; + delete[] newStrides; + + return true; +} +// this function checks the consistence of dimensions with array rank (negative +// dimensions, too large dimensions, too big number of dimensions) also it sorts +// input array of dimensions, this operation is also necessary for creating TAD +// object +INLINEDEF _CUDA_H void checkDimensions(const int rank, + std::vector &dimensions) { + int dimSize = dimensions.size(); + if (dimSize == 0) + throw std::runtime_error( + "shape::checkDimensions method: array of dimensions is empty!"); + // check presence of negative dimensions and if they are present transform + // them to positive ones -dim -> rank - |dim| + for (auto &dim : dimensions) + if (dim < 0) dim += rank; + // sort input array of dimensions, this operation is also necessary for + // creating TAD object in external methods + if (dimSize > 1) { + std::sort(dimensions.begin(), dimensions.end()); + // remove duplicates if they are present + dimensions.erase(std::unique(dimensions.begin(), dimensions.end()), + dimensions.end()); + } + // check whether number of dimensions is to big (>rank) + dimSize = dimensions.size(); + if (dimSize > rank) + throw std::runtime_error( + "shape::checkDimensions method: number of input dimensions is too big " + "( > rank of array)!"); + // check if min dimension is still negative and whether max dimension is + // bigger then rank-1 + if (dimensions[0] < 0 || dimensions.back() > (rank - 1)) + throw std::runtime_error( + "shape::checkDimensions method: the negative dimension is still " + "present in input array after transform or the too big dimension is " + "present ( > rank of array) !"); +} // max array is outer for min array, min array is sub-array of max array -// function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs) -INLINEDEF _CUDA_HD void maxIndToMinInd(int* maxIdxs, int* minIdxs, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude, int dimsLen) { - - const auto maxRank = shape::rank(maxShapeInfo); - const auto minRank = shape::rank(minShapeInfo); - - // if(minRank >= maxRank) - // throw std::runtime_error("shape::maxIndToMinInd method: rank of min array should be smaller then rank of max array!"); - - if(dimsLen == -1) - dimsLen = maxRank - minRank; // if size is not given (= -1) then it is equal to ranks difference - - if(maxRank == minRank) { - - if(dimsToExclude == nullptr) { // --> means dimsToExclude == {0,1,2,...,dimsLen-1} - - for (int i = 0; i < maxRank; ++i) { - - if(i < dimsLen) - minIdxs[i] = maxIdxs[i]; - else { - if(maxIdxs[i] > minShapeInfo[i + 1]) - minIdxs[i] = maxIdxs[i] % minShapeInfo[i + 1]; - else if(maxIdxs[i] == minShapeInfo[i + 1]) - minIdxs[i] = 0; - else - minIdxs[i] = maxIdxs[i]; - } - } - } +// function calculates the coordinates of min array (and saves them into +// minIdxs) given coordinates of max array (already stored in maxIdxs) +INLINEDEF _CUDA_HD void maxIndToMinInd(int *maxIdxs, int *minIdxs, + const Nd4jLong *maxShapeInfo, + const Nd4jLong *minShapeInfo, + const int *dimsToExclude, int dimsLen) { + const auto maxRank = shape::rank(maxShapeInfo); + const auto minRank = shape::rank(minShapeInfo); + + // if(minRank >= maxRank) + // throw std::runtime_error("shape::maxIndToMinInd method: rank of min + // array should be smaller then rank of max array!"); + + if (dimsLen == -1) + dimsLen = maxRank - minRank; // if size is not given (= -1) then it is + // equal to ranks difference + + if (maxRank == minRank) { + if (dimsToExclude == + nullptr) { // --> means dimsToExclude == {0,1,2,...,dimsLen-1} + + for (int i = 0; i < maxRank; ++i) { + if (i < dimsLen) + minIdxs[i] = maxIdxs[i]; else { - - for (int i = 0, dim = 0; i < maxRank; ++i) { - - if(dim < dimsLen && dimsToExclude[dim] == i) { - minIdxs[i] = maxIdxs[i]; - ++dim; - continue; - } - - if(maxIdxs[i] > minShapeInfo[i + 1]) - minIdxs[i] = maxIdxs[i] % minShapeInfo[i + 1]; - else if(maxIdxs[i] == minShapeInfo[i + 1]) - minIdxs[i] = 0; - else - minIdxs[i] = maxIdxs[i]; - } + if (maxIdxs[i] > minShapeInfo[i + 1]) + minIdxs[i] = maxIdxs[i] % minShapeInfo[i + 1]; + else if (maxIdxs[i] == minShapeInfo[i + 1]) + minIdxs[i] = 0; + else + minIdxs[i] = maxIdxs[i]; } - } - else { - - if(dimsToExclude == nullptr) { // --> means dimsToExclude == {0,1,2,...,dimsLen-1} - - for (int i = 0; i < minRank; ++i) { - - if(maxIdxs[i + dimsLen] > minShapeInfo[i + 1]) - minIdxs[i] = maxIdxs[i + dimsLen] % minShapeInfo[i + 1]; - else if(maxIdxs[i + dimsLen] == minShapeInfo[i + 1]) - minIdxs[i] = 0; - else - minIdxs[i] = maxIdxs[i + dimsLen]; - } + } + } else { + for (int i = 0, dim = 0; i < maxRank; ++i) { + if (dim < dimsLen && dimsToExclude[dim] == i) { + minIdxs[i] = maxIdxs[i]; + ++dim; + continue; } - else { - for (int minI = 0, maxI = 0, dim = 0; maxI < maxRank; ++maxI) { - - if(dim < dimsLen && dimsToExclude[dim] == maxI) { - ++dim; - continue; - } - - if(maxIdxs[maxI] == minShapeInfo[minI + 1]) - minIdxs[minI] = 0; - else if(maxIdxs[maxI] > minShapeInfo[minI + 1]) - minIdxs[minI] = maxIdxs[maxI] % minShapeInfo[minI + 1]; - else - minIdxs[minI] = maxIdxs[maxI]; - ++minI; - } + if (maxIdxs[i] > minShapeInfo[i + 1]) + minIdxs[i] = maxIdxs[i] % minShapeInfo[i + 1]; + else if (maxIdxs[i] == minShapeInfo[i + 1]) + minIdxs[i] = 0; + else + minIdxs[i] = maxIdxs[i]; + } + } + } else { + if (dimsToExclude == + nullptr) { // --> means dimsToExclude == {0,1,2,...,dimsLen-1} + + for (int i = 0; i < minRank; ++i) { + if (maxIdxs[i + dimsLen] > minShapeInfo[i + 1]) + minIdxs[i] = maxIdxs[i + dimsLen] % minShapeInfo[i + 1]; + else if (maxIdxs[i + dimsLen] == minShapeInfo[i + 1]) + minIdxs[i] = 0; + else + minIdxs[i] = maxIdxs[i + dimsLen]; + } + } else { + for (int minI = 0, maxI = 0, dim = 0; maxI < maxRank; ++maxI) { + if (dim < dimsLen && dimsToExclude[dim] == maxI) { + ++dim; + continue; } - } -} - - ////////////////////////////////////////////////////////////////////// - INLINEDEF _CUDA_HD Nd4jLong subArrayIndex(const Nd4jLong maxIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude, const int dimsLen) { - int maxIdxs[MAX_RANK]; - shape::index2coords(const_cast(maxIdx), maxShapeInfo, maxIdxs); - - int minIdxs[MAX_RANK]; - maxIndToMinInd(maxIdxs, minIdxs, maxShapeInfo, minShapeInfo, dimsToExclude, dimsLen); - - return shape::coords2index(minShapeInfo, minIdxs); - } - - ////////////////////////////////////////////////////////////////////// - INLINEDEF _CUDA_HD Nd4jLong subArrayOffset(const Nd4jLong maxIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude, const int dimsLen) { - - int maxIdxs[MAX_RANK]; - shape::index2coords(const_cast(maxIdx), maxShapeInfo, maxIdxs); - - int minIdxs[MAX_RANK]; - maxIndToMinInd(maxIdxs, minIdxs, maxShapeInfo, minShapeInfo, dimsToExclude, dimsLen); - - return getOffset(minShapeInfo, minIdxs); + if (maxIdxs[maxI] == minShapeInfo[minI + 1]) + minIdxs[minI] = 0; + else if (maxIdxs[maxI] > minShapeInfo[minI + 1]) + minIdxs[minI] = maxIdxs[maxI] % minShapeInfo[minI + 1]; + else + minIdxs[minI] = maxIdxs[maxI]; + ++minI; + } } + } +} - ////////////////////////////////////////////////////////////////////// - INLINEDEF _CUDA_HD int outerArrayOffsets(Nd4jLong* maxOffsets, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, int* memBuff, const int* dimsToExclude) { - - const auto rankMin = shape::rank(minShapeInfo); - const auto rankMax = shape::rank(maxShapeInfo); - - // if(rankMin >= rankMax) - // throw std::runtime_error("shape::subArrayIndex method: rank of min array should be smaller then rank of max array!"); - - const auto diff = rankMax - rankMin; // the size of dimsToExclude is equal to diff - - int* indices = memBuff; - int* increment = memBuff + rankMax; - - int N, minI, maxI; - - // calculate min per-dim-indices which corresponds to absolute minIdx index - shape::index2coords(minIdx, minShapeInfo, indices); +////////////////////////////////////////////////////////////////////// +INLINEDEF _CUDA_HD Nd4jLong subArrayIndex(const Nd4jLong maxIdx, + const Nd4jLong *maxShapeInfo, + const Nd4jLong *minShapeInfo, + const int *dimsToExclude, + const int dimsLen) { + int maxIdxs[MAX_RANK]; + shape::index2coords(const_cast(maxIdx), maxShapeInfo, maxIdxs); + + int minIdxs[MAX_RANK]; + maxIndToMinInd(maxIdxs, minIdxs, maxShapeInfo, minShapeInfo, dimsToExclude, + dimsLen); + + return shape::coords2index(minShapeInfo, minIdxs); +} - // transform storage indices to contain per-dim max indices, purpose - memory saving - // fill increment array as well - if(dimsToExclude == nullptr) { // means dimsToExclude == {0,1,2,...,diff-1} - for(minI = rankMin - 1, maxI = rankMax-1; maxI >= diff; --maxI, --minI) { - increment[maxI] = (maxShapeInfo[maxI+1] == minShapeInfo[minI+1]) ? 0 : minShapeInfo[minI+1]; - indices[maxI] = indices[minI]; - } - for(maxI = 0; maxI < diff; ++maxI) { - increment[maxI] = 1; - indices[maxI] = 0; - } - } - else { - for(N = diff-1, minI = rankMin - 1, maxI = rankMax - 1; maxI >= 0; --maxI) { - if(N >= 0 && dimsToExclude[N] == maxI) { - increment[maxI] = 1; - indices[maxI] = 0; - --N; - } - else { - increment[maxI] = (maxShapeInfo[maxI+1] == minShapeInfo[minI+1]) ? 0 : minShapeInfo[minI+1]; - indices[maxI] = indices[minI--]; - } - } - } +////////////////////////////////////////////////////////////////////// +INLINEDEF _CUDA_HD Nd4jLong subArrayOffset(const Nd4jLong maxIdx, + const Nd4jLong *maxShapeInfo, + const Nd4jLong *minShapeInfo, + const int *dimsToExclude, + const int dimsLen) { + int maxIdxs[MAX_RANK]; + shape::index2coords(const_cast(maxIdx), maxShapeInfo, maxIdxs); + + int minIdxs[MAX_RANK]; + maxIndToMinInd(maxIdxs, minIdxs, maxShapeInfo, minShapeInfo, dimsToExclude, + dimsLen); + + return getOffset(minShapeInfo, minIdxs); +} - maxI = rankMax-1; - N = 0; - int step; +////////////////////////////////////////////////////////////////////// +INLINEDEF _CUDA_HD int outerArrayOffsets( + Nd4jLong *maxOffsets, const Nd4jLong minIdx, const Nd4jLong *maxShapeInfo, + const Nd4jLong *minShapeInfo, int *memBuff, const int *dimsToExclude) { + const auto rankMin = shape::rank(minShapeInfo); + const auto rankMax = shape::rank(maxShapeInfo); + + // if(rankMin >= rankMax) + // throw std::runtime_error("shape::subArrayIndex method: rank of min + // array should be smaller then rank of max array!"); + + const auto diff = + rankMax - rankMin; // the size of dimsToExclude is equal to diff + + int *indices = memBuff; + int *increment = memBuff + rankMax; + + int N, minI, maxI; + + // calculate min per-dim-indices which corresponds to absolute minIdx index + shape::index2coords(minIdx, minShapeInfo, indices); + + // transform storage indices to contain per-dim max indices, purpose - memory + // saving fill increment array as well + if (dimsToExclude == nullptr) { // means dimsToExclude == {0,1,2,...,diff-1} + for (minI = rankMin - 1, maxI = rankMax - 1; maxI >= diff; --maxI, --minI) { + increment[maxI] = (maxShapeInfo[maxI + 1] == minShapeInfo[minI + 1]) + ? 0 + : minShapeInfo[minI + 1]; + indices[maxI] = indices[minI]; + } + for (maxI = 0; maxI < diff; ++maxI) { + increment[maxI] = 1; + indices[maxI] = 0; + } + } else { + for (N = diff - 1, minI = rankMin - 1, maxI = rankMax - 1; maxI >= 0; + --maxI) { + if (N >= 0 && dimsToExclude[N] == maxI) { + increment[maxI] = 1; + indices[maxI] = 0; + --N; + } else { + increment[maxI] = (maxShapeInfo[maxI + 1] == minShapeInfo[minI + 1]) + ? 0 + : minShapeInfo[minI + 1]; + indices[maxI] = indices[minI--]; + } + } + } + + maxI = rankMax - 1; + N = 0; + int step; + maxOffsets[N++] = shape::getOffset(maxShapeInfo, indices); + + // nested loops - producing of absolute indices for max array + while (maxI >= 0) { + if (increment[maxI] != 0) { + indices[maxI] += increment[maxI]; + if (indices[maxI] >= maxShapeInfo[maxI + 1]) { + indices[maxI] %= + increment[maxI]; // restore initial value of indices[maxI] + step = -1; + } else { maxOffsets[N++] = shape::getOffset(maxShapeInfo, indices); - - // nested loops - producing of absolute indices for max array - while(maxI >= 0) { - - if(increment[maxI] != 0) { - - indices[maxI] += increment[maxI]; - if(indices[maxI] >= maxShapeInfo[maxI+1]) { - indices[maxI] %= increment[maxI]; // restore initial value of indices[maxI] - step = -1; - } - else { - maxOffsets[N++] = shape::getOffset(maxShapeInfo, indices); - step = rankMax - 1 - maxI; - } - } - else if(maxI == rankMax - 1) - step = -1; - - maxI += step; - } - return N; - } - - ////////////////////////////////////////////////////////////////////// - INLINEDEF _CUDA_HD int outerArrayIndexes(int* maxIdxs, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude) { - - const auto rankMin = shape::rank(minShapeInfo); - const auto rankMax = shape::rank(maxShapeInfo); - - // if(rankMin >= rankMax) - // throw std::runtime_error("shape::subArrayIndex method: rank of min array should be smaller then rank of max array!"); - // if(rankMax > MAX_RANK/2) - // throw std::runtime_error("shape::subArrayIndex method: rank of max array should be <= MAX_RANK/2 !"); - - const auto diff = rankMax - rankMin; // the size of dimsToExclude is equal to diff - - int indices[MAX_RANK], increment[MAX_RANK]; - - int N, minI, maxI; - - // calculate min per-dim-indices which corresponds to absolute minIdx index - shape::index2coords(minIdx, minShapeInfo, indices); - - // transform storage indices to contain per-dim max indices, purpose - memory saving - // fill increment array as well - if(dimsToExclude == nullptr) { // means dimsToExclude == {0,1,2,...,diff-1} - for(minI = rankMin - 1, maxI = rankMax-1; maxI >= diff; --maxI, --minI) { - increment[maxI] = (maxShapeInfo[maxI+1] == minShapeInfo[minI+1]) ? 0 : minShapeInfo[minI+1]; - indices[maxI] = indices[minI]; - } - for(maxI = 0; maxI < diff; ++maxI) { - increment[maxI] = 1; - indices[maxI] = 0; - } - } - else { - for(N = diff-1, minI = rankMin - 1, maxI = rankMax - 1; maxI >= 0; --maxI) { - if(N >= 0 && dimsToExclude[N] == maxI) { - increment[maxI] = 1; - indices[maxI] = 0; - --N; - } - else { - increment[maxI] = (maxShapeInfo[maxI+1] == minShapeInfo[minI+1]) ? 0 : minShapeInfo[minI+1]; - indices[maxI] = indices[minI--]; - } - } - } - - maxI = rankMax-1; - N = 0; - int step; - maxIdxs[N++] = shape::coords2index(maxShapeInfo, indices); - - // nested loops - producing of absolute indices for max array - while(maxI >= 0) { - - if(increment[maxI] != 0) { - - indices[maxI] += increment[maxI]; - if(indices[maxI] >= maxShapeInfo[maxI+1]) { - indices[maxI] %= increment[maxI]; // restore initial value of indices[maxI] - step = -1; - } - else { - maxIdxs[N++] = shape::coords2index(maxShapeInfo, indices); - step = rankMax - 1 - maxI; - } - } - else if(maxI == rankMax - 1) - step = -1; - - maxI += step; - } - return N; - } - - INLINEDEF _CUDA_HD void shapeOldScalar(sd::DataType dataType, Nd4jLong* const buffer, const char order) { - - buffer[0] = 2; - buffer[1] = 1; - buffer[2] = 1; - buffer[3] = 1; - buffer[4] = 1; - buffer[6] = 1; - buffer[7] = (int)order; - - sd::ArrayOptions::setDataType(buffer, dataType); - } - - template - INLINEDEF _CUDA_H void convertT(T1 *from, T2 *to, Nd4jLong length) { - for (Nd4jLong e = 0; e < length; e++) - to[e] = (T2) from[e]; - }; + step = rankMax - 1 - maxI; + } + } else if (maxI == rankMax - 1) + step = -1; + + maxI += step; + } + return N; +} ////////////////////////////////////////////////////////////////////// -INLINEDEF void calcOffsets(const Nd4jLong* shapeInfo, Nd4jLong* offsets, const char order) { - - // firstly consider simple case when ews > 0 - const Nd4jLong ews = shape::elementWiseStride(shapeInfo); - - if(ews > 0) { - - // set offset for first sub-array, it is equal to zero always - offsets[0] = 0; - - Nd4jLong e = 0; - if(order != shape::order(shapeInfo)) - for(int i = 1; i <= shape::rank(shapeInfo); ++i) - if(shapeInfo[i] != 1) - ++e; //check whether input is CommonVector - - if(order == shape::order(shapeInfo) || e == 1) { // e==1 means common vector - e = 1; - Nd4jLong len = shape::length(shapeInfo); - while(e < len) { - offsets[e] = offsets[e - 1] + ews; - e++; - } - return; - } - } +INLINEDEF _CUDA_HD int outerArrayIndexes(int *maxIdxs, const Nd4jLong minIdx, + const Nd4jLong *maxShapeInfo, + const Nd4jLong *minShapeInfo, + const int *dimsToExclude) { + const auto rankMin = shape::rank(minShapeInfo); + const auto rankMax = shape::rank(maxShapeInfo); + + // if(rankMin >= rankMax) + // throw std::runtime_error("shape::subArrayIndex method: rank of min + // array should be smaller then rank of max array!"); + // if(rankMax > MAX_RANK/2) + // throw std::runtime_error("shape::subArrayIndex method: rank of max + // array should be <= MAX_RANK/2 !"); + + const auto diff = + rankMax - rankMin; // the size of dimsToExclude is equal to diff + + int indices[MAX_RANK], increment[MAX_RANK]; + + int N, minI, maxI; + + // calculate min per-dim-indices which corresponds to absolute minIdx index + shape::index2coords(minIdx, minShapeInfo, indices); + + // transform storage indices to contain per-dim max indices, purpose - memory + // saving fill increment array as well + if (dimsToExclude == nullptr) { // means dimsToExclude == {0,1,2,...,diff-1} + for (minI = rankMin - 1, maxI = rankMax - 1; maxI >= diff; --maxI, --minI) { + increment[maxI] = (maxShapeInfo[maxI + 1] == minShapeInfo[minI + 1]) + ? 0 + : minShapeInfo[minI + 1]; + indices[maxI] = indices[minI]; + } + for (maxI = 0; maxI < diff; ++maxI) { + increment[maxI] = 1; + indices[maxI] = 0; + } + } else { + for (N = diff - 1, minI = rankMin - 1, maxI = rankMax - 1; maxI >= 0; + --maxI) { + if (N >= 0 && dimsToExclude[N] == maxI) { + increment[maxI] = 1; + indices[maxI] = 0; + --N; + } else { + increment[maxI] = (maxShapeInfo[maxI + 1] == minShapeInfo[minI + 1]) + ? 0 + : minShapeInfo[minI + 1]; + indices[maxI] = indices[minI--]; + } + } + } + + maxI = rankMax - 1; + N = 0; + int step; + maxIdxs[N++] = shape::coords2index(maxShapeInfo, indices); + + // nested loops - producing of absolute indices for max array + while (maxI >= 0) { + if (increment[maxI] != 0) { + indices[maxI] += increment[maxI]; + if (indices[maxI] >= maxShapeInfo[maxI + 1]) { + indices[maxI] %= + increment[maxI]; // restore initial value of indices[maxI] + step = -1; + } else { + maxIdxs[N++] = shape::coords2index(maxShapeInfo, indices); + step = rankMax - 1 - maxI; + } + } else if (maxI == rankMax - 1) + step = -1; + + maxI += step; + } + return N; +} - shape::calcOffsets(shape::rank(shapeInfo), shape::shapeOf(const_cast(shapeInfo)), shape::stride(const_cast(shapeInfo)), offsets, order); +INLINEDEF _CUDA_HD void shapeOldScalar(sd::DataType dataType, + Nd4jLong *const buffer, + const char order) { + buffer[0] = 2; + buffer[1] = 1; + buffer[2] = 1; + buffer[3] = 1; + buffer[4] = 1; + buffer[6] = 1; + buffer[7] = (int)order; + + sd::ArrayOptions::setDataType(buffer, dataType); } +template +INLINEDEF _CUDA_H void convertT(T1 *from, T2 *to, Nd4jLong length) { + for (Nd4jLong e = 0; e < length; e++) to[e] = (T2)from[e]; +}; + ////////////////////////////////////////////////////////////////////// -INLINEDEF void calcOffsets(const int rank, const Nd4jLong* shape, const Nd4jLong* strides, Nd4jLong* offsets, const char order) { - - // if(false) { // tests showed that this code did calculation notably slower even for big N - // Nd4jLong indexes[MAX_RANK]; - // PRAGMA_OMP_PARALLEL_FOR_ARGS(private(indexes)) - // for (Nd4jLong i = 0; i < N; ++i) { - // shape::index2coords(rank, shape, i, indexes); - // subArrOffsets[i] = 0; - // for (int j = 0; j < rank; ++j) - // if(shape[j] != 1) - // subArrOffsets[i] += indexes[j] * strides[j]; - // } - // return; - // } +INLINEDEF void calcOffsets(const Nd4jLong *shapeInfo, Nd4jLong *offsets, + const char order) { + // firstly consider simple case when ews > 0 + const Nd4jLong ews = shape::elementWiseStride(shapeInfo); + if (ews > 0) { // set offset for first sub-array, it is equal to zero always offsets[0] = 0; - Nd4jLong * idx = new Nd4jLong[rank]; - Nd4jLong* offsetPerDim = new Nd4jLong[rank]; - memset(idx, 0, sizeof(Nd4jLong) * rank); - - PRAGMA_OMP_SIMD - for (int k = 0; k < rank; ++k) - offsetPerDim[k] = (shape[k] - 1) * strides[k]; - - Nd4jLong init = 0, i = 1; - // nested loops - calculation of sub-array offsets - if(order == 'c') { - - Nd4jLong rankMinusOne = rank - 1, j = rankMinusOne; - - while(j >= 0) { - - if(shape[j] == 1) { --j; continue; } // ignore dimensions equal to unity - - if(j == rankMinusOne) { // last dimension - for(int l = 1; l < shape[j]; ++l) { - offsets[i] = offsets[i - 1] + strides[j]; - i++; - } - --j; - } - else if(idx[j] < shape[j] - 1) { - init += strides[j]; - offsets[i++] = init; - ++idx[j]; - j = rankMinusOne; - } - else { - init -= offsetPerDim[j]; - idx[j--] = 0; - } - } - } - else { + Nd4jLong e = 0; + if (order != shape::order(shapeInfo)) + for (int i = 1; i <= shape::rank(shapeInfo); ++i) + if (shapeInfo[i] != 1) ++e; // check whether input is CommonVector + + if (order == shape::order(shapeInfo) || + e == 1) { // e==1 means common vector + e = 1; + Nd4jLong len = shape::length(shapeInfo); + while (e < len) { + offsets[e] = offsets[e - 1] + ews; + e++; + } + return; + } + } + + shape::calcOffsets( + shape::rank(shapeInfo), shape::shapeOf(const_cast(shapeInfo)), + shape::stride(const_cast(shapeInfo)), offsets, order); +} - Nd4jLong j = 0; - - while(j < rank) { - - if(shape[j] == 1) { ++j; continue; } // ignore dimensions equal to unity - - if(j == 0) { // last dimension - for(int l = 1; l < shape[j]; ++l) { - offsets[i] = offsets[i - 1] + strides[j]; - i++; - } - ++j; - } - else if(idx[j] < shape[j] - 1) { - init += strides[j]; - offsets[i++] = init; - ++idx[j]; - j = 0; - } - else { - init -= offsetPerDim[j]; - idx[j++] = 0; - } +////////////////////////////////////////////////////////////////////// +INLINEDEF void calcOffsets(const int rank, const Nd4jLong *shape, + const Nd4jLong *strides, Nd4jLong *offsets, + const char order) { + // if(false) { // tests showed that this code did + // calculation notably slower even for big N + // Nd4jLong indexes[MAX_RANK]; + // PRAGMA_OMP_PARALLEL_FOR_ARGS(private(indexes)) + // for (Nd4jLong i = 0; i < N; ++i) { + // shape::index2coords(rank, shape, i, indexes); + // subArrOffsets[i] = 0; + // for (int j = 0; j < rank; ++j) + // if(shape[j] != 1) + // subArrOffsets[i] += indexes[j] * strides[j]; + // } + // return; + // } + + // set offset for first sub-array, it is equal to zero always + offsets[0] = 0; + + Nd4jLong *idx = new Nd4jLong[rank]; + Nd4jLong *offsetPerDim = new Nd4jLong[rank]; + memset(idx, 0, sizeof(Nd4jLong) * rank); + + PRAGMA_OMP_SIMD + for (int k = 0; k < rank; ++k) offsetPerDim[k] = (shape[k] - 1) * strides[k]; + + Nd4jLong init = 0, i = 1; + // nested loops - calculation of sub-array offsets + if (order == 'c') { + Nd4jLong rankMinusOne = rank - 1, j = rankMinusOne; + + while (j >= 0) { + if (shape[j] == 1) { + --j; + continue; + } // ignore dimensions equal to unity + + if (j == rankMinusOne) { // last dimension + for (int l = 1; l < shape[j]; ++l) { + offsets[i] = offsets[i - 1] + strides[j]; + i++; } - } - - delete []idx; - delete []offsetPerDim; + --j; + } else if (idx[j] < shape[j] - 1) { + init += strides[j]; + offsets[i++] = init; + ++idx[j]; + j = rankMinusOne; + } else { + init -= offsetPerDim[j]; + idx[j--] = 0; + } + } + } else { + Nd4jLong j = 0; + + while (j < rank) { + if (shape[j] == 1) { + ++j; + continue; + } // ignore dimensions equal to unity + + if (j == 0) { // last dimension + for (int l = 1; l < shape[j]; ++l) { + offsets[i] = offsets[i - 1] + strides[j]; + i++; + } + ++j; + } else if (idx[j] < shape[j] - 1) { + init += strides[j]; + offsets[i++] = init; + ++idx[j]; + j = 0; + } else { + init -= offsetPerDim[j]; + idx[j++] = 0; + } + } + } + + delete[] idx; + delete[] offsetPerDim; } ////////////////////////////////////////////////////////////////////// -INLINEDEF void _CUDA_HD checkStridesEwsAndOrder(Nd4jLong* shapeInfo) { - - // FIXME - indeed we don't need to allocate so large memory amount (2*MAX_RANK), sufficient amount is (2*oldNumOfNonUnities + 2*newNumOfNonUnities) - Nd4jLong tempBuffer[2*MAX_RANK]; - Nd4jLong *shape = tempBuffer, *strides; - - // exclude unities from shapeInfo - const int numOfNonUnities = shape::excludeUnitiesFromShapeInfo(shapeInfo, shape, strides); - - shape::checkStridesEwsAndOrder(shapeInfo, shape::order(shapeInfo), numOfNonUnities, shape, strides); +INLINEDEF void _CUDA_HD checkStridesEwsAndOrder(Nd4jLong *shapeInfo) { + // FIXME - indeed we don't need to allocate so large memory amount + // (2*MAX_RANK), sufficient amount is (2*oldNumOfNonUnities + + // 2*newNumOfNonUnities) + Nd4jLong tempBuffer[2 * MAX_RANK]; + Nd4jLong *shape = tempBuffer, *strides; + + // exclude unities from shapeInfo + const int numOfNonUnities = + shape::excludeUnitiesFromShapeInfo(shapeInfo, shape, strides); + + shape::checkStridesEwsAndOrder(shapeInfo, shape::order(shapeInfo), + numOfNonUnities, shape, strides); } ////////////////////////////////////////////////////////////////////// -INLINEDEF void _CUDA_HD checkStridesEwsAndOrder(Nd4jLong* shapeInfo, const char proposedOrder, const int numOfNonUnities, const Nd4jLong* shapeNoUnities, const Nd4jLong* stridesNoUnities) { +INLINEDEF void _CUDA_HD checkStridesEwsAndOrder( + Nd4jLong *shapeInfo, const char proposedOrder, const int numOfNonUnities, + const Nd4jLong *shapeNoUnities, const Nd4jLong *stridesNoUnities) { + const int rank = shape::rank(shapeInfo); - const int rank = shape::rank(shapeInfo); - - if(shape::length(shapeInfo) == 1) { - *shape::ews(shapeInfo) = 1; - shapeInfo[rank * 2 + 3] = (int)proposedOrder; - return; - } + if (shape::length(shapeInfo) == 1) { + *shape::ews(shapeInfo) = 1; + shapeInfo[rank * 2 + 3] = (int)proposedOrder; + return; + } - if(numOfNonUnities == 1) { // case of common vector - *shape::ews(shapeInfo) = *stridesNoUnities; - shapeInfo[rank * 2 + 3] = (int)proposedOrder; - return; - } + if (numOfNonUnities == 1) { // case of common vector + *shape::ews(shapeInfo) = *stridesNoUnities; + shapeInfo[rank * 2 + 3] = (int)proposedOrder; + return; + } - bool contiguous = true; + bool contiguous = true; - //*** check whether strides are in c contiguous order ***// - for (uint i = 0; i < numOfNonUnities - 1; ++i) { - if(stridesNoUnities[i] != shapeNoUnities[i + 1] * stridesNoUnities[i + 1]) { - contiguous = false; - break; - } + //*** check whether strides are in c contiguous order ***// + for (uint i = 0; i < numOfNonUnities - 1; ++i) { + if (stridesNoUnities[i] != + shapeNoUnities[i + 1] * stridesNoUnities[i + 1]) { + contiguous = false; + break; } + } - if(contiguous) { - - *shape::ews(shapeInfo) = stridesNoUnities[numOfNonUnities - 1]; - shapeInfo[rank * 2 + 3] = 99; - return; - } + if (contiguous) { + *shape::ews(shapeInfo) = stridesNoUnities[numOfNonUnities - 1]; + shapeInfo[rank * 2 + 3] = 99; + return; + } - contiguous = true; + contiguous = true; - //*** check whether strides are in f contiguous order ***// - for (uint i = 1; i < numOfNonUnities; ++i) { - if(stridesNoUnities[i] != shapeNoUnities[i - 1] * stridesNoUnities[i - 1]) { - contiguous = false; - break; - } + //*** check whether strides are in f contiguous order ***// + for (uint i = 1; i < numOfNonUnities; ++i) { + if (stridesNoUnities[i] != + shapeNoUnities[i - 1] * stridesNoUnities[i - 1]) { + contiguous = false; + break; } + } - if(contiguous) { - - *shape::ews(shapeInfo) = stridesNoUnities[0]; - shapeInfo[rank * 2 + 3] = 102; - return; - } + if (contiguous) { + *shape::ews(shapeInfo) = stridesNoUnities[0]; + shapeInfo[rank * 2 + 3] = 102; + return; + } - *shape::ews(shapeInfo) = 0; - shapeInfo[rank * 2 + 3] = (int)proposedOrder; + *shape::ews(shapeInfo) = 0; + shapeInfo[rank * 2 + 3] = (int)proposedOrder; } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD void calcSubArrsShapeInfoAndOffsets(const Nd4jLong* wholeShapeInfo, const Nd4jLong numOfSubArrs, const int dimsSize, const int* dimsToExclude, Nd4jLong* subArrShapeInfo, Nd4jLong* subArrOffsets, bool keepUnitiesInShape) { - - const int rank = shape::rank(wholeShapeInfo); - - if(dimsSize == rank || dimsSize == 0) { // means there is one sub-array and it coincides with whole array, return copy of wholeShapeInfo and one zero offset in this case - memcpy(subArrShapeInfo, wholeShapeInfo, shape::shapeInfoLength(rank) * sizeof(Nd4jLong)); - *subArrOffsets = 0; - return; - } - - const int subArrRank = keepUnitiesInShape ? rank : rank - dimsSize; - - subArrShapeInfo[0] = subArrRank; // rank - sd::ArrayOptions::copyDataType(subArrShapeInfo, wholeShapeInfo); // type - subArrShapeInfo[2 * subArrRank + 3] = shape::order(wholeShapeInfo); // order - - Nd4jLong* shape = new Nd4jLong[dimsSize]; - Nd4jLong* strides = new Nd4jLong[dimsSize]; - - for(int k = subArrRank - 1, j = dimsSize - 1, i = rank - 1; i >= 0; --i) { - - if(j >= 0 && i == dimsToExclude[j]) { - - strides[j] = shape::stride(wholeShapeInfo)[i]; - shape[j--] = shape::shapeOf(wholeShapeInfo)[i]; - - if(keepUnitiesInShape) { - shape::shapeOf(subArrShapeInfo)[k] = 1; - shape::stride(subArrShapeInfo)[k--] = shape::stride(wholeShapeInfo)[i]; - } - } - else { - shape::shapeOf(subArrShapeInfo)[k] = shape::shapeOf(wholeShapeInfo)[i]; - shape::stride(subArrShapeInfo)[k--] = shape::stride(wholeShapeInfo)[i]; - } - - } - - // calculation of sub-array offsets (subArrOffsets) - shape::calcOffsets(dimsSize, shape, strides, subArrOffsets); - - // evaluate ews - shape::checkStridesEwsAndOrder(subArrShapeInfo); - - delete []strides; - delete []shape; +INLINEDEF _CUDA_HD void calcSubArrsShapeInfoAndOffsets( + const Nd4jLong *wholeShapeInfo, const Nd4jLong numOfSubArrs, + const int dimsSize, const int *dimsToExclude, Nd4jLong *subArrShapeInfo, + Nd4jLong *subArrOffsets, bool keepUnitiesInShape) { + const int rank = shape::rank(wholeShapeInfo); + + if (dimsSize == rank || + dimsSize == 0) { // means there is one sub-array and it coincides with + // whole array, return copy of wholeShapeInfo and one + // zero offset in this case + memcpy(subArrShapeInfo, wholeShapeInfo, + shape::shapeInfoLength(rank) * sizeof(Nd4jLong)); + *subArrOffsets = 0; + return; + } + + const int subArrRank = keepUnitiesInShape ? rank : rank - dimsSize; + + subArrShapeInfo[0] = subArrRank; // rank + sd::ArrayOptions::copyDataType(subArrShapeInfo, wholeShapeInfo); // type + subArrShapeInfo[2 * subArrRank + 3] = shape::order(wholeShapeInfo); // order + + Nd4jLong *shape = new Nd4jLong[dimsSize]; + Nd4jLong *strides = new Nd4jLong[dimsSize]; + + for (int k = subArrRank - 1, j = dimsSize - 1, i = rank - 1; i >= 0; --i) { + if (j >= 0 && i == dimsToExclude[j]) { + strides[j] = shape::stride(wholeShapeInfo)[i]; + shape[j--] = shape::shapeOf(wholeShapeInfo)[i]; + + if (keepUnitiesInShape) { + shape::shapeOf(subArrShapeInfo)[k] = 1; + shape::stride(subArrShapeInfo)[k--] = shape::stride(wholeShapeInfo)[i]; + } + } else { + shape::shapeOf(subArrShapeInfo)[k] = shape::shapeOf(wholeShapeInfo)[i]; + shape::stride(subArrShapeInfo)[k--] = shape::stride(wholeShapeInfo)[i]; + } + } + + // calculation of sub-array offsets (subArrOffsets) + shape::calcOffsets(dimsSize, shape, strides, subArrOffsets); + + // evaluate ews + shape::checkStridesEwsAndOrder(subArrShapeInfo); + + delete[] strides; + delete[] shape; } ////////////////////////////////////////////////////////////////////// -INLINEDEF void calcSubArrShapeInfoAndOffset(const Nd4jLong* idx, const Nd4jLong* maxShapeInfo, Nd4jLong* minShapeInfo, Nd4jLong& minOffset, const bool keepUnitiesInShape, const bool isStrided, const int numOfUntiesInMinShape) { - - const uint maxRank = shape::rank(maxShapeInfo); - minOffset = 0; - uint first, last, stride, n(isStrided ? 3 : 2); - - minShapeInfo[0] = keepUnitiesInShape ? maxRank : maxRank - numOfUntiesInMinShape; - - for (uint step = 0, j = 0, i = 0; i < maxRank; ++i, step += n) { - - if (idx[step] == idx[step + 1]) { // means whole dimension - shape::shapeOf(minShapeInfo)[j] = shape::shapeOf(maxShapeInfo)[i]; - shape::stride(minShapeInfo)[j++] = shape::stride(maxShapeInfo)[i]; - } - else { - - first = idx[step] >= 0 ? idx[step] : idx[step] + shape::sizeAt(maxShapeInfo, i) + 1; - last = idx[step + 1] >= 0 ? idx[step + 1] : idx[step + 1] + shape::sizeAt(maxShapeInfo, i) + 1; - - if(last < first) - throw("shape::calcSubArrShapeInfoAndOffset: negative range in input indexes is found!"); - - if(isStrided) { - stride = idx[step + 2]; - last /*resulting sub-array axis*/ = (last - first + stride - 1) / stride; // ceil (last - first) / stride; - } - else { - stride = 1; - last /*resulting sub-array axis*/ = last - first; - } - - minOffset += first * shape::stride(maxShapeInfo)[i]; - - if(!keepUnitiesInShape && last == 1) - continue; - - shape::shapeOf(minShapeInfo)[j] = last; - shape::stride(minShapeInfo)[j++] = last == 1 ? shape::stride(maxShapeInfo)[i] : shape::stride(maxShapeInfo)[i] * stride; - } - } - - minShapeInfo[2 * shape::rank(minShapeInfo) + 3] = shape::order(maxShapeInfo); // order - sd::ArrayOptions::copyDataType(minShapeInfo, maxShapeInfo); // type - - shape::checkStridesEwsAndOrder(minShapeInfo); +INLINEDEF void calcSubArrShapeInfoAndOffset( + const Nd4jLong *idx, const Nd4jLong *maxShapeInfo, Nd4jLong *minShapeInfo, + Nd4jLong &minOffset, const bool keepUnitiesInShape, const bool isStrided, + const int numOfUntiesInMinShape) { + const uint maxRank = shape::rank(maxShapeInfo); + minOffset = 0; + uint first, last, stride, n(isStrided ? 3 : 2); + + minShapeInfo[0] = + keepUnitiesInShape ? maxRank : maxRank - numOfUntiesInMinShape; + + for (uint step = 0, j = 0, i = 0; i < maxRank; ++i, step += n) { + if (idx[step] == idx[step + 1]) { // means whole dimension + shape::shapeOf(minShapeInfo)[j] = shape::shapeOf(maxShapeInfo)[i]; + shape::stride(minShapeInfo)[j++] = shape::stride(maxShapeInfo)[i]; + } else { + first = idx[step] >= 0 ? idx[step] + : idx[step] + shape::sizeAt(maxShapeInfo, i) + 1; + last = idx[step + 1] >= 0 + ? idx[step + 1] + : idx[step + 1] + shape::sizeAt(maxShapeInfo, i) + 1; + + if (last < first) + throw( + "shape::calcSubArrShapeInfoAndOffset: negative range in input " + "indexes is found!"); + + if (isStrided) { + stride = idx[step + 2]; + last /*resulting sub-array axis*/ = + (last - first + stride - 1) / + stride; // ceil (last - first) / stride; + } else { + stride = 1; + last /*resulting sub-array axis*/ = last - first; + } + + minOffset += first * shape::stride(maxShapeInfo)[i]; + + if (!keepUnitiesInShape && last == 1) continue; + + shape::shapeOf(minShapeInfo)[j] = last; + shape::stride(minShapeInfo)[j++] = + last == 1 ? shape::stride(maxShapeInfo)[i] + : shape::stride(maxShapeInfo)[i] * stride; + } + } + + minShapeInfo[2 * shape::rank(minShapeInfo) + 3] = + shape::order(maxShapeInfo); // order + sd::ArrayOptions::copyDataType(minShapeInfo, maxShapeInfo); // type + + shape::checkStridesEwsAndOrder(minShapeInfo); } ////////////////////////////////////////////////////////////////////// -INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords) { - - for(uint i = shapeInfo[0]; i > 1; --i) { - coords[i - 1] = index % shapeInfo[i]; - index /= shapeInfo[i]; - } - coords[0] = index; // last iteration +INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, + Nd4jLong *coords) { + for (uint i = shapeInfo[0]; i > 1; --i) { + coords[i - 1] = index % shapeInfo[i]; + index /= shapeInfo[i]; + } + coords[0] = index; // last iteration } ////////////////////////////////////////////////////////////////////// -INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords) { - - for(uint i = shapeInfo[0]; i > 1; --i) { - coords[i - 1] = static_cast(index) % static_cast(shapeInfo[i]); - index /= static_cast(shapeInfo[i]); - } - coords[0] = static_cast(index); // last iteration +INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, + int *coords) { + for (uint i = shapeInfo[0]; i > 1; --i) { + coords[i - 1] = static_cast(index) % static_cast(shapeInfo[i]); + index /= static_cast(shapeInfo[i]); + } + coords[0] = static_cast(index); // last iteration } ////////////////////////////////////////////////////////////////////// -INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, uint *coords) { - - for(uint i = shapeInfo[0]; i > 1; --i) { - coords[i - 1] = static_cast(index) % static_cast(shapeInfo[i]); - index /= static_cast(shapeInfo[i]); - } - coords[0] = static_cast(index); // last iteration +INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, + uint *coords) { + for (uint i = shapeInfo[0]; i > 1; --i) { + coords[i - 1] = static_cast(index) % static_cast(shapeInfo[i]); + index /= static_cast(shapeInfo[i]); + } + coords[0] = static_cast(index); // last iteration } ////////////////////////////////////////////////////////////////////// -INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords) { - - for(uint i = rank - 1; i > 0; --i) { - coords[i] = index % shape[i]; - index /= shape[i]; - } - coords[0] = index; // last iteration +INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const int rank, + const Nd4jLong *shape, Nd4jLong *coords) { + for (uint i = rank - 1; i > 0; --i) { + coords[i] = index % shape[i]; + index /= shape[i]; + } + coords[0] = index; // last iteration } ////////////////////////////////////////////////////////////////////// -INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, int *coords) { - - for(uint i = rank - 1; i > 0; --i) { - coords[i] = index % shape[i]; - index /= shape[i]; - } - coords[0] = index; // last iteration +INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const int rank, + const Nd4jLong *shape, int *coords) { + for (uint i = rank - 1; i > 0; --i) { + coords[i] = index % shape[i]; + index /= shape[i]; + } + coords[0] = index; // last iteration } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords, const int dimsSize, const int* tadDims) { - - for(uint i = dimsSize - 1; i > 0; --i) { - coords[tadDims[i]] = index % shapeInfo[1 + tadDims[i]]; - index /= shapeInfo[1 + tadDims[i]]; - } - coords[tadDims[0]] = index; // last iteration +INLINEDEF _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, + int *coords, const int dimsSize, + const int *tadDims) { + for (uint i = dimsSize - 1; i > 0; --i) { + coords[tadDims[i]] = index % shapeInfo[1 + tadDims[i]]; + index /= shapeInfo[1 + tadDims[i]]; + } + coords[tadDims[0]] = index; // last iteration } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLong& index, const Nd4jLong *shapeInfo, Nd4jLong *coords) { - - if(startIndex == index) { - shape::index2coords(index, shapeInfo, coords); - } - else { - int axis = shapeInfo[0] - 1; - while(coords[axis] == shape::sizeAt(shapeInfo, axis) - 1) - coords[axis--] = 0; - ++coords[axis]; - } +INLINEDEF _CUDA_HD void index2coordsCPU(const Nd4jLong &startIndex, + const Nd4jLong &index, + const Nd4jLong *shapeInfo, + Nd4jLong *coords) { + if (startIndex == index) { + shape::index2coords(index, shapeInfo, coords); + } else { + int axis = shapeInfo[0] - 1; + while (coords[axis] == shape::sizeAt(shapeInfo, axis) - 1) + coords[axis--] = 0; + ++coords[axis]; + } } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLong& index, const Nd4jLong *shapeInfo, int *coords) { - - if(startIndex == index) { - shape::index2coords(index, shapeInfo, coords); - } - else { - int axis = shapeInfo[0] - 1; - while(coords[axis] == shape::sizeAt(shapeInfo, axis) - 1) - coords[axis--] = 0; - ++coords[axis]; - } +INLINEDEF _CUDA_HD void index2coordsCPU(const Nd4jLong &startIndex, + const Nd4jLong &index, + const Nd4jLong *shapeInfo, + int *coords) { + if (startIndex == index) { + shape::index2coords(index, shapeInfo, coords); + } else { + int axis = shapeInfo[0] - 1; + while (coords[axis] == shape::sizeAt(shapeInfo, axis) - 1) + coords[axis--] = 0; + ++coords[axis]; + } } ////////////////////////////////////////////////////////////////////// -// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) { +// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& +// xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* +// zShapeInfo, Nd4jLong*& zOffsets, const char order) { // // we assume all array have same length // const Nd4jLong len = shape::length(xShapeInfo); @@ -4935,22 +5156,27 @@ INLINEDEF _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLo // const char yOrder = shape::order(yShapeInfo); // const char zOrder = shape::order(zShapeInfo); -// const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo, zShapeInfo); +// const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo, +// zShapeInfo); -// if (xEws == 1 && yEws == 1 && zEws == 1 && xOrder == yOrder && xOrder == zOrder && (xOrder == 'c' || shapesSame)) { +// if (xEws == 1 && yEws == 1 && zEws == 1 && xOrder == yOrder && xOrder == +// zOrder && (xOrder == 'c' || shapesSame)) { // xOffsets = yOffsets = zOffsets = nullptr; // } -// else if(xEws == 1 && yEws == 1 && xOrder == yOrder && (xOrder == 'c' || shape::shapeEquals(xShapeInfo, yShapeInfo))) { +// else if(xEws == 1 && yEws == 1 && xOrder == yOrder && (xOrder == 'c' || +// shape::shapeEquals(xShapeInfo, yShapeInfo))) { // xOffsets = yOffsets = nullptr; // zOffsets = new Nd4jLong[len]; // shape::calcOffsets(zShapeInfo, zOffsets, xOrder); // } -// else if(xEws == 1 && zEws == 1 && xOrder == zOrder && (xOrder == 'c' || shape::shapeEquals(xShapeInfo, zShapeInfo))) { +// else if(xEws == 1 && zEws == 1 && xOrder == zOrder && (xOrder == 'c' || +// shape::shapeEquals(xShapeInfo, zShapeInfo))) { // xOffsets = zOffsets = nullptr; // yOffsets = new Nd4jLong[len]; // shape::calcOffsets(yShapeInfo, yOffsets, xOrder); // } -// else if(yEws == 1 && zEws == 1 && yOrder == zOrder && (yOrder == 'c' || shape::shapeEquals(yShapeInfo, zShapeInfo))) { +// else if(yEws == 1 && zEws == 1 && yOrder == zOrder && (yOrder == 'c' || +// shape::shapeEquals(yShapeInfo, zShapeInfo))) { // yOffsets = zOffsets = nullptr; // xOffsets = new Nd4jLong[len]; // shape::calcOffsets(xShapeInfo, xOffsets, yOrder); @@ -5003,7 +5229,8 @@ INLINEDEF _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLo // } // } // } -// else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo, zShapeInfo)) { +// else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo, +// zShapeInfo)) { // xOffsets = new Nd4jLong[len]; // shape::calcOffsets(xShapeInfo, xOffsets); // yOffsets = zOffsets = xOffsets; @@ -5063,7 +5290,9 @@ INLINEDEF _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLo // } ////////////////////////////////////////////////////////////////////// -// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order) { +// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& +// xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order) +// { // // we assume all array have same length // const Nd4jLong len = shape::length(xShapeInfo); @@ -5076,7 +5305,8 @@ INLINEDEF _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLo // const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo); -// if (xEws == 1 && yEws == 1 && xOrder == yOrder && (xOrder == 'c' || shapesSame)) { +// if (xEws == 1 && yEws == 1 && xOrder == yOrder && (xOrder == 'c' || +// shapesSame)) { // xOffsets = yOffsets = nullptr; // } // else if(xEws == 1) { @@ -5112,52 +5342,56 @@ INLINEDEF _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLo // } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD int excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, Nd4jLong*& shapeNoUnities, Nd4jLong*& stridesNoUnities) { - - const int rank = shape::rank(inShapeInfo); - const int numOfNonUnities = shape::numOfNonUnitDims(rank, shape::shapeOf(inShapeInfo)); - - if(numOfNonUnities == rank) { // no unities in shape, no copy procedure - shapeNoUnities = const_cast(inShapeInfo) + 1; - stridesNoUnities = const_cast(inShapeInfo) + 1 + rank; - return numOfNonUnities; - } +INLINEDEF _CUDA_HD int excludeUnitiesFromShapeInfo( + const Nd4jLong *inShapeInfo, Nd4jLong *&shapeNoUnities, + Nd4jLong *&stridesNoUnities) { + const int rank = shape::rank(inShapeInfo); + const int numOfNonUnities = + shape::numOfNonUnitDims(rank, shape::shapeOf(inShapeInfo)); + + if (numOfNonUnities == rank) { // no unities in shape, no copy procedure + shapeNoUnities = const_cast(inShapeInfo) + 1; + stridesNoUnities = const_cast(inShapeInfo) + 1 + rank; + return numOfNonUnities; + } - for(uint j = 0, i = 0; i < rank; ++i) { - if(shape::shapeOf(inShapeInfo)[i] != 1) { - shapeNoUnities[j] = shape::shapeOf(inShapeInfo)[i]; - shapeNoUnities[numOfNonUnities + j++] = shape::stride(inShapeInfo)[i]; - } + for (uint j = 0, i = 0; i < rank; ++i) { + if (shape::shapeOf(inShapeInfo)[i] != 1) { + shapeNoUnities[j] = shape::shapeOf(inShapeInfo)[i]; + shapeNoUnities[numOfNonUnities + j++] = shape::stride(inShapeInfo)[i]; } + } - stridesNoUnities = shapeNoUnities + numOfNonUnities; + stridesNoUnities = shapeNoUnities + numOfNonUnities; - return numOfNonUnities; + return numOfNonUnities; } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, Nd4jLong* outShapeInfo) { +INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong *inShapeInfo, + const int dimsSize, + const int *dimsToExclude, + Nd4jLong *outShapeInfo) { + outShapeInfo[0] = inShapeInfo[0] - dimsSize; - outShapeInfo[0] = inShapeInfo[0] - dimsSize; - - for(uint j = 0, k = 0, i = 0; i < inShapeInfo[0]; ++i) { - if(j < dimsSize && i == dimsToExclude[j]) { - ++j; - continue; - } - - shape::shapeOf(outShapeInfo)[k] = shape::shapeOf(inShapeInfo)[i]; - shape::stride(outShapeInfo)[k++] = shape::stride(inShapeInfo)[i]; + for (uint j = 0, k = 0, i = 0; i < inShapeInfo[0]; ++i) { + if (j < dimsSize && i == dimsToExclude[j]) { + ++j; + continue; } - sd::ArrayOptions::copyDataType(outShapeInfo, inShapeInfo); // type - *shape::ews(outShapeInfo) = shape::elementWiseStride(inShapeInfo); // ews - outShapeInfo[2 * outShapeInfo[0] + 3] = shape::order(inShapeInfo); // order -} + shape::shapeOf(outShapeInfo)[k] = shape::shapeOf(inShapeInfo)[i]; + shape::stride(outShapeInfo)[k++] = shape::stride(inShapeInfo)[i]; + } + sd::ArrayOptions::copyDataType(outShapeInfo, inShapeInfo); // type + *shape::ews(outShapeInfo) = shape::elementWiseStride(inShapeInfo); // ews + outShapeInfo[2 * outShapeInfo[0] + 3] = shape::order(inShapeInfo); // order +} ////////////////////////////////////////////////////////////////////// -// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo) { +// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const +// Nd4jLong* inShapeInfo) { // Nd4jLong result = 9223372036854775807LL; @@ -5175,8 +5409,6 @@ INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, // return result == 9223372036854775807LL ? 1 : result; // } - - -} +} // namespace shape #endif /* SHAPE_H_ */ diff --git a/libnd4j/include/helpers/svd.h b/libnd4j/include/helpers/svd.h index 58007bf37bf1..6729c0cb6db3 100644 --- a/libnd4j/include/helpers/svd.h +++ b/libnd4j/include/helpers/svd.h @@ -22,67 +22,77 @@ #define LIBND4J_SVD_H #include + #include "array/NDArray.h" -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { template class SVD { + public: + int _switchSize = 10; - public: - - int _switchSize = 10; + NDArray _m; + NDArray _s; + NDArray _u; + NDArray _v; - NDArray _m; - NDArray _s; - NDArray _u; - NDArray _v; - - int _diagSize; + int _diagSize; - bool _transp; - bool _calcU; - bool _calcV; - bool _fullUV; + bool _transp; + bool _calcU; + bool _calcV; + bool _fullUV; - /** - * constructor - */ - SVD(const NDArray& matrix, const int switchSize, const bool calcV, const bool calcU, const bool fullUV); + /** + * constructor + */ + SVD(const NDArray& matrix, const int switchSize, const bool calcV, + const bool calcU, const bool fullUV); - SVD(const NDArray& matrix, const int switchSize, const bool calcV, const bool calcU, const bool fullUV, const char t); + SVD(const NDArray& matrix, const int switchSize, const bool calcV, + const bool calcU, const bool fullUV, const char t); - void deflation1(int col1, int shift, int ind, int size); - - void deflation2(int col1U , int col1M, int row1W, int col1W, int ind1, int ind2, int size); - - void deflation(int col1, int col2, int ind, int row1W, int col1W, int shift); + void deflation1(int col1, int shift, int ind, int size); - // FIXME: proper T support required here - T secularEq(const T diff, const NDArray& col0, const NDArray& diag, const NDArray &permut, const NDArray& diagShifted, const T shift); + void deflation2(int col1U, int col1M, int row1W, int col1W, int ind1, + int ind2, int size); - void calcSingVals(const NDArray& col0, const NDArray& diag, const NDArray& permut, NDArray& singVals, NDArray& shifts, NDArray& mus); + void deflation(int col1, int col2, int ind, int row1W, int col1W, int shift); - void perturb(const NDArray& col0, const NDArray& diag, const NDArray& permut, const NDArray& singVals, const NDArray& shifts, const NDArray& mus, NDArray& zhat); + // FIXME: proper T support required here + T secularEq(const T diff, const NDArray& col0, const NDArray& diag, + const NDArray& permut, const NDArray& diagShifted, const T shift); - void calcSingVecs(const NDArray& zhat, const NDArray& diag, const NDArray& perm, const NDArray& singVals, const NDArray& shifts, const NDArray& mus, NDArray& U, NDArray& V); + void calcSingVals(const NDArray& col0, const NDArray& diag, + const NDArray& permut, NDArray& singVals, NDArray& shifts, + NDArray& mus); - void calcBlockSVD(int firstCol, int size, NDArray& U, NDArray& singVals, NDArray& V); + void perturb(const NDArray& col0, const NDArray& diag, const NDArray& permut, + const NDArray& singVals, const NDArray& shifts, + const NDArray& mus, NDArray& zhat); - void DivideAndConquer(int col1, int col2, int row1W, int col1W, int shift); + void calcSingVecs(const NDArray& zhat, const NDArray& diag, + const NDArray& perm, const NDArray& singVals, + const NDArray& shifts, const NDArray& mus, NDArray& U, + NDArray& V); - void exchangeUV(const HHsequence& hhU, const HHsequence& hhV, const NDArray& U, const NDArray& V); + void calcBlockSVD(int firstCol, int size, NDArray& U, NDArray& singVals, + NDArray& V); - void evalData(const NDArray& matrix); + void DivideAndConquer(int col1, int col2, int row1W, int col1W, int shift); - FORCEINLINE NDArray& getS(); - FORCEINLINE NDArray& getU(); - FORCEINLINE NDArray& getV(); + void exchangeUV(const HHsequence& hhU, const HHsequence& hhV, + const NDArray& U, const NDArray& V); -}; + void evalData(const NDArray& matrix); + FORCEINLINE NDArray& getS(); + FORCEINLINE NDArray& getU(); + FORCEINLINE NDArray& getV(); +}; ////////////////////////////////////////////////////////////////////////// template @@ -102,10 +112,8 @@ FORCEINLINE NDArray& SVD::getV() { return _v; } +} // namespace helpers +} // namespace ops +} // namespace sd - -} -} -} - -#endif //LIBND4J_SVD_H +#endif // LIBND4J_SVD_H diff --git a/libnd4j/include/helpers/threshold.h b/libnd4j/include/helpers/threshold.h index ba5304e55278..55c34e7384fe 100644 --- a/libnd4j/include/helpers/threshold.h +++ b/libnd4j/include/helpers/threshold.h @@ -23,5 +23,4 @@ #include - -#endif //LIBND4J_THRESHOLD_H +#endif // LIBND4J_THRESHOLD_H diff --git a/libnd4j/include/helpers/unicode.h b/libnd4j/include/helpers/unicode.h index 6db4841db2ba..fa5787298c47 100644 --- a/libnd4j/include/helpers/unicode.h +++ b/libnd4j/include/helpers/unicode.h @@ -26,164 +26,163 @@ namespace sd { namespace unicode { - /** - * This method calculate u16 offset based on utf8 - * @param const pointer to the utf8 string start point - * @param size of the string - * @return offset of utf16 - */ - Nd4jLong offsetUtf8StringInUtf16(const void* start, const void* end); - - /** - * This method calculate u8 offset based on utf16 - * @param const pointer to the utf16 string start point - * @param size of the string - * @return offset of utf8 - */ - Nd4jLong offsetUtf16StringInUtf8(const void* start, const void* end); - - /** - * This method calculate u32 offset based on utf16 - * @param const pointer to the utf16 string start point - * @param size of the string - * @return offset of utf32 - */ - Nd4jLong offsetUtf32StringInUtf16(const void* start, const void* end); - - /** - * This method calculate u32 offset based on utf8 - * @param const pointer to the utf16 string start point - * @param size of the string - * @return offset of utf8 - */ - Nd4jLong offsetUtf32StringInUtf8(const void* start, const void* end); - - /* - * This function check is valid charecter in u8 string - */ - bool isStringValidU8(const void* start, const void* stop); - - /* - * This function check is valid charecter in u16 string - */ - bool isStringValidU16(const void* start, const void* stop); - - /* - * This function check is valid u32 charecter in string - */ - bool isStringValidU32(const void* start, const void* stop); - - /** - * This method count offset for utf8 string in utf32 - * @param const pointer to the utf8 string start point - * @param size of the string - * @return offset - */ - Nd4jLong offsetUtf8StringInUtf32(const void* input, uint32_t nInputSize); - - /** - * This method count offset for utf8 string in utf32 - * @param const pointer to the utf8 string start point - * @param const end pointer to the utf8 string - * @return offset - */ - Nd4jLong offsetUtf8StringInUtf32(const void* input, const void* stop); - - /** - * This method count offset for utf32 based on utf16 string - * @param const pointer to the utf16 string start point - * @param size of the string - * @return offset - */ - Nd4jLong offsetUtf16StringInUtf32(const void* input, uint32_t nInputSize); - - /** - * This method calculate offset of u16 based on utf8 - * @param const pointer to the utf8 string start point - * @param size of the string - * @return offset of utf16 - */ - Nd4jLong offsetUtf8StringInUtf16(const void* input, uint32_t nInputSize); - - /** - * This method calculate offset of u8 based on utf16 - * @param const pointer to the utf16 string start point - * @param size of the string - * @return offset of utf8 - */ - Nd4jLong offsetUtf16StringInUtf8(const void* input, uint32_t nInputSize); - - /** - * This method calculate offset of u32 based on utf8 - * @param const pointer to the utf16 string start point - * @param size of the string - * @return offset of utf32 - */ - Nd4jLong offsetUtf32StringInUtf8(const void* input, uint32_t nInputSize); - - /** - * This method calculate offset of u32 based on utf16 - * @param const pointer to the utf16 string start point - * @param size of the string - * @return offset of utf32 - */ - Nd4jLong offsetUtf32StringInUtf16(const void* input, const uint32_t nInputSize); - - /** - * This method convert utf8 string to utf16 string - * @param const pointer to the utf8 string start point - * @param reference to start point to utf16 - * @param size of input utf8 string - * @return status of convertion - */ - bool utf8to16(const void* input, void* output, uint32_t nInputSize); - - /** - * This method convert utf8 string to utf32 string - * @param const pointer to the utf8 string start point - * @param reference to start point to utf32 - * @param size of input utf8 string - * @return status of convertion - */ - bool utf8to32(const void* input, void* output, uint32_t nInputSize); - - /** - * This method convert utf16 string to utf32 string - * @param const pointer to the utf16 string start point - * @param reference to start point to utf32 - * @param size of input utf16 string - * @return status of convertion - */ - bool utf16to32(const void* input, void* output, uint32_t nInputSize); - - /** - * This method convert utf16 string to utf8 string - * @param const pointer to the utf16 string start point - * @param reference to start point to utf8 - * @param size of input utf16 string - * @return status of convertion - */ - bool utf16to8(const void* input, void* output, uint32_t nInputSize); - - /** - * This method convert utf32 string to utf16 string - * @param const pointer to the utf32 string start point - * @param reference to start point to utf16 - * @param size of input utf32 string - * @return status of convertion - */ - bool utf32to16(const void* input, void* output, uint32_t nInputSize); - - /** - * This method convert utf32 string to utf8 string - * @param const pointer to the utf32 string start point - * @param reference to start point to utf8 - * @param size of input utf32 string - * @return status of convertion - */ - bool utf32to8(const void* input, void* output, const Nd4jLong nInputSize); -} -} - - -#endif //LIBND4J_UNICODE_H +/** + * This method calculate u16 offset based on utf8 + * @param const pointer to the utf8 string start point + * @param size of the string + * @return offset of utf16 + */ +Nd4jLong offsetUtf8StringInUtf16(const void* start, const void* end); + +/** + * This method calculate u8 offset based on utf16 + * @param const pointer to the utf16 string start point + * @param size of the string + * @return offset of utf8 + */ +Nd4jLong offsetUtf16StringInUtf8(const void* start, const void* end); + +/** + * This method calculate u32 offset based on utf16 + * @param const pointer to the utf16 string start point + * @param size of the string + * @return offset of utf32 + */ +Nd4jLong offsetUtf32StringInUtf16(const void* start, const void* end); + +/** + * This method calculate u32 offset based on utf8 + * @param const pointer to the utf16 string start point + * @param size of the string + * @return offset of utf8 + */ +Nd4jLong offsetUtf32StringInUtf8(const void* start, const void* end); + +/* + * This function check is valid charecter in u8 string + */ +bool isStringValidU8(const void* start, const void* stop); + +/* + * This function check is valid charecter in u16 string + */ +bool isStringValidU16(const void* start, const void* stop); + +/* + * This function check is valid u32 charecter in string + */ +bool isStringValidU32(const void* start, const void* stop); + +/** + * This method count offset for utf8 string in utf32 + * @param const pointer to the utf8 string start point + * @param size of the string + * @return offset + */ +Nd4jLong offsetUtf8StringInUtf32(const void* input, uint32_t nInputSize); + +/** + * This method count offset for utf8 string in utf32 + * @param const pointer to the utf8 string start point + * @param const end pointer to the utf8 string + * @return offset + */ +Nd4jLong offsetUtf8StringInUtf32(const void* input, const void* stop); + +/** + * This method count offset for utf32 based on utf16 string + * @param const pointer to the utf16 string start point + * @param size of the string + * @return offset + */ +Nd4jLong offsetUtf16StringInUtf32(const void* input, uint32_t nInputSize); + +/** + * This method calculate offset of u16 based on utf8 + * @param const pointer to the utf8 string start point + * @param size of the string + * @return offset of utf16 + */ +Nd4jLong offsetUtf8StringInUtf16(const void* input, uint32_t nInputSize); + +/** + * This method calculate offset of u8 based on utf16 + * @param const pointer to the utf16 string start point + * @param size of the string + * @return offset of utf8 + */ +Nd4jLong offsetUtf16StringInUtf8(const void* input, uint32_t nInputSize); + +/** + * This method calculate offset of u32 based on utf8 + * @param const pointer to the utf16 string start point + * @param size of the string + * @return offset of utf32 + */ +Nd4jLong offsetUtf32StringInUtf8(const void* input, uint32_t nInputSize); + +/** + * This method calculate offset of u32 based on utf16 + * @param const pointer to the utf16 string start point + * @param size of the string + * @return offset of utf32 + */ +Nd4jLong offsetUtf32StringInUtf16(const void* input, const uint32_t nInputSize); + +/** + * This method convert utf8 string to utf16 string + * @param const pointer to the utf8 string start point + * @param reference to start point to utf16 + * @param size of input utf8 string + * @return status of convertion + */ +bool utf8to16(const void* input, void* output, uint32_t nInputSize); + +/** + * This method convert utf8 string to utf32 string + * @param const pointer to the utf8 string start point + * @param reference to start point to utf32 + * @param size of input utf8 string + * @return status of convertion + */ +bool utf8to32(const void* input, void* output, uint32_t nInputSize); + +/** + * This method convert utf16 string to utf32 string + * @param const pointer to the utf16 string start point + * @param reference to start point to utf32 + * @param size of input utf16 string + * @return status of convertion + */ +bool utf16to32(const void* input, void* output, uint32_t nInputSize); + +/** + * This method convert utf16 string to utf8 string + * @param const pointer to the utf16 string start point + * @param reference to start point to utf8 + * @param size of input utf16 string + * @return status of convertion + */ +bool utf16to8(const void* input, void* output, uint32_t nInputSize); + +/** + * This method convert utf32 string to utf16 string + * @param const pointer to the utf32 string start point + * @param reference to start point to utf16 + * @param size of input utf32 string + * @return status of convertion + */ +bool utf32to16(const void* input, void* output, uint32_t nInputSize); + +/** + * This method convert utf32 string to utf8 string + * @param const pointer to the utf32 string start point + * @param reference to start point to utf8 + * @param size of input utf32 string + * @return status of convertion + */ +bool utf32to8(const void* input, void* output, const Nd4jLong nInputSize); +} // namespace unicode +} // namespace sd + +#endif // LIBND4J_UNICODE_H diff --git a/libnd4j/include/indexing/IndicesList.h b/libnd4j/include/indexing/IndicesList.h index a652615d5355..0ecbae29cbaa 100644 --- a/libnd4j/include/indexing/IndicesList.h +++ b/libnd4j/include/indexing/IndicesList.h @@ -22,22 +22,24 @@ #define LIBND4J_INDICESLIST_H #include + #include "NDIndex.h" namespace sd { - class ND4J_EXPORT IndicesList { - protected: - std::vector _indices; - public: - explicit IndicesList() = default; - explicit IndicesList(std::initializer_list list); - - int size(); - NDIndex* at(int idx); - void push_back(NDIndex* idx); - bool isScalar(); - - ~IndicesList(); - }; -} -#endif //LIBND4J_INDICESLIST_H +class SD_EXPORT IndicesList { + protected: + std::vector _indices; + + public: + explicit IndicesList() = default; + explicit IndicesList(std::initializer_list list); + + int size(); + NDIndex* at(int idx); + void push_back(NDIndex* idx); + bool isScalar(); + + ~IndicesList(); +}; +} // namespace sd +#endif // LIBND4J_INDICESLIST_H diff --git a/libnd4j/include/indexing/NDIndex.h b/libnd4j/include/indexing/NDIndex.h index 799da4e6ca99..00189a1aa290 100644 --- a/libnd4j/include/indexing/NDIndex.h +++ b/libnd4j/include/indexing/NDIndex.h @@ -21,54 +21,53 @@ #ifndef LIBND4J_NDINDEX_H #define LIBND4J_NDINDEX_H +#include #include + #include -#include namespace sd { - class ND4J_EXPORT NDIndex { - protected: - std::vector _indices; - Nd4jLong _stride = 1; - public: - NDIndex() = default; - ~NDIndex() = default; - - bool isAll(); - bool isPoint(); - virtual bool isInterval(); - - std::vector& getIndices(); - Nd4jLong stride(); +class SD_EXPORT NDIndex { + protected: + std::vector _indices; + Nd4jLong _stride = 1; - static NDIndex* all(); - static NDIndex* point(Nd4jLong pt); - static NDIndex* interval(Nd4jLong start, Nd4jLong end, Nd4jLong stride = 1); - }; + public: + NDIndex() = default; + ~NDIndex() = default; - class ND4J_EXPORT NDIndexAll : public NDIndex { - public: - NDIndexAll(); - virtual bool isInterval(); - ~NDIndexAll() = default; - }; + bool isAll(); + bool isPoint(); + virtual bool isInterval(); + std::vector& getIndices(); + Nd4jLong stride(); - class ND4J_EXPORT NDIndexPoint : public NDIndex { - public: - NDIndexPoint(Nd4jLong point); - virtual bool isInterval(); - ~NDIndexPoint() = default; - }; + static NDIndex* all(); + static NDIndex* point(Nd4jLong pt); + static NDIndex* interval(Nd4jLong start, Nd4jLong end, Nd4jLong stride = 1); +}; - class ND4J_EXPORT NDIndexInterval : public NDIndex { - public: - NDIndexInterval(Nd4jLong start, Nd4jLong end, Nd4jLong stride = 1); - virtual bool isInterval(); - ~NDIndexInterval() = default; - }; -} +class SD_EXPORT NDIndexAll : public NDIndex { + public: + NDIndexAll(); + virtual bool isInterval(); + ~NDIndexAll() = default; +}; +class SD_EXPORT NDIndexPoint : public NDIndex { + public: + NDIndexPoint(Nd4jLong point); + virtual bool isInterval(); + ~NDIndexPoint() = default; +}; +class SD_EXPORT NDIndexInterval : public NDIndex { + public: + NDIndexInterval(Nd4jLong start, Nd4jLong end, Nd4jLong stride = 1); + virtual bool isInterval(); + ~NDIndexInterval() = default; +}; +} // namespace sd -#endif //LIBND4J_NDINDEX_H +#endif // LIBND4J_NDINDEX_H diff --git a/libnd4j/include/indexing/impl/IndicesList.cpp b/libnd4j/include/indexing/impl/IndicesList.cpp index 5acbf57d5cf9..d3467d12a184 100644 --- a/libnd4j/include/indexing/impl/IndicesList.cpp +++ b/libnd4j/include/indexing/impl/IndicesList.cpp @@ -22,32 +22,24 @@ using namespace sd; -sd::IndicesList::IndicesList(std::initializer_list list) { - for (auto v: list) - _indices.emplace_back(v); +sd::IndicesList::IndicesList(std::initializer_list list) { + for (auto v : list) _indices.emplace_back(v); } sd::IndicesList::~IndicesList() { - for(auto v: _indices) - delete v; + for (auto v : _indices) delete v; } -int sd::IndicesList::size() { - return (int) _indices.size(); -} +int sd::IndicesList::size() { return (int)_indices.size(); } bool sd::IndicesList::isScalar() { - if (_indices.size() == 1) { - return _indices.at(0)->isPoint(); - } + if (_indices.size() == 1) { + return _indices.at(0)->isPoint(); + } - return false; + return false; } -sd::NDIndex* sd::IndicesList::at(int idx) { - return _indices.at(idx); -} +sd::NDIndex* sd::IndicesList::at(int idx) { return _indices.at(idx); } -void sd::IndicesList::push_back(NDIndex* idx) { - _indices.emplace_back(idx); -} \ No newline at end of file +void sd::IndicesList::push_back(NDIndex* idx) { _indices.emplace_back(idx); } \ No newline at end of file diff --git a/libnd4j/include/indexing/impl/NDIndex.cpp b/libnd4j/include/indexing/impl/NDIndex.cpp index 43aaf09143fc..7111ce4cd8d5 100644 --- a/libnd4j/include/indexing/impl/NDIndex.cpp +++ b/libnd4j/include/indexing/impl/NDIndex.cpp @@ -22,64 +22,45 @@ namespace sd { - bool NDIndex::isInterval() { - return false; - } +bool NDIndex::isInterval() { return false; } - Nd4jLong NDIndex::stride() { - return _stride; - } +Nd4jLong NDIndex::stride() { return _stride; } - sd::NDIndexAll::NDIndexAll() : sd::NDIndex() { - _indices.push_back(-1); - } +sd::NDIndexAll::NDIndexAll() : sd::NDIndex() { _indices.push_back(-1); } - sd::NDIndexPoint::NDIndexPoint(Nd4jLong point) : sd::NDIndex() { - this->_indices.push_back(point); - } +sd::NDIndexPoint::NDIndexPoint(Nd4jLong point) : sd::NDIndex() { + this->_indices.push_back(point); +} - bool NDIndexAll::isInterval() { - return false; - } +bool NDIndexAll::isInterval() { return false; } - bool NDIndexPoint::isInterval() { - return false; - } +bool NDIndexPoint::isInterval() { return false; } - bool NDIndexInterval::isInterval() { - return true; - } +bool NDIndexInterval::isInterval() { return true; } +sd::NDIndexInterval::NDIndexInterval(Nd4jLong start, Nd4jLong end, + Nd4jLong stride) + : sd::NDIndex() { + this->_stride = stride; + for (int e = start; e < end; e += stride) this->_indices.push_back(e); +} +bool sd::NDIndex::isAll() { + return _indices.size() == 1 && _indices.at(0) == -1; +} - sd::NDIndexInterval::NDIndexInterval(Nd4jLong start, Nd4jLong end, Nd4jLong stride) : sd::NDIndex() { - this->_stride = stride; - for (int e = start; e < end; e+= stride) - this->_indices.push_back(e); - } +bool sd::NDIndex::isPoint() { + return _indices.size() == 1 && _indices.at(0) >= 0; +} - bool sd::NDIndex::isAll() { - return _indices.size() == 1 && _indices.at(0) == -1; - } +std::vector &sd::NDIndex::getIndices() { return _indices; } - bool sd::NDIndex::isPoint() { - return _indices.size() == 1 && _indices.at(0) >= 0; - } +sd::NDIndex *sd::NDIndex::all() { return new NDIndexAll(); } - std::vector &sd::NDIndex::getIndices() { - return _indices; - } +sd::NDIndex *sd::NDIndex::point(Nd4jLong pt) { return new NDIndexPoint(pt); } - - sd::NDIndex *sd::NDIndex::all() { - return new NDIndexAll(); - } - - sd::NDIndex *sd::NDIndex::point(Nd4jLong pt) { - return new NDIndexPoint(pt); - } - - sd::NDIndex *sd::NDIndex::interval(Nd4jLong start, Nd4jLong end, Nd4jLong stride) { - return new NDIndexInterval(start, end, stride); - } -} \ No newline at end of file +sd::NDIndex *sd::NDIndex::interval(Nd4jLong start, Nd4jLong end, + Nd4jLong stride) { + return new NDIndexInterval(start, end, stride); +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/legacy/NativeOpExecutioner.h b/libnd4j/include/legacy/NativeOpExecutioner.h index 84ab886c4352..9e1c8f2842b6 100644 --- a/libnd4j/include/legacy/NativeOpExecutioner.h +++ b/libnd4j/include/legacy/NativeOpExecutioner.h @@ -21,659 +21,587 @@ #ifndef NATIVEOPERATIONS_NATIVEOPEXCUTIONER_H #define NATIVEOPERATIONS_NATIVEOPEXCUTIONER_H - -#include -#include +#include +#include #include #include -#include -#include +#include +#include /** * Native op executioner: * */ -class ND4J_EXPORT NativeOpExecutioner { -public: - /** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param extraParams - * @param result - * @param resultShapeInfo - */ - static void execIndexReduceScalar(sd::LaunchContext *lc, - int opNum, +class SD_EXPORT NativeOpExecutioner { + public: + /** + * + * @param opNum + * @param x + * @param xShapeInfo + * @param extraParams + * @param result + * @param resultShapeInfo + */ + static void execIndexReduceScalar(sd::LaunchContext *lc, int opNum, const void *hX, const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo); - - /** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param extraParamsVals - * @param y - * @param yShapeInfo - * @param result - * @param resultShapeInfoBuffer - * @param dimension - * @param dimensionLength - */ - static void execReduce3Scalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParamsVals, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo); - - - /** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param extraParamsVals - * @param y - * @param yShapeInfo - * @param result - * @param resultShapeInfo - */ - static void execReduce3(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParamsVals, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo); - - /** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param extraParamsVals - * @param y - * @param yShapeInfo - * @param result - * @param resultShapeInfoBuffer - * @param dimension - * @param dimensionLength - */ - static void execReduce3(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParamsVals, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadOnlyShapeInfo, const Nd4jLong *xTadOffsets, - const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets); - - static void execReduce3All(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParamsVals, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets); - - /** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param extraParams - * @param result - * @param resultShapeInfoBuffer - * @param dimension - * @param dimensionLength - */ - static void execIndexReduce(sd::LaunchContext *lc, - int opNum, + void *extraParams, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo); + + /** + * + * @param opNum + * @param x + * @param xShapeInfo + * @param extraParamsVals + * @param y + * @param yShapeInfo + * @param result + * @param resultShapeInfoBuffer + * @param dimension + * @param dimensionLength + */ + static void execReduce3Scalar(sd::LaunchContext *lc, int opNum, + const void *hX, const Nd4jLong *hXShapeInfo, + const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParamsVals, const void *hY, + const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo); + + /** + * + * @param opNum + * @param x + * @param xShapeInfo + * @param extraParamsVals + * @param y + * @param yShapeInfo + * @param result + * @param resultShapeInfo + */ + static void execReduce3(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *extraParamsVals, + const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo); + + /** + * + * @param opNum + * @param x + * @param xShapeInfo + * @param extraParamsVals + * @param y + * @param yShapeInfo + * @param result + * @param resultShapeInfoBuffer + * @param dimension + * @param dimensionLength + */ + static void execReduce3( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParamsVals, const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *xTadOnlyShapeInfo, + const Nd4jLong *xTadOffsets, const Nd4jLong *yTadOnlyShapeInfo, + const Nd4jLong *yTadOffsets); + + static void execReduce3All(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *extraParamsVals, + const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xOffsets, + const Nd4jLong *yTadShapeInfo, + const Nd4jLong *yOffsets); + + /** + * + * @param opNum + * @param x + * @param xShapeInfo + * @param extraParams + * @param result + * @param resultShapeInfoBuffer + * @param dimension + * @param dimensionLength + */ + static void execIndexReduce(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *extraParams, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + /** + * + * @param opNum + * @param x + * @param xStride + * @param result + * @param resultStride + * @param scalar + * @param extraParams + * @param n + */ + static void execScalar(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, const void *hScalar, + const Nd4jLong *hSscalarShapeInfo, const void *dScalar, + const Nd4jLong *dSscalarShapeInfo, void *extraParams, + bool allowParallelism = true); + + static void execScalarBool(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, const void *hScalar, + const Nd4jLong *hSscalarShapeInfo, + const void *dScalar, + const Nd4jLong *dSscalarShapeInfo, + void *extraParams, bool allowParallelism = true); + + static void execScalarInt(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, const void *hScalar, + const Nd4jLong *hSscalarShapeInfo, + const void *dScalar, + const Nd4jLong *dSscalarShapeInfo, + void *extraParams, bool allowParallelism = true); + + static void execScalar(sd::LaunchContext *lc, int opNum, void const *hX, + Nd4jLong const *hXShapeInfo, void const *dX, + Nd4jLong const *dXShapeInfo, void *extraParams, + void *hZ, Nd4jLong const *hZShapeInfo, void *dZ, + Nd4jLong const *dZShapeInfo, void const *hScalars, + Nd4jLong const *hScalarShapeInfo, void const *dScalars, + Nd4jLong const *dScalarShapeInfo, int *dimension, + int dimensionLength, Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, + Nd4jLong const *tadShapeInfoZ, + Nd4jLong const *tadOffsetsZ); + + static void execScalarBool( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, const void *hScalars, + const Nd4jLong *hScalarShapeInfo, const void *dScalars, + const Nd4jLong *dScalarShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ); + + static void execScalarInt( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, const void *hScalars, + const Nd4jLong *hScalarShapeInfo, const void *dScalars, + const Nd4jLong *dScalarShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ); + + /** + * + * @param opNum + * @param x + * @param xShapeInfo + * @param y + * @param yShapeInfo + * @param result + * @param resultShapeInfo + * @param dimension + * @param dimensionLength + */ + static void execBroadcast( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + static void execBroadcast(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, const void *hY, + const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo); + + static void execInverseBroadcast( + sd::LaunchContext *lc, int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + static void execBroadcastBool( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, void *extraParams, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + static void execBroadcastBool(sd::LaunchContext *lc, int opNum, const void *hX, const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - /** - * - * @param opNum - * @param x - * @param xStride - * @param result - * @param resultStride - * @param scalar - * @param extraParams - * @param n - */ - static void execScalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - const void *hScalar, const Nd4jLong *hSscalarShapeInfo, - const void *dScalar, const Nd4jLong *dSscalarShapeInfo, - void *extraParams, - bool allowParallelism = true); - -static void execScalarBool(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - const void *hScalar, const Nd4jLong *hSscalarShapeInfo, - const void *dScalar, const Nd4jLong *dSscalarShapeInfo, - void *extraParams, - bool allowParallelism = true); - -static void execScalarInt(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - const void *hScalar, const Nd4jLong *hSscalarShapeInfo, - const void *dScalar, const Nd4jLong *dSscalarShapeInfo, - void *extraParams, + const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, void *extraParams); + + static void execInverseBroadcastBool( + sd::LaunchContext *lc, int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo, + void *extraParams, int *dimension, int dimensionLength, + const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); + + static void execBroadcastInt( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + static void execBroadcastInt(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, const void *hY, + const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo); + + static void execInverseBroadcastInt( + sd::LaunchContext *lc, int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + /** + * + * @param opNum + * @param dx + * @param xStride + * @param y + * @param yStride + * @param result + * @param resultStride + * @param extraParams + * @param n + */ + static void execPairwiseTransform(sd::LaunchContext *lc, int opNum, + const void *hX, const Nd4jLong *hXShapeInfo, + const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, + void *extraParams); + + static void execPairwiseBoolTransform( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, void *extraParams); + + static void execPairwiseIntTransform( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, void *extraParams); + + /** + * + * @param opNum + * @param dx + * @param xStride + * @param result + * @param resultStride + * @param extraParams + * @param n + */ + static void execTransformFloat(sd::LaunchContext *lc, int opNum, + const void *hX, const Nd4jLong *hXShapeInfo, + const void *dX, const Nd4jLong *dXShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, + void *extraParams, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static void execTransformAny(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, void *extraParams, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, bool allowParallelism = true); - static void execScalar(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void const* hScalars, Nd4jLong const* hScalarShapeInfo, - void const* dScalars, Nd4jLong const* dScalarShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ); - - static void execScalarBool(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - const void *hScalars, const Nd4jLong *hScalarShapeInfo, - const void *dScalars, const Nd4jLong *dScalarShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - static void execScalarInt(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - const void *hScalars, const Nd4jLong *hScalarShapeInfo, - const void *dScalars, const Nd4jLong *dScalarShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - -/** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param y - * @param yShapeInfo - * @param result - * @param resultShapeInfo - * @param dimension - * @param dimensionLength - */ - static void execBroadcast(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ); - - static void execBroadcast(sd::LaunchContext* lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo); - - static void execInverseBroadcast(sd::LaunchContext *lc, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - - static void execBroadcastBool(sd::LaunchContext *lc, - int opNum, + static void execTransformStrict(sd::LaunchContext *lc, int opNum, const void *hX, const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo, void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ); + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); - static void execBroadcastBool(sd::LaunchContext* lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams); - - static void execInverseBroadcastBool(sd::LaunchContext *lc, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - static void execBroadcastInt(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ); + static void execTransformSame(sd::LaunchContext *lc, int opNum, + const void *hX, const Nd4jLong *hXShapeInfo, + const void *dX, const Nd4jLong *dXShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, void *extraParams, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); - static void execBroadcastInt(sd::LaunchContext* lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo); - - static void execInverseBroadcastInt(sd::LaunchContext *lc, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); + static void execTransformBool(sd::LaunchContext *lc, int opNum, + const void *hX, const Nd4jLong *hXShapeInfo, + const void *dX, const Nd4jLong *dXShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, void *extraParams, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + /** + * + * @param opNum + * @param x + * @param xShapeInfo + * @param extraParams + * @param result + * @param resultShapeInfo + */ + static void execReduceFloat(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *extraParams, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static void execReduceSame(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *extraParams, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static void execReduceBool(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *extraParams, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static void execReduceLong(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *extraParams, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + /** + * + * @param opNum + * @param x + * @param xShapeInfo + * @param extraParams + * @return + */ + static void execReduceFloatScalar(sd::LaunchContext *lc, int opNum, + const void *hX, const Nd4jLong *hXShapeInfo, + const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo); -/** - * - * @param opNum - * @param dx - * @param xStride - * @param y - * @param yStride - * @param result - * @param resultStride - * @param extraParams - * @param n - */ - static void execPairwiseTransform(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams); - - static void execPairwiseBoolTransform(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams); - - static void execPairwiseIntTransform(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams); + static void execReduceBoolScalar(sd::LaunchContext *lc, int opNum, + const void *hX, const Nd4jLong *hXShapeInfo, + const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo); -/** - * - * @param opNum - * @param dx - * @param xStride - * @param result - * @param resultStride - * @param extraParams - * @param n - */ - static void execTransformFloat(sd::LaunchContext *lc, - int opNum, + static void execReduceSameScalar(sd::LaunchContext *lc, int opNum, const void *hX, const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - -static void execTransformAny(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - bool allowParallelism = true); - -static void execTransformStrict(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - -static void execTransformSame(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - -static void execTransformBool(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - /** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param extraParams - * @param result - * @param resultShapeInfo - */ - static void execReduceFloat(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - static void execReduceSame(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - static void execReduceBool(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - static void execReduceLong(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - /** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param extraParams - * @return - */ - static void execReduceFloatScalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo); - - static void execReduceBoolScalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo); - - static void execReduceSameScalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo); - - static void execReduceLongScalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo); - - static void execReduce3TAD(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParamsVals, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yTadOffsets); - - /** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param extraParams - * @param result - * @param resultShapeInfoBuffer - * @param dimension - * @param dimensionLength - */ - static void execSummaryStats(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - bool biasCorrected); - - /** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param extraParams - * @param result - * @param resultShapeInfo - */ - static void execSummaryStats(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - bool biasCorrected); - - /** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param extraParams - * @param result - * @param resultShapeInfo - */ - static void execSummaryStatsScalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - bool biasCorrected); - - - static void execRandom(sd::LaunchContext *lc, - int opNum, - Nd4jPointer state, - void *hZ, const Nd4jLong *hZShapeBuffer, - void *dZ, const Nd4jLong *dZShapeBuffer, - void *extraArguments); - - static void execRandom(sd::LaunchContext *lc, - int opNum, - Nd4jPointer state, - const void *hX, const Nd4jLong *hXShapeBuffer, - const void *dX, const Nd4jLong *dXShapeBuffer, - void *hZ, const Nd4jLong *hZShapeBuffer, - void *dZ, const Nd4jLong *dZShapeBuffer, - void *extraArguments); - - static void execRandom(sd::LaunchContext *lc, - int opNum, - Nd4jPointer state, - const void *hX, const Nd4jLong *hXShapeBuffer, - const void *dX, const Nd4jLong *dXShapeBuffer, - const void *hY, const Nd4jLong *hYShapeBuffer, - const void *dY, const Nd4jLong *dYShapeBuffer, - void *hZ, const Nd4jLong *hZShapeBuffer, - void *dZ, const Nd4jLong *dZShapeBuffer, - void *extraArguments); - - - - inline static void execSort(void *x, const Nd4jLong *xShapeInfo, bool descending) { - auto xType = sd::ArrayOptions::dataType(xShapeInfo); - - BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::sortGeneric(x, xShapeInfo, descending), LIBND4J_TYPES); - } - - static void execSort(void *x, const Nd4jLong *xShapeInfo, int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, bool descending) { - auto xType = sd::ArrayOptions::dataType(xShapeInfo); - - BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::sortTadGeneric(x, xShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending), LIBND4J_TYPES); - } - - inline static void execSortCooIndices(Nd4jLong *indices, void *values, Nd4jLong length, int rank) { - sd::sparse::SparseUtils::sortCooIndicesGeneric(indices, reinterpret_cast(values), length, rank); - } - - - inline static Nd4jLong encodeBitmap(void *dx, const Nd4jLong *xShapeInfo, Nd4jLong N, int *dz, float threshold) { - auto xType = sd::ArrayOptions::dataType(xShapeInfo); - - BUILD_SINGLE_SELECTOR(xType, return sd::SpecialMethods, ::encodeBitmapGeneric(dx, xShapeInfo, N, dz, threshold), FLOAT_TYPES); - } - - inline static void decodeBitmap(const void *dx, Nd4jLong N, void *dz, const Nd4jLong *zShapeInfo) { - auto zType = sd::ArrayOptions::dataType(zShapeInfo); - - BUILD_SINGLE_SELECTOR(zType, sd::SpecialMethods, ::decodeBitmapGeneric(dx, N, dz, zShapeInfo), FLOAT_TYPES); - } + void *extraParams, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo); + static void execReduceLongScalar(sd::LaunchContext *lc, int opNum, + const void *hX, const Nd4jLong *hXShapeInfo, + const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo); + + static void execReduce3TAD(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *extraParamsVals, + const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, + const Nd4jLong *yTadShapeInfo, + const Nd4jLong *yTadOffsets); + + /** + * + * @param opNum + * @param x + * @param xShapeInfo + * @param extraParams + * @param result + * @param resultShapeInfoBuffer + * @param dimension + * @param dimensionLength + */ + static void execSummaryStats(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *extraParams, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, + int dimensionLength, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, bool biasCorrected); + + /** + * + * @param opNum + * @param x + * @param xShapeInfo + * @param extraParams + * @param result + * @param resultShapeInfo + */ + static void execSummaryStats(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *extraParams, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, bool biasCorrected); + + /** + * + * @param opNum + * @param x + * @param xShapeInfo + * @param extraParams + * @param result + * @param resultShapeInfo + */ + static void execSummaryStatsScalar( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, bool biasCorrected); + + static void execRandom(sd::LaunchContext *lc, int opNum, Nd4jPointer state, + void *hZ, const Nd4jLong *hZShapeBuffer, void *dZ, + const Nd4jLong *dZShapeBuffer, void *extraArguments); + + static void execRandom(sd::LaunchContext *lc, int opNum, Nd4jPointer state, + const void *hX, const Nd4jLong *hXShapeBuffer, + const void *dX, const Nd4jLong *dXShapeBuffer, + void *hZ, const Nd4jLong *hZShapeBuffer, void *dZ, + const Nd4jLong *dZShapeBuffer, void *extraArguments); + + static void execRandom(sd::LaunchContext *lc, int opNum, Nd4jPointer state, + const void *hX, const Nd4jLong *hXShapeBuffer, + const void *dX, const Nd4jLong *dXShapeBuffer, + const void *hY, const Nd4jLong *hYShapeBuffer, + const void *dY, const Nd4jLong *dYShapeBuffer, + void *hZ, const Nd4jLong *hZShapeBuffer, void *dZ, + const Nd4jLong *dZShapeBuffer, void *extraArguments); + + inline static void execSort(void *x, const Nd4jLong *xShapeInfo, + bool descending) { + auto xType = sd::ArrayOptions::dataType(xShapeInfo); + + BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, + ::sortGeneric(x, xShapeInfo, descending), + LIBND4J_TYPES); + } + + static void execSort(void *x, const Nd4jLong *xShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, bool descending) { + auto xType = sd::ArrayOptions::dataType(xShapeInfo); + + BUILD_SINGLE_SELECTOR( + xType, sd::SpecialMethods, + ::sortTadGeneric(x, xShapeInfo, dimension, dimensionLength, + tadShapeInfo, tadOffsets, descending), + LIBND4J_TYPES); + } + + inline static void execSortCooIndices(Nd4jLong *indices, void *values, + Nd4jLong length, int rank) { + sd::sparse::SparseUtils::sortCooIndicesGeneric( + indices, reinterpret_cast(values), length, rank); + } + + inline static Nd4jLong encodeBitmap(void *dx, const Nd4jLong *xShapeInfo, + Nd4jLong N, int *dz, float threshold) { + auto xType = sd::ArrayOptions::dataType(xShapeInfo); + + BUILD_SINGLE_SELECTOR( + xType, return sd::SpecialMethods, + ::encodeBitmapGeneric(dx, xShapeInfo, N, dz, threshold), FLOAT_TYPES); + } + + inline static void decodeBitmap(const void *dx, Nd4jLong N, void *dz, + const Nd4jLong *zShapeInfo) { + auto zType = sd::ArrayOptions::dataType(zShapeInfo); + + BUILD_SINGLE_SELECTOR(zType, sd::SpecialMethods, + ::decodeBitmapGeneric(dx, N, dz, zShapeInfo), + FLOAT_TYPES); + } }; - -#endif //NATIVEOPERATIONS_NATIVEOPEXCUTIONER_H +#endif // NATIVEOPERATIONS_NATIVEOPEXCUTIONER_H diff --git a/libnd4j/include/legacy/NativeOps.h b/libnd4j/include/legacy/NativeOps.h old mode 100755 new mode 100644 index 29c629b5a729..74f7b759aee4 --- a/libnd4j/include/legacy/NativeOps.h +++ b/libnd4j/include/legacy/NativeOps.h @@ -31,7 +31,7 @@ defined __DMC__ || \ defined __BORLANDC__ ) # define thread_local __declspec(thread) -// note that ICC (linux) and Clang are covered by __GNUC__ +// note that ICC (linux) and Clang are covered by __GNUC__ # elif defined __GNUC__ || \ defined __SUNPRO_C || \ defined __xlC__ @@ -42,17 +42,17 @@ #endif */ +#include #include #include -#include -//DO NOT REMOVE: THIS IS AN EDITOR SEMANTICS THING FOR CLION -//IT DEFINES THE EXPORT MACRO FOR THE EDITOR AND THEN -//RE ADDS THE DEFINITION VIA dll.h -#ifdef _WIN32 -#define ND4J_EXPORT __declspec(dllexport) +// DO NOT REMOVE: THIS IS AN EDITOR SEMANTICS THING FOR CLION +// IT DEFINES THE EXPORT MACRO FOR THE EDITOR AND THEN +// RE ADDS THE DEFINITION VIA dll.h +#ifdef _WIN32 +#define SD_EXPORT __declspec(dllexport) #else -#define ND4J_EXPORT +#define SD_EXPORT #endif #include @@ -64,17 +64,16 @@ bool debug = false; bool verbose = false; */ -#include -#include -#include #include +#include #include -#include +#include #include -#include -#include -#include #include +#include +#include +#include +#include #include #include @@ -86,45 +85,47 @@ extern "C" { * This function returns last error code stored, * @return non-zero if something bad happened */ -ND4J_EXPORT int lastErrorCode(); +SD_EXPORT int lastErrorCode(); /** * This function returns last error message, if last error code > 0 * @return */ -ND4J_EXPORT const char* lastErrorMessage(); +SD_EXPORT const char* lastErrorMessage(); /** * * @param p * @param len */ -ND4J_EXPORT void tryPointer(Nd4jPointer extra, Nd4jPointer p, int len); +SD_EXPORT void tryPointer(Nd4jPointer extra, Nd4jPointer p, int len); /** * * @param num */ -ND4J_EXPORT void setElementThreshold(int num); +SD_EXPORT void setElementThreshold(int num); /** * * @param num */ -ND4J_EXPORT void setTADThreshold(int num); +SD_EXPORT void setTADThreshold(int num); /** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param extraParams - */ -ND4J_EXPORT void execIndexReduceScalar(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo); + * + * @param opNum + * @param x + * @param xShapeInfo + * @param extraParams + */ +SD_EXPORT void execIndexReduceScalar(Nd4jPointer* extraPointers, int opNum, + OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, + Nd4jLong const* dXShapeInfo, + void* extraParams, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo); /** * @@ -137,12 +138,12 @@ ND4J_EXPORT void execIndexReduceScalar(Nd4jPointer *extraPointers, * @param dimension * @param dimensionLength */ -ND4J_EXPORT void execIndexReduce(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape); +SD_EXPORT void execIndexReduce( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, void* extraParams, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, OpaqueDataBuffer* dbDimension, + Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape); /** * @@ -156,23 +157,25 @@ ND4J_EXPORT void execIndexReduce(Nd4jPointer *extraPointers, * @param dimension * @param dimensionLength */ -ND4J_EXPORT void execBroadcast( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape); +SD_EXPORT void execBroadcast(Nd4jPointer* extraPointers, int opNum, + OpaqueDataBuffer* dbX, Nd4jLong const* hXShapeInfo, + Nd4jLong const* dXShapeInfo, OpaqueDataBuffer* dbY, + Nd4jLong const* hYShapeInfo, + Nd4jLong const* dYShapeInfo, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, + OpaqueDataBuffer* dbDimension, + Nd4jLong const* hDimensionShape, + Nd4jLong const* dDimensionShape); - -ND4J_EXPORT void execBroadcastBool( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape); +SD_EXPORT void execBroadcastBool( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + OpaqueDataBuffer* dbY, Nd4jLong const* hYShapeInfo, + Nd4jLong const* dYShapeInfo, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, void* extraParams, + OpaqueDataBuffer* dbDimension, Nd4jLong const* hDimensionShape, + Nd4jLong const* dDimensionShape); /** * @@ -186,21 +189,21 @@ ND4J_EXPORT void execBroadcastBool( * @param extraParams * @param n */ -ND4J_EXPORT void execPairwiseTransform( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams); +SD_EXPORT void execPairwiseTransform( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + OpaqueDataBuffer* dbY, Nd4jLong const* hYShapeInfo, + Nd4jLong const* dYShapeInfo, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, + void* extraParams); -ND4J_EXPORT void execPairwiseTransformBool( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams); +SD_EXPORT void execPairwiseTransformBool( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + OpaqueDataBuffer* dbY, Nd4jLong const* hYShapeInfo, + Nd4jLong const* dYShapeInfo, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, + void* extraParams); /** * @@ -211,30 +214,37 @@ ND4J_EXPORT void execPairwiseTransformBool( * @param result * @param resultShapeInfo */ -ND4J_EXPORT void execReduceFloat(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo); - -ND4J_EXPORT void execReduceSame(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo); - -ND4J_EXPORT void execReduceBool(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo); - - -ND4J_EXPORT void execReduceLong(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo); +SD_EXPORT void execReduceFloat(Nd4jPointer* extraPointers, int opNum, + OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, + Nd4jLong const* dXShapeInfo, void* extraParams, + OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo); + +SD_EXPORT void execReduceSame(Nd4jPointer* extraPointers, int opNum, + OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, + Nd4jLong const* dXShapeInfo, void* extraParams, + OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo); + +SD_EXPORT void execReduceBool(Nd4jPointer* extraPointers, int opNum, + OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, + Nd4jLong const* dXShapeInfo, void* extraParams, + OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo); + +SD_EXPORT void execReduceLong(Nd4jPointer* extraPointers, int opNum, + OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, + Nd4jLong const* dXShapeInfo, void* extraParams, + OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo); /** * @@ -245,36 +255,33 @@ ND4J_EXPORT void execReduceLong(Nd4jPointer *extraPointers, * @param result * @param resultShapeInfo */ -ND4J_EXPORT void execReduceFloat2(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape); - +SD_EXPORT void execReduceFloat2( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, void* extraParams, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, OpaqueDataBuffer* dbDimension, + Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape); -ND4J_EXPORT void execReduceSame2(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape); +SD_EXPORT void execReduceSame2( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, void* extraParams, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, OpaqueDataBuffer* dbDimension, + Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape); +SD_EXPORT void execReduceBool2( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, void* extraParams, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, OpaqueDataBuffer* dbDimension, + Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape); -ND4J_EXPORT void execReduceBool2(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape); - - -ND4J_EXPORT void execReduceLong2(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape); +SD_EXPORT void execReduceLong2( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, void* extraParams, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, OpaqueDataBuffer* dbDimension, + Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape); /** * @@ -287,12 +294,13 @@ ND4J_EXPORT void execReduceLong2(Nd4jPointer *extraPointers, * @param result * @param resultShapeInfo */ -ND4J_EXPORT void execReduce3(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParamsVals, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo); +SD_EXPORT void execReduce3(Nd4jPointer* extraPointers, int opNum, + OpaqueDataBuffer* dbX, Nd4jLong const* hXShapeInfo, + Nd4jLong const* dXShapeInfo, void* extraParamsVals, + OpaqueDataBuffer* dbY, Nd4jLong const* hYShapeInfo, + Nd4jLong const* dYShapeInfo, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo); /** * @@ -303,12 +311,12 @@ ND4J_EXPORT void execReduce3(Nd4jPointer *extraPointers, * @param y * @param yShapeInfo */ -ND4J_EXPORT void execReduce3Scalar(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParamsVals, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo); +SD_EXPORT void execReduce3Scalar( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + void* extraParamsVals, OpaqueDataBuffer* dbY, Nd4jLong const* hYShapeInfo, + Nd4jLong const* dYShapeInfo, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo); /** * * @param opNum @@ -322,26 +330,27 @@ ND4J_EXPORT void execReduce3Scalar(Nd4jPointer *extraPointers, * @param dimension * @param dimensionLength */ -ND4J_EXPORT void execReduce3Tad(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParamsVals, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets); - - -ND4J_EXPORT void execReduce3All(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParamsVals, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, - Nd4jLong const* xTadShapeInfo, Nd4jLong const* xOffsets, - Nd4jLong const* yTadShapeInfo, Nd4jLong const* yOffsets); +SD_EXPORT void execReduce3Tad( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + void* extraParamsVals, OpaqueDataBuffer* dbY, Nd4jLong const* hYShapeInfo, + Nd4jLong const* dYShapeInfo, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, + OpaqueDataBuffer* dbDimension, Nd4jLong const* hDimensionShape, + Nd4jLong const* dDimensionShape, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* yTadOnlyShapeInfo, + Nd4jLong const* yTadOffsets); + +SD_EXPORT void execReduce3All( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + void* extraParamsVals, OpaqueDataBuffer* dbY, Nd4jLong const* hYShapeInfo, + Nd4jLong const* dYShapeInfo, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, + OpaqueDataBuffer* dbDimension, Nd4jLong const* hDimensionShape, + Nd4jLong const* dDimensionShape, Nd4jLong const* xTadShapeInfo, + Nd4jLong const* xOffsets, Nd4jLong const* yTadShapeInfo, + Nd4jLong const* yOffsets); /** * @@ -354,19 +363,22 @@ ND4J_EXPORT void execReduce3All(Nd4jPointer *extraPointers, * @param extraParams * @param n */ -ND4J_EXPORT void execScalar(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbScalar, Nd4jLong const* hSscalarShapeInfo, Nd4jLong const* dSscalarShapeInfo, - void *extraParams); +SD_EXPORT void execScalar(Nd4jPointer* extraPointers, int opNum, + OpaqueDataBuffer* dbX, Nd4jLong const* hXShapeInfo, + Nd4jLong const* dXShapeInfo, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, + OpaqueDataBuffer* dbScalar, + Nd4jLong const* hSscalarShapeInfo, + Nd4jLong const* dSscalarShapeInfo, void* extraParams); -ND4J_EXPORT void execScalarBool(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbScalar, Nd4jLong const* hSscalarShapeInfo, Nd4jLong const* dSscalarShapeInfo, - void *extraParams); +SD_EXPORT void execScalarBool( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, OpaqueDataBuffer* dbScalar, + Nd4jLong const* hSscalarShapeInfo, Nd4jLong const* dSscalarShapeInfo, + void* extraParams); /** * @@ -375,12 +387,11 @@ ND4J_EXPORT void execScalarBool(Nd4jPointer *extraPointers, * @param xShapeInfo * @param extraParams */ -ND4J_EXPORT void execSummaryStatsScalar(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - bool biasCorrected); +SD_EXPORT void execSummaryStatsScalar( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, void* extraParams, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, bool biasCorrected); /** * * @param opNum @@ -390,12 +401,11 @@ ND4J_EXPORT void execSummaryStatsScalar(Nd4jPointer *extraPointers, * @param result * @param resultShapeInfo */ -ND4J_EXPORT void execSummaryStats(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - bool biasCorrected); +SD_EXPORT void execSummaryStats( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, void* extraParams, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, bool biasCorrected); /** * * @param opNum @@ -407,14 +417,14 @@ ND4J_EXPORT void execSummaryStats(Nd4jPointer *extraPointers, * @param dimension * @param dimensionLength */ -ND4J_EXPORT void execSummaryStatsTad(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, - bool biasCorrected, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets); +SD_EXPORT void execSummaryStatsTad( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, void* extraParams, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, OpaqueDataBuffer* dbDimension, + Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, + bool biasCorrected, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets); /** * @@ -426,35 +436,37 @@ ND4J_EXPORT void execSummaryStatsTad(Nd4jPointer *extraPointers, * @param extraParams * @param n */ -ND4J_EXPORT void execTransformFloat(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams); +SD_EXPORT void execTransformFloat( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, void* extraParams); -ND4J_EXPORT void execTransformSame(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams); +SD_EXPORT void execTransformSame( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, void* extraParams); -ND4J_EXPORT void execTransformBool(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams); +SD_EXPORT void execTransformBool( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, void* extraParams); -ND4J_EXPORT void execTransformAny(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams); +SD_EXPORT void execTransformAny(Nd4jPointer* extraPointers, int opNum, + OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, + Nd4jLong const* dXShapeInfo, + OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, void* extraParams); -ND4J_EXPORT void execTransformStrict(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams); +SD_EXPORT void execTransformStrict( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, void* extraParams); /** * @@ -469,44 +481,42 @@ ND4J_EXPORT void execTransformStrict(Nd4jPointer *extraPointers, * @param dimension * @param dimensionLength */ -ND4J_EXPORT void execScalarTad(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbScalars, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ); - -ND4J_EXPORT void execScalarBoolTad(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbScalars, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ); - -ND4J_EXPORT void specialConcat ( - Nd4jPointer *extraPointers, - int dimension, - int numArrays, - Nd4jPointer *data, - Nd4jPointer *inputShapeInfo, - void *result, - Nd4jLong const* resultShapeInfo, - Nd4jPointer *tadPointers, - Nd4jPointer *offsetPointers); +SD_EXPORT void execScalarTad( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, OpaqueDataBuffer* dbScalars, + Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo, + void* extraParams, OpaqueDataBuffer* dbDimension, + Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ); + +SD_EXPORT void execScalarBoolTad( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, OpaqueDataBuffer* dbScalars, + Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo, + void* extraParams, OpaqueDataBuffer* dbDimension, + Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ); + +SD_EXPORT void specialConcat(Nd4jPointer* extraPointers, int dimension, + int numArrays, Nd4jPointer* data, + Nd4jPointer* inputShapeInfo, void* result, + Nd4jLong const* resultShapeInfo, + Nd4jPointer* tadPointers, + Nd4jPointer* offsetPointers); /** * This method implementation exists only for cuda. * The other backends should have dummy method for JNI compatibility reasons. */ -ND4J_EXPORT void initializeDevicesAndFunctions(); +SD_EXPORT void initializeDevicesAndFunctions(); -ND4J_EXPORT void initializeFunctions(Nd4jPointer *functions); +SD_EXPORT void initializeFunctions(Nd4jPointer* functions); /** * This method acquires memory chunk of requested size on host side @@ -515,24 +525,26 @@ ND4J_EXPORT void initializeFunctions(Nd4jPointer *functions); * @param memorySize memory size, in bytes * @param flags optional parameter */ -ND4J_EXPORT Nd4jPointer mallocHost(Nd4jLong memorySize, int flags); +SD_EXPORT Nd4jPointer mallocHost(Nd4jLong memorySize, int flags); /** * This method acquires memory chunk of requested size on specified device * * @param pointer pointer that'll be used for allocation * @param memorySize memory size, in bytes - * @param ptrToDeviceId pointer to deviceId. For cuda that's just and int, for OpenCL that's pointer to device_id, etc + * @param ptrToDeviceId pointer to deviceId. For cuda that's just and int, for + * OpenCL that's pointer to device_id, etc * @param flags optional parameter */ -ND4J_EXPORT Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags); +SD_EXPORT Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, + int flags); /** * This method releases previously allocated host memory space * * @param pointer pointer that'll be freed */ -ND4J_EXPORT int freeHost(Nd4jPointer pointer); +SD_EXPORT int freeHost(Nd4jPointer pointer); /** * This method releases previously allocated memory space on device @@ -540,52 +552,51 @@ ND4J_EXPORT int freeHost(Nd4jPointer pointer); * @param pointer pointer that'll be freed * @param ptrToDeviceId pointer to deviceId. */ -ND4J_EXPORT int freeDevice(Nd4jPointer pointer, int deviceId); +SD_EXPORT int freeDevice(Nd4jPointer pointer, int deviceId); /** * * @return */ -ND4J_EXPORT int ompGetMaxThreads(); +SD_EXPORT int ompGetMaxThreads(); /** * * @return */ -ND4J_EXPORT int ompGetNumThreads(); +SD_EXPORT int ompGetNumThreads(); /** * * @param threads */ -ND4J_EXPORT void setOmpNumThreads(int threads); +SD_EXPORT void setOmpNumThreads(int threads); /** * * @param threads */ -ND4J_EXPORT void setOmpMinThreads(int threads); +SD_EXPORT void setOmpMinThreads(int threads); - -ND4J_EXPORT bool isBlasVersionMatches(int major, int minor, int build); +SD_EXPORT bool isBlasVersionMatches(int major, int minor, int build); /** * * @return */ -ND4J_EXPORT Nd4jPointer createContext(); +SD_EXPORT Nd4jPointer createContext(); /** * * @return */ -ND4J_EXPORT Nd4jPointer createStream(); +SD_EXPORT Nd4jPointer createStream(); /** * * @return */ -ND4J_EXPORT Nd4jPointer createEvent(); +SD_EXPORT Nd4jPointer createEvent(); /** * @@ -593,89 +604,89 @@ ND4J_EXPORT Nd4jPointer createEvent(); * @param stream * @return */ -ND4J_EXPORT int registerEvent(Nd4jPointer event, Nd4jPointer stream); +SD_EXPORT int registerEvent(Nd4jPointer event, Nd4jPointer stream); /** * * @param event * @return */ -ND4J_EXPORT int destroyEvent(Nd4jPointer event); +SD_EXPORT int destroyEvent(Nd4jPointer event); /** * * @param ptrToDeviceId * @return */ -ND4J_EXPORT int setDevice(int deviceId); +SD_EXPORT int setDevice(int deviceId); /** * * @return */ -ND4J_EXPORT int getDevice(); +SD_EXPORT int getDevice(); /** * * @param stream * @return */ -ND4J_EXPORT int streamSynchronize(Nd4jPointer stream); +SD_EXPORT int streamSynchronize(Nd4jPointer stream); /** * * @param event * @return */ -ND4J_EXPORT int eventSynchronize(Nd4jPointer event); +SD_EXPORT int eventSynchronize(Nd4jPointer event); /** * * @param ptrToDeviceId * @return */ -ND4J_EXPORT Nd4jLong getDeviceFreeMemory(int deviceId); +SD_EXPORT Nd4jLong getDeviceFreeMemory(int deviceId); /** * Returns amount of free memory for current device * @return */ -ND4J_EXPORT Nd4jLong getDeviceFreeMemoryDefault(); +SD_EXPORT Nd4jLong getDeviceFreeMemoryDefault(); /** * * @param ptrToDeviceId * @return */ -ND4J_EXPORT Nd4jLong getDeviceTotalMemory(int deviceId); +SD_EXPORT Nd4jLong getDeviceTotalMemory(int deviceId); /** * * @param ptrToDeviceId * @return */ -ND4J_EXPORT int getDeviceMajor(int deviceId); +SD_EXPORT int getDeviceMajor(int deviceId); /** * This method returns amount of cached memory * @param deviceId * @return */ -ND4J_EXPORT Nd4jLong getCachedMemory(int deviceId); +SD_EXPORT Nd4jLong getCachedMemory(int deviceId); /** * * @param ptrToDeviceId * @return */ -ND4J_EXPORT int getDeviceMinor(int deviceId); +SD_EXPORT int getDeviceMinor(int deviceId); /** * * @param ptrToDeviceId * @return */ -ND4J_EXPORT const char * getDeviceName(int deviceId); +SD_EXPORT const char* getDeviceName(int deviceId); /** * @@ -686,11 +697,8 @@ ND4J_EXPORT const char * getDeviceName(int deviceId); * @param reserved * @return */ -ND4J_EXPORT int memcpySync(Nd4jPointer dst, - Nd4jPointer src, - Nd4jLong size, - int flags, - Nd4jPointer reserved); +SD_EXPORT int memcpySync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, + int flags, Nd4jPointer reserved); /** * @@ -701,11 +709,8 @@ ND4J_EXPORT int memcpySync(Nd4jPointer dst, * @param reserved * @return */ -ND4J_EXPORT int memcpyAsync(Nd4jPointer dst, - Nd4jPointer src, - Nd4jLong size, - int flags, - Nd4jPointer reserved); +SD_EXPORT int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, + int flags, Nd4jPointer reserved); /** * @@ -716,11 +721,8 @@ ND4J_EXPORT int memcpyAsync(Nd4jPointer dst, * @param reserved * @return */ -ND4J_EXPORT int memsetSync(Nd4jPointer dst, - int value, - Nd4jLong size, - int flags, - Nd4jPointer reserved); +SD_EXPORT int memsetSync(Nd4jPointer dst, int value, Nd4jLong size, int flags, + Nd4jPointer reserved); /** * @@ -731,11 +733,8 @@ ND4J_EXPORT int memsetSync(Nd4jPointer dst, * @param reserved * @return */ -ND4J_EXPORT int memsetAsync(Nd4jPointer dst, - int value, - Nd4jLong size, - int flags, - Nd4jPointer reserved); +SD_EXPORT int memsetAsync(Nd4jPointer dst, int value, Nd4jLong size, int flags, + Nd4jPointer reserved); /** * @@ -746,41 +745,38 @@ ND4J_EXPORT int memsetAsync(Nd4jPointer dst, * @param reserved * @return */ -ND4J_EXPORT int memcpyConstantAsync(Nd4jLong dst, - Nd4jPointer src, - Nd4jLong size, - int flags, - Nd4jPointer reserved); +SD_EXPORT int memcpyConstantAsync(Nd4jLong dst, Nd4jPointer src, Nd4jLong size, + int flags, Nd4jPointer reserved); /** * * @return */ -ND4J_EXPORT Nd4jPointer getConstantSpace(); +SD_EXPORT Nd4jPointer getConstantSpace(); /** * * @return */ -ND4J_EXPORT int getAvailableDevices(); +SD_EXPORT int getAvailableDevices(); /** * * @param reallyEnable */ -ND4J_EXPORT void enableDebugMode(bool reallyEnable); +SD_EXPORT void enableDebugMode(bool reallyEnable); /** * * @param reallyEnable */ -ND4J_EXPORT void enableVerboseMode(bool reallyEnable); +SD_EXPORT void enableVerboseMode(bool reallyEnable); /** * * @param gridSize */ -ND4J_EXPORT void setGridLimit(int gridSize); +SD_EXPORT void setGridLimit(int gridSize); typedef sd::TadPack OpaqueTadPack; @@ -792,18 +788,17 @@ typedef sd::TadPack OpaqueTadPack; * @param targetBuffer * @param offsetsBuffer */ -ND4J_EXPORT OpaqueTadPack* tadOnlyShapeInfo(Nd4jLong const*xShapeInfo, - int *dimension, - int dimensionLength); +SD_EXPORT OpaqueTadPack* tadOnlyShapeInfo(Nd4jLong const* xShapeInfo, + int* dimension, int dimensionLength); -ND4J_EXPORT Nd4jLong const* getPrimaryShapeInfo(OpaqueTadPack* pack); -ND4J_EXPORT Nd4jLong const* getPrimaryOffsets(OpaqueTadPack* pack); -ND4J_EXPORT Nd4jLong const* getSpecialShapeInfo(OpaqueTadPack* pack); -ND4J_EXPORT Nd4jLong const* getSpecialOffsets(OpaqueTadPack* pack); -ND4J_EXPORT Nd4jLong getNumberOfTads(OpaqueTadPack* pack); -ND4J_EXPORT int getShapeInfoLength(OpaqueTadPack* pack); +SD_EXPORT Nd4jLong const* getPrimaryShapeInfo(OpaqueTadPack* pack); +SD_EXPORT Nd4jLong const* getPrimaryOffsets(OpaqueTadPack* pack); +SD_EXPORT Nd4jLong const* getSpecialShapeInfo(OpaqueTadPack* pack); +SD_EXPORT Nd4jLong const* getSpecialOffsets(OpaqueTadPack* pack); +SD_EXPORT Nd4jLong getNumberOfTads(OpaqueTadPack* pack); +SD_EXPORT int getShapeInfoLength(OpaqueTadPack* pack); -ND4J_EXPORT void deleteTadPack(OpaqueTadPack* ptr); +SD_EXPORT void deleteTadPack(OpaqueTadPack* ptr); /* * PullRow special op @@ -823,15 +818,14 @@ ND4J_EXPORT void deleteTadPack(OpaqueTadPack* ptr); * @param zTadShapeInfo * @param zTadOffsets */ -ND4J_EXPORT void pullRows(Nd4jPointer *extraPointers, - OpaqueDataBuffer *dbX, Nd4jLong const* xShapeInfo, Nd4jLong const* dxShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* zShapeInfo, Nd4jLong const* dzShapeInfo, - Nd4jLong n, - Nd4jLong *indexes, - Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffsets, - Nd4jLong const* zTadShapeInfo, - Nd4jLong const* zTadOffsets); +SD_EXPORT void pullRows(Nd4jPointer* extraPointers, OpaqueDataBuffer* dbX, + Nd4jLong const* xShapeInfo, Nd4jLong const* dxShapeInfo, + OpaqueDataBuffer* dbZ, Nd4jLong const* zShapeInfo, + Nd4jLong const* dzShapeInfo, Nd4jLong n, + Nd4jLong* indexes, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, + Nd4jLong const* zTadShapeInfo, + Nd4jLong const* zTadOffsets); /** * @@ -842,24 +836,18 @@ ND4J_EXPORT void pullRows(Nd4jPointer *extraPointers, * @param length * @param propagate */ -ND4J_EXPORT void average(Nd4jPointer *extras, - Nd4jPointer *x, Nd4jLong const* xShapeInfo, - Nd4jPointer *dx, Nd4jLong const* dxShapeInfo, - void *z, Nd4jLong const* zShapeInfo, - void *dz, Nd4jLong const* dzShapeInfo, - int n, - Nd4jLong length, - bool propagate); - - -ND4J_EXPORT void accumulate(Nd4jPointer *extras, - Nd4jPointer *x, Nd4jLong const* xShapeInfo, - Nd4jPointer *dx, Nd4jLong const* dxShapeInfo, - void *z, Nd4jLong const* zShapeInfo, - void *dz, Nd4jLong const* dzShapeInfo, - int n, - Nd4jLong length); +SD_EXPORT void average(Nd4jPointer* extras, Nd4jPointer* x, + Nd4jLong const* xShapeInfo, Nd4jPointer* dx, + Nd4jLong const* dxShapeInfo, void* z, + Nd4jLong const* zShapeInfo, void* dz, + Nd4jLong const* dzShapeInfo, int n, Nd4jLong length, + bool propagate); +SD_EXPORT void accumulate(Nd4jPointer* extras, Nd4jPointer* x, + Nd4jLong const* xShapeInfo, Nd4jPointer* dx, + Nd4jLong const* dxShapeInfo, void* z, + Nd4jLong const* zShapeInfo, void* dz, + Nd4jLong const* dzShapeInfo, int n, Nd4jLong length); /** * P2P enabler @@ -868,18 +856,18 @@ ND4J_EXPORT void accumulate(Nd4jPointer *extras, * * @param enable */ -ND4J_EXPORT void enableP2P(bool enable); +SD_EXPORT void enableP2P(bool enable); /** * */ -ND4J_EXPORT void checkP2P(); +SD_EXPORT void checkP2P(); /** * * @return */ -ND4J_EXPORT bool isP2PAvailable(); +SD_EXPORT bool isP2PAvailable(); /** * Shuffle methods @@ -897,16 +885,12 @@ ND4J_EXPORT bool isP2PAvailable(); * @param tadShapeInfo * @param tadOffsets */ -ND4J_EXPORT void shuffle(Nd4jPointer *extras, - Nd4jPointer *x, Nd4jPointer *xShapeInfo, - Nd4jPointer *dx, Nd4jPointer *dxShapeInfo, - Nd4jPointer *z, Nd4jPointer *zShapeInfo, - Nd4jPointer *dz, Nd4jPointer *dzShapeInfo, - int N, - int *shuffleMap, - Nd4jPointer *tadShapeInfo, - Nd4jPointer *tadOffsets); - +SD_EXPORT void shuffle(Nd4jPointer* extras, Nd4jPointer* x, + Nd4jPointer* xShapeInfo, Nd4jPointer* dx, + Nd4jPointer* dxShapeInfo, Nd4jPointer* z, + Nd4jPointer* zShapeInfo, Nd4jPointer* dz, + Nd4jPointer* dzShapeInfo, int N, int* shuffleMap, + Nd4jPointer* tadShapeInfo, Nd4jPointer* tadOffsets); /** * Type Conversions @@ -921,14 +905,14 @@ ND4J_EXPORT void shuffle(Nd4jPointer *extras, * @param dstType * @param z */ -ND4J_EXPORT void convertTypes(Nd4jPointer *extras, int srcType, Nd4jPointer x, Nd4jLong N, int dstType, Nd4jPointer z); - +SD_EXPORT void convertTypes(Nd4jPointer* extras, int srcType, Nd4jPointer x, + Nd4jLong N, int dstType, Nd4jPointer z); /** * * @return */ -ND4J_EXPORT bool isExperimentalEnabled(); +SD_EXPORT bool isExperimentalEnabled(); /** * Aggregate @@ -949,44 +933,25 @@ ND4J_EXPORT bool isExperimentalEnabled(); * @param realArguments * @param numRealArguments */ -ND4J_EXPORT void execAggregate(Nd4jPointer *extraPointers, - int opNum, - void **arguments, - int numArguments, - Nd4jLong **shapeArguments, - int numShapeArguments, - int *indexArguments, - int numIndexArguments, - int **intArrays, - int numIntArrays, - void *realArguments, - int numRealArguments, - sd::DataType dtype); - - -ND4J_EXPORT void batchExecutor(Nd4jPointer *extraPointers, - int numAggregates, - int opNum, - int maxArgs, - int maxShapes, - int maxIntArrays, - int maxIntArraySize, - int maxIdx, - int maxReals, - void *ptrToArguments, - sd::DataType dtype); - -ND4J_EXPORT void execAggregateBatch(Nd4jPointer *extraPointers, - int numAggregates, - int opNum, - int maxArgs, - int maxShapes, - int maxIntArrays, - int maxIntArraySize, - int maxIdx, - int maxReals, - void *ptrToArguments, - sd::DataType dtype); +SD_EXPORT void execAggregate(Nd4jPointer* extraPointers, int opNum, + void** arguments, int numArguments, + Nd4jLong** shapeArguments, int numShapeArguments, + int* indexArguments, int numIndexArguments, + int** intArrays, int numIntArrays, + void* realArguments, int numRealArguments, + sd::DataType dtype); + +SD_EXPORT void batchExecutor(Nd4jPointer* extraPointers, int numAggregates, + int opNum, int maxArgs, int maxShapes, + int maxIntArrays, int maxIntArraySize, int maxIdx, + int maxReals, void* ptrToArguments, + sd::DataType dtype); + +SD_EXPORT void execAggregateBatch(Nd4jPointer* extraPointers, int numAggregates, + int opNum, int maxArgs, int maxShapes, + int maxIntArrays, int maxIntArraySize, + int maxIdx, int maxReals, + void* ptrToArguments, sd::DataType dtype); /** * Random operations @@ -1001,11 +966,10 @@ ND4J_EXPORT void execAggregateBatch(Nd4jPointer *extraPointers, * @param zShapeBuffer * @param extraArguments */ -ND4J_EXPORT void execRandom(Nd4jPointer *extraPointers, - int opNum, - Nd4jPointer state, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeBuffer, Nd4jLong const* dZShapeBuffer, - void *extraArguments); +SD_EXPORT void execRandom(Nd4jPointer* extraPointers, int opNum, + Nd4jPointer state, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeBuffer, + Nd4jLong const* dZShapeBuffer, void* extraArguments); /** * @@ -1020,13 +984,14 @@ ND4J_EXPORT void execRandom(Nd4jPointer *extraPointers, * @param zShapeBuffer * @param extraArguments */ -ND4J_EXPORT void execRandom3(Nd4jPointer *extraPointers, - int opNum, - Nd4jPointer state, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeBuffer, Nd4jLong const* dXShapeBuffer, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeBuffer, Nd4jLong const* dYShapeBuffer, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeBuffer, Nd4jLong const* dZShapeBuffer, - void *extraArguments); +SD_EXPORT void execRandom3(Nd4jPointer* extraPointers, int opNum, + Nd4jPointer state, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeBuffer, + Nd4jLong const* dXShapeBuffer, OpaqueDataBuffer* dbY, + Nd4jLong const* hYShapeBuffer, + Nd4jLong const* dYShapeBuffer, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeBuffer, + Nd4jLong const* dZShapeBuffer, void* extraArguments); /** * @@ -1039,13 +1004,12 @@ ND4J_EXPORT void execRandom3(Nd4jPointer *extraPointers, * @param zShapeBuffer * @param extraArguments */ -ND4J_EXPORT void execRandom2(Nd4jPointer *extraPointers, - int opNum, - Nd4jPointer state, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeBuffer, Nd4jLong const* dXShapeBuffer, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeBuffer, Nd4jLong const* dZShapeBuffer, - void *extraArguments); - +SD_EXPORT void execRandom2(Nd4jPointer* extraPointers, int opNum, + Nd4jPointer state, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeBuffer, + Nd4jLong const* dXShapeBuffer, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeBuffer, + Nd4jLong const* dZShapeBuffer, void* extraArguments); /** * @@ -1055,10 +1019,8 @@ ND4J_EXPORT void execRandom2(Nd4jPointer *extraPointers, * @param ptrToBuffer * @return */ -ND4J_EXPORT Nd4jPointer initRandom(Nd4jPointer *extraPointers, - long seed, - long bufferSize, - Nd4jPointer ptrToBuffer); +SD_EXPORT Nd4jPointer initRandom(Nd4jPointer* extraPointers, long seed, + long bufferSize, Nd4jPointer ptrToBuffer); /** * @@ -1066,9 +1028,8 @@ ND4J_EXPORT Nd4jPointer initRandom(Nd4jPointer *extraPointers, * @param seed * @param ptrRandom */ -ND4J_EXPORT void refreshBuffer(Nd4jPointer *extraPointers, - long seed, - Nd4jPointer ptrRandom); +SD_EXPORT void refreshBuffer(Nd4jPointer* extraPointers, long seed, + Nd4jPointer ptrRandom); /** * @@ -1076,349 +1037,357 @@ ND4J_EXPORT void refreshBuffer(Nd4jPointer *extraPointers, * @param seed * @param ptrRandom */ -ND4J_EXPORT void reSeedBuffer(Nd4jPointer *extraPointers, - long seed, - Nd4jPointer ptrRandom); +SD_EXPORT void reSeedBuffer(Nd4jPointer* extraPointers, long seed, + Nd4jPointer ptrRandom); /** * * @param ptrRandom */ -ND4J_EXPORT void destroyRandom(Nd4jPointer ptrRandom); - +SD_EXPORT void destroyRandom(Nd4jPointer ptrRandom); } /** -* -* @param data -* @param shapeBuffer -* @param wordSize -* @param headerSize -* @return -*/ + * + * @param data + * @param shapeBuffer + * @param wordSize + * @param headerSize + * @return + */ template -static Nd4jPointer _numpyHeaderForNd4j(Nd4jPointer data,const Nd4jPointer shapeBuffer,Nd4jLong wordSize,Nd4jLong* headerSize) { - Nd4jLong const* shapeBufferCast = reinterpret_cast(shapeBuffer); - int rank = shape::rank(shapeBufferCast); - const Nd4jLong* shape = shape::shapeOf(shapeBufferCast); - unsigned int* npShape = new unsigned int[rank]; - for(int i = 0; i < rank; i++) { - npShape[i] = shape[i]; - } - - Nd4jLong length = shape::prodLong(shape,rank); - auto npHeader = cnpy::createNpyHeader(data,npShape,rank,wordSize); - char *ret = new char[npHeader.size() + 1]; - int count = 0; - for(int i = 0; i < npHeader.size(); i++) { - ret[count] = npHeader[i]; - count++; - } - - ret[count] = '\0'; +static Nd4jPointer _numpyHeaderForNd4j(Nd4jPointer data, + const Nd4jPointer shapeBuffer, + Nd4jLong wordSize, + Nd4jLong* headerSize) { + Nd4jLong const* shapeBufferCast = + reinterpret_cast(shapeBuffer); + int rank = shape::rank(shapeBufferCast); + const Nd4jLong* shape = shape::shapeOf(shapeBufferCast); + unsigned int* npShape = new unsigned int[rank]; + for (int i = 0; i < rank; i++) { + npShape[i] = shape[i]; + } + + Nd4jLong length = shape::prodLong(shape, rank); + auto npHeader = cnpy::createNpyHeader(data, npShape, rank, wordSize); + char* ret = new char[npHeader.size() + 1]; + int count = 0; + for (int i = 0; i < npHeader.size(); i++) { + ret[count] = npHeader[i]; count++; + } + + ret[count] = '\0'; + count++; - *headerSize = count; - return reinterpret_cast(ret); + *headerSize = count; + return reinterpret_cast(ret); } extern "C" { -static Nd4jPointer numpyHeaderForNd4j(Nd4jPointer data,Nd4jPointer shapeBuffer,Nd4jLong wordSize,Nd4jLong* headerSize) { - auto shapeBufferCast = reinterpret_cast(shapeBuffer); - auto type = sd::ArrayOptions::dataType(shapeBufferCast); - BUILD_SINGLE_SELECTOR(type, return _numpyHeaderForNd4j, (data, shapeBuffer, wordSize, headerSize), LIBND4J_TYPES); +static Nd4jPointer numpyHeaderForNd4j(Nd4jPointer data, Nd4jPointer shapeBuffer, + Nd4jLong wordSize, Nd4jLong* headerSize) { + auto shapeBufferCast = reinterpret_cast(shapeBuffer); + auto type = sd::ArrayOptions::dataType(shapeBufferCast); + BUILD_SINGLE_SELECTOR(type, return _numpyHeaderForNd4j, + (data, shapeBuffer, wordSize, headerSize), + LIBND4J_TYPES); } /** -* Load numpy from a header -* based on the cnpy parse from header method. -* @param data the header data to parse -* @return a pointer to a numpy cnpy:NpyArray struct -*/ + * Load numpy from a header + * based on the cnpy parse from header method. + * @param data the header data to parse + * @return a pointer to a numpy cnpy:NpyArray struct + */ static Nd4jPointer loadNpyFromHeader(Nd4jPointer data) { - char *header = reinterpret_cast(data); - - cnpy::NpyArray arr = cnpy::loadNpyFromHeader(header); - cnpy::NpyArray *ret = new cnpy::NpyArray(); - int totalLengthOfShape = 1; - for(int i = 0; i < arr.shape.size(); i++) { - totalLengthOfShape *= arr.shape[i]; - } - - ret->data = arr.data; - ret->wordSize = arr.wordSize; - ret->shape = arr.shape; - return reinterpret_cast(ret); + char* header = reinterpret_cast(data); + + cnpy::NpyArray arr = cnpy::loadNpyFromHeader(header); + cnpy::NpyArray* ret = new cnpy::NpyArray(); + int totalLengthOfShape = 1; + for (int i = 0; i < arr.shape.size(); i++) { + totalLengthOfShape *= arr.shape[i]; + } + + ret->data = arr.data; + ret->wordSize = arr.wordSize; + ret->shape = arr.shape; + return reinterpret_cast(ret); } - } /** -* Create a numpy array from an nd4j -* array -* @param data a pointer to the data -* @param shapeBuffer the shapebuffer for the nd4j array -* @param wordSize the word size (4 for float, 8 for doubles) -* @return a pointer to a numpy array -*/ + * Create a numpy array from an nd4j + * array + * @param data a pointer to the data + * @param shapeBuffer the shapebuffer for the nd4j array + * @param wordSize the word size (4 for float, 8 for doubles) + * @return a pointer to a numpy array + */ template -static Nd4jPointer _numpyFromNd4j(Nd4jPointer data,Nd4jPointer shapeBuffer,Nd4jLong wordSize) { - Nd4jLong *shapeBufferCast = reinterpret_cast(shapeBuffer); - int rank = shape::rank(shapeBufferCast); - Nd4jLong *shape = shape::shapeOf(shapeBufferCast); - unsigned int *npShape = new unsigned int[rank]; - for(int i = 0; i < rank; i++) { - npShape[i] = shape[i]; - } - - Nd4jLong length = shape::prodLong(shape,rank); - auto npHeader = cnpy::createNpyHeader(data,npShape,rank,wordSize); - char *dataChar = reinterpret_cast(data); - char *npHeaderData = npHeader.data(); - char *ret = new char[(wordSize * length) + npHeader.size()]; - char *cursorStart = ret; - std::memcpy(reinterpret_cast(ret), reinterpret_cast(npHeaderData), npHeader.size() * sizeof(Nd4jLong)); - //move to next - cursorStart += npHeader.size(); - std::memcpy(reinterpret_cast(ret), reinterpret_cast(dataChar), length * wordSize * sizeof(Nd4jLong)); - Nd4jPointer rettPointer = reinterpret_cast(ret); - return rettPointer; +static Nd4jPointer _numpyFromNd4j(Nd4jPointer data, Nd4jPointer shapeBuffer, + Nd4jLong wordSize) { + Nd4jLong* shapeBufferCast = reinterpret_cast(shapeBuffer); + int rank = shape::rank(shapeBufferCast); + Nd4jLong* shape = shape::shapeOf(shapeBufferCast); + unsigned int* npShape = new unsigned int[rank]; + for (int i = 0; i < rank; i++) { + npShape[i] = shape[i]; + } + + Nd4jLong length = shape::prodLong(shape, rank); + auto npHeader = cnpy::createNpyHeader(data, npShape, rank, wordSize); + char* dataChar = reinterpret_cast(data); + char* npHeaderData = npHeader.data(); + char* ret = new char[(wordSize * length) + npHeader.size()]; + char* cursorStart = ret; + std::memcpy(reinterpret_cast(ret), + reinterpret_cast(npHeaderData), + npHeader.size() * sizeof(Nd4jLong)); + // move to next + cursorStart += npHeader.size(); + std::memcpy(reinterpret_cast(ret), reinterpret_cast(dataChar), + length * wordSize * sizeof(Nd4jLong)); + Nd4jPointer rettPointer = reinterpret_cast(ret); + return rettPointer; } extern "C" { -static Nd4jPointer numpyFromNd4j(Nd4jPointer data,Nd4jPointer shapeBuffer,Nd4jLong wordSize) { - auto shapeBufferCast = reinterpret_cast(shapeBuffer); - auto type = sd::ArrayOptions::dataType(shapeBufferCast); - BUILD_SINGLE_SELECTOR(type, return _numpyFromNd4j, (data, shapeBuffer, wordSize), LIBND4J_TYPES); +static Nd4jPointer numpyFromNd4j(Nd4jPointer data, Nd4jPointer shapeBuffer, + Nd4jLong wordSize) { + auto shapeBufferCast = reinterpret_cast(shapeBuffer); + auto type = sd::ArrayOptions::dataType(shapeBufferCast); + BUILD_SINGLE_SELECTOR(type, return _numpyFromNd4j, + (data, shapeBuffer, wordSize), LIBND4J_TYPES); } - /** -* -* @param npyArray -* @return -*/ -ND4J_EXPORT Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray); - + * + * @param npyArray + * @return + */ +SD_EXPORT Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray); /** -* Get the shape buffer from a -* numpy array. -* **Warning** this allocates memory -* @param npyArray -* @return -*/ + * Get the shape buffer from a + * numpy array. + * **Warning** this allocates memory + * @param npyArray + * @return + */ static Nd4jPointer shapeBufferForNumpyHeader(Nd4jPointer npyArray) { - cnpy::NpyArray arr = cnpy::loadNpyFromHeader(reinterpret_cast(npyArray)); - auto shape = new unsigned int[arr.shape.size()]; - for(unsigned int i = 0; i < arr.shape.size(); i++) { - shape[i] = arr.shape[i]; - } - - auto shapeBuffer = shape::shapeBufferOfNpy(arr.shape.size(), shape, arr.fortranOrder); - delete[] shape; - return reinterpret_cast(shapeBuffer); + cnpy::NpyArray arr = + cnpy::loadNpyFromHeader(reinterpret_cast(npyArray)); + auto shape = new unsigned int[arr.shape.size()]; + for (unsigned int i = 0; i < arr.shape.size(); i++) { + shape[i] = arr.shape[i]; + } + + auto shapeBuffer = + shape::shapeBufferOfNpy(arr.shape.size(), shape, arr.fortranOrder); + delete[] shape; + return reinterpret_cast(shapeBuffer); } - - /** -* -* @param npyArray -* @return -*/ + * + * @param npyArray + * @return + */ static Nd4jPointer dataPointForNumpyHeader(Nd4jPointer npyArray) { - cnpy::NpyArray arr = cnpy::loadNpyFromHeader(reinterpret_cast(npyArray)); - unsigned char *dataToPrint = reinterpret_cast(arr.data); - return dataToPrint; + cnpy::NpyArray arr = + cnpy::loadNpyFromHeader(reinterpret_cast(npyArray)); + unsigned char* dataToPrint = reinterpret_cast(arr.data); + return dataToPrint; } /** -* -* @param npyArray -* @return -*/ + * + * @param npyArray + * @return + */ static Nd4jPointer dataPointForNumpyStruct(Nd4jPointer npyArrayStruct) { - cnpy::NpyArray *arrPointer = reinterpret_cast(npyArrayStruct); - unsigned char *dataToPrint = reinterpret_cast(arrPointer->data); - return reinterpret_cast(dataToPrint); + cnpy::NpyArray* arrPointer = + reinterpret_cast(npyArrayStruct); + unsigned char* dataToPrint = + reinterpret_cast(arrPointer->data); + return reinterpret_cast(dataToPrint); } /** -* -* @param npyArray -* @param fromFile -* @return -*/ + * + * @param npyArray + * @param fromFile + * @return + */ static Nd4jPointer dataPointForNumpy(Nd4jPointer npyArray) { - char *npyArrayBuffer = reinterpret_cast< char *>(npyArray); - cnpy::NpyArray arr = cnpy::loadNpyFromPointer(npyArrayBuffer); - return dataPointForNumpyStruct(reinterpret_cast(&arr)); + char* npyArrayBuffer = reinterpret_cast(npyArray); + cnpy::NpyArray arr = cnpy::loadNpyFromPointer(npyArrayBuffer); + return dataPointForNumpyStruct(reinterpret_cast(&arr)); } /** -* Load a numpy array from a file -* and return it as an Nd4jPointer -* @param path -* @return -*/ + * Load a numpy array from a file + * and return it as an Nd4jPointer + * @param path + * @return + */ static Nd4jPointer numpyFromFile(std::string path) { - char *numpyBuffer = cnpy::loadFile(path.data()); - return reinterpret_cast(numpyBuffer); + char* numpyBuffer = cnpy::loadFile(path.data()); + return reinterpret_cast(numpyBuffer); } - ////// NPZ ////// -static void* mapFromNpzFile(std::string path){ - cnpy::npz_t* mapPtr = new cnpy::npz_t(); - cnpy::npz_t map = cnpy::npzLoad(path); - mapPtr->insert(map.begin(), map.end()); - return reinterpret_cast(mapPtr); +static void* mapFromNpzFile(std::string path) { + cnpy::npz_t* mapPtr = new cnpy::npz_t(); + cnpy::npz_t map = cnpy::npzLoad(path); + mapPtr->insert(map.begin(), map.end()); + return reinterpret_cast(mapPtr); } - -static int getNumNpyArraysInMap(void *map){ - cnpy::npz_t* arrays = reinterpret_cast(map); - int n = arrays->size(); - return n; +static int getNumNpyArraysInMap(void* map) { + cnpy::npz_t* arrays = reinterpret_cast(map); + int n = arrays->size(); + return n; } -static const char* getNpyArrayNameFromMap(void *map, int index){ - cnpy::npz_t* arrays = reinterpret_cast(map); - cnpy::npz_t::iterator it = arrays->begin(); - cnpy::npz_t::iterator end = arrays->end(); - int cnt = 0; - for(; it != end; ++it, ++cnt){ - if (cnt == index){ - // FIXME: @fariz, this is a leak! +static const char* getNpyArrayNameFromMap(void* map, int index) { + cnpy::npz_t* arrays = reinterpret_cast(map); + cnpy::npz_t::iterator it = arrays->begin(); + cnpy::npz_t::iterator end = arrays->end(); + int cnt = 0; + for (; it != end; ++it, ++cnt) { + if (cnt == index) { + // FIXME: @fariz, this is a leak! #ifdef _MSC_VER - return const_cast(_strdup(it->first.c_str())); + return const_cast(_strdup(it->first.c_str())); #else - return const_cast(strdup(it->first.c_str())); + return const_cast(strdup(it->first.c_str())); #endif - } } - throw std::runtime_error("No array at index."); + } + throw std::runtime_error("No array at index."); } -static void* getNpyArrayFromMap(void *map, int index){ - cnpy::npz_t* arrays = reinterpret_cast(map); - cnpy::npz_t::iterator it = arrays->begin(); - cnpy::npz_t::iterator end = arrays->end(); - cnpy::NpyArray *arr = new cnpy::NpyArray(); - int cnt = 0; - for(; it != end; ++it, ++cnt){ - if (cnt == index){ - *arr = it->second; - return arr; - } +static void* getNpyArrayFromMap(void* map, int index) { + cnpy::npz_t* arrays = reinterpret_cast(map); + cnpy::npz_t::iterator it = arrays->begin(); + cnpy::npz_t::iterator end = arrays->end(); + cnpy::NpyArray* arr = new cnpy::NpyArray(); + int cnt = 0; + for (; it != end; ++it, ++cnt) { + if (cnt == index) { + *arr = it->second; + return arr; } - throw std::runtime_error("No array at index."); + } + throw std::runtime_error("No array at index."); } -ND4J_EXPORT int dataTypeFromNpyHeader(void *header); +SD_EXPORT int dataTypeFromNpyHeader(void* header); -static void* getNpyArrayData(void *npArray){ - cnpy::NpyArray* npyArray2 = reinterpret_cast(npArray); - return reinterpret_cast(npyArray2->data); +static void* getNpyArrayData(void* npArray) { + cnpy::NpyArray* npyArray2 = reinterpret_cast(npArray); + return reinterpret_cast(npyArray2->data); } -static int getNpyArrayRank(void *npArray){ - cnpy::NpyArray* arr = reinterpret_cast(npArray); - int rank = arr->shape.size(); - return rank; +static int getNpyArrayRank(void* npArray) { + cnpy::NpyArray* arr = reinterpret_cast(npArray); + int rank = arr->shape.size(); + return rank; } -static Nd4jLong* getNpyArrayShape(void *npArray){ - cnpy::NpyArray* arr = reinterpret_cast(npArray); - int ndim = arr->shape.size(); - Nd4jLong* shape = new Nd4jLong[ndim]; - for (int i=0; ishape.at(i); - } - return shape; +static Nd4jLong* getNpyArrayShape(void* npArray) { + cnpy::NpyArray* arr = reinterpret_cast(npArray); + int ndim = arr->shape.size(); + Nd4jLong* shape = new Nd4jLong[ndim]; + for (int i = 0; i < ndim; i++) { + shape[i] = arr->shape.at(i); + } + return shape; } -static char getNpyArrayOrder(void *npArray){ - cnpy::NpyArray* arr = reinterpret_cast(npArray); - return (arr->fortranOrder)?'f':'c'; +static char getNpyArrayOrder(void* npArray) { + cnpy::NpyArray* arr = reinterpret_cast(npArray); + return (arr->fortranOrder) ? 'f' : 'c'; } -static int getNpyArrayElemSize(void *npArray){ - cnpy::NpyArray* arr = reinterpret_cast(npArray); - return arr->wordSize; +static int getNpyArrayElemSize(void* npArray) { + cnpy::NpyArray* arr = reinterpret_cast(npArray); + return arr->wordSize; } -static void deleteNPArrayStruct(void *npArray){ - cnpy::NpyArray* arr = reinterpret_cast(npArray); - delete arr; +static void deleteNPArrayStruct(void* npArray) { + cnpy::NpyArray* arr = reinterpret_cast(npArray); + delete arr; } -static void deleteNPArrayMap(void *map){ - cnpy::npz_t* arrays = reinterpret_cast(map); - delete arrays; +static void deleteNPArrayMap(void* map) { + cnpy::npz_t* arrays = reinterpret_cast(map); + delete arrays; } ////// /** -* Get the element size for a numpy array -* @param npyArray the numpy array's address -* to get the length for -* @return -*/ + * Get the element size for a numpy array + * @param npyArray the numpy array's address + * to get the length for + * @return + */ static int elementSizeForNpyArray(Nd4jPointer npyArray) { - cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast(npyArray)); - cnpy::NpyArray *arrPointer = &arr; - int size = arrPointer->wordSize; - // arrPointer->destruct(); - return size; + cnpy::NpyArray arr = + cnpy::loadNpyFromPointer(reinterpret_cast(npyArray)); + cnpy::NpyArray* arrPointer = &arr; + int size = arrPointer->wordSize; + // arrPointer->destruct(); + return size; } - /** -* Get the element size for a numpy array -* @param npyArray the numpy array's address -* to get the length for -* @return -*/ + * Get the element size for a numpy array + * @param npyArray the numpy array's address + * to get the length for + * @return + */ static int elementSizeForNpyArrayHeader(Nd4jPointer npyArray) { - cnpy::NpyArray arr = cnpy::loadNpyFromHeader(reinterpret_cast(npyArray)); - cnpy::NpyArray *arrPointer = &arr; - int size = arrPointer->wordSize; - return size; + cnpy::NpyArray arr = + cnpy::loadNpyFromHeader(reinterpret_cast(npyArray)); + cnpy::NpyArray* arrPointer = &arr; + int size = arrPointer->wordSize; + return size; } - static void releaseNumpy(Nd4jPointer npyArray) { - free(reinterpret_cast(npyArray)); + free(reinterpret_cast(npyArray)); } - /** * Return the length of a shape buffer * based on the pointer * @param buffer the buffer pointer to check * @return */ -ND4J_EXPORT int lengthForShapeBufferPointer(Nd4jPointer buffer); - +SD_EXPORT int lengthForShapeBufferPointer(Nd4jPointer buffer); - /** -* The pointer to get the address for -* -* @param address the address to get the pointer -* @return the pointer for the given address -*/ +/** + * The pointer to get the address for + * + * @param address the address to get the pointer + * @return the pointer for the given address + */ -ND4J_EXPORT Nd4jPointer pointerForAddress(Nd4jLong address); +SD_EXPORT Nd4jPointer pointerForAddress(Nd4jLong address); /** - * This method takes single N-dimensional tensor, and copies its TADs to target arrays + * This method takes single N-dimensional tensor, and copies its TADs to target + * arrays * * @param x * @param xShapeInfo @@ -1426,241 +1395,300 @@ ND4J_EXPORT Nd4jPointer pointerForAddress(Nd4jLong address); * @param zShapeInfo * @return */ -ND4J_EXPORT void tear(Nd4jPointer *extraPointers, - OpaqueDataBuffer *dbX, Nd4jLong const* xShapeInfo, Nd4jLong const* dxShapeInfo, - Nd4jPointer *targets, Nd4jLong const* zShapeInfo, - Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffsets); - -ND4J_EXPORT void sort(Nd4jPointer *extraPointers, - void *x, Nd4jLong const* xShapeInfo, - void *dx, Nd4jLong const* dxShapeInfo, - bool descending); - -ND4J_EXPORT void sortByKey(Nd4jPointer *extraPointers, - void *x, Nd4jLong const* xShapeInfo, - void *dx, Nd4jLong const* dxShapeInfo, - void *y, Nd4jLong const* yShapeInfo, - void *dy, Nd4jLong const* dyShapeInfo, - bool descending); - -ND4J_EXPORT void sortByValue(Nd4jPointer *extraPointers, - void *x, Nd4jLong const* xShapeInfo, - void *dx, Nd4jLong const* dxShapeInfo, - void *y, Nd4jLong const* yShapeInfo, - void *dy, Nd4jLong const* dyShapeInfo, - bool descending); - -ND4J_EXPORT void sortTad(Nd4jPointer *extraPointers, - void *x, Nd4jLong const* xShapeInfo, - void *dx, Nd4jLong const* dxShapeInfo, - int *dimension, - int dimensionLength, - Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffsets, - bool descending); - -ND4J_EXPORT void sortTadByKey(Nd4jPointer *extraPointers, - void *x, Nd4jLong const* xShapeInfo, - void *dx, Nd4jLong const* dxShapeInfo, - void *y, Nd4jLong const* yShapeInfo, - void *dy, Nd4jLong const* dyShapeInfo, - int *dimension, - int dimensionLength, - bool descending); - -ND4J_EXPORT void sortTadByValue(Nd4jPointer *extraPointers, - void *x, Nd4jLong const* xShapeInfo, - void *dx, Nd4jLong const* dxShapeInfo, - void *y, Nd4jLong const* yShapeInfo, - void *dy, Nd4jLong const* dyShapeInfo, - int *dimension, - int dimensionLength, - bool descending); - +SD_EXPORT void tear(Nd4jPointer* extraPointers, OpaqueDataBuffer* dbX, + Nd4jLong const* xShapeInfo, Nd4jLong const* dxShapeInfo, + Nd4jPointer* targets, Nd4jLong const* zShapeInfo, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets); + +SD_EXPORT void sort(Nd4jPointer* extraPointers, void* x, + Nd4jLong const* xShapeInfo, void* dx, + Nd4jLong const* dxShapeInfo, bool descending); + +SD_EXPORT void sortByKey(Nd4jPointer* extraPointers, void* x, + Nd4jLong const* xShapeInfo, void* dx, + Nd4jLong const* dxShapeInfo, void* y, + Nd4jLong const* yShapeInfo, void* dy, + Nd4jLong const* dyShapeInfo, bool descending); + +SD_EXPORT void sortByValue(Nd4jPointer* extraPointers, void* x, + Nd4jLong const* xShapeInfo, void* dx, + Nd4jLong const* dxShapeInfo, void* y, + Nd4jLong const* yShapeInfo, void* dy, + Nd4jLong const* dyShapeInfo, bool descending); + +SD_EXPORT void sortTad(Nd4jPointer* extraPointers, void* x, + Nd4jLong const* xShapeInfo, void* dx, + Nd4jLong const* dxShapeInfo, int* dimension, + int dimensionLength, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, bool descending); + +SD_EXPORT void sortTadByKey(Nd4jPointer* extraPointers, void* x, + Nd4jLong const* xShapeInfo, void* dx, + Nd4jLong const* dxShapeInfo, void* y, + Nd4jLong const* yShapeInfo, void* dy, + Nd4jLong const* dyShapeInfo, int* dimension, + int dimensionLength, bool descending); + +SD_EXPORT void sortTadByValue(Nd4jPointer* extraPointers, void* x, + Nd4jLong const* xShapeInfo, void* dx, + Nd4jLong const* dxShapeInfo, void* y, + Nd4jLong const* yShapeInfo, void* dy, + Nd4jLong const* dyShapeInfo, int* dimension, + int dimensionLength, bool descending); // special sort impl for sorting out COO indices and values -ND4J_EXPORT void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank); - +SD_EXPORT void sortCooIndices(Nd4jPointer* extraPointers, Nd4jLong* indices, + void* values, Nd4jLong length, int rank); -ND4J_EXPORT Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length); +SD_EXPORT Nd4jLong* mmapFile(Nd4jPointer* extraPointers, const char* fileName, + Nd4jLong length); -ND4J_EXPORT void munmapFile(Nd4jPointer *extraPointers, Nd4jLong* ptrMap, Nd4jLong length); +SD_EXPORT void munmapFile(Nd4jPointer* extraPointers, Nd4jLong* ptrMap, + Nd4jLong length); typedef sd::graph::ResultWrapper OpaqueResultWrapper; // flatbuffers execution -ND4J_EXPORT OpaqueResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPointer flatBufferPointer); +SD_EXPORT OpaqueResultWrapper* executeFlatGraph(Nd4jPointer* extraPointers, + Nd4jPointer flatBufferPointer); -ND4J_EXPORT Nd4jLong getResultWrapperSize(OpaqueResultWrapper* ptr); -ND4J_EXPORT Nd4jPointer getResultWrapperPointer(OpaqueResultWrapper* ptr); +SD_EXPORT Nd4jLong getResultWrapperSize(OpaqueResultWrapper* ptr); +SD_EXPORT Nd4jPointer getResultWrapperPointer(OpaqueResultWrapper* ptr); -ND4J_EXPORT const char* getAllCustomOps(); +SD_EXPORT const char* getAllCustomOps(); -ND4J_EXPORT const char* getAllOperations(); +SD_EXPORT const char* getAllOperations(); // customOp executioner -ND4J_EXPORT int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace); -ND4J_EXPORT int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opContext); +SD_EXPORT int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, + Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, + int numInputs, Nd4jPointer* outputBuffers, + Nd4jPointer* outputShapes, int numOutputs, + double* tArgs, int numTArgs, Nd4jLong* iArgs, + int numIArgs, bool* bArgs, int numBArgs, + bool isInplace); +SD_EXPORT int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, + Nd4jPointer opContext); typedef sd::ShapeList OpaqueShapeList; -ND4J_EXPORT OpaqueShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs); -ND4J_EXPORT OpaqueShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs); +SD_EXPORT OpaqueShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, + Nd4jLong hash, + Nd4jPointer* inputShapes, + int numInputShapes, + double* tArgs, int numTArgs, + Nd4jLong* iArgs, int numIArgs); +SD_EXPORT OpaqueShapeList* calculateOutputShapes2( + Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, + Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, + Nd4jLong* iArgs, int numIArgs, bool* bArgs, int numBArgs, int* dArgs, + int numDArgs); -ND4J_EXPORT Nd4jLong getShapeListSize(OpaqueShapeList* list); -ND4J_EXPORT Nd4jLong const* getShape(OpaqueShapeList* list, Nd4jLong i); +SD_EXPORT Nd4jLong getShapeListSize(OpaqueShapeList* list); +SD_EXPORT Nd4jLong const* getShape(OpaqueShapeList* list, Nd4jLong i); -ND4J_EXPORT void deleteShapeList(Nd4jPointer shapeList); +SD_EXPORT void deleteShapeList(Nd4jPointer shapeList); -ND4J_EXPORT int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flatBufferPointer); +SD_EXPORT int registerGraph(Nd4jPointer* extraPointers, Nd4jLong graphId, + Nd4jPointer flatBufferPointer); typedef sd::graph::VariablesSet OpaqueVariablesSet; typedef sd::graph::Variable OpaqueVariable; -ND4J_EXPORT OpaqueVariablesSet *executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs); - -ND4J_EXPORT Nd4jLong getVariablesSetSize(OpaqueVariablesSet* set); -ND4J_EXPORT Nd4jStatus getVariablesSetStatus(OpaqueVariablesSet* set); -ND4J_EXPORT OpaqueVariable* getVariable(OpaqueVariablesSet* set, Nd4jLong i); -ND4J_EXPORT int getVariableId(OpaqueVariable* variable); -ND4J_EXPORT int getVariableIndex(OpaqueVariable* variable); -ND4J_EXPORT const char* getVariableName(OpaqueVariable* variable); -ND4J_EXPORT Nd4jLong const* getVariableShape(OpaqueVariable* variable); -ND4J_EXPORT void* getVariableBuffer(OpaqueVariable* variable); - -ND4J_EXPORT int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId); - -ND4J_EXPORT void deleteCharArray(Nd4jPointer pointer); -ND4J_EXPORT void deleteIntArray(Nd4jPointer pointer); -ND4J_EXPORT void deleteLongArray(Nd4jPointer pointer); -ND4J_EXPORT void deletePointerArray(Nd4jPointer pointer); - -ND4J_EXPORT void deleteVariablesSet(OpaqueVariablesSet* pointer); - -// GraphState creation -ND4J_EXPORT Nd4jPointer getGraphState(Nd4jLong id); - -ND4J_EXPORT void deleteGraphState(Nd4jPointer state); - -ND4J_EXPORT void deleteResultWrapper(Nd4jPointer ptr); - -ND4J_EXPORT int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer x, Nd4jLong const* xShapeInfo, int N, float threshold); - -// this method executes op that requires scope to be present: if/while/cond/whatever -ND4J_EXPORT Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPointer state, Nd4jLong opHash, Nd4jLong *scopes, int numScopes, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputs, Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs); - -//void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int numStrings, Nd4jPointer buffer); -ND4J_EXPORT Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, int length); -ND4J_EXPORT Nd4jLong getUtf8StringLength(Nd4jPointer *extraPointers, Nd4jPointer ptr); -ND4J_EXPORT char* getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer ptr); -ND4J_EXPORT void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr); - -ND4J_EXPORT void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs, - void* hX, Nd4jLong const* hXShapeInfo, Nd4jLong const* hXOffsets, - void* dX, Nd4jLong const* dXShapeInfo, Nd4jLong const* dXOffsets, - void* hY, Nd4jLong const* hYShapeInfo, Nd4jLong const* hYOffsets, - void* dY, Nd4jLong const* dYShapeInfo, Nd4jLong const* dYOffsets, - void* hIindexes, Nd4jLong const* hIndicesShapeInfo, void* dIindexes, Nd4jLong const* dIndicesShapeInfo); - -ND4J_EXPORT void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo); - +SD_EXPORT OpaqueVariablesSet* executeStoredGraph( + Nd4jPointer* extraPointers, Nd4jLong graphId, Nd4jPointer* inputBuffers, + Nd4jPointer* inputShapes, int* inputIndices, int numInputs); + +SD_EXPORT Nd4jLong getVariablesSetSize(OpaqueVariablesSet* set); +SD_EXPORT Nd4jStatus getVariablesSetStatus(OpaqueVariablesSet* set); +SD_EXPORT OpaqueVariable* getVariable(OpaqueVariablesSet* set, Nd4jLong i); +SD_EXPORT int getVariableId(OpaqueVariable* variable); +SD_EXPORT int getVariableIndex(OpaqueVariable* variable); +SD_EXPORT const char* getVariableName(OpaqueVariable* variable); +SD_EXPORT Nd4jLong const* getVariableShape(OpaqueVariable* variable); +SD_EXPORT void* getVariableBuffer(OpaqueVariable* variable); + +SD_EXPORT int unregisterGraph(Nd4jPointer* extraPointers, Nd4jLong graphId); + +SD_EXPORT void deleteCharArray(Nd4jPointer pointer); +SD_EXPORT void deleteIntArray(Nd4jPointer pointer); +SD_EXPORT void deleteLongArray(Nd4jPointer pointer); +SD_EXPORT void deletePointerArray(Nd4jPointer pointer); + +SD_EXPORT void deleteVariablesSet(OpaqueVariablesSet* pointer); + +SD_EXPORT void deleteResultWrapper(Nd4jPointer ptr); + +SD_EXPORT int estimateThreshold(Nd4jPointer* extraPointers, Nd4jPointer x, + Nd4jLong const* xShapeInfo, int N, + float threshold); + +// this method executes op that requires scope to be present: +// if/while/cond/whatever +SD_EXPORT Nd4jStatus execCustomOpWithScope( + Nd4jPointer* extraPointers, Nd4jPointer state, Nd4jLong opHash, + Nd4jLong* scopes, int numScopes, Nd4jPointer* inputBuffers, + Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, + Nd4jPointer* outputShapes, int numOutputs); + +// void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int +// numStrings, Nd4jPointer buffer); +SD_EXPORT Nd4jPointer createUtf8String(Nd4jPointer* extraPointers, + const char* string, int length); +SD_EXPORT Nd4jLong getUtf8StringLength(Nd4jPointer* extraPointers, + Nd4jPointer ptr); +SD_EXPORT char* getUtf8StringBuffer(Nd4jPointer* extraPointers, + Nd4jPointer ptr); +SD_EXPORT void deleteUtf8String(Nd4jPointer* extraPointers, Nd4jPointer ptr); + +SD_EXPORT void scatterUpdate( + Nd4jPointer* extraPointers, int opCode, int numOfSubArrs, void* hX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* hXOffsets, void* dX, + Nd4jLong const* dXShapeInfo, Nd4jLong const* dXOffsets, void* hY, + Nd4jLong const* hYShapeInfo, Nd4jLong const* hYOffsets, void* dY, + Nd4jLong const* dYShapeInfo, Nd4jLong const* dYOffsets, void* hIindexes, + Nd4jLong const* hIndicesShapeInfo, void* dIindexes, + Nd4jLong const* dIndicesShapeInfo); + +SD_EXPORT void inspectArray(Nd4jPointer* extraPointers, Nd4jPointer buffer, + Nd4jLong* shapeInfo, Nd4jPointer specialBuffer, + Nd4jLong* specialShapeInfo, Nd4jPointer debugInfo); typedef sd::ConstantDataBuffer OpaqueConstantDataBuffer; typedef sd::ConstantShapeBuffer OpaqueConstantShapeBuffer; -ND4J_EXPORT OpaqueConstantShapeBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, sd::DataType dtype, char order, Nd4jLong ews, bool empty); - -ND4J_EXPORT OpaqueConstantDataBuffer* constantBufferLong(sd::DataType dtype, Nd4jLong const* data, int length); -ND4J_EXPORT OpaqueConstantDataBuffer* constantBufferDouble(sd::DataType dtype, double *data, int length); -ND4J_EXPORT OpaqueConstantDataBuffer* constantBuffer(sd::DataType dtype, sd::ConstantDescriptor *descriptor); - -ND4J_EXPORT Nd4jPointer getConstantDataBufferPrimary(OpaqueConstantDataBuffer* dbf); -ND4J_EXPORT Nd4jPointer getConstantDataBufferSpecial(OpaqueConstantDataBuffer* dbf); -ND4J_EXPORT Nd4jLong getConstantDataBufferLength(OpaqueConstantDataBuffer* dbf); - -ND4J_EXPORT Nd4jPointer getConstantShapeBufferPrimary(OpaqueConstantShapeBuffer* dbf); -ND4J_EXPORT Nd4jPointer getConstantShapeBufferSpecial(OpaqueConstantShapeBuffer* dbf); - -ND4J_EXPORT void deleteConstantShapeBuffer(OpaqueConstantShapeBuffer* ptr); -ND4J_EXPORT void deleteConstantDataBuffer(OpaqueConstantDataBuffer* ptr); +SD_EXPORT OpaqueConstantShapeBuffer* shapeBuffer(int rank, Nd4jLong* shape, + Nd4jLong* strides, + sd::DataType dtype, char order, + Nd4jLong ews, bool empty); + +SD_EXPORT OpaqueConstantDataBuffer* constantBufferLong(sd::DataType dtype, + Nd4jLong const* data, + int length); +SD_EXPORT OpaqueConstantDataBuffer* constantBufferDouble(sd::DataType dtype, + double* data, + int length); +SD_EXPORT OpaqueConstantDataBuffer* constantBuffer( + sd::DataType dtype, sd::ConstantDescriptor* descriptor); + +SD_EXPORT Nd4jPointer +getConstantDataBufferPrimary(OpaqueConstantDataBuffer* dbf); +SD_EXPORT Nd4jPointer +getConstantDataBufferSpecial(OpaqueConstantDataBuffer* dbf); +SD_EXPORT Nd4jLong getConstantDataBufferLength(OpaqueConstantDataBuffer* dbf); +SD_EXPORT Nd4jLong getConstantDataBufferSizeOf(OpaqueConstantDataBuffer* dbf); + +SD_EXPORT void deleteShapeBuffer(OpaqueConstantDataBuffer* ptr); +SD_EXPORT Nd4jPointer getConstantShapeBufferPrimary(OpaqueConstantShapeBuffer* dbf); +SD_EXPORT Nd4jPointer getConstantShapeBufferSpecial(OpaqueConstantShapeBuffer* dbf); + +SD_EXPORT void deleteConstantShapeBuffer(OpaqueConstantShapeBuffer* ptr); +SD_EXPORT void deleteConstantDataBuffer(OpaqueConstantDataBuffer* ptr); typedef sd::graph::Context OpaqueContext; typedef sd::graph::RandomGenerator OpaqueRandomGenerator; -ND4J_EXPORT OpaqueContext* createGraphContext(int nodeId); -ND4J_EXPORT OpaqueRandomGenerator* getGraphContextRandomGenerator(OpaqueContext* ptr); -ND4J_EXPORT void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow); -ND4J_EXPORT void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride); -ND4J_EXPORT void ctxSetExecutionMode(OpaqueContext* ptr, int execMode); -ND4J_EXPORT void ctxPurge(OpaqueContext* ptr); -ND4J_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace); -ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer); -ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); -ND4J_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); -ND4J_EXPORT void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo); -ND4J_EXPORT void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo); -ND4J_EXPORT void setGraphContextDArguments(OpaqueContext* ptr, int *arguments, int numberOfArguments); -ND4J_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double *arguments, int numberOfArguments); -ND4J_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, Nd4jLong *arguments, int numberOfArguments); -ND4J_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool *arguments, int numberOfArguments); -ND4J_EXPORT void deleteGraphContext(OpaqueContext* ptr); - -ND4J_EXPORT OpaqueRandomGenerator* createRandomGenerator(Nd4jLong rootSeed = 0, Nd4jLong nodeSeed = 0); -ND4J_EXPORT Nd4jLong getRandomGeneratorRootState(OpaqueRandomGenerator* ptr); -ND4J_EXPORT Nd4jLong getRandomGeneratorNodeState(OpaqueRandomGenerator* ptr); -ND4J_EXPORT void setRandomGeneratorStates(OpaqueRandomGenerator* ptr, Nd4jLong rootSeed = 0, Nd4jLong nodeSeed = 0); -ND4J_EXPORT float getRandomGeneratorRelativeFloat(OpaqueRandomGenerator* ptr, Nd4jLong index); -ND4J_EXPORT double getRandomGeneratorRelativeDouble(OpaqueRandomGenerator* ptr, Nd4jLong index); -ND4J_EXPORT int getRandomGeneratorRelativeInt(OpaqueRandomGenerator* ptr, Nd4jLong index); -ND4J_EXPORT Nd4jLong getRandomGeneratorRelativeLong(OpaqueRandomGenerator* ptr, Nd4jLong index); -ND4J_EXPORT void deleteRandomGenerator(OpaqueRandomGenerator* ptr); - -ND4J_EXPORT const char* runLightBenchmarkSuit(bool printOut); -ND4J_EXPORT const char* runFullBenchmarkSuit(bool printOut); +SD_EXPORT OpaqueContext* createGraphContext(int nodeId); +SD_EXPORT OpaqueRandomGenerator* getGraphContextRandomGenerator( + OpaqueContext* ptr); +SD_EXPORT void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow); +SD_EXPORT void ctxShapeFunctionOverride(OpaqueContext* ptr, + bool reallyOverride); +SD_EXPORT void ctxSetExecutionMode(OpaqueContext* ptr, int execMode); +SD_EXPORT void ctxPurge(OpaqueContext* ptr); +SD_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace); +SD_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void* stream, + void* reductionPointer, + void* allocationPointer); +SD_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, + void* buffer, void* shapeInfo, + void* specialBuffer, + void* specialShapeInfo); +SD_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, + void* buffer, void* shapeInfo, + void* specialBuffer, + void* specialShapeInfo); +SD_EXPORT void setGraphContextInputBuffer(OpaqueContext* ptr, int index, + OpaqueDataBuffer* buffer, + void* shapeInfo, + void* specialShapeInfo); +SD_EXPORT void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, + OpaqueDataBuffer* buffer, + void* shapeInfo, + void* specialShapeInfo); +SD_EXPORT void setGraphContextDArguments(OpaqueContext* ptr, int* arguments, + int numberOfArguments); +SD_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double* arguments, + int numberOfArguments); +SD_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, + Nd4jLong* arguments, + int numberOfArguments); +SD_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool* arguments, + int numberOfArguments); +SD_EXPORT void deleteGraphContext(OpaqueContext* ptr); + +SD_EXPORT OpaqueRandomGenerator* createRandomGenerator(Nd4jLong rootSeed = 0, Nd4jLong nodeSeed = 0); +SD_EXPORT Nd4jLong getRandomGeneratorRootState(OpaqueRandomGenerator* ptr); +SD_EXPORT Nd4jLong getRandomGeneratorNodeState(OpaqueRandomGenerator* ptr); +SD_EXPORT void setRandomGeneratorStates(OpaqueRandomGenerator* ptr, Nd4jLong rootSeed = 0, Nd4jLong nodeSeed = 0); +SD_EXPORT float getRandomGeneratorRelativeFloat(OpaqueRandomGenerator* ptr, Nd4jLong index); +SD_EXPORT double getRandomGeneratorRelativeDouble(OpaqueRandomGenerator* ptr, Nd4jLong index); +SD_EXPORT int getRandomGeneratorRelativeInt(OpaqueRandomGenerator* ptr, Nd4jLong index); +SD_EXPORT Nd4jLong getRandomGeneratorRelativeLong(OpaqueRandomGenerator* ptr, Nd4jLong index); +SD_EXPORT void deleteRandomGenerator(OpaqueRandomGenerator* ptr); + +SD_EXPORT const char* runLightBenchmarkSuit(bool printOut); +SD_EXPORT const char* runFullBenchmarkSuit(bool printOut); typedef sd::LaunchContext OpaqueLaunchContext; -ND4J_EXPORT OpaqueLaunchContext* defaultLaunchContext(); -ND4J_EXPORT Nd4jPointer lcScalarPointer(OpaqueLaunchContext* lc); -ND4J_EXPORT Nd4jPointer lcReductionPointer(OpaqueLaunchContext* lc); -ND4J_EXPORT Nd4jPointer lcAllocationPointer(OpaqueLaunchContext* lc); -ND4J_EXPORT Nd4jPointer lcExecutionStream(OpaqueLaunchContext* lc); -ND4J_EXPORT Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc); -ND4J_EXPORT Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc); -ND4J_EXPORT Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc); - -ND4J_EXPORT OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth); -ND4J_EXPORT OpaqueDataBuffer* dbAllocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth); -ND4J_EXPORT OpaqueDataBuffer* dbCreateExternalDataBuffer(Nd4jLong elements, int dataType, Nd4jPointer primary, Nd4jPointer special); -ND4J_EXPORT OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, Nd4jLong offset); -ND4J_EXPORT Nd4jPointer dbPrimaryBuffer(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT Nd4jPointer dbSpecialBuffer(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT void dbExpandBuffer(OpaqueDataBuffer *dataBuffer, Nd4jLong elements); -ND4J_EXPORT void dbAllocatePrimaryBuffer(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT void dbAllocateSpecialBuffer(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT void dbSetPrimaryBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer primaryBuffer, Nd4jLong numBytes); -ND4J_EXPORT void dbSetSpecialBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer specialBuffer, Nd4jLong numBytes); -ND4J_EXPORT void dbSyncToSpecial(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT void dbSyncToPrimary(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT int dbLocality(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT int dbDeviceId(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT void dbSetDeviceId(OpaqueDataBuffer *dataBuffer, int deviceId); -ND4J_EXPORT void dbTickHostRead(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT void dbTickHostWrite(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT void dbTickDeviceRead(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT void dbTickDeviceWrite(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT void dbClose(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT void deleteDataBuffer(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT void dbExpand(OpaqueDataBuffer *dataBuffer, Nd4jLong elements); - - -ND4J_EXPORT int binaryLevel(); -ND4J_EXPORT int optimalLevel(); - -ND4J_EXPORT bool isMinimalRequirementsMet(); -ND4J_EXPORT bool isOptimalRequirementsMet(); - +SD_EXPORT OpaqueLaunchContext* defaultLaunchContext(); +SD_EXPORT Nd4jPointer lcScalarPointer(OpaqueLaunchContext* lc); +SD_EXPORT Nd4jPointer lcReductionPointer(OpaqueLaunchContext* lc); +SD_EXPORT Nd4jPointer lcAllocationPointer(OpaqueLaunchContext* lc); +SD_EXPORT Nd4jPointer lcExecutionStream(OpaqueLaunchContext* lc); +SD_EXPORT Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc); +SD_EXPORT Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc); +SD_EXPORT Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc); + +SD_EXPORT OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, + bool allocateBoth); +SD_EXPORT OpaqueDataBuffer* dbAllocateDataBuffer(Nd4jLong elements, + int dataType, + bool allocateBoth); +SD_EXPORT OpaqueDataBuffer* dbCreateExternalDataBuffer(Nd4jLong elements, + int dataType, + Nd4jPointer primary, + Nd4jPointer special); +SD_EXPORT OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer* dataBuffer, + Nd4jLong length, Nd4jLong offset); +SD_EXPORT Nd4jPointer dbPrimaryBuffer(OpaqueDataBuffer* dataBuffer); +SD_EXPORT Nd4jPointer dbSpecialBuffer(OpaqueDataBuffer* dataBuffer); +SD_EXPORT void dbExpandBuffer(OpaqueDataBuffer* dataBuffer, Nd4jLong elements); +SD_EXPORT void dbAllocatePrimaryBuffer(OpaqueDataBuffer* dataBuffer); +SD_EXPORT void dbAllocateSpecialBuffer(OpaqueDataBuffer* dataBuffer); +SD_EXPORT void dbSetPrimaryBuffer(OpaqueDataBuffer* dataBuffer, + Nd4jPointer primaryBuffer, Nd4jLong numBytes); +SD_EXPORT void dbSetSpecialBuffer(OpaqueDataBuffer* dataBuffer, + Nd4jPointer specialBuffer, Nd4jLong numBytes); +SD_EXPORT void dbSyncToSpecial(OpaqueDataBuffer* dataBuffer); +SD_EXPORT void dbSyncToPrimary(OpaqueDataBuffer* dataBuffer); +SD_EXPORT int dbLocality(OpaqueDataBuffer* dataBuffer); +SD_EXPORT int dbDeviceId(OpaqueDataBuffer* dataBuffer); +SD_EXPORT void dbSetDeviceId(OpaqueDataBuffer* dataBuffer, int deviceId); +SD_EXPORT void dbTickHostRead(OpaqueDataBuffer* dataBuffer); +SD_EXPORT void dbTickHostWrite(OpaqueDataBuffer* dataBuffer); +SD_EXPORT void dbTickDeviceRead(OpaqueDataBuffer* dataBuffer); +SD_EXPORT void dbTickDeviceWrite(OpaqueDataBuffer* dataBuffer); +SD_EXPORT void dbClose(OpaqueDataBuffer* dataBuffer); +SD_EXPORT void deleteDataBuffer(OpaqueDataBuffer* dataBuffer); +SD_EXPORT void dbExpand(OpaqueDataBuffer* dataBuffer, Nd4jLong elements); + +SD_EXPORT int binaryLevel(); +SD_EXPORT int optimalLevel(); + +SD_EXPORT bool isMinimalRequirementsMet(); +SD_EXPORT bool isOptimalRequirementsMet(); } -#endif //NATIVEOPERATIONS_NATIVEOPS_H +#endif // NATIVEOPERATIONS_NATIVEOPS_H diff --git a/libnd4j/include/legacy/cpu/NativeOpExecutioner.cpp b/libnd4j/include/legacy/cpu/NativeOpExecutioner.cpp index 6b6c51a1375a..40b347bff963 100644 --- a/libnd4j/include/legacy/cpu/NativeOpExecutioner.cpp +++ b/libnd4j/include/legacy/cpu/NativeOpExecutioner.cpp @@ -14,82 +14,69 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - -#include -#include #include "legacy/NativeOpExecutioner.h" -#include +#include +#include +#include #include - -#include +#include #include -#include - -#include #include -#include - -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include #include +#include +#include #include -#include +#include #include +#include +#include +#include +#include +#include +#include +#include #include +#include +#include +#include #include -#include -#include +#include #include -#include -#include -#include +#include +#include #ifdef _OPENMP -#include #include +#include #endif - - - //////////////////////////////////////////////////////////////////////// /** -* -* @param opNum -* @param hX -* @param hXShapeInfo -* @param extraParams -* @param hZ -* @param hZShapeInfo -*/ -void NativeOpExecutioner::execIndexReduceScalar(sd::LaunchContext *lc, int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo) { - - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - auto hz = reinterpret_cast(hZ); - - BUILD_DOUBLE_SELECTOR(xType, zType, hz[0] = functions::indexreduce::IndexReduce, ::execScalar(opNum,hX,hXShapeInfo,extraParams), LIBND4J_TYPES, INDEXING_TYPES); + * + * @param opNum + * @param hX + * @param hXShapeInfo + * @param extraParams + * @param hZ + * @param hZShapeInfo + */ +void NativeOpExecutioner::execIndexReduceScalar( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + auto hz = reinterpret_cast(hZ); + + BUILD_DOUBLE_SELECTOR(xType, zType, + hz[0] = functions::indexreduce::IndexReduce, + ::execScalar(opNum, hX, hXShapeInfo, extraParams), + LIBND4J_TYPES, INDEXING_TYPES); } //////////////////////////////////////////////////////////////////////// @@ -105,22 +92,21 @@ void NativeOpExecutioner::execIndexReduceScalar(sd::LaunchContext *lc, int opNu * @param dimensionLength */ -void NativeOpExecutioner::execIndexReduce(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - auto hz = reinterpret_cast(hZ); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::indexreduce::IndexReduce, ::exec(opNum, hX, hXShapeInfo, extraParams, hz, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES, INDEXING_TYPES); +void NativeOpExecutioner::execIndexReduce( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + auto hz = reinterpret_cast(hZ); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::indexreduce::IndexReduce, + ::exec(opNum, hX, hXShapeInfo, extraParams, hz, hZShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffsets), + LIBND4J_TYPES, INDEXING_TYPES); } //////////////////////////////////////////////////////////////////////// @@ -137,557 +123,597 @@ void NativeOpExecutioner::execIndexReduce(sd::LaunchContext *lc, * @param dimensionLength */ -void NativeOpExecutioner::execBroadcast(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ) { - - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); +void NativeOpExecutioner::execBroadcast( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_PAIRWISE_SELECTOR( + xType, yType, zType, functions::broadcast::Broadcast, + ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ), + LIBND4J_TYPES, LIBND4J_TYPES); #else - auto loopKind = sd::LoopKind::deduceKindOfLoopBroadcast(hXShapeInfo, hYShapeInfo, hZShapeInfo); - - auto func = PRAGMA_THREADS_FOR { - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, loopKind, start, stop), LIBND4J_TYPES); - }; - - Nd4jLong numTads = 0; - - switch (loopKind) { - case sd::LoopKind::BROADCAST_SCALAR_X: { - numTads = shape::length(hXShapeInfo); - } - break; - case sd::LoopKind::BROADCAST_SCALAR_Y: { - numTads = shape::length(hYShapeInfo); - } - break; - case sd::LoopKind::BROADCAST_3D: { - numTads = shape::sizeAt(hZShapeInfo, 0); - } - break; - case sd::LoopKind::BROADCAST_4D: { - numTads = shape::sizeAt(hZShapeInfo, 0) * shape::sizeAt(hZShapeInfo, 1); - } - break; - case sd::LoopKind::BROADCAST_5D: { - numTads = shape::sizeAt(hZShapeInfo, 0) * shape::sizeAt(hZShapeInfo, 1); - } - break; - default: { - auto xLen = shape::length(hXShapeInfo); - auto yLen = shape::length(hYShapeInfo); - numTads = xLen / yLen; - } + auto loopKind = sd::LoopKind::deduceKindOfLoopBroadcast( + hXShapeInfo, hYShapeInfo, hZShapeInfo); + + auto func = PRAGMA_THREADS_FOR { + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::broadcast::Broadcast, + ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ, loopKind, start, stop), + LIBND4J_TYPES); + }; + + Nd4jLong numTads = 0; + + switch (loopKind) { + case sd::LoopKind::BROADCAST_SCALAR_X: { + numTads = shape::length(hXShapeInfo); + } break; + case sd::LoopKind::BROADCAST_SCALAR_Y: { + numTads = shape::length(hYShapeInfo); + } break; + case sd::LoopKind::BROADCAST_3D: { + numTads = shape::sizeAt(hZShapeInfo, 0); + } break; + case sd::LoopKind::BROADCAST_4D: { + numTads = shape::sizeAt(hZShapeInfo, 0) * shape::sizeAt(hZShapeInfo, 1); + } break; + case sd::LoopKind::BROADCAST_5D: { + numTads = shape::sizeAt(hZShapeInfo, 0) * shape::sizeAt(hZShapeInfo, 1); + } break; + default: { + auto xLen = shape::length(hXShapeInfo); + auto yLen = shape::length(hYShapeInfo); + numTads = xLen / yLen; } + } - samediff::Threads::parallel_tad(func, 0, numTads); + samediff::Threads::parallel_tad(func, 0, numTads); #endif } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execBroadcast(sd::LaunchContext* lc, const int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo) { - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo), LIBND4J_TYPES, LIBND4J_TYPES); - #else - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo), LIBND4J_TYPES); - #endif -} - -void NativeOpExecutioner::execInverseBroadcast(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ) { - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (!sd::Environment::getInstance().isExperimentalBuild()) - if ((yType != xType && yType != sd::DataType::BOOL) || xType != zType) - throw sd::datatype_exception::build("NativeOps::execBroadcast both operands must have same data type", xType, yType); +void NativeOpExecutioner::execBroadcast( + sd::LaunchContext *lc, const int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo) { + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_PAIRWISE_SELECTOR( + xType, yType, zType, functions::broadcast::Broadcast, + ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo), + LIBND4J_TYPES, LIBND4J_TYPES); #else - auto func = PRAGMA_THREADS_FOR { - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES); - }; + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::broadcast::Broadcast, + ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo), + LIBND4J_TYPES); +#endif +} - auto xLen = shape::length(hXShapeInfo); - auto yLen = shape::length(hYShapeInfo); - auto numTads = yLen / xLen; +void NativeOpExecutioner::execInverseBroadcast( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (!sd::Environment::getInstance().isExperimentalBuild()) + if ((yType != xType && yType != sd::DataType::BOOL) || xType != zType) + throw sd::datatype_exception::build( + "NativeOps::execBroadcast both operands must have same data type", + xType, yType); - samediff::Threads::parallel_tad(func, 0, numTads); +#ifdef __ND4J_EXPERIMENTAL__ + BUILD_PAIRWISE_SELECTOR( + xType, yType, zType, functions::broadcast::Broadcast, + ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ), + LIBND4J_TYPES, LIBND4J_TYPES); +#else + auto func = PRAGMA_THREADS_FOR { + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::broadcast::Broadcast, + ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), + LIBND4J_TYPES); + }; + + auto xLen = shape::length(hXShapeInfo); + auto yLen = shape::length(hYShapeInfo); + auto numTads = yLen / xLen; + + samediff::Threads::parallel_tad(func, 0, numTads); #endif - } - //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ) { - - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES); - }; - - auto xLen = shape::length(hXShapeInfo); - auto yLen = shape::length(hYShapeInfo); - auto numTads = xLen / yLen; - - samediff::Threads::parallel_tad(func, 0, numTads); +void NativeOpExecutioner::execBroadcastBool( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, void *extraParams, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ) { + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + auto func = PRAGMA_THREADS_FOR { + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::broadcast::BroadcastBool, + ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, + extraParams, dimension, dimensionLength, tadOnlyShapeInfo, + tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), + LIBND4J_TYPES, BOOL_TYPES); + }; + + auto xLen = shape::length(hXShapeInfo); + auto yLen = shape::length(hYShapeInfo); + auto numTads = xLen / yLen; + + samediff::Threads::parallel_tad(func, 0, numTads); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext* lc, const int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams) { - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), LIBND4J_TYPES, BOOL_TYPES); +void NativeOpExecutioner::execBroadcastBool( + sd::LaunchContext *lc, const int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, void *extraParams) { + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, + ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, + hZShapeInfo, extraParams), + LIBND4J_TYPES, BOOL_TYPES); } - -void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (!sd::Environment::getInstance().isExperimentalBuild()) - if (yType != xType || sd::DataType::BOOL != zType) - throw sd::datatype_exception::build("NativeOps::execInverseBroadcastBool both operands must have same data type", xType, yType); - - auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES); - }; - - auto xLen = shape::length(hXShapeInfo); - auto yLen = shape::length(hYShapeInfo); - auto numTads = yLen / xLen; - - samediff::Threads::parallel_tad(func, 0, numTads); +void NativeOpExecutioner::execInverseBroadcastBool( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, void *extraParams, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (!sd::Environment::getInstance().isExperimentalBuild()) + if (yType != xType || sd::DataType::BOOL != zType) + throw sd::datatype_exception::build( + "NativeOps::execInverseBroadcastBool both operands must have same " + "data type", + xType, yType); + + auto func = PRAGMA_THREADS_FOR { + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::broadcast::BroadcastBool, + ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, + extraParams, dimension, dimensionLength, tadOnlyShapeInfo, + tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), + LIBND4J_TYPES, BOOL_TYPES); + }; + + auto xLen = shape::length(hXShapeInfo); + auto yLen = shape::length(hYShapeInfo); + auto numTads = yLen / xLen; + + samediff::Threads::parallel_tad(func, 0, numTads); } - - //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ) { - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (xType != yType || xType != zType) - throw sd::datatype_exception::build("NativeOpExecutioner::execBroadcastInt", zType, xType, yType); - - if (!sd::DataTypeUtils::isZ(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execBroadcastInt requires integer data type", zType); - - auto func = PRAGMA_THREADS_FOR { - BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), INTEGER_TYPES); - }; - - auto xLen = shape::length(hXShapeInfo); - auto yLen = shape::length(hYShapeInfo); - auto numTads = xLen / yLen; - - samediff::Threads::parallel_tad(func, 0, numTads); +void NativeOpExecutioner::execBroadcastInt( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (xType != yType || xType != zType) + throw sd::datatype_exception::build("NativeOpExecutioner::execBroadcastInt", + zType, xType, yType); + + if (!sd::DataTypeUtils::isZ(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execBroadcastInt requires integer data type", + zType); + + auto func = PRAGMA_THREADS_FOR { + BUILD_SINGLE_SELECTOR( + xType, functions::broadcast::BroadcastInt, + ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), + INTEGER_TYPES); + }; + + auto xLen = shape::length(hXShapeInfo); + auto yLen = shape::length(hYShapeInfo); + auto numTads = xLen / yLen; + + samediff::Threads::parallel_tad(func, 0, numTads); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext *lc, const int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo) { - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (xType != yType || xType != zType) - throw sd::datatype_exception::build("NativeOpExecutioner::execBroadcastInt", zType, xType, yType); - - if (!sd::DataTypeUtils::isZ(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execBroadcastInt requires integer data type", zType); - - BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo), INTEGER_TYPES); - +void NativeOpExecutioner::execBroadcastInt( + sd::LaunchContext *lc, const int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (xType != yType || xType != zType) + throw sd::datatype_exception::build("NativeOpExecutioner::execBroadcastInt", + zType, xType, yType); + + if (!sd::DataTypeUtils::isZ(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execBroadcastInt requires integer data type", + zType); + + BUILD_SINGLE_SELECTOR( + xType, functions::broadcast::BroadcastInt, + ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo), + INTEGER_TYPES); } -void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (xType != yType || xType != zType) - throw sd::datatype_exception::build("NativeOpExecutioner::execInverseBroadcastInt", zType, xType, yType); - - if (!sd::DataTypeUtils::isZ(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execInverseBroadcastInt requires integer data type", zType); - - auto func = PRAGMA_THREADS_FOR { - BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt,::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), INTEGER_TYPES); - }; - - auto xLen = shape::length(hXShapeInfo); - auto yLen = shape::length(hYShapeInfo); - auto numTads = yLen / xLen; - - samediff::Threads::parallel_tad(func, 0, numTads); +void NativeOpExecutioner::execInverseBroadcastInt( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (xType != yType || xType != zType) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execInverseBroadcastInt", zType, xType, yType); + + if (!sd::DataTypeUtils::isZ(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execInverseBroadcastInt requires integer data " + "type", + zType); + + auto func = PRAGMA_THREADS_FOR { + BUILD_SINGLE_SELECTOR( + xType, functions::broadcast::BroadcastInt, + ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), + INTEGER_TYPES); + }; + + auto xLen = shape::length(hXShapeInfo); + auto yLen = shape::length(hYShapeInfo); + auto numTads = yLen / xLen; + + samediff::Threads::parallel_tad(func, 0, numTads); } //////////////////////////////////////////////////////////////////////// /** -* -* @param opNum -* @param hX -* @param xStride -* @param hY -* @param yStride -* @param hZ -* @param resultStride -* @param extraParams -* @param n -*/ -void NativeOpExecutioner::execPairwiseTransform(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; + * + * @param opNum + * @param hX + * @param xStride + * @param hY + * @param yStride + * @param hZ + * @param resultStride + * @param extraParams + * @param n + */ +void NativeOpExecutioner::execPairwiseTransform( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, void *extraParams) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::pairwise_transforms::PairWiseTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_PAIRWISE_SELECTOR(xType, yType, zType, + functions::pairwise_transforms::PairWiseTransform, + ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, + hZShapeInfo, extraParams), + LIBND4J_TYPES, LIBND4J_TYPES); #else - auto func = PRAGMA_THREADS_FOR { - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::pairwise_transforms::PairWiseTransform, - ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, start, stop), - LIBND4J_TYPES); - }; - - auto zLen = shape::length(hZShapeInfo); - samediff::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max(1, sd::math::nd4j_min(zLen / 1024, sd::Environment::getInstance().maxMasterThreads()))); + auto func = PRAGMA_THREADS_FOR { + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::pairwise_transforms::PairWiseTransform, + ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, + extraParams, start, stop), + LIBND4J_TYPES); + }; + + auto zLen = shape::length(hZShapeInfo); + samediff::Threads::parallel_for( + func, 0, zLen, 1, + sd::math::nd4j_max( + 1, sd::math::nd4j_min( + zLen / 1024, + sd::Environment::getInstance().maxMasterThreads()))); #endif } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execPairwiseBoolTransform(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams) { - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (xType != yType) - throw sd::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform", xType, yType); - - if (zType != sd::DataType::BOOL) - throw sd::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform", sd::DataType::BOOL, zType); - - auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::pairwise_transforms::PairWiseBoolTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, start, stop), LIBND4J_TYPES, BOOL_TYPES); - }; - - auto zLen = shape::length(hZShapeInfo); - samediff::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max(1, sd::math::nd4j_min(zLen / 1024, sd::Environment::getInstance().maxMasterThreads()))); - +void NativeOpExecutioner::execPairwiseBoolTransform( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, void *extraParams) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (xType != yType) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execPairwiseBoolTransform", xType, yType); + + if (zType != sd::DataType::BOOL) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execPairwiseBoolTransform", sd::DataType::BOOL, + zType); + + auto func = PRAGMA_THREADS_FOR { + BUILD_DOUBLE_SELECTOR(xType, zType, + functions::pairwise_transforms::PairWiseBoolTransform, + ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, + hZShapeInfo, extraParams, start, stop), + LIBND4J_TYPES, BOOL_TYPES); + }; + + auto zLen = shape::length(hZShapeInfo); + samediff::Threads::parallel_for( + func, 0, zLen, 1, + sd::math::nd4j_max( + 1, sd::math::nd4j_min( + zLen / 1024, + sd::Environment::getInstance().maxMasterThreads()))); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execPairwiseIntTransform(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (xType != yType || xType != zType) - throw sd::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform", zType, xType, yType); - - if (!sd::DataTypeUtils::isZ(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execSPairwiseInt requires integer data type", zType); - - auto func = PRAGMA_THREADS_FOR { - BUILD_SINGLE_SELECTOR(xType, functions::pairwise_transforms::PairWiseIntTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, start, stop), INTEGER_TYPES); - }; - - auto zLen = shape::length(hZShapeInfo); - samediff::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max(1, sd::math::nd4j_min(zLen / 1024, sd::Environment::getInstance().maxMasterThreads()))); - +void NativeOpExecutioner::execPairwiseIntTransform( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, void *extraParams) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (xType != yType || xType != zType) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execPairwiseIntTransform", zType, xType, yType); + + if (!sd::DataTypeUtils::isZ(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execSPairwiseInt requires integer data type", + zType); + + auto func = PRAGMA_THREADS_FOR { + BUILD_SINGLE_SELECTOR(xType, + functions::pairwise_transforms::PairWiseIntTransform, + ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, + hZShapeInfo, extraParams, start, stop), + INTEGER_TYPES); + }; + + auto zLen = shape::length(hZShapeInfo); + samediff::Threads::parallel_for( + func, 0, zLen, 1, + sd::math::nd4j_max( + 1, sd::math::nd4j_min( + zLen / 1024, + sd::Environment::getInstance().maxMasterThreads()))); } //////////////////////////////////////////////////////////////////////// /** -* -* @param opNum -* @param hX -* @param hXShapeInfo -* @param extraParams -* @param hZ -* @param hZShapeInfo -*/ -void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - // nothing to do here if result is empty - if (shape::isEmpty(hZShapeInfo)) - return; - - auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, FLOAT_TYPES); - }; - - const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo); - - samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance().maxMasterThreads()); + * + * @param opNum + * @param hX + * @param hXShapeInfo + * @param extraParams + * @param hZ + * @param hZShapeInfo + */ +void NativeOpExecutioner::execReduceFloat( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + // nothing to do here if result is empty + if (shape::isEmpty(hZShapeInfo)) return; + + auto func = PRAGMA_THREADS_FOR { + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceFloatFunction, + ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffsets, start, stop), + LIBND4J_TYPES, FLOAT_TYPES); + }; + + const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ( + hXShapeInfo, hZShapeInfo, tadShapeInfo); + + samediff::Threads::parallel_tad( + func, 0, shape::length(hZShapeInfo), 1, + kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX + ? 1 + : sd::Environment::getInstance().maxMasterThreads()); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceSame(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - // nothing to do here if result is empty - if (shape::isEmpty(hZShapeInfo)) - return; - - auto func = PRAGMA_THREADS_FOR { - BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES); - }; - - const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo); - - samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance().maxMasterThreads()); +void NativeOpExecutioner::execReduceSame( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + // nothing to do here if result is empty + if (shape::isEmpty(hZShapeInfo)) return; + + auto func = PRAGMA_THREADS_FOR { + BUILD_SINGLE_SELECTOR( + xType, functions::reduce::ReduceSameFunction, + ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffsets, start, stop), + LIBND4J_TYPES); + }; + + const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ( + hXShapeInfo, hZShapeInfo, tadShapeInfo); + + samediff::Threads::parallel_tad( + func, 0, shape::length(hZShapeInfo), 1, + kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX + ? 1 + : sd::Environment::getInstance().maxMasterThreads()); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - // nothing to do here if result is empty - if (shape::isEmpty(hZShapeInfo)) - return; - - auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, BOOL_TYPES); - }; - - const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo); - - samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance().maxMasterThreads()); +void NativeOpExecutioner::execReduceBool( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + // nothing to do here if result is empty + if (shape::isEmpty(hZShapeInfo)) return; + + auto func = PRAGMA_THREADS_FOR { + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceBoolFunction, + ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffsets, start, stop), + LIBND4J_TYPES, BOOL_TYPES); + }; + + const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ( + hXShapeInfo, hZShapeInfo, tadShapeInfo); + + samediff::Threads::parallel_tad( + func, 0, shape::length(hZShapeInfo), 1, + kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX + ? 1 + : sd::Environment::getInstance().maxMasterThreads()); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - // nothing to do here if result is empty - if (shape::isEmpty(hZShapeInfo)) - return; - - auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, LONG_TYPES); - }; - - const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo); - - samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance().maxMasterThreads()); +void NativeOpExecutioner::execReduceLong( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + // nothing to do here if result is empty + if (shape::isEmpty(hZShapeInfo)) return; + + auto func = PRAGMA_THREADS_FOR { + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceLongFunction, + ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffsets, start, stop), + LIBND4J_TYPES, LONG_TYPES); + }; + + const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ( + hXShapeInfo, hZShapeInfo, tadShapeInfo); + + samediff::Threads::parallel_tad( + func, 0, shape::length(hZShapeInfo), 1, + kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX + ? 1 + : sd::Environment::getInstance().maxMasterThreads()); } //////////////////////////////////////////////////////////////////////// @@ -699,67 +725,64 @@ void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc, * @param extraParams * @return */ -void NativeOpExecutioner::execReduceFloatScalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo) { - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::execScalar(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo), LIBND4J_TYPES, FLOAT_TYPES); +void NativeOpExecutioner::execReduceFloatScalar( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceFloatFunction, + ::execScalar(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo), + LIBND4J_TYPES, FLOAT_TYPES); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceSameScalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo) { - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - - BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::execScalar(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo), LIBND4J_TYPES); +void NativeOpExecutioner::execReduceSameScalar( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + + BUILD_SINGLE_SELECTOR( + xType, functions::reduce::ReduceSameFunction, + ::execScalar(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo), + LIBND4J_TYPES); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceBoolScalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::execScalar(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo), LIBND4J_TYPES, BOOL_TYPES); +void NativeOpExecutioner::execReduceBoolScalar( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceBoolFunction, + ::execScalar(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo), + LIBND4J_TYPES, BOOL_TYPES); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceLongScalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::execScalar(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo), LIBND4J_TYPES, LONG_TYPES); +void NativeOpExecutioner::execReduceLongScalar( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceLongFunction, + ::execScalar(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo), + LIBND4J_TYPES, LONG_TYPES); } - //////////////////////////////////////////////////////////////////////// /** * @@ -774,648 +797,705 @@ void NativeOpExecutioner::execReduceLongScalar(sd::LaunchContext *lc, * @param dimension * @param dimensionLength */ -void NativeOpExecutioner::execReduce3Scalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParamsVals, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execScalar(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo), LIBND4J_TYPES, FLOAT_TYPES); +void NativeOpExecutioner::execReduce3Scalar( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParamsVals, const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, + ::execScalar(opNum, hX, hXShapeInfo, extraParamsVals, + hY, hYShapeInfo, hZ, hZShapeInfo), + LIBND4J_TYPES, FLOAT_TYPES); } //////////////////////////////////////////////////////////////////////// /** -* -* @param opNum -* @param hX -* @param hXShapeInfo -* @param extraParamsVals -* @param hY -* @param hYShapeInfo -* @param hZ -* @param hZShapeInfo -*/ -void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParamsVals, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - //BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, nullptr, 0), LIBND4J_TYPES, FLOAT_TYPES); - NativeOpExecutioner::execReduce3Scalar(lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParamsVals, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); + * + * @param opNum + * @param hX + * @param hXShapeInfo + * @param extraParamsVals + * @param hY + * @param hYShapeInfo + * @param hZ + * @param hZShapeInfo + */ +void NativeOpExecutioner::execReduce3( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParamsVals, const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + // BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, + // ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, + // hZShapeInfo, nullptr, 0), LIBND4J_TYPES, FLOAT_TYPES); + NativeOpExecutioner::execReduce3Scalar( + lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParamsVals, hY, + hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParamsVals, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadOnlyShapeInfo, const Nd4jLong *xTadOffsets, - const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets) { - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - const auto xLen = shape::length(hXShapeInfo); - const auto yLen = shape::length(hYShapeInfo); - - sd::TadPack tadPack; - - if(xLen == yLen) { - tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength); - } - else if(yLen > xLen) { - tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hYShapeInfo, dimension, dimensionLength); - } - else { - tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength); - } - - auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, start, stop), LIBND4J_TYPES, FLOAT_TYPES); - }; - - samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads()); +void NativeOpExecutioner::execReduce3( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParamsVals, const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *xTadOnlyShapeInfo, + const Nd4jLong *xTadOffsets, const Nd4jLong *yTadOnlyShapeInfo, + const Nd4jLong *yTadOffsets) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + const auto xLen = shape::length(hXShapeInfo); + const auto yLen = shape::length(hYShapeInfo); + + sd::TadPack tadPack; + + if (xLen == yLen) { + tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + } else if (yLen > xLen) { + tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + hYShapeInfo, dimension, dimensionLength); + } else { + tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + } + + auto func = PRAGMA_THREADS_FOR { + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce3::Reduce3, + ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, + hZShapeInfo, dimension, dimensionLength, start, stop), + LIBND4J_TYPES, FLOAT_TYPES); + }; + + samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads()); } - //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduce3All(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParamsVals, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength); - - // TODO: make it 2d - auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execAll(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, start, stop), LIBND4J_TYPES, FLOAT_TYPES); - }; - - samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads()); +void NativeOpExecutioner::execReduce3All( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParamsVals, const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xOffsets, const Nd4jLong *yTadShapeInfo, + const Nd4jLong *yOffsets) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + + // TODO: make it 2d + auto func = PRAGMA_THREADS_FOR { + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce3::Reduce3, + ::execAll(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, + hZShapeInfo, dimension, dimensionLength, xTadShapeInfo, + xOffsets, yTadShapeInfo, yOffsets, start, stop), + LIBND4J_TYPES, FLOAT_TYPES); + }; + + samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads()); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduce3TAD(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParamsVals, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yTadOffsets) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - const auto xLen = shape::length(hXShapeInfo); - const auto yLen = shape::length(hYShapeInfo); - - sd::TadPack tadPack; - - if(xLen == yLen) { - tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength); - } - else if(yLen > xLen) { - tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hYShapeInfo, dimension, dimensionLength); - } - else { - tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength); - } - - auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, FLOAT_TYPES); - }; - - samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads()); +void NativeOpExecutioner::execReduce3TAD( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParamsVals, const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *yTadShapeInfo, + const Nd4jLong *yTadOffsets) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + const auto xLen = shape::length(hXShapeInfo); + const auto yLen = shape::length(hYShapeInfo); + + sd::TadPack tadPack; + + if (xLen == yLen) { + tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + } else if (yLen > xLen) { + tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + hYShapeInfo, dimension, dimensionLength); + } else { + tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + } + + auto func = PRAGMA_THREADS_FOR { + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce3::Reduce3, + ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, + hZShapeInfo, dimension, dimensionLength, tadShapeInfo, + tadOffsets, start, stop), + LIBND4J_TYPES, FLOAT_TYPES); + }; + + samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads()); } - //////////////////////////////////////////////////////////////////////// /** -* -* @param opNum -* @param hX -* @param xStride -* @param hZ -* @param resultStride -* @param scalar -* @param extraParams -* @param n -*/ -void NativeOpExecutioner::execScalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - const void *hScalar, const Nd4jLong *hScalarShapeInfo, - const void *dScalar, const Nd4jLong *dScalarShapeInfo, - void *extraParams, - bool allowParallelism) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) - return; + * + * @param opNum + * @param hX + * @param xStride + * @param hZ + * @param resultStride + * @param scalar + * @param extraParams + * @param n + */ +void NativeOpExecutioner::execScalar( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, const void *hScalar, + const Nd4jLong *hScalarShapeInfo, const void *dScalar, + const Nd4jLong *dScalarShapeInfo, void *extraParams, + bool allowParallelism) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) return; #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_PAIRWISE_SELECTOR(xType, yType, zType, + functions::scalar::ScalarTransform, + ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, + hScalar, extraParams), + LIBND4J_TYPES, LIBND4J_TYPES); #else - if (xType != yType || xType != zType) - throw sd::datatype_exception::build("NativeOpExecutioner::execScalar", zType, xType, yType); - - auto func = PRAGMA_THREADS_FOR { - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform,::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams, start, stop), LIBND4J_TYPES); - }; - - auto zLen = shape::length(hZShapeInfo); - samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max(1, sd::math::nd4j_min(zLen / 1024, sd::Environment::getInstance().maxMasterThreads()))); + if (xType != yType || xType != zType) + throw sd::datatype_exception::build("NativeOpExecutioner::execScalar", + zType, xType, yType); + + auto func = PRAGMA_THREADS_FOR { + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::scalar::ScalarTransform, + ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, + extraParams, start, stop), + LIBND4J_TYPES); + }; + + auto zLen = shape::length(hZShapeInfo); + samediff::Threads::parallel_for( + func, 0, zLen, 1, + !allowParallelism + ? 1 + : sd::math::nd4j_max( + 1, sd::math::nd4j_min( + zLen / 1024, + sd::Environment::getInstance().maxMasterThreads()))); #endif } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execScalar(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const*hXShapeInfo, - void const* dX, Nd4jLong const*dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const*hZShapeInfo, - void *dZ, Nd4jLong const*dZShapeInfo, - void const* hScalars, Nd4jLong const*hScalarShapeInfo, - void const* dScalars, Nd4jLong const*dScalarShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const*tadShapeInfo, Nd4jLong const*tadOffsets, - Nd4jLong const*tadShapeInfoZ, Nd4jLong const*tadOffsetsZ) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) - return; +void NativeOpExecutioner::execScalar( + sd::LaunchContext *lc, int opNum, void const *hX, + Nd4jLong const *hXShapeInfo, void const *dX, Nd4jLong const *dXShapeInfo, + void *extraParams, void *hZ, Nd4jLong const *hZShapeInfo, void *dZ, + Nd4jLong const *dZShapeInfo, void const *hScalars, + Nd4jLong const *hScalarShapeInfo, void const *dScalars, + Nd4jLong const *dScalarShapeInfo, int *dimension, int dimensionLength, + Nd4jLong const *tadShapeInfo, Nd4jLong const *tadOffsets, + Nd4jLong const *tadShapeInfoZ, Nd4jLong const *tadOffsetsZ) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) return; #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_PAIRWISE_SELECTOR( + xType, yType, zType, functions::scalar::ScalarTransform, + ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, + hScalars, dimension, dimensionLength, tadShapeInfo, + tadOffsets, tadShapeInfoZ, tadOffsetsZ), + LIBND4J_TYPES, LIBND4J_TYPES); #else - if (xType != yType || xType != zType) - throw sd::datatype_exception::build("NativeOpExecutioner::execScalar", zType, xType, yType); - - auto func = PRAGMA_THREADS_FOR { - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES); - }; - - auto yLen = shape::length(hScalarShapeInfo); - samediff::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min(yLen, sd::Environment::getInstance().maxMasterThreads())); + if (xType != yType || xType != zType) + throw sd::datatype_exception::build("NativeOpExecutioner::execScalar", + zType, xType, yType); + + auto func = PRAGMA_THREADS_FOR { + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::scalar::ScalarTransform, + ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, + hScalars, dimension, dimensionLength, tadShapeInfo, + tadOffsets, tadShapeInfoZ, tadOffsetsZ, start, stop), + LIBND4J_TYPES); + }; + + auto yLen = shape::length(hScalarShapeInfo); + samediff::Threads::parallel_tad( + func, 0, yLen, 1, + sd::math::nd4j_min( + yLen, sd::Environment::getInstance().maxMasterThreads())); #endif } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - const void *hScalar, const Nd4jLong *hSscalarShapeInfo, - const void *dScalar, const Nd4jLong *dSscalarShapeInfo, - void *extraParams, - bool allowParallelism) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hSscalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hSscalarShapeInfo)) - return; - - if (xType != yType) - throw sd::datatype_exception::build("NativeOpExecutioner::execScalarBool", xType, yType); - - if (zType != sd::DataType::BOOL) - throw sd::datatype_exception::build("NativeOpExecutioner::execScalarBool", sd::DataType::BOOL, zType); - - auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams, start, stop), LIBND4J_TYPES, BOOL_TYPES); - }; - - auto zLen = shape::length(hZShapeInfo); - samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max(1, sd::math::nd4j_min(zLen / 1024, sd::Environment::getInstance().maxMasterThreads()))); - +void NativeOpExecutioner::execScalarBool( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, const void *hScalar, + const Nd4jLong *hSscalarShapeInfo, const void *dScalar, + const Nd4jLong *dSscalarShapeInfo, void *extraParams, + bool allowParallelism) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hSscalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hSscalarShapeInfo)) return; + + if (xType != yType) + throw sd::datatype_exception::build("NativeOpExecutioner::execScalarBool", + xType, yType); + + if (zType != sd::DataType::BOOL) + throw sd::datatype_exception::build("NativeOpExecutioner::execScalarBool", + sd::DataType::BOOL, zType); + + auto func = PRAGMA_THREADS_FOR { + BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, + ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, + hScalar, extraParams, start, stop), + LIBND4J_TYPES, BOOL_TYPES); + }; + + auto zLen = shape::length(hZShapeInfo); + samediff::Threads::parallel_for( + func, 0, zLen, 1, + !allowParallelism + ? 1 + : sd::math::nd4j_max( + 1, sd::math::nd4j_min( + zLen / 1024, + sd::Environment::getInstance().maxMasterThreads()))); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - const void *hScalars, const Nd4jLong *hScalarShapeInfo, - const void *dScalars, const Nd4jLong *dScalarShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) - return; - - if (xType != yType) - throw sd::datatype_exception::build("NativeOpExecutioner::execScalarBool", xType, yType); - - if (zType != sd::DataType::BOOL) - throw sd::datatype_exception::build("NativeOpExecutioner::execScalarBool", sd::DataType::BOOL, zType); - - auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES); - }; - - auto yLen = shape::length(hScalarShapeInfo); - samediff::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min(yLen, sd::Environment::getInstance().maxMasterThreads())); +void NativeOpExecutioner::execScalarBool( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, const void *hScalars, + const Nd4jLong *hScalarShapeInfo, const void *dScalars, + const Nd4jLong *dScalarShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) return; + + if (xType != yType) + throw sd::datatype_exception::build("NativeOpExecutioner::execScalarBool", + xType, yType); + + if (zType != sd::DataType::BOOL) + throw sd::datatype_exception::build("NativeOpExecutioner::execScalarBool", + sd::DataType::BOOL, zType); + + auto func = PRAGMA_THREADS_FOR { + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::scalar::ScalarBoolTransform, + ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, + hScalars, dimension, dimensionLength, tadShapeInfo, + tadOffsets, tadShapeInfoZ, tadOffsetsZ, start, stop), + LIBND4J_TYPES, BOOL_TYPES); + }; + + auto yLen = shape::length(hScalarShapeInfo); + samediff::Threads::parallel_tad( + func, 0, yLen, 1, + sd::math::nd4j_min( + yLen, sd::Environment::getInstance().maxMasterThreads())); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - const void *hScalar, const Nd4jLong *hSscalarShapeInfo, - const void *dScalar, const Nd4jLong *dSscalarShapeInfo, - void *extraParams, - bool allowParallelism) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hSscalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hSscalarShapeInfo)) - return; - - if (xType != yType || xType != zType) - throw sd::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType); - - if (!sd::DataTypeUtils::isZ(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execScalarInt", sd::DataType::INT32, zType); - - auto func = PRAGMA_THREADS_FOR { - BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams, start, stop), INTEGER_TYPES); - }; - - auto zLen = shape::length(hZShapeInfo); - samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max(1, sd::math::nd4j_min(zLen / 1024, sd::Environment::getInstance().maxMasterThreads()))); - +void NativeOpExecutioner::execScalarInt( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, const void *hScalar, + const Nd4jLong *hSscalarShapeInfo, const void *dScalar, + const Nd4jLong *dSscalarShapeInfo, void *extraParams, + bool allowParallelism) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hSscalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hSscalarShapeInfo)) return; + + if (xType != yType || xType != zType) + throw sd::datatype_exception::build("NativeOpExecutioner::execScalarInt", + xType, yType); + + if (!sd::DataTypeUtils::isZ(zType)) + throw sd::datatype_exception::build("NativeOpExecutioner::execScalarInt", + sd::DataType::INT32, zType); + + auto func = PRAGMA_THREADS_FOR { + BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, + ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, + hScalar, extraParams, start, stop), + INTEGER_TYPES); + }; + + auto zLen = shape::length(hZShapeInfo); + samediff::Threads::parallel_for( + func, 0, zLen, 1, + !allowParallelism + ? 1 + : sd::math::nd4j_max( + 1, sd::math::nd4j_min( + zLen / 1024, + sd::Environment::getInstance().maxMasterThreads()))); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - const void *hScalars, const Nd4jLong *hScalarShapeInfo, - const void *dScalars, const Nd4jLong *dScalarShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) - return; - - if (xType != yType || xType != zType) - throw sd::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType); - - if (!sd::DataTypeUtils::isZ(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execScalarInt requires integer data type", zType); - - auto func = PRAGMA_THREADS_FOR { - BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ, start, stop), INTEGER_TYPES); - }; - - auto yLen = shape::length(hScalarShapeInfo); - samediff::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min(yLen, sd::Environment::getInstance().maxMasterThreads())); +void NativeOpExecutioner::execScalarInt( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, const void *hScalars, + const Nd4jLong *hScalarShapeInfo, const void *dScalars, + const Nd4jLong *dScalarShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) return; + + if (xType != yType || xType != zType) + throw sd::datatype_exception::build("NativeOpExecutioner::execScalarInt", + xType, yType); + + if (!sd::DataTypeUtils::isZ(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execScalarInt requires integer data type", zType); + + auto func = PRAGMA_THREADS_FOR { + BUILD_SINGLE_SELECTOR( + xType, functions::scalar::ScalarIntTransform, + ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, + hScalars, dimension, dimensionLength, tadShapeInfo, + tadOffsets, tadShapeInfoZ, tadOffsetsZ, start, stop), + INTEGER_TYPES); + }; + + auto yLen = shape::length(hScalarShapeInfo); + samediff::Threads::parallel_tad( + func, 0, yLen, 1, + sd::math::nd4j_min( + yLen, sd::Environment::getInstance().maxMasterThreads())); } //////////////////////////////////////////////////////////////////////// /** -* -* @param opNum -* @param hX -* @param hXShapeInfo -* @param extraParams -* @param hZ -* @param hZShapeInfo -*/ -void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - bool biasCorrected) { - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::summarystats::SummaryStatsReduce, ::exec(opNum, biasCorrected, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, nullptr, 1), LIBND4J_TYPES, FLOAT_TYPES); + * + * @param opNum + * @param hX + * @param hXShapeInfo + * @param extraParams + * @param hZ + * @param hZShapeInfo + */ +void NativeOpExecutioner::execSummaryStats( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, bool biasCorrected) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + BUILD_DOUBLE_SELECTOR(xType, zType, + functions::summarystats::SummaryStatsReduce, + ::exec(opNum, biasCorrected, hX, hXShapeInfo, + extraParams, hZ, hZShapeInfo, nullptr, 1), + LIBND4J_TYPES, FLOAT_TYPES); } //////////////////////////////////////////////////////////////////////// /** -* -* @param opNum -* @param hX -* @param hXShapeInfo -* @param extraParams -* @param hZ -* @param hZShapeInfo -*/ -void NativeOpExecutioner::execSummaryStatsScalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - bool biasCorrected) { - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::summarystats::SummaryStatsReduce, ::execScalar(opNum, biasCorrected, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo), LIBND4J_TYPES, FLOAT_TYPES); + * + * @param opNum + * @param hX + * @param hXShapeInfo + * @param extraParams + * @param hZ + * @param hZShapeInfo + */ +void NativeOpExecutioner::execSummaryStatsScalar( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, bool biasCorrected) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + BUILD_DOUBLE_SELECTOR(xType, zType, + functions::summarystats::SummaryStatsReduce, + ::execScalar(opNum, biasCorrected, hX, hXShapeInfo, + extraParams, hZ, hZShapeInfo), + LIBND4J_TYPES, FLOAT_TYPES); } //////////////////////////////////////////////////////////////////////// /** -* -* @param opNum -* @param hX -* @param hXShapeInfo -* @param extraParams -* @param hZ -* @param hZShapeInfo -* @param dimension -* @param dimensionLength -*/ -void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - bool biasCorrected) { - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::summarystats::SummaryStatsReduce, ::exec(opNum, biasCorrected, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength), LIBND4J_TYPES, FLOAT_TYPES); + * + * @param opNum + * @param hX + * @param hXShapeInfo + * @param extraParams + * @param hZ + * @param hZShapeInfo + * @param dimension + * @param dimensionLength + */ +void NativeOpExecutioner::execSummaryStats( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, + bool biasCorrected) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::summarystats::SummaryStatsReduce, + ::exec(opNum, biasCorrected, hX, hXShapeInfo, extraParams, hZ, + hZShapeInfo, dimension, dimensionLength), + LIBND4J_TYPES, FLOAT_TYPES); } - //////////////////////////////////////////////////////////////////////// /** -* -* @param opNum -* @param hX -* @param xStride -* @param hZ -* @param resultStride -* @param extraParams -* @param n -*/ -void NativeOpExecutioner::execTransformFloat(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo)) - return; - - auto func = PRAGMA_THREADS_DO { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, FLOAT_TYPES); - }; - - samediff::Threads::parallel_do(func, sd::math::nd4j_max(1, sd::math::nd4j_min(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance().maxMasterThreads()))); -} - -//////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execTransformBool(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo)) - return; - - auto func = PRAGMA_THREADS_DO { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, BOOL_TYPES); - }; - - samediff::Threads::parallel_do(func, sd::math::nd4j_max(1, sd::math::nd4j_min(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance().maxMasterThreads()))); + * + * @param opNum + * @param hX + * @param xStride + * @param hZ + * @param resultStride + * @param extraParams + * @param n + */ +void NativeOpExecutioner::execTransformFloat( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, void *extraParams, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo)) return; + + auto func = PRAGMA_THREADS_DO { + BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, + ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, + extraParams, thread_id, numThreads), + LIBND4J_TYPES, FLOAT_TYPES); + }; + + samediff::Threads::parallel_do( + func, sd::math::nd4j_max( + 1, sd::math::nd4j_min( + shape::length(hZShapeInfo) / 1024, + sd::Environment::getInstance().maxMasterThreads()))); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execTransformAny(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - bool allowParallelism) { - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo)) - return; - - if (opNum == sd::transform::Assign && shape::order(hXShapeInfo) == shape::order(hZShapeInfo) && shape::order(hXShapeInfo) == 'c' && xType == zType && shape::elementWiseStride(hXShapeInfo) == 1 && shape::elementWiseStride(hZShapeInfo) == 1) { - - memcpy(hZ, hX, shape::length(hXShapeInfo) * sd::DataTypeUtils::sizeOfElement(xType)); - } - else { - auto func = PRAGMA_THREADS_DO { - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, LIBND4J_TYPES); - }; - - samediff::Threads::parallel_do(func, sd::math::nd4j_max(1, sd::math::nd4j_min(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance().maxMasterThreads()))); - } +void NativeOpExecutioner::execTransformBool( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, void *extraParams, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo)) return; + + auto func = PRAGMA_THREADS_DO { + BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, + ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, + extraParams, thread_id, numThreads), + LIBND4J_TYPES, BOOL_TYPES); + }; + + samediff::Threads::parallel_do( + func, sd::math::nd4j_max( + 1, sd::math::nd4j_min( + shape::length(hZShapeInfo) / 1024, + sd::Environment::getInstance().maxMasterThreads()))); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execTransformSame(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo)) - return; - +void NativeOpExecutioner::execTransformAny( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, void *extraParams, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, + bool allowParallelism) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo)) return; + + if (opNum == sd::transform::Assign && + shape::order(hXShapeInfo) == shape::order(hZShapeInfo) && + shape::order(hXShapeInfo) == 'c' && xType == zType && + shape::elementWiseStride(hXShapeInfo) == 1 && + shape::elementWiseStride(hZShapeInfo) == 1) { + memcpy( + hZ, hX, + shape::length(hXShapeInfo) * sd::DataTypeUtils::sizeOfElement(xType)); + } else { auto func = PRAGMA_THREADS_DO { - BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES); + BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, + ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, + extraParams, thread_id, numThreads), + LIBND4J_TYPES, LIBND4J_TYPES); }; - samediff::Threads::parallel_do(func, sd::math::nd4j_max(1, sd::math::nd4j_min(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance().maxMasterThreads()))); + samediff::Threads::parallel_do( + func, sd::math::nd4j_max( + 1, sd::math::nd4j_min( + shape::length(hZShapeInfo) / 1024, + sd::Environment::getInstance().maxMasterThreads()))); + } } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execTransformStrict(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo)) - return; - - auto func = PRAGMA_THREADS_DO { - BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), FLOAT_TYPES); - }; - - samediff::Threads::parallel_do(func, sd::math::nd4j_max(1, sd::math::nd4j_min(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance().maxMasterThreads()))); +void NativeOpExecutioner::execTransformSame( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, void *extraParams, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo)) return; + + auto func = PRAGMA_THREADS_DO { + BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, + ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, + extraParams, thread_id, numThreads), + LIBND4J_TYPES); + }; + + samediff::Threads::parallel_do( + func, sd::math::nd4j_max( + 1, sd::math::nd4j_min( + shape::length(hZShapeInfo) / 1024, + sd::Environment::getInstance().maxMasterThreads()))); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execRandom(sd::LaunchContext *lc, - int opNum, - Nd4jPointer state, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraArguments) { - - - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - BUILD_SINGLE_SELECTOR(zType, functions::random::RandomFunction, ::execTransform(opNum, state, hZ, hZShapeInfo, extraArguments), FLOAT_TYPES); - - auto rng = reinterpret_cast(state); - rng->rewindH(shape::length(hZShapeInfo)); +void NativeOpExecutioner::execTransformStrict( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, void *extraParams, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo)) return; + + auto func = PRAGMA_THREADS_DO { + BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, + ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, + extraParams, thread_id, numThreads), + FLOAT_TYPES); + }; + + samediff::Threads::parallel_do( + func, sd::math::nd4j_max( + 1, sd::math::nd4j_min( + shape::length(hZShapeInfo) / 1024, + sd::Environment::getInstance().maxMasterThreads()))); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execRandom(sd::LaunchContext *lc, - int opNum, - Nd4jPointer state, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, +void NativeOpExecutioner::execRandom(sd::LaunchContext *lc, int opNum, + Nd4jPointer state, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, void *extraArguments) { + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - BUILD_SINGLE_SELECTOR(zType, functions::random::RandomFunction, ::execTransform(opNum, state, hX, hXShapeInfo, hZ, hZShapeInfo, extraArguments), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR( + zType, functions::random::RandomFunction, + ::execTransform(opNum, state, hZ, hZShapeInfo, extraArguments), + FLOAT_TYPES); - auto rng = reinterpret_cast(state); - rng->rewindH(shape::length(hZShapeInfo)); + auto rng = reinterpret_cast(state); + rng->rewindH(shape::length(hZShapeInfo)); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execRandom(sd::LaunchContext *lc, - int opNum, - Nd4jPointer state, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraArguments) { - - auto xType = sd::ArrayOptions::dataType(hZShapeInfo); - - BUILD_SINGLE_SELECTOR(xType, functions::random::RandomFunction, ::execTransform(opNum, state, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraArguments), FLOAT_TYPES); - - auto rng = reinterpret_cast(state); - rng->rewindH(shape::length(hZShapeInfo)); +void NativeOpExecutioner::execRandom( + sd::LaunchContext *lc, int opNum, Nd4jPointer state, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, void *extraArguments) { + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + BUILD_SINGLE_SELECTOR(zType, functions::random::RandomFunction, + ::execTransform(opNum, state, hX, hXShapeInfo, hZ, + hZShapeInfo, extraArguments), + FLOAT_TYPES); + + auto rng = reinterpret_cast(state); + rng->rewindH(shape::length(hZShapeInfo)); } - - - - - - +//////////////////////////////////////////////////////////////////////// +void NativeOpExecutioner::execRandom( + sd::LaunchContext *lc, int opNum, Nd4jPointer state, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, void *extraArguments) { + auto xType = sd::ArrayOptions::dataType(hZShapeInfo); + + BUILD_SINGLE_SELECTOR( + xType, functions::random::RandomFunction, + ::execTransform(opNum, state, hX, hXShapeInfo, hY, hYShapeInfo, hZ, + hZShapeInfo, extraArguments), + FLOAT_TYPES); + + auto rng = reinterpret_cast(state); + rng->rewindH(shape::length(hZShapeInfo)); +} diff --git a/libnd4j/include/legacy/cpu/NativeOps.cpp b/libnd4j/include/legacy/cpu/NativeOps.cpp index f9e3f669c692..6ef6fd484c35 100644 --- a/libnd4j/include/legacy/cpu/NativeOps.cpp +++ b/libnd4j/include/legacy/cpu/NativeOps.cpp @@ -20,62 +20,57 @@ #define __STDC_CONSTANT_MACROS -#include -#include "legacy/NativeOpExecutioner.h" #include -#include +#include +#include #include -#include -#include -#include +#include #include #include -#include -#include -#include +#include +#include +#include #include -#include -#include - - -#include #include #include +#include +#include +#include +#include + +#include "legacy/NativeOpExecutioner.h" #ifndef _WIN32 -#include #include +#include #else -#include #include +#include #endif -#include - -#include #include - +#include +#include char *name; bool nameSet = false; - #ifdef __ND4J_EXPERIMENTAL__ bool experimentalSupport = true; #else bool experimentalSupport = false; #endif -#include -#include -#include -#include +#include #include #include -#include #include +#include +#include +#include +#include #include #include #include -#include +#include #ifdef CPU_FEATURES #include @@ -84,13 +79,11 @@ bool experimentalSupport = false; using namespace sd; void setElementThreshold(int num) { - if (num > 0) - sd::Environment::getInstance().setElementwiseThreshold(num); + if (num > 0) sd::Environment::getInstance().setElementwiseThreshold(num); } void setTADThreshold(int num) { - if (num > 0) - sd::Environment::getInstance().setTadThreshold(num); + if (num > 0) sd::Environment::getInstance().setTadThreshold(num); } /** @@ -100,17 +93,21 @@ void setTADThreshold(int num) { * @param hXShapeInfo * @param extraParams */ -void execIndexReduceScalar(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) { - try { - NativeOpExecutioner::execIndexReduceScalar(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execIndexReduceScalar(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo) { + try { + NativeOpExecutioner::execIndexReduceScalar( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** @@ -124,44 +121,36 @@ void execIndexReduceScalar(Nd4jPointer *extraPointers, * @param dimension * @param dimensionLength */ -void execIndexReduce(Nd4jPointer *extraPointers,int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) { - try { - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, - dimensionLength); - - auto hTADShapeInfo = tadPack.primaryShapeInfo(); - auto hTADOffsets = tadPack.primaryOffsets(); - - auto hz = reinterpret_cast(dbZ->primary()); - - NativeOpExecutioner::execIndexReduce(nullptr, opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - hz, - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - dimension, - dimensionLength, - hTADShapeInfo, - hTADOffsets); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execIndexReduce(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbDimension, + const Nd4jLong *hDimensionShape, + const Nd4jLong *dDimensionShape) { + try { + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + + auto hTADShapeInfo = tadPack.primaryShapeInfo(); + auto hTADOffsets = tadPack.primaryOffsets(); + + auto hz = reinterpret_cast(dbZ->primary()); + + NativeOpExecutioner::execIndexReduce( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, hz, hZShapeInfo, dbZ->special(), dZShapeInfo, + dimension, dimensionLength, hTADShapeInfo, hTADOffsets); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } - /** * * @param opNum @@ -174,83 +163,75 @@ void execIndexReduce(Nd4jPointer *extraPointers,int opNum, * @param dimension * @param dimensionLength */ -void execBroadcast(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) { - try { - auto dimension = reinterpret_cast(dbDimension->primary()); - auto dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength); - auto tadPackZ = sd::ConstantTadHelper::getInstance().tadForDimensions(hZShapeInfo, dimension, dimensionLength); - - auto hTADShapeInfo = tadPackX.primaryShapeInfo(); - auto hTADOffsets = tadPackX.primaryOffsets(); - auto hTADShapeInfoZ = tadPackZ.primaryShapeInfo(); - auto hTADOffsetsZ = tadPackZ.primaryOffsets(); - - NativeOpExecutioner::execBroadcast(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - dbY->primary(), - hYShapeInfo, - dbY->special(), - dYShapeInfo, - dbZ->primary(), hZShapeInfo, - dbZ->special(), dZShapeInfo, - dimension, - dimensionLength, hTADShapeInfo, hTADOffsets, hTADShapeInfoZ, hTADOffsetsZ); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execBroadcastBool(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) { - try { - auto dimension = reinterpret_cast(dbDimension->primary()); - auto dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength); - auto tadPackZ = sd::ConstantTadHelper::getInstance().tadForDimensions(hZShapeInfo, dimension, dimensionLength); - - auto hTADShapeInfo = tadPackX.primaryShapeInfo(); - auto hTADOffsets = tadPackX.primaryOffsets(); - auto hTADShapeInfoZ = tadPackZ.primaryShapeInfo(); - auto hTADOffsetsZ = tadPackZ.primaryOffsets(); - - NativeOpExecutioner::execBroadcastBool(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - dbY->primary(), - hYShapeInfo, - dbY->special(), - dYShapeInfo, - dbZ->primary(), hZShapeInfo, - dbZ->special(), dZShapeInfo, - extraParams, - dimension, - dimensionLength, hTADShapeInfo, hTADOffsets, hTADShapeInfoZ, - hTADOffsetsZ); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execBroadcast(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, + const Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, + const Nd4jLong *hDimensionShape, + const Nd4jLong *dDimensionShape) { + try { + auto dimension = reinterpret_cast(dbDimension->primary()); + auto dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + auto tadPackZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + hZShapeInfo, dimension, dimensionLength); + + auto hTADShapeInfo = tadPackX.primaryShapeInfo(); + auto hTADOffsets = tadPackX.primaryOffsets(); + auto hTADShapeInfoZ = tadPackZ.primaryShapeInfo(); + auto hTADOffsetsZ = tadPackZ.primaryOffsets(); + + NativeOpExecutioner::execBroadcast( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, + dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, dimension, + dimensionLength, hTADShapeInfo, hTADOffsets, hTADShapeInfoZ, + hTADOffsetsZ); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execBroadcastBool(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbY, + const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, void *extraParams, + OpaqueDataBuffer *dbDimension, + const Nd4jLong *hDimensionShape, + const Nd4jLong *dDimensionShape) { + try { + auto dimension = reinterpret_cast(dbDimension->primary()); + auto dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + auto tadPackZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + hZShapeInfo, dimension, dimensionLength); + + auto hTADShapeInfo = tadPackX.primaryShapeInfo(); + auto hTADOffsets = tadPackX.primaryOffsets(); + auto hTADShapeInfoZ = tadPackZ.primaryShapeInfo(); + auto hTADOffsetsZ = tadPackZ.primaryOffsets(); + + NativeOpExecutioner::execBroadcastBool( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, + dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraParams, + dimension, dimensionLength, hTADShapeInfo, hTADOffsets, hTADShapeInfoZ, + hTADOffsetsZ); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** @@ -265,63 +246,42 @@ void execBroadcastBool(Nd4jPointer *extraPointers, * @param extraParams * @param n */ -void execPairwiseTransform( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - void *extraParams) { - try { - NativeOpExecutioner::execPairwiseTransform(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - dbY->primary(), - hYShapeInfo, - dbY->special(), - dYShapeInfo, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - extraParams); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execPairwiseTransform(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbY, + const Nd4jLong *hYShapeInfo, + const Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, void *extraParams) { + try { + NativeOpExecutioner::execPairwiseTransform( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, + dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraParams); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } void execPairwiseTransformBool( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - void *extraParams) { - - try { - NativeOpExecutioner::execPairwiseBoolTransform(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - dbY->primary(), - hYShapeInfo, - dbY->special(), - dYShapeInfo, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - extraParams); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, + const Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, + void *extraParams) { + try { + NativeOpExecutioner::execPairwiseBoolTransform( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, + dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraParams); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** @@ -333,102 +293,72 @@ void execPairwiseTransformBool( * @param hZ * @param hZShapeInfo */ -void execReduceFloat( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) { - - try { - NativeOpExecutioner::execReduceFloatScalar(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execReduceSame( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) { - - try { - NativeOpExecutioner::execReduceSameScalar(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execReduceBool( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) { - try { - NativeOpExecutioner::execReduceBoolScalar(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execReduceLong( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) { - try { - NativeOpExecutioner::execReduceLongScalar(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduceFloat(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo) { + try { + NativeOpExecutioner::execReduceFloatScalar( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execReduceSame(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo) { + try { + NativeOpExecutioner::execReduceSameScalar( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execReduceBool(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo) { + try { + NativeOpExecutioner::execReduceBoolScalar( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execReduceLong(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo) { + try { + NativeOpExecutioner::execReduceLongScalar( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** @@ -440,146 +370,117 @@ void execReduceLong( * @param hZ * @param hZShapeInfo */ -void execReduceFloat2(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) { - try { - auto dimension = reinterpret_cast(dbDimension->primary()); - auto dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength); - - auto hTADShapeInfo = tadPackX.primaryShapeInfo(); - auto hTADOffsets = tadPackX.primaryOffsets(); - - NativeOpExecutioner::execReduceFloat(nullptr, opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - dimension, - dimensionLength, - hTADShapeInfo, - hTADOffsets); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execReduceBool2(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) { - try { - auto dimension = reinterpret_cast(dbDimension->primary()); - auto dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, - dimensionLength); - - auto hTADShapeInfo = tadPack.primaryShapeInfo(); - auto hTADOffsets = tadPack.primaryOffsets(); - - NativeOpExecutioner::execReduceBool(nullptr, opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - dimension, - dimensionLength, - hTADShapeInfo, - hTADOffsets); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execReduceSame2(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) { - try { - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, - dimensionLength); - - auto hTADShapeInfo = tadPack.primaryShapeInfo(); - auto hTADOffsets = tadPack.primaryOffsets(); - - NativeOpExecutioner::execReduceSame(nullptr, opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - dimension, - dimensionLength, - hTADShapeInfo, - hTADOffsets); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execReduceLong2(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) { - try { - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength); - - auto hTADShapeInfo = tadPack.primaryShapeInfo(); - auto hTADOffsets = tadPack.primaryOffsets(); - - NativeOpExecutioner::execReduceLong(nullptr, opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - dimension, - dimensionLength, - hTADShapeInfo, - hTADOffsets); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduceFloat2(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, + const Nd4jLong *hDimensionShape, + const Nd4jLong *dDimensionShape) { + try { + auto dimension = reinterpret_cast(dbDimension->primary()); + auto dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + + auto hTADShapeInfo = tadPackX.primaryShapeInfo(); + auto hTADOffsets = tadPackX.primaryOffsets(); + + NativeOpExecutioner::execReduceFloat( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo, dimension, dimensionLength, hTADShapeInfo, hTADOffsets); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execReduceBool2(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbDimension, + const Nd4jLong *hDimensionShape, + const Nd4jLong *dDimensionShape) { + try { + auto dimension = reinterpret_cast(dbDimension->primary()); + auto dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + + auto hTADShapeInfo = tadPack.primaryShapeInfo(); + auto hTADOffsets = tadPack.primaryOffsets(); + + NativeOpExecutioner::execReduceBool( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo, dimension, dimensionLength, hTADShapeInfo, hTADOffsets); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execReduceSame2(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbDimension, + const Nd4jLong *hDimensionShape, + const Nd4jLong *dDimensionShape) { + try { + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + + auto hTADShapeInfo = tadPack.primaryShapeInfo(); + auto hTADOffsets = tadPack.primaryOffsets(); + + NativeOpExecutioner::execReduceSame( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo, dimension, dimensionLength, hTADShapeInfo, hTADOffsets); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execReduceLong2(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbDimension, + const Nd4jLong *hDimensionShape, + const Nd4jLong *dDimensionShape) { + try { + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + + auto hTADShapeInfo = tadPack.primaryShapeInfo(); + auto hTADOffsets = tadPack.primaryOffsets(); + + NativeOpExecutioner::execReduceLong( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo, dimension, dimensionLength, hTADShapeInfo, hTADOffsets); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** @@ -593,19 +494,22 @@ void execReduceLong2(Nd4jPointer *extraPointers, * @param hZ * @param hZShapeInfo */ -void execReduce3(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) { - try { - NativeOpExecutioner::execReduce3(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbY->primary(), hYShapeInfo, - dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduce3(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, + void *extraParams, OpaqueDataBuffer *dbY, + const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo) { + try { + NativeOpExecutioner::execReduce3( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbY->primary(), hYShapeInfo, dbY->special(), + dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** @@ -617,18 +521,23 @@ void execReduce3(Nd4jPointer *extraPointers, * @param hY * @param hYShapeInfo */ -void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) { - try { - NativeOpExecutioner::execReduce3Scalar(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbY->primary(), - hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduce3Scalar(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, + const Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo) { + try { + NativeOpExecutioner::execReduce3Scalar( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbY->primary(), hYShapeInfo, dbY->special(), + dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** * @@ -643,46 +552,50 @@ void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum, * @param dimension * @param dimensionLength */ -void execReduce3Tad(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets) { - try { - auto dimension = reinterpret_cast(dbDimension->primary()); - auto dimensionLength = static_cast(shape::length(hDimensionShape)); - - if (extraPointers == nullptr || extraPointers[2] == 0) { - NativeOpExecutioner::execReduce3(LaunchContext::defaultContext(), opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, - extraParams, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), - dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, - yTadOnlyShapeInfo, yTadOffsets); - } else { - // going tad-way - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, - dimensionLength); - - auto hTADShapeInfo = tadPack.primaryShapeInfo(); - auto hTADOffsets = tadPack.primaryOffsets(); - - NativeOpExecutioner::execReduce3TAD(LaunchContext::defaultContext(), opNum, dbX->primary(), hXShapeInfo, dbX->special(), - dXShapeInfo, extraParams, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), - hZShapeInfo, dbZ->special(), dZShapeInfo, dimension, dimensionLength, hTADShapeInfo, - hTADOffsets, nullptr, nullptr); - } - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); +void execReduce3Tad( + Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, + const Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, + const Nd4jLong *dDimensionShape, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *yTadOnlyShapeInfo, + const Nd4jLong *yTadOffsets) { + try { + auto dimension = reinterpret_cast(dbDimension->primary()); + auto dimensionLength = static_cast(shape::length(hDimensionShape)); + + if (extraPointers == nullptr || extraPointers[2] == 0) { + NativeOpExecutioner::execReduce3( + LaunchContext::defaultContext(), opNum, dbX->primary(), hXShapeInfo, + dbX->special(), dXShapeInfo, extraParams, dbY->primary(), hYShapeInfo, + dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, + dbZ->special(), dZShapeInfo, dimension, dimensionLength, + tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); + } else { + // going tad-way + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + + auto hTADShapeInfo = tadPack.primaryShapeInfo(); + auto hTADOffsets = tadPack.primaryOffsets(); + + NativeOpExecutioner::execReduce3TAD( + LaunchContext::defaultContext(), opNum, dbX->primary(), hXShapeInfo, + dbX->special(), dXShapeInfo, extraParams, dbY->primary(), hYShapeInfo, + dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, + dbZ->special(), dZShapeInfo, dimension, dimensionLength, + hTADShapeInfo, hTADOffsets, nullptr, nullptr); } + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } -bool isBlasVersionMatches(int major, int minor, int build) { - return true; -} +bool isBlasVersionMatches(int major, int minor, int build) { return true; } /** * @@ -695,62 +608,43 @@ bool isBlasVersionMatches(int major, int minor, int build) { * @param extraParams * @param n */ -void execScalar( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - OpaqueDataBuffer *dbScalar, const Nd4jLong *hScalarShapeInfo, const Nd4jLong *dScalarShapeInfo, - void *extraParams) { - try { - NativeOpExecutioner::execScalar(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - dbScalar->primary(), - hScalarShapeInfo, - dbScalar->special(), - dScalarShapeInfo, - extraParams); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execScalarBool( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - OpaqueDataBuffer *dbScalar, const Nd4jLong *hScalarShapeInfo, const Nd4jLong *dScalarShapeInfo, - void *extraParams) { - try { - NativeOpExecutioner::execScalarBool(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - dbScalar->primary(), - hScalarShapeInfo, - dbScalar->special(), - dScalarShapeInfo, - extraParams); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execScalar(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbScalar, + const Nd4jLong *hScalarShapeInfo, + const Nd4jLong *dScalarShapeInfo, void *extraParams) { + try { + NativeOpExecutioner::execScalar( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, + dbScalar->primary(), hScalarShapeInfo, dbScalar->special(), + dScalarShapeInfo, extraParams); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execScalarBool(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalar, + const Nd4jLong *hScalarShapeInfo, + const Nd4jLong *dScalarShapeInfo, void *extraParams) { + try { + NativeOpExecutioner::execScalarBool( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, + dbScalar->primary(), hScalarShapeInfo, dbScalar->special(), + dScalarShapeInfo, extraParams); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** @@ -760,29 +654,21 @@ void execScalarBool( * @param hXShapeInfo * @param extraParams */ -void execSummaryStatsScalar(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - bool biasCorrected) { - try { - NativeOpExecutioner::execSummaryStatsScalar(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - biasCorrected); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execSummaryStatsScalar(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, bool biasCorrected) { + try { + NativeOpExecutioner::execSummaryStatsScalar( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo, biasCorrected); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** * @@ -793,29 +679,21 @@ void execSummaryStatsScalar(Nd4jPointer *extraPointers, * @param hZ * @param hZShapeInfo */ -void execSummaryStats(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - bool biasCorrected) { - try { - NativeOpExecutioner::execSummaryStats(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - biasCorrected); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execSummaryStats(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, bool biasCorrected) { + try { + NativeOpExecutioner::execSummaryStats( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo, biasCorrected); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** * @@ -828,39 +706,30 @@ void execSummaryStats(Nd4jPointer *extraPointers, * @param dimension * @param dimensionLength */ -void execSummaryStatsTad(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape, - bool biasCorrected, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - try { - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - - NativeOpExecutioner::execSummaryStats(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - dimension, - dimensionLength, - tadShapeInfo, - tadOffsets, - biasCorrected); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execSummaryStatsTad(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, + const Nd4jLong *hDimensionShape, + const Nd4jLong *dDimensionShape, bool biasCorrected, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + try { + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + NativeOpExecutioner::execSummaryStats( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, + biasCorrected); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** @@ -873,220 +742,178 @@ void execSummaryStatsTad(Nd4jPointer *extraPointers, * @param extraParams * @param n */ -void execTransformFloat( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - void *extraParams) { - try { - NativeOpExecutioner::execTransformFloat(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - extraParams, - nullptr, - nullptr); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execTransformSame( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - void *extraParams) { - try { - NativeOpExecutioner::execTransformSame(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - extraParams, - nullptr, - nullptr); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execTransformBool( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - void *extraParams) { - try { - NativeOpExecutioner::execTransformBool(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - extraParams, - nullptr, - nullptr); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execTransformAny( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - void *extraParams) { - try { - NativeOpExecutioner::execTransformAny(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - extraParams, - nullptr, - nullptr); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execTransformStrict( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - void *extraParams) { - try { - NativeOpExecutioner::execTransformStrict(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - extraParams, - nullptr, - nullptr); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execReduce3All(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParamsVals, - OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets) { - - try { - auto dimension = reinterpret_cast(dbDimension->primary()); - auto dimensionLength = static_cast(shape::length(hDimensionShape)); - - - NativeOpExecutioner::execReduce3All(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParamsVals, dbY->primary(), - hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, dimension, - dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execTransformFloat(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, void *extraParams) { + try { + NativeOpExecutioner::execTransformFloat( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, + extraParams, nullptr, nullptr); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execTransformSame(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, + void *extraParams) { + try { + NativeOpExecutioner::execTransformSame( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, + extraParams, nullptr, nullptr); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execTransformBool(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, + void *extraParams) { + try { + NativeOpExecutioner::execTransformBool( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, + extraParams, nullptr, nullptr); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execTransformAny(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, + void *extraParams) { + try { + NativeOpExecutioner::execTransformAny( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, + extraParams, nullptr, nullptr); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execTransformStrict(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, void *extraParams) { + try { + NativeOpExecutioner::execTransformStrict( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, + extraParams, nullptr, nullptr); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execReduce3All(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParamsVals, + OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, + const Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, + const Nd4jLong *hDimensionShape, + const Nd4jLong *dDimensionShape, + const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, + const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets) { + try { + auto dimension = reinterpret_cast(dbDimension->primary()); + auto dimensionLength = static_cast(shape::length(hDimensionShape)); + + NativeOpExecutioner::execReduce3All( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParamsVals, dbY->primary(), hYShapeInfo, + dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, + dbZ->special(), dZShapeInfo, dimension, dimensionLength, xTadShapeInfo, + xOffsets, yTadShapeInfo, yOffsets); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** - * Concatneate multi array of the same shape together - * along a particular dimension - */ -void specialConcat( - Nd4jPointer *extraPointers, - int dimension, - int numArrays, - Nd4jPointer *data, - Nd4jPointer *inputShapeInfo, - void *hZ, - Nd4jLong const* hZShapeInfo, - Nd4jPointer *tadPointers, - Nd4jPointer *offsetPointers) { - try { - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - BUILD_SINGLE_SELECTOR(zType, sd::SpecialMethods,::concatCpuGeneric(dimension, numArrays, data, inputShapeInfo, hZ, hZShapeInfo), LIBND4J_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + * Concatneate multi array of the same shape together + * along a particular dimension + */ +void specialConcat(Nd4jPointer *extraPointers, int dimension, int numArrays, + Nd4jPointer *data, Nd4jPointer *inputShapeInfo, void *hZ, + Nd4jLong const *hZShapeInfo, Nd4jPointer *tadPointers, + Nd4jPointer *offsetPointers) { + try { + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + BUILD_SINGLE_SELECTOR(zType, sd::SpecialMethods, + ::concatCpuGeneric(dimension, numArrays, data, + inputShapeInfo, hZ, hZShapeInfo), + LIBND4J_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** * This is dummy method for JNI compatibility - * Since we'll use this from java, jni compiler would like to have method no matter what. + * Since we'll use this from java, jni compiler would like to have method no + * matter what. */ -void initializeDevicesAndFunctions() { - -} +void initializeDevicesAndFunctions() {} void initializeFunctions(Nd4jPointer *functions) { - sd::BlasHelper::getInstance().initializeFunctions(functions); + sd::BlasHelper::getInstance().initializeFunctions(functions); } /** - * This method acquires memory chunk of requested size on host side - * - * @param pointer pointer that'll be used for allocation - * @param memorySize memory size, in bytes - * @param flags optional parameter - */ + * This method acquires memory chunk of requested size on host side + * + * @param pointer pointer that'll be used for allocation + * @param memorySize memory size, in bytes + * @param flags optional parameter + */ Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) { - return reinterpret_cast(new int8_t[memorySize]); + return reinterpret_cast(new int8_t[memorySize]); } /** * This method acquires memory chunk of requested size on specified device * - * PLEASE NOTE: This method is NOT supported and has NO effect in CPU-based backend. + * PLEASE NOTE: This method is NOT supported and has NO effect in CPU-based + * backend. * * @param pointer pointer that'll be used for allocation * @param memorySize memory size, in bytes - * @param ptrToDeviceId pointer to deviceId. For cuda that's just and int, for OpenCL that's pointer to device_id, etc + * @param ptrToDeviceId pointer to deviceId. For cuda that's just and int, for + * OpenCL that's pointer to device_id, etc * @param flags optional parameter */ Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) { - // not supported - return 0L; + // not supported + return 0L; } /** @@ -1095,1646 +922,1524 @@ Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) { * @param pointer pointer that'll be freed */ int freeHost(Nd4jPointer pointer) { - delete[] reinterpret_cast(pointer); - return 1L; + delete[] reinterpret_cast(pointer); + return 1L; } /** * This method releases previously allocated memory space on device * - * PLEASE NOTE: This method is NOT supported and has NO effect in CPU-based backend. + * PLEASE NOTE: This method is NOT supported and has NO effect in CPU-based + * backend. * * @param pointer pointer that'll be freed * @param ptrToDeviceId pointer to deviceId. */ int freeDevice(Nd4jPointer pointer, int deviceId) { - // not supported - return 0L; + // not supported + return 0L; } - /** * Returns the maximum number open mp threads */ -int ompGetMaxThreads() { - return omp_get_max_threads(); -} +int ompGetMaxThreads() { return omp_get_max_threads(); } /** * Returns the number open mp threads */ -int ompGetNumThreads() { - return omp_get_num_threads(); -} +int ompGetNumThreads() { return omp_get_num_threads(); } /** * Sets the number of openmp threads */ -void setOmpNumThreads(int threads) { - omp_set_num_threads(threads); +void setOmpNumThreads(int threads) { omp_set_num_threads(threads); } -} +Nd4jPointer createContext() { return 0L; } -Nd4jPointer createContext() { - return 0L; -} +Nd4jPointer createStream() { return 0L; } -Nd4jPointer createStream() { - return 0L; -} +Nd4jPointer createEvent() { return 0L; } -Nd4jPointer createEvent() { - return 0L; -} +int getDeviceMajor(int deviceId) { return 0; } -int getDeviceMajor(int deviceId ) { - return 0; -} +int getDeviceMinor(int deviceId) { return 0; } -int getDeviceMinor(int deviceId) { - return 0; -} +int registerEvent(Nd4jPointer event, Nd4jPointer stream) { return 0L; } -int registerEvent(Nd4jPointer event, Nd4jPointer stream) { - return 0L; -} +int setDevice(int deviceId) { return 0L; } -int setDevice(int deviceId) { - return 0L; -} +Nd4jLong getDeviceFreeMemory(int deviceId) { return 0L; } -Nd4jLong getDeviceFreeMemory(int deviceId) { - return 0L; -} +Nd4jLong getDeviceFreeMemoryDefault() { return 0L; } -Nd4jLong getDeviceFreeMemoryDefault() { - return 0L; -} +Nd4jLong getDeviceTotalMemory(int deviceId) { return 0L; } -Nd4jLong getDeviceTotalMemory(int deviceId) { - return 0L; +int memcpySync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, + Nd4jPointer reserved) { + return 0L; } -int memcpySync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) { - return 0L; +int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, + Nd4jPointer reserved) { + return 0L; } -int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) { - return 0L; +int memsetSync(Nd4jPointer dst, int value, Nd4jLong size, int flags, + Nd4jPointer reserved) { + return 0L; } -int memsetSync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointer reserved) { - return 0L; +int memsetAsync(Nd4jPointer dst, int value, Nd4jLong size, int flags, + Nd4jPointer reserved) { + return 0L; } -int memsetAsync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointer reserved) { - return 0L; -} +int destroyEvent(Nd4jPointer event) { return 0L; } -int destroyEvent(Nd4jPointer event) { - return 0L; -} - -int streamSynchronize(Nd4jPointer stream) { - return 0L; -} +int streamSynchronize(Nd4jPointer stream) { return 0L; } -int eventSynchronize(Nd4jPointer event) { - return 0L; -} +int eventSynchronize(Nd4jPointer event) { return 0L; } -int getAvailableDevices() { - return 0L; -} +int getAvailableDevices() { return 0L; } void enableDebugMode(bool reallyEnable) { - sd::Environment::getInstance().setDebug(reallyEnable); + sd::Environment::getInstance().setDebug(reallyEnable); } void enableVerboseMode(bool reallyEnable) { - sd::Environment::getInstance().setVerbose(reallyEnable); + sd::Environment::getInstance().setVerbose(reallyEnable); } void setGridLimit(int gridSize) { - // no-op + // no-op } -sd::TadPack* tadOnlyShapeInfo(Nd4jLong const* hXShapeInfo, int *dimension, int dimensionLength) { - auto pack = new TadPack(); - try { - *pack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +sd::TadPack *tadOnlyShapeInfo(Nd4jLong const *hXShapeInfo, int *dimension, + int dimensionLength) { + auto pack = new TadPack(); + try { + *pack = sd::ConstantTadHelper::getInstance().tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } - return pack; + return pack; } -Nd4jLong const* getPrimaryShapeInfo(sd::TadPack* pack) { - return const_cast(pack->primaryShapeInfo()); +Nd4jLong const *getPrimaryShapeInfo(sd::TadPack *pack) { + return const_cast(pack->primaryShapeInfo()); } -Nd4jLong const* getPrimaryOffsets(sd::TadPack* pack) { - return const_cast(pack->primaryOffsets()); +Nd4jLong const *getPrimaryOffsets(sd::TadPack *pack) { + return const_cast(pack->primaryOffsets()); } -Nd4jLong const* getSpecialShapeInfo(sd::TadPack* pack) { - return const_cast(pack->specialShapeInfo()); +Nd4jLong const *getSpecialShapeInfo(sd::TadPack *pack) { + return const_cast(pack->specialShapeInfo()); } -Nd4jLong const* getSpecialOffsets(sd::TadPack* pack) { - return const_cast(pack->specialOffsets()); +Nd4jLong const *getSpecialOffsets(sd::TadPack *pack) { + return const_cast(pack->specialOffsets()); } -Nd4jLong getNumberOfTads(sd::TadPack* pack) { - return pack->numberOfTads(); -} +Nd4jLong getNumberOfTads(sd::TadPack *pack) { return pack->numberOfTads(); } -int getShapeInfoLength(sd::TadPack* pack) { - return pack->shapeInfoLength(); -} +int getShapeInfoLength(sd::TadPack *pack) { return pack->shapeInfoLength(); } -int memcpyConstantAsync(Nd4jLong dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) { - // no-op - return 0L; +int memcpyConstantAsync(Nd4jLong dst, Nd4jPointer src, Nd4jLong size, int flags, + Nd4jPointer reserved) { + // no-op + return 0L; } Nd4jPointer getConstantSpace() { - // no-op - return 0L; -} - -template -void pullRowsGeneric(void *vx, - Nd4jLong const* hXShapeInfo, - void *vz, - Nd4jLong const* hZShapeInfo, - const int n, - Nd4jLong const* indexes, - Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffsets, - Nd4jLong const* zTadShapeInfo, - Nd4jLong const* zTadOffsets) { - auto hX = reinterpret_cast(vx); - auto hZ = reinterpret_cast(vz); - - const auto xEWS = shape::elementWiseStride(tadShapeInfo); - const auto zEWS = shape::elementWiseStride(zTadShapeInfo); - const auto tadLength = shape::length(tadShapeInfo); - - int elementsPerThread = n / TAD_THRESHOLD; - int _threads = sd::math::nd4j_max(1, elementsPerThread); - _threads = sd::math::nd4j_min(_threads, sd::Environment::getInstance().maxThreads()); - - auto func = PRAGMA_THREADS_FOR { - for (auto idx = start; idx < stop; idx++) { - auto xTadOffsetForBlock = tadOffsets[indexes[idx]]; - auto zTadOffsetForBlock = zTadOffsets[idx]; - - auto rX = hX + xTadOffsetForBlock; - auto rZ = hZ + zTadOffsetForBlock; - - if (xEWS == 1 && zEWS == 1) { - PRAGMA_OMP_SIMD - for (Nd4jLong i = 0; i < tadLength; i++) { - rZ[i] = rX[i]; - } - } else if (xEWS >= 1 && zEWS >= 1) { - PRAGMA_OMP_SIMD - for (Nd4jLong i = 0; i < tadLength; i++) { - rZ[i * zEWS] = rX[i * xEWS]; - } - } else { - for (Nd4jLong i = 0; i < tadLength; i++) { - auto xOffset = xTadOffsetForBlock + shape::getIndexOffset(i, tadShapeInfo); - auto zOffset = zTadOffsetForBlock + shape::getIndexOffset(i, zTadShapeInfo); - hZ[zOffset] = hX[xOffset]; - } - } - } - }; - - samediff::Threads::parallel_tad(func, 0, n, 1, _threads); -} - -void pullRows(Nd4jPointer *extraPointers, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - Nd4jLong n, - Nd4jLong* indexes, - Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffsets, - Nd4jLong const* zTadShapeInfo, - Nd4jLong const* zTadOffsets) { - try { - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - - BUILD_SINGLE_SELECTOR(xType, pullRowsGeneric, (dbX->primary(), hXShapeInfo, dbZ->primary(), hZShapeInfo, n, indexes, tadShapeInfo, tadOffsets, zTadShapeInfo, zTadOffsets), LIBND4J_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + // no-op + return 0L; } -template -void tearGeneric(void *vx, - Nd4jLong const* hXShapeInfo, - Nd4jPointer *targets, - Nd4jLong const* hZShapeInfo, - Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffsets) { - - auto hX = reinterpret_cast(vx); - - const auto tadLength = shape::length(tadShapeInfo); - auto tadEWS = shape::elementWiseStride(tadShapeInfo); - auto zEWS = shape::elementWiseStride(hZShapeInfo); - auto numTads = shape::length(hXShapeInfo) / tadLength; - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto hZ = reinterpret_cast(targets[i]); - auto s = hX + tadOffsets[i]; - - if (zEWS == 1 && tadEWS == 1) { - PRAGMA_OMP_SIMD - for (Nd4jLong j = 0; j < tadLength; j++) { - hZ[j] = s[j]; - } - } else if (zEWS > 0 && tadEWS > 0) { - PRAGMA_OMP_SIMD - for (Nd4jLong j = 0; j < tadLength; j++) { - hZ[j * zEWS] = s[j * tadEWS]; - } - } else { - for (Nd4jLong j = 0; j < tadLength; j++) - hZ[shape::getIndexOffset(j, hZShapeInfo)] = s[shape::getIndexOffset(j, tadShapeInfo)]; - } +template +void pullRowsGeneric(void *vx, Nd4jLong const *hXShapeInfo, void *vz, + Nd4jLong const *hZShapeInfo, const int n, + Nd4jLong const *indexes, Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, Nd4jLong const *zTadShapeInfo, + Nd4jLong const *zTadOffsets) { + auto hX = reinterpret_cast(vx); + auto hZ = reinterpret_cast(vz); + + const auto xEWS = shape::elementWiseStride(tadShapeInfo); + const auto zEWS = shape::elementWiseStride(zTadShapeInfo); + const auto tadLength = shape::length(tadShapeInfo); + + int elementsPerThread = n / TAD_THRESHOLD; + int _threads = sd::math::nd4j_max(1, elementsPerThread); + _threads = sd::math::nd4j_min( + _threads, sd::Environment::getInstance().maxThreads()); + + auto func = PRAGMA_THREADS_FOR { + for (auto idx = start; idx < stop; idx++) { + auto xTadOffsetForBlock = tadOffsets[indexes[idx]]; + auto zTadOffsetForBlock = zTadOffsets[idx]; + + auto rX = hX + xTadOffsetForBlock; + auto rZ = hZ + zTadOffsetForBlock; + + if (xEWS == 1 && zEWS == 1) { + PRAGMA_OMP_SIMD + for (Nd4jLong i = 0; i < tadLength; i++) { + rZ[i] = rX[i]; + } + } else if (xEWS >= 1 && zEWS >= 1) { + PRAGMA_OMP_SIMD + for (Nd4jLong i = 0; i < tadLength; i++) { + rZ[i * zEWS] = rX[i * xEWS]; + } + } else { + for (Nd4jLong i = 0; i < tadLength; i++) { + auto xOffset = + xTadOffsetForBlock + shape::getIndexOffset(i, tadShapeInfo); + auto zOffset = + zTadOffsetForBlock + shape::getIndexOffset(i, zTadShapeInfo); + hZ[zOffset] = hX[xOffset]; } - }; + } + } + }; - samediff::Threads::parallel_tad(func,0, numTads); + samediff::Threads::parallel_tad(func, 0, n, 1, _threads); } -void tear(Nd4jPointer *extraPointers, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - Nd4jPointer *targets, - Nd4jLong const* hZShapeInfo, - Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffsets) { - try { - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); +void pullRows(Nd4jPointer *extraPointers, OpaqueDataBuffer *dbX, + Nd4jLong const *hXShapeInfo, Nd4jLong const *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, Nd4jLong n, Nd4jLong *indexes, + Nd4jLong const *tadShapeInfo, Nd4jLong const *tadOffsets, + Nd4jLong const *zTadShapeInfo, Nd4jLong const *zTadOffsets) { + try { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - BUILD_SINGLE_SELECTOR(xType, tearGeneric, (dbX->primary(), hXShapeInfo, targets, hZShapeInfo, tadShapeInfo, tadOffsets), LIBND4J_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + BUILD_SINGLE_SELECTOR( + xType, pullRowsGeneric, + (dbX->primary(), hXShapeInfo, dbZ->primary(), hZShapeInfo, n, indexes, + tadShapeInfo, tadOffsets, zTadShapeInfo, zTadOffsets), + LIBND4J_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } - -void average(Nd4jPointer *extras, - Nd4jPointer *hX, const Nd4jLong *hXShapeInfo, - Nd4jPointer *dX, const Nd4jLong *dXShapeInfo, - void *z, const Nd4jLong *hZShapeInfo, - void *dz, const Nd4jLong *dZShapeInfo, - int n, - Nd4jLong length, - bool propagate) { - try { - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - - BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::averageGeneric(hX, z, hZShapeInfo, n, length, propagate), LIBND4J_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); +template +void tearGeneric(void *vx, Nd4jLong const *hXShapeInfo, Nd4jPointer *targets, + Nd4jLong const *hZShapeInfo, Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets) { + auto hX = reinterpret_cast(vx); + + const auto tadLength = shape::length(tadShapeInfo); + auto tadEWS = shape::elementWiseStride(tadShapeInfo); + auto zEWS = shape::elementWiseStride(hZShapeInfo); + auto numTads = shape::length(hXShapeInfo) / tadLength; + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto hZ = reinterpret_cast(targets[i]); + auto s = hX + tadOffsets[i]; + + if (zEWS == 1 && tadEWS == 1) { + PRAGMA_OMP_SIMD + for (Nd4jLong j = 0; j < tadLength; j++) { + hZ[j] = s[j]; + } + } else if (zEWS > 0 && tadEWS > 0) { + PRAGMA_OMP_SIMD + for (Nd4jLong j = 0; j < tadLength; j++) { + hZ[j * zEWS] = s[j * tadEWS]; + } + } else { + for (Nd4jLong j = 0; j < tadLength; j++) + hZ[shape::getIndexOffset(j, hZShapeInfo)] = + s[shape::getIndexOffset(j, tadShapeInfo)]; + } } + }; + + samediff::Threads::parallel_tad(func, 0, numTads); } -void accumulate(Nd4jPointer *extras, - Nd4jPointer *hX, Nd4jLong const* hXShapeInfo, - Nd4jPointer *dX, Nd4jLong const* dXShapeInfo, - void *hz, Nd4jLong const* hZShapeInfo, - void *dz, Nd4jLong const* dZShapeInfo, - int n, - Nd4jLong length) { - try { - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - - BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::accumulateGeneric(hX, hz, hZShapeInfo, n, length), LIBND4J_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void tear(Nd4jPointer *extraPointers, OpaqueDataBuffer *dbX, + Nd4jLong const *hXShapeInfo, Nd4jLong const *dXShapeInfo, + Nd4jPointer *targets, Nd4jLong const *hZShapeInfo, + Nd4jLong const *tadShapeInfo, Nd4jLong const *tadOffsets) { + try { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + + BUILD_SINGLE_SELECTOR(xType, tearGeneric, + (dbX->primary(), hXShapeInfo, targets, hZShapeInfo, + tadShapeInfo, tadOffsets), + LIBND4J_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void average(Nd4jPointer *extras, Nd4jPointer *hX, const Nd4jLong *hXShapeInfo, + Nd4jPointer *dX, const Nd4jLong *dXShapeInfo, void *z, + const Nd4jLong *hZShapeInfo, void *dz, const Nd4jLong *dZShapeInfo, + int n, Nd4jLong length, bool propagate) { + try { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + + BUILD_SINGLE_SELECTOR( + xType, sd::SpecialMethods, + ::averageGeneric(hX, z, hZShapeInfo, n, length, propagate), + LIBND4J_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void accumulate(Nd4jPointer *extras, Nd4jPointer *hX, + Nd4jLong const *hXShapeInfo, Nd4jPointer *dX, + Nd4jLong const *dXShapeInfo, void *hz, + Nd4jLong const *hZShapeInfo, void *dz, + Nd4jLong const *dZShapeInfo, int n, Nd4jLong length) { + try { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + + BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, + ::accumulateGeneric(hX, hz, hZShapeInfo, n, length), + LIBND4J_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } void enableP2P(bool enable) { - // no-op + // no-op } - - -void encodeThresholdP1(Nd4jPointer *extraPointers, void *hX, Nd4jLong const* hXShapeInfo, Nd4jLong N, int *dz, float threshold) { - // TODO: to be implemented +void encodeThresholdP1(Nd4jPointer *extraPointers, void *hX, + Nd4jLong const *hXShapeInfo, Nd4jLong N, int *dz, + float threshold) { + // TODO: to be implemented } - -void encodeThresholdP2Int(Nd4jPointer *extraPointers, int *hX, Nd4jLong N, int *dz) { - // TODO: to be implemented +void encodeThresholdP2Int(Nd4jPointer *extraPointers, int *hX, Nd4jLong N, + int *dz) { + // TODO: to be implemented } +void encodeThresholdP3(Nd4jPointer *extraPointers, void *hX, + Nd4jLong const *hXShapeInfo, int *offsets, Nd4jLong N, + int *dz) { + // offsets won't be used here -void encodeThresholdP3(Nd4jPointer *extraPointers, void *hX, Nd4jLong const* hXShapeInfo, int *offsets, Nd4jLong N, int *dz){ - // offsets won't be used here - - // TODO: to be implemented + // TODO: to be implemented } -void decodeThreshold(Nd4jPointer *extraPointers, void *hX, Nd4jLong N, void *dz, const Nd4jLong *hZShapeInfo){ - // TODO: to be implemented +void decodeThreshold(Nd4jPointer *extraPointers, void *hX, Nd4jLong N, void *dz, + const Nd4jLong *hZShapeInfo) { + // TODO: to be implemented } bool isP2PAvailable() { - // always TRUE for cpu backend - return true; + // always TRUE for cpu backend + return true; } void checkP2P() { - // no-op + // no-op } -void decodeBitmap(Nd4jPointer *extraPointers, void *hX, Nd4jLong N, void *dz, Nd4jLong const* hZShapeInfo) { - NativeOpExecutioner::decodeBitmap(hX, N, dz, hZShapeInfo); +void decodeBitmap(Nd4jPointer *extraPointers, void *hX, Nd4jLong N, void *dz, + Nd4jLong const *hZShapeInfo) { + NativeOpExecutioner::decodeBitmap(hX, N, dz, hZShapeInfo); } -template -void shuffleGeneric(void **hX, Nd4jLong * const*hXShapeInfo, void **dz, Nd4jLong * const* hZShapeInfo, int N, int *shuffleMap, Nd4jLong * const* tadOnlyShapeInfo, Nd4jLong * const* tadOffsets) { - - auto dX = reinterpret_cast(hX); - auto dZ = reinterpret_cast(dz); - - auto func = PRAGMA_THREADS_FOR { - for (auto f = start; f < stop; f++) { - auto hX = reinterpret_cast(dX[f]); - //auto hZ = reinterpret_cast(dZ[f]); - - auto xShapeInfo = hXShapeInfo[f]; - auto tadOffset = reinterpret_cast(tadOffsets[f]); - - - const auto tadLength = shape::length(tadOnlyShapeInfo[f]); - auto tadEWS = shape::elementWiseStride(tadOnlyShapeInfo[f]); - auto tadRank = shape::rank(tadOnlyShapeInfo[f]); - auto numTads = shape::length(hXShapeInfo[f]) / tadLength; - - auto tadShape = shape::shapeOf(tadOnlyShapeInfo[f]); - auto tadStride = shape::stride(tadOnlyShapeInfo[f]); - - if (shape::rank(xShapeInfo) == 1) { - auto xLength = shape::length(xShapeInfo); - auto ews = shape::elementWiseStride(xShapeInfo); - for (Nd4jLong r = 0; r < xLength; r++) { - auto swapIdx = shuffleMap[r]; - if (swapIdx < 0) - continue; - - sd::math::nd4j_swap(hX[r * ews], hX[swapIdx * ews]); - } - } else { - for (Nd4jLong r = 0; r < numTads; r++) { - if (shuffleMap[r] < 0) - continue; - - auto oldOffset = tadOffset[r]; - auto newOffset = tadOffset[shuffleMap[r]]; - - auto rX = hX + oldOffset; - auto rY = hX + newOffset; - - if (tadEWS == 1) { - for (Nd4jLong i = 0; i < tadLength; i++) { - sd::math::nd4j_swap(rX[i], rY[i]); - } - } else { - for (Nd4jLong i = 0; i < tadLength; i++) { - auto offset = shape::getIndexOffset(i, tadOnlyShapeInfo[f]); - sd::math::nd4j_swap(hX[offset + oldOffset], hX[offset + newOffset]); - } - } - } - } +template +void shuffleGeneric(void **hX, Nd4jLong *const *hXShapeInfo, void **dz, + Nd4jLong *const *hZShapeInfo, int N, int *shuffleMap, + Nd4jLong *const *tadOnlyShapeInfo, + Nd4jLong *const *tadOffsets) { + auto dX = reinterpret_cast(hX); + auto dZ = reinterpret_cast(dz); + + auto func = PRAGMA_THREADS_FOR { + for (auto f = start; f < stop; f++) { + auto hX = reinterpret_cast(dX[f]); + // auto hZ = reinterpret_cast(dZ[f]); + + auto xShapeInfo = hXShapeInfo[f]; + auto tadOffset = reinterpret_cast(tadOffsets[f]); + + const auto tadLength = shape::length(tadOnlyShapeInfo[f]); + auto tadEWS = shape::elementWiseStride(tadOnlyShapeInfo[f]); + auto tadRank = shape::rank(tadOnlyShapeInfo[f]); + auto numTads = shape::length(hXShapeInfo[f]) / tadLength; + + auto tadShape = shape::shapeOf(tadOnlyShapeInfo[f]); + auto tadStride = shape::stride(tadOnlyShapeInfo[f]); + + if (shape::rank(xShapeInfo) == 1) { + auto xLength = shape::length(xShapeInfo); + auto ews = shape::elementWiseStride(xShapeInfo); + for (Nd4jLong r = 0; r < xLength; r++) { + auto swapIdx = shuffleMap[r]; + if (swapIdx < 0) continue; + + sd::math::nd4j_swap(hX[r * ews], hX[swapIdx * ews]); } - }; - - samediff::Threads::parallel_tad(func, 0, N); -} - -void shuffle(Nd4jPointer *extras, - Nd4jPointer *hX, Nd4jPointer *hXShapeInfo, - Nd4jPointer *dX, Nd4jPointer *dXShapeInfo, - Nd4jPointer *hz, Nd4jPointer *hZShapeInfo, - Nd4jPointer *dz, Nd4jPointer *dZShapeInfo, - int N, - int *shuffleMap, - Nd4jPointer *tadShapeInfo, - Nd4jPointer *tadOffsets) { - try { - auto xShape = reinterpret_cast(hXShapeInfo); - auto zShape = reinterpret_cast(hZShapeInfo); - auto tadOnlyShapeInfo = reinterpret_cast(tadShapeInfo); - auto tadOffset = reinterpret_cast(tadOffsets); - - auto xType = sd::ArrayOptions::dataType(xShape[0]); - - BUILD_SINGLE_SELECTOR(xType, shuffleGeneric, - (hX, xShape, hz, zShape, N, shuffleMap, tadOnlyShapeInfo, tadOffset), LIBND4J_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - - -bool isExperimentalEnabled() { - return sd::Environment::getInstance().isExperimentalBuild(); -} - - -void setOmpMinThreads(int threads) { - // TODO: to be implemented -} - -int getDevice() { - return 0; -} - -void execScalarTad(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const*dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const*dZShapeInfo, - OpaqueDataBuffer *dbScalars, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, - Nd4jLong const*tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const*tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - try { - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - NativeOpExecutioner::execScalar(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - dbScalars->primary(), - hScalarShapeInfo, - dbScalars->special(), - dScalarShapeInfo, - dimension, - shape::length(hDimensionShape), - tadShapeInfo, - tadOffsets, - tadShapeInfoZ, - tadOffsetsZ); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execScalarBoolTad(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - OpaqueDataBuffer *dbScalars, const Nd4jLong *hScalarShapeInfo, const Nd4jLong *dScalarShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ) { - try { - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - NativeOpExecutioner::execScalarBool(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - dbScalars->primary(), - hScalarShapeInfo, - dbScalars->special(), - dScalarShapeInfo, - dimension, - dimensionLength, - tadShapeInfo, - tadOffsets, - tadShapeInfoZ, - tadOffsetsZ); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -const char * getDeviceName(int deviceId) { - try { - if (!nameSet) { - name = reinterpret_cast(malloc(256 * sizeof(char))); + } else { + for (Nd4jLong r = 0; r < numTads; r++) { + if (shuffleMap[r] < 0) continue; - CHECK_ALLOC(name, "Failed to allocate new string buffer", 256); + auto oldOffset = tadOffset[r]; + auto newOffset = tadOffset[shuffleMap[r]]; - std::memset(name, 0, 256 * sizeof(char)); - nameSet = true; + auto rX = hX + oldOffset; + auto rY = hX + newOffset; - // TODO: provide proper CPU model name here - sprintf(name, "x86-compatible CPU"); + if (tadEWS == 1) { + for (Nd4jLong i = 0; i < tadLength; i++) { + sd::math::nd4j_swap(rX[i], rY[i]); + } + } else { + for (Nd4jLong i = 0; i < tadLength; i++) { + auto offset = shape::getIndexOffset(i, tadOnlyShapeInfo[f]); + sd::math::nd4j_swap(hX[offset + oldOffset], + hX[offset + newOffset]); + } + } } - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } + }; - - return name; + samediff::Threads::parallel_tad(func, 0, N); } +void shuffle(Nd4jPointer *extras, Nd4jPointer *hX, Nd4jPointer *hXShapeInfo, + Nd4jPointer *dX, Nd4jPointer *dXShapeInfo, Nd4jPointer *hz, + Nd4jPointer *hZShapeInfo, Nd4jPointer *dz, + Nd4jPointer *dZShapeInfo, int N, int *shuffleMap, + Nd4jPointer *tadShapeInfo, Nd4jPointer *tadOffsets) { + try { + auto xShape = reinterpret_cast(hXShapeInfo); + auto zShape = reinterpret_cast(hZShapeInfo); + auto tadOnlyShapeInfo = reinterpret_cast(tadShapeInfo); + auto tadOffset = reinterpret_cast(tadOffsets); -void execAggregate(Nd4jPointer *extraPointers,int opNum, - void **arguments, - int numArguments, - Nd4jLong **shapeArguments, - int numShapeArguments, - int *indexArguments, - int numIndexArguments, - int **intArrays, - int numIntArrays, - void *realArguments, - int numRealArguments, - sd::DataType dtype) { + auto xType = sd::ArrayOptions::dataType(xShape[0]); + BUILD_SINGLE_SELECTOR( + xType, shuffleGeneric, + (hX, xShape, hz, zShape, N, shuffleMap, tadOnlyShapeInfo, tadOffset), + LIBND4J_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } -void batchExecutor(Nd4jPointer *extraPointers, - int numAggregates, - int opNum, - int maxArgs, - int maxShapes, - int maxIntArrays, - int maxIntArraySize, - int maxIdx, - int maxReals, - void *ptrToArguments, - sd::DataType dtype) { - +bool isExperimentalEnabled() { + return sd::Environment::getInstance().isExperimentalBuild(); } -void execAggregateBatch(Nd4jPointer *extraPointers, - int numAggregates, - int opNum, - int maxArgs, - int maxShapes, - int maxIntArrays, - int maxIntArraySize, - int maxIdx, - int maxReals, - void *ptrToArguments, - sd::DataType dtype) { - +void setOmpMinThreads(int threads) { + // TODO: to be implemented +} + +int getDevice() { return 0; } + +void execScalarTad(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + Nd4jLong const *hXShapeInfo, Nd4jLong const *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, OpaqueDataBuffer *dbScalars, + Nd4jLong const *hScalarShapeInfo, + Nd4jLong const *dScalarShapeInfo, void *extraParams, + OpaqueDataBuffer *dbDimension, + Nd4jLong const *hDimensionShape, + Nd4jLong const *dDimensionShape, + Nd4jLong const *tadShapeInfo, Nd4jLong const *tadOffsets, + Nd4jLong const *tadShapeInfoZ, Nd4jLong const *tadOffsetsZ) { + try { + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + NativeOpExecutioner::execScalar( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo, dbScalars->primary(), hScalarShapeInfo, + dbScalars->special(), dScalarShapeInfo, dimension, + shape::length(hDimensionShape), tadShapeInfo, tadOffsets, tadShapeInfoZ, + tadOffsetsZ); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execScalarBoolTad( + Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbScalars, + const Nd4jLong *hScalarShapeInfo, const Nd4jLong *dScalarShapeInfo, + void *extraParams, OpaqueDataBuffer *dbDimension, + const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ) { + try { + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + NativeOpExecutioner::execScalarBool( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo, dbScalars->primary(), hScalarShapeInfo, + dbScalars->special(), dScalarShapeInfo, dimension, dimensionLength, + tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } +const char *getDeviceName(int deviceId) { + try { + if (!nameSet) { + name = reinterpret_cast(malloc(256 * sizeof(char))); -void execRandom(Nd4jPointer *extraPointers, - int opNum, - Nd4jPointer state, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - void *extraArguments) { - try { - NativeOpExecutioner::execRandom(nullptr, opNum, state, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraArguments); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} + CHECK_ALLOC(name, "Failed to allocate new string buffer", 256); -void execRandom3(Nd4jPointer *extraPointers, - int opNum, - Nd4jPointer state, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - void *extraArguments) { - try { - NativeOpExecutioner::execRandom(nullptr, opNum, state, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraArguments); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} + std::memset(name, 0, 256 * sizeof(char)); + nameSet = true; -void execRandom2(Nd4jPointer *extraPointers, - int opNum, - Nd4jPointer state, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - void *extraArguments) { - try { - NativeOpExecutioner::execRandom(nullptr, opNum, state, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraArguments); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + // TODO: provide proper CPU model name here + sprintf(name, "x86-compatible CPU"); } + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } + + return name; +} + +void execAggregate(Nd4jPointer *extraPointers, int opNum, void **arguments, + int numArguments, Nd4jLong **shapeArguments, + int numShapeArguments, int *indexArguments, + int numIndexArguments, int **intArrays, int numIntArrays, + void *realArguments, int numRealArguments, + sd::DataType dtype) {} + +void batchExecutor(Nd4jPointer *extraPointers, int numAggregates, int opNum, + int maxArgs, int maxShapes, int maxIntArrays, + int maxIntArraySize, int maxIdx, int maxReals, + void *ptrToArguments, sd::DataType dtype) {} + +void execAggregateBatch(Nd4jPointer *extraPointers, int numAggregates, + int opNum, int maxArgs, int maxShapes, int maxIntArrays, + int maxIntArraySize, int maxIdx, int maxReals, + void *ptrToArguments, sd::DataType dtype) {} + +void execRandom(Nd4jPointer *extraPointers, int opNum, Nd4jPointer state, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, void *extraArguments) { + try { + NativeOpExecutioner::execRandom(nullptr, opNum, state, dbZ->primary(), + hZShapeInfo, dbZ->special(), dZShapeInfo, + extraArguments); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execRandom3(Nd4jPointer *extraPointers, int opNum, Nd4jPointer state, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbY, + const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, void *extraArguments) { + try { + NativeOpExecutioner::execRandom( + nullptr, opNum, state, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, + dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, + extraArguments); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execRandom2(Nd4jPointer *extraPointers, int opNum, Nd4jPointer state, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, + void *extraArguments) { + try { + NativeOpExecutioner::execRandom(nullptr, opNum, state, dbX->primary(), + hXShapeInfo, dbX->special(), dXShapeInfo, + dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo, extraArguments); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } -Nd4jPointer initRandom(Nd4jPointer *extraPointers, long seed, long bufferSize, Nd4jPointer ptrToBuffer) { - try { - auto generator = new graph::RandomGenerator(seed, seed); +Nd4jPointer initRandom(Nd4jPointer *extraPointers, long seed, long bufferSize, + Nd4jPointer ptrToBuffer) { + try { + auto generator = new graph::RandomGenerator(seed, seed); - return (Nd4jPointer) generator; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return (Nd4jPointer)generator; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); - return nullptr; - } + return nullptr; + } } -void refreshBuffer(Nd4jPointer *extraPointers, long seed, Nd4jPointer ptrRandom) { - auto generator = reinterpret_cast (ptrRandom); +void refreshBuffer(Nd4jPointer *extraPointers, long seed, + Nd4jPointer ptrRandom) { + auto generator = reinterpret_cast(ptrRandom); - generator->setStates(seed); + generator->setStates(seed); } -void reSeedBuffer(Nd4jPointer *extraPointers, long seed, Nd4jPointer ptrRandom) { - auto generator = reinterpret_cast (ptrRandom); +void reSeedBuffer(Nd4jPointer *extraPointers, long seed, + Nd4jPointer ptrRandom) { + auto generator = reinterpret_cast(ptrRandom); - generator->setStates(seed); + generator->setStates(seed); } - void destroyRandom(Nd4jPointer ptrBuffer) { - auto buffer = reinterpret_cast(ptrBuffer); - delete buffer; + auto buffer = reinterpret_cast(ptrBuffer); + delete buffer; } - - - /** - * Return the length of a shape buffer - * based on the pointer - * @param buffer the buffer pointer to check - * @return - */ + * Return the length of a shape buffer + * based on the pointer + * @param buffer the buffer pointer to check + * @return + */ int lengthForShapeBufferPointer(Nd4jPointer buffer) { - auto shapeBuffer = reinterpret_cast(buffer); - return shape::shapeInfoLength(shape::rank(shapeBuffer)); + auto shapeBuffer = reinterpret_cast(buffer); + return shape::shapeInfoLength(shape::rank(shapeBuffer)); } - /** - * The pointer to get the address for - * - * @param address the address to get the pointer - * @return the pointer for the given address - */ + * The pointer to get the address for + * + * @param address the address to get the pointer + * @return the pointer for the given address + */ Nd4jPointer pointerForAddress(Nd4jLong address) { - return reinterpret_cast(address); -} - -void sort(Nd4jPointer *extraPointers, - void *hX, const Nd4jLong *hXShapeInfo, - void *dX, const Nd4jLong *dXShapeInfo, - bool descending) { - try { - NativeOpExecutioner::execSort(hX, hXShapeInfo, descending); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + return reinterpret_cast(address); } -void sortTad(Nd4jPointer *extraPointers, - void *hX, const Nd4jLong *hXShapeInfo, - void *dX, const Nd4jLong *dXShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, - const Nd4jLong *tadOffsets, - bool descending) { - try { - NativeOpExecutioner::execSort(hX, hXShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void sort(Nd4jPointer *extraPointers, void *hX, const Nd4jLong *hXShapeInfo, + void *dX, const Nd4jLong *dXShapeInfo, bool descending) { + try { + NativeOpExecutioner::execSort(hX, hXShapeInfo, descending); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void sortTad(Nd4jPointer *extraPointers, void *hX, const Nd4jLong *hXShapeInfo, + void *dX, const Nd4jLong *dXShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, bool descending) { + try { + NativeOpExecutioner::execSort(hX, hXShapeInfo, dimension, dimensionLength, + tadShapeInfo, tadOffsets, descending); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } -void sortCooIndices(Nd4jPointer *extraPointers, - Nd4jLong *indices, - void *values, - Nd4jLong length, - int rank) { - try { - NativeOpExecutioner::execSortCooIndices(indices, values, length, rank); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, + Nd4jLong length, int rank) { + try { + NativeOpExecutioner::execSortCooIndices(indices, values, length, rank); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } -Nd4jLong encodeBitmap(Nd4jPointer *extraPointers, void *hX, Nd4jLong const* hXShapeInfo, Nd4jLong N, int *dz, float threshold) { - return NativeOpExecutioner::encodeBitmap(hX, hXShapeInfo, N, dz, threshold); +Nd4jLong encodeBitmap(Nd4jPointer *extraPointers, void *hX, + Nd4jLong const *hXShapeInfo, Nd4jLong N, int *dz, + float threshold) { + return NativeOpExecutioner::encodeBitmap(hX, hXShapeInfo, N, dz, threshold); } - - -Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length) { - auto hZ = new Nd4jLong[2];errno = 0; -try { +Nd4jLong *mmapFile(Nd4jPointer *extraPointers, const char *fileName, + Nd4jLong length) { + auto hZ = new Nd4jLong[3]; + errno = 0; + try { #if defined(_WIN32) || defined(_WIN64) _mmap(hZ, static_cast(length), fileName); #else - int fd = open(fileName, O_RDWR, 0);// checking for failed fopen + int fd = open(fileName, O_RDWR, 0); // checking for failed fopen if (fd < 0) { - nd4j_printf("Errno: %i\n", errno); - throw std::runtime_error("Failed to open file for MMAP"); + nd4j_printf("Errno: %i\n", errno); + throw std::runtime_error("Failed to open file for MMAP"); } void *ptr = mmap(NULL, length, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); -// check for failed allocation - if (ptr == MAP_FAILED) - return nullptr; + // check for failed allocation + if (ptr == MAP_FAILED) return nullptr; - hZ[0] = (Nd4jLong) ptr; + hZ[0] = (Nd4jLong)ptr; hZ[1] = fd; #endif + hZ[2] = length; return hZ; -} catch (std::exception &e) { + } catch (std::exception &e) { sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); return nullptr; -} + } } void munmapFile(Nd4jPointer *extraPointers, Nd4jLong *ptrMap, Nd4jLong length) { - munmap((Nd4jPointer) ptrMap[0], length); + munmap((Nd4jPointer)ptrMap[0], length); #if defined(_WIN32) || defined(_WIN64) - CloseHandle(reinterpret_cast(ptrMap[1])); + CloseHandle(reinterpret_cast(ptrMap[1])); #else - close((int) ptrMap[1]); + close((int)ptrMap[1]); #endif - delete[] ptrMap; + delete[] ptrMap; } -sd::graph::ResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPointer flatBufferPointer) { - try { - return sd::graph::GraphExecutioner::executeFlatBuffer(flatBufferPointer); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } +sd::graph::ResultWrapper *executeFlatGraph(Nd4jPointer *extraPointers, + Nd4jPointer flatBufferPointer) { + return nullptr; } -Nd4jLong getResultWrapperSize(sd::graph::ResultWrapper* ptr) { - return ptr->size(); +Nd4jLong getResultWrapperSize(sd::graph::ResultWrapper *ptr) { + return ptr->size(); } -Nd4jPointer getResultWrapperPointer(sd::graph::ResultWrapper* ptr) { - return ptr->pointer(); +Nd4jPointer getResultWrapperPointer(sd::graph::ResultWrapper *ptr) { + return ptr->pointer(); } -const char* getAllCustomOps() { - return sd::ops::OpRegistrator::getInstance().getAllCustomOperations(); +const char *getAllCustomOps() { + return sd::ops::OpRegistrator::getInstance().getAllCustomOperations(); } template -FORCEINLINE int estimateThresholdGeneric(Nd4jPointer *extraPointers, Nd4jPointer hX, int N, T threshold) { - auto buffer = reinterpret_cast(hX); - int span = (N / 6) + 8; +FORCEINLINE int estimateThresholdGeneric(Nd4jPointer *extraPointers, + Nd4jPointer hX, int N, T threshold) { + auto buffer = reinterpret_cast(hX); + int span = (N / 6) + 8; - auto func = PRAGMA_REDUCE_LONG { - int64_t cnt = 0; - PRAGMA_OMP_SIMD - for (auto e = start; e < stop; e++) { - auto v = sd::math::nd4j_abs(buffer[e]); - if (v >= threshold) - cnt++; - } + auto func = PRAGMA_REDUCE_LONG { + int64_t cnt = 0; + PRAGMA_OMP_SIMD + for (auto e = start; e < stop; e++) { + auto v = sd::math::nd4j_abs(buffer[e]); + if (v >= threshold) cnt++; + } - return cnt; - }; + return cnt; + }; - return samediff::Threads::parallel_long(func, LAMBDA_AL { return _old + _new; }, 0, N); + return samediff::Threads::parallel_long( + func, LAMBDA_AL { return _old + _new; }, 0, N); } - -int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer hX, Nd4jLong const* hXShapeInfo, int N, float threshold) { - try { - auto xType = ArrayOptions::dataType(hXShapeInfo); - BUILD_SINGLE_SELECTOR(xType, return estimateThresholdGeneric, (extraPointers, hX, N, threshold), FLOAT_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return 0; - } +int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer hX, + Nd4jLong const *hXShapeInfo, int N, float threshold) { + try { + auto xType = ArrayOptions::dataType(hXShapeInfo); + BUILD_SINGLE_SELECTOR(xType, return estimateThresholdGeneric, + (extraPointers, hX, N, threshold), FLOAT_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return 0; + } } -Nd4jLong getShapeListSize(sd::ShapeList* list) { - return list->size(); -} +Nd4jLong getShapeListSize(sd::ShapeList *list) { return list->size(); } -Nd4jLong const* getShape(sd::ShapeList* list, Nd4jLong i) { - return const_cast(list->at(i)); +Nd4jLong const *getShape(sd::ShapeList *list, Nd4jLong i) { + return const_cast(list->at(i)); } void deleteShapeList(Nd4jPointer shapeList) { - auto list = reinterpret_cast(shapeList); + auto list = reinterpret_cast(shapeList); - //list->destroy(); - delete list; + // list->destroy(); + delete list; } -sd::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, sd::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) { - sd::graph::VariableSpace varSpace; - Context block(2, &varSpace); - sd::ShapeList inShapes; - - for (int e = 0; e < numIArgs; e++) - block.getIArguments()->push_back(iArgs[e]); +sd::ShapeList *_calculateOutputShapes( + Nd4jPointer *extraPointers, sd::ops::DeclarableOp *op, + Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputShapes, + double *tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, + int numBArgs, int *dArgs, int numDArgs) { + sd::graph::VariableSpace varSpace; + Context block(2, &varSpace); + sd::ShapeList inShapes; - for (int e = 0; e < numTArgs; e++) - block.getTArguments()->push_back(tArgs[e]); + for (int e = 0; e < numIArgs; e++) block.appendI(iArgs[e]); - for (int e = 0; e < numBArgs; e++) - block.getBArguments()->push_back(bArgs[e]); + for (int e = 0; e < numTArgs; e++) block.appendT(tArgs[e]); - for (int e = 0; e < numDArgs; e++) - block.getDArguments()->push_back((sd::DataType) dArgs[e]); + for (int e = 0; e < numBArgs; e++) block.appendB(bArgs[e]); - for (int e = 0; e < numInputShapes; e++) { - auto shape_ = reinterpret_cast(inputShapes[e]); + for (int e = 0; e < numDArgs; e++) block.appendD((sd::DataType)dArgs[e]); - // we shouldn't copy buffer if that's empty array - void *buffer_ = sd::ArrayOptions::arrayType(shape_) == ArrayType::EMPTY ? nullptr : inputBuffers[e]; + for (int e = 0; e < numInputShapes; e++) { + auto shape_ = reinterpret_cast(inputShapes[e]); - auto array = new sd::NDArray(buffer_, shape_, varSpace.launchContext(), false); + // we shouldn't copy buffer if that's empty array + void *buffer_ = sd::ArrayOptions::arrayType(shape_) == ArrayType::EMPTY + ? nullptr + : inputBuffers[e]; - // block should contain references to proper variable - varSpace.putVariable(1, e, array); - block.pickInput(1, e); + auto array = std::make_shared( + buffer_, shape_, LaunchContext::defaultContext(), false); - inShapes.push_back(shape_); - } + // block should contain references to proper variable + varSpace.putVariable(1, e, array); + block.pickInput(1, e); - auto status = op->validateDataTypes(block); - if (status != Status::OK()) - throw std::runtime_error("Data types validation failed"); + inShapes.push_back(shape_); + } - auto shapeList = op->calculateOutputShape(&inShapes, block); + auto status = op->validateDataTypes(block); + if (status != Status::OK()) + throw std::runtime_error("Data types validation failed"); - if (varSpace.launchContext() != nullptr) - shapeList->detach(); + auto shapeList = op->calculateOutputShape(&inShapes, block); - return shapeList; + return shapeList; } -sd::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) { - try { - auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash); +sd::ShapeList *calculateOutputShapes2(Nd4jPointer *extraPointers, Nd4jLong hash, + Nd4jPointer *inputBuffers, + Nd4jPointer *inputShapes, + int numInputShapes, double *tArgs, + int numTArgs, Nd4jLong *iArgs, + int numIArgs, bool *bArgs, int numBArgs, + int *dArgs, int numDArgs) { + try { + auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash); - return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, dArgs, numDArgs); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } + return _calculateOutputShapes( + extraPointers, op.get(), inputBuffers, inputShapes, numInputShapes, + tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, dArgs, numDArgs); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } -sd::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, sd::ops::DeclarableOp *op, Nd4jPointer* inputShapes, int numInputShapes, double *tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs) { - Context block(1); - sd::ShapeList inShapes; - - for (int e = 0; e < numIArgs; e++) - block.getIArguments()->push_back(iArgs[e]); +sd::ShapeList *_calculateOutputShapes(Nd4jPointer *extraPointers, + sd::ops::DeclarableOp *op, + Nd4jPointer *inputShapes, + int numInputShapes, double *tArgs, + int numTArgs, Nd4jLong *iArgs, + int numIArgs) { + Context block(1); + sd::ShapeList inShapes; - for (int e = 0; e < numTArgs; e++) - block.getTArguments()->push_back(tArgs[e]); + for (int e = 0; e < numIArgs; e++) block.appendI(iArgs[e]); - for (int e = 0; e < numInputShapes; e++) - inShapes.push_back(reinterpret_cast(inputShapes[e])); + for (int e = 0; e < numTArgs; e++) block.appendT(tArgs[e]); - auto shapeList = op->calculateOutputShape(&inShapes, block); - shapeList->detach(); + for (int e = 0; e < numInputShapes; e++) + inShapes.push_back(reinterpret_cast(inputShapes[e])); - return shapeList; -} - -sd::ShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs) { - try { - auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash); + auto shapeList = op->calculateOutputShape(&inShapes, block); + shapeList->detach(); - return _calculateOutputShapes(extraPointers, op, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } + return shapeList; } -int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opContext) { - try { - auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash); - auto context = reinterpret_cast(opContext); +sd::ShapeList *calculateOutputShapes(Nd4jPointer *extraPointers, Nd4jLong hash, + Nd4jPointer *inputShapes, + int numInputShapes, double *tArgs, + int numTArgs, Nd4jLong *iArgs, + int numIArgs) { + try { + auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash); - return op->execute(context); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return 20; - } + return _calculateOutputShapes(extraPointers, op.get(), inputShapes, + numInputShapes, tArgs, numTArgs, iArgs, + numIArgs); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } -Nd4jStatus realExec(sd::ops::DeclarableOp* op, Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace) { - if (op == nullptr) - nd4j_printf("Can't find requested operation: [%lld]\n", hash); - - // we're using the same fake nodeId everywhere here - - std::vector inputs(numInputs); - std::vector outputs(numOutputs); - std::vector ttArgs(numTArgs); - std::vector iiArgs(numIArgs); - std::vector biArgs(numBArgs); - - // filling block now with inputs - for (int e = 0; e < numInputs; e++) { - auto shape = reinterpret_cast(inputShapes[e]); - void *buffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY ? nullptr : inputBuffers[e]; - - inputs[e] = new sd::NDArray(buffer, shape); - } - - // if not inplace - transferring output arrays - - if (!isInplace) - for (int e = 0; e < numOutputs; e++) { - // we want to keep original output shape intact - auto shape = shape::copyShape(reinterpret_cast(outputShapes[e])); - void *buffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY ? nullptr : outputBuffers[e]; - - // FIXME: revisit this. - bool canNullify = true; - for (int i = 0; i < numInputs; i++) { - void *ibuffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY ? nullptr : inputBuffers[i]; - if (ibuffer == buffer) { - canNullify = false; - break; - } - } - - if (canNullify) - memset((uint8_t *) buffer, '\0', shape::length(shape) * DataTypeUtils::sizeOfElement(ArrayOptions::dataType(shape))); - - auto array = new sd::NDArray(buffer, shape); - outputs[e] = array; +int execCustomOp2(Nd4jPointer *extraPointers, Nd4jLong hash, + Nd4jPointer opContext) { + try { + auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash); + auto context = reinterpret_cast(opContext); - // and we want to release shape copy once we're done - delete []shape; + return op->execute(context); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return 20; + } +} + +Nd4jStatus realExec(sd::ops::DeclarableOp *op, Nd4jPointer *extraPointers, + Nd4jLong hash, Nd4jPointer *inputBuffers, + Nd4jPointer *inputShapes, int numInputs, + Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, + int numOutputs, double *tArgs, int numTArgs, + Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, + bool isInplace) { + if (op == nullptr) + nd4j_printf("Can't find requested operation: [%lld]\n", hash); + + // we're using the same fake nodeId everywhere here + + std::vector inputs(numInputs); + std::vector outputs(numOutputs); + std::vector ttArgs(numTArgs); + std::vector iiArgs(numIArgs); + std::vector biArgs(numBArgs); + + // filling block now with inputs + for (int e = 0; e < numInputs; e++) { + auto shape = reinterpret_cast(inputShapes[e]); + void *buffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY + ? nullptr + : inputBuffers[e]; + + inputs[e] = new sd::NDArray(buffer, shape); + } + + // if not inplace - transferring output arrays + + if (!isInplace) + for (int e = 0; e < numOutputs; e++) { + // we want to keep original output shape intact + auto shape = + shape::copyShape(reinterpret_cast(outputShapes[e])); + void *buffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY + ? nullptr + : outputBuffers[e]; + + // FIXME: revisit this. + bool canNullify = true; + for (int i = 0; i < numInputs; i++) { + void *ibuffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY + ? nullptr + : inputBuffers[i]; + if (ibuffer == buffer) { + canNullify = false; + break; } + } - for (int e = 0; e < numIArgs; e++) - iiArgs[e] = iArgs[e]; + if (canNullify) + memset((uint8_t *)buffer, '\0', + shape::length(shape) * + DataTypeUtils::sizeOfElement(ArrayOptions::dataType(shape))); + auto array = new sd::NDArray(buffer, shape); + outputs[e] = array; - for (int e = 0; e < numTArgs; e++) - ttArgs[e] = tArgs[e]; + // and we want to release shape copy once we're done + delete[] shape; + } - for (int e = 0; e < numBArgs; e++) - biArgs[e] = bArgs[e]; + for (int e = 0; e < numIArgs; e++) iiArgs[e] = iArgs[e]; - // hypothetically at this point we have everything filled - auto hZ = op->execute(inputs, outputs, ttArgs, iiArgs, biArgs, std::vector(), isInplace); - //auto hZ = op->execute(inputs, ttArgs, iiArgs, isInplace); + for (int e = 0; e < numTArgs; e++) ttArgs[e] = tArgs[e]; + for (int e = 0; e < numBArgs; e++) biArgs[e] = bArgs[e]; + // hypothetically at this point we have everything filled + auto hZ = op->execute(inputs, outputs, ttArgs, iiArgs, biArgs, + std::vector(), isInplace); + // auto hZ = op->execute(inputs, ttArgs, iiArgs, isInplace); - if (!isInplace) - for (int e = 0; e < numOutputs; e++) { - //shape::printShapeInfoLinear("JVM output shape", (int *) outputShapes[e]); - //shape::printShapeInfoLinear("C++ output shape", (int *) outputs[e]->shapeInfo()); - //outputs[e]->printIndexedBuffer("C++ raw output"); - //outputs[e]->printBuffer("C++ indexed output"); + if (!isInplace) + for (int e = 0; e < numOutputs; e++) { + // shape::printShapeInfoLinear("JVM output shape", (int *) + // outputShapes[e]); shape::printShapeInfoLinear("C++ output shape", (int + // *) outputs[e]->shapeInfo()); outputs[e]->printIndexedBuffer("C++ raw + // output"); outputs[e]->printBuffer("C++ indexed output"); - if (outputs[e]->ordering() != shape::order(reinterpret_cast(outputShapes[e]))) - outputs[e]->streamline(shape::order(reinterpret_cast(outputShapes[e]))); - } + if (outputs[e]->ordering() != + shape::order(reinterpret_cast(outputShapes[e]))) + outputs[e]->streamline( + shape::order(reinterpret_cast(outputShapes[e]))); + } - for (auto v: inputs) - delete v; + for (auto v : inputs) delete v; - for (auto v: outputs) - delete v; + for (auto v : outputs) delete v; - return hZ; + return hZ; } - -int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace) { - try { - auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash); - return realExec(op, extraPointers, hash, inputBuffers, inputShapes, numInputs, outputBuffers, outputShapes, numOutputs, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, isInplace); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return 1; - } +int execCustomOp(Nd4jPointer *extraPointers, Nd4jLong hash, + Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, + int numInputs, Nd4jPointer *outputBuffers, + Nd4jPointer *outputShapes, int numOutputs, double *tArgs, + int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, + int numBArgs, bool isInplace) { + try { + auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash); + return realExec(op.get(), extraPointers, hash, inputBuffers, inputShapes, + numInputs, outputBuffers, outputShapes, numOutputs, tArgs, + numTArgs, iArgs, numIArgs, bArgs, numBArgs, isInplace); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return 1; + } } -int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flatBufferPointer) { - try { - auto graph = sd::graph::GraphExecutioner::importFromFlatPointer(flatBufferPointer); +int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, + Nd4jPointer flatBufferPointer) { + try { + auto graph = sd::graph::Graph::fromFlatPointer(flatBufferPointer); - sd::graph::GraphHolder::getInstance().registerGraph(graphId, graph); + // sd::graph::GraphHolder::getInstance().registerGraph(graphId, graph); - return ND4J_STATUS_OK; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return 1; - } + return ND4J_STATUS_OK; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return 1; + } } -static VariablesSet* executeStoredGraphT(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs) { - auto graph = sd::graph::GraphHolder::getInstance().cloneGraph(graphId); - auto varSpace = graph->getVariableSpace(); - - std::vector handles; - - for (int e = 0; e < numInputs; e++) { - auto idx = inputIndices[e]; - - // we'll delete this array later, together with cloned VariableSpace - auto array = new sd::NDArray(inputBuffers[e], reinterpret_cast(inputShapes[e])); - handles.emplace_back(array); - - if (varSpace->hasVariable(idx)) { - auto var = varSpace->getVariable(idx); - if (var->hasNDArray()) - delete var->getNDArray(); - - var->setNDArray(array); - } else - varSpace->putVariable(idx, array); - } - - auto hZ = sd::graph::GraphExecutioner::execute(graph, varSpace); - auto varSet = new sd::graph::VariablesSet(hZ); - - if (hZ == ND4J_STATUS_OK) { - // pull back results, and provide them - auto outputs = graph->fetchOutputs(); - for (int e = 0; e < outputs->size(); e++) { - // we're only getting variable ID/Index from original grap. values will be taken from cloned workspace - std::pair varId(outputs->at(e)->id(), outputs->at(e)->index()); - - auto var = varSpace->getVariable(varId); - - varSet->push_back(var->clone()); - } - - delete outputs; - } - - delete graph; - - return varSet; +sd::graph::VariablesSet *executeStoredGraph(Nd4jPointer *extraPointers, + Nd4jLong graphId, + Nd4jPointer *inputBuffers, + Nd4jPointer *inputShapes, + int *inputIndices, int numInputs) { + return nullptr; } -sd::graph::VariablesSet* executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs) { - return nullptr; +Nd4jLong getVariablesSetSize(sd::graph::VariablesSet *set) { + return set->size(); } -Nd4jLong getVariablesSetSize(sd::graph::VariablesSet* set) { - return set->size(); +Nd4jStatus getVariablesSetStatus(sd::graph::VariablesSet *set) { + return set->status(); } -Nd4jStatus getVariablesSetStatus(sd::graph::VariablesSet* set) { - return set->status(); +sd::graph::Variable *getVariable(sd::graph::VariablesSet *set, Nd4jLong i) { + return set->at(i); } -sd::graph::Variable* getVariable(sd::graph::VariablesSet* set, Nd4jLong i) { - return set->at(i); -} +int getVariableId(sd::graph::Variable *variable) { return variable->id(); } -int getVariableId(sd::graph::Variable* variable) { - return variable->id(); +int getVariableIndex(sd::graph::Variable *variable) { + return variable->index(); } -int getVariableIndex(sd::graph::Variable* variable) { - return variable->index(); +const char *getVariableName(sd::graph::Variable *variable) { + return variable->getName().c_str(); } -const char* getVariableName(sd::graph::Variable* variable) { - return variable->getName()->c_str(); +Nd4jLong const *getVariableShape(sd::graph::Variable *variable) { + return const_cast(variable->getNDArray()->shapeInfo()); } -Nd4jLong const* getVariableShape(sd::graph::Variable* variable) { - return const_cast(variable->getNDArray()->shapeInfo()); -} - -void* getVariableBuffer(sd::graph::Variable* variable) { - return variable->getNDArray()->buffer(); +void *getVariableBuffer(sd::graph::Variable *variable) { + return variable->getNDArray()->buffer(); } int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId) { - - sd::graph::GraphHolder::getInstance().dropGraphAny(graphId); - - return sd::Status::OK(); + sd::graph::GraphHolder::getInstance().forgetGraph(graphId); + return Status::OK(); } void deletePointerArray(Nd4jPointer pointer) { - auto ptr = reinterpret_cast(pointer); - delete[] ptr; + auto ptr = reinterpret_cast(pointer); + delete[] ptr; } void deleteCharArray(Nd4jPointer pointer) { - auto ptr = reinterpret_cast(pointer); - delete[] ptr; + auto ptr = reinterpret_cast(pointer); + delete[] ptr; } void deleteIntArray(Nd4jPointer pointer) { - auto ptr = reinterpret_cast(pointer); - delete[] ptr; + auto ptr = reinterpret_cast(pointer); + delete[] ptr; } void deleteLongArray(Nd4jPointer pointer) { - auto ptr = reinterpret_cast(pointer); - delete[] ptr; -} - -void deleteVariablesSet(sd::graph::VariablesSet* pointer) { - delete pointer; -} - -const char* getAllOperations() { - return sd::OpTracker::getInstance().exportOperations(); + auto ptr = reinterpret_cast(pointer); + delete[] ptr; } +void deleteVariablesSet(sd::graph::VariablesSet *pointer) { delete pointer; } -Nd4jPointer getGraphState(Nd4jLong id) { - return (Nd4jPointer) new sd::graph::GraphState(id); +const char *getAllOperations() { + return sd::OpTracker::getInstance().exportOperations(); } -void deleteGraphState(Nd4jPointer state) { - auto stateP = reinterpret_cast(state); - delete stateP; -} - -Nd4jStatus execCustomOpWithScope_(Nd4jPointer *extraPointers, sd::graph::GraphState *state, Nd4jLong opHash, Nd4jLong *scopes, int numScopes, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputs, Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs) { - /** - * That's basically exec, with VariableSpace provided in GraphState: - * depending on operation (i.e. while of if), different logic executors could be used - */ - - auto graph = state->graph(); - auto varSpace = state->variableSpace(); - - // Node is dynamically created, and has nothing beyond it: only inputs and outputs - // this node has id of 0, and inputs are - Node node(OpType_LOGIC, opHash, 0); - - // mapping inputs - for (int e = 0; e < numInputs; e++) { - auto buffer = inputBuffers[e]; - auto shapeInfo = reinterpret_cast(inputShapes[e]); - - auto array = new sd::NDArray(buffer, shapeInfo, varSpace->launchContext()); - - // now we just put array to VarSpace - varSpace->putVariable(0, e, array); - node.pickInput(0, e); - } - - // mapping scopes - for (int e = 0; e < numScopes; e++) { - // we should check scope existence in GraphState/Graph - int scopeId = (int) scopes[e]; - if (!state->hasScope(scopeId)) { - // nd4j_printf("execCustomOpWithScope: referenced scope [%i] doesn't exist\n", scopeId); - return Status::THROW(); - } - node.pickInput(scopeId, 0); - } - - auto hZ = LogicExecutor::processNode(graph, &node); - if (hZ != Status::OK()) - return hZ; - - // mapping outputs - - for (int e = 0; e < numOutputs; e++) { - auto buffer = outputBuffers[e]; - auto shapeInfo = reinterpret_cast(outputShapes[e]); - - NDArray array(buffer, shapeInfo, varSpace->launchContext()); - - // now we just put array to VarSpace to the same ID - //varSpace->putVariable(0, e, array); - - auto t = varSpace->getVariable(0, e)->getNDArray(); - array.assign(t); - } - - // removing input variables - for (int e = 0; e < numInputs; e++) { - varSpace->dropVariable(0, e); - } - - - // after some bla-bla-bla we should have Graph and Node for current op - return Status::OK(); -} - -Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPointer state, Nd4jLong opHash, Nd4jLong *scopes, int numScopes, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputs, Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs) { - try { - return execCustomOpWithScope_(extraPointers, reinterpret_cast(state), opHash, scopes, numScopes, inputBuffers, inputShapes, numInputs, outputBuffers, outputShapes, numOutputs); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return 1; - } +Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPointer state, + Nd4jLong opHash, Nd4jLong *scopes, + int numScopes, Nd4jPointer *inputBuffers, + Nd4jPointer *inputShapes, int numInputs, + Nd4jPointer *outputBuffers, + Nd4jPointer *outputShapes, int numOutputs) { + return 0; } void deleteResultWrapper(Nd4jPointer ptr) { - // just 0 room for compiler s@!t - auto p = reinterpret_cast(ptr); - delete p; + // just 0 room for compiler s@!t + auto p = reinterpret_cast(ptr); + delete p; } /* * TypeDef: - * void convertTypes(Nd4jPointer *extras, int srcType, Nd4jPointer hX, long N, int dstType, Nd4jPointer hZ); + * void convertTypes(Nd4jPointer *extras, int srcType, Nd4jPointer hX, long + * N, int dstType, Nd4jPointer hZ); */ -void convertTypes(Nd4jPointer *extras, int srcType, Nd4jPointer hX, Nd4jLong N, int dstType, Nd4jPointer hZ) { - auto hx = reinterpret_cast(hX); - auto hz = reinterpret_cast(hZ); - - if (srcType == ND4J_FLOAT8) { - if (dstType == ND4J_FLOAT8) { - // convertGeneric(hx, N, hz); - } else if (dstType == ND4J_INT8) { - //sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT8) { - //sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT16) { - //sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT16) { - //sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT16) { - //sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT24) { - - } else if (dstType == ND4J_FLOAT32) { - //sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_DOUBLE) { - //sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else { - //nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_INT8) { - if (dstType == ND4J_FLOAT8) { - //sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT8) { - //convertGeneric(hx, N, hz); - } else if (dstType == ND4J_UINT8) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT16) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT16) { - //sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT24) { - // TODO: eventually we might want to add it - } else if (dstType == ND4J_FLOAT32) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_DOUBLE) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_UINT8) { - if (dstType == ND4J_FLOAT8) { - // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT8) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT8) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT16) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT16) { - // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT24) { - // TODO: still might want to add - } else if (dstType == ND4J_FLOAT32) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_DOUBLE) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_FLOAT16) { - if (dstType == ND4J_FLOAT8) { - // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT8) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT8) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT16) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT16) { -// sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT24) { - // TODO: .... ^^^ - } else if (dstType == ND4J_FLOAT32) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_DOUBLE) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_THRESHOLD) { - sd::TypeCast::convertToThreshold(nullptr, hx, N, hz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_INT16) { - if (dstType == ND4J_FLOAT8) { - // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT8) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT8) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT16) { - //sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT16) { -// sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT24) { - // TODO... - } else if (dstType == ND4J_FLOAT32) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_DOUBLE) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else { - printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_FLOAT24) { - - } else if (srcType == ND4J_FLOAT32) { - if (dstType == ND4J_FLOAT8) { - // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT8) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT8) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT16) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT16) { -// sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT24) { - - } else if (dstType == ND4J_DOUBLE) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_THRESHOLD) { - sd::TypeCast::convertToThreshold(nullptr, hx, N, hz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_DOUBLE) { - if (dstType == ND4J_FLOAT8) { - // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT8) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT8) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT16) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT16) { -// sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT24) { - - } else if (dstType == ND4J_FLOAT32) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_DOUBLE) { - // - } else if (dstType == ND4J_THRESHOLD) { - sd::TypeCast::convertToThreshold(nullptr, hx, N, hz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_THRESHOLD) { - if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertFromThreshold(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT32) { - sd::TypeCast::convertFromThreshold(nullptr, hx, N, hz); - } else if (dstType == ND4J_DOUBLE) { - sd::TypeCast::convertFromThreshold(nullptr, hx, N, hz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } +void convertTypes(Nd4jPointer *extras, int srcType, Nd4jPointer hX, Nd4jLong N, + int dstType, Nd4jPointer hZ) { + auto hx = reinterpret_cast(hX); + auto hz = reinterpret_cast(hZ); + + if (srcType == ND4J_FLOAT8) { + if (dstType == ND4J_FLOAT8) { + // convertGeneric(hx, N, hz); + } else if (dstType == ND4J_INT8) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_UINT8) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, + // hz); + } else if (dstType == ND4J_FLOAT16) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_INT16) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, + // hz); + } else if (dstType == ND4J_UINT16) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, + // hz); + } else if (dstType == ND4J_FLOAT24) { + } else if (dstType == ND4J_FLOAT32) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_DOUBLE) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else { + // nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + // dstType); + } + } else if (srcType == ND4J_INT8) { + if (dstType == ND4J_FLOAT8) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_INT8) { + // convertGeneric(hx, N, hz); + } else if (dstType == ND4J_UINT8) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_INT16) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_UINT16) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_FLOAT24) { + // TODO: eventually we might want to add it + } else if (dstType == ND4J_FLOAT32) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_DOUBLE) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else if (srcType == ND4J_UINT8) { + if (dstType == ND4J_FLOAT8) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, + // hz); + } else if (dstType == ND4J_INT8) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_UINT8) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_INT16) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_UINT16) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, + // hz); + } else if (dstType == ND4J_FLOAT24) { + // TODO: still might want to add + } else if (dstType == ND4J_FLOAT32) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_DOUBLE) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else if (srcType == ND4J_FLOAT16) { + if (dstType == ND4J_FLOAT8) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, + // hz); + } else if (dstType == ND4J_INT8) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_UINT8) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_INT16) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_UINT16) { + // sd::TypeCast::convertGeneric(nullptr, hx, + // N, hz); + } else if (dstType == ND4J_FLOAT24) { + // TODO: .... ^^^ + } else if (dstType == ND4J_FLOAT32) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_DOUBLE) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_THRESHOLD) { + sd::TypeCast::convertToThreshold(nullptr, hx, N, hz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else if (srcType == ND4J_INT16) { + if (dstType == ND4J_FLOAT8) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, + // hz); + } else if (dstType == ND4J_INT8) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_UINT8) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_INT16) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_UINT16) { + // sd::TypeCast::convertGeneric(nullptr, hx, + // N, hz); + } else if (dstType == ND4J_FLOAT24) { + // TODO... + } else if (dstType == ND4J_FLOAT32) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_DOUBLE) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); + printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); + } + } else if (srcType == ND4J_FLOAT24) { + } else if (srcType == ND4J_FLOAT32) { + if (dstType == ND4J_FLOAT8) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_INT8) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_UINT8) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_INT16) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_UINT16) { + // sd::TypeCast::convertGeneric(nullptr, hx, + // N, hz); + } else if (dstType == ND4J_FLOAT24) { + } else if (dstType == ND4J_DOUBLE) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_THRESHOLD) { + sd::TypeCast::convertToThreshold(nullptr, hx, N, hz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else if (srcType == ND4J_DOUBLE) { + if (dstType == ND4J_FLOAT8) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_INT8) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_UINT8) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_INT16) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_UINT16) { + // sd::TypeCast::convertGeneric(nullptr, hx, + // N, hz); + } else if (dstType == ND4J_FLOAT24) { + } else if (dstType == ND4J_FLOAT32) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_DOUBLE) { + // + } else if (dstType == ND4J_THRESHOLD) { + sd::TypeCast::convertToThreshold(nullptr, hx, N, hz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else if (srcType == ND4J_THRESHOLD) { + if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertFromThreshold(nullptr, hx, N, hz); + } else if (dstType == ND4J_FLOAT32) { + sd::TypeCast::convertFromThreshold(nullptr, hx, N, hz); + } else if (dstType == ND4J_DOUBLE) { + sd::TypeCast::convertFromThreshold(nullptr, hx, N, hz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); } + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } } /* -void fillUtf8String(Nd4jPointer *extraPointers, const char **strings, int numStrings, Nd4jPointer buffer) { - auto hZ = reinterpret_cast(buffer); - for (int e = 0; e < numStrings; e++) { - hZ[e] = reinterpret_cast(createUtf8String(extraPointers, strings[e])); +void fillUtf8String(Nd4jPointer *extraPointers, const char **strings, int +numStrings, Nd4jPointer buffer) { auto hZ = +reinterpret_cast(buffer); for (int e = 0; e < numStrings; e++) +{ hZ[e] = reinterpret_cast(createUtf8String(extraPointers, +strings[e])); } } */ -Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, int length) { - auto u = new sd::utf8string(string, length); - return reinterpret_cast(u); +Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, + int length) { + auto u = new sd::utf8string(string, length); + return reinterpret_cast(u); } Nd4jLong getUtf8StringLength(Nd4jPointer *extraPointers, Nd4jPointer ptr) { - return reinterpret_cast(ptr)->_length; + return reinterpret_cast(ptr)->_length; } -char* getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer ptr) { - return reinterpret_cast(ptr)->_buffer; +char *getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer ptr) { + return reinterpret_cast(ptr)->_buffer; } void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr) { - delete(reinterpret_cast(ptr)); + delete (reinterpret_cast(ptr)); } template -static void _scatterUpdate( - Nd4jPointer *extraPointers, int opCode, int numOfSubArrs, - void* hX, const Nd4jLong* hXShapeInfo, const Nd4jLong* hXOffsets, - void* dX, const Nd4jLong* dXShapeInfo, const Nd4jLong* dXOffsets, - void* hY, const Nd4jLong* hYShapeInfo, const Nd4jLong* hYOffsets, - void* dY, const Nd4jLong* dYShapeInfo, const Nd4jLong* dYOffsets, - void* vIindexes, const Nd4jLong* hIndicesShapeInfo, void* dIindexes, const Nd4jLong* dIndicesShapeInfo) { - - auto hIindexes = reinterpret_cast(vIindexes); - auto func = PRAGMA_THREADS_DO { - for (int i = 0; i < numOfSubArrs; ++i) { - int threadIndex = thread_id; - const auto xIndex = hIindexes[i]; - const bool isOwner = xIndex < numThreads ? threadIndex == xIndex : threadIndex == xIndex % numThreads; - - if (!isOwner) - continue; - - NDArray inSubArr(reinterpret_cast(hX) + (hXOffsets[hIindexes[i]] * DataTypeUtils::sizeOf(hXShapeInfo)), hXShapeInfo); - NDArray updSubArr(reinterpret_cast(hY) + (hYOffsets[i] * DataTypeUtils::sizeOf(hXShapeInfo)), hYShapeInfo); - - if (inSubArr.lengthOf() != updSubArr.lengthOf()) { - continue; - } - - switch (opCode) { - case 0: - inSubArr.applyPairwiseTransform(pairwise::Add, updSubArr, inSubArr); - break; - case 1: - inSubArr.applyPairwiseTransform(pairwise::Subtract, updSubArr, inSubArr); - break; - case 2: - inSubArr.applyPairwiseTransform(pairwise::Multiply, updSubArr, inSubArr); - break; - case 3: - inSubArr.applyPairwiseTransform(pairwise::Divide, updSubArr, inSubArr); - break; - case 4: - inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, updSubArr, inSubArr); - break; - case 5: - inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, updSubArr, inSubArr); - break; - case 6: - inSubArr.applyPairwiseTransform(pairwise::CopyPws, updSubArr, inSubArr); - break; - default: - continue; - } - } - }; - - samediff::Threads::parallel_do(func); +static void _scatterUpdate( + Nd4jPointer *extraPointers, int opCode, int numOfSubArrs, void *hX, + const Nd4jLong *hXShapeInfo, const Nd4jLong *hXOffsets, void *dX, + const Nd4jLong *dXShapeInfo, const Nd4jLong *dXOffsets, void *hY, + const Nd4jLong *hYShapeInfo, const Nd4jLong *hYOffsets, void *dY, + const Nd4jLong *dYShapeInfo, const Nd4jLong *dYOffsets, void *vIindexes, + const Nd4jLong *hIndicesShapeInfo, void *dIindexes, + const Nd4jLong *dIndicesShapeInfo) { + auto hIindexes = reinterpret_cast(vIindexes); + auto func = PRAGMA_THREADS_DO { + for (int i = 0; i < numOfSubArrs; ++i) { + int threadIndex = thread_id; + const auto xIndex = hIindexes[i]; + const bool isOwner = xIndex < numThreads + ? threadIndex == xIndex + : threadIndex == xIndex % numThreads; + + if (!isOwner) continue; + + NDArray inSubArr( + reinterpret_cast(hX) + + (hXOffsets[hIindexes[i]] * DataTypeUtils::sizeOf(hXShapeInfo)), + hXShapeInfo); + NDArray updSubArr(reinterpret_cast(hY) + + (hYOffsets[i] * DataTypeUtils::sizeOf(hXShapeInfo)), + hYShapeInfo); + + if (inSubArr.lengthOf() != updSubArr.lengthOf()) { + continue; + } + + switch (opCode) { + case 0: + inSubArr.applyPairwiseTransform(pairwise::Add, updSubArr, inSubArr); + break; + case 1: + inSubArr.applyPairwiseTransform(pairwise::Subtract, updSubArr, + inSubArr); + break; + case 2: + inSubArr.applyPairwiseTransform(pairwise::Multiply, updSubArr, + inSubArr); + break; + case 3: + inSubArr.applyPairwiseTransform(pairwise::Divide, updSubArr, + inSubArr); + break; + case 4: + inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, updSubArr, + inSubArr); + break; + case 5: + inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, updSubArr, + inSubArr); + break; + case 6: + inSubArr.applyPairwiseTransform(pairwise::CopyPws, updSubArr, + inSubArr); + break; + default: + continue; + } + } + }; + + samediff::Threads::parallel_do(func); } //////////////////////////////////////////////////////////////////////// void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs, - void* hX, const Nd4jLong* hXShapeInfo, const Nd4jLong* hXOffsets, - void* dX, const Nd4jLong* dXShapeInfo, const Nd4jLong* dXOffsets, - void* hY, const Nd4jLong* hYShapeInfo, const Nd4jLong* hYOffsets, - void* dY, const Nd4jLong* dYShapeInfo, const Nd4jLong* dYOffsets, - void* hIindexes, const Nd4jLong* hIndicesShapeInfo, void* dIindexes, const Nd4jLong* dIndicesShapeInfo) { - auto iType = ArrayOptions::dataType(hIndicesShapeInfo); - - try { - BUILD_SINGLE_SELECTOR(iType, _scatterUpdate, (extraPointers, opCode, numOfSubArrs, hX, hXShapeInfo, hXOffsets, dX, dXShapeInfo, dXOffsets, hY, hYShapeInfo, hYOffsets, dY, dYShapeInfo, dYOffsets, hIindexes, hIndicesShapeInfo, dIindexes, dIndicesShapeInfo), INDEXING_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - - -void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo) { - try { - auto p = reinterpret_cast(debugInfo); - NDArray array(buffer, shapeInfo); - sd::DebugHelper::retrieveDebugStatistics(p, &array); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + void *hX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *hXOffsets, void *dX, + const Nd4jLong *dXShapeInfo, const Nd4jLong *dXOffsets, + void *hY, const Nd4jLong *hYShapeInfo, + const Nd4jLong *hYOffsets, void *dY, + const Nd4jLong *dYShapeInfo, const Nd4jLong *dYOffsets, + void *hIindexes, const Nd4jLong *hIndicesShapeInfo, + void *dIindexes, const Nd4jLong *dIndicesShapeInfo) { + auto iType = ArrayOptions::dataType(hIndicesShapeInfo); + + try { + BUILD_SINGLE_SELECTOR( + iType, _scatterUpdate, + (extraPointers, opCode, numOfSubArrs, hX, hXShapeInfo, hXOffsets, dX, + dXShapeInfo, dXOffsets, hY, hYShapeInfo, hYOffsets, dY, dYShapeInfo, + dYOffsets, hIindexes, hIndicesShapeInfo, dIindexes, dIndicesShapeInfo), + INDEXING_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, + Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, + Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo) { + try { + auto p = reinterpret_cast(debugInfo); + NDArray array(buffer, shapeInfo); + sd::DebugHelper::retrieveDebugStatistics(p, &array); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } void tryPointer(Nd4jPointer extra, Nd4jPointer p, int len) { - try { - auto buf = reinterpret_cast(p); - int cnt = 0; - for (int i = 0; i < len; i++) - cnt += buf[cnt]; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -sd::ConstantShapeBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, sd::DataType dtype, char order, Nd4jLong ews, bool empty) { - try { - auto buffer = new ConstantShapeBuffer(); - *buffer = sd::ConstantShapeHelper::getInstance().bufferForShapeInfo( - ShapeDescriptor(dtype, order, shape, strides, rank, ews, empty)); - return buffer; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } + try { + auto buf = reinterpret_cast(p); + int cnt = 0; + for (int i = 0; i < len; i++) cnt += buf[cnt]; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +sd::ConstantShapeBuffer *shapeBuffer(int rank, Nd4jLong *shape, + Nd4jLong *strides, sd::DataType dtype, + char order, Nd4jLong ews, bool empty) { + try { + auto buffer = new ConstantShapeBuffer(); + *buffer = sd::ConstantShapeHelper::getInstance().bufferForShapeInfo( + ShapeDescriptor(dtype, order, shape, strides, rank, ews, empty)); + return buffer; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } -void deleteConstantShapeBuffer(sd::ConstantShapeBuffer* ptr) { - delete ptr; -} +void deleteConstantShapeBuffer(sd::ConstantShapeBuffer *ptr) { delete ptr; } void deleteConstantDataBuffer(sd::ConstantDataBuffer* ptr) { delete ptr; } -void deleteTadPack(sd::TadPack* ptr) { - delete ptr; -} +void deleteTadPack(sd::TadPack *ptr) { delete ptr; } -sd::ConstantDataBuffer* constantBufferLong(sd::DataType dtype, const Nd4jLong *data, int length) { - return nullptr; +sd::ConstantDataBuffer *constantBufferLong(sd::DataType dtype, + const Nd4jLong *data, int length) { + return nullptr; } -sd::ConstantDataBuffer* constantBufferDouble(sd::DataType dtype, double *data, int length) { - return nullptr; +sd::ConstantDataBuffer *constantBufferDouble(sd::DataType dtype, double *data, + int length) { + return nullptr; } -sd::ConstantDataBuffer* constantBuffer(sd::DataType dtype, sd::ConstantDescriptor *descriptor) { - try { - return sd::ConstantHelper::getInstance().constantBuffer(*descriptor, dtype); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } +sd::ConstantDataBuffer *constantBuffer(sd::DataType dtype, + sd::ConstantDescriptor *descriptor) { + try { + return sd::ConstantHelper::getInstance().constantBuffer(*descriptor, + dtype); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } Nd4jPointer getConstantShapeBufferPrimary(sd::ConstantShapeBuffer* dbf) { @@ -2745,103 +2450,115 @@ Nd4jPointer getConstantShapeBufferSpecial(sd::ConstantShapeBuffer* dbf) { return const_cast(dbf->special()); } -Nd4jPointer getConstantDataBufferPrimary(sd::ConstantDataBuffer* dbf) { - return dbf->primary(); +Nd4jPointer getConstantDataBufferPrimary(sd::ConstantDataBuffer *dbf) { + return dbf->primary(); } -Nd4jPointer getConstantDataBufferSpecial(sd::ConstantDataBuffer* dbf) { - return dbf->special(); +Nd4jPointer getConstantDataBufferSpecial(sd::ConstantDataBuffer *dbf) { + return dbf->special(); } -Nd4jLong getConstantDataBufferLength(sd::ConstantDataBuffer* dbf) { - return dbf->length(); +Nd4jLong getConstantDataBufferLength(sd::ConstantDataBuffer *dbf) { + return dbf->length(); } -Nd4jLong getConstantDataBufferSizeOf(sd::ConstantDataBuffer* dbf) { - return dbf->sizeOf(); +Nd4jLong getConstantDataBufferSizeOf(sd::ConstantDataBuffer *dbf) { + return dbf->sizeOf(); } - -sd::graph::Context* createGraphContext(int nodeId) { - try { - return new sd::graph::Context(nodeId); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } -} -sd::graph::RandomGenerator* getGraphContextRandomGenerator(sd::graph::Context* ptr) { - return &ptr->randomGenerator(); +sd::graph::Context *createGraphContext(int nodeId) { + try { + return new sd::graph::Context(nodeId); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } -void markGraphContextInplace(sd::graph::Context* ptr, bool reallyInplace) { - ptr->markInplace(reallyInplace); +sd::graph::RandomGenerator *getGraphContextRandomGenerator( + sd::graph::Context *ptr) { + return &ptr->randomGenerator(); } -void setGraphContextCudaContext(sd::graph::Context* ptr, void *stream, void *reductionPointer, void *allocationPointer) { +void markGraphContextInplace(sd::graph::Context *ptr, bool reallyInplace) { + ptr->markInplace(reallyInplace); } -void setGraphContextInputArray(sd::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) { - ptr->setInputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo); +void setGraphContextCudaContext(sd::graph::Context *ptr, void *stream, + void *reductionPointer, + void *allocationPointer) {} +void setGraphContextInputArray(sd::graph::Context *ptr, int index, void *buffer, + void *shapeInfo, void *specialBuffer, + void *specialShapeInfo) { + ptr->setInputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo); } -void setGraphContextOutputArray(sd::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) { - ptr->setOutputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo); +void setGraphContextOutputArray(sd::graph::Context *ptr, int index, + void *buffer, void *shapeInfo, + void *specialBuffer, void *specialShapeInfo) { + ptr->setOutputArray(index, buffer, shapeInfo, specialBuffer, + specialShapeInfo); } -void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo) { - ptr->setInputArray(index, buffer, shapeInfo, specialShapeInfo); +void setGraphContextInputBuffer(OpaqueContext *ptr, int index, + OpaqueDataBuffer *buffer, void *shapeInfo, + void *specialShapeInfo) { + ptr->setInputArray(index, buffer, shapeInfo, specialShapeInfo); } -void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo) { - ptr->setOutputArray(index, buffer, shapeInfo, specialShapeInfo); +void setGraphContextOutputBuffer(OpaqueContext *ptr, int index, + OpaqueDataBuffer *buffer, void *shapeInfo, + void *specialShapeInfo) { + ptr->setOutputArray(index, buffer, shapeInfo, specialShapeInfo); } -void setGraphContextTArguments(sd::graph::Context* ptr, double *arguments, int numberOfArguments) { - ptr->setTArguments(arguments, numberOfArguments); +void setGraphContextTArguments(sd::graph::Context *ptr, double *arguments, + int numberOfArguments) { + ptr->setTArguments(arguments, numberOfArguments); } -void setGraphContextIArguments(sd::graph::Context* ptr, Nd4jLong *arguments, int numberOfArguments) { - ptr->setIArguments(arguments, numberOfArguments); +void setGraphContextIArguments(sd::graph::Context *ptr, Nd4jLong *arguments, + int numberOfArguments) { + ptr->setIArguments(arguments, numberOfArguments); } -void setGraphContextBArguments(sd::graph::Context* ptr, bool *arguments, int numberOfArguments) { - ptr->setBArguments(arguments, numberOfArguments); +void setGraphContextBArguments(sd::graph::Context *ptr, bool *arguments, + int numberOfArguments) { + ptr->setBArguments(arguments, numberOfArguments); } -void setGraphContextDArguments(OpaqueContext* ptr, int *arguments, int numberOfArguments) { - std::vector dtypes(numberOfArguments); - for (int e = 0; e < numberOfArguments; e++) - dtypes[e] = (sd::DataType) arguments[e]; +void setGraphContextDArguments(OpaqueContext *ptr, int *arguments, + int numberOfArguments) { + std::vector dtypes(numberOfArguments); + for (int e = 0; e < numberOfArguments; e++) + dtypes[e] = (sd::DataType)arguments[e]; - ptr->setDArguments(dtypes); + ptr->setDArguments(dtypes); } -void deleteGraphContext(sd::graph::Context* ptr) { - delete ptr; -} +void deleteGraphContext(sd::graph::Context *ptr) { delete ptr; } -void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow) { - ptr->allowHelpers(reallyAllow); +void ctxAllowHelpers(OpaqueContext *ptr, bool reallyAllow) { + ptr->allowHelpers(reallyAllow); } -void ctxSetExecutionMode(OpaqueContext* ptr, int execMode) { - if (execMode < 0 || execMode > 2) - execMode = 0; +void ctxSetExecutionMode(OpaqueContext *ptr, int execMode) { + if (execMode < 0 || execMode > 2) execMode = 0; - ptr->setExecutionMode((samediff::ExecutionMode) execMode); + ptr->setExecutionMode((samediff::ExecutionMode)execMode); } -void ctxPurge(OpaqueContext* ptr) { - ptr->clearFastPath(); -} +void ctxPurge(OpaqueContext *ptr) { ptr->clearFastPath(); } -sd::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) { - return new sd::graph::RandomGenerator(rootSeed, nodeSeed); +sd::graph::RandomGenerator *createRandomGenerator(Nd4jLong rootSeed, + Nd4jLong nodeSeed) { + return new sd::graph::RandomGenerator(rootSeed, nodeSeed); } -Nd4jLong getRandomGeneratorRootState(sd::graph::RandomGenerator* ptr) { - return ptr->rootState(); +Nd4jLong getRandomGeneratorRootState(sd::graph::RandomGenerator *ptr) { + return ptr->rootState(); } -Nd4jLong getRandomGeneratorNodeState(sd::graph::RandomGenerator* ptr) { - return ptr->nodeState(); +Nd4jLong getRandomGeneratorNodeState(sd::graph::RandomGenerator *ptr) { + return ptr->nodeState(); } -void setRandomGeneratorStates(sd::graph::RandomGenerator* ptr, Nd4jLong rootSeed, Nd4jLong nodeSeed) { - ptr->setStates(rootSeed, nodeSeed); +void setRandomGeneratorStates(sd::graph::RandomGenerator *ptr, + Nd4jLong rootSeed, Nd4jLong nodeSeed) { + ptr->setStates(rootSeed, nodeSeed); } float getRandomGeneratorRelativeFloat(sd::graph::RandomGenerator* ptr, Nd4jLong index) { @@ -2852,398 +2569,415 @@ double getRandomGeneratorRelativeDouble(sd::graph::RandomGenerator* ptr, Nd4jLon return ptr->relativeT(index); } -int getRandomGeneratorRelativeInt(sd::graph::RandomGenerator* ptr, Nd4jLong index) { - return ptr->relativeInt(index); -} - -Nd4jLong getRandomGeneratorRelativeLong(sd::graph::RandomGenerator* ptr, Nd4jLong index) { - return ptr->relativeLong(index); +int getRandomGeneratorRelativeInt(sd::graph::RandomGenerator *ptr, + Nd4jLong index) { + return ptr->relativeInt(index); } -void deleteRandomGenerator(sd::graph::RandomGenerator* ptr) { - delete ptr; +Nd4jLong getRandomGeneratorRelativeLong(sd::graph::RandomGenerator *ptr, + Nd4jLong index) { + return ptr->relativeLong(index); } +void deleteRandomGenerator(sd::graph::RandomGenerator *ptr) { delete ptr; } int dataTypeFromNpyHeader(void *header) { - return (int) cnpy::dataTypeFromHeader(reinterpret_cast(header)); + return (int)cnpy::dataTypeFromHeader(reinterpret_cast(header)); } Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) { - try { - cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast(npyArray)); - unsigned int shapeSize = arr.shape.size(); - std::vector shape(shapeSize); - bool _empty = false; - for (unsigned int i = 0; i < shapeSize; i++) { - shape[i] = arr.shape[i]; - - if (arr.shape[i] == 0) - _empty = true; - } - - auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast(npyArray)); - - Nd4jLong *shapeBuffer; - if (shape.size() == 1 && shape[0] == 0) { - // scalar case - shapeBuffer = sd::ShapeBuilders::createScalarShapeInfo(dtype); - } else if (_empty) { - if (shapeSize > 0) - shapeBuffer = sd::ShapeBuilders::emptyShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); - else - shapeBuffer = sd::ShapeBuilders::emptyShapeInfo(dtype); - } else { - shapeBuffer = sd::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); - } - return const_cast(sd::ConstantShapeHelper::getInstance().createFromExisting(shapeBuffer, true)); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } -} - -void sortByKey(Nd4jPointer *extraPointers, - void *x, const Nd4jLong *xShapeInfo, - void *dx, const Nd4jLong *dxShapeInfo, - void *y, const Nd4jLong *yShapeInfo, - void *dy, const Nd4jLong *dyShapeInfo, - bool descending) { - try { - auto xType = ArrayOptions::dataType(xShapeInfo); - auto yType = ArrayOptions::dataType(yShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, yType, sd::DoubleMethods, ::sortByKey(x, xShapeInfo, y, yShapeInfo, descending), LIBND4J_TYPES, LIBND4J_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void sortByValue(Nd4jPointer *extraPointers, - void *x, const Nd4jLong *xShapeInfo, - void *dx, const Nd4jLong *dxShapeInfo, - void *y, const Nd4jLong *yShapeInfo, - void *dy, const Nd4jLong *dyShapeInfo, - bool descending) { - try { - auto xType = ArrayOptions::dataType(xShapeInfo); - auto yType = ArrayOptions::dataType(yShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, yType, sd::DoubleMethods, ::sortByValue(x, xShapeInfo, y, yShapeInfo, descending), LIBND4J_TYPES, LIBND4J_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + try { + cnpy::NpyArray arr = + cnpy::loadNpyFromPointer(reinterpret_cast(npyArray)); + unsigned int shapeSize = arr.shape.size(); + std::vector shape(shapeSize); + bool _empty = false; + for (unsigned int i = 0; i < shapeSize; i++) { + shape[i] = arr.shape[i]; + + if (arr.shape[i] == 0) _empty = true; + } + + auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast(npyArray)); + + Nd4jLong *shapeBuffer; + if (shape.size() == 1 && shape[0] == 0) { + // scalar case + shapeBuffer = sd::ShapeBuilders::createScalarShapeInfo(dtype); + } else if (_empty) { + if (shapeSize > 0) + shapeBuffer = sd::ShapeBuilders::emptyShapeInfo( + dtype, arr.fortranOrder ? 'f' : 'c', shape); + else + shapeBuffer = sd::ShapeBuilders::emptyShapeInfo(dtype); + } else { + shapeBuffer = sd::ShapeBuilders::createShapeInfo( + dtype, arr.fortranOrder ? 'f' : 'c', shape); } + return const_cast( + sd::ConstantShapeHelper::getInstance().createFromExisting(shapeBuffer, + true)); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } -void sortTadByKey(Nd4jPointer *extraPointers, - void *x, const Nd4jLong *xShapeInfo, - void *dx, const Nd4jLong *dxShapeInfo, - void *y, const Nd4jLong *yShapeInfo, - void *dy, const Nd4jLong *dyShapeInfo, - int *dimension, int dimensionLength, - bool descending) { - try { - auto xType = ArrayOptions::dataType(xShapeInfo); - auto yType = ArrayOptions::dataType(yShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, yType, sd::DoubleMethods, ::sortTadByKey(x, xShapeInfo, y, yShapeInfo, dimension, dimensionLength, descending), LIBND4J_TYPES, LIBND4J_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} +void sortByKey(Nd4jPointer *extraPointers, void *x, const Nd4jLong *xShapeInfo, + void *dx, const Nd4jLong *dxShapeInfo, void *y, + const Nd4jLong *yShapeInfo, void *dy, + const Nd4jLong *dyShapeInfo, bool descending) { + try { + auto xType = ArrayOptions::dataType(xShapeInfo); + auto yType = ArrayOptions::dataType(yShapeInfo); -void sortTadByValue(Nd4jPointer *extraPointers, - void *x, const Nd4jLong *xShapeInfo, - void *dx, const Nd4jLong *dxShapeInfo, - void *y, const Nd4jLong *yShapeInfo, - void *dy, const Nd4jLong *dyShapeInfo, - int *dimension, int dimensionLength, - bool descending) { - try { - auto xType = ArrayOptions::dataType(xShapeInfo); - auto yType = ArrayOptions::dataType(yShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, yType, sd::DoubleMethods, ::sortTadByValue(x, xShapeInfo, y, yShapeInfo, dimension, dimensionLength, descending), LIBND4J_TYPES, LIBND4J_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + BUILD_DOUBLE_SELECTOR(xType, yType, sd::DoubleMethods, + ::sortByKey(x, xShapeInfo, y, yShapeInfo, descending), + LIBND4J_TYPES, LIBND4J_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void sortByValue(Nd4jPointer *extraPointers, void *x, + const Nd4jLong *xShapeInfo, void *dx, + const Nd4jLong *dxShapeInfo, void *y, + const Nd4jLong *yShapeInfo, void *dy, + const Nd4jLong *dyShapeInfo, bool descending) { + try { + auto xType = ArrayOptions::dataType(xShapeInfo); + auto yType = ArrayOptions::dataType(yShapeInfo); + + BUILD_DOUBLE_SELECTOR( + xType, yType, sd::DoubleMethods, + ::sortByValue(x, xShapeInfo, y, yShapeInfo, descending), LIBND4J_TYPES, + LIBND4J_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void sortTadByKey(Nd4jPointer *extraPointers, void *x, + const Nd4jLong *xShapeInfo, void *dx, + const Nd4jLong *dxShapeInfo, void *y, + const Nd4jLong *yShapeInfo, void *dy, + const Nd4jLong *dyShapeInfo, int *dimension, + int dimensionLength, bool descending) { + try { + auto xType = ArrayOptions::dataType(xShapeInfo); + auto yType = ArrayOptions::dataType(yShapeInfo); + + BUILD_DOUBLE_SELECTOR( + xType, yType, sd::DoubleMethods, + ::sortTadByKey(x, xShapeInfo, y, yShapeInfo, dimension, dimensionLength, + descending), + LIBND4J_TYPES, LIBND4J_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void sortTadByValue(Nd4jPointer *extraPointers, void *x, + const Nd4jLong *xShapeInfo, void *dx, + const Nd4jLong *dxShapeInfo, void *y, + const Nd4jLong *yShapeInfo, void *dy, + const Nd4jLong *dyShapeInfo, int *dimension, + int dimensionLength, bool descending) { + try { + auto xType = ArrayOptions::dataType(xShapeInfo); + auto yType = ArrayOptions::dataType(yShapeInfo); + + BUILD_DOUBLE_SELECTOR( + xType, yType, sd::DoubleMethods, + ::sortTadByValue(x, xShapeInfo, y, yShapeInfo, dimension, + dimensionLength, descending), + LIBND4J_TYPES, LIBND4J_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } -const char* runLightBenchmarkSuit(bool printOut) { - try { - sd::LightBenchmarkSuit suit; - auto result = suit.runSuit(); +const char *runLightBenchmarkSuit(bool printOut) { + try { + sd::LightBenchmarkSuit suit; + auto result = suit.runSuit(); - if (printOut) - nd4j_printf("%s\n", result.data()); + if (printOut) nd4j_printf("%s\n", result.data()); - auto chars = new char[result.length() + 1]; - std::memcpy(chars, result.data(), result.length()); - chars[result.length()] = (char) 0x0; + auto chars = new char[result.length() + 1]; + std::memcpy(chars, result.data(), result.length()); + chars[result.length()] = (char)0x0; - return chars; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } + return chars; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } Nd4jLong getCachedMemory(int deviceId) { - return sd::ConstantHelper::getInstance().getCachedAmount(deviceId); + return sd::ConstantHelper::getInstance().getCachedAmount(deviceId); } -const char* runFullBenchmarkSuit(bool printOut) { - try { - sd::FullBenchmarkSuit suit; - auto result = suit.runSuit(); +const char *runFullBenchmarkSuit(bool printOut) { + try { + sd::FullBenchmarkSuit suit; + auto result = suit.runSuit(); - if (printOut) - nd4j_printf("%s\n", result.data()); + if (printOut) nd4j_printf("%s\n", result.data()); - auto chars = new char[result.length() + 1]; - std::memcpy(chars, result.data(), result.length()); - chars[result.length()] = (char) 0x0; + auto chars = new char[result.length() + 1]; + std::memcpy(chars, result.data(), result.length()); + chars[result.length()] = (char)0x0; - return chars; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } + return chars; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } -sd::LaunchContext* defaultLaunchContext() { - return LaunchContext::defaultContext(); +sd::LaunchContext *defaultLaunchContext() { + return LaunchContext::defaultContext(); } -Nd4jPointer lcScalarPointer(OpaqueLaunchContext* lc) { - return nullptr; -} +Nd4jPointer lcScalarPointer(OpaqueLaunchContext *lc) { return nullptr; } -Nd4jPointer lcReductionPointer(OpaqueLaunchContext* lc) { - return nullptr; -} +Nd4jPointer lcReductionPointer(OpaqueLaunchContext *lc) { return nullptr; } -Nd4jPointer lcAllocationPointer(OpaqueLaunchContext* lc) { - return nullptr; -} +Nd4jPointer lcAllocationPointer(OpaqueLaunchContext *lc) { return nullptr; } -Nd4jPointer lcExecutionStream(OpaqueLaunchContext* lc) { - return nullptr; -} +Nd4jPointer lcExecutionStream(OpaqueLaunchContext *lc) { return nullptr; } -Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc) { - return nullptr; -} +Nd4jPointer lcCopyStream(OpaqueLaunchContext *lc) { return nullptr; } -Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc) { - return nullptr; -} +Nd4jPointer lcBlasHandle(OpaqueLaunchContext *lc) { return nullptr; } -Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc) { - return nullptr; -} +Nd4jPointer lcSolverHandle(OpaqueLaunchContext *lc) { return nullptr; } int lastErrorCode() { - return sd::LaunchContext::defaultContext()->errorReference()->errorCode(); + return sd::LaunchContext::defaultContext()->errorReference()->errorCode(); } -const char* lastErrorMessage() { - return sd::LaunchContext::defaultContext()->errorReference()->errorMessage(); +const char *lastErrorMessage() { + return sd::LaunchContext::defaultContext()->errorReference()->errorMessage(); } -void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride) { - ptr->setShapeFunctionOverride(reallyOverride); +void ctxShapeFunctionOverride(OpaqueContext *ptr, bool reallyOverride) { + ptr->setShapeFunctionOverride(reallyOverride); } -int binaryLevel() { +int binaryLevel() { #ifdef CPU_FEATURES #if defined(F_X64) - return 1; -#elif defined (F_AVX2) - return 2; -#elif defined (F_AVX512) - return 3; + return 1; +#elif defined(F_AVX2) + return 2; +#elif defined(F_AVX512) + return 3; #else - return 0; + return 0; #endif #else - return 0; + return 0; #endif } int optimalLevel() { #ifdef CPU_FEATURES - auto features = cpu_features::GetX86Info().features; + auto features = cpu_features::GetX86Info().features; - if (features.avx && features.avx2 && features.avx512f && features.avx512vl && features.avx512bw && features.avx512dq && features.avx512cd) - return 3; - else if (features.avx && features.avx2) - return 2; - else - return 1; + if (features.avx && features.avx2 && features.avx512f && features.avx512vl && + features.avx512bw && features.avx512dq && features.avx512cd) + return 3; + else if (features.avx && features.avx2) + return 2; + else + return 1; #else - return 0; + return 0; #endif } bool isMinimalRequirementsMet() { #ifdef CPU_FEATURES - auto features = cpu_features::GetX86Info().features; + auto features = cpu_features::GetX86Info().features; #if defined(F_X64) - return true; -#elif defined (F_AVX2) - return features.avx && features.avx2; -#elif defined (F_AVX512) - // we're optimizing for skylake-avx512 features, so we'll check those out - return features.avx && features.avx2 && features.avx512f && features.avx512vl && features.avx512bw && features.avx512dq && features.avx512cd; + return true; +#elif defined(F_AVX2) + return features.avx && features.avx2; +#elif defined(F_AVX512) + // we're optimizing for skylake-avx512 features, so we'll check those out + return features.avx && features.avx2 && features.avx512f && + features.avx512vl && features.avx512bw && features.avx512dq && + features.avx512cd; #else - return true; + return true; #endif #else - return true; + return true; #endif } bool isOptimalRequirementsMet() { #ifdef CPU_FEATURES - auto b = ::binaryLevel(); - auto o = ::optimalLevel(); + auto b = ::binaryLevel(); + auto o = ::optimalLevel(); - if (b == o) - return true; - else - return false; -#else + if (b == o) return true; + else + return false; +#else + return true; #endif } -OpaqueDataBuffer* dbAllocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) { - return allocateDataBuffer(elements, dataType, allocateBoth); +OpaqueDataBuffer *dbAllocateDataBuffer(Nd4jLong elements, int dataType, + bool allocateBoth) { + return allocateDataBuffer(elements, dataType, allocateBoth); } -OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) { - try { - auto dtype = DataTypeUtils::fromInt(dataType); - return new sd::InteropDataBuffer(elements * DataTypeUtils::sizeOf(dtype) , dtype, allocateBoth); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } +OpaqueDataBuffer *allocateDataBuffer(Nd4jLong elements, int dataType, + bool allocateBoth) { + try { + auto dtype = DataTypeUtils::fromInt(dataType); + return new sd::InteropDataBuffer(elements * DataTypeUtils::sizeOf(dtype), + dtype, allocateBoth); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } Nd4jPointer dbPrimaryBuffer(OpaqueDataBuffer *dataBuffer) { - return dataBuffer->primary(); + return dataBuffer->primary(); } Nd4jPointer dbSpecialBuffer(OpaqueDataBuffer *dataBuffer) { - return dataBuffer->special(); + return dataBuffer->special(); } -void deleteDataBuffer(OpaqueDataBuffer *dataBuffer) { - delete dataBuffer; -} +void deleteDataBuffer(OpaqueDataBuffer *dataBuffer) { delete dataBuffer; } -OpaqueDataBuffer* dbCreateExternalDataBuffer(Nd4jLong elements, int dataType, Nd4jPointer primary, Nd4jPointer special) { - auto buffer = dbAllocateDataBuffer(0, dataType, false); +OpaqueDataBuffer *dbCreateExternalDataBuffer(Nd4jLong elements, int dataType, + Nd4jPointer primary, + Nd4jPointer special) { + auto buffer = dbAllocateDataBuffer(0, dataType, false); - if (primary != nullptr) - buffer->setPrimary(primary, elements); + if (primary != nullptr) buffer->setPrimary(primary, elements); - if (special != nullptr) - buffer->setSpecial(special, elements); + if (special != nullptr) buffer->setSpecial(special, elements); - return buffer; + return buffer; } -void dbSetPrimaryBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer primaryBuffer, Nd4jLong numBytes) { - dataBuffer->setPrimary(primaryBuffer, numBytes); +void dbSetPrimaryBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer primaryBuffer, + Nd4jLong numBytes) { + dataBuffer->setPrimary(primaryBuffer, numBytes); } -void dbSetSpecialBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer specialBuffer, Nd4jLong numBytes) { - dataBuffer->setSpecial(specialBuffer, numBytes); +void dbSetSpecialBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer specialBuffer, + Nd4jLong numBytes) { + dataBuffer->setSpecial(specialBuffer, numBytes); } void dbAllocatePrimaryBuffer(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->allocatePrimary(); + dataBuffer->dataBuffer()->allocatePrimary(); } void dbAllocateSpecialBuffer(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->allocateSpecial(); + dataBuffer->dataBuffer()->allocateSpecial(); } void dbExpandBuffer(OpaqueDataBuffer *dataBuffer, Nd4jLong elements) { - try { - dataBuffer->dataBuffer()->expand(elements * DataTypeUtils::sizeOf(dataBuffer->dataBuffer()->getDataType())); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + try { + dataBuffer->dataBuffer()->expand( + elements * + DataTypeUtils::sizeOf(dataBuffer->dataBuffer()->getDataType())); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } -OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, Nd4jLong offset) { - return new InteropDataBuffer(*dataBuffer, length, offset); +OpaqueDataBuffer *dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, + Nd4jLong offset) { + return new InteropDataBuffer(*dataBuffer, length, offset); } void dbSyncToSpecial(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->syncToSpecial(); + dataBuffer->dataBuffer()->syncToSpecial(); } void dbSyncToPrimary(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->syncToPrimary(nullptr); + dataBuffer->dataBuffer()->syncToPrimary(nullptr); } void dbTickHostRead(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->readPrimary(); + dataBuffer->dataBuffer()->readPrimary(); } void dbTickHostWrite(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->writePrimary(); + dataBuffer->dataBuffer()->writePrimary(); } void dbTickDeviceRead(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->readSpecial(); + dataBuffer->dataBuffer()->readSpecial(); } void dbTickDeviceWrite(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->writeSpecial(); + dataBuffer->dataBuffer()->writeSpecial(); } void dbExpand(OpaqueDataBuffer *dataBuffer, Nd4jLong elements) { - dataBuffer->expand(elements); + dataBuffer->expand(elements); } -int dbLocality(OpaqueDataBuffer *dataBuffer) { - return 0; -} +int dbLocality(OpaqueDataBuffer *dataBuffer) { return 0; } void dbSetDeviceId(OpaqueDataBuffer *dataBuffer, int deviceId) { - dataBuffer->setDeviceId(deviceId); + dataBuffer->setDeviceId(deviceId); } -int dbDeviceId(OpaqueDataBuffer *dataBuffer) { - return dataBuffer->deviceId(); -} +int dbDeviceId(OpaqueDataBuffer *dataBuffer) { return dataBuffer->deviceId(); } void dbClose(OpaqueDataBuffer *dataBuffer) { - dataBuffer->getDataBuffer()->close(); -} - -BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong const*, void*, Nd4jLong const*, const int, Nd4jLong const*, Nd4jLong const*, Nd4jLong const*, Nd4jLong const*, Nd4jLong const*), LIBND4J_TYPES); -BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong const* , Nd4jPointer*, Nd4jLong const*, Nd4jLong const*, Nd4jLong const*), LIBND4J_TYPES); -BUILD_SINGLE_TEMPLATE(template void shuffleGeneric, (void**, Nd4jLong* const*, void**, Nd4jLong* const*, int, int*, Nd4jLong* const*, Nd4jLong* const*), LIBND4J_TYPES); - - + dataBuffer->getDataBuffer()->close(); +} + +BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, + (void *, Nd4jLong const *, void *, Nd4jLong const *, + const int, Nd4jLong const *, Nd4jLong const *, + Nd4jLong const *, Nd4jLong const *, Nd4jLong const *), + LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void tearGeneric, + (void *, Nd4jLong const *, Nd4jPointer *, + Nd4jLong const *, Nd4jLong const *, Nd4jLong const *), + LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void shuffleGeneric, + (void **, Nd4jLong *const *, void **, Nd4jLong *const *, + int, int *, Nd4jLong *const *, Nd4jLong *const *), + LIBND4J_TYPES); diff --git a/libnd4j/include/legacy/cuda/BlasVersionHelper.cu b/libnd4j/include/legacy/cuda/BlasVersionHelper.cu index 04b0e78f1990..d156ac1b9ecd 100644 --- a/libnd4j/include/legacy/cuda/BlasVersionHelper.cu +++ b/libnd4j/include/legacy/cuda/BlasVersionHelper.cu @@ -21,9 +21,9 @@ #include namespace sd { - BlasVersionHelper::BlasVersionHelper() { - _blasMajorVersion = __CUDACC_VER_MAJOR__; - _blasMinorVersion = __CUDACC_VER_MINOR__; - _blasPatchVersion = __CUDACC_VER_BUILD__; - } -} \ No newline at end of file +BlasVersionHelper::BlasVersionHelper() { + _blasMajorVersion = __CUDACC_VER_MAJOR__; + _blasMinorVersion = __CUDACC_VER_MINOR__; + _blasPatchVersion = __CUDACC_VER_BUILD__; +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/legacy/cuda/NativeOpExecutioner.cu b/libnd4j/include/legacy/cuda/NativeOpExecutioner.cu index 14cbf306ac26..adf96771f082 100644 --- a/libnd4j/include/legacy/cuda/NativeOpExecutioner.cu +++ b/libnd4j/include/legacy/cuda/NativeOpExecutioner.cu @@ -14,440 +14,482 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -#include -#include -#include -#include +#include #include -#include +#include +#include #include +#include +#include #include -#include +#include #include - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include #include -#include #include #include -#include -#include -#include -#include -#include -#include -#include +#include #include -#include +#include +#include +#include +#include +#include #include #include #include +#include +#include +#include +#include +#include +#include +#include +#include using namespace sd; /** -* This is utility kernel, that updates given special buffer with proper values in device memory -*/ -extern "C" __global__ void prepareShapeBuffer(int *dimension, int *maxDimension, Nd4jLong *specialPointer, int rows, sd::DataType dataType) { - Nd4jLong tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid > 0) - return; - - dimension[0] = 0; - maxDimension[0] = 1; - - specialPointer[0] = 2; - specialPointer[1] = rows; - specialPointer[2] = 1; - specialPointer[3] = 1; - specialPointer[4] = 1; - specialPointer[5] = 0; - specialPointer[6] = 1; - specialPointer[7] = 99; - - ArrayOptions::setDataType(specialPointer, dataType); - - //printf("special[0]: [%lld]\n", (long long) specialPointer[0]); - //shape::printShapeInfoLinear("prepareShapeBuffer", specialPointer); + * This is utility kernel, that updates given special buffer with proper values + * in device memory + */ +extern "C" __global__ void prepareShapeBuffer(int* dimension, int* maxDimension, + Nd4jLong* specialPointer, + int rows, sd::DataType dataType) { + Nd4jLong tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid > 0) return; + + dimension[0] = 0; + maxDimension[0] = 1; + + specialPointer[0] = 2; + specialPointer[1] = rows; + specialPointer[2] = 1; + specialPointer[3] = 1; + specialPointer[4] = 1; + specialPointer[5] = 0; + specialPointer[6] = 1; + specialPointer[7] = 99; + + ArrayOptions::setDataType(specialPointer, dataType); + + // printf("special[0]: [%lld]\n", (long long) specialPointer[0]); + // shape::printShapeInfoLinear("prepareShapeBuffer", specialPointer); } - //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execPairwiseTransform(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void *extraParams) { - - auto stream = lc->getCudaStream(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (xType != zType && yType != zType) - throw std::runtime_error("NativeOpExecutioner::execPairwiseTransform requires Z operand to have either X or Y type"); - if (lc == nullptr) - throw std::runtime_error("NativeOpExecutioner::execPairwiseTransform: launch context cannot be nullptr !"); - if (stream == nullptr) - throw std::runtime_error("NativeOpExecutioner::execPairwiseTransform: CUDA stream cannot be nullptr !"); - - dim3 launchDims(256, 1024, 8192); +void NativeOpExecutioner::execPairwiseTransform( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void const* hY, Nd4jLong const* hYShapeInfo, void const* dY, + Nd4jLong const* dYShapeInfo, void* hZ, Nd4jLong const* hZShapeInfo, + void* dZ, Nd4jLong const* dZShapeInfo, void* extraParams) { + auto stream = lc->getCudaStream(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (xType != zType && yType != zType) + throw std::runtime_error( + "NativeOpExecutioner::execPairwiseTransform requires Z operand to have " + "either X or Y type"); + if (lc == nullptr) + throw std::runtime_error( + "NativeOpExecutioner::execPairwiseTransform: launch context cannot be " + "nullptr !"); + if (stream == nullptr) + throw std::runtime_error( + "NativeOpExecutioner::execPairwiseTransform: CUDA stream cannot be " + "nullptr !"); + + dim3 launchDims(256, 1024, 8192); #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::pairwise_transforms::PairWiseTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams), LIBND4J_TYPES, LIBND4J_TYPES) + BUILD_PAIRWISE_SELECTOR( + xType, yType, zType, functions::pairwise_transforms::PairWiseTransform, + ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, extraParams), + LIBND4J_TYPES, LIBND4J_TYPES) #else - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::pairwise_transforms::PairWiseTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams), LIBND4J_TYPES) + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::pairwise_transforms::PairWiseTransform, + ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, extraParams), + LIBND4J_TYPES) #endif - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execPairwiseTransform failed", res); + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execPairwiseTransform failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execPairwiseBoolTransform( sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void *extraParams) { - - auto stream = lc->getCudaStream(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (!DataTypeUtils::isB(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform wrong Z operand data type", sd::DataType::BOOL, zType); - - if (yType != xType) - throw sd::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform both operands must have same data type", xType, yType); - - dim3 launchDims(256, 1024, 16384); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::pairwise_transforms::PairWiseBoolTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams), LIBND4J_TYPES, BOOL_TYPES) - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execPairwiseBoolTransform failed", res); +void NativeOpExecutioner::execPairwiseBoolTransform( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void const* hY, Nd4jLong const* hYShapeInfo, void const* dY, + Nd4jLong const* dYShapeInfo, void* hZ, Nd4jLong const* hZShapeInfo, + void* dZ, Nd4jLong const* dZShapeInfo, void* extraParams) { + auto stream = lc->getCudaStream(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (!DataTypeUtils::isB(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execPairwiseBoolTransform wrong Z operand data " + "type", + sd::DataType::BOOL, zType); + + if (yType != xType) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execPairwiseBoolTransform both operands must " + "have same data type", + xType, yType); + + dim3 launchDims(256, 1024, 16384); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::pairwise_transforms::PairWiseBoolTransform, + ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, extraParams), + LIBND4J_TYPES, BOOL_TYPES) + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execPairwiseBoolTransform failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execPairwiseIntTransform( sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void * hZ, Nd4jLong const* hZShapeInfo, - void * dZ, Nd4jLong const* dZShapeInfo, - void *extraParams) { - - auto stream = lc->getCudaStream(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (!DataTypeUtils::isZ(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform wrong Z operand data type", sd::DataType::BOOL, zType); - - if (yType != xType || zType != xType) - throw sd::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform both operands must have same data type", xType, yType); - - dim3 launchDims(256, 1024, 16384); - - BUILD_SINGLE_SELECTOR(xType, functions::pairwise_transforms::PairWiseIntTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams), INTEGER_TYPES) - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execPairwiseIntTransform failed", res); +void NativeOpExecutioner::execPairwiseIntTransform( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void const* hY, Nd4jLong const* hYShapeInfo, void const* dY, + Nd4jLong const* dYShapeInfo, void* hZ, Nd4jLong const* hZShapeInfo, + void* dZ, Nd4jLong const* dZShapeInfo, void* extraParams) { + auto stream = lc->getCudaStream(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (!DataTypeUtils::isZ(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execPairwiseIntTransform wrong Z operand data " + "type", + sd::DataType::BOOL, zType); + + if (yType != xType || zType != xType) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execPairwiseIntTransform both operands must have " + "same data type", + xType, yType); + + dim3 launchDims(256, 1024, 16384); + + BUILD_SINGLE_SELECTOR( + xType, functions::pairwise_transforms::PairWiseIntTransform, + ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, extraParams), + INTEGER_TYPES) + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execPairwiseIntTransform failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execSummaryStatsScalar(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - bool biasCorrected) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - - dim3 launchDims = dim3(256, 256, 32768); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::summarystats::SummaryStatsReduce, ::execSummaryStatsReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, nullptr, biasCorrected, reductionPointer), LIBND4J_TYPES, FLOAT_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execSummaryStatsScalar failed", res); +void NativeOpExecutioner::execSummaryStatsScalar( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, bool biasCorrected) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + + dim3 launchDims = dim3(256, 256, 32768); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::summarystats::SummaryStatsReduce, + ::execSummaryStatsReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, + hXShapeInfo, extraParams, dZ, dZShapeInfo, + hZShapeInfo, nullptr, nullptr, + biasCorrected, reductionPointer), + LIBND4J_TYPES, FLOAT_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execSummaryStatsScalar failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - auto stream = lc->getCudaStream(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (!DataTypeUtils::isB(zType)) - throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires Z operand to have BOOL type"); - - if (yType != xType) - throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires both X & Y operands to have same type"); - - if (sd::Environment::getInstance().isDebugAndVerbose()) - printf("F3B opNum:[%i]\n", opNum); - - dim3 launchDims(256, 256, 1024); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES) - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execBroadcastBool failed", res); +void NativeOpExecutioner::execBroadcastBool( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void const* hY, Nd4jLong const* hYShapeInfo, void const* dY, + Nd4jLong const* dYShapeInfo, void* hZ, Nd4jLong const* hZShapeInfo, + void* dZ, Nd4jLong const* dZShapeInfo, void* extraParams, int* dimension, + int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + auto stream = lc->getCudaStream(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (!DataTypeUtils::isB(zType)) + throw std::runtime_error( + "NativeOpExecutioner::execBroadcastBool requires Z operand to have " + "BOOL type"); + + if (yType != xType) + throw std::runtime_error( + "NativeOpExecutioner::execBroadcastBool requires both X & Y operands " + "to have same type"); + + if (sd::Environment::getInstance().isDebugAndVerbose()) + printf("F3B opNum:[%i]\n", opNum); + + dim3 launchDims(256, 256, 1024); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::broadcast::BroadcastBool, + ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, extraParams, dimension, + dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ), + LIBND4J_TYPES, BOOL_TYPES) + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execBroadcastBool failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext* lc, const int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams) { - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - auto stream = lc->getCudaStream(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - dim3 launchDims; - - launchDims.y = MAX_NUM_THREADS / 4; // threadsPerBlock - launchDims.x = (shape::length(hZShapeInfo) + launchDims.y - 1) / launchDims.y; // blocksPerGrid - launchDims.z = 1024; // shared memory - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams), LIBND4J_TYPES, BOOL_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execBroadcastBool failed", res); +void NativeOpExecutioner::execBroadcastBool( + sd::LaunchContext* lc, const int opNum, const void* hX, + const Nd4jLong* hXShapeInfo, const void* dX, const Nd4jLong* dXShapeInfo, + const void* hY, const Nd4jLong* hYShapeInfo, const void* dY, + const Nd4jLong* dYShapeInfo, void* hZ, const Nd4jLong* hZShapeInfo, + void* dZ, const Nd4jLong* dZShapeInfo, void* extraParams) { + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + auto stream = lc->getCudaStream(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + dim3 launchDims; + + launchDims.y = MAX_NUM_THREADS / 4; // threadsPerBlock + launchDims.x = (shape::length(hZShapeInfo) + launchDims.y - 1) / + launchDims.y; // blocksPerGrid + launchDims.z = 1024; // shared memory + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::broadcast::BroadcastBool, + ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, extraParams), + LIBND4J_TYPES, BOOL_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execBroadcastBool failed", res); } - -void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void* hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - auto stream = lc->getCudaStream(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (!DataTypeUtils::isB(zType)) - throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires Z operand to have BOOL type"); - - if (yType != xType) - throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires both X & Y operands to have same type"); - - dim3 launchDims(256, 256, 1024); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES) - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execInverseBroadcastBool failed", res); +void NativeOpExecutioner::execInverseBroadcastBool( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void const* hY, Nd4jLong const* hYShapeInfo, void const* dY, + Nd4jLong const* dYShapeInfo, void* hZ, Nd4jLong const* hZShapeInfo, + void* dZ, Nd4jLong const* dZShapeInfo, void* extraParams, int* dimension, + int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + auto stream = lc->getCudaStream(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (!DataTypeUtils::isB(zType)) + throw std::runtime_error( + "NativeOpExecutioner::execBroadcastBool requires Z operand to have " + "BOOL type"); + + if (yType != xType) + throw std::runtime_error( + "NativeOpExecutioner::execBroadcastBool requires both X & Y operands " + "to have same type"); + + dim3 launchDims(256, 256, 1024); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::broadcast::BroadcastBool, + ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, extraParams, + dimension, dimensionLength, tadOnlyShapeInfo, + tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), + LIBND4J_TYPES, BOOL_TYPES) + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execInverseBroadcastBool failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadOnlyShapeInfoZ,Nd4jLong const* tadOffsetsZ) { - - auto stream = lc->getCudaStream(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (!DataTypeUtils::isZ(zType)) - throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type"); - - if (yType != xType || zType != xType) - throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires both X & Y operands to have same type"); - - dim3 launchDims(256, 256, 1024); - - BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES) - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execBroadcastBool failed", res); +void NativeOpExecutioner::execBroadcastInt( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void const* hY, Nd4jLong const* hYShapeInfo, void const* dY, + Nd4jLong const* dYShapeInfo, void* hZ, Nd4jLong const* hZShapeInfo, + void* dZ, Nd4jLong const* dZShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + auto stream = lc->getCudaStream(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (!DataTypeUtils::isZ(zType)) + throw std::runtime_error( + "NativeOpExecutioner::execBroadcastInt requires Z operand to have INT " + "type"); + + if (yType != xType || zType != xType) + throw std::runtime_error( + "NativeOpExecutioner::execBroadcastInt requires both X & Y operands to " + "have same type"); + + dim3 launchDims(256, 256, 1024); + + BUILD_SINGLE_SELECTOR( + xType, functions::broadcast::BroadcastInt, + ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, + tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, + tadOffsetsZ), + INTEGER_TYPES) + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execBroadcastBool failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext* lc, const int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo) { - - auto stream = lc->getCudaStream(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (!DataTypeUtils::isZ(zType)) - throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type"); - - if (yType != xType || zType != xType) - throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires both X & Y operands to have same type"); - - dim3 launchDims; - - launchDims.y = MAX_NUM_THREADS / 4; // threadsPerBlock - launchDims.x = (shape::length(hZShapeInfo) + launchDims.y - 1) / launchDims.y; // blocksPerGrid - launchDims.z = 1024; // shared memory - - BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo), INTEGER_TYPES) - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execBroadcastBool failed", res); +void NativeOpExecutioner::execBroadcastInt( + sd::LaunchContext* lc, const int opNum, const void* hX, + const Nd4jLong* hXShapeInfo, const void* dX, const Nd4jLong* dXShapeInfo, + const void* hY, const Nd4jLong* hYShapeInfo, const void* dY, + const Nd4jLong* dYShapeInfo, void* hZ, const Nd4jLong* hZShapeInfo, + void* dZ, const Nd4jLong* dZShapeInfo) { + auto stream = lc->getCudaStream(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (!DataTypeUtils::isZ(zType)) + throw std::runtime_error( + "NativeOpExecutioner::execBroadcastInt requires Z operand to have INT " + "type"); + + if (yType != xType || zType != xType) + throw std::runtime_error( + "NativeOpExecutioner::execBroadcastInt requires both X & Y operands to " + "have same type"); + + dim3 launchDims; + + launchDims.y = MAX_NUM_THREADS / 4; // threadsPerBlock + launchDims.x = (shape::length(hZShapeInfo) + launchDims.y - 1) / + launchDims.y; // blocksPerGrid + launchDims.z = 1024; // shared memory + + BUILD_SINGLE_SELECTOR( + xType, functions::broadcast::BroadcastInt, + ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo), + INTEGER_TYPES) + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execBroadcastBool failed", res); } -void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadOnlyShapeInfoZ,Nd4jLong const* tadOffsetsZ) { - auto stream = lc->getCudaStream(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (!DataTypeUtils::isZ(zType)) - throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type"); - - if (yType != xType || zType != xType) - throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires both X & Y operands to have same type"); - - if (sd::Environment::getInstance().isDebugAndVerbose()) - printf("F3BI opNum:[%i]\n", opNum); - - dim3 launchDims(256, 256, 1024); - - BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES) - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execInverseBroadcastInt failed", res); +void NativeOpExecutioner::execInverseBroadcastInt( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void const* hY, Nd4jLong const* hYShapeInfo, void const* dY, + Nd4jLong const* dYShapeInfo, void* hZ, Nd4jLong const* hZShapeInfo, + void* dZ, Nd4jLong const* dZShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + auto stream = lc->getCudaStream(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (!DataTypeUtils::isZ(zType)) + throw std::runtime_error( + "NativeOpExecutioner::execBroadcastInt requires Z operand to have INT " + "type"); + + if (yType != xType || zType != xType) + throw std::runtime_error( + "NativeOpExecutioner::execBroadcastInt requires both X & Y operands to " + "have same type"); + + if (sd::Environment::getInstance().isDebugAndVerbose()) + printf("F3BI opNum:[%i]\n", opNum); + + dim3 launchDims(256, 256, 1024); + + BUILD_SINGLE_SELECTOR( + xType, functions::broadcast::BroadcastInt, + ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, dimension, + dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ), + INTEGER_TYPES) + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execInverseBroadcastInt failed", res); } //////////////////////////////////////////////////////////////////////// @@ -463,216 +505,242 @@ void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc, * @param dimension * @param dimensionLength */ -void NativeOpExecutioner::execBroadcast(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadOnlyShapeInfoZ,Nd4jLong const* tadOffsetsZ) { +void NativeOpExecutioner::execBroadcast( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void const* hY, Nd4jLong const* hYShapeInfo, void const* dY, + Nd4jLong const* dYShapeInfo, void* hZ, Nd4jLong const* hZShapeInfo, + void* dZ, Nd4jLong const* dZShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + auto stream = lc->getCudaStream(); - auto stream = lc->getCudaStream(); + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - dim3 launchDims(256, 256, 1024); + dim3 launchDims(256, 256, 1024); #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_PAIRWISE_SELECTOR( + xType, yType, zType, functions::broadcast::Broadcast, + ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, + tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, + tadOffsetsZ), + LIBND4J_TYPES, LIBND4J_TYPES); #else - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::broadcast::Broadcast, + ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, + tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, + tadOffsetsZ), + LIBND4J_TYPES); #endif - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execBroadcast failed", res); + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execBroadcast failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execBroadcast(sd::LaunchContext *lc, const int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo) { - - auto stream = lc->getCudaStream(); +void NativeOpExecutioner::execBroadcast( + sd::LaunchContext* lc, const int opNum, const void* hX, + const Nd4jLong* hXShapeInfo, const void* dX, const Nd4jLong* dXShapeInfo, + const void* hY, const Nd4jLong* hYShapeInfo, const void* dY, + const Nd4jLong* dYShapeInfo, void* hZ, const Nd4jLong* hZShapeInfo, + void* dZ, const Nd4jLong* dZShapeInfo) { + auto stream = lc->getCudaStream(); - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; - dim3 launchDims; + dim3 launchDims; - launchDims.y = MAX_NUM_THREADS / 4; // threadsPerBlock - launchDims.x = (shape::length(hZShapeInfo) + launchDims.y - 1) / launchDims.y; // blocksPerGrid - launchDims.z = 1024; // shared memory + launchDims.y = MAX_NUM_THREADS / 4; // threadsPerBlock + launchDims.x = (shape::length(hZShapeInfo) + launchDims.y - 1) / + launchDims.y; // blocksPerGrid + launchDims.z = 1024; // shared memory #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_PAIRWISE_SELECTOR( + xType, yType, zType, functions::broadcast::Broadcast, + ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo), + LIBND4J_TYPES, LIBND4J_TYPES); #else - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::broadcast::Broadcast, + ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo), + LIBND4J_TYPES); #endif - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execBroadcast failed", res); + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execBroadcast failed", res); } -void NativeOpExecutioner::execInverseBroadcast(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadOnlyShapeInfoZ,Nd4jLong const* tadOffsetsZ) { - - auto stream = lc->getCudaStream(); +void NativeOpExecutioner::execInverseBroadcast( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void const* hY, Nd4jLong const* hYShapeInfo, void const* dY, + Nd4jLong const* dYShapeInfo, void* hZ, Nd4jLong const* hZShapeInfo, + void* dZ, Nd4jLong const* dZShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + auto stream = lc->getCudaStream(); - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; - dim3 launchDims(256, 256, 1024); + dim3 launchDims(256, 256, 1024); #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_PAIRWISE_SELECTOR( + xType, yType, zType, functions::broadcast::Broadcast, + ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, dimension, + dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ), + LIBND4J_TYPES, LIBND4J_TYPES); #else - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::broadcast::Broadcast, + ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, dimension, + dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ), + LIBND4J_TYPES); #endif - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execInverseBroadcast failed", res); + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execInverseBroadcast failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceSame(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - - if (sd::Environment::getInstance().isDebugAndVerbose()) - printf("SF7 opNum:[%i]\n", opNum); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - auto xRank = shape::rank(hXShapeInfo); - - if (zType != xType) - throw datatype_exception::build("NativeOpExecutioner::execReduceSame requires both X & Z operands to have same type", xType, zType); - - auto numBlocks = shape::length(hZShapeInfo); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 8192); - - BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduceSame failed", res); +void NativeOpExecutioner::execReduceSame( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + + if (sd::Environment::getInstance().isDebugAndVerbose()) + printf("SF7 opNum:[%i]\n", opNum); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + auto xRank = shape::rank(hXShapeInfo); + + if (zType != xType) + throw datatype_exception::build( + "NativeOpExecutioner::execReduceSame requires both X & Z operands to " + "have same type", + xType, zType); + + auto numBlocks = shape::length(hZShapeInfo); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 8192); + + BUILD_SINGLE_SELECTOR( + xType, functions::reduce::ReduceSameFunction, + ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, + hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, + dimension, dimensionLength, reductionPointer, tadShapeInfo, + tadOffsets), + LIBND4J_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execReduceSame failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension,int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - - if (sd::Environment::getInstance().isDebugAndVerbose()) - printf("LF7 opNum:[%i]\n", opNum); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (zType != sd::DataType::INT64) - throw datatype_exception::build("NativeOpExecutioner::execReduceLong wrong Z data type", sd::DataType::INT64, zType); - - auto xRank = shape::rank(hXShapeInfo); - auto numBlocks = shape::length(hZShapeInfo); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, LONG_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduceLong failed", res); - +void NativeOpExecutioner::execReduceLong( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + + if (sd::Environment::getInstance().isDebugAndVerbose()) + printf("LF7 opNum:[%i]\n", opNum); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (zType != sd::DataType::INT64) + throw datatype_exception::build( + "NativeOpExecutioner::execReduceLong wrong Z data type", + sd::DataType::INT64, zType); + + auto xRank = shape::rank(hXShapeInfo); + auto numBlocks = shape::length(hZShapeInfo); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceLongFunction, + ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, + hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, + dimension, dimensionLength, reductionPointer, tadShapeInfo, + tadOffsets), + LIBND4J_TYPES, LONG_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execReduceLong failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - - if (sd::Environment::getInstance().isDebugAndVerbose()) - printf("BF7 opNum:[%i]\n", opNum); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (zType != sd::DataType::BOOL) - throw std::runtime_error("NativeOpExecutioner::execReduceBool requires Z operand to have BOOL type"); - - auto xRank = shape::rank(hXShapeInfo); - auto numBlocks = shape::length(hZShapeInfo); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, BOOL_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduceBool failed", res); +void NativeOpExecutioner::execReduceBool( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + + if (sd::Environment::getInstance().isDebugAndVerbose()) + printf("BF7 opNum:[%i]\n", opNum); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (zType != sd::DataType::BOOL) + throw std::runtime_error( + "NativeOpExecutioner::execReduceBool requires Z operand to have BOOL " + "type"); + + auto xRank = shape::rank(hXShapeInfo); + auto numBlocks = shape::length(hZShapeInfo); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceBoolFunction, + ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, + hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, + dimension, dimensionLength, reductionPointer, tadShapeInfo, + tadOffsets), + LIBND4J_TYPES, BOOL_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execReduceBool failed", res); } //////////////////////////////////////////////////////////////////////// @@ -687,39 +755,44 @@ void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc, * @param dimension * @param dimensionLength */ -void NativeOpExecutioner::execIndexReduce(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - auto allocationPointer = lc->getAllocationPointer(); - - if (sd::Environment::getInstance().isDebugAndVerbose()) - printf("F2 opNum:[%i]\n", opNum); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - auto numBlocks = shape::length(hZShapeInfo); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); - - if (zType != sd::DataType::INT64 && zType != sd::DataType::INT32) - throw datatype_exception::build("NativeOpExecutioner::execIndexReduce requires Z operand to have INT32/INT64 type", zType); - - auto dz = reinterpret_cast(dZ); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::indexreduce::IndexReduce, ::executeIndexReduce(launchDims, stream, opNum, dX, dXShapeInfo, shape::rank(hXShapeInfo), extraParams, dz, dZShapeInfo, shape::rank(hZShapeInfo), dimension, dimensionLength, 1, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, INDEXING_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execIndexReduce failed", res); +void NativeOpExecutioner::execIndexReduce( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + auto allocationPointer = lc->getAllocationPointer(); + + if (sd::Environment::getInstance().isDebugAndVerbose()) + printf("F2 opNum:[%i]\n", opNum); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + auto numBlocks = shape::length(hZShapeInfo); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); + + if (zType != sd::DataType::INT64 && zType != sd::DataType::INT32) + throw datatype_exception::build( + "NativeOpExecutioner::execIndexReduce requires Z operand to have " + "INT32/INT64 type", + zType); + + auto dz = reinterpret_cast(dZ); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::indexreduce::IndexReduce, + ::executeIndexReduce(launchDims, stream, opNum, dX, dXShapeInfo, + shape::rank(hXShapeInfo), extraParams, dz, + dZShapeInfo, shape::rank(hZShapeInfo), dimension, + dimensionLength, 1, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets), + LIBND4J_TYPES, INDEXING_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execIndexReduce failed", res); } //////////////////////////////////////////////////////////////////////// @@ -732,38 +805,38 @@ void NativeOpExecutioner::execIndexReduce(sd::LaunchContext *lc, * @param dZ * @param dZShapeInfo */ -void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension,int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - - if (sd::Environment::getInstance().isDebugAndVerbose()) - printf("F8 opNum:[%i]\n", opNum); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - auto xRank = shape::rank(hXShapeInfo); - auto numBlocks = shape::length(hZShapeInfo); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, FLOAT_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduceFloat failed", res); +void NativeOpExecutioner::execReduceFloat( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + + if (sd::Environment::getInstance().isDebugAndVerbose()) + printf("F8 opNum:[%i]\n", opNum); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + auto xRank = shape::rank(hXShapeInfo); + auto numBlocks = shape::length(hZShapeInfo); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceFloatFunction, + ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, + hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, + dimension, dimensionLength, reductionPointer, tadShapeInfo, + tadOffsets), + LIBND4J_TYPES, FLOAT_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execReduceFloat failed", res); } - /** * * @param opNum @@ -772,947 +845,1030 @@ void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc, * @param extraParams */ //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execIndexReduceScalar(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo){ - - if (sd::Environment::getInstance().isDebug()) - printf("F1 opNum:[%i]\n", opNum); - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - auto allocationPointer = lc->getAllocationPointer(); - - auto xLength = shape::length(hXShapeInfo); - auto blockWidth = 256; - auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); - - if (sd::Environment::getInstance().isDebugAndVerbose() && launchDims.x == 1) - printf("AF1 opNum:[%i]\n", opNum); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - // FIXME: we want Z to be one of integer types - //if (!DataTypeUtils::isZ(zType)) - // throw sd::datatype_exception("NativeOpExecutioner::execIndexReduceScalar requires Z operand to have one of integer types") - if (zType != sd::DataType::INT64 && zType != sd::DataType::INT32) - throw sd::datatype_exception::build("NativeOpExecutioner::execIndexReduceScalar requires Z operand to have INT32/INT64 data type", zType); - - auto dz = reinterpret_cast(dZ); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::indexreduce::IndexReduce, ::executeIndexReduceScalar(launchDims, stream, - opNum, - dX, dXShapeInfo, shape::rank(hXShapeInfo), - extraParams, - dz, dZShapeInfo, 0, - nullptr, 0, - 1, - allocationPointer, reductionPointer, - nullptr, nullptr), LIBND4J_TYPES, INDEXING_TYPES); - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execIndexReduceScalar failed", res); +void NativeOpExecutioner::execIndexReduceScalar( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo) { + if (sd::Environment::getInstance().isDebug()) + printf("F1 opNum:[%i]\n", opNum); + + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + auto allocationPointer = lc->getAllocationPointer(); + + auto xLength = shape::length(hXShapeInfo); + auto blockWidth = 256; + auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); + + if (sd::Environment::getInstance().isDebugAndVerbose() && launchDims.x == 1) + printf("AF1 opNum:[%i]\n", opNum); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + // FIXME: we want Z to be one of integer types + // if (!DataTypeUtils::isZ(zType)) + // throw sd::datatype_exception("NativeOpExecutioner::execIndexReduceScalar + // requires Z operand to have one of integer types") + if (zType != sd::DataType::INT64 && zType != sd::DataType::INT32) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execIndexReduceScalar requires Z operand to have " + "INT32/INT64 data type", + zType); + + auto dz = reinterpret_cast(dZ); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::indexreduce::IndexReduce, + ::executeIndexReduceScalar( + launchDims, stream, opNum, dX, dXShapeInfo, shape::rank(hXShapeInfo), + extraParams, dz, dZShapeInfo, 0, nullptr, 0, 1, allocationPointer, + reductionPointer, nullptr, nullptr), + LIBND4J_TYPES, INDEXING_TYPES); + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execIndexReduceScalar failed", res); } - //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceFloatScalar(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - auto xLength = shape::length(hXShapeInfo); - auto blockWidth = 256; - auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::execReduceScalar(launchDims, stream, opNum, dX,dXShapeInfo, hXShapeInfo, extraParams, dZ,dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES, FLOAT_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduceFloatScalar failed", res); +void NativeOpExecutioner::execReduceFloatScalar( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + auto xLength = shape::length(hXShapeInfo); + auto blockWidth = 256; + auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceFloatFunction, + ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, + hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, + nullptr, 0, reductionPointer, nullptr), + LIBND4J_TYPES, FLOAT_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execReduceFloatScalar failed", res); } - //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceBoolScalar(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (zType != sd::DataType::BOOL) - throw std::runtime_error("NativeOpExecutioner::execReduceBoolScalar requires Z operand to have BOOL type"); - - auto xLength = shape::length(hXShapeInfo); - auto blockWidth = 256; - auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES, BOOL_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduceBoolScalar failed", res); +void NativeOpExecutioner::execReduceBoolScalar( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (zType != sd::DataType::BOOL) + throw std::runtime_error( + "NativeOpExecutioner::execReduceBoolScalar requires Z operand to have " + "BOOL type"); + + auto xLength = shape::length(hXShapeInfo); + auto blockWidth = 256; + auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceBoolFunction, + ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, + hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, + nullptr, 0, reductionPointer, nullptr), + LIBND4J_TYPES, BOOL_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execReduceBoolScalar failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceSameScalar(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (zType != xType) - throw datatype_exception::build("NativeOpExecutioner::execReduceSameScalar requires both X & Z operands to have same type", xType, zType); - - auto xLength = shape::length(hXShapeInfo); - auto blockWidth = 256; - auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); - - BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduceSameScalar failed", res); +void NativeOpExecutioner::execReduceSameScalar( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (zType != xType) + throw datatype_exception::build( + "NativeOpExecutioner::execReduceSameScalar requires both X & Z " + "operands to have same type", + xType, zType); + + auto xLength = shape::length(hXShapeInfo); + auto blockWidth = 256; + auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); + + BUILD_SINGLE_SELECTOR( + xType, functions::reduce::ReduceSameFunction, + ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, + hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, + nullptr, 0, reductionPointer, nullptr), + LIBND4J_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execReduceSameScalar failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceLongScalar(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (zType != sd::DataType::INT64) - throw datatype_exception::build("NativeOpExecutioner::execReduceLongScalar wrong Z data type", sd::DataType::INT64, zType); - - auto xLength = shape::length(hXShapeInfo); - auto blockWidth = 256; - auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES, LONG_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduceLongScalar failed", res); +void NativeOpExecutioner::execReduceLongScalar( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (zType != sd::DataType::INT64) + throw datatype_exception::build( + "NativeOpExecutioner::execReduceLongScalar wrong Z data type", + sd::DataType::INT64, zType); + + auto xLength = shape::length(hXShapeInfo); + auto blockWidth = 256; + auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceLongFunction, + ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, + hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, + nullptr, 0, reductionPointer, nullptr), + LIBND4J_TYPES, LONG_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execReduceLongScalar failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execTransformSame(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void *extraParams, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - - auto stream = lc->getCudaStream(); - - auto xRank = shape::rank(hXShapeInfo); - auto zRank = shape::rank(hZShapeInfo); - auto xType = ArrayOptions::dataType(hXShapeInfo); - auto zType = ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo)) { - return; - } - - if (xType != zType) { - throw std::runtime_error("NativeOpExecutioner::execTransformSame requires X & Z to have same type"); - } - - dim3 launchDims(512, 512, 16384); - BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execTransformSame failed", res); +void NativeOpExecutioner::execTransformSame( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, void* extraParams, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { + auto stream = lc->getCudaStream(); + + auto xRank = shape::rank(hXShapeInfo); + auto zRank = shape::rank(hZShapeInfo); + auto xType = ArrayOptions::dataType(hXShapeInfo); + auto zType = ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo)) { + return; + } + + if (xType != zType) { + throw std::runtime_error( + "NativeOpExecutioner::execTransformSame requires X & Z to have same " + "type"); + } + + dim3 launchDims(512, 512, 16384); + BUILD_SINGLE_SELECTOR( + xType, functions::transform::TransformSame, + ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, + xRank, extraParams, dZ, dZShapeInfo, zRank, + nullptr, nullptr, nullptr, nullptr), + LIBND4J_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execTransformSame failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execTransformBool(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void *extraParams, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - - auto stream = lc->getCudaStream(); - - auto xRank = shape::rank(hXShapeInfo); - auto zRank = shape::rank(hZShapeInfo); - auto xType = ArrayOptions::dataType(hXShapeInfo); - auto zType = ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo)) { - return; - } - - if (!DataTypeUtils::isB(zType)) { - throw std::runtime_error("NativeOpExecutioner::execTransformBool requires Z to have same boolean type"); - } - - dim3 launchDims(512, 512, 16384); - BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, BOOL_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execTransformBool failed", res); +void NativeOpExecutioner::execTransformBool( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, void* extraParams, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { + auto stream = lc->getCudaStream(); + + auto xRank = shape::rank(hXShapeInfo); + auto zRank = shape::rank(hZShapeInfo); + auto xType = ArrayOptions::dataType(hXShapeInfo); + auto zType = ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo)) { + return; + } + + if (!DataTypeUtils::isB(zType)) { + throw std::runtime_error( + "NativeOpExecutioner::execTransformBool requires Z to have same " + "boolean type"); + } + + dim3 launchDims(512, 512, 16384); + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::transform::TransformBool, + ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, + xRank, extraParams, dZ, dZShapeInfo, zRank, + nullptr, nullptr, nullptr, nullptr), + LIBND4J_TYPES, BOOL_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execTransformBool failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execTransformAny(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void *extraParams, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool allowParallelism) { - - auto stream = lc->getCudaStream(); - - auto xRank = shape::rank(hXShapeInfo); - auto zRank = shape::rank(hZShapeInfo); - auto xType = ArrayOptions::dataType(hXShapeInfo); - auto zType = ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo)) - return; - - if (opNum == sd::transform::Assign && shape::order(hXShapeInfo) == shape::order(hZShapeInfo) && shape::order(hXShapeInfo) == 'c' && xType == zType && shape::elementWiseStride(hXShapeInfo) == 1 && shape::elementWiseStride(hZShapeInfo) == 1) { - cudaMemcpyAsync(dZ, dX, shape::length(hXShapeInfo) * sd::DataTypeUtils::sizeOfElement(xType), cudaMemcpyDeviceToDevice, *stream); - } - else { - - dim3 launchDims(512, 512, 2048); - BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, LIBND4J_TYPES); - } - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execTransformAny failed", res); +void NativeOpExecutioner::execTransformAny( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, void* extraParams, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + bool allowParallelism) { + auto stream = lc->getCudaStream(); + + auto xRank = shape::rank(hXShapeInfo); + auto zRank = shape::rank(hZShapeInfo); + auto xType = ArrayOptions::dataType(hXShapeInfo); + auto zType = ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo)) return; + + if (opNum == sd::transform::Assign && + shape::order(hXShapeInfo) == shape::order(hZShapeInfo) && + shape::order(hXShapeInfo) == 'c' && xType == zType && + shape::elementWiseStride(hXShapeInfo) == 1 && + shape::elementWiseStride(hZShapeInfo) == 1) { + cudaMemcpyAsync( + dZ, dX, + shape::length(hXShapeInfo) * sd::DataTypeUtils::sizeOfElement(xType), + cudaMemcpyDeviceToDevice, *stream); + } else { + dim3 launchDims(512, 512, 2048); + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::transform::TransformAny, + ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, + xRank, extraParams, dZ, dZShapeInfo, zRank, + nullptr, nullptr, nullptr, nullptr), + LIBND4J_TYPES, LIBND4J_TYPES); + } + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execTransformAny failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execTransformStrict(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void *extraParams, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - - auto stream = lc->getCudaStream(); - - auto xRank = shape::rank(hXShapeInfo); - auto zRank = shape::rank(hZShapeInfo); - auto xType = ArrayOptions::dataType(hXShapeInfo); - auto zType = ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo)) { - return; - } - - if (xType != zType || !DataTypeUtils::isR(xType)) { - throw datatype_exception::build("NativeOpExecutioner::execTransformStrict requires X & Z to have same floating point type", xType, zType); - } - - dim3 launchDims(512, 512, 16384); - BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), FLOAT_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execTransformStrict failed", res); +void NativeOpExecutioner::execTransformStrict( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, void* extraParams, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { + auto stream = lc->getCudaStream(); + + auto xRank = shape::rank(hXShapeInfo); + auto zRank = shape::rank(hZShapeInfo); + auto xType = ArrayOptions::dataType(hXShapeInfo); + auto zType = ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo)) { + return; + } + + if (xType != zType || !DataTypeUtils::isR(xType)) { + throw datatype_exception::build( + "NativeOpExecutioner::execTransformStrict requires X & Z to have same " + "floating point type", + xType, zType); + } + + dim3 launchDims(512, 512, 16384); + BUILD_SINGLE_SELECTOR( + xType, functions::transform::TransformStrict, + ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, + xRank, extraParams, dZ, dZShapeInfo, zRank, + nullptr, nullptr, nullptr, nullptr), + FLOAT_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execTransformStrict failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execTransformFloat(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void *extraParams, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - - auto xRank = shape::rank(hXShapeInfo); - auto zRank = shape::rank(hZShapeInfo); - auto xType = ArrayOptions::dataType(hXShapeInfo); - auto zType = ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo)) - return; - - if (!DataTypeUtils::isR(zType)) - throw datatype_exception::build("NativeOpExecutioner::execTransformFloat requires Z to have floating point type", zType); - - dim3 launchDims(512, 512, 2048); - BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, FLOAT_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execTransformFloat failed", res); +void NativeOpExecutioner::execTransformFloat( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, void* extraParams, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + + auto xRank = shape::rank(hXShapeInfo); + auto zRank = shape::rank(hZShapeInfo); + auto xType = ArrayOptions::dataType(hXShapeInfo); + auto zType = ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo)) return; + + if (!DataTypeUtils::isR(zType)) + throw datatype_exception::build( + "NativeOpExecutioner::execTransformFloat requires Z to have floating " + "point type", + zType); + + dim3 launchDims(512, 512, 2048); + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::transform::TransformFloat, + ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, + xRank, extraParams, dZ, dZShapeInfo, zRank, + nullptr, nullptr, nullptr, nullptr), + LIBND4J_TYPES, FLOAT_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execTransformFloat failed", res); } - //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - bool biasCorrected) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - - dim3 launchDims = dim3(256, 256, 32768); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (!DataTypeUtils::isR(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execSummaryStats requires Z operand to have floating point data type", zType); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::summarystats::SummaryStatsReduce, ::execSummaryStatsReduce(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, nullptr, biasCorrected, reductionPointer), LIBND4J_TYPES, FLOAT_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execSummaryStats A failed", res); +void NativeOpExecutioner::execSummaryStats( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, bool biasCorrected) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + + dim3 launchDims = dim3(256, 256, 32768); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (!DataTypeUtils::isR(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execSummaryStats requires Z operand to have " + "floating point data type", + zType); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::summarystats::SummaryStatsReduce, + ::execSummaryStatsReduce(launchDims, stream, opNum, dX, dXShapeInfo, + hXShapeInfo, extraParams, dZ, dZShapeInfo, + hZShapeInfo, nullptr, nullptr, biasCorrected, + reductionPointer), + LIBND4J_TYPES, FLOAT_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execSummaryStats A failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - bool biasCorrected) { - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - - dim3 launchDims = dim3(256, 256, 32768); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (!DataTypeUtils::isR(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execSummaryStats requires Z operand to have floating point data type", zType); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::summarystats::SummaryStatsReduce, ::execSummaryStatsReduce(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, biasCorrected, reductionPointer), LIBND4J_TYPES, FLOAT_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execSummaryStats B failed", res); +void NativeOpExecutioner::execSummaryStats( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + bool biasCorrected) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + + dim3 launchDims = dim3(256, 256, 32768); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (!DataTypeUtils::isR(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execSummaryStats requires Z operand to have " + "floating point data type", + zType); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::summarystats::SummaryStatsReduce, + ::execSummaryStatsReduce( + launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, + dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, + tadShapeInfo, tadOffsets, biasCorrected, reductionPointer), + LIBND4J_TYPES, FLOAT_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execSummaryStats B failed", res); } - //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - auto allocationPointer = lc->getAllocationPointer(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - auto blockWidth = 256; - auto numBlocks = CudaLaunchHelper::getReductionBlocks(shape::length(hXShapeInfo), blockWidth); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); - - if (xType != yType) - throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3 requires Y operand to have X type", xType, yType); - - if (!DataTypeUtils::isR(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3 requires Z operand to have floating point data type", zType); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execScalar(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, extraParams, dZ, dZShapeInfo, allocationPointer, reductionPointer, nullptr), LIBND4J_TYPES, FLOAT_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduce3 failed", res); +void NativeOpExecutioner::execReduce3( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void const* hY, Nd4jLong const* hYShapeInfo, + void const* dY, Nd4jLong const* dYShapeInfo, void* hZ, + Nd4jLong const* hZShapeInfo, void* dZ, Nd4jLong const* dZShapeInfo) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + auto allocationPointer = lc->getAllocationPointer(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + auto blockWidth = 256; + auto numBlocks = CudaLaunchHelper::getReductionBlocks( + shape::length(hXShapeInfo), blockWidth); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); + + if (xType != yType) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execReduce3 requires Y operand to have X type", + xType, yType); + + if (!DataTypeUtils::isR(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execReduce3 requires Z operand to have floating " + "point data type", + zType); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce3::Reduce3, + ::execScalar(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, + extraParams, dZ, dZShapeInfo, allocationPointer, + reductionPointer, nullptr), + LIBND4J_TYPES, FLOAT_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execReduce3 failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { - - if(shape::isScalar(hZShapeInfo)) { - NativeOpExecutioner::execReduce3(lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); - return; - } - - auto stream = lc->getCudaStream(); - auto allocationPointer = lc->getAllocationPointer(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (xType != yType) - throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3 requires Y operand to have X type", xType, yType); - - if (!DataTypeUtils::isR(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3 requires Z operand to have floating point data type", zType); - - - auto numBlocks = shape::length(hZShapeInfo); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(launchDims, stream, opNum, - dX, dXShapeInfo, - dY, dYShapeInfo, - extraParams, - dZ, dZShapeInfo, - dimension, dimensionLength, - 1, - allocationPointer, - tadOnlyShapeInfo, tadOffsets, - yTadOnlyShapeInfo, yTadOffsets), LIBND4J_TYPES, FLOAT_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduce3 B failed", res); +void NativeOpExecutioner::execReduce3( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void const* hY, Nd4jLong const* hYShapeInfo, + void const* dY, Nd4jLong const* dYShapeInfo, void* hZ, + Nd4jLong const* hZShapeInfo, void* dZ, Nd4jLong const* dZShapeInfo, + int* dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* yTadOnlyShapeInfo, + Nd4jLong const* yTadOffsets) { + if (shape::isScalar(hZShapeInfo)) { + NativeOpExecutioner::execReduce3( + lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, + hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); + return; + } + + auto stream = lc->getCudaStream(); + auto allocationPointer = lc->getAllocationPointer(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (xType != yType) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execReduce3 requires Y operand to have X type", + xType, yType); + + if (!DataTypeUtils::isR(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execReduce3 requires Z operand to have floating " + "point data type", + zType); + + auto numBlocks = shape::length(hZShapeInfo); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce3::Reduce3, + ::exec(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, + extraParams, dZ, dZShapeInfo, dimension, dimensionLength, 1, + allocationPointer, tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, + yTadOffsets), + LIBND4J_TYPES, FLOAT_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execReduce3 B failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduce3Scalar(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo) { - - - auto stream = lc->getCudaStream(); - auto allocationPointer = lc->getAllocationPointer(); - auto reductionPointer = lc->getReductionPointer(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - auto xLength = shape::length(hXShapeInfo); - auto blockWidth = 256; - auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); - - if (xType != yType) - throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3Scalar requires Y operand to have X type", xType, yType); - - if (!DataTypeUtils::isR(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3Scalar requires Z operand to have floating point data type", zType); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execScalar(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, extraParams, dZ, dZShapeInfo, allocationPointer, reductionPointer, nullptr), LIBND4J_TYPES, FLOAT_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduce3Scalar failed", res); +void NativeOpExecutioner::execReduce3Scalar( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void const* hY, Nd4jLong const* hYShapeInfo, + void const* dY, Nd4jLong const* dYShapeInfo, void* hZ, + Nd4jLong const* hZShapeInfo, void* dZ, Nd4jLong const* dZShapeInfo) { + auto stream = lc->getCudaStream(); + auto allocationPointer = lc->getAllocationPointer(); + auto reductionPointer = lc->getReductionPointer(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + auto xLength = shape::length(hXShapeInfo); + auto blockWidth = 256; + auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); + + if (xType != yType) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execReduce3Scalar requires Y operand to have X " + "type", + xType, yType); + + if (!DataTypeUtils::isR(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execReduce3Scalar requires Z operand to have " + "floating point data type", + zType); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce3::Reduce3, + ::execScalar(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, + extraParams, dZ, dZShapeInfo, allocationPointer, + reductionPointer, nullptr), + LIBND4J_TYPES, FLOAT_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execReduce3Scalar failed", res); } - //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void const* hScalar, Nd4jLong const* hScalarShapeInfo, - void const* dScalar, Nd4jLong const* dScalarShapeInfo, - void *extraParams, bool allowParallelism) { - - auto stream = lc->getCudaStream(); - - dim3 launchDims = dim3(256, 512, 8192); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) - return; - - if (xType != yType ) - throw std::runtime_error("NativeOpExecutioner::execScalarBool requires X & Y to have same type"); - - if (!DataTypeUtils::isB(zType) ) - throw std::runtime_error("NativeOpExecutioner::execScalarBool requires Z operand to have BOOL type"); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalar, extraParams), LIBND4J_TYPES, BOOL_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execScalarBool failed", res); +void NativeOpExecutioner::execScalarBool( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, void const* hScalar, + Nd4jLong const* hScalarShapeInfo, void const* dScalar, + Nd4jLong const* dScalarShapeInfo, void* extraParams, + bool allowParallelism) { + auto stream = lc->getCudaStream(); + + dim3 launchDims = dim3(256, 512, 8192); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) return; + + if (xType != yType) + throw std::runtime_error( + "NativeOpExecutioner::execScalarBool requires X & Y to have same type"); + + if (!DataTypeUtils::isB(zType)) + throw std::runtime_error( + "NativeOpExecutioner::execScalarBool requires Z operand to have BOOL " + "type"); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::scalar::ScalarBoolTransform, + ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dZ, + dZShapeInfo, dScalar, extraParams), + LIBND4J_TYPES, BOOL_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execScalarBool failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void const* hScalars, Nd4jLong const* hScalarShapeInfo, - void const* dScalars, Nd4jLong const* dScalarShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - auto stream = lc->getCudaStream(); - - dim3 launchDims(256, 512, 8192); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) - return; - - if (xType != yType ) - throw std::runtime_error("NativeOpExecutioner::execScalarBool requires X & Y to have same type"); - - if (!DataTypeUtils::isB(zType) ) - throw std::runtime_error("NativeOpExecutioner::execScalarBool requires Z operand to have BOOL type"); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execScalarBool B failed", res); +void NativeOpExecutioner::execScalarBool( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, void const* hScalars, + Nd4jLong const* hScalarShapeInfo, void const* dScalars, + Nd4jLong const* dScalarShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + auto stream = lc->getCudaStream(); + + dim3 launchDims(256, 512, 8192); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) return; + + if (xType != yType) + throw std::runtime_error( + "NativeOpExecutioner::execScalarBool requires X & Y to have same type"); + + if (!DataTypeUtils::isB(zType)) + throw std::runtime_error( + "NativeOpExecutioner::execScalarBool requires Z operand to have BOOL " + "type"); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::scalar::ScalarBoolTransform, + ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, + dZ, dZShapeInfo, dScalars, extraParams, + dimension, dimensionLength, tadShapeInfo, + tadOffsets, tadShapeInfoZ, tadOffsetsZ), + LIBND4J_TYPES, BOOL_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execScalarBool B failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void const* hScalar, Nd4jLong const* hScalarShapeInfo, - void const* dScalar, Nd4jLong const* dScalarShapeInfo, - void *extraParams, bool allowParallelism) { - - auto stream = lc->getCudaStream(); - - dim3 launchDims = dim3(256, 512, 8192); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) - return; - - if (xType != yType || zType != xType) - throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type"); - - if (!DataTypeUtils::isZ(zType) ) - throw std::runtime_error("NativeOpExecutioner::execScalarInt requires Z operand to have INT type"); - - BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalar, extraParams), INTEGER_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execScalarInt failed", res); +void NativeOpExecutioner::execScalarInt( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, void const* hScalar, + Nd4jLong const* hScalarShapeInfo, void const* dScalar, + Nd4jLong const* dScalarShapeInfo, void* extraParams, + bool allowParallelism) { + auto stream = lc->getCudaStream(); + + dim3 launchDims = dim3(256, 512, 8192); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) return; + + if (xType != yType || zType != xType) + throw std::runtime_error( + "NativeOpExecutioner::execScalarInt requires X & Y to have same type"); + + if (!DataTypeUtils::isZ(zType)) + throw std::runtime_error( + "NativeOpExecutioner::execScalarInt requires Z operand to have INT " + "type"); + + BUILD_SINGLE_SELECTOR( + xType, functions::scalar::ScalarIntTransform, + ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dZ, + dZShapeInfo, dScalar, extraParams), + INTEGER_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execScalarInt failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void const* hScalars, Nd4jLong const* hScalarShapeInfo, - void const* dScalars, Nd4jLong const* dScalarShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - auto stream = lc->getCudaStream(); - - dim3 launchDims(256, 512, 8192); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) - return; - - if (xType != yType || zType != xType) - throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type"); - - if (!DataTypeUtils::isZ(zType) ) - throw std::runtime_error("NativeOpExecutioner::execScalarInt requires Z operand to have INT type"); - - BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), INTEGER_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execScalarInt B failed", res); +void NativeOpExecutioner::execScalarInt( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, void const* hScalars, + Nd4jLong const* hScalarShapeInfo, void const* dScalars, + Nd4jLong const* dScalarShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + auto stream = lc->getCudaStream(); + + dim3 launchDims(256, 512, 8192); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) return; + + if (xType != yType || zType != xType) + throw std::runtime_error( + "NativeOpExecutioner::execScalarInt requires X & Y to have same type"); + + if (!DataTypeUtils::isZ(zType)) + throw std::runtime_error( + "NativeOpExecutioner::execScalarInt requires Z operand to have INT " + "type"); + + BUILD_SINGLE_SELECTOR( + xType, functions::scalar::ScalarIntTransform, + ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, + dZ, dZShapeInfo, dScalars, extraParams, + dimension, dimensionLength, tadShapeInfo, + tadOffsets, tadShapeInfoZ, tadOffsetsZ), + INTEGER_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execScalarInt B failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execScalar(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void* hZ, Nd4jLong const* hZShapeInfo, - void* dZ, Nd4jLong const* dZShapeInfo, - void const* hScalar, Nd4jLong const* hScalarShapeInfo, - void const* dScalar, Nd4jLong const* dScalarShapeInfo, - void *extraParams, bool allowParallelism) { - - auto stream = lc->getCudaStream(); - - dim3 launchDims(256, 512, 8192); +void NativeOpExecutioner::execScalar( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, void const* hScalar, + Nd4jLong const* hScalarShapeInfo, void const* dScalar, + Nd4jLong const* dScalarShapeInfo, void* extraParams, + bool allowParallelism) { + auto stream = lc->getCudaStream(); - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + dim3 launchDims(256, 512, 8192); - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) - return; + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) return; #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, dZ, dZShapeInfo, hZShapeInfo, dScalar, extraParams), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_PAIRWISE_SELECTOR( + xType, yType, zType, functions::scalar::ScalarTransform, + ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, + hXShapeInfo, dZ, dZShapeInfo, hZShapeInfo, dScalar, + extraParams), + LIBND4J_TYPES, LIBND4J_TYPES); #else - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, dZ, dZShapeInfo, hZShapeInfo, dScalar, extraParams), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::scalar::ScalarTransform, + ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, + hXShapeInfo, dZ, dZShapeInfo, hZShapeInfo, dScalar, + extraParams), + LIBND4J_TYPES); #endif - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execScalar failed", res); + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execScalar failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execScalar(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void const* hScalars, Nd4jLong const* hScalarShapeInfo, - void const* dScalars, Nd4jLong const* dScalarShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - auto stream = lc->getCudaStream(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) - return; - - dim3 launchDims(256, 256, 16384); +void NativeOpExecutioner::execScalar( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, void const* hScalars, + Nd4jLong const* hScalarShapeInfo, void const* dScalars, + Nd4jLong const* dScalarShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + auto stream = lc->getCudaStream(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) return; + + dim3 launchDims(256, 256, 16384); #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_PAIRWISE_SELECTOR( + xType, yType, zType, functions::scalar::ScalarTransform, + ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, + dZ, dZShapeInfo, dScalars, extraParams, + dimension, dimensionLength, tadShapeInfo, + tadOffsets, tadShapeInfoZ, tadOffsetsZ), + LIBND4J_TYPES, LIBND4J_TYPES); #else - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::scalar::ScalarTransform, + ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, + dZ, dZShapeInfo, dScalars, extraParams, + dimension, dimensionLength, tadShapeInfo, + tadOffsets, tadShapeInfoZ, tadOffsetsZ), + LIBND4J_TYPES); #endif - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execScalar B failed", res); + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execScalar B failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execRandom(sd::LaunchContext *lc, - int opNum, - Nd4jPointer stateHost, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void *extraArguments) { - - auto stream = lc->getCudaStream(); - auto sizeOf = sizeof(sd::graph::RandomGenerator); - Nd4jPointer stateDevice; - - cudaError_t res = cudaMalloc(reinterpret_cast(&stateDevice), sizeOf); - checkCudaErrors(cudaStreamSynchronize(*stream)); - checkCudaErrors(cudaMemcpyAsync(stateDevice, stateHost, sizeOf, cudaMemcpyHostToDevice, *stream)); - - dim3 launchDims = dim3(512, 512, 32768); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - auto rng = reinterpret_cast(stateHost); - - // functions::random::RandomFunction::executeCudaSingle(launchDims, extraPointers, opNum, stateHost, dZ, dZShapeInfo, extraArguments), - BUILD_SINGLE_SELECTOR(zType, functions::random::RandomFunction, ::executeCudaSingle(launchDims, stream, opNum, stateDevice, dZ, dZShapeInfo, extraArguments), FLOAT_TYPES); - - res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execRandom X failed", res); - - cudaFree(stateDevice); - - rng->rewindH(shape::length(hZShapeInfo)); +void NativeOpExecutioner::execRandom(sd::LaunchContext* lc, int opNum, + Nd4jPointer stateHost, void* hZ, + Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, + void* extraArguments) { + auto stream = lc->getCudaStream(); + auto sizeOf = sizeof(sd::graph::RandomGenerator); + Nd4jPointer stateDevice; + + cudaError_t res = cudaMalloc(reinterpret_cast(&stateDevice), sizeOf); + checkCudaErrors(cudaStreamSynchronize(*stream)); + checkCudaErrors(cudaMemcpyAsync(stateDevice, stateHost, sizeOf, + cudaMemcpyHostToDevice, *stream)); + + dim3 launchDims = dim3(512, 512, 32768); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + auto rng = reinterpret_cast(stateHost); + + // functions::random::RandomFunction::executeCudaSingle(launchDims, + // extraPointers, opNum, stateHost, dZ, dZShapeInfo, extraArguments), + BUILD_SINGLE_SELECTOR( + zType, functions::random::RandomFunction, + ::executeCudaSingle(launchDims, stream, opNum, stateDevice, dZ, + dZShapeInfo, extraArguments), + FLOAT_TYPES); + + res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execRandom X failed", res); + + cudaFree(stateDevice); + + rng->rewindH(shape::length(hZShapeInfo)); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execRandom(sd::LaunchContext *lc, - int opNum, - Nd4jPointer stateHost, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void *extraArguments) { - - auto stream = lc->getCudaStream(); - - auto sizeOf = sizeof(sd::graph::RandomGenerator); - Nd4jPointer stateDevice; - - cudaError_t res = cudaMalloc(reinterpret_cast(&stateDevice), sizeOf); - checkCudaErrors(cudaStreamSynchronize(*stream)); - checkCudaErrors(cudaMemcpyAsync(stateDevice, stateHost, sizeOf, cudaMemcpyHostToDevice, *stream)); - - auto rng = reinterpret_cast(stateHost); - - dim3 launchDims = dim3(512, 512, 32768); - auto xType = sd::ArrayOptions::dataType(hZShapeInfo); - // functions::random::RandomFunction::executeCudaDouble(launchDims, extraPointers, opNum, stateHost, dX, dXShapeInfo, dZ, dZShapeInfo, extraArguments); - BUILD_SINGLE_SELECTOR(xType, functions::random::RandomFunction, ::executeCudaDouble(launchDims, stream, opNum, stateDevice, dX, dXShapeInfo, dZ, dZShapeInfo, extraArguments), FLOAT_TYPES); - - res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execRandom XY failed", res); - - cudaFree(stateDevice); - - rng->rewindH(shape::length(hZShapeInfo)); +void NativeOpExecutioner::execRandom( + sd::LaunchContext* lc, int opNum, Nd4jPointer stateHost, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, void* extraArguments) { + auto stream = lc->getCudaStream(); + + auto sizeOf = sizeof(sd::graph::RandomGenerator); + Nd4jPointer stateDevice; + + cudaError_t res = cudaMalloc(reinterpret_cast(&stateDevice), sizeOf); + checkCudaErrors(cudaStreamSynchronize(*stream)); + checkCudaErrors(cudaMemcpyAsync(stateDevice, stateHost, sizeOf, + cudaMemcpyHostToDevice, *stream)); + + auto rng = reinterpret_cast(stateHost); + + dim3 launchDims = dim3(512, 512, 32768); + auto xType = sd::ArrayOptions::dataType(hZShapeInfo); + // functions::random::RandomFunction::executeCudaDouble(launchDims, + // extraPointers, opNum, stateHost, dX, dXShapeInfo, dZ, dZShapeInfo, + // extraArguments); + BUILD_SINGLE_SELECTOR( + xType, functions::random::RandomFunction, + ::executeCudaDouble(launchDims, stream, opNum, stateDevice, dX, + dXShapeInfo, dZ, dZShapeInfo, extraArguments), + FLOAT_TYPES); + + res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execRandom XY failed", res); + + cudaFree(stateDevice); + + rng->rewindH(shape::length(hZShapeInfo)); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execRandom(sd::LaunchContext *lc, - int opNum, - Nd4jPointer stateHost, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void *extraArguments) { - - auto stream = lc->getCudaStream(); - auto sizeOf = sizeof(sd::graph::RandomGenerator); - Nd4jPointer stateDevice; - - cudaError_t res = cudaMalloc(reinterpret_cast(&stateDevice), sizeOf); - checkCudaErrors(cudaStreamSynchronize(*stream)); - checkCudaErrors(cudaMemcpyAsync(stateDevice, stateHost, sizeOf, cudaMemcpyHostToDevice, *stream)); - - auto rng = reinterpret_cast(stateHost); - - dim3 launchDims = dim3(512, 512, 32768); - auto xType = sd::ArrayOptions::dataType(hZShapeInfo); - // functions::random::RandomFunction::executeCudaTriple(launchDims, extraPointers, opNum, stateHost, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraArguments); - BUILD_SINGLE_SELECTOR(xType, functions::random::RandomFunction, ::executeCudaTriple(launchDims, stream, opNum, stateDevice, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraArguments), FLOAT_TYPES); - - res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execRandom XYZ failed", res); - - cudaFree(stateDevice); - - rng->rewindH(shape::length(hZShapeInfo)); +void NativeOpExecutioner::execRandom( + sd::LaunchContext* lc, int opNum, Nd4jPointer stateHost, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void const* hY, Nd4jLong const* hYShapeInfo, void const* dY, + Nd4jLong const* dYShapeInfo, void* hZ, Nd4jLong const* hZShapeInfo, + void* dZ, Nd4jLong const* dZShapeInfo, void* extraArguments) { + auto stream = lc->getCudaStream(); + auto sizeOf = sizeof(sd::graph::RandomGenerator); + Nd4jPointer stateDevice; + + cudaError_t res = cudaMalloc(reinterpret_cast(&stateDevice), sizeOf); + checkCudaErrors(cudaStreamSynchronize(*stream)); + checkCudaErrors(cudaMemcpyAsync(stateDevice, stateHost, sizeOf, + cudaMemcpyHostToDevice, *stream)); + + auto rng = reinterpret_cast(stateHost); + + dim3 launchDims = dim3(512, 512, 32768); + auto xType = sd::ArrayOptions::dataType(hZShapeInfo); + // functions::random::RandomFunction::executeCudaTriple(launchDims, + // extraPointers, opNum, stateHost, dX, dXShapeInfo, dY, dYShapeInfo, dZ, + // dZShapeInfo, extraArguments); + BUILD_SINGLE_SELECTOR( + xType, functions::random::RandomFunction, + ::executeCudaTriple(launchDims, stream, opNum, stateDevice, dX, + dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, + extraArguments), + FLOAT_TYPES); + + res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execRandom XYZ failed", res); + + cudaFree(stateDevice); + + rng->rewindH(shape::length(hZShapeInfo)); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduce3All(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParamsVals, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* xTadShapeInfo, Nd4jLong const* xOffsets, - Nd4jLong const* yTadShapeInfo, Nd4jLong const* yOffsets) { - - auto stream = lc->getCudaStream(); - auto allocationPointer = lc->getAllocationPointer(); - auto reductionPointer = lc->getReductionPointer(); - - if (sd::Environment::getInstance().isDebugAndVerbose()) - printf("D119 opNum:[%i]\n", opNum); - - dim3 launchDims(shape::length(hZShapeInfo), 256, 32768); - - if (sd::Environment::getInstance().isVerbose() && launchDims.x == 1) - printf("AD119 opNum:[%i]\n", opNum); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (yType != xType) - throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3All both operands must have same data type", xType, yType); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execAll(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, extraParamsVals, dZ, dZShapeInfo, dimension, dimensionLength, 1, allocationPointer, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets), LIBND4J_TYPES, FLOAT_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduce3All failed", res); +void NativeOpExecutioner::execReduce3All( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParamsVals, void const* hY, Nd4jLong const* hYShapeInfo, + void const* dY, Nd4jLong const* dYShapeInfo, void* hZ, + Nd4jLong const* hZShapeInfo, void* dZ, Nd4jLong const* dZShapeInfo, + int* dimension, int dimensionLength, Nd4jLong const* xTadShapeInfo, + Nd4jLong const* xOffsets, Nd4jLong const* yTadShapeInfo, + Nd4jLong const* yOffsets) { + auto stream = lc->getCudaStream(); + auto allocationPointer = lc->getAllocationPointer(); + auto reductionPointer = lc->getReductionPointer(); + + if (sd::Environment::getInstance().isDebugAndVerbose()) + printf("D119 opNum:[%i]\n", opNum); + + dim3 launchDims(shape::length(hZShapeInfo), 256, 32768); + + if (sd::Environment::getInstance().isVerbose() && launchDims.x == 1) + printf("AD119 opNum:[%i]\n", opNum); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (yType != xType) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execReduce3All both operands must have same data " + "type", + xType, yType); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce3::Reduce3, + ::execAll(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, + extraParamsVals, dZ, dZShapeInfo, dimension, dimensionLength, 1, + allocationPointer, xTadShapeInfo, xOffsets, yTadShapeInfo, + yOffsets), + LIBND4J_TYPES, FLOAT_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execReduce3All failed", res); } - //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduce3TAD(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* yTadShapeInfo, Nd4jLong const* yTadOffsets) { - - if(shape::isScalar(hZShapeInfo)) { - NativeOpExecutioner::execReduce3(lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); - return; - } - - auto stream = lc->getCudaStream(); - auto allocationPointer = lc->getAllocationPointer(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (xType != yType) - throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3TAD requires Y operand to have X type", xType, yType); - - if (!DataTypeUtils::isR(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3TAD requires Z operand to have floating point data type", zType); - - auto numBlocks = shape::length(hZShapeInfo); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, extraParams, dZ, dZShapeInfo, dimension, dimensionLength, 1, allocationPointer, tadShapeInfo, tadOffsets, yTadShapeInfo, yTadOffsets), LIBND4J_TYPES, FLOAT_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduce3TAD failed", res); +void NativeOpExecutioner::execReduce3TAD( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void const* hY, Nd4jLong const* hYShapeInfo, + void const* dY, Nd4jLong const* dYShapeInfo, void* hZ, + Nd4jLong const* hZShapeInfo, void* dZ, Nd4jLong const* dZShapeInfo, + int* dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* yTadShapeInfo, + Nd4jLong const* yTadOffsets) { + if (shape::isScalar(hZShapeInfo)) { + NativeOpExecutioner::execReduce3( + lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, + hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); + return; + } + + auto stream = lc->getCudaStream(); + auto allocationPointer = lc->getAllocationPointer(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (xType != yType) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execReduce3TAD requires Y operand to have X type", + xType, yType); + + if (!DataTypeUtils::isR(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execReduce3TAD requires Z operand to have " + "floating point data type", + zType); + + auto numBlocks = shape::length(hZShapeInfo); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce3::Reduce3, + ::exec(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, + extraParams, dZ, dZShapeInfo, dimension, dimensionLength, 1, + allocationPointer, tadShapeInfo, tadOffsets, yTadShapeInfo, + yTadOffsets), + LIBND4J_TYPES, FLOAT_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execReduce3TAD failed", res); } - diff --git a/libnd4j/include/legacy/cuda/NativeOps.cu b/libnd4j/include/legacy/cuda/NativeOps.cu old mode 100755 new mode 100644 index 1ccc2c7d515a..da7da37d721a --- a/libnd4j/include/legacy/cuda/NativeOps.cu +++ b/libnd4j/include/legacy/cuda/NativeOps.cu @@ -14,32 +14,25 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include #include -#include - -#include - - -#include #include #include #include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include +#include #include -#include - +#include +#include //#include @@ -47,6 +40,16 @@ #include #include +// this section is for MMAP +#ifndef _WIN32 +#include +#include +#include +#else +#include +#include +#endif + using namespace sd; #include @@ -69,268 +72,295 @@ int minThreads = 32; __constant__ char deviceConstantMemory[49152]; - // this method just does type conversion in fancy way int getDeviceId(Nd4jPointer ptrToDeviceId) { - return (int)(Nd4jLong)ptrToDeviceId; + return (int)(Nd4jLong)ptrToDeviceId; } /* * Basic CUDA constants here: number of blocks per MP */ int getDeviceBlockThreshold(int deviceId) { - int ccMinor = deviceProperties[deviceId].minor; - int ccMajor = deviceProperties[deviceId].major; + int ccMinor = deviceProperties[deviceId].minor; + int ccMajor = deviceProperties[deviceId].major; - int blockThreshold = 8; + int blockThreshold = 8; - if (ccMajor >= 5) - blockThreshold = 32; - else if (ccMajor == 3) - blockThreshold = 16; - else if (ccMajor < 3) - blockThreshold = 8; + if (ccMajor >= 5) + blockThreshold = 32; + else if (ccMajor == 3) + blockThreshold = 16; + else if (ccMajor < 3) + blockThreshold = 8; - return blockThreshold; + return blockThreshold; } - /* - * This message returns shared memory threshold value. default overflow ratio is 0.3 + * This message returns shared memory threshold value. default overflow ratio is + * 0.3 */ int getDeviceSharedThreshold(int deviceId) { - int ccMinor = deviceProperties[deviceId].minor; - int ccMajor = deviceProperties[deviceId].major; + int ccMinor = deviceProperties[deviceId].minor; + int ccMajor = deviceProperties[deviceId].major; - // please note threshold isn't multiple of 32, and that's NOT a mistake + // please note threshold isn't multiple of 32, and that's NOT a mistake - int shmemThreshold; - if (ccMajor == 6 && ccMinor == 0) - shmemThreshold = 65536; - else if (ccMajor == 6 && ccMinor == 1) - shmemThreshold = 49152; - else if (ccMajor == 5 && ccMinor == 2) - shmemThreshold = 98304; - else if (ccMajor == 5) - shmemThreshold = 65536; - else if (ccMajor == 3 && ccMinor == 7) - shmemThreshold = 114688; - else shmemThreshold = 49152; + int shmemThreshold; + if (ccMajor == 6 && ccMinor == 0) + shmemThreshold = 65536; + else if (ccMajor == 6 && ccMinor == 1) + shmemThreshold = 49152; + else if (ccMajor == 5 && ccMinor == 2) + shmemThreshold = 98304; + else if (ccMajor == 5) + shmemThreshold = 65536; + else if (ccMajor == 3 && ccMinor == 7) + shmemThreshold = 114688; + else + shmemThreshold = 49152; - return shmemThreshold / 0.3; + return shmemThreshold / 0.3; } - - -sd::buffer::Buffer * createScalarBuffer(cudaStream_t stream) { - auto scalarShapeInfo = shape::createScalarShapeInfo(); - auto buff = sd::buffer::createBuffer(scalarShapeInfo,shape::shapeInfoLength(2), stream); - sd::buffer::copyDataToGpu(&buff, stream); - return buff; +sd::buffer::Buffer *createScalarBuffer(cudaStream_t stream) { + auto scalarShapeInfo = shape::createScalarShapeInfo(); + auto buff = sd::buffer::createBuffer(scalarShapeInfo, + shape::shapeInfoLength(2), stream); + sd::buffer::copyDataToGpu(&buff, stream); + return buff; } - class ScalarShapeInformation { -private: - sd::buffer::Buffer *scalarDimension; - sd::buffer::Buffer *scalarShapeInfo; -// std::thread::id threadId; - -public: - ScalarShapeInformation(cudaStream_t stream) { - auto scalarDimensionBuff = reinterpret_cast(malloc(sizeof(Nd4jLong))); - - CHECK_ALLOC(scalarDimensionBuff, "Failed to allocate ShapeInfoBuffer", sizeof(Nd4jLong)); - - scalarDimensionBuff[0] = MAX_DIMENSION; - scalarDimension = sd::buffer::createBuffer(scalarDimensionBuff,1, stream); - scalarShapeInfo = createScalarBuffer(stream); -// threadId = std::this_thread::get_id(); + private: + sd::buffer::Buffer *scalarDimension; + sd::buffer::Buffer *scalarShapeInfo; + // std::thread::id threadId; - } - ~ScalarShapeInformation() { - sd::buffer::freeBuffer(&scalarShapeInfo); - sd::buffer::freeBuffer(&scalarDimension); - } + public: + ScalarShapeInformation(cudaStream_t stream) { + auto scalarDimensionBuff = + reinterpret_cast(malloc(sizeof(Nd4jLong))); + CHECK_ALLOC(scalarDimensionBuff, "Failed to allocate ShapeInfoBuffer", + sizeof(Nd4jLong)); - Nd4jLong *getShapeInfoHostPointer() { - return scalarShapeInfo->data; - } + scalarDimensionBuff[0] = MAX_DIMENSION; + scalarDimension = sd::buffer::createBuffer(scalarDimensionBuff, 1, stream); + scalarShapeInfo = createScalarBuffer(stream); + // threadId = std::this_thread::get_id(); + } + ~ScalarShapeInformation() { + sd::buffer::freeBuffer(&scalarShapeInfo); + sd::buffer::freeBuffer(&scalarDimension); + } - Nd4jLong * getShapeInfoGpuPointer() { - return scalarShapeInfo->gData; - } + Nd4jLong *getShapeInfoHostPointer() { return scalarShapeInfo->data; } - Nd4jLong * getDimensionHostPointer() { - return scalarDimension->data; - } + Nd4jLong *getShapeInfoGpuPointer() { return scalarShapeInfo->gData; } - Nd4jLong * getDimensionGpuPointer() { - return scalarDimension->gData; - } + Nd4jLong *getDimensionHostPointer() { return scalarDimension->data; } + Nd4jLong *getDimensionGpuPointer() { return scalarDimension->gData; } }; - - - - template class ScalarInfo { - sd::buffer::Buffer *scalarData; - ScalarShapeInformation *shapeInfo; - T finalResult; - cudaStream_t streamRef; -public: - ScalarInfo(cudaStream_t stream) { - T *scalarResult = reinterpret_cast(malloc(sizeof(T))); - - CHECK_ALLOC(scalarResult, "Failed to allocate new scalar buffer", sizeof(T)); - - shapeInfo = new ScalarShapeInformation(stream); - scalarData = sd::buffer::createBuffer(scalarResult,1, stream); - streamRef = stream; - sd::buffer::copyDataToGpu(&scalarData, stream); - } - - T getFinalResultFromDevice() { - sd::buffer::copyDataFromGpu(&scalarData, streamRef); - return scalarData->data[0]; - } - - /** - * Get the device shape information - * representing a scalar - */ - Nd4jLong *getDeviceShapeInfo() { - return shapeInfo->getShapeInfoGpuPointer(); - } - - /** - * Get the dZ pointers - */ - T *getDevicePointer() { - return scalarData->gData; - } - - /** - * Get the infinite dimension device pointer - */ - Nd4jLong *getDimensionDevicePointer() { - return shapeInfo->getDimensionGpuPointer(); - } - - ~ScalarInfo() { - sd::buffer::freeBuffer(&scalarData); - delete shapeInfo; - } + sd::buffer::Buffer *scalarData; + ScalarShapeInformation *shapeInfo; + T finalResult; + cudaStream_t streamRef; + + public: + ScalarInfo(cudaStream_t stream) { + T *scalarResult = reinterpret_cast(malloc(sizeof(T))); + + CHECK_ALLOC(scalarResult, "Failed to allocate new scalar buffer", + sizeof(T)); + + shapeInfo = new ScalarShapeInformation(stream); + scalarData = sd::buffer::createBuffer(scalarResult, 1, stream); + streamRef = stream; + sd::buffer::copyDataToGpu(&scalarData, stream); + } + + T getFinalResultFromDevice() { + sd::buffer::copyDataFromGpu(&scalarData, streamRef); + return scalarData->data[0]; + } + + /** + * Get the device shape information + * representing a scalar + */ + Nd4jLong *getDeviceShapeInfo() { return shapeInfo->getShapeInfoGpuPointer(); } + + /** + * Get the dZ pointers + */ + T *getDevicePointer() { return scalarData->gData; } + + /** + * Get the infinite dimension device pointer + */ + Nd4jLong *getDimensionDevicePointer() { + return shapeInfo->getDimensionGpuPointer(); + } + + ~ScalarInfo() { + sd::buffer::freeBuffer(&scalarData); + delete shapeInfo; + } }; -void execPairwiseTransform( Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execPairwiseTransform(&lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hYShapeInfo).special(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), extraParams); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execPairwiseTransform(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, OpaqueDataBuffer *dbY, + Nd4jLong const *hYShapeInfo, + Nd4jLong const *dYShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, void *extraParams) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execPairwiseTransform( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + dbY->primary(), hYShapeInfo, dbY->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hYShapeInfo) + .special(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + extraParams); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execPairwiseTransformBool(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execPairwiseBoolTransform(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hYShapeInfo).special(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), - extraParams); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execPairwiseTransformBool( + Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + Nd4jLong const *hXShapeInfo, Nd4jLong const *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong const *hYShapeInfo, + Nd4jLong const *dYShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, Nd4jLong const *dZShapeInfo, + void *extraParams) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execPairwiseBoolTransform( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + dbY->primary(), hYShapeInfo, dbY->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hYShapeInfo) + .special(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + extraParams); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execSummaryStatsScalar(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - bool biasCorrected) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execSummaryStatsScalar(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - extraParams, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), - biasCorrected); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execSummaryStatsScalar(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, bool biasCorrected) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execSummaryStatsScalar( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + biasCorrected); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execBroadcastBool(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); - InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); - auto tadOnlyShapeInfo = reinterpret_cast(extraPointers[10]); - auto tadOffsets = reinterpret_cast(extraPointers[11]); - auto tadOnlyShapeInfoZ = reinterpret_cast(extraPointers[12]); - auto tadOffsetsZ = reinterpret_cast(extraPointers[13]); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execBroadcastBool(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hYShapeInfo).special(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), - extraParams, - dimension, dimensionLength, - tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execBroadcastBool(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, OpaqueDataBuffer *dbY, + Nd4jLong const *hYShapeInfo, Nd4jLong const *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, void *extraParams, + OpaqueDataBuffer *dbDimension, + Nd4jLong const *hDimensionShape, + Nd4jLong const *dDimensionShape) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); + auto tadOnlyShapeInfo = reinterpret_cast(extraPointers[10]); + auto tadOffsets = reinterpret_cast(extraPointers[11]); + auto tadOnlyShapeInfoZ = reinterpret_cast(extraPointers[12]); + auto tadOffsetsZ = reinterpret_cast(extraPointers[13]); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execBroadcastBool( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + dbY->primary(), hYShapeInfo, dbY->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hYShapeInfo) + .special(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** @@ -345,47 +375,58 @@ void execBroadcastBool(Nd4jPointer *extraPointers, * @param dimension * @param dimensionLength */ -void execBroadcast( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); - InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - - auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); - auto tadOnlyShapeInfo = reinterpret_cast(extraPointers[10]); - auto tadOffsets = reinterpret_cast(extraPointers[11]); - auto tadOnlyShapeInfoZ = reinterpret_cast(extraPointers[12]); - auto tadOffsetsZ = reinterpret_cast(extraPointers[13]); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execBroadcast(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hYShapeInfo).special(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), - dimension, dimensionLength, - tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} +void execBroadcast(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + Nd4jLong const *hXShapeInfo, Nd4jLong const *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong const *hYShapeInfo, + Nd4jLong const *dYShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, Nd4jLong const *dZShapeInfo, + OpaqueDataBuffer *dbDimension, + Nd4jLong const *hDimensionShape, + Nd4jLong const *dDimensionShape) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); + auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); + auto tadOnlyShapeInfo = reinterpret_cast(extraPointers[10]); + auto tadOffsets = reinterpret_cast(extraPointers[11]); + auto tadOnlyShapeInfoZ = reinterpret_cast(extraPointers[12]); + auto tadOffsetsZ = reinterpret_cast(extraPointers[13]); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execBroadcast( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + dbY->primary(), hYShapeInfo, dbY->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hYShapeInfo) + .special(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} /** * @@ -397,230 +438,283 @@ void execBroadcast( * @param dZShapeInfo */ //////////////////////////////////////////////////////////////////////// -void execReduceFloat(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceFloatScalar(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - extraParams, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special()); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduceFloat(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execReduceFloatScalar( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execReduceSame(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceSameScalar(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - extraParams, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special()); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduceSame(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execReduceSameScalar( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execReduceSame2(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const*hXShapeInfo, Nd4jLong const*dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const*hZShapeInfo, Nd4jLong const*dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const*hDimensionShape, Nd4jLong const*dDimensionShape) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, - dimension, - shape::length(hDimensionShape)); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceSame(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - extraParams, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), - dimension, dimensionLength, - tadPack.specialShapeInfo(), tadPack.specialOffsets()); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduceSame2(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, OpaqueDataBuffer *dbDimension, + Nd4jLong const *hDimensionShape, + Nd4jLong const *dDimensionShape) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + hXShapeInfo, dimension, shape::length(hDimensionShape)); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execReduceSame( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + dimension, dimensionLength, tadPack.specialShapeInfo(), + tadPack.specialOffsets()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execReduceLong2(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const*hXShapeInfo, Nd4jLong const*dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const*hZShapeInfo, Nd4jLong const*dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const*hDimensionShape, Nd4jLong const*dDimensionShape) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, - dimension, - shape::length(hDimensionShape)); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceLong(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - extraParams, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), - dimension, dimensionLength, - tadPack.specialShapeInfo(), tadPack.specialOffsets()); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduceLong2(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, OpaqueDataBuffer *dbDimension, + Nd4jLong const *hDimensionShape, + Nd4jLong const *dDimensionShape) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + hXShapeInfo, dimension, shape::length(hDimensionShape)); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execReduceLong( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + dimension, dimensionLength, tadPack.specialShapeInfo(), + tadPack.specialOffsets()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execReduceLong(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const*hXShapeInfo, Nd4jLong const*dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const*hZShapeInfo, Nd4jLong const*dZShapeInfo) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - auto stream = reinterpret_cast(extraPointers[1]); - auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); - auto dTADShapeInfo = reinterpret_cast(extraPointers[10]); - - auto reductionPointer = reinterpret_cast(extraPointers[4]); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (zType != sd::DataType::INT64) - throw datatype_exception::build("execReduceLong wrong Z data type", sd::DataType::INT64, zType); - - auto xLength = shape::length(hXShapeInfo); - auto blockWidth = 256; - auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); - dim3 launchDims(numBlocks, blockWidth, 32768); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, - ::execReduceScalar(launchDims, stream, opNum, - dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), hXShapeInfo, - extraParams, - dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), hXShapeInfo, - nullptr, 0, reductionPointer, dTADShapeInfo), LIBND4J_TYPES, LONG_TYPES); - - sd::DebugHelper::checkErrorCode(stream, "execReduceLong(...) failed"); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduceLong(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + auto stream = reinterpret_cast(extraPointers[1]); + auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); + auto dTADShapeInfo = reinterpret_cast(extraPointers[10]); + + auto reductionPointer = reinterpret_cast(extraPointers[4]); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (zType != sd::DataType::INT64) + throw datatype_exception::build("execReduceLong wrong Z data type", + sd::DataType::INT64, zType); + + auto xLength = shape::length(hXShapeInfo); + auto blockWidth = 256; + auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); + dim3 launchDims(numBlocks, blockWidth, 32768); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceLongFunction, + ::execReduceScalar(launchDims, stream, opNum, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + hXShapeInfo, extraParams, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + hXShapeInfo, nullptr, 0, reductionPointer, + dTADShapeInfo), + LIBND4J_TYPES, LONG_TYPES); + + sd::DebugHelper::checkErrorCode(stream, "execReduceLong(...) failed"); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execReduceBool2(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const*hXShapeInfo, Nd4jLong const*dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const*hZShapeInfo, Nd4jLong const*dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const*hDimensionShape, Nd4jLong const*dDimensionShape) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, - dimension, - shape::length(hDimensionShape)); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceBool(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - extraParams, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), - dimension, dimensionLength, - tadPack.specialShapeInfo(), tadPack.specialOffsets()); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduceBool2(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, OpaqueDataBuffer *dbDimension, + Nd4jLong const *hDimensionShape, + Nd4jLong const *dDimensionShape) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + hXShapeInfo, dimension, shape::length(hDimensionShape)); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execReduceBool( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + dimension, dimensionLength, tadPack.specialShapeInfo(), + tadPack.specialOffsets()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execReduceBool(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - auto stream = reinterpret_cast(extraPointers[1]); - auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); - auto dTADShapeInfo = reinterpret_cast(extraPointers[10]); - - auto reductionPointer = reinterpret_cast(extraPointers[4]); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (zType != sd::DataType::BOOL) - throw std::runtime_error("execReduceBool requires Z operand to have BOOL type"); - - auto xLength = shape::length(hXShapeInfo); - auto blockWidth = 256; - auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); - dim3 launchDims(numBlocks, blockWidth, 32768); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, - ::execReduceScalar(launchDims, stream, opNum, - dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), hXShapeInfo, - extraParams, - dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), hZShapeInfo, - nullptr, 0, reductionPointer, dTADShapeInfo), LIBND4J_TYPES, BOOL_TYPES); - - sd::DebugHelper::checkErrorCode(stream, "execReduceBool(...) failed"); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduceBool(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + auto stream = reinterpret_cast(extraPointers[1]); + auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); + auto dTADShapeInfo = reinterpret_cast(extraPointers[10]); + + auto reductionPointer = reinterpret_cast(extraPointers[4]); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (zType != sd::DataType::BOOL) + throw std::runtime_error( + "execReduceBool requires Z operand to have BOOL type"); + + auto xLength = shape::length(hXShapeInfo); + auto blockWidth = 256; + auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); + dim3 launchDims(numBlocks, blockWidth, 32768); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceBoolFunction, + ::execReduceScalar(launchDims, stream, opNum, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + hXShapeInfo, extraParams, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + hZShapeInfo, nullptr, 0, reductionPointer, + dTADShapeInfo), + LIBND4J_TYPES, BOOL_TYPES); + + sd::DebugHelper::checkErrorCode(stream, "execReduceBool(...) failed"); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** @@ -635,36 +729,43 @@ void execReduceBool(Nd4jPointer *extraPointers, * @param dimensionLength */ //////////////////////////////////////////////////////////////////////// -void execIndexReduce(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, - dimension, - shape::length(hDimensionShape)); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execIndexReduce(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - extraParams, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), - (int *) dbDimension->special(), dimensionLength, - tadPack.specialShapeInfo(), tadPack.specialOffsets()); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execIndexReduce(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, OpaqueDataBuffer *dbDimension, + Nd4jLong const *hDimensionShape, + Nd4jLong const *dDimensionShape) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + hXShapeInfo, dimension, shape::length(hDimensionShape)); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execIndexReduce( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + (int *)dbDimension->special(), dimensionLength, + tadPack.specialShapeInfo(), tadPack.specialOffsets()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** @@ -677,36 +778,44 @@ void execIndexReduce(Nd4jPointer *extraPointers, * @param dZShapeInfo */ //////////////////////////////////////////////////////////////////////// -void execReduceFloat2(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, - dimension, - shape::length(hDimensionShape)); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceFloat(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - extraParams, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), - dimension, dimensionLength, - tadPack.specialShapeInfo(), tadPack.specialOffsets()); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduceFloat2(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, + OpaqueDataBuffer *dbDimension, + Nd4jLong const *hDimensionShape, + Nd4jLong const *dDimensionShape) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + hXShapeInfo, dimension, shape::length(hDimensionShape)); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execReduceFloat( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + dimension, dimensionLength, tadPack.specialShapeInfo(), + tadPack.specialOffsets()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** @@ -717,287 +826,331 @@ void execReduceFloat2(Nd4jPointer *extraPointers, * @param extraParams */ //////////////////////////////////////////////////////////////////////// -void execIndexReduceScalar( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo){ - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execIndexReduceScalar(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - extraParams, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special()); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execIndexReduceScalar(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execIndexReduceScalar( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execTransformSame(Nd4jPointer *extraPointers,int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - auto tadShapeInfo = reinterpret_cast(extraPointers != nullptr ? extraPointers[0] : nullptr); - auto tadOffsets = reinterpret_cast(extraPointers != nullptr ? extraPointers[1] : nullptr); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execTransformSame(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), - extraParams, - tadShapeInfo, tadOffsets); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execTransformSame(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, Nd4jLong const *dZShapeInfo, + void *extraParams) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + auto tadShapeInfo = reinterpret_cast( + extraPointers != nullptr ? extraPointers[0] : nullptr); + auto tadOffsets = reinterpret_cast( + extraPointers != nullptr ? extraPointers[1] : nullptr); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execTransformSame( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + extraParams, tadShapeInfo, tadOffsets); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execTransformBool(Nd4jPointer *extraPointers,int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - auto tadShapeInfo = reinterpret_cast(extraPointers != nullptr ? extraPointers[0] : nullptr); - auto tadOffsets = reinterpret_cast(extraPointers != nullptr ? extraPointers[1] : nullptr); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execTransformBool(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), - extraParams, - tadShapeInfo, tadOffsets); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execTransformBool(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, Nd4jLong const *dZShapeInfo, + void *extraParams) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + auto tadShapeInfo = reinterpret_cast( + extraPointers != nullptr ? extraPointers[0] : nullptr); + auto tadOffsets = reinterpret_cast( + extraPointers != nullptr ? extraPointers[1] : nullptr); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execTransformBool( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + extraParams, tadShapeInfo, tadOffsets); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execTransformAny(Nd4jPointer *extraPointers,int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - auto stream = reinterpret_cast(extraPointers[1]); - auto streamSpecial = reinterpret_cast(extraPointers[4]); - LaunchContext lc(stream, streamSpecial, extraPointers[5], extraPointers[3], - reinterpret_cast(extraPointers[6])); - - NativeOpExecutioner::execTransformAny(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), - extraParams, - nullptr, nullptr); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execTransformAny(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, Nd4jLong const *dZShapeInfo, + void *extraParams) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + auto stream = reinterpret_cast(extraPointers[1]); + auto streamSpecial = reinterpret_cast(extraPointers[4]); + LaunchContext lc(stream, streamSpecial, extraPointers[5], extraPointers[3], + reinterpret_cast(extraPointers[6])); + + NativeOpExecutioner::execTransformAny( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + extraParams, nullptr, nullptr); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execTransformStrict(Nd4jPointer *extraPointers,int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - auto tadShapeInfo = reinterpret_cast(extraPointers != nullptr ? extraPointers[10] : nullptr); - auto tadOffsets = reinterpret_cast(extraPointers != nullptr ? extraPointers[11] : nullptr); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execTransformStrict(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), - extraParams, - tadShapeInfo, tadOffsets); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execTransformStrict(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, void *extraParams) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + auto tadShapeInfo = reinterpret_cast( + extraPointers != nullptr ? extraPointers[10] : nullptr); + auto tadOffsets = reinterpret_cast( + extraPointers != nullptr ? extraPointers[11] : nullptr); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execTransformStrict( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + extraParams, tadShapeInfo, tadOffsets); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execTransformFloat(Nd4jPointer *extraPointers,int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - auto tadShapeInfo = reinterpret_cast(extraPointers != nullptr ? extraPointers[10] : nullptr); - auto tadOffsets = reinterpret_cast(extraPointers != nullptr ? extraPointers[11] : nullptr); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execTransformFloat(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), - extraParams, - tadShapeInfo, tadOffsets); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execTransformFloat(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, void *extraParams) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + auto tadShapeInfo = reinterpret_cast( + extraPointers != nullptr ? extraPointers[10] : nullptr); + auto tadOffsets = reinterpret_cast( + extraPointers != nullptr ? extraPointers[11] : nullptr); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execTransformFloat( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + extraParams, tadShapeInfo, tadOffsets); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } void checkP2P() { - int curDevice = 0; - - cudaGetDevice(&curDevice); + int curDevice = 0; - int devCnt = 0; - cudaGetDeviceCount(&devCnt); + cudaGetDevice(&curDevice); - if (curDevice < 0 && curDevice > devCnt) - curDevice = 0; + int devCnt = 0; + cudaGetDeviceCount(&devCnt); - bool tempSupport = true; + if (curDevice < 0 && curDevice > devCnt) curDevice = 0; - if (devCnt > 1) { - for (int dX = 0; dX < devCnt; dX++) { + bool tempSupport = true; - for (int dY = 0; dY < devCnt; dY++) { - if (dX == dY) - continue; + if (devCnt > 1) { + for (int dX = 0; dX < devCnt; dX++) { + for (int dY = 0; dY < devCnt; dY++) { + if (dX == dY) continue; - int canAccess = 0; - cudaSetDevice(dX); + int canAccess = 0; + cudaSetDevice(dX); - cudaDeviceCanAccessPeer(&canAccess, dX , dY); + cudaDeviceCanAccessPeer(&canAccess, dX, dY); - if (!canAccess) { - tempSupport = false; - break; - } - } - } + if (!canAccess) { + tempSupport = false; + break; + } + } + } - supportedP2P = tempSupport; + supportedP2P = tempSupport; - cudaSetDevice(curDevice); - } else { - // if we have only 1 device - we say that we support P2P, since all data will be on 1 device - supportedP2P = true; - } + cudaSetDevice(curDevice); + } else { + // if we have only 1 device - we say that we support P2P, since all data + // will be on 1 device + supportedP2P = true; + } } void enableP2P(bool enable) { - if (enable == allowedP2P) - return; + if (enable == allowedP2P) return; - int curDevice = 0; + int curDevice = 0; - cudaGetDevice(&curDevice); - - int devCnt = 0; - cudaGetDeviceCount(&devCnt); + cudaGetDevice(&curDevice); - if (curDevice < 0 && curDevice > devCnt) - curDevice = 0; + int devCnt = 0; + cudaGetDeviceCount(&devCnt); - if (devCnt > 1) { - for (int dX = 0; dX < devCnt; dX++) { + if (curDevice < 0 && curDevice > devCnt) curDevice = 0; - for (int dY = 0; dY < devCnt; dY++) { - if (dX == dY) - continue; + if (devCnt > 1) { + for (int dX = 0; dX < devCnt; dX++) { + for (int dY = 0; dY < devCnt; dY++) { + if (dX == dY) continue; - int canAccess = 0; - cudaSetDevice(dX); + int canAccess = 0; + cudaSetDevice(dX); - cudaDeviceCanAccessPeer(&canAccess, dX , dY); + cudaDeviceCanAccessPeer(&canAccess, dX, dY); - if (canAccess) { - if (enable) { - cudaDeviceEnablePeerAccess(dY, 0); - } else { - cudaDeviceDisablePeerAccess(dY); - } - } else { - if (sd::Environment::getInstance().isVerbose()) printf("Peer access [%i] -> [%i] isn't possible\n", dX, dY); - } - } + if (canAccess) { + if (enable) { + cudaDeviceEnablePeerAccess(dY, 0); + } else { + cudaDeviceDisablePeerAccess(dY); + } + } else { + if (sd::Environment::getInstance().isVerbose()) + printf("Peer access [%i] -> [%i] isn't possible\n", dX, dY); } - - cudaSetDevice(curDevice); + } } - allowedP2P = enable; - cudaSetDevice(curDevice); -} + } -bool isP2PAvailable() { - return supportedP2P; + allowedP2P = enable; + + cudaSetDevice(curDevice); } +bool isP2PAvailable() { return supportedP2P; } void initializeDevicesAndFunctions() { - try { - int devCnt = 0; - cudaGetDeviceCount(&devCnt); - deviceProperties = new cudaDeviceProp[devCnt]; - for (int i = 0; i < devCnt; i++) { - cudaSetDevice(i); - cudaGetDeviceProperties(&deviceProperties[i], i); - - cudaDeviceSetLimit(cudaLimitStackSize, 4096); - } + try { + int devCnt = 0; + cudaGetDeviceCount(&devCnt); + deviceProperties = new cudaDeviceProp[devCnt]; + for (int i = 0; i < devCnt; i++) { + cudaSetDevice(i); + cudaGetDeviceProperties(&deviceProperties[i], i); - cudaSetDevice(0); + cudaDeviceSetLimit(cudaLimitStackSize, 4096); + } - checkP2P(); + cudaSetDevice(0); - // enabling p2p gpu access if it's supported - if (supportedP2P && devCnt > 1) - enableP2P(allowedP2P); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + checkP2P(); + + // enabling p2p gpu access if it's supported + if (supportedP2P && devCnt > 1) enableP2P(allowedP2P); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } void initializeFunctions(Nd4jPointer *functions) { - sd::BlasHelper::getInstance().initializeDeviceFunctions(functions); - /* - cublasSgemv = (CublasSgemv)functions[0]; - cublasDgemv = (CublasDgemv)functions[1]; - cublasHgemm = (CublasHgemm)functions[2]; - cublasSgemm = (CublasSgemm)functions[3]; - cublasDgemm = (CublasDgemm)functions[4]; - cublasSgemmEx = (CublasSgemmEx)functions[5]; - cublasHgemmBatched = (CublasHgemmBatched)functions[6]; - cublasSgemmBatched = (CublasSgemmBatched)functions[7]; - cublasDgemmBatched = (CublasDgemmBatched)functions[8]; - */ + sd::BlasHelper::getInstance().initializeDeviceFunctions(functions); + /* + cublasSgemv = (CublasSgemv)functions[0]; +cublasDgemv = (CublasDgemv)functions[1]; +cublasHgemm = (CublasHgemm)functions[2]; +cublasSgemm = (CublasSgemm)functions[3]; +cublasDgemm = (CublasDgemm)functions[4]; +cublasSgemmEx = (CublasSgemmEx)functions[5]; +cublasHgemmBatched = (CublasHgemmBatched)functions[6]; +cublasSgemmBatched = (CublasSgemmBatched)functions[7]; +cublasDgemmBatched = (CublasDgemmBatched)functions[8]; + */ } - /** * This method acquires memory chunk of requested size on host side * @@ -1006,15 +1159,17 @@ void initializeFunctions(Nd4jPointer *functions) { * @param flags optional parameter */ Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) { - Nd4jPointer pointer; - // cudaHostAllocMapped |cudaHostAllocPortable - auto res = cudaHostAlloc(reinterpret_cast(&pointer), memorySize + 8, cudaHostAllocDefault); - if (res != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(res); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaHostAlloc failed"); - } + Nd4jPointer pointer; + // cudaHostAllocMapped |cudaHostAllocPortable + auto res = cudaHostAlloc(reinterpret_cast(&pointer), memorySize + 8, + cudaHostAllocDefault); + if (res != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(res); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaHostAlloc failed"); + } - return reinterpret_cast(pointer); + return reinterpret_cast(pointer); } /** @@ -1022,18 +1177,20 @@ Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) { * * @param pointer pointer that'll be used for allocation * @param memorySize memory size, in bytes - * @param ptrToDeviceId pointer to deviceId. For cuda that's just and int, for OpenCL that's pointer to device_id, etc + * @param ptrToDeviceId pointer to deviceId. For cuda that's just and int, for + * OpenCL that's pointer to device_id, etc * @param flags optional parameter */ Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) { - Nd4jPointer pointer; - auto res = cudaMalloc(reinterpret_cast(&pointer), memorySize + 8); - if (res != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(res); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMalloc failed"); - } + Nd4jPointer pointer; + auto res = cudaMalloc(reinterpret_cast(&pointer), memorySize + 8); + if (res != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(res); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaMalloc failed"); + } - return reinterpret_cast(pointer); + return reinterpret_cast(pointer); } /** @@ -1042,13 +1199,14 @@ Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) { * @param pointer pointer that'll be freed */ int freeHost(Nd4jPointer pointer) { - auto res = cudaFreeHost(reinterpret_cast(pointer)); - if (res != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(res); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaFreeHost failed"); - } + auto res = cudaFreeHost(reinterpret_cast(pointer)); + if (res != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(res); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaFreeHost failed"); + } - return 1L; + return 1L; } /** @@ -1058,2386 +1216,2526 @@ int freeHost(Nd4jPointer pointer) { * @param ptrToDeviceId pointer to deviceId. */ int freeDevice(Nd4jPointer pointer, int deviceId) { - auto res = cudaFree(reinterpret_cast(pointer)); + auto res = cudaFree(reinterpret_cast(pointer)); - // we're intentionally skipping - if (res != 0 && res != 1) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(res); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaFree failed"); - } + // we're intentionally skipping + if (res != 0 && res != 1) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(res); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaFree failed"); + } - return res == 0 ? 1L : 0L; + return res == 0 ? 1L : 0L; } - -Nd4jPointer createContext() { - return 0L; -} +Nd4jPointer createContext() { return 0L; } Nd4jPointer createStream() { + auto stream = new cudaStream_t(); + auto dZ = cudaStreamCreate(stream); + if (dZ != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaStreamCreate failed"); + } - auto stream = new cudaStream_t(); - auto dZ = cudaStreamCreate(stream); - if (dZ != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaStreamCreate failed"); - } - - return stream; + return stream; } Nd4jPointer createEvent() { - Nd4jPointer nativeEvent= (Nd4jPointer) malloc(sizeof(cudaEvent_t)); + Nd4jPointer nativeEvent = (Nd4jPointer)malloc(sizeof(cudaEvent_t)); - CHECK_ALLOC(nativeEvent, "Failed to allocate new CUDA event buffer", sizeof(cudaEvent_t)); + CHECK_ALLOC(nativeEvent, "Failed to allocate new CUDA event buffer", + sizeof(cudaEvent_t)); - auto dZ = cudaEventCreateWithFlags(reinterpret_cast(&nativeEvent), cudaEventDisableTiming); - if (dZ != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaEventCreateWithFlags failed"); - } + auto dZ = cudaEventCreateWithFlags( + reinterpret_cast(&nativeEvent), cudaEventDisableTiming); + if (dZ != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaEventCreateWithFlags failed"); + } - return nativeEvent; + return nativeEvent; } int registerEvent(Nd4jPointer event, Nd4jPointer stream) { - auto pEvent = reinterpret_cast(&event); - auto pStream = reinterpret_cast(stream); + auto pEvent = reinterpret_cast(&event); + auto pStream = reinterpret_cast(stream); - auto dZ = cudaEventRecord(*pEvent, *pStream); - if (dZ != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaEventRecord failed"); - } + auto dZ = cudaEventRecord(*pEvent, *pStream); + if (dZ != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaEventRecord failed"); + } - return 1; + return 1; } int setDevice(int deviceId) { - AffinityManager::setCurrentDevice(deviceId); - return 1; + AffinityManager::setCurrentDevice(deviceId); + return 1; } Nd4jLong getDeviceFreeMemoryDefault() { - size_t memFree = 0; - size_t memTotal = 0; + size_t memFree = 0; + size_t memTotal = 0; - cudaMemGetInfo(&memFree, &memTotal); + cudaMemGetInfo(&memFree, &memTotal); - return (Nd4jLong) memFree; + return (Nd4jLong)memFree; } Nd4jLong getDeviceFreeMemory(int device) { - int orig = -1; + int orig = -1; - cudaGetDevice(&orig); + cudaGetDevice(&orig); - if (device >= 0 && device != orig) { - cudaSetDevice(device); - } + if (device >= 0 && device != orig) { + cudaSetDevice(device); + } - size_t memFree = 0; - size_t memTotal = 0; + size_t memFree = 0; + size_t memTotal = 0; - cudaMemGetInfo(&memFree, &memTotal); + cudaMemGetInfo(&memFree, &memTotal); - if (device >= 0 && device != orig) { - cudaSetDevice(orig); - } + if (device >= 0 && device != orig) { + cudaSetDevice(orig); + } - return (Nd4jLong) memFree; + return (Nd4jLong)memFree; } Nd4jLong getDeviceTotalMemory(int device) { - int orig = -1; - - cudaGetDevice(&orig); - - if (device >= 0 && device != orig) { - cudaSetDevice(device); - } - size_t memFree = 0; - size_t memTotal = 0; - - cudaMemGetInfo(&memFree, &memTotal); - - if (device >= 0 && device != orig) { - cudaSetDevice(orig); - } + int orig = -1; + + cudaGetDevice(&orig); + + if (device >= 0 && device != orig) { + cudaSetDevice(device); + } + size_t memFree = 0; + size_t memTotal = 0; + + cudaMemGetInfo(&memFree, &memTotal); + + if (device >= 0 && device != orig) { + cudaSetDevice(orig); + } + + return (Nd4jLong)memTotal; +} + +int memcpySync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, + Nd4jPointer reserved) { + cudaMemcpyKind kind; + + switch (flags) { + case 0: { + kind = cudaMemcpyHostToHost; + } break; + case 1: { + kind = cudaMemcpyHostToDevice; + } break; + case 2: { + kind = cudaMemcpyDeviceToHost; + } break; + case 3: { + kind = cudaMemcpyDeviceToDevice; + } break; + default: { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "UNDEFNED MEMCPY"); + return 0; + } + } + + auto dZ = cudaMemcpy(reinterpret_cast(dst), + const_cast(reinterpret_cast(src)), + static_cast(size), kind); + if (dZ != 0) { + printf("Failed on [%p] -> [%p], size: [%i], direction: [%i], dZ: [%i]\n", + src, dst, size, flags, static_cast(dZ)); + fflush(stdout); + fflush(stderr); + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaMemcpy failed"); + return 0; + } + + return 1; +} + +int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, + Nd4jPointer reserved) { + auto pStream = reinterpret_cast(reserved); + + cudaMemcpyKind kind; + + // sd::DebugHelper::checkErrorCode(pStream, "Preliminary sync failed"); + + switch (flags) { + case 0: { + kind = cudaMemcpyHostToHost; + } break; + case 1: { + kind = cudaMemcpyHostToDevice; + } break; + case 2: { + kind = cudaMemcpyDeviceToHost; + } break; + case 3: { + kind = cudaMemcpyDeviceToDevice; + } break; + default: { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "UNDEFNED MEMCPY"); + return 0; + } + } + + auto dZ = + cudaMemcpyAsync(reinterpret_cast(dst), + const_cast(reinterpret_cast(src)), + static_cast(size), kind, *pStream); + // auto dZ = cudaMemcpy(reinterpret_cast(dst), const_cast(reinterpret_cast(src)), static_cast(size), kind); + if (dZ != 0) { + printf("Failed on [%p] -> [%p], size: [%i], direction: [%i], dZ: [%i]\n", + src, dst, size, flags, static_cast(dZ)); + fflush(stdout); + fflush(stderr); + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaMemcpyAsync failed"); + return 0; + } - return (Nd4jLong) memTotal; + return 1; } -int memcpySync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) { - cudaMemcpyKind kind; +int memsetSync(Nd4jPointer dst, int value, Nd4jLong size, int flags, + Nd4jPointer reserved) { + auto dZ = cudaMemset(reinterpret_cast(dst), value, + static_cast(size)); + if (dZ != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaMemset failed"); + } - switch (flags) { - case 0: { - kind = cudaMemcpyHostToHost; - } - break; - case 1: { - kind = cudaMemcpyHostToDevice; - } - break; - case 2: { - kind = cudaMemcpyDeviceToHost; - } - break; - case 3: { - kind = cudaMemcpyDeviceToDevice; - } - break; - default: { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("UNDEFNED MEMCPY"); - return 0; - } - } + return 1; +} - auto dZ = cudaMemcpy(reinterpret_cast(dst), const_cast(reinterpret_cast(src)), static_cast(size), kind); - if (dZ != 0) { - printf("Failed on [%p] -> [%p], size: [%i], direction: [%i], dZ: [%i]\n", src, dst, size, flags, static_cast(dZ)); - fflush(stdout); - fflush(stderr); - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMemcpy failed"); - return 0; - } +int memsetAsync(Nd4jPointer dst, int value, Nd4jLong size, int flags, + Nd4jPointer reserved) { + auto pStream = reinterpret_cast(reserved); - return 1; -} + auto dZ = cudaMemsetAsync(reinterpret_cast(dst), value, + static_cast(size), *pStream); + if (dZ != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaMemsetAsync failed"); + } -int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) { - auto pStream = reinterpret_cast(reserved); - - cudaMemcpyKind kind; - - //sd::DebugHelper::checkErrorCode(pStream, "Preliminary sync failed"); - - switch (flags) { - case 0: { - kind = cudaMemcpyHostToHost; - } - break; - case 1: { - kind = cudaMemcpyHostToDevice; - } - break; - case 2: { - kind = cudaMemcpyDeviceToHost; - } - break; - case 3: { - kind = cudaMemcpyDeviceToDevice; - } - break; - default: { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("UNDEFNED MEMCPY"); - return 0; - } - } - - auto dZ = cudaMemcpyAsync(reinterpret_cast(dst), const_cast(reinterpret_cast(src)), static_cast(size), kind, *pStream); - //auto dZ = cudaMemcpy(reinterpret_cast(dst), const_cast(reinterpret_cast(src)), static_cast(size), kind); - if (dZ != 0) { - printf("Failed on [%p] -> [%p], size: [%i], direction: [%i], dZ: [%i]\n", src, dst, size, flags, static_cast(dZ)); - fflush(stdout); - fflush(stderr); - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMemcpyAsync failed"); - return 0; - } - - return 1; -} - -int memsetSync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointer reserved) { - auto dZ = cudaMemset(reinterpret_cast(dst), value, static_cast(size)); - if (dZ != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMemset failed"); - } - - return 1; -} - -int memsetAsync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointer reserved) { - auto pStream = reinterpret_cast(reserved); - - auto dZ = cudaMemsetAsync(reinterpret_cast(dst), value, static_cast(size), *pStream); - if (dZ != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMemsetAsync failed"); - } - - return 1; + return 1; } int destroyEvent(Nd4jPointer event) { - auto pEvent = reinterpret_cast(&event); - auto dZ = cudaEventDestroy(*pEvent); - if (dZ != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaEventDestroy failed"); - } + auto pEvent = reinterpret_cast(&event); + auto dZ = cudaEventDestroy(*pEvent); + if (dZ != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaEventDestroy failed"); + } - return 1; + return 1; } int streamSynchronize(Nd4jPointer stream) { - auto pStream = reinterpret_cast(stream); + auto pStream = reinterpret_cast(stream); - auto dZ = cudaStreamSynchronize(*pStream); - if (dZ != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaStreamSynchronize failed"); - } + auto dZ = cudaStreamSynchronize(*pStream); + if (dZ != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaStreamSynchronize failed"); + } - return 1L; + return 1L; } int eventSynchronize(Nd4jPointer event) { - auto pEvent = reinterpret_cast(&event); + auto pEvent = reinterpret_cast(&event); - auto dZ = cudaEventSynchronize(*pEvent); - if (dZ != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaEventSynchronize failed"); - } + auto dZ = cudaEventSynchronize(*pEvent); + if (dZ != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaEventSynchronize failed"); + } - return 1L; + return 1L; } int getAvailableDevices() { - int devCnt = 0; - cudaGetDeviceCount(&devCnt); - return devCnt; + int devCnt = 0; + cudaGetDeviceCount(&devCnt); + return devCnt; } void enableDebugMode(bool reallyEnable) { - sd::Environment::getInstance().setDebug(reallyEnable); + sd::Environment::getInstance().setDebug(reallyEnable); } void setGridLimit(int gridSize) { - if (gridSize > 8192) - gridSize = 8192; - if (gridSize < 1) - gridSize = 1; - blockLimit = gridSize; + if (gridSize > 8192) gridSize = 8192; + if (gridSize < 1) gridSize = 1; + blockLimit = gridSize; } -int ompGetMaxThreads() { - return maxThreads; -} +int ompGetMaxThreads() { return maxThreads; } -int ompGetNumThreads() { - return maxThreads; -} +int ompGetNumThreads() { return maxThreads; } void setOmpNumThreads(int threads) { - if (threads > 1024) - threads = 1024; - if (threads < 32) - threads = 32; - maxThreads = threads; + if (threads > 1024) threads = 1024; + if (threads < 32) threads = 32; + maxThreads = threads; } void enableVerboseMode(bool reallyEnable) { - sd::Environment::getInstance().setVerbose(reallyEnable); + sd::Environment::getInstance().setVerbose(reallyEnable); } -int getDeviceMajor(int device) { - return deviceProperties[device].major; -} +int getDeviceMajor(int device) { return deviceProperties[device].major; } -int getDeviceMinor(int device) { - return deviceProperties[device].minor; -} +int getDeviceMinor(int device) { return deviceProperties[device].minor; } +const char *getDeviceName(int device) { return deviceProperties[device].name; } -const char * getDeviceName(int device) { - return deviceProperties[device].name; +void specialConcat(Nd4jPointer *extraPointers, int dimension, int numArrays, + Nd4jPointer *data, Nd4jPointer *inputShapeInfo, void *dZ, + Nd4jLong const *dZShapeInfo, Nd4jPointer *tadPointers, + Nd4jPointer *offsetPointers) { + try { + BUILD_SINGLE_SELECTOR(ArrayOptions::dataType(dZShapeInfo), + sd::SpecialMethods, + ::concatCpuGeneric(dimension, numArrays, data, + inputShapeInfo, dZ, dZShapeInfo), + LIBND4J_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } -void specialConcat( - Nd4jPointer *extraPointers, - int dimension, - int numArrays, - Nd4jPointer *data, - Nd4jPointer *inputShapeInfo, - void *dZ, - Nd4jLong const* dZShapeInfo, Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers) { - try { - BUILD_SINGLE_SELECTOR(ArrayOptions::dataType(dZShapeInfo), sd::SpecialMethods, - ::concatCpuGeneric(dimension, numArrays, data, inputShapeInfo, dZ, dZShapeInfo), - LIBND4J_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - - /** * This method saves */ -sd::TadPack* tadOnlyShapeInfo(Nd4jLong const* dXShapeInfo, int *dimension, int dimensionLength) { - try { - auto pack = new TadPack(); - *pack = sd::ConstantTadHelper::getInstance().tadForDimensions(dXShapeInfo, dimension, dimensionLength); - return pack; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; +sd::TadPack *tadOnlyShapeInfo(Nd4jLong const *dXShapeInfo, int *dimension, + int dimensionLength) { + try { + auto pack = new TadPack(); + *pack = sd::ConstantTadHelper::getInstance().tadForDimensions( + dXShapeInfo, dimension, dimensionLength); + return pack; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } +} + +Nd4jLong const *getPrimaryShapeInfo(sd::TadPack *pack) { + return pack->primaryShapeInfo(); +} +Nd4jLong const *getPrimaryOffsets(sd::TadPack *pack) { + return pack->primaryOffsets(); +} +Nd4jLong const *getSpecialShapeInfo(sd::TadPack *pack) { + return pack->specialShapeInfo(); +} +Nd4jLong const *getSpecialOffsets(sd::TadPack *pack) { + return pack->specialOffsets(); +} +Nd4jLong getNumberOfTads(sd::TadPack *pack) { return pack->numberOfTads(); } +int getShapeInfoLength(sd::TadPack *pack) { return pack->shapeInfoLength(); } + +int memcpyConstantAsync(Nd4jLong dst, Nd4jPointer src, Nd4jLong size, int flags, + Nd4jPointer reserved) { + cudaStream_t *pStream = reinterpret_cast(reserved); + + cudaMemcpyKind kind; + + DEBUG_KERNEL(pStream, -1); + + switch (flags) { + case 0: { + kind = cudaMemcpyHostToHost; + } break; + case 1: { + kind = cudaMemcpyHostToDevice; + } break; + case 2: { + kind = cudaMemcpyDeviceToHost; } -} - -Nd4jLong const* getPrimaryShapeInfo(sd::TadPack* pack) { - return pack->primaryShapeInfo(); -} -Nd4jLong const* getPrimaryOffsets(sd::TadPack* pack) { - return pack->primaryOffsets(); -} -Nd4jLong const* getSpecialShapeInfo(sd::TadPack* pack) { - return pack->specialShapeInfo(); -} -Nd4jLong const* getSpecialOffsets(sd::TadPack* pack) { - return pack->specialOffsets(); -} -Nd4jLong getNumberOfTads(sd::TadPack* pack) { - return pack->numberOfTads(); -} -int getShapeInfoLength(sd::TadPack* pack) { - return pack->shapeInfoLength(); -} - -int memcpyConstantAsync(Nd4jLong dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) { - cudaStream_t *pStream = reinterpret_cast(reserved); - - cudaMemcpyKind kind; - - DEBUG_KERNEL(pStream, -1); - - switch (flags) { - case 0: { - kind = cudaMemcpyHostToHost; - } - break; - case 1: { - kind = cudaMemcpyHostToDevice; - } - break; - case 2: { - kind = cudaMemcpyDeviceToHost; - } - case 3: { - kind = cudaMemcpyDeviceToDevice; - } - break; - } - auto dZ = cudaMemcpyToSymbolAsync(deviceConstantMemory, const_cast(src), size, dst, kind, *pStream); - if (dZ != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMemcpyToSymbolAsync failed"); - } - - return 1; + case 3: { + kind = cudaMemcpyDeviceToDevice; + } break; + } + auto dZ = cudaMemcpyToSymbolAsync(deviceConstantMemory, + const_cast(src), size, dst, + kind, *pStream); + if (dZ != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaMemcpyToSymbolAsync failed"); + } + + return 1; } Nd4jPointer getConstantSpace() { - Nd4jPointer dConstAddr; - cudaError_t dZ = cudaGetSymbolAddress(reinterpret_cast(&dConstAddr), deviceConstantMemory); - - if (dZ != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaGetSymbolAddress failed"); - } - - return dConstAddr; -} - -void pullRows(Nd4jPointer *extraPointers, - OpaqueDataBuffer *dbX, Nd4jLong const* xShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* zShapeInfo, Nd4jLong const* dZShapeInfo, - Nd4jLong n, - Nd4jLong *indexes, - Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffsets, - Nd4jLong const* zTadShapeInfo, - Nd4jLong const* zTadOffsets) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - dim3 launchDims(64, 256, 1024); - auto xType = sd::ArrayOptions::dataType(xShapeInfo); - BUILD_SINGLE_SELECTOR(xType, pullRowsKernelGeneric, - (launchDims, stream, dbX->special(), dbZ->special(), n, indexes, tadShapeInfo, tadOffsets, zTadShapeInfo, zTadOffsets), - LIBND4J_TYPES); - - DEBUG_KERNEL(stream, -1); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - + Nd4jPointer dConstAddr; + cudaError_t dZ = cudaGetSymbolAddress(reinterpret_cast(&dConstAddr), + deviceConstantMemory); -void average(Nd4jPointer *extras, - Nd4jPointer *x, Nd4jLong const* xShapeInfo, - Nd4jPointer *dx, Nd4jLong const* dXShapeInfo, - void *z, Nd4jLong const* zShapeInfo, - void *dz, Nd4jLong const* dzShapeInfo, - int n, - Nd4jLong length, - bool propagate) { - try { - cudaStream_t *stream = reinterpret_cast(extras[1]); - int mode = getDeviceId(extras[3]); + if (dZ != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaGetSymbolAddress failed"); + } - auto dX = reinterpret_cast(dx); + return dConstAddr; +} - if (sd::Environment::getInstance().isDebugAndVerbose()) - printf("averageFloat called\n"); +void pullRows(Nd4jPointer *extraPointers, OpaqueDataBuffer *dbX, + Nd4jLong const *xShapeInfo, Nd4jLong const *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong const *zShapeInfo, + Nd4jLong const *dZShapeInfo, Nd4jLong n, Nd4jLong *indexes, + Nd4jLong const *tadShapeInfo, Nd4jLong const *tadOffsets, + Nd4jLong const *zTadShapeInfo, Nd4jLong const *zTadOffsets) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - auto xType = sd::ArrayOptions::dataType(xShapeInfo); - // launching on gpu - if (mode == 0) { - dim3 launchDims(256, 256, 4096); - BUILD_SINGLE_SELECTOR(xType, averagingKernelGeneric, (launchDims, stream, dX, dz, n, length, propagate), - LIBND4J_TYPES); - sd::DebugHelper::checkErrorCode(stream, "AverageFloat(...) failed"); - } else { - // launching on host memory - BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::averageGeneric(x, z, zShapeInfo, n, length, propagate), - LIBND4J_TYPES); - } - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); + dim3 launchDims(64, 256, 1024); + auto xType = sd::ArrayOptions::dataType(xShapeInfo); + BUILD_SINGLE_SELECTOR( + xType, pullRowsKernelGeneric, + (launchDims, stream, dbX->special(), dbZ->special(), n, indexes, + tadShapeInfo, tadOffsets, zTadShapeInfo, zTadOffsets), + LIBND4J_TYPES); + + DEBUG_KERNEL(stream, -1); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void average(Nd4jPointer *extras, Nd4jPointer *x, Nd4jLong const *xShapeInfo, + Nd4jPointer *dx, Nd4jLong const *dXShapeInfo, void *z, + Nd4jLong const *zShapeInfo, void *dz, Nd4jLong const *dzShapeInfo, + int n, Nd4jLong length, bool propagate) { + try { + cudaStream_t *stream = reinterpret_cast(extras[1]); + int mode = getDeviceId(extras[3]); + + auto dX = reinterpret_cast(dx); + + if (sd::Environment::getInstance().isDebugAndVerbose()) + printf("averageFloat called\n"); + + auto xType = sd::ArrayOptions::dataType(xShapeInfo); + // launching on gpu + if (mode == 0) { + dim3 launchDims(256, 256, 4096); + BUILD_SINGLE_SELECTOR(xType, averagingKernelGeneric, + (launchDims, stream, dX, dz, n, length, propagate), + LIBND4J_TYPES); + sd::DebugHelper::checkErrorCode(stream, "AverageFloat(...) failed"); + } else { + // launching on host memory + BUILD_SINGLE_SELECTOR( + xType, sd::SpecialMethods, + ::averageGeneric(x, z, zShapeInfo, n, length, propagate), + LIBND4J_TYPES); } -} + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void accumulate(Nd4jPointer *extras, Nd4jPointer *x, Nd4jLong const *xShapeInfo, + Nd4jPointer *dx, Nd4jLong const *dXShapeInfo, void *z, + Nd4jLong const *zShapeInfo, void *dz, + Nd4jLong const *dzShapeInfo, int n, Nd4jLong length) { + try { + auto stream = reinterpret_cast(extras[1]); + int mode = getDeviceId(extras[3]); -void accumulate(Nd4jPointer *extras, - Nd4jPointer *x, Nd4jLong const* xShapeInfo, - Nd4jPointer *dx, Nd4jLong const* dXShapeInfo, - void *z, Nd4jLong const* zShapeInfo, - void *dz, Nd4jLong const* dzShapeInfo, - int n, - Nd4jLong length) { - try { - auto stream = reinterpret_cast(extras[1]); - int mode = getDeviceId(extras[3]); - - auto dX = reinterpret_cast(dx); - - if (sd::Environment::getInstance().isDebugAndVerbose()) - printf("accumulateFloat called\n"); - auto xType = sd::ArrayOptions::dataType(xShapeInfo); - - // launching on gpu - if (mode == 0) { - dim3 launchDims(n, 256, 16384); - BUILD_SINGLE_SELECTOR(xType, accumulateKernelGeneric, (launchDims, stream, dX, dz, n, length), - LIBND4J_TYPES); - sd::DebugHelper::checkErrorCode(stream, "AccumulateFloat(...) failed"); - } else { - // launching on host memory - BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::accumulateGeneric(x, z, zShapeInfo, n, length), - LIBND4J_TYPES); - } - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} + auto dX = reinterpret_cast(dx); + if (sd::Environment::getInstance().isDebugAndVerbose()) + printf("accumulateFloat called\n"); + auto xType = sd::ArrayOptions::dataType(xShapeInfo); -void shuffle(Nd4jPointer *extras, - Nd4jPointer *x, Nd4jPointer *xShapeInfo, - Nd4jPointer *dx, Nd4jPointer *dXShapeInfo, - Nd4jPointer *z, Nd4jPointer *zShapeInfo, - Nd4jPointer *dz, Nd4jPointer *dZShapeInfo, - int N, - int *shuffleMap, - Nd4jPointer *tadShapeInfo, - Nd4jPointer *tadOffsets) { - try { - cudaStream_t *stream = reinterpret_cast(extras[1]); - - auto dX = reinterpret_cast(dx); - auto dZ = reinterpret_cast(dz); - auto xShape = reinterpret_cast(xShapeInfo); - auto dxShape = reinterpret_cast(dXShapeInfo); - auto tadOnlyShapeInfo = reinterpret_cast(tadShapeInfo); - auto tadOffset = reinterpret_cast(tadOffsets); - - auto xType = sd::ArrayOptions::dataType(xShape[0]); - dim3 launchDims(256, 512, 8192); - BUILD_SINGLE_SELECTOR(xType, shuffleKernelGeneric, - (launchDims, stream, dX, dxShape, dZ, N, shuffleMap, tadOnlyShapeInfo, tadOffset), - LIBND4J_TYPES); - - sd::DebugHelper::checkErrorCode(stream, "shuffle(...) failed"); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + // launching on gpu + if (mode == 0) { + dim3 launchDims(n, 256, 16384); + BUILD_SINGLE_SELECTOR(xType, accumulateKernelGeneric, + (launchDims, stream, dX, dz, n, length), + LIBND4J_TYPES); + sd::DebugHelper::checkErrorCode(stream, "AccumulateFloat(...) failed"); + } else { + // launching on host memory + BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, + ::accumulateGeneric(x, z, zShapeInfo, n, length), + LIBND4J_TYPES); } + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void shuffle(Nd4jPointer *extras, Nd4jPointer *x, Nd4jPointer *xShapeInfo, + Nd4jPointer *dx, Nd4jPointer *dXShapeInfo, Nd4jPointer *z, + Nd4jPointer *zShapeInfo, Nd4jPointer *dz, Nd4jPointer *dZShapeInfo, + int N, int *shuffleMap, Nd4jPointer *tadShapeInfo, + Nd4jPointer *tadOffsets) { + try { + cudaStream_t *stream = reinterpret_cast(extras[1]); + + auto dX = reinterpret_cast(dx); + auto dZ = reinterpret_cast(dz); + auto xShape = reinterpret_cast(xShapeInfo); + auto dxShape = reinterpret_cast(dXShapeInfo); + auto tadOnlyShapeInfo = reinterpret_cast(tadShapeInfo); + auto tadOffset = reinterpret_cast(tadOffsets); + + auto xType = sd::ArrayOptions::dataType(xShape[0]); + dim3 launchDims(256, 512, 8192); + BUILD_SINGLE_SELECTOR(xType, shuffleKernelGeneric, + (launchDims, stream, dX, dxShape, dZ, N, shuffleMap, + tadOnlyShapeInfo, tadOffset), + LIBND4J_TYPES); + + sd::DebugHelper::checkErrorCode(stream, "shuffle(...) failed"); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } bool isExperimentalEnabled() { - return sd::Environment::getInstance().isExperimentalBuild(); + return sd::Environment::getInstance().isExperimentalBuild(); } void setOmpMinThreads(int threads) { - minThreads = sd::math::nd4j_max(32, threads); - minThreads = sd::math::nd4j_min(maxThreads, minThreads); + minThreads = sd::math::nd4j_max(32, threads); + minThreads = sd::math::nd4j_min(maxThreads, minThreads); } -int getDevice() { - return sd::AffinityManager::currentDeviceId(); -} +int getDevice() { return sd::AffinityManager::currentDeviceId(); } void setElementThreshold(int num) { - // this is no-op for CUDA + // this is no-op for CUDA } void setTADThreshold(int num) { - // this is no-op for CUDA + // this is no-op for CUDA } //////////////////////////////////////////////////////////////////////// -void execSummaryStats(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - bool biasCorrected) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execSummaryStats(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - extraParams, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), - biasCorrected); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execSummaryStats(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, bool biasCorrected) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execSummaryStats( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + biasCorrected); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execSummaryStatsTad(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, - bool biasCorrected, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbDimension}); - InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execSummaryStats(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - extraParams, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), - reinterpret_cast(dbDimension->special()), dimensionLength, - tadShapeInfo, tadOffsets, - biasCorrected); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbDimension}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execSummaryStatsTad(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, + OpaqueDataBuffer *dbDimension, + Nd4jLong const *hDimensionShape, + Nd4jLong const *dDimensionShape, bool biasCorrected, + Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbDimension}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execSummaryStats( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + reinterpret_cast(dbDimension->special()), dimensionLength, + tadShapeInfo, tadOffsets, biasCorrected); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbDimension}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execReduce3(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduce3(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - extraParams, - dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hYShapeInfo).special(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special()); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduce3(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + Nd4jLong const *hXShapeInfo, Nd4jLong const *dXShapeInfo, + void *extraParams, OpaqueDataBuffer *dbY, + Nd4jLong const *hYShapeInfo, Nd4jLong const *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execReduce3( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + extraParams, dbY->primary(), hYShapeInfo, dbY->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hYShapeInfo) + .special(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execReduce3Tad(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); - InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, - dimension, - shape::length(hDimensionShape)); - auto tadLength = shape::length(tadPack.primaryShapeInfo()); - auto yLength = shape::length(hYShapeInfo); - auto xLength = shape::length(hXShapeInfo); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - - if (tadLength == yLength || tadLength == xLength) { - // nd4j_printf("== way\n",""); - NativeOpExecutioner::execReduce3(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - extraParams, - dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hYShapeInfo).special(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), - dimension, dimensionLength, - tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); - } else - NativeOpExecutioner::execReduce3TAD(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - extraParams, - dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hYShapeInfo).special(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), - dimension, dimensionLength, - tadOnlyShapeInfo, yTadOffsets, yTadOnlyShapeInfo, yTadOffsets); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduce3Tad( + Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + Nd4jLong const *hXShapeInfo, Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbY, Nd4jLong const *hYShapeInfo, + Nd4jLong const *dYShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, Nd4jLong const *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong const *hDimensionShape, + Nd4jLong const *dDimensionShape, Nd4jLong const *tadOnlyShapeInfo, + Nd4jLong const *tadOffsets, Nd4jLong const *yTadOnlyShapeInfo, + Nd4jLong const *yTadOffsets) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + hXShapeInfo, dimension, shape::length(hDimensionShape)); + auto tadLength = shape::length(tadPack.primaryShapeInfo()); + auto yLength = shape::length(hYShapeInfo); + auto xLength = shape::length(hXShapeInfo); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + + if (tadLength == yLength || tadLength == xLength) { + // nd4j_printf("== way\n",""); + NativeOpExecutioner::execReduce3( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + extraParams, dbY->primary(), hYShapeInfo, dbY->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hYShapeInfo) + .special(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + yTadOnlyShapeInfo, yTadOffsets); + } else + NativeOpExecutioner::execReduce3TAD( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + extraParams, dbY->primary(), hYShapeInfo, dbY->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hYShapeInfo) + .special(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + dimension, dimensionLength, tadOnlyShapeInfo, yTadOffsets, + yTadOnlyShapeInfo, yTadOffsets); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduce3Scalar(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - extraParams, - dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hYShapeInfo).special(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special()); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduce3Scalar(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbY, Nd4jLong const *hYShapeInfo, + Nd4jLong const *dYShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execReduce3Scalar( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + extraParams, dbY->primary(), hYShapeInfo, dbY->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hYShapeInfo) + .special(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execScalarBool(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbScalar, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo, - void *extraParams) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalar}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execScalarBool(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), - dbScalar->primary(), hScalarShapeInfo, dbScalar->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hScalarShapeInfo).special(), - extraParams); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalar}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execScalarBool(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, Nd4jLong const *dZShapeInfo, + OpaqueDataBuffer *dbScalar, + Nd4jLong const *hScalarShapeInfo, + Nd4jLong const *dScalarShapeInfo, void *extraParams) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalar}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execScalarBool( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + dbScalar->primary(), hScalarShapeInfo, dbScalar->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hScalarShapeInfo) + .special(), + extraParams); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalar}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execScalarBoolTad(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbScalars, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalars}); - InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execScalarBool(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - extraParams, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), - dbScalars->primary(), hScalarShapeInfo, dbScalars->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hScalarShapeInfo).special(), - dimension, dimensionLength, - tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalars}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execScalarBoolTad( + Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + Nd4jLong const *hXShapeInfo, Nd4jLong const *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, OpaqueDataBuffer *dbScalars, + Nd4jLong const *hScalarShapeInfo, Nd4jLong const *dScalarShapeInfo, + void *extraParams, OpaqueDataBuffer *dbDimension, + Nd4jLong const *hDimensionShape, Nd4jLong const *dDimensionShape, + Nd4jLong const *tadShapeInfo, Nd4jLong const *tadOffsets, + Nd4jLong const *tadShapeInfoZ, Nd4jLong const *tadOffsetsZ) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalars}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execScalarBool( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + dbScalars->primary(), hScalarShapeInfo, dbScalars->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hScalarShapeInfo) + .special(), + dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, + tadOffsetsZ); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalars}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execScalar(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbScalar, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo, - void *extraParams) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalar}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execScalar(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), - dbScalar->primary(), hScalarShapeInfo, dbScalar->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hScalarShapeInfo).special(), - extraParams); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalar}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execScalar(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + Nd4jLong const *hXShapeInfo, Nd4jLong const *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, OpaqueDataBuffer *dbScalar, + Nd4jLong const *hScalarShapeInfo, + Nd4jLong const *dScalarShapeInfo, void *extraParams) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalar}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execScalar( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + dbScalar->primary(), hScalarShapeInfo, dbScalar->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hScalarShapeInfo) + .special(), + extraParams); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalar}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execScalarTad(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbScalars, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalars}); - InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); +void execScalarTad(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + Nd4jLong const *hXShapeInfo, Nd4jLong const *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, OpaqueDataBuffer *dbScalars, + Nd4jLong const *hScalarShapeInfo, + Nd4jLong const *dScalarShapeInfo, void *extraParams, + OpaqueDataBuffer *dbDimension, + Nd4jLong const *hDimensionShape, + Nd4jLong const *dDimensionShape, + Nd4jLong const *tadShapeInfo, Nd4jLong const *tadOffsets, + Nd4jLong const *tadShapeInfoZ, Nd4jLong const *tadOffsetsZ) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalars}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - if (yType != xType && yType != sd::DataType::BOOL && !isExperimentalEnabled()) - throw sd::datatype_exception::build("execScalar both operands must have same data type", xType, yType); + if (yType != xType && yType != sd::DataType::BOOL && + !isExperimentalEnabled()) + throw sd::datatype_exception::build( + "execScalar both operands must have same data type", xType, yType); - dim3 launchDims(256, 256, 16384); + dim3 launchDims(256, 256, 16384); #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_PAIRWISE_SELECTOR( + xType, yType, zType, functions::scalar::ScalarTransform, + ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, + dZ, dZShapeInfo, dScalars, extraParams, + dimension, dimensionLength, tadShapeInfo, + tadOffsets, tadShapeInfoZ, tadOffsetsZ), + LIBND4J_TYPES, LIBND4J_TYPES); #else - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), dbScalars->special(), extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::scalar::ScalarTransform, + ::executeCudaAlongDimension(launchDims, stream, opNum, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + dbScalars->special(), extraParams, + dimension, dimensionLength, tadShapeInfo, + tadOffsets, tadShapeInfoZ, tadOffsetsZ), + LIBND4J_TYPES); #endif - DEBUG_KERNEL(stream, opNum); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalars}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execAggregate(Nd4jPointer *extraPointers, - int opNum, - void **arguments, - int numArguments, - Nd4jLong **shapes, - int numShapes, - int *indexArguments, - int numIndexArguments, - int **intArrays, - int numIntArrays, - void *realArguments, - int numRealArguments, - sd::DataType dtype) { + DEBUG_KERNEL(stream, opNum); + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalars}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } -void batchExecutor(Nd4jPointer *extraPointers, - int numAggregates, - int opNum, - int maxArgs, - int maxShapes, - int maxIntArrays, - int maxIntArraySize, - int maxIdx, - int maxReals, - void *ptrToArguments, - sd::DataType dtype) { -} +void execAggregate(Nd4jPointer *extraPointers, int opNum, void **arguments, + int numArguments, Nd4jLong **shapes, int numShapes, + int *indexArguments, int numIndexArguments, int **intArrays, + int numIntArrays, void *realArguments, int numRealArguments, + sd::DataType dtype) {} -void execAggregateBatch(Nd4jPointer *extraPointers, - int numAggregates, int opNum, - int maxArgs, int maxShapes, - int maxIntArrays, int maxIntArraySize, - int maxIdx, int maxReals, - void *ptrToArguments, sd::DataType dtype) { +void batchExecutor(Nd4jPointer *extraPointers, int numAggregates, int opNum, + int maxArgs, int maxShapes, int maxIntArrays, + int maxIntArraySize, int maxIdx, int maxReals, + void *ptrToArguments, sd::DataType dtype) {} -} +void execAggregateBatch(Nd4jPointer *extraPointers, int numAggregates, + int opNum, int maxArgs, int maxShapes, int maxIntArrays, + int maxIntArraySize, int maxIdx, int maxReals, + void *ptrToArguments, sd::DataType dtype) {} //////////////////////////////////////////////////////////////////////// -void execRandom(Nd4jPointer *extraPointers, - int opNum, - Nd4jPointer stateHost, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraArguments) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execRandom(&lc, opNum, stateHost, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), - extraArguments); - - InteropDataBuffer::registerSpecialUse({dbZ}, {}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execRandom(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, void *extraArguments) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execRandom(&lc, opNum, stateHost, dbZ->primary(), + hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + extraArguments); + + InteropDataBuffer::registerSpecialUse({dbZ}, {}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// void execRandom2(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraArguments) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execRandom(&lc, opNum, stateHost, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), - extraArguments); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, Nd4jLong const *dZShapeInfo, + void *extraArguments) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execRandom(&lc, opNum, stateHost, dbX->primary(), + hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + extraArguments); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// void execRandom3(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraArguments) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execRandom(&lc, opNum, stateHost, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hYShapeInfo).special(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), - extraArguments); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, OpaqueDataBuffer *dbY, + Nd4jLong const *hYShapeInfo, Nd4jLong const *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, void *extraArguments) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execRandom(&lc, opNum, stateHost, dbX->primary(), + hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + dbY->primary(), hYShapeInfo, dbY->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hYShapeInfo) + .special(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + extraArguments); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +Nd4jPointer initRandom(Nd4jPointer *extraPointers, long seed, long bufferSize, + Nd4jPointer ptrToBuffer) { + unsigned long long *ptrHost = + reinterpret_cast(extraPointers[0]); + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); + + // we don't synchronize at random initialization, it's safe to go unsync here + // cudaStreamSynchronize(*stream); + + auto ptrDev = reinterpret_cast(ptrToBuffer); + auto buffer = new sd::random::RandomBuffer( + seed, bufferSize, reinterpret_cast(ptrHost), + reinterpret_cast(ptrDev)); + buffer->propagateToDevice(buffer, *stream); + + sd::DebugHelper::checkErrorCode(stream, "initRandom(...) failed A"); + + // we generate sequence in the host memory + sd::random::Xoroshiro128 generator(buffer); + generator.refreshBuffer(); + + // and copy it to gpu + cudaMemcpyAsync(ptrDev, ptrHost, bufferSize * 8, cudaMemcpyHostToDevice, + *stream); + sd::DebugHelper::checkErrorCode(stream, "initRandom(...) failed B"); + + return buffer; } - -Nd4jPointer initRandom(Nd4jPointer *extraPointers, long seed, long bufferSize, Nd4jPointer ptrToBuffer) { - - unsigned long long *ptrHost = reinterpret_cast(extraPointers[0]); - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - - // we don't synchronize at random initialization, it's safe to go unsync here - // cudaStreamSynchronize(*stream); - - auto ptrDev = reinterpret_cast(ptrToBuffer); - auto buffer = new sd::random::RandomBuffer(seed, bufferSize, reinterpret_cast(ptrHost), reinterpret_cast(ptrDev)); - buffer->propagateToDevice(buffer, *stream); - - sd::DebugHelper::checkErrorCode(stream, "initRandom(...) failed A"); - - // we generate sequence in the host memory - sd::random::Xoroshiro128 generator(buffer); - generator.refreshBuffer(); - - // and copy it to gpu - cudaMemcpyAsync(ptrDev, ptrHost, bufferSize * 8, cudaMemcpyHostToDevice, *stream); - sd::DebugHelper::checkErrorCode(stream, "initRandom(...) failed B"); - - return buffer; -} - - void destroyRandom(Nd4jPointer ptrBuffer) { + sd::random::RandomBuffer *buffer = + reinterpret_cast(ptrBuffer); - sd::random::RandomBuffer *buffer = reinterpret_cast (ptrBuffer); + // FIXME: it's bad thing, but we can't know in advance, which stream(s) where + // using this generator in practice + cudaDeviceSynchronize(); - // FIXME: it's bad thing, but we can't know in advance, which stream(s) where using this generator in practice - cudaDeviceSynchronize(); - - delete buffer; + delete buffer; } -void refreshBuffer(Nd4jPointer *extraPointers, long seed, Nd4jPointer ptrRandom) { - - sd::random::RandomBuffer *buffer = reinterpret_cast (ptrRandom); +void refreshBuffer(Nd4jPointer *extraPointers, long seed, + Nd4jPointer ptrRandom) { + sd::random::RandomBuffer *buffer = + reinterpret_cast(ptrRandom); - unsigned long long *ptrHost = reinterpret_cast(extraPointers[0]); - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - cudaStreamSynchronize(*stream); + unsigned long long *ptrHost = + reinterpret_cast(extraPointers[0]); + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); + cudaStreamSynchronize(*stream); - uint64_t *ptrDev = buffer->getDeviceBuffer(); + uint64_t *ptrDev = buffer->getDeviceBuffer(); - // update rng state - buffer->setSeed(seed); - buffer->setOffset(0); - buffer->propagateToDevice(buffer, *stream); + // update rng state + buffer->setSeed(seed); + buffer->setOffset(0); + buffer->propagateToDevice(buffer, *stream); - // refresh buffer on host size - sd::random::Xoroshiro128 generator(buffer); - generator.refreshBuffer(); + // refresh buffer on host size + sd::random::Xoroshiro128 generator(buffer); + generator.refreshBuffer(); - // copy back to gpu - cudaMemcpyAsync(ptrDev, ptrHost, buffer->getSize() * 8, cudaMemcpyHostToDevice, *stream); + // copy back to gpu + cudaMemcpyAsync(ptrDev, ptrHost, buffer->getSize() * 8, + cudaMemcpyHostToDevice, *stream); } -void reSeedBuffer(Nd4jPointer *extraPointers, long seed, Nd4jPointer ptrRandom) { +void reSeedBuffer(Nd4jPointer *extraPointers, long seed, + Nd4jPointer ptrRandom) { + sd::random::RandomBuffer *buffer = + reinterpret_cast(ptrRandom); - sd::random::RandomBuffer *buffer = reinterpret_cast (ptrRandom); + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); + cudaStreamSynchronize(*stream); - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - cudaStreamSynchronize(*stream); - - // update rng state - buffer->reSeed(seed); - buffer->setOffset(0); - buffer->propagateToDevice(buffer, *stream); + // update rng state + buffer->reSeed(seed); + buffer->setOffset(0); + buffer->propagateToDevice(buffer, *stream); } - - /** - * Return the length of a shape buffer - * based on the pointer - * @param buffer the buffer pointer to check - * @return - */ + * Return the length of a shape buffer + * based on the pointer + * @param buffer the buffer pointer to check + * @return + */ int lengthForShapeBufferPointer(Nd4jPointer buffer) { - auto shapeBuffer = reinterpret_cast(buffer); - return shape::shapeInfoLength(shape::rank(shapeBuffer)); + auto shapeBuffer = reinterpret_cast(buffer); + return shape::shapeInfoLength(shape::rank(shapeBuffer)); } - /** - * The pointer to get the address for - * - * @param address the address to get the pointer - * @return the pointer for the given address - */ + * The pointer to get the address for + * + * @param address the address to get the pointer + * @return the pointer for the given address + */ Nd4jPointer pointerForAddress(Nd4jLong address) { - return reinterpret_cast(address); -} - -void tear(Nd4jPointer *extras, - OpaqueDataBuffer *dbX, Nd4jLong const* xShapeInfo, Nd4jLong const* dXShapeInfo, - Nd4jPointer *targets, - Nd4jLong const* zShapeInfo, - Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffsets) { - try { - InteropDataBuffer::prepareSpecialUse({}, {dbX}); - - cudaStream_t *stream = reinterpret_cast(extras[1]); - dim3 launchDims(512, 512, 512); - auto xType = sd::ArrayOptions::dataType(xShapeInfo); - BUILD_SINGLE_SELECTOR(xType, tearKernelGeneric, - (launchDims, stream, dbX->special(), dXShapeInfo, targets, zShapeInfo, tadShapeInfo, tadOffsets), - LIBND4J_TYPES); - - sd::DebugHelper::checkErrorCode(stream, "tearFloat(...) failed"); - - InteropDataBuffer::registerSpecialUse({}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - - -void prescanArrayRecursive(Nd4jPointer *extras, int *dZ, int *dX, int numElements, int level) { - - auto stream = reinterpret_cast(extras[1]); - auto g_scanBlockSums = reinterpret_cast(extras[2]); - - int blockSize = 512; // max size of the thread blocks - int numBlocks = sd::math::nd4j_max(1, static_cast(ceil(static_cast(numElements) / (2.f * blockSize)))); - int numThreads; - - if (numBlocks > 1) - numThreads = blockSize; - else if (sd::isPowerOfTwo(numElements)) - numThreads = numElements / 2; - else - numThreads = sd::floorPow2(numElements); - - int numEltsPerBlock = numThreads * 2; - - // if this is a non-power-of-2 array, the last block will be non-full - // compute the smallest power of 2 able to compute its scan. - int numEltsLastBlock = - numElements - (numBlocks-1) * numEltsPerBlock; - int numThreadsLastBlock = sd::math::nd4j_max(1, numEltsLastBlock / 2); - int np2LastBlock = 0; - int sharedMemLastBlock = 0; - - if (numEltsLastBlock != numEltsPerBlock) { - np2LastBlock = 1; - - if(!isPowerOfTwo(numEltsLastBlock)) - numThreadsLastBlock = floorPow2(numEltsLastBlock); - - unsigned int extraSpace = (2 * numThreadsLastBlock) / NUM_BANKS; - sharedMemLastBlock = sizeof(int) * (2 * numThreadsLastBlock + extraSpace); + return reinterpret_cast(address); +} + +void tear(Nd4jPointer *extras, OpaqueDataBuffer *dbX, + Nd4jLong const *xShapeInfo, Nd4jLong const *dXShapeInfo, + Nd4jPointer *targets, Nd4jLong const *zShapeInfo, + Nd4jLong const *tadShapeInfo, Nd4jLong const *tadOffsets) { + try { + InteropDataBuffer::prepareSpecialUse({}, {dbX}); + + cudaStream_t *stream = reinterpret_cast(extras[1]); + dim3 launchDims(512, 512, 512); + auto xType = sd::ArrayOptions::dataType(xShapeInfo); + BUILD_SINGLE_SELECTOR(xType, tearKernelGeneric, + (launchDims, stream, dbX->special(), dXShapeInfo, + targets, zShapeInfo, tadShapeInfo, tadOffsets), + LIBND4J_TYPES); + + sd::DebugHelper::checkErrorCode(stream, "tearFloat(...) failed"); + + InteropDataBuffer::registerSpecialUse({}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void prescanArrayRecursive(Nd4jPointer *extras, int *dZ, int *dX, + int numElements, int level) { + auto stream = reinterpret_cast(extras[1]); + auto g_scanBlockSums = reinterpret_cast(extras[2]); + + int blockSize = 512; // max size of the thread blocks + int numBlocks = sd::math::nd4j_max( + 1, static_cast( + ceil(static_cast(numElements) / (2.f * blockSize)))); + int numThreads; + + if (numBlocks > 1) + numThreads = blockSize; + else if (sd::isPowerOfTwo(numElements)) + numThreads = numElements / 2; + else + numThreads = sd::floorPow2(numElements); + + int numEltsPerBlock = numThreads * 2; + + // if this is a non-power-of-2 array, the last block will be non-full + // compute the smallest power of 2 able to compute its scan. + int numEltsLastBlock = numElements - (numBlocks - 1) * numEltsPerBlock; + int numThreadsLastBlock = sd::math::nd4j_max(1, numEltsLastBlock / 2); + int np2LastBlock = 0; + int sharedMemLastBlock = 0; + + if (numEltsLastBlock != numEltsPerBlock) { + np2LastBlock = 1; + + if (!isPowerOfTwo(numEltsLastBlock)) + numThreadsLastBlock = floorPow2(numEltsLastBlock); + + unsigned int extraSpace = (2 * numThreadsLastBlock) / NUM_BANKS; + sharedMemLastBlock = sizeof(int) * (2 * numThreadsLastBlock + extraSpace); + } + + // padding space is used to avoid shared memory bank conflicts + int extraSpace = numEltsPerBlock / NUM_BANKS; + int sharedMemSize = sizeof(int) * (numEltsPerBlock + extraSpace); + + // setup execution parameters + // if NP2, we process the last block separately + dim3 grid(max(1, numBlocks - np2LastBlock), 1, 1); + dim3 threads(numThreads, 1, 1); + dim3 gridOnes(1, 1, 1); + dim3 threadsOnes(numThreadsLastBlock, 1, 1); + + if (sharedMemSize < 2048) sharedMemSize = 2048; + + if (sharedMemLastBlock < 2048) sharedMemLastBlock = 2048; + + // execute the scan + if (numBlocks > 1) { + sd::prescanLauncher(grid, threads, sharedMemSize, stream, dZ, + dX, g_scanBlockSums[level], numThreads * 2, + 0, 0); + if (np2LastBlock) { + sd::prescanLauncher(gridOnes, threadsOnes, sharedMemLastBlock, + stream, dZ, dX, g_scanBlockSums[level], + numEltsLastBlock, numBlocks - 1, + numElements - numEltsLastBlock); } - // padding space is used to avoid shared memory bank conflicts - int extraSpace = numEltsPerBlock / NUM_BANKS; - int sharedMemSize = sizeof(int) * (numEltsPerBlock + extraSpace); - - // setup execution parameters - // if NP2, we process the last block separately - dim3 grid(max(1, numBlocks - np2LastBlock), 1, 1); - dim3 threads(numThreads, 1, 1); - dim3 gridOnes(1, 1, 1); - dim3 threadsOnes(numThreadsLastBlock, 1, 1); - - if (sharedMemSize < 2048) - sharedMemSize = 2048; - - if (sharedMemLastBlock < 2048) - sharedMemLastBlock = 2048; - - // execute the scan - if (numBlocks > 1) { - sd::prescanLauncher(grid, threads, sharedMemSize, stream, dZ, dX, g_scanBlockSums[level], numThreads * 2, 0, 0); - if (np2LastBlock) { - sd::prescanLauncher(gridOnes, threadsOnes, sharedMemLastBlock, stream, dZ, dX, g_scanBlockSums[level], numEltsLastBlock, numBlocks - 1, numElements - numEltsLastBlock); - } - - // After scanning all the sub-blocks, we are mostly done. But now we - // need to take all of the last values of the sub-blocks and scan those. - // This will give us a new value that must be sdded to each block to - // get the final results. - // recursive (CPU) call - prescanArrayRecursive(extras, g_scanBlockSums[level], g_scanBlockSums[level], numBlocks, level+1); - - sd::uniformAdd<<>>(dZ, g_scanBlockSums[level], numElements - numEltsLastBlock, 0, 0); - - if (np2LastBlock) { - sd::uniformAdd<<<1, numThreadsLastBlock, 1024, *stream>>>(dZ, g_scanBlockSums[level], numEltsLastBlock, numBlocks - 1, numElements - numEltsLastBlock); - } - } else if (isPowerOfTwo(numElements)) { - sd::prescanLauncher(grid, threads, sharedMemSize, stream, dZ, dX, 0, numThreads * 2, 0, 0); - } else { - sd::prescanLauncher(grid, threads, sharedMemSize, stream, dZ, dX, 0, numElements, 0, 0); + // After scanning all the sub-blocks, we are mostly done. But now we + // need to take all of the last values of the sub-blocks and scan those. + // This will give us a new value that must be sdded to each block to + // get the final results. + // recursive (CPU) call + prescanArrayRecursive(extras, g_scanBlockSums[level], + g_scanBlockSums[level], numBlocks, level + 1); + + sd::uniformAdd<<>>( + dZ, g_scanBlockSums[level], numElements - numEltsLastBlock, 0, 0); + + if (np2LastBlock) { + sd::uniformAdd<<<1, numThreadsLastBlock, 1024, *stream>>>( + dZ, g_scanBlockSums[level], numEltsLastBlock, numBlocks - 1, + numElements - numEltsLastBlock); } + } else if (isPowerOfTwo(numElements)) { + sd::prescanLauncher(grid, threads, sharedMemSize, stream, dZ, + dX, 0, numThreads * 2, 0, 0); + } else { + sd::prescanLauncher(grid, threads, sharedMemSize, stream, dZ, + dX, 0, numElements, 0, 0); + } - sd::DebugHelper::checkErrorCode(stream, "prescanArray(...) failed"); + sd::DebugHelper::checkErrorCode(stream, "prescanArray(...) failed"); } //////////////////////////////////////////////////////////////////////// -void execReduce3All(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParamsVals, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, - Nd4jLong const* xTadShapeInfo, Nd4jLong const* xOffsets, - Nd4jLong const* yTadShapeInfo, Nd4jLong const* yOffsets) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY, dbDimension}); - InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduce3All(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), - extraParamsVals, - dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hYShapeInfo).special(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), - reinterpret_cast(dbDimension->special()), dimensionLength, - xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - - -void sort(Nd4jPointer *extraPointers, - void *x, Nd4jLong const* xShapeInfo, - void *dX, Nd4jLong const* dXShapeInfo, - bool descending) { - try { - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - - auto xLength = shape::length(xShapeInfo); - auto xEWS = shape::elementWiseStride(xShapeInfo); - auto xType = sd::ArrayOptions::dataType(xShapeInfo); - - - // check if xLength is a power of 2, and use bitonic sort, if that's the case - if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) { - int numThreads = sd::math::nd4j_min(512, xLength); - int numBlocks = xLength / numThreads; - if (xLength % numThreads > 0 || numBlocks == 0) - numBlocks++; - - dim3 launchDims(numBlocks, numThreads, 32768); +void execReduce3All(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParamsVals, + OpaqueDataBuffer *dbY, Nd4jLong const *hYShapeInfo, + Nd4jLong const *dYShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, Nd4jLong const *dZShapeInfo, + OpaqueDataBuffer *dbDimension, + Nd4jLong const *hDimensionShape, + Nd4jLong const *dDimensionShape, + Nd4jLong const *xTadShapeInfo, Nd4jLong const *xOffsets, + Nd4jLong const *yTadShapeInfo, Nd4jLong const *yOffsets) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY, dbDimension}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execReduce3All( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hXShapeInfo) + .special(), + extraParamsVals, dbY->primary(), hYShapeInfo, dbY->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hYShapeInfo) + .special(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + .bufferForShapeInfo(hZShapeInfo) + .special(), + reinterpret_cast(dbDimension->special()), dimensionLength, + xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void sort(Nd4jPointer *extraPointers, void *x, Nd4jLong const *xShapeInfo, + void *dX, Nd4jLong const *dXShapeInfo, bool descending) { + try { + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - for (int k = 2; k <= xLength; k = 2 * k) { - for (int j = k >> 1; j > 0; j = j >> 1) { - BUILD_SINGLE_SELECTOR(xType, bitonicSortStepGeneric, - (launchDims, stream, dX, dXShapeInfo, j, k, xLength, descending), - LIBND4J_TYPES); - } - } - } else { - int numThreads = sd::math::nd4j_min(512, xLength); - int numBlocks = xLength / numThreads; - if (xLength % numThreads > 0 || numBlocks == 0) - numBlocks++; - - numBlocks = sd::math::nd4j_min(512, numBlocks); - dim3 launchDims(numBlocks, numThreads, 32768); - - int max = 2, dg = 0; - while (max < xLength) { - max <<= 1; - dg++; - } - max <<= 1; - - for (int window = 2; window < max; window <<= 1) { - int n = window; - int rev = 0; - do { - int half = n >> 1; - BUILD_SINGLE_SELECTOR(xType, bitonicArbitraryStepGeneric, - (launchDims, stream, dX, dXShapeInfo, n, xLength, rev, descending), - LIBND4J_TYPES); - n >>= 1; - rev = 1; - } while (n > 1); - } + auto xLength = shape::length(xShapeInfo); + auto xEWS = shape::elementWiseStride(xShapeInfo); + auto xType = sd::ArrayOptions::dataType(xShapeInfo); + + // check if xLength is a power of 2, and use bitonic sort, if that's the + // case + if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && + (xLength <= 1024 * 1024 * 10)) { + int numThreads = sd::math::nd4j_min(512, xLength); + int numBlocks = xLength / numThreads; + if (xLength % numThreads > 0 || numBlocks == 0) numBlocks++; + + dim3 launchDims(numBlocks, numThreads, 32768); + + for (int k = 2; k <= xLength; k = 2 * k) { + for (int j = k >> 1; j > 0; j = j >> 1) { + BUILD_SINGLE_SELECTOR( + xType, bitonicSortStepGeneric, + (launchDims, stream, dX, dXShapeInfo, j, k, xLength, descending), + LIBND4J_TYPES); } - - sd::DebugHelper::checkErrorCode(stream, "sort(...) failed"); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } + } else { + int numThreads = sd::math::nd4j_min(512, xLength); + int numBlocks = xLength / numThreads; + if (xLength % numThreads > 0 || numBlocks == 0) numBlocks++; + + numBlocks = sd::math::nd4j_min(512, numBlocks); + dim3 launchDims(numBlocks, numThreads, 32768); + + int max = 2, dg = 0; + while (max < xLength) { + max <<= 1; + dg++; + } + max <<= 1; + + for (int window = 2; window < max; window <<= 1) { + int n = window; + int rev = 0; + do { + int half = n >> 1; + BUILD_SINGLE_SELECTOR(xType, bitonicArbitraryStepGeneric, + (launchDims, stream, dX, dXShapeInfo, n, + xLength, rev, descending), + LIBND4J_TYPES); + n >>= 1; + rev = 1; + } while (n > 1); + } } -} - - -void sortByKey(Nd4jPointer *extraPointers, - void *x, Nd4jLong const* xShapeInfo, - void *dX, Nd4jLong const* dXShapeInfo, - void *y, Nd4jLong const* yShapeInfo, - void *dy, Nd4jLong const* dyShapeInfo, - bool descending) { - try { - auto stream = reinterpret_cast(extraPointers[1]); - - auto xLength = shape::length(xShapeInfo); - auto yLength = shape::length(yShapeInfo); - auto xEWS = shape::elementWiseStride(xShapeInfo); - auto xType = sd::ArrayOptions::dataType(xShapeInfo); - auto yType = sd::ArrayOptions::dataType(yShapeInfo); - - if (shape::isEmpty(xShapeInfo) || shape::isEmpty(yShapeInfo)) - return; - - if (xLength != yLength) - throw std::runtime_error("sortByKey: keys and values must have the same size"); - - // check if xLength is a power of 2, and use bitonic sort, if that's the case - if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) { - int numThreads = sd::math::nd4j_min(512, xLength); - int numBlocks = xLength / numThreads; - if (xLength % numThreads > 0 || numBlocks == 0) - numBlocks++; - - dim3 launchDims(numBlocks, numThreads, 32768); - - for (int k = 2; k <= xLength; k = 2 * k) { - for (int j = k >> 1; j > 0; j = j >> 1) { - BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericKey, - (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, j, k, xLength, descending), - LIBND4J_TYPES, LIBND4J_TYPES); - } - } - } else { - int numThreads = sd::math::nd4j_min(512, xLength); - int numBlocks = xLength / numThreads; - if (xLength % numThreads > 0 || numBlocks == 0) - numBlocks++; - - numBlocks = sd::math::nd4j_min(512, numBlocks); - dim3 launchDims(numBlocks, numThreads, 32768); - - int max = 2, dg = 0; - while (max < xLength) { - max <<= 1; - dg++; - } - max <<= 1; - - for (int window = 2; window < max; window <<= 1) { - int n = window; - int rev = 0; - do { - int half = n >> 1; - BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericKey, - (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, n, xLength, rev, descending), - LIBND4J_TYPES, LIBND4J_TYPES); - n >>= 1; - rev = 1; - } while (n > 1); - } + sd::DebugHelper::checkErrorCode(stream, "sort(...) failed"); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void sortByKey(Nd4jPointer *extraPointers, void *x, Nd4jLong const *xShapeInfo, + void *dX, Nd4jLong const *dXShapeInfo, void *y, + Nd4jLong const *yShapeInfo, void *dy, + Nd4jLong const *dyShapeInfo, bool descending) { + try { + auto stream = reinterpret_cast(extraPointers[1]); + + auto xLength = shape::length(xShapeInfo); + auto yLength = shape::length(yShapeInfo); + auto xEWS = shape::elementWiseStride(xShapeInfo); + auto xType = sd::ArrayOptions::dataType(xShapeInfo); + auto yType = sd::ArrayOptions::dataType(yShapeInfo); + + if (shape::isEmpty(xShapeInfo) || shape::isEmpty(yShapeInfo)) return; + + if (xLength != yLength) + throw std::runtime_error( + "sortByKey: keys and values must have the same size"); + + // check if xLength is a power of 2, and use bitonic sort, if that's the + // case + if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && + (xLength <= 1024 * 1024 * 10)) { + int numThreads = sd::math::nd4j_min(512, xLength); + int numBlocks = xLength / numThreads; + if (xLength % numThreads > 0 || numBlocks == 0) numBlocks++; + + dim3 launchDims(numBlocks, numThreads, 32768); + + for (int k = 2; k <= xLength; k = 2 * k) { + for (int j = k >> 1; j > 0; j = j >> 1) { + BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericKey, + (launchDims, stream, dX, dXShapeInfo, dy, + dyShapeInfo, j, k, xLength, descending), + LIBND4J_TYPES, LIBND4J_TYPES); } - - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } + } else { + int numThreads = sd::math::nd4j_min(512, xLength); + int numBlocks = xLength / numThreads; + if (xLength % numThreads > 0 || numBlocks == 0) numBlocks++; + + numBlocks = sd::math::nd4j_min(512, numBlocks); + dim3 launchDims(numBlocks, numThreads, 32768); + + int max = 2, dg = 0; + while (max < xLength) { + max <<= 1; + dg++; + } + max <<= 1; + + for (int window = 2; window < max; window <<= 1) { + int n = window; + int rev = 0; + do { + int half = n >> 1; + BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericKey, + (launchDims, stream, dX, dXShapeInfo, dy, + dyShapeInfo, n, xLength, rev, descending), + LIBND4J_TYPES, LIBND4J_TYPES); + n >>= 1; + rev = 1; + } while (n > 1); + } } -} -void sortByValue(Nd4jPointer *extraPointers, - void *x, Nd4jLong const* xShapeInfo, - void *dX, Nd4jLong const* dXShapeInfo, - void *y, Nd4jLong const* yShapeInfo, - void *dy, Nd4jLong const* dyShapeInfo, - bool descending) { - try { - auto stream = reinterpret_cast(extraPointers[1]); - - auto xLength = shape::length(xShapeInfo); - auto yLength = shape::length(yShapeInfo); - auto xEWS = shape::elementWiseStride(xShapeInfo); - auto xType = sd::ArrayOptions::dataType(yShapeInfo); - auto yType = sd::ArrayOptions::dataType(xShapeInfo); - - if (shape::isEmpty(xShapeInfo) || shape::isEmpty(yShapeInfo)) - return; - - if (xLength != yLength) - throw std::runtime_error("sortByValue: keys and values must have the same size"); - - - // check if xLength is a power of 2, and use bitonic sort, if that's the case - if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) { - int numThreads = sd::math::nd4j_min(512, xLength); - int numBlocks = xLength / numThreads; - if (xLength % numThreads > 0 || numBlocks == 0) - numBlocks++; - - dim3 launchDims(numBlocks, numThreads, 32768); - - for (int k = 2; k <= xLength; k = 2 * k) { - for (int j = k >> 1; j > 0; j = j >> 1) { - BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericKey, - (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, j, k, xLength, descending), - LIBND4J_TYPES, LIBND4J_TYPES); - } - } - } else { - int numThreads = sd::math::nd4j_min(512, xLength); - int numBlocks = xLength / numThreads; - if (xLength % numThreads > 0 || numBlocks == 0) - numBlocks++; - - numBlocks = sd::math::nd4j_min(512, numBlocks); - dim3 launchDims(numBlocks, numThreads, 32768); - - int max = 2, dg = 0; - while (max < xLength) { - max <<= 1; - dg++; - } - max <<= 1; - - for (int window = 2; window < max; window <<= 1) { - int n = window; - int rev = 0; - do { - int half = n >> 1; - BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericKey, - (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, n, xLength, rev, descending), - LIBND4J_TYPES, LIBND4J_TYPES); - n >>= 1; - rev = 1; - } while (n > 1); - } + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void sortByValue(Nd4jPointer *extraPointers, void *x, + Nd4jLong const *xShapeInfo, void *dX, + Nd4jLong const *dXShapeInfo, void *y, + Nd4jLong const *yShapeInfo, void *dy, + Nd4jLong const *dyShapeInfo, bool descending) { + try { + auto stream = reinterpret_cast(extraPointers[1]); + + auto xLength = shape::length(xShapeInfo); + auto yLength = shape::length(yShapeInfo); + auto xEWS = shape::elementWiseStride(xShapeInfo); + auto xType = sd::ArrayOptions::dataType(yShapeInfo); + auto yType = sd::ArrayOptions::dataType(xShapeInfo); + + if (shape::isEmpty(xShapeInfo) || shape::isEmpty(yShapeInfo)) return; + + if (xLength != yLength) + throw std::runtime_error( + "sortByValue: keys and values must have the same size"); + + // check if xLength is a power of 2, and use bitonic sort, if that's the + // case + if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && + (xLength <= 1024 * 1024 * 10)) { + int numThreads = sd::math::nd4j_min(512, xLength); + int numBlocks = xLength / numThreads; + if (xLength % numThreads > 0 || numBlocks == 0) numBlocks++; + + dim3 launchDims(numBlocks, numThreads, 32768); + + for (int k = 2; k <= xLength; k = 2 * k) { + for (int j = k >> 1; j > 0; j = j >> 1) { + BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericKey, + (launchDims, stream, dy, dyShapeInfo, dX, + dXShapeInfo, j, k, xLength, descending), + LIBND4J_TYPES, LIBND4J_TYPES); } - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - - - -void sortTadByKey(Nd4jPointer *extraPointers, - void *x, Nd4jLong const* xShapeInfo, - void *dX, Nd4jLong const* dXShapeInfo, - void *y, Nd4jLong const* yShapeInfo, - void *dy, Nd4jLong const* dyShapeInfo, - int *dimension, - int dimensionLength, - bool descending) { - try { - auto stream = reinterpret_cast(extraPointers[1]); - auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext() - : reinterpret_cast(extraPointers[0]); - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength); - dim3 launchDims((int) tadPack.numberOfTads(), 256, 2048); - auto xType = sd::ArrayOptions::dataType(xShapeInfo); - auto yType = sd::ArrayOptions::dataType(yShapeInfo); - BUILD_DOUBLE_SELECTOR(xType, yType, oesTadGenericKey, - (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, nullptr, dimensionLength, tadPack.platformShapeInfo(), tadPack.platformOffsets(), descending), - LIBND4J_TYPES, LIBND4J_TYPES); - - sd::DebugHelper::checkErrorCode(stream, "sortTadKey(...) failed"); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } + } else { + int numThreads = sd::math::nd4j_min(512, xLength); + int numBlocks = xLength / numThreads; + if (xLength % numThreads > 0 || numBlocks == 0) numBlocks++; + + numBlocks = sd::math::nd4j_min(512, numBlocks); + dim3 launchDims(numBlocks, numThreads, 32768); + + int max = 2, dg = 0; + while (max < xLength) { + max <<= 1; + dg++; + } + max <<= 1; + + for (int window = 2; window < max; window <<= 1) { + int n = window; + int rev = 0; + do { + int half = n >> 1; + BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericKey, + (launchDims, stream, dy, dyShapeInfo, dX, + dXShapeInfo, n, xLength, rev, descending), + LIBND4J_TYPES, LIBND4J_TYPES); + n >>= 1; + rev = 1; + } while (n > 1); + } } -} - -void sortTadByValue(Nd4jPointer *extraPointers, - void *x, Nd4jLong const* xShapeInfo, - void *dX, Nd4jLong const* dXShapeInfo, - void *y, Nd4jLong const* yShapeInfo, - void *dy, Nd4jLong const* dyShapeInfo, - int *dimension, - int dimensionLength, - bool descending) { - try { - auto stream = reinterpret_cast(extraPointers[1]); - auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext() - : reinterpret_cast(extraPointers[0]); - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength); - dim3 launchDims((int) tadPack.numberOfTads(), 256, 2048); - auto xType = sd::ArrayOptions::dataType(yShapeInfo); - auto yType = sd::ArrayOptions::dataType(xShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, yType, oesTadGenericKey, - (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, nullptr, dimensionLength, tadPack.platformShapeInfo(), tadPack.platformOffsets(), descending), - LIBND4J_TYPES, LIBND4J_TYPES); - - sd::DebugHelper::checkErrorCode(stream, "sortTadValue(...) failed"); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void sortTadByKey(Nd4jPointer *extraPointers, void *x, + Nd4jLong const *xShapeInfo, void *dX, + Nd4jLong const *dXShapeInfo, void *y, + Nd4jLong const *yShapeInfo, void *dy, + Nd4jLong const *dyShapeInfo, int *dimension, + int dimensionLength, bool descending) { + try { + auto stream = reinterpret_cast(extraPointers[1]); + auto context = extraPointers[0] == 0 + ? LaunchContext::defaultContext() + : reinterpret_cast(extraPointers[0]); + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + xShapeInfo, dimension, dimensionLength); + dim3 launchDims((int)tadPack.numberOfTads(), 256, 2048); + auto xType = sd::ArrayOptions::dataType(xShapeInfo); + auto yType = sd::ArrayOptions::dataType(yShapeInfo); + BUILD_DOUBLE_SELECTOR( + xType, yType, oesTadGenericKey, + (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, nullptr, + dimensionLength, tadPack.platformShapeInfo(), + tadPack.platformOffsets(), descending), + LIBND4J_TYPES, LIBND4J_TYPES); + + sd::DebugHelper::checkErrorCode(stream, "sortTadKey(...) failed"); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void sortTadByValue(Nd4jPointer *extraPointers, void *x, + Nd4jLong const *xShapeInfo, void *dX, + Nd4jLong const *dXShapeInfo, void *y, + Nd4jLong const *yShapeInfo, void *dy, + Nd4jLong const *dyShapeInfo, int *dimension, + int dimensionLength, bool descending) { + try { + auto stream = reinterpret_cast(extraPointers[1]); + auto context = extraPointers[0] == 0 + ? LaunchContext::defaultContext() + : reinterpret_cast(extraPointers[0]); + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + xShapeInfo, dimension, dimensionLength); + dim3 launchDims((int)tadPack.numberOfTads(), 256, 2048); + auto xType = sd::ArrayOptions::dataType(yShapeInfo); + auto yType = sd::ArrayOptions::dataType(xShapeInfo); + + BUILD_DOUBLE_SELECTOR( + xType, yType, oesTadGenericKey, + (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, nullptr, + dimensionLength, tadPack.platformShapeInfo(), + tadPack.platformOffsets(), descending), + LIBND4J_TYPES, LIBND4J_TYPES); + + sd::DebugHelper::checkErrorCode(stream, "sortTadValue(...) failed"); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void sortTad(Nd4jPointer *extraPointers, void *x, Nd4jLong const *xShapeInfo, + void *dX, Nd4jLong const *dXShapeInfo, int *dimension, + int dimensionLength, Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, bool descending) { + try { + // to be implemented + auto stream = reinterpret_cast(extraPointers[1]); + auto context = extraPointers[0] == 0 + ? LaunchContext::defaultContext() + : reinterpret_cast(extraPointers[0]); + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + xShapeInfo, dimension, dimensionLength); + dim3 launchDims((int)tadPack.numberOfTads(), 512, 33768); + auto xType = sd::ArrayOptions::dataType(xShapeInfo); + BUILD_SINGLE_SELECTOR( + xType, oesTadGeneric, + (launchDims, stream, dX, dXShapeInfo, nullptr, dimensionLength, + tadShapeInfo, tadOffsets, descending), + LIBND4J_TYPES); + + sd::DebugHelper::checkErrorCode(stream, "sortTad(...) failed"); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, + Nd4jLong length, int rank) { + throw std::runtime_error("sortCooIndices:: Not implemented yet"); +} + +Nd4jLong *mmapFile(Nd4jPointer *extraPointers, const char *fileName, + Nd4jLong length) { + auto hZ = new Nd4jLong[3]; + errno = 0; + try { +#if defined(_WIN32) || defined(_WIN64) + _mmap(hZ, static_cast(length), fileName); +#else + int fd = open(fileName, O_RDWR, 0); // checking for failed fopen + if (fd < 0) { + nd4j_printf("Errno: %i\n", errno); + throw std::runtime_error("Failed to open file for MMAP"); } -} + void *ptr = mmap(NULL, length, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + // check for failed allocation + if (ptr == MAP_FAILED) return nullptr; -void sortTad(Nd4jPointer *extraPointers, - void *x, Nd4jLong const* xShapeInfo, - void *dX, Nd4jLong const* dXShapeInfo, - int *dimension, - int dimensionLength, - Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffsets, - bool descending) { - try { - // to be implemented - auto stream = reinterpret_cast(extraPointers[1]); - auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext() - : reinterpret_cast(extraPointers[0]); - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength); - dim3 launchDims((int) tadPack.numberOfTads(), 512, 33768); - auto xType = sd::ArrayOptions::dataType(xShapeInfo); - BUILD_SINGLE_SELECTOR(xType, oesTadGeneric, - (launchDims, stream, dX, dXShapeInfo, nullptr, dimensionLength, tadShapeInfo, tadOffsets, descending), - LIBND4J_TYPES); - - sd::DebugHelper::checkErrorCode(stream, "sortTad(...) failed"); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} + hZ[0] = (Nd4jLong)ptr; + hZ[1] = fd; -void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank) { - throw std::runtime_error("sortCooIndices:: Not implemented yet"); -} +#endif + hZ[2] = length; -Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length) { - return nullptr; + return hZ; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } -void munmapFile(Nd4jPointer *extraPointers, Nd4jLong* ptrMap, Nd4jLong length) { +void munmapFile(Nd4jPointer *extraPointers, Nd4jLong *ptrMap, Nd4jLong length) { + munmap((Nd4jPointer)ptrMap[0], length); +#if defined(_WIN32) || defined(_WIN64) + CloseHandle(reinterpret_cast(ptrMap[1])); +#else + close((int)ptrMap[1]); +#endif + delete[] ptrMap; } - -sd::graph::ResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPointer flatBufferPointer) { - try { - return sd::graph::GraphExecutioner::executeFlatBuffer(flatBufferPointer); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } +sd::graph::ResultWrapper *executeFlatGraph(Nd4jPointer *extraPointers, + Nd4jPointer flatBufferPointer) { + try { + //return sd::graph::GraphExecutioner::executeFlatBuffer(flatBufferPointer); + return nullptr; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } -Nd4jLong getResultWrapperSize(sd::graph::ResultWrapper* ptr) { - return ptr->size(); +Nd4jLong getResultWrapperSize(sd::graph::ResultWrapper *ptr) { + return ptr->size(); } -Nd4jPointer getResultWrapperPointer(sd::graph::ResultWrapper* ptr) { - return ptr->pointer(); +Nd4jPointer getResultWrapperPointer(sd::graph::ResultWrapper *ptr) { + return ptr->pointer(); } - -const char* getAllCustomOps() { - return sd::ops::OpRegistrator::getInstance().getAllCustomOperations(); +const char *getAllCustomOps() { + return sd::ops::OpRegistrator::getInstance().getAllCustomOperations(); } +sd::ShapeList *_calculateOutputShapes( + Nd4jPointer *extraPointers, std::shared_ptr &op, + Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputShapes, + double *tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, + int numBArgs, int *dArgs, int numDArgs) { + sd::graph::VariableSpace varSpace; + Context block(2, &varSpace); + sd::ShapeList inShapes; -sd::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, sd::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) { - sd::graph::VariableSpace varSpace; - Context block(2, &varSpace); - sd::ShapeList inShapes; + for (int e = 0; e < numIArgs; e++) block.appendI(iArgs[e]); - for (int e = 0; e < numIArgs; e++) - block.getIArguments()->push_back(iArgs[e]); + for (int e = 0; e < numTArgs; e++) block.appendT(tArgs[e]); - for (int e = 0; e < numTArgs; e++) - block.getTArguments()->push_back(tArgs[e]); + for (int e = 0; e < numBArgs; e++) block.appendB(bArgs[e]); - for (int e = 0; e < numBArgs; e++) - block.getBArguments()->push_back(bArgs[e]); + for (int e = 0; e < numDArgs; e++) + block.appendD((sd::DataType) dArgs[e]); - for (int e = 0; e < numDArgs; e++) - block.getDArguments()->push_back((sd::DataType) dArgs[e]); + for (int e = 0; e < numInputShapes; e++) { + auto shape_ = reinterpret_cast(inputShapes[e]); - for (int e = 0; e < numInputShapes; e++) { - auto shape_ = reinterpret_cast(inputShapes[e]); + // we shouldn't copy buffer if that's empty array + void *buffer_ = sd::ArrayOptions::arrayType(shape_) == ArrayType::EMPTY + ? nullptr + : inputBuffers[e]; + void *bufferD_ = sd::ArrayOptions::arrayType(shape_) == ArrayType::EMPTY + ? nullptr + : inputBuffers[e + numInputShapes]; - // we shouldn't copy buffer if that's empty array - void *buffer_ = sd::ArrayOptions::arrayType(shape_) == ArrayType::EMPTY ? nullptr : inputBuffers[e]; - void *bufferD_ = sd::ArrayOptions::arrayType(shape_) == ArrayType::EMPTY ? nullptr : inputBuffers[e + numInputShapes]; + sd::NDArray array(buffer_, bufferD_, shape_); - auto array = new sd::NDArray(buffer_, bufferD_, shape_); + // block should contain references to proper variable + varSpace.putVariable(1, e, array); + block.pickInput(1, e); - // block should contain references to proper variable - varSpace.putVariable(1, e, array); - block.pickInput(1, e); + inShapes.push_back(shape_); + } - inShapes.push_back(shape_); - } + auto shapeList = op->calculateOutputShape(&inShapes, block); - auto shapeList = op->calculateOutputShape(&inShapes, block); - - if (varSpace.launchContext()->getWorkspace() != nullptr) - shapeList->detach(); - - return shapeList; + return shapeList; } -sd::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) { - try { - auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash); +sd::ShapeList *calculateOutputShapes2(Nd4jPointer *extraPointers, Nd4jLong hash, + Nd4jPointer *inputBuffers, + Nd4jPointer *inputShapes, + int numInputShapes, double *tArgs, + int numTArgs, Nd4jLong *iArgs, + int numIArgs, bool *bArgs, int numBArgs, + int *dArgs, int numDArgs) { + try { + auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash); - return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, - iArgs, numIArgs, bArgs, numBArgs, dArgs, numDArgs); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } + return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, + numInputShapes, tArgs, numTArgs, iArgs, + numIArgs, bArgs, numBArgs, dArgs, numDArgs); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } -sd::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, sd::ops::DeclarableOp* op, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs) { - Context block(1); - sd::ShapeList inShapes; +sd::ShapeList *_calculateOutputShapes(Nd4jPointer *extraPointers, + std::shared_ptr &op, + Nd4jPointer *inputShapes, + int numInputShapes, double *tArgs, + int numTArgs, Nd4jLong *iArgs, + int numIArgs) { + Context block(1); + sd::ShapeList inShapes; - for (int e = 0; e < numIArgs; e++) - block.getIArguments()->push_back(iArgs[e]); + for (int e = 0; e < numIArgs; e++) block.appendI(iArgs[e]); - for (int e = 0; e < numTArgs; e++) - block.getTArguments()->push_back(tArgs[e]); + for (int e = 0; e < numTArgs; e++) block.appendT(tArgs[e]); - for (int e = 0; e < numInputShapes; e++) - inShapes.push_back(reinterpret_cast(inputShapes[e])); + for (int e = 0; e < numInputShapes; e++) + inShapes.push_back(reinterpret_cast(inputShapes[e])); - auto shapeList = op->calculateOutputShape(&inShapes, block); + auto shapeList = op->calculateOutputShape(&inShapes, block); - return shapeList; + return shapeList; } -sd::ShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs) { - try { - auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash); +sd::ShapeList *calculateOutputShapes(Nd4jPointer *extraPointers, Nd4jLong hash, + Nd4jPointer *inputShapes, + int numInputShapes, double *tArgs, + int numTArgs, Nd4jLong *iArgs, + int numIArgs) { + try { + auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash); - return _calculateOutputShapes(extraPointers, op, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } + return _calculateOutputShapes(extraPointers, op, inputShapes, + numInputShapes, tArgs, numTArgs, iArgs, + numIArgs); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } -Nd4jLong getShapeListSize(sd::ShapeList* list) { - return list->size(); -} +Nd4jLong getShapeListSize(sd::ShapeList *list) { return list->size(); } -Nd4jLong const* getShape(sd::ShapeList* list, Nd4jLong i) { - return list->at(i); +Nd4jLong const *getShape(sd::ShapeList *list, Nd4jLong i) { + return list->at(i); } -static FORCEINLINE Nd4jStatus realExec(sd::ops::DeclarableOp* op, Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace) { - if (op == nullptr) - nd4j_printf("Can't find requested operation: [%lld]\n", hash); - - // we're using the same fake nodeId everywhere here +static FORCEINLINE Nd4jStatus realExec( + std::shared_ptr &op, Nd4jPointer *extraPointers, Nd4jLong hash, + Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputs, + Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs, + double *tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, + bool *bArgs, int numBArgs, bool isInplace) { + if (op == nullptr) + nd4j_printf("Can't find requested operation: [%lld]\n", hash); + + // we're using the same fake nodeId everywhere here - std::vector inputs(numInputs); - std::vector outputs(numOutputs); - std::vector ttArgs(numTArgs); - std::vector bbArgs(numBArgs); - std::vector iiArgs(numIArgs); + std::vector inputs(numInputs); + std::vector outputs(numOutputs); + std::vector ttArgs(numTArgs); + std::vector bbArgs(numBArgs); + std::vector iiArgs(numIArgs); - // filling block now with inputs - for (int e = 0; e < numInputs; e++) { - auto shape = reinterpret_cast(inputShapes[e]); - void *buffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY ? nullptr : inputBuffers[e]; - void *bufferD = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY ? nullptr : inputBuffers[e + numInputs]; + // filling block now with inputs + for (int e = 0; e < numInputs; e++) { + auto shape = reinterpret_cast(inputShapes[e]); + void *buffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY + ? nullptr + : inputBuffers[e]; + void *bufferD = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY + ? nullptr + : inputBuffers[e + numInputs]; - inputs[e] = new sd::NDArray(buffer, bufferD, shape); - } - - // if not inplace - transferring output arrays - - if (!isInplace) - for (int e = 0; e < numOutputs; e++) { - // we want to keep original output shape intact - auto shape = shape::copyShape(reinterpret_cast(outputShapes[e])); - void *buffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY ? nullptr : outputBuffers[e]; - void *bufferD = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY ? nullptr : outputBuffers[e + numOutputs]; - - // FIXME: revisit this. - bool canNullify = true; - for (int i = 0; i < numInputs; i++) { - void *ibuffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY ? nullptr : inputBuffers[i]; - if (ibuffer == buffer) { - canNullify = false; - break; - } - } - - if (canNullify && buffer != nullptr) - memset((uint8_t *) buffer, '\0', shape::length(shape) * DataTypeUtils::sizeOfElement(ArrayOptions::dataType(shape))); - - auto array = new sd::NDArray(buffer, bufferD, shape); - outputs[e] = array; - } - - for (int e = 0; e < numIArgs; e++) - iiArgs[e] = iArgs[e]; + inputs[e] = new sd::NDArray(buffer, bufferD, shape); + } + + // if not inplace - transferring output arrays + + if (!isInplace) + for (int e = 0; e < numOutputs; e++) { + // we want to keep original output shape intact + auto shape = + shape::copyShape(reinterpret_cast(outputShapes[e])); + void *buffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY + ? nullptr + : outputBuffers[e]; + void *bufferD = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY + ? nullptr + : outputBuffers[e + numOutputs]; + + // FIXME: revisit this. + bool canNullify = true; + for (int i = 0; i < numInputs; i++) { + void *ibuffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY + ? nullptr + : inputBuffers[i]; + if (ibuffer == buffer) { + canNullify = false; + break; + } + } - for (int e = 0; e < numTArgs; e++) - ttArgs[e] = tArgs[e]; + if (canNullify && buffer != nullptr) + memset((uint8_t *)buffer, '\0', + shape::length(shape) * + DataTypeUtils::sizeOfElement(ArrayOptions::dataType(shape))); - for (int e = 0; e < numBArgs; e++) - bbArgs[e] = bArgs[e]; + auto array = new sd::NDArray(buffer, bufferD, shape); + outputs[e] = array; + } + for (int e = 0; e < numIArgs; e++) iiArgs[e] = iArgs[e]; - // hypothetically at this point we have everything filled - auto dZ = op->execute(inputs, outputs, ttArgs, iiArgs, bbArgs, std::vector(), isInplace); - //auto dZ = op->execute(inputs, ttArgs, iiArgs, isInplace); + for (int e = 0; e < numTArgs; e++) ttArgs[e] = tArgs[e]; + for (int e = 0; e < numBArgs; e++) bbArgs[e] = bArgs[e]; - if (!isInplace) - for (int e = 0; e < numOutputs; e++) { - //shape::printShapeInfoLinear("JVM output shape", (int *) outputShapes[e]); - //shape::printShapeInfoLinear("C++ output shape", (int *) outputs[e]->shapeInfo()); - //outputs[e]->printIndexedBuffer("C++ raw output"); - //outputs[e]->printBuffer("C++ indexed output"); + // hypothetically at this point we have everything filled + auto dZ = op->execute(inputs, outputs, ttArgs, iiArgs, bbArgs, + std::vector(), isInplace); + // auto dZ = op->execute(inputs, ttArgs, iiArgs, isInplace); - if (outputs[e]->ordering() != shape::order(reinterpret_cast(outputShapes[e]))) - outputs[e]->streamline(shape::order(reinterpret_cast(outputShapes[e]))); - } + if (!isInplace) + for (int e = 0; e < numOutputs; e++) { + // shape::printShapeInfoLinear("JVM output shape", (int *) + // outputShapes[e]); shape::printShapeInfoLinear("C++ output shape", (int + // *) outputs[e]->shapeInfo()); outputs[e]->printIndexedBuffer("C++ raw + // output"); outputs[e]->printBuffer("C++ indexed output"); + + if (outputs[e]->ordering() != + shape::order(reinterpret_cast(outputShapes[e]))) + outputs[e]->streamline( + shape::order(reinterpret_cast(outputShapes[e]))); + } - for (auto v: inputs) - delete v; + for (auto v : inputs) delete v; - for (auto v: outputs) - delete v; + for (auto v : outputs) delete v; - return Status::OK(); + return Status::OK(); } +int execCustomOp(Nd4jPointer *extraPointers, Nd4jLong hash, + Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, + int numInputs, Nd4jPointer *outputBuffers, + Nd4jPointer *outputShapes, int numOutputs, double *tArgs, + int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, + int numBArgs, bool isInplace) { + try { + auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash); -int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace) { - try { - auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash); - - return realExec(op, extraPointers, hash, inputBuffers, inputShapes, numInputs, outputBuffers, outputShapes, - numOutputs, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, isInplace); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return 1; - } + return realExec(op, extraPointers, hash, inputBuffers, inputShapes, + numInputs, outputBuffers, outputShapes, numOutputs, tArgs, + numTArgs, iArgs, numIArgs, bArgs, numBArgs, isInplace); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return 1; + } } -int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opContext) { - try { - auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash); - auto context = reinterpret_cast(opContext); - - auto result = op->execute(context); +int execCustomOp2(Nd4jPointer *extraPointers, Nd4jLong hash, + Nd4jPointer opContext) { + try { + auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash); + auto context = reinterpret_cast(opContext); - auto res = cudaStreamSynchronize(*context->launchContext()->getCudaStream()); - if (res != 0) - throw sd::cuda_exception::build("customOp execution failed", res); + auto result = op->execute(context); - for (auto v:context->fastpath_in()) { - if (!v->isEmpty()) - v->syncToDevice(); - } + auto res = + cudaStreamSynchronize(*context->launchContext()->getCudaStream()); + if (res != 0) + throw sd::cuda_exception::build("customOp execution failed", res); - for (auto v:context->fastpath_out()) { - if (!v->isEmpty()) - v->syncToDevice(); - } - - return result; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return 1; + for (auto v : context->fastpath_in()) { + if (!v->isEmpty()) v->syncToDevice(); } -} - -int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flatBufferPointer) { - try { - auto graph = sd::graph::GraphExecutioner::importFromFlatPointer(flatBufferPointer); - sd::graph::GraphHolder::getInstance().registerGraph(graphId, graph); - - return ND4J_STATUS_OK; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return 1; + for (auto v : context->fastpath_out()) { + if (!v->isEmpty()) v->syncToDevice(); } -} - - -static VariablesSet* executeStoredGraphT(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs) { - auto graph = sd::graph::GraphHolder::getInstance().pullGraph(graphId); - auto varSpace = graph->getVariableSpace()->clone(); - - std::vector handles; - - for (int e = 0; e < numInputs; e++) { - auto idx = inputIndices[e]; - // we'll delete this array later, together with cloned VariableSpace - auto array = new sd::NDArray(inputBuffers[e], reinterpret_cast(inputShapes[e])); - handles.emplace_back(array); - - if (varSpace->hasVariable(idx)) { - auto var = varSpace->getVariable(idx); - if (var->hasNDArray()) - delete var->getNDArray(); - - var->setNDArray(array); - } else - varSpace->putVariable(idx, array); - } - - auto dZ = sd::graph::GraphExecutioner::execute(graph, varSpace); - auto varSet = new sd::graph::VariablesSet(dZ); - - if (dZ == ND4J_STATUS_OK) { - // pull back results, and provide them - auto outputs = graph->fetchOutputs(); - for (int e = 0; e < outputs->size(); e++) { - // we're only getting variable ID/Index from original grap. values will be taken from cloned workspace - std::pair varId(outputs->at(e)->id(), outputs->at(e)->index()); - - auto var = varSpace->getVariable(varId); - - varSet->push_back(var->clone()); - } + return result; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return 1; + } +} - delete outputs; - } +int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, + Nd4jPointer flatBufferPointer) { + try { + auto graph = sd::graph::Graph::fromFlatPointer(flatBufferPointer); - delete varSpace; + sd::graph::GraphHolder::getInstance().registerGraph(graphId, graph); - return varSet; + return ND4J_STATUS_OK; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return 1; + } } -VariablesSet* executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs) { - try { - return executeStoredGraphT(extraPointers, graphId, inputBuffers, inputShapes, inputIndices, numInputs); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } +VariablesSet *executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, + Nd4jPointer *inputBuffers, + Nd4jPointer *inputShapes, int *inputIndices, + int numInputs) { + try { + throw std::runtime_error("Not implemented yet"); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } -Nd4jLong getVariablesSetSize(sd::graph::VariablesSet* set) { - return set->size(); +Nd4jLong getVariablesSetSize(sd::graph::VariablesSet *set) { + return set->size(); } -Nd4jStatus getVariablesSetStatus(sd::graph::VariablesSet* set) { - return set->status(); +Nd4jStatus getVariablesSetStatus(sd::graph::VariablesSet *set) { + return set->status(); } -sd::graph::Variable* getVariable(sd::graph::VariablesSet* set, Nd4jLong i) { - return set->at(i); +sd::graph::Variable *getVariable(sd::graph::VariablesSet *set, Nd4jLong i) { + return set->at(i); } -int getVariableId(sd::graph::Variable* variable) { - return variable->id(); -} +int getVariableId(sd::graph::Variable *variable) { return variable->id(); } -int getVariableIndex(sd::graph::Variable* variable) { - return variable->index(); +int getVariableIndex(sd::graph::Variable *variable) { + return variable->index(); } -const char* getVariableName(sd::graph::Variable* variable) { - return variable->getName()->c_str(); +const char *getVariableName(sd::graph::Variable *variable) { + return variable->name().c_str(); } -Nd4jLong const* getVariableShape(sd::graph::Variable* variable) { - return variable->getNDArray()->shapeInfo(); +Nd4jLong const *getVariableShape(sd::graph::Variable *variable) { + return variable->getNDArray()->shapeInfo(); } -void* getVariableBuffer(sd::graph::Variable* variable) { - return variable->getNDArray()->buffer(); +void *getVariableBuffer(sd::graph::Variable *variable) { + return variable->getNDArray()->buffer(); } int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId) { - try { - sd::graph::GraphHolder::getInstance().dropGraphAny(graphId); - - return ND4J_STATUS_OK; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return 1; - } + try { + sd::graph::GraphHolder::getInstance().forgetGraph(graphId); + + return Status::OK(); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return 1; + } } void deletePointerArray(Nd4jPointer pointer) { - Nd4jPointer *ptr = reinterpret_cast(pointer); - delete[] ptr; + Nd4jPointer *ptr = reinterpret_cast(pointer); + delete[] ptr; } void deleteCharArray(Nd4jPointer pointer) { - auto ptr = reinterpret_cast(pointer); - delete[] ptr; + auto ptr = reinterpret_cast(pointer); + delete[] ptr; } void deleteIntArray(Nd4jPointer pointer) { - auto ptr = reinterpret_cast(pointer); - delete[] ptr; + auto ptr = reinterpret_cast(pointer); + delete[] ptr; } void deleteLongArray(Nd4jPointer pointer) { - auto ptr = reinterpret_cast(pointer); - delete[] ptr; + auto ptr = reinterpret_cast(pointer); + delete[] ptr; } -void deleteVariablesSet(sd::graph::VariablesSet* pointer) { - delete pointer; -} +void deleteVariablesSet(sd::graph::VariablesSet *pointer) { delete pointer; } void deleteShapeList(Nd4jPointer shapeList) { - sd::ShapeList* list = reinterpret_cast(shapeList); + sd::ShapeList *list = reinterpret_cast(shapeList); - //list->destroy(); - delete list; + // list->destroy(); + delete list; } -const char* getAllOperations() { - return sd::OpTracker::getInstance().exportOperations(); +const char *getAllOperations() { + return sd::OpTracker::getInstance().exportOperations(); } -Nd4jPointer getGraphState(Nd4jLong id) { - return (Nd4jPointer) new sd::graph::GraphState(id); -} -void deleteGraphState(Nd4jPointer state) { - auto stateP = reinterpret_cast(state); - delete stateP; -} - - -Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, sd::graph::GraphState *state, Nd4jLong opHash, Nd4jLong *scopes, int numScopes, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputs, Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs) { - /** - * That's basically exec, with VariableSpace provided in GraphState: - * depending on operation (i.e. while of if), different logic executors could be used - */ - - auto graph = state->graph(); - auto varSpace = state->variableSpace(); - - // Node is dynamically created, and has nothing beyond it: only inputs and outputs - // this node has id of 0, and inputs are - Node node(OpType_LOGIC, opHash, 0); - - // mapping inputs - for (int e = 0; e < numInputs; e++) { - auto buffer = inputBuffers[e]; - auto shapeInfo = reinterpret_cast(inputShapes[e]); - - auto array = new sd::NDArray(buffer, shapeInfo, varSpace->launchContext()); - - // now we just put array to VarSpace - varSpace->putVariable(0, e, array); - node.pickInput(0, e); - } - - // mapping scopes - for (int e = 0; e < numScopes; e++) { - // we should check scope existence in GraphState/Graph - int scopeId = (int) scopes[e]; - if (!state->hasScope(scopeId)) { - // nd4j_printf("execCustomOpWithScope: referenced scope [%i] doesn't exist\n", scopeId); - return Status::THROW(); - } - node.pickInput(scopeId, 0); - } - - auto dZ = LogicExecutor::processNode(graph, &node); - if (dZ != Status::OK()) - return dZ; - - // mapping outputs - - for (int e = 0; e < numOutputs; e++) { - auto buffer = outputBuffers[e]; - auto shapeInfo = reinterpret_cast(outputShapes[e]); - - NDArray array(buffer, shapeInfo, varSpace->launchContext()); - - // now we just put array to VarSpace to the same ID - //varSpace->putVariable(0, e, array); - - auto t = varSpace->getVariable(0, e)->getNDArray(); - array.assign(t); - } - - // removing input variables - for (int e = 0; e < numInputs; e++) { - varSpace->dropVariable(0, e); - } - - // after some bla-bla-bla we should have Graph and Node for current op - return Status::OK(); -} - - -Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPointer state, Nd4jLong opHash, Nd4jLong *scopes, int numScopes, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputs, Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs) { - try { - return execCustomOpWithScope(extraPointers, reinterpret_cast(state), opHash, scopes, - numScopes, inputBuffers, inputShapes, numInputs, outputBuffers, outputShapes, - numOutputs); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return 1; - } +Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPointer state, + Nd4jLong opHash, Nd4jLong *scopes, + int numScopes, Nd4jPointer *inputBuffers, + Nd4jPointer *inputShapes, int numInputs, + Nd4jPointer *outputBuffers, + Nd4jPointer *outputShapes, int numOutputs) { + return 0; } void deleteResultWrapper(Nd4jPointer ptr) { - // just 0 room for compiler s@!t - auto p = reinterpret_cast(ptr); - delete p; + // just 0 room for compiler s@!t + auto p = reinterpret_cast(ptr); + delete p; } -int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer dX, Nd4jLong const* dXShapeInfo, int N, float threshold) { - throw std::runtime_error("estimateThreshold: Not implemented yet"); +int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer dX, + Nd4jLong const *dXShapeInfo, int N, float threshold) { + throw std::runtime_error("estimateThreshold: Not implemented yet"); } /* * TypeDef: - * void convertTypes(Nd4jPointer *extras, int srcType, Nd4jPointer dX, long N, int dstType, Nd4jPointer dZ); + * void convertTypes(Nd4jPointer *extras, int srcType, Nd4jPointer dX, long + * N, int dstType, Nd4jPointer dZ); */ -void convertTypes(Nd4jPointer *extras, int srcType, Nd4jPointer dX, Nd4jLong N, int dstType, Nd4jPointer dZ) { - try { - auto dx = reinterpret_cast(dX); - auto dz = reinterpret_cast(dZ); - - if (srcType == ND4J_FLOAT8) { - if (dstType == ND4J_FLOAT8) { - // convertKernel(extras, dx, N, dz); - } else if (dstType == ND4J_INT8) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT8) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT16) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT16) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT16) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT24) { - - } else if (dstType == ND4J_FLOAT32) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_DOUBLE) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_INT8) { - if (dstType == ND4J_FLOAT8) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT8) { - //convertKernel(extras, dx, N, dz); - } else if (dstType == ND4J_UINT8) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT24) { - // TODO: eventually we might want to add it - } else if (dstType == ND4J_FLOAT32) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_DOUBLE) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_UINT8) { - if (dstType == ND4J_FLOAT8) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT8) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT8) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT24) { - // TODO: still might want to add - } else if (dstType == ND4J_FLOAT32) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_DOUBLE) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_FLOAT16) { - if (dstType == ND4J_FLOAT8) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT8) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT8) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT24) { - // TODO: .... ^^^ - } else if (dstType == ND4J_FLOAT32) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_DOUBLE) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_THRESHOLD) { - //sd::convertToThreshold(nullptr, dx, N, dz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_INT16) { - if (dstType == ND4J_FLOAT8) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT8) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT8) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT24) { - // TODO... - } else if (dstType == ND4J_FLOAT32) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_DOUBLE) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else { - printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_FLOAT24) { - - } else if (srcType == ND4J_FLOAT32) { - if (dstType == ND4J_FLOAT8) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT8) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT8) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT24) { - - } else if (dstType == ND4J_DOUBLE) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_THRESHOLD) { - //sd::convertToThreshold(nullptr, dx, N, dz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_DOUBLE) { - if (dstType == ND4J_FLOAT8) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT8) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT8) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT24) { - - } else if (dstType == ND4J_FLOAT32) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_DOUBLE) { - // - } else if (dstType == ND4J_THRESHOLD) { - //sd::convertToThreshold(nullptr, dx, N, dz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_THRESHOLD) { - if (dstType == ND4J_FLOAT16) { - //sd::convertFromThreshold(nullptr, dx, N, dz); - } else if (dstType == ND4J_FLOAT32) { - //sd::convertFromThreshold(nullptr, dx, N, dz); - } else if (dstType == ND4J_DOUBLE) { - //sd::convertFromThreshold(nullptr, dx, N, dz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); +void convertTypes(Nd4jPointer *extras, int srcType, Nd4jPointer dX, Nd4jLong N, + int dstType, Nd4jPointer dZ) { + try { + auto dx = reinterpret_cast(dX); + auto dz = reinterpret_cast(dZ); + + if (srcType == ND4J_FLOAT8) { + if (dstType == ND4J_FLOAT8) { + // convertKernel(extras, dx, N, dz); + } else if (dstType == ND4J_INT8) { + // sd::TypeCast::convertGenericCuda(extras, dx, N, + // dz); + } else if (dstType == ND4J_UINT8) { + // sd::TypeCast::convertGenericCuda(extras, dx, + // N, dz); + } else if (dstType == ND4J_FLOAT16) { + // sd::TypeCast::convertGenericCuda(extras, dx, N, + // dz); + } else if (dstType == ND4J_INT16) { + // sd::TypeCast::convertGenericCuda(extras, dx, + // N, dz); + } else if (dstType == ND4J_UINT16) { + // sd::TypeCast::convertGenericCuda(extras, dx, + // N, dz); + } else if (dstType == ND4J_FLOAT24) { + } else if (dstType == ND4J_FLOAT32) { + // sd::TypeCast::convertGenericCuda(extras, dx, N, + // dz); + } else if (dstType == ND4J_DOUBLE) { + // sd::TypeCast::convertGenericCuda(extras, dx, N, + // dz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else if (srcType == ND4J_INT8) { + if (dstType == ND4J_FLOAT8) { + // sd::TypeCast::convertGenericCuda(extras, dx, N, + // dz); + } else if (dstType == ND4J_INT8) { + // convertKernel(extras, dx, N, dz); + } else if (dstType == ND4J_UINT8) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_INT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT24) { + // TODO: eventually we might want to add it + } else if (dstType == ND4J_FLOAT32) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_DOUBLE) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else if (srcType == ND4J_UINT8) { + if (dstType == ND4J_FLOAT8) { + // sd::TypeCast::convertGenericCuda(extras, dx, N, + // dz); + } else if (dstType == ND4J_INT8) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT8) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_INT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT24) { + // TODO: still might want to add + } else if (dstType == ND4J_FLOAT32) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_DOUBLE) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else if (srcType == ND4J_FLOAT16) { + if (dstType == ND4J_FLOAT8) { + // sd::TypeCast::convertGenericCuda(extras, dx, N, + // dz); + } else if (dstType == ND4J_INT8) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT8) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_INT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT24) { + // TODO: .... ^^^ + } else if (dstType == ND4J_FLOAT32) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_DOUBLE) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_THRESHOLD) { + // sd::convertToThreshold(nullptr, dx, N, dz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else if (srcType == ND4J_INT16) { + if (dstType == ND4J_FLOAT8) { + // sd::TypeCast::convertGenericCuda(extras, dx, N, + // dz); + } else if (dstType == ND4J_INT8) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT8) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_INT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT24) { + // TODO... + } else if (dstType == ND4J_FLOAT32) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_DOUBLE) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else { + printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else if (srcType == ND4J_FLOAT24) { + } else if (srcType == ND4J_FLOAT32) { + if (dstType == ND4J_FLOAT8) { + // sd::TypeCast::convertGenericCuda(extras, dx, N, + // dz); + } else if (dstType == ND4J_INT8) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT8) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_INT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT24) { + } else if (dstType == ND4J_DOUBLE) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_THRESHOLD) { + // sd::convertToThreshold(nullptr, dx, N, dz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else if (srcType == ND4J_DOUBLE) { + if (dstType == ND4J_FLOAT8) { + // sd::TypeCast::convertGenericCuda(extras, dx, N, + // dz); + } else if (dstType == ND4J_INT8) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT8) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_INT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT24) { + } else if (dstType == ND4J_FLOAT32) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_DOUBLE) { + // + } else if (dstType == ND4J_THRESHOLD) { + // sd::convertToThreshold(nullptr, dx, N, dz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else if (srcType == ND4J_THRESHOLD) { + if (dstType == ND4J_FLOAT16) { + // sd::convertFromThreshold(nullptr, dx, N, dz); + } else if (dstType == ND4J_FLOAT32) { + // sd::convertFromThreshold(nullptr, dx, N, dz); + } else if (dstType == ND4J_DOUBLE) { + // sd::convertFromThreshold(nullptr, dx, N, dz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); } + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } -Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, int length) { - auto u = new sd::utf8string(string, length); - return reinterpret_cast(u); +Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, + int length) { + auto u = new sd::utf8string(string, length); + return reinterpret_cast(u); } Nd4jLong getUtf8StringLength(Nd4jPointer *extraPointers, Nd4jPointer ptr) { - return reinterpret_cast(ptr)->_length; + return reinterpret_cast(ptr)->_length; } -char* getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer ptr) { - return reinterpret_cast(ptr)->_buffer; +char *getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer ptr) { + return reinterpret_cast(ptr)->_buffer; } void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr) { - delete(reinterpret_cast(ptr)); + delete (reinterpret_cast(ptr)); } /////////////////////////////////////////////////////////////////// -template -__global__ static void scatterUpdateCuda(const int opCode, const int numOfSubArrs, - void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong *xOffsets, - void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, - const void* vindexes) { - - __shared__ T *x, *y; - __shared__ Nd4jLong arrLenX, arrLenY; - auto indexes = reinterpret_cast(vindexes); - - for (int e = 0; e < numOfSubArrs; e++ ) { - - const auto xIndex = indexes[e]; - const bool isOwner = xIndex < gridDim.x ? blockIdx.x == xIndex : blockIdx.x == xIndex % gridDim.x; - - if (!isOwner) - continue; +template +__global__ static void scatterUpdateCuda(const int opCode, + const int numOfSubArrs, void *vx, + const Nd4jLong *xShapeInfo, + const Nd4jLong *xOffsets, void *vy, + const Nd4jLong *yShapeInfo, + const Nd4jLong *yOffsets, + const void *vindexes) { + __shared__ T *x, *y; + __shared__ Nd4jLong arrLenX, arrLenY; + auto indexes = reinterpret_cast(vindexes); + + for (int e = 0; e < numOfSubArrs; e++) { + const auto xIndex = indexes[e]; + const bool isOwner = xIndex < gridDim.x ? blockIdx.x == xIndex + : blockIdx.x == xIndex % gridDim.x; + + if (!isOwner) continue; + + if (threadIdx.x == 0) { + x = reinterpret_cast(vx) + xOffsets[xIndex]; + y = reinterpret_cast(vy) + yOffsets[e]; + arrLenX = shape::length(xShapeInfo); + arrLenY = shape::length(yShapeInfo); + } + __syncthreads(); - if (threadIdx.x == 0) { - x = reinterpret_cast(vx) + xOffsets[xIndex]; - y = reinterpret_cast(vy) + yOffsets[e]; - arrLenX = shape::length(xShapeInfo); - arrLenY = shape::length(yShapeInfo); - } - __syncthreads(); - - if (arrLenX != arrLenY) - return; - - for (Nd4jLong i = threadIdx.x; i < arrLenX; i += blockDim.x) { - - const auto xOffset = shape::getIndexOffset(i, xShapeInfo); - const auto yOffset = shape::getIndexOffset(i, yShapeInfo); - - switch (opCode) { - case 0: - x[xOffset] += y[yOffset]; - break; - case 1: - x[xOffset] -= y[yOffset]; - break; - case 2: - x[xOffset] *= y[yOffset]; - break; - case 3: - x[xOffset] /= y[yOffset]; - break; - case 4: - x[xOffset] = y[yOffset] - x[xOffset]; - break; - case 5: - x[xOffset] = y[yOffset] / x[xOffset]; - break; - case 6: - x[xOffset] = y[yOffset]; - break; - default: - continue; - } - } - __syncthreads(); + if (arrLenX != arrLenY) return; + + for (Nd4jLong i = threadIdx.x; i < arrLenX; i += blockDim.x) { + const auto xOffset = shape::getIndexOffset(i, xShapeInfo); + const auto yOffset = shape::getIndexOffset(i, yShapeInfo); + + switch (opCode) { + case 0: + x[xOffset] += y[yOffset]; + break; + case 1: + x[xOffset] -= y[yOffset]; + break; + case 2: + x[xOffset] *= y[yOffset]; + break; + case 3: + x[xOffset] /= y[yOffset]; + break; + case 4: + x[xOffset] = y[yOffset] - x[xOffset]; + break; + case 5: + x[xOffset] = y[yOffset] / x[xOffset]; + break; + case 6: + x[xOffset] = y[yOffset]; + break; + default: + continue; + } } + __syncthreads(); + } } -template -__host__ static void scatterUpdateCudaLauncher(const cudaStream_t* stream, const int opCode, const int numOfSubArrs, void* vx, const Nd4jLong const* xShapeInfo, const Nd4jLong* xOffsets, void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, const void* indexes) { - - scatterUpdateCuda<<<512, 256, MAX_NUM_THREADS, *stream>>>(opCode, numOfSubArrs, vx, xShapeInfo, xOffsets, vy, yShapeInfo, yOffsets, indexes); +template +__host__ static void scatterUpdateCudaLauncher( + const cudaStream_t *stream, const int opCode, const int numOfSubArrs, + void *vx, const Nd4jLong const *xShapeInfo, const Nd4jLong *xOffsets, + void *vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, + const void *indexes) { + scatterUpdateCuda<<<512, 256, MAX_NUM_THREADS, *stream>>>( + opCode, numOfSubArrs, vx, xShapeInfo, xOffsets, vy, yShapeInfo, yOffsets, + indexes); } - ////////////////////////////////////////////////////////////////////////// void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs, - void* hX, Nd4jLong const* hXShapeInfo, Nd4jLong const* hXOffsets, - void* dX, Nd4jLong const* dXShapeInfo, Nd4jLong const* dXOffsets, - void* hY, Nd4jLong const* hYShapeInfo, Nd4jLong const* hYOffsets, - void* dY, Nd4jLong const* dYShapeInfo, Nd4jLong const* dYOffsets, - void* hIindexes, Nd4jLong const* hIndicesShapeInfo, void* dIindexes, Nd4jLong const* dIndicesShapeInfo) { - try { - auto stream = reinterpret_cast(extraPointers[1]); - - auto type = ArrayOptions::dataType(hXShapeInfo); - auto iType = ArrayOptions::dataType(hIndicesShapeInfo); - - BUILD_DOUBLE_SELECTOR(type, iType, scatterUpdateCudaLauncher, - (stream, opCode, numOfSubArrs, dX, dXShapeInfo, dXOffsets, dY, dYShapeInfo, dYOffsets, dIindexes), - LIBND4J_TYPES, INDEXING_TYPES); - - sd::DebugHelper::checkErrorCode(stream, "scatterUpdate(...) failed"); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo) { - try { - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - auto p = reinterpret_cast(debugInfo); - NDArray array(buffer, specialBuffer, shapeInfo, &lc); - sd::DebugHelper::retrieveDebugStatistics(p, &array); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void __global__ tryPointerKernel(void* p, int len) { - auto buf = reinterpret_cast(p); - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - __shared__ int b; - if (tid < len) - atomicAdd(&b, buf[tid]); - - __syncthreads(); - - if (threadIdx.x ==0 && blockIdx.x == 0) - printf("Pointer check complete: %i\n", b); + void *hX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *hXOffsets, void *dX, + Nd4jLong const *dXShapeInfo, Nd4jLong const *dXOffsets, + void *hY, Nd4jLong const *hYShapeInfo, + Nd4jLong const *hYOffsets, void *dY, + Nd4jLong const *dYShapeInfo, Nd4jLong const *dYOffsets, + void *hIindexes, Nd4jLong const *hIndicesShapeInfo, + void *dIindexes, Nd4jLong const *dIndicesShapeInfo) { + try { + auto stream = reinterpret_cast(extraPointers[1]); + + auto type = ArrayOptions::dataType(hXShapeInfo); + auto iType = ArrayOptions::dataType(hIndicesShapeInfo); + + BUILD_DOUBLE_SELECTOR(type, iType, scatterUpdateCudaLauncher, + (stream, opCode, numOfSubArrs, dX, dXShapeInfo, + dXOffsets, dY, dYShapeInfo, dYOffsets, dIindexes), + LIBND4J_TYPES, INDEXING_TYPES); + + sd::DebugHelper::checkErrorCode(stream, "scatterUpdate(...) failed"); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, + Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, + Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo) { + try { + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + auto p = reinterpret_cast(debugInfo); + NDArray array(buffer, specialBuffer, shapeInfo, &lc); + sd::DebugHelper::retrieveDebugStatistics(p, &array); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void __global__ tryPointerKernel(void *p, int len) { + auto buf = reinterpret_cast(p); + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + __shared__ int b; + if (tid < len) atomicAdd(&b, buf[tid]); + + __syncthreads(); + + if (threadIdx.x == 0 && blockIdx.x == 0) + printf("Pointer check complete: %i\n", b); } void tryPointer(Nd4jPointer extra, Nd4jPointer p, int len) { - try { - cudaStream_t stream; - cudaStreamCreate(&stream); + try { + cudaStream_t stream; + cudaStreamCreate(&stream); - tryPointerKernel <<< 256, 512, len + 64, stream>>> (p, len); - auto e = cudaStreamSynchronize(stream); + tryPointerKernel<<<256, 512, len + 64, stream>>>(p, len); + auto e = cudaStreamSynchronize(stream); - if (e != 0) - throw sd::cuda_exception::build("tryPointer failed", e); + if (e != 0) throw sd::cuda_exception::build("tryPointer failed", e); - cudaStreamDestroy(stream); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + cudaStreamDestroy(stream); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } int dataTypeFromNpyHeader(void *header) { - return (int) cnpy::dataTypeFromHeader(reinterpret_cast(header)); -} - -OpaqueConstantShapeBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, sd::DataType dtype, char order, Nd4jLong ews, bool empty) { - try { - auto buffer = new ConstantShapeBuffer(); - *buffer = sd::ConstantShapeHelper::getInstance().bufferForShapeInfo( - ShapeDescriptor(dtype, order, shape, strides, rank, ews, empty)); - return buffer; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } + return (int)cnpy::dataTypeFromHeader(reinterpret_cast(header)); } -void deleteConstantShapeBuffer(OpaqueConstantShapeBuffer* ptr) { - delete ptr; +OpaqueConstantShapeBuffer *shapeBuffer(int rank, Nd4jLong *shape, + Nd4jLong *strides, sd::DataType dtype, + char order, Nd4jLong ews, bool empty) { + try { + auto buffer = new ConstantShapeBuffer(); + *buffer = sd::ConstantShapeHelper::getInstance().bufferForShapeInfo( + ShapeDescriptor(dtype, order, shape, strides, rank, ews, empty)); + return buffer; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } +void deleteConstantShapeBuffer(OpaqueConstantShapeBuffer *ptr) { delete ptr; } + void deleteConstantDataBuffer(OpaqueConstantDataBuffer* ptr) { delete ptr; } -void deleteTadPack(sd::TadPack* ptr) { - delete ptr; -} +void deleteTadPack(sd::TadPack *ptr) { delete ptr; } bool isBlasVersionMatches(int major, int minor, int build) { - auto result = major == Environment::getInstance()._blasMajorVersion && minor == Environment::getInstance()._blasMinorVersion && build == Environment::getInstance()._blasPatchVersion; + auto result = major == Environment::getInstance()._blasMajorVersion && + minor == Environment::getInstance()._blasMinorVersion && + build == Environment::getInstance()._blasPatchVersion; - if (!result) { - nd4j_printf("CUDA/cuBLAS version mismatch. Expected: %i.%i.%i but got %i.%i.%i instead\n", Environment::getInstance()._blasMajorVersion, Environment::getInstance()._blasMinorVersion, Environment::getInstance()._blasPatchVersion, major, minor, build); - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(152); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("CUDA/cuBLAS version mismatch"); - } + if (!result) { + nd4j_printf( + "CUDA/cuBLAS version mismatch. Expected: %i.%i.%i but got %i.%i.%i " + "instead\n", + Environment::getInstance()._blasMajorVersion, + Environment::getInstance()._blasMinorVersion, + Environment::getInstance()._blasPatchVersion, major, minor, build); + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(152); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "CUDA/cuBLAS version mismatch"); + } - return result; + return result; } -sd::ConstantDataBuffer* constantBufferLong(sd::DataType dtype, Nd4jLong const* data, int length) { - return sd::ConstantHelper::getInstance().constantBuffer(ConstantDescriptor(data, length), dtype); +sd::ConstantDataBuffer *constantBufferLong(sd::DataType dtype, + Nd4jLong const *data, int length) { + return sd::ConstantHelper::getInstance().constantBuffer( + ConstantDescriptor(data, length), dtype); } -sd::ConstantDataBuffer* constantBufferDouble(sd::DataType dtype, double *data, int length) { - return sd::ConstantHelper::getInstance().constantBuffer(ConstantDescriptor(data, length), dtype); +sd::ConstantDataBuffer *constantBufferDouble(sd::DataType dtype, double *data, + int length) { + return sd::ConstantHelper::getInstance().constantBuffer( + ConstantDescriptor(data, length), dtype); } -sd::ConstantDataBuffer* constantBuffer(sd::DataType dtype, sd::ConstantDescriptor *descriptor) { - return sd::ConstantHelper::getInstance().constantBuffer(*descriptor, dtype); +sd::ConstantDataBuffer *constantBuffer(sd::DataType dtype, + sd::ConstantDescriptor *descriptor) { + return sd::ConstantHelper::getInstance().constantBuffer(*descriptor, dtype); } - -Nd4jPointer getConstantDataBufferPrimary(sd::ConstantDataBuffer* dbf) { - return dbf->primary(); +Nd4jPointer getConstantDataBufferPrimary(sd::ConstantDataBuffer *dbf) { + return dbf->primary(); } -Nd4jPointer getConstantDataBufferSpecial(sd::ConstantDataBuffer* dbf) { - return dbf->special(); +Nd4jPointer getConstantDataBufferSpecial(sd::ConstantDataBuffer *dbf) { + return dbf->special(); } -Nd4jLong getConstantDataBufferLength(sd::ConstantDataBuffer* dbf) { - return dbf->length(); +Nd4jLong getConstantDataBufferLength(sd::ConstantDataBuffer *dbf) { + return dbf->length(); } -Nd4jLong getConstantDataBufferSizeOf(sd::ConstantDataBuffer* dbf) { - return dbf->sizeOf(); +Nd4jLong getConstantDataBufferSizeOf(sd::ConstantDataBuffer *dbf) { + return dbf->sizeOf(); } Nd4jPointer getConstantShapeBufferPrimary(sd::ConstantShapeBuffer* dbf) { @@ -3446,85 +3744,99 @@ Nd4jPointer getConstantShapeBufferPrimary(sd::ConstantShapeBuffer* dbf) { Nd4jPointer getConstantShapeBufferSpecial(sd::ConstantShapeBuffer* dbf) { return const_cast(dbf->special()); +}sd::graph::Context *createGraphContext(int nodeId) { + return new sd::graph::Context(nodeId); } -sd::graph::Context* createGraphContext(int nodeId) { - return new sd::graph::Context(nodeId); -} - -sd::graph::RandomGenerator* getGraphContextRandomGenerator(sd::graph::Context* ptr) { - return &ptr->randomGenerator(); +sd::graph::RandomGenerator *getGraphContextRandomGenerator( + sd::graph::Context *ptr) { + return &ptr->randomGenerator(); } -void markGraphContextInplace(sd::graph::Context* ptr, bool reallyInplace) { - ptr->markInplace(reallyInplace); +void markGraphContextInplace(sd::graph::Context *ptr, bool reallyInplace) { + ptr->markInplace(reallyInplace); } -void setGraphContextCudaContext(sd::graph::Context* ptr, void *stream, void *reductionPointer, void *allocationPointer) { - ptr->setCudaContext(stream, reductionPointer, allocationPointer); +void setGraphContextCudaContext(sd::graph::Context *ptr, void *stream, + void *reductionPointer, + void *allocationPointer) { + ptr->setCudaContext(stream, reductionPointer, allocationPointer); } -void setGraphContextInputArray(sd::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) { - ptr->setInputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo); +void setGraphContextInputArray(sd::graph::Context *ptr, int index, void *buffer, + void *shapeInfo, void *specialBuffer, + void *specialShapeInfo) { + ptr->setInputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo); } -void setGraphContextOutputArray(sd::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) { - ptr->setOutputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo); +void setGraphContextOutputArray(sd::graph::Context *ptr, int index, + void *buffer, void *shapeInfo, + void *specialBuffer, void *specialShapeInfo) { + ptr->setOutputArray(index, buffer, shapeInfo, specialBuffer, + specialShapeInfo); } -void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo) { - ptr->setInputArray(index, buffer, shapeInfo, specialShapeInfo); +void setGraphContextInputBuffer(OpaqueContext *ptr, int index, + OpaqueDataBuffer *buffer, void *shapeInfo, + void *specialShapeInfo) { + ptr->setInputArray(index, buffer, shapeInfo, specialShapeInfo); } -void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo) { - ptr->setOutputArray(index, buffer, shapeInfo, specialShapeInfo); +void setGraphContextOutputBuffer(OpaqueContext *ptr, int index, + OpaqueDataBuffer *buffer, void *shapeInfo, + void *specialShapeInfo) { + ptr->setOutputArray(index, buffer, shapeInfo, specialShapeInfo); } -void setGraphContextTArguments(sd::graph::Context* ptr, double *arguments, int numberOfArguments) { - ptr->setTArguments(arguments, numberOfArguments); +void setGraphContextTArguments(sd::graph::Context *ptr, double *arguments, + int numberOfArguments) { + ptr->setTArguments(arguments, numberOfArguments); } -void setGraphContextIArguments(sd::graph::Context* ptr, Nd4jLong *arguments, int numberOfArguments) { - ptr->setIArguments(arguments, numberOfArguments); +void setGraphContextIArguments(sd::graph::Context *ptr, Nd4jLong *arguments, + int numberOfArguments) { + ptr->setIArguments(arguments, numberOfArguments); } -void setGraphContextBArguments(sd::graph::Context* ptr, bool *arguments, int numberOfArguments) { - ptr->setBArguments(arguments, numberOfArguments); +void setGraphContextBArguments(sd::graph::Context *ptr, bool *arguments, + int numberOfArguments) { + ptr->setBArguments(arguments, numberOfArguments); } -void setGraphContextDArguments(OpaqueContext* ptr, int *arguments, int numberOfArguments) { - std::vector dtypes(numberOfArguments); - for (int e = 0; e < numberOfArguments; e++) - dtypes[e] = (sd::DataType) arguments[e]; +void setGraphContextDArguments(OpaqueContext *ptr, int *arguments, + int numberOfArguments) { + std::vector dtypes(numberOfArguments); + for (int e = 0; e < numberOfArguments; e++) + dtypes[e] = (sd::DataType)arguments[e]; - ptr->setDArguments(dtypes); + ptr->setDArguments(dtypes); } -void deleteGraphContext(sd::graph::Context* ptr) { - delete ptr; -} +void deleteGraphContext(sd::graph::Context *ptr) { delete ptr; } - -sd::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) { - try { - return new sd::graph::RandomGenerator(rootSeed, nodeSeed); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } +sd::graph::RandomGenerator *createRandomGenerator(Nd4jLong rootSeed, + Nd4jLong nodeSeed) { + try { + return new sd::graph::RandomGenerator(rootSeed, nodeSeed); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } -Nd4jLong getRandomGeneratorRootState(sd::graph::RandomGenerator* ptr) { - return ptr->rootState(); +Nd4jLong getRandomGeneratorRootState(sd::graph::RandomGenerator *ptr) { + return ptr->rootState(); } -Nd4jLong getRandomGeneratorNodeState(sd::graph::RandomGenerator* ptr) { - return ptr->nodeState(); +Nd4jLong getRandomGeneratorNodeState(sd::graph::RandomGenerator *ptr) { + return ptr->nodeState(); } -void setRandomGeneratorStates(sd::graph::RandomGenerator* ptr, Nd4jLong rootSeed, Nd4jLong nodeSeed) { - ptr->setStates(rootSeed, nodeSeed); +void setRandomGeneratorStates(sd::graph::RandomGenerator *ptr, + Nd4jLong rootSeed, Nd4jLong nodeSeed) { + ptr->setStates(rootSeed, nodeSeed); } float getRandomGeneratorRelativeFloat(sd::graph::RandomGenerator* ptr, Nd4jLong index) { @@ -3535,289 +3847,289 @@ double getRandomGeneratorRelativeDouble(sd::graph::RandomGenerator* ptr, Nd4jLon return ptr->relativeT(index); } -int getRandomGeneratorRelativeInt(sd::graph::RandomGenerator* ptr, Nd4jLong index) { - return ptr->relativeInt(index); +int getRandomGeneratorRelativeInt(sd::graph::RandomGenerator *ptr, + Nd4jLong index) { + return ptr->relativeInt(index); } -Nd4jLong getRandomGeneratorRelativeLong(sd::graph::RandomGenerator* ptr, Nd4jLong index) { - return ptr->relativeLong(index); -} - -void deleteRandomGenerator(sd::graph::RandomGenerator* ptr) { - delete ptr; +Nd4jLong getRandomGeneratorRelativeLong(sd::graph::RandomGenerator *ptr, + Nd4jLong index) { + return ptr->relativeLong(index); } +void deleteRandomGenerator(sd::graph::RandomGenerator *ptr) { delete ptr; } Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) { - try { - cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast(npyArray)); - unsigned int shapeSize = arr.shape.size(); - std::vector shape(shapeSize); - bool _empty = false; - for (unsigned int i = 0; i < shapeSize; i++) { - shape[i] = arr.shape[i]; - - if (arr.shape[i] == 0) - _empty = true; - } - - auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast(npyArray)); - - Nd4jLong *shapeBuffer; - if (shape.size() == 1 && shape[0] == 0) { - // scalar case - shapeBuffer = sd::ShapeBuilders::createScalarShapeInfo(dtype); - } else if (_empty) { - if (shapeSize > 0) - shapeBuffer = sd::ShapeBuilders::emptyShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); - else - shapeBuffer = sd::ShapeBuilders::emptyShapeInfo(dtype); - } else { - shapeBuffer = sd::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); - } - return (Nd4jPointer)(sd::ConstantShapeHelper::getInstance().createFromExisting(shapeBuffer, true)); // TO DO: this can lead to unpleasant crash sometimes - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; + try { + cnpy::NpyArray arr = + cnpy::loadNpyFromPointer(reinterpret_cast(npyArray)); + unsigned int shapeSize = arr.shape.size(); + std::vector shape(shapeSize); + bool _empty = false; + for (unsigned int i = 0; i < shapeSize; i++) { + shape[i] = arr.shape[i]; + + if (arr.shape[i] == 0) _empty = true; } -} -const char* runLightBenchmarkSuit(bool printOut) { - try { - sd::LightBenchmarkSuit suit; - auto result = suit.runSuit(); - - if (printOut) - nd4j_printf("%s\n", result.data()); - - auto chars = new char[result.length() + 1]; - std::memcpy(chars, result.data(), result.length()); - chars[result.length()] = (char) 0x0; - - return chars; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } -} - -const char* runFullBenchmarkSuit(bool printOut) { - try { - sd::FullBenchmarkSuit suit; - auto result = suit.runSuit(); - - if (printOut) - nd4j_printf("%s\n", result.data()); - - auto chars = new char[result.length() + 1]; - std::memcpy(chars, result.data(), result.length()); - chars[result.length()] = (char) 0x0; - - return chars; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; + auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast(npyArray)); + + Nd4jLong *shapeBuffer; + if (shape.size() == 1 && shape[0] == 0) { + // scalar case + shapeBuffer = sd::ShapeBuilders::createScalarShapeInfo(dtype); + } else if (_empty) { + if (shapeSize > 0) + shapeBuffer = sd::ShapeBuilders::emptyShapeInfo( + dtype, arr.fortranOrder ? 'f' : 'c', shape); + else + shapeBuffer = sd::ShapeBuilders::emptyShapeInfo(dtype); + } else { + shapeBuffer = sd::ShapeBuilders::createShapeInfo( + dtype, arr.fortranOrder ? 'f' : 'c', shape); } + return (Nd4jPointer)( + sd::ConstantShapeHelper::getInstance().createFromExisting( + shapeBuffer, + true)); // TO DO: this can lead to unpleasant crash sometimes + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } +} + +const char *runLightBenchmarkSuit(bool printOut) { + try { + sd::LightBenchmarkSuit suit; + auto result = suit.runSuit(); + + if (printOut) nd4j_printf("%s\n", result.data()); + + auto chars = new char[result.length() + 1]; + std::memcpy(chars, result.data(), result.length()); + chars[result.length()] = (char)0x0; + + return chars; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } +} + +const char *runFullBenchmarkSuit(bool printOut) { + try { + sd::FullBenchmarkSuit suit; + auto result = suit.runSuit(); + + if (printOut) nd4j_printf("%s\n", result.data()); + + auto chars = new char[result.length() + 1]; + std::memcpy(chars, result.data(), result.length()); + chars[result.length()] = (char)0x0; + + return chars; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } Nd4jLong getCachedMemory(int deviceId) { - return sd::ConstantHelper::getInstance().getCachedAmount(deviceId); + return sd::ConstantHelper::getInstance().getCachedAmount(deviceId); } -sd::LaunchContext* defaultLaunchContext() { - return LaunchContext::defaultContext(); +sd::LaunchContext *defaultLaunchContext() { + return LaunchContext::defaultContext(); } -Nd4jPointer lcScalarPointer(OpaqueLaunchContext* lc) { - return lc->getScalarPointer(); +Nd4jPointer lcScalarPointer(OpaqueLaunchContext *lc) { + return lc->getScalarPointer(); } -Nd4jPointer lcReductionPointer(OpaqueLaunchContext* lc) { - return lc->getReductionPointer(); +Nd4jPointer lcReductionPointer(OpaqueLaunchContext *lc) { + return lc->getReductionPointer(); } -Nd4jPointer lcAllocationPointer(OpaqueLaunchContext* lc) { - return lc->getAllocationPointer(); +Nd4jPointer lcAllocationPointer(OpaqueLaunchContext *lc) { + return lc->getAllocationPointer(); } -Nd4jPointer lcExecutionStream(OpaqueLaunchContext* lc) { - return lc->getCudaStream(); +Nd4jPointer lcExecutionStream(OpaqueLaunchContext *lc) { + return lc->getCudaStream(); } -Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc) { - return lc->getCudaSpecialStream(); +Nd4jPointer lcCopyStream(OpaqueLaunchContext *lc) { + return lc->getCudaSpecialStream(); } -Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc) { - return lc->getCublasHandle(); +Nd4jPointer lcBlasHandle(OpaqueLaunchContext *lc) { + return lc->getCublasHandle(); } -Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc) { - return lc->getCusolverHandle(); +Nd4jPointer lcSolverHandle(OpaqueLaunchContext *lc) { + return lc->getCusolverHandle(); } int lastErrorCode() { - return sd::LaunchContext::defaultContext()->errorReference()->errorCode(); + return sd::LaunchContext::defaultContext()->errorReference()->errorCode(); } -const char* lastErrorMessage() { - return sd::LaunchContext::defaultContext()->errorReference()->errorMessage(); +const char *lastErrorMessage() { + return sd::LaunchContext::defaultContext()->errorReference()->errorMessage(); } -void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride) { - ptr->setShapeFunctionOverride(reallyOverride); +void ctxShapeFunctionOverride(OpaqueContext *ptr, bool reallyOverride) { + ptr->setShapeFunctionOverride(reallyOverride); } -void ctxPurge(OpaqueContext* ptr) { - ptr->clearFastPath(); -} +void ctxPurge(OpaqueContext *ptr) { ptr->clearFastPath(); } -int binaryLevel() { - return 0; -} +int binaryLevel() { return 0; } -int optimalLevel() { - return 0; -} +int optimalLevel() { return 0; } -bool isMinimalRequirementsMet() { - return true; -} +bool isMinimalRequirementsMet() { return true; } -bool isOptimalRequirementsMet() { - return true; -} +bool isOptimalRequirementsMet() { return true; } -void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow) { - ptr->allowHelpers(reallyAllow); +void ctxAllowHelpers(OpaqueContext *ptr, bool reallyAllow) { + ptr->allowHelpers(reallyAllow); } -void ctxSetExecutionMode(OpaqueContext* ptr, int execMode) { - if (execMode < 0 || execMode > 2) - execMode = 0; +void ctxSetExecutionMode(OpaqueContext *ptr, int execMode) { + if (execMode < 0 || execMode > 2) execMode = 0; - ptr->setExecutionMode((samediff::ExecutionMode) execMode); + ptr->setExecutionMode((samediff::ExecutionMode)execMode); } -OpaqueDataBuffer* dbCreateExternalDataBuffer(Nd4jLong elements, int dataType, Nd4jPointer primary, Nd4jPointer special) { - auto buffer = dbAllocateDataBuffer(0, dataType, false); +OpaqueDataBuffer *dbCreateExternalDataBuffer(Nd4jLong elements, int dataType, + Nd4jPointer primary, + Nd4jPointer special) { + auto buffer = dbAllocateDataBuffer(0, dataType, false); - if (primary != nullptr) - buffer->setPrimary(primary, elements); + if (primary != nullptr) buffer->setPrimary(primary, elements); - if (special != nullptr) - buffer->setSpecial(special, elements); + if (special != nullptr) buffer->setSpecial(special, elements); - return buffer; + return buffer; } -OpaqueDataBuffer* dbAllocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) { - return allocateDataBuffer(elements, dataType, allocateBoth); +OpaqueDataBuffer *dbAllocateDataBuffer(Nd4jLong elements, int dataType, + bool allocateBoth) { + return allocateDataBuffer(elements, dataType, allocateBoth); } -OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) { - try { - auto dtype = DataTypeUtils::fromInt(dataType); - return new sd::InteropDataBuffer(elements * DataTypeUtils::sizeOf(dtype), dtype, allocateBoth); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } +OpaqueDataBuffer *allocateDataBuffer(Nd4jLong elements, int dataType, + bool allocateBoth) { + try { + auto dtype = DataTypeUtils::fromInt(dataType); + return new sd::InteropDataBuffer(elements * DataTypeUtils::sizeOf(dtype), + dtype, allocateBoth); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } Nd4jPointer dbPrimaryBuffer(OpaqueDataBuffer *dataBuffer) { - return dataBuffer->primary(); + return dataBuffer->primary(); } Nd4jPointer dbSpecialBuffer(OpaqueDataBuffer *dataBuffer) { - return dataBuffer->special(); + return dataBuffer->special(); } -void deleteDataBuffer(OpaqueDataBuffer *dataBuffer) { - delete dataBuffer; -} +void deleteDataBuffer(OpaqueDataBuffer *dataBuffer) { delete dataBuffer; } -void dbSetPrimaryBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer primaryBuffer, Nd4jLong numBytes) { - dataBuffer->setPrimary(primaryBuffer, numBytes); +void dbSetPrimaryBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer primaryBuffer, + Nd4jLong numBytes) { + dataBuffer->setPrimary(primaryBuffer, numBytes); } -void dbSetSpecialBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer specialBuffer, Nd4jLong numBytes) { - dataBuffer->setSpecial(specialBuffer, numBytes); +void dbSetSpecialBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer specialBuffer, + Nd4jLong numBytes) { + dataBuffer->setSpecial(specialBuffer, numBytes); } void dbAllocatePrimaryBuffer(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->allocatePrimary(); + dataBuffer->dataBuffer()->allocatePrimary(); } void dbAllocateSpecialBuffer(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->allocateSpecial(); + dataBuffer->dataBuffer()->allocateSpecial(); } void dbExpandBuffer(OpaqueDataBuffer *dataBuffer, Nd4jLong elements) { - try { - dataBuffer->dataBuffer()->expand(elements * DataTypeUtils::sizeOf(dataBuffer->dataBuffer()->getDataType())); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + try { + dataBuffer->dataBuffer()->expand( + elements * + DataTypeUtils::sizeOf(dataBuffer->dataBuffer()->getDataType())); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } -OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, Nd4jLong offset) { - return new InteropDataBuffer(*dataBuffer, length, offset); +OpaqueDataBuffer *dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, + Nd4jLong offset) { + return new InteropDataBuffer(*dataBuffer, length, offset); } void dbSyncToSpecial(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->syncToSpecial(); + dataBuffer->dataBuffer()->syncToSpecial(); } void dbSyncToPrimary(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->syncToPrimary(nullptr); + dataBuffer->dataBuffer()->syncToPrimary(nullptr); } void dbTickHostRead(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->readPrimary(); + dataBuffer->dataBuffer()->readPrimary(); } void dbTickHostWrite(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->writePrimary(); + dataBuffer->dataBuffer()->writePrimary(); } void dbTickDeviceRead(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->readSpecial(); + dataBuffer->dataBuffer()->readSpecial(); } void dbTickDeviceWrite(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->writeSpecial(); + dataBuffer->dataBuffer()->writeSpecial(); } void dbExpand(OpaqueDataBuffer *dataBuffer, Nd4jLong elements) { - dataBuffer->expand(elements); + dataBuffer->expand(elements); } void dbClose(OpaqueDataBuffer *dataBuffer) { - dataBuffer->getDataBuffer()->close(); + dataBuffer->getDataBuffer()->close(); } -int dbDeviceId(OpaqueDataBuffer *dataBuffer) { - return dataBuffer->deviceId(); -} +int dbDeviceId(OpaqueDataBuffer *dataBuffer) { return dataBuffer->deviceId(); } void dbSetDeviceId(OpaqueDataBuffer *dataBuffer, int deviceId) { - dataBuffer->setDeviceId(deviceId); + dataBuffer->setDeviceId(deviceId); } int dbLocality(OpaqueDataBuffer *dataBuffer) { - auto p = dataBuffer->dataBuffer()->isPrimaryActual(); - auto d = dataBuffer->dataBuffer()->isSpecialActual(); - - if (p && d) - return 0; - else if (p) - return -1; - else - return 1; + auto p = dataBuffer->dataBuffer()->isPrimaryActual(); + auto d = dataBuffer->dataBuffer()->isSpecialActual(); + + if (p && d) + return 0; + else if (p) + return -1; + else + return 1; } \ No newline at end of file diff --git a/libnd4j/include/legacy/impl/Environment.cpp b/libnd4j/include/legacy/impl/Environment.cpp index 38d7e82ed475..9659f4c1854e 100644 --- a/libnd4j/include/legacy/impl/Environment.cpp +++ b/libnd4j/include/legacy/impl/Environment.cpp @@ -18,16 +18,18 @@ // Created by raver119 on 06.10.2017. // -#include -#include -#include -#include #include "system/Environment.h" + #include -#include #include #include +#include +#include +#include +#include +#include + #ifdef _OPENMP #include @@ -43,342 +45,318 @@ namespace sd { - sd::Environment::Environment() { - _tadThreshold.store(1); - _elementThreshold.store(1024); - _verbose.store(false); - _debug.store(false); - _profile.store(false); - _precBoost.store(false); - _leaks.store(false); - _dataType.store(sd::DataType::FLOAT32); - _maxThreads = std::thread::hardware_concurrency(); - _maxMasterThreads = _maxThreads.load(); +sd::Environment::Environment() { + _tadThreshold.store(1); + _elementThreshold.store(1024); + _verbose.store(false); + _debug.store(false); + _profile.store(false); + _precBoost.store(false); + _leaks.store(false); + _dataType.store(sd::DataType::FLOAT32); + _maxThreads = std::thread::hardware_concurrency(); + _maxMasterThreads = _maxThreads.load(); #ifndef ANDROID - const char* omp_threads = std::getenv("OMP_NUM_THREADS"); - if (omp_threads != nullptr) { - try { - std::string omp(omp_threads); - int val = std::stoi(omp); - _maxThreads.store(val); - _maxMasterThreads.store(val); - } catch (std::invalid_argument &e) { - // just do nothing - } catch (std::out_of_range &e) { - // still do nothing - } - } - - /** - * Defines size of thread pool used for parallelism - */ - const char* max_threads = std::getenv("SD_MAX_THREADS"); - if (max_threads != nullptr) { - try { - std::string t(max_threads); - int val = std::stoi(t); - _maxThreads.store(val); - } catch (std::invalid_argument &e) { - // just do nothing - } catch (std::out_of_range &e) { - // still do nothing - } - } - - /** - * Defines max number of threads usable at once - */ - const char* max_master_threads = std::getenv("SD_MASTER_THREADS"); - if (max_master_threads != nullptr) { - try { - std::string t(max_master_threads); - int val = std::stoi(t); - _maxMasterThreads.store(val); - } catch (std::invalid_argument &e) { - // just do nothing - } catch (std::out_of_range &e) { - // still do nothing - } - } - - if (_maxMasterThreads.load() > _maxThreads.load()) { - nd4j_printf("Warning! MAX_MASTER_THREADS > MAX_THREADS, tuning them down to match each other\n",""); - _maxMasterThreads.store(_maxThreads.load()); - } - - /** - * If this env var is defined - we'll disallow use of platform-specific helpers (mkldnn, cudnn, etc) - */ - const char* forbid_helpers = std::getenv("SD_FORBID_HELPERS"); - if (max_master_threads != nullptr) { - _allowHelpers = false; - } - - /** - * This var defines max amount of host memory library can allocate - */ - const char* max_primary_memory = std::getenv("SD_MAX_PRIMARY_BYTES"); - if (max_primary_memory != nullptr) { - try { - std::string t(max_primary_memory); - auto val = std::stol(t); - _maxTotalPrimaryMemory.store(val); - } catch (std::invalid_argument &e) { - // just do nothing - } catch (std::out_of_range &e) { - // still do nothing - } - } - - /** - * This var defines max amount of special (i.e. device) memory library can allocate on all devices combined - */ - const char* max_special_memory = std::getenv("SD_MAX_SPECIAL_BYTES"); - if (max_special_memory != nullptr) { - try { - std::string t(max_special_memory); - auto val = std::stol(t); - _maxTotalSpecialMemory.store(val); - } catch (std::invalid_argument &e) { - // just do nothing - } catch (std::out_of_range &e) { - // still do nothing - } - } - - /** - * This var defines max amount of special (i.e. device) memory library can allocate on all devices combined - */ - const char* max_device_memory = std::getenv("SD_MAX_DEVICE_BYTES"); - if (max_device_memory != nullptr) { - try { - std::string t(max_device_memory); - auto val = std::stol(t); - _maxDeviceMemory.store(val); - } catch (std::invalid_argument &e) { - // just do nothing - } catch (std::out_of_range &e) { - // still do nothing - } - } - - const char* blas_fallback = std::getenv("SD_BLAS_FALLBACK"); - if (blas_fallback != nullptr) { - _blasFallback = true; - } + const char *omp_threads = std::getenv("OMP_NUM_THREADS"); + if (omp_threads != nullptr) { + try { + std::string omp(omp_threads); + int val = std::stoi(omp); + _maxThreads.store(val); + _maxMasterThreads.store(val); + } catch (std::invalid_argument &e) { + // just do nothing + } catch (std::out_of_range &e) { + // still do nothing + } + } + + /** + * Defines size of thread pool used for parallelism + */ + const char *max_threads = std::getenv("SD_MAX_THREADS"); + if (max_threads != nullptr) { + try { + std::string t(max_threads); + int val = std::stoi(t); + _maxThreads.store(val); + } catch (std::invalid_argument &e) { + // just do nothing + } catch (std::out_of_range &e) { + // still do nothing + } + } + + /** + * Defines max number of threads usable at once + */ + const char *max_master_threads = std::getenv("SD_MASTER_THREADS"); + if (max_master_threads != nullptr) { + try { + std::string t(max_master_threads); + int val = std::stoi(t); + _maxMasterThreads.store(val); + } catch (std::invalid_argument &e) { + // just do nothing + } catch (std::out_of_range &e) { + // still do nothing + } + } + + if (_maxMasterThreads.load() > _maxThreads.load()) { + nd4j_printf( + "Warning! MAX_MASTER_THREADS > MAX_THREADS, tuning them down to match " + "each other\n", + ""); + _maxMasterThreads.store(_maxThreads.load()); + } + + /** + * If this env var is defined - we'll disallow use of platform-specific + * helpers (mkldnn, cudnn, etc) + */ + const char *forbid_helpers = std::getenv("SD_FORBID_HELPERS"); + if (max_master_threads != nullptr) { + _allowHelpers = false; + } + + /** + * This var defines max amount of host memory library can allocate + */ + const char *max_primary_memory = std::getenv("SD_MAX_PRIMARY_BYTES"); + if (max_primary_memory != nullptr) { + try { + std::string t(max_primary_memory); + auto val = std::stol(t); + _maxTotalPrimaryMemory.store(val); + } catch (std::invalid_argument &e) { + // just do nothing + } catch (std::out_of_range &e) { + // still do nothing + } + } + + /** + * This var defines max amount of special (i.e. device) memory library can + * allocate on all devices combined + */ + const char *max_special_memory = std::getenv("SD_MAX_SPECIAL_BYTES"); + if (max_special_memory != nullptr) { + try { + std::string t(max_special_memory); + auto val = std::stol(t); + _maxTotalSpecialMemory.store(val); + } catch (std::invalid_argument &e) { + // just do nothing + } catch (std::out_of_range &e) { + // still do nothing + } + } + + /** + * This var defines max amount of special (i.e. device) memory library can + * allocate on all devices combined + */ + const char *max_device_memory = std::getenv("SD_MAX_DEVICE_BYTES"); + if (max_device_memory != nullptr) { + try { + std::string t(max_device_memory); + auto val = std::stol(t); + _maxDeviceMemory.store(val); + } catch (std::invalid_argument &e) { + // just do nothing + } catch (std::out_of_range &e) { + // still do nothing + } + } + + const char *blas_fallback = std::getenv("SD_BLAS_FALLBACK"); + if (blas_fallback != nullptr) { + _blasFallback = true; + } #endif #ifdef __CUDABLAS__ - int devCnt = 0; - cudaGetDeviceCount(&devCnt); - auto devProperties = new cudaDeviceProp[devCnt]; - for (int i = 0; i < devCnt; i++) { - cudaSetDevice(i); - cudaGetDeviceProperties(&devProperties[i], i); - - //cudaDeviceSetLimit(cudaLimitStackSize, 4096); - Pair p(devProperties[i].major, devProperties[i].minor); - _capabilities.emplace_back(p); - } - - BlasVersionHelper ver; - _blasMajorVersion = ver._blasMajorVersion; - _blasMinorVersion = ver._blasMinorVersion; - _blasPatchVersion = ver._blasPatchVersion; - - cudaSetDevice(0); - delete[] devProperties; + int devCnt = 0; + cudaGetDeviceCount(&devCnt); + auto devProperties = new cudaDeviceProp[devCnt]; + for (int i = 0; i < devCnt; i++) { + cudaSetDevice(i); + cudaGetDeviceProperties(&devProperties[i], i); + + // cudaDeviceSetLimit(cudaLimitStackSize, 4096); + Pair p(devProperties[i].major, devProperties[i].minor); + _capabilities.emplace_back(p); + } + + BlasVersionHelper ver; + _blasMajorVersion = ver._blasMajorVersion; + _blasMinorVersion = ver._blasMinorVersion; + _blasPatchVersion = ver._blasPatchVersion; + + cudaSetDevice(0); + delete[] devProperties; #else #endif - } +} - bool sd::Environment::blasFallback() { - return _blasFallback; - } +bool sd::Environment::blasFallback() { return _blasFallback; } - sd::Environment::~Environment() { - // - } +sd::Environment::~Environment() { + // +} - void Environment::setMaxPrimaryMemory(uint64_t maxBytes) { - _maxTotalPrimaryMemory = maxBytes; - } +void Environment::setMaxPrimaryMemory(uint64_t maxBytes) { + _maxTotalPrimaryMemory = maxBytes; +} - void Environment::setMaxSpecialyMemory(uint64_t maxBytes) { - _maxTotalSpecialMemory = maxBytes; - } +void Environment::setMaxSpecialyMemory(uint64_t maxBytes) { + _maxTotalSpecialMemory = maxBytes; +} - void Environment::setMaxDeviceMemory(uint64_t maxBytes) { - _maxDeviceMemory = maxBytes; - } +void Environment::setMaxDeviceMemory(uint64_t maxBytes) { + _maxDeviceMemory = maxBytes; +} - Environment& Environment::getInstance() { - static Environment instance; - return instance; - } +Environment &Environment::getInstance() { + static Environment instance; - bool Environment::isVerbose() { - return _verbose.load(); - } + return instance; +} - bool Environment::isExperimentalBuild() { - return _experimental; - } +bool Environment::isVerbose() { return _verbose.load(); } - sd::DataType Environment::defaultFloatDataType() { - return _dataType.load(); - } +bool Environment::isExperimentalBuild() { return _experimental; } - std::vector& Environment::capabilities() { - return _capabilities; - } +sd::DataType Environment::defaultFloatDataType() { return _dataType.load(); } - void Environment::setDefaultFloatDataType(sd::DataType dtype) { - if (dtype != sd::DataType::FLOAT32 && dtype != sd::DataType::DOUBLE && dtype != sd::DataType::FLOAT8 && dtype != sd::DataType::HALF) - throw std::runtime_error("Default Float data type must be one of [FLOAT8, FLOAT16, FLOAT32, DOUBLE]"); +std::vector &Environment::capabilities() { return _capabilities; } - _dataType.store(dtype); - } +void Environment::setDefaultFloatDataType(sd::DataType dtype) { + if (dtype != sd::DataType::FLOAT32 && dtype != sd::DataType::DOUBLE && + dtype != sd::DataType::FLOAT8 && dtype != sd::DataType::HALF) + throw std::runtime_error( + "Default Float data type must be one of [FLOAT8, FLOAT16, FLOAT32, " + "DOUBLE]"); - void Environment::setVerbose(bool reallyVerbose) { - _verbose = reallyVerbose; - } + _dataType.store(dtype); +} - bool Environment::isDebug() { - return _debug.load(); - } +void Environment::setVerbose(bool reallyVerbose) { _verbose = reallyVerbose; } - bool Environment::isProfiling() { - return _profile.load(); - } +bool Environment::isDebug() { return _debug.load(); } - bool Environment::isDetectingLeaks() { - return _leaks.load(); - } +bool Environment::isProfiling() { return _profile.load(); } - void Environment::setLeaksDetector(bool reallyDetect) { - _leaks.store(reallyDetect); - } +bool Environment::isDetectingLeaks() { return _leaks.load(); } - void Environment::setProfiling(bool reallyProfile) { - _profile.store(reallyProfile); - } +void Environment::setLeaksDetector(bool reallyDetect) { + _leaks.store(reallyDetect); +} - bool Environment::isDebugAndVerbose() { - return this->isDebug() && this->isVerbose(); - } +void Environment::setProfiling(bool reallyProfile) { + _profile.store(reallyProfile); +} - void Environment::setDebug(bool reallyDebug) { - _debug = reallyDebug; - } +bool Environment::isDebugAndVerbose() { + return this->isDebug() && this->isVerbose(); +} - int Environment::tadThreshold() { - return _tadThreshold.load(); - } +void Environment::setDebug(bool reallyDebug) { _debug = reallyDebug; } - void Environment::setTadThreshold(int threshold) { - _tadThreshold = threshold; - } +int Environment::tadThreshold() { return _tadThreshold.load(); } - int Environment::elementwiseThreshold() { - return _elementThreshold.load(); - } +void Environment::setTadThreshold(int threshold) { _tadThreshold = threshold; } - void Environment::setElementwiseThreshold(int threshold) { - _elementThreshold = threshold; - } +int Environment::elementwiseThreshold() { return _elementThreshold.load(); } - int Environment::maxThreads() { - return _maxThreads.load(); - } +void Environment::setElementwiseThreshold(int threshold) { + _elementThreshold = threshold; +} - int Environment::maxMasterThreads() { - return _maxMasterThreads.load(); - } +int Environment::maxThreads() { return _maxThreads.load(); } - void Environment::setMaxThreads(int max) { - // FIXME: not possible at this moment, since maxThreads is limited by number of threads in pool. however we can allocate more threads if we want - //_maxThreads.store(max); - } +int Environment::maxMasterThreads() { return _maxMasterThreads.load(); } - void Environment::setMaxMasterThreads(int max) { - if (max > maxThreads()) { - max = maxThreads(); - } +void Environment::setMaxThreads(int max) { + // FIXME: not possible at this moment, since maxThreads is limited by number + // of threads in pool. however we can allocate more threads if we want + //_maxThreads.store(max); +} - if (max < 1) - return; +void Environment::setMaxMasterThreads(int max) { + if (max > maxThreads()) { + max = maxThreads(); + } - _maxMasterThreads = max; - } + if (max < 1) return; - bool Environment::precisionBoostAllowed() { - return _precBoost.load(); - } + _maxMasterThreads = max; +} - void Environment::allowPrecisionBoost(bool reallyAllow) { - _precBoost.store(reallyAllow); - } +bool Environment::precisionBoostAllowed() { return _precBoost.load(); } - bool Environment::isCPU() { +void Environment::allowPrecisionBoost(bool reallyAllow) { + _precBoost.store(reallyAllow); +} + +bool Environment::isCPU() { #ifdef __CUDABLAS__ - return false; + return false; #else - return true; + return true; #endif - } +} - int Environment::blasMajorVersion(){ - return _blasMajorVersion; - } +int Environment::blasMajorVersion() { return _blasMajorVersion; } - int Environment::blasMinorVersion(){ - return _blasMinorVersion; - } +int Environment::blasMinorVersion() { return _blasMinorVersion; } - int Environment::blasPatchVersion(){ - return _blasPatchVersion; - } +int Environment::blasPatchVersion() { return _blasPatchVersion; } - bool Environment::helpersAllowed() { - return _allowHelpers.load(); - } +bool Environment::helpersAllowed() { return _allowHelpers.load(); } - void Environment::allowHelpers(bool reallyAllow) { - _allowHelpers.store(reallyAllow); - } +void Environment::allowHelpers(bool reallyAllow) { + _allowHelpers.store(reallyAllow); +} - void Environment::setGroupLimit(int group, Nd4jLong numBytes) { - sd::memory::MemoryCounter::getInstance().setGroupLimit((sd::memory::MemoryType) group, numBytes); - } +void Environment::setGroupLimit(int group, Nd4jLong numBytes) { + sd::memory::MemoryCounter::getInstance().setGroupLimit( + (sd::memory::MemoryType)group, numBytes); +} - void Environment::setDeviceLimit(int deviceId, Nd4jLong numBytes) { - sd::memory::MemoryCounter::getInstance().setDeviceLimit(deviceId, numBytes); - } +void Environment::setDeviceLimit(int deviceId, Nd4jLong numBytes) { + sd::memory::MemoryCounter::getInstance().setDeviceLimit(deviceId, numBytes); +} - Nd4jLong Environment::getGroupLimit(int group) { - return sd::memory::MemoryCounter::getInstance().groupLimit((sd::memory::MemoryType) group); - } +Nd4jLong Environment::getGroupLimit(int group) { + return sd::memory::MemoryCounter::getInstance().groupLimit( + (sd::memory::MemoryType)group); +} - Nd4jLong Environment::getDeviceLimit(int deviceId) { - return sd::memory::MemoryCounter::getInstance().deviceLimit(deviceId); - } +Nd4jLong Environment::getDeviceLimit(int deviceId) { + return sd::memory::MemoryCounter::getInstance().deviceLimit(deviceId); +} - Nd4jLong Environment::getGroupCounter(int group) { - return sd::memory::MemoryCounter::getInstance().allocatedGroup((sd::memory::MemoryType) group); - } +Nd4jLong Environment::getGroupCounter(int group) { + return sd::memory::MemoryCounter::getInstance().allocatedGroup( + (sd::memory::MemoryType)group); +} - Nd4jLong Environment::getDeviceCounter(int deviceId) { - return sd::memory::MemoryCounter::getInstance().allocatedDevice(deviceId); - } +Nd4jLong Environment::getDeviceCounter(int deviceId) { + return sd::memory::MemoryCounter::getInstance().allocatedDevice(deviceId); +} - uint64_t Environment::maxPrimaryMemory() { - return _maxTotalPrimaryMemory.load(); - } +uint64_t Environment::maxPrimaryMemory() { + return _maxTotalPrimaryMemory.load(); +} - uint64_t Environment::maxSpecialMemory() { - return _maxTotalSpecialMemory.load(); - } +uint64_t Environment::maxSpecialMemory() { + return _maxTotalSpecialMemory.load(); } + + + +} // namespace sd diff --git a/libnd4j/include/legacy/impl/cnpy.cpp b/libnd4j/include/legacy/impl/cnpy.cpp index ee4fa36b0f81..49595563ae79 100644 --- a/libnd4j/include/legacy/impl/cnpy.cpp +++ b/libnd4j/include/legacy/impl/cnpy.cpp @@ -22,25 +22,25 @@ * THE SOFTWARE. ******************************************************************************/ -//Copyright (C) 2011 Carl Rogers -//Released under MIT License -//license available in LICENSE file, or at http://www.opensource.org/licenses/mit-license.php +// Copyright (C) 2011 Carl Rogers +// Released under MIT License +// license available in LICENSE file, or at +// http://www.opensource.org/licenses/mit-license.php -#include -#include #include +#include #include - +#include /** * * @return */ char cnpy::BigEndianTest() { - unsigned char x[] = {1,0}; - short y = *(short*) x; - return y == 1 ? '<' : '>'; + unsigned char x[] = {1, 0}; + short y = *(short *)x; + return y == 1 ? '<' : '>'; } /** @@ -49,121 +49,145 @@ char cnpy::BigEndianTest() { * @return */ char cnpy::mapType(const std::type_info &t) { - if(t == typeid(float) ) return 'f'; - if(t == typeid(double) ) return 'f'; - if(t == typeid(long double) ) return 'f'; - - if(t == typeid(int) ) return 'i'; - if(t == typeid(char) ) return 'i'; - if(t == typeid(short) ) return 'i'; - if(t == typeid(long) ) return 'i'; - if(t == typeid(long long) ) return 'i'; - - if(t == typeid(unsigned char) ) return 'u'; - if(t == typeid(unsigned short) ) return 'u'; - if(t == typeid(unsigned long) ) return 'u'; - if(t == typeid(unsigned long long) ) return 'u'; - if(t == typeid(unsigned int) ) return 'u'; - - if(t == typeid(bool) ) return 'b'; - - if(t == typeid(std::complex) ) return 'c'; - if(t == typeid(std::complex) ) return 'c'; - if(t == typeid(std::complex) ) return 'c'; - - else return '?'; + if (t == typeid(float)) return 'f'; + if (t == typeid(double)) return 'f'; + if (t == typeid(long double)) return 'f'; + + if (t == typeid(int)) return 'i'; + if (t == typeid(char)) return 'i'; + if (t == typeid(short)) return 'i'; + if (t == typeid(long)) return 'i'; + if (t == typeid(long long)) return 'i'; + + if (t == typeid(unsigned char)) return 'u'; + if (t == typeid(unsigned short)) return 'u'; + if (t == typeid(unsigned long)) return 'u'; + if (t == typeid(unsigned long long)) return 'u'; + if (t == typeid(unsigned int)) return 'u'; + + if (t == typeid(bool)) return 'b'; + + if (t == typeid(std::complex)) return 'c'; + if (t == typeid(std::complex)) return 'c'; + if (t == typeid(std::complex)) + return 'c'; + + else + return '?'; } template char cnpy::mapType() { - if(std::is_same::value) return 'f'; - if(std::is_same::value) return 'f'; - if(std::is_same::value) return 'f'; - if(std::is_same::value) return 'f'; - - if(std::is_same::value) return 'i'; - if(std::is_same::value) return 'i'; - if(std::is_same::value) return 'i'; - if(std::is_same::value) return 'i'; - if(std::is_same::value) return 'i'; - if(std::is_same::value) return 'i'; - if(std::is_same::value) return 'i'; - - if(std::is_same::value) return 'u'; - if(std::is_same::value) return 'u'; - if(std::is_same::value) return 'u'; - if(std::is_same::value) return 'u'; - if(std::is_same::value) return 'u'; - - if(std::is_same::value) return 'b'; - - if(std::is_same, T>::value) return 'c'; - if(std::is_same, T>::value) return 'c'; - if(std::is_same, T>::value) return 'c'; - - else return '?'; + if (std::is_same::value) return 'f'; + if (std::is_same::value) return 'f'; + if (std::is_same::value) return 'f'; + if (std::is_same::value) return 'f'; + + if (std::is_same::value) return 'i'; + if (std::is_same::value) return 'i'; + if (std::is_same::value) return 'i'; + if (std::is_same::value) return 'i'; + if (std::is_same::value) return 'i'; + if (std::is_same::value) return 'i'; + if (std::is_same::value) return 'i'; + + if (std::is_same::value) return 'u'; + if (std::is_same::value) return 'u'; + if (std::is_same::value) return 'u'; + if (std::is_same::value) return 'u'; + if (std::is_same::value) return 'u'; + + if (std::is_same::value) return 'b'; + + if (std::is_same, T>::value) return 'c'; + if (std::is_same, T>::value) return 'c'; + if (std::is_same, T>::value) + return 'c'; + + else + return '?'; } sd::DataType cnpy::dataTypeFromHeader(char *data) { - - // indices for type & data size - const int st = 10; - const int ti = 22; - const int si = 23; - - // read first char to make sure it looks like a header - if (data == nullptr || data[st] != '{') - throw std::runtime_error("cnpy::dataTypeFromHeader() - provided pointer doesn't look like a pointer to numpy header"); - - const auto t = data[ti]; - const auto s = data[si]; - - switch (t) { - case 'b': - return sd::DataType::BOOL; - case 'i': - switch (s) { - case '1': return sd::DataType::INT8; - case '2': return sd::DataType::INT16; - case '4': return sd::DataType::INT32; - case '8': return sd::DataType::INT64; - default: - throw std::runtime_error("Only data sizes of [1, 2, 4, 8] are supported for Integer data types import"); - } - case 'f': - switch (s) { - case '1': return sd::DataType::FLOAT8; - case '2': return sd::DataType::HALF; - case '4': return sd::DataType::FLOAT32; - case '8': return sd::DataType::DOUBLE; - default: - throw std::runtime_error("Only data sizes of [1, 2, 4, 8] are supported for Float data types import"); - } - case 'u': - switch (s) { - case '1': return sd::DataType::UINT8; - case '2': return sd::DataType::UINT16; - case '4': return sd::DataType::UINT32; - case '8': return sd::DataType::UINT64; - default: - throw std::runtime_error("Only data sizes of [1, 2, 4, 8] are supported for Unsigned data types import"); - } - case 'c': - throw std::runtime_error("Import of complex data types isn't supported yet"); + // indices for type & data size + const int st = 10; + const int ti = 22; + const int si = 23; + + // read first char to make sure it looks like a header + if (data == nullptr || data[st] != '{') + throw std::runtime_error( + "cnpy::dataTypeFromHeader() - provided pointer doesn't look like a " + "pointer to numpy header"); + + const auto t = data[ti]; + const auto s = data[si]; + + switch (t) { + case 'b': + return sd::DataType::BOOL; + case 'i': + switch (s) { + case '1': + return sd::DataType::INT8; + case '2': + return sd::DataType::INT16; + case '4': + return sd::DataType::INT32; + case '8': + return sd::DataType::INT64; default: - throw std::runtime_error("Unknown type marker"); - } + throw std::runtime_error( + "Only data sizes of [1, 2, 4, 8] are supported for Integer data " + "types import"); + } + case 'f': + switch (s) { + case '1': + return sd::DataType::FLOAT8; + case '2': + return sd::DataType::HALF; + case '4': + return sd::DataType::FLOAT32; + case '8': + return sd::DataType::DOUBLE; + default: + throw std::runtime_error( + "Only data sizes of [1, 2, 4, 8] are supported for Float data " + "types import"); + } + case 'u': + switch (s) { + case '1': + return sd::DataType::UINT8; + case '2': + return sd::DataType::UINT16; + case '4': + return sd::DataType::UINT32; + case '8': + return sd::DataType::UINT64; + default: + throw std::runtime_error( + "Only data sizes of [1, 2, 4, 8] are supported for Unsigned data " + "types import"); + } + case 'c': + throw std::runtime_error( + "Import of complex data types isn't supported yet"); + default: + throw std::runtime_error("Unknown type marker"); + } } template -std::vector& operator+=(std::vector& lhs, const T rhs) { - //write in little endian - for(char byte = 0; byte < sizeof(T); byte++) { - char val = *((char*)&rhs+byte); - lhs.push_back(val); - } - - return lhs; +std::vector &operator+=(std::vector &lhs, const T rhs) { + // write in little endian + for (char byte = 0; byte < sizeof(T); byte++) { + char val = *((char *)&rhs + byte); + lhs.push_back(val); + } + + return lhs; } /** @@ -172,10 +196,10 @@ std::vector& operator+=(std::vector& lhs, const T rhs) { * @param rhs * @return */ -template<> -std::vector& operator+=(std::vector& lhs, const std::string rhs) { - lhs.insert(lhs.end(),rhs.begin(),rhs.end()); - return lhs; +template <> +std::vector &operator+=(std::vector &lhs, const std::string rhs) { + lhs.insert(lhs.end(), rhs.begin(), rhs.end()); + return lhs; } /** @@ -184,15 +208,15 @@ std::vector& operator+=(std::vector& lhs, const std::string rhs) { * @param rhs * @return */ -template<> -std::vector& operator+=(std::vector& lhs, const char* rhs) { - //write in little endian - size_t len = strlen(rhs); - lhs.reserve(len); - for(size_t byte = 0; byte < len; byte++) { - lhs.push_back(rhs[byte]); - } - return lhs; +template <> +std::vector &operator+=(std::vector &lhs, const char *rhs) { + // write in little endian + size_t len = strlen(rhs); + lhs.reserve(len); + for (size_t byte = 0; byte < len; byte++) { + lhs.push_back(rhs[byte]); + } + return lhs; } /** @@ -200,87 +224,80 @@ std::vector& operator+=(std::vector& lhs, const char* rhs) { * @param path * @return */ -char* cnpy::loadFile(const char *path) { - char* buffer = 0; - long length; - FILE * f = fopen (path, "rb"); //was "rb" - - if (f) { - fseek (f, 0, SEEK_END); - length = ftell (f); - fseek (f, 0, SEEK_SET); - buffer = (char*) malloc ((length+ 1) * sizeof(char)); - - // just getting rid of compiler warning - Nd4jLong fps = 0; - - if (buffer) { - fps += fread (buffer, sizeof(char), length, f); - } - - fclose (f); +char *cnpy::loadFile(const char *path) { + char *buffer = 0; + long length; + FILE *f = fopen(path, "rb"); // was "rb" + + if (f) { + fseek(f, 0, SEEK_END); + length = ftell(f); + fseek(f, 0, SEEK_SET); + buffer = (char *)malloc((length + 1) * sizeof(char)); + + // just getting rid of compiler warning + Nd4jLong fps = 0; + + if (buffer) { + fps += fread(buffer, sizeof(char), length, f); } - buffer[length] = '\0'; - return buffer; -} + fclose(f); + } + buffer[length] = '\0'; + return buffer; +} /** -* Parse the numpy header from -* the given file -* based on the pointers passed in -* @param fp the file to parse from -* @param wordSize the size of -* the individual elements -* @param shape -* @param ndims -* @param fortranOrder -*/ -void cnpy::parseNpyHeaderStr(std::string header, - unsigned int &wordSize, - unsigned int *&shape, - unsigned int &ndims, + * Parse the numpy header from + * the given file + * based on the pointers passed in + * @param fp the file to parse from + * @param wordSize the size of + * the individual elements + * @param shape + * @param ndims + * @param fortranOrder + */ +void cnpy::parseNpyHeaderStr(std::string header, unsigned int &wordSize, + unsigned int *&shape, unsigned int &ndims, bool &fortranOrder) { - - - int loc1, loc2; - - - - //fortran order - loc1 = header.find("fortran_order") + 16; - fortranOrder = (header.substr(loc1,5) == "True" ? true : false); - //shape - loc1 = header.find("("); - loc2 = header.find(")"); - std::string str_shape = header.substr(loc1 + 1,loc2 - loc1 - 1); - if(str_shape[str_shape.size() - 1] == ',') ndims = 1; - else ndims = std::count(str_shape.begin(),str_shape.end(),',')+1; - - shape = new unsigned int[ndims]; - for(unsigned int i = 0; i < ndims; i++) { - loc1 = str_shape.find(","); - shape[i] = atoi(str_shape.substr(0,loc1).c_str()); - str_shape = str_shape.substr(loc1 + 1); - } - - - - //endian, word size, data type - //byte order code | stands for not applicable. - //not sure when this applies except for byte array - loc1 = header.find("descr") + 9; - bool littleEndian = (header[loc1] == '<' || header[loc1] == '|' ? true : false); - assert(littleEndian); - - //char type = header[loc1+1]; - //assert(type == map_type(T)); - - std::string str_ws = header.substr(loc1 + 2); - loc2 = str_ws.find("'"); - wordSize = atoi(str_ws.substr(0,loc2).c_str()); - + int loc1, loc2; + + // fortran order + loc1 = header.find("fortran_order") + 16; + fortranOrder = (header.substr(loc1, 5) == "True" ? true : false); + // shape + loc1 = header.find("("); + loc2 = header.find(")"); + std::string str_shape = header.substr(loc1 + 1, loc2 - loc1 - 1); + if (str_shape[str_shape.size() - 1] == ',') + ndims = 1; + else + ndims = std::count(str_shape.begin(), str_shape.end(), ',') + 1; + + shape = new unsigned int[ndims]; + for (unsigned int i = 0; i < ndims; i++) { + loc1 = str_shape.find(","); + shape[i] = atoi(str_shape.substr(0, loc1).c_str()); + str_shape = str_shape.substr(loc1 + 1); + } + + // endian, word size, data type + // byte order code | stands for not applicable. + // not sure when this applies except for byte array + loc1 = header.find("descr") + 9; + bool littleEndian = + (header[loc1] == '<' || header[loc1] == '|' ? true : false); + assert(littleEndian); + + // char type = header[loc1+1]; + // assert(type == map_type(T)); + + std::string str_ws = header.substr(loc1 + 2); + loc2 = str_ws.find("'"); + wordSize = atoi(str_ws.substr(0, loc2).c_str()); } /** @@ -294,26 +311,17 @@ void cnpy::parseNpyHeaderStr(std::string header, * @param ndims the number of dimensions for the array * @param fortranOrder */ -void cnpy::parseNpyHeader(FILE *fp, - unsigned int &wordSize, - unsigned int *&shape, - unsigned int &ndims, +void cnpy::parseNpyHeader(FILE *fp, unsigned int &wordSize, + unsigned int *&shape, unsigned int &ndims, bool &fortranOrder) { - char buffer[256]; - size_t res = fread(buffer,sizeof(char),11,fp); - if(res != 11) - throw std::runtime_error("parse_npy_header: failed fread"); - std::string header = fgets(buffer,256,fp); - assert(header[header.size() - 1] == '\n'); - cnpy::parseNpyHeaderStr(header, - wordSize, - shape, - ndims, - fortranOrder); + char buffer[256]; + size_t res = fread(buffer, sizeof(char), 11, fp); + if (res != 11) throw std::runtime_error("parse_npy_header: failed fread"); + std::string header = fgets(buffer, 256, fp); + assert(header[header.size() - 1] == '\n'); + cnpy::parseNpyHeaderStr(header, wordSize, shape, ndims, fortranOrder); } - - /** * * @param fp @@ -321,30 +329,27 @@ void cnpy::parseNpyHeader(FILE *fp, * @param global_header_size * @param global_header_offset */ -void cnpy::parseZipFooter(FILE* fp, - unsigned short& nrecs, - unsigned int& global_header_size, - unsigned int& global_header_offset) { - - std::vector footer(22); - fseek(fp, -22, SEEK_END); - size_t res = fread(&footer[0],sizeof(char),22,fp); - if(res != 22) - throw std::runtime_error("parse_zip_footer: failed fread"); - - unsigned short disk_no, disk_start, nrecs_on_disk, comment_len; - disk_no = *(unsigned short*) &footer[4]; - disk_start = *(unsigned short*) &footer[6]; - nrecs_on_disk = *(unsigned short*) &footer[8]; - nrecs = *(unsigned short*) &footer[10]; - global_header_size = *(unsigned int*) &footer[12]; - global_header_offset = *(unsigned int*) &footer[16]; - comment_len = *(unsigned short*) &footer[20]; - - assert(disk_no == 0); - assert(disk_start == 0); - assert(nrecs_on_disk == nrecs); - assert(comment_len == 0); +void cnpy::parseZipFooter(FILE *fp, unsigned short &nrecs, + unsigned int &global_header_size, + unsigned int &global_header_offset) { + std::vector footer(22); + fseek(fp, -22, SEEK_END); + size_t res = fread(&footer[0], sizeof(char), 22, fp); + if (res != 22) throw std::runtime_error("parse_zip_footer: failed fread"); + + unsigned short disk_no, disk_start, nrecs_on_disk, comment_len; + disk_no = *(unsigned short *)&footer[4]; + disk_start = *(unsigned short *)&footer[6]; + nrecs_on_disk = *(unsigned short *)&footer[8]; + nrecs = *(unsigned short *)&footer[10]; + global_header_size = *(unsigned int *)&footer[12]; + global_header_offset = *(unsigned int *)&footer[16]; + comment_len = *(unsigned short *)&footer[20]; + + assert(disk_no == 0); + assert(disk_start == 0); + assert(nrecs_on_disk == nrecs); + assert(comment_len == 0); } /** @@ -353,88 +358,85 @@ void cnpy::parseZipFooter(FILE* fp, * @return the loaded array */ cnpy::NpyArray cnpy::loadNpyFromFile(FILE *fp) { - unsigned int *shape; - unsigned int ndims, wordSize; - bool fortranOrder; - cnpy::parseNpyHeader(fp,wordSize,shape,ndims,fortranOrder); - unsigned long long size = 1; //long long so no overflow when multiplying by word_size - for(unsigned int i = 0;i < ndims;i++) size *= shape[i]; - - cnpy::NpyArray arr; - arr.wordSize = wordSize; - arr.shape = std::vector(shape,shape + ndims); - arr.data = new char[size * wordSize]; - arr.fortranOrder = fortranOrder; - size_t nread = fread(arr.data,wordSize,size,fp); - if(nread != size) - throw std::runtime_error("load_the_npy_file: failed fread"); - return arr; + unsigned int *shape; + unsigned int ndims, wordSize; + bool fortranOrder; + cnpy::parseNpyHeader(fp, wordSize, shape, ndims, fortranOrder); + unsigned long long size = + 1; // long long so no overflow when multiplying by word_size + for (unsigned int i = 0; i < ndims; i++) size *= shape[i]; + + cnpy::NpyArray arr; + arr.wordSize = wordSize; + arr.shape = std::vector(shape, shape + ndims); + arr.data = new char[size * wordSize]; + arr.fortranOrder = fortranOrder; + size_t nread = fread(arr.data, wordSize, size, fp); + if (nread != size) + throw std::runtime_error("load_the_npy_file: failed fread"); + return arr; } - - /** - * - * @param data - * @return - */ -cnpy::NpyArray cnpy::loadNpyFromPointer(char *data) { - //move the pointer forward by 11 imitating - //the seek in loading directly from a file - return cnpy::loadNpyFromHeader(data); + * + * @param data + * @return + */ +cnpy::NpyArray cnpy::loadNpyFromPointer(char *data) { + // move the pointer forward by 11 imitating + // the seek in loading directly from a file + return cnpy::loadNpyFromHeader(data); } /** -* -* @param data -* @return -*/ + * + * @param data + * @return + */ cnpy::NpyArray cnpy::loadNpyFromHeader(char *data) { - // check for magic header - if (data == nullptr) - throw std::runtime_error("NULL pointer doesn't look like a NumPy header"); - - if (data[0] == (char) 0x93) { - std::vector exp({(char) 0x93, 'N', 'U', 'M', 'P', 'Y', (char) 0x01}); - std::vector hdr(data, data+7); - if (hdr != exp) - throw std::runtime_error("Pointer doesn't look like a NumPy header"); - } else - throw std::runtime_error("Pointer doesn't look like a NumPy header"); - - //move passed magic - data += 11; - unsigned int *shape; - unsigned int ndims, wordSize; - bool fortranOrder; - cnpy::parseNpyHeaderStr(std::string(data), - wordSize, - shape, - ndims, - fortranOrder); - //the "real" data starts after the \n - char currChar = data[0]; - int count = 0; - while(currChar != '\n') { - data++; - currChar = data[0]; - count++; - } - - //move pass the \n + // check for magic header + if (data == nullptr) + throw std::runtime_error("NULL pointer doesn't look like a NumPy header"); + + if (data[0] == (char)0x93) { + std::vector exp({(char)0x93, 'N', 'U', 'M', 'P', 'Y', (char)0x01}); + std::vector hdr(data, data + 7); + if (hdr != exp) + throw std::runtime_error("Pointer doesn't look like a NumPy header"); + } else + throw std::runtime_error("Pointer doesn't look like a NumPy header"); + + // move passed magic + data += 11; + unsigned int *shape; + unsigned int ndims, wordSize; + bool fortranOrder; + cnpy::parseNpyHeaderStr(std::string(data), wordSize, shape, ndims, + fortranOrder); + // the "real" data starts after the \n + char currChar = data[0]; + int count = 0; + while (currChar != '\n') { data++; + currChar = data[0]; count++; - - unsigned long long size = 1; //long long so no overflow when multiplying by word_size - for(unsigned int i = 0; i < ndims; i++) size *= shape[i]; - char *cursor = data; - cnpy::NpyArray arr; - arr.wordSize = wordSize; - arr.shape = std::vector(shape,shape + ndims); - delete[] shape; - arr.data = cursor; - arr.fortranOrder = fortranOrder; - return arr; + } + + // move pass the \n + data++; + count++; + + unsigned long long size = + 1; // long long so no overflow when multiplying by word_size + for (unsigned int i = 0; i < ndims; i++) size *= shape[i]; + char *cursor = data; + cnpy::NpyArray arr; + arr.wordSize = wordSize; + arr.shape = std::vector(shape, shape + ndims); + delete[] shape; + arr.data = cursor; + arr.fortranOrder = fortranOrder; + return arr; } /** @@ -443,41 +445,39 @@ cnpy::NpyArray cnpy::loadNpyFromHeader(char *data) { * @return the arrays */ -cnpy::npz_t cnpy::npzLoad(FILE* fp){ - cnpy::npz_t arrays; - - while(1) { - std::vector local_header(30); - size_t headerres = fread(&local_header[0],sizeof(char),30,fp); - if(headerres != 30) - throw std::runtime_error("npz_load: failed fread"); - - //if we've reached the global header, stop reading - if(local_header[2] != 0x03 || local_header[3] != 0x04) break; - - //read in the variable name - unsigned short name_len = *(unsigned short*) &local_header[26]; - std::string varname(name_len,' '); - size_t vname_res = fread(&varname[0],sizeof(char),name_len,fp); - if(vname_res != name_len) - throw std::runtime_error("npz_load: failed fread"); - - //erase the lagging .npy - for (int e = 0; e < 4; e++) - varname.pop_back(); - - //read in the extra field - unsigned short extra_field_len = *(unsigned short*) &local_header[28]; - if(extra_field_len > 0) { - std::vector buff(extra_field_len); - size_t efield_res = fread(&buff[0],sizeof(char),extra_field_len,fp); - if(efield_res != extra_field_len) - throw std::runtime_error("npz_load: failed fread"); - } - - arrays[varname] = loadNpyFromFile(fp); +cnpy::npz_t cnpy::npzLoad(FILE *fp) { + cnpy::npz_t arrays; + + while (1) { + std::vector local_header(30); + size_t headerres = fread(&local_header[0], sizeof(char), 30, fp); + if (headerres != 30) throw std::runtime_error("npz_load: failed fread"); + + // if we've reached the global header, stop reading + if (local_header[2] != 0x03 || local_header[3] != 0x04) break; + + // read in the variable name + unsigned short name_len = *(unsigned short *)&local_header[26]; + std::string varname(name_len, ' '); + size_t vname_res = fread(&varname[0], sizeof(char), name_len, fp); + if (vname_res != name_len) + throw std::runtime_error("npz_load: failed fread"); + + // erase the lagging .npy + for (int e = 0; e < 4; e++) varname.pop_back(); + + // read in the extra field + unsigned short extra_field_len = *(unsigned short *)&local_header[28]; + if (extra_field_len > 0) { + std::vector buff(extra_field_len); + size_t efield_res = fread(&buff[0], sizeof(char), extra_field_len, fp); + if (efield_res != extra_field_len) + throw std::runtime_error("npz_load: failed fread"); } - return arrays; + + arrays[varname] = loadNpyFromFile(fp); + } + return arrays; } /** @@ -486,45 +486,42 @@ cnpy::npz_t cnpy::npzLoad(FILE* fp){ * @return the arrays */ cnpy::npz_t cnpy::npzLoad(std::string fname) { - FILE* fp = fopen(fname.c_str(),"rb"); - - if(!fp) printf("npz_load: Error! Unable to open file %s!\n",fname.c_str()); - assert(fp); - cnpy::npz_t arrays; - while(1) { - std::vector local_header(30); - size_t headerres = fread(&local_header[0],sizeof(char),30,fp); - if(headerres != 30) - throw std::runtime_error("npz_load: failed fread"); - - //if we've reached the global header, stop reading - if(local_header[2] != 0x03 || local_header[3] != 0x04) break; - - //read in the variable name - unsigned short name_len = *(unsigned short*) &local_header[26]; - std::string varname(name_len,' '); - size_t vname_res = fread(&varname[0],sizeof(char),name_len,fp); - if(vname_res != name_len) - throw std::runtime_error("npz_load: failed fread"); - - //erase the lagging .npy - for (int e = 0; e < 4; e++) - varname.pop_back(); - - //read in the extra field - unsigned short extra_field_len = *(unsigned short*) &local_header[28]; - if(extra_field_len > 0) { - std::vector buff(extra_field_len); - size_t efield_res = fread(&buff[0],sizeof(char),extra_field_len,fp); - if(efield_res != extra_field_len) - throw std::runtime_error("npz_load: failed fread"); - } - - arrays[varname] = loadNpyFromFile(fp); + FILE *fp = fopen(fname.c_str(), "rb"); + + if (!fp) printf("npz_load: Error! Unable to open file %s!\n", fname.c_str()); + assert(fp); + cnpy::npz_t arrays; + while (1) { + std::vector local_header(30); + size_t headerres = fread(&local_header[0], sizeof(char), 30, fp); + if (headerres != 30) throw std::runtime_error("npz_load: failed fread"); + + // if we've reached the global header, stop reading + if (local_header[2] != 0x03 || local_header[3] != 0x04) break; + + // read in the variable name + unsigned short name_len = *(unsigned short *)&local_header[26]; + std::string varname(name_len, ' '); + size_t vname_res = fread(&varname[0], sizeof(char), name_len, fp); + if (vname_res != name_len) + throw std::runtime_error("npz_load: failed fread"); + + // erase the lagging .npy + for (int e = 0; e < 4; e++) varname.pop_back(); + + // read in the extra field + unsigned short extra_field_len = *(unsigned short *)&local_header[28]; + if (extra_field_len > 0) { + std::vector buff(extra_field_len); + size_t efield_res = fread(&buff[0], sizeof(char), extra_field_len, fp); + if (efield_res != extra_field_len) + throw std::runtime_error("npz_load: failed fread"); } - fclose(fp); - return arrays; + arrays[varname] = loadNpyFromFile(fp); + } + fclose(fp); + return arrays; } /** @@ -534,149 +531,138 @@ cnpy::npz_t cnpy::npzLoad(std::string fname) { * @return */ cnpy::NpyArray cnpy::npzLoad(std::string fname, std::string varname) { - FILE *fp = fopen(fname.c_str(),"rb"); - - if(!fp) { - printf("npz_load: Error! Unable to open file %s!\n",fname.c_str()); + FILE *fp = fopen(fname.c_str(), "rb"); + + if (!fp) { + printf("npz_load: Error! Unable to open file %s!\n", fname.c_str()); + } + + while (1) { + std::vector local_header(30); + size_t header_res = fread(&local_header[0], sizeof(char), 30, fp); + if (header_res != 30) throw std::runtime_error("npz_load: failed fread"); + + // if we've reached the global header, stop reading + if (local_header[2] != 0x03 || local_header[3] != 0x04) break; + + // read in the variable name + unsigned short name_len = *(unsigned short *)&local_header[26]; + std::string vname(name_len, ' '); + size_t vname_res = fread(&vname[0], sizeof(char), name_len, fp); + if (vname_res != name_len) + throw std::runtime_error("npz_load: failed fread"); + + // erase the lagging .npy + for (int e = 0; e < 4; e++) varname.pop_back(); + + // read in the extra field + unsigned short extra_field_len = *(unsigned short *)&local_header[28]; + fseek(fp, extra_field_len, SEEK_CUR); // skip past the extra field + + if (vname == varname) { + NpyArray array = cnpy::loadNpyFromFile(fp); + fclose(fp); + return array; + } else { + // skip past the data + unsigned int size = *(unsigned int *)&local_header[22]; + fseek(fp, size, SEEK_CUR); } + } - while(1) { - std::vector local_header(30); - size_t header_res = fread(&local_header[0],sizeof(char),30,fp); - if(header_res != 30) - throw std::runtime_error("npz_load: failed fread"); - - //if we've reached the global header, stop reading - if(local_header[2] != 0x03 || local_header[3] != 0x04) break; - - //read in the variable name - unsigned short name_len = *(unsigned short*) &local_header[26]; - std::string vname(name_len,' '); - size_t vname_res = fread(&vname[0],sizeof(char),name_len,fp); - if(vname_res != name_len) - throw std::runtime_error("npz_load: failed fread"); - - //erase the lagging .npy - for (int e = 0; e < 4; e++) - varname.pop_back(); - - //read in the extra field - unsigned short extra_field_len = *(unsigned short*) &local_header[28]; - fseek(fp,extra_field_len,SEEK_CUR); //skip past the extra field - - if(vname == varname) { - NpyArray array = cnpy::loadNpyFromFile(fp); - fclose(fp); - return array; - } - else { - //skip past the data - unsigned int size = *(unsigned int*) &local_header[22]; - fseek(fp,size,SEEK_CUR); - } - } - - fclose(fp); - printf("npz_load: Error! Variable name %s not found in %s!\n",varname.c_str(),fname.c_str()); - throw std::runtime_error("Variable wasn't found in file"); + fclose(fp); + printf("npz_load: Error! Variable name %s not found in %s!\n", + varname.c_str(), fname.c_str()); + throw std::runtime_error("Variable wasn't found in file"); } - - - /** * Load a numpy array from the given file * @param fname the fully qualified path for the file * @return the NpArray for this file */ cnpy::NpyArray cnpy::npyLoad(std::string fname) { - FILE* fp = fopen(fname.c_str(), "rb"); + FILE *fp = fopen(fname.c_str(), "rb"); - if(!fp) { - printf("npy_load: Error! Unable to open file %s!\n",fname.c_str()); - } + if (!fp) { + printf("npy_load: Error! Unable to open file %s!\n", fname.c_str()); + } - NpyArray arr = cnpy::loadNpyFromFile(fp); + NpyArray arr = cnpy::loadNpyFromFile(fp); - fclose(fp); - return arr; + fclose(fp); + return arr; } - /** - * Save the numpy array - * @tparam T - * @param fname the file - * @param data the data for the ndarray - * @param shape the shape of the ndarray - * @param ndims the number of dimensions - * for the ndarray - * @param mode the mode for writing - */ -template -void cnpy::npy_save(std::string fname, - const T* data, - const unsigned int* shape, - const unsigned int ndims, - std::string mode) { - - FILE* fp = NULL; - - if(mode == "a") - fp = fopen(fname.c_str(),"r+b"); - - if(fp) { - //file exists. we need to append to it. read the header, modify the array size - unsigned int word_size, tmp_dims; - unsigned int* tmp_shape = 0; - bool fortran_order; - parseNpyHeader(fp, - word_size, - tmp_shape, - tmp_dims, - fortran_order); - - assert(!fortran_order); - - if(word_size != sizeof(T)) { - std::cout<<"libnpy error: " << fname<< " has word size " << word_size<<" but npy_save appending data sized " << sizeof(T) <<"\n"; - assert( word_size == sizeof(T) ); - } - - if(tmp_dims != ndims) { - std::cout<<"libnpy error: npy_save attempting to append misdimensioned data to "< header = createNpyHeader(data,tmp_shape,ndims); - fwrite(&header[0],sizeof(char),header.size(),fp); - fseek(fp,0,SEEK_END); - - delete[] tmp_shape; + * Save the numpy array + * @tparam T + * @param fname the file + * @param data the data for the ndarray + * @param shape the shape of the ndarray + * @param ndims the number of dimensions + * for the ndarray + * @param mode the mode for writing + */ +template +void cnpy::npy_save(std::string fname, const T *data, const unsigned int *shape, + const unsigned int ndims, std::string mode) { + FILE *fp = NULL; + + if (mode == "a") fp = fopen(fname.c_str(), "r+b"); + + if (fp) { + // file exists. we need to append to it. read the header, modify the array + // size + unsigned int word_size, tmp_dims; + unsigned int *tmp_shape = 0; + bool fortran_order; + parseNpyHeader(fp, word_size, tmp_shape, tmp_dims, fortran_order); + + assert(!fortran_order); + + if (word_size != sizeof(T)) { + std::cout << "libnpy error: " << fname << " has word size " << word_size + << " but npy_save appending data sized " << sizeof(T) << "\n"; + assert(word_size == sizeof(T)); } - else { - fp = fopen(fname.c_str(),"wb"); - std::vector header = createNpyHeader(data,shape,ndims); - fwrite(&header[0],sizeof(char),header.size(),fp); + + if (tmp_dims != ndims) { + std::cout << "libnpy error: npy_save attempting to append misdimensioned " + "data to " + << fname << "\n"; + assert(tmp_dims == ndims); } - unsigned long long nels = 1; - for(int i = 0;i < ndims;i++) nels *= shape[i]; + for (int i = 1; i < ndims; i++) { + if (shape[i] != tmp_shape[i]) { + std::cout + << "libnpy error: npy_save attempting to append misshaped data to " + << fname << "\n"; + assert(shape[i] == tmp_shape[i]); + } + } - fwrite(data,sizeof(T),nels,fp); - fclose(fp); -} + tmp_shape[0] += shape[0]; + fseek(fp, 0, SEEK_SET); + std::vector header = createNpyHeader(data, tmp_shape, ndims); + fwrite(&header[0], sizeof(char), header.size(), fp); + fseek(fp, 0, SEEK_END); + + delete[] tmp_shape; + } else { + fp = fopen(fname.c_str(), "wb"); + std::vector header = createNpyHeader(data, shape, ndims); + fwrite(&header[0], sizeof(char), header.size(), fp); + } + + unsigned long long nels = 1; + for (int i = 0; i < ndims; i++) nels *= shape[i]; + + fwrite(data, sizeof(T), nels, fp); + fclose(fp); +} /** * @@ -686,49 +672,58 @@ void cnpy::npy_save(std::string fname, * @param ndims * @return */ -template +template std::vector cnpy::createNpyHeader(const void *vdata, - const unsigned int *shape, - const unsigned int ndims, - unsigned int wordSize) { - - auto data = reinterpret_cast(vdata); - - std::vector dict; - dict += "{'descr': '"; - dict += sizeof(T) > 1 ? BigEndianTest() : '|'; - dict += mapType(); - dict += tostring(wordSize); - dict += "', 'fortran_order': False, 'shape': ("; - if (ndims > 0) { - dict += tostring(shape[0]); - for (int i = 1; i < ndims; i++) { - dict += ", "; - dict += tostring(shape[i]); - } - - if (ndims == 1) - dict += ","; + const unsigned int *shape, + const unsigned int ndims, + unsigned int wordSize) { + auto data = reinterpret_cast(vdata); + + std::vector dict; + dict += "{'descr': '"; + dict += sizeof(T) > 1 ? BigEndianTest() : '|'; + dict += mapType(); + dict += tostring(wordSize); + dict += "', 'fortran_order': False, 'shape': ("; + if (ndims > 0) { + dict += tostring(shape[0]); + for (int i = 1; i < ndims; i++) { + dict += ", "; + dict += tostring(shape[i]); } - // 0D case still requires close - dict += "), }"; - - //pad with spaces so that preamble+dict is modulo 16 bytes. preamble is 10 bytes. dict needs to end with \n - int remainder = 64 - (10 + dict.size()) % 64; - dict.insert(dict.end(),remainder,' '); - dict.back() = '\n'; - - std::vector header; - header += (char) 0x93; - header += "NUMPY"; - header += (char) 0x01; //major version of numpy format - header += (char) 0x00; //minor version of numpy format - header += (unsigned short) dict.size(); - header.insert(header.end(),dict.begin(),dict.end()); - - return header; + + if (ndims == 1) dict += ","; + } + // 0D case still requires close + dict += "), }"; + + // pad with spaces so that preamble+dict is modulo 16 bytes. preamble is 10 + // bytes. dict needs to end with \n + int remainder = 64 - (10 + dict.size()) % 64; + dict.insert(dict.end(), remainder, ' '); + dict.back() = '\n'; + + std::vector header; + header += (char)0x93; + header += "NUMPY"; + header += (char)0x01; // major version of numpy format + header += (char)0x00; // minor version of numpy format + header += (unsigned short)dict.size(); + header.insert(header.end(), dict.begin(), dict.end()); + + return header; } -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector cnpy::createNpyHeader, (const void *data, const unsigned int *shape, const unsigned int ndims, unsigned int wordSize), LIBND4J_TYPES); -//template ND4J_EXPORT std::vector cnpy::createNpyHeader(const void *data, const unsigned int *shape, const unsigned int ndims, unsigned int wordSize); -template ND4J_EXPORT void cnpy::npy_save(std::string fname, const float* data, const unsigned int* shape, const unsigned int ndims, std::string mode); +BUILD_SINGLE_TEMPLATE( + template SD_EXPORT std::vector cnpy::createNpyHeader, + (const void *data, const unsigned int *shape, const unsigned int ndims, + unsigned int wordSize), + LIBND4J_TYPES); +// template SD_EXPORT std::vector cnpy::createNpyHeader(const void +// *data, const unsigned int *shape, const unsigned int ndims, unsigned int +// wordSize); +template SD_EXPORT void cnpy::npy_save(std::string fname, + const float *data, + const unsigned int *shape, + const unsigned int ndims, + std::string mode); diff --git a/libnd4j/include/loops/BroadcastPairwiseConverter.h b/libnd4j/include/loops/BroadcastPairwiseConverter.h index acb7e8d64035..fed97b469572 100644 --- a/libnd4j/include/loops/BroadcastPairwiseConverter.h +++ b/libnd4j/include/loops/BroadcastPairwiseConverter.h @@ -18,79 +18,127 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 09.04.2019 // -#ifndef DEV_TESTS_BROADCASTPAIRWISECONVERTER_H -#define DEV_TESTS_BROADCASTPAIRWISECONVERTER_H +#ifndef SD_BROADCASTPAIRWISECONVERTER_H +#define SD_BROADCASTPAIRWISECONVERTER_H #include +#include + #include namespace sd { ////////////////////////////////////////////////////////////////////////// inline pairwise::Ops fromBroadcastToPairwise(broadcast::Ops op) { - switch (op) { - case broadcast::Add: return pairwise::Add; - case broadcast::Subtract: return pairwise::Subtract; - case broadcast::Multiply: return pairwise::Multiply; - case broadcast::Divide: return pairwise::Divide; - case broadcast::ReverseDivide: return pairwise::ReverseDivide; - case broadcast::ReverseSubtract: return pairwise::ReverseSubtract; - case broadcast::CopyPws: return pairwise::CopyPws; - case broadcast::Pow: return pairwise::Pow; - case broadcast::MinPairwise: return pairwise::MinPairwise; - case broadcast::MaxPairwise: return pairwise::MaxPairwise; - case broadcast::AMinPairwise: return pairwise::AMinPairwise; - case broadcast::AMaxPairwise: return pairwise::AMaxPairwise; - case broadcast::SquaredSubtract: return pairwise::SquaredSubtract; - case broadcast::FloorMod: return pairwise::FloorMod; - case broadcast::FloorDiv: return pairwise::FloorDiv; - case broadcast::ReverseMod: return pairwise::ReverseMod; - case broadcast::SafeDivide: return pairwise::SafeDivide; - case broadcast::Mod: return pairwise::Mod; - case broadcast::TruncateDiv: return pairwise::TruncateDiv; - case broadcast::Atan2: return pairwise::Atan2; - case broadcast::LogicalOr: return pairwise::LogicalOr; - case broadcast::LogicalXor: return pairwise::LogicalXor; - case broadcast::LogicalNot: return pairwise::LogicalNot; - case broadcast::LogicalAnd: return pairwise::LogicalAnd; - case broadcast::PowDerivative: return pairwise::PowDerivative; - default: - throw std::runtime_error("fromBroadcastToPairwise: Not convertible operation"); - } + switch (op) { + case broadcast::Add: + return pairwise::Add; + case broadcast::Subtract: + return pairwise::Subtract; + case broadcast::Multiply: + return pairwise::Multiply; + case broadcast::Divide: + return pairwise::Divide; + case broadcast::ReverseDivide: + return pairwise::ReverseDivide; + case broadcast::ReverseSubtract: + return pairwise::ReverseSubtract; + case broadcast::CopyPws: + return pairwise::CopyPws; + case broadcast::Pow: + return pairwise::Pow; + case broadcast::MinPairwise: + return pairwise::MinPairwise; + case broadcast::MaxPairwise: + return pairwise::MaxPairwise; + case broadcast::AMinPairwise: + return pairwise::AMinPairwise; + case broadcast::AMaxPairwise: + return pairwise::AMaxPairwise; + case broadcast::SquaredSubtract: + return pairwise::SquaredSubtract; + case broadcast::FloorMod: + return pairwise::FloorMod; + case broadcast::FloorDiv: + return pairwise::FloorDiv; + case broadcast::ReverseMod: + return pairwise::ReverseMod; + case broadcast::SafeDivide: + return pairwise::SafeDivide; + case broadcast::Mod: + return pairwise::Mod; + case broadcast::TruncateDiv: + return pairwise::TruncateDiv; + case broadcast::Atan2: + return pairwise::Atan2; + case broadcast::LogicalOr: + return pairwise::LogicalOr; + case broadcast::LogicalXor: + return pairwise::LogicalXor; + case broadcast::LogicalNot: + return pairwise::LogicalNot; + case broadcast::LogicalAnd: + return pairwise::LogicalAnd; + case broadcast::PowDerivative: + return pairwise::PowDerivative; + default: + throw std::runtime_error( + "fromBroadcastToPairwise: Not convertible operation"); + } } ////////////////////////////////////////////////////////////////////////// inline pairwise::BoolOps fromBroadcastToPairwiseBool(broadcast::BoolOps op) { - switch (op) { - case broadcast::EqualTo: return pairwise::EqualTo; - case broadcast::GreaterThan: return pairwise::GreaterThan; - case broadcast::LessThan: return pairwise::LessThan; - case broadcast::Epsilon: return pairwise::Epsilon; - case broadcast::GreaterThanOrEqual: return pairwise::GreaterThanOrEqual; - case broadcast::LessThanOrEqual: return pairwise::LessThanOrEqual; - case broadcast::NotEqualTo: return pairwise::NotEqualTo; - case broadcast::And: return pairwise::And; - case broadcast::Or: return pairwise::Or; - case broadcast::Xor: return pairwise::Xor; - case broadcast::Not: return pairwise::Not; - default: - throw std::runtime_error("fromBroadcastToPairwiseBool: Not convertible operation"); - } + switch (op) { + case broadcast::EqualTo: + return pairwise::EqualTo; + case broadcast::GreaterThan: + return pairwise::GreaterThan; + case broadcast::LessThan: + return pairwise::LessThan; + case broadcast::Epsilon: + return pairwise::Epsilon; + case broadcast::GreaterThanOrEqual: + return pairwise::GreaterThanOrEqual; + case broadcast::LessThanOrEqual: + return pairwise::LessThanOrEqual; + case broadcast::NotEqualTo: + return pairwise::NotEqualTo; + case broadcast::And: + return pairwise::And; + case broadcast::Or: + return pairwise::Or; + case broadcast::Xor: + return pairwise::Xor; + case broadcast::Not: + return pairwise::Not; + default: + throw std::runtime_error( + "fromBroadcastToPairwiseBool: Not convertible operation"); + } } - inline pairwise::IntOps fromBroadcastToPairwiseInt(broadcast::IntOps op) { - switch (op) { - case broadcast::IntOps::IntAnd: return pairwise::IntOps::IntAnd; - case broadcast::IntOps::IntOr: return pairwise::IntOps::IntOr; - case broadcast::IntOps::IntXor: return pairwise::IntOps::IntXor; - case broadcast::IntOps::ShiftLeft: return pairwise::IntOps::ShiftLeft; - case broadcast::IntOps::ShiftRight: return pairwise::IntOps::ShiftRight; - case broadcast::IntOps::CyclicShiftLeft: return pairwise::IntOps::CyclicShiftLeft; - case broadcast::IntOps::CyclicShiftRight: return pairwise::IntOps::CyclicShiftRight; - default: - throw std::runtime_error("fromBroadcastToPairwiseInt: Not convertible operation"); - } - } +inline pairwise::IntOps fromBroadcastToPairwiseInt(broadcast::IntOps op) { + switch (op) { + case broadcast::IntOps::IntAnd: + return pairwise::IntOps::IntAnd; + case broadcast::IntOps::IntOr: + return pairwise::IntOps::IntOr; + case broadcast::IntOps::IntXor: + return pairwise::IntOps::IntXor; + case broadcast::IntOps::ShiftLeft: + return pairwise::IntOps::ShiftLeft; + case broadcast::IntOps::ShiftRight: + return pairwise::IntOps::ShiftRight; + case broadcast::IntOps::CyclicShiftLeft: + return pairwise::IntOps::CyclicShiftLeft; + case broadcast::IntOps::CyclicShiftRight: + return pairwise::IntOps::CyclicShiftRight; + default: + throw std::runtime_error( + "fromBroadcastToPairwiseInt: Not convertible operation"); + } } +} // namespace sd -#endif //DEV_TESTS_BROADCASTPAIRWISECONVERTER_H \ No newline at end of file +#endif // SD_BROADCASTPAIRWISECONVERTER_H \ No newline at end of file diff --git a/libnd4j/include/loops/BroadcastScalarConverter.h b/libnd4j/include/loops/BroadcastScalarConverter.h index f4d536f332ce..e7730f242672 100644 --- a/libnd4j/include/loops/BroadcastScalarConverter.h +++ b/libnd4j/include/loops/BroadcastScalarConverter.h @@ -17,42 +17,55 @@ /** * @author raver119@gmail.com */ -#ifndef DEV_TESTS_BROADCASTSCALARCONVERTER_H -#define DEV_TESTS_BROADCASTSCALARCONVERTER_H +#ifndef SD_BROADCASTSCALARCONVERTER_H +#define SD_BROADCASTSCALARCONVERTER_H #include #include + #include namespace sd { - inline bool isConvertibleToScalar(broadcast::Ops op) { - int opNum = (int) op; - - if (opNum <= 17) - return true; - - return false; - } - - inline scalar::Ops convertToScalar(broadcast::Ops op) { - switch (op) { - case broadcast::Add: return scalar::Add; - case broadcast::Subtract: return scalar::Subtract; - case broadcast::Multiply: return scalar::Multiply; - case broadcast::Divide: return scalar::Divide; - case broadcast::ReverseDivide: return scalar::ReverseDivide; - case broadcast::ReverseSubtract: return scalar::ReverseSubtract; - case broadcast::CopyPws: return scalar::CopyPws; - case broadcast::Pow: return scalar::Pow; - case broadcast::MinPairwise: return scalar::MinPairwise; - case broadcast::MaxPairwise: return scalar::MaxPairwise; - case broadcast::AMinPairwise: return scalar::AMinPairwise; - case broadcast::AMaxPairwise: return scalar::AMaxPairwise; - case broadcast::SquaredSubtract: return scalar::SquaredSubtract; - default: - throw std::runtime_error("Not convertible operation"); - } - } +inline bool isConvertibleToScalar(broadcast::Ops op) { + int opNum = (int)op; + + if (opNum <= 17) return true; + + return false; +} + +inline scalar::Ops convertToScalar(broadcast::Ops op) { + switch (op) { + case broadcast::Add: + return scalar::Add; + case broadcast::Subtract: + return scalar::Subtract; + case broadcast::Multiply: + return scalar::Multiply; + case broadcast::Divide: + return scalar::Divide; + case broadcast::ReverseDivide: + return scalar::ReverseDivide; + case broadcast::ReverseSubtract: + return scalar::ReverseSubtract; + case broadcast::CopyPws: + return scalar::CopyPws; + case broadcast::Pow: + return scalar::Pow; + case broadcast::MinPairwise: + return scalar::MinPairwise; + case broadcast::MaxPairwise: + return scalar::MaxPairwise; + case broadcast::AMinPairwise: + return scalar::AMinPairwise; + case broadcast::AMaxPairwise: + return scalar::AMaxPairwise; + case broadcast::SquaredSubtract: + return scalar::SquaredSubtract; + default: + throw std::runtime_error("Not convertible operation"); + } } +} // namespace sd -#endif //DEV_TESTS_BROADCASTSCALARCONVERTER_H +#endif // SD_BROADCASTSCALARCONVERTER_H diff --git a/libnd4j/include/loops/ReduceType.h b/libnd4j/include/loops/ReduceType.h index 501b7229e3ff..3b90585cf60a 100644 --- a/libnd4j/include/loops/ReduceType.h +++ b/libnd4j/include/loops/ReduceType.h @@ -18,19 +18,11 @@ * @author raver119@gmail.com */ -#ifndef DEV_TESTS_REDUCETYPE_H -#define DEV_TESTS_REDUCETYPE_H +#ifndef SD_REDUCETYPE_H +#define SD_REDUCETYPE_H namespace functions { - enum ReduceType { - SUM, - PRODUCT, - MAX, - MIN, - ASUM, - AMAX, - AMIN - }; +enum ReduceType { SUM, PRODUCT, MAX, MIN, ASUM, AMAX, AMIN }; } -#endif //DEV_TESTS_REDUCETYPE_H +#endif // SD_REDUCETYPE_H diff --git a/libnd4j/include/loops/broadcasting.h b/libnd4j/include/loops/broadcasting.h old mode 100755 new mode 100644 index 4f05f0c6e749..c6e3bc7cedc2 --- a/libnd4j/include/loops/broadcasting.h +++ b/libnd4j/include/loops/broadcasting.h @@ -23,13 +23,13 @@ #ifndef BROADCASTING_H_ #define BROADCASTING_H_ -#include +#include #include #include -#include #include +#include #include -#include +#include #ifdef __CUDACC__ #include @@ -39,160 +39,150 @@ #include #endif -#include #include +#include #include "legacy_ops.h" namespace functions { - namespace broadcast { +namespace broadcast { /** * Broadcast operation * for broadcasting a smaller tensor * along long a bigger one. */ - template - class Broadcast { - public: - +template +class Broadcast { + public: #ifdef __CUDABLAS__ - template - static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - template - static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo); - - template - static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - template - static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo); - - static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo); - - - template - static __device__ void transformInverseCuda(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - template - static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - + template + static __device__ void transformCuda( + const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + template + static __device__ void transformCuda(const void *x, + const Nd4jLong *xShapeInfo, + const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo); + + template + static __host__ void intermediateBroadcast( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + template + static __host__ void intermediateBroadcast( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo); + + static __host__ void execBroadcast( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, + int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo); + + template + static __device__ void transformInverseCuda( + const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + template + static __host__ void intermediateInverseBroadcast( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + static __host__ void execInverseBroadcast( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); #else - static void execInverse(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, - uint64_t start, uint64_t stop); - - static void exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, - sd::LoopKind::Kind loopKind, - uint64_t start, uint64_t stop); - - /** - * CPU execution - * @param x the input - * @param xShapeInfo the x shape information - * @param y the y data - * @param yShapeInfo the y shape information - * @param result the result - * @param resultShapeInfo the result shape information - * @param dimension the dimension to broadcast along long - * @param dimensionLength the length of the dimension buffer - */ - template - static void exec(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, - sd::LoopKind::Kind loopKind, - uint64_t start, uint64_t stop); - - template - static void execInverse(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, - uint64_t start, uint64_t stop); - - static void exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo); - - template - static void exec(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo); + static void execInverse( + int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetZ, uint64_t start, uint64_t stop); + + static void exec(int opNum, const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetZ, sd::LoopKind::Kind loopKind, + uint64_t start, uint64_t stop); + + /** + * CPU execution + * @param x the input + * @param xShapeInfo the x shape information + * @param y the y data + * @param yShapeInfo the y shape information + * @param result the result + * @param resultShapeInfo the result shape information + * @param dimension the dimension to broadcast along long + * @param dimensionLength the length of the dimension buffer + */ + template + static void exec(const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetZ, sd::LoopKind::Kind loopKind, + uint64_t start, uint64_t stop); + + template + static void execInverse( + const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetZ, uint64_t start, uint64_t stop); + + static void exec(int opNum, const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo); + + template + static void exec(const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo); #endif - }; - } -} +}; +} // namespace broadcast +} // namespace functions #endif /* BROADCASTING_H_ */ diff --git a/libnd4j/include/loops/broadcasting_bool.h b/libnd4j/include/loops/broadcasting_bool.h index 400269c02e3a..fc1a6663d074 100644 --- a/libnd4j/include/loops/broadcasting_bool.h +++ b/libnd4j/include/loops/broadcasting_bool.h @@ -23,13 +23,13 @@ #ifndef BROADCASTING_BOOL_H_ #define BROADCASTING_BOOL_H_ -#include +#include #include #include -#include #include +#include #include -#include +#include #ifdef __CUDACC__ #include @@ -44,148 +44,151 @@ #include "legacy_ops.h" namespace functions { - namespace broadcast { +namespace broadcast { /** * Broadcast operation * for broadcasting a smaller tensor * along long a bigger one. */ - template - class BroadcastBool { - public: - +template +class BroadcastBool { + public: #ifdef __CUDACC__ - template - static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - template - static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams); - - template - static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *result, Nd4jLong const* resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ); - - template - static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams); - - static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *result, Nd4jLong const* resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ); - - static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams); - - template - static __device__ void transformInverseCuda(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - template - static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); + template + static __device__ void transformCuda( + const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, + void *extraParams, int *dimension, int dimensionLength, + const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); + + template + static __device__ void transformCuda(const void *x, + const Nd4jLong *xShapeInfo, + const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo, + void *extraParams); + + template + static __host__ void intermediateBroadcast( + dim3 launchDims, cudaStream_t *stream, void const *x, + Nd4jLong const *xShapeInfo, void const *y, Nd4jLong const *yShapeInfo, + void *result, Nd4jLong const *resultShapeInfo, void *extraParams, + int *dimension, int dimensionLength, Nd4jLong const *tadOnlyShapeInfo, + Nd4jLong const *tadOffsets, Nd4jLong const *tadOnlyShapeInfoZ, + Nd4jLong const *tadOffsetsZ); + + template + static __host__ void intermediateBroadcast( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo, void *extraParams); + + static __host__ void execBroadcast( + dim3 launchDims, cudaStream_t *stream, int opNum, void const *x, + Nd4jLong const *xShapeInfo, void const *y, Nd4jLong const *yShapeInfo, + void *result, Nd4jLong const *resultShapeInfo, void *extraParams, + int *dimension, int dimensionLength, Nd4jLong const *tadOnlyShapeInfo, + Nd4jLong const *tadOffsets, Nd4jLong const *tadOnlyShapeInfoZ, + Nd4jLong const *tadOffsetsZ); + + static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, + const int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo, + void *extraParams); + + template + static __device__ void transformInverseCuda( + const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, + void *extraParams, int *dimension, int dimensionLength, + const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); + + template + static __host__ void intermediateInverseBroadcast( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, void *extraParams, + int *dimension, int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + static __host__ void execInverseBroadcast( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, void *extraParams, + int *dimension, int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); #else - static void exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, - uint64_t start, uint64_t stop); - - static void exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams); - - static void execInverse(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, - uint64_t start, uint64_t stop); - - /** - * CPU execution - * @param x the input - * @param xShapeInfo the x shape information - * @param y the y data - * @param yShapeInfo the y shape information - * @param result the result - * @param resultShapeInfo the result shape information - * @param dimension the dimension to broadcast along long - * @param dimensionLength the length of the dimension buffer - */ - template - static void exec(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, - uint64_t start, uint64_t stop); - - template - static void exec(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams); - - template - static void execInverse(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, - uint64_t start, uint64_t stop); + static void exec(int opNum, const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, void *extraParams, + int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, + const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, + uint64_t start, uint64_t stop); + + static void exec(int opNum, const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo, void *extraParams); + + static void execInverse(int opNum, const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, + void *extraParams, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, + const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetZ, uint64_t start, + uint64_t stop); + + /** + * CPU execution + * @param x the input + * @param xShapeInfo the x shape information + * @param y the y data + * @param yShapeInfo the y shape information + * @param result the result + * @param resultShapeInfo the result shape information + * @param dimension the dimension to broadcast along long + * @param dimensionLength the length of the dimension buffer + */ + template + static void exec(const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, void *extraParams, + int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, + const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, + uint64_t start, uint64_t stop); + + template + static void exec(const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo, void *extraParams); + + template + static void execInverse(const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, + void *extraParams, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, + const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetZ, uint64_t start, + uint64_t stop); #endif - }; - } -} +}; +} // namespace broadcast +} // namespace functions #endif /* BROADCASTING_H_ */ diff --git a/libnd4j/include/loops/broadcasting_int.h b/libnd4j/include/loops/broadcasting_int.h index 386fbd3f74ec..f8c377fce1df 100644 --- a/libnd4j/include/loops/broadcasting_int.h +++ b/libnd4j/include/loops/broadcasting_int.h @@ -23,13 +23,13 @@ #ifndef BROADCASTING_INT_H_ #define BROADCASTING_INT_H_ -#include +#include #include #include -#include #include +#include #include -#include +#include #ifdef __CUDACC__ #include @@ -44,148 +44,141 @@ #include "legacy_ops.h" namespace functions { - namespace broadcast { +namespace broadcast { /** * Broadcast operation * for broadcasting a smaller tensor * along long a bigger one. */ - template - class BroadcastInt { - public: - +template +class BroadcastInt { + public: #ifdef __CUDACC__ - template - static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - template - static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo); - - template - static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - template - static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo); - - static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo); - - template - static __device__ void transformInverseCuda(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - template - static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); + template + static __device__ void transformCuda( + const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + template + static __device__ void transformCuda(const void *x, + const Nd4jLong *xShapeInfo, + const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo); + + template + static __host__ void intermediateBroadcast( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + template + static __host__ void intermediateBroadcast( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo); + + static __host__ void execBroadcast( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, + const int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo); + + template + static __device__ void transformInverseCuda( + const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + template + static __host__ void intermediateInverseBroadcast( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + static __host__ void execInverseBroadcast( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); #else - static void exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, - uint64_t start, uint64_t stop); - - static void exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo); - - static void execInverse(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, - uint64_t start, uint64_t stop); - - /** - * CPU execution - * @param x the input - * @param xShapeInfo the x shape information - * @param y the y data - * @param yShapeInfo the y shape information - * @param result the result - * @param resultShapeInfo the result shape information - * @param dimension the dimension to broadcast along long - * @param dimensionLength the length of the dimension buffer - */ - template - static void exec(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, - uint64_t start, uint64_t stop); - - template - static void exec(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo); - - template - static void execInverse(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, - uint64_t start, uint64_t stop); + static void exec(int opNum, const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetZ, uint64_t start, uint64_t stop); + + static void exec(int opNum, const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo); + + static void execInverse( + int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetZ, uint64_t start, uint64_t stop); + + /** + * CPU execution + * @param x the input + * @param xShapeInfo the x shape information + * @param y the y data + * @param yShapeInfo the y shape information + * @param result the result + * @param resultShapeInfo the result shape information + * @param dimension the dimension to broadcast along long + * @param dimensionLength the length of the dimension buffer + */ + template + static void exec(const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetZ, uint64_t start, uint64_t stop); + + template + static void exec(const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo); + + template + static void execInverse( + const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetZ, uint64_t start, uint64_t stop); #endif - }; - } -} +}; +} // namespace broadcast +} // namespace functions #endif /* BROADCASTING_H_ */ diff --git a/libnd4j/include/loops/cpu/broadcasting.hpp b/libnd4j/include/loops/cpu/broadcasting.hpp index 4c59de0ecea0..d3a23a8dcd36 100644 --- a/libnd4j/include/loops/cpu/broadcasting.hpp +++ b/libnd4j/include/loops/cpu/broadcasting.hpp @@ -18,830 +18,890 @@ // @author raver119@gmail.com // -#include +#include +#include +#include +#include #include #include +#include #include -#include -#include -#include -#include using namespace simdOps; namespace functions { namespace broadcast { - template - void Broadcast::execInverse(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, - uint64_t start, uint64_t stop) { - DISPATCH_BY_OPNUM_TTT(execInverse, PARAMS(x, - xShapeInfo, - y, - yShapeInfo, - z, - zShapeInfo, - dimension, - dimensionLength, - xTadShapeInfo, - xTadOffset, - zTadShapeInfo, - zTadOffset, start, stop), BROADCAST_OPS); - } - - template - void Broadcast::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, - sd::LoopKind::Kind loopKind, - uint64_t start, uint64_t stop) { - DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x, - xShapeInfo, - y, - yShapeInfo, - z, - zShapeInfo, - dimension, - dimensionLength, - xTadShapeInfo, - xTadOffset, - zTadShapeInfo, - zTadOffset, loopKind, start, stop), BROADCAST_OPS); - } +template +void Broadcast::execInverse( + const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xTadOffset, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffset, uint64_t start, uint64_t stop) { + DISPATCH_BY_OPNUM_TTT( + execInverse, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, + dimensionLength, xTadShapeInfo, xTadOffset, zTadShapeInfo, + zTadOffset, start, stop), + BROADCAST_OPS); +} +template +void Broadcast::exec( + const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xTadOffset, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffset, sd::LoopKind::Kind loopKind, uint64_t start, + uint64_t stop) { + DISPATCH_BY_OPNUM_TTT( + exec, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, + dimensionLength, xTadShapeInfo, xTadOffset, zTadShapeInfo, + zTadOffset, loopKind, start, stop), + BROADCAST_OPS); +} - template - template - void Broadcast::exec(const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, - sd::LoopKind::Kind loopKind, - uint64_t start, uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the x shape info for setting up the tad problem - auto xTadShapeShapeInfo = xTadShapeInfo; - auto tadOffsets = xTadOffset; - - if (xTadShapeInfo == nullptr || tadOffsets == nullptr) { - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength); - - xTadShapeShapeInfo = tadPack.primaryShapeInfo(); - tadOffsets = tadPack.primaryOffsets(); - } - - //int *resultStride = shape::stride(xTadShapeShapeInfo); - unsigned int tadLength = shape::length(xTadShapeShapeInfo);//shape::length(xTadShapeShapeInfo); - unsigned int tads = shape::length(xShapeInfo) / tadLength; - - if (zTadShapeInfo == nullptr) { - zTadShapeInfo = xTadShapeShapeInfo; - zTadOffset = tadOffsets; - } - - auto lenZ = shape::length(zTadShapeInfo); - auto lenY = shape::length(yShapeInfo); - - auto xEws = shape::elementWiseStride(xTadShapeShapeInfo); - auto yEws = shape::elementWiseStride(yShapeInfo); - auto zEws = shape::elementWiseStride(zTadShapeInfo); - - - const sd::LoopKind::Kind kindOfLoop = - (loopKind == sd::LoopKind::BROADCAST_SCALAR_X || - loopKind == sd::LoopKind::BROADCAST_SCALAR_Y || - loopKind == sd::LoopKind::BROADCAST_3D || - loopKind == sd::LoopKind::BROADCAST_4D || - loopKind == sd::LoopKind::BROADCAST_5D) - ? loopKind : sd::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo); - - if (kindOfLoop == sd::LoopKind::EWS1) { - for (auto i = start; i < stop; i++) { - auto oX = x + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) - oZ[f] = OpType::op(oX[f], y[f]); - } - } - else if(kindOfLoop == sd::LoopKind::EWSNONZERO){ - for (auto i = start; i < stop; i++) { - auto oX = x + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) - oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]); - } - } else if(kindOfLoop == sd::LoopKind::BROADCAST_SCALAR_X){ - // this loop effectively turns broadcast into series of scalar ops - auto loopLength = yShapeInfo[shape::rank(yShapeInfo)]; - - for (auto i = start; i < stop; i++) { - auto oY = y + (i * loopLength); - auto oZ = z + (i * loopLength); - - const auto oX = x[i]; - - PRAGMA_OMP_SIMD - for (Nd4jLong f = 0; f < loopLength; f++) - oZ[f] = OpType::op(oX, oY[f]); - } - } else if(kindOfLoop == sd::LoopKind::BROADCAST_SCALAR_Y){ - // this loop effectively turns broadcast into series of scalar ops - auto loopLength = xShapeInfo[shape::rank(xShapeInfo)]; - - for (auto i = start; i < stop; i++) { - auto oX = x + (i * loopLength); - auto oZ = z + (i * loopLength); - - const auto oY = y[i]; - - PRAGMA_OMP_SIMD - for (Nd4jLong f = 0; f < loopLength; f++) - oZ[f] = OpType::op(oX[f], oY); - } - } - else if (kindOfLoop == sd::LoopKind::BROADCAST_3D) { - - int xRank = shape::rank(xShapeInfo); - int yRank = shape::rank(yShapeInfo); - - auto xStrides = shape::stride(xShapeInfo); - auto zStrides = shape::stride(zShapeInfo); - - Nd4jLong yStrides[3] = { 0,0,0 }; - sd::ShapeUtils::copyCertainStridesFromShapeInfo(yShapeInfo, xRank, dimensionLength, dimension, yStrides); - - uint64_t nSize1 = shape::sizeAt(zShapeInfo, 1); - uint64_t nSize2 = shape::sizeAt(zShapeInfo, 2); - - for (auto index0 = start; index0 < stop; index0++) { - - PRAGMA_OMP_SIMD - for (uint64_t index1 = 0; index1 < nSize1; index1++) { - for (uint64_t index2 = 0; index2 < nSize2; index2++) { - auto rX = x + (xStrides[0] * index0 + xStrides[1] * index1 + xStrides[2] * index2); - auto rY = y + (yStrides[0] * index0 + yStrides[1] * index1 + yStrides[2] * index2); - auto rZ = z + (zStrides[0] * index0 + zStrides[1] * index1 + zStrides[2] * index2); - *rZ = OpType::op(*rX, *rY); - } - } - - } - - } - else if (kindOfLoop == sd::LoopKind::BROADCAST_4D) { - - int xRank = shape::rank(xShapeInfo); - int yRank = shape::rank(yShapeInfo); - - auto xStrides = shape::stride(xShapeInfo); - auto zStrides = shape::stride(zShapeInfo); - - Nd4jLong yStrides[4] = { 0,0,0,0 }; - sd::ShapeUtils::copyCertainStridesFromShapeInfo(yShapeInfo, xRank, dimensionLength, dimension, yStrides); - - uint64_t nSize1 = shape::sizeAt(zShapeInfo, 1); - uint64_t nSize2 = shape::sizeAt(zShapeInfo, 2); - uint64_t nSize3 = shape::sizeAt(zShapeInfo, 3); - - for (auto i = start; i < stop; i++) { - - uint64_t index0 = i / nSize1; - uint64_t index1 = i % nSize1; - - PRAGMA_OMP_SIMD - for (uint64_t index2 = 0; index2 < nSize2; index2++) { - for (uint64_t index3 = 0; index3 < nSize3; index3++) { - auto rX = x + (xStrides[0] * index0 + xStrides[1] * index1 + xStrides[2] * index2 + xStrides[3] * index3); - auto rY = y + (yStrides[0] * index0 + yStrides[1] * index1 + yStrides[2] * index2 + yStrides[3] * index3); - auto rZ = z + (zStrides[0] * index0 + zStrides[1] * index1 + zStrides[2] * index2 + zStrides[3] * index3); - *rZ = OpType::op(*rX, *rY); - } - } - } - - } - else if (kindOfLoop == sd::LoopKind::BROADCAST_5D) { - - int xRank = shape::rank(xShapeInfo); - int yRank = shape::rank(yShapeInfo); - - auto xStrides = shape::stride(xShapeInfo); - auto zStrides = shape::stride(zShapeInfo); - - Nd4jLong yStrides[5] = { 0,0,0,0,0 }; - sd::ShapeUtils::copyCertainStridesFromShapeInfo(yShapeInfo, xRank, dimensionLength, dimension, yStrides); - - uint32_t nSize1 = shape::sizeAt(zShapeInfo, 1); - uint32_t nSize2 = shape::sizeAt(zShapeInfo, 2); - uint32_t nSize3 = shape::sizeAt(zShapeInfo, 3); - uint32_t nSize4 = shape::sizeAt(zShapeInfo, 4); - - for (auto i = start; i < stop; i++) { - - uint32_t index0 = i / nSize1; - uint32_t index1 = i % nSize1; - - PRAGMA_OMP_SIMD - for (uint32_t index2 = 0; index2 < nSize2; index2++) { - for (uint32_t index3 = 0; index3 < nSize3; index3++) { - for (uint32_t index4 = 0; index4 < nSize4; index4++) { - auto rX = x + (xStrides[0] * index0 + xStrides[1] * index1 + xStrides[2] * index2 + xStrides[3] * index3 + xStrides[4] * index4); - auto rY = y + (yStrides[0] * index0 + yStrides[1] * index1 + yStrides[2] * index2 + yStrides[3] * index3 + yStrides[4] * index4); - auto rZ = z + (zStrides[0] * index0 + zStrides[1] * index1 + zStrides[2] * index2 + zStrides[3] * index3 + zStrides[4] * index4); - - *rZ = OpType::op(*rX, *rY); - } - } - } - } - - } - else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - - for (auto i = start; i < stop; i++) { - auto oX = x + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - oZ[offset] = OpType::op(oX[offset], y[offset]); - } - } - } - else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint tadShapeInfoZCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); - - - for (auto i = start; i < stop; i++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(oX[offset], y[offset]); - } - } - } - else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - for (auto i = start; i < stop; i++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); - oZ[offset] = OpType::op(oX[offset], y[yOffset]); - } - } - } - else if(shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - for (auto i = start; i < stop; i++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); - oZ[offset] = OpType::op(oX[xOffset], y[offset]); - } - } - } - else { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint tadShapeInfoZCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); - - for (auto i = start; i < stop; i++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]); - } - } - } - } +template +template +void Broadcast::exec( + const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xTadOffset, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffset, sd::LoopKind::Kind loopKind, uint64_t start, + uint64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the x shape info for setting up the tad problem + auto xTadShapeShapeInfo = xTadShapeInfo; + auto tadOffsets = xTadOffset; + + if (xTadShapeInfo == nullptr || tadOffsets == nullptr) { + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + xShapeInfo, dimension, dimensionLength); + + xTadShapeShapeInfo = tadPack.primaryShapeInfo(); + tadOffsets = tadPack.primaryOffsets(); + } + + // int *resultStride = shape::stride(xTadShapeShapeInfo); + unsigned int tadLength = + shape::length(xTadShapeShapeInfo); // shape::length(xTadShapeShapeInfo); + unsigned int tads = shape::length(xShapeInfo) / tadLength; + + if (zTadShapeInfo == nullptr) { + zTadShapeInfo = xTadShapeShapeInfo; + zTadOffset = tadOffsets; + } + + auto lenZ = shape::length(zTadShapeInfo); + auto lenY = shape::length(yShapeInfo); + + auto xEws = shape::elementWiseStride(xTadShapeShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + auto zEws = shape::elementWiseStride(zTadShapeInfo); + + const sd::LoopKind::Kind kindOfLoop = + (loopKind == sd::LoopKind::BROADCAST_SCALAR_X || + loopKind == sd::LoopKind::BROADCAST_SCALAR_Y || + loopKind == sd::LoopKind::BROADCAST_3D || + loopKind == sd::LoopKind::BROADCAST_4D || + loopKind == sd::LoopKind::BROADCAST_5D) + ? loopKind + : sd::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, + zTadShapeInfo); + + if (kindOfLoop == sd::LoopKind::EWS1) { + for (auto i = start; i < stop; i++) { + auto oX = x + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f] = OpType::op(oX[f], y[f]); + } + } else if (kindOfLoop == sd::LoopKind::EWSNONZERO) { + for (auto i = start; i < stop; i++) { + auto oX = x + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]); + } + } else if (kindOfLoop == sd::LoopKind::BROADCAST_SCALAR_X) { + // this loop effectively turns broadcast into series of scalar ops + auto loopLength = yShapeInfo[shape::rank(yShapeInfo)]; + for (auto i = start; i < stop; i++) { + auto oY = y + (i * loopLength); + auto oZ = z + (i * loopLength); + const auto oX = x[i]; - template - template - void Broadcast::execInverse(const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yTadOffset, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, - uint64_t start, uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the x shape info for setting up the tad problem - auto yTadShapeShapeInfo = yTadShapeInfo; - auto tadOffsets = yTadOffset; - - if (yTadShapeInfo == nullptr || tadOffsets == nullptr) { - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(yShapeInfo, dimension, dimensionLength); - - yTadShapeShapeInfo = tadPack.primaryShapeInfo(); - tadOffsets = tadPack.primaryOffsets(); - } - - //int *resultStride = shape::stride(yTadShapeShapeInfo); - unsigned int tadLength = shape::length(yTadShapeShapeInfo); - unsigned int tads = shape::length(yShapeInfo) / tadLength; - - if (zTadShapeInfo == nullptr) { - zTadShapeInfo = yTadShapeShapeInfo; - zTadOffset = tadOffsets; - } - - auto lenZ = shape::length(zTadShapeInfo); - auto lenX = shape::length(xShapeInfo); - - int tadsPerThread = tads / TAD_THRESHOLD; - int threads = sd::math::nd4j_max(1, tadsPerThread); - threads = sd::math::nd4j_min(threads, sd::Environment::getInstance().maxThreads()); - - auto yEws = shape::elementWiseStride(yTadShapeShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - auto zEws = shape::elementWiseStride(zTadShapeInfo); - - const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ(yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo); - - if(kindOfLoop == sd::LoopKind::EWS1) { - for (auto i = start; i < stop; i++) { - auto oY = y + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) - oZ[f] = OpType::op(x[f], oY[f]); - } - } - else if(kindOfLoop == sd::LoopKind::EWSNONZERO) { - for (auto i = start; i < stop; i++) { - auto oY = y + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) - oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws]); - }; - } - else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - - for (auto i = start; i < stop; i++) { - auto oY = x + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - oZ[offset] = OpType::op(x[offset], oY[offset]); - } - }; - } - else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint tadShapeInfoZCast[MAX_RANK]; - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); - - for (auto i = start; i < stop; i++) { - auto oZ = z + zTadOffset[i]; - auto oY = y + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(x[offset], oY[offset]); - } - }; - } - else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - - for (auto i = start; i < stop; i++) { - auto oZ = z + zTadOffset[i]; - auto oY = y + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - auto xOffset = shape::indexOffset(f, yShapeInfo, xShapeInfoCast, canCastX); - oZ[offset] = OpType::op(x[xOffset], oY[offset]); - } - }; - } - else if(shape::haveSameShapeAndStrides(xShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - - for (auto i = start; i < stop; i++) { - auto oZ = z + zTadOffset[i]; - auto oY = y + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); - oZ[offset] = OpType::op(x[offset], oY[yOffset]); - } - }; - } - else { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint tadShapeInfoZCast[MAX_RANK]; - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); - - for (auto i = start; i < stop; i++) { - auto oZ = z + zTadOffset[i]; - auto oY = y + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]); - } - }; - } - } + PRAGMA_OMP_SIMD + for (Nd4jLong f = 0; f < loopLength; f++) oZ[f] = OpType::op(oX, oY[f]); + } + } else if (kindOfLoop == sd::LoopKind::BROADCAST_SCALAR_Y) { + // this loop effectively turns broadcast into series of scalar ops + auto loopLength = xShapeInfo[shape::rank(xShapeInfo)]; + for (auto i = start; i < stop; i++) { + auto oX = x + (i * loopLength); + auto oZ = z + (i * loopLength); -//////////////////////////////////////////////////////////////////////// -template - void Broadcast::exec(const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo) { + const auto oY = y[i]; - DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), BROADCAST_OPS); -} + PRAGMA_OMP_SIMD + for (Nd4jLong f = 0; f < loopLength; f++) oZ[f] = OpType::op(oX[f], oY); + } + } else if (kindOfLoop == sd::LoopKind::BROADCAST_3D) { + int xRank = shape::rank(xShapeInfo); + int yRank = shape::rank(yShapeInfo); + + auto xStrides = shape::stride(xShapeInfo); + auto zStrides = shape::stride(zShapeInfo); + + Nd4jLong yStrides[3] = {0, 0, 0}; + sd::ShapeUtils::copyCertainStridesFromShapeInfo( + yShapeInfo, xRank, dimensionLength, dimension, yStrides); + + uint64_t nSize1 = shape::sizeAt(zShapeInfo, 1); + uint64_t nSize2 = shape::sizeAt(zShapeInfo, 2); + + for (auto index0 = start; index0 < stop; index0++) { + PRAGMA_OMP_SIMD + for (uint64_t index1 = 0; index1 < nSize1; index1++) { + for (uint64_t index2 = 0; index2 < nSize2; index2++) { + auto rX = x + (xStrides[0] * index0 + xStrides[1] * index1 + + xStrides[2] * index2); + auto rY = y + (yStrides[0] * index0 + yStrides[1] * index1 + + yStrides[2] * index2); + auto rZ = z + (zStrides[0] * index0 + zStrides[1] * index1 + + zStrides[2] * index2); + *rZ = OpType::op(*rX, *rY); + } + } + } -//////////////////////////////////////////////////////////////////////// -template -static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) { + } else if (kindOfLoop == sd::LoopKind::BROADCAST_4D) { + int xRank = shape::rank(xShapeInfo); + int yRank = shape::rank(yShapeInfo); + + auto xStrides = shape::stride(xShapeInfo); + auto zStrides = shape::stride(zShapeInfo); + + Nd4jLong yStrides[4] = {0, 0, 0, 0}; + sd::ShapeUtils::copyCertainStridesFromShapeInfo( + yShapeInfo, xRank, dimensionLength, dimension, yStrides); + + uint64_t nSize1 = shape::sizeAt(zShapeInfo, 1); + uint64_t nSize2 = shape::sizeAt(zShapeInfo, 2); + uint64_t nSize3 = shape::sizeAt(zShapeInfo, 3); + + for (auto i = start; i < stop; i++) { + uint64_t index0 = i / nSize1; + uint64_t index1 = i % nSize1; + + PRAGMA_OMP_SIMD + for (uint64_t index2 = 0; index2 < nSize2; index2++) { + for (uint64_t index3 = 0; index3 < nSize3; index3++) { + auto rX = x + (xStrides[0] * index0 + xStrides[1] * index1 + + xStrides[2] * index2 + xStrides[3] * index3); + auto rY = y + (yStrides[0] * index0 + yStrides[1] * index1 + + yStrides[2] * index2 + yStrides[3] * index3); + auto rZ = z + (zStrides[0] * index0 + zStrides[1] * index1 + + zStrides[2] * index2 + zStrides[3] * index3); + *rZ = OpType::op(*rX, *rY); + } + } + } - uint zAxis0 = shape::sizeAt(zShapeInfo, 0); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0); + } else if (kindOfLoop == sd::LoopKind::BROADCAST_5D) { + int xRank = shape::rank(xShapeInfo); + int yRank = shape::rank(yShapeInfo); + + auto xStrides = shape::stride(xShapeInfo); + auto zStrides = shape::stride(zShapeInfo); + + Nd4jLong yStrides[5] = {0, 0, 0, 0, 0}; + sd::ShapeUtils::copyCertainStridesFromShapeInfo( + yShapeInfo, xRank, dimensionLength, dimension, yStrides); + + uint32_t nSize1 = shape::sizeAt(zShapeInfo, 1); + uint32_t nSize2 = shape::sizeAt(zShapeInfo, 2); + uint32_t nSize3 = shape::sizeAt(zShapeInfo, 3); + uint32_t nSize4 = shape::sizeAt(zShapeInfo, 4); + + for (auto i = start; i < stop; i++) { + uint32_t index0 = i / nSize1; + uint32_t index1 = i % nSize1; + + PRAGMA_OMP_SIMD + for (uint32_t index2 = 0; index2 < nSize2; index2++) { + for (uint32_t index3 = 0; index3 < nSize3; index3++) { + for (uint32_t index4 = 0; index4 < nSize4; index4++) { + auto rX = x + (xStrides[0] * index0 + xStrides[1] * index1 + + xStrides[2] * index2 + xStrides[3] * index3 + + xStrides[4] * index4); + auto rY = y + (yStrides[0] * index0 + yStrides[1] * index1 + + yStrides[2] * index2 + yStrides[3] * index3 + + yStrides[4] * index4); + auto rZ = z + (zStrides[0] * index0 + zStrides[1] * index1 + + zStrides[2] * index2 + zStrides[3] * index3 + + zStrides[4] * index4); + + *rZ = OpType::op(*rX, *rY); + } + } + } + } - auto func = PRAGMA_THREADS_FOR{ + } else if (shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && + shape::haveSameShapeAndStrides(xTadShapeShapeInfo, + zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oX = x + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + oZ[offset] = OpType::op(oX[offset], y[offset]); + } + } + } else if (shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); + oZ[zOffset] = OpType::op(oX[offset], y[offset]); + } + } + } else if (shape::haveSameShapeAndStrides(xTadShapeShapeInfo, + zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); + oZ[offset] = OpType::op(oX[offset], y[yOffset]); + } + } + } else if (shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + auto offset = + shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); + oZ[offset] = OpType::op(oX[xOffset], y[offset]); + } + } + } else { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); + oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]); + } + } + } +} - if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(x[i0], *y); - } - else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(*x, y[i0]); - } - else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(x[i0], y[i0]); - } - else { - for (auto i0 = start; i0 < stop; ++i0) - z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]); - } +template +template +void Broadcast::execInverse( + const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *yTadShapeInfo, + const Nd4jLong *yTadOffset, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffset, uint64_t start, uint64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the x shape info for setting up the tad problem + auto yTadShapeShapeInfo = yTadShapeInfo; + auto tadOffsets = yTadOffset; + + if (yTadShapeInfo == nullptr || tadOffsets == nullptr) { + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + yShapeInfo, dimension, dimensionLength); + + yTadShapeShapeInfo = tadPack.primaryShapeInfo(); + tadOffsets = tadPack.primaryOffsets(); + } + + // int *resultStride = shape::stride(yTadShapeShapeInfo); + unsigned int tadLength = shape::length(yTadShapeShapeInfo); + unsigned int tads = shape::length(yShapeInfo) / tadLength; + + if (zTadShapeInfo == nullptr) { + zTadShapeInfo = yTadShapeShapeInfo; + zTadOffset = tadOffsets; + } + + auto lenZ = shape::length(zTadShapeInfo); + auto lenX = shape::length(xShapeInfo); + + int tadsPerThread = tads / TAD_THRESHOLD; + int threads = sd::math::nd4j_max(1, tadsPerThread); + threads = sd::math::nd4j_min( + threads, sd::Environment::getInstance().maxThreads()); + + auto yEws = shape::elementWiseStride(yTadShapeShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zTadShapeInfo); + + const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ( + yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo); + + if (kindOfLoop == sd::LoopKind::EWS1) { + for (auto i = start; i < stop; i++) { + auto oY = y + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f] = OpType::op(x[f], oY[f]); + } + } else if (kindOfLoop == sd::LoopKind::EWSNONZERO) { + for (auto i = start; i < stop; i++) { + auto oY = y + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws]); + }; + } else if (shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && + shape::haveSameShapeAndStrides(yTadShapeShapeInfo, + zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oY = x + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + oZ[offset] = OpType::op(x[offset], oY[offset]); + } }; - samediff::Threads::parallel_tad(func, 0, zAxis0); + } else if (shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); + oZ[zOffset] = OpType::op(x[offset], oY[offset]); + } + }; + } else if (shape::haveSameShapeAndStrides(yTadShapeShapeInfo, + zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + auto xOffset = + shape::indexOffset(f, yShapeInfo, xShapeInfoCast, canCastX); + oZ[offset] = OpType::op(x[xOffset], oY[offset]); + } + }; + } else if (shape::haveSameShapeAndStrides(xShapeInfo, zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + auto offset = + shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); + oZ[offset] = OpType::op(x[offset], oY[yOffset]); + } + }; + } else { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto xOffset = + shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); + auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); + oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]); + } + }; + } } //////////////////////////////////////////////////////////////////////// -template -static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); - - uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); - - auto func = PRAGMA_THREADS_FOR{ - - for (auto i0 = start; i0 < stop; ++i0) { - - auto x0 = x + i0 * xStrd0; - auto y0 = y + i0 * yStrd0; - auto z0 = z + i0 * zStrd0; - - if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(x0[i1], *y0); - else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(*x0, y0[i1]); - else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(x0[i1], y0[i1]); - else - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]); - } - }; +template +void Broadcast::exec(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo) { + DISPATCH_BY_OPNUM_TTT( + exec, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), BROADCAST_OPS); +} - samediff::Threads::parallel_tad(func, 0, zAxis0); +//////////////////////////////////////////////////////////////////////// +template +static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const Y *y, + const Nd4jLong *yShapeInfo, Z *z, + const Nd4jLong *zShapeInfo) { + uint zAxis0 = shape::sizeAt(zShapeInfo, 0); + Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0); + Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0); + Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0); + + auto func = PRAGMA_THREADS_FOR { + if (zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) { + for (auto i0 = start; i0 < stop; ++i0) z[i0] = OpType::op(x[i0], *y); + } else if (zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) { + for (auto i0 = start; i0 < stop; ++i0) z[i0] = OpType::op(*x, y[i0]); + } else if (zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) { + for (auto i0 = start; i0 < stop; ++i0) z[i0] = OpType::op(x[i0], y[i0]); + } else { + for (auto i0 = start; i0 < stop; ++i0) + z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]); + } + }; + samediff::Threads::parallel_tad(func, 0, zAxis0); } //////////////////////////////////////////////////////////////////////// -template -static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); - - uint zAxis1 = shape::sizeAt(zShapeInfo, 1); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1); - - uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); - Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); - Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); - Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); - - auto func = PRAGMA_THREADS_FOR_2D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - - auto x1 = x + i0 * xStrd0 + i1 * xStrd1; - auto y1 = y + i0 * yStrd0 + i1 * yStrd1; - auto z1 = z + i0 * zStrd0 + i1 * zStrd1; - - if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(x1[i2], *y1); - else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(*x1, y1[i2]); - else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(x1[i2], y1[i2]); - else - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]); - } - } - }; +template +static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const Y *y, + const Nd4jLong *yShapeInfo, Z *z, + const Nd4jLong *zShapeInfo) { + uint zAxis0 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong xStrd0 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong yStrd0 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong zStrd0 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + + uint zAxis1 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong xStrd1 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong yStrd1 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong zStrd1 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + + auto func = PRAGMA_THREADS_FOR { + for (auto i0 = start; i0 < stop; ++i0) { + auto x0 = x + i0 * xStrd0; + auto y0 = y + i0 * yStrd0; + auto z0 = z + i0 * zStrd0; + + if (zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0) + for (uint i1 = 0; i1 < zAxis1; ++i1) z0[i1] = OpType::op(x0[i1], *y0); + else if (zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1) + for (uint i1 = 0; i1 < zAxis1; ++i1) z0[i1] = OpType::op(*x0, y0[i1]); + else if (zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1) + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1] = OpType::op(x0[i1], y0[i1]); + else + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]); + } + }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1); + samediff::Threads::parallel_tad(func, 0, zAxis0); } //////////////////////////////////////////////////////////////////////// -template -static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); - - uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); - - uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); - Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); - Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); - Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); - - uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); - Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); - Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); - Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); - - auto func = PRAGMA_THREADS_FOR_3D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - for (auto i2 = start_z; i2 < stop_z; ++i2) { - - auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2; - auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2; - auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2; - - if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(x2[i3], *y2); - else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(*x2, y2[i3]); - else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(x2[i3], y2[i3]); - else - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]); - } - } - } - }; +template +static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const Y *y, + const Nd4jLong *yShapeInfo, Z *z, + const Nd4jLong *zShapeInfo) { + uint zAxis0 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong xStrd0 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong yStrd0 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong zStrd0 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + + uint zAxis1 = shape::sizeAt(zShapeInfo, 1); + Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1); + Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1); + Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1); + + uint zAxis2 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong xStrd2 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong yStrd2 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong zStrd2 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + + auto func = PRAGMA_THREADS_FOR_2D { + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + auto x1 = x + i0 * xStrd0 + i1 * xStrd1; + auto y1 = y + i0 * yStrd0 + i1 * yStrd1; + auto z1 = z + i0 * zStrd0 + i1 * zStrd1; + + if (zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0) + for (uint i2 = 0; i2 < zAxis2; ++i2) z1[i2] = OpType::op(x1[i2], *y1); + else if (zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1) + for (uint i2 = 0; i2 < zAxis2; ++i2) z1[i2] = OpType::op(*x1, y1[i2]); + else if (zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1) + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2] = OpType::op(x1[i2], y1[i2]); + else + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]); + } + } + }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); + samediff::Threads::parallel_for(func, 0, zAxis0, 1, 0, zAxis1, 1); } //////////////////////////////////////////////////////////////////////// -template -static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); - - uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); - - uint zAxis2 = shape::sizeAt(zShapeInfo, 2); - Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2); - Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2); - Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2); - - uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); - Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); - Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); - Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); - - uint zAxis4 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); - Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); - Nd4jLong yStrd4 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); - Nd4jLong zStrd4 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); - - auto func = PRAGMA_THREADS_FOR_3D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - for (auto i2 = start_z; i2 < stop_z; ++i2) { - for (uint i3 = 0; i3 < zAxis3; ++i3) { - - auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3; - auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3; - auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3; - - if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(x3[i4], *y3); - else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(*x3, y3[i4]); - else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(x3[i4], y3[i4]); - else - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]); - } - } - } +template +static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const Y *y, + const Nd4jLong *yShapeInfo, Z *z, + const Nd4jLong *zShapeInfo) { + uint zAxis0 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong xStrd0 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong yStrd0 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong zStrd0 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + + uint zAxis1 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong xStrd1 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong yStrd1 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong zStrd1 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + + uint zAxis2 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong xStrd2 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong yStrd2 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong zStrd2 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + + uint zAxis3 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong xStrd3 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong yStrd3 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong zStrd3 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + + auto func = PRAGMA_THREADS_FOR_3D { + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2; + auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2; + auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2; + + if (zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(x2[i3], *y2); + else if (zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(*x2, y2[i3]); + else if (zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(x2[i3], y2[i3]); + else + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]); } - }; + } + } + }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); + samediff::Threads::parallel_for(func, 0, zAxis0, 1, 0, zAxis1, 1, 0, zAxis2, + 1); } //////////////////////////////////////////////////////////////////////// -template -static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) { - - const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); - - const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank - - auto func = PRAGMA_THREADS_FOR{ - - int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; - - for (auto i = start; i < stop; ++i) { - - shape::index2coordsCPU(start, i, zShapeInfo, zCoords); - - for (uint j = 0; j < rank; ++j) { - xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; - yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; - } - - const auto zOffset = shape::getOffset(zShapeInfo, zCoords); - const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); - const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset]); +template +static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const Y *y, + const Nd4jLong *yShapeInfo, Z *z, + const Nd4jLong *zShapeInfo) { + uint zAxis0 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong xStrd0 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong yStrd0 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong zStrd0 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + + uint zAxis1 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong xStrd1 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong yStrd1 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong zStrd1 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + + uint zAxis2 = shape::sizeAt(zShapeInfo, 2); + Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2); + Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2); + Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2); + + uint zAxis3 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong xStrd3 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong yStrd3 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong zStrd3 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + + uint zAxis4 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong xStrd4 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong yStrd4 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong zStrd4 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + + auto func = PRAGMA_THREADS_FOR_3D { + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + for (uint i3 = 0; i3 < zAxis3; ++i3) { + auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3; + auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3; + auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3; + + if (zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(x3[i4], *y3); + else if (zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(*x3, y3[i4]); + else if (zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(x3[i4], y3[i4]); + else + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]); + } } - }; + } + } + }; - samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo)); + samediff::Threads::parallel_for(func, 0, zAxis0, 1, 0, zAxis1, 1, 0, zAxis2, + 1); } //////////////////////////////////////////////////////////////////////// -template -template -void Broadcast::exec(const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { - - const X* x = reinterpret_cast(vx); - const Y* y = reinterpret_cast(vy); - Z* z = reinterpret_cast(vz); - - const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank - - switch (rank) { - - case 1: - execRank1(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); - break; - case 2: - execRank2(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); - break; - case 3: - execRank3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); - break; - case 4: - execRank4(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); - break; - case 5: - execRank5(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); - break; - default: - execDefault(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); +template +static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const Y *y, + const Nd4jLong *yShapeInfo, Z *z, + const Nd4jLong *zShapeInfo) { + const bool xzSameOffsets = + shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + const bool yzSameOffsets = + shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); + + const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank + + auto func = PRAGMA_THREADS_FOR { + int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; + + for (auto i = start; i < stop; ++i) { + shape::index2coordsCPU(start, i, zShapeInfo, zCoords); + + for (uint j = 0; j < rank; ++j) { + xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; + yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; + } + + const auto zOffset = shape::getOffset(zShapeInfo, zCoords); + const auto xOffset = + xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); + const auto yOffset = + yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset]); } + }; + + samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo)); } +//////////////////////////////////////////////////////////////////////// +template +template +void Broadcast::exec(const void *vx, const Nd4jLong *xShapeInfo, + const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo) { + const X *x = reinterpret_cast(vx); + const Y *y = reinterpret_cast(vy); + Z *z = reinterpret_cast(vz); + + const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank + + switch (rank) { + case 1: + execRank1(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + case 2: + execRank2(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + case 3: + execRank3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + case 4: + execRank4(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + case 5: + execRank5(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + default: + execDefault(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + } } -} \ No newline at end of file + +} // namespace broadcast +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/broadcasting_bool.hpp b/libnd4j/include/loops/cpu/broadcasting_bool.hpp index a1593512455d..6d25aadecbb5 100644 --- a/libnd4j/include/loops/cpu/broadcasting_bool.hpp +++ b/libnd4j/include/loops/cpu/broadcasting_bool.hpp @@ -18,718 +18,787 @@ // @author raver119@gmail.com // -#include +#include +#include +#include #include #include +#include #include -#include -#include -#include using namespace simdOps; namespace functions { namespace broadcast { - template - void BroadcastBool::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, - uint64_t start, uint64_t stop) { - DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, - xShapeInfo, - y, - yShapeInfo, - z, - zShapeInfo, - extraParams, - dimension, - dimensionLength, - xTadShapeInfo, - xTadOffset, - zTadShapeInfo, - zTadOffset, start, stop), BROADCAST_BOOL_OPS); - } +template +void BroadcastBool::exec( + const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo, + void *extraParams, int *dimension, int dimensionLength, + const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset, + const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, uint64_t start, + uint64_t stop) { + DISPATCH_BY_OPNUM_TT( + exec, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, + dimension, dimensionLength, xTadShapeInfo, xTadOffset, + zTadShapeInfo, zTadOffset, start, stop), + BROADCAST_BOOL_OPS); +} - template - void BroadcastBool::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void* extraParams) { +template +void BroadcastBool::exec(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo, void *extraParams) { + DISPATCH_BY_OPNUM_TT( + exec, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams), + BROADCAST_BOOL_OPS); +} - DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams), BROADCAST_BOOL_OPS); - } +template +void BroadcastBool::execInverse( + const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo, + void *extraParams, int *dimension, int dimensionLength, + const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset, + const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, uint64_t start, + uint64_t stop) { + DISPATCH_BY_OPNUM_TT( + execInverse, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, + dimension, dimensionLength, xTadShapeInfo, xTadOffset, + zTadShapeInfo, zTadOffset, start, stop), + BROADCAST_BOOL_OPS); +} - template - void BroadcastBool::execInverse(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, - uint64_t start, uint64_t stop) { - DISPATCH_BY_OPNUM_TT(execInverse, PARAMS(x, - xShapeInfo, - y, - yShapeInfo, - z, - zShapeInfo, - extraParams, - dimension, - dimensionLength, - xTadShapeInfo, - xTadOffset, - zTadShapeInfo, - zTadOffset, start, stop), BROADCAST_BOOL_OPS); - } +template +template +void BroadcastBool::exec( + const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, + void *vextraParams, int *dimension, int dimensionLength, + const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset, + const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, uint64_t start, + uint64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the x shape info for setting up the tad problem + auto xTadShapeShapeInfo = xTadShapeInfo; + auto tadOffsets = xTadOffset; + + if (xTadShapeInfo == nullptr || tadOffsets == nullptr) { + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + xShapeInfo, dimension, dimensionLength); + + xTadShapeShapeInfo = const_cast(tadPack.primaryShapeInfo()); + tadOffsets = const_cast(tadPack.primaryOffsets()); + } + + // int *resultStride = shape::stride(xTadShapeShapeInfo); + unsigned int tadLength = + shape::length(xTadShapeShapeInfo); // shape::length(xTadShapeShapeInfo); + unsigned int tads = shape::length(xShapeInfo) / tadLength; + + if (zTadShapeInfo == nullptr) { + zTadShapeInfo = xTadShapeShapeInfo; + zTadOffset = tadOffsets; + } + + auto lenZ = shape::length(zTadShapeInfo); + auto lenY = shape::length(yShapeInfo); + + int tadsPerThread = tads / TAD_THRESHOLD; + int threads = sd::math::nd4j_max(1, tadsPerThread); + threads = sd::math::nd4j_min( + threads, sd::Environment::getInstance().maxThreads()); + + auto xEws = shape::elementWiseStride(xTadShapeShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + auto zEws = shape::elementWiseStride(zTadShapeInfo); + + const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ( + xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo); + + if (kindOfLoop == sd::LoopKind::EWS1) { + for (auto i = start; i < stop; i++) { + auto oX = x + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f] = OpType::op(oX[f], y[f], extraParams); + } + } else if (kindOfLoop == sd::LoopKind::EWSNONZERO) { + for (auto i = start; i < stop; i++) { + auto oX = x + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws], extraParams); + }; + } else if (shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && + shape::haveSameShapeAndStrides(xTadShapeShapeInfo, + zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + oZ[offset] = OpType::op(oX[offset], y[offset], extraParams); + } + }; + } else if (shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); + oZ[zOffset] = OpType::op(oX[offset], y[offset], extraParams); + } + }; + } else if (shape::haveSameShapeAndStrides(xTadShapeShapeInfo, + zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); + oZ[offset] = OpType::op(oX[offset], y[yOffset], extraParams); + } + }; - template - template - void BroadcastBool::exec(const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraParams, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, - uint64_t start, uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the x shape info for setting up the tad problem - auto xTadShapeShapeInfo = xTadShapeInfo; - auto tadOffsets = xTadOffset; - - if (xTadShapeInfo == nullptr || tadOffsets == nullptr) { - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength); - - xTadShapeShapeInfo = const_cast(tadPack.primaryShapeInfo()); - tadOffsets = const_cast(tadPack.primaryOffsets()); - } - - //int *resultStride = shape::stride(xTadShapeShapeInfo); - unsigned int tadLength = shape::length(xTadShapeShapeInfo);//shape::length(xTadShapeShapeInfo); - unsigned int tads = shape::length(xShapeInfo) / tadLength; - - if (zTadShapeInfo == nullptr) { - zTadShapeInfo = xTadShapeShapeInfo; - zTadOffset = tadOffsets; - } - - auto lenZ = shape::length(zTadShapeInfo); - auto lenY = shape::length(yShapeInfo); - - int tadsPerThread = tads / TAD_THRESHOLD; - int threads = sd::math::nd4j_max(1, tadsPerThread); - threads = sd::math::nd4j_min(threads, sd::Environment::getInstance().maxThreads()); - - auto xEws = shape::elementWiseStride(xTadShapeShapeInfo); - auto yEws = shape::elementWiseStride(yShapeInfo); - auto zEws = shape::elementWiseStride(zTadShapeInfo); - - const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo); - - if (kindOfLoop == sd::LoopKind::EWS1) { - for (auto i = start; i < stop; i++) { - auto oX = x + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) - oZ[f] = OpType::op(oX[f], y[f], extraParams); - } - } - else if(kindOfLoop == sd::LoopKind::EWSNONZERO) { - for (auto i = start; i < stop; i ++) { - auto oX = x + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) - oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws], extraParams); - }; - } - else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - oZ[offset] = OpType::op(oX[offset], y[offset], extraParams); - } - }; - } - else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint tadShapeInfoZCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(oX[offset], y[offset], extraParams); - } - }; - } - else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); - oZ[offset] = OpType::op(oX[offset], y[yOffset], extraParams); - } - }; - - } - else if(shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); - oZ[offset] = OpType::op(oX[xOffset], y[offset], extraParams); - } - }; - } - else { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint tadShapeInfoZCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset], extraParams); - } - }; - } - } + } else if (shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + auto offset = + shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); + oZ[offset] = OpType::op(oX[xOffset], y[offset], extraParams); + } + }; + } else { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); + oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset], extraParams); + } + }; + } +} - template - template - void BroadcastBool::execInverse(const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraParams, - int *dimension, int dimensionLength, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yTadOffset, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, - uint64_t start, uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the x shape info for setting up the tad problem - auto yTadShapeShapeInfo = yTadShapeInfo; - auto tadOffsets = yTadOffset; - - if (yTadShapeInfo == nullptr || tadOffsets == nullptr) { - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(yShapeInfo, dimension, dimensionLength); - - yTadShapeShapeInfo = const_cast(tadPack.primaryShapeInfo()); - tadOffsets = const_cast(tadPack.primaryOffsets()); - } - - //int *resultStride = shape::stride(yTadShapeShapeInfo); - unsigned int tadLength = shape::length(yTadShapeShapeInfo); - unsigned int tads = shape::length(yShapeInfo) / tadLength; - - if (zTadShapeInfo == nullptr) { - zTadShapeInfo = yTadShapeShapeInfo; - zTadOffset = tadOffsets; - } - - auto lenZ = shape::length(zTadShapeInfo); - auto lenX = shape::length(xShapeInfo); - - int tadsPerThread = tads / TAD_THRESHOLD; - int threads = sd::math::nd4j_max(1, tadsPerThread); - threads = sd::math::nd4j_min(threads, sd::Environment::getInstance().maxThreads()); - - auto yEws = shape::elementWiseStride(yTadShapeShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - auto zEws = shape::elementWiseStride(zTadShapeInfo); - - const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ(yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo); - - if (kindOfLoop == sd::LoopKind::EWS1) { - for (auto i = start; i < stop; i ++) { - auto oY = y + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) - oZ[f] = OpType::op(x[f], oY[f], extraParams); - } - } - else if(kindOfLoop == sd::LoopKind::EWSNONZERO) { - for (auto i = start; i < stop; i ++) { - auto oY = y + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (uint f = 0; f < tadLength; f++) - oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws], extraParams); - } - } - else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) { - - uint tadShapeShapeInfoCast[MAX_RANK]; - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - - for (auto i = start; i < stop; i ++) { - auto oY = y + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - oZ[offset] = OpType::op(x[offset], oY[offset], extraParams); - } - } - } - else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo)) { - - uint tadShapeShapeInfoCast[MAX_RANK]; - uint tadShapeInfoZCast[MAX_RANK]; - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oY = y + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(x[offset], oY[offset], extraParams); - } - } - } - else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) { - - uint tadShapeShapeInfoCast[MAX_RANK]; - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oY = y + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); - oZ[offset] = OpType::op(x[xOffset], oY[offset], extraParams); - } - } - } - else if(shape::haveSameShapeAndStrides(xShapeInfo, zTadShapeInfo)) { - - uint tadShapeShapeInfoCast[MAX_RANK]; - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oY = y + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); - oZ[offset] = OpType::op(x[offset], oY[yOffset], extraParams); - } - } - } - else { - - uint xShapeInfoCast[MAX_RANK]; - uint tadShapeShapeInfoCast[MAX_RANK]; - uint tadShapeInfoZCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oY = y + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset], extraParams); - } - } - } - } +template +template +void BroadcastBool::execInverse( + const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, + void *vextraParams, int *dimension, int dimensionLength, + const Nd4jLong *yTadShapeInfo, const Nd4jLong *yTadOffset, + const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, uint64_t start, + uint64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the x shape info for setting up the tad problem + auto yTadShapeShapeInfo = yTadShapeInfo; + auto tadOffsets = yTadOffset; + + if (yTadShapeInfo == nullptr || tadOffsets == nullptr) { + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + yShapeInfo, dimension, dimensionLength); + + yTadShapeShapeInfo = const_cast(tadPack.primaryShapeInfo()); + tadOffsets = const_cast(tadPack.primaryOffsets()); + } + + // int *resultStride = shape::stride(yTadShapeShapeInfo); + unsigned int tadLength = shape::length(yTadShapeShapeInfo); + unsigned int tads = shape::length(yShapeInfo) / tadLength; + + if (zTadShapeInfo == nullptr) { + zTadShapeInfo = yTadShapeShapeInfo; + zTadOffset = tadOffsets; + } + + auto lenZ = shape::length(zTadShapeInfo); + auto lenX = shape::length(xShapeInfo); + + int tadsPerThread = tads / TAD_THRESHOLD; + int threads = sd::math::nd4j_max(1, tadsPerThread); + threads = sd::math::nd4j_min( + threads, sd::Environment::getInstance().maxThreads()); + + auto yEws = shape::elementWiseStride(yTadShapeShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zTadShapeInfo); + + const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ( + yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo); + + if (kindOfLoop == sd::LoopKind::EWS1) { + for (auto i = start; i < stop; i++) { + auto oY = y + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f] = OpType::op(x[f], oY[f], extraParams); + } + } else if (kindOfLoop == sd::LoopKind::EWSNONZERO) { + for (auto i = start; i < stop; i++) { + auto oY = y + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (uint f = 0; f < tadLength; f++) + oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws], extraParams); + } + } else if (shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && + shape::haveSameShapeAndStrides(yTadShapeShapeInfo, + zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oY = y + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + oZ[offset] = OpType::op(x[offset], oY[offset], extraParams); + } + } + } else if (shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); + oZ[zOffset] = OpType::op(x[offset], oY[offset], extraParams); + } + } + } else if (shape::haveSameShapeAndStrides(yTadShapeShapeInfo, + zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + auto xOffset = + shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); + oZ[offset] = OpType::op(x[xOffset], oY[offset], extraParams); + } + } + } else if (shape::haveSameShapeAndStrides(xShapeInfo, zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + auto offset = + shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); + oZ[offset] = OpType::op(x[offset], oY[yOffset], extraParams); + } + } + } else { + uint xShapeInfoCast[MAX_RANK]; + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto xOffset = + shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); + auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); + oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset], extraParams); + } + } + } +} //////////////////////////////////////////////////////////////////////// template -static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, 0); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0); - - auto func = PRAGMA_THREADS_FOR{ - - if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(x[i0], *y, extraParams); - } - else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(*x, y[i0], extraParams); - } - else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(x[i0], y[i0], extraParams); - } - else { - for (auto i0 = start; i0 < stop; ++i0) - z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0], extraParams); - } - }; - samediff::Threads::parallel_tad(func, 0, zAxis0); +static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const X *y, + const Nd4jLong *yShapeInfo, Z *z, + const Nd4jLong *zShapeInfo, X *extraParams) { + uint zAxis0 = shape::sizeAt(zShapeInfo, 0); + Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0); + Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0); + Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0); + + auto func = PRAGMA_THREADS_FOR { + if (zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) { + for (auto i0 = start; i0 < stop; ++i0) + z[i0] = OpType::op(x[i0], *y, extraParams); + } else if (zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) { + for (auto i0 = start; i0 < stop; ++i0) + z[i0] = OpType::op(*x, y[i0], extraParams); + } else if (zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) { + for (auto i0 = start; i0 < stop; ++i0) + z[i0] = OpType::op(x[i0], y[i0], extraParams); + } else { + for (auto i0 = start; i0 < stop; ++i0) + z[i0 * zStrd0] = + OpType::op(x[i0 * xStrd0], y[i0 * yStrd0], extraParams); + } + }; + samediff::Threads::parallel_tad(func, 0, zAxis0); } //////////////////////////////////////////////////////////////////////// template -static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); - - uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); - - auto func = PRAGMA_THREADS_FOR{ - - for (auto i0 = start; i0 < stop; ++i0) { - - auto x0 = x + i0 * xStrd0; - auto y0 = y + i0 * yStrd0; - auto z0 = z + i0 * zStrd0; - - if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(x0[i1], *y0, extraParams); - else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(*x0, y0[i1], extraParams); - else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(x0[i1], y0[i1], extraParams); - else - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1], extraParams); - } - }; +static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const X *y, + const Nd4jLong *yShapeInfo, Z *z, + const Nd4jLong *zShapeInfo, X *extraParams) { + uint zAxis0 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong xStrd0 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong yStrd0 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong zStrd0 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + + uint zAxis1 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong xStrd1 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong yStrd1 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong zStrd1 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + + auto func = PRAGMA_THREADS_FOR { + for (auto i0 = start; i0 < stop; ++i0) { + auto x0 = x + i0 * xStrd0; + auto y0 = y + i0 * yStrd0; + auto z0 = z + i0 * zStrd0; + + if (zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0) + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1] = OpType::op(x0[i1], *y0, extraParams); + else if (zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1) + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1] = OpType::op(*x0, y0[i1], extraParams); + else if (zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1) + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1] = OpType::op(x0[i1], y0[i1], extraParams); + else + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1 * zStrd1] = + OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1], extraParams); + } + }; - samediff::Threads::parallel_tad(func, 0, zAxis0); + samediff::Threads::parallel_tad(func, 0, zAxis0); } //////////////////////////////////////////////////////////////////////// template -static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); - - uint zAxis1 = shape::sizeAt(zShapeInfo, 1); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1); - - uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); - Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); - Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); - Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); - - auto func = PRAGMA_THREADS_FOR_2D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - - auto x1 = x + i0 * xStrd0 + i1 * xStrd1; - auto y1 = y + i0 * yStrd0 + i1 * yStrd1; - auto z1 = z + i0 * zStrd0 + i1 * zStrd1; - - if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(x1[i2], *y1, extraParams); - else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(*x1, y1[i2], extraParams); - else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(x1[i2], y1[i2], extraParams); - else - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2], extraParams); - } - } - }; +static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const X *y, + const Nd4jLong *yShapeInfo, Z *z, + const Nd4jLong *zShapeInfo, X *extraParams) { + uint zAxis0 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong xStrd0 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong yStrd0 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong zStrd0 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + + uint zAxis1 = shape::sizeAt(zShapeInfo, 1); + Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1); + Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1); + Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1); + + uint zAxis2 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong xStrd2 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong yStrd2 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong zStrd2 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + + auto func = PRAGMA_THREADS_FOR_2D { + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + auto x1 = x + i0 * xStrd0 + i1 * xStrd1; + auto y1 = y + i0 * yStrd0 + i1 * yStrd1; + auto z1 = z + i0 * zStrd0 + i1 * zStrd1; + + if (zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0) + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2] = OpType::op(x1[i2], *y1, extraParams); + else if (zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1) + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2] = OpType::op(*x1, y1[i2], extraParams); + else if (zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1) + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2] = OpType::op(x1[i2], y1[i2], extraParams); + else + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2 * zStrd2] = + OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2], extraParams); + } + } + }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1); + samediff::Threads::parallel_for(func, 0, zAxis0, 1, 0, zAxis1, 1); } //////////////////////////////////////////////////////////////////////// template -static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); - - uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); - - uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); - Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); - Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); - Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); - - uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); - Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); - Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); - Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); - - auto func = PRAGMA_THREADS_FOR_3D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - for (auto i2 = start_z; i2 < stop_z; ++i2) { - - auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2; - auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2; - auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2; - - if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(x2[i3], *y2, extraParams); - else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(*x2, y2[i3], extraParams); - else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(x2[i3], y2[i3], extraParams); - else - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3], extraParams); - } - } +static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const X *y, + const Nd4jLong *yShapeInfo, Z *z, + const Nd4jLong *zShapeInfo, X *extraParams) { + uint zAxis0 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong xStrd0 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong yStrd0 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong zStrd0 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + + uint zAxis1 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong xStrd1 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong yStrd1 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong zStrd1 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + + uint zAxis2 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong xStrd2 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong yStrd2 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong zStrd2 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + + uint zAxis3 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong xStrd3 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong yStrd3 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong zStrd3 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + + auto func = PRAGMA_THREADS_FOR_3D { + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2; + auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2; + auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2; + + if (zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(x2[i3], *y2, extraParams); + else if (zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(*x2, y2[i3], extraParams); + else if (zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(x2[i3], y2[i3], extraParams); + else + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3 * zStrd3] = + OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3], extraParams); } - }; + } + } + }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); + samediff::Threads::parallel_for(func, 0, zAxis0, 1, 0, zAxis1, 1, 0, zAxis2, + 1); } //////////////////////////////////////////////////////////////////////// template -static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); - - uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); - - uint zAxis2 = shape::sizeAt(zShapeInfo, 2); - Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2); - Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2); - Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2); - - uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); - Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); - Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); - Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); - - uint zAxis4 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); - Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); - Nd4jLong yStrd4 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); - Nd4jLong zStrd4 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); - - auto func = PRAGMA_THREADS_FOR_3D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - for (auto i2 = start_z; i2 < stop_z; ++i2) { - for (uint i3 = 0; i3 < zAxis3; ++i3) { - - auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3; - auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3; - auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3; - - if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(x3[i4], *y3, extraParams); - else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(*x3, y3[i4], extraParams); - else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(x3[i4], y3[i4], extraParams); - else - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4], extraParams); - } - } - } +static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const X *y, + const Nd4jLong *yShapeInfo, Z *z, + const Nd4jLong *zShapeInfo, X *extraParams) { + uint zAxis0 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong xStrd0 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong yStrd0 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong zStrd0 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + + uint zAxis1 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong xStrd1 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong yStrd1 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong zStrd1 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + + uint zAxis2 = shape::sizeAt(zShapeInfo, 2); + Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2); + Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2); + Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2); + + uint zAxis3 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong xStrd3 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong yStrd3 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong zStrd3 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + + uint zAxis4 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong xStrd4 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong yStrd4 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong zStrd4 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + + auto func = PRAGMA_THREADS_FOR_3D { + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + for (uint i3 = 0; i3 < zAxis3; ++i3) { + auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3; + auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3; + auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3; + + if (zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(x3[i4], *y3, extraParams); + else if (zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(*x3, y3[i4], extraParams); + else if (zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(x3[i4], y3[i4], extraParams); + else + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4 * zStrd4] = + OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4], extraParams); + } } - }; + } + } + }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); + samediff::Threads::parallel_for(func, 0, zAxis0, 1, 0, zAxis1, 1, 0, zAxis2, + 1); } //////////////////////////////////////////////////////////////////////// template -static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) { - - const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); - - const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank - - auto func = PRAGMA_THREADS_FOR{ - - int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; - - for (auto i = start; i < stop; ++i) { - - shape::index2coordsCPU(start, i, zShapeInfo, zCoords); - - for (uint j = 0; j < rank; ++j) { - xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; - yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; - } - - const auto zOffset = shape::getOffset(zShapeInfo, zCoords); - const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); - const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); - } - }; +static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const X *y, + const Nd4jLong *yShapeInfo, Z *z, + const Nd4jLong *zShapeInfo, X *extraParams) { + const bool xzSameOffsets = + shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + const bool yzSameOffsets = + shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); + + const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank + + auto func = PRAGMA_THREADS_FOR { + int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; + + for (auto i = start; i < stop; ++i) { + shape::index2coordsCPU(start, i, zShapeInfo, zCoords); + + for (uint j = 0; j < rank; ++j) { + xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; + yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; + } + + const auto zOffset = shape::getOffset(zShapeInfo, zCoords); + const auto xOffset = + xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); + const auto yOffset = + yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + } + }; - samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo)); + samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo)); } //////////////////////////////////////////////////////////////////////// template -template +template void BroadcastBool::exec(const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraParams) { - - const X* x = reinterpret_cast(vx); - const X* y = reinterpret_cast(vy); - Z* z = reinterpret_cast(vz); - - X* extraParams = reinterpret_cast(vextraParams); - - const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank - - switch (rank) { - - case 1: - execRank1(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); - break; - case 2: - execRank2(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); - break; - case 3: - execRank3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); - break; - case 4: - execRank4(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); - break; - case 5: - execRank5(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); - break; - default: - execDefault(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); - } + void *vz, const Nd4jLong *zShapeInfo, + void *vextraParams) { + const X *x = reinterpret_cast(vx); + const X *y = reinterpret_cast(vy); + Z *z = reinterpret_cast(vz); + + X *extraParams = reinterpret_cast(vextraParams); + + const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank + + switch (rank) { + case 1: + execRank1(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraParams); + break; + case 2: + execRank2(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraParams); + break; + case 3: + execRank3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraParams); + break; + case 4: + execRank4(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraParams); + break; + case 5: + execRank5(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraParams); + break; + default: + execDefault(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraParams); + } } - //BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES, BOOL_TYPES); +// BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , +// LIBND4J_TYPES, BOOL_TYPES); - -} -} \ No newline at end of file +} // namespace broadcast +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/broadcasting_int.hpp b/libnd4j/include/loops/cpu/broadcasting_int.hpp index 39b251594ecf..a70e6252c2a6 100644 --- a/libnd4j/include/loops/cpu/broadcasting_int.hpp +++ b/libnd4j/include/loops/cpu/broadcasting_int.hpp @@ -18,700 +18,759 @@ // @author raver119@gmail.com // -#include +#include +#include +#include #include #include +#include #include -#include -#include -#include using namespace simdOps; namespace functions { - namespace broadcast { - - template - void BroadcastInt::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, - uint64_t start, uint64_t stop) { - DISPATCH_BY_OPNUM_T(exec, PARAMS(x, - xShapeInfo, - y, - yShapeInfo, - z, - zShapeInfo, - dimension, - dimensionLength, - xTadShapeInfo, - xTadOffset, - zTadShapeInfo, - zTadOffset, start, stop), BROADCAST_INT_OPS); - } - - template - void BroadcastInt::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo) { +namespace broadcast { - DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), BROADCAST_INT_OPS); - } +template +void BroadcastInt::exec( + const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xTadOffset, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffset, uint64_t start, uint64_t stop) { + DISPATCH_BY_OPNUM_T( + exec, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, + dimensionLength, xTadShapeInfo, xTadOffset, zTadShapeInfo, + zTadOffset, start, stop), + BROADCAST_INT_OPS); +} - template - void BroadcastInt::execInverse(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, - uint64_t start, uint64_t stop) { - DISPATCH_BY_OPNUM_T(execInverse, PARAMS(x, - xShapeInfo, - y, - yShapeInfo, - z, - zShapeInfo, - dimension, - dimensionLength, - xTadShapeInfo, - xTadOffset, - zTadShapeInfo, - zTadOffset, start, stop), BROADCAST_INT_OPS); - } +template +void BroadcastInt::exec(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo) { + DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), + BROADCAST_INT_OPS); +} - template - template - void BroadcastInt::exec(const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, - uint64_t start, uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the x shape info for setting up the tad problem - auto xTadShapeShapeInfo = xTadShapeInfo; - auto tadOffsets = xTadOffset; - - if (xTadShapeInfo == nullptr || tadOffsets == nullptr) { - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength); - - xTadShapeShapeInfo = const_cast(tadPack.primaryShapeInfo()); - tadOffsets = const_cast(tadPack.primaryOffsets()); - } - - //int *resultStride = shape::stride(xTadShapeShapeInfo); - unsigned int tadLength = shape::length(xTadShapeShapeInfo);//shape::length(xTadShapeShapeInfo); - unsigned int tads = shape::length(xShapeInfo) / tadLength; - - if (zTadShapeInfo == nullptr) { - zTadShapeInfo = xTadShapeShapeInfo; - zTadOffset = tadOffsets; - } - - auto lenZ = shape::length(zTadShapeInfo); - auto lenY = shape::length(yShapeInfo); - - int tadsPerThread = tads / TAD_THRESHOLD; - int threads = sd::math::nd4j_max(1, tadsPerThread); - threads = sd::math::nd4j_min(threads, sd::Environment::getInstance().maxThreads()); - - auto xEws = shape::elementWiseStride(xTadShapeShapeInfo); - auto yEws = shape::elementWiseStride(yShapeInfo); - auto zEws = shape::elementWiseStride(zTadShapeInfo); - - const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo); - - if (kindOfLoop == sd::LoopKind::EWS1) { - for (auto i = start; i < stop; i ++) { - auto oX = x + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) - oZ[f] = OpType::op(oX[f], y[f]); - }; - } - else if(kindOfLoop == sd::LoopKind::EWSNONZERO) { - for (auto i = start; i < stop; i ++) { - auto oX = x + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) - oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]); - }; - } - else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - oZ[offset] = OpType::op(oX[offset], y[offset]); - } - }; - } - else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint tadShapeInfoZCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(oX[offset], y[offset]); - } - }; - } - else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); - oZ[offset] = OpType::op(oX[offset], y[yOffset]); - } - }; - } - else if(shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); - oZ[offset] = OpType::op(oX[xOffset], y[offset]); - } - }; - } - else { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint tadShapeInfoZCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]); - } - }; - } - } +template +void BroadcastInt::execInverse( + const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xTadOffset, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffset, uint64_t start, uint64_t stop) { + DISPATCH_BY_OPNUM_T( + execInverse, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, + dimensionLength, xTadShapeInfo, xTadOffset, zTadShapeInfo, + zTadOffset, start, stop), + BROADCAST_INT_OPS); +} +template +template +void BroadcastInt::exec(const void *vx, const Nd4jLong *xShapeInfo, + const void *vy, const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xTadOffset, + const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffset, uint64_t start, + uint64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the x shape info for setting up the tad problem + auto xTadShapeShapeInfo = xTadShapeInfo; + auto tadOffsets = xTadOffset; + + if (xTadShapeInfo == nullptr || tadOffsets == nullptr) { + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + xShapeInfo, dimension, dimensionLength); + + xTadShapeShapeInfo = const_cast(tadPack.primaryShapeInfo()); + tadOffsets = const_cast(tadPack.primaryOffsets()); + } + + // int *resultStride = shape::stride(xTadShapeShapeInfo); + unsigned int tadLength = + shape::length(xTadShapeShapeInfo); // shape::length(xTadShapeShapeInfo); + unsigned int tads = shape::length(xShapeInfo) / tadLength; + + if (zTadShapeInfo == nullptr) { + zTadShapeInfo = xTadShapeShapeInfo; + zTadOffset = tadOffsets; + } + + auto lenZ = shape::length(zTadShapeInfo); + auto lenY = shape::length(yShapeInfo); + + int tadsPerThread = tads / TAD_THRESHOLD; + int threads = sd::math::nd4j_max(1, tadsPerThread); + threads = sd::math::nd4j_min( + threads, sd::Environment::getInstance().maxThreads()); + + auto xEws = shape::elementWiseStride(xTadShapeShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + auto zEws = shape::elementWiseStride(zTadShapeInfo); + + const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ( + xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo); + + if (kindOfLoop == sd::LoopKind::EWS1) { + for (auto i = start; i < stop; i++) { + auto oX = x + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f] = OpType::op(oX[f], y[f]); + }; + } else if (kindOfLoop == sd::LoopKind::EWSNONZERO) { + for (auto i = start; i < stop; i++) { + auto oX = x + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]); + }; + } else if (shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && + shape::haveSameShapeAndStrides(xTadShapeShapeInfo, + zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + oZ[offset] = OpType::op(oX[offset], y[offset]); + } + }; + } else if (shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); + oZ[zOffset] = OpType::op(oX[offset], y[offset]); + } + }; + } else if (shape::haveSameShapeAndStrides(xTadShapeShapeInfo, + zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); + oZ[offset] = OpType::op(oX[offset], y[yOffset]); + } + }; + } else if (shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + auto offset = + shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); + oZ[offset] = OpType::op(oX[xOffset], y[offset]); + } + }; + } else { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); + oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]); + } + }; + } +} - template - template - void BroadcastInt::execInverse(const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, const int dimensionLength, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yTadOffset, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, - uint64_t start, uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the x shape info for setting up the tad problem - auto yTadShapeShapeInfo = yTadShapeInfo; - auto tadOffsets = yTadOffset; - - if (yTadShapeInfo == nullptr || tadOffsets == nullptr) { - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(yShapeInfo, dimension, dimensionLength); - - yTadShapeShapeInfo = const_cast(tadPack.primaryShapeInfo()); - tadOffsets = const_cast(tadPack.primaryOffsets()); - } - - //int *resultStride = shape::stride(yTadShapeShapeInfo); - unsigned int tadLength = shape::length(yTadShapeShapeInfo); - unsigned int tads = shape::length(yShapeInfo) / tadLength; - - if (zTadShapeInfo == nullptr) { - zTadShapeInfo = yTadShapeShapeInfo; - zTadOffset = tadOffsets; - } - - auto lenZ = shape::length(zTadShapeInfo); - auto lenX = shape::length(xShapeInfo); - - int tadsPerThread = tads / TAD_THRESHOLD; - int threads = sd::math::nd4j_max(1, tadsPerThread); - threads = sd::math::nd4j_min(threads, sd::Environment::getInstance().maxThreads()); - - auto yEws = shape::elementWiseStride(yTadShapeShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - auto zEws = shape::elementWiseStride(zTadShapeInfo); - - const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ(yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo); - - if (kindOfLoop == sd::LoopKind::EWS1) { - for (auto i = start; i < stop; i ++) { - auto oY = y + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) - oZ[f] = OpType::op(x[f], oY[f]); - }; - } - else if(kindOfLoop == sd::LoopKind::EWSNONZERO) { - for (auto i = start; i < stop; i ++) { - auto oY = y + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (uint f = 0; f < tadLength; f++) - oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws]); - }; - } - else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - - for (auto i = start; i < stop; i ++) { - auto oY = y + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (uint f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - oZ[offset] = OpType::op(x[offset], oY[offset]); - } - }; - } - else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo)) { - - uint tadShapeShapeInfoCast[MAX_RANK]; - uint tadShapeInfoZCast[MAX_RANK]; - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oY = y + tadOffsets[i]; - - for (uint f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(x[offset], oY[offset]); - } - }; - } - else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oY = y + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (uint f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); - oZ[offset] = OpType::op(x[xOffset], oY[offset]); - } - }; - } - else if(shape::haveSameShapeAndStrides(xShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oY = y + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (uint f = 0; f < tadLength; f++) { - auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); - oZ[offset] = OpType::op(x[offset], oY[yOffset]); - } - }; - } - else { - uint xShapeInfoCast[MAX_RANK]; - uint tadShapeShapeInfoCast[MAX_RANK]; - uint tadShapeInfoZCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oY = y + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (uint f = 0; f < tadLength; f++) { - auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]); - } - }; - } - } +template +template +void BroadcastInt::execInverse( + const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, + int *dimension, const int dimensionLength, const Nd4jLong *yTadShapeInfo, + const Nd4jLong *yTadOffset, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffset, uint64_t start, uint64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the x shape info for setting up the tad problem + auto yTadShapeShapeInfo = yTadShapeInfo; + auto tadOffsets = yTadOffset; + + if (yTadShapeInfo == nullptr || tadOffsets == nullptr) { + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + yShapeInfo, dimension, dimensionLength); + + yTadShapeShapeInfo = const_cast(tadPack.primaryShapeInfo()); + tadOffsets = const_cast(tadPack.primaryOffsets()); + } + + // int *resultStride = shape::stride(yTadShapeShapeInfo); + unsigned int tadLength = shape::length(yTadShapeShapeInfo); + unsigned int tads = shape::length(yShapeInfo) / tadLength; + + if (zTadShapeInfo == nullptr) { + zTadShapeInfo = yTadShapeShapeInfo; + zTadOffset = tadOffsets; + } + + auto lenZ = shape::length(zTadShapeInfo); + auto lenX = shape::length(xShapeInfo); + + int tadsPerThread = tads / TAD_THRESHOLD; + int threads = sd::math::nd4j_max(1, tadsPerThread); + threads = sd::math::nd4j_min( + threads, sd::Environment::getInstance().maxThreads()); + + auto yEws = shape::elementWiseStride(yTadShapeShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zTadShapeInfo); + + const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ( + yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo); + + if (kindOfLoop == sd::LoopKind::EWS1) { + for (auto i = start; i < stop; i++) { + auto oY = y + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f] = OpType::op(x[f], oY[f]); + }; + } else if (kindOfLoop == sd::LoopKind::EWSNONZERO) { + for (auto i = start; i < stop; i++) { + auto oY = y + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (uint f = 0; f < tadLength; f++) + oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws]); + }; + } else if (shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && + shape::haveSameShapeAndStrides(yTadShapeShapeInfo, + zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oY = y + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (uint f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + oZ[offset] = OpType::op(x[offset], oY[offset]); + } + }; + } else if (shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + for (uint f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); + oZ[zOffset] = OpType::op(x[offset], oY[offset]); + } + }; + } else if (shape::haveSameShapeAndStrides(yTadShapeShapeInfo, + zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (uint f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + auto xOffset = + shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); + oZ[offset] = OpType::op(x[xOffset], oY[offset]); + } + }; + } else if (shape::haveSameShapeAndStrides(xShapeInfo, zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (uint f = 0; f < tadLength; f++) { + auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + auto offset = + shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); + oZ[offset] = OpType::op(x[offset], oY[yOffset]); + } + }; + } else { + uint xShapeInfoCast[MAX_RANK]; + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (uint f = 0; f < tadLength; f++) { + auto xOffset = + shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); + auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); + oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]); + } + }; + } +} //////////////////////////////////////////////////////////////////////// template -static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, 0); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0); - - auto func = PRAGMA_THREADS_FOR{ - - if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(x[i0], *y); - } - else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(*x, y[i0]); - } - else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(x[i0], y[i0]); - } - else { - for (auto i0 = start; i0 < stop; ++i0) - z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]); - } - }; - samediff::Threads::parallel_tad(func, 0, zAxis0); +static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const X *y, + const Nd4jLong *yShapeInfo, X *z, + const Nd4jLong *zShapeInfo) { + uint zAxis0 = shape::sizeAt(zShapeInfo, 0); + Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0); + Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0); + Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0); + + auto func = PRAGMA_THREADS_FOR { + if (zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) { + for (auto i0 = start; i0 < stop; ++i0) z[i0] = OpType::op(x[i0], *y); + } else if (zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) { + for (auto i0 = start; i0 < stop; ++i0) z[i0] = OpType::op(*x, y[i0]); + } else if (zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) { + for (auto i0 = start; i0 < stop; ++i0) z[i0] = OpType::op(x[i0], y[i0]); + } else { + for (auto i0 = start; i0 < stop; ++i0) + z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]); + } + }; + samediff::Threads::parallel_tad(func, 0, zAxis0); } //////////////////////////////////////////////////////////////////////// template -static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); - - uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); - - auto func = PRAGMA_THREADS_FOR{ - - for (auto i0 = start; i0 < stop; ++i0) { - - auto x0 = x + i0 * xStrd0; - auto y0 = y + i0 * yStrd0; - auto z0 = z + i0 * zStrd0; - - if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(x0[i1], *y0); - else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(*x0, y0[i1]); - else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(x0[i1], y0[i1]); - else - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]); - } - }; +static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const X *y, + const Nd4jLong *yShapeInfo, X *z, + const Nd4jLong *zShapeInfo) { + uint zAxis0 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong xStrd0 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong yStrd0 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong zStrd0 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + + uint zAxis1 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong xStrd1 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong yStrd1 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong zStrd1 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + + auto func = PRAGMA_THREADS_FOR { + for (auto i0 = start; i0 < stop; ++i0) { + auto x0 = x + i0 * xStrd0; + auto y0 = y + i0 * yStrd0; + auto z0 = z + i0 * zStrd0; + + if (zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0) + for (uint i1 = 0; i1 < zAxis1; ++i1) z0[i1] = OpType::op(x0[i1], *y0); + else if (zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1) + for (uint i1 = 0; i1 < zAxis1; ++i1) z0[i1] = OpType::op(*x0, y0[i1]); + else if (zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1) + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1] = OpType::op(x0[i1], y0[i1]); + else + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]); + } + }; - samediff::Threads::parallel_tad(func, 0, zAxis0); + samediff::Threads::parallel_tad(func, 0, zAxis0); } //////////////////////////////////////////////////////////////////////// template -static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); - - uint zAxis1 = shape::sizeAt(zShapeInfo, 1); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1); - - uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); - Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); - Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); - Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); - - auto func = PRAGMA_THREADS_FOR_2D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - - auto x1 = x + i0 * xStrd0 + i1 * xStrd1; - auto y1 = y + i0 * yStrd0 + i1 * yStrd1; - auto z1 = z + i0 * zStrd0 + i1 * zStrd1; - - if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(x1[i2], *y1); - else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(*x1, y1[i2]); - else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(x1[i2], y1[i2]); - else - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]); - } - } - }; +static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const X *y, + const Nd4jLong *yShapeInfo, X *z, + const Nd4jLong *zShapeInfo) { + uint zAxis0 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong xStrd0 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong yStrd0 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong zStrd0 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + + uint zAxis1 = shape::sizeAt(zShapeInfo, 1); + Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1); + Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1); + Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1); + + uint zAxis2 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong xStrd2 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong yStrd2 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong zStrd2 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + + auto func = PRAGMA_THREADS_FOR_2D { + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + auto x1 = x + i0 * xStrd0 + i1 * xStrd1; + auto y1 = y + i0 * yStrd0 + i1 * yStrd1; + auto z1 = z + i0 * zStrd0 + i1 * zStrd1; + + if (zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0) + for (uint i2 = 0; i2 < zAxis2; ++i2) z1[i2] = OpType::op(x1[i2], *y1); + else if (zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1) + for (uint i2 = 0; i2 < zAxis2; ++i2) z1[i2] = OpType::op(*x1, y1[i2]); + else if (zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1) + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2] = OpType::op(x1[i2], y1[i2]); + else + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]); + } + } + }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1); + samediff::Threads::parallel_for(func, 0, zAxis0, 1, 0, zAxis1, 1); } //////////////////////////////////////////////////////////////////////// template -static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); - - uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); - - uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); - Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); - Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); - Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); - - uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); - Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); - Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); - Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); - - auto func = PRAGMA_THREADS_FOR_3D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - for (auto i2 = start_z; i2 < stop_z; ++i2) { - - auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2; - auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2; - auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2; - - if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(x2[i3], *y2); - else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(*x2, y2[i3]); - else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(x2[i3], y2[i3]); - else - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]); - } - } +static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const X *y, + const Nd4jLong *yShapeInfo, X *z, + const Nd4jLong *zShapeInfo) { + uint zAxis0 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong xStrd0 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong yStrd0 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong zStrd0 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + + uint zAxis1 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong xStrd1 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong yStrd1 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong zStrd1 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + + uint zAxis2 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong xStrd2 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong yStrd2 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong zStrd2 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + + uint zAxis3 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong xStrd3 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong yStrd3 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong zStrd3 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + + auto func = PRAGMA_THREADS_FOR_3D { + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2; + auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2; + auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2; + + if (zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(x2[i3], *y2); + else if (zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(*x2, y2[i3]); + else if (zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(x2[i3], y2[i3]); + else + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]); } - }; + } + } + }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); + samediff::Threads::parallel_for(func, 0, zAxis0, 1, 0, zAxis1, 1, 0, zAxis2, + 1); } //////////////////////////////////////////////////////////////////////// template -static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); - - uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); - - uint zAxis2 = shape::sizeAt(zShapeInfo, 2); - Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2); - Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2); - Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2); - - uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); - Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); - Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); - Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); - - uint zAxis4 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); - Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); - Nd4jLong yStrd4 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); - Nd4jLong zStrd4 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); - - auto func = PRAGMA_THREADS_FOR_3D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - for (auto i2 = start_z; i2 < stop_z; ++i2) { - for (uint i3 = 0; i3 < zAxis3; ++i3) { - - auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3; - auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3; - auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3; - - if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(x3[i4], *y3); - else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(*x3, y3[i4]); - else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(x3[i4], y3[i4]); - else - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]); - } - } - } +static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const X *y, + const Nd4jLong *yShapeInfo, X *z, + const Nd4jLong *zShapeInfo) { + uint zAxis0 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong xStrd0 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong yStrd0 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong zStrd0 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + + uint zAxis1 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong xStrd1 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong yStrd1 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong zStrd1 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + + uint zAxis2 = shape::sizeAt(zShapeInfo, 2); + Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2); + Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2); + Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2); + + uint zAxis3 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong xStrd3 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong yStrd3 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong zStrd3 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + + uint zAxis4 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong xStrd4 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong yStrd4 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong zStrd4 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + + auto func = PRAGMA_THREADS_FOR_3D { + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + for (uint i3 = 0; i3 < zAxis3; ++i3) { + auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3; + auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3; + auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3; + + if (zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(x3[i4], *y3); + else if (zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(*x3, y3[i4]); + else if (zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(x3[i4], y3[i4]); + else + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]); + } } - }; + } + } + }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); + samediff::Threads::parallel_for(func, 0, zAxis0, 1, 0, zAxis1, 1, 0, zAxis2, + 1); } //////////////////////////////////////////////////////////////////////// template -static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) { - - const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); - - const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank - - auto func = PRAGMA_THREADS_FOR{ - - int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; - - for (auto i = start; i < stop; ++i) { - - shape::index2coordsCPU(start, i, zShapeInfo, zCoords); - - for (uint j = 0; j < rank; ++j) { - xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; - yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; - } - - const auto zOffset = shape::getOffset(zShapeInfo, zCoords); - const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); - const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset]); - } - }; +static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const X *y, + const Nd4jLong *yShapeInfo, X *z, + const Nd4jLong *zShapeInfo) { + const bool xzSameOffsets = + shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + const bool yzSameOffsets = + shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); + + const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank + + auto func = PRAGMA_THREADS_FOR { + int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; + + for (auto i = start; i < stop; ++i) { + shape::index2coordsCPU(start, i, zShapeInfo, zCoords); + + for (uint j = 0; j < rank; ++j) { + xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; + yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; + } + + const auto zOffset = shape::getOffset(zShapeInfo, zCoords); + const auto xOffset = + xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); + const auto yOffset = + yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset]); + } + }; - samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo)); + samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo)); } //////////////////////////////////////////////////////////////////////// template -template +template void BroadcastInt::exec(const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - const X* x = reinterpret_cast(vx); - const X* y = reinterpret_cast(vy); - X* z = reinterpret_cast(vz); - - const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank - - switch (rank) { - - case 1: - execRank1(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); - break; - case 2: - execRank2(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); - break; - case 3: - execRank3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); - break; - case 4: - execRank4(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); - break; - case 5: - execRank5(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); - break; - default: - execDefault(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); - } + const void *vy, const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + const X *x = reinterpret_cast(vx); + const X *y = reinterpret_cast(vy); + X *z = reinterpret_cast(vz); + + const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank + + switch (rank) { + case 1: + execRank1(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + case 2: + execRank2(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + case 3: + execRank3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + case 4: + execRank4(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + case 5: + execRank5(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + default: + execDefault(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + } } -//BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES); -} -} \ No newline at end of file +// BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , +// INTEGER_TYPES); +} // namespace broadcast +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p.cpp.in b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p.cpp.in index b3c60462beea..7aa9e5824b52 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p.cpp.in @@ -23,6 +23,6 @@ namespace functions { namespace broadcast { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES_@FL_TYPE_INDEX@, BOOL_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_@FL_TYPE_INDEX@, BOOL_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p.cpp.in b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p.cpp.in index a36c1a0b2518..d52ecccd232c 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p.cpp.in @@ -23,6 +23,6 @@ namespace functions { namespace broadcast { - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES_@FL_TYPE_INDEX@); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_@FL_TYPE_INDEX@); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p.cpp.in b/libnd4j/include/loops/cpu/compilation_units/broadcast_p.cpp.in index 1dbb4aac4006..e14963df2009 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p.cpp.in @@ -22,6 +22,6 @@ #cmakedefine PAIRWISE_TYPE_GEN namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_@FL_TYPE_INDEX@); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_@FL_TYPE_INDEX@); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32.cpp.in b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32.cpp.in index 97402d38eedb..ad4df30fa112 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32.cpp.in @@ -23,6 +23,6 @@ #cmakedefine LIBND4J_TYPE_GEN namespace functions { namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES_@FL_TYPE_INDEX@, (sd::DataType::INT32, int32_t)); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_@FL_TYPE_INDEX@, (sd::DataType::INT32, int32_t)); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64.cpp.in b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64.cpp.in index 30fa30749bd9..7e4ed33957d8 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64.cpp.in @@ -23,6 +23,6 @@ #cmakedefine LIBND4J_TYPE_GEN namespace functions { namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES_@FL_TYPE_INDEX@, (sd::DataType::INT64, Nd4jLong)); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_@FL_TYPE_INDEX@, (sd::DataType::INT64, Nd4jLong)); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p.cpp.in b/libnd4j/include/loops/cpu/compilation_units/pairwise_p.cpp.in index bbf809de8761..997350a3aaf4 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p.cpp.in @@ -22,6 +22,6 @@ #cmakedefine PAIRWISE_TYPE_GEN namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_@FL_TYPE_INDEX@); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_@FL_TYPE_INDEX@); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/random.cpp.in b/libnd4j/include/loops/cpu/compilation_units/random.cpp.in index 921532ac881b..e6e7dac16b57 100644 --- a/libnd4j/include/loops/cpu/compilation_units/random.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/random.cpp.in @@ -22,6 +22,6 @@ #cmakedefine FLOAT_TYPE_GEN namespace functions { namespace random { - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT RandomFunction, , FLOAT_TYPES_@FL_TYPE_INDEX@); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT RandomFunction, , FLOAT_TYPES_@FL_TYPE_INDEX@); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16.cpp.in b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16.cpp.in index 68616c3f9b74..cd2398e5b3d4 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16.cpp.in @@ -23,6 +23,6 @@ #cmakedefine LIBND4J_TYPE_GEN namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_@FL_TYPE_INDEX@, FLOAT_TYPES_3); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_@FL_TYPE_INDEX@, FLOAT_TYPES_3); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_double.cpp.in b/libnd4j/include/loops/cpu/compilation_units/reduce3_double.cpp.in index 5c722838d1d4..a29a6952a97b 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_double.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_double.cpp.in @@ -23,6 +23,6 @@ #cmakedefine LIBND4J_TYPE_GEN namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_@FL_TYPE_INDEX@, FLOAT_TYPES_2); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_@FL_TYPE_INDEX@, FLOAT_TYPES_2); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float.cpp.in b/libnd4j/include/loops/cpu/compilation_units/reduce3_float.cpp.in index ee127c2d9b84..c0a62ad94425 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float.cpp.in @@ -23,6 +23,6 @@ #cmakedefine LIBND4J_TYPE_GEN namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_@FL_TYPE_INDEX@, FLOAT_TYPES_1); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_@FL_TYPE_INDEX@, FLOAT_TYPES_1); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16.cpp.in b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16.cpp.in index 65c2b563ad83..3699c541d2dd 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16.cpp.in @@ -23,6 +23,6 @@ #cmakedefine LIBND4J_TYPE_GEN namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_@FL_TYPE_INDEX@, FLOAT_TYPES_0); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_@FL_TYPE_INDEX@, FLOAT_TYPES_0); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce_float.cpp.in b/libnd4j/include/loops/cpu/compilation_units/reduce_float.cpp.in index 3837c7810b4d..25cefff53ee3 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce_float.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/reduce_float.cpp.in @@ -23,6 +23,6 @@ #cmakedefine FLOAT_TYPE_GEN namespace functions { namespace reduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_@FL_TYPE_INDEX@); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_@FL_TYPE_INDEX@); } } diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p.cpp.in b/libnd4j/include/loops/cpu/compilation_units/scalar_p.cpp.in index dc024170d8b7..768665099a76 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p.cpp.in @@ -22,6 +22,6 @@ #cmakedefine PAIRWISE_TYPE_GEN namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_@FL_TYPE_INDEX@); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_@FL_TYPE_INDEX@); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/indexreduce.hpp b/libnd4j/include/loops/cpu/indexreduce.hpp index d46dd89d787a..5551eb19d099 100644 --- a/libnd4j/include/loops/cpu/indexreduce.hpp +++ b/libnd4j/include/loops/cpu/indexreduce.hpp @@ -18,138 +18,148 @@ // Created by raver on 4/9/2018. // +#include +#include +#include #include +#include #include -#include #include -#include -#include -#include using namespace simdOps; -namespace functions { +namespace functions { namespace indexreduce { //////////////////////////////////////////////////////////////////////// template -Nd4jLong IndexReduce::execScalar( const int opNum, const void *x, const Nd4jLong *xShapeInfo, void *extraParams) { - RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), INDEX_REDUCE_OPS); +Nd4jLong IndexReduce::execScalar(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams) { + RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), + INDEX_REDUCE_OPS); } //////////////////////////////////////////////////////////////////////// template -void IndexReduce::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset) { - DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS); +void IndexReduce::exec(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, + int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset) { + DISPATCH_BY_OPNUM_TT( + exec, + PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffset), + INDEX_REDUCE_OPS); } //////////////////////////////////////////////////////////////////////// template -template -Nd4jLong IndexReduce::execScalar(const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams) { - - auto x = reinterpret_cast(vx); - auto extraParams = reinterpret_cast(vextraParams); - - //T startingVal = OpType::startingValue(x); - auto startingIndex = OpType::startingIndexValue(x); - auto len = shape::length(xShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - sd::OmpLaunchHelper info(len); - - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - int maxThreads = sd::math::nd4j_min(64, sd::Environment::getInstance().maxThreads()); - IndexValue intermediatery[64]; - for (int e = 0; e < maxThreads; e++) - intermediatery[e].index = -1; +template +Nd4jLong IndexReduce::execScalar(const void *vx, + const Nd4jLong *xShapeInfo, + void *vextraParams) { + auto x = reinterpret_cast(vx); + auto extraParams = reinterpret_cast(vextraParams); + + // T startingVal = OpType::startingValue(x); + auto startingIndex = OpType::startingIndexValue(x); + auto len = shape::length(xShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + sd::OmpLaunchHelper info(len); + + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + int maxThreads = + sd::math::nd4j_min(64, sd::Environment::getInstance().maxThreads()); + IndexValue intermediatery[64]; + for (int e = 0; e < maxThreads; e++) intermediatery[e].index = -1; + + if (xEws == 1 && shape::order(xShapeInfo) == 'c') { + auto func = PRAGMA_THREADS_FOR { + intermediatery[thread_id] = OpType::startingIndexValue(x); + + for (auto i = start; i < stop; i++) { + IndexValue curr(x[i], i); + intermediatery[thread_id] = + OpType::update(intermediatery[thread_id], curr, extraParams); + } + }; + + maxThreads = samediff::Threads::parallel_for(func, 0, len, 1, maxThreads); - if (xEws == 1 && shape::order(xShapeInfo) == 'c') { - auto func = PRAGMA_THREADS_FOR { - intermediatery[thread_id] = OpType::startingIndexValue(x); - - for (auto i = start; i < stop; i++) { - IndexValue curr(x[i], i); - intermediatery[thread_id] = OpType::update(intermediatery[thread_id], curr, extraParams); - } - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, len, 1, maxThreads); - - for (int e = 0; e < maxThreads; e++) - startingIndex = OpType::update(startingIndex, intermediatery[e], extraParams); + for (int e = 0; e < maxThreads; e++) + startingIndex = + OpType::update(startingIndex, intermediatery[e], extraParams); - } else { - auto func = PRAGMA_THREADS_FOR { - intermediatery[thread_id] = OpType::startingIndexValue(x); + } else { + auto func = PRAGMA_THREADS_FOR { + intermediatery[thread_id] = OpType::startingIndexValue(x); - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - IndexValue curr(x[offset], i); - intermediatery[thread_id] = OpType::update(intermediatery[thread_id], curr, extraParams); - } - }; + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + IndexValue curr(x[offset], i); + intermediatery[thread_id] = + OpType::update(intermediatery[thread_id], curr, extraParams); + } + }; - maxThreads = samediff::Threads::parallel_for(func, 0, len, 1, maxThreads); + maxThreads = samediff::Threads::parallel_for(func, 0, len, 1, maxThreads); - for (int e = 0; e < maxThreads; e++) - startingIndex = OpType::update(startingIndex, intermediatery[e], extraParams); - } - return startingIndex.index; + for (int e = 0; e < maxThreads; e++) + startingIndex = + OpType::update(startingIndex, intermediatery[e], extraParams); + } + return startingIndex.index; } - //////////////////////////////////////////////////////////////////////// template -template +template void IndexReduce::exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset) { + void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); + const Nd4jLong zLen = shape::length(zShapeInfo); - const Nd4jLong zLen = shape::length(zShapeInfo); + if (sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { + if (sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) return; + const auto indexValue = OpType::startingIndexValue(x); - if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { - if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) - return; - const auto indexValue = OpType::startingIndexValue(x); + for (Nd4jLong i = 0; i < zLen; i++) z[i] = (Z)indexValue.index; - for (Nd4jLong i = 0; i < zLen; i++) - z[i] = (Z) indexValue.index; + return; + } - return; - } + if (shape::isScalar(zShapeInfo)) { + z[0] = (Z)execScalar(x, xShapeInfo, extraParams); + return; + } - if(shape::isScalar(zShapeInfo)) { - z[0] = (Z) execScalar(x,xShapeInfo,extraParams); - return; - } + auto tadOnlyShapeInfo = tadShapeInfo; + auto tadOffsets = tadOffset; - auto tadOnlyShapeInfo = tadShapeInfo; - auto tadOffsets = tadOffset; + if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) { + if (dimensionLength < 1) return; - if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) { - if (dimensionLength < 1) - return; + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + xShapeInfo, dimension, dimensionLength); - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength); + tadOnlyShapeInfo = tadPack.primaryShapeInfo(); + tadOffsets = tadPack.primaryOffsets(); + } - tadOnlyShapeInfo = tadPack.primaryShapeInfo(); - tadOffsets = tadPack.primaryOffsets(); - } - - sd::IndexReductionLoops::template loopIndexReduce(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams); + sd::IndexReductionLoops::template loopIndexReduce( + x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams); } -} -} \ No newline at end of file +} // namespace indexreduce +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/pairwise.hpp b/libnd4j/include/loops/cpu/pairwise.hpp index 45fe46e8f1c6..ceb3a196d099 100644 --- a/libnd4j/include/loops/cpu/pairwise.hpp +++ b/libnd4j/include/loops/cpu/pairwise.hpp @@ -18,210 +18,211 @@ // Created by remote on 2018-09-20. // -#include -#include -#include +#include #include -#include +#include #include +#include +#include +#include #include -#include -#include +#include using namespace simdOps; namespace functions { - namespace pairwise_transforms { - - template - void PairWiseTransform::exec(const int opNum, - const void *x, Nd4jLong xEws, - const void *y, Nd4jLong yEws, - void *z, Nd4jLong zEws, - void *extraParams, - Nd4jLong n, - const uint64_t start,const uint64_t stop) { - DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x, - xEws, - y, - yEws, - z, - zEws, - extraParams, - n, start, stop), PAIRWISE_TRANSFORM_OPS); - }; - - - - template - template - void PairWiseTransform::exec(const void *vx, Nd4jLong xEws, - const void *vy, Nd4jLong yEws, - void *vz, Nd4jLong zEws, - void *vextraParams, - const Nd4jLong n, - const uint64_t start, - const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - if (xEws == 1 && yEws == 1 && zEws == 1) { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) - z[i] = OpType::op(x[i], y[i], extraParams); - } - else { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) - z[i*zEws] = OpType::op(x[i*xEws], y[i*yEws], extraParams); - } - } - - template - void PairWiseTransform::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams, - const uint64_t start, const uint64_t stop) { - DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x, - xShapeInfo, - y, - yShapeInfo, - z, - zShapeInfo, - extraParams, start, stop), - PAIRWISE_TRANSFORM_OPS); - }; - - - template - template - void PairWiseTransform::exec(const void *vx, const Nd4jLong* xShapeInfo, - const void *vy, const Nd4jLong* yShapeInfo, - void *vz, const Nd4jLong* zShapeInfo, - void *vextraParams, - const uint64_t start, const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - auto n = shape::length(xShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - auto yEws = shape::elementWiseStride(yShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); - - - if (shape::isScalar(yShapeInfo)) { - - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - PRAGMA_OMP_SIMD - for(auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpType::op(x[offset], y[0], extraParams); - }; - } - else { - uint zShapeInfoCast[MAX_RANK]; - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_SIMD - for(auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[xOffset], y[0], extraParams); - }; - } - return; - } - - - - const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo); - const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo); - - if ((kindOfLoop == sd::LoopKind::EWS1 || kindOfLoop == sd::LoopKind::EWSNONZERO) && sameShapesXY) { - exec(x, xEws, y, yEws, z, zEws, extraParams, n, start, stop); - } - else if ((kindOfLoop == sd::LoopKind::EWS1 || kindOfLoop == sd::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape - exec(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo), start, stop); - } - else { - - if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpType::op(x[offset], y[offset], extraParams); - } - } - else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - uint zShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[offset], y[offset], extraParams); - }; - } - else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - z[offset] = OpType::op(x[offset], y[yOffset], extraParams); - }; - } - else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto offset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - z[offset] = OpType::op(x[xOffset], y[offset], extraParams); - }; - } - else { - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - uint zShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); - }; - } - } - } +namespace pairwise_transforms { + +template +void PairWiseTransform::exec(const int opNum, const void *x, + Nd4jLong xEws, const void *y, + Nd4jLong yEws, void *z, Nd4jLong zEws, + void *extraParams, Nd4jLong n, + const uint64_t start, + const uint64_t stop) { + DISPATCH_BY_OPNUM_TTT( + exec, PARAMS(x, xEws, y, yEws, z, zEws, extraParams, n, start, stop), + PAIRWISE_TRANSFORM_OPS); +}; + +template +template +void PairWiseTransform::exec(const void *vx, Nd4jLong xEws, + const void *vy, Nd4jLong yEws, void *vz, + Nd4jLong zEws, void *vextraParams, + const Nd4jLong n, const uint64_t start, + const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + if (xEws == 1 && yEws == 1 && zEws == 1) { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) + z[i] = OpType::op(x[i], y[i], extraParams); + } else { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) + z[i * zEws] = OpType::op(x[i * xEws], y[i * yEws], extraParams); + } +} + +template +void PairWiseTransform::exec(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo, + void *extraParams, const uint64_t start, + const uint64_t stop) { + DISPATCH_BY_OPNUM_TTT(exec, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraParams, start, stop), + PAIRWISE_TRANSFORM_OPS); +}; + +template +template +void PairWiseTransform::exec( + const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, + void *vextraParams, const uint64_t start, const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + auto n = shape::length(xShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + + if (shape::isScalar(yShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = OpType::op(x[offset], y[0], extraParams); + }; + } else { + uint zShapeInfoCast[MAX_RANK]; + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[xOffset], y[0], extraParams); + }; + } + return; + } + + const sd::LoopKind::Kind kindOfLoop = + sd::LoopKind::deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo); + const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo); + + if ((kindOfLoop == sd::LoopKind::EWS1 || + kindOfLoop == sd::LoopKind::EWSNONZERO) && + sameShapesXY) { + exec(x, xEws, y, yEws, z, zEws, extraParams, n, start, stop); + } else if ((kindOfLoop == sd::LoopKind::EWS1 || + kindOfLoop == sd::LoopKind::EWSNONZERO) && + !sameShapesXY) { // not same shape + exec(x, xEws, y, yEws, z, zEws, extraParams, + shape::length(yShapeInfo), start, stop); + } else { + if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && + shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = OpType::op(x[offset], y[offset], extraParams); + } + } else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + uint zShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[offset], y[offset], extraParams); + }; + } else if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + z[offset] = OpType::op(x[offset], y[yOffset], extraParams); + }; + } else if (shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto offset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + z[offset] = OpType::op(x[xOffset], y[offset], extraParams); + }; + } else { + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + uint zShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + }; } + } } +} // namespace pairwise_transforms +} // namespace functions diff --git a/libnd4j/include/loops/cpu/pairwise_bool.cpp b/libnd4j/include/loops/cpu/pairwise_bool.cpp index dfcdf6bfacc8..b7c0cc2559a9 100644 --- a/libnd4j/include/loops/cpu/pairwise_bool.cpp +++ b/libnd4j/include/loops/cpu/pairwise_bool.cpp @@ -18,204 +18,208 @@ // Created by remote on 2018-09-20. // -#include -#include +#include #include #include -#include +#include +#include using namespace simdOps; namespace functions { - namespace pairwise_transforms { - - template - void PairWiseBoolTransform::exec(const int opNum, - const void *x, Nd4jLong xEws, - const void *y, Nd4jLong yEws, - void *z, Nd4jLong zEws, - void *extraParams, - Nd4jLong n, - const uint64_t start, const uint64_t stop) { - DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, - xEws, - y, - yEws, - z, - zEws, - extraParams, - n, start, stop), PAIRWISE_BOOL_OPS); - }; - - - - template - template - void PairWiseBoolTransform::exec(const void *vx, Nd4jLong xEws, - const void *vy, Nd4jLong yEws, - void *vz, Nd4jLong zEws, - void *vextraParams, - const Nd4jLong n, - const uint64_t start, const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - if (xEws == 1 && yEws == 1 && zEws == 1) { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) - z[i] = OpType::op(x[i], y[i], extraParams); - } - else { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) - z[i*zEws] = OpType::op(x[i*xEws], y[i*yEws], extraParams); - } - } - - template - void PairWiseBoolTransform::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams, - const uint64_t start,const uint64_t stop) { - DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, - xShapeInfo, - y, - yShapeInfo, - z, - zShapeInfo, - extraParams, start, stop), - PAIRWISE_BOOL_OPS); - }; - - - template - template - void PairWiseBoolTransform::exec(const void *vx, const Nd4jLong* xShapeInfo, - const void *vy, const Nd4jLong* yShapeInfo, - void *vz, const Nd4jLong* zShapeInfo, - void *vextraParams, - const uint64_t start, const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - auto n = shape::length(xShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - auto yEws = shape::elementWiseStride(yShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); - - if (shape::isScalar(yShapeInfo)) { - - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - - PRAGMA_OMP_SIMD - for(auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpType::op(x[offset], y[0], extraParams); - }; - } - else { - uint zShapeInfoCast[MAX_RANK]; - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_SIMD - for(auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[xOffset], y[0], extraParams); - }; - } - return; - } - - const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo); - const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo); - - if ((kindOfLoop == sd::LoopKind::EWS1 || kindOfLoop == sd::LoopKind::EWSNONZERO) && sameShapesXY) { - exec(x, xEws, y, yEws, z, zEws, extraParams, n, start, stop); - } - else if ((kindOfLoop == sd::LoopKind::EWS1 || kindOfLoop == sd::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape - exec(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo), start, stop); - } - else { - if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpType::op(x[offset], y[offset], extraParams); - }; - } - else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - uint zShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[offset], y[offset], extraParams); - }; - } - else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - const bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - z[offset] = OpType::op(x[offset], y[yOffset], extraParams); - }; - } - else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - const bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto offset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - z[offset] = OpType::op(x[xOffset], y[offset], extraParams); - }; - } - else { - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - uint zShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - const bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); - }; - } - } - } - - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT PairWiseBoolTransform, , LIBND4J_TYPES, BOOL_TYPES); +namespace pairwise_transforms { + +template +void PairWiseBoolTransform::exec(const int opNum, const void *x, + Nd4jLong xEws, const void *y, + Nd4jLong yEws, void *z, Nd4jLong zEws, + void *extraParams, Nd4jLong n, + const uint64_t start, + const uint64_t stop) { + DISPATCH_BY_OPNUM_TT( + exec, PARAMS(x, xEws, y, yEws, z, zEws, extraParams, n, start, stop), + PAIRWISE_BOOL_OPS); +}; + +template +template +void PairWiseBoolTransform::exec(const void *vx, Nd4jLong xEws, + const void *vy, Nd4jLong yEws, void *vz, + Nd4jLong zEws, void *vextraParams, + const Nd4jLong n, const uint64_t start, + const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + if (xEws == 1 && yEws == 1 && zEws == 1) { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) + z[i] = OpType::op(x[i], y[i], extraParams); + } else { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) + z[i * zEws] = OpType::op(x[i * xEws], y[i * yEws], extraParams); + } +} + +template +void PairWiseBoolTransform::exec( + const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo, + void *extraParams, const uint64_t start, const uint64_t stop) { + DISPATCH_BY_OPNUM_TT(exec, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraParams, start, stop), + PAIRWISE_BOOL_OPS); +}; + +template +template +void PairWiseBoolTransform::exec( + const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, + void *vextraParams, const uint64_t start, const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + auto n = shape::length(xShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + + if (shape::isScalar(yShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = OpType::op(x[offset], y[0], extraParams); + }; + } else { + uint zShapeInfoCast[MAX_RANK]; + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[xOffset], y[0], extraParams); + }; } + return; + } + + const sd::LoopKind::Kind kindOfLoop = + sd::LoopKind::deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo); + const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo); + + if ((kindOfLoop == sd::LoopKind::EWS1 || + kindOfLoop == sd::LoopKind::EWSNONZERO) && + sameShapesXY) { + exec(x, xEws, y, yEws, z, zEws, extraParams, n, start, stop); + } else if ((kindOfLoop == sd::LoopKind::EWS1 || + kindOfLoop == sd::LoopKind::EWSNONZERO) && + !sameShapesXY) { // not same shape + exec(x, xEws, y, yEws, z, zEws, extraParams, + shape::length(yShapeInfo), start, stop); + } else { + if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && + shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = OpType::op(x[offset], y[offset], extraParams); + }; + } else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + uint zShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[offset], y[offset], extraParams); + }; + } else if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + z[offset] = OpType::op(x[offset], y[yOffset], extraParams); + }; + } else if (shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto offset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + z[offset] = OpType::op(x[xOffset], y[offset], extraParams); + }; + } else { + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + uint zShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + }; + } + } } + +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT PairWiseBoolTransform, , + LIBND4J_TYPES, BOOL_TYPES); +} // namespace pairwise_transforms +} // namespace functions diff --git a/libnd4j/include/loops/cpu/pairwise_int.cpp b/libnd4j/include/loops/cpu/pairwise_int.cpp index b822166110f4..e77592103a01 100644 --- a/libnd4j/include/loops/cpu/pairwise_int.cpp +++ b/libnd4j/include/loops/cpu/pairwise_int.cpp @@ -18,205 +18,210 @@ // @author raver119@gmail.com // -#include -#include +#include #include #include -#include +#include +#include using namespace simdOps; namespace functions { - namespace pairwise_transforms { - - template - void PairWiseIntTransform::exec(const int opNum, - const void *x, Nd4jLong xEws, - const void *y, Nd4jLong yEws, - void *z, Nd4jLong zEws, - void *extraParams, - Nd4jLong n, - const uint64_t start, const uint64_t stop) { - DISPATCH_BY_OPNUM_T(exec, PARAMS(x, - xEws, - y, - yEws, - z, - zEws, - extraParams, - n, start, stop), PAIRWISE_INT_OPS); - }; - - - - template - template - void PairWiseIntTransform::exec(const void *vx, Nd4jLong xEws, - const void *vy, Nd4jLong yEws, - void *vz, Nd4jLong zEws, - void *vextraParams, - const Nd4jLong n, - const uint64_t start, - const uint64_t stop) { - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - if (xEws == 1 && yEws == 1 && zEws == 1) { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) - z[i] = OpType::op(x[i], y[i], extraParams); - } - else { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) - z[i*zEws] = OpType::op(x[i*xEws], y[i*yEws], extraParams); - } - } - - template - void PairWiseIntTransform::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams, - const uint64_t start, const uint64_t stop) { - DISPATCH_BY_OPNUM_T(exec, PARAMS(x, - xShapeInfo, - y, - yShapeInfo, - z, - zShapeInfo, - extraParams, start, stop), - PAIRWISE_INT_OPS); - }; - - - template - template - void PairWiseIntTransform::exec(const void *vx, const Nd4jLong* xShapeInfo, - const void *vy, const Nd4jLong* yShapeInfo, - void *vz, const Nd4jLong* zShapeInfo, - void *vextraParams, - const uint64_t start, - const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - auto n = shape::length(xShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - auto yEws = shape::elementWiseStride(yShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); - - if (shape::isScalar(yShapeInfo)) { - - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - PRAGMA_OMP_SIMD - for(auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpType::op(x[offset], y[0], extraParams); - }; - } - else { - uint zShapeInfoCast[MAX_RANK]; - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_SIMD - for(auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[xOffset], y[0], extraParams); - }; - } - return; - } - - const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo); - const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo); - - if ((kindOfLoop == sd::LoopKind::EWS1 || kindOfLoop == sd::LoopKind::EWSNONZERO) && sameShapesXY) { - exec(x, xEws, y, yEws, z, zEws, extraParams, n, start, stop); - } - else if ((kindOfLoop == sd::LoopKind::EWS1 || kindOfLoop == sd::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape - exec(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo), start, stop); - } - else { - - if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpType::op(x[offset], y[offset], extraParams); - }; - } - else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - uint zShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[offset], y[offset], extraParams); - }; - } - else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - const bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - z[offset] = OpType::op(x[offset], y[yOffset], extraParams); - }; - } - else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - const bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto offset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - z[offset] = OpType::op(x[xOffset], y[offset], extraParams); - }; - } - else { - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - uint zShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - const bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); - }; - } - } - } - - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT PairWiseIntTransform, , INTEGER_TYPES); +namespace pairwise_transforms { + +template +void PairWiseIntTransform::exec(const int opNum, const void *x, + Nd4jLong xEws, const void *y, Nd4jLong yEws, + void *z, Nd4jLong zEws, void *extraParams, + Nd4jLong n, const uint64_t start, + const uint64_t stop) { + DISPATCH_BY_OPNUM_T( + exec, PARAMS(x, xEws, y, yEws, z, zEws, extraParams, n, start, stop), + PAIRWISE_INT_OPS); +}; + +template +template +void PairWiseIntTransform::exec(const void *vx, Nd4jLong xEws, + const void *vy, Nd4jLong yEws, void *vz, + Nd4jLong zEws, void *vextraParams, + const Nd4jLong n, const uint64_t start, + const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + if (xEws == 1 && yEws == 1 && zEws == 1) { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) + z[i] = OpType::op(x[i], y[i], extraParams); + } else { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) + z[i * zEws] = OpType::op(x[i * xEws], y[i * yEws], extraParams); + } +} + +template +void PairWiseIntTransform::exec(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo, + void *extraParams, const uint64_t start, + const uint64_t stop) { + DISPATCH_BY_OPNUM_T(exec, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraParams, start, stop), + PAIRWISE_INT_OPS); +}; + +template +template +void PairWiseIntTransform::exec(const void *vx, const Nd4jLong *xShapeInfo, + const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, + void *vextraParams, const uint64_t start, + const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + auto n = shape::length(xShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + + if (shape::isScalar(yShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = OpType::op(x[offset], y[0], extraParams); + }; + } else { + uint zShapeInfoCast[MAX_RANK]; + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[xOffset], y[0], extraParams); + }; } + return; + } + + const sd::LoopKind::Kind kindOfLoop = + sd::LoopKind::deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo); + const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo); + + if ((kindOfLoop == sd::LoopKind::EWS1 || + kindOfLoop == sd::LoopKind::EWSNONZERO) && + sameShapesXY) { + exec(x, xEws, y, yEws, z, zEws, extraParams, n, start, stop); + } else if ((kindOfLoop == sd::LoopKind::EWS1 || + kindOfLoop == sd::LoopKind::EWSNONZERO) && + !sameShapesXY) { // not same shape + exec(x, xEws, y, yEws, z, zEws, extraParams, + shape::length(yShapeInfo), start, stop); + } else { + if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && + shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = OpType::op(x[offset], y[offset], extraParams); + }; + } else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + uint zShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[offset], y[offset], extraParams); + }; + } else if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + z[offset] = OpType::op(x[offset], y[yOffset], extraParams); + }; + } else if (shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto offset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + z[offset] = OpType::op(x[xOffset], y[offset], extraParams); + }; + } else { + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + uint zShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + }; + } + } } + +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT PairWiseIntTransform, , + INTEGER_TYPES); +} // namespace pairwise_transforms +} // namespace functions diff --git a/libnd4j/include/loops/cpu/random.hpp b/libnd4j/include/loops/cpu/random.hpp index ea1dc9e7645e..6e80368f6f17 100644 --- a/libnd4j/include/loops/cpu/random.hpp +++ b/libnd4j/include/loops/cpu/random.hpp @@ -19,263 +19,298 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include #include +#include +#include +#include using namespace randomOps; namespace functions { - namespace random { - - - template - template - void RandomFunction::execTransform(Nd4jPointer state, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraArguments) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - if (OpClass::requiresSpecial) { - OpClass::specialOp(state, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraArguments); - return; - } - - auto length = shape::length(zShapeInfo); - - sd::graph::RandomGenerator* rng = reinterpret_cast(state); - - if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - - - if(shape::elementWiseStride(zShapeInfo) == 1 && shape::elementWiseStride(xShapeInfo) == 1 && shape::elementWiseStride(yShapeInfo) == 1 && - shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(yShapeInfo) ){ - - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - z[i] = OpClass::op(x[i], y[i], i, length, rng, extraArguments); - } - }; - samediff::Threads::parallel_for(func, 0, length, 1); - } - else{ - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments); - } - }; - - samediff::Threads::parallel_for(func, 0, length, 1); - } - } - else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { - - uint xShapeInfoCast[MAX_RANK]; - uint zShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments); - } - }; - - samediff::Threads::parallel_for(func, 0, length, 1); - } - else if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - const bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - z[offset] = OpClass::op(x[offset], y[yOffset], i, length, rng, extraArguments); - } - }; - - samediff::Threads::parallel_for(func, 0, length, 1); - } - else if (shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) { - - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - const bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto offset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - z[offset] = OpClass::op(x[xOffset], y[offset], i, length, rng, extraArguments); - } - }; - - samediff::Threads::parallel_for(func, 0, length, 1); - } - else { - - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - uint zShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - const bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpClass::op(x[xOffset], y[yOffset], i, length, rng, extraArguments); - } - }; - - samediff::Threads::parallel_for(func, 0, length, 1); - } - }; - - - - template - template - void RandomFunction::execTransform(Nd4jPointer state, - const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraArguments) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - auto length = shape::length(zShapeInfo); - - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - sd::graph::RandomGenerator* rng = reinterpret_cast(state); - - if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - - if(shape::elementWiseStride(zShapeInfo) == 1 && shape::elementWiseStride(xShapeInfo) == 1 && shape::order(xShapeInfo) == shape::order(zShapeInfo)){ - - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - z[i] = OpClass::op(x[i], i, length, rng, extraArguments); - } - }; - samediff::Threads::parallel_for(func, 0, length, 1); - } - else{ - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpClass::op(x[offset], i, length, rng, extraArguments); - } - }; - - samediff::Threads::parallel_for(func, 0, length, 1); - } - } - else { - - uint zShapeInfoCast[MAX_RANK]; - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpClass::op(x[xOffset], i, length, rng, extraArguments); - } - }; - - samediff::Threads::parallel_for(func, 0, length, 1); - } - } - - - template - template - void RandomFunction::execTransform(Nd4jPointer state, void *vz, const Nd4jLong *zShapeInfo, void *vextraArguments) { - - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - auto length = shape::length(zShapeInfo); - - sd::graph::RandomGenerator* rng = reinterpret_cast(state); - - if(shape::elementWiseStride(zShapeInfo) == 1){ - - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - z[i] = OpClass::op( i, length, rng, extraArguments); - } - }; - - samediff::Threads::parallel_for(func, 0, length, 1); - } - else{ - sd::OmpLaunchHelper info(length); - - uint zShapeInfoCast[MAX_RANK]; - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[offset] = OpClass::op(i, length, rng, extraArguments); - } - }; - - samediff::Threads::parallel_for(func, 0, length, 1); - } +namespace random { + +template +template +void RandomFunction::execTransform(Nd4jPointer state, const void *vx, + const Nd4jLong *xShapeInfo, + const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + void *vextraArguments) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + if (OpClass::requiresSpecial) { + OpClass::specialOp(state, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraArguments); + return; + } + + auto length = shape::length(zShapeInfo); + + sd::graph::RandomGenerator *rng = + reinterpret_cast(state); + + if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && + shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + if (shape::elementWiseStride(zShapeInfo) == 1 && + shape::elementWiseStride(xShapeInfo) == 1 && + shape::elementWiseStride(yShapeInfo) == 1 && + shape::order(xShapeInfo) == shape::order(zShapeInfo) && + shape::order(zShapeInfo) == shape::order(yShapeInfo)) { + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + z[i] = OpClass::op(x[i], y[i], i, length, rng, extraArguments); } - - template - void RandomFunction::execTransform(int opNum, Nd4jPointer state, const void *x, const Nd4jLong *xShapeInfo, void *z, const Nd4jLong *zShapeInfo, void *extraArguments) { - DISPATCH_BY_OPNUM_T(execTransform, PARAMS(state, x, xShapeInfo, z, zShapeInfo, extraArguments), RANDOM_OPS) + }; + samediff::Threads::parallel_for(func, 0, length, 1); + } else { + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = + OpClass::op(x[offset], y[offset], i, length, rng, extraArguments); } + }; - template - void RandomFunction::execTransform(int opNum, Nd4jPointer state, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo, void *extraArguments) { - DISPATCH_BY_OPNUM_T(execTransform, PARAMS(state, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraArguments), RANDOM_OPS) + samediff::Threads::parallel_for(func, 0, length, 1); + } + } else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + uint zShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = + OpClass::op(x[offset], y[offset], i, length, rng, extraArguments); + } + }; + + samediff::Threads::parallel_for(func, 0, length, 1); + } else if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + z[offset] = + OpClass::op(x[offset], y[yOffset], i, length, rng, extraArguments); + } + }; + + samediff::Threads::parallel_for(func, 0, length, 1); + } else if (shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto offset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + z[offset] = + OpClass::op(x[xOffset], y[offset], i, length, rng, extraArguments); + } + }; + + samediff::Threads::parallel_for(func, 0, length, 1); + } else { + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + uint zShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = + OpClass::op(x[xOffset], y[yOffset], i, length, rng, extraArguments); + } + }; + + samediff::Threads::parallel_for(func, 0, length, 1); + } +}; + +template +template +void RandomFunction::execTransform(Nd4jPointer state, const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + void *vextraArguments) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + auto length = shape::length(zShapeInfo); + + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + sd::graph::RandomGenerator *rng = + reinterpret_cast(state); + + if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + if (shape::elementWiseStride(zShapeInfo) == 1 && + shape::elementWiseStride(xShapeInfo) == 1 && + shape::order(xShapeInfo) == shape::order(zShapeInfo)) { + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + z[i] = OpClass::op(x[i], i, length, rng, extraArguments); } - - template - void RandomFunction::execTransform(int opNum, Nd4jPointer state, void *z, const Nd4jLong *zShapeInfo, void *extraArguments) { - DISPATCH_BY_OPNUM_T(execTransform, PARAMS(state, z, zShapeInfo, extraArguments), RANDOM_OPS) + }; + samediff::Threads::parallel_for(func, 0, length, 1); + } else { + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = OpClass::op(x[offset], i, length, rng, extraArguments); } + }; - - //BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT RandomFunction, , FLOAT_TYPES); + samediff::Threads::parallel_for(func, 0, length, 1); } -} \ No newline at end of file + } else { + uint zShapeInfoCast[MAX_RANK]; + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpClass::op(x[xOffset], i, length, rng, extraArguments); + } + }; + + samediff::Threads::parallel_for(func, 0, length, 1); + } +} + +template +template +void RandomFunction::execTransform(Nd4jPointer state, void *vz, + const Nd4jLong *zShapeInfo, + void *vextraArguments) { + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + auto length = shape::length(zShapeInfo); + + sd::graph::RandomGenerator *rng = + reinterpret_cast(state); + + if (shape::elementWiseStride(zShapeInfo) == 1) { + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + z[i] = OpClass::op(i, length, rng, extraArguments); + } + }; + + samediff::Threads::parallel_for(func, 0, length, 1); + } else { + sd::OmpLaunchHelper info(length); + + uint zShapeInfoCast[MAX_RANK]; + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[offset] = OpClass::op(i, length, rng, extraArguments); + } + }; + + samediff::Threads::parallel_for(func, 0, length, 1); + } +} + +template +void RandomFunction::execTransform(int opNum, Nd4jPointer state, + const void *x, const Nd4jLong *xShapeInfo, + void *z, const Nd4jLong *zShapeInfo, + void *extraArguments) { + DISPATCH_BY_OPNUM_T( + execTransform, + PARAMS(state, x, xShapeInfo, z, zShapeInfo, extraArguments), RANDOM_OPS) +} + +template +void RandomFunction::execTransform(int opNum, Nd4jPointer state, + const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo, + void *extraArguments) { + DISPATCH_BY_OPNUM_T(execTransform, + PARAMS(state, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraArguments), + RANDOM_OPS) +} + +template +void RandomFunction::execTransform(int opNum, Nd4jPointer state, void *z, + const Nd4jLong *zShapeInfo, + void *extraArguments) { + DISPATCH_BY_OPNUM_T(execTransform, + PARAMS(state, z, zShapeInfo, extraArguments), RANDOM_OPS) +} + +// BUILD_SINGLE_TEMPLATE(template class SD_EXPORT RandomFunction, , +// FLOAT_TYPES); +} // namespace random +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/reduce/reduce_bool.cpp b/libnd4j/include/loops/cpu/reduce/reduce_bool.cpp index 94e1567051cc..3ddf0dad0917 100644 --- a/libnd4j/include/loops/cpu/reduce/reduce_bool.cpp +++ b/libnd4j/include/loops/cpu/reduce/reduce_bool.cpp @@ -19,208 +19,232 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include -#include -#include -#include #include +#include +#include +#include +#include +#include +#include using namespace simdOps; namespace functions { - namespace reduce { - template - template - void _CUDA_H ReduceBoolFunction::execScalar(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - const Nd4jLong length = shape::length(xShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - - if (shape::isEmpty(xShapeInfo)) { - z[0] = OpType::startingValue(x); - return; - } - - if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { - if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) - return; - const auto startingVal = OpType::startingValue(x); - - for (Nd4jLong i = 0; i < length; i++) - z[i] = startingVal; - return; - } - - if (xEws >= 1) { - z[0] = execScalar(x, xEws, length, extraParams); - } - else { - auto startingValue = OpType::startingValue(x); - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - for (Nd4jLong i = 0; i < length; i++) - startingValue = OpType::update(startingValue, OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams); - - z[0] = OpType::postProcess(startingValue, length, extraParams); - } - } - - - template - template - Z _CUDA_H ReduceBoolFunction::execScalar(const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams) { - - auto x = reinterpret_cast(vx); - auto extraParams = reinterpret_cast(vextraParams); - - const Nd4jLong length = shape::length(xShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - - if (xEws >= 1) { - return execScalar(x, xEws, length, extraParams); - } - else { - auto startingValue = OpType::startingValue(x); - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - for (Nd4jLong i = 0; i < length; i++) - startingValue = OpType::update(startingValue, OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams); - - return OpType::postProcess(startingValue, length, extraParams); - } - } - - template - Y ReduceBoolFunction::execScalar(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams) { - RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), REDUCE_BOOL_OPS); - } - - template - void ReduceBoolFunction::execScalar(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo) { - DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), REDUCE_BOOL_OPS); - } - - template - void ReduceBoolFunction::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop) { - DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset, start, stop), REDUCE_BOOL_OPS); - } - - template - template - void _CUDA_H ReduceBoolFunction::exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vresult, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, int64_t start, int64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vresult); - auto extraParams = reinterpret_cast(vextraParams); - - auto resultLength = shape::length(zShapeInfo); - - if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { - if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) - return; - const auto startingVal = OpType::startingValue(x); - - for (Nd4jLong i = 0; i < resultLength; i++) - z[i] = startingVal; - return; - } - - //pre squeezed: this is for keeping the pointer to the original - //shape information for tad offset - //the squeezed information doesn't render the right strides for - //tad offset - // || tad.wholeThing - if (resultLength == 1 || dimension == nullptr || dimensionLength == shape::rank(xShapeInfo)) { - z[0] = execScalar(x, xShapeInfo, extraParams); - return; - } - - auto tadOnlyShapeInfo = tadShapeInfo; - auto tadOffsets = tadOffset; - - if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) { - if (dimensionLength < 1) - return; - - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength); - tadOnlyShapeInfo = tadPack.primaryShapeInfo(); - tadOffsets = tadPack.primaryOffsets(); - } +namespace reduce { +template +template +void _CUDA_H ReduceBoolFunction::execScalar(const void *vx, + const Nd4jLong *xShapeInfo, + void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + const Nd4jLong length = shape::length(xShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + + if (shape::isEmpty(xShapeInfo)) { + z[0] = OpType::startingValue(x); + return; + } + + if (sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { + if (sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) return; + const auto startingVal = OpType::startingValue(x); + + for (Nd4jLong i = 0; i < length; i++) z[i] = startingVal; + return; + } + + if (xEws >= 1) { + z[0] = execScalar(x, xEws, length, extraParams); + } else { + auto startingValue = OpType::startingValue(x); + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + for (Nd4jLong i = 0; i < length; i++) + startingValue = OpType::update( + startingValue, + OpType::op( + x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], + extraParams), + extraParams); + + z[0] = OpType::postProcess(startingValue, length, extraParams); + } +} + +template +template +Z _CUDA_H ReduceBoolFunction::execScalar(const void *vx, + const Nd4jLong *xShapeInfo, + void *vextraParams) { + auto x = reinterpret_cast(vx); + auto extraParams = reinterpret_cast(vextraParams); + + const Nd4jLong length = shape::length(xShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + + if (xEws >= 1) { + return execScalar(x, xEws, length, extraParams); + } else { + auto startingValue = OpType::startingValue(x); + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + for (Nd4jLong i = 0; i < length; i++) + startingValue = OpType::update( + startingValue, + OpType::op( + x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], + extraParams), + extraParams); + + return OpType::postProcess(startingValue, length, extraParams); + } +} + +template +Y ReduceBoolFunction::execScalar(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams) { + RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), + REDUCE_BOOL_OPS); +} + +template +void ReduceBoolFunction::execScalar(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo) { + DISPATCH_BY_OPNUM_TT(execScalar, + PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), + REDUCE_BOOL_OPS); +} + +template +void ReduceBoolFunction::exec( + const int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, int64_t stop) { + DISPATCH_BY_OPNUM_TT( + exec, + PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffset, start, stop), + REDUCE_BOOL_OPS); +} + +template +template +void _CUDA_H ReduceBoolFunction::exec( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, + void *vresult, const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, int64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vresult); + auto extraParams = reinterpret_cast(vextraParams); + + auto resultLength = shape::length(zShapeInfo); + + if (sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { + if (sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) return; + const auto startingVal = OpType::startingValue(x); + + for (Nd4jLong i = 0; i < resultLength; i++) z[i] = startingVal; + return; + } + + // pre squeezed: this is for keeping the pointer to the original + // shape information for tad offset + // the squeezed information doesn't render the right strides for + // tad offset + // || tad.wholeThing + if (resultLength == 1 || dimension == nullptr || + dimensionLength == shape::rank(xShapeInfo)) { + z[0] = execScalar(x, xShapeInfo, extraParams); + return; + } + + auto tadOnlyShapeInfo = tadShapeInfo; + auto tadOffsets = tadOffset; + + if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) { + if (dimensionLength < 1) return; + + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + xShapeInfo, dimension, dimensionLength); + tadOnlyShapeInfo = tadPack.primaryShapeInfo(); + tadOffsets = tadPack.primaryOffsets(); + } #ifdef INLINE_LOOPS - sd::ReductionLoops::template loopReduce(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop); + sd::ReductionLoops::template loopReduce( + x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, + start, stop); #else - sd::ReductionBoolLoops::template innerloopReduce(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop); + sd::ReductionBoolLoops::template innerloopReduce( + x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, + start, stop); #endif - } - - - template - template - void _CUDA_H ReduceBoolFunction::exec(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vresult, const Nd4jLong *resultShapeInfo) { - auto z = reinterpret_cast(vresult); - z[0] = execScalar(x, xShapeInfo, extraParams); - } - - template - template - Z _CUDA_H ReduceBoolFunction::execScalar(const void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) { - auto x = reinterpret_cast(vx); - auto extraParams = reinterpret_cast(vextraParams); - int maxThreads = sd::math::nd4j_min(64, sd::Environment::getInstance().maxThreads()); - Z intermediate[64]; - - PRAGMA_OMP_SIMD - for (auto e = 0; e < maxThreads; e++) - intermediate[e] = OpType::startingValue(x); - - auto func = PRAGMA_THREADS_FOR { - if (xEws == 1) { - for (auto i = start; i < stop; i++) - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], extraParams), extraParams); - } else { - for (auto i = start; i < stop; i++) - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i * xEws], extraParams), extraParams); - } - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); - - // merge results - for (int e = 1; e < maxThreads; e++) - intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams); - - // return result - return OpType::postProcess(intermediate[0], length, extraParams); - } - - - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceBoolFunction, , LIBND4J_TYPES, BOOL_TYPES); +} + +template +template +void _CUDA_H ReduceBoolFunction::exec(const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams, void *vresult, + const Nd4jLong *resultShapeInfo) { + auto z = reinterpret_cast(vresult); + z[0] = execScalar(x, xShapeInfo, extraParams); +} + +template +template +Z _CUDA_H ReduceBoolFunction::execScalar(const void *vx, Nd4jLong xEws, + Nd4jLong length, + void *vextraParams) { + auto x = reinterpret_cast(vx); + auto extraParams = reinterpret_cast(vextraParams); + int maxThreads = + sd::math::nd4j_min(64, sd::Environment::getInstance().maxThreads()); + Z intermediate[64]; + + PRAGMA_OMP_SIMD + for (auto e = 0; e < maxThreads; e++) + intermediate[e] = OpType::startingValue(x); + + auto func = PRAGMA_THREADS_FOR { + if (xEws == 1) { + for (auto i = start; i < stop; i++) + intermediate[thread_id] = + OpType::update(intermediate[thread_id], + OpType::op(x[i], extraParams), extraParams); + } else { + for (auto i = start; i < stop; i++) + intermediate[thread_id] = + OpType::update(intermediate[thread_id], + OpType::op(x[i * xEws], extraParams), extraParams); } -} \ No newline at end of file + }; + + maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); + + // merge results + for (int e = 1; e < maxThreads; e++) + intermediate[0] = + OpType::update(intermediate[0], intermediate[e], extraParams); + + // return result + return OpType::postProcess(intermediate[0], length, extraParams); +} + +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceBoolFunction, , + LIBND4J_TYPES, BOOL_TYPES); +} // namespace reduce +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/reduce/reduce_float.hpp b/libnd4j/include/loops/cpu/reduce/reduce_float.hpp index 6be93b1c4223..97059ea4f19d 100644 --- a/libnd4j/include/loops/cpu/reduce/reduce_float.hpp +++ b/libnd4j/include/loops/cpu/reduce/reduce_float.hpp @@ -19,241 +19,261 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include -#include -#include -#include #include +#include +#include +#include +#include +#include +#include using namespace simdOps; namespace functions { - namespace reduce { - template - template - void _CUDA_H ReduceFloatFunction::execScalar(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - const Nd4jLong length = shape::length(xShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - - if (shape::isEmpty(xShapeInfo)) { - if (std::is_same>::value) { - z[0] = sd::DataTypeUtils::nanOrZero(); - } else { - z[0] = OpType::startingValue(x); - } - return; - } - - if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { - if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) - return; - const auto startingVal = OpType::startingValue(x); - - for (Nd4jLong i = 0; i < length; i++) - z[i] = startingVal; - - return; - } - - if (xEws > 0) { - z[0] = execScalar(x, xEws, length, extraParams); - } - else { - auto startingValue = OpType::startingValue(x); - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - int maxThreads = sd::math::nd4j_min(64, sd::Environment::getInstance().maxThreads()); - Z intermediate[64]; - - PRAGMA_OMP_SIMD - for (auto e = 0; e < maxThreads; e++) - intermediate[e] = OpType::startingValue(x); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams); - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); - - // merge results - for (int e = 1; e < maxThreads; e++) - intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams); - - // write out results - z[0] = OpType::postProcess(intermediate[0], length, extraParams); - } - } - - - template - template - Z _CUDA_H ReduceFloatFunction::execScalar(const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams) { - auto x = reinterpret_cast(vx); - auto extraParams = reinterpret_cast(vextraParams); - - const Nd4jLong length = shape::length(xShapeInfo); - int xEws = shape::elementWiseStride(xShapeInfo); - - if (xEws > 0) { - return execScalar(x, xEws, length, extraParams); - } - else { - auto startingValue = OpType::startingValue(x); - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - for (Nd4jLong i = 0; i < length; i++) - startingValue = OpType::update(startingValue, OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams); - - return OpType::postProcess(startingValue, length, extraParams); - } - } - - template - Y ReduceFloatFunction::execScalar(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams) { - RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), REDUCE_FLOAT_OPS); - } - - template - void ReduceFloatFunction::execScalar(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo) { - DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), REDUCE_FLOAT_OPS); - } - - template - void ReduceFloatFunction::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop) { - DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, - xShapeInfo, - extraParams, - z, - zShapeInfo, - dimension, - dimensionLength, - tadShapeInfo, - tadOffset, start, stop), - REDUCE_FLOAT_OPS); - } - - template - template - void _CUDA_H ReduceFloatFunction::exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vresult, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vresult); - auto extraParams = reinterpret_cast(vextraParams); - - auto resultLength = shape::length(zShapeInfo); - - if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { - if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) - return; - const auto startingVal = std::is_same>::value ? sd::DataTypeUtils::nanOrZero() : static_cast(OpType::startingValue(x)); - - for (Nd4jLong i = 0; i < resultLength; i++) - z[i] = startingVal; - return; - } - - //pre squeezed: this is for keeping the pointer to the original - //shape information for tad offset - //the squeezed information doesn't render the right strides for - //tad offset - // || tad.wholeThing - if (resultLength == 1 || dimension == nullptr || dimensionLength == shape::rank(xShapeInfo)) { - z[0] = execScalar(x, xShapeInfo, extraParams); - return; - } - - if (OpType::requiresSpecialAccumulation) { - OpType::execSpecial(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset); - return; - } - - auto tadOnlyShapeInfo = tadShapeInfo; - auto tadOffsets = tadOffset; - - if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) { - if (dimensionLength < 0) - return; - - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength); - tadOnlyShapeInfo = tadPack.primaryShapeInfo(); - tadOffsets = tadPack.primaryOffsets(); - } +namespace reduce { +template +template +void _CUDA_H ReduceFloatFunction::execScalar(const void *vx, + const Nd4jLong *xShapeInfo, + void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + const Nd4jLong length = shape::length(xShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + + if (shape::isEmpty(xShapeInfo)) { + if (std::is_same>::value) { + z[0] = sd::DataTypeUtils::nanOrZero(); + } else { + z[0] = OpType::startingValue(x); + } + return; + } + + if (sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { + if (sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) return; + const auto startingVal = OpType::startingValue(x); + + for (Nd4jLong i = 0; i < length; i++) z[i] = startingVal; + + return; + } + + if (xEws > 0) { + z[0] = execScalar(x, xEws, length, extraParams); + } else { + auto startingValue = OpType::startingValue(x); + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + int maxThreads = sd::math::nd4j_min( + 64, sd::Environment::getInstance().maxThreads()); + Z intermediate[64]; + + PRAGMA_OMP_SIMD + for (auto e = 0; e < maxThreads; e++) + intermediate[e] = OpType::startingValue(x); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) + intermediate[thread_id] = OpType::update( + intermediate[thread_id], + OpType::op( + x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], + extraParams), + extraParams); + }; + + maxThreads = + samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); + + // merge results + for (int e = 1; e < maxThreads; e++) + intermediate[0] = + OpType::update(intermediate[0], intermediate[e], extraParams); + + // write out results + z[0] = OpType::postProcess(intermediate[0], length, extraParams); + } +} + +template +template +Z _CUDA_H ReduceFloatFunction::execScalar(const void *vx, + const Nd4jLong *xShapeInfo, + void *vextraParams) { + auto x = reinterpret_cast(vx); + auto extraParams = reinterpret_cast(vextraParams); + + const Nd4jLong length = shape::length(xShapeInfo); + int xEws = shape::elementWiseStride(xShapeInfo); + + if (xEws > 0) { + return execScalar(x, xEws, length, extraParams); + } else { + auto startingValue = OpType::startingValue(x); + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + for (Nd4jLong i = 0; i < length; i++) + startingValue = OpType::update( + startingValue, + OpType::op( + x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], + extraParams), + extraParams); + + return OpType::postProcess(startingValue, length, extraParams); + } +} + +template +Y ReduceFloatFunction::execScalar(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams) { + RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), + REDUCE_FLOAT_OPS); +} + +template +void ReduceFloatFunction::execScalar(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo) { + DISPATCH_BY_OPNUM_TT(execScalar, + PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), + REDUCE_FLOAT_OPS); +} + +template +void ReduceFloatFunction::exec( + const int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, int64_t stop) { + DISPATCH_BY_OPNUM_TT( + exec, + PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffset, start, stop), + REDUCE_FLOAT_OPS); +} + +template +template +void _CUDA_H ReduceFloatFunction::exec( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, + void *vresult, const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, int64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vresult); + auto extraParams = reinterpret_cast(vextraParams); + + auto resultLength = shape::length(zShapeInfo); + + if (sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { + if (sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) return; + const auto startingVal = std::is_same>::value + ? sd::DataTypeUtils::nanOrZero() + : static_cast(OpType::startingValue(x)); + + for (Nd4jLong i = 0; i < resultLength; i++) z[i] = startingVal; + return; + } + + // pre squeezed: this is for keeping the pointer to the original + // shape information for tad offset + // the squeezed information doesn't render the right strides for + // tad offset + // || tad.wholeThing + if (resultLength == 1 || dimension == nullptr || + dimensionLength == shape::rank(xShapeInfo)) { + z[0] = execScalar(x, xShapeInfo, extraParams); + return; + } + + if (OpType::requiresSpecialAccumulation) { + OpType::execSpecial(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffset); + return; + } + + auto tadOnlyShapeInfo = tadShapeInfo; + auto tadOffsets = tadOffset; + + if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) { + if (dimensionLength < 0) return; + + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + xShapeInfo, dimension, dimensionLength); + tadOnlyShapeInfo = tadPack.primaryShapeInfo(); + tadOffsets = tadPack.primaryOffsets(); + } #ifdef INLINE_LOOPS - sd::ReductionLoops::template loopReduce(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop); + sd::ReductionLoops::template loopReduce( + x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, + start, stop); #else - sd::ReductionFloatLoops::template innerloopReduce(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop); + sd::ReductionFloatLoops::template innerloopReduce( + x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, + start, stop); #endif - } - - - template - template - void _CUDA_H ReduceFloatFunction::exec(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vresult, const Nd4jLong *resultShapeInfo) { - // FIXME: wtf??? - auto z = reinterpret_cast(vresult); - z[0] = execScalar(x, xShapeInfo, extraParams); - } - - template - template - Z _CUDA_H ReduceFloatFunction::execScalar(const void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) { - - auto x = reinterpret_cast(vx); - auto extraParams = reinterpret_cast(vextraParams); - int maxThreads = sd::math::nd4j_min(64, sd::Environment::getInstance().maxThreads()); - Z intermediate[64]; - - PRAGMA_OMP_SIMD - for (auto e = 0; e < maxThreads; e++) - intermediate[e] = OpType::startingValue(x); - - auto func = PRAGMA_THREADS_FOR { - if (xEws == 1) { - for (auto i = start; i < stop; i++) - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], extraParams), extraParams); - } else { - for (auto i = start; i < stop; i++) - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i * xEws], extraParams), extraParams); - } - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); - - // merge results - for (int e = 1; e < maxThreads; e++) - intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams); - - // return result - return OpType::postProcess(intermediate[0], length, extraParams); - } +} + +template +template +void _CUDA_H ReduceFloatFunction::exec(const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams, void *vresult, + const Nd4jLong *resultShapeInfo) { + // FIXME: wtf??? + auto z = reinterpret_cast(vresult); + z[0] = execScalar(x, xShapeInfo, extraParams); +} + +template +template +Z _CUDA_H ReduceFloatFunction::execScalar(const void *vx, Nd4jLong xEws, + Nd4jLong length, + void *vextraParams) { + auto x = reinterpret_cast(vx); + auto extraParams = reinterpret_cast(vextraParams); + int maxThreads = + sd::math::nd4j_min(64, sd::Environment::getInstance().maxThreads()); + Z intermediate[64]; + + PRAGMA_OMP_SIMD + for (auto e = 0; e < maxThreads; e++) + intermediate[e] = OpType::startingValue(x); + + auto func = PRAGMA_THREADS_FOR { + if (xEws == 1) { + for (auto i = start; i < stop; i++) + intermediate[thread_id] = + OpType::update(intermediate[thread_id], + OpType::op(x[i], extraParams), extraParams); + } else { + for (auto i = start; i < stop; i++) + intermediate[thread_id] = + OpType::update(intermediate[thread_id], + OpType::op(x[i * xEws], extraParams), extraParams); } -} \ No newline at end of file + }; + + maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); + + // merge results + for (int e = 1; e < maxThreads; e++) + intermediate[0] = + OpType::update(intermediate[0], intermediate[e], extraParams); + + // return result + return OpType::postProcess(intermediate[0], length, extraParams); +} +} // namespace reduce +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/reduce/reduce_long.cpp b/libnd4j/include/loops/cpu/reduce/reduce_long.cpp index a4fae322817d..eb1064ffd92a 100644 --- a/libnd4j/include/loops/cpu/reduce/reduce_long.cpp +++ b/libnd4j/include/loops/cpu/reduce/reduce_long.cpp @@ -19,230 +19,256 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include -#include -#include -#include #include +#include +#include +#include +#include +#include +#include using namespace simdOps; namespace functions { - namespace reduce { - template - template - void _CUDA_H ReduceLongFunction::execScalar(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - const Nd4jLong length = shape::length(xShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - - if (shape::isEmpty(xShapeInfo)) { - z[0] = OpType::startingValue(x); - return; - } - - if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { - if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) - return; - const auto startingVal = OpType::startingValue(x); - - for (Nd4jLong i = 0; i < length; i++) - z[i] = startingVal; - return; - } - - if (xEws >= 1) { - z[0] = execScalar(x, xEws, length, extraParams); - } - else { - auto startingValue = OpType::startingValue(x); - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - int maxThreads = sd::math::nd4j_min(64, sd::Environment::getInstance().maxThreads()); - Z intermediate[64]; - - PRAGMA_OMP_SIMD - for (auto e = 0; e < maxThreads; e++) - intermediate[e] = OpType::startingValue(x); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams); - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); - - // merge results - for (int e = 1; e < maxThreads; e++) - intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams); - - // write out results - z[0] = OpType::postProcess(intermediate[0], length, extraParams); - } - } - - - template - template - Z _CUDA_H ReduceLongFunction::execScalar(const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams) { - auto x = reinterpret_cast(vx); - auto extraParams = reinterpret_cast(vextraParams); - - const Nd4jLong length = shape::length(xShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - - if (xEws >= 1) { - return execScalar(x, xEws, length, extraParams); - } - else { - auto startingValue = OpType::startingValue(x); - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - for (Nd4jLong i = 0; i < length; i++) - startingValue = OpType::update(startingValue, OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams); - - return OpType::postProcess(startingValue, length, extraParams); - } - } - - - template - Y ReduceLongFunction::execScalar(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams) { - RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), REDUCE_LONG_OPS); - } - - template - void ReduceLongFunction::execScalar(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo) { - DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), REDUCE_LONG_OPS); - } - - template - void ReduceLongFunction::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop) { - DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset, start, stop), REDUCE_LONG_OPS); - } - - template - template - void _CUDA_H ReduceLongFunction::exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vresult, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vresult); - auto extraParams = reinterpret_cast(vextraParams); - - auto resultLength = shape::length(zShapeInfo); - - if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { - if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) - return; - const auto startingVal = OpType::startingValue(x); - - for (Nd4jLong i = 0; i < resultLength; i++) - z[i] = startingVal; - return; - } - - //pre squeezed: this is for keeping the pointer to the original - //shape information for tad offset - //the squeezed information doesn't render the right strides for - //tad offset - // || tad.wholeThing - if (resultLength == 1 || dimension == nullptr || dimensionLength == shape::rank(xShapeInfo)) { - z[0] = execScalar(x, xShapeInfo, extraParams); - return; - } - - if (OpType::requiresSpecialAccumulation) { - OpType::execSpecial(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset); - return; - } - - auto tadOnlyShapeInfo = tadShapeInfo; - auto tadOffsets = tadOffset; - - if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) { - if (dimensionLength < 1) - return; - - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength); - tadOnlyShapeInfo = tadPack.primaryShapeInfo(); - tadOffsets = tadPack.primaryOffsets(); - } +namespace reduce { +template +template +void _CUDA_H ReduceLongFunction::execScalar(const void *vx, + const Nd4jLong *xShapeInfo, + void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + const Nd4jLong length = shape::length(xShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + + if (shape::isEmpty(xShapeInfo)) { + z[0] = OpType::startingValue(x); + return; + } + + if (sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { + if (sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) return; + const auto startingVal = OpType::startingValue(x); + + for (Nd4jLong i = 0; i < length; i++) z[i] = startingVal; + return; + } + + if (xEws >= 1) { + z[0] = execScalar(x, xEws, length, extraParams); + } else { + auto startingValue = OpType::startingValue(x); + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + int maxThreads = sd::math::nd4j_min( + 64, sd::Environment::getInstance().maxThreads()); + Z intermediate[64]; + + PRAGMA_OMP_SIMD + for (auto e = 0; e < maxThreads; e++) + intermediate[e] = OpType::startingValue(x); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) + intermediate[thread_id] = OpType::update( + intermediate[thread_id], + OpType::op( + x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], + extraParams), + extraParams); + }; + + maxThreads = + samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); + + // merge results + for (int e = 1; e < maxThreads; e++) + intermediate[0] = + OpType::update(intermediate[0], intermediate[e], extraParams); + + // write out results + z[0] = OpType::postProcess(intermediate[0], length, extraParams); + } +} + +template +template +Z _CUDA_H ReduceLongFunction::execScalar(const void *vx, + const Nd4jLong *xShapeInfo, + void *vextraParams) { + auto x = reinterpret_cast(vx); + auto extraParams = reinterpret_cast(vextraParams); + + const Nd4jLong length = shape::length(xShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + + if (xEws >= 1) { + return execScalar(x, xEws, length, extraParams); + } else { + auto startingValue = OpType::startingValue(x); + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + for (Nd4jLong i = 0; i < length; i++) + startingValue = OpType::update( + startingValue, + OpType::op( + x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], + extraParams), + extraParams); + + return OpType::postProcess(startingValue, length, extraParams); + } +} + +template +Y ReduceLongFunction::execScalar(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams) { + RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), + REDUCE_LONG_OPS); +} + +template +void ReduceLongFunction::execScalar(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo) { + DISPATCH_BY_OPNUM_TT(execScalar, + PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), + REDUCE_LONG_OPS); +} + +template +void ReduceLongFunction::exec( + const int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, int64_t stop) { + DISPATCH_BY_OPNUM_TT( + exec, + PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffset, start, stop), + REDUCE_LONG_OPS); +} + +template +template +void _CUDA_H ReduceLongFunction::exec( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, + void *vresult, const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, int64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vresult); + auto extraParams = reinterpret_cast(vextraParams); + + auto resultLength = shape::length(zShapeInfo); + + if (sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { + if (sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) return; + const auto startingVal = OpType::startingValue(x); + + for (Nd4jLong i = 0; i < resultLength; i++) z[i] = startingVal; + return; + } + + // pre squeezed: this is for keeping the pointer to the original + // shape information for tad offset + // the squeezed information doesn't render the right strides for + // tad offset + // || tad.wholeThing + if (resultLength == 1 || dimension == nullptr || + dimensionLength == shape::rank(xShapeInfo)) { + z[0] = execScalar(x, xShapeInfo, extraParams); + return; + } + + if (OpType::requiresSpecialAccumulation) { + OpType::execSpecial(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffset); + return; + } + + auto tadOnlyShapeInfo = tadShapeInfo; + auto tadOffsets = tadOffset; + + if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) { + if (dimensionLength < 1) return; + + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + xShapeInfo, dimension, dimensionLength); + tadOnlyShapeInfo = tadPack.primaryShapeInfo(); + tadOffsets = tadPack.primaryOffsets(); + } #ifdef INLINE_LOOPS - sd::ReductionLoops::template loopReduce(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop); + sd::ReductionLoops::template loopReduce( + x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, + start, stop); #else - sd::ReductionLongLoops::template innerloopReduce(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop); + sd::ReductionLongLoops::template innerloopReduce( + x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, + start, stop); #endif - } - - - template - template - void _CUDA_H ReduceLongFunction::exec(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vresult, const Nd4jLong *resultShapeInfo) { - auto z = reinterpret_cast(vresult); - z[0] = execScalar(x, xShapeInfo, extraParams); - } - - template - template - Z _CUDA_H ReduceLongFunction::execScalar(const void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) { - - auto x = reinterpret_cast(vx); - auto extraParams = reinterpret_cast(vextraParams); - int maxThreads = sd::math::nd4j_min(64, sd::Environment::getInstance().maxThreads()); - Z intermediate[64]; - - PRAGMA_OMP_SIMD - for (auto e = 0; e < maxThreads; e++) - intermediate[e] = OpType::startingValue(x); - - auto func = PRAGMA_THREADS_FOR { - if (xEws == 1) { - for (auto i = start; i < stop; i++) - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], extraParams), extraParams); - } else { - for (auto i = start; i < stop; i++) - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i * xEws], extraParams), extraParams); - } - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); +} + +template +template +void _CUDA_H ReduceLongFunction::exec(const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams, void *vresult, + const Nd4jLong *resultShapeInfo) { + auto z = reinterpret_cast(vresult); + z[0] = execScalar(x, xShapeInfo, extraParams); +} + +template +template +Z _CUDA_H ReduceLongFunction::execScalar(const void *vx, Nd4jLong xEws, + Nd4jLong length, + void *vextraParams) { + auto x = reinterpret_cast(vx); + auto extraParams = reinterpret_cast(vextraParams); + int maxThreads = + sd::math::nd4j_min(64, sd::Environment::getInstance().maxThreads()); + Z intermediate[64]; + + PRAGMA_OMP_SIMD + for (auto e = 0; e < maxThreads; e++) + intermediate[e] = OpType::startingValue(x); + + auto func = PRAGMA_THREADS_FOR { + if (xEws == 1) { + for (auto i = start; i < stop; i++) + intermediate[thread_id] = + OpType::update(intermediate[thread_id], + OpType::op(x[i], extraParams), extraParams); + } else { + for (auto i = start; i < stop; i++) + intermediate[thread_id] = + OpType::update(intermediate[thread_id], + OpType::op(x[i * xEws], extraParams), extraParams); + } + }; - // merge results - for (int e = 1; e < maxThreads; e++) - intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams); + maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); - // return result - return OpType::postProcess(intermediate[0], length, extraParams); - } + // merge results + for (int e = 1; e < maxThreads; e++) + intermediate[0] = + OpType::update(intermediate[0], intermediate[e], extraParams); + // return result + return OpType::postProcess(intermediate[0], length, extraParams); +} - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceLongFunction, , LIBND4J_TYPES, LONG_TYPES); - } -} \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceLongFunction, , + LIBND4J_TYPES, LONG_TYPES); +} // namespace reduce +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/reduce/reduce_same.cpp b/libnd4j/include/loops/cpu/reduce/reduce_same.cpp index 10607fb6d3f8..0ded1b2e9a10 100644 --- a/libnd4j/include/loops/cpu/reduce/reduce_same.cpp +++ b/libnd4j/include/loops/cpu/reduce/reduce_same.cpp @@ -19,239 +19,261 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include -#include +#include +#include #include +#include +#include +#include +#include + #include -#include -#include using namespace simdOps; namespace functions { - namespace reduce { - template - template - void _CUDA_H ReduceSameFunction::execScalar(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - const auto length = shape::length(xShapeInfo); - const auto xEws = shape::elementWiseStride(xShapeInfo); - const int rank = shape::rank(xShapeInfo); - - if (shape::isEmpty(xShapeInfo)) { - z[0] = OpType::startingValue(x); - return; - } - - if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { - if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) - return; - const auto startingVal = OpType::startingValue(x); - - for (Nd4jLong i = 0; i < length; i++) - z[i] = startingVal; - return; - } - - if (xEws >= 1) { - z[0] = execScalar(x, xEws, length, extraParams); - } - else { - auto startingValue = OpType::startingValue(x); - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - int maxThreads = sd::math::nd4j_min(64, sd::Environment::getInstance().maxThreads()); - X intermediate[64]; - - PRAGMA_OMP_SIMD - for (auto e = 0; e < maxThreads; e++) - intermediate[e] = OpType::startingValue(x); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams); - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); - - // merge results - for (int e = 1; e < maxThreads; e++) - intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams); - - // write out results - z[0] = OpType::postProcess(intermediate[0], length, extraParams); - } - } - - - template - template - X _CUDA_H ReduceSameFunction::execScalar(const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams) { - auto x = reinterpret_cast(vx); - auto extraParams = reinterpret_cast(vextraParams); - - const Nd4jLong length = shape::length(xShapeInfo); - const auto xEws = shape::elementWiseStride(xShapeInfo); - - if (xEws >= 1) { - return execScalar(x, xEws, length, extraParams); - } else { - auto startingValue = OpType::startingValue(x); - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - for (Nd4jLong i = 0; i < length; i++) - startingValue = OpType::update(startingValue, OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams); - - return OpType::postProcess(startingValue, length, extraParams); - } - } - - template - X ReduceSameFunction::execScalar(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams) { - RETURNING_DISPATCH_BY_OPNUM_T(execScalar, PARAMS(x, xShapeInfo, extraParams), REDUCE_SAME_OPS); - } - - template - void ReduceSameFunction::execScalar(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo) { - DISPATCH_BY_OPNUM_T(execScalar, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), REDUCE_SAME_OPS); - } - - template - void ReduceSameFunction::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop) { - DISPATCH_BY_OPNUM_T(exec, PARAMS(x, - xShapeInfo, - extraParams, - z, - zShapeInfo, - dimension, - dimensionLength, - tadShapeInfo, - tadOffset, start, stop), - REDUCE_SAME_OPS); - } - - template - template - void _CUDA_H ReduceSameFunction::exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - auto zLength = shape::length(zShapeInfo); - - if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { - if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) - return; - const auto startingVal = OpType::startingValue(x); - - for (Nd4jLong i = 0; i < zLength; i++) - z[i] = startingVal; - return; - } - - //pre squeezed: this is for keeping the pointer to the original - //shape information for tad offset - //the squeezed information doesn't render the right strides for - //tad offset - // || tad.wholeThing - if (zLength == 1 || dimension == nullptr || dimensionLength == shape::rank(xShapeInfo)) { - z[0] = execScalar(x, xShapeInfo, extraParams); - return; - } - - if (OpType::requiresSpecialAccumulation) { - OpType::execSpecial(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset); - return; - } - - auto tadOnlyShapeInfo = tadShapeInfo; - auto tadOffsets = tadOffset; - - if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) { - if (dimensionLength < 1) - return; - - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength); - tadOnlyShapeInfo = tadPack.primaryShapeInfo(); - tadOffsets = tadPack.primaryOffsets(); - } +namespace reduce { +template +template +void _CUDA_H ReduceSameFunction::execScalar(const void *vx, + const Nd4jLong *xShapeInfo, + void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + const auto length = shape::length(xShapeInfo); + const auto xEws = shape::elementWiseStride(xShapeInfo); + const int rank = shape::rank(xShapeInfo); + + if (shape::isEmpty(xShapeInfo)) { + z[0] = OpType::startingValue(x); + return; + } + + if (sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { + if (sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) return; + const auto startingVal = OpType::startingValue(x); + + for (Nd4jLong i = 0; i < length; i++) z[i] = startingVal; + return; + } + + if (xEws >= 1) { + z[0] = execScalar(x, xEws, length, extraParams); + } else { + auto startingValue = OpType::startingValue(x); + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + int maxThreads = sd::math::nd4j_min( + 64, sd::Environment::getInstance().maxThreads()); + X intermediate[64]; + + PRAGMA_OMP_SIMD + for (auto e = 0; e < maxThreads; e++) + intermediate[e] = OpType::startingValue(x); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) + intermediate[thread_id] = OpType::update( + intermediate[thread_id], + OpType::op( + x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], + extraParams), + extraParams); + }; + + maxThreads = + samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); + + // merge results + for (int e = 1; e < maxThreads; e++) + intermediate[0] = + OpType::update(intermediate[0], intermediate[e], extraParams); + + // write out results + z[0] = OpType::postProcess(intermediate[0], length, extraParams); + } +} + +template +template +X _CUDA_H ReduceSameFunction::execScalar(const void *vx, + const Nd4jLong *xShapeInfo, + void *vextraParams) { + auto x = reinterpret_cast(vx); + auto extraParams = reinterpret_cast(vextraParams); + + const Nd4jLong length = shape::length(xShapeInfo); + const auto xEws = shape::elementWiseStride(xShapeInfo); + + if (xEws >= 1) { + return execScalar(x, xEws, length, extraParams); + } else { + auto startingValue = OpType::startingValue(x); + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + for (Nd4jLong i = 0; i < length; i++) + startingValue = OpType::update( + startingValue, + OpType::op( + x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], + extraParams), + extraParams); + + return OpType::postProcess(startingValue, length, extraParams); + } +} + +template +X ReduceSameFunction::execScalar(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams) { + RETURNING_DISPATCH_BY_OPNUM_T(execScalar, PARAMS(x, xShapeInfo, extraParams), + REDUCE_SAME_OPS); +} + +template +void ReduceSameFunction::execScalar(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo) { + DISPATCH_BY_OPNUM_T(execScalar, + PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), + REDUCE_SAME_OPS); +} + +template +void ReduceSameFunction::exec(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, + int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, + int64_t stop) { + DISPATCH_BY_OPNUM_T( + exec, + PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffset, start, stop), + REDUCE_SAME_OPS); +} + +template +template +void _CUDA_H ReduceSameFunction::exec( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, int64_t start, + int64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + auto zLength = shape::length(zShapeInfo); + + if (sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { + if (sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) return; + const auto startingVal = OpType::startingValue(x); + + for (Nd4jLong i = 0; i < zLength; i++) z[i] = startingVal; + return; + } + + // pre squeezed: this is for keeping the pointer to the original + // shape information for tad offset + // the squeezed information doesn't render the right strides for + // tad offset + // || tad.wholeThing + if (zLength == 1 || dimension == nullptr || + dimensionLength == shape::rank(xShapeInfo)) { + z[0] = execScalar(x, xShapeInfo, extraParams); + return; + } + + if (OpType::requiresSpecialAccumulation) { + OpType::execSpecial(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffset); + return; + } + + auto tadOnlyShapeInfo = tadShapeInfo; + auto tadOffsets = tadOffset; + + if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) { + if (dimensionLength < 1) return; + + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + xShapeInfo, dimension, dimensionLength); + tadOnlyShapeInfo = tadPack.primaryShapeInfo(); + tadOffsets = tadPack.primaryOffsets(); + } #ifdef INLINE_LOOPS - sd::ReductionLoops::template loopReduce(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop); + sd::ReductionLoops::template loopReduce( + x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, + start, stop); #else - sd::ReductionSameLoops::template innerloopReduce(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop); + sd::ReductionSameLoops::template innerloopReduce( + x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, + start, stop); #endif - } - - - template - template - void _CUDA_H ReduceSameFunction::exec(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *zShapeInfo) { - auto z = reinterpret_cast(vz); - z[0] = execScalar(x, xShapeInfo, extraParams); - } - - template - template - X _CUDA_H ReduceSameFunction::execScalar(const void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) { - - auto x = reinterpret_cast(vx); - auto extraParams = reinterpret_cast(vextraParams); - int maxThreads = sd::math::nd4j_min(64, sd::Environment::getInstance().maxThreads()); - X intermediate[64]; - - PRAGMA_OMP_SIMD - for (auto e = 0; e < maxThreads; e++) - intermediate[e] = OpType::startingValue(x); - - auto func = PRAGMA_THREADS_FOR { - if (xEws == 1) { - for (auto i = start; i < stop; i++) - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], extraParams), extraParams); - } else { - for (auto i = start; i < stop; i++) - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i * xEws], extraParams), extraParams); - } - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); +} + +template +template +void _CUDA_H ReduceSameFunction::exec(const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams, void *vz, + const Nd4jLong *zShapeInfo) { + auto z = reinterpret_cast(vz); + z[0] = execScalar(x, xShapeInfo, extraParams); +} + +template +template +X _CUDA_H ReduceSameFunction::execScalar(const void *vx, Nd4jLong xEws, + Nd4jLong length, + void *vextraParams) { + auto x = reinterpret_cast(vx); + auto extraParams = reinterpret_cast(vextraParams); + int maxThreads = + sd::math::nd4j_min(64, sd::Environment::getInstance().maxThreads()); + X intermediate[64]; + + PRAGMA_OMP_SIMD + for (auto e = 0; e < maxThreads; e++) + intermediate[e] = OpType::startingValue(x); + + auto func = PRAGMA_THREADS_FOR { + if (xEws == 1) { + for (auto i = start; i < stop; i++) + intermediate[thread_id] = + OpType::update(intermediate[thread_id], + OpType::op(x[i], extraParams), extraParams); + } else { + for (auto i = start; i < stop; i++) + intermediate[thread_id] = + OpType::update(intermediate[thread_id], + OpType::op(x[i * xEws], extraParams), extraParams); + } + }; - // merge results - for (int e = 1; e < maxThreads; e++) - intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams); + maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); - // return result - return OpType::postProcess(intermediate[0], length, extraParams); - } + // merge results + for (int e = 1; e < maxThreads; e++) + intermediate[0] = + OpType::update(intermediate[0], intermediate[e], extraParams); + // return result + return OpType::postProcess(intermediate[0], length, extraParams); +} - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT ReduceSameFunction, , LIBND4J_TYPES); - } -} \ No newline at end of file +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT ReduceSameFunction, , + LIBND4J_TYPES); +} // namespace reduce +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/reduce3.hpp b/libnd4j/include/loops/cpu/reduce3.hpp index a19c7c1a15e4..f808ac1ef2b1 100644 --- a/libnd4j/include/loops/cpu/reduce3.hpp +++ b/libnd4j/include/loops/cpu/reduce3.hpp @@ -17,248 +17,276 @@ // @author raver119@gmail.com // @author Yurii Shyrma (iuriish@yahoo.com), created on 19.11.2018 - -#include -#include -#include -#include +#include #include #include -#include +#include +#include +#include +#include using namespace simdOps; namespace functions { -namespace reduce3 { +namespace reduce3 { ////////////////////////////////////////////////////////////////////////// template -template -void Reduce3::execScalar(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - auto length = shape::length(xShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - auto yEws = shape::elementWiseStride(yShapeInfo); - - if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY || sd::ArrayOptions::arrayType(yShapeInfo) == sd::ArrayType::EMPTY) { - if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) - return; - const auto startingVal = OpType::startingValue(x); - - for (Nd4jLong i = 0; i < length; i++) - z[i] = startingVal; - - return; - } - - Z extraParamsVals[3] = {(Z) 0.0f, (Z) 0.0f, (Z) 0.0f}; - - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - Z startingVal = OpType::startingValue(x); - int maxThreads = sd::math::nd4j_min(64, sd::Environment::getInstance().maxThreads()); - Z intermediate[64]; - Z extraParamsLocal[3 * 64]; - +template +void Reduce3::execScalar(const void *vx, const Nd4jLong *xShapeInfo, + void *vextraParams, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + auto length = shape::length(xShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + + if (sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY || + sd::ArrayOptions::arrayType(yShapeInfo) == sd::ArrayType::EMPTY) { + if (sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) return; + const auto startingVal = OpType::startingValue(x); + + for (Nd4jLong i = 0; i < length; i++) z[i] = startingVal; + + return; + } + + Z extraParamsVals[3] = {(Z)0.0f, (Z)0.0f, (Z)0.0f}; + + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + Z startingVal = OpType::startingValue(x); + int maxThreads = + sd::math::nd4j_min(64, sd::Environment::getInstance().maxThreads()); + Z intermediate[64]; + Z extraParamsLocal[3 * 64]; + + PRAGMA_OMP_SIMD + for (int e = 0; e < maxThreads; e++) intermediate[e] = startingVal; + + memset(extraParamsLocal, 0, 3 * 64 * sizeof(Z)); + if (extraParams != nullptr) { PRAGMA_OMP_SIMD - for (int e = 0; e < maxThreads; e++) - intermediate[e] = startingVal; - - memset(extraParamsLocal, 0, 3 * 64 * sizeof(Z)); - if (extraParams != nullptr) { - PRAGMA_OMP_SIMD - // mostly for future reference - for (int e = 0; e < maxThreads; e++) { - extraParamsLocal[3 * e] = extraParams[0]; - extraParamsLocal[3 * e + 1] = extraParams[1]; - extraParamsLocal[3 * e + 2] = extraParams[2]; - } + // mostly for future reference + for (int e = 0; e < maxThreads; e++) { + extraParamsLocal[3 * e] = extraParams[0]; + extraParamsLocal[3 * e + 1] = extraParams[1]; + extraParamsLocal[3 * e + 2] = extraParams[2]; } - - sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXZ(xShapeInfo, yShapeInfo); - - if (kindOfLoop == sd::LoopKind::EWS1) { - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], y[i], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id); - } - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); - - } else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[offset], y[offset], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id); - } - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); - } else { - uint yShapeInfoCast[MAX_RANK]; - const bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[xOffset], y[yOffset], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id); - } - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); - } - - // merge step - for (int e = 0; e < maxThreads; e++) - OpType::aggregateExtraParams(extraParamsVals, extraParamsLocal + 3 * e); - - for (int e = 0; e < maxThreads; e++) - startingVal = OpType::update(startingVal, intermediate[e], extraParamsVals); - - // writing out result - z[0] = OpType::postProcess(startingVal, length, extraParamsVals); + } + + sd::LoopKind::Kind kindOfLoop = + sd::LoopKind::deduceKindOfLoopXZ(xShapeInfo, yShapeInfo); + + if (kindOfLoop == sd::LoopKind::EWS1) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + intermediate[thread_id] = OpType::update( + intermediate[thread_id], + OpType::op(x[i], y[i], extraParamsLocal + 3 * thread_id), + extraParamsLocal + 3 * thread_id); + } + }; + + maxThreads = + samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); + + } else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + intermediate[thread_id] = OpType::update( + intermediate[thread_id], + OpType::op(x[offset], y[offset], extraParamsLocal + 3 * thread_id), + extraParamsLocal + 3 * thread_id); + } + }; + + maxThreads = + samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); + } else { + uint yShapeInfoCast[MAX_RANK]; + const bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + intermediate[thread_id] = + OpType::update(intermediate[thread_id], + OpType::op(x[xOffset], y[yOffset], + extraParamsLocal + 3 * thread_id), + extraParamsLocal + 3 * thread_id); + } + }; + + maxThreads = + samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); + } + + // merge step + for (int e = 0; e < maxThreads; e++) + OpType::aggregateExtraParams(extraParamsVals, extraParamsLocal + 3 * e); + + for (int e = 0; e < maxThreads; e++) + startingVal = OpType::update(startingVal, intermediate[e], extraParamsVals); + + // writing out result + z[0] = OpType::postProcess(startingVal, length, extraParamsVals); } ////////////////////////////////////////////////////////////////////////// template -void Reduce3::execScalar(const int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - void *extraParamsVals, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo), REDUCE3_OPS); +void Reduce3::execScalar(const int opNum, const void *vx, + const Nd4jLong *xShapeInfo, + void *extraParamsVals, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + DISPATCH_BY_OPNUM_TT( + execScalar, + PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo), + REDUCE3_OPS); } - ////////////////////////////////////////////////////////////////////////// template -template -void Reduce3::exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - int64_t start, int64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - if(shape::isScalar(zShapeInfo)) { - execScalar(vx, xShapeInfo, vextraParams, vy, yShapeInfo, vz, zShapeInfo); - return; - } +template +void Reduce3::exec(const void *vx, const Nd4jLong *xShapeInfo, + void *vextraParams, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, int64_t start, int64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + if (shape::isScalar(zShapeInfo)) { + execScalar(vx, xShapeInfo, vextraParams, vy, yShapeInfo, vz, + zShapeInfo); + return; + } #ifdef INLINE_LOOPS - sd::Reduction3Loops::template loopReduce3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, extraParams, start, stop); + sd::Reduction3Loops::template loopReduce3( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, + extraParams, start, stop); #else - sd::Reduction3Loops::template innerloopReduce3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, extraParams, start, stop); + sd::Reduction3Loops::template innerloopReduce3( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, + extraParams, start, stop); #endif } ////////////////////////////////////////////////////////////////////////// template -template -void Reduce3::exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - int64_t start, int64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); +template +void Reduce3::exec(const void *vx, const Nd4jLong *xShapeInfo, + void *vextraParams, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, int64_t start, + int64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); #ifdef INLINE_LOOPS - sd::Reduction3Loops::template loopReduce3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, extraParams, start, stop); + sd::Reduction3Loops::template loopReduce3( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, + extraParams, start, stop); #else - sd::Reduction3Loops::template innerloopReduce3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, extraParams, start, stop); + sd::Reduction3Loops::template innerloopReduce3( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, + extraParams, start, stop); #endif } - ////////////////////////////////////////////////////////////////////////// template -template -void Reduce3:: execAll(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets, - int64_t start, int64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); +template +void Reduce3::execAll(const void *vx, const Nd4jLong *xShapeInfo, + void *vextraParams, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xOffsets, + const Nd4jLong *yTadShapeInfo, + const Nd4jLong *yOffsets, int64_t start, + int64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); #ifdef INLINE_LOOPS - sd::Reduction3Loops::template loopReduce3All(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, extraParams, start, stop); + sd::Reduction3Loops::template loopReduce3All( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xOffsets, + yTadShapeInfo, yOffsets, extraParams, start, stop); #else - sd::Reduction3Loops::template innerloopReduce3All(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, extraParams, start, stop); + sd::Reduction3Loops::template innerloopReduce3All( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xOffsets, + yTadShapeInfo, yOffsets, extraParams, start, stop); #endif } ////////////////////////////////////////////////////////////////////////// template -void Reduce3::exec(const int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - void *extraParamsVals, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - int64_t start, int64_t stop) { - - DISPATCH_BY_OPNUM_TT(exec, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo, dimension, dimensionLength, start, stop), REDUCE3_OPS); +void Reduce3::exec(const int opNum, const void *vx, + const Nd4jLong *xShapeInfo, void *extraParamsVals, + const void *vy, const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, int64_t start, int64_t stop) { + DISPATCH_BY_OPNUM_TT( + exec, + PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo, + dimension, dimensionLength, start, stop), + REDUCE3_OPS); } - ////////////////////////////////////////////////////////////////////////// template -void Reduce3::exec(const int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - void *extraParamsVals, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - int64_t start, int64_t stop) { - - DISPATCH_BY_OPNUM_TT(exec, PARAMS(vx,xShapeInfo,extraParamsVals,vy, yShapeInfo,vz,zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), REDUCE3_OPS); +void Reduce3::exec(const int opNum, const void *vx, + const Nd4jLong *xShapeInfo, void *extraParamsVals, + const void *vy, const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, int64_t start, + int64_t stop) { + DISPATCH_BY_OPNUM_TT( + exec, + PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo, + dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), + REDUCE3_OPS); } - ////////////////////////////////////////////////////////////////////////// template -void Reduce3::execAll(const int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - void *extraParamsVals, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets, - int64_t start, int64_t stop) { - - DISPATCH_BY_OPNUM_TT(execAll, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, start, stop), REDUCE3_OPS); +void Reduce3::execAll( + const int opNum, const void *vx, const Nd4jLong *xShapeInfo, + void *extraParamsVals, const void *vy, const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, + const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets, int64_t start, + int64_t stop) { + DISPATCH_BY_OPNUM_TT( + execAll, + PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo, + dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, + yOffsets, start, stop), + REDUCE3_OPS); } -} -} \ No newline at end of file +} // namespace reduce3 +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/scalar.hpp b/libnd4j/include/loops/cpu/scalar.hpp index f539f387f182..6becfde6dd5a 100644 --- a/libnd4j/include/loops/cpu/scalar.hpp +++ b/libnd4j/include/loops/cpu/scalar.hpp @@ -18,193 +18,198 @@ // Created by raver119 on 08.10.2017. // -#include "../scalar.h" +#include +#include #include #include -#include -#include + #include "../legacy_ops.h" +#include "../scalar.h" using namespace simdOps; namespace functions { -namespace scalar { - +namespace scalar { //////////////////////////////////////////////////////////////////////// -template -template -void ScalarTransform::transform(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - const void *vscalars, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets, - const uint64_t start, const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto scalars = reinterpret_cast(vscalars); - auto extraParams = reinterpret_cast(vextraParams); - - if (zTadShapeInfo == nullptr) { - zTadShapeInfo = xTadShapeInfo; - zTadOffsets = xTadOffsets; - } - - const int xTadEws = shape::elementWiseStride(xTadShapeInfo); - const int zTadEws = shape::elementWiseStride(zTadShapeInfo); - const int tadLength = shape::tadLength(xShapeInfo, dimension, dimensionLength); - const int numTads = shape::length(xShapeInfo) / tadLength; - - sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXZ(xTadShapeInfo, zTadShapeInfo); - - if (kindOfLoop != sd::LoopKind::EWS1 && kindOfLoop != sd::LoopKind::EWSNONZERO) { - printf("ScalarTransform::transform: super-bad loop visited. Shouldn't ever happen\n"); - return; - } - - int num_threads = sd::math::nd4j_min(numTads, sd::Environment::getInstance().maxThreads()); - - if (kindOfLoop == sd::LoopKind::EWS1) { - for (auto r = start; r < stop; r++) { - auto oZ = z + zTadOffsets[r]; - auto oX = x + xTadOffsets[r]; - - PRAGMA_OMP_SIMD - for (int f = 0; f < tadLength; f++) - oZ[f] = OpType::op(oX[f], scalars[r], extraParams); - }; - } - else { - for (auto r = start; r < stop; r++) { - auto oZ = z + zTadOffsets[r]; - auto oX = x + xTadOffsets[r]; - - PRAGMA_OMP_SIMD - for (int f = 0; f < tadLength; f++) - oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams); - }; - } +template +template +void ScalarTransform::transform( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, const void *vscalars, int *dimension, + int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xTadOffsets, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffsets, const uint64_t start, const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto scalars = reinterpret_cast(vscalars); + auto extraParams = reinterpret_cast(vextraParams); + + if (zTadShapeInfo == nullptr) { + zTadShapeInfo = xTadShapeInfo; + zTadOffsets = xTadOffsets; + } + + const int xTadEws = shape::elementWiseStride(xTadShapeInfo); + const int zTadEws = shape::elementWiseStride(zTadShapeInfo); + const int tadLength = + shape::tadLength(xShapeInfo, dimension, dimensionLength); + const int numTads = shape::length(xShapeInfo) / tadLength; + + sd::LoopKind::Kind kindOfLoop = + sd::LoopKind::deduceKindOfLoopXZ(xTadShapeInfo, zTadShapeInfo); + + if (kindOfLoop != sd::LoopKind::EWS1 && + kindOfLoop != sd::LoopKind::EWSNONZERO) { + printf( + "ScalarTransform::transform: super-bad loop visited. Shouldn't " + "ever happen\n"); + return; + } + + int num_threads = sd::math::nd4j_min( + numTads, sd::Environment::getInstance().maxThreads()); + + if (kindOfLoop == sd::LoopKind::EWS1) { + for (auto r = start; r < stop; r++) { + auto oZ = z + zTadOffsets[r]; + auto oX = x + xTadOffsets[r]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) + oZ[f] = OpType::op(oX[f], scalars[r], extraParams); + }; + } else { + for (auto r = start; r < stop; r++) { + auto oZ = z + zTadOffsets[r]; + auto oX = x + xTadOffsets[r]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) + oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams); + }; + } } //////////////////////////////////////////////////////////////////////// -template -void ScalarTransform::transform(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets, - const uint64_t start, const uint64_t stop) { - - DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, start, stop), SCALAR_OPS); +template +void ScalarTransform::transform( + int opNum, const void *x, const Nd4jLong *xShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, const void *scalars, int *dimension, + int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xTadOffsets, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffsets, const uint64_t start, const uint64_t stop) { + DISPATCH_BY_OPNUM_TTT( + transform, + PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, + dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, + zTadOffsets, start, stop), + SCALAR_OPS); } //////////////////////////////////////////////////////////////////////// -template -void ScalarTransform::transform(const int opNum, - const void *x, Nd4jLong xStride, - void *z, Nd4jLong zStride, - const void *scalar, - void *extraParams, - const uint64_t n, - const uint64_t start, const uint64_t stop) { - - DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xStride, z, zStride, scalar, extraParams, n, start, stop), SCALAR_OPS); +template +void ScalarTransform::transform(const int opNum, const void *x, + Nd4jLong xStride, void *z, + Nd4jLong zStride, const void *scalar, + void *extraParams, const uint64_t n, + const uint64_t start, + const uint64_t stop) { + DISPATCH_BY_OPNUM_TTT( + transform, + PARAMS(x, xStride, z, zStride, scalar, extraParams, n, start, stop), + SCALAR_OPS); } //////////////////////////////////////////////////////////////////////// -template -void ScalarTransform::transform(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - const void *scalar, - void *extraParams, - const uint64_t start, const uint64_t stop) { - - DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), SCALAR_OPS); +template +void ScalarTransform::transform(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *z, + const Nd4jLong *zShapeInfo, + const void *scalar, void *extraParams, + const uint64_t start, + const uint64_t stop) { + DISPATCH_BY_OPNUM_TTT( + transform, + PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), + SCALAR_OPS); } //////////////////////////////////////////////////////////////////////// -template -template -void ScalarTransform::transform(const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - const void *vscalar, - void *vextraParams, - const uint64_t start, const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto scalar = reinterpret_cast(vscalar)[0]; - auto extraParams = reinterpret_cast(vextraParams); - - const auto len = shape::length(xShapeInfo); - const auto xEws = shape::elementWiseStride(xShapeInfo); - const auto zEws = shape::elementWiseStride(zShapeInfo); - - sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo); - - if (kindOfLoop == sd::LoopKind::EWS1 || kindOfLoop == sd::LoopKind::EWSNONZERO) { - transform(x, xEws, z, zEws, vscalar, extraParams, len, start, stop); - } - else { - - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpType::op(x[offset], scalar, extraParams); - }; - } - else { - uint zShapeInfoCast[MAX_RANK]; - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[xOffset], scalar, extraParams); - }; - } +template +template +void ScalarTransform::transform( + const void *vx, const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, const void *vscalar, void *vextraParams, + const uint64_t start, const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto scalar = reinterpret_cast(vscalar)[0]; + auto extraParams = reinterpret_cast(vextraParams); + + const auto len = shape::length(xShapeInfo); + const auto xEws = shape::elementWiseStride(xShapeInfo); + const auto zEws = shape::elementWiseStride(zShapeInfo); + + sd::LoopKind::Kind kindOfLoop = + sd::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo); + + if (kindOfLoop == sd::LoopKind::EWS1 || + kindOfLoop == sd::LoopKind::EWSNONZERO) { + transform(x, xEws, z, zEws, vscalar, extraParams, len, start, stop); + } else { + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = OpType::op(x[offset], scalar, extraParams); + }; + } else { + uint zShapeInfoCast[MAX_RANK]; + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[xOffset], scalar, extraParams); + }; } + } } //////////////////////////////////////////////////////////////////////// -template -template +template +template void ScalarTransform::transform(const void *vx, Nd4jLong xEws, void *vz, Nd4jLong zEws, const void *vscalar, - void *vextraParams, - const uint64_t len, const uint64_t start, const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto scalar = reinterpret_cast(vscalar)[0]; - auto extraParams = reinterpret_cast(vextraParams); - - if (xEws == 1 && zEws == 1) { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) - z[i] = OpType::op(x[i], scalar, extraParams); - } - else { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) - z[i * zEws] = OpType::op(x[i * xEws], scalar, extraParams); - } + void *vextraParams, const uint64_t len, + const uint64_t start, + const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto scalar = reinterpret_cast(vscalar)[0]; + auto extraParams = reinterpret_cast(vextraParams); + + if (xEws == 1 && zEws == 1) { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) + z[i] = OpType::op(x[i], scalar, extraParams); + } else { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) + z[i * zEws] = OpType::op(x[i * xEws], scalar, extraParams); + } } - - -} -} +} // namespace scalar +} // namespace functions diff --git a/libnd4j/include/loops/cpu/scalar_bool.cpp b/libnd4j/include/loops/cpu/scalar_bool.cpp index 63182bdc3b54..397eb879cea2 100644 --- a/libnd4j/include/loops/cpu/scalar_bool.cpp +++ b/libnd4j/include/loops/cpu/scalar_bool.cpp @@ -19,187 +19,193 @@ // #include "../scalar_bool.h" + +#include +#include #include #include -#include -#include #include "../legacy_ops.h" using namespace simdOps; namespace functions { - namespace scalar { - - - template - template - void ScalarBoolTransform::transform(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - const void *vscalars, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets, - const uint64_t start, const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto scalars = reinterpret_cast(vscalars); - auto extraParams = reinterpret_cast(vextraParams); - - if (zTadShapeInfo == nullptr) { - zTadShapeInfo = xTadShapeInfo; - zTadOffsets = xTadOffsets; - } - - // tad preparation - const int xTadEws = shape::elementWiseStride(xTadShapeInfo); - const int zTadEws = shape::elementWiseStride(zTadShapeInfo); - const int tadLength = shape::tadLength(xShapeInfo, dimension, dimensionLength); - const int numTads = shape::length(xShapeInfo) / tadLength; - - sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXZ(xTadShapeInfo, zTadShapeInfo); - - if (kindOfLoop != sd::LoopKind::EWS1 && kindOfLoop != sd::LoopKind::EWSNONZERO) { - printf("ScalarBoolTransform::transform: super-bad loop visited. Shouldn't ever happen\n"); - return; - } - - int num_threads = sd::math::nd4j_min(numTads, sd::Environment::getInstance().maxThreads()); - - if (kindOfLoop == sd::LoopKind::EWS1) { - for (auto r = start; r < stop; r++) { - auto oZ = z + zTadOffsets[r]; - auto oX = x + xTadOffsets[r]; - - PRAGMA_OMP_SIMD - for (int f = 0; f < tadLength; f++) - oZ[f] = OpType::op(oX[f], scalars[r], extraParams); - }; - } - else { - for (auto r = start; r < stop; r++) { - auto oZ = z + zTadOffsets[r]; - auto oX = x + xTadOffsets[r]; - - PRAGMA_OMP_SIMD - for (int f = 0; f < tadLength; f++) - oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams); - }; - } - } - - template - void ScalarBoolTransform::transform(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets, - const uint64_t start, const uint64_t stop) { - DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, start, stop), SCALAR_BOOL_OPS); - } - - - template - void ScalarBoolTransform::transform(const int opNum, - const void *x, Nd4jLong xEws, - void *z, Nd4jLong zEws, - const void *scalar, - void *extraParams, - const uint64_t n, - const uint64_t start, const uint64_t stop) { - DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n, start, stop), SCALAR_BOOL_OPS); - } - - template - void ScalarBoolTransform::transform(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - const void *scalar, - void *extraParams, - const uint64_t start, const uint64_t stop) { - DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), SCALAR_BOOL_OPS); - } - - template - template - void ScalarBoolTransform::transform(const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - const void *vscalar, - void *vextraParams, - const uint64_t start, const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto scalar = reinterpret_cast(vscalar)[0]; - auto extraParams = reinterpret_cast(vextraParams); - - auto xEws = shape::elementWiseStride(xShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); - auto len = shape::length(xShapeInfo); - - sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo); - - if (kindOfLoop == sd::LoopKind::EWS1 || kindOfLoop == sd::LoopKind::EWSNONZERO) { - transform(x, xEws, z, zEws, vscalar, extraParams, len, start, stop); - return; - } - - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpType::op(x[offset], scalar, extraParams); - }; - } - else { - uint zShapeInfoCast[MAX_RANK]; - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[xOffset], scalar, extraParams); - }; - } - } - - - template - template - void ScalarBoolTransform::transform(const void *vx, Nd4jLong xEws, - void *vz, Nd4jLong zEws, - const void *vscalar, - void *vextraParams, - const uint64_t len, - const uint64_t start, const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto scalar = reinterpret_cast(vscalar)[0]; - auto extraParams = reinterpret_cast(vextraParams); - - if (xEws == 1 && zEws == 1) { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) - z[i] = OpType::op(x[i], scalar, extraParams); - } - else { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) - z[i * zEws] = OpType::op(x[i * xEws], scalar, extraParams); - } - } - - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ScalarBoolTransform, , LIBND4J_TYPES, BOOL_TYPES); +namespace scalar { + +template +template +void ScalarBoolTransform::transform( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, const void *vscalars, int *dimension, + int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xTadOffsets, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffsets, const uint64_t start, const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto scalars = reinterpret_cast(vscalars); + auto extraParams = reinterpret_cast(vextraParams); + + if (zTadShapeInfo == nullptr) { + zTadShapeInfo = xTadShapeInfo; + zTadOffsets = xTadOffsets; + } + + // tad preparation + const int xTadEws = shape::elementWiseStride(xTadShapeInfo); + const int zTadEws = shape::elementWiseStride(zTadShapeInfo); + const int tadLength = + shape::tadLength(xShapeInfo, dimension, dimensionLength); + const int numTads = shape::length(xShapeInfo) / tadLength; + + sd::LoopKind::Kind kindOfLoop = + sd::LoopKind::deduceKindOfLoopXZ(xTadShapeInfo, zTadShapeInfo); + + if (kindOfLoop != sd::LoopKind::EWS1 && + kindOfLoop != sd::LoopKind::EWSNONZERO) { + printf( + "ScalarBoolTransform::transform: super-bad loop visited. " + "Shouldn't ever happen\n"); + return; + } + + int num_threads = sd::math::nd4j_min( + numTads, sd::Environment::getInstance().maxThreads()); + + if (kindOfLoop == sd::LoopKind::EWS1) { + for (auto r = start; r < stop; r++) { + auto oZ = z + zTadOffsets[r]; + auto oX = x + xTadOffsets[r]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) + oZ[f] = OpType::op(oX[f], scalars[r], extraParams); + }; + } else { + for (auto r = start; r < stop; r++) { + auto oZ = z + zTadOffsets[r]; + auto oX = x + xTadOffsets[r]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) + oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams); + }; + } +} +template +void ScalarBoolTransform::transform( + int opNum, const void *x, const Nd4jLong *xShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, const void *scalars, int *dimension, + int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xTadOffsets, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffsets, const uint64_t start, const uint64_t stop) { + DISPATCH_BY_OPNUM_TT( + transform, + PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, + dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, + zTadOffsets, start, stop), + SCALAR_BOOL_OPS); } + +template +void ScalarBoolTransform::transform(const int opNum, const void *x, + Nd4jLong xEws, void *z, Nd4jLong zEws, + const void *scalar, void *extraParams, + const uint64_t n, + const uint64_t start, + const uint64_t stop) { + DISPATCH_BY_OPNUM_TT( + transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n, start, stop), + SCALAR_BOOL_OPS); } + +template +void ScalarBoolTransform::transform(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *z, + const Nd4jLong *zShapeInfo, + const void *scalar, void *extraParams, + const uint64_t start, + const uint64_t stop) { + DISPATCH_BY_OPNUM_TT( + transform, + PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), + SCALAR_BOOL_OPS); +} + +template +template +void ScalarBoolTransform::transform( + const void *vx, const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, const void *vscalar, void *vextraParams, + const uint64_t start, const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto scalar = reinterpret_cast(vscalar)[0]; + auto extraParams = reinterpret_cast(vextraParams); + + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + auto len = shape::length(xShapeInfo); + + sd::LoopKind::Kind kindOfLoop = + sd::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo); + + if (kindOfLoop == sd::LoopKind::EWS1 || + kindOfLoop == sd::LoopKind::EWSNONZERO) { + transform(x, xEws, z, zEws, vscalar, extraParams, len, start, stop); + return; + } + + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = OpType::op(x[offset], scalar, extraParams); + }; + } else { + uint zShapeInfoCast[MAX_RANK]; + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[xOffset], scalar, extraParams); + }; + } +} + +template +template +void ScalarBoolTransform::transform( + const void *vx, Nd4jLong xEws, void *vz, Nd4jLong zEws, const void *vscalar, + void *vextraParams, const uint64_t len, const uint64_t start, + const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto scalar = reinterpret_cast(vscalar)[0]; + auto extraParams = reinterpret_cast(vextraParams); + + if (xEws == 1 && zEws == 1) { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) + z[i] = OpType::op(x[i], scalar, extraParams); + } else { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) + z[i * zEws] = OpType::op(x[i * xEws], scalar, extraParams); + } +} + +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ScalarBoolTransform, , + LIBND4J_TYPES, BOOL_TYPES); + +} // namespace scalar +} // namespace functions diff --git a/libnd4j/include/loops/cpu/scalar_int.cpp b/libnd4j/include/loops/cpu/scalar_int.cpp index adf53e7f6790..141e69026643 100644 --- a/libnd4j/include/loops/cpu/scalar_int.cpp +++ b/libnd4j/include/loops/cpu/scalar_int.cpp @@ -19,185 +19,195 @@ // #include "../scalar_int.h" + +#include +#include #include #include -#include -#include #include "../legacy_ops.h" using namespace simdOps; namespace functions { - namespace scalar { - - - template - template - void ScalarIntTransform::transform(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - const void *vscalars, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets, - const uint64_t start, const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto scalars = reinterpret_cast(vscalars); - auto extraParams = reinterpret_cast(vextraParams); - - if (zTadShapeInfo == nullptr) { - zTadShapeInfo = xTadShapeInfo; - zTadOffsets = xTadOffsets; - } - - // tad preparation - const int xTadEws = shape::elementWiseStride(xTadShapeInfo); - const int zTadEws = shape::elementWiseStride(zTadShapeInfo); - const int tadLength = shape::tadLength(xShapeInfo, dimension, dimensionLength); - const int numTads = shape::length(xShapeInfo) / tadLength; - - sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXZ(xTadShapeInfo, zTadShapeInfo); - - if (kindOfLoop != sd::LoopKind::EWS1 && kindOfLoop != sd::LoopKind::EWSNONZERO) { - printf("ScalarIntTransform::transform: super-bad loop visited. Shouldn't ever happen\n"); - return; - } - - int num_threads = sd::math::nd4j_min(numTads, sd::Environment::getInstance().maxThreads()); - - if (kindOfLoop == sd::LoopKind::EWS1) { - for (auto r = start; r < stop; r++) { - auto oZ = z + zTadOffsets[r]; - auto oX = x + xTadOffsets[r]; - - PRAGMA_OMP_SIMD - for (int f = 0; f < tadLength; f++) - oZ[f] = OpType::op(oX[f], scalars[r], extraParams); - }; - } - else { - for (auto r = start; r < stop; r++) { - auto oZ = z + zTadOffsets[r]; - auto oX = x + xTadOffsets[r]; - - PRAGMA_OMP_SIMD - for (int f = 0; f < tadLength; f++) - oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams); - }; - } - } - - template - void ScalarIntTransform::transform(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets, - const uint64_t start, const uint64_t stop) { - - DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, start, stop), SCALAR_INT_OPS); - } - - - template - void ScalarIntTransform::transform(const int opNum, - const void *x, Nd4jLong xEws, - void *z, Nd4jLong zEws, - const void *scalar, - void *extraParams, - const uint64_t n, - const uint64_t start, const uint64_t stop) { - DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n, start, stop), SCALAR_INT_OPS); - } - - template - void ScalarIntTransform::transform(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - const void *scalar, - void *extraParams, - const uint64_t start, const uint64_t stop) { - DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), SCALAR_INT_OPS); - } - - template - template - void ScalarIntTransform::transform(const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - const void *vscalar, void *vextraParams, - const uint64_t start, const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto scalar = reinterpret_cast(vscalar)[0]; - auto extraParams = reinterpret_cast(vextraParams); - - auto xEws = shape::elementWiseStride(xShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); - auto len = shape::length(xShapeInfo); - - sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo); - - if (kindOfLoop == sd::LoopKind::EWS1 || kindOfLoop == sd::LoopKind::EWSNONZERO) { - transform(x, xEws, z, zEws, vscalar, extraParams, len, start, stop); - return; - } - - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpType::op(x[offset], scalar, extraParams); - }; - } - else { - uint zShapeInfoCast[MAX_RANK]; - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[xOffset], scalar, extraParams); - }; - } - } - - - template - template - void ScalarIntTransform::transform(const void *vx, Nd4jLong xEws, - void *vz, Nd4jLong zEws, - const void *vscalar, - void *vextraParams, - const uint64_t len, const uint64_t start, const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto scalar = reinterpret_cast(vscalar)[0]; - auto extraParams = reinterpret_cast(vextraParams); - - if (scalar < (sizeof(X) * 8)) { - if (xEws == 1 && zEws == 1) { - for (auto i = start; i < stop; i++) - z[i] = OpType::op(x[i], scalar, extraParams); - } else { - for (auto i = start; i < stop; i++) - z[i * zEws] = OpType::op(x[i * xEws], scalar, extraParams); - } - } - } - - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT ScalarIntTransform, , INTEGER_TYPES); +namespace scalar { + +template +template +void ScalarIntTransform::transform( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, const void *vscalars, int *dimension, + int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xTadOffsets, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffsets, const uint64_t start, const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto scalars = reinterpret_cast(vscalars); + auto extraParams = reinterpret_cast(vextraParams); + + if (zTadShapeInfo == nullptr) { + zTadShapeInfo = xTadShapeInfo; + zTadOffsets = xTadOffsets; + } + + // tad preparation + const int xTadEws = shape::elementWiseStride(xTadShapeInfo); + const int zTadEws = shape::elementWiseStride(zTadShapeInfo); + const int tadLength = + shape::tadLength(xShapeInfo, dimension, dimensionLength); + const int numTads = shape::length(xShapeInfo) / tadLength; + + sd::LoopKind::Kind kindOfLoop = + sd::LoopKind::deduceKindOfLoopXZ(xTadShapeInfo, zTadShapeInfo); + + if (kindOfLoop != sd::LoopKind::EWS1 && + kindOfLoop != sd::LoopKind::EWSNONZERO) { + printf( + "ScalarIntTransform::transform: super-bad loop visited. Shouldn't " + "ever happen\n"); + return; + } + + int num_threads = sd::math::nd4j_min( + numTads, sd::Environment::getInstance().maxThreads()); + + if (kindOfLoop == sd::LoopKind::EWS1) { + for (auto r = start; r < stop; r++) { + auto oZ = z + zTadOffsets[r]; + auto oX = x + xTadOffsets[r]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) + oZ[f] = OpType::op(oX[f], scalars[r], extraParams); + }; + } else { + for (auto r = start; r < stop; r++) { + auto oZ = z + zTadOffsets[r]; + auto oX = x + xTadOffsets[r]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) + oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams); + }; + } +} +template +void ScalarIntTransform::transform( + int opNum, const void *x, const Nd4jLong *xShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, const void *scalars, int *dimension, + int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xTadOffsets, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffsets, const uint64_t start, const uint64_t stop) { + DISPATCH_BY_OPNUM_T( + transform, + PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, + dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, + zTadOffsets, start, stop), + SCALAR_INT_OPS); } + +template +void ScalarIntTransform::transform(const int opNum, const void *x, + Nd4jLong xEws, void *z, Nd4jLong zEws, + const void *scalar, void *extraParams, + const uint64_t n, const uint64_t start, + const uint64_t stop) { + DISPATCH_BY_OPNUM_T( + transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n, start, stop), + SCALAR_INT_OPS); } + +template +void ScalarIntTransform::transform(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *z, + const Nd4jLong *zShapeInfo, + const void *scalar, void *extraParams, + const uint64_t start, + const uint64_t stop) { + DISPATCH_BY_OPNUM_T( + transform, + PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), + SCALAR_INT_OPS); +} + +template +template +void ScalarIntTransform::transform(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + const void *vscalar, void *vextraParams, + const uint64_t start, + const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto scalar = reinterpret_cast(vscalar)[0]; + auto extraParams = reinterpret_cast(vextraParams); + + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + auto len = shape::length(xShapeInfo); + + sd::LoopKind::Kind kindOfLoop = + sd::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo); + + if (kindOfLoop == sd::LoopKind::EWS1 || + kindOfLoop == sd::LoopKind::EWSNONZERO) { + transform(x, xEws, z, zEws, vscalar, extraParams, len, start, stop); + return; + } + + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = OpType::op(x[offset], scalar, extraParams); + }; + } else { + uint zShapeInfoCast[MAX_RANK]; + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[xOffset], scalar, extraParams); + }; + } +} + +template +template +void ScalarIntTransform::transform(const void *vx, Nd4jLong xEws, void *vz, + Nd4jLong zEws, const void *vscalar, + void *vextraParams, const uint64_t len, + const uint64_t start, + const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto scalar = reinterpret_cast(vscalar)[0]; + auto extraParams = reinterpret_cast(vextraParams); + + if (scalar < (sizeof(X) * 8)) { + if (xEws == 1 && zEws == 1) { + for (auto i = start; i < stop; i++) + z[i] = OpType::op(x[i], scalar, extraParams); + } else { + for (auto i = start; i < stop; i++) + z[i * zEws] = OpType::op(x[i * xEws], scalar, extraParams); + } + } +} + +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT ScalarIntTransform, , + INTEGER_TYPES); + +} // namespace scalar +} // namespace functions diff --git a/libnd4j/include/loops/cpu/summarystatsreduce.cpp b/libnd4j/include/loops/cpu/summarystatsreduce.cpp index 63993d853178..72257a80d906 100644 --- a/libnd4j/include/loops/cpu/summarystatsreduce.cpp +++ b/libnd4j/include/loops/cpu/summarystatsreduce.cpp @@ -18,168 +18,180 @@ // Created by raver119 on 18.12.17. // -#include -#include -#include -#include -#include -#include #include +#include +#include +#include +#include +#include +#include using namespace simdOps; namespace functions { - namespace summarystats { - - - template - Y SummaryStatsReduce::execScalar(const int opNum, - const bool biasCorrected, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams) { - RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(biasCorrected, x, xShapeInfo, extraParams), SUMMARY_STATS_OPS); +namespace summarystats { + +template +Y SummaryStatsReduce::execScalar(const int opNum, + const bool biasCorrected, const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams) { + RETURNING_DISPATCH_BY_OPNUM_TT( + execScalar, PARAMS(biasCorrected, x, xShapeInfo, extraParams), + SUMMARY_STATS_OPS); +} + +template +void SummaryStatsReduce::execScalar(const int opNum, + const bool biasCorrected, + const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo) { + DISPATCH_BY_OPNUM_TT( + execScalar, + PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, zShapeInfo), + SUMMARY_STATS_OPS); +} + +template +void SummaryStatsReduce::exec(const int opNum, const bool biasCorrected, + const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength) { + DISPATCH_BY_OPNUM_TT(exec, + PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, + zShapeInfo, dimension, dimensionLength), + SUMMARY_STATS_OPS); +} + +template +template +void SummaryStatsReduce::execScalar(const bool biasCorrected, + const void *vx, + const Nd4jLong *xShapeInfo, + void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo) { + auto z = reinterpret_cast(vz); + z[0] = execScalar(biasCorrected, vx, xShapeInfo, vextraParams); +} + +template +template +Z SummaryStatsReduce::execScalar(const bool biasCorrected, const void *vx, + const Nd4jLong *xShapeInfo, + void *vextraParams) { + auto x = reinterpret_cast(vx); + auto extraParams = reinterpret_cast(vextraParams); + + SummaryStatsData startingIndex; + startingIndex.initialize(); + auto length = shape::length(xShapeInfo); + + uint xShapeInfoCast[MAX_RANK]; + const bool canCast = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + for (Nd4jLong i = 0; i < length; i++) { + auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCast); + + SummaryStatsData curr; + curr.initWithValue(x[xOffset]); + startingIndex = update(startingIndex, curr, extraParams); + } + + return OpType::getValue(biasCorrected, startingIndex); +} + +template +template +void SummaryStatsReduce::exec(const bool biasCorrected, const void *vx, + const Nd4jLong *xShapeInfo, + void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + auto resultLength = shape::length(zShapeInfo); + + if (sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { + if (sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) return; + SummaryStatsData comp; + comp.initWithValue(x[0]); + + for (Nd4jLong i = 0; i < resultLength; i++) + z[i] = OpType::getValue(biasCorrected, comp); + return; + } + + if (shape::isScalar(zShapeInfo)) { + z[0] = execScalar(biasCorrected, x, xShapeInfo, extraParams); + return; + } + + // no-op + if (dimensionLength < 1) return; + + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + xShapeInfo, dimension, dimensionLength); + + // pre squeezed: this is for keeping the pointer to the original + // shape information for tad offset + // the squeezed information doesn't render the right strides for + // tad offset + if (resultLength == 1 || dimensionLength == shape::rank(xShapeInfo) || + tadPack.numberOfTads() == 1) { + z[0] = execScalar(biasCorrected, x, xShapeInfo, extraParams); + return; + } + + auto tadShapeShapeInfo = tadPack.primaryShapeInfo(); + auto tadLength = shape::length(tadPack.primaryShapeInfo()); + auto tadEWS = shape::elementWiseStride(tadPack.primaryShapeInfo()); + auto tadOrder = shape::order(tadPack.primaryShapeInfo()); + + uint tadShapeShapeInfoCast[MAX_RANK]; + const bool canCast = tadEWS == 1 && tadOrder == 'c' + ? false + : sd::DataTypeUtils::castShapeInfo( + tadShapeShapeInfo, tadShapeShapeInfoCast); + + auto func = PRAGMA_THREADS_FOR { + for (auto r = start; r < stop; r++) { + auto tadOffsetForBlock = tadPack.primaryOffsets()[r]; + auto tx = x + tadOffsetForBlock; + SummaryStatsData comp; + comp.initWithValue(tx[0]); + + if (tadEWS == 1 && tadOrder == 'c') { + for (Nd4jLong i = 1; i < tadLength; i++) { + SummaryStatsData indexVal2; + indexVal2.initWithValue(tx[i]); + + comp = update(comp, OpType::op(indexVal2, extraParams), extraParams); } + } else { + for (Nd4jLong i = 1; i < tadLength; i++) { + auto xOffset = shape::indexOffset(i, tadShapeShapeInfo, + tadShapeShapeInfoCast, canCast); - template - void SummaryStatsReduce::execScalar(const int opNum, - const bool biasCorrected, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo) { - DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, zShapeInfo), SUMMARY_STATS_OPS); - } - - template - void SummaryStatsReduce::exec(const int opNum, - const bool biasCorrected, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength) { - DISPATCH_BY_OPNUM_TT(exec, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength), SUMMARY_STATS_OPS); - } - - template - template - void SummaryStatsReduce::execScalar(const bool biasCorrected, - const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo) { - auto z = reinterpret_cast(vz); - z[0] = execScalar(biasCorrected, vx, xShapeInfo, vextraParams); - } - - template - template - Z SummaryStatsReduce::execScalar(const bool biasCorrected, const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams) { - - auto x = reinterpret_cast(vx); - auto extraParams = reinterpret_cast(vextraParams); + SummaryStatsData indexVal2; + indexVal2.initWithValue(tx[xOffset]); - SummaryStatsData startingIndex; - startingIndex.initialize(); - auto length = shape::length(xShapeInfo); - - uint xShapeInfoCast[MAX_RANK]; - const bool canCast = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - for (Nd4jLong i = 0; i < length; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCast); - - SummaryStatsData curr; - curr.initWithValue(x[xOffset]); - startingIndex = update(startingIndex, curr, extraParams); - } - - return OpType::getValue(biasCorrected, startingIndex); + comp = update(comp, OpType::op(indexVal2, extraParams), extraParams); } + } - template - template - void SummaryStatsReduce::exec(const bool biasCorrected, - const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - auto resultLength = shape::length(zShapeInfo); - - if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { - if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) - return; - SummaryStatsData comp; - comp.initWithValue(x[0]); - - for (Nd4jLong i = 0; i < resultLength; i++) - z[i] = OpType::getValue(biasCorrected, comp); - return; - } - - if (shape::isScalar(zShapeInfo)) { - z[0] = execScalar(biasCorrected, x, xShapeInfo, extraParams); - return; - } - - //no-op - if (dimensionLength < 1) - return; - - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength); - - //pre squeezed: this is for keeping the pointer to the original - //shape information for tad offset - //the squeezed information doesn't render the right strides for - //tad offset - if (resultLength == 1 || dimensionLength == shape::rank(xShapeInfo) || tadPack.numberOfTads() == 1) { - z[0] = execScalar(biasCorrected, x, xShapeInfo, extraParams); - return; - } - - auto tadShapeShapeInfo = tadPack.primaryShapeInfo(); - auto tadLength = shape::length(tadPack.primaryShapeInfo()); - auto tadEWS = shape::elementWiseStride(tadPack.primaryShapeInfo()); - auto tadOrder = shape::order(tadPack.primaryShapeInfo()); - - uint tadShapeShapeInfoCast[MAX_RANK]; - const bool canCast = tadEWS == 1 && tadOrder == 'c' ? false : sd::DataTypeUtils::castShapeInfo(tadShapeShapeInfo, tadShapeShapeInfoCast); - - auto func = PRAGMA_THREADS_FOR { - for (auto r = start; r < stop; r++) { - - auto tadOffsetForBlock = tadPack.primaryOffsets()[r]; - auto tx = x + tadOffsetForBlock; - SummaryStatsData comp; - comp.initWithValue(tx[0]); - - if (tadEWS == 1 && tadOrder == 'c') { - for (Nd4jLong i = 1; i < tadLength; i++) { - SummaryStatsData indexVal2; - indexVal2.initWithValue(tx[i]); - - comp = update(comp, OpType::op(indexVal2, extraParams), extraParams); - } - } else { - for (Nd4jLong i = 1; i < tadLength; i++) { - auto xOffset = shape::indexOffset(i, tadShapeShapeInfo, tadShapeShapeInfoCast, canCast); - - SummaryStatsData indexVal2; - indexVal2.initWithValue(tx[xOffset]); - - comp = update(comp, OpType::op(indexVal2, extraParams), extraParams); - } - } - - z[r] = OpType::getValue(biasCorrected, comp); - } - }; - - samediff::Threads::parallel_tad(func, 0, resultLength, 1); - } + z[r] = OpType::getValue(biasCorrected, comp); + } + }; + samediff::Threads::parallel_tad(func, 0, resultLength, 1); +} - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT SummaryStatsReduce, , LIBND4J_TYPES, FLOAT_TYPES); - } -} \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT SummaryStatsReduce, , + LIBND4J_TYPES, FLOAT_TYPES); +} // namespace summarystats +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/transform/transform_any.cpp b/libnd4j/include/loops/cpu/transform/transform_any.cpp index 6a8c07094c35..2f2646d51637 100644 --- a/libnd4j/include/loops/cpu/transform/transform_any.cpp +++ b/libnd4j/include/loops/cpu/transform/transform_any.cpp @@ -18,43 +18,45 @@ // @author raver119@gmail.com // -#include #include -#include -#include #include +#include +#include +#include using namespace simdOps; namespace functions { - namespace transform { - - template - void TransformAny::exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads) { - DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_ANY_OPS); - } +namespace transform { + +template +void TransformAny::exec(int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *z, + const Nd4jLong *zShapeInfo, void *extraParams, + uint64_t threadId, uint64_t numThreads) { + DISPATCH_BY_OPNUM_TT( + exec, + PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), + TRANSFORM_ANY_OPS); +} ///////////////////////////////////////////////////////////////////// template -template -void _CUDA_H TransformAny::exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraParams, - uint64_t threadId, uint64_t numThreads) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - sd::TransformLoops::template loopTransform(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads); +template +void _CUDA_H TransformAny::exec(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + void *vextraParams, uint64_t threadId, + uint64_t numThreads) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + sd::TransformLoops::template loopTransform( + x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads); } - - -BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformAny, , LIBND4J_TYPES, LIBND4J_TYPES); -} -} \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT TransformAny, , LIBND4J_TYPES, + LIBND4J_TYPES); +} // namespace transform +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/transform/transform_bool.cpp b/libnd4j/include/loops/cpu/transform/transform_bool.cpp index 5e88a15c3094..6d0170b3cbc7 100644 --- a/libnd4j/include/loops/cpu/transform/transform_bool.cpp +++ b/libnd4j/include/loops/cpu/transform/transform_bool.cpp @@ -18,40 +18,44 @@ // @author raver119@gmail.com // -#include #include -#include -#include #include +#include +#include +#include using namespace simdOps; namespace functions { - namespace transform { - - template - void TransformBool::exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads) { - DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_BOOL_OPS); - } - - template - template - void _CUDA_H TransformBool::exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraParams, - uint64_t threadId, uint64_t numThreads) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - sd::TransformLoops::template loopTransform(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads); - } - - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformBool, , LIBND4J_TYPES, BOOL_TYPES); - } -} \ No newline at end of file +namespace transform { + +template +void TransformBool::exec(int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *z, + const Nd4jLong *zShapeInfo, void *extraParams, + uint64_t threadId, uint64_t numThreads) { + DISPATCH_BY_OPNUM_TT( + exec, + PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), + TRANSFORM_BOOL_OPS); +} + +template +template +void _CUDA_H TransformBool::exec(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + void *vextraParams, uint64_t threadId, + uint64_t numThreads) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + sd::TransformLoops::template loopTransform( + x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads); +} + +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT TransformBool, , LIBND4J_TYPES, + BOOL_TYPES); +} // namespace transform +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/transform/transform_float.cpp b/libnd4j/include/loops/cpu/transform/transform_float.cpp index fd37391c209e..bda579240de6 100644 --- a/libnd4j/include/loops/cpu/transform/transform_float.cpp +++ b/libnd4j/include/loops/cpu/transform/transform_float.cpp @@ -18,39 +18,43 @@ // @author raver119@gmail.com // -#include #include -#include -#include #include +#include +#include +#include using namespace simdOps; namespace functions { - namespace transform { - template - void TransformFloat::exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads) { - DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_FLOAT_OPS); - } - - template - template - void _CUDA_H TransformFloat::exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraParams, - uint64_t threadId, uint64_t numThreads) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - sd::TransformLoops::template loopTransform(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads); - } - - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformFloat, , LIBND4J_TYPES, FLOAT_TYPES); - } -} \ No newline at end of file +namespace transform { +template +void TransformFloat::exec(int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *z, + const Nd4jLong *zShapeInfo, void *extraParams, + uint64_t threadId, uint64_t numThreads) { + DISPATCH_BY_OPNUM_TT( + exec, + PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), + TRANSFORM_FLOAT_OPS); +} + +template +template +void _CUDA_H TransformFloat::exec(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + void *vextraParams, uint64_t threadId, + uint64_t numThreads) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + sd::TransformLoops::template loopTransform( + x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads); +} + +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT TransformFloat, , LIBND4J_TYPES, + FLOAT_TYPES); +} // namespace transform +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/transform/transform_same.cpp b/libnd4j/include/loops/cpu/transform/transform_same.cpp index d2793d9c0857..4530d71b0c6f 100644 --- a/libnd4j/include/loops/cpu/transform/transform_same.cpp +++ b/libnd4j/include/loops/cpu/transform/transform_same.cpp @@ -18,41 +18,42 @@ // @author raver119@gmail.com // -#include #include -#include -#include #include +#include +#include +#include using namespace simdOps; namespace functions { - namespace transform { - - template - void TransformSame::exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads) { - DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_SAME_OPS); - } - - template - template - void _CUDA_H TransformSame::exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraParams, - uint64_t threadId, uint64_t numThreads) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - - sd::TransformLoops::template loopTransform(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads); - } - - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TransformSame, , LIBND4J_TYPES); - } -} \ No newline at end of file +namespace transform { + +template +void TransformSame::exec(int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *z, + const Nd4jLong *zShapeInfo, void *extraParams, + uint64_t threadId, uint64_t numThreads) { + DISPATCH_BY_OPNUM_T( + exec, + PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), + TRANSFORM_SAME_OPS); +} + +template +template +void _CUDA_H TransformSame::exec(const void *vx, const Nd4jLong *xShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, + void *vextraParams, uint64_t threadId, + uint64_t numThreads) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + sd::TransformLoops::template loopTransform( + x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads); +} + +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT TransformSame, , LIBND4J_TYPES); +} // namespace transform +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/transform/transform_strict.cpp b/libnd4j/include/loops/cpu/transform/transform_strict.cpp index 54a24d0e3ecb..a5070f7a566e 100644 --- a/libnd4j/include/loops/cpu/transform/transform_strict.cpp +++ b/libnd4j/include/loops/cpu/transform/transform_strict.cpp @@ -18,41 +18,43 @@ // @author raver119@gmail.com // -#include #include -#include -#include #include +#include +#include +#include using namespace simdOps; namespace functions { - namespace transform { - - template - void TransformStrict::exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *z, +namespace transform { + +template +void TransformStrict::exec(int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *z, + const Nd4jLong *zShapeInfo, void *extraParams, + uint64_t threadId, uint64_t numThreads) { + DISPATCH_BY_OPNUM_T( + exec, + PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), + TRANSFORM_STRICT_OPS); +} + +template +template +void _CUDA_H TransformStrict::exec(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads) { - DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_STRICT_OPS); - } - - template - template - void _CUDA_H TransformStrict::exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraParams, - uint64_t threadId, uint64_t numThreads) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - sd::TransformLoops::template loopTransform(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads); - } - - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TransformStrict, , FLOAT_TYPES); - } -} \ No newline at end of file + void *vextraParams, uint64_t threadId, + uint64_t numThreads) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + sd::TransformLoops::template loopTransform( + x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads); +} + +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT TransformStrict, , FLOAT_TYPES); +} // namespace transform +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/broadcasting.chpp b/libnd4j/include/loops/cuda/broadcasting.chpp index 4b5c7833fe59..146b7f930b8c 100644 --- a/libnd4j/include/loops/cuda/broadcasting.chpp +++ b/libnd4j/include/loops/cuda/broadcasting.chpp @@ -18,297 +18,336 @@ // @author raver119@gmail.com // -#include -#include -#include -#include -#include #include #include -#include -#include #include +#include +#include #include +#include +#include +#include + +#include +#include using namespace simdOps; -template +template static __global__ void broadcastSimple( - void const* x, - Nd4jLong const* xShapeInfo, - void const* y, - Nd4jLong const* yShapeInfo, - void *z, - Nd4jLong const* zShapeInfo, - int *dimension, - int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - functions::broadcast::Broadcast::template transformCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); + void const* x, Nd4jLong const* xShapeInfo, void const* y, + Nd4jLong const* yShapeInfo, void* z, Nd4jLong const* zShapeInfo, + int* dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + functions::broadcast::Broadcast::template transformCuda( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, + tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); } -template -static __global__ void broadcastSimple(const void const* x, const Nd4jLong const* xShapeInfo, - const void const* y, const Nd4jLong const* yShapeInfo, - void *z, const Nd4jLong const* zShapeInfo ) { +template +static __global__ void broadcastSimple(const void const* x, + const Nd4jLong const* xShapeInfo, + const void const* y, + const Nd4jLong const* yShapeInfo, + void* z, + const Nd4jLong const* zShapeInfo) { + functions::broadcast::Broadcast::template transformCuda( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); +} - functions::broadcast::Broadcast::template transformCuda(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); +template +static __global__ void broadcastInverseSimple( + void const* x, Nd4jLong const* xShapeInfo, void const* y, + Nd4jLong const* yShapeInfo, void* z, Nd4jLong const* zShapeInfo, + int* dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + functions::broadcast::Broadcast::template transformInverseCuda< + OpClass>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, + dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, + tadOffsetsZ); } +namespace functions { +namespace broadcast { -template -static __global__ void broadcastInverseSimple( - void const* x, - Nd4jLong const* xShapeInfo, - void const* y, - Nd4jLong const* yShapeInfo, - void *z, - Nd4jLong const* zShapeInfo, - int *dimension, - int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - functions::broadcast::Broadcast::template transformInverseCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); +static Nd4jLong __device__ __noinline__ +getIndexOffset(Nd4jLong index, const Nd4jLong* shapeInfo) { + return shape::getIndexOffset(index, shapeInfo); +} + +static Nd4jLong __device__ __noinline__ length(const Nd4jLong* shapeInfo) { + return shape::length(shapeInfo); } +template +template +__host__ void Broadcast::intermediateBroadcast( + dim3 launchDims, cudaStream_t* stream, void const* x, + Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, + void* z, Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + broadcastSimple + <<>>( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, + dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, + tadOffsetsZ); +} -namespace functions { - namespace broadcast { - - static Nd4jLong __device__ __noinline__ getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo) { - return shape::getIndexOffset(index, shapeInfo); - } - - static Nd4jLong __device__ __noinline__ length(const Nd4jLong *shapeInfo) { - return shape::length(shapeInfo); - } - - template - template - __host__ void Broadcast::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void* z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - broadcastSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); - } - - template - template - __host__ void Broadcast::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo) { - broadcastSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); - } - - template - __host__ void Broadcast::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - DISPATCH_BY_OPNUM_TTT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_OPS)) - - DEBUG_KERNEL(stream, opNum); - } - - template - __host__ void Broadcast::execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong const* zShapeInfo) { - DISPATCH_BY_OPNUM_TTT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), OPS_A(BROADCAST_OPS)) - - DEBUG_KERNEL(stream, opNum); - } - - template - template - __host__ void Broadcast::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - broadcastInverseSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); - } - - template - __host__ void Broadcast::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - DISPATCH_BY_OPNUM_TTT(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_OPS)) - - DEBUG_KERNEL(stream, opNum); - } - - template - template - __device__ void Broadcast::transformInverseCuda( - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void* vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - if (tadOnlyShapeInfoZ == nullptr) { - tadOnlyShapeInfoZ = tadOnlyShapeInfo; - tadOffsetsZ = tadOffsets; - } - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the x shape info for setting up the tad problem - __shared__ Nd4jLong tadLength; - __shared__ Nd4jLong tadEWS; - __shared__ int numTads; - __shared__ Nd4jLong xEWS; - __shared__ Nd4jLong zEWS; - - - if (threadIdx.x == 0) { - tadLength = length(tadOnlyShapeInfo); - tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); - numTads = length(yShapeInfo) / tadLength; - xEWS = shape::elementWiseStride(xShapeInfo); - zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); - } - __syncthreads(); - - auto xOrder = shape::order(xShapeInfo); - auto yOrder = shape::order(tadOnlyShapeInfo); - auto zOrder = shape::order(tadOnlyShapeInfoZ); - - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - - - auto rY = y + tadOffsets[r]; - auto rZ = z + tadOffsetsZ[r]; - - - if(tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1 && xOrder == yOrder && xOrder == zOrder) { - for (int i = threadIdx.x; i < tadLength; i+= blockDim.x) - rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS]); - } - else { - // it is expected that x and z tads and y array all have the same length - for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { - auto xOffset = getIndexOffset(i, xShapeInfo); - auto yOffset = getIndexOffset(i, tadOnlyShapeInfo); - auto zOffset = getIndexOffset(i, tadOnlyShapeInfoZ); - rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]); - } - } - } - } - - - template - template - __device__ void Broadcast::transformCuda( - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - if (tadOnlyShapeInfoZ == nullptr) { - tadOnlyShapeInfoZ = tadOnlyShapeInfo; - tadOffsetsZ = tadOffsets; - } - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the x shape info for setting up the tad problem - __shared__ Nd4jLong tadLength; - __shared__ Nd4jLong tadEWS; - __shared__ int numTads; - __shared__ Nd4jLong yEWS; - __shared__ Nd4jLong zEWS; - - if (threadIdx.x == 0) { - tadLength = length(tadOnlyShapeInfo); - tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); - numTads = length(xShapeInfo) / tadLength; - yEWS = shape::elementWiseStride(yShapeInfo); - zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); - } - __syncthreads(); - - auto xOrder = shape::order(tadOnlyShapeInfo); - auto yOrder = shape::order(yShapeInfo); - auto zOrder = shape::order(tadOnlyShapeInfoZ); - - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - - auto rX = x + tadOffsets[r]; - auto rZ = z + tadOffsetsZ[r]; - - - if(tadEWS > 0 && zEWS > 0 && yEWS > 0 && xOrder == yOrder && xOrder == zOrder) { - for (int i = threadIdx.x; i < tadLength; i+= blockDim.x) - rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS]); - } - else { - // it is expected that x and z tads and y array all have the same length - for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { - - auto xOffset = getIndexOffset(i, tadOnlyShapeInfo); - auto yOffset = getIndexOffset(i, yShapeInfo); - auto zOffset = getIndexOffset(i, tadOnlyShapeInfoZ); - rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]); - } - } - } - } +template +template +__host__ void Broadcast::intermediateBroadcast( + dim3 launchDims, cudaStream_t* stream, const void* x, + const Nd4jLong* xShapeInfo, const void* y, const Nd4jLong* yShapeInfo, + void* z, const Nd4jLong* zShapeInfo) { + broadcastSimple + <<>>( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); +} -//////////////////////////////////////////////////////////////////////// -template -template -__device__ void Broadcast::transformCuda( - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { +template +__host__ void Broadcast::execBroadcast( + dim3 launchDims, cudaStream_t* stream, int opNum, void const* x, + Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, + void* z, Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + DISPATCH_BY_OPNUM_TTT( + intermediateBroadcast, + PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ), + OPS_A(BROADCAST_OPS)) + + DEBUG_KERNEL(stream, opNum); +} - const X* x = reinterpret_cast(vx); - const Y* y = reinterpret_cast(vy); - Z* z = reinterpret_cast(vz); +template +__host__ void Broadcast::execBroadcast( + dim3 launchDims, cudaStream_t* stream, const int opNum, const void* x, + const Nd4jLong* xShapeInfo, const void* y, const Nd4jLong* yShapeInfo, + void* z, const Nd4jLong const* zShapeInfo) { + DISPATCH_BY_OPNUM_TTT( + intermediateBroadcast, + PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), + OPS_A(BROADCAST_OPS)) + + DEBUG_KERNEL(stream, opNum); +} - __shared__ Nd4jLong zLen; - __shared__ int rank; - __shared__ bool xzSameOffsets, yzSameOffsets; +template +template +__host__ void Broadcast::intermediateInverseBroadcast( + dim3 launchDims, cudaStream_t* stream, void const* x, + Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, + void* z, Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + broadcastInverseSimple + <<>>( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, + dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, + tadOffsetsZ); +} - if (threadIdx.x == 0) { +template +__host__ void Broadcast::execInverseBroadcast( + dim3 launchDims, cudaStream_t* stream, int opNum, void const* x, + Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, + void* z, Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + DISPATCH_BY_OPNUM_TTT( + intermediateInverseBroadcast, + PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ), + OPS_A(BROADCAST_OPS)) + + DEBUG_KERNEL(stream, opNum); +} - zLen = shape::length(zShapeInfo); - rank = shape::rank(zShapeInfo); +template +template +__device__ void Broadcast::transformInverseCuda( + void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + int* dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + if (tadOnlyShapeInfoZ == nullptr) { + tadOnlyShapeInfoZ = tadOnlyShapeInfo; + tadOffsetsZ = tadOffsets; + } + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the x shape info for setting up the tad problem + __shared__ Nd4jLong tadLength; + __shared__ Nd4jLong tadEWS; + __shared__ int numTads; + __shared__ Nd4jLong xEWS; + __shared__ Nd4jLong zEWS; + + if (threadIdx.x == 0) { + tadLength = length(tadOnlyShapeInfo); + tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); + numTads = length(yShapeInfo) / tadLength; + xEWS = shape::elementWiseStride(xShapeInfo); + zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); + } + __syncthreads(); + + auto xOrder = shape::order(xShapeInfo); + auto yOrder = shape::order(tadOnlyShapeInfo); + auto zOrder = shape::order(tadOnlyShapeInfoZ); + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + auto rY = y + tadOffsets[r]; + auto rZ = z + tadOffsetsZ[r]; + + if (tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1 && + xOrder == yOrder && xOrder == zOrder) { + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) + rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS]); + } else { + // it is expected that x and z tads and y array all have the same length + for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = getIndexOffset(i, xShapeInfo); + auto yOffset = getIndexOffset(i, tadOnlyShapeInfo); + auto zOffset = getIndexOffset(i, tadOnlyShapeInfoZ); + rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]); + } + } + } +} - xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); +template +template +__device__ void Broadcast::transformCuda( + void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + int* dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + if (tadOnlyShapeInfoZ == nullptr) { + tadOnlyShapeInfoZ = tadOnlyShapeInfo; + tadOffsetsZ = tadOffsets; + } + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the x shape info for setting up the tad problem + __shared__ Nd4jLong tadLength; + __shared__ Nd4jLong tadEWS; + __shared__ int numTads; + __shared__ Nd4jLong yEWS; + __shared__ Nd4jLong zEWS; + + if (threadIdx.x == 0) { + tadLength = length(tadOnlyShapeInfo); + tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); + numTads = length(xShapeInfo) / tadLength; + yEWS = shape::elementWiseStride(yShapeInfo); + zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); + } + __syncthreads(); + + auto xOrder = shape::order(tadOnlyShapeInfo); + auto yOrder = shape::order(yShapeInfo); + auto zOrder = shape::order(tadOnlyShapeInfoZ); + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + auto rX = x + tadOffsets[r]; + auto rZ = z + tadOffsetsZ[r]; + + if (tadEWS > 0 && zEWS > 0 && yEWS > 0 && xOrder == yOrder && + xOrder == zOrder) { + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) + rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS]); + } else { + // it is expected that x and z tads and y array all have the same length + for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = getIndexOffset(i, tadOnlyShapeInfo); + auto yOffset = getIndexOffset(i, yShapeInfo); + auto zOffset = getIndexOffset(i, tadOnlyShapeInfoZ); + rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]); + } } - __syncthreads(); + } +} +//////////////////////////////////////////////////////////////////////// +template +template +__device__ void Broadcast::transformCuda( + const void* vx, const Nd4jLong* xShapeInfo, const void* vy, + const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo) { + const X* x = reinterpret_cast(vx); + const Y* y = reinterpret_cast(vy); + Z* z = reinterpret_cast(vz); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + __shared__ Nd4jLong zLen; + __shared__ int rank; + __shared__ bool xzSameOffsets, yzSameOffsets; - int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; + if (threadIdx.x == 0) { + zLen = shape::length(zShapeInfo); + rank = shape::rank(zShapeInfo); - for (int i = tid; i < zLen; i += blockDim.x * gridDim.x) { + xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); + } + __syncthreads(); - shape::index2coords(i, zShapeInfo, zCoords); + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (uint j = 0; j < rank; ++j) { - xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; - yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; - } + int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; - const auto zOffset = shape::getOffset(zShapeInfo, zCoords); - const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); - const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); + for (int i = tid; i < zLen; i += blockDim.x * gridDim.x) { + shape::index2coords(i, zShapeInfo, zCoords); - z[zOffset] = OpType::op(x[xOffset], y[yOffset]); + for (uint j = 0; j < rank; ++j) { + xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; + yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; } + + const auto zOffset = shape::getOffset(zShapeInfo, zCoords); + const auto xOffset = + xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); + const auto yOffset = + yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset]); + } } /* - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_0); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_1); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_2); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_3); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_4); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_5); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_6); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_7); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_8); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_9); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , + PAIRWISE_TYPES_0); BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT + Broadcast, , PAIRWISE_TYPES_1); BUILD_PAIRWISE_TEMPLATE(template class + SD_EXPORT Broadcast, , PAIRWISE_TYPES_2); BUILD_PAIRWISE_TEMPLATE(template + class SD_EXPORT Broadcast, , PAIRWISE_TYPES_3); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , + PAIRWISE_TYPES_4); BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT + Broadcast, , PAIRWISE_TYPES_5); BUILD_PAIRWISE_TEMPLATE(template class + SD_EXPORT Broadcast, , PAIRWISE_TYPES_6); BUILD_PAIRWISE_TEMPLATE(template + class SD_EXPORT Broadcast, , PAIRWISE_TYPES_7); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , + PAIRWISE_TYPES_8); BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT + Broadcast, , PAIRWISE_TYPES_9); */ - } -} \ No newline at end of file +} // namespace broadcast +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/broadcasting.cu b/libnd4j/include/loops/cuda/broadcasting.cu index 55c882c3f81d..dc40f2b3b7b3 100644 --- a/libnd4j/include/loops/cuda/broadcasting.cu +++ b/libnd4j/include/loops/cuda/broadcasting.cu @@ -18,20 +18,19 @@ // @author raver119@gmail.com // -#include -#include -#include -#include -#include #include #include -#include -#include #include +#include +#include #include +#include +#include +#include -namespace functions { - namespace broadcast { +#include +#include - } -} \ No newline at end of file +namespace functions { +namespace broadcast {} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/broadcasting_bool.cu b/libnd4j/include/loops/cuda/broadcasting_bool.cu index bed00a20f38e..2f6a08b4b9c4 100644 --- a/libnd4j/include/loops/cuda/broadcasting_bool.cu +++ b/libnd4j/include/loops/cuda/broadcasting_bool.cu @@ -18,303 +18,337 @@ // @author raver119@gmail.com // -#include +#include +#include +#include #include #include -#include #include -#include -#include -#include +#include +#include + #include -#include +#include using namespace simdOps; ////////////////////////////////////////////////////////////////////////// -template +template static __global__ void broadcastBoolSimple( - void const* x, - Nd4jLong const* xShapeInfo, - void const* y, - Nd4jLong const* yShapeInfo, - void *z, - Nd4jLong const* zShapeInfo, - void *extraParams, - int *dimension, - int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - functions::broadcast::BroadcastBool::template transformCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo, extraParams, dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); + void const* x, Nd4jLong const* xShapeInfo, void const* y, + Nd4jLong const* yShapeInfo, void* z, Nd4jLong const* zShapeInfo, + void* extraParams, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + functions::broadcast::BroadcastBool::template transformCuda( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, + dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, + tadOffsetsZ); } ////////////////////////////////////////////////////////////////////////// -template -static __global__ void broadcastBoolSimple(const void const* x, const Nd4jLong const* xShapeInfo, - const void const* y, const Nd4jLong const* yShapeInfo, - void *z, const Nd4jLong const* zShapeInfo, - void *extraParams) { - - functions::broadcast::BroadcastBool::template transformCuda(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); +template +static __global__ void broadcastBoolSimple( + const void const* x, const Nd4jLong const* xShapeInfo, const void const* y, + const Nd4jLong const* yShapeInfo, void* z, const Nd4jLong const* zShapeInfo, + void* extraParams) { + functions::broadcast::BroadcastBool::template transformCuda( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); } ////////////////////////////////////////////////////////////////////////// -template +template static __global__ void broadcastBoolInverseSimple( - void const* x, - Nd4jLong const* xShapeInfo, - void const* y, - Nd4jLong const* yShapeInfo, - void *z, - Nd4jLong const* zShapeInfo, - void *extraParams, - int *dimension, - int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - functions::broadcast::BroadcastBool::template transformInverseCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,extraParams,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); + void const* x, Nd4jLong const* xShapeInfo, void const* y, + Nd4jLong const* yShapeInfo, void* z, Nd4jLong const* zShapeInfo, + void* extraParams, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + functions::broadcast::BroadcastBool::template transformInverseCuda< + OpClass>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ); } namespace functions { namespace broadcast { ////////////////////////////////////////////////////////////////////////// -template +template template -__host__ void BroadcastBool::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void* z, Nd4jLong const* zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - broadcastBoolSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); - sd::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed"); +__host__ void BroadcastBool::intermediateBroadcast( + dim3 launchDims, cudaStream_t* stream, void const* x, + Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, + void* z, Nd4jLong const* zShapeInfo, void* extraParams, int* dimension, + int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + broadcastBoolSimple + <<>>( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, + dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, + tadOffsetsZ); + sd::DebugHelper::checkErrorCode(stream, + "intermediateBroadcastBool(...) failed"); } ////////////////////////////////////////////////////////////////////////// -template +template template -__host__ void BroadcastBool::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams) { - - broadcastBoolSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); - sd::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed"); +__host__ void BroadcastBool::intermediateBroadcast( + dim3 launchDims, cudaStream_t* stream, const void* x, + const Nd4jLong* xShapeInfo, const void* y, const Nd4jLong* yShapeInfo, + void* z, const Nd4jLong* zShapeInfo, void* extraParams) { + broadcastBoolSimple + <<>>( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); + sd::DebugHelper::checkErrorCode(stream, + "intermediateBroadcastBool(...) failed"); } ////////////////////////////////////////////////////////////////////////// -template -__host__ void BroadcastBool::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - DISPATCH_BY_OPNUM_TT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS)) - DEBUG_KERNEL(stream, opNum); +template +__host__ void BroadcastBool::execBroadcast( + dim3 launchDims, cudaStream_t* stream, int opNum, void const* x, + Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, + void* z, Nd4jLong const* zShapeInfo, void* extraParams, int* dimension, + int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + DISPATCH_BY_OPNUM_TT( + intermediateBroadcast, + PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraParams, dimension, dimensionLength, tadOnlyShapeInfo, + tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), + OPS_A(BROADCAST_BOOL_OPS)) + DEBUG_KERNEL(stream, opNum); } ////////////////////////////////////////////////////////////////////////// -template -__host__ void BroadcastBool::execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams) { - - DISPATCH_BY_OPNUM_TT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams), OPS_A(BROADCAST_BOOL_OPS)) - DEBUG_KERNEL(stream, opNum); +template +__host__ void BroadcastBool::execBroadcast( + dim3 launchDims, cudaStream_t* stream, const int opNum, const void* x, + const Nd4jLong* xShapeInfo, const void* y, const Nd4jLong* yShapeInfo, + void* z, const Nd4jLong* zShapeInfo, void* extraParams) { + DISPATCH_BY_OPNUM_TT(intermediateBroadcast, + PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, + z, zShapeInfo, extraParams), + OPS_A(BROADCAST_BOOL_OPS)) + DEBUG_KERNEL(stream, opNum); } ////////////////////////////////////////////////////////////////////////// - template - template - __host__ void BroadcastBool::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - broadcastBoolInverseSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); - sd::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed"); - } - -////////////////////////////////////////////////////////////////////////// - template - __host__ void BroadcastBool::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - DISPATCH_BY_OPNUM_TT(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS)) - - DEBUG_KERNEL(stream, opNum); - } +template +template +__host__ void BroadcastBool::intermediateInverseBroadcast( + dim3 launchDims, cudaStream_t* stream, void const* x, + Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, + void* z, Nd4jLong const* zShapeInfo, void* extraParams, int* dimension, + int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + broadcastBoolInverseSimple + <<>>( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, + dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, + tadOffsetsZ); + sd::DebugHelper::checkErrorCode(stream, + "intermediateBroadcastBool(...) failed"); +} ////////////////////////////////////////////////////////////////////////// - template - template - __device__ void BroadcastBool::transformInverseCuda( - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - void *vextraParams, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - if (tadOnlyShapeInfoZ == nullptr) { - tadOnlyShapeInfoZ = tadOnlyShapeInfo; - tadOffsetsZ = tadOffsets; - } - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the x shape info for setting up the tad problem - __shared__ Nd4jLong tadLength; - __shared__ Nd4jLong tadEWS; - __shared__ int numTads; - __shared__ Nd4jLong xEWS; - __shared__ Nd4jLong zEWS; - - if (threadIdx.x == 0) { - tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); - tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); - numTads = shape::length(yShapeInfo) / tadLength; - xEWS = shape::elementWiseStride(xShapeInfo); - zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); - } - __syncthreads(); - - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - auto rZ = z + tadOffsetsZ[r]; - auto rY = y + tadOffsets[r]; - - if(tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1) { - - for (int i = threadIdx.x; i < tadLength; i+= blockDim.x) - rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS], extraParams); - } - else { - // it is expected that x and z tads and y array all have the same length - for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo); - auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ); - - rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset], extraParams); - } - } - } - } +template +__host__ void BroadcastBool::execInverseBroadcast( + dim3 launchDims, cudaStream_t* stream, int opNum, void const* x, + Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, + void* z, Nd4jLong const* zShapeInfo, void* extraParams, int* dimension, + int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + DISPATCH_BY_OPNUM_TT( + intermediateInverseBroadcast, + PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraParams, dimension, dimensionLength, tadOnlyShapeInfo, + tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), + OPS_A(BROADCAST_BOOL_OPS)) + + DEBUG_KERNEL(stream, opNum); +} ////////////////////////////////////////////////////////////////////////// - template - template - __device__ void BroadcastBool::transformCuda( - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - void *vextraParams, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - if (tadOnlyShapeInfoZ == nullptr) { - tadOnlyShapeInfoZ = tadOnlyShapeInfo; - tadOffsetsZ = tadOffsets; - } - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the x shape info for setting up the tad problem - __shared__ Nd4jLong tadLength; - __shared__ Nd4jLong tadEWS; - __shared__ int numTads; - __shared__ Nd4jLong yEWS; - __shared__ Nd4jLong zEWS; - - if (threadIdx.x == 0) { - tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); - tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); - numTads = shape::length(xShapeInfo) / tadLength; - yEWS = shape::elementWiseStride(yShapeInfo); - zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); - } - __syncthreads(); - - __shared__ Z *rZ; - __shared__ X const* rX; - - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - - if (threadIdx.x == 0) { - rZ = z + tadOffsetsZ[r]; - rX = x + tadOffsets[r]; - } - __syncthreads(); - - - if(tadEWS > 0 && zEWS > 0 && yEWS > 0 && dimensionLength == 1) { - - for (int i = threadIdx.x; i < tadLength; i+= blockDim.x) - rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS], extraParams); - } - else { - // it is expected that x and z tads and y array all have the same length - for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { - auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo); - auto yOffset = shape::getIndexOffset(i, yShapeInfo); - auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ); - - rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset], extraParams); - } - } - } - } +template +template +__device__ void BroadcastBool::transformInverseCuda( + void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + void* vextraParams, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + if (tadOnlyShapeInfoZ == nullptr) { + tadOnlyShapeInfoZ = tadOnlyShapeInfo; + tadOffsetsZ = tadOffsets; + } + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the x shape info for setting up the tad problem + __shared__ Nd4jLong tadLength; + __shared__ Nd4jLong tadEWS; + __shared__ int numTads; + __shared__ Nd4jLong xEWS; + __shared__ Nd4jLong zEWS; + + if (threadIdx.x == 0) { + tadLength = + shape::length(tadOnlyShapeInfo); // shape::tadLength(xShapeInfo, + // dimension, dimensionLength); + tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); + numTads = shape::length(yShapeInfo) / tadLength; + xEWS = shape::elementWiseStride(xShapeInfo); + zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); + } + __syncthreads(); + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + auto rZ = z + tadOffsetsZ[r]; + auto rY = y + tadOffsets[r]; + + if (tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1) { + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) + rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS], extraParams); + } else { + // it is expected that x and z tads and y array all have the same length + for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo); + auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ); + + rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset], extraParams); + } + } + } +} ////////////////////////////////////////////////////////////////////////// -template +template template -__device__ void BroadcastBool::transformCuda(const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraParams) { - - const X* x = reinterpret_cast(vx); - const X* y = reinterpret_cast(vy); - Z* z = reinterpret_cast(vz); - - auto extraParams = reinterpret_cast(vextraParams); - - __shared__ Nd4jLong zLen; - __shared__ int rank; - __shared__ bool xzSameOffsets, yzSameOffsets; - +__device__ void BroadcastBool::transformCuda( + void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + void* vextraParams, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + if (tadOnlyShapeInfoZ == nullptr) { + tadOnlyShapeInfoZ = tadOnlyShapeInfo; + tadOffsetsZ = tadOffsets; + } + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the x shape info for setting up the tad problem + __shared__ Nd4jLong tadLength; + __shared__ Nd4jLong tadEWS; + __shared__ int numTads; + __shared__ Nd4jLong yEWS; + __shared__ Nd4jLong zEWS; + + if (threadIdx.x == 0) { + tadLength = + shape::length(tadOnlyShapeInfo); // shape::tadLength(xShapeInfo, + // dimension, dimensionLength); + tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); + numTads = shape::length(xShapeInfo) / tadLength; + yEWS = shape::elementWiseStride(yShapeInfo); + zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); + } + __syncthreads(); + + __shared__ Z* rZ; + __shared__ X const* rX; + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { if (threadIdx.x == 0) { - - zLen = shape::length(zShapeInfo); - rank = shape::rank(zShapeInfo); - - xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); + rZ = z + tadOffsetsZ[r]; + rX = x + tadOffsets[r]; } __syncthreads(); + if (tadEWS > 0 && zEWS > 0 && yEWS > 0 && dimensionLength == 1) { + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) + rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS], extraParams); + } else { + // it is expected that x and z tads and y array all have the same length + for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo); + auto yOffset = shape::getIndexOffset(i, yShapeInfo); + auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ); + + rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset], extraParams); + } + } + } +} - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; +////////////////////////////////////////////////////////////////////////// +template +template +__device__ void BroadcastBool::transformCuda( + const void* vx, const Nd4jLong* xShapeInfo, const void* vy, + const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vextraParams) { + const X* x = reinterpret_cast(vx); + const X* y = reinterpret_cast(vy); + Z* z = reinterpret_cast(vz); - int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; + auto extraParams = reinterpret_cast(vextraParams); - for (int i = tid; i < zLen; i += blockDim.x * gridDim.x) { + __shared__ Nd4jLong zLen; + __shared__ int rank; + __shared__ bool xzSameOffsets, yzSameOffsets; - shape::index2coords(i, zShapeInfo, zCoords); + if (threadIdx.x == 0) { + zLen = shape::length(zShapeInfo); + rank = shape::rank(zShapeInfo); - for (uint j = 0; j < rank; ++j) { - xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; - yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; - } + xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); + } + __syncthreads(); - const auto zOffset = shape::getOffset(zShapeInfo, zCoords); - const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); - const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; + + for (int i = tid; i < zLen; i += blockDim.x * gridDim.x) { + shape::index2coords(i, zShapeInfo, zCoords); + + for (uint j = 0; j < rank; ++j) { + xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; + yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; } -} + const auto zOffset = shape::getOffset(zShapeInfo, zCoords); + const auto xOffset = + xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); + const auto yOffset = + yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); -BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES, BOOL_TYPES); + z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + } } -} \ No newline at end of file + +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES, + BOOL_TYPES); +} // namespace broadcast +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/broadcasting_int.cu b/libnd4j/include/loops/cuda/broadcasting_int.cu index 37cbf3eba59f..a7e92732e9c1 100644 --- a/libnd4j/include/loops/cuda/broadcasting_int.cu +++ b/libnd4j/include/loops/cuda/broadcasting_int.cu @@ -18,283 +18,319 @@ // @author raver119@gmail.com // -#include +#include +#include +#include #include #include -#include #include -#include -#include -#include +#include +#include + #include -#include +#include using namespace simdOps; ////////////////////////////////////////////////////////////////////////// -template +template static __global__ void broadcastIntSimple( - void const* x, - Nd4jLong const* xShapeInfo, - void const* y, - Nd4jLong const* yShapeInfo, - void *z, - Nd4jLong const* zShapeInfo, - int *dimension, - int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - functions::broadcast::BroadcastInt::template transformCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); + void const* x, Nd4jLong const* xShapeInfo, void const* y, + Nd4jLong const* yShapeInfo, void* z, Nd4jLong const* zShapeInfo, + int* dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + functions::broadcast::BroadcastInt::template transformCuda( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, + tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); } ////////////////////////////////////////////////////////////////////////// -template -static __global__ void broadcastIntSimple(const void *x, const Nd4jLong const* xShapeInfo, - const void *y, const Nd4jLong const* yShapeInfo, - void *z, const Nd4jLong const* zShapeInfo) { - - functions::broadcast::BroadcastInt::template transformCuda(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); +template +static __global__ void broadcastIntSimple(const void* x, + const Nd4jLong const* xShapeInfo, + const void* y, + const Nd4jLong const* yShapeInfo, + void* z, + const Nd4jLong const* zShapeInfo) { + functions::broadcast::BroadcastInt::template transformCuda( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); } ////////////////////////////////////////////////////////////////////////// -template +template static __global__ void broadcastBoolInverseSimple( - void const* x, - Nd4jLong const* xShapeInfo, - void const* y, - Nd4jLong const* yShapeInfo, - void *z, - Nd4jLong const* zShapeInfo, - int *dimension, - int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - functions::broadcast::BroadcastInt::template transformInverseCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); + void const* x, Nd4jLong const* xShapeInfo, void const* y, + Nd4jLong const* yShapeInfo, void* z, Nd4jLong const* zShapeInfo, + int* dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + functions::broadcast::BroadcastInt::template transformInverseCuda( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, + tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); } namespace functions { namespace broadcast { ////////////////////////////////////////////////////////////////////////// -template +template template -__host__ void BroadcastInt::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - broadcastIntSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); +__host__ void BroadcastInt::intermediateBroadcast( + dim3 launchDims, cudaStream_t* stream, void const* x, + Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, + void* z, Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + broadcastIntSimple + <<>>( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, + dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, + tadOffsetsZ); } ////////////////////////////////////////////////////////////////////////// -template +template template -__host__ void BroadcastInt::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo) { - - broadcastIntSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); +__host__ void BroadcastInt::intermediateBroadcast( + dim3 launchDims, cudaStream_t* stream, const void* x, + const Nd4jLong* xShapeInfo, const void* y, const Nd4jLong* yShapeInfo, + void* z, const Nd4jLong* zShapeInfo) { + broadcastIntSimple + <<>>( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); } ////////////////////////////////////////////////////////////////////////// -template -__host__ void BroadcastInt::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - DISPATCH_BY_OPNUM_T(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_INT_OPS)) +template +__host__ void BroadcastInt::execBroadcast( + dim3 launchDims, cudaStream_t* stream, int opNum, void const* x, + Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, + void* z, Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + DISPATCH_BY_OPNUM_T( + intermediateBroadcast, + PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ), + OPS_A(BROADCAST_INT_OPS)) } ////////////////////////////////////////////////////////////////////////// -template -__host__ void BroadcastInt::execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, - const void *x, const Nd4jLong const* xShapeInfo, - const void *y, const Nd4jLong const* yShapeInfo, - void *z, const Nd4jLong const* zShapeInfo) { - - DISPATCH_BY_OPNUM_T(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), OPS_A(BROADCAST_INT_OPS)) +template +__host__ void BroadcastInt::execBroadcast( + dim3 launchDims, cudaStream_t* stream, const int opNum, const void* x, + const Nd4jLong const* xShapeInfo, const void* y, + const Nd4jLong const* yShapeInfo, void* z, + const Nd4jLong const* zShapeInfo) { + DISPATCH_BY_OPNUM_T( + intermediateBroadcast, + PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), + OPS_A(BROADCAST_INT_OPS)) } ////////////////////////////////////////////////////////////////////////// - template - template - __host__ void BroadcastInt::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - broadcastBoolInverseSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); - } - -////////////////////////////////////////////////////////////////////////// - template - __host__ void BroadcastInt::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - DISPATCH_BY_OPNUM_T(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_INT_OPS)) - } +template +template +__host__ void BroadcastInt::intermediateInverseBroadcast( + dim3 launchDims, cudaStream_t* stream, void const* x, + Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, + void* z, Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + broadcastBoolInverseSimple + <<>>( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, + dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, + tadOffsetsZ); +} ////////////////////////////////////////////////////////////////////////// - template - template - __device__ void BroadcastInt::transformInverseCuda( - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - if (tadOnlyShapeInfoZ == nullptr) { - tadOnlyShapeInfoZ = tadOnlyShapeInfo; - tadOffsetsZ = tadOffsets; - } - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the x shape info for setting up the tad problem - __shared__ Nd4jLong tadLength; - __shared__ Nd4jLong tadEWS; - __shared__ int numTads; - __shared__ Nd4jLong xEWS; - __shared__ Nd4jLong zEWS; - - if (threadIdx.x == 0) { - tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); - tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); - numTads = shape::length(yShapeInfo) / tadLength; - xEWS = shape::elementWiseStride(xShapeInfo); - zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); - } - __syncthreads(); - - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - auto rZ = z + tadOffsetsZ[r]; - auto rY = y + tadOffsets[r]; - - if(tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1) { - - for (int i = threadIdx.x; i < tadLength; i+= blockDim.x) - rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS]); - } - else { - // it is expected that x and z tads and y array all have the same length - for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo); - auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ); - - rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]); - } - } - } - } +template +__host__ void BroadcastInt::execInverseBroadcast( + dim3 launchDims, cudaStream_t* stream, int opNum, void const* x, + Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, + void* z, Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + DISPATCH_BY_OPNUM_T( + intermediateInverseBroadcast, + PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ), + OPS_A(BROADCAST_INT_OPS)) +} ////////////////////////////////////////////////////////////////////////// - template - template - __device__ void BroadcastInt::transformCuda( - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - if (tadOnlyShapeInfoZ == nullptr) { - tadOnlyShapeInfoZ = tadOnlyShapeInfo; - tadOffsetsZ = tadOffsets; - } - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the x shape info for setting up the tad problem - __shared__ Nd4jLong tadLength; - __shared__ Nd4jLong tadEWS; - __shared__ int numTads; - __shared__ Nd4jLong yEWS; - __shared__ Nd4jLong zEWS; - - if (threadIdx.x == 0) { - tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); - tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); - numTads = shape::length(xShapeInfo) / tadLength; - yEWS = shape::elementWiseStride(yShapeInfo); - zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); - } - __syncthreads(); - - __shared__ X *rZ; - __shared__ X const* rX; - - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - - if (threadIdx.x == 0) { - rZ = z + tadOffsetsZ[r]; - rX = x + tadOffsets[r]; - } - __syncthreads(); - - - if(tadEWS > 0 && zEWS > 0 && yEWS > 0 && dimensionLength == 1) { - - for (int i = threadIdx.x; i < tadLength; i+= blockDim.x) - rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS]); - } - else { - // it is expected that x and z tads and y array all have the same length - for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { - auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo); - auto yOffset = shape::getIndexOffset(i, yShapeInfo); - auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ); - - rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]); - } - } - } - } +template +template +__device__ void BroadcastInt::transformInverseCuda( + void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + int* dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + if (tadOnlyShapeInfoZ == nullptr) { + tadOnlyShapeInfoZ = tadOnlyShapeInfo; + tadOffsetsZ = tadOffsets; + } + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the x shape info for setting up the tad problem + __shared__ Nd4jLong tadLength; + __shared__ Nd4jLong tadEWS; + __shared__ int numTads; + __shared__ Nd4jLong xEWS; + __shared__ Nd4jLong zEWS; + + if (threadIdx.x == 0) { + tadLength = + shape::length(tadOnlyShapeInfo); // shape::tadLength(xShapeInfo, + // dimension, dimensionLength); + tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); + numTads = shape::length(yShapeInfo) / tadLength; + xEWS = shape::elementWiseStride(xShapeInfo); + zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); + } + __syncthreads(); + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + auto rZ = z + tadOffsetsZ[r]; + auto rY = y + tadOffsets[r]; + + if (tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1) { + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) + rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS]); + } else { + // it is expected that x and z tads and y array all have the same length + for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo); + auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ); + + rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]); + } + } + } +} ////////////////////////////////////////////////////////////////////////// -template +template template -__device__ void BroadcastInt::transformCuda(const void *vx, const Nd4jLong const* xShapeInfo, - const void *vy, const Nd4jLong const* yShapeInfo, - void *vz, const Nd4jLong const* zShapeInfo) { - - const X* x = reinterpret_cast(vx); - const X* y = reinterpret_cast(vy); - X* z = reinterpret_cast(vz); - - __shared__ Nd4jLong zLen; - __shared__ int rank; - __shared__ bool xzSameOffsets, yzSameOffsets; - +__device__ void BroadcastInt::transformCuda( + void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + int* dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + if (tadOnlyShapeInfoZ == nullptr) { + tadOnlyShapeInfoZ = tadOnlyShapeInfo; + tadOffsetsZ = tadOffsets; + } + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the x shape info for setting up the tad problem + __shared__ Nd4jLong tadLength; + __shared__ Nd4jLong tadEWS; + __shared__ int numTads; + __shared__ Nd4jLong yEWS; + __shared__ Nd4jLong zEWS; + + if (threadIdx.x == 0) { + tadLength = + shape::length(tadOnlyShapeInfo); // shape::tadLength(xShapeInfo, + // dimension, dimensionLength); + tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); + numTads = shape::length(xShapeInfo) / tadLength; + yEWS = shape::elementWiseStride(yShapeInfo); + zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); + } + __syncthreads(); + + __shared__ X* rZ; + __shared__ X const* rX; + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { if (threadIdx.x == 0) { - - zLen = shape::length(zShapeInfo); - rank = shape::rank(zShapeInfo); - - xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); + rZ = z + tadOffsetsZ[r]; + rX = x + tadOffsets[r]; } __syncthreads(); + if (tadEWS > 0 && zEWS > 0 && yEWS > 0 && dimensionLength == 1) { + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) + rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS]); + } else { + // it is expected that x and z tads and y array all have the same length + for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo); + auto yOffset = shape::getIndexOffset(i, yShapeInfo); + auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ); + + rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]); + } + } + } +} - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; +////////////////////////////////////////////////////////////////////////// +template +template +__device__ void BroadcastInt::transformCuda( + const void* vx, const Nd4jLong const* xShapeInfo, const void* vy, + const Nd4jLong const* yShapeInfo, void* vz, + const Nd4jLong const* zShapeInfo) { + const X* x = reinterpret_cast(vx); + const X* y = reinterpret_cast(vy); + X* z = reinterpret_cast(vz); + + __shared__ Nd4jLong zLen; + __shared__ int rank; + __shared__ bool xzSameOffsets, yzSameOffsets; - int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; + if (threadIdx.x == 0) { + zLen = shape::length(zShapeInfo); + rank = shape::rank(zShapeInfo); - for (int i = tid; i < zLen; i += blockDim.x * gridDim.x) { + xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); + } + __syncthreads(); - shape::index2coords(i, zShapeInfo, zCoords); + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (uint j = 0; j < rank; ++j) { - xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; - yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; - } + int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; - const auto zOffset = shape::getOffset(zShapeInfo, zCoords); - const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); - const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); + for (int i = tid; i < zLen; i += blockDim.x * gridDim.x) { + shape::index2coords(i, zShapeInfo, zCoords); - z[zOffset] = OpType::op(x[xOffset], y[yOffset]); + for (uint j = 0; j < rank; ++j) { + xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; + yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; } -} + const auto zOffset = shape::getOffset(zShapeInfo, zCoords); + const auto xOffset = + xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); + const auto yOffset = + yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); -BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES); + z[zOffset] = OpType::op(x[xOffset], y[yOffset]); + } } -} \ No newline at end of file + +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES); +} // namespace broadcast +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting.cu.in b/libnd4j/include/loops/cuda/compilation_units/broadcasting.cu.in index 6349dcfc98b8..cd93088da9e3 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting.cu.in +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting.cu.in @@ -22,6 +22,6 @@ #cmakedefine PAIRWISE_TYPE_GEN namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_@FL_TYPE_INDEX@); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_@FL_TYPE_INDEX@); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise.cu.in b/libnd4j/include/loops/cuda/compilation_units/pairwise.cu.in index 312ed7416242..2903b2f2c2a4 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise.cu.in +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise.cu.in @@ -22,6 +22,6 @@ #cmakedefine PAIRWISE_TYPE_GEN namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_@FL_TYPE_INDEX@); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_@FL_TYPE_INDEX@); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/reduce3.cu.in b/libnd4j/include/loops/cuda/compilation_units/reduce3.cu.in index dd74728369a9..92ce78a83da6 100644 --- a/libnd4j/include/loops/cuda/compilation_units/reduce3.cu.in +++ b/libnd4j/include/loops/cuda/compilation_units/reduce3.cu.in @@ -22,6 +22,6 @@ #cmakedefine FLOAT_TYPE_GEN namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES_@FL_TYPE_INDEX@); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES_@FL_TYPE_INDEX@); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/reduce_float.cu.in b/libnd4j/include/loops/cuda/compilation_units/reduce_float.cu.in index 34c2bf8caadc..30d1c80a01be 100644 --- a/libnd4j/include/loops/cuda/compilation_units/reduce_float.cu.in +++ b/libnd4j/include/loops/cuda/compilation_units/reduce_float.cu.in @@ -22,6 +22,6 @@ #cmakedefine FLOAT_TYPE_GEN namespace functions { namespace reduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_@FL_TYPE_INDEX@); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_@FL_TYPE_INDEX@); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar.cu.in b/libnd4j/include/loops/cuda/compilation_units/scalar.cu.in index 15608bdd1e65..4d5fbf05d9f4 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar.cu.in +++ b/libnd4j/include/loops/cuda/compilation_units/scalar.cu.in @@ -22,6 +22,6 @@ #cmakedefine PAIRWISE_TYPE_GEN namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_@FL_TYPE_INDEX@); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_@FL_TYPE_INDEX@); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/indexreduce.cu b/libnd4j/include/loops/cuda/indexreduce.cu index dbe03a9bf4e7..43c6d3c232cf 100644 --- a/libnd4j/include/loops/cuda/indexreduce.cu +++ b/libnd4j/include/loops/cuda/indexreduce.cu @@ -18,354 +18,344 @@ // Created by raver on 4/9/2018. // +#include #include -#include "../indexreduce.h" #include -#include #include +#include "../indexreduce.h" #include "../legacy_ops.h" using namespace simdOps; - template -static __global__ void simpleIndexReduceGeneric(const int op, - void const* dx, - Nd4jLong const* xShapeInfo, int xRank, - void *extraParams, - void *result, - Nd4jLong const* zShapeInfo, int zRank, - int *dimension, - int dimensionLength, - int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) { - - functions::indexreduce::IndexReduce::transform(op,dx,xShapeInfo,extraParams,result,zShapeInfo,dimension,dimensionLength,postProcessOrNot,allocationBuffer,reductionBuffer,tadOnlyShapeInfo,tadOffsets); +static __global__ void simpleIndexReduceGeneric( + const int op, void const *dx, Nd4jLong const *xShapeInfo, int xRank, + void *extraParams, void *result, Nd4jLong const *zShapeInfo, int zRank, + int *dimension, int dimensionLength, int postProcessOrNot, + int *allocationBuffer, void *reductionBuffer, + Nd4jLong const *tadOnlyShapeInfo, Nd4jLong const *tadOffsets) { + functions::indexreduce::IndexReduce::transform( + op, dx, xShapeInfo, extraParams, result, zShapeInfo, dimension, + dimensionLength, postProcessOrNot, allocationBuffer, reductionBuffer, + tadOnlyShapeInfo, tadOffsets); } namespace functions { - namespace indexreduce { - - template - _CUDA_H void IndexReduce::executeIndexReduceScalar(dim3 launchDims, cudaStream_t *stream, - const int opNum, - void const* dx, Nd4jLong const* xShapeInfo, - int xRank, - void *extraParams, - void *result, Nd4jLong const* zShapeInfo, - int zRank, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationBuffer, void *reductionBuffer, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) { - - simpleIndexReduceGeneric<<>>(opNum, - dx, xShapeInfo, xRank, - extraParams, - result, zShapeInfo, 0, - nullptr, 0, - 1, - allocationBuffer, reductionBuffer, - tadOnlyShapeInfo, tadOffsets); - } +namespace indexreduce { - template - _CUDA_H void IndexReduce::executeIndexReduce(dim3 launchDims, cudaStream_t *stream, const int opNum, void const* dx, Nd4jLong const* xShapeInfo, int xRank, void *extraParams, void *result, Nd4jLong const* zShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) { - simpleIndexReduceGeneric<<>>( - opNum, - dx, - xShapeInfo, xRank, - extraParams, - result, - zShapeInfo, zRank, - dimension, - dimensionLength, - 1, allocationBuffer, reductionBuffer, tadOnlyShapeInfo, tadOffsets); - } +template +_CUDA_H void IndexReduce::executeIndexReduceScalar( + dim3 launchDims, cudaStream_t *stream, const int opNum, void const *dx, + Nd4jLong const *xShapeInfo, int xRank, void *extraParams, void *result, + Nd4jLong const *zShapeInfo, int zRank, int *dimension, int dimensionLength, + int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, + Nd4jLong const *tadOnlyShapeInfo, Nd4jLong const *tadOffsets) { + simpleIndexReduceGeneric + <<>>( + opNum, dx, xShapeInfo, xRank, extraParams, result, zShapeInfo, 0, + nullptr, 0, 1, allocationBuffer, reductionBuffer, tadOnlyShapeInfo, + tadOffsets); +} - // This is the un-specialized struct. Note that we prevent instantiation of this - // struct by putting an undefined symbol in the function body so it won't compile. - template - struct SharedIndexValue { - // Ensure that we won't compile any un-specialized types - __device__ T * getPointer() { - extern __device__ void error(void); - error(); - return 0; - } - }; +template +_CUDA_H void IndexReduce::executeIndexReduce( + dim3 launchDims, cudaStream_t *stream, const int opNum, void const *dx, + Nd4jLong const *xShapeInfo, int xRank, void *extraParams, void *result, + Nd4jLong const *zShapeInfo, int zRank, int *dimension, int dimensionLength, + int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, + Nd4jLong const *tadOnlyShapeInfo, Nd4jLong const *tadOffsets) { + simpleIndexReduceGeneric + <<>>( + opNum, dx, xShapeInfo, xRank, extraParams, result, zShapeInfo, zRank, + dimension, dimensionLength, 1, allocationBuffer, reductionBuffer, + tadOnlyShapeInfo, tadOffsets); +} + +// This is the un-specialized struct. Note that we prevent instantiation of +// this struct by putting an undefined symbol in the function body so it won't +// compile. +template +struct SharedIndexValue { + // Ensure that we won't compile any un-specialized types + __device__ T *getPointer() { + extern __device__ void error(void); + error(); + return 0; + } +}; // Following are the specializations for the following types. -// int, uint, char, uchar, short, ushort, long long, ulong long, bool, float, and double -// One could also specialize it for user-defined types. - - template<> - struct SharedIndexValue { - __device__ IndexValue * getPointer() { - extern __shared__ IndexValue s_int2[]; - return s_int2; - } - }; +// int, uint, char, uchar, short, ushort, long long, ulong long, bool, float, +// and double One could also specialize it for user-defined types. + +template <> +struct SharedIndexValue { + __device__ IndexValue *getPointer() { + extern __shared__ IndexValue s_int2[]; + return s_int2; + } +}; // Following are the specializations for the following types. -// int, uint, char, uchar, short, ushort, long long, ulong long, bool, float, and double -// One could also specialize it for user-defined types. - - template<> - struct SharedIndexValue { - __device__ IndexValue * getPointer() { - extern __shared__ IndexValue s_int6[]; - return s_int6; - } - }; - - template - template - __device__ void IndexReduce::aggregatePartials(IndexValue **sPartialsRef, Nd4jLong tid, Nd4jLong numElements, void *vextraParams) { - // start the shared memory loop on the next power of 2 less - // than the block size. If block size is not a power of 2, - // accumulate the intermediate sums in the remainder range. - auto extraParams = static_cast(vextraParams); - IndexValue *sPartials = *sPartialsRef; - Nd4jLong floorPow2 = blockDim.x; - - if (floorPow2 & (floorPow2 - 1)) { - while ( floorPow2 & (floorPow2 - 1) ) { - floorPow2 &= floorPow2 - 1; - } - - if (tid >= floorPow2) { - IndexValue prev = sPartials[tid - floorPow2]; - IndexValue curr = sPartials[tid]; - sPartials[tid - floorPow2] = OpType::update(prev,curr,extraParams); - } - __syncthreads(); - } - - for (int activeThreads = floorPow2 >> 1;activeThreads; activeThreads >>= 1) { - if (tid < activeThreads && tid + activeThreads < numElements) { - IndexValue curr = sPartials[tid]; - IndexValue next = sPartials[tid + activeThreads]; - sPartials[tid] = OpType::update(curr,next,extraParams); - } - __syncthreads(); - } - } +// int, uint, char, uchar, short, ushort, long long, ulong long, bool, float, +// and double One could also specialize it for user-defined types. + +template <> +struct SharedIndexValue { + __device__ IndexValue *getPointer() { + extern __shared__ IndexValue s_int6[]; + return s_int6; + } +}; - template - __device__ void IndexReduce::transform( - const int opNum, - void const* x, - Nd4jLong const* xShapeInfo, - void *extraParams, - void *result, - Nd4jLong const* zShapeInfo, - int *dimension, - int dimensionLength, - int postProcessOrNot, - int *allocationBuffer, - void *reductionBuffer, - Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffset) { - DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, extraParams, result, zShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationBuffer, reductionBuffer, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS); +template +template +__device__ void IndexReduce::aggregatePartials( + IndexValue **sPartialsRef, Nd4jLong tid, Nd4jLong numElements, + void *vextraParams) { + // start the shared memory loop on the next power of 2 less + // than the block size. If block size is not a power of 2, + // accumulate the intermediate sums in the remainder range. + auto extraParams = static_cast(vextraParams); + IndexValue *sPartials = *sPartialsRef; + Nd4jLong floorPow2 = blockDim.x; + + if (floorPow2 & (floorPow2 - 1)) { + while (floorPow2 & (floorPow2 - 1)) { + floorPow2 &= floorPow2 - 1; + } + + if (tid >= floorPow2) { + IndexValue prev = sPartials[tid - floorPow2]; + IndexValue curr = sPartials[tid]; + sPartials[tid - floorPow2] = OpType::update(prev, curr, extraParams); + } + __syncthreads(); + } + + for (int activeThreads = floorPow2 >> 1; activeThreads; activeThreads >>= 1) { + if (tid < activeThreads && tid + activeThreads < numElements) { + IndexValue curr = sPartials[tid]; + IndexValue next = sPartials[tid + activeThreads]; + sPartials[tid] = OpType::update(curr, next, extraParams); + } + __syncthreads(); + } +} + +template +__device__ void IndexReduce::transform( + const int opNum, void const *x, Nd4jLong const *xShapeInfo, + void *extraParams, void *result, Nd4jLong const *zShapeInfo, int *dimension, + int dimensionLength, int postProcessOrNot, int *allocationBuffer, + void *reductionBuffer, Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffset) { + DISPATCH_BY_OPNUM_TT( + transform, + PARAMS(x, xShapeInfo, extraParams, result, zShapeInfo, dimension, + dimensionLength, postProcessOrNot, allocationBuffer, + reductionBuffer, tadShapeInfo, tadOffset), + INDEX_REDUCE_OPS); +} + +template +template +__device__ void IndexReduce::transform( + void const *vdx, Nd4jLong const *xShapeInfo, void *vextraParams, void *vz, + Nd4jLong const *zShapeInfo, int *dimension, int dimensionLength, + int postProcessOrNot, int *allocationBuffer, void *vreductionBuffer, + Nd4jLong const *tadOnlyShapeInfo, Nd4jLong const *tadOffsets) { + /**int + * Gpu information for the problem + */ + auto dx = reinterpret_cast(vdx); + auto z = reinterpret_cast(vz); + auto extraParams = static_cast(vextraParams); + auto reductionBuffer = static_cast(vreductionBuffer); + auto order = shape::order(xShapeInfo); + int tid = blockIdx.x * blockDim.x + threadIdx.x; + __shared__ volatile bool resultScalar; + + // shared memory space for storing intermediate results + __shared__ IndexValue *sPartials; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast *>(shmem); + } + __syncthreads(); + + sPartials[threadIdx.x] = OpType::startingIndexValue(dx); + + // length for the tad + __shared__ volatile Nd4jLong xLength; + + __shared__ volatile Nd4jLong zLen; + + // only compute the tad indexes once + IndexValue reduction = OpType::startingIndexValue(dx); + + if (threadIdx.x == 0) { + if (zShapeInfo != nullptr) + zLen = shape::length(zShapeInfo); + else + zLen = 1; + + + if (zLen == 1 ) + resultScalar = true; else + resultScalar = false; + + xLength = shape::length(xShapeInfo); + } + __syncthreads(); + + if (sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { + if (sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) return; + + for (uint i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; + i += gridDim.x * blockDim.x) + z[i] = (Z)reduction.index; + + return; + } + + if (!resultScalar) { + __shared__ Nd4jLong tadLength; + __shared__ int tadEWS; + __shared__ int numTads; + + if (threadIdx.x == 0) { + tadLength = shape::length(tadOnlyShapeInfo); + tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); + numTads = shape::length(xShapeInfo) / tadLength; + } + __syncthreads(); + + if (dimensionLength > 1 || tadEWS < 1) { + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + auto tadOffsetForBlock = tadOffsets[r]; + sPartials[threadIdx.x] = OpType::startingIndexValue(dx); + + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = + tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo); + IndexValue comp{dx[xOffset], i}; + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], comp, extraParams); } + __syncthreads(); + aggregatePartials( + &sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, tadLength), extraParams); - template - template - __device__ void IndexReduce::transform(void const* vdx, Nd4jLong const* xShapeInfo, - void *vextraParams, - void* vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationBuffer, void *vreductionBuffer, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets){ - /**int - * Gpu information for the problem - */ - auto dx = reinterpret_cast(vdx); - auto z = reinterpret_cast(vz); - auto extraParams = static_cast(vextraParams); - auto reductionBuffer = static_cast(vreductionBuffer); - auto order = shape::order(xShapeInfo); - int tid = blockIdx.x * blockDim.x + threadIdx.x; - __shared__ volatile bool resultScalar; - - //shared memory space for storing intermediate results - __shared__ IndexValue* sPartials; - if(threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast*>(shmem); - } - __syncthreads(); - - sPartials[threadIdx.x] = OpType::startingIndexValue(dx); - - //length for the tad - __shared__ volatile Nd4jLong xLength; - - __shared__ volatile Nd4jLong zLen; - - - //only compute the tad indexes once - IndexValue reduction = OpType::startingIndexValue(dx); - - if (threadIdx.x == 0) { - if (zShapeInfo != nullptr) - zLen = shape::length(zShapeInfo); - else zLen = 1; - - if (zLen == 1) - resultScalar = true; - else - resultScalar = false; - - xLength = shape::length(xShapeInfo); - } - __syncthreads(); - - if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { - - if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) - return; - - for (uint i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; i += gridDim.x * blockDim.x) - z[i] = (Z) reduction.index; - - return; - } - - if (!resultScalar) { - - __shared__ Nd4jLong tadLength; - __shared__ int tadEWS; - __shared__ int numTads; - - if (threadIdx.x == 0) { - tadLength = shape::length(tadOnlyShapeInfo); - tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); - numTads = shape::length(xShapeInfo) / tadLength; - } - __syncthreads(); - - if (dimensionLength > 1 || tadEWS < 1) { - - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - - auto tadOffsetForBlock = tadOffsets[r]; - sPartials[threadIdx.x] = OpType::startingIndexValue(dx); - - for(int i = threadIdx.x;i < tadLength; i += blockDim.x) { - auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo); - IndexValue comp {dx[xOffset], i}; - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], comp, extraParams); - } - - __syncthreads(); - aggregatePartials(&sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, tadLength),extraParams); - - __syncthreads(); - if (threadIdx.x == 0) { - z[r] = (Z) sPartials[threadIdx.x].index; - } - __syncthreads(); - } - } else { - - for(int i = blockIdx.x; i < numTads; i+= gridDim.x) { - Nd4jLong tadOffsetForBlock = tadOffsets[i]; - - sPartials[threadIdx.x] = OpType::startingIndexValue(dx); - - for (int x = threadIdx.x; x < tadLength; x+= blockDim.x) { - IndexValue comp {dx[tadOffsetForBlock + x * tadEWS], x}; - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], comp, extraParams); - } - - __syncthreads(); - aggregatePartials(&sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, tadLength),extraParams); - - __syncthreads(); - if (threadIdx.x == 0) { - z[i] = (Z) sPartials[threadIdx.x].index; //postProcess(sPartials[0],tadLength ,extraParams); - } - __syncthreads(); - } - } - } else { - auto n = shape::length(xShapeInfo); - auto xElementWiseStride = shape::elementWiseStride(xShapeInfo); - - if(xElementWiseStride >= 1 && order == 'c') { - for(Nd4jLong i = tid;i < n; i += (blockDim.x * gridDim.x)) { - IndexValue indexVal = {dx[i * xElementWiseStride], i}; - reduction = OpType::update(reduction, indexVal, extraParams); - } - } else { - - for(Nd4jLong i = tid;i < n; i += blockDim.x * gridDim.x) { - auto offset = shape::getIndexOffset(i, xShapeInfo); - IndexValue indexVal = {dx[offset], i}; - reduction = OpType::update(reduction, indexVal, extraParams); - } - } - - - sPartials[threadIdx.x] = reduction; - __syncthreads(); - - aggregatePartials(&sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, (int) n),extraParams); - __syncthreads(); - - if (gridDim.x > 1) { - __shared__ bool amLast; - unsigned int *tc = (unsigned int *) reductionBuffer; - tid = threadIdx.x; - if (threadIdx.x == 0) { - auto pBuffer = reinterpret_cast *>(reductionBuffer); - pBuffer[blockIdx.x] = {sPartials[0].value, sPartials[0].index}; - } - __threadfence(); - __syncthreads(); - - if (tid==0) { - unsigned int ticket = atomicInc(&tc[16384], gridDim.x); - amLast = (ticket == gridDim.x-1); - } - - __syncthreads(); - - if (amLast) { - tc[16384] = 0; - IndexValue *pBuffer = (IndexValue *) reductionBuffer; - - sPartials[threadIdx.x] = OpType::startingIndexValue(dx); - - for (Nd4jLong i = threadIdx.x; i < gridDim.x; i += blockDim.x) { - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], pBuffer[i], extraParams); - } - - __syncthreads(); - aggregatePartials(&sPartials, threadIdx.x, sd::math::nd4j_min(gridDim.x, blockDim.x),extraParams); - - __syncthreads(); - if (tid == 0) { - z[0] = (Z) sPartials[0].index; - } - } - } else { - if (tid == 0) { - auto tc = reinterpret_cast(reductionBuffer); - tc[16384] = 0; - z[0] = (Z) sPartials[0].index; - } - } - - } + __syncthreads(); + if (threadIdx.x == 0) { + z[r] = (Z)sPartials[threadIdx.x].index; } + __syncthreads(); + } + } else { + for (int i = blockIdx.x; i < numTads; i += gridDim.x) { + Nd4jLong tadOffsetForBlock = tadOffsets[i]; + + sPartials[threadIdx.x] = OpType::startingIndexValue(dx); + + for (int x = threadIdx.x; x < tadLength; x += blockDim.x) { + IndexValue comp{dx[tadOffsetForBlock + x * tadEWS], x}; + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], comp, extraParams); + } + + __syncthreads(); + aggregatePartials( + &sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, tadLength), extraParams); - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES, INDEXING_TYPES); + __syncthreads(); + if (threadIdx.x == 0) { + z[i] = + (Z)sPartials[threadIdx.x] + .index; // postProcess(sPartials[0],tadLength ,extraParams); + } + __syncthreads(); + } } -} + } else { + auto n = shape::length(xShapeInfo); + auto xElementWiseStride = shape::elementWiseStride(xShapeInfo); + + if (xElementWiseStride >= 1 && order == 'c') { + for (Nd4jLong i = tid; i < n; i += (blockDim.x * gridDim.x)) { + IndexValue indexVal = {dx[i * xElementWiseStride], i}; + reduction = OpType::update(reduction, indexVal, extraParams); + } + } else { + for (Nd4jLong i = tid; i < n; i += blockDim.x * gridDim.x) { + auto offset = shape::getIndexOffset(i, xShapeInfo); + IndexValue indexVal = {dx[offset], i}; + reduction = OpType::update(reduction, indexVal, extraParams); + } + } + + sPartials[threadIdx.x] = reduction; + __syncthreads(); + + aggregatePartials(&sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, (int)n), + extraParams); + __syncthreads(); + + if (gridDim.x > 1) { + __shared__ bool amLast; + unsigned int *tc = (unsigned int *)reductionBuffer; + tid = threadIdx.x; + if (threadIdx.x == 0) { + auto pBuffer = reinterpret_cast *>(reductionBuffer); + pBuffer[blockIdx.x] = {sPartials[0].value, sPartials[0].index}; + } + __threadfence(); + __syncthreads(); + + if (tid == 0) { + unsigned int ticket = atomicInc(&tc[16384], gridDim.x); + amLast = (ticket == gridDim.x - 1); + } + + __syncthreads(); + + if (amLast) { + tc[16384] = 0; + IndexValue *pBuffer = (IndexValue *)reductionBuffer; + + sPartials[threadIdx.x] = OpType::startingIndexValue(dx); + + for (Nd4jLong i = threadIdx.x; i < gridDim.x; i += blockDim.x) { + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], pBuffer[i], extraParams); + } + __syncthreads(); + aggregatePartials( + &sPartials, threadIdx.x, + sd::math::nd4j_min(gridDim.x, blockDim.x), extraParams); + __syncthreads(); + if (tid == 0) { + z[0] = (Z)sPartials[0].index; + } + } + } else { + if (tid == 0) { + auto tc = reinterpret_cast(reductionBuffer); + tc[16384] = 0; + z[0] = (Z)sPartials[0].index; + } + } + } +} +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES, + INDEXING_TYPES); +} // namespace indexreduce +} // namespace functions diff --git a/libnd4j/include/loops/cuda/inplace_loops/reduce_same_inplace.h b/libnd4j/include/loops/cuda/inplace_loops/reduce_same_inplace.h index 1989cadc5c53..45350d761bf6 100644 --- a/libnd4j/include/loops/cuda/inplace_loops/reduce_same_inplace.h +++ b/libnd4j/include/loops/cuda/inplace_loops/reduce_same_inplace.h @@ -18,156 +18,171 @@ // @author raver119@gmail.com // +#ifndef SD_REDUCE_SAME_LOOPS_H +#define SD_REDUCE_SAME_LOOPS_H -#ifndef DEV_TESTS_REDUCE_SAME_LOOPS_H -#define DEV_TESTS_REDUCE_SAME_LOOPS_H - +#include #include -#include #include -#include +#include using namespace simdOps; namespace functions { - namespace reduce { - template - class ReduceSameInplace { - public: - static FORCEINLINE void _CUDA_D execScalarCudaLegacy(int opNum, void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vz, Nd4jLong *zShapeInfo, void *vreductionBuffer, Nd4jLong *tadOnlyShapeInfo); - - template - static FORCEINLINE void _CUDA_D execScalarCuda(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vz, Nd4jLong *zShapeInfo, void *vreductionBuffer, Nd4jLong *tadOnlyShapeInfo); - - template - static FORCEINLINE void _CUDA_D aggregatePartials(void *vsPartials, Nd4jLong tid, Nd4jLong numItems, void *vextraParams); - }; +namespace reduce { +template +class ReduceSameInplace { + public: + static FORCEINLINE void _CUDA_D execScalarCudaLegacy( + int opNum, void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + Nd4jLong *zShapeInfo, void *vreductionBuffer, Nd4jLong *tadOnlyShapeInfo); + + template + static FORCEINLINE void _CUDA_D execScalarCuda(void *vx, Nd4jLong *xShapeInfo, + void *vextraParams, void *vz, + Nd4jLong *zShapeInfo, + void *vreductionBuffer, + Nd4jLong *tadOnlyShapeInfo); + + template + static FORCEINLINE void _CUDA_D aggregatePartials(void *vsPartials, + Nd4jLong tid, + Nd4jLong numItems, + void *vextraParams); +}; + +template +template +__device__ void ReduceSameInplace::aggregatePartials(void *vsPartials, + Nd4jLong tid, + Nd4jLong numItems, + void *vextraParams) { + // start the shared memory loop on the next power of 2 less + // than the block size. If block size is not a power of 2, + // accumulate the intermediate sums in the remainder range. + + auto sPartials = static_cast(vsPartials); + auto extraParams = static_cast(vextraParams); + + Nd4jLong floorPow2 = numItems; + + if (floorPow2 & (floorPow2 - 1)) { + while (floorPow2 & (floorPow2 - 1)) floorPow2 &= floorPow2 - 1; + + if (tid >= floorPow2) + sPartials[tid - floorPow2] = OpType::update(sPartials[tid - floorPow2], + sPartials[tid], extraParams); + + __syncthreads(); + } + + for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; + activeThreads >>= 1) { + if (tid < activeThreads && tid + activeThreads < numItems) + sPartials[tid] = OpType::update( + sPartials[tid], sPartials[tid + activeThreads], extraParams); + + __syncthreads(); + } +} - template - template - __device__ void ReduceSameInplace::aggregatePartials(void *vsPartials, Nd4jLong tid, Nd4jLong numItems, void *vextraParams) { +template +FORCEINLINE void _CUDA_D ReduceSameInplace::execScalarCudaLegacy( + int opNum, void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + Nd4jLong *zShapeInfo, void *vreductionBuffer, Nd4jLong *tadOnlyShapeInfo) { + DISPATCH_BY_OPNUM_T(execScalarCuda, + PARAMS(vx, xShapeInfo, vextraParams, vz, zShapeInfo, + vreductionBuffer, tadOnlyShapeInfo), + REDUCE_SAME_OPS); +} - // start the shared memory loop on the next power of 2 less - // than the block size. If block size is not a power of 2, - // accumulate the intermediate sums in the remainder range. +template +template +FORCEINLINE void _CUDA_D ReduceSameInplace::execScalarCuda( + void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + Nd4jLong *zShapeInfo, void *vreductionBuffer, Nd4jLong *tadOnlyShapeInfo) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + auto reductionBuffer = reinterpret_cast(vreductionBuffer); + + int xEws = shape::elementWiseStride(xShapeInfo); + auto len = shape::length(xShapeInfo); + auto tid = blockDim.x * blockIdx.x + threadIdx.x; + + // shared memory space for storing intermediate results + __shared__ X *sPartials; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast(shmem); + } + __syncthreads(); + sPartials[threadIdx.x] = OpType::startingValue(x); + + if (xEws > 0) + for (int i = tid; i < len; i += (blockDim.x * gridDim.x)) + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], + OpType::op(x[i * xEws], extraParams), extraParams); + else + for (int i = tid; i < len; i += blockDim.x * gridDim.x) + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], + OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], extraParams), + extraParams); + + __syncthreads(); + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, len), + extraParams); + __syncthreads(); + + if (gridDim.x > 1) { + unsigned int *tc = (unsigned int *)reductionBuffer; + __shared__ bool amLast; + + tid = threadIdx.x; + if (threadIdx.x == 0) + reductionBuffer[blockIdx.x] = + sPartials[0]; // this->postProcess(sPartials[0],len,extraParams); + + __threadfence(); + __syncthreads(); + + if (threadIdx.x == 0) { + unsigned int ticket = atomicInc(&tc[16384], gridDim.x); + amLast = (ticket == gridDim.x - 1); + } - auto sPartials = static_cast(vsPartials); - auto extraParams = static_cast(vextraParams); + __syncthreads(); - Nd4jLong floorPow2 = numItems; + if (amLast) { + tc[16384] = 0; + sPartials[threadIdx.x] = OpType::startingValue(x); - if (floorPow2 & (floorPow2 - 1)) { - - while (floorPow2 & (floorPow2 - 1)) - floorPow2 &= floorPow2 - 1; + for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], reductionBuffer[i], extraParams); - if (tid >= floorPow2) - sPartials[tid - floorPow2] = OpType::update(sPartials[tid - floorPow2], sPartials[tid], extraParams); - - __syncthreads(); - } - - for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; activeThreads >>= 1) { - if (tid < activeThreads && tid + activeThreads < numItems) - sPartials[tid] = OpType::update(sPartials[tid], sPartials[tid + activeThreads], extraParams); - - __syncthreads(); - } - } - - template - FORCEINLINE void _CUDA_D ReduceSameInplace::execScalarCudaLegacy(int opNum, void *vx, Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, Nd4jLong *zShapeInfo, - void *vreductionBuffer, - Nd4jLong *tadOnlyShapeInfo) { - DISPATCH_BY_OPNUM_T(execScalarCuda, PARAMS(vx, xShapeInfo, vextraParams, vz, zShapeInfo, vreductionBuffer, tadOnlyShapeInfo), REDUCE_SAME_OPS); - } - - template - template - FORCEINLINE void _CUDA_D ReduceSameInplace::execScalarCuda(void *vx, Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, Nd4jLong *zShapeInfo, - void *vreductionBuffer, - Nd4jLong *tadOnlyShapeInfo) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - auto reductionBuffer = reinterpret_cast(vreductionBuffer); - - int xEws = shape::elementWiseStride(xShapeInfo); - auto len = shape::length(xShapeInfo); - auto tid = blockDim.x * blockIdx.x + threadIdx.x; - - //shared memory space for storing intermediate results - __shared__ X* sPartials; - if(threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast(shmem); - } - __syncthreads(); - sPartials[threadIdx.x] = OpType::startingValue(x); - - if (xEws > 0) - for (int i = tid; i < len; i += (blockDim.x * gridDim.x)) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[i * xEws], extraParams), extraParams); - else - for (int i = tid; i < len; i += blockDim.x * gridDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], extraParams), extraParams); - - __syncthreads(); - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, len), extraParams); - __syncthreads(); - - - if (gridDim.x > 1) { - - unsigned int *tc = (unsigned int *)reductionBuffer; - __shared__ bool amLast; - - tid = threadIdx.x; - if (threadIdx.x == 0) - reductionBuffer[blockIdx.x] = sPartials[0];//this->postProcess(sPartials[0],len,extraParams); - - __threadfence(); - __syncthreads(); - - if (threadIdx.x == 0) { - unsigned int ticket = atomicInc(&tc[16384], gridDim.x); - amLast = (ticket == gridDim.x - 1); - } - - __syncthreads(); - - if (amLast) { - - tc[16384] = 0; - sPartials[threadIdx.x] = OpType::startingValue(x); - - for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], reductionBuffer[i], extraParams); - - __syncthreads(); - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(gridDim.x, blockDim.x), extraParams); - __syncthreads(); - - if (threadIdx.x == 0) { - z[0] = OpType::postProcess(sPartials[0], len, extraParams); - } - } - } - else { - - if (threadIdx.x == 0) { - unsigned int *tc = (unsigned *)reductionBuffer; - tc[16384] = 0; - z[0] = OpType::postProcess(sPartials[0], len, extraParams); - } - } - } + __syncthreads(); + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(gridDim.x, blockDim.x), + extraParams); + __syncthreads(); + + if (threadIdx.x == 0) { + z[0] = OpType::postProcess(sPartials[0], len, extraParams); + } + } + } else { + if (threadIdx.x == 0) { + unsigned int *tc = (unsigned *)reductionBuffer; + tc[16384] = 0; + z[0] = OpType::postProcess(sPartials[0], len, extraParams); } + } } +} // namespace reduce +} // namespace functions -#endif //DEV_TESTS_REDUCE_SAME_LOOPS_H +#endif // SD_REDUCE_SAME_LOOPS_H diff --git a/libnd4j/include/loops/cuda/inplace_loops/scalar_inplace.h b/libnd4j/include/loops/cuda/inplace_loops/scalar_inplace.h index df1a87ba896e..2cafbbca6240 100644 --- a/libnd4j/include/loops/cuda/inplace_loops/scalar_inplace.h +++ b/libnd4j/include/loops/cuda/inplace_loops/scalar_inplace.h @@ -18,65 +18,66 @@ // @author raver119@gmail.com // +#ifndef SD_SCALAR_INPLACE_H +#define SD_SCALAR_INPLACE_H -#ifndef DEV_TESTS_SCALAR_INPLACE_H -#define DEV_TESTS_SCALAR_INPLACE_H - +#include #include -#include #include -#include +#include using namespace simdOps; namespace functions { - namespace scalar { - template - class ScalarInplace { - public: - static FORCEINLINE _CUDA_D void transformCudaLegacy(int opNum, void* vscalar, void *vy, Nd4jLong *yShapeInfo, void *vparams, void *vz, Nd4jLong *zShapeInfo, int *allocationBuffer); - - template - static FORCEINLINE _CUDA_D void transformCuda(void* vscalar, void *vy, Nd4jLong *yShapeInfo, void *vparams, void *vz, Nd4jLong *zShapeInfo, int *allocationBuffer); - }; - - template - FORCEINLINE _CUDA_D void ScalarInplace::transformCudaLegacy(int opNum, void* vscalar, - void *vy, Nd4jLong *yShapeInfo, - void *vparams, - void *vz, Nd4jLong *zShapeInfo, - int *allocationBuffer) { - - DISPATCH_BY_OPNUM_TTT(transformCuda, PARAMS(vscalar, vy, yShapeInfo, vparams, vz, zShapeInfo, allocationBuffer), SCALAR_OPS); - } - - template - template - FORCEINLINE _CUDA_D void ScalarInplace::transformCuda(void* vscalar, - void *vy, Nd4jLong *yShapeInfo, - void *vparams, - void *vz, Nd4jLong *zShapeInfo, - int *allocationBuffer) { - - auto scalar = reinterpret_cast(vscalar)[0]; - auto y = reinterpret_cast(vy); - auto params = reinterpret_cast(vparams); - auto z = reinterpret_cast(vz); - - int totalThreads = gridDim.x * blockDim.x; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - __shared__ Nd4jLong length; - if(threadIdx.x == 0) - length = shape::length(yShapeInfo); - __syncthreads(); - +namespace scalar { +template +class ScalarInplace { + public: + static FORCEINLINE _CUDA_D void transformCudaLegacy( + int opNum, void *vscalar, void *vy, Nd4jLong *yShapeInfo, void *vparams, + void *vz, Nd4jLong *zShapeInfo, int *allocationBuffer); + + template + static FORCEINLINE _CUDA_D void transformCuda(void *vscalar, void *vy, + Nd4jLong *yShapeInfo, + void *vparams, void *vz, + Nd4jLong *zShapeInfo, + int *allocationBuffer); +}; + +template +FORCEINLINE _CUDA_D void ScalarInplace::transformCudaLegacy( + int opNum, void *vscalar, void *vy, Nd4jLong *yShapeInfo, void *vparams, + void *vz, Nd4jLong *zShapeInfo, int *allocationBuffer) { + DISPATCH_BY_OPNUM_TTT(transformCuda, + PARAMS(vscalar, vy, yShapeInfo, vparams, vz, zShapeInfo, + allocationBuffer), + SCALAR_OPS); +} - for (Nd4jLong i = tid; i < length; i+= totalThreads) { - z[shape::getIndexOffset(i, zShapeInfo)] = OpType::op(y[shape::getIndexOffset(i, yShapeInfo)], scalar, params); - } - } - } +template +template +FORCEINLINE _CUDA_D void ScalarInplace::transformCuda( + void *vscalar, void *vy, Nd4jLong *yShapeInfo, void *vparams, void *vz, + Nd4jLong *zShapeInfo, int *allocationBuffer) { + auto scalar = reinterpret_cast(vscalar)[0]; + auto y = reinterpret_cast(vy); + auto params = reinterpret_cast(vparams); + auto z = reinterpret_cast(vz); + + int totalThreads = gridDim.x * blockDim.x; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ Nd4jLong length; + if (threadIdx.x == 0) length = shape::length(yShapeInfo); + __syncthreads(); + + for (Nd4jLong i = tid; i < length; i += totalThreads) { + z[shape::getIndexOffset(i, zShapeInfo)] = + OpType::op(y[shape::getIndexOffset(i, yShapeInfo)], scalar, params); + } } +} // namespace scalar +} // namespace functions -#endif //DEV_TESTS_SCALAR_INPLACE_H +#endif // SD_SCALAR_INPLACE_H diff --git a/libnd4j/include/loops/cuda/inplace_loops/transform_strict_inplace.h b/libnd4j/include/loops/cuda/inplace_loops/transform_strict_inplace.h index c4b94fca5651..78cbcb40d897 100644 --- a/libnd4j/include/loops/cuda/inplace_loops/transform_strict_inplace.h +++ b/libnd4j/include/loops/cuda/inplace_loops/transform_strict_inplace.h @@ -18,82 +18,77 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_TRANSFORM_FLOAT_INPLACE_H -#define DEV_TESTS_TRANSFORM_FLOAT_INPLACE_H +#ifndef SD_TRANSFORM_FLOAT_INPLACE_H +#define SD_TRANSFORM_FLOAT_INPLACE_H +#include #include -#include #include -#include +#include using namespace simdOps; -#define LOCAL_TRANSFORM_STRICT_OPS \ - (23, Exp), \ - (24, Log) +#define LOCAL_TRANSFORM_STRICT_OPS (23, Exp), (24, Log) namespace functions { - namespace transform { - template - class TransformStrictInplace { - public: - static FORCEINLINE _CUDA_D void transformCudaLegacy(int opNum, void *dy, Nd4jLong *shapeInfo, void *params, void *result, Nd4jLong *zShapeInfo, int *allocationPointer, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets); - - template - static FORCEINLINE _CUDA_D void transformCuda(void *vdy, Nd4jLong *shapeInfo, void *vparams, void *vresult, Nd4jLong *zShapeInfo, int *allocationPointer, void *vreductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets); - }; - - template - template - FORCEINLINE _CUDA_D void TransformStrictInplace::transformCuda( - void *vdy, - Nd4jLong *shapeInfo, - void *vparams, - void *vresult, - Nd4jLong *zShapeInfo, - int *allocationPointer, void *vreductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { - - auto dy = static_cast(vdy); - auto result = static_cast(vresult); - auto params = static_cast(vparams); - auto reductionPointer = static_cast(vreductionPointer); - - auto xOrder = shape::order(shapeInfo); - auto zOrder = shape::order(zShapeInfo); - - auto xEws = shape::elementWiseStride(shapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - __shared__ Nd4jLong length; - if(threadIdx.x == 0) - length = shape::length(shapeInfo); - __syncthreads(); - - - for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { - auto xOffset2 = shape::getIndexOffset(i, shapeInfo); - auto zOffset2 = shape::getIndexOffset(i, zShapeInfo); - result[zOffset2] = OpType::op(dy[xOffset2], params); - } - } +namespace transform { +template +class TransformStrictInplace { + public: + static FORCEINLINE _CUDA_D void transformCudaLegacy( + int opNum, void *dy, Nd4jLong *shapeInfo, void *params, void *result, + Nd4jLong *zShapeInfo, int *allocationPointer, void *reductionPointer, + Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets); + + template + static FORCEINLINE _CUDA_D void transformCuda( + void *vdy, Nd4jLong *shapeInfo, void *vparams, void *vresult, + Nd4jLong *zShapeInfo, int *allocationPointer, void *vreductionPointer, + Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets); +}; + +template +template +FORCEINLINE _CUDA_D void TransformStrictInplace::transformCuda( + void *vdy, Nd4jLong *shapeInfo, void *vparams, void *vresult, + Nd4jLong *zShapeInfo, int *allocationPointer, void *vreductionPointer, + Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + auto dy = static_cast(vdy); + auto result = static_cast(vresult); + auto params = static_cast(vparams); + auto reductionPointer = static_cast(vreductionPointer); + + auto xOrder = shape::order(shapeInfo); + auto zOrder = shape::order(zShapeInfo); + + auto xEws = shape::elementWiseStride(shapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ Nd4jLong length; + if (threadIdx.x == 0) length = shape::length(shapeInfo); + __syncthreads(); + + for (Nd4jLong i = tid; i < length; i += gridDim.x * blockDim.x) { + auto xOffset2 = shape::getIndexOffset(i, shapeInfo); + auto zOffset2 = shape::getIndexOffset(i, zShapeInfo); + result[zOffset2] = OpType::op(dy[xOffset2], params); + } +} - template - FORCEINLINE _CUDA_D void TransformStrictInplace::transformCudaLegacy( - int opNum, - void *dy, - Nd4jLong *shapeInfo, - void *params, - void *result, - Nd4jLong *zShapeInfo, - int *allocationPointer, - void *reductionPointer, - Nd4jLong *tadShapeInfo, - Nd4jLong *tadOffsets) { - DISPATCH_BY_OPNUM_T(transformCuda, PARAMS(dy, shapeInfo, params, result, zShapeInfo, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), LOCAL_TRANSFORM_STRICT_OPS); - } - } +template +FORCEINLINE _CUDA_D void TransformStrictInplace::transformCudaLegacy( + int opNum, void *dy, Nd4jLong *shapeInfo, void *params, void *result, + Nd4jLong *zShapeInfo, int *allocationPointer, void *reductionPointer, + Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + DISPATCH_BY_OPNUM_T( + transformCuda, + PARAMS(dy, shapeInfo, params, result, zShapeInfo, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets), + LOCAL_TRANSFORM_STRICT_OPS); } +} // namespace transform +} // namespace functions #undef LOCAL_TRANSFORM_STRICT_OPS -#endif //DEV_TESTS_TRANSFORM_FLOAT_INPLACE_H +#endif // SD_TRANSFORM_FLOAT_INPLACE_H diff --git a/libnd4j/include/loops/cuda/legacy/transform.legacy b/libnd4j/include/loops/cuda/legacy/transform.legacy index 88a4ceb168cb..691913cd875c 100644 --- a/libnd4j/include/loops/cuda/legacy/transform.legacy +++ b/libnd4j/include/loops/cuda/legacy/transform.legacy @@ -297,9 +297,9 @@ namespace functions { } - //template class ND4J_EXPORT Transform; - //template class ND4J_EXPORT Transform; - //template class ND4J_EXPORT Transform; + //template class SD_EXPORT Transform; + //template class SD_EXPORT Transform; + //template class SD_EXPORT Transform; BUILD_CALL_1(template __device__ void Transform::transformCuda, float, (float*, Nd4jLong*, float*, float*,Nd4jLong*, int*,float*, UnifiedSharedMemory*, Nd4jLong*, Nd4jLong*), TRANSFORM_OPS) BUILD_CALL_1(template __device__ void Transform::transformCuda, float16, (float16*, Nd4jLong*, float16*, float16*,Nd4jLong*, int*, float16*, UnifiedSharedMemory*, Nd4jLong*, Nd4jLong*), TRANSFORM_OPS) diff --git a/libnd4j/include/loops/cuda/pairwise.chpp b/libnd4j/include/loops/cuda/pairwise.chpp index ee2c01695a73..92c8f6607cc2 100644 --- a/libnd4j/include/loops/cuda/pairwise.chpp +++ b/libnd4j/include/loops/cuda/pairwise.chpp @@ -20,104 +20,109 @@ #ifndef PAIRWISE_CU #define PAIRWISE_CU - #include "../pairwise_transform.h" - using namespace simdOps; //////////////////////////////////////////////////////////////////////////////// template -__global__ static void pairwiseSimpleShaped(void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - void *vextraParams) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - __shared__ int xEws; - __shared__ int yEws; - __shared__ int zEws; - __shared__ char xOrder; - __shared__ char yOrder; - __shared__ char zOrder; - __shared__ Nd4jLong len; - - if (threadIdx.x == 0) { - xEws = shape::elementWiseStride(xShapeInfo); - yEws = shape::elementWiseStride(yShapeInfo); - zEws = shape::elementWiseStride(zShapeInfo); - xOrder = shape::order(xShapeInfo); - yOrder = shape::order(yShapeInfo); - zOrder = shape::order(zShapeInfo); - len = shape::length(xShapeInfo); - } - __syncthreads(); - - - if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == yOrder && xOrder == zOrder) { - for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { - z[i * zEws] = OpType::op(x[i * xEws], y[i * yEws], extraParams); - } - } - else if (vx == vz) { - for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto yOffset = shape::getIndexOffset(i, yShapeInfo); - - z[xOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); - } - } - else { - for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto yOffset = shape::getIndexOffset(i, yShapeInfo); - auto zOffset = shape::getIndexOffset(i, zShapeInfo); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); - } - } +__global__ static void pairwiseSimpleShaped( + void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + void* vextraParams) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ int xEws; + __shared__ int yEws; + __shared__ int zEws; + __shared__ char xOrder; + __shared__ char yOrder; + __shared__ char zOrder; + __shared__ Nd4jLong len; + + if (threadIdx.x == 0) { + xEws = shape::elementWiseStride(xShapeInfo); + yEws = shape::elementWiseStride(yShapeInfo); + zEws = shape::elementWiseStride(zShapeInfo); + xOrder = shape::order(xShapeInfo); + yOrder = shape::order(yShapeInfo); + zOrder = shape::order(zShapeInfo); + len = shape::length(xShapeInfo); + } + __syncthreads(); + + if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == yOrder && + xOrder == zOrder) { + for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { + z[i * zEws] = OpType::op(x[i * xEws], y[i * yEws], extraParams); + } + } else if (vx == vz) { + for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto yOffset = shape::getIndexOffset(i, yShapeInfo); + + z[xOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + } + } else { + for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto yOffset = shape::getIndexOffset(i, yShapeInfo); + auto zOffset = shape::getIndexOffset(i, zShapeInfo); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + } + } } -namespace functions { +namespace functions { namespace pairwise_transforms { //////////////////////////////////////////////////////////////////////////////// -template -template -void __host__ PairWiseTransform::intermediateShaped(dim3& launchDims, cudaStream_t *stream, - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - void *vextraParams){ - - pairwiseSimpleShaped<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams); +template +template +void __host__ PairWiseTransform::intermediateShaped( + dim3& launchDims, cudaStream_t* stream, void const* vx, + Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, + void* vz, Nd4jLong const* zShapeInfo, void* vextraParams) { + pairwiseSimpleShaped + <<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams); } //////////////////////////////////////////////////////////////////////////////// -template -void __host__ PairWiseTransform::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, void *vz, Nd4jLong const* zShapeInfo, void* vextraParams) { - DISPATCH_BY_OPNUM_TTT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_TRANSFORM_OPS); +template +void __host__ PairWiseTransform::executeCudaShaped( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, + void* vz, Nd4jLong const* zShapeInfo, void* vextraParams) { + DISPATCH_BY_OPNUM_TTT(intermediateShaped, + PARAMS(launchDims, stream, vx, xShapeInfo, vy, + yShapeInfo, vz, zShapeInfo, vextraParams), + PAIRWISE_TRANSFORM_OPS); } /* - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_0); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_1); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_2); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_3); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_4); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_5); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_6); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_7); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_8); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_9); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_0); BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT + PairWiseTransform, , PAIRWISE_TYPES_1); BUILD_PAIRWISE_TEMPLATE(template + class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_2); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_3); BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT + PairWiseTransform, , PAIRWISE_TYPES_4); BUILD_PAIRWISE_TEMPLATE(template + class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_5); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_6); BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT + PairWiseTransform, , PAIRWISE_TYPES_7); BUILD_PAIRWISE_TEMPLATE(template + class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_8); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_9); */ -} -} +} // namespace pairwise_transforms +} // namespace functions -#endif // PAIRWISE_CU \ No newline at end of file +#endif // PAIRWISE_CU \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/pairwise.cu b/libnd4j/include/loops/cuda/pairwise.cu index 4833d32d0927..869200ba901e 100644 --- a/libnd4j/include/loops/cuda/pairwise.cu +++ b/libnd4j/include/loops/cuda/pairwise.cu @@ -21,7 +21,5 @@ #include "../pairwise_transform.h" namespace functions { - namespace pairwise_transforms { - - } -} \ No newline at end of file +namespace pairwise_transforms {} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/pairwise_bool.cu b/libnd4j/include/loops/cuda/pairwise_bool.cu index 29cc90f2c8dc..c52ec95b611f 100644 --- a/libnd4j/include/loops/cuda/pairwise_bool.cu +++ b/libnd4j/include/loops/cuda/pairwise_bool.cu @@ -20,98 +20,98 @@ #ifndef PAIRWISE_BOOL_CU #define PAIRWISE_BOOL_CU - #include "../pairwise_bool.h" - using namespace simdOps; //////////////////////////////////////////////////////////////////////////////// template -__global__ static void pairwiseSimpleShaped(void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - void *vextraParams) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - __shared__ int xEws; - __shared__ int yEws; - __shared__ int zEws; - __shared__ char xOrder; - __shared__ char yOrder; - __shared__ char zOrder; - __shared__ Nd4jLong len; - - if (threadIdx.x == 0) { - xEws = shape::elementWiseStride(xShapeInfo); - yEws = shape::elementWiseStride(yShapeInfo); - zEws = shape::elementWiseStride(zShapeInfo); - xOrder = shape::order(xShapeInfo); - yOrder = shape::order(yShapeInfo); - zOrder = shape::order(zShapeInfo); - len = shape::length(xShapeInfo); - } - __syncthreads(); - - - if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == yOrder && xOrder == zOrder) { - for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { - z[i * zEws] = OpType::op(x[i * xEws], y[i * yEws], extraParams); - } - } - else if (vx == vz) { - for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto yOffset = shape::getIndexOffset(i, yShapeInfo); - - z[xOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); - } - } - else { - for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto yOffset = shape::getIndexOffset(i, yShapeInfo); - auto zOffset = shape::getIndexOffset(i, zShapeInfo); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); - } - } +__global__ static void pairwiseSimpleShaped( + void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + void* vextraParams) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ int xEws; + __shared__ int yEws; + __shared__ int zEws; + __shared__ char xOrder; + __shared__ char yOrder; + __shared__ char zOrder; + __shared__ Nd4jLong len; + + if (threadIdx.x == 0) { + xEws = shape::elementWiseStride(xShapeInfo); + yEws = shape::elementWiseStride(yShapeInfo); + zEws = shape::elementWiseStride(zShapeInfo); + xOrder = shape::order(xShapeInfo); + yOrder = shape::order(yShapeInfo); + zOrder = shape::order(zShapeInfo); + len = shape::length(xShapeInfo); + } + __syncthreads(); + + if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == yOrder && + xOrder == zOrder) { + for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { + z[i * zEws] = OpType::op(x[i * xEws], y[i * yEws], extraParams); + } + } else if (vx == vz) { + for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto yOffset = shape::getIndexOffset(i, yShapeInfo); + + z[xOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + } + } else { + for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto yOffset = shape::getIndexOffset(i, yShapeInfo); + auto zOffset = shape::getIndexOffset(i, zShapeInfo); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + } + } } - -namespace functions { +namespace functions { namespace pairwise_transforms { //////////////////////////////////////////////////////////////////////////////// -template -template -void _CUDA_H PairWiseBoolTransform::intermediateShaped(dim3& launchDims, cudaStream_t *stream, - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - void *vextraParams){ - - pairwiseSimpleShaped<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams); +template +template +void _CUDA_H PairWiseBoolTransform::intermediateShaped( + dim3& launchDims, cudaStream_t* stream, void const* vx, + Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, + void* vz, Nd4jLong const* zShapeInfo, void* vextraParams) { + pairwiseSimpleShaped + <<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams); } - //////////////////////////////////////////////////////////////////////////////// -template -void PairWiseBoolTransform::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, void *vz, Nd4jLong const* zShapeInfo, void *vextraParams) { - auto xType = sd::DataTypeUtils::fromT(); - auto yType = sd::DataTypeUtils::fromT(); - - DISPATCH_BY_OPNUM_TT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_BOOL_OPS); +template +void PairWiseBoolTransform::executeCudaShaped( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, + void* vz, Nd4jLong const* zShapeInfo, void* vextraParams) { + auto xType = sd::DataTypeUtils::fromT(); + auto yType = sd::DataTypeUtils::fromT(); + + DISPATCH_BY_OPNUM_TT(intermediateShaped, + PARAMS(launchDims, stream, vx, xShapeInfo, vy, + yShapeInfo, vz, zShapeInfo, vextraParams), + PAIRWISE_BOOL_OPS); } - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT PairWiseBoolTransform, , LIBND4J_TYPES, BOOL_TYPES); -} -} +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT PairWiseBoolTransform, , + LIBND4J_TYPES, BOOL_TYPES); +} // namespace pairwise_transforms +} // namespace functions -#endif // PAIRWISE_BOOL_CU \ No newline at end of file +#endif // PAIRWISE_BOOL_CU \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/pairwise_int.cu b/libnd4j/include/loops/cuda/pairwise_int.cu index 740995cee728..039a5e682f51 100644 --- a/libnd4j/include/loops/cuda/pairwise_int.cu +++ b/libnd4j/include/loops/cuda/pairwise_int.cu @@ -20,97 +20,97 @@ #ifndef PAIRWISE_INT_CU #define PAIRWISE_INT_CU - #include "../pairwise_int.h" - using namespace simdOps; //////////////////////////////////////////////////////////////////////////////// template -__global__ static void pairwiseSimpleShaped(void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - void *vextraParams) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - __shared__ int xEws; - __shared__ int yEws; - __shared__ int zEws; - __shared__ char xOrder; - __shared__ char yOrder; - __shared__ char zOrder; - __shared__ Nd4jLong len; - - if (threadIdx.x == 0) { - xEws = shape::elementWiseStride(xShapeInfo); - yEws = shape::elementWiseStride(yShapeInfo); - zEws = shape::elementWiseStride(zShapeInfo); - xOrder = shape::order(xShapeInfo); - yOrder = shape::order(yShapeInfo); - zOrder = shape::order(zShapeInfo); - len = shape::length(xShapeInfo); - } - __syncthreads(); - - - if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == yOrder && xOrder == zOrder) { - for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { - z[i * zEws] = OpType::op(x[i * xEws], y[i * yEws], extraParams); - } - } - else if (vx == vz) { - for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto yOffset = shape::getIndexOffset(i, yShapeInfo); - - z[xOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); - } - } - else { - for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto yOffset = shape::getIndexOffset(i, yShapeInfo); - auto zOffset = shape::getIndexOffset(i, zShapeInfo); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); - } - } +__global__ static void pairwiseSimpleShaped( + void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + void* vextraParams) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ int xEws; + __shared__ int yEws; + __shared__ int zEws; + __shared__ char xOrder; + __shared__ char yOrder; + __shared__ char zOrder; + __shared__ Nd4jLong len; + + if (threadIdx.x == 0) { + xEws = shape::elementWiseStride(xShapeInfo); + yEws = shape::elementWiseStride(yShapeInfo); + zEws = shape::elementWiseStride(zShapeInfo); + xOrder = shape::order(xShapeInfo); + yOrder = shape::order(yShapeInfo); + zOrder = shape::order(zShapeInfo); + len = shape::length(xShapeInfo); + } + __syncthreads(); + + if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == yOrder && + xOrder == zOrder) { + for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { + z[i * zEws] = OpType::op(x[i * xEws], y[i * yEws], extraParams); + } + } else if (vx == vz) { + for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto yOffset = shape::getIndexOffset(i, yShapeInfo); + + z[xOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + } + } else { + for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto yOffset = shape::getIndexOffset(i, yShapeInfo); + auto zOffset = shape::getIndexOffset(i, zShapeInfo); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + } + } } - -namespace functions { +namespace functions { namespace pairwise_transforms { //////////////////////////////////////////////////////////////////////////////// -template -template -void _CUDA_H PairWiseIntTransform::intermediateShaped(dim3& launchDims, cudaStream_t *stream, - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - void *vextraParams){ - - pairwiseSimpleShaped<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams); +template +template +void _CUDA_H PairWiseIntTransform::intermediateShaped( + dim3& launchDims, cudaStream_t* stream, void const* vx, + Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, + void* vz, Nd4jLong const* zShapeInfo, void* vextraParams) { + pairwiseSimpleShaped + <<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams); } - //////////////////////////////////////////////////////////////////////////////// -template -void PairWiseIntTransform::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, void *vz, Nd4jLong const* zShapeInfo, void *vextraParams) { - auto xType = sd::DataTypeUtils::fromT(); - - DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_INT_OPS); +template +void PairWiseIntTransform::executeCudaShaped( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, + void* vz, Nd4jLong const* zShapeInfo, void* vextraParams) { + auto xType = sd::DataTypeUtils::fromT(); + + DISPATCH_BY_OPNUM_T(intermediateShaped, + PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, + vz, zShapeInfo, vextraParams), + PAIRWISE_INT_OPS); } - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT PairWiseIntTransform, , INTEGER_TYPES); -} -} +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT PairWiseIntTransform, , + INTEGER_TYPES); +} // namespace pairwise_transforms +} // namespace functions -#endif // PAIRWISE_INT_CU \ No newline at end of file +#endif // PAIRWISE_INT_CU \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/random.cu b/libnd4j/include/loops/cuda/random.cu index 755763293561..472c06510877 100644 --- a/libnd4j/include/loops/cuda/random.cu +++ b/libnd4j/include/loops/cuda/random.cu @@ -18,429 +18,524 @@ // @author raver119@gmail.com // -#include -#include -#include #include #include #include +#include #include +#include +#include using namespace randomOps; template -static inline __device__ void randomSingleGeneric( - Nd4jPointer state, - void *z, - Nd4jLong const* zShapeBuffer, - void *extraArguments) { - - - functions::random::RandomFunction::template execTransformCuda( - state, - z, - zShapeBuffer, - extraArguments); +static inline __device__ void randomSingleGeneric(Nd4jPointer state, void* z, + Nd4jLong const* zShapeBuffer, + void* extraArguments) { + functions::random::RandomFunction::template execTransformCuda( + state, z, zShapeBuffer, extraArguments); } template static inline __device__ void randomDoubleGeneric( - Nd4jPointer state, - void const* x, - Nd4jLong const* xShapeBuffer, - void *z, - Nd4jLong const* zShapeBuffer, - void *extraArguments) { - - - functions::random::RandomFunction::template execTransformCuda( - state, - x, - xShapeBuffer, - z, - zShapeBuffer, - extraArguments); + Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, void* z, + Nd4jLong const* zShapeBuffer, void* extraArguments) { + functions::random::RandomFunction::template execTransformCuda( + state, x, xShapeBuffer, z, zShapeBuffer, extraArguments); } - template static inline __device__ void randomTripleGeneric( - Nd4jPointer state, - void const* x, - Nd4jLong const* xShapeBuffer, - void const* y, - Nd4jLong const* yShapeBuffer, - void *z, - Nd4jLong const* zShapeBuffer, - void *extraArguments) { - - - functions::random::RandomFunction::template execTransformCuda( - state, - x, - xShapeBuffer, - y, - yShapeBuffer, - z, - zShapeBuffer, - extraArguments); + Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, + void const* y, Nd4jLong const* yShapeBuffer, void* z, + Nd4jLong const* zShapeBuffer, void* extraArguments) { + functions::random::RandomFunction::template execTransformCuda( + state, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments); } - #ifndef __CLION_IDE__ // here we generate kernels for target operations -DISPATCH_KERNEL_SIMPLE(randomSingle_, randomSingleGeneric, float, INPUT(Nd4jPointer state, void *z, Nd4jLong const* zShapeBuffer, void *extraArguments), PARAMS(state, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) -DISPATCH_KERNEL_SIMPLE(randomSingle_, randomSingleGeneric, double, INPUT(Nd4jPointer state, void *z, Nd4jLong const* zShapeBuffer, void *extraArguments), PARAMS(state, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) -DISPATCH_KERNEL_SIMPLE(randomSingle_, randomSingleGeneric, float16, INPUT(Nd4jPointer state, void *z, Nd4jLong const* zShapeBuffer, void *extraArguments), PARAMS(state, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) -DISPATCH_KERNEL_SIMPLE(randomSingle_, randomSingleGeneric, bfloat16, INPUT(Nd4jPointer state, void *z, Nd4jLong const* zShapeBuffer, void *extraArguments), PARAMS(state, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) - -DISPATCH_KERNEL_SIMPLE(randomDouble_, randomDoubleGeneric, float, INPUT(Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, void *z, Nd4jLong const* zShapeBuffer, void *extraArguments), PARAMS(state, x, xShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) -DISPATCH_KERNEL_SIMPLE(randomDouble_, randomDoubleGeneric, double, INPUT(Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, void *z, Nd4jLong const* zShapeBuffer, void *extraArguments), PARAMS(state, x, xShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) -DISPATCH_KERNEL_SIMPLE(randomDouble_, randomDoubleGeneric, float16, INPUT(Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, void *z, Nd4jLong const* zShapeBuffer, void *extraArguments), PARAMS(state, x, xShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) -DISPATCH_KERNEL_SIMPLE(randomDouble_, randomDoubleGeneric, bfloat16, INPUT(Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, void *z, Nd4jLong const* zShapeBuffer, void *extraArguments), PARAMS(state, x, xShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) - -DISPATCH_KERNEL_SIMPLE(randomTriple_, randomTripleGeneric, float, INPUT(Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, void const* y, Nd4jLong const* yShapeBuffer, void *z, Nd4jLong const* zShapeBuffer, void *extraArguments), PARAMS(state, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) -DISPATCH_KERNEL_SIMPLE(randomTriple_, randomTripleGeneric, double, INPUT(Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, void const* y, Nd4jLong const* yShapeBuffer, void *z, Nd4jLong const* zShapeBuffer, void *extraArguments), PARAMS(state, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) -DISPATCH_KERNEL_SIMPLE(randomTriple_, randomTripleGeneric, float16, INPUT(Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, void const* y, Nd4jLong const* yShapeBuffer, void *z, Nd4jLong const* zShapeBuffer, void *extraArguments), PARAMS(state, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) -DISPATCH_KERNEL_SIMPLE(randomTriple_, randomTripleGeneric, bfloat16, INPUT(Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, void const* y, Nd4jLong const* yShapeBuffer, void *z, Nd4jLong const* zShapeBuffer, void *extraArguments), PARAMS(state, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) +DISPATCH_KERNEL_SIMPLE(randomSingle_, randomSingleGeneric, float, + INPUT(Nd4jPointer state, void* z, + Nd4jLong const* zShapeBuffer, + void* extraArguments), + PARAMS(state, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) +DISPATCH_KERNEL_SIMPLE(randomSingle_, randomSingleGeneric, double, + INPUT(Nd4jPointer state, void* z, + Nd4jLong const* zShapeBuffer, + void* extraArguments), + PARAMS(state, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) +DISPATCH_KERNEL_SIMPLE(randomSingle_, randomSingleGeneric, float16, + INPUT(Nd4jPointer state, void* z, + Nd4jLong const* zShapeBuffer, + void* extraArguments), + PARAMS(state, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) +DISPATCH_KERNEL_SIMPLE(randomSingle_, randomSingleGeneric, bfloat16, + INPUT(Nd4jPointer state, void* z, + Nd4jLong const* zShapeBuffer, + void* extraArguments), + PARAMS(state, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) + +DISPATCH_KERNEL_SIMPLE( + randomDouble_, randomDoubleGeneric, float, + INPUT(Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, + void* z, Nd4jLong const* zShapeBuffer, void* extraArguments), + PARAMS(state, x, xShapeBuffer, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) +DISPATCH_KERNEL_SIMPLE( + randomDouble_, randomDoubleGeneric, double, + INPUT(Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, + void* z, Nd4jLong const* zShapeBuffer, void* extraArguments), + PARAMS(state, x, xShapeBuffer, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) +DISPATCH_KERNEL_SIMPLE( + randomDouble_, randomDoubleGeneric, float16, + INPUT(Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, + void* z, Nd4jLong const* zShapeBuffer, void* extraArguments), + PARAMS(state, x, xShapeBuffer, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) +DISPATCH_KERNEL_SIMPLE( + randomDouble_, randomDoubleGeneric, bfloat16, + INPUT(Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, + void* z, Nd4jLong const* zShapeBuffer, void* extraArguments), + PARAMS(state, x, xShapeBuffer, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) + +DISPATCH_KERNEL_SIMPLE(randomTriple_, randomTripleGeneric, float, + INPUT(Nd4jPointer state, void const* x, + Nd4jLong const* xShapeBuffer, void const* y, + Nd4jLong const* yShapeBuffer, void* z, + Nd4jLong const* zShapeBuffer, + void* extraArguments), + PARAMS(state, x, xShapeBuffer, y, yShapeBuffer, z, + zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) +DISPATCH_KERNEL_SIMPLE(randomTriple_, randomTripleGeneric, double, + INPUT(Nd4jPointer state, void const* x, + Nd4jLong const* xShapeBuffer, void const* y, + Nd4jLong const* yShapeBuffer, void* z, + Nd4jLong const* zShapeBuffer, + void* extraArguments), + PARAMS(state, x, xShapeBuffer, y, yShapeBuffer, z, + zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) +DISPATCH_KERNEL_SIMPLE(randomTriple_, randomTripleGeneric, float16, + INPUT(Nd4jPointer state, void const* x, + Nd4jLong const* xShapeBuffer, void const* y, + Nd4jLong const* yShapeBuffer, void* z, + Nd4jLong const* zShapeBuffer, + void* extraArguments), + PARAMS(state, x, xShapeBuffer, y, yShapeBuffer, z, + zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) +DISPATCH_KERNEL_SIMPLE(randomTriple_, randomTripleGeneric, bfloat16, + INPUT(Nd4jPointer state, void const* x, + Nd4jLong const* xShapeBuffer, void const* y, + Nd4jLong const* yShapeBuffer, void* z, + Nd4jLong const* zShapeBuffer, + void* extraArguments), + PARAMS(state, x, xShapeBuffer, y, yShapeBuffer, z, + zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) #endif namespace functions { - namespace random { - template - template - void _CUDA_D RandomFunction::execTransformCuda(Nd4jPointer state, void const* vx, Nd4jLong const* xShapeBuffer, void const* vy, Nd4jLong const* yShapeBuffer, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - if (OpClass::requiresSpecial) { - OpClass::specialOpCuda(state, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments); - return; - } else { - - __shared__ Nd4jLong length; - __shared__ int xEWS; - __shared__ int yEWS; - __shared__ int zEWS; - __shared__ char xOrder; - __shared__ char yOrder; - __shared__ char zOrder; - - __shared__ sd::graph::RandomGenerator *buffer; - __shared__ unsigned char *cB; - __shared__ unsigned char *dB; - sd::graph::RandomGenerator *devBuffer; - if (threadIdx.x == 0) { - length = shape::length(zShapeBuffer); - xEWS = shape::elementWiseStride(xShapeBuffer); - yEWS = shape::elementWiseStride(yShapeBuffer); - zEWS = shape::elementWiseStride(zShapeBuffer); - xOrder = shape::order(xShapeBuffer); - yOrder = shape::order(yShapeBuffer); - zOrder = shape::order(zShapeBuffer); - - extern __shared__ unsigned char shmem[]; - buffer = (sd::graph::RandomGenerator *) shmem; - cB = shmem; - devBuffer = reinterpret_cast (state); - dB = reinterpret_cast (state); - } - __syncthreads(); - - // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); e+= blockDim.x) - cB[e] = dB[e]; - - __syncthreads(); - - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - if (xEWS >= 1 && yEWS >= 1 && zEWS >= 1 && xOrder == yOrder && xOrder == zOrder) { - for (Nd4jLong e = tid; e < length; e += blockDim.x * gridDim.x) { - z[e * zEWS] = OpClass::op(x[e * xEWS], y[e * yEWS], e, length, buffer, extraArguments); - } - } else { - for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x) { - - auto xOffset2 = shape::getIndexOffset(i, xShapeBuffer); - auto yOffset2 = shape::getIndexOffset(i, yShapeBuffer); - auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer); - - z[zOffset2] = OpClass::op(x[xOffset2], y[yOffset2], i, length, buffer, extraArguments); - } - } - } - }; - - - template - template - void _CUDA_D RandomFunction::execTransformCuda(Nd4jPointer state, void const* vx, Nd4jLong const* xShapeBuffer, void* vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - __shared__ Nd4jLong length; - __shared__ int xEWS; - __shared__ int zEWS; - __shared__ char xOrder; - __shared__ char zOrder; - - __shared__ sd::graph::RandomGenerator *buffer; - __shared__ unsigned char *cB; - __shared__ unsigned char *dB; - __shared__ sd::graph::RandomGenerator *devBuffer; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - buffer = (sd::graph::RandomGenerator *) shmem; - cB = shmem; - devBuffer = reinterpret_cast (state); - dB = reinterpret_cast (state); - - length = shape::length(zShapeBuffer); - xEWS = shape::elementWiseStride(xShapeBuffer); - zEWS = shape::elementWiseStride(zShapeBuffer); - xOrder = shape::order(xShapeBuffer); - zOrder = shape::order(zShapeBuffer); - } - __syncthreads(); - - // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); e+= blockDim.x) - cB[e] = dB[e]; - - __syncthreads(); - - - if (xEWS >= 1 && zEWS >= 1 && xOrder == zOrder) { - for (Nd4jLong e = blockIdx.x * blockDim.x + threadIdx.x; e < length; e += blockDim.x * gridDim.x) { - z[e * zEWS] = OpClass::op(x[e * xEWS], e, length, buffer, extraArguments); - } - } else { - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < length; i += blockDim.x * gridDim.x) { - - auto xOffset2 = shape::getIndexOffset(i, xShapeBuffer); - auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer); - - z[zOffset2] = OpClass::op(x[xOffset2], i, length, buffer, extraArguments); - } - } - } - - - template - template - void _CUDA_D RandomFunction::execTransformCuda(Nd4jPointer state, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { - - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - __shared__ Nd4jLong length; - __shared__ Nd4jLong ews; - __shared__ sd::graph::RandomGenerator *buffer; - __shared__ unsigned char *cB; - __shared__ unsigned char *dB; - __shared__ sd::graph::RandomGenerator *devBuffer; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - buffer = (sd::graph::RandomGenerator *) shmem; - cB = shmem; - devBuffer = reinterpret_cast (state); - dB = reinterpret_cast (state); - length = shape::length(zShapeBuffer); - ews = shape::elementWiseStride(zShapeBuffer); - } - __syncthreads(); - - // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); e+= blockDim.x) - cB[e] = dB[e]; - - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - if (ews > 0) { - for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x) { - z[i * ews] = OpClass::op(i, length, buffer, extraArguments); - } - } else { - - for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x) { - auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer); - z[zOffset2] = OpClass::op(i, length, buffer, extraArguments); - } - } - } - - template <> - _CUDA_H void RandomFunction::executeCudaSingle(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { - - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - // this macro builds bunch of IF/ELSE selectors for kernel launch - DISPATCH_SIMPLE(randomSingle, float, PARAMS(stateHost, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) - - DEBUG_KERNEL(stream, opNum); - } - - template <> - _CUDA_H void RandomFunction::executeCudaSingle(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { - - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - // this macro builds bunch of IF/ELSE selectors for kernel launch - DISPATCH_SIMPLE(randomSingle, float16, PARAMS(stateHost, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) - - DEBUG_KERNEL(stream, opNum); - } - - template <> - _CUDA_H void RandomFunction::executeCudaSingle(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { - - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - // this macro builds bunch of IF/ELSE selectors for kernel launch - DISPATCH_SIMPLE(randomSingle, bfloat16, PARAMS(stateHost, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) - - DEBUG_KERNEL(stream, opNum); - } - - template <> - _CUDA_H void RandomFunction::executeCudaSingle(dim3& launchDims, cudaStream_t *stream, int opNum, Nd4jPointer stateHost, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { - - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - // this macro builds bunch of IF/ELSE selectors for kernel launch - DISPATCH_SIMPLE(randomSingle, double, PARAMS(stateHost, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) - - DEBUG_KERNEL(stream, opNum); - } - - template <> - _CUDA_H void RandomFunction::executeCudaDouble(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void const* vx, Nd4jLong const* xShapeBuffer, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - // this macro builds bunch of IF/ELSE selectors for kernel launch - DISPATCH_SIMPLE(randomDouble, float, PARAMS(stateHost, x, xShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) - - DEBUG_KERNEL(stream, opNum); - } - - - template <> - _CUDA_H void RandomFunction::executeCudaDouble(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void const* vx, Nd4jLong const* xShapeBuffer, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - // this macro builds bunch of IF/ELSE selectors for kernel launch - DISPATCH_SIMPLE(randomDouble, float16, PARAMS(stateHost, x, xShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) - - DEBUG_KERNEL(stream, opNum); - } - - template <> - _CUDA_H void RandomFunction::executeCudaDouble(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void const* vx, Nd4jLong const* xShapeBuffer, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - // this macro builds bunch of IF/ELSE selectors for kernel launch - DISPATCH_SIMPLE(randomDouble, bfloat16, PARAMS(stateHost, x, xShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) - - DEBUG_KERNEL(stream, opNum); - } - - template <> - _CUDA_H void RandomFunction::executeCudaDouble(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void const* vx, Nd4jLong const* xShapeBuffer, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { +namespace random { +template +template +void _CUDA_D RandomFunction::execTransformCuda( + Nd4jPointer state, void const* vx, Nd4jLong const* xShapeBuffer, + void const* vy, Nd4jLong const* yShapeBuffer, void* vz, + Nd4jLong const* zShapeBuffer, void* vextraArguments) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + if (OpClass::requiresSpecial) { + OpClass::specialOpCuda(state, x, xShapeBuffer, y, yShapeBuffer, z, + zShapeBuffer, extraArguments); + return; + } else { + __shared__ Nd4jLong length; + __shared__ int xEWS; + __shared__ int yEWS; + __shared__ int zEWS; + __shared__ char xOrder; + __shared__ char yOrder; + __shared__ char zOrder; + + __shared__ sd::graph::RandomGenerator* buffer; + __shared__ unsigned char* cB; + __shared__ unsigned char* dB; + sd::graph::RandomGenerator* devBuffer; + if (threadIdx.x == 0) { + length = shape::length(zShapeBuffer); + xEWS = shape::elementWiseStride(xShapeBuffer); + yEWS = shape::elementWiseStride(yShapeBuffer); + zEWS = shape::elementWiseStride(zShapeBuffer); + xOrder = shape::order(xShapeBuffer); + yOrder = shape::order(yShapeBuffer); + zOrder = shape::order(zShapeBuffer); + + extern __shared__ unsigned char shmem[]; + buffer = (sd::graph::RandomGenerator*)shmem; + cB = shmem; + devBuffer = reinterpret_cast(state); + dB = reinterpret_cast(state); + } + __syncthreads(); + + // using this loop instead of memcpy + for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); + e += blockDim.x) + cB[e] = dB[e]; + + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + if (xEWS >= 1 && yEWS >= 1 && zEWS >= 1 && xOrder == yOrder && + xOrder == zOrder) { + for (Nd4jLong e = tid; e < length; e += blockDim.x * gridDim.x) { + z[e * zEWS] = OpClass::op(x[e * xEWS], y[e * yEWS], e, length, buffer, + extraArguments); + } + } else { + for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x) { + auto xOffset2 = shape::getIndexOffset(i, xShapeBuffer); + auto yOffset2 = shape::getIndexOffset(i, yShapeBuffer); + auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer); + + z[zOffset2] = OpClass::op(x[xOffset2], y[yOffset2], i, length, buffer, + extraArguments); + } + } + } +}; + +template +template +void _CUDA_D RandomFunction::execTransformCuda( + Nd4jPointer state, void const* vx, Nd4jLong const* xShapeBuffer, void* vz, + Nd4jLong const* zShapeBuffer, void* vextraArguments) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + __shared__ Nd4jLong length; + __shared__ int xEWS; + __shared__ int zEWS; + __shared__ char xOrder; + __shared__ char zOrder; + + __shared__ sd::graph::RandomGenerator* buffer; + __shared__ unsigned char* cB; + __shared__ unsigned char* dB; + __shared__ sd::graph::RandomGenerator* devBuffer; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + buffer = (sd::graph::RandomGenerator*)shmem; + cB = shmem; + devBuffer = reinterpret_cast(state); + dB = reinterpret_cast(state); + + length = shape::length(zShapeBuffer); + xEWS = shape::elementWiseStride(xShapeBuffer); + zEWS = shape::elementWiseStride(zShapeBuffer); + xOrder = shape::order(xShapeBuffer); + zOrder = shape::order(zShapeBuffer); + } + __syncthreads(); + + // using this loop instead of memcpy + for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); + e += blockDim.x) + cB[e] = dB[e]; + + __syncthreads(); + + if (xEWS >= 1 && zEWS >= 1 && xOrder == zOrder) { + for (Nd4jLong e = blockIdx.x * blockDim.x + threadIdx.x; e < length; + e += blockDim.x * gridDim.x) { + z[e * zEWS] = OpClass::op(x[e * xEWS], e, length, buffer, extraArguments); + } + } else { + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < length; + i += blockDim.x * gridDim.x) { + auto xOffset2 = shape::getIndexOffset(i, xShapeBuffer); + auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer); - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); + z[zOffset2] = OpClass::op(x[xOffset2], i, length, buffer, extraArguments); + } + } +} - // this macro builds bunch of IF/ELSE selectors for kernel launch - DISPATCH_SIMPLE(randomDouble, double, PARAMS(stateHost, x, xShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) +template +template +void _CUDA_D RandomFunction::execTransformCuda(Nd4jPointer state, void* vz, + Nd4jLong const* zShapeBuffer, + void* vextraArguments) { + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + __shared__ Nd4jLong length; + __shared__ Nd4jLong ews; + __shared__ sd::graph::RandomGenerator* buffer; + __shared__ unsigned char* cB; + __shared__ unsigned char* dB; + __shared__ sd::graph::RandomGenerator* devBuffer; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + buffer = (sd::graph::RandomGenerator*)shmem; + cB = shmem; + devBuffer = reinterpret_cast(state); + dB = reinterpret_cast(state); + length = shape::length(zShapeBuffer); + ews = shape::elementWiseStride(zShapeBuffer); + } + __syncthreads(); + + // using this loop instead of memcpy + for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); + e += blockDim.x) + cB[e] = dB[e]; + + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + if (ews > 0) { + for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x) { + z[i * ews] = OpClass::op(i, length, buffer, extraArguments); + } + } else { + for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x) { + auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer); + z[zOffset2] = OpClass::op(i, length, buffer, extraArguments); + } + } +} - DEBUG_KERNEL(stream, opNum); - } +template <> +_CUDA_H void RandomFunction::executeCudaSingle( + dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, + void* vz, Nd4jLong const* zShapeBuffer, void* vextraArguments) { + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); - template <> - _CUDA_H void RandomFunction::executeCudaTriple(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void const* vx, Nd4jLong const* xShapeBuffer, void const* vy, Nd4jLong const* yShapeBuffer, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { + // this macro builds bunch of IF/ELSE selectors for kernel launch + DISPATCH_SIMPLE(randomSingle, float, + PARAMS(stateHost, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); + DEBUG_KERNEL(stream, opNum); +} - // this macro builds bunch of IF/ELSE selectors for kernel launch - DISPATCH_SIMPLE(randomTriple, float, PARAMS(stateHost, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) +template <> +_CUDA_H void RandomFunction::executeCudaSingle( + dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, + void* vz, Nd4jLong const* zShapeBuffer, void* vextraArguments) { + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); - DEBUG_KERNEL(stream, opNum); - } + // this macro builds bunch of IF/ELSE selectors for kernel launch + DISPATCH_SIMPLE(randomSingle, float16, + PARAMS(stateHost, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) - template <> - _CUDA_H void RandomFunction::executeCudaTriple(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void const* vx, Nd4jLong const* xShapeBuffer, void const* vy, Nd4jLong const* yShapeBuffer, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { + DEBUG_KERNEL(stream, opNum); +} - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); +template <> +_CUDA_H void RandomFunction::executeCudaSingle( + dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, + void* vz, Nd4jLong const* zShapeBuffer, void* vextraArguments) { + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); - // this macro builds bunch of IF/ELSE selectors for kernel launch - DISPATCH_SIMPLE(randomTriple, float16, PARAMS(stateHost, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) + // this macro builds bunch of IF/ELSE selectors for kernel launch + DISPATCH_SIMPLE(randomSingle, bfloat16, + PARAMS(stateHost, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) - DEBUG_KERNEL(stream, opNum); - } + DEBUG_KERNEL(stream, opNum); +} - template <> - _CUDA_H void RandomFunction::executeCudaTriple(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void const* vx, Nd4jLong const* xShapeBuffer, void const* vy, Nd4jLong const* yShapeBuffer, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { +template <> +_CUDA_H void RandomFunction::executeCudaSingle( + dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, + void* vz, Nd4jLong const* zShapeBuffer, void* vextraArguments) { + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); + // this macro builds bunch of IF/ELSE selectors for kernel launch + DISPATCH_SIMPLE(randomSingle, double, + PARAMS(stateHost, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) - // this macro builds bunch of IF/ELSE selectors for kernel launch - DISPATCH_SIMPLE(randomTriple, bfloat16, PARAMS(stateHost, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) + DEBUG_KERNEL(stream, opNum); +} - DEBUG_KERNEL(stream, opNum); - } +template <> +_CUDA_H void RandomFunction::executeCudaDouble( + dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, + void const* vx, Nd4jLong const* xShapeBuffer, void* vz, + Nd4jLong const* zShapeBuffer, void* vextraArguments) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + // this macro builds bunch of IF/ELSE selectors for kernel launch + DISPATCH_SIMPLE( + randomDouble, float, + PARAMS(stateHost, x, xShapeBuffer, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) + + DEBUG_KERNEL(stream, opNum); +} +template <> +_CUDA_H void RandomFunction::executeCudaDouble( + dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, + void const* vx, Nd4jLong const* xShapeBuffer, void* vz, + Nd4jLong const* zShapeBuffer, void* vextraArguments) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + // this macro builds bunch of IF/ELSE selectors for kernel launch + DISPATCH_SIMPLE( + randomDouble, float16, + PARAMS(stateHost, x, xShapeBuffer, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) + + DEBUG_KERNEL(stream, opNum); +} +template <> +_CUDA_H void RandomFunction::executeCudaDouble( + dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, + void const* vx, Nd4jLong const* xShapeBuffer, void* vz, + Nd4jLong const* zShapeBuffer, void* vextraArguments) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + // this macro builds bunch of IF/ELSE selectors for kernel launch + DISPATCH_SIMPLE( + randomDouble, bfloat16, + PARAMS(stateHost, x, xShapeBuffer, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) + + DEBUG_KERNEL(stream, opNum); +} - template <> - _CUDA_H void RandomFunction::executeCudaTriple(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void const* vx, Nd4jLong const* xShapeBuffer, void const* vy, Nd4jLong const* yShapeBuffer, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { +template <> +_CUDA_H void RandomFunction::executeCudaDouble( + dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, + void const* vx, Nd4jLong const* xShapeBuffer, void* vz, + Nd4jLong const* zShapeBuffer, void* vextraArguments) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + // this macro builds bunch of IF/ELSE selectors for kernel launch + DISPATCH_SIMPLE( + randomDouble, double, + PARAMS(stateHost, x, xShapeBuffer, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) + + DEBUG_KERNEL(stream, opNum); +} - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); +template <> +_CUDA_H void RandomFunction::executeCudaTriple( + dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, + void const* vx, Nd4jLong const* xShapeBuffer, void const* vy, + Nd4jLong const* yShapeBuffer, void* vz, Nd4jLong const* zShapeBuffer, + void* vextraArguments) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + // this macro builds bunch of IF/ELSE selectors for kernel launch + DISPATCH_SIMPLE(randomTriple, float, + PARAMS(stateHost, x, xShapeBuffer, y, yShapeBuffer, z, + zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) + + DEBUG_KERNEL(stream, opNum); +} - // this macro builds bunch of IF/ELSE selectors for kernel launch - DISPATCH_SIMPLE(randomTriple, double, PARAMS(stateHost, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) +template <> +_CUDA_H void RandomFunction::executeCudaTriple( + dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, + void const* vx, Nd4jLong const* xShapeBuffer, void const* vy, + Nd4jLong const* yShapeBuffer, void* vz, Nd4jLong const* zShapeBuffer, + void* vextraArguments) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + // this macro builds bunch of IF/ELSE selectors for kernel launch + DISPATCH_SIMPLE(randomTriple, float16, + PARAMS(stateHost, x, xShapeBuffer, y, yShapeBuffer, z, + zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) + + DEBUG_KERNEL(stream, opNum); +} - DEBUG_KERNEL(stream, opNum); - } +template <> +_CUDA_H void RandomFunction::executeCudaTriple( + dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, + void const* vx, Nd4jLong const* xShapeBuffer, void const* vy, + Nd4jLong const* yShapeBuffer, void* vz, Nd4jLong const* zShapeBuffer, + void* vextraArguments) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + // this macro builds bunch of IF/ELSE selectors for kernel launch + DISPATCH_SIMPLE(randomTriple, bfloat16, + PARAMS(stateHost, x, xShapeBuffer, y, yShapeBuffer, z, + zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) + + DEBUG_KERNEL(stream, opNum); +} - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT RandomFunction, , FLOAT_TYPES); - } +template <> +_CUDA_H void RandomFunction::executeCudaTriple( + dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, + void const* vx, Nd4jLong const* xShapeBuffer, void const* vy, + Nd4jLong const* yShapeBuffer, void* vz, Nd4jLong const* zShapeBuffer, + void* vextraArguments) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + // this macro builds bunch of IF/ELSE selectors for kernel launch + DISPATCH_SIMPLE(randomTriple, double, + PARAMS(stateHost, x, xShapeBuffer, y, yShapeBuffer, z, + zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) + + DEBUG_KERNEL(stream, opNum); } + +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT RandomFunction, , FLOAT_TYPES); +} // namespace random +} // namespace functions diff --git a/libnd4j/include/loops/cuda/reduce/reduce_bool.cu b/libnd4j/include/loops/cuda/reduce/reduce_bool.cu index de854416db30..75d63b19c52f 100644 --- a/libnd4j/include/loops/cuda/reduce/reduce_bool.cu +++ b/libnd4j/include/loops/cuda/reduce/reduce_bool.cu @@ -19,327 +19,352 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include -#include -#include -#include #include +#include +#include +#include +#include #include - +#include +#include using namespace simdOps; //////////////////////////////////////////////////////////////////////// template __global__ void simpleReduce(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets) { - - functions::reduce::ReduceBoolFunction::template transformCudaXD(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo, tadOffsets); + void *extraParams, void *z, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets) { + functions::reduce::ReduceBoolFunction::template transformCudaXD( + x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, + reductionBuffer, tadOnlyShapeInfo, tadOffsets); } //////////////////////////////////////////////////////////////////////// template __global__ void simpleScalar(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo) { - - functions::reduce::ReduceBoolFunction::template execScalarCuda(x, xShapeInfo, extraParams, z, zShapeInfo, reductionBuffer, tadOnlyShapeInfo); + functions::reduce::ReduceBoolFunction::template execScalarCuda( + x, xShapeInfo, extraParams, z, zShapeInfo, reductionBuffer, + tadOnlyShapeInfo); } - namespace functions { -namespace reduce { +namespace reduce { //////////////////////////////////////////////////////////////////////// template template -__device__ void ReduceBoolFunction::aggregatePartials(void *vsPartials, Nd4jLong tid, Nd4jLong numItems, void *vextraParams) { +__device__ void ReduceBoolFunction::aggregatePartials( + void *vsPartials, Nd4jLong tid, Nd4jLong numItems, void *vextraParams) { + // start the shared memory loop on the next power of 2 less + // than the block size. If block size is not a power of 2, + // accumulate the intermediate sums in the remainder range. - // start the shared memory loop on the next power of 2 less - // than the block size. If block size is not a power of 2, - // accumulate the intermediate sums in the remainder range. + auto sPartials = reinterpret_cast(vsPartials); + auto extraParams = reinterpret_cast(vextraParams); - auto sPartials = reinterpret_cast(vsPartials); - auto extraParams = reinterpret_cast(vextraParams); + Nd4jLong floorPow2 = numItems; - Nd4jLong floorPow2 = numItems; + if (floorPow2 & (floorPow2 - 1)) { + while (floorPow2 & (floorPow2 - 1)) floorPow2 &= floorPow2 - 1; - if (floorPow2 & (floorPow2 - 1)) { + if (tid >= floorPow2) + sPartials[tid - floorPow2] = OpType::update(sPartials[tid - floorPow2], + sPartials[tid], extraParams); - while (floorPow2 & (floorPow2 - 1)) - floorPow2 &= floorPow2 - 1; - - if (tid >= floorPow2) - sPartials[tid - floorPow2] = OpType::update(sPartials[tid - floorPow2], sPartials[tid], extraParams); - - __syncthreads(); - } + __syncthreads(); + } - for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; activeThreads >>= 1) { - if (tid < activeThreads && tid + activeThreads < numItems) - sPartials[tid] = OpType::update(sPartials[tid], sPartials[tid + activeThreads], extraParams); + for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; + activeThreads >>= 1) { + if (tid < activeThreads && tid + activeThreads < numItems) + sPartials[tid] = OpType::update( + sPartials[tid], sPartials[tid + activeThreads], extraParams); - __syncthreads(); - } + __syncthreads(); + } } //////////////////////////////////////////////////////////////////////// template template -__device__ void ReduceBoolFunction::transformCudaXD(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - void *vreductionBuffer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - auto reductionBuffer = reinterpret_cast(vreductionBuffer); - - //shared memory space for storing intermediate results - __shared__ Z* sPartials; - __shared__ int tadLength, numTads; - __shared__ bool isPlainOutput; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast(shmem); - - isPlainOutput = shape::order(zShapeInfo) == 'c' && shape::elementWiseStride(zShapeInfo) == 1; - - tadLength = shape::length(tadOnlyShapeInfo); //tadLength(xShapeInfo, dimension, dimensionLength); - numTads = shape::length(xShapeInfo) / tadLength; +__device__ void ReduceBoolFunction::transformCudaXD( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, + void *vreductionBuffer, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + auto reductionBuffer = reinterpret_cast(vreductionBuffer); + + // shared memory space for storing intermediate results + __shared__ Z *sPartials; + __shared__ int tadLength, numTads; + __shared__ bool isPlainOutput; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast(shmem); + + isPlainOutput = shape::order(zShapeInfo) == 'c' && + shape::elementWiseStride(zShapeInfo) == 1; + + tadLength = + shape::length(tadOnlyShapeInfo); // tadLength(xShapeInfo, + // dimension, dimensionLength); + numTads = shape::length(xShapeInfo) / tadLength; + } + __syncthreads(); + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + Nd4jLong tadOffsetForBlock = tadOffsets[r]; + sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock); + + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = + tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo); + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], + OpType::op(x[xOffset], extraParams), extraParams); } __syncthreads(); - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - - Nd4jLong tadOffsetForBlock = tadOffsets[r]; - sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock); - - for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { + // aggregate. do NOT reduce for elements > tadLength + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, tadLength), + extraParams); - auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo); - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[xOffset], extraParams), extraParams); - } - __syncthreads(); - - // aggregate. do NOT reduce for elements > tadLength - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, tadLength), extraParams); - - __syncthreads(); + __syncthreads(); - if (threadIdx.x == 0) - z[isPlainOutput ? r : shape::getIndexOffset(r, zShapeInfo)] = OpType::postProcess(sPartials[threadIdx.x], tadLength, extraParams); - } + if (threadIdx.x == 0) + z[isPlainOutput ? r : shape::getIndexOffset(r, zShapeInfo)] = + OpType::postProcess(sPartials[threadIdx.x], tadLength, extraParams); + } } //////////////////////////////////////////////////////////////////////// template template -__device__ void ReduceBoolFunction::execScalarCuda(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - void *vreductionBuffer, - const Nd4jLong *tadOnlyShapeInfo) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - auto reductionBuffer = reinterpret_cast(vreductionBuffer); - - auto tid = blockDim.x * blockIdx.x + threadIdx.x; - - //shared memory space for storing intermediate results - __shared__ Z* sPartials; - __shared__ Nd4jLong xEws; - __shared__ Nd4jLong len; - - if(threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast(shmem); - xEws = shape::elementWiseStride(xShapeInfo); - len = shape::length(xShapeInfo); - } +__device__ void ReduceBoolFunction::execScalarCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, void *vreductionBuffer, + const Nd4jLong *tadOnlyShapeInfo) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + auto reductionBuffer = reinterpret_cast(vreductionBuffer); + + auto tid = blockDim.x * blockIdx.x + threadIdx.x; + + // shared memory space for storing intermediate results + __shared__ Z *sPartials; + __shared__ Nd4jLong xEws; + __shared__ Nd4jLong len; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast(shmem); + xEws = shape::elementWiseStride(xShapeInfo); + len = shape::length(xShapeInfo); + } + __syncthreads(); + + sPartials[threadIdx.x] = OpType::startingValue(x); + + if (xEws > 0) + for (int i = tid; i < len; i += (blockDim.x * gridDim.x)) + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], + OpType::op(x[i * xEws], extraParams), extraParams); + else + for (int i = tid; i < len; i += blockDim.x * gridDim.x) + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], + OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], extraParams), + extraParams); + + __syncthreads(); + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, len), + extraParams); + __syncthreads(); + + if (gridDim.x > 1) { + unsigned int *tc = (unsigned int *)reductionBuffer; + __shared__ bool amLast; + + tid = threadIdx.x; + if (threadIdx.x == 0) + reductionBuffer[blockIdx.x] = + sPartials[0]; // this->postProcess(sPartials[0],len,extraParams); + + __threadfence(); __syncthreads(); - sPartials[threadIdx.x] = OpType::startingValue(x); - - if (xEws > 0) - for (int i = tid; i < len; i += (blockDim.x * gridDim.x)) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[i * xEws], extraParams), extraParams); - else - for (int i = tid; i < len; i += blockDim.x * gridDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], extraParams), extraParams); + if (threadIdx.x == 0) { + unsigned int ticket = atomicInc(&tc[16384], gridDim.x); + amLast = (ticket == gridDim.x - 1); + } __syncthreads(); - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, len), extraParams); - __syncthreads(); - - if (gridDim.x > 1) { - - unsigned int *tc = (unsigned int *)reductionBuffer; - __shared__ bool amLast; - - tid = threadIdx.x; - if (threadIdx.x == 0) - reductionBuffer[blockIdx.x] = sPartials[0];//this->postProcess(sPartials[0],len,extraParams); - __threadfence(); - __syncthreads(); + if (amLast) { + tc[16384] = 0; + sPartials[threadIdx.x] = OpType::startingValue(x); - if (threadIdx.x == 0) { - unsigned int ticket = atomicInc(&tc[16384], gridDim.x); - amLast = (ticket == gridDim.x - 1); - } + for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], reductionBuffer[i], extraParams); - __syncthreads(); + __syncthreads(); + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(gridDim.x, blockDim.x), + extraParams); + __syncthreads(); - if (amLast) { - - tc[16384] = 0; - sPartials[threadIdx.x] = OpType::startingValue(x); - - for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], reductionBuffer[i], extraParams); - - __syncthreads(); - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(gridDim.x, blockDim.x), extraParams); - __syncthreads(); - - if (threadIdx.x == 0) { - z[0] = OpType::postProcess(sPartials[0], len, extraParams); - } - } + if (threadIdx.x == 0) { + z[0] = OpType::postProcess(sPartials[0], len, extraParams); + } } - else { - - if (threadIdx.x == 0) { - unsigned int *tc = (unsigned *)reductionBuffer; - tc[16384] = 0; - z[0] = OpType::postProcess(sPartials[0], len, extraParams); - } + } else { + if (threadIdx.x == 0) { + unsigned int *tc = (unsigned *)reductionBuffer; + tc[16384] = 0; + z[0] = OpType::postProcess(sPartials[0], len, extraParams); } + } } //////////////////////////////////////////////////////////////////////// template -template -__host__ void ReduceBoolFunction::intermediateXD(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, - int *dimension, int dimensionLength, - void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - if(shape::isEmpty(hXShapeInfo)) { - - if(shape::isEmpty(hZShapeInfo)) - return; - - const auto startingVal = static_cast(OpType::startingValue(reinterpret_cast(x))); - - auto res = cudaMemcpyAsync(sd::LaunchContext::defaultContext()->getScalarPointer(), &startingVal, sizeof(Z), cudaMemcpyHostToDevice, *stream); - if (res != 0) - throw sd::cuda_exception::build("ReduceBoolFunction::intermediateXD: failed to copy temporary scalar", res); - - auto ptr = sd::LaunchContext::defaultContext()->getScalarPointer(); - - // scalar assign - functions::scalar::ScalarTransform::executeCudaShaped(launchDims, stream, 14, z, zShapeInfo, hZShapeInfo, z, zShapeInfo, hZShapeInfo, ptr, nullptr); - sd::DebugHelper::checkErrorCode(stream, "reduceBoolDim empty(...) failed"); - } - else { - simpleReduce<<>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets); - sd::DebugHelper::checkErrorCode(stream, "reduceBoolDim(...) failed"); - } +template +__host__ void ReduceBoolFunction::intermediateXD( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, + int *dimension, int dimensionLength, void *reductionPointer, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + + + if (shape::isEmpty(hXShapeInfo)) { + if (shape::isEmpty(hZShapeInfo)) return; + + const auto startingVal = + static_cast(OpType::startingValue(reinterpret_cast(x))); + + auto res = cudaMemcpyAsync( + sd::LaunchContext::defaultContext()->getScalarPointer(), &startingVal, + sizeof(Z), cudaMemcpyHostToDevice, *stream); + if (res != 0) + throw sd::cuda_exception::build( + "ReduceBoolFunction::intermediateXD: failed to copy temporary " + "scalar", + res); + + auto ptr = sd::LaunchContext::defaultContext()->getScalarPointer(); + + // scalar assign + functions::scalar::ScalarTransform::executeCudaShaped( + launchDims, stream, 14, z, zShapeInfo, hZShapeInfo, z, zShapeInfo, + hZShapeInfo, ptr, nullptr); + sd::DebugHelper::checkErrorCode(stream, "reduceBoolDim empty(...) failed"); + } else { + simpleReduce + <<>>( + x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, reductionPointer, tadShapeInfo, tadOffsets); + sd::DebugHelper::checkErrorCode(stream, "reduceBoolDim(...) failed"); + } } //////////////////////////////////////////////////////////////////////// template -template -__host__ void ReduceBoolFunction::intermediateScalar(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo) { - - if (shape::isEmpty(hXShapeInfo)) { - - if (shape::isEmpty(hZShapeInfo)) - return; - - const auto startingVal = static_cast(OpType::startingValue(reinterpret_cast(x))); - - auto res = cudaMemcpyAsync(z, &startingVal, sizeof(Z), cudaMemcpyHostToDevice, *stream); - if (res != 0) - throw sd::cuda_exception::build("ReduceBoolFunction::intermediateScalar: failed to copy resulting scalar", res); - - sd::DebugHelper::checkErrorCode(stream, "reduceBoolScalar empty(...) failed"); - - } - else { - simpleScalar<<>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo); - sd::DebugHelper::checkErrorCode(stream, "reduceBoolScalar(...) failed"); - } +template +__host__ void ReduceBoolFunction::intermediateScalar( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, + int *dimension, int dimensionLength, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo) { + if (shape::isEmpty(hXShapeInfo)) { + if (shape::isEmpty(hZShapeInfo)) return; + + const auto startingVal = + static_cast(OpType::startingValue(reinterpret_cast(x))); + + auto res = cudaMemcpyAsync(z, &startingVal, sizeof(Z), + cudaMemcpyHostToDevice, *stream); + if (res != 0) + throw sd::cuda_exception::build( + "ReduceBoolFunction::intermediateScalar: failed to copy " + "resulting scalar", + res); + + sd::DebugHelper::checkErrorCode(stream, + "reduceBoolScalar empty(...) failed"); + + } else { + simpleScalar + <<>>( + x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, reductionBuffer, tadOnlyShapeInfo); + sd::DebugHelper::checkErrorCode(stream, "reduceBoolScalar(...) failed"); + } } //////////////////////////////////////////////////////////////////////// template -_CUDA_H void ReduceBoolFunction::execReduceScalar(dim3 launchDims, cudaStream_t *stream, - const int opNum, - const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo) { - - DISPATCH_BY_OPNUM_TT(intermediateScalar, PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, zShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo), OPS_A(REDUCE_BOOL_OPS)); - sd::DebugHelper::checkErrorCode(stream, "execReduceScalarFloat(...) failed"); +_CUDA_H void ReduceBoolFunction::execReduceScalar( + dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, + int *dimension, int dimensionLength, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo) { + DISPATCH_BY_OPNUM_TT( + intermediateScalar, + PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, + zShapeInfo, hZShapeInfo, dimension, dimensionLength, + reductionBuffer, tadOnlyShapeInfo), + OPS_A(REDUCE_BOOL_OPS)); + sd::DebugHelper::checkErrorCode(stream, "execReduceScalarFloat(...) failed"); } //////////////////////////////////////////////////////////////////////// template -_CUDA_H void ReduceBoolFunction::execReduceXD(dim3 launchDims, cudaStream_t *stream, - const int opNum, - const int rank, const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, - int *dimension, int dimensionLength, - void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - DISPATCH_BY_OPNUM_TT(intermediateXD, PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, zShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), OPS_A(REDUCE_BOOL_OPS)); - DEBUG_KERNEL(stream, opNum); +_CUDA_H void ReduceBoolFunction::execReduceXD( + dim3 launchDims, cudaStream_t *stream, const int opNum, const int rank, + const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *z, const Nd4jLong *zShapeInfo, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + DISPATCH_BY_OPNUM_TT( + intermediateXD, + PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, + zShapeInfo, hZShapeInfo, dimension, dimensionLength, + reductionPointer, tadShapeInfo, tadOffsets), + OPS_A(REDUCE_BOOL_OPS)); + DEBUG_KERNEL(stream, opNum); } //////////////////////////////////////////////////////////////////////// template __device__ void initializeShared(X *extraParams, X **sPartials, int sMemSize) { - int sPartialsLength = sMemSize / sizeof(X); - X *sPartialsDeref = (X *) *sPartials; - for (int i = 0; i < sPartialsLength; i++) - sPartialsDeref[i] = extraParams[0]; - + int sPartialsLength = sMemSize / sizeof(X); + X *sPartialsDeref = (X *)*sPartials; + for (int i = 0; i < sPartialsLength; i++) sPartialsDeref[i] = extraParams[0]; } +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceBoolFunction, , + LIBND4J_TYPES, BOOL_TYPES); -BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceBoolFunction, , LIBND4J_TYPES, BOOL_TYPES); - -} -} - +} // namespace reduce +} // namespace functions diff --git a/libnd4j/include/loops/cuda/reduce/reduce_float.chpp b/libnd4j/include/loops/cuda/reduce/reduce_float.chpp index 71f5d03dac57..5391c65623bb 100644 --- a/libnd4j/include/loops/cuda/reduce/reduce_float.chpp +++ b/libnd4j/include/loops/cuda/reduce/reduce_float.chpp @@ -18,321 +18,346 @@ // @author raver119@gmail.com // -#include +#include +#include #include -#include +#include +#include +#include #include #include -#include -#include -#include #include - -#include -#include +#include +#include using namespace simdOps; //////////////////////////////////////////////////////////////////////// template __global__ void simpleReduce(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets) { - - functions::reduce::ReduceFloatFunction::template transformCudaXD(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo, tadOffsets); + void *extraParams, void *z, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets) { + functions::reduce::ReduceFloatFunction::template transformCudaXD< + OpType>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, reductionBuffer, tadOnlyShapeInfo, tadOffsets); } //////////////////////////////////////////////////////////////////////// template __global__ void simpleScalar(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo) { - - functions::reduce::ReduceFloatFunction::template execScalarCuda(x, xShapeInfo, extraParams, z, zShapeInfo, reductionBuffer, tadOnlyShapeInfo); + functions::reduce::ReduceFloatFunction::template execScalarCuda( + x, xShapeInfo, extraParams, z, zShapeInfo, reductionBuffer, + tadOnlyShapeInfo); } namespace functions { -namespace reduce { +namespace reduce { //////////////////////////////////////////////////////////////////////// template template -__device__ void ReduceFloatFunction::aggregatePartials(void *vsPartials, Nd4jLong tid, Nd4jLong numItems, void *vextraParams) { +__device__ void ReduceFloatFunction::aggregatePartials( + void *vsPartials, Nd4jLong tid, Nd4jLong numItems, void *vextraParams) { + // start the shared memory loop on the next power of 2 less + // than the block size. If block size is not a power of 2, + // accumulate the intermediate sums in the remainder range. - // start the shared memory loop on the next power of 2 less - // than the block size. If block size is not a power of 2, - // accumulate the intermediate sums in the remainder range. + auto sPartials = reinterpret_cast(vsPartials); + auto extraParams = reinterpret_cast(vextraParams); - auto sPartials = reinterpret_cast(vsPartials); - auto extraParams = reinterpret_cast(vextraParams); + Nd4jLong floorPow2 = numItems; - Nd4jLong floorPow2 = numItems; + if (floorPow2 & (floorPow2 - 1)) { + while (floorPow2 & (floorPow2 - 1)) floorPow2 &= floorPow2 - 1; - if (floorPow2 & (floorPow2 - 1)) { + if (tid >= floorPow2) + sPartials[tid - floorPow2] = OpType::update(sPartials[tid - floorPow2], + sPartials[tid], extraParams); - while (floorPow2 & (floorPow2 - 1)) - floorPow2 &= floorPow2 - 1; - - if (tid >= floorPow2) - sPartials[tid - floorPow2] = OpType::update(sPartials[tid - floorPow2], sPartials[tid], extraParams); - - __syncthreads(); - } + __syncthreads(); + } - for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; activeThreads >>= 1) { - if (tid < activeThreads && tid + activeThreads < numItems) - sPartials[tid] = OpType::update(sPartials[tid], sPartials[tid + activeThreads], extraParams); + for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; + activeThreads >>= 1) { + if (tid < activeThreads && tid + activeThreads < numItems) + sPartials[tid] = OpType::update( + sPartials[tid], sPartials[tid + activeThreads], extraParams); - __syncthreads(); - } + __syncthreads(); + } } //////////////////////////////////////////////////////////////////////// template template -__device__ void ReduceFloatFunction::transformCudaXD(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - void *vreductionBuffer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - auto reductionBuffer = reinterpret_cast(vreductionBuffer); - - //shared memory space for storing intermediate results - __shared__ Z* sPartials; - __shared__ int tadLength, numTads; - __shared__ bool isPlainOutput; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast(shmem); - - isPlainOutput = shape::order(zShapeInfo) == 'c' && shape::elementWiseStride(zShapeInfo) == 1; - - tadLength = shape::length(tadOnlyShapeInfo); - numTads = shape::length(xShapeInfo) / tadLength; +__device__ void ReduceFloatFunction::transformCudaXD( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, + void *vreductionBuffer, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + auto reductionBuffer = reinterpret_cast(vreductionBuffer); + + // shared memory space for storing intermediate results + __shared__ Z *sPartials; + __shared__ int tadLength, numTads; + __shared__ bool isPlainOutput; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast(shmem); + + isPlainOutput = shape::order(zShapeInfo) == 'c' && + shape::elementWiseStride(zShapeInfo) == 1; + + tadLength = shape::length(tadOnlyShapeInfo); + numTads = shape::length(xShapeInfo) / tadLength; + } + __syncthreads(); + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + auto tadOffsetForBlock = tadOffsets[r]; + sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock); + + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = + tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo); + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], + OpType::op(x[xOffset], extraParams), extraParams); } __syncthreads(); - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - - auto tadOffsetForBlock = tadOffsets[r]; - sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock); - - for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { - auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo); - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[xOffset], extraParams), extraParams); - } - __syncthreads(); - - // aggregate. do NOT reduce for elements > tadLength - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, tadLength), extraParams); - __syncthreads(); + // aggregate. do NOT reduce for elements > tadLength + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, tadLength), + extraParams); + __syncthreads(); - if (threadIdx.x == 0) - z[isPlainOutput ? r : shape::getIndexOffset(r, zShapeInfo)] = OpType::postProcess(sPartials[threadIdx.x], tadLength, extraParams); - } + if (threadIdx.x == 0) + z[isPlainOutput ? r : shape::getIndexOffset(r, zShapeInfo)] = + OpType::postProcess(sPartials[threadIdx.x], tadLength, extraParams); + } } //////////////////////////////////////////////////////////////////////// template template -__device__ void ReduceFloatFunction::execScalarCuda(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - void *vreductionBuffer, - const Nd4jLong *tadOnlyShapeInfo) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - auto reductionBuffer = reinterpret_cast(vreductionBuffer); - - auto tid = blockDim.x * blockIdx.x + threadIdx.x; - - //shared memory space for storing intermediate results - __shared__ Z* sPartials; - __shared__ Nd4jLong xEws; - __shared__ Nd4jLong len; - - if(threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast(shmem); - xEws = shape::elementWiseStride(xShapeInfo); - len = shape::length(xShapeInfo); - } +__device__ void ReduceFloatFunction::execScalarCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, void *vreductionBuffer, + const Nd4jLong *tadOnlyShapeInfo) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + auto reductionBuffer = reinterpret_cast(vreductionBuffer); + + auto tid = blockDim.x * blockIdx.x + threadIdx.x; + + // shared memory space for storing intermediate results + __shared__ Z *sPartials; + __shared__ Nd4jLong xEws; + __shared__ Nd4jLong len; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast(shmem); + xEws = shape::elementWiseStride(xShapeInfo); + len = shape::length(xShapeInfo); + } + __syncthreads(); + + sPartials[threadIdx.x] = OpType::startingValue(x); + + if (xEws > 0) + for (int i = tid; i < len; i += (blockDim.x * gridDim.x)) + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], + OpType::op(x[i * xEws], extraParams), extraParams); + else + for (int i = tid; i < len; i += blockDim.x * gridDim.x) + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], + OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], extraParams), + extraParams); + + __syncthreads(); + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, len), + extraParams); + __syncthreads(); + + if (gridDim.x > 1) { + unsigned int *tc = (unsigned int *)reductionBuffer; + __shared__ bool amLast; + + tid = threadIdx.x; + if (threadIdx.x == 0) + reductionBuffer[blockIdx.x] = + sPartials[0]; // this->postProcess(sPartials[0],len,extraParams); + + __threadfence(); __syncthreads(); - sPartials[threadIdx.x] = OpType::startingValue(x); - - if (xEws > 0) - for (int i = tid; i < len; i += (blockDim.x * gridDim.x)) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[i * xEws], extraParams), extraParams); - else - for (int i = tid; i < len; i += blockDim.x * gridDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], extraParams), extraParams); + if (threadIdx.x == 0) { + unsigned int ticket = atomicInc(&tc[16384], gridDim.x); + amLast = (ticket == gridDim.x - 1); + } __syncthreads(); - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, len), extraParams); - __syncthreads(); - - if (gridDim.x > 1) { - - unsigned int *tc = (unsigned int *)reductionBuffer; - __shared__ bool amLast; - - tid = threadIdx.x; - if (threadIdx.x == 0) - reductionBuffer[blockIdx.x] = sPartials[0];//this->postProcess(sPartials[0],len,extraParams); - - __threadfence(); - __syncthreads(); - - if (threadIdx.x == 0) { - unsigned int ticket = atomicInc(&tc[16384], gridDim.x); - amLast = (ticket == gridDim.x - 1); - } - __syncthreads(); + if (amLast) { + tc[16384] = 0; + sPartials[threadIdx.x] = OpType::startingValue(x); - if (amLast) { + for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], reductionBuffer[i], extraParams); - tc[16384] = 0; - sPartials[threadIdx.x] = OpType::startingValue(x); + __syncthreads(); + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(gridDim.x, blockDim.x), + extraParams); + __syncthreads(); - for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], reductionBuffer[i], extraParams); - - __syncthreads(); - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(gridDim.x, blockDim.x), extraParams); - __syncthreads(); - - if (threadIdx.x == 0) { - z[0] = OpType::postProcess(sPartials[0], len, extraParams); - } - } + if (threadIdx.x == 0) { + z[0] = OpType::postProcess(sPartials[0], len, extraParams); + } } - else { - - if (threadIdx.x == 0) { - unsigned int *tc = (unsigned *)reductionBuffer; - tc[16384] = 0; - z[0] = OpType::postProcess(sPartials[0], len, extraParams); - } + } else { + if (threadIdx.x == 0) { + unsigned int *tc = (unsigned *)reductionBuffer; + tc[16384] = 0; + z[0] = OpType::postProcess(sPartials[0], len, extraParams); } + } } //////////////////////////////////////////////////////////////////////// template -template -__host__ void ReduceFloatFunction::intermediateXD(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShape, const Nd4jLong *hXShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShape, const Nd4jLong *hZShapeInfo, - int *dimension, int dimensionLength, - void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - if(shape::isEmpty(hXShapeInfo)) { - - if(shape::isEmpty(hZShapeInfo)) - return; - - const auto startingVal = std::is_same>::value ? sd::DataTypeUtils::nanOrZero() : static_cast(OpType::startingValue(reinterpret_cast(x))); - auto res = cudaMemcpyAsync(sd::LaunchContext::defaultContext()->getScalarPointer(), &startingVal, sizeof(Z), cudaMemcpyHostToDevice, *stream); - if (res != 0) - throw sd::cuda_exception::build("ReduceFloatFunction::intermediateXD: failed to copy temporary scalar", res); - - auto ptr = sd::LaunchContext::defaultContext()->getScalarPointer(); - - // scalar assign - functions::scalar::ScalarTransform::executeCudaShaped(launchDims, stream, 14, z, zShape, hZShapeInfo, z, zShape, hZShapeInfo, ptr, nullptr); - } - else { - simpleReduce<<>>(x, xShape, extraParams, z, zShape, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets); - } +template +__host__ void ReduceFloatFunction::intermediateXD( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShape, const Nd4jLong *hXShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShape, const Nd4jLong *hZShapeInfo, + int *dimension, int dimensionLength, void *reductionPointer, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + if (shape::isEmpty(hXShapeInfo)) { + if (shape::isEmpty(hZShapeInfo)) return; + + const auto startingVal = std::is_same>::value + ? sd::DataTypeUtils::nanOrZero() + : static_cast(OpType::startingValue( + reinterpret_cast(x))); + auto res = cudaMemcpyAsync( + sd::LaunchContext::defaultContext()->getScalarPointer(), &startingVal, + sizeof(Z), cudaMemcpyHostToDevice, *stream); + if (res != 0) + throw sd::cuda_exception::build( + "ReduceFloatFunction::intermediateXD: failed to copy temporary " + "scalar", + res); + + auto ptr = sd::LaunchContext::defaultContext()->getScalarPointer(); + + // scalar assign + functions::scalar::ScalarTransform::executeCudaShaped( + launchDims, stream, 14, z, zShape, hZShapeInfo, z, zShape, hZShapeInfo, + ptr, nullptr); + } else { + simpleReduce + <<>>( + x, xShape, extraParams, z, zShape, dimension, dimensionLength, + reductionPointer, tadShapeInfo, tadOffsets); + } } //////////////////////////////////////////////////////////////////////// template -template -__host__ void ReduceFloatFunction::intermediateScalar(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo) { - - if (shape::isEmpty(hXShapeInfo)) { - - if (shape::isEmpty(hZShapeInfo)) - return; - - const auto startingVal = std::is_same>::value ? sd::DataTypeUtils::nanOrZero() : static_cast(OpType::startingValue(reinterpret_cast(x))); - - auto res = cudaMemcpyAsync(z, &startingVal, sizeof(Z), cudaMemcpyHostToDevice, *stream); - if (res != 0) - throw sd::cuda_exception::build("ReduceFloatFunction::intermediateScalar: failed to copy resulting scalar", res); - } - else { - simpleScalar <<>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo); - } +template +__host__ void ReduceFloatFunction::intermediateScalar( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, + int *dimension, int dimensionLength, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo) { + if (shape::isEmpty(hXShapeInfo)) { + if (shape::isEmpty(hZShapeInfo)) return; + + const auto startingVal = std::is_same>::value + ? sd::DataTypeUtils::nanOrZero() + : static_cast(OpType::startingValue( + reinterpret_cast(x))); + + auto res = cudaMemcpyAsync(z, &startingVal, sizeof(Z), + cudaMemcpyHostToDevice, *stream); + if (res != 0) + throw sd::cuda_exception::build( + "ReduceFloatFunction::intermediateScalar: failed to copy " + "resulting scalar", + res); + } else { + simpleScalar + <<>>( + x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, reductionBuffer, tadOnlyShapeInfo); + } } //////////////////////////////////////////////////////////////////////// template -_CUDA_H void ReduceFloatFunction::execReduceScalar(dim3 launchDims, cudaStream_t *stream, - const int opNum, - const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo) { - - DISPATCH_BY_OPNUM_TT(intermediateScalar, PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, zShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo), OPS_A(REDUCE_FLOAT_OPS)); - sd::DebugHelper::checkErrorCode(stream, "execReduceScalarFloat(...) failed"); +_CUDA_H void ReduceFloatFunction::execReduceScalar( + dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, + int *dimension, int dimensionLength, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo) { + DISPATCH_BY_OPNUM_TT( + intermediateScalar, + PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, + zShapeInfo, hZShapeInfo, dimension, dimensionLength, + reductionBuffer, tadOnlyShapeInfo), + OPS_A(REDUCE_FLOAT_OPS)); + sd::DebugHelper::checkErrorCode(stream, "execReduceScalarFloat(...) failed"); } //////////////////////////////////////////////////////////////////////// template -_CUDA_H void ReduceFloatFunction::execReduceXD(dim3 launchDims, cudaStream_t *stream, - const int opNum, - const int rank, const void *x, const Nd4jLong *xShape, const Nd4jLong *hXShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShape, const Nd4jLong *hZShapeInfo, - int *dimension, int dimensionLength, - void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - DISPATCH_BY_OPNUM_TT(intermediateXD, PARAMS(launchDims, stream, x, xShape, hXShapeInfo, extraParams, z, zShape, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), OPS_A(REDUCE_FLOAT_OPS)); - DEBUG_KERNEL(stream, opNum); +_CUDA_H void ReduceFloatFunction::execReduceXD( + dim3 launchDims, cudaStream_t *stream, const int opNum, const int rank, + const void *x, const Nd4jLong *xShape, const Nd4jLong *hXShapeInfo, + void *extraParams, void *z, const Nd4jLong *zShape, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + DISPATCH_BY_OPNUM_TT( + intermediateXD, + PARAMS(launchDims, stream, x, xShape, hXShapeInfo, extraParams, z, zShape, + hZShapeInfo, dimension, dimensionLength, reductionPointer, + tadShapeInfo, tadOffsets), + OPS_A(REDUCE_FLOAT_OPS)); + DEBUG_KERNEL(stream, opNum); } //////////////////////////////////////////////////////////////////////// template __device__ void initializeShared(X *extraParams, X **sPartials, int sMemSize) { - int sPartialsLength = sMemSize / sizeof(X); - X *sPartialsDeref = (X *) *sPartials; - for (int i = 0; i < sPartialsLength; i++) - sPartialsDeref[i] = extraParams[0]; - + int sPartialsLength = sMemSize / sizeof(X); + X *sPartialsDeref = (X *)*sPartials; + for (int i = 0; i < sPartialsLength; i++) sPartialsDeref[i] = extraParams[0]; } +// BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , +// LIBND4J_TYPES, FLOAT_TYPES); -//BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES); - -} -} - +} // namespace reduce +} // namespace functions diff --git a/libnd4j/include/loops/cuda/reduce/reduce_long.cu b/libnd4j/include/loops/cuda/reduce/reduce_long.cu index 1beac5330858..c13ddfd304b4 100644 --- a/libnd4j/include/loops/cuda/reduce/reduce_long.cu +++ b/libnd4j/include/loops/cuda/reduce/reduce_long.cu @@ -19,342 +19,365 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include -#include -#include -#include #include +#include +#include +#include +#include #include - +#include +#include using namespace simdOps; //////////////////////////////////////////////////////////////////////// template -__device__ void reduceSimpleGeneric(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets) { - - functions::reduce::ReduceLongFunction::template transformCudaXD(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo, tadOffsets); +__device__ void reduceSimpleGeneric(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets) { + functions::reduce::ReduceLongFunction::template transformCudaXD( + x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, + reductionBuffer, tadOnlyShapeInfo, tadOffsets); } //////////////////////////////////////////////////////////////////////// template __device__ void reduceScalarGeneric(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo) { - - functions::reduce::ReduceLongFunction::template execScalarCuda(x, xShapeInfo, extraParams, z, zShapeInfo, reductionBuffer, tadOnlyShapeInfo); + functions::reduce::ReduceLongFunction::template execScalarCuda( + x, xShapeInfo, extraParams, z, zShapeInfo, reductionBuffer, + tadOnlyShapeInfo); } //////////////////////////////////////////////////////////////////////// template __global__ void simpleReduce(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets) { - - reduceSimpleGeneric(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo, tadOffsets); + void *extraParams, void *z, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets) { + reduceSimpleGeneric(x, xShapeInfo, extraParams, z, zShapeInfo, + dimension, dimensionLength, reductionBuffer, + tadOnlyShapeInfo, tadOffsets); } //////////////////////////////////////////////////////////////////////// template __global__ void simpleScalar(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo) { - - reduceScalarGeneric(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo); + void *extraParams, void *z, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo) { + reduceScalarGeneric(x, xShapeInfo, extraParams, z, zShapeInfo, + dimension, dimensionLength, reductionBuffer, + tadOnlyShapeInfo); } namespace functions { -namespace reduce { +namespace reduce { //////////////////////////////////////////////////////////////////////// template template -__device__ void ReduceLongFunction::aggregatePartials(void *vsPartials, Nd4jLong tid, Nd4jLong numItems, void *vextraParams) { - - // start the shared memory loop on the next power of 2 less - // than the block size. If block size is not a power of 2, - // accumulate the intermediate sums in the remainder range. +__device__ void ReduceLongFunction::aggregatePartials( + void *vsPartials, Nd4jLong tid, Nd4jLong numItems, void *vextraParams) { + // start the shared memory loop on the next power of 2 less + // than the block size. If block size is not a power of 2, + // accumulate the intermediate sums in the remainder range. - auto sPartials = reinterpret_cast(vsPartials); - auto extraParams = reinterpret_cast(vextraParams); + auto sPartials = reinterpret_cast(vsPartials); + auto extraParams = reinterpret_cast(vextraParams); - Nd4jLong floorPow2 = numItems; + Nd4jLong floorPow2 = numItems; - if (floorPow2 & (floorPow2 - 1)) { + if (floorPow2 & (floorPow2 - 1)) { + while (floorPow2 & (floorPow2 - 1)) floorPow2 &= floorPow2 - 1; - while (floorPow2 & (floorPow2 - 1)) - floorPow2 &= floorPow2 - 1; + if (tid >= floorPow2) + sPartials[tid - floorPow2] = OpType::update(sPartials[tid - floorPow2], + sPartials[tid], extraParams); - if (tid >= floorPow2) - sPartials[tid - floorPow2] = OpType::update(sPartials[tid - floorPow2], sPartials[tid], extraParams); - - __syncthreads(); - } + __syncthreads(); + } - for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; activeThreads >>= 1) { - if (tid < activeThreads && tid + activeThreads < numItems) - sPartials[tid] = OpType::update(sPartials[tid], sPartials[tid + activeThreads], extraParams); + for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; + activeThreads >>= 1) { + if (tid < activeThreads && tid + activeThreads < numItems) + sPartials[tid] = OpType::update( + sPartials[tid], sPartials[tid + activeThreads], extraParams); - __syncthreads(); - } + __syncthreads(); + } } //////////////////////////////////////////////////////////////////////// template template -__device__ void ReduceLongFunction::transformCudaXD(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - void *vreductionBuffer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - auto reductionBuffer = reinterpret_cast(vreductionBuffer); - - //shared memory space for storing intermediate results - __shared__ Z* sPartials; - __shared__ int tadLength, numTads; - __shared__ bool isPlainOutput; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast(shmem); - - isPlainOutput = shape::order(zShapeInfo) == 'c' && shape::elementWiseStride(zShapeInfo) == 1; - - tadLength = shape::length(tadOnlyShapeInfo); - numTads = shape::length(xShapeInfo) / tadLength; +__device__ void ReduceLongFunction::transformCudaXD( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, + void *vreductionBuffer, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + auto reductionBuffer = reinterpret_cast(vreductionBuffer); + + // shared memory space for storing intermediate results + __shared__ Z *sPartials; + __shared__ int tadLength, numTads; + __shared__ bool isPlainOutput; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast(shmem); + + isPlainOutput = shape::order(zShapeInfo) == 'c' && + shape::elementWiseStride(zShapeInfo) == 1; + + tadLength = shape::length(tadOnlyShapeInfo); + numTads = shape::length(xShapeInfo) / tadLength; + } + __syncthreads(); + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + Nd4jLong tadOffsetForBlock = tadOffsets[r]; + sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock); + + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = + tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo); + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], + OpType::op(x[xOffset], extraParams), extraParams); } __syncthreads(); - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - - Nd4jLong tadOffsetForBlock = tadOffsets[r]; - sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock); - - for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { - auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo); - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[xOffset], extraParams), extraParams); - } - __syncthreads(); - - // aggregate. do NOT reduce for elements > tadLength - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, tadLength), extraParams); - __syncthreads(); + // aggregate. do NOT reduce for elements > tadLength + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, tadLength), + extraParams); + __syncthreads(); - if (threadIdx.x == 0) - z[isPlainOutput ? r : shape::getIndexOffset(r, zShapeInfo)] = OpType::postProcess(sPartials[threadIdx.x], tadLength, extraParams); - } + if (threadIdx.x == 0) + z[isPlainOutput ? r : shape::getIndexOffset(r, zShapeInfo)] = + OpType::postProcess(sPartials[threadIdx.x], tadLength, extraParams); + } } //////////////////////////////////////////////////////////////////////// template template -__device__ void ReduceLongFunction::execScalarCuda(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - void *vreductionBuffer, - const Nd4jLong *tadOnlyShapeInfo) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - auto reductionBuffer = reinterpret_cast(vreductionBuffer); - - auto tid = blockDim.x * blockIdx.x + threadIdx.x; - - //shared memory space for storing intermediate results - __shared__ Z* sPartials; - __shared__ Nd4jLong xEws; - __shared__ Nd4jLong len; - - if(threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast(shmem); - xEws = shape::elementWiseStride(xShapeInfo); - len = shape::length(xShapeInfo); - } +__device__ void ReduceLongFunction::execScalarCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, void *vreductionBuffer, + const Nd4jLong *tadOnlyShapeInfo) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + auto reductionBuffer = reinterpret_cast(vreductionBuffer); + + auto tid = blockDim.x * blockIdx.x + threadIdx.x; + + // shared memory space for storing intermediate results + __shared__ Z *sPartials; + __shared__ Nd4jLong xEws; + __shared__ Nd4jLong len; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast(shmem); + xEws = shape::elementWiseStride(xShapeInfo); + len = shape::length(xShapeInfo); + } + __syncthreads(); + + sPartials[threadIdx.x] = OpType::startingValue(x); + + if (xEws > 0) + for (int i = tid; i < len; i += (blockDim.x * gridDim.x)) + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], + OpType::op(x[i * xEws], extraParams), extraParams); + else + for (int i = tid; i < len; i += blockDim.x * gridDim.x) + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], + OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], extraParams), + extraParams); + + __syncthreads(); + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, len), + extraParams); + __syncthreads(); + + if (gridDim.x > 1) { + auto tc = reinterpret_cast(reductionBuffer); + __shared__ bool amLast; + + tid = threadIdx.x; + if (threadIdx.x == 0) + reductionBuffer[blockIdx.x] = + sPartials[0]; // this->postProcess(sPartials[0],len,extraParams); + + __threadfence(); __syncthreads(); - sPartials[threadIdx.x] = OpType::startingValue(x); - - if (xEws > 0) - for (int i = tid; i < len; i += (blockDim.x * gridDim.x)) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[i * xEws], extraParams), extraParams); - else - for (int i = tid; i < len; i += blockDim.x * gridDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], extraParams), extraParams); + if (threadIdx.x == 0) { + unsigned int ticket = atomicInc(&tc[16384], gridDim.x); + amLast = (ticket == gridDim.x - 1); + } __syncthreads(); - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, len), extraParams); - __syncthreads(); - - if (gridDim.x > 1) { - - auto tc = reinterpret_cast(reductionBuffer); - __shared__ bool amLast; - - tid = threadIdx.x; - if (threadIdx.x == 0) - reductionBuffer[blockIdx.x] = sPartials[0];//this->postProcess(sPartials[0],len,extraParams); - - __threadfence(); - __syncthreads(); - - if (threadIdx.x == 0) { - unsigned int ticket = atomicInc(&tc[16384], gridDim.x); - amLast = (ticket == gridDim.x - 1); - } - __syncthreads(); + if (amLast) { + tc[16384] = 0; + sPartials[threadIdx.x] = OpType::startingValue(x); - if (amLast) { - tc[16384] = 0; - sPartials[threadIdx.x] = OpType::startingValue(x); + for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], reductionBuffer[i], extraParams); - for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], reductionBuffer[i], extraParams); + __syncthreads(); + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(gridDim.x, blockDim.x), + extraParams); + __syncthreads(); - __syncthreads(); - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(gridDim.x, blockDim.x), extraParams); - __syncthreads(); - - if (threadIdx.x == 0) { - z[0] = OpType::postProcess(sPartials[0], len, extraParams); - } - } + if (threadIdx.x == 0) { + z[0] = OpType::postProcess(sPartials[0], len, extraParams); + } } - else { - - if (threadIdx.x == 0) { - auto tc = reinterpret_cast(reductionBuffer); - tc[16384] = 0; - z[0] = OpType::postProcess(sPartials[0], len, extraParams); - } + } else { + if (threadIdx.x == 0) { + auto tc = reinterpret_cast(reductionBuffer); + tc[16384] = 0; + z[0] = OpType::postProcess(sPartials[0], len, extraParams); } + } } //////////////////////////////////////////////////////////////////////// template -template -__host__ void ReduceLongFunction::intermediateXD(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, - int *dimension, int dimensionLength, - void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - if(shape::isEmpty(hXShapeInfo)) { - - if(shape::isEmpty(hZShapeInfo)) - return; - - const auto startingVal = static_cast(OpType::startingValue(reinterpret_cast(x))); - - auto res = cudaMemcpyAsync(sd::LaunchContext::defaultContext()->getScalarPointer(), &startingVal, sizeof(Z), cudaMemcpyHostToDevice, *stream); - if (res != 0) - throw sd::cuda_exception::build("ReduceLongFunction::intermediateXD: failed to copy temporary scalar", res); - - auto ptr = sd::LaunchContext::defaultContext()->getScalarPointer(); - - // scalar assign - functions::scalar::ScalarTransform::executeCudaShaped(launchDims, stream, 14, z, zShapeInfo, hXShapeInfo, z, zShapeInfo, hZShapeInfo, ptr, nullptr); - } - else { - simpleReduce<<>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets); - } +template +__host__ void ReduceLongFunction::intermediateXD( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, + int *dimension, int dimensionLength, void *reductionPointer, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + if (shape::isEmpty(hXShapeInfo)) { + if (shape::isEmpty(hZShapeInfo)) return; + + const auto startingVal = + static_cast(OpType::startingValue(reinterpret_cast(x))); + + auto res = cudaMemcpyAsync( + sd::LaunchContext::defaultContext()->getScalarPointer(), &startingVal, + sizeof(Z), cudaMemcpyHostToDevice, *stream); + if (res != 0) + throw sd::cuda_exception::build( + "ReduceLongFunction::intermediateXD: failed to copy temporary " + "scalar", + res); + + auto ptr = sd::LaunchContext::defaultContext()->getScalarPointer(); + + // scalar assign + functions::scalar::ScalarTransform::executeCudaShaped( + launchDims, stream, 14, z, zShapeInfo, hXShapeInfo, z, zShapeInfo, + hZShapeInfo, ptr, nullptr); + } else { + simpleReduce + <<>>( + x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, reductionPointer, tadShapeInfo, tadOffsets); + } } //////////////////////////////////////////////////////////////////////// template -template -__host__ void ReduceLongFunction::intermediateScalar(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo) { - - if (shape::isEmpty(hXShapeInfo)) { - - if (shape::isEmpty(hZShapeInfo)) - return; - - const auto startingVal = static_cast(OpType::startingValue(reinterpret_cast(x))); - - auto res = cudaMemcpyAsync(z, &startingVal, sizeof(Z), cudaMemcpyHostToDevice, *stream); - if (res != 0) - throw sd::cuda_exception::build("ReduceLongFunction::intermediateScalar: failed to copy resulting scalar", res); - } - else { - simpleScalar<<>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo); - } +template +__host__ void ReduceLongFunction::intermediateScalar( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, + int *dimension, int dimensionLength, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo) { + if (shape::isEmpty(hXShapeInfo)) { + if (shape::isEmpty(hZShapeInfo)) return; + + const auto startingVal = + static_cast(OpType::startingValue(reinterpret_cast(x))); + + auto res = cudaMemcpyAsync(z, &startingVal, sizeof(Z), + cudaMemcpyHostToDevice, *stream); + if (res != 0) + throw sd::cuda_exception::build( + "ReduceLongFunction::intermediateScalar: failed to copy " + "resulting scalar", + res); + } else { + simpleScalar + <<>>( + x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, reductionBuffer, tadOnlyShapeInfo); + } } //////////////////////////////////////////////////////////////////////// template -_CUDA_H void ReduceLongFunction::execReduceScalar(dim3 launchDims, cudaStream_t *stream, - const int opNum, - const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong* hXShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, const Nd4jLong* hZShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo) { - - DISPATCH_BY_OPNUM_TT(intermediateScalar, PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, zShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo), OPS_A(REDUCE_LONG_OPS)); - sd::DebugHelper::checkErrorCode(stream, "execReduceScalarFloat(...) failed"); +_CUDA_H void ReduceLongFunction::execReduceScalar( + dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, + int *dimension, int dimensionLength, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo) { + DISPATCH_BY_OPNUM_TT( + intermediateScalar, + PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, + zShapeInfo, hZShapeInfo, dimension, dimensionLength, + reductionBuffer, tadOnlyShapeInfo), + OPS_A(REDUCE_LONG_OPS)); + sd::DebugHelper::checkErrorCode(stream, "execReduceScalarFloat(...) failed"); } //////////////////////////////////////////////////////////////////////// template -_CUDA_H void ReduceLongFunction::execReduceXD(dim3 launchDims, cudaStream_t *stream, - const int opNum, - int rank, const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong* hXShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, const Nd4jLong* hZShapeInfo, - int *dimension, int dimensionLength, - void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - DISPATCH_BY_OPNUM_TT(intermediateXD, PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, zShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), OPS_A(REDUCE_LONG_OPS)); - DEBUG_KERNEL(stream, opNum); +_CUDA_H void ReduceLongFunction::execReduceXD( + dim3 launchDims, cudaStream_t *stream, const int opNum, int rank, + const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *z, const Nd4jLong *zShapeInfo, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + DISPATCH_BY_OPNUM_TT( + intermediateXD, + PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, + zShapeInfo, hZShapeInfo, dimension, dimensionLength, + reductionPointer, tadShapeInfo, tadOffsets), + OPS_A(REDUCE_LONG_OPS)); + DEBUG_KERNEL(stream, opNum); } //////////////////////////////////////////////////////////////////////// template __device__ void initializeShared(X *extraParams, X **sPartials, int sMemSize) { - int sPartialsLength = sMemSize / sizeof(X); - X *sPartialsDeref = (X *) *sPartials; - for (int i = 0; i < sPartialsLength; i++) - sPartialsDeref[i] = extraParams[0]; - + int sPartialsLength = sMemSize / sizeof(X); + X *sPartialsDeref = (X *)*sPartials; + for (int i = 0; i < sPartialsLength; i++) sPartialsDeref[i] = extraParams[0]; } +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceLongFunction, , + LIBND4J_TYPES, LONG_TYPES); -BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceLongFunction, , LIBND4J_TYPES, LONG_TYPES); - -} -} - +} // namespace reduce +} // namespace functions diff --git a/libnd4j/include/loops/cuda/reduce/reduce_same.cu b/libnd4j/include/loops/cuda/reduce/reduce_same.cu index c1947314ea6c..cee846ee80c3 100644 --- a/libnd4j/include/loops/cuda/reduce/reduce_same.cu +++ b/libnd4j/include/loops/cuda/reduce/reduce_same.cu @@ -19,310 +19,364 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include -#include -#include -#include #include +#include +#include +#include +#include #include - +#include +#include using namespace simdOps; - //////////////////////////////////////////////////////////////////////// template __global__ void simpleReduce(void const* x, Nd4jLong const* xShapeInfo, - void *extraParams, - void *z, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) { - - functions::reduce::ReduceSameFunction::template transformCudaXD(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo, tadOffsets); + void* extraParams, void* z, + Nd4jLong const* zShapeInfo, int* dimension, + int dimensionLength, void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets) { + functions::reduce::ReduceSameFunction::template transformCudaXD( + x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, + reductionBuffer, tadOnlyShapeInfo, tadOffsets); } //////////////////////////////////////////////////////////////////////// template __global__ void simpleScalar(void const* x, Nd4jLong const* xShapeInfo, - void *extraParams, - void *z, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo) { - - functions::reduce::ReduceSameFunction::template execScalarCuda(x, xShapeInfo, extraParams, z, zShapeInfo, reductionBuffer, tadOnlyShapeInfo); + void* extraParams, void* z, + Nd4jLong const* zShapeInfo, int* dimension, + int dimensionLength, void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo) { + functions::reduce::ReduceSameFunction::template execScalarCuda( + x, xShapeInfo, extraParams, z, zShapeInfo, reductionBuffer, + tadOnlyShapeInfo); } - namespace functions { -namespace reduce { +namespace reduce { //////////////////////////////////////////////////////////////////////// template template -__device__ void ReduceSameFunction::aggregatePartials(void *vsPartials, Nd4jLong tid, Nd4jLong numItems, void *vextraParams) { +__device__ void ReduceSameFunction::aggregatePartials(void* vsPartials, + Nd4jLong tid, + Nd4jLong numItems, + void* vextraParams) { + // start the shared memory loop on the next power of 2 less + // than the block size. If block size is not a power of 2, + // accumulate the intermediate sums in the remainder range. - // start the shared memory loop on the next power of 2 less - // than the block size. If block size is not a power of 2, - // accumulate the intermediate sums in the remainder range. + auto sPartials = static_cast(vsPartials); + auto extraParams = static_cast(vextraParams); - auto sPartials = static_cast(vsPartials); - auto extraParams = static_cast(vextraParams); + Nd4jLong floorPow2 = numItems; - Nd4jLong floorPow2 = numItems; + if (floorPow2 & (floorPow2 - 1)) { + while (floorPow2 & (floorPow2 - 1)) floorPow2 &= floorPow2 - 1; - if (floorPow2 & (floorPow2 - 1)) { + if (tid >= floorPow2) + sPartials[tid - floorPow2] = OpType::update(sPartials[tid - floorPow2], + sPartials[tid], extraParams); - while (floorPow2 & (floorPow2 - 1)) - floorPow2 &= floorPow2 - 1; - - if (tid >= floorPow2) - sPartials[tid - floorPow2] = OpType::update(sPartials[tid - floorPow2], sPartials[tid], extraParams); - - __syncthreads(); - } + __syncthreads(); + } - for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; activeThreads >>= 1) { - if (tid < activeThreads && tid + activeThreads < numItems) - sPartials[tid] = OpType::update(sPartials[tid], sPartials[tid + activeThreads], extraParams); + for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; + activeThreads >>= 1) { + if (tid < activeThreads && tid + activeThreads < numItems) + sPartials[tid] = OpType::update( + sPartials[tid], sPartials[tid + activeThreads], extraParams); - __syncthreads(); - } + __syncthreads(); + } } //////////////////////////////////////////////////////////////////////// template template -__device__ void ReduceSameFunction::transformCudaXD( void const* vx, Nd4jLong const* xShapeInfo, - void *vextraParams, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - void *vreductionBuffer, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - auto reductionBuffer = reinterpret_cast(vreductionBuffer); - - if (OpType::requiresSpecialAccumulation) { - OpType::execSpecialCuda(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo, tadOffsets); - return; - } - - //shared memory space for storing intermediate results - __shared__ X* sPartials; - - __shared__ int tadLength, tadRank, numTads; - __shared__ Nd4jLong *tadShape, *tadStride; - __shared__ bool isPlainOutput; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast(shmem); - - isPlainOutput = shape::order(zShapeInfo) == 'c' && shape::elementWiseStride(zShapeInfo) == 1; - - tadLength = shape::length(tadOnlyShapeInfo); - tadRank = shape::rank(tadOnlyShapeInfo); - numTads = shape::length(xShapeInfo) / tadLength; - tadShape = shape::shapeOf(tadOnlyShapeInfo); - tadStride = shape::stride(tadOnlyShapeInfo); +__device__ void ReduceSameFunction::transformCudaXD( + void const* vx, Nd4jLong const* xShapeInfo, void* vextraParams, void* vz, + Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + void* vreductionBuffer, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + auto reductionBuffer = reinterpret_cast(vreductionBuffer); + + if (OpType::requiresSpecialAccumulation) { + OpType::execSpecialCuda(x, xShapeInfo, extraParams, z, zShapeInfo, + dimension, dimensionLength, reductionBuffer, + tadOnlyShapeInfo, tadOffsets); + return; + } + + // shared memory space for storing intermediate results + __shared__ X* sPartials; + + __shared__ int tadLength, tadRank, numTads; + __shared__ Nd4jLong *tadShape, *tadStride; + __shared__ bool isPlainOutput; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast(shmem); + + isPlainOutput = shape::order(zShapeInfo) == 'c' && + shape::elementWiseStride(zShapeInfo) == 1; + + tadLength = shape::length(tadOnlyShapeInfo); + tadRank = shape::rank(tadOnlyShapeInfo); + numTads = shape::length(xShapeInfo) / tadLength; + tadShape = shape::shapeOf(tadOnlyShapeInfo); + tadStride = shape::stride(tadOnlyShapeInfo); + } + __syncthreads(); + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + Nd4jLong tadOffsetForBlock = tadOffsets[r]; + sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock); + + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = + tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo); + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], + OpType::op(x[xOffset], extraParams), extraParams); } __syncthreads(); - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - - Nd4jLong tadOffsetForBlock = tadOffsets[r]; - sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock); - - for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { - auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo); - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[xOffset], extraParams), extraParams); - } - __syncthreads(); - - // aggregate. do NOT reduce for elements > tadLength - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, tadLength), extraParams); - __syncthreads(); + // aggregate. do NOT reduce for elements > tadLength + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, tadLength), + extraParams); + __syncthreads(); - if (threadIdx.x == 0) - z[isPlainOutput ? r : shape::getIndexOffset(r, zShapeInfo)] = OpType::postProcess(sPartials[threadIdx.x], tadLength, extraParams); - } + if (threadIdx.x == 0) + z[isPlainOutput ? r : shape::getIndexOffset(r, zShapeInfo)] = + OpType::postProcess(sPartials[threadIdx.x], tadLength, extraParams); + } } //////////////////////////////////////////////////////////////////////// template -__device__ void ReduceSameFunction::execScalarCudaLegacy(int opNum, void const* vx, Nd4jLong const* xShapeInfo, - void *vextraParams, - void *vz, Nd4jLong const* zShapeInfo, - void *vreductionBuffer, - Nd4jLong const* tadOnlyShapeInfo) { - DISPATCH_BY_OPNUM_T(execScalarCuda, PARAMS(vx, xShapeInfo, vextraParams, vz, zShapeInfo, vreductionBuffer, tadOnlyShapeInfo), REDUCE_SAME_OPS); +__device__ void ReduceSameFunction::execScalarCudaLegacy( + int opNum, void const* vx, Nd4jLong const* xShapeInfo, void* vextraParams, + void* vz, Nd4jLong const* zShapeInfo, void* vreductionBuffer, + Nd4jLong const* tadOnlyShapeInfo) { + DISPATCH_BY_OPNUM_T(execScalarCuda, + PARAMS(vx, xShapeInfo, vextraParams, vz, zShapeInfo, + vreductionBuffer, tadOnlyShapeInfo), + REDUCE_SAME_OPS); } //////////////////////////////////////////////////////////////////////// template template -__device__ void ReduceSameFunction::execScalarCuda(void const* vx, Nd4jLong const* xShapeInfo, - void *vextraParams, - void * vz, Nd4jLong const* zShapeInfo, - void *vreductionBuffer, - Nd4jLong const* tadOnlyShapeInfo) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - auto reductionBuffer = reinterpret_cast(vreductionBuffer); - - auto tid = blockDim.x * blockIdx.x + threadIdx.x; - - //shared memory space for storing intermediate results - __shared__ X* sPartials; - __shared__ Nd4jLong xEws; - __shared__ Nd4jLong len; - - if(threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast(shmem); - xEws = shape::elementWiseStride(xShapeInfo); - len = shape::length(xShapeInfo); - } +__device__ void ReduceSameFunction::execScalarCuda( + void const* vx, Nd4jLong const* xShapeInfo, void* vextraParams, void* vz, + Nd4jLong const* zShapeInfo, void* vreductionBuffer, + Nd4jLong const* tadOnlyShapeInfo) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + auto reductionBuffer = reinterpret_cast(vreductionBuffer); + + auto tid = blockDim.x * blockIdx.x + threadIdx.x; + + // shared memory space for storing intermediate results + __shared__ X* sPartials; + __shared__ Nd4jLong xEws; + __shared__ Nd4jLong len; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast(shmem); + xEws = shape::elementWiseStride(xShapeInfo); + len = shape::length(xShapeInfo); + } + __syncthreads(); + sPartials[threadIdx.x] = OpType::startingValue(x); + + if (xEws > 0) + for (int i = tid; i < len; i += (blockDim.x * gridDim.x)) + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], + OpType::op(x[i * xEws], extraParams), extraParams); + else + for (int i = tid; i < len; i += blockDim.x * gridDim.x) + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], + OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], extraParams), + extraParams); + + __syncthreads(); + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, len), + extraParams); + __syncthreads(); + + if (gridDim.x > 1) { + unsigned int* tc = (unsigned int*)reductionBuffer; + __shared__ bool amLast; + + tid = threadIdx.x; + if (threadIdx.x == 0) + reductionBuffer[blockIdx.x] = + sPartials[0]; // this->postProcess(sPartials[0],len,extraParams); + + __threadfence(); __syncthreads(); - sPartials[threadIdx.x] = OpType::startingValue(x); - if (xEws > 0) - for (int i = tid; i < len; i += (blockDim.x * gridDim.x)) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[i * xEws], extraParams), extraParams); - else - for (int i = tid; i < len; i += blockDim.x * gridDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], extraParams), extraParams); + if (threadIdx.x == 0) { + unsigned int ticket = atomicInc(&tc[16384], gridDim.x); + amLast = (ticket == gridDim.x - 1); + } __syncthreads(); - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, len), extraParams); - __syncthreads(); - - if (gridDim.x > 1) { - - unsigned int *tc = (unsigned int *)reductionBuffer; - __shared__ bool amLast; - - tid = threadIdx.x; - if (threadIdx.x == 0) - reductionBuffer[blockIdx.x] = sPartials[0];//this->postProcess(sPartials[0],len,extraParams); - - __threadfence(); - __syncthreads(); - if (threadIdx.x == 0) { - unsigned int ticket = atomicInc(&tc[16384], gridDim.x); - amLast = (ticket == gridDim.x - 1); - } + if (amLast) { + tc[16384] = 0; + sPartials[threadIdx.x] = OpType::startingValue(x); - __syncthreads(); + for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], reductionBuffer[i], extraParams); - if (amLast) { - tc[16384] = 0; - sPartials[threadIdx.x] = OpType::startingValue(x); + __syncthreads(); + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(gridDim.x, blockDim.x), + extraParams); + __syncthreads(); - for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], reductionBuffer[i], extraParams); - - __syncthreads(); - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(gridDim.x, blockDim.x), extraParams); - __syncthreads(); - - if (threadIdx.x == 0) { - z[0] = OpType::postProcess(sPartials[0], len, extraParams); - } - } + if (threadIdx.x == 0) { + z[0] = OpType::postProcess(sPartials[0], len, extraParams); + } } - else { - - if (threadIdx.x == 0) { - auto tc = reinterpret_cast(reductionBuffer); - tc[16384] = 0; - z[0] = OpType::postProcess(sPartials[0], len, extraParams); - } + } else { + if (threadIdx.x == 0) { + auto tc = reinterpret_cast(reductionBuffer); + tc[16384] = 0; + z[0] = OpType::postProcess(sPartials[0], len, extraParams); } + } } //////////////////////////////////////////////////////////////////////// template -template -__host__ void ReduceSameFunction::intermediateXD(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, void *extraParams, void *z, Nd4jLong const* zShapeInfo, Nd4jLong const* hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - - if(shape::isEmpty(hXShapeInfo)) { - - if(shape::isEmpty(hZShapeInfo)) - return; - - const auto startingVal = static_cast(OpType::startingValue(reinterpret_cast(x))); - - auto res = cudaMemcpyAsync(sd::LaunchContext::defaultContext()->getScalarPointer(), &startingVal, sizeof(X), cudaMemcpyHostToDevice, *stream); - if (res != 0) - throw sd::cuda_exception::build("ReduceSameFunction::intermediateXD: failed to copy temporary scalar", res); - - auto ptr = sd::LaunchContext::defaultContext()->getScalarPointer(); - - // scalar assign - functions::scalar::ScalarTransform::executeCudaShaped(launchDims, stream, 14, z, zShapeInfo, hXShapeInfo, z, zShapeInfo, hZShapeInfo, ptr, nullptr); - } - else { - simpleReduce<<>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets); - } +template +__host__ void ReduceSameFunction::intermediateXD( + dim3 launchDims, cudaStream_t* stream, void const* x, + Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, void* extraParams, + void* z, Nd4jLong const* zShapeInfo, Nd4jLong const* hZShapeInfo, + int* dimension, int dimensionLength, void* reductionPointer, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { + if (shape::isEmpty(hXShapeInfo)) { + if (shape::isEmpty(hZShapeInfo)) return; + + const auto startingVal = + static_cast(OpType::startingValue(reinterpret_cast(x))); + + auto res = cudaMemcpyAsync( + sd::LaunchContext::defaultContext()->getScalarPointer(), &startingVal, + sizeof(X), cudaMemcpyHostToDevice, *stream); + if (res != 0) + throw sd::cuda_exception::build( + "ReduceSameFunction::intermediateXD: failed to copy temporary " + "scalar", + res); + + auto ptr = sd::LaunchContext::defaultContext()->getScalarPointer(); + + // scalar assign + functions::scalar::ScalarTransform::executeCudaShaped( + launchDims, stream, 14, z, zShapeInfo, hXShapeInfo, z, zShapeInfo, + hZShapeInfo, ptr, nullptr); + } else { + simpleReduce + <<>>( + x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, reductionPointer, tadShapeInfo, tadOffsets); + } } //////////////////////////////////////////////////////////////////////// template -template -__host__ void ReduceSameFunction::intermediateScalar(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, void *extraParams, void *z, Nd4jLong const* zShapeInfo, Nd4jLong const* hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo) { - - if (shape::isEmpty(hXShapeInfo)) { - - if (shape::isEmpty(hZShapeInfo)) - return; - - const auto startingVal = static_cast(OpType::startingValue(reinterpret_cast(x))); - - auto res = cudaMemcpyAsync(z, &startingVal, sizeof(X), cudaMemcpyHostToDevice, *stream); - if (res != 0) - throw sd::cuda_exception::build("ReduceSameFunction::intermediateScalar: failed to copy resulting scalar", res); - } - else { - simpleScalar<<>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo); - } +template +__host__ void ReduceSameFunction::intermediateScalar( + dim3 launchDims, cudaStream_t* stream, void const* x, + Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, void* extraParams, + void* z, Nd4jLong const* zShapeInfo, Nd4jLong const* hZShapeInfo, + int* dimension, int dimensionLength, void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo) { + if (shape::isEmpty(hXShapeInfo)) { + if (shape::isEmpty(hZShapeInfo)) return; + + const auto startingVal = + static_cast(OpType::startingValue(reinterpret_cast(x))); + + auto res = cudaMemcpyAsync(z, &startingVal, sizeof(X), + cudaMemcpyHostToDevice, *stream); + if (res != 0) + throw sd::cuda_exception::build( + "ReduceSameFunction::intermediateScalar: failed to copy resulting " + "scalar", + res); + } else { + simpleScalar + <<>>( + x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, reductionBuffer, tadOnlyShapeInfo); + } } //////////////////////////////////////////////////////////////////////// template -_CUDA_H void ReduceSameFunction::execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, void *extraParams, void *z, Nd4jLong const* zShapeInfo, Nd4jLong const* hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo) { - - DISPATCH_BY_OPNUM_T(intermediateScalar, PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, zShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo), REDUCE_SAME_OPS); - sd::DebugHelper::checkErrorCode(stream, "execReduceScalarSame(...) failed"); +_CUDA_H void ReduceSameFunction::execReduceScalar( + dim3 launchDims, cudaStream_t* stream, int opNum, void const* x, + Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, void* extraParams, + void* z, Nd4jLong const* zShapeInfo, Nd4jLong const* hZShapeInfo, + int* dimension, int dimensionLength, void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo) { + DISPATCH_BY_OPNUM_T( + intermediateScalar, + PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, + zShapeInfo, hZShapeInfo, dimension, dimensionLength, + reductionBuffer, tadOnlyShapeInfo), + REDUCE_SAME_OPS); + sd::DebugHelper::checkErrorCode(stream, "execReduceScalarSame(...) failed"); } //////////////////////////////////////////////////////////////////////// template -_CUDA_H void ReduceSameFunction::execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, void const* x, Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, void *extraParams, void *z, Nd4jLong const* zShapeInfo, Nd4jLong const* hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - - DISPATCH_BY_OPNUM_T(intermediateXD, PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, zShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), REDUCE_SAME_OPS); - DEBUG_KERNEL(stream, opNum); +_CUDA_H void ReduceSameFunction::execReduceXD( + dim3 launchDims, cudaStream_t* stream, int opNum, int rank, void const* x, + Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, void* extraParams, + void* z, Nd4jLong const* zShapeInfo, Nd4jLong const* hZShapeInfo, + int* dimension, int dimensionLength, void* reductionPointer, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { + DISPATCH_BY_OPNUM_T( + intermediateXD, + PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, + zShapeInfo, hZShapeInfo, dimension, dimensionLength, + reductionPointer, tadShapeInfo, tadOffsets), + REDUCE_SAME_OPS); + DEBUG_KERNEL(stream, opNum); } //////////////////////////////////////////////////////////////////////// template -__device__ void initializeShared(X *extraParams, X **sPartials, int sMemSize) { - int sPartialsLength = sMemSize / sizeof(X); - X *sPartialsDeref = (X *) *sPartials; - for (int i = 0; i < sPartialsLength; i++) - sPartialsDeref[i] = extraParams[0]; - +__device__ void initializeShared(X* extraParams, X** sPartials, int sMemSize) { + int sPartialsLength = sMemSize / sizeof(X); + X* sPartialsDeref = (X*)*sPartials; + for (int i = 0; i < sPartialsLength; i++) sPartialsDeref[i] = extraParams[0]; } +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT ReduceSameFunction, , + LIBND4J_TYPES); -BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT ReduceSameFunction, , LIBND4J_TYPES); - -} -} \ No newline at end of file +} // namespace reduce +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/reduce3.chpp b/libnd4j/include/loops/cuda/reduce3.chpp index 2a301b81714d..1549a0aeefbe 100644 --- a/libnd4j/include/loops/cuda/reduce3.chpp +++ b/libnd4j/include/loops/cuda/reduce3.chpp @@ -17,546 +17,548 @@ // @author raver119@gmail.com // @author Yurii Shyrma (iuriish@yahoo.com), created on 19.11.2018 - -#include -#include #include -#include +#include #include +#include +#include using namespace simdOps; namespace functions { -namespace reduce3 { +namespace reduce3 { //////////////////////////////////////////////////////////////////////// template -__global__ void execScalarGeneric(const int opNum, - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *extraParams, - void *vz, Nd4jLong const* zShapeInfo, - int* allocationPointer, - void *reductionBuffer, - Nd4jLong const* tadOnlyShapeInfo) { - - Reduce3::execScalarCuda(opNum, vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, allocationPointer, reductionBuffer, tadOnlyShapeInfo); +__global__ void execScalarGeneric(const int opNum, void const* vx, + Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* extraParams, + void* vz, Nd4jLong const* zShapeInfo, + int* allocationPointer, void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo) { + Reduce3::execScalarCuda(opNum, vx, xShapeInfo, vy, yShapeInfo, + extraParams, vz, zShapeInfo, allocationPointer, + reductionBuffer, tadOnlyShapeInfo); } template -__global__ void execAllGeneric(const int opNum, - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *extraParams, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { - - Reduce3::execAllCuda(opNum, vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationPointer, tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); +__global__ void execAllGeneric( + const int opNum, void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* extraParams, void* vz, + Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + int postProcessOrNot, int* allocationPointer, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { + Reduce3::execAllCuda( + opNum, vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, + dimension, dimensionLength, postProcessOrNot, allocationPointer, + tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); } - //////////////////////////////////////////////////////////////////////// template -__global__ void execGeneric(const int opNum, - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *extraParams, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { - - Reduce3::execCuda(opNum, vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationPointer, tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); +__global__ void execGeneric( + const int opNum, void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* extraParams, void* vz, + Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + int postProcessOrNot, int* allocationPointer, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { + Reduce3::execCuda(opNum, vx, xShapeInfo, vy, yShapeInfo, extraParams, + vz, zShapeInfo, dimension, dimensionLength, + postProcessOrNot, allocationPointer, tadOnlyShapeInfo, + tadOffsets, yTadOnlyShapeInfo, yTadOffsets); } - ////////////////////////////////////////////////////////////////////////// template template -__device__ void Reduce3::aggregatePartials(void* vsPartials, Nd4jLong tid, Nd4jLong numItems, void *vextraParams) { +__device__ void Reduce3::aggregatePartials(void* vsPartials, Nd4jLong tid, + Nd4jLong numItems, + void* vextraParams) { + // start the shared memory loop on the next power of 2 less + // than the block size. If block size is not a power of 2, + // accumulate the intermediate sums in the remainder range. - // start the shared memory loop on the next power of 2 less - // than the block size. If block size is not a power of 2, - // accumulate the intermediate sums in the remainder range. + auto sPartials = reinterpret_cast(vsPartials); + auto extraParams = reinterpret_cast(vextraParams); + Nd4jLong floorPow2 = numItems; - auto sPartials = reinterpret_cast(vsPartials); - auto extraParams = reinterpret_cast(vextraParams); - Nd4jLong floorPow2 = numItems; + if (floorPow2 & (floorPow2 - 1)) { + while (floorPow2 & (floorPow2 - 1)) floorPow2 &= floorPow2 - 1; - if (floorPow2 & (floorPow2 - 1)) { + if (tid >= floorPow2) + sPartials[tid - floorPow2] = OpType::update(sPartials[tid - floorPow2], + sPartials[tid], extraParams); - while(floorPow2 & (floorPow2 - 1)) - floorPow2 &= floorPow2 - 1; - - if (tid >= floorPow2) - sPartials[tid - floorPow2] = OpType::update(sPartials[tid - floorPow2], sPartials[tid], extraParams); - - __syncthreads(); - } + __syncthreads(); + } - for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; activeThreads >>= 1) { - if (tid < activeThreads) - sPartials[tid] = OpType::update(sPartials[tid], sPartials[tid + activeThreads], extraParams); + for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; + activeThreads >>= 1) { + if (tid < activeThreads) + sPartials[tid] = OpType::update( + sPartials[tid], sPartials[tid + activeThreads], extraParams); - __syncthreads(); - } + __syncthreads(); + } } ////////////////////////////////////////////////////////////////////////// template -template -__device__ void Reduce3::execScalarCuda( void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *extraParams, - void* vz, Nd4jLong const* zShapeInfo, - int *allocationPointer, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ Z extraZ[3]; - __shared__ Z* sPartials; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast(shmem); - - extraZ[0] = (Z) 0.0f; - extraZ[1] = (Z) 0.0f; - - if (extraParams != nullptr) - extraZ[2] = static_cast(extraParams)[2]; - else - extraZ[2] = (Z) 0.0f; - } - __syncthreads(); - - sPartials[threadIdx.x] = OpType::startingValue(x); - Nd4jLong length = shape::length(xShapeInfo); - int xEws = shape::elementWiseStride(xShapeInfo); - int yEws = shape::elementWiseStride(yShapeInfo); - int tid = blockIdx.x * blockDim.x + threadIdx.x; - char xOrder = shape::order(xShapeInfo); - char yOrder = shape::order(yShapeInfo); - - if(xOrder == yOrder && (xEws > 0 && yEws > 0) && shape::strideDescendingCAscendingF(xShapeInfo) && shape::strideDescendingCAscendingF(yShapeInfo)) { - - if (xEws == 1 && yEws == 1) { - for(Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::opAtomic(x[i], y[i], extraZ), extraZ); - } - else { - for(Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::opAtomic(x[i * xEws], y[i * yEws], extraZ), extraZ); - } +template +__device__ void Reduce3::execScalarCuda( + void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* extraParams, void* vz, + Nd4jLong const* zShapeInfo, int* allocationPointer, void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + __shared__ Z extraZ[3]; + __shared__ Z* sPartials; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast(shmem); + + extraZ[0] = (Z)0.0f; + extraZ[1] = (Z)0.0f; + + if (extraParams != nullptr) + extraZ[2] = static_cast(extraParams)[2]; + else + extraZ[2] = (Z)0.0f; + } + __syncthreads(); + + sPartials[threadIdx.x] = OpType::startingValue(x); + Nd4jLong length = shape::length(xShapeInfo); + int xEws = shape::elementWiseStride(xShapeInfo); + int yEws = shape::elementWiseStride(yShapeInfo); + int tid = blockIdx.x * blockDim.x + threadIdx.x; + char xOrder = shape::order(xShapeInfo); + char yOrder = shape::order(yShapeInfo); + + if (xOrder == yOrder && (xEws > 0 && yEws > 0) && + shape::strideDescendingCAscendingF(xShapeInfo) && + shape::strideDescendingCAscendingF(yShapeInfo)) { + if (xEws == 1 && yEws == 1) { + for (Nd4jLong i = tid; i < length; i += gridDim.x * blockDim.x) + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], + OpType::opAtomic(x[i], y[i], extraZ), extraZ); + } else { + for (Nd4jLong i = tid; i < length; i += gridDim.x * blockDim.x) + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], + OpType::opAtomic(x[i * xEws], y[i * yEws], extraZ), extraZ); } - else { - sPartials[threadIdx.x] = OpType::startingValue(x); - auto threadCount = gridDim.x * blockDim.x; - for(Nd4jLong i = tid; i < length; i += threadCount) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto yOffset = shape::getIndexOffset(i, yShapeInfo); - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::opAtomic(x[xOffset], y[yOffset], extraZ), extraZ); - } + } else { + sPartials[threadIdx.x] = OpType::startingValue(x); + auto threadCount = gridDim.x * blockDim.x; + for (Nd4jLong i = tid; i < length; i += threadCount) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto yOffset = shape::getIndexOffset(i, yShapeInfo); + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], + OpType::opAtomic(x[xOffset], y[yOffset], extraZ), extraZ); + } + } + + __syncthreads(); + aggregatePartials(reinterpret_cast(sPartials), threadIdx.x, + sd::math::nd4j_min(blockDim.x, length), + extraZ); + __syncthreads(); + + if (gridDim.x > 1) { + auto tc = reinterpret_cast(reductionBuffer); + __shared__ bool amLast; + int rank = shape::rank(xShapeInfo); + tid = threadIdx.x; + Z* extraBuffer = (Z*)allocationPointer; + if (threadIdx.x == 0) { + reinterpret_cast(reductionBuffer)[blockIdx.x] = sPartials[0]; + extraBuffer[blockIdx.x] = extraZ[0]; + extraBuffer[gridDim.x + blockIdx.x] = extraZ[1]; } - __syncthreads(); - aggregatePartials(reinterpret_cast(sPartials), threadIdx.x, sd::math::nd4j_min(blockDim.x, length), extraZ); + __threadfence(); __syncthreads(); - if (gridDim.x > 1) { - - auto tc = reinterpret_cast(reductionBuffer); - __shared__ bool amLast; - int rank = shape::rank(xShapeInfo); - tid = threadIdx.x; - Z *extraBuffer = (Z *) allocationPointer; - if (threadIdx.x == 0) { - reinterpret_cast(reductionBuffer)[blockIdx.x] = sPartials[0]; - extraBuffer[blockIdx.x] = extraZ[0]; - extraBuffer[gridDim.x + blockIdx.x] = extraZ[1]; - } + if (threadIdx.x == 0) { + unsigned int ticket = atomicInc(&tc[16384], gridDim.x); + amLast = (ticket == gridDim.x - 1); + } - __threadfence(); - __syncthreads(); + sPartials[tid] = OpType::startingValue(x); + __syncthreads(); - if (threadIdx.x == 0) { - unsigned int ticket = atomicInc(&tc[16384], gridDim.x); - amLast = (ticket == gridDim.x - 1); + if (amLast) { + tc[16384] = 0; + sPartials[threadIdx.x] = OpType::startingValue(x); + + // TODO: later probably replace this. Right now we need extraZ sync for + // CosineSimilarity ONLY + if (tid == 0 && extraZ[0] != static_cast(0) && + extraZ[1] != static_cast(0)) { + extraZ[0] = 0.0; + extraZ[1] = 0.0; + for (int i = 0; i < gridDim.x; i++) { + extraZ[0] += extraBuffer[i]; + extraZ[1] += extraBuffer[gridDim.x + i]; } + } - sPartials[tid] = OpType::startingValue(x); - __syncthreads(); - - if (amLast) { - - tc[16384] = 0; - sPartials[threadIdx.x] = OpType::startingValue(x); - - // TODO: later probably replace this. Right now we need extraZ sync for CosineSimilarity ONLY - if (tid == 0 && extraZ[0] != static_cast(0) && extraZ[1] != static_cast(0)) { - extraZ[0] = 0.0; - extraZ[1] = 0.0; - for (int i = 0; i < gridDim.x; i++) { - extraZ[0] += extraBuffer[i]; - extraZ[1] += extraBuffer[gridDim.x + i]; - } - } + for (Nd4jLong i = threadIdx.x; i < gridDim.x; i += blockDim.x) + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], + static_cast(reductionBuffer)[i], extraZ); - for (Nd4jLong i = threadIdx.x; i < gridDim.x; i += blockDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], static_cast(reductionBuffer)[i], extraZ); + __syncthreads(); + aggregatePartials(reinterpret_cast(sPartials), threadIdx.x, + sd::math::nd4j_min(gridDim.x, blockDim.x), + extraZ); + __syncthreads(); - __syncthreads(); - aggregatePartials(reinterpret_cast(sPartials), threadIdx.x, sd::math::nd4j_min(gridDim.x, blockDim.x), extraZ); - __syncthreads(); - - if (threadIdx.x == 0) - z[0] = OpType::postProcess(sPartials[0], length, extraZ); - } + if (threadIdx.x == 0) + z[0] = OpType::postProcess(sPartials[0], length, extraZ); } - else { - - if (tid == 0) { - auto tc = reinterpret_cast(reductionBuffer); - tc[16384] = 0; - z[0] = OpType::postProcess(sPartials[0], length, extraZ); - //printf("Z: [%f]\n", (float) z[0]); - } + } else { + if (tid == 0) { + auto tc = reinterpret_cast(reductionBuffer); + tc[16384] = 0; + z[0] = OpType::postProcess(sPartials[0], length, extraZ); + // printf("Z: [%f]\n", (float) z[0]); } + } } ////////////////////////////////////////////////////////////////////////// template -template -__device__ void Reduce3::transformAll( void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *extraParams, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - Nd4jLong const* xTadShapeInfo, Nd4jLong const* xOffsets, - Nd4jLong const* yTadShapeInfo, Nd4jLong const* yOffsets) { - - auto dx = reinterpret_cast(vx); - auto dy = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - // initialize partials first - __shared__ Z* sPartials; - if(threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast(shmem); +template +__device__ void Reduce3::transformAll( + void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* extraParams, void* vz, + Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + int postProcessOrNot, int* allocationPointer, Nd4jLong const* xTadShapeInfo, + Nd4jLong const* xOffsets, Nd4jLong const* yTadShapeInfo, + Nd4jLong const* yOffsets) { + auto dx = reinterpret_cast(vx); + auto dy = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + // initialize partials first + __shared__ Z* sPartials; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast(shmem); + } + __syncthreads(); + + Z startingVal = OpType::startingValue(dx); + sPartials[threadIdx.x] = startingVal; + X* tempX = reinterpret_cast(sPartials) + blockDim.x; + + const int maxBlock = blockDim.x; + + __shared__ Z extraZ[OpType::extraParamsLen > 0 ? OpType::extraParamsLen : 1]; + + __shared__ int xTadLength; + __shared__ int yTadLength; + + __shared__ int xTads; + __shared__ int yTads; + + // reading initial data + if (threadIdx.x == 0) { + xTadLength = shape::length(xTadShapeInfo); + yTadLength = shape::length(yTadShapeInfo); + + xTads = shape::length(xShapeInfo) / xTadLength; + yTads = shape::length(yShapeInfo) / yTadLength; + } + __syncthreads(); + + int limit = xTadLength / maxBlock; + if (xTadLength % maxBlock > 0) limit++; + + for (int r = blockIdx.x; r < xTads; r += blockDim.x * gridDim.x) { + auto x = dx + xOffsets[r]; + + if (threadIdx.x < xTadLength && threadIdx.x < maxBlock) { + auto x0 = shape::getIndexOffset(threadIdx.x, xTadShapeInfo); + tempX[threadIdx.x] = x[x0]; } __syncthreads(); - Z startingVal = OpType::startingValue(dx); - sPartials[threadIdx.x] = startingVal; - X *tempX = reinterpret_cast(sPartials) + blockDim.x; - - const int maxBlock = blockDim.x; - - __shared__ Z extraZ[OpType::extraParamsLen > 0 ? OpType::extraParamsLen : 1]; - - __shared__ int xTadLength; - __shared__ int yTadLength; - - __shared__ int xTads; - __shared__ int yTads; - - //reading initial data - if (threadIdx.x == 0) { - xTadLength = shape::length(xTadShapeInfo); - yTadLength = shape::length(yTadShapeInfo); - - xTads = shape::length(xShapeInfo) / xTadLength; - yTads = shape::length(yShapeInfo) / yTadLength; - } - __syncthreads(); - - int limit = xTadLength / maxBlock; - if (xTadLength % maxBlock > 0) - limit++; - - for (int r = blockIdx.x; r < xTads; r += blockDim.x * gridDim.x) { - - auto x = dx + xOffsets[r]; - - if (threadIdx.x < xTadLength && threadIdx.x < maxBlock) { - auto x0 = shape::getIndexOffset(threadIdx.x, xTadShapeInfo); + for (int g = 0; g < yTads; g++) { + auto y = dy + yOffsets[g]; + int ri = (r * yTads) + g; + + sPartials[threadIdx.x] = startingVal; + if (OpType::extraParamsLen > 0 && threadIdx.x < OpType::extraParamsLen) + extraZ[threadIdx.x] = startingVal; + __syncthreads(); + + // we might have data too large for single cache block, rendering cache + // useless though :( + for (int t = 0; t < limit; t++) { + // we reset tempX IF we have >1 tiles + if (t >= 1 || (limit > 1 && g > 0)) + if (threadIdx.x + (t * maxBlock) < xTadLength) { + auto x0 = shape::getIndexOffset(threadIdx.x + (t * maxBlock), + xTadShapeInfo); tempX[threadIdx.x] = x[x0]; + } + + for (int f = threadIdx.x + (t * maxBlock); + f < xTadLength && f < threadIdx.x + ((t + 1) * maxBlock); + f += blockDim.x * gridDim.x) { + auto y0 = shape::getIndexOffset(f, yTadShapeInfo); + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], + OpType::opAtomic(tempX[threadIdx.x], y[y0], extraZ), extraZ); } - __syncthreads(); - - for (int g = 0; g < yTads; g++) { - - auto y = dy + yOffsets[g]; - int ri = (r * yTads) + g; - - sPartials[threadIdx.x] = startingVal; - if (OpType::extraParamsLen > 0 && threadIdx.x < OpType::extraParamsLen) - extraZ[threadIdx.x] = startingVal; - __syncthreads(); - // we might have data too large for single cache block, rendering cache useless though :( - for (int t = 0; t < limit; t++) { - - // we reset tempX IF we have >1 tiles - if (t >= 1 || (limit > 1 && g > 0)) - if (threadIdx.x + (t * maxBlock) < xTadLength) { - auto x0 = shape::getIndexOffset(threadIdx.x + (t * maxBlock), xTadShapeInfo); - tempX[threadIdx.x] = x[x0]; - } - - for (int f = threadIdx.x + (t * maxBlock); f < xTadLength && f < threadIdx.x + ((t + 1) * maxBlock); f += blockDim.x * gridDim.x) { - auto y0 = shape::getIndexOffset(f, yTadShapeInfo); - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::opAtomic(tempX[threadIdx.x], y[y0], extraZ), extraZ); - } - - // we MUST step through this block altogether - __syncthreads(); - } + // we MUST step through this block altogether + __syncthreads(); + } - aggregatePartials(reinterpret_cast(sPartials), threadIdx.x, sd::math::nd4j_min(blockDim.x, xTadLength), extraZ); - __syncthreads(); + aggregatePartials(reinterpret_cast(sPartials), threadIdx.x, + sd::math::nd4j_min(blockDim.x, xTadLength), + extraZ); + __syncthreads(); - if (threadIdx.x == 0) { - z[ri] = OpType::postProcess(sPartials[threadIdx.x], xTadLength, extraZ); - } + if (threadIdx.x == 0) { + z[ri] = OpType::postProcess(sPartials[threadIdx.x], xTadLength, extraZ); + } - __syncthreads(); - } - } + __syncthreads(); + } + } } ////////////////////////////////////////////////////////////////////////// template -template -__device__ void Reduce3::transform(void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *extraParams, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { - - // FIXME - if(shape::isScalar(zShapeInfo)) - return; - - if (yTadOnlyShapeInfo == nullptr) { - yTadOnlyShapeInfo = yShapeInfo; // execReduce3TAD case - } - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - Z startingVal = OpType::startingValue(x); - - __shared__ Z extraZ[OpType::extraParamsLen > 0 ? OpType::extraParamsLen : 1]; - - __shared__ Z* sPartials; - __shared__ int tadLen; - __shared__ Nd4jLong zLen; - __shared__ Nd4jLong xTadEws; - __shared__ Nd4jLong yTadEws; - __shared__ Nd4jLong yTadNum; - __shared__ char xTadOrder; - __shared__ char yTadOrder; - - if(threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast(shmem); - - tadLen = shape::length(tadOnlyShapeInfo); - zLen = shape::length(zShapeInfo); - xTadEws = shape::elementWiseStride(tadOnlyShapeInfo); - yTadEws = shape::elementWiseStride(yTadOnlyShapeInfo); - yTadNum = shape::length(yShapeInfo) / tadLen; - xTadOrder = shape::order(tadOnlyShapeInfo); - yTadOrder = shape::order(yTadOnlyShapeInfo); - } - __syncthreads(); - - sPartials[threadIdx.x] = startingVal; - - if(xTadEws >= 1 && yTadEws >= 1 && xTadOrder == yTadOrder) { - - for(int i = blockIdx.x; i < zLen; i+= gridDim.x) { - - Nd4jLong xOffset = tadOffsets[i]; - Nd4jLong yOffset = yTadNum == 1 ? 0 : yTadOffsets[i]; - - if (OpType::extraParamsLen > 0 && threadIdx.x < OpType::extraParamsLen) - extraZ[threadIdx.x] = startingVal; - - __syncthreads(); - - for (int j = threadIdx.x; j < tadLen; j += blockDim.x) { - - Nd4jLong xOffset2 = xOffset + j*xTadEws; - Nd4jLong yOffset2 = yOffset + j*yTadEws; - sPartials[threadIdx.x] = j < blockDim.x ? OpType::opAtomic(x[xOffset2], y[yOffset2], extraZ) : OpType::update(sPartials[threadIdx.x], OpType::opAtomic(x[xOffset2], y[yOffset2], extraZ), extraZ); - } - - __syncthreads(); - aggregatePartials(reinterpret_cast(sPartials), threadIdx.x, sd::math::nd4j_min(blockDim.x, tadLen), extraZ); - __syncthreads(); - - if (threadIdx.x == 0) - z[i] = OpType::postProcess(sPartials[threadIdx.x], tadLen, extraZ); - - __syncthreads(); - } +template +__device__ void Reduce3::transform( + void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* extraParams, void* vz, + Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + int postProcessOrNot, int* allocationPointer, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { + // FIXME + if (shape::isScalar(zShapeInfo)) return; + + if (yTadOnlyShapeInfo == nullptr) { + yTadOnlyShapeInfo = yShapeInfo; // execReduce3TAD case + } + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + Z startingVal = OpType::startingValue(x); + + __shared__ Z extraZ[OpType::extraParamsLen > 0 ? OpType::extraParamsLen : 1]; + + __shared__ Z* sPartials; + __shared__ int tadLen; + __shared__ Nd4jLong zLen; + __shared__ Nd4jLong xTadEws; + __shared__ Nd4jLong yTadEws; + __shared__ Nd4jLong yTadNum; + __shared__ char xTadOrder; + __shared__ char yTadOrder; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast(shmem); + + tadLen = shape::length(tadOnlyShapeInfo); + zLen = shape::length(zShapeInfo); + xTadEws = shape::elementWiseStride(tadOnlyShapeInfo); + yTadEws = shape::elementWiseStride(yTadOnlyShapeInfo); + yTadNum = shape::length(yShapeInfo) / tadLen; + xTadOrder = shape::order(tadOnlyShapeInfo); + yTadOrder = shape::order(yTadOnlyShapeInfo); + } + __syncthreads(); + + sPartials[threadIdx.x] = startingVal; + + if (xTadEws >= 1 && yTadEws >= 1 && xTadOrder == yTadOrder) { + for (int i = blockIdx.x; i < zLen; i += gridDim.x) { + Nd4jLong xOffset = tadOffsets[i]; + Nd4jLong yOffset = yTadNum == 1 ? 0 : yTadOffsets[i]; + + if (OpType::extraParamsLen > 0 && threadIdx.x < OpType::extraParamsLen) + extraZ[threadIdx.x] = startingVal; + + __syncthreads(); + + for (int j = threadIdx.x; j < tadLen; j += blockDim.x) { + Nd4jLong xOffset2 = xOffset + j * xTadEws; + Nd4jLong yOffset2 = yOffset + j * yTadEws; + sPartials[threadIdx.x] = + j < blockDim.x + ? OpType::opAtomic(x[xOffset2], y[yOffset2], extraZ) + : OpType::update( + sPartials[threadIdx.x], + OpType::opAtomic(x[xOffset2], y[yOffset2], extraZ), + extraZ); + } + + __syncthreads(); + aggregatePartials(reinterpret_cast(sPartials), threadIdx.x, + sd::math::nd4j_min(blockDim.x, tadLen), + extraZ); + __syncthreads(); + + if (threadIdx.x == 0) + z[i] = OpType::postProcess(sPartials[threadIdx.x], tadLen, extraZ); + + __syncthreads(); } - else { - - for(int i = blockIdx.x; i < zLen; i += gridDim.x) { - - Nd4jLong xOffset = tadOffsets[i]; - Nd4jLong yOffset = yTadNum == 1 ? 0 : yTadOffsets[i]; - - if (OpType::extraParamsLen > 0 && threadIdx.x < OpType::extraParamsLen) - extraZ[threadIdx.x] = startingVal; - - __syncthreads(); - - for (int j = threadIdx.x; j < tadLen; j += blockDim.x) { - - Nd4jLong xOffset2 = xOffset + shape::getIndexOffset(j, tadOnlyShapeInfo); - Nd4jLong yOffset2 = yOffset + shape::getIndexOffset(j, yTadOnlyShapeInfo); - sPartials[threadIdx.x] = j < blockDim.x ? OpType::opAtomic(x[xOffset2], y[yOffset2], extraZ) : OpType::update(sPartials[threadIdx.x], OpType::opAtomic(x[xOffset2], y[yOffset2], extraZ), extraZ); - - } - - __syncthreads(); - aggregatePartials(reinterpret_cast(sPartials), threadIdx.x, sd::math::nd4j_min(blockDim.x, tadLen), extraZ); - __syncthreads(); - - if (threadIdx.x == 0) - z[i] = OpType::postProcess(sPartials[threadIdx.x], tadLen, extraZ); - - __syncthreads(); - } + } else { + for (int i = blockIdx.x; i < zLen; i += gridDim.x) { + Nd4jLong xOffset = tadOffsets[i]; + Nd4jLong yOffset = yTadNum == 1 ? 0 : yTadOffsets[i]; + + if (OpType::extraParamsLen > 0 && threadIdx.x < OpType::extraParamsLen) + extraZ[threadIdx.x] = startingVal; + + __syncthreads(); + + for (int j = threadIdx.x; j < tadLen; j += blockDim.x) { + Nd4jLong xOffset2 = + xOffset + shape::getIndexOffset(j, tadOnlyShapeInfo); + Nd4jLong yOffset2 = + yOffset + shape::getIndexOffset(j, yTadOnlyShapeInfo); + sPartials[threadIdx.x] = + j < blockDim.x + ? OpType::opAtomic(x[xOffset2], y[yOffset2], extraZ) + : OpType::update( + sPartials[threadIdx.x], + OpType::opAtomic(x[xOffset2], y[yOffset2], extraZ), + extraZ); + } + + __syncthreads(); + aggregatePartials(reinterpret_cast(sPartials), threadIdx.x, + sd::math::nd4j_min(blockDim.x, tadLen), + extraZ); + __syncthreads(); + + if (threadIdx.x == 0) + z[i] = OpType::postProcess(sPartials[threadIdx.x], tadLen, extraZ); + + __syncthreads(); } + } } ////////////////////////////////////////////////////////////////////////// template -__device__ void Reduce3::execCuda(const int opNum, - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *extraParams, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { - - DISPATCH_BY_OPNUM_TT(transform, PARAMS(vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationPointer, tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets), REDUCE3_OPS); +__device__ void Reduce3::execCuda( + const int opNum, void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* extraParams, void* vz, + Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + int postProcessOrNot, int* allocationPointer, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { + DISPATCH_BY_OPNUM_TT( + transform, + PARAMS(vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, + dimension, dimensionLength, postProcessOrNot, allocationPointer, + tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets), + REDUCE3_OPS); } - - ////////////////////////////////////////////////////////////////////////// template -__device__ void Reduce3::execAllCuda( const int opNum, - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *extraParams, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { - - DISPATCH_BY_OPNUM_TT(transformAll, PARAMS(vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationPointer, tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets), REDUCE3_OPS); +__device__ void Reduce3::execAllCuda( + const int opNum, void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* extraParams, void* vz, + Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + int postProcessOrNot, int* allocationPointer, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { + DISPATCH_BY_OPNUM_TT( + transformAll, + PARAMS(vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, + dimension, dimensionLength, postProcessOrNot, allocationPointer, + tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets), + REDUCE3_OPS); } - ////////////////////////////////////////////////////////////////////////// template -__device__ void Reduce3::execScalarCuda(const int opNum, - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *extraParams, - void *vz, Nd4jLong const* zShapeInfo, - int * allocationPointer, void *reductionBuffer, - Nd4jLong const* tadOnlyShapeInfo) { - - DISPATCH_BY_OPNUM_TT(execScalarCuda, PARAMS(vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, allocationPointer, reductionBuffer, tadOnlyShapeInfo), REDUCE3_OPS); +__device__ void Reduce3::execScalarCuda( + const int opNum, void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* extraParams, void* vz, + Nd4jLong const* zShapeInfo, int* allocationPointer, void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo) { + DISPATCH_BY_OPNUM_TT( + execScalarCuda, + PARAMS(vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, + allocationPointer, reductionBuffer, tadOnlyShapeInfo), + REDUCE3_OPS); } - //////////////////////////////////////////////////////////////////////// template -__host__ void Reduce3::exec(dim3 launchDims, cudaStream_t *stream, - int opNum, - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *extraParams, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { - - execGeneric<<>>(opNum, vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationPointer, tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); - sd::DebugHelper::checkErrorCode(stream, "reduce3exec(...) failed"); +__host__ void Reduce3::exec( + dim3 launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, + void* extraParams, void* vz, Nd4jLong const* zShapeInfo, int* dimension, + int dimensionLength, int postProcessOrNot, int* allocationPointer, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { + execGeneric<<>>( + opNum, vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, + dimension, dimensionLength, postProcessOrNot, allocationPointer, + tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); + sd::DebugHelper::checkErrorCode(stream, "reduce3exec(...) failed"); } //////////////////////////////////////////////////////////////////////// - template - __host__ void Reduce3::execAll(dim3 launchDims, cudaStream_t *stream, - int opNum, - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *extraParams, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { - - execAllGeneric<<>>(opNum, vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationPointer, tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); - sd::DebugHelper::checkErrorCode(stream, "execAllGeneric(...) failed"); - } +template +__host__ void Reduce3::execAll( + dim3 launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, + void* extraParams, void* vz, Nd4jLong const* zShapeInfo, int* dimension, + int dimensionLength, int postProcessOrNot, int* allocationPointer, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { + execAllGeneric<<>>( + opNum, vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, + dimension, dimensionLength, postProcessOrNot, allocationPointer, + tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); + sd::DebugHelper::checkErrorCode(stream, "execAllGeneric(...) failed"); +} //////////////////////////////////////////////////////////////////////// template -__host__ void Reduce3::execScalar(dim3 launchDims, cudaStream_t *stream, - int opNum, - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *extraParams, - void *vz, Nd4jLong const* zShapeInfo, - int* allocationPointer, - void *reductionBuffer, - Nd4jLong const* tadOnlyShapeInfo) { - - execScalarGeneric<<>>(opNum, vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, allocationPointer, reductionBuffer, tadOnlyShapeInfo); - sd::DebugHelper::checkErrorCode(stream, "execScalarGeneric(...) failed"); +__host__ void Reduce3::execScalar( + dim3 launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, + void* extraParams, void* vz, Nd4jLong const* zShapeInfo, + int* allocationPointer, void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo) { + execScalarGeneric + <<>>( + opNum, vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, + allocationPointer, reductionBuffer, tadOnlyShapeInfo); + sd::DebugHelper::checkErrorCode(stream, "execScalarGeneric(...) failed"); } +// BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES, +// FLOAT_TYPES); - - - - //BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES); - -} -} \ No newline at end of file +} // namespace reduce3 +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/reduce3.cu b/libnd4j/include/loops/cuda/reduce3.cu index c1d63e8dd245..cf8f8c3481bd 100644 --- a/libnd4j/include/loops/cuda/reduce3.cu +++ b/libnd4j/include/loops/cuda/reduce3.cu @@ -18,16 +18,12 @@ // @author raver119@gmail.com // - -#include -#include #include -#include +#include #include +#include +#include namespace functions { - namespace reduce3 { - - - } -} \ No newline at end of file +namespace reduce3 {} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/scalar.chpp b/libnd4j/include/loops/cuda/scalar.chpp index 93b76f910d34..6dd31f9266f4 100644 --- a/libnd4j/include/loops/cuda/scalar.chpp +++ b/libnd4j/include/loops/cuda/scalar.chpp @@ -21,152 +21,184 @@ #ifndef SCALAR_CU #define SCALAR_CU -#include "loops/scalar.h" #include #include -#include #include +#include #include +#include "loops/scalar.h" + using namespace simdOps; //////////////////////////////////////////////////////////////////////////////// template -__global__ static void scalarSimpleShaped(void const* vx, void const* vscalar, Nd4jLong const* xShapeInfo, void *vparams, void *vz, Nd4jLong const* zShapeInfo, int *allocationBuffer) { - - auto scalar = reinterpret_cast(vscalar)[0]; - auto x = reinterpret_cast(vx); - auto params = reinterpret_cast(vparams); - auto z = reinterpret_cast(vz); - - int totalThreads = gridDim.x * blockDim.x; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - __shared__ Nd4jLong length; - if(threadIdx.x == 0) { - length = shape::length(xShapeInfo); +__global__ static void scalarSimpleShaped(void const* vx, void const* vscalar, + Nd4jLong const* xShapeInfo, + void* vparams, void* vz, + Nd4jLong const* zShapeInfo, + int* allocationBuffer) { + auto scalar = reinterpret_cast(vscalar)[0]; + auto x = reinterpret_cast(vx); + auto params = reinterpret_cast(vparams); + auto z = reinterpret_cast(vz); + + int totalThreads = gridDim.x * blockDim.x; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ Nd4jLong length; + if (threadIdx.x == 0) { + length = shape::length(xShapeInfo); + } + __syncthreads(); + + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + + auto xOrder = shape::order(xShapeInfo); + auto zOrder = shape::order(zShapeInfo); + + if (xEws >= 1 && zEws >= 1 && xOrder == zOrder) { + for (Nd4jLong i = tid; i < length; i += totalThreads) { + z[i * zEws] = OpType::op(x[i * xEws], scalar, params); } - __syncthreads(); - - auto xEws = shape::elementWiseStride(xShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); - - auto xOrder = shape::order(xShapeInfo); - auto zOrder = shape::order(zShapeInfo); - - - if (xEws >= 1 && zEws >= 1 && xOrder == zOrder) { - for (Nd4jLong i = tid; i < length; i += totalThreads) { - z[i * zEws] = OpType::op(x[i * xEws], scalar, params); - } - } else { - for (Nd4jLong i = tid; i < length; i += totalThreads) { - z[shape::getIndexOffset(i, zShapeInfo)] = OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], scalar, params); - } + } else { + for (Nd4jLong i = tid; i < length; i += totalThreads) { + z[shape::getIndexOffset(i, zShapeInfo)] = + OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], scalar, params); } - + } } //////////////////////////////////////////////////////////////////////////////// template -__global__ static void scalarAlongDimension(void const* vx, Nd4jLong const* xShapeInfo, - void* vextraParams, - void* vz, Nd4jLong const* zShapeInfo, - void const* vscalars, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - auto x = reinterpret_cast(vx); - auto extraParams = reinterpret_cast(vextraParams); - auto z = reinterpret_cast(vz); - auto scalars = reinterpret_cast(vscalars); - - if (tadShapeInfoZ == nullptr) { - tadShapeInfoZ = tadShapeInfo; - tadOffsetsZ = tadOffsets; +__global__ static void scalarAlongDimension( + void const* vx, Nd4jLong const* xShapeInfo, void* vextraParams, void* vz, + Nd4jLong const* zShapeInfo, void const* vscalars, int* dimension, + int dimensionLength, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + auto x = reinterpret_cast(vx); + auto extraParams = reinterpret_cast(vextraParams); + auto z = reinterpret_cast(vz); + auto scalars = reinterpret_cast(vscalars); + + if (tadShapeInfoZ == nullptr) { + tadShapeInfoZ = tadShapeInfo; + tadOffsetsZ = tadOffsets; + } + + // tad preparation + auto tadEws = shape::elementWiseStride(tadShapeInfo); + auto zEws = shape::elementWiseStride(tadShapeInfoZ); + auto tadLength = shape::length(tadShapeInfo); // shape::tadLength(xShapeInfo, + // dimension, dimensionLength); + auto numTads = shape::length(xShapeInfo) / tadLength; + + if (tadEws > 0 && zEws > 0 && + shape::order(tadShapeInfo) == shape::order(zShapeInfo)) { + // main loop, rolling over tads + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + Z* oZ = z + tadOffsetsZ[r]; + auto oX = x + tadOffsets[r]; + + auto s = scalars[r]; + + for (int f = threadIdx.x; f < tadLength; f += blockDim.x) + oZ[f * zEws] = OpType::op(oX[f * tadEws], s, extraParams); } + } else { + // main loop, rolling over tads + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + Z* oZ = z + tadOffsetsZ[r]; + auto oX = x + tadOffsets[r]; - // tad preparation - auto tadEws = shape::elementWiseStride(tadShapeInfo); - auto zEws = shape::elementWiseStride(tadShapeInfoZ); - auto tadLength = shape::length(tadShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); - auto numTads =shape::length(xShapeInfo) / tadLength; - - if (tadEws > 0 && zEws > 0 && shape::order(tadShapeInfo) == shape::order(zShapeInfo)) { - - // main loop, rolling over tads - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - Z *oZ = z + tadOffsetsZ[r]; - auto oX = x + tadOffsets[r]; + auto s = scalars[r]; - auto s = scalars[r]; - - for (int f = threadIdx.x; f < tadLength; f += blockDim.x) - oZ[f * zEws] = OpType::op(oX[f * tadEws], s, extraParams); - } - } else { - // main loop, rolling over tads - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - Z *oZ = z + tadOffsetsZ[r]; - auto oX = x + tadOffsets[r]; - - auto s = scalars[r]; - - for (int f = threadIdx.x; f < tadLength; f += blockDim.x) - oZ[shape::getIndexOffset(f, tadShapeInfoZ)] = OpType::op(oX[shape::getIndexOffset(f, tadShapeInfo)], s, extraParams); - } + for (int f = threadIdx.x; f < tadLength; f += blockDim.x) + oZ[shape::getIndexOffset(f, tadShapeInfoZ)] = OpType::op( + oX[shape::getIndexOffset(f, tadShapeInfo)], s, extraParams); } + } } - namespace functions { -namespace scalar { +namespace scalar { //////////////////////////////////////////////////////////////////////////////// -template -template -void _CUDA_H ScalarTransform::intermediateShaped(dim3& launchDims, cudaStream_t *stream, void const* vx, Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void *vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, void const* vscalar, void *vextraParams, int *allocPointer){ - - auto xEws = shape::elementWiseStride(hxShapeInfo); - auto xOrder = shape::order(hxShapeInfo); - - auto zEws = shape::elementWiseStride(hzShapeInfo); - auto zOrder = shape::order(hzShapeInfo); - - auto length = shape::length(hxShapeInfo); - - scalarSimpleShaped<<>>(vx, vscalar, xShapeInfo, vextraParams, vz, zShapeInfo, allocPointer); - sd::DebugHelper::checkErrorCode(stream, "scalarSimpleShapedA(...) failed"); +template +template +void _CUDA_H ScalarTransform::intermediateShaped( + dim3& launchDims, cudaStream_t* stream, void const* vx, + Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void* vz, + Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, + void const* vscalar, void* vextraParams, int* allocPointer) { + auto xEws = shape::elementWiseStride(hxShapeInfo); + auto xOrder = shape::order(hxShapeInfo); + + auto zEws = shape::elementWiseStride(hzShapeInfo); + auto zOrder = shape::order(hzShapeInfo); + + auto length = shape::length(hxShapeInfo); + + scalarSimpleShaped + <<>>( + vx, vscalar, xShapeInfo, vextraParams, vz, zShapeInfo, allocPointer); + sd::DebugHelper::checkErrorCode(stream, "scalarSimpleShapedA(...) failed"); } //////////////////////////////////////////////////////////////////////////////// -template -template -void _CUDA_H ScalarTransform::intermediateAlongDimension(dim3& launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void *z, Nd4jLong const* zShapeInfo, void const* scalars, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - scalarAlongDimension<<>>(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); - sd::DebugHelper::checkErrorCode(stream, "scalarAlongDimA(...) failed"); +template +template +void _CUDA_H ScalarTransform::intermediateAlongDimension( + dim3& launchDims, cudaStream_t* stream, void const* x, + Nd4jLong const* xShapeInfo, void* z, Nd4jLong const* zShapeInfo, + void const* scalars, void* extraParams, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + scalarAlongDimension + <<>>( + x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, + dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, + tadOffsetsZ); + sd::DebugHelper::checkErrorCode(stream, "scalarAlongDimA(...) failed"); } //////////////////////////////////////////////////////////////////////////////// -template -void ScalarTransform::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void *vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, void const* vscalar, void *vextraParams) { - - if (sd::Environment::getInstance().isDebugAndVerbose()) - printf("H14 opNum:[%i]\n", opNum); - - DISPATCH_BY_OPNUM_TTT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, hxShapeInfo, vz, zShapeInfo, hzShapeInfo, vscalar, vextraParams, nullptr), SCALAR_OPS); +template +void ScalarTransform::executeCudaShaped( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void* vz, + Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, + void const* vscalar, void* vextraParams) { + if (sd::Environment::getInstance().isDebugAndVerbose()) + printf("H14 opNum:[%i]\n", opNum); + + DISPATCH_BY_OPNUM_TTT( + intermediateShaped, + PARAMS(launchDims, stream, vx, xShapeInfo, hxShapeInfo, vz, zShapeInfo, + hzShapeInfo, vscalar, vextraParams, nullptr), + SCALAR_OPS); } //////////////////////////////////////////////////////////////////////////////// -template -void ScalarTransform::executeCudaAlongDimension(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, void *vz, Nd4jLong const* zShapeInfo, void const* vscalars, void *vextraParams, int *dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - DISPATCH_BY_OPNUM_TTT(intermediateAlongDimension, PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalars, vextraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), SCALAR_OPS); +template +void ScalarTransform::executeCudaAlongDimension( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + void const* vscalars, void* vextraParams, int* dimension, + int dimensionLength, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + DISPATCH_BY_OPNUM_TTT( + intermediateAlongDimension, + PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalars, + vextraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, + tadShapeInfoZ, tadOffsetsZ), + SCALAR_OPS); } -} -} - - +} // namespace scalar +} // namespace functions -#endif // SCALAR_CU \ No newline at end of file +#endif // SCALAR_CU \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/scalar.cu b/libnd4j/include/loops/cuda/scalar.cu index 26c3e5cb871b..c470fb46289c 100644 --- a/libnd4j/include/loops/cuda/scalar.cu +++ b/libnd4j/include/loops/cuda/scalar.cu @@ -18,15 +18,14 @@ // @author raver119@gmail.com // -#include "loops/scalar.h" #include #include -#include #include +#include #include -namespace functions { - namespace scalar { +#include "loops/scalar.h" - } -} \ No newline at end of file +namespace functions { +namespace scalar {} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/scalar_bool.cu b/libnd4j/include/loops/cuda/scalar_bool.cu index 0976e60ad9f6..ebdbbb213ced 100644 --- a/libnd4j/include/loops/cuda/scalar_bool.cu +++ b/libnd4j/include/loops/cuda/scalar_bool.cu @@ -19,218 +19,222 @@ // @author raver119@gmail.com // -#include "../scalar_bool.h" #include #include #include "../legacy_ops.h" +#include "../scalar_bool.h" using namespace simdOps; //////////////////////////////////////////////////////////////////////// template -__global__ void scalarAlongDimension(void const* x, Nd4jLong const* xShapeInfo, - void *extraParams, - void *z, Nd4jLong const* zShapeInfo, - void const* scalars, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - functions::scalar::ScalarBoolTransform::template transformCuda(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); +__global__ void scalarAlongDimension( + void const* x, Nd4jLong const* xShapeInfo, void* extraParams, void* z, + Nd4jLong const* zShapeInfo, void const* scalars, int* dimension, + int dimensionLength, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + functions::scalar::ScalarBoolTransform::template transformCuda( + x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, + dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); } - //////////////////////////////////////////////////////////////////////// template -__global__ void scalarSimpleShaped(void const* x, void const* y, Nd4jLong const* xShapeInfo, void *params, void *z, Nd4jLong const* zShapeInfo, int *allocationBuffer) { - - functions::scalar::ScalarBoolTransform::template transformCuda(y, x, xShapeInfo, params, z, zShapeInfo, allocationBuffer); +__global__ void scalarSimpleShaped(void const* x, void const* y, + Nd4jLong const* xShapeInfo, void* params, + void* z, Nd4jLong const* zShapeInfo, + int* allocationBuffer) { + functions::scalar::ScalarBoolTransform::template transformCuda( + y, x, xShapeInfo, params, z, zShapeInfo, allocationBuffer); } - - - - // *********************************************************************// // *********************************************************************// namespace functions { -namespace scalar { +namespace scalar { //////////////////////////////////////////////////////////////////////// -template -template -__device__ void ScalarBoolTransform::transformCuda(void const* vscalar, - void const* vy, Nd4jLong const* yShapeInfo, - void *vparams, - void *vz, Nd4jLong const* zShapeInfo, - int *allocationBuffer) { - auto scalar = reinterpret_cast(vscalar)[0]; - auto y = reinterpret_cast(vy); - auto params = reinterpret_cast(vparams); - auto z = reinterpret_cast(vz); - - auto yRank = shape::rank(yShapeInfo); - auto yEWS = shape::elementWiseStride(yShapeInfo); - auto yShape = shape::shapeOf(yShapeInfo); - auto yStride = shape::stride(yShapeInfo); - - auto zRank = shape::rank(zShapeInfo); - auto zEWS = shape::elementWiseStride(zShapeInfo); - auto zShape = shape::shapeOf(zShapeInfo); - auto zStride = shape::stride(zShapeInfo); - - int totalThreads = gridDim.x * blockDim.x; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - __shared__ int len; - if(threadIdx.x == 0) - len = shape::length(yShapeInfo); - __syncthreads(); - - if(yEWS >= 1 && zEWS >= 1 && shape::order(yShapeInfo) == shape::order(zShapeInfo)) { - transformCuda(len, vscalar, vy, yEWS, vparams, vz, zEWS, allocationBuffer); - } - else { - for (Nd4jLong i = tid; i < len; i+= totalThreads) - z[shape::getIndexOffset(i, zShapeInfo)] = OpType::op(y[shape::getIndexOffset(i, yShapeInfo)], scalar, params); - } +template +template +__device__ void ScalarBoolTransform::transformCuda( + void const* vscalar, void const* vy, Nd4jLong const* yShapeInfo, + void* vparams, void* vz, Nd4jLong const* zShapeInfo, + int* allocationBuffer) { + auto scalar = reinterpret_cast(vscalar)[0]; + auto y = reinterpret_cast(vy); + auto params = reinterpret_cast(vparams); + auto z = reinterpret_cast(vz); + + auto yRank = shape::rank(yShapeInfo); + auto yEWS = shape::elementWiseStride(yShapeInfo); + auto yShape = shape::shapeOf(yShapeInfo); + auto yStride = shape::stride(yShapeInfo); + + auto zRank = shape::rank(zShapeInfo); + auto zEWS = shape::elementWiseStride(zShapeInfo); + auto zShape = shape::shapeOf(zShapeInfo); + auto zStride = shape::stride(zShapeInfo); + + int totalThreads = gridDim.x * blockDim.x; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ int len; + if (threadIdx.x == 0) len = shape::length(yShapeInfo); + __syncthreads(); + + if (yEWS >= 1 && zEWS >= 1 && + shape::order(yShapeInfo) == shape::order(zShapeInfo)) { + transformCuda(len, vscalar, vy, yEWS, vparams, vz, zEWS, + allocationBuffer); + } else { + for (Nd4jLong i = tid; i < len; i += totalThreads) + z[shape::getIndexOffset(i, zShapeInfo)] = + OpType::op(y[shape::getIndexOffset(i, yShapeInfo)], scalar, params); + } } //////////////////////////////////////////////////////////////////////// -template -template -__device__ void ScalarBoolTransform::transformCuda(Nd4jLong len, - void const* vx, - void const* vy, Nd4jLong yEWS, - void *vparams, - void *vz, Nd4jLong zEWS, - int *allocationBuffer) { - - auto x = reinterpret_cast(vx)[0]; - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto params = reinterpret_cast(vparams); - - int totalThreads = gridDim.x * blockDim.x; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - Nd4jLong i = tid; - if(yEWS == 1 && zEWS == 1) { - for (; i < len; i += totalThreads) - z[i] = OpType::op(y[i], x, params); - } - else { - for (; i < len; i += totalThreads) - z[i * zEWS] = OpType::op(y[i * yEWS], x, params); - } +template +template +__device__ void ScalarBoolTransform::transformCuda( + Nd4jLong len, void const* vx, void const* vy, Nd4jLong yEWS, void* vparams, + void* vz, Nd4jLong zEWS, int* allocationBuffer) { + auto x = reinterpret_cast(vx)[0]; + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto params = reinterpret_cast(vparams); + + int totalThreads = gridDim.x * blockDim.x; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + Nd4jLong i = tid; + if (yEWS == 1 && zEWS == 1) { + for (; i < len; i += totalThreads) z[i] = OpType::op(y[i], x, params); + } else { + for (; i < len; i += totalThreads) + z[i * zEWS] = OpType::op(y[i * yEWS], x, params); + } } - //////////////////////////////////////////////////////////////////////// -template -template -__device__ void ScalarBoolTransform::transformCuda(void const* vx, Nd4jLong const* xShapeInfo, - void *vextraParams, - void *vz, Nd4jLong const* zShapeInfo, - void const* vscalars, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - auto x = reinterpret_cast(vx); - auto scalars = reinterpret_cast(vscalars); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - if (tadShapeInfoZ == nullptr) { - tadShapeInfoZ = tadShapeInfo; - tadOffsetsZ = tadOffsets; +template +template +__device__ void ScalarBoolTransform::transformCuda( + void const* vx, Nd4jLong const* xShapeInfo, void* vextraParams, void* vz, + Nd4jLong const* zShapeInfo, void const* vscalars, int* dimension, + int dimensionLength, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + auto x = reinterpret_cast(vx); + auto scalars = reinterpret_cast(vscalars); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + if (tadShapeInfoZ == nullptr) { + tadShapeInfoZ = tadShapeInfo; + tadOffsetsZ = tadOffsets; + } + + // tad preparation + auto tadEws = shape::elementWiseStride(tadShapeInfo); + auto zEws = shape::elementWiseStride(tadShapeInfoZ); + auto tadLength = shape::length(tadShapeInfo); // shape::tadLength(xShapeInfo, + // dimension, dimensionLength); + auto numTads = shape::length(xShapeInfo) / tadLength; + + if (tadEws > 0 && zEws > 0 && + shape::order(tadShapeInfo) == shape::order(zShapeInfo)) { + // main loop, rolling over tads + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + Z* oZ = z + tadOffsetsZ[r]; + auto oX = x + tadOffsets[r]; + + auto s = scalars[r]; + + for (int f = threadIdx.x; f < tadLength; f += blockDim.x) + oZ[f * zEws] = OpType::op(oX[f * tadEws], s, extraParams); } + } else { + // main loop, rolling over tads + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + Z* oZ = z + tadOffsetsZ[r]; + auto oX = x + tadOffsets[r]; - // tad preparation - auto tadEws = shape::elementWiseStride(tadShapeInfo); - auto zEws = shape::elementWiseStride(tadShapeInfoZ); - auto tadLength = shape::length(tadShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); - auto numTads =shape::length(xShapeInfo) / tadLength; - - if (tadEws > 0 && zEws > 0 && shape::order(tadShapeInfo) == shape::order(zShapeInfo)) { - - // main loop, rolling over tads - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - Z *oZ = z + tadOffsetsZ[r]; - auto oX = x + tadOffsets[r]; - - auto s = scalars[r]; - - for (int f = threadIdx.x; f < tadLength; f += blockDim.x) - oZ[f * zEws] = OpType::op(oX[f * tadEws], s, extraParams); - } - } else { - // main loop, rolling over tads - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - Z *oZ = z + tadOffsetsZ[r]; - auto oX = x + tadOffsets[r]; + auto s = scalars[r]; - auto s = scalars[r]; - - for (int f = threadIdx.x; f < tadLength; f += blockDim.x) - oZ[shape::getIndexOffset(f, tadShapeInfoZ)] = OpType::op(oX[shape::getIndexOffset(f, tadShapeInfo)], s, extraParams); - } + for (int f = threadIdx.x; f < tadLength; f += blockDim.x) + oZ[shape::getIndexOffset(f, tadShapeInfoZ)] = OpType::op( + oX[shape::getIndexOffset(f, tadShapeInfo)], s, extraParams); } + } } - //////////////////////////////////////////////////////////////////////// -template +template template -_CUDA_H void ScalarBoolTransform::intermediateAlongDimension(dim3& launchDims, cudaStream_t *stream, - void const* x, Nd4jLong const* xShapeInfo, - void *z, Nd4jLong const* zShapeInfo, - void const* scalars, - void *extraParams, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - scalarAlongDimension<<>>(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); - sd::DebugHelper::checkErrorCode(stream, "scalarAlongDim(...) failed"); +_CUDA_H void ScalarBoolTransform::intermediateAlongDimension( + dim3& launchDims, cudaStream_t* stream, void const* x, + Nd4jLong const* xShapeInfo, void* z, Nd4jLong const* zShapeInfo, + void const* scalars, void* extraParams, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + scalarAlongDimension + <<>>( + x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, + dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, + tadOffsetsZ); + sd::DebugHelper::checkErrorCode(stream, "scalarAlongDim(...) failed"); } //////////////////////////////////////////////////////////////////////// -template -template -void _CUDA_H ScalarBoolTransform::intermediateShaped(dim3& launchDims, cudaStream_t *stream, - void const* vx, Nd4jLong const* xShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - void const* vscalar, - void *vextraParams, int *allocPointer){ - - scalarSimpleShaped<<>>(vx, vscalar, xShapeInfo, vextraParams, vz, zShapeInfo, allocPointer); - sd::DebugHelper::checkErrorCode(stream, "scalarSimpleShaped(...) failed"); +template +template +void _CUDA_H ScalarBoolTransform::intermediateShaped( + dim3& launchDims, cudaStream_t* stream, void const* vx, + Nd4jLong const* xShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + void const* vscalar, void* vextraParams, int* allocPointer) { + scalarSimpleShaped + <<>>( + vx, vscalar, xShapeInfo, vextraParams, vz, zShapeInfo, allocPointer); + sd::DebugHelper::checkErrorCode(stream, "scalarSimpleShaped(...) failed"); } //////////////////////////////////////////////////////////////////////// -template -void ScalarBoolTransform::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, - int opNum, - void const* vx, Nd4jLong const* xShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - void const* vscalar, - void const* vextraParams) { - - if (sd::Environment::getInstance().isDebugAndVerbose()) - printf("H14 opNum:[%i]\n", opNum); - - DISPATCH_BY_OPNUM_TT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalar, const_cast(vextraParams), nullptr), SCALAR_BOOL_OPS); +template +void ScalarBoolTransform::executeCudaShaped( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + void const* vscalar, void const* vextraParams) { + if (sd::Environment::getInstance().isDebugAndVerbose()) + printf("H14 opNum:[%i]\n", opNum); + + DISPATCH_BY_OPNUM_TT( + intermediateShaped, + PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalar, + const_cast(vextraParams), nullptr), + SCALAR_BOOL_OPS); } //////////////////////////////////////////////////////////////////////// -template -void ScalarBoolTransform::executeCudaAlongDimension(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, void *vz, Nd4jLong const* zShapeInfo, void const* vscalars, void *vextraParams, int *dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - DISPATCH_BY_OPNUM_TT(intermediateAlongDimension, PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalars, vextraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), SCALAR_BOOL_OPS); -} - - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ScalarBoolTransform, , LIBND4J_TYPES, BOOL_TYPES); -} +template +void ScalarBoolTransform::executeCudaAlongDimension( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + void const* vscalars, void* vextraParams, int* dimension, + int dimensionLength, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + DISPATCH_BY_OPNUM_TT( + intermediateAlongDimension, + PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalars, + vextraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, + tadShapeInfoZ, tadOffsetsZ), + SCALAR_BOOL_OPS); } +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ScalarBoolTransform, , + LIBND4J_TYPES, BOOL_TYPES); +} // namespace scalar +} // namespace functions diff --git a/libnd4j/include/loops/cuda/scalar_int.cu b/libnd4j/include/loops/cuda/scalar_int.cu index b8cac0846551..1f604713481e 100644 --- a/libnd4j/include/loops/cuda/scalar_int.cu +++ b/libnd4j/include/loops/cuda/scalar_int.cu @@ -19,217 +19,222 @@ // @author raver119@gmail.com // -#include "../scalar_int.h" #include #include #include "../legacy_ops.h" +#include "../scalar_int.h" using namespace simdOps; //////////////////////////////////////////////////////////////////////// template -__global__ void scalarAlongDimension(void const* x, Nd4jLong const* xShapeInfo, - void *extraParams, - void *z, Nd4jLong const* zShapeInfo, - void const* scalars, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - functions::scalar::ScalarIntTransform::template transformCuda(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); +__global__ void scalarAlongDimension( + void const* x, Nd4jLong const* xShapeInfo, void* extraParams, void* z, + Nd4jLong const* zShapeInfo, void const* scalars, int* dimension, + int dimensionLength, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + functions::scalar::ScalarIntTransform::template transformCuda( + x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, + dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); } - //////////////////////////////////////////////////////////////////////// template -__global__ void scalarSimpleShaped(void const* x, void const* y, Nd4jLong const* xShapeInfo, void *params, void *z, Nd4jLong const* zShapeInfo, int *allocationBuffer) { - - functions::scalar::ScalarIntTransform::template transformCuda(y, x, xShapeInfo, params, z, zShapeInfo, allocationBuffer); +__global__ void scalarSimpleShaped(void const* x, void const* y, + Nd4jLong const* xShapeInfo, void* params, + void* z, Nd4jLong const* zShapeInfo, + int* allocationBuffer) { + functions::scalar::ScalarIntTransform::template transformCuda( + y, x, xShapeInfo, params, z, zShapeInfo, allocationBuffer); } - - - - // *********************************************************************// // *********************************************************************// namespace functions { -namespace scalar { +namespace scalar { //////////////////////////////////////////////////////////////////////// -template -template -__device__ void ScalarIntTransform::transformCuda(void const* vscalar, - void const* vy, Nd4jLong const* yShapeInfo, - void *vparams, - void *vz, Nd4jLong const* zShapeInfo, - int *allocationBuffer) { - auto scalar = reinterpret_cast(vscalar)[0]; - auto y = reinterpret_cast(vy); - auto params = reinterpret_cast(vparams); - auto z = reinterpret_cast(vz); - - auto yRank = shape::rank(yShapeInfo); - auto yEWS = shape::elementWiseStride(yShapeInfo); - auto yShape = shape::shapeOf(yShapeInfo); - auto yStride = shape::stride(yShapeInfo); - - auto zRank = shape::rank(zShapeInfo); - auto zEWS = shape::elementWiseStride(zShapeInfo); - auto zShape = shape::shapeOf(zShapeInfo); - auto zStride = shape::stride(zShapeInfo); - - int totalThreads = gridDim.x * blockDim.x; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - __shared__ int len; - if(threadIdx.x == 0) - len = shape::length(yShapeInfo); - __syncthreads(); - - if(yEWS >= 1 && zEWS >= 1 && shape::order(yShapeInfo) == shape::order(zShapeInfo)) { - transformCuda(len, vscalar, vy, yEWS, vparams, vz, zEWS, allocationBuffer); - } - else { - for (Nd4jLong i = tid; i < len; i+= totalThreads) - z[shape::getIndexOffset(i, zShapeInfo)] = OpType::op(y[shape::getIndexOffset(i, yShapeInfo)], scalar, params); - } +template +template +__device__ void ScalarIntTransform::transformCuda(void const* vscalar, + void const* vy, + Nd4jLong const* yShapeInfo, + void* vparams, void* vz, + Nd4jLong const* zShapeInfo, + int* allocationBuffer) { + auto scalar = reinterpret_cast(vscalar)[0]; + auto y = reinterpret_cast(vy); + auto params = reinterpret_cast(vparams); + auto z = reinterpret_cast(vz); + + auto yRank = shape::rank(yShapeInfo); + auto yEWS = shape::elementWiseStride(yShapeInfo); + auto yShape = shape::shapeOf(yShapeInfo); + auto yStride = shape::stride(yShapeInfo); + + auto zRank = shape::rank(zShapeInfo); + auto zEWS = shape::elementWiseStride(zShapeInfo); + auto zShape = shape::shapeOf(zShapeInfo); + auto zStride = shape::stride(zShapeInfo); + + int totalThreads = gridDim.x * blockDim.x; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ int len; + if (threadIdx.x == 0) len = shape::length(yShapeInfo); + __syncthreads(); + + if (yEWS >= 1 && zEWS >= 1 && + shape::order(yShapeInfo) == shape::order(zShapeInfo)) { + transformCuda(len, vscalar, vy, yEWS, vparams, vz, zEWS, + allocationBuffer); + } else { + for (Nd4jLong i = tid; i < len; i += totalThreads) + z[shape::getIndexOffset(i, zShapeInfo)] = + OpType::op(y[shape::getIndexOffset(i, yShapeInfo)], scalar, params); + } } //////////////////////////////////////////////////////////////////////// -template -template -__device__ void ScalarIntTransform::transformCuda(Nd4jLong len, - void const* vx, - void const* vy, Nd4jLong yEWS, - void *vparams, - void *vz, Nd4jLong zEWS, - int *allocationBuffer) { - - auto x = reinterpret_cast(vx)[0]; - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto params = reinterpret_cast(vparams); - - int totalThreads = gridDim.x * blockDim.x; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - Nd4jLong i = tid; - if(yEWS == 1 && zEWS == 1) { - for (; i < len; i += totalThreads) - z[i] = OpType::op(y[i], x, params); - } - else { - for (; i < len; i += totalThreads) - z[i * zEWS] = OpType::op(y[i * yEWS], x, params); - } +template +template +__device__ void ScalarIntTransform::transformCuda( + Nd4jLong len, void const* vx, void const* vy, Nd4jLong yEWS, void* vparams, + void* vz, Nd4jLong zEWS, int* allocationBuffer) { + auto x = reinterpret_cast(vx)[0]; + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto params = reinterpret_cast(vparams); + + int totalThreads = gridDim.x * blockDim.x; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + Nd4jLong i = tid; + if (yEWS == 1 && zEWS == 1) { + for (; i < len; i += totalThreads) z[i] = OpType::op(y[i], x, params); + } else { + for (; i < len; i += totalThreads) + z[i * zEWS] = OpType::op(y[i * yEWS], x, params); + } } - //////////////////////////////////////////////////////////////////////// -template -template -__device__ void ScalarIntTransform::transformCuda(void const* vx, Nd4jLong const* xShapeInfo, - void *vextraParams, - void *vz, Nd4jLong const* zShapeInfo, - void const* vscalars, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - auto x = reinterpret_cast(vx); - auto scalars = reinterpret_cast(vscalars); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - if (tadShapeInfoZ == nullptr) { - tadShapeInfoZ = tadShapeInfo; - tadOffsetsZ = tadOffsets; +template +template +__device__ void ScalarIntTransform::transformCuda( + void const* vx, Nd4jLong const* xShapeInfo, void* vextraParams, void* vz, + Nd4jLong const* zShapeInfo, void const* vscalars, int* dimension, + int dimensionLength, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + auto x = reinterpret_cast(vx); + auto scalars = reinterpret_cast(vscalars); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + if (tadShapeInfoZ == nullptr) { + tadShapeInfoZ = tadShapeInfo; + tadOffsetsZ = tadOffsets; + } + + // tad preparation + auto tadEws = shape::elementWiseStride(tadShapeInfo); + auto zEws = shape::elementWiseStride(tadShapeInfoZ); + auto tadLength = shape::length(tadShapeInfo); // shape::tadLength(xShapeInfo, + // dimension, dimensionLength); + auto numTads = shape::length(xShapeInfo) / tadLength; + + if (tadEws > 0 && zEws > 0 && + shape::order(tadShapeInfo) == shape::order(zShapeInfo)) { + // main loop, rolling over tads + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + X* oZ = z + tadOffsetsZ[r]; + auto oX = x + tadOffsets[r]; + + auto s = scalars[r]; + + for (int f = threadIdx.x; f < tadLength; f += blockDim.x) + oZ[f * zEws] = OpType::op(oX[f * tadEws], s, extraParams); } + } else { + // main loop, rolling over tads + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + X* oZ = z + tadOffsetsZ[r]; + auto oX = x + tadOffsets[r]; - // tad preparation - auto tadEws = shape::elementWiseStride(tadShapeInfo); - auto zEws = shape::elementWiseStride(tadShapeInfoZ); - auto tadLength = shape::length(tadShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); - auto numTads =shape::length(xShapeInfo) / tadLength; - - if (tadEws > 0 && zEws > 0 && shape::order(tadShapeInfo) == shape::order(zShapeInfo)) { - - // main loop, rolling over tads - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - X *oZ = z + tadOffsetsZ[r]; - auto oX = x + tadOffsets[r]; - - auto s = scalars[r]; - - for (int f = threadIdx.x; f < tadLength; f += blockDim.x) - oZ[f * zEws] = OpType::op(oX[f * tadEws], s, extraParams); - } - } else { - // main loop, rolling over tads - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - X *oZ = z + tadOffsetsZ[r]; - auto oX = x + tadOffsets[r]; + auto s = scalars[r]; - auto s = scalars[r]; - - for (int f = threadIdx.x; f < tadLength; f += blockDim.x) - oZ[shape::getIndexOffset(f, tadShapeInfoZ)] = OpType::op(oX[shape::getIndexOffset(f, tadShapeInfo)], s, extraParams); - } + for (int f = threadIdx.x; f < tadLength; f += blockDim.x) + oZ[shape::getIndexOffset(f, tadShapeInfoZ)] = OpType::op( + oX[shape::getIndexOffset(f, tadShapeInfo)], s, extraParams); } + } } - //////////////////////////////////////////////////////////////////////// -template +template template -_CUDA_H void ScalarIntTransform::intermediateAlongDimension(dim3& launchDims, cudaStream_t *stream, - void const* x, Nd4jLong const* xShapeInfo, - void *z, Nd4jLong const* zShapeInfo, - void const* scalars, - void *extraParams, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - scalarAlongDimension<<>>(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); +_CUDA_H void ScalarIntTransform::intermediateAlongDimension( + dim3& launchDims, cudaStream_t* stream, void const* x, + Nd4jLong const* xShapeInfo, void* z, Nd4jLong const* zShapeInfo, + void const* scalars, void* extraParams, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + scalarAlongDimension + <<>>( + x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, + dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, + tadOffsetsZ); } //////////////////////////////////////////////////////////////////////// -template -template -void _CUDA_H ScalarIntTransform::intermediateShaped(dim3& launchDims, cudaStream_t *stream, - void const* vx, Nd4jLong const* xShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - void const* vscalar, - void *vextraParams, int *allocPointer){ - - scalarSimpleShaped<<>>(vx, vscalar, xShapeInfo, vextraParams, vz, zShapeInfo, allocPointer); +template +template +void _CUDA_H ScalarIntTransform::intermediateShaped( + dim3& launchDims, cudaStream_t* stream, void const* vx, + Nd4jLong const* xShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + void const* vscalar, void* vextraParams, int* allocPointer) { + scalarSimpleShaped + <<>>( + vx, vscalar, xShapeInfo, vextraParams, vz, zShapeInfo, allocPointer); } //////////////////////////////////////////////////////////////////////// -template -void ScalarIntTransform::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, - int opNum, - void const* vx, Nd4jLong const* xShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - void const* vscalar, - void* vextraParams) { - - if (sd::Environment::getInstance().isDebugAndVerbose()) - printf("H14 opNum:[%i]\n", opNum); - - DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalar, vextraParams, nullptr), SCALAR_INT_OPS); +template +void ScalarIntTransform::executeCudaShaped( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + void const* vscalar, void* vextraParams) { + if (sd::Environment::getInstance().isDebugAndVerbose()) + printf("H14 opNum:[%i]\n", opNum); + + DISPATCH_BY_OPNUM_T(intermediateShaped, + PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, + vscalar, vextraParams, nullptr), + SCALAR_INT_OPS); } //////////////////////////////////////////////////////////////////////// -template -void ScalarIntTransform::executeCudaAlongDimension(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, void *vz, Nd4jLong const* zShapeInfo, void const* vscalars, void *vextraParams, int *dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - DISPATCH_BY_OPNUM_T(intermediateAlongDimension, PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalars, vextraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), SCALAR_INT_OPS); +template +void ScalarIntTransform::executeCudaAlongDimension( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + void const* vscalars, void* vextraParams, int* dimension, + int dimensionLength, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + DISPATCH_BY_OPNUM_T( + intermediateAlongDimension, + PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalars, + vextraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, + tadShapeInfoZ, tadOffsetsZ), + SCALAR_INT_OPS); } - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT ScalarIntTransform, , INTEGER_TYPES); - -} -} +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT ScalarIntTransform, , + INTEGER_TYPES); +} // namespace scalar +} // namespace functions diff --git a/libnd4j/include/loops/cuda/specials/accumulateKernel.cu b/libnd4j/include/loops/cuda/specials/accumulateKernel.cu index 6d6dd42a4f10..13851d358d2a 100644 --- a/libnd4j/include/loops/cuda/specials/accumulateKernel.cu +++ b/libnd4j/include/loops/cuda/specials/accumulateKernel.cu @@ -33,58 +33,62 @@ namespace sd { * @param n * @param length */ - template - __device__ void accumulateKernel(void **vx, void *vz, int n, const Nd4jLong length) { +template +__device__ void accumulateKernel(void **vx, void *vz, int n, + const Nd4jLong length) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); + __shared__ T *shmem; - __shared__ - T *shmem; + if (threadIdx.x == 0) { + extern __shared__ unsigned char sharedmem[]; + shmem = (T *)sharedmem; + } + __syncthreads(); - if (threadIdx.x == 0) { - extern __shared__ unsigned char sharedmem[]; - shmem = (T *) sharedmem; - } - __syncthreads(); + for (int r = blockDim.x * blockIdx.x; r < length; + r += blockDim.x * gridDim.x) { + shmem[threadIdx.x] = 0.0f; - for (int r = blockDim.x * blockIdx.x; r < length; r += blockDim.x * gridDim.x) { - shmem[threadIdx.x] = 0.0f; + Nd4jLong baseIdx = r; - Nd4jLong baseIdx = r; + // aggregation step, we roll over all arrays + for (int ar = 0; ar < n; ar++) { + T *cdata = (T *)x[ar]; + cdata += baseIdx; - // aggregation step, we roll over all arrays - for (int ar = 0; ar < n; ar++) { - T *cdata = (T *) x[ar]; - cdata += baseIdx; - - if (baseIdx + threadIdx.x < length) - shmem[threadIdx.x] += cdata[threadIdx.x]; - } - - T *wdata = z + baseIdx; - - // saving accumulated values - if (baseIdx + threadIdx.x < length) - wdata[threadIdx.x] = shmem[threadIdx.x]; - } + if (baseIdx + threadIdx.x < length) + shmem[threadIdx.x] += cdata[threadIdx.x]; } -/////////////////////////////////////////////////////////////////////// - template - __global__ void execAccumulateKernel(void **vx, void *vz, int n, const Nd4jLong length) { + T *wdata = z + baseIdx; - accumulateKernel(vx, vz, n, length); - } + // saving accumulated values + if (baseIdx + threadIdx.x < length) wdata[threadIdx.x] = shmem[threadIdx.x]; + } +} /////////////////////////////////////////////////////////////////////// - template - __host__ void - accumulateKernelGeneric(dim3 &launchDims, cudaStream_t *stream, void **vx, void *vz, int n, const Nd4jLong length) { - - execAccumulateKernel<<< launchDims.x, launchDims.y, launchDims.z, *stream>>> (vx, vz, n, length); - sd::DebugHelper::checkErrorCode(stream, "accumulate(...) failed"); - } +template +__global__ void execAccumulateKernel(void **vx, void *vz, int n, + const Nd4jLong length) { + accumulateKernel(vx, vz, n, length); +} - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT accumulateKernelGeneric, (dim3 & launchDims, cudaStream_t * stream, void * *vx, void * vz, int n, const Nd4jLong length), LIBND4J_TYPES); -} \ No newline at end of file +/////////////////////////////////////////////////////////////////////// +template +__host__ void accumulateKernelGeneric(dim3 &launchDims, cudaStream_t *stream, + void **vx, void *vz, int n, + const Nd4jLong length) { + execAccumulateKernel + <<>>(vx, vz, n, + length); + sd::DebugHelper::checkErrorCode(stream, "accumulate(...) failed"); +} + +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT accumulateKernelGeneric, + (dim3 & launchDims, cudaStream_t *stream, void **vx, + void *vz, int n, const Nd4jLong length), + LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/averagingKernel.cu b/libnd4j/include/loops/cuda/specials/averagingKernel.cu index 798b273cfdc4..3e0d1e6f155d 100644 --- a/libnd4j/include/loops/cuda/specials/averagingKernel.cu +++ b/libnd4j/include/loops/cuda/specials/averagingKernel.cu @@ -24,81 +24,79 @@ namespace sd { /////////////////////////////////////////////////////////////////////// - template - __device__ void averagingKernel(void **vdx, void *vdz, int n, Nd4jLong length, bool propagate) { - - auto dx = reinterpret_cast(vdx); - auto dz = reinterpret_cast(vdz); - - __shared__ - T *shmem; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char sharedmem[]; - shmem = (T *) sharedmem; - } - __syncthreads(); - - - // each block cycles over it's own part of arrays - for (int r = blockDim.x * blockIdx.x; r < length; r += blockDim.x * gridDim.x) { - shmem[threadIdx.x] = (T) 0.0f; - - Nd4jLong baseIdx = r; - - // aggregation step, we roll over all arrays - for (int ar = 0; ar < n; ar++) { - T *cdata = (T *) dx[ar]; - cdata += baseIdx; - - if (baseIdx + threadIdx.x < length) - shmem[threadIdx.x] += cdata[threadIdx.x]; - } - - - // average data in shared memory - if (baseIdx + threadIdx.x < length) - shmem[threadIdx.x] /= n; - - // div step & write out step - if (dz != nullptr) { - T *wdata = dz + baseIdx; - - if (baseIdx + threadIdx.x < length) { - wdata[threadIdx.x] = shmem[threadIdx.x]; - } - } - - // propagate averaged data to all arrays - if (propagate) - for (int ar = 0; ar < n; ar++) { - T *cdata = (T *) dx[ar]; - cdata += baseIdx; - - if (baseIdx + threadIdx.x < length) - cdata[threadIdx.x] = shmem[threadIdx.x]; - } - } +template +__device__ void averagingKernel(void **vdx, void *vdz, int n, Nd4jLong length, + bool propagate) { + auto dx = reinterpret_cast(vdx); + auto dz = reinterpret_cast(vdz); + + __shared__ T *shmem; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char sharedmem[]; + shmem = (T *)sharedmem; + } + __syncthreads(); + + // each block cycles over it's own part of arrays + for (int r = blockDim.x * blockIdx.x; r < length; + r += blockDim.x * gridDim.x) { + shmem[threadIdx.x] = (T)0.0f; + + Nd4jLong baseIdx = r; + + // aggregation step, we roll over all arrays + for (int ar = 0; ar < n; ar++) { + T *cdata = (T *)dx[ar]; + cdata += baseIdx; + + if (baseIdx + threadIdx.x < length) + shmem[threadIdx.x] += cdata[threadIdx.x]; } + // average data in shared memory + if (baseIdx + threadIdx.x < length) shmem[threadIdx.x] /= n; -/////////////////////////////////////////////////////////////////////// - template - __global__ void execAveragingKernel(void **vdx, void *vdz, int n, Nd4jLong length, bool propagate) { + // div step & write out step + if (dz != nullptr) { + T *wdata = dz + baseIdx; - averagingKernel(vdx, vdz, n, length, propagate); + if (baseIdx + threadIdx.x < length) { + wdata[threadIdx.x] = shmem[threadIdx.x]; + } } + // propagate averaged data to all arrays + if (propagate) + for (int ar = 0; ar < n; ar++) { + T *cdata = (T *)dx[ar]; + cdata += baseIdx; -/////////////////////////////////////////////////////////////////////// - template - __host__ void - averagingKernelGeneric(dim3 &launchDims, cudaStream_t *stream, void **vdx, void *vdz, int n, Nd4jLong length, - bool propagate) { + if (baseIdx + threadIdx.x < length) + cdata[threadIdx.x] = shmem[threadIdx.x]; + } + } +} - execAveragingKernel<<< launchDims.x, launchDims.y, launchDims.z, *stream>>>(vdx, vdz, n, length, propagate); - sd::DebugHelper::checkErrorCode(stream, "averaging(...) failed"); - } +/////////////////////////////////////////////////////////////////////// +template +__global__ void execAveragingKernel(void **vdx, void *vdz, int n, + Nd4jLong length, bool propagate) { + averagingKernel(vdx, vdz, n, length, propagate); +} - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT averagingKernelGeneric, (dim3 & launchDims, cudaStream_t * stream, void * *vdx, void * vdz, int n, Nd4jLong length, bool propagate), LIBND4J_TYPES); -} \ No newline at end of file +/////////////////////////////////////////////////////////////////////// +template +__host__ void averagingKernelGeneric(dim3 &launchDims, cudaStream_t *stream, + void **vdx, void *vdz, int n, + Nd4jLong length, bool propagate) { + execAveragingKernel<<>>( + vdx, vdz, n, length, propagate); + sd::DebugHelper::checkErrorCode(stream, "averaging(...) failed"); +} + +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT averagingKernelGeneric, + (dim3 & launchDims, cudaStream_t *stream, void **vdx, + void *vdz, int n, Nd4jLong length, bool propagate), + LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/bitonicArbitraryStep.cu b/libnd4j/include/loops/cuda/specials/bitonicArbitraryStep.cu index 999a0994257e..103c76aa6f99 100644 --- a/libnd4j/include/loops/cuda/specials/bitonicArbitraryStep.cu +++ b/libnd4j/include/loops/cuda/specials/bitonicArbitraryStep.cu @@ -23,168 +23,194 @@ ////////////////////////////////////////////////////////////////////////// template -__global__ void bitonicArbitraryStepKernelKey(void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int window, int length, int reverse, bool descending) { - auto x = static_cast(vx); - auto y = static_cast(vy); - - int tid = threadIdx.x + blockDim.x * blockIdx.x; - int half = window>>1; - - __shared__ Nd4jLong xLength; - if (threadIdx.x == 0) { - xLength = shape::length(xShapeInfo); - } - __syncthreads(); - - //for (int i = 0; i < length; i+= window) - /* - if window == 4; - iterations will be: 0; 4; 8; 12; 16; 20 - if gridDim = 3; - on first iteration we'll have: 0; 4; 8; - on second iteration we'll have: 0 + (3 * 4) = 12; 4 + (3 * 4) = 16; 8 + (3 * 4) = 20 - */ - int firstPosition; - int firstStep; - int secondPosition; - int secondStep; - - int WARP_SIZE = 32; - int numWarps = (gridDim.x * blockDim.x) / 32; - int warpId = tid / WARP_SIZE; - int warpIdx = tid % WARP_SIZE; - - if (half >= 128) { - firstPosition = blockIdx.x * window; - firstStep = gridDim.x * window; - - secondPosition = threadIdx.x; - secondStep = blockDim.x; - } else if (half >= 32) { - firstPosition = warpId * window; - firstStep = numWarps * window; - - secondPosition = warpIdx; - secondStep = WARP_SIZE; - } else { - firstPosition = tid * window; - firstStep = blockDim.x * gridDim.x * window; - - secondPosition = 0; - secondStep = 1; - } - - - for (int i = firstPosition; i < length; i += firstStep) { - for (int j = secondPosition; j < half; j += secondStep) { - int it = (reverse) ? i + j + half : i + window - j - 1; - int ij = i+j; - if (it < length && ij < length ) { - int posIT = shape::getIndexOffset(it, xShapeInfo); - int posIJ = shape::getIndexOffset(ij, xShapeInfo); - - X v0 = x[posIJ]; - X v1 = x[posIT]; - - if(!descending == (v0 > v1)) { - x[posIJ] = v1; - x[posIT] = v0; - - Y ytemp = y[posIJ]; - y[posIJ] = y[posIT]; - y[posIT] = ytemp; - } - } +__global__ void bitonicArbitraryStepKernelKey( + void *vx, Nd4jLong const *xShapeInfo, void *vy, Nd4jLong const *yShapeInfo, + int window, int length, int reverse, bool descending) { + auto x = static_cast(vx); + auto y = static_cast(vy); + + int tid = threadIdx.x + blockDim.x * blockIdx.x; + int half = window >> 1; + + __shared__ Nd4jLong xLength; + if (threadIdx.x == 0) { + xLength = shape::length(xShapeInfo); + } + __syncthreads(); + + // for (int i = 0; i < length; i+= window) + /* + if window == 4; + iterations will be: 0; 4; 8; 12; 16; 20 + if gridDim = 3; + on first iteration we'll have: 0; 4; 8; + on second iteration we'll have: 0 + (3 * 4) = 12; 4 + (3 * 4) = 16; 8 + + (3 * 4) = 20 + */ + int firstPosition; + int firstStep; + int secondPosition; + int secondStep; + + int WARP_SIZE = 32; + int numWarps = (gridDim.x * blockDim.x) / 32; + int warpId = tid / WARP_SIZE; + int warpIdx = tid % WARP_SIZE; + + if (half >= 128) { + firstPosition = blockIdx.x * window; + firstStep = gridDim.x * window; + + secondPosition = threadIdx.x; + secondStep = blockDim.x; + } else if (half >= 32) { + firstPosition = warpId * window; + firstStep = numWarps * window; + + secondPosition = warpIdx; + secondStep = WARP_SIZE; + } else { + firstPosition = tid * window; + firstStep = blockDim.x * gridDim.x * window; + + secondPosition = 0; + secondStep = 1; + } + + for (int i = firstPosition; i < length; i += firstStep) { + for (int j = secondPosition; j < half; j += secondStep) { + int it = (reverse) ? i + j + half : i + window - j - 1; + int ij = i + j; + if (it < length && ij < length) { + int posIT = shape::getIndexOffset(it, xShapeInfo); + int posIJ = shape::getIndexOffset(ij, xShapeInfo); + + X v0 = x[posIJ]; + X v1 = x[posIT]; + + if (!descending == (v0 > v1)) { + x[posIJ] = v1; + x[posIT] = v0; + + Y ytemp = y[posIJ]; + y[posIJ] = y[posIT]; + y[posIT] = ytemp; } + } } + } } ////////////////////////////////////////////////////////////////////////// -template -__global__ void execBitonicArbitraryStepKernel(void *vx, Nd4jLong const* xShapeInfo, int window, int length, int reverse, bool descending) { - auto x = static_cast(vx); - - int tid = threadIdx.x + blockDim.x * blockIdx.x; - int half = window>>1; - - __shared__ T *shmem; - __shared__ Nd4jLong xLength; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shrd[]; - shmem = (T *) shrd; - xLength = shape::length(xShapeInfo); - } - __syncthreads(); - - //for (int i = 0; i < length; i+= window) - /* - if window == 4; - iterations will be: 0; 4; 8; 12; 16; 20 - if gridDim = 3; - on first iteration we'll have: 0; 4; 8; - on second iteration we'll have: 0 + (3 * 4) = 12; 4 + (3 * 4) = 16; 8 + (3 * 4) = 20 - */ - int firstPosition; - int firstStep; - int secondPosition; - int secondStep; - - int WARP_SIZE = 32; - int numWarps = (gridDim.x * blockDim.x) / 32; - int warpId = tid / WARP_SIZE; - int warpIdx = tid % WARP_SIZE; - - if (half >= 128) { - firstPosition = blockIdx.x * window; - firstStep = gridDim.x * window; - - secondPosition = threadIdx.x; - secondStep = blockDim.x; - } else if (half >= 32) { - firstPosition = warpId * window; - firstStep = numWarps * window; - - secondPosition = warpIdx; - secondStep = WARP_SIZE; - } else { - firstPosition = tid * window; - firstStep = blockDim.x * gridDim.x * window; - - secondPosition = 0; - secondStep = 1; - } - - - for (int i = firstPosition; i < length; i += firstStep) { - for (int j = secondPosition; j < half; j += secondStep) { - int it = (reverse) ? i + j + half : i + window - j - 1; - int ij = i+j; - if (it < length && ij < length ) { - int posIT = shape::getIndexOffset(it, xShapeInfo); - int posIJ = shape::getIndexOffset(ij, xShapeInfo); - - shmem[threadIdx.x] = x[posIJ]; - shmem[threadIdx.x + blockDim.x] = x[posIT]; - - if(!descending == (shmem[threadIdx.x] > shmem[threadIdx.x + blockDim.x])) { - x[posIJ] = shmem[threadIdx.x + blockDim.x]; - x[posIT] = shmem[threadIdx.x]; - } - } +template +__global__ void execBitonicArbitraryStepKernel(void *vx, + Nd4jLong const *xShapeInfo, + int window, int length, + int reverse, bool descending) { + auto x = static_cast(vx); + + int tid = threadIdx.x + blockDim.x * blockIdx.x; + int half = window >> 1; + + __shared__ T *shmem; + __shared__ Nd4jLong xLength; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shrd[]; + shmem = (T *)shrd; + xLength = shape::length(xShapeInfo); + } + __syncthreads(); + + // for (int i = 0; i < length; i+= window) + /* + if window == 4; + iterations will be: 0; 4; 8; 12; 16; 20 + if gridDim = 3; + on first iteration we'll have: 0; 4; 8; + on second iteration we'll have: 0 + (3 * 4) = 12; 4 + (3 * 4) = 16; 8 + + (3 * 4) = 20 + */ + int firstPosition; + int firstStep; + int secondPosition; + int secondStep; + + int WARP_SIZE = 32; + int numWarps = (gridDim.x * blockDim.x) / 32; + int warpId = tid / WARP_SIZE; + int warpIdx = tid % WARP_SIZE; + + if (half >= 128) { + firstPosition = blockIdx.x * window; + firstStep = gridDim.x * window; + + secondPosition = threadIdx.x; + secondStep = blockDim.x; + } else if (half >= 32) { + firstPosition = warpId * window; + firstStep = numWarps * window; + + secondPosition = warpIdx; + secondStep = WARP_SIZE; + } else { + firstPosition = tid * window; + firstStep = blockDim.x * gridDim.x * window; + + secondPosition = 0; + secondStep = 1; + } + + for (int i = firstPosition; i < length; i += firstStep) { + for (int j = secondPosition; j < half; j += secondStep) { + int it = (reverse) ? i + j + half : i + window - j - 1; + int ij = i + j; + if (it < length && ij < length) { + int posIT = shape::getIndexOffset(it, xShapeInfo); + int posIJ = shape::getIndexOffset(ij, xShapeInfo); + + shmem[threadIdx.x] = x[posIJ]; + shmem[threadIdx.x + blockDim.x] = x[posIT]; + + if (!descending == + (shmem[threadIdx.x] > shmem[threadIdx.x + blockDim.x])) { + x[posIJ] = shmem[threadIdx.x + blockDim.x]; + x[posIT] = shmem[threadIdx.x]; } + } } + } } ////////////////////////////////////////////////////////////////////////// -template -__host__ void bitonicArbitraryStepGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, int window, int length, int reverse, bool descending) { - execBitonicArbitraryStepKernel<<>>(vx, xShapeInfo, window, length, reverse, descending); +template +__host__ void bitonicArbitraryStepGeneric(dim3 &launchDims, + cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, + int window, int length, int reverse, + bool descending) { + execBitonicArbitraryStepKernel + <<>>( + vx, xShapeInfo, window, length, reverse, descending); } template -__host__ void bitonicArbitraryStepGenericKey(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int window, int length, int reverse, bool descending) { - bitonicArbitraryStepKernelKey<<>>(vx, xShapeInfo, vy, yShapeInfo, window, length, reverse, descending); +__host__ void bitonicArbitraryStepGenericKey( + dim3 &launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, void *vy, Nd4jLong const *yShapeInfo, + int window, int length, int reverse, bool descending) { + bitonicArbitraryStepKernelKey + <<>>( + vx, xShapeInfo, vy, yShapeInfo, window, length, reverse, descending); } -BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES); -BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES, LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT bitonicArbitraryStepGeneric, + (dim3 & launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, int window, int length, + int reverse, bool descending), + LIBND4J_TYPES); +BUILD_DOUBLE_TEMPLATE(template void SD_EXPORT bitonicArbitraryStepGenericKey, + (dim3 & launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, void *vy, + Nd4jLong const *yShapeInfo, int window, int length, + int reverse, bool descending), + LIBND4J_TYPES, LIBND4J_TYPES); diff --git a/libnd4j/include/loops/cuda/specials/bitonicSortStep.cu b/libnd4j/include/loops/cuda/specials/bitonicSortStep.cu index 679e44d1f941..e76a972d2076 100644 --- a/libnd4j/include/loops/cuda/specials/bitonicSortStep.cu +++ b/libnd4j/include/loops/cuda/specials/bitonicSortStep.cu @@ -21,120 +21,135 @@ #include - ////////////////////////////////////////////////////////////////////////// template -__global__ void bitonicSortStepKernelKey(void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int j, int k, int length, bool descending) { - - auto x = static_cast(vx); - auto y = static_cast(vy); - - unsigned int i, ixj; /* Sorting partners: i and ixj */ - i = threadIdx.x + blockDim.x * blockIdx.x; - - __shared__ Nd4jLong xLength; - if (threadIdx.x == 0) - xLength = shape::length(xShapeInfo); - - __syncthreads(); - - - if (i >= length) - return; - - ixj = i^j; - - /* The threads with the lowest ids sort the array. */ - if ((ixj)>i) { - int posI = shape::getIndexOffset(i, xShapeInfo); - int posIXJ = shape::getIndexOffset(ixj, xShapeInfo); - - if ((i&k)==0) { - /* Sort ascending */ - if (!descending == (x[posI]>x[posIXJ])) { - /* exchange(i,ixj); */ - X temp = x[posI]; - x[posI] = x[posIXJ]; - x[posIXJ] = temp; - - Y ytemp = y[posI]; - y[posI] = y[posIXJ]; - y[posIXJ] = ytemp; - } - } else if ((i&k)!=0) { - /* Sort descending */ - if (!descending == (x[posI](vx); + auto y = static_cast(vy); + + unsigned int i, ixj; /* Sorting partners: i and ixj */ + i = threadIdx.x + blockDim.x * blockIdx.x; + + __shared__ Nd4jLong xLength; + if (threadIdx.x == 0) xLength = shape::length(xShapeInfo); + + __syncthreads(); + + if (i >= length) return; + + ixj = i ^ j; + + /* The threads with the lowest ids sort the array. */ + if ((ixj) > i) { + int posI = shape::getIndexOffset(i, xShapeInfo); + int posIXJ = shape::getIndexOffset(ixj, xShapeInfo); + + if ((i & k) == 0) { + /* Sort ascending */ + if (!descending == (x[posI] > x[posIXJ])) { + /* exchange(i,ixj); */ + X temp = x[posI]; + x[posI] = x[posIXJ]; + x[posIXJ] = temp; + + Y ytemp = y[posI]; + y[posI] = y[posIXJ]; + y[posIXJ] = ytemp; + } + } else if ((i & k) != 0) { + /* Sort descending */ + if (!descending == (x[posI] < x[posIXJ])) { + /* exchange(i,ixj); */ + X temp = x[posI]; + x[posI] = x[posIXJ]; + x[posIXJ] = temp; + + Y ytemp = y[posI]; + y[posI] = y[posIXJ]; + y[posIXJ] = ytemp; + } } + } } ////////////////////////////////////////////////////////////////////////// -template -__global__ void bitonicSortStepKernel(void *vx, Nd4jLong const* xShapeInfo, int j, int k, int length, bool descending) { - - auto x = static_cast(vx); - - unsigned int i, ixj; /* Sorting partners: i and ixj */ - i = threadIdx.x + blockDim.x * blockIdx.x; - - __shared__ Nd4jLong xLength; - if (threadIdx.x == 0) - xLength = shape::length(xShapeInfo); - - __syncthreads(); - - - if (i >= length) - return; - - ixj = i^j; - - /* The threads with the lowest ids sort the array. */ - if ((ixj)>i) { - int posI = shape::getIndexOffset(i, xShapeInfo); - int posIXJ = shape::getIndexOffset(ixj, xShapeInfo); - - if ((i&k)==0) { - /* Sort ascending */ - if (!descending == (x[posI]>x[posIXJ])) { - /* exchange(i,ixj); */ - T temp = x[posI]; - x[posI] = x[posIXJ]; - x[posIXJ] = temp; - } - } else if ((i&k)!=0) { - /* Sort descending */ - if (!descending == (x[posI] +__global__ void bitonicSortStepKernel(void *vx, Nd4jLong const *xShapeInfo, + int j, int k, int length, + bool descending) { + auto x = static_cast(vx); + + unsigned int i, ixj; /* Sorting partners: i and ixj */ + i = threadIdx.x + blockDim.x * blockIdx.x; + + __shared__ Nd4jLong xLength; + if (threadIdx.x == 0) xLength = shape::length(xShapeInfo); + + __syncthreads(); + + if (i >= length) return; + + ixj = i ^ j; + + /* The threads with the lowest ids sort the array. */ + if ((ixj) > i) { + int posI = shape::getIndexOffset(i, xShapeInfo); + int posIXJ = shape::getIndexOffset(ixj, xShapeInfo); + + if ((i & k) == 0) { + /* Sort ascending */ + if (!descending == (x[posI] > x[posIXJ])) { + /* exchange(i,ixj); */ + T temp = x[posI]; + x[posI] = x[posIXJ]; + x[posIXJ] = temp; + } + } else if ((i & k) != 0) { + /* Sort descending */ + if (!descending == (x[posI] < x[posIXJ])) { + /* exchange(i,ixj); */ + T temp = x[posI]; + x[posI] = x[posIXJ]; + x[posIXJ] = temp; + } } + } } ////////////////////////////////////////////////////////////////////////// -template -__host__ void bitonicSortStepGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, int j, int k, int length, bool descending) { - bitonicSortStepKernel<<>>(vx, xShapeInfo, j, k, length, descending); +template +__host__ void bitonicSortStepGeneric(dim3 &launchDims, cudaStream_t *stream, + void *vx, Nd4jLong const *xShapeInfo, + int j, int k, int length, + bool descending) { + bitonicSortStepKernel + <<>>( + vx, xShapeInfo, j, k, length, descending); } ////////////////////////////////////////////////////////////////////////// template -__host__ void bitonicSortStepGenericKey(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int j, int k, int length, bool descending) { - bitonicSortStepKernelKey<<>>(vx, xShapeInfo, vy, yShapeInfo, j, k, length, descending); +__host__ void bitonicSortStepGenericKey(dim3 &launchDims, cudaStream_t *stream, + void *vx, Nd4jLong const *xShapeInfo, + void *vy, Nd4jLong const *yShapeInfo, + int j, int k, int length, + bool descending) { + bitonicSortStepKernelKey + <<>>( + vx, xShapeInfo, vy, yShapeInfo, j, k, length, descending); } - -BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES); -BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES, LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT bitonicSortStepGeneric, + (dim3 & launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, int j, int k, int length, + bool descending), + LIBND4J_TYPES); +BUILD_DOUBLE_TEMPLATE(template void SD_EXPORT bitonicSortStepGenericKey, + (dim3 & launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, void *vy, + Nd4jLong const *yShapeInfo, int j, int k, int length, + bool descending), + LIBND4J_TYPES, LIBND4J_TYPES); diff --git a/libnd4j/include/loops/cuda/specials/concatKernel.cu b/libnd4j/include/loops/cuda/specials/concatKernel.cu index a4a849e49ec2..cbd9e258425e 100644 --- a/libnd4j/include/loops/cuda/specials/concatKernel.cu +++ b/libnd4j/include/loops/cuda/specials/concatKernel.cu @@ -23,248 +23,248 @@ namespace sd { /////////////////////////////////////////////////////////////////////// - template - __device__ void concatKernel(int numArrays, - Nd4jPointer *data, Nd4jPointer *inputShapeInfos, - void *vz, Nd4jLong *resultShapeInfo, - Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers, - Nd4jLong *zTadShape, Nd4jLong *zOffsets) { - - int tid = threadIdx.x + blockIdx.x * blockDim.x; - - int zRank = shape::rank(resultShapeInfo); - - auto result = reinterpret_cast(vz); - auto dataT = reinterpret_cast(data); - auto shapeInfoPointers = reinterpret_cast(inputShapeInfos); - auto tadShapes = reinterpret_cast(tadPointers); - auto tadOffsets = reinterpret_cast(offsetPointers); - - //if (threadIdx.x == 0 && blockIdx.x == 0) { - // shape::printShapeInfoLinear("zTadShape", zTadShape); - //} - - //__shared__ int tDim[1]; - __shared__ int baseIdx; - - __shared__ int yLength; - __shared__ char yOrder; - __shared__ int yEWS; - - char zOrder = shape::order(resultShapeInfo); - - int zEWS = shape::elementWiseStride(resultShapeInfo); - int tadEWS = shape::elementWiseStride(zTadShape); - int zLength = shape::length(resultShapeInfo); - - __shared__ int arrOffset; - __shared__ int numTads; - - - if (shape::isVector(resultShapeInfo)) { - //if (threadIdx.x == 0 && blockIdx.x == 0) - // printf("Vector here\n"); - - if (zEWS >= 1) { - for (int r = blockIdx.x; r < numArrays; r += gridDim.x) { - if(shape::isVector(shapeInfoPointers[r]) || shape::order(shapeInfoPointers[r]) == shape::order(resultShapeInfo)) { - yLength = shape::length(shapeInfoPointers[r]); - yEWS = shape::elementWiseStride(shapeInfoPointers[r]); - // FIXME: this is bad - __shared__ int baseIdx; - if (threadIdx.x == 0) { - baseIdx = 0; - for (int f = 0; f < r; f++) { - baseIdx += shape::length(shapeInfoPointers[f]); - } - } - __syncthreads(); - for (int i = threadIdx.x; i < yLength && baseIdx + i < zLength; i += blockDim.x) { - result[baseIdx + i * zEWS] = dataT[r][i * yEWS]; - } - __syncthreads(); - } else { - if (tid == 0) - printf("Non-matched order for vector\n"); - } - } - } else { - if (tid == 0) - printf("Vector Non-1 zEWS\n"); +template +__device__ void concatKernel(int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfos, void *vz, + Nd4jLong *resultShapeInfo, + Nd4jPointer *tadPointers, + Nd4jPointer *offsetPointers, Nd4jLong *zTadShape, + Nd4jLong *zOffsets) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + + int zRank = shape::rank(resultShapeInfo); + + auto result = reinterpret_cast(vz); + auto dataT = reinterpret_cast(data); + auto shapeInfoPointers = reinterpret_cast(inputShapeInfos); + auto tadShapes = reinterpret_cast(tadPointers); + auto tadOffsets = reinterpret_cast(offsetPointers); + + // if (threadIdx.x == 0 && blockIdx.x == 0) { + // shape::printShapeInfoLinear("zTadShape", zTadShape); + //} + + //__shared__ int tDim[1]; + __shared__ int baseIdx; + + __shared__ int yLength; + __shared__ char yOrder; + __shared__ int yEWS; + + char zOrder = shape::order(resultShapeInfo); + + int zEWS = shape::elementWiseStride(resultShapeInfo); + int tadEWS = shape::elementWiseStride(zTadShape); + int zLength = shape::length(resultShapeInfo); + + __shared__ int arrOffset; + __shared__ int numTads; + + if (shape::isVector(resultShapeInfo)) { + // if (threadIdx.x == 0 && blockIdx.x == 0) + // printf("Vector here\n"); + + if (zEWS >= 1) { + for (int r = blockIdx.x; r < numArrays; r += gridDim.x) { + if (shape::isVector(shapeInfoPointers[r]) || + shape::order(shapeInfoPointers[r]) == + shape::order(resultShapeInfo)) { + yLength = shape::length(shapeInfoPointers[r]); + yEWS = shape::elementWiseStride(shapeInfoPointers[r]); + // FIXME: this is bad + __shared__ int baseIdx; + if (threadIdx.x == 0) { + baseIdx = 0; + for (int f = 0; f < r; f++) { + baseIdx += shape::length(shapeInfoPointers[f]); } - return; + } + __syncthreads(); + for (int i = threadIdx.x; i < yLength && baseIdx + i < zLength; + i += blockDim.x) { + result[baseIdx + i * zEWS] = dataT[r][i * yEWS]; + } + __syncthreads(); + } else { + if (tid == 0) printf("Non-matched order for vector\n"); } + } + } else { + if (tid == 0) printf("Vector Non-1 zEWS\n"); + } + return; + } + + bool _vec = shape::isVector(resultShapeInfo); + + // TODO: to be pulled into separate kernel. matrix concatenation + for (int r = 0; r < numArrays; r++) { + auto currentShape = shapeInfoPointers[r]; + auto currentData = dataT[r]; + auto currentTad = tadShapes[r]; + auto currentOffsets = tadOffsets[r]; + + if (threadIdx.x == 0) { + yLength = shape::length(currentTad); + yOrder = shape::order(currentTad); + yEWS = shape::elementWiseStride(currentTad); + numTads = shape::length(currentShape) / yLength; + + arrOffset = 0; + for (int f = 0; f < r; f++) { + arrOffset += shape::length(tadShapes[f]); + } + + // if (threadIdx.x == 0 && blockIdx.x == 0) { + // shape::printShapeInfoLinear("currentTad", currentTad); + //} + } + __syncthreads(); + if (yLength == 1 && _vec) { + // if (threadIdx.x == 0 && blockIdx.x == 0) + // printf("Branch 0\n"); - bool _vec = shape::isVector(resultShapeInfo); + // edge case, each thread will handle it's own tad then + for (int j = tid; j < numTads; j += blockDim.x * gridDim.x) { + Nd4jLong inputOffset = currentOffsets[j]; + Nd4jLong resultOffset = zOffsets[j]; + T *dataTAD = currentData + inputOffset; + T *resultTAD = result + resultOffset; - // TODO: to be pulled into separate kernel. matrix concatenation - for (int r = 0; r < numArrays; r ++) { + int sub[MAX_RANK]; - auto currentShape = shapeInfoPointers[r]; - auto currentData = dataT[r]; - auto currentTad = tadShapes[r]; - auto currentOffsets = tadOffsets[r]; + shape::index2coords(arrOffset, zTadShape, sub); + Nd4jLong baseOffset = shape::getOffset(zTadShape, sub); - if (threadIdx.x == 0) { - yLength = shape::length(currentTad); - yOrder = shape::order(currentTad); - yEWS = shape::elementWiseStride(currentTad); - numTads = shape::length(currentShape) / yLength; - - arrOffset = 0; - for (int f = 0; f < r; f++) { - arrOffset += shape::length(tadShapes[f]); - } - - //if (threadIdx.x == 0 && blockIdx.x == 0) { - // shape::printShapeInfoLinear("currentTad", currentTad); - //} - } - __syncthreads(); + resultTAD += baseOffset; + + auto yRank = shape::rank(currentTad); + auto tadRank = shape::rank(zTadShape); + + shape::index2coords(0, currentTad, sub); - if (yLength == 1 && _vec) { - //if (threadIdx.x == 0 && blockIdx.x == 0) - // printf("Branch 0\n"); + auto yOffset = shape::getOffset(currentTad, sub); + resultOffset = shape::getOffset(zTadShape, sub); - // edge case, each thread will handle it's own tad then - for (int j = tid; j < numTads; j += blockDim.x * gridDim.x) { - Nd4jLong inputOffset = currentOffsets[j]; - Nd4jLong resultOffset = zOffsets[j]; + resultTAD[resultOffset] = dataTAD[yOffset]; + } + } else { + // if (threadIdx.x == 0 && blockIdx.x == 0) + // printf("Branch 1\n"); - T *dataTAD = currentData + inputOffset; - T *resultTAD = result + resultOffset; + for (int j = blockIdx.x; j < numTads; j += gridDim.x) { + auto inputOffset = currentOffsets[j]; + auto resultOffset = zOffsets[j]; - int sub[MAX_RANK]; + auto dataTAD = currentData + inputOffset; + auto resultTAD = result + resultOffset; - shape::index2coords(arrOffset, zTadShape, sub); + int sub[MAX_RANK]; - Nd4jLong baseOffset = shape::getOffset(zTadShape, sub); + shape::index2coords(arrOffset, zTadShape, sub); + Nd4jLong baseOffset = shape::getOffset(zTadShape, sub); - resultTAD += baseOffset; + resultTAD += baseOffset; - auto yRank = shape::rank(currentTad); - auto tadRank = shape::rank(zTadShape); + if (zOrder == yOrder && yEWS > 0 && tadEWS > 0) { + // if (threadIdx.x == 0 && blockIdx.x == 0) + // printf("Branch A\n"); - shape::index2coords(0, currentTad, sub); + for (int i = threadIdx.x; i < yLength; i += blockDim.x) { + resultTAD[i * tadEWS] = dataTAD[i * yEWS]; + } + } else { + if (tadEWS > 0 && + shape::order(resultShapeInfo) == shape::order(currentTad)) { + // if (threadIdx.x == 0 && blockIdx.x == 0) + // printf("Branch B\n"); - auto yOffset = shape::getOffset(currentTad, sub); - resultOffset = shape::getOffset(zTadShape, sub); + if (threadIdx.x == 0) { + baseIdx = 0; + for (int f = 0; f < r; f++) { + baseIdx += shape::length(shapeInfoPointers[f]); + } + // printf("R: %i; baseIdx: %i;\n", baseIdx); + } + __syncthreads(); - resultTAD[resultOffset] = dataTAD[yOffset]; - } + if (numTads == 1) { + for (int k = threadIdx.x; k < yLength; k += blockDim.x) { + resultTAD[baseIdx + k * tadEWS] = dataTAD[k]; + } } else { - //if (threadIdx.x == 0 && blockIdx.x == 0) - // printf("Branch 1\n"); - - for (int j = blockIdx.x; j < numTads; j += gridDim.x) { - auto inputOffset = currentOffsets[j]; - auto resultOffset = zOffsets[j]; - - auto dataTAD = currentData + inputOffset; - auto resultTAD = result + resultOffset; - - int sub[MAX_RANK]; - - shape::index2coords(arrOffset, zTadShape, sub); - Nd4jLong baseOffset = shape::getOffset(zTadShape, sub); - - resultTAD += baseOffset; - - if (zOrder == yOrder && yEWS > 0 && tadEWS > 0) { - //if (threadIdx.x == 0 && blockIdx.x == 0) - // printf("Branch A\n"); - - for (int i = threadIdx.x; i < yLength; i += blockDim.x) { - resultTAD[i * tadEWS] = dataTAD[i * yEWS]; - } - } else { - if(tadEWS > 0 && shape::order(resultShapeInfo) == shape::order(currentTad)) { - //if (threadIdx.x == 0 && blockIdx.x == 0) - // printf("Branch B\n"); - - if (threadIdx.x == 0) { - baseIdx = 0; - for (int f = 0; f < r; f++) { - baseIdx += shape::length(shapeInfoPointers[f]); - } - //printf("R: %i; baseIdx: %i;\n", baseIdx); - } - __syncthreads(); - - if (numTads == 1) { - for(int k = threadIdx.x; k < yLength; k+= blockDim.x) { - resultTAD[baseIdx + k * tadEWS] = dataTAD[k]; - } - } else { - int yIdx[MAX_RANK]; - auto yRank = shape::rank(currentTad); - - for (int i = threadIdx.x; i < yLength; i+= blockDim.x) { - shape::index2coords(i, currentTad, yIdx); - auto yOffset = shape::getOffset(currentTad, yIdx); - - resultTAD[baseIdx + i * tadEWS] = dataTAD[yOffset]; - } - } - __syncthreads(); - } else { - //if (threadIdx.x == 0 && blockIdx.x == 0) - // printf("Branch C; yLength: %i;\n", yLength); - - int zIdx[MAX_RANK]; - int yIdx[MAX_RANK]; - auto yRank = shape::rank(currentTad); - auto tadRank = shape::rank(zTadShape); - - for (int i = threadIdx.x; i < yLength; i+= blockDim.x) { - shape::index2coords(i, currentTad, yIdx); - shape::index2coords(i, zTadShape, zIdx); - - auto yOffset = shape::getOffset(currentTad, yIdx); - auto resultOffset = shape::getOffset(zTadShape, zIdx); - - resultTAD[resultOffset] = dataTAD[yOffset]; - } - } - } - __syncthreads(); - } + int yIdx[MAX_RANK]; + auto yRank = shape::rank(currentTad); + + for (int i = threadIdx.x; i < yLength; i += blockDim.x) { + shape::index2coords(i, currentTad, yIdx); + auto yOffset = shape::getOffset(currentTad, yIdx); + + resultTAD[baseIdx + i * tadEWS] = dataTAD[yOffset]; + } } __syncthreads(); + } else { + // if (threadIdx.x == 0 && blockIdx.x == 0) + // printf("Branch C; yLength: %i;\n", yLength); + + int zIdx[MAX_RANK]; + int yIdx[MAX_RANK]; + auto yRank = shape::rank(currentTad); + auto tadRank = shape::rank(zTadShape); + + for (int i = threadIdx.x; i < yLength; i += blockDim.x) { + shape::index2coords(i, currentTad, yIdx); + shape::index2coords(i, zTadShape, zIdx); + + auto yOffset = shape::getOffset(currentTad, yIdx); + auto resultOffset = shape::getOffset(zTadShape, zIdx); + + resultTAD[resultOffset] = dataTAD[yOffset]; + } + } } + __syncthreads(); + } } + __syncthreads(); + } +} /////////////////////////////////////////////////////////////////////// - template - __global__ void execConcatKernel(int numArrays, - Nd4jPointer *data, Nd4jPointer *inputShapeInfos, - void *vz, Nd4jLong *zShapeInfo, - Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers, - Nd4jLong *zTadShape, - Nd4jLong *zOffsets) { - - concatKernel(numArrays, data, inputShapeInfos, vz, zShapeInfo, tadPointers, offsetPointers, zTadShape, - zOffsets); - } - +template +__global__ void execConcatKernel(int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfos, void *vz, + Nd4jLong *zShapeInfo, Nd4jPointer *tadPointers, + Nd4jPointer *offsetPointers, + Nd4jLong *zTadShape, Nd4jLong *zOffsets) { + concatKernel(numArrays, data, inputShapeInfos, vz, zShapeInfo, tadPointers, + offsetPointers, zTadShape, zOffsets); +} /////////////////////////////////////////////////////////////////////// - template - __host__ void concatKernelGeneric(dim3 &launchDims, cudaStream_t *stream, - int numArrays, - Nd4jPointer *data, Nd4jPointer *inputShapeInfos, - void *vz, Nd4jLong *zShapeInfo, - Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers, - Nd4jLong *zTadShape, - Nd4jLong *zOffsets) { - - - execConcatKernel<<>>(numArrays, data, inputShapeInfos, vz, zShapeInfo, tadPointers, offsetPointers, zTadShape, zOffsets); - sd::DebugHelper::checkErrorCode(stream, "concatGenericLegacy(...) failed"); - } - - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT concatKernelGeneric, (dim3 & launchDims, cudaStream_t * stream, int numArrays, Nd4jPointer * data, Nd4jPointer * inputShapeInfos, void * vz, Nd4jLong *zShapeInfo, Nd4jPointer * tadPointers, Nd4jPointer * offsetPointers, Nd4jLong * zTadShape, Nd4jLong * zOffsets), LIBND4J_TYPES); -} \ No newline at end of file +template +__host__ void concatKernelGeneric(dim3 &launchDims, cudaStream_t *stream, + int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfos, void *vz, + Nd4jLong *zShapeInfo, + Nd4jPointer *tadPointers, + Nd4jPointer *offsetPointers, + Nd4jLong *zTadShape, Nd4jLong *zOffsets) { + execConcatKernel<<>>( + numArrays, data, inputShapeInfos, vz, zShapeInfo, tadPointers, + offsetPointers, zTadShape, zOffsets); + sd::DebugHelper::checkErrorCode(stream, "concatGenericLegacy(...) failed"); +} + +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT concatKernelGeneric, + (dim3 & launchDims, cudaStream_t *stream, int numArrays, + Nd4jPointer *data, Nd4jPointer *inputShapeInfos, + void *vz, Nd4jLong *zShapeInfo, Nd4jPointer *tadPointers, + Nd4jPointer *offsetPointers, Nd4jLong *zTadShape, + Nd4jLong *zOffsets), + LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/concatKernelHStack.cu b/libnd4j/include/loops/cuda/specials/concatKernelHStack.cu index 8ef9dfd24838..8fcf292558f9 100644 --- a/libnd4j/include/loops/cuda/specials/concatKernelHStack.cu +++ b/libnd4j/include/loops/cuda/specials/concatKernelHStack.cu @@ -24,72 +24,74 @@ namespace sd { /////////////////////////////////////////////////////////////////////// - template - __device__ void concatKernelHStack(int numArrays, - Nd4jPointer *data, Nd4jPointer *inputShapeInfos, - void *vz, Nd4jLong *zShapeInfo) { - - // we expect all data coming in as vectors, and z as 2D matrix - // the only significant difference here is the fact that input lengths might be different - auto z = reinterpret_cast(vz); - auto inputShapes = (Nd4jLong **) inputShapeInfos; - T **input = (T **) data; - - __shared__ int inputEWS; - __shared__ int resultEWS; - __shared__ int inputLength; - - if (threadIdx.x == 0) { - resultEWS = shape::elementWiseStride(zShapeInfo); - } - __syncthreads(); - - for (int r = blockIdx.x; r < numArrays; r += gridDim.x) { - - __shared__ int baseIdx; - if (threadIdx.x == 0) { - baseIdx = 0; - for (int f = 0; f < r; f++) { - baseIdx += shape::length(inputShapes[f]); - } - } - __syncthreads(); - - - T *inputData = (T *) input[r]; +template +__device__ void concatKernelHStack(int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfos, void *vz, + Nd4jLong *zShapeInfo) { + // we expect all data coming in as vectors, and z as 2D matrix + // the only significant difference here is the fact that input lengths might + // be different + auto z = reinterpret_cast(vz); + auto inputShapes = (Nd4jLong **)inputShapeInfos; + T **input = (T **)data; + + __shared__ int inputEWS; + __shared__ int resultEWS; + __shared__ int inputLength; + + if (threadIdx.x == 0) { + resultEWS = shape::elementWiseStride(zShapeInfo); + } + __syncthreads(); + + for (int r = blockIdx.x; r < numArrays; r += gridDim.x) { + __shared__ int baseIdx; + if (threadIdx.x == 0) { + baseIdx = 0; + for (int f = 0; f < r; f++) { + baseIdx += shape::length(inputShapes[f]); + } + } + __syncthreads(); - if (threadIdx.x == 0) { - inputEWS = shape::elementWiseStride(inputShapes[r]); - inputLength = shape::length(inputShapes[r]); - } - __syncthreads(); + T *inputData = (T *)input[r]; - for (int i = threadIdx.x; i < inputLength; i += blockDim.x) { - z[baseIdx + i * resultEWS] = inputData[i * inputEWS]; - } - __syncthreads(); - } + if (threadIdx.x == 0) { + inputEWS = shape::elementWiseStride(inputShapes[r]); + inputLength = shape::length(inputShapes[r]); } + __syncthreads(); -/////////////////////////////////////////////////////////////////////// - template - __global__ void execConcatKernelHStack(int numArrays, - Nd4jPointer *data, Nd4jPointer *inputShapeInfos, - void *vz, Nd4jLong *zShapeInfo) { - - concatKernelHStack(numArrays, data, inputShapeInfos, vz, zShapeInfo); + for (int i = threadIdx.x; i < inputLength; i += blockDim.x) { + z[baseIdx + i * resultEWS] = inputData[i * inputEWS]; } + __syncthreads(); + } +} /////////////////////////////////////////////////////////////////////// - template - __host__ void concatKernelHStackGeneric(dim3 &launchDims, cudaStream_t *stream, - int numArrays, - Nd4jPointer *data, Nd4jPointer *inputShapeInfos, - void *vz, Nd4jLong *zShapeInfo) { +template +__global__ void execConcatKernelHStack(int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfos, void *vz, + Nd4jLong *zShapeInfo) { + concatKernelHStack(numArrays, data, inputShapeInfos, vz, zShapeInfo); +} - execConcatKernelHStack<<>>(numArrays, data, inputShapeInfos, vz, zShapeInfo); - sd::DebugHelper::checkErrorCode(stream, "concatHStack(...) failed"); - } - - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT concatKernelHStackGeneric, (dim3 & launchDims, cudaStream_t * stream, int numArrays, Nd4jPointer * data, Nd4jPointer * inputShapeInfos, void * vz, Nd4jLong * zShapeInfo), LIBND4J_TYPES); -} \ No newline at end of file +/////////////////////////////////////////////////////////////////////// +template +__host__ void concatKernelHStackGeneric(dim3 &launchDims, cudaStream_t *stream, + int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfos, void *vz, + Nd4jLong *zShapeInfo) { + execConcatKernelHStack + <<>>( + numArrays, data, inputShapeInfos, vz, zShapeInfo); + sd::DebugHelper::checkErrorCode(stream, "concatHStack(...) failed"); +} + +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT concatKernelHStackGeneric, + (dim3 & launchDims, cudaStream_t *stream, int numArrays, + Nd4jPointer *data, Nd4jPointer *inputShapeInfos, + void *vz, Nd4jLong *zShapeInfo), + LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/concatKernelScalar.cu b/libnd4j/include/loops/cuda/specials/concatKernelScalar.cu index 6614480f25ec..10f1663f3ce6 100644 --- a/libnd4j/include/loops/cuda/specials/concatKernelScalar.cu +++ b/libnd4j/include/loops/cuda/specials/concatKernelScalar.cu @@ -24,32 +24,36 @@ namespace sd { /////////////////////////////////////////////////////////////////////// - template - __device__ void concatKernelScalar(int numArrays, Nd4jPointer *data, void *vz) { +template +__device__ void concatKernelScalar(int numArrays, Nd4jPointer *data, void *vz) { + auto z = static_cast(vz); + Nd4jLong tid = blockIdx.x * blockDim.x + threadIdx.x; + auto input = reinterpret_cast(data); - auto z = static_cast(vz); - Nd4jLong tid = blockIdx.x * blockDim.x + threadIdx.x; - auto input = reinterpret_cast(data); - - for (int i = tid; i < numArrays; i += blockDim.x * gridDim.x) - z[i] = input[i][0]; - } + for (int i = tid; i < numArrays; i += blockDim.x * gridDim.x) + z[i] = input[i][0]; +} /////////////////////////////////////////////////////////////////////// - template - __global__ void execConcatKernelScalar(int numArrays, Nd4jPointer *data, void *vz) { - - concatKernelScalar(numArrays, data, vz); - } +template +__global__ void execConcatKernelScalar(int numArrays, Nd4jPointer *data, + void *vz) { + concatKernelScalar(numArrays, data, vz); +} /////////////////////////////////////////////////////////////////////// - template - __host__ void - concatKernelScalarGeneric(dim3 &launchDims, cudaStream_t *stream, int numArrays, Nd4jPointer *data, void *vz) { - - execConcatKernelScalar<<>>(numArrays, data, vz); - sd::DebugHelper::checkErrorCode(stream, "concatScalar(...) failed"); - } - - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT concatKernelScalarGeneric, (dim3 & launchDims, cudaStream_t * stream, int numArrays, Nd4jPointer * data, void * vz), LIBND4J_TYPES); -} \ No newline at end of file +template +__host__ void concatKernelScalarGeneric(dim3 &launchDims, cudaStream_t *stream, + int numArrays, Nd4jPointer *data, + void *vz) { + execConcatKernelScalar + <<>>(numArrays, data, + vz); + sd::DebugHelper::checkErrorCode(stream, "concatScalar(...) failed"); +} + +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT concatKernelScalarGeneric, + (dim3 & launchDims, cudaStream_t *stream, int numArrays, + Nd4jPointer *data, void *vz), + LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/concatKernelVStack.cu b/libnd4j/include/loops/cuda/specials/concatKernelVStack.cu index f95bad413dd1..2d4cbee4ba03 100644 --- a/libnd4j/include/loops/cuda/specials/concatKernelVStack.cu +++ b/libnd4j/include/loops/cuda/specials/concatKernelVStack.cu @@ -24,62 +24,64 @@ namespace sd { /////////////////////////////////////////////////////////////////////// - template - __device__ void concatKernelVStack(int numArrays, - Nd4jPointer *data, Nd4jPointer *inputShapeInfos, - void *vz, Nd4jLong *zShapeInfo) { - - /* - this is special case for concat: we group bunch of vectors into 2D matrix - also: we expect each inputShapeInfo to have EWS, be a vector, and have equal size - */ - auto z = static_cast(vz); - - auto inputShapes = (Nd4jLong **) inputShapeInfos; - T **input = (T **) data; - - __shared__ int inputEWS; - __shared__ int resultEWS; - __shared__ int inputLength; - - if (threadIdx.x == 0) { - inputLength = shape::length(inputShapes[0]); - inputEWS = shape::elementWiseStride(inputShapes[0]); - resultEWS = shape::elementWiseStride(zShapeInfo); - } - __syncthreads(); - - for (int r = blockIdx.x; r < numArrays; r += gridDim.x) { - - int zOffset = r * inputLength * resultEWS; - T *inputData = (T *) input[r]; - - for (int i = threadIdx.x; i < inputLength; i += blockDim.x) { - z[zOffset + i * resultEWS] = inputData[i * inputEWS]; - } - } +template +__device__ void concatKernelVStack(int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfos, void *vz, + Nd4jLong *zShapeInfo) { + /* + this is special case for concat: we group bunch of vectors into 2D matrix + also: we expect each inputShapeInfo to have EWS, be a vector, and have equal + size + */ + auto z = static_cast(vz); + + auto inputShapes = (Nd4jLong **)inputShapeInfos; + T **input = (T **)data; + + __shared__ int inputEWS; + __shared__ int resultEWS; + __shared__ int inputLength; + + if (threadIdx.x == 0) { + inputLength = shape::length(inputShapes[0]); + inputEWS = shape::elementWiseStride(inputShapes[0]); + resultEWS = shape::elementWiseStride(zShapeInfo); + } + __syncthreads(); + + for (int r = blockIdx.x; r < numArrays; r += gridDim.x) { + int zOffset = r * inputLength * resultEWS; + T *inputData = (T *)input[r]; + + for (int i = threadIdx.x; i < inputLength; i += blockDim.x) { + z[zOffset + i * resultEWS] = inputData[i * inputEWS]; } + } +} /////////////////////////////////////////////////////////////////////// - template - __global__ void execConcatKernelVStack(int numArrays, - Nd4jPointer *data, Nd4jPointer *inputShapeInfos, - void *vz, Nd4jLong *zShapeInfo) { - - concatKernelVStack(numArrays, data, inputShapeInfos, vz, zShapeInfo); - } - +template +__global__ void execConcatKernelVStack(int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfos, void *vz, + Nd4jLong *zShapeInfo) { + concatKernelVStack(numArrays, data, inputShapeInfos, vz, zShapeInfo); +} /////////////////////////////////////////////////////////////////////// - template - __host__ void concatKernelVStackGeneric(dim3 &launchDims, cudaStream_t *stream, - int numArrays, - Nd4jPointer *data, Nd4jPointer *inputShapeInfos, - void *vz, Nd4jLong *zShapeInfo) { - - execConcatKernelVStack<<>>(numArrays, data, inputShapeInfos, vz, zShapeInfo); - sd::DebugHelper::checkErrorCode(stream, "concatVStack(...) failed"); - } - - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT concatKernelVStackGeneric, (dim3 & launchDims, cudaStream_t * stream, int numArrays, Nd4jPointer * data, Nd4jPointer * inputShapeInfos, void * vz, Nd4jLong *zShapeInfo), LIBND4J_TYPES); -} \ No newline at end of file +template +__host__ void concatKernelVStackGeneric(dim3 &launchDims, cudaStream_t *stream, + int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfos, void *vz, + Nd4jLong *zShapeInfo) { + execConcatKernelVStack + <<>>( + numArrays, data, inputShapeInfos, vz, zShapeInfo); + sd::DebugHelper::checkErrorCode(stream, "concatVStack(...) failed"); +} + +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT concatKernelVStackGeneric, + (dim3 & launchDims, cudaStream_t *stream, int numArrays, + Nd4jPointer *data, Nd4jPointer *inputShapeInfos, + void *vz, Nd4jLong *zShapeInfo), + LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/convertHalfs.cu b/libnd4j/include/loops/cuda/specials/convertHalfs.cu index dec1705a42d3..4c659c062553 100644 --- a/libnd4j/include/loops/cuda/specials/convertHalfs.cu +++ b/libnd4j/include/loops/cuda/specials/convertHalfs.cu @@ -24,23 +24,26 @@ namespace sd { /////////////////////////////////////////////////////////////////////// - template - __global__ void execConvertHalfs(half *dx, Nd4jLong n, void *dz) { - auto z = reinterpret_cast(dz); - int tid = threadIdx.x + blockIdx.x * blockDim.x; - - for (Nd4jLong i = tid; i < n; i += blockDim.x * gridDim.x) - z[i] = static_cast(__half2float(dx[i])); - } +template +__global__ void execConvertHalfs(half *dx, Nd4jLong n, void *dz) { + auto z = reinterpret_cast(dz); + int tid = threadIdx.x + blockIdx.x * blockDim.x; + for (Nd4jLong i = tid; i < n; i += blockDim.x * gridDim.x) + z[i] = static_cast(__half2float(dx[i])); +} /////////////////////////////////////////////////////////////////////// - template - __host__ void convertHalfsToGeneric(dim3 &launchDims, cudaStream_t *stream, half *dx, Nd4jLong n, void *dz) { - - execConvertHalfs<<>>(dx, n, dz); - sd::DebugHelper::checkErrorCode(stream, "convertHalfsToGeneric(...) failed"); - } - - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT convertHalfsToGeneric, (dim3 & launchDims, cudaStream_t * stream, half * dx, Nd4jLong n, void * dz), LIBND4J_TYPES); -} \ No newline at end of file +template +__host__ void convertHalfsToGeneric(dim3 &launchDims, cudaStream_t *stream, + half *dx, Nd4jLong n, void *dz) { + execConvertHalfs + <<>>(dx, n, dz); + sd::DebugHelper::checkErrorCode(stream, "convertHalfsToGeneric(...) failed"); +} + +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT convertHalfsToGeneric, + (dim3 & launchDims, cudaStream_t *stream, half *dx, + Nd4jLong n, void *dz), + LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/convertToHalf.cu b/libnd4j/include/loops/cuda/specials/convertToHalf.cu index d86982d03585..fc6b8801e15f 100644 --- a/libnd4j/include/loops/cuda/specials/convertToHalf.cu +++ b/libnd4j/include/loops/cuda/specials/convertToHalf.cu @@ -24,22 +24,27 @@ namespace sd { //////////////////////////////////////////////////////////////////////// - template - __global__ void execConvertToHalf(void *dx, Nd4jLong n, half *dz) { - auto x = reinterpret_cast(dx); - int tid = threadIdx.x + blockIdx.x * blockDim.x; +template +__global__ void execConvertToHalf(void *dx, Nd4jLong n, half *dz) { + auto x = reinterpret_cast(dx); + int tid = threadIdx.x + blockIdx.x * blockDim.x; - for (Nd4jLong i = tid; i < n; i += blockDim.x * gridDim.x) - dz[i] = __float2half(static_cast(x[i])); - } + for (Nd4jLong i = tid; i < n; i += blockDim.x * gridDim.x) + dz[i] = __float2half(static_cast(x[i])); +} //////////////////////////////////////////////////////////////////////// - template - __host__ void convertToHalfGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong n, half *dz) { - execConvertToHalf<<>>(dx, n, dz); - sd::DebugHelper::checkErrorCode(stream, "convertToHalfs(...) failed"); - } - - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT convertToHalfGeneric, (dim3 & launchDims, cudaStream_t * stream, void * dx, Nd4jLong n, half * dz), LIBND4J_TYPES); - -} \ No newline at end of file +template +__host__ void convertToHalfGeneric(dim3 &launchDims, cudaStream_t *stream, + void *dx, Nd4jLong n, half *dz) { + execConvertToHalf + <<>>(dx, n, dz); + sd::DebugHelper::checkErrorCode(stream, "convertToHalfs(...) failed"); +} + +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT convertToHalfGeneric, + (dim3 & launchDims, cudaStream_t *stream, void *dx, + Nd4jLong n, half *dz), + LIBND4J_TYPES); + +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/fillDimensionalIsMax.cu b/libnd4j/include/loops/cuda/specials/fillDimensionalIsMax.cu index 409f84cc69ed..3b278e67b7f2 100644 --- a/libnd4j/include/loops/cuda/specials/fillDimensionalIsMax.cu +++ b/libnd4j/include/loops/cuda/specials/fillDimensionalIsMax.cu @@ -23,72 +23,78 @@ namespace sd { +//////////////////////////////////////////////////////////////////////// +template +__device__ void fillDimensionalIsMax(const void *vdX, void *vdZ, + const Nd4jLong *zShapeInfo, + const Nd4jLong *tadOnlyShapeInfo, + int *dimension, int dimensionLength, + const Nd4jLong *tadOffsets) { + auto dX = reinterpret_cast(vdX); + auto dZ = reinterpret_cast(vdZ); + + __shared__ int tadLength; + __shared__ int tadEWS; + __shared__ int numTads; + + if (threadIdx.x == 0) { + tadLength = + shape::length(tadOnlyShapeInfo); // shape::tadLength(zShapeInfo, + // dimension, dimensionLength); + tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); + numTads = shape::length(zShapeInfo) / tadLength; + } + __syncthreads(); + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + auto tadOffsetForBlock = tadOffsets[r]; + auto highestElement = dX[r]; + + if (dimensionLength > 1 || tadEWS < 1) { + for (Nd4jLong e = threadIdx.x; e < tadLength; e += blockDim.x) { + auto xOffset = + tadOffsetForBlock + shape::getIndexOffset(e, tadOnlyShapeInfo); + dZ[xOffset] = (e == highestElement ? (T)1 : (T)0); + } + } else { + for (Nd4jLong e = threadIdx.x; e < tadLength; e += blockDim.x) { + // so, we just set dZ[e] for each TAD. Sure, e should be replaced with + auto idx = tadOffsetForBlock + (e * tadEWS); + dZ[idx] = (e == highestElement ? (T)1 : (T)0); + } + } + } +} //////////////////////////////////////////////////////////////////////// - template - __device__ void fillDimensionalIsMax(const void *vdX, - void *vdZ, const Nd4jLong *zShapeInfo, +template +__global__ void execfillDimensionalIsMax(const void *dX, void *dZ, + const Nd4jLong *zShapeInfo, const Nd4jLong *tadOnlyShapeInfo, int *dimension, int dimensionLength, const Nd4jLong *tadOffsets) { - - auto dX = reinterpret_cast(vdX); - auto dZ = reinterpret_cast(vdZ); - - __shared__ int tadLength; - __shared__ int tadEWS; - __shared__ int numTads; - - if (threadIdx.x == 0) { - tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(zShapeInfo, dimension, dimensionLength); - tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); - numTads = shape::length(zShapeInfo) / tadLength; - } - __syncthreads(); - - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - auto tadOffsetForBlock = tadOffsets[r]; - auto highestElement = dX[r]; - - if (dimensionLength > 1 || tadEWS < 1) { - - for (Nd4jLong e = threadIdx.x; e < tadLength; e += blockDim.x) { - auto xOffset = tadOffsetForBlock + shape::getIndexOffset(e, tadOnlyShapeInfo); - dZ[xOffset] = (e == highestElement ? (T) 1 : (T) 0); - } - } else { - for (Nd4jLong e = threadIdx.x; e < tadLength; e += blockDim.x) { - // so, we just set dZ[e] for each TAD. Sure, e should be replaced with - auto idx = tadOffsetForBlock + (e * tadEWS); - dZ[idx] = (e == highestElement ? (T) 1 : (T) 0); - } - } - } - } - - -//////////////////////////////////////////////////////////////////////// - template - __global__ void execfillDimensionalIsMax(const void *dX, - void *dZ, const Nd4jLong *zShapeInfo, - const Nd4jLong *tadOnlyShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOffsets) { - - fillDimensionalIsMax(dX, dZ, zShapeInfo, tadOnlyShapeInfo, dimension, dimensionLength, tadOffsets); - } + fillDimensionalIsMax(dX, dZ, zShapeInfo, tadOnlyShapeInfo, dimension, + dimensionLength, tadOffsets); +} //////////////////////////////////////////////////////////////////////// - template - __host__ void fillDimensionalIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, - const void *dX, - void *dZ, const Nd4jLong *zShapeInfo, - const Nd4jLong *tadOnlyShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOffsets) { - - execfillDimensionalIsMax<<>>(dX, dZ, zShapeInfo, tadOnlyShapeInfo, dimension, dimensionLength, tadOffsets); - sd::DebugHelper::checkErrorCode(stream, "fillDimensionalIsMax(...) failed"); - } - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT fillDimensionalIsMaxGeneric, (dim3& launchDims, cudaStream_t *stream, const void *dX, void *dZ, const Nd4jLong *zShapeInfo, const Nd4jLong *tadOnlyShapeInfo, int *dimension, int dimensionLength, const Nd4jLong *tadOffsets), LIBND4J_TYPES); -} \ No newline at end of file +template +__host__ void fillDimensionalIsMaxGeneric(dim3 &launchDims, + cudaStream_t *stream, const void *dX, + void *dZ, const Nd4jLong *zShapeInfo, + const Nd4jLong *tadOnlyShapeInfo, + int *dimension, int dimensionLength, + const Nd4jLong *tadOffsets) { + execfillDimensionalIsMax + <<>>( + dX, dZ, zShapeInfo, tadOnlyShapeInfo, dimension, dimensionLength, + tadOffsets); + sd::DebugHelper::checkErrorCode(stream, "fillDimensionalIsMax(...) failed"); +} +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT fillDimensionalIsMaxGeneric, + (dim3 & launchDims, cudaStream_t *stream, const void *dX, + void *dZ, const Nd4jLong *zShapeInfo, + const Nd4jLong *tadOnlyShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadOffsets), + LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/fillIsMax.cu b/libnd4j/include/loops/cuda/specials/fillIsMax.cu index 00997b0220ad..909b5b6f87c5 100644 --- a/libnd4j/include/loops/cuda/specials/fillIsMax.cu +++ b/libnd4j/include/loops/cuda/specials/fillIsMax.cu @@ -24,22 +24,28 @@ namespace sd { //////////////////////////////////////////////////////////////////////// - template - __global__ void execFillIsMax(void *vdZ, const Nd4jLong *xShapeInfo, Nd4jLong length, long idx) { - auto dz = reinterpret_cast(vdZ); - int tid = blockIdx.x * blockDim.x + threadIdx.x; +template +__global__ void execFillIsMax(void *vdZ, const Nd4jLong *xShapeInfo, + Nd4jLong length, long idx) { + auto dz = reinterpret_cast(vdZ); + int tid = blockIdx.x * blockDim.x + threadIdx.x; - for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x) - dz[shape::getIndexOffset(i, xShapeInfo)] = (i == idx ? (T) 1 : (T) 0); - } + for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x) + dz[shape::getIndexOffset(i, xShapeInfo)] = (i == idx ? (T)1 : (T)0); +} //////////////////////////////////////////////////////////////////////// - template - __host__ void fillIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, const Nd4jLong *xShapeInfo, Nd4jLong length, long idx) { - execFillIsMax<<>>(dx, xShapeInfo, length, idx); - sd::DebugHelper::checkErrorCode(stream, "fillIsMax(...) failed"); - } - - - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT fillIsMaxGeneric, (dim3& launchDims, cudaStream_t *stream, void* dz, const Nd4jLong *zShapeInfo, Nd4jLong length, long idx), LIBND4J_TYPES); -} \ No newline at end of file +template +__host__ void fillIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, + const Nd4jLong *xShapeInfo, Nd4jLong length, + long idx) { + execFillIsMax<<>>( + dx, xShapeInfo, length, idx); + sd::DebugHelper::checkErrorCode(stream, "fillIsMax(...) failed"); +} + +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT fillIsMaxGeneric, + (dim3 & launchDims, cudaStream_t *stream, void *dz, + const Nd4jLong *zShapeInfo, Nd4jLong length, long idx), + LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/flatten.cu b/libnd4j/include/loops/cuda/specials/flatten.cu index b0bbf58e12cb..0babe05b4d81 100644 --- a/libnd4j/include/loops/cuda/specials/flatten.cu +++ b/libnd4j/include/loops/cuda/specials/flatten.cu @@ -26,46 +26,44 @@ namespace sd { //////////////////////////////////////////////////////////////////////// template -__global__ void flattenKernel( - Nd4jPointer *extraPointers, - int dOffset, - char order, - void *vz, Nd4jLong *zShapeInfo, - void *vy, Nd4jLong *yShapeInfo) { +__global__ void flattenKernel(Nd4jPointer *extraPointers, int dOffset, + char order, void *vz, Nd4jLong *zShapeInfo, + void *vy, Nd4jLong *yShapeInfo) { + auto z = reinterpret_cast(vz); + auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto y = reinterpret_cast(vy); + __shared__ Nd4jLong lenY, yOrder, zEWS, yEWS; - __shared__ Nd4jLong lenY, yOrder, zEWS, yEWS; + if (threadIdx.x == 0) { + yEWS = shape::elementWiseStride(yShapeInfo); + zEWS = shape::elementWiseStride(zShapeInfo); + lenY = shape::length(yShapeInfo); + } + __syncthreads(); - if (threadIdx.x == 0) { + Nd4jLong tid = blockIdx.x * blockDim.x + threadIdx.x; - yEWS = shape::elementWiseStride(yShapeInfo); - zEWS = shape::elementWiseStride(zShapeInfo); - lenY = shape::length(yShapeInfo); - } - __syncthreads(); - - Nd4jLong tid = blockIdx.x * blockDim.x + threadIdx.x; - - for(auto i = tid; i < lenY; i += gridDim.x * blockDim.x) - z[i * zEWS + dOffset] = y[ops::helpers::getIndexOffsetOrdered(i, yShapeInfo, order)]; + for (auto i = tid; i < lenY; i += gridDim.x * blockDim.x) + z[i * zEWS + dOffset] = + y[ops::helpers::getIndexOffsetOrdered(i, yShapeInfo, order)]; } //////////////////////////////////////////////////////////////////////// template -__host__ void flattenKernelGeneric(dim3& launchDims, cudaStream_t *stream, - Nd4jPointer *extraPointers, - int dOffset, - char order, - void *vz, Nd4jLong *zShapeInfo, - void *vy, Nd4jLong *yShapeInfo) { - - flattenKernel<<>>(extraPointers, dOffset, order, vz, zShapeInfo, vy, yShapeInfo); - sd::DebugHelper::checkErrorCode(stream, "flattenGeneric(...) failed"); +__host__ void flattenKernelGeneric(dim3 &launchDims, cudaStream_t *stream, + Nd4jPointer *extraPointers, int dOffset, + char order, void *vz, Nd4jLong *zShapeInfo, + void *vy, Nd4jLong *yShapeInfo) { + flattenKernel<<>>( + extraPointers, dOffset, order, vz, zShapeInfo, vy, yShapeInfo); + sd::DebugHelper::checkErrorCode(stream, "flattenGeneric(...) failed"); } -BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT flattenKernelGeneric, (dim3& launchDims, cudaStream_t *stream, Nd4jPointer *extraPointers, int dOffset, char order, void *vz, Nd4jLong *zShapeInfo, void *vy, Nd4jLong *yShapeInfo), LIBND4J_TYPES); - +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT flattenKernelGeneric, + (dim3 & launchDims, cudaStream_t *stream, + Nd4jPointer *extraPointers, int dOffset, char order, + void *vz, Nd4jLong *zShapeInfo, void *vy, + Nd4jLong *yShapeInfo), + LIBND4J_TYPES); -} \ No newline at end of file +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/oesTad.cu b/libnd4j/include/loops/cuda/specials/oesTad.cu index 6f08e23ad032..38d0e915e99a 100644 --- a/libnd4j/include/loops/cuda/specials/oesTad.cu +++ b/libnd4j/include/loops/cuda/specials/oesTad.cu @@ -22,184 +22,195 @@ ////////////////////////////////////////////////////////////////////////// template -__global__ void execOesTadKernelKey(void *vx, Nd4jLong const* xShapeInfo, - void *vy, Nd4jLong const* yShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - bool descending) { - - auto x = static_cast(vx); - auto y = static_cast(vy); - - __shared__ int xLength; - __shared__ int xTadLength; - __shared__ int numTads; - if (threadIdx.x == 0) { - xLength = shape::length(xShapeInfo); - xTadLength = shape::length(tadShapeInfo); - numTads = xLength / xTadLength; - } - __syncthreads(); - - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - auto dx = x + tadOffsets[r]; - auto dy = y + tadOffsets[r]; - - // this is general loop, we go uncached - int iterations = xTadLength; - - for (int i = 0; i < iterations; i++) { - - if (i % 2 == 0) { - for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { - auto top = 2 * tid + 1; - if (top < xTadLength) { - auto t0 = shape::getIndexOffset(top - 1, tadShapeInfo); - auto t1 = shape::getIndexOffset(top, tadShapeInfo); - - if (!descending == (dx[t0] > dx[t1])) { - X dt0 = dx[t0]; - dx[t0] = dx[t1]; - dx[t1] = dt0; - - Y dy0 = dy[t0]; - dy[t0] = dy[t1]; - dy[t1] = dy0; - } - } - } - } else { - for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { - auto top = 2 * tid + 2; - if (top < xTadLength) { - auto t0 = shape::getIndexOffset(top - 1, tadShapeInfo); - auto t1 = shape::getIndexOffset(top, tadShapeInfo); - - if (!descending == (dx[t0] > dx[t1])) { - X dt0 = dx[t0]; - dx[t0] = dx[t1]; - dx[t1] = dt0; - - Y dy0 = dy[t0]; - dy[t0] = dy[t1]; - dy[t1] = dy0; - } - } - } +__global__ void execOesTadKernelKey(void *vx, Nd4jLong const *xShapeInfo, + void *vy, Nd4jLong const *yShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, + bool descending) { + auto x = static_cast(vx); + auto y = static_cast(vy); + + __shared__ int xLength; + __shared__ int xTadLength; + __shared__ int numTads; + if (threadIdx.x == 0) { + xLength = shape::length(xShapeInfo); + xTadLength = shape::length(tadShapeInfo); + numTads = xLength / xTadLength; + } + __syncthreads(); + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + auto dx = x + tadOffsets[r]; + auto dy = y + tadOffsets[r]; + + // this is general loop, we go uncached + int iterations = xTadLength; + + for (int i = 0; i < iterations; i++) { + if (i % 2 == 0) { + for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { + auto top = 2 * tid + 1; + if (top < xTadLength) { + auto t0 = shape::getIndexOffset(top - 1, tadShapeInfo); + auto t1 = shape::getIndexOffset(top, tadShapeInfo); + + if (!descending == (dx[t0] > dx[t1])) { + X dt0 = dx[t0]; + dx[t0] = dx[t1]; + dx[t1] = dt0; + + Y dy0 = dy[t0]; + dy[t0] = dy[t1]; + dy[t1] = dy0; + } + } + } + } else { + for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { + auto top = 2 * tid + 2; + if (top < xTadLength) { + auto t0 = shape::getIndexOffset(top - 1, tadShapeInfo); + auto t1 = shape::getIndexOffset(top, tadShapeInfo); + + if (!descending == (dx[t0] > dx[t1])) { + X dt0 = dx[t0]; + dx[t0] = dx[t1]; + dx[t1] = dt0; + + Y dy0 = dy[t0]; + dy[t0] = dy[t1]; + dy[t1] = dy0; } - __syncthreads(); + } } + } + __syncthreads(); } + } } ////////////////////////////////////////////////////////////////////////// -template -__global__ void execOesTadKernel(void *vx, Nd4jLong const* xShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - bool descending) { - - auto x = static_cast(vx); - const int sharedSize = 32768; - - __shared__ int xLength; - __shared__ int xTadLength; - __shared__ int numTads; - __shared__ T *shmem; - __shared__ bool cached; - if (threadIdx.x == 0) { - xLength = shape::length(xShapeInfo); - xTadLength = shape::length(tadShapeInfo); - numTads = xLength / xTadLength; - - extern __shared__ unsigned char shrd[]; - shmem = (T *) shrd; - - cached = xTadLength <= (sharedSize / sizeof(T)); +template +__global__ void execOesTadKernel(void *vx, Nd4jLong const *xShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, bool descending) { + auto x = static_cast(vx); + const int sharedSize = 32768; + + __shared__ int xLength; + __shared__ int xTadLength; + __shared__ int numTads; + __shared__ T *shmem; + __shared__ bool cached; + if (threadIdx.x == 0) { + xLength = shape::length(xShapeInfo); + xTadLength = shape::length(tadShapeInfo); + numTads = xLength / xTadLength; + + extern __shared__ unsigned char shrd[]; + shmem = (T *)shrd; + + cached = xTadLength <= (sharedSize / sizeof(T)); + } + __syncthreads(); + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + auto dx = x + tadOffsets[r]; + + // this is general loop, we go uncached + int iterations = xTadLength; + if (cached) { + for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { + auto t0 = shape::getIndexOffset(tid, tadShapeInfo); + shmem[tid] = dx[t0]; + } + + __syncthreads(); + dx = shmem; } - __syncthreads(); - - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - auto dx = x + tadOffsets[r]; - // this is general loop, we go uncached - int iterations = xTadLength; - if (cached) { - for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { - auto t0 = shape::getIndexOffset(tid, tadShapeInfo); - shmem[tid] = dx[t0]; + for (int i = 0; i < iterations; i++) { + if (i % 2 == 0) { + for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { + auto top = 2 * tid + 1; + if (top < xTadLength) { + auto t0 = + cached ? top - 1 : shape::getIndexOffset(top - 1, tadShapeInfo); + auto t1 = cached ? top : shape::getIndexOffset(top, tadShapeInfo); + + if (!descending == (dx[t0] > dx[t1])) { + T dt0 = dx[t0]; + dx[t0] = dx[t1]; + dx[t1] = dt0; } - - __syncthreads(); - dx = shmem; + } } - - for (int i = 0; i < iterations; i++) { - - if (i % 2 == 0) { - for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { - auto top = 2 * tid + 1; - if (top < xTadLength) { - auto t0 = cached ? top - 1 : shape::getIndexOffset(top - 1, tadShapeInfo); - auto t1 = cached ? top : shape::getIndexOffset(top, tadShapeInfo); - - if (!descending == (dx[t0] > dx[t1])) { - T dt0 = dx[t0]; - dx[t0] = dx[t1]; - dx[t1] = dt0; - } - } - } - } else { - for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { - auto top = 2 * tid + 2; - if (top < xTadLength) { - auto t0 = cached ? top - 1 : shape::getIndexOffset(top - 1, tadShapeInfo); - auto t1 = cached ? top : shape::getIndexOffset(top, tadShapeInfo); - - if (!descending == (dx[t0] > dx[t1])) { - T dt0 = dx[t0]; - dx[t0] = dx[t1]; - dx[t1] = dt0; - } - } - } + } else { + for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { + auto top = 2 * tid + 2; + if (top < xTadLength) { + auto t0 = + cached ? top - 1 : shape::getIndexOffset(top - 1, tadShapeInfo); + auto t1 = cached ? top : shape::getIndexOffset(top, tadShapeInfo); + + if (!descending == (dx[t0] > dx[t1])) { + T dt0 = dx[t0]; + dx[t0] = dx[t1]; + dx[t1] = dt0; } - __syncthreads(); + } } + } + __syncthreads(); + } - - if (cached) { - dx = x + tadOffsets[r]; - for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { - auto t0 = shape::getIndexOffset(tid, tadShapeInfo); - dx[t0] = shmem[tid]; - } - } + if (cached) { + dx = x + tadOffsets[r]; + for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { + auto t0 = shape::getIndexOffset(tid, tadShapeInfo); + dx[t0] = shmem[tid]; + } } + } } ////////////////////////////////////////////////////////////////////////// -template -__host__ void oesTadGeneric(dim3 &launchDims, cudaStream_t *stream, - void *vx, Nd4jLong const* xShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - bool descending) { - - execOesTadKernel<<>>(vx, xShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending); +template +__host__ void oesTadGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, int *dimension, + int dimensionLength, Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, bool descending) { + execOesTadKernel<<>>( + vx, xShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, + descending); } template -__host__ void oesTadGenericKey(dim3 &launchDims, cudaStream_t *stream, - void *vx, Nd4jLong const* xShapeInfo, - void *vy, Nd4jLong const* yShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - bool descending) { - - execOesTadKernelKey<<>>(vx, xShapeInfo, vy, yShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending); +__host__ void oesTadGenericKey(dim3 &launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, void *vy, + Nd4jLong const *yShapeInfo, int *dimension, + int dimensionLength, + Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, bool descending) { + execOesTadKernelKey + <<>>( + vx, xShapeInfo, vy, yShapeInfo, dimension, dimensionLength, + tadShapeInfo, tadOffsets, descending); } -BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT oesTadGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool descending), LIBND4J_TYPES); -BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT oesTadGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool descending), LIBND4J_TYPES, LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT oesTadGeneric, + (dim3 & launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, int *dimension, + int dimensionLength, Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, bool descending), + LIBND4J_TYPES); +BUILD_DOUBLE_TEMPLATE(template void SD_EXPORT oesTadGenericKey, + (dim3 & launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, void *vy, + Nd4jLong const *yShapeInfo, int *dimension, + int dimensionLength, Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, bool descending), + LIBND4J_TYPES, LIBND4J_TYPES); diff --git a/libnd4j/include/loops/cuda/specials/pullRowsKernel.cu b/libnd4j/include/loops/cuda/specials/pullRowsKernel.cu index 69d103e6735c..36e2431ca19e 100644 --- a/libnd4j/include/loops/cuda/specials/pullRowsKernel.cu +++ b/libnd4j/include/loops/cuda/specials/pullRowsKernel.cu @@ -24,69 +24,70 @@ namespace sd { /////////////////////////////////////////////////////////////////////// - template - __device__ void pullRowsKernel(void *vx, - void *vz, - Nd4jLong len, - Nd4jLong *indexes, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* zTadShapeInfo, Nd4jLong const* zTadOffsets) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto xEWS = shape::elementWiseStride(tadShapeInfo); - auto zEWS = shape::elementWiseStride(zTadShapeInfo); - auto tadLength = shape::length(tadShapeInfo); - - if (xEWS >= 1 && zEWS >= 1) { - for (int idx = blockIdx.x; idx < len; idx += gridDim.x) { - T *rX = x + tadOffsets[indexes[idx]]; - T *rZ = z + zTadOffsets[idx]; +template +__device__ void pullRowsKernel(void *vx, void *vz, Nd4jLong len, + Nd4jLong *indexes, Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, + Nd4jLong const *zTadShapeInfo, + Nd4jLong const *zTadOffsets) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto xEWS = shape::elementWiseStride(tadShapeInfo); + auto zEWS = shape::elementWiseStride(zTadShapeInfo); + auto tadLength = shape::length(tadShapeInfo); - for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { - rZ[i * zEWS] = rX[i * xEWS]; - } - } - } else { - for (int idx = blockIdx.x; idx < len; idx += gridDim.x) { - T *rX = x + tadOffsets[indexes[idx]]; - T *rZ = z + zTadOffsets[idx]; + if (xEWS >= 1 && zEWS >= 1) { + for (int idx = blockIdx.x; idx < len; idx += gridDim.x) { + T *rX = x + tadOffsets[indexes[idx]]; + T *rZ = z + zTadOffsets[idx]; - for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { - auto xOffset = shape::getIndexOffset(i, tadShapeInfo); - auto zOffset = shape::getIndexOffset(i, zTadShapeInfo); - rZ[zOffset] = rX[xOffset]; - } - } - } + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { + rZ[i * zEWS] = rX[i * xEWS]; + } } + } else { + for (int idx = blockIdx.x; idx < len; idx += gridDim.x) { + T *rX = x + tadOffsets[indexes[idx]]; + T *rZ = z + zTadOffsets[idx]; -/////////////////////////////////////////////////////////////////////// - template - __global__ void execPullRowsKernel(void *vx, - void *vz, - Nd4jLong len, - Nd4jLong *indexes, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* zTadShapeInfo, Nd4jLong const* zTadOffsets) { - - pullRowsKernel(vx, vz, len, indexes, tadShapeInfo, tadOffsets, zTadShapeInfo, zTadOffsets); + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = shape::getIndexOffset(i, tadShapeInfo); + auto zOffset = shape::getIndexOffset(i, zTadShapeInfo); + rZ[zOffset] = rX[xOffset]; + } } + } +} /////////////////////////////////////////////////////////////////////// - template - __host__ void pullRowsKernelGeneric(dim3 &launchDims, cudaStream_t *stream, - void *vx, - void *vz, - Nd4jLong len, - Nd4jLong *indexes, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* zTadShapeInfo, Nd4jLong const* zTadOffsets) { - - execPullRowsKernel<<>>(vx, vz, len, indexes, tadShapeInfo, tadOffsets, zTadShapeInfo, zTadOffsets); - sd::DebugHelper::checkErrorCode(stream, "pullRows(...) failed"); - } +template +__global__ void execPullRowsKernel(void *vx, void *vz, Nd4jLong len, + Nd4jLong *indexes, + Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, + Nd4jLong const *zTadShapeInfo, + Nd4jLong const *zTadOffsets) { + pullRowsKernel(vx, vz, len, indexes, tadShapeInfo, tadOffsets, + zTadShapeInfo, zTadOffsets); +} - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT pullRowsKernelGeneric, (dim3 & launchDims, cudaStream_t * stream, void * vx, void * vz, Nd4jLong len, Nd4jLong * indexes, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* zTadShapeInfo, Nd4jLong const* zTadOffsets), LIBND4J_TYPES); +/////////////////////////////////////////////////////////////////////// +template +__host__ void pullRowsKernelGeneric( + dim3 &launchDims, cudaStream_t *stream, void *vx, void *vz, Nd4jLong len, + Nd4jLong *indexes, Nd4jLong const *tadShapeInfo, Nd4jLong const *tadOffsets, + Nd4jLong const *zTadShapeInfo, Nd4jLong const *zTadOffsets) { + execPullRowsKernel<<>>( + vx, vz, len, indexes, tadShapeInfo, tadOffsets, zTadShapeInfo, + zTadOffsets); + sd::DebugHelper::checkErrorCode(stream, "pullRows(...) failed"); } +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT pullRowsKernelGeneric, + (dim3 & launchDims, cudaStream_t *stream, void *vx, + void *vz, Nd4jLong len, Nd4jLong *indexes, + Nd4jLong const *tadShapeInfo, Nd4jLong const *tadOffsets, + Nd4jLong const *zTadShapeInfo, + Nd4jLong const *zTadOffsets), + LIBND4J_TYPES); +} // namespace sd diff --git a/libnd4j/include/loops/cuda/specials/setDiagonalKernel.cu b/libnd4j/include/loops/cuda/specials/setDiagonalKernel.cu index bb063180c45a..32bf9b1df988 100644 --- a/libnd4j/include/loops/cuda/specials/setDiagonalKernel.cu +++ b/libnd4j/include/loops/cuda/specials/setDiagonalKernel.cu @@ -18,8 +18,8 @@ // @author GS , created on 21.01.2019 // -#include #include +#include namespace sd { //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -27,93 +27,151 @@ namespace sd { // buffer - input buffer // shape - input shape // value - given value -// diagonal - given upper diagonal (acceptable negative values also, 0 - the main diagonal) -// row, cols - height and width of given matrix (MxN, rows = M, cols = N) +// diagonal - given upper diagonal (acceptable negative values also, 0 - the +// main diagonal) row, cols - height and width of given matrix (MxN, rows = M, +// cols = N) // - template - static __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, T value, int diagonal, Nd4jLong rows, - Nd4jLong cols) { - - __shared__ Nd4jLong rank; - __shared__ T* array; +template +static __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, + T value, int diagonal, + Nd4jLong rows, Nd4jLong cols) { + __shared__ Nd4jLong rank; + __shared__ T* array; - if (0 == threadIdx.x) { - rank = shape::rank(shape); - array = reinterpret_cast(buffer); - } - __syncthreads(); + if (0 == threadIdx.x) { + rank = shape::rank(shape); + array = reinterpret_cast(buffer); + } + __syncthreads(); - for (Nd4jLong i = blockIdx.x; i < rows; i += gridDim.x) { - for (int j = threadIdx.x; j < cols; j += blockDim.x) { - Nd4jLong coords[2] = {i, j}; - Nd4jLong xOffset = shape::getOffset(shape, coords); - if (i + diagonal <= j) - array[xOffset] = value; - } - } + for (Nd4jLong i = blockIdx.x; i < rows; i += gridDim.x) { + for (int j = threadIdx.x; j < cols; j += blockDim.x) { + Nd4jLong coords[2] = {i, j}; + Nd4jLong xOffset = shape::getOffset(shape, coords); + if (i + diagonal <= j) array[xOffset] = value; } + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // set up given value to lower given diagonal // buffer - input buffer // shape - input shape // value - given value -// diagonal - given lower diagonal (acceptable negative values also, 0 - the main diagonal) -// row, cols - height and width of given matrix (MxN, rows = M, cols = N) +// diagonal - given lower diagonal (acceptable negative values also, 0 - the +// main diagonal) row, cols - height and width of given matrix (MxN, rows = M, +// cols = N) // - template - static __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, T value, int diagonal, Nd4jLong rows, Nd4jLong cols) { - Nd4jLong rank = shape::rank(shape); - int totalThreads = blockDim.x; - for (Nd4jLong i = blockIdx.x; i < rows; i += gridDim.x) { - for (int j = threadIdx.x; j < cols; j += totalThreads) { - Nd4jLong coords[2] = {i, j}; - auto xOffset = shape::getOffset(shape, coords); - if (i + diagonal >= j) - *(reinterpret_cast(buffer) + xOffset) = value; - } - } +template +static __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, + T value, int diagonal, + Nd4jLong rows, Nd4jLong cols) { + Nd4jLong rank = shape::rank(shape); + int totalThreads = blockDim.x; + for (Nd4jLong i = blockIdx.x; i < rows; i += gridDim.x) { + for (int j = threadIdx.x; j < cols; j += totalThreads) { + Nd4jLong coords[2] = {i, j}; + auto xOffset = shape::getOffset(shape, coords); + if (i + diagonal >= j) *(reinterpret_cast(buffer) + xOffset) = value; } + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, double value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, double value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, float value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, float value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, int value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, int value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, float16 value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, float16 value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, bfloat16 value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, bfloat16 value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, Nd4jLong value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, Nd4jLong value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, int16_t value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, int16_t value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, uint8_t value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, uint8_t value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, int8_t value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, int8_t value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, bool value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, bool value, int diagonal, Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, + double value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, + double value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, + float value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, + float value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, + int value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, + int value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, + float16 value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, + float16 value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, + bfloat16 value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, + bfloat16 value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, + Nd4jLong value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, + Nd4jLong value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, + int16_t value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, + int16_t value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, + uint8_t value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, + uint8_t value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, + int8_t value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, + int8_t value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, + bool value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, + bool value, int diagonal, + Nd4jLong rows, Nd4jLong cols); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - static void setDiagonalValueUpper(void* buffer, Nd4jLong* shape, NDArray const& value, int diagonal, Nd4jLong rows, Nd4jLong cols, cudaStream_t& stream) { - dim3 launchDims(256, 512, 8192); - setDiagValueUpperKernel<<>>(buffer, shape, value.e(0), diagonal, rows, cols); - } +template +static void setDiagonalValueUpper(void* buffer, Nd4jLong* shape, + NDArray const& value, int diagonal, + Nd4jLong rows, Nd4jLong cols, + cudaStream_t& stream) { + dim3 launchDims(256, 512, 8192); + setDiagValueUpperKernel + <<>>( + buffer, shape, value.e(0), diagonal, rows, cols); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - static void setDiagonalValueLower(void* buffer, Nd4jLong* shape, NDArray const& value, int diagonal, Nd4jLong rows, Nd4jLong cols, cudaStream_t& stream) { - dim3 launchDims(256, 512, 8192); - setDiagValueLowerKernel<<>>(buffer, shape, value.e(0), diagonal, rows, cols); - } +template +static void setDiagonalValueLower(void* buffer, Nd4jLong* shape, + NDArray const& value, int diagonal, + Nd4jLong rows, Nd4jLong cols, + cudaStream_t& stream) { + dim3 launchDims(256, 512, 8192); + setDiagValueLowerKernel + <<>>( + buffer, shape, value.e(0), diagonal, rows, cols); +} - BUILD_SINGLE_TEMPLATE(template void setDiagonalValueUpper, (void* buffer, Nd4jLong* shape, NDArray const& value, - int diagonal, Nd4jLong rows, Nd4jLong cols, cudaStream_t& stream), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void setDiagonalValueLower, (void* buffer, Nd4jLong* shape, NDArray const& value, - int diagonal, Nd4jLong rows, Nd4jLong cols, cudaStream_t& stream), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void setDiagonalValueUpper, + (void* buffer, Nd4jLong* shape, NDArray const& value, + int diagonal, Nd4jLong rows, Nd4jLong cols, + cudaStream_t& stream), + LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void setDiagonalValueLower, + (void* buffer, Nd4jLong* shape, NDArray const& value, + int diagonal, Nd4jLong rows, Nd4jLong cols, + cudaStream_t& stream), + LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -} \ No newline at end of file +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/shuffleKernel.cu b/libnd4j/include/loops/cuda/specials/shuffleKernel.cu index db63c2af728c..e30ce8210e9f 100644 --- a/libnd4j/include/loops/cuda/specials/shuffleKernel.cu +++ b/libnd4j/include/loops/cuda/specials/shuffleKernel.cu @@ -24,102 +24,103 @@ namespace sd { //////////////////////////////////////////////////////////////////////// - template - __global__ void execShuffleKernel(void **vdX, Nd4jLong **dxShapeInfo, - void **vdZ, - int N, - int *shuffleMap, - Nd4jLong **tadOnlyShapeInfo, Nd4jLong **tadOffsets) { - - // we assume that shuffle map for each X contains pair TAD Y - auto dX = reinterpret_cast(vdX); - auto dZ = reinterpret_cast(vdZ); - - __shared__ int tadLength; - __shared__ int xRank; - __shared__ int tadEWS; - __shared__ int numTads; - __shared__ Nd4jLong* xShapeInfo; - __shared__ Nd4jLong xLength; - - for (int f = 0; f < N; f++) { - auto x = reinterpret_cast(dX[f]); - auto z = reinterpret_cast(dZ[f]); - - if (threadIdx.x == 0) { - tadLength = shape::length(tadOnlyShapeInfo[f]); - tadEWS = shape::elementWiseStride(tadOnlyShapeInfo[f]); - xShapeInfo = dxShapeInfo[f]; - xRank = shape::rank(xShapeInfo); - xLength = shape::length(xShapeInfo); - numTads = xLength / tadLength; +template +__global__ void execShuffleKernel(void **vdX, Nd4jLong **dxShapeInfo, + void **vdZ, int N, int *shuffleMap, + Nd4jLong **tadOnlyShapeInfo, + Nd4jLong **tadOffsets) { + // we assume that shuffle map for each X contains pair TAD Y + auto dX = reinterpret_cast(vdX); + auto dZ = reinterpret_cast(vdZ); + + __shared__ int tadLength; + __shared__ int xRank; + __shared__ int tadEWS; + __shared__ int numTads; + __shared__ Nd4jLong *xShapeInfo; + __shared__ Nd4jLong xLength; + + for (int f = 0; f < N; f++) { + auto x = reinterpret_cast(dX[f]); + auto z = reinterpret_cast(dZ[f]); + + if (threadIdx.x == 0) { + tadLength = shape::length(tadOnlyShapeInfo[f]); + tadEWS = shape::elementWiseStride(tadOnlyShapeInfo[f]); + xShapeInfo = dxShapeInfo[f]; + xRank = shape::rank(xShapeInfo); + xLength = shape::length(xShapeInfo); + numTads = xLength / tadLength; + } + __syncthreads(); + + if (xRank == 1) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + for (int r = tid; r < xLength; r += gridDim.x * blockDim.x) { + auto swapIndex = shuffleMap[r]; + if (swapIndex >= 0 && swapIndex < xLength) { + int idx = r * tadEWS; + int swap = swapIndex * tadEWS; + T oldX = x[idx]; + x[idx] = x[swap]; + x[swap] = oldX; + } + } + } else { + // we roll over the pairs of TADs, thus limit is numTads / 2 + for (uint r = blockIdx.x; r < numTads; r += gridDim.x) { + if (shuffleMap[r] >= 0) { + auto oldOffset = tadOffsets[f][r]; + auto newOffset = tadOffsets[f][shuffleMap[r]]; + + auto rX = x + oldOffset; + auto rY = x + newOffset; + + auto zX = z + oldOffset; + auto zY = z + newOffset; + + // so we're going to change TAD[oldOffset] with TAD[newOffset] + if (tadEWS == 1) { + for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) { + T oldX = rX[i]; + rX[i] = rY[i]; + zY[i] = oldX; } - __syncthreads(); - - if (xRank == 1) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - for (int r = tid; r < xLength; r += gridDim.x * blockDim.x) { - auto swapIndex = shuffleMap[r]; - if (swapIndex >= 0 && swapIndex < xLength) { - int idx = r * tadEWS; - int swap = swapIndex * tadEWS; - T oldX = x[idx]; - x[idx] = x[swap]; - x[swap] = oldX; - } - } - } else { - // we roll over the pairs of TADs, thus limit is numTads / 2 - for (uint r = blockIdx.x; r < numTads; r += gridDim.x) { - if (shuffleMap[r] >= 0) { - auto oldOffset = tadOffsets[f][r]; - auto newOffset = tadOffsets[f][shuffleMap[r]]; - auto rX = x + oldOffset; - auto rY = x + newOffset; + } else { + for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo[f]); + auto yOffset = newOffset + xOffset; + xOffset += oldOffset; - auto zX = z + oldOffset; - auto zY = z + newOffset; - - // so we're going to change TAD[oldOffset] with TAD[newOffset] - if (tadEWS == 1) { - for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) { - T oldX = rX[i]; - rX[i] = rY[i]; - zY[i] = oldX; - } - - } else { - for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) { - - auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo[f]); - auto yOffset = newOffset + xOffset; - xOffset += oldOffset; - - T oldX = x[xOffset]; - z[xOffset] = x[yOffset]; - z[yOffset] = oldX; - } - } - } - } + T oldX = x[xOffset]; + z[xOffset] = x[yOffset]; + z[yOffset] = oldX; } - __syncthreads(); + } } + } } + __syncthreads(); + } +} //////////////////////////////////////////////////////////////////////// - template - __host__ void shuffleKernelGeneric(dim3 &launchDims, cudaStream_t *stream, - void **vdX, Nd4jLong **xShapeInfo, - void **vdZ, - int N, - int *shuffleMap, - Nd4jLong **tadOnlyShapeInfo, Nd4jLong **tadOffsets) { - - execShuffleKernel<<>>(vdX, xShapeInfo, vdZ, N, shuffleMap, tadOnlyShapeInfo, tadOffsets); - sd::DebugHelper::checkErrorCode(stream, "shuffleGeneric(...) failed"); - } - - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT shuffleKernelGeneric, (dim3 & launchDims, cudaStream_t * stream, void * *vdX, Nd4jLong * *xShapeInfo, void **vdZ, int N, int * shuffleMap, Nd4jLong * *tadOnlyShapeInfo, Nd4jLong * *tadOffsets), LIBND4J_TYPES); -} \ No newline at end of file +template +__host__ void shuffleKernelGeneric(dim3 &launchDims, cudaStream_t *stream, + void **vdX, Nd4jLong **xShapeInfo, + void **vdZ, int N, int *shuffleMap, + Nd4jLong **tadOnlyShapeInfo, + Nd4jLong **tadOffsets) { + execShuffleKernel<<>>( + vdX, xShapeInfo, vdZ, N, shuffleMap, tadOnlyShapeInfo, tadOffsets); + sd::DebugHelper::checkErrorCode(stream, "shuffleGeneric(...) failed"); +} + +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT shuffleKernelGeneric, + (dim3 & launchDims, cudaStream_t *stream, void **vdX, + Nd4jLong **xShapeInfo, void **vdZ, int N, + int *shuffleMap, Nd4jLong **tadOnlyShapeInfo, + Nd4jLong **tadOffsets), + LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/swapUnsafeKernel.cu b/libnd4j/include/loops/cuda/specials/swapUnsafeKernel.cu index 6d2bcadf54f6..529a10668df0 100644 --- a/libnd4j/include/loops/cuda/specials/swapUnsafeKernel.cu +++ b/libnd4j/include/loops/cuda/specials/swapUnsafeKernel.cu @@ -22,55 +22,64 @@ namespace sd { - //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // kernel to swap two NDArrays vals as linear sequences - // input - theSecondBuffer/Shape from input NDArray - // output - theFirstBuffer/Shape from input NDArray - template - static __global__ void swapUnsafeKernel(void* theFirstBuffer, Nd4jLong const* theFirstShape, void* theSecondBuffer, Nd4jLong const* theSecondShape) { - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - int totalThreads = gridDim.x * blockDim.x; +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// kernel to swap two NDArrays vals as linear sequences +// input - theSecondBuffer/Shape from input NDArray +// output - theFirstBuffer/Shape from input NDArray +template +static __global__ void swapUnsafeKernel(void* theFirstBuffer, + Nd4jLong const* theFirstShape, + void* theSecondBuffer, + Nd4jLong const* theSecondShape) { + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + int totalThreads = gridDim.x * blockDim.x; - __shared__ Nd4jLong resultLength, xEws, yEws; + __shared__ Nd4jLong resultLength, xEws, yEws; __shared__ bool sameOffsets, sameOrders; - __shared__ T* input; - __shared__ T* output; - - if (0 == threadIdx.x) { - resultLength = shape::length(theFirstShape); - input = reinterpret_cast(theSecondBuffer); - output = reinterpret_cast(theFirstBuffer); - - sameOffsets = shape::haveSameShapeAndStrides(theFirstShape, theSecondShape); + __shared__ T* input; + __shared__ T* output; + if (0 == threadIdx.x) { + resultLength = shape::length(theFirstShape); + input = reinterpret_cast(theSecondBuffer); + output = reinterpret_cast(theFirstBuffer);sameOffsets = shape::haveSameShapeAndStrides(theFirstShape, theSecondShape); sameOrders = shape::order(theFirstShape) == shape::order(theSecondShape); xEws = shape::elementWiseStride(theFirstShape); yEws = shape::elementWiseStride(theSecondShape); - } - __syncthreads(); + } + __syncthreads(); - for (int i = tid; i < resultLength; i += totalThreads) { - if(sameOrders && xEws > 0 && yEws > 0) { + for (int i = tid; i < resultLength; i += totalThreads) { + if(sameOrders && xEws > 0 && yEws > 0) { sd::math::nd4j_swap(output[i*xEws], input[i*yEws]); } else if(sameOffsets) { - const auto offset = shape::getIndexOffset(i, theFirstShape); + const auto offset = shape::getIndexOffset(i,theFirstShape) ; sd::math::nd4j_swap(output[offset], input[offset]); - } - else{ - const auto xOffset = shape::getIndexOffset(i, theFirstShape); - const auto yOffset = shape::getIndexOffset(i, theSecondShape); - sd::math::nd4j_swap(output[xOffset], input[yOffset]); - } - } } + else{ + const auto xOffset = shape::getIndexOffset(i, theFirstShape); + const auto yOffset = shape::getIndexOffset(i, theSecondShape); + sd::math::nd4j_swap(output[xOffset], input[yOffset]); + } +}} - BUILD_SINGLE_TEMPLATE(template __global__ void swapUnsafeKernel, (void* theFirstBuffer, Nd4jLong const* theFirstShape, void* theSecondBuffer, Nd4jLong const* theSecondShape), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template __global__ void swapUnsafeKernel, + (void* theFirstBuffer, Nd4jLong const* theFirstShape, + void* theSecondBuffer, Nd4jLong const* theSecondShape), + LIBND4J_TYPES); - template - void templatedSwapUnsafe(void* theFirstBuffer, Nd4jLong const* theFirstShape, void* theSecondBuffer, Nd4jLong const* theSecondShape, cudaStream_t* theStream) { - swapUnsafeKernel<<<256, 512, 8192, *theStream>>>(theFirstBuffer, theFirstShape, theSecondBuffer, theSecondShape); - } - BUILD_SINGLE_TEMPLATE(template void templatedSwapUnsafe, (void* theFirstBuffer, Nd4jLong const* theFirstShape, void* theSecondBuffer, Nd4jLong const* theSecondShape, cudaStream_t* theStream), LIBND4J_TYPES); +template +void templatedSwapUnsafe(void* theFirstBuffer, Nd4jLong const* theFirstShape, + void* theSecondBuffer, Nd4jLong const* theSecondShape, + cudaStream_t* theStream) { + swapUnsafeKernel<<<256, 512, 8192, *theStream>>>( + theFirstBuffer, theFirstShape, theSecondBuffer, theSecondShape); +} +BUILD_SINGLE_TEMPLATE(template void templatedSwapUnsafe, + (void* theFirstBuffer, Nd4jLong const* theFirstShape, + void* theSecondBuffer, Nd4jLong const* theSecondShape, + cudaStream_t* theStream), + LIBND4J_TYPES); -} \ No newline at end of file +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/tearKernel.cu b/libnd4j/include/loops/cuda/specials/tearKernel.cu index e1d70e6b5084..b85b86f97ec7 100644 --- a/libnd4j/include/loops/cuda/specials/tearKernel.cu +++ b/libnd4j/include/loops/cuda/specials/tearKernel.cu @@ -24,72 +24,75 @@ namespace sd { //////////////////////////////////////////////////////////////////////// - template - __device__ void - tearKernel(void *vx, Nd4jLong const* xShapeInfo, Nd4jPointer *targets, Nd4jLong const* zShapeInfo, Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffsets) { - - - - __shared__ Nd4jLong tadLength; - __shared__ int tadEWS; - __shared__ int zEWS; -// __shared__ int tadRank; - __shared__ Nd4jLong numTads; -// __shared__ int zRank; -// __shared__ Nd4jLong *tadShape; -// __shared__ Nd4jLong *tadStride; -// __shared__ Nd4jLong const* zShape; -// __shared__ Nd4jLong const* zStride; - __shared__ T* x; - if (threadIdx.x == 0) { - tadLength = shape::length(tadShapeInfo); - tadEWS = shape::elementWiseStride(tadShapeInfo); - zEWS = shape::elementWiseStride(zShapeInfo); - numTads = shape::length(xShapeInfo) / tadLength; - x = static_cast(vx); - } - __syncthreads(); - - for (Nd4jLong r = blockIdx.x; r < numTads; r += gridDim.x) { - T *z = (T *) targets[r]; - T *s = x + tadOffsets[r]; - - if (zEWS > 0 && tadEWS > 0) { - for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) - z[i * zEWS] = s[i * tadEWS]; - } else { - - for (Nd4jLong j = threadIdx.x; j < tadLength; j += blockDim.x) { - auto xOffset = shape::getIndexOffset(j, tadShapeInfo); - auto zOffset = shape::getIndexOffset(j, zShapeInfo); - - z[zOffset] = s[xOffset]; - } - } - } +template +__device__ void tearKernel(void* vx, Nd4jLong const* xShapeInfo, + Nd4jPointer* targets, Nd4jLong const* zShapeInfo, + Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets) { + __shared__ Nd4jLong tadLength; + __shared__ int tadEWS; + __shared__ int zEWS; + // __shared__ int tadRank; + __shared__ Nd4jLong numTads; + // __shared__ int zRank; + // __shared__ Nd4jLong *tadShape; + // __shared__ Nd4jLong *tadStride; + // __shared__ Nd4jLong const* zShape; + // __shared__ Nd4jLong const* zStride; + __shared__ T* x; + if (threadIdx.x == 0) { + tadLength = shape::length(tadShapeInfo); + tadEWS = shape::elementWiseStride(tadShapeInfo); + zEWS = shape::elementWiseStride(zShapeInfo); + numTads = shape::length(xShapeInfo) / tadLength; + x = static_cast(vx); + } + __syncthreads(); + + for (Nd4jLong r = blockIdx.x; r < numTads; r += gridDim.x) { + T* z = (T*)targets[r]; + T* s = x + tadOffsets[r]; + + if (zEWS > 0 && tadEWS > 0) { + for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) + z[i * zEWS] = s[i * tadEWS]; + } else { + for (Nd4jLong j = threadIdx.x; j < tadLength; j += blockDim.x) { + auto xOffset = shape::getIndexOffset(j, tadShapeInfo); + auto zOffset = shape::getIndexOffset(j, zShapeInfo); + + z[zOffset] = s[xOffset]; + } } - + } +} //////////////////////////////////////////////////////////////////////// - template - __global__ void - execTearKernel(void *vx, Nd4jLong const* xShapeInfo, Nd4jPointer *targets, Nd4jLong const* zShapeInfo, Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffsets) { - - tearKernel(vx, xShapeInfo, targets, zShapeInfo, tadShapeInfo, tadOffsets); - } +template +__global__ void execTearKernel(void* vx, Nd4jLong const* xShapeInfo, + Nd4jPointer* targets, Nd4jLong const* zShapeInfo, + Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets) { + tearKernel(vx, xShapeInfo, targets, zShapeInfo, tadShapeInfo, tadOffsets); +} //////////////////////////////////////////////////////////////////////// - template - __host__ void tearKernelGeneric(dim3 &launchDims, cudaStream_t *stream, - void *vx, Nd4jLong const* xShapeInfo, - Nd4jPointer *targets, Nd4jLong const* zShapeInfo, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - - execTearKernel<<>>(vx, xShapeInfo, targets, zShapeInfo, tadShapeInfo, tadOffsets); - sd::DebugHelper::checkErrorCode(stream, "tear(...) failed"); - } - - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT tearKernelGeneric, (dim3 & launchDims, cudaStream_t * stream, void * vx, Nd4jLong const* xShapeInfo, Nd4jPointer *targets, Nd4jLong const* zShapeInfo, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets), LIBND4J_TYPES); -} \ No newline at end of file +template +__host__ void tearKernelGeneric(dim3& launchDims, cudaStream_t* stream, + void* vx, Nd4jLong const* xShapeInfo, + Nd4jPointer* targets, + Nd4jLong const* zShapeInfo, + Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets) { + execTearKernel<<>>( + vx, xShapeInfo, targets, zShapeInfo, tadShapeInfo, tadOffsets); + sd::DebugHelper::checkErrorCode(stream, "tear(...) failed"); +} + +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT tearKernelGeneric, + (dim3 & launchDims, cudaStream_t* stream, void* vx, + Nd4jLong const* xShapeInfo, Nd4jPointer* targets, + Nd4jLong const* zShapeInfo, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets), + LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/tileKernel.cu b/libnd4j/include/loops/cuda/specials/tileKernel.cu index 3a2684579b59..abed61912405 100644 --- a/libnd4j/include/loops/cuda/specials/tileKernel.cu +++ b/libnd4j/include/loops/cuda/specials/tileKernel.cu @@ -21,91 +21,126 @@ #include namespace sd { - static Nd4jLong __device__ __noinline__ getIndexOffset_(Nd4jLong index, Nd4jLong const* shapeInfo) { - return shape::getIndexOffset(index, shapeInfo); - } - - static Nd4jLong __device__ __noinline__ subArrayOffset(Nd4jLong index, Nd4jLong const* shapeInfoA, Nd4jLong const* shapeInfoB) { - return shape::subArrayOffset(index, shapeInfoA, shapeInfoB); - } +static Nd4jLong __device__ __noinline__ +getIndexOffset_(Nd4jLong index, Nd4jLong const* shapeInfo) { + return shape::getIndexOffset(index, shapeInfo); +} +static Nd4jLong __device__ __noinline__ subArrayOffset( + Nd4jLong index, Nd4jLong const* shapeInfoA, Nd4jLong const* shapeInfoB) { + return shape::subArrayOffset(index, shapeInfoA, shapeInfoB); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // tileKernel: // input: (inputBuffer and inputShape) - NDArray buffer and shape to tile // output: (outputBuffer and outputShape) - NDArray to tile input // resultLength - length for output array - template - static __global__ void - tileKernel(void const *inputBuffer, Nd4jLong const* inputShape, void *outputBuffer, Nd4jLong const* outputShape, - Nd4jLong resultLength) { -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -// Original code to transform in cuda-based - auto tid = blockIdx.x * blockDim.x + threadIdx.x; // copy linear sequence of elements, so one-level threading - int totalThreads = gridDim.x * blockDim.x; - if (shape::order(outputShape) == 'c') { // ews == 1 always here - for (int i = tid; i < resultLength; i += totalThreads) { - auto yOffset = subArrayOffset(i, outputShape, inputShape); - *(reinterpret_cast(outputBuffer) + i) = *(reinterpret_cast(inputBuffer) + yOffset); - } - } else { - for (int i = tid; i < resultLength; i += totalThreads) { - auto xOffset = getIndexOffset_(i, outputShape); - auto yOffset = subArrayOffset(i, outputShape, inputShape); - *(reinterpret_cast(outputBuffer) + xOffset) = *(reinterpret_cast(inputBuffer) + yOffset); - } - } - +template +static __global__ void tileKernel(void const* inputBuffer, + Nd4jLong const* inputShape, + void* outputBuffer, + Nd4jLong const* outputShape, + Nd4jLong resultLength) { + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Original code to transform in cuda-based + auto tid = + blockIdx.x * blockDim.x + + threadIdx.x; // copy linear sequence of elements, so one-level threading + int totalThreads = gridDim.x * blockDim.x; + if (shape::order(outputShape) == 'c') { // ews == 1 always here + for (int i = tid; i < resultLength; i += totalThreads) { + auto yOffset = subArrayOffset(i, outputShape, inputShape); + *(reinterpret_cast(outputBuffer) + i) = + *(reinterpret_cast(inputBuffer) + yOffset); } - - BUILD_SINGLE_TEMPLATE(template __global__ void tileKernel,(void const* inputBuffer, Nd4jLong const* inputShape, void* outputBuffer, Nd4jLong const* outputShape, Nd4jLong resultLength), LIBND4J_TYPES); - -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - void tileKernelH(void const *inputBuffer, Nd4jLong const* inputShape, void *outputBuffer, Nd4jLong const* outputShape, Nd4jLong resultLength, cudaStream_t *stream) { - dim3 launchDims(256, 512, 8192); - tileKernel << < launchDims.x, launchDims.y, launchDims.z, *stream>>>(inputBuffer, inputShape, outputBuffer, outputShape, resultLength); + } else { + for (int i = tid; i < resultLength; i += totalThreads) { + auto xOffset = getIndexOffset_(i, outputShape); + auto yOffset = subArrayOffset(i, outputShape, inputShape); + *(reinterpret_cast(outputBuffer) + xOffset) = + *(reinterpret_cast(inputBuffer) + yOffset); } + } +} - BUILD_SINGLE_TEMPLATE(template void tileKernelH, (void const* inputBuffer, Nd4jLong const* inputShape, void* outputBuffer, Nd4jLong const* outputShape, Nd4jLong resultLength, cudaStream_t *stream), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template __global__ void tileKernel, + (void const* inputBuffer, Nd4jLong const* inputShape, + void* outputBuffer, Nd4jLong const* outputShape, + Nd4jLong resultLength), + LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -// enhancement for tileKernel to different input and output data types: X - output type, Y - input type - template - static __global__ void - tileKernelDouble(void const *inputBuffer, Nd4jLong const* inputShape, void *outputBuffer, Nd4jLong const* outputShape, Nd4jLong resultLength, Nd4jLong ews) { - char ordering = shape::order(outputShape); - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - int totalThreads = gridDim.x * blockDim.x; +template +void tileKernelH(void const* inputBuffer, Nd4jLong const* inputShape, + void* outputBuffer, Nd4jLong const* outputShape, + Nd4jLong resultLength, cudaStream_t* stream) { + dim3 launchDims(256, 512, 8192); + tileKernel<<>>( + inputBuffer, inputShape, outputBuffer, outputShape, resultLength); +} + +BUILD_SINGLE_TEMPLATE(template void tileKernelH, + (void const* inputBuffer, Nd4jLong const* inputShape, + void* outputBuffer, Nd4jLong const* outputShape, + Nd4jLong resultLength, cudaStream_t* stream), + LIBND4J_TYPES); - if (ordering == 'c' && ews == 1) { // ews == 1 always here - for (int i = tid; i < resultLength; i += totalThreads) { - auto yOffset = subArrayOffset(i, outputShape, inputShape); - *(reinterpret_cast(outputBuffer) + i) = static_cast(*(reinterpret_cast(inputBuffer) + yOffset)); - } - } else if (ordering == 'c' && ews > 1) { - for (int i = tid; i < resultLength; i += totalThreads) { - auto yOffset = subArrayOffset(i, outputShape, inputShape); - *(reinterpret_cast(outputBuffer) + i * ews) = static_cast(*(reinterpret_cast(inputBuffer) + yOffset)); - } - } else { - - for (int i = tid; i < resultLength; i += totalThreads) { - - auto xOffset = getIndexOffset_(i, outputShape); - auto yOffset = subArrayOffset(i, outputShape, inputShape); - *(reinterpret_cast(outputBuffer) + xOffset) = static_cast(*(reinterpret_cast(inputBuffer) + yOffset)); - } - } +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// enhancement for tileKernel to different input and output data types: X - +// output type, Y - input type +template +static __global__ void tileKernelDouble(void const* inputBuffer, + Nd4jLong const* inputShape, + void* outputBuffer, + Nd4jLong const* outputShape, + Nd4jLong resultLength, Nd4jLong ews) { + char ordering = shape::order(outputShape); + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + int totalThreads = gridDim.x * blockDim.x; + + if (ordering == 'c' && ews == 1) { // ews == 1 always here + for (int i = tid; i < resultLength; i += totalThreads) { + auto yOffset = subArrayOffset(i, outputShape, inputShape); + *(reinterpret_cast(outputBuffer) + i) = + static_cast(*(reinterpret_cast(inputBuffer) + yOffset)); } - - BUILD_SINGLE_TEMPLATE_TWICE(template __global__ void tileKernelDouble, (void const* inputBuffer, Nd4jLong const* inputShape, void* outputBuffer, Nd4jLong const* outputShape, Nd4jLong resultLength, Nd4jLong ews), LIBND4J_TYPES); - - template - void tileKernelHH(void const *inputBuffer, Nd4jLong const* inputShape, void *outputBuffer, Nd4jLong const* outputShape, Nd4jLong resultLength, Nd4jLong ews, cudaStream_t *stream) { - dim3 launchDims(256, 512, 8192); - tileKernelDouble<<>>(inputBuffer, inputShape, outputBuffer, outputShape, resultLength, ews); + } else if (ordering == 'c' && ews > 1) { + for (int i = tid; i < resultLength; i += totalThreads) { + auto yOffset = subArrayOffset(i, outputShape, inputShape); + *(reinterpret_cast(outputBuffer) + i * ews) = + static_cast(*(reinterpret_cast(inputBuffer) + yOffset)); } - - BUILD_SINGLE_TEMPLATE_TWICE(template void tileKernelHH, (void const* inputBuffer, Nd4jLong const* inputShape, void* outputBuffer, Nd4jLong const* outputShape, Nd4jLong resultLength, Nd4jLong ews, cudaStream_t *stream),LIBND4J_TYPES); -} \ No newline at end of file + } else { + for (int i = tid; i < resultLength; i += totalThreads) { + auto xOffset = getIndexOffset_(i, outputShape); + auto yOffset = subArrayOffset(i, outputShape, inputShape); + *(reinterpret_cast(outputBuffer) + xOffset) = + static_cast(*(reinterpret_cast(inputBuffer) + yOffset)); + } + } +} + +BUILD_SINGLE_TEMPLATE_TWICE(template __global__ void tileKernelDouble, + (void const* inputBuffer, + Nd4jLong const* inputShape, void* outputBuffer, + Nd4jLong const* outputShape, Nd4jLong resultLength, + Nd4jLong ews), + LIBND4J_TYPES); + +template +void tileKernelHH(void const* inputBuffer, Nd4jLong const* inputShape, + void* outputBuffer, Nd4jLong const* outputShape, + Nd4jLong resultLength, Nd4jLong ews, cudaStream_t* stream) { + dim3 launchDims(256, 512, 8192); + tileKernelDouble<<>>( + inputBuffer, inputShape, outputBuffer, outputShape, resultLength, ews); +} + +BUILD_SINGLE_TEMPLATE_TWICE(template void tileKernelHH, + (void const* inputBuffer, + Nd4jLong const* inputShape, void* outputBuffer, + Nd4jLong const* outputShape, Nd4jLong resultLength, + Nd4jLong ews, cudaStream_t* stream), + LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/summarystatsreduce.cu b/libnd4j/include/loops/cuda/summarystatsreduce.cu index 521ac5b065c4..8dd289e81990 100644 --- a/libnd4j/include/loops/cuda/summarystatsreduce.cu +++ b/libnd4j/include/loops/cuda/summarystatsreduce.cu @@ -18,402 +18,415 @@ // @author raver119@gmail.com // - -#include -#include -#include -#include -#include -#include -#include -#include -#include #include #include #include +#include +#include +#include #include +#include +#include +#include +#include +#include +#include using namespace simdOps; namespace functions { - namespace summarystats { +namespace summarystats { template -void _CUDA_G summaryStatsReduceT(int op, void const* dx, Nd4jLong const* xShapeInfo, int xRank, void *extraParams, void *z, Nd4jLong const* zShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot,bool biasCorrected,int *allocationBuffer, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) { - - functions::summarystats::SummaryStatsReduce::transform(op,dx,xShapeInfo,extraParams,z,zShapeInfo,dimension,dimensionLength,biasCorrected,allocationBuffer,reductionBuffer,tadOnlyShapeInfo,tadOffsets); +void _CUDA_G summaryStatsReduceT( + int op, void const* dx, Nd4jLong const* xShapeInfo, int xRank, + void* extraParams, void* z, Nd4jLong const* zShapeInfo, int zRank, + int* dimension, int dimensionLength, int postProcessOrNot, + bool biasCorrected, int* allocationBuffer, void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) { + functions::summarystats::SummaryStatsReduce::transform( + op, dx, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, biasCorrected, allocationBuffer, reductionBuffer, + tadOnlyShapeInfo, tadOffsets); } - /** - * - * @param sPartialsRef - * @param tid - * @param extraParams - */ - template - template - _CUDA_D void SummaryStatsReduce::aggregatePartials(SummaryStatsData **sPartialsRef, Nd4jLong tid, Nd4jLong numElements, void *vextraParams) { - // start the shared memory loop on the next power of 2 less - // than the block size. If block size is not a power of 2, - // accumulate the intermediate sums in the remainder range. - auto extraParams = static_cast(vextraParams); - SummaryStatsData *sPartials = *sPartialsRef; - Nd4jLong floorPow2 = blockDim.x; - - if (floorPow2 & (floorPow2 - 1)) { - while (floorPow2 & (floorPow2 - 1)) { - floorPow2 &= floorPow2 - 1; - } - - if (tid >= floorPow2) { - SummaryStatsData prev = sPartials[tid - floorPow2]; - SummaryStatsData curr = sPartials[tid]; - sPartials[tid - floorPow2] = update(prev, curr, extraParams); - } - __syncthreads(); - } - - for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; activeThreads >>= 1) { - if (tid < activeThreads && tid + activeThreads < numElements) { - SummaryStatsData curr = sPartials[tid]; - SummaryStatsData next = sPartials[tid + activeThreads]; - sPartials[tid] = update(curr, next, extraParams); - } - __syncthreads(); - } - }; - - /** - * @param n n is the number of - * elements to loop through - * @param dx the data to operate on - * @param xVectorInfo the meta data for the vector: - * 0 is the offset - * 1 is the increment/stride - * 2 is the real length of the buffer (n and dx.length won't always be the same) - * 3 is the element wise stride for the buffer - * 4 is the number of elements it takes to get to the next row/column/tensor - * @param gpuInformation - * 0 is the block size - * 1 is the grid size - * 2 is the shared memory size - * @param problemDefinition - * 0 is the number of elements per vector - * 1 is the number of vectors - */ - template - template - _CUDA_D void SummaryStatsReduce::transform(void const* vx, Nd4jLong const* xShapeInfo, - void *vextraParams, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationBuffer, void *vreductionBuffer, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) { - - auto dx = static_cast(vx); - auto z = static_cast(vz); - auto extraParams = static_cast(vextraParams); - auto reductionBuffer = static_cast(vreductionBuffer); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - __shared__ volatile int resultScalar; - - __shared__ int xElementWiseStride; - - int numElements = blockDim.x; - //shared memory space for storing intermediate results - __shared__ SummaryStatsData *sPartials; - if(threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast*>(shmem); - } - __syncthreads(); - - Z startingVal = startingValue(dx); - - SummaryStatsData val; - val.initWithValue(startingVal); - val.n = 0; - sPartials[threadIdx.x] = val; - - - //length for the tad - __shared__ volatile int xLength; - - __shared__ volatile int resultLength; - - - SummaryStatsData reduction; - reduction.initWithValue(0.0); - reduction.n = 0; - if (threadIdx.x == 0) { - if (zShapeInfo != nullptr) - resultLength = shape::length(zShapeInfo); - else resultLength = 1; - - - if (dimensionLength == 1) { - if (resultLength == 1 && (dimension == nullptr || dimension[0] == MAX_DIMENSION)) - resultScalar = 1; - else - resultScalar = 0; - } - else - resultScalar = 0; - - if (resultLength == 1) - resultScalar = 1; - - auto xStride = shape::stride(xShapeInfo); - auto xOrder = shape::order(xShapeInfo); - - if (dimension != nullptr && (dimension[0] != MAX_DIMENSION && dimensionLength == 1)) { - xElementWiseStride = xStride[dimension[0]]; - } - else { - xElementWiseStride = shape::elementWiseStride(xShapeInfo); - } - - - xLength = shape::length(xShapeInfo); - - - } - __syncthreads(); - if (!resultScalar) { - - __shared__ int tadLength; - __shared__ int tadEWS; - __shared__ int numTads; - - if (threadIdx.x == 0) { - tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); - tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); - numTads = shape::length(xShapeInfo) / tadLength; - } - __syncthreads(); - - if (tadEWS == 0) { - - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - auto tadOffsetForBlock = tadOffsets[r]; - - val.initWithValue(startingVal); - val.n = 0; - sPartials[threadIdx.x] = val; - - for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { - auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo); - SummaryStatsData indexVal2; - indexVal2.initWithValue(dx[xOffset]); - - sPartials[threadIdx.x] = update(sPartials[threadIdx.x], OpType::op(indexVal2, extraParams), extraParams); - } - __syncthreads(); - aggregatePartials(&sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, tadLength), extraParams); - - __syncthreads(); - if (threadIdx.x == 0) { - z[r] = OpType::getValue(postProcessOrNot, sPartials[threadIdx.x]); - } - __syncthreads(); - } - } - else { - - for (int i = blockIdx.x; i < numTads; i += gridDim.x) { - auto tadOffsetForBlock = tadOffsets[i]; - - val.initWithValue(startingVal); - val.n = 0; - sPartials[threadIdx.x] = val; - - for (int x = threadIdx.x; x < tadLength; x += blockDim.x) { - auto indexX = tadOffsetForBlock + x * tadEWS; - SummaryStatsData indexVal2; - indexVal2.initWithValue(dx[indexX]); - sPartials[threadIdx.x] = update(sPartials[threadIdx.x], OpType::op(indexVal2, extraParams), extraParams); - } - - __syncthreads(); - aggregatePartials(&sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, tadLength), extraParams); - - __syncthreads(); - if (threadIdx.x == 0) { - z[i] = OpType::getValue(postProcessOrNot, sPartials[threadIdx.x]); //postProcess(sPartials[0],tadLength ,extraParams); - } - } - } - } - else if (resultScalar) { - __shared__ int n; - if (threadIdx.x == 0) { - xElementWiseStride = shape::elementWiseStride(xShapeInfo); - n = shape::length(xShapeInfo); - } - __syncthreads(); - - if (xElementWiseStride >= 1) { - for (Nd4jLong i = tid; i < n; i += (blockDim.x * gridDim.x)) { - SummaryStatsData indexVal2; - indexVal2.initWithValue(dx[i * xElementWiseStride]); - reduction = update(reduction, indexVal2, extraParams); - } - } - else { - - for (Nd4jLong i = tid; i < n; i += blockDim.x * gridDim.x) { - - auto offset = shape::getIndexOffset(i, xShapeInfo); - SummaryStatsData indexVal2; - indexVal2.initWithValue(dx[offset]); - reduction = update(reduction, indexVal2, extraParams); - } - } - sPartials[threadIdx.x] = reduction; - - __syncthreads(); - aggregatePartials(&sPartials, threadIdx.x, blockDim.x, extraParams); - __syncthreads(); - - if (gridDim.x > 1) { - __shared__ bool amLast; - unsigned int *tc = (unsigned int *)reductionBuffer; - tid = threadIdx.x; - if (threadIdx.x == 0) { - SummaryStatsData *pBuffer = (SummaryStatsData*) reductionBuffer; - pBuffer[blockIdx.x] = sPartials[0]; - } - __threadfence(); - __syncthreads(); - - if (tid == 0) { - unsigned int ticket = atomicInc(&tc[16384], gridDim.x); - amLast = (ticket == gridDim.x - 1); - } - - __syncthreads(); - - if (amLast) { - tc[16384] = 0; - SummaryStatsData* pBuffer = (SummaryStatsData*) reductionBuffer; - - Z startingVal = startingValue(dx); - - SummaryStatsData val; - val.initWithValue(startingVal); - val.n = 0; - sPartials[threadIdx.x] = val; - - for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) { - sPartials[threadIdx.x] = update(sPartials[threadIdx.x], pBuffer[i], extraParams); - } - - __syncthreads(); - aggregatePartials(&sPartials, threadIdx.x, gridDim.x, extraParams); - __syncthreads(); - - if (tid == 0) { - z[0] = OpType::getValue(postProcessOrNot, sPartials[0]); - } - } - } - else { - if (tid == 0) { - unsigned int *tc = (unsigned *)reductionBuffer; - tc[16384] = 0; - z[0] = z[0] = OpType::getValue(postProcessOrNot, sPartials[0]); - } - } - } - }; - - - template - _CUDA_D void SummaryStatsReduce::transform(const int opNum, void const* dx, Nd4jLong const* xShapeInfo, void *extraParams, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) { - DISPATCH_BY_OPNUM_TT(transform, PARAMS(dx, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationBuffer, reductionBuffer, tadOnlyShapeInfo, tadOffsets), SUMMARY_STATS_OPS); - }; - - - template - _CUDA_H void SummaryStatsReduce::execSummaryStatsReduceScalar(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void *vextraParams, void *vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool biasCorrected, void *reductionBuffer) { - - auto x = static_cast(vx); - auto extraParams = static_cast(vextraParams); - auto z = reinterpret_cast(vz); - auto reductionPointerA = reinterpret_cast(reductionBuffer); - - if (sd::Environment::getInstance().isDebugAndVerbose()) - printf("D16 opNum:[%i]\n", opNum); - - summaryStatsReduceT<<>>( - opNum, - x, - xShapeInfo, shape::rank(hxShapeInfo), - extraParams, - z, - zShapeInfo, shape::rank(hzShapeInfo), - nullptr, - 1, - 1,biasCorrected, nullptr, reductionPointerA, tadShapeInfo, tadOffsets); - - // this is blocking method since method should return scalar - sd::DebugHelper::checkErrorCode(stream, "execSSReduceScalar(...) failed"); - } +/** + * + * @param sPartialsRef + * @param tid + * @param extraParams + */ +template +template +_CUDA_D void SummaryStatsReduce::aggregatePartials( + SummaryStatsData** sPartialsRef, Nd4jLong tid, Nd4jLong numElements, + void* vextraParams) { + // start the shared memory loop on the next power of 2 less + // than the block size. If block size is not a power of 2, + // accumulate the intermediate sums in the remainder range. + auto extraParams = static_cast(vextraParams); + SummaryStatsData* sPartials = *sPartialsRef; + Nd4jLong floorPow2 = blockDim.x; + + if (floorPow2 & (floorPow2 - 1)) { + while (floorPow2 & (floorPow2 - 1)) { + floorPow2 &= floorPow2 - 1; + } - template - _CUDA_H void SummaryStatsReduce::execSummaryStatsReduce(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void *vextraParams, void *vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool biasCorrected, void *reductionBuffer) { + if (tid >= floorPow2) { + SummaryStatsData prev = sPartials[tid - floorPow2]; + SummaryStatsData curr = sPartials[tid]; + sPartials[tid - floorPow2] = update(prev, curr, extraParams); + } + __syncthreads(); + } + + for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; + activeThreads >>= 1) { + if (tid < activeThreads && tid + activeThreads < numElements) { + SummaryStatsData curr = sPartials[tid]; + SummaryStatsData next = sPartials[tid + activeThreads]; + sPartials[tid] = update(curr, next, extraParams); + } + __syncthreads(); + } +}; + +/** + * @param n n is the number of + * elements to loop through + * @param dx the data to operate on + * @param xVectorInfo the meta data for the vector: + * 0 is the offset + * 1 is the increment/stride + * 2 is the real length of the buffer (n and + * dx.length won't always be the same) 3 is the element wise stride for the + * buffer 4 is the number of elements it takes to get to the next + * row/column/tensor + * @param gpuInformation + * 0 is the block size + * 1 is the grid size + * 2 is the shared memory size + * @param problemDefinition + * 0 is the number of elements per vector + * 1 is the number of vectors + */ +template +template +_CUDA_D void SummaryStatsReduce::transform( + void const* vx, Nd4jLong const* xShapeInfo, void* vextraParams, void* vz, + Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + int postProcessOrNot, int* allocationBuffer, void* vreductionBuffer, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) { + auto dx = static_cast(vx); + auto z = static_cast(vz); + auto extraParams = static_cast(vextraParams); + auto reductionBuffer = static_cast(vreductionBuffer); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + __shared__ volatile int resultScalar; + + __shared__ int xElementWiseStride; + + int numElements = blockDim.x; + // shared memory space for storing intermediate results + __shared__ SummaryStatsData* sPartials; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast*>(shmem); + } + __syncthreads(); + + Z startingVal = startingValue(dx); + + SummaryStatsData val; + val.initWithValue(startingVal); + val.n = 0; + sPartials[threadIdx.x] = val; + + // length for the tad + __shared__ volatile int xLength; + + __shared__ volatile int resultLength; + + SummaryStatsData reduction; + reduction.initWithValue(0.0); + reduction.n = 0; + if (threadIdx.x == 0) { + if (zShapeInfo != nullptr) + resultLength = shape::length(zShapeInfo); + else + resultLength = 1; + + if (dimensionLength == 1) { + if (resultLength == 1 && + (dimension == nullptr || dimension[0] == MAX_DIMENSION)) + resultScalar = 1; + else + resultScalar = 0; + } else + resultScalar = 0; + + if (resultLength == 1) resultScalar = 1; + + auto xStride = shape::stride(xShapeInfo); + auto xOrder = shape::order(xShapeInfo); + + if (dimension != nullptr && + (dimension[0] != MAX_DIMENSION && dimensionLength == 1)) { + xElementWiseStride = xStride[dimension[0]]; + } else { + xElementWiseStride = shape::elementWiseStride(xShapeInfo); + } - auto x = static_cast(vx); - auto z = static_cast(vz); - auto extraParams = static_cast(vextraParams); + xLength = shape::length(xShapeInfo); + } + __syncthreads(); + if (!resultScalar) { + __shared__ int tadLength; + __shared__ int tadEWS; + __shared__ int numTads; + + if (threadIdx.x == 0) { + tadLength = + shape::length(tadOnlyShapeInfo); // shape::tadLength(xShapeInfo, + // dimension, dimensionLength); + tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); + numTads = shape::length(xShapeInfo) / tadLength; + } + __syncthreads(); - if (sd::Environment::getInstance().isDebugAndVerbose()) - printf("F17 opNum:[%i]\n", opNum); + if (tadEWS == 0) { + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + auto tadOffsetForBlock = tadOffsets[r]; - auto reductionPointerA = reinterpret_cast(reductionBuffer); + val.initWithValue(startingVal); + val.n = 0; + sPartials[threadIdx.x] = val; - summaryStatsReduceT<<>>( - opNum, - x, - xShapeInfo, shape::rank(hxShapeInfo), - extraParams, - z, - zShapeInfo, shape::rank(hzShapeInfo), - nullptr, - 1, - 1,biasCorrected, nullptr, reductionPointerA, tadShapeInfo, tadOffsets); + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = + tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo); + SummaryStatsData indexVal2; + indexVal2.initWithValue(dx[xOffset]); - DEBUG_KERNEL(stream, opNum); + sPartials[threadIdx.x] = + update(sPartials[threadIdx.x], OpType::op(indexVal2, extraParams), + extraParams); + } + __syncthreads(); + aggregatePartials( + &sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, tadLength), extraParams); + + __syncthreads(); + if (threadIdx.x == 0) { + z[r] = OpType::getValue(postProcessOrNot, sPartials[threadIdx.x]); + } + __syncthreads(); + } + } else { + for (int i = blockIdx.x; i < numTads; i += gridDim.x) { + auto tadOffsetForBlock = tadOffsets[i]; + + val.initWithValue(startingVal); + val.n = 0; + sPartials[threadIdx.x] = val; + + for (int x = threadIdx.x; x < tadLength; x += blockDim.x) { + auto indexX = tadOffsetForBlock + x * tadEWS; + SummaryStatsData indexVal2; + indexVal2.initWithValue(dx[indexX]); + sPartials[threadIdx.x] = + update(sPartials[threadIdx.x], OpType::op(indexVal2, extraParams), + extraParams); } + __syncthreads(); + aggregatePartials( + &sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, tadLength), extraParams); + + __syncthreads(); + if (threadIdx.x == 0) { + z[i] = OpType::getValue( + postProcessOrNot, + sPartials[threadIdx.x]); // postProcess(sPartials[0],tadLength + // ,extraParams); + } + } + } + } else if (resultScalar) { + __shared__ int n; + if (threadIdx.x == 0) { + xElementWiseStride = shape::elementWiseStride(xShapeInfo); + n = shape::length(xShapeInfo); + } + __syncthreads(); + + if (xElementWiseStride >= 1) { + for (Nd4jLong i = tid; i < n; i += (blockDim.x * gridDim.x)) { + SummaryStatsData indexVal2; + indexVal2.initWithValue(dx[i * xElementWiseStride]); + reduction = update(reduction, indexVal2, extraParams); + } + } else { + for (Nd4jLong i = tid; i < n; i += blockDim.x * gridDim.x) { + auto offset = shape::getIndexOffset(i, xShapeInfo); + SummaryStatsData indexVal2; + indexVal2.initWithValue(dx[offset]); + reduction = update(reduction, indexVal2, extraParams); + } + } + sPartials[threadIdx.x] = reduction; + + __syncthreads(); + aggregatePartials(&sPartials, threadIdx.x, blockDim.x, extraParams); + __syncthreads(); + + if (gridDim.x > 1) { + __shared__ bool amLast; + unsigned int* tc = (unsigned int*)reductionBuffer; + tid = threadIdx.x; + if (threadIdx.x == 0) { + SummaryStatsData* pBuffer = (SummaryStatsData*)reductionBuffer; + pBuffer[blockIdx.x] = sPartials[0]; + } + __threadfence(); + __syncthreads(); + + if (tid == 0) { + unsigned int ticket = atomicInc(&tc[16384], gridDim.x); + amLast = (ticket == gridDim.x - 1); + } + + __syncthreads(); + + if (amLast) { + tc[16384] = 0; + SummaryStatsData* pBuffer = (SummaryStatsData*)reductionBuffer; + + Z startingVal = startingValue(dx); + + SummaryStatsData val; + val.initWithValue(startingVal); + val.n = 0; + sPartials[threadIdx.x] = val; + + for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) { + sPartials[threadIdx.x] = + update(sPartials[threadIdx.x], pBuffer[i], extraParams); + } - template - _CUDA_H void SummaryStatsReduce::execSummaryStatsReduce(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void *vextraParams, void *vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool biasCorrected, void *reductionBuffer) { - - auto x = static_cast(vx); - auto z = static_cast(vz); - auto extraParams = static_cast(vextraParams); + __syncthreads(); + aggregatePartials(&sPartials, threadIdx.x, gridDim.x, + extraParams); + __syncthreads(); - if (sd::Environment::getInstance().isDebugAndVerbose()) - printf("D18 opNum:[%i]\n", opNum); + if (tid == 0) { + z[0] = OpType::getValue(postProcessOrNot, sPartials[0]); + } + } + } else { + if (tid == 0) { + unsigned int* tc = (unsigned*)reductionBuffer; + tc[16384] = 0; + z[0] = z[0] = OpType::getValue(postProcessOrNot, sPartials[0]); + } + } + } +}; + +template +_CUDA_D void SummaryStatsReduce::transform( + const int opNum, void const* dx, Nd4jLong const* xShapeInfo, + void* extraParams, void* z, Nd4jLong const* zShapeInfo, int* dimension, + int dimensionLength, int postProcessOrNot, int* allocationBuffer, + void* reductionBuffer, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets) { + DISPATCH_BY_OPNUM_TT( + transform, + PARAMS(dx, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, postProcessOrNot, allocationBuffer, + reductionBuffer, tadOnlyShapeInfo, tadOffsets), + SUMMARY_STATS_OPS); +}; - summaryStatsReduceT<<>>( - opNum, - x, - xShapeInfo, shape::rank(hxShapeInfo), - extraParams, - z, - zShapeInfo, shape::rank(hzShapeInfo), - dimension, - dimensionLength, - 1, biasCorrected, nullptr, reinterpret_cast(reductionBuffer), tadShapeInfo, tadOffsets); +template +_CUDA_H void SummaryStatsReduce::execSummaryStatsReduceScalar( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void* vextraParams, + void* vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + bool biasCorrected, void* reductionBuffer) { + auto x = static_cast(vx); + auto extraParams = static_cast(vextraParams); + auto z = reinterpret_cast(vz); + auto reductionPointerA = reinterpret_cast(reductionBuffer); + + if (sd::Environment::getInstance().isDebugAndVerbose()) + printf("D16 opNum:[%i]\n", opNum); + + summaryStatsReduceT + <<>>( + opNum, x, xShapeInfo, shape::rank(hxShapeInfo), extraParams, z, + zShapeInfo, shape::rank(hzShapeInfo), nullptr, 1, 1, biasCorrected, + nullptr, reductionPointerA, tadShapeInfo, tadOffsets); + + // this is blocking method since method should return scalar + sd::DebugHelper::checkErrorCode(stream, "execSSReduceScalar(...) failed"); +} - DEBUG_KERNEL(stream, opNum); - } +template +_CUDA_H void SummaryStatsReduce::execSummaryStatsReduce( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void* vextraParams, + void* vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + bool biasCorrected, void* reductionBuffer) { + auto x = static_cast(vx); + auto z = static_cast(vz); + auto extraParams = static_cast(vextraParams); + + if (sd::Environment::getInstance().isDebugAndVerbose()) + printf("F17 opNum:[%i]\n", opNum); + + auto reductionPointerA = reinterpret_cast(reductionBuffer); + + summaryStatsReduceT + <<>>( + opNum, x, xShapeInfo, shape::rank(hxShapeInfo), extraParams, z, + zShapeInfo, shape::rank(hzShapeInfo), nullptr, 1, 1, biasCorrected, + nullptr, reductionPointerA, tadShapeInfo, tadOffsets); + + DEBUG_KERNEL(stream, opNum); +} +template +_CUDA_H void SummaryStatsReduce::execSummaryStatsReduce( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void* vextraParams, + void* vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, + int* dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, bool biasCorrected, void* reductionBuffer) { + auto x = static_cast(vx); + auto z = static_cast(vz); + auto extraParams = static_cast(vextraParams); + + if (sd::Environment::getInstance().isDebugAndVerbose()) + printf("D18 opNum:[%i]\n", opNum); + + summaryStatsReduceT + <<>>( + opNum, x, xShapeInfo, shape::rank(hxShapeInfo), extraParams, z, + zShapeInfo, shape::rank(hzShapeInfo), dimension, dimensionLength, 1, + biasCorrected, nullptr, reinterpret_cast(reductionBuffer), + tadShapeInfo, tadOffsets); + + DEBUG_KERNEL(stream, opNum); +} - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT SummaryStatsReduce, , LIBND4J_TYPES, FLOAT_TYPES); - } -} \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT SummaryStatsReduce, , + LIBND4J_TYPES, FLOAT_TYPES); +} // namespace summarystats +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/transform/transform_any.cu b/libnd4j/include/loops/cuda/transform/transform_any.cu index 8b00b28fe925..7fcb39b04b04 100644 --- a/libnd4j/include/loops/cuda/transform/transform_any.cu +++ b/libnd4j/include/loops/cuda/transform/transform_any.cu @@ -18,118 +18,112 @@ // @author raver119@gmail.com // -#include +#include +#include #include -#include +#include #include - -#include -#include +#include using namespace simdOps; - template -__global__ void transformAnySimple( - const void *x, const Nd4jLong *xShapeInfo, int xRank, - void *params, - void *z, const Nd4jLong *zShapeInfo, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - functions::transform::TransformAny::template transformCuda(x,xShapeInfo,params,z,zShapeInfo,allocationPointer,reductionPointer,tadShapeInfo, tadOffsets); +__global__ void transformAnySimple(const void *x, const Nd4jLong *xShapeInfo, + int xRank, void *params, void *z, + const Nd4jLong *zShapeInfo, int zRank, + int *allocationPointer, + void *reductionPointer, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + functions::transform::TransformAny::template transformCuda( + x, xShapeInfo, params, z, zShapeInfo, allocationPointer, reductionPointer, + tadShapeInfo, tadOffsets); } - namespace functions { - namespace transform { - - template - _CUDA_H void TransformAny::executeTransformShaped( - dim3 launchDims, cudaStream_t *stream, - const int opNum, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - DISPATCH_BY_OPNUM_TT(intermediateShaped, PARAMS(launchDims, stream, x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), TRANSFORM_ANY_OPS); - - DEBUG_KERNEL(stream, opNum); - } - - - template - template - __device__ void TransformAny::transformCuda( - const void *vx, const Nd4jLong *xShapeInfo, - void *vparams, - void *vz, const Nd4jLong *zShapeInfo, - int *allocationPointer, void *vreductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto params = reinterpret_cast(vparams); - auto reductionPointer = reinterpret_cast(vreductionPointer); - - __shared__ Nd4jLong xEws; - __shared__ Nd4jLong zEws; - __shared__ char xOrder; - __shared__ char zOrder; - __shared__ Nd4jLong length; - - if (threadIdx.x == 0) { - - xEws = shape::elementWiseStride(xShapeInfo); - zEws = shape::elementWiseStride(zShapeInfo); - xOrder = shape::order(xShapeInfo); - zOrder = shape::order(zShapeInfo); - length = shape::length(xShapeInfo); - } - __syncthreads(); - - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - int totalThreads = gridDim.x * blockDim.x; - - if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { - - for (int i = tid; i < length; i += totalThreads) - z[i * zEws] = OpType::op(x[i * xEws], params); - } - else { - if(vx == vz) { - for (Nd4jLong i = tid; i < length; i+= totalThreads) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - z[xOffset] = OpType::op(x[xOffset], params); - } - } - else { - for (Nd4jLong i = tid; i < length; i+= totalThreads) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto zOffset = shape::getIndexOffset(i, zShapeInfo); - z[zOffset] = OpType::op(x[xOffset], params); - } - } - } - }; - - - template - template - _CUDA_H void TransformAny::intermediateShaped( - dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - transformAnySimple<<>>(x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); - - sd::DebugHelper::checkErrorCode(stream, "transformAny(...) failed"); - } +namespace transform { + +template +_CUDA_H void TransformAny::executeTransformShaped( + dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + DISPATCH_BY_OPNUM_TT(intermediateShaped, + PARAMS(launchDims, stream, x, xShape, xRank, extraParams, + z, zShape, zRank, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets), + TRANSFORM_ANY_OPS); + + DEBUG_KERNEL(stream, opNum); +} - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformAny, , LIBND4J_TYPES, LIBND4J_TYPES); +template +template +__device__ void TransformAny::transformCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vparams, void *vz, + const Nd4jLong *zShapeInfo, int *allocationPointer, void *vreductionPointer, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto params = reinterpret_cast(vparams); + auto reductionPointer = reinterpret_cast(vreductionPointer); + + __shared__ Nd4jLong xEws; + __shared__ Nd4jLong zEws; + __shared__ char xOrder; + __shared__ char zOrder; + __shared__ Nd4jLong length; + + if (threadIdx.x == 0) { + xEws = shape::elementWiseStride(xShapeInfo); + zEws = shape::elementWiseStride(zShapeInfo); + xOrder = shape::order(xShapeInfo); + zOrder = shape::order(zShapeInfo); + length = shape::length(xShapeInfo); + } + __syncthreads(); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + int totalThreads = gridDim.x * blockDim.x; + + if (xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { + for (int i = tid; i < length; i += totalThreads) + z[i * zEws] = OpType::op(x[i * xEws], params); + } else { + if (vx == vz) { + for (Nd4jLong i = tid; i < length; i += totalThreads) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + z[xOffset] = OpType::op(x[xOffset], params); + } + } else { + for (Nd4jLong i = tid; i < length; i += totalThreads) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto zOffset = shape::getIndexOffset(i, zShapeInfo); + z[zOffset] = OpType::op(x[xOffset], params); + } } + } +}; + +template +template +_CUDA_H void TransformAny::intermediateShaped( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + transformAnySimple + <<>>( + x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets); + + sd::DebugHelper::checkErrorCode(stream, "transformAny(...) failed"); } + +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT TransformAny, , LIBND4J_TYPES, + LIBND4J_TYPES); +} // namespace transform +} // namespace functions diff --git a/libnd4j/include/loops/cuda/transform/transform_bool.cu b/libnd4j/include/loops/cuda/transform/transform_bool.cu index f9526d296893..791a48097a6b 100644 --- a/libnd4j/include/loops/cuda/transform/transform_bool.cu +++ b/libnd4j/include/loops/cuda/transform/transform_bool.cu @@ -18,123 +18,118 @@ // @author raver119@gmail.com // -#include +#include +#include #include -#include +#include #include - -#include -#include +#include using namespace simdOps; - template -__global__ void transformBoolSimple( - const void *x, const Nd4jLong *xShapeInfo, int xRank, - void *params, - void *z, const Nd4jLong *zShapeInfo, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - functions::transform::TransformBool::template transformCuda(x,xShapeInfo,params,z,zShapeInfo,allocationPointer,reductionPointer,tadShapeInfo, tadOffsets); +__global__ void transformBoolSimple(const void *x, const Nd4jLong *xShapeInfo, + int xRank, void *params, void *z, + const Nd4jLong *zShapeInfo, int zRank, + int *allocationPointer, + void *reductionPointer, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + functions::transform::TransformBool::template transformCuda( + x, xShapeInfo, params, z, zShapeInfo, allocationPointer, reductionPointer, + tadShapeInfo, tadOffsets); } - namespace functions { - namespace transform { - - template - _CUDA_H void TransformBool::executeTransformShaped( - dim3 launchDims, cudaStream_t *stream, - const int opNum, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - DISPATCH_BY_OPNUM_TT(intermediateShaped, PARAMS(launchDims, stream, x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), TRANSFORM_BOOL_OPS); +namespace transform { + +template +_CUDA_H void TransformBool::executeTransformShaped( + dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + DISPATCH_BY_OPNUM_TT(intermediateShaped, + PARAMS(launchDims, stream, x, xShape, xRank, extraParams, + z, zShape, zRank, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets), + TRANSFORM_BOOL_OPS); + + DEBUG_KERNEL(stream, opNum); +} - DEBUG_KERNEL(stream, opNum); +template +template +__device__ void TransformBool::transformCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vparams, void *vz, + const Nd4jLong *zShapeInfo, int *allocationPointer, void *vreductionPointer, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto x = static_cast(vx); + auto z = static_cast(vz); + auto params = static_cast(vparams); + auto reductionPointer = static_cast(vreductionPointer); + + if (OpType::requiresSpecial) { + OpType::execSpecialCuda(x, xShapeInfo, z, zShapeInfo, params, + allocationPointer, reductionPointer, tadShapeInfo, + tadOffsets); + return; + } else { + __shared__ Nd4jLong xEws; + __shared__ Nd4jLong zEws; + __shared__ char xOrder; + __shared__ char zOrder; + __shared__ Nd4jLong length; + + if (threadIdx.x == 0) { + xEws = shape::elementWiseStride(xShapeInfo); + zEws = shape::elementWiseStride(zShapeInfo); + xOrder = shape::order(xShapeInfo); + zOrder = shape::order(zShapeInfo); + length = shape::length(xShapeInfo); + } + __syncthreads(); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + int totalThreads = gridDim.x * blockDim.x; + + if (xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { + for (int i = tid; i < length; i += totalThreads) + z[i * zEws] = OpType::op(x[i * xEws], params); + } else { + if (vx == vz) { + for (Nd4jLong i = tid; i < length; i += totalThreads) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + z[xOffset] = OpType::op(x[xOffset], params); } - - - template - template - __device__ void TransformBool::transformCuda( - const void *vx, const Nd4jLong *xShapeInfo, - void *vparams, - void *vz, const Nd4jLong *zShapeInfo, - int *allocationPointer, void *vreductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - auto x = static_cast(vx); - auto z = static_cast(vz); - auto params = static_cast(vparams); - auto reductionPointer = static_cast(vreductionPointer); - - if(OpType::requiresSpecial) { - OpType::execSpecialCuda(x,xShapeInfo,z,zShapeInfo,params, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); - return; - } - else { - __shared__ Nd4jLong xEws; - __shared__ Nd4jLong zEws; - __shared__ char xOrder; - __shared__ char zOrder; - __shared__ Nd4jLong length; - - if (threadIdx.x == 0) { - - xEws = shape::elementWiseStride(xShapeInfo); - zEws = shape::elementWiseStride(zShapeInfo); - xOrder = shape::order(xShapeInfo); - zOrder = shape::order(zShapeInfo); - length = shape::length(xShapeInfo); - } - __syncthreads(); - - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - int totalThreads = gridDim.x * blockDim.x; - - if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { - - for (int i = tid; i < length; i += totalThreads) - z[i * zEws] = OpType::op(x[i * xEws], params); - } - else { - if(vx == vz) { - for (Nd4jLong i = tid; i < length; i+= totalThreads) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - z[xOffset] = OpType::op(x[xOffset], params); - } - } - else { - for (Nd4jLong i = tid; i < length; i+= totalThreads) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto zOffset = shape::getIndexOffset(i, zShapeInfo); - z[zOffset] = OpType::op(x[xOffset], params); - } - } - } - } - }; - - - template - template - _CUDA_H void TransformBool::intermediateShaped( - dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - transformBoolSimple<<>>(x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); - sd::DebugHelper::checkErrorCode(stream, "transformBool(...) failed"); - } - - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformBool, , LIBND4J_TYPES, BOOL_TYPES); + } else { + for (Nd4jLong i = tid; i < length; i += totalThreads) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto zOffset = shape::getIndexOffset(i, zShapeInfo); + z[zOffset] = OpType::op(x[xOffset], params); + } + } } + } +}; + +template +template +_CUDA_H void TransformBool::intermediateShaped( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + transformBoolSimple + <<>>( + x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets); + sd::DebugHelper::checkErrorCode(stream, "transformBool(...) failed"); } + +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT TransformBool, , LIBND4J_TYPES, + BOOL_TYPES); +} // namespace transform +} // namespace functions diff --git a/libnd4j/include/loops/cuda/transform/transform_float.cu b/libnd4j/include/loops/cuda/transform/transform_float.cu index 6b6889009a50..85e564798b4f 100644 --- a/libnd4j/include/loops/cuda/transform/transform_float.cu +++ b/libnd4j/include/loops/cuda/transform/transform_float.cu @@ -18,130 +18,132 @@ // @author raver119@gmail.com // -#include +#include +#include #include -#include +#include #include - -#include -#include +#include using namespace simdOps; template -__global__ void transformFloatSimple(const void *x, const Nd4jLong *xShapeInfo, int xRank, - void *params, - void *z, const Nd4jLong *zShapeInfo, int zRank, - int *allocationPointer, - void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - functions::transform::TransformFloat::template transformCuda( - x, xShapeInfo, - params, - z, zShapeInfo, - allocationPointer, reductionPointer, - tadShapeInfo, tadOffsets); +__global__ void transformFloatSimple(const void *x, const Nd4jLong *xShapeInfo, + int xRank, void *params, void *z, + const Nd4jLong *zShapeInfo, int zRank, + int *allocationPointer, + void *reductionPointer, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + functions::transform::TransformFloat::template transformCuda( + x, xShapeInfo, params, z, zShapeInfo, allocationPointer, reductionPointer, + tadShapeInfo, tadOffsets); } - namespace functions { - namespace transform { - - template - _CUDA_H void TransformFloat::executeTransformShaped(dim3 launchDims, cudaStream_t *stream, int opNum, const void *x, const Nd4jLong *xShape, int xRank, void *extraParams, void *z, const Nd4jLong *zShape, int zRank, int *allocationPointer, void *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - DISPATCH_BY_OPNUM_TT(intermediateShaped, PARAMS(launchDims, stream, x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), TRANSFORM_FLOAT_OPS); +namespace transform { + +template +_CUDA_H void TransformFloat::executeTransformShaped( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + DISPATCH_BY_OPNUM_TT(intermediateShaped, + PARAMS(launchDims, stream, x, xShape, xRank, extraParams, + z, zShape, zRank, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets), + TRANSFORM_FLOAT_OPS); + + DEBUG_KERNEL(stream, opNum); +} - DEBUG_KERNEL(stream, opNum); +template +template +__device__ void TransformFloat::transformCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vparams, void *vz, + const Nd4jLong *zShapeInfo, int *allocationPointer, void *vreductionPointer, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto params = reinterpret_cast(vparams); + auto reductionPointer = reinterpret_cast(vreductionPointer); + + if (OpType::requiresSpecial) { + OpType::execSpecialCuda(x, xShapeInfo, z, zShapeInfo, params, + allocationPointer, reductionPointer, tadShapeInfo, + tadOffsets); + return; + } else { + __shared__ Nd4jLong xEws; + __shared__ Nd4jLong zEws; + __shared__ char xOrder; + __shared__ char zOrder; + __shared__ Nd4jLong length; + + if (threadIdx.x == 0) { + xEws = shape::elementWiseStride(xShapeInfo); + zEws = shape::elementWiseStride(zShapeInfo); + xOrder = shape::order(xShapeInfo); + zOrder = shape::order(zShapeInfo); + length = shape::length(xShapeInfo); + } + __syncthreads(); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + int totalThreads = gridDim.x * blockDim.x; + + if (xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { + for (Nd4jLong i = tid; i < length; i += totalThreads) + z[i * zEws] = OpType::op(x[i * xEws], params); + } else { + if (vx == vz) { + for (Nd4jLong i = tid; i < length; i += totalThreads) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + z[xOffset] = OpType::op(x[xOffset], params); } - - - template - template - __device__ void TransformFloat::transformCuda(const void *vx, const Nd4jLong *xShapeInfo, - void *vparams, - void *vz, const Nd4jLong *zShapeInfo, - int *allocationPointer, void *vreductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto params = reinterpret_cast(vparams); - auto reductionPointer = reinterpret_cast(vreductionPointer); - - if(OpType::requiresSpecial) { - OpType::execSpecialCuda(x,xShapeInfo,z,zShapeInfo,params, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); - return; - } - else { - - __shared__ Nd4jLong xEws; - __shared__ Nd4jLong zEws; - __shared__ char xOrder; - __shared__ char zOrder; - __shared__ Nd4jLong length; - - if (threadIdx.x == 0) { - - xEws = shape::elementWiseStride(xShapeInfo); - zEws = shape::elementWiseStride(zShapeInfo); - xOrder = shape::order(xShapeInfo); - zOrder = shape::order(zShapeInfo); - length = shape::length(xShapeInfo); - } - __syncthreads(); - - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - int totalThreads = gridDim.x * blockDim.x; - - if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { - - for (Nd4jLong i = tid; i < length; i += totalThreads) - z[i * zEws] = OpType::op(x[i * xEws], params); - } - else { - if(vx == vz) { - for (Nd4jLong i = tid; i < length; i+= totalThreads) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - z[xOffset] = OpType::op(x[xOffset], params); - } - } - else { - for (Nd4jLong i = tid; i < length; i+= totalThreads) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto zOffset = shape::getIndexOffset(i, zShapeInfo); - z[zOffset] = OpType::op(x[xOffset], params); - } - } - } - } - }; - - template - __device__ void TransformFloat::transformCudaLegacy( - const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *params, - void *z, const Nd4jLong *zShapeInfo, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - DISPATCH_BY_OPNUM_TT(transformCuda, PARAMS(x, xShapeInfo, params, z, zShapeInfo, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), TRANSFORM_FLOAT_OPS); + } else { + for (Nd4jLong i = tid; i < length; i += totalThreads) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto zOffset = shape::getIndexOffset(i, zShapeInfo); + z[zOffset] = OpType::op(x[xOffset], params); } - - template - template - _CUDA_H void TransformFloat::intermediateShaped( - dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - transformFloatSimple<<>>(x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); - - sd::DebugHelper::checkErrorCode(stream, "transformFloat(...) failed"); - } - - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformFloat, , LIBND4J_TYPES, FLOAT_TYPES); + } } + } +}; + +template +__device__ void TransformFloat::transformCudaLegacy( + const int opNum, const void *x, const Nd4jLong *xShapeInfo, void *params, + void *z, const Nd4jLong *zShapeInfo, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + DISPATCH_BY_OPNUM_TT( + transformCuda, + PARAMS(x, xShapeInfo, params, z, zShapeInfo, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets), + TRANSFORM_FLOAT_OPS); +} + +template +template +_CUDA_H void TransformFloat::intermediateShaped( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + transformFloatSimple + <<>>( + x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets); + + sd::DebugHelper::checkErrorCode(stream, "transformFloat(...) failed"); } + +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT TransformFloat, , LIBND4J_TYPES, + FLOAT_TYPES); +} // namespace transform +} // namespace functions diff --git a/libnd4j/include/loops/cuda/transform/transform_same.cu b/libnd4j/include/loops/cuda/transform/transform_same.cu index b03146da90b6..91036d08e680 100644 --- a/libnd4j/include/loops/cuda/transform/transform_same.cu +++ b/libnd4j/include/loops/cuda/transform/transform_same.cu @@ -18,112 +18,117 @@ // @author raver119@gmail.com // -#include +#include +#include #include -#include +#include #include - -#include -#include +#include using namespace simdOps; template -__global__ void transformSameSimple(const void *x, const Nd4jLong *xShapeInfo, int xRank, - void *params, - void *z, const Nd4jLong *zShapeInfo, int zRank, - int *allocationPointer, - void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - functions::transform::TransformSame::template transformCuda(x,xShapeInfo,params,z,zShapeInfo,allocationPointer,reductionPointer, tadShapeInfo, tadOffsets); +__global__ void transformSameSimple(const void *x, const Nd4jLong *xShapeInfo, + int xRank, void *params, void *z, + const Nd4jLong *zShapeInfo, int zRank, + int *allocationPointer, + void *reductionPointer, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + functions::transform::TransformSame::template transformCuda( + x, xShapeInfo, params, z, zShapeInfo, allocationPointer, reductionPointer, + tadShapeInfo, tadOffsets); } - namespace functions { - namespace transform { - - template - _CUDA_H void TransformSame::executeTransformShaped(dim3 launchDims, cudaStream_t *stream, - const int opNum, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), TRANSFORM_SAME_OPS); +namespace transform { + +template +_CUDA_H void TransformSame::executeTransformShaped( + dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + DISPATCH_BY_OPNUM_T(intermediateShaped, + PARAMS(launchDims, stream, x, xShape, xRank, extraParams, + z, zShape, zRank, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets), + TRANSFORM_SAME_OPS); + + DEBUG_KERNEL(stream, opNum); +} - DEBUG_KERNEL(stream, opNum); +template +template +__device__ void TransformSame::transformCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vparams, void *vz, + const Nd4jLong *zShapeInfo, int *allocationPointer, void *vreductionPointer, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto x = static_cast(vx); + auto z = static_cast(vz); + auto params = static_cast(vparams); + auto reductionPointer = static_cast(vreductionPointer); + + if (OpType::requiresSpecial) { + OpType::execSpecialCuda(x, xShapeInfo, z, zShapeInfo, params, + allocationPointer, reductionPointer, tadShapeInfo, + tadOffsets); + return; + } else { + __shared__ Nd4jLong xEws; + __shared__ Nd4jLong zEws; + __shared__ char xOrder; + __shared__ char zOrder; + __shared__ Nd4jLong length; + + if (threadIdx.x == 0) { + xEws = shape::elementWiseStride(xShapeInfo); + zEws = shape::elementWiseStride(zShapeInfo); + xOrder = shape::order(xShapeInfo); + zOrder = shape::order(zShapeInfo); + length = shape::length(xShapeInfo); + } + __syncthreads(); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + int totalThreads = gridDim.x * blockDim.x; + + if (xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { + for (int i = tid; i < length; i += totalThreads) + z[i * zEws] = OpType::op(x[i * xEws], params); + } else { + if (vx == vz) { + for (Nd4jLong i = tid; i < length; i += totalThreads) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + z[xOffset] = OpType::op(x[xOffset], params); } - - - template - template - __device__ void TransformSame::transformCuda(const void *vx, const Nd4jLong *xShapeInfo, - void *vparams, - void *vz, const Nd4jLong *zShapeInfo, - int *allocationPointer, void *vreductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - auto x = static_cast(vx); - auto z = static_cast(vz); - auto params = static_cast(vparams); - auto reductionPointer = static_cast(vreductionPointer); - - if(OpType::requiresSpecial) { - OpType::execSpecialCuda(x,xShapeInfo,z,zShapeInfo,params, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); - return; - } else { - __shared__ Nd4jLong xEws; - __shared__ Nd4jLong zEws; - __shared__ char xOrder; - __shared__ char zOrder; - __shared__ Nd4jLong length; - - if (threadIdx.x == 0) { - - xEws = shape::elementWiseStride(xShapeInfo); - zEws = shape::elementWiseStride(zShapeInfo); - xOrder = shape::order(xShapeInfo); - zOrder = shape::order(zShapeInfo); - length = shape::length(xShapeInfo); - } - __syncthreads(); - - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - int totalThreads = gridDim.x * blockDim.x; - - if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { - - for (int i = tid; i < length; i += totalThreads) - z[i * zEws] = OpType::op(x[i * xEws], params); - } - else { - if(vx == vz) { - for (Nd4jLong i = tid; i < length; i+= totalThreads) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - z[xOffset] = OpType::op(x[xOffset], params); - } - } - else { - for (Nd4jLong i = tid; i < length; i+= totalThreads) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto zOffset = shape::getIndexOffset(i, zShapeInfo); - z[zOffset] = OpType::op(x[xOffset], params); - } - } - } - } - }; - - - template - template - _CUDA_H void TransformSame::intermediateShaped(dim3 launchDims, cudaStream_t *stream, const void *x, const Nd4jLong *xShape, int xRank, void *extraParams, void *z, const Nd4jLong *zShape, int zRank, int *allocationPointer, void *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - transformSameSimple<<>>(x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); - sd::DebugHelper::checkErrorCode(stream, "transformSame(...) failed"); - } - - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TransformSame, , LIBND4J_TYPES); + } else { + for (Nd4jLong i = tid; i < length; i += totalThreads) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto zOffset = shape::getIndexOffset(i, zShapeInfo); + z[zOffset] = OpType::op(x[xOffset], params); + } + } } + } +}; + +template +template +_CUDA_H void TransformSame::intermediateShaped( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + transformSameSimple + <<>>( + x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets); + sd::DebugHelper::checkErrorCode(stream, "transformSame(...) failed"); } + +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT TransformSame, , LIBND4J_TYPES); +} // namespace transform +} // namespace functions diff --git a/libnd4j/include/loops/cuda/transform/transform_strict.cu b/libnd4j/include/loops/cuda/transform/transform_strict.cu index f36b50c29843..f3f714ec3391 100644 --- a/libnd4j/include/loops/cuda/transform/transform_strict.cu +++ b/libnd4j/include/loops/cuda/transform/transform_strict.cu @@ -18,119 +18,117 @@ // @author raver119@gmail.com // -#include +#include +#include #include -#include +#include #include - -#include -#include +#include using namespace simdOps; template -__global__ void transformStrictSimple(const void *x, const Nd4jLong *xShapeInfo, int xRank, - void *params, - void *z, const Nd4jLong *zShapeInfo, int zRank, - int *allocationPointer, - void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - functions::transform::TransformStrict::template transformCuda(x,xShapeInfo,params,z,zShapeInfo,allocationPointer,reductionPointer,tadShapeInfo, tadOffsets); +__global__ void transformStrictSimple(const void *x, const Nd4jLong *xShapeInfo, + int xRank, void *params, void *z, + const Nd4jLong *zShapeInfo, int zRank, + int *allocationPointer, + void *reductionPointer, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + functions::transform::TransformStrict::template transformCuda( + x, xShapeInfo, params, z, zShapeInfo, allocationPointer, reductionPointer, + tadShapeInfo, tadOffsets); } - namespace functions { - namespace transform { - - template - _CUDA_H void TransformStrict::executeTransformShaped(dim3 launchDims, cudaStream_t *stream, - const int opNum, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), TRANSFORM_STRICT_OPS); +namespace transform { + +template +_CUDA_H void TransformStrict::executeTransformShaped( + dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + DISPATCH_BY_OPNUM_T(intermediateShaped, + PARAMS(launchDims, stream, x, xShape, xRank, extraParams, + z, zShape, zRank, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets), + TRANSFORM_STRICT_OPS); + + DEBUG_KERNEL(stream, opNum); +} - DEBUG_KERNEL(stream, opNum); +template +template +__device__ void TransformStrict::transformCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vparams, void *vz, + const Nd4jLong *zShapeInfo, int *allocationPointer, void *vreductionPointer, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto x = static_cast(vx); + auto z = static_cast(vz); + auto params = static_cast(vparams); + auto reductionPointer = static_cast(vreductionPointer); + + if (OpType::requiresSpecial) { + OpType::execSpecialCuda(x, xShapeInfo, z, zShapeInfo, params, + allocationPointer, reductionPointer, tadShapeInfo, + tadOffsets); + return; + } else { + __shared__ Nd4jLong xEws; + __shared__ Nd4jLong zEws; + __shared__ char xOrder; + __shared__ char zOrder; + __shared__ Nd4jLong length; + + if (threadIdx.x == 0) { + xEws = shape::elementWiseStride(xShapeInfo); + zEws = shape::elementWiseStride(zShapeInfo); + xOrder = shape::order(xShapeInfo); + zOrder = shape::order(zShapeInfo); + length = shape::length(xShapeInfo); + } + __syncthreads(); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + int totalThreads = gridDim.x * blockDim.x; + + if (xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { + for (int i = tid; i < length; i += totalThreads) + z[i * zEws] = OpType::op(x[i * xEws], params); + } else { + if (vx == vz) { + for (Nd4jLong i = tid; i < length; i += totalThreads) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + z[xOffset] = OpType::op(x[xOffset], params); } - - - template - template - __device__ void TransformStrict::transformCuda(const void *vx, const Nd4jLong *xShapeInfo, - void *vparams, - void *vz, const Nd4jLong *zShapeInfo, - int *allocationPointer, void *vreductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - auto x = static_cast(vx); - auto z = static_cast(vz); - auto params = static_cast(vparams); - auto reductionPointer = static_cast(vreductionPointer); - - - if(OpType::requiresSpecial) { - OpType::execSpecialCuda(x,xShapeInfo,z,zShapeInfo,params, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); - return; - } - else { - __shared__ Nd4jLong xEws; - __shared__ Nd4jLong zEws; - __shared__ char xOrder; - __shared__ char zOrder; - __shared__ Nd4jLong length; - - if (threadIdx.x == 0) { - - xEws = shape::elementWiseStride(xShapeInfo); - zEws = shape::elementWiseStride(zShapeInfo); - xOrder = shape::order(xShapeInfo); - zOrder = shape::order(zShapeInfo); - length = shape::length(xShapeInfo); - } - __syncthreads(); - - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - int totalThreads = gridDim.x * blockDim.x; - - if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { - - for (int i = tid; i < length; i += totalThreads) - z[i * zEws] = OpType::op(x[i * xEws], params); - } - else { - if(vx == vz) { - for (Nd4jLong i = tid; i < length; i+= totalThreads) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - z[xOffset] = OpType::op(x[xOffset], params); - } - } - else { - for (Nd4jLong i = tid; i < length; i+= totalThreads) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto zOffset = shape::getIndexOffset(i, zShapeInfo); - z[zOffset] = OpType::op(x[xOffset], params); - } - } - } - } - }; - - template - template - _CUDA_H void TransformStrict::intermediateShaped(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - transformStrictSimple<<>>(x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); - sd::DebugHelper::checkErrorCode(stream, "transformStrict(...) failed"); - } - - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TransformStrict, , FLOAT_TYPES); + } else { + for (Nd4jLong i = tid; i < length; i += totalThreads) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto zOffset = shape::getIndexOffset(i, zShapeInfo); + z[zOffset] = OpType::op(x[xOffset], params); + } + } } + } +}; + +template +template +_CUDA_H void TransformStrict::intermediateShaped( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + transformStrictSimple + <<>>( + x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets); + sd::DebugHelper::checkErrorCode(stream, "transformStrict(...) failed"); } + +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT TransformStrict, , FLOAT_TYPES); +} // namespace transform +} // namespace functions diff --git a/libnd4j/include/loops/cuda/type_conversions.cu b/libnd4j/include/loops/cuda/type_conversions.cu index d0dee4f0d5ad..3252adce9e8d 100644 --- a/libnd4j/include/loops/cuda/type_conversions.cu +++ b/libnd4j/include/loops/cuda/type_conversions.cu @@ -18,523 +18,578 @@ // // +#include #include #include -#include namespace sd { - template - void TypeCast::convertGenericCuda(Nd4jPointer *extras, void *dx, Nd4jLong N, void *dz) { - auto stream = reinterpret_cast(&extras[1]); - - sd::convertKernel<<<256, 1024, 1024, *stream>>>(dx, N, dz); - sd::DebugHelper::checkErrorCode(stream, "convertGeneric(...) failed"); - }; - - - template - __device__ void convertKernelGeneric(S *x, Nd4jLong N, T *z) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < N; i+= blockDim.x * gridDim.x) { - // despite it's stupid, it simplifies conversion to bottom dtypes - // FIXME: get rid of through-float though - z[i] = static_cast(static_cast(x[i])); - } - }; - - -// Define this to more rigorously avoid bank conflicts, even at the lower (root) levels of the tree -//#define ZERO_BANK_CONFLICTS +template +void TypeCast::convertGenericCuda(Nd4jPointer *extras, void *dx, Nd4jLong N, + void *dz) { + auto stream = reinterpret_cast(&extras[1]); + + sd::convertKernel<<<256, 1024, 1024, *stream>>>(dx, N, dz); + sd::DebugHelper::checkErrorCode(stream, "convertGeneric(...) failed"); +}; + +template +__device__ void convertKernelGeneric(S *x, Nd4jLong N, T *z) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < N; i += blockDim.x * gridDim.x) { + // despite it's stupid, it simplifies conversion to bottom dtypes + // FIXME: get rid of through-float though + z[i] = static_cast(static_cast(x[i])); + } +}; + + // Define this to more rigorously avoid bank conflicts, even at the lower + // (root) levels of the tree + //#define ZERO_BANK_CONFLICTS #ifdef ZERO_BANK_CONFLICTS -#define CONFLICT_FREE_OFFSET(index) ((index) >> LOG_NUM_BANKS + (index) >> (2 * LOG_NUM_BANKS)) +#define CONFLICT_FREE_OFFSET(index) \ + ((index) >> LOG_NUM_BANKS + (index) >> (2 * LOG_NUM_BANKS)) #else #define CONFLICT_FREE_OFFSET(index) ((index) >> LOG_NUM_BANKS) #endif #ifdef CHECK_BANK_CONFLICTS -#define TEMP(index) CUT_BANK_CHECKER(temp, index) +#define TEMP(index) CUT_BANK_CHECKER(temp, index) #else -#define TEMP(index) temp[index] +#define TEMP(index) temp[index] #endif +template +__device__ void loadSharedChunkFromMem(int *s_data, const int *g_idata, int n, + int baseIndex, int &ai, int &bi, + int &mem_ai, int &mem_bi, + int &bankOffsetA, int &bankOffsetB) { + int thid = threadIdx.x; + mem_ai = baseIndex + threadIdx.x; + mem_bi = mem_ai + blockDim.x; + + ai = thid; + bi = thid + blockDim.x; + + // compute spacing to avoid bank conflicts + bankOffsetA = CONFLICT_FREE_OFFSET(ai); + bankOffsetB = CONFLICT_FREE_OFFSET(bi); + + // Cache the computational window in shared memory + // pad values beyond n with zeros + s_data[ai + bankOffsetA] = g_idata[mem_ai]; + + if (isNP2) { // compile-time decision + s_data[bi + bankOffsetB] = (bi < n) ? g_idata[mem_bi] : 0; + } else { + s_data[bi + bankOffsetB] = g_idata[mem_bi]; + } +} +template +__device__ void storeSharedChunkToMem(int *g_odata, int *s_data, int n, int ai, + int bi, int mem_ai, int mem_bi, + int bankOffsetA, int bankOffsetB) { + __syncthreads(); + + // write results to global memory + g_odata[mem_ai] = s_data[ai + bankOffsetA]; + if (isNP2) { // compile-time decision + if (bi < n) g_odata[mem_bi] = s_data[bi + bankOffsetB]; + } else { + g_odata[mem_bi] = s_data[bi + bankOffsetB]; + } +} - - - template - __device__ void loadSharedChunkFromMem(int *s_data, const int *g_idata, int n, int baseIndex, int& ai, int& bi, int& mem_ai, int& mem_bi, int& bankOffsetA, int& bankOffsetB) { - int thid = threadIdx.x; - mem_ai = baseIndex + threadIdx.x; - mem_bi = mem_ai + blockDim.x; - - ai = thid; - bi = thid + blockDim.x; - - // compute spacing to avoid bank conflicts - bankOffsetA = CONFLICT_FREE_OFFSET(ai); - bankOffsetB = CONFLICT_FREE_OFFSET(bi); - - // Cache the computational window in shared memory - // pad values beyond n with zeros - s_data[ai + bankOffsetA] = g_idata[mem_ai]; - - if (isNP2) { // compile-time decision - s_data[bi + bankOffsetB] = (bi < n) ? g_idata[mem_bi] : 0; - } else { - s_data[bi + bankOffsetB] = g_idata[mem_bi]; - } - } - - template - __device__ void storeSharedChunkToMem(int* g_odata, int* s_data, int n, int ai, int bi, int mem_ai, int mem_bi, int bankOffsetA, int bankOffsetB) { - __syncthreads(); - - // write results to global memory - g_odata[mem_ai] = s_data[ai + bankOffsetA]; - if (isNP2) { // compile-time decision - if (bi < n) - g_odata[mem_bi] = s_data[bi + bankOffsetB]; - } else { - g_odata[mem_bi] = s_data[bi + bankOffsetB]; - } +template +__device__ void clearLastElement(int *s_data, int *g_blockSums, + int blockIndex) { + if (threadIdx.x == 0) { + int index = (blockDim.x << 1) - 1; + index += CONFLICT_FREE_OFFSET(index); + + if (storeSum) { // compile-time decision + // write this block's total sum to the corresponding index in the + // blockSums array + g_blockSums[blockIndex] = s_data[index]; } - template - __device__ void clearLastElement(int* s_data, int *g_blockSums, int blockIndex) { - if (threadIdx.x == 0) - { - int index = (blockDim.x << 1) - 1; - index += CONFLICT_FREE_OFFSET(index); - - if (storeSum) { // compile-time decision - // write this block's total sum to the corresponding index in the blockSums array - g_blockSums[blockIndex] = s_data[index]; - } - - // zero the last element in the scan so it will propagate back to the front - s_data[index] = 0; - } - } - - - - __device__ unsigned int buildSum(int *s_data) { - unsigned int thid = threadIdx.x; - unsigned int stride = 1; - - // build the sum in place up the tree - for (int d = blockDim.x; d > 0; d >>= 1) { - __syncthreads(); + // zero the last element in the scan so it will propagate back to the front + s_data[index] = 0; + } +} - if (thid < d) { - int i = __mul24(__mul24(2, stride), thid); - int ai = i + stride - 1; - int bi = ai + stride; +__device__ unsigned int buildSum(int *s_data) { + unsigned int thid = threadIdx.x; + unsigned int stride = 1; - ai += CONFLICT_FREE_OFFSET(ai); - bi += CONFLICT_FREE_OFFSET(bi); + // build the sum in place up the tree + for (int d = blockDim.x; d > 0; d >>= 1) { + __syncthreads(); - s_data[bi] += s_data[ai]; - } + if (thid < d) { + int i = __mul24(__mul24(2, stride), thid); + int ai = i + stride - 1; + int bi = ai + stride; - stride *= 2; - } + ai += CONFLICT_FREE_OFFSET(ai); + bi += CONFLICT_FREE_OFFSET(bi); - return stride; + s_data[bi] += s_data[ai]; } - __device__ void scanRootToLeaves(int *s_data, unsigned int stride) { - unsigned int thid = threadIdx.x; + stride *= 2; + } - // traverse down the tree building the scan in place - for (int d = 1; d <= blockDim.x; d *= 2) { - stride >>= 1; + return stride; +} - __syncthreads(); +__device__ void scanRootToLeaves(int *s_data, unsigned int stride) { + unsigned int thid = threadIdx.x; - if (thid < d) { - int i = __mul24(__mul24(2, stride), thid); - int ai = i + stride - 1; - int bi = ai + stride; + // traverse down the tree building the scan in place + for (int d = 1; d <= blockDim.x; d *= 2) { + stride >>= 1; - ai += CONFLICT_FREE_OFFSET(ai); - bi += CONFLICT_FREE_OFFSET(bi); + __syncthreads(); - float t = s_data[ai]; - s_data[ai] = s_data[bi]; - s_data[bi] += t; - } - } - } + if (thid < d) { + int i = __mul24(__mul24(2, stride), thid); + int ai = i + stride - 1; + int bi = ai + stride; - template - __device__ void prescanBlock(int *data, int blockIndex, int *blockSums) { - int stride = buildSum(data); // build the sum in place up the tree - clearLastElement(data, blockSums, - (blockIndex == 0) ? blockIdx.x : blockIndex); - scanRootToLeaves(data, stride); // traverse down tree to build the scan - } + ai += CONFLICT_FREE_OFFSET(ai); + bi += CONFLICT_FREE_OFFSET(bi); + float t = s_data[ai]; + s_data[ai] = s_data[bi]; + s_data[bi] += t; + } + } +} - template - __global__ void prescan(int *g_odata, const int *g_idata, int *g_blockSums, int n, int blockIndex, int baseIndex) { - int ai, bi, mem_ai, mem_bi, bankOffsetA, bankOffsetB; - extern __shared__ int s_data[]; +template +__device__ void prescanBlock(int *data, int blockIndex, int *blockSums) { + int stride = buildSum(data); // build the sum in place up the tree + clearLastElement(data, blockSums, + (blockIndex == 0) ? blockIdx.x : blockIndex); + scanRootToLeaves(data, stride); // traverse down tree to build the scan +} - // load data into shared memory - loadSharedChunkFromMem(reinterpret_cast(s_data), g_idata, n, (baseIndex == 0) ? __mul24(blockIdx.x, (blockDim.x << 1)):baseIndex, ai, bi, mem_ai, mem_bi, bankOffsetA, bankOffsetB); +template +__global__ void prescan(int *g_odata, const int *g_idata, int *g_blockSums, + int n, int blockIndex, int baseIndex) { + int ai, bi, mem_ai, mem_bi, bankOffsetA, bankOffsetB; + extern __shared__ int s_data[]; - // scan the data in each block - prescanBlock(s_data, blockIndex, g_blockSums); + // load data into shared memory + loadSharedChunkFromMem( + reinterpret_cast(s_data), g_idata, n, + (baseIndex == 0) ? __mul24(blockIdx.x, (blockDim.x << 1)) : baseIndex, ai, + bi, mem_ai, mem_bi, bankOffsetA, bankOffsetB); - // write results to device memory - storeSharedChunkToMem(g_odata, s_data, n, ai, bi, mem_ai, mem_bi, bankOffsetA, bankOffsetB); - } + // scan the data in each block + prescanBlock(s_data, blockIndex, g_blockSums); + // write results to device memory + storeSharedChunkToMem(g_odata, s_data, n, ai, bi, mem_ai, mem_bi, + bankOffsetA, bankOffsetB); +} - __global__ void uniformAdd(int *g_data, int *uniforms, int n, int blockOffset, int baseIndex) { - __shared__ float uni; - if (threadIdx.x == 0) - uni = uniforms[blockIdx.x + blockOffset]; +__global__ void uniformAdd(int *g_data, int *uniforms, int n, int blockOffset, + int baseIndex) { + __shared__ float uni; + if (threadIdx.x == 0) uni = uniforms[blockIdx.x + blockOffset]; - unsigned int address = __mul24(blockIdx.x, (blockDim.x << 1)) + baseIndex + threadIdx.x; + unsigned int address = + __mul24(blockIdx.x, (blockDim.x << 1)) + baseIndex + threadIdx.x; - __syncthreads(); + __syncthreads(); - // note two adds per thread - g_data[address] += uni; - g_data[address + blockDim.x] += (threadIdx.x + blockDim.x < n) * uni; - } + // note two adds per thread + g_data[address] += uni; + g_data[address + blockDim.x] += (threadIdx.x + blockDim.x < n) * uni; +} /* * This kernel does prefix sum in parallel, to calculate offsets for each block */ - template - __device__ inline void encoderKernelP2Generic(void *dx, Nd4jLong n, void *dz) { - // TODO: to be remove - } +template +__device__ inline void encoderKernelP2Generic(void *dx, Nd4jLong n, void *dz) { + // TODO: to be remove +} ////////////////////////////////////////////////////////////////////////// /* - * PLEASE NOTE: This kernel doesn't allow loop for data. Basically: grid will be huge. + * PLEASE NOTE: This kernel doesn't allow loop for data. Basically: grid will be + * huge. */ -template -__global__ static void execEncoderKernelP1(const void *dx, Nd4jLong N, void *dz, float threshold) { - auto x = reinterpret_cast (dx); - auto z = reinterpret_cast (dz); - - //basically, for phase One we want do calculation: how many eligible values we have, and which blocks will be holding data - Nd4jLong tid = blockIdx.x * blockDim.x + threadIdx.x; - - int pass = tid < N && sd::math::nd4j_abs(x[tid]) >= static_cast(threshold) ? 1 : 0; - int bp=__syncthreads_count(pass); - - if (threadIdx.x == 0) { - // saving out per-block passes - z[blockIdx.x+1] = bp; - - // saving out sum - atomicAdd(&z[0], bp); - } +template +__global__ static void execEncoderKernelP1(const void *dx, Nd4jLong N, void *dz, + float threshold) { + auto x = reinterpret_cast(dx); + auto z = reinterpret_cast(dz); + + // basically, for phase One we want do calculation: how many eligible values + // we have, and which blocks will be holding data + Nd4jLong tid = blockIdx.x * blockDim.x + threadIdx.x; + + int pass = + tid < N && sd::math::nd4j_abs(x[tid]) >= static_cast(threshold) ? 1 + : 0; + int bp = __syncthreads_count(pass); + + if (threadIdx.x == 0) { + // saving out per-block passes + z[blockIdx.x + 1] = bp; + + // saving out sum + atomicAdd(&z[0], bp); + } } ////////////////////////////////////////////////////////////////////////// -template -__host__ void encoderKernelP1Generic(dim3 &launchDims, cudaStream_t *stream, const void *dx, Nd4jLong N, void *dz, float threshold) { - - execEncoderKernelP1<<>>(dx, N, dz, threshold); - sd::DebugHelper::checkErrorCode(stream, "encoderP1(...) failed"); +template +__host__ void encoderKernelP1Generic(dim3 &launchDims, cudaStream_t *stream, + const void *dx, Nd4jLong N, void *dz, + float threshold) { + execEncoderKernelP1<<>>( + dx, N, dz, threshold); + sd::DebugHelper::checkErrorCode(stream, "encoderP1(...) failed"); } -BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT encoderKernelP1Generic, (dim3 &launchDims, cudaStream_t *stream, const void *dx, Nd4jLong N, void *dz, float threshold), FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT encoderKernelP1Generic, + (dim3 & launchDims, cudaStream_t *stream, const void *dx, + Nd4jLong N, void *dz, float threshold), + FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// /* - * PLEASE NOTE: This kernel doesn't allow loop for data. Basically: grid will be huge. + * PLEASE NOTE: This kernel doesn't allow loop for data. Basically: grid will be + * huge. * - * Based on: https://github.com/knotman90/cuStreamComp <-- efficient CUDA stream compaction algorithm + * Based on: https://github.com/knotman90/cuStreamComp <-- efficient CUDA stream + * compaction algorithm */ -template -__global__ static void execEncoderKernelP3(void *dx, int *offsets, Nd4jLong N, void *dz) { - auto x = reinterpret_cast (dx); - auto z = reinterpret_cast (dz); - - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - extern __shared__ int warpTotals[]; - - // fetch block offset only once - __shared__ float threshold; - __shared__ FloatBits fb; - __shared__ int bo; - __shared__ int limit; - if (threadIdx.x == 0) { - limit = z[0]; - fb.i_ = z[2]; - threshold = fb.f_; - bo = offsets[blockIdx.x]; - } - __syncthreads(); - - // out-of-limit threads do not play here - auto value = tid < N ? x[tid] : (T) 0.f; - - // out-of-limit threads just declare they have no changes - auto pred = tid >= N ? 0 : sd::math::nd4j_abs(value) >= static_cast(threshold) ? 1 : 0; - auto w_i = threadIdx.x / warpSize; // warp index (or, warp number) - index of the Warp within TOTAL_WARPS - auto t_i = threadIdx.x % warpSize; // thread index within a warp - unsigned int t_m = INT_MAX >> (warpSize - t_i - 1); //thread mask (ERROR IN THE PAPER minus one is required) - - int b = __ballot_sync(t_m, pred); // balres = number whose ith bit isone if the ith's thread pred is true masked up to the current index in warp - auto t_u = __popc(b); // popc count the number of bit one. simply count the number predicated true BEFORE MY INDEX - - if (t_i == warpSize - 1) - warpTotals[w_i] = t_u + pred; - - __syncthreads(); - - - int w_i_u = 0; - for (int j = 0; j <= 5; j++) { - unsigned int b_j = __ballot_sync(t_m, warpTotals[t_i] & pow2i(j)); //# of the ones in the j'th digit of the warp offsets - w_i_u += (__popc(b_j) << j); - } - - // we just ignore all results coming from non-0 threads - if (w_i == 0 && t_i < blockDim.x / warpSize) - warpTotals[t_i] = w_i_u; - - __syncthreads(); - - - // pred is always false if we're out-of-limits - if (pred) { - int idx = t_u + warpTotals[w_i] + bo + 4; - if (idx < limit + 4) { - z[idx] = value > static_cast(0.0f) ? tid + 1 : -(tid + 1); - x[tid] = value > static_cast(0.0f) ? x[tid] - threshold : x[tid] + threshold; - } - } +template +__global__ static void execEncoderKernelP3(void *dx, int *offsets, Nd4jLong N, + void *dz) { + auto x = reinterpret_cast(dx); + auto z = reinterpret_cast(dz); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + extern __shared__ int warpTotals[]; + + // fetch block offset only once + __shared__ float threshold; + __shared__ FloatBits fb; + __shared__ int bo; + __shared__ int limit; + if (threadIdx.x == 0) { + limit = z[0]; + fb.i_ = z[2]; + threshold = fb.f_; + bo = offsets[blockIdx.x]; + } + __syncthreads(); + + // out-of-limit threads do not play here + auto value = tid < N ? x[tid] : (T)0.f; + + // out-of-limit threads just declare they have no changes + auto pred = + tid >= N + ? 0 + : sd::math::nd4j_abs(value) >= static_cast(threshold) ? 1 : 0; + auto w_i = threadIdx.x / warpSize; // warp index (or, warp number) - index of + // the Warp within TOTAL_WARPS + auto t_i = threadIdx.x % warpSize; // thread index within a warp + unsigned int t_m = + INT_MAX >> (warpSize - t_i - + 1); // thread mask (ERROR IN THE PAPER minus one is required) + + int b = __ballot_sync( + t_m, pred); // balres = number whose ith bit isone if the ith's thread + // pred is true masked up to the current index in warp + auto t_u = __popc(b); // popc count the number of bit one. simply count the + // number predicated true BEFORE MY INDEX + + if (t_i == warpSize - 1) warpTotals[w_i] = t_u + pred; + + __syncthreads(); + + int w_i_u = 0; + for (int j = 0; j <= 5; j++) { + unsigned int b_j = __ballot_sync( + t_m, + warpTotals[t_i] & + pow2i(j)); //# of the ones in the j'th digit of the warp offsets + w_i_u += (__popc(b_j) << j); + } + + // we just ignore all results coming from non-0 threads + if (w_i == 0 && t_i < blockDim.x / warpSize) warpTotals[t_i] = w_i_u; + + __syncthreads(); + + // pred is always false if we're out-of-limits + if (pred) { + int idx = t_u + warpTotals[w_i] + bo + 4; + if (idx < limit + 4) { + z[idx] = value > static_cast(0.0f) ? tid + 1 : -(tid + 1); + x[tid] = value > static_cast(0.0f) ? x[tid] - threshold + : x[tid] + threshold; + } + } } ////////////////////////////////////////////////////////////////////////// -template -__host__ void encoderKernelP3Generic(dim3 &launchDims, cudaStream_t *stream, void *dx, int *offsets, Nd4jLong N, void *dz) { - execEncoderKernelP3<<>>(dx, offsets, N, dz); - sd::DebugHelper::checkErrorCode(stream, "encoderP3(...) failed"); +template +__host__ void encoderKernelP3Generic(dim3 &launchDims, cudaStream_t *stream, + void *dx, int *offsets, Nd4jLong N, + void *dz) { + execEncoderKernelP3<<>>( + dx, offsets, N, dz); + sd::DebugHelper::checkErrorCode(stream, "encoderP3(...) failed"); } -BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT encoderKernelP3Generic, (dim3 &launchDims, cudaStream_t *stream, void *dx, int *offsets, Nd4jLong N, void *dz), FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT encoderKernelP3Generic, + (dim3 & launchDims, cudaStream_t *stream, void *dx, + int *offsets, Nd4jLong N, void *dz), + FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// /* * This kernel handles decode from sparse threshold array, to dense array * * PLEASE NOTE: Z is expected to be memset to 0 -*/ -template + */ +template __global__ static void execDecoderKernel(const void *dx, Nd4jLong N, void *dz) { - auto x = reinterpret_cast (dx); - auto z = reinterpret_cast (dz); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - __shared__ float threshold; - __shared__ int limit; - - __shared__ FloatBits fb; - if (threadIdx.x == 0) { - limit = x[0]; - fb.i_ = x[2]; - threshold = fb.f_; - } - __syncthreads(); - - for (int e = tid; e < limit; e += blockDim.x * gridDim.x) { - int el = x[e+4]; - int ael = sd::math::nd4j_abs(el) - 1; + auto x = reinterpret_cast(dx); + auto z = reinterpret_cast(dz); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + __shared__ float threshold; + __shared__ int limit; + + __shared__ FloatBits fb; + if (threadIdx.x == 0) { + limit = x[0]; + fb.i_ = x[2]; + threshold = fb.f_; + } + __syncthreads(); + + for (int e = tid; e < limit; e += blockDim.x * gridDim.x) { + int el = x[e + 4]; + int ael = sd::math::nd4j_abs(el) - 1; + + // TODO: investigate, if += would work better here, as in "decoded + // accumulation" + z[ael] += el > 0 ? threshold : -threshold; + } +} - // TODO: investigate, if += would work better here, as in "decoded accumulation" - z[ael] += el > 0 ? threshold : -threshold; - } +////////////////////////////////////////////////////////////////////////// +template +__host__ void decoderKernelGeneric(dim3 &launchDims, cudaStream_t *stream, + const void *dx, Nd4jLong N, void *dz) { + execDecoderKernel + <<>>(dx, N, dz); + sd::DebugHelper::checkErrorCode(stream, "execDecoder(...) failed"); } +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT decoderKernelGeneric, + (dim3 & launchDims, cudaStream_t *stream, const void *dx, + Nd4jLong N, void *dz), + FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// -template -__host__ void decoderKernelGeneric(dim3 &launchDims, cudaStream_t *stream, const void *dx, Nd4jLong N, void *dz) { +template +__global__ static void execCudaEncodeBitmapKernel(void *vdx, Nd4jLong N, + int *dz, int *scalar, + int *reductionBuffer, + float threshold) { + auto dx = reinterpret_cast(vdx); + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + T off(0.0f); + __shared__ int counter; + __shared__ int *shmem; + __shared__ T *vals; + if (threadIdx.x == 0) { + extern __shared__ char mem[]; + shmem = reinterpret_cast(mem); + vals = reinterpret_cast(shmem + blockDim.x); + counter = 0; + } + __syncthreads(); + + Nd4jLong loopRemainder = N % (blockDim.x * gridDim.x); + Nd4jLong loopLimit = N + (blockDim.x * gridDim.x - loopRemainder); + + for (Nd4jLong i = tid; i < loopLimit; i += blockDim.x * gridDim.x) { + // all threads in block reading stuff + T val = i < N ? dx[i] : off; + T abs = sd::math::nd4j_abs(val); + + int byteId = i / 16 + 4; + int bitId = i % 16; + + shmem[threadIdx.x] = 0; + vals[threadIdx.x] = val; + + if (abs >= static_cast(threshold) && i < N) { + shmem[threadIdx.x] = 1 << (bitId); + atomicAdd(&counter, 1); + if (val < static_cast(0.0f)) { + shmem[threadIdx.x] |= 1 << (bitId + 16); + vals[threadIdx.x] += static_cast(threshold); + } else { + vals[threadIdx.x] -= static_cast(threshold); + } + } else if (abs >= static_cast(threshold) / static_cast(2.0f) && + val < static_cast(0.0f) && i < N) { + atomicAdd(&counter, 1); + shmem[threadIdx.x] = 1 << (bitId + 16); + + vals[threadIdx.x] += static_cast(threshold) / static_cast(2.0f); + } + __syncthreads(); - execDecoderKernel<<>>(dx, N, dz); - sd::DebugHelper::checkErrorCode(stream, "execDecoder(...) failed"); -} -BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT decoderKernelGeneric, (dim3 &launchDims, cudaStream_t *stream, const void *dx, Nd4jLong N, void *dz), FLOAT_TYPES); + if (threadIdx.x % 16 == 0 && i < N) { + int byte = 0; + for (int e = 0; e < 16; e++) { + if (i + e >= N) continue; + byte |= shmem[threadIdx.x + e]; + } + dz[byteId] = byte; + } + __syncthreads(); -////////////////////////////////////////////////////////////////////////// -template -__global__ static void execCudaEncodeBitmapKernel(void *vdx, Nd4jLong N, int *dz, int *scalar, int *reductionBuffer, float threshold) { - auto dx = reinterpret_cast(vdx); - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - T off(0.0f); - __shared__ int counter; - __shared__ int *shmem; - __shared__ T *vals; - if (threadIdx.x == 0){ - extern __shared__ char mem[]; - shmem = reinterpret_cast(mem); - vals = reinterpret_cast(shmem + blockDim.x); - counter = 0; - } - __syncthreads(); - - Nd4jLong loopRemainder = N % (blockDim.x * gridDim.x); - Nd4jLong loopLimit = N + (blockDim.x * gridDim.x - loopRemainder); - - for (Nd4jLong i = tid; i < loopLimit; i += blockDim.x * gridDim.x) { - // all threads in block reading stuff - T val = i < N ? dx[i] : off; - T abs = sd::math::nd4j_abs(val); - - int byteId = i / 16 + 4; - int bitId = i % 16; - - shmem[threadIdx.x] = 0; - vals[threadIdx.x] = val; - - if (abs >= static_cast(threshold) && i < N) { - shmem[threadIdx.x] = 1 << (bitId); - atomicAdd(&counter, 1); - if (val < static_cast(0.0f)) { - shmem[threadIdx.x] |= 1 << (bitId + 16); - vals[threadIdx.x] += static_cast(threshold); - } else { - vals[threadIdx.x] -= static_cast(threshold); - } - } else if (abs >= static_cast(threshold) / static_cast(2.0f) && val < static_cast(0.0f) && i < N) { - atomicAdd(&counter, 1); - shmem[threadIdx.x] = 1 << (bitId + 16); - - vals[threadIdx.x] += static_cast(threshold) / static_cast(2.0f); - } - __syncthreads(); - - if (threadIdx.x % 16 == 0 && i < N) { - int byte = 0; - for (int e = 0; e < 16; e++) { - if (i + e >= N) - continue; - - byte |= shmem[threadIdx.x + e]; - } - dz[byteId] = byte; - } - __syncthreads(); - - if (i < N) - dx[i] = vals[threadIdx.x]; - } - __syncthreads(); + if (i < N) dx[i] = vals[threadIdx.x]; + } + __syncthreads(); - if (threadIdx.x == 0) { - atomicAdd(scalar, counter); - } + if (threadIdx.x == 0) { + atomicAdd(scalar, counter); + } } ////////////////////////////////////////////////////////////////////////// -template -__host__ void cudaEncodeBitmapGeneric(dim3 &launchDims, cudaStream_t *stream, void *vdx, Nd4jLong N, int *dz, int *scalar, int *reductionBuffer, float threshold) { - - execCudaEncodeBitmapKernel<<>>(vdx, N, dz, scalar, reductionBuffer, threshold); - sd::DebugHelper::checkErrorCode(stream, "encodeBitmap(...) failed"); +template +__host__ void cudaEncodeBitmapGeneric(dim3 &launchDims, cudaStream_t *stream, + void *vdx, Nd4jLong N, int *dz, + int *scalar, int *reductionBuffer, + float threshold) { + execCudaEncodeBitmapKernel + <<>>( + vdx, N, dz, scalar, reductionBuffer, threshold); + sd::DebugHelper::checkErrorCode(stream, "encodeBitmap(...) failed"); } -BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT cudaEncodeBitmapGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vdx, Nd4jLong N, int *dz, int *scalar, int *reductionBuffer, float threshold), FLOAT_TYPES); - +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT cudaEncodeBitmapGeneric, + (dim3 & launchDims, cudaStream_t *stream, void *vdx, + Nd4jLong N, int *dz, int *scalar, int *reductionBuffer, + float threshold), + FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// -template -__global__ static void execCudaDecodeBitmapKernel(const void *dx, Nd4jLong N, void *vdz) { - auto dz = static_cast(vdz); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - __shared__ T *shmem; - __shared__ FloatBits fb; - __shared__ float threshold; - __shared__ const int *x; - if (threadIdx.x == 0){ - extern __shared__ char mem[]; - shmem = reinterpret_cast(mem); - x = reinterpret_cast(dx); - fb.i_ = x[2]; - threshold = fb.f_; +template +__global__ static void execCudaDecodeBitmapKernel(const void *dx, Nd4jLong N, + void *vdz) { + auto dz = static_cast(vdz); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + __shared__ T *shmem; + __shared__ FloatBits fb; + __shared__ float threshold; + __shared__ const int *x; + if (threadIdx.x == 0) { + extern __shared__ char mem[]; + shmem = reinterpret_cast(mem); + x = reinterpret_cast(dx); + fb.i_ = x[2]; + threshold = fb.f_; + } + __syncthreads(); + + int lim = N / 16 + 5; + for (int i = tid; i < N; i += blockDim.x * gridDim.x) { + int byteId = i / 16 + 4; + // printf("I: [%i]; byteId: [%i]\n", i, byteId); + + shmem[threadIdx.x] = dz[i]; + __syncthreads(); + + if (threadIdx.x % 16 == 0) { + int byte = x[byteId]; + + for (int e = 0; e < 16; e++) { + if (i + e >= N) continue; + + int bitId = (i + e) % 16; + + bool hasBit = (byte & 1 << (bitId)) != 0; + bool hasSign = (byte & 1 << (bitId + 16)) != 0; + + if (hasBit) { + if (hasSign) + shmem[threadIdx.x + bitId] -= threshold; + else + shmem[threadIdx.x + bitId] += threshold; + } else if (hasSign) { + shmem[threadIdx.x + bitId] -= threshold / 2; } - __syncthreads(); - - int lim = N / 16 + 5; - for (int i = tid; i < N; i += blockDim.x * gridDim.x) { - int byteId = i / 16 + 4; -// printf("I: [%i]; byteId: [%i]\n", i, byteId); - - shmem[threadIdx.x] = dz[i]; - __syncthreads(); - - if (threadIdx.x % 16 == 0) { - int byte = x[byteId]; - - for (int e = 0; e < 16; e++) { - if (i + e >= N) - continue; - - int bitId = (i + e) % 16; - - bool hasBit = (byte & 1 << (bitId) ) != 0; - bool hasSign = (byte & 1 << (bitId + 16) ) != 0; - - if (hasBit) { - if (hasSign) - shmem[threadIdx.x + bitId] -= threshold; - else - shmem[threadIdx.x + bitId] += threshold; - } else if (hasSign) { - shmem[threadIdx.x + bitId] -= threshold / 2; - } - } - } - __syncthreads(); + } + } + __syncthreads(); - dz[i] = shmem[threadIdx.x]; - } + dz[i] = shmem[threadIdx.x]; + } } ////////////////////////////////////////////////////////////////////////// -template -__host__ void cudaDecodeBitmapGeneric(dim3 &launchDims, cudaStream_t *stream, const void *dx, Nd4jLong N, void *vdz) { - - execCudaDecodeBitmapKernel<<>>(dx, N, vdz); - sd::DebugHelper::checkErrorCode(stream, "cudeDecodeBitmap(...) failed"); +template +__host__ void cudaDecodeBitmapGeneric(dim3 &launchDims, cudaStream_t *stream, + const void *dx, Nd4jLong N, void *vdz) { + execCudaDecodeBitmapKernel + <<>>(dx, N, vdz); + sd::DebugHelper::checkErrorCode(stream, "cudeDecodeBitmap(...) failed"); +} +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT cudaDecodeBitmapGeneric, + (dim3 & launchDims, cudaStream_t *stream, const void *dx, + Nd4jLong N, void *vdz), + FLOAT_TYPES); + +template +__host__ void prescanLauncher(dim3 &blocks, dim3 &threads, int shmem, + cudaStream_t *stream, int *g_odata, + const int *g_idata, int *g_blockSums, int n, + int blockIndex, int baseIndex) { + //printf("Prescan grid: <%i/%i/%i>; threads: <%i/%i/%i>; shareMemSize: %i\n", blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z, shmem); + shmem = sd::math::nd4j_max(shmem, 16384);prescan<<>>( + g_odata, g_idata, g_blockSums, n, blockIndex, baseIndex); + +}; + +template +__global__ void convertKernel(void *dx, Nd4jLong N, void *dz) { + auto x = reinterpret_cast(dx); + auto z = reinterpret_cast(dz); + + sd::convertKernelGeneric(x, N, z); } -BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT cudaDecodeBitmapGeneric, (dim3 &launchDims, cudaStream_t *stream, const void *dx, Nd4jLong N, void *vdz), FLOAT_TYPES); - - - template - __host__ void prescanLauncher(dim3 &blocks, dim3 &threads, int shmem, cudaStream_t *stream, int *g_odata, const int *g_idata, int *g_blockSums, int n, int blockIndex, int baseIndex) { - //printf("Prescan grid: <%i/%i/%i>; threads: <%i/%i/%i>; shareMemSize: %i\n", blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z, shmem); - shmem = sd::math::nd4j_max(shmem, 16384); - prescan<<>>(g_odata, g_idata, g_blockSums, n, blockIndex, baseIndex); - }; - - template - __global__ void convertKernel(void *dx, Nd4jLong N, void *dz) { - auto x = reinterpret_cast(dx); - auto z = reinterpret_cast(dz); - - sd::convertKernelGeneric(x, N, z); - } - -#define LIBND4J_BOOLS_LOCAL \ - (randomName0, 0), \ - (randomName1, 1) +#define LIBND4J_BOOLS_LOCAL (randomName0, 0), (randomName1, 1) - BUILD_DOUBLE_TEMPLATE(template void TypeCast::convertGenericCuda, (Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz), LIBND4J_TYPES_EXTENDED, LIBND4J_TYPES_EXTENDED); - BUILD_DOUBLE_TEMPLATE(template void prescanLauncher, (dim3 &blocks, dim3 &threads, int shmem, cudaStream_t *stream, int *g_odata, const int *g_idata, int *g_blockSums, int n, int blockIndex, int baseIndex), LIBND4J_BOOLS_LOCAL, LIBND4J_BOOLS_LOCAL); +BUILD_DOUBLE_TEMPLATE(template void TypeCast::convertGenericCuda, + (Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz), + LIBND4J_TYPES_EXTENDED, LIBND4J_TYPES_EXTENDED); +BUILD_DOUBLE_TEMPLATE(template void prescanLauncher, + (dim3 & blocks, dim3 &threads, int shmem, + cudaStream_t *stream, int *g_odata, const int *g_idata, + int *g_blockSums, int n, int blockIndex, int baseIndex), + LIBND4J_BOOLS_LOCAL, LIBND4J_BOOLS_LOCAL); #undef LIBND4J_BOOLS_LOCAL -} \ No newline at end of file +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/impl/type_conversions.cpp b/libnd4j/include/loops/impl/type_conversions.cpp index a2f302d25194..b503e3045dfc 100644 --- a/libnd4j/include/loops/impl/type_conversions.cpp +++ b/libnd4j/include/loops/impl/type_conversions.cpp @@ -18,222 +18,251 @@ // Created by raver on 6/12/2018. // -#include -#include -#include -#include #include +#include +#include +#include +#include namespace sd { - template - _CUDA_H void TypeCast::convertFromQuantized(Nd4jPointer *extras, void *dx, Nd4jLong N, void *dz) { - // - auto z = reinterpret_cast(dz); - - auto fx = reinterpret_cast(dx); - auto amin = sd::math::nd4j_abs(fx[0]); - auto amax = sd::math::nd4j_abs(fx[1]); - +template +_CUDA_H void TypeCast::convertFromQuantized(Nd4jPointer *extras, void *dx, + Nd4jLong N, void *dz) { + // + auto z = reinterpret_cast(dz); - auto x = reinterpret_cast(dx) + 8; + auto fx = reinterpret_cast(dx); + auto amin = sd::math::nd4j_abs(fx[0]); + auto amax = sd::math::nd4j_abs(fx[1]); + auto x = reinterpret_cast(dx) + 8; - for (Nd4jLong e = 0; e < N; e++) { - z[e] = static_cast(static_cast(x[e]) / static_cast(DataTypeUtils::max()) * sd::math::nd4j_max(amin, amax)); - } - } - - template - _CUDA_H void TypeCast::convertToQuantized(Nd4jPointer *extras, void *dx, Nd4jLong N, void *dz) { - // find min/max first + for (Nd4jLong e = 0; e < N; e++) { + z[e] = static_cast(static_cast(x[e]) / + static_cast(DataTypeUtils::max()) * + sd::math::nd4j_max(amin, amax)); + } +} - auto x = reinterpret_cast(dx); - auto z = reinterpret_cast(dz); +template +_CUDA_H void TypeCast::convertToQuantized(Nd4jPointer *extras, void *dx, + Nd4jLong N, void *dz) { + // find min/max first - T mn = DataTypeUtils::max(); - T mx = -DataTypeUtils::max(); + auto x = reinterpret_cast(dx); + auto z = reinterpret_cast(dz); - for (Nd4jLong e = 0; e < N; e++) { - T v = x[e]; - if (v < mn) - mn = v; + T mn = DataTypeUtils::max(); + T mx = -DataTypeUtils::max(); - if (v > mx) - mx = v; - } + for (Nd4jLong e = 0; e < N; e++) { + T v = x[e]; + if (v < mn) mn = v; - // we shift by 2 fp32 elements - auto rz = z + 8; + if (v > mx) mx = v; + } - // - auto fz = reinterpret_cast(z); + // we shift by 2 fp32 elements + auto rz = z + 8; - float max = static_cast(mx); - float min = static_cast(mn); + // + auto fz = reinterpret_cast(z); - int max_byte = static_cast(DataTypeUtils::max()); - fz[0] = min; - fz[1] = max; + float max = static_cast(mx); + float min = static_cast(mn); - auto amax = sd::math::nd4j_abs(max); - auto amin = sd::math::nd4j_abs(min); + int max_byte = static_cast(DataTypeUtils::max()); + fz[0] = min; + fz[1] = max; - // now we actually apply quantization - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - rz[e] = static_cast(sd::math::nd4j_round( 1.0f * static_cast(x[e]) / sd::math::nd4j_max(amax, amin) * max_byte)); - } - }; + auto amax = sd::math::nd4j_abs(max); + auto amin = sd::math::nd4j_abs(min); - samediff::Threads::parallel_for(func, 0, N); + // now we actually apply quantization + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + rz[e] = static_cast(sd::math::nd4j_round( + 1.0f * static_cast(x[e]) / + sd::math::nd4j_max(amax, amin) * max_byte)); } + }; + + samediff::Threads::parallel_for(func, 0, N); +} - template - void TypeCast::convertToThreshold(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz) { - // we suppose that first 4 bytes are integer, second 4 bytes are float - // integer: enc length - // integer: dec length - // float: threshold - FloatBits fb; - auto x = reinterpret_cast(dx); - auto z = reinterpret_cast(dz); - int limit = z[0]; - fb.i_ = z[2]; - float threshold = fb.f_; - - // TODO: int limit is sad thing here, 2B elements limitation - auto l = static_cast(N); - z[1] = l; +template +void TypeCast::convertToThreshold(Nd4jPointer *extras, void *dx, Nd4jLong N, + void *dz) { + // we suppose that first 4 bytes are integer, second 4 bytes are float + // integer: enc length + // integer: dec length + // float: threshold + FloatBits fb; + auto x = reinterpret_cast(dx); + auto z = reinterpret_cast(dz); + int limit = z[0]; + fb.i_ = z[2]; + float threshold = fb.f_; + + // TODO: int limit is sad thing here, 2B elements limitation + auto l = static_cast(N); + z[1] = l; #ifdef _OPENMP - int threads = OmpLaunchHelper::betterThreads(N); - auto span = OmpLaunchHelper::betterSpan(N, threads); + int threads = OmpLaunchHelper::betterThreads(N); + auto span = OmpLaunchHelper::betterSpan(N, threads); #else - int threads = 1; - auto span = N; + int threads = 1; + auto span = N; #endif + T tt = static_cast(threshold); + T mtt = -tt; + + // we use 3 as offset, since first 12 bytes are occupied with header + int flimit = limit + 4; + volatile int cnt = 4; + volatile bool flag = false; + PRAGMA_OMP_PARALLEL_THREADS(threads) { + int tid = omp_get_thread_num(); + int start = span * tid; + int stop = span * (tid + 1); + if (stop > l) stop = l; + + for (int e = start; e < stop; e++) { + bool flag_load; + PRAGMA_OMP_ATOMIC_ARGS(read) + flag_load = flag; + if (flag_load) break; + + T cUpd = x[e]; + if (cUpd >= tt) { + int idx; + PRAGMA_OMP_ATOMIC_ARGS(capture) + idx = cnt++; + + if (idx >= flimit) { + PRAGMA_OMP_ATOMIC_ARGS(write) + flag = true; + break; + } - T tt = static_cast(threshold); - T mtt = -tt; - - // we use 3 as offset, since first 12 bytes are occupied with header - int flimit = limit + 4; - volatile int cnt = 4; - volatile bool flag = false; - PRAGMA_OMP_PARALLEL_THREADS(threads) - { - int tid = omp_get_thread_num(); - int start = span * tid; - int stop = span * (tid + 1); - if (stop > l) - stop = l; - - for (int e = start; e < stop; e++) { - bool flag_load; -PRAGMA_OMP_ATOMIC_ARGS(read) - flag_load = flag; - if (flag_load) - break; - - T cUpd = x[e]; - if (cUpd >= tt) { - int idx; -PRAGMA_OMP_ATOMIC_ARGS(capture) - idx = cnt++; - - if (idx >= flimit) { -PRAGMA_OMP_ATOMIC_ARGS(write) - flag = true; - break; - } - - z[idx] = e + 1; - x[e] -= tt; - } else if (cUpd <= mtt) { - int idx; -PRAGMA_OMP_ATOMIC_ARGS(capture) - idx = cnt++; - - if (idx >= flimit) { -PRAGMA_OMP_ATOMIC_ARGS(write) - flag = true; - break; - } - - - z[idx] = -e - 1; - x[e] += tt; - } - } + z[idx] = e + 1; + x[e] -= tt; + } else if (cUpd <= mtt) { + int idx; + PRAGMA_OMP_ATOMIC_ARGS(capture) + idx = cnt++; + + if (idx >= flimit) { + PRAGMA_OMP_ATOMIC_ARGS(write) + flag = true; + break; } + + z[idx] = -e - 1; + x[e] += tt; + } } + } +} - template - void TypeCast::convertFromThreshold(Nd4jPointer * extras, const void *dx, Nd4jLong N, void *dz) { - FloatBits fb; - auto z = reinterpret_cast(dz); - auto x = reinterpret_cast(dx); - int limit = x[0]; - fb.i_ = x[2]; - float threshold = fb.f_; - - // we use 3 as offset, since first 12 bytes are occupied with header - int flimit = limit + 4; - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - int el = x[e]; - int ael = sd::math::nd4j_abs(el) - 1; - z[ael] += el > 0 ? static_cast(threshold) : static_cast(-threshold); - } - }; - - samediff::Threads::parallel_for(func, 4, flimit); +template +void TypeCast::convertFromThreshold(Nd4jPointer *extras, const void *dx, + Nd4jLong N, void *dz) { + FloatBits fb; + auto z = reinterpret_cast(dz); + auto x = reinterpret_cast(dx); + int limit = x[0]; + fb.i_ = x[2]; + float threshold = fb.f_; + + // we use 3 as offset, since first 12 bytes are occupied with header + int flimit = limit + 4; + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + int el = x[e]; + int ael = sd::math::nd4j_abs(el) - 1; + z[ael] += el > 0 ? static_cast(threshold) : static_cast(-threshold); } + }; + + samediff::Threads::parallel_for(func, 4, flimit); +} - /** - * This is cpu version, so leave it here as inline, to avoid templates instantiation - * - * @tparam S - * @tparam T - * @param dx - * @param N - * @param dz - */ - template - void TypeCast::convertGeneric(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz) { - auto x = reinterpret_cast(dx); - auto z = reinterpret_cast(dz); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - z[i] = static_cast(static_cast(x[i])); - } - }; - samediff::Threads::parallel_for(func, 0, N); - }; - - template void TypeCast::convertFromThreshold(Nd4jPointer * extras, const void *dx, Nd4jLong N, void *dz); - template void TypeCast::convertFromThreshold(Nd4jPointer * extras, const void *dx, Nd4jLong N, void *dz); - template void TypeCast::convertFromThreshold(Nd4jPointer * extras, const void *dx, Nd4jLong N, void *dz); - template void TypeCast::convertFromThreshold(Nd4jPointer * extras, const void *dx, Nd4jLong N, void *dz); - - template void TypeCast::convertToThreshold(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); - template void TypeCast::convertToThreshold(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); - template void TypeCast::convertToThreshold(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); - template void TypeCast::convertToThreshold(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); - - template void TypeCast::convertFromQuantized(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); - template void TypeCast::convertFromQuantized(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); - template void TypeCast::convertFromQuantized(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); - - template void TypeCast::convertToQuantized(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); - template void TypeCast::convertToQuantized(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); - template void TypeCast::convertToQuantized(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); +/** + * This is cpu version, so leave it here as inline, to avoid templates + * instantiation + * + * @tparam S + * @tparam T + * @param dx + * @param N + * @param dz + */ +template +void TypeCast::convertGeneric(Nd4jPointer *extras, void *dx, Nd4jLong N, + void *dz) { + auto x = reinterpret_cast(dx); + auto z = reinterpret_cast(dz); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + z[i] = static_cast(static_cast(x[i])); + } + }; + samediff::Threads::parallel_for(func, 0, N); +}; + +template void TypeCast::convertFromThreshold(Nd4jPointer *extras, + const void *dx, Nd4jLong N, + void *dz); +template void TypeCast::convertFromThreshold(Nd4jPointer *extras, + const void *dx, Nd4jLong N, + void *dz); +template void TypeCast::convertFromThreshold(Nd4jPointer *extras, + const void *dx, + Nd4jLong N, void *dz); +template void TypeCast::convertFromThreshold(Nd4jPointer *extras, + const void *dx, + Nd4jLong N, void *dz); + +template void TypeCast::convertToThreshold(Nd4jPointer *extras, + void *dx, Nd4jLong N, + void *dz); +template void TypeCast::convertToThreshold(Nd4jPointer *extras, void *dx, + Nd4jLong N, void *dz); +template void TypeCast::convertToThreshold(Nd4jPointer *extras, + void *dx, Nd4jLong N, + void *dz); +template void TypeCast::convertToThreshold(Nd4jPointer *extras, + void *dx, Nd4jLong N, + void *dz); + +template void TypeCast::convertFromQuantized(Nd4jPointer *extras, + void *dx, Nd4jLong N, + void *dz); +template void TypeCast::convertFromQuantized(Nd4jPointer *extras, + void *dx, Nd4jLong N, + void *dz); +template void TypeCast::convertFromQuantized(Nd4jPointer *extras, + void *dx, Nd4jLong N, + void *dz); + +template void TypeCast::convertToQuantized(Nd4jPointer *extras, + void *dx, Nd4jLong N, + void *dz); +template void TypeCast::convertToQuantized(Nd4jPointer *extras, void *dx, + Nd4jLong N, void *dz); +template void TypeCast::convertToQuantized(Nd4jPointer *extras, + void *dx, Nd4jLong N, + void *dz); #ifndef __CLION_IDE__ - BUILD_DOUBLE_TEMPLATE(template void TypeCast::convertGeneric, (Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz), LIBND4J_TYPES, LIBND4J_TYPES) +BUILD_DOUBLE_TEMPLATE(template void TypeCast::convertGeneric, + (Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz), + LIBND4J_TYPES, LIBND4J_TYPES) #endif -} +} // namespace sd diff --git a/libnd4j/include/loops/indexreduce.h b/libnd4j/include/loops/indexreduce.h old mode 100755 new mode 100644 index 2e8bc33d20af..e7a14460b2fa --- a/libnd4j/include/loops/indexreduce.h +++ b/libnd4j/include/loops/indexreduce.h @@ -26,103 +26,92 @@ #ifdef _OPENMP #include #endif -#include +#include +#include #include +#include #include -#include -#include #ifdef __CUDACC__ #include #include #endif - #include - -#include "system/pairwise_util.h" - #include "legacy_ops.h" +#include "system/pairwise_util.h" namespace functions { - namespace indexreduce { +namespace indexreduce { - template - class IndexReduce { - public: +template +class IndexReduce { + public: #ifdef __CUDABLAS__ - static __device__ void transform(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension,int dimensionLength, - int postProcessOrNot, - int *allocationBuffer, void *reductionBuffer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset); - - template - static __device__ void aggregatePartials(IndexValue **sPartialsRef, Nd4jLong tid, Nd4jLong numElements, void *extraParams); - - - template - static __device__ void transform(const void *dx, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationBuffer, void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets); - - - static _CUDA_H void executeIndexReduceScalar(dim3 launchDims, cudaStream_t *stream, - int op, - const void *dx, const Nd4jLong *xShapeInfo, - int xRank, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfo, - int zRank, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationBuffer, void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets); - - static _CUDA_H void executeIndexReduce(dim3 launchDims, cudaStream_t *stream, - int op, - const void *dx, const Nd4jLong *xShapeInfo, - int xRank, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfo, - int zRank, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationBuffer, void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets); + static __device__ void transform( + int opNum, const void *x, const Nd4jLong *xShapeInfo, void *extraParams, + void *result, const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, int postProcessOrNot, int *allocationBuffer, + void *reductionBuffer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset); + + template + static __device__ void aggregatePartials(IndexValue **sPartialsRef, + Nd4jLong tid, Nd4jLong numElements, + void *extraParams); + + template + static __device__ void transform(const void *dx, const Nd4jLong *xShapeInfo, + void *extraParams, void *result, + const Nd4jLong *resultShapeInfo, + int *dimension, int dimensionLength, + int postProcessOrNot, int *allocationBuffer, + void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets); + + static _CUDA_H void executeIndexReduceScalar( + dim3 launchDims, cudaStream_t *stream, int op, const void *dx, + const Nd4jLong *xShapeInfo, int xRank, void *extraParams, void *result, + const Nd4jLong *resultShapeInfo, int zRank, int *dimension, + int dimensionLength, int postProcessOrNot, int *allocationBuffer, + void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets); + + static _CUDA_H void executeIndexReduce( + dim3 launchDims, cudaStream_t *stream, int op, const void *dx, + const Nd4jLong *xShapeInfo, int xRank, void *extraParams, void *result, + const Nd4jLong *resultShapeInfo, int zRank, int *dimension, + int dimensionLength, int postProcessOrNot, int *allocationBuffer, + void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets); #else - static Nd4jLong execScalar(int opNum, const void *x, const Nd4jLong *xShapeInfo, void *extraParams); - - static void exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfoBuffer, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset); - - template - static _CUDA_H Nd4jLong execScalar(const void *x, const Nd4jLong *xShapeInfo, void *extraParams); - - template - static _CUDA_H void exec(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfoBuffer, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset); + static Nd4jLong execScalar(int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *extraParams); + + static void exec(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *result, + const Nd4jLong *resultShapeInfoBuffer, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset); + + template + static _CUDA_H Nd4jLong execScalar(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams); + + template + static _CUDA_H void exec(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *result, + const Nd4jLong *resultShapeInfoBuffer, + int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset); #endif - }; - } -} +}; +} // namespace indexreduce +} // namespace functions #endif /* INDEXREDUCE_H_ */ - diff --git a/libnd4j/include/loops/legacy_ops.h b/libnd4j/include/loops/legacy_ops.h index 001f8806c6db..c54d53bd63b7 100644 --- a/libnd4j/include/loops/legacy_ops.h +++ b/libnd4j/include/loops/legacy_ops.h @@ -21,392 +21,145 @@ #ifndef PROJECT_LEGACY_OPS_H #define PROJECT_LEGACY_OPS_H -#define AGGREGATE_OPS \ - (0, aggregateOps::HierarchicSoftmax) ,\ - (1, aggregateOps::Dot) ,\ - (2, aggregateOps::Axpy) ,\ - (3, aggregateOps::SkipGram) ,\ - (4, aggregateOps::CBOW) ,\ - (5, aggregateOps::GEMM) - -#define BROADCAST_INT_OPS \ - (0, ShiftLeft), \ - (1, ShiftRight), \ - (2, CyclicShiftLeft), \ - (3, CyclicShiftRight), \ - (4, IntAnd), \ - (5, IntOr), \ - (6, IntXor) - - -#define BROADCAST_BOOL_OPS \ - (0, EqualTo),\ - (1, GreaterThan),\ - (2, LessThan),\ - (3, Epsilon),\ - (4, GreaterThanOrEqual),\ - (5, MatchCondition) ,\ - (6, NotEqualTo),\ - (7, And),\ - (8, Or),\ - (9, Xor) ,\ - (10, Not) ,\ - (11, LessThanOrEqual) - -#define BROADCAST_OPS \ - (0, Add), \ - (1, Subtract), \ - (2, Multiply), \ - (3, Divide), \ - (4, ReverseDivide), \ - (5, ReverseSubtract), \ - (6, CopyPws), \ - (7, Pow), \ - (13, MinPairwise) ,\ - (14, MaxPairwise) ,\ - (15, AMinPairwise) ,\ - (16, AMaxPairwise) ,\ - (17, SquaredSubtract),\ - (18, FloorMod),\ - (19, FloorDiv),\ - (20, ReverseMod),\ - (21, SafeDivide),\ - (22, Mod) ,\ - (23, TruncateDiv), \ - (26, Atan2) ,\ - (27, LogicalOr) ,\ - (28, LogicalXor) ,\ - (29, LogicalNot) ,\ - (30, LogicalAnd), \ - (31, DivideNoNan), \ - (32, IGamma), \ - (33, IGammac),\ - (34, PowDerivative) +#define AGGREGATE_OPS \ + (0, aggregateOps::HierarchicSoftmax), (1, aggregateOps::Dot), \ + (2, aggregateOps::Axpy), (3, aggregateOps::SkipGram), \ + (4, aggregateOps::CBOW), (5, aggregateOps::GEMM) + +#define BROADCAST_INT_OPS \ + (0, ShiftLeft), (1, ShiftRight), (2, CyclicShiftLeft), \ + (3, CyclicShiftRight), (4, IntAnd), (5, IntOr), (6, IntXor) + +#define BROADCAST_BOOL_OPS \ + (0, EqualTo), (1, GreaterThan), (2, LessThan), (3, Epsilon), \ + (4, GreaterThanOrEqual), (5, MatchCondition), (6, NotEqualTo), (7, And), \ + (8, Or), (9, Xor), (10, Not), (11, LessThanOrEqual) + +#define BROADCAST_OPS \ + (0, Add), (1, Subtract), (2, Multiply), (3, Divide), (4, ReverseDivide), \ + (5, ReverseSubtract), (6, CopyPws), (7, Pow), (13, MinPairwise), \ + (14, MaxPairwise), (15, AMinPairwise), (16, AMaxPairwise), \ + (17, SquaredSubtract), (18, FloorMod), (19, FloorDiv), (20, ReverseMod), \ + (21, SafeDivide), (22, Mod), (23, TruncateDiv), (26, Atan2), \ + (27, LogicalOr), (28, LogicalXor), (29, LogicalNot), (30, LogicalAnd), \ + (31, DivideNoNan), (32, IGamma), (33, IGammac), (34, PowDerivative) // these ops return same data type as input -#define TRANSFORM_SAME_OPS \ - (0, Abs), \ - (1, Sign), \ - (2, Ones), \ - (3, Neg), \ - (4, Round), \ - (5, TimesOneMinus), \ - (6, Cube), \ - (7, OneMinus), \ - (11, Reciprocal), \ - (12, Square), \ - (13, CompareAndSetTransform) ,\ - (15, Identity), \ - (17, Ceiling), \ - (18, Floor), \ - (19, ClipByValue) ,\ - (21, Copy) +#define TRANSFORM_SAME_OPS \ + (0, Abs), (1, Sign), (2, Ones), (3, Neg), (4, Round), (5, TimesOneMinus), \ + (6, Cube), (7, OneMinus), (11, Reciprocal), (12, Square), \ + (13, CompareAndSetTransform), (15, Identity), (17, Ceiling), \ + (18, Floor), (19, ClipByValue), (21, Copy) -#define TRANSFORM_ANY_OPS \ - (0, Assign) +#define TRANSFORM_ANY_OPS (0, Assign) // these ops return bool -#define TRANSFORM_BOOL_OPS \ - (1, IsInf), \ - (2, IsNan), \ - (3, IsFinite), \ - (4, IsInfOrNan), \ - (5, MatchConditionBool), \ - (6, IsPositive) , \ - (7, Not), \ - (8, IsNegative) - - -#define TRANSFORM_STRICT_OPS \ - (2, ScaledTanh), \ - (3, Affine), \ - (4, TanhDerivative), \ - (5, HardTanhDerivative), \ - (6, SigmoidDerivative), \ - (7, SoftSignDerivative), \ - (8, TanDerivative) ,\ - (9, SELUDerivative) ,\ - (10, HardSigmoidDerivative) ,\ - (11, RationalTanhDerivative) ,\ - (12, RectifiedTanhDerivative) ,\ - (13, SwishDerivative) ,\ - (14, ACoshDerivative) ,\ - (15, ASinhDerivative) ,\ - (16, SinhDerivative), \ - (17, LogSigmoidDerivative) ,\ - (18, SpecialDerivative), \ - (19, Stabilize), \ - (20, StabilizeFP16) ,\ - (21, CubeDerivative) ,\ - (22, Cosine), \ - (23, Exp), \ - (24, Log), \ - (25, SetRange), \ - (26, Sigmoid), \ - (27, Sin), \ - (28, SoftPlus), \ - (29, Tanh), \ - (30, ACos), \ - (31, ASin), \ - (32, ATan), \ - (33, HardTanh), \ - (34, SoftSign), \ - (36, HardSigmoid), \ - (37, RationalTanh) ,\ - (38, RectifiedTanh) ,\ - (39, Sinh) ,\ - (40, Cosh) ,\ - (41, Tan) ,\ - (42, SELU) ,\ - (43, Swish) ,\ - (44, Log1p), \ - (45, Erf), \ - (46, ACosh), \ - (47, ASinh), \ - (48, Rint), \ - (49, LogSigmoid), \ - (50, Erfc) ,\ - (51, Expm1), \ - (52, ATanh) ,\ - (53, GELU) ,\ - (54, GELUDerivative), \ - (55, PreciseGELU) ,\ - (56, PreciseGELUDerivative), \ - (57, Mish),\ - (58, MishDerivative) +#define TRANSFORM_BOOL_OPS \ + (1, IsInf), (2, IsNan), (3, IsFinite), (4, IsInfOrNan), \ + (5, MatchConditionBool), (6, IsPositive), (7, Not), (8, IsNegative) + +#define TRANSFORM_STRICT_OPS \ + (2, ScaledTanh), (3, Affine), (4, TanhDerivative), (5, HardTanhDerivative), \ + (6, SigmoidDerivative), (7, SoftSignDerivative), (8, TanDerivative), \ + (9, SELUDerivative), (10, HardSigmoidDerivative), \ + (11, RationalTanhDerivative), (12, RectifiedTanhDerivative), \ + (13, SwishDerivative), (14, ACoshDerivative), (15, ASinhDerivative), \ + (16, SinhDerivative), (17, LogSigmoidDerivative), \ + (18, SpecialDerivative), (19, Stabilize), (20, StabilizeFP16), \ + (21, CubeDerivative), (22, Cosine), (23, Exp), (24, Log), \ + (25, SetRange), (26, Sigmoid), (27, Sin), (28, SoftPlus), (29, Tanh), \ + (30, ACos), (31, ASin), (32, ATan), (33, HardTanh), (34, SoftSign), \ + (36, HardSigmoid), (37, RationalTanh), (38, RectifiedTanh), (39, Sinh), \ + (40, Cosh), (41, Tan), (42, SELU), (43, Swish), (44, Log1p), (45, Erf), \ + (46, ACosh), (47, ASinh), (48, Rint), (49, LogSigmoid), (50, Erfc), \ + (51, Expm1), (52, ATanh), (53, GELU), (54, GELUDerivative), \ + (55, PreciseGELU), (56, PreciseGELUDerivative), (57, Mish), \ + (58, MishDerivative) // these ops return one of FLOAT data types -#define TRANSFORM_FLOAT_OPS \ - (1, Sqrt), \ - (3, RSqrt) - +#define TRANSFORM_FLOAT_OPS (1, Sqrt), (3, RSqrt) #define SUMMARY_STATS_OPS \ - (0, SummaryStatsVariance), \ - (1, SummaryStatsStandardDeviation) - -#define SCALAR_INT_OPS \ - (0, ShiftLeft) ,\ - (1, ShiftRight), \ - (2, CyclicShiftLeft), \ - (3, CyclicShiftRight), \ - (4, IntAnd), \ - (5, IntOr), \ - (6, IntXor) - -#define SCALAR_BOOL_OPS \ - (0, EqualTo),\ - (1, GreaterThan),\ - (2, LessThan),\ - (3, Epsilon),\ - (4, GreaterThanOrEqual),\ - (5, MatchCondition) ,\ - (6, NotEqualTo),\ - (7, And),\ - (8, Or),\ - (9, Xor) ,\ - (10, Not) ,\ - (11, LessThanOrEqual) - -#define SCALAR_OPS \ - (0, Add),\ - (1, Subtract),\ - (2, Multiply),\ - (3, Divide),\ - (4, ReverseDivide),\ - (5, ReverseSubtract),\ - (6, MaxPairwise),\ - (7, ELU), \ - (8, ELUDerivative), \ - (13, MinPairwise),\ - (14, CopyPws),\ - (15, Mod),\ - (16, ReverseMod),\ - (17, Remainder),\ - (18, FMod) ,\ - (19, TruncateDiv) ,\ - (20, FloorDiv) ,\ - (21, FloorMod), \ - (22, SquaredSubtract),\ - (23, SafeDivide), \ - (24, AMaxPairwise), \ - (25, AMinPairwise), \ - (26, Atan2) ,\ - (27, LogicalOr) ,\ - (28, LogicalXor) ,\ - (29, LogicalNot) ,\ - (30, LogicalAnd) ,\ - (31, Pow) ,\ - (32, PowDerivative) ,\ - (33, CompareAndSet) ,\ - (34, SXELogitsSmoother), \ - (35, LeakyRELU), \ - (36, LeakyRELUDerivative), \ - (37, ReplaceNans) ,\ - (38, LogX) ,\ - (39, RELU), \ - (40, RELU6), \ - (41, Step), \ - (42, LstmClip), \ - (43, TruncateMod) ,\ - (44, SquaredReverseSubtract) ,\ - (45, ReversePow), \ - (46, DivideNoNan), \ - (47, IGamma), \ - (48, IGammac), \ - (49, RELUDerivative) - - - - - -#define REDUCE3_OPS \ - (0, ManhattanDistance), \ - (1, EuclideanDistance), \ - (2, CosineSimilarity), \ - (3, Dot), \ - (4, EqualsWithEps) ,\ - (5, CosineDistance) ,\ - (6, JaccardDistance) ,\ - (7, SimpleHammingDistance) - -#define REDUCE_LONG_OPS \ - (0, CountNonZero), \ - (1, CountZero), \ - (2, MatchCondition) - -#define REDUCE_BOOL_OPS \ - (0, Any) ,\ - (1, All), \ - (2, IsFinite), \ - (3, IsInfOrNan), \ - (4, IsNan), \ - (5, IsInf), \ - (6, IsPositive), \ - (7, IsNegative) - -#define REDUCE_SAME_OPS \ - (0, Sum), \ - (1, Max), \ - (2, Min), \ - (3, Prod), \ - (4, ASum), \ - (5, AMax) ,\ - (6, AMin) ,\ - (7, ReduceSameBenchmarkOp) - -#define REDUCE_FLOAT_OPS \ - (0, Mean), \ - (1, AMean) ,\ - (2, Norm1), \ - (3, Norm2), \ - (4, NormMax), \ - (5, NormFrobenius), \ - (6, NormP), \ - (7, SquaredNorm) ,\ - (8, Entropy) ,\ - (9, LogEntropy) ,\ - (10, ShannonEntropy) ,\ - (12, ReduceFloatBenchmarkOp) - - - - -#define RANDOM_OPS \ - (0, UniformDistribution) ,\ - (1, DropOut) ,\ - (2, DropOutInverted) ,\ - (3, ProbablisticMerge) ,\ - (4, Linspace) ,\ - (5, Choice) ,\ - (6, GaussianDistribution) ,\ - (7, BernoulliDistribution) ,\ - (8, BinomialDistribution),\ - (9, BinomialDistributionEx),\ - (10, LogNormalDistribution) ,\ - (11, TruncatedNormalDistribution) ,\ - (12, AlphaDropOut),\ - (13, ExponentialDistribution),\ - (14, ExponentialDistributionInv), \ - (15, PoissonDistribution), \ - (16, GammaDistribution) - -#define PAIRWISE_INT_OPS \ - (0, ShiftLeft), \ - (1, ShiftRight), \ - (2, CyclicShiftLeft), \ - (3, CyclicShiftRight), \ - (4, IntAnd), \ - (5, IntOr), \ - (6, IntXor) - -#define PAIRWISE_BOOL_OPS \ - (0, EqualTo),\ - (1, GreaterThan),\ - (2, LessThan),\ - (3, Epsilon),\ - (4, GreaterThanOrEqual),\ - (5, MatchCondition) ,\ - (6, NotEqualTo),\ - (7, And),\ - (8, Or),\ - (9, Xor) ,\ - (10, Not) ,\ - (11, LessThanOrEqual) - -#define PAIRWISE_TRANSFORM_OPS \ - (0, Add),\ - (1, CopyPws),\ - (2, Divide),\ - (3, Multiply),\ - (4, Pow),\ - (5, ReverseSubtract),\ - (6, Subtract),\ - (7, MaxPairwise),\ - (8, MinPairwise),\ - (9, Copy2) ,\ - (10, Axpy),\ - (11, ReverseDivide),\ - (12, CompareAndSet),\ - (13, CompareAndReplace),\ - (14, Remainder),\ - (15, FMod),\ - (16, Atan2) ,\ - (17, TruncateDiv),\ - (18, FloorDiv), \ - (19, FloorMod) ,\ - (20, SquaredSubtract) ,\ - (21, ReverseMod),\ - (22, SafeDivide), \ - (23, Mod) ,\ - (24, RelativeError) ,\ - (25, BinaryRelativeError) ,\ - (26, BinaryMinimumAbsoluteRelativeError) ,\ - (27, LogicalOr) ,\ - (28, LogicalXor) ,\ - (29, LogicalNot) ,\ - (30, LogicalAnd) ,\ - (31, PowDerivative), \ - (32, LogPoissonLoss), \ - (33, LogPoissonLossFull) , \ - (34, AMaxPairwise), \ - (35, AMinPairwise) ,\ - (36, TruncateMod), \ - (37, ReplaceNans), \ - (38, DivideNoNan), \ - (39, IGamma), \ - (40, IGammac) - - - -#define INDEX_REDUCE_OPS \ - (0, IndexMax), \ - (1, IndexMin), \ - (2, IndexAbsoluteMax), \ - (3, IndexAbsoluteMin) , \ - (4, FirstIndex) , \ - (5, LastIndex) - - - -#endif //PROJECT_LEGACY_OPS_H + (0, SummaryStatsVariance), (1, SummaryStatsStandardDeviation) + +#define SCALAR_INT_OPS \ + (0, ShiftLeft), (1, ShiftRight), (2, CyclicShiftLeft), \ + (3, CyclicShiftRight), (4, IntAnd), (5, IntOr), (6, IntXor) + +#define SCALAR_BOOL_OPS \ + (0, EqualTo), (1, GreaterThan), (2, LessThan), (3, Epsilon), \ + (4, GreaterThanOrEqual), (5, MatchCondition), (6, NotEqualTo), (7, And), \ + (8, Or), (9, Xor), (10, Not), (11, LessThanOrEqual) + +#define SCALAR_OPS \ + (0, Add), (1, Subtract), (2, Multiply), (3, Divide), (4, ReverseDivide), \ + (5, ReverseSubtract), (6, MaxPairwise), (7, ELU), (8, ELUDerivative), \ + (13, MinPairwise), (14, CopyPws), (15, Mod), (16, ReverseMod), \ + (17, Remainder), (18, FMod), (19, TruncateDiv), (20, FloorDiv), \ + (21, FloorMod), (22, SquaredSubtract), (23, SafeDivide), \ + (24, AMaxPairwise), (25, AMinPairwise), (26, Atan2), (27, LogicalOr), \ + (28, LogicalXor), (29, LogicalNot), (30, LogicalAnd), (31, Pow), \ + (32, PowDerivative), (33, CompareAndSet), (34, SXELogitsSmoother), \ + (35, LeakyRELU), (36, LeakyRELUDerivative), (37, ReplaceNans), \ + (38, LogX), (39, RELU), (40, RELU6), (41, Step), (42, LstmClip), \ + (43, TruncateMod), (44, SquaredReverseSubtract), (45, ReversePow), \ + (46, DivideNoNan), (47, IGamma), (48, IGammac), (49, RELUDerivative) + +#define REDUCE3_OPS \ + (0, ManhattanDistance), (1, EuclideanDistance), (2, CosineSimilarity), \ + (3, Dot), (4, EqualsWithEps), (5, CosineDistance), (6, JaccardDistance), \ + (7, SimpleHammingDistance) + +#define REDUCE_LONG_OPS (0, CountNonZero), (1, CountZero), (2, MatchCondition) + +#define REDUCE_BOOL_OPS \ + (0, Any), (1, All), (2, IsFinite), (3, IsInfOrNan), (4, IsNan), (5, IsInf), \ + (6, IsPositive), (7, IsNegative) + +#define REDUCE_SAME_OPS \ + (0, Sum), (1, Max), (2, Min), (3, Prod), (4, ASum), (5, AMax), (6, AMin), \ + (7, ReduceSameBenchmarkOp) + +#define REDUCE_FLOAT_OPS \ + (0, Mean), (1, AMean), (2, Norm1), (3, Norm2), (4, NormMax), \ + (5, NormFrobenius), (6, NormP), (7, SquaredNorm), (8, Entropy), \ + (9, LogEntropy), (10, ShannonEntropy), (12, ReduceFloatBenchmarkOp) + +#define RANDOM_OPS \ + (0, UniformDistribution), (1, DropOut), (2, DropOutInverted), \ + (3, ProbablisticMerge), (4, Linspace), (5, Choice), \ + (6, GaussianDistribution), (7, BernoulliDistribution), \ + (8, BinomialDistribution), (9, BinomialDistributionEx), \ + (10, LogNormalDistribution), (11, TruncatedNormalDistribution), \ + (12, AlphaDropOut), (13, ExponentialDistribution), \ + (14, ExponentialDistributionInv), (15, PoissonDistribution), \ + (16, GammaDistribution) + +#define PAIRWISE_INT_OPS \ + (0, ShiftLeft), (1, ShiftRight), (2, CyclicShiftLeft), \ + (3, CyclicShiftRight), (4, IntAnd), (5, IntOr), (6, IntXor) + +#define PAIRWISE_BOOL_OPS \ + (0, EqualTo), (1, GreaterThan), (2, LessThan), (3, Epsilon), \ + (4, GreaterThanOrEqual), (5, MatchCondition), (6, NotEqualTo), (7, And), \ + (8, Or), (9, Xor), (10, Not), (11, LessThanOrEqual) + +#define PAIRWISE_TRANSFORM_OPS \ + (0, Add), (1, CopyPws), (2, Divide), (3, Multiply), (4, Pow), \ + (5, ReverseSubtract), (6, Subtract), (7, MaxPairwise), (8, MinPairwise), \ + (9, Copy2), (10, Axpy), (11, ReverseDivide), (12, CompareAndSet), \ + (13, CompareAndReplace), (14, Remainder), (15, FMod), (16, Atan2), \ + (17, TruncateDiv), (18, FloorDiv), (19, FloorMod), \ + (20, SquaredSubtract), (21, ReverseMod), (22, SafeDivide), (23, Mod), \ + (24, RelativeError), (25, BinaryRelativeError), \ + (26, BinaryMinimumAbsoluteRelativeError), (27, LogicalOr), \ + (28, LogicalXor), (29, LogicalNot), (30, LogicalAnd), \ + (31, PowDerivative), (32, LogPoissonLoss), (33, LogPoissonLossFull), \ + (34, AMaxPairwise), (35, AMinPairwise), (36, TruncateMod), \ + (37, ReplaceNans), (38, DivideNoNan), (39, IGamma), (40, IGammac) + +#define INDEX_REDUCE_OPS \ + (0, IndexMax), (1, IndexMin), (2, IndexAbsoluteMax), (3, IndexAbsoluteMin), \ + (4, FirstIndex), (5, LastIndex) + +#endif // PROJECT_LEGACY_OPS_H diff --git a/libnd4j/include/loops/pairwise_bool.h b/libnd4j/include/loops/pairwise_bool.h index 9cc8f220cbc8..1c87a28ed923 100644 --- a/libnd4j/include/loops/pairwise_bool.h +++ b/libnd4j/include/loops/pairwise_bool.h @@ -26,87 +26,72 @@ #ifdef _OPENMP #include #endif -#include +#include #include -#include -#include -#include +#include #include +#include +#include #include -#include +#include #ifdef __CUDACC__ #include #include #endif - #include "legacy_ops.h" using namespace simdOps; namespace functions { - namespace pairwise_transforms { +namespace pairwise_transforms { /** * Transforms involving 2 arrays */ - template - class PairWiseBoolTransform { - public: - +template +class PairWiseBoolTransform { + public: #ifdef __CUDACC__ - template - static __host__ void intermediateShaped(dim3& launchDims, cudaStream_t *stream, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraParams); - - static __host__ void executeCudaShaped(dim3& launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams); + template + static __host__ void intermediateShaped( + dim3 &launchDims, cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, void *vextraParams); + static __host__ void executeCudaShaped( + dim3 &launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo, void *extraParams); #else - static void exec(int opNum, - const void *dx, const Nd4jLong *xShapeBuffer, - const void *y, const Nd4jLong *yShapeBuffer, - void *result, const Nd4jLong *resultShapeBuffer, - void *extraParams, - uint64_t start, uint64_t stop); - - static void exec(int opNum, - const void *dx, Nd4jLong xStride, - const void *y, Nd4jLong yStride, - void *result, Nd4jLong resultStride, - void *extraParams, - Nd4jLong n, - uint64_t start, uint64_t stop); - - - template - static void exec(const void *vx, const Nd4jLong* xShapeBuffer, - const void *vy, const Nd4jLong* yShapeBuffer, - void *vresult, const Nd4jLong* resultShapeBuffer, - void *vextraParams, - uint64_t start, uint64_t stop); - - template - static void exec(const void *vx, Nd4jLong xStride, - const void *vy, Nd4jLong yStride, - void *vresult, Nd4jLong resultStride, - void *vextraParams, - Nd4jLong n, - uint64_t start, uint64_t stop); + static void exec(int opNum, const void *dx, const Nd4jLong *xShapeBuffer, + const void *y, const Nd4jLong *yShapeBuffer, void *result, + const Nd4jLong *resultShapeBuffer, void *extraParams, + uint64_t start, uint64_t stop); + + static void exec(int opNum, const void *dx, Nd4jLong xStride, const void *y, + Nd4jLong yStride, void *result, Nd4jLong resultStride, + void *extraParams, Nd4jLong n, uint64_t start, + uint64_t stop); + + template + static void exec(const void *vx, const Nd4jLong *xShapeBuffer, const void *vy, + const Nd4jLong *yShapeBuffer, void *vresult, + const Nd4jLong *resultShapeBuffer, void *vextraParams, + uint64_t start, uint64_t stop); + + template + static void exec(const void *vx, Nd4jLong xStride, const void *vy, + Nd4jLong yStride, void *vresult, Nd4jLong resultStride, + void *vextraParams, Nd4jLong n, uint64_t start, + uint64_t stop); #endif - }; - } -} +}; +} // namespace pairwise_transforms +} // namespace functions #endif /* PAIRWISE_TRANSFORM_H_ */ diff --git a/libnd4j/include/loops/pairwise_int.h b/libnd4j/include/loops/pairwise_int.h index 64deebc04723..d80c32989dcb 100644 --- a/libnd4j/include/loops/pairwise_int.h +++ b/libnd4j/include/loops/pairwise_int.h @@ -26,88 +26,72 @@ #ifdef _OPENMP #include #endif -#include +#include #include -#include -#include -#include +#include #include +#include +#include #include -#include +#include #ifdef __CUDACC__ #include #include #endif - - #include "legacy_ops.h" using namespace simdOps; namespace functions { - namespace pairwise_transforms { +namespace pairwise_transforms { /** * Transforms involving 2 arrays */ - template - class PairWiseIntTransform { - public: - +template +class PairWiseIntTransform { + public: #ifdef __CUDACC__ - template - static __host__ void intermediateShaped(dim3& launchDims, cudaStream_t *stream, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraParams); - - static __host__ void executeCudaShaped(dim3& launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams); + template + static __host__ void intermediateShaped( + dim3 &launchDims, cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, void *vextraParams); + static __host__ void executeCudaShaped( + dim3 &launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo, void *extraParams); #else - static void exec(int opNum, - const void *dx, const Nd4jLong *xShapeBuffer, - const void *y, const Nd4jLong *yShapeBuffer, - void *result, const Nd4jLong *resultShapeBuffer, - void *extraParams, - uint64_t start, uint64_t stop); - - static void exec(int opNum, - const void *dx, Nd4jLong xStride, - const void *y, Nd4jLong yStride, - void *result, Nd4jLong resultStride, - void *extraParams, - Nd4jLong n, - uint64_t start, uint64_t stop); - - - template - static void exec(const void *vx, const Nd4jLong* xShapeBuffer, - const void *vy, const Nd4jLong* yShapeBuffer, - void *vresult, const Nd4jLong* resultShapeBuffer, - void *vextraParams, - uint64_t start,uint64_t stop); - - template - static void exec(const void *vx, Nd4jLong xStride, - const void *vy, Nd4jLong yStride, - void *vresult, Nd4jLong resultStride, - void *vextraParams, - Nd4jLong n, - uint64_t start, uint64_t stop); + static void exec(int opNum, const void *dx, const Nd4jLong *xShapeBuffer, + const void *y, const Nd4jLong *yShapeBuffer, void *result, + const Nd4jLong *resultShapeBuffer, void *extraParams, + uint64_t start, uint64_t stop); + + static void exec(int opNum, const void *dx, Nd4jLong xStride, const void *y, + Nd4jLong yStride, void *result, Nd4jLong resultStride, + void *extraParams, Nd4jLong n, uint64_t start, + uint64_t stop); + + template + static void exec(const void *vx, const Nd4jLong *xShapeBuffer, const void *vy, + const Nd4jLong *yShapeBuffer, void *vresult, + const Nd4jLong *resultShapeBuffer, void *vextraParams, + uint64_t start, uint64_t stop); + + template + static void exec(const void *vx, Nd4jLong xStride, const void *vy, + Nd4jLong yStride, void *vresult, Nd4jLong resultStride, + void *vextraParams, Nd4jLong n, uint64_t start, + uint64_t stop); #endif - }; - } -} +}; +} // namespace pairwise_transforms +} // namespace functions #endif /* PAIRWISE_TRANSFORM_H_ */ diff --git a/libnd4j/include/loops/pairwise_transform.h b/libnd4j/include/loops/pairwise_transform.h old mode 100755 new mode 100644 index b3b514df6235..008cb553ac0b --- a/libnd4j/include/loops/pairwise_transform.h +++ b/libnd4j/include/loops/pairwise_transform.h @@ -27,13 +27,14 @@ #include #endif -#include -#include +#include #include +#include +#include #include #include + #include "legacy_ops.h" -#include #ifdef __CUDACC__ #include @@ -41,68 +42,53 @@ #include #endif - namespace functions { - namespace pairwise_transforms { +namespace pairwise_transforms { /** * Transforms involving 2 arrays */ - template - class PairWiseTransform { - public: - +template +class PairWiseTransform { + public: #ifdef __CUDABLAS__ - template - static __host__ void intermediateShaped(dim3& launchDims, cudaStream_t *stream, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraParams); + template + static __host__ void intermediateShaped( + dim3 &launchDims, cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, void *vextraParams); - static __host__ void executeCudaShaped(dim3& launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams); + static __host__ void executeCudaShaped( + dim3 &launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo, void *extraParams); #endif - public: - - static void exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams, - uint64_t start, uint64_t stop); - - static void exec(int opNum, - const void *x, Nd4jLong xStride, - const void *y, Nd4jLong yStride, - void *z, Nd4jLong resultStride, - void *extraParams, - Nd4jLong len, - uint64_t start, uint64_t stop); - - - template - static void exec(const void *vx, const Nd4jLong* xShapeInfo, - const void *vy, const Nd4jLong* yShapeInfo, - void *vresult, const Nd4jLong* zShapeInfo, - void *vextraParams, - uint64_t start, uint64_t stop); - - template - static void exec(const void *vx, Nd4jLong xStride, - const void *vy, Nd4jLong yStride, - void *vresult, Nd4jLong resultStride, - void *vextraParams, - Nd4jLong len, - uint64_t start, uint64_t stop); - }; - } -} + public: + static void exec(int opNum, const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo, void *extraParams, + uint64_t start, uint64_t stop); + + static void exec(int opNum, const void *x, Nd4jLong xStride, const void *y, + Nd4jLong yStride, void *z, Nd4jLong resultStride, + void *extraParams, Nd4jLong len, uint64_t start, + uint64_t stop); + + template + static void exec(const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vresult, + const Nd4jLong *zShapeInfo, void *vextraParams, + uint64_t start, uint64_t stop); + + template + static void exec(const void *vx, Nd4jLong xStride, const void *vy, + Nd4jLong yStride, void *vresult, Nd4jLong resultStride, + void *vextraParams, Nd4jLong len, uint64_t start, + uint64_t stop); +}; +} // namespace pairwise_transforms +} // namespace functions #endif /* PAIRWISE_TRANSFORM_H_ */ diff --git a/libnd4j/include/loops/random.h b/libnd4j/include/loops/random.h index 9b35f472fd25..6c46fcd18def 100644 --- a/libnd4j/include/loops/random.h +++ b/libnd4j/include/loops/random.h @@ -21,81 +21,84 @@ #ifndef LIBND4J_RANDOM_H #define LIBND4J_RANDOM_H - - -#include #include +#include +#include #include #include -#include - - namespace functions { - namespace random { - - template - class RandomFunction { - public: +namespace random { +template +class RandomFunction { + public: #ifdef __CUDABLAS__ - template - static _CUDA_D void execTransformCuda(Nd4jPointer state, - const void *x, const Nd4jLong *xShapeBuffer, - const void *y, const Nd4jLong *yShapeBuffer, - void *z, const Nd4jLong *zShapeBuffer, - void *extraArguments); - - template - static _CUDA_D void execTransformCuda(Nd4jPointer state, - const void *x, const Nd4jLong *xShapeBuffer, - void *z, const Nd4jLong *zShapeBuffer, - void *extraArguments); - - template - static _CUDA_D void execTransformCuda(Nd4jPointer state, void *z, const Nd4jLong *zShapeBuffer, void *extraArguments); - - - static _CUDA_H void executeCudaSingle(dim3& launchDims, cudaStream_t* stream, - int opNum, - Nd4jPointer stateHost, - void *z, const Nd4jLong *zShapeBuffer, - void *extraArguments); - - - static _CUDA_H void executeCudaDouble(dim3& launchDims, cudaStream_t* stream, - int opNum, - Nd4jPointer stateHost, - const void *x, const Nd4jLong *xShapeBuffer, - void *z, const Nd4jLong *zShapeBuffer, - void *extraArguments); - - - static _CUDA_H void executeCudaTriple(dim3& launchDims, cudaStream_t* stream, - int opNum, - Nd4jPointer stateHost, - const void *x, const Nd4jLong *xShapeBuffer, - const void *y, const Nd4jLong *yShapeBuffer, - void *z, const Nd4jLong* zShapeBuffer, - void *extraArguments); + template + static _CUDA_D void execTransformCuda(Nd4jPointer state, const void *x, + const Nd4jLong *xShapeBuffer, + const void *y, + const Nd4jLong *yShapeBuffer, void *z, + const Nd4jLong *zShapeBuffer, + void *extraArguments); + + template + static _CUDA_D void execTransformCuda(Nd4jPointer state, const void *x, + const Nd4jLong *xShapeBuffer, void *z, + const Nd4jLong *zShapeBuffer, + void *extraArguments); + + template + static _CUDA_D void execTransformCuda(Nd4jPointer state, void *z, + const Nd4jLong *zShapeBuffer, + void *extraArguments); + + static _CUDA_H void executeCudaSingle(dim3 &launchDims, cudaStream_t *stream, + int opNum, Nd4jPointer stateHost, + void *z, const Nd4jLong *zShapeBuffer, + void *extraArguments); + + static _CUDA_H void executeCudaDouble(dim3 &launchDims, cudaStream_t *stream, + int opNum, Nd4jPointer stateHost, + const void *x, + const Nd4jLong *xShapeBuffer, void *z, + const Nd4jLong *zShapeBuffer, + void *extraArguments); + + static _CUDA_H void executeCudaTriple( + dim3 &launchDims, cudaStream_t *stream, int opNum, Nd4jPointer stateHost, + const void *x, const Nd4jLong *xShapeBuffer, const void *y, + const Nd4jLong *yShapeBuffer, void *z, const Nd4jLong *zShapeBuffer, + void *extraArguments); #else - template - static void execTransform(Nd4jPointer state, const void *x, const Nd4jLong *xShapeBuffer, const void *y, const Nd4jLong *yShapeBuffer, void *z, const Nd4jLong *zShapeBuffer, void *extraArguments); - - template - static void execTransform(Nd4jPointer state, const void *x, const Nd4jLong *xShapeBuffer, void *z, const Nd4jLong *zShapeBuffer, void *extraArguments); - - template - static void execTransform(Nd4jPointer state, void *z, const Nd4jLong *zShapeBuffer, void *extraArguments); - - static void execTransform(int opNum, Nd4jPointer state, const void *x, const Nd4jLong *xShapeBuffer, void *z, const Nd4jLong *zShapeBuffer, void *extraArguments); - static void execTransform(int opNum, Nd4jPointer state, const void *x, const Nd4jLong *xShapeBuffer, const void *y, const Nd4jLong *yShapeBuffer, void *z, const Nd4jLong *zShapeBuffer, void *extraArguments); - static void execTransform(int opNum, Nd4jPointer state, void *z, const Nd4jLong *zShapeBuffer, void *extraArguments); + template + static void execTransform(Nd4jPointer state, const void *x, + const Nd4jLong *xShapeBuffer, const void *y, + const Nd4jLong *yShapeBuffer, void *z, + const Nd4jLong *zShapeBuffer, void *extraArguments); + + template + static void execTransform(Nd4jPointer state, const void *x, + const Nd4jLong *xShapeBuffer, void *z, + const Nd4jLong *zShapeBuffer, void *extraArguments); + + template + static void execTransform(Nd4jPointer state, void *z, + const Nd4jLong *zShapeBuffer, void *extraArguments); + + static void execTransform(int opNum, Nd4jPointer state, const void *x, + const Nd4jLong *xShapeBuffer, void *z, + const Nd4jLong *zShapeBuffer, void *extraArguments); + static void execTransform(int opNum, Nd4jPointer state, const void *x, + const Nd4jLong *xShapeBuffer, const void *y, + const Nd4jLong *yShapeBuffer, void *z, + const Nd4jLong *zShapeBuffer, void *extraArguments); + static void execTransform(int opNum, Nd4jPointer state, void *z, + const Nd4jLong *zShapeBuffer, void *extraArguments); #endif - }; - } -} - +}; +} // namespace random +} // namespace functions -#endif //LIBND4J_RANDOM_H +#endif // LIBND4J_RANDOM_H diff --git a/libnd4j/include/loops/reduce3.h b/libnd4j/include/loops/reduce3.h old mode 100755 new mode 100644 index f2496f1fe443..1a5608846782 --- a/libnd4j/include/loops/reduce3.h +++ b/libnd4j/include/loops/reduce3.h @@ -30,242 +30,193 @@ #ifdef _OPENMP #include #endif -#include -#include -#include +#include +#include #include +#include #include +#include #include -#include -#include +#include #ifdef __CUDACC__ #include #include #endif - #include "legacy_ops.h" using namespace simdOps; namespace functions { -namespace reduce3 { +namespace reduce3 { /** * Reduce involving * 2 arrays */ -template +template class Reduce3 { - - public: - + public: #ifdef __CUDACC__ - virtual __device__ - inline Y opAtomic(X d1, X d2, Y *extraParamsRef) = 0; - - /** - * Aggregate shared memory - * @param sPartialsRef - * @param tid - * @param extraParams - */ - template - static __device__ void aggregatePartials(void* sPartials, Nd4jLong tid, Nd4jLong numItems, void *extraParams); - - template - static __device__ void execScalarCuda(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *allocationPointer, void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo); - - template - static __device__ void transformAll(const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets); - - /** - Perform a reduction - @param n the number of elements - @param xOffset the starting offset - @param dx the data to perform the reduction on - @param incx the increment on which to perform the reduction - @param extraParams extra parameters used for calculations - @param result where to store the result of the reduction - */ - template - static __device__ void transform(const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets); - - - static __device__ void execCuda(int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets); - - - static __device__ void execAllCuda(int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets); - - - static __device__ void execScalarCuda(int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *zShapeInfo, - int * allocationPointer, void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo); - - - static __host__ void exec(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *extraParams, + virtual __device__ inline Y opAtomic(X d1, X d2, Y *extraParamsRef) = 0; + + /** + * Aggregate shared memory + * @param sPartialsRef + * @param tid + * @param extraParams + */ + template + static __device__ void aggregatePartials(void *sPartials, Nd4jLong tid, + Nd4jLong numItems, + void *extraParams); + + template + static __device__ void execScalarCuda( + const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *extraParams, void *z, + const Nd4jLong *zShapeInfo, int *allocationPointer, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo); + + template + static __device__ void transformAll( + const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *extraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, + int postProcessOrNot, int *allocationPointer, + const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, + const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets); + + /** +Perform a reduction +@param n the number of elements +@param xOffset the starting offset +@param dx the data to perform the reduction on +@param incx the increment on which to perform the reduction +@param extraParams extra parameters used for calculations +@param result where to store the result of the reduction +*/ + template + static __device__ void transform( + const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *extraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, + int postProcessOrNot, int *allocationPointer, + const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets); + + static __device__ void execCuda( + int opNum, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *extraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, + int postProcessOrNot, int *allocationPointer, + const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets); + + static __device__ void execAllCuda( + int opNum, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *extraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, + int postProcessOrNot, int *allocationPointer, + const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets); + + static __device__ void execScalarCuda( + int opNum, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *extraParams, void *vz, + const Nd4jLong *zShapeInfo, int *allocationPointer, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo); + + static __host__ void exec( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, int postProcessOrNot, int *allocationPointer, + const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets); + + static __host__ void execAll( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, int postProcessOrNot, int *allocationPointer, + const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets); + + static __host__ void execScalar(dim3 launchDims, cudaStream_t *stream, + int opNum, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets); - - static __host__ void execAll(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets); - - static __host__ void execScalar(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *zShapeInfo, - int* allocationPointer, void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo); + int *allocationPointer, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo); #else - template - static void execScalar(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo); - - - static void execScalar(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParamsVals, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo); - - - template - static void exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - int64_t start, int64_t stop); - - - template - static void exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - int64_t start, int64_t stop); - - - template - static void execAll(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets, - int64_t start, int64_t stop); - - - static void exec(int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - void *extraParamsVals, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - int64_t start, int64_t stop); - - - static void exec(int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - void *extraParamsVals, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - int64_t start, int64_t stop); - - - static void execAll(int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - void *extraParamsVals, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets, - int64_t start, int64_t stop); + template + static void execScalar(const void *vx, const Nd4jLong *xShapeInfo, + void *vextraParams, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo); + + static void execScalar(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParamsVals, const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo); + + template + static void exec(const void *vx, const Nd4jLong *xShapeInfo, + void *vextraParams, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, int64_t start, int64_t stop); + + template + static void exec(const void *vx, const Nd4jLong *xShapeInfo, + void *vextraParams, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, int64_t start, int64_t stop); + + template + static void execAll(const void *vx, const Nd4jLong *xShapeInfo, + void *vextraParams, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xOffsets, const Nd4jLong *yTadShapeInfo, + const Nd4jLong *yOffsets, int64_t start, int64_t stop); + + static void exec(int opNum, const void *vx, const Nd4jLong *xShapeInfo, + void *extraParamsVals, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, int64_t start, int64_t stop); + + static void exec(int opNum, const void *vx, const Nd4jLong *xShapeInfo, + void *extraParamsVals, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, int64_t start, int64_t stop); + + static void execAll(int opNum, const void *vx, const Nd4jLong *xShapeInfo, + void *extraParamsVals, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xOffsets, const Nd4jLong *yTadShapeInfo, + const Nd4jLong *yOffsets, int64_t start, int64_t stop); #endif }; - - -} -} +} // namespace reduce3 +} // namespace functions #ifdef __CUDACC__ #endif - - #endif /* REDUCE3_H_ */ diff --git a/libnd4j/include/loops/reduce_bool.h b/libnd4j/include/loops/reduce_bool.h index a74d53033f5f..9faa41ab93dd 100644 --- a/libnd4j/include/loops/reduce_bool.h +++ b/libnd4j/include/loops/reduce_bool.h @@ -14,21 +14,20 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - #ifndef REDUCE_BOOL_H #define REDUCE_BOOL_H #include //#include -#include #include +#include #ifdef _OPENMP #include #endif #include -#include -#include #include +#include #include +#include #pragma once #ifdef __CUDACC__ @@ -36,12 +35,11 @@ #include #endif - #include "legacy_ops.h" -//an op for the kernel +// an op for the kernel namespace functions { - namespace reduce { +namespace reduce { /** * A reduce function @@ -50,128 +48,156 @@ namespace functions { * via aggregating member * elements. */ - template - class ReduceBoolFunction { - public: +template +class ReduceBoolFunction { + public: #ifdef __CUDACC__ - template - static __device__ void aggregatePartials(void *sPartials, Nd4jLong tid, Nd4jLong numItems, void *extraParams); - - template - static __device__ void execScalarCuda(const void *vx, const Nd4jLong *xShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); - - template - static __device__ void transformCudaXD(const void *vx, const Nd4jLong *xShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets); - - template - static __host__ void intermediateScalar(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); - - template - static __host__ void intermediateXD(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - static __host__ void execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong* hXShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong* hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); - - static __host__ void execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong* hXShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong* hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); + template + static __device__ void aggregatePartials(void *sPartials, Nd4jLong tid, + Nd4jLong numItems, + void *extraParams); + + template + static __device__ void execScalarCuda(const void *vx, + const Nd4jLong *xShapeInfo, + void *extraParams, void *vz, + const Nd4jLong *zShapeInfo, + void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo); + + template + static __device__ void transformCudaXD( + const void *vx, const Nd4jLong *xShapeInfo, void *extraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, + void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets); + + template + static __host__ void intermediateScalar( + dim3 launchDims, cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShapeInfo, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); + + template + static __host__ void intermediateXD( + dim3 launchDims, cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShapeInfo, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static __host__ void execReduceScalar( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *vx, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShapeInfo, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); + + static __host__ void execReduceXD( + dim3 launchDims, cudaStream_t *stream, int opNum, int rank, + const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShapeInfo, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); #else - /** - * Reduce down to 1 number - * @param x the input - * @param xShapeInfo the shape information - * for the input - * @param extraParams the extra params - * @return - */ - template - static _CUDA_H Z execScalar(const void *x, const Nd4jLong *xShapeInfo, void *extraParams); - - template - static _CUDA_H void execScalar(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo); - - - static Z execScalar(int opNum, const void *x, const Nd4jLong *xShapeInfo, void *extraParams); - - static void execScalar(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo); - - static void exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfoBuffer, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop); - - /** - * Execute on the cpu - * @param x the input data - * @param xShapeInfo the shape information for x - * @param extraParams the extra parameters - * @param result the result buffer - * @param resultShapeInfoBuffer the shape information - * @param dimension the dimension to perform - * the reduce along long - * @param dimensionLength the length of the dimension buffer - */ - - - template - static void _CUDA_H exec(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfoBuffer, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop); - - /** - * CPU implementation - * @param x the input data - * @param xShapeInfo the shape information for - * the input data - * @param extraParams the extra parameters for the problem - * @param result the result buffer - * @param resultShapeInfo the shape information - */ - template - static void _CUDA_H exec(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfo); - - - - /** - * Reduce down to 1 number - * @param x the input - * @param xShapeInfo the shape information - * for the input - * @param extraParams the extra params - * @return - */ - template - static Z _CUDA_H execScalar(const void *x, Nd4jLong xElementWiseStride, Nd4jLong length, void *extraParams); + /** + * Reduce down to 1 number + * @param x the input + * @param xShapeInfo the shape information + * for the input + * @param extraParams the extra params + * @return + */ + template + static _CUDA_H Z execScalar(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams); + + template + static _CUDA_H void execScalar(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo); + + static Z execScalar(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams); + + static void execScalar(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo); + + static void exec(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *result, + const Nd4jLong *resultShapeInfoBuffer, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, int64_t stop); + + /** + * Execute on the cpu + * @param x the input data + * @param xShapeInfo the shape information for x + * @param extraParams the extra parameters + * @param result the result buffer + * @param resultShapeInfoBuffer the shape information + * @param dimension the dimension to perform + * the reduce along long + * @param dimensionLength the length of the dimension buffer + */ + + template + static void _CUDA_H exec(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *result, + const Nd4jLong *resultShapeInfoBuffer, + int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, + int64_t stop); + + /** + * CPU implementation + * @param x the input data + * @param xShapeInfo the shape information for + * the input data + * @param extraParams the extra parameters for the problem + * @param result the result buffer + * @param resultShapeInfo the shape information + */ + template + static void _CUDA_H exec(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *result, + const Nd4jLong *resultShapeInfo); + + /** + * Reduce down to 1 number + * @param x the input + * @param xShapeInfo the shape information + * for the input + * @param extraParams the extra params + * @return + */ + template + static Z _CUDA_H execScalar(const void *x, Nd4jLong xElementWiseStride, + Nd4jLong length, void *extraParams); #endif - }; - +}; #ifdef __CUDACC__ - /** - * - * @param extraParams - * @param sPartials - * @param sMemSize - */ - template - __device__ void initializeShared(T *extraParams, T **sPartials, int sMemSize); +/** + * + * @param extraParams + * @param sPartials + * @param sMemSize + */ +template +__device__ void initializeShared(T *extraParams, T **sPartials, int sMemSize); #endif - } +} // namespace reduce -} +} // namespace functions #endif - diff --git a/libnd4j/include/loops/reduce_float.h b/libnd4j/include/loops/reduce_float.h index c78082f8ed57..4437650db788 100644 --- a/libnd4j/include/loops/reduce_float.h +++ b/libnd4j/include/loops/reduce_float.h @@ -14,21 +14,20 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - #ifndef REDUCE_FLOAT_H #define REDUCE_FLOAT_H #include //#include -#include #include +#include #ifdef _OPENMP #include #endif #include -#include -#include #include +#include #include +#include #pragma once #ifdef __CUDACC__ @@ -36,12 +35,11 @@ #include #endif - #include "legacy_ops.h" -//an op for the kernel +// an op for the kernel namespace functions { - namespace reduce { +namespace reduce { /** * A reduce function @@ -50,132 +48,156 @@ namespace functions { * via aggregating member * elements. */ - template - class ReduceFloatFunction { - - public: - -#ifdef __CUDACC__ - template - static __device__ void aggregatePartials(void *sPartials, Nd4jLong tid, Nd4jLong numItems, void *extraParams); - - template - static __device__ void execScalarCuda(const void *vx, const Nd4jLong *xShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); - - template - static __device__ void transformCudaXD(const void *vx, const Nd4jLong *xShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets); - - template - static __host__ void intermediateScalar(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); - - template - static __host__ void intermediateXD(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShape, const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - static __host__ void execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); - - static __host__ void execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShape, const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); +template +class ReduceFloatFunction { + public: +#ifdef __CUDACC__ + template + static __device__ void aggregatePartials(void *sPartials, Nd4jLong tid, + Nd4jLong numItems, + void *extraParams); + + template + static __device__ void execScalarCuda(const void *vx, + const Nd4jLong *xShapeInfo, + void *extraParams, void *vz, + const Nd4jLong *zShapeInfo, + void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo); + + template + static __device__ void transformCudaXD( + const void *vx, const Nd4jLong *xShapeInfo, void *extraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, + void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets); + + template + static __host__ void intermediateScalar( + dim3 launchDims, cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShapeInfo, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); + + template + static __host__ void intermediateXD( + dim3 launchDims, cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShape, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static __host__ void execReduceScalar( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *vx, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShapeInfo, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); + + static __host__ void execReduceXD( + dim3 launchDims, cudaStream_t *stream, int opNum, int rank, + const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShape, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); #else - /** - * Reduce down to 1 number - * @param vx the input - * @param xShapeInfo the shape information - * for the input - * @param extraParams the extra params - * @return - */ - template - static _CUDA_H Z execScalar(const void *vx, const Nd4jLong *xShapeInfo, void *extraParams); - - template - static _CUDA_H void execScalar(const void *vx, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *zShapeInfo); - - - static Z execScalar(int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - void *extraParams); - - static void execScalar(int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *zShapeInfo); - - static void exec(int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *resultShapeInfoBuffer, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop); - - /** - * Execute on the cpu - * @param vx the input data - * @param xShapeInfo the shape information for vx - * @param extraParams the extra parameters - * @param vz the vz buffer - * @param resultShapeInfoBuffer the shape information - * @param dimension the dimension to perform - * the reduce along long - * @param dimensionLength the length of the dimension buffer - */ - - - template - static void _CUDA_H exec(const void *vx, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *resultShapeInfoBuffer, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop); - - /** - * CPU implementation - * @param vx the input data - * @param xShapeInfo the shape information for - * the input data - * @param extraParams the extra parameters for the problem - * @param vz the vz buffer - * @param zShapeInfo the shape information - */ - template - static void _CUDA_H exec(const void *vx, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *zShapeInfo); - - - - /** - * Reduce down to 1 number - * @param vx the input - * @param xShapeInfo the shape information - * for the input - * @param extraParams the extra params - * @return - */ - template - static Z _CUDA_H execScalar(const void *vx, Nd4jLong xElementWiseStride, Nd4jLong length, void *extraParams); + /** + * Reduce down to 1 number + * @param vx the input + * @param xShapeInfo the shape information + * for the input + * @param extraParams the extra params + * @return + */ + template + static _CUDA_H Z execScalar(const void *vx, const Nd4jLong *xShapeInfo, + void *extraParams); + + template + static _CUDA_H void execScalar(const void *vx, const Nd4jLong *xShapeInfo, + void *extraParams, void *vz, + const Nd4jLong *zShapeInfo); + + static Z execScalar(int opNum, const void *vx, const Nd4jLong *xShapeInfo, + void *extraParams); + + static void execScalar(int opNum, const void *vx, const Nd4jLong *xShapeInfo, + void *extraParams, void *vz, + const Nd4jLong *zShapeInfo); + + static void exec(int opNum, const void *vx, const Nd4jLong *xShapeInfo, + void *extraParams, void *vz, + const Nd4jLong *resultShapeInfoBuffer, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, int64_t stop); + + /** + * Execute on the cpu + * @param vx the input data + * @param xShapeInfo the shape information for vx + * @param extraParams the extra parameters + * @param vz the vz buffer + * @param resultShapeInfoBuffer the shape information + * @param dimension the dimension to perform + * the reduce along long + * @param dimensionLength the length of the dimension buffer + */ + + template + static void _CUDA_H exec(const void *vx, const Nd4jLong *xShapeInfo, + void *extraParams, void *vz, + const Nd4jLong *resultShapeInfoBuffer, + int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, + int64_t stop); + + /** + * CPU implementation + * @param vx the input data + * @param xShapeInfo the shape information for + * the input data + * @param extraParams the extra parameters for the problem + * @param vz the vz buffer + * @param zShapeInfo the shape information + */ + template + static void _CUDA_H exec(const void *vx, const Nd4jLong *xShapeInfo, + void *extraParams, void *vz, + const Nd4jLong *zShapeInfo); + + /** + * Reduce down to 1 number + * @param vx the input + * @param xShapeInfo the shape information + * for the input + * @param extraParams the extra params + * @return + */ + template + static Z _CUDA_H execScalar(const void *vx, Nd4jLong xElementWiseStride, + Nd4jLong length, void *extraParams); #endif - }; - +}; #ifdef __CUDACC__ - /** - * - * @param extraParams - * @param sPartials - * @param sMemSize - */ - template - __device__ void initializeShared(T *extraParams, T **sPartials, int sMemSize); +/** + * + * @param extraParams + * @param sPartials + * @param sMemSize + */ +template +__device__ void initializeShared(T *extraParams, T **sPartials, int sMemSize); #endif - } +} // namespace reduce -} +} // namespace functions #endif - diff --git a/libnd4j/include/loops/reduce_long.h b/libnd4j/include/loops/reduce_long.h index 45ede29850dc..193a6df0ebaa 100644 --- a/libnd4j/include/loops/reduce_long.h +++ b/libnd4j/include/loops/reduce_long.h @@ -14,21 +14,20 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - #ifndef REDUCE_LONG_H #define REDUCE_LONG_H #include //#include -#include #include +#include #ifdef _OPENMP #include #endif #include -#include -#include #include +#include #include +#include #pragma once #ifdef __CUDACC__ @@ -38,9 +37,9 @@ #include "legacy_ops.h" -//an op for the kernel +// an op for the kernel namespace functions { - namespace reduce { +namespace reduce { /** * A reduce function @@ -49,132 +48,157 @@ namespace functions { * via aggregating member * elements. */ - template - class ReduceLongFunction { - public: +template +class ReduceLongFunction { + public: #ifdef __CUDACC__ - template - static __device__ void aggregatePartials(void *sPartials, Nd4jLong tid, Nd4jLong numItems, void *extraParams); - - template - static __device__ void execScalarCuda(const void *vx, const Nd4jLong *xShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); - - template - static __device__ void transformCudaXD(const void *vx, const Nd4jLong *xShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets); - - template - static __host__ void intermediateScalar(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); - - template - static __host__ void intermediateXD(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - static __host__ void execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong* hXShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong* hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); - - static __host__ void execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong* hXShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong* hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); + template + static __device__ void aggregatePartials(void *sPartials, Nd4jLong tid, + Nd4jLong numItems, + void *extraParams); + + template + static __device__ void execScalarCuda(const void *vx, + const Nd4jLong *xShapeInfo, + void *extraParams, void *vz, + const Nd4jLong *zShapeInfo, + void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo); + + template + static __device__ void transformCudaXD( + const void *vx, const Nd4jLong *xShapeInfo, void *extraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, + void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets); + + template + static __host__ void intermediateScalar( + dim3 launchDims, cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShapeInfo, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); + + template + static __host__ void intermediateXD( + dim3 launchDims, cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShapeInfo, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static __host__ void execReduceScalar( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *vx, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShapeInfo, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); + + static __host__ void execReduceXD( + dim3 launchDims, cudaStream_t *stream, int opNum, int rank, + const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShapeInfo, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); #else - /** - * Reduce down to 1 number - * @param x the input - * @param xShapeInfo the shape information - * for the input - * @param extraParams the extra params - * @return - */ - template - static _CUDA_H Z execScalar(const void *x, const Nd4jLong *xShapeInfo, void *extraParams); - - template - static _CUDA_H void execScalar(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo); - - - static Z execScalar(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams); - - static void execScalar(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo); - - static void exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfoBuffer, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop); - - /** - * Execute on the cpu - * @param x the input data - * @param xShapeInfo the shape information for x - * @param extraParams the extra parameters - * @param result the result buffer - * @param resultShapeInfoBuffer the shape information - * @param dimension the dimension to perform - * the reduce along long - * @param dimensionLength the length of the dimension buffer - */ - - - template - static void _CUDA_H exec(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfoBuffer, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop); - - /** - * CPU implementation - * @param x the input data - * @param xShapeInfo the shape information for - * the input data - * @param extraParams the extra parameters for the problem - * @param result the result buffer - * @param resultShapeInfo the shape information - */ - template - static void _CUDA_H exec(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfo); - - - - /** - * Reduce down to 1 number - * @param x the input - * @param xShapeInfo the shape information - * for the input - * @param extraParams the extra params - * @return - */ - template - static Z _CUDA_H execScalar(const void *x, Nd4jLong xElementWiseStride, - Nd4jLong length, - void *extraParams); + /** + * Reduce down to 1 number + * @param x the input + * @param xShapeInfo the shape information + * for the input + * @param extraParams the extra params + * @return + */ + template + static _CUDA_H Z execScalar(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams); + + template + static _CUDA_H void execScalar(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo); + + static Z execScalar(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams); + + static void execScalar(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo); + + static void exec(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *result, + const Nd4jLong *resultShapeInfoBuffer, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, int64_t stop); + + /** + * Execute on the cpu + * @param x the input data + * @param xShapeInfo the shape information for x + * @param extraParams the extra parameters + * @param result the result buffer + * @param resultShapeInfoBuffer the shape information + * @param dimension the dimension to perform + * the reduce along long + * @param dimensionLength the length of the dimension buffer + */ + + template + static void _CUDA_H exec(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *result, + const Nd4jLong *resultShapeInfoBuffer, + int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, + int64_t stop); + + /** + * CPU implementation + * @param x the input data + * @param xShapeInfo the shape information for + * the input data + * @param extraParams the extra parameters for the problem + * @param result the result buffer + * @param resultShapeInfo the shape information + */ + template + static void _CUDA_H exec(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *result, + const Nd4jLong *resultShapeInfo); + + /** + * Reduce down to 1 number + * @param x the input + * @param xShapeInfo the shape information + * for the input + * @param extraParams the extra params + * @return + */ + template + static Z _CUDA_H execScalar(const void *x, Nd4jLong xElementWiseStride, + Nd4jLong length, void *extraParams); #endif - }; +}; #ifdef __CUDACC__ - /** - * - * @param extraParams - * @param sPartials - * @param sMemSize - */ - template - __device__ void initializeShared(T *extraParams, T **sPartials, int sMemSize); +/** + * + * @param extraParams + * @param sPartials + * @param sMemSize + */ +template +__device__ void initializeShared(T *extraParams, T **sPartials, int sMemSize); #endif - } +} // namespace reduce -} +} // namespace functions #endif - diff --git a/libnd4j/include/loops/reduce_same.h b/libnd4j/include/loops/reduce_same.h index 5f3622f39b60..a1dd79280661 100644 --- a/libnd4j/include/loops/reduce_same.h +++ b/libnd4j/include/loops/reduce_same.h @@ -14,21 +14,20 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - #ifndef REDUCE_SAME_H #define REDUCE_SAME_H #include //#include -#include #include +#include #ifdef _OPENMP #include #endif #include -#include -#include #include +#include #include +#include #pragma once #ifdef __CUDACC__ @@ -38,9 +37,9 @@ #include "legacy_ops.h" -//an op for the kernel +// an op for the kernel namespace functions { - namespace reduce { +namespace reduce { /** * A reduce function @@ -49,136 +48,165 @@ namespace functions { * via aggregating member * elements. */ - template - class ReduceSameFunction { - public: +template +class ReduceSameFunction { + public: #ifdef __CUDACC__ - template - static __device__ void aggregatePartials(void *sPartials, Nd4jLong tid, Nd4jLong numItems, void *extraParams); - - template - static __device__ void execScalarCuda( void const* vx, Nd4jLong const *xShapeInfo, void *extraParams, void *vz, Nd4jLong const* zShapeInfo, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo); - - static __device__ void execScalarCudaLegacy(int opNum, void const* vx, Nd4jLong const* xShapeInfo, void *extraParams, void *vz, Nd4jLong const* zShapeInfo, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo); - - template - static __device__ void transformCudaXD( void const* vx, Nd4jLong const* xShapeInfo, void *extraParams, void *vz, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets); - - template - static __host__ void intermediateScalar(dim3 launchDims, cudaStream_t *stream, void const* vx, Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, void *extraParams, void *vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo); - - template - static __host__ void intermediateXD(dim3 launchDims, cudaStream_t *stream, void const* vx, Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, void *extraParams, void *vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets); - - static __host__ void execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, void *extraParams, void *vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo); - - static __host__ void execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, void const* vx, Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, void *extraParams, void *vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets); + template + static __device__ void aggregatePartials(void* sPartials, Nd4jLong tid, + Nd4jLong numItems, + void* extraParams); + + template + static __device__ void execScalarCuda(void const* vx, + Nd4jLong const* xShapeInfo, + void* extraParams, void* vz, + Nd4jLong const* zShapeInfo, + void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo); + + static __device__ void execScalarCudaLegacy(int opNum, void const* vx, + Nd4jLong const* xShapeInfo, + void* extraParams, void* vz, + Nd4jLong const* zShapeInfo, + void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo); + + template + static __device__ void transformCudaXD( + void const* vx, Nd4jLong const* xShapeInfo, void* extraParams, void* vz, + Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + void* reductionBuffer, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets); + + template + static __host__ void intermediateScalar( + dim3 launchDims, cudaStream_t* stream, void const* vx, + Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, + void* extraParams, void* vz, Nd4jLong const* zShapeInfo, + Nd4jLong const* hZShapeInfo, int* dimension, int dimensionLength, + void* reductionBuffer, Nd4jLong const* tadOnlyShapeInfo); + + template + static __host__ void intermediateXD( + dim3 launchDims, cudaStream_t* stream, void const* vx, + Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, + void* extraParams, void* vz, Nd4jLong const* zShapeInfo, + Nd4jLong const* hZShapeInfo, int* dimension, int dimensionLength, + void* reductionPointer, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets); + + static __host__ void execReduceScalar( + dim3 launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, + void* extraParams, void* vz, Nd4jLong const* zShapeInfo, + Nd4jLong const* hZShapeInfo, int* dimension, int dimensionLength, + void* reductionBuffer, Nd4jLong const* tadOnlyShapeInfo); + + static __host__ void execReduceXD( + dim3 launchDims, cudaStream_t* stream, int opNum, int rank, + void const* vx, Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, + void* extraParams, void* vz, Nd4jLong const* zShapeInfo, + Nd4jLong const* hZShapeInfo, int* dimension, int dimensionLength, + void* reductionPointer, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets); #else - /** - * Reduce down to 1 number - * @param x the input - * @param xShapeInfo the shape information - * for the input - * @param extraParams the extra params - * @return - */ - template - static _CUDA_H X execScalar(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams); - - template - static _CUDA_H void execScalar(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo); - - - static X execScalar(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams); - - static void execScalar(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo); - - static void exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfoBuffer, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop); - - /** - * Execute on the cpu - * @param x the input data - * @param xShapeInfo the shape information for x - * @param extraParams the extra parameters - * @param result the result buffer - * @param resultShapeInfoBuffer the shape information - * @param dimension the dimension to perform - * the reduce along long - * @param dimensionLength the length of the dimension buffer - */ - - - template - static void _CUDA_H exec(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfoBuffer, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop); - - /** - * CPU implementation - * @param x the input data - * @param xShapeInfo the shape information for - * the input data - * @param extraParams the extra parameters for the problem - * @param result the result buffer - * @param resultShapeInfo the shape information - */ - template - static void _CUDA_H exec(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfo); - - - - /** - * Reduce down to 1 number - * @param x the input - * @param xShapeInfo the shape information - * for the input - * @param extraParams the extra params - * @return - */ - template - static X _CUDA_H execScalar(const void *x, Nd4jLong xElementWiseStride, - Nd4jLong length, - void *extraParams); + /** + * Reduce down to 1 number + * @param x the input + * @param xShapeInfo the shape information + * for the input + * @param extraParams the extra params + * @return + */ + template + static _CUDA_H X execScalar(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams); + + template + static _CUDA_H void execScalar(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo); + + static X execScalar(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams); + + static void execScalar(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo); + + static void exec(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *result, + const Nd4jLong *resultShapeInfoBuffer, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, int64_t stop); + + /** + * Execute on the cpu + * @param x the input data + * @param xShapeInfo the shape information for x + * @param extraParams the extra parameters + * @param result the result buffer + * @param resultShapeInfoBuffer the shape information + * @param dimension the dimension to perform + * the reduce along long + * @param dimensionLength the length of the dimension buffer + */ + + template + static void _CUDA_H exec(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *result, + const Nd4jLong *resultShapeInfoBuffer, + int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, + int64_t stop); + + /** + * CPU implementation + * @param x the input data + * @param xShapeInfo the shape information for + * the input data + * @param extraParams the extra parameters for the problem + * @param result the result buffer + * @param resultShapeInfo the shape information + */ + template + static void _CUDA_H exec(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *result, + const Nd4jLong *resultShapeInfo); + + /** + * Reduce down to 1 number + * @param x the input + * @param xShapeInfo the shape information + * for the input + * @param extraParams the extra params + * @return + */ + template + static X _CUDA_H execScalar(const void *x, Nd4jLong xElementWiseStride, + Nd4jLong length, void *extraParams); #endif - }; +}; #ifdef __CUDACC__ - /** - * - * @param extraParams - * @param sPartials - * @param sMemSize - */ - template - __device__ void initializeShared(T *extraParams, T **sPartials, int sMemSize); +/** + * + * @param extraParams + * @param sPartials + * @param sMemSize + */ +template +__device__ void initializeShared(T* extraParams, T** sPartials, int sMemSize); #endif - } +} // namespace reduce -} +} // namespace functions #endif - diff --git a/libnd4j/include/loops/scalar.h b/libnd4j/include/loops/scalar.h old mode 100755 new mode 100644 index f7333d57de6f..c50b8c4764f2 --- a/libnd4j/include/loops/scalar.h +++ b/libnd4j/include/loops/scalar.h @@ -23,9 +23,9 @@ #ifndef SCALAR_H_ #define SCALAR_H_ +#include #include #include -#include #ifdef __JNI__ #include @@ -33,6 +33,7 @@ #include #include #include + #include "helpers/logger.h" #ifdef __CUDACC__ @@ -44,142 +45,117 @@ #include "legacy_ops.h" namespace functions { - namespace scalar { +namespace scalar { /** * Apply a scalar * operation to an array */ - template - class ScalarTransform { - - public: - +template +class ScalarTransform { + public: #ifdef __CUDACC__ - template - __host__ - static void intermediateShaped(dim3& launchDims, cudaStream_t *stream, - const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *hxShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong *hzShapeInfo, - const void* vscalar, - void *vextraParams, - int *allocPointer); - - template - __host__ - static void intermediateAlongDimension(dim3& launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - __host__ - static void executeCudaShaped(dim3& launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong *hxShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, const Nd4jLong *hzShapeInfo, - const void* scalar, - void *extraParams); - - __host__ - static void executeCudaAlongDimension(dim3& launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ); + template + __host__ static void intermediateShaped( + dim3 &launchDims, cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const Nd4jLong *hxShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, const Nd4jLong *hzShapeInfo, + const void *vscalar, void *vextraParams, int *allocPointer); + + template + __host__ static void intermediateAlongDimension( + dim3 &launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, void *z, const Nd4jLong *zShapeInfo, + const void *scalars, void *extraParams, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + __host__ static void executeCudaShaped( + dim3 &launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, const Nd4jLong *hxShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, const Nd4jLong *hzShapeInfo, + const void *scalar, void *extraParams); + + __host__ static void executeCudaAlongDimension( + dim3 &launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *z, const Nd4jLong *zShapeInfo, + const void *scalars, void *extraParams, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ); #else - template - static void transform(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ, - uint64_t start, uint64_t stop); - - static void transform(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ, - uint64_t start, uint64_t stop); - - static void transform(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - const void *scalar, - void *extraParams, - uint64_t start, uint64_t stop); - - static void transform(int opNum, - const void *x, Nd4jLong xStride, - void *result, Nd4jLong resultStride, - const void *scalar, - void *extraParams, - uint64_t len, uint64_t start, uint64_t stop); - - - - - /* - * ScalarOp along dimension - */ - - - /** - * CPU implementation of scalar operation - * @param x the input - * @param xStride the stride for the input - * @param result the result buffer - * @param resultStride the stride for the result - * @param scalar the scalar to apply - * @param extraParams the extra parameters where - * neccssary - * @param len the number of elements to loop over - */ - - template - static void transform(const void *x, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - const void *scalar, - void *extraParams, - uint64_t start, uint64_t stop); - - - /** - * CPU implementation of scalar operation - * @param x the input - * @param xStride the stride for the input - * @param result the result buffer - * @param resultStride the stride for the result - * @param scalar the scalar to apply - * @param extraParams the extra parameters where - * neccssary - * @param len the number of elements to loop over - */ - - template - static void transform(const void *x, Nd4jLong xStride, - void *result, Nd4jLong resultStride, - const void *scalar, - void *extraParams, - uint64_t len, uint64_t start, uint64_t stop); + template + static void transform(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, const Nd4jLong *zShapeInfo, + const void *scalars, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, + const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ, uint64_t start, + uint64_t stop); + + static void transform(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, const Nd4jLong *zShapeInfo, + const void *scalars, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, + const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ, uint64_t start, + uint64_t stop); + + static void transform(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, + const void *scalar, void *extraParams, uint64_t start, + uint64_t stop); + + static void transform(int opNum, const void *x, Nd4jLong xStride, + void *result, Nd4jLong resultStride, const void *scalar, + void *extraParams, uint64_t len, uint64_t start, + uint64_t stop); + + /* + * ScalarOp along dimension + */ + + /** + * CPU implementation of scalar operation + * @param x the input + * @param xStride the stride for the input + * @param result the result buffer + * @param resultStride the stride for the result + * @param scalar the scalar to apply + * @param extraParams the extra parameters where + * neccssary + * @param len the number of elements to loop over + */ + + template + static void transform(const void *x, const Nd4jLong *xShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, const void *scalar, + void *extraParams, uint64_t start, uint64_t stop); + + /** + * CPU implementation of scalar operation + * @param x the input + * @param xStride the stride for the input + * @param result the result buffer + * @param resultStride the stride for the result + * @param scalar the scalar to apply + * @param extraParams the extra parameters where + * neccssary + * @param len the number of elements to loop over + */ + + template + static void transform(const void *x, Nd4jLong xStride, void *result, + Nd4jLong resultStride, const void *scalar, + void *extraParams, uint64_t len, uint64_t start, + uint64_t stop); #endif - }; - } -} - +}; +} // namespace scalar +} // namespace functions #endif /* SCALAR_H_ */ diff --git a/libnd4j/include/loops/scalar_bool.h b/libnd4j/include/loops/scalar_bool.h index 4992df5a16b7..e1e20973081c 100644 --- a/libnd4j/include/loops/scalar_bool.h +++ b/libnd4j/include/loops/scalar_bool.h @@ -28,12 +28,13 @@ #ifdef __JNI__ #include #endif +#include +#include #include #include #include + #include "helpers/logger.h" -#include -#include #ifdef __CUDACC__ #include @@ -44,170 +45,139 @@ #include "legacy_ops.h" namespace functions { - namespace scalar { +namespace scalar { /** * Apply a scalar * operation to an array */ - template - class ScalarBoolTransform { - - public: - +template +class ScalarBoolTransform { + public: #ifdef __CUDACC__ - - template - __device__ - static void transformCuda(const void* scalar, - const void *vy, const Nd4jLong *shapeInfo, - void *vparams, - void *vresult, const Nd4jLong *resultShapeInfo, - int *allocationBuffer); - - template - __device__ - static void transformCuda(Nd4jLong n, - const void* vx, const void *vy, Nd4jLong yEWS, - void *vparams, - void *vz, Nd4jLong zEWS, - int *allocationBuffer); - - template - __device__ - static void transformCuda(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - const void *vscalars, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - template - __host__ - static void intermediateAlongDimension(dim3& launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - template - __host__ - static void intermediateShaped(dim3& launchDims, cudaStream_t *stream, - const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - const void* vscalar, - void *vextraParams, - int *allocPointer); - - __host__ - static void executeCudaShaped(dim3& launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - const void* scalar, - const void *extraParams); - - __host__ - static void executeCudaAlongDimension(dim3& launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ); + + template + __device__ static void transformCuda(const void *scalar, const void *vy, + const Nd4jLong *shapeInfo, void *vparams, + void *vresult, + const Nd4jLong *resultShapeInfo, + int *allocationBuffer); + + template + __device__ static void transformCuda(Nd4jLong n, const void *vx, + const void *vy, Nd4jLong yEWS, + void *vparams, void *vz, Nd4jLong zEWS, + int *allocationBuffer); + + template + __device__ static void transformCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, const void *vscalars, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + template + __host__ static void intermediateAlongDimension( + dim3 &launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, void *z, const Nd4jLong *zShapeInfo, + const void *scalars, void *extraParams, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + template + __host__ static void intermediateShaped( + dim3 &launchDims, cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, + const void *vscalar, void *vextraParams, int *allocPointer); + + __host__ static void executeCudaShaped( + dim3 &launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *result, const Nd4jLong *resultShapeInfo, + const void *scalar, const void *extraParams); + + __host__ static void executeCudaAlongDimension( + dim3 &launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *z, const Nd4jLong *zShapeInfo, + const void *scalars, void *extraParams, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ); /* #include "cuda/scalar_temp.cu" */ #else - template - static void transform(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ, - uint64_t start, uint64_t stop); - - static void transform(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ, - uint64_t start, uint64_t stop); - - static void transform(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - const void *scalar, - void *extraParams, - uint64_t start, uint64_t stop); - - static void transform(int opNum, - const void *x, Nd4jLong xStride, - void *result, Nd4jLong resultStride, - const void *scalar, - void *extraParams, - uint64_t n, uint64_t start, uint64_t stop); - - - - - /* - * ScalarOp along dimension - */ - - - /** - * CPU implementation of scalar operation - * @param x the input - * @param xStride the stride for the input - * @param result the result buffer - * @param resultStride the stride for the result - * @param scalar the scalar to apply - * @param extraParams the extra parameters where - * neccssary - * @param n the number of elements to loop over - */ - - template - static void transform(const void *x, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - const void *scalar, - void *extraParams, - uint64_t start, uint64_t stop); - - - /** - * CPU implementation of scalar operation - * @param x the input - * @param xStride the stride for the input - * @param result the result buffer - * @param resultStride the stride for the result - * @param scalar the scalar to apply - * @param extraParams the extra parameters where - * neccssary - * @param n the number of elements to loop over - */ - - template - static void transform(const void *x, Nd4jLong xStride, - void *result, Nd4jLong resultStride, - const void *scalar, void *extraParams, - uint64_t n, uint64_t start, uint64_t stop); + template + static void transform(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, const Nd4jLong *zShapeInfo, + const void *scalars, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, + const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ, uint64_t start, + uint64_t stop); + + static void transform(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, const Nd4jLong *zShapeInfo, + const void *scalars, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, + const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ, uint64_t start, + uint64_t stop); + + static void transform(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, + const void *scalar, void *extraParams, uint64_t start, + uint64_t stop); + + static void transform(int opNum, const void *x, Nd4jLong xStride, + void *result, Nd4jLong resultStride, const void *scalar, + void *extraParams, uint64_t n, uint64_t start, + uint64_t stop); + + /* + * ScalarOp along dimension + */ + + /** + * CPU implementation of scalar operation + * @param x the input + * @param xStride the stride for the input + * @param result the result buffer + * @param resultStride the stride for the result + * @param scalar the scalar to apply + * @param extraParams the extra parameters where + * neccssary + * @param n the number of elements to loop over + */ + + template + static void transform(const void *x, const Nd4jLong *xShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, const void *scalar, + void *extraParams, uint64_t start, uint64_t stop); + + /** + * CPU implementation of scalar operation + * @param x the input + * @param xStride the stride for the input + * @param result the result buffer + * @param resultStride the stride for the result + * @param scalar the scalar to apply + * @param extraParams the extra parameters where + * neccssary + * @param n the number of elements to loop over + */ + + template + static void transform(const void *x, Nd4jLong xStride, void *result, + Nd4jLong resultStride, const void *scalar, + void *extraParams, uint64_t n, uint64_t start, + uint64_t stop); #endif - }; - } -} - +}; +} // namespace scalar +} // namespace functions #endif /* SCALAR_H_ */ diff --git a/libnd4j/include/loops/scalar_int.h b/libnd4j/include/loops/scalar_int.h index c3a53199efcb..c7a39da3be96 100644 --- a/libnd4j/include/loops/scalar_int.h +++ b/libnd4j/include/loops/scalar_int.h @@ -28,12 +28,13 @@ #ifdef __JNI__ #include #endif +#include +#include #include #include #include + #include "helpers/logger.h" -#include -#include #ifdef __CUDACC__ #include @@ -44,169 +45,138 @@ #include "legacy_ops.h" namespace functions { - namespace scalar { +namespace scalar { /** * Apply a scalar * operation to an array */ - template - class ScalarIntTransform { - - public: - +template +class ScalarIntTransform { + public: #ifdef __CUDACC__ - - template - __device__ - static void transformCuda(const void* scalar, - const void *vy, const Nd4jLong *shapeInfo, - void *vparams, - void *vresult, const Nd4jLong *resultShapeInfo, - int *allocationBuffer); - - template - __device__ - static void transformCuda(Nd4jLong n, - const void* vx, const void *vy, Nd4jLong yEWS, - void *vparams, - void *vz, Nd4jLong zEWS, - int *allocationBuffer); - - template - __device__ - static void transformCuda(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - const void *vscalars, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - template - __host__ - static void intermediateAlongDimension(dim3& launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - template - __host__ - static void intermediateShaped(dim3& launchDims, cudaStream_t *stream, - const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - const void* vscalar, - void *vextraParams, - int *allocPointer); - - __host__ - static void executeCudaShaped(dim3& launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - const void* scalar, - void *extraParams); - - __host__ - static void executeCudaAlongDimension(dim3& launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ); + + template + __device__ static void transformCuda(const void *scalar, const void *vy, + const Nd4jLong *shapeInfo, void *vparams, + void *vresult, + const Nd4jLong *resultShapeInfo, + int *allocationBuffer); + + template + __device__ static void transformCuda(Nd4jLong n, const void *vx, + const void *vy, Nd4jLong yEWS, + void *vparams, void *vz, Nd4jLong zEWS, + int *allocationBuffer); + + template + __device__ static void transformCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, const void *vscalars, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + template + __host__ static void intermediateAlongDimension( + dim3 &launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, void *z, const Nd4jLong *zShapeInfo, + const void *scalars, void *extraParams, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + template + __host__ static void intermediateShaped( + dim3 &launchDims, cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, + const void *vscalar, void *vextraParams, int *allocPointer); + + __host__ static void executeCudaShaped(dim3 &launchDims, cudaStream_t *stream, + int opNum, const void *x, + const Nd4jLong *xShapeInfo, + void *result, + const Nd4jLong *resultShapeInfo, + const void *scalar, void *extraParams); + + __host__ static void executeCudaAlongDimension( + dim3 &launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *z, const Nd4jLong *zShapeInfo, + const void *scalars, void *extraParams, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ); #else - template - static void transform(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ, - uint64_t start, uint64_t stop); - - static void transform(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ, - uint64_t start, uint64_t stop); - - static void transform(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - const void *scalar, - void *extraParams, - uint64_t start, - uint64_t stop); - - static void transform(int opNum, - const void *x, Nd4jLong xStride, - void *result, Nd4jLong resultStride, - const void *scalar, - void *extraParams, - uint64_t n, uint64_t start, uint64_t stop); - - - - - /* - * ScalarOp along dimension - */ - - - /** - * CPU implementation of scalar operation - * @param x the input - * @param xStride the stride for the input - * @param result the result buffer - * @param resultStride the stride for the result - * @param scalar the scalar to apply - * @param extraParams the extra parameters where - * neccssary - * @param n the number of elements to loop over - */ - - template - static void transform(const void *x, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - const void *scalar, - void *extraParams, - uint64_t start, uint64_t stop); - - - /** - * CPU implementation of scalar operation - * @param x the input - * @param xStride the stride for the input - * @param result the result buffer - * @param resultStride the stride for the result - * @param scalar the scalar to apply - * @param extraParams the extra parameters where - * neccssary - * @param n the number of elements to loop over - */ - - template - static void transform(const void *x, Nd4jLong xStride, - void *result, Nd4jLong resultStride, - const void *scalar, - void *extraParams, - uint64_t n, uint64_t start, uint64_t stop); + template + static void transform(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, const Nd4jLong *zShapeInfo, + const void *scalars, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, + const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ, uint64_t start, + uint64_t stop); + + static void transform(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, const Nd4jLong *zShapeInfo, + const void *scalars, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, + const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ, uint64_t start, + uint64_t stop); + + static void transform(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, + const void *scalar, void *extraParams, uint64_t start, + uint64_t stop); + + static void transform(int opNum, const void *x, Nd4jLong xStride, + void *result, Nd4jLong resultStride, const void *scalar, + void *extraParams, uint64_t n, uint64_t start, + uint64_t stop); + + /* + * ScalarOp along dimension + */ + + /** + * CPU implementation of scalar operation + * @param x the input + * @param xStride the stride for the input + * @param result the result buffer + * @param resultStride the stride for the result + * @param scalar the scalar to apply + * @param extraParams the extra parameters where + * neccssary + * @param n the number of elements to loop over + */ + + template + static void transform(const void *x, const Nd4jLong *xShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, const void *scalar, + void *extraParams, uint64_t start, uint64_t stop); + + /** + * CPU implementation of scalar operation + * @param x the input + * @param xStride the stride for the input + * @param result the result buffer + * @param resultStride the stride for the result + * @param scalar the scalar to apply + * @param extraParams the extra parameters where + * neccssary + * @param n the number of elements to loop over + */ + + template + static void transform(const void *x, Nd4jLong xStride, void *result, + Nd4jLong resultStride, const void *scalar, + void *extraParams, uint64_t n, uint64_t start, + uint64_t stop); #endif - }; - } -} - +}; +} // namespace scalar +} // namespace functions #endif /* SCALAR_H_ */ diff --git a/libnd4j/include/loops/special_kernels.h b/libnd4j/include/loops/special_kernels.h index 209d35120781..b84366cfe4b0 100644 --- a/libnd4j/include/loops/special_kernels.h +++ b/libnd4j/include/loops/special_kernels.h @@ -21,84 +21,134 @@ #ifndef LIBND4J_SPECIAL_KERNELS_H #define LIBND4J_SPECIAL_KERNELS_H -#include - -#include -#include -#include #include -#include -#include #include +#include +#include #include #include +#include +#include +#include +#include namespace sd { - template - _CUDA_H void fillIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, const Nd4jLong *xShapeInfo, Nd4jLong length, long idx); - - template - _CUDA_H void fillDimensionalIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, const void *dX, void *dZ, const Nd4jLong *zShapeInfo, const Nd4jLong *tadOnlyShapeInfo, int *dimension, int dimensionLength, const Nd4jLong *tadOffsets); - - template - _CUDA_H void convertToHalfGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong n, half *dz); - - template - _CUDA_H void tearKernelGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, Nd4jPointer *targets, - Nd4jLong const* zShapeInfo, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets); - - template - _CUDA_H void shuffleKernelGeneric(dim3 &launchDims, cudaStream_t *stream, void **vdX, Nd4jLong **xShapeInfo, void **vdZ, int N, - int *shuffleMap, Nd4jLong** tadOnlyShapeInfo, Nd4jLong** tadOffsets); - - template - _CUDA_H void convertHalfsToGeneric(dim3 &launchDims, cudaStream_t *stream, half *dx, Nd4jLong n, void *dz); - - template - _CUDA_H void concatKernelVStackGeneric(dim3 &launchDims, cudaStream_t *stream, int numArrays, Nd4jPointer *data, - Nd4jPointer *inputShapeInfos, void *vz, Nd4jLong const* zShapeInfo); - - template - _CUDA_H void concatKernelScalarGeneric(dim3 &launchDims, cudaStream_t *stream, int numArrays, Nd4jPointer *data, void *vresult); - - template - _CUDA_H void concatKernelHStackGeneric(dim3 &launchDims, cudaStream_t *stream, int numArrays, Nd4jPointer *data, - Nd4jPointer *inputShapeInfos, void *vresult, Nd4jLong const* resultShapeInfo); - - template - _CUDA_H void concatKernelGeneric(dim3 &launchDims, cudaStream_t *stream, int numArrays, Nd4jPointer *data, - Nd4jPointer *inputShapeInfos, void *vresult, Nd4jLong const* resultShapeInfo, - Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers, Nd4jLong const* zTadShape, Nd4jLong const* zOffsets); - - template - _CUDA_H void pullRowsKernelGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, void *vz, Nd4jLong n, Nd4jLong *indexes, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* zTadShapeInfo, Nd4jLong const* zTadOffsets); - - template - _CUDA_H void averagingKernelGeneric(dim3 &launchDims, cudaStream_t *stream, void **vdx, void *vdz, int n, Nd4jLong length, bool propagate); - - template - _CUDA_H void accumulateKernelGeneric(dim3 &launchDims, cudaStream_t *stream, void **vx, void *vz, int n, const Nd4jLong length); - - template - _CUDA_H void flattenKernelGeneric(dim3& launchDims, cudaStream_t *stream, Nd4jPointer *extraPointers, int dOffset, char order, void *vz, Nd4jLong *zShapeInfo, void *vy, Nd4jLong *yShapeInfo); - - template - _CUDA_H void tileKernelH(void const* inputBuffer, Nd4jLong const* inputShape, void* outputBuffer, Nd4jLong const* outputShape, Nd4jLong resultLength, cudaStream_t *stream); - template - _CUDA_H void tileKernelHH(void const* inputBuffer, Nd4jLong const* inputShape, void* outputBuffer, Nd4jLong const* outputShape, Nd4jLong resultLength, Nd4jLong ews, cudaStream_t *stream); - - class NDArray; - template - _CUDA_H void setDiagonalValueUpper(void* buffer, Nd4jLong const* shape, NDArray const& value, int diagonal, Nd4jLong rows, Nd4jLong cols, cudaStream_t& stream); - - template - _CUDA_H void setDiagonalValueLower(void* buffer, Nd4jLong const* shape, NDArray const& value, int diagonal, Nd4jLong rows, Nd4jLong cols, cudaStream_t& stream); - - template - _CUDA_H void templatedSwapUnsafe(void* theFirstBuffer, Nd4jLong const* theFirstShape, void* theSecondBuffer, Nd4jLong const* theSecondShape, cudaStream_t* theStream); - -} +template +_CUDA_H void fillIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, + const Nd4jLong *xShapeInfo, Nd4jLong length, + long idx); + +template +_CUDA_H void fillDimensionalIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, + const void *dX, void *dZ, + const Nd4jLong *zShapeInfo, + const Nd4jLong *tadOnlyShapeInfo, + int *dimension, int dimensionLength, + const Nd4jLong *tadOffsets); + +template +_CUDA_H void convertToHalfGeneric(dim3 &launchDims, cudaStream_t *stream, + void *dx, Nd4jLong n, half *dz); + +template +_CUDA_H void tearKernelGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, Nd4jPointer *targets, + Nd4jLong const *zShapeInfo, + Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets); + +template +_CUDA_H void shuffleKernelGeneric(dim3 &launchDims, cudaStream_t *stream, + void **vdX, Nd4jLong **xShapeInfo, void **vdZ, + int N, int *shuffleMap, + Nd4jLong **tadOnlyShapeInfo, + Nd4jLong **tadOffsets); + +template +_CUDA_H void convertHalfsToGeneric(dim3 &launchDims, cudaStream_t *stream, + half *dx, Nd4jLong n, void *dz); + +template +_CUDA_H void concatKernelVStackGeneric(dim3 &launchDims, cudaStream_t *stream, + int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfos, void *vz, + Nd4jLong const *zShapeInfo); + +template +_CUDA_H void concatKernelScalarGeneric(dim3 &launchDims, cudaStream_t *stream, + int numArrays, Nd4jPointer *data, + void *vresult); + +template +_CUDA_H void concatKernelHStackGeneric(dim3 &launchDims, cudaStream_t *stream, + int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfos, + void *vresult, + Nd4jLong const *resultShapeInfo); + +template +_CUDA_H void concatKernelGeneric(dim3 &launchDims, cudaStream_t *stream, + int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfos, void *vresult, + Nd4jLong const *resultShapeInfo, + Nd4jPointer *tadPointers, + Nd4jPointer *offsetPointers, + Nd4jLong const *zTadShape, + Nd4jLong const *zOffsets); + +template +_CUDA_H void pullRowsKernelGeneric( + dim3 &launchDims, cudaStream_t *stream, void *vx, void *vz, Nd4jLong n, + Nd4jLong *indexes, Nd4jLong const *tadShapeInfo, Nd4jLong const *tadOffsets, + Nd4jLong const *zTadShapeInfo, Nd4jLong const *zTadOffsets); + +template +_CUDA_H void averagingKernelGeneric(dim3 &launchDims, cudaStream_t *stream, + void **vdx, void *vdz, int n, + Nd4jLong length, bool propagate); + +template +_CUDA_H void accumulateKernelGeneric(dim3 &launchDims, cudaStream_t *stream, + void **vx, void *vz, int n, + const Nd4jLong length); + +template +_CUDA_H void flattenKernelGeneric(dim3 &launchDims, cudaStream_t *stream, + Nd4jPointer *extraPointers, int dOffset, + char order, void *vz, Nd4jLong *zShapeInfo, + void *vy, Nd4jLong *yShapeInfo); + +template +_CUDA_H void tileKernelH(void const *inputBuffer, Nd4jLong const *inputShape, + void *outputBuffer, Nd4jLong const *outputShape, + Nd4jLong resultLength, cudaStream_t *stream); +template +_CUDA_H void tileKernelHH(void const *inputBuffer, Nd4jLong const *inputShape, + void *outputBuffer, Nd4jLong const *outputShape, + Nd4jLong resultLength, Nd4jLong ews, + cudaStream_t *stream); + +class NDArray; +template +_CUDA_H void setDiagonalValueUpper(void *buffer, Nd4jLong const *shape, + NDArray const &value, int diagonal, + Nd4jLong rows, Nd4jLong cols, + cudaStream_t &stream); + +template +_CUDA_H void setDiagonalValueLower(void *buffer, Nd4jLong const *shape, + NDArray const &value, int diagonal, + Nd4jLong rows, Nd4jLong cols, + cudaStream_t &stream); + +template +_CUDA_H void templatedSwapUnsafe(void *theFirstBuffer, + Nd4jLong const *theFirstShape, + void *theSecondBuffer, + Nd4jLong const *theSecondShape, + cudaStream_t *theStream); + +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/loops/summarystatsreduce.h b/libnd4j/include/loops/summarystatsreduce.h old mode 100755 new mode 100644 index 1ab06a11bd89..c2601df33283 --- a/libnd4j/include/loops/summarystatsreduce.h +++ b/libnd4j/include/loops/summarystatsreduce.h @@ -23,15 +23,14 @@ #ifndef SUMMARYSTATSREDUCE_H_ #define SUMMARYSTATSREDUCE_H_ +#include #include #include - -#include #ifdef __CUDACC__ #include #include -#define host_and_device inline __host__ __device__ +#define host_and_device inline __host__ __device__ #else #define host_and_device inline #endif @@ -46,288 +45,274 @@ #include "legacy_ops.h" namespace functions { - namespace summarystats { - - // This example computes several statistical properties of a data - // series in a single reduction. The algorithm is described in detail here: - // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm - // - // Thanks to Joseph Rhoads for contributing this example - - - // structure used to accumulate the moments and other - // statistical properties encountered so far. - template - class SummaryStatsData { - - public: - double n; - double min; - double max; - double mean; - double M2; - double M3; - double M4; - double bias; - - _CUDA_HD SummaryStatsData() { - initialize(); - } - - // initialize to the identity element - - _CUDA_HD void initialize() { - n = mean = M2 = M3 = M4 = bias = 0; - } - - _CUDA_HD void initWithValue(X val) { - n = 1; - min = val; - max = val; - mean = val; - M2 = 0; - M3 = 0; - M4 = 0; - bias = 0; - } - - _CUDA_HD void setValues(SummaryStatsData *target) { - n = target->n; - min = target->min; - max = target->max; - mean = target->mean; - M2 = target->M2; - M3 = target->M3; - M4 = target->M4; - bias = target->bias; - } - - _CUDA_HD double variance() { - if (n <= 1.0) - return 0.0; - return M2 / (n); - } - - _CUDA_HD double varianceBiasCorrected() { - if (this->n <= 1.0) { - return 0.0; - } - - return M2 / (n - 1.0); - } - - - _CUDA_HD double variance_n() { - if (n <= 1.0) - return 0.0; - return M2 / n; - } - - _CUDA_HD double skewness() { return M2 > 0.0 ? sd::math::nd4j_sqrt(n) * M3 / sd::math::nd4j_pow(M2, 1.5) : 0.0; } - - _CUDA_HD double kurtosis() { return M2 > 0.0 ? n * M4 / (M2 * M2) : 0; } - - _CUDA_HD double getM2() { - return M2; - } - - _CUDA_HD void setM2(X m2) { - M2 = m2; - } - - _CUDA_HD double getM3() { - return M3; - } - - _CUDA_HD void setM3(X m3) { - M3 = m3; - } - - _CUDA_HD double getM4() { - return M4; - } - - _CUDA_HD void setM4(X m4) { - M4 = m4; - } - - _CUDA_HD double getMax() { - return max; - } - - _CUDA_HD void setMax(X max) { - this->max = max; - } - - _CUDA_HD double getMean() { - return mean; - } - - _CUDA_HD void setMean(X mean) { - this->mean = mean; - } - - _CUDA_HD double getMin() { - return min; - } - - _CUDA_HD void setMin(X min) { - this->min = min; - } - - _CUDA_HD double getN() { - return n; - } - - _CUDA_HD void setN(X n) { - this->n = n; - } - }; +namespace summarystats { + +// This example computes several statistical properties of a data +// series in a single reduction. The algorithm is described in detail here: +// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm +// +// Thanks to Joseph Rhoads for contributing this example + +// structure used to accumulate the moments and other +// statistical properties encountered so far. +template +class SummaryStatsData { + public: + double n; + double min; + double max; + double mean; + double M2; + double M3; + double M4; + double bias; + + _CUDA_HD SummaryStatsData() { initialize(); } + + // initialize to the identity element + + _CUDA_HD void initialize() { n = mean = M2 = M3 = M4 = bias = 0; } + + _CUDA_HD void initWithValue(X val) { + n = 1; + min = val; + max = val; + mean = val; + M2 = 0; + M3 = 0; + M4 = 0; + bias = 0; + } + + _CUDA_HD void setValues(SummaryStatsData* target) { + n = target->n; + min = target->min; + max = target->max; + mean = target->mean; + M2 = target->M2; + M3 = target->M3; + M4 = target->M4; + bias = target->bias; + } + + _CUDA_HD double variance() { + if (n <= 1.0) return 0.0; + return M2 / (n); + } + + _CUDA_HD double varianceBiasCorrected() { + if (this->n <= 1.0) { + return 0.0; + } -#ifdef __CUDACC__ - // This is the un-specialized struct. Note that we prevent instantiation of this -// struct by putting an undefined symbol in the function body so it won't compile. - template - struct SharedSummaryStatsData { - // Ensure that we won't compile any un-specialized types - __device__ T * getPointer() { - extern __device__ void error(void); - error(); - return 0; - } - }; - - // Following are the specializations for the following types. - // int, uint, char, uchar, short, ushort, long long, ulong long, bool, float, and double - // One could also specialize it for user-defined types. - - template<> - struct SharedSummaryStatsData { - __device__ SummaryStatsData * getPointer() { - extern __shared__ SummaryStatsData s_int2[]; - return s_int2; - } - }; - // Following are the specializations for the following types. - // int, uint, char, uchar, short, ushort, long long, ulong long, bool, float, and double - // One could also specialize it for user-defined types. - - template<> - struct SharedSummaryStatsData { - __device__ SummaryStatsData * getPointer() { - extern __shared__ SummaryStatsData s_int6[]; - return s_int6; - } - }; -#endif + return M2 / (n - 1.0); + } - /** - * Standard deviation or variance 1 pass - */ - template - class SummaryStatsReduce { - public: - //calculate an update of the reduce operation - _CUDA_HD static SummaryStatsData update(SummaryStatsData x, SummaryStatsData y, - void* extraParams) { - if ((long) x.n == 0 && (long) y.n > 0) - return y; - else if ((long) x.n > 0 && (long) y.n == 0) - return x; - SummaryStatsData vz; - double n = x.n + y.n; - double n2 = n * n; - double n3 = n2 * n; - - - double delta = y.mean - x.mean; - double delta2 = delta * delta; - double delta3 = delta2 * delta; - double delta4 = delta3 * delta; - - //Basic number of samples (n), min, and max - vz.n = n; - vz.min = sd::math::nd4j_min(x.min, y.min); - vz.max = sd::math::nd4j_max(x.max, y.max); - double meanD = x.mean + delta * y.n / n; - vz.mean = meanD; - double M2D = x.M2 + y.M2; - M2D += delta2 * x.n * y.n / n; - vz.M2 = M2D; - vz.M3 = x.M3 + y.M3; - vz.M3 += delta3 * x.n * y.n * (x.n - y.n) / n2; - vz.M3 += 3.0 * delta * (x.n * y.M2 - y.n * x.M2) / n; - - vz.M4 = x.M4 + y.M4; - vz.M4 += delta4 * x.n * y.n * (x.n * x.n - x.n * y.n + y.n * y.n) / n3; - vz.M4 += 6.0 * delta2 * (x.n * x.n * y.M2 + y.n * y.n * x.M2) / n2; - vz.M4 += 4.0 * delta * (x.n * y.M3 - y.n * x.M3) / n; - - return vz; - } + _CUDA_HD double variance_n() { + if (n <= 1.0) return 0.0; + return M2 / n; + } + _CUDA_HD double skewness() { + return M2 > 0.0 ? sd::math::nd4j_sqrt(n) * M3 / + sd::math::nd4j_pow(M2, 1.5) + : 0.0; + } + _CUDA_HD double kurtosis() { return M2 > 0.0 ? n * M4 / (M2 * M2) : 0; } -#ifdef __CUDACC__ + _CUDA_HD double getM2() { return M2; } - static inline _CUDA_D Z startingValue(X const* input) { - return static_cast(0); - } + _CUDA_HD void setM2(X m2) { M2 = m2; } - template - static _CUDA_D void aggregatePartials(SummaryStatsData **sPartialsRef, Nd4jLong tid, Nd4jLong numElements, void *extraParams); + _CUDA_HD double getM3() { return M3; } + _CUDA_HD void setM3(X m3) { M3 = m3; } - template - static _CUDA_D void transform(void const* dx, Nd4jLong const* xShapeInfo, void *extraParams, void *vz, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets); + _CUDA_HD double getM4() { return M4; } - static _CUDA_D void transform(const int opNum, void const* dx, Nd4jLong const* xShapeInfo, void *extraParams, void *vz, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets); + _CUDA_HD void setM4(X m4) { M4 = m4; } - static _CUDA_H void execSummaryStatsReduceScalar(dim3& launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void *extraParams, void *vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool biasCorrected, void *reductionBuffer); - static _CUDA_H void execSummaryStatsReduce(dim3& launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void *extraParams, void *vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool biasCorrected, void *reductionBuffer); - static _CUDA_H void execSummaryStatsReduce(dim3& launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void *extraParams, void *vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool biasCorrected, void *reductionBuffer); -#else + _CUDA_HD double getMax() { return max; } + + _CUDA_HD void setMax(X max) { this->max = max; } + + _CUDA_HD double getMean() { return mean; } + + _CUDA_HD void setMean(X mean) { this->mean = mean; } + + _CUDA_HD double getMin() { return min; } + + _CUDA_HD void setMin(X min) { this->min = min; } + + _CUDA_HD double getN() { return n; } + + _CUDA_HD void setN(X n) { this->n = n; } +}; - static Z execScalar(int opNum, - bool biasCorrected, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams); - - static void execScalar(int opNum, - bool biasCorrected, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *resultShapeInfoBuffer); - - static void exec(int opNum, - bool biasCorrected, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *resultShapeInfoBuffer, - int *dimension, int dimensionLength); - - template - static Z execScalar(bool biasCorrected, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams); - - template - static void execScalar(bool biasCorrected, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *resultShapeInfoBuffer); - - - template - static void exec(bool biasCorrected, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *resultShapeInfoBuffer, - int *dimension, int dimensionLength); +#ifdef __CUDACC__ +// This is the un-specialized struct. Note that we prevent instantiation of +// this +// struct by putting an undefined symbol in the function body so it won't +// compile. +template +struct SharedSummaryStatsData { + // Ensure that we won't compile any un-specialized types + __device__ T* getPointer() { + extern __device__ void error(void); + error(); + return 0; + } +}; + +// Following are the specializations for the following types. +// int, uint, char, uchar, short, ushort, long long, ulong long, bool, float, +// and double One could also specialize it for user-defined types. + +template <> +struct SharedSummaryStatsData { + __device__ SummaryStatsData* getPointer() { + extern __shared__ SummaryStatsData s_int2[]; + return s_int2; + } +}; +// Following are the specializations for the following types. +// int, uint, char, uchar, short, ushort, long long, ulong long, bool, float, +// and double One could also specialize it for user-defined types. + +template <> +struct SharedSummaryStatsData { + __device__ SummaryStatsData* getPointer() { + extern __shared__ SummaryStatsData s_int6[]; + return s_int6; + } +}; #endif - }; - } -} +/** + * Standard deviation or variance 1 pass + */ +template +class SummaryStatsReduce { + public: + // calculate an update of the reduce operation + _CUDA_HD static SummaryStatsData update(SummaryStatsData x, + SummaryStatsData y, + void* extraParams) { + if ((long)x.n == 0 && (long)y.n > 0) + return y; + else if ((long)x.n > 0 && (long)y.n == 0) + return x; + SummaryStatsData vz; + double n = x.n + y.n; + double n2 = n * n; + double n3 = n2 * n; + + double delta = y.mean - x.mean; + double delta2 = delta * delta; + double delta3 = delta2 * delta; + double delta4 = delta3 * delta; + + // Basic number of samples (n), min, and max + vz.n = n; + vz.min = sd::math::nd4j_min(x.min, y.min); + vz.max = sd::math::nd4j_max(x.max, y.max); + double meanD = x.mean + delta * y.n / n; + vz.mean = meanD; + double M2D = x.M2 + y.M2; + M2D += delta2 * x.n * y.n / n; + vz.M2 = M2D; + vz.M3 = x.M3 + y.M3; + vz.M3 += delta3 * x.n * y.n * (x.n - y.n) / n2; + vz.M3 += 3.0 * delta * (x.n * y.M2 - y.n * x.M2) / n; + + vz.M4 = x.M4 + y.M4; + vz.M4 += delta4 * x.n * y.n * (x.n * x.n - x.n * y.n + y.n * y.n) / n3; + vz.M4 += 6.0 * delta2 * (x.n * x.n * y.M2 + y.n * y.n * x.M2) / n2; + vz.M4 += 4.0 * delta * (x.n * y.M3 - y.n * x.M3) / n; + + return vz; + } + +#ifdef __CUDACC__ + + static inline _CUDA_D Z startingValue(X const* input) { + return static_cast(0); + } + + template + static _CUDA_D void aggregatePartials(SummaryStatsData** sPartialsRef, + Nd4jLong tid, Nd4jLong numElements, + void* extraParams); + + template + static _CUDA_D void transform(void const* dx, Nd4jLong const* xShapeInfo, + void* extraParams, void* vz, + Nd4jLong const* zShapeInfo, int* dimension, + int dimensionLength, int postProcessOrNot, + int* allocationBuffer, void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets); + + static _CUDA_D void transform(const int opNum, void const* dx, + Nd4jLong const* xShapeInfo, void* extraParams, + void* vz, Nd4jLong const* zShapeInfo, + int* dimension, int dimensionLength, + int postProcessOrNot, int* allocationBuffer, + void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets); + + static _CUDA_H void execSummaryStatsReduceScalar( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* x, + Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, + void* extraParams, void* vz, Nd4jLong const* zShapeInfo, + Nd4jLong const* hzShapeInfo, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, bool biasCorrected, void* reductionBuffer); + static _CUDA_H void execSummaryStatsReduce( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* x, + Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, + void* extraParams, void* vz, Nd4jLong const* zShapeInfo, + Nd4jLong const* hzShapeInfo, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, bool biasCorrected, void* reductionBuffer); + static _CUDA_H void execSummaryStatsReduce( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* x, + Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, + void* extraParams, void* vz, Nd4jLong const* zShapeInfo, + Nd4jLong const* hzShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + bool biasCorrected, void* reductionBuffer); +#else + + static Z execScalar(int opNum, bool biasCorrected, const void *x, + const Nd4jLong *xShapeInfo, void *extraParams); + + static void execScalar(int opNum, bool biasCorrected, const void *x, + const Nd4jLong *xShapeInfo, void *extraParams, + void *vz, const Nd4jLong *resultShapeInfoBuffer); + + static void exec(int opNum, bool biasCorrected, const void *x, + const Nd4jLong *xShapeInfo, void *extraParams, void *vz, + const Nd4jLong *resultShapeInfoBuffer, int *dimension, + int dimensionLength); + + template + static Z execScalar(bool biasCorrected, const void *x, + const Nd4jLong *xShapeInfo, void *extraParams); + + template + static void execScalar(bool biasCorrected, const void *x, + const Nd4jLong *xShapeInfo, void *extraParams, + void *vz, const Nd4jLong *resultShapeInfoBuffer); + + template + static void exec(bool biasCorrected, const void *x, + const Nd4jLong *xShapeInfo, void *extraParams, void *vz, + const Nd4jLong *resultShapeInfoBuffer, int *dimension, + int dimensionLength); +#endif +}; +} // namespace summarystats +} // namespace functions #endif /* SUMMARYSTATSREDUCE_H_ */ diff --git a/libnd4j/include/loops/transform_any.h b/libnd4j/include/loops/transform_any.h index 751328b89840..635c956a1038 100644 --- a/libnd4j/include/loops/transform_any.h +++ b/libnd4j/include/loops/transform_any.h @@ -24,15 +24,16 @@ #ifndef TRANSFORM_ANY_H_ #define TRANSFORM_ANY_H_ -#include #include #include +#include + #ifdef _OPENMP #include #endif -#include #include +#include //#include //#include @@ -46,56 +47,54 @@ #include "legacy_ops.h" - namespace functions { namespace transform { -template +template class TransformAny { - public: - + public: #ifdef __CUDACC__ - template - static __device__ void transformCuda(const void *vx, const Nd4jLong *xShapeInfo, - void *params, - void *vz, const Nd4jLong *zShapeInfo, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - template - static _CUDA_H void intermediateShaped(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - static _CUDA_H void executeTransformShaped(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); + template + static __device__ void transformCuda(const void *vx, + const Nd4jLong *xShapeInfo, void *params, + void *vz, const Nd4jLong *zShapeInfo, + int *allocationPointer, + void *reductionPointer, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + template + static _CUDA_H void intermediateShaped(dim3 launchDims, cudaStream_t *stream, + const void *x, const Nd4jLong *xShape, + int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, + int *allocationPointer, + void *reductionPointer, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static _CUDA_H void executeTransformShaped( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); #else - static void exec(int opNum, - const void *dx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads); - - template - static ND4J_EXPORT void exec(const void *dx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads); + static void exec(int opNum, const void *dx, const Nd4jLong *xShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, void *extraParams, + uint64_t threadId, uint64_t numThreads); + + template + static SD_EXPORT void exec(const void *dx, const Nd4jLong *xShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, + void *extraParams, uint64_t threadId, + uint64_t numThreads); #endif }; -} -} - +} // namespace transform +} // namespace functions #endif /* TRANSFORM_H_ */ diff --git a/libnd4j/include/loops/transform_bool.h b/libnd4j/include/loops/transform_bool.h index 5553c164fa5d..421e1695af70 100644 --- a/libnd4j/include/loops/transform_bool.h +++ b/libnd4j/include/loops/transform_bool.h @@ -24,15 +24,16 @@ #ifndef TRANSFORM_BOOL_H_ #define TRANSFORM_BOOL_H_ -#include #include #include +#include + #ifdef _OPENMP #include #endif -#include #include +#include //#include //#include @@ -46,55 +47,51 @@ #include "legacy_ops.h" - namespace functions { - namespace transform { - - template - class TransformBool { - public: +namespace transform { +template +class TransformBool { + public: #ifdef __CUDACC__ - template - static __device__ void transformCuda(const void *dy, const Nd4jLong *shapeInfo, - void *params, - void *result, const Nd4jLong *resultShapeInfo, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - template - static _CUDA_H void intermediateShaped(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - static _CUDA_H void executeTransformShaped(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); + template + static __device__ void transformCuda( + const void *dy, const Nd4jLong *shapeInfo, void *params, void *result, + const Nd4jLong *resultShapeInfo, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + template + static _CUDA_H void intermediateShaped(dim3 launchDims, cudaStream_t *stream, + const void *x, const Nd4jLong *xShape, + int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, + int *allocationPointer, + void *reductionPointer, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static _CUDA_H void executeTransformShaped( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); #else - static void exec(int opNum, - const void *dx, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads); + static void exec(int opNum, const void *dx, const Nd4jLong *xShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, + void *extraParams, uint64_t threadId, uint64_t numThreads); - template - static ND4J_EXPORT void exec(const void *dx, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads); + template + static SD_EXPORT void exec(const void *dx, const Nd4jLong *xShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, + void *extraParams, uint64_t threadId, + uint64_t numThreads); #endif - }; - } -} - +}; +} // namespace transform +} // namespace functions #endif /* TRANSFORM_H_ */ diff --git a/libnd4j/include/loops/transform_float.h b/libnd4j/include/loops/transform_float.h index 4264278ba5e2..b31ded716dcf 100644 --- a/libnd4j/include/loops/transform_float.h +++ b/libnd4j/include/loops/transform_float.h @@ -24,15 +24,16 @@ #ifndef TRANSFORM_FLOAT_H_ #define TRANSFORM_FLOAT_H_ -#include #include #include +#include + #ifdef _OPENMP #include #endif -#include #include +#include //#include //#include @@ -46,70 +47,64 @@ #include "legacy_ops.h" - namespace functions { - namespace transform { - - template - class TransformFloat { - public: +namespace transform { +template +class TransformFloat { + public: #ifdef __CUDACC__ - template - static __device__ void transformCuda(const void *dy, const Nd4jLong *shapeInfo, - void *params, - void *result, const Nd4jLong *resultShapeInfo, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - static __device__ void transformCudaLegacy(int opNum, - const void *dy, const Nd4jLong *shapeInfo, - void *params, - void *result, const Nd4jLong *resultShapeInfo, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - template - static __device__ void transformCuda(Nd4jLong n, - const void *dy, Nd4jLong incy, - void *params, - void *result, Nd4jLong resultStride, - int *allocationPointer, void *reductionPointer); - - - template - static _CUDA_H void intermediateShaped(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - static _CUDA_H void executeTransformShaped(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); + template + static __device__ void transformCuda( + const void *dy, const Nd4jLong *shapeInfo, void *params, void *result, + const Nd4jLong *resultShapeInfo, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static __device__ void transformCudaLegacy( + int opNum, const void *dy, const Nd4jLong *shapeInfo, void *params, + void *result, const Nd4jLong *resultShapeInfo, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + template + static __device__ void transformCuda(Nd4jLong n, const void *dy, + Nd4jLong incy, void *params, + void *result, Nd4jLong resultStride, + int *allocationPointer, + void *reductionPointer); + + template + static _CUDA_H void intermediateShaped(dim3 launchDims, cudaStream_t *stream, + const void *x, const Nd4jLong *xShape, + int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, + int *allocationPointer, + void *reductionPointer, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static _CUDA_H void executeTransformShaped( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); #else - static void exec(int opNum, - const void *dx, const Nd4jLong *xShapeInfo, + static void exec(int opNum, const void *dx, const Nd4jLong *xShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, + void *extraParams, uint64_t threadId, uint64_t numThreads); + + template + static SD_EXPORT void exec(const void *dx, const Nd4jLong *xShapeInfo, void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads); - - template - static ND4J_EXPORT void exec(const void *dx, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads); + void *extraParams, uint64_t threadId, + uint64_t numThreads); #endif - }; - } -} - +}; +} // namespace transform +} // namespace functions #endif /* TRANSFORM_H_ */ diff --git a/libnd4j/include/loops/transform_same.h b/libnd4j/include/loops/transform_same.h index cb069ecc9121..e9c73293331f 100644 --- a/libnd4j/include/loops/transform_same.h +++ b/libnd4j/include/loops/transform_same.h @@ -24,10 +24,11 @@ #ifndef TRANSFORM_SAME_H_ #define TRANSFORM_SAME_H_ -#include #include #include +#include + #ifdef _OPENMP #include #endif @@ -46,57 +47,51 @@ #include "legacy_ops.h" - namespace functions { - namespace transform { - - template - class TransformSame { - public: +namespace transform { +template +class TransformSame { + public: #ifdef __CUDACC__ - template - static __device__ void transformCuda(const void *dy, const Nd4jLong *shapeInfo, - void *params, - void *result, const Nd4jLong *resultShapeInfo, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - - template - static _CUDA_H void intermediateShaped(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - static _CUDA_H void executeTransformShaped(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - + template + static __device__ void transformCuda( + const void *dy, const Nd4jLong *shapeInfo, void *params, void *result, + const Nd4jLong *resultShapeInfo, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + template + static _CUDA_H void intermediateShaped(dim3 launchDims, cudaStream_t *stream, + const void *x, const Nd4jLong *xShape, + int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, + int *allocationPointer, + void *reductionPointer, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static _CUDA_H void executeTransformShaped( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); #else - static void exec(int opNum, - const void *dx, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads); + static void exec(int opNum, const void *dx, const Nd4jLong *xShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, + void *extraParams, uint64_t threadId, uint64_t numThreads); - template - static ND4J_EXPORT void exec(const void *dx, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads); + template + static SD_EXPORT void exec(const void *dx, const Nd4jLong *xShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, + void *extraParams, uint64_t threadId, + uint64_t numThreads); #endif - }; - } -} - +}; +} // namespace transform +} // namespace functions #endif /* TRANSFORM_H_ */ diff --git a/libnd4j/include/loops/transform_strict.h b/libnd4j/include/loops/transform_strict.h index 903f4e9df371..b5da49ff839e 100644 --- a/libnd4j/include/loops/transform_strict.h +++ b/libnd4j/include/loops/transform_strict.h @@ -24,15 +24,16 @@ #ifndef TRANSFORM_STRICT_H_ #define TRANSFORM_STRICT_H_ -#include #include #include +#include + #ifdef _OPENMP #include #endif -#include #include +#include //#include //#include @@ -46,60 +47,53 @@ #include "legacy_ops.h" - namespace functions { - namespace transform { - - template - class TransformStrict { - public: +namespace transform { +template +class TransformStrict { + public: #ifdef __CUDACC__ - template - static __device__ void transformCuda(const void *dy, const Nd4jLong *shapeInfo, - void *params, - void *result, const Nd4jLong *resultShapeInfo, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - - template - static _CUDA_H void intermediateShaped(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - static _CUDA_H void executeTransformShaped(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); + template + static __device__ void transformCuda( + const void *dy, const Nd4jLong *shapeInfo, void *params, void *result, + const Nd4jLong *resultShapeInfo, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + template + static _CUDA_H void intermediateShaped(dim3 launchDims, cudaStream_t *stream, + const void *x, const Nd4jLong *xShape, + int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, + int *allocationPointer, + void *reductionPointer, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static _CUDA_H void executeTransformShaped( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); #else + static void exec(int opNum, const void *dx, const Nd4jLong *xShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, + void *extraParams, uint64_t threadId, uint64_t numThreads); - - static void exec(int opNum, - const void *dx, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads); - - template - static ND4J_EXPORT void exec(const void *dx, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads); + template + static SD_EXPORT void exec(const void *dx, const Nd4jLong *xShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, + void *extraParams, uint64_t threadId, + uint64_t numThreads); #endif - }; - } -} - +}; +} // namespace transform +} // namespace functions #endif /* TRANSFORM_H_ */ diff --git a/libnd4j/include/loops/type_conversions.h b/libnd4j/include/loops/type_conversions.h index b56921435b16..300accd81f3c 100644 --- a/libnd4j/include/loops/type_conversions.h +++ b/libnd4j/include/loops/type_conversions.h @@ -15,8 +15,8 @@ ******************************************************************************/ /* - * This set of methods provides dataType conversions in all possible directions supported: - * FP8, FP16, FLOAT, DOUBLE, INT8, UINT8, UINT16, + * This set of methods provides dataType conversions in all possible directions + * supported: FP8, FP16, FLOAT, DOUBLE, INT8, UINT8, UINT16, * * @author raver119@gmail.com */ @@ -33,118 +33,126 @@ #define ND4J_FLOAT32 6 #define ND4J_DOUBLE 7 #define ND4J_THRESHOLD 8 -#define ND4J_FLOAT24 119 // not supported after all. might want to add support later. +#define ND4J_FLOAT24 \ + 119 // not supported after all. might want to add support later. -#include #include +#include +#include #include #include -#include -#include #include +#include #include -#include +#include #define NUM_BANKS 32 #define LOG_NUM_BANKS 4 - namespace sd { - typedef union { - float f_; - int i_; - } FloatBits; - - - class TypeCast { - - public: - template - static _CUDA_H void convertGeneric(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); +typedef union { + float f_; + int i_; +} FloatBits; - template - static _CUDA_H void convertToThreshold(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); +class TypeCast { + public: + template + static _CUDA_H void convertGeneric(Nd4jPointer *extras, void *dx, Nd4jLong N, + void *dz); - template - static _CUDA_H void convertFromThreshold(Nd4jPointer * extras, const void *dx, Nd4jLong N, void *dz); + template + static _CUDA_H void convertToThreshold(Nd4jPointer *extras, void *dx, + Nd4jLong N, void *dz); - FORCEINLINE static _CUDA_H Nd4jLong estimateQuantizedSize(Nd4jLong rawSize) { - if (rawSize <= 0) - throw std::runtime_error("Input size for quantization can't be <= 0"); + template + static _CUDA_H void convertFromThreshold(Nd4jPointer *extras, const void *dx, + Nd4jLong N, void *dz); - // 2 fp32 values for max/min, and rawSize number of BYTES - return 8 + rawSize; - } + FORCEINLINE static _CUDA_H Nd4jLong estimateQuantizedSize(Nd4jLong rawSize) { + if (rawSize <= 0) + throw std::runtime_error("Input size for quantization can't be <= 0"); + // 2 fp32 values for max/min, and rawSize number of BYTES + return 8 + rawSize; + } - template - static _CUDA_H void convertToQuantized(Nd4jPointer *extras, void *dx, Nd4jLong N, void *dz); + template + static _CUDA_H void convertToQuantized(Nd4jPointer *extras, void *dx, + Nd4jLong N, void *dz); - template - static _CUDA_H void convertFromQuantized(Nd4jPointer *extras, void *dx, Nd4jLong N, void *dz); - - #ifdef __CUDACC__ - template - static _CUDA_H void convertGenericCuda(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); - #endif - }; + template + static _CUDA_H void convertFromQuantized(Nd4jPointer *extras, void *dx, + Nd4jLong N, void *dz); +#ifdef __CUDACC__ + template + static _CUDA_H void convertGenericCuda(Nd4jPointer *extras, void *dx, + Nd4jLong N, void *dz); +#endif +}; - FORCEINLINE _CUDA_HD bool isPowerOfTwo(int n) { - return ((n&(n-1))==0) ; - } +FORCEINLINE _CUDA_HD bool isPowerOfTwo(int n) { return ((n & (n - 1)) == 0); } - FORCEINLINE _CUDA_HD int floorPow2(int n) { +FORCEINLINE _CUDA_HD int floorPow2(int n) { #ifdef WIN32 - // method 2 - return 1 << static_cast(logb(static_cast(n))); + // method 2 + return 1 << static_cast(logb(static_cast(n))); #else - // method 1 - // float nf = (float)n; - // return 1 << (((*(int*)&nf) >> 23) - 127); - int exp; - frexp(static_cast(n), &exp); - return 1 << (exp - 1); + // method 1 + // float nf = (float)n; + // return 1 << (((*(int*)&nf) >> 23) - 127); + int exp; + frexp(static_cast(n), &exp); + return 1 << (exp - 1); #endif - } +} #ifdef __CUDACC__ - __device__ __inline__ int pow2i (int e){ - return 1< - __host__ void encoderKernelP1Generic(dim3 &launchDims, cudaStream_t *stream, const void *dx, Nd4jLong N, void *dz, float threshold); - - - template - __host__ void encoderKernelP3Generic(dim3 &launchDims, cudaStream_t *stream, void *dx, int *offsets, Nd4jLong N, void *dz); - - - template - __host__ void decoderKernelGeneric(dim3 &launchDims, cudaStream_t *stream, const void *dx, Nd4jLong N, void *dz); - - template - __host__ void cudaEncodeBitmapGeneric(dim3 &launchDims, cudaStream_t *stream, void *vdx, Nd4jLong N, int *dz, int *scalar, int *reductionBuffer, float threshold); - - - template - __host__ void cudaDecodeBitmapGeneric(dim3 &launchDims, cudaStream_t *stream, const void *dx, Nd4jLong N, void *vdz); - - __global__ void uniformAdd(int *g_data, int *uniforms, int n, int blockOffset, int baseIndex); - - template - __global__ void prescan(int *g_odata, const int *g_idata, int *g_blockSums, int n, int blockIndex, int baseIndex); - - - template - __host__ void prescanLauncher(dim3 &blocks, dim3 &threads, int shmem, cudaStream_t *stream, int *g_odata, const int *g_idata, int *g_blockSums, int n, int blockIndex, int baseIndex); - - template - __global__ void convertKernel(void *dx, Nd4jLong N, void *dz); +__device__ __inline__ int pow2i(int e) { return 1 << e; } + +template +__host__ void encoderKernelP1Generic(dim3 &launchDims, cudaStream_t *stream, + const void *dx, Nd4jLong N, void *dz, + float threshold); + +template +__host__ void encoderKernelP3Generic(dim3 &launchDims, cudaStream_t *stream, + void *dx, int *offsets, Nd4jLong N, + void *dz); + +template +__host__ void decoderKernelGeneric(dim3 &launchDims, cudaStream_t *stream, + const void *dx, Nd4jLong N, void *dz); + +template +__host__ void cudaEncodeBitmapGeneric(dim3 &launchDims, cudaStream_t *stream, + void *vdx, Nd4jLong N, int *dz, + int *scalar, int *reductionBuffer, + float threshold); + +template +__host__ void cudaDecodeBitmapGeneric(dim3 &launchDims, cudaStream_t *stream, + const void *dx, Nd4jLong N, void *vdz); + +__global__ void uniformAdd(int *g_data, int *uniforms, int n, int blockOffset, + int baseIndex); + +template +__global__ void prescan(int *g_odata, const int *g_idata, int *g_blockSums, + int n, int blockIndex, int baseIndex); + +template +__host__ void prescanLauncher(dim3 &blocks, dim3 &threads, int shmem, + cudaStream_t *stream, int *g_odata, + const int *g_idata, int *g_blockSums, int n, + int blockIndex, int baseIndex); + +template +__global__ void convertKernel(void *dx, Nd4jLong N, void *dz); #endif -} +} // namespace sd -#endif //LIBND4J_TYPE_CONVERSIONS_H +#endif // LIBND4J_TYPE_CONVERSIONS_H diff --git a/libnd4j/include/math/platformmath.h b/libnd4j/include/math/platformmath.h index e4990cc87338..f12ca2295106 100644 --- a/libnd4j/include/math/platformmath.h +++ b/libnd4j/include/math/platformmath.h @@ -22,60 +22,54 @@ #define LIBND4J_PLATFORM_MATH_H #include -#include #include #include +#include + #ifdef __CUDACC__ -#include #include +#include union BPAIR { - struct { - bfloat16 H; - bfloat16 L; - } B; - int W; + struct { + bfloat16 H; + bfloat16 L; + } B; + int W; - __host__ __device__ - BPAIR() {}; + __host__ __device__ BPAIR(){}; - __host__ __device__ - ~BPAIR() {}; + __host__ __device__ ~BPAIR(){}; }; #define math_def __host__ __device__ #ifdef CUDA_8 typedef union { - struct { - half H; - half L; - } B; - int W; + struct { + half H; + half L; + } B; + int W; } PAIR; #else -struct HALFS{ - half H; - half L; +struct HALFS { + half H; + half L; - __host__ __device__ - HALFS() {}; + __host__ __device__ HALFS(){}; - __host__ __device__ - ~HALFS() {}; - }; + __host__ __device__ ~HALFS(){}; +}; union PAIR { - HALFS B; - int W; - - __host__ __device__ - PAIR() {}; + HALFS B; + int W; - __host__ __device__ - ~PAIR(){} + __host__ __device__ PAIR(){}; + __host__ __device__ ~PAIR() {} }; -#endif // cuda_9 +#endif // cuda_9 #else #define math_def @@ -83,792 +77,799 @@ union PAIR { #endif - namespace sd { - namespace math { - template - math_def FORCEINLINE T p_exp(T value); +namespace math { +template +math_def FORCEINLINE T p_exp(T value); - template - math_def FORCEINLINE T p_log(T value); +template +math_def FORCEINLINE T p_log(T value); - template - math_def FORCEINLINE T p_floor(T value); +template +math_def FORCEINLINE T p_floor(T value); - template - math_def FORCEINLINE T p_ceil(T value); +template +math_def FORCEINLINE T p_ceil(T value); - template - math_def FORCEINLINE T p_round(T value); +template +math_def FORCEINLINE T p_round(T value); - template - math_def FORCEINLINE T p_cos(T value); +template +math_def FORCEINLINE T p_cos(T value); - template - math_def FORCEINLINE T p_cosh(T value); +template +math_def FORCEINLINE T p_cosh(T value); - template - math_def FORCEINLINE T p_acos(T value); +template +math_def FORCEINLINE T p_acos(T value); - template - math_def FORCEINLINE T p_acosh(T value); +template +math_def FORCEINLINE T p_acosh(T value); - template - math_def FORCEINLINE T p_sin(T value); +template +math_def FORCEINLINE T p_sin(T value); - template - math_def FORCEINLINE T p_sinh(T value); +template +math_def FORCEINLINE T p_sinh(T value); - template - math_def FORCEINLINE T p_asin(T value); +template +math_def FORCEINLINE T p_asin(T value); - template - math_def FORCEINLINE T p_sqrt(T value); +template +math_def FORCEINLINE T p_sqrt(T value); - template - math_def FORCEINLINE T p_tanh(T value); +template +math_def FORCEINLINE T p_tanh(T value); - template - math_def FORCEINLINE T p_erf(T value); +template +math_def FORCEINLINE T p_erf(T value); - template - math_def FORCEINLINE T p_erfc(T value); +template +math_def FORCEINLINE T p_erfc(T value); - template - math_def FORCEINLINE T p_atan(T value); +template +math_def FORCEINLINE T p_atan(T value); - template - math_def FORCEINLINE T p_tan(T value); +template +math_def FORCEINLINE T p_tan(T value); - template - math_def FORCEINLINE T p_atanh(T value); +template +math_def FORCEINLINE T p_atanh(T value); - template - math_def FORCEINLINE T p_rint(T value); +template +math_def FORCEINLINE T p_rint(T value); - template - math_def FORCEINLINE T p_rotl(T value, T shift); +template +math_def FORCEINLINE T p_rotl(T value, T shift); - template - math_def FORCEINLINE T p_rotr(T value, T shift); +template +math_def FORCEINLINE T p_rotr(T value, T shift); - template - math_def FORCEINLINE T p_remainder(T val1, T val2); +template +math_def FORCEINLINE T p_remainder(T val1, T val2); - template - math_def FORCEINLINE T p_fmod(T val1, T val2); +template +math_def FORCEINLINE T p_fmod(T val1, T val2); - template - math_def FORCEINLINE T p_pow(T value, T power); +template +math_def FORCEINLINE T p_pow(T value, T power); - template - math_def FORCEINLINE T p_atan2(T val1, T val2); +template +math_def FORCEINLINE T p_atan2(T val1, T val2); ////// - template <> - math_def FORCEINLINE float p_exp(float value) { - return expf(value); - } +template <> +math_def FORCEINLINE float p_exp(float value) { + return expf(value); +} - template <> - math_def FORCEINLINE float16 p_exp(float16 val) { +template <> +math_def FORCEINLINE float16 p_exp(float16 val) { #ifdef NATIVE_HALFS - return hexp(val.data); + return hexp(val.data); #else - return static_cast(expf((float) val)); + return static_cast(expf((float)val)); #endif - } +} - template <> - math_def FORCEINLINE bfloat16 p_exp(bfloat16 val) { - return static_cast(expf((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_exp(bfloat16 val) { + return static_cast(expf((float)val)); +} - template <> - math_def FORCEINLINE double p_exp(double value) { - return exp(value); - } +template <> +math_def FORCEINLINE double p_exp(double value) { + return exp(value); +} - template - math_def FORCEINLINE T p_exp(T value) { - return static_cast(expf(static_cast(value))); - } +template +math_def FORCEINLINE T p_exp(T value) { + return static_cast(expf(static_cast(value))); +} ///////// - template <> - math_def FORCEINLINE float16 p_pow(float16 value, float16 power) { - return static_cast(powf(static_cast(value), static_cast(power))); - } - - template <> - math_def FORCEINLINE bfloat16 p_pow(bfloat16 value, bfloat16 power) { - return static_cast(powf(static_cast(value), static_cast(power))); - } - - template <> - math_def FORCEINLINE float p_pow(float value, float power) { - return powf(value, power); - } - - template <> - math_def FORCEINLINE double p_pow(double value, double power) { - return pow(value, power); - } - - template - math_def FORCEINLINE T p_pow(T value, T power) { - return static_cast(powf(static_cast(value), static_cast(power))); - } +template <> +math_def FORCEINLINE float16 p_pow(float16 value, float16 power) { + return static_cast( + powf(static_cast(value), static_cast(power))); +} + +template <> +math_def FORCEINLINE bfloat16 p_pow(bfloat16 value, bfloat16 power) { + return static_cast( + powf(static_cast(value), static_cast(power))); +} + +template <> +math_def FORCEINLINE float p_pow(float value, float power) { + return powf(value, power); +} + +template <> +math_def FORCEINLINE double p_pow(double value, double power) { + return pow(value, power); +} + +template +math_def FORCEINLINE T p_pow(T value, T power) { + return static_cast( + powf(static_cast(value), static_cast(power))); +} ///////// - template <> - math_def FORCEINLINE float16 p_fmod(float16 value, float16 power) { - return static_cast(fmodf(static_cast(value), static_cast(power))); - } +template <> +math_def FORCEINLINE float16 p_fmod(float16 value, float16 power) { + return static_cast( + fmodf(static_cast(value), static_cast(power))); +} - template <> - math_def FORCEINLINE bfloat16 p_fmod(bfloat16 value, bfloat16 power) { - return static_cast(fmodf(static_cast(value), static_cast(power))); - } +template <> +math_def FORCEINLINE bfloat16 p_fmod(bfloat16 value, bfloat16 power) { + return static_cast( + fmodf(static_cast(value), static_cast(power))); +} - template <> - math_def FORCEINLINE float p_fmod(float value, float power) { - return fmodf(value, power); - } +template <> +math_def FORCEINLINE float p_fmod(float value, float power) { + return fmodf(value, power); +} - template <> - math_def FORCEINLINE double p_fmod(double value, double power) { - return fmod(value, power); - } +template <> +math_def FORCEINLINE double p_fmod(double value, double power) { + return fmod(value, power); +} - template - math_def FORCEINLINE T p_fmod(T value, T power) { - return static_cast(fmodf(static_cast(value), static_cast(power))); - } +template +math_def FORCEINLINE T p_fmod(T value, T power) { + return static_cast( + fmodf(static_cast(value), static_cast(power))); +} ///////// - template <> - math_def FORCEINLINE float16 p_atan2(float16 value, float16 power) { - return static_cast(atan2f(static_cast(value), static_cast(power))); - } +template <> +math_def FORCEINLINE float16 p_atan2(float16 value, float16 power) { + return static_cast( + atan2f(static_cast(value), static_cast(power))); +} - template <> - math_def FORCEINLINE float p_atan2(float value, float power) { - return atan2f(value, power); - } +template <> +math_def FORCEINLINE float p_atan2(float value, float power) { + return atan2f(value, power); +} - template <> - math_def FORCEINLINE double p_atan2(double value, double power) { - return atan2(value, power); - } +template <> +math_def FORCEINLINE double p_atan2(double value, double power) { + return atan2(value, power); +} - template - math_def FORCEINLINE T p_atan2(T value, T power) { - return static_cast(atan2f(static_cast(value), static_cast(power))); - } +template +math_def FORCEINLINE T p_atan2(T value, T power) { + return static_cast( + atan2f(static_cast(value), static_cast(power))); +} ///////// - template <> - math_def FORCEINLINE float16 p_remainder(float16 value, float16 power) { - return static_cast(remainderf(static_cast(value), static_cast(power))); - } - - template <> - math_def FORCEINLINE float p_remainder(float value, float power) { - return remainderf(value, power); - } - - template <> - math_def FORCEINLINE double p_remainder(double value, double power) { - return remainder(value, power); - } - - template - math_def FORCEINLINE T p_remainder(T value, T power) { - return static_cast(remainderf(static_cast(value), static_cast(power))); - } +template <> +math_def FORCEINLINE float16 p_remainder(float16 value, float16 power) { + return static_cast( + remainderf(static_cast(value), static_cast(power))); +} + +template <> +math_def FORCEINLINE float p_remainder(float value, float power) { + return remainderf(value, power); +} + +template <> +math_def FORCEINLINE double p_remainder(double value, double power) { + return remainder(value, power); +} + +template +math_def FORCEINLINE T p_remainder(T value, T power) { + return static_cast( + remainderf(static_cast(value), static_cast(power))); +} ///////// - template <> - math_def FORCEINLINE float p_log(float value) { - return logf(value); - } +template <> +math_def FORCEINLINE float p_log(float value) { + return logf(value); +} - template <> - math_def FORCEINLINE float16 p_log(float16 val) { +template <> +math_def FORCEINLINE float16 p_log(float16 val) { #ifdef NATIVE_HALFS - return hlog(val.data); + return hlog(val.data); #else - return static_cast(logf((float) val)); + return static_cast(logf((float)val)); #endif - } +} - template <> - math_def FORCEINLINE double p_log(double value) { - return log(value); - } +template <> +math_def FORCEINLINE double p_log(double value) { + return log(value); +} - template - math_def FORCEINLINE T p_log(T value) { - return static_cast(logf(static_cast(value))); - } +template +math_def FORCEINLINE T p_log(T value) { + return static_cast(logf(static_cast(value))); +} ///////// - template <> - math_def FORCEINLINE float p_floor(float value) { - return floorf(value); - } +template <> +math_def FORCEINLINE float p_floor(float value) { + return floorf(value); +} - template <> - math_def FORCEINLINE float16 p_floor(float16 val) { +template <> +math_def FORCEINLINE float16 p_floor(float16 val) { #ifdef NATIVE_HALFS - return hfloor(val.data); + return hfloor(val.data); #else - return static_cast(floorf((float) val)); + return static_cast(floorf((float)val)); #endif - } +} - template <> - math_def FORCEINLINE bfloat16 p_floor(bfloat16 value) { - return static_cast(floorf((float)value)); - } +template <> +math_def FORCEINLINE bfloat16 p_floor(bfloat16 value) { + return static_cast(floorf((float)value)); +} - template <> - math_def FORCEINLINE double p_floor(double value) { - return floor(value); - } +template <> +math_def FORCEINLINE double p_floor(double value) { + return floor(value); +} - template - math_def FORCEINLINE T p_floor(T value) { - return value; - } +template +math_def FORCEINLINE T p_floor(T value) { + return value; +} ///////// - template <> - math_def FORCEINLINE float p_ceil(float value) { - return ceilf(value); - } +template <> +math_def FORCEINLINE float p_ceil(float value) { + return ceilf(value); +} - template <> - math_def FORCEINLINE float16 p_ceil(float16 val) { +template <> +math_def FORCEINLINE float16 p_ceil(float16 val) { #ifdef NATIVE_HALFS - return hceil(val.data); + return hceil(val.data); #else - return static_cast(ceilf((float) val)); + return static_cast(ceilf((float)val)); #endif - } +} - template <> - math_def FORCEINLINE bfloat16 p_ceil(bfloat16 value) { - return static_cast(ceilf((float)value)); - } +template <> +math_def FORCEINLINE bfloat16 p_ceil(bfloat16 value) { + return static_cast(ceilf((float)value)); +} - template <> - math_def FORCEINLINE double p_ceil(double value) { - return ceil(value); - } +template <> +math_def FORCEINLINE double p_ceil(double value) { + return ceil(value); +} - template - math_def FORCEINLINE T p_ceil(T value) { - return value; - } +template +math_def FORCEINLINE T p_ceil(T value) { + return value; +} ///////// - template <> - math_def FORCEINLINE float p_round(float value) { - return roundf(value); - } - - template <> - math_def FORCEINLINE float16 p_round(float16 val) { - return static_cast(roundf((float) val)); - } +template <> +math_def FORCEINLINE float p_round(float value) { + return roundf(value); +} - template <> - math_def FORCEINLINE bfloat16 p_round(bfloat16 value) { - return static_cast(roundf((float)value)); - } +template <> +math_def FORCEINLINE float16 p_round(float16 val) { + return static_cast(roundf((float)val)); +} +template <> +math_def FORCEINLINE bfloat16 p_round(bfloat16 value) { + return static_cast(roundf((float)value)); +} - template <> - math_def FORCEINLINE double p_round(double value) { - return round(value); - } +template <> +math_def FORCEINLINE double p_round(double value) { + return round(value); +} - template - math_def FORCEINLINE T p_round(T value) { - return value; - } +template +math_def FORCEINLINE T p_round(T value) { + return value; +} ///////// - template <> - math_def FORCEINLINE float p_rint(float value) { - return rintf(value); - } +template <> +math_def FORCEINLINE float p_rint(float value) { + return rintf(value); +} - template <> - math_def FORCEINLINE float16 p_rint(float16 val) { +template <> +math_def FORCEINLINE float16 p_rint(float16 val) { #ifdef NATIVE_HALFS - return hrint(val.data); + return hrint(val.data); #else - return static_cast(rintf((float) val)); + return static_cast(rintf((float)val)); #endif - } +} - template <> - math_def FORCEINLINE bfloat16 p_rint(bfloat16 val) { - return static_cast(rintf((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_rint(bfloat16 val) { + return static_cast(rintf((float)val)); +} - template <> - math_def FORCEINLINE double p_rint(double value) { - return rint(value); - } +template <> +math_def FORCEINLINE double p_rint(double value) { + return rint(value); +} - template - math_def FORCEINLINE T p_rint(T value) { - return value; - } +template +math_def FORCEINLINE T p_rint(T value) { + return value; +} ///////// - template <> - math_def FORCEINLINE float p_cos(float value) { - return cosf(value); - } +template <> +math_def FORCEINLINE float p_cos(float value) { + return cosf(value); +} - template <> - math_def FORCEINLINE float16 p_cos(float16 val) { +template <> +math_def FORCEINLINE float16 p_cos(float16 val) { #ifdef NATIVE_HALFS - return hcos(val.data); + return hcos(val.data); #else - return static_cast(cosf((float) val)); + return static_cast(cosf((float)val)); #endif - } +} - template <> - math_def FORCEINLINE bfloat16 p_cos(bfloat16 val) { - return static_cast(cosf((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_cos(bfloat16 val) { + return static_cast(cosf((float)val)); +} - template <> - math_def FORCEINLINE double p_cos(double value) { - return cos(value); - } +template <> +math_def FORCEINLINE double p_cos(double value) { + return cos(value); +} ///////// - template <> - math_def FORCEINLINE float p_sin(float value) { - return sinf(value); - } +template <> +math_def FORCEINLINE float p_sin(float value) { + return sinf(value); +} - template <> - math_def FORCEINLINE float16 p_sin(float16 val) { +template <> +math_def FORCEINLINE float16 p_sin(float16 val) { #ifdef NATIVE_HALFS - return hsin(val.data); + return hsin(val.data); #else - return static_cast(sinf((float) val)); + return static_cast(sinf((float)val)); #endif - } +} - template <> - math_def FORCEINLINE bfloat16 p_sin(bfloat16 val) { - return static_cast(sinf((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_sin(bfloat16 val) { + return static_cast(sinf((float)val)); +} - template <> - math_def FORCEINLINE double p_sin(double value) { - return sin(value); - } +template <> +math_def FORCEINLINE double p_sin(double value) { + return sin(value); +} ///////// - template <> - math_def FORCEINLINE float p_sqrt(float value) { - return sqrtf(value); - } +template <> +math_def FORCEINLINE float p_sqrt(float value) { + return sqrtf(value); +} - template <> - math_def FORCEINLINE float16 p_sqrt(float16 val) { +template <> +math_def FORCEINLINE float16 p_sqrt(float16 val) { #ifdef NATIVE_HALFS - return hsqrt(val.data); + return hsqrt(val.data); #else - return static_cast(sqrtf((float) val)); + return static_cast(sqrtf((float)val)); #endif - } - template <> - math_def FORCEINLINE bfloat16 p_sqrt(bfloat16 val) { - return static_cast(sqrtf((float) val)); - } +} +template <> +math_def FORCEINLINE bfloat16 p_sqrt(bfloat16 val) { + return static_cast(sqrtf((float)val)); +} - template <> - math_def FORCEINLINE double p_sqrt(double value) { - return sqrt(value); - } +template <> +math_def FORCEINLINE double p_sqrt(double value) { + return sqrt(value); +} ///////// - template <> - math_def FORCEINLINE float p_tanh(float value) { - return tanhf(value); - } +template <> +math_def FORCEINLINE float p_tanh(float value) { + return tanhf(value); +} - template <> - math_def FORCEINLINE float16 p_tanh(float16 val) { - return static_cast(tanhf((float) val)); - } +template <> +math_def FORCEINLINE float16 p_tanh(float16 val) { + return static_cast(tanhf((float)val)); +} - template <> - math_def FORCEINLINE bfloat16 p_tanh(bfloat16 val) { - return static_cast(tanhf((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_tanh(bfloat16 val) { + return static_cast(tanhf((float)val)); +} - template <> - math_def FORCEINLINE double p_tanh(double value) { - return tanh(value); - } +template <> +math_def FORCEINLINE double p_tanh(double value) { + return tanh(value); +} ///////// - template <> - math_def FORCEINLINE float p_erf(float value) { - return erff(value); - } +template <> +math_def FORCEINLINE float p_erf(float value) { + return erff(value); +} - template <> - math_def FORCEINLINE float16 p_erf(float16 val) { - return static_cast(erff((float) val)); - } +template <> +math_def FORCEINLINE float16 p_erf(float16 val) { + return static_cast(erff((float)val)); +} - template <> - math_def FORCEINLINE bfloat16 p_erf(bfloat16 val) { - return static_cast(erff((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_erf(bfloat16 val) { + return static_cast(erff((float)val)); +} - template <> - math_def FORCEINLINE double p_erf(double value) { - return erf(value); - } +template <> +math_def FORCEINLINE double p_erf(double value) { + return erf(value); +} ///////// - template <> - math_def FORCEINLINE float p_erfc(float value) { - return erfcf(value); - } +template <> +math_def FORCEINLINE float p_erfc(float value) { + return erfcf(value); +} - template <> - math_def FORCEINLINE float16 p_erfc(float16 val) { - return static_cast(erfcf((float) val)); - } +template <> +math_def FORCEINLINE float16 p_erfc(float16 val) { + return static_cast(erfcf((float)val)); +} - template <> - math_def FORCEINLINE bfloat16 p_erfc(bfloat16 val) { - return static_cast(erfcf((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_erfc(bfloat16 val) { + return static_cast(erfcf((float)val)); +} - template <> - math_def FORCEINLINE double p_erfc(double value) { - return erfc(value); - } +template <> +math_def FORCEINLINE double p_erfc(double value) { + return erfc(value); +} ///////// - template <> - math_def FORCEINLINE float p_acos(float value) { - return acosf(value); - } +template <> +math_def FORCEINLINE float p_acos(float value) { + return acosf(value); +} - template <> - math_def FORCEINLINE float16 p_acos(float16 val) { - return static_cast(acosf((float) val)); - } +template <> +math_def FORCEINLINE float16 p_acos(float16 val) { + return static_cast(acosf((float)val)); +} - template <> - math_def FORCEINLINE bfloat16 p_acos(bfloat16 val) { - return static_cast(acosf((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_acos(bfloat16 val) { + return static_cast(acosf((float)val)); +} - template <> - math_def FORCEINLINE double p_acos(double value) { - return acos(value); - } +template <> +math_def FORCEINLINE double p_acos(double value) { + return acos(value); +} ///////// - template <> - math_def FORCEINLINE float p_sinh(float value) { - return sinhf(value); - } +template <> +math_def FORCEINLINE float p_sinh(float value) { + return sinhf(value); +} - template <> - math_def FORCEINLINE float16 p_sinh(float16 val) { - return static_cast(sinhf((float) val)); - } +template <> +math_def FORCEINLINE float16 p_sinh(float16 val) { + return static_cast(sinhf((float)val)); +} - template <> - math_def FORCEINLINE bfloat16 p_sinh(bfloat16 val) { - return static_cast(sinhf((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_sinh(bfloat16 val) { + return static_cast(sinhf((float)val)); +} - template <> - math_def FORCEINLINE double p_sinh(double value) { - return sinh(value); - } +template <> +math_def FORCEINLINE double p_sinh(double value) { + return sinh(value); +} ///////// - template <> - math_def FORCEINLINE float p_acosh(float value) { - return acoshf(value); - } +template <> +math_def FORCEINLINE float p_acosh(float value) { + return acoshf(value); +} - template <> - math_def FORCEINLINE float16 p_acosh(float16 val) { - return static_cast(acoshf((float) val)); - } +template <> +math_def FORCEINLINE float16 p_acosh(float16 val) { + return static_cast(acoshf((float)val)); +} - template <> - math_def FORCEINLINE bfloat16 p_acosh(bfloat16 val) { - return static_cast(acoshf((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_acosh(bfloat16 val) { + return static_cast(acoshf((float)val)); +} - template <> - math_def FORCEINLINE double p_acosh(double value) { - return acosh(value); - } +template <> +math_def FORCEINLINE double p_acosh(double value) { + return acosh(value); +} ///////// - template <> - math_def FORCEINLINE float p_cosh(float value) { - return coshf(value); - } - - template <> - math_def FORCEINLINE float16 p_cosh(float16 val) { - return static_cast(coshf((float) val)); - } +template <> +math_def FORCEINLINE float p_cosh(float value) { + return coshf(value); +} - template <> - math_def FORCEINLINE bfloat16 p_cosh(bfloat16 val) { - return static_cast(coshf((float) val)); - } +template <> +math_def FORCEINLINE float16 p_cosh(float16 val) { + return static_cast(coshf((float)val)); +} - template <> - math_def FORCEINLINE double p_cosh(double value) { - return cosh(value); - } +template <> +math_def FORCEINLINE bfloat16 p_cosh(bfloat16 val) { + return static_cast(coshf((float)val)); +} +template <> +math_def FORCEINLINE double p_cosh(double value) { + return cosh(value); +} ///////// - template <> - math_def FORCEINLINE float p_asin(float value) { - return asinf(value); - } +template <> +math_def FORCEINLINE float p_asin(float value) { + return asinf(value); +} - template <> - math_def FORCEINLINE float16 p_asin(float16 val) { - return static_cast(asinf((float) val)); - } +template <> +math_def FORCEINLINE float16 p_asin(float16 val) { + return static_cast(asinf((float)val)); +} - template <> - math_def FORCEINLINE bfloat16 p_asin(bfloat16 val) { - return static_cast(asinf((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_asin(bfloat16 val) { + return static_cast(asinf((float)val)); +} - template <> - math_def FORCEINLINE double p_asin(double value) { - return asin(value); - } +template <> +math_def FORCEINLINE double p_asin(double value) { + return asin(value); +} ///////// - template <> - math_def FORCEINLINE float p_atan(float value) { - return atanf(value); - } - - template <> - math_def FORCEINLINE float16 p_atan(float16 val) { - return static_cast(atanf((float) val)); - } +template <> +math_def FORCEINLINE float p_atan(float value) { + return atanf(value); +} - template <> - math_def FORCEINLINE bfloat16 p_atan(bfloat16 val) { - return static_cast(atanf((float) val)); - } +template <> +math_def FORCEINLINE float16 p_atan(float16 val) { + return static_cast(atanf((float)val)); +} - template <> - math_def FORCEINLINE double p_atan(double value) { - return atan(value); - } +template <> +math_def FORCEINLINE bfloat16 p_atan(bfloat16 val) { + return static_cast(atanf((float)val)); +} +template <> +math_def FORCEINLINE double p_atan(double value) { + return atan(value); +} ///////// - template <> - math_def FORCEINLINE float p_tan(float value) { - return tanf(value); - } +template <> +math_def FORCEINLINE float p_tan(float value) { + return tanf(value); +} - template <> - math_def FORCEINLINE float16 p_tan(float16 val) { - return static_cast(tanf((float) val)); - } +template <> +math_def FORCEINLINE float16 p_tan(float16 val) { + return static_cast(tanf((float)val)); +} - template <> - math_def FORCEINLINE bfloat16 p_tan(bfloat16 val) { - return static_cast(tanf((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_tan(bfloat16 val) { + return static_cast(tanf((float)val)); +} - template <> - math_def FORCEINLINE double p_tan(double value) { - return tan(value); - } +template <> +math_def FORCEINLINE double p_tan(double value) { + return tan(value); +} ///////// - template <> - math_def FORCEINLINE float p_atanh(float value) { - return atanhf(value); - } +template <> +math_def FORCEINLINE float p_atanh(float value) { + return atanhf(value); +} - template <> - math_def FORCEINLINE float16 p_atanh(float16 val) { - return static_cast(atanhf((float) val)); - } +template <> +math_def FORCEINLINE float16 p_atanh(float16 val) { + return static_cast(atanhf((float)val)); +} - template <> - math_def FORCEINLINE bfloat16 p_atanh(bfloat16 val) { - return static_cast(atanhf((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_atanh(bfloat16 val) { + return static_cast(atanhf((float)val)); +} - template <> - math_def FORCEINLINE double p_atanh(double value) { - return atanh(value); - } +template <> +math_def FORCEINLINE double p_atanh(double value) { + return atanh(value); +} ///////// - template - math_def FORCEINLINE T _rotate_left(T value, T shift); - - template - math_def FORCEINLINE T _rotate_right(T value, T shift); - - template <> - math_def FORCEINLINE int8_t _rotate_left(int8_t value, int8_t shift) { - return value << shift | value >> (8 - shift); - } - - template <> - math_def FORCEINLINE int8_t _rotate_right(int8_t value, int8_t shift) { - return value >> shift | value << (8 - shift); - } - - template <> - math_def FORCEINLINE uint8_t _rotate_left(uint8_t value, uint8_t shift) { - return value << shift | value >> (8 - shift); - } - - template <> - math_def FORCEINLINE uint8_t _rotate_right(uint8_t value, uint8_t shift) { - return value >> shift | value << (8 - shift); - } - - template <> - math_def FORCEINLINE int16_t _rotate_left(int16_t value, int16_t shift) { - return value << shift | value >> (16 - shift); - } - - template <> - math_def FORCEINLINE int16_t _rotate_right(int16_t value, int16_t shift) { - return value >> shift | value << (16 - shift); - } - - template <> - math_def FORCEINLINE uint16_t _rotate_left(uint16_t value, uint16_t shift) { - return value << shift | value >> (16 - shift); - } - - template <> - math_def FORCEINLINE uint16_t _rotate_right(uint16_t value, uint16_t shift) { - return value >> shift | value << (16 - shift); - } - - template <> - math_def FORCEINLINE int _rotate_left(int value, int shift) { - return value << shift | value >> (32 - shift); - } - - template <> - math_def FORCEINLINE int _rotate_right(int value, int shift) { - return value >> shift | value << (32 - shift); - } - - template <> - math_def FORCEINLINE uint32_t _rotate_left(uint32_t value, uint32_t shift) { - return value << shift | value >> (32 - shift); - } - - template <> - math_def FORCEINLINE uint32_t _rotate_right(uint32_t value, uint32_t shift) { - return value >> shift | value << (32 - shift); - } - - template <> - math_def FORCEINLINE Nd4jLong _rotate_left(Nd4jLong value, Nd4jLong shift) { - return value << shift | value >> (64 - shift); - } - - template <> - math_def FORCEINLINE Nd4jLong _rotate_right(Nd4jLong value, Nd4jLong shift) { - return value >> shift | value << (64 - shift); - } - - template <> - math_def FORCEINLINE uint64_t _rotate_left(uint64_t value, uint64_t shift) { +template +math_def FORCEINLINE T _rotate_left(T value, T shift); + +template +math_def FORCEINLINE T _rotate_right(T value, T shift); + +template <> +math_def FORCEINLINE int8_t _rotate_left(int8_t value, int8_t shift) { + return value << shift | value >> (8 - shift); +} + +template <> +math_def FORCEINLINE int8_t _rotate_right(int8_t value, int8_t shift) { + return value >> shift | value << (8 - shift); +} + +template <> +math_def FORCEINLINE uint8_t _rotate_left(uint8_t value, uint8_t shift) { + return value << shift | value >> (8 - shift); +} + +template <> +math_def FORCEINLINE uint8_t _rotate_right(uint8_t value, uint8_t shift) { + return value >> shift | value << (8 - shift); +} + +template <> +math_def FORCEINLINE int16_t _rotate_left(int16_t value, int16_t shift) { + return value << shift | value >> (16 - shift); +} + +template <> +math_def FORCEINLINE int16_t _rotate_right(int16_t value, int16_t shift) { + return value >> shift | value << (16 - shift); +} + +template <> +math_def FORCEINLINE uint16_t _rotate_left(uint16_t value, uint16_t shift) { + return value << shift | value >> (16 - shift); +} + +template <> +math_def FORCEINLINE uint16_t _rotate_right(uint16_t value, uint16_t shift) { + return value >> shift | value << (16 - shift); +} + +template <> +math_def FORCEINLINE int _rotate_left(int value, int shift) { + return value << shift | value >> (32 - shift); +} + +template <> +math_def FORCEINLINE int _rotate_right(int value, int shift) { + return value >> shift | value << (32 - shift); +} + +template <> +math_def FORCEINLINE uint32_t _rotate_left(uint32_t value, uint32_t shift) { + return value << shift | value >> (32 - shift); +} + +template <> +math_def FORCEINLINE uint32_t _rotate_right(uint32_t value, uint32_t shift) { + return value >> shift | value << (32 - shift); +} + +template <> +math_def FORCEINLINE Nd4jLong _rotate_left(Nd4jLong value, Nd4jLong shift) { + return value << shift | value >> (64 - shift); +} + +template <> +math_def FORCEINLINE Nd4jLong _rotate_right(Nd4jLong value, Nd4jLong shift) { + return value >> shift | value << (64 - shift); +} + +template <> +math_def FORCEINLINE uint64_t _rotate_left(uint64_t value, uint64_t shift) { #ifdef SD_ARM_BUILD - // TODO: eventually remove this once gcc fixes the bug - Nd4jLong val = _rotate_left(*reinterpret_cast(&value), *reinterpret_cast(&shift)); - return *reinterpret_cast(&val); + // TODO: eventually remove this once gcc fixes the bug + Nd4jLong val = _rotate_left(*reinterpret_cast(&value), + *reinterpret_cast(&shift)); + return *reinterpret_cast(&val); #else - return value << shift | value >> (64 - shift); + return value << shift | value >> (64 - shift); #endif - } +} - template <> - math_def FORCEINLINE uint64_t _rotate_right(uint64_t value, uint64_t shift) { +template <> +math_def FORCEINLINE uint64_t _rotate_right(uint64_t value, uint64_t shift) { #ifdef SD_ARM_BUILD - // TODO: eventually remove this once gcc fixes the bug - Nd4jLong val = _rotate_right(*reinterpret_cast(&value), *reinterpret_cast(&shift)); - return *reinterpret_cast(&val); + // TODO: eventually remove this once gcc fixes the bug + Nd4jLong val = _rotate_right(*reinterpret_cast(&value), + *reinterpret_cast(&shift)); + return *reinterpret_cast(&val); #else - return value >> shift | value << (64 - shift); + return value >> shift | value << (64 - shift); #endif - } - +} - template - math_def FORCEINLINE T p_rotl(T value, T shift) { - return _rotate_left(value, shift); - } +template +math_def FORCEINLINE T p_rotl(T value, T shift) { + return _rotate_left(value, shift); +} - template - math_def FORCEINLINE T p_rotr(T value, T shift) { - return _rotate_right(value, shift); - } - } +template +math_def FORCEINLINE T p_rotr(T value, T shift) { + return _rotate_right(value, shift); } +} // namespace math +} // namespace sd -#endif //DEV_TESTS_PLATFORM_MATH_H +#endif // SD_PLATFORM_MATH_H diff --git a/libnd4j/include/math/templatemath.h b/libnd4j/include/math/templatemath.h index c220231d8627..f676b62d95e7 100644 --- a/libnd4j/include/math/templatemath.h +++ b/libnd4j/include/math/templatemath.h @@ -25,10 +25,10 @@ #ifndef TEMPLATEMATH_H_ #define TEMPLATEMATH_H_ +#include +#include #include #include -#include -#include #define BFLOAT16_MAX_VALUE 32737. #define HALF_MAX_VALUE 65504. @@ -49,932 +49,910 @@ namespace sd { #endif - namespace math { - template - math_def inline T nd4j_abs(T value); +namespace math { +template +math_def inline T nd4j_abs(T value); - template - math_def inline void nd4j_swap(T &val1, T &val2); +template +math_def inline void nd4j_swap(T& val1, T& val2); - template - math_def inline T nd4j_max(T val1, T val2); +template +math_def inline T nd4j_max(T val1, T val2); - template - math_def inline T nd4j_min(T val1, T val2); +template +math_def inline T nd4j_min(T val1, T val2); - template - math_def inline bool nd4j_eq(T val1, T val2, double eps); +template +math_def inline bool nd4j_eq(T val1, T val2, double eps); - template - math_def inline Z nd4j_re(T val1, T val2); +template +math_def inline Z nd4j_re(T val1, T val2); - template - math_def inline Z nd4j_rint(T val1); +template +math_def inline Z nd4j_rint(T val1); - template - math_def inline Z nd4j_copysign(T val1, T val2); +template +math_def inline Z nd4j_copysign(T val1, T val2); - template - math_def inline Z nd4j_softplus(T val); +template +math_def inline Z nd4j_softplus(T val); - template - math_def inline T nd4j_rotl(T val, T shift); +template +math_def inline T nd4j_rotl(T val, T shift); - template - math_def inline T nd4j_rotr(T val, T shift); +template +math_def inline T nd4j_rotr(T val, T shift); //#ifndef __CUDACC__ - template - math_def inline Z nd4j_dot(X *x, Y *y, int length); +template +math_def inline Z nd4j_dot(X* x, Y* y, int length); //#endif - template - math_def inline Z nd4j_ceil(T val1); +template +math_def inline Z nd4j_ceil(T val1); - template - math_def inline bool nd4j_isnan(T val1); +template +math_def inline bool nd4j_isnan(T val1); - template - math_def inline bool nd4j_isinf(T val1); +template +math_def inline bool nd4j_isinf(T val1); - template - math_def inline bool nd4j_isfin(T val1); +template +math_def inline bool nd4j_isfin(T val1); - template - math_def inline Z nd4j_cos(T val); +template +math_def inline Z nd4j_cos(T val); - template - math_def inline Z nd4j_cosh(T val); +template +math_def inline Z nd4j_cosh(T val); - template - math_def inline Z nd4j_exp(X val); +template +math_def inline Z nd4j_exp(X val); - template - math_def inline Z nd4j_floor(T val); +template +math_def inline Z nd4j_floor(T val); - template - math_def inline Z nd4j_log(X val); +template +math_def inline Z nd4j_log(X val); - template - math_def inline Z nd4j_pow(X val, Y val2); +template +math_def inline Z nd4j_pow(X val, Y val2); - template - math_def inline Z nd4j_round(T val); +template +math_def inline Z nd4j_round(T val); - template - math_def inline Z nd4j_remainder(X num, Y denom); +template +math_def inline Z nd4j_remainder(X num, Y denom); - template - math_def inline Z nd4j_fmod(X num, Y denom); +template +math_def inline Z nd4j_fmod(X num, Y denom); - template - math_def inline Z nd4j_erf(T num); +template +math_def inline Z nd4j_erf(T num); - template - math_def inline Z nd4j_erfc(T num); +template +math_def inline Z nd4j_erfc(T num); - math_def inline int32_t floatToRawIntBits(float d) { - union { - float f; - int32_t i; - } tmp; - tmp.f = d; - return tmp.i; - } +math_def inline int32_t floatToRawIntBits(float d) { + union { + float f; + int32_t i; + } tmp; + tmp.f = d; + return tmp.i; +} - math_def inline float intBitsToFloat(int32_t i) { - union { - float f; - int32_t i; - } tmp; - tmp.i = i; - return tmp.f; - } +math_def inline float intBitsToFloat(int32_t i) { + union { + float f; + int32_t i; + } tmp; + tmp.i = i; + return tmp.f; +} - math_def inline float mulsignf(float x, float y) { - return intBitsToFloat(floatToRawIntBits(x) ^ (floatToRawIntBits(y) & (1 << 31))); - } +math_def inline float mulsignf(float x, float y) { + return intBitsToFloat(floatToRawIntBits(x) ^ + (floatToRawIntBits(y) & (1 << 31))); +} - math_def inline float copysignfk(float x, float y) { - return intBitsToFloat((floatToRawIntBits(x) & ~(1 << 31)) ^ (floatToRawIntBits(y) & (1 << 31))); - } +math_def inline float copysignfk(float x, float y) { + return intBitsToFloat((floatToRawIntBits(x) & ~(1 << 31)) ^ + (floatToRawIntBits(y) & (1 << 31))); +} - template - math_def inline Z nd4j_sigmoid(T val) { - return (Z) 1.0f / ((Z) 1.0f + nd4j_exp(-val)); - } - - template - math_def inline Z nd4j_elu(T val, T alpha) { - if (val >= (T) 0.f) - return val; - return static_cast(alpha) * (nd4j_exp(val) - static_cast(1.0f)); - } - - template - math_def inline Z nd4j_leakyrelu(T val,T alpha) { - if (val < (T) 0.0f) - return alpha * val; - else - return val; - } - - template - math_def inline Z nd4j_eluderivative(T val, T alpha) { - if (val >= static_cast(0.0f)) - return static_cast(1.0f); - return static_cast(alpha) * nd4j_exp(val); - //return val >= 0.0 ? 1.0 : nd4j_exp(val); - } - - template - math_def inline Z nd4j_sin(T val); - - template - math_def inline Z nd4j_sinh(T val); - - template - math_def inline Z nd4j_softplus(T val) { - return nd4j_log((Z) 1.0f + nd4j_exp(val)); - } - - template - math_def inline Z nd4j_softsign(T val) { - return val / ((T) 1.0f + sd::math::nd4j_abs(val)); - } - - template - math_def inline Z nd4j_sqrt(X val); - - template - math_def inline Z nd4j_tanh(X val); - - template - math_def inline Z nd4j_tan(T val); - - template - math_def inline Z nd4j_atan2(X val1, X val2); - - template - math_def inline Z nd4j_atan2(X val1, X val2) { - return p_atan2(static_cast(val1), static_cast(val2)); - } - - - template - math_def inline Z nd4j_tan(T tval) { - return p_tan(static_cast(tval)); - } +template +math_def inline Z nd4j_sigmoid(T val) { + return (Z)1.0f / ((Z)1.0f + nd4j_exp(-val)); +} - template - math_def inline Z nd4j_tanhderivative(T val) { - Z tanh = nd4j_tanh(val); - return (Z) 1.0f - tanh * tanh; - } - template - math_def inline T nd4j_sigmoidderivative(T val) { - Z sigmoid = nd4j_sigmoid(val); - return sigmoid * ((Z) 1.0f - sigmoid); - } - - template - math_def inline T nd4j_softsignderivative(T val) { - T y = (T) 1.0f + nd4j_abs(val); - return (Z) 1.0f / (y * y); - } - - template - math_def inline T nd4j_sgn(T val) { - return val < (T) 0.0f ? (Z) -1.0f : val > (T) 0.0f ? (Z) 1.0f : (Z) 0.0f; - } +template +math_def inline Z nd4j_elu(T val, T alpha) { + if (val >= (T)0.f) return val; + return static_cast(alpha) * (nd4j_exp(val) - static_cast(1.0f)); +} - template - math_def inline Z nd4j_sign(T val) { - return nd4j_sgn(val); - } +template +math_def inline Z nd4j_leakyrelu(T val, T alpha) { + if (val < (T)0.0f) + return alpha * val; + else + return val; +} - template - math_def inline Z nd4j_signum(T val) { - return nd4j_sgn(val); - } +template +math_def inline Z nd4j_eluderivative(T val, T alpha) { + if (val >= static_cast(0.0f)) return static_cast(1.0f); + return static_cast(alpha) * nd4j_exp(val); + // return val >= 0.0 ? 1.0 : nd4j_exp(val); +} + +template +math_def inline Z nd4j_sin(T val); + +template +math_def inline Z nd4j_sinh(T val); + +template +math_def inline Z nd4j_softplus(T val) { + return nd4j_log((Z)1.0f + nd4j_exp(val)); +} + +template +math_def inline Z nd4j_softsign(T val) { + return val / ((T)1.0f + sd::math::nd4j_abs(val)); +} + +template +math_def inline Z nd4j_sqrt(X val); + +template +math_def inline Z nd4j_tanh(X val); + +template +math_def inline Z nd4j_tan(T val); + +template +math_def inline Z nd4j_atan2(X val1, X val2); + +template +math_def inline Z nd4j_atan2(X val1, X val2) { + return p_atan2(static_cast(val1), static_cast(val2)); +} - template - math_def inline Z nd4j_gamma(X a); +template +math_def inline Z nd4j_tan(T tval) { + return p_tan(static_cast(tval)); +} + +template +math_def inline Z nd4j_tanhderivative(T val) { + Z tanh = nd4j_tanh(val); + return (Z)1.0f - tanh * tanh; +} +template +math_def inline T nd4j_sigmoidderivative(T val) { + Z sigmoid = nd4j_sigmoid(val); + return sigmoid * ((Z)1.0f - sigmoid); +} + +template +math_def inline T nd4j_softsignderivative(T val) { + T y = (T)1.0f + nd4j_abs(val); + return (Z)1.0f / (y * y); +} - template - math_def inline Z nd4j_lgamma(X x); +template +math_def inline T nd4j_sgn(T val) { + return val < (T)0.0f ? (Z)-1.0f : val > (T)0.0f ? (Z)1.0f : (Z)0.0f; +} + +template +math_def inline Z nd4j_sign(T val) { + return nd4j_sgn(val); +} + +template +math_def inline Z nd4j_signum(T val) { + return nd4j_sgn(val); +} + +template +math_def inline Z nd4j_gamma(X a); + +template +math_def inline Z nd4j_lgamma(X x); //#ifndef __CUDACC__ /* template<> - math_def inline float16 nd4j_dot(float16 *x, float16 *y, int length) { - float16 dot = (float16) 0.0f; + math_def inline float16 nd4j_dot(float16 *x, float16 *y, int + length) { float16 dot = (float16) 0.0f; - // TODO: since we can't use simd on unions, we might use something else here. - for(int e = 0; e < length; e++) { - dot += x[e] * y[e]; + // TODO: since we can't use simd on unions, we might use something + else here. for(int e = 0; e < length; e++) { dot += x[e] * y[e]; } return dot; } */ - template - math_def inline Z nd4j_dot(X *x, Y *y, int length) { - Z dot = (Z)0.0f; +template +math_def inline Z nd4j_dot(X* x, Y* y, int length) { + Z dot = (Z)0.0f; - for(int e = 0; e < length; e++) { - dot += static_cast(x[e]) * static_cast(y[e]); - } + for (int e = 0; e < length; e++) { + dot += static_cast(x[e]) * static_cast(y[e]); + } - return dot; - } + return dot; +} //#endif - template - math_def inline Z nd4j_acos(T val); +template +math_def inline Z nd4j_acos(T val); - template - math_def inline Z nd4j_sech(T val); +template +math_def inline Z nd4j_sech(T val); - template - math_def inline Z nd4j_acosh(T val); +template +math_def inline Z nd4j_acosh(T val); - template - math_def inline Z nd4j_asin(T val); +template +math_def inline Z nd4j_asin(T val); - template - math_def inline Z nd4j_asinh(T val); - - template - math_def inline Z nd4j_asinh(T val) { - //Math.log(Math.sqrt(Math.pow(x, 2) + 1) + x) - return nd4j_log(nd4j_sqrt(nd4j_pow(val, (T) 2) + (Z) 1.f) + (Z) val); - } +template +math_def inline Z nd4j_asinh(T val); - template - math_def inline Z nd4j_atan(T val); +template +math_def inline Z nd4j_asinh(T val) { + // Math.log(Math.sqrt(Math.pow(x, 2) + 1) + x) + return nd4j_log(nd4j_sqrt(nd4j_pow(val, (T)2) + (Z)1.f) + + (Z)val); +} - template - math_def inline Z nd4j_atanh(T val); +template +math_def inline Z nd4j_atan(T val); +template +math_def inline Z nd4j_atanh(T val); - template<> - math_def inline float16 nd4j_abs(float16 value) { +template <> +math_def inline float16 nd4j_abs(float16 value) { #ifdef NATIVE_HALFS - if (value < (float16) 0.f) { - return float16(__hneg(value.data)); - } else - return value; + if (value < (float16)0.f) { + return float16(__hneg(value.data)); + } else + return value; #else - return (float16) fabsf((float) value); + return (float16)fabsf((float)value); #endif - } - template<> - math_def inline bfloat16 nd4j_abs(bfloat16 value) { - return (bfloat16) fabsf((float) value); - } - template<> - math_def inline float nd4j_abs(float value) { - return fabsf(value); - } - - template<> - math_def inline double nd4j_abs(double value) { - return fabs(value); - } - - template<> - math_def inline int nd4j_abs(int value) { - return abs(value); - } - - template<> - math_def inline Nd4jLong nd4j_abs(Nd4jLong value) { - return llabs(value); - } - - template<> - math_def inline bool nd4j_abs(bool value) { - return value; - } - - template<> - math_def inline uint8_t nd4j_abs(uint8_t value) { - return value; - } - - template<> - math_def inline uint16_t nd4j_abs(uint16_t value) { - return value; - } - - template<> - math_def inline uint32_t nd4j_abs(uint32_t value) { - return value; - } - - template<> - math_def inline Nd4jULong nd4j_abs(Nd4jULong value) { - return value; - } - - template<> - math_def inline int8_t nd4j_abs(int8_t value) { - return value < 0 ? -value : value; - } - - template<> - math_def inline int16_t nd4j_abs(int16_t value) { - return value < 0 ? -value : value; - } - - - template<> - math_def inline bool nd4j_isnan(float16 value) { - return *(value.data.getXP()) == 0x7fffU; - } - - template<> - math_def inline bool nd4j_isnan(bfloat16 value) { - return value == bfloat16::nan(); //0x7fffU; - } - - template<> - math_def inline bool nd4j_isnan(float value) { - return value != value; - } - - template<> - math_def inline bool nd4j_isnan(double value) { - return value != value; - } - - template<> - math_def inline bool nd4j_isnan(int value) { - return false; - } - - template<> - math_def inline bool nd4j_isnan(uint32_t value) { - return false; - } - - template<> - math_def inline bool nd4j_isnan(uint16_t value) { - return false; - } - - template<> - math_def inline bool nd4j_isnan(uint8_t value) { - return false; - } - - template<> - math_def inline bool nd4j_isnan(int16_t value) { - return false; - } - - template<> - math_def inline bool nd4j_isnan(int8_t value) { - return false; - } - - template<> - math_def inline bool nd4j_isnan(bool value) { - return false; - } - - template<> - math_def inline bool nd4j_isnan(Nd4jLong value) { - return false; - } - - template<> - math_def inline bool nd4j_isnan(Nd4jULong value) { - return false; - } - - template<> - math_def inline bool nd4j_isinf(float16 value) { - return value < (float16) -HALF_MAX_VALUE || value > (float16) HALF_MAX_VALUE; - } - - template<> - math_def inline bool nd4j_isinf(bfloat16 value) { - return value < (bfloat16) -BFLOAT16_MAX_VALUE || value > (bfloat16) BFLOAT16_MAX_VALUE; - } - - template<> - math_def inline bool nd4j_isinf(float value) { +} +template <> +math_def inline bfloat16 nd4j_abs(bfloat16 value) { + return (bfloat16)fabsf((float)value); +} +template <> +math_def inline float nd4j_abs(float value) { + return fabsf(value); +} + +template <> +math_def inline double nd4j_abs(double value) { + return fabs(value); +} + +template <> +math_def inline int nd4j_abs(int value) { + return abs(value); +} + +template <> +math_def inline Nd4jLong nd4j_abs(Nd4jLong value) { + return llabs(value); +} + +template <> +math_def inline bool nd4j_abs(bool value) { + return value; +} + +template <> +math_def inline uint8_t nd4j_abs(uint8_t value) { + return value; +} + +template <> +math_def inline uint16_t nd4j_abs(uint16_t value) { + return value; +} + +template <> +math_def inline uint32_t nd4j_abs(uint32_t value) { + return value; +} + +template <> +math_def inline Nd4jULong nd4j_abs(Nd4jULong value) { + return value; +} + +template <> +math_def inline int8_t nd4j_abs(int8_t value) { + return value < 0 ? -value : value; +} + +template <> +math_def inline int16_t nd4j_abs(int16_t value) { + return value < 0 ? -value : value; +} + +template <> +math_def inline bool nd4j_isnan(float16 value) { + return *(value.data.getXP()) == 0x7fffU; +} + +template <> +math_def inline bool nd4j_isnan(bfloat16 value) { + return value == bfloat16::nan(); // 0x7fffU; +} + +template <> +math_def inline bool nd4j_isnan(float value) { + return value != value; +} + +template <> +math_def inline bool nd4j_isnan(double value) { + return value != value; +} + +template <> +math_def inline bool nd4j_isnan(int value) { + return false; +} + +template <> +math_def inline bool nd4j_isnan(uint32_t value) { + return false; +} + +template <> +math_def inline bool nd4j_isnan(uint16_t value) { + return false; +} + +template <> +math_def inline bool nd4j_isnan(uint8_t value) { + return false; +} + +template <> +math_def inline bool nd4j_isnan(int16_t value) { + return false; +} + +template <> +math_def inline bool nd4j_isnan(int8_t value) { + return false; +} + +template <> +math_def inline bool nd4j_isnan(bool value) { + return false; +} + +template <> +math_def inline bool nd4j_isnan(Nd4jLong value) { + return false; +} + +template <> +math_def inline bool nd4j_isnan(Nd4jULong value) { + return false; +} + +template <> +math_def inline bool nd4j_isinf(float16 value) { + return value < (float16)-HALF_MAX_VALUE || value > (float16)HALF_MAX_VALUE; +} + +template <> +math_def inline bool nd4j_isinf(bfloat16 value) { + return value < (bfloat16)-BFLOAT16_MAX_VALUE || + value > (bfloat16)BFLOAT16_MAX_VALUE; +} + +template <> +math_def inline bool nd4j_isinf(float value) { #ifdef __CUDACC__ - return isinf(value); + return isinf(value); #else - return std::isinf(value); + return std::isinf(value); #endif - //return value < -FLOAT_MAX_VALUE || value > FLOAT_MAX_VALUE; - } + // return value < -FLOAT_MAX_VALUE || value > FLOAT_MAX_VALUE; +} - template<> - math_def inline bool nd4j_isinf(double value) { +template <> +math_def inline bool nd4j_isinf(double value) { #ifdef __CUDACC__ - return isinf(value); + return isinf(value); #else - return std::isinf(value); + return std::isinf(value); #endif - //return value < -DOUBLE_MAX_VALUE || value > DOUBLE_MAX_VALUE; - } - - template<> - math_def inline bool nd4j_isinf(int value) { - return false; - } - - template<> - math_def inline bool nd4j_isinf(uint32_t value) { - return false; - } - - template<> - math_def inline bool nd4j_isinf(uint16_t value) { - return false; - } - - template<> - math_def inline bool nd4j_isinf(uint8_t value) { - return false; - } - - template<> - math_def inline bool nd4j_isinf(int16_t value) { - return false; - } - - template<> - math_def inline bool nd4j_isinf(int8_t value) { - return false; - } - - template<> - math_def inline bool nd4j_isinf(bool value) { - return false; - } - - template<> - math_def inline bool nd4j_isinf(Nd4jLong value) { - return false; - } - - template<> - math_def inline bool nd4j_isinf(Nd4jULong value) { - return false; - } - - template - math_def inline bool nd4j_isfin(T value) { - return !nd4j_isnan(value) && !nd4j_isinf(value); - } - - template<> - math_def inline float16 nd4j_copysign(float16 val1, float16 val2) { - return (float16) copysignf((float) val1, (float) val2); - } - - template<> - math_def inline float nd4j_copysign(float val1, float val2) { - return copysignf(val1, val2); - } - - template<> - math_def inline double nd4j_copysign(double val1, double val2) { - return copysign(val1, val2); - } - - template<> - math_def inline int nd4j_copysign(int val1, int val2) { - if (val2 < 0) return -(nd4j_abs(val1)); - else return nd4j_abs(val1); - } - - template<> - math_def inline Nd4jLong nd4j_copysign(Nd4jLong val1, Nd4jLong val2) { - if (val2 < 0) return -(nd4j_abs(val1)); - else return nd4j_abs(val1); - } - - template<> - math_def inline bool nd4j_max(bool val1, bool val2) { - return (val1 || val2) ? true : false; - } - - template - math_def inline T nd4j_max(T val1, T val2) { - return val1 > val2 ? val1 : val2; - } - - template<> - math_def inline bool nd4j_min(bool val1, bool val2) { - return (val1 && val2) ? true : false; - } - - template - math_def inline T nd4j_min(T val1, T val2) { - return val1 < val2 ? val1 : val2; - } - - template - math_def inline bool nd4j_eq(T d1, T d2, double eps) { - if (sd::math::nd4j_isinf(d1) && sd::math::nd4j_isinf(d2)) { - if (d1 > 0 && d2 > 0) - return true; - else if (d1 < 0 && d2 < 0) - return true; - else - return false; - } - - auto diff = static_cast(sd::math::nd4j_abs(d1 - d2)); - - - // works well except in the range of very large numbers - if (diff <= eps) - return true; - - // Knuth approach - // works well except in the range of very small numbers - if (diff <= sd::math::nd4j_max(sd::math::nd4j_abs(static_cast(d1)), sd::math::nd4j_abs(static_cast(d2))) * eps) - return true; - - return false; - } - - template - math_def inline Z nd4j_ceil(X val) { - return static_cast(p_ceil(val)); - } - - template - math_def inline Z nd4j_round(X val) { - return static_cast(p_round(val)); - } + // return value < -DOUBLE_MAX_VALUE || value > DOUBLE_MAX_VALUE; +} - template - math_def inline Z nd4j_asin(X val) { - return p_asin(static_cast(val)); - } +template <> +math_def inline bool nd4j_isinf(int value) { + return false; +} - template - math_def inline Z nd4j_atan(X val) { - return p_atan(static_cast(val)); - } +template <> +math_def inline bool nd4j_isinf(uint32_t value) { + return false; +} - template - math_def inline Z nd4j_atanh(X val) { - return p_atanh(static_cast(val)); - } +template <> +math_def inline bool nd4j_isinf(uint16_t value) { + return false; +} - template - math_def inline Z nd4j_cosh(X val) { - return p_cosh(static_cast(val)); - } +template <> +math_def inline bool nd4j_isinf(uint8_t value) { + return false; +} - template - math_def inline Z nd4j_rint(X val) { - return p_rint(val); - } +template <> +math_def inline bool nd4j_isinf(int16_t value) { + return false; +} - template - math_def inline Z nd4j_sinh(X val) { - return p_sinh(static_cast(val)); - } +template <> +math_def inline bool nd4j_isinf(int8_t value) { + return false; +} - template - math_def inline Z nd4j_acos(X val) { - return p_acos(static_cast(val)); - } +template <> +math_def inline bool nd4j_isinf(bool value) { + return false; +} - template - math_def inline Z nd4j_sech(X val) { - return static_cast(1) / nd4j_cosh(val); - } +template <> +math_def inline bool nd4j_isinf(Nd4jLong value) { + return false; +} - template - math_def inline Z nd4j_acosh(X val) { - return p_acosh(static_cast(val)); - } +template <> +math_def inline bool nd4j_isinf(Nd4jULong value) { + return false; +} - template - math_def inline Z nd4j_cos(X val) { - return p_cos(static_cast(val)); - } +template +math_def inline bool nd4j_isfin(T value) { + return !nd4j_isnan(value) && !nd4j_isinf(value); +} - template - math_def inline Z nd4j_exp(X val) { - return p_exp(val); - } +template <> +math_def inline float16 nd4j_copysign(float16 val1, float16 val2) { + return (float16)copysignf((float)val1, (float)val2); +} - template - math_def inline Z nd4j_floor(X val) { - return static_cast(p_floor(val)); - } - - template - math_def inline Z nd4j_log(X val) { - return static_cast(p_log(val)); - } - - /** - * This func is special case - it must return floating point value, and optionally Y arg can be floating point argument - * @tparam X - * @tparam Y - * @tparam Z - * @param val - * @param val2 - * @return - */ - template <> - math_def inline float nd4j_pow(float val, float val2) { - return p_pow(val, val2); - } +template <> +math_def inline float nd4j_copysign(float val1, float val2) { + return copysignf(val1, val2); +} - template - math_def inline Z nd4j_pow(X val, Y val2) { - return p_pow(static_cast(val), static_cast(val2)); - } - - /** - * LogGamma(a) - float point extension of ln(n!) - **/ - template - math_def inline Z nd4j_lgamma(X x) { -// if (x <= X(0.0)) -// { -// std::stringstream os; -// os << "Logarithm of Gamma has sence only for positive values, but " << x << " was given."; -// throw std::invalid_argument( os.str() ); -// } - - if (x < X(12.0)) { - return nd4j_log(nd4j_gamma(x)); - } +template <> +math_def inline double nd4j_copysign(double val1, double val2) { + return copysign(val1, val2); +} - // Abramowitz and Stegun 6.1.41 - // Asymptotic series should be good to at least 11 or 12 figures - // For error analysis, see Whittiker and Watson - // A Course in Modern Analysis (1927), page 252 - - static const double c[8] = { - 1.0/12.0, - -1.0/360.0, - 1.0/1260.0, - -1.0/1680.0, - 1.0/1188.0, - -691.0/360360.0, - 1.0/156.0, - -3617.0/122400.0 - }; - - double z = Z(1.0 / Z(x * x)); - double sum = c[7]; - - for (int i = 6; i >= 0; i--) { - sum *= z; - sum += c[i]; - } +template <> +math_def inline int nd4j_copysign(int val1, int val2) { + if (val2 < 0) + return -(nd4j_abs(val1)); + else + return nd4j_abs(val1); +} - double series = sum / Z(x); +template <> +math_def inline Nd4jLong nd4j_copysign(Nd4jLong val1, Nd4jLong val2) { + if (val2 < 0) + return -(nd4j_abs(val1)); + else + return nd4j_abs(val1); +} - static const double halfLogTwoPi = 0.91893853320467274178032973640562; +template <> +math_def inline bool nd4j_max(bool val1, bool val2) { + return (val1 || val2) ? true : false; +} - return Z((double(x) - 0.5) * nd4j_log(x) - double(x) + halfLogTwoPi + series); - } +template +math_def inline T nd4j_max(T val1, T val2) { + return val1 > val2 ? val1 : val2; +} +template <> +math_def inline bool nd4j_min(bool val1, bool val2) { + return (val1 && val2) ? true : false; +} +template +math_def inline T nd4j_min(T val1, T val2) { + return val1 < val2 ? val1 : val2; +} - template - math_def inline T nd4j_re(T val1, T val2) { - if (val1 == (T) 0.0f && val2 == (T) 0.0f) - return (T) 0.0f; +template +math_def inline bool nd4j_eq(T d1, T d2, double eps) { + if (sd::math::nd4j_isinf(d1) && sd::math::nd4j_isinf(d2)) { + if (d1 > 0 && d2 > 0) + return true; + else if (d1 < 0 && d2 < 0) + return true; + else + return false; + } + + auto diff = static_cast(sd::math::nd4j_abs(d1 - d2)); + + // works well except in the range of very large numbers + if (diff <= eps) return true; + + // Knuth approach + // works well except in the range of very small numbers + if (diff <= sd::math::nd4j_max( + sd::math::nd4j_abs(static_cast(d1)), + sd::math::nd4j_abs(static_cast(d2))) * + eps) + return true; + + return false; +} - return nd4j_abs(val1 - val2) / (nd4j_abs(val1) + nd4j_abs(val2)); - } +template +math_def inline Z nd4j_ceil(X val) { + return static_cast(p_ceil(val)); +} +template +math_def inline Z nd4j_round(X val) { + return static_cast(p_round(val)); +} - template - math_def inline Z nd4j_remainder(X val, Y val2) { - return p_remainder(static_cast(val), static_cast(val2)); - } +template +math_def inline Z nd4j_asin(X val) { + return p_asin(static_cast(val)); +} + +template +math_def inline Z nd4j_atan(X val) { + return p_atan(static_cast(val)); +} - template - math_def inline Z nd4j_fmod(X val, Y val2) { - return p_fmod(static_cast(val), static_cast(val2)); - } +template +math_def inline Z nd4j_atanh(X val) { + return p_atanh(static_cast(val)); +} +template +math_def inline Z nd4j_cosh(X val) { + return p_cosh(static_cast(val)); +} - template - math_def inline Z nd4j_sin(X val) { - return p_sin(static_cast(val)); - } +template +math_def inline Z nd4j_rint(X val) { + return p_rint(val); +} +template +math_def inline Z nd4j_sinh(X val) { + return p_sinh(static_cast(val)); +} - template - math_def inline Z nd4j_sqrt(X val) { - return p_sqrt(static_cast(val)); - } +template +math_def inline Z nd4j_acos(X val) { + return p_acos(static_cast(val)); +} +template +math_def inline Z nd4j_sech(X val) { + return static_cast(1) / nd4j_cosh(val); +} - template - math_def inline X neg_tanh(X val) { - X o = static_cast(1.0f); - X t = static_cast(2.0f); - X e = static_cast(M_E); +template +math_def inline Z nd4j_acosh(X val) { + return p_acosh(static_cast(val)); +} - auto p = sd::math::nd4j_pow(e, val * t); - return (p - o)/ (p + o); - } +template +math_def inline Z nd4j_cos(X val) { + return p_cos(static_cast(val)); +} - template - math_def inline X pos_tanh(X val) { - X o = static_cast(1.0f); - X t = static_cast(-2.0f); - X e = static_cast(M_E); +template +math_def inline Z nd4j_exp(X val) { + return p_exp(val); +} - auto p = sd::math::nd4j_pow(e, val * t); - return (o - p) / (o + p); - } +template +math_def inline Z nd4j_floor(X val) { + return static_cast(p_floor(val)); +} +template +math_def inline Z nd4j_log(X val) { + return static_cast(p_log(val)); +} - math_def inline float neu_tanh(float val, float sign) { - float e(M_E); - float av = sign * val; - auto p = sd::math::nd4j_pow(e, -av * 2.f); - return (1 - p) / (1 + p); - } +/** + * This func is special case - it must return floating point value, and + * optionally Y arg can be floating point argument + * @tparam X + * @tparam Y + * @tparam Z + * @param val + * @param val2 + * @return + */ +template <> +math_def inline float nd4j_pow(float val, float val2) { + return p_pow(val, val2); +} - template <> - math_def inline float nd4j_tanh(float val) { - float sign = copysignfk(1.0f, val); - return sign * neu_tanh(val, sign); - } +template +math_def inline Z nd4j_pow(X val, Y val2) { + return p_pow(static_cast(val), static_cast(val2)); +} +/** + * LogGamma(a) - float point extension of ln(n!) + **/ +template +math_def inline Z nd4j_lgamma(X x) { + // if (x <= X(0.0)) + // { + // std::stringstream os; + // os << "Logarithm of Gamma has sence only for positive + // values, but " << x << " was given."; throw + // std::invalid_argument( os.str() ); + // } + + if (x < X(12.0)) { + return nd4j_log(nd4j_gamma(x)); + } + + // Abramowitz and Stegun 6.1.41 + // Asymptotic series should be good to at least 11 or 12 figures + // For error analysis, see Whittiker and Watson + // A Course in Modern Analysis (1927), page 252 + + static const double c[8] = { + 1.0 / 12.0, -1.0 / 360.0, 1.0 / 1260.0, -1.0 / 1680.0, + 1.0 / 1188.0, -691.0 / 360360.0, 1.0 / 156.0, -3617.0 / 122400.0}; + + double z = Z(1.0 / Z(x * x)); + double sum = c[7]; + + for (int i = 6; i >= 0; i--) { + sum *= z; + sum += c[i]; + } + + double series = sum / Z(x); + + static const double halfLogTwoPi = 0.91893853320467274178032973640562; + + return Z((double(x) - 0.5) * nd4j_log(x) - double(x) + + halfLogTwoPi + series); +} - template - math_def inline Z nd4j_tanh(X val) { - return val <= 0 ? neg_tanh(val) : pos_tanh(val); - } +template +math_def inline T nd4j_re(T val1, T val2) { + if (val1 == (T)0.0f && val2 == (T)0.0f) return (T)0.0f; - template - math_def inline T nd4j_rotl(T val, T shift) { - return p_rotl(val, shift); - } + return nd4j_abs(val1 - val2) / (nd4j_abs(val1) + nd4j_abs(val2)); +} - template - math_def inline T nd4j_rotr(T val, T shift) { - return p_rotr(val, shift); - } +template +math_def inline Z nd4j_remainder(X val, Y val2) { + return p_remainder(static_cast(val), static_cast(val2)); +} - template - math_def inline Z nd4j_erf(X val) { - return p_erf(static_cast(val)); - } +template +math_def inline Z nd4j_fmod(X val, Y val2) { + return p_fmod(static_cast(val), static_cast(val2)); +} +template +math_def inline Z nd4j_sin(X val) { + return p_sin(static_cast(val)); +} - template - math_def inline Z nd4j_erfc(X val) { - return p_erfc(static_cast(val)); - } +template +math_def inline Z nd4j_sqrt(X val) { + return p_sqrt(static_cast(val)); +} - template - math_def inline void nd4j_swap(T &val1, T &val2) { - T temp = val1; val1=val2; val2=temp; - }; - - template - math_def inline Z nd4j_gamma(X a) { -// nd4j_lgamma(a); -// return (Z)std::tgamma(a); - // Split the function domain into three intervals: - // (0, 0.001), [0.001, 12), and (12, infinity) - - /////////////////////////////////////////////////////////////////////////// - // First interval: (0, 0.001) - // - // For small a, 1/Gamma(a) has power series a + gamma a^2 - ... - // So in this range, 1/Gamma(a) = a + gamma a^2 with error on the order of a^3. - // The relative error over this interval is less than 6e-7. - - const double eulerGamma = 0.577215664901532860606512090; // Euler's gamma constant - - if (a < X(0.001)) - return Z(1.0 / ((double)a * (1.0 + eulerGamma * (double)a))); - - /////////////////////////////////////////////////////////////////////////// - // Second interval: [0.001, 12) - - if (a < X(12.0)) { - // The algorithm directly approximates gamma over (1,2) and uses - // reduction identities to reduce other arguments to this interval. - - double y = (double)a; - int n = 0; - bool argWasLessThanOne = y < 1.0; - - // Add or subtract integers as necessary to bring y into (1,2) - // Will correct for this below - if (argWasLessThanOne) { - y += 1.0; - } - else { - n = static_cast(floor(y)) - 1; // will use n later - y -= n; - } - - // numerator coefficients for approximation over the interval (1,2) - static const double p[] = { - -1.71618513886549492533811E+0, - 2.47656508055759199108314E+1, - -3.79804256470945635097577E+2, - 6.29331155312818442661052E+2, - 8.66966202790413211295064E+2, - -3.14512729688483675254357E+4, - -3.61444134186911729807069E+4, - 6.64561438202405440627855E+4 - }; - - // denominator coefficients for approximation over the interval (1,2) - static const double q[] = { - -3.08402300119738975254353E+1, - 3.15350626979604161529144E+2, - -1.01515636749021914166146E+3, - -3.10777167157231109440444E+3, - 2.25381184209801510330112E+4, - 4.75584627752788110767815E+3, - -1.34659959864969306392456E+5, - -1.15132259675553483497211E+5 - }; - - double num = 0.0; - double den = 1.0; - - - double z = y - 1; - for (auto i = 0; i < 8; i++) { - num = (num + p[i]) * z; - den = den * z + q[i]; - } - double result = num / den + 1.0; - - // Apply correction if argument was not initially in (1,2) - if (argWasLessThanOne) { - // Use identity gamma(z) = gamma(z+1)/z - // The variable "result" now holds gamma of the original y + 1 - // Thus we use y-1 to get back the orginal y. - result /= (y - 1.0); - } - else { - // Use the identity gamma(z+n) = z*(z+1)* ... *(z+n-1)*gamma(z) - for (auto i = 0; i < n; i++) - result *= y++; - } - - return Z(result); - } +template +math_def inline X neg_tanh(X val) { + X o = static_cast(1.0f); + X t = static_cast(2.0f); + X e = static_cast(M_E); - /////////////////////////////////////////////////////////////////////////// - // Third interval: [12, infinity) + auto p = sd::math::nd4j_pow(e, val * t); + return (p - o) / (p + o); +} - if (a > 171.624) { - // Correct answer too large to display. Force +infinity. - return Z(DOUBLE_MAX_VALUE); -// return DataTypeUtils::infOrMax(); - } +template +math_def inline X pos_tanh(X val) { + X o = static_cast(1.0f); + X t = static_cast(-2.0f); + X e = static_cast(M_E); - return sd::math::nd4j_exp(sd::math::nd4j_lgamma(a)); - } + auto p = sd::math::nd4j_pow(e, val * t); + return (o - p) / (o + p); +} - template - math_def inline Z nd4j_igamma(X a, Y x) { - Z aim = nd4j_pow(x, a) / (nd4j_exp(x) * nd4j_gamma(a)); - auto sum = Z(0.); - auto denom = Z(1.); - if (a <= X(0.000001)) - //throw std::runtime_error("Cannot calculate gamma for a zero val."); - return Z(0); - - for (int i = 0; Z(1./denom) > Z(1.0e-12); i++) { - denom *= (a + i); - sum += nd4j_pow(x, i) / denom; - } - return aim * sum; - } +math_def inline float neu_tanh(float val, float sign) { + float e(M_E); + float av = sign * val; + auto p = sd::math::nd4j_pow(e, -av * 2.f); + return (1 - p) / (1 + p); +} - template - math_def inline Z nd4j_igammac(X a, Y x) { - return Z(1.) - nd4j_igamma(a, x); - } +template <> +math_def inline float nd4j_tanh(float val) { + float sign = copysignfk(1.0f, val); + return sign * neu_tanh(val, sign); +} + +template +math_def inline Z nd4j_tanh(X val) { + return val <= 0 ? neg_tanh(val) : pos_tanh(val); +} + +template +math_def inline T nd4j_rotl(T val, T shift) { + return p_rotl(val, shift); +} + +template +math_def inline T nd4j_rotr(T val, T shift) { + return p_rotr(val, shift); +} + +template +math_def inline Z nd4j_erf(X val) { + return p_erf(static_cast(val)); +} + +template +math_def inline Z nd4j_erfc(X val) { + return p_erfc(static_cast(val)); +} + +template +math_def inline void nd4j_swap(T& val1, T& val2) { + T temp = val1; + val1 = val2; + val2 = temp; +}; + +template +math_def inline Z nd4j_gamma(X a) { + // nd4j_lgamma(a); + // return (Z)std::tgamma(a); + // Split the function domain into three intervals: + // (0, 0.001), [0.001, 12), and (12, infinity) + + /////////////////////////////////////////////////////////////////////////// + // First interval: (0, 0.001) + // + // For small a, 1/Gamma(a) has power series a + gamma a^2 - ... + // So in this range, 1/Gamma(a) = a + gamma a^2 with error on the order of + // a^3. The relative error over this interval is less than 6e-7. + + const double eulerGamma = + 0.577215664901532860606512090; // Euler's gamma constant + + if (a < X(0.001)) + return Z(1.0 / ((double)a * (1.0 + eulerGamma * (double)a))); + + /////////////////////////////////////////////////////////////////////////// + // Second interval: [0.001, 12) + + if (a < X(12.0)) { + // The algorithm directly approximates gamma over (1,2) and uses + // reduction identities to reduce other arguments to this interval. + + double y = (double)a; + int n = 0; + bool argWasLessThanOne = y < 1.0; + + // Add or subtract integers as necessary to bring y into (1,2) + // Will correct for this below + if (argWasLessThanOne) { + y += 1.0; + } else { + n = static_cast(floor(y)) - 1; // will use n later + y -= n; + } + + // numerator coefficients for approximation over the interval (1,2) + static const double p[] = { + -1.71618513886549492533811E+0, 2.47656508055759199108314E+1, + -3.79804256470945635097577E+2, 6.29331155312818442661052E+2, + 8.66966202790413211295064E+2, -3.14512729688483675254357E+4, + -3.61444134186911729807069E+4, 6.64561438202405440627855E+4}; + + // denominator coefficients for approximation over the interval (1,2) + static const double q[] = { + -3.08402300119738975254353E+1, 3.15350626979604161529144E+2, + -1.01515636749021914166146E+3, -3.10777167157231109440444E+3, + 2.25381184209801510330112E+4, 4.75584627752788110767815E+3, + -1.34659959864969306392456E+5, -1.15132259675553483497211E+5}; + + double num = 0.0; + double den = 1.0; + + double z = y - 1; + for (auto i = 0; i < 8; i++) { + num = (num + p[i]) * z; + den = den * z + q[i]; + } + double result = num / den + 1.0; + + // Apply correction if argument was not initially in (1,2) + if (argWasLessThanOne) { + // Use identity gamma(z) = gamma(z+1)/z + // The variable "result" now holds gamma of the original y + 1 + // Thus we use y-1 to get back the orginal y. + result /= (y - 1.0); + } else { + // Use the identity gamma(z+n) = z*(z+1)* ... *(z+n-1)*gamma(z) + for (auto i = 0; i < n; i++) result *= y++; + } + + return Z(result); + } + + /////////////////////////////////////////////////////////////////////////// + // Third interval: [12, infinity) + + if (a > 171.624) { + // Correct answer too large to display. Force +infinity. + return Z(DOUBLE_MAX_VALUE); + // return DataTypeUtils::infOrMax(); + } + + return sd::math::nd4j_exp(sd::math::nd4j_lgamma(a)); +} + +template +math_def inline Z nd4j_igamma(X a, Y x) { + Z aim = nd4j_pow(x, a) / (nd4j_exp(x) * nd4j_gamma(a)); + auto sum = Z(0.); + auto denom = Z(1.); + if (a <= X(0.000001)) + // throw std::runtime_error("Cannot calculate gamma for a zero val."); + return Z(0); + + for (int i = 0; Z(1. / denom) > Z(1.0e-12); i++) { + denom *= (a + i); + sum += nd4j_pow(x, i) / denom; + } + return aim * sum; +} + +template +math_def inline Z nd4j_igammac(X a, Y x) { + return Z(1.) - nd4j_igamma(a, x); +} #ifdef __CUDACC__ - namespace atomics { +namespace atomics { template inline __device__ T nd4j_atomicAdd(T* address, T val); @@ -991,712 +969,748 @@ template inline __device__ T nd4j_atomicMax(T* address, T val); template <> -inline __device__ int32_t nd4j_atomicMin(int32_t* address, int32_t val) { - return atomicMin(address, val); +inline __device__ int32_t nd4j_atomicMin(int32_t* address, + int32_t val) { + return atomicMin(address, val); } template <> -inline __device__ uint32_t nd4j_atomicMin(uint32_t* address, uint32_t val) { - return atomicMin(address, val); +inline __device__ uint32_t nd4j_atomicMin(uint32_t* address, + uint32_t val) { + return atomicMin(address, val); } template <> -inline __device__ float nd4j_atomicMin(float* address, float val) { - int* address_as_ull = (int*)address; - int old = __float_as_int(val), assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, __float_as_int(math::nd4j_min(val, __int_as_float(assumed)))); - } while (assumed != old); - return __int_as_float(old); +inline __device__ float nd4j_atomicMin(float* address, float val) { + int* address_as_ull = (int*)address; + int old = __float_as_int(val), assumed; + do { + assumed = old; + old = + atomicCAS(address_as_ull, assumed, + __float_as_int(math::nd4j_min(val, __int_as_float(assumed)))); + } while (assumed != old); + return __int_as_float(old); } template <> -inline __device__ double nd4j_atomicMin(double* address, double val) { - unsigned long long int* address_as_ull = (unsigned long long int*)address; - unsigned long long int old = __double_as_longlong(val), assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, __double_as_longlong(math::nd4j_min(val, __longlong_as_double(assumed)))); - } while (assumed != old); - return __longlong_as_double(old); +inline __device__ double nd4j_atomicMin(double* address, double val) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = __double_as_longlong(val), assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong( + math::nd4j_min(val, __longlong_as_double(assumed)))); + } while (assumed != old); + return __longlong_as_double(old); } template <> -inline __device__ uint64_t nd4j_atomicMin(uint64_t* address, uint64_t val) { +inline __device__ uint64_t nd4j_atomicMin(uint64_t* address, + uint64_t val) { #if __CUDA_ARCH__ >= 350 - return atomicMin((unsigned long long*)address, (unsigned long long)val); + return atomicMin((unsigned long long*)address, (unsigned long long)val); #else - unsigned long long int* address_as_ull = (unsigned long long int*)address; - unsigned long long int old = __double_as_longlong(val), assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, math::nd4j_min((unsigned long long)val, assumed)); - } while (assumed != old); - return old; + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = __double_as_longlong(val), assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + math::nd4j_min((unsigned long long)val, assumed)); + } while (assumed != old); + return old; #endif } template <> -inline __device__ Nd4jLong nd4j_atomicMin(Nd4jLong* address, Nd4jLong val) { - - #if __CUDA_ARCH__ >= 350 - return atomicMin((unsigned long long*)address, (unsigned long long)val); - #else - unsigned long long int* address_as_ull = (unsigned long long int*)address; - unsigned long long int old = (unsigned long long)val, assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, math::nd4j_min(val, (Nd4jLong)assumed)); - } while (assumed != old); - return old; +inline __device__ Nd4jLong nd4j_atomicMin(Nd4jLong* address, + Nd4jLong val) { +#if __CUDA_ARCH__ >= 350 + return atomicMin((unsigned long long*)address, (unsigned long long)val); +#else + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = (unsigned long long)val, assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + math::nd4j_min(val, (Nd4jLong)assumed)); + } while (assumed != old); + return old; #endif - } template <> -inline __device__ int16_t nd4j_atomicMin(int16_t* address, int16_t val) { - int32_t temp = *address; - *address = atomicMin(&temp, (int)val); - return *address; +inline __device__ int16_t nd4j_atomicMin(int16_t* address, + int16_t val) { + int32_t temp = *address; + *address = atomicMin(&temp, (int)val); + return *address; } template <> -inline __device__ bfloat16 nd4j_atomicMin(bfloat16* address, bfloat16 val) { - return bfloat16(nd4j_atomicMin(&address->_data, val._data)); +inline __device__ bfloat16 nd4j_atomicMin(bfloat16* address, + bfloat16 val) { + return bfloat16(nd4j_atomicMin(&address->_data, val._data)); } template <> -inline __device__ float16 nd4j_atomicMin(float16* address, float16 val) { - return float16(nd4j_atomicMin(reinterpret_cast(&address->data), (int16_t)val.data)); +inline __device__ float16 nd4j_atomicMin(float16* address, + float16 val) { + return float16(nd4j_atomicMin( + reinterpret_cast(&address->data), (int16_t)val.data)); } template <> -inline __device__ int32_t nd4j_atomicMax(int32_t* address, int32_t val) { - return atomicMax(address, val); +inline __device__ int32_t nd4j_atomicMax(int32_t* address, + int32_t val) { + return atomicMax(address, val); } template <> -inline __device__ uint32_t nd4j_atomicMax(uint32_t* address, uint32_t val) { - return atomicMax(address, val); +inline __device__ uint32_t nd4j_atomicMax(uint32_t* address, + uint32_t val) { + return atomicMax(address, val); } template <> -inline __device__ double nd4j_atomicMax(double* address, double val) { - unsigned long long int* address_as_ull = (unsigned long long int*)address; - unsigned long long int old = __double_as_longlong(val), assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, __double_as_longlong(math::nd4j_max(val, __longlong_as_double(assumed)))); - } while (assumed != old); - return __longlong_as_double(old); +inline __device__ double nd4j_atomicMax(double* address, double val) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = __double_as_longlong(val), assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong( + math::nd4j_max(val, __longlong_as_double(assumed)))); + } while (assumed != old); + return __longlong_as_double(old); } template <> -inline __device__ float nd4j_atomicMax(float* address, float val) { - int* address_as_ull = (int*)address; - int old = __float_as_int(val), assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, __float_as_int(math::nd4j_max(val, __int_as_float(assumed)))); - } while (assumed != old); - return __int_as_float(old); +inline __device__ float nd4j_atomicMax(float* address, float val) { + int* address_as_ull = (int*)address; + int old = __float_as_int(val), assumed; + do { + assumed = old; + old = + atomicCAS(address_as_ull, assumed, + __float_as_int(math::nd4j_max(val, __int_as_float(assumed)))); + } while (assumed != old); + return __int_as_float(old); } template <> -inline __device__ uint8_t nd4j_atomicMin(uint8_t* address, uint8_t val) { - uint32_t temp = *address; - *address = atomicMin(&temp, (uint32_t)val); - return *address; +inline __device__ uint8_t nd4j_atomicMin(uint8_t* address, + uint8_t val) { + uint32_t temp = *address; + *address = atomicMin(&temp, (uint32_t)val); + return *address; } template <> -inline __device__ int8_t nd4j_atomicMin(int8_t* address, int8_t val) { - int32_t temp = *address; - *address = atomicMin(&temp, (int)val); - return *address; +inline __device__ int8_t nd4j_atomicMin(int8_t* address, int8_t val) { + int32_t temp = *address; + *address = atomicMin(&temp, (int)val); + return *address; } template <> -inline __device__ uint16_t nd4j_atomicMin(uint16_t* address, uint16_t val) { - uint32_t temp = *address; - *address = atomicMin(&temp, (uint32_t)val); - return *address; +inline __device__ uint16_t nd4j_atomicMin(uint16_t* address, + uint16_t val) { + uint32_t temp = *address; + *address = atomicMin(&temp, (uint32_t)val); + return *address; } template <> -inline __device__ uint8_t nd4j_atomicMax(uint8_t* address, uint8_t val) { - uint32_t temp = *address; - *address = atomicMax(&temp, (uint32_t)val); - return *address; +inline __device__ uint8_t nd4j_atomicMax(uint8_t* address, + uint8_t val) { + uint32_t temp = *address; + *address = atomicMax(&temp, (uint32_t)val); + return *address; } template <> -inline __device__ int8_t nd4j_atomicMax(int8_t* address, int8_t val) { - int32_t temp = *address; - *address = atomicMax(&temp, (int)val); - return *address; +inline __device__ int8_t nd4j_atomicMax(int8_t* address, int8_t val) { + int32_t temp = *address; + *address = atomicMax(&temp, (int)val); + return *address; } template <> -inline __device__ uint16_t nd4j_atomicMax(uint16_t* address, uint16_t val) { - uint32_t temp = *address; - *address = atomicMax(&temp, (uint32_t)val); - return *address; +inline __device__ uint16_t nd4j_atomicMax(uint16_t* address, + uint16_t val) { + uint32_t temp = *address; + *address = atomicMax(&temp, (uint32_t)val); + return *address; } template <> -inline __device__ int16_t nd4j_atomicMax(int16_t* address, int16_t val) { - int32_t temp = *address; - *address = atomicMax(&temp, (int32_t)val); - return *address; +inline __device__ int16_t nd4j_atomicMax(int16_t* address, + int16_t val) { + int32_t temp = *address; + *address = atomicMax(&temp, (int32_t)val); + return *address; } template <> -inline __device__ float16 nd4j_atomicMax(float16* address, float16 val) { - auto address_as_ull = (int*) address; - - long addr = (long) address; - bool misaligned = addr & 0x3; +inline __device__ float16 nd4j_atomicMax(float16* address, + float16 val) { + auto address_as_ull = (int*)address; - if (misaligned) - address_as_ull = (int *) (address - 1); + long addr = (long)address; + bool misaligned = addr & 0x3; - PAIR old, assumed, fresh; + if (misaligned) address_as_ull = (int*)(address - 1); - old.W = *address_as_ull; - do { + PAIR old, assumed, fresh; - if (!misaligned) { - float16 res = nd4j_max((float16) old.B.H, val); - fresh.B.H = res.data; - fresh.B.L = old.B.L; - } else { - float16 res = nd4j_max((float16) old.B.L, val); - fresh.B.L = res.data; - fresh.B.H = old.B.H; - } + old.W = *address_as_ull; + do { + if (!misaligned) { + float16 res = nd4j_max((float16)old.B.H, val); + fresh.B.H = res.data; + fresh.B.L = old.B.L; + } else { + float16 res = nd4j_max((float16)old.B.L, val); + fresh.B.L = res.data; + fresh.B.H = old.B.H; + } - assumed.W = old.W; - old.W = atomicCAS(address_as_ull, assumed.W, fresh.W); - } while (assumed.W != old.W); + assumed.W = old.W; + old.W = atomicCAS(address_as_ull, assumed.W, fresh.W); + } while (assumed.W != old.W); - if (!misaligned) return old.B.H; - else return old.B.L; + if (!misaligned) + return old.B.H; + else + return old.B.L; } template <> -inline __device__ bfloat16 nd4j_atomicMax(bfloat16* address, bfloat16 val) { - auto address_as_ull = (int*) address; +inline __device__ bfloat16 nd4j_atomicMax(bfloat16* address, + bfloat16 val) { + auto address_as_ull = (int*)address; - long addr = (long)(address); - bool misaligned = addr & 0x3; + long addr = (long)(address); + bool misaligned = addr & 0x3; - if (misaligned) - address_as_ull = (int *) (address - 1); + if (misaligned) address_as_ull = (int*)(address - 1); - BPAIR old, assumed, fresh; + BPAIR old, assumed, fresh; - old.W = *address_as_ull; - do { - - if (!misaligned) { - bfloat16 res = nd4j_max(old.B.H, val); - fresh.B.H = res; - fresh.B.L = old.B.L; - } else { - bfloat16 res = nd4j_max(old.B.L, val); - fresh.B.L = res; - fresh.B.H = old.B.H; - } + old.W = *address_as_ull; + do { + if (!misaligned) { + bfloat16 res = nd4j_max(old.B.H, val); + fresh.B.H = res; + fresh.B.L = old.B.L; + } else { + bfloat16 res = nd4j_max(old.B.L, val); + fresh.B.L = res; + fresh.B.H = old.B.H; + } - assumed.W = old.W; - old.W = atomicCAS(address_as_ull, assumed.W, fresh.W); - } while (assumed.W != old.W); + assumed.W = old.W; + old.W = atomicCAS(address_as_ull, assumed.W, fresh.W); + } while (assumed.W != old.W); - if (!misaligned) return old.B.H; - else return old.B.L; + if (!misaligned) + return old.B.H; + else + return old.B.L; } template <> -inline __device__ uint64_t nd4j_atomicMax(uint64_t* address, uint64_t val) { +inline __device__ uint64_t nd4j_atomicMax(uint64_t* address, + uint64_t val) { #if __CUDA_ARCH__ >= 350 - return atomicMax((unsigned long long*)address, (unsigned long long)val); + return atomicMax((unsigned long long*)address, (unsigned long long)val); #else - unsigned long long int* address_as_ull = (unsigned long long int*)address; - unsigned long long int old = __double_as_longlong(val), assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, math::nd4j_max((unsigned long long)val, assumed)); - } while (assumed != old); - return old; + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = __double_as_longlong(val), assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + math::nd4j_max((unsigned long long)val, assumed)); + } while (assumed != old); + return old; #endif } template <> -inline __device__ Nd4jLong nd4j_atomicMax(Nd4jLong* address, Nd4jLong val) { - unsigned long long int* address_as_ull = (unsigned long long int *) address; - - //return (Nd4jLong) atomicAdd(address_as_ull, (unsigned long long int) val); - unsigned long long int old = *address_as_ull, assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, (unsigned long long)nd4j_max(val, (Nd4jLong)assumed)); - } while (assumed != old); - return old; +inline __device__ Nd4jLong nd4j_atomicMax(Nd4jLong* address, + Nd4jLong val) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + + // return (Nd4jLong) atomicAdd(address_as_ull, (unsigned long long int) val); + unsigned long long int old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + (unsigned long long)nd4j_max(val, (Nd4jLong)assumed)); + } while (assumed != old); + return old; } - template <> -inline __device__ double nd4j_atomicAdd(double* address, double val) { - unsigned long long int* address_as_ull = - (unsigned long long int *) address; - unsigned long long int old = *address_as_ull, assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed,__double_as_longlong(val + - __longlong_as_double(assumed))); - } while (assumed != old); - return __longlong_as_double(old); +inline __device__ double nd4j_atomicAdd(double* address, double val) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong(val + __longlong_as_double(assumed))); + } while (assumed != old); + return __longlong_as_double(old); } template <> -inline __device__ Nd4jLong nd4j_atomicAdd(Nd4jLong* address, Nd4jLong val) { - unsigned long long int* address_as_ull = (unsigned long long int *) address; - - //return (Nd4jLong) atomicAdd(address_as_ull, (unsigned long long int) val); - unsigned long long int old = *address_as_ull, assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, val + assumed); - } while (assumed != old); - return old; +inline __device__ Nd4jLong nd4j_atomicAdd(Nd4jLong* address, + Nd4jLong val) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + + // return (Nd4jLong) atomicAdd(address_as_ull, (unsigned long long int) val); + unsigned long long int old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, val + assumed); + } while (assumed != old); + return old; } template <> -inline __device__ long nd4j_atomicAdd(long* address, long val) { - unsigned long long* address_as_ull = (unsigned long long int *) address; - -// return atomicAdd(address, val); - unsigned long int old = *address_as_ull, assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, val + assumed); - } while (assumed != old); - return old; +inline __device__ long nd4j_atomicAdd(long* address, long val) { + unsigned long long* address_as_ull = (unsigned long long int*)address; + + // return atomicAdd(address, val); + unsigned long int old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, val + assumed); + } while (assumed != old); + return old; } template <> -inline __device__ uint32_t nd4j_atomicAdd(uint32_t* address, uint32_t val) { - return atomicAdd(address, val); +inline __device__ uint32_t nd4j_atomicAdd(uint32_t* address, + uint32_t val) { + return atomicAdd(address, val); } template <> -inline __device__ uint64_t nd4j_atomicAdd(uint64_t* address, uint64_t val) { -// unsigned long long* address_as_ull = (unsigned long long int *) address; -// -//// return atomicAdd(address, val); -// unsigned long int old = *address_as_ull, assumed; -// do { -// assumed = old; -// old = atomicCAS(address_as_ull, assumed, val + assumed); -// } while (assumed != old); -// return old; - return (uint64_t)atomicAdd((unsigned long long*)address, (unsigned long long)val); +inline __device__ uint64_t nd4j_atomicAdd(uint64_t* address, + uint64_t val) { + // unsigned long long* address_as_ull = (unsigned long long int *) address; + // + //// return atomicAdd(address, val); + // unsigned long int old = *address_as_ull, assumed; + // do { + // assumed = old; + // old = atomicCAS(address_as_ull, assumed, val + assumed); + // } while (assumed != old); + // return old; + return (uint64_t)atomicAdd((unsigned long long*)address, + (unsigned long long)val); } template <> -inline __device__ float16 nd4j_atomicAdd(float16* address, float16 val) { +inline __device__ float16 nd4j_atomicAdd(float16* address, + float16 val) { #if __CUDA_ARCH__ >= 700 && defined(CUDA_10) - atomicAdd(reinterpret_cast<__half*>(address), val.data); + atomicAdd(reinterpret_cast<__half*>(address), val.data); #else - auto address_as_ull = (int*) address; + auto address_as_ull = (int*)address; - long addr = (long) address; - bool misaligned = addr & 0x3; + long addr = (long)address; + bool misaligned = addr & 0x3; - if (misaligned) - address_as_ull = (int *) (address - 1); + if (misaligned) address_as_ull = (int*)(address - 1); - PAIR old, assumed, fresh; + PAIR old, assumed, fresh; - old.W = *address_as_ull; - do { - - if (!misaligned) { - float16 res = ((float16) old.B.H) + val; - fresh.B.H = res.data; - fresh.B.L = old.B.L; - } else { - float16 res = ((float16) old.B.L) + val; - fresh.B.L = res.data; - fresh.B.H = old.B.H; - } + old.W = *address_as_ull; + do { + if (!misaligned) { + float16 res = ((float16)old.B.H) + val; + fresh.B.H = res.data; + fresh.B.L = old.B.L; + } else { + float16 res = ((float16)old.B.L) + val; + fresh.B.L = res.data; + fresh.B.H = old.B.H; + } - assumed.W = old.W; - old.W = atomicCAS(address_as_ull, assumed.W, fresh.W); - } while (assumed.W != old.W); + assumed.W = old.W; + old.W = atomicCAS(address_as_ull, assumed.W, fresh.W); + } while (assumed.W != old.W); - if (!misaligned) return old.B.H; - else return old.B.L; + if (!misaligned) + return old.B.H; + else + return old.B.L; #endif } template <> -inline __device__ bfloat16 nd4j_atomicAdd(bfloat16* address, bfloat16 val) { - auto address_as_ull = (int*) address; - - auto addr = (long)(address); - bool misaligned = addr & 0x3; +inline __device__ bfloat16 nd4j_atomicAdd(bfloat16* address, + bfloat16 val) { + auto address_as_ull = (int*)address; - if (misaligned) - address_as_ull = (int *) (address - 1); + auto addr = (long)(address); + bool misaligned = addr & 0x3; - BPAIR old, assumed, fresh; + if (misaligned) address_as_ull = (int*)(address - 1); - old.W = *address_as_ull; - do { + BPAIR old, assumed, fresh; - if (!misaligned) { - bfloat16 res = old.B.H + val; - fresh.B.H = res; - fresh.B.L = old.B.L; - } else { - bfloat16 res = old.B.L + val; - fresh.B.L = res; - fresh.B.H = old.B.H; - } + old.W = *address_as_ull; + do { + if (!misaligned) { + bfloat16 res = old.B.H + val; + fresh.B.H = res; + fresh.B.L = old.B.L; + } else { + bfloat16 res = old.B.L + val; + fresh.B.L = res; + fresh.B.H = old.B.H; + } - assumed.W = old.W; - old.W = atomicCAS(address_as_ull, assumed.W, fresh.W); - } while (assumed.W != old.W); + assumed.W = old.W; + old.W = atomicCAS(address_as_ull, assumed.W, fresh.W); + } while (assumed.W != old.W); - if (!misaligned) return old.B.H; - else return old.B.L; + if (!misaligned) + return old.B.H; + else + return old.B.L; } template static inline __device__ T internal_16bit_atomicAdd(T* address, T val) { - size_t shift = ((size_t)address & 2); - int *base_address = (int *)((char*)address - shift); - - union I16PAIR { - struct { - T H; - T L; - } B; - int W; + size_t shift = ((size_t)address & 2); + int* base_address = (int*)((char*)address - shift); - __host__ __device__ - I16PAIR() {}; + union I16PAIR { + struct { + T H; + T L; + } B; + int W; - __host__ __device__ - ~I16PAIR() {}; - }; + __host__ __device__ I16PAIR(){}; - I16PAIR pairNew, pairOld, pairAssumed; + __host__ __device__ ~I16PAIR(){}; + }; - if (reinterpret_cast(address) == base_address) { - pairOld.B.L = val; - do { + I16PAIR pairNew, pairOld, pairAssumed; - pairNew.B.L = pairOld.B.L; - pairNew.B.H = pairOld.B.H + val; - pairAssumed.W = pairOld.W; - - pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W); - } while (pairAssumed.W != pairOld.W); + if (reinterpret_cast(address) == base_address) { + pairOld.B.L = val; + do { + pairNew.B.L = pairOld.B.L; + pairNew.B.H = pairOld.B.H + val; + pairAssumed.W = pairOld.W; - return (T) pairOld.B.H; - } else { - pairOld.B.H = val; - do { + pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W); + } while (pairAssumed.W != pairOld.W); - pairNew.B.H = pairOld.B.H; - pairNew.B.L = pairOld.B.L + val; - pairAssumed.W = pairOld.W; - pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W); + return (T)pairOld.B.H; + } else { + pairOld.B.H = val; + do { + pairNew.B.H = pairOld.B.H; + pairNew.B.L = pairOld.B.L + val; + pairAssumed.W = pairOld.W; + pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W); - } while (pairAssumed.W != pairOld.W); - - return (T) pairOld.B.L; - } + } while (pairAssumed.W != pairOld.W); + return (T)pairOld.B.L; + } } template <> -inline __device__ int16_t nd4j_atomicAdd(int16_t* address, int16_t val) { - return internal_16bit_atomicAdd(address, val); +inline __device__ int16_t nd4j_atomicAdd(int16_t* address, + int16_t val) { + return internal_16bit_atomicAdd(address, val); } template <> -inline __device__ uint16_t nd4j_atomicAdd(uint16_t* address, uint16_t val) { - return internal_16bit_atomicAdd(address, val); +inline __device__ uint16_t nd4j_atomicAdd(uint16_t* address, + uint16_t val) { + return internal_16bit_atomicAdd(address, val); } template <> -inline __device__ int8_t nd4j_atomicAdd(int8_t* address, int8_t val) { - int res = *address; - atomicAdd(&res, (int)val); - *address = res; - return *address; +inline __device__ int8_t nd4j_atomicAdd(int8_t* address, int8_t val) { + int res = *address; + atomicAdd(&res, (int)val); + *address = res; + return *address; } template <> -inline __device__ uint8_t nd4j_atomicAdd(uint8_t* address, uint8_t val) { - int res = *address; - atomicAdd(&res, (int)val); - *address = res; - return *address; +inline __device__ uint8_t nd4j_atomicAdd(uint8_t* address, + uint8_t val) { + int res = *address; + atomicAdd(&res, (int)val); + *address = res; + return *address; } template <> -inline __device__ bool nd4j_atomicAdd(bool* address, bool val) { - *address += (val); - return *address; +inline __device__ bool nd4j_atomicAdd(bool* address, bool val) { + *address += (val); + return *address; } template <> -inline __device__ double nd4j_atomicSub(double* address, double val) { - return nd4j_atomicAdd(address, -val); +inline __device__ double nd4j_atomicSub(double* address, double val) { + return nd4j_atomicAdd(address, -val); } template <> -inline __device__ double nd4j_atomicMul(double* address, double val) { - unsigned long long int* address_as_ull = - (unsigned long long int*) address; - unsigned long long int old = *address_as_ull, assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed,__double_as_longlong(val * - __longlong_as_double(assumed))); - } while (assumed != old); - return __longlong_as_double(old); +inline __device__ double nd4j_atomicMul(double* address, double val) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong(val * __longlong_as_double(assumed))); + } while (assumed != old); + return __longlong_as_double(old); } template <> -inline __device__ double nd4j_atomicDiv(double* address, double val) { - return nd4j_atomicMul(address, 1./val); +inline __device__ double nd4j_atomicDiv(double* address, double val) { + return nd4j_atomicMul(address, 1. / val); } template <> -inline __device__ float nd4j_atomicAdd(float* address, float val) { - return atomicAdd(address,val); +inline __device__ float nd4j_atomicAdd(float* address, float val) { + return atomicAdd(address, val); } -//template <> -//inline __device__ int nd4j_atomicAdd(int* address, int val) { +// template <> +// inline __device__ int nd4j_atomicAdd(int* address, int val) { // return atomicAdd(address, val); //} template <> -inline __device__ int32_t nd4j_atomicAdd(int32_t* address, int32_t val) { - return (int32_t)atomicAdd((int*)address, (int)val); +inline __device__ int32_t nd4j_atomicAdd(int32_t* address, + int32_t val) { + return (int32_t)atomicAdd((int*)address, (int)val); } - template <> inline __device__ float nd4j_atomicSub(float* address, float val) { - return nd4j_atomicAdd(address, -val); + return nd4j_atomicAdd(address, -val); } template <> -inline __device__ float16 nd4j_atomicSub(float16* address, float16 val) { - return nd4j_atomicAdd(address, -val); +inline __device__ float16 nd4j_atomicSub(float16* address, + float16 val) { + return nd4j_atomicAdd(address, -val); } template <> -inline __device__ bfloat16 nd4j_atomicSub(bfloat16* address, bfloat16 val) { - return nd4j_atomicAdd(address, -val); +inline __device__ bfloat16 nd4j_atomicSub(bfloat16* address, + bfloat16 val) { + return nd4j_atomicAdd(address, -val); } template <> inline __device__ float nd4j_atomicMul(float* address, float val) { - int* address_as_ull = - ( int*)address; - int old = *address_as_ull, assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, __float_as_int(val * - __int_as_float(assumed))); - } while (assumed != old); - return __int_as_float(old); + int* address_as_ull = (int*)address; + int old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __float_as_int(val * __int_as_float(assumed))); + } while (assumed != old); + return __int_as_float(old); } template <> inline __device__ int8_t nd4j_atomicMul(int8_t* address, int8_t val) { - unsigned int *base_address = (unsigned int *)((size_t)address & ~3); - unsigned int selectors[] = {0x3214, 0x3240, 0x3410, 0x4210}; - unsigned int sel = selectors[(size_t)address & 3]; - unsigned int old, assumed, mul, new_; + unsigned int* base_address = (unsigned int*)((size_t)address & ~3); + unsigned int selectors[] = {0x3214, 0x3240, 0x3410, 0x4210}; + unsigned int sel = selectors[(size_t)address & 3]; + unsigned int old, assumed, mul, new_; - old = *base_address; + old = *base_address; - do { + do { + assumed = old; + mul = val * (int8_t)__byte_perm(old, 0, ((size_t)address & 3) | 0x4440); + new_ = __byte_perm(old, mul, sel); - assumed = old; - mul = val * (int8_t)__byte_perm(old, 0, ((size_t)address & 3) | 0x4440); - new_ = __byte_perm(old, mul, sel); + if (new_ == old) break; - if (new_ == old) - break; - - old = atomicCAS(base_address, assumed, new_); - } while (assumed != old); - return (int8_t)old; + old = atomicCAS(base_address, assumed, new_); + } while (assumed != old); + return (int8_t)old; } template <> -inline __device__ unsigned char nd4j_atomicMul(unsigned char* address, unsigned char val) { - unsigned int *base_address = (unsigned int *)((size_t)address & ~3); - unsigned int selectors[] = {0x3214, 0x3240, 0x3410, 0x4210}; - unsigned int sel = selectors[(size_t)address & 3]; - unsigned int old, assumed, mul, new_; - - old = *base_address; +inline __device__ unsigned char nd4j_atomicMul( + unsigned char* address, unsigned char val) { + unsigned int* base_address = (unsigned int*)((size_t)address & ~3); + unsigned int selectors[] = {0x3214, 0x3240, 0x3410, 0x4210}; + unsigned int sel = selectors[(size_t)address & 3]; + unsigned int old, assumed, mul, new_; - do { + old = *base_address; - assumed = old; - mul = val * (uint8_t)__byte_perm(old, 0, ((size_t)address & 3) | 0x4440); - new_ = __byte_perm(old, mul, sel); + do { + assumed = old; + mul = val * (uint8_t)__byte_perm(old, 0, ((size_t)address & 3) | 0x4440); + new_ = __byte_perm(old, mul, sel); - if (new_ == old) - break; + if (new_ == old) break; - old = atomicCAS(base_address, assumed, new_); - } while (assumed != old); - return (uint8_t)old; + old = atomicCAS(base_address, assumed, new_); + } while (assumed != old); + return (uint8_t)old; } template static inline __device__ T internal_16bit_atomicMul(T* address, T val) { - size_t shift = ((size_t)address & 2); - int *base_address = (int *)((char*)address - shift); + size_t shift = ((size_t)address & 2); + int* base_address = (int*)((char*)address - shift); - union I16PAIR { - struct { - T H; - T L; - } B; - int W; + union I16PAIR { + struct { + T H; + T L; + } B; + int W; - __host__ __device__ - I16PAIR() {}; + __host__ __device__ I16PAIR(){}; - __host__ __device__ - ~I16PAIR() {}; - }; + __host__ __device__ ~I16PAIR(){}; + }; - I16PAIR pairNew, pairOld, pairAssumed; + I16PAIR pairNew, pairOld, pairAssumed; - if (reinterpret_cast(address) == base_address) { - pairOld.B.L = val; - do { - pairNew.B.L = pairOld.B.L; - pairNew.B.H = pairOld.B.H * val; - pairAssumed.W = pairOld.W; + if (reinterpret_cast(address) == base_address) { + pairOld.B.L = val; + do { + pairNew.B.L = pairOld.B.L; + pairNew.B.H = pairOld.B.H * val; + pairAssumed.W = pairOld.W; - pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W); - } while (pairAssumed.W != pairOld.W); + pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W); + } while (pairAssumed.W != pairOld.W); - return (T) pairOld.B.H; - } else { - pairOld.B.H = val; - do { - pairNew.B.H = pairOld.B.H; - pairNew.B.L = pairOld.B.L * val; - pairAssumed.W = pairOld.W; - pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W); + return (T)pairOld.B.H; + } else { + pairOld.B.H = val; + do { + pairNew.B.H = pairOld.B.H; + pairNew.B.L = pairOld.B.L * val; + pairAssumed.W = pairOld.W; + pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W); - } while (pairAssumed.W != pairOld.W); + } while (pairAssumed.W != pairOld.W); - return (T) pairOld.B.L; - } + return (T)pairOld.B.L; + } } template <> -inline __device__ int16_t nd4j_atomicMul(int16_t* address, int16_t val) { - return internal_16bit_atomicMul(address, val); +inline __device__ int16_t nd4j_atomicMul(int16_t* address, + int16_t val) { + return internal_16bit_atomicMul(address, val); } template <> -inline __device__ uint16_t nd4j_atomicMul(uint16_t* address, uint16_t val) { - return internal_16bit_atomicMul(address, val); +inline __device__ uint16_t nd4j_atomicMul(uint16_t* address, + uint16_t val) { + return internal_16bit_atomicMul(address, val); } template <> inline __device__ int nd4j_atomicMul(int* address, int val) { - int* res_address = address; - int old = *res_address, assumed; - do { - assumed = old; - old = atomicCAS(res_address, assumed, val * assumed); - } while (assumed != old); - return old; + int* res_address = address; + int old = *res_address, assumed; + do { + assumed = old; + old = atomicCAS(res_address, assumed, val * assumed); + } while (assumed != old); + return old; } template <> -inline __device__ unsigned int nd4j_atomicMul(unsigned int* address, unsigned int val) { - unsigned int* res_address = address; - unsigned int old = *res_address, assumed; - do { - assumed = old; - old = atomicCAS(res_address, assumed, val * assumed); - } while (assumed != old); - return old; +inline __device__ unsigned int nd4j_atomicMul( + unsigned int* address, unsigned int val) { + unsigned int* res_address = address; + unsigned int old = *res_address, assumed; + do { + assumed = old; + old = atomicCAS(res_address, assumed, val * assumed); + } while (assumed != old); + return old; } template <> -inline __device__ int64_t nd4j_atomicMul(int64_t* address, int64_t val) { - unsigned long long int* res_address = (unsigned long long int*)address; - unsigned long long int old = *res_address, assumed; - do { - assumed = old; - old = atomicCAS(res_address, assumed, val * assumed); - } while (assumed != old); - return (int64_t)old; +inline __device__ int64_t nd4j_atomicMul(int64_t* address, + int64_t val) { + unsigned long long int* res_address = (unsigned long long int*)address; + unsigned long long int old = *res_address, assumed; + do { + assumed = old; + old = atomicCAS(res_address, assumed, val * assumed); + } while (assumed != old); + return (int64_t)old; } template <> -inline __device__ uint64_t nd4j_atomicMul(uint64_t* address, uint64_t val) { - unsigned long long int* res_address = (unsigned long long int*)address; - unsigned long long int old = *res_address, assumed; - do { - assumed = old; - old = atomicCAS(res_address, assumed, val * assumed); - } while (assumed != old); - return (uint64_t)old; +inline __device__ uint64_t nd4j_atomicMul(uint64_t* address, + uint64_t val) { + unsigned long long int* res_address = (unsigned long long int*)address; + unsigned long long int old = *res_address, assumed; + do { + assumed = old; + old = atomicCAS(res_address, assumed, val * assumed); + } while (assumed != old); + return (uint64_t)old; } #if !defined(_WIN32) && !defined(_WIN64) template <> -inline __device__ Nd4jLong nd4j_atomicMul(Nd4jLong* address, Nd4jLong val) { - unsigned long long int* res_address = (unsigned long long*)address; - unsigned long long int old = *res_address, assumed; - do { - assumed = old; - old = atomicCAS(res_address, assumed, val * assumed); - } while (assumed != old); - return (Nd4jLong)old; +inline __device__ Nd4jLong nd4j_atomicMul(Nd4jLong* address, + Nd4jLong val) { + unsigned long long int* res_address = (unsigned long long*)address; + unsigned long long int old = *res_address, assumed; + do { + assumed = old; + old = atomicCAS(res_address, assumed, val * assumed); + } while (assumed != old); + return (Nd4jLong)old; } #endif template <> -inline __device__ bfloat16 nd4j_atomicMul(bfloat16* address, bfloat16 val) { - return internal_16bit_atomicMul(address, val); +inline __device__ bfloat16 nd4j_atomicMul(bfloat16* address, + bfloat16 val) { + return internal_16bit_atomicMul(address, val); } template <> -inline __device__ float16 nd4j_atomicMul(float16* address, float16 val) { - return internal_16bit_atomicMul(address, val); +inline __device__ float16 nd4j_atomicMul(float16* address, + float16 val) { + return internal_16bit_atomicMul(address, val); } template <> inline __device__ float nd4j_atomicDiv(float* address, float val) { - return nd4j_atomicMul(address, 1.f / val); + return nd4j_atomicMul(address, 1.f / val); } template <> -inline __device__ float16 nd4j_atomicDiv(float16* address, float16 val) { - return internal_16bit_atomicMul(address, (float16) 1.f / val); +inline __device__ float16 nd4j_atomicDiv(float16* address, + float16 val) { + return internal_16bit_atomicMul(address, (float16)1.f / val); } template <> -inline __device__ bfloat16 nd4j_atomicDiv(bfloat16* address, bfloat16 val) { - return internal_16bit_atomicMul(address, (bfloat16) 1 / val); -} +inline __device__ bfloat16 nd4j_atomicDiv(bfloat16* address, + bfloat16 val) { + return internal_16bit_atomicMul(address, (bfloat16)1 / val); } +} // namespace atomics #endif - } -} +} // namespace math +} // namespace sd #ifdef _OPENMP @@ -1704,39 +1718,60 @@ inline __device__ bfloat16 nd4j_atomicDiv(bfloat16* address, bfloat16 #define MAX_FLOAT 1e37 #endif -#pragma omp declare reduction(maxTF : float,double,float16,bfloat16 : \ - omp_out = sd::math::nd4j_max(omp_in, omp_out) )\ - initializer (omp_priv=-MAX_FLOAT) - -#pragma omp declare reduction(minTF : float,double,float16,bfloat16 : \ - omp_out = sd::math::nd4j_min(omp_in, omp_out) )\ - initializer (omp_priv=MAX_FLOAT) - -#pragma omp declare reduction(maxT : float,double,float16,bfloat16,int,Nd4jLong,Nd4jULong,int8_t,uint8_t,bool,int16_t,uint16_t,uint32_t : \ - omp_out = sd::math::nd4j_max(omp_in, omp_out) )\ - initializer (omp_priv=0) - -#pragma omp declare reduction(minT : float,double,float16,bfloat16,int,Nd4jLong,Nd4jULong,int8_t,uint8_t,bool,int16_t,uint16_t,uint32_t : \ - omp_out = sd::math::nd4j_min(omp_in, omp_out) )\ - initializer (omp_priv=0) - -#pragma omp declare reduction(amaxT : float,double,float16,bfloat16,int,Nd4jLong,Nd4jULong,int8_t,uint8_t,bool,int16_t,uint16_t,uint32_t : \ - omp_out = sd::math::nd4j_max(sd::math::nd4j_abs(omp_in), sd::math::nd4j_abs(omp_out)) ) - -#pragma omp declare reduction(aminT : float,double,float16,bfloat16,int,Nd4jLong,Nd4jULong,int8_t,uint8_t,bool,int16_t,uint16_t,uint32_t : \ - omp_out = sd::math::nd4j_min(sd::math::nd4j_abs(omp_in), sd::math::nd4j_abs(omp_out)) ) - -#pragma omp declare reduction(asumT : float,double,float16,bfloat16,int,Nd4jLong,Nd4jULong,int8_t,uint8_t,bool,int16_t,uint16_t,uint32_t : \ - omp_out = sd::math::nd4j_abs(omp_in) + sd::math::nd4j_abs(omp_out))\ - initializer (omp_priv=0) - -#pragma omp declare reduction(sumT : float,double,float16,bfloat16,int,Nd4jLong,Nd4jULong,int8_t,uint8_t,bool,int16_t,uint16_t,uint32_t : \ - omp_out = omp_in + omp_out)\ - initializer (omp_priv=0) - -#pragma omp declare reduction(prodT : float,double,float16,bfloat16,int,Nd4jLong,Nd4jULong,int8_t,uint8_t,bool,int16_t,uint16_t,uint32_t : \ - omp_out = omp_in * omp_out)\ - initializer (omp_priv=1) +#pragma omp declare reduction(maxTF \ + : float, double, float16, bfloat16 \ + : omp_out = sd::math::nd4j_max(omp_in, omp_out)) \ + initializer(omp_priv = -MAX_FLOAT) + +#pragma omp declare reduction(minTF \ + : float, double, float16, bfloat16 \ + : omp_out = sd::math::nd4j_min(omp_in, omp_out)) \ + initializer(omp_priv = MAX_FLOAT) + +#pragma omp declare reduction( \ + maxT \ + : float, double, float16, bfloat16, int, Nd4jLong, Nd4jULong, int8_t, \ + uint8_t, bool, int16_t, uint16_t, uint32_t \ + : omp_out = sd::math::nd4j_max(omp_in, omp_out)) initializer(omp_priv = 0) + +#pragma omp declare reduction( \ + minT \ + : float, double, float16, bfloat16, int, Nd4jLong, Nd4jULong, int8_t, \ + uint8_t, bool, int16_t, uint16_t, uint32_t \ + : omp_out = sd::math::nd4j_min(omp_in, omp_out)) initializer(omp_priv = 0) + +#pragma omp declare reduction( \ + amaxT \ + : float, double, float16, bfloat16, int, Nd4jLong, Nd4jULong, int8_t, \ + uint8_t, bool, int16_t, uint16_t, uint32_t \ + : omp_out = sd::math::nd4j_max(sd::math::nd4j_abs(omp_in), \ + sd::math::nd4j_abs(omp_out))) + +#pragma omp declare reduction( \ + aminT \ + : float, double, float16, bfloat16, int, Nd4jLong, Nd4jULong, int8_t, \ + uint8_t, bool, int16_t, uint16_t, uint32_t \ + : omp_out = sd::math::nd4j_min(sd::math::nd4j_abs(omp_in), \ + sd::math::nd4j_abs(omp_out))) + +#pragma omp declare reduction( \ + asumT \ + : float, double, float16, bfloat16, int, Nd4jLong, Nd4jULong, int8_t, \ + uint8_t, bool, int16_t, uint16_t, uint32_t \ + : omp_out = sd::math::nd4j_abs(omp_in) + sd::math::nd4j_abs(omp_out)) \ + initializer(omp_priv = 0) + +#pragma omp declare reduction( \ + sumT \ + : float, double, float16, bfloat16, int, Nd4jLong, Nd4jULong, int8_t, \ + uint8_t, bool, int16_t, uint16_t, uint32_t \ + : omp_out = omp_in + omp_out) initializer(omp_priv = 0) + +#pragma omp declare reduction( \ + prodT \ + : float, double, float16, bfloat16, int, Nd4jLong, Nd4jULong, int8_t, \ + uint8_t, bool, int16_t, uint16_t, uint32_t \ + : omp_out = omp_in * omp_out) initializer(omp_priv = 1) #endif diff --git a/libnd4j/include/memory/AllocationEntry.h b/libnd4j/include/memory/AllocationEntry.h index 815a5c992f2b..13ec15e96c0e 100644 --- a/libnd4j/include/memory/AllocationEntry.h +++ b/libnd4j/include/memory/AllocationEntry.h @@ -18,33 +18,34 @@ // Created by raver119 on 07.05.19. // -#ifndef DEV_TESTS_ALLOCATIONENTRY_H -#define DEV_TESTS_ALLOCATIONENTRY_H +#ifndef SD_ALLOCATIONENTRY_H +#define SD_ALLOCATIONENTRY_H +#include #include + #include -#include namespace sd { - namespace memory { - class AllocationEntry { - private: - MemoryType _memoryType; - Nd4jLong _pointer; - Nd4jLong _numBytes; - std::string _stack; - public: - AllocationEntry() = default; - AllocationEntry(MemoryType type, Nd4jLong ptr, Nd4jLong numBytes, std::string &stack); - ~AllocationEntry() = default; - - - Nd4jLong numBytes(); - std::string stackTrace(); - MemoryType memoryType(); - }; - } -} - - -#endif //DEV_TESTS_ALLOCATIONENTRY_H +namespace memory { +class AllocationEntry { + private: + MemoryType _memoryType; + Nd4jLong _pointer; + Nd4jLong _numBytes; + std::string _stack; + + public: + AllocationEntry() = default; + AllocationEntry(MemoryType type, Nd4jLong ptr, Nd4jLong numBytes, + std::string &stack); + ~AllocationEntry() = default; + + Nd4jLong numBytes(); + std::string stackTrace(); + MemoryType memoryType(); +}; +} // namespace memory +} // namespace sd + +#endif // SD_ALLOCATIONENTRY_H diff --git a/libnd4j/include/memory/ColdZoneManager.h b/libnd4j/include/memory/ColdZoneManager.h new file mode 100644 index 000000000000..cb4b3f1ab936 --- /dev/null +++ b/libnd4j/include/memory/ColdZoneManager.h @@ -0,0 +1,53 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_COLDZONEMANAGER_H +#define SD_COLDZONEMANAGER_H + +#include + +namespace sd { +namespace memory { +class ColdZoneManager : public ZoneManager { + public: + /** + * This constructor is used to initialize ZoneManager with existing + * FlatBuffers file + * @param filename - full path to existing file (i.e. FlatBuffers file) + */ + explicit ColdZoneManager(const char *filename); + + ColdZoneManager() = default; + ~ColdZoneManager() = default; + + MemoryZone zone() const override; + + uint64_t available() const override; + + uint64_t used() const override; + + MemoryDescriptor allocate(uint64_t numBytes) override; + + void release(MemoryDescriptor &descriptor) override; +}; +} // namespace memory +} // namespace sd + +#endif // SD_COLDZONEMANAGER_H diff --git a/libnd4j/include/memory/ExternalWorkspace.h b/libnd4j/include/memory/ExternalWorkspace.h index 772afc6082ae..25fee4075ca0 100644 --- a/libnd4j/include/memory/ExternalWorkspace.h +++ b/libnd4j/include/memory/ExternalWorkspace.h @@ -21,31 +21,33 @@ #ifndef LIBND4J_EXTERNALWORKSPACE_H #define LIBND4J_EXTERNALWORKSPACE_H -#include #include +#include namespace sd { - namespace memory { - class ND4J_EXPORT ExternalWorkspace { - private: - void *_ptrH = nullptr; - void *_ptrD = nullptr; - - Nd4jLong _sizeH = 0L; - Nd4jLong _sizeD = 0L; - public: - ExternalWorkspace() = default; - ~ExternalWorkspace() = default; - - ExternalWorkspace(Nd4jPointer ptrH, Nd4jLong sizeH, Nd4jPointer ptrD, Nd4jLong sizeD); - - void *pointerHost(); - void *pointerDevice(); - - Nd4jLong sizeHost(); - Nd4jLong sizeDevice(); - }; - } -} +namespace memory { +class SD_EXPORT ExternalWorkspace { + private: + void *_ptrH = nullptr; + void *_ptrD = nullptr; + + Nd4jLong _sizeH = 0L; + Nd4jLong _sizeD = 0L; + + public: + ExternalWorkspace() = default; + ~ExternalWorkspace() = default; + + ExternalWorkspace(Nd4jPointer ptrH, Nd4jLong sizeH, Nd4jPointer ptrD, + Nd4jLong sizeD); + + void *pointerHost(); + void *pointerDevice(); + + Nd4jLong sizeHost(); + Nd4jLong sizeDevice(); +}; +} // namespace memory +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/memory/GraphMemoryManager.h b/libnd4j/include/memory/GraphMemoryManager.h new file mode 100644 index 000000000000..98a85985996f --- /dev/null +++ b/libnd4j/include/memory/GraphMemoryManager.h @@ -0,0 +1,69 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_GRAPHMEMORYMANAGER_H +#define SD_GRAPHMEMORYMANAGER_H + +#include +#include +#include +#include +#include +#include +#include + +using namespace sd::memory; + +namespace sd { +namespace graph { +class GraphMemoryManager { + protected: + std::map _zones; + + mutable std::vector> _attached; + public: + GraphMemoryManager(); + ~GraphMemoryManager(); + + /** + * This method does allocation (probably) and returns structure that describes + * it + * @param numBytes - number of bytes to be allocated + * @param zone - memory zone for allocation + * @return + */ + virtual MemoryDescriptor allocate(size_t numBytes, MemoryZone zone); + + /** + * This method releases (probably) memory chunk described by given descriptor + * @param descriptor + */ + virtual void release(MemoryDescriptor &descriptor); + + /** + * This method allows to store reference to certain memory regions and keep them alive as long as Graph is alive + * @param ptr + */ + virtual void track(const std::shared_ptr &ptr) const; +}; +} // namespace graph +} // namespace sd + +#endif // SD_GRAPHMEMORYMANAGER_H diff --git a/libnd4j/include/memory/HotRamZoneManager.h b/libnd4j/include/memory/HotRamZoneManager.h new file mode 100644 index 000000000000..f733e0f4a8ec --- /dev/null +++ b/libnd4j/include/memory/HotRamZoneManager.h @@ -0,0 +1,40 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_HOTRAMZONEMANAGER_H +#define SD_HOTRAMZONEMANAGER_H + +#include + +namespace sd { +namespace memory { +class HotRamZoneManager : public HotZoneManager { + public: + HotRamZoneManager() = default; + ~HotRamZoneManager() = default; + + MemoryDescriptor allocate(uint64_t numBytes) override; + + void release(MemoryDescriptor &descriptor) override; +}; +} // namespace memory +} // namespace sd + +#endif // SD_HOTRAMZONEMANAGER_H diff --git a/libnd4j/include/memory/HotZoneManager.h b/libnd4j/include/memory/HotZoneManager.h new file mode 100644 index 000000000000..5499e488fc59 --- /dev/null +++ b/libnd4j/include/memory/HotZoneManager.h @@ -0,0 +1,52 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_HOTZONEMANAGER_H +#define SD_HOTZONEMANAGER_H + +#include + +#include + +namespace sd { +namespace memory { +class SD_EXPORT HotZoneManager : public ZoneManager { + protected: + std::atomic _used = {0}; + std::atomic _available = {0}; + + public: + HotZoneManager() = default; + ~HotZoneManager() = default; + + MemoryZone zone() const override; + + uint64_t available() const override; + + uint64_t used() const override; + + virtual MemoryDescriptor allocate(uint64_t numBytes) override = 0; + + virtual void release(MemoryDescriptor &descriptor) override = 0; +}; +} // namespace memory +} // namespace sd + +#endif // SD_HOTZONEMANAGER_H diff --git a/libnd4j/include/memory/MemoryCounter.h b/libnd4j/include/memory/MemoryCounter.h index 160c243798fe..9675f62534ee 100644 --- a/libnd4j/include/memory/MemoryCounter.h +++ b/libnd4j/include/memory/MemoryCounter.h @@ -21,124 +21,130 @@ #ifndef SD_MEMORYCOUNTER_H #define SD_MEMORYCOUNTER_H -#include +#include #include +#include + #include -#include #include namespace sd { - namespace memory { - /** - * This class provides simple per-device counter - */ - class ND4J_EXPORT MemoryCounter { - private: - // used for synchronization - std::mutex _locker; - - // per-device counters - std::map _deviceCounters; - - // TODO: change this wrt heterogenous stuff on next iteration - // per-group counters - std::map _groupCounters; - - // per-device limits - std::map _deviceLimits; - - // per-group limits - std::map _groupLimits; - - MemoryCounter(); - ~MemoryCounter() = default; - - public: - static MemoryCounter & getInstance(); - - /** - * This method checks if allocation of numBytes won't break through per-group or per-device limit - * @param numBytes - * @return TRUE if allocated ammount will keep us below limit, FALSE otherwise - */ - bool validate(Nd4jLong numBytes); - - /** - * This method checks if allocation of numBytes won't break through per-device limit - * @param deviceId - * @param numBytes - * @return TRUE if allocated ammount will keep us below limit, FALSE otherwise - */ - bool validateDevice(int deviceId, Nd4jLong numBytes); - - /** - * This method checks if allocation of numBytes won't break through per-group limit - * @param deviceId - * @param numBytes - * @return TRUE if allocated ammount will keep us below limit, FALSE otherwise - */ - bool validateGroup(sd::memory::MemoryType group, Nd4jLong numBytes); - - /** - * This method adds specified number of bytes to specified counter - * @param deviceId - * @param numBytes - */ - void countIn(int deviceId, Nd4jLong numBytes); - void countIn(sd::memory::MemoryType group, Nd4jLong numBytes); - - /** - * This method subtracts specified number of bytes from specified counter - * @param deviceId - * @param numBytes - */ - void countOut(int deviceId, Nd4jLong numBytes); - void countOut(sd::memory::MemoryType group, Nd4jLong numBytes); - - /** - * This method returns amount of memory allocated on specified device - * @param deviceId - * @return - */ - Nd4jLong allocatedDevice(int deviceId); - - /** - * This method returns amount of memory allocated in specified group of devices - * @param group - * @return - */ - Nd4jLong allocatedGroup(sd::memory::MemoryType group); - - /** - * This method allows to set per-device memory limits - * @param deviceId - * @param numBytes - */ - void setDeviceLimit(int deviceId, Nd4jLong numBytes); - - /** - * This method returns current device limit in bytes - * @param deviceId - * @return - */ - Nd4jLong deviceLimit(int deviceId); - - /** - * This method allows to set per-group memory limits - * @param group - * @param numBytes - */ - void setGroupLimit(sd::memory::MemoryType group, Nd4jLong numBytes); - - /** - * This method returns current group limit in bytes - * @param group - * @return - */ - Nd4jLong groupLimit(sd::memory::MemoryType group); - }; - } -} - - -#endif //SD_MEMORYCOUNTER_H +namespace memory { +/** + * This class provides simple per-device counter + */ +class SD_EXPORT MemoryCounter { + private: + + + // used for synchronization + std::mutex _locker; + + // per-device counters + std::map _deviceCounters; + + // TODO: change this wrt heterogenous stuff on next iteration + // per-group counters + std::map _groupCounters; + + // per-device limits + std::map _deviceLimits; + + // per-group limits + std::map _groupLimits; + + MemoryCounter(); + ~MemoryCounter() = default; + + public: + static MemoryCounter& getInstance(); + + /** + * This method checks if allocation of numBytes won't break through per-group + * or per-device limit + * @param numBytes + * @return TRUE if allocated ammount will keep us below limit, FALSE otherwise + */ + bool validate(Nd4jLong numBytes); + + /** + * This method checks if allocation of numBytes won't break through per-device + * limit + * @param deviceId + * @param numBytes + * @return TRUE if allocated ammount will keep us below limit, FALSE otherwise + */ + bool validateDevice(int deviceId, Nd4jLong numBytes); + + /** + * This method checks if allocation of numBytes won't break through per-group + * limit + * @param deviceId + * @param numBytes + * @return TRUE if allocated ammount will keep us below limit, FALSE otherwise + */ + bool validateGroup(sd::memory::MemoryType group, Nd4jLong numBytes); + + /** + * This method adds specified number of bytes to specified counter + * @param deviceId + * @param numBytes + */ + void countIn(int deviceId, Nd4jLong numBytes); + void countIn(sd::memory::MemoryType group, Nd4jLong numBytes); + + /** + * This method subtracts specified number of bytes from specified counter + * @param deviceId + * @param numBytes + */ + void countOut(int deviceId, Nd4jLong numBytes); + void countOut(sd::memory::MemoryType group, Nd4jLong numBytes); + + /** + * This method returns amount of memory allocated on specified device + * @param deviceId + * @return + */ + Nd4jLong allocatedDevice(int deviceId); + + /** + * This method returns amount of memory allocated in specified group of + * devices + * @param group + * @return + */ + Nd4jLong allocatedGroup(sd::memory::MemoryType group); + + /** + * This method allows to set per-device memory limits + * @param deviceId + * @param numBytes + */ + void setDeviceLimit(int deviceId, Nd4jLong numBytes); + + /** + * This method returns current device limit in bytes + * @param deviceId + * @return + */ + Nd4jLong deviceLimit(int deviceId); + + /** + * This method allows to set per-group memory limits + * @param group + * @param numBytes + */ + void setGroupLimit(sd::memory::MemoryType group, Nd4jLong numBytes); + + /** + * This method returns current group limit in bytes + * @param group + * @return + */ + Nd4jLong groupLimit(sd::memory::MemoryType group); +}; +} // namespace memory +} // namespace sd + +#endif // SD_MEMORYCOUNTER_H diff --git a/libnd4j/include/memory/MemoryDescriptor.h b/libnd4j/include/memory/MemoryDescriptor.h new file mode 100644 index 000000000000..f0aad28ef183 --- /dev/null +++ b/libnd4j/include/memory/MemoryDescriptor.h @@ -0,0 +1,58 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_MEMORYDESCRIPTOR_H +#define SD_MEMORYDESCRIPTOR_H + +#include +#include + +#include + +namespace sd { +namespace memory { +class SD_EXPORT MemoryDescriptor { + private: + void* _ptr; + MemoryZone _zone; + uint64_t _bytes; + + public: + MemoryDescriptor(void* ptr, MemoryZone zone, uint64_t bytes); + ~MemoryDescriptor() = default; + + MemoryDescriptor(const MemoryDescriptor& other) noexcept; + + MemoryDescriptor& operator=(const MemoryDescriptor& other) noexcept; + + // move constructor + MemoryDescriptor(MemoryDescriptor&& other) noexcept; + + // move assignment operator + MemoryDescriptor& operator=(MemoryDescriptor&& other) noexcept; + + void* address() const; + MemoryZone zone() const; + uint64_t bytes() const; +}; +} // namespace memory +} // namespace sd + +#endif // SD_MEMORYDESCRIPTOR_H diff --git a/libnd4j/include/memory/MemoryRegistrator.h b/libnd4j/include/memory/MemoryRegistrator.h index 70afafb42b51..3f775a140a87 100644 --- a/libnd4j/include/memory/MemoryRegistrator.h +++ b/libnd4j/include/memory/MemoryRegistrator.h @@ -18,50 +18,54 @@ // Created by raver119 on 12.09.17. // -#ifndef LIBND4J_MEMORYREGISTRATOR_H -#define LIBND4J_MEMORYREGISTRATOR_H +#ifndef SD_MEMORYREGISTRATOR_H +#define SD_MEMORYREGISTRATOR_H -#include "Workspace.h" +#include #include -#include + #include #include -#include +#include + +#include "Workspace.h" namespace sd { - namespace memory { - class ND4J_EXPORT MemoryRegistrator { - protected: - Workspace* _workspace; - MAP_IMPL _footprint; - std::mutex _lock; +namespace memory { +class SD_EXPORT MemoryRegistrator { + protected: + + Workspace* _workspace; + MAP_IMPL _footprint; + std::mutex _lock; + + MemoryRegistrator(); + ~MemoryRegistrator() = default; - MemoryRegistrator(); - ~MemoryRegistrator() = default; - public: - static MemoryRegistrator& getInstance(); - bool hasWorkspaceAttached(); - Workspace* getWorkspace(); - void attachWorkspace(Workspace* workspace); - void forgetWorkspace(); + public: + static MemoryRegistrator& getInstance(); + bool hasWorkspaceAttached(); + Workspace* getWorkspace(); + void attachWorkspace(Workspace* workspace); + void forgetWorkspace(); - /** - * This method allows you to set memory requirements for given graph - */ - void setGraphMemoryFootprint(Nd4jLong hash, Nd4jLong bytes); + /** + * This method allows you to set memory requirements for given graph + */ + void setGraphMemoryFootprint(Nd4jLong hash, Nd4jLong bytes); - /** - * This method allows you to set memory requirements for given graph, ONLY if - * new amount of bytes is greater then current one - */ - void setGraphMemoryFootprintIfGreater(Nd4jLong hash, Nd4jLong bytes); + /** + * This method allows you to set memory requirements for given graph, ONLY if + * new amount of bytes is greater then current one + */ + void setGraphMemoryFootprintIfGreater(Nd4jLong hash, Nd4jLong bytes); - /** - * This method returns memory requirements for given graph - */ - Nd4jLong getGraphMemoryFootprint(Nd4jLong hash); - }; - } -} + /** + * This method returns memory requirements for given graph + */ + Nd4jLong getGraphMemoryFootprint(Nd4jLong hash); +}; +} // namespace memory +} // namespace sd -#endif //LIBND4J_MEMORYREGISTRATOR_H +#endif // SD_MEMORYREGISTRATOR_H diff --git a/libnd4j/include/memory/MemoryReport.h b/libnd4j/include/memory/MemoryReport.h index 647886ab54dd..aca8bf7efd8f 100644 --- a/libnd4j/include/memory/MemoryReport.h +++ b/libnd4j/include/memory/MemoryReport.h @@ -21,36 +21,34 @@ #ifndef LIBND4J_MEMORYREPORT_H #define LIBND4J_MEMORYREPORT_H -#include #include +#include namespace sd { - namespace memory { - class ND4J_EXPORT MemoryReport { - private: - Nd4jLong _vm = 0; - Nd4jLong _rss = 0; - - public: - MemoryReport() = default; - ~MemoryReport() = default; - - bool operator < (const MemoryReport& other) const; - bool operator <= (const MemoryReport& other) const; - bool operator > (const MemoryReport& other) const; - bool operator >= (const MemoryReport& other) const; - bool operator == (const MemoryReport& other) const; - bool operator != (const MemoryReport& other) const; - - Nd4jLong getVM() const; - void setVM(Nd4jLong vm); - - Nd4jLong getRSS() const; - void setRSS(Nd4jLong rss); - }; - } -} - - - -#endif //LIBND4J_MEMORYREPORT_H +namespace memory { +class SD_EXPORT MemoryReport { + private: + Nd4jLong _vm = 0; + Nd4jLong _rss = 0; + + public: + MemoryReport() = default; + ~MemoryReport() = default; + + bool operator<(const MemoryReport& other) const; + bool operator<=(const MemoryReport& other) const; + bool operator>(const MemoryReport& other) const; + bool operator>=(const MemoryReport& other) const; + bool operator==(const MemoryReport& other) const; + bool operator!=(const MemoryReport& other) const; + + Nd4jLong getVM() const; + void setVM(Nd4jLong vm); + + Nd4jLong getRSS() const; + void setRSS(Nd4jLong rss); +}; +} // namespace memory +} // namespace sd + +#endif // LIBND4J_MEMORYREPORT_H diff --git a/libnd4j/include/memory/MemoryTracker.h b/libnd4j/include/memory/MemoryTracker.h index dd99905bd190..05e107e7b631 100644 --- a/libnd4j/include/memory/MemoryTracker.h +++ b/libnd4j/include/memory/MemoryTracker.h @@ -18,40 +18,45 @@ // Created by raver119 on 07.05.19. // -#ifndef DEV_TESTS_MEMORYTRACKER_H -#define DEV_TESTS_MEMORYTRACKER_H +#ifndef SD_MEMORYTRACKER_H +#define SD_MEMORYTRACKER_H -#include -#include +#include #include + +#include #include +#include + #include "AllocationEntry.h" -#include namespace sd { - namespace memory { - /** - * This class is used for tracking memory allocation wrt their allocation points in code - */ - class ND4J_EXPORT MemoryTracker { - private: - std::map _allocations; - std::map _released; - std::mutex _locker; - - MemoryTracker(); - ~MemoryTracker() = default; - public: - static MemoryTracker& getInstance(); - - void countIn(MemoryType type, Nd4jPointer ptr, Nd4jLong numBytes); - void countOut(Nd4jPointer ptr); - - void summarize(); - void reset(); - }; - } -} - - -#endif //DEV_TESTS_MEMORYTRACKER_H +namespace memory { +/** + * This class is used for tracking memory allocation wrt their allocation points + * in code + */ +class SD_EXPORT MemoryTracker { + private: + + std::map _allocations; + std::map _released; + std::mutex _locker; + + MemoryTracker(); + ~MemoryTracker() = default; + + public: + static MemoryTracker& getInstance(); + + void countIn(MemoryType type, Nd4jPointer ptr, Nd4jLong numBytes); + void countOut(Nd4jPointer ptr); + + void summarize(); + void reset(); +}; + +} // namespace memory +} // namespace sd + +#endif // SD_MEMORYTRACKER_H diff --git a/libnd4j/include/memory/MemoryType.h b/libnd4j/include/memory/MemoryType.h index 113d8d16d383..6cf1091cdee4 100644 --- a/libnd4j/include/memory/MemoryType.h +++ b/libnd4j/include/memory/MemoryType.h @@ -2,16 +2,16 @@ // Created by raver119 on 07.05.19. // -#ifndef DEV_TESTS_MEMORYTYPE_H -#define DEV_TESTS_MEMORYTYPE_H +#ifndef SD_MEMORYTYPE_H +#define SD_MEMORYTYPE_H namespace sd { - namespace memory { - enum MemoryType { - HOST = 0, - DEVICE = 10, - }; - } +namespace memory { +enum MemoryType { + HOST = 0, + DEVICE = 10, +}; } +} // namespace sd -#endif //DEV_TESTS_MEMORYTYPE_H +#endif // SD_MEMORYTYPE_H diff --git a/libnd4j/include/memory/MemoryUtils.h b/libnd4j/include/memory/MemoryUtils.h index 027008238535..a3e749144219 100644 --- a/libnd4j/include/memory/MemoryUtils.h +++ b/libnd4j/include/memory/MemoryUtils.h @@ -21,18 +21,17 @@ #ifndef LIBND4J_MEMORYUTILS_H #define LIBND4J_MEMORYUTILS_H -#include "MemoryReport.h" #include -namespace sd { - namespace memory { - class ND4J_EXPORT MemoryUtils { - public: - static bool retrieveMemoryStatistics(MemoryReport& report); - }; - } -} - +#include "MemoryReport.h" +namespace sd { +namespace memory { +class SD_EXPORT MemoryUtils { + public: + static bool retrieveMemoryStatistics(MemoryReport& report); +}; +} // namespace memory +} // namespace sd -#endif //LIBND4J_MEMORYUTILS_H +#endif // LIBND4J_MEMORYUTILS_H diff --git a/libnd4j/include/memory/MemoryZone.h b/libnd4j/include/memory/MemoryZone.h new file mode 100644 index 000000000000..666f6c9e6a51 --- /dev/null +++ b/libnd4j/include/memory/MemoryZone.h @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_MEMORYZONE_H +#define SD_MEMORYZONE_H + +namespace sd { +namespace memory { +enum MemoryZone { + COLD = 0, + WARM = 10, + HOT = 20, +}; +} +} // namespace sd + +#endif // SD_MEMORYZONE_H diff --git a/libnd4j/include/memory/MmapDeallocator.h b/libnd4j/include/memory/MmapDeallocator.h new file mode 100644 index 000000000000..6e2429c9db68 --- /dev/null +++ b/libnd4j/include/memory/MmapDeallocator.h @@ -0,0 +1,41 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_MEMORY_MMAPDEALLOCATOR_H_ +#define SD_MEMORY_MMAPDEALLOCATOR_H_ + +#include + +namespace sd { +namespace memory { + class MmapDeallocator : public sd::PointerDeallocator { + private: + public: + MmapDeallocator() = default; + ~MmapDeallocator() = default; + + void release(void *ptr) override; + }; + +} +} + + +#endif //SD_MEMORY_MMAPDEALLOCATOR_H_ diff --git a/libnd4j/include/memory/WarmZoneManager.h b/libnd4j/include/memory/WarmZoneManager.h new file mode 100644 index 000000000000..05700805b169 --- /dev/null +++ b/libnd4j/include/memory/WarmZoneManager.h @@ -0,0 +1,37 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_WARMZONEMANAGER_H +#define SD_WARMZONEMANAGER_H + +#include + +namespace sd { +namespace memory { +class SD_EXPORT WarmZoneManager : public ZoneManager { + protected: + public: + WarmZoneManager() = default; + ~WarmZoneManager() = default; +}; +} // namespace memory +} // namespace sd + +#endif // SD_WARMZONEMANAGER_H diff --git a/libnd4j/include/memory/Workspace.h b/libnd4j/include/memory/Workspace.h index c97f6a178978..098856ba8876 100644 --- a/libnd4j/include/memory/Workspace.h +++ b/libnd4j/include/memory/Workspace.h @@ -24,85 +24,88 @@ #ifndef LIBND4J_WORKSPACE_H #define LIBND4J_WORKSPACE_H -#include -#include -#include +#include +#include #include #include #include -#include -#include + +#include +#include +#include namespace sd { - namespace memory { +namespace memory { + +class SD_EXPORT Workspace { + protected: + char* _ptrHost = nullptr; + char* _ptrDevice = nullptr; - class ND4J_EXPORT Workspace { - protected: - char* _ptrHost = nullptr; - char* _ptrDevice = nullptr; + bool _allocatedHost = false; + bool _allocatedDevice = false; - bool _allocatedHost = false; - bool _allocatedDevice = false; + std::atomic _offset; + std::atomic _offsetSecondary; - std::atomic _offset; - std::atomic _offsetSecondary; + Nd4jLong _initialSize = 0L; + Nd4jLong _initialSizeSecondary = 0L; - Nd4jLong _initialSize = 0L; - Nd4jLong _initialSizeSecondary = 0L; + Nd4jLong _currentSize = 0L; + Nd4jLong _currentSizeSecondary = 0L; - Nd4jLong _currentSize = 0L; - Nd4jLong _currentSizeSecondary = 0L; + std::mutex _mutexAllocation; + std::mutex _mutexSpills; - std::mutex _mutexAllocation; - std::mutex _mutexSpills; + bool _externalized = false; - bool _externalized = false; + std::vector _spills; + std::vector _spillsSecondary; - std::vector _spills; - std::vector _spillsSecondary; + std::atomic _spillsSize; + std::atomic _cycleAllocations; - std::atomic _spillsSize; - std::atomic _cycleAllocations; + std::atomic _spillsSizeSecondary; + std::atomic _cycleAllocationsSecondary; - std::atomic _spillsSizeSecondary; - std::atomic _cycleAllocationsSecondary; + void init(Nd4jLong primaryBytes, Nd4jLong secondaryBytes = 0L); + void freeSpills(); - void init(Nd4jLong primaryBytes, Nd4jLong secondaryBytes = 0L); - void freeSpills(); - public: - explicit Workspace(ExternalWorkspace *external); - Workspace(Nd4jLong initialSize = 0L, Nd4jLong secondaryBytes = 0L); - ~Workspace(); + public: + explicit Workspace(ExternalWorkspace* external); + Workspace(Nd4jLong initialSize = 0L, Nd4jLong secondaryBytes = 0L); + ~Workspace(); - Nd4jLong getAllocatedSize(); - Nd4jLong getCurrentSize(); - Nd4jLong getCurrentOffset(); - Nd4jLong getSpilledSize(); - Nd4jLong getUsedSize(); + Nd4jLong getAllocatedSize(); + Nd4jLong getCurrentSize(); + Nd4jLong getCurrentOffset(); + Nd4jLong getSpilledSize(); + Nd4jLong getUsedSize(); - Nd4jLong getAllocatedSecondarySize(); - Nd4jLong getCurrentSecondarySize(); - Nd4jLong getCurrentSecondaryOffset(); - Nd4jLong getSpilledSecondarySize(); - Nd4jLong getUsedSecondarySize(); + Nd4jLong getAllocatedSecondarySize(); + Nd4jLong getCurrentSecondarySize(); + Nd4jLong getCurrentSecondaryOffset(); + Nd4jLong getSpilledSecondarySize(); + Nd4jLong getUsedSecondarySize(); - void expandBy(Nd4jLong primaryBytes, Nd4jLong secondaryBytes = 0L); - void expandTo(Nd4jLong primaryBytes, Nd4jLong secondaryBytes = 0L); + void expandBy(Nd4jLong primaryBytes, Nd4jLong secondaryBytes = 0L); + void expandTo(Nd4jLong primaryBytes, Nd4jLong secondaryBytes = 0L); -// bool resizeSupported(); + // bool resizeSupported(); - void* allocateBytes(Nd4jLong numBytes); - void* allocateBytes(MemoryType type, Nd4jLong numBytes); + void* allocateBytes(Nd4jLong numBytes); + void* allocateBytes(MemoryType type, Nd4jLong numBytes); - void scopeIn(); - void scopeOut(); + void scopeIn(); + void scopeOut(); - /* - * This method creates NEW workspace of the same memory size and returns pointer to it - */ - Workspace* clone(); - }; - } -} + /* + * This method creates NEW workspace of the same memory size and returns + * pointer to it + */ + Workspace* clone(); +}; +} // namespace memory +} // namespace sd -#endif //LIBND4J_WORKSPACE_H +#endif // LIBND4J_WORKSPACE_H diff --git a/libnd4j/include/memory/ZoneManager.h b/libnd4j/include/memory/ZoneManager.h new file mode 100644 index 000000000000..67d79fa3488c --- /dev/null +++ b/libnd4j/include/memory/ZoneManager.h @@ -0,0 +1,80 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_ZONEMANAGER_H +#define SD_ZONEMANAGER_H + +#include +#include +#include + +#include +#include + +namespace sd { +namespace memory { +/** + * Abstract class that defines common methods for zone managers + */ +class SD_EXPORT ZoneManager { + protected: + std::mutex _lock; + + public: + ZoneManager() = default; + + virtual ~ZoneManager() = default; + + /** + * This method returns id of the current zone served by this manager instance + * @return MemoryZone enum + */ + virtual MemoryZone zone() const = 0; + + /** + * This method returns amount of memory available in this zone + * @return number of bytes + */ + virtual uint64_t available() const = 0; + + /** + * This method returns amount of memory currently used in this zone + * @return number of bytes + */ + virtual uint64_t used() const = 0; + + /** + * This method allocates (probably) some memory chunk, and returns you pointer + * to it. + * @param numBytes + * @return + */ + virtual MemoryDescriptor allocate(uint64_t numBytes) = 0; + + /** + * This method releases (probably) memory described by given MemoryDescriptor + * @param descriptor + */ + virtual void release(MemoryDescriptor &descriptor) = 0; +}; +} // namespace memory +} // namespace sd + +#endif // SD_ZONEMANAGER_H diff --git a/libnd4j/include/memory/cpu/Workspace.cpp b/libnd4j/include/memory/cpu/Workspace.cpp index ae60f1eeac30..a25195c44f16 100644 --- a/libnd4j/include/memory/cpu/Workspace.cpp +++ b/libnd4j/include/memory/cpu/Workspace.cpp @@ -20,199 +20,181 @@ // @author raver119@gmail.com // - -#include -#include -#include -#include #include "../Workspace.h" + #include #include -#include +#include +#include +#include +#include +#include namespace sd { - namespace memory { - Workspace::Workspace(ExternalWorkspace *external) { - if (external->sizeHost() > 0) { - _ptrHost = (char *) external->pointerHost(); - _ptrDevice = (char *) external->pointerDevice(); - - _initialSize = external->sizeHost(); - _currentSize = external->sizeHost(); - _offset = 0L; - _offsetSecondary = 0L; - this->_cycleAllocations = 0; - this->_spillsSize = 0; - - _externalized = true; - } - }; - - Workspace::Workspace(Nd4jLong initialSize, Nd4jLong secondaryBytes) { - if (initialSize > 0) { - this->_ptrHost = (char *) malloc(initialSize); - - CHECK_ALLOC(this->_ptrHost, "Failed to allocate new workspace", initialSize); - - memset(this->_ptrHost, 0, initialSize); - this->_allocatedHost = true; - } else - this->_allocatedHost = false; - - this->_initialSize = initialSize; - this->_currentSize = initialSize; - this->_currentSizeSecondary = 0; - this->_spillsSizeSecondary = 0; - this->_offset = 0; - this->_offsetSecondary = 0; - this->_cycleAllocations = 0; - this->_spillsSize = 0; - } - - void Workspace::init(Nd4jLong bytes, Nd4jLong secondaryBytes) { - if (this->_currentSize < bytes) { - if (this->_allocatedHost && !_externalized) - free((void *)this->_ptrHost); - - this->_ptrHost =(char *) malloc(bytes); +namespace memory { +Workspace::Workspace(ExternalWorkspace *external) { + if (external->sizeHost() > 0) { + _ptrHost = (char *)external->pointerHost(); + _ptrDevice = (char *)external->pointerDevice(); + + _initialSize = external->sizeHost(); + _currentSize = external->sizeHost(); + _offset = 0L; + _offsetSecondary = 0L; + this->_cycleAllocations = 0; + this->_spillsSize = 0; + + _externalized = true; + } +}; + +Workspace::Workspace(Nd4jLong initialSize, Nd4jLong secondaryBytes) { + if (initialSize > 0) { + this->_ptrHost = (char *)malloc(initialSize); + + CHECK_ALLOC(this->_ptrHost, "Failed to allocate new workspace", + initialSize); + + memset(this->_ptrHost, 0, initialSize); + this->_allocatedHost = true; + } else + this->_allocatedHost = false; + + this->_initialSize = initialSize; + this->_currentSize = initialSize; + this->_currentSizeSecondary = 0; + this->_spillsSizeSecondary = 0; + this->_offset = 0; + this->_offsetSecondary = 0; + this->_cycleAllocations = 0; + this->_spillsSize = 0; +} - CHECK_ALLOC(this->_ptrHost, "Failed to allocate new workspace", bytes); +void Workspace::init(Nd4jLong bytes, Nd4jLong secondaryBytes) { + if (this->_currentSize < bytes) { + if (this->_allocatedHost && !_externalized) free((void *)this->_ptrHost); - memset(this->_ptrHost, 0, bytes); - this->_currentSize = bytes; - this->_allocatedHost = true; - } - } + this->_ptrHost = (char *)malloc(bytes); - void Workspace::expandBy(Nd4jLong numBytes, Nd4jLong secondaryBytes) { - this->init(_currentSize + numBytes, _currentSizeSecondary + secondaryBytes); - } + CHECK_ALLOC(this->_ptrHost, "Failed to allocate new workspace", bytes); - void Workspace::expandTo(Nd4jLong numBytes, Nd4jLong secondaryBytes) { - this->init(numBytes, secondaryBytes); - } + memset(this->_ptrHost, 0, bytes); + this->_currentSize = bytes; + this->_allocatedHost = true; + } +} - void Workspace::freeSpills() { - _spillsSize = 0; +void Workspace::expandBy(Nd4jLong numBytes, Nd4jLong secondaryBytes) { + this->init(_currentSize + numBytes, _currentSizeSecondary + secondaryBytes); +} - if (_spills.size() < 1) - return; +void Workspace::expandTo(Nd4jLong numBytes, Nd4jLong secondaryBytes) { + this->init(numBytes, secondaryBytes); +} - for (auto v:_spills) - free(v); +void Workspace::freeSpills() { + _spillsSize = 0; - _spills.clear(); - } + if (_spills.size() < 1) return; - Workspace::~Workspace() { - if (this->_allocatedHost && !_externalized) - free((void *)this->_ptrHost); + for (auto v : _spills) free(v); - freeSpills(); - } + _spills.clear(); +} - Nd4jLong Workspace::getUsedSize() { - return getCurrentOffset(); - } +Workspace::~Workspace() { + if (this->_allocatedHost && !_externalized) free((void *)this->_ptrHost); - Nd4jLong Workspace::getCurrentSize() { - return _currentSize; - } + freeSpills(); +} - Nd4jLong Workspace::getCurrentOffset() { - return _offset.load(); - } +Nd4jLong Workspace::getUsedSize() { return getCurrentOffset(); } +Nd4jLong Workspace::getCurrentSize() { return _currentSize; } - void* Workspace::allocateBytes(Nd4jLong numBytes) { - if (numBytes < 1) - throw allocation_exception::build("Number of bytes for allocation should be positive", numBytes); +Nd4jLong Workspace::getCurrentOffset() { return _offset.load(); } +void *Workspace::allocateBytes(Nd4jLong numBytes) { + if (numBytes < 1) + throw allocation_exception::build( + "Number of bytes for allocation should be positive", numBytes); - //numBytes += 32; - void* result = nullptr; - this->_cycleAllocations += numBytes; - this->_mutexAllocation.lock(); + // numBytes += 32; + void *result = nullptr; + this->_cycleAllocations += numBytes; + this->_mutexAllocation.lock(); - if (_offset.load() + numBytes > _currentSize) { - nd4j_debug("Allocating %lld bytes in spills\n", numBytes); - this->_mutexAllocation.unlock(); + if (_offset.load() + numBytes > _currentSize) { + nd4j_debug("Allocating %lld bytes in spills\n", numBytes); + this->_mutexAllocation.unlock(); - void *p = malloc(numBytes); + void *p = malloc(numBytes); - CHECK_ALLOC(p, "Failed to allocate new workspace", numBytes); + CHECK_ALLOC(p, "Failed to allocate new workspace", numBytes); - _mutexSpills.lock(); - _spills.push_back(p); - _mutexSpills.unlock(); + _mutexSpills.lock(); + _spills.push_back(p); + _mutexSpills.unlock(); - _spillsSize += numBytes; + _spillsSize += numBytes; - return p; - } + return p; + } - result = (void *)(_ptrHost + _offset.load()); - _offset += numBytes; - //memset(result, 0, (int) numBytes); + result = (void *)(_ptrHost + _offset.load()); + _offset += numBytes; + // memset(result, 0, (int) numBytes); - nd4j_debug("Allocating %lld bytes from workspace; Current PTR: %p; Current offset: %lld\n", numBytes, result, _offset.load()); + nd4j_debug( + "Allocating %lld bytes from workspace; Current PTR: %p; Current offset: " + "%lld\n", + numBytes, result, _offset.load()); - this->_mutexAllocation.unlock(); + this->_mutexAllocation.unlock(); - return result; - } + return result; +} - Nd4jLong Workspace::getAllocatedSize() { - return getCurrentSize() + getSpilledSize(); - } +Nd4jLong Workspace::getAllocatedSize() { + return getCurrentSize() + getSpilledSize(); +} - void Workspace::scopeIn() { - freeSpills(); - init(_cycleAllocations.load()); - _cycleAllocations = 0; - } +void Workspace::scopeIn() { + freeSpills(); + init(_cycleAllocations.load()); + _cycleAllocations = 0; +} - void Workspace::scopeOut() { - _offset = 0; - _offsetSecondary = 0; - } +void Workspace::scopeOut() { + _offset = 0; + _offsetSecondary = 0; +} - Nd4jLong Workspace::getSpilledSize() { - return _spillsSize.load(); - } +Nd4jLong Workspace::getSpilledSize() { return _spillsSize.load(); } - void* Workspace::allocateBytes(sd::memory::MemoryType type, Nd4jLong numBytes) { - if (type == DEVICE) - throw std::runtime_error("CPU backend doesn't have device memory"); +void *Workspace::allocateBytes(sd::memory::MemoryType type, Nd4jLong numBytes) { + if (type == DEVICE) + throw std::runtime_error("CPU backend doesn't have device memory"); - return this->allocateBytes(numBytes); - } + return this->allocateBytes(numBytes); +} - Nd4jLong Workspace::getAllocatedSecondarySize() { - return 0L; - } +Nd4jLong Workspace::getAllocatedSecondarySize() { return 0L; } - Nd4jLong Workspace::getCurrentSecondarySize() { - return 0L; - } +Nd4jLong Workspace::getCurrentSecondarySize() { return 0L; } - Nd4jLong Workspace::getCurrentSecondaryOffset() { - return 0L; - } +Nd4jLong Workspace::getCurrentSecondaryOffset() { return 0L; } - Nd4jLong Workspace::getSpilledSecondarySize() { - return 0L; - } +Nd4jLong Workspace::getSpilledSecondarySize() { return 0L; } - Nd4jLong Workspace::getUsedSecondarySize() { - return 0L; - } +Nd4jLong Workspace::getUsedSecondarySize() { return 0L; } - Workspace* Workspace::clone() { - // for clone we take whatever is higher: current allocated size, or allocated size of current loop - return new Workspace(sd::math::nd4j_max(this->getCurrentSize(), this->_cycleAllocations.load())); - } - } +Workspace *Workspace::clone() { + // for clone we take whatever is higher: current allocated size, or allocated + // size of current loop + return new Workspace(sd::math::nd4j_max( + this->getCurrentSize(), this->_cycleAllocations.load())); } - +} // namespace memory +} // namespace sd diff --git a/libnd4j/include/memory/cuda/Workspace.cu b/libnd4j/include/memory/cuda/Workspace.cu index 9d228615689f..2d57b43c5ce3 100644 --- a/libnd4j/include/memory/cuda/Workspace.cu +++ b/libnd4j/include/memory/cuda/Workspace.cu @@ -20,278 +20,273 @@ // @author raver119@gmail.com // -#include -#include -#include -#include -#include "../Workspace.h" +#include +#include +#include #include #include +#include +#include +#include + +#include #include -#include -#include -#include +#include "../Workspace.h" namespace sd { - namespace memory { - Workspace::Workspace(ExternalWorkspace *external) { - if (external->sizeHost() > 0) { - _ptrHost = (char *) external->pointerHost(); - _ptrDevice = (char *) external->pointerDevice(); - - _initialSize = external->sizeDevice(); - _currentSize = external->sizeDevice(); - _initialSizeSecondary = external->sizeHost(); - _currentSizeSecondary = external->sizeHost(); - _offset = 0L; - _offsetSecondary = 0L; - this->_cycleAllocations = 0; - this->_cycleAllocationsSecondary = 0; - this->_spillsSize = 0; - this->_spillsSizeSecondary = 0; - - _externalized = true; - } - } - - Workspace::Workspace(Nd4jLong primarySize, Nd4jLong secondarySize) { - if (secondarySize > 0) { - auto res = cudaHostAlloc(reinterpret_cast(&_ptrHost), secondarySize, cudaHostAllocDefault); - if (res != 0) - throw cuda_exception::build("Can't allocate [HOST] memory", res); - - cudaMemset(this->_ptrHost, 0, secondarySize); - this->_allocatedHost = true; - } else - this->_allocatedHost = false; - - if (primarySize > 0) { - auto res = cudaMalloc(reinterpret_cast(&_ptrDevice), primarySize); - if (res != 0) - throw cuda_exception::build("Can't allocate [DEVICE] memory", res); - - cudaMemset(this->_ptrDevice, 0, primarySize); - this->_allocatedDevice = true; - } else - this->_allocatedDevice = false; - - this->_initialSize = primarySize; - this->_initialSizeSecondary = secondarySize; - this->_currentSize = primarySize; - this->_currentSizeSecondary = secondarySize; - this->_offset = 0; - this->_offsetSecondary = 0; - this->_cycleAllocations = 0; - this->_spillsSize = 0; - this->_spillsSizeSecondary = 0; - } - - void Workspace::init(Nd4jLong primaryBytes, Nd4jLong secondaryBytes) { - if (this->_currentSize < primaryBytes) { - if (this->_allocatedDevice && !_externalized) - cudaFree((void *)this->_ptrDevice); - - auto res = cudaMalloc(reinterpret_cast(&_ptrDevice), secondaryBytes); - if (res != 0) - throw cuda_exception::build("Can't allocate [DEVICE] memory", res); - - cudaMemset(this->_ptrDevice, 0, primaryBytes); - this->_currentSize = primaryBytes; - this->_allocatedDevice = true; - } - - if (this->_currentSizeSecondary < secondaryBytes) { - if (this->_allocatedHost && !_externalized) - cudaFreeHost((void *)this->_ptrHost); - - auto res = cudaHostAlloc(reinterpret_cast(&_ptrHost), secondaryBytes, cudaHostAllocDefault); - if (res != 0) - throw cuda_exception::build("Can't allocate [HOST] memory", res); - - - cudaMemset(this->_ptrHost, 0, secondaryBytes); - this->_currentSizeSecondary = secondaryBytes; - this->_allocatedHost = true; - } - } - - void Workspace::expandBy(Nd4jLong numBytes, Nd4jLong secondaryBytes) { - this->init(_currentSize + numBytes, _currentSizeSecondary + secondaryBytes); - } - - void Workspace::expandTo(Nd4jLong numBytes, Nd4jLong secondaryBytes) { - this->init(numBytes, secondaryBytes); - } - - void Workspace::freeSpills() { - _spillsSize = 0; - _spillsSizeSecondary = 0; - - for (auto v:_spills) - cudaFree(v); - - for (auto v:_spillsSecondary) - cudaFreeHost(v); - - _spills.clear(); - _spillsSecondary.clear(); - } - - Workspace::~Workspace() { - if (this->_allocatedHost && !_externalized) - cudaFreeHost((void *)this->_ptrHost); - - if (this->_allocatedDevice && !_externalized) - cudaFree((void *)this->_ptrDevice); - - freeSpills(); - } - - Nd4jLong Workspace::getUsedSize() { - return getCurrentOffset(); - } +namespace memory { +Workspace::Workspace(ExternalWorkspace *external) { + if (external->sizeHost() > 0) { + _ptrHost = (char *)external->pointerHost(); + _ptrDevice = (char *)external->pointerDevice(); + + _initialSize = external->sizeDevice(); + _currentSize = external->sizeDevice(); + _initialSizeSecondary = external->sizeHost(); + _currentSizeSecondary = external->sizeHost(); + _offset = 0L; + _offsetSecondary = 0L; + this->_cycleAllocations = 0; + this->_cycleAllocationsSecondary = 0; + this->_spillsSize = 0; + this->_spillsSizeSecondary = 0; + + _externalized = true; + } +} + +Workspace::Workspace(Nd4jLong primarySize, Nd4jLong secondarySize) { + if (secondarySize > 0) { + auto res = cudaHostAlloc(reinterpret_cast(&_ptrHost), + secondarySize, cudaHostAllocDefault); + if (res != 0) + throw cuda_exception::build("Can't allocate [HOST] memory", res); + + cudaMemset(this->_ptrHost, 0, secondarySize); + this->_allocatedHost = true; + } else + this->_allocatedHost = false; + + if (primarySize > 0) { + auto res = cudaMalloc(reinterpret_cast(&_ptrDevice), primarySize); + if (res != 0) + throw cuda_exception::build("Can't allocate [DEVICE] memory", res); + + cudaMemset(this->_ptrDevice, 0, primarySize); + this->_allocatedDevice = true; + } else + this->_allocatedDevice = false; + + this->_initialSize = primarySize; + this->_initialSizeSecondary = secondarySize; + this->_currentSize = primarySize; + this->_currentSizeSecondary = secondarySize; + this->_offset = 0; + this->_offsetSecondary = 0; + this->_cycleAllocations = 0; + this->_spillsSize = 0; + this->_spillsSizeSecondary = 0; +} - Nd4jLong Workspace::getCurrentSize() { - return _currentSize; - } +void Workspace::init(Nd4jLong primaryBytes, Nd4jLong secondaryBytes) { + if (this->_currentSize < primaryBytes) { + if (this->_allocatedDevice && !_externalized) + cudaFree((void *)this->_ptrDevice); + + auto res = + cudaMalloc(reinterpret_cast(&_ptrDevice), secondaryBytes); + if (res != 0) + throw cuda_exception::build("Can't allocate [DEVICE] memory", res); + + cudaMemset(this->_ptrDevice, 0, primaryBytes); + this->_currentSize = primaryBytes; + this->_allocatedDevice = true; + } + + if (this->_currentSizeSecondary < secondaryBytes) { + if (this->_allocatedHost && !_externalized) + cudaFreeHost((void *)this->_ptrHost); + + auto res = cudaHostAlloc(reinterpret_cast(&_ptrHost), + secondaryBytes, cudaHostAllocDefault); + if (res != 0) + throw cuda_exception::build("Can't allocate [HOST] memory", res); + + cudaMemset(this->_ptrHost, 0, secondaryBytes); + this->_currentSizeSecondary = secondaryBytes; + this->_allocatedHost = true; + } +} + +void Workspace::expandBy(Nd4jLong numBytes, Nd4jLong secondaryBytes) { + this->init(_currentSize + numBytes, _currentSizeSecondary + secondaryBytes); +} + +void Workspace::expandTo(Nd4jLong numBytes, Nd4jLong secondaryBytes) { + this->init(numBytes, secondaryBytes); +} + +void Workspace::freeSpills() { + _spillsSize = 0; + _spillsSizeSecondary = 0; + + for (auto v : _spills) cudaFree(v); + + for (auto v : _spillsSecondary) cudaFreeHost(v); + + _spills.clear(); + _spillsSecondary.clear(); +} - Nd4jLong Workspace::getCurrentOffset() { - return _offset.load(); - } +Workspace::~Workspace() { + if (this->_allocatedHost && !_externalized) + cudaFreeHost((void *)this->_ptrHost); + if (this->_allocatedDevice && !_externalized) + cudaFree((void *)this->_ptrDevice); - void* Workspace::allocateBytes(Nd4jLong numBytes) { - return allocateBytes(sd::memory::MemoryType::HOST, numBytes); - } + freeSpills(); +} - Nd4jLong Workspace::getAllocatedSize() { - return getCurrentSize() + getSpilledSize(); - } +Nd4jLong Workspace::getUsedSize() { return getCurrentOffset(); } - void Workspace::scopeIn() { - freeSpills(); - init(_cycleAllocations.load()); - _cycleAllocations = 0; - } +Nd4jLong Workspace::getCurrentSize() { return _currentSize; } - void Workspace::scopeOut() { - _offset = 0; - } +Nd4jLong Workspace::getCurrentOffset() { return _offset.load(); } - Nd4jLong Workspace::getSpilledSize() { - return _spillsSize.load(); - } +void *Workspace::allocateBytes(Nd4jLong numBytes) { + return allocateBytes(sd::memory::MemoryType::HOST, numBytes); +} - void* Workspace::allocateBytes(sd::memory::MemoryType type, Nd4jLong numBytes) { - switch (type) { - case HOST: { - if (numBytes < 1) - throw allocation_exception::build("Number of [HOST] bytes for allocation should be positive", numBytes); +Nd4jLong Workspace::getAllocatedSize() { + return getCurrentSize() + getSpilledSize(); +} + +void Workspace::scopeIn() { + freeSpills(); + init(_cycleAllocations.load()); + _cycleAllocations = 0; +} +void Workspace::scopeOut() { _offset = 0; } - //numBytes += 32; - void* result = nullptr; - this->_cycleAllocationsSecondary += numBytes; - this->_mutexAllocation.lock(); +Nd4jLong Workspace::getSpilledSize() { return _spillsSize.load(); } - if (_offsetSecondary.load() + numBytes > _currentSizeSecondary) { - nd4j_debug("Allocating %lld [HOST] bytes in spills\n", numBytes); - this->_mutexAllocation.unlock(); +void *Workspace::allocateBytes(sd::memory::MemoryType type, Nd4jLong numBytes) { + switch (type) { + case HOST: { + if (numBytes < 1) + throw allocation_exception::build( + "Number of [HOST] bytes for allocation should be positive", + numBytes); - Nd4jPointer p; - auto res = cudaHostAlloc(reinterpret_cast(&p), numBytes, cudaHostAllocDefault); - if (res != 0) - throw cuda_exception::build("Can't allocate [HOST] memory", res); + // numBytes += 32; + void *result = nullptr; + this->_cycleAllocationsSecondary += numBytes; + this->_mutexAllocation.lock(); - _mutexSpills.lock(); - _spillsSecondary.push_back(p); - _mutexSpills.unlock(); + if (_offsetSecondary.load() + numBytes > _currentSizeSecondary) { + nd4j_debug("Allocating %lld [HOST] bytes in spills\n", numBytes); + this->_mutexAllocation.unlock(); - _spillsSizeSecondary += numBytes; + Nd4jPointer p; + auto res = cudaHostAlloc(reinterpret_cast(&p), numBytes, + cudaHostAllocDefault); + if (res != 0) + throw cuda_exception::build("Can't allocate [HOST] memory", res); - return p; - } + _mutexSpills.lock(); + _spillsSecondary.push_back(p); + _mutexSpills.unlock(); - result = (void *)(_ptrHost + _offsetSecondary.load()); - _offsetSecondary += numBytes; - //memset(result, 0, (int) numBytes); + _spillsSizeSecondary += numBytes; - nd4j_debug("Allocating %lld bytes from [HOST] workspace; Current PTR: %p; Current offset: %lld\n", numBytes, result, _offset.load()); + return p; + } - this->_mutexAllocation.unlock(); + result = (void *)(_ptrHost + _offsetSecondary.load()); + _offsetSecondary += numBytes; + // memset(result, 0, (int) numBytes); - return result; - } - break; - case DEVICE: { - if (numBytes < 1) - throw allocation_exception::build("Number of [DEVICE] bytes for allocation should be positive", numBytes); + nd4j_debug( + "Allocating %lld bytes from [HOST] workspace; Current PTR: %p; " + "Current offset: %lld\n", + numBytes, result, _offset.load()); + this->_mutexAllocation.unlock(); - //numBytes += 32; - void* result = nullptr; - this->_cycleAllocations += numBytes; - this->_mutexAllocation.lock(); + return result; + } break; + case DEVICE: { + if (numBytes < 1) + throw allocation_exception::build( + "Number of [DEVICE] bytes for allocation should be positive", + numBytes); - if (_offset.load() + numBytes > _currentSize) { - nd4j_debug("Allocating %lld [DEVICE] bytes in spills\n", numBytes); - this->_mutexAllocation.unlock(); + // numBytes += 32; + void *result = nullptr; + this->_cycleAllocations += numBytes; + this->_mutexAllocation.lock(); - Nd4jPointer p; - auto res = cudaMalloc(reinterpret_cast(&p), numBytes); - if (res != 0) - throw cuda_exception::build("Can't allocate [DEVICE] memory", res); + if (_offset.load() + numBytes > _currentSize) { + nd4j_debug("Allocating %lld [DEVICE] bytes in spills\n", numBytes); + this->_mutexAllocation.unlock(); - _mutexSpills.lock(); - _spills.push_back(p); - _mutexSpills.unlock(); + Nd4jPointer p; + auto res = cudaMalloc(reinterpret_cast(&p), numBytes); + if (res != 0) + throw cuda_exception::build("Can't allocate [DEVICE] memory", res); - _spillsSize += numBytes; + _mutexSpills.lock(); + _spills.push_back(p); + _mutexSpills.unlock(); - return p; - } + _spillsSize += numBytes; - result = (void *)(_ptrDevice + _offset.load()); - _offset += numBytes; - //memset(result, 0, (int) numBytes); + return p; + } - nd4j_debug("Allocating %lld bytes from [DEVICE] workspace; Current PTR: %p; Current offset: %lld\n", numBytes, result, _offset.load()); + result = (void *)(_ptrDevice + _offset.load()); + _offset += numBytes; + // memset(result, 0, (int) numBytes); - this->_mutexAllocation.unlock(); + nd4j_debug( + "Allocating %lld bytes from [DEVICE] workspace; Current PTR: %p; " + "Current offset: %lld\n", + numBytes, result, _offset.load()); - return result; - } - break; - default: - throw std::runtime_error("Unknown MemoryType was passed in"); - } - } + this->_mutexAllocation.unlock(); - Workspace* Workspace::clone() { - // for clone we take whatever is higher: current allocated size, or allocated size of current loop - return new Workspace(sd::math::nd4j_max(this->getCurrentSize(), this->_cycleAllocations.load())); - } + return result; + } break; + default: + throw std::runtime_error("Unknown MemoryType was passed in"); + } +} - Nd4jLong Workspace::getAllocatedSecondarySize() { - return getCurrentSecondarySize() + getSpilledSecondarySize(); - } +Workspace *Workspace::clone() { + // for clone we take whatever is higher: current allocated size, or allocated + // size of current loop + return new Workspace(sd::math::nd4j_max( + this->getCurrentSize(), this->_cycleAllocations.load())); +} - Nd4jLong Workspace::getCurrentSecondarySize() { - return _currentSizeSecondary; - } +Nd4jLong Workspace::getAllocatedSecondarySize() { + return getCurrentSecondarySize() + getSpilledSecondarySize(); +} - Nd4jLong Workspace::getCurrentSecondaryOffset() { - return _offsetSecondary.load(); - } +Nd4jLong Workspace::getCurrentSecondarySize() { return _currentSizeSecondary; } - Nd4jLong Workspace::getSpilledSecondarySize() { - return _spillsSizeSecondary; - } +Nd4jLong Workspace::getCurrentSecondaryOffset() { + return _offsetSecondary.load(); +} - Nd4jLong Workspace::getUsedSecondarySize() { - return getCurrentSecondaryOffset(); - } +Nd4jLong Workspace::getSpilledSecondarySize() { return _spillsSizeSecondary; } - } +Nd4jLong Workspace::getUsedSecondarySize() { + return getCurrentSecondaryOffset(); } + +} // namespace memory +} // namespace sd diff --git a/libnd4j/include/memory/impl/AllocationEntry.cpp b/libnd4j/include/memory/impl/AllocationEntry.cpp index 6b4d85bb1040..0a1598144b97 100644 --- a/libnd4j/include/memory/impl/AllocationEntry.cpp +++ b/libnd4j/include/memory/impl/AllocationEntry.cpp @@ -21,24 +21,19 @@ #include namespace sd { - namespace memory { - AllocationEntry::AllocationEntry(MemoryType type, Nd4jLong ptr, Nd4jLong numBytes, std::string &stack) { - _pointer = ptr; - _numBytes = numBytes; - _stack = stack; - _memoryType = type; - } +namespace memory { +AllocationEntry::AllocationEntry(MemoryType type, Nd4jLong ptr, + Nd4jLong numBytes, std::string &stack) { + _pointer = ptr; + _numBytes = numBytes; + _stack = stack; + _memoryType = type; +} - std::string AllocationEntry::stackTrace() { - return _stack; - } +std::string AllocationEntry::stackTrace() { return _stack; } - Nd4jLong AllocationEntry::numBytes() { - return _numBytes; - } +Nd4jLong AllocationEntry::numBytes() { return _numBytes; } - MemoryType AllocationEntry::memoryType() { - return _memoryType; - } - } -} \ No newline at end of file +MemoryType AllocationEntry::memoryType() { return _memoryType; } +} // namespace memory +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/memory/impl/ColdZoneManager.cpp b/libnd4j/include/memory/impl/ColdZoneManager.cpp new file mode 100644 index 000000000000..25c00ef27036 --- /dev/null +++ b/libnd4j/include/memory/impl/ColdZoneManager.cpp @@ -0,0 +1,43 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { +namespace memory { +ColdZoneManager::ColdZoneManager(const char *filename) { + // +} + +MemoryZone ColdZoneManager::zone() const { return COLD; } + +uint64_t ColdZoneManager::available() const { return 0; } + +uint64_t ColdZoneManager::used() const { return 0; } + +MemoryDescriptor ColdZoneManager::allocate(uint64_t numBytes) { + return MemoryDescriptor(nullptr, COLD, numBytes); +} + +void ColdZoneManager::release(MemoryDescriptor &descriptor) { + // +} +} // namespace memory +} // namespace sd diff --git a/libnd4j/include/memory/impl/ExternalWorkspace.cpp b/libnd4j/include/memory/impl/ExternalWorkspace.cpp index c4feb181dddb..e2a81b812b2d 100644 --- a/libnd4j/include/memory/impl/ExternalWorkspace.cpp +++ b/libnd4j/include/memory/impl/ExternalWorkspace.cpp @@ -21,29 +21,22 @@ #include namespace sd { - namespace memory { - ExternalWorkspace::ExternalWorkspace(Nd4jPointer ptrH, Nd4jLong sizeH, Nd4jPointer ptrD, Nd4jLong sizeD) { - _ptrH = ptrH; - _sizeH = sizeH; - - _ptrD = ptrD; - _sizeD = sizeD; - }; - - void* ExternalWorkspace::pointerHost() { - return _ptrH; - } - - void* ExternalWorkspace::pointerDevice() { - return _ptrD; - } - - Nd4jLong ExternalWorkspace::sizeHost() { - return _sizeH; - } - - Nd4jLong ExternalWorkspace::sizeDevice() { - return _sizeD; - } - } -} \ No newline at end of file +namespace memory { +ExternalWorkspace::ExternalWorkspace(Nd4jPointer ptrH, Nd4jLong sizeH, + Nd4jPointer ptrD, Nd4jLong sizeD) { + _ptrH = ptrH; + _sizeH = sizeH; + + _ptrD = ptrD; + _sizeD = sizeD; +}; + +void* ExternalWorkspace::pointerHost() { return _ptrH; } + +void* ExternalWorkspace::pointerDevice() { return _ptrD; } + +Nd4jLong ExternalWorkspace::sizeHost() { return _sizeH; } + +Nd4jLong ExternalWorkspace::sizeDevice() { return _sizeD; } +} // namespace memory +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/memory/impl/GraphMemoryManager.cpp b/libnd4j/include/memory/impl/GraphMemoryManager.cpp new file mode 100644 index 000000000000..44294a59f02d --- /dev/null +++ b/libnd4j/include/memory/impl/GraphMemoryManager.cpp @@ -0,0 +1,56 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include + +namespace sd { +namespace graph { +GraphMemoryManager::GraphMemoryManager() { + // first of all we initialize all memory managers + // CPU backend only has two: HOT and COLD + + _zones[MemoryZone::HOT] = new memory::HotRamZoneManager(); + _zones[MemoryZone::COLD] = new memory::ColdZoneManager(); +} + +GraphMemoryManager::~GraphMemoryManager() { + delete _zones[MemoryZone::HOT]; + delete _zones[MemoryZone::COLD]; +} + +MemoryDescriptor GraphMemoryManager::allocate(size_t numBytes, + MemoryZone zone) { + if (zone == MemoryZone::WARM) zone = MemoryZone::HOT; + + return _zones[zone]->allocate(numBytes); +} + +void GraphMemoryManager::release(MemoryDescriptor &descriptor) { + _zones[descriptor.zone()]->release(descriptor); +} + +void GraphMemoryManager::track(const std::shared_ptr &ptr) const { + _attached.emplace_back(ptr); +} + +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/memory/impl/HotRamZoneManager.cpp b/libnd4j/include/memory/impl/HotRamZoneManager.cpp new file mode 100644 index 000000000000..15066bd09bb1 --- /dev/null +++ b/libnd4j/include/memory/impl/HotRamZoneManager.cpp @@ -0,0 +1,38 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { +namespace memory { +MemoryDescriptor HotRamZoneManager::allocate(uint64_t numBytes) { + _used += numBytes; + + auto ptr = new int8_t[numBytes]; + return MemoryDescriptor(ptr, zone(), numBytes); +} + +void HotRamZoneManager::release(MemoryDescriptor &descriptor) { + _used -= descriptor.bytes(); + + delete[](reinterpret_cast(descriptor.address())); +} +} // namespace memory +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/memory/impl/HotZoneManager.cpp b/libnd4j/include/memory/impl/HotZoneManager.cpp new file mode 100644 index 000000000000..73db24a94eb3 --- /dev/null +++ b/libnd4j/include/memory/impl/HotZoneManager.cpp @@ -0,0 +1,31 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { +namespace memory { +MemoryZone HotZoneManager::zone() const { return HOT; } + +uint64_t HotZoneManager::available() const { return _available; } + +uint64_t HotZoneManager::used() const { return _used; } +} // namespace memory +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/memory/impl/MemoryCounter.cpp b/libnd4j/include/memory/impl/MemoryCounter.cpp index 287b1989773b..00f072e241ec 100644 --- a/libnd4j/include/memory/impl/MemoryCounter.cpp +++ b/libnd4j/include/memory/impl/MemoryCounter.cpp @@ -19,111 +19,116 @@ // #include "../MemoryCounter.h" + #include -#include #include +#include namespace sd { - namespace memory { - - MemoryCounter::MemoryCounter() { - auto numDevices = sd::AffinityManager::numberOfDevices(); - - // setting default 0s - for (int e = 0; e < numDevices; e++) { - _deviceLimits[e] = 0; - _deviceCounters[e] = 0; - } - - // setting initial values for limits - _groupLimits[sd::memory::MemoryType::HOST] = sd::Environment::getInstance().maxPrimaryMemory(); - _groupLimits[sd::memory::MemoryType::DEVICE] = sd::Environment::getInstance().maxSpecialMemory(); - - // setting initial counter values - _groupCounters[sd::memory::MemoryType::HOST] = 0; - _groupCounters[sd::memory::MemoryType::DEVICE] = 0; - } - - MemoryCounter& MemoryCounter::getInstance() { - static MemoryCounter instance; - return instance; - } - - void MemoryCounter::countIn(int deviceId, Nd4jLong numBytes) { - std::lock_guard lock(_locker); - _deviceCounters[deviceId] += numBytes; - } - - void MemoryCounter::countIn(sd::memory::MemoryType group, Nd4jLong numBytes) { - std::lock_guard lock(_locker); - _groupCounters[group] += numBytes; - } - - void MemoryCounter::countOut(int deviceId, Nd4jLong numBytes) { - std::lock_guard lock(_locker); - _deviceCounters[deviceId] -= numBytes; - } - - void MemoryCounter::countOut(sd::memory::MemoryType group, Nd4jLong numBytes) { - std::lock_guard lock(_locker); - _groupCounters[group] -= numBytes; - } - - bool MemoryCounter::validate(Nd4jLong numBytes) { - auto deviceId = sd::AffinityManager::currentDeviceId(); - return validateDevice(deviceId, numBytes); - } - - bool MemoryCounter::validateDevice(int deviceId, Nd4jLong numBytes) { - std::lock_guard lock(_locker); - auto dLimit = _deviceLimits[deviceId]; - if (dLimit <= 0) - return true; - - auto dAlloc = _deviceCounters[deviceId]; - - return numBytes + dAlloc <= dLimit; - } - - bool MemoryCounter::validateGroup(sd::memory::MemoryType group, Nd4jLong numBytes) { - std::lock_guard lock(_locker); - auto gLimit = _groupLimits[group]; - if (gLimit <= 0) - return true; - - auto gAlloc = _groupCounters[group]; - - return numBytes + gAlloc <= gLimit; - } - - Nd4jLong MemoryCounter::allocatedDevice(int deviceId) { - std::lock_guard lock(_locker); - return _deviceCounters[deviceId]; - } - - Nd4jLong MemoryCounter::allocatedGroup(sd::memory::MemoryType group) { - std::lock_guard lock(_locker); - return _groupCounters[group]; - } - - void MemoryCounter::setDeviceLimit(int deviceId, Nd4jLong numBytes) { - std::lock_guard lock(_locker); - _deviceLimits[deviceId] = numBytes; - } - - void MemoryCounter::setGroupLimit(sd::memory::MemoryType group, Nd4jLong numBytes) { - std::lock_guard lock(_locker); - _groupLimits[group] = numBytes; - } - - Nd4jLong MemoryCounter::deviceLimit(int deviceId) { - std::lock_guard lock(_locker); - return _deviceLimits[deviceId]; - } - - Nd4jLong MemoryCounter::groupLimit(sd::memory::MemoryType group) { - std::lock_guard lock(_locker); - return _groupLimits[group]; - } - } -} \ No newline at end of file +namespace memory { + +MemoryCounter::MemoryCounter() { + auto numDevices = sd::AffinityManager::numberOfDevices(); + + // setting default 0s + for (int e = 0; e < numDevices; e++) { + _deviceLimits[e] = 0; + _deviceCounters[e] = 0; + } + + // setting initial values for limits + _groupLimits[sd::memory::MemoryType::HOST] = + sd::Environment::getInstance().maxPrimaryMemory(); + _groupLimits[sd::memory::MemoryType::DEVICE] = + sd::Environment::getInstance().maxSpecialMemory(); + + // setting initial counter values + _groupCounters[sd::memory::MemoryType::HOST] = 0; + _groupCounters[sd::memory::MemoryType::DEVICE] = 0; +} + +MemoryCounter& MemoryCounter::getInstance() { + static MemoryCounter instance; + + return instance; +} + +void MemoryCounter::countIn(int deviceId, Nd4jLong numBytes) { + std::lock_guard lock(_locker); + _deviceCounters[deviceId] += numBytes; +} + +void MemoryCounter::countIn(sd::memory::MemoryType group, Nd4jLong numBytes) { + std::lock_guard lock(_locker); + _groupCounters[group] += numBytes; +} + +void MemoryCounter::countOut(int deviceId, Nd4jLong numBytes) { + std::lock_guard lock(_locker); + _deviceCounters[deviceId] -= numBytes; +} + +void MemoryCounter::countOut(sd::memory::MemoryType group, Nd4jLong numBytes) { + std::lock_guard lock(_locker); + _groupCounters[group] -= numBytes; +} + +bool MemoryCounter::validate(Nd4jLong numBytes) { + auto deviceId = sd::AffinityManager::currentDeviceId(); + return validateDevice(deviceId, numBytes); +} + +bool MemoryCounter::validateDevice(int deviceId, Nd4jLong numBytes) { + std::lock_guard lock(_locker); + auto dLimit = _deviceLimits[deviceId]; + if (dLimit <= 0) return true; + + auto dAlloc = _deviceCounters[deviceId]; + + return numBytes + dAlloc <= dLimit; +} + +bool MemoryCounter::validateGroup(sd::memory::MemoryType group, + Nd4jLong numBytes) { + std::lock_guard lock(_locker); + auto gLimit = _groupLimits[group]; + if (gLimit <= 0) return true; + + auto gAlloc = _groupCounters[group]; + + return numBytes + gAlloc <= gLimit; +} + +Nd4jLong MemoryCounter::allocatedDevice(int deviceId) { + std::lock_guard lock(_locker); + return _deviceCounters[deviceId]; +} + +Nd4jLong MemoryCounter::allocatedGroup(sd::memory::MemoryType group) { + std::lock_guard lock(_locker); + return _groupCounters[group]; +} + +void MemoryCounter::setDeviceLimit(int deviceId, Nd4jLong numBytes) { + std::lock_guard lock(_locker); + _deviceLimits[deviceId] = numBytes; +} + +void MemoryCounter::setGroupLimit(sd::memory::MemoryType group, + Nd4jLong numBytes) { + std::lock_guard lock(_locker); + _groupLimits[group] = numBytes; +} + +Nd4jLong MemoryCounter::deviceLimit(int deviceId) { + std::lock_guard lock(_locker); + return _deviceLimits[deviceId]; +} + +Nd4jLong MemoryCounter::groupLimit(sd::memory::MemoryType group) { + std::lock_guard lock(_locker); + return _groupLimits[group]; +} + +} // namespace memory +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/memory/impl/MemoryDescriptor.cpp b/libnd4j/include/memory/impl/MemoryDescriptor.cpp new file mode 100644 index 000000000000..ea0c0e0b3c2a --- /dev/null +++ b/libnd4j/include/memory/impl/MemoryDescriptor.cpp @@ -0,0 +1,68 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { +namespace memory { +MemoryDescriptor::MemoryDescriptor(void *ptr, MemoryZone zone, uint64_t bytes) + : _ptr(ptr), _zone(zone), _bytes(bytes) { + // +} + +MemoryDescriptor::MemoryDescriptor(const MemoryDescriptor &other) noexcept + : _ptr(other._ptr), _zone(other._zone), _bytes(other._bytes) { + // +} + +MemoryDescriptor &MemoryDescriptor::operator=( + const MemoryDescriptor &other) noexcept { + if (this == &other) return *this; + + _ptr = other._ptr; + _zone = other._zone; + _bytes = other._bytes; + + return *this; +} + +MemoryDescriptor::MemoryDescriptor(MemoryDescriptor &&other) noexcept + : _ptr(other._ptr), _zone(other._zone), _bytes(other._bytes) { + // +} + +MemoryDescriptor &MemoryDescriptor::operator=( + MemoryDescriptor &&other) noexcept { + if (this == &other) return *this; + + _ptr = other._ptr; + _zone = other._zone; + _bytes = other._bytes; + + return *this; +} + +void *MemoryDescriptor::address() const { return _ptr; } + +MemoryZone MemoryDescriptor::zone() const { return _zone; } + +uint64_t MemoryDescriptor::bytes() const { return _bytes; } +} // namespace memory +} // namespace sd diff --git a/libnd4j/include/memory/impl/MemoryRegistrator.cpp b/libnd4j/include/memory/impl/MemoryRegistrator.cpp index 0ac2bf0cb155..de7a7119624d 100644 --- a/libnd4j/include/memory/impl/MemoryRegistrator.cpp +++ b/libnd4j/include/memory/impl/MemoryRegistrator.cpp @@ -21,65 +21,58 @@ #include namespace sd { - namespace memory { - - MemoryRegistrator::MemoryRegistrator() { - _workspace = nullptr; - }; - - MemoryRegistrator& MemoryRegistrator::getInstance() { - static MemoryRegistrator instance; - return instance; - } - - bool MemoryRegistrator::hasWorkspaceAttached() { - return _workspace != nullptr; - } - - Workspace* MemoryRegistrator::getWorkspace() { - return _workspace; - } - - void MemoryRegistrator::attachWorkspace(Workspace* workspace) { - _workspace = workspace; - } - - void MemoryRegistrator::forgetWorkspace() { - _workspace = nullptr; - } - - void MemoryRegistrator::setGraphMemoryFootprint(Nd4jLong hash, Nd4jLong bytes) { - _lock.lock(); - - _footprint[hash] = bytes; - - _lock.unlock(); - } - - void MemoryRegistrator::setGraphMemoryFootprintIfGreater(Nd4jLong hash, Nd4jLong bytes) { - _lock.lock(); - - if (_footprint.count(hash) == 0) - _footprint[hash] = bytes; - else { - Nd4jLong cv = _footprint[hash]; - if (bytes > cv) - _footprint[hash] = bytes; - } - - _lock.unlock(); - } - - Nd4jLong MemoryRegistrator::getGraphMemoryFootprint(Nd4jLong hash) { - _lock.lock(); - - Nd4jLong result = 0L; - if (_footprint.count(hash) > 0) - result = _footprint[hash]; - - _lock.unlock(); - - return result; - } - } -} \ No newline at end of file +namespace memory { + +MemoryRegistrator::MemoryRegistrator() { _workspace = nullptr; }; + +MemoryRegistrator& MemoryRegistrator::getInstance() { + static MemoryRegistrator instance; + + return instance; +} + +bool MemoryRegistrator::hasWorkspaceAttached() { return _workspace != nullptr; } + +Workspace* MemoryRegistrator::getWorkspace() { return _workspace; } + +void MemoryRegistrator::attachWorkspace(Workspace* workspace) { + _workspace = workspace; +} + +void MemoryRegistrator::forgetWorkspace() { _workspace = nullptr; } + +void MemoryRegistrator::setGraphMemoryFootprint(Nd4jLong hash, Nd4jLong bytes) { + _lock.lock(); + + _footprint[hash] = bytes; + + _lock.unlock(); +} + +void MemoryRegistrator::setGraphMemoryFootprintIfGreater(Nd4jLong hash, + Nd4jLong bytes) { + _lock.lock(); + + if (_footprint.count(hash) == 0) + _footprint[hash] = bytes; + else { + Nd4jLong cv = _footprint[hash]; + if (bytes > cv) _footprint[hash] = bytes; + } + + _lock.unlock(); +} + +Nd4jLong MemoryRegistrator::getGraphMemoryFootprint(Nd4jLong hash) { + _lock.lock(); + + Nd4jLong result = 0L; + if (_footprint.count(hash) > 0) result = _footprint[hash]; + + _lock.unlock(); + + return result; +} + +} // namespace memory +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/memory/impl/MemoryReport.cpp b/libnd4j/include/memory/impl/MemoryReport.cpp index 0c623b0cea4f..a56c9e1d1362 100644 --- a/libnd4j/include/memory/impl/MemoryReport.cpp +++ b/libnd4j/include/memory/impl/MemoryReport.cpp @@ -20,42 +20,42 @@ #include "memory/MemoryReport.h" -bool sd::memory::MemoryReport::operator<(const sd::memory::MemoryReport &other) const { - return this->_rss < other._rss; +bool sd::memory::MemoryReport::operator<( + const sd::memory::MemoryReport &other) const { + return this->_rss < other._rss; } -bool sd::memory::MemoryReport::operator>(const sd::memory::MemoryReport &other) const { - return this->_rss > other._rss; +bool sd::memory::MemoryReport::operator>( + const sd::memory::MemoryReport &other) const { + return this->_rss > other._rss; } -bool sd::memory::MemoryReport::operator==(const sd::memory::MemoryReport &other) const { - return this->_rss == other._rss; +bool sd::memory::MemoryReport::operator==( + const sd::memory::MemoryReport &other) const { + return this->_rss == other._rss; } -bool sd::memory::MemoryReport::operator!=(const sd::memory::MemoryReport &other) const { - return this->_rss != other._rss; +bool sd::memory::MemoryReport::operator!=( + const sd::memory::MemoryReport &other) const { + return this->_rss != other._rss; } -bool sd::memory::MemoryReport::operator<=(const sd::memory::MemoryReport &other) const { - return this->_rss <= other._rss; +bool sd::memory::MemoryReport::operator<=( + const sd::memory::MemoryReport &other) const { + return this->_rss <= other._rss; } -bool sd::memory::MemoryReport::operator>=(const sd::memory::MemoryReport &other) const { - return this->_rss >= other._rss; +bool sd::memory::MemoryReport::operator>=( + const sd::memory::MemoryReport &other) const { + return this->_rss >= other._rss; } -Nd4jLong sd::memory::MemoryReport::getVM() const { - return _vm; -} +Nd4jLong sd::memory::MemoryReport::getVM() const { return _vm; } -void sd::memory::MemoryReport::setVM(Nd4jLong _vm) { - MemoryReport::_vm = _vm; -} +void sd::memory::MemoryReport::setVM(Nd4jLong _vm) { MemoryReport::_vm = _vm; } -Nd4jLong sd::memory::MemoryReport::getRSS() const { - return _rss; -} +Nd4jLong sd::memory::MemoryReport::getRSS() const { return _rss; } void sd::memory::MemoryReport::setRSS(Nd4jLong _rss) { - MemoryReport::_rss = _rss; + MemoryReport::_rss = _rss; } diff --git a/libnd4j/include/memory/impl/MemoryTracker.cpp b/libnd4j/include/memory/impl/MemoryTracker.cpp index cf2b975cf188..cb35467b94f5 100644 --- a/libnd4j/include/memory/impl/MemoryTracker.cpp +++ b/libnd4j/include/memory/impl/MemoryTracker.cpp @@ -18,157 +18,164 @@ // Created by raver119 on 07.05.19. // -#include -#include #include - - +#include #include -#if defined(__GNUC__) && !defined(__MINGW64__) && !defined(SD_ANDROID_BUILD) && !defined(SD_IOS_BUILD) && !defined(SD_APPLE_BUILD) +#include + +#if defined(__GNUC__) && !defined(__MINGW64__) && \ + !defined(SD_ANDROID_BUILD) && !defined(SD_IOS_BUILD) && \ + !defined(SD_APPLE_BUILD) -#include -#include #include +#include +#include #endif namespace sd { - namespace memory { - - MemoryTracker::MemoryTracker() { - // - } - - MemoryTracker& MemoryTracker::getInstance() { - static MemoryTracker instance; - return instance; - } - -#if defined(__GNUC__) && !defined(__MINGW64__) && !defined(SD_ANDROID_BUILD) && !defined(SD_IOS_BUILD) && !defined(SD_APPLE_BUILD) - std::string demangle(char *message) { - char *mangled_name = 0, *offset_begin = 0, *offset_end = 0; - - // find parantheses and +address offset surrounding mangled name - for (char *p = message; *p; ++p) - { - if (*p == '(') - { - mangled_name = p; - } - else if (*p == '+') - { - offset_begin = p; - } - else if (*p == ')') - { - offset_end = p; - break; - } - } - - // if the line could be processed, attempt to demangle the symbol - if (mangled_name && offset_begin && offset_end && mangled_name < offset_begin) { - *mangled_name++ = '\0'; - *offset_begin++ = '\0'; - *offset_end++ = '\0'; - - int status; - char * real_name = abi::__cxa_demangle(mangled_name, 0, 0, &status); - - // if demangling is successful, output the demangled function name - if (status == 0) { - std::string result(real_name); - free(real_name); - return result; - } else { - // otherwise, output the mangled function name - std::string result (message); - free(real_name); - return result; - } - } - - // safe return - return std::string(""); - } +namespace memory { -#endif +MemoryTracker::MemoryTracker() { + // +} - void MemoryTracker::countIn(MemoryType type, Nd4jPointer ptr, Nd4jLong numBytes) { -#if defined(__GNUC__) && !defined(__MINGW64__) && !defined(SD_ANDROID_BUILD) && !defined(SD_IOS_BUILD) && !defined(SD_APPLE_BUILD) - if (Environment::getInstance().isDetectingLeaks()) { - auto lptr = reinterpret_cast(ptr); +MemoryTracker &MemoryTracker::getInstance() { + static MemoryTracker instance; - _locker.lock(); + return instance; +} - void *array[50]; - size_t size; - char **messages; - size = backtrace(array, 50); +#if defined(__GNUC__) && !defined(__MINGW64__) && \ + !defined(SD_ANDROID_BUILD) && !defined(SD_IOS_BUILD) && \ + !defined(SD_APPLE_BUILD) +std::string demangle(char *message) { + char *mangled_name = 0, *offset_begin = 0, *offset_end = 0; + + // find parantheses and +address offset surrounding mangled name + for (char *p = message; *p; ++p) { + if (*p == '(') { + mangled_name = p; + } else if (*p == '+') { + offset_begin = p; + } else if (*p == ')') { + offset_end = p; + break; + } + } + + // if the line could be processed, attempt to demangle the symbol + if (mangled_name && offset_begin && offset_end && + mangled_name < offset_begin) { + *mangled_name++ = '\0'; + *offset_begin++ = '\0'; + *offset_end++ = '\0'; + + int status; + char *real_name = abi::__cxa_demangle(mangled_name, 0, 0, &status); + + // if demangling is successful, output the demangled function name + if (status == 0) { + std::string result(real_name); + free(real_name); + return result; + } else { + // otherwise, output the mangled function name + std::string result(message); + free(real_name); + return result; + } + } - std::string stack(""); - messages = backtrace_symbols(array, size); - for (int i = 1; i < size && messages != NULL; ++i) { - stack += demangle(messages[i]) + "\n"; - } + // safe return + return std::string(""); +} - free(messages); +#endif - if (stack.find("ConstantTad") != std::string::npos || - stack.find("ConstantShape") != std::string::npos) { - _locker.unlock(); - return; - } +void MemoryTracker::countIn(MemoryType type, Nd4jPointer ptr, + Nd4jLong numBytes) { +#if defined(__GNUC__) && !defined(__MINGW64__) && \ + !defined(SD_ANDROID_BUILD) && !defined(SD_IOS_BUILD) && \ + !defined(SD_APPLE_BUILD) + if (Environment::getInstance().isDetectingLeaks()) { + auto lptr = reinterpret_cast(ptr); + + _locker.lock(); + + void *array[50]; + size_t size; + char **messages; + size = backtrace(array, 50); + + std::string stack(""); + messages = backtrace_symbols(array, size); + for (int i = 1; i < size && messages != NULL; ++i) { + stack += demangle(messages[i]) + "\n"; + } - std::pair pair(lptr, AllocationEntry(type, lptr, numBytes, stack)); - _allocations.insert(pair); + free(messages); - _locker.unlock(); - } -#endif - } + if (stack.find("ConstantTad") != std::string::npos || + stack.find("ConstantShape") != std::string::npos) { + _locker.unlock(); + return; + } - void MemoryTracker::countOut(Nd4jPointer ptr) { -#if defined(__GNUC__) && !defined(__MINGW64__) && !defined(SD_ANDROID_BUILD) && !defined(SD_IOS_BUILD) && !defined(SD_APPLE_BUILD) - if (Environment::getInstance().isDetectingLeaks()) { - auto lptr = reinterpret_cast(ptr); + std::pair pair( + lptr, AllocationEntry(type, lptr, numBytes, stack)); + _allocations.insert(pair); - _locker.lock(); - if (_released.count(lptr) > 0) { - //throw std::runtime_error("Double free!"); - } + _locker.unlock(); + } +#endif +} - if (_allocations.count(lptr) > 0) { - //auto entry = _allocations[lptr]; - //std::string stack("new stack"); - //std::pair pair(lptr, entry); - //_released.insert(pair); +void MemoryTracker::countOut(Nd4jPointer ptr) { +#if defined(__GNUC__) && !defined(__MINGW64__) && \ + !defined(SD_ANDROID_BUILD) && !defined(SD_IOS_BUILD) && \ + !defined(SD_APPLE_BUILD) + if (Environment::getInstance().isDetectingLeaks()) { + auto lptr = reinterpret_cast(ptr); + _locker.lock(); + if (_released.count(lptr) > 0) { + // throw std::runtime_error("Double free!"); + } - _allocations.erase(lptr); - } + if (_allocations.count(lptr) > 0) { + // auto entry = _allocations[lptr]; + // std::string stack("new stack"); + // std::pair pair(lptr, entry); + //_released.insert(pair); - _locker.unlock(); - } + _allocations.erase(lptr); + } + + _locker.unlock(); + } #endif - } +} - void MemoryTracker::summarize() { - if (!_allocations.empty()) { - nd4j_printf("\n%i leaked allocations\n", (int) _allocations.size()); +void MemoryTracker::summarize() { + if (!_allocations.empty()) { + nd4j_printf("\n%i leaked allocations\n", (int)_allocations.size()); - for (auto &v: _allocations) { - nd4j_printf("Leak of %i [%s] bytes\n%s\n\n", (int) v.second.numBytes(), v.second.memoryType() == MemoryType::HOST ? "HOST" : "DEVICE", v.second.stackTrace().c_str()); - } + for (auto &v : _allocations) { + nd4j_printf("Leak of %i [%s] bytes\n%s\n\n", (int)v.second.numBytes(), + v.second.memoryType() == MemoryType::HOST ? "HOST" : "DEVICE", + v.second.stackTrace().c_str()); + } - throw std::runtime_error("Non-released allocations found"); - } - } + throw std::runtime_error("Non-released allocations found"); + } +} - void MemoryTracker::reset() { - _allocations.clear(); - _released.clear(); - } - } +void MemoryTracker::reset() { + _allocations.clear(); + _released.clear(); } + +} // namespace memory +} // namespace sd diff --git a/libnd4j/include/memory/impl/MemoryUtils.cpp b/libnd4j/include/memory/impl/MemoryUtils.cpp index 8500a044e341..0c576ba81057 100644 --- a/libnd4j/include/memory/impl/MemoryUtils.cpp +++ b/libnd4j/include/memory/impl/MemoryUtils.cpp @@ -19,78 +19,81 @@ // #include "memory/MemoryUtils.h" + #include #if defined(__APPLE__) -#include +#include #include #elif defined(_WIN32) || defined(__WIN32__) || defined(WIN32) #else // linux -#include #include +#include #include + #include #endif - -bool sd::memory::MemoryUtils::retrieveMemoryStatistics(sd::memory::MemoryReport &report) { +bool sd::memory::MemoryUtils::retrieveMemoryStatistics( + sd::memory::MemoryReport &report) { #if defined(__APPLE__) - nd4j_debug("APPLE route\n", ""); -/* - struct task_basic_info t_info; - mach_msg_type_number_t t_info_count = TASK_BASIC_INFO_COUNT; + nd4j_debug("APPLE route\n", ""); + /* + struct task_basic_info t_info; + mach_msg_type_number_t t_info_count = TASK_BASIC_INFO_COUNT; - if (KERN_SUCCESS != task_info(mach_task_self(), TASK_BASIC_INFO, (task_info_t)&t_info, &t_info_count)) - return false; + if (KERN_SUCCESS != task_info(mach_task_self(), TASK_BASIC_INFO, + (task_info_t)&t_info, &t_info_count)) return false; - report.setVM(t_info.resident_size); - report.setRSS(t_info.resident_size); + report.setVM(t_info.resident_size); + report.setRSS(t_info.resident_size); - nd4j_debug("RSS: %lld; VM: %lld;\n", report.getRSS(), report.getVM()); -*/ - struct rusage _usage; + nd4j_debug("RSS: %lld; VM: %lld;\n", report.getRSS(), report.getVM()); + */ + struct rusage _usage; - auto res = getrusage(RUSAGE_SELF, &_usage); + auto res = getrusage(RUSAGE_SELF, &_usage); - report.setRSS(_usage.ru_maxrss); + report.setRSS(_usage.ru_maxrss); - nd4j_debug("Usage: %lld; %lld; %lld; %lld;\n", _usage.ru_ixrss, _usage.ru_idrss, _usage.ru_isrss, _usage.ru_maxrss); + nd4j_debug("Usage: %lld; %lld; %lld; %lld;\n", _usage.ru_ixrss, + _usage.ru_idrss, _usage.ru_isrss, _usage.ru_maxrss); - return true; + return true; #elif defined(_WIN32) || defined(__WIN32__) || defined(WIN32) - nd4j_debug("WIN32 route\n", ""); - + nd4j_debug("WIN32 route\n", ""); #else - nd4j_debug("LINUX route\n", ""); - int fd = open("/proc/self/statm", O_RDONLY, 0); - if (fd >= 0) { - char line[256]; - char* s; - int n; - lseek(fd, 0, SEEK_SET); - if ((n = read(fd, line, sizeof(line))) > 0 && (s = (char*)memchr(line, ' ', n)) != NULL) { - report.setRSS((Nd4jLong)(atoll(s + 1) * getpagesize())); - } - close(fd); + nd4j_debug("LINUX route\n", ""); + int fd = open("/proc/self/statm", O_RDONLY, 0); + if (fd >= 0) { + char line[256]; + char* s; + int n; + lseek(fd, 0, SEEK_SET); + if ((n = read(fd, line, sizeof(line))) > 0 && + (s = (char*)memchr(line, ' ', n)) != NULL) { + report.setRSS((Nd4jLong)(atoll(s + 1) * getpagesize())); } + close(fd); + } - /* - struct rusage _usage; - - auto res = getrusage(RUSAGE_SELF, &_usage); + /* + struct rusage _usage; - report.setRSS(_usage.ru_maxrss); + auto res = getrusage(RUSAGE_SELF, &_usage); - //nd4j_printf("Usage: %lld; %lld; %lld; %lld;\n", _usage.ru_ixrss, _usage.ru_idrss, _usage.ru_isrss, _usage.ru_maxrss); - */ + report.setRSS(_usage.ru_maxrss); + //nd4j_printf("Usage: %lld; %lld; %lld; %lld;\n", _usage.ru_ixrss, + _usage.ru_idrss, _usage.ru_isrss, _usage.ru_maxrss); + */ - return true; + return true; #endif - return false; + return false; } diff --git a/libnd4j/include/memory/impl/MmapDeallocator.cpp b/libnd4j/include/memory/impl/MmapDeallocator.cpp new file mode 100644 index 000000000000..b4df0ada8009 --- /dev/null +++ b/libnd4j/include/memory/impl/MmapDeallocator.cpp @@ -0,0 +1,33 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include + +namespace sd { +namespace memory { + +void MmapDeallocator::release(void *ptr) { + auto ref = reinterpret_cast(ptr); + ::munmapFile(nullptr, ref, ref[2]); +} + +} // namespace memory +} // namespace sd diff --git a/libnd4j/include/memory/impl/WarmZoneManager.cpp b/libnd4j/include/memory/impl/WarmZoneManager.cpp new file mode 100644 index 000000000000..d6969fae099f --- /dev/null +++ b/libnd4j/include/memory/impl/WarmZoneManager.cpp @@ -0,0 +1,21 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include diff --git a/libnd4j/include/memory/impl/ZoneManager.cpp b/libnd4j/include/memory/impl/ZoneManager.cpp new file mode 100644 index 000000000000..4b90acddc2ad --- /dev/null +++ b/libnd4j/include/memory/impl/ZoneManager.cpp @@ -0,0 +1,21 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include diff --git a/libnd4j/include/ops/BroadcastBoolOpsTuple.h b/libnd4j/include/ops/BroadcastBoolOpsTuple.h index 188186b4ceb4..2a10326a84c2 100644 --- a/libnd4j/include/ops/BroadcastBoolOpsTuple.h +++ b/libnd4j/include/ops/BroadcastBoolOpsTuple.h @@ -18,33 +18,35 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_BROADCASTBOOLOPSTUPLE_H -#define DEV_TESTS_BROADCASTBOOLOPSTUPLE_H +#ifndef SD_BROADCASTBOOLOPSTUPLE_H +#define SD_BROADCASTBOOLOPSTUPLE_H -#include #include +#include namespace sd { - class ND4J_EXPORT BroadcastBoolOpsTuple { - private: - - public: - sd::scalar::BoolOps s; - sd::pairwise::BoolOps p; - sd::broadcast::BoolOps b; - - BroadcastBoolOpsTuple() = default; - ~BroadcastBoolOpsTuple() = default; - - BroadcastBoolOpsTuple(sd::scalar::BoolOps scalar, sd::pairwise::BoolOps pairwise, sd::broadcast::BoolOps broadcast) { - s = scalar; - p = pairwise; - b = broadcast; - } - - static BroadcastBoolOpsTuple custom(sd::scalar::BoolOps scalar, sd::pairwise::BoolOps pairwise, sd::broadcast::BoolOps broadcast); - }; -} - - -#endif //DEV_TESTS_BROADCASTOPSTUPLE_H +class SD_EXPORT BroadcastBoolOpsTuple { + private: + public: + sd::scalar::BoolOps s; + sd::pairwise::BoolOps p; + sd::broadcast::BoolOps b; + + BroadcastBoolOpsTuple() = default; + ~BroadcastBoolOpsTuple() = default; + + BroadcastBoolOpsTuple(sd::scalar::BoolOps scalar, + sd::pairwise::BoolOps pairwise, + sd::broadcast::BoolOps broadcast) { + s = scalar; + p = pairwise; + b = broadcast; + } + + static BroadcastBoolOpsTuple custom(sd::scalar::BoolOps scalar, + sd::pairwise::BoolOps pairwise, + sd::broadcast::BoolOps broadcast); +}; +} // namespace sd + +#endif // SD_BROADCASTOPSTUPLE_H diff --git a/libnd4j/include/ops/BroadcastIntOpsTuple.h b/libnd4j/include/ops/BroadcastIntOpsTuple.h index 258719004aba..9d0fbdf81cce 100644 --- a/libnd4j/include/ops/BroadcastIntOpsTuple.h +++ b/libnd4j/include/ops/BroadcastIntOpsTuple.h @@ -18,33 +18,34 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_BROADCASTINTOPSTUPLE_H -#define DEV_TESTS_BROADCASTINTOPSTUPLE_H +#ifndef SD_BROADCASTINTOPSTUPLE_H +#define SD_BROADCASTINTOPSTUPLE_H -#include #include +#include namespace sd { - class ND4J_EXPORT BroadcastIntOpsTuple { - private: - - public: - sd::scalar::IntOps s; - sd::pairwise::IntOps p; - sd::broadcast::IntOps b; - - BroadcastIntOpsTuple() = default; - ~BroadcastIntOpsTuple() = default; - - BroadcastIntOpsTuple(sd::scalar::IntOps scalar, sd::pairwise::IntOps pairwise, sd::broadcast::IntOps broadcast) { - s = scalar; - p = pairwise; - b = broadcast; - } - - static BroadcastIntOpsTuple custom(sd::scalar::IntOps scalar, sd::pairwise::IntOps pairwise, sd::broadcast::IntOps broadcast); - }; -} - - -#endif //DEV_TESTS_BROADCASTOPSTUPLE_H +class SD_EXPORT BroadcastIntOpsTuple { + private: + public: + sd::scalar::IntOps s; + sd::pairwise::IntOps p; + sd::broadcast::IntOps b; + + BroadcastIntOpsTuple() = default; + ~BroadcastIntOpsTuple() = default; + + BroadcastIntOpsTuple(sd::scalar::IntOps scalar, sd::pairwise::IntOps pairwise, + sd::broadcast::IntOps broadcast) { + s = scalar; + p = pairwise; + b = broadcast; + } + + static BroadcastIntOpsTuple custom(sd::scalar::IntOps scalar, + sd::pairwise::IntOps pairwise, + sd::broadcast::IntOps broadcast); +}; +} // namespace sd + +#endif // SD_BROADCASTOPSTUPLE_H diff --git a/libnd4j/include/ops/BroadcastOpsTuple.h b/libnd4j/include/ops/BroadcastOpsTuple.h index 34e2c603995d..81b470076c28 100644 --- a/libnd4j/include/ops/BroadcastOpsTuple.h +++ b/libnd4j/include/ops/BroadcastOpsTuple.h @@ -18,45 +18,46 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_BROADCASTOPSTUPLE_H -#define DEV_TESTS_BROADCASTOPSTUPLE_H +#ifndef SD_BROADCASTOPSTUPLE_H +#define SD_BROADCASTOPSTUPLE_H -#include #include +#include namespace sd { - class ND4J_EXPORT BroadcastOpsTuple { - private: - - public: - sd::scalar::Ops s; - sd::pairwise::Ops p; - sd::broadcast::Ops b; - - BroadcastOpsTuple() = default; - ~BroadcastOpsTuple() = default; - - BroadcastOpsTuple(sd::scalar::Ops scalar, sd::pairwise::Ops pairwise, sd::broadcast::Ops broadcast) { - s = scalar; - p = pairwise; - b = broadcast; - } - - static BroadcastOpsTuple custom(sd::scalar::Ops scalar, sd::pairwise::Ops pairwise, sd::broadcast::Ops broadcast); - - static BroadcastOpsTuple Add(); - static BroadcastOpsTuple Assign(); - static BroadcastOpsTuple Divide(); - static BroadcastOpsTuple DivideNoNan(); - static BroadcastOpsTuple Multiply(); - static BroadcastOpsTuple Subtract(); - static BroadcastOpsTuple IGamma(); - static BroadcastOpsTuple IGammac(); - - static BroadcastOpsTuple Pow(); - static BroadcastOpsTuple PowDerivative(); - }; -} - - -#endif //DEV_TESTS_BROADCASTOPSTUPLE_H +class SD_EXPORT BroadcastOpsTuple { + private: + public: + sd::scalar::Ops s; + sd::pairwise::Ops p; + sd::broadcast::Ops b; + + BroadcastOpsTuple() = default; + ~BroadcastOpsTuple() = default; + + BroadcastOpsTuple(sd::scalar::Ops scalar, sd::pairwise::Ops pairwise, + sd::broadcast::Ops broadcast) { + s = scalar; + p = pairwise; + b = broadcast; + } + + static BroadcastOpsTuple custom(sd::scalar::Ops scalar, + sd::pairwise::Ops pairwise, + sd::broadcast::Ops broadcast); + + static BroadcastOpsTuple Add(); + static BroadcastOpsTuple Assign(); + static BroadcastOpsTuple Divide(); + static BroadcastOpsTuple DivideNoNan(); + static BroadcastOpsTuple Multiply(); + static BroadcastOpsTuple Subtract(); + static BroadcastOpsTuple IGamma(); + static BroadcastOpsTuple IGammac(); + + static BroadcastOpsTuple Pow(); + static BroadcastOpsTuple PowDerivative(); +}; +} // namespace sd + +#endif // SD_BROADCASTOPSTUPLE_H diff --git a/libnd4j/include/ops/InputType.h b/libnd4j/include/ops/InputType.h index 4deff4900807..94e62fb432de 100644 --- a/libnd4j/include/ops/InputType.h +++ b/libnd4j/include/ops/InputType.h @@ -22,15 +22,15 @@ #define ND4J_INPUTTYPE_H namespace sd { - namespace ops { - enum InputType { - InputType_BOOLEAN = 0, - InputType_NUMERIC = 1, - InputType_STRINGULAR = 2, - InputType_NUMERIC_SET = 3, - InputType_STRINGULAR_SET = 4, - }; - } +namespace ops { +enum InputType { + InputType_BOOLEAN = 0, + InputType_NUMERIC = 1, + InputType_STRINGULAR = 2, + InputType_NUMERIC_SET = 3, + InputType_STRINGULAR_SET = 4, +}; } +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/BooleanOp.h b/libnd4j/include/ops/declarable/BooleanOp.h index b04ca8ecab54..f0a9f11ab9fe 100644 --- a/libnd4j/include/ops/declarable/BooleanOp.h +++ b/libnd4j/include/ops/declarable/BooleanOp.h @@ -22,30 +22,31 @@ #define LIBND4J_BOOLEANOP_H #include -#include "OpDescriptor.h" + #include "DeclarableOp.h" +#include "OpDescriptor.h" namespace sd { - namespace ops { - class ND4J_EXPORT BooleanOp : public DeclarableOp { - protected: - OpDescriptor * _descriptor; - - bool prepareOutputs(Context& block); - Nd4jStatus validateAndExecute(Context& block) override = 0; - public: - BooleanOp(const char *name, int numInputs, bool scalar); +namespace ops { +class SD_EXPORT BooleanOp : public DeclarableOp { + protected: + OpDescriptor* _descriptor; - bool verify(const std::vector& args); - bool verify(sd::graph::Context& block); + bool prepareOutputs(Context& block); + Nd4jStatus validateAndExecute(Context& block) override = 0; - Nd4jStatus execute(Context* block) override; + public: + BooleanOp(const char* name, int numInputs, bool scalar); - ShapeList *calculateOutputShape(ShapeList *inputShape, sd::graph::Context& block) override; - }; - } -} + bool verify(const std::vector& args); + bool verify(sd::graph::Context& block); + Nd4jStatus execute(Context* block) override; + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_BOOLEANOP_H \ No newline at end of file +#endif // LIBND4J_BOOLEANOP_H \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/BroadcastableBoolOp.h b/libnd4j/include/ops/declarable/BroadcastableBoolOp.h index c48650294315..2455667f7414 100644 --- a/libnd4j/include/ops/declarable/BroadcastableBoolOp.h +++ b/libnd4j/include/ops/declarable/BroadcastableBoolOp.h @@ -22,22 +22,24 @@ #define SD_BROADCASTABLEBOOLOP_H #include -#include "OpDescriptor.h" -#include "DeclarableOp.h" + #include "DeclarableCustomOp.h" +#include "DeclarableOp.h" +#include "OpDescriptor.h" namespace sd { - namespace ops { - class ND4J_EXPORT BroadcastableBoolOp : public DeclarableCustomOp{ - protected: - Nd4jStatus validateAndExecute(Context& block) override = 0; - public: - BroadcastableBoolOp(const char *name, int numTArgs, int numIArgs); +namespace ops { +class SD_EXPORT BroadcastableBoolOp : public DeclarableCustomOp { + protected: + Nd4jStatus validateAndExecute(Context &block) override = 0; - ShapeList *calculateOutputShape(ShapeList *inputShape, sd::graph::Context& block) override; - }; - } -} + public: + BroadcastableBoolOp(const char *name, int numTArgs, int numIArgs); + ShapeList *calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) override; +}; +} // namespace ops +} // namespace sd -#endif //SD_BROADCASTABLEBOOLOP_H +#endif // SD_BROADCASTABLEBOOLOP_H diff --git a/libnd4j/include/ops/declarable/BroadcastableOp.h b/libnd4j/include/ops/declarable/BroadcastableOp.h index 9bc7561283e8..9b517e160449 100644 --- a/libnd4j/include/ops/declarable/BroadcastableOp.h +++ b/libnd4j/include/ops/declarable/BroadcastableOp.h @@ -22,22 +22,24 @@ #define LIBND4J_BROADCASTABLEOP_H #include -#include "OpDescriptor.h" -#include "DeclarableOp.h" + #include "DeclarableCustomOp.h" +#include "DeclarableOp.h" +#include "OpDescriptor.h" namespace sd { - namespace ops { - class ND4J_EXPORT BroadcastableOp : public DeclarableCustomOp{ - protected: - Nd4jStatus validateAndExecute(Context& block) override = 0; - public: - BroadcastableOp(const char *name, int numTArgs, int numIArgs); +namespace ops { +class SD_EXPORT BroadcastableOp : public DeclarableCustomOp { + protected: + Nd4jStatus validateAndExecute(Context &block) override = 0; - ShapeList *calculateOutputShape(ShapeList *inputShape, sd::graph::Context& block) override; - }; - } -} + public: + BroadcastableOp(const char *name, int numTArgs, int numIArgs); + ShapeList *calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_BROADCASTABLEOP_H +#endif // LIBND4J_BROADCASTABLEOP_H diff --git a/libnd4j/include/ops/declarable/CustomOperations.h b/libnd4j/include/ops/declarable/CustomOperations.h index 8aa612c7bfe5..dccb3e96d02c 100644 --- a/libnd4j/include/ops/declarable/CustomOperations.h +++ b/libnd4j/include/ops/declarable/CustomOperations.h @@ -21,67 +21,66 @@ #ifndef LIBND4J_CUSTOMOPERATIONS_H #define LIBND4J_CUSTOMOPERATIONS_H +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include +#include +#include #include #include -#include +#include #include +#include +#include +#include +#include #include -#include -#include -#include -#include -#include +#include #include #include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include #include -#include -#include -#include -#include +#include +#include +#include #include +#include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include - namespace sd { - struct ND4J_EXPORT _loader { - _loader(); - }; - - namespace ops { +struct SD_EXPORT _loader { + _loader(); +}; - // logic ops - DECLARE_DIVERGENT_OP(Switch, 2, 2, true); - DECLARE_LOGIC_OP(While); - DECLARE_LOGIC_OP(Scope); - DECLARE_LOGIC_OP(Conditional); - DECLARE_LOGIC_OP(Return); +namespace ops { +// logic ops +DECLARE_DIVERGENT_OP(Switch, 2, 2, true); +DECLARE_LOGIC_OP(While); +DECLARE_LOGIC_OP(Scope); +DECLARE_LOGIC_OP(Conditional); +DECLARE_LOGIC_OP(Return); - /** - * This operations exposes given arguments as it's own outputs, but does it only once. - * Subsequent calls will be served directly by this op. - * - * PLEASE NOTE: This operation is internal graph operation, and shouldn't be used directly usually. - */ - DECLARE_CUSTOM_OP(expose, -1, -1, true, 0, 0); - } -} +/** + * This operations exposes given arguments as it's own outputs, but does it only + * once. Subsequent calls will be served directly by this op. + * + * PLEASE NOTE: This operation is internal graph operation, and shouldn't be + * used directly usually. + */ +DECLARE_CUSTOM_OP(expose, -1, -1, true, 0, 0); +} // namespace ops +} // namespace sd -#endif //LIBND4J_CUSTOMOPERATIONS_H +#endif // LIBND4J_CUSTOMOPERATIONS_H diff --git a/libnd4j/include/ops/declarable/DeclarableCustomOp.h b/libnd4j/include/ops/declarable/DeclarableCustomOp.h index 4aa133a4bba8..163c519bad8c 100644 --- a/libnd4j/include/ops/declarable/DeclarableCustomOp.h +++ b/libnd4j/include/ops/declarable/DeclarableCustomOp.h @@ -24,19 +24,22 @@ #include namespace sd { - namespace ops { - class ND4J_EXPORT DeclarableCustomOp : public sd::ops::DeclarableOp { - protected: - /** - * This method executes this Op - */ - Nd4jStatus validateAndExecute(Context& block) override = 0; - public: - DeclarableCustomOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs); +namespace ops { +class SD_EXPORT DeclarableCustomOp : public sd::ops::DeclarableOp { + protected: + /** + * This method executes this Op + */ + Nd4jStatus validateAndExecute(Context& block) override = 0; - ShapeList* calculateOutputShape(ShapeList* inputShapes, sd::graph::Context& block) override = 0; - }; - } -} + public: + DeclarableCustomOp(int numInputs, int numOutputs, const char* opName, + bool allowsInplace, int tArgs, int iArgs); -#endif //LIBND4J_DECLARABLECUSTOMOP_H + ShapeList* calculateOutputShape(ShapeList* inputShapes, + sd::graph::Context& block) override = 0; +}; +} // namespace ops +} // namespace sd + +#endif // LIBND4J_DECLARABLECUSTOMOP_H diff --git a/libnd4j/include/ops/declarable/DeclarableListOp.h b/libnd4j/include/ops/declarable/DeclarableListOp.h index cc77ee17b271..aab3613e381d 100644 --- a/libnd4j/include/ops/declarable/DeclarableListOp.h +++ b/libnd4j/include/ops/declarable/DeclarableListOp.h @@ -23,34 +23,36 @@ #include #include -#include #include +#include using namespace sd::graph; namespace sd { - namespace ops { - class ND4J_EXPORT DeclarableListOp : public sd::ops::DeclarableOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override = 0; - - sd::NDArray* getZ(Context& block, int inputId) ; - void setupResult(NDArray* array, Context& block); - void setupResultList(NDArrayList* arrayList, Context& block); - - public: - DeclarableListOp(int numInputs, int numOutputs, const char* opName, int tArgs, int iArgs); - - - Nd4jStatus execute(Context* block) override; - - - ResultSet execute(NDArrayList* list, std::initializer_list inputs, std::initializer_list tArgs, std::initializer_list iArgs); - ResultSet execute(NDArrayList* list, std::vector& inputs, std::vector& tArgs, std::vector& iArgs); - - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - }; - } -} +namespace ops { +class SD_EXPORT DeclarableListOp : public sd::ops::DeclarableOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override = 0; + + sd::NDArray* getZ(Context& block, int inputId); + void setupResult(const NDArray& array, Context& block); + void setupResultList(const NDArrayList& arrayList, Context& block); + + public: + DeclarableListOp(int numInputs, int numOutputs, const char* opName, int tArgs, + int iArgs); + + Nd4jStatus execute(Context* block) override; + + ResultSet execute(const NDArrayList& list, + const std::vector& inputs, + const std::vector& tArgs = {}, + const std::vector& iArgs = {}); + + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; +}; +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/DeclarableOp.h b/libnd4j/include/ops/declarable/DeclarableOp.h index 3cce3b8e4fd8..e78b211fdcc8 100644 --- a/libnd4j/include/ops/declarable/DeclarableOp.h +++ b/libnd4j/include/ops/declarable/DeclarableOp.h @@ -21,18 +21,20 @@ #ifndef LIBND4J_DECLARABLE_OPS_H #define LIBND4J_DECLARABLE_OPS_H -#include -#include -#include #include -#include -#include "OpDescriptor.h" -#include -#include #include +#include +#include #include -#include +#include #include +#include +#include +#include + +#include + +#include "OpDescriptor.h" //#include #include @@ -42,181 +44,215 @@ using namespace sd::graph; namespace sd { - namespace ops { - - Nd4jStatus ND4J_EXPORT conditionHelper(const char *file, int line, int condition, int argNumber, const char *format, ...); - - - template - Nd4jStatus resultHelper(T status, const char *func, const char *file, int line) { - if (status) { - // TODO: fill out error codes here - fprintf(stderr, "Validation error at %s:%d code=%d(%s) \"%s\" \n", file, line, - static_cast(status), "", func); - - return ND4J_STATUS_BAD_INPUT; - } - - return ND4J_STATUS_OK; - } - - /** - * This class is the basic building block of Graph Operations. Any CustomOp out there is built on top of this "abstract" class. - * - */ - class ND4J_EXPORT DeclarableOp { - private: - std::mutex _registrator; - bool _registered = false; - std::string _name; - protected: - OpDescriptor *_descriptor; - NDArray *_scalar = nullptr; - - virtual void registerTypes(); - - /** - * This method executes this Op, and defined for most of individual ops separately - */ - virtual Nd4jStatus validateAndExecute(Context& block) = 0; - - /** - * This method ensures that target variable has enough space for op execution - * - * TODO: we want workspaces support right here - */ - bool allocateResult(Context& block, std::initializer_list& shape, char order = 'c'); - bool allocateResult(Context& block, Nd4jLong* shape); - - /** - * This method overwrites existen NDArray or NDArrayList in VariableSpace - * - * PLEASE NOTE: This method is dangerous. - * - * @param block - * @param numOutput - * @param array - */ - void overwriteResult(Context& block, int outputIdx, NDArray* array); - void overwriteResult(Context& block, int outputIdx, NDArrayList* list); - - /* - * This method attaches array to specific Variable, identified by node ID and outputNumber (which is output index for multi-output operations) - */ - void storeResult(Context &block, int outputNumber, NDArray& array); - void storeResult(Context &block, int outputNumber, NDArray* array); - sd::NDArray* getZ(Context& block, int inputId = 0); - sd::NDArray* getNullifiedZ(Context& block, int inputId = 0); - - /** - * This method pre-allocates NDArrays for Op output, in case they are not available at op execution time - */ - int prepareOutputs(Context& block); - - virtual samediff::EmptyHandling emptyHandling(); - public: - // for special cases, like BooleanOps - DeclarableOp(); - DeclarableOp(const char *name, int numInputs, bool scalar); - - // regular constructors - DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace); - DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent); - DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs); - - // for LogicalOps - DeclarableOp(const char *name, bool isLogical); - - // default testructor - virtual ~DeclarableOp(); - - // this method returns OpDescriptor, describing this Op instance - OpDescriptor *getOpDescriptor(); - - virtual Nd4jStatus validateDataTypes(Context& block); - - /** - * This method should be available in each implemented Op, and should return Op output shape(s), for a given input shape(s) - */ - virtual ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) = 0; - - /** - * Returns opName - * - * @return - */ - std::string *getOpName(); - - /** - * Returns opHash - */ - Nd4jLong getOpHash(); - - /** - * This method sets arguments for op - */ -// void setArguments(); - - /** - * This method returns pointer to results - */ -// void getResults(); - - /** - * This method executes given Op - * - * @param block - * @return 0 if OK, error code otherwise - */ - virtual Nd4jStatus execute(Context* block); - - Nd4jStatus execute(const std::vector &inputs, const std::vector &outputs); - - template ::value>> - Nd4jStatus execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list tArgs); - - Nd4jStatus execute(const std::vector &inputs, const std::vector &outputs, const std::vector &tArgs, const std::vector &iArgs, const std::vector &bArgs = std::vector(), const std::vector &dArgs = std::vector(), bool isInplace = false); - - sd::ResultSet evaluate(const std::vector &inputs); - - template ::value>> - sd::ResultSet evaluate(const std::vector &inputs, std::initializer_list args); - - sd::ResultSet evaluate(const std::vector &inputs, const std::vector &tArgs, const std::vector &iArgs, const std::vector &bArgs = std::vector(), const std::vector &dArgs = std::vector(), bool isInplace = false); - - Nd4jStatus execute(sd::graph::RandomGenerator& rng, const std::vector& inputs, const std::vector& outputs, const std::vector& tArgs, const std::vector& iArgs, const std::vector& bArgs, const std::vector &dArgs = std::vector(), bool isInplace = false, sd::DataType type = sd::DataType::FLOAT32); - - sd::ResultSet execute(const sd::OpArgsHolder& holder, bool isInplace = false); +namespace ops { +Nd4jStatus SD_EXPORT conditionHelper(const char* file, int line, int condition, + int argNumber, const char* format, ...); - // There methods provide various validation options - Nd4jStatus validateNonEmptyInput(Context& block); +template +Nd4jStatus resultHelper(T status, const char* func, const char* file, + int line) { + if (status) { + // TODO: fill out error codes here + fprintf(stderr, "Validation error at %s:%d code=%d(%s) \"%s\" \n", file, + line, static_cast(status), "", func); - // this method checks if all input arrays have equal lengths - Nd4jStatus validateInputLengthMatch(Context& block); + return ND4J_STATUS_BAD_INPUT; + } - // this method checks if all input arrays have the same shapes (orders/strides are NOT checked) - Nd4jStatus validateInputDimensionsMatch(Context& block); - - // this method check if all input arrays have the same orders - Nd4jStatus validateOrdersMatch(Context& block); - - // this method checks if all input arrays are 2D - Nd4jStatus validateInput2D(Context& block); - - // this method checks if all input arrays are 3D - Nd4jStatus validateInput3D(Context& block); - - // this method checks if all input arrays are 4D - Nd4jStatus validateInput4D(Context& block); - - // this method checks if all input arrays are ND - Nd4jStatus validateInputDimensions(Context& block, int rank); - - // this method checks if number of available arguments matches op expectations - Nd4jStatus validateArguments(Context& block); - }; - } + return ND4J_STATUS_OK; } -#endif //LIBND4J_DECLARABLE_OPS_H +/** + * This class is the basic building block of Graph Operations. Any CustomOp out + * there is built on top of this "abstract" class. + * + */ +class SD_EXPORT DeclarableOp { + private: + std::mutex _registrator; + bool _registered = false; + std::string _name; + + protected: + OpDescriptor* _descriptor; + NDArray* _scalar = nullptr; + + virtual void registerTypes(); + + /** + * This method executes this Op, and defined for most of individual ops + * separately + */ + virtual Nd4jStatus validateAndExecute(Context& block) = 0; + + /** + * This method ensures that target variable has enough space for op execution + * + * TODO: we want workspaces support right here + */ + bool allocateResult(Context& block, std::initializer_list& shape, + char order = 'c'); + bool allocateResult(Context& block, Nd4jLong* shape); + + /** + * This method overwrites existen NDArray or NDArrayList in VariableSpace + * + * PLEASE NOTE: This method is dangerous. + * + * @param block + * @param numOutput + * @param array + */ + void overwriteResult(Context& block, int outputIdx, NDArray* array); + void overwriteResult(Context& block, int outputIdx, NDArrayList* list); + + /* + * This method attaches array to specific Variable, identified by node ID and + * outputNumber (which is output index for multi-output operations) + */ + void storeResult(Context& block, int outputNumber, NDArray& array); + void storeResult(Context& block, int outputNumber, NDArray* array); + sd::NDArray* getZ(Context& block, int inputId = 0); + sd::NDArray* getNullifiedZ(Context& block, int inputId = 0); + + virtual samediff::EmptyHandling emptyHandling(); + + public: + // for special cases, like BooleanOps + DeclarableOp(); + DeclarableOp(const char* name, int numInputs, bool scalar); + + // regular constructors + DeclarableOp(int numInputs, int numOutputs, const char* opName, + bool allowsInplace); + DeclarableOp(int numInputs, int numOutputs, const char* opName, + bool allowsInplace, bool divergent); + DeclarableOp(int numInputs, int numOutputs, const char* opName, + bool allowsInplace, int tArgs, int iArgs); + + // for LogicalOps + DeclarableOp(const char* name, bool isLogical); + + // default testructor + virtual ~DeclarableOp(); + + // this method returns OpDescriptor, describing this Op instance + OpDescriptor* getOpDescriptor(); + + virtual Nd4jStatus validateDataTypes(Context& block); + + /** + * This method should be available in each implemented Op, and should return + * Op output shape(s), for a given input shape(s) + */ + virtual ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) = 0; + + /** + * Returns opName + * + * @return + */ + const std::string& getOpName() const; + + /** + * Returns opHash + */ + Nd4jLong getOpHash() const; + + /** + * This method sets arguments for op + */ + // void setArguments(); + + /** + * This method returns pointer to results + */ + // void getResults(); + + /** + * This method executes given Op + * + * @param block + * @return 0 if OK, error code otherwise + */ + virtual Nd4jStatus execute(Context* block); + + Nd4jStatus execute(const std::vector& inputs, + const std::vector& outputs); + + template ::value>> + Nd4jStatus execute(const std::vector& inputs, + const std::vector& outputs, + std::initializer_list tArgs); + + Nd4jStatus execute( + const std::vector& inputs, const std::vector& outputs, + const std::vector& tArgs, const std::vector& iArgs, + const std::vector& bArgs = std::vector(), + const std::vector& dArgs = std::vector(), + bool isInplace = false); + + sd::ResultSet evaluate(const std::vector& inputs); + + template ::value>> + sd::ResultSet evaluate(const std::vector& inputs, + std::initializer_list args); + + sd::ResultSet evaluate( + const std::vector& inputs, const std::vector& tArgs, + const std::vector& iArgs, + const std::vector& bArgs = std::vector(), + const std::vector& dArgs = std::vector(), + bool isInplace = false); + + Nd4jStatus execute( + sd::graph::RandomGenerator& rng, const std::vector& inputs, + const std::vector& outputs, const std::vector& tArgs, + const std::vector& iArgs, const std::vector& bArgs, + const std::vector& dArgs = std::vector(), + bool isInplace = false, sd::DataType type = sd::DataType::FLOAT32); + + sd::ResultSet execute(const sd::OpArgsHolder& holder, bool isInplace = false); + + // There methods provide various validation options + Nd4jStatus validateNonEmptyInput(Context& block); + + // this method checks if all input arrays have equal lengths + Nd4jStatus validateInputLengthMatch(Context& block); + + // this method checks if all input arrays have the same shapes (orders/strides + // are NOT checked) + Nd4jStatus validateInputDimensionsMatch(Context& block); + + // this method check if all input arrays have the same orders + Nd4jStatus validateOrdersMatch(Context& block); + + // this method checks if all input arrays are 2D + Nd4jStatus validateInput2D(Context& block); + + // this method checks if all input arrays are 3D + Nd4jStatus validateInput3D(Context& block); + + // this method checks if all input arrays are 4D + Nd4jStatus validateInput4D(Context& block); + + // this method checks if all input arrays are ND + Nd4jStatus validateInputDimensions(Context& block, int rank); + + // this method checks if number of available arguments matches op expectations + Nd4jStatus validateArguments(Context& block); + + /** + * This method pre-allocates NDArrays for Op output, in case they are not + * available at op execution time + */ + int prepareOutputs(Context& block); +}; +} // namespace ops +} // namespace sd + +#endif // LIBND4J_DECLARABLE_OPS_H diff --git a/libnd4j/include/ops/declarable/DeclarableReductionOp.h b/libnd4j/include/ops/declarable/DeclarableReductionOp.h index 11f4ec410b94..bd93d94848c3 100644 --- a/libnd4j/include/ops/declarable/DeclarableReductionOp.h +++ b/libnd4j/include/ops/declarable/DeclarableReductionOp.h @@ -24,19 +24,22 @@ #include namespace sd { - namespace ops { - class ND4J_EXPORT DeclarableReductionOp : public sd::ops::DeclarableOp { - protected: - /** - * This method executes this Op - */ - Nd4jStatus validateAndExecute(Context& block) override = 0; - public: - DeclarableReductionOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs); +namespace ops { +class SD_EXPORT DeclarableReductionOp : public sd::ops::DeclarableOp { + protected: + /** + * This method executes this Op + */ + Nd4jStatus validateAndExecute(Context& block) override = 0; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - }; - } -} + public: + DeclarableReductionOp(int numInputs, int numOutputs, const char* opName, + bool allowsInplace, int tArgs, int iArgs); -#endif //LIBND4J_DECLARABLE_REDUCTION_OP_H + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; +}; +} // namespace ops +} // namespace sd + +#endif // LIBND4J_DECLARABLE_REDUCTION_OP_H diff --git a/libnd4j/include/ops/declarable/EmptyHandling.h b/libnd4j/include/ops/declarable/EmptyHandling.h index c25fea498201..810faa279e3c 100644 --- a/libnd4j/include/ops/declarable/EmptyHandling.h +++ b/libnd4j/include/ops/declarable/EmptyHandling.h @@ -22,11 +22,7 @@ #define SAMEDIFF_EMPTYHANDLING_H namespace samediff { - enum EmptyHandling { - EMPTY_SKIP = 1, - EMPTY_EXCEPTION = 2, - EMPTY_EXECUTE = 3 - }; +enum EmptyHandling { EMPTY_SKIP = 1, EMPTY_EXCEPTION = 2, EMPTY_EXECUTE = 3 }; } -#endif //SAMEDIFF_EMPTYHANDLING_H +#endif // SAMEDIFF_EMPTYHANDLING_H diff --git a/libnd4j/include/ops/declarable/LegacyBroadcastBoolOp.h b/libnd4j/include/ops/declarable/LegacyBroadcastBoolOp.h index 67787ca4b1ad..54c737c1295f 100644 --- a/libnd4j/include/ops/declarable/LegacyBroadcastBoolOp.h +++ b/libnd4j/include/ops/declarable/LegacyBroadcastBoolOp.h @@ -24,22 +24,23 @@ #include namespace sd { - namespace ops { - /** - * This class provides wrapper for broadcast operations. - */ - class ND4J_EXPORT LegacyBroadcastBoolOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override ; - public: - LegacyBroadcastBoolOp(); - LegacyBroadcastBoolOp(int opNum); +namespace ops { +/** + * This class provides wrapper for broadcast operations. + */ +class SD_EXPORT LegacyBroadcastBoolOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyBroadcastBoolOp(); + LegacyBroadcastBoolOp(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_LEGACYBROADCASTOP_H +#endif // LIBND4J_LEGACYBROADCASTOP_H diff --git a/libnd4j/include/ops/declarable/LegacyBroadcastOp.h b/libnd4j/include/ops/declarable/LegacyBroadcastOp.h index 755277397d0f..718ff2dffe8b 100644 --- a/libnd4j/include/ops/declarable/LegacyBroadcastOp.h +++ b/libnd4j/include/ops/declarable/LegacyBroadcastOp.h @@ -24,22 +24,23 @@ #include namespace sd { - namespace ops { - /** - * This class provides wrapper for broadcast operations. - */ - class ND4J_EXPORT LegacyBroadcastOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override; - public: - LegacyBroadcastOp(); - LegacyBroadcastOp(int opNum); +namespace ops { +/** + * This class provides wrapper for broadcast operations. + */ +class SD_EXPORT LegacyBroadcastOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyBroadcastOp(); + LegacyBroadcastOp(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_LEGACYBROADCASTOP_H +#endif // LIBND4J_LEGACYBROADCASTOP_H diff --git a/libnd4j/include/ops/declarable/LegacyIndexReduceOp.h b/libnd4j/include/ops/declarable/LegacyIndexReduceOp.h index fae0c5e8fd3a..9076a2acd6f4 100644 --- a/libnd4j/include/ops/declarable/LegacyIndexReduceOp.h +++ b/libnd4j/include/ops/declarable/LegacyIndexReduceOp.h @@ -24,23 +24,26 @@ #include namespace sd { - namespace ops { - /** - * This class provides wrapper for IndexAccumulation operations. i.e. IndexMax or IndexAbsoluteMin etc - * - * TODO: eventually we want this op class to return long long instead of T - */ - class ND4J_EXPORT LegacyIndexReduceOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override; - public: - LegacyIndexReduceOp(); - LegacyIndexReduceOp(int opNum); +namespace ops { +/** + * This class provides wrapper for IndexAccumulation operations. i.e. IndexMax + * or IndexAbsoluteMin etc + * + * TODO: eventually we want this op class to return long long instead of T + */ +class SD_EXPORT LegacyIndexReduceOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; + + public: + LegacyIndexReduceOp(); + LegacyIndexReduceOp(int opNum); - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_LEGACYINDEXREDUCEOP_H +#endif // LIBND4J_LEGACYINDEXREDUCEOP_H diff --git a/libnd4j/include/ops/declarable/LegacyOp.h b/libnd4j/include/ops/declarable/LegacyOp.h index 0dfd91a429fe..e8cf41e3d847 100644 --- a/libnd4j/include/ops/declarable/LegacyOp.h +++ b/libnd4j/include/ops/declarable/LegacyOp.h @@ -21,38 +21,50 @@ #ifndef LIBND4J_LEGACYOP_H #define LIBND4J_LEGACYOP_H -#include #include +#include namespace sd { - namespace ops { - - /** - * This class is root abstraction for legacy XYZ ops wrappers. - * All wrappers for specific op groups (i.e. LegacyTransformOp for Transform ops) are inheriting this class. - * - * - */ - class ND4J_EXPORT LegacyOp : public DeclarableOp { - protected: - // this field is mainly for debugging - // it defines, which legacy op should be invoked on a given data - int _opNum = -1; - int _numInputs = 0; - - // All Op classes provide own specific implementation for this method - Nd4jStatus validateAndExecute(Context& block) override = 0; - public: - LegacyOp(int numInputs); - LegacyOp(int numInputs, int opNum); - ~LegacyOp() = default; - - // All Op classes provide own specific implementation for this method - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override = 0; - virtual LegacyOp* clone() = 0; - }; - } -} - - -#endif //LIBND4J_LEGACYOP_H +namespace ops { + +/** + * This class is root abstraction for legacy XYZ ops wrappers. + * All wrappers for specific op groups (i.e. LegacyTransformOp for Transform + * ops) are inheriting this class. + * + * + */ +class SD_EXPORT LegacyOp : public DeclarableOp { + protected: + // this field is mainly for debugging + // it defines, which legacy op should be invoked on a given data + int _opNum = -1; + int _numInputs = 0; + + // All Op classes provide own specific implementation for this method + Nd4jStatus validateAndExecute(Context& block) override = 0; + + public: + LegacyOp(int numInputs); + LegacyOp(int numInputs, int opNum); + ~LegacyOp() = default; + + LegacyOp(const LegacyOp& other) noexcept; + + LegacyOp& operator=(const LegacyOp& other) noexcept; + + // move constructor + LegacyOp(LegacyOp&& other) noexcept; + + // move assignment operator + LegacyOp& operator=(LegacyOp&& other) noexcept; + + // All Op classes provide own specific implementation for this method + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override = 0; + virtual LegacyOp* clone() = 0; +}; +} // namespace ops +} // namespace sd + +#endif // LIBND4J_LEGACYOP_H diff --git a/libnd4j/include/ops/declarable/LegacyPairwiseTransformBoolOp.h b/libnd4j/include/ops/declarable/LegacyPairwiseTransformBoolOp.h index 16a482811244..74194c82a79b 100644 --- a/libnd4j/include/ops/declarable/LegacyPairwiseTransformBoolOp.h +++ b/libnd4j/include/ops/declarable/LegacyPairwiseTransformBoolOp.h @@ -24,22 +24,23 @@ #include namespace sd { - namespace ops { - /** - * This class provides wrapper for Pairwise transform operations - */ - class ND4J_EXPORT LegacyPairwiseTransformBoolOp: public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override; - public: - LegacyPairwiseTransformBoolOp(); - LegacyPairwiseTransformBoolOp(int opNum); +namespace ops { +/** + * This class provides wrapper for Pairwise transform operations + */ +class SD_EXPORT LegacyPairwiseTransformBoolOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyPairwiseTransformBoolOp(); + LegacyPairwiseTransformBoolOp(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_LEGACYPAIRWISETRANSFORMOP_H +#endif // LIBND4J_LEGACYPAIRWISETRANSFORMOP_H diff --git a/libnd4j/include/ops/declarable/LegacyPairwiseTransformOp.h b/libnd4j/include/ops/declarable/LegacyPairwiseTransformOp.h index 81bbdc71556e..f82dca1f1454 100644 --- a/libnd4j/include/ops/declarable/LegacyPairwiseTransformOp.h +++ b/libnd4j/include/ops/declarable/LegacyPairwiseTransformOp.h @@ -24,22 +24,23 @@ #include namespace sd { - namespace ops { - /** - * This class provides wrapper for Pairwise transform operations - */ - class ND4J_EXPORT LegacyPairwiseTransformOp: public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override; - public: - LegacyPairwiseTransformOp(); - LegacyPairwiseTransformOp(int opNum); +namespace ops { +/** + * This class provides wrapper for Pairwise transform operations + */ +class SD_EXPORT LegacyPairwiseTransformOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyPairwiseTransformOp(); + LegacyPairwiseTransformOp(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_LEGACYPAIRWISETRANSFORMOP_H +#endif // LIBND4J_LEGACYPAIRWISETRANSFORMOP_H diff --git a/libnd4j/include/ops/declarable/LegacyRandomOp.h b/libnd4j/include/ops/declarable/LegacyRandomOp.h index c0bab879d661..f7b6d8d51c8b 100644 --- a/libnd4j/include/ops/declarable/LegacyRandomOp.h +++ b/libnd4j/include/ops/declarable/LegacyRandomOp.h @@ -21,37 +21,42 @@ #ifndef LIBND4J_LEGACYRANDOMOP_H #define LIBND4J_LEGACYRANDOMOP_H - #include #include namespace sd { - namespace ops { - /** - * This class provides wrapper for Random operations (i.e. linspace or Uniform) - */ - class ND4J_EXPORT LegacyRandomOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override; - public: - LegacyRandomOp(); - LegacyRandomOp(int opNum); - ~LegacyRandomOp() = default; - - template - Nd4jStatus validateAndExecute_(Context &block); - - sd::ResultSet execute(sd::graph::RandomGenerator& rng, std::initializer_list inputs, std::initializer_list tArgs, std::initializer_list iArgs, bool isInplace = false); - sd::ResultSet execute(sd::graph::RandomGenerator& rng, std::vector& inputs, std::vector& tArgs, std::vector& iArgs, bool isInplace = false); - - Nd4jStatus execute(Context* block) override; - - Nd4jStatus validateDataTypes(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} - - -#endif //LIBND4J_LEGACYTRANSFORMOP_H +namespace ops { +/** + * This class provides wrapper for Random operations (i.e. linspace or + * Uniform) + */ +class SD_EXPORT LegacyRandomOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; + + public: + LegacyRandomOp(); + LegacyRandomOp(int opNum); + ~LegacyRandomOp() = default; + + template + Nd4jStatus validateAndExecute_(Context& block); + + sd::ResultSet execute(sd::graph::RandomGenerator& rng, + const std::vector& inputs, + const std::vector& tArgs = {}, + const std::vector& iArgs = {}, + const std::vector& dArgs = {}, + bool isInplace = false); + + Nd4jStatus execute(Context* block) override; + + Nd4jStatus validateDataTypes(Context& block) override; + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd + +#endif // LIBND4J_LEGACYTRANSFORMOP_H diff --git a/libnd4j/include/ops/declarable/LegacyReduce3Op.h b/libnd4j/include/ops/declarable/LegacyReduce3Op.h index b0a06bd94e81..565e31f3b7ab 100644 --- a/libnd4j/include/ops/declarable/LegacyReduce3Op.h +++ b/libnd4j/include/ops/declarable/LegacyReduce3Op.h @@ -24,22 +24,24 @@ #include namespace sd { - namespace ops { - /** - * This class provides wrapper for Reduce3 operations (i.e. dot, cosineDistance etc) - */ - class ND4J_EXPORT LegacyReduce3Op : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override; - public: - LegacyReduce3Op(); - LegacyReduce3Op(int opNum); +namespace ops { +/** + * This class provides wrapper for Reduce3 operations (i.e. dot, + * cosineDistance etc) + */ +class SD_EXPORT LegacyReduce3Op : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyReduce3Op(); + LegacyReduce3Op(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_LEGACYREDUCE3OP_H +#endif // LIBND4J_LEGACYREDUCE3OP_H diff --git a/libnd4j/include/ops/declarable/LegacyReduceBoolOp.h b/libnd4j/include/ops/declarable/LegacyReduceBoolOp.h index 11cd52146874..64b7bf626559 100644 --- a/libnd4j/include/ops/declarable/LegacyReduceBoolOp.h +++ b/libnd4j/include/ops/declarable/LegacyReduceBoolOp.h @@ -24,19 +24,20 @@ #include namespace sd { - namespace ops { - class ND4J_EXPORT LegacyReduceBoolOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override; - public: - LegacyReduceBoolOp(); - LegacyReduceBoolOp(int opNum); +namespace ops { +class SD_EXPORT LegacyReduceBoolOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyReduceBoolOp(); + LegacyReduceBoolOp(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_LEGACYREDUCEOP_H +#endif // LIBND4J_LEGACYREDUCEOP_H diff --git a/libnd4j/include/ops/declarable/LegacyReduceFloatOp.h b/libnd4j/include/ops/declarable/LegacyReduceFloatOp.h index ed36a04fe20c..cafc3b685625 100644 --- a/libnd4j/include/ops/declarable/LegacyReduceFloatOp.h +++ b/libnd4j/include/ops/declarable/LegacyReduceFloatOp.h @@ -24,19 +24,20 @@ #include namespace sd { - namespace ops { - class ND4J_EXPORT LegacyReduceFloatOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override; - public: - LegacyReduceFloatOp(); - LegacyReduceFloatOp(int opNum); +namespace ops { +class SD_EXPORT LegacyReduceFloatOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyReduceFloatOp(); + LegacyReduceFloatOp(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_LEGACYREDUCEOP_H +#endif // LIBND4J_LEGACYREDUCEOP_H diff --git a/libnd4j/include/ops/declarable/LegacyReduceLongOp.h b/libnd4j/include/ops/declarable/LegacyReduceLongOp.h index 4f23a9717f53..d91f0ca99114 100644 --- a/libnd4j/include/ops/declarable/LegacyReduceLongOp.h +++ b/libnd4j/include/ops/declarable/LegacyReduceLongOp.h @@ -24,19 +24,20 @@ #include namespace sd { - namespace ops { - class ND4J_EXPORT LegacyReduceLongOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override; - public: - LegacyReduceLongOp(); - LegacyReduceLongOp(int opNum); +namespace ops { +class SD_EXPORT LegacyReduceLongOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyReduceLongOp(); + LegacyReduceLongOp(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_LEGACYREDUCEOP_H +#endif // LIBND4J_LEGACYREDUCEOP_H diff --git a/libnd4j/include/ops/declarable/LegacyReduceOp.h b/libnd4j/include/ops/declarable/LegacyReduceOp.h index 3e289fe258ed..39d27fa4fcee 100644 --- a/libnd4j/include/ops/declarable/LegacyReduceOp.h +++ b/libnd4j/include/ops/declarable/LegacyReduceOp.h @@ -25,23 +25,22 @@ /* namespace sd { namespace ops { - class ND4J_EXPORT LegacyReduceOp : public LegacyOp { + class SD_EXPORT LegacyReduceOp : public LegacyOp { protected: Nd4jStatus validateAndExecute(Context& block); public: LegacyReduceOp(); LegacyReduceOp(int opNum); - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block); - virtual LegacyOp* clone(); + ShapeList* calculateOutputShape(ShapeList* inputShape, +sd::graph::Context& block); virtual LegacyOp* clone(); }; } } */ -#include -#include #include +#include #include +#include - -#endif //LIBND4J_LEGACYREDUCEOP_H +#endif // LIBND4J_LEGACYREDUCEOP_H diff --git a/libnd4j/include/ops/declarable/LegacyReduceSameOp.h b/libnd4j/include/ops/declarable/LegacyReduceSameOp.h index 86cc06a0ecbe..eb09a3a39ab4 100644 --- a/libnd4j/include/ops/declarable/LegacyReduceSameOp.h +++ b/libnd4j/include/ops/declarable/LegacyReduceSameOp.h @@ -24,19 +24,20 @@ #include namespace sd { - namespace ops { - class ND4J_EXPORT LegacyReduceSameOp: public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override; - public: - LegacyReduceSameOp(); - LegacyReduceSameOp(int opNum); +namespace ops { +class SD_EXPORT LegacyReduceSameOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyReduceSameOp(); + LegacyReduceSameOp(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_LEGACYREDUCEOP_H +#endif // LIBND4J_LEGACYREDUCEOP_H diff --git a/libnd4j/include/ops/declarable/LegacyScalarBoolOp.h b/libnd4j/include/ops/declarable/LegacyScalarBoolOp.h index 0d52eee9d35b..8fe33aeac7a9 100644 --- a/libnd4j/include/ops/declarable/LegacyScalarBoolOp.h +++ b/libnd4j/include/ops/declarable/LegacyScalarBoolOp.h @@ -24,24 +24,25 @@ #include namespace sd { - namespace ops { - /** - * This class provides wrapper for scalar transform operations, i.e. a + b = c, where either a or b is scalar primitive and other operand is NDArray - */ - class ND4J_EXPORT LegacyScalarBoolOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override; - - public: - LegacyScalarBoolOp(); - LegacyScalarBoolOp(int opNum); - LegacyScalarBoolOp(int opNum, NDArray &scalar); - - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} - - -#endif //LIBND4J_LEGACYSCALAROP_H +namespace ops { +/** + * This class provides wrapper for scalar transform operations, i.e. a + b = + * c, where either a or b is scalar primitive and other operand is NDArray + */ +class SD_EXPORT LegacyScalarBoolOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; + + public: + LegacyScalarBoolOp(); + LegacyScalarBoolOp(int opNum); + LegacyScalarBoolOp(int opNum, NDArray& scalar); + + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd + +#endif // LIBND4J_LEGACYSCALAROP_H diff --git a/libnd4j/include/ops/declarable/LegacyScalarOp.h b/libnd4j/include/ops/declarable/LegacyScalarOp.h index 9f2a1a23a35b..c1b5d2d01f5a 100644 --- a/libnd4j/include/ops/declarable/LegacyScalarOp.h +++ b/libnd4j/include/ops/declarable/LegacyScalarOp.h @@ -24,24 +24,25 @@ #include namespace sd { - namespace ops { - /** - * This class provides wrapper for scalar transform operations, i.e. a + b = c, where either a or b is scalar primitive and other operand is NDArray - */ - class ND4J_EXPORT LegacyScalarOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override; - - public: - LegacyScalarOp(); - LegacyScalarOp(int opNum); - LegacyScalarOp(int opNum, NDArray &scalar); - - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} - - -#endif //LIBND4J_LEGACYSCALAROP_H +namespace ops { +/** + * This class provides wrapper for scalar transform operations, i.e. a + b = + * c, where either a or b is scalar primitive and other operand is NDArray + */ +class SD_EXPORT LegacyScalarOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; + + public: + LegacyScalarOp(); + LegacyScalarOp(int opNum); + LegacyScalarOp(int opNum, NDArray& scalar); + + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd + +#endif // LIBND4J_LEGACYSCALAROP_H diff --git a/libnd4j/include/ops/declarable/LegacyStatsOp.h b/libnd4j/include/ops/declarable/LegacyStatsOp.h index 74520b9ddf03..53b71414c82f 100644 --- a/libnd4j/include/ops/declarable/LegacyStatsOp.h +++ b/libnd4j/include/ops/declarable/LegacyStatsOp.h @@ -24,22 +24,24 @@ #include namespace sd { - namespace ops { - /** - * This class provides wrapper for SummaryStats operations: Variance and Standard Deviation - */ - class ND4J_EXPORT LegacyStatsOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context &block) override; - public: - LegacyStatsOp(); - LegacyStatsOp(int opNum); +namespace ops { +/** + * This class provides wrapper for SummaryStats operations: Variance and + * Standard Deviation + */ +class SD_EXPORT LegacyStatsOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context &block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyStatsOp(); + LegacyStatsOp(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_LEGACYSTATSOP_H +#endif // LIBND4J_LEGACYSTATSOP_H diff --git a/libnd4j/include/ops/declarable/LegacyTransformAnyOp.h b/libnd4j/include/ops/declarable/LegacyTransformAnyOp.h index f98ccd4c85e7..b5ac326acf2f 100644 --- a/libnd4j/include/ops/declarable/LegacyTransformAnyOp.h +++ b/libnd4j/include/ops/declarable/LegacyTransformAnyOp.h @@ -21,26 +21,26 @@ #ifndef LIBND4J__LEGACY_TRANSFORM_ANY_OP__H #define LIBND4J__LEGACY_TRANSFORM_ANY_OP__H - #include namespace sd { - namespace ops { - /** - * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) - */ - class ND4J_EXPORT LegacyTransformAnyOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context &block) override; - public: - LegacyTransformAnyOp(); - LegacyTransformAnyOp(int opNum); - - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context &block) override; - LegacyOp* clone() override; - }; - } -} - - -#endif //LIBND4J__LEGACY_TRANSFORM_FLOAT_OP__H +namespace ops { +/** + * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) + */ +class SD_EXPORT LegacyTransformAnyOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; + + public: + LegacyTransformAnyOp(); + LegacyTransformAnyOp(int opNum); + + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd + +#endif // LIBND4J__LEGACY_TRANSFORM_FLOAT_OP__H diff --git a/libnd4j/include/ops/declarable/LegacyTransformBoolOp.h b/libnd4j/include/ops/declarable/LegacyTransformBoolOp.h index d64dd4b019c0..56c01ae32aac 100644 --- a/libnd4j/include/ops/declarable/LegacyTransformBoolOp.h +++ b/libnd4j/include/ops/declarable/LegacyTransformBoolOp.h @@ -22,26 +22,26 @@ #ifndef LIBND4J__LEGACY_TRANSFORM_BOOL_OP__H #define LIBND4J__LEGACY_TRANSFORM_BOOL_OP__H - #include namespace sd { - namespace ops { - /** - * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) - */ - class ND4J_EXPORT LegacyTransformBoolOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context &block) override; - public: - LegacyTransformBoolOp(); - LegacyTransformBoolOp(int opNum); +namespace ops { +/** + * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) + */ +class SD_EXPORT LegacyTransformBoolOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context &block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyTransformBoolOp(); + LegacyTransformBoolOp(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J__LEGACY_TRANSFORM_SAME_OP__H +#endif // LIBND4J__LEGACY_TRANSFORM_SAME_OP__H diff --git a/libnd4j/include/ops/declarable/LegacyTransformFloatOp.h b/libnd4j/include/ops/declarable/LegacyTransformFloatOp.h index 37bd0edce8b4..4f01611bca7e 100644 --- a/libnd4j/include/ops/declarable/LegacyTransformFloatOp.h +++ b/libnd4j/include/ops/declarable/LegacyTransformFloatOp.h @@ -21,26 +21,26 @@ #ifndef LIBND4J__LEGACY_TRANSFORM_FLOAT_OP__H #define LIBND4J__LEGACY_TRANSFORM_FLOAT_OP__H - #include namespace sd { - namespace ops { - /** - * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) - */ - class ND4J_EXPORT LegacyTransformFloatOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context &block) override; - public: - LegacyTransformFloatOp(); - LegacyTransformFloatOp(int opNum); - - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context &block) override; - LegacyOp* clone() override; - }; - } -} - - -#endif //LIBND4J__LEGACY_TRANSFORM_FLOAT_OP__H +namespace ops { +/** + * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) + */ +class SD_EXPORT LegacyTransformFloatOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; + + public: + LegacyTransformFloatOp(); + LegacyTransformFloatOp(int opNum); + + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd + +#endif // LIBND4J__LEGACY_TRANSFORM_FLOAT_OP__H diff --git a/libnd4j/include/ops/declarable/LegacyTransformOp.h b/libnd4j/include/ops/declarable/LegacyTransformOp.h index 7eb265bcbaa8..1ba819aea0e3 100644 --- a/libnd4j/include/ops/declarable/LegacyTransformOp.h +++ b/libnd4j/include/ops/declarable/LegacyTransformOp.h @@ -21,32 +21,32 @@ #ifndef LIBND4J__LEGACY_TRANSFORM_OP__H #define LIBND4J__LEGACY_TRANSFORM_OP__H - //#include #ifdef ONLY_SAME_TRANSFORM namespace sd { - namespace ops { - /** - * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) - */ - class ND4J_EXPORT LegacyTransformOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context &block); - public: - LegacyTransformOp(); - LegacyTransformOp(int opNum); - - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context &block); - virtual LegacyOp* clone(); - }; - } -} +namespace ops { +/** + * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) + */ +class SD_EXPORT LegacyTransformOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block); + + public: + LegacyTransformOp(); + LegacyTransformOp(int opNum); + + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block); + virtual LegacyOp* clone(); +}; +} // namespace ops +} // namespace sd #endif +#include #include #include -#include #include - -#endif //LIBND4J__LEGACY_TRANSFORM_OP__H +#endif // LIBND4J__LEGACY_TRANSFORM_OP__H diff --git a/libnd4j/include/ops/declarable/LegacyTransformSameOp.h b/libnd4j/include/ops/declarable/LegacyTransformSameOp.h index 4d9312dafa1b..e4494b205be1 100644 --- a/libnd4j/include/ops/declarable/LegacyTransformSameOp.h +++ b/libnd4j/include/ops/declarable/LegacyTransformSameOp.h @@ -22,26 +22,26 @@ #ifndef LIBND4J__LEGACY_TRANSFORM_SAME_OP__H #define LIBND4J__LEGACY_TRANSFORM_SAME_OP__H - #include namespace sd { - namespace ops { - /** - * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) - */ - class ND4J_EXPORT LegacyTransformSameOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context &block) override; - public: - LegacyTransformSameOp(); - LegacyTransformSameOp(int opNum); +namespace ops { +/** + * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) + */ +class SD_EXPORT LegacyTransformSameOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context &block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyTransformSameOp(); + LegacyTransformSameOp(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J__LEGACY_TRANSFORM_SAME_OP__H +#endif // LIBND4J__LEGACY_TRANSFORM_SAME_OP__H diff --git a/libnd4j/include/ops/declarable/LegacyTransformStrictOp.h b/libnd4j/include/ops/declarable/LegacyTransformStrictOp.h index ee48c02b7b57..f66839f80313 100644 --- a/libnd4j/include/ops/declarable/LegacyTransformStrictOp.h +++ b/libnd4j/include/ops/declarable/LegacyTransformStrictOp.h @@ -22,26 +22,26 @@ #ifndef LIBND4J__LEGACY_TRANSFORM_STRICT_OP__H #define LIBND4J__LEGACY_TRANSFORM_STRICT_OP__H - #include namespace sd { - namespace ops { - /** - * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) - */ - class ND4J_EXPORT LegacyTransformStrictOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context &block) override; - public: - LegacyTransformStrictOp(); - LegacyTransformStrictOp(int opNum); +namespace ops { +/** + * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) + */ +class SD_EXPORT LegacyTransformStrictOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context &block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyTransformStrictOp(); + LegacyTransformStrictOp(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J__LEGACY_TRANSFORM_SAME_OP__H +#endif // LIBND4J__LEGACY_TRANSFORM_SAME_OP__H diff --git a/libnd4j/include/ops/declarable/LogicOp.h b/libnd4j/include/ops/declarable/LogicOp.h index d3ad59af29cf..03f071a95f1d 100644 --- a/libnd4j/include/ops/declarable/LogicOp.h +++ b/libnd4j/include/ops/declarable/LogicOp.h @@ -24,24 +24,27 @@ #include "DeclarableOp.h" namespace sd { - namespace ops { - - /** - * Logic ops are unique snowflakes in any Graph. They dramatically change Graph Execution process, by introducing loops, conditions, etc. - * - * Their code is the part of GraphExecutioner logic. But we still want them to be expressed via Graph - * @tparam T - */ - class ND4J_EXPORT LogicOp : public DeclarableOp { - protected: - Nd4jStatus validateAndExecute(sd::graph::Context& block) override; - public: - LogicOp(const char *name); - - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context &block) override; - }; - } -} - - -#endif //LIBND4J_LOGICOP_H +namespace ops { + +/** + * Logic ops are unique snowflakes in any Graph. They dramatically change Graph + * Execution process, by introducing loops, conditions, etc. + * + * Their code is the part of GraphExecutioner logic. But we still want them to + * be expressed via Graph + * @tparam T + */ +class SD_EXPORT LogicOp : public DeclarableOp { + protected: + Nd4jStatus validateAndExecute(sd::graph::Context& block) override; + + public: + LogicOp(const char* name); + + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; +}; +} // namespace ops +} // namespace sd + +#endif // LIBND4J_LOGICOP_H diff --git a/libnd4j/include/ops/declarable/OpDescriptor.h b/libnd4j/include/ops/declarable/OpDescriptor.h index 3feff5916735..9c99518750e9 100644 --- a/libnd4j/include/ops/declarable/OpDescriptor.h +++ b/libnd4j/include/ops/declarable/OpDescriptor.h @@ -21,169 +21,179 @@ #ifndef LIBND4J_OPDESCRIPTOR_H #define LIBND4J_OPDESCRIPTOR_H -#include -#include -#include +#include +#include #include #include -#include -#include -namespace sd { - namespace ops { - - /** - * This class is very basic info holder for ops. bean/pojo pretty much. - * - */ - class ND4J_EXPORT OpDescriptor { - protected: - // opNum for legacy XYZ ops - int _opNum = 0; - - // opName for CustomOp - std::string _opName; +#include +#include +#include - // hash is used for ops lookup in OpRegistrator - Nd4jLong _hash = -1; +namespace sd { +namespace ops { - // minimal required/expected number of inputs/outpus for this given op - int _numInputs = 1; - int _numOutputs = 1; +/** + * This class is very basic info holder for ops. bean/pojo pretty much. + * + */ +class SD_EXPORT OpDescriptor { + protected: + // opNum for legacy XYZ ops + int _opNum = 0; - // enum for ops. deprecated. will be removed - sd::graph::OpClass _opClass; + // opName for CustomOp + std::string _opName; - // special flag for divergent ops - ops that CAN and WILL modify graph behavior. Literally: IF, CASE. - bool _divergent = false; + // hash is used for ops lookup in OpRegistrator + Nd4jLong _hash = -1; - // flag, if this given op allows in-place execution - bool _allowsInplace = true; + // minimal required/expected number of inputs/outpus for this given op + int _numInputs = 1; + int _numOutputs = 1; - // minimal required number of T-type arguments. - // -1 as value means: not limited, variable number of arguments - int _tArgs = 0; + // enum for ops. deprecated. will be removed + sd::graph::OpClass _opClass; - // minimal required number of Integer-type arguments. - // -1 as value means: not limited, variable number of arguments - int _iArgs = 0; + // special flag for divergent ops - ops that CAN and WILL modify graph + // behavior. Literally: IF, CASE. + bool _divergent = false; - // field for BooleanOps - bool _scalar = false; + // flag, if this given op allows in-place execution + bool _allowsInplace = true; - // field for LogicOps - bool _logic = false; + // minimal required number of T-type arguments. + // -1 as value means: not limited, variable number of arguments + int _tArgs = 0; - // default InputType is numeric - InputType _inputType = InputType_NUMERIC; + // minimal required number of Integer-type arguments. + // -1 as value means: not limited, variable number of arguments + int _iArgs = 0; + // field for BooleanOps + bool _scalar = false; - bool _sameMode = false; - std::vector _allowedIns; - std::vector _allowedOuts; + // field for LogicOps + bool _logic = false; - // optional per-input configuration - MAP_IMPL> _outputTypes; - MAP_IMPL> _inputTypes; + // default InputType is numeric + InputType _inputType = InputType_NUMERIC; + bool _sameMode = false; + std::vector _allowedIns; + std::vector _allowedOuts; - // field for ops that allow data type override at runtime - bool _dtypeOverride = false; + // optional per-input configuration + MAP_IMPL> _outputTypes; + MAP_IMPL> _inputTypes; - bool checkDataTypesMatch(sd::DataType needle, std::vector &haystack) const; - public: - // default constructor - OpDescriptor(int numInputs, int numOutputs, std::string opName, bool allowsInplace); + // field for ops that allow data type override at runtime + bool _dtypeOverride = false; - // constructor for boolean ops - OpDescriptor(int numInputs, std::string opName, bool isScalar); - OpDescriptor(int numInputs, const char* opName, bool isScalar); + bool checkDataTypesMatch(sd::DataType needle, + std::vector& haystack) const; - // default constructor - OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace); + public: + // default constructor + OpDescriptor(int numInputs, int numOutputs, std::string opName, + bool allowsInplace); - // constructor for configurable op - OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs); + // constructor for boolean ops + OpDescriptor(int numInputs, std::string opName, bool isScalar); + OpDescriptor(int numInputs, const char* opName, bool isScalar); - // constructor for non-configurable divergent op - OpDescriptor(int numInputs, int numOutputs, std::string opName, bool allowsInplace, bool divergent); + // default constructor + OpDescriptor(int numInputs, int numOutputs, const char* opName, + bool allowsInplace); - // constructor for non-configurable divergent op - OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent); + // constructor for configurable op + OpDescriptor(int numInputs, int numOutputs, const char* opName, + bool allowsInplace, int tArgs, int iArgs); - // constructor for configurable divergent op - OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent, int tArgs, int iArgs); + // constructor for non-configurable divergent op + OpDescriptor(int numInputs, int numOutputs, std::string opName, + bool allowsInplace, bool divergent); - // constructor for logical ops (while, scope, etc) - OpDescriptor(const char * opName, bool isLogic); + // constructor for non-configurable divergent op + OpDescriptor(int numInputs, int numOutputs, const char* opName, + bool allowsInplace, bool divergent); - bool operator==(const OpDescriptor& other) const; + // constructor for configurable divergent op + OpDescriptor(int numInputs, int numOutputs, const char* opName, + bool allowsInplace, bool divergent, int tArgs, int iArgs); - // default destructor - ~OpDescriptor(); + // constructor for logical ops (while, scope, etc) + OpDescriptor(const char* opName, bool isLogic); - // this method returns minimal expected number of T arguments - int getNumberOfTArgs(); + bool operator==(const OpDescriptor& other) const; - // this method returns minimal expected number of Integer arguments - int getNumberOfIArgs(); + // default destructor + ~OpDescriptor(); - // this method returns minimal expected number of inputs - int getNumberOfInputs(); + // this method returns minimal expected number of T arguments + int getNumberOfTArgs(); - // this method returns hash code for this operation - Nd4jLong getHash(); + // this method returns minimal expected number of Integer arguments + int getNumberOfIArgs(); - // this method returns minimal expected number of outputs - int getNumberOfOutputs(); + // this method returns minimal expected number of inputs + int getNumberOfInputs(); - // this method returns opName (can be empty) - std::string *getOpName(); + // this method returns hash code for this operation + Nd4jLong getHash(); - // returns TRUE if this op is divergent. FALSE otherwise - bool isDivergent(); + // this method returns minimal expected number of outputs + int getNumberOfOutputs(); - // returns TRUE if this op allows in-place execution - bool allowsInplace(); + // this method returns opName (can be empty) + std::string* getOpName(); - // this method allows you to enable/disable inplace call for a given op - void allowInplace(bool reallyAllow); + // returns TRUE if this op is divergent. FALSE otherwise + bool isDivergent(); - // this method returns opNum (applicable for legacy XYZ ops only) - int getOpNum(); + // returns TRUE if this op allows in-place execution + bool allowsInplace(); - // this method allows to set specifc opNum - void setOpNum(int opNum); + // this method allows you to enable/disable inplace call for a given op + void allowInplace(bool reallyAllow); - void setHash(Nd4jLong hash); + // this method returns opNum (applicable for legacy XYZ ops only) + int getOpNum(); - InputType inputType(); + // this method allows to set specifc opNum + void setOpNum(int opNum); + void setHash(Nd4jLong hash); + InputType inputType(); - OpDescriptor* setInputType(InputType type); - OpDescriptor* setAllowedInputTypes(const std::initializer_list &dtype); - OpDescriptor* setAllowedOutputTypes(const std::initializer_list &dtype); - OpDescriptor* setAllowedInputTypes(int index, const std::vector &dtype); - OpDescriptor* setAllowedOutputTypes(int index, const std::vector &dtype); - OpDescriptor* setAllowedInputTypes(int index, sd::DataType dtype); - OpDescriptor* setAllowedOutputTypes(int index, sd::DataType dtype); - OpDescriptor* setAllowedInputTypes(sd::DataType dtype); - OpDescriptor* setAllowedOutputTypes(sd::DataType dtype); - OpDescriptor* allowOverride(bool reallyAllow); - OpDescriptor* setSameMode(bool reallySame); - OpDescriptor* setInputType(int idx, sd::DataType dtype); - OpDescriptor* setOutputType(int idx, sd::DataType dtype); + OpDescriptor* setInputType(InputType type); + OpDescriptor* setAllowedInputTypes( + const std::initializer_list& dtype); + OpDescriptor* setAllowedOutputTypes( + const std::initializer_list& dtype); + OpDescriptor* setAllowedInputTypes(int index, + const std::vector& dtype); + OpDescriptor* setAllowedOutputTypes(int index, + const std::vector& dtype); + OpDescriptor* setAllowedInputTypes(int index, sd::DataType dtype); + OpDescriptor* setAllowedOutputTypes(int index, sd::DataType dtype); + OpDescriptor* setAllowedInputTypes(sd::DataType dtype); + OpDescriptor* setAllowedOutputTypes(sd::DataType dtype); + OpDescriptor* allowOverride(bool reallyAllow); + OpDescriptor* setSameMode(bool reallySame); + OpDescriptor* setInputType(int idx, sd::DataType dtype); + OpDescriptor* setOutputType(int idx, sd::DataType dtype); - std::vector getOutputTypesForOutput(int index); + std::vector getOutputTypesForOutput(int index); - bool checkInputMatch(int index, sd::DataType dataType); - bool checkOutputMatch(int index, sd::DataType dataType); - bool isSameMode(); + bool checkInputMatch(int index, sd::DataType dataType); + bool checkOutputMatch(int index, sd::DataType dataType); + bool isSameMode(); - bool isInherit(int index); - }; - } -} + bool isInherit(int index); +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_OPDESCRIPTOR_H +#endif // LIBND4J_OPDESCRIPTOR_H diff --git a/libnd4j/include/ops/declarable/OpRegistrator.h b/libnd4j/include/ops/declarable/OpRegistrator.h index a4967d877209..55c1b48a917e 100644 --- a/libnd4j/include/ops/declarable/OpRegistrator.h +++ b/libnd4j/include/ops/declarable/OpRegistrator.h @@ -21,132 +21,142 @@ #ifndef LIBND4J_OPREGISTRATOR_H #define LIBND4J_OPREGISTRATOR_H -#include -#include -#include -#include +#include #include #include -#include +#include + +#include +#include +#include // handlers part -#include #include +#include #ifndef __JAVACPP_HACK__ namespace std { - template <> - class hash> { - public: - size_t operator()(const std::pair& k) const; - }; - - template <> - class hash> { - public: - size_t operator()(const std::pair& k) const; - }; +template <> +class hash> { + public: + size_t operator()(const std::pair& k) const; }; -#endif +template <> +class hash> { + public: + size_t operator()(const std::pair& k) const; +}; +}; // namespace std +#endif namespace sd { - namespace ops { - /** - * This class provides runtime ops lookup, based on opName or opHash. - * To build lookup directory we use *_OP_IMPL macro, which puts static structs at compile time in .cpp files, - * so once binary is executed, static objects are initialized automatically, and we get list of all ops - * available at runtime via this singleton. - * - */ - class ND4J_EXPORT OpRegistrator { - private: - static OpRegistrator* _INSTANCE; - OpRegistrator() { - nd4j_debug("OpRegistrator started\n",""); - +namespace ops { +/** + * This class provides runtime ops lookup, based on opName or opHash. + * To build lookup directory we use *_OP_IMPL macro, which puts static structs + * at compile time in .cpp files, so once binary is executed, static objects are + * initialized automatically, and we get list of all ops available at runtime + * via this singleton. + * + */ +class SD_EXPORT OpRegistrator { + private: + static OpRegistrator* _INSTANCE; + OpRegistrator() { + nd4j_debug("OpRegistrator started\n", ""); + + /* #ifndef _RELEASE - std::signal(SIGSEGV, &OpRegistrator::sigSegVHandler); - std::signal(SIGINT, &OpRegistrator::sigIntHandler); - std::signal(SIGABRT, &OpRegistrator::sigIntHandler); - std::signal(SIGFPE, &OpRegistrator::sigIntHandler); - std::signal(SIGILL, &OpRegistrator::sigIntHandler); - std::signal(SIGTERM, &OpRegistrator::sigIntHandler); - atexit(&OpRegistrator::exitHandler); + std::signal(SIGSEGV, &OpRegistrator::sigSegVHandler); + std::signal(SIGINT, &OpRegistrator::sigIntHandler); + std::signal(SIGABRT, &OpRegistrator::sigIntHandler); + std::signal(SIGFPE, &OpRegistrator::sigIntHandler); + std::signal(SIGILL, &OpRegistrator::sigIntHandler); + std::signal(SIGTERM, &OpRegistrator::sigIntHandler); + atexit(&OpRegistrator::exitHandler); #endif - }; + */ + }; - MAP_IMPL _msvc; + MAP_IMPL _msvc; - // pointers to our operations - MAP_IMPL _declarablesLD; - MAP_IMPL _declarablesD; - std::vector _uniqueD; + // pointers to our operations + MAP_IMPL> _declarablesLD; + MAP_IMPL> _declarablesD; - // pointers to platform-specific helpers - MAP_IMPL, sd::ops::platforms::PlatformHelper*> _helpersLH; - MAP_IMPL, sd::ops::platforms::PlatformHelper*> _helpersH; - std::vector _uniqueH; + // pointers to platform-specific helpers + MAP_IMPL, + sd::ops::platforms::PlatformHelper*> + _helpersLH; + MAP_IMPL, + sd::ops::platforms::PlatformHelper*> + _helpersH; + std::vector _uniqueH; - std::mutex _locker; - std::string _opsList; - bool isInit = false; - public: - ~OpRegistrator(); + std::mutex _locker; + std::string _opsList; + bool isInit = false; - static OpRegistrator& getInstance(); + public: + ~OpRegistrator(); - static void exitHandler(); - static void sigIntHandler(int sig); - static void sigSegVHandler(int sig); + static OpRegistrator& getInstance(); - void updateMSVC(Nd4jLong newHash, std::string& oldName); + static void exitHandler(); + static void sigIntHandler(int sig); + static void sigSegVHandler(int sig); - template - std::string local_to_string(T value); - const char * getAllCustomOperations(); + void updateMSVC(Nd4jLong newHash, std::string& oldName); - /** - * This method registers operation in our registry, so we can use them later - * - * @param op - */ - bool registerOperation(const char* name, sd::ops::DeclarableOp* op); - bool registerOperation(sd::ops::DeclarableOp *op); + template + std::string local_to_string(T value); + const char* getAllCustomOperations(); - void registerHelper(sd::ops::platforms::PlatformHelper* op); + /** + * This method registers operation in our registry, so we can use them later + * + * @param op + */ + bool registerOperation(const std::string& opName, + std::shared_ptr op); + bool registerOperation(std::shared_ptr op); - bool hasHelper(Nd4jLong hash, samediff::Engine engine); + void registerHelper(sd::ops::platforms::PlatformHelper* op); - sd::ops::DeclarableOp* getOperation(const char *name); - sd::ops::DeclarableOp* getOperation(Nd4jLong hash); - sd::ops::DeclarableOp* getOperation(std::string &name); + bool hasHelper(Nd4jLong hash, samediff::Engine engine); - sd::ops::platforms::PlatformHelper* getPlatformHelper(Nd4jLong hash, samediff::Engine engine); + std::shared_ptr getOperation(Nd4jLong hash); + std::shared_ptr getOperation(const std::string& name); - std::vector getAllHashes(); + bool hasOperation(const std::string& opName) const; + bool hasOperation(const Nd4jLong opName) const; - int numberOfOperations(); - }; + sd::ops::platforms::PlatformHelper* getPlatformHelper( + Nd4jLong hash, samediff::Engine engine); + std::vector getAllHashes(); - /* - * These structs are used to "register" our ops in OpRegistrator. - */ - template - struct __registrator{ - __registrator(); - }; + int numberOfOperations(); +}; - template - struct __registratorSynonym { - __registratorSynonym(const char *name, const char *oname); - }; +/* + * These structs are used to "register" our ops in OpRegistrator. + */ +template +struct __registrator { + __registrator(); +}; + +template +struct __registratorSynonym { + __registratorSynonym(const char* name, const char* oname); +}; - } -} +} // namespace ops +} // namespace sd -#endif //LIBND4J_OPREGISTRATOR_H +#endif // LIBND4J_OPREGISTRATOR_H diff --git a/libnd4j/include/ops/declarable/OpTuple.h b/libnd4j/include/ops/declarable/OpTuple.h index 7458ef3d0bfb..537fe7d8b0a4 100644 --- a/libnd4j/include/ops/declarable/OpTuple.h +++ b/libnd4j/include/ops/declarable/OpTuple.h @@ -21,31 +21,33 @@ #ifndef LIBND4J_OPTUPLE_H #define LIBND4J_OPTUPLE_H -#include -#include #include +#include +#include + namespace sd { - namespace ops { - class ND4J_EXPORT OpTuple { - public: - std::string _opName; - std::vector _inputs; - std::vector _outputs; - std::vector _tArgs; - std::vector _iArgs; - - OpTuple(const char *opName); - OpTuple(const char *opName, std::initializer_list&& inputs, std::initializer_list&& tArgs, std::initializer_list&& iArgs); - ~OpTuple(); - - OpTuple* addInput(sd::NDArray *array); - OpTuple* addOutput(sd::NDArray *array); - OpTuple* setTArgs(std::initializer_list tArgs); - OpTuple* setIArgs(std::initializer_list iArgs); - }; - } -} - - -#endif //LIBND4J_OPTUPLE_H +namespace ops { +class SD_EXPORT OpTuple { + public: + std::string _opName; + std::vector _inputs; + std::vector _outputs; + std::vector _tArgs; + std::vector _iArgs; + + OpTuple(const char* opName); + OpTuple(const char* opName, std::initializer_list&& inputs, + std::initializer_list&& tArgs, + std::initializer_list&& iArgs); + ~OpTuple(); + + OpTuple* addInput(sd::NDArray* array); + OpTuple* addOutput(sd::NDArray* array); + OpTuple* setTArgs(std::initializer_list tArgs); + OpTuple* setIArgs(std::initializer_list iArgs); +}; +} // namespace ops +} // namespace sd + +#endif // LIBND4J_OPTUPLE_H diff --git a/libnd4j/include/ops/declarable/PlatformHelper.h b/libnd4j/include/ops/declarable/PlatformHelper.h index e0231ad9addd..3adad1a3b1fa 100644 --- a/libnd4j/include/ops/declarable/PlatformHelper.h +++ b/libnd4j/include/ops/declarable/PlatformHelper.h @@ -21,75 +21,79 @@ #ifndef SD_PLATFORMHELPER_H #define SD_PLATFORMHELPER_H -#include #include #include -#include -#include +#include #include +#include + +#include namespace sd { - namespace ops { - namespace platforms { - /** - * This abstract class defines methods used by platform-specific helpers implementations - */ - class ND4J_EXPORT PlatformHelper { - protected: - // target engine for this impl - samediff::Engine _engine; - - // name of the operation this helper is built for - std::string _name; - - // hash of the operation this helper is built for - Nd4jLong _hash; - public: - PlatformHelper(const char *name, samediff::Engine engine); - - ~PlatformHelper() = default; - - std::string name(); - - samediff::Engine engine(); - - Nd4jLong hash(); - - /** - * This method checks, if given helper can be used with given input/output/configuration options - * - * @param context - * @return - */ - virtual bool isUsable(graph::Context &context) = 0; - - /** - * This method invokes helper. Typically this method replaces actual op execution - * - * @param context - * @return - */ - virtual Nd4jStatus invokeHelper(graph::Context &context) = 0; - - /** - * Helper method, needed for compatibility with DeclarableOp macros - * @param ctx - * @param inputId - * @return - */ - sd::NDArray* getZ(graph::Context &ctx, int inputId); - - /** - * Helper method, needed for compatibility with DeclarableOp macros - * @param ctx - * @param inputId - * @return - */ - sd::NDArray* getNullifiedZ(graph::Context &ctx, int inputId); - }; - } - } -} - - -#endif //SD_PLATFORMHELPER_H +namespace ops { +namespace platforms { +/** + * This abstract class defines methods used by platform-specific helpers + * implementations + */ +class SD_EXPORT PlatformHelper { + protected: + // target engine for this impl + samediff::Engine _engine; + + // name of the operation this helper is built for + std::string _name; + + // hash of the operation this helper is built for + Nd4jLong _hash; + + public: + PlatformHelper(const char *name, samediff::Engine engine); + + ~PlatformHelper() = default; + + std::string name(); + + samediff::Engine engine(); + + Nd4jLong hash(); + + /** + * This method checks, if given helper can be used with given + * input/output/configuration options + * + * @param context + * @return + */ + virtual bool isUsable(graph::Context &context) = 0; + + /** + * This method invokes helper. Typically this method replaces actual op + * execution + * + * @param context + * @return + */ + virtual Nd4jStatus invokeHelper(graph::Context &context) = 0; + + /** + * Helper method, needed for compatibility with DeclarableOp macros + * @param ctx + * @param inputId + * @return + */ + sd::NDArray *getZ(graph::Context &ctx, int inputId); + + /** + * Helper method, needed for compatibility with DeclarableOp macros + * @param ctx + * @param inputId + * @return + */ + sd::NDArray *getNullifiedZ(graph::Context &ctx, int inputId); +}; +} // namespace platforms +} // namespace ops +} // namespace sd + +#endif // SD_PLATFORMHELPER_H diff --git a/libnd4j/include/ops/declarable/generic/CustomOperations.cpp b/libnd4j/include/ops/declarable/generic/CustomOperations.cpp index c13430ce3abf..60c3e47ee296 100644 --- a/libnd4j/include/ops/declarable/generic/CustomOperations.cpp +++ b/libnd4j/include/ops/declarable/generic/CustomOperations.cpp @@ -15,37 +15,38 @@ ******************************************************************************/ // -// This is special snowflake. This file builds bindings for ops availability tests +// This is special snowflake. This file builds bindings for ops availability +// tests // // @author raver119@gmail.com // -#include #include +#include #include namespace sd { - _loader::_loader() { - // - OpTracker::getInstance(); +_loader::_loader() { + // + OpTracker::getInstance(); -//#ifndef __CLION_IDE__ - BUILD_TRACKER(OpType_TRANSFORM_SAME, TRANSFORM_FLOAT_OPS); - BUILD_TRACKER(OpType_TRANSFORM_SAME, TRANSFORM_SAME_OPS); - BUILD_TRACKER(OpType_TRANSFORM_SAME, TRANSFORM_BOOL_OPS); - BUILD_TRACKER(OpType_BROADCAST, BROADCAST_OPS); - BUILD_TRACKER(OpType_PAIRWISE, PAIRWISE_TRANSFORM_OPS); - BUILD_TRACKER(OpType_RANDOM, RANDOM_OPS); - BUILD_TRACKER(OpType_REDUCE_FLOAT, REDUCE_FLOAT_OPS); - BUILD_TRACKER(OpType_REDUCE_SAME, REDUCE_SAME_OPS); - BUILD_TRACKER(OpType_REDUCE_BOOL, REDUCE_BOOL_OPS); - BUILD_TRACKER(OpType_REDUCE_3, REDUCE3_OPS); - BUILD_TRACKER(OpType_INDEX_REDUCE, INDEX_REDUCE_OPS); - BUILD_TRACKER(OpType_SCALAR, SCALAR_OPS); - BUILD_TRACKER(OpType_SUMMARYSTATS, SUMMARY_STATS_OPS); -//#endif - }; + //#ifndef __CLION_IDE__ + BUILD_TRACKER(OpType_TRANSFORM_SAME, TRANSFORM_FLOAT_OPS); + BUILD_TRACKER(OpType_TRANSFORM_SAME, TRANSFORM_SAME_OPS); + BUILD_TRACKER(OpType_TRANSFORM_SAME, TRANSFORM_BOOL_OPS); + BUILD_TRACKER(OpType_BROADCAST, BROADCAST_OPS); + BUILD_TRACKER(OpType_PAIRWISE, PAIRWISE_TRANSFORM_OPS); + BUILD_TRACKER(OpType_RANDOM, RANDOM_OPS); + BUILD_TRACKER(OpType_REDUCE_FLOAT, REDUCE_FLOAT_OPS); + BUILD_TRACKER(OpType_REDUCE_SAME, REDUCE_SAME_OPS); + BUILD_TRACKER(OpType_REDUCE_BOOL, REDUCE_BOOL_OPS); + BUILD_TRACKER(OpType_REDUCE_3, REDUCE3_OPS); + BUILD_TRACKER(OpType_INDEX_REDUCE, INDEX_REDUCE_OPS); + BUILD_TRACKER(OpType_SCALAR, SCALAR_OPS); + BUILD_TRACKER(OpType_SUMMARYSTATS, SUMMARY_STATS_OPS); + //#endif +}; - static sd::_loader loader; -} \ No newline at end of file +static sd::_loader loader; +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/bits_hamming_distance.cpp b/libnd4j/include/ops/declarable/generic/bitwise/bits_hamming_distance.cpp index 693ebf7c6fc7..cb529f572b73 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/bits_hamming_distance.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/bits_hamming_distance.cpp @@ -22,35 +22,40 @@ #if NOT_EXCLUDED(OP_bits_hamming_distance) #include -#include #include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(bits_hamming_distance, 2, 1, true, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto output = OUTPUT_NULLIFIED(0); - - REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, "bits_hamming_distance: both arguments must have the same length"); - REQUIRE_TRUE(x->dataType() == y->dataType(), 0, "bits_hamming_distance: both arguments must have the same data type"); - - helpers::hamming(block.launchContext(), *x, *y, *output); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(bits_hamming_distance) { - return SHAPELIST(ConstantShapeHelper::getInstance().scalarShapeInfo(sd::DataType::INT64)); - } - - DECLARE_TYPES(bits_hamming_distance) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes(0, {ALL_INDICES}); - } - } +namespace ops { +CUSTOM_OP_IMPL(bits_hamming_distance, 2, 1, true, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto output = OUTPUT_NULLIFIED(0); + + REQUIRE_TRUE( + x->lengthOf() == y->lengthOf(), 0, + "bits_hamming_distance: both arguments must have the same length"); + REQUIRE_TRUE( + x->dataType() == y->dataType(), 0, + "bits_hamming_distance: both arguments must have the same data type"); + + helpers::hamming(block.launchContext(), *x, *y, *output); + + return Status::OK(); +} + +DECLARE_SHAPE_FN(bits_hamming_distance) { + return SHAPELIST( + ConstantShapeHelper::getInstance().scalarShapeInfo(sd::DataType::INT64)); +} + +DECLARE_TYPES(bits_hamming_distance) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes(0, {ALL_INDICES}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_and.cpp b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_and.cpp index 1e951c1d9580..8fc387c848f0 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_and.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_and.cpp @@ -26,25 +26,26 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(bitwise_and, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_OP_IMPL(bitwise_and, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntAnd, pairwise::IntOps::IntAnd, broadcast::IntOps::IntAnd), *y, *z, false); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntAnd, + pairwise::IntOps::IntAnd, + broadcast::IntOps::IntAnd), + *y, *z, false); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(bitwise_and) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS}) - ->setSameMode(true); - } - } +DECLARE_TYPES(bitwise_and) { + getOpDescriptor()->setAllowedInputTypes({ALL_INTS})->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_or.cpp b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_or.cpp index cd20a8434f9e..642cf4dcc2d7 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_or.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_or.cpp @@ -26,25 +26,26 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(bitwise_or, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_OP_IMPL(bitwise_or, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntOr, pairwise::IntOps::IntOr, broadcast::IntOps::IntOr), *y, *z, false); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntOr, + pairwise::IntOps::IntOr, + broadcast::IntOps::IntOr), + *y, *z, false); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(bitwise_or) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS}) - ->setSameMode(true); - } - } +DECLARE_TYPES(bitwise_or) { + getOpDescriptor()->setAllowedInputTypes({ALL_INTS})->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_xor.cpp b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_xor.cpp index 0af9fe759163..241ae8c22459 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_xor.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_xor.cpp @@ -26,25 +26,26 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(bitwise_xor, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_OP_IMPL(bitwise_xor, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntXor, pairwise::IntOps::IntXor, broadcast::IntOps::IntXor), *y, *z, false); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntXor, + pairwise::IntOps::IntXor, + broadcast::IntOps::IntXor), + *y, *z, false); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(bitwise_xor) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS}) - ->setSameMode(true); - } - } +DECLARE_TYPES(bitwise_xor) { + getOpDescriptor()->setAllowedInputTypes({ALL_INTS})->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp index cc0c4827b23d..5ffa1f1989df 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp @@ -26,25 +26,27 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(cyclic_rshift_bits, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_OP_IMPL(cyclic_rshift_bits, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftRight, pairwise::CyclicShiftRight, broadcast::CyclicShiftRight), *y, *z, false); + x->applyTrueBroadcast( + BroadcastIntOpsTuple::custom(scalar::CyclicShiftRight, + pairwise::CyclicShiftRight, + broadcast::CyclicShiftRight), + *y, *z, false); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(cyclic_rshift_bits) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS}) - ->setSameMode(true); - } - } +DECLARE_TYPES(cyclic_rshift_bits) { + getOpDescriptor()->setAllowedInputTypes({ALL_INTS})->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp index f2b36a6d8346..0578b87557db 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp @@ -26,25 +26,26 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(cyclic_shift_bits, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_OP_IMPL(cyclic_shift_bits, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftLeft, pairwise::CyclicShiftLeft, broadcast::CyclicShiftLeft), *y, *z, false); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom( + scalar::CyclicShiftLeft, pairwise::CyclicShiftLeft, + broadcast::CyclicShiftLeft), + *y, *z, false); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(cyclic_shift_bits) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS}) - ->setSameMode(true); - } - } +DECLARE_TYPES(cyclic_shift_bits) { + getOpDescriptor()->setAllowedInputTypes({ALL_INTS})->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp index 8b44d2a6f415..bde55383ba1e 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp @@ -26,25 +26,26 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(rshift_bits, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_OP_IMPL(rshift_bits, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftRight, pairwise::ShiftRight, broadcast::ShiftRight), *y, *z, false); + x->applyTrueBroadcast( + BroadcastIntOpsTuple::custom(scalar::ShiftRight, pairwise::ShiftRight, + broadcast::ShiftRight), + *y, *z, false); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(rshift_bits) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS}) - ->setSameMode(true); - } - } +DECLARE_TYPES(rshift_bits) { + getOpDescriptor()->setAllowedInputTypes({ALL_INTS})->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp index 7d0647e1b7bf..47ff4ee5b294 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp @@ -26,25 +26,26 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(shift_bits, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_OP_IMPL(shift_bits, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftLeft, pairwise::ShiftLeft, broadcast::ShiftLeft), *y, *z, false); + x->applyTrueBroadcast( + BroadcastIntOpsTuple::custom(scalar::ShiftLeft, pairwise::ShiftLeft, + broadcast::ShiftLeft), + *y, *z, false); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(shift_bits) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS}) - ->setSameMode(true); - } - } +DECLARE_TYPES(shift_bits) { + getOpDescriptor()->setAllowedInputTypes({ALL_INTS})->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/toggle_bits.cpp b/libnd4j/include/ops/declarable/generic/bitwise/toggle_bits.cpp index 0ba6fbcc7d52..ffe81cafa457 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/toggle_bits.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/toggle_bits.cpp @@ -26,28 +26,30 @@ #include namespace sd { - namespace ops { - OP_IMPL(toggle_bits, -1, -1, true) { - - for (int i = 0; i < block.width(); i++) { - auto x = INPUT_VARIABLE(i); - auto z = OUTPUT_VARIABLE(i); - - REQUIRE_TRUE(x->dataType() == z->dataType(), 0, "Toggle bits requires input and output to have same type"); - REQUIRE_TRUE(x->isZ(),0, "Toggle bits requires input and output to be integer type (int8, int16, int32, int64)"); - - helpers::__toggle_bits(block.launchContext(), *x, *z); - } - return Status::OK(); - } - - DECLARE_TYPES(toggle_bits) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS}) - ->setAllowedOutputTypes({ALL_INTS}) - ->setSameMode(false); - } - } +namespace ops { +OP_IMPL(toggle_bits, -1, -1, true) { + for (int i = 0; i < block.width(); i++) { + auto x = INPUT_VARIABLE(i); + auto z = OUTPUT_VARIABLE(i); + + REQUIRE_TRUE(x->dataType() == z->dataType(), 0, + "Toggle bits requires input and output to have same type"); + REQUIRE_TRUE(x->isZ(), 0, + "Toggle bits requires input and output to be integer type " + "(int8, int16, int32, int64)"); + + helpers::__toggle_bits(block.launchContext(), *x, *z); + } + return Status::OK(); } +DECLARE_TYPES(toggle_bits) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS}) + ->setAllowedOutputTypes({ALL_INTS}) + ->setSameMode(false); +} +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/blas/axpy.cpp b/libnd4j/include/ops/declarable/generic/blas/axpy.cpp index 7c115059991a..cfee30389f5b 100644 --- a/libnd4j/include/ops/declarable/generic/blas/axpy.cpp +++ b/libnd4j/include/ops/declarable/generic/blas/axpy.cpp @@ -24,38 +24,41 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(axpy, 2, 1, false, -2, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(x->isSameShape(y),0, "Axpy: both arguments should have the same shape"); - REQUIRE_TRUE(x->dataType() == y->dataType() && x->dataType() == z->dataType(), 0, "Axpy: all arguments must have the same data type"); - - double a = 1.0; - - if (block.width() > 2) { - auto alpha = INPUT_VARIABLE(2); - REQUIRE_TRUE(alpha->isScalar(), 0, "Axpy: alpha argument should be scalar or TArg"); - } else if (block.getTArguments()->size() > 0) { - a = T_ARG(0); - } - - ExtraArguments arguments({a}); - - y->applyPairwiseTransform(pairwise::Axpy, *x, *z, &arguments); - - return ND4J_STATUS_OK; - } - - DECLARE_TYPES(axpy) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } - } +namespace ops { +CONFIGURABLE_OP_IMPL(axpy, 2, 1, false, -2, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(x->isSameShape(y), 0, + "Axpy: both arguments should have the same shape"); + REQUIRE_TRUE(x->dataType() == y->dataType() && x->dataType() == z->dataType(), + 0, "Axpy: all arguments must have the same data type"); + + double a = 1.0; + + if (block.width() > 2) { + auto alpha = INPUT_VARIABLE(2); + REQUIRE_TRUE(alpha->isScalar(), 0, + "Axpy: alpha argument should be scalar or TArg"); + } else if (block.numT() > 0) { + a = T_ARG(0); + } + + ExtraArguments arguments({a}); + + y->applyPairwiseTransform(pairwise::Axpy, *x, *z, &arguments); + + return Status::OK(); +} + +DECLARE_TYPES(axpy) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/blas/batched_gemm.cpp b/libnd4j/include/ops/declarable/generic/blas/batched_gemm.cpp index 79227e2ba75b..da63c146e363 100644 --- a/libnd4j/include/ops/declarable/generic/blas/batched_gemm.cpp +++ b/libnd4j/include/ops/declarable/generic/blas/batched_gemm.cpp @@ -25,115 +25,135 @@ #include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(batched_gemm, -1, -1, false, 0, 9) { - - int transA = INT_ARG(0); - int transB = INT_ARG(1); - int M = INT_ARG(2); - int N = INT_ARG(3); - int K = INT_ARG(4); - int ldA = INT_ARG(5); - int ldB = INT_ARG(6); - int ldC = INT_ARG(7); - int batchSize = INT_ARG(8); - - if (transA == 0) - transA = 111; - - if (transB == 0) - transB = 111; - - if (transA == 1) - transA = 112; - - if (transB == 1) - transB = 112; - - // basically A+B and 2 arrays of alpha and beta - int expectedWidth = batchSize * 2 + 2; - - REQUIRE_TRUE((transA == 111 || transA == 112) && (transB == 111 || transB == 112), 0, "BatchedGemm: valid values for transA and transB are: 0/1 or 111/112, for NoTrans/Trans respectively") - REQUIRE_TRUE(M > 0 && N > 0 && K > 0 && ldA > 0 && ldB > 0 && ldC > 0 && batchSize > 0, 0, ""); - REQUIRE_TRUE(block.width() == expectedWidth, 0, "BatchedGemm: expected number of input arrays is %i, but got %i instead", expectedWidth, block.width()); - - auto alpha = INPUT_VARIABLE(0); - auto beta = INPUT_VARIABLE(1); - - std::vector vA(batchSize); - std::vector vB(batchSize); - std::vector vC(batchSize); - - auto firstType = INPUT_VARIABLE(0)->dataType(); - for(int e = 0; e < batchSize; e++) { - vA[e] = INPUT_VARIABLE(e+2); - vB[e] = INPUT_VARIABLE(e+2+batchSize); - vC[e] = OUTPUT_VARIABLE(e); - - - REQUIRE_TRUE(firstType == vC[e]->dataType(), 0, "BatchedGemm: all inputs and outputs must have same data type"); - - REQUIRE_TRUE(vA[e]->rankOf() == 2, 0, "BatchedGemm: batch %i, rank of A should be equal to 2", e); - REQUIRE_TRUE(vB[e]->rankOf() == 2, 0, "BatchedGemm: batch %i, rank of B should be equal to 2", e); - REQUIRE_TRUE(vC[e]->rankOf() == 2, 0, "BatchedGemm: batch %i, rank of C should be equal to 2", e); - - REQUIRE_TRUE(M == vA[e]->sizeAt(0), 0, "BatchedGemm: batch %i, number of A.rows() should be equal to M", e); - REQUIRE_TRUE(N == vB[e]->sizeAt(1), 0, "BatchedGemm: batch %i, number of B.columns() should be equal to N", e); - REQUIRE_TRUE(K == vA[e]->sizeAt(1) && K == vB[e]->sizeAt(0), 0, "BatchedGemm: batch %i, number of A.columns() and B.rows() should be equal to K", e); - }; - - REQUIRE_TRUE(vA.size() == vB.size() && vA.size() == vC.size() && vA.size() == batchSize, 0, "BatchedGemm: mismatched numbers of A, B, C for unknown reason"); - - sd::ops::helpers::bgemm(vA, vB, vC, alpha, beta, transA, transB, M, N, K, ldA, ldB, ldC); - - return Status::OK(); + int transA = INT_ARG(0); + int transB = INT_ARG(1); + int M = INT_ARG(2); + int N = INT_ARG(3); + int K = INT_ARG(4); + int ldA = INT_ARG(5); + int ldB = INT_ARG(6); + int ldC = INT_ARG(7); + int batchSize = INT_ARG(8); + + if (transA == 0) transA = 111; + + if (transB == 0) transB = 111; + + if (transA == 1) transA = 112; + + if (transB == 1) transB = 112; + + // basically A+B and 2 arrays of alpha and beta + int expectedWidth = batchSize * 2 + 2; + + REQUIRE_TRUE( + (transA == 111 || transA == 112) && (transB == 111 || transB == 112), 0, + "BatchedGemm: valid values for transA and transB are: 0/1 or 111/112, " + "for NoTrans/Trans respectively") + REQUIRE_TRUE( + M > 0 && N > 0 && K > 0 && ldA > 0 && ldB > 0 && ldC > 0 && batchSize > 0, + 0, ""); + REQUIRE_TRUE( + block.width() == expectedWidth, 0, + "BatchedGemm: expected number of input arrays is %i, but got %i instead", + expectedWidth, block.width()); + + auto alpha = INPUT_VARIABLE(0); + auto beta = INPUT_VARIABLE(1); + + std::vector vA(batchSize); + std::vector vB(batchSize); + std::vector vC(batchSize); + + auto firstType = INPUT_VARIABLE(0)->dataType(); + for (int e = 0; e < batchSize; e++) { + vA[e] = INPUT_VARIABLE(e + 2); + vB[e] = INPUT_VARIABLE(e + 2 + batchSize); + vC[e] = OUTPUT_VARIABLE(e); + + REQUIRE_TRUE( + firstType == vC[e]->dataType(), 0, + "BatchedGemm: all inputs and outputs must have same data type"); + + REQUIRE_TRUE(vA[e]->rankOf() == 2, 0, + "BatchedGemm: batch %i, rank of A should be equal to 2", e); + REQUIRE_TRUE(vB[e]->rankOf() == 2, 0, + "BatchedGemm: batch %i, rank of B should be equal to 2", e); + REQUIRE_TRUE(vC[e]->rankOf() == 2, 0, + "BatchedGemm: batch %i, rank of C should be equal to 2", e); + + REQUIRE_TRUE( + M == vA[e]->sizeAt(0), 0, + "BatchedGemm: batch %i, number of A.rows() should be equal to M", e); + REQUIRE_TRUE( + N == vB[e]->sizeAt(1), 0, + "BatchedGemm: batch %i, number of B.columns() should be equal to N", e); + REQUIRE_TRUE(K == vA[e]->sizeAt(1) && K == vB[e]->sizeAt(0), 0, + "BatchedGemm: batch %i, number of A.columns() and B.rows() " + "should be equal to K", + e); + }; + + REQUIRE_TRUE(vA.size() == vB.size() && vA.size() == vC.size() && + vA.size() == batchSize, + 0, + "BatchedGemm: mismatched numbers of A, B, C for unknown reason"); + + sd::ops::helpers::bgemm(vA, vB, vC, alpha, beta, transA, transB, M, N, K, ldA, + ldB, ldC); + + return Status::OK(); }; - DECLARE_SHAPE_FN(batched_gemm) { - int transA = INT_ARG(0); - int transB = INT_ARG(1); - int M = INT_ARG(2); - int N = INT_ARG(3); - int K = INT_ARG(4); - int ldA = INT_ARG(5); - int ldB = INT_ARG(6); - int ldC = INT_ARG(7); - int batchSize = INT_ARG(8); - - auto firstType = ArrayOptions::dataType(inputShape->at(0)); - for (int e = 1; e < block.width(); e++) { - REQUIRE_TRUE(firstType == ArrayOptions::dataType(inputShape->at(1)), 0, "BatchedGemm: all inputs must have same data type"); - } - - auto shapeList = SHAPELIST(); - - if (!(M > 0 && N > 0 && K > 0 && ldA > 0 && ldB > 0 && ldC > 0 && batchSize > 0)) { - shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inputShape->at(0)), 'c', {1, 1})); - return shapeList; - } - + int transA = INT_ARG(0); + int transB = INT_ARG(1); + int M = INT_ARG(2); + int N = INT_ARG(3); + int K = INT_ARG(4); + int ldA = INT_ARG(5); + int ldB = INT_ARG(6); + int ldC = INT_ARG(7); + int batchSize = INT_ARG(8); + + auto firstType = ArrayOptions::dataType(inputShape->at(0)); + for (int e = 1; e < block.width(); e++) { + REQUIRE_TRUE(firstType == ArrayOptions::dataType(inputShape->at(1)), 0, + "BatchedGemm: all inputs must have same data type"); + } + + auto shapeList = SHAPELIST(); + + if (!(M > 0 && N > 0 && K > 0 && ldA > 0 && ldB > 0 && ldC > 0 && + batchSize > 0)) { + shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inputShape->at(0)), 'c', {1, 1})); + return shapeList; + } - std::vector shape({M, N}); + std::vector shape({M, N}); - for (int e = 0; e < batchSize; e++) { - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inputShape->at(0)), 'f', shape); - shapeList->push_back(newShape); - } + for (int e = 0; e < batchSize; e++) { + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inputShape->at(0)), 'f', shape); + shapeList->push_back(newShape); + } - return shapeList; + return shapeList; } DECLARE_TYPES(batched_gemm) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) -// ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + // ->setAllowedInputTypes(1, {DataType::FLOAT32, + // DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes({ALL_FLOATS}); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/blas/matmul.cpp b/libnd4j/include/ops/declarable/generic/blas/matmul.cpp index f8ee952a85c4..d237cf901128 100644 --- a/libnd4j/include/ops/declarable/generic/blas/matmul.cpp +++ b/libnd4j/include/ops/declarable/generic/blas/matmul.cpp @@ -23,67 +23,100 @@ #include #if NOT_EXCLUDED(OP_matmul) -#include #include +#include namespace sd { - namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) { - - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - int iSize = (int) block.getIArguments()->size(); - int transX = iSize > 0 ? INT_ARG(0) : 0; - int transY = iSize > 1 ? INT_ARG(1) : 0; - const int transZ = iSize > 2 ? INT_ARG(2) : 0; - // optional use alpha nad beta - iSize = (int)block.getTArguments()->size(); - double alpha = iSize > 0 ? T_ARG(0) : 1.0; - double beta = iSize > 1 ? T_ARG(1) : 0.0; - - const int xRank = x->rankOf(); - const int yRank = y->rankOf(); - const int zRank = z->rankOf(); - - if (transZ) { - x = INPUT_VARIABLE(1); - y = INPUT_VARIABLE(0); - bool temp = transX; - transX = !transY; - transY = !temp; - } - - const int xLastDim = transX ? -2 : -1; - const int yLastDim = transY ? -2 : -1; - const int xLastButOneDim = transX ? -1 : -2; - const int yLastButOneDim = transY ? -1 : -2; - - // ******* input validation ******* // - REQUIRE_TRUE(xRank > 0 && yRank > 0, 0, "MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !", xRank, yRank); - - if (xRank == 1 && yRank == 1) { // dot case, output is scalar (or vector with length = 1) - REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, "MATMUL OP: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !", x->lengthOf(), y->lengthOf()); - } else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector - REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0, "MATMUL OP: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); - } else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector - REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0, "MATMUL OP: input arrays have inconsistent shapes for matrix-vector product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); - } else { - REQUIRE_TRUE(xRank == yRank && yRank == zRank, 0, "MATMUL OP: input and output arrays must have the same rank, but got instead: x rank = %i, y rank = %i, z rank = %i !", xRank, yRank, zRank); - REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) && x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && y->sizeAt(yLastDim) == z->sizeAt(-1), 0, "MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str()); - - if (xRank > 2) // outer dims must be the same - for (int i = 0; i < xRank - 2; ++i) - REQUIRE_TRUE(x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0, "MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str()); - } - // ******* end of input validation ******* // - - MmulHelper::matmul(x, y, z, transX, transY, alpha, beta); - - return Status::OK(); + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + int iSize = (int)block.numI(); + int transX = iSize > 0 ? INT_ARG(0) : 0; + int transY = iSize > 1 ? INT_ARG(1) : 0; + const int transZ = iSize > 2 ? INT_ARG(2) : 0; + // optional use alpha nad beta + iSize = (int)block.numT(); + double alpha = iSize > 0 ? T_ARG(0) : 1.0; + double beta = iSize > 1 ? T_ARG(1) : 0.0; + + const int xRank = x->rankOf(); + const int yRank = y->rankOf(); + const int zRank = z->rankOf(); + + if (transZ) { + x = INPUT_VARIABLE(1); + y = INPUT_VARIABLE(0); + bool temp = transX; + transX = !transY; + transY = !temp; + } + + const int xLastDim = transX ? -2 : -1; + const int yLastDim = transY ? -2 : -1; + const int xLastButOneDim = transX ? -1 : -2; + const int yLastButOneDim = transY ? -1 : -2; + + // ******* input validation ******* // + REQUIRE_TRUE(xRank > 0 && yRank > 0, 0, + "MATMUL OP: input arrays must have rank bigger than 0 (should " + "not be scalars), but got instead: x rank = %i, y rank = %i !", + xRank, yRank); + + if (xRank == 1 && + yRank == 1) { // dot case, output is scalar (or vector with length = 1) + REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, + "MATMUL OP: since input arrays are vectors they must have the " + "same length, but got x length = %i, y length = %i !", + x->lengthOf(), y->lengthOf()); + } else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = + // [5], output is vector + REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0, + "MATMUL OP: input arrays have inconsistent shapes for " + "vector-matrix product: x %s, y %s !", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str()); + } else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] + // = [4], output is vector + REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0, + "MATMUL OP: input arrays have inconsistent shapes for " + "matrix-vector product: x %s, y %s !", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str()); + } else { + REQUIRE_TRUE(xRank == yRank && yRank == zRank, 0, + "MATMUL OP: input and output arrays must have the same rank, " + "but got instead: x rank = %i, y rank = %i, z rank = %i !", + xRank, yRank, zRank); + REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) && + x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && + y->sizeAt(yLastDim) == z->sizeAt(-1), + 0, + "MATMUL OP: input/output arrays have inconsistent shapes for " + "matrix product: x %s, y %s, z %s !", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str(), + ShapeUtils::shapeAsString(z).c_str()); + + if (xRank > 2) // outer dims must be the same + for (int i = 0; i < xRank - 2; ++i) + REQUIRE_TRUE( + x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0, + "MATMUL OP: input/output arrays have inconsistent shapes for " + "matrix product: x %s, y %s, z %s !", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str(), + ShapeUtils::shapeAsString(z).c_str()); + } + // ******* end of input validation ******* // + + MmulHelper::matmul(x, y, z, transX, transY, alpha, beta); + + return Status::OK(); } DECLARE_SYN(mMul, matmul); @@ -98,111 +131,113 @@ DECLARE_SYN(dot, matmul); ////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(matmul) { - - auto xShapeInfo = inputShape->at(0); - auto yShapeInfo = inputShape->at(1); - - const int iSize = (int) block.getIArguments()->size(); - int transX = iSize > 0 ? INT_ARG(0) : 0; - int transY = iSize > 1 ? INT_ARG(1) : 0; - const int transZ = iSize > 2 ? INT_ARG(2) : 0; - - REQUIRE_TRUE(xShapeInfo[0] > 0 && yShapeInfo[0] > 0, 0, - "MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !", - xShapeInfo[0], yShapeInfo[0]); - - if (transZ) { - xShapeInfo = inputShape->at(1); - yShapeInfo = inputShape->at(0); - bool temp = transX; - transX = !transY; - transY = !temp; - } - - auto zShapeOnly = ShapeUtils::evalShapeForMatmul(xShapeInfo, yShapeInfo, transX, transY); - - auto dtypeX = ArrayOptions::dataType(xShapeInfo); - auto dtypeY = ArrayOptions::dataType(yShapeInfo); - - auto xOrder = shape::order(xShapeInfo); - auto yOrder = shape::order(yShapeInfo); - auto zOrder = xOrder == 'c' && yOrder == 'c' ? 'c' : 'f'; - - // we just pick the higher data type out of X and Y - auto dtypeZ = dtypeX > dtypeY ? dtypeX : dtypeY; - - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(dtypeZ, zOrder, zShapeOnly); - return SHAPELIST(newShape); + auto xShapeInfo = inputShape->at(0); + auto yShapeInfo = inputShape->at(1); + + const int iSize = (int)block.numI(); + int transX = iSize > 0 ? INT_ARG(0) : 0; + int transY = iSize > 1 ? INT_ARG(1) : 0; + const int transZ = iSize > 2 ? INT_ARG(2) : 0; + + REQUIRE_TRUE(xShapeInfo[0] > 0 && yShapeInfo[0] > 0, 0, + "MATMUL OP: input arrays must have rank bigger than 0 (should " + "not be scalars), but got instead: x rank = %i, y rank = %i !", + xShapeInfo[0], yShapeInfo[0]); + + if (transZ) { + xShapeInfo = inputShape->at(1); + yShapeInfo = inputShape->at(0); + bool temp = transX; + transX = !transY; + transY = !temp; + } + + auto zShapeOnly = + ShapeUtils::evalShapeForMatmul(xShapeInfo, yShapeInfo, transX, transY); + + auto dtypeX = ArrayOptions::dataType(xShapeInfo); + auto dtypeY = ArrayOptions::dataType(yShapeInfo); + + auto xOrder = shape::order(xShapeInfo); + auto yOrder = shape::order(yShapeInfo); + auto zOrder = xOrder == 'c' && yOrder == 'c' ? 'c' : 'f'; + + // we just pick the higher data type out of X and Y + auto dtypeZ = dtypeX > dtypeY ? dtypeX : dtypeY; + + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + dtypeZ, zOrder, zShapeOnly); + return SHAPELIST(newShape); } ////////////////////////////////////////////////////////////////////// DECLARE_TYPES(matmul) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_FLOATS, ALL_INTS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS}); + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_FLOATS, ALL_INTS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS}); } ////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto eps = INPUT_VARIABLE(2); - auto dldx = OUTPUT_VARIABLE(0); - auto dldy = OUTPUT_VARIABLE(1); - - int iSize = (int) block.getIArguments()->size(); - int transX = iSize > 0 ? INT_ARG(0) : 0; - int transY = iSize > 1 ? INT_ARG(1) : 0; - const int transZ = iSize > 2 ? INT_ARG(2) : 0; - - // optional use alpha nad beta - iSize = (int)block.getTArguments()->size(); - - double alpha = iSize > 0 ? T_ARG(0) : 1.0; - double beta = iSize > 1 ? T_ARG(1) : 0.0; - -/* -In: x=[a,b], y=[b,c] -tX tY tZ x y z dz dLdx dLdy -F F F [a,b] [b,c] [a,c] [a,c] [a,c]*[b,c]T = [a,b] x*yT [a,b]T*[a,c] = [b,c] xT*y -T F F [b,a] [b,c] [a,c] [a,c] ([a,c]*[b,c]T)T = [b,a] (x*yT)T [b,a]*[a,c] = [b,c] x*y -F T F [a,b] [c,b] [a,c] [a,c] ([a,c]*[c,b]) = [a,b] x*y [a,b]T*[a,c] = [b,c] ->T xT*y -T T F [b,a] [c,b] [a,c] [a,c] ([a,c]*[c,b])T = [b,a] (x*y)T [b,a]*[a,c] = [b,c] ->T x*y -F F T [a,b] [b,c] [c,a] [c,a] -*/ - - - sd::ops::matmul op; - op.execute({eps, y}, {dldx}, {alpha, beta}, {transZ, !transY, transX}, {}); - op.execute({x, eps}, {dldy}, {alpha, beta}, {!transX, transZ, transY}, {}); - - return Status::OK(); + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto eps = INPUT_VARIABLE(2); + auto dldx = OUTPUT_VARIABLE(0); + auto dldy = OUTPUT_VARIABLE(1); + + int iSize = (int)block.numI(); + int transX = iSize > 0 ? INT_ARG(0) : 0; + int transY = iSize > 1 ? INT_ARG(1) : 0; + const int transZ = iSize > 2 ? INT_ARG(2) : 0; + + // optional use alpha nad beta + iSize = (int)block.numT(); + + double alpha = iSize > 0 ? T_ARG(0) : 1.0; + double beta = iSize > 1 ? T_ARG(1) : 0.0; + + /* + In: x=[a,b], y=[b,c] + tX tY tZ x y z dz dLdx dLdy F F F [a,b] + [b,c] [a,c] [a,c] [a,c]*[b,c]T = [a,b] x*yT [a,b]T*[a,c] = + [b,c] xT*y T F F [b,a] [b,c] [a,c] [a,c] ([a,c]*[b,c]T)T = + [b,a] (x*yT)T [b,a]*[a,c] = [b,c] x*y F T F [a,b] [c,b] + [a,c] [a,c] ([a,c]*[c,b]) = [a,b] x*y [a,b]T*[a,c] = + [b,c] ->T xT*y T T F [b,a] [c,b] [a,c] [a,c] ([a,c]*[c,b])T = + [b,a] (x*y)T [b,a]*[a,c] = [b,c] ->T x*y F F T [a,b] [b,c] + [c,a] [c,a] + */ + + sd::ops::matmul op; + op.execute({eps, y}, {dldx}, {alpha, beta}, {transZ, !transY, transX}, {}); + op.execute({x, eps}, {dldy}, {alpha, beta}, {!transX, transZ, transY}, {}); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(matmul_bp) { - Nd4jLong *xShapeInfo; - Nd4jLong *yShapeInfo; + Nd4jLong *xShapeInfo; + Nd4jLong *yShapeInfo; - COPY_SHAPE(inputShape->at(0), xShapeInfo); - COPY_SHAPE(inputShape->at(1), yShapeInfo); + COPY_SHAPE(inputShape->at(0), xShapeInfo); + COPY_SHAPE(inputShape->at(1), yShapeInfo); - return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(yShapeInfo)); + return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(yShapeInfo)); } ////////////////////////////////////////////////////////////////////// DECLARE_TYPES(matmul_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_FLOATS}); -} - -} + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/blas/tensormmul.cpp b/libnd4j/include/ops/declarable/generic/blas/tensormmul.cpp index 0ae64b8cd3ec..a349038d2261 100644 --- a/libnd4j/include/ops/declarable/generic/blas/tensormmul.cpp +++ b/libnd4j/include/ops/declarable/generic/blas/tensormmul.cpp @@ -21,174 +21,188 @@ #include #if NOT_EXCLUDED(OP_tensormmul) -#include +#include #include #include -#include +#include namespace sd { -namespace ops { +namespace ops { //////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(tensormmul, 2, 1, false, 0, -1) { + auto a = INPUT_VARIABLE(0); + auto b = INPUT_VARIABLE(1); - auto a = INPUT_VARIABLE(0); - auto b = INPUT_VARIABLE(1); - - auto c = OUTPUT_VARIABLE(0); + auto c = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(a->dataType() == b->dataType(), 0, "tensormmul: A, B and C data types must be the same"); + REQUIRE_TRUE(a->dataType() == b->dataType(), 0, + "tensormmul: A, B and C data types must be the same"); - // building axes - int axe0_size = INT_ARG(0); - int axe1_size = INT_ARG(axe0_size+1); - std::vector axes_0(axe0_size), axes_1(axe1_size); - for (int e = 0; e < axe0_size; e++) - axes_0[e] = (int)INT_ARG(e + 1); + // building axes + int axe0_size = INT_ARG(0); + int axe1_size = INT_ARG(axe0_size + 1); + std::vector axes_0(axe0_size), axes_1(axe1_size); + for (int e = 0; e < axe0_size; e++) axes_0[e] = (int)INT_ARG(e + 1); - for (int e = 0; e < axe1_size; e++) - axes_1[e] = (int)INT_ARG(e + axe0_size + 2); + for (int e = 0; e < axe1_size; e++) + axes_1[e] = (int)INT_ARG(e + axe0_size + 2); - nd4j_verbose("axe0: %i; axe1: %i;\n", axes_0.size(), axes_1.size()); + nd4j_verbose("axe0: %i; axe1: %i;\n", axes_0.size(), axes_1.size()); - MmulHelper::tensorDot(a, b, c, axes_0, axes_1); - return Status::OK(); + MmulHelper::tensorDot(a, b, c, axes_0, axes_1); + return Status::OK(); } DECLARE_SYN(tensordot, tensormmul); //////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(tensormmul) { - - auto aShapeInfo = inputShape->at(0); - auto bShapeInfo = inputShape->at(1); - - REQUIRE_TRUE(ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo), 0, "tensormmul: A and B data types must be the same"); - - // building axes - int axe0_size = INT_ARG(0); - int axe1_size = INT_ARG(axe0_size+1); - std::vector axes_0(axe0_size), axes_1(axe1_size); - for (int e = 0; e < axe0_size; e++) - axes_0[e] = (int) INT_ARG(e+1); - - for (int e = 0; e < axe1_size; e++) - axes_1[e] = (int) INT_ARG(e + axe0_size + 2); - - // evaluate shapes - std::vector permutAt, permutBt; - std::vector shapeAt, shapeBt; - auto outShape = sd::ShapeUtils::evalShapeForTensorDot(aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt); - - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(aShapeInfo), 'c', outShape))); + auto aShapeInfo = inputShape->at(0); + auto bShapeInfo = inputShape->at(1); + + REQUIRE_TRUE( + ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo), + 0, "tensormmul: A and B data types must be the same"); + + // building axes + int axe0_size = INT_ARG(0); + int axe1_size = INT_ARG(axe0_size + 1); + std::vector axes_0(axe0_size), axes_1(axe1_size); + for (int e = 0; e < axe0_size; e++) axes_0[e] = (int)INT_ARG(e + 1); + + for (int e = 0; e < axe1_size; e++) + axes_1[e] = (int)INT_ARG(e + axe0_size + 2); + + // evaluate shapes + std::vector permutAt, permutBt; + std::vector shapeAt, shapeBt; + auto outShape = sd::ShapeUtils::evalShapeForTensorDot( + aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, + shapeBt); + + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(aShapeInfo), 'c', outShape))); } //////////////////////////////////////////////////////////////////////// DECLARE_TYPES(tensormmul) { - getOpDescriptor() - ->setAllowedInputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); + getOpDescriptor() + ->setAllowedInputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } //////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(tensormmul_bp, 3, 2, false, 0, -1) { + auto A = INPUT_VARIABLE(0); + auto B = INPUT_VARIABLE(1); - auto A = INPUT_VARIABLE(0); - auto B = INPUT_VARIABLE(1); - - auto dLdC = INPUT_VARIABLE(2); + auto dLdC = INPUT_VARIABLE(2); - auto dLdA = OUTPUT_VARIABLE(0); - auto dLdB = OUTPUT_VARIABLE(1); + auto dLdA = OUTPUT_VARIABLE(0); + auto dLdB = OUTPUT_VARIABLE(1); - REQUIRE_TRUE( (A->dataType() == B->dataType() && (dLdC->dataType() == A->dataType())), 0, "tensormmul_bp: A, B and dLdC data types must be the same"); + REQUIRE_TRUE( + (A->dataType() == B->dataType() && (dLdC->dataType() == A->dataType())), + 0, "tensormmul_bp: A, B and dLdC data types must be the same"); - int axe0Size = INT_ARG(0); - int axe1Size = INT_ARG(axe0Size + 1); + int axe0Size = INT_ARG(0); + int axe1Size = INT_ARG(axe0Size + 1); - auto Arank = A->rankOf(); - auto Brank = B->rankOf(); - auto dLdCrank = dLdC->rankOf(); + auto Arank = A->rankOf(); + auto Brank = B->rankOf(); + auto dLdCrank = dLdC->rankOf(); - REQUIRE_TRUE((Arank >= axe0Size), 0, "tensormmul_bp: A rank must be the higher or same as input axes 0"); + REQUIRE_TRUE( + (Arank >= axe0Size), 0, + "tensormmul_bp: A rank must be the higher or same as input axes 0"); - REQUIRE_TRUE((Brank >= axe1Size), 0, "tensormmul_bp: B rank must be the higher or same as input axes 1"); + REQUIRE_TRUE( + (Brank >= axe1Size), 0, + "tensormmul_bp: B rank must be the higher or same as input axes 1"); - // building axes - std::vector axes0(axe0Size), axes1(axe1Size); - for (uint e = 0; e < axe0Size; e++) - axes0[e] = (int)INT_ARG(e + 1); - for (uint e = 0; e < axe1Size; e++) - axes1[e] = (int)INT_ARG(e + axe0Size + 2); + // building axes + std::vector axes0(axe0Size), axes1(axe1Size); + for (uint e = 0; e < axe0Size; e++) axes0[e] = (int)INT_ARG(e + 1); + for (uint e = 0; e < axe1Size; e++) axes1[e] = (int)INT_ARG(e + axe0Size + 2); - std::vector permutAt, permutBt; - std::vector shapeAt, shapeBt; + std::vector permutAt, permutBt; + std::vector shapeAt, shapeBt; - ShapeUtils::evalShapeForTensorDot(A, B, axes0, axes1, permutAt, permutBt, shapeAt, shapeBt); + ShapeUtils::evalShapeForTensorDot(A, B, axes0, axes1, permutAt, permutBt, + shapeAt, shapeBt); - // special case for scalar value - if (dLdC->isScalar()) { + // special case for scalar value + if (dLdC->isScalar()) { + dLdA->assign((*dLdC) * *B); + dLdB->assign((*dLdC) * *A); - dLdA->assign((*dLdC) * *B); - dLdB->assign((*dLdC) * *A); - - return Status::OK(); - } + return Status::OK(); + } - std::vector axesA = ShapeUtils::evalDimsToExclude(Arank, axes0); - std::vector axesB = ShapeUtils::evalDimsToExclude(Brank, axes1); + std::vector axesA = ShapeUtils::evalDimsToExclude(Arank, axes0); + std::vector axesB = ShapeUtils::evalDimsToExclude(Brank, axes1); - // rank always have to be divided by 2 - std::vector axesAdLdC, axesBdLdC; - if (dLdCrank > 1) { - axesAdLdC.resize(dLdCrank / 2); - std::iota(axesAdLdC.begin(), axesAdLdC.end(), 0); - axesBdLdC = ShapeUtils::evalDimsToExclude(dLdCrank, axesAdLdC); - } - else { - axesAdLdC.push_back(0); - axesBdLdC.push_back(0); - } + // rank always have to be divided by 2 + std::vector axesAdLdC, axesBdLdC; + if (dLdCrank > 1) { + axesAdLdC.resize(dLdCrank / 2); + std::iota(axesAdLdC.begin(), axesAdLdC.end(), 0); + axesBdLdC = ShapeUtils::evalDimsToExclude(dLdCrank, axesAdLdC); + } else { + axesAdLdC.push_back(0); + axesBdLdC.push_back(0); + } - // calculate dLdA - MmulHelper::tensorDot(dLdC, B, dLdA, axesBdLdC, axesB, permutAt); + // calculate dLdA + MmulHelper::tensorDot(dLdC, B, dLdA, axesBdLdC, axesB, permutAt); - // calculate dLdB - MmulHelper::tensorDot(A, dLdC, dLdB, axesA, axesAdLdC, permutBt); + // calculate dLdB + MmulHelper::tensorDot(A, dLdC, dLdB, axesA, axesAdLdC, permutBt); - return Status::OK(); + return Status::OK(); } //////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(tensormmul_bp) { + auto aShapeInfo = inputShape->at(0); + auto bShapeInfo = inputShape->at(1); + auto dLShapeInfo = inputShape->at(2); - auto aShapeInfo = inputShape->at(0); - auto bShapeInfo = inputShape->at(1); - auto dLShapeInfo = inputShape->at(2); - - REQUIRE_TRUE((ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo) && - (ArrayOptions::dataType(dLShapeInfo) == ArrayOptions::dataType(aShapeInfo))), 0, "tensormmul_bp: A, B and dLdC data types must be the same"); + REQUIRE_TRUE((ArrayOptions::dataType(aShapeInfo) == + ArrayOptions::dataType(bShapeInfo) && + (ArrayOptions::dataType(dLShapeInfo) == + ArrayOptions::dataType(aShapeInfo))), + 0, "tensormmul_bp: A, B and dLdC data types must be the same"); - Nd4jLong* dLdAShapeInfo = nullptr; - Nd4jLong* dLdBShapeInfo = nullptr; + Nd4jLong* dLdAShapeInfo = nullptr; + Nd4jLong* dLdBShapeInfo = nullptr; - COPY_SHAPE(aShapeInfo, dLdAShapeInfo); - COPY_SHAPE(bShapeInfo, dLdBShapeInfo); + COPY_SHAPE(aShapeInfo, dLdAShapeInfo); + COPY_SHAPE(bShapeInfo, dLdBShapeInfo); - return SHAPELIST(CONSTANT(dLdAShapeInfo), CONSTANT(dLdBShapeInfo)); + return SHAPELIST(CONSTANT(dLdAShapeInfo), CONSTANT(dLdBShapeInfo)); } //////////////////////////////////////////////////////////////////////// DECLARE_TYPES(tensormmul_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF }) // maybe better ALL_FLOATS - ->setAllowedInputTypes(1, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF }) - ->setAllowedInputTypes(2, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF }) - ->setAllowedOutputTypes(0, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF }) - ->setAllowedOutputTypes(1, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF }); -} -} + getOpDescriptor() + ->setAllowedInputTypes(0, {DataType::FLOAT32, DataType::DOUBLE, + DataType::HALF}) // maybe better ALL_FLOATS + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType::DOUBLE, DataType::HALF}) + ->setAllowedInputTypes( + 2, {DataType::FLOAT32, DataType::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 1, {DataType::FLOAT32, DataType::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/boolean_not.cpp b/libnd4j/include/ops/declarable/generic/boolean/boolean_not.cpp index 56047f16c9ec..5cefb3eddecc 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/boolean_not.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/boolean_not.cpp @@ -24,22 +24,22 @@ #include namespace sd { - namespace ops { - OP_IMPL(boolean_not, 1, 1,true) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); - - x->applyTransform(transform::Not, *z); - - return Status::OK(); - } - - DECLARE_TYPES(boolean_not) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::BOOL) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } +namespace ops { +OP_IMPL(boolean_not, 1, 1, true) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + + x->applyTransform(transform::Not, *z); + + return Status::OK(); +} + +DECLARE_TYPES(boolean_not) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::BOOL) + ->setAllowedOutputTypes(0, DataType::BOOL); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/choose.cpp b/libnd4j/include/ops/declarable/generic/boolean/choose.cpp index a28d8230b50a..0058e0acd71d 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/choose.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/choose.cpp @@ -26,74 +26,75 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(choose, -1, 2, false, -2, -1) { - - int mode = INT_ARG(0); - auto result = OUTPUT_VARIABLE(0); - auto numResults = OUTPUT_VARIABLE(1); - - if (block.width() > 1) { - auto arg = INPUT_VARIABLE(0); - auto comp = INPUT_VARIABLE(1); - - helpers::chooseFunctorArray(block.launchContext(), arg, comp, mode, result, numResults); - - }//scalar case - else { - double scalar = T_ARG(0); - auto arg = INPUT_VARIABLE(0); - helpers::chooseFunctorScalar(block.launchContext(), arg, scalar, mode, result, numResults); - } - - - return Status::OK(); - } - - DECLARE_TYPES(choose) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}); - } - - DECLARE_SHAPE_FN(choose) { - Nd4jLong const* shape; - int rank; - int mode = INT_ARG(0); - auto numResults = NDArrayFactory::create(0L); - if(block.width() > 1) { - auto first = INPUT_VARIABLE(0); - auto second = INPUT_VARIABLE(1); - if(first->lengthOf() > second->lengthOf()) { - shape = first->shapeInfo(); - rank = first->rankOf(); - } - else { - shape = second->shapeInfo(); - rank = second->rankOf(); - } - - helpers::chooseFunctorArray(block.launchContext(), first, second, mode, nullptr, &numResults); - } - else { - auto first = INPUT_VARIABLE(0); - shape = first->shapeInfo(); - rank = first->rankOf(); - double scalar = T_ARG(0); - - helpers::chooseFunctorScalar(block.launchContext(), first, scalar, mode, nullptr, &numResults); - } - - auto newShape = ConstantShapeHelper::getInstance().vectorShapeInfo(numResults.e(0), ArrayOptions::dataType(inputShape->at(0))); - - auto shapeScalar = ConstantShapeHelper::getInstance().scalarShapeInfo(sd::DataType::INT64); - return SHAPELIST(newShape, shapeScalar); - } +namespace ops { +CUSTOM_OP_IMPL(choose, -1, 2, false, -2, -1) { + int mode = INT_ARG(0); + auto result = OUTPUT_VARIABLE(0); + auto numResults = OUTPUT_VARIABLE(1); + + if (block.width() > 1) { + auto arg = INPUT_VARIABLE(0); + auto comp = INPUT_VARIABLE(1); + + helpers::chooseFunctorArray(block.launchContext(), arg, comp, mode, result, + numResults); + + } // scalar case + else { + double scalar = T_ARG(0); + auto arg = INPUT_VARIABLE(0); + helpers::chooseFunctorScalar(block.launchContext(), arg, scalar, mode, + result, numResults); + } + + return Status::OK(); +} +DECLARE_TYPES(choose) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INTS}); +} +DECLARE_SHAPE_FN(choose) { + Nd4jLong const* shape; + int rank; + int mode = INT_ARG(0); + auto numResults = NDArrayFactory::create(0L); + if (block.width() > 1) { + auto first = INPUT_VARIABLE(0); + auto second = INPUT_VARIABLE(1); + if (first->lengthOf() > second->lengthOf()) { + shape = first->shapeInfo(); + rank = first->rankOf(); + } else { + shape = second->shapeInfo(); + rank = second->rankOf(); } + + helpers::chooseFunctorArray(block.launchContext(), first, second, mode, + nullptr, &numResults); + } else { + auto first = INPUT_VARIABLE(0); + shape = first->shapeInfo(); + rank = first->rankOf(); + double scalar = T_ARG(0); + + helpers::chooseFunctorScalar(block.launchContext(), first, scalar, mode, + nullptr, &numResults); + } + + auto newShape = ConstantShapeHelper::getInstance().vectorShapeInfo( + numResults.e(0), ArrayOptions::dataType(inputShape->at(0))); + + auto shapeScalar = + ConstantShapeHelper::getInstance().scalarShapeInfo(sd::DataType::INT64); + return SHAPELIST(newShape, shapeScalar); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/eq_scalar.cpp b/libnd4j/include/ops/declarable/generic/boolean/eq_scalar.cpp index c0623ebb70a0..4c733cd8a00d 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/eq_scalar.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/eq_scalar.cpp @@ -24,26 +24,26 @@ #include namespace sd { - namespace ops { - BOOLEAN_OP_IMPL(eq_scalar, 2, true) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); +namespace ops { +BOOLEAN_OP_IMPL(eq_scalar, 2, true) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); - if (x->e(0) == y->e(0)) - return ND4J_STATUS_TRUE; - else - return ND4J_STATUS_FALSE; - } - DECLARE_SYN(Equals, eq_scalar); - //DECLARE_SYN(equals, eq_scalar); + if (x->e(0) == y->e(0)) + return ND4J_STATUS_TRUE; + else + return ND4J_STATUS_FALSE; +} +DECLARE_SYN(Equals, eq_scalar); +// DECLARE_SYN(equals, eq_scalar); - DECLARE_TYPES(eq_scalar) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } +DECLARE_TYPES(eq_scalar) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/gt_scalar.cpp b/libnd4j/include/ops/declarable/generic/boolean/gt_scalar.cpp index d40c501d4f93..8d97d33acf75 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/gt_scalar.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/gt_scalar.cpp @@ -24,26 +24,26 @@ #include namespace sd { - namespace ops { - BOOLEAN_OP_IMPL(gt_scalar, 2, true) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); +namespace ops { +BOOLEAN_OP_IMPL(gt_scalar, 2, true) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); - if (x->e(0) > y->e(0)) - return ND4J_STATUS_TRUE; - else - return ND4J_STATUS_FALSE; - } - //DECLARE_SYN(Greater, gt_scalar); - //DECLARE_SYN(greater, gt_scalar); + if (x->e(0) > y->e(0)) + return ND4J_STATUS_TRUE; + else + return ND4J_STATUS_FALSE; +} +// DECLARE_SYN(Greater, gt_scalar); +// DECLARE_SYN(greater, gt_scalar); - DECLARE_TYPES(gt_scalar) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } +DECLARE_TYPES(gt_scalar) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/gte_scalar.cpp b/libnd4j/include/ops/declarable/generic/boolean/gte_scalar.cpp index d555f5d24437..ce854e4bbc42 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/gte_scalar.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/gte_scalar.cpp @@ -24,26 +24,26 @@ #include namespace sd { - namespace ops { - BOOLEAN_OP_IMPL(gte_scalar, 2, true) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); +namespace ops { +BOOLEAN_OP_IMPL(gte_scalar, 2, true) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); - if (x->e(0) >= y->e(0)) - return ND4J_STATUS_TRUE; - else - return ND4J_STATUS_FALSE; - } - DECLARE_SYN(GreaterOrEquals, gte_scalar); - DECLARE_SYN(greaterOrEquals, gte_scalar); + if (x->e(0) >= y->e(0)) + return ND4J_STATUS_TRUE; + else + return ND4J_STATUS_FALSE; +} +DECLARE_SYN(GreaterOrEquals, gte_scalar); +DECLARE_SYN(greaterOrEquals, gte_scalar); - DECLARE_TYPES(gte_scalar) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } +DECLARE_TYPES(gte_scalar) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/is_non_decreasing.cpp b/libnd4j/include/ops/declarable/generic/boolean/is_non_decreasing.cpp index 4dd8ca605b12..69dcc3b44b17 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/is_non_decreasing.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/is_non_decreasing.cpp @@ -25,30 +25,30 @@ #include namespace sd { - namespace ops { - BOOLEAN_OP_IMPL(is_non_decreasing, 1, true) { - auto input = INPUT_VARIABLE(0); - - // in case of empty input there's nothing to do - if (input->isEmpty()) - return ND4J_STATUS_TRUE; - - bool isNonDecreasing = true; - - sd::ops::helpers::compare_elem(block.launchContext(), input, false, isNonDecreasing); - - if (isNonDecreasing) - return ND4J_STATUS_TRUE; - else - return ND4J_STATUS_FALSE; - } - - DECLARE_TYPES(is_non_decreasing) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } +namespace ops { +BOOLEAN_OP_IMPL(is_non_decreasing, 1, true) { + auto input = INPUT_VARIABLE(0); + + // in case of empty input there's nothing to do + if (input->isEmpty()) return ND4J_STATUS_TRUE; + + bool isNonDecreasing = true; + + sd::ops::helpers::compare_elem(block.launchContext(), input, false, + isNonDecreasing); + + if (isNonDecreasing) + return ND4J_STATUS_TRUE; + else + return ND4J_STATUS_FALSE; +} + +DECLARE_TYPES(is_non_decreasing) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/is_numeric_tensor.cpp b/libnd4j/include/ops/declarable/generic/boolean/is_numeric_tensor.cpp index 184b7b0a6da9..585bd6cb7ffe 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/is_numeric_tensor.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/is_numeric_tensor.cpp @@ -25,20 +25,19 @@ #include namespace sd { - namespace ops { - BOOLEAN_OP_IMPL(is_numeric_tensor, 1, true) { +namespace ops { +BOOLEAN_OP_IMPL(is_numeric_tensor, 1, true) { + auto input = INPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - - return input->isR() || input->isZ() ? ND4J_STATUS_TRUE : ND4J_STATUS_FALSE; - } + return input->isR() || input->isZ() ? ND4J_STATUS_TRUE : ND4J_STATUS_FALSE; +} - DECLARE_TYPES(is_numeric_tensor) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } +DECLARE_TYPES(is_numeric_tensor) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/is_strictly_increasing.cpp b/libnd4j/include/ops/declarable/generic/boolean/is_strictly_increasing.cpp index 0c434cf571bd..580175ef2ba2 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/is_strictly_increasing.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/is_strictly_increasing.cpp @@ -25,30 +25,30 @@ #include namespace sd { - namespace ops { - BOOLEAN_OP_IMPL(is_strictly_increasing, 1, true) { - auto input = INPUT_VARIABLE(0); - - // in case of empty input there's nothing to do - if (input->isEmpty()) - return ND4J_STATUS_TRUE; - - bool isStrictlyIncreasing = true; - - sd::ops::helpers::compare_elem(block.launchContext(), input, true, isStrictlyIncreasing); - - if (isStrictlyIncreasing) - return ND4J_STATUS_TRUE; - else - return ND4J_STATUS_FALSE; - } - - DECLARE_TYPES(is_strictly_increasing) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } +namespace ops { +BOOLEAN_OP_IMPL(is_strictly_increasing, 1, true) { + auto input = INPUT_VARIABLE(0); + + // in case of empty input there's nothing to do + if (input->isEmpty()) return ND4J_STATUS_TRUE; + + bool isStrictlyIncreasing = true; + + sd::ops::helpers::compare_elem(block.launchContext(), input, true, + isStrictlyIncreasing); + + if (isStrictlyIncreasing) + return ND4J_STATUS_TRUE; + else + return ND4J_STATUS_FALSE; +} + +DECLARE_TYPES(is_strictly_increasing) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/lt_scalar.cpp b/libnd4j/include/ops/declarable/generic/boolean/lt_scalar.cpp index 1c4f7ab2759b..a4dc91797fde 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/lt_scalar.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/lt_scalar.cpp @@ -24,26 +24,26 @@ #include namespace sd { - namespace ops { - BOOLEAN_OP_IMPL(lt_scalar, 2, true) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); +namespace ops { +BOOLEAN_OP_IMPL(lt_scalar, 2, true) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); - if (x->e(0) < y->e(0)) - return ND4J_STATUS_TRUE; - else - return ND4J_STATUS_FALSE; - } - //DECLARE_SYN(Less, lt_scalar); - //DECLARE_SYN(less, lt_scalar); + if (x->e(0) < y->e(0)) + return ND4J_STATUS_TRUE; + else + return ND4J_STATUS_FALSE; +} +// DECLARE_SYN(Less, lt_scalar); +// DECLARE_SYN(less, lt_scalar); - DECLARE_TYPES(lt_scalar) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } +DECLARE_TYPES(lt_scalar) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/lte_scalar.cpp b/libnd4j/include/ops/declarable/generic/boolean/lte_scalar.cpp index 07a72cfedfdc..7fe63df45aa1 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/lte_scalar.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/lte_scalar.cpp @@ -24,26 +24,26 @@ #include namespace sd { - namespace ops { - BOOLEAN_OP_IMPL(lte_scalar, 2, true) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); +namespace ops { +BOOLEAN_OP_IMPL(lte_scalar, 2, true) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); - if (x->e(0) <= y->e(0)) - return ND4J_STATUS_TRUE; - else - return ND4J_STATUS_FALSE; - } - DECLARE_SYN(LessOrEquals, lte_scalar); - DECLARE_SYN(lessorequals, lte_scalar); + if (x->e(0) <= y->e(0)) + return ND4J_STATUS_TRUE; + else + return ND4J_STATUS_FALSE; +} +DECLARE_SYN(LessOrEquals, lte_scalar); +DECLARE_SYN(lessorequals, lte_scalar); - DECLARE_TYPES(lte_scalar) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } +DECLARE_TYPES(lte_scalar) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/neq_scalar.cpp b/libnd4j/include/ops/declarable/generic/boolean/neq_scalar.cpp index 1c05b9fc13e4..6e9d8056a7f0 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/neq_scalar.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/neq_scalar.cpp @@ -24,26 +24,26 @@ #include namespace sd { - namespace ops { - BOOLEAN_OP_IMPL(neq_scalar, 2, true) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); +namespace ops { +BOOLEAN_OP_IMPL(neq_scalar, 2, true) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); - if (x->e(0) != y->e(0)) - return ND4J_STATUS_TRUE; - else - return ND4J_STATUS_FALSE; - } - DECLARE_SYN(NotEquals, neq_scalar); - DECLARE_SYN(notequals, neq_scalar); + if (x->e(0) != y->e(0)) + return ND4J_STATUS_TRUE; + else + return ND4J_STATUS_FALSE; +} +DECLARE_SYN(NotEquals, neq_scalar); +DECLARE_SYN(notequals, neq_scalar); - DECLARE_TYPES(neq_scalar) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } +DECLARE_TYPES(neq_scalar) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/select.cpp b/libnd4j/include/ops/declarable/generic/boolean/select.cpp index e8e257258b3c..9088278d719c 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/select.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/select.cpp @@ -25,81 +25,88 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(select, 3, 1, false, 0, 0) { - auto cond = INPUT_VARIABLE(0); - auto x = INPUT_VARIABLE(1); - auto y = INPUT_VARIABLE(2); - - REQUIRE_TRUE(x->isSameShape(y), 0, "Select: X and Y shape should be equal"); - if (x->isScalar()) { - REQUIRE_TRUE(cond->isScalar(), 0, - "Select: Condition should gave either equal shape to X/Y first dimension or to be scalar"); - - auto z = OUTPUT_VARIABLE(0); - - if (y->isR()) { - auto v = !cond->e(0)? y->e(0) : x->e(0); - z->p(0, v); - } else { - auto v = !cond->e(0)? y->e(0) : x->e(0); - z->p(0, v); - } - } else { - bool same = cond->isSameShape(x); - REQUIRE_TRUE(cond->isScalar() || cond->lengthOf() == x->sizeAt(0) || same, 0, "Select: Condition should gave either equal shape to X/Y first dimension or to be scalar"); - if (same) { - auto z = OUTPUT_VARIABLE(0); - - for (int e = 0; e < cond->lengthOf(); e++) { - if (y->isR()) { - auto r = !cond->e(e) ? y->e(e) : x->e(e); - z->p(e, r); - } else { - auto r = !cond->e(e) ? y->e(e) : x->e(e); - z->p(e, r); - } - } - } else { - REQUIRE_TRUE(cond->lengthOf() == x->sizeAt(0), 0, "Condition length should be equal to the dim0 of x/y to act as TAD-mask, but got %d instead", cond->lengthOf()); - - auto z = OUTPUT_VARIABLE(0); - - auto dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {0}); - auto tadsX = x->allTensorsAlongDimension(dims); - auto tadsY = y->allTensorsAlongDimension(dims); - auto tadsZ = z->allTensorsAlongDimension(dims); - - for (int e = 0; e < tadsX.size(); e++) { - if (!cond->e(e)) { - tadsZ.at(e)->assign(tadsY.at(e)); - } else { - tadsZ.at(e)->assign(tadsX.at(e)); - } - } - } - } - - return Status::OK(); +namespace ops { +CUSTOM_OP_IMPL(select, 3, 1, false, 0, 0) { + auto cond = INPUT_VARIABLE(0); + auto x = INPUT_VARIABLE(1); + auto y = INPUT_VARIABLE(2); + + REQUIRE_TRUE(x->isSameShape(y), 0, "Select: X and Y shape should be equal"); + if (x->isScalar()) { + REQUIRE_TRUE(cond->isScalar(), 0, + "Select: Condition should gave either equal shape to X/Y " + "first dimension or to be scalar"); + + auto z = OUTPUT_VARIABLE(0); + + if (y->isR()) { + auto v = !cond->e(0) ? y->e(0) : x->e(0); + z->p(0, v); + } else { + auto v = !cond->e(0) ? y->e(0) : x->e(0); + z->p(0, v); + } + } else { + bool same = cond->isSameShape(x); + REQUIRE_TRUE(cond->isScalar() || cond->lengthOf() == x->sizeAt(0) || same, + 0, + "Select: Condition should gave either equal shape to X/Y " + "first dimension or to be scalar"); + if (same) { + auto z = OUTPUT_VARIABLE(0); + + for (int e = 0; e < cond->lengthOf(); e++) { + if (y->isR()) { + auto r = !cond->e(e) ? y->e(e) : x->e(e); + z->p(e, r); + } else { + auto r = !cond->e(e) ? y->e(e) : x->e(e); + z->p(e, r); + } + } + } else { + REQUIRE_TRUE(cond->lengthOf() == x->sizeAt(0), 0, + "Condition length should be equal to the dim0 of x/y to act " + "as TAD-mask, but got %d instead", + cond->lengthOf()); + + auto z = OUTPUT_VARIABLE(0); + + auto dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {0}); + auto tadsX = x->allTensorsAlongDimension(dims); + auto tadsY = y->allTensorsAlongDimension(dims); + auto tadsZ = z->allTensorsAlongDimension(dims); + + for (int e = 0; e < tadsX.size(); e++) { + if (!cond->e(e)) { + tadsZ.at(e).assign(tadsY.at(e)); + } else { + tadsZ.at(e).assign(tadsX.at(e)); } + } + } + } - DECLARE_SHAPE_FN(select) { - auto inShape = inputShape->at(1); + return Status::OK(); +} - Nd4jLong *newshape; - COPY_SHAPE(inShape, newshape); +DECLARE_SHAPE_FN(select) { + auto inShape = inputShape->at(1); - return SHAPELIST(CONSTANT(newshape)); - } + Nd4jLong *newshape; + COPY_SHAPE(inShape, newshape); - DECLARE_TYPES(select) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::BOOL) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedInputTypes(2, DataType::ANY) - ->setAllowedOutputTypes(1, DataType::INHERIT); - } - } + return SHAPELIST(CONSTANT(newshape)); +} + +DECLARE_TYPES(select) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::BOOL) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedInputTypes(2, DataType::ANY) + ->setAllowedOutputTypes(1, DataType::INHERIT); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/where.cpp b/libnd4j/include/ops/declarable/generic/boolean/where.cpp index a72de2ee0f2d..61765eb5e24b 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/where.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/where.cpp @@ -26,112 +26,122 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(Where, 1, 1, false, 0, 0) { - auto condition = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); - if (z->isEmpty()) - return Status::OK(); - - if (block.width() == 3) { - auto x = INPUT_VARIABLE(1); - auto y = INPUT_VARIABLE(2); - - REQUIRE_TRUE(x->isSameShape(y), 0, "X and Y must have equal shapes"); - - // if cond matches x/y shape - we have per-element mask - if (condition->isSameShape(x)) { - // FIXME: for perf it might be better to issue memcpy here, and fill only mismatched values from either X or Y - for (int e = 0; e < condition->lengthOf(); e++) { - if (y->isR()) { - auto r = !condition->e(e) ? y->e(e) : x->e(e); - z->p(e, r); - } else { - auto r = !condition->e(e) ? y->e(e) : x->e(e); - z->p(e, r); - } - } - } else { - REQUIRE_TRUE(condition->lengthOf() == x->sizeAt(0), 0, "Condition length should be equal to the dim0 of x/y to act as TAD-mask, but got %d instead", condition->lengthOf()); - - auto dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {0}); - auto tadsX = x->allTensorsAlongDimension(dims); - auto tadsY = y->allTensorsAlongDimension(dims); - auto tadsZ = z->allTensorsAlongDimension(dims); - - for (int e = 0; e < tadsX.size(); e++) { - if (!condition->e(e)) { - tadsZ.at(e)->assign(tadsY.at(e)); - } else { - tadsZ.at(e)->assign(tadsX.at(e)); - } - } - } - } else { - // in this case we return 2D matrix, which basically contains coordinates fo true - REQUIRE_TRUE(block.width() == 1, 0, "Where op takes either 1 or 3 operands, But got %d operands instead", block.width()); - auto output = OUTPUT_VARIABLE(0); - - int width = condition->rankOf(); - if (z->isEmpty()) - return ND4J_STATUS_OK; - - std::vector dims = ShapeUtils::evalDimsToExclude(width, {0}); - - helpers::_where(block.launchContext(), *condition, *output, block.workspace()); - } - return Status::OK(); +namespace ops { +CUSTOM_OP_IMPL(Where, 1, 1, false, 0, 0) { + auto condition = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + if (z->isEmpty()) return Status::OK(); + + if (block.width() == 3) { + auto x = INPUT_VARIABLE(1); + auto y = INPUT_VARIABLE(2); + + REQUIRE_TRUE(x->isSameShape(y), 0, "X and Y must have equal shapes"); + + // if cond matches x/y shape - we have per-element mask + if (condition->isSameShape(x)) { + // FIXME: for perf it might be better to issue memcpy here, and fill only + // mismatched values from either X or Y + for (int e = 0; e < condition->lengthOf(); e++) { + if (y->isR()) { + auto r = !condition->e(e) ? y->e(e) : x->e(e); + z->p(e, r); + } else { + auto r = + !condition->e(e) ? y->e(e) : x->e(e); + z->p(e, r); } - - DECLARE_SHAPE_FN(Where) { - if (block.width() == 3) { - auto inShape = inputShape->at(1); - Nd4jLong *newshape; - COPY_SHAPE(inShape, newshape); - - return SHAPELIST(CONSTANT(newshape)); - } else { - // FIXME: we can't estimate result here in this case - // output shape is the 2D tensor num_true x rankOf (inShape) - auto condition = INPUT_VARIABLE(0); - auto inShape = inputShape->at(0); - Nd4jLong numOfTrue = 0; //condition->reduceNumber(reduce::CountNonZero, nullptr).e(0); - for (Nd4jLong i = 0; i < condition->lengthOf(); i++) - if (condition->e(i)) numOfTrue++; - - Nd4jLong const* theNewShape; - if (numOfTrue > 0) { - Nd4jLong* newShape; - ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong); - - newShape[0] = 2; - newShape[1] = numOfTrue; - newShape[2] = shape::rank(inShape); - newShape[3] = 1; - newShape[4] = 1; - newShape[5] = 0; - newShape[6] = 1; - newShape[7] = 99; - ShapeUtils::updateStridesAndType(newShape, sd::DataType::INT64, 'c'); - - theNewShape = CONSTANT(newShape); - } - else { - theNewShape = ConstantShapeHelper::getInstance().emptyShapeInfo(sd::DataType::INT64); - } - - return SHAPELIST(theNewShape); - } + } + } else { + REQUIRE_TRUE(condition->lengthOf() == x->sizeAt(0), 0, + "Condition length should be equal to the dim0 of x/y to act " + "as TAD-mask, but got %d instead", + condition->lengthOf()); + + auto dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {0}); + auto tadsX = x->allTensorsAlongDimension(dims); + auto tadsY = y->allTensorsAlongDimension(dims); + auto tadsZ = z->allTensorsAlongDimension(dims); + + for (int e = 0; e < tadsX.size(); e++) { + if (!condition->e(e)) { + tadsZ.at(e).assign(tadsY.at(e)); + } else { + tadsZ.at(e).assign(tadsX.at(e)); } + } + } + } else { + // in this case we return 2D matrix, which basically contains coordinates fo + // true + REQUIRE_TRUE( + block.width() == 1, 0, + "Where op takes either 1 or 3 operands, But got %d operands instead", + block.width()); + auto output = OUTPUT_VARIABLE(0); + + int width = condition->rankOf(); + if (z->isEmpty()) return ND4J_STATUS_OK; + + std::vector dims = ShapeUtils::evalDimsToExclude(width, {0}); + + helpers::_where(block.launchContext(), *condition, *output, + block.workspace()); + } + return Status::OK(); +} - DECLARE_TYPES(Where) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) // bool - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedInputTypes(2, DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_INTS, ALL_FLOATS}); - } +DECLARE_SHAPE_FN(Where) { + if (block.width() == 3) { + auto inShape = inputShape->at(1); + Nd4jLong* newshape; + COPY_SHAPE(inShape, newshape); + + return SHAPELIST(CONSTANT(newshape)); + } else { + // FIXME: we can't estimate result here in this case + // output shape is the 2D tensor num_true x rankOf (inShape) + auto condition = INPUT_VARIABLE(0); + auto inShape = inputShape->at(0); + Nd4jLong numOfTrue = 0; // condition->reduceNumber(reduce::CountNonZero, + // nullptr).e(0); + for (Nd4jLong i = 0; i < condition->lengthOf(); i++) + if (condition->e(i)) numOfTrue++; + + Nd4jLong const* theNewShape; + if (numOfTrue > 0) { + Nd4jLong* newShape; + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(2), + Nd4jLong); + + newShape[0] = 2; + newShape[1] = numOfTrue; + newShape[2] = shape::rank(inShape); + newShape[3] = 1; + newShape[4] = 1; + newShape[5] = 0; + newShape[6] = 1; + newShape[7] = 99; + ShapeUtils::updateStridesAndType(newShape, sd::DataType::INT64, 'c'); + + theNewShape = CONSTANT(newShape); + } else { + theNewShape = ConstantShapeHelper::getInstance().emptyShapeInfo( + sd::DataType::INT64); } + + return SHAPELIST(theNewShape); + } +} + +DECLARE_TYPES(Where) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) // bool + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedInputTypes(2, DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_INTS, ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp b/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp index 23284b2f9859..ca08421b79d8 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp @@ -18,8 +18,8 @@ // @author Adam Gibson // -#include #include +#include #if NOT_EXCLUDED(OP_where_np) @@ -27,136 +27,142 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(where_np, -1, 1, false, 0, 0) { - auto condition = INPUT_VARIABLE(0); - - if (block.width() == 3) { - auto x = INPUT_VARIABLE(1); - auto y = INPUT_VARIABLE(2); - - auto z = OUTPUT_VARIABLE(0); - int numMatches = 0; - // if cond matches x/y shape - we have per-element mask - if (condition->isSameShape(x)) { - // FIXME: for perf it might be better to issue memcpy here, and fill only mismatched values from either X or Y - if(y->isScalar()) { - if (y->isR()) { - for (int e = 0; e < condition->lengthOf(); e++) { - auto r = condition->e(e) ? y->e(0) - : x->e(e); - z->p(e, r); - } - } else { - for (int e = 0; e < condition->lengthOf(); e++) { - auto r = condition->e(e) ? y->e(0) - : x->e(e); - z->p(e, r); - } - } - } - else { - if (y->isR()) { - for (int e = 0; e < condition->lengthOf(); e++) { - if (condition->e(e)) { - auto r = y->e(numMatches); - z->p(e, r); - numMatches++; - } else { - auto r = x->e(e); - z->p(e, r); - } - } - } else { - for (int e = 0; e < condition->lengthOf(); e++) { - if (condition->e(e)) { - auto r = y->e(numMatches); - z->p(e, r); - numMatches++; - } else { - auto r = x->e(e); - z->p(e, r); - } - } - } - } - } - else { - REQUIRE_TRUE(condition->lengthOf() == x->sizeAt(0), 0, "Condition length should be equal to the dim0 of x/y to act as TAD-mask, but got %d instead", condition->lengthOf()); - - auto dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {0}); - auto tadsX = x->allTensorsAlongDimension(dims); - auto tadsY = y->allTensorsAlongDimension(dims); - auto tadsZ = z->allTensorsAlongDimension(dims); - - for (int e = 0; e < tadsX.size(); e++) { - if (!condition->e(e)) - tadsZ.at(e)->assign(tadsY.at(e)); - else - tadsZ.at(e)->assign(tadsX.at(e)); - } - } +namespace ops { +CUSTOM_OP_IMPL(where_np, -1, 1, false, 0, 0) { + auto condition = INPUT_VARIABLE(0); + + if (block.width() == 3) { + auto x = INPUT_VARIABLE(1); + auto y = INPUT_VARIABLE(2); + + auto z = OUTPUT_VARIABLE(0); + int numMatches = 0; + // if cond matches x/y shape - we have per-element mask + if (condition->isSameShape(x)) { + // FIXME: for perf it might be better to issue memcpy here, and fill only + // mismatched values from either X or Y + if (y->isScalar()) { + if (y->isR()) { + for (int e = 0; e < condition->lengthOf(); e++) { + auto r = condition->e(e) ? y->e(0) : x->e(e); + z->p(e, r); + } + } else { + for (int e = 0; e < condition->lengthOf(); e++) { + auto r = + condition->e(e) ? y->e(0) : x->e(e); + z->p(e, r); + } + } + } else { + if (y->isR()) { + for (int e = 0; e < condition->lengthOf(); e++) { + if (condition->e(e)) { + auto r = y->e(numMatches); + z->p(e, r); + numMatches++; } else { - // in this case we return 2D matrix, which basically contains coordinates fo true - - REQUIRE_TRUE(block.width() == 1, 0, "Where op takes either 1 or 3 operands, But got %d operands instead", block.width()); -// if (output->isEmpty()) - Nd4jLong width = condition->rankOf(); - - sd::ops::Where op; - auto res(op.evaluate({condition})); - REQUIRE_OK(res.status()); - NDArray* whereTrue = res.at(0); - - if (whereTrue->isEmpty()) - return ND4J_STATUS_OK; - for (Nd4jLong outNext = 0; outNext < width; ++outNext) { - auto output = OUTPUT_VARIABLE(outNext); - for (Nd4jLong e = 0; e < output->lengthOf(); ++e) { - output->p(e, whereTrue->e(e, outNext)); - } - } + auto r = x->e(e); + z->p(e, r); } - - return Status::OK(); - } - - DECLARE_SHAPE_FN(where_np) { - auto shapes = SHAPELIST(); - Nd4jLong *newShape; - if (block.width() == 3) { - auto inShape = inputShape->at(1); - COPY_SHAPE(inShape, newShape); - - shapes->push_back(CONSTANT(newShape)); + } + } else { + for (int e = 0; e < condition->lengthOf(); e++) { + if (condition->e(e)) { + auto r = y->e(numMatches); + z->p(e, r); + numMatches++; } else { - auto condition = INPUT_VARIABLE(0); - - Nd4jLong numOfTrue = 0LL; //condition->reduceNumber(reduce::CountNonZero).e(0); - for (Nd4jLong i = 0; i < condition->lengthOf(); ++i) - if (condition->e(i)) numOfTrue++; - - // output shape - a tuple of rank(inShape) 1D tensors with numOfTrue len - if (numOfTrue) { - for (Nd4jLong e = 0; e < condition->rankOf(); ++e) { - shapes->push_back(ConstantShapeHelper::getInstance().vectorShapeInfo(numOfTrue, sd::DataType::INT64)); - } - } - else { - shapes->push_back(ConstantShapeHelper::getInstance().emptyShapeInfo(sd::DataType::INT64)); - } + auto r = x->e(e); + z->p(e, r); } - return shapes; + } } + } + } else { + REQUIRE_TRUE(condition->lengthOf() == x->sizeAt(0), 0, + "Condition length should be equal to the dim0 of x/y to act " + "as TAD-mask, but got %d instead", + condition->lengthOf()); + + auto dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {0}); + auto tadsX = x->allTensorsAlongDimension(dims); + auto tadsY = y->allTensorsAlongDimension(dims); + auto tadsZ = z->allTensorsAlongDimension(dims); + + for (int e = 0; e < tadsX.size(); e++) { + if (!condition->e(e)) + tadsZ.at(e).assign(tadsY.at(e)); + else + tadsZ.at(e).assign(tadsX.at(e)); + } + } + } else { + // in this case we return 2D matrix, which basically contains coordinates fo + // true + + REQUIRE_TRUE( + block.width() == 1, 0, + "Where op takes either 1 or 3 operands, But got %d operands instead", + block.width()); + // if (output->isEmpty()) + Nd4jLong width = condition->rankOf(); + + sd::ops::Where op; + auto res(op.evaluate({condition})); + REQUIRE_OK(res.status()); + auto& whereTrue = res.at(0); + + if (whereTrue.isEmpty()) return ND4J_STATUS_OK; + for (Nd4jLong outNext = 0; outNext < width; ++outNext) { + auto output = OUTPUT_VARIABLE(outNext); + for (Nd4jLong e = 0; e < output->lengthOf(); ++e) { + output->p(e, whereTrue.e(e, outNext)); + } + } + } - DECLARE_TYPES(where_np) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::BOOL) - ->setAllowedInputTypes(1, sd::DataType::ANY) - ->setAllowedInputTypes(2, sd::DataType::ANY) - ->setAllowedOutputTypes( {ALL_FLOATS, ALL_INTS}); - } + return Status::OK(); +} + +DECLARE_SHAPE_FN(where_np) { + auto shapes = SHAPELIST(); + Nd4jLong* newShape; + if (block.width() == 3) { + auto inShape = inputShape->at(1); + COPY_SHAPE(inShape, newShape); + + shapes->push_back(CONSTANT(newShape)); + } else { + auto condition = INPUT_VARIABLE(0); + + Nd4jLong numOfTrue = + 0LL; // condition->reduceNumber(reduce::CountNonZero).e(0); + for (Nd4jLong i = 0; i < condition->lengthOf(); ++i) + if (condition->e(i)) numOfTrue++; + + // output shape - a tuple of rank(inShape) 1D tensors with numOfTrue len + if (numOfTrue) { + for (Nd4jLong e = 0; e < condition->rankOf(); ++e) { + shapes->push_back(ConstantShapeHelper::getInstance().vectorShapeInfo( + numOfTrue, sd::DataType::INT64)); + } + } else { + shapes->push_back(ConstantShapeHelper::getInstance().emptyShapeInfo( + sd::DataType::INT64)); } + } + return shapes; +} + +DECLARE_TYPES(where_np) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::BOOL) + ->setAllowedInputTypes(1, sd::DataType::ANY) + ->setAllowedInputTypes(2, sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/add.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/add.cpp index 936addea5644..35588f743300 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/add.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/add.cpp @@ -21,97 +21,98 @@ #include #if NOT_EXCLUDED(OP_add) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(add, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(sd::BroadcastOpsTuple::Add(), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) - throw std::runtime_error("add: result was replaced"); - - - return Status::OK(); - } - - DECLARE_TYPES(add) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(DataType::ANY); - } - - DECLARE_TYPES(add_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - - CUSTOM_OP_IMPL(add_bp, 3, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - - if (x->isSameShape(y)) { - // PWT case case - gradY->assign(epsNext); - gradX->assign(epsNext); - } else if (y->isScalar()) { - // scalar case - auto tmp = epsNext->reduceNumber(sd::reduce::Sum); - gradY->assign(tmp); - gradX->assign(epsNext); - } else { - // broadcast case - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); - - if (axisX.size() > 0) { - auto sum = epsNext->reduceAlongDimension(sd::reduce::Sum, axisX); - gradX->assign(sum); - } else - gradX->assign(epsNext); - - if (axisY.size() > 0) { - auto sum = epsNext->reduceAlongDimension(sd::reduce::Sum, axisY); - gradY->assign(sum); - } else - gradY->assign(epsNext); - } - - return Status::OK(); - } - - DECLARE_SHAPE_FN(add_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); - - // eps always has shape of x - // grad always has shape of y - - Nd4jLong *shapeE; - Nd4jLong *shapeG; - - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); - - return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(add, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = + BroadcastHelper::broadcastApply(sd::BroadcastOpsTuple::Add(), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) + throw std::runtime_error("add: result was replaced"); + + return Status::OK(); +} + +DECLARE_TYPES(add) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(DataType::ANY); +} + +DECLARE_TYPES(add_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +CUSTOM_OP_IMPL(add_bp, 3, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + + if (x->isSameShape(y)) { + // PWT case case + gradY->assign(epsNext); + gradX->assign(epsNext); + } else if (y->isScalar()) { + // scalar case + auto tmp = epsNext->reduceNumber(sd::reduce::Sum); + gradY->assign(tmp); + gradX->assign(epsNext); + } else { + // broadcast case + auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), + epsNext->shapeInfo()); + auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), + epsNext->shapeInfo()); + + if (axisX.size() > 0) { + auto sum = epsNext->reduceAlongDimension(sd::reduce::Sum, axisX); + gradX->assign(sum); + } else + gradX->assign(epsNext); + + if (axisY.size() > 0) { + auto sum = epsNext->reduceAlongDimension(sd::reduce::Sum, axisY); + gradY->assign(sum); + } else + gradY->assign(epsNext); + } + + return Status::OK(); +} + +DECLARE_SHAPE_FN(add_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); + + // eps always has shape of x + // grad always has shape of y + + Nd4jLong *shapeE; + Nd4jLong *shapeG; + + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); + + return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/assign.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/assign.cpp index aeaa5d128535..7d0d5d87b3f1 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/assign.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/assign.cpp @@ -21,91 +21,93 @@ #include #if NOT_EXCLUDED(OP_assign) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(assign, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(sd::BroadcastOpsTuple::Assign(), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return ND4J_STATUS_OK; - } - DECLARE_SYN(set, assign); - DECLARE_SYN(copy, assign); - - DECLARE_TYPES(assign) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - - DECLARE_TYPES(assign_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - CUSTOM_OP_IMPL(assign_bp, 3, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - - gradX->assign(0.0f); - - if (x->isSameShape(y)) { - gradY->assign(epsNext); - } else if (y->isScalar()) { - auto sum = epsNext->reduceNumber(sd::reduce::Sum); - gradY->assign(sum); - } else { - // broadcastable - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); - - if (axisY.size() > 0) { - auto sum = epsNext->reduceAlongDimension(sd::reduce::Sum, axisY); - gradY->assign(sum); - } else - gradY->assign(epsNext); - } - - return Status::OK(); - } - - DECLARE_SHAPE_FN(assign_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); - - // eps always has shape of x - // grad always has shape of y - - Nd4jLong *shapeE; - Nd4jLong *shapeG; - - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); - - auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - - return shapeList; - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(assign, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = + BroadcastHelper::broadcastApply(sd::BroadcastOpsTuple::Assign(), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return ND4J_STATUS_OK; +} +DECLARE_SYN(set, assign); +DECLARE_SYN(copy, assign); + +DECLARE_TYPES(assign) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} + +DECLARE_TYPES(assign_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +CUSTOM_OP_IMPL(assign_bp, 3, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + + gradX->assign(0.0f); + + if (x->isSameShape(y)) { + gradY->assign(epsNext); + } else if (y->isScalar()) { + auto sum = epsNext->reduceNumber(sd::reduce::Sum); + gradY->assign(sum); + } else { + // broadcastable + auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), + epsNext->shapeInfo()); + + if (axisY.size() > 0) { + auto sum = epsNext->reduceAlongDimension(sd::reduce::Sum, axisY); + gradY->assign(sum); + } else + gradY->assign(epsNext); + } + + return Status::OK(); +} + +DECLARE_SHAPE_FN(assign_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); + + // eps always has shape of x + // grad always has shape of y + + Nd4jLong *shapeE; + Nd4jLong *shapeG; + + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); + + auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); + + return shapeList; } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/atan2.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/atan2.cpp index ed60f59250d7..cf20c04e01c9 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/atan2.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/atan2.cpp @@ -28,33 +28,35 @@ namespace sd { namespace ops { BROADCASTABLE_OP_IMPL(tf_atan2, 0, 0) { + auto y = INPUT_VARIABLE(0); + auto x = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(0); - auto x = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + BROADCAST_CHECK_EMPTY(x, y, z); - BROADCAST_CHECK_EMPTY(x,y,z); + // auto tZ = BroadcastHelper::template broadcastApply>(y, + // x, z); + x->applyTrueBroadcast(sd::BroadcastOpsTuple::custom( + scalar::Atan2, pairwise::Atan2, broadcast::Atan2), + *y, *z, true); - // auto tZ = BroadcastHelper::template broadcastApply>(y, x, z); - x->applyTrueBroadcast(sd::BroadcastOpsTuple::custom(scalar::Atan2, pairwise::Atan2, broadcast::Atan2), *y, *z, true); + // if (tZ == nullptr) + // return ND4J_STATUS_KERNEL_FAILURE; + // else if (tZ != z) { + // OVERWRITE_RESULT(tZ); + // } - // if (tZ == nullptr) - // return ND4J_STATUS_KERNEL_FAILURE; - // else if (tZ != z) { - // OVERWRITE_RESULT(tZ); - // } - - return Status::OK(); + return Status::OK(); } - DECLARE_TYPES(tf_atan2) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - -} +DECLARE_TYPES(tf_atan2) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/boolean_and.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/boolean_and.cpp index 32593ecf62f8..bfdd6818f202 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/boolean_and.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/boolean_and.cpp @@ -24,30 +24,33 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(boolean_and, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::custom(scalar::LogicalAnd, pairwise::LogicalAnd, broadcast::LogicalAnd), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) - throw std::runtime_error("boolean_and: result was overwritten"); - - return Status::OK(); - } - - DECLARE_TYPES(boolean_and) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(boolean_and, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = BroadcastHelper::broadcastApply( + BroadcastOpsTuple::custom(scalar::LogicalAnd, pairwise::LogicalAnd, + broadcast::LogicalAnd), + x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) + throw std::runtime_error("boolean_and: result was overwritten"); + + return Status::OK(); } +DECLARE_TYPES(boolean_and) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/boolean_or.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/boolean_or.cpp index 1dbb69f30846..f34f0c60aa27 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/boolean_or.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/boolean_or.cpp @@ -24,30 +24,33 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(boolean_or, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::custom(scalar::LogicalOr, pairwise::LogicalOr, broadcast::LogicalOr), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) - throw std::runtime_error("boolean_and: result was overwritten"); - - return Status::OK(); - } - - DECLARE_TYPES(boolean_or) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(boolean_or, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = BroadcastHelper::broadcastApply( + BroadcastOpsTuple::custom(scalar::LogicalOr, pairwise::LogicalOr, + broadcast::LogicalOr), + x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) + throw std::runtime_error("boolean_and: result was overwritten"); + + return Status::OK(); } +DECLARE_TYPES(boolean_or) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/boolean_xor.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/boolean_xor.cpp index 8f242fbda84d..dc49a4c7eabd 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/boolean_xor.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/boolean_xor.cpp @@ -24,30 +24,33 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(boolean_xor, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::custom(scalar::LogicalXor, pairwise::LogicalXor, broadcast::LogicalXor), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) - throw std::runtime_error("boolean_xor: result was overwritten"); - - return Status::OK(); - } - - DECLARE_TYPES(boolean_xor) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(boolean_xor, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = BroadcastHelper::broadcastApply( + BroadcastOpsTuple::custom(scalar::LogicalXor, pairwise::LogicalXor, + broadcast::LogicalXor), + x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) + throw std::runtime_error("boolean_xor: result was overwritten"); + + return Status::OK(); } +DECLARE_TYPES(boolean_xor) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/divide.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/divide.cpp index cd907de366e0..0b192967f58e 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/divide.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/divide.cpp @@ -21,125 +21,128 @@ #include #if NOT_EXCLUDED(OP_divide) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(divide, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - REQUIRE_TRUE(!y->isB(), 0, "DIVIDE OP: you can't divide by bool array!"); - auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::Divide(), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - DECLARE_SYN(Div, divide); - - DECLARE_TYPES(divide) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - - DECLARE_TYPES(divide_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - CUSTOM_OP_IMPL(divide_bp, 3, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - -/* - auto lambdaY = LAMBDA_TTT(_e, _x, _y) { - return _e * -_x / (_y * _y); - }; -*/ - - if (x->isSameShape(y)) { - // PWT case case - - // X gradient - //epsNext->applyPairwiseLambda(y, lambdaX, gradX); - gradX->assign((*epsNext) / (*y)); - // Y gradient - //epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); - gradY->assign((*epsNext) * (*x) / ((*y) * (*y))); - gradY->applyTransform(transform::Neg, *gradY); - - } else if (y->isScalar()) { - // scalar case - - auto tmp = epsNext->reduceNumber(reduce::Sum); - auto tmpX = x->reduceNumber(reduce::Sum); - //tmpX.printBuffer("SumX"); - //tmp.printBuffer("Sum Eps"); - gradY->assign(tmp * tmpX / ((*y) * (*y))); - gradY->applyTransform(transform::Neg, *gradY); - - //epsNext->applyLambda(lambdaS, *gradX); - epsNext->applyScalarArr(scalar::Divide, *y, *gradX); - } else { - // broadcast case - - auto preX = *epsNext / *y; - - NDArray negX(*x); - x->applyTransform(transform::Neg, negX); - auto preY = *epsNext * negX / ((*y) * (*y)); - - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); - - if (axisX.size() > 0) { - auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); - gradX->assign(sum); - } else - gradX->assign(preX); - - if (axisY.size() > 0) { - auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); - gradY->assign(sum); - } else - gradY->assign(preY); - } - - return Status::OK(); - } - - DECLARE_SHAPE_FN(divide_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); - - // eps always has shape of x - // grad always has shape of y - - Nd4jLong *shapeE; - Nd4jLong *shapeG; - - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); - - return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(divide, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + REQUIRE_TRUE(!y->isB(), 0, "DIVIDE OP: you can't divide by bool array!"); + auto tZ = + BroadcastHelper::broadcastApply(BroadcastOpsTuple::Divide(), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} +DECLARE_SYN(Div, divide); + +DECLARE_TYPES(divide) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} + +DECLARE_TYPES(divide_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +CUSTOM_OP_IMPL(divide_bp, 3, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + + /* + auto lambdaY = LAMBDA_TTT(_e, _x, _y) { + return _e * -_x / (_y * _y); + }; + */ + + if (x->isSameShape(y)) { + // PWT case case + + // X gradient + // epsNext->applyPairwiseLambda(y, lambdaX, gradX); + gradX->assign((*epsNext) / (*y)); + // Y gradient + // epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); + gradY->assign((*epsNext) * (*x) / ((*y) * (*y))); + gradY->applyTransform(transform::Neg, *gradY); + + } else if (y->isScalar()) { + // scalar case + + auto tmp = epsNext->reduceNumber(reduce::Sum); + auto tmpX = x->reduceNumber(reduce::Sum); + // tmpX.printBuffer("SumX"); + // tmp.printBuffer("Sum Eps"); + gradY->assign(tmp * tmpX / ((*y) * (*y))); + gradY->applyTransform(transform::Neg, *gradY); + + // epsNext->applyLambda(lambdaS, *gradX); + epsNext->applyScalarArr(scalar::Divide, *y, *gradX); + } else { + // broadcast case + + auto preX = *epsNext / *y; + + NDArray negX(*x); + x->applyTransform(transform::Neg, negX); + auto preY = *epsNext * negX / ((*y) * (*y)); + + auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), + epsNext->shapeInfo()); + auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), + epsNext->shapeInfo()); + + if (axisX.size() > 0) { + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); + gradX->assign(sum); + } else + gradX->assign(preX); + + if (axisY.size() > 0) { + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); + gradY->assign(sum); + } else + gradY->assign(preY); + } + + return Status::OK(); +} + +DECLARE_SHAPE_FN(divide_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); + + // eps always has shape of x + // grad always has shape of y + + Nd4jLong *shapeE; + Nd4jLong *shapeG; + + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); + + return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/divide_no_nan.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/divide_no_nan.cpp index 9ef300e1c4fa..f20a7b43061c 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/divide_no_nan.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/divide_no_nan.cpp @@ -21,37 +21,39 @@ #include #if NOT_EXCLUDED(OP_divide_no_nan) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(divide_no_nan, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - REQUIRE_TRUE(!y->isB(), 0, "DIVIDE_NO_NAN OP: you can't divide by bool array!"); - auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::DivideNoNan(), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - DECLARE_SYN(Div, divide); - - DECLARE_TYPES(divide_no_nan) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(divide_no_nan, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + REQUIRE_TRUE(!y->isB(), 0, + "DIVIDE_NO_NAN OP: you can't divide by bool array!"); + auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::DivideNoNan(), x, + y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} +DECLARE_SYN(Div, divide); + +DECLARE_TYPES(divide_no_nan) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/equals.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/equals.cpp index 5d4aaef5e44a..2ad612b04ac6 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/equals.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/equals.cpp @@ -18,35 +18,38 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_BOOL_OP_IMPL(equals, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(BroadcastBoolOpsTuple::custom(scalar::EqualTo, pairwise::EqualTo, broadcast::EqualTo), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - - DECLARE_SYN(equal, equals); - - DECLARE_TYPES(equals) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } -} \ No newline at end of file +namespace ops { +BROADCASTABLE_BOOL_OP_IMPL(equals, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = BroadcastHelper::broadcastApply( + BroadcastBoolOpsTuple::custom(scalar::EqualTo, pairwise::EqualTo, + broadcast::EqualTo), + x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} + +DECLARE_SYN(equal, equals); + +DECLARE_TYPES(equals) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/floordiv.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/floordiv.cpp index d0a59bcc196d..b0384f9c5054 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/floordiv.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/floordiv.cpp @@ -21,75 +21,77 @@ #include #if NOT_EXCLUDED(OP_floordiv) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(floordiv, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - REQUIRE_TRUE(!y->isB(), 0, "FLOORDIV OP: you can't divide by bool array!"); - auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::custom(scalar::FloorDiv, pairwise::FloorDiv, broadcast::FloorDiv), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - - - DECLARE_TYPES(floordiv) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - - DECLARE_TYPES(floordiv_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - CUSTOM_OP_IMPL(floordiv_bp, 3, 2, false, 0, 0) { - // PLEASE NOTE: we're just passing eps down the line here - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - - gradY->assign(0.0f); - gradX->assign(0.0f); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(floordiv_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); - - // eps always has shape of x - // grad always has shape of y - - Nd4jLong *shapeE; - Nd4jLong *shapeG; - - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); - - return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(floordiv, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + REQUIRE_TRUE(!y->isB(), 0, "FLOORDIV OP: you can't divide by bool array!"); + auto tZ = BroadcastHelper::broadcastApply( + BroadcastOpsTuple::custom(scalar::FloorDiv, pairwise::FloorDiv, + broadcast::FloorDiv), + x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} + +DECLARE_TYPES(floordiv) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} + +DECLARE_TYPES(floordiv_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +CUSTOM_OP_IMPL(floordiv_bp, 3, 2, false, 0, 0) { + // PLEASE NOTE: we're just passing eps down the line here + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + + gradY->assign(0.0f); + gradX->assign(0.0f); + + return Status::OK(); +} + +DECLARE_SHAPE_FN(floordiv_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); + + // eps always has shape of x + // grad always has shape of y + + Nd4jLong *shapeE; + Nd4jLong *shapeG; + + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); + + return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp index fac2099055c8..744b44ee7db6 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp @@ -14,93 +14,93 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author raver119@gmail.com - // modified by sgazeos@gmail.com with backprop implementation. - // +// +// @author raver119@gmail.com +// modified by sgazeos@gmail.com with backprop implementation. +// #include #if NOT_EXCLUDED(OP_floormod) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(floormod, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x, y, z); - - REQUIRE_TRUE(!y->isB(), 0, "FLOORMOD OP: you can't divide by bool array!"); - auto tZ = BroadcastHelper::broadcastApply(BROADCAST(FloorMod), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - - DECLARE_TYPES(floormod) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - - DECLARE_TYPES(floormod_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ ALL_FLOATS }); - } - - CUSTOM_OP_IMPL(floormod_bp, 3, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - gradX->assign(epsNext); - - NDArray temp(*epsNext); - BroadcastHelper::broadcastApply(BROADCAST(FloorMod), x, y, &temp); - - if (gradY->rankOf() == gradX->rankOf()) - epsNext->applyPairwiseTransform(pairwise::Multiply, temp, *gradY); - else // epsNext is greater than gradY - { - std::vector dims(epsNext->rankOf() * 2); - Nd4jLong gap = epsNext->rankOf() - gradY->rankOf(); - for (Nd4jLong d = 0; d < gap; d++) { - dims[d * 2 + 1] = 1; - } - auto tempIn((temp)(dims)); - (*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, tempIn, *gradY); - } - return Status::OK(); - } - - DECLARE_SHAPE_FN(floormod_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); - - // eps always has shape of x - // grad always has shape of y - - Nd4jLong* shapeE; - Nd4jLong* shapeG; - - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); - - return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - } +namespace ops { +BROADCASTABLE_OP_IMPL(floormod, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + REQUIRE_TRUE(!y->isB(), 0, "FLOORMOD OP: you can't divide by bool array!"); + auto tZ = BroadcastHelper::broadcastApply(BROADCAST(FloorMod), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} + +DECLARE_TYPES(floormod) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} + +DECLARE_TYPES(floormod_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +CUSTOM_OP_IMPL(floormod_bp, 3, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + gradX->assign(epsNext); + + auto temp = epsNext->dup(); + BroadcastHelper::broadcastApply(BROADCAST(FloorMod), x, y, &temp); + + if (gradY->rankOf() == gradX->rankOf()) + epsNext->applyPairwiseTransform(pairwise::Multiply, temp, *gradY); + else // epsNext is greater than gradY + { + std::vector dims(epsNext->rankOf() * 2); + Nd4jLong gap = epsNext->rankOf() - gradY->rankOf(); + for (Nd4jLong d = 0; d < gap; d++) { + dims[d * 2 + 1] = 1; } + auto tempIn((temp)(dims)); + (*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, tempIn, *gradY); + } + return Status::OK(); +} + +DECLARE_SHAPE_FN(floormod_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); + + // eps always has shape of x + // grad always has shape of y + + Nd4jLong* shapeE; + Nd4jLong* shapeG; + + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); + + return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/greater.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/greater.cpp index 084453dc8dea..56b5cbaf7f6d 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/greater.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/greater.cpp @@ -22,29 +22,30 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_BOOL_OP_IMPL(greater, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_BOOL_OP_IMPL(greater, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } + auto tZ = + BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(greater) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } -} \ No newline at end of file +DECLARE_TYPES(greater) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/greater_equal.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/greater_equal.cpp index 5f448585eb4f..571017000603 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/greater_equal.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/greater_equal.cpp @@ -21,29 +21,30 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_BOOL_OP_IMPL(greater_equal, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_BOOL_OP_IMPL(greater_equal, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThanOrEqual), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } + auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThanOrEqual), + x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(greater_equal) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } -} \ No newline at end of file +DECLARE_TYPES(greater_equal) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/igamma.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/igamma.cpp index 9fa07424c04a..3f93ddb4d292 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/igamma.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/igamma.cpp @@ -21,39 +21,41 @@ #include #if NOT_EXCLUDED(OP_igamma) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(igamma, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - //REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!"); - -// auto tZ = BroadcastHelper::broadcastApply({scalar::IGamma, pairwise::IGamma, broadcast::IGamma}, x, y, z); - auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::IGamma(), x, y, z); - - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - - DECLARE_TYPES(igamma) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(igamma, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + // REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!"); + + // auto tZ = BroadcastHelper::broadcastApply({scalar::IGamma, + // pairwise::IGamma, broadcast::IGamma}, x, y, z); + auto tZ = + BroadcastHelper::broadcastApply(BroadcastOpsTuple::IGamma(), x, y, z); + + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} + +DECLARE_TYPES(igamma) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/igammac.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/igammac.cpp index deeacd4ef59c..e4d9682d0a5b 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/igammac.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/igammac.cpp @@ -21,38 +21,40 @@ #include #if NOT_EXCLUDED(OP_igammac) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(igammac, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - //REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!"); - -// auto tZ = BroadcastHelper::broadcastApply({scalar::IGammac, pairwise::IGammac, broadcast::IGammac}, x, y, z); - auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::IGammac(), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - - DECLARE_TYPES(igammac) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(igammac, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + // REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!"); + + // auto tZ = BroadcastHelper::broadcastApply({scalar::IGammac, + // pairwise::IGammac, broadcast::IGammac}, x, y, z); + auto tZ = + BroadcastHelper::broadcastApply(BroadcastOpsTuple::IGammac(), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} + +DECLARE_TYPES(igammac) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/less.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/less.cpp index 5d9c73f1b2d2..2f0aec719cfe 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/less.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/less.cpp @@ -21,29 +21,29 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_BOOL_OP_IMPL(less, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_BOOL_OP_IMPL(less, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(LessThan), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } + auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(LessThan), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(less) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } -} \ No newline at end of file +DECLARE_TYPES(less) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/less_equal.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/less_equal.cpp index a0f0a0366029..1a5779a6f1f5 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/less_equal.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/less_equal.cpp @@ -21,29 +21,30 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_BOOL_OP_IMPL(less_equal, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_BOOL_OP_IMPL(less_equal, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(LessThanOrEqual), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } + auto tZ = + BroadcastHelper::broadcastApply(BROADCAST_BOOL(LessThanOrEqual), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(less_equal) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } -} \ No newline at end of file +DECLARE_TYPES(less_equal) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/maximum.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/maximum.cpp index dfb6b3d66954..d5634035ed93 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/maximum.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/maximum.cpp @@ -21,70 +21,70 @@ #include #if NOT_EXCLUDED(OP_maximum) -#include #include +#include #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(maximum, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(BROADCAST(MaxPairwise), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - - DECLARE_TYPES(maximum) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - - DECLARE_TYPES(maximum_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - CUSTOM_OP_IMPL(maximum_bp, 3, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - - helpers::maximumBPFunctor(block.launchContext(), x, y, epsNext, gradX, gradY); - return Status::OK(); - } - - DECLARE_SHAPE_FN(maximum_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); - - // eps always has shape of x - // grad always has shape of y - - Nd4jLong *shapeE; - Nd4jLong *shapeG; - - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); - - return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(maximum, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = BroadcastHelper::broadcastApply(BROADCAST(MaxPairwise), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} + +DECLARE_TYPES(maximum) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} + +DECLARE_TYPES(maximum_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +CUSTOM_OP_IMPL(maximum_bp, 3, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + + helpers::maximumBPFunctor(block.launchContext(), x, y, epsNext, gradX, gradY); + return Status::OK(); +} + +DECLARE_SHAPE_FN(maximum_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); + + // eps always has shape of x + // grad always has shape of y + + Nd4jLong *shapeE; + Nd4jLong *shapeG; + + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); + + return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/meshgrid.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/meshgrid.cpp index b07f50202c08..c3e5ca77d36e 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/meshgrid.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/meshgrid.cpp @@ -22,72 +22,72 @@ #if NOT_EXCLUDED(OP_meshgrid) #include -#include +#include + #include namespace sd { -namespace ops { +namespace ops { -CUSTOM_OP_IMPL(meshgrid, -1, -1, false, 0, 0) { - - int rank = block.width(); +CUSTOM_OP_IMPL(meshgrid, -1, -1, false, 0, 0) { + int rank = block.width(); - if(rank == 1) { - OUTPUT_VARIABLE(0)->assign(INPUT_VARIABLE(0)); - return Status::OK(); - } + if (rank == 1) { + OUTPUT_VARIABLE(0)->assign(INPUT_VARIABLE(0)); + return Status::OK(); + } - bool swapFirst2Dims = block.getIArguments()->size() > 0 ? (bool)INT_ARG(0) : true; + bool swapFirst2Dims = block.numI() > 0 ? (bool)INT_ARG(0) : true; - std::vector inArrs(rank); - std::vector outArrs(rank); + std::vector inArrs(rank); + std::vector outArrs(rank); - for(int i = 0; i < rank; ++i) { - inArrs[i] = INPUT_VARIABLE(i); - outArrs[i] = OUTPUT_VARIABLE(i); - } + for (int i = 0; i < rank; ++i) { + inArrs[i] = INPUT_VARIABLE(i); + outArrs[i] = OUTPUT_VARIABLE(i); + } - helpers::meshgrid(block.launchContext(), inArrs, outArrs, swapFirst2Dims); + helpers::meshgrid(block.launchContext(), inArrs, outArrs, swapFirst2Dims); - return Status::OK(); + return Status::OK(); } - DECLARE_TYPES(meshgrid) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes(DataType::INHERIT) - ->setSameMode(true); - } - +DECLARE_TYPES(meshgrid) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes(DataType::INHERIT) + ->setSameMode(true); +} DECLARE_SHAPE_FN(meshgrid) { - bool swapFirst2Dims = block.getIArguments()->size() > 0 ? (bool)INT_ARG(0) : true; - - int rank = block.width(); - Nd4jLong* outShapeInfo = nullptr; - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); - outShapeInfo[0] = rank; - for(int i = 1; i <= rank; ++i) - outShapeInfo[i] = (Nd4jLong)shape::length(inputShape->at(i - 1)); - - if(swapFirst2Dims && rank > 1) - math::nd4j_swap(outShapeInfo[1], outShapeInfo[2]); - - auto in = inputShape->at(0); - ShapeUtils::updateStridesAndType(outShapeInfo, in, shape::order(in)); - - auto shapes = SHAPELIST(); - auto resultShape = CONSTANT(outShapeInfo); + bool swapFirst2Dims = block.numI() > 0 ? (bool)INT_ARG(0) : true; + + int rank = block.width(); + Nd4jLong* outShapeInfo = nullptr; + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); + outShapeInfo[0] = rank; + for (int i = 1; i <= rank; ++i) + outShapeInfo[i] = (Nd4jLong)shape::length(inputShape->at(i - 1)); + + if (swapFirst2Dims && rank > 1) + math::nd4j_swap(outShapeInfo[1], outShapeInfo[2]); + + auto in = inputShape->at(0); + ShapeUtils::updateStridesAndType(outShapeInfo, in, shape::order(in)); + + auto shapes = SHAPELIST(); + auto resultShape = CONSTANT(outShapeInfo); + shapes->push_back(resultShape); + + for (int i = 2; i <= rank; ++i) { shapes->push_back(resultShape); + } - for(int i = 2; i <= rank; ++i) { - shapes->push_back(resultShape); - } - - return shapes; + return shapes; } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/minimum.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/minimum.cpp index ef8645d1dbb2..854922fc1c22 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/minimum.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/minimum.cpp @@ -26,66 +26,65 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(minimum, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(BROADCAST(MinPairwise),x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return ND4J_STATUS_OK; - } - - DECLARE_TYPES(minimum) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - - DECLARE_TYPES(minimum_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - CUSTOM_OP_IMPL(minimum_bp, 3, 2, false, 0, 0) { - - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - helpers::minimumBPFunctor(block.launchContext(), x, y, epsNext, gradX, gradY); - return Status::OK(); - } - - DECLARE_SHAPE_FN(minimum_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); - - // eps always has shape of x - // grad always has shape of y - - Nd4jLong *shapeE; - Nd4jLong *shapeG; - - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); - - return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(minimum, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = BroadcastHelper::broadcastApply(BROADCAST(MinPairwise), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return ND4J_STATUS_OK; } +DECLARE_TYPES(minimum) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} + +DECLARE_TYPES(minimum_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +CUSTOM_OP_IMPL(minimum_bp, 3, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + helpers::minimumBPFunctor(block.launchContext(), x, y, epsNext, gradX, gradY); + return Status::OK(); +} + +DECLARE_SHAPE_FN(minimum_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); + + // eps always has shape of x + // grad always has shape of y + + Nd4jLong *shapeE; + Nd4jLong *shapeG; + + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); + + return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); +} +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/mod.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/mod.cpp index 95c710d170b6..dc88f27e4032 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/mod.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/mod.cpp @@ -21,75 +21,75 @@ #include #if NOT_EXCLUDED(OP_mod) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(mod, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(BROADCAST(Mod), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - - DECLARE_TYPES(mod) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - - DECLARE_TYPES(mod_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - CUSTOM_OP_IMPL(mod_bp, 3, 2, false, 0, 0) { - // PLEASE NOTE: we're just passing eps down the line here - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - - gradY->assign(0.0f); - gradX->assign(0.0f); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(mod_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); - - // eps always has shape of x - // grad always has shape of y - - Nd4jLong *shapeE; - Nd4jLong *shapeG; - - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); - - auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - - return shapeList; - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(mod, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = BroadcastHelper::broadcastApply(BROADCAST(Mod), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} + +DECLARE_TYPES(mod) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} + +DECLARE_TYPES(mod_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +CUSTOM_OP_IMPL(mod_bp, 3, 2, false, 0, 0) { + // PLEASE NOTE: we're just passing eps down the line here + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + + gradY->assign(0.0f); + gradX->assign(0.0f); + + return Status::OK(); +} + +DECLARE_SHAPE_FN(mod_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); + + // eps always has shape of x + // grad always has shape of y + + Nd4jLong *shapeE; + Nd4jLong *shapeG; + + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); + + auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); + + return shapeList; } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/multiply.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/multiply.cpp index b7635c664f22..c44adc830a38 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/multiply.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/multiply.cpp @@ -27,123 +27,132 @@ namespace sd { namespace ops { - BROADCASTABLE_OP_IMPL(multiply, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - const Nd4jLong* zShapeInfo = nullptr; - const bool areShapesBroadcastable = ShapeUtils::evalBroadcastShapeInfo(x->shapeInfo(), y->shapeInfo(), true, zShapeInfo, block.getWorkspace()); - REQUIRE_TRUE(areShapesBroadcastable, 0, "MULTIPLY OP: the shapes of x %s and y %s are not suitable for broadcast !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); - - auto tZ = BroadcastHelper::broadcastApply(sd::BroadcastOpsTuple::Multiply(), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) - throw std::runtime_error("multiply: result was replaced"); - - return Status::OK(); - } - DECLARE_SYN(Mul, multiply); - - DECLARE_TYPES(multiply) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - - DECLARE_TYPES(multiply_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +BROADCASTABLE_OP_IMPL(multiply, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + const Nd4jLong* zShapeInfo = nullptr; + const bool areShapesBroadcastable = ShapeUtils::evalBroadcastShapeInfo( + x->shapeInfo(), y->shapeInfo(), true, zShapeInfo, block.workspace()); + REQUIRE_TRUE(areShapesBroadcastable, 0, + "MULTIPLY OP: the shapes of x %s and y %s are not suitable for " + "broadcast !", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str()); + + auto tZ = BroadcastHelper::broadcastApply(sd::BroadcastOpsTuple::Multiply(), + x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) + throw std::runtime_error("multiply: result was replaced"); + + return Status::OK(); +} +DECLARE_SYN(Mul, multiply); + +DECLARE_TYPES(multiply) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} + +DECLARE_TYPES(multiply_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} /////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(multiply_bp, 3, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto dLdz = INPUT_VARIABLE(2); - - auto dLdx = OUTPUT_VARIABLE(0); - auto dLdy = OUTPUT_VARIABLE(1); - - const Nd4jLong* dLdzShapeInfo = nullptr; - const bool areShapesBroadcastable = ShapeUtils::evalBroadcastShapeInfo(x->shapeInfo(), y->shapeInfo(), true, dLdzShapeInfo, block.getWorkspace()); - REQUIRE_TRUE(areShapesBroadcastable, 0, "MULTIPLY_BP OP: the shapes of x %s and y %s are not suitable for broadcast !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); - REQUIRE_TRUE(shape::equalsSoft(dLdz->shapeInfo(), dLdzShapeInfo), 0, "MULTIPLY_BP OP: wrong shape of next epsilon array (dLdOut), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(dLdzShapeInfo).c_str(), ShapeUtils::shapeAsString(dLdz).c_str()); - - const Nd4jLong xLen = x->lengthOf(); - const Nd4jLong yLen = y->lengthOf(); - - if(x->isScalar() && y->isScalar()) { // both are scalars - y->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdx); - x->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdy); - //dLdx->assign((*y) * (*dLdz)); - //dLdy->assign((*x) * (*dLdz)); - - } - else if(x->isScalar()) { // x is scalar and y is not - dLdx->assign((*y * *dLdz).reduceNumber(reduce::Sum)); - dLdz->applyScalarArr(scalar::Multiply, *x, *dLdy); - //dLdz->applyTrueBroadcast(broadcast::Multiply, x, dLdy, true); - } - else if(y->isScalar()) { // y is scalar and x is not - dLdy->assign((*x * *dLdz).reduceNumber(reduce::Sum)); - dLdz->applyScalarArr(scalar::Multiply, *y, *dLdx); - } - else if(x->isSameShape(y)) { - x->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdy); - y->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdx); - } - else if (x->isSameShape(dLdz)) { - - auto yTiled = NDArray(dLdz, false, block.launchContext()); - y->tile(yTiled); - std::vector axesForY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), dLdz->shapeInfo()); - - dLdy->assign( (*x * *dLdz).reduceAlongDimension(reduce::Sum, axesForY) ); - yTiled.applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdx); - } - else if (y->isSameShape(dLdz)) { - - auto xTiled = NDArray(dLdz, false, block.launchContext()); - x->tile(xTiled); - std::vector axesForX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), dLdz->shapeInfo()); - - dLdx->assign( (*y * *dLdz).reduceAlongDimension(reduce::Sum, axesForX) ); - xTiled.applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdy); - } - else { - - auto xTiled = NDArray(dLdz, false, block.launchContext()); - auto yTiled = NDArray(dLdz, false, block.launchContext()); - x->tile(xTiled); - y->tile(yTiled); - std::vector axesForX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), dLdz->shapeInfo()); - std::vector axesForY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), dLdz->shapeInfo()); - - dLdx->assign( (*y * *dLdz).reduceAlongDimension(reduce::Sum, axesForX) ); - dLdy->assign( (*x * *dLdz).reduceAlongDimension(reduce::Sum, axesForY) ); - } - - return Status::OK(); + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto dLdz = INPUT_VARIABLE(2); + + auto dLdx = OUTPUT_VARIABLE(0); + auto dLdy = OUTPUT_VARIABLE(1); + + const Nd4jLong* dLdzShapeInfo = nullptr; + const bool areShapesBroadcastable = ShapeUtils::evalBroadcastShapeInfo( + x->shapeInfo(), y->shapeInfo(), true, dLdzShapeInfo, block.workspace()); + REQUIRE_TRUE(areShapesBroadcastable, 0, + "MULTIPLY_BP OP: the shapes of x %s and y %s are not suitable " + "for broadcast !", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str()); + REQUIRE_TRUE(shape::equalsSoft(dLdz->shapeInfo(), dLdzShapeInfo), 0, + "MULTIPLY_BP OP: wrong shape of next epsilon array (dLdOut), " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(dLdzShapeInfo).c_str(), + ShapeUtils::shapeAsString(dLdz).c_str()); + + const Nd4jLong xLen = x->lengthOf(); + const Nd4jLong yLen = y->lengthOf(); + + if (x->isScalar() && y->isScalar()) { // both are scalars + y->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdx); + x->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdy); + // dLdx->assign((*y) * (*dLdz)); + // dLdy->assign((*x) * (*dLdz)); + + } else if (x->isScalar()) { // x is scalar and y is not + dLdx->assign((*y * *dLdz).reduceNumber(reduce::Sum)); + dLdz->applyScalarArr(scalar::Multiply, *x, *dLdy); + // dLdz->applyTrueBroadcast(broadcast::Multiply, x, dLdy, true); + } else if (y->isScalar()) { // y is scalar and x is not + dLdy->assign((*x * *dLdz).reduceNumber(reduce::Sum)); + dLdz->applyScalarArr(scalar::Multiply, *y, *dLdx); + } else if (x->isSameShape(y)) { + x->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdy); + y->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdx); + } else if (x->isSameShape(dLdz)) { + auto yTiled = NDArray(dLdz, false, block.launchContext()); + y->tile(yTiled); + std::vector axesForY = ShapeUtils::evalBroadcastBackwardAxis( + y->shapeInfo(), dLdz->shapeInfo()); + + dLdy->assign((*x * *dLdz).reduceAlongDimension(reduce::Sum, axesForY)); + yTiled.applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdx); + } else if (y->isSameShape(dLdz)) { + auto xTiled = NDArray(dLdz, false, block.launchContext()); + x->tile(xTiled); + std::vector axesForX = ShapeUtils::evalBroadcastBackwardAxis( + x->shapeInfo(), dLdz->shapeInfo()); + + dLdx->assign((*y * *dLdz).reduceAlongDimension(reduce::Sum, axesForX)); + xTiled.applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdy); + } else { + auto xTiled = NDArray(dLdz, false, block.launchContext()); + auto yTiled = NDArray(dLdz, false, block.launchContext()); + x->tile(xTiled); + y->tile(yTiled); + std::vector axesForX = ShapeUtils::evalBroadcastBackwardAxis( + x->shapeInfo(), dLdz->shapeInfo()); + std::vector axesForY = ShapeUtils::evalBroadcastBackwardAxis( + y->shapeInfo(), dLdz->shapeInfo()); + + dLdx->assign((*y * *dLdz).reduceAlongDimension(reduce::Sum, axesForX)); + dLdy->assign((*x * *dLdz).reduceAlongDimension(reduce::Sum, axesForY)); + } + + return Status::OK(); } DECLARE_SHAPE_FN(multiply_bp) { + auto xShapeInfo = inputShape->at(0); + auto yShapeInfo = inputShape->at(1); - auto xShapeInfo = inputShape->at(0); - auto yShapeInfo = inputShape->at(1); - - Nd4jLong *dLdxShapeInfo = nullptr; - Nd4jLong *dLdyShapeInfo = nullptr; + Nd4jLong* dLdxShapeInfo = nullptr; + Nd4jLong* dLdyShapeInfo = nullptr; - COPY_SHAPE(xShapeInfo, dLdxShapeInfo); - COPY_SHAPE(yShapeInfo, dLdyShapeInfo); + COPY_SHAPE(xShapeInfo, dLdxShapeInfo); + COPY_SHAPE(yShapeInfo, dLdyShapeInfo); - return SHAPELIST(CONSTANT(dLdxShapeInfo), CONSTANT(dLdyShapeInfo)); + return SHAPELIST(CONSTANT(dLdxShapeInfo), CONSTANT(dLdyShapeInfo)); } /* CUSTOM_OP_IMPL(multiply_bp, 3, 2, false, 0, 0) { @@ -194,20 +203,20 @@ DECLARE_SHAPE_FN(multiply_bp) { preX->tileToShape(targetShape); preY->tileToShape(targetShape); - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); + auto axisX = + ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); + auto axisY = + ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); if (axisX.size() > 0) { - auto sum = preX->template reduceAlongDimension>(axisX); - gradX->assign(sum); - delete sum; + auto sum = preX->template + reduceAlongDimension>(axisX); gradX->assign(sum); delete sum; } else gradX->assign(preX); if (axisY.size() > 0) { - auto sum = preY->template reduceAlongDimension>(axisY); - gradY->assign(sum); - delete sum; + auto sum = preY->template + reduceAlongDimension>(axisY); gradY->assign(sum); delete sum; } else gradY->assign(preY); @@ -220,7 +229,7 @@ DECLARE_SHAPE_FN(multiply_bp) { } */ -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/not_equals.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/not_equals.cpp index 9e2609f9d9c1..befc126347e6 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/not_equals.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/not_equals.cpp @@ -21,29 +21,30 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_BOOL_OP_IMPL(not_equals, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_BOOL_OP_IMPL(not_equals, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(NotEqualTo), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } + auto tZ = + BroadcastHelper::broadcastApply(BROADCAST_BOOL(NotEqualTo), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(not_equals) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } -} \ No newline at end of file +DECLARE_TYPES(not_equals) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/percentile.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/percentile.cpp index f5fbd4b185a7..6cb552d1be26 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/percentile.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/percentile.cpp @@ -24,68 +24,98 @@ #include #include - namespace sd { -namespace ops { - -CUSTOM_OP_IMPL(percentile, 1, 1, false, 1, -2) { - auto input = INPUT_VARIABLE(0); // tensor with rank > 0 - auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) - - const auto q = T_ARG(0); // percentile - const int interpolation = block.getTArguments()->size() > 1 ? T_ARG(1) : 2.; // 0-"lower", 1-"higher", 2-"nearest"(default) - const int keepDims = block.getTArguments()->size() > 2 ? T_ARG(2) : 0.; // false is default - - const int axisArrRank = block.getIArguments()->size(); - const int inputArrRank = input->rankOf(); - - REQUIRE_TRUE(inputArrRank > 0, 0, "PERCENTILE OP: rank of input array must be positive (>0), but got %i instead !", inputArrRank); - REQUIRE_TRUE(0.f <= q && q <= 100.f, 0, "PERCENTILE OP: percentile parameter must be within [0, 100] range, but got %f instead !", q); - REQUIRE_TRUE(interpolation == 0 || interpolation == 1 || interpolation == 2, 0, "PERCENTILE OP: the correct values for interpolation parameter are 0, 1, 2, but got %i instead !", interpolation); - REQUIRE_TRUE(axisArrRank <= inputArrRank, 0, "PERCENTILE OP: the rank of axis array must be <= rank of input array, but got %i and %i correspondingly !", axisArrRank, inputArrRank); - - for(int i = 0; i < axisArrRank; ++i) { - int dim = INT_ARG(i) >= 0 ? INT_ARG(i) : INT_ARG(i) + inputArrRank; - REQUIRE_TRUE(dim < inputArrRank, 0, "PERCENTILE OP: element (dimension) of axis array at position %i is >= rank of input array (%i >= %i), which is unacceptable !", i, dim, inputArrRank); - } - - std::vector axises = *block.getIArguments(); - helpers::percentile(block.launchContext(), *input, *output, axises, q, interpolation); - - return Status::OK(); +namespace ops { + +CUSTOM_OP_IMPL(percentile, 1, 1, false, 1, -2) { + auto input = INPUT_VARIABLE(0); // tensor with rank > 0 + auto output = OUTPUT_VARIABLE( + 0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) + + const auto q = T_ARG(0); // percentile + const int interpolation = + block.numT() > 1 ? T_ARG(1) + : 2.; // 0-"lower", 1-"higher", 2-"nearest"(default) + const int keepDims = block.numT() > 2 ? T_ARG(2) : 0.; // false is default + + const int axisArrRank = block.numI(); + const int inputArrRank = input->rankOf(); + + REQUIRE_TRUE(inputArrRank > 0, 0, + "PERCENTILE OP: rank of input array must be positive (>0), but " + "got %i instead !", + inputArrRank); + REQUIRE_TRUE(0.f <= q && q <= 100.f, 0, + "PERCENTILE OP: percentile parameter must be within [0, 100] " + "range, but got %f instead !", + q); + REQUIRE_TRUE(interpolation == 0 || interpolation == 1 || interpolation == 2, + 0, + "PERCENTILE OP: the correct values for interpolation parameter " + "are 0, 1, 2, but got %i instead !", + interpolation); + REQUIRE_TRUE(axisArrRank <= inputArrRank, 0, + "PERCENTILE OP: the rank of axis array must be <= rank of input " + "array, but got %i and %i correspondingly !", + axisArrRank, inputArrRank); + + for (int i = 0; i < axisArrRank; ++i) { + int dim = INT_ARG(i) >= 0 ? INT_ARG(i) : INT_ARG(i) + inputArrRank; + REQUIRE_TRUE( + dim < inputArrRank, 0, + "PERCENTILE OP: element (dimension) of axis array at position %i is >= " + "rank of input array (%i >= %i), which is unacceptable !", + i, dim, inputArrRank); + } + + auto axises = block.getIArguments(); + helpers::percentile(block.launchContext(), *input, *output, axises, q, + interpolation); + + return Status::OK(); } - DECLARE_TYPES(percentile) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT) - ->setSameMode(true); - } - +DECLARE_TYPES(percentile) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT) + ->setSameMode(true); +} DECLARE_SHAPE_FN(percentile) { - auto inputShapeInfo = inputShape->at(0); - const int keepDims = block.getTArguments()->size() > 2 ? T_ARG(2) : 0.; // false is default - - const int axisArrRank = block.getIArguments()->size(); - const int inputArrRank = inputShapeInfo[0]; - - REQUIRE_TRUE(inputArrRank > 0, 0, "PERCENTILE OP: rank of input array must be positive (>0), but got %i instead !", inputArrRank); - REQUIRE_TRUE(axisArrRank <= inputArrRank, 0, "PERCENTILE OP: the rank of axis array must be <= rank of input array, but got %i and %i correspondingly !", axisArrRank, inputArrRank); - - for(int i = 0; i < axisArrRank; ++i) { - int dim = INT_ARG(i) >= 0 ? INT_ARG(i) : INT_ARG(i) + inputArrRank; - REQUIRE_TRUE(dim < inputArrRank, 0, "PERCENTILE OP: element (dimension) of axis array at position %i is >= rank of input array (%i >= %i), which is unacceptable !", i, dim, inputArrRank); - } - - std::vector axises = *block.getIArguments(); - auto outputShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShapeInfo), axises, inputShapeInfo, keepDims, false, block.getWorkspace()); - - return SHAPELIST(outputShapeInfo); + auto inputShapeInfo = inputShape->at(0); + const int keepDims = block.numT() > 2 ? T_ARG(2) : 0.; // false is default + + const int axisArrRank = block.numI(); + const int inputArrRank = inputShapeInfo[0]; + + REQUIRE_TRUE(inputArrRank > 0, 0, + "PERCENTILE OP: rank of input array must be positive (>0), but " + "got %i instead !", + inputArrRank); + REQUIRE_TRUE(axisArrRank <= inputArrRank, 0, + "PERCENTILE OP: the rank of axis array must be <= rank of input " + "array, but got %i and %i correspondingly !", + axisArrRank, inputArrRank); + + for (int i = 0; i < axisArrRank; ++i) { + int dim = INT_ARG(i) >= 0 ? INT_ARG(i) : INT_ARG(i) + inputArrRank; + REQUIRE_TRUE( + dim < inputArrRank, 0, + "PERCENTILE OP: element (dimension) of axis array at position %i is >= " + "rank of input array (%i >= %i), which is unacceptable !", + i, dim, inputArrRank); + } + + auto axises = block.getIArguments(); + auto outputShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(inputShapeInfo), axises, inputShapeInfo, keepDims, false, + block.workspace()); + + return SHAPELIST(outputShapeInfo); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/pow.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/pow.cpp index 8ceb61e18fd0..85283fac93a7 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/pow.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/pow.cpp @@ -22,106 +22,112 @@ #include #if NOT_EXCLUDED(OP_Pow) -#include #include +#include namespace sd { namespace ops { - BROADCASTABLE_OP_IMPL(Pow, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - //REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!"); - - auto tZ = BroadcastHelper::broadcastApply({scalar::Pow, pairwise::Pow, broadcast::Pow}, x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - - DECLARE_TYPES(Pow) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_FLOATS, ALL_INTS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS}); - } - - CUSTOM_OP_IMPL(Pow_bp, 3, 2, false, 0, 0) { - - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto dLdz = INPUT_VARIABLE(2); - - auto dLdx = OUTPUT_VARIABLE(0); - auto dLdy = OUTPUT_VARIABLE(1); - - const Nd4jLong* dLdzShapeInfo = nullptr; - const bool areShapesBroadcastable = ShapeUtils::evalBroadcastShapeInfo(x->shapeInfo(), y->shapeInfo(), true, dLdzShapeInfo, block.getWorkspace()); - REQUIRE_TRUE(areShapesBroadcastable, 0, "POW_BP OP: the shapes of x %s" - " and y %s are not suitable for broadcast !", - ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); - REQUIRE_TRUE(shape::equalsSoft(dLdz->shapeInfo(), dLdzShapeInfo), 0, +BROADCASTABLE_OP_IMPL(Pow, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + // REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!"); + + auto tZ = BroadcastHelper::broadcastApply( + {scalar::Pow, pairwise::Pow, broadcast::Pow}, x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} + +DECLARE_TYPES(Pow) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_FLOATS, ALL_INTS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS}); +} + +CUSTOM_OP_IMPL(Pow_bp, 3, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto dLdz = INPUT_VARIABLE(2); + + auto dLdx = OUTPUT_VARIABLE(0); + auto dLdy = OUTPUT_VARIABLE(1); + + const Nd4jLong* dLdzShapeInfo = nullptr; + const bool areShapesBroadcastable = ShapeUtils::evalBroadcastShapeInfo( + x->shapeInfo(), y->shapeInfo(), true, dLdzShapeInfo, block.workspace()); + REQUIRE_TRUE(areShapesBroadcastable, 0, + "POW_BP OP: the shapes of x %s" + " and y %s are not suitable for broadcast !", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str()); + REQUIRE_TRUE(shape::equalsSoft(dLdz->shapeInfo(), dLdzShapeInfo), 0, "POW_BP OP: wrong shape of next epsilon array (dLdOut)," - " expected is %s, but got %s instead !", - ShapeUtils::shapeAsString(dLdzShapeInfo).c_str(), ShapeUtils::shapeAsString(dLdz).c_str()); - - // dL/dy = x^y * log(x) * dL/dz - auto temp = x->applyTrueBroadcast(BroadcastOpsTuple::Pow(), *y); // a = x^y - x->applyTransform(transform::Log, *dLdx); // b = log(x) - dLdx->applyScalar(sd::scalar::ReplaceNans, 0, *dLdx); - temp *= *dLdx; // c = b*a - temp *= *dLdz; // dL/dy = c * dL/dz - if (dLdy->isSameShape(*dLdz)) { - dLdy->assign(temp); - } - else { - std::vector axesForY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), dLdz->shapeInfo()); - dLdy->assign(temp.reduceAlongDimension(reduce::Sum, axesForY)); // dL/dy = sum(c * dL/dz) - } - - // dL/dx = y*x^(y-1) * dL/dz - x->applyTrueBroadcast(BroadcastOpsTuple::PowDerivative(), *y, temp); // a = y*x^(y-1) - temp *= *dLdz; // dLdx = a*dL/dz - - if (dLdx->isSameShape(*dLdz)) { - dLdx->assign(temp); // dLdx = a*dL/dz - } - else { - std::vector axesForX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), dLdz->shapeInfo()); - dLdx->assign(temp.reduceAlongDimension(reduce::Sum, axesForX)); // dLdx = a*dL/dz - } - - return Status::OK(); - } - - DECLARE_SHAPE_FN(Pow_bp) { - - auto xShapeInfo = inputShape->at(0); - auto yShapeInfo = inputShape->at(1); - - Nd4jLong* dLdxShapeInfo = nullptr; - Nd4jLong* dLdyShapeInfo = nullptr; - - COPY_SHAPE(xShapeInfo, dLdxShapeInfo); - COPY_SHAPE(yShapeInfo, dLdyShapeInfo); - - return SHAPELIST(CONSTANT(dLdxShapeInfo), CONSTANT(dLdyShapeInfo)); - } - - DECLARE_TYPES(Pow_bp) { - getOpDescriptor() - ->setAllowedInputTypes({ ALL_FLOATS, ALL_INTS }) - ->setAllowedOutputTypes({ ALL_FLOATS }); - } + " expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(dLdzShapeInfo).c_str(), + ShapeUtils::shapeAsString(dLdz).c_str()); + + // dL/dy = x^y * log(x) * dL/dz + auto temp = x->applyTrueBroadcast(BroadcastOpsTuple::Pow(), *y); // a = x^y + x->applyTransform(transform::Log, *dLdx); // b = log(x) + dLdx->applyScalar(sd::scalar::ReplaceNans, 0, *dLdx); + temp *= *dLdx; // c = b*a + temp *= *dLdz; // dL/dy = c * dL/dz + if (dLdy->isSameShape(*dLdz)) { + dLdy->assign(temp); + } else { + std::vector axesForY = ShapeUtils::evalBroadcastBackwardAxis( + y->shapeInfo(), dLdz->shapeInfo()); + dLdy->assign(temp.reduceAlongDimension( + reduce::Sum, axesForY)); // dL/dy = sum(c * dL/dz) + } + + // dL/dx = y*x^(y-1) * dL/dz + x->applyTrueBroadcast(BroadcastOpsTuple::PowDerivative(), *y, + temp); // a = y*x^(y-1) + temp *= *dLdz; // dLdx = a*dL/dz + if (dLdx->isSameShape(*dLdz)) { + dLdx->assign(temp); // dLdx = a*dL/dz + } else { + std::vector axesForX = ShapeUtils::evalBroadcastBackwardAxis( + x->shapeInfo(), dLdz->shapeInfo()); + dLdx->assign( + temp.reduceAlongDimension(reduce::Sum, axesForX)); // dLdx = a*dL/dz + } + + return Status::OK(); } + +DECLARE_SHAPE_FN(Pow_bp) { + auto xShapeInfo = inputShape->at(0); + auto yShapeInfo = inputShape->at(1); + + Nd4jLong* dLdxShapeInfo = nullptr; + Nd4jLong* dLdyShapeInfo = nullptr; + + COPY_SHAPE(xShapeInfo, dLdxShapeInfo); + COPY_SHAPE(yShapeInfo, dLdyShapeInfo); + + return SHAPELIST(CONSTANT(dLdxShapeInfo), CONSTANT(dLdyShapeInfo)); } +DECLARE_TYPES(Pow_bp) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS, ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/realdiv.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/realdiv.cpp index 3691ffb55e6f..e8ca145ea3ed 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/realdiv.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/realdiv.cpp @@ -21,119 +21,122 @@ #include #if NOT_EXCLUDED(OP_realdiv) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(realdiv, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::Divide(), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - DECLARE_SYN(RealDiv, realdiv); - - DECLARE_TYPES(realdiv) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType::HALF, DataType::DOUBLE}); - } - - DECLARE_TYPES(realdiv_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +namespace ops { +BROADCASTABLE_OP_IMPL(realdiv, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = + BroadcastHelper::broadcastApply(BroadcastOpsTuple::Divide(), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} +DECLARE_SYN(RealDiv, realdiv); + +DECLARE_TYPES(realdiv) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType::HALF, DataType::DOUBLE}); +} - CUSTOM_OP_IMPL(realdiv_bp, 3, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); +DECLARE_TYPES(realdiv_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); +CUSTOM_OP_IMPL(realdiv_bp, 3, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); - if (x->isSameShape(y)) { - // PWT case case + if (x->isSameShape(y)) { + // PWT case case - // X gradient - //epsNext->applyPairwiseLambda(y, lambdaX, gradX); - epsNext->applyPairwiseTransform(pairwise::Divide, *y, *gradX); + // X gradient + // epsNext->applyPairwiseLambda(y, lambdaX, gradX); + epsNext->applyPairwiseTransform(pairwise::Divide, *y, *gradX); - // Y gradient - //epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); + // Y gradient + // epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); - gradY->assign((*epsNext) * -(*x) / ((*y) * (*y))); + gradY->assign((*epsNext) * -(*x) / ((*y) * (*y))); - } else if (y->isScalar()) { - // scalar case + } else if (y->isScalar()) { + // scalar case - auto tmp = epsNext->reduceNumber(reduce::Sum); - auto tmpX = x->reduceNumber(reduce::Sum); - gradY->assign(tmp * -tmpX / ((*y) * (*y))); + auto tmp = epsNext->reduceNumber(reduce::Sum); + auto tmpX = x->reduceNumber(reduce::Sum); + gradY->assign(tmp * -tmpX / ((*y) * (*y))); - //epsNext->applyLambda(lambdaS, gradX); - epsNext->applyScalarArr(scalar::Divide, *y, *gradX); - } else { - // broadcast case + // epsNext->applyLambda(lambdaS, gradX); + epsNext->applyScalarArr(scalar::Divide, *y, *gradX); + } else { + // broadcast case - auto preX = *epsNext / *y; + auto preX = *epsNext / *y; - NDArray negX(*x); - x->applyTransform(transform::Neg, negX); - auto preY = *epsNext * negX / ((*y) * (*y)); + NDArray negX(*x); + x->applyTransform(transform::Neg, negX); + auto preY = *epsNext * negX / ((*y) * (*y)); - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); + auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), + epsNext->shapeInfo()); + auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), + epsNext->shapeInfo()); - if (axisX.size() > 0) { - auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); - gradX->assign(sum); - } else - gradX->assign(preX); + if (axisX.size() > 0) { + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); + gradX->assign(sum); + } else + gradX->assign(preX); - if (axisY.size() > 0) { - auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); - gradY->assign(sum); - } else - gradY->assign(preY); - } + if (axisY.size() > 0) { + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); + gradY->assign(sum); + } else + gradY->assign(preY); + } - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(realdiv_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); +DECLARE_SHAPE_FN(realdiv_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); - // eps always has shape of x - // grad always has shape of y + // eps always has shape of x + // grad always has shape of y - Nd4jLong *shapeE; - Nd4jLong *shapeG; + Nd4jLong *shapeE; + Nd4jLong *shapeG; - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); - auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); + auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - return shapeList; - } - } + return shapeList; } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_divide.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_divide.cpp index 0b6ea7d2a00a..6497a9568298 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_divide.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_divide.cpp @@ -21,108 +21,111 @@ #include #if NOT_EXCLUDED(OP_reversedivide) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(reversedivide, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - REQUIRE_TRUE(!x->isB(), 0, "REVERSEDIVIDE OP: you can't divide by bool array!"); - x->applyTrueBroadcast(BROADCAST(ReverseDivide), *y, *z, true); - - return Status::OK(); - } - DECLARE_SYN(RDiv, reversedivide); - - DECLARE_TYPES(reversedivide) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - - DECLARE_TYPES(reversedivide_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - CUSTOM_OP_IMPL(reversedivide_bp, 3, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - - if (x->isSameShape(y)) { - // PWT case case - - // X gradient - //epsNext->applyTriplewiseLambda(x, y, lambdaX, gradX); - gradX->assign((*epsNext) * (*y) / ((*x) * (*x))); - gradX->applyTransform(transform::Neg, *gradX); - // Y gradient - //epsNext->applyPairwiseLambda(x, lambdaY, gradY); - gradY->assign((*epsNext) / (*x)); - } else if (y->isScalar()) { - // scalar case - auto tmp = epsNext->reduceNumber(reduce::Sum); - auto tmpX = x->reduceNumber(reduce::Sum); - gradY->assign(tmp / tmpX); - - gradX->assign((*epsNext) * (*y) / ((*x) * (*x))); - gradX->applyTransform(transform::Neg, *gradX); - } else { - // broadcast case - - auto preY = (*epsNext) / (*x); - - auto preX = *epsNext * (*y) / ((*x) * (*x)); - preX.applyTransform(transform::Neg, preX); - - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); - - if (axisX.size() > 0) { - auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); - gradX->assign(sum); - } else - gradX->assign(preX); - - if (axisY.size() > 0) { - auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); - gradY->assign(sum); - } else - gradY->assign(preY); - } - - return Status::OK(); - } - - DECLARE_SHAPE_FN(reversedivide_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); - - // eps always has shape of x - // grad always has shape of y - - Nd4jLong *shapeE; - Nd4jLong *shapeG; - - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); - - return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(reversedivide, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + REQUIRE_TRUE(!x->isB(), 0, + "REVERSEDIVIDE OP: you can't divide by bool array!"); + x->applyTrueBroadcast(BROADCAST(ReverseDivide), *y, *z, true); + + return Status::OK(); +} +DECLARE_SYN(RDiv, reversedivide); + +DECLARE_TYPES(reversedivide) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} + +DECLARE_TYPES(reversedivide_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +CUSTOM_OP_IMPL(reversedivide_bp, 3, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + + if (x->isSameShape(y)) { + // PWT case case + + // X gradient + // epsNext->applyTriplewiseLambda(x, y, lambdaX, gradX); + gradX->assign((*epsNext) * (*y) / ((*x) * (*x))); + gradX->applyTransform(transform::Neg, *gradX); + // Y gradient + // epsNext->applyPairwiseLambda(x, lambdaY, gradY); + gradY->assign((*epsNext) / (*x)); + } else if (y->isScalar()) { + // scalar case + auto tmp = epsNext->reduceNumber(reduce::Sum); + auto tmpX = x->reduceNumber(reduce::Sum); + gradY->assign(tmp / tmpX); + + gradX->assign((*epsNext) * (*y) / ((*x) * (*x))); + gradX->applyTransform(transform::Neg, *gradX); + } else { + // broadcast case + + auto preY = (*epsNext) / (*x); + + auto preX = *epsNext * (*y) / ((*x) * (*x)); + preX.applyTransform(transform::Neg, preX); + + auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), + epsNext->shapeInfo()); + auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), + epsNext->shapeInfo()); + + if (axisX.size() > 0) { + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); + gradX->assign(sum); + } else + gradX->assign(preX); + + if (axisY.size() > 0) { + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); + gradY->assign(sum); + } else + gradY->assign(preY); + } + + return Status::OK(); +} + +DECLARE_SHAPE_FN(reversedivide_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); + + // eps always has shape of x + // grad always has shape of y + + Nd4jLong *shapeE; + Nd4jLong *shapeG; + + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); + + return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_mod.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_mod.cpp index bb25fada6cbf..4203ab9f0b1b 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_mod.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_mod.cpp @@ -21,76 +21,75 @@ #include #if NOT_EXCLUDED(OP_reversemod) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(reversemod, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(BROADCAST(ReverseMod), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - - DECLARE_TYPES(reversemod) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } +namespace ops { +BROADCASTABLE_OP_IMPL(reversemod, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = BroadcastHelper::broadcastApply(BROADCAST(ReverseMod), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} - DECLARE_TYPES(reversemod_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(reversemod) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} +DECLARE_TYPES(reversemod_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - CUSTOM_OP_IMPL(reversemod_bp, 3, 2, false, 0, 0) { - // PLEASE NOTE: we're just passing eps down the line here - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); +CUSTOM_OP_IMPL(reversemod_bp, 3, 2, false, 0, 0) { + // PLEASE NOTE: we're just passing eps down the line here + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); - gradY->assign(0.0f); - gradX->assign(0.0f); + gradY->assign(0.0f); + gradX->assign(0.0f); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(reversemod_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); +DECLARE_SHAPE_FN(reversemod_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); - // eps always has shape of x - // grad always has shape of y + // eps always has shape of x + // grad always has shape of y - Nd4jLong *shapeE; - Nd4jLong *shapeG; + Nd4jLong *shapeE; + Nd4jLong *shapeG; - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); - auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); + auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - return shapeList; - } - } + return shapeList; } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_subtract.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_subtract.cpp index 5d33c7cea09d..c3cbe232c95d 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_subtract.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_subtract.cpp @@ -21,101 +21,104 @@ #include #if NOT_EXCLUDED(OP_reversesubtract) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(reversesubtract, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(BROADCAST(ReverseSubtract), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - DECLARE_SYN(RSub, reversesubtract); - - DECLARE_TYPES(reversesubtract) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - - CUSTOM_OP_IMPL(reversesubtract_bp, 3, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - - if (x->isSameShape(y)) { - // PWT case case - epsNext->applyTransform(transform::Neg, *gradX); - gradY->assign(epsNext); - } else if (y->isScalar()) { - // scalar case - auto tmp = epsNext->reduceNumber(reduce::Sum); - gradY->assign(tmp); - epsNext->applyTransform(transform::Neg, *gradX); - } else { - // broadcastable - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); - - if (axisX.size() > 0) { - auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisX); - sum.applyTransform(transform::Neg, *gradX); - } else { - epsNext->applyTransform(transform::Neg, *gradX); - } - - if (axisY.size() > 0) { - auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisY); - gradY->assign(sum); - } else { - gradY->assign(epsNext); - } - } - - return Status::OK(); - } - - DECLARE_SHAPE_FN(reversesubtract_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); - - // eps always has shape of x - // grad always has shape of y - - Nd4jLong *shapeE; - Nd4jLong *shapeG; - - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); - - auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - - return shapeList; - } - - DECLARE_TYPES(reversesubtract_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +namespace ops { +BROADCASTABLE_OP_IMPL(reversesubtract, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = + BroadcastHelper::broadcastApply(BROADCAST(ReverseSubtract), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} +DECLARE_SYN(RSub, reversesubtract); + +DECLARE_TYPES(reversesubtract) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} + +CUSTOM_OP_IMPL(reversesubtract_bp, 3, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + + if (x->isSameShape(y)) { + // PWT case case + epsNext->applyTransform(transform::Neg, *gradX); + gradY->assign(epsNext); + } else if (y->isScalar()) { + // scalar case + auto tmp = epsNext->reduceNumber(reduce::Sum); + gradY->assign(tmp); + epsNext->applyTransform(transform::Neg, *gradX); + } else { + // broadcastable + auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), + epsNext->shapeInfo()); + auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), + epsNext->shapeInfo()); + + if (axisX.size() > 0) { + auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisX); + sum.applyTransform(transform::Neg, *gradX); + } else { + epsNext->applyTransform(transform::Neg, *gradX); + } + + if (axisY.size() > 0) { + auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisY); + gradY->assign(sum); + } else { + gradY->assign(epsNext); } + } + + return Status::OK(); +} + +DECLARE_SHAPE_FN(reversesubtract_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); + + // eps always has shape of x + // grad always has shape of y + + Nd4jLong *shapeE; + Nd4jLong *shapeG; + + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); + + auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); + + return shapeList; +} + +DECLARE_TYPES(reversesubtract_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/squared_subtract.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/squared_subtract.cpp index 6f5482512227..0b619d707a35 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/squared_subtract.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/squared_subtract.cpp @@ -21,137 +21,138 @@ #include #if NOT_EXCLUDED(OP_squaredsubtract) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(squaredsubtract, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(BROADCAST(SquaredSubtract), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - DECLARE_SYN(squareddifference, squaredsubtract); - - DECLARE_TYPES(squaredsubtract) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - - CUSTOM_OP_IMPL(squaredsubtract_bp, 3, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - - /* - auto lambdaX = LAMBDA_TTT(_e, _x, _y) { - return _e * (T) 2.0 * (_x - _y) ; - }; - - auto lambdaY = LAMBDA_TTT(_e, _x, _y) { - return _e * (T) 2.0 * (_y - _x); - }; - */ - - auto ts = NDArrayFactory::create(x->dataType(), 2, block.launchContext()); - - - if (x->isSameShape(y)) { - // PWT case case - - // X gradient - //epsNext->applyTriplewiseLambda(x, y, lambdaX, gradX); - gradX->assign((*epsNext) * ts * ((*x) - (*y))); - - // Y gradient - //epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); - gradY->assign((*epsNext) * ts * ((*y) - (*x))); - - } else if (y->isScalar()) { - // scalar case - auto tmpX = x->reduceNumber(reduce::Sum); - gradY->assign(tmpX); - - //epsNext->applyPairwiseLambda(x, lambdaS, gradX); - gradX->assign((*epsNext) * ts * ((*x) - (*y))); - } else { - // broadcast case - - auto preX = x->dup(); - auto preY = y->dup(); - - auto targetShape = epsNext->getShapeAsVector(); - - preX.tileToShape(targetShape, preX); - preY.tileToShape(targetShape, preY); - - - //epsNext->applyTriplewiseLambda(x, y, lambdaX, preX); - //epsNext->applyTriplewiseLambda(x, y, lambdaY, preY); - auto resX = (*epsNext) * ts * ((*x) - (*y)); - preX.assign(resX); - auto resY = (*epsNext) * ts * ((*y) - (*x)); - preY.assign(resY); - - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); - - if (axisX.size() > 0) { - auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); - gradX->assign(sum); - } else - gradX->assign(preX); - - if (axisY.size() > 0) { - auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); - gradY->assign(sum); - } else - gradY->assign(preY); - } +namespace ops { +BROADCASTABLE_OP_IMPL(squaredsubtract, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = + BroadcastHelper::broadcastApply(BROADCAST(SquaredSubtract), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} +DECLARE_SYN(squareddifference, squaredsubtract); - return Status::OK(); - } +DECLARE_TYPES(squaredsubtract) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} - DECLARE_SHAPE_FN(squaredsubtract_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); +CUSTOM_OP_IMPL(squaredsubtract_bp, 3, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + + /* + auto lambdaX = LAMBDA_TTT(_e, _x, _y) { + return _e * (T) 2.0 * (_x - _y) ; + }; + + auto lambdaY = LAMBDA_TTT(_e, _x, _y) { + return _e * (T) 2.0 * (_y - _x); + }; + */ + + auto ts = NDArrayFactory::create(x->dataType(), 2, block.launchContext()); + + if (x->isSameShape(y)) { + // PWT case case + + // X gradient + // epsNext->applyTriplewiseLambda(x, y, lambdaX, gradX); + gradX->assign((*epsNext) * ts * ((*x) - (*y))); + + // Y gradient + // epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); + gradY->assign((*epsNext) * ts * ((*y) - (*x))); + + } else if (y->isScalar()) { + // scalar case + auto tmpX = x->reduceNumber(reduce::Sum); + gradY->assign(tmpX); + + // epsNext->applyPairwiseLambda(x, lambdaS, gradX); + gradX->assign((*epsNext) * ts * ((*x) - (*y))); + } else { + // broadcast case + + auto preX = x->dup(); + auto preY = y->dup(); + + auto targetShape = epsNext->getShapeAsVector(); + + preX.tileToShape(targetShape, preX); + preY.tileToShape(targetShape, preY); + + // epsNext->applyTriplewiseLambda(x, y, lambdaX, preX); + // epsNext->applyTriplewiseLambda(x, y, lambdaY, preY); + auto resX = (*epsNext) * ts * ((*x) - (*y)); + preX.assign(resX); + auto resY = (*epsNext) * ts * ((*y) - (*x)); + preY.assign(resY); + + auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), + epsNext->shapeInfo()); + auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), + epsNext->shapeInfo()); + + if (axisX.size() > 0) { + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); + gradX->assign(sum); + } else + gradX->assign(preX); + + if (axisY.size() > 0) { + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); + gradY->assign(sum); + } else + gradY->assign(preY); + } + + return Status::OK(); +} - // eps always has shape of x - // grad always has shape of y +DECLARE_SHAPE_FN(squaredsubtract_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); - Nd4jLong *shapeE; - Nd4jLong *shapeG; + // eps always has shape of x + // grad always has shape of y - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); + Nd4jLong *shapeE; + Nd4jLong *shapeG; - return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - } + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); - DECLARE_TYPES(squaredsubtract_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } + return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); +} - } +DECLARE_TYPES(squaredsubtract_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/subtract.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/subtract.cpp index fac1c5dfaed9..1b0d45440794 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/subtract.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/subtract.cpp @@ -21,102 +21,104 @@ #include #if NOT_EXCLUDED(OP_subtract) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(subtract, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::Subtract(), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - DECLARE_SYN(Sub, subtract); - DECLARE_SYN(sub, subtract); - - DECLARE_TYPES(subtract) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - - CUSTOM_OP_IMPL(subtract_bp, 3, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - - if (x->isSameShape(y)) { - // PWT case case - epsNext->applyTransform(transform::Neg, *gradY); - gradX->assign(epsNext); - } else if (y->isScalar()) { - // scalar case - auto tmp = epsNext->reduceNumber(reduce::Sum); - gradY->assign(-tmp); - gradX->assign(epsNext); - } else { - // broadcastable - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); - - if (axisX.size() > 0) { - auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisX); - gradX->assign(sum); - } else - gradX->assign(epsNext); - - if (axisY.size() > 0) { - auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisY); - sum.applyTransform(transform::Neg, *gradY); - } else { - epsNext->applyTransform(transform::Neg, *gradY); - } - } - - return Status::OK(); - } - - DECLARE_TYPES(subtract_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - - DECLARE_SHAPE_FN(subtract_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); - - // eps always has shape of x - // grad always has shape of y - - Nd4jLong *shapeE; - Nd4jLong *shapeG; - - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); - - auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - - return shapeList; - } +namespace ops { +BROADCASTABLE_OP_IMPL(subtract, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = + BroadcastHelper::broadcastApply(BroadcastOpsTuple::Subtract(), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} +DECLARE_SYN(Sub, subtract); +DECLARE_SYN(sub, subtract); + +DECLARE_TYPES(subtract) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} + +CUSTOM_OP_IMPL(subtract_bp, 3, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + + if (x->isSameShape(y)) { + // PWT case case + epsNext->applyTransform(transform::Neg, *gradY); + gradX->assign(epsNext); + } else if (y->isScalar()) { + // scalar case + auto tmp = epsNext->reduceNumber(reduce::Sum); + gradY->assign(-tmp); + gradX->assign(epsNext); + } else { + // broadcastable + auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), + epsNext->shapeInfo()); + auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), + epsNext->shapeInfo()); + + if (axisX.size() > 0) { + auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisX); + gradX->assign(sum); + } else + gradX->assign(epsNext); + + if (axisY.size() > 0) { + auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisY); + sum.applyTransform(transform::Neg, *gradY); + } else { + epsNext->applyTransform(transform::Neg, *gradY); } + } + + return Status::OK(); +} + +DECLARE_TYPES(subtract_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +DECLARE_SHAPE_FN(subtract_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); + + // eps always has shape of x + // grad always has shape of y + + Nd4jLong *shapeE; + Nd4jLong *shapeG; + + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); + + auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); + + return shapeList; } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/truncatediv.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/truncatediv.cpp index 60900a5d9644..dbd86e537210 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/truncatediv.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/truncatediv.cpp @@ -22,29 +22,29 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(truncatediv, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_OP_IMPL(truncatediv, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - auto tZ = BroadcastHelper::broadcastApply(BROADCAST(TruncateDiv), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } + auto tZ = BroadcastHelper::broadcastApply(BROADCAST(TruncateDiv), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(truncatediv) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - } -} \ No newline at end of file +DECLARE_TYPES(truncatediv) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/compat/compat_sparse_to_dense.cpp b/libnd4j/include/ops/declarable/generic/compat/compat_sparse_to_dense.cpp index a2dcd6b1471b..384eba6e19da 100644 --- a/libnd4j/include/ops/declarable/generic/compat/compat_sparse_to_dense.cpp +++ b/libnd4j/include/ops/declarable/generic/compat/compat_sparse_to_dense.cpp @@ -25,49 +25,52 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(compat_sparse_to_dense, 4, 1, false, 0, 0) { - auto indices = INPUT_VARIABLE(0); - auto shape = INPUT_VARIABLE(1); - auto values = INPUT_VARIABLE(2); - NDArray *def = nullptr; +namespace ops { +CUSTOM_OP_IMPL(compat_sparse_to_dense, 4, 1, false, 0, 0) { + auto indices = INPUT_VARIABLE(0); + auto shape = INPUT_VARIABLE(1); + auto values = INPUT_VARIABLE(2); + NDArray *def = nullptr; - auto output = OUTPUT_NULLIFIED(0); + auto output = OUTPUT_NULLIFIED(0); - if (block.width() > 3) - def = INPUT_VARIABLE(3); + if (block.width() > 3) def = INPUT_VARIABLE(3); - sd::ops::helpers::compat_sparse_to_dense(*values, *indices, def, *output); + sd::ops::helpers::compat_sparse_to_dense(*values, *indices, def, *output); - return Status::OK(); - }; + return Status::OK(); +}; - DECLARE_SHAPE_FN(compat_sparse_to_dense) { - auto indices = INPUT_VARIABLE(0); - auto shape = INPUT_VARIABLE(1); - auto values = INPUT_VARIABLE(2); +DECLARE_SHAPE_FN(compat_sparse_to_dense) { + auto indices = INPUT_VARIABLE(0); + auto shape = INPUT_VARIABLE(1); + auto values = INPUT_VARIABLE(2); - if (block.width() > 3) { - auto def = INPUT_VARIABLE(3); + if (block.width() > 3) { + auto def = INPUT_VARIABLE(3); - REQUIRE_TRUE(def->dataType() == values->dataType() && def->isScalar(), 0, "compat_sparse_to_dense: default value must be a scalar of the same data type as actual values") - }; + REQUIRE_TRUE(def->dataType() == values->dataType() && def->isScalar(), 0, + "compat_sparse_to_dense: default value must be a scalar of " + "the same data type as actual values") + }; - auto dtype = values->dataType(); + auto dtype = values->dataType(); - // basically output shape is defined by the type of input, and desired shape input - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(dtype, 'c', shape->getBufferAsVector())); - } + // basically output shape is defined by the type of input, and desired shape + // input + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + dtype, 'c', shape->getBufferAsVector())); +} - DECLARE_TYPES(compat_sparse_to_dense) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS}) // indices - ->setAllowedInputTypes(1, {ALL_INTS}) // shape - ->setAllowedInputTypes(2,sd::DataType::ANY) // sparse values - ->setAllowedInputTypes(3,sd::DataType::ANY) // default value - ->setAllowedOutputTypes(sd::DataType::ANY); - } - } +DECLARE_TYPES(compat_sparse_to_dense) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) // indices + ->setAllowedInputTypes(1, {ALL_INTS}) // shape + ->setAllowedInputTypes(2, sd::DataType::ANY) // sparse values + ->setAllowedInputTypes(3, sd::DataType::ANY) // default value + ->setAllowedOutputTypes(sd::DataType::ANY); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp b/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp index 00965217872b..6133522d7c5f 100644 --- a/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp +++ b/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp @@ -14,126 +14,134 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author raver119@gmail.com - // +// +// @author raver119@gmail.com +// #include #if NOT_EXCLUDED(OP_split_string) -#include #include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(compat_string_split, 2, 2, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto delim = INPUT_VARIABLE(1); +namespace ops { +CUSTOM_OP_IMPL(compat_string_split, 2, 2, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto delim = INPUT_VARIABLE(1); + + auto indices = OUTPUT_NULLIFIED(0); + auto values = OUTPUT_VARIABLE(1); + + auto d = delim->e(0); + + input->syncToHost(); + delim->syncToHost(); + + // output rank N+1 wrt input rank + std::vector icoords(input->rankOf()); + + // getting buffer lengths + // FIXME: it'll be bigger, since it'll include delimiters, + auto outputLength = StringUtils::byteLength(*input); - auto indices = OUTPUT_NULLIFIED(0); - auto values = OUTPUT_VARIABLE(1); - - auto d = delim->e(0); - - input->syncToHost(); - delim->syncToHost(); - - // output rank N+1 wrt input rank - std::vector icoords(input->rankOf()); - - // getting buffer lengths - // FIXME: it'll be bigger, since it'll include delimiters, - auto outputLength = StringUtils::byteLength(*input); - - uint64_t ss = 0L; - Nd4jLong ic = 0L; - // loop through each string within tensor - for (auto e = 0L; e < input->lengthOf(); e++) { - // now we should map substring to indices - auto s = input->e(e); - - // getting base index - shape::index2coordsCPU(0, e, input->shapeInfo(), icoords.data()); - - // getting number of substrings - auto cnt = StringUtils::countSubarrays(s.c_str(), s.length(), d.c_str(), d.length()) + 1; - - // filling output indices - for (uint64_t f = 0; f < cnt; f++) { - for (auto v : icoords) - indices->p(ic++, v); - - // last index - indices->p(ic++, f); - } - - ss += cnt; - } - - // process strings now - std::vector strings; - for (auto e = 0L; e < input->lengthOf(); e++) { - auto split = StringUtils::split(input->e(e), d); - - for (const auto& s : split) - strings.emplace_back(s); - } - - // now once we have all strings in single vector time to fill - auto tmp = NDArrayFactory::string({ (Nd4jLong)strings.size() }, strings, input->dataType(), block.launchContext()); - auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size()); - - // for CUDA mostly - values->dataBuffer()->allocatePrimary(); - values->dataBuffer()->expand(blen); - memcpy(values->buffer(), tmp.buffer(), blen); - values->tickWriteHost(); - - // special case, for future use - indices->syncToDevice(); - values->syncToDevice(); - - // we have to tick buffers - values->dataBuffer()->writePrimary(); - values->dataBuffer()->readSpecial(); - - return Status::OK(); - }; - - DECLARE_SHAPE_FN(compat_string_split) { - auto input = INPUT_VARIABLE(0); - auto delim = INPUT_VARIABLE(1); - - auto d = delim->e(0); - - // count number of delimiter substrings in all strings within input tensor - uint64_t cnt = 0; - for (auto e = 0L; e < input->lengthOf(); e++) { - // FIXME: bad, not UTF-compatible - auto s = input->e(e); - - // each substring we see in haystack, splits string in two parts. so we should add 1 to the number of subarrays - cnt += StringUtils::countSubarrays(s.c_str(), s.length(), d.c_str(), d.length()) + 1; - } - - // shape calculations - // virtual tensor rank will be N+1, for N rank input array, where data will be located at the biggest dimension - // values tensor is going to be vector always - // indices tensor is going to be vector with length equal to values.length * output rank - - auto valuesShape = ConstantShapeHelper::getInstance().vectorShapeInfo(cnt, sd::DataType::UTF8); - auto indicesShape = ConstantShapeHelper::getInstance().vectorShapeInfo(cnt * (input->rankOf() + 1), sd::DataType::INT64); - - return SHAPELIST(indicesShape, valuesShape); - } + uint64_t ss = 0L; + Nd4jLong ic = 0L; + // loop through each string within tensor + for (auto e = 0L; e < input->lengthOf(); e++) { + // now we should map substring to indices + auto s = input->e(e); - DECLARE_TYPES(compat_string_split) { - getOpDescriptor() - ->setAllowedInputTypes({ ALL_STRINGS }) - ->setAllowedOutputTypes(0, { ALL_INDICES }) - ->setAllowedOutputTypes(1, { ALL_STRINGS }); - } + // getting base index + shape::index2coordsCPU(0, e, input->shapeInfo(), icoords.data()); + + // getting number of substrings + auto cnt = StringUtils::countSubarrays(s.c_str(), s.length(), d.c_str(), + d.length()) + + 1; + + // filling output indices + for (uint64_t f = 0; f < cnt; f++) { + for (auto v : icoords) indices->p(ic++, v); + + // last index + indices->p(ic++, f); } + + ss += cnt; + } + + // process strings now + std::vector strings; + for (auto e = 0L; e < input->lengthOf(); e++) { + auto split = StringUtils::split(input->e(e), d); + + for (const auto& s : split) strings.emplace_back(s); + } + + // now once we have all strings in single vector time to fill + auto tmp = NDArrayFactory::string({(Nd4jLong)strings.size()}, strings, + input->dataType(), block.launchContext()); + auto blen = StringUtils::byteLength(tmp) + + ShapeUtils::stringBufferHeaderRequirements(strings.size()); + + // for CUDA mostly + values->dataBuffer()->allocatePrimary(); + values->dataBuffer()->expand(blen); + memcpy(values->buffer(), tmp.buffer(), blen); + values->tickWriteHost(); + + // special case, for future use + indices->syncToDevice(); + values->syncToDevice(); + + // we have to tick buffers + values->dataBuffer()->writePrimary(); + values->dataBuffer()->readSpecial(); + + return Status::OK(); +}; + +DECLARE_SHAPE_FN(compat_string_split) { + auto input = INPUT_VARIABLE(0); + auto delim = INPUT_VARIABLE(1); + + auto d = delim->e(0); + + // count number of delimiter substrings in all strings within input tensor + uint64_t cnt = 0; + for (auto e = 0L; e < input->lengthOf(); e++) { + // FIXME: bad, not UTF-compatible + auto s = input->e(e); + + // each substring we see in haystack, splits string in two parts. so we + // should add 1 to the number of subarrays + cnt += StringUtils::countSubarrays(s.c_str(), s.length(), d.c_str(), + d.length()) + + 1; + } + + // shape calculations + // virtual tensor rank will be N+1, for N rank input array, where data will be + // located at the biggest dimension values tensor is going to be vector always + // indices tensor is going to be vector with length equal to values.length * + // output rank + + auto valuesShape = ConstantShapeHelper::getInstance().vectorShapeInfo( + cnt, sd::DataType::UTF8); + auto indicesShape = ConstantShapeHelper::getInstance().vectorShapeInfo( + cnt * (input->rankOf() + 1), sd::DataType::INT64); + + return SHAPELIST(indicesShape, valuesShape); +} + +DECLARE_TYPES(compat_string_split) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_STRINGS}) + ->setAllowedOutputTypes(0, {ALL_INDICES}) + ->setAllowedOutputTypes(1, {ALL_STRINGS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/compression/bitmap.cpp b/libnd4j/include/ops/declarable/generic/compression/bitmap.cpp index 7e89ce2c039c..59719786f72e 100644 --- a/libnd4j/include/ops/declarable/generic/compression/bitmap.cpp +++ b/libnd4j/include/ops/declarable/generic/compression/bitmap.cpp @@ -18,75 +18,78 @@ // @author George A. Shulinok // -#include #include #include +#include #if NOT_EXCLUDED(OP_decode_bitmap) namespace sd { - namespace ops { - CUSTOM_OP_IMPL(decode_bitmap, 2, 1, true, 0, 0) { - const auto encoded = INPUT_VARIABLE(1); - auto updates = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(decode_bitmap, 2, 1, true, 0, 0) { + const auto encoded = INPUT_VARIABLE(1); + auto updates = OUTPUT_VARIABLE(0); - helpers::decodeBitmap(block.launchContext(), encoded, updates); - return Status::OK(); - } + helpers::decodeBitmap(block.launchContext(), encoded, updates); + return Status::OK(); +} - DECLARE_SHAPE_FN(decode_bitmap) { - auto weights = INPUT_VARIABLE(0); +DECLARE_SHAPE_FN(decode_bitmap) { + auto weights = INPUT_VARIABLE(0); - return SHAPELIST(weights->shapeInfo()); - } + return SHAPELIST(weights->shapeInfo()); +} - DECLARE_TYPES(decode_bitmap) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, DataType::INT32) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +DECLARE_TYPES(decode_bitmap) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, DataType::INT32) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif #if NOT_EXCLUDED(OP_encode_bitmap) namespace sd { - namespace ops { - CUSTOM_OP_IMPL(encode_bitmap, 1, 3, true, 1, 0) { - auto input = INPUT_VARIABLE(0); - auto encoded = OUTPUT_NULLIFIED(1); - auto counter = OUTPUT_NULLIFIED(2); +namespace ops { +CUSTOM_OP_IMPL(encode_bitmap, 1, 3, true, 1, 0) { + auto input = INPUT_VARIABLE(0); + auto encoded = OUTPUT_NULLIFIED(1); + auto counter = OUTPUT_NULLIFIED(2); - float threshold = T_ARG(0); + float threshold = T_ARG(0); - encoded->p(0, (int) input->lengthOf()); - encoded->p(1, (int) input->lengthOf()); - encoded->p(2, reinterpret_cast(&threshold)[0]); - encoded->p(3, 1); // flag for BITMAP_ENCODING + encoded->p(0, (int)input->lengthOf()); + encoded->p(1, (int)input->lengthOf()); + encoded->p(2, reinterpret_cast(&threshold)[0]); + encoded->p(3, 1); // flag for BITMAP_ENCODING - auto result = helpers::encodeBitmap(block.launchContext(), input, encoded, threshold); - counter->p(0, result); - counter->syncToDevice(); + auto result = + helpers::encodeBitmap(block.launchContext(), input, encoded, threshold); + counter->p(0, result); + counter->syncToDevice(); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(encode_bitmap) { - auto input = inputShape->at(0); +DECLARE_SHAPE_FN(encode_bitmap) { + auto input = inputShape->at(0); - auto outputLength = shape::length(input) / 16 + 5; - auto encodedShape = ConstantShapeHelper::getInstance().vectorShapeInfo(outputLength, DataType::INT32); - auto encodedCounter = ConstantShapeHelper::getInstance().scalarShapeInfo(DataType::INT32); - return SHAPELIST(input, encodedShape, encodedCounter); - } + auto outputLength = shape::length(input) / 16 + 5; + auto encodedShape = ConstantShapeHelper::getInstance().vectorShapeInfo( + outputLength, DataType::INT32); + auto encodedCounter = + ConstantShapeHelper::getInstance().scalarShapeInfo(DataType::INT32); + return SHAPELIST(input, encodedShape, encodedCounter); +} - DECLARE_TYPES(encode_bitmap) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, DataType::INT32) - ->setAllowedInputTypes(2, DataType::INT32); - } - } +DECLARE_TYPES(encode_bitmap) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, DataType::INT32) + ->setAllowedInputTypes(2, DataType::INT32); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/compression/threshold.cpp b/libnd4j/include/ops/declarable/generic/compression/threshold.cpp index 83836bb8fda5..270b077589c5 100644 --- a/libnd4j/include/ops/declarable/generic/compression/threshold.cpp +++ b/libnd4j/include/ops/declarable/generic/compression/threshold.cpp @@ -18,87 +18,99 @@ // @author raver119@gmail.com // -#include #include #include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(encode_threshold, 1, 2, true, 1, 0) { - auto x = INPUT_VARIABLE(0); - auto updated = OUTPUT_VARIABLE(0); - auto encoded = OUTPUT_NULLIFIED(1); - - float threshold = T_ARG(0); - - REQUIRE_TRUE(x->lengthOf() <= DataTypeUtils::max(), 0, "encode_threshold: gradients array must have length <= MAX_INT"); - REQUIRE_TRUE(encoded->lengthOf() >= 4, 0, "encode_threshold: array for encoded updates can't have less than 4 elements"); -// REQUIRE_TRUE(x->platformBuffer() == updated->platformBuffer(), 0, "encode_threshold: gradients array must be the same at input and output"); - - // filling header bytes - encoded->p(0, encoded->lengthOf() - 4); - encoded->p(1, (int) x->lengthOf()); - encoded->p(2, reinterpret_cast(&threshold)[0]); - encoded->p(3, 0); // flag for FLEXIBLE_ENCODING - - // if there's no updates to process - just skip execution - if (encoded->lengthOf() == 4) - return Status::OK(); - - helpers::thresholdEncode(*x, *encoded, threshold); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(encode_threshold) { - auto x = INPUT_VARIABLE(0); - // we have limit option here - int boundary = block.numI() > 0 ? I_ARG(0) : DataTypeUtils::max(); - float threshold = T_ARG(0); - - REQUIRE_TRUE(boundary >= 0, 0, "encode_threshold: boundary must be positive"); - REQUIRE_TRUE(x->lengthOf() <= DataTypeUtils::max(), 0, "encode_threshold: gradients array must have length <= MAX_INT"); - - // we must calculate number of elements that >= threshold - auto elements = sd::math::nd4j_min(helpers::thresholdEstimate(*x, threshold), boundary); - if (elements < 2) - elements = 0; - - // result array must have 4 additional int elements for header - return SHAPELIST(x->shapeInfo(), sd::ConstantShapeHelper::getInstance().vectorShapeInfo(elements + 4, DataType::INT32)); - } - - DECLARE_TYPES(encode_threshold) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, DataType::INT32); - } - - CUSTOM_OP_IMPL(decode_threshold, 2, 1, true, 0, 0) { - auto weights = INPUT_VARIABLE(0); - auto encoded = INPUT_VARIABLE(1); - auto updates = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(encoded->lengthOf() >= 4, 0, "decode_threshold: encoded array can't have length < 4"); - REQUIRE_TRUE(updates->lengthOf() == encoded->e(1), 0, "decode_threshold: updates array must have length equal to [%i]", encoded->e(1)); - REQUIRE_TRUE(encoded->e(3) == 0, 0, "decode_threshold: encoded array doesn't look like threshold-encoded"); - - helpers::thresholdDecode(*encoded, *updates); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(decode_threshold) { - auto weights = inputShape->at(0); - return SHAPELIST(weights); - } - - DECLARE_TYPES(decode_threshold) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, DataType::INT32) - ->setAllowedOutputTypes(0,{ALL_FLOATS}); - } - } -} \ No newline at end of file +namespace ops { +CUSTOM_OP_IMPL(encode_threshold, 1, 2, true, 1, 0) { + auto x = INPUT_VARIABLE(0); + auto updated = OUTPUT_VARIABLE(0); + auto encoded = OUTPUT_NULLIFIED(1); + + float threshold = T_ARG(0); + + REQUIRE_TRUE(x->lengthOf() <= DataTypeUtils::max(), 0, + "encode_threshold: gradients array must have length <= MAX_INT"); + REQUIRE_TRUE(encoded->lengthOf() >= 4, 0, + "encode_threshold: array for encoded updates can't have less " + "than 4 elements"); + // REQUIRE_TRUE(x->platformBuffer() == updated->platformBuffer(), + // 0, "encode_threshold: gradients array must be the same at input + // and output"); + + // filling header bytes + encoded->p(0, encoded->lengthOf() - 4); + encoded->p(1, (int)x->lengthOf()); + encoded->p(2, reinterpret_cast(&threshold)[0]); + encoded->p(3, 0); // flag for FLEXIBLE_ENCODING + + // if there's no updates to process - just skip execution + if (encoded->lengthOf() == 4) return Status::OK(); + + helpers::thresholdEncode(*x, *encoded, threshold); + + return Status::OK(); +} + +DECLARE_SHAPE_FN(encode_threshold) { + auto x = INPUT_VARIABLE(0); + // we have limit option here + int boundary = block.numI() > 0 ? I_ARG(0) : DataTypeUtils::max(); + float threshold = T_ARG(0); + + REQUIRE_TRUE(boundary >= 0, 0, "encode_threshold: boundary must be positive"); + REQUIRE_TRUE(x->lengthOf() <= DataTypeUtils::max(), 0, + "encode_threshold: gradients array must have length <= MAX_INT"); + + // we must calculate number of elements that >= threshold + auto elements = sd::math::nd4j_min( + helpers::thresholdEstimate(*x, threshold), boundary); + if (elements < 2) elements = 0; + + // result array must have 4 additional int elements for header + return SHAPELIST(x->shapeInfo(), + sd::ConstantShapeHelper::getInstance().vectorShapeInfo( + elements + 4, DataType::INT32)); +} + +DECLARE_TYPES(encode_threshold) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, DataType::INT32); +} + +CUSTOM_OP_IMPL(decode_threshold, 2, 1, true, 0, 0) { + auto weights = INPUT_VARIABLE(0); + auto encoded = INPUT_VARIABLE(1); + auto updates = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(encoded->lengthOf() >= 4, 0, + "decode_threshold: encoded array can't have length < 4"); + REQUIRE_TRUE(updates->lengthOf() == encoded->e(1), 0, + "decode_threshold: updates array must have length equal to [%i]", + encoded->e(1)); + REQUIRE_TRUE( + encoded->e(3) == 0, 0, + "decode_threshold: encoded array doesn't look like threshold-encoded"); + + helpers::thresholdDecode(*encoded, *updates); + + return Status::OK(); +} + +DECLARE_SHAPE_FN(decode_threshold) { + auto weights = inputShape->at(0); + return SHAPELIST(weights); +} + +DECLARE_TYPES(decode_threshold) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, DataType::INT32) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp b/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp index 294406cb8695..8a819dc087fd 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp @@ -25,79 +25,90 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(bitcast, 1, 1, false, 0, 1) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - // when empty - nothing to do - DataType newType = DataTypeUtils::fromInt(INT_ARG(0)); - DataType oldType = input->dataType(); - // correct output shape to conform with output data type - auto inputSize = DataTypeUtils::sizeOf(oldType); - auto outputSize = DataTypeUtils::sizeOf(newType); - auto lastSize = outputSize / inputSize; - if (inputSize < outputSize) { - REQUIRE_TRUE(input->sizeAt(-1) == lastSize, 0, - "BITCAST: %llu > %llu. So last dimension should be %i, but %i given.", inputSize, - outputSize, lastSize, input->sizeAt(-1)); - } - if(input->isEmpty()){ - REQUIRE_TRUE(output->isEmpty(), 0, "BITCAST: If input is empty, output array must also be empty."); - return Status::OK(); - } +namespace ops { +CUSTOM_OP_IMPL(bitcast, 1, 1, false, 0, 1) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + // when empty - nothing to do + DataType newType = DataTypeUtils::fromInt(INT_ARG(0)); + DataType oldType = input->dataType(); + // correct output shape to conform with output data type + auto inputSize = DataTypeUtils::sizeOf(oldType); + auto outputSize = DataTypeUtils::sizeOf(newType); + auto lastSize = outputSize / inputSize; + if (inputSize < outputSize) { + REQUIRE_TRUE( + input->sizeAt(-1) == lastSize, 0, + "BITCAST: %llu > %llu. So last dimension should be %i, but %i given.", + inputSize, outputSize, lastSize, input->sizeAt(-1)); + } + if (input->isEmpty()) { + REQUIRE_TRUE( + output->isEmpty(), 0, + "BITCAST: If input is empty, output array must also be empty."); + return Status::OK(); + } - // just memcpy data - DataBuffer::memcpy(*output->dataBuffer(), *input->dataBuffer()); + // just memcpy data + DataBuffer::memcpy(*output->dataBuffer(), *input->dataBuffer()); - return Status::OK(); - } - DECLARE_SYN(BitCast, bitcast); + return Status::OK(); +} +DECLARE_SYN(BitCast, bitcast); - DECLARE_SHAPE_FN(bitcast) { - auto inShape = inputShape->at(0); - auto inputRank = shape::rank(inShape); - auto it = INT_ARG(0); - DataType newType = DataTypeUtils::fromInt(it); - DataType oldType = ArrayOptions::dataType(inShape); - // correct output shape to conform with output data type - auto inputSize = DataTypeUtils::sizeOf(oldType); - auto outputSize = DataTypeUtils::sizeOf(newType); +DECLARE_SHAPE_FN(bitcast) { + auto inShape = inputShape->at(0); + auto inputRank = shape::rank(inShape); + auto it = INT_ARG(0); + DataType newType = DataTypeUtils::fromInt(it); + DataType oldType = ArrayOptions::dataType(inShape); + // correct output shape to conform with output data type + auto inputSize = DataTypeUtils::sizeOf(oldType); + auto outputSize = DataTypeUtils::sizeOf(newType); - if (shape::length(inShape) == 0) - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(inShape, newType))); + if (shape::length(inShape) == 0) + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(inShape, newType))); - if (inputSize == outputSize) { - // only type should be changed - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(inShape, newType))); - } - else if (inputSize > outputSize) { - // range of output increased by 1 with inputSize / outputSize as last dimension - std::vector shapeOf(inputRank + 1); - int i; - for (i = 0; i < inputRank; ++i) { - shapeOf[i] = inShape[i + 1]; - } - shapeOf[i] = inputSize / outputSize; - auto outputShape = ConstantShapeHelper::getInstance().createShapeInfo(newType, shape::order(inShape), shapeOf); - return SHAPELIST(outputShape); - } - REQUIRE_TRUE(shape::sizeAt(inShape, -1) == outputSize / inputSize, 0, "BITCAST: %llu > %llu. So last dimension should be %i, but %i given.", inputSize, outputSize, outputSize / inputSize, shape::sizeAt(inShape, -1)); - std::vector shapeOf(inputRank - 1); + if (inputSize == outputSize) { + // only type should be changed + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(inShape, newType))); + } else if (inputSize > outputSize) { + // range of output increased by 1 with inputSize / outputSize as last + // dimension + std::vector shapeOf(inputRank + 1); + int i; + for (i = 0; i < inputRank; ++i) { + shapeOf[i] = inShape[i + 1]; + } + shapeOf[i] = inputSize / outputSize; + auto outputShape = ConstantShapeHelper::getInstance().createShapeInfo( + newType, shape::order(inShape), shapeOf); + return SHAPELIST(outputShape); + } + REQUIRE_TRUE( + shape::sizeAt(inShape, -1) == outputSize / inputSize, 0, + "BITCAST: %llu > %llu. So last dimension should be %i, but %i given.", + inputSize, outputSize, outputSize / inputSize, + shape::sizeAt(inShape, -1)); + std::vector shapeOf(inputRank - 1); - for (auto i = 0; i < shapeOf.size(); ++i) { - shapeOf[i] = inShape[i + 1]; - } + for (auto i = 0; i < shapeOf.size(); ++i) { + shapeOf[i] = inShape[i + 1]; + } - auto outputShape = ConstantShapeHelper::getInstance().createShapeInfo(newType, shape::order(inShape), shapeOf); - return SHAPELIST(outputShape); - } + auto outputShape = ConstantShapeHelper::getInstance().createShapeInfo( + newType, shape::order(inShape), shapeOf); + return SHAPELIST(outputShape); +} - DECLARE_TYPES(bitcast) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::ANY); - } - } +DECLARE_TYPES(bitcast) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::ANY); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/datatypes/cast.cpp b/libnd4j/include/ops/declarable/generic/datatypes/cast.cpp index ff071f7a9cde..8ddec422e5b0 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/cast.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/cast.cpp @@ -25,39 +25,40 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(cast, 1, 1, false, 0, 1) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - if(input->isEmpty()){ - REQUIRE_TRUE(output->isEmpty(), 0, "If input is empty, output array must also be empty"); - return Status::OK(); - } - - if (!block.isInplace()) - output->assign(input); - - STORE_RESULT(output); - return Status::OK(); - } - DECLARE_SYN(Cast, cast); - - DECLARE_SHAPE_FN(cast) { - auto inShape = inputShape->at(0); - - auto it = INT_ARG(0); - DataType newType = DataTypeUtils::fromInt(it); - - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(inShape, newType))); - } - - DECLARE_TYPES(cast) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::ANY); - } - } +namespace ops { +CUSTOM_OP_IMPL(cast, 1, 1, false, 0, 1) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (input->isEmpty()) { + REQUIRE_TRUE(output->isEmpty(), 0, + "If input is empty, output array must also be empty"); + return Status::OK(); + } + + if (!block.isInplace()) output->assign(input); + + STORE_RESULT(output); + return Status::OK(); +} +DECLARE_SYN(Cast, cast); + +DECLARE_SHAPE_FN(cast) { + auto inShape = inputShape->at(0); + + auto it = INT_ARG(0); + DataType newType = DataTypeUtils::fromInt(it); + + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(inShape, newType))); +} + +DECLARE_TYPES(cast) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::ANY); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/datatypes/to_double.cpp b/libnd4j/include/ops/declarable/generic/datatypes/to_double.cpp index 4eae77a5a285..42bbe029c7fe 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/to_double.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/to_double.cpp @@ -24,31 +24,31 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(to_double, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(to_double, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - if (!block.isInplace()) - output->assign(input); + if (!block.isInplace()) output->assign(input); - STORE_RESULT(output); + STORE_RESULT(output); - return Status::OK(); - } - - DECLARE_TYPES(to_double) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::DOUBLE); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(to_double) { - auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::DOUBLE, true, block.workspace()); - return SHAPELIST(CONSTANT(outShape)); - } +DECLARE_TYPES(to_double) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::DOUBLE); +} - } +DECLARE_SHAPE_FN(to_double) { + auto outShape = ShapeBuilders::copyShapeInfoAndType( + inputShape->at(0), DataType::DOUBLE, true, block.workspace()); + return SHAPELIST(CONSTANT(outShape)); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/datatypes/to_float16.cpp b/libnd4j/include/ops/declarable/generic/datatypes/to_float16.cpp index aa8ceb045483..bdf81af3f456 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/to_float16.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/to_float16.cpp @@ -24,31 +24,31 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(to_float16, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(to_float16, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - if (!block.isInplace()) - output->assign(input); + if (!block.isInplace()) output->assign(input); - STORE_RESULT(output); + STORE_RESULT(output); - return Status::OK(); - } - - DECLARE_TYPES(to_float16) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::HALF); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(to_float16) { - auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::HALF, true, block.workspace()); - return SHAPELIST(CONSTANT(outShape)); - } +DECLARE_TYPES(to_float16) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::HALF); +} - } +DECLARE_SHAPE_FN(to_float16) { + auto outShape = ShapeBuilders::copyShapeInfoAndType( + inputShape->at(0), DataType::HALF, true, block.workspace()); + return SHAPELIST(CONSTANT(outShape)); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/datatypes/to_float32.cpp b/libnd4j/include/ops/declarable/generic/datatypes/to_float32.cpp index 23a924f9cb85..2616cce0c1dc 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/to_float32.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/to_float32.cpp @@ -24,31 +24,31 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(to_float32, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(to_float32, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - if (!block.isInplace()) - output->assign(input); + if (!block.isInplace()) output->assign(input); - STORE_RESULT(output); + STORE_RESULT(output); - return Status::OK(); - } - - DECLARE_TYPES(to_float32) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::FLOAT32); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(to_float32) { - auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::FLOAT32, true, block.workspace()); - return SHAPELIST(CONSTANT(outShape)); - } +DECLARE_TYPES(to_float32) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::FLOAT32); +} - } +DECLARE_SHAPE_FN(to_float32) { + auto outShape = ShapeBuilders::copyShapeInfoAndType( + inputShape->at(0), DataType::FLOAT32, true, block.workspace()); + return SHAPELIST(CONSTANT(outShape)); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/datatypes/to_int32.cpp b/libnd4j/include/ops/declarable/generic/datatypes/to_int32.cpp index c28fa6049f1a..565481bb10d7 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/to_int32.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/to_int32.cpp @@ -24,30 +24,30 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(to_int32, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - if (!block.isInplace()) - output->assign(input); - - STORE_RESULT(output); - - return Status::OK(); - } - - DECLARE_TYPES(to_int32) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::INT32); - } - DECLARE_SHAPE_FN(to_int32) { - auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::INT32, true, block.workspace()); - return SHAPELIST(CONSTANT(outShape)); - } - - } +namespace ops { +CUSTOM_OP_IMPL(to_int32, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (!block.isInplace()) output->assign(input); + + STORE_RESULT(output); + + return Status::OK(); } +DECLARE_TYPES(to_int32) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::INT32); +} +DECLARE_SHAPE_FN(to_int32) { + auto outShape = ShapeBuilders::copyShapeInfoAndType( + inputShape->at(0), DataType::INT32, true, block.workspace()); + return SHAPELIST(CONSTANT(outShape)); +} + +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/datatypes/to_int64.cpp b/libnd4j/include/ops/declarable/generic/datatypes/to_int64.cpp index cb994ccfe5e0..b785d32c4871 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/to_int64.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/to_int64.cpp @@ -24,30 +24,30 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(to_int64, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - if (!block.isInplace()) - output->assign(input); - - STORE_RESULT(output); - - return Status::OK(); - } - - DECLARE_TYPES(to_int64) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::INT64); - } - DECLARE_SHAPE_FN(to_int64) { - auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::INT64, true, block.workspace()); - return SHAPELIST(CONSTANT(outShape)); - } - - } +namespace ops { +CUSTOM_OP_IMPL(to_int64, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (!block.isInplace()) output->assign(input); + + STORE_RESULT(output); + + return Status::OK(); } +DECLARE_TYPES(to_int64) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::INT64); +} +DECLARE_SHAPE_FN(to_int64) { + auto outShape = ShapeBuilders::copyShapeInfoAndType( + inputShape->at(0), DataType::INT64, true, block.workspace()); + return SHAPELIST(CONSTANT(outShape)); +} + +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/datatypes/to_uint32.cpp b/libnd4j/include/ops/declarable/generic/datatypes/to_uint32.cpp index f62d9cd9b87b..bfff0bdb64dd 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/to_uint32.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/to_uint32.cpp @@ -24,30 +24,30 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(to_uint32, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - if (!block.isInplace()) - output->assign(input); - - STORE_RESULT(output); - - return Status::OK(); - } - - DECLARE_TYPES(to_uint32) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::INT32); - } - DECLARE_SHAPE_FN(to_uint32) { - auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::UINT32, true, block.workspace()); - return SHAPELIST(CONSTANT(outShape)); - } - - } +namespace ops { +CUSTOM_OP_IMPL(to_uint32, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (!block.isInplace()) output->assign(input); + + STORE_RESULT(output); + + return Status::OK(); } +DECLARE_TYPES(to_uint32) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::INT32); +} +DECLARE_SHAPE_FN(to_uint32) { + auto outShape = ShapeBuilders::copyShapeInfoAndType( + inputShape->at(0), DataType::UINT32, true, block.workspace()); + return SHAPELIST(CONSTANT(outShape)); +} + +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/datatypes/to_uint64.cpp b/libnd4j/include/ops/declarable/generic/datatypes/to_uint64.cpp index dc337ea1bb73..2e3c02a80613 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/to_uint64.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/to_uint64.cpp @@ -24,29 +24,29 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(to_uint64, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - if (!block.isInplace()) - output->assign(input); - - STORE_RESULT(output); - - return Status::OK(); - } - - DECLARE_TYPES(to_uint64) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::INT8); - } - DECLARE_SHAPE_FN(to_uint64) { - auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::UINT64, true, block.workspace()); - return SHAPELIST(CONSTANT(outShape)); - } - } +namespace ops { +CUSTOM_OP_IMPL(to_uint64, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (!block.isInplace()) output->assign(input); + + STORE_RESULT(output); + + return Status::OK(); +} + +DECLARE_TYPES(to_uint64) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::INT8); +} +DECLARE_SHAPE_FN(to_uint64) { + auto outShape = ShapeBuilders::copyShapeInfoAndType( + inputShape->at(0), DataType::UINT64, true, block.workspace()); + return SHAPELIST(CONSTANT(outShape)); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/flow/flow_control_ops.cpp b/libnd4j/include/ops/declarable/generic/flow/flow_control_ops.cpp index 108660c7b7db..6b7ddd3f9cbb 100644 --- a/libnd4j/include/ops/declarable/generic/flow/flow_control_ops.cpp +++ b/libnd4j/include/ops/declarable/generic/flow/flow_control_ops.cpp @@ -18,68 +18,68 @@ // Created by raver119 on 13.10.2017. // - #include #include namespace sd { - namespace ops { - /** - * This operation is, basically IF statement - * - * arg_0 is our "signal" - * arg_1 is condition that will determine transition - */ - // TODO: make this op a placeholder too - DIVERGENT_OP_IMPL(Switch, 2, 2, true) { - auto input = INPUT_VARIABLE(0); - auto condition = INPUT_VARIABLE(1); - - // we'll store signal to both ends - //STORE_2_RESULTS(*input, *input); - - // but we'll ensure only one node is active, and other is disabled - if (condition->e(0) == 0) { - block.setBranch(0); - this->storeResult(block, 0, new NDArray(input->dup())); - } else { - block.setBranch(1); - this->storeResult(block, 1, new NDArray(input->dup())); - } +namespace ops { +/** + * This operation is, basically IF statement + * + * arg_0 is our "signal" + * arg_1 is condition that will determine transition + */ +// TODO: make this op a placeholder too +DIVERGENT_OP_IMPL(Switch, 2, 2, true) { + auto input = INPUT_VARIABLE(0); + auto condition = INPUT_VARIABLE(1); - return Status::OK(); - } - DECLARE_SYN(switch, Switch); - DECLARE_SYN(if, Switch); + // we'll store signal to both ends + // STORE_2_RESULTS(*input, *input); + /* + // but we'll ensure only one node is active, and other is disabled + if (condition->e(0) == 0) { + block.setBranch(0); + this->storeResult(block, 0, new NDArray(input->dup())); + } else { + block.setBranch(1); + this->storeResult(block, 1, new NDArray(input->dup())); + } + return Status::OK(); + */ - /** - * This op is a placeholder. - * Actual WHILE implementation is in GraphExecutioner - */ - LOGIC_OP_IMPL(While); - DECLARE_SYN(while, While); + throw std::runtime_error("Switch - Not implemented yet"); +} +DECLARE_SYN(switch, Switch); +DECLARE_SYN(if, Switch); - /** - * This op is a placeholder. - * Actual Scope implementation is in Graph and GraphExecutioner - */ - LOGIC_OP_IMPL(Scope); - DECLARE_SYN(scope, Scope); +/** + * This op is a placeholder. + * Actual WHILE implementation is in GraphExecutioner + */ +LOGIC_OP_IMPL(While); +DECLARE_SYN(while, While); - /** - * This op is a placeholder. - * Actual Conditional implementation is in Graph and GraphExecutioner - */ - LOGIC_OP_IMPL(Conditional); - DECLARE_SYN(cond, Conditional); +/** + * This op is a placeholder. + * Actual Scope implementation is in Graph and GraphExecutioner + */ +LOGIC_OP_IMPL(Scope); +DECLARE_SYN(scope, Scope); +/** + * This op is a placeholder. + * Actual Conditional implementation is in Graph and GraphExecutioner + */ +LOGIC_OP_IMPL(Conditional); +DECLARE_SYN(cond, Conditional); - /** - * This op is a placeholder - * Actual implementation is in LogicReturn class - */ - LOGIC_OP_IMPL(Return); - DECLARE_SYN(return, Return); - } -} \ No newline at end of file +/** + * This op is a placeholder + * Actual implementation is in LogicReturn class + */ +LOGIC_OP_IMPL(Return); +DECLARE_SYN(return, Return); +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/grad/broadcast_gradient_args.cpp b/libnd4j/include/ops/declarable/generic/grad/broadcast_gradient_args.cpp index e4dbcd6d549d..0b74bdeebe58 100644 --- a/libnd4j/include/ops/declarable/generic/grad/broadcast_gradient_args.cpp +++ b/libnd4j/include/ops/declarable/generic/grad/broadcast_gradient_args.cpp @@ -24,23 +24,21 @@ #include namespace sd { - namespace ops { - /** - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ - OP_IMPL(broadcastgradientargs, 2, 2, true) { - - nd4j_printf("BroadcastGradientArgs: Not implemented yet\n", ""); +namespace ops { +/** + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +OP_IMPL(broadcastgradientargs, 2, 2, true) { + nd4j_printf("BroadcastGradientArgs: Not implemented yet\n", ""); - return ND4J_STATUS_KERNEL_FAILURE; - } - DECLARE_SYN(BroadcastGradientArgs, broadcastgradientargs); + return ND4J_STATUS_KERNEL_FAILURE; +} +DECLARE_SYN(BroadcastGradientArgs, broadcastgradientargs); - DECLARE_TYPES(broadcastgradientargs) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY); - } - } +DECLARE_TYPES(broadcastgradientargs) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/helpers/BroadcastHelper.h b/libnd4j/include/ops/declarable/generic/helpers/BroadcastHelper.h index af7f2d8d73c3..e7e9f758969a 100644 --- a/libnd4j/include/ops/declarable/generic/helpers/BroadcastHelper.h +++ b/libnd4j/include/ops/declarable/generic/helpers/BroadcastHelper.h @@ -22,124 +22,140 @@ #define LIBND4J_BROADCAST_HELPER_H #include +#include #include -#include #include -#include +#include namespace sd { - namespace ops { - class BroadcastHelper { - public: - static FORCEINLINE NDArray* broadcastApply(sd::BroadcastOpsTuple op, NDArray* x, NDArray* y, NDArray* z, ExtraArguments *extraArgs = nullptr) { - - if(x->isEmpty() || y->isEmpty()) { - if(!z->isEmpty()) - throw std::runtime_error("BroadcastHelper::broadcastApply: when some of input arrays (or both) is empty, output array must be empty as well !"); - return z; - } - - std::unique_ptr ptr; - if (!Environment::getInstance().isExperimentalBuild()) { - if (y->dataType() != x->dataType()) { - y = new NDArray(y->cast(x->dataType())); - std::unique_ptr ptr2(y); - ptr.swap(ptr2); - } - } - - if (!x->isScalar() && !y->isScalar() && x->isSameShape(y)) { - x->applyPairwiseTransform(op.p, *y, *z); - } else if (!x->isScalar() && y->isScalar()) { - x->applyScalarArr(op.s, const_cast(*y), *z); - } else if (x->isScalar() && !y->isScalar()) { - if (z->isSameShape(y)) { - if (op.s == scalar::Add || op.s == scalar::Multiply ) { - y->applyScalarArr(op.s, *x, *z); - } else if (op.s == scalar::SquaredSubtract) { - y->applyScalarArr(scalar::SquaredReverseSubtract, *x, *z); - } else if (op.s == scalar::Subtract) { - y->applyScalarArr(scalar::ReverseSubtract, *x, *z); - } else if (op.s == scalar::Divide) { - y->applyScalarArr(scalar::ReverseDivide, *x, *z); - } else if (op.s == scalar::Pow) { - y->applyScalarArr(scalar::ReversePow, *x, *z); - } else if (op.s == scalar::ReverseSubtract) { - y->applyScalarArr(scalar::Subtract, *x, *z); - } else if (op.s == scalar::ReverseDivide) { - y->applyScalarArr(scalar::Divide, *x, *z); - } else if (op.s == scalar::MaxPairwise || op.s == scalar::MinPairwise || op.s == scalar::AMaxPairwise || op.s == scalar::AMinPairwise) { - y->applyScalarArr(op.s, *x, *z); - } else if (op.s == scalar::CopyPws) { - z->assign(y); - } else { - z->assign(x); - z->applyPairwiseTransform(op.p, *y, extraArgs); - } - return z; - } else { - auto v = y->getShapeAsVector(); - auto tZ = NDArrayFactory::valueOf(v, y, y->ordering()); - tZ->applyPairwiseTransform(op.p, *y, extraArgs); - return tZ; - } - } else if (x->isScalar() && y->isScalar()) { // x->isScalar() && y->isScalar() - x->applyScalarArr(op.s, const_cast(*y), *z); - } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { - x->applyTrueBroadcast(op, *y, *z, true, extraArgs); - return z; - } else { - auto sx = ShapeUtils::shapeAsString(x); - auto sy = ShapeUtils::shapeAsString(y); - nd4j_printf("Broadcast: shapes should be equal, or broadcastable. But got %s vs %s instead\n", sx.c_str(), sy.c_str()); - return nullptr; - } +namespace ops { +class BroadcastHelper { + public: + static FORCEINLINE NDArray* broadcastApply( + sd::BroadcastOpsTuple op, NDArray* x, NDArray* y, NDArray* z, + ExtraArguments* extraArgs = nullptr) { + if (x->isEmpty() || y->isEmpty()) { + if (!z->isEmpty()) + throw std::runtime_error( + "BroadcastHelper::broadcastApply: when some of input arrays (or " + "both) is empty, output array must be empty as well !"); + return z; + } - return z; - } + std::unique_ptr ptr; + if (!Environment::getInstance().isExperimentalBuild()) { + if (y->dataType() != x->dataType()) { + y = new NDArray(y->cast(x->dataType())); + std::unique_ptr ptr2(y); + ptr.swap(ptr2); + } + } - static FORCEINLINE NDArray* broadcastApply(sd::BroadcastBoolOpsTuple op, NDArray* x, NDArray* y, NDArray* z, ExtraArguments *extraArgs = nullptr) { + if (!x->isScalar() && !y->isScalar() && x->isSameShape(y)) { + x->applyPairwiseTransform(op.p, *y, *z); + } else if (!x->isScalar() && y->isScalar()) { + x->applyScalarArr(op.s, const_cast(*y), *z); + } else if (x->isScalar() && !y->isScalar()) { + if (z->isSameShape(y)) { + if (op.s == scalar::Add || op.s == scalar::Multiply) { + y->applyScalarArr(op.s, *x, *z); + } else if (op.s == scalar::SquaredSubtract) { + y->applyScalarArr(scalar::SquaredReverseSubtract, *x, *z); + } else if (op.s == scalar::Subtract) { + y->applyScalarArr(scalar::ReverseSubtract, *x, *z); + } else if (op.s == scalar::Divide) { + y->applyScalarArr(scalar::ReverseDivide, *x, *z); + } else if (op.s == scalar::Pow) { + y->applyScalarArr(scalar::ReversePow, *x, *z); + } else if (op.s == scalar::ReverseSubtract) { + y->applyScalarArr(scalar::Subtract, *x, *z); + } else if (op.s == scalar::ReverseDivide) { + y->applyScalarArr(scalar::Divide, *x, *z); + } else if (op.s == scalar::MaxPairwise || op.s == scalar::MinPairwise || + op.s == scalar::AMaxPairwise || + op.s == scalar::AMinPairwise) { + y->applyScalarArr(op.s, *x, *z); + } else if (op.s == scalar::CopyPws) { + z->assign(y); + } else { + z->assign(x); + z->applyPairwiseTransform(op.p, *y, extraArgs); + } + return z; + } else { + auto v = y->getShapeAsVector(); + auto tZ = NDArrayFactory::valueOf(v, y, y->ordering()); + tZ->applyPairwiseTransform(op.p, *y, extraArgs); + return tZ; + } + } else if (x->isScalar() && + y->isScalar()) { // x->isScalar() && y->isScalar() + x->applyScalarArr(op.s, const_cast(*y), *z); + } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { + x->applyTrueBroadcast(op, *y, *z, true, extraArgs); + return z; + } else { + auto sx = ShapeUtils::shapeAsString(x); + auto sy = ShapeUtils::shapeAsString(y); + nd4j_printf( + "Broadcast: shapes should be equal, or broadcastable. But got %s vs " + "%s instead\n", + sx.c_str(), sy.c_str()); + return nullptr; + } - if(x->isEmpty() || y->isEmpty()) { - if(!z->isEmpty()) - throw std::runtime_error("BroadcastHelper::broadcastApply: when some of input arrays (or both) is empty, output array must be empty as well !"); - return z; - } + return z; + } - if (!x->isScalar() && !y->isScalar() && x->isSameShape(y)) { - x->applyPairwiseTransform(op.p, *y, *z); - } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { - x->applyTrueBroadcast(op, *y, *z, true, extraArgs); - return z; - } else if (!x->isScalar() && y->isScalar()) { - x->applyScalarArr(op.s, const_cast(*y), *z); - } else if (x->isScalar() && !y->isScalar()) { - if (z->isSameShape(y)) { - //z->assign(x); - x->applyPairwiseTransform(op.p, *y, *z, extraArgs); - return z; - } else { - auto v = y->getShapeAsVector(); - auto tZ = NDArrayFactory::valueOf(v, y, y->ordering()); - //tZ->applyPairwiseTransform(op.p, *y, extraArgs); - return tZ; - } - } else if (x->isScalar() && y->isScalar()) { // x->isScalar() && y->isScalar() - x->applyScalarArr(op.s, const_cast(*y), *z); - } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { - x->applyTrueBroadcast(op, *y, *z, true, extraArgs); - return z; - } else { - auto sx = ShapeUtils::shapeAsString(x); - auto sy = ShapeUtils::shapeAsString(y); - nd4j_printf("Broadcast: shapes should be equal, or broadcastable. But got %s vs %s instead\n", sx.c_str(), sy.c_str()); - return nullptr; - } + static FORCEINLINE NDArray* broadcastApply( + sd::BroadcastBoolOpsTuple op, NDArray* x, NDArray* y, NDArray* z, + ExtraArguments* extraArgs = nullptr) { + if (x->isEmpty() || y->isEmpty()) { + if (!z->isEmpty()) + throw std::runtime_error( + "BroadcastHelper::broadcastApply: when some of input arrays (or " + "both) is empty, output array must be empty as well !"); + return z; + } - return z; - } - }; + if (!x->isScalar() && !y->isScalar() && x->isSameShape(y)) { + x->applyPairwiseTransform(op.p, *y, *z); + } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { + x->applyTrueBroadcast(op, *y, *z, true, extraArgs); + return z; + } else if (!x->isScalar() && y->isScalar()) { + x->applyScalarArr(op.s, const_cast(*y), *z); + } else if (x->isScalar() && !y->isScalar()) { + if (z->isSameShape(y)) { + // z->assign(x); + x->applyPairwiseTransform(op.p, *y, *z, extraArgs); + return z; + } else { + auto v = y->getShapeAsVector(); + auto tZ = NDArrayFactory::valueOf(v, y, y->ordering()); + // tZ->applyPairwiseTransform(op.p, *y, extraArgs); + return tZ; + } + } else if (x->isScalar() && + y->isScalar()) { // x->isScalar() && y->isScalar() + x->applyScalarArr(op.s, const_cast(*y), *z); + } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { + x->applyTrueBroadcast(op, *y, *z, true, extraArgs); + return z; + } else { + auto sx = ShapeUtils::shapeAsString(x); + auto sy = ShapeUtils::shapeAsString(y); + nd4j_printf( + "Broadcast: shapes should be equal, or broadcastable. But got %s vs " + "%s instead\n", + sx.c_str(), sy.c_str()); + return nullptr; } -} + + return z; + } +}; +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/helpers/ScatterHelper.h b/libnd4j/include/ops/declarable/generic/helpers/ScatterHelper.h index 4d464a7456ce..525d8334931c 100644 --- a/libnd4j/include/ops/declarable/generic/helpers/ScatterHelper.h +++ b/libnd4j/include/ops/declarable/generic/helpers/ScatterHelper.h @@ -22,18 +22,15 @@ #ifndef LIBND4J_SCATTERHELPER_H #define LIBND4J_SCATTERHELPER_H -#include -#include #include -#include #include +#include +#include +#include namespace sd { -namespace ops { - - -} -} +namespace ops {} +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/images/adjust_contrast.cpp b/libnd4j/include/ops/declarable/generic/images/adjust_contrast.cpp index 796dbb80ba36..d79f721e95e4 100644 --- a/libnd4j/include/ops/declarable/generic/images/adjust_contrast.cpp +++ b/libnd4j/include/ops/declarable/generic/images/adjust_contrast.cpp @@ -21,110 +21,117 @@ #include #if NOT_EXCLUDED(OP_adjust_contrast) -#include #include +#include namespace sd { namespace ops { //////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - // just skip op if input is empty - if (input->isEmpty()) - return Status::OK(); + // just skip op if input is empty + if (input->isEmpty()) return Status::OK(); - REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST: Scale factor required"); - REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST: op expects rank of input array to be >= 3, but got %i instead", input->rankOf()); -// REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); + REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, + "ADJUST_CONTRAST: Scale factor required"); + REQUIRE_TRUE(input->rankOf() > 2, 0, + "ADJUST_CONTRAST: op expects rank of input array to be >= 3, " + "but got %i instead", + input->rankOf()); + // REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation + // expects image with 3 channels (R, G, B), but got %i instead", + // input->sizeAt(-1)); - NDArray* factor = nullptr; + NDArray* factor = nullptr; - if(block.width() > 1) - factor = INPUT_VARIABLE(1); - else { - factor = new NDArray(output->dataType(), block.launchContext()); - factor->p(0, T_ARG(0)); - } + if (block.width() > 1) + factor = INPUT_VARIABLE(1); + else { + factor = new NDArray(output->dataType(), block.launchContext()); + factor->p(0, T_ARG(0)); + } - // fill up axes vector first - std::vector axes(input->rankOf() - 1); - for (auto i = 0; i < axes.size(); ++i) - axes[i] = i; + // fill up axes vector first + std::vector axes(input->rankOf() - 1); + for (auto i = 0; i < axes.size(); ++i) axes[i] = i; - // mean as reduction for last dimension set - auto mean = input->reduceAlongDimension(reduce::Mean, axes); + // mean as reduction for last dimension set + auto mean = input->reduceAlongDimension(reduce::Mean, axes); - // this is contrast calculation - output->assign((*input - mean) * (*factor) + mean); + // this is contrast calculation + output->assign((*input - mean) * (*factor) + mean); - if(block.width() == 1) - delete factor; + if (block.width() == 1) delete factor; - return Status::OK(); + return Status::OK(); } DECLARE_TYPES(adjust_contrast) { - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(true); } //////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - // just skip op if input is empty - if (input->isEmpty()) - return Status::OK(); - - REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST_V2: op expects rank of input array to be >= 3, but got %i instead", input->rankOf()); -// REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST_V2: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); - REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST_V2: Scale factor required"); - - NDArray* factor = nullptr; - auto size = input->sizeAt(-2) * input->sizeAt(-3); - auto channels = input->sizeAt(-1); - auto batch = input->lengthOf() / (size * channels); - auto input3D = input->reshape(input->ordering(), {batch, size, channels}); - auto output3D = input->reshape(input->ordering(), {batch, size, channels}); - - if(block.width() > 1) - factor = INPUT_VARIABLE(1); - else { - factor = new NDArray(output->dataType(), block.launchContext()); - factor->p(0, T_ARG(0)); - } - - std::vector axes({1}); // dim 1 of pseudoresult - -// mean as reduction for last dimension set over size (dim 1) of result3D - auto mean = input3D.reduceAlongDimension(reduce::Mean, axes); - - // result as (x - mean) * factor + mean - auto temp = input3D.ulike(); - input3D.applyBroadcast(broadcast::Subtract, {0, 2}, mean, temp); - temp.applyScalarArr(scalar::Multiply, *factor, temp); - temp.applyBroadcast(broadcast::Add, {0, 2}, mean, output3D); - output->assign(output3D); - if(block.width() == 1) - delete factor; - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + // just skip op if input is empty + if (input->isEmpty()) return Status::OK(); + + REQUIRE_TRUE(input->rankOf() > 2, 0, + "ADJUST_CONTRAST_V2: op expects rank of input array to be >= 3, " + "but got %i instead", + input->rankOf()); + // REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST_V2: operation + // expects image with 3 channels (R, G, B), but got %i instead", + // input->sizeAt(-1)); + REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, + "ADJUST_CONTRAST_V2: Scale factor required"); + + NDArray* factor = nullptr; + auto size = input->sizeAt(-2) * input->sizeAt(-3); + auto channels = input->sizeAt(-1); + auto batch = input->lengthOf() / (size * channels); + auto input3D = input->reshape(input->ordering(), {batch, size, channels}); + auto output3D = input->reshape(input->ordering(), {batch, size, channels}); + + if (block.width() > 1) + factor = INPUT_VARIABLE(1); + else { + factor = new NDArray(output->dataType(), block.launchContext()); + factor->p(0, T_ARG(0)); + } + + std::vector axes({1}); // dim 1 of pseudoresult + + // mean as reduction for last dimension set over size (dim 1) of result3D + auto mean = input3D.reduceAlongDimension(reduce::Mean, axes); + + // result as (x - mean) * factor + mean + auto temp = input3D.ulike(); + input3D.applyBroadcast(broadcast::Subtract, {0, 2}, mean, temp); + temp.applyScalarArr(scalar::Multiply, *factor, temp); + temp.applyBroadcast(broadcast::Add, {0, 2}, mean, output3D); + output->assign(output3D); + if (block.width() == 1) delete factor; + + return Status::OK(); } DECLARE_TYPES(adjust_contrast_v2) { - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(true); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/images/adjust_hue.cpp b/libnd4j/include/ops/declarable/generic/images/adjust_hue.cpp index 436fae28d75c..8bf5d94ba052 100644 --- a/libnd4j/include/ops/declarable/generic/images/adjust_hue.cpp +++ b/libnd4j/include/ops/declarable/generic/images/adjust_hue.cpp @@ -28,55 +28,62 @@ namespace sd { namespace ops { - CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - // just skip op if input is empty - if (input->isEmpty()) - return Status::OK(); - - const int rank = input->rankOf(); - const int arg_size = block.getIArguments()->size(); - const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; - - REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_HUE: delta factor is required !"); - REQUIRE_TRUE(rank >= 3, 0, "ADJUST_HUE: op expects rank of input array to be >= 3, but got %i instead", rank); - if (arg_size > 0) { - REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); - } - REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_HUE: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); - - NDArray* delta = nullptr; - - if(block.width() > 1) - delta = INPUT_VARIABLE(1); - else { - delta = new NDArray(output->dataType(), block.launchContext()); - delta->p(0, T_ARG(0)); - } - - REQUIRE_TRUE(-1. <= delta->e(0) && delta->e(0) <= 1., 0, "ADJUST_HUE: parameter delta must be within [-1, 1] interval, but got %f instead", delta); - - helpers::adjustHue(block.launchContext(), input, delta, output, dimC); - - if(block.width() == 1) - delete delta; - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + // just skip op if input is empty + if (input->isEmpty()) return Status::OK(); + + const int rank = input->rankOf(); + const int arg_size = block.numI(); + const int dimC = arg_size > 0 + ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) + : rank - 1; + + REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, + "ADJUST_HUE: delta factor is required !"); + REQUIRE_TRUE(rank >= 3, 0, + "ADJUST_HUE: op expects rank of input array to be >= 3, but got " + "%i instead", + rank); + if (arg_size > 0) { + REQUIRE_TRUE( + dimC >= 0 && dimC < rank, 0, + "Index of the Channel dimension out of range: %i not in [%i,%i) ", + INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, + "ADJUST_HUE: operation expects image with 3 channels (R, G, B), " + "but got %i instead", + input->sizeAt(dimC)); + + NDArray* delta = nullptr; + + if (block.width() > 1) + delta = INPUT_VARIABLE(1); + else { + delta = new NDArray(output->dataType(), block.launchContext()); + delta->p(0, T_ARG(0)); + } + + REQUIRE_TRUE(-1. <= delta->e(0) && delta->e(0) <= 1., 0, + "ADJUST_HUE: parameter delta must be within [-1, 1] interval, " + "but got %f instead", + delta); + + helpers::adjustHue(block.launchContext(), input, delta, output, dimC); + + if (block.width() == 1) delete delta; + + return Status::OK(); } DECLARE_TYPES(adjust_hue) { - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } - - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/images/adjust_saturation.cpp b/libnd4j/include/ops/declarable/generic/images/adjust_saturation.cpp index 5be1699f4a6b..0079d5970675 100644 --- a/libnd4j/include/ops/declarable/generic/images/adjust_saturation.cpp +++ b/libnd4j/include/ops/declarable/generic/images/adjust_saturation.cpp @@ -21,58 +21,64 @@ #include #if NOT_EXCLUDED(OP_adjust_saturation) +#include #include #include -#include namespace sd { namespace ops { CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - // just skip op if input is empty - if (input->isEmpty()) - return Status::OK(); - - const int rank = input->rankOf(); - const int arg_size = block.getIArguments()->size(); - const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; - - REQUIRE_TRUE(rank >= 3, 0, "ADJUST_SATURATION: op expects rank of input array to be >= 3, but got %i instead", rank); - if (arg_size > 0) { - REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); - } - REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_SATURATION: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); - REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_SATURATION: scale factor is required !"); - - NDArray* factor = nullptr; - - if(block.width() > 1) - factor = INPUT_VARIABLE(1); - else { - factor = new NDArray(output->dataType(), block.launchContext()); - factor->p(0, T_ARG(0)); - } - - helpers::adjustSaturation(block.launchContext(), input, factor, output, dimC); - - if(block.width() == 1) - delete factor; - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + // just skip op if input is empty + if (input->isEmpty()) return Status::OK(); + + const int rank = input->rankOf(); + const int arg_size = block.numI(); + const int dimC = arg_size > 0 + ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) + : rank - 1; + + REQUIRE_TRUE(rank >= 3, 0, + "ADJUST_SATURATION: op expects rank of input array to be >= 3, " + "but got %i instead", + rank); + if (arg_size > 0) { + REQUIRE_TRUE( + dimC >= 0 && dimC < rank, 0, + "Index of the Channel dimension out of range: %i not in [%i,%i) ", + INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, + "ADJUST_SATURATION: operation expects image with 3 channels (R, " + "G, B), but got %i instead", + input->sizeAt(dimC)); + REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, + "ADJUST_SATURATION: scale factor is required !"); + + NDArray* factor = nullptr; + + if (block.width() > 1) + factor = INPUT_VARIABLE(1); + else { + factor = new NDArray(output->dataType(), block.launchContext()); + factor->p(0, T_ARG(0)); + } + + helpers::adjustSaturation(block.launchContext(), input, factor, output, dimC); + + if (block.width() == 1) delete factor; + + return Status::OK(); } DECLARE_TYPES(adjust_saturation) { - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/images/crop_and_resize.cpp b/libnd4j/include/ops/declarable/generic/images/crop_and_resize.cpp index 3c101070de68..6a245d590507 100644 --- a/libnd4j/include/ops/declarable/generic/images/crop_and_resize.cpp +++ b/libnd4j/include/ops/declarable/generic/images/crop_and_resize.cpp @@ -26,70 +26,80 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(crop_and_resize, 4, 1, false, 0, 0) { - - auto image = INPUT_VARIABLE(0); - auto boxes = INPUT_VARIABLE(1); - auto boxIndexes = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - int width; - int height; - int method = 0; // bilinear - double extrapolationVal = 0.; - - auto newImageSize = INPUT_VARIABLE(3); - REQUIRE_TRUE(output->dataType() == image->dataType(), 0, "crop_and_resize: Source images and output should have the same data type."); - REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "crop_and_resize: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); - //REQUIRE_TRUE(block.numI() <= 1, 0, "crop_and_resize: Resize params already given by the second param. Int params are expensive."); - //width = int(newImageSize->getScalar(0)); - //height = int(newImageSize->getScalar(1)); - if (block.numI() == 1) { - method = INT_ARG(0); - } - - if (block.numT() == 1) { - extrapolationVal = T_ARG(0); - } - - helpers::cropAndResizeFunctor(block.launchContext(), image, boxes, boxIndexes, newImageSize, method, extrapolationVal, output); - return ND4J_STATUS_OK; - } - - DECLARE_SHAPE_FN(crop_and_resize) { - auto in = inputShape->at(0); - auto boxShape = inputShape->at(1); - - Nd4jLong outputShape[4]; - - int width; - int height; - auto newImageSize = INPUT_VARIABLE(3); - REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "crop_and_resize: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); - //REQUIRE_TRUE(block.numI() <= 1, 0, "crop_and_resize: Resize params already given by the second param. Int params are expensive."); - width = newImageSize->e(0); - height = newImageSize->e(1); - - outputShape[0] = boxShape[1]; - outputShape[1] = width; - outputShape[2] = height; - outputShape[3] = in[4]; - - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(in), shape::order(in), outputShape, 4))); - } - - DECLARE_TYPES(crop_and_resize) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) -// ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_INTS}) - ->setAllowedInputTypes(3, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); // as TF -// ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +namespace ops { + +CUSTOM_OP_IMPL(crop_and_resize, 4, 1, false, 0, 0) { + auto image = INPUT_VARIABLE(0); + auto boxes = INPUT_VARIABLE(1); + auto boxIndexes = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + int width; + int height; + int method = 0; // bilinear + double extrapolationVal = 0.; + + auto newImageSize = INPUT_VARIABLE(3); + REQUIRE_TRUE(output->dataType() == image->dataType(), 0, + "crop_and_resize: Source images and output should have the same " + "data type."); + REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, + "crop_and_resize: Resize params is a pair of values, not %i.", + newImageSize->lengthOf()); + // REQUIRE_TRUE(block.numI() <= 1, 0, "crop_and_resize: Resize params already + // given by the second param. Int params are expensive."); width = + // int(newImageSize->getScalar(0)); height = int(newImageSize->getScalar(1)); + if (block.numI() == 1) { + method = INT_ARG(0); + } + + if (block.numT() == 1) { + extrapolationVal = T_ARG(0); + } + + helpers::cropAndResizeFunctor(block.launchContext(), image, boxes, boxIndexes, + newImageSize, method, extrapolationVal, output); + return ND4J_STATUS_OK; +} + +DECLARE_SHAPE_FN(crop_and_resize) { + auto in = inputShape->at(0); + auto boxShape = inputShape->at(1); + + Nd4jLong outputShape[4]; + + int width; + int height; + auto newImageSize = INPUT_VARIABLE(3); + REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, + "crop_and_resize: Resize params is a pair of values, not %i.", + newImageSize->lengthOf()); + // REQUIRE_TRUE(block.numI() <= 1, 0, "crop_and_resize: Resize params already + // given by the second param. Int params are expensive."); + width = newImageSize->e(0); + height = newImageSize->e(1); + + outputShape[0] = boxShape[1]; + outputShape[1] = width; + outputShape[2] = height; + outputShape[3] = in[4]; + + return SHAPELIST( + ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor( + ArrayOptions::dataType(in), shape::order(in), outputShape, 4))); +} + +DECLARE_TYPES(crop_and_resize) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + // ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_INTS}) + ->setAllowedInputTypes(3, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); // as TF + // ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/images/draw_bounding_boxes.cpp b/libnd4j/include/ops/declarable/generic/images/draw_bounding_boxes.cpp index d143bdcf81d9..45927669440b 100644 --- a/libnd4j/include/ops/declarable/generic/images/draw_bounding_boxes.cpp +++ b/libnd4j/include/ops/declarable/generic/images/draw_bounding_boxes.cpp @@ -24,48 +24,56 @@ #include #include namespace sd { - namespace ops { - OP_IMPL(draw_bounding_boxes, 3, 1, true) { +namespace ops { +OP_IMPL(draw_bounding_boxes, 3, 1, true) { + auto images = INPUT_VARIABLE(0); + auto boxes = INPUT_VARIABLE(1); - auto images = INPUT_VARIABLE(0); - auto boxes = INPUT_VARIABLE(1); + auto colors = (NDArray*)nullptr; + if (block.width() > + 2) // TF v.1.x ommits color set for boxes, and use color 1.0 for fill up + colors = INPUT_VARIABLE(2); // but v.2.y require color set - auto colors = (NDArray*) nullptr; - if (block.width() > 2) // TF v.1.x ommits color set for boxes, and use color 1.0 for fill up - colors = INPUT_VARIABLE(2); // but v.2.y require color set - - auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(images->dataType() == output->dataType(), 0, "draw_bounding_boxes: Input and Output types " - "should be equals, but %d and %d occured.", - (int)images->dataType(), (int)output->dataType()); - REQUIRE_TRUE(images->rankOf() == 4, 0, "draw_bounding_boxes: Images input should be 4D tensor, but %i occured.", - images->rankOf()); - REQUIRE_TRUE(boxes->rankOf() == 3, 0, "draw_bounding_boxes: Boxes should be 3D tensor, but %i occured.", - boxes->rankOf()); - if (colors) { - - REQUIRE_TRUE(colors->rankOf() == 2, 0, "draw_bounding_boxes: Color set should be 2D matrix, but %i occured.", - colors->rankOf()); - REQUIRE_TRUE(colors->sizeAt(1) >= images->sizeAt(3), 0, "draw_bounding_boxes: Color set last dim " - "should be not less than images depth, but " - "%lld and %lld occured.", - colors->sizeAt(1), images->sizeAt(3)); - } - REQUIRE_TRUE(boxes->sizeAt(0) == images->sizeAt(0), 0, "draw_bounding_boxes: Batches for images and boxes " - "should be the same, but %lld and %lld occured.", - images->sizeAt(0), boxes->sizeAt(0)); - helpers::drawBoundingBoxesFunctor(block.launchContext(), images, boxes, colors, output); - return ND4J_STATUS_OK; - } + auto output = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(images->dataType() == output->dataType(), 0, + "draw_bounding_boxes: Input and Output types " + "should be equals, but %d and %d occured.", + (int)images->dataType(), (int)output->dataType()); + REQUIRE_TRUE( + images->rankOf() == 4, 0, + "draw_bounding_boxes: Images input should be 4D tensor, but %i occured.", + images->rankOf()); + REQUIRE_TRUE( + boxes->rankOf() == 3, 0, + "draw_bounding_boxes: Boxes should be 3D tensor, but %i occured.", + boxes->rankOf()); + if (colors) { + REQUIRE_TRUE( + colors->rankOf() == 2, 0, + "draw_bounding_boxes: Color set should be 2D matrix, but %i occured.", + colors->rankOf()); + REQUIRE_TRUE(colors->sizeAt(1) >= images->sizeAt(3), 0, + "draw_bounding_boxes: Color set last dim " + "should be not less than images depth, but " + "%lld and %lld occured.", + colors->sizeAt(1), images->sizeAt(3)); + } + REQUIRE_TRUE(boxes->sizeAt(0) == images->sizeAt(0), 0, + "draw_bounding_boxes: Batches for images and boxes " + "should be the same, but %lld and %lld occured.", + images->sizeAt(0), boxes->sizeAt(0)); + helpers::drawBoundingBoxesFunctor(block.launchContext(), images, boxes, colors, output); + return Status::OK(); +} - DECLARE_TYPES(draw_bounding_boxes) { - getOpDescriptor() - ->setAllowedInputTypes(0, {HALF, FLOAT32})// TF allows HALF and FLOAT32 only - ->setAllowedInputTypes(1, {FLOAT32}) // as TF - ->setAllowedInputTypes(2, {FLOAT32}) // as TF - ->setAllowedOutputTypes({HALF, FLOAT32}); // TF allows HALF and FLOAT32 only - } - } +DECLARE_TYPES(draw_bounding_boxes) { + getOpDescriptor() + ->setAllowedInputTypes(0, {HALF, FLOAT32}) // TF allows HALF and FLOAT32 only + ->setAllowedInputTypes(1, sd::DataType::FLOAT32) // as TF + ->setAllowedInputTypes(2, sd::DataType::FLOAT32) // as TF + ->setAllowedOutputTypes({HALF, FLOAT32}); // TF allows HALF and FLOAT32 only } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/images/extract_image_patches.cpp b/libnd4j/include/ops/declarable/generic/images/extract_image_patches.cpp index 1bcb8ef36b81..209fbcb946cf 100644 --- a/libnd4j/include/ops/declarable/generic/images/extract_image_patches.cpp +++ b/libnd4j/include/ops/declarable/generic/images/extract_image_patches.cpp @@ -22,76 +22,78 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(extract_image_patches, 1, 1, false, 0, 7) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - int ksizeRows = INT_ARG(0); - int ksizeCols = INT_ARG(1); - int kstrideRows = INT_ARG(2); - int kstrideCols = INT_ARG(3); - int krateRows = INT_ARG(4); - int krateCols = INT_ARG(5); - bool isSame = INT_ARG(6) != 0; - - REQUIRE_TRUE(input->rankOf() == 4, 0, "extract_image_patches: The rank of input array should be 4, but %i is given", input->rankOf()); - // - if (output->isSameShape(input)) - output->assign(input); - else { - output->nullify(); - helpers::extractPatches(block.launchContext(), input, output, ksizeRows, ksizeCols, kstrideRows, kstrideCols, krateRows, krateCols, isSame); - } - return Status::OK(); - } - - DECLARE_TYPES(extract_image_patches) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - - DECLARE_SHAPE_FN(extract_image_patches) { - - auto in = inputShape->at(0); - int outRank = shape::rank(in); - Nd4jLong *outputShape = nullptr; - - int ksizeRowsEffective = INT_ARG(0) + (INT_ARG(0) - 1) * (INT_ARG(4) - 1); - int ksizeColsEffective = INT_ARG(1) + (INT_ARG(1) - 1) * (INT_ARG(5) - 1); - - auto batchSizeDim = shape::sizeAt(in, 0); - auto inputRowsDim = shape::sizeAt(in, 1); - auto inputColsDim = shape::sizeAt(in, 2); - auto outputDepthDim = shape::sizeAt(in, 3) * INT_ARG(0) * INT_ARG(1); // last dim * ksizeRows * ksizeCols - - auto inputRowSize = inputRowsDim; //shape::sizeAt(in, inputRowsDim); - auto inputColSize = inputColsDim; //shape::sizeAt(in, inputColsDim); - Nd4jLong outRowSize; - Nd4jLong outColSize; - if (INT_ARG(6) == 0) { - // Padding is "VALID": - outRowSize = (inputRowSize - ksizeRowsEffective + INT_ARG(2)) / INT_ARG(2); - outColSize = (inputColSize - ksizeColsEffective + INT_ARG(3)) / INT_ARG(3); - } else { - // Padding is "SAME": - outRowSize = (inputRowSize + INT_ARG(2) - 1) / INT_ARG(2); - outColSize = (inputColSize + INT_ARG(3) - 1) / INT_ARG(3); - } - - - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); - - outputShape[0] = outRank; - outputShape[1] = batchSizeDim; - outputShape[2] = outRowSize; - outputShape[3] = outColSize; - outputShape[4] = outputDepthDim; - - - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - - return SHAPELIST(CONSTANT(outputShape)); - } - } -} \ No newline at end of file +namespace ops { +CUSTOM_OP_IMPL(extract_image_patches, 1, 1, false, 0, 7) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + int ksizeRows = INT_ARG(0); + int ksizeCols = INT_ARG(1); + int kstrideRows = INT_ARG(2); + int kstrideCols = INT_ARG(3); + int krateRows = INT_ARG(4); + int krateCols = INT_ARG(5); + bool isSame = INT_ARG(6) != 0; + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "extract_image_patches: The rank of input array should be 4, " + "but %i is given", + input->rankOf()); + // + if (output->isSameShape(input)) + output->assign(input); + else { + output->nullify(); + helpers::extractPatches(block.launchContext(), input, output, ksizeRows, + ksizeCols, kstrideRows, kstrideCols, krateRows, + krateCols, isSame); + } + return Status::OK(); +} + +DECLARE_TYPES(extract_image_patches) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} + +DECLARE_SHAPE_FN(extract_image_patches) { + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong *outputShape = nullptr; + + int ksizeRowsEffective = INT_ARG(0) + (INT_ARG(0) - 1) * (INT_ARG(4) - 1); + int ksizeColsEffective = INT_ARG(1) + (INT_ARG(1) - 1) * (INT_ARG(5) - 1); + + auto batchSizeDim = shape::sizeAt(in, 0); + auto inputRowsDim = shape::sizeAt(in, 1); + auto inputColsDim = shape::sizeAt(in, 2); + auto outputDepthDim = shape::sizeAt(in, 3) * INT_ARG(0) * + INT_ARG(1); // last dim * ksizeRows * ksizeCols + + auto inputRowSize = inputRowsDim; // shape::sizeAt(in, inputRowsDim); + auto inputColSize = inputColsDim; // shape::sizeAt(in, inputColsDim); + Nd4jLong outRowSize; + Nd4jLong outColSize; + if (INT_ARG(6) == 0) { + // Padding is "VALID": + outRowSize = (inputRowSize - ksizeRowsEffective + INT_ARG(2)) / INT_ARG(2); + outColSize = (inputColSize - ksizeColsEffective + INT_ARG(3)) / INT_ARG(3); + } else { + // Padding is "SAME": + outRowSize = (inputRowSize + INT_ARG(2) - 1) / INT_ARG(2); + outColSize = (inputColSize + INT_ARG(3) - 1) / INT_ARG(3); + } + + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + + outputShape[0] = outRank; + outputShape[1] = batchSizeDim; + outputShape[2] = outRowSize; + outputShape[3] = outColSize; + outputShape[4] = outputDepthDim; + + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + + return SHAPELIST(CONSTANT(outputShape)); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/images/hsvToRgb.cpp b/libnd4j/include/ops/declarable/generic/images/hsvToRgb.cpp index d5211e498241..be4cc399dbd2 100644 --- a/libnd4j/include/ops/declarable/generic/images/hsvToRgb.cpp +++ b/libnd4j/include/ops/declarable/generic/images/hsvToRgb.cpp @@ -16,44 +16,49 @@ // // @author AbdelRauf (rauf@konduit.ai) -// +// -#include -#include -#include #include +#include +#include +#include namespace sd { namespace ops { CONFIGURABLE_OP_IMPL(hsv_to_rgb, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + if (input->isEmpty()) return Status::OK(); - if (input->isEmpty()) - return Status::OK(); + const int rank = input->rankOf(); + const int argSize = block.numI(); + const int dimC = argSize > 0 + ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) + : rank - 1; - const int rank = input->rankOf(); - const int argSize = block.getIArguments()->size(); - const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; + REQUIRE_TRUE(rank >= 1, 0, + "HSVtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank); + if (argSize > 0) { + REQUIRE_TRUE( + dimC >= 0 && dimC < rank, 0, + "Index of the Channel dimension out of range: %i not in [%i,%i) ", + INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE( + input->sizeAt(dimC) == 3, 0, + "HSVtoRGB: operation expects 3 channels (H, S, V), but got %i instead", + input->sizeAt(dimC)); - REQUIRE_TRUE(rank >= 1, 0, "HSVtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank); - if (argSize > 0) { - REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); - } - REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "HSVtoRGB: operation expects 3 channels (H, S, V), but got %i instead", input->sizeAt(dimC)); + helpers::transformHsvRgb(block.launchContext(), input, output, dimC); - helpers::transformHsvRgb(block.launchContext(), input, output, dimC); - - return Status::OK(); + return Status::OK(); } DECLARE_TYPES(hsv_to_rgb) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } - -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/images/image_resize.cpp b/libnd4j/include/ops/declarable/generic/images/image_resize.cpp index 8e6e29d3a830..b96d5bc49144 100644 --- a/libnd4j/include/ops/declarable/generic/images/image_resize.cpp +++ b/libnd4j/include/ops/declarable/generic/images/image_resize.cpp @@ -25,11 +25,10 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(image_resize, 2, 1, false, 0, 0) { - - auto image = INPUT_VARIABLE(0); - auto size = INPUT_VARIABLE(1); +namespace ops { +CUSTOM_OP_IMPL(image_resize, 2, 1, false, 0, 0) { + auto image = INPUT_VARIABLE(0); + auto size = INPUT_VARIABLE(1); auto output = OUTPUT_VARIABLE(0); @@ -43,10 +42,10 @@ namespace sd { antialias = B_ARG(1); } - auto method = helpers::ImageResizeMethods::kResizeBilinear; - if (block.numI() == 1) { - method = (helpers::ImageResizeMethods)INT_ARG(0); - } + auto method = helpers::ImageResizeMethods::kResizeBilinear; + if (block.numI() == 1) { + method = (helpers::ImageResizeMethods)INT_ARG(0); + } REQUIRE_TRUE(method == helpers::ImageResizeMethods::kResizeNearest || output->dataType() == DataType::FLOAT32, 0, "image_resize: Output data type should be FLOAT32 for this method %i", (int)method ); REQUIRE_TRUE(method >= helpers::ImageResizeMethods::kResizeFirst && method <= helpers::ImageResizeMethods::kResizeLast, 0, "image_resize: Resize method should be between %i and %i, but %i was given.", (int)helpers::ImageResizeMethods::kResizeFirst, (int)helpers::ImageResizeMethods::kResizeLast, (int)method); auto inRank = image->rankOf(); @@ -54,13 +53,14 @@ namespace sd { auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false); - return helpers::resizeFunctor(block.launchContext(), image, width, height, method, antialias, output); - } + return helpers::resizeFunctor(block.launchContext(), image, width, height, + method, antialias, output); +} - DECLARE_SHAPE_FN(image_resize) { - auto in = inputShape->at(0); +DECLARE_SHAPE_FN(image_resize) { + auto in = inputShape->at(0); - Nd4jLong* outputShape; + Nd4jLong* outputShape; auto method = helpers::ImageResizeMethods::kResizeBilinear; if (block.numI() == 1) { method = (helpers::ImageResizeMethods)INT_ARG(0); @@ -84,16 +84,17 @@ namespace sd { dtype = ArrayOptions::dataType(in); auto shape = ConstantShapeHelper::getInstance().createShapeInfo(dtype, 'c', shape::rank(in) == 4?std::vector{in[1], height, width, in[4]}:std::vector{ height, width, in[4]}); - return SHAPELIST(shape); - } - DECLARE_TYPES(image_resize) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); - } - } + return SHAPELIST(shape); } +DECLARE_TYPES(image_resize) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS,ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); +} + +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/images/resize_area.cpp b/libnd4j/include/ops/declarable/generic/images/resize_area.cpp index 4ae03cc256af..f0289ac4ed56 100644 --- a/libnd4j/include/ops/declarable/generic/images/resize_area.cpp +++ b/libnd4j/include/ops/declarable/generic/images/resize_area.cpp @@ -25,98 +25,134 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(resize_area, 1, 1, false, 0, -2) { - - auto image = INPUT_VARIABLE(0); - int width; - int height; - - if (block.width() == 2) { - auto size = INPUT_VARIABLE(1); // integer vector with shape {2} and content (new_height, new_width) - REQUIRE_TRUE(size->rankOf() == 1, size->lengthOf() == 2, 0, "resize_area: Resize params is a pair of values, not %i.", size->lengthOf()); - size->syncToHost(); - width = size->e(1); - height = size->e(0); - } - else { - REQUIRE_TRUE(block.numI() == 2, 0, "resize_area: Resize params already given by the second param. Int params are expensive."); - width = INT_ARG(1); - height = INT_ARG(0); - } - - auto output = OUTPUT_VARIABLE(0); - if (output->isEmpty()) return Status::OK(); - auto inRank = image->rankOf(); - - REQUIRE_TRUE(inRank == 3 || inRank == 4, 0, "resize_area: Source tensor should have rank 4, but %i given.", inRank); - REQUIRE_TRUE(output->rankOf() == inRank, 0, "resize_area: Source tensor and output should have the same rank, but %i and %i given.", inRank, output->rankOf()); - REQUIRE_TRUE(width > 0 , 0, "resize_area: picture width should be positive 32 bit integer, but %i given", width); - REQUIRE_TRUE(height > 0 , 0, "resize_area: picture height should be positive 32 bit integer, but %i given", height); - REQUIRE_TRUE(image->lengthOf() > 0, 0, "resize_area: Only non-zero images allowed to processing."); - - auto alignCorners = false; - if (block.numB() > 0) { - alignCorners = B_ARG(0); - } - - auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); - auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false); - - return helpers::resizeAreaFunctor(block.launchContext(), &source, width, height, alignCorners, &target); - } - - DECLARE_SHAPE_FN(resize_area) { - auto shapeList = SHAPELIST(); - auto in = inputShape->at(0); - - Nd4jLong* outputShape; - auto inRank = shape::rank(in); - int width; - int height; - if (block.width() == 2) { - auto newImageSize = INPUT_VARIABLE(1); - REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, - "resize_area: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); - REQUIRE_TRUE(block.numI() <= 1, 0, - "resize_area: Resize params already given by the second param. Int params are expensive."); - width = newImageSize->e(1); - height = newImageSize->e(0); - } - else { - REQUIRE_TRUE(block.numI() == 2, 0, "resize_area: Resize params ommited as pair ints nor int tensor."); - width = INT_ARG(1); - height = INT_ARG(0); - } - - REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_area: Source tensor should have rank 4, but %i given.", inRank); - - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong); - outputShape[0] = inRank; - if (inRank == 4) { - outputShape[1] = in[1]; - outputShape[2] = height; - outputShape[3] = width; - outputShape[4] = in[4]; - } - else { - outputShape[1] = height; - outputShape[2] = width; - outputShape[3] = in[3]; - } - ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, shape::order(in)); - - shapeList->push_back(CONSTANT(outputShape)); - return shapeList; - } - DECLARE_TYPES(resize_area) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, DataType::INT32) - ->setAllowedOutputTypes({DataType::FLOAT32}); - } - - } +namespace ops { +CUSTOM_OP_IMPL(resize_area, 1, 1, false, 0, -2) { + auto image = INPUT_VARIABLE(0); + int width; + int height; + + if (block.width() == 2) { + auto size = INPUT_VARIABLE(1); // integer vector with shape {2} and content + // (new_height, new_width) + REQUIRE_TRUE(size->rankOf() == 1, size->lengthOf() == 2, 0, + "resize_area: Resize params is a pair of values, not %i.", + size->lengthOf()); + size->syncToHost(); + width = size->e(1); + height = size->e(0); + } else { + REQUIRE_TRUE(block.numI() == 2, 0, + "resize_area: Resize params already given by the second " + "param. Int params are expensive."); + width = INT_ARG(1); + height = INT_ARG(0); + } + + auto output = OUTPUT_VARIABLE(0); + if (output->isEmpty()) return Status::OK(); + auto inRank = image->rankOf(); + + REQUIRE_TRUE(inRank == 3 || inRank == 4, 0, + "resize_area: Source tensor should have rank 4, but %i given.", + inRank); + REQUIRE_TRUE(output->rankOf() == inRank, 0, + "resize_area: Source tensor and output should have the same " + "rank, but %i and %i given.", + inRank, output->rankOf()); + REQUIRE_TRUE(width > 0, 0, + "resize_area: picture width should be positive 32 bit integer, " + "but %i given", + width); + REQUIRE_TRUE(height > 0, 0, + "resize_area: picture height should be positive 32 bit integer, " + "but %i given", + height); + REQUIRE_TRUE(image->lengthOf() > 0, 0, + "resize_area: Only non-zero images allowed to processing."); + + auto alignCorners = false; + if (block.numB() > 0) { + alignCorners = B_ARG(0); + } + + auto source = inRank == 4 + ? image->reshape(image->ordering(), + {image->sizeAt(0), image->sizeAt(1), + image->sizeAt(2), image->sizeAt(3)}) + : image->reshape(image->ordering(), + {1, image->sizeAt(0), image->sizeAt(1), + image->sizeAt(2)}); + auto target = inRank == 4 + ? output->reshape(output->ordering(), + {output->sizeAt(0), output->sizeAt(1), + output->sizeAt(2), output->sizeAt(3)}, + false) + : output->reshape(output->ordering(), + {1, output->sizeAt(0), output->sizeAt(1), + output->sizeAt(2)}, + false); + + return helpers::resizeAreaFunctor(block.launchContext(), &source, width, + height, alignCorners, &target); } +DECLARE_SHAPE_FN(resize_area) { + auto shapeList = SHAPELIST(); + auto in = inputShape->at(0); + + Nd4jLong* outputShape; + auto inRank = shape::rank(in); + int width; + int height; + if (block.width() == 2) { + auto newImageSize = INPUT_VARIABLE(1); + REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, + "resize_area: Resize params is a pair of values, not %i.", + newImageSize->lengthOf()); + REQUIRE_TRUE(block.numI() <= 1, 0, + "resize_area: Resize params already given by the second " + "param. Int params are expensive."); + width = newImageSize->e(1); + height = newImageSize->e(0); + } else { + REQUIRE_TRUE( + block.numI() == 2, 0, + "resize_area: Resize params ommited as pair ints nor int tensor."); + width = INT_ARG(1); + height = INT_ARG(0); + } + + REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, + "resize_area: Source tensor should have rank 4, but %i given.", + inRank); + + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(inRank), + Nd4jLong); + outputShape[0] = inRank; + if (inRank == 4) { + outputShape[1] = in[1]; + outputShape[2] = height; + outputShape[3] = width; + outputShape[4] = in[4]; + } else { + outputShape[1] = height; + outputShape[2] = width; + outputShape[3] = in[3]; + } + ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, + shape::order(in)); + + shapeList->push_back(CONSTANT(outputShape)); + return shapeList; +} +DECLARE_TYPES(resize_area) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, DataType::INT32) + ->setAllowedOutputTypes({DataType::FLOAT32}); +} + +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/images/resize_bicubic.cpp b/libnd4j/include/ops/declarable/generic/images/resize_bicubic.cpp index a867a2147421..e24f8d20ed4c 100644 --- a/libnd4j/include/ops/declarable/generic/images/resize_bicubic.cpp +++ b/libnd4j/include/ops/declarable/generic/images/resize_bicubic.cpp @@ -25,90 +25,136 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(resize_bicubic, 2, 1, false, 0, 0) { +namespace ops { +CUSTOM_OP_IMPL(resize_bicubic, 2, 1, false, 0, 0) { + auto image = INPUT_VARIABLE(0); + auto size = INPUT_VARIABLE( + 1); // integer vector with shape {2} and content (new_height, new_width) + size->syncToHost(); + auto output = OUTPUT_VARIABLE(0); + int width; + int height; + auto inRank = image->rankOf(); + if (output->isEmpty()) return Status::OK(); - auto image = INPUT_VARIABLE(0); - auto size = INPUT_VARIABLE(1); // integer vector with shape {2} and content (new_height, new_width) - size->syncToHost(); - auto output = OUTPUT_VARIABLE(0); - int width; - int height; - auto inRank = image->rankOf(); - if (output->isEmpty()) return Status::OK(); + REQUIRE_TRUE( + inRank == 3 || inRank == 4, 0, + "resize_bicubic: Source tensor should have rank 4, but %i given.", + inRank); + REQUIRE_TRUE(output->rankOf() == inRank, 0, + "resize_bicubic: Source tensor and output should have the same " + "rank, but %i and %i given.", + inRank, output->rankOf()); + REQUIRE_TRUE(size->rankOf() == 1, size->lengthOf() == 2, 0, + "resize_bicubic: Resize params is a pair of values, not %i.", + size->lengthOf()); + REQUIRE_TRUE(block.numI() <= 1, 0, + "resize_bicubic: Resize params already given by the second " + "param. Int params are expensive."); + width = size->e(1); + height = size->e(0); + REQUIRE_TRUE(width > 0, 0, + "resize_bicubic: picture width should be positive 32 bit " + "integer, but %i given", + width); + REQUIRE_TRUE(height > 0, 0, + "resize_bicubic: picture height should be positive 32 bit " + "integer, but %i given", + height); + // REQUIRE_TRUE(image->sizeAt(1) > 3 && image->sizeAt(2) > 3, 0, + // "resize_cubic: To use bicubic algorithm need at least 16 pixels as + // source."); + REQUIRE_TRUE(width > 3 && height > 3, 0, + "resize_bicubic: To use bicubic algorithm need at least 16 " + "pixels as target."); + REQUIRE_TRUE(image->lengthOf() > 0, 0, + "resize_bicubic: Only non-zero images allowed to processing."); + // auto method = 1; //kResizeBilinear; + // if (block.numI() == 1) { + // method = INT_ARG(0); + // } + auto alignCorners = false; + auto halfPixelAlign = false; + if (block.numB() > 0) { + alignCorners = block.getBArguments().at(0); + if (block.numB() > 1) halfPixelAlign = block.getBArguments().at(1); + } + REQUIRE_TRUE(!halfPixelAlign || (halfPixelAlign && !alignCorners), 0, + "resize_bicubic: `half_pixel_centers' should be false or true " + "only when `align_corners' is false"); - REQUIRE_TRUE(inRank == 3 || inRank == 4, 0, "resize_bicubic: Source tensor should have rank 4, but %i given.", inRank); - REQUIRE_TRUE(output->rankOf() == inRank, 0, "resize_bicubic: Source tensor and output should have the same rank, but %i and %i given.", inRank, output->rankOf()); - REQUIRE_TRUE(size->rankOf() == 1, size->lengthOf() == 2, 0, "resize_bicubic: Resize params is a pair of values, not %i.", size->lengthOf()); - REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bicubic: Resize params already given by the second param. Int params are expensive."); - width = size->e(1); - height = size->e(0); - REQUIRE_TRUE(width > 0 , 0, "resize_bicubic: picture width should be positive 32 bit integer, but %i given", width); - REQUIRE_TRUE(height > 0 , 0, "resize_bicubic: picture height should be positive 32 bit integer, but %i given", height); - //REQUIRE_TRUE(image->sizeAt(1) > 3 && image->sizeAt(2) > 3, 0, "resize_cubic: To use bicubic algorithm need at least 16 pixels as source."); - REQUIRE_TRUE(width > 3 && height > 3, 0, "resize_bicubic: To use bicubic algorithm need at least 16 pixels as target."); - REQUIRE_TRUE(image->lengthOf() > 0, 0, "resize_bicubic: Only non-zero images allowed to processing."); -// auto method = 1; //kResizeBilinear; -// if (block.numI() == 1) { -// method = INT_ARG(0); -// } - auto alignCorners = false; - auto halfPixelAlign = false; - if (block.numB() > 0) { - alignCorners = block.getBArguments()->at(0); - if (block.numB()> 1) - halfPixelAlign = block.getBArguments()->at(1); - } - REQUIRE_TRUE(!halfPixelAlign || (halfPixelAlign && !alignCorners), 0, "resize_bicubic: `half_pixel_centers' should be false or true only when `align_corners' is false"); + auto source = inRank == 4 + ? image->reshape(image->ordering(), + {image->sizeAt(0), image->sizeAt(1), + image->sizeAt(2), image->sizeAt(3)}) + : image->reshape(image->ordering(), + {1, image->sizeAt(0), image->sizeAt(1), + image->sizeAt(2)}); + auto target = inRank == 4 + ? output->reshape(output->ordering(), + {output->sizeAt(0), output->sizeAt(1), + output->sizeAt(2), output->sizeAt(3)}, + false) + : output->reshape(output->ordering(), + {1, output->sizeAt(0), output->sizeAt(1), + output->sizeAt(2)}, + false); - auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); - auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false); - - return helpers::resizeBicubicFunctorA(block.launchContext(), &source, width, height, alignCorners, halfPixelAlign, &target); - } - - DECLARE_SHAPE_FN(resize_bicubic) { - auto shapeList = SHAPELIST(); - auto in = inputShape->at(0); + return helpers::resizeBicubicFunctorA(block.launchContext(), &source, width, + height, alignCorners, halfPixelAlign, + &target); +} - Nd4jLong* outputShape; - auto inRank = shape::rank(in); - int width; - int height; - auto newImageSize = INPUT_VARIABLE(1); - REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bicubic: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); - REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bicubic: Resize params already given by the second param. Int params are expensive."); - width = newImageSize->e(0); - height = newImageSize->e(1); +DECLARE_SHAPE_FN(resize_bicubic) { + auto shapeList = SHAPELIST(); + auto in = inputShape->at(0); - REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_bicubic: Source tensor should have rank 4, but %i given.", inRank); + Nd4jLong* outputShape; + auto inRank = shape::rank(in); + int width; + int height; + auto newImageSize = INPUT_VARIABLE(1); + REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, + "resize_bicubic: Resize params is a pair of values, not %i.", + newImageSize->lengthOf()); + REQUIRE_TRUE(block.numI() <= 1, 0, + "resize_bicubic: Resize params already given by the second " + "param. Int params are expensive."); + width = newImageSize->e(0); + height = newImageSize->e(1); - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong); - outputShape[0] = inRank; - if (inRank == 4) { - outputShape[1] = in[1]; - outputShape[2] = width; - outputShape[3] = height; - outputShape[4] = in[4]; - } - else { - outputShape[1] = width; - outputShape[2] = height; - outputShape[3] = in[3]; - } - ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, shape::order(in)); + REQUIRE_TRUE( + inRank == 4 || inRank == 3, 0, + "resize_bicubic: Source tensor should have rank 4, but %i given.", + inRank); - shapeList->push_back(CONSTANT(outputShape)); - return shapeList; - } - DECLARE_TYPES(resize_bicubic) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, DataType::INT32) - ->setAllowedOutputTypes({DataType::FLOAT32}); - } + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(inRank), + Nd4jLong); + outputShape[0] = inRank; + if (inRank == 4) { + outputShape[1] = in[1]; + outputShape[2] = width; + outputShape[3] = height; + outputShape[4] = in[4]; + } else { + outputShape[1] = width; + outputShape[2] = height; + outputShape[3] = in[3]; + } + ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, + shape::order(in)); - } + shapeList->push_back(CONSTANT(outputShape)); + return shapeList; } +DECLARE_TYPES(resize_bicubic) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, DataType::INT32) + ->setAllowedOutputTypes({DataType::FLOAT32}); +} + +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/images/resize_linear.cpp b/libnd4j/include/ops/declarable/generic/images/resize_linear.cpp index 6d72bf889728..fb1a40a4baf9 100644 --- a/libnd4j/include/ops/declarable/generic/images/resize_linear.cpp +++ b/libnd4j/include/ops/declarable/generic/images/resize_linear.cpp @@ -26,105 +26,135 @@ #include #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(resize_bilinear, 1, 1, false, 0, -2) { - - NDArray* image = INPUT_VARIABLE(0); - NDArray* output = OUTPUT_VARIABLE(0); - int width; - int height; - bool alignCorners = false; // - default value - auto inRank = image->rankOf(); - if (output->isEmpty()) return Status::OK(); - - REQUIRE_TRUE( inRank == 4 || inRank == 3, 0, "resize_bilinear: input image should be 4D " - "tensor, but input has rank %i", - image->rankOf()); - REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_bilinear: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf()); - - auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); - auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false); - - if (block.width() > 1) { - auto newImageSize = INPUT_VARIABLE(1); - REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); - REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bilinear: Resize params already given by the second param. Int params are expensive."); - height = newImageSize->e(0); - width = newImageSize->e(1); - } - else { - REQUIRE_TRUE(block.numI() > 1, 0, "resize_bilinear: Neither resize width nor height are provided."); - height = INT_ARG(0); - width = INT_ARG(1); - } - - if (block.numB() > 0) - alignCorners = B_ARG(0); - bool halfPixelCenter = false; - - if (block.numB() > 1) - halfPixelCenter = B_ARG(1); - - REQUIRE_TRUE(!halfPixelCenter || (halfPixelCenter && !alignCorners), 0, "resize_bilinear: `half_pixel_centers' should be false or true only when `align_corners' is false"); - - return helpers::resizeBilinearFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target); - } - - DECLARE_SHAPE_FN(resize_bilinear) { - auto shapeList = SHAPELIST(); - auto in = inputShape->at(0); - - Nd4jLong* outputShape; - auto inRank = shape::rank(in); - REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_bilinear: input image should be 4D " - "tensor, but input has rank %i", - inRank); - - int width; - int height; - if (block.width() > 1) { - auto newImageSize = INPUT_VARIABLE(1); - REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); - REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bilinear: Resize params already given by the second param. Int params are expensive."); - width = newImageSize->e(0); - height = newImageSize->e(1); - } - else { - REQUIRE_TRUE(block.numI() == 2, 0, "resize_bilinear: Neither resize width nor height are provided."); - width = INT_ARG(0); - height = INT_ARG(1); - } - - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong); - outputShape[0] = inRank; - if (inRank == 4) { - outputShape[1] = in[1]; - outputShape[2] = width; - outputShape[3] = height; - outputShape[4] = in[4]; - } - else { // input shape is 3D, so result also should be 3D - outputShape[1] = width; - outputShape[2] = height; - outputShape[3] = in[3]; - } - if (DataTypeUtils::isR(ArrayOptions::dataType(in))) { - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - } - else { - ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, shape::order(in)); - } - - shapeList->push_back(CONSTANT(outputShape)); - return shapeList; - } - DECLARE_TYPES(resize_bilinear) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - } +namespace ops { +CUSTOM_OP_IMPL(resize_bilinear, 1, 1, false, 0, -2) { + NDArray* image = INPUT_VARIABLE(0); + NDArray* output = OUTPUT_VARIABLE(0); + int width; + int height; + bool alignCorners = false; // - default value + auto inRank = image->rankOf(); + if (output->isEmpty()) return Status::OK(); + + REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, + "resize_bilinear: input image should be 4D " + "tensor, but input has rank %i", + image->rankOf()); + REQUIRE_TRUE(inRank == output->rankOf(), 0, + "resize_bilinear: Input and output ranks should be equals, but " + "%i and %i occured.", + inRank, output->rankOf()); + + auto source = inRank == 4 + ? image->reshape(image->ordering(), + {image->sizeAt(0), image->sizeAt(1), + image->sizeAt(2), image->sizeAt(3)}) + : image->reshape(image->ordering(), + {1, image->sizeAt(0), image->sizeAt(1), + image->sizeAt(2)}); + auto target = inRank == 4 + ? output->reshape(output->ordering(), + {output->sizeAt(0), output->sizeAt(1), + output->sizeAt(2), output->sizeAt(3)}, + false) + : output->reshape(output->ordering(), + {1, output->sizeAt(0), output->sizeAt(1), + output->sizeAt(2)}, + false); + + if (block.width() > 1) { + auto newImageSize = INPUT_VARIABLE(1); + REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, + "resize_bilinear: Resize params is a pair of values, not %i.", + newImageSize->lengthOf()); + REQUIRE_TRUE(block.numI() <= 1, 0, + "resize_bilinear: Resize params already given by the second " + "param. Int params are expensive."); + height = newImageSize->e(0); + width = newImageSize->e(1); + } else { + REQUIRE_TRUE( + block.numI() > 1, 0, + "resize_bilinear: Neither resize width nor height are provided."); + height = INT_ARG(0); + width = INT_ARG(1); + } + + if (block.numB() > 0) alignCorners = B_ARG(0); + bool halfPixelCenter = false; + + if (block.numB() > 1) halfPixelCenter = B_ARG(1); + + REQUIRE_TRUE(!halfPixelCenter || (halfPixelCenter && !alignCorners), 0, + "resize_bilinear: `half_pixel_centers' should be false or true " + "only when `align_corners' is false"); + + return helpers::resizeBilinearFunctor( + block.launchContext(), inRank == 4 ? image : &source, width, height, + alignCorners, halfPixelCenter, inRank == 4 ? output : &target); } +DECLARE_SHAPE_FN(resize_bilinear) { + auto shapeList = SHAPELIST(); + auto in = inputShape->at(0); + + Nd4jLong* outputShape; + auto inRank = shape::rank(in); + REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, + "resize_bilinear: input image should be 4D " + "tensor, but input has rank %i", + inRank); + + int width; + int height; + if (block.width() > 1) { + auto newImageSize = INPUT_VARIABLE(1); + REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, + "resize_bilinear: Resize params is a pair of values, not %i.", + newImageSize->lengthOf()); + REQUIRE_TRUE(block.numI() <= 1, 0, + "resize_bilinear: Resize params already given by the second " + "param. Int params are expensive."); + width = newImageSize->e(0); + height = newImageSize->e(1); + } else { + REQUIRE_TRUE( + block.numI() == 2, 0, + "resize_bilinear: Neither resize width nor height are provided."); + width = INT_ARG(0); + height = INT_ARG(1); + } + + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(inRank), + Nd4jLong); + outputShape[0] = inRank; + if (inRank == 4) { + outputShape[1] = in[1]; + outputShape[2] = width; + outputShape[3] = height; + outputShape[4] = in[4]; + } else { // input shape is 3D, so result also should be 3D + outputShape[1] = width; + outputShape[2] = height; + outputShape[3] = in[3]; + } + if (DataTypeUtils::isR(ArrayOptions::dataType(in))) { + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + } else { + ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, + shape::order(in)); + } + + shapeList->push_back(CONSTANT(outputShape)); + return shapeList; +} +DECLARE_TYPES(resize_bilinear) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/images/resize_neighbor.cpp b/libnd4j/include/ops/declarable/generic/images/resize_neighbor.cpp index 3454fb897e6e..80e912a2509e 100644 --- a/libnd4j/include/ops/declarable/generic/images/resize_neighbor.cpp +++ b/libnd4j/include/ops/declarable/generic/images/resize_neighbor.cpp @@ -27,97 +27,137 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(resize_nearest_neighbor, 1, 1, false, 0, -2) { +namespace ops { +CUSTOM_OP_IMPL(resize_nearest_neighbor, 1, 1, false, 0, -2) { + auto image = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + auto inRank = image->rankOf(); + int width; + int height; + bool alignCorners = false; // - default value + if (output->isEmpty()) return Status::OK(); + if (block.width() > 1) { + auto newImageSize = INPUT_VARIABLE(1); + REQUIRE_TRUE( + newImageSize->lengthOf() == 2, 0, + "resize_nearest_neighbor: Resize params is a pair of values, not %i.", + newImageSize->lengthOf()); + REQUIRE_TRUE(block.numI() <= 1, 0, + "resize_nearest_neighbor: Resize params already given by the " + "second param. Int params are expensive."); + height = newImageSize->e(0); + width = newImageSize->e(1); + } else { + REQUIRE_TRUE(block.numI() == 2, 0, + "resize_nearest_neighbor: Neither resize width nor height are " + "provided."); + height = INT_ARG(0); + width = INT_ARG(1); + } + if (block.numB() > 0) alignCorners = B_ARG(0); + bool halfPixelCenter = false; - auto image = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - auto inRank = image->rankOf(); - int width; - int height; - bool alignCorners = false; // - default value - if (output->isEmpty()) return Status::OK(); - if (block.width() > 1) { - auto newImageSize = INPUT_VARIABLE(1); - REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_nearest_neighbor: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); - REQUIRE_TRUE(block.numI() <= 1, 0, "resize_nearest_neighbor: Resize params already given by the second param. Int params are expensive."); - height = newImageSize->e(0); - width = newImageSize->e(1); - } - else { - REQUIRE_TRUE(block.numI() == 2, 0, "resize_nearest_neighbor: Neither resize width nor height are provided."); - height = INT_ARG(0); - width = INT_ARG(1); - } - if (block.numB() > 0) - alignCorners = B_ARG(0); - bool halfPixelCenter = false; + if (block.numB() > 1) halfPixelCenter = B_ARG(1); + REQUIRE_TRUE( + width <= (1 << 24) || height <= (1 << 24), 0, + "resize_nearest_neighbor: the image resize should be limited to 2^24 " + "pixels both for height and width, but %d and %d were given.", + height, width); + REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, + "resize_nearest_neighbor: Input should be 4D tensor, but rank " + "%i occured"); + REQUIRE_TRUE(inRank == output->rankOf(), 0, + "resize_nearest_neighbor: Input and output ranks should be " + "equals, but %i and %i occured.", + inRank, output->rankOf()); + REQUIRE_TRUE(image->dataType() == output->dataType(), 0, + "resize_nearest_neighbor: Input and output types should be the " + "same, but `%s' occured instead.", + DataTypeUtils::asString(output->dataType()).c_str()); + REQUIRE_TRUE(!halfPixelCenter || (halfPixelCenter && !alignCorners), 0, + "resize_nearest_neighbor: `half_pixel_centers' should be false " + "or true only when `align_corners' is false"); + REQUIRE_TRUE(((alignCorners && height > 2) || (height > 0)) && + ((alignCorners && width > 1) || (width > 0)), + 0, + "resize_nearest_neighbor: Wrong input or output size to resize " + "(width = %d, height = %d)", + width, height); - if (block.numB() > 1) - halfPixelCenter = B_ARG(1); - REQUIRE_TRUE(width <= (1 << 24) || height <= (1 << 24), 0, "resize_nearest_neighbor: the image resize should be limited to 2^24 pixels both for height and width, but %d and %d were given.", height, width); - REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: Input should be 4D tensor, but rank %i occured"); - REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_nearest_neighbor: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf()); - REQUIRE_TRUE(image->dataType() == output->dataType(), 0, "resize_nearest_neighbor: Input and output types should be the same, but `%s' occured instead.", DataTypeUtils::asString(output->dataType()).c_str()); - REQUIRE_TRUE(!halfPixelCenter || (halfPixelCenter && !alignCorners), 0, "resize_nearest_neighbor: `half_pixel_centers' should be false or true only when `align_corners' is false"); - REQUIRE_TRUE(((alignCorners && height > 2) || (height > 0)) && ((alignCorners && width > 1) || (width > 0)), 0, "resize_nearest_neighbor: Wrong input or output size to resize (width = %d, height = %d)", width, height); + auto source = inRank == 4 + ? *image + : image->reshape(image->ordering(), + {1, image->sizeAt(0), image->sizeAt(1), + image->sizeAt(2)}); + auto target = inRank == 4 + ? *output + : output->reshape(output->ordering(), + {1, output->sizeAt(0), output->sizeAt(1), + output->sizeAt(2)}, + false); - auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); - auto target = inRank == 4 ? *output : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false); - - return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target); - } - - DECLARE_SHAPE_FN(resize_nearest_neighbor) { - auto shapeList = SHAPELIST(); - auto in = inputShape->at(0); - auto inRank = shape::rank(in); - Nd4jLong* outputShape; + return helpers::resizeNeighborFunctor( + block.launchContext(), inRank == 4 ? image : &source, width, height, + alignCorners, halfPixelCenter, inRank == 4 ? output : &target); +} - REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: input image should be 4D " - "tensor, but input has rank %i", - inRank); +DECLARE_SHAPE_FN(resize_nearest_neighbor) { + auto shapeList = SHAPELIST(); + auto in = inputShape->at(0); + auto inRank = shape::rank(in); + Nd4jLong* outputShape; - int width; - int height; - if (block.width() > 1) { - auto newImageSize = INPUT_VARIABLE(1); - REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_nearest_neighbor: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); - REQUIRE_TRUE(block.numI() <= 1, 0, "resize_nearest_neighbor: Resize params already given by the second param. Int params are expensive."); - width = newImageSize->e(0); - height = newImageSize->e(1); - } - else { - REQUIRE_TRUE(block.numI() <= 3, 0, "resize_nearest_neighbor: Neither resize width nor height are provided."); - width = INT_ARG(0); - height = INT_ARG(1); - } + REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, + "resize_nearest_neighbor: input image should be 4D " + "tensor, but input has rank %i", + inRank); - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong); - outputShape[0] = inRank; - if (inRank == 4) { - outputShape[1] = in[1]; - outputShape[2] = width; - outputShape[3] = height; - outputShape[4] = in[4]; - } - else { // input shape is 3D, so result also should be 3D - outputShape[1] = width; - outputShape[2] = height; - outputShape[3] = in[3]; - } - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + int width; + int height; + if (block.width() > 1) { + auto newImageSize = INPUT_VARIABLE(1); + REQUIRE_TRUE( + newImageSize->lengthOf() == 2, 0, + "resize_nearest_neighbor: Resize params is a pair of values, not %i.", + newImageSize->lengthOf()); + REQUIRE_TRUE(block.numI() <= 1, 0, + "resize_nearest_neighbor: Resize params already given by the " + "second param. Int params are expensive."); + width = newImageSize->e(0); + height = newImageSize->e(1); + } else { + REQUIRE_TRUE(block.numI() <= 3, 0, + "resize_nearest_neighbor: Neither resize width nor height are " + "provided."); + width = INT_ARG(0); + height = INT_ARG(1); + } - shapeList->push_back(CONSTANT(outputShape)); - return shapeList; - } - DECLARE_TYPES(resize_nearest_neighbor) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(inRank), + Nd4jLong); + outputShape[0] = inRank; + if (inRank == 4) { + outputShape[1] = in[1]; + outputShape[2] = width; + outputShape[3] = height; + outputShape[4] = in[4]; + } else { // input shape is 3D, so result also should be 3D + outputShape[1] = width; + outputShape[2] = height; + outputShape[3] = in[3]; + } + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - } + shapeList->push_back(CONSTANT(outputShape)); + return shapeList; } +DECLARE_TYPES(resize_nearest_neighbor) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); +} + +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/images/rgbToGrs.cpp b/libnd4j/include/ops/declarable/generic/images/rgbToGrs.cpp index a6d80365c9e2..2ac7635c13f0 100644 --- a/libnd4j/include/ops/declarable/generic/images/rgbToGrs.cpp +++ b/libnd4j/include/ops/declarable/generic/images/rgbToGrs.cpp @@ -18,57 +18,77 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include -#include -#include #include +#include +#include +#include namespace sd { namespace ops { CUSTOM_OP_IMPL(rgb_to_grs, 1, 1, false, 0, 0) { + const auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - const auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - const int inRank = input->rankOf(); - const int argSize = block.getIArguments()->size(); - const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + inRank) : inRank - 1; + const int inRank = input->rankOf(); + const int argSize = block.numI(); + const int dimC = argSize > 0 + ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + inRank) + : inRank - 1; - REQUIRE_TRUE(inRank >= 1, 0, "RGBtoGrayScale: Fails to meet the inRank requirement: %i >= 1 ", inRank); - if (argSize > 0) { - REQUIRE_TRUE(dimC >= 0 && dimC < inRank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -inRank, inRank); - } - REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBGrayScale: operation expects 3 channels (R, G, B) in last dimention, but received %i instead", input->sizeAt(dimC)); + REQUIRE_TRUE(inRank >= 1, 0, + "RGBtoGrayScale: Fails to meet the inRank requirement: %i >= 1 ", + inRank); + if (argSize > 0) { + REQUIRE_TRUE( + dimC >= 0 && dimC < inRank, 0, + "Index of the Channel dimension out of range: %i not in [%i,%i) ", + INT_ARG(0), -inRank, inRank); + } + REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, + "RGBGrayScale: operation expects 3 channels (R, G, B) in last " + "dimention, but received %i instead", + input->sizeAt(dimC)); - helpers::transformRgbGrs(block.launchContext(), *input, *output, dimC); - return Status::OK(); + helpers::transformRgbGrs(block.launchContext(), *input, *output, dimC); + return Status::OK(); } DECLARE_TYPES(rgb_to_grs) { - getOpDescriptor()->setAllowedInputTypes( {ALL_INTS, ALL_FLOATS} ) - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) + ->setSameMode(true); } DECLARE_SHAPE_FN(rgb_to_grs) { + const auto input = INPUT_VARIABLE(0); + const int inRank = input->rankOf(); - const auto input = INPUT_VARIABLE(0); - const int inRank = input->rankOf(); - - const int argSize = block.getIArguments()->size(); - const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + inRank) : inRank - 1; + const int argSize = block.numI(); + const int dimC = argSize > 0 + ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + inRank) + : inRank - 1; - REQUIRE_TRUE(inRank >= 1, 0, "RGBtoGrayScale: Fails to meet the inRank requirement: %i >= 1 ", inRank); - if (argSize > 0) { - REQUIRE_TRUE(dimC >= 0 && dimC < inRank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -inRank, inRank); - } - REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoGrayScale: operation expects 3 channels (R, B, G) in last dimention, but received %i", dimC); + REQUIRE_TRUE(inRank >= 1, 0, + "RGBtoGrayScale: Fails to meet the inRank requirement: %i >= 1 ", + inRank); + if (argSize > 0) { + REQUIRE_TRUE( + dimC >= 0 && dimC < inRank, 0, + "Index of the Channel dimension out of range: %i not in [%i,%i) ", + INT_ARG(0), -inRank, inRank); + } + REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, + "RGBtoGrayScale: operation expects 3 channels (R, B, G) in last " + "dimention, but received %i", + dimC); - auto nShape = input->getShapeAsVector(); - nShape[dimC] = 1; + auto nShape = input->getShapeAsVector(); + nShape[dimC] = 1; - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(input->dataType(), input->ordering(), nShape)); + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + input->dataType(), input->ordering(), nShape)); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/images/rgbToHsv.cpp b/libnd4j/include/ops/declarable/generic/images/rgbToHsv.cpp index ac5a27c667d8..ae5931cff989 100644 --- a/libnd4j/include/ops/declarable/generic/images/rgbToHsv.cpp +++ b/libnd4j/include/ops/declarable/generic/images/rgbToHsv.cpp @@ -18,45 +18,47 @@ // @author AbdelRauf (rauf@konduit.ai) // - - -#include -#include -#include #include +#include +#include +#include namespace sd { namespace ops { CONFIGURABLE_OP_IMPL(rgb_to_hsv, 1, 1, true, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - if (input->isEmpty()) - return Status::OK(); - - const int rank = input->rankOf(); - const int argSize = block.getIArguments()->size(); - const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; - - REQUIRE_TRUE(rank >= 1, 0, "RGBtoHSV: Fails to meet the rank requirement: %i >= 1 ", rank); - if (argSize > 0) { - REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); - } - REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoHSV: operation expects 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); - - helpers::transformRgbHsv(block.launchContext(), input, output, dimC); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (input->isEmpty()) return Status::OK(); + + const int rank = input->rankOf(); + const int argSize = block.numI(); + const int dimC = argSize > 0 + ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) + : rank - 1; + + REQUIRE_TRUE(rank >= 1, 0, + "RGBtoHSV: Fails to meet the rank requirement: %i >= 1 ", rank); + if (argSize > 0) { + REQUIRE_TRUE( + dimC >= 0 && dimC < rank, 0, + "Index of the Channel dimension out of range: %i not in [%i,%i) ", + INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE( + input->sizeAt(dimC) == 3, 0, + "RGBtoHSV: operation expects 3 channels (R, G, B), but got %i instead", + input->sizeAt(dimC)); + + helpers::transformRgbHsv(block.launchContext(), input, output, dimC); + + return Status::OK(); } - DECLARE_TYPES(rgb_to_hsv) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } - -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/images/rgbToYiq.cpp b/libnd4j/include/ops/declarable/generic/images/rgbToYiq.cpp index 40c936e4f995..aaccd9720772 100644 --- a/libnd4j/include/ops/declarable/generic/images/rgbToYiq.cpp +++ b/libnd4j/include/ops/declarable/generic/images/rgbToYiq.cpp @@ -14,47 +14,50 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author AbdelRauf (rauf@konduit.ai) - // +// +// @author AbdelRauf (rauf@konduit.ai) +// -#include -#include -#include #include +#include +#include +#include namespace sd { - namespace ops { - - - - CONFIGURABLE_OP_IMPL(rgb_to_yiq, 1, 1, true, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - if (input->isEmpty()) - return Status::OK(); - - const int rank = input->rankOf(); - const int arg_size = block.getIArguments()->size(); - const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; - - REQUIRE_TRUE(rank >= 1, 0, "RGBtoYIQ: Fails to meet the rank requirement: %i >= 1 ", rank); - if (arg_size > 0) { - REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); - } - REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoYIQ: operation expects 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); - - helpers::transformRgbYiq(block.launchContext(), input, output, dimC); - - return Status::OK(); - } - +namespace ops { + +CONFIGURABLE_OP_IMPL(rgb_to_yiq, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (input->isEmpty()) return Status::OK(); + + const int rank = input->rankOf(); + const int arg_size = block.numI(); + const int dimC = arg_size > 0 + ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) + : rank - 1; + + REQUIRE_TRUE(rank >= 1, 0, + "RGBtoYIQ: Fails to meet the rank requirement: %i >= 1 ", rank); + if (arg_size > 0) { + REQUIRE_TRUE( + dimC >= 0 && dimC < rank, 0, + "Index of the Channel dimension out of range: %i not in [%i,%i) ", + INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE( + input->sizeAt(dimC) == 3, 0, + "RGBtoYIQ: operation expects 3 channels (R, G, B), but got %i instead", + input->sizeAt(dimC)); + + helpers::transformRgbYiq(block.launchContext(), input, output, dimC); + + return Status::OK(); +} - DECLARE_TYPES(rgb_to_yiq) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); - } - } +DECLARE_TYPES(rgb_to_yiq) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/images/rgbToYuv.cpp b/libnd4j/include/ops/declarable/generic/images/rgbToYuv.cpp index b52b5a8a6b6c..cb9418d6d0fa 100644 --- a/libnd4j/include/ops/declarable/generic/images/rgbToYuv.cpp +++ b/libnd4j/include/ops/declarable/generic/images/rgbToYuv.cpp @@ -18,44 +18,48 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // - - -#include -#include -#include #include +#include +#include +#include namespace sd { namespace ops { CONFIGURABLE_OP_IMPL(rgb_to_yuv, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - // just skip op if input is empty - if (input->isEmpty()) - return Status::OK(); + // just skip op if input is empty + if (input->isEmpty()) return Status::OK(); - const int rank = input->rankOf(); - const int argSize = block.getIArguments()->size(); - const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; + const int rank = input->rankOf(); + const int argSize = block.numI(); + const int dimC = argSize > 0 + ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) + : rank - 1; - REQUIRE_TRUE(rank >= 1, 0, "RGBtoYUV: Fails to meet the rank requirement: %i >= 1 ", rank); - if (argSize > 0) { - REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); - } - REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoYUV: operation expects 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); + REQUIRE_TRUE(rank >= 1, 0, + "RGBtoYUV: Fails to meet the rank requirement: %i >= 1 ", rank); + if (argSize > 0) { + REQUIRE_TRUE( + dimC >= 0 && dimC < rank, 0, + "Index of the Channel dimension out of range: %i not in [%i,%i) ", + INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE( + input->sizeAt(dimC) == 3, 0, + "RGBtoYUV: operation expects 3 channels (R, G, B), but got %i instead", + input->sizeAt(dimC)); - helpers::transformRgbYuv(block.launchContext(), *input, *output, dimC); + helpers::transformRgbYuv(block.launchContext(), *input, *output, dimC); - return Status::OK(); + return Status::OK(); } DECLARE_TYPES(rgb_to_yuv) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/images/yiqToRgb.cpp b/libnd4j/include/ops/declarable/generic/images/yiqToRgb.cpp index e339fb74b678..14841f089460 100644 --- a/libnd4j/include/ops/declarable/generic/images/yiqToRgb.cpp +++ b/libnd4j/include/ops/declarable/generic/images/yiqToRgb.cpp @@ -16,46 +16,49 @@ // // @author AbdelRauf (rauf@konduit.ai) -// +// -#include -#include -#include #include +#include +#include +#include namespace sd { - namespace ops { +namespace ops { +CONFIGURABLE_OP_IMPL(yiq_to_rgb, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + if (input->isEmpty()) return Status::OK(); - CONFIGURABLE_OP_IMPL(yiq_to_rgb, 1, 1, true, 0, 0) { + const int rank = input->rankOf(); + const int arg_size = block.numI(); + const int dimC = arg_size > 0 + ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) + : rank - 1; - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(rank >= 1, 0, + "YIQtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank); + if (arg_size > 0) { + REQUIRE_TRUE( + dimC >= 0 && dimC < rank, 0, + "Index of the Channel dimension out of range: %i not in [%i,%i) ", + INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE( + input->sizeAt(dimC) == 3, 0, + "YIQtoRGB: operation expects 3 channels (Y, I, Q), but got %i instead", + input->sizeAt(dimC)); - if (input->isEmpty()) - return Status::OK(); + helpers::transformYiqRgb(block.launchContext(), input, output, dimC); - const int rank = input->rankOf(); - const int arg_size = block.getIArguments()->size(); - const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; - - REQUIRE_TRUE(rank >= 1, 0, "YIQtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank); - if (arg_size > 0) { - REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); - } - REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "YIQtoRGB: operation expects 3 channels (Y, I, Q), but got %i instead", input->sizeAt(dimC)); - - helpers::transformYiqRgb(block.launchContext(), input, output, dimC); - - return Status::OK(); - } - + return Status::OK(); +} - DECLARE_TYPES(yiq_to_rgb) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); - } - - } +DECLARE_TYPES(yiq_to_rgb) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/images/yuvToRgb.cpp b/libnd4j/include/ops/declarable/generic/images/yuvToRgb.cpp index 48d4e379a9eb..a1e3b62d3f36 100644 --- a/libnd4j/include/ops/declarable/generic/images/yuvToRgb.cpp +++ b/libnd4j/include/ops/declarable/generic/images/yuvToRgb.cpp @@ -16,45 +16,50 @@ // // @author Oleh Semeniv (oleg.semeniv@gmail.com) -// +// -#include -#include -#include #include +#include +#include +#include namespace sd { namespace ops { CONFIGURABLE_OP_IMPL(yuv_to_rgb, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + // just skip op if input is empty + if (input->isEmpty()) return Status::OK(); - // just skip op if input is empty - if (input->isEmpty()) - return Status::OK(); + const int rank = input->rankOf(); + const int argSize = block.numI(); + const int dimC = argSize > 0 + ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) + : rank - 1; - const int rank = input->rankOf(); - const int argSize = block.getIArguments()->size(); - const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; + REQUIRE_TRUE(rank >= 1, 0, + "YUVtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank); + if (argSize > 0) { + REQUIRE_TRUE( + dimC >= 0 && dimC < rank, 0, + "Index of the Channel dimension out of range: %i not in [%i,%i) ", + INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE( + input->sizeAt(dimC) == 3, 0, + "YUVtoRGB: operation expects 3 channels (Y, U, V), but got %i instead", + input->sizeAt(dimC)); - REQUIRE_TRUE(rank >= 1, 0, "YUVtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank); - if (argSize > 0) { - REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); - } - REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "YUVtoRGB: operation expects 3 channels (Y, U, V), but got %i instead", input->sizeAt(dimC)); + helpers::transformYuvRgb(block.launchContext(), *input, *output, dimC); - helpers::transformYuvRgb(block.launchContext(), *input, *output, dimC); - - return Status::OK(); + return Status::OK(); } DECLARE_TYPES(yuv_to_rgb) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } - -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/kernels/knn_mindistance.cpp b/libnd4j/include/ops/declarable/generic/kernels/knn_mindistance.cpp index 334014ee7c66..4f49e458459c 100644 --- a/libnd4j/include/ops/declarable/generic/kernels/knn_mindistance.cpp +++ b/libnd4j/include/ops/declarable/generic/kernels/knn_mindistance.cpp @@ -25,35 +25,42 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(knn_mindistance, 3, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto lowest = INPUT_VARIABLE(1); - auto highest = INPUT_VARIABLE(2); +namespace ops { - auto output = OUTPUT_VARIABLE(0); +CUSTOM_OP_IMPL(knn_mindistance, 3, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto lowest = INPUT_VARIABLE(1); + auto highest = INPUT_VARIABLE(2); - REQUIRE_TRUE(input->lengthOf() == lowest->lengthOf() && input->lengthOf() == highest->lengthOf(), 0, "knn_mindistance: all input arrays must have same length"); - REQUIRE_TRUE(input->dataType() == lowest->dataType() && input->dataType() == highest->dataType() && input->dataType() == output->dataType(), 0, "knn_mindistance: all inputs must have the same data type"); + auto output = OUTPUT_VARIABLE(0); - helpers::knn_mindistance(*input, *lowest, *highest, *output); + REQUIRE_TRUE(input->lengthOf() == lowest->lengthOf() && + input->lengthOf() == highest->lengthOf(), + 0, "knn_mindistance: all input arrays must have same length"); + REQUIRE_TRUE(input->dataType() == lowest->dataType() && + input->dataType() == highest->dataType() && + input->dataType() == output->dataType(), + 0, "knn_mindistance: all inputs must have the same data type"); - return Status::OK(); - } + helpers::knn_mindistance(*input, *lowest, *highest, *output); - DECLARE_SHAPE_FN(knn_mindistance) { - auto input = inputShape->at(0); + return Status::OK(); +} + +DECLARE_SHAPE_FN(knn_mindistance) { + auto input = inputShape->at(0); - // always return scalar here - return SHAPELIST(ConstantShapeHelper::getInstance().scalarShapeInfo(ArrayOptions::dataType(input))); - } + // always return scalar here + return SHAPELIST(ConstantShapeHelper::getInstance().scalarShapeInfo( + ArrayOptions::dataType(input))); +} - DECLARE_TYPES(knn_mindistance) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +DECLARE_TYPES(knn_mindistance) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/betaInc.cpp b/libnd4j/include/ops/declarable/generic/linalg/betaInc.cpp index 1850f10a1651..92af4e719673 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/betaInc.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/betaInc.cpp @@ -25,48 +25,55 @@ #include namespace sd { -namespace ops { +namespace ops { - DECLARE_TYPES(betainc) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setSameMode(true); - } +DECLARE_TYPES(betainc) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); +} CONFIGURABLE_OP_IMPL(betainc, 3, 1, false, 0, 0) { - auto a = INPUT_VARIABLE(0); - auto b = INPUT_VARIABLE(1); - auto x = INPUT_VARIABLE(2); - - // just skip op if input is empty - if (x->isEmpty()) { - *x = DataTypeUtils::nanOrZero(); - return Status::OK(); - } - - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(a->isSameShape(b) && a->isSameShape(x), 0, "CONFIGURABLE_OP betainc: all three input arrays must have the same shapes, bit got a=%s, b=%s and x=%s instead !", ShapeUtils::shapeAsString(a).c_str(), ShapeUtils::shapeAsString(b).c_str(), ShapeUtils::shapeAsString(x).c_str()); - - Nd4jLong arrLen = a->lengthOf(); - - // FIXME: this stuff should be single op call. No sense rolling over couple of arrays twice - for(Nd4jLong i = 0; i < arrLen; ++i ) { - REQUIRE_TRUE(a->e(i) > 0.f, 0, "BETAINC op: arrays a array must contain only elements > 0 !"); - REQUIRE_TRUE(b->e(i) > 0.f, 0, "BETAINC op: arrays b array must contain only elements > 0 !"); - REQUIRE_TRUE(0.f <= x->e(i) && x->e(i) <= 1.f, 0, "BETAINC op: all elements of x array must be within [0, 1] range!"); - } - - helpers::betaInc(block.launchContext(), *a, *b, *x, *output); + auto a = INPUT_VARIABLE(0); + auto b = INPUT_VARIABLE(1); + auto x = INPUT_VARIABLE(2); + // just skip op if input is empty + if (x->isEmpty()) { + *x = DataTypeUtils::nanOrZero(); return Status::OK(); + } + + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(a->isSameShape(b) && a->isSameShape(x), 0, + "CONFIGURABLE_OP betainc: all three input arrays must have the " + "same shapes, bit got a=%s, b=%s and x=%s instead !", + ShapeUtils::shapeAsString(a).c_str(), + ShapeUtils::shapeAsString(b).c_str(), + ShapeUtils::shapeAsString(x).c_str()); + + Nd4jLong arrLen = a->lengthOf(); + + // FIXME: this stuff should be single op call. No sense rolling over couple of + // arrays twice + for (Nd4jLong i = 0; i < arrLen; ++i) { + REQUIRE_TRUE(a->e(i) > 0.f, 0, + "BETAINC op: arrays a array must contain only elements > 0 !"); + REQUIRE_TRUE(b->e(i) > 0.f, 0, + "BETAINC op: arrays b array must contain only elements > 0 !"); + REQUIRE_TRUE( + 0.f <= x->e(i) && x->e(i) <= 1.f, 0, + "BETAINC op: all elements of x array must be within [0, 1] range!"); + } + + helpers::betaInc(block.launchContext(), *a, *b, *x, *output); + + return Status::OK(); } DECLARE_SYN(BetaInc, betainc); DECLARE_SYN(betaInc, betainc); - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/cholesky.cpp b/libnd4j/include/ops/declarable/generic/linalg/cholesky.cpp index dfc3830ca6ec..12b24e23a71e 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/cholesky.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/cholesky.cpp @@ -24,23 +24,32 @@ #include #include namespace sd { - namespace ops { - OP_IMPL(cholesky, 1, 1, true) { - NDArray* input = INPUT_VARIABLE(0); - NDArray* output = OUTPUT_VARIABLE(0); +namespace ops { +OP_IMPL(cholesky, 1, 1, true) { + NDArray* input = INPUT_VARIABLE(0); + NDArray* output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(input->rankOf() >=2, 0, "cholesky: The rank of input array should not less than 2, but %i is given", input->rankOf()); - REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "cholesky: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2)); - REQUIRE_TRUE(helpers::checkCholeskyInput(block.launchContext(), input), 0, "cholesky: The input tensor should be positive-defined and symmetric."); - return helpers::cholesky(block.launchContext(), input, output, block.isInplace()); - } - DECLARE_TYPES(cholesky) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - } + REQUIRE_TRUE(input->rankOf() >= 2, 0, + "cholesky: The rank of input array should not less than 2, but " + "%i is given", + input->rankOf()); + REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, + "cholesky: The last two dimmensions should be equal, but %i and " + "%i are given", + input->sizeAt(-1), input->sizeAt(-2)); + REQUIRE_TRUE( + helpers::checkCholeskyInput(block.launchContext(), input), 0, + "cholesky: The input tensor should be positive-defined and symmetric."); + return helpers::cholesky(block.launchContext(), input, output, + block.isInplace()); +} +DECLARE_TYPES(cholesky) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/cross.cpp b/libnd4j/include/ops/declarable/generic/linalg/cross.cpp index 0d701cf7191e..daf2718f43c8 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/cross.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/cross.cpp @@ -26,35 +26,38 @@ namespace sd { namespace ops { - DECLARE_TYPES(cross) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}) - ->setSameMode(true); - } - - OP_IMPL(cross, 2, 1, false) { - auto a = INPUT_VARIABLE(0); - auto b = INPUT_VARIABLE(1); - - REQUIRE_TRUE(a->lengthOf() == b->lengthOf(), 0, "Cross: A and B lengths should match"); - REQUIRE_TRUE(a->rankOf() >= 1 && b->rankOf() >= 1, 0, "Cross: A and B should have rank >= 1"); - - // TODO: we might want to lift this restriction - REQUIRE_TRUE(a->isSameShape(b),0, "Cross: A and B should have equal shape"); - REQUIRE_TRUE(a->sizeAt(-1) == 3, 0, "Cross: outer dimension of A and B should be equal to 3"); - - auto o = OUTPUT_VARIABLE(0); - - if (a->lengthOf() == 3) { - helpers::cross(block.launchContext(), a, b, o); - } else { - helpers::crossBatched(block.launchContext(), a, b, o); - } - - return Status::OK(); - } +DECLARE_TYPES(cross) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}) + ->setSameMode(true); } + +OP_IMPL(cross, 2, 1, false) { + auto a = INPUT_VARIABLE(0); + auto b = INPUT_VARIABLE(1); + + REQUIRE_TRUE(a->lengthOf() == b->lengthOf(), 0, + "Cross: A and B lengths should match"); + REQUIRE_TRUE(a->rankOf() >= 1 && b->rankOf() >= 1, 0, + "Cross: A and B should have rank >= 1"); + + // TODO: we might want to lift this restriction + REQUIRE_TRUE(a->isSameShape(b), 0, "Cross: A and B should have equal shape"); + REQUIRE_TRUE(a->sizeAt(-1) == 3, 0, + "Cross: outer dimension of A and B should be equal to 3"); + + auto o = OUTPUT_VARIABLE(0); + + if (a->lengthOf() == 3) { + helpers::cross(block.launchContext(), a, b, o); + } else { + helpers::crossBatched(block.launchContext(), a, b, o); + } + + return Status::OK(); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/diag.cpp b/libnd4j/include/ops/declarable/generic/linalg/diag.cpp index d67ca057b2c5..35b5ff1c972b 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/diag.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/diag.cpp @@ -25,42 +25,42 @@ #include namespace sd { -namespace ops { +namespace ops { -////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(diag, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - // input validation - REQUIRE_TRUE(input->rankOf() <= 3, 0, "CUSTOM_OP diag: rank of input array must be <= 3 !, but got %i instead", input->rankOf()); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - // TODO: still not sure if we really want this - output->assign(0); + // input validation + REQUIRE_TRUE( + input->rankOf() <= 3, 0, + "CUSTOM_OP diag: rank of input array must be <= 3 !, but got %i instead", + input->rankOf()); - helpers::diagFunctor(block.launchContext(), input, output); - - return Status::OK(); + // TODO: still not sure if we really want this + output->assign(0); + + helpers::diagFunctor(block.launchContext(), input, output); + + return Status::OK(); } DECLARE_SYN(MatrixDiag, diag); +DECLARE_TYPES(diag) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} - DECLARE_TYPES(diag) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - -////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(diag) { - const Nd4jLong* inputShapeInfo = inputShape->at(0); + const Nd4jLong* inputShapeInfo = inputShape->at(0); - return SHAPELIST(ShapeUtils::evalDiagShapeInfo(inputShapeInfo, block.workspace())); + return SHAPELIST( + ShapeUtils::evalDiagShapeInfo(inputShapeInfo, block.workspace())); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/diagPart.cpp b/libnd4j/include/ops/declarable/generic/linalg/diagPart.cpp index 6562a02a8a47..e926a405cb69 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/diagPart.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/diagPart.cpp @@ -25,58 +25,69 @@ #include namespace sd { -namespace ops { - - CUSTOM_OP_IMPL(diag_part, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { - const int inRank = input->rankOf(); - - // input validation - REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 6, 0, "DIAG_PART op: input array must have rank among following three possible values: 2, 4, 6, but got %i instead !", inRank); - for(int i = 0; i < inRank-1; ++i) - REQUIRE_TRUE(input->sizeAt(i) == input->sizeAt(i+1), 0, "DIAG_PART op: wrong shape of input array %s ! All dimensions must be equal !", ShapeUtils::shapeAsString(input).c_str()); +CUSTOM_OP_IMPL(diag_part, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - helpers::diagPartFunctor(block.launchContext(), input, output); + const int inRank = input->rankOf(); - return Status::OK(); - } - DECLARE_SYN(DiagPart, diag_part); + // input validation + REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 6, 0, + "DIAG_PART op: input array must have rank among following three " + "possible values: 2, 4, 6, but got %i instead !", + inRank); + for (int i = 0; i < inRank - 1; ++i) + REQUIRE_TRUE(input->sizeAt(i) == input->sizeAt(i + 1), 0, + "DIAG_PART op: wrong shape of input array %s ! All dimensions " + "must be equal !", + ShapeUtils::shapeAsString(input).c_str()); - DECLARE_TYPES(diag_part) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } + helpers::diagPartFunctor(block.launchContext(), input, output); - DECLARE_SHAPE_FN(diag_part) { - auto inputShapeInfo = inputShape->at(0); + return Status::OK(); +} +DECLARE_SYN(DiagPart, diag_part); - const int inRank = inputShapeInfo[0]; +DECLARE_TYPES(diag_part) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} - // input validation - REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 6, 0, "DIAG_PART op: input array must have rank among following three possible values: 2, 4, 6, but got %i instead !", inRank); - for(int i = 1; i < inRank; ++i) - REQUIRE_TRUE(inputShapeInfo[i] == inputShapeInfo[i+1], 0, "DIAG_PART op: wrong shape of input array %s ! All dimensions must be equal !", ShapeUtils::shapeAsString(inputShapeInfo).c_str()); +DECLARE_SHAPE_FN(diag_part) { + auto inputShapeInfo = inputShape->at(0); - Nd4jLong* outShapeInfo = nullptr; + const int inRank = inputShapeInfo[0]; - int outRank = inRank/2; + // input validation + REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 6, 0, + "DIAG_PART op: input array must have rank among following three " + "possible values: 2, 4, 6, but got %i instead !", + inRank); + for (int i = 1; i < inRank; ++i) + REQUIRE_TRUE(inputShapeInfo[i] == inputShapeInfo[i + 1], 0, + "DIAG_PART op: wrong shape of input array %s ! All dimensions " + "must be equal !", + ShapeUtils::shapeAsString(inputShapeInfo).c_str()); - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); - - outShapeInfo[0] = outRank; - for(int i = 1; i <= outRank; ++i) - outShapeInfo[i] = inputShapeInfo[i]; + Nd4jLong* outShapeInfo = nullptr; - ShapeUtils::updateStridesAndType(outShapeInfo, inputShapeInfo, shape::order(inputShapeInfo)); + int outRank = inRank / 2; - return SHAPELIST(ConstantShapeHelper::getInstance().createFromExisting(outShapeInfo, block.workspace())); - } + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + outShapeInfo[0] = outRank; + for (int i = 1; i <= outRank; ++i) outShapeInfo[i] = inputShapeInfo[i]; + ShapeUtils::updateStridesAndType(outShapeInfo, inputShapeInfo, + shape::order(inputShapeInfo)); + + return SHAPELIST(ConstantShapeHelper::getInstance().createFromExisting( + outShapeInfo, block.workspace())); } -} + +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/digamma.cpp b/libnd4j/include/ops/declarable/generic/linalg/digamma.cpp index 17afcc10b170..5f627a9cdfdd 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/digamma.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/digamma.cpp @@ -26,25 +26,24 @@ #include namespace sd { -namespace ops { +namespace ops { CONFIGURABLE_OP_IMPL(digamma, 1, 1, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); + helpers::diGamma(block.launchContext(), *x, *z); - helpers::diGamma(block.launchContext(), *x, *z); - - return Status::OK(); + return Status::OK(); } DECLARE_TYPES(digamma) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS, ALL_INTS}) - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS, ALL_INTS}) + ->setSameMode(true); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/eye.cpp b/libnd4j/include/ops/declarable/generic/linalg/eye.cpp index 4bf33961489a..20174a6de112 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/eye.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/eye.cpp @@ -21,92 +21,91 @@ #if NOT_EXCLUDED(OP_eye) #include -#include +#include namespace sd { namespace ops { +CUSTOM_OP_IMPL(eye, -2, 1, false, -2, -2) { + helpers::eye(block.launchContext(), *OUTPUT_VARIABLE(0)); - CUSTOM_OP_IMPL(eye, -2, 1, false, -2, -2) { - - helpers::eye(block.launchContext(), *OUTPUT_VARIABLE(0)); - - return Status::OK(); - } - - DECLARE_TYPES(eye) { - getOpDescriptor()->setAllowedInputTypes(0, {ALL_INTS}); - getOpDescriptor()->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}); - getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS}); - } - - DECLARE_SHAPE_FN(eye) { - - std::vector params; - - sd::DataType dtype = block.getTArguments()->empty() ? sd::DataType::FLOAT32 : sd::DataTypeUtils::fromInt(T_ARG(0)); - - if(block.width() == 0) { - params = *block.getIArguments(); - } - else { - for (int i = 0; i < block.width(); i++) { - auto input = INPUT_VARIABLE(i); - REQUIRE_TRUE(input->rankOf() == 1, 0, "Inputs to eye should be 1D"); - - for (int e = 0; e < input->lengthOf(); e++) - params.emplace_back(input->e(e)); - } - } - - - REQUIRE_TRUE(params.size() > 0, 0, "Size is not provided for eye op."); - - const bool ordered = (params[0] == -99 || params[0] == -102); // -99 :'c', -102 : 'f' - if (!ordered) - params.insert(params.begin(), -99); - - REQUIRE_TRUE(params.size() > 1, 0, "Size is not provided for eye op."); - - Nd4jLong* outShapeInfo(nullptr); - - const int size = params.size(); + return Status::OK(); +} - switch(size) { +DECLARE_TYPES(eye) { + getOpDescriptor()->setAllowedInputTypes(0, {ALL_INTS}); + getOpDescriptor()->setAllowedInputTypes(1, + {DataType::INT32, DataType::INT64}); + getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS}); +} - case 2: - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong); - outShapeInfo[0] = 2; - outShapeInfo[1] = params[1]; - outShapeInfo[2] = params[1]; - break; +DECLARE_SHAPE_FN(eye) { + std::vector params; - case 3: - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong); - outShapeInfo[0] = 2; - outShapeInfo[1] = params[1]; - outShapeInfo[2] = params[2]; - break; + sd::DataType dtype = block.getTArguments().empty() + ? sd::DataType::FLOAT32 + : sd::DataTypeUtils::fromInt(T_ARG(0)); - default: - int rank = size-1; - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); - outShapeInfo[0] = rank; - outShapeInfo[rank-1] = params[1]; - outShapeInfo[rank] = params[2]; - for(int i = 1; i < rank-1; ++i) - outShapeInfo[i] = params[i+2]; - break; - } + if (block.width() == 0) { + params = block.getIArguments(); + } else { + for (int i = 0; i < block.width(); i++) { + auto input = INPUT_VARIABLE(i); + REQUIRE_TRUE(input->rankOf() == 1, 0, "Inputs to eye should be 1D"); - shape::updateStrides(outShapeInfo, static_cast(-params[0])); - auto result = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(outShapeInfo, dtype)); - RELEASE(outShapeInfo, block.getWorkspace()); - return SHAPELIST(result); + for (int e = 0; e < input->lengthOf(); e++) + params.emplace_back(input->e(e)); } - - -} + } + + REQUIRE_TRUE(params.size() > 0, 0, "Size is not provided for eye op."); + + const bool ordered = + (params[0] == -99 || params[0] == -102); // -99 :'c', -102 : 'f' + if (!ordered) params.insert(params.begin(), -99); + + REQUIRE_TRUE(params.size() > 1, 0, "Size is not provided for eye op."); + + Nd4jLong* outShapeInfo(nullptr); + + const int size = params.size(); + + switch (size) { + case 2: + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(2), + Nd4jLong); + outShapeInfo[0] = 2; + outShapeInfo[1] = params[1]; + outShapeInfo[2] = params[1]; + break; + + case 3: + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(2), + Nd4jLong); + outShapeInfo[0] = 2; + outShapeInfo[1] = params[1]; + outShapeInfo[2] = params[2]; + break; + + default: + int rank = size - 1; + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); + outShapeInfo[0] = rank; + outShapeInfo[rank - 1] = params[1]; + outShapeInfo[rank] = params[2]; + for (int i = 1; i < rank - 1; ++i) outShapeInfo[i] = params[i + 2]; + break; + } + + shape::updateStrides(outShapeInfo, static_cast(-params[0])); + auto result = ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(outShapeInfo, dtype)); + RELEASE(outShapeInfo, block.workspace()); + return SHAPELIST(result); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/lgamma.cpp b/libnd4j/include/ops/declarable/generic/linalg/lgamma.cpp index c39f8b55da3e..b01ecba3947a 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/lgamma.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/lgamma.cpp @@ -26,25 +26,24 @@ #include namespace sd { -namespace ops { +namespace ops { OP_IMPL(lgamma, 1, 1, true) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); + helpers::lgamma(block.launchContext(), *x, *z); - helpers::lgamma(block.launchContext(), *x, *z); - - return Status::OK(); + return Status::OK(); } DECLARE_TYPES(lgamma) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) // as TF says - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) // as TF says + ->setSameMode(true); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/log1p.cpp b/libnd4j/include/ops/declarable/generic/linalg/log1p.cpp index 797ca8b2a487..1e99fc6377bb 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/log1p.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/log1p.cpp @@ -24,25 +24,25 @@ #include namespace sd { - namespace ops { - OP_IMPL(Log1p, 1, 1, true) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +OP_IMPL(Log1p, 1, 1, true) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - x->applyTransform(transform::Log1p, *z); + x->applyTransform(transform::Log1p, *z); - STORE_RESULT(z); + STORE_RESULT(z); - return Status::OK(); - } - DECLARE_SYN(log1p, Log1p); - } + return Status::OK(); +} +DECLARE_SYN(log1p, Log1p); +} // namespace ops - DECLARE_TYPES(Log1p) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(Log1p) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/lstsq.cpp b/libnd4j/include/ops/declarable/generic/linalg/lstsq.cpp index 5078ff6f12c9..2be09ce27d79 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/lstsq.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/lstsq.cpp @@ -25,112 +25,144 @@ #include namespace sd { - namespace ops { - - CUSTOM_OP_IMPL(lstsq, 2, 1, false, 0, 0) { - auto a = INPUT_VARIABLE(0); - auto b = INPUT_VARIABLE(1); - auto z = OUTPUT_NULLIFIED(0); - bool fastFlag = true; - double l2_factor = 0.; - if (block.numB() > 0) { - fastFlag = B_ARG(0); - } - if (block.numT() > 0) { - l2_factor = T_ARG(0); - } - REQUIRE_TRUE(a->rankOf() >=2, 0, "lstsq: The rank of input left tensor should not be less than 2, but %i is given", a->rankOf()); - REQUIRE_TRUE(b->rankOf() >=2, 0, "lstsq: The rank of input right tensor should not be less than 2, but %i is given", b->rankOf()); - -// REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "lstsq: The last two dimmensions should be equal, but %i and %i are given", a->sizeAt(-1), a->sizeAt(-2)); - REQUIRE_TRUE(a->sizeAt(-2) == b->sizeAt(-2), 0, "lstsq: The last dimmension of left part should be equal to prelast of right part, but %i and %i are given", a->sizeAt(-1), b->sizeAt(-2)); - //REQUIRE_TRUE(l2_factor == 0., 0, "lstsq: Implementation of operation is not finished for factor difference from 0."); - if (a->isEmpty() || b->isEmpty() || z->isEmpty()) - return Status::OK(); - - auto res = helpers::leastSquaresSolveFunctor(block.launchContext(), a, b, l2_factor, fastFlag, z); - - return res; - } - - CUSTOM_OP_IMPL(solve_ls, 2, 1, false, 0, 0) { - auto a = INPUT_VARIABLE(0); - auto b = INPUT_VARIABLE(1); - auto z = OUTPUT_NULLIFIED(0); - bool fastFlag = true; - double l2_factor = 0.; - if (block.numB() > 0) { - fastFlag = B_ARG(0); - } - if (block.numT() > 0) { - l2_factor = T_ARG(0); - } - REQUIRE_TRUE(a->rankOf() >=2, 0, "lstsq: The rank of input left tensor should not be less than 2, but %i is given", a->rankOf()); - REQUIRE_TRUE(b->rankOf() >=2, 0, "lstsq: The rank of input right tensor should not be less than 2, but %i is given", b->rankOf()); - -// REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "lstsq: The last two dimmensions should be equal, but %i and %i are given", a->sizeAt(-1), a->sizeAt(-2)); - REQUIRE_TRUE(a->sizeAt(-2) == b->sizeAt(-2), 0, "lstsq: The last dimmension of left part should be equal to prelast of right part, but %i and %i are given", a->sizeAt(-1), b->sizeAt(-2)); - //REQUIRE_TRUE(l2_factor == 0., 0, "lstsq: Implementation of operation is not finished for factor difference from 0."); - auto res = Status::OK(); - if (a->isEmpty() || b->isEmpty() || z->isEmpty()) - return res; - - res = helpers::leastSquaresSolveFunctor(block.launchContext(), a, b, l2_factor, fastFlag, z); - - return res; - } - - DECLARE_SYN(MatrixSolveLs, lstsq); - - DECLARE_SHAPE_FN(lstsq) { - auto in0 = inputShape->at(0); - auto in1 = inputShape->at(1); - auto shapeOf = ShapeUtils::shapeAsVector(in1); - auto rank = shapeOf.size(); - shapeOf[rank - 2] = shape::sizeAt(in0, -1); - - if (shape::isEmpty(in0) || shape::isEmpty(in1)) { - shapeOf[rank - 1] = 0; // set output shape to empty - } - auto resShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(in0), shape::order(in1), shapeOf);//ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace()); - if (shapeOf[rank - 1] == 0) { -// ArrayOptions::setPropertyBit(resShape, ARRAY_EMPTY); - resShape = ConstantShapeHelper::getInstance().emptyShapeInfo(ArrayOptions::dataType(in0)); - } - return SHAPELIST(resShape); - } - - DECLARE_TYPES(lstsq) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(false); - } - DECLARE_SHAPE_FN(solve_ls) { - auto in0 = inputShape->at(0); - auto in1 = inputShape->at(1); - auto shapeOf = ShapeUtils::shapeAsVector(in1); - auto rank = shapeOf.size(); - shapeOf[rank - 2] = shape::sizeAt(in0, -1); - - if (shape::isEmpty(in0) || shape::isEmpty(in1)) { - shapeOf[rank - 1] = 0; // set output shape to empty - } - auto resShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(in0), shape::order(in1), shapeOf);//ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace()); - if (shapeOf[rank - 1] == 0) { - resShape = ConstantShapeHelper::getInstance().emptyShapeInfo(ArrayOptions::dataType(in1)); -// ArrayOptions::setPropertyBit(resShape, ARRAY_EMPTY); - } - return SHAPELIST(resShape); - } - - DECLARE_TYPES(solve_ls) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(false); - } - } +namespace ops { + +CUSTOM_OP_IMPL(lstsq, 2, 1, false, 0, 0) { + auto a = INPUT_VARIABLE(0); + auto b = INPUT_VARIABLE(1); + auto z = OUTPUT_NULLIFIED(0); + bool fastFlag = true; + double l2_factor = 0.; + if (block.numB() > 0) { + fastFlag = B_ARG(0); + } + if (block.numT() > 0) { + l2_factor = T_ARG(0); + } + REQUIRE_TRUE(a->rankOf() >= 2, 0, + "lstsq: The rank of input left tensor should not be less than " + "2, but %i is given", + a->rankOf()); + REQUIRE_TRUE(b->rankOf() >= 2, 0, + "lstsq: The rank of input right tensor should not be less than " + "2, but %i is given", + b->rankOf()); + + // REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "lstsq: The last + // two dimmensions should be equal, but %i and %i are given", + // a->sizeAt(-1), a->sizeAt(-2)); + REQUIRE_TRUE(a->sizeAt(-2) == b->sizeAt(-2), 0, + "lstsq: The last dimmension of left part should be equal to " + "prelast of right part, but %i and %i are given", + a->sizeAt(-1), b->sizeAt(-2)); + // REQUIRE_TRUE(l2_factor == 0., 0, "lstsq: Implementation of operation is not + // finished for factor difference from 0."); + if (a->isEmpty() || b->isEmpty() || z->isEmpty()) return Status::OK(); + + auto res = helpers::leastSquaresSolveFunctor(block.launchContext(), a, b, + l2_factor, fastFlag, z); + + return res; } +CUSTOM_OP_IMPL(solve_ls, 2, 1, false, 0, 0) { + auto a = INPUT_VARIABLE(0); + auto b = INPUT_VARIABLE(1); + auto z = OUTPUT_NULLIFIED(0); + bool fastFlag = true; + double l2_factor = 0.; + if (block.numB() > 0) { + fastFlag = B_ARG(0); + } + if (block.numT() > 0) { + l2_factor = T_ARG(0); + } + REQUIRE_TRUE(a->rankOf() >= 2, 0, + "lstsq: The rank of input left tensor should not be less than " + "2, but %i is given", + a->rankOf()); + REQUIRE_TRUE(b->rankOf() >= 2, 0, + "lstsq: The rank of input right tensor should not be less than " + "2, but %i is given", + b->rankOf()); + + // REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "lstsq: The last + // two dimmensions should be equal, but %i and %i are given", + // a->sizeAt(-1), a->sizeAt(-2)); + REQUIRE_TRUE(a->sizeAt(-2) == b->sizeAt(-2), 0, + "lstsq: The last dimmension of left part should be equal to " + "prelast of right part, but %i and %i are given", + a->sizeAt(-1), b->sizeAt(-2)); + // REQUIRE_TRUE(l2_factor == 0., 0, "lstsq: Implementation of operation is not + // finished for factor difference from 0."); + auto res = Status::OK(); + if (a->isEmpty() || b->isEmpty() || z->isEmpty()) return res; + + res = helpers::leastSquaresSolveFunctor(block.launchContext(), a, b, + l2_factor, fastFlag, z); + + return res; +} + +DECLARE_SYN(MatrixSolveLs, lstsq); + +DECLARE_SHAPE_FN(lstsq) { + auto in0 = inputShape->at(0); + auto in1 = inputShape->at(1); + auto shapeOf = ShapeUtils::shapeAsVector(in1); + auto rank = shapeOf.size(); + shapeOf[rank - 2] = shape::sizeAt(in0, -1); + + if (shape::isEmpty(in0) || shape::isEmpty(in1)) { + shapeOf[rank - 1] = 0; // set output shape to empty + } + auto resShape = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(in0), shape::order(in1), + shapeOf); // ShapeBuilders::copyShapeInfoAndType(in1, in0, true, + // block.workspace()); + if (shapeOf[rank - 1] == 0) { + // ArrayOptions::setPropertyBit(resShape, ARRAY_EMPTY); + resShape = ConstantShapeHelper::getInstance().emptyShapeInfo( + ArrayOptions::dataType(in0)); + } + return SHAPELIST(resShape); +} + +DECLARE_TYPES(lstsq) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(false); +} +DECLARE_SHAPE_FN(solve_ls) { + auto in0 = inputShape->at(0); + auto in1 = inputShape->at(1); + auto shapeOf = ShapeUtils::shapeAsVector(in1); + auto rank = shapeOf.size(); + shapeOf[rank - 2] = shape::sizeAt(in0, -1); + + if (shape::isEmpty(in0) || shape::isEmpty(in1)) { + shapeOf[rank - 1] = 0; // set output shape to empty + } + auto resShape = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(in0), shape::order(in1), + shapeOf); // ShapeBuilders::copyShapeInfoAndType(in1, in0, true, + // block.workspace()); + if (shapeOf[rank - 1] == 0) { + resShape = ConstantShapeHelper::getInstance().emptyShapeInfo( + ArrayOptions::dataType(in1)); + // ArrayOptions::setPropertyBit(resShape, ARRAY_EMPTY); + } + return SHAPELIST(resShape); +} + +DECLARE_TYPES(solve_ls) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(false); +} +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/lup.cpp b/libnd4j/include/ops/declarable/generic/linalg/lup.cpp index e0b1eb8d7e32..21123947c81c 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/lup.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/lup.cpp @@ -24,45 +24,62 @@ #include #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(lu, 1, 2, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(lu, 1, 2, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - auto p = OUTPUT_VARIABLE(1); - if (block.getIArguments()->size()) { - DataType dtype = (DataType)INT_ARG(0); - REQUIRE_TRUE(dtype == sd::DataType::INT32 || dtype == sd::DataType::INT64, 0, "lu: Permutation data type should be 32bit or 64bit int only, but '%s' given.", DataTypeUtils::asString(dtype).c_str()); } + auto p = OUTPUT_VARIABLE(1); + if (block.numI()) { + DataType dtype = (DataType)INT_ARG(0); + REQUIRE_TRUE(dtype == sd::DataType::INT32 || dtype == sd::DataType::INT64, + 0, + "lu: Permutation data type should be 32bit or 64bit int only, " + "but '%s' given.", + DataTypeUtils::asString(dtype).c_str()); + } - REQUIRE_TRUE(input->rankOf() >=2, 0, "lu: The rank of input array should not less than 2, but %i is given", input->rankOf()); - REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "lu: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2)); + REQUIRE_TRUE( + input->rankOf() >= 2, 0, + "lu: The rank of input array should not less than 2, but %i is given", + input->rankOf()); + REQUIRE_TRUE( + input->sizeAt(-1) == input->sizeAt(-2), 0, + "lu: The last two dimmensions should be equal, but %i and %i are given", + input->sizeAt(-1), input->sizeAt(-2)); - helpers::lu(block.launchContext(), input, z, p); - return Status::OK(); - } - - DECLARE_SHAPE_FN(lu) { - auto in = inputShape->at(0); - auto shapeVector = ShapeUtils::shapeAsVector(in); - auto luShape = ShapeBuilders::copyShapeInfoAndType(in, in, true, block.workspace()); - auto dtype = sd::DataType::INT32; - if (block.getIArguments()->size()) { - dtype = (DataType)INT_ARG(0); - REQUIRE_TRUE(dtype == sd::DataType::INT32 || dtype == sd::DataType::INT64, 0, "lu: Permutation data type should be 32bit or 64bit int only, but '%s' given.", DataTypeUtils::asString(dtype).c_str()); - } - auto luP = ShapeBuilders::createShapeInfo(dtype, shape::order(in), shapeVector.size() - 1, - shapeVector.data(), block.workspace()); - return SHAPELIST(CONSTANT(luShape), CONSTANT(luP)); - } + helpers::lu(block.launchContext(), input, z, p); + return Status::OK(); +} + +DECLARE_SHAPE_FN(lu) { + auto in = inputShape->at(0); + auto shapeVector = ShapeUtils::shapeAsVector(in); + auto luShape = + ShapeBuilders::copyShapeInfoAndType(in, in, true, block.workspace()); + auto dtype = sd::DataType::INT32; + if (block.numI()) { + dtype = (DataType)INT_ARG(0); + REQUIRE_TRUE(dtype == sd::DataType::INT32 || dtype == sd::DataType::INT64, + 0, + "lu: Permutation data type should be 32bit or 64bit int only, " + "but '%s' given.", + DataTypeUtils::asString(dtype).c_str()); + } + auto luP = ShapeBuilders::createShapeInfo( + dtype, shape::order(in), shapeVector.size() - 1, shapeVector.data(), + block.workspace()); + return SHAPELIST(CONSTANT(luShape), CONSTANT(luP)); +} - DECLARE_TYPES(lu) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {sd::DataType::INT32, sd::DataType::INT64}) - ->setSameMode(false); - } - } +DECLARE_TYPES(lu) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {sd::DataType::INT32, sd::DataType::INT64}) + ->setSameMode(false); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/matrixDiagPart.cpp b/libnd4j/include/ops/declarable/generic/linalg/matrixDiagPart.cpp index db73fac75bec..7452e6b3d21b 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/matrixDiagPart.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/matrixDiagPart.cpp @@ -21,52 +21,56 @@ #include #include - namespace sd { - namespace ops { - CUSTOM_OP_IMPL(matrix_diag_part, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - const int inRank = input->rankOf(); - - REQUIRE_TRUE(inRank >= 2, 0, "CUSTOM_OP matrix_diag_part: input array must have rank >= 2, but %i given!", inRank); +namespace ops { +CUSTOM_OP_IMPL(matrix_diag_part, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + const int inRank = input->rankOf(); - output->nullify(); - return helpers::matrixDiagPart(block.launchContext(), input, output); - } + REQUIRE_TRUE(inRank >= 2, 0, + "CUSTOM_OP matrix_diag_part: input array must have rank >= 2, " + "but %i given!", + inRank); - DECLARE_SHAPE_FN(matrix_diag_part) { - Nd4jLong const* outShapeInfo = nullptr; - auto in = inputShape->at(0); - int inRank = shape::rank(in); + output->nullify(); + return helpers::matrixDiagPart(block.launchContext(), input, output); +} - REQUIRE_TRUE(inRank >= 2, 0, "CUSTOM_OP matrix_diag_part: input array must have rank >= 2, but %i given!", inRank); +DECLARE_SHAPE_FN(matrix_diag_part) { + Nd4jLong const* outShapeInfo = nullptr; + auto in = inputShape->at(0); + int inRank = shape::rank(in); - int outRank = inRank - 1; - int lastDimension = sd::math::nd4j_min(shape::sizeAt(in, -1), shape::sizeAt(in, -2)); - if(outRank == 1) { - //output shape is a vector with size min(sizeAt(0), sizeAt(1)) - outShapeInfo = ConstantShapeHelper::getInstance().vectorShapeInfo(lastDimension, ArrayOptions::dataType(in)); - } - else { - Nd4jLong* anShapeInfo; - ALLOCATE(anShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); - anShapeInfo[0] = outRank; - for(int i = 0; i < outRank - 1; ++i) - anShapeInfo[i + 1] = shape::sizeAt(in, i); - anShapeInfo[outRank] = lastDimension; + REQUIRE_TRUE(inRank >= 2, 0, + "CUSTOM_OP matrix_diag_part: input array must have rank >= 2, " + "but %i given!", + inRank); - ShapeUtils::updateStridesAndType(anShapeInfo, in, shape::order(in)); - outShapeInfo = CONSTANT(anShapeInfo); - } - return SHAPELIST(outShapeInfo); - } + int outRank = inRank - 1; + int lastDimension = + sd::math::nd4j_min(shape::sizeAt(in, -1), shape::sizeAt(in, -2)); + if (outRank == 1) { + // output shape is a vector with size min(sizeAt(0), sizeAt(1)) + outShapeInfo = ConstantShapeHelper::getInstance().vectorShapeInfo( + lastDimension, ArrayOptions::dataType(in)); + } else { + Nd4jLong* anShapeInfo; + ALLOCATE(anShapeInfo, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + anShapeInfo[0] = outRank; + for (int i = 0; i < outRank - 1; ++i) + anShapeInfo[i + 1] = shape::sizeAt(in, i); + anShapeInfo[outRank] = lastDimension; - DECLARE_TYPES(matrix_diag_part) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } -} + ShapeUtils::updateStridesAndType(anShapeInfo, in, shape::order(in)); + outShapeInfo = CONSTANT(anShapeInfo); + } + return SHAPELIST(outShapeInfo); } +DECLARE_TYPES(matrix_diag_part) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/linalg/matrixSetDiag.cpp b/libnd4j/include/ops/declarable/generic/linalg/matrixSetDiag.cpp index 222074c81c8d..730491e80599 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/matrixSetDiag.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/matrixSetDiag.cpp @@ -25,35 +25,50 @@ #include namespace sd { -namespace ops { +namespace ops { CONFIGURABLE_OP_IMPL(matrix_set_diag, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto diagonal = INPUT_VARIABLE(1); + auto input = INPUT_VARIABLE(0); + auto diagonal = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(diagonal->rankOf() == input->rankOf()-1, 0, "MATRIX_SET_DIAG op: rank of diagonal array must be smaller by one compared to rank of input array, but got %i and %i correspondingly !", diagonal->rankOf(), input->rankOf()); + REQUIRE_TRUE( + diagonal->rankOf() == input->rankOf() - 1, 0, + "MATRIX_SET_DIAG op: rank of diagonal array must be smaller by one " + "compared to rank of input array, but got %i and %i correspondingly !", + diagonal->rankOf(), input->rankOf()); - for(int i = 0; i < diagonal->rankOf() - 1; ++i) - REQUIRE_TRUE(diagonal->sizeAt(i) == input->sizeAt(i), 0, "MATRIX_SET_DIAG op: the shapes of diagonal and input arrays must be equal till last diagonal dimension but one, however got diagonal=%s and input=%s instead !", ShapeUtils::shapeAsString(diagonal).c_str(), ShapeUtils::shapeAsString(input).c_str()); + for (int i = 0; i < diagonal->rankOf() - 1; ++i) + REQUIRE_TRUE(diagonal->sizeAt(i) == input->sizeAt(i), 0, + "MATRIX_SET_DIAG op: the shapes of diagonal and input arrays " + "must be equal till last diagonal dimension but one, however " + "got diagonal=%s and input=%s instead !", + ShapeUtils::shapeAsString(diagonal).c_str(), + ShapeUtils::shapeAsString(input).c_str()); - REQUIRE_TRUE(diagonal->sizeAt(-1) == (int)sd::math::nd4j_min(input->sizeAt(-1), input->sizeAt(-2)), 0, "MATRIX_SET_DIAG op: the value of last dimension of diagonal array must be equal to min(input_last_shape=%i, input_last_but_one_shape=%i), but got %i instead !", input->sizeAt(-1), input->sizeAt(-2), diagonal->sizeAt(-1)); + REQUIRE_TRUE( + diagonal->sizeAt(-1) == (int)sd::math::nd4j_min( + input->sizeAt(-1), input->sizeAt(-2)), + 0, + "MATRIX_SET_DIAG op: the value of last dimension of diagonal array must " + "be equal to min(input_last_shape=%i, input_last_but_one_shape=%i), but " + "got %i instead !", + input->sizeAt(-1), input->sizeAt(-2), diagonal->sizeAt(-1)); - helpers::matrixSetDiag(block.launchContext(), *input, *diagonal, *output, false); + helpers::matrixSetDiag(block.launchContext(), *input, *diagonal, *output, + false); - return Status::OK(); + return Status::OK(); } DECLARE_SYN(MatrixSetDiag, matrix_set_diag); - DECLARE_TYPES(matrix_set_diag) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - -} +DECLARE_TYPES(matrix_set_diag) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/matrix_band_part.cpp b/libnd4j/include/ops/declarable/generic/linalg/matrix_band_part.cpp index 51beff4c8a13..388a23169255 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/matrix_band_part.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/matrix_band_part.cpp @@ -20,18 +20,18 @@ #include #if NOT_EXCLUDED(OP_matrix_band_part) -#include #include +#include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(matrix_band_part, 1, 1, true, 0, 0) { +namespace ops { +CONFIGURABLE_OP_IMPL(matrix_band_part, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - Nd4jLong minLower(0LL); + Nd4jLong minLower(0LL); Nd4jLong maxUpper(0LL); if (block.width() == 1) { REQUIRE_TRUE(block.numI() == 2, 0, "matrix_band_part: min and max band numbers should be given before."); @@ -45,20 +45,25 @@ namespace sd { REQUIRE_TRUE(minLowerT->isScalar() && maxUpperT->isScalar(), 0, "matrix_band_part: min and max should be scalars, but %i and %i ranks given", minLowerT->rankOf(), maxUpperT->rankOf()); minLower = minLowerT->e(0); maxUpper = maxUpperT->e(0); - } - REQUIRE_TRUE(input->rankOf() >= 2, 0, "matrix_band_part: Input rank should be 2 or greater."); - Nd4jLong N = input->sizeAt(-2); - Nd4jLong M = input->sizeAt(-1); - REQUIRE_TRUE(minLower > -N && minLower < N, 0, "matrix_band_part: lower diagonal count %i should be less than %i.", - minLower, N); - REQUIRE_TRUE(maxUpper > -M && maxUpper < M, 0, "matrix_band_part: upper diagonal count %i should be less than %i.", - maxUpper, M); + }REQUIRE_TRUE(input->rankOf() >= 2, 0, + "matrix_band_part: Input rank should be 2 or greater."); + Nd4jLong N = input->sizeAt(-2); + Nd4jLong M = input->sizeAt(-1); + REQUIRE_TRUE( + minLower > -N && minLower < N, 0, + "matrix_band_part: lower diagonal count %i should be less than %i.", + minLower, N); + REQUIRE_TRUE( + maxUpper > -M && maxUpper < M, 0, + "matrix_band_part: upper diagonal count %i should be less than %i.", + maxUpper, M); + + helpers::matrixBandPart(block.launchContext(), input, output, minLower, + maxUpper); + return ND4J_STATUS_OK; +} +DECLARE_SYN(band_part, matrix_band_part); - helpers::matrixBandPart(block.launchContext(), input, output, minLower, maxUpper); - return ND4J_STATUS_OK; - } - DECLARE_SYN(band_part, matrix_band_part); - } DECLARE_TYPES(matrix_band_part) { getOpDescriptor() @@ -68,5 +73,6 @@ namespace sd { ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}); } } +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/matrix_determinant.cpp b/libnd4j/include/ops/declarable/generic/linalg/matrix_determinant.cpp index 7046b69f90de..2ffd62c65054 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/matrix_determinant.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/matrix_determinant.cpp @@ -18,128 +18,157 @@ // Created by GS at 2/26/2018 // -#include #include #include +#include #if NOT_EXCLUDED(OP_matrix_determinant) namespace sd { - namespace ops { - CUSTOM_OP_IMPL(matrix_determinant, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(input->rankOf() >=2, 0, "matrix_determinant: The rank of input array should not less than 2, but %i is given", input->rankOf()); - REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "matrix_determinant: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2)); - - return helpers::determinant(block.launchContext(), input, output); - } - - DECLARE_SHAPE_FN(matrix_determinant) { - auto inShape = inputShape->at(0); - - Nd4jLong const* determinantShape; - int targetRank = shape::rank(inShape) - 2; // last two dimensions will be reduced to scalar - - if (targetRank == 0) { // scalar only - determinantShape = ConstantShapeHelper::getInstance().scalarShapeInfo(ArrayOptions::dataType(inShape)); - } - else if (targetRank == 1) { // vector - determinantShape = ConstantShapeHelper::getInstance().vectorShapeInfo(shape::sizeAt(inShape, 0), ArrayOptions::dataType(inShape)); - } - else { // only two last dimensions are excluded - determinantShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), targetRank, shape::shapeOf(inShape)); - } - return SHAPELIST(determinantShape); - } - - DECLARE_TYPES(matrix_determinant) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +namespace ops { +CUSTOM_OP_IMPL(matrix_determinant, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(input->rankOf() >= 2, 0, + "matrix_determinant: The rank of input array should not less " + "than 2, but %i is given", + input->rankOf()); + REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, + "matrix_determinant: The last two dimmensions should be equal, " + "but %i and %i are given", + input->sizeAt(-1), input->sizeAt(-2)); + + return helpers::determinant(block.launchContext(), input, output); +} + +DECLARE_SHAPE_FN(matrix_determinant) { + auto inShape = inputShape->at(0); + + Nd4jLong const* determinantShape; + int targetRank = shape::rank(inShape) - + 2; // last two dimensions will be reduced to scalar + + if (targetRank == 0) { // scalar only + determinantShape = ConstantShapeHelper::getInstance().scalarShapeInfo( + ArrayOptions::dataType(inShape)); + } else if (targetRank == 1) { // vector + determinantShape = ConstantShapeHelper::getInstance().vectorShapeInfo( + shape::sizeAt(inShape, 0), ArrayOptions::dataType(inShape)); + } else { // only two last dimensions are excluded + determinantShape = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inShape), shape::order(inShape), targetRank, + shape::shapeOf(inShape)); + } + return SHAPELIST(determinantShape); } +DECLARE_TYPES(matrix_determinant) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} +} // namespace ops +} // namespace sd + #endif #if NOT_EXCLUDED(OP_log_matrix_determinant) namespace sd { - namespace ops { - DECLARE_TYPES(log_matrix_determinant) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - CUSTOM_OP_IMPL(log_matrix_determinant, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(input->rankOf() >=2, 0, "log_matrix_determinant: The rank of input array should not less than 2, but %i is given", input->rankOf()); - REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "log_matrix_determinant: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2)); - - return helpers::logAbsDeterminant(block.launchContext(), input, output); - } - - DECLARE_SHAPE_FN(log_matrix_determinant) { - auto inShape = inputShape->at(0); - - Nd4jLong const* determinantShape; - int targetRank = shape::rank(inShape) - 2; // last two dimensions will be reduced to scalar - - if (targetRank == 0) { // scalar only - determinantShape = ConstantShapeHelper::getInstance().scalarShapeInfo(ArrayOptions::dataType(inShape)); - } - else if (targetRank == 1) { // vector - determinantShape = ConstantShapeHelper::getInstance().vectorShapeInfo(shape::sizeAt(inShape, 0), ArrayOptions::dataType(inShape)); - } - else { // only two last dimensions are excluded - determinantShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), targetRank, shape::shapeOf(inShape)); - } - return SHAPELIST(determinantShape); - } - } +namespace ops { +DECLARE_TYPES(log_matrix_determinant) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +CUSTOM_OP_IMPL(log_matrix_determinant, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(input->rankOf() >= 2, 0, + "log_matrix_determinant: The rank of input array should not " + "less than 2, but %i is given", + input->rankOf()); + REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, + "log_matrix_determinant: The last two dimmensions should be " + "equal, but %i and %i are given", + input->sizeAt(-1), input->sizeAt(-2)); + + return helpers::logAbsDeterminant(block.launchContext(), input, output); +} + +DECLARE_SHAPE_FN(log_matrix_determinant) { + auto inShape = inputShape->at(0); + + Nd4jLong const* determinantShape; + int targetRank = shape::rank(inShape) - + 2; // last two dimensions will be reduced to scalar + + if (targetRank == 0) { // scalar only + determinantShape = ConstantShapeHelper::getInstance().scalarShapeInfo( + ArrayOptions::dataType(inShape)); + } else if (targetRank == 1) { // vector + determinantShape = ConstantShapeHelper::getInstance().vectorShapeInfo( + shape::sizeAt(inShape, 0), ArrayOptions::dataType(inShape)); + } else { // only two last dimensions are excluded + determinantShape = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inShape), shape::order(inShape), targetRank, + shape::shapeOf(inShape)); + } + return SHAPELIST(determinantShape); } +} // namespace ops +} // namespace sd #endif #if NOT_EXCLUDED(OP_logdet) namespace sd { - namespace ops { - DECLARE_TYPES(logdet) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - CUSTOM_OP_IMPL(logdet, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_NULLIFIED(0); - - REQUIRE_TRUE(input->rankOf() >=2, 0, "logdet: The rank of input array should not less than 2, but %i is given", input->rankOf()); - REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "logdet: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2)); - REQUIRE_TRUE(helpers::checkCholeskyInput(block.launchContext(), input), 0, "logdet: The input tensor should be positive-defined hermitian."); - - return helpers::logdetFunctor(block.launchContext(), input, output); - } - - DECLARE_SHAPE_FN(logdet) { - auto inShape = inputShape->at(0); - - Nd4jLong const* determinantShape; - int targetRank = shape::rank(inShape) - 2; // last two dimensions will be reduced to scalar - - if (targetRank == 0) { // scalar only - determinantShape = ConstantShapeHelper::getInstance().scalarShapeInfo(ArrayOptions::dataType(inShape)); - } - else if (targetRank == 1) { // vector - determinantShape = ConstantShapeHelper::getInstance().vectorShapeInfo(shape::sizeAt(inShape, 0), ArrayOptions::dataType(inShape)); - } - else { // only two last dimensions are excluded - determinantShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), targetRank, shape::shapeOf(inShape)); - } - return SHAPELIST(determinantShape); - } - } +namespace ops { +DECLARE_TYPES(logdet) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +CUSTOM_OP_IMPL(logdet, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_NULLIFIED(0); + + REQUIRE_TRUE( + input->rankOf() >= 2, 0, + "logdet: The rank of input array should not less than 2, but %i is given", + input->rankOf()); + REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, + "logdet: The last two dimmensions should be equal, but %i and " + "%i are given", + input->sizeAt(-1), input->sizeAt(-2)); + REQUIRE_TRUE( + helpers::checkCholeskyInput(block.launchContext(), input), 0, + "logdet: The input tensor should be positive-defined hermitian."); + + return helpers::logdetFunctor(block.launchContext(), input, output); +} + +DECLARE_SHAPE_FN(logdet) { + auto inShape = inputShape->at(0); + + Nd4jLong const* determinantShape; + int targetRank = shape::rank(inShape) - + 2; // last two dimensions will be reduced to scalar + + if (targetRank == 0) { // scalar only + determinantShape = ConstantShapeHelper::getInstance().scalarShapeInfo( + ArrayOptions::dataType(inShape)); + } else if (targetRank == 1) { // vector + determinantShape = ConstantShapeHelper::getInstance().vectorShapeInfo( + shape::sizeAt(inShape, 0), ArrayOptions::dataType(inShape)); + } else { // only two last dimensions are excluded + determinantShape = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inShape), shape::order(inShape), targetRank, + shape::shapeOf(inShape)); + } + return SHAPELIST(determinantShape); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/linalg/matrix_diag.cpp b/libnd4j/include/ops/declarable/generic/linalg/matrix_diag.cpp index 6e95d127de6d..c07f8e36bbae 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/matrix_diag.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/matrix_diag.cpp @@ -23,46 +23,45 @@ #include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(matrix_diag, 1, 1, false, 0, 0) { + auto diagonal = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto diagonal = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(!diagonal->isScalar(), 0, + "CUSTOM_OP matrix_diag: input diagonal array must be at list a " + "vector, but scalar was given!"); - REQUIRE_TRUE(!diagonal->isScalar(), 0, "CUSTOM_OP matrix_diag: input diagonal array must be at list a vector, but scalar was given!"); + helpers::matrixSetDiag(block.launchContext(), *output, *diagonal, *output, + true); - helpers::matrixSetDiag(block.launchContext(), *output, *diagonal, *output, true); - - return Status::OK(); + return Status::OK(); } DECLARE_SHAPE_FN(matrix_diag) { + Nd4jLong* outShapeInfo = nullptr; + auto in = inputShape->at(0); + int inRank = shape::rank(in); - Nd4jLong* outShapeInfo = nullptr; - auto in = inputShape->at(0); - int inRank = shape::rank(in); - - // if for example diagonal array has shape [A,B,C] then output array has shape [A,B,C,C] + // if for example diagonal array has shape [A,B,C] then output array has + // shape [A,B,C,C] - int outRank = inRank + 1; + int outRank = inRank + 1; - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); - outShapeInfo[0] = outRank; - for(int i = 0; i < inRank; ++i) - outShapeInfo[i + 1] = shape::sizeAt(in, i); - outShapeInfo[outRank] = shape::sizeAt(in, -1); + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + outShapeInfo[0] = outRank; + for (int i = 0; i < inRank; ++i) outShapeInfo[i + 1] = shape::sizeAt(in, i); + outShapeInfo[outRank] = shape::sizeAt(in, -1); - ShapeUtils::updateStridesAndType(outShapeInfo, in, shape::order(in)); + ShapeUtils::updateStridesAndType(outShapeInfo, in, shape::order(in)); - return SHAPELIST(CONSTANT(outShapeInfo)); + return SHAPELIST(CONSTANT(outShapeInfo)); } DECLARE_TYPES(matrix_diag) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); -} + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } -} - +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/linalg/matrix_inverse.cpp b/libnd4j/include/ops/declarable/generic/linalg/matrix_inverse.cpp index 6a595a92b300..b5c5d7bba107 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/matrix_inverse.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/matrix_inverse.cpp @@ -24,23 +24,27 @@ #include #include namespace sd { - namespace ops { - OP_IMPL(matrix_inverse, 1, 1, true) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +OP_IMPL(matrix_inverse, 1, 1, true) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(input->rankOf() >=2, 0, "matrix_inverse: The rank of input array should not less than 2, but %i is given", input->rankOf()); - REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "matrix_inverse: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2)); + REQUIRE_TRUE(input->rankOf() >= 2, 0, + "matrix_inverse: The rank of input array should not less than " + "2, but %i is given", + input->rankOf()); + REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, + "matrix_inverse: The last two dimmensions should be equal, but " + "%i and %i are given", + input->sizeAt(-1), input->sizeAt(-2)); - return helpers::inverse(block.launchContext(), input, output); - } + return helpers::inverse(block.launchContext(), input, output); +} - DECLARE_TYPES(matrix_inverse) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +DECLARE_TYPES(matrix_inverse) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/moments.cpp b/libnd4j/include/ops/declarable/generic/linalg/moments.cpp index c8fdf2e48cf9..5b38c1efb6a1 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/moments.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/moments.cpp @@ -25,68 +25,69 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(moments, 1, 2, false, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto means = OUTPUT_VARIABLE(0); - auto variances = OUTPUT_VARIABLE(1); - - std::vector axis = *block.getIArguments(); - const bool keepDims = block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; - - // axis might be dynamic (i.e. tf mode) - if (block.width() > 1 && axis.size() == 0) { - auto axisVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axisVector, axis); -// for (int e = 0; e < axisVector->lengthOf(); e++) { -// int ca = (int) axisVector->e(e); -// if (ca < 0) -// ca += input->rankOf(); -// -// axis.emplace_back(ca); -// } - - } - - std::vector& dims = axis; - input->varianceAlongDimension(variance::SummaryStatsVariance, *variances, false, axis); - input->reduceAlongDimension(reduce::Mean, *means, axis, keepDims); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(moments) { - auto axis = *block.getIArguments(); - auto input = INPUT_VARIABLE(0); - - // axis might be dynamic (i.e. tf mode) - if (block.width() > 1 && axis.size() == 0) { - auto axisVector = INPUT_VARIABLE(1); - - for (int e = 0; e < axisVector->lengthOf(); e++) { - int ca = axisVector->e(e); - if (ca < 0) - ca += input->rankOf(); +namespace ops { +CUSTOM_OP_IMPL(moments, 1, 2, false, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto means = OUTPUT_VARIABLE(0); + auto variances = OUTPUT_VARIABLE(1); + + auto axis = block.getIArguments(); + const bool keepDims = block.numT() > 0 ? (bool)T_ARG(0) : false; + + // axis might be dynamic (i.e. tf mode) + if (block.width() > 1 && axis.size() == 0) { + auto axisVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axisVector, axis); + // for (int e = 0; e < axisVector->lengthOf(); e++) { + // int ca = (int) axisVector->e(e); + // if (ca < 0) + // ca += input->rankOf(); + // + // axis.emplace_back(ca); + // } + } + + std::vector& dims = axis; + input->varianceAlongDimension(variance::SummaryStatsVariance, *variances, + false, axis); + input->reduceAlongDimension(reduce::Mean, *means, axis, keepDims); + + return Status::OK(); +} - axis.emplace_back(ca); - } +DECLARE_SHAPE_FN(moments) { + auto axis = block.getIArguments(); + auto input = INPUT_VARIABLE(0); - } - //std::vector dims = ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); - const bool keepDims = block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; + // axis might be dynamic (i.e. tf mode) + if (block.width() > 1 && axis.size() == 0) { + auto axisVector = INPUT_VARIABLE(1); - auto meanShape = ShapeUtils::evalReduceShapeInfo('c', axis, *input, keepDims, false, block.workspace()); - auto varianceShape = ShapeUtils::evalReduceShapeInfo('c', axis, *input, keepDims, false, block.workspace()); - return SHAPELIST(meanShape, varianceShape); - } + for (int e = 0; e < axisVector->lengthOf(); e++) { + int ca = axisVector->e(e); + if (ca < 0) ca += input->rankOf(); - DECLARE_TYPES(moments) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } + axis.emplace_back(ca); } + } + // std::vector dims = ShapeUtils::evalDimsToExclude(input->rankOf(), + // {axis}); + const bool keepDims = block.numT() > 0 ? (bool)T_ARG(0) : false; + + auto meanShape = ShapeUtils::evalReduceShapeInfo('c', axis, *input, keepDims, + false, block.workspace()); + auto varianceShape = ShapeUtils::evalReduceShapeInfo( + 'c', axis, *input, keepDims, false, block.workspace()); + return SHAPELIST(meanShape, varianceShape); +} +DECLARE_TYPES(moments) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops + +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/polygamma.cpp b/libnd4j/include/ops/declarable/generic/linalg/polygamma.cpp index 35ffdcbc6e82..eccf5eafd99a 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/polygamma.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/polygamma.cpp @@ -25,38 +25,42 @@ #include namespace sd { -namespace ops { +namespace ops { CONFIGURABLE_OP_IMPL(polygamma, 2, 1, false, 0, 0) { - auto n = INPUT_VARIABLE(0); - auto x = INPUT_VARIABLE(1); + auto n = INPUT_VARIABLE(0); + auto x = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(n->isSameShape(x), 0, "POLYGAMMA op: two input arrays n and x must have the same shapes, but got n=%s and x=%s instead !", ShapeUtils::shapeAsString(n).c_str(), ShapeUtils::shapeAsString(x).c_str()); + REQUIRE_TRUE(n->isSameShape(x), 0, + "POLYGAMMA op: two input arrays n and x must have the same " + "shapes, but got n=%s and x=%s instead !", + ShapeUtils::shapeAsString(n).c_str(), + ShapeUtils::shapeAsString(x).c_str()); - Nd4jLong arrLen = n->lengthOf(); - // FIXME: this shit should be single op call, not a loop! - auto nNegative = n->reduceNumber(sd::reduce::IsNegative, nullptr); - auto xPositive = x->reduceNumber(sd::reduce::IsPositive, nullptr); - bool nPositiveFlag = !nNegative.e(0); // require all n >= 0 - bool xPositiveFlag = xPositive.e(0); // require all x > 0 - REQUIRE_TRUE(nPositiveFlag, 0, "POLYGAMMA op: all elements of n array must be >= 0 !"); - REQUIRE_TRUE(xPositiveFlag, 0, "POLYGAMMA op: all elements of x array must be > 0 !"); + Nd4jLong arrLen = n->lengthOf(); + // FIXME: this shit should be single op call, not a loop! + auto nNegative = n->reduceNumber(sd::reduce::IsNegative, nullptr); + auto xPositive = x->reduceNumber(sd::reduce::IsPositive, nullptr); + bool nPositiveFlag = !nNegative.e(0); // require all n >= 0 + bool xPositiveFlag = xPositive.e(0); // require all x > 0 + REQUIRE_TRUE(nPositiveFlag, 0, + "POLYGAMMA op: all elements of n array must be >= 0 !"); + REQUIRE_TRUE(xPositiveFlag, 0, + "POLYGAMMA op: all elements of x array must be > 0 !"); - helpers::polyGamma(block.launchContext(), *n, *x, *output); - return Status::OK(); + helpers::polyGamma(block.launchContext(), *n, *x, *output); + return Status::OK(); } DECLARE_SYN(polyGamma, polygamma); DECLARE_SYN(PolyGamma, polygamma); - DECLARE_TYPES(polygamma) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setSameMode(true); - } -} +DECLARE_TYPES(polygamma) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/qr.cpp b/libnd4j/include/ops/declarable/generic/linalg/qr.cpp index 1cdfc6884704..ca59bd365b4e 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/qr.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/qr.cpp @@ -18,72 +18,84 @@ // Created by GS at 12/20/2019 // -#include #include #include +#include #if NOT_EXCLUDED(OP_qr) namespace sd { - namespace ops { - CUSTOM_OP_IMPL(qr, 1, 2, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto outputQ = OUTPUT_VARIABLE(0); - auto outputR = OUTPUT_VARIABLE(1); - auto fullMatricies = false; - if (block.getBArguments()->size()) - fullMatricies = B_ARG(0); - REQUIRE_TRUE(input->rankOf() >=2, 0, "qr: The rank of input array should not be less than 2, but %i is given", input->rankOf()); - REQUIRE_TRUE((fullMatricies && outputQ->sizeAt(-1) == input->sizeAt(-2)) || (!fullMatricies && outputQ->isSameShape(input)), 0, "qr: The last dimmensions should be equal to result Q, but %i and %i are given", outputQ->sizeAt(-1), input->sizeAt(-2)); - REQUIRE_TRUE((fullMatricies && outputR->sizeAt(-1) == input->sizeAt(-1)) || (!fullMatricies && outputR->sizeAt(-1) == outputR->sizeAt(-2)), 0, "qr: The last dimmensions should be equal to result R, but %i and %i are given", outputR->sizeAt(-1), input->sizeAt(-1)); - if (!input->isEmpty() && !outputQ->isEmpty() && !outputR->isEmpty()) - helpers::qr(block.launchContext(), input, outputQ, outputR, fullMatricies); +namespace ops { +CUSTOM_OP_IMPL(qr, 1, 2, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto outputQ = OUTPUT_VARIABLE(0); + auto outputR = OUTPUT_VARIABLE(1); + auto fullMatricies = false; + if (block.numB()) fullMatricies = B_ARG(0); - return Status::OK(); - } + REQUIRE_TRUE( + input->rankOf() >= 2, 0, + "qr: The rank of input array should not be less than 2, but %i is given", + input->rankOf()); + REQUIRE_TRUE((fullMatricies && outputQ->sizeAt(-1) == input->sizeAt(-2)) || + (!fullMatricies && outputQ->isSameShape(input)), + 0, + "qr: The last dimmensions should be equal to result Q, but %i " + "and %i are given", + outputQ->sizeAt(-1), input->sizeAt(-2)); + REQUIRE_TRUE( + (fullMatricies && outputR->sizeAt(-1) == input->sizeAt(-1)) || + (!fullMatricies && outputR->sizeAt(-1) == outputR->sizeAt(-2)), + 0, + "qr: The last dimmensions should be equal to result R, but %i and %i are " + "given", + outputR->sizeAt(-1), input->sizeAt(-1)); + if (!input->isEmpty() && !outputQ->isEmpty() && !outputR->isEmpty()) + helpers::qr(block.launchContext(), input, outputQ, outputR, fullMatricies); - DECLARE_SHAPE_FN(qr) { - auto inShape = inputShape->at(0); + return Status::OK(); +} - Nd4jLong const* shapeQ; - Nd4jLong const* shapeR; - int targetRank = shape::rank(inShape); // last two dimensions will be reduced to scalar +DECLARE_SHAPE_FN(qr) { + auto inShape = inputShape->at(0); - auto fullMatricies = false; - if (block.getBArguments()->size()) - fullMatricies = B_ARG(0); + Nd4jLong const* shapeQ; + Nd4jLong const* shapeR; + int targetRank = + shape::rank(inShape); // last two dimensions will be reduced to scalar - auto shape = ShapeUtils::shapeAsVector(inShape); + auto fullMatricies = false; + if (block.numB()) fullMatricies = B_ARG(0); - if (!fullMatricies) { // outputs are: Q is MxN and R is NxN - shape[targetRank - 1] = shape::sizeAt(inShape, -1); - shape[targetRank - 2] = shape[targetRank - 1]; - shapeQ = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inShape), - shape::order(inShape), targetRank, - shape::shapeOf(inShape)); - shapeR = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inShape), - shape::order(inShape), shape); + auto shape = ShapeUtils::shapeAsVector(inShape); - } - else {// otherwise outputs are Q is MxM and R is MxN with zero filled rows - shape[targetRank - 1] = shape::sizeAt(inShape, -2); - shape[targetRank - 2] = shape[targetRank - 1]; - shapeR = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inShape), - shape::order(inShape), targetRank, - shape::shapeOf(inShape)); - shapeQ = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inShape), - shape::order(inShape), shape); - } + if (!fullMatricies) { // outputs are: Q is MxN and R is NxN + shape[targetRank - 1] = shape::sizeAt(inShape, -1); + shape[targetRank - 2] = shape[targetRank - 1]; + shapeQ = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inShape), shape::order(inShape), targetRank, + shape::shapeOf(inShape)); + shapeR = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inShape), shape::order(inShape), shape); - return SHAPELIST(shapeQ, shapeR); + } else { // otherwise outputs are Q is MxM and R is MxN with zero filled rows + shape[targetRank - 1] = shape::sizeAt(inShape, -2); + shape[targetRank - 2] = shape[targetRank - 1]; + shapeR = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inShape), shape::order(inShape), targetRank, + shape::shapeOf(inShape)); + shapeQ = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inShape), shape::order(inShape), shape); + } - } + return SHAPELIST(shapeQ, shapeR); +} - DECLARE_TYPES(qr) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +DECLARE_TYPES(qr) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/linalg/solve.cpp b/libnd4j/include/ops/declarable/generic/linalg/solve.cpp index 154001684b14..90e77e621219 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/solve.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/solve.cpp @@ -24,55 +24,69 @@ #include #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(solve, 2, 1, false, 0, 0) { - auto a = INPUT_VARIABLE(0); - auto b = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(solve, 2, 1, false, 0, 0) { + auto a = INPUT_VARIABLE(0); + auto b = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - bool useAdjoint = false; + bool useAdjoint = false; - if (block.numB() > 0) { - useAdjoint = B_ARG(0); - } - REQUIRE_TRUE(shape::shapeEquals(a->rankOf() - 2, a->shapeInfo(), b->rankOf() - 2, b->shapeInfo()), 0, "solve: Input shapes should be alike."); - REQUIRE_TRUE(a->rankOf() >=2, 0, "solve: The rank of input left tensor should not be less than 2, but %i is given", a->rankOf()); - REQUIRE_TRUE(b->rankOf() >=2, 0, "solve: The rank of input right tensor should not be less than 2, but %i is given", b->rankOf()); + if (block.numB() > 0) { + useAdjoint = B_ARG(0); + } + REQUIRE_TRUE(shape::shapeEquals(a->rankOf() - 2, a->shapeInfo(), + b->rankOf() - 2, b->shapeInfo()), + 0, "solve: Input shapes should be alike."); + REQUIRE_TRUE(a->rankOf() >= 2, 0, + "solve: The rank of input left tensor should not be less than " + "2, but %i is given", + a->rankOf()); + REQUIRE_TRUE(b->rankOf() >= 2, 0, + "solve: The rank of input right tensor should not be less than " + "2, but %i is given", + b->rankOf()); - REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "solve: The last two dimmensions should be equal, but %i and %i are given", a->sizeAt(-1), a->sizeAt(-2)); - REQUIRE_TRUE(a->sizeAt(-1) == b->sizeAt(-2), 0, "solve: The last dimmension of left part should be equal to prelast of right part, but %i and %i are given", a->sizeAt(-1), b->sizeAt(-2)); - if (a->isEmpty() || b->isEmpty() || z->isEmpty()) - return Status::OK(); + REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, + "solve: The last two dimmensions should be equal, but %i and %i " + "are given", + a->sizeAt(-1), a->sizeAt(-2)); + REQUIRE_TRUE(a->sizeAt(-1) == b->sizeAt(-2), 0, + "solve: The last dimmension of left part should be equal to " + "prelast of right part, but %i and %i are given", + a->sizeAt(-1), b->sizeAt(-2)); + if (a->isEmpty() || b->isEmpty() || z->isEmpty()) return Status::OK(); - auto input = a; - if (useAdjoint) { - auto adjointA = a->ulike(); - helpers::adjointMatrix(block.launchContext(), a, &adjointA); - input = new NDArray(adjointA); //.detach(); - }; + auto input = a; + if (useAdjoint) { + auto adjointA = a->ulike(); + helpers::adjointMatrix(block.launchContext(), a, &adjointA); + input = new NDArray(adjointA); //.detach(); + }; - auto res = helpers::solveFunctor(block.launchContext(), input, b, useAdjoint, z); - if (input != a) - delete input; + auto res = + helpers::solveFunctor(block.launchContext(), input, b, useAdjoint, z); + if (input != a) delete input; - return Status::OK(); - } - - DECLARE_SHAPE_FN(solve) { - auto in0 = inputShape->at(1); - auto in1 = inputShape->at(1); - auto luShape = ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace()); + return Status::OK(); +} + +DECLARE_SHAPE_FN(solve) { + auto in0 = inputShape->at(1); + auto in1 = inputShape->at(1); + auto luShape = + ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace()); - return SHAPELIST(CONSTANT(luShape)); - } + return SHAPELIST(CONSTANT(luShape)); +} - DECLARE_TYPES(solve) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(false); - } - } +DECLARE_TYPES(solve) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(false); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/sufficient_statistics.cpp b/libnd4j/include/ops/declarable/generic/linalg/sufficient_statistics.cpp index 915ba5fb9642..046cc033a6d2 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/sufficient_statistics.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/sufficient_statistics.cpp @@ -24,66 +24,68 @@ #include #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(sufficient_statistics, 2, 3, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto axisVector = INPUT_VARIABLE(1); - auto dataCount = OUTPUT_VARIABLE(0); - - auto sum = OUTPUT_VARIABLE(1); - auto squares = OUTPUT_VARIABLE(2); - - std::vector axis(axisVector->lengthOf());//*block.getIArguments(); - - // axis might be dynamic (i.e. tf mode) - helpers::adjustAxis(input->rankOf(), axisVector, axis); - - input->reduceAlongDimension(reduce::SquaredNorm, *squares, axis); - input->reduceAlongDimension(reduce::Sum, *sum, axis); - auto count = NDArrayFactory::create(input->dataType(), input->lengthOf() / sum->lengthOf()); - dataCount->assign(count); - if (block.numT() > 0) { - auto shift = OUTPUT_VARIABLE(3); - shift->assign(T_ARG(0)); - } - - return Status::OK(); - } - - DECLARE_TYPES(sufficient_statistics) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}); - getOpDescriptor() - ->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}); - getOpDescriptor() - ->setAllowedOutputTypes(0, DataType::INHERIT); - getOpDescriptor() - ->setAllowedOutputTypes(1, DataType::INHERIT); - getOpDescriptor() - ->setAllowedOutputTypes(2, DataType::INHERIT); - } - - DECLARE_SHAPE_FN(sufficient_statistics) { - auto axisVector = INPUT_VARIABLE(1); - std::vector axis(axisVector->lengthOf()); - - auto input = INPUT_VARIABLE(0); - helpers::adjustAxis(input->rankOf(), axisVector, axis); - - //std::vector dims = ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); - auto scalarShape = ConstantShapeHelper::getInstance().scalarShapeInfo(ArrayOptions::dataType(inputShape->at(0))); - auto sumShape = ShapeUtils::evalReduceShapeInfo('c', axis, *input, false, false, block.workspace()); - - auto squareShape = ShapeUtils::evalReduceShapeInfo('c', axis, *input, false, false, block.workspace()); - - auto shapeList = SHAPELIST(scalarShape, sumShape, squareShape); - if (block.numT() > 0) - shapeList->push_back(ConstantShapeHelper::getInstance().scalarShapeInfo(ArrayOptions::dataType(inputShape->at(0)))); - - return shapeList; - } - } +namespace ops { +CUSTOM_OP_IMPL(sufficient_statistics, 2, 3, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto axisVector = INPUT_VARIABLE(1); + auto dataCount = OUTPUT_VARIABLE(0); + + auto sum = OUTPUT_VARIABLE(1); + auto squares = OUTPUT_VARIABLE(2); + + std::vector axis(axisVector->lengthOf()); //*block.getIArguments(); + + // axis might be dynamic (i.e. tf mode) + helpers::adjustAxis(input->rankOf(), axisVector, axis); + + input->reduceAlongDimension(reduce::SquaredNorm, *squares, axis); + input->reduceAlongDimension(reduce::Sum, *sum, axis); + auto count = NDArrayFactory::create(input->dataType(), + input->lengthOf() / sum->lengthOf()); + dataCount->assign(count); + if (block.numT() > 0) { + auto shift = OUTPUT_VARIABLE(3); + shift->assign(T_ARG(0)); + } + + return Status::OK(); +} +DECLARE_TYPES(sufficient_statistics) { + getOpDescriptor()->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}); + getOpDescriptor()->setAllowedInputTypes(1, + {DataType::INT32, DataType::INT64}); + getOpDescriptor()->setAllowedOutputTypes(0, DataType::INHERIT); + getOpDescriptor()->setAllowedOutputTypes(1, DataType::INHERIT); + getOpDescriptor()->setAllowedOutputTypes(2, DataType::INHERIT); } +DECLARE_SHAPE_FN(sufficient_statistics) { + auto axisVector = INPUT_VARIABLE(1); + std::vector axis(axisVector->lengthOf()); + + auto input = INPUT_VARIABLE(0); + helpers::adjustAxis(input->rankOf(), axisVector, axis); + + // std::vector dims = ShapeUtils::evalDimsToExclude(input->rankOf(), + // {axis}); + auto scalarShape = ConstantShapeHelper::getInstance().scalarShapeInfo( + ArrayOptions::dataType(inputShape->at(0))); + auto sumShape = ShapeUtils::evalReduceShapeInfo('c', axis, *input, false, + false, block.workspace()); + + auto squareShape = ShapeUtils::evalReduceShapeInfo('c', axis, *input, false, + false, block.workspace()); + + auto shapeList = SHAPELIST(scalarShape, sumShape, squareShape); + if (block.numT() > 0) + shapeList->push_back(ConstantShapeHelper::getInstance().scalarShapeInfo( + ArrayOptions::dataType(inputShape->at(0)))); + + return shapeList; +} +} // namespace ops + +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/svd.cpp b/libnd4j/include/ops/declarable/generic/linalg/svd.cpp index 3331dcdd8582..3c998fd6eb53 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/svd.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/svd.cpp @@ -25,95 +25,106 @@ #include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(svd, 1, 1, false, 0, 3) { - auto x = INPUT_VARIABLE(0); + auto x = INPUT_VARIABLE(0); - const int rank = x->rankOf(); - REQUIRE_TRUE(rank >= 2 , 0, "SVD OP: the rank of input array must be >=2, but got %i instead!", rank); + const int rank = x->rankOf(); + REQUIRE_TRUE( + rank >= 2, 0, + "SVD OP: the rank of input array must be >=2, but got %i instead!", rank); - bool fullUV = (bool)INT_ARG(0); - const bool calcUV = (bool)INT_ARG(1); + bool fullUV = (bool)INT_ARG(0); + const bool calcUV = (bool)INT_ARG(1); - if(calcUV == false) - fullUV = false; + if (calcUV == false) fullUV = false; - const int switchNum = INT_ARG(2); + const int switchNum = INT_ARG(2); - // #ifndef __CUDABLAS__ - helpers::svd(block.launchContext(), x, {OUTPUT_VARIABLE(0), calcUV ? OUTPUT_VARIABLE(1) : nullptr, calcUV ? OUTPUT_VARIABLE(2) : nullptr}, fullUV, calcUV, switchNum); - // #endif + // #ifndef __CUDABLAS__ + helpers::svd(block.launchContext(), x, + {OUTPUT_VARIABLE(0), calcUV ? OUTPUT_VARIABLE(1) : nullptr, + calcUV ? OUTPUT_VARIABLE(2) : nullptr}, + fullUV, calcUV, switchNum); + // #endif - return Status::OK();; + return Status::OK(); + ; } - - DECLARE_TYPES(svd) { - getOpDescriptor() - ->setAllowedInputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setSameMode(true); - } +DECLARE_TYPES(svd) { + getOpDescriptor() + ->setAllowedInputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setSameMode(true); +} DECLARE_SHAPE_FN(svd) { - - auto inShapeInfo = inputShape->at(0); - bool fullUV = (bool)INT_ARG(0); - bool calcUV = (bool)INT_ARG(1); - - const int rank = inShapeInfo[0]; - REQUIRE_TRUE(rank >= 2 , 0, "SVD OP: the rank of input array must be >=2, but got %i instead!", rank); - - const int diagSize = inShapeInfo[rank] < inShapeInfo[rank-1] ? inShapeInfo[rank] : inShapeInfo[rank-1]; - - Nd4jLong* sShapeInfo(nullptr); - if(rank == 2) { - ALLOCATE(sShapeInfo, block.getWorkspace(), shape::shapeInfoLength(1), Nd4jLong); - sShapeInfo[0] = 1; - sShapeInfo[1] = diagSize; - } - else { - ALLOCATE(sShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank-1), Nd4jLong); - sShapeInfo[0] = rank - 1; - for(int i=1; i <= rank-2; ++i) - sShapeInfo[i] = inShapeInfo[i]; - sShapeInfo[rank-1] = diagSize; + auto inShapeInfo = inputShape->at(0); + bool fullUV = (bool)INT_ARG(0); + bool calcUV = (bool)INT_ARG(1); + + const int rank = inShapeInfo[0]; + REQUIRE_TRUE( + rank >= 2, 0, + "SVD OP: the rank of input array must be >=2, but got %i instead!", rank); + + const int diagSize = inShapeInfo[rank] < inShapeInfo[rank - 1] + ? inShapeInfo[rank] + : inShapeInfo[rank - 1]; + + Nd4jLong* sShapeInfo(nullptr); + if (rank == 2) { + ALLOCATE(sShapeInfo, block.workspace(), shape::shapeInfoLength(1), + Nd4jLong); + sShapeInfo[0] = 1; + sShapeInfo[1] = diagSize; + } else { + ALLOCATE(sShapeInfo, block.workspace(), shape::shapeInfoLength(rank - 1), + Nd4jLong); + sShapeInfo[0] = rank - 1; + for (int i = 1; i <= rank - 2; ++i) sShapeInfo[i] = inShapeInfo[i]; + sShapeInfo[rank - 1] = diagSize; + } + + ShapeUtils::updateStridesAndType(sShapeInfo, inShapeInfo, + shape::order(inShapeInfo)); + + if (calcUV) { + Nd4jLong *uShapeInfo(nullptr), *vShapeInfo(nullptr); + COPY_SHAPE(inShapeInfo, uShapeInfo); + COPY_SHAPE(inShapeInfo, vShapeInfo); + + if (fullUV) { + uShapeInfo[rank] = uShapeInfo[rank - 1]; + vShapeInfo[rank - 1] = vShapeInfo[rank]; + } else { + uShapeInfo[rank] = diagSize; + vShapeInfo[rank - 1] = vShapeInfo[rank]; + vShapeInfo[rank] = diagSize; } - ShapeUtils::updateStridesAndType(sShapeInfo, inShapeInfo, shape::order(inShapeInfo)); - - if(calcUV){ - - Nd4jLong *uShapeInfo(nullptr), *vShapeInfo(nullptr); - COPY_SHAPE(inShapeInfo, uShapeInfo); - COPY_SHAPE(inShapeInfo, vShapeInfo); - - if(fullUV) { - uShapeInfo[rank] = uShapeInfo[rank-1]; - vShapeInfo[rank-1] = vShapeInfo[rank]; - } - else { - uShapeInfo[rank] = diagSize; - vShapeInfo[rank-1] = vShapeInfo[rank]; - vShapeInfo[rank] = diagSize; - } - - shape::updateStrides(uShapeInfo, shape::order(inShapeInfo)); - shape::updateStrides(vShapeInfo, shape::order(inShapeInfo)); - - auto result = SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(sShapeInfo)), ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(uShapeInfo)), ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(vShapeInfo))); - RELEASE(sShapeInfo, block.workspace()); - RELEASE(uShapeInfo, block.workspace()); - RELEASE(vShapeInfo, block.workspace()); - return result; - } - - return SHAPELIST(ConstantShapeHelper::getInstance().createFromExisting(sShapeInfo, block.workspace())); + shape::updateStrides(uShapeInfo, shape::order(inShapeInfo)); + shape::updateStrides(vShapeInfo, shape::order(inShapeInfo)); + + auto result = SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(sShapeInfo)), + ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(uShapeInfo)), + ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(vShapeInfo))); + RELEASE(sShapeInfo, block.workspace()); + RELEASE(uShapeInfo, block.workspace()); + RELEASE(vShapeInfo, block.workspace()); + return result; + } + + return SHAPELIST(ConstantShapeHelper::getInstance().createFromExisting( + sShapeInfo, block.workspace())); } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/trace.cpp b/libnd4j/include/ops/declarable/generic/linalg/trace.cpp index 1a67ec7542da..6cb4c68a99f9 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/trace.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/trace.cpp @@ -22,48 +22,55 @@ #if NOT_EXCLUDED(OP_trace) #include -#include +#include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(trace, 1, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(input->rankOf() >= 2, 0, "TRACE op: the rank of input array must be >=2, but got %i instead!", input->rankOf()); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - helpers::trace(block.launchContext(), *input, *output); + REQUIRE_TRUE( + input->rankOf() >= 2, 0, + "TRACE op: the rank of input array must be >=2, but got %i instead!", + input->rankOf()); - return Status::OK(); + helpers::trace(block.launchContext(), *input, *output); + + return Status::OK(); } - DECLARE_TYPES(trace) { - getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(trace) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} DECLARE_SHAPE_FN(trace) { - auto inShapeInfo = inputShape->at(0); + auto inShapeInfo = inputShape->at(0); - REQUIRE_TRUE(inShapeInfo[0] >= 2, 0, "TRACE op: the rank of input array must be >=2, but got %i instead!", inShapeInfo[0]); - const int rank = inShapeInfo[0] - 2; + REQUIRE_TRUE( + inShapeInfo[0] >= 2, 0, + "TRACE op: the rank of input array must be >=2, but got %i instead!", + inShapeInfo[0]); + const int rank = inShapeInfo[0] - 2; - Nd4jLong* outShapeInfo(nullptr); - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); + Nd4jLong* outShapeInfo(nullptr); + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); - outShapeInfo[0] = rank; - for(int i=1; i <= rank; ++i) - outShapeInfo[i] = inShapeInfo[i]; + outShapeInfo[0] = rank; + for (int i = 1; i <= rank; ++i) outShapeInfo[i] = inShapeInfo[i]; - shape::updateStrides(outShapeInfo, shape::order(inShapeInfo)); - auto result = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(outShapeInfo, ArrayOptions::dataType(inShapeInfo))); - RELEASE(outShapeInfo, block.getWorkspace()); - return SHAPELIST(result); + shape::updateStrides(outShapeInfo, shape::order(inShapeInfo)); + auto result = ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(outShapeInfo, ArrayOptions::dataType(inShapeInfo))); + RELEASE(outShapeInfo, block.workspace()); + return SHAPELIST(result); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/tri.cpp b/libnd4j/include/ops/declarable/generic/linalg/tri.cpp index d0c1f7a6f992..eec440d75762 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/tri.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/tri.cpp @@ -20,44 +20,42 @@ #include - namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(tri, -2, 1, false, 0, 1) { + auto output = OUTPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - const int diag = block.numI() > 2 ? INT_ARG(2) : 0; + const int diag = block.numI() > 2 ? INT_ARG(2) : 0; - BUILD_SINGLE_SELECTOR(output->dataType(), output->fillAsTriangular, (1., diag + 1, 0, *output, 'l'), LIBND4J_TYPES); // fill with unities lower triangular block of matrix - BUILD_SINGLE_SELECTOR(output->dataType(), output->fillAsTriangular, (0., 0, diag, *output, 'u'), LIBND4J_TYPES); // fill with zeros upper triangular block of matrix + BUILD_SINGLE_SELECTOR( + output->dataType(), output->fillAsTriangular, + (1., diag + 1, 0, *output, 'l'), + LIBND4J_TYPES); // fill with unities lower triangular block of matrix + BUILD_SINGLE_SELECTOR( + output->dataType(), output->fillAsTriangular, (0., 0, diag, *output, 'u'), + LIBND4J_TYPES); // fill with zeros upper triangular block of matrix - // output->setValueInDiagMatrix(1., diag, 'l'); - // output->setValueInDiagMatrix(0., diag+1, 'u'); + // output->setValueInDiagMatrix(1., diag, 'l'); + // output->setValueInDiagMatrix(0., diag+1, 'u'); - return Status::OK(); + return Status::OK(); } - DECLARE_TYPES(tri) { - getOpDescriptor() - ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS}); - } - +DECLARE_TYPES(tri) { + getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS}); +} DECLARE_SHAPE_FN(tri) { - const int rows = INT_ARG(0); - const int cols = block.numI() > 1 ? INT_ARG(1) : rows; + const int rows = INT_ARG(0); + const int cols = block.numI() > 1 ? INT_ARG(1) : rows; - auto dtype = block.numD() ? D_ARG(0) : DataType::FLOAT32; + auto dtype = block.numD() ? D_ARG(0) : DataType::FLOAT32; - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(dtype, 'c', {rows, cols})); + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + dtype, 'c', {rows, cols})); } - - - -} -} \ No newline at end of file +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/triangular_solve.cpp b/libnd4j/include/ops/declarable/generic/linalg/triangular_solve.cpp index 49ec1e135d54..f064f38b4a85 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/triangular_solve.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/triangular_solve.cpp @@ -24,59 +24,71 @@ #include #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(triangular_solve, 2, 1, false, 0, 0) { - auto a = INPUT_VARIABLE(0); - auto b = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - bool isLower = true; - bool useAdjoint = false; +namespace ops { +CUSTOM_OP_IMPL(triangular_solve, 2, 1, false, 0, 0) { + auto a = INPUT_VARIABLE(0); + auto b = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + bool isLower = true; + bool useAdjoint = false; - if (block.numB() > 0) { - if (block.numB() > 1) { - isLower = B_ARG(0); - useAdjoint = B_ARG(1); - } - else { - isLower = B_ARG(0); - } - } + if (block.numB() > 0) { + if (block.numB() > 1) { + isLower = B_ARG(0); + useAdjoint = B_ARG(1); + } else { + isLower = B_ARG(0); + } + } - REQUIRE_TRUE(a->rankOf() >=2, 0, "triangular_solve: The rank of input left tensor should not be less than 2, but %i is given", a->rankOf()); - REQUIRE_TRUE(b->rankOf() >=2, 0, "triangular_solve: The rank of input right tensor should not be less than 2, but %i is given", b->rankOf()); + REQUIRE_TRUE(a->rankOf() >= 2, 0, + "triangular_solve: The rank of input left tensor should not be " + "less than 2, but %i is given", + a->rankOf()); + REQUIRE_TRUE(b->rankOf() >= 2, 0, + "triangular_solve: The rank of input right tensor should not be " + "less than 2, but %i is given", + b->rankOf()); - REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "triangular_solve: The last two dimmensions should be equal, but %i and %i are given", a->sizeAt(-1), a->sizeAt(-2)); - REQUIRE_TRUE(a->sizeAt(-1) == b->sizeAt(-2), 0, "triangular_solve: The last dimmension of left part should be equal to prelast of right part, but %i and %i are given", a->sizeAt(-1), b->sizeAt(-2)); - auto input = a; - if (useAdjoint) { - auto adjointA = a->ulike(); - helpers::adjointMatrix(block.launchContext(), a, isLower, &adjointA); - input = new NDArray(adjointA); //.detach(); - isLower = !isLower; - }; + REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, + "triangular_solve: The last two dimmensions should be equal, " + "but %i and %i are given", + a->sizeAt(-1), a->sizeAt(-2)); + REQUIRE_TRUE(a->sizeAt(-1) == b->sizeAt(-2), 0, + "triangular_solve: The last dimmension of left part should be " + "equal to prelast of right part, but %i and %i are given", + a->sizeAt(-1), b->sizeAt(-2)); + auto input = a; + if (useAdjoint) { + auto adjointA = a->ulike(); + helpers::adjointMatrix(block.launchContext(), a, isLower, &adjointA); + input = new NDArray(adjointA); //.detach(); + isLower = !isLower; + }; - auto res = helpers::triangularSolveFunctor(block.launchContext(), input, b, isLower, false, z); - if (input != a) - delete input; + auto res = helpers::triangularSolveFunctor(block.launchContext(), input, b, + isLower, false, z); + if (input != a) delete input; - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(triangular_solve) { - auto in0 = inputShape->at(1); - auto in1 = inputShape->at(1); - auto luShape = ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace()); +DECLARE_SHAPE_FN(triangular_solve) { + auto in0 = inputShape->at(1); + auto in1 = inputShape->at(1); + auto luShape = + ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace()); - return SHAPELIST(CONSTANT(luShape)); - } + return SHAPELIST(CONSTANT(luShape)); +} - DECLARE_TYPES(triangular_solve) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(false); - } - } +DECLARE_TYPES(triangular_solve) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(false); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/triu.cpp b/libnd4j/include/ops/declarable/generic/linalg/triu.cpp index 839828f62a77..0347ef8b8c52 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/triu.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/triu.cpp @@ -19,96 +19,106 @@ // #include -#include - +#include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(triu, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(input->rankOf() > 0, 0, "TRIU OP: the rank of input array must be > 0, but got %i instead !", input->rankOf()); + REQUIRE_TRUE( + input->rankOf() > 0, 0, + "TRIU OP: the rank of input array must be > 0, but got %i instead !", + input->rankOf()); - const int diag = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; + const int diag = block.numI() > 0 ? INT_ARG(0) : 0; - BUILD_SINGLE_SELECTOR(input->dataType(), input->fillAsTriangular, (0, diag, 0, *output, 'l' ), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), input->fillAsTriangular, + (0, diag, 0, *output, 'l'), LIBND4J_TYPES); - return Status::OK(); + return Status::OK(); } - DECLARE_TYPES(triu) { - getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(triu) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} DECLARE_SHAPE_FN(triu) { + auto inShapeInfo = inputShape->at(0); - auto inShapeInfo = inputShape->at(0); - - REQUIRE_TRUE(inShapeInfo[0] > 0, 0, "TRIU OP: the rank of input array must be > 0, but got %i instead !", inShapeInfo[0]); + REQUIRE_TRUE( + inShapeInfo[0] > 0, 0, + "TRIU OP: the rank of input array must be > 0, but got %i instead !", + inShapeInfo[0]); - int rank = (inShapeInfo[0] == 1) ? 2 : inShapeInfo[0]; + int rank = (inShapeInfo[0] == 1) ? 2 : inShapeInfo[0]; - Nd4jLong *outShapeInfo = nullptr; - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); - memcpy(outShapeInfo, inShapeInfo, (1 + rank) * sizeof(Nd4jLong)); // copy rank and dimensions values only + Nd4jLong* outShapeInfo = nullptr; + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); + memcpy( + outShapeInfo, inShapeInfo, + (1 + rank) * sizeof(Nd4jLong)); // copy rank and dimensions values only - if(inShapeInfo[0] == 1) { - outShapeInfo[0] = rank; - outShapeInfo[1] = inShapeInfo[1]; - outShapeInfo[2] = inShapeInfo[1]; - } + if (inShapeInfo[0] == 1) { + outShapeInfo[0] = rank; + outShapeInfo[1] = inShapeInfo[1]; + outShapeInfo[2] = inShapeInfo[1]; + } - ShapeUtils::updateStridesAndType(outShapeInfo, inShapeInfo, shape::order(inShapeInfo)); + ShapeUtils::updateStridesAndType(outShapeInfo, inShapeInfo, + shape::order(inShapeInfo)); - return SHAPELIST(CONSTANT(outShapeInfo)); + return SHAPELIST(CONSTANT(outShapeInfo)); } - - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(triu_bp, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); // dLoss/dO - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); // dLoss/dO - - auto gradI = OUTPUT_VARIABLE(0); // dLoss/dI + auto gradI = OUTPUT_VARIABLE(0); // dLoss/dI - REQUIRE_TRUE(input->rankOf() > 0, 0, "TRIU_BP OP: the rank of input array must be > 0, but got %i instead !", input->rankOf()); + REQUIRE_TRUE( + input->rankOf() > 0, 0, + "TRIU_BP OP: the rank of input array must be > 0, but got %i instead !", + input->rankOf()); - const int diag = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; + const int diag = block.numI() > 0 ? INT_ARG(0) : 0; - helpers::triuBP(block.launchContext(), *input, *gradO, *gradI, diag); + helpers::triuBP(block.launchContext(), *input, *gradO, *gradI, diag); - return Status::OK(); + return Status::OK(); } - DECLARE_TYPES(triu_bp) { - getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(triu_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} DECLARE_SHAPE_FN(triu_bp) { + auto gradOShapeInfo = inputShape->at(1); + int rank = gradOShapeInfo[0]; - auto gradOShapeInfo = inputShape->at(1); - int rank = gradOShapeInfo[0]; + Nd4jLong* outShapeInfo = nullptr; + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); + memcpy( + outShapeInfo, gradOShapeInfo, + (1 + rank) * sizeof(Nd4jLong)); // copy rank and dimensions values only - Nd4jLong* outShapeInfo = nullptr; - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); - memcpy(outShapeInfo, gradOShapeInfo, (1 + rank) * sizeof(Nd4jLong)); // copy rank and dimensions values only + auto in = inputShape->at(0); + ShapeUtils::updateStridesAndType(outShapeInfo, in, shape::order(in)); - auto in = inputShape->at(0); - ShapeUtils::updateStridesAndType(outShapeInfo, in, shape::order(in)); - - return SHAPELIST(CONSTANT(outShapeInfo)); + return SHAPELIST(CONSTANT(outShapeInfo)); } - -} -} \ No newline at end of file +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/zeta.cpp b/libnd4j/include/ops/declarable/generic/linalg/zeta.cpp index 6aba1fc5f4d1..4a53faeb48d7 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/zeta.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/zeta.cpp @@ -25,35 +25,41 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(zeta, 2, 1, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto q = INPUT_VARIABLE(1); +namespace ops { +CONFIGURABLE_OP_IMPL(zeta, 2, 1, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto q = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(x->isSameShape(q), 0, "ZETA op: two input arrays must have the same shapes, bot got x=%s and q=%s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(q).c_str()); + REQUIRE_TRUE(x->isSameShape(q), 0, + "ZETA op: two input arrays must have the same shapes, bot got " + "x=%s and q=%s !", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(q).c_str()); - Nd4jLong arrLen = x->lengthOf(); + Nd4jLong arrLen = x->lengthOf(); - // FIXME: this should NOT be loop. - for(Nd4jLong i = 0; i < arrLen; ++i ) { - REQUIRE_TRUE(x->e(i) > 1.f, 0, "ZETA op: all elements of x array must be > 1 !"); - REQUIRE_TRUE(q->e(i) > 0.f, 0, "ZETA op: all elements of q array must be > 0 !"); - } + // FIXME: this should NOT be loop. + for (Nd4jLong i = 0; i < arrLen; ++i) { + REQUIRE_TRUE(x->e(i) > 1.f, 0, + "ZETA op: all elements of x array must be > 1 !"); + REQUIRE_TRUE(q->e(i) > 0.f, 0, + "ZETA op: all elements of q array must be > 0 !"); + } - helpers::zeta(block.launchContext(), *x, *q, *output); + helpers::zeta(block.launchContext(), *x, *q, *output); - return Status::OK(); - } - DECLARE_SYN(Zeta, zeta); + return Status::OK(); +} +DECLARE_SYN(Zeta, zeta); - DECLARE_TYPES(zeta) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +DECLARE_TYPES(zeta) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/list/clone_list.cpp b/libnd4j/include/ops/declarable/generic/list/clone_list.cpp index d100153ec421..468a4566533b 100644 --- a/libnd4j/include/ops/declarable/generic/list/clone_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/clone_list.cpp @@ -24,19 +24,19 @@ #include namespace sd { - namespace ops { - LIST_OP_IMPL(clone_list, 1, 1, 0, 0) { - auto list = INPUT_LIST(0); +namespace ops { +LIST_OP_IMPL(clone_list, 1, 1, 0, 0) { + auto list = INPUT_LIST(0); - auto newList = list->clone(); + auto newList = list->clone(); - //OVERWRITE_RESULT(newList); - setupResultList(newList, block); - return ND4J_STATUS_OK; - } - DECLARE_SYN(TensorArrayIdentityV3, clone_list); - DECLARE_SYN(tensorarrayidentityv3, clone_list); - } + // OVERWRITE_RESULT(newList); + setupResultList(newList, block); + return ND4J_STATUS_OK; } +DECLARE_SYN(TensorArrayIdentityV3, clone_list); +DECLARE_SYN(tensorarrayidentityv3, clone_list); +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/list/create_list.cpp b/libnd4j/include/ops/declarable/generic/list/create_list.cpp index 606558e7edab..d362df25495e 100644 --- a/libnd4j/include/ops/declarable/generic/list/create_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/create_list.cpp @@ -24,40 +24,40 @@ #include namespace sd { - namespace ops { - LIST_OP_IMPL(create_list, 1, 2, 0, -2) { - int height = 0; - bool expandable = false; - if (block.numI() == 2) { - height = INT_ARG(0); - expandable = (bool) INT_ARG(1); - } else if (block.numI() == 1) { - height = INT_ARG(0); - } else if (block.width() == 1) { - height = INPUT_VARIABLE(0)->e(0); - expandable = true; - } else { - height = 0; - expandable = true; - } - - auto list = new NDArrayList(height, expandable); - - // we recieve input array for graph integrity purposes only - auto input = INPUT_VARIABLE(0); - setupResultList(list, block); -// OVERWRITE_RESULT(list); - - auto scalar = NDArrayFactory::create_(list->counter()); - block.pushNDArrayToVariableSpace(block.getNodeId(), 1, scalar); - - return ND4J_STATUS_OK; - } - DECLARE_SYN(TensorArrayV3, create_list); - DECLARE_SYN(tensorarrayv3, create_list); - DECLARE_SYN(TensorArrayCreateV3, create_list); - DECLARE_SYN(tensorarraycreatev3, create_list); - } +namespace ops { +LIST_OP_IMPL(create_list, 1, 2, 0, -2) { + int height = 0; + bool expandable = false; + if (block.numI() == 2) { + height = INT_ARG(0); + expandable = (bool)INT_ARG(1); + } else if (block.numI() == 1) { + height = INT_ARG(0); + } else if (block.width() == 1) { + height = INPUT_VARIABLE(0)->e(0); + expandable = true; + } else { + height = 0; + expandable = true; + } + + NDArrayList list(height, expandable); + + // we recieve input array for graph integrity purposes only + auto input = INPUT_VARIABLE(0); + setupResultList(list, block); + + auto scalar = NDArrayFactory::create(list.counter()); + block.pushNDArrayToVariableSpace(block.nodeId(), 1, scalar); + + return Status::OK(); } +DECLARE_SYN(TensorArrayV3, create_list); +DECLARE_SYN(tensorarrayv3, create_list); +DECLARE_SYN(TensorArrayCreateV3, create_list); +DECLARE_SYN(tensorarraycreatev3, create_list); +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/list/gather_list.cpp b/libnd4j/include/ops/declarable/generic/list/gather_list.cpp index 943313ad0242..2d1837ed8a88 100644 --- a/libnd4j/include/ops/declarable/generic/list/gather_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/gather_list.cpp @@ -24,53 +24,59 @@ #include namespace sd { - namespace ops { - LIST_OP_IMPL(gather_list, 2, 1, 0, -2) { - auto list = INPUT_LIST(0); - auto indices = INPUT_VARIABLE(1); +namespace ops { +LIST_OP_IMPL(gather_list, 2, 1, 0, -2) { + auto list = INPUT_LIST(0); + auto indices = INPUT_VARIABLE(1); - indices->printShapeInfo("indices shape"); - indices->printIndexedBuffer("indices"); + indices->printShapeInfo("indices shape"); + indices->printIndexedBuffer("indices"); - REQUIRE_TRUE(indices->isVector() || indices->rankOf() == 1, 0, "Indices for Gather operation should be a vector"); - REQUIRE_TRUE(list->height() > 0, 0, "Number of elements in list should be positive prior to Gather call"); - REQUIRE_TRUE(list->height() == indices->lengthOf(), 1, "Number of indicies should be equal to number of elements in list, but got [%i] indices instead", indices->lengthOf()); + REQUIRE_TRUE(indices->isVector() || indices->rankOf() == 1, 0, + "Indices for Gather operation should be a vector"); + REQUIRE_TRUE( + list->height() > 0, 0, + "Number of elements in list should be positive prior to Gather call"); + REQUIRE_TRUE(list->height() == indices->lengthOf(), 1, + "Number of indicies should be equal to number of elements in " + "list, but got [%i] indices instead", + indices->lengthOf()); - // first of all we need to get shapes - std::vector shape({0}); - shape[0] = indices->lengthOf(); - for (int e = 0; e < list->height(); e++) { - auto array = list->readRaw(e); + // first of all we need to get shapes + std::vector shape({0}); + shape[0] = indices->lengthOf(); + for (int e = 0; e < list->height(); e++) { + auto array = list->readRaw(e); - // now we should fill other dimensions - if (e == 0) { - for (int d = 0; d < array->rankOf(); d++) - shape.emplace_back(array->sizeAt(d)); - } - } + // now we should fill other dimensions + if (e == 0) { + for (int d = 0; d < array.rankOf(); d++) + shape.emplace_back(array.sizeAt(d)); + } + } - auto result = NDArrayFactory::create_('c', shape, list->dataType()); - std::vector indicesList((list->readRaw(0)->rankOf() + 1) * 2, 0); - int skipPosition = 0; - for (int e = 0; e < indices->lengthOf(); e++) { - auto idx = indices->e(e); - auto array = list->readRaw(idx); - - // first dimension - indicesList[0] = skipPosition; - indicesList[1] = skipPosition++ + 1; + auto result = NDArrayFactory::create('c', shape, list->dataType()); + std::vector indicesList((list->readRaw(0).rankOf() + 1) * 2, 0); + int skipPosition = 0; + for (int e = 0; e < indices->lengthOf(); e++) { + auto idx = indices->e(e); + auto array = list->readRaw(idx); - auto subarray = (*result)(indicesList, true); - subarray.assign(array); - } + // first dimension + indicesList[0] = skipPosition; + indicesList[1] = skipPosition++ + 1; - //OVERWRITE_RESULT(result); - setupResult(result, block); - return Status::OK(); - } - DECLARE_SYN(TensorArrayGatherV3, gather_list); - DECLARE_SYN(tensorarraygatherv3, gather_list); - } + auto subarray = (result)(indicesList, true); + subarray.assign(array); + } + + // OVERWRITE_RESULT(result); + setupResult(result, block); + return Status::OK(); } +DECLARE_SYN(TensorArrayGatherV3, gather_list); +DECLARE_SYN(tensorarraygatherv3, gather_list); +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/list/pick_list.cpp b/libnd4j/include/ops/declarable/generic/list/pick_list.cpp index 1254456bda4e..01eb6eb753d8 100644 --- a/libnd4j/include/ops/declarable/generic/list/pick_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/pick_list.cpp @@ -24,33 +24,36 @@ #include namespace sd { - namespace ops { - LIST_OP_IMPL(pick_list, 1, 1, 0, -2) { - auto list = INPUT_LIST(0); - - std::vector indices; - if (block.width() > 1 && block.getVariable(1)->getNDArray()->isVector()) { - auto ia = INPUT_VARIABLE(1); - for (int e = 0; e < ia->lengthOf(); e++) - indices.emplace_back(ia->e(e)); - } else if (block.getIArguments()->size() > 0) { - indices = *(block.getIArguments()); - } else return ND4J_STATUS_BAD_ARGUMENTS; - - for (auto& v: indices) { - if (v >= list->height()) { - nd4j_printf("Requested index [%i] is higher (or equal) then ArrayList height: [%i]", v, - list->height()); - return ND4J_STATUS_BAD_ARGUMENTS; - } - } - auto result = list->pick(indices); - -// OVERWRITE_RESULT(result); - setupResult(result, block); - return Status::OK(); - } +namespace ops { +LIST_OP_IMPL(pick_list, 1, 1, 0, -2) { + auto list = INPUT_LIST(0); + + std::vector indices; + if (block.width() > 1 && block.getVariable(1)->getNDArray()->isVector()) { + auto ia = INPUT_VARIABLE(1); + for (int e = 0; e < ia->lengthOf(); e++) + indices.emplace_back(ia->e(e)); + } else if (block.numI() > 0) { + indices = block.getIArguments(); + } else + return ND4J_STATUS_BAD_ARGUMENTS; + + for (auto& v : indices) { + if (v >= list->height()) { + nd4j_printf( + "Requested index [%i] is higher (or equal) then ArrayList height: " + "[%i]", + v, list->height()); + return ND4J_STATUS_BAD_ARGUMENTS; } + } + auto result = list->pick(indices); + + // OVERWRITE_RESULT(result); + setupResult(result, block); + return Status::OK(); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/list/read_list.cpp b/libnd4j/include/ops/declarable/generic/list/read_list.cpp index a1320b9b3767..5ab63a124043 100644 --- a/libnd4j/include/ops/declarable/generic/list/read_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/read_list.cpp @@ -24,40 +24,47 @@ #include namespace sd { - namespace ops { - LIST_OP_IMPL(read_list, 1, 1, 0, 0) { - auto list = INPUT_LIST(0); - NDArray *result = nullptr; +namespace ops { +LIST_OP_IMPL(read_list, 1, 1, 0, 0) { + auto list = INPUT_LIST(0); + NDArray result; - REQUIRE_TRUE(list->height() > 0, 0, "ReadList: number of elements in list should be positive prior to Read call"); + REQUIRE_TRUE(list->height() > 0, 0, + "ReadList: number of elements in list should be positive prior " + "to Read call"); - if (block.numI() > 0) { - auto index = INT_ARG(0); + if (block.numI() > 0) { + auto index = INT_ARG(0); - REQUIRE_TRUE(list->isWritten(index), 0, "ReadList: requested index [%i] wasn't written yet", index); + REQUIRE_TRUE(list->isWritten(index), 0, + "ReadList: requested index [%i] wasn't written yet", index); - result = list->read(index); - } else if (block.width() > 0) { - auto vec = INPUT_VARIABLE(1); + result = list->read(index); + } else if (block.width() > 0) { + auto vec = INPUT_VARIABLE(1); - REQUIRE_TRUE(vec->isScalar(), 0, "ReadList: index operand should be a scalar"); - - auto index = vec->e(0); + REQUIRE_TRUE(vec->isScalar(), 0, + "ReadList: index operand should be a scalar"); - REQUIRE_TRUE(list->isWritten(index), 0, "ReadList: requested index [%i] wasn't written yet", index); + auto index = vec->e(0); - result = list->read(index); - } else { - REQUIRE_TRUE(false, 0, "ReadList: index value should be set either via IntArgs or via second operand"); - } + REQUIRE_TRUE(list->isWritten(index), 0, + "ReadList: requested index [%i] wasn't written yet", index); -// OVERWRITE_RESULT(result); - setupResult(result, block); - return Status::OK(); - } - DECLARE_SYN(TensorArrayReadV3, read_list); - DECLARE_SYN(tensorarrayreadv3, read_list); - } + result = list->read(index); + } else { + REQUIRE_TRUE(false, 0, + "ReadList: index value should be set either via IntArgs or " + "via second operand"); + } + + // OVERWRITE_RESULT(result); + setupResult(result, block); + return Status::OK(); } +DECLARE_SYN(TensorArrayReadV3, read_list); +DECLARE_SYN(tensorarrayreadv3, read_list); +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/list/scatter_list.cpp b/libnd4j/include/ops/declarable/generic/list/scatter_list.cpp index 38a4da7bd9ce..c9c387d8fee7 100644 --- a/libnd4j/include/ops/declarable/generic/list/scatter_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/scatter_list.cpp @@ -21,56 +21,58 @@ #include #if NOT_EXCLUDED(OP_scatter_list) -#include #include +#include namespace sd { - namespace ops { - LIST_OP_IMPL(scatter_list, 1, 1, 0, -2) { - NDArrayList *list = nullptr; - NDArray *array = nullptr; - NDArray *indices = nullptr; +namespace ops { +LIST_OP_IMPL(scatter_list, 1, 1, 0, -2) { + NDArrayList *list = nullptr; + NDArray *array = nullptr; + NDArray *indices = nullptr; + + bool hasList = false; + auto w = block.width(); + + if (w >= 3) { + list = INPUT_LIST(0); + indices = INPUT_VARIABLE(1); + array = INPUT_VARIABLE(2); + hasList = true; + } else { + array = INPUT_VARIABLE(1); + indices = INPUT_VARIABLE(2); + list = new NDArrayList(indices->lengthOf(), false); - bool hasList = false; - auto w = block.width(); + throw std::runtime_error("scatter_list - Not implemented yet"); + } - if (w >= 3){ - list = INPUT_LIST(0); - indices = INPUT_VARIABLE(1); - array = INPUT_VARIABLE(2); - hasList = true; - } else { - array = INPUT_VARIABLE(1); - indices = INPUT_VARIABLE(2); - list = new NDArrayList(indices->lengthOf(), false); - block.trackList(list); - } + REQUIRE_TRUE(indices->isVector() || indices->rankOf() == 1, 0, + "ScatterList: Indices for Scatter should be a vector") + REQUIRE_TRUE(indices->lengthOf() == array->sizeAt(0), 0, + "ScatterList: Indices length should be equal number of TADs " + "along dim0, but got %i instead", + indices->lengthOf()); - REQUIRE_TRUE(indices->isVector() || indices->rankOf() == 1, 0, "ScatterList: Indices for Scatter should be a vector") - REQUIRE_TRUE(indices->lengthOf() == array->sizeAt(0), 0, "ScatterList: Indices length should be equal number of TADs along dim0, but got %i instead", indices->lengthOf()); + std::vector axis = ShapeUtils::evalDimsToExclude(array->rankOf(), {0}); + auto tads = array->allTensorsAlongDimension(axis); + for (int e = 0; e < tads.size(); e++) { + auto idx = indices->e(e); + if (idx >= tads.size()) return ND4J_STATUS_BAD_ARGUMENTS; - std::vector axis = ShapeUtils::evalDimsToExclude(array->rankOf(), {0}); - auto tads = array->allTensorsAlongDimension( axis); - for (int e = 0; e < tads.size(); e++) { - auto idx = indices->e(e); - if (idx >= tads.size()) - return ND4J_STATUS_BAD_ARGUMENTS; + auto arr = tads.at(e).dup(array->ordering()); + auto res = list->write(idx, arr); - auto arr = new NDArray(tads.at(e)->dup(array->ordering())); - auto res = list->write(idx, arr); - if (res != ND4J_STATUS_OK) - return res; - } + if (res != Status::OK()) return res; + } - if (!hasList) - //OVERWRITE_RESULT(list); - setupResultList(list, block); + if (!hasList) setupResultList(*list, block); - return Status::OK(); - } - DECLARE_SYN(TensorArrayScatterV3, scatter_list); - DECLARE_SYN(tensorarrayscatterv3, scatter_list); - } + return Status::OK(); } +DECLARE_SYN(TensorArrayScatterV3, scatter_list); +DECLARE_SYN(tensorarrayscatterv3, scatter_list); +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/list/size_list.cpp b/libnd4j/include/ops/declarable/generic/list/size_list.cpp index 9c4d7ff70ce5..a1dd7739f5fd 100644 --- a/libnd4j/include/ops/declarable/generic/list/size_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/size_list.cpp @@ -24,25 +24,26 @@ #include namespace sd { - namespace ops { - LIST_OP_IMPL(size_list, 1, 1, 0, 0) { - auto list = INPUT_LIST(0); +namespace ops { +LIST_OP_IMPL(size_list, 1, 1, 0, 0) { + auto list = INPUT_LIST(0); - auto result = NDArrayFactory::create_(list->height(), block.launchContext()); + auto result = + NDArrayFactory::create(list->height(), block.launchContext()); - //nd4j_printf("List size: [%i]\n", list->height()); - result->printIndexedBuffer("actual height"); + // nd4j_printf("List size: [%i]\n", list->height()); + result.printIndexedBuffer("actual height"); - //nd4j_printf("List size: [%i]\n", list->height()); - result->printIndexedBuffer("actual height"); + // nd4j_printf("List size: [%i]\n", list->height()); + result.printIndexedBuffer("actual height"); - //OVERWRITE_RESULT(result); - setupResult(result, block); - return Status::OK(); - } - DECLARE_SYN(TensorArraySizeV3, size_list); - DECLARE_SYN(tensorarraysizev3, size_list); - } + // OVERWRITE_RESULT(result); + setupResult(result, block); + return Status::OK(); } +DECLARE_SYN(TensorArraySizeV3, size_list); +DECLARE_SYN(tensorarraysizev3, size_list); +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/list/split_list.cpp b/libnd4j/include/ops/declarable/generic/list/split_list.cpp index c490479617c9..5cd69324e019 100644 --- a/libnd4j/include/ops/declarable/generic/list/split_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/split_list.cpp @@ -21,67 +21,75 @@ #include #if NOT_EXCLUDED(OP_split_list) -#include #include +#include namespace sd { - namespace ops { - LIST_OP_IMPL(split_list, 2, 1, 0, -2) { - NDArrayList *list = nullptr; - NDArray *array = nullptr; - NDArray *sizes = nullptr; - - bool hasList = false; - - if (block.width() >= 3){ - list = INPUT_LIST(0); - array = INPUT_VARIABLE(1); - sizes = INPUT_VARIABLE(2); - hasList = true; - } else { - array = INPUT_VARIABLE(0); - sizes = INPUT_VARIABLE(1); - list = new NDArrayList(sizes->lengthOf(), false); - block.trackList(list); - } - - REQUIRE_TRUE(sizes->isZ(), 0, "split_list: sizes array must have one of integer types"); - REQUIRE_TRUE(sizes->rankOf() == 1, 0, "split_list: sizes array must be 1D") - - list->shape() = array->getShapeAsVector(); - - // now let's build subarrays - int cnt = 0; - std::vector indices(2 * array->rankOf(), 0); - for (Nd4jLong e = 0; e < sizes->lengthOf(); e++) { - int c_size = sizes->e(e); - - REQUIRE_TRUE(c_size > 0, 0, "Slice size should have postive value, but got %i instead", c_size); - REQUIRE_TRUE(cnt < array->sizeAt(0) && cnt + c_size <= array->sizeAt(0), 0, "Slices size should NOT be higher then number of TADs of source array. Source size: [%i]; Slice start: [%i]; Slice size: [%i]", array->sizeAt(0), cnt, c_size); - - // we're adding our interval along zeroth dimension - indices[0] = cnt; - indices[1] = cnt + c_size; - cnt += c_size; - - auto subarray = (*array)(indices); - - auto status = list->write(e, new NDArray(subarray.dup(array->ordering()))); - - if (status != ND4J_STATUS_OK) - return status; - } - - if (!hasList) { - //OVERWRITE_RESULT(list); - setupResultList(list, block); - } - - return Status::OK(); - } - DECLARE_SYN(TensorArraySplitV3, split_list); - DECLARE_SYN(tensorarraysplitv3, split_list); - } +namespace ops { +LIST_OP_IMPL(split_list, 2, 1, 0, -2) { + NDArrayList *list = nullptr; + NDArray *array = nullptr; + NDArray *sizes = nullptr; + + bool hasList = false; + + if (block.width() >= 3) { + list = INPUT_LIST(0); + array = INPUT_VARIABLE(1); + sizes = INPUT_VARIABLE(2); + hasList = true; + } else { + array = INPUT_VARIABLE(0); + sizes = INPUT_VARIABLE(1); + list = new NDArrayList(sizes->lengthOf(), false); + // block.trackList(list); + + throw std::runtime_error("split_list - Not implemented yet"); + } + + REQUIRE_TRUE(sizes->isZ(), 0, + "split_list: sizes array must have one of integer types"); + REQUIRE_TRUE(sizes->rankOf() == 1, 0, "split_list: sizes array must be 1D") + + list->setShape(array->getShapeAsVector()); + + // now let's build subarrays + int cnt = 0; + std::vector indices(2 * array->rankOf(), 0); + for (Nd4jLong e = 0; e < sizes->lengthOf(); e++) { + int c_size = sizes->e(e); + + REQUIRE_TRUE(c_size > 0, 0, + "Slice size should have postive value, but got %i instead", + c_size); + REQUIRE_TRUE( + cnt < array->sizeAt(0) && cnt + c_size <= array->sizeAt(0), 0, + "Slices size should NOT be higher then number of TADs of source array. " + "Source size: [%i]; Slice start: [%i]; Slice size: [%i]", + array->sizeAt(0), cnt, c_size); + + // we're adding our interval along zeroth dimension + indices[0] = cnt; + indices[1] = cnt + c_size; + cnt += c_size; + + auto subarray = (*array)(indices); + + auto status = list->write(e, subarray.dup(array->ordering())); + + if (status != ND4J_STATUS_OK) return status; + } + + if (!hasList) { + // OVERWRITE_RESULT(list); + setupResultList(*list, block); + } + + return Status::OK(); } +DECLARE_SYN(TensorArraySplitV3, split_list); +DECLARE_SYN(tensorarraysplitv3, split_list); +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/list/stack_list.cpp b/libnd4j/include/ops/declarable/generic/list/stack_list.cpp index a0f0f422054d..85d6f99004cd 100644 --- a/libnd4j/include/ops/declarable/generic/list/stack_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/stack_list.cpp @@ -25,21 +25,21 @@ #include namespace sd { - namespace ops { - LIST_OP_IMPL(stack_list, 1, 1, 0, 0) { - auto list = INPUT_LIST(0); - //auto z = OUTPUT_VARIABLE(0); +namespace ops { +LIST_OP_IMPL(stack_list, 1, 1, 0, 0) { + auto list = INPUT_LIST(0); + // auto z = OUTPUT_VARIABLE(0); - // FIXME: this is obviously bad - auto result = list->stack(); + // FIXME: this is obviously bad + auto result = list->stack(); - //OVERWRITE_RESULT(result); - setupResult(result, block); - return Status::OK(); - } - DECLARE_SYN(TensorArrayConcatV3, stack_list); - DECLARE_SYN(tensorarrayconcatv3, stack_list); - } + // OVERWRITE_RESULT(result); + setupResult(result, block); + return Status::OK(); } +DECLARE_SYN(TensorArrayConcatV3, stack_list); +DECLARE_SYN(tensorarrayconcatv3, stack_list); +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/list/unstack_list.cpp b/libnd4j/include/ops/declarable/generic/list/unstack_list.cpp index 5f452294904a..b89a16bd3e9c 100644 --- a/libnd4j/include/ops/declarable/generic/list/unstack_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/unstack_list.cpp @@ -25,23 +25,20 @@ namespace sd { namespace ops { - LIST_OP_IMPL(unstack_list, 1, 1, 0, 0) { - auto outputList = INPUT_LIST(0); - auto input = INPUT_VARIABLE(int(outputList != nullptr) ); - - if (outputList == nullptr) { - outputList = new NDArrayList(0, true); - //block.trackList(outputList); - setupResultList(outputList, block); - } - outputList->unstack(input, INT_ARG(0)); - - //OVERWRITE_RESULT(list); - - // - return Status::OK(); - } -} +LIST_OP_IMPL(unstack_list, 1, 1, 0, 0) { + auto outputList = INPUT_LIST(0); + auto input = INPUT_VARIABLE(int(outputList != nullptr)); + + if (outputList == nullptr) { + outputList = new NDArrayList(0, true); + // block.trackList(outputList); + setupResultList(*outputList, block); + } + outputList->unstack(*input, INT_ARG(0)); + + return Status::OK(); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/list/write_list.cpp b/libnd4j/include/ops/declarable/generic/list/write_list.cpp index c61bcb68b3e0..e9a72b052d73 100644 --- a/libnd4j/include/ops/declarable/generic/list/write_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/write_list.cpp @@ -24,47 +24,47 @@ #include namespace sd { - namespace ops { - LIST_OP_IMPL(write_list, 2, 1, 0, -2) { - auto list = INPUT_LIST(0); - auto output = OUTPUT_VARIABLE(0); - //output->printShapeInfo("Write_list default output shape"); - // nd4j mode - if (block.width() >= 3) { - auto input = INPUT_VARIABLE(block.width() - 2); - auto idx = INPUT_VARIABLE(block.width() - 1); +namespace ops { +LIST_OP_IMPL(write_list, 2, 1, 0, -2) { + auto list = INPUT_LIST(0); + auto output = OUTPUT_VARIABLE(0); + // output->printShapeInfo("Write_list default output shape"); + // nd4j mode + if (block.width() >= 3) { + auto input = INPUT_VARIABLE(block.width() - 2); + auto idx = INPUT_VARIABLE(block.width() - 1); - REQUIRE_TRUE(idx->isScalar(), 0, "Index should be Scalar"); + REQUIRE_TRUE(idx->isScalar(), 0, "Index should be Scalar"); - //nd4j_printf("Writing [%i]:\n", idx->e(0)); - //input->printShapeInfo("input shape"); - //input->printIndexedBuffer("input buffer"); - Nd4jStatus result = list->write(idx->e(0), new NDArray(input->dup())); + // nd4j_printf("Writing [%i]:\n", idx->e(0)); + // input->printShapeInfo("input shape"); + // input->printIndexedBuffer("input buffer"); + Nd4jStatus result = list->write(idx->e(0), input->dup()); - auto res = NDArrayFactory::create_(list->counter(), block.launchContext()); - //res->printShapeInfo("Write_list 2 output shape"); + auto res = NDArrayFactory::create(list->counter(), block.launchContext()); + // res->printShapeInfo("Write_list 2 output shape"); - setupResult(res, block); -// OVERWRITE_RESULT(res); + setupResult(res, block); + // OVERWRITE_RESULT(res); - return result; - } else if (block.getIArguments()->size() == 1) { - auto input = INPUT_VARIABLE(1); - auto idx = INT_ARG(0); + return result; + } else if (block.numI() == 1) { + auto input = INPUT_VARIABLE(1); + auto idx = INT_ARG(0); - Nd4jStatus result = list->write(idx, new NDArray(input->dup())); + Nd4jStatus result = list->write(idx, input->dup()); - auto res = NDArrayFactory::create_(list->counter(), block.launchContext()); - //res->printShapeInfo("Write_list 1 output shape"); - //OVERWRITE_RESULT(res); - setupResult(res, block); - return result; - } else - return ND4J_STATUS_BAD_INPUT; - } - DECLARE_SYN(TensorArrayWriteV3, write_list); - DECLARE_SYN(tensorarraywritev3, write_list); - } + auto res = NDArrayFactory::create(list->counter(), block.launchContext()); + // res->printShapeInfo("Write_list 1 output shape"); + // OVERWRITE_RESULT(res); + setupResult(res, block); + return result; + } else + return ND4J_STATUS_BAD_INPUT; } +DECLARE_SYN(TensorArrayWriteV3, write_list); +DECLARE_SYN(tensorarraywritev3, write_list); +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/loss/absoluteDifference.cpp b/libnd4j/include/ops/declarable/generic/loss/absoluteDifference.cpp index 0d5d1d011d58..6661dc815dc9 100644 --- a/libnd4j/include/ops/declarable/generic/loss/absoluteDifference.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/absoluteDifference.cpp @@ -23,272 +23,369 @@ #include namespace sd { - namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(absolute_difference_loss, 3, 1, false, 0, 1) { - - auto predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - - // input validation - REQUIRE_TRUE(labels->isSameShape(predictions), 0, "ABSOLUTE_DIFFERENCE_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "ABSOLUTE_DIFFERENCE_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "ABSOLUTE_DIFFERENCE_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "ABSOLUTE_DIFFERENCE_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(predictions)) - weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); - - NDArray E = (*predictions - *labels).transform(sd::transform::Abs); - E *= *weightsBroad; - - switch (reductionMode) { - case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels. - output->assign(E); - break; - - case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array - E.reduceNumber(reduce::Sum, *output); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - NDArray sum; - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) - (*output) = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / sum); - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else { - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - } - - if (numOfNonZeroWeights == 0) - (*output) = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); + auto predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + + // input validation + REQUIRE_TRUE(labels->isSameShape(predictions), 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: labels and predictions arrays " + "must have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(predictions).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: shapes of weights and labels arrays should " + "be broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE( + reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(predictions)) + weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); + + NDArray E = (*predictions - *labels).transform(sd::transform::Abs); + E *= *weightsBroad; + + switch (reductionMode) { + case 0: // 0 - "none", un-reduced weighted losses with the same shape as + // labels. + output->assign(E); + break; + + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all + // elements of E array + E.reduceNumber(reduce::Sum, *output); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + NDArray sum; + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) + (*output) = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / sum); + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else { + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + } + + if (numOfNonZeroWeights == 0) + (*output) = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + break; + } + } + + if (weightsBroad != weights) delete weightsBroad; + + return Status::OK(); } DECLARE_TYPES(absolute_difference_loss) { - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } DECLARE_SHAPE_FN(absolute_difference_loss) { - - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "ABSOLUTE_DIFFERENCE_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "ABSOLUTE_DIFFERENCE_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "ABSOLUTE_DIFFERENCE_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - Nd4jLong const* outShapeInfo = nullptr; - - if(INT_ARG(0) != 0) // in this case output is scalar - outShapeInfo = ConstantShapeHelper::getInstance().scalarShapeInfo(outType); - else // in this case output has the same shape as labels and predictions - outShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); - - return SHAPELIST(outShapeInfo); + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: labels and predictions arrays " + "must have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: shapes of weights and labels arrays should " + "be broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + Nd4jLong const* outShapeInfo = nullptr; + + if (INT_ARG(0) != 0) // in this case output is scalar + outShapeInfo = ConstantShapeHelper::getInstance().scalarShapeInfo(outType); + else // in this case output has the same shape as labels and predictions + outShapeInfo = + ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor( + outType, shape::order(labelsShapeInfo), + shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); + + return SHAPELIST(outShapeInfo); } - - - - - - - - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(absolute_difference_loss_grad, 3, 3, false, 0, 1) { - - auto predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions - auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights - auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients - if(reductionMode == 0) - reductionMode = 1; - - // inputs validation - REQUIRE_TRUE(labels->isSameShape(predictions), 0, "ABSOLUTE_DIFFERENCE_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "ABSOLUTE_DIFFERENCE_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "ABSOLUTE_DIFFERENCE_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "ABSOLUTE_DIFFERENCE_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(predictions)) - weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); - - NDArray E = *predictions - *labels; - - // dE_i/dp_i = sign(p_i - y_i) - E.applyTransform(sd::transform::Sign, *dLdp); // dE/dp - // dE_i/dy_i = -sign(p_i - y_i) - - E.applyTransform(sd::transform::Abs, E); - - switch (reductionMode) { - - case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array - - *dLdp *= *weightsBroad; - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign(E); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - - NDArray sum; - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) { - *dLdp = 0.; - *dLdw = 0.; - } - else { - - *dLdp *= *weightsBroad / sum; - - if(weights->isScalar()) - *dLdw = 0.; - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); - } - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - - if (numOfNonZeroWeights == 0) { - *dLdp = 0.; - *dLdw = 0.; - } - else { - auto numOfNonZeroWeightsScalar = NDArrayFactory::create(dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - *dLdw /= numOfNonZeroWeightsScalar; - } - else - dLdw->assign(E / numOfNonZeroWeightsScalar); - - NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; - *dLdp *= temp; - } - break; - } - } - - dLdl->assign(-*dLdp); - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); + auto predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions + auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights + auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + // take into account Alex's proposition to treat "none" the same as + // "weighted_sum" mode when calculating gradients + if (reductionMode == 0) reductionMode = 1; + + // inputs validation + REQUIRE_TRUE(labels->isSameShape(predictions), 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: labels and predictions arrays " + "must have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(predictions).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: shapes of weights and labels arrays should " + "be broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE( + reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(predictions)) + weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); + + NDArray E = *predictions - *labels; + + // dE_i/dp_i = sign(p_i - y_i) + E.applyTransform(sd::transform::Sign, *dLdp); // dE/dp + // dE_i/dy_i = -sign(p_i - y_i) + + E.applyTransform(sd::transform::Abs, E); + + switch (reductionMode) { + case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to + // sum of all elements of E array + + *dLdp *= *weightsBroad; + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign(E); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + + NDArray sum; + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) { + *dLdp = 0.; + *dLdw = 0.; + } else { + *dLdp *= *weightsBroad / sum; + + if (weights->isScalar()) + *dLdw = 0.; + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)) + .reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign( + (E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)); + } + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + + if (numOfNonZeroWeights == 0) { + *dLdp = 0.; + *dLdw = 0.; + } else { + auto numOfNonZeroWeightsScalar = NDArrayFactory::create( + dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + *dLdw /= numOfNonZeroWeightsScalar; + } else + dLdw->assign(E / numOfNonZeroWeightsScalar); + + NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; + *dLdp *= temp; + } + break; + } + } + + dLdl->assign(-*dLdp); + + if (weightsBroad != weights) delete weightsBroad; + + return Status::OK(); } DECLARE_TYPES(absolute_difference_loss_grad) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } DECLARE_SHAPE_FN(absolute_difference_loss_grad) { - - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "ABSOLUTE_DIFFERENCE_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "ABSOLUTE_DIFFERENCE_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "ABSOLUTE_DIFFERENCE_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - - auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.getWorkspace()); - auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); - auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); - - return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo)); + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: labels and predictions arrays " + "must have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: shapes of weights and labels arrays should " + "be broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + + auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType( + predictionsShapeInfo, outType, false, block.workspace()); + auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, outType, false, block.workspace()); + auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType( + labelsShapeInfo, outType, false, block.workspace()); + + return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), + CONSTANT(dLdlShapeInfo)); } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/loss/cosineDistance.cpp b/libnd4j/include/ops/declarable/generic/loss/cosineDistance.cpp index 99cf2e3c13e6..e6d64290dc3f 100644 --- a/libnd4j/include/ops/declarable/generic/loss/cosineDistance.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/cosineDistance.cpp @@ -21,328 +21,438 @@ #include #if NOT_EXCLUDED(OP_cosine_distance_loss) -#include #include +#include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(cosine_distance_loss, 3, 1, false, 0, 2) { - - auto predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - int dim = INT_ARG(1); // axis along which sum will be made - if(dim < 0) - dim += labels->rankOf(); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(labels->isSameShape(predictions), 0, "COSINE_DISTANCE_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); - // regard 4 possible reduction modes below - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "COSINE_DISTANCE_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - // input dimension can't be larger than labels/predictions/weights rank - REQUIRE_TRUE(dim < labels->rankOf(), 0, "COSINE_DISTANCE_LOSS OP: input reduction dimension (got %i) must be < labels rank %i!", dim, labels->rankOf()); - - if(!output->isScalar()) { - // weights array can be single scalar or has the same shape as output, and must be broadcastable to output shape - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == output->rankOf(), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the same rank as output array, but got %i and %i correspondingly!", weights->rankOf(), output->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *output), 0, "COSINE_DISTANCE_LOSS OP: shapes of weights and output arrays should be broadcastable, but got weights = %s and output = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); + auto predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + int dim = INT_ARG(1); // axis along which sum will be made + if (dim < 0) dim += labels->rankOf(); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(labels->isSameShape(predictions), 0, + "COSINE_DISTANCE_LOSS OP: labels and predictions arrays must " + "have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(predictions).c_str()); + // regard 4 possible reduction modes below + REQUIRE_TRUE( + reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "COSINE_DISTANCE_LOSS OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + // input dimension can't be larger than labels/predictions/weights rank + REQUIRE_TRUE(dim < labels->rankOf(), 0, + "COSINE_DISTANCE_LOSS OP: input reduction dimension (got %i) " + "must be < labels rank %i!", + dim, labels->rankOf()); + + if (!output->isScalar()) { + // weights array can be single scalar or has the same shape as output, and + // must be broadcastable to output shape + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == output->rankOf(), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have " + "the same rank as output array, but got %i and %i correspondingly!", + weights->rankOf(), output->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *output), + 0, + "COSINE_DISTANCE_LOSS OP: shapes of weights and output arrays should " + "be broadcastable, but got weights = %s and output = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + } + + NDArray E = + 1. - + (*predictions * *labels).reduceAlongDimension(reduce::Sum, {dim}, true); + + // perform weights broadcasting/tile to E if it is necessary + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(&E)) + weightsBroad = new NDArray(weights->tileToShape(E.shapeInfo())); + + // multiply E on weights + E *= (*weightsBroad); + + switch (reductionMode) { + case 0: // 0 - "none", un-reduced weighted losses with the same shape as + // labels. + output->assign(&E); + break; + + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all + // elements of E array + output->assign(E.reduceNumber(reduce::Sum)); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + NDArray sum; + if (weights->isScalar()) + sum = *weights * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) + *output = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / sum); + break; } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else + numOfNonZeroWeights = + E.reduceNumber(reduce::CountNonZero).e(0); + + if (numOfNonZeroWeights == 0) + *output = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + + break; + } + } + + STORE_RESULT(*output); + + if (weightsBroad != weights) delete weightsBroad; - NDArray E = 1. - (*predictions * *labels).reduceAlongDimension(reduce::Sum, {dim}, true); - - // perform weights broadcasting/tile to E if it is necessary - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(&E)) - weightsBroad = new NDArray(weights->tileToShape(E.shapeInfo())); - - // multiply E on weights - E *= (*weightsBroad); - - switch (reductionMode) { - case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels. - output->assign(&E); - break; - - case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array - output->assign(E.reduceNumber(reduce::Sum)); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - NDArray sum; - if (weights->isScalar()) - sum = *weights * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) - *output = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / sum); - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else - numOfNonZeroWeights = E.reduceNumber(reduce::CountNonZero).e(0); - - if (numOfNonZeroWeights == 0) - *output = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - - break; - } - } - - - STORE_RESULT(*output); - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(cosine_distance_loss) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(cosine_distance_loss) { + // labels and predictions must have the same shapes + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + int dim = INT_ARG(1); + if (dim < 0) dim += labelsShapeInfo[0]; + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "COSINE_DISTANCE_LOSS OP: labels and predictions arrays must " + "have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // input dimension can't be larger than labels/predictions/weights rank + REQUIRE_TRUE(dim < labelsShapeInfo[0], 0, + "COSINE_DISTANCE_LOSS OP: input reduction dimension (got %i) " + "must be < labels rank %i!", + dim, labelsShapeInfo[0]); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + + // evaluate output shapeInfo + Nd4jLong const* outShapeInfo = nullptr; + if (INT_ARG(0) != 0) // in this case output is scalar + outShapeInfo = ConstantShapeHelper::getInstance().scalarShapeInfo(outType); + else { // in this case output has the same shape as labels reduced by dim + // axis - // labels and predictions must have the same shapes - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - int dim = INT_ARG(1); - if(dim < 0) - dim += labelsShapeInfo[0]; - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "COSINE_DISTANCE_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // input dimension can't be larger than labels/predictions/weights rank - REQUIRE_TRUE(dim < labelsShapeInfo[0], 0, "COSINE_DISTANCE_LOSS OP: input reduction dimension (got %i) must be < labels rank %i!", dim, labelsShapeInfo[0]); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - - // evaluate output shapeInfo - Nd4jLong const* outShapeInfo = nullptr; - if(INT_ARG(0) != 0) // in this case output is scalar - outShapeInfo = ConstantShapeHelper::getInstance().scalarShapeInfo(outType); - else { // in this case output has the same shape as labels reduced by dim axis - - std::vector dimensions = {dim}; - outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(predictionsShapeInfo), dimensions, predictionsShapeInfo, outType, true, false, block.getWorkspace()); - - // weights array can be single scalar or has the same rank as output, and must be broadcastable to output - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(outShapeInfo), 0, "COSINE_DISTANCE_LOSS OP: weights array should be scalar or have the same rank as output array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(outShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, outShapeInfo), 0, "COSINE_DISTANCE_LOSS OP: shapes of weights and output arrays should be broadcastable, but got weights = %s and output = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(outShapeInfo).c_str()); - } - - return SHAPELIST(outShapeInfo); + std::vector dimensions = {dim}; + outShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(predictionsShapeInfo), dimensions, predictionsShapeInfo, + outType, true, false, block.workspace()); + + // weights array can be single scalar or has the same rank as output, and + // must be broadcastable to output + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(outShapeInfo), + 0, + "COSINE_DISTANCE_LOSS OP: weights array should be scalar or have the " + "same rank as output array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(outShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, outShapeInfo), + 0, + "COSINE_DISTANCE_LOSS OP: shapes of weights and output arrays should " + "be broadcastable, but got weights = %s and output = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(outShapeInfo).c_str()); + } + + return SHAPELIST(outShapeInfo); } - - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(cosine_distance_loss_grad, 3, 3, false, 0, 2) { + auto predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions + auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights + auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + // take into account Alex's proposition to treat "none" the same as + // "weighted_sum" mode when calculating gradients + if (reductionMode == 0) reductionMode = 1; + + int dim = INT_ARG(1); // axis along which sum will be made + if (dim < 0) dim += labels->rankOf(); + + std::vector dimensions = {dim}; + + // input validation + REQUIRE_TRUE(labels->isSameShape(predictions), 0, + "COSINE_DISTANCE_LOSS_GRAD OP: labels and predictions arrays " + "must have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(predictions).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE( + reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "COSINE_DISTANCE_LOSS_GRAD OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo( + predictions->ordering(), dimensions, predictions->shapeInfo(), true, + false, block.workspace()); + // weights array can be single scalar or has the same shape as loss, and must + // be broadcastable to loss shape + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == shape::rank(lossShapeInfo), 0, + "COSINE_DISTANCE_LOSS_GRAD OP: weights array should be scalar or have " + "the same rank as loss array, but got %i and %i correspondingly!", + weights->rankOf(), shape::rank(lossShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || ShapeUtils::areShapesBroadcastable( + weights->shapeInfo(), lossShapeInfo), + 0, + "COSINE_DISTANCE_LOSS_GRAD OP: shapes of weights and loss arrays should " + "be broadcastable, but got weights = %s and loss = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(lossShapeInfo).c_str()); + // input dimension can't be larger than labels/predictions/weights rank + REQUIRE_TRUE(dim < labels->rankOf(), 0, + "COSINE_DISTANCE_LOSS_GRAD OP: input reduction dimension (got " + "%i) must be < labels rank %i!", + dim, labels->rankOf()); + + NDArray E = + 1. - + (*predictions * *labels).reduceAlongDimension(reduce::Sum, {dim}, true); + + // perform weights broadcasting/tile to E if it is necessary + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(&E)) + weightsBroad = new NDArray(weights->tileToShape(E.shapeInfo())); + + dLdp->assign(-*labels); + dLdl->assign(-*predictions); + + switch (reductionMode) { + case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to + // sum of all elements of E array + + *dLdp *= *weightsBroad; + *dLdl *= *weightsBroad; + + if (weights->isScalar() || weights->lengthOf() == 1) { + dLdw->assign(E.reduceNumber(reduce::Sum)); + } else { + if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign(E); + } + + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + NDArray sum; + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + NDArray temp = *weightsBroad / sum; + *dLdp *= temp; + *dLdl *= temp; + + if (weights->isScalar() || weights->lengthOf() == 1) { + *dLdw = 0.; + } else { + if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis( + weights->shapeInfo(), weightsBroad->shapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)) + .reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, + true, false, false); + } else + dLdw->assign( + (E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)); + } + } + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + + if (numOfNonZeroWeights == 0) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + NDArray temp = *weightsBroad / numOfNonZeroWeights; + *dLdp *= temp; + *dLdl *= temp; + + if (weights->isScalar() || weights->lengthOf() == 1) { + dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeights); + } else { + if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis( + weights->shapeInfo(), weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + *dLdw /= numOfNonZeroWeights; + } else + dLdw->assign(E / numOfNonZeroWeights); + } + } + break; + } + } - auto predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions - auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights - auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients - if(reductionMode == 0) - reductionMode = 1; - - int dim = INT_ARG(1); // axis along which sum will be made - if(dim < 0) - dim += labels->rankOf(); - - std::vector dimensions = {dim}; + if (weightsBroad != weights) delete weightsBroad; - // input validation - REQUIRE_TRUE(labels->isSameShape(predictions), 0, "COSINE_DISTANCE_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "COSINE_DISTANCE_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo(predictions->ordering(), dimensions, predictions->shapeInfo(), true, false, block.getWorkspace()); - // weights array can be single scalar or has the same shape as loss, and must be broadcastable to loss shape - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == shape::rank(lossShapeInfo), 0, "COSINE_DISTANCE_LOSS_GRAD OP: weights array should be scalar or have the same rank as loss array, but got %i and %i correspondingly!", weights->rankOf(), shape::rank(lossShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(weights->shapeInfo(), lossShapeInfo), 0, "COSINE_DISTANCE_LOSS_GRAD OP: shapes of weights and loss arrays should be broadcastable, but got weights = %s and loss = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(lossShapeInfo).c_str()); - // input dimension can't be larger than labels/predictions/weights rank - REQUIRE_TRUE(dim < labels->rankOf(), 0, "COSINE_DISTANCE_LOSS_GRAD OP: input reduction dimension (got %i) must be < labels rank %i!", dim, labels->rankOf()); - - NDArray E = 1. - (*predictions * *labels).reduceAlongDimension(reduce::Sum, {dim}, true); - - // perform weights broadcasting/tile to E if it is necessary - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(&E)) - weightsBroad = new NDArray(weights->tileToShape(E.shapeInfo())); - - dLdp->assign(-*labels); - dLdl->assign(-*predictions); - - switch (reductionMode) { - case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array - - *dLdp *= *weightsBroad; - *dLdl *= *weightsBroad; - - if(weights->isScalar() || weights->lengthOf() == 1) { - dLdw->assign(E.reduceNumber(reduce::Sum)); - } - else { - if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign(E); - } - - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - NDArray sum; - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - - NDArray temp = *weightsBroad / sum; - *dLdp *= temp; - *dLdl *= temp; - - if(weights->isScalar() || weights->lengthOf() == 1) { - *dLdw = 0.; - } - else { - - if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); - } - } - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - - if (numOfNonZeroWeights == 0) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - - NDArray temp = *weightsBroad / numOfNonZeroWeights; - *dLdp *= temp; - *dLdl *= temp; - - if(weights->isScalar() || weights->lengthOf() == 1) { - dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeights); - } - else { - - if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - *dLdw /= numOfNonZeroWeights; - } - else - dLdw->assign(E / numOfNonZeroWeights); - } - } - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(cosine_distance_loss_grad) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(cosine_distance_loss_grad) { - - /// labels and predictions must have the same shapes - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - int dim = INT_ARG(1); - if(dim < 0) - dim += labelsShapeInfo[0]; - - std::vector dimensions = {dim}; - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "COSINE_DISTANCE_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(predictionsShapeInfo), dimensions, predictionsShapeInfo, true, false, block.getWorkspace()); - // weights array can be single scalar or has the same rank as loss, and must be broadcastable to loss - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(lossShapeInfo), 0, "COSINE_DISTANCE_LOSS_GRAD OP: weights array should be scalar or have the same rank as loss array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(lossShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, lossShapeInfo), 0, "COSINE_DISTANCE_LOSS_GRAD OP: shapes of weights and loss arrays should be broadcastable, but got weights = %s and loss = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(lossShapeInfo).c_str()); - // input dimension can't be larger than labels/predictions/weights rank - REQUIRE_TRUE(dim < labelsShapeInfo[0], 0, "COSINE_DISTANCE_LOSS_GRAD OP: input reduction dimension (got %i) must be < labels rank %i!", dim, labelsShapeInfo[0]); - - auto outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - - auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.getWorkspace()); - auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); - auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); - - return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo)); + /// labels and predictions must have the same shapes + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + int dim = INT_ARG(1); + if (dim < 0) dim += labelsShapeInfo[0]; + + std::vector dimensions = {dim}; + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "COSINE_DISTANCE_LOSS_GRAD OP: labels and predictions arrays " + "must have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(predictionsShapeInfo), dimensions, predictionsShapeInfo, + true, false, block.workspace()); + // weights array can be single scalar or has the same rank as loss, and must + // be broadcastable to loss + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(lossShapeInfo), + 0, + "COSINE_DISTANCE_LOSS_GRAD OP: weights array should be scalar or have " + "the same rank as loss array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(lossShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, lossShapeInfo), + 0, + "COSINE_DISTANCE_LOSS_GRAD OP: shapes of weights and loss arrays should " + "be broadcastable, but got weights = %s and loss = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(lossShapeInfo).c_str()); + // input dimension can't be larger than labels/predictions/weights rank + REQUIRE_TRUE(dim < labelsShapeInfo[0], 0, + "COSINE_DISTANCE_LOSS_GRAD OP: input reduction dimension (got " + "%i) must be < labels rank %i!", + dim, labelsShapeInfo[0]); + + auto outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + + auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType( + predictionsShapeInfo, outType, false, block.workspace()); + auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, outType, false, block.workspace()); + auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType( + labelsShapeInfo, outType, false, block.workspace()); + + return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), + CONSTANT(dLdlShapeInfo)); } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp b/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp index 71e7489ea8c2..5ecb015d3a25 100644 --- a/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp @@ -24,281 +24,378 @@ #include namespace sd { - namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// - CUSTOM_OP_IMPL(hinge_loss, 3, 1, false, 0, 1) { - auto logits = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - - // input validation - REQUIRE_TRUE(labels->isSameShape(logits), 0, "HINGE_LOSS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "HINGE_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "HINGE_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "HINGE_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to logits if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(logits)) - weightsBroad = new NDArray(weights->tileToShape(logits->shapeInfo())); - - // We first need to convert binary labels to -1/1 labels (as floats) - NDArray E = 1.f - (*labels * 2.f - 1.f) * (*logits); - E.applyScalar(scalar::RELU, 0.0f, E); - - // multiply E on weights - E *= *weightsBroad; - - switch (reductionMode) { - case 0: { // 0 - "none", un-reduced weighted losses with the same shape as labels. - output->assign(E); - break; - } - case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array - E.reduceNumber(reduce::Sum, *output); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = *weights * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) - *output = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / sum); - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else { - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - } - - if (numOfNonZeroWeights == 0) - (*output) = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeights); - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); - } +CUSTOM_OP_IMPL(hinge_loss, 3, 1, false, 0, 1) { + auto logits = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + + // input validation + REQUIRE_TRUE(labels->isSameShape(logits), 0, + "HINGE_LOSS OP: labels and logits arrays must have the same " + "shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(logits).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "HINGE_LOSS OP: weights array should be scalar or have the same " + "rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE(weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "HINGE_LOSS OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE(reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "HINGE_LOSS OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to logits if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(logits)) + weightsBroad = new NDArray(weights->tileToShape(logits->shapeInfo())); + + // We first need to convert binary labels to -1/1 labels (as floats) + NDArray E = 1.f - (*labels * 2.f - 1.f) * (*logits); + E.applyScalar(scalar::RELU, 0.0f, E); + + // multiply E on weights + E *= *weightsBroad; + + switch (reductionMode) { + case 0: { // 0 - "none", un-reduced weighted losses with the same shape as + // labels. + output->assign(E); + break; + } + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all + // elements of E array + E.reduceNumber(reduce::Sum, *output); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = *weights * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) + *output = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / sum); + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else { + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + } + + if (numOfNonZeroWeights == 0) + (*output) = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeights); + break; + } + } -////////////////////////////////////////////////////////////////////////// - DECLARE_TYPES(hinge_loss) { + if (weightsBroad != weights) delete weightsBroad; - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); - } + return Status::OK(); +} ////////////////////////////////////////////////////////////////////////// - DECLARE_SHAPE_FN(hinge_loss) { - - auto logitsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, logitsShapeInfo), 0, "HINGE_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "HINGE_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "HINGE_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); - Nd4jLong const* outShapeInfo = nullptr; - - if(INT_ARG(0) != 0) // in this case output is scalar - outShapeInfo = ConstantShapeHelper::getInstance().scalarShapeInfo(outType); - else // in this case output has the same shape as labels and predictions - outShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); - - return SHAPELIST(outShapeInfo); +DECLARE_TYPES(hinge_loss) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - } +////////////////////////////////////////////////////////////////////////// +DECLARE_SHAPE_FN(hinge_loss) { + auto logitsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, logitsShapeInfo), 0, + "HINGE_LOSS OP: labels and predictions arrays must have the " + "same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "HINGE_LOSS OP: weights array should be scalar or have the same rank as " + "labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "HINGE_LOSS OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = + DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); + Nd4jLong const *outShapeInfo = nullptr; + + if (INT_ARG(0) != 0) // in this case output is scalar + outShapeInfo = ConstantShapeHelper::getInstance().scalarShapeInfo(outType); + else // in this case output has the same shape as labels and predictions + outShapeInfo = + ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor( + outType, shape::order(labelsShapeInfo), + shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); + + return SHAPELIST(outShapeInfo); +} +////////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(hinge_loss_grad, 3, 3, false, 0, 1) { + auto logits = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions + auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights + auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + // take into account Alex's proposition to treat "none" the same as + // "weighted_sum" mode when calculating gradients + if (reductionMode == 0) reductionMode = 1; + + // inputs validation + REQUIRE_TRUE(labels->isSameShape(logits), 0, + "HINGE_LOSS_GRAD OP: labels and logits arrays must have the " + "same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(logits).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "HINGE_LOSS_GRAD OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "HINGE_LOSS_GRAD OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE(reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "HINGE_LOSS_GRAD OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(logits)) + weightsBroad = new NDArray(weights->tileToShape(logits->shapeInfo())); + + // We first need to convert binary labels to -1/1 labels (as floats) + NDArray z = (*labels * 2.f - 1.f); + + NDArray E = 1.f - z * (*logits); + E.applyScalar(scalar::RELU, 0.0f, E); + // turn E into gradient mask + + NDArray gradientMask(E.shapeInfo(), block.workspace()); + E.applyTransform(sd::transform::Sign, gradientMask); + + dLdp->assign(-z * gradientMask); + dLdl->assign(-2.f * (*logits) * gradientMask); + + switch (reductionMode) { + case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to + // sum of all elements of E array + + *dLdp *= *weightsBroad; + *dLdl *= *weightsBroad; + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign(E); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + *dLdp *= *weightsBroad / sum; + *dLdl *= *weightsBroad / sum; + + if (weights->isScalar()) + *dLdw = 0.; + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)) + .reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign( + (E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)); + } + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + + if (numOfNonZeroWeights == 0) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + auto numOfNonZeroWeightsScalar = NDArrayFactory::create( + dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + *dLdw /= numOfNonZeroWeightsScalar; + } else + dLdw->assign(E / numOfNonZeroWeightsScalar); + + NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; + *dLdp *= temp; + *dLdl *= temp; + } + break; + } + } + if (weightsBroad != weights) delete weightsBroad; + return Status::OK(); +} -////////////////////////////////////////////////////////////////////////// - CUSTOM_OP_IMPL(hinge_loss_grad, 3, 3, false, 0, 1) { - - auto logits = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions - auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights - auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients - if(reductionMode == 0) - reductionMode = 1; - - // inputs validation - REQUIRE_TRUE(labels->isSameShape(logits), 0, "HINGE_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "HINGE_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "HINGE_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "HINGE_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(logits)) - weightsBroad = new NDArray(weights->tileToShape(logits->shapeInfo())); - - // We first need to convert binary labels to -1/1 labels (as floats) - NDArray z = (*labels * 2.f - 1.f); - - NDArray E = 1.f - z * (*logits); - E.applyScalar(scalar::RELU, 0.0f, E); - // turn E into gradient mask - - NDArray gradientMask(E.shapeInfo(), block.getWorkspace()); - E.applyTransform(sd::transform::Sign, gradientMask); - - dLdp->assign(-z * gradientMask); - dLdl->assign(-2.f * (*logits) * gradientMask); - - switch (reductionMode) { - - case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array - - *dLdp *= *weightsBroad; - *dLdl *= *weightsBroad; - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign(E); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - - *dLdp *= *weightsBroad / sum; - *dLdl *= *weightsBroad / sum; - - if(weights->isScalar()) - *dLdw = 0.; - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); - } - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - - if (numOfNonZeroWeights == 0) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - auto numOfNonZeroWeightsScalar = NDArrayFactory::create(dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - *dLdw /= numOfNonZeroWeightsScalar; - } - else - dLdw->assign(E / numOfNonZeroWeightsScalar); - - NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; - *dLdp *= temp; - *dLdl *= temp; - } - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); - } - - DECLARE_TYPES(hinge_loss_grad) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); - } - - DECLARE_SHAPE_FN(hinge_loss_grad) { - - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "HINGE_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "HINGE_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "HINGE_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - - Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.getWorkspace()); - Nd4jLong *dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); - Nd4jLong *dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); - - return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo); - } +DECLARE_TYPES(hinge_loss_grad) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - } +DECLARE_SHAPE_FN(hinge_loss_grad) { + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "HINGE_LOSS_GRAD OP: labels and predictions arrays must have " + "the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "HINGE_LOSS_GRAD OP: weights array should be scalar or have the same " + "rank as labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "HINGE_LOSS_GRAD OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + + Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType( + predictionsShapeInfo, outType, false, block.workspace()); + Nd4jLong *dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, outType, false, block.workspace()); + Nd4jLong *dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType( + labelsShapeInfo, outType, false, block.workspace()); + + return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp b/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp index 2d0b44b3c5e4..c3f4d7e46823 100644 --- a/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp @@ -24,297 +24,395 @@ #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(huber_loss, 3, 1, false, 1, 1) { - auto predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - auto output = OUTPUT_VARIABLE(0); - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - // FIXME: double? - double delta = T_ARG(0); - - // input validation - REQUIRE_TRUE(labels->isSameShape(predictions), 0, "HUBER_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "HUBER_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "HUBER_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "HUBER_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to predictions if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(predictions)) - weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); - - auto error = *predictions - *labels; - error.applyTransform(transform::Abs, error); - NDArray quadratic(error.shapeInfo(), block.getWorkspace()); - error.applyScalar(scalar::MinPairwise, delta, quadratic); - - NDArray E = quadratic * quadratic * 0.5f + (error - quadratic)*delta; - - // multiply E on weights - E *= *weightsBroad; - - switch (reductionMode) { - case 0: { // 0 - "none", un-reduced weighted losses with the same shape as labels. - output->assign(E); - break; - } - case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array - E.reduceNumber(reduce::Sum, *output); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = *weights * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) - *output = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / sum); - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else { - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - } - - if (numOfNonZeroWeights == 0) - (*output) = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); + auto predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + auto output = OUTPUT_VARIABLE(0); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + // FIXME: double? + double delta = T_ARG(0); + + // input validation + REQUIRE_TRUE(labels->isSameShape(predictions), 0, + "HUBER_LOSS OP: labels and predictions arrays must have the " + "same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(predictions).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "HUBER_LOSS OP: weights array should be scalar or have the same " + "rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE(weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "HUBER_LOSS OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE(reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "HUBER_LOSS OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to predictions if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(predictions)) + weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); + + auto error = *predictions - *labels; + error.applyTransform(transform::Abs, error); + NDArray quadratic(error.shapeInfo(), block.workspace()); + error.applyScalar(scalar::MinPairwise, delta, quadratic); + + NDArray E = quadratic * quadratic * 0.5f + (error - quadratic) * delta; + + // multiply E on weights + E *= *weightsBroad; + + switch (reductionMode) { + case 0: { // 0 - "none", un-reduced weighted losses with the same shape as + // labels. + output->assign(E); + break; + } + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all + // elements of E array + E.reduceNumber(reduce::Sum, *output); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = *weights * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) + *output = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / sum); + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else { + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + } + + if (numOfNonZeroWeights == 0) + (*output) = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + break; + } + } + + if (weightsBroad != weights) delete weightsBroad; + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(huber_loss) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(huber_loss) { - - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "HUBER_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "HUBER_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "HUBER_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - Nd4jLong const* outShapeInfo = nullptr; - - if(INT_ARG(0) != 0) // in this case output is scalar - outShapeInfo = ConstantShapeHelper::getInstance().scalarShapeInfo(outType); - else // in this case output has the same shape as labels and predictions - outShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); - - return SHAPELIST(outShapeInfo); + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "HUBER_LOSS OP: labels and predictions arrays must have the " + "same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "HUBER_LOSS OP: weights array should be scalar or have the same rank as " + "labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "HUBER_LOSS OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + Nd4jLong const* outShapeInfo = nullptr; + + if (INT_ARG(0) != 0) // in this case output is scalar + outShapeInfo = ConstantShapeHelper::getInstance().scalarShapeInfo(outType); + else // in this case output has the same shape as labels and predictions + outShapeInfo = + ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor( + outType, shape::order(labelsShapeInfo), + shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); + + return SHAPELIST(outShapeInfo); } ////////////////////////////////////////////////////////////////////////// - CUSTOM_OP_IMPL(huber_loss_grad, 3, 3, false, 1, 1) { - - auto predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions - auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights - auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels - - auto delta = T_ARG(0); - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients - if(reductionMode == 0) - reductionMode = 1; - - // inputs validation - REQUIRE_TRUE(labels->isSameShape(predictions), 0, "HUBER_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "HUBER_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "HUBER_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "HUBER_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(predictions)) - weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); - - NDArray diff = *predictions - *labels; - NDArray absDiff(diff); - absDiff.applyTransform(transform::Abs, absDiff); - NDArray quadratic(absDiff); - absDiff.applyScalar(scalar::MinPairwise, delta, quadratic); - - NDArray E = quadratic * quadratic * 0.5f + (absDiff - quadratic)*delta; - - NDArray lteMask(diff.shapeInfo(), BOOL, true, block.launchContext()); - absDiff.applyScalar(scalar::LessThanOrEqual, delta, lteMask); - - NDArray gtMask(diff.shapeInfo(), BOOL, true, block.launchContext()); - absDiff.applyScalar(scalar::GreaterThan, delta, gtMask); - - NDArray signDiff(diff); - diff.applyTransform(transform::Sign, signDiff); - - - auto gtMaskFloat = gtMask.cast(diff.dataType()); - auto lteMaskFloat = lteMask.cast(diff.dataType()); - - - dLdp->assign( lteMaskFloat * diff + gtMaskFloat * delta * signDiff); - dLdl->assign(-lteMaskFloat * diff - gtMaskFloat * delta * signDiff); - - switch (reductionMode) { - - case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array - - *dLdp *= *weightsBroad; - *dLdl *= *weightsBroad; - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign(E); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - - *dLdp *= *weightsBroad / sum; - *dLdl *= *weightsBroad / sum; - - if(weights->isScalar()) - *dLdw = 0.; - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); - } - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - - if (numOfNonZeroWeights == 0) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - auto numOfNonZeroWeightsScalar = NDArrayFactory::create(dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - *dLdw /= numOfNonZeroWeightsScalar; - } - else - dLdw->assign(E / numOfNonZeroWeightsScalar); - - NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; - *dLdp *= temp; - *dLdl *= temp; - } - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); - } - - DECLARE_TYPES(huber_loss_grad) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); - } - - DECLARE_SHAPE_FN(huber_loss_grad) { - - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "HUBER_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "HUBER_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "HUBER_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - - auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.getWorkspace()); - auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); - auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); - - return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo); - } - +CUSTOM_OP_IMPL(huber_loss_grad, 3, 3, false, 1, 1) { + auto predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions + auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights + auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels + + auto delta = T_ARG(0); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + // take into account Alex's proposition to treat "none" the same as + // "weighted_sum" mode when calculating gradients + if (reductionMode == 0) reductionMode = 1; + + // inputs validation + REQUIRE_TRUE(labels->isSameShape(predictions), 0, + "HUBER_LOSS_GRAD OP: labels and predictions arrays must have " + "the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(predictions).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "HUBER_LOSS_GRAD OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "HUBER_LOSS_GRAD OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE(reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "HUBER_LOSS_GRAD OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(predictions)) + weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); + + NDArray diff = *predictions - *labels; + NDArray absDiff(diff); + absDiff.applyTransform(transform::Abs, absDiff); + NDArray quadratic(absDiff); + absDiff.applyScalar(scalar::MinPairwise, delta, quadratic); + + NDArray E = quadratic * quadratic * 0.5f + (absDiff - quadratic) * delta; + + NDArray lteMask(diff.shapeInfo(), BOOL, true, block.launchContext()); + absDiff.applyScalar(scalar::LessThanOrEqual, delta, lteMask); + + NDArray gtMask(diff.shapeInfo(), BOOL, true, block.launchContext()); + absDiff.applyScalar(scalar::GreaterThan, delta, gtMask); + + NDArray signDiff(diff); + diff.applyTransform(transform::Sign, signDiff); + + auto gtMaskFloat = gtMask.cast(diff.dataType()); + auto lteMaskFloat = lteMask.cast(diff.dataType()); + + dLdp->assign(lteMaskFloat * diff + gtMaskFloat * delta * signDiff); + dLdl->assign(-lteMaskFloat * diff - gtMaskFloat * delta * signDiff); + + switch (reductionMode) { + case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to + // sum of all elements of E array + + *dLdp *= *weightsBroad; + *dLdl *= *weightsBroad; + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign(E); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + *dLdp *= *weightsBroad / sum; + *dLdl *= *weightsBroad / sum; + + if (weights->isScalar()) + *dLdw = 0.; + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)) + .reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign( + (E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)); + } + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + + if (numOfNonZeroWeights == 0) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + auto numOfNonZeroWeightsScalar = NDArrayFactory::create( + dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + *dLdw /= numOfNonZeroWeightsScalar; + } else + dLdw->assign(E / numOfNonZeroWeightsScalar); + + NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; + *dLdp *= temp; + *dLdl *= temp; + } + break; + } + } + + if (weightsBroad != weights) delete weightsBroad; + + return Status::OK(); +} +DECLARE_TYPES(huber_loss_grad) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } + +DECLARE_SHAPE_FN(huber_loss_grad) { + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "HUBER_LOSS_GRAD OP: labels and predictions arrays must have " + "the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "HUBER_LOSS_GRAD OP: weights array should be scalar or have the same " + "rank as labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "HUBER_LOSS_GRAD OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + + auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType( + predictionsShapeInfo, outType, false, block.workspace()); + auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, outType, false, block.workspace()); + auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType( + labelsShapeInfo, outType, false, block.workspace()); + + return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/loss/l2_loss.cpp b/libnd4j/include/ops/declarable/generic/loss/l2_loss.cpp index 48f3a64faa55..f3fd276d0eae 100644 --- a/libnd4j/include/ops/declarable/generic/loss/l2_loss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/l2_loss.cpp @@ -24,29 +24,32 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(l2_loss, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar"); - - // FIXME: output should be used directly here, to avoid sum - input->reduceNumber(reduce::SquaredNorm, *output); - (*output) /= 2.; - - return Status::OK(); - } - DECLARE_SHAPE_FN(l2_loss) { - return SHAPELIST(ConstantShapeHelper::getInstance().scalarShapeInfo(ArrayOptions::dataType(inputShape->at(0)))); - } - - DECLARE_TYPES(l2_loss) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +namespace ops { + +CUSTOM_OP_IMPL(l2_loss, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar"); + + // FIXME: output should be used directly here, to avoid sum + input->reduceNumber(reduce::SquaredNorm, *output); + (*output) /= 2.; + + return Status::OK(); } +DECLARE_SHAPE_FN(l2_loss) { + return SHAPELIST(ConstantShapeHelper::getInstance().scalarShapeInfo( + ArrayOptions::dataType(inputShape->at(0)))); +} + +DECLARE_TYPES(l2_loss) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp b/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp index ab0c8923e170..2e33533e4d4b 100644 --- a/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp @@ -24,290 +24,390 @@ #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(log_loss, 3, 1, false, 1, 1) { - - auto predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - // FIXME: double? - double epsilon = T_ARG(0); - - // input validation - REQUIRE_TRUE(labels->isSameShape(predictions), 0, "LOG_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "LOG_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "LOG_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "LOG_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to predictions if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(predictions)) - weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); - - NDArray E = -(*labels)*((*predictions + epsilon).transform(transform::Log)) - (1. - *labels)*(((1. + epsilon) - *predictions).transform(transform::Log)); - - // multiply E on weights - E *= *weightsBroad; - - switch (reductionMode) { - case 0: { // 0 - "none", un-reduced weighted losses with the same shape as labels. - output->assign(E); - break; - } - case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array - E.reduceNumber(reduce::Sum, *output); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = *weights * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) - *output = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / sum); - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else { - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - } - - if (numOfNonZeroWeights == 0) - (*output) = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); + auto predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + // FIXME: double? + double epsilon = T_ARG(0); + + // input validation + REQUIRE_TRUE(labels->isSameShape(predictions), 0, + "LOG_LOSS OP: labels and predictions arrays must have the same " + "shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(predictions).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "LOG_LOSS OP: weights array should be scalar or have the same " + "rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE(weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "LOG_LOSS OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE(reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "LOG_LOSS OP: reduction mode value is not acceptable, possible " + "values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to predictions if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(predictions)) + weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); + + NDArray E = + -(*labels) * ((*predictions + epsilon).transform(transform::Log)) - + (1. - *labels) * + (((1. + epsilon) - *predictions).transform(transform::Log)); + + // multiply E on weights + E *= *weightsBroad; + + switch (reductionMode) { + case 0: { // 0 - "none", un-reduced weighted losses with the same shape as + // labels. + output->assign(E); + break; + } + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all + // elements of E array + E.reduceNumber(reduce::Sum, *output); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = *weights * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) + *output = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / sum); + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else { + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + } + + if (numOfNonZeroWeights == 0) + (*output) = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + break; + } + } + + if (weightsBroad != weights) delete weightsBroad; + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(log_loss) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(log_loss) { - - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "LOG_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "LOG_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "LOG_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - Nd4jLong const* outShapeInfo = nullptr; - - if(INT_ARG(0) != 0) // in this case output is scalar - outShapeInfo = ConstantShapeHelper::getInstance().scalarShapeInfo(outType); - else // in this case output has the same shape as labels and predictions - outShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); - - return SHAPELIST(outShapeInfo); + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "LOG_LOSS OP: labels and predictions arrays must have the same " + "shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "LOG_LOSS OP: weights array should be scalar or have the same rank as " + "labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "LOG_LOSS OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + Nd4jLong const* outShapeInfo = nullptr; + + if (INT_ARG(0) != 0) // in this case output is scalar + outShapeInfo = ConstantShapeHelper::getInstance().scalarShapeInfo(outType); + else // in this case output has the same shape as labels and predictions + outShapeInfo = + ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor( + outType, shape::order(labelsShapeInfo), + shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); + + return SHAPELIST(outShapeInfo); } - - - - - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) { - - auto predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions - auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights - auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients - if(reductionMode == 0) - reductionMode = 1; - - // FIXME: double? - double epsilon = T_ARG(0); - - // input validation - REQUIRE_TRUE(labels->isSameShape(predictions), 0, "LOG_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "LOG_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "LOG_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "LOG_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(predictions)) - weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); - - NDArray predictPlusEps = *predictions + epsilon; - NDArray oneMinusLabels = 1. - *labels; - NDArray onePlusEpsMinusPredict = (1. + epsilon) - *predictions; - - // dE_i/dp_i = (1-y_i)/(1-p_i+eps) - y_i/(p_i+eps) - dLdp->assign(oneMinusLabels / onePlusEpsMinusPredict - *labels / predictPlusEps); // dE/dp - // dE_i/dy_i = log((1+2eps)/(p_i+eps) - 1) - ((1. + 2. * epsilon) / predictPlusEps - 1.).applyTransform(transform::Log, *dLdl); // dE/dy - - NDArray E = -(*labels) * predictPlusEps.transform(transform::Log) - oneMinusLabels * onePlusEpsMinusPredict.transform(transform::Log); - - // process 3 possible reduction modes below - switch (reductionMode) { - case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array - - *dLdp *= *weightsBroad; - *dLdl *= *weightsBroad; - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign(E); - - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - - NDArray temp = *weightsBroad / sum; - *dLdp *= temp; - *dLdl *= temp; - - if(weights->isScalar()) - *dLdw = 0.; - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); - } - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - - if (numOfNonZeroWeights == 0) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - auto numOfNonZeroWeightsScalar = NDArrayFactory::create(dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeights); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - *dLdw /= numOfNonZeroWeightsScalar; - } - else - dLdw->assign(E / numOfNonZeroWeightsScalar); - - NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; - *dLdp *= temp; - *dLdl *= temp; - } - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); + auto predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions + auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights + auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + // take into account Alex's proposition to treat "none" the same as + // "weighted_sum" mode when calculating gradients + if (reductionMode == 0) reductionMode = 1; + + // FIXME: double? + double epsilon = T_ARG(0); + + // input validation + REQUIRE_TRUE(labels->isSameShape(predictions), 0, + "LOG_LOSS_GRAD OP: labels and predictions arrays must have the " + "same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(predictions).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "LOG_LOSS_GRAD OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "LOG_LOSS_GRAD OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE(reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "LOG_LOSS_GRAD OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(predictions)) + weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); + + NDArray predictPlusEps = *predictions + epsilon; + NDArray oneMinusLabels = 1. - *labels; + NDArray onePlusEpsMinusPredict = (1. + epsilon) - *predictions; + + // dE_i/dp_i = (1-y_i)/(1-p_i+eps) - y_i/(p_i+eps) + dLdp->assign(oneMinusLabels / onePlusEpsMinusPredict - + *labels / predictPlusEps); // dE/dp + // dE_i/dy_i = log((1+2eps)/(p_i+eps) - 1) + ((1. + 2. * epsilon) / predictPlusEps - 1.) + .applyTransform(transform::Log, *dLdl); // dE/dy + + NDArray E = -(*labels) * predictPlusEps.transform(transform::Log) - + oneMinusLabels * onePlusEpsMinusPredict.transform(transform::Log); + + // process 3 possible reduction modes below + switch (reductionMode) { + case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to + // sum of all elements of E array + + *dLdp *= *weightsBroad; + *dLdl *= *weightsBroad; + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign(E); + + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + NDArray temp = *weightsBroad / sum; + *dLdp *= temp; + *dLdl *= temp; + + if (weights->isScalar()) + *dLdw = 0.; + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)) + .reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign( + (E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)); + } + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + + if (numOfNonZeroWeights == 0) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + auto numOfNonZeroWeightsScalar = NDArrayFactory::create( + dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeights); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + *dLdw /= numOfNonZeroWeightsScalar; + } else + dLdw->assign(E / numOfNonZeroWeightsScalar); + + NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; + *dLdp *= temp; + *dLdl *= temp; + } + break; + } + } + + if (weightsBroad != weights) delete weightsBroad; + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(log_loss_grad) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(log_loss_grad) { - - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "LOG_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "LOG_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "LOG_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - - auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.getWorkspace()); - auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); - auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); - - return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo)); + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "LOG_LOSS_GRAD OP: labels and predictions arrays must have the " + "same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "LOG_LOSS_GRAD OP: weights array should be scalar or have the same rank " + "as labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "LOG_LOSS_GRAD OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + + auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType( + predictionsShapeInfo, outType, false, block.workspace()); + auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, outType, false, block.workspace()); + auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType( + labelsShapeInfo, outType, false, block.workspace()); + + return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), + CONSTANT(dLdlShapeInfo)); } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp b/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp index 5cc6b60ab475..36e2638d045c 100644 --- a/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp @@ -25,287 +25,394 @@ namespace sd { namespace ops { - CUSTOM_OP_IMPL(log_poisson_loss, 3, 1, true, 0, 1) { - auto log_predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - - bool computeFullLoss = false; - if (block.numI() > 1) - computeFullLoss = INT_ARG(1) != 0; - - // inputs validation - REQUIRE_TRUE(labels->isSameShape(log_predictions), 0, "LOG_POISSON_LOSS OP: labels and log_predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(log_predictions).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "LOG_POISSON_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "LOG_POISSON_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "LOG_POISSON_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(log_predictions)) - weightsBroad = new NDArray(weights->tileToShape(log_predictions->shapeInfo())); - - - NDArray E(labels->shapeInfo(), block.getWorkspace()); - if (computeFullLoss) - labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, *log_predictions, E); - else - labels->applyPairwiseTransform(pairwise::LogPoissonLoss, *log_predictions, E); - - - // multiply E on weights - E *= *weightsBroad; - - switch (reductionMode) { - case 0: { // 0 - "none", un-reduced weighted losses with the same shape as labels. - output->assign(E); - break; - } - case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array - E.reduceNumber(reduce::Sum, *output); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = *weights * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) - *output = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / sum); - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else { - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - } - - if (numOfNonZeroWeights == 0) - (*output) = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); +CUSTOM_OP_IMPL(log_poisson_loss, 3, 1, true, 0, 1) { + auto log_predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + + bool computeFullLoss = false; + if (block.numI() > 1) computeFullLoss = INT_ARG(1) != 0; + + // inputs validation + REQUIRE_TRUE(labels->isSameShape(log_predictions), 0, + "LOG_POISSON_LOSS OP: labels and log_predictions arrays must " + "have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(log_predictions).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "LOG_POISSON_LOSS OP: weights array should be scalar or have the same " + "rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "LOG_POISSON_LOSS OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE(reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "LOG_POISSON_LOSS OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(log_predictions)) + weightsBroad = + new NDArray(weights->tileToShape(log_predictions->shapeInfo())); + + NDArray E(labels->shapeInfo(), block.workspace()); + if (computeFullLoss) + labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, + *log_predictions, E); + else + labels->applyPairwiseTransform(pairwise::LogPoissonLoss, *log_predictions, + E); + + // multiply E on weights + E *= *weightsBroad; + + switch (reductionMode) { + case 0: { // 0 - "none", un-reduced weighted losses with the same shape as + // labels. + output->assign(E); + break; } - - ////////////////////////////////////////////////////////////////////////// - DECLARE_TYPES(log_poisson_loss) { - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all + // elements of E array + E.reduceNumber(reduce::Sum, *output); + break; } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = *weights * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) + *output = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / sum); + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else { + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + } + + if (numOfNonZeroWeights == 0) + (*output) = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + break; + } + } - ////////////////////////////////////////////////////////////////////////// - DECLARE_SHAPE_FN(log_poisson_loss) { - - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); + if (weightsBroad != weights) delete weightsBroad; - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "LOG_POISSON_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "LOG_POISSON_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "LOG_POISSON_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + return Status::OK(); +} - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - Nd4jLong const* outShapeInfo = nullptr; +////////////////////////////////////////////////////////////////////////// +DECLARE_TYPES(log_poisson_loss) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - if(INT_ARG(0) != 0) // in this case output is scalar - outShapeInfo = ConstantShapeHelper::getInstance().scalarShapeInfo(outType); - else // in this case output has the same shape as labels and predictions - outShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(labelsShapeInfo, outType)); +////////////////////////////////////////////////////////////////////////// +DECLARE_SHAPE_FN(log_poisson_loss) { + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "LOG_POISSON_LOSS OP: labels and predictions arrays must have " + "the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "LOG_POISSON_LOSS OP: weights array should be scalar or have the same " + "rank as labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "LOG_POISSON_LOSS OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + Nd4jLong const* outShapeInfo = nullptr; + + if (INT_ARG(0) != 0) // in this case output is scalar + outShapeInfo = ConstantShapeHelper::getInstance().scalarShapeInfo(outType); + else // in this case output has the same shape as labels and predictions + outShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(labelsShapeInfo, outType)); + + return SHAPELIST(outShapeInfo); +} - return SHAPELIST(outShapeInfo); +////////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(log_poisson_loss_grad, 3, 3, false, 0, 1) { + auto log_predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions + auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights + auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + // take into account Alex's proposition to treat "none" the same as + // "weighted_sum" mode when calculating gradients + if (reductionMode == 0) reductionMode = 1; + + bool computeFullLoss = false; + if (block.numI() > 1) computeFullLoss = INT_ARG(1) != 0; + + // inputs validation + REQUIRE_TRUE(labels->isSameShape(log_predictions), 0, + "LOG_POISSON_LOSS_GRAD OP: labels and log_predictions arrays " + "must have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(log_predictions).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "LOG_POISSON_LOSS_GRAD OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "LOG_POISSON_LOSS_GRAD OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE( + reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "LOG_POISSON_LOSS_GRAD OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(log_predictions)) + weightsBroad = + new NDArray(weights->tileToShape(log_predictions->shapeInfo())); + + NDArray E(labels->shapeInfo(), block.workspace()); + if (computeFullLoss) { + labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, + *log_predictions, E); + + NDArray rDiv(labels->shapeInfo(), block.workspace()); + labels->applyScalar(scalar::ReverseDivide, 0.5f, rDiv); + dLdl->assign(rDiv + labels->transform(transform::Log) + + -(*log_predictions)); + } else { + labels->applyPairwiseTransform(pairwise::LogPoissonLoss, *log_predictions, + E); + + dLdl->assign(-(*log_predictions)); + } + + dLdp->assign(log_predictions->transform(transform::Exp) - (*labels)); + + switch (reductionMode) { + case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to + // sum of all elements of E array + + *dLdp *= *weightsBroad; + *dLdl *= *weightsBroad; + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign(E); + break; } - - ////////////////////////////////////////////////////////////////////////// - CUSTOM_OP_IMPL(log_poisson_loss_grad, 3, 3, false, 0, 1) { - - auto log_predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions - auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights - auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients - if(reductionMode == 0) - reductionMode = 1; - - bool computeFullLoss = false; - if (block.numI() > 1) - computeFullLoss = INT_ARG(1) != 0; - - // inputs validation - REQUIRE_TRUE(labels->isSameShape(log_predictions), 0, "LOG_POISSON_LOSS_GRAD OP: labels and log_predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(log_predictions).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "LOG_POISSON_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "LOG_POISSON_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "LOG_POISSON_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(log_predictions)) - weightsBroad = new NDArray(weights->tileToShape(log_predictions->shapeInfo())); - - - NDArray E(labels->shapeInfo(), block.getWorkspace()); - if (computeFullLoss) { - labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, *log_predictions, E); - - NDArray rDiv(labels->shapeInfo(), block.getWorkspace()); - labels->applyScalar(scalar::ReverseDivide, 0.5f, rDiv); - dLdl->assign(rDiv + labels->transform(transform::Log) + -(*log_predictions)); - } else { - labels->applyPairwiseTransform(pairwise::LogPoissonLoss, *log_predictions, E); - - dLdl->assign(-(*log_predictions)); - } - - dLdp->assign(log_predictions->transform(transform::Exp) - (*labels)); - - switch (reductionMode) { - - case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array - - *dLdp *= *weightsBroad; - *dLdl *= *weightsBroad; - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign(E); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - - *dLdp *= *weightsBroad / sum; - *dLdl *= *weightsBroad / sum; - - if(weights->isScalar()) - *dLdw = 0.; - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); - } - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - - if (numOfNonZeroWeights == 0) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - auto numOfNonZeroWeightsScalar = NDArrayFactory::create(dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - *dLdw /= numOfNonZeroWeightsScalar; - } - else - dLdw->assign(E / numOfNonZeroWeightsScalar); - - NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; - *dLdp *= temp; - *dLdl *= temp; - } - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + *dLdp *= *weightsBroad / sum; + *dLdl *= *weightsBroad / sum; + + if (weights->isScalar()) + *dLdw = 0.; + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)) + .reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign( + (E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)); + } + break; } - - DECLARE_TYPES(log_poisson_loss_grad) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + + if (numOfNonZeroWeights == 0) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + auto numOfNonZeroWeightsScalar = NDArrayFactory::create( + dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + *dLdw /= numOfNonZeroWeightsScalar; + } else + dLdw->assign(E / numOfNonZeroWeightsScalar); + + NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; + *dLdp *= temp; + *dLdl *= temp; + } + break; } + } - DECLARE_SHAPE_FN(log_poisson_loss_grad) { - - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "LOG_POISSON_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "LOG_POISSON_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "LOG_POISSON_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + if (weightsBroad != weights) delete weightsBroad; - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - - auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.getWorkspace()); - auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); - auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); + return Status::OK(); +} - return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo); - } +DECLARE_TYPES(log_poisson_loss_grad) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } + +DECLARE_SHAPE_FN(log_poisson_loss_grad) { + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "LOG_POISSON_LOSS_GRAD OP: labels and predictions arrays must " + "have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "LOG_POISSON_LOSS_GRAD OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "LOG_POISSON_LOSS_GRAD OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + + auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType( + predictionsShapeInfo, outType, false, block.workspace()); + auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, outType, false, block.workspace()); + auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType( + labelsShapeInfo, outType, false, block.workspace()); + + return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp b/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp index f36fa3c62ab6..6e3235376883 100644 --- a/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp @@ -26,357 +26,471 @@ #if NOT_EXCLUDED(OP_mean_pairwssqerr_loss) #include -#include + #include +#include namespace sd { - namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// - CUSTOM_OP_IMPL(mean_pairwssqerr_loss, 3, 1, false, 0, 1) { - /* - * Implementation of mean pairwise squared error loss - * - * For context on where this loss function may be useful see: - * - * Wei, Z., Zhang, J., Shen, X., Lin, Z., Mech, R., Hoai, M. and Samaras, D., 2018. - * Good view hunting: learning photo composition from dense view pairs. In Proceedings of the IEEE Conference on - * Computer Vision and Pattern Recognition (pp. 5437-5446). - * - * The paper defines the loss function as: - * - * L(y,q) = 1/((n*(n-1))/2) * (sum_(i,j=1..n,i!=j)((y_i - y_j) - (q_i - q_j))^2) - * - * with y: predictions, q: labels, n: length of y and q - * - * As creating those pairs is computationally expensive, we implement a mathematically equivalent function: - * - * L(y,q) = 4/(n*(n-1)) * (n * sum (y_i - q_i)^2 - (sum y_i - q_i)^2) - * - * This equivalency can be derived as: - * - * sum_(i,j=1..n,i!=j)((y_i - y_j) - (q_i - q_j))^2 = sum_(i,j=1..n,i!=j)((y_i - q_i) - (y_j - q_j))^2 - * - * To simplify the following equations we use - * - * sum_(i,j=1..n,i!=j)(d_i - d_j)^2 = sum_(i,j=1..n,i!=j)(d_i^2 + d_j^2 - 2*d_i*d_j) - * - * Due to the pairings each element will appear as both d_i and d_j exactly n-1 times. This allows us to split the sum: - * - * sum_(i,j=1..n,i!=j)(d_i^2 + d_j^2 - 2*d_i*d_j) = 2*(n-1)*sum d_i^2 - 2 * sum_(i,j=1..n,i!=j) d_i * d_j - * = 2*((n-1) * sum d_i^2 - sum_(i,j=1..n,i!=j) d_i * d_j) - * - * Now we use the following equivalency: - * - * (sum d_i)^2 = sum d_i^2 + sum_(i,j=1..n,i!=j) d_i * d_j - * - * This allows us to now use sum d_i^2 and (sum d_i)^2 as a quick way to calculate the sum: - * - * (n-1) * sum d_i^2 - sum_(i,j=1..n,i!=j) d_i * d_j = n * sum d_i^2 - (sum d_i)^2 - * - * And by substituting it into the original definition we get: - * - * 1/((n*(n-1))/2) * 2*(n * sum d_i^2 - (sum d_i)^2) - * - * Which can be again simplified to - * - * 4/(n*(n-1)) * (n * sum d_i^2 - (sum d_i)^2) - * - * After substituting d_i back to (y_i - q_i) this results in the function that we actually implement. - * - */ - auto predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - - - // input validation - REQUIRE_TRUE(labels->isSameShape(predictions), 0, - "MEAN_PAIRWSSQERR_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", - ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "MEAN_PAIRWSSQERR_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - if (labels->rankOf() == 1) { // If labels and predictions are of rank 1, it means that all data entries are 0-tensor (scalar) so that the result of becomes always zero. - *output = 0.; - return Status::OK(); - } - - std::vector reductionIdx = ShapeUtils::evalDimsToExclude(labels->rankOf(), {0}); - - auto n = double(labels->sizeAt(1)); - auto diffs = *predictions - *labels; - - auto sumOfSquares = (diffs * diffs).reduceAlongDimension(reduce::Sum, reductionIdx, true); - - auto squareOfSum = diffs.reduceAlongDimension(reduce::Sum, reductionIdx, true); - squareOfSum.applyScalar(scalar::Pow, 2, squareOfSum); - - - auto E = ((sumOfSquares * n) - squareOfSum) * (4/(n*(n-1))); - - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == E.rankOf(), 0, "MEAN_PAIRWSSQERR_LOSS_GRAD OP: weights array should be scalar or have the same rank as results array, but got %i and %i correspondingly!", weights->rankOf(), E.rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, E), 0, "MEAN_PAIRWSSQERR_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and results = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(&E).c_str()); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(E)) - weightsBroad = new NDArray(weights->tileToShape(E.shapeInfo())); - - E *= *weightsBroad; - - switch (reductionMode) { - case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels. - output->assign(E); - break; - - case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array - E.reduceNumber(reduce::Sum, *output); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) - (*output) = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / sum); - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else { - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - } - - if (numOfNonZeroWeights == 0) - (*output) = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - break; - } - } - - if (weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); - } +CUSTOM_OP_IMPL(mean_pairwssqerr_loss, 3, 1, false, 0, 1) { + /* + * Implementation of mean pairwise squared error loss + * + * For context on where this loss function may be useful see: + * + * Wei, Z., Zhang, J., Shen, X., Lin, Z., Mech, R., Hoai, M. and Samaras, D., + * 2018. Good view hunting: learning photo composition from dense view pairs. + * In Proceedings of the IEEE Conference on Computer Vision and Pattern + * Recognition (pp. 5437-5446). + * + * The paper defines the loss function as: + * + * L(y,q) = 1/((n*(n-1))/2) * (sum_(i,j=1..n,i!=j)((y_i - y_j) - (q_i - + * q_j))^2) + * + * with y: predictions, q: labels, n: length of y and q + * + * As creating those pairs is computationally expensive, we implement a + * mathematically equivalent function: + * + * L(y,q) = 4/(n*(n-1)) * (n * sum (y_i - q_i)^2 - (sum y_i - q_i)^2) + * + * This equivalency can be derived as: + * + * sum_(i,j=1..n,i!=j)((y_i - y_j) - (q_i - q_j))^2 = sum_(i,j=1..n,i!=j)((y_i + * - q_i) - (y_j - q_j))^2 + * + * To simplify the following equations we use + * + * sum_(i,j=1..n,i!=j)(d_i - d_j)^2 = sum_(i,j=1..n,i!=j)(d_i^2 + d_j^2 - + * 2*d_i*d_j) + * + * Due to the pairings each element will appear as both d_i and d_j exactly + * n-1 times. This allows us to split the sum: + * + * sum_(i,j=1..n,i!=j)(d_i^2 + d_j^2 - 2*d_i*d_j) = 2*(n-1)*sum d_i^2 - 2 * + * sum_(i,j=1..n,i!=j) d_i * d_j = 2*((n-1) * sum d_i^2 - sum_(i,j=1..n,i!=j) + * d_i * d_j) + * + * Now we use the following equivalency: + * + * (sum d_i)^2 = sum d_i^2 + sum_(i,j=1..n,i!=j) d_i * d_j + * + * This allows us to now use sum d_i^2 and (sum d_i)^2 as a quick way to + * calculate the sum: + * + * (n-1) * sum d_i^2 - sum_(i,j=1..n,i!=j) d_i * d_j = n * sum d_i^2 - (sum + * d_i)^2 + * + * And by substituting it into the original definition we get: + * + * 1/((n*(n-1))/2) * 2*(n * sum d_i^2 - (sum d_i)^2) + * + * Which can be again simplified to + * + * 4/(n*(n-1)) * (n * sum d_i^2 - (sum d_i)^2) + * + * After substituting d_i back to (y_i - q_i) this results in the function + * that we actually implement. + * + */ + auto predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + + // input validation + REQUIRE_TRUE(labels->isSameShape(predictions), 0, + "MEAN_PAIRWSSQERR_LOSS OP: labels and predictions arrays must " + "have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(predictions).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE( + reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "MEAN_PAIRWSSQERR_LOSS OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + if (labels->rankOf() == + 1) { // If labels and predictions are of rank 1, it means that all data + // entries are 0-tensor (scalar) so that the result of becomes + // always zero. + *output = 0.; + return Status::OK(); + } + + std::vector reductionIdx = + ShapeUtils::evalDimsToExclude(labels->rankOf(), {0}); + + auto n = double(labels->sizeAt(1)); + auto diffs = *predictions - *labels; + + auto sumOfSquares = + (diffs * diffs).reduceAlongDimension(reduce::Sum, reductionIdx, true); + + auto squareOfSum = + diffs.reduceAlongDimension(reduce::Sum, reductionIdx, true); + squareOfSum.applyScalar(scalar::Pow, 2, squareOfSum); + + auto E = ((sumOfSquares * n) - squareOfSum) * (4 / (n * (n - 1))); + + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == E.rankOf(), 0, + "MEAN_PAIRWSSQERR_LOSS_GRAD OP: weights array should be scalar or have " + "the same rank as results array, but got %i and %i correspondingly!", + weights->rankOf(), E.rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, E), 0, + "MEAN_PAIRWSSQERR_LOSS_GRAD OP: shapes of weights and labels arrays " + "should be broadcastable, but got weights = %s and results = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(&E).c_str()); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(E)) + weightsBroad = new NDArray(weights->tileToShape(E.shapeInfo())); + + E *= *weightsBroad; + + switch (reductionMode) { + case 0: // 0 - "none", un-reduced weighted losses with the same shape as + // labels. + output->assign(E); + break; + + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all + // elements of E array + E.reduceNumber(reduce::Sum, *output); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) + (*output) = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / sum); + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else { + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + } + + if (numOfNonZeroWeights == 0) + (*output) = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + break; + } + } + + if (weightsBroad != weights) delete weightsBroad; + + return Status::OK(); +} ////////////////////////////////////////////////////////////////////////// - DECLARE_TYPES(mean_pairwssqerr_loss) { +DECLARE_TYPES(mean_pairwssqerr_loss) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); - } +////////////////////////////////////////////////////////////////////////// +DECLARE_SHAPE_FN(mean_pairwssqerr_loss) { + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "MEAN_PAIRWSSQERR_LOSS OP: labels and predictions arrays must " + "have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + Nd4jLong const *outShapeInfo = nullptr; + + if (INT_ARG(0) != 0) // in this case output is scalar + outShapeInfo = ConstantShapeHelper::getInstance().scalarShapeInfo(outType); + else { // in this case output has the shape as labels and logits minus last + // dimension + std::vector dimensions = {-1}; + outShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(predictionsShapeInfo), dimensions, predictionsShapeInfo, + false, true, block.workspace()); + + // weights array can be single scalar or has the same rank as output, and + // must be broadcastable to output + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(outShapeInfo), + 0, + "MEAN_PAIRWSSQERR_LOSS OP: weights array should be scalar or have the " + "same rank as output array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(outShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, outShapeInfo), + 0, + "MEAN_PAIRWSSQERR_LOSS OP: shapes of weights and output arrays should " + "be broadcastable, but got weights = %s and output = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(outShapeInfo).c_str()); + } + + return SHAPELIST(outShapeInfo); +} ////////////////////////////////////////////////////////////////////////// - DECLARE_SHAPE_FN(mean_pairwssqerr_loss) { +CUSTOM_OP_IMPL(mean_pairwssqerr_loss_grad, 3, 3, false, 0, 1) { + auto predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions + auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights + auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + // take into account Alex's proposition to treat "none" the same as + // "weighted_sum" mode when calculating gradients + if (reductionMode == 0) reductionMode = 1; + + // inputs validation + REQUIRE_TRUE(labels->isSameShape(predictions), 0, + "MEAN_PAIRWSSQERR_LOSS_GRAD OP: labels and predictions arrays " + "must have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(predictions).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE( + reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "MEAN_PAIRWSSQERR_LOSS_GRAD OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + auto n = double(labels->sizeAt(1)); + auto diffs = *predictions - *labels; + + std::vector reductionIdx = + ShapeUtils::evalDimsToExclude(labels->rankOf(), {0}); + auto sumOfSquares = + (diffs * diffs).reduceAlongDimension(reduce::Sum, reductionIdx, true); + + auto squareOfSum = + diffs.reduceAlongDimension(reduce::Sum, reductionIdx, true); + squareOfSum.applyScalar(scalar::Pow, 2, squareOfSum); + + auto E = ((sumOfSquares * n) - squareOfSum) * (4 / (n * (n - 1))); + + auto sumPred = + predictions->reduceAlongDimension(reduce::Sum, reductionIdx, true); + auto sumLabel = labels->reduceAlongDimension(reduce::Sum, reductionIdx, true); + + dLdp->assign(((diffs * n) - sumPred + sumLabel) * (8 / (n * (n - 1)))); + + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == E.rankOf(), 0, + "MEAN_PAIRWSSQERR_LOSS_GRAD OP: weights array should be scalar or have " + "the same rank as results array, but got %i and %i correspondingly!", + weights->rankOf(), E.rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, E), 0, + "MEAN_PAIRWSSQERR_LOSS_GRAD OP: shapes of weights and labels arrays " + "should be broadcastable, but got weights = %s and results = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(&E).c_str()); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(E)) + weightsBroad = new NDArray(weights->tileToShape(E.shapeInfo())); + + switch (reductionMode) { + case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to + // sum of all elements of E array + + *dLdp *= *weightsBroad; + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign(E); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) { + *dLdp = 0.; + *dLdw = 0.; + } else { + *dLdp *= *weightsBroad / sum; + + if (weights->isScalar()) + *dLdw = 0.; + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)) + .reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign( + (E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)); + } + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + + if (numOfNonZeroWeights == 0) { + *dLdp = 0.; + *dLdw = 0.; + } else { + auto numOfNonZeroWeightsScalar = NDArrayFactory::create( + dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + *dLdw /= numOfNonZeroWeightsScalar; + } else + dLdw->assign(E / numOfNonZeroWeightsScalar); + + NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; + *dLdp *= temp; + } + break; + } + } - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); + dLdl->assign(-*dLdp); - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, - "MEAN_PAIRWSSQERR_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", - ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), - ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - Nd4jLong const* outShapeInfo = nullptr; + if (weightsBroad != weights) delete weightsBroad; - if(INT_ARG(0) != 0) // in this case output is scalar - outShapeInfo = ConstantShapeHelper::getInstance().scalarShapeInfo(outType); - else { // in this case output has the shape as labels and logits minus last dimension - std::vector dimensions = {-1}; - outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(predictionsShapeInfo), dimensions, predictionsShapeInfo, false, true, block.getWorkspace()); + return Status::OK(); +} - // weights array can be single scalar or has the same rank as output, and must be broadcastable to output - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(outShapeInfo), 0, "MEAN_PAIRWSSQERR_LOSS OP: weights array should be scalar or have the same rank as output array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(outShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, outShapeInfo), 0, "MEAN_PAIRWSSQERR_LOSS OP: shapes of weights and output arrays should be broadcastable, but got weights = %s and output = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(outShapeInfo).c_str()); - } - - return SHAPELIST(outShapeInfo); - } - - - ////////////////////////////////////////////////////////////////////////// - CUSTOM_OP_IMPL(mean_pairwssqerr_loss_grad, 3, 3, false, 0, 1) { - - auto predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions - auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights - auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels - - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients - if(reductionMode == 0) - reductionMode = 1; - - // inputs validation - REQUIRE_TRUE(labels->isSameShape(predictions), 0, "MEAN_PAIRWSSQERR_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "MEAN_PAIRWSSQERR_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - auto n = double(labels->sizeAt(1)); - auto diffs = *predictions - *labels; - - std::vector reductionIdx = ShapeUtils::evalDimsToExclude(labels->rankOf(), {0}); - auto sumOfSquares = (diffs * diffs).reduceAlongDimension(reduce::Sum, reductionIdx, true); - - auto squareOfSum = diffs.reduceAlongDimension(reduce::Sum, reductionIdx, true); - squareOfSum.applyScalar(scalar::Pow, 2, squareOfSum); - - auto E = ((sumOfSquares * n) - squareOfSum) * (4/(n*(n-1))); - - auto sumPred = predictions->reduceAlongDimension(reduce::Sum, reductionIdx, true); - auto sumLabel = labels->reduceAlongDimension(reduce::Sum, reductionIdx, true); - - dLdp->assign(((diffs * n) - sumPred + sumLabel)*(8/(n*(n-1)))); - - - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == E.rankOf(), 0, "MEAN_PAIRWSSQERR_LOSS_GRAD OP: weights array should be scalar or have the same rank as results array, but got %i and %i correspondingly!", weights->rankOf(), E.rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, E), 0, "MEAN_PAIRWSSQERR_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and results = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(&E).c_str()); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(E)) - weightsBroad = new NDArray(weights->tileToShape(E.shapeInfo())); - - switch (reductionMode) { - - case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array - - *dLdp *= *weightsBroad; - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign(E); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) { - *dLdp = 0.; - *dLdw = 0.; - } - else { - - *dLdp *= *weightsBroad / sum; - - if(weights->isScalar()) - *dLdw = 0.; - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); - } - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - - if (numOfNonZeroWeights == 0) { - *dLdp = 0.; - *dLdw = 0.; - } - else { - auto numOfNonZeroWeightsScalar = NDArrayFactory::create(dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - *dLdw /= numOfNonZeroWeightsScalar; - } - else - dLdw->assign(E / numOfNonZeroWeightsScalar); - - NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; - *dLdp *= temp; - } - break; - } - } - - dLdl->assign(-*dLdp); - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); - } - - DECLARE_TYPES(mean_pairwssqerr_loss_grad) { - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); - } - - DECLARE_SHAPE_FN(mean_pairwssqerr_loss_grad) { - - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "MEAN_PAIRWSSQERR_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "MEAN_PAIRWSSQERR_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "MEAN_PAIRWSSQERR_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - - Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.getWorkspace()); - Nd4jLong *dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); - Nd4jLong *dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); - - return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo); - } - } +DECLARE_TYPES(mean_pairwssqerr_loss_grad) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +DECLARE_SHAPE_FN(mean_pairwssqerr_loss_grad) { + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "MEAN_PAIRWSSQERR_LOSS_GRAD OP: labels and predictions arrays " + "must have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "MEAN_PAIRWSSQERR_LOSS_GRAD OP: weights array should be scalar or have " + "the same rank as labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "MEAN_PAIRWSSQERR_LOSS_GRAD OP: shapes of weights and labels arrays " + "should be broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + + Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType( + predictionsShapeInfo, outType, false, block.workspace()); + Nd4jLong *dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, outType, false, block.workspace()); + Nd4jLong *dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType( + labelsShapeInfo, outType, false, block.workspace()); + + return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo); } +} // namespace ops +} // namespace sd #endif #pragma clang diagnostic pop diff --git a/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp b/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp index 6c54706c4809..b6189a8980ad 100644 --- a/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp @@ -24,280 +24,374 @@ #include namespace sd { - namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(mean_sqerr_loss, 3, 1, false, 0, 1) { - auto predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - auto output = OUTPUT_VARIABLE(0); - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - - // inputs validation - REQUIRE_TRUE(labels->isSameShape(predictions), 0, "MEAN_SQERR_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "MEAN_SQERR_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "MEAN_SQERR_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "MEAN_SQERR_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(predictions)) - weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); - - NDArray E(labels->shapeInfo(), false, block.launchContext()); - predictions->applyPairwiseTransform(pairwise::SquaredSubtract, *labels, E); - - // multiply E on weights - E *= (*weightsBroad); - - switch (reductionMode) { - case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels. - output->assign(&E); - break; - - case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array - E.reduceNumber(reduce::Sum, *output); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) - (*output) = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / sum); - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else { - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - } - - if (numOfNonZeroWeights == 0) - (*output) = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - break; - } - } - - - STORE_RESULT(*output); - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); + auto predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + auto output = OUTPUT_VARIABLE(0); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + + // inputs validation + REQUIRE_TRUE(labels->isSameShape(predictions), 0, + "MEAN_SQERR_LOSS OP: labels and predictions arrays must have " + "the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(predictions).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "MEAN_SQERR_LOSS OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "MEAN_SQERR_LOSS OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE(reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "MEAN_SQERR_LOSS OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(predictions)) + weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); + + NDArray E(labels->shapeInfo(), false, block.launchContext()); + predictions->applyPairwiseTransform(pairwise::SquaredSubtract, *labels, E); + + // multiply E on weights + E *= (*weightsBroad); + + switch (reductionMode) { + case 0: // 0 - "none", un-reduced weighted losses with the same shape as + // labels. + output->assign(&E); + break; + + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all + // elements of E array + E.reduceNumber(reduce::Sum, *output); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) + (*output) = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / sum); + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else { + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + } + + if (numOfNonZeroWeights == 0) + (*output) = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + break; + } + } + + STORE_RESULT(*output); + + if (weightsBroad != weights) delete weightsBroad; + + return Status::OK(); } DECLARE_TYPES(mean_sqerr_loss) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - DECLARE_SHAPE_FN(mean_sqerr_loss) { - - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "MEAN_SQERR_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "MEAN_SQERR_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "MEAN_SQERR_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - Nd4jLong const* outShapeInfo = nullptr; - - if(INT_ARG(0) != 0) // in this case output is scalar - outShapeInfo = ConstantShapeHelper::getInstance().scalarShapeInfo(outType); - else // in this case output has the same shape as labels and predictions - outShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); - - return SHAPELIST(outShapeInfo); - + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "MEAN_SQERR_LOSS OP: labels and predictions arrays must have " + "the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "MEAN_SQERR_LOSS OP: weights array should be scalar or have the same " + "rank as labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "MEAN_SQERR_LOSS OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + Nd4jLong const* outShapeInfo = nullptr; + + if (INT_ARG(0) != 0) // in this case output is scalar + outShapeInfo = ConstantShapeHelper::getInstance().scalarShapeInfo(outType); + else // in this case output has the same shape as labels and predictions + outShapeInfo = + ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor( + outType, shape::order(labelsShapeInfo), + shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); + + return SHAPELIST(outShapeInfo); } - - - - - - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) { - - auto predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions - auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights - auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients - if(reductionMode == 0) - reductionMode = 1; - - // inputs validation - REQUIRE_TRUE(labels->isSameShape(predictions), 0, "MEAN_SQERR_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "MEAN_SQERR_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "MEAN_SQERR_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "MEAN_SQERR_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(predictions)) - weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); - - NDArray diff = *predictions - *labels; - - // dE_i/dp_i = 2 * (p_i - y_i) - dLdp->assign(2. * diff); // dE/dp - // dE_i/dy_i = -2 * (p_i - y_i) - // dLdl->assign(-(*dLdp)); // dE/dl - - NDArray E = diff * diff; - - switch (reductionMode) { - - case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array - - *dLdp *= *weightsBroad; - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign(E); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) { - *dLdp = 0.; - *dLdw = 0.; - } - else { - - *dLdp *= *weightsBroad / sum; - - if(weights->isScalar()) - *dLdw = 0.; - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); - } - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - - if (numOfNonZeroWeights == 0) { - *dLdp = 0.; - *dLdw = 0.; - } - else { - auto numOfNonZeroWeightsScalar = NDArrayFactory::create(dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - *dLdw /= numOfNonZeroWeightsScalar; - } - else - dLdw->assign(E / numOfNonZeroWeights); - - NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; - *dLdp *= temp; - } - break; - } - } - - dLdl->assign(-(*dLdp)); - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); + auto predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions + auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights + auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + // take into account Alex's proposition to treat "none" the same as + // "weighted_sum" mode when calculating gradients + if (reductionMode == 0) reductionMode = 1; + + // inputs validation + REQUIRE_TRUE(labels->isSameShape(predictions), 0, + "MEAN_SQERR_LOSS_GRAD OP: labels and predictions arrays must " + "have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(predictions).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "MEAN_SQERR_LOSS_GRAD OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "MEAN_SQERR_LOSS_GRAD OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE( + reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "MEAN_SQERR_LOSS_GRAD OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(predictions)) + weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); + + NDArray diff = *predictions - *labels; + + // dE_i/dp_i = 2 * (p_i - y_i) + dLdp->assign(2. * diff); // dE/dp + // dE_i/dy_i = -2 * (p_i - y_i) + // dLdl->assign(-(*dLdp)); // dE/dl + + NDArray E = diff * diff; + + switch (reductionMode) { + case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to + // sum of all elements of E array + + *dLdp *= *weightsBroad; + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign(E); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) { + *dLdp = 0.; + *dLdw = 0.; + } else { + *dLdp *= *weightsBroad / sum; + + if (weights->isScalar()) + *dLdw = 0.; + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)) + .reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign( + (E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)); + } + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + + if (numOfNonZeroWeights == 0) { + *dLdp = 0.; + *dLdw = 0.; + } else { + auto numOfNonZeroWeightsScalar = NDArrayFactory::create( + dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + *dLdw /= numOfNonZeroWeightsScalar; + } else + dLdw->assign(E / numOfNonZeroWeights); + + NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; + *dLdp *= temp; + } + break; + } + } + + dLdl->assign(-(*dLdp)); + + if (weightsBroad != weights) delete weightsBroad; + + return Status::OK(); } DECLARE_TYPES(mean_sqerr_loss_grad) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } DECLARE_SHAPE_FN(mean_sqerr_loss_grad) { - - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "MEAN_SQERR_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "MEAN_SQERR_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "MEAN_SQERR_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - - auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.getWorkspace()); - auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); - auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); - - return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo)); + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "MEAN_SQERR_LOSS_GRAD OP: labels and predictions arrays must " + "have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "MEAN_SQERR_LOSS_GRAD OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "MEAN_SQERR_LOSS_GRAD OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + + auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType( + predictionsShapeInfo, outType, false, block.workspace()); + auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, outType, false, block.workspace()); + auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType( + labelsShapeInfo, outType, false, block.workspace()); + + return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), + CONSTANT(dLdlShapeInfo)); } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp b/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp index ddd28d43d6c5..f2213c4a0df8 100644 --- a/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp @@ -25,303 +25,408 @@ #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(sigm_cross_entropy_loss, 3, 1, false, 1, 1) { - auto logits = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - auto output = OUTPUT_VARIABLE(0); - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - auto labelsSmoothing = T_ARG(0); - - // input validation - REQUIRE_TRUE(labels->isSameShape(logits), 0, "SIGM_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "SIGM_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "SIGM_CROSS_ENTROPY_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "SIGM_CROSS_ENTROPY_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(logits)) - weightsBroad = new NDArray(weights->tileToShape(logits->shapeInfo())); - - // If labelsSmoothing is nonzero, smooth the labels towards 1/2: - auto newLabels = labels; - if(labelsSmoothing != 0.) { - newLabels = new NDArray(*labels); - newLabels->applyScalar(scalar::SXELogitsSmoother, labelsSmoothing, *newLabels); - } - - NDArray E(labels, false, block.launchContext()); - - // logits - labels * logits + log(1 + exp(-logits)) -> take into account numerical stability at large logits - helpers::sigmCrossEntropy(block.launchContext(), logits, newLabels, &E); - - // multiply E on weights - E *= *weightsBroad; - - switch (reductionMode) { - case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels. - output->assign(E); - break; - - case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array - E.reduceNumber(reduce::Sum, *output); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) - *output = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / sum); - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else { - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - } - - if (numOfNonZeroWeights == 0) - (*output) = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - if(newLabels != labels) - delete newLabels; - - return Status::OK(); + auto logits = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + auto output = OUTPUT_VARIABLE(0); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + auto labelsSmoothing = T_ARG(0); + + // input validation + REQUIRE_TRUE(labels->isSameShape(logits), 0, + "SIGM_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have " + "the same shapes, but got %s and %s correspondingly!", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(logits).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "SIGM_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "SIGM_CROSS_ENTROPY_LOSS OP: shapes of weights and labels arrays should " + "be broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE( + reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "SIGM_CROSS_ENTROPY_LOSS OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(logits)) + weightsBroad = new NDArray(weights->tileToShape(logits->shapeInfo())); + + // If labelsSmoothing is nonzero, smooth the labels towards 1/2: + auto newLabels = labels; + if (labelsSmoothing != 0.) { + newLabels = new NDArray(*labels); + newLabels->applyScalar(scalar::SXELogitsSmoother, labelsSmoothing, + *newLabels); + } + + NDArray E(labels, false, block.launchContext()); + + // logits - labels * logits + log(1 + exp(-logits)) -> take into account + // numerical stability at large logits + helpers::sigmCrossEntropy(block.launchContext(), logits, newLabels, &E); + + // multiply E on weights + E *= *weightsBroad; + + switch (reductionMode) { + case 0: // 0 - "none", un-reduced weighted losses with the same shape as + // labels. + output->assign(E); + break; + + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all + // elements of E array + E.reduceNumber(reduce::Sum, *output); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) + *output = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / sum); + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else { + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + } + + if (numOfNonZeroWeights == 0) + (*output) = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + break; + } + } + + if (weightsBroad != weights) delete weightsBroad; + if (newLabels != labels) delete newLabels; + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(sigm_cross_entropy_loss) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(sigm_cross_entropy_loss) { - - auto logitsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and logits must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, logitsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); - Nd4jLong const* outShapeInfo = nullptr; - - if(INT_ARG(0) != 0) // in this case output is scalar - outShapeInfo = ConstantShapeHelper::getInstance().scalarShapeInfo(outType); - else // in this case output has the same shape as labels and logits - outShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); - - return SHAPELIST(outShapeInfo); + auto logitsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and logits must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, logitsShapeInfo), 0, + "SIGM_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have " + "the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "SIGM_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "SIGM_CROSS_ENTROPY_LOSS OP: shapes of weights and labels arrays should " + "be broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = + DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); + Nd4jLong const* outShapeInfo = nullptr; + + if (INT_ARG(0) != 0) // in this case output is scalar + outShapeInfo = ConstantShapeHelper::getInstance().scalarShapeInfo(outType); + else // in this case output has the same shape as labels and logits + outShapeInfo = + ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor( + outType, shape::order(labelsShapeInfo), + shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); + + return SHAPELIST(outShapeInfo); } - - - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) { - - auto logits = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto dLdp = OUTPUT_VARIABLE(0); // dL/dlogits - auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights - auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels - - - NDArray labelsSmoothing = NDArrayFactory::create(logits->dataType(), T_ARG(0), block.launchContext()); - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients - if(reductionMode == 0) - reductionMode = 1; - - // input validation - REQUIRE_TRUE(labels->isSameShape(logits), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(logits)) - weightsBroad = new NDArray(weights->tileToShape(logits->shapeInfo())); - - // If labelsSmoothing is nonzero, smooth the labels towards 1/2: - auto newLabels = labels; - if(labelsSmoothing.e(0) != 0.f) { - newLabels = new NDArray(*labels); - newLabels->applyScalar(scalar::SXELogitsSmoother, labelsSmoothing.e(0), *newLabels); - } - - NDArray E(labels, false, block.launchContext()); - - // logits - labels * logits + log(1 + exp(-logits)) -> take into account numerical stability at large logits - helpers::sigmCrossEntropy(block.launchContext(), logits, newLabels, &E); - - // dLdp = 1 - labels - 1 / (1 + exp(logits)) - helpers::sigmCrossEntropyGrad(block.launchContext(), logits, newLabels, dLdp); - - // dLdl = -logits - labelsSmoothing -= 1.f; - dLdl->assign(*logits * labelsSmoothing); - - switch (reductionMode) { - case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array - - *dLdp *= *weightsBroad; - *dLdl *= *weightsBroad; - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign(E); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - - NDArray temp = *weightsBroad / sum; - *dLdp *= temp; - *dLdl *= temp; - - if(weights->isScalar()) - *dLdw = 0.; - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum * sum)); - } - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - - if (numOfNonZeroWeights == 0) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - auto numOfNonZeroWeightsScalar = NDArrayFactory::create(dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeightsScalar); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - *dLdw /= numOfNonZeroWeightsScalar; - } - else - dLdw->assign(E / numOfNonZeroWeightsScalar); - - NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; - *dLdp *= temp; - *dLdl *= temp; - } - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - if(newLabels != labels) - delete newLabels; - - return Status::OK(); + auto logits = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto dLdp = OUTPUT_VARIABLE(0); // dL/dlogits + auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights + auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels + + NDArray labelsSmoothing = NDArrayFactory::create(logits->dataType(), T_ARG(0), + block.launchContext()); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + // take into account Alex's proposition to treat "none" the same as + // "weighted_sum" mode when calculating gradients + if (reductionMode == 0) reductionMode = 1; + + // input validation + REQUIRE_TRUE(labels->isSameShape(logits), 0, + "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must " + "have the same shapes, but got %s and %s correspondingly!", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(logits).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have " + "the same rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and labels arrays " + "should be broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE( + reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: reduction mode value is not " + "acceptable, possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(logits)) + weightsBroad = new NDArray(weights->tileToShape(logits->shapeInfo())); + + // If labelsSmoothing is nonzero, smooth the labels towards 1/2: + auto newLabels = labels; + if (labelsSmoothing.e(0) != 0.f) { + newLabels = new NDArray(*labels); + newLabels->applyScalar(scalar::SXELogitsSmoother, + labelsSmoothing.e(0), *newLabels); + } + + NDArray E(labels, false, block.launchContext()); + + // logits - labels * logits + log(1 + exp(-logits)) -> take into account + // numerical stability at large logits + helpers::sigmCrossEntropy(block.launchContext(), logits, newLabels, &E); + + // dLdp = 1 - labels - 1 / (1 + exp(logits)) + helpers::sigmCrossEntropyGrad(block.launchContext(), logits, newLabels, dLdp); + + // dLdl = -logits + labelsSmoothing -= 1.f; + dLdl->assign(*logits * labelsSmoothing); + + switch (reductionMode) { + case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to + // sum of all elements of E array + + *dLdp *= *weightsBroad; + *dLdl *= *weightsBroad; + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign(E); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + NDArray temp = *weightsBroad / sum; + *dLdp *= temp; + *dLdl *= temp; + + if (weights->isScalar()) + *dLdw = 0.; + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)) + .reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign( + (E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)); + } + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + + if (numOfNonZeroWeights == 0) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + auto numOfNonZeroWeightsScalar = NDArrayFactory::create( + dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeightsScalar); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + *dLdw /= numOfNonZeroWeightsScalar; + } else + dLdw->assign(E / numOfNonZeroWeightsScalar); + + NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; + *dLdp *= temp; + *dLdl *= temp; + } + break; + } + } + + if (weightsBroad != weights) delete weightsBroad; + if (newLabels != labels) delete newLabels; + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(sigm_cross_entropy_loss_grad) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(sigm_cross_entropy_loss_grad) { - - auto logitsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and logits must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, logitsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); - - auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(logitsShapeInfo, outType, false, block.getWorkspace()); - auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); - auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); - - return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo)); + auto logitsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and logits must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, logitsShapeInfo), 0, + "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must " + "have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have " + "the same rank as labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and labels arrays " + "should be broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = + DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); + + auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType( + logitsShapeInfo, outType, false, block.workspace()); + auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, outType, false, block.workspace()); + auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType( + labelsShapeInfo, outType, false, block.workspace()); + + return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), + CONSTANT(dLdlShapeInfo)); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropy.cpp b/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropy.cpp index 79d46e448977..f08b1e36c521 100644 --- a/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropy.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropy.cpp @@ -24,375 +24,503 @@ #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(softmax_cross_entropy_loss, 3, 1, false, 1, 1) { - - auto logits = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - auto output = OUTPUT_VARIABLE(0); - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - double labelsSmoothing = T_ARG(0); - - // input validation - REQUIRE_TRUE(labels->isSameShape(logits), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - // smoothing is possible for rank of logits/labels > 1 - REQUIRE_TRUE(labels->rankOf() > 1 || (labels->rankOf() == 1 && labelsSmoothing == 0.), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: smoothing is not possible when rank of labels/ logits = 1 !"); - - if(!output->isScalar()) { - // weights array can be single scalar or has the same shape as output, and must be broadcastable to output shape - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == output->rankOf(), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the same rank as output array, but got %i and %i correspondingly!", weights->rankOf(), output->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *output), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: shapes of weights and output arrays should be broadcastable, but got weights = %s and output = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); + auto logits = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + auto output = OUTPUT_VARIABLE(0); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + double labelsSmoothing = T_ARG(0); + + // input validation + REQUIRE_TRUE(labels->isSameShape(logits), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS OP: labels and logits arrays must " + "have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(logits).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE( + reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "SOFTMAX_CROSS_ENTROPY_LOSS OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + // smoothing is possible for rank of logits/labels > 1 + REQUIRE_TRUE( + labels->rankOf() > 1 || (labels->rankOf() == 1 && labelsSmoothing == 0.), + 0, + "SOFTMAX_CROSS_ENTROPY_LOSS OP: smoothing is not possible when rank of " + "labels/ logits = 1 !"); + + if (!output->isScalar()) { + // weights array can be single scalar or has the same shape as output, and + // must be broadcastable to output shape + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == output->rankOf(), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have " + "the same rank as output array, but got %i and %i correspondingly!", + weights->rankOf(), output->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE(weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *output), + 0, + "SOFTMAX_CROSS_ENTROPY_LOSS OP: shapes of weights and output " + "arrays should be broadcastable, but got weights = %s and " + "output = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + } + + // If label_smoothing is nonzero, smooth the labels towards 1/num_classes: + // new_onehot_labels = onehot_labels * (1 - label_smoothing) + label_smoothing + // / num_classes num_classes = labels->sizeAt(1) + NDArray* cLabels = new NDArray(labels->cast(weights->dataType())); + NDArray* newLabels = cLabels; + if (labelsSmoothing != 0.) { + newLabels = new NDArray(cLabels); + newLabels->assign((1.f - labelsSmoothing) * *cLabels + + labelsSmoothing / cLabels->sizeAt(1)); + } + + // main formula: result = - sum_i(lables_i * log(softmax_i)) - sum over last + // dimension softmax_i = exp(logits_i) / sum_j(exp(logits_j)) so result = + // sum_i( lables_i * (log(sum_j(exp(logits_j))) - logits_i) ) for numerical + // stability we use shifted logits (one can approve this using simple math): + // softmax_i = exp(logits_i - maxLogit) / sum_j(exp(logits_j - maxLogit)) + // maxLogit is max among logits_i + + std::vector dimensions = {-1}; + NDArray shiftedLogits = + *logits - logits->reduceAlongDimension(reduce::Max, dimensions, true); + NDArray logSumExp = shiftedLogits.transform(transform::Exp) + .reduceAlongDimension(reduce::Sum, dimensions, true) + .transform(transform::Log); + NDArray E = (*newLabels * (logSumExp - shiftedLogits)) + .reduceAlongDimension(reduce::Sum, dimensions); + + // perform weights broadcasting/tile to E if it is necessary + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(&E)) { + if (E.rankOf() == 1 && weights->isVector() && weights->rankOf() > 1) + weightsBroad = new NDArray( + weights->reshape(weights->ordering(), {weights->lengthOf()})); + else + weightsBroad = new NDArray(weights->tileToShape(E.shapeInfo())); + } + + // multiply E on weights + E *= *weightsBroad; + + switch (reductionMode) { + case 0: // 0 - "none", un-reduced weighted losses with the same shape as + // labels. + output->assign(&E); + break; + + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all + // elements of E array + E.reduceNumber(reduce::Sum, *output); + break; } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + double sum; + if (weights->isScalar()) + sum = weights->e(0) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum).e(0); + + if (sum == 0.) + *output = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / sum); + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else { + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + } + + if (numOfNonZeroWeights == 0) + *output = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + + break; + } + } + + if (weightsBroad != weights) delete weightsBroad; - // If label_smoothing is nonzero, smooth the labels towards 1/num_classes: new_onehot_labels = onehot_labels * (1 - label_smoothing) + label_smoothing / num_classes - // num_classes = labels->sizeAt(1) - NDArray* cLabels = new NDArray(labels->cast(weights->dataType())); - NDArray* newLabels = cLabels; - if(labelsSmoothing != 0.) { - newLabels = new NDArray(cLabels); - newLabels->assign((1.f - labelsSmoothing) * *cLabels + labelsSmoothing / cLabels->sizeAt(1)); - } - - // main formula: result = - sum_i(lables_i * log(softmax_i)) - sum over last dimension - // softmax_i = exp(logits_i) / sum_j(exp(logits_j)) - // so result = sum_i( lables_i * (log(sum_j(exp(logits_j))) - logits_i) ) - // for numerical stability we use shifted logits (one can approve this using simple math): - // softmax_i = exp(logits_i - maxLogit) / sum_j(exp(logits_j - maxLogit)) - // maxLogit is max among logits_i - - - std::vector dimensions = {-1}; - NDArray shiftedLogits = *logits - logits->reduceAlongDimension(reduce::Max, dimensions, true); - NDArray logSumExp = shiftedLogits.transform(transform::Exp).reduceAlongDimension(reduce::Sum, dimensions, true).transform(transform::Log); - NDArray E = (*newLabels * (logSumExp - shiftedLogits)).reduceAlongDimension(reduce::Sum, dimensions); - - // perform weights broadcasting/tile to E if it is necessary - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(&E)) { - if(E.rankOf() == 1 && weights->isVector() && weights->rankOf() > 1) - weightsBroad = new NDArray(weights->reshape(weights->ordering(), {weights->lengthOf()})); - else - weightsBroad = new NDArray(weights->tileToShape(E.shapeInfo())); - } - - // multiply E on weights - E *= *weightsBroad; - - switch (reductionMode) { - case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels. - output->assign(&E); - break; - - case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array - E.reduceNumber(reduce::Sum, *output); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - double sum; - if (weights->isScalar()) - sum = weights->e(0) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum).e(0); - - if (sum == 0.) - *output = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / sum); - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else { - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - } - - if (numOfNonZeroWeights == 0) - *output = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - - if(newLabels != cLabels) - delete newLabels; - - delete cLabels; - - return Status::OK(); + if (newLabels != cLabels) delete newLabels; + + delete cLabels; + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(softmax_cross_entropy_loss) { - - getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS, ALL_INTS}) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS, ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(softmax_cross_entropy_loss) { - - auto logitsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and logits must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); - Nd4jLong const* outShapeInfo = nullptr; - - if(INT_ARG(0) != 0) // in this case output is scalar - outShapeInfo = ConstantShapeHelper::getInstance().scalarShapeInfo(outType); - else { // in this case output has the shape as labels and logits minus last dimension - std::vector dimensions = {-1}; - outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(logitsShapeInfo), dimensions, logitsShapeInfo, false, true, block.getWorkspace()); - - // weights array can be single scalar or has the same rank as output, and must be broadcastable to output - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(outShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the same rank as output array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(outShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, outShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: shapes of weights and output arrays should be broadcastable, but got weights = %s and output = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(outShapeInfo).c_str()); - } - - return SHAPELIST(outShapeInfo); + auto logitsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and logits must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS OP: labels and logits arrays must " + "have the same shapes, but got %s and %s correspondingly!", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); + + DataType outType = + DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); + Nd4jLong const* outShapeInfo = nullptr; + + if (INT_ARG(0) != 0) // in this case output is scalar + outShapeInfo = ConstantShapeHelper::getInstance().scalarShapeInfo(outType); + else { // in this case output has the shape as labels and logits minus last + // dimension + std::vector dimensions = {-1}; + outShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(logitsShapeInfo), dimensions, logitsShapeInfo, false, true, + block.workspace()); + + // weights array can be single scalar or has the same rank as output, and + // must be broadcastable to output + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(outShapeInfo), + 0, + "SOFTMAX_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have " + "the same rank as output array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(outShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, outShapeInfo), + 0, + "SOFTMAX_CROSS_ENTROPY_LOSS OP: shapes of weights and output arrays " + "should be broadcastable, but got weights = %s and output = %s " + "instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(outShapeInfo).c_str()); + } + + return SHAPELIST(outShapeInfo); } - - - - - - - - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) { + auto logits = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto dLdp = OUTPUT_VARIABLE(0); // dL/dlogits + auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights + auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels + + auto labelsSmoothing = T_ARG(0); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + // take into account Alex's proposition to treat "none" the same as + // "weighted_sum" mode when calculating gradients + if (reductionMode == 0) reductionMode = 1; + + std::vector dimensions = {-1}; + + // input validation + REQUIRE_TRUE(labels->isSameShape(logits), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays " + "must have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(logits).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE( + reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: reduction mode value is not " + "acceptable, possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo( + logits->ordering(), dimensions, logits->shapeInfo(), false, false, + block.workspace()); + // weights array can be single scalar or has the same shape as loss, and must + // be broadcastable to loss shape + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == shape::rank(lossShapeInfo), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or " + "have the same rank as loss array, but got %i and %i correspondingly!", + weights->rankOf(), shape::rank(lossShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || ShapeUtils::areShapesBroadcastable( + weights->shapeInfo(), lossShapeInfo), + 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and loss arrays " + "should be broadcastable, but got weights = %s and loss = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(lossShapeInfo).c_str()); + // smoothing is possible for rank of logits/labels > 1 + REQUIRE_TRUE( + labels->rankOf() > 1 || (labels->rankOf() == 1 && labelsSmoothing == 0.), + 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: smoothing is not possible when rank " + "of labels/ logits = 1 !"); + + // If label_smoothing is nonzero, smooth the labels towards 1/num_classes: + // new_onehot_labels = onehot_labels * (1 - label_smoothing) + label_smoothing + // / num_classes num_classes = labels->sizeAt(1) + NDArray* cLabels = new NDArray(labels->cast(weights->dataType())); + NDArray* newLabels = cLabels; + if (labelsSmoothing != 0.) { + newLabels = new NDArray(labels->shapeInfo(), dLdl->dataType(), false, + block.launchContext()); + newLabels->assign((1.f - labelsSmoothing) * *cLabels + + labelsSmoothing / cLabels->sizeAt(1)); + } + + NDArray softmax = + (*logits - logits->reduceAlongDimension(reduce::Max, dimensions, true)) + .transform(transform::Exp); + softmax /= softmax.reduceAlongDimension(reduce::Sum, dimensions, true); + + // dEdp = softmax * sum_i(lables_i) - labels + dLdp->assign( + softmax * newLabels->reduceAlongDimension(reduce::Sum, dimensions, true) - + *newLabels); + + // dEdl = -log(softmax) + dLdl->assign(-softmax.transform(transform::Log) * (1.f - labelsSmoothing)); + + NDArray shiftedLogits = + *logits - logits->reduceAlongDimension(reduce::Max, dimensions, true); + NDArray logSumExp = shiftedLogits.transform(transform::Exp) + .reduceAlongDimension(reduce::Sum, dimensions, true) + .transform(transform::Log); + NDArray E = (*newLabels * (logSumExp - shiftedLogits)) + .reduceAlongDimension(reduce::Sum, dimensions); + + // perform weights broadcasting/tile to E if it is necessary + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(&E)) + weightsBroad = new NDArray(weights->tileToShape(E.shapeInfo())); + + dimensions = ShapeUtils::evalDimsToExclude(dLdp->rankOf(), dimensions); + + switch (reductionMode) { + case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to + // sum of all elements of E array + + if (weights->isScalar() || weights->lengthOf() == 1) { + dLdw->assign(E.reduceNumber(reduce::Sum)); + *dLdp *= *weights; + *dLdl *= *weights; + } else { + dLdp->applyBroadcast(sd::broadcast::Multiply, dimensions, *weightsBroad, + *dLdp); + dLdl->applyBroadcast(sd::broadcast::Multiply, dimensions, *weightsBroad, + *dLdl); + + if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign(E); + } + + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + NDArray sum; + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + if (weights->isScalar() || weights->lengthOf() == 1) { + NDArray temp = *weights / sum; + *dLdp *= temp; + *dLdl *= temp; + *dLdw = 0.; + } else { + NDArray temp = *weightsBroad / sum; + dLdp->applyBroadcast(sd::broadcast::Multiply, dimensions, temp, + *dLdp); + dLdl->applyBroadcast(sd::broadcast::Multiply, dimensions, temp, + *dLdl); + + if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis( + weights->shapeInfo(), weightsBroad->shapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)) + .reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, + true, false, false); + } else + dLdw->assign( + (E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)); + } + } + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + + if (numOfNonZeroWeights == 0) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + if (weights->isScalar() || weights->lengthOf() == 1) { + NDArray temp = *weights / numOfNonZeroWeights; + *dLdp *= temp; + *dLdl *= temp; + dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeights); + } else { + NDArray temp = *weightsBroad / numOfNonZeroWeights; + dLdp->applyBroadcast(sd::broadcast::Multiply, dimensions, temp, + *dLdp); + dLdl->applyBroadcast(sd::broadcast::Multiply, dimensions, temp, + *dLdl); + + if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis( + weights->shapeInfo(), weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + *dLdw /= numOfNonZeroWeights; + } else + dLdw->assign(E / numOfNonZeroWeights); + } + } + break; + } + } - auto logits = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto dLdp = OUTPUT_VARIABLE(0); // dL/dlogits - auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights - auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels - - auto labelsSmoothing = T_ARG(0); + if (weightsBroad != weights) delete weightsBroad; - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients - if(reductionMode == 0) - reductionMode = 1; + if (newLabels != cLabels) delete newLabels; - std::vector dimensions = {-1}; + delete cLabels; - // input validation - REQUIRE_TRUE(labels->isSameShape(logits), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo(logits->ordering(), dimensions, logits->shapeInfo(), false, false, block.getWorkspace()); - // weights array can be single scalar or has the same shape as loss, and must be broadcastable to loss shape - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == shape::rank(lossShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have the same rank as loss array, but got %i and %i correspondingly!", weights->rankOf(), shape::rank(lossShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(weights->shapeInfo(), lossShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and loss arrays should be broadcastable, but got weights = %s and loss = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(lossShapeInfo).c_str()); - // smoothing is possible for rank of logits/labels > 1 - REQUIRE_TRUE(labels->rankOf() > 1 || (labels->rankOf() == 1 && labelsSmoothing == 0.), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: smoothing is not possible when rank of labels/ logits = 1 !"); - - // If label_smoothing is nonzero, smooth the labels towards 1/num_classes: new_onehot_labels = onehot_labels * (1 - label_smoothing) + label_smoothing / num_classes - // num_classes = labels->sizeAt(1) - NDArray* cLabels = new NDArray(labels->cast(weights->dataType())); - NDArray* newLabels = cLabels; - if(labelsSmoothing != 0.) { - newLabels = new NDArray(labels->shapeInfo(), dLdl->dataType(), false, block.launchContext()); - newLabels->assign((1.f - labelsSmoothing) * *cLabels + labelsSmoothing / cLabels->sizeAt(1)); - } - - NDArray softmax = (*logits - logits->reduceAlongDimension(reduce::Max, dimensions, true)).transform(transform::Exp); - softmax /= softmax.reduceAlongDimension(reduce::Sum, dimensions, true); - - // dEdp = softmax * sum_i(lables_i) - labels - dLdp->assign(softmax * newLabels->reduceAlongDimension(reduce::Sum, dimensions, true) - *newLabels); - - // dEdl = -log(softmax) - dLdl->assign(-softmax.transform(transform::Log)* (1.f - labelsSmoothing)); - - NDArray shiftedLogits = *logits - logits->reduceAlongDimension(reduce::Max, dimensions, true); - NDArray logSumExp = shiftedLogits.transform(transform::Exp).reduceAlongDimension(reduce::Sum, dimensions, true).transform(transform::Log); - NDArray E = (*newLabels * (logSumExp - shiftedLogits)).reduceAlongDimension(reduce::Sum, dimensions); - - // perform weights broadcasting/tile to E if it is necessary - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(&E)) - weightsBroad = new NDArray(weights->tileToShape(E.shapeInfo())); - - dimensions = ShapeUtils::evalDimsToExclude(dLdp->rankOf(), dimensions); - - switch (reductionMode) { - case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array - - if(weights->isScalar() || weights->lengthOf() == 1) { - dLdw->assign(E.reduceNumber(reduce::Sum)); - *dLdp *= *weights; - *dLdl *= *weights; - } - else { - dLdp->applyBroadcast(sd::broadcast::Multiply, dimensions, *weightsBroad, *dLdp); - dLdl->applyBroadcast(sd::broadcast::Multiply, dimensions, *weightsBroad, *dLdl); - - if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign(E); - } - - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - NDArray sum; - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - - if(weights->isScalar() || weights->lengthOf() == 1) { - NDArray temp = *weights / sum; - *dLdp *= temp; - *dLdl *= temp; - *dLdw = 0.; - } - else { - - NDArray temp = *weightsBroad / sum; - dLdp->applyBroadcast(sd::broadcast::Multiply, dimensions, temp, *dLdp); - dLdl->applyBroadcast(sd::broadcast::Multiply, dimensions, temp, *dLdl); - - if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); - } - } - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - - if (numOfNonZeroWeights == 0) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - - if(weights->isScalar() || weights->lengthOf() == 1) { - NDArray temp = *weights / numOfNonZeroWeights; - *dLdp *= temp; - *dLdl *= temp; - dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeights); - } - else { - NDArray temp = *weightsBroad / numOfNonZeroWeights; - dLdp->applyBroadcast(sd::broadcast::Multiply, dimensions, temp, *dLdp); - dLdl->applyBroadcast(sd::broadcast::Multiply, dimensions, temp, *dLdl); - - if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - *dLdw /= numOfNonZeroWeights; - } - else - dLdw->assign(E / numOfNonZeroWeights); - } - } - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - - if(newLabels != cLabels) - delete newLabels; - - delete cLabels; - - return Status::OK(); + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(softmax_cross_entropy_loss_grad) { - - getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(3, {ALL_FLOATS}) - ->setAllowedInputTypes(4, {ALL_FLOATS}) - ->setAllowedInputTypes(5, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(3, {ALL_FLOATS}) + ->setAllowedInputTypes(4, {ALL_FLOATS}) + ->setAllowedInputTypes(5, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(softmax_cross_entropy_loss_grad) { - - auto logitsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - std::vector dimensions = {-1}; - - // labels and logits must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); - auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(logitsShapeInfo), dimensions, logitsShapeInfo, false, false, block.getWorkspace()); - // weights array can be single scalar or has the same rank as loss, and must be broadcastable to loss - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(lossShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have the same rank as loss array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(lossShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, lossShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and loss arrays should be broadcastable, but got weights = %s and loss = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(lossShapeInfo).c_str()); - - auto outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); - - auto dLdpShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(outType, shape::order(logitsShapeInfo), shape::shapeOf(logitsShapeInfo), shape::rank(logitsShapeInfo))); - auto dLdwShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(outType, shape::order(weightsShapeInfo), shape::shapeOf(weightsShapeInfo), shape::rank(weightsShapeInfo))); - auto dLdlShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); - - return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo); + auto logitsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + std::vector dimensions = {-1}; + + // labels and logits must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays " + "must have the same shapes, but got %s and %s correspondingly!", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); + auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(logitsShapeInfo), dimensions, logitsShapeInfo, false, false, + block.workspace()); + // weights array can be single scalar or has the same rank as loss, and must + // be broadcastable to loss + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(lossShapeInfo), + 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or " + "have the same rank as loss array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(lossShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, lossShapeInfo), + 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and loss arrays " + "should be broadcastable, but got weights = %s and loss = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(lossShapeInfo).c_str()); + + auto outType = + DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); + + auto dLdpShapeInfo = + ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor( + outType, shape::order(logitsShapeInfo), + shape::shapeOf(logitsShapeInfo), shape::rank(logitsShapeInfo))); + auto dLdwShapeInfo = + ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor( + outType, shape::order(weightsShapeInfo), + shape::shapeOf(weightsShapeInfo), shape::rank(weightsShapeInfo))); + auto dLdlShapeInfo = + ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor( + outType, shape::order(labelsShapeInfo), + shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); + + return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropyWithLogits.cpp b/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropyWithLogits.cpp index 0636450c73d3..7605b1c38204 100644 --- a/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropyWithLogits.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropyWithLogits.cpp @@ -24,118 +24,154 @@ #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(softmax_cross_entropy_loss_with_logits, 2, 1, false, 0, 0) { - auto logits = INPUT_VARIABLE(0); - auto labels = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - - const int classesDim = block.getIArguments()->size() > 0 ? INT_ARG(0) : logits->rankOf()-1; - - // input validation - REQUIRE_TRUE(labels->isSameShape(logits), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); - REQUIRE_TRUE(classesDim < logits->rankOf(), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: class dimension must be smaller than rank of logits, but got %i and %i correspondingly !", classesDim, logits->rankOf()); - - std::vector dimension = {classesDim}; - - auto maxAlongDim = logits->reduceAlongDimension(reduce::Max, {classesDim}, true); - auto logExp = (*logits - maxAlongDim).transform(transform::Exp); - auto logSoftMax = ( logExp / logExp.reduceAlongDimension(reduce::Sum, {classesDim}, true) ).transform(transform::Log); - - (-(*labels) * logSoftMax).reduceAlongDimension(reduce::Sum, *output, dimension); - - return Status::OK(); + auto logits = INPUT_VARIABLE(0); + auto labels = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + + const int classesDim = block.numI() > 0 ? INT_ARG(0) : logits->rankOf() - 1; + + // input validation + REQUIRE_TRUE( + labels->isSameShape(logits), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: labels and logits arrays " + "must have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(logits).c_str()); + REQUIRE_TRUE( + classesDim < logits->rankOf(), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: class dimension must be " + "smaller than rank of logits, but got %i and %i correspondingly !", + classesDim, logits->rankOf()); + + std::vector dimension = {classesDim}; + + auto maxAlongDim = + logits->reduceAlongDimension(reduce::Max, {classesDim}, true); + auto logExp = (*logits - maxAlongDim).transform(transform::Exp); + auto logSoftMax = + (logExp / logExp.reduceAlongDimension(reduce::Sum, {classesDim}, true)) + .transform(transform::Log); + + (-(*labels) * logSoftMax) + .reduceAlongDimension(reduce::Sum, *output, dimension); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(softmax_cross_entropy_loss_with_logits) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(softmax_cross_entropy_loss_with_logits) { - - auto logitsShapeInfo = inputShape->at(0); - auto labelsShapeInfo = inputShape->at(1); - - const int classesDim = block.getIArguments()->size() > 0 ? INT_ARG(0) : -1; - std::vector dimensions = {classesDim}; - - // labels and logits must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); - - auto outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); - auto reducedShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(labelsShapeInfo), dimensions, labelsShapeInfo, outType, false, false, block.getWorkspace()); - - return SHAPELIST(reducedShapeInfo); - + auto logitsShapeInfo = inputShape->at(0); + auto labelsShapeInfo = inputShape->at(1); + + const int classesDim = block.numI() > 0 ? INT_ARG(0) : -1; + std::vector dimensions = {classesDim}; + + // labels and logits must have the same shapes + REQUIRE_TRUE( + shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: labels and logits arrays " + "must have the same shapes, but got %s and %s correspondingly!", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); + + auto outType = + DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); + auto reducedShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(labelsShapeInfo), dimensions, labelsShapeInfo, outType, + false, false, block.workspace()); + + return SHAPELIST(reducedShapeInfo); } - - - - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(softmax_cross_entropy_loss_with_logits_grad, 2, 2, false, 0, 0) { - - auto logits = INPUT_VARIABLE(0); - auto labels = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - - auto dLdp = OUTPUT_VARIABLE(0); // dL/dlogits - auto dLdl = OUTPUT_VARIABLE(1); // dL/dlabels - - const int classesDim = block.getIArguments()->size() > 0 ? INT_ARG(0) : logits->rankOf()-1; - - // input validation - REQUIRE_TRUE(labels->isSameShape(logits), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); - REQUIRE_TRUE(classesDim < logits->rankOf(), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: class dimension must be smaller than rank of logits, but got %i and %i correspondingly !", classesDim, logits->rankOf()); - - std::vector dimension = {classesDim}; - - NDArray softmax = (*logits - logits->reduceAlongDimension(reduce::Max, dimension, true)).transform(transform::Exp); - softmax /= softmax.reduceAlongDimension(reduce::Sum, dimension, true); - - // dEdp = softmax * sum_i(labels_i) - labels - dLdp->assign(softmax * labels->reduceAlongDimension(reduce::Sum, dimension, true) - *labels); - - // dEdl = -log(softmax) - (-softmax).applyTransform(transform::Log, *dLdl); - - return Status::OK(); + auto logits = INPUT_VARIABLE(0); + auto labels = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + + auto dLdp = OUTPUT_VARIABLE(0); // dL/dlogits + auto dLdl = OUTPUT_VARIABLE(1); // dL/dlabels + + const int classesDim = block.numI() > 0 ? INT_ARG(0) : logits->rankOf() - 1; + + // input validation + REQUIRE_TRUE( + labels->isSameShape(logits), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: labels and logits " + "arrays must have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(logits).c_str()); + REQUIRE_TRUE( + classesDim < logits->rankOf(), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: class dimension must be " + "smaller than rank of logits, but got %i and %i correspondingly !", + classesDim, logits->rankOf()); + + std::vector dimension = {classesDim}; + + NDArray softmax = + (*logits - logits->reduceAlongDimension(reduce::Max, dimension, true)) + .transform(transform::Exp); + softmax /= softmax.reduceAlongDimension(reduce::Sum, dimension, true); + + // dEdp = softmax * sum_i(labels_i) - labels + dLdp->assign(softmax * + labels->reduceAlongDimension(reduce::Sum, dimension, true) - + *labels); + + // dEdl = -log(softmax) + (-softmax).applyTransform(transform::Log, *dLdl); + + return Status::OK(); } - ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(softmax_cross_entropy_loss_with_logits_grad) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(softmax_cross_entropy_loss_with_logits_grad) { - - auto logitsShapeInfo = inputShape->at(0); - auto labelsShapeInfo = inputShape->at(1); - - // labels and logits must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); - - auto dLdpShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(outType, shape::order(logitsShapeInfo), shape::shapeOf(logitsShapeInfo), shape::rank(logitsShapeInfo))); - auto dLdlShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); - - return SHAPELIST(dLdpShapeInfo, dLdlShapeInfo); + auto logitsShapeInfo = inputShape->at(0); + auto labelsShapeInfo = inputShape->at(1); + + // labels and logits must have the same shapes + REQUIRE_TRUE( + shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: labels and logits " + "arrays must have the same shapes, but got %s and %s correspondingly!", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); + + DataType outType = + DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); + + auto dLdpShapeInfo = + ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor( + outType, shape::order(logitsShapeInfo), + shape::shapeOf(logitsShapeInfo), shape::rank(logitsShapeInfo))); + auto dLdlShapeInfo = + ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor( + outType, shape::order(labelsShapeInfo), + shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); + + return SHAPELIST(dLdpShapeInfo, dLdlShapeInfo); } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/loss/sparseSoftmaxCrossEntropyWithLogits.cpp b/libnd4j/include/ops/declarable/generic/loss/sparseSoftmaxCrossEntropyWithLogits.cpp index c641bf12f5b0..397feeb35f4f 100644 --- a/libnd4j/include/ops/declarable/generic/loss/sparseSoftmaxCrossEntropyWithLogits.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/sparseSoftmaxCrossEntropyWithLogits.cpp @@ -25,144 +25,189 @@ #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// -CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits, 2, 1, false, 0, 0) { - auto labels = INPUT_VARIABLE(0); - auto logits = INPUT_VARIABLE(1); - - auto output = OUTPUT_VARIABLE(0); - - const int labelsRank = labels->rankOf(); - const int logitsRank = logits->rankOf(); - - // input validation - REQUIRE_TRUE(labelsRank == logitsRank - 1, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: input arrays should satisfy relation (labels_rank = logits_rank - 1), but got labels_rank = %i and logits_rank = %i instead !", labelsRank, logitsRank); - - std::vector labelsShape = labels->getShapeAsVector(); // this is correct - std::vector logitsShape = logits->getShapeAsVector(); - logitsShape.pop_back(); - bool equalSoft = logitsShape == labelsShape; - - REQUIRE_TRUE(equalSoft, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead !", ShapeUtils::shapeAsString(labelsShape).c_str(), ShapeUtils::shapeAsString(logitsShape).c_str()); - - std::vector dimension = {-1}; - - auto maxAlongDim = logits->reduceAlongDimension(reduce::Max, dimension, true); - auto logitsExp = (*logits - maxAlongDim).transform(transform::Exp, nullptr); - auto logSoftMax = -(( logitsExp / logitsExp.reduceAlongDimension(reduce::Sum, dimension, true) ).transform(transform::Log)); - - helpers::scatterForLoss(block.launchContext(), *labels, logSoftMax, *output, false); - - return Status::OK(); +CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits, 2, 1, false, 0, + 0) { + auto labels = INPUT_VARIABLE(0); + auto logits = INPUT_VARIABLE(1); + + auto output = OUTPUT_VARIABLE(0); + + const int labelsRank = labels->rankOf(); + const int logitsRank = logits->rankOf(); + + // input validation + REQUIRE_TRUE(labelsRank == logitsRank - 1, 0, + "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: input arrays " + "should satisfy relation (labels_rank = logits_rank - 1), but " + "got labels_rank = %i and logits_rank = %i instead !", + labelsRank, logitsRank); + + std::vector labelsShape = + labels->getShapeAsVector(); // this is correct + std::vector logitsShape = logits->getShapeAsVector(); + logitsShape.pop_back(); + bool equalSoft = logitsShape == labelsShape; + + REQUIRE_TRUE( + equalSoft, 0, + "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: wrong shape of labels " + "array, its shape should be the same as logits shape with last dimension " + "excluded, however got labels_shape = %s and logits_shape = %s instead !", + ShapeUtils::shapeAsString(labelsShape).c_str(), + ShapeUtils::shapeAsString(logitsShape).c_str()); + + std::vector dimension = {-1}; + + auto maxAlongDim = logits->reduceAlongDimension(reduce::Max, dimension, true); + auto logitsExp = (*logits - maxAlongDim).transform(transform::Exp, nullptr); + auto logSoftMax = -( + (logitsExp / logitsExp.reduceAlongDimension(reduce::Sum, dimension, true)) + .transform(transform::Log)); + + helpers::scatterForLoss(block.launchContext(), *labels, logSoftMax, *output, + false); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(sparse_softmax_cross_entropy_loss_with_logits) { - - getOpDescriptor()->setAllowedInputTypes(0, {ALL_INTS})->setAllowedInputTypes(1, {ALL_FLOATS})->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } - ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(sparse_softmax_cross_entropy_loss_with_logits) { - - auto labelsShapeInfo = inputShape->at(0); - auto logitsShapeInfo = inputShape->at(1); - - REQUIRE_TRUE(labelsShapeInfo[0] == logitsShapeInfo[0] - 1, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: input arrays should satisfy relation (labels_rank = logits_rank - 1), but got labels_rank = %i and logits_rank = %i instead !", labelsShapeInfo[0], logitsShapeInfo[0]); - - bool equalSoft = true; - for (int i = 1; i < labelsShapeInfo[0]; ++i) - if (labelsShapeInfo[i] != logitsShapeInfo[i]) { - equalSoft = false; - break; - } - - REQUIRE_TRUE(equalSoft, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); - - auto outShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, logitsShapeInfo, false, block.getWorkspace()); - - return SHAPELIST(CONSTANT(outShapeInfo)); + auto labelsShapeInfo = inputShape->at(0); + auto logitsShapeInfo = inputShape->at(1); + + REQUIRE_TRUE(labelsShapeInfo[0] == logitsShapeInfo[0] - 1, 0, + "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: input arrays " + "should satisfy relation (labels_rank = logits_rank - 1), but " + "got labels_rank = %i and logits_rank = %i instead !", + labelsShapeInfo[0], logitsShapeInfo[0]); + + bool equalSoft = true; + for (int i = 1; i < labelsShapeInfo[0]; ++i) + if (labelsShapeInfo[i] != logitsShapeInfo[i]) { + equalSoft = false; + break; + } + + REQUIRE_TRUE( + equalSoft, 0, + "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: wrong shape of labels " + "array, its shape should be the same as logits shape with last dimension " + "excluded, however got labels_shape = %s and logits_shape = %s instead !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); + + auto outShapeInfo = ShapeBuilders::copyShapeInfoAndType( + labelsShapeInfo, logitsShapeInfo, false, block.workspace()); + + return SHAPELIST(CONSTANT(outShapeInfo)); } - - - - - - ////////////////////////////////////////////////////////////////////////// -CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, false, 0, 0) { - - auto labels = INPUT_VARIABLE(0); - auto logits = INPUT_VARIABLE(1); - - auto dLdp = OUTPUT_VARIABLE(0); // dL/dlogits - - const int labelsRank = labels->rankOf(); - const int logitsRank = logits->rankOf(); - - // input validation - REQUIRE_TRUE(labelsRank == logitsRank - 1, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: input arrays should satisfy relation (labels_rank = logits_rank - 1), but got labels_rank = %i and logits_rank = %i instead !", labelsRank, logitsRank); - - std::vector labelsShape = labels->getShapeAsVector(); // this is correct - std::vector logitsShape = logits->getShapeAsVector(); - logitsShape.pop_back(); - bool equalSoft = logitsShape == labelsShape; - - REQUIRE_TRUE(equalSoft, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead !", ShapeUtils::shapeAsString(labelsShape).c_str(), ShapeUtils::shapeAsString(logitsShape).c_str()); - - std::vector dimension = {-1}; - - NDArray softmax = (*logits - logits->reduceAlongDimension(reduce::Max, dimension, true)).transform(transform::Exp); - softmax /= softmax.reduceAlongDimension(reduce::Sum, dimension, true); - - // dEdp = softmax - 1 (or 0) - dLdp->assign(softmax); - - // subtract unities at appropriate indexes of dLdp array - helpers::scatterForLoss(block.launchContext(), *labels, *dLdp, *labels /*actually third array is unnecessary for gradient calculation*/, true); - - return Status::OK(); +CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, false, + 0, 0) { + auto labels = INPUT_VARIABLE(0); + auto logits = INPUT_VARIABLE(1); + + auto dLdp = OUTPUT_VARIABLE(0); // dL/dlogits + + const int labelsRank = labels->rankOf(); + const int logitsRank = logits->rankOf(); + + // input validation + REQUIRE_TRUE(labelsRank == logitsRank - 1, 0, + "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: input " + "arrays should satisfy relation (labels_rank = logits_rank - " + "1), but got labels_rank = %i and logits_rank = %i instead !", + labelsRank, logitsRank); + + std::vector labelsShape = + labels->getShapeAsVector(); // this is correct + std::vector logitsShape = logits->getShapeAsVector(); + logitsShape.pop_back(); + bool equalSoft = logitsShape == labelsShape; + + REQUIRE_TRUE(equalSoft, 0, + "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: wrong " + "shape of labels array, its shape should be the same as logits " + "shape with last dimension excluded, however got labels_shape = " + "%s and logits_shape = %s instead !", + ShapeUtils::shapeAsString(labelsShape).c_str(), + ShapeUtils::shapeAsString(logitsShape).c_str()); + + std::vector dimension = {-1}; + + NDArray softmax = + (*logits - logits->reduceAlongDimension(reduce::Max, dimension, true)) + .transform(transform::Exp); + softmax /= softmax.reduceAlongDimension(reduce::Sum, dimension, true); + + // dEdp = softmax - 1 (or 0) + dLdp->assign(softmax); + + // subtract unities at appropriate indexes of dLdp array + helpers::scatterForLoss( + block.launchContext(), *labels, *dLdp, + *labels /*actually third array is unnecessary for gradient calculation*/, + true); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(sparse_softmax_cross_entropy_loss_with_logits_grad) { - - getOpDescriptor()->setAllowedInputTypes(0, {ALL_INTS})->setAllowedInputTypes(1, {ALL_FLOATS})->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } - ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(sparse_softmax_cross_entropy_loss_with_logits_grad) { - - auto labelsShapeInfo = inputShape->at(0); - auto logitsShapeInfo = inputShape->at(1); - - REQUIRE_TRUE(labelsShapeInfo[0] == logitsShapeInfo[0] - 1, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: input arrays should satisfy relation (labels_rank = logits_rank - 1), but got labels_rank = %i and logits_rank = %i instead !", labelsShapeInfo[0], logitsShapeInfo[0]); - - bool equalSoft = true; - for (int i = 1; i < labelsShapeInfo[0]; ++i) - if (labelsShapeInfo[i] != logitsShapeInfo[i]) { - equalSoft = false; - break; - } - - REQUIRE_TRUE(equalSoft, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); - - Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(logitsShapeInfo, outType, false, block.getWorkspace()); - - return SHAPELIST(CONSTANT(dLdpShapeInfo)); + auto labelsShapeInfo = inputShape->at(0); + auto logitsShapeInfo = inputShape->at(1); + + REQUIRE_TRUE(labelsShapeInfo[0] == logitsShapeInfo[0] - 1, 0, + "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: input " + "arrays should satisfy relation (labels_rank = logits_rank - " + "1), but got labels_rank = %i and logits_rank = %i instead !", + labelsShapeInfo[0], logitsShapeInfo[0]); + + bool equalSoft = true; + for (int i = 1; i < labelsShapeInfo[0]; ++i) + if (labelsShapeInfo[i] != logitsShapeInfo[i]) { + equalSoft = false; + break; + } + + REQUIRE_TRUE(equalSoft, 0, + "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: wrong " + "shape of labels array, its shape should be the same as logits " + "shape with last dimension excluded, however got labels_shape = " + "%s and logits_shape = %s instead !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); + + DataType outType = + DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); + + Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType( + logitsShapeInfo, outType, false, block.workspace()); + + return SHAPELIST(CONSTANT(dLdpShapeInfo)); } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nlp/cbow.cpp b/libnd4j/include/ops/declarable/generic/nlp/cbow.cpp index 9b5ed1918875..69ba6bb893b9 100644 --- a/libnd4j/include/ops/declarable/generic/nlp/cbow.cpp +++ b/libnd4j/include/ops/declarable/generic/nlp/cbow.cpp @@ -25,69 +25,74 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(cbow, 15, 15, true, 0, 0) { - auto target = INPUT_VARIABLE(0); - auto ngStarter = INPUT_VARIABLE(1); +namespace ops { +CONFIGURABLE_OP_IMPL(cbow, 15, 15, true, 0, 0) { + auto target = INPUT_VARIABLE(0); + auto ngStarter = INPUT_VARIABLE(1); - // required part - auto context = INPUT_VARIABLE(2); - auto indices = INPUT_VARIABLE(3); - auto codes = INPUT_VARIABLE(4); + // required part + auto context = INPUT_VARIABLE(2); + auto indices = INPUT_VARIABLE(3); + auto codes = INPUT_VARIABLE(4); - auto syn0 = INPUT_VARIABLE(5); - auto syn1 = INPUT_VARIABLE(6); - auto syn1neg = INPUT_VARIABLE(7); + auto syn0 = INPUT_VARIABLE(5); + auto syn1 = INPUT_VARIABLE(6); + auto syn1neg = INPUT_VARIABLE(7); - auto expTable = INPUT_VARIABLE(8); - auto negTable = INPUT_VARIABLE(9); + auto expTable = INPUT_VARIABLE(8); + auto negTable = INPUT_VARIABLE(9); - auto alpha = INPUT_VARIABLE(10); - auto randomValue = INPUT_VARIABLE(11); - auto numLabels = INPUT_VARIABLE(12); + auto alpha = INPUT_VARIABLE(10); + auto randomValue = INPUT_VARIABLE(11); + auto numLabels = INPUT_VARIABLE(12); - auto lockedWords = INPUT_VARIABLE(13); + auto lockedWords = INPUT_VARIABLE(13); - auto inferenceVector = INPUT_VARIABLE(14); + auto inferenceVector = INPUT_VARIABLE(14); - auto numWorkers = block.numI() > 0 ? INT_ARG(0) : omp_get_max_threads(); - auto nsRounds = block.numI() > 1 ? INT_ARG(1) : 0; + auto numWorkers = block.numI() > 0 ? INT_ARG(0) : omp_get_max_threads(); + auto nsRounds = block.numI() > 1 ? INT_ARG(1) : 0; - auto trainWords = block.numB() > 0 ? B_ARG(0) : true; - auto isInference = block.numB() > 1 ? B_ARG(1) : false; + auto trainWords = block.numB() > 0 ? B_ARG(0) : true; + auto isInference = block.numB() > 1 ? B_ARG(1) : false; - REQUIRE_TRUE(block.isInplace(), 0, "CBOW: this operation requires inplace execution only"); + REQUIRE_TRUE(block.isInplace(), 0, + "CBOW: this operation requires inplace execution only"); - REQUIRE_TRUE(syn0->dataType() == syn1->dataType() && syn0->dataType() == syn1neg->dataType(), 0, "CBOW: all syn tables must have the same data type"); - REQUIRE_TRUE(syn0->dataType() == expTable->dataType(), 0, "CBOW: expTable must have the same data type as syn0 table"); + REQUIRE_TRUE(syn0->dataType() == syn1->dataType() && + syn0->dataType() == syn1neg->dataType(), + 0, "CBOW: all syn tables must have the same data type"); + REQUIRE_TRUE(syn0->dataType() == expTable->dataType(), 0, + "CBOW: expTable must have the same data type as syn0 table"); + sd::ops::helpers::cbow(*syn0, *syn1, *syn1neg, *expTable, *negTable, *target, + *ngStarter, nsRounds, *context, *lockedWords, *indices, + *codes, *alpha, *randomValue, *numLabels, + *inferenceVector, trainWords, numWorkers); - sd::ops::helpers::cbow(*syn0, *syn1, *syn1neg, *expTable, *negTable, *target, *ngStarter, nsRounds, *context, *lockedWords, *indices, *codes, *alpha, *randomValue, *numLabels, *inferenceVector, trainWords, numWorkers); - - - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(cbow) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::INT32) - ->setAllowedInputTypes(1, sd::DataType::INT32) - ->setAllowedInputTypes(2, sd::DataType::INT32) - ->setAllowedInputTypes(3, sd::DataType::INT32) - ->setAllowedInputTypes(4, sd::DataType::INT8) - ->setAllowedInputTypes(5, {ALL_FLOATS}) - ->setAllowedInputTypes(6, {ALL_FLOATS}) - ->setAllowedInputTypes(7, {ALL_FLOATS}) - ->setAllowedInputTypes(8, {ALL_FLOATS}) - ->setAllowedInputTypes(9, {ALL_FLOATS}) - ->setAllowedInputTypes(10, {ALL_FLOATS}) - ->setAllowedInputTypes(11, sd::DataType::INT64) - ->setAllowedInputTypes(12, sd::DataType::INT32) - ->setAllowedInputTypes(13, sd::DataType::INT32) - ->setAllowedInputTypes(14, {ALL_FLOATS}) - ->setAllowedOutputTypes(sd::DataType::ANY); - } - } +DECLARE_TYPES(cbow) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::INT32) + ->setAllowedInputTypes(1, sd::DataType::INT32) + ->setAllowedInputTypes(2, sd::DataType::INT32) + ->setAllowedInputTypes(3, sd::DataType::INT32) + ->setAllowedInputTypes(4, sd::DataType::INT8) + ->setAllowedInputTypes(5, {ALL_FLOATS}) + ->setAllowedInputTypes(6, {ALL_FLOATS}) + ->setAllowedInputTypes(7, {ALL_FLOATS}) + ->setAllowedInputTypes(8, {ALL_FLOATS}) + ->setAllowedInputTypes(9, {ALL_FLOATS}) + ->setAllowedInputTypes(10, {ALL_FLOATS}) + ->setAllowedInputTypes(11, sd::DataType::INT64) + ->setAllowedInputTypes(12, sd::DataType::INT32) + ->setAllowedInputTypes(13, sd::DataType::INT32) + ->setAllowedInputTypes(14, {ALL_FLOATS}) + ->setAllowedOutputTypes(sd::DataType::ANY); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nlp/skipgram.cpp b/libnd4j/include/ops/declarable/generic/nlp/skipgram.cpp index 921662fa6b5e..d9212086284a 100644 --- a/libnd4j/include/ops/declarable/generic/nlp/skipgram.cpp +++ b/libnd4j/include/ops/declarable/generic/nlp/skipgram.cpp @@ -25,69 +25,76 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(skipgram, 12, 12, true, 0, 0) { - auto target = INPUT_VARIABLE(0); - auto ngStarter = INPUT_VARIABLE(1); +namespace ops { +CONFIGURABLE_OP_IMPL(skipgram, 12, 12, true, 0, 0) { + auto target = INPUT_VARIABLE(0); + auto ngStarter = INPUT_VARIABLE(1); - // required part - auto indices = INPUT_VARIABLE(2); - auto codes = INPUT_VARIABLE(3); + // required part + auto indices = INPUT_VARIABLE(2); + auto codes = INPUT_VARIABLE(3); - auto syn0 = INPUT_VARIABLE(4); - auto syn1 = INPUT_VARIABLE(5); - auto syn1neg = INPUT_VARIABLE(6); + auto syn0 = INPUT_VARIABLE(4); + auto syn1 = INPUT_VARIABLE(5); + auto syn1neg = INPUT_VARIABLE(6); - auto expTable = INPUT_VARIABLE(7); - auto negTable = INPUT_VARIABLE(8); + auto expTable = INPUT_VARIABLE(7); + auto negTable = INPUT_VARIABLE(8); - auto alpha = INPUT_VARIABLE(9); - auto randomValue = INPUT_VARIABLE(10); + auto alpha = INPUT_VARIABLE(9); + auto randomValue = INPUT_VARIABLE(10); - auto inferenceVector = INPUT_VARIABLE(11); + auto inferenceVector = INPUT_VARIABLE(11); - //auto neu1e = INPUT_VARIABLE(12); + // auto neu1e = INPUT_VARIABLE(12); - auto numWorkers = block.numI() > 0 ? INT_ARG(0) : omp_get_max_threads(); - auto nsRounds = block.numI() > 1 ? INT_ARG(1) : 0; + auto numWorkers = block.numI() > 0 ? INT_ARG(0) : omp_get_max_threads(); + auto nsRounds = block.numI() > 1 ? INT_ARG(1) : 0; - auto isInference = block.numB() > 0 ? B_ARG(0) : false; - auto isPreciseMode = block.numB() > 1 ? B_ARG(1) : false; + auto isInference = block.numB() > 0 ? B_ARG(0) : false; + auto isPreciseMode = block.numB() > 1 ? B_ARG(1) : false; - REQUIRE_TRUE(block.isInplace(), 0, "SkipGram: this operation requires inplace execution only"); + REQUIRE_TRUE(block.isInplace(), 0, + "SkipGram: this operation requires inplace execution only"); - REQUIRE_TRUE(syn0->dataType() == syn1->dataType() && syn0->dataType() == syn1neg->dataType(), 0, "SkipGram: all syn tables must have the same data type"); - REQUIRE_TRUE(syn0->dataType() == expTable->dataType(), 0, "SkipGram: expTable must have the same data type as syn0 table"); + REQUIRE_TRUE(syn0->dataType() == syn1->dataType() && + syn0->dataType() == syn1neg->dataType(), + 0, "SkipGram: all syn tables must have the same data type"); + REQUIRE_TRUE(syn0->dataType() == expTable->dataType(), 0, + "SkipGram: expTable must have the same data type as syn0 table"); + sd::ops::helpers::skipgram(*syn0, *syn1, *syn1neg, *expTable, *negTable, + *target, *ngStarter, nsRounds, *indices, *codes, + *alpha, *randomValue, *inferenceVector, + isPreciseMode, numWorkers); - sd::ops::helpers::skipgram(*syn0, *syn1, *syn1neg, *expTable, *negTable, *target, *ngStarter, nsRounds, *indices, *codes, *alpha, *randomValue, *inferenceVector, isPreciseMode, numWorkers); - - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(skipgram) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::INT32) - ->setAllowedInputTypes(1, sd::DataType::INT32) - ->setAllowedInputTypes(2, sd::DataType::INT32) - ->setAllowedInputTypes(3, sd::DataType::INT8) - ->setAllowedInputTypes(4, {ALL_FLOATS}) - ->setAllowedInputTypes(5, {ALL_FLOATS}) - ->setAllowedInputTypes(6, {ALL_FLOATS}) - ->setAllowedInputTypes(7, {ALL_FLOATS}) - ->setAllowedInputTypes(8, {ALL_FLOATS}) - ->setAllowedInputTypes(9, {ALL_FLOATS}) - ->setAllowedInputTypes(10, sd::DataType::INT64) - ->setAllowedInputTypes(11, {ALL_FLOATS}) - ->setAllowedOutputTypes(sd::DataType::ANY); - } +DECLARE_TYPES(skipgram) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::INT32) + ->setAllowedInputTypes(1, sd::DataType::INT32) + ->setAllowedInputTypes(2, sd::DataType::INT32) + ->setAllowedInputTypes(3, sd::DataType::INT8) + ->setAllowedInputTypes(4, {ALL_FLOATS}) + ->setAllowedInputTypes(5, {ALL_FLOATS}) + ->setAllowedInputTypes(6, {ALL_FLOATS}) + ->setAllowedInputTypes(7, {ALL_FLOATS}) + ->setAllowedInputTypes(8, {ALL_FLOATS}) + ->setAllowedInputTypes(9, {ALL_FLOATS}) + ->setAllowedInputTypes(10, sd::DataType::INT64) + ->setAllowedInputTypes(11, {ALL_FLOATS}) + ->setAllowedOutputTypes(sd::DataType::ANY); +} - /* - DECLARE_SHAPE_FN(skipgram) { - return SHAPELIST(ShapeBuilders::createScalarShapeInfo(DataType::INT8, block.getWorkspace())); - } - */ - } +/* +DECLARE_SHAPE_FN(skipgram) { + return SHAPELIST(ShapeBuilders::createScalarShapeInfo(DataType::INT8, +block.workspace())); } +*/ +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/crelu.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/crelu.cpp index df107451a40f..d2f499cadd82 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/crelu.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/crelu.cpp @@ -22,93 +22,94 @@ #if NOT_EXCLUDED(OP_crelu) #include -#include #include +#include + namespace sd { - namespace ops { - CUSTOM_OP_IMPL(crelu, 1, 1, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - - REQUIRE_TRUE(x->isR(), 0, "CRELU: input must be real type"); - - auto tmp = x->dup(); - tmp.applyTransform(sd::transform::Neg, tmp); - - auto z = OUTPUT_VARIABLE(0); - - helpers::concat(block.launchContext(), {x, &tmp}, *z, x->rankOf()-1); - // NDArrayFactory::concat({x, tmp}, -1, z); - - // TODO: make this configurable? - double threshold = 0.0; - z->applyScalar(sd::scalar::RELU, threshold, *z); - - STORE_RESULT(z); - - return Status::OK(); - } - - DECLARE_TYPES(crelu) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setSameMode(true); - } - - DECLARE_SHAPE_FN(crelu) { - auto inShape = inputShape->at(0); - std::vector shape; - for (int e = 0; e < shape::rank(inShape); e++) - shape.emplace_back(shape::shapeOf(inShape)[e]); - - shape[shape.size()-1] *= 2; - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), shape); - - return SHAPELIST(newShape); - } - - CUSTOM_OP_IMPL(crelu_bp, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilonNext = INPUT_VARIABLE(1); - auto epsilon = OUTPUT_VARIABLE(0); - - // at first step we build fwd activation - sd::ops::crelu op; - auto tmpResult = op.evaluate({input}); - if (tmpResult.status() != ND4J_STATUS_OK) - return tmpResult.status(); - - auto actv = tmpResult.at(0); - - // now we do RELU backward pass - //actv->applyPairwiseTransform(pairwise::RELUDerivativeE, *epsilon, nullptr); - helpers::reluDerivative(block.launchContext(), actv, epsilonNext); - // now we split updated array into 2 chunks along last dimension - sd::ops::concat_bp opc; - auto dec = opc.evaluate({input, input, actv}, {-1}); - if (dec.status() != ND4J_STATUS_OK) - return dec.status(); - - // and now we subtract two parts of epsilons and pass result out - auto pos = dec.at(0); - auto neg = dec.at(1); - - pos->applyPairwiseTransform(sd::pairwise::Subtract, *neg, *epsilon); - - return ND4J_STATUS_OK; - } - - DECLARE_TYPES(crelu_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - - DECLARE_SHAPE_FN(crelu_bp) { - auto inShape = inputShape->at(0); - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(inShape))); - } - } +namespace ops { +CUSTOM_OP_IMPL(crelu, 1, 1, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + + REQUIRE_TRUE(x->isR(), 0, "CRELU: input must be real type"); + + auto tmp = x->dup(); + tmp.applyTransform(sd::transform::Neg, tmp); + + auto z = OUTPUT_VARIABLE(0); + + helpers::concat(block.launchContext(), {x, &tmp}, *z, x->rankOf() - 1); + // NDArrayFactory::concat({x, tmp}, -1, z); + + // TODO: make this configurable? + double threshold = 0.0; + z->applyScalar(sd::scalar::RELU, threshold, *z); + + STORE_RESULT(z); + + return Status::OK(); +} + +DECLARE_TYPES(crelu) { + getOpDescriptor()->setAllowedInputTypes(0, DataType::ANY)->setSameMode(true); +} + +DECLARE_SHAPE_FN(crelu) { + auto inShape = inputShape->at(0); + std::vector shape; + for (int e = 0; e < shape::rank(inShape); e++) + shape.emplace_back(shape::shapeOf(inShape)[e]); + + shape[shape.size() - 1] *= 2; + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inShape), shape::order(inShape), shape); + + return SHAPELIST(newShape); +} + +CUSTOM_OP_IMPL(crelu_bp, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilonNext = INPUT_VARIABLE(1); + auto epsilon = OUTPUT_VARIABLE(0); + + // at first step we build fwd activation + sd::ops::crelu op; + auto tmpResult = op.evaluate({input}); + if (tmpResult.status() != ND4J_STATUS_OK) return tmpResult.status(); + + auto actv = tmpResult.at(0); + + // now we do RELU backward pass + // actv->applyPairwiseTransform(pairwise::RELUDerivativeE, *epsilon, nullptr); + helpers::reluDerivative(block.launchContext(), &actv, epsilonNext); + // now we split updated array into 2 chunks along last dimension + sd::ops::concat_bp opc; + auto dec = opc.evaluate({input, input, &actv}, {-1}); + if (dec.status() != ND4J_STATUS_OK) return dec.status(); + + // and now we subtract two parts of epsilons and pass result out + auto pos = dec.at(0); + auto neg = dec.at(1); + + pos.applyPairwiseTransform(sd::pairwise::Subtract, neg, *epsilon); + + return ND4J_STATUS_OK; +} + +DECLARE_TYPES(crelu_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); +} + +DECLARE_SHAPE_FN(crelu_bp) { + auto inShape = inputShape->at(0); + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(inShape))); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/cube.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/cube.cpp index d71906d2908f..a8446b3419ff 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/cube.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/cube.cpp @@ -18,7 +18,6 @@ // @author raver119@gmail.com // - #include #if NOT_EXCLUDED(OP_cube) @@ -26,41 +25,42 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(cube, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CONFIGURABLE_OP_IMPL(cube, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - input->applyTransform(sd::transform::Cube, *output); - STORE_RESULT(output); + input->applyTransform(sd::transform::Cube, *output); + STORE_RESULT(output); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(cube) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setSameMode(true); - } +DECLARE_TYPES(cube) { + getOpDescriptor()->setAllowedInputTypes(0, DataType::ANY)->setSameMode(true); +} - CONFIGURABLE_OP_IMPL(cube_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(cube_bp, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - //input->applyPairwiseTransform(pairwise::CUBEDerivativeE, epsilon, z, nullptr); - helpers::cubeDerivative(block.launchContext(), input, epsilon, z); - return Status::OK(); - } + // input->applyPairwiseTransform(pairwise::CUBEDerivativeE, epsilon, z, + // nullptr); + helpers::cubeDerivative(block.launchContext(), input, epsilon, z); + return Status::OK(); +} - DECLARE_TYPES(cube_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(cube_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/elu.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/elu.cpp index f89f0d2c7c17..accd6f2d2559 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/elu.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/elu.cpp @@ -24,47 +24,47 @@ #include #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(elu, 1, 1, true, -2, 0) { +namespace ops { +CONFIGURABLE_OP_IMPL(elu, 1, 1, true, -2, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + const auto alpha = block.numT() > 0 ? T_ARG(0) : 1.f; - const auto alpha = block.numT() > 0 ? T_ARG(0) : 1.f; + input->applyScalar(sd::scalar::ELU, alpha, *output); - input->applyScalar(sd::scalar::ELU, alpha, *output); - - return Status::OK(); - } - - DECLARE_TYPES(elu) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } + return Status::OK(); +} - CONFIGURABLE_OP_IMPL(elu_bp, 2, 1, true, -2, 0) { +DECLARE_TYPES(elu) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(elu_bp, 2, 1, true, -2, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - const auto alpha = block.numT() > 0 ? T_ARG(0) : 1.f; + const auto alpha = block.numT() > 0 ? T_ARG(0) : 1.f; - // input->applyPairwiseTransform(pairwise::ELUDerivativeE, epsilon, output); - helpers::eluDerivative(block.launchContext(), input, epsilon, output, alpha); + // input->applyPairwiseTransform(pairwise::ELUDerivativeE, epsilon, output); + helpers::eluDerivative(block.launchContext(), input, epsilon, output, alpha); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(elu_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(elu_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/hardsigmoid.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/hardsigmoid.cpp index ba498fea98db..bb7261dff267 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/hardsigmoid.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/hardsigmoid.cpp @@ -25,41 +25,44 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(hardsigmoid, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CONFIGURABLE_OP_IMPL(hardsigmoid, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - input->applyTransform(sd::transform::HardSigmoid, *output); - STORE_RESULT(output); + input->applyTransform(sd::transform::HardSigmoid, *output); + STORE_RESULT(output); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(hardsigmoid) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(hardsigmoid) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} - CONFIGURABLE_OP_IMPL(hardsigmoid_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(hardsigmoid_bp, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - //input->applyPairwiseTransform(pairwise::HardSigmoidDerivativeE, epsilon, z, nullptr); - helpers::hardSigmoidDerivative(block.launchContext(), input, epsilon, z); - return Status::OK(); - } + // input->applyPairwiseTransform(pairwise::HardSigmoidDerivativeE, epsilon, z, + // nullptr); + helpers::hardSigmoidDerivative(block.launchContext(), input, epsilon, z); + return Status::OK(); +} - DECLARE_TYPES(hardsigmoid_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(hardsigmoid_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/hardtanh.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/hardtanh.cpp index 0a245e6a086e..66fadedeab3c 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/hardtanh.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/hardtanh.cpp @@ -25,41 +25,44 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(hardtanh, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CONFIGURABLE_OP_IMPL(hardtanh, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - input->applyTransform(sd::transform::HardTanh, *output); - STORE_RESULT(output); + input->applyTransform(sd::transform::HardTanh, *output); + STORE_RESULT(output); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(hardtanh) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(hardtanh) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} - CONFIGURABLE_OP_IMPL(hardtanh_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(hardtanh_bp, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - //input->applyPairwiseTransform(pairwise::HardTanhDerivativeE, epsilon, z, nullptr); - helpers::hardTanhDerivative(block.launchContext(), input, epsilon, z); - return Status::OK(); - } + // input->applyPairwiseTransform(pairwise::HardTanhDerivativeE, epsilon, z, + // nullptr); + helpers::hardTanhDerivative(block.launchContext(), input, epsilon, z); + return Status::OK(); +} - DECLARE_TYPES(hardtanh_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(hardtanh_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/identity.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/identity.cpp index 38e4a3ae88d6..799a895b34b6 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/identity.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/identity.cpp @@ -24,47 +24,43 @@ #include namespace sd { - namespace ops { - OP_IMPL(identity, 1, 1, true) { - auto z = OUTPUT_VARIABLE(0); +namespace ops { +OP_IMPL(identity, 1, 1, true) { + auto z = OUTPUT_VARIABLE(0); - if (!block.isInplace()) { - auto first = INPUT_VARIABLE(0); + if (!block.isInplace()) { + auto first = INPUT_VARIABLE(0); - // we hope for memcpy here - z->assign(first); - } - - return Status::OK(); - } - DECLARE_SYN(linear, identity); - - DECLARE_TYPES(identity) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setSameMode(true); - } + // we hope for memcpy here + z->assign(first); + } + return Status::OK(); +} +DECLARE_SYN(linear, identity); - OP_IMPL(identity_bp, 2, 1, true) { - auto first = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +DECLARE_TYPES(identity) { + getOpDescriptor()->setAllowedInputTypes(0, DataType::ANY)->setSameMode(true); +} - z->assign(epsilon); +OP_IMPL(identity_bp, 2, 1, true) { + auto first = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - return Status::OK(); - } - DECLARE_SYN(LinearGrad, identity_bp); + z->assign(epsilon); + return Status::OK(); +} +DECLARE_SYN(LinearGrad, identity_bp); - DECLARE_TYPES(identity_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } - } +DECLARE_TYPES(identity_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/identity_n.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/identity_n.cpp index 4b7088660de4..b96628165ce4 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/identity_n.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/identity_n.cpp @@ -24,39 +24,38 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(identity_n, 1, 1, true, 0, 0) { - - // just for lulz - if (!block.isInplace()) { - for (Nd4jLong i = 0; i < block.width(); ++i) { - auto x = INPUT_VARIABLE(i); - auto z = OUTPUT_VARIABLE(i); - - x->applyTransform(transform::Identity, *z); - } - } - - return Status::OK(); - } - - DECLARE_SHAPE_FN(identity_n) { - auto shapes = SHAPELIST(); - for (size_t i = 0; i < inputShape->size(); ++i) { - Nd4jLong* shape; - COPY_SHAPE_EX(inputShape->at(i), shape, block.getWorkspace()); - shapes->push_back(CONSTANT(shape)); - } - return shapes; - } - - DECLARE_TYPES(identity_n) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::ANY); - } - +namespace ops { +CUSTOM_OP_IMPL(identity_n, 1, 1, true, 0, 0) { + // just for lulz + if (!block.isInplace()) { + for (Nd4jLong i = 0; i < block.width(); ++i) { + auto x = INPUT_VARIABLE(i); + auto z = OUTPUT_VARIABLE(i); + + x->applyTransform(transform::Identity, *z); } + } + + return Status::OK(); +} + +DECLARE_SHAPE_FN(identity_n) { + auto shapes = SHAPELIST(); + for (size_t i = 0; i < inputShape->size(); ++i) { + Nd4jLong* shape; + COPY_SHAPE_EX(inputShape->at(i), shape, block.workspace()); + shapes->push_back(CONSTANT(shape)); + } + return shapes; } +DECLARE_TYPES(identity_n) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::ANY); +} + +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/lrelu.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/lrelu.cpp index 2f4c2dc041a3..df6d11dcd314 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/lrelu.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/lrelu.cpp @@ -24,45 +24,48 @@ #include #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(lrelu, 1, 1, true, -2, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CONFIGURABLE_OP_IMPL(lrelu, 1, 1, true, -2, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - float alpha = block.numT() > 0 ? T_ARG(0) : 0.01f; + float alpha = block.numT() > 0 ? T_ARG(0) : 0.01f; - input->applyScalar(sd::scalar::LeakyRELU, alpha, *output); - STORE_RESULT(output); + input->applyScalar(sd::scalar::LeakyRELU, alpha, *output); + STORE_RESULT(output); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(lrelu) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(lrelu) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} - CONFIGURABLE_OP_IMPL(lrelu_bp, 2, 1, true, -2, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(lrelu_bp, 2, 1, true, -2, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - float alpha = block.numT() > 0 ? T_ARG(0) : 0.01f; + float alpha = block.numT() > 0 ? T_ARG(0) : 0.01f; - //input->applyPairwiseTransform(pairwise::LRELUDerivativeE, epsilon, z, nullptr); - helpers::leakyReluDerivative(block.launchContext(), input, epsilon, z, alpha); - return Status::OK(); - } + // input->applyPairwiseTransform(pairwise::LRELUDerivativeE, epsilon, z, + // nullptr); + helpers::leakyReluDerivative(block.launchContext(), input, epsilon, z, alpha); + return Status::OK(); +} - DECLARE_TYPES(lrelu_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(lrelu_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/prelu.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/prelu.cpp index b7d260a4c740..b4d29921527b 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/prelu.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/prelu.cpp @@ -18,135 +18,163 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 24.07.2018 // - #include #if NOT_EXCLUDED(OP_prelu) #include #include + #include namespace sd { -namespace ops { - +namespace ops { //////////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(prelu, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto alpha = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - - std::vector sharedAxes = *block.getIArguments(); - - const int inputRank = input->rankOf(); - const int numSharedAxes = sharedAxes.size(); // can be zero as well - const Nd4jLong inputLen = input->lengthOf(); - const Nd4jLong alphaLen = alpha->lengthOf(); - const std::vector inputShape = input->getShapeAsVector(); - const std::vector alphaShape = alpha->getShapeAsVector(); - - //***** input validation *****// - std::vector expectedAlphaShape(&inputShape[1], &inputShape[inputRank]); - - REQUIRE_TRUE(inputRank > 1, 0, "PRELU OP: wrong rank of input array, expected rank should be > 1, but got %i instead !", inputRank); - - for(int i = 0; i < numSharedAxes; ++i) { - if(sharedAxes[i] <= 0) - sharedAxes[i] += inputRank - 1; - REQUIRE_TRUE(1 <= sharedAxes[i] && sharedAxes[i] <= inputRank - 1, 0, "PRELU OP: wrong axis value %i in sharedAxes at position %i, axis value must be within range [1, input_rank-1] !", sharedAxes[i], i); - expectedAlphaShape[sharedAxes[i] - 1] = 1; - } - - Nd4jLong product = 1; - for(const auto& item : expectedAlphaShape) - product *= item; - - REQUIRE_TRUE(product == alphaLen, 0, "PRELU OP: wrong shape of alpha array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedAlphaShape).c_str(), ShapeUtils::shapeAsString(alphaShape).c_str()); - // ***** end of validation ***** // - - helpers::prelu(block.launchContext(), *input, alphaShape != expectedAlphaShape ? alpha->reshape(alpha->ordering(), expectedAlphaShape) : *alpha, *output); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto alpha = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + + std::vector sharedAxes = block.getIArguments(); + + const int inputRank = input->rankOf(); + const int numSharedAxes = sharedAxes.size(); // can be zero as well + const Nd4jLong inputLen = input->lengthOf(); + const Nd4jLong alphaLen = alpha->lengthOf(); + const std::vector inputShape = input->getShapeAsVector(); + const std::vector alphaShape = alpha->getShapeAsVector(); + + //***** input validation *****// + std::vector expectedAlphaShape(&inputShape[1], + &inputShape[inputRank]); + + REQUIRE_TRUE(inputRank > 1, 0, + "PRELU OP: wrong rank of input array, expected rank should be > " + "1, but got %i instead !", + inputRank); + + for (int i = 0; i < numSharedAxes; ++i) { + if (sharedAxes[i] <= 0) sharedAxes[i] += inputRank - 1; + REQUIRE_TRUE(1 <= sharedAxes[i] && sharedAxes[i] <= inputRank - 1, 0, + "PRELU OP: wrong axis value %i in sharedAxes at position %i, " + "axis value must be within range [1, input_rank-1] !", + sharedAxes[i], i); + expectedAlphaShape[sharedAxes[i] - 1] = 1; + } + + Nd4jLong product = 1; + for (const auto& item : expectedAlphaShape) product *= item; + + REQUIRE_TRUE(product == alphaLen, 0, + "PRELU OP: wrong shape of alpha array, expected is %s, but got " + "%s instead !", + ShapeUtils::shapeAsString(expectedAlphaShape).c_str(), + ShapeUtils::shapeAsString(alphaShape).c_str()); + // ***** end of validation ***** // + + helpers::prelu(block.launchContext(), *input, + alphaShape != expectedAlphaShape + ? alpha->reshape(alpha->ordering(), expectedAlphaShape) + : *alpha, + *output); + + return Status::OK(); } - - DECLARE_TYPES(prelu) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } - +DECLARE_TYPES(prelu) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} //////////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(prelu_bp, 3, 2, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto alpha = INPUT_VARIABLE(1); - auto dLdO = INPUT_VARIABLE(2); - - auto dLdI = OUTPUT_VARIABLE(0); - auto dLdA = OUTPUT_VARIABLE(1); - - std::vector sharedAxes = *block.getIArguments(); - - const int inputRank = input->rankOf(); - const int numSharedAxes = sharedAxes.size(); // can be zero as well - const Nd4jLong inputLen = input->lengthOf(); - const Nd4jLong alphaLen = alpha->lengthOf(); - const std::vector inputShape = input->getShapeAsVector(); - const std::vector alphaShape = alpha->getShapeAsVector(); - - //***** input validation *****// - - // temporary limitation imposed by Yurii - REQUIRE_TRUE(inputRank <= MAX_RANK/2, 0, "rank of input array should be <= MAX_RANK/2, but got %i instead!", inputRank); - REQUIRE_TRUE(input->lengthOf() / alpha->lengthOf() <= MAX_RANK*2, 0, "the length of input array should be no more than MAX_RANK*2 times the alpha array length, but got %lld and %lld correspondingly!", input->lengthOf(), alpha->lengthOf()); - - std::vector expectedAlphaShape(&inputShape[1], &inputShape[inputRank]); - - REQUIRE_TRUE(inputRank > 1, 0, "PRELU_BP OP: wrong rank of input array, expected rank should be > 1, but got %i instead !", inputRank); - - for(int i = 0; i < numSharedAxes; ++i) { - if(sharedAxes[i] <= 0) - sharedAxes[i] += inputRank - 1; - REQUIRE_TRUE(1 <= sharedAxes[i] && sharedAxes[i] <= inputRank - 1, 0, "PRELU_BP OP: wrong axis value %i in sharedAxes at position %i, axis value must be within range [1, input_rank-1] !", sharedAxes[i], i); - expectedAlphaShape[sharedAxes[i] - 1] = 1; - } - - Nd4jLong product = 1; - for(const auto& item : expectedAlphaShape) - product *= item; - - REQUIRE_TRUE(product == alphaLen, 0, "PRELU_BP OP: wrong shape of alpha array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedAlphaShape).c_str(), ShapeUtils::shapeAsString(alphaShape).c_str()); - // ***** end of validation ***** // - - - if(alphaShape != expectedAlphaShape) { - alpha = new NDArray(alpha->reshape(alpha->ordering(), expectedAlphaShape)); - dLdA = new NDArray(dLdA->reshape(dLdA->ordering(), expectedAlphaShape)); - } - - helpers::preluBP(block.launchContext(), *input, *alpha, *dLdO, *dLdI, *dLdA); - - if(alphaShape != expectedAlphaShape) { - delete alpha; - delete dLdA; - } - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto alpha = INPUT_VARIABLE(1); + auto dLdO = INPUT_VARIABLE(2); + + auto dLdI = OUTPUT_VARIABLE(0); + auto dLdA = OUTPUT_VARIABLE(1); + + std::vector sharedAxes = block.getIArguments(); + + const int inputRank = input->rankOf(); + const int numSharedAxes = sharedAxes.size(); // can be zero as well + const Nd4jLong inputLen = input->lengthOf(); + const Nd4jLong alphaLen = alpha->lengthOf(); + const std::vector inputShape = input->getShapeAsVector(); + const std::vector alphaShape = alpha->getShapeAsVector(); + + //***** input validation *****// + + // temporary limitation imposed by Yurii + REQUIRE_TRUE( + inputRank <= MAX_RANK / 2, 0, + "rank of input array should be <= MAX_RANK/2, but got %i instead!", + inputRank); + REQUIRE_TRUE( + input->lengthOf() / alpha->lengthOf() <= MAX_RANK * 2, 0, + "the length of input array should be no more than MAX_RANK*2 times the " + "alpha array length, but got %lld and %lld correspondingly!", + input->lengthOf(), alpha->lengthOf()); + + std::vector expectedAlphaShape(&inputShape[1], + &inputShape[inputRank]); + + REQUIRE_TRUE(inputRank > 1, 0, + "PRELU_BP OP: wrong rank of input array, expected rank should " + "be > 1, but got %i instead !", + inputRank); + + for (int i = 0; i < numSharedAxes; ++i) { + if (sharedAxes[i] <= 0) sharedAxes[i] += inputRank - 1; + REQUIRE_TRUE(1 <= sharedAxes[i] && sharedAxes[i] <= inputRank - 1, 0, + "PRELU_BP OP: wrong axis value %i in sharedAxes at position " + "%i, axis value must be within range [1, input_rank-1] !", + sharedAxes[i], i); + expectedAlphaShape[sharedAxes[i] - 1] = 1; + } + + Nd4jLong product = 1; + for (const auto& item : expectedAlphaShape) product *= item; + + REQUIRE_TRUE(product == alphaLen, 0, + "PRELU_BP OP: wrong shape of alpha array, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(expectedAlphaShape).c_str(), + ShapeUtils::shapeAsString(alphaShape).c_str()); + // ***** end of validation ***** // + + if (alphaShape != expectedAlphaShape) { + alpha = new NDArray(alpha->reshape(alpha->ordering(), expectedAlphaShape)); + dLdA = new NDArray(dLdA->reshape(dLdA->ordering(), expectedAlphaShape)); + } + + helpers::preluBP(block.launchContext(), *input, *alpha, *dLdO, *dLdI, *dLdA); + + if (alphaShape != expectedAlphaShape) { + delete alpha; + delete dLdA; + } + + return Status::OK(); } - DECLARE_TYPES(prelu_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedInputTypes(2, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - - -} +DECLARE_TYPES(prelu_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedInputTypes( + 2, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/rationaltanh.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/rationaltanh.cpp index 3386d1578289..f871bfde2622 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/rationaltanh.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/rationaltanh.cpp @@ -25,41 +25,44 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(rationaltanh, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CONFIGURABLE_OP_IMPL(rationaltanh, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - input->applyTransform(sd::transform::RationalTanh, *output); - STORE_RESULT(output); + input->applyTransform(sd::transform::RationalTanh, *output); + STORE_RESULT(output); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(rationaltanh) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(rationaltanh) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} - CONFIGURABLE_OP_IMPL(rationaltanh_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(rationaltanh_bp, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - //input->applyPairwiseTransform(pairwise::RationalTanhDerivativeE, epsilon, z, nullptr); - helpers::rationalTanhDerivative(block.launchContext(), input, epsilon, z); - return Status::OK(); - } + // input->applyPairwiseTransform(pairwise::RationalTanhDerivativeE, epsilon, + // z, nullptr); + helpers::rationalTanhDerivative(block.launchContext(), input, epsilon, z); + return Status::OK(); +} - DECLARE_TYPES(rationaltanh_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(rationaltanh_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/rectifiedtanh.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/rectifiedtanh.cpp index 641ee0d0e9d1..478012372dbe 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/rectifiedtanh.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/rectifiedtanh.cpp @@ -25,41 +25,44 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(rectifiedtanh, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CONFIGURABLE_OP_IMPL(rectifiedtanh, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - input->applyTransform(sd::transform::RectifiedTanh, *output); - STORE_RESULT(output); + input->applyTransform(sd::transform::RectifiedTanh, *output); + STORE_RESULT(output); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(rectifiedtanh) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(rectifiedtanh) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} - CONFIGURABLE_OP_IMPL(rectifiedtanh_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(rectifiedtanh_bp, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - //input->applyPairwiseTransform(pairwise::RectifiedTanhDerivativeE, epsilon, z, nullptr); - helpers::rectifiedTanhDerivative(block.launchContext(), input, epsilon, z); - return Status::OK(); - } + // input->applyPairwiseTransform(pairwise::RectifiedTanhDerivativeE, epsilon, + // z, nullptr); + helpers::rectifiedTanhDerivative(block.launchContext(), input, epsilon, z); + return Status::OK(); +} - DECLARE_TYPES(rectifiedtanh_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(rectifiedtanh_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/relu.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/relu.cpp index 3b42c2e5af00..3818dd205233 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/relu.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/relu.cpp @@ -25,45 +25,46 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(relu, 1, 1, true, 1, 0) { - auto first = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +CONFIGURABLE_OP_IMPL(relu, 1, 1, true, 1, 0) { + auto first = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0; + auto scalar = block.numT() > 0 ? block.getTArguments().at(0) : 0.0; - first->applyScalar(sd::scalar::RELU, scalar, *z); + first->applyScalar(sd::scalar::RELU, scalar, *z); - STORE_RESULT(*z); + STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(relu) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setSameMode(true); - } +DECLARE_TYPES(relu) { + getOpDescriptor()->setAllowedInputTypes(0, DataType::ANY)->setSameMode(true); +} - CONFIGURABLE_OP_IMPL(relu_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(relu_bp, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - //input->applyPairwiseTransform(pairwise::RELUDerivativeE, epsilon, z, nullptr); - helpers::reluDerivative(block.launchContext(), input, epsilon, z); - return Status::OK(); - } - DECLARE_SYN(ReluGrad, relu_bp); + // input->applyPairwiseTransform(pairwise::RELUDerivativeE, epsilon, z, + // nullptr); + helpers::reluDerivative(block.launchContext(), input, epsilon, z); + return Status::OK(); +} +DECLARE_SYN(ReluGrad, relu_bp); - DECLARE_TYPES(relu_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(relu_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/relu6.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/relu6.cpp index 129c09480ac6..65dd0b74b539 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/relu6.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/relu6.cpp @@ -25,46 +25,44 @@ #include namespace sd { -namespace ops { - +namespace ops { //////////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(relu6, 1, 1, true, 1, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - input->applyScalar(sd::scalar::RELU6, T_ARG(0), *output); + input->applyScalar(sd::scalar::RELU6, T_ARG(0), *output); - return Status::OK(); + return Status::OK(); } - DECLARE_TYPES(relu6) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setSameMode(true); - } +DECLARE_TYPES(relu6) { + getOpDescriptor()->setAllowedInputTypes(0, DataType::ANY)->setSameMode(true); +} //////////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(relu6_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); - - //input->applyPairwiseTransform(pairwise::RELU6DerivativeE, gradO, gradI, nullptr); - helpers::relu6Derivative(block.launchContext(), input, gradO, gradI); - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); + + // input->applyPairwiseTransform(pairwise::RELU6DerivativeE, gradO, gradI, + // nullptr); + helpers::relu6Derivative(block.launchContext(), input, gradO, gradI); + return Status::OK(); } - DECLARE_TYPES(relu6_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - - - -} +DECLARE_TYPES(relu6_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/selu.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/selu.cpp index 7fc6aa11ac63..ebbddb9a73be 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/selu.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/selu.cpp @@ -25,42 +25,45 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(selu, 1, 1, true, 0, 0) { - auto first = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +CONFIGURABLE_OP_IMPL(selu, 1, 1, true, 0, 0) { + auto first = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - first->applyTransform(sd::transform::SELU, *z); + first->applyTransform(sd::transform::SELU, *z); - STORE_RESULT(*z); + STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(selu) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(selu) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} - CONFIGURABLE_OP_IMPL(selu_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(selu_bp, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - //input->applyPairwiseTransform(pairwise::SELUDerivativeE, epsilon, z, nullptr); - helpers::seluDerivative(block.launchContext(), input, epsilon, z); - return Status::OK(); - } + // input->applyPairwiseTransform(pairwise::SELUDerivativeE, epsilon, z, + // nullptr); + helpers::seluDerivative(block.launchContext(), input, epsilon, z); + return Status::OK(); +} - DECLARE_TYPES(selu_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(selu_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/sigmoid.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/sigmoid.cpp index 047d973e6f23..381fc7136a80 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/sigmoid.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/sigmoid.cpp @@ -24,42 +24,45 @@ #include #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(sigmoid, 1, 1, true, 0, 0) { - auto first = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +CONFIGURABLE_OP_IMPL(sigmoid, 1, 1, true, 0, 0) { + auto first = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - first->applyTransform(sd::transform::Sigmoid, *z); + first->applyTransform(sd::transform::Sigmoid, *z); - STORE_RESULT(*z); + STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(sigmoid) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(sigmoid) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} - CONFIGURABLE_OP_IMPL(sigmoid_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(sigmoid_bp, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - //input->applyPairwiseTransform(pairwise::SigmoidDerivativeE, epsilon, z, nullptr); - helpers::sigmoidDerivative(block.launchContext(), input, epsilon, z); - return Status::OK(); - } + // input->applyPairwiseTransform(pairwise::SigmoidDerivativeE, epsilon, z, + // nullptr); + helpers::sigmoidDerivative(block.launchContext(), input, epsilon, z); + return Status::OK(); +} - DECLARE_TYPES(sigmoid_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(sigmoid_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/softplus.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/softplus.cpp index 5cd17e752e93..1340eacf4170 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/softplus.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/softplus.cpp @@ -25,43 +25,46 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(softplus, 1, 1, true, 0, 0) { - auto first = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +CONFIGURABLE_OP_IMPL(softplus, 1, 1, true, 0, 0) { + auto first = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - first->applyTransform(sd::transform::SoftPlus, *z); + first->applyTransform(sd::transform::SoftPlus, *z); - STORE_RESULT(*z); + STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(softplus) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(softplus) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} - CONFIGURABLE_OP_IMPL(softplus_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(softplus_bp, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - //input->applyPairwiseTransform(pairwise::SoftplusDerivativeE, epsilon, z, nullptr); - helpers::softPlusDerivative(block.launchContext(), input, epsilon, z); - return Status::OK(); - } - DECLARE_SYN(SoftplusGrad, softplus_bp); + // input->applyPairwiseTransform(pairwise::SoftplusDerivativeE, epsilon, z, + // nullptr); + helpers::softPlusDerivative(block.launchContext(), input, epsilon, z); + return Status::OK(); +} +DECLARE_SYN(SoftplusGrad, softplus_bp); - DECLARE_TYPES(softplus_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(softplus_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/softsign.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/softsign.cpp index c7fb15fddf29..9b07c512beae 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/softsign.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/softsign.cpp @@ -25,44 +25,47 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(softsign, 1, 1, true, 0, 0) { - auto first = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +CONFIGURABLE_OP_IMPL(softsign, 1, 1, true, 0, 0) { + auto first = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - first->applyTransform(sd::transform::SoftSign, *z); + first->applyTransform(sd::transform::SoftSign, *z); - STORE_RESULT(*z); + STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(softsign) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(softsign) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} - CONFIGURABLE_OP_IMPL(softsign_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(softsign_bp, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - //input->applyPairwiseTransform(pairwise::SoftsignDerivativeE, epsilon, z, nullptr); - helpers::softSignDerivative(block.launchContext(), input, epsilon, z); + // input->applyPairwiseTransform(pairwise::SoftsignDerivativeE, epsilon, z, + // nullptr); + helpers::softSignDerivative(block.launchContext(), input, epsilon, z); - return Status::OK(); - } - DECLARE_SYN(SoftsignGrad, softsign_bp); + return Status::OK(); +} +DECLARE_SYN(SoftsignGrad, softsign_bp); - DECLARE_TYPES(softsign_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(softsign_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/tanh.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/tanh.cpp index a42552f75327..f577083ec21a 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/tanh.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/tanh.cpp @@ -25,43 +25,46 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(tanh, 1, 1, true, 0, 0) { - auto first = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +CONFIGURABLE_OP_IMPL(tanh, 1, 1, true, 0, 0) { + auto first = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - first->applyTransform(sd::transform::Tanh, *z); + first->applyTransform(sd::transform::Tanh, *z); - STORE_RESULT(*z); + STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(tanh) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(tanh) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} - CONFIGURABLE_OP_IMPL(tanh_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(tanh_bp, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - //input->applyPairwiseTransform(pairwise::TanhDerivativeE, epsilon, z, nullptr); - helpers::tanhDerivative(block.launchContext(), input, epsilon, z); - return Status::OK(); - } - DECLARE_SYN(TanhGrad, tanh_bp); + // input->applyPairwiseTransform(pairwise::TanhDerivativeE, epsilon, z, + // nullptr); + helpers::tanhDerivative(block.launchContext(), input, epsilon, z); + return Status::OK(); +} +DECLARE_SYN(TanhGrad, tanh_bp); - DECLARE_TYPES(tanh_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(tanh_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/thresholdedrelu.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/thresholdedrelu.cpp index a0cba155ac05..fe5c1c70a693 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/thresholdedrelu.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/thresholdedrelu.cpp @@ -18,57 +18,55 @@ // @author Yurii Shyrma, created on 24.07.2018 // - #include #if NOT_EXCLUDED(OP_thresholdedrelu) #include -#include #include +#include namespace sd { -namespace ops { - +namespace ops { //////////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(thresholdedrelu, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0; + auto scalar = block.numT() > 0 ? block.getTArguments().at(0) : 0.0; - helpers::thresholdRelu(block.launchContext(), *input, scalar, *output); + helpers::thresholdRelu(block.launchContext(), *input, scalar, *output); - return Status::OK(); + return Status::OK(); } - DECLARE_TYPES(thresholdedrelu) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setSameMode(true); - } +DECLARE_TYPES(thresholdedrelu) { + getOpDescriptor()->setAllowedInputTypes(0, DataType::ANY)->setSameMode(true); +} //////////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(thresholdedrelu_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto dLdO = INPUT_VARIABLE(1); - - auto dLdI = OUTPUT_VARIABLE(0); - auto threshold = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0; + auto input = INPUT_VARIABLE(0); + auto dLdO = INPUT_VARIABLE(1); - helpers::thresholdReluDerivative(block.launchContext(), input, threshold, dLdO, dLdI); - - return Status::OK(); -} - - DECLARE_TYPES(thresholdedrelu_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } + auto dLdI = OUTPUT_VARIABLE(0); + auto threshold = block.numT() > 0 ? block.getTArguments().at(0) : 0.0; + helpers::thresholdReluDerivative(block.launchContext(), input, threshold, + dLdO, dLdI); + return Status::OK(); } + +DECLARE_TYPES(thresholdedrelu_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/apply_sgd.cpp b/libnd4j/include/ops/declarable/generic/nn/apply_sgd.cpp index 389d07c7b651..39fb71b2e483 100644 --- a/libnd4j/include/ops/declarable/generic/nn/apply_sgd.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/apply_sgd.cpp @@ -25,38 +25,45 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(apply_sgd, 2, 1, true, -2, 0) { - auto parameters = INPUT_VARIABLE(0); - auto gradients = INPUT_VARIABLE(1); - - double lr = 0.0; - - REQUIRE_TRUE(parameters->isSameShape(gradients), 0, "ApplySGD: parameters and gradients should have the same shape, but got parameters = %s and gradients = %s !", ShapeUtils::shapeAsString(parameters).c_str(), ShapeUtils::shapeAsString(gradients).c_str()); - - if (block.width() == 3) { - auto tarr = INPUT_VARIABLE(2); - lr = tarr->e(0); - } else if (block.getTArguments()->size() == 1) { - lr = T_ARG(0); - } else { - REQUIRE_TRUE(false, 0, "ApplyGradients op should have LR announced either es T argument or additional NDArray!"); - } - - auto Z = OUTPUT_VARIABLE(0); - - helpers::applyGradientDescent(block.launchContext(), parameters, gradients, lr, Z); - - return Status::OK(); - } - DECLARE_SYN(ApplyGradientDescent, apply_sgd); - } - - DECLARE_TYPES(apply_sgd) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +namespace ops { +CONFIGURABLE_OP_IMPL(apply_sgd, 2, 1, true, -2, 0) { + auto parameters = INPUT_VARIABLE(0); + auto gradients = INPUT_VARIABLE(1); + + double lr = 0.0; + + REQUIRE_TRUE(parameters->isSameShape(gradients), 0, + "ApplySGD: parameters and gradients should have the same shape, " + "but got parameters = %s and gradients = %s !", + ShapeUtils::shapeAsString(parameters).c_str(), + ShapeUtils::shapeAsString(gradients).c_str()); + + if (block.width() == 3) { + auto tarr = INPUT_VARIABLE(2); + lr = tarr->e(0); + } else if (block.numT() == 1) { + lr = T_ARG(0); + } else { + REQUIRE_TRUE(false, 0, + "ApplyGradients op should have LR announced either es T " + "argument or additional NDArray!"); + } + + auto Z = OUTPUT_VARIABLE(0); + + helpers::applyGradientDescent(block.launchContext(), parameters, gradients, + lr, Z); + + return Status::OK(); +} +DECLARE_SYN(ApplyGradientDescent, apply_sgd); +} // namespace ops + +DECLARE_TYPES(apply_sgd) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp b/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp index 7018ae342ff7..f82200f51de6 100644 --- a/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp @@ -24,336 +24,398 @@ #if NOT_EXCLUDED(OP_batchnorm) #include -#include +#include namespace sd { namespace ops { - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(batchnorm, 3, 1, false, 1, 2) { - - auto input = INPUT_VARIABLE(0); - auto mean = INPUT_VARIABLE(1); - auto variance = INPUT_VARIABLE(2); - NDArray* gamma = nullptr; - NDArray* beta = nullptr; - - auto output = OUTPUT_VARIABLE(0); - - const bool applyScale = (bool)INT_ARG(0); - const bool applyOffset = (bool)INT_ARG(1); - const double epsilon = T_ARG(0); - - if(applyScale) - gamma = INPUT_VARIABLE(3); - if(applyOffset) - beta = INPUT_VARIABLE(3 + (int)applyScale); - - const int numOfIntArgs = block.getIArguments()->size(); - const int inRank = input->rankOf(); - - // get axes args to normalize input array over - std::vector axes; - if(numOfIntArgs > 2) - for(int i = 2; i < numOfIntArgs; ++i) - axes.push_back(INT_ARG(i)); - else - axes.push_back(inRank-1); // default dimension to reduce along is last dimension - - const uint numOfAxes = axes.size(); - REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank); - - // evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes - // for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5} - std::vector expShape; - if(numOfAxes == 1) - expShape.push_back(input->sizeAt(axes[0])); - else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3} - expShape = std::vector(inRank, 1); - for(uint i = 0; i < numOfAxes; ++i) - expShape[axes[i]] = input->sizeAt(axes[i]); - } - - REQUIRE_TRUE(mean->isSameShape(expShape) , 0, "BATCHNORM op: wrong shape of mean array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(mean).c_str()); - REQUIRE_TRUE(variance->isSameShape(expShape), 0, "BATCHNORM op: wrong shape of variance array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(variance).c_str()); - if(gamma) - REQUIRE_TRUE(gamma->isSameShape(expShape), 0, "BATCHNORM op: wrong shape of gamma array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(gamma).c_str()); - if(beta) - REQUIRE_TRUE(beta->isSameShape(expShape), 0, "BATCHNORM op: wrong shape of beta array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(beta).c_str()); - - // types of all input arrays should be the same - for(unsigned long i = 1; i < block.width(); ++i) - REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM op: types of all input arrays should be the same !"); - - nd4j_debug("MKL-DNN is not used for batchnorm!\n", 0); - - // formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta - // auto v = input->varianceAlongDimension(variance::SummaryStatsVariance, false, ShapeUtils::evalDimsToExclude(input->rankOf(), axes)); - // auto m = input->reduceAlongDimension(sd::reduce::Mean, ShapeUtils::evalDimsToExclude(input->rankOf(), axes)); - - helpers::batchnorm(input, mean, variance, gamma, beta, output, axes, epsilon); - - // NDArray stdInv = *v + epsilon; - // stdInv.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon) - // stdInv.applyTransform(transform::Sqrt); // 1 / (variance + epsilon)^0.5 - // if(applyScale) - // stdInv *= *gamma; - - // // empty array with same shape as input - // input->applyBroadcast(sd::broadcast::Subtract, axes, m, output); - // output->applyBroadcast(sd::broadcast::Multiply, axes, &stdInv); - - // if(applyOffset) - // output->applyBroadcast(sd::broadcast::Add, axes, beta); - - // delete v; - // delete m; - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto mean = INPUT_VARIABLE(1); + auto variance = INPUT_VARIABLE(2); + NDArray* gamma = nullptr; + NDArray* beta = nullptr; + + auto output = OUTPUT_VARIABLE(0); + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + const double epsilon = T_ARG(0); + + if (applyScale) gamma = INPUT_VARIABLE(3); + if (applyOffset) beta = INPUT_VARIABLE(3 + (int)applyScale); + + const int numOfIntArgs = block.numI(); + const int inRank = input->rankOf(); + + // get axes args to normalize input array over + std::vector axes; + if (numOfIntArgs > 2) + for (int i = 2; i < numOfIntArgs; ++i) axes.push_back(INT_ARG(i)); + else + axes.push_back(inRank - + 1); // default dimension to reduce along is last dimension + + const uint numOfAxes = axes.size(); + REQUIRE_TRUE(numOfAxes <= inRank, 0, + "BATCHNORM op: too big number of input axes to normalize over, " + "expected number should be less or equal to rank of input " + "array, but got %i and %i correspondingly !", + numOfAxes, inRank); + + // evaluate expected shape for mean, variance and gamma. These 3 arrays should + // have identical shapes for example if input shape is {2,3,4,5,6} and axes = + // {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then + // expected shape would be {5} + std::vector expShape; + if (numOfAxes == 1) + expShape.push_back(input->sizeAt(axes[0])); + else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if + // axes = {1, 3} + expShape = std::vector(inRank, 1); + for (uint i = 0; i < numOfAxes; ++i) + expShape[axes[i]] = input->sizeAt(axes[i]); + } + + REQUIRE_TRUE(mean->isSameShape(expShape), 0, + "BATCHNORM op: wrong shape of mean array, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(mean).c_str()); + REQUIRE_TRUE(variance->isSameShape(expShape), 0, + "BATCHNORM op: wrong shape of variance array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(variance).c_str()); + if (gamma) + REQUIRE_TRUE(gamma->isSameShape(expShape), 0, + "BATCHNORM op: wrong shape of gamma array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(gamma).c_str()); + if (beta) + REQUIRE_TRUE(beta->isSameShape(expShape), 0, + "BATCHNORM op: wrong shape of beta array, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(beta).c_str()); + + // types of all input arrays should be the same + for (unsigned long i = 1; i < block.width(); ++i) + REQUIRE_TRUE( + INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, + "BATCHNORM op: types of all input arrays should be the same !"); + + nd4j_debug("MKL-DNN is not used for batchnorm!\n", 0); + + // formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + + // beta auto v = input->varianceAlongDimension(variance::SummaryStatsVariance, + // false, ShapeUtils::evalDimsToExclude(input->rankOf(), axes)); auto m = + // input->reduceAlongDimension(sd::reduce::Mean, + // ShapeUtils::evalDimsToExclude(input->rankOf(), axes)); + + helpers::batchnorm(input, mean, variance, gamma, beta, output, axes, epsilon); + + // NDArray stdInv = *v + epsilon; + // stdInv.applyTransform(transform::Reciprocal); // 1 / + // (variance + epsilon) stdInv.applyTransform(transform::Sqrt); // 1 / + // (variance + epsilon)^0.5 if(applyScale) + // stdInv *= *gamma; + + // // empty array with same shape as input + // input->applyBroadcast(sd::broadcast::Subtract, axes, m, output); + // output->applyBroadcast(sd::broadcast::Multiply, axes, &stdInv); + + // if(applyOffset) + // output->applyBroadcast(sd::broadcast::Add, axes, beta); + + // delete v; + // delete m; + + return Status::OK(); } DECLARE_TYPES(batchnorm) { - getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } DECLARE_SHAPE_FN(batchnorm) { + auto inShapeInfo = inputShape->at(0); + DataType outType = + DataTypeUtils::pickFloatingType(ArrayOptions::dataType(inShapeInfo)); - auto inShapeInfo = inputShape->at(0); - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(inShapeInfo)); + auto outShapeInfo = ShapeBuilders::copyShapeInfoAndType( + inShapeInfo, outType, false, + block.workspace()); // output shape is identical to input shape - auto outShapeInfo = ShapeBuilders::copyShapeInfoAndType(inShapeInfo, outType, false, block.getWorkspace()); // output shape is identical to input shape - - return SHAPELIST(CONSTANT(outShapeInfo)); + return SHAPELIST(CONSTANT(outShapeInfo)); } ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) { - - NDArray* input = INPUT_VARIABLE(0); - NDArray* mean = INPUT_VARIABLE(1); - NDArray* variance = INPUT_VARIABLE(2); - NDArray* gamma = nullptr; - NDArray* beta = nullptr; - NDArray* dLdO = INPUT_VARIABLE(block.width() - 1); // next epsilon - - NDArray* dLdI = OUTPUT_VARIABLE(0); - NDArray* dLdM = OUTPUT_VARIABLE(1); - NDArray* dLdV = OUTPUT_VARIABLE(2); - NDArray* dLdG = nullptr; - NDArray* dLdB = nullptr; - - const bool applyScale = (bool)INT_ARG(0); - const bool applyOffset = (bool)INT_ARG(1); - const float epsilon = T_ARG(0); - - if(applyScale) { - gamma = INPUT_VARIABLE(3); - dLdG = OUTPUT_VARIABLE(3); - } - if(applyOffset) { - beta = INPUT_VARIABLE(3 + (int)applyScale); - dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); - } - - const int numOfIntArgs = block.getIArguments()->size(); - const int inRank = input->rankOf(); - - // get axes args to normalize input array over - std::vector axes; - if(numOfIntArgs > 2) - for(int i = 2; i < numOfIntArgs; ++i) - axes.push_back(INT_ARG(i)); - else - axes.push_back(inRank-1); // default dimension to reduce along is last dimension - - const uint numOfAxes = axes.size(); - REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM_BP op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank); - - // evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes - // for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5} - std::vector expShape; - if(numOfAxes == 1) - expShape.push_back(input->sizeAt(axes[0])); - else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3} - expShape = std::vector(inRank, 1); - for(uint i = 0; i < numOfAxes; ++i) - expShape[axes[i]] = input->sizeAt(axes[i]); - } - - REQUIRE_TRUE(mean->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of mean array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(mean).c_str()); - REQUIRE_TRUE(variance->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of variance array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(variance).c_str()); - if(gamma) - REQUIRE_TRUE(gamma->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of gamma array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(gamma).c_str()); - if(beta) - REQUIRE_TRUE(beta->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of beta array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(beta).c_str()); - - REQUIRE_TRUE(input->isSameShape(dLdO), 0, "BATCHNORM_BP op: wrong shape of output gradients array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(dLdO).c_str()); - - // types of all input arrays should be the same (except dLdO) - for(unsigned long i = 1; i < block.width() - 2; ++i) - REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP op: types of arrays (input, mean, variance, gamma, beta) should be the same !"); - - // ***** calculations ***** // - - // notations: - // f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output, g = dLdO - // stdInv = 1 / (v + eps)^0.5 - // N - batch size (product of spatial dimensions) - - // derivatives: - // dLdI = dfdx + dfdm*dmdx + dfdv*(dvdm*dmdx + dvdx) - - // dfdx = gamma*stdInv*g; - // dfdm = -gamma*stdInv*g_sum; - // dmdx = 1/N; - // dvdx = 2 * (x - m) / N - // dvdm = -2 * [(x - m)]_sum / N - // dfdv = -0.5 * [g*(x - m)]_sum * stdInv^3, drop gamma here for calc convenience - - // finally: - // dLdI = gamma * ( stdInv * (g - g_sum/N) + (2/N) * dfdv * (dvdm/2 + (x - m)) ) - - // dLdG = (g * (x - m))_sum * stdInv - // dLdB = g_sum - - // variance = input->varianceAlongDimension(variance::SummaryStatsVariance, false, ShapeUtils::evalDimsToExclude(input->rankOf(), axes)); - // mean = input->reduceAlongDimension(sd::reduce::Mean, ShapeUtils::evalDimsToExclude(input->rankOf(), axes)); - - const auto excludedAxes = ShapeUtils::evalDimsToExclude(inRank, axes); - const bool keepUnitiesInShape = inRank == mean->rankOf(); - - // inverse batch size 1/N - const float Ninv = 1.f * shape::tadLength(input->shapeInfo(), axes.data(), axes.size()) / input->lengthOf(); - - // input - mean - NDArray xMinusMean(input); // empty array with same shape as input - input->applyBroadcast(sd::broadcast::Subtract, axes, *mean, xMinusMean); - - // stdInv - NDArray stdInv = *variance + epsilon; - stdInv.applyTransform(transform::Reciprocal, stdInv); // 1 / (variance + epsilon) - stdInv.applyTransform(transform::Sqrt, stdInv); // 1 / (variance + epsilon)^0.5 - - // dvdm (use dLdM as storage for dvdm) - xMinusMean.reduceAlongDimension(sd::reduce::Sum, *dLdM, excludedAxes, keepUnitiesInShape); - *dLdM *= -Ninv; - - // g_sum - auto gSum = dLdO->reduceAlongDimension(sd::reduce::Sum, excludedAxes, keepUnitiesInShape); - - // dLdB - if(applyOffset) - dLdB->assign(gSum); - - // stdInv * (g - g_sum/N) (use dLdI as storage for this expression) - gSum *= Ninv; - dLdO->applyBroadcast(sd::broadcast::Subtract, axes, gSum, *dLdI); - dLdI->applyBroadcast(sd::broadcast::Multiply, axes, stdInv, *dLdI); - - // dLdV <- [g*(x - m)]_sum - (xMinusMean * *dLdO).reduceAlongDimension(sd::reduce::Sum, *dLdV, excludedAxes, keepUnitiesInShape); - - // dLdG - *dLdV *= stdInv; - if(applyScale) - dLdG->assign(dLdV); - - // (2 / N) * dfdv (use dLdV as storage for dfdv) - *dLdV *= stdInv*stdInv; // dLdV*stdInv * stdInv^2 - *dLdV *= -Ninv; // -0.5f * (2 / N); - - // dfdv * (dvdm + (x - m)) (use xMinusMean as storage for this expression) - xMinusMean.applyBroadcast(sd::broadcast::Add, axes, *dLdM, xMinusMean); - xMinusMean.applyBroadcast(sd::broadcast::Multiply, axes, *dLdV, xMinusMean); - - // dLdI - *dLdI += xMinusMean; - if(applyScale) - dLdI->applyBroadcast(sd::broadcast::Multiply, axes, *gamma, *dLdI); - - *dLdM = 0; // put zeros so far - *dLdV = 0; // put zeros so far - - // java code - // NDArray std = *variance + epsilon; - // std.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon) - // std.applyTransform(transform::Sqrt); // 1 / (variance + epsilon)^0.5 - // NDArray xMu(input); - // input->applyBroadcast(sd::broadcast::Subtract, axes, mean, &xMu); - // NDArray xHat(input); - // xMu.applyBroadcast(sd::broadcast::Multiply, axes, &std, &xHat); - // NDArray dxhat(input); - // dLdO->applyBroadcast(sd::broadcast::Multiply, axes, gamma, &dxhat); - // NDArray temp = dxhat*xMu; - // temp.reduceAlongDimension(reduce::Sum, dLdV, excludedAxes, keepUnitiesInShape); - // *dLdV *= -0.5f * std*std*std; - // NDArray* dxmu1 = dxhat.reduceAlongDimension(reduce::Sum, excludedAxes, keepUnitiesInShape); - // *dxmu1 *= -std; - // NDArray* dxmu2 = xMu.reduceAlongDimension(reduce::Sum, excludedAxes, keepUnitiesInShape); - // *dxmu2 *= *dLdV * (-2.f/N); - // NDArray dLdmu = *dxmu1 + *dxmu2; - // dLdmu *= (1.f /N); - // *dLdV *= (2.f/N); - // dxhat.applyBroadcast(sd::broadcast::Multiply, axes, &std); - // xMu.applyBroadcast(sd::broadcast::Multiply, axes, dLdV); - // dxhat += xMu; - // dxhat.applyBroadcast(sd::broadcast::Add, axes, &dLdmu, dLdI); - // delete dxmu1; - // delete dxmu2; - // xHat *= *dLdO; - // xHat.reduceAlongDimension(reduce::Sum, dLdG, excludedAxes, keepUnitiesInShape); - - return Status::OK(); + NDArray* input = INPUT_VARIABLE(0); + NDArray* mean = INPUT_VARIABLE(1); + NDArray* variance = INPUT_VARIABLE(2); + NDArray* gamma = nullptr; + NDArray* beta = nullptr; + NDArray* dLdO = INPUT_VARIABLE(block.width() - 1); // next epsilon + + NDArray* dLdI = OUTPUT_VARIABLE(0); + NDArray* dLdM = OUTPUT_VARIABLE(1); + NDArray* dLdV = OUTPUT_VARIABLE(2); + NDArray* dLdG = nullptr; + NDArray* dLdB = nullptr; + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + const float epsilon = T_ARG(0); + + if (applyScale) { + gamma = INPUT_VARIABLE(3); + dLdG = OUTPUT_VARIABLE(3); + } + if (applyOffset) { + beta = INPUT_VARIABLE(3 + (int)applyScale); + dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); + } + + const int numOfIntArgs = block.numI(); + const int inRank = input->rankOf(); + + // get axes args to normalize input array over + std::vector axes; + if (numOfIntArgs > 2) + for (int i = 2; i < numOfIntArgs; ++i) axes.push_back(INT_ARG(i)); + else + axes.push_back(inRank - + 1); // default dimension to reduce along is last dimension + + const uint numOfAxes = axes.size(); + REQUIRE_TRUE(numOfAxes <= inRank, 0, + "BATCHNORM_BP op: too big number of input axes to normalize " + "over, expected number should be less or equal to rank of input " + "array, but got %i and %i correspondingly !", + numOfAxes, inRank); + + // evaluate expected shape for mean, variance and gamma. These 3 arrays should + // have identical shapes for example if input shape is {2,3,4,5,6} and axes = + // {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then + // expected shape would be {5} + std::vector expShape; + if (numOfAxes == 1) + expShape.push_back(input->sizeAt(axes[0])); + else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if + // axes = {1, 3} + expShape = std::vector(inRank, 1); + for (uint i = 0; i < numOfAxes; ++i) + expShape[axes[i]] = input->sizeAt(axes[i]); + } + + REQUIRE_TRUE(mean->isSameShape(expShape), 0, + "BATCHNORM_BP op: wrong shape of mean array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(mean).c_str()); + REQUIRE_TRUE(variance->isSameShape(expShape), 0, + "BATCHNORM_BP op: wrong shape of variance array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(variance).c_str()); + if (gamma) + REQUIRE_TRUE(gamma->isSameShape(expShape), 0, + "BATCHNORM_BP op: wrong shape of gamma array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(gamma).c_str()); + if (beta) + REQUIRE_TRUE(beta->isSameShape(expShape), 0, + "BATCHNORM_BP op: wrong shape of beta array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(beta).c_str()); + + REQUIRE_TRUE(input->isSameShape(dLdO), 0, + "BATCHNORM_BP op: wrong shape of output gradients array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(input).c_str(), + ShapeUtils::shapeAsString(dLdO).c_str()); + + // types of all input arrays should be the same (except dLdO) + for (unsigned long i = 1; i < block.width() - 2; ++i) + REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), + 0, + "BATCHNORM_BP op: types of arrays (input, mean, variance, " + "gamma, beta) should be the same !"); + + // ***** calculations ***** // + + // notations: + // f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * + // ff_output, g = dLdO stdInv = 1 / (v + eps)^0.5 N - batch size (product of + // spatial dimensions) + + // derivatives: + // dLdI = dfdx + dfdm*dmdx + dfdv*(dvdm*dmdx + dvdx) + + // dfdx = gamma*stdInv*g; + // dfdm = -gamma*stdInv*g_sum; + // dmdx = 1/N; + // dvdx = 2 * (x - m) / N + // dvdm = -2 * [(x - m)]_sum / N + // dfdv = -0.5 * [g*(x - m)]_sum * stdInv^3, drop gamma here for calc + // convenience + + // finally: + // dLdI = gamma * ( stdInv * (g - g_sum/N) + (2/N) * dfdv * (dvdm/2 + (x - + // m)) ) + + // dLdG = (g * (x - m))_sum * stdInv + // dLdB = g_sum + + // variance = input->varianceAlongDimension(variance::SummaryStatsVariance, + // false, ShapeUtils::evalDimsToExclude(input->rankOf(), axes)); mean = + // input->reduceAlongDimension(sd::reduce::Mean, + // ShapeUtils::evalDimsToExclude(input->rankOf(), axes)); + + const auto excludedAxes = ShapeUtils::evalDimsToExclude(inRank, axes); + const bool keepUnitiesInShape = inRank == mean->rankOf(); + + // inverse batch size 1/N + const float Ninv = + 1.f * shape::tadLength(input->shapeInfo(), axes.data(), axes.size()) / + input->lengthOf(); + + // input - mean + NDArray xMinusMean(input); // empty array with same shape as input + input->applyBroadcast(sd::broadcast::Subtract, axes, *mean, xMinusMean); + + // stdInv + NDArray stdInv = *variance + epsilon; + stdInv.applyTransform(transform::Reciprocal, + stdInv); // 1 / (variance + epsilon) + stdInv.applyTransform(transform::Sqrt, + stdInv); // 1 / (variance + epsilon)^0.5 + + // dvdm (use dLdM as storage for dvdm) + xMinusMean.reduceAlongDimension(sd::reduce::Sum, *dLdM, excludedAxes, + keepUnitiesInShape); + *dLdM *= -Ninv; + + // g_sum + auto gSum = dLdO->reduceAlongDimension(sd::reduce::Sum, excludedAxes, + keepUnitiesInShape); + + // dLdB + if (applyOffset) dLdB->assign(gSum); + + // stdInv * (g - g_sum/N) (use dLdI as storage for this expression) + gSum *= Ninv; + dLdO->applyBroadcast(sd::broadcast::Subtract, axes, gSum, *dLdI); + dLdI->applyBroadcast(sd::broadcast::Multiply, axes, stdInv, *dLdI); + + // dLdV <- [g*(x - m)]_sum + (xMinusMean * *dLdO) + .reduceAlongDimension(sd::reduce::Sum, *dLdV, excludedAxes, + keepUnitiesInShape); + + // dLdG + *dLdV *= stdInv; + if (applyScale) dLdG->assign(dLdV); + + // (2 / N) * dfdv (use dLdV as storage for dfdv) + *dLdV *= stdInv * stdInv; // dLdV*stdInv * stdInv^2 + *dLdV *= -Ninv; // -0.5f * (2 / N); + + // dfdv * (dvdm + (x - m)) (use xMinusMean as storage for this expression) + xMinusMean.applyBroadcast(sd::broadcast::Add, axes, *dLdM, xMinusMean); + xMinusMean.applyBroadcast(sd::broadcast::Multiply, axes, *dLdV, xMinusMean); + + // dLdI + *dLdI += xMinusMean; + if (applyScale) + dLdI->applyBroadcast(sd::broadcast::Multiply, axes, *gamma, *dLdI); + + *dLdM = 0; // put zeros so far + *dLdV = 0; // put zeros so far + + // java code + // NDArray std = *variance + epsilon; + // std.applyTransform(transform::Reciprocal); // 1 / + // (variance + epsilon) std.applyTransform(transform::Sqrt); // 1 / (variance + // + epsilon)^0.5 NDArray xMu(input); + // input->applyBroadcast(sd::broadcast::Subtract, axes, mean, &xMu); + // NDArray xHat(input); + // xMu.applyBroadcast(sd::broadcast::Multiply, axes, &std, &xHat); + // NDArray dxhat(input); + // dLdO->applyBroadcast(sd::broadcast::Multiply, axes, gamma, &dxhat); + // NDArray temp = dxhat*xMu; + // temp.reduceAlongDimension(reduce::Sum, dLdV, excludedAxes, + // keepUnitiesInShape); *dLdV *= -0.5f * std*std*std; NDArray* dxmu1 = + // dxhat.reduceAlongDimension(reduce::Sum, excludedAxes, keepUnitiesInShape); + // *dxmu1 *= -std; + // NDArray* dxmu2 = xMu.reduceAlongDimension(reduce::Sum, excludedAxes, + // keepUnitiesInShape); *dxmu2 *= *dLdV * (-2.f/N); NDArray dLdmu = *dxmu1 + + // *dxmu2; dLdmu *= (1.f /N); *dLdV *= (2.f/N); + // dxhat.applyBroadcast(sd::broadcast::Multiply, axes, &std); + // xMu.applyBroadcast(sd::broadcast::Multiply, axes, dLdV); + // dxhat += xMu; + // dxhat.applyBroadcast(sd::broadcast::Add, axes, &dLdmu, dLdI); + // delete dxmu1; + // delete dxmu2; + // xHat *= *dLdO; + // xHat.reduceAlongDimension(reduce::Sum, dLdG, excludedAxes, + // keepUnitiesInShape); + + return Status::OK(); } DECLARE_TYPES(batchnorm_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, sd::DataType::ANY) - ->setAllowedInputTypes(2, sd::DataType::ANY) - ->setAllowedInputTypes(3, {ALL_FLOATS}) - ->setAllowedInputTypes(4, sd::DataType::ANY) - ->setAllowedInputTypes(5, sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, sd::DataType::ANY) + ->setAllowedInputTypes(2, sd::DataType::ANY) + ->setAllowedInputTypes(3, {ALL_FLOATS}) + ->setAllowedInputTypes(4, sd::DataType::ANY) + ->setAllowedInputTypes(5, sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(batchnorm_bp) { + Nd4jLong const* inShapeInfo = inputShape->at(0); + Nd4jLong const* meanShapeInfo = inputShape->at(1); - Nd4jLong const* inShapeInfo = inputShape->at(0); - Nd4jLong const* meanShapeInfo = inputShape->at(1); + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); - const bool applyScale = (bool)INT_ARG(0); - const bool applyOffset = (bool)INT_ARG(1); + DataType outType = + DataTypeUtils::pickFloatingType(ArrayOptions::dataType(inShapeInfo)); - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(inShapeInfo)); + auto shapes = SHAPELIST(); - auto shapes = SHAPELIST(); + // dLdI shapeInfo + shapes->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + outType, inShapeInfo)); - // dLdI shapeInfo - shapes->push_back(ConstantShapeHelper::getInstance().createShapeInfo(outType, inShapeInfo)); + // dLdM shapeInfo + shapes->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + outType, meanShapeInfo)); - // dLdM shapeInfo - shapes->push_back(ConstantShapeHelper::getInstance().createShapeInfo(outType, meanShapeInfo)); + // dLdV shapeInfo (same as dLdM) + shapes->push_back(shapes->at(shapes->size() - 1)); - // dLdV shapeInfo (same as dLdM) - shapes->push_back(shapes->at(shapes->size()-1)); + // dLdG shapeInfo (same as dLdM) + if (applyScale) shapes->push_back(shapes->at(shapes->size() - 1)); - // dLdG shapeInfo (same as dLdM) - if(applyScale) - shapes->push_back(shapes->at(shapes->size()-1)); + // dLdB shapeInfo (same as dLdM) + if (applyOffset) shapes->push_back(shapes->at(shapes->size() - 1)); - // dLdB shapeInfo (same as dLdM) - if(applyOffset) - shapes->push_back(shapes->at(shapes->size()-1)); - - return shapes; + return shapes; } - -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/nn/bias_add.cpp b/libnd4j/include/ops/declarable/generic/nn/bias_add.cpp index bc164e9520aa..f09879a974ad 100644 --- a/libnd4j/include/ops/declarable/generic/nn/bias_add.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/bias_add.cpp @@ -23,93 +23,105 @@ #if NOT_EXCLUDED(OP_biasadd) #include -#include +#include namespace sd { namespace ops { //////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(biasadd, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto bias = INPUT_VARIABLE(1); - auto input = INPUT_VARIABLE(0); - auto bias = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + const bool isNCHW = !block.getBArguments().empty() ? B_ARG(0) : false; + const int channelDim = isNCHW ? 1 : input->rankOf() - 1; // second or last - const bool isNCHW = !block.getBArguments()->empty() ? B_ARG(0) : false; - const int channelDim = isNCHW ? 1 : input->rankOf() - 1; // second or last + REQUIRE_TRUE(bias->rankOf() == 1, 0, + "BIASADD CUSTOM_OP: bias array should have rank = 1, but got %i " + "instead !", + bias->rankOf()); - REQUIRE_TRUE(bias->rankOf() == 1, 0, "BIASADD CUSTOM_OP: bias array should have rank = 1, but got %i instead !", bias->rankOf()); + REQUIRE_TRUE( + bias->sizeAt(0) == input->sizeAt(channelDim), 0, + "BIASADD CUSTOM_OP: shapes of bias %s and input %s arrays are not " + "suitable for broadcast operation along channel dimension %i !", + ShapeUtils::shapeAsString(bias).c_str(), + ShapeUtils::shapeAsString(input).c_str(), channelDim); - REQUIRE_TRUE(bias->sizeAt(0) == input->sizeAt(channelDim), 0, "BIASADD CUSTOM_OP: shapes of bias %s and input %s arrays are not suitable for broadcast operation along channel dimension %i !", ShapeUtils::shapeAsString(bias).c_str(), ShapeUtils::shapeAsString(input).c_str(), channelDim); + REQUIRE_TRUE(output->isSameShape(input), 0, + "BIASADD CUSTOM_OP: wrong shape of output array, expected is %s " + "but got %s instead !", + ShapeUtils::shapeAsString(input).c_str(), + ShapeUtils::shapeAsString(output).c_str()); - REQUIRE_TRUE(output->isSameShape(input), 0, "BIASADD CUSTOM_OP: wrong shape of output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(output).c_str()); + helpers::addBias(block, *input, *bias, *output, isNCHW); + // input->applyBroadcast(sd::broadcast::Add, {channelDim}, bias, output); - helpers::addBias(block, *input, *bias, *output, isNCHW); - // input->applyBroadcast(sd::broadcast::Add, {channelDim}, bias, output); - - return Status::OK(); + return Status::OK(); } DECLARE_SYN(bias_add, biasadd); //////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(biasadd) { - auto xShape = inputShape->at(0); - auto yShape = inputShape->at(1); + auto xShape = inputShape->at(0); + auto yShape = inputShape->at(1); - auto dtype = ArrayOptions::dataType(yShape); - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(xShape, dtype))); + auto dtype = ArrayOptions::dataType(yShape); + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(xShape, dtype))); } DECLARE_TYPES(biasadd) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } //////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(biasadd_bp, 3, 2, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto bias = INPUT_VARIABLE(1); + auto gradO = INPUT_VARIABLE(2); - auto input = INPUT_VARIABLE(0); - auto bias = INPUT_VARIABLE(1); - auto gradO = INPUT_VARIABLE(2); - - auto gradI = OUTPUT_VARIABLE(0); - auto gradB = OUTPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); + auto gradB = OUTPUT_VARIABLE(1); - const bool isNCHW = !block.getBArguments()->empty() ? B_ARG(0) : false; - const int channelDim = isNCHW ? 1 : input->rankOf() - 1; // second or last + const bool isNCHW = !block.getBArguments().empty() ? B_ARG(0) : false; + const int channelDim = isNCHW ? 1 : input->rankOf() - 1; // second or last - gradI->assign(gradO); + gradI->assign(gradO); - gradO->reduceAlongDimension(sd::reduce::Sum, *gradB, ShapeUtils::evalDimsToExclude(gradO->rankOf(), {channelDim})); + gradO->reduceAlongDimension( + sd::reduce::Sum, *gradB, + ShapeUtils::evalDimsToExclude(gradO->rankOf(), {channelDim})); - return ND4J_STATUS_OK; + return ND4J_STATUS_OK; } DECLARE_SYN(BiasAddGrad, biasadd_bp); //////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(biasadd_bp) { - auto input = inputShape->at(0); - auto bias = inputShape->at(1); + auto input = inputShape->at(0); + auto bias = inputShape->at(1); - Nd4jLong* epsShape; - Nd4jLong* gradShape; + Nd4jLong* epsShape; + Nd4jLong* gradShape; - COPY_SHAPE(input, epsShape); - COPY_SHAPE(bias, gradShape); + COPY_SHAPE(input, epsShape); + COPY_SHAPE(bias, gradShape); - return SHAPELIST(CONSTANT(epsShape), CONSTANT(gradShape)); + return SHAPELIST(CONSTANT(epsShape), CONSTANT(gradShape)); } DECLARE_TYPES(biasadd_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/col2im.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/col2im.cpp index d6e95a582678..eb27aa59424b 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/col2im.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/col2im.cpp @@ -25,68 +25,71 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(col2im, 1, 1, false, 0, 9) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_NULLIFIED(0); - - REQUIRE_TRUE(x->rankOf() == 6, 0, "col2im input should be 6D, but got %i instead", x->rankOf()); - REQUIRE_TRUE(z->rankOf() == 4, 0, "col2im output should be 4D, but got %i instead", z->rankOf()); - - int strideY = INT_ARG(0); - int strideX = INT_ARG(1); - int padHeight = INT_ARG(2); - int padWidth = INT_ARG(3); - int imgHeight = INT_ARG(4); - int imgWidth = INT_ARG(5); - int dY = INT_ARG(6); //Dilation in height/y dimension - int dX = INT_ARG(7); //Dilation in width/x dimension - - LaunchContext* ctx = block.launchContext(); - helpers::col2im(*ctx, *x, *z, strideY, strideX, padHeight, padWidth, imgHeight, imgWidth, dY, dX); - - return ND4J_STATUS_OK; - } - DECLARE_SHAPE_FN(col2im) { - auto inShape = inputShape->at(0); - - int bS = shape::shapeOf(inShape)[0]; - int iD = shape::shapeOf(inShape)[1]; - - int sY = INT_ARG(0); - int sX = INT_ARG(1); - int pY = INT_ARG(2); - int pX = INT_ARG(3); - int inY = INT_ARG(4); - int inX = INT_ARG(5); - int dY = INT_ARG(6); //Dilation, height/y dimension - int dX = INT_ARG(7); //Dilation, width/x dimension - bool isSameMode = INT_ARG(8) > 0; - - Nd4jLong* zShape; - ALLOCATE(zShape, block.getWorkspace(), shape::shapeInfoLength(4), Nd4jLong); - - zShape[0] = 4; - zShape[1] = bS; - zShape[2] = iD; - zShape[3] = inY; - zShape[4] = inX; - - zShape[shape::shapeInfoLength(zShape) - 2] = 1; - zShape[shape::shapeInfoLength(zShape) - 1] = 99; - - ShapeUtils::updateStridesAndType(zShape, inShape, 'c'); - - return SHAPELIST(CONSTANT(zShape)); - } - - DECLARE_TYPES(col2im) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT) - ->setSameMode(true); - } - } +namespace ops { +CUSTOM_OP_IMPL(col2im, 1, 1, false, 0, 9) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_NULLIFIED(0); + + REQUIRE_TRUE(x->rankOf() == 6, 0, + "col2im input should be 6D, but got %i instead", x->rankOf()); + REQUIRE_TRUE(z->rankOf() == 4, 0, + "col2im output should be 4D, but got %i instead", z->rankOf()); + + int strideY = INT_ARG(0); + int strideX = INT_ARG(1); + int padHeight = INT_ARG(2); + int padWidth = INT_ARG(3); + int imgHeight = INT_ARG(4); + int imgWidth = INT_ARG(5); + int dY = INT_ARG(6); // Dilation in height/y dimension + int dX = INT_ARG(7); // Dilation in width/x dimension + + LaunchContext* ctx = block.launchContext(); + helpers::col2im(*ctx, *x, *z, strideY, strideX, padHeight, padWidth, + imgHeight, imgWidth, dY, dX); + + return ND4J_STATUS_OK; } +DECLARE_SHAPE_FN(col2im) { + auto inShape = inputShape->at(0); + + int bS = shape::shapeOf(inShape)[0]; + int iD = shape::shapeOf(inShape)[1]; + + int sY = INT_ARG(0); + int sX = INT_ARG(1); + int pY = INT_ARG(2); + int pX = INT_ARG(3); + int inY = INT_ARG(4); + int inX = INT_ARG(5); + int dY = INT_ARG(6); // Dilation, height/y dimension + int dX = INT_ARG(7); // Dilation, width/x dimension + bool isSameMode = INT_ARG(8) > 0; + + Nd4jLong* zShape; + ALLOCATE(zShape, block.workspace(), shape::shapeInfoLength(4), Nd4jLong); + + zShape[0] = 4; + zShape[1] = bS; + zShape[2] = iD; + zShape[3] = inY; + zShape[4] = inX; + + zShape[shape::shapeInfoLength(zShape) - 2] = 1; + zShape[shape::shapeInfoLength(zShape) - 1] = 99; + + ShapeUtils::updateStridesAndType(zShape, inShape, 'c'); + + return SHAPELIST(CONSTANT(zShape)); +} + +DECLARE_TYPES(col2im) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT) + ->setSameMode(true); +} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp index 881e6010529c..b3abc3ac683f 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp @@ -18,291 +18,436 @@ // @author raver119@gmail.com // @author Yurii Shyrma - #include #if NOT_EXCLUDED(OP_conv1d) -#include #include +#include #include namespace sd { -namespace ops { - - +namespace ops { CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) { - - auto input = INPUT_VARIABLE(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW) - auto weights = INPUT_VARIABLE(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - - auto output = OUTPUT_NULLIFIED(0); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW) - - int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) width - int sW = INT_ARG(1); // strides width - int pW = INT_ARG(2); // paddings width - int dW = INT_ARG(3); // dilations width - int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME, 2-CAUSAL - int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 0-NCW, 1-NWC - int wFormat = block.getIArguments()->size() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] - - const int rank = 3; - REQUIRE_TRUE(input->rankOf() == rank, 0, "CUSTOM CONV1D OP: rank of input array must be equal to %i, but got %i instead !", rank, input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM CONV1D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weights->rankOf()); - - int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); - if(!isNCW) { - indIOioC = 2; indIiW = 1; - } - else { - indIOioC = 1; indIiW = 2; - } - - int bS = input->sizeAt(0); // batch size - int iW = input->sizeAt(indIiW); // input width - int iC = input->sizeAt(indIOioC); // input channels - int oC = weights->sizeAt(indWoC); // output channels - - std::vector expectedWeightsShape = 0 == wFormat ? std::vector({kW, iC, oC}) : (1 == wFormat ? std::vector({oC, iC, kW}) : std::vector({oC, kW, iC})); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV1D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - std::vector reshapeForInput, reshapeForOutput; - if(!isNCW) { - reshapeForInput = {input->sizeAt(0), 1, input->sizeAt(1), input->sizeAt(2)}; // [bS, iW, iC] -> [bS, 1, iW, iC] - reshapeForOutput = {output->sizeAt(0), 1, output->sizeAt(1), output->sizeAt(2)}; // [bS, oW, oC] -> [bS, 1, oW, oC] - } - else { - reshapeForInput = {input->sizeAt(0), input->sizeAt(1), 1, input->sizeAt(2)}; // [bS, iC, iW] -> [bS, iC, 1, iW] - reshapeForOutput = {output->sizeAt(0), output->sizeAt(1), 1, output->sizeAt(2)}; // [bS, oC, oW] -> [bS, oC, 1, oW] - } - - auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput); - auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput, false); - auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] - - sd::ops::conv2d conv2d; - const Nd4jStatus status = conv2d.execute({&inputReshaped, &weightsReshaped, bias}, {&outputReshaped}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW, wFormat}, {}); - if (status != ND4J_STATUS_OK) - return status; - - // ConvolutionUtils::conv2d(block, &inputReshaped, &weightsReshaped, bias, &outputReshaped, 1,kW, 1,sW, 0,pW, 1,dW, paddingMode, isNCW, wFormat); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW) + auto weights = INPUT_VARIABLE(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + auto output = + OUTPUT_NULLIFIED(0); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW) + + int kW = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) width + int sW = INT_ARG(1); // strides width + int pW = INT_ARG(2); // paddings width + int dW = INT_ARG(3); // dilations width + int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME, 2-CAUSAL + int isNCW = block.numI() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 0-NCW, 1-NWC + int wFormat = + block.numI() > 6 + ? INT_ARG(6) + : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] + + const int rank = 3; + REQUIRE_TRUE(input->rankOf() == rank, 0, + "CUSTOM CONV1D OP: rank of input array must be equal to %i, but " + "got %i instead !", + rank, input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == rank, 0, + "CUSTOM CONV1D OP: rank of weights array must be equal to %i, " + "but got %i instead !", + rank, weights->rankOf()); + + int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); + if (!isNCW) { + indIOioC = 2; + indIiW = 1; + } else { + indIOioC = 1; + indIiW = 2; + } + + int bS = input->sizeAt(0); // batch size + int iW = input->sizeAt(indIiW); // input width + int iC = input->sizeAt(indIOioC); // input channels + int oC = weights->sizeAt(indWoC); // output channels + + std::vector expectedWeightsShape = + 0 == wFormat ? std::vector({kW, iC, oC}) + : (1 == wFormat ? std::vector({oC, iC, kW}) + : std::vector({oC, kW, iC})); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM CONV1D OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM CONV1D OP: wrong shape of array with biases, expected " + "rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + std::vector reshapeForInput, reshapeForOutput; + if (!isNCW) { + reshapeForInput = {input->sizeAt(0), 1, input->sizeAt(1), + input->sizeAt(2)}; // [bS, iW, iC] -> [bS, 1, iW, iC] + reshapeForOutput = {output->sizeAt(0), 1, output->sizeAt(1), + output->sizeAt(2)}; // [bS, oW, oC] -> [bS, 1, oW, oC] + } else { + reshapeForInput = {input->sizeAt(0), input->sizeAt(1), 1, + input->sizeAt(2)}; // [bS, iC, iW] -> [bS, iC, 1, iW] + reshapeForOutput = {output->sizeAt(0), output->sizeAt(1), 1, + output->sizeAt(2)}; // [bS, oC, oW] -> [bS, oC, 1, oW] + } + + auto inputReshaped = input->reshape(input->ordering(), reshapeForInput); + auto outputReshaped = + output->reshape(output->ordering(), reshapeForOutput, false); + auto weightsReshaped = weights->reshape( + weights->ordering(), + {1, weights->sizeAt(0), weights->sizeAt(1), + weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] + + sd::ops::conv2d conv2d; + const Nd4jStatus status = conv2d.execute( + {&inputReshaped, &weightsReshaped, bias}, {&outputReshaped}, {}, + {1, kW, 1, sW, 0, pW, 1, dW, paddingMode, !isNCW, wFormat}, {}); + if (status != ND4J_STATUS_OK) return status; + + // ConvolutionUtils::conv2d(block, &inputReshaped, &weightsReshaped, bias, + // &outputReshaped, 1,kW, 1,sW, 0,pW, 1,dW, paddingMode, isNCW, wFormat); + + return Status::OK(); } - DECLARE_SHAPE_FN(conv1d) { - - auto inputShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - Nd4jLong const* biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; - - int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0)); // filter(kernel) width - int sW = INT_ARG(1); // strides width - int pW = INT_ARG(2); // paddings width - int dW = INT_ARG(3); // dilations width - int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME - int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW - int wFormat = block.getIArguments()->size() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] - - int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); - if(!isNCW) { - indIOioC = 2; indIiW = 1; - } - else { - indIOioC = 1; indIiW = 2; - } - - const int rank = 3; - REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV1D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo); - REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV1D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo); - - int bS = inputShapeInfo[1]; // batch size - int iW = inputShapeInfo[indIiW+1]; // input width - int iC = inputShapeInfo[indIOioC+1]; // input channels - int oC = weightsShapeInfo[indWoC+1]; // output channels - - std::vector expectedWeightsShape = 0 == wFormat ? std::vector({kW, iC, oC}) : (1 == wFormat ? std::vector({oC, iC, kW}) : std::vector({oC, kW, iC})); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if (biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV1D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - - int oH, oW; // output height, width - ConvolutionUtils::calcOutSizePool2D(oH,oW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode); - - Nd4jLong* outputShapeInfo = nullptr; - ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); - - outputShapeInfo[0] = 3; - outputShapeInfo[1] = bS; - - if (isNCW) { - outputShapeInfo[2] = oC; - outputShapeInfo[3] = oW; - } else { - outputShapeInfo[2] = oW; - outputShapeInfo[3] = oC; - } - - ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, shape::order(weightsShapeInfo)); - - return SHAPELIST(CONSTANT(outputShapeInfo)); + auto inputShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + Nd4jLong const* biasShapeInfo = + block.width() > 2 ? inputShape->at(2) : nullptr; + + int kW = INT_ARG(0) > 0 ? INT_ARG(0) + : static_cast(shape::sizeAt( + weightsShapeInfo, 0)); // filter(kernel) width + int sW = INT_ARG(1); // strides width + int pW = INT_ARG(2); // paddings width + int dW = INT_ARG(3); // dilations width + int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME + int isNCW = block.numI() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW + int wFormat = + block.numI() > 6 + ? INT_ARG(6) + : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] + + int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); + if (!isNCW) { + indIOioC = 2; + indIiW = 1; + } else { + indIOioC = 1; + indIiW = 2; + } + + const int rank = 3; + REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, + "CUSTOM CONV1D OP: rank of input array must be equal to %i, but " + "got %i instead !", + rank, inputShapeInfo); + REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, + "CUSTOM CONV1D OP: rank of weights array must be equal to %i, " + "but got %i instead !", + rank, weightsShapeInfo); + + int bS = inputShapeInfo[1]; // batch size + int iW = inputShapeInfo[indIiW + 1]; // input width + int iC = inputShapeInfo[indIOioC + 1]; // input channels + int oC = weightsShapeInfo[indWoC + 1]; // output channels + + std::vector expectedWeightsShape = + 0 == wFormat ? std::vector({kW, iC, oC}) + : (1 == wFormat ? std::vector({oC, iC, kW}) + : std::vector({oC, kW, iC})); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, + "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got " + "%s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, + "CUSTOM CONV1D OP: wrong shape of array with biases, expected " + "rank, length: <=2, %i, but got %i, %i instead !", + oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + + int oH, oW; // output height, width + ConvolutionUtils::calcOutSizePool2D(oH, oW, 1, kW, 1, sW, 0, pW, 1, dW, 1, iW, + paddingMode); + + Nd4jLong* outputShapeInfo = nullptr; + ALLOCATE(outputShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); + + outputShapeInfo[0] = 3; + outputShapeInfo[1] = bS; + + if (isNCW) { + outputShapeInfo[2] = oC; + outputShapeInfo[3] = oW; + } else { + outputShapeInfo[2] = oW; + outputShapeInfo[3] = oC; + } + + ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, + shape::order(weightsShapeInfo)); + + return SHAPELIST(CONSTANT(outputShapeInfo)); } DECLARE_TYPES(conv1d) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS, DataType::QINT8, DataType::QINT16}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes( + 0, {ALL_FLOATS, ALL_INTS, DataType::QINT8, DataType::QINT16}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); } - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) { - - auto input = INPUT_VARIABLE(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW) - auto weights = INPUT_VARIABLE(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW), epsilon_next - - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW), epsilon - auto gradW = OUTPUT_NULLIFIED(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC] - auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] - - int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) width - int sW = INT_ARG(1); // strides width - int pW = INT_ARG(2); // paddings width - int dW = INT_ARG(3); // dilations width - int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME, 2-CAUSAL - int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW - int wFormat = block.getIArguments()->size() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] - - const int rank = 3; - REQUIRE_TRUE(input->rankOf() == rank, 0, "CUSTOM CONV1D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM CONV1D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == rank, 0, "CUSTOM CONV1D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradO->rankOf()); - - int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); - if(!isNCW) { - indIOioC = 2; indIiW = 1; - } - else { - indIOioC = 1; indIiW = 2; - } - - const int bS = input->sizeAt(0); // batch size - const int iW = input->sizeAt(indIiW); // input width - const int iC = input->sizeAt(indIOioC); // input channels - const int oC = weights->sizeAt(indWoC); // output channels - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW}); - std::vector expectedWeightsShape = 0 == wFormat ? std::vector({kW, iC, oC}) : (1 == wFormat ? std::vector({oC, iC, kW}) : std::vector({oC, kW, iC})); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV1D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - std::vector reshapeForInput, reshapeForGradO; - if(!isNCW) { - reshapeForInput = {input->sizeAt(0), 1, input->sizeAt(1), input->sizeAt(2)}; // [bS, iW, iC] -> [bS, 1, iW, iC] - reshapeForGradO = {gradO->sizeAt(0), 1, gradO->sizeAt(1), gradO->sizeAt(2)}; // [bS, oW, oC] -> [bS, 1, oW, oC] - } - else { - reshapeForInput = {input->sizeAt(0), input->sizeAt(1), 1, input->sizeAt(2)}; // [bS, iC, iW] -> [bS, iC, 1, iW] - reshapeForGradO = {gradO->sizeAt(0), gradO->sizeAt(1), 1, gradO->sizeAt(2)}; // [bS, oC, oW] -> [bS, oC, 1, oW] - } - - auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput); - auto gradIReshaped = gradI ->reshape(gradI->ordering(), reshapeForInput, false); - auto gradOReshaped = gradO ->reshape(gradO->ordering(), reshapeForGradO); - auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] - auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}, false);// [kW, iC, oC] -> [1, kW, iC, oC] - - sd::ops::conv2d_bp conv2dBP; - auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW, wFormat}, {}); - if (status != ND4J_STATUS_OK) - return status; - - // ConvolutionUtils::conv2dBP(block, &inputReshaped, &weightsReshaped, bias, &gradOReshaped, &gradIReshaped, &gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,dW, paddingMode, isNCW, wFormat); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW) + auto weights = INPUT_VARIABLE(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = + block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE( + 2); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW), epsilon_next + + auto gradI = + OUTPUT_NULLIFIED(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW), epsilon + auto gradW = OUTPUT_NULLIFIED(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC] + auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] + + int kW = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) width + int sW = INT_ARG(1); // strides width + int pW = INT_ARG(2); // paddings width + int dW = INT_ARG(3); // dilations width + int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME, 2-CAUSAL + int isNCW = block.numI() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW + int wFormat = + block.numI() > 6 + ? INT_ARG(6) + : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] + + const int rank = 3; + REQUIRE_TRUE(input->rankOf() == rank, 0, + "CUSTOM CONV1D_BP OP: rank of input array must be equal to %i, " + "but got %i instead !", + rank, input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == rank, 0, + "CUSTOM CONV1D_BP OP: rank of weights array must be equal to " + "%i, but got %i instead !", + rank, weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == rank, 0, + "CUSTOM CONV1D_BP OP: rank of output gradients (next epsilon) " + "array must be equal to %i, but got %i instead !", + rank, gradO->rankOf()); + + int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); + if (!isNCW) { + indIOioC = 2; + indIiW = 1; + } else { + indIOioC = 1; + indIiW = 2; + } + + const int bS = input->sizeAt(0); // batch size + const int iW = input->sizeAt(indIiW); // input width + const int iC = input->sizeAt(indIOioC); // input channels + const int oC = weights->sizeAt(indWoC); // output channels + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, 1, kW, 1, sW, 0, pW, 1, + dW, 1, iW, paddingMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoW, 0, indIOioC, indIiW}); + std::vector expectedWeightsShape = + 0 == wFormat ? std::vector({kW, iC, oC}) + : (1 == wFormat ? std::vector({oC, iC, kW}) + : std::vector({oC, kW, iC})); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM CONV1D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + std::vector reshapeForInput, reshapeForGradO; + if (!isNCW) { + reshapeForInput = {input->sizeAt(0), 1, input->sizeAt(1), + input->sizeAt(2)}; // [bS, iW, iC] -> [bS, 1, iW, iC] + reshapeForGradO = {gradO->sizeAt(0), 1, gradO->sizeAt(1), + gradO->sizeAt(2)}; // [bS, oW, oC] -> [bS, 1, oW, oC] + } else { + reshapeForInput = {input->sizeAt(0), input->sizeAt(1), 1, + input->sizeAt(2)}; // [bS, iC, iW] -> [bS, iC, 1, iW] + reshapeForGradO = {gradO->sizeAt(0), gradO->sizeAt(1), 1, + gradO->sizeAt(2)}; // [bS, oC, oW] -> [bS, oC, 1, oW] + } + + auto inputReshaped = input->reshape(input->ordering(), reshapeForInput); + auto gradIReshaped = + gradI->reshape(gradI->ordering(), reshapeForInput, false); + auto gradOReshaped = gradO->reshape(gradO->ordering(), reshapeForGradO); + auto weightsReshaped = weights->reshape( + weights->ordering(), + {1, weights->sizeAt(0), weights->sizeAt(1), + weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] + auto gradWReshaped = gradW->reshape( + gradW->ordering(), + {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}, + false); // [kW, iC, oC] -> [1, kW, iC, oC] + + sd::ops::conv2d_bp conv2dBP; + auto status = conv2dBP.execute( + {&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, + {&gradIReshaped, &gradWReshaped, gradB}, {}, + {1, kW, 1, sW, 0, pW, 1, dW, paddingMode, !isNCW, wFormat}, {}); + if (status != ND4J_STATUS_OK) return status; + + // ConvolutionUtils::conv2dBP(block, &inputReshaped, &weightsReshaped, bias, + // &gradOReshaped, &gradIReshaped, &gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, + // 1,dW, paddingMode, isNCW, wFormat); + + return Status::OK(); } - DECLARE_SHAPE_FN(conv1d_bp) { - - auto inputShapeInfo = inputShape->at(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW) - auto weightsShapeInfo = inputShape->at(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC] - Nd4jLong const* biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] - Nd4jLong const* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW), epsilon_next - - const int rank = 3; - REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV1D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]); - REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV1D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); - REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM CONV1D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo[0]); - - int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) width - int sW = INT_ARG(1); // strides width - int pW = INT_ARG(2); // paddings width - int dW = INT_ARG(3); // dilations width - int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME - int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW - int wFormat = block.getIArguments()->size() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] - - int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); - if(!isNCW) { - indIOioC = 2; indIiW = 1; - } - else { - indIOioC = 1; indIiW = 2; - } - - const int bS = inputShapeInfo[1]; // batch size - const int iW = inputShapeInfo[indIiW+1]; // input width - const int iC = inputShapeInfo[indIOioC+1]; // input channels - const int oC = weightsShapeInfo[indWoC+1]; // output channels - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW}); - std::vector expectedWeightsShape = 0 == wFormat ? std::vector({kW, iC, oC}) : (1 == wFormat ? std::vector({oC, iC, kW}) : std::vector({oC, kW, iC})); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if(biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV1D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - - auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - - if(biasShapeInfo) { - auto gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), CONSTANT(gradBshapeInfo)); - } - - return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo)); + auto inputShapeInfo = + inputShape->at(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW) + auto weightsShapeInfo = + inputShape->at(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC] + Nd4jLong const* biasShapeInfo = + block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] + Nd4jLong const* gradOShapeInfo = + block.width() > 3 + ? inputShape->at(3) + : inputShape->at( + 2); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW), epsilon_next + + const int rank = 3; + REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, + "CUSTOM CONV1D_BP OP: rank of input array must be equal to %i, " + "but got %i instead !", + rank, inputShapeInfo[0]); + REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, + "CUSTOM CONV1D_BP OP: rank of weights array must be equal to " + "%i, but got %i instead !", + rank, weightsShapeInfo[0]); + REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, + "CUSTOM CONV1D_BP OP: rank of output gradients (next epsilon) " + "array must be equal to %i, but got %i instead !", + rank, gradOShapeInfo[0]); + + int kW = INT_ARG(0) > 0 ? INT_ARG(0) + : static_cast(shape::sizeAt( + weightsShapeInfo, 0)); // filter(kernel) width + int sW = INT_ARG(1); // strides width + int pW = INT_ARG(2); // paddings width + int dW = INT_ARG(3); // dilations width + int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME + int isNCW = block.numI() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW + int wFormat = + block.numI() > 6 + ? INT_ARG(6) + : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] + + int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); + if (!isNCW) { + indIOioC = 2; + indIiW = 1; + } else { + indIOioC = 1; + indIiW = 2; + } + + const int bS = inputShapeInfo[1]; // batch size + const int iW = inputShapeInfo[indIiW + 1]; // input width + const int iC = inputShapeInfo[indIOioC + 1]; // input channels + const int oC = weightsShapeInfo[indWoC + 1]; // output channels + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, 1, kW, 1, sW, 0, pW, 1, + dW, 1, iW, paddingMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoW, 0, indIOioC, indIiW}); + std::vector expectedWeightsShape = + 0 == wFormat ? std::vector({kW, iC, oC}) + : (1 == wFormat ? std::vector({oC, iC, kW}) + : std::vector({oC, kW, iC})); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), + 0, + "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, + "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, + "CUSTOM CONV1D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + + auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType( + inputShapeInfo, gradOShapeInfo, false, block.workspace()); + auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, gradOShapeInfo, false, block.workspace()); + + if (biasShapeInfo) { + auto gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType( + biasShapeInfo, gradOShapeInfo, false, block.workspace()); + return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), + CONSTANT(gradBshapeInfo)); + } + + return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo)); } DECLARE_TYPES(conv1d_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS, DataType::QINT8, DataType::QINT16}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedInputTypes(3, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes( + 0, {ALL_FLOATS, ALL_INTS, DataType::QINT8, DataType::QINT16}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedInputTypes(3, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_FLOATS}); } - - -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp index 4377c1487217..bc553a07f73d 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp @@ -24,378 +24,567 @@ #include #if NOT_EXCLUDED(OP_conv2d) -#include -#include #include #include #include +#include -namespace sd { -namespace ops { +#include +namespace sd { +namespace ops { CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - - auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0)); // filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // filter(kernel) width - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + auto output = OUTPUT_NULLIFIED( + 0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, + // iC, kH, kW], 2 - [oC, kH, kW, iC] + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM CONV2D OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM CONV2D OP: wrong shape of array with biases, expected " + "rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + ConvolutionUtils::conv2d(block, input, weights, bias, output, kH, kW, sH, sW, + pH, pW, dH, dW, isSameMode, isNCHW, wFormat); + + return Status::OK(); } - - DECLARE_SHAPE_FN(conv2d) { - - auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] - - //output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0)); // filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(shape::sizeAt(weightsShapeInfo, 1)); // filter(kernel) width - - const int rank = 4; // 4 - - REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV2D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]); - REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); - - int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0); - if(!isNCHW) { - indIOioC = 3; indIiH = 1; - } - else { - indIOioC = 1; indIiH = 2; - } - - const int bS = inputShapeInfo[1]; // batch size - const int iH = inputShapeInfo[indIiH+1]; // input height - const int iW = inputShapeInfo[indIiH+2]; // input width - const int iC = inputShapeInfo[indIOioC+1]; // input channels - const int oC = weightsShapeInfo[indWoC+1]; // output channels - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if (biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - - Nd4jLong* outputShapeInfo = nullptr; - ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); - - int oH, oW; // output height, width - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - outputShapeInfo[0] = rank; - outputShapeInfo[1] = bS; - - if (isNCHW) { - outputShapeInfo[2] = oC; - outputShapeInfo[3] = oH; - outputShapeInfo[4] = oW; - } else { - outputShapeInfo[2] = oH; - outputShapeInfo[3] = oW; - outputShapeInfo[4] = oC; - } - - ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, shape::order(inputShapeInfo)); - - return SHAPELIST(CONSTANT(outputShapeInfo)); + auto inputShapeInfo = + inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weightsShapeInfo = inputShape->at( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] + + // output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, + // iC, kH, kW], 2 - [oC, kH, kW, iC] + + int kH = INT_ARG(0) > 0 ? INT_ARG(0) + : static_cast(shape::sizeAt( + weightsShapeInfo, 0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 ? INT_ARG(1) + : static_cast(shape::sizeAt( + weightsShapeInfo, 1)); // filter(kernel) width + + const int rank = 4; // 4 + + REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, + "CUSTOM CONV2D OP: rank of input array must be equal to %i, but " + "got %i instead !", + rank, inputShapeInfo[0]); + REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, + "CUSTOM CONV2D OP: rank of weights array must be equal to %i, " + "but got %i instead !", + rank, weightsShapeInfo[0]); + + int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0); + if (!isNCHW) { + indIOioC = 3; + indIiH = 1; + } else { + indIOioC = 1; + indIiH = 2; + } + + const int bS = inputShapeInfo[1]; // batch size + const int iH = inputShapeInfo[indIiH + 1]; // input height + const int iW = inputShapeInfo[indIiH + 2]; // input width + const int iC = inputShapeInfo[indIOioC + 1]; // input channels + const int oC = weightsShapeInfo[indWoC + 1]; // output channels + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, + "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got " + "%s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, + "CUSTOM CONV2D OP: wrong shape of array with biases, expected " + "rank, length: <=2, %i, but got %i, %i instead !", + oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + + Nd4jLong* outputShapeInfo = nullptr; + ALLOCATE(outputShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); + + int oH, oW; // output height, width + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, + iH, iW, isSameMode); + + outputShapeInfo[0] = rank; + outputShapeInfo[1] = bS; + + if (isNCHW) { + outputShapeInfo[2] = oC; + outputShapeInfo[3] = oH; + outputShapeInfo[4] = oW; + } else { + outputShapeInfo[2] = oH; + outputShapeInfo[3] = oW; + outputShapeInfo[4] = oC; + } + + ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, + shape::order(inputShapeInfo)); + + return SHAPELIST(CONSTANT(outputShapeInfo)); } - DECLARE_TYPES(conv2d) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - DECLARE_TYPES(conv2d_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(conv2d) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); +} +DECLARE_TYPES(conv2d_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 4, 0, "CUSTOM CONV2D_BP OP: rank of output's gradients (next epsilon) array must be equal to 4, but got %i instead !", gradO->rankOf()); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - std::vectorexpectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vectorexpectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - ConvolutionUtils::conv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, + // oH, oW] (NCHW), epsilon_next + + auto gradI = OUTPUT_NULLIFIED( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + auto gradW = OUTPUT_NULLIFIED( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, + // iC, kH, kW], 2 - [oC, kH, kW, iC] + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, " + "but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, " + "but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 4, 0, + "CUSTOM CONV2D_BP OP: rank of output's gradients (next epsilon) " + "array must be equal to 4, but got %i instead !", + gradO->rankOf()); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM CONV2D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + ConvolutionUtils::conv2dBP(block, input, weights, bias, gradO, gradI, gradW, + gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, + isNCHW, wFormat); + + return Status::OK(); } - - DECLARE_SHAPE_FN(conv2d_bp) { - - auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] - auto gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - - const int rank = 4; - - REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV2D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]); - REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV2D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); - REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM CONV2D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo[0]); - - const int kH = INT_ARG(0); // filter(kernel) height - const int kW = INT_ARG(1); // filter(kernel) width - const int sH = INT_ARG(2); // strides height - const int sW = INT_ARG(3); // strides width - const int pH = INT_ARG(4); // paddings height - const int pW = INT_ARG(5); // paddings width - const int dH = INT_ARG(6); // dilations height - const int dW = INT_ARG(7); // dilations width - const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - int indIOioC, indIiH, indOoH, indWoC(0 == wFormat ? 3 : 0); - if(!isNCHW) { - indIOioC = 3; indIiH = 1; indOoH = 1; - } - else { - indIOioC = 1; indIiH = 2; indOoH = 2; - } - - const int bS = inputShapeInfo[1]; // batch size - const int iH = inputShapeInfo[indIiH+1]; // input height - const int iW = inputShapeInfo[indIiH+2]; // input width - const int iC = inputShapeInfo[indIOioC+1]; // input channels - const int oC = weightsShapeInfo[indWoC+1]; // output channels - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if(biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - - auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - - if(biasShapeInfo) { - auto gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), CONSTANT(gradBshapeInfo)); - } - - return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo)); + auto inputShapeInfo = + inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weightsShapeInfo = inputShape->at( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] + auto gradOShapeInfo = + block.width() > 3 + ? inputShape->at(3) + : inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] + // (NCHW), epsilon_next + + const int rank = 4; + + REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, + "CUSTOM CONV2D_BP OP: rank of input array must be equal to %i, " + "but got %i instead !", + rank, inputShapeInfo[0]); + REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, + "CUSTOM CONV2D_BP OP: rank of weights array must be equal to " + "%i, but got %i instead !", + rank, weightsShapeInfo[0]); + REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, + "CUSTOM CONV2D_BP OP: rank of output gradients (next epsilon) " + "array must be equal to %i, but got %i instead !", + rank, gradOShapeInfo[0]); + + const int kH = INT_ARG(0); // filter(kernel) height + const int kW = INT_ARG(1); // filter(kernel) width + const int sH = INT_ARG(2); // strides height + const int sW = INT_ARG(3); // strides width + const int pH = INT_ARG(4); // paddings height + const int pW = INT_ARG(5); // paddings width + const int dH = INT_ARG(6); // dilations height + const int dW = INT_ARG(7); // dilations width + const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + const int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + const int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, + // iC, kH, kW], 2 - [oC, kH, kW, iC] + + int indIOioC, indIiH, indOoH, indWoC(0 == wFormat ? 3 : 0); + if (!isNCHW) { + indIOioC = 3; + indIiH = 1; + indOoH = 1; + } else { + indIOioC = 1; + indIiH = 2; + indOoH = 2; + } + + const int bS = inputShapeInfo[1]; // batch size + const int iH = inputShapeInfo[indIiH + 1]; // input height + const int iW = inputShapeInfo[indIiH + 2]; // input width + const int iC = inputShapeInfo[indIOioC + 1]; // input channels + const int oC = weightsShapeInfo[indWoC + 1]; // output channels + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), + 0, + "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, + "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, + "CUSTOM CONV2D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + + auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType( + inputShapeInfo, gradOShapeInfo, false, block.workspace()); + auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, gradOShapeInfo, false, block.workspace()); + + if (biasShapeInfo) { + auto gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType( + biasShapeInfo, gradOShapeInfo, false, block.workspace()); + return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), + CONSTANT(gradBshapeInfo)); + } + + return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo)); } ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) { - - auto gradIShape = INPUT_VARIABLE(0); // [4] - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - const int rank = gradO->rankOf(); - - REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of array with output shape must be equal to 1, but got %i instead !", gradIShape->rankOf()); - REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, "CUSTOM CONV2D_INPUT_BP OP: length of array with output shape must be equal to 4, but got %i instead !", gradIShape->lengthOf()); - - // create empty conv2d input array - std::vector gradIShapeAsVector(rank); - for(int i = 0; i < rank; ++i) - gradIShapeAsVector[i] = gradIShape->e(i); - NDArray input(gradO->ordering(), gradIShapeAsVector, gradO->dataType(), block.launchContext()); - - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - - ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat); - - return Status::OK(); + auto gradIShape = INPUT_VARIABLE(0); // [4] + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto gradO = INPUT_VARIABLE( + 2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + + auto gradI = OUTPUT_NULLIFIED( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, + // iC, kH, kW], 2 - [oC, kH, kW, iC] + + const int rank = gradO->rankOf(); + + REQUIRE_TRUE(weights->rankOf() == rank, 0, + "CUSTOM CONV2D_INPUT_BP OP: rank of weights array must be equal " + "to 4, but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, + "CUSTOM CONV2D_INPUT_BP OP: rank of array with output shape " + "must be equal to 1, but got %i instead !", + gradIShape->rankOf()); + REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, + "CUSTOM CONV2D_INPUT_BP OP: length of array with output shape " + "must be equal to 4, but got %i instead !", + gradIShape->lengthOf()); + + // create empty conv2d input array + std::vector gradIShapeAsVector(rank); + for (int i = 0; i < rank; ++i) + gradIShapeAsVector[i] = gradIShape->e(i); + NDArray input(gradO->ordering(), gradIShapeAsVector, gradO->dataType(), + block.launchContext()); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients " + "(next epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + + ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, + nullptr, nullptr, kH, kW, sH, sW, pH, pW, dH, dW, + isSameMode, isNCHW, wFormat); + + return Status::OK(); } - DECLARE_TYPES(conv2d_input_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - +DECLARE_TYPES(conv2d_input_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(conv2d_input_bp) { - - auto gradIShapeShapeInfo = inputShape->at(0); // [4] - auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto gradOShapeInfo = inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - - const int rank = 4; - - REQUIRE_TRUE(gradIShapeShapeInfo[0] == 1, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of array with output shape must be equal to %i, but got %i instead !", 1, gradIShapeShapeInfo[0]); - REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); - REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo[0]); - - const int kH = INT_ARG(0); // filter(kernel) height - const int kW = INT_ARG(1); // filter(kernel) width - const int sH = INT_ARG(2); // strides height - const int sW = INT_ARG(3); // strides width - const int pH = INT_ARG(4); // paddings height - const int pW = INT_ARG(5); // paddings width - const int dH = INT_ARG(6); // dilations height - const int dW = INT_ARG(7); // dilations width - const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0), indOoH; - if(!isNCHW) { - indIOioC = 3; indIiH = 1; indOoH = 1; - } - else { - indIOioC = 1; indIiH = 2; indOoH = 2; - } - - std::vector gradIShape = INPUT_VARIABLE(0)->template asVectorT(); - - const int bS = gradIShape[0]; // batch size - const int iH = gradIShape[indIiH]; // input height - const int iW = gradIShape[indIiH+1]; // input width - const int iC = gradIShape[indIOioC]; // input channels - const int oC = weightsShapeInfo[indWoC+1]; // output channels - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - - Nd4jLong* gradIshapeInfo(nullptr); - ALLOCATE(gradIshapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); - - gradIshapeInfo[0] = rank; - gradIshapeInfo[1] = bS; - - if (isNCHW) { - gradIshapeInfo[2] = iC; - gradIshapeInfo[3] = iH; - gradIshapeInfo[4] = iW; - } else { - gradIshapeInfo[2] = iH; - gradIshapeInfo[3] = iW; - gradIshapeInfo[4] = iC; - } - - ShapeUtils::updateStridesAndType(gradIshapeInfo, gradOShapeInfo, shape::order(gradOShapeInfo)); - - return SHAPELIST(CONSTANT(gradIshapeInfo)); + auto gradIShapeShapeInfo = inputShape->at(0); // [4] + auto weightsShapeInfo = inputShape->at( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto gradOShapeInfo = inputShape->at( + 2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + + const int rank = 4; + + REQUIRE_TRUE(gradIShapeShapeInfo[0] == 1, 0, + "CUSTOM CONV2D_INPUT_BP OP: rank of array with output shape " + "must be equal to %i, but got %i instead !", + 1, gradIShapeShapeInfo[0]); + REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, + "CUSTOM CONV2D_INPUT_BP OP: rank of weights array must be equal " + "to %i, but got %i instead !", + rank, weightsShapeInfo[0]); + REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, + "CUSTOM CONV2D_INPUT_BP OP: rank of output gradients (next " + "epsilon) array must be equal to %i, but got %i instead !", + rank, gradOShapeInfo[0]); + + const int kH = INT_ARG(0); // filter(kernel) height + const int kW = INT_ARG(1); // filter(kernel) width + const int sH = INT_ARG(2); // strides height + const int sW = INT_ARG(3); // strides width + const int pH = INT_ARG(4); // paddings height + const int pW = INT_ARG(5); // paddings width + const int dH = INT_ARG(6); // dilations height + const int dW = INT_ARG(7); // dilations width + const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + const int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + const int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, + // iC, kH, kW], 2 - [oC, kH, kW, iC] + + int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0), indOoH; + if (!isNCHW) { + indIOioC = 3; + indIiH = 1; + indOoH = 1; + } else { + indIOioC = 1; + indIiH = 2; + indOoH = 2; + } + + std::vector gradIShape = + INPUT_VARIABLE(0)->template asVectorT(); + + const int bS = gradIShape[0]; // batch size + const int iH = gradIShape[indIiH]; // input height + const int iW = gradIShape[indIiH + 1]; // input width + const int iC = gradIShape[indIOioC]; // input channels + const int oC = weightsShapeInfo[indWoC + 1]; // output channels + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), + 0, + "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients " + "(next epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, + "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + + Nd4jLong* gradIshapeInfo(nullptr); + ALLOCATE(gradIshapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); + + gradIshapeInfo[0] = rank; + gradIshapeInfo[1] = bS; + + if (isNCHW) { + gradIshapeInfo[2] = iC; + gradIshapeInfo[3] = iH; + gradIshapeInfo[4] = iW; + } else { + gradIshapeInfo[2] = iH; + gradIshapeInfo[3] = iW; + gradIshapeInfo[4] = iC; + } + + ShapeUtils::updateStridesAndType(gradIshapeInfo, gradOShapeInfo, + shape::order(gradOShapeInfo)); + + return SHAPELIST(CONSTANT(gradIshapeInfo)); } - -} -} +} // namespace ops +} // namespace sd #endif -#endif //LIBND4J_CONVO_OPS_H +#endif // LIBND4J_CONVO_OPS_H diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp index 889a01b9ab76..baef60fc7ab9 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp @@ -14,7 +14,6 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // @author Yurii Shyrma, created on 05.02.2018 // @@ -22,347 +21,526 @@ #include #if NOT_EXCLUDED(OP_conv3dnew) +#include #include -#include #include -#include +#include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) { - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) - - REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM CONV3D OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf()); - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode); - - nd4j_debug("MKL-DNN is not used for conv3dnew!\n", 0); - - std::vector permutForOutput; - - if (isNCDHW) - permutForOutput = {0,2,3,4,1}; // [bS, oC, oD, oH, oW] -> [bS, oD, oH, oW, oC] - else - input = new NDArray(input->permute({0,4,1,2,3})); - - std::vector wAxes; - if(0 == wFormat) - wAxes = {3,0,1,2}; - else if(1 == wFormat) - wAxes = {1,2,3,4}; - else - wAxes = {4,1,2,3}; - - NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext()); - ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW] - // [bS, iC, kD, kH, kW, oD, oH, oW] x [kD, kH, kW, iC, oC] = [bS, oD, oH, oW, oC] - // [bS, iC, kD, kH, kW, oD, oH, oW] x [oC, iC, kD, kH, kW] = [bS, oD, oH, oW, oC] - // [bS, iC, kD, kH, kW, oD, oH, oW] x [oC, kD, kH, kW, iC] = [bS, oD, oH, oW, oC] - MmulHelper::tensorDot(&columns, weights, output, {1,2,3,4}, wAxes, permutForOutput); - - if(bias) - // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); - helpers::addBias(block, *output, *bias, *output, isNCDHW); - - if(!isNCDHW) - delete input; - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto output = OUTPUT_VARIABLE( + 0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "CUSTOM CONV3D OP: rank of input array must be equal to 5, but " + "got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, + "CUSTOM CONV3D OP: rank of weights array must be equal to 5, " + "but got %i instead !", + weights->rankOf()); + + int kD = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 + ? INT_ARG(2) + : static_cast(weights->sizeAt(2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID + int isNCDHW = + block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 ? INT_ARG(14) + : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, + // kD, kH, kW], 2-[oC, kD, kH, kW, iC] + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWiC, indWoC, indWkD); + + REQUIRE_TRUE(paddingMode < 2, 0, + "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not " + "allowed for this operation !"); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM CONV3D OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM CONV3D OP: wrong shape of array with biases, expected " + "rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW, paddingMode); + + nd4j_debug("MKL-DNN is not used for conv3dnew!\n", 0); + + std::vector permutForOutput; + + if (isNCDHW) + permutForOutput = {0, 2, 3, 4, + 1}; // [bS, oC, oD, oH, oW] -> [bS, oD, oH, oW, oC] + else + input = new NDArray(input->permute({0, 4, 1, 2, 3})); + + std::vector wAxes; + if (0 == wFormat) + wAxes = {3, 0, 1, 2}; + else if (1 == wFormat) + wAxes = {1, 2, 3, 4}; + else + wAxes = {4, 1, 2, 3}; + + NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, + input->dataType(), block.launchContext()); + ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, + dH, dW); // [bS, iC, iD, iH, iW] is convoluted to + // [bS, iC, kD, kH, kW, oD, oH, oW] + // [bS, iC, kD, kH, kW, oD, oH, oW] x [kD, kH, kW, iC, oC] = [bS, oD, oH, oW, + // oC] [bS, iC, kD, kH, kW, oD, oH, oW] x [oC, iC, kD, kH, kW] = [bS, oD, oH, + // oW, oC] [bS, iC, kD, kH, kW, oD, oH, oW] x [oC, kD, kH, kW, iC] = [bS, oD, + // oH, oW, oC] + MmulHelper::tensorDot(&columns, weights, output, {1, 2, 3, 4}, wAxes, + permutForOutput); + + if (bias) + // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); + helpers::addBias(block, *output, *bias, *output, isNCDHW); + + if (!isNCDHW) delete input; + + return Status::OK(); } - DECLARE_TYPES(conv3dnew) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(conv3dnew) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(conv3dnew) { - - auto inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(shape::sizeAt(weightsShapeInfo, 2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID; - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] - - const int rank = 5; - REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); - REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV3D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo); - REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV3D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo); - - int indIOioC, indIiD, indWoC(0 == wFormat ? 4 : 0); - if(!isNCDHW) { - indIOioC = 4; indIiD = 1; - } - else { - indIOioC = 1; indIiD = 2; - } - - int bS = inputShapeInfo[1]; // batch size - int iD = inputShapeInfo[indIiD+1]; // input depth - int iH = inputShapeInfo[indIiD+2]; // input height - int iW = inputShapeInfo[indIiD+3]; // input width - int iC = inputShapeInfo[indIOioC+1]; // input channels - int oC = weightsShapeInfo[indWoC+1]; // output channels - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if (biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - - int oD, oH, oW; // output depth, height, width - ConvolutionUtils::calcOutSizePool3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); - - Nd4jLong* outputShapeInfo = nullptr; - ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inputShapeInfo), Nd4jLong); - - outputShapeInfo[0] = rank; - outputShapeInfo[1] = bS; - if (isNCDHW) { - outputShapeInfo[2] = oC; - outputShapeInfo[3] = oD; - outputShapeInfo[4] = oH; - outputShapeInfo[5] = oW; - } else { - outputShapeInfo[2] = oD; - outputShapeInfo[3] = oH; - outputShapeInfo[4] = oW; - outputShapeInfo[5] = oC; - } - - ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, shape::order(inputShapeInfo)); - - return SHAPELIST(CONSTANT(outputShapeInfo)); + auto inputShapeInfo = inputShape->at( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weightsShapeInfo = inputShape->at( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] + + int kD = INT_ARG(0) > 0 ? INT_ARG(0) + : static_cast(shape::sizeAt( + weightsShapeInfo, 0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 ? INT_ARG(1) + : static_cast(shape::sizeAt( + weightsShapeInfo, 1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 ? INT_ARG(2) + : static_cast(shape::sizeAt( + weightsShapeInfo, 2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID; + int isNCDHW = + block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 ? INT_ARG(14) + : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, + // kD, kH, kW], 2-[oC, kD, kH, kW, iC] + + const int rank = 5; + REQUIRE_TRUE(paddingMode < 2, 0, + "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not " + "allowed for this operation !"); + REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, + "CUSTOM CONV3D OP: rank of input array must be equal to %i, but " + "got %i instead !", + rank, inputShapeInfo); + REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, + "CUSTOM CONV3D OP: rank of weights array must be equal to %i, " + "but got %i instead !", + rank, weightsShapeInfo); + + int indIOioC, indIiD, indWoC(0 == wFormat ? 4 : 0); + if (!isNCDHW) { + indIOioC = 4; + indIiD = 1; + } else { + indIOioC = 1; + indIiD = 2; + } + + int bS = inputShapeInfo[1]; // batch size + int iD = inputShapeInfo[indIiD + 1]; // input depth + int iH = inputShapeInfo[indIiD + 2]; // input height + int iW = inputShapeInfo[indIiD + 3]; // input width + int iC = inputShapeInfo[indIOioC + 1]; // input channels + int oC = weightsShapeInfo[indWoC + 1]; // output channels + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, + "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got " + "%s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, + "CUSTOM CONV3D OP: wrong shape of array with biases, expected " + "rank, length: <=2, %i, but got %i, %i instead !", + oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + + int oD, oH, oW; // output depth, height, width + ConvolutionUtils::calcOutSizePool3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, + pH, pW, dD, dH, dW, iD, iH, iW, + paddingMode); + + Nd4jLong* outputShapeInfo = nullptr; + ALLOCATE(outputShapeInfo, block.workspace(), + shape::shapeInfoLength(inputShapeInfo), Nd4jLong); + + outputShapeInfo[0] = rank; + outputShapeInfo[1] = bS; + if (isNCDHW) { + outputShapeInfo[2] = oC; + outputShapeInfo[3] = oD; + outputShapeInfo[4] = oH; + outputShapeInfo[5] = oW; + } else { + outputShapeInfo[2] = oD; + outputShapeInfo[3] = oH; + outputShapeInfo[4] = oW; + outputShapeInfo[5] = oC; + } + + ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, + shape::order(inputShapeInfo)); + + return SHAPELIST(CONSTANT(outputShapeInfo)); } - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM CONV3D_BP OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 5, 0, "CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !", gradO->rankOf()); - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - int trueoD, trueoH, trueoW; // true output depth/height/width - ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); - - REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D_BP OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode); - - nd4j_debug("MKL-DNN is not used for conv3dnew_bp!\n", 0); - - std::vector gradOaxesForDot; - - if(!isNCDHW) { - gradOaxesForDot = {0,1,2,3}; // bS, oD, oH, oW - input = new NDArray(input->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - gradI = new NDArray(gradI->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - } - else { - gradOaxesForDot = {0,2,3,4}; // bS, oD, oH, oW - } - - std::vector wPermut, colPermut; - - if(0 == wFormat) { - wPermut = {3,0,1,2,4}; - colPermut = {2,3,4,1,0,5,6,7}; - } - else if(1 == wFormat) { - wPermut = {1,2,3,4,0}; - colPermut = {1,2,3,4,0,5,6,7}; - } - else { - wPermut = {4,1,2,3,0}; - colPermut = {2,3,4,1,0,5,6,7}; - } - - // ----- calculation of gradW and gradB ----- // - NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext()); - ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW] - MmulHelper::tensorDot(&columns, gradO, gradW, {0,5,6,7}, gradOaxesForDot, wPermut); // [bS, iC, kD, kH, kW, oD, oH, oW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [iC, kD, kH, kW, oC] - - //----- calculation of gradO -----// - if(gradB) { - if(gradB->rankOf() == 2) - gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false)); - gradO->reduceAlongDimension(reduce::Sum, *gradB, gradOaxesForDot); // sum over bS oD oH oW - if(gradB != OUTPUT_VARIABLE(2)) - delete gradB; - } - - //----- calculation of gradI -----// - // [kD, kH, kW, iC, oC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW] - // [oC, iC, kD, kH, kW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW] - // [oC, kD, kH, kW, iC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW] - MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, colPermut); - ConvolutionUtils::col2vol(block, columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW] - - if(!isNCDHW) { - delete input; - delete gradI; - } - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = + block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_VARIABLE( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "CUSTOM CONV3D_BP OP: rank of input array must be equal to 5, " + "but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, + "CUSTOM CONV3D_BP OP: rank of weights array must be equal to 5, " + "but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 5, 0, + "CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) " + "array must be equal to 5, but got %i instead !", + gradO->rankOf()); + + int kD = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 + ? INT_ARG(2) + : static_cast(weights->sizeAt(2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + int isNCDHW = + block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 ? INT_ARG(14) + : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, + // kD, kH, kW], 2-[oC, kD, kH, kW, iC] + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWiC, indWoC, indWkD); + + int trueoD, trueoH, trueoW; // true output depth/height/width + ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, + sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, + iW, paddingMode); + + REQUIRE_TRUE(paddingMode < 2, 0, + "CUSTOM CONV3D_BP OP: causal padding mode (paddingMode = 2) is " + "not allowed for this operation !"); + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, oC, trueoD, trueoH, trueoW, + 0, indIOioC, indIOioD, + indIOioD + 1, indIOioD + 2}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM CONV3D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW, paddingMode); + + nd4j_debug("MKL-DNN is not used for conv3dnew_bp!\n", 0); + + std::vector gradOaxesForDot; + + if (!isNCDHW) { + gradOaxesForDot = {0, 1, 2, 3}; // bS, oD, oH, oW + input = new NDArray(input->permute( + {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] + gradI = new NDArray(gradI->permute( + {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] + } else { + gradOaxesForDot = {0, 2, 3, 4}; // bS, oD, oH, oW + } + + std::vector wPermut, colPermut; + + if (0 == wFormat) { + wPermut = {3, 0, 1, 2, 4}; + colPermut = {2, 3, 4, 1, 0, 5, 6, 7}; + } else if (1 == wFormat) { + wPermut = {1, 2, 3, 4, 0}; + colPermut = {1, 2, 3, 4, 0, 5, 6, 7}; + } else { + wPermut = {4, 1, 2, 3, 0}; + colPermut = {2, 3, 4, 1, 0, 5, 6, 7}; + } + + // ----- calculation of gradW and gradB ----- // + NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, + input->dataType(), block.launchContext()); + ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, + dH, dW); // [bS, iC, iD, iH, iW] is convoluted to + // [bS, iC, kD, kH, kW, oD, oH, oW] + MmulHelper::tensorDot( + &columns, gradO, gradW, {0, 5, 6, 7}, gradOaxesForDot, + wPermut); // [bS, iC, kD, kH, kW, oD, oH, oW] x [bS, oD, oH, oW, oC]/[bS, + // oC, oD, oH, oW] = [iC, kD, kH, kW, oC] + + //----- calculation of gradO -----// + if (gradB) { + if (gradB->rankOf() == 2) + gradB = new NDArray( + gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false)); + gradO->reduceAlongDimension(reduce::Sum, *gradB, + gradOaxesForDot); // sum over bS oD oH oW + if (gradB != OUTPUT_VARIABLE(2)) delete gradB; + } + + //----- calculation of gradI -----// + // [kD, kH, kW, iC, oC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, + // kW, iC, bS, oD, oH, oW] [oC, iC, kD, kH, kW] x [bS, oD, oH, oW, oC]/[bS, + // oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW] [oC, kD, kH, kW, iC] x + // [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, + // oW] + MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, + colPermut); + ConvolutionUtils::col2vol(block, columns, *gradI, sD, sH, sW, pD, pH, pW, dD, + dH, + dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is + // de-convoluted to [bS, iC, iD, iH, iW] + + if (!isNCDHW) { + delete input; + delete gradI; + } + + return Status::OK(); } - DECLARE_TYPES(conv3dnew_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedInputTypes(3, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - +DECLARE_TYPES(conv3dnew_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedInputTypes(3, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(conv3dnew_bp) { - - auto inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - Nd4jLong const* biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] - Nd4jLong const* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(shape::sizeAt(weightsShapeInfo, 2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] - - const int rank = 5; - REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); - REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo); - REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo); - REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo); - - int indIOioC, indIiD, indWoC(0 == wFormat ? 4 : 0); - if(!isNCDHW) { - indIOioC = 4; indIiD = 1; - } - else { - indIOioC = 1; indIiD = 2; - } - - int bS = inputShapeInfo[1]; // batch size - int iD = inputShapeInfo[indIiD+1]; // input depth - int iH = inputShapeInfo[indIiD+2]; // input height - int iW = inputShapeInfo[indIiD+3]; // input width - int iC = inputShapeInfo[indIOioC+1]; // input channels - int oC = weightsShapeInfo[indWoC+1]; // output channels - - int trueoD, trueoH, trueoW; // true output depth/height/width - ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIiD,indIiD+1,indIiD+2}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if(biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - - auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - - if(biasShapeInfo) { - auto gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), CONSTANT(gradBshapeInfo)); - } - - return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo)); -} -} + auto inputShapeInfo = inputShape->at( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weightsShapeInfo = inputShape->at( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + Nd4jLong const* biasShapeInfo = + block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] + Nd4jLong const* gradOShapeInfo = + block.width() > 3 + ? inputShape->at(3) + : inputShape->at(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, + // oH, oW] (NCDHW), epsilon_next + + int kD = INT_ARG(0) > 0 ? INT_ARG(0) + : static_cast(shape::sizeAt( + weightsShapeInfo, 0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 ? INT_ARG(1) + : static_cast(shape::sizeAt( + weightsShapeInfo, 1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 ? INT_ARG(2) + : static_cast(shape::sizeAt( + weightsShapeInfo, 2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + int isNCDHW = + block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 ? INT_ARG(14) + : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, + // kD, kH, kW], 2-[oC, kD, kH, kW, iC] + + const int rank = 5; + REQUIRE_TRUE(paddingMode < 2, 0, + "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not " + "allowed for this operation !"); + REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, + "CUSTOM CONV3D_BP OP: rank of input array must be equal to %i, " + "but got %i instead !", + rank, inputShapeInfo); + REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, + "CUSTOM CONV3D_BP OP: rank of weights array must be equal to " + "%i, but got %i instead !", + rank, weightsShapeInfo); + REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, + "CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) " + "array must be equal to %i, but got %i instead !", + rank, gradOShapeInfo); + + int indIOioC, indIiD, indWoC(0 == wFormat ? 4 : 0); + if (!isNCDHW) { + indIOioC = 4; + indIiD = 1; + } else { + indIOioC = 1; + indIiD = 2; + } + + int bS = inputShapeInfo[1]; // batch size + int iD = inputShapeInfo[indIiD + 1]; // input depth + int iH = inputShapeInfo[indIiD + 2]; // input height + int iW = inputShapeInfo[indIiD + 3]; // input width + int iC = inputShapeInfo[indIOioC + 1]; // input channels + int oC = weightsShapeInfo[indWoC + 1]; // output channels + + int trueoD, trueoH, trueoW; // true output depth/height/width + ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, + sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, + iW, paddingMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, oC, trueoD, trueoH, trueoW, + 0, indIOioC, indIiD, indIiD + 1, + indIiD + 2}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), + 0, + "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, + "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, + "CUSTOM CONV3D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + + auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType( + inputShapeInfo, gradOShapeInfo, false, block.workspace()); + auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, gradOShapeInfo, false, block.workspace()); + + if (biasShapeInfo) { + auto gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType( + biasShapeInfo, gradOShapeInfo, false, block.workspace()); + return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), + CONSTANT(gradBshapeInfo)); + } + + return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo)); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp index d62a98d52bb8..8d4d976763ee 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp @@ -22,308 +22,473 @@ #include #if NOT_EXCLUDED(OP_deconv2d) -#include #include +#include +#include +#include #include #include -#include -#include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - - auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - - REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DECONV2D OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DECONV2D OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - if(!isNCHW) - output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] - - std::vector colPermut; - if(1 == wFormat) - colPermut = {1, 2, 3, 0, 4, 5}; - else - colPermut = {2, 3, 1, 0, 4, 5}; - - if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass - ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); - - NDArray columns(input->ordering(), {bS, oC, kH, kW, iH, iW}, input->dataType(), block.launchContext()); - - //----- calculation of output -----// - // NHWC: [kH, kW, oC, iC] x [bS, iH, iW, iC] = [kH, kW, oC, bS, iH, iW] - // NHWC: [iC, oC, kH, kW] x [bS, iH, iW, iC] = [oC, kH, kW, bS, iH, iW] - // NHWC: [iC, kH, kW, oC] x [bS, iH, iW, iC] = [kH, kW, oC, bS, iH, iW] - sd::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, colPermut); - LaunchContext* ctx = block.launchContext(); - helpers::col2im(*ctx, columns, *output, sH, sW, pH, pW, oH, oW, dH, dW); // [bS, oC, kH, kW, iH, iW] is de-convoluted to [bS, oC, oH, oW] - - //----- add biases if required -----// - if(bias) - // output->applyBroadcast(broadcast::Add, {1}, bias); - helpers::addBias(block, *output, *bias, *output, true); - - if(!isNCHW) - delete output; - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + auto output = OUTPUT_NULLIFIED( + 0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "CUSTOM DECONV2D OP: rank of input array must be equal to 4, " + "but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "CUSTOM DECONV2D OP: rank of weights array must be equal to 4, " + "but got %i instead !", + weights->rankOf()); + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, + // oC, kH, kW], 2 - [iC, kH, kW, oC] + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWoC, indWiC, indWkH, indOoH); + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DECONV2D OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM DECONV2D OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + if (!isNCHW) + output = new NDArray( + output->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] + + std::vector colPermut; + if (1 == wFormat) + colPermut = {1, 2, 3, 0, 4, 5}; + else + colPermut = {2, 3, 1, 0, 4, 5}; + + if (isSameMode) // Note: we're intentionally swapping iH and oH, to + // calculated the padding for a"normal" conv (not deconv) + // forward pass + ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, + dW); + + NDArray columns(input->ordering(), {bS, oC, kH, kW, iH, iW}, + input->dataType(), block.launchContext()); + + //----- calculation of output -----// + // NHWC: [kH, kW, oC, iC] x [bS, iH, iW, iC] = [kH, kW, oC, bS, iH, iW] + // NHWC: [iC, oC, kH, kW] x [bS, iH, iW, iC] = [oC, kH, kW, bS, iH, iW] + // NHWC: [iC, kH, kW, oC] x [bS, iH, iW, iC] = [kH, kW, oC, bS, iH, iW] + sd::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, + colPermut); + LaunchContext* ctx = block.launchContext(); + helpers::col2im( + *ctx, columns, *output, sH, sW, pH, pW, oH, oW, dH, + dW); // [bS, oC, kH, kW, iH, iW] is de-convoluted to [bS, oC, oH, oW] + + //----- add biases if required -----// + if (bias) + // output->applyBroadcast(broadcast::Add, {1}, bias); + helpers::addBias(block, *output, *bias, *output, true); + + if (!isNCHW) delete output; + + return Status::OK(); +} +DECLARE_TYPES(deconv2d) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - DECLARE_TYPES(deconv2d) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } DECLARE_SHAPE_FN(deconv2d) { - - auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weightsShapeInfo = inputShape->at(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] - auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] - - const int rank = 4; - REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, "CUSTOM DECONV2D OP: rank of input array must be equal to %i, but got %i instead !", rank, shape::rank(inputShapeInfo)); - REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DECONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, shape::rank(weightsShapeInfo)); - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] - - int indIOioC, indIiH, indWoC(0 == wFormat ? 2 : (1 == wFormat ? 1 : 3)); - if(!isNCHW) { - indIOioC = 3; indIiH = 1; - } - else { - indIOioC = 1; indIiH = 2; - } - - const int bS = inputShapeInfo[1]; // batch size - const int iH = inputShapeInfo[indIiH+1]; // input height - const int iW = inputShapeInfo[indIiH+2]; // input width - const int iC = inputShapeInfo[indIOioC+1]; // input channels - const int oC = weightsShapeInfo[indWoC+1]; // output channels - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); - REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if (biasShapeInfo) - REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - - int oH, oW; // output height, width - ConvolutionUtils::calcOutSizeDeconv2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - Nd4jLong outputShape[4]; - - outputShape[0] = bS; - if (isNCHW) { - outputShape[1] = oC; - outputShape[2] = oH; - outputShape[3] = oW; - } else { - outputShape[1] = oH; - outputShape[2] = oW; - outputShape[3] = oC; - } - - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(weightsShapeInfo), shape::order(inputShapeInfo), outputShape, 4))); + auto inputShapeInfo = + inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weightsShapeInfo = inputShape->at( + 1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] + + const int rank = 4; + REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, + "CUSTOM DECONV2D OP: rank of input array must be equal to %i, " + "but got %i instead !", + rank, shape::rank(inputShapeInfo)); + REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, + "CUSTOM DECONV2D OP: rank of weights array must be equal to %i, " + "but got %i instead !", + rank, shape::rank(weightsShapeInfo)); + + int kH = INT_ARG(0) > 0 ? INT_ARG(0) + : static_cast(shape::sizeAt( + weightsShapeInfo, 0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 ? INT_ARG(1) + : static_cast(shape::sizeAt( + weightsShapeInfo, 1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, + // oC, kH, kW], 2 - [iC, kH, kW, oC] + + int indIOioC, indIiH, indWoC(0 == wFormat ? 2 : (1 == wFormat ? 1 : 3)); + if (!isNCHW) { + indIOioC = 3; + indIiH = 1; + } else { + indIOioC = 1; + indIiH = 2; + } + + const int bS = inputShapeInfo[1]; // batch size + const int iH = inputShapeInfo[indIiH + 1]; // input height + const int iW = inputShapeInfo[indIiH + 2]; // input width + const int iC = inputShapeInfo[indIOioC + 1]; // input channels + const int oC = weightsShapeInfo[indWoC + 1]; // output channels + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); + REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), + shape::rank(weightsShapeInfo), + shape::shapeOf(weightsShapeInfo)), + 0, + "CUSTOM DECONV2D OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE( + shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), + 0, + "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, " + "length: <=2, %i, but got %i, %i instead !", + oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + + int oH, oW; // output height, width + ConvolutionUtils::calcOutSizeDeconv2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, + iH, iW, isSameMode); + + Nd4jLong outputShape[4]; + + outputShape[0] = bS; + if (isNCHW) { + outputShape[1] = oC; + outputShape[2] = oH; + outputShape[3] = oW; + } else { + outputShape[1] = oH; + outputShape[2] = oW; + outputShape[3] = oC; + } + + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(weightsShapeInfo), + shape::order(inputShapeInfo), outputShape, 4))); } - DECLARE_TYPES(deconv2d_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(deconv2d_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI - auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DECONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DECONV2D_BP OP: rank of weights array must be equal to 4 , but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 4, 0, "CUSTOM DECONV2D_BP OP: rank of output gradients (next epsilon) array must be equal to 4, but got %i instead !", gradO->rankOf()); - - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - if(isSameMode){ // SAME - //Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass - ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); - } - - // ----- calculation of gradI -> pass it through conv2d_ff ----- // - sd::ops::conv2d conv2d; - const Nd4jStatus status = conv2d.execute({gradO, weights}, {gradI}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, !isNCHW, wFormat}, {}); - if (status != ND4J_STATUS_OK) - return status; - - // -----prepare permutation arrays and axes for dot product ----- // - std::vector inputAxes; - - if(!isNCHW) { - gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] - inputAxes = {0, 1, 2}; // bS, iH, iW - } - else - inputAxes = {0, 2, 3}; // bS, iH, iW - - std::vector gradWAxes; // empty for wFormat = 1 - if(0 == wFormat) - gradWAxes = {3, 2, 0, 1}; - else if(2 == wFormat) - gradWAxes = {0, 3, 1, 2}; - - // ----- calculation of gradW ----- // - NDArray columns(input->ordering(), {bS, oC, kH, kW, iH, iW}, input->dataType(), block.launchContext()); - - LaunchContext* ctx = block.launchContext(); - helpers::im2col(*ctx, *gradO, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, oC, oH, oW] is convoluted to [bS, oC, kH, kW, iH, iW] - MmulHelper::tensorDot(input, &columns, gradW, inputAxes, {0, 4, 5}, gradWAxes); // [bS, iC, iH, iW]/[bS, iH, iW, iC] x [bS, oC, kH, kW, iH, iW] = [iC, oC, kH, kW] - - // ----- calculation of gradB ----- // - if(gradB) { - if(gradB->rankOf() == 2) - gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()}, false)); - gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3}); // sum over bS, oH, oW - if(gradB != OUTPUT_VARIABLE(2)) - delete gradB; - } - - if(!isNCHW) - delete gradO; - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI + auto gradW = OUTPUT_VARIABLE( + 1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "CUSTOM DECONV2D_BP OP: rank of input array must be equal to 4, " + "but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "CUSTOM DECONV2D_BP OP: rank of weights array must be equal to " + "4 , but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 4, 0, + "CUSTOM DECONV2D_BP OP: rank of output gradients (next epsilon) " + "array must be equal to 4, but got %i instead !", + gradO->rankOf()); + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, + // oC, kH, kW], 2 - [iC, kH, kW, oC] + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWoC, indWiC, indWkH, indOoH); + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM DECONV2D_BP OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM DECONV2D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + if (isSameMode) { // SAME + // Note: we're intentionally swapping iH and oH, to calculated the padding + // for a"normal" conv (not deconv) forward pass + ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, + dW); + } + + // ----- calculation of gradI -> pass it through conv2d_ff ----- // + sd::ops::conv2d conv2d; + const Nd4jStatus status = conv2d.execute( + {gradO, weights}, {gradI}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, !isNCHW, wFormat}, {}); + if (status != ND4J_STATUS_OK) return status; + + // -----prepare permutation arrays and axes for dot product ----- // + std::vector inputAxes; + + if (!isNCHW) { + gradO = new NDArray( + gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] + inputAxes = {0, 1, 2}; // bS, iH, iW + } else + inputAxes = {0, 2, 3}; // bS, iH, iW + + std::vector gradWAxes; // empty for wFormat = 1 + if (0 == wFormat) + gradWAxes = {3, 2, 0, 1}; + else if (2 == wFormat) + gradWAxes = {0, 3, 1, 2}; + + // ----- calculation of gradW ----- // + NDArray columns(input->ordering(), {bS, oC, kH, kW, iH, iW}, + input->dataType(), block.launchContext()); + + LaunchContext* ctx = block.launchContext(); + helpers::im2col( + *ctx, *gradO, columns, kH, kW, sH, sW, pH, pW, dH, dW, + NDArrayFactory::create( + 0.f, input->getContext())); // [bS, oC, oH, oW] is convoluted to [bS, + // oC, kH, kW, iH, iW] + MmulHelper::tensorDot(input, &columns, gradW, inputAxes, {0, 4, 5}, + gradWAxes); // [bS, iC, iH, iW]/[bS, iH, iW, iC] x [bS, + // oC, kH, kW, iH, iW] = [iC, oC, kH, kW] + + // ----- calculation of gradB ----- // + if (gradB) { + if (gradB->rankOf() == 2) + gradB = new NDArray( + gradB->reshape(gradB->ordering(), {gradB->lengthOf()}, false)); + gradO->reduceAlongDimension(reduce::Sum, *gradB, + {0, 2, 3}); // sum over bS, oH, oW + if (gradB != OUTPUT_VARIABLE(2)) delete gradB; + } + + if (!isNCHW) delete gradO; + + return Status::OK(); } DECLARE_SHAPE_FN(deconv2d_bp) { - - auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) - auto weightsShapeInfo = inputShape->at(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] - Nd4jLong const* biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] - auto gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - - const int rank = 4; - REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, "CUSTOM DECONV2D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, shape::rank(inputShapeInfo)); - REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DECONV2D_BP OP: rank of weights array must be equal to %i , but got %i instead !", rank, shape::rank(weightsShapeInfo)); - REQUIRE_TRUE(shape::rank(gradOShapeInfo) == rank, 0, "CUSTOM DECONV2D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, shape::rank(gradOShapeInfo)); - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] - - int indIOioC, indIiH, indOoH, indWoC(0 == wFormat ? 2 : (1 == wFormat ? 1 : 3)); - if(!isNCHW) { - indIOioC = 3; indIiH = 1; indOoH = 1; - } - else { - indIOioC = 1; indIiH = 2; indOoH = 2; - } - - const int bS = inputShapeInfo[1]; // batch size - const int iH = inputShapeInfo[indIiH+1]; // input height - const int iW = inputShapeInfo[indIiH+2]; // input width - const int iC = inputShapeInfo[indIOioC+1]; // input channels - const int oC = weightsShapeInfo[indWoC+1]; // output channels - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); - REQUIRE_TRUE(shape::shapeEquals(4, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DECONV2D_BP OP: wrong shape of output gradients next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); - REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if(biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - - auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - auto gradWShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - - auto shapes = SHAPELIST(CONSTANT(gradIShapeInfo), CONSTANT(gradWShapeInfo)); - - if (biasShapeInfo != nullptr) { - auto gradBShapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - shapes->push_back(CONSTANT(gradBShapeInfo)); - } - - return shapes; + auto inputShapeInfo = + inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) + auto weightsShapeInfo = inputShape->at( + 1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + Nd4jLong const* biasShapeInfo = + block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] + auto gradOShapeInfo = + block.width() > 3 + ? inputShape->at(3) + : inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] + // (NCDHW), epsilon_next + + const int rank = 4; + REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, + "CUSTOM DECONV2D_BP OP: rank of input array must be equal to " + "%i, but got %i instead !", + rank, shape::rank(inputShapeInfo)); + REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, + "CUSTOM DECONV2D_BP OP: rank of weights array must be equal to " + "%i , but got %i instead !", + rank, shape::rank(weightsShapeInfo)); + REQUIRE_TRUE(shape::rank(gradOShapeInfo) == rank, 0, + "CUSTOM DECONV2D_BP OP: rank of output gradients (next epsilon) " + "array must be equal to %i, but got %i instead !", + rank, shape::rank(gradOShapeInfo)); + + int kH = INT_ARG(0) > 0 ? INT_ARG(0) + : static_cast(shape::sizeAt( + weightsShapeInfo, 0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 ? INT_ARG(1) + : static_cast(shape::sizeAt( + weightsShapeInfo, 1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, + // oC, kH, kW], 2 - [iC, kH, kW, oC] + + int indIOioC, indIiH, indOoH, + indWoC(0 == wFormat ? 2 : (1 == wFormat ? 1 : 3)); + if (!isNCHW) { + indIOioC = 3; + indIiH = 1; + indOoH = 1; + } else { + indIOioC = 1; + indIiH = 2; + indOoH = 2; + } + + const int bS = inputShapeInfo[1]; // batch size + const int iH = inputShapeInfo[indIiH + 1]; // input height + const int iW = inputShapeInfo[indIiH + 2]; // input width + const int iC = inputShapeInfo[indIOioC + 1]; // input channels + const int oC = weightsShapeInfo[indWoC + 1]; // output channels + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); + REQUIRE_TRUE(shape::shapeEquals(4, expectedGradOShape.data(), + shape::rank(gradOShapeInfo), + shape::shapeOf(gradOShapeInfo)), + 0, + "CUSTOM DECONV2D_BP OP: wrong shape of output gradients next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); + REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), + shape::rank(weightsShapeInfo), + shape::shapeOf(weightsShapeInfo)), + 0, + "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, + "CUSTOM DECONV2D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + + auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType( + inputShapeInfo, gradOShapeInfo, false, block.workspace()); + auto gradWShapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, gradOShapeInfo, false, block.workspace()); + + auto shapes = SHAPELIST(CONSTANT(gradIShapeInfo), CONSTANT(gradWShapeInfo)); + + if (biasShapeInfo != nullptr) { + auto gradBShapeInfo = ShapeBuilders::copyShapeInfoAndType( + biasShapeInfo, gradOShapeInfo, false, block.workspace()); + shapes->push_back(CONSTANT(gradBShapeInfo)); + } + + return shapes; } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d_tf.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d_tf.cpp index 9af389bf63b6..418fc61b46bf 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d_tf.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d_tf.cpp @@ -26,128 +26,207 @@ #include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) { - - auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto gradIShape = INPUT_VARIABLE(0); // [4] - shape of input of conv2d (that is shape of gradI) - - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - const int rank = gradO->rankOf(); - - REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM DECONV2D_TF OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, "CUSTOM DECONV2D_TF OP: rank of array with output shape must be equal to 1, but got %i instead !", gradIShape->rankOf()); - REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, "CUSTOM DECONV2D_TF OP: length of array with output shape must be equal to 4, but got %i instead !", gradIShape->lengthOf()); - - // create empty conv2d input array - NDArray input(gradO->ordering(), gradIShape->asVectorT(), gradO->dataType(), block.launchContext()); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - - ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat); - - return Status::OK(); + auto gradO = INPUT_VARIABLE( + 2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto gradIShape = INPUT_VARIABLE( + 0); // [4] - shape of input of conv2d (that is shape of gradI) + + auto gradI = OUTPUT_NULLIFIED( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, + // iC, kH, kW], 2 - [oC, kH, kW, iC] + + const int rank = gradO->rankOf(); + + REQUIRE_TRUE(weights->rankOf() == rank, 0, + "CUSTOM DECONV2D_TF OP: rank of weights array must be equal to " + "4, but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, + "CUSTOM DECONV2D_TF OP: rank of array with output shape must be " + "equal to 1, but got %i instead !", + gradIShape->rankOf()); + REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, + "CUSTOM DECONV2D_TF OP: length of array with output shape must " + "be equal to 4, but got %i instead !", + gradIShape->lengthOf()); + + // create empty conv2d input array + NDArray input(gradO->ordering(), gradIShape->asVectorT(), + gradO->dataType(), block.launchContext()); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on " + "array with output shape expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + + ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, + nullptr, nullptr, kH, kW, sH, sW, pH, pW, dH, dW, + isSameMode, isNCHW, wFormat); + + return Status::OK(); } - DECLARE_TYPES(deconv2d_tf) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(deconv2d_tf) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(deconv2d_tf) { - - auto gradOShapeInfo = inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto gradIShapeShapeInfo = inputShape->at(0); // [4] - - const int rank = 4; - - REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DECONV2D_TF OP: rank of weights array must be equal to %i, but got %i instead !", rank, shape::rank(weightsShapeInfo)); - REQUIRE_TRUE(shape::rank(gradOShapeInfo) == rank, 0, "CUSTOM DECONV2D_TF OP: rank of input array must be equal to %i, but got %i instead !", rank, shape::rank(gradOShapeInfo)); - REQUIRE_TRUE(shape::rank(gradIShapeShapeInfo) == 1, 0, "CUSTOM DECONV2D_TF OP: rank of array with output shape must be equal to %i, but got %i instead !", 1, shape::rank(gradIShapeShapeInfo)); - - const int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) height - const int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) width - const int sH = INT_ARG(2); // strides height - const int sW = INT_ARG(3); // strides width - const int pH = INT_ARG(4); // paddings height - const int pW = INT_ARG(5); // paddings width - const int dH = INT_ARG(6); // dilations height - const int dW = INT_ARG(7); // dilations width - const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0), indOoH; - if(!isNCHW) { - indIOioC = 3; indIiH = 1; indOoH = 1; - } - else { - indIOioC = 1; indIiH = 2; indOoH = 2; - } - - std::vector gradIShape = INPUT_VARIABLE(0)->template asVectorT(); - - const int bS = gradIShape[0]; // batch size - const int iH = gradIShape[indIiH]; // input height - const int iW = gradIShape[indIiH+1]; // input width - const int iC = gradIShape[indIOioC]; // input channels - const int oC = weightsShapeInfo[indWoC+1]; // output channels - const int oH = gradOShapeInfo[indOoH+1]; // input height - const int oW = gradOShapeInfo[indOoH+2]; // input width - - int trueiH, trueiW; // output height, width - ConvolutionUtils::calcOutSizeDeconv2D(trueiH, trueiW, kH, kW, sH, sW, pH, pW, dH, dW, oH, oW, isSameMode); - - std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,trueiH,trueiW, 0,indIOioC,indIiH,indIiH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); - REQUIRE_TRUE(expectedGradIShape == gradIShape, 0, "CUSTOM DECONV2D_TF OP: wrong shape of array with output shape, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradIShape).c_str()); - REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - - Nd4jLong shape[4]; - shape[0] = bS; - - if (isNCHW) { - shape[1] = iC; - shape[2] = iH; - shape[3] = iW; - } else { - shape[1] = iH; - shape[2] = iW; - shape[3] = iC; - } - - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(weightsShapeInfo), shape::order(gradOShapeInfo), 4, shape)); + auto gradOShapeInfo = inputShape->at( + 2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto weightsShapeInfo = inputShape->at( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto gradIShapeShapeInfo = inputShape->at(0); // [4] + + const int rank = 4; + + REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, + "CUSTOM DECONV2D_TF OP: rank of weights array must be equal to " + "%i, but got %i instead !", + rank, shape::rank(weightsShapeInfo)); + REQUIRE_TRUE(shape::rank(gradOShapeInfo) == rank, 0, + "CUSTOM DECONV2D_TF OP: rank of input array must be equal to " + "%i, but got %i instead !", + rank, shape::rank(gradOShapeInfo)); + REQUIRE_TRUE(shape::rank(gradIShapeShapeInfo) == 1, 0, + "CUSTOM DECONV2D_TF OP: rank of array with output shape must be " + "equal to %i, but got %i instead !", + 1, shape::rank(gradIShapeShapeInfo)); + + const int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(shape::sizeAt( + weightsShapeInfo, 0)); // filter(kernel) height + const int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(shape::sizeAt( + weightsShapeInfo, 1)); // filter(kernel) width + const int sH = INT_ARG(2); // strides height + const int sW = INT_ARG(3); // strides width + const int pH = INT_ARG(4); // paddings height + const int pW = INT_ARG(5); // paddings width + const int dH = INT_ARG(6); // dilations height + const int dW = INT_ARG(7); // dilations width + const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + const int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + const int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, + // iC, kH, kW], 2 - [oC, kH, kW, iC] + + int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0), indOoH; + if (!isNCHW) { + indIOioC = 3; + indIiH = 1; + indOoH = 1; + } else { + indIOioC = 1; + indIiH = 2; + indOoH = 2; + } + + std::vector gradIShape = + INPUT_VARIABLE(0)->template asVectorT(); + + const int bS = gradIShape[0]; // batch size + const int iH = gradIShape[indIiH]; // input height + const int iW = gradIShape[indIiH + 1]; // input width + const int iC = gradIShape[indIOioC]; // input channels + const int oC = weightsShapeInfo[indWoC + 1]; // output channels + const int oH = gradOShapeInfo[indOoH + 1]; // input height + const int oW = gradOShapeInfo[indOoH + 2]; // input width + + int trueiH, trueiW; // output height, width + ConvolutionUtils::calcOutSizeDeconv2D(trueiH, trueiW, kH, kW, sH, sW, pH, pW, + dH, dW, oH, oW, isSameMode); + + std::vector expectedGradIShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, trueiH, trueiW, 0, indIOioC, indIiH, indIiH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); + REQUIRE_TRUE(expectedGradIShape == gradIShape, 0, + "CUSTOM DECONV2D_TF OP: wrong shape of array with output shape, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradIShape).c_str(), + ShapeUtils::shapeAsString(gradIShape).c_str()); + REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), + shape::rank(weightsShapeInfo), + shape::shapeOf(weightsShapeInfo)), + 0, + "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + + Nd4jLong shape[4]; + shape[0] = bS; + + if (isNCHW) { + shape[1] = iC; + shape[2] = iH; + shape[3] = iW; + } else { + shape[1] = iH; + shape[2] = iW; + shape[3] = iC; + } + + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(weightsShapeInfo), shape::order(gradOShapeInfo), 4, + shape)); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp index 7c68ee74caf5..648f088d6eb5 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp @@ -21,334 +21,514 @@ #include #if NOT_EXCLUDED(OP_deconv3d) +#include #include -#include #include -#include +#include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - - auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) - - REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM DECONV3D OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM DECONV3D OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf()); - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0)); // filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2)); // filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - if(!isNCDHW) - output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW] - - std::vector colPermut; - if(1 == wFormat) - colPermut = {1,2,3,4,0,5,6,7}; - else - colPermut = {2,3,4,1,0,5,6,7}; - - if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass - ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - NDArray columns(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext()); - - //----- calculation of output -----// - // [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW] - // [iC, oC, kD, kH, kW] x [bS, iD, iH, iW, iC] = [oC, kD, kH, kW, bS, iD, iH, iW] - // [iC, kD, kH, kW, oC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW] - sd::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, colPermut); // [bS, oC, kD, kH, kW, iD, iH, iW] -> [kD, kH, kW, oC, bS, iD, iH, iW] - ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW] - - //----- add biases if required -----// - if(bias) - // output->applyBroadcast(broadcast::Add,{1}, bias); - helpers::addBias(block, *output, *bias, *output, true); - - if(!isNCDHW) - delete output; - - return Status::OK(); - + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + auto output = OUTPUT_VARIABLE( + 0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "CUSTOM DECONV3D OP: rank of input array must be equal to 5, " + "but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, + "CUSTOM DECONV3D OP: rank of weights array must be equal to 5, " + "but got %i instead !", + weights->rankOf()); + + int kD = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 + ? INT_ARG(2) + : static_cast(weights->sizeAt(2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID + int isNCDHW = + block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 + ? INT_ARG(14) + : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], + // 2 - [iC, kD, kH, kW, oC] + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWoC, indWiC, indWkD); + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DECONV3D OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM DECONV3D OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + if (!isNCDHW) + output = new NDArray(output->permute( + {0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW] + + std::vector colPermut; + if (1 == wFormat) + colPermut = {1, 2, 3, 4, 0, 5, 6, 7}; + else + colPermut = {2, 3, 4, 1, 0, 5, 6, 7}; + + if (isSameMode) // Note: we're intentionally swapping iH and oH, to + // calculated the padding for a"normal" conv (not deconv) + // forward pass + ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + NDArray columns(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, + input->dataType(), block.launchContext()); + + //----- calculation of output -----// + // [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, + // iW] [iC, oC, kD, kH, kW] x [bS, iD, iH, iW, iC] = [oC, kD, kH, kW, bS, iD, + // iH, iW] [iC, kD, kH, kW, oC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, + // iD, iH, iW] + sd::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, + colPermut); // [bS, oC, kD, kH, kW, iD, iH, iW] -> + // [kD, kH, kW, oC, bS, iD, iH, iW] + ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, + dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is + // de-convoluted to [bS, oC, oD, oH, oW] + + //----- add biases if required -----// + if (bias) + // output->applyBroadcast(broadcast::Add,{1}, bias); + helpers::addBias(block, *output, *bias, *output, true); + + if (!isNCDHW) delete output; + + return Status::OK(); } - DECLARE_TYPES(deconv3d) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(deconv3d) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(deconv3d) { - - auto inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NDCHW) - auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] - auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] - - const int rank = 5; - REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, "CUSTOM DECONV3D OP: rank of input array must be equal to %i, but got %i instead !", rank, shape::rank(inputShapeInfo)); - REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DECONV3D OP: rank of weights array must be equal to %i, but got %i instead !", rank, shape::rank(weightsShapeInfo)); - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(shape::sizeAt(weightsShapeInfo, 2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] - - int indIOioC, indIiD, indWoC(0 == wFormat ? 3 : (1 == wFormat ? 1 : 4)); - if(!isNCDHW) { - indIOioC = 4; indIiD = 1; - } - else { - indIOioC = 1; indIiD = 2; - } - - const int bS = inputShapeInfo[1]; // batch size - const int iD = inputShapeInfo[indIiD+1]; // input depth - const int iH = inputShapeInfo[indIiD+2]; // input height - const int iW = inputShapeInfo[indIiD+3]; // input width - const int iC = inputShapeInfo[indIOioC+1]; // input channels - const int oC = weightsShapeInfo[indWoC+1]; // output channels - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); - REQUIRE_TRUE(shape::shapeEquals(5, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if (biasShapeInfo) - REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo)); - - int oD, oH, oW; // output depth, height, width - ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); - - Nd4jLong* outputShapeInfo = nullptr; - ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inputShapeInfo), Nd4jLong); - - outputShapeInfo[0] = rank; - outputShapeInfo[1] = bS; - - if (isNCDHW) { - outputShapeInfo[2] = oC; - outputShapeInfo[3] = oD; - outputShapeInfo[4] = oH; - outputShapeInfo[5] = oW; - } else { - outputShapeInfo[2] = oD; - outputShapeInfo[3] = oH; - outputShapeInfo[4] = oW; - outputShapeInfo[5] = oC; - } - - ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, shape::order(inputShapeInfo)); - - return SHAPELIST(CONSTANT(outputShapeInfo)); + auto inputShapeInfo = inputShape->at( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NDCHW) + auto weightsShapeInfo = inputShape->at( + 1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] + auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] + + const int rank = 5; + REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, + "CUSTOM DECONV3D OP: rank of input array must be equal to %i, " + "but got %i instead !", + rank, shape::rank(inputShapeInfo)); + REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, + "CUSTOM DECONV3D OP: rank of weights array must be equal to %i, " + "but got %i instead !", + rank, shape::rank(weightsShapeInfo)); + + int kD = INT_ARG(0) > 0 ? INT_ARG(0) + : static_cast(shape::sizeAt( + weightsShapeInfo, 0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 ? INT_ARG(1) + : static_cast(shape::sizeAt( + weightsShapeInfo, 1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 ? INT_ARG(2) + : static_cast(shape::sizeAt( + weightsShapeInfo, 2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID + int isNCDHW = + block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 + ? INT_ARG(14) + : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], + // 2 - [iC, kD, kH, kW, oC] + + int indIOioC, indIiD, indWoC(0 == wFormat ? 3 : (1 == wFormat ? 1 : 4)); + if (!isNCDHW) { + indIOioC = 4; + indIiD = 1; + } else { + indIOioC = 1; + indIiD = 2; + } + + const int bS = inputShapeInfo[1]; // batch size + const int iD = inputShapeInfo[indIiD + 1]; // input depth + const int iH = inputShapeInfo[indIiD + 2]; // input height + const int iW = inputShapeInfo[indIiD + 3]; // input width + const int iC = inputShapeInfo[indIOioC + 1]; // input channels + const int oC = weightsShapeInfo[indWoC + 1]; // output channels + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); + REQUIRE_TRUE(shape::shapeEquals(5, expectedWeightsShape.data(), + shape::rank(weightsShapeInfo), + shape::shapeOf(weightsShapeInfo)), + 0, + "CUSTOM DECONV3D OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE( + shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), + 0, + "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, " + "length: <=2, %i, but got %i, %i instead !", + oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo)); + + int oD, oH, oW; // output depth, height, width + ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, + pH, pW, dD, dH, dW, iD, iH, iW, + isSameMode); + + Nd4jLong* outputShapeInfo = nullptr; + ALLOCATE(outputShapeInfo, block.workspace(), + shape::shapeInfoLength(inputShapeInfo), Nd4jLong); + + outputShapeInfo[0] = rank; + outputShapeInfo[1] = bS; + + if (isNCDHW) { + outputShapeInfo[2] = oC; + outputShapeInfo[3] = oD; + outputShapeInfo[4] = oH; + outputShapeInfo[5] = oW; + } else { + outputShapeInfo[2] = oD; + outputShapeInfo[3] = oH; + outputShapeInfo[4] = oW; + outputShapeInfo[5] = oC; + } + + ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, + shape::order(inputShapeInfo)); + + return SHAPELIST(CONSTANT(outputShapeInfo)); } - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI - auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM DECONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM DECONV3D_BP OP: rank of weights array must be equal to 5 , but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 5, 0, "CUSTOM DECONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !", gradO->rankOf()); - - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); - - int trueoD, trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass - ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - // ----- calculation of gradI -> pass it through conv3d_ff ----- // - sd::ops::conv3dnew conv3d; - const Nd4jStatus status = conv3d.execute({gradO, weights}, {gradI}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, isSameMode, !isNCDHW, wFormat}, {}); - if (status != ND4J_STATUS_OK) - return status; - - // -----prepare permutation arrays and axes for dot product ----- // - std::vector inputAxesForDot; - - if(!isNCDHW) { - gradO = new NDArray(gradO->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW] - inputAxesForDot = {0, 1, 2, 3}; // bS, iD, iH, iW - } - else - inputAxesForDot = {0, 2, 3, 4}; // bS, iD, iH, iW - - std::vector gradWAxes; // empty for wFormat = 1 - if(0 == wFormat) - gradWAxes = {4,3,0,1,2}; - else if(2 == wFormat) - gradWAxes = {0,4,1,2,3}; - - // ----- calculation of gradW ----- // - auto columns = NDArrayFactory::create(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext()); - ConvolutionUtils::vol2col(block, *gradO, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, oD, oH, oW] is deconvoluted to [bS, oC, kD, kH, kW, iD, iH, iW] - MmulHelper::tensorDot(input, &columns, gradW, inputAxesForDot, {0, 5, 6, 7}, gradWAxes); // [bS, iC, iD, iH, iW]/[bS, iD, iH, iW, iC] x [bS, oC, kD, kH, kW, iD, iH, iW] = [iC, oC, kD, kH, kW] - - // ----- calculation of gradB ----- // - if(gradB) { - if(gradB->rankOf() == 2) - gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false)); - gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW - if(gradB != OUTPUT_VARIABLE(2)) - delete gradB; - } - - if(!isNCDHW) - delete gradO; - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = + block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), gradI + auto gradW = OUTPUT_VARIABLE( + 1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "CUSTOM DECONV3D_BP OP: rank of input array must be equal to 5, " + "but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, + "CUSTOM DECONV3D_BP OP: rank of weights array must be equal to " + "5 , but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 5, 0, + "CUSTOM DECONV3D_BP OP: rank of output gradients (next epsilon) " + "array must be equal to 5, but got %i instead !", + gradO->rankOf()); + + int kD = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 + ? INT_ARG(2) + : static_cast(weights->sizeAt(2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID + int isNCDHW = + block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 + ? INT_ARG(14) + : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], + // 2 - [iC, kD, kH, kW, oC] + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWoC, indWiC, indWkD); + + int trueoD, trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, + sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, + iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, oC, trueoD, trueoH, trueoW, + 0, indIOioC, indIOioD, + indIOioD + 1, indIOioD + 2}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM DECONV3D_BP OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM DECONV3D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + if (isSameMode) // Note: we're intentionally swapping iH and oH, to + // calculated the padding for a"normal" conv (not deconv) + // forward pass + ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + // ----- calculation of gradI -> pass it through conv3d_ff ----- // + sd::ops::conv3dnew conv3d; + const Nd4jStatus status = + conv3d.execute({gradO, weights}, {gradI}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + isSameMode, !isNCDHW, wFormat}, + {}); + if (status != ND4J_STATUS_OK) return status; + + // -----prepare permutation arrays and axes for dot product ----- // + std::vector inputAxesForDot; + + if (!isNCDHW) { + gradO = new NDArray(gradO->permute( + {0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW] + inputAxesForDot = {0, 1, 2, 3}; // bS, iD, iH, iW + } else + inputAxesForDot = {0, 2, 3, 4}; // bS, iD, iH, iW + + std::vector gradWAxes; // empty for wFormat = 1 + if (0 == wFormat) + gradWAxes = {4, 3, 0, 1, 2}; + else if (2 == wFormat) + gradWAxes = {0, 4, 1, 2, 3}; + + // ----- calculation of gradW ----- // + auto columns = NDArrayFactory::create( + input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), + block.launchContext()); + ConvolutionUtils::vol2col(block, *gradO, columns, sD, sH, sW, pD, pH, pW, dD, + dH, dW); // [bS, oC, oD, oH, oW] is deconvoluted to + // [bS, oC, kD, kH, kW, iD, iH, iW] + MmulHelper::tensorDot( + input, &columns, gradW, inputAxesForDot, {0, 5, 6, 7}, + gradWAxes); // [bS, iC, iD, iH, iW]/[bS, iD, iH, iW, iC] x [bS, oC, kD, + // kH, kW, iD, iH, iW] = [iC, oC, kD, kH, kW] + + // ----- calculation of gradB ----- // + if (gradB) { + if (gradB->rankOf() == 2) + gradB = new NDArray( + gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false)); + gradO->reduceAlongDimension(reduce::Sum, *gradB, + {0, 2, 3, 4}); // sum over bS, oD, oH, oW + if (gradB != OUTPUT_VARIABLE(2)) delete gradB; + } + + if (!isNCDHW) delete gradO; + + return Status::OK(); } - DECLARE_TYPES(deconv3d_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedInputTypes(3, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(deconv3d_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedInputTypes(3, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(deconv3d_bp) { - - auto inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] - auto biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] - auto gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - - const int rank = 5; - REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, "CUSTOM DECONV3D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, shape::rank(inputShapeInfo)); - REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DECONV3D_BP OP: rank of weights array must be equal to %i , but got %i instead !", rank, shape::rank(weightsShapeInfo)); - REQUIRE_TRUE(shape::rank(gradOShapeInfo) == rank, 0, "CUSTOM DECONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, shape::rank(gradOShapeInfo)); - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(shape::sizeAt(weightsShapeInfo, 2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] - - int indIOioC, indIiD, indWoC(0 == wFormat ? 3 : (1 == wFormat ? 1 : 4)); - if(!isNCDHW) { - indIOioC = 4; indIiD = 1; - } - else { - indIOioC = 1; indIiD = 2; - } - - const int bS = inputShapeInfo[1]; // batch size - const int iD = inputShapeInfo[indIiD+1]; // input depth - const int iH = inputShapeInfo[indIiD+2]; // input height - const int iW = inputShapeInfo[indIiD+3]; // input width - const int iC = inputShapeInfo[indIOioC+1]; // input channels - const int oC = weightsShapeInfo[indWoC+1]; // output channels - - int trueoD, trueoH, trueoW; // true output depth, height, width - ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIiD,indIiD+1,indIiD+2}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); - REQUIRE_TRUE(shape::shapeEquals(5, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DECONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); - REQUIRE_TRUE(shape::shapeEquals(5, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if(biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - - auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - auto gradWShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - - auto shapes = SHAPELIST(CONSTANT(gradIShapeInfo), CONSTANT(gradWShapeInfo)); - - if (biasShapeInfo != nullptr) { - auto gradBShapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - shapes->push_back(CONSTANT(gradBShapeInfo)); - } - - return shapes; + auto inputShapeInfo = inputShape->at( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weightsShapeInfo = inputShape->at( + 1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] + auto biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] + auto gradOShapeInfo = + block.width() > 3 + ? inputShape->at(3) + : inputShape->at(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, + // oH, oW] (NCDHW), epsilon_next + + const int rank = 5; + REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, + "CUSTOM DECONV3D_BP OP: rank of input array must be equal to " + "%i, but got %i instead !", + rank, shape::rank(inputShapeInfo)); + REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, + "CUSTOM DECONV3D_BP OP: rank of weights array must be equal to " + "%i , but got %i instead !", + rank, shape::rank(weightsShapeInfo)); + REQUIRE_TRUE(shape::rank(gradOShapeInfo) == rank, 0, + "CUSTOM DECONV3D_BP OP: rank of output gradients (next epsilon) " + "array must be equal to %i, but got %i instead !", + rank, shape::rank(gradOShapeInfo)); + + int kD = INT_ARG(0) > 0 ? INT_ARG(0) + : static_cast(shape::sizeAt( + weightsShapeInfo, 0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 ? INT_ARG(1) + : static_cast(shape::sizeAt( + weightsShapeInfo, 1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 ? INT_ARG(2) + : static_cast(shape::sizeAt( + weightsShapeInfo, 2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID + int isNCDHW = + block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 + ? INT_ARG(14) + : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], + // 2 - [iC, kD, kH, kW, oC] + + int indIOioC, indIiD, indWoC(0 == wFormat ? 3 : (1 == wFormat ? 1 : 4)); + if (!isNCDHW) { + indIOioC = 4; + indIiD = 1; + } else { + indIOioC = 1; + indIiD = 2; + } + + const int bS = inputShapeInfo[1]; // batch size + const int iD = inputShapeInfo[indIiD + 1]; // input depth + const int iH = inputShapeInfo[indIiD + 2]; // input height + const int iW = inputShapeInfo[indIiD + 3]; // input width + const int iC = inputShapeInfo[indIOioC + 1]; // input channels + const int oC = weightsShapeInfo[indWoC + 1]; // output channels + + int trueoD, trueoH, trueoW; // true output depth, height, width + ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, + sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, + iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, oC, trueoD, trueoH, trueoW, + 0, indIOioC, indIiD, indIiD + 1, + indIiD + 2}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); + REQUIRE_TRUE(shape::shapeEquals(5, expectedGradOShape.data(), + shape::rank(gradOShapeInfo), + shape::shapeOf(gradOShapeInfo)), + 0, + "CUSTOM DECONV3D_BP OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); + REQUIRE_TRUE(shape::shapeEquals(5, expectedWeightsShape.data(), + shape::rank(weightsShapeInfo), + shape::shapeOf(weightsShapeInfo)), + 0, + "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, + "CUSTOM DECONV3D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + + auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType( + inputShapeInfo, gradOShapeInfo, false, block.workspace()); + auto gradWShapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, gradOShapeInfo, false, block.workspace()); + + auto shapes = SHAPELIST(CONSTANT(gradIShapeInfo), CONSTANT(gradWShapeInfo)); + + if (biasShapeInfo != nullptr) { + auto gradBShapeInfo = ShapeBuilders::copyShapeInfoAndType( + biasShapeInfo, gradOShapeInfo, false, block.workspace()); + shapes->push_back(CONSTANT(gradBShapeInfo)); + } + + return shapes; } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp index 744512a13c6b..133f222e5bef 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp @@ -21,240 +21,395 @@ #include #if NOT_EXCLUDED(OP_depthwise_conv2d) -#include #include #include - +#include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC - - auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) - - REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "CUSTOM DEPTHWISECONV2D OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - ConvolutionUtils::depthwiseConv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC + + auto output = OUTPUT_NULLIFIED( + 0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "CUSTOM DEPTHWISECONV2D OP: rank of input array must be equal " + "to 4, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "CUSTOM DEPTHWISECONV2D OP: rank of weights array must be equal " + "to 4, but got %i instead !", + weights->rankOf()); + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, + // iC, kH, kW], 2 - [mC, kH, kW, iC] + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DEPTHWISECONV2D OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + REQUIRE_TRUE(output->sizeAt(indIOioC) == iC * mC, 0, + "CUSTOM DEPTHWISECONV2D OP: the output_channels must be equal " + "to input_channels * channels_multiplier = %i !", + iC * mC); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM DEPTHWISECONV2D OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + ConvolutionUtils::depthwiseConv2d(block, input, weights, bias, output, kH, kW, + sH, sW, pH, pW, dH, dW, isSameMode, isNCHW, + wFormat); + + return Status::OK(); } - DECLARE_TYPES(depthwise_conv2d) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(depthwise_conv2d) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(depthwise_conv2d) { - auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] = iC*mC - - const int rank = 4; - REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, "CUSTOM DEPTHWISECONV2D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]); - REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DEPTHWISECONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); - if(!isNCHW) { - indIOioC = 3; indIiH = 1; - } - else { - indIOioC = 1; indIiH = 2; - } - - const int bS = shape::sizeAt(inputShapeInfo, 0); // batch size - const int iH = shape::sizeAt(inputShapeInfo, indIiH); // input height - const int iW = shape::sizeAt(inputShapeInfo, indIiH+1); // input width - const int iC = shape::sizeAt(inputShapeInfo, indIOioC); // input channels - const int mC = shape::sizeAt(weightsShapeInfo, indWmC); // channels multiplier(oC = iC*mC) - const int oC = iC*mC; // output channels - - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); - REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "DEPTHWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if (biasShapeInfo) - REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo)); - - int oH, oW; // output height, width - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - Nd4jLong* outputShapeInfo = nullptr; - ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inputShapeInfo), Nd4jLong); - - outputShapeInfo[0] = rank; - outputShapeInfo[1] = bS; - - if (isNCHW) { - outputShapeInfo[2] = oC; - outputShapeInfo[3] = oH; - outputShapeInfo[4] = oW; - } else { - outputShapeInfo[2] = oH; - outputShapeInfo[3] = oW; - outputShapeInfo[4] = oC; - } - - ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, shape::order(inputShapeInfo)); - - return SHAPELIST(CONSTANT(outputShapeInfo)); + auto inputShapeInfo = + inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weightsShapeInfo = inputShape->at( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + auto biasShapeInfo = + block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] = iC*mC + + const int rank = 4; + REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, + "CUSTOM DEPTHWISECONV2D OP: rank of input array must be equal " + "to %i, but got %i instead !", + rank, inputShapeInfo[0]); + REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, + "CUSTOM DEPTHWISECONV2D OP: rank of weights array must be equal " + "to %i, but got %i instead !", + rank, weightsShapeInfo[0]); + + int kH = INT_ARG(0) > 0 ? INT_ARG(0) + : static_cast(shape::sizeAt( + weightsShapeInfo, 0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 ? INT_ARG(1) + : static_cast(shape::sizeAt( + weightsShapeInfo, 1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, + // iC, kH, kW], 2 - [mC, kH, kW, iC] + + int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); + if (!isNCHW) { + indIOioC = 3; + indIiH = 1; + } else { + indIOioC = 1; + indIiH = 2; + } + + const int bS = shape::sizeAt(inputShapeInfo, 0); // batch size + const int iH = shape::sizeAt(inputShapeInfo, indIiH); // input height + const int iW = shape::sizeAt(inputShapeInfo, indIiH + 1); // input width + const int iC = shape::sizeAt(inputShapeInfo, indIOioC); // input channels + const int mC = shape::sizeAt(weightsShapeInfo, + indWmC); // channels multiplier(oC = iC*mC) + const int oC = iC * mC; // output channels + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); + REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), + shape::rank(weightsShapeInfo), + shape::shapeOf(weightsShapeInfo)), + 0, + "DEPTHWISECONV2D OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE( + shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), + 0, + "DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, " + "length: <=2, %i, but got %i, %i instead !", + oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo)); + + int oH, oW; // output height, width + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, + iH, iW, isSameMode); + + Nd4jLong* outputShapeInfo = nullptr; + ALLOCATE(outputShapeInfo, block.workspace(), + shape::shapeInfoLength(inputShapeInfo), Nd4jLong); + + outputShapeInfo[0] = rank; + outputShapeInfo[1] = bS; + + if (isNCHW) { + outputShapeInfo[2] = oC; + outputShapeInfo[3] = oH; + outputShapeInfo[4] = oW; + } else { + outputShapeInfo[2] = oH; + outputShapeInfo[3] = oW; + outputShapeInfo[4] = oC; + } + + ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, + shape::order(inputShapeInfo)); + + return SHAPELIST(CONSTANT(outputShapeInfo)); } - DECLARE_TYPES(depthwise_conv2d_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(depthwise_conv2d_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(depthwise_conv2d_bp, 3, 2, false, 0, 9) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] - - REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of output gradients (next epsilon) array must be equal to 4, but got %i instead !", gradO->rankOf()); - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - int trueoH, trueoW; // correct output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - ConvolutionUtils::depthwiseConv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + auto bias = + block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_NULLIFIED( + 0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_NULLIFIED( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "CUSTOM DEPTHWISECONV2D_BP OP: rank of input array must be " + "equal to 4, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "CUSTOM DEPTHWISECONV2D_BP OP: rank of weights array must be " + "equal to 4, but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 4, 0, + "CUSTOM DEPTHWISECONV2D_BP OP: rank of output gradients (next " + "epsilon) array must be equal to 4, but got %i instead !", + gradO->rankOf()); + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, + // iC, kH, kW], 2 - [mC, kH, kW, iC] + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + int trueoH, trueoW; // correct output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients " + "(next epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE( + bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + ConvolutionUtils::depthwiseConv2dBP(block, input, weights, bias, gradO, gradI, + gradW, gradB, kH, kW, sH, sW, pH, pW, dH, + dW, isSameMode, isNCHW, wFormat); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(depthwise_conv2d_bp) { - auto inputShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; - auto gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); - - const int rank = 4; - REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, shape::rank(inputShapeInfo)); - REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, shape::rank(weightsShapeInfo)); - REQUIRE_TRUE(shape::rank(gradOShapeInfo) == rank, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, shape::rank(gradOShapeInfo)); - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); - if(!isNCHW) { - indIOioC = 3; indIiH = 1; - } - else { - indIOioC = 1; indIiH = 2; - } - - const int bS = shape::sizeAt(inputShapeInfo, 0); // batch size - const int iH = shape::sizeAt(inputShapeInfo, indIiH); // input height - const int iW = shape::sizeAt(inputShapeInfo, indIiH+1); // input width - const int iC = shape::sizeAt(inputShapeInfo, indIOioC); // input channels - const int mC = shape::sizeAt(weightsShapeInfo, indWmC); // channels multiplier(oC = iC*mC) - const int oC = iC*mC; // output channels - - int trueoH, trueoW; // correct output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indIiH,indIiH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); - REQUIRE_TRUE(shape::shapeEquals(4, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); - REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if(biasShapeInfo) - REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo)); - - auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - - if(biasShapeInfo) { - Nd4jLong* gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), CONSTANT(gradBshapeInfo)); - } - - return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo)); + auto inputShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; + auto gradOShapeInfo = + block.width() > 3 ? inputShape->at(3) : inputShape->at(2); + + const int rank = 4; + REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, + "CUSTOM DEPTHWISECONV2D_BP OP: rank of input array must be " + "equal to %i, but got %i instead !", + rank, shape::rank(inputShapeInfo)); + REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, + "CUSTOM DEPTHWISECONV2D_BP OP: rank of weights array must be " + "equal to %i, but got %i instead !", + rank, shape::rank(weightsShapeInfo)); + REQUIRE_TRUE(shape::rank(gradOShapeInfo) == rank, 0, + "CUSTOM DEPTHWISECONV2D_BP OP: rank of output gradients (next " + "epsilon) array must be equal to %i, but got %i instead !", + rank, shape::rank(gradOShapeInfo)); + + int kH = INT_ARG(0) > 0 ? INT_ARG(0) + : static_cast(shape::sizeAt( + weightsShapeInfo, 0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 ? INT_ARG(1) + : static_cast(shape::sizeAt( + weightsShapeInfo, 1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, + // iC, kH, kW], 2 - [mC, kH, kW, iC] + + int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); + if (!isNCHW) { + indIOioC = 3; + indIiH = 1; + } else { + indIOioC = 1; + indIiH = 2; + } + + const int bS = shape::sizeAt(inputShapeInfo, 0); // batch size + const int iH = shape::sizeAt(inputShapeInfo, indIiH); // input height + const int iW = shape::sizeAt(inputShapeInfo, indIiH + 1); // input width + const int iC = shape::sizeAt(inputShapeInfo, indIOioC); // input channels + const int mC = shape::sizeAt(weightsShapeInfo, + indWmC); // channels multiplier(oC = iC*mC) + const int oC = iC * mC; // output channels + + int trueoH, trueoW; // correct output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indIiH, indIiH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); + REQUIRE_TRUE(shape::shapeEquals(4, expectedGradOShape.data(), + shape::rank(gradOShapeInfo), + shape::shapeOf(gradOShapeInfo)), + 0, + "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients " + "(next epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); + REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), + shape::rank(weightsShapeInfo), + shape::shapeOf(weightsShapeInfo)), + 0, + "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE( + shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), + 0, + "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo)); + + auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType( + inputShapeInfo, gradOShapeInfo, false, block.workspace()); + auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, gradOShapeInfo, false, block.workspace()); + + if (biasShapeInfo) { + Nd4jLong* gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType( + biasShapeInfo, gradOShapeInfo, false, block.workspace()); + return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), + CONSTANT(gradBshapeInfo)); + } + + return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo)); } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/dilation2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/dilation2d.cpp index b3a0e1667457..23bc4f0c09f0 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/dilation2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/dilation2d.cpp @@ -26,111 +26,116 @@ namespace sd { namespace ops { - CUSTOM_OP_IMPL(dilation2d, 2, 1, false, 0, 1) { - auto input = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); +CUSTOM_OP_IMPL(dilation2d, 2, 1, false, 0, 1) { + auto input = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(input->rankOf() == 4, 0, "Dilation2D: input should be 4D"); - REQUIRE_TRUE(weights->rankOf() == 3, 0, "Dilation2D: weights should be 3D"); + REQUIRE_TRUE(input->rankOf() == 4, 0, "Dilation2D: input should be 4D"); + REQUIRE_TRUE(weights->rankOf() == 3, 0, "Dilation2D: weights should be 3D"); - const int bS = input->sizeAt(0); - const int iC = input->sizeAt(3); - const bool isSameShape = INT_ARG(0) == 1; + const int bS = input->sizeAt(0); + const int iC = input->sizeAt(3); + const bool isSameShape = INT_ARG(0) == 1; - REQUIRE_TRUE(input->sizeAt(3) == weights->sizeAt(2), 0, "Dilation2D: number of input channels doesn't match number of channels in weights: %i vs %i", input->sizeAt(3), weights->sizeAt(2)); + REQUIRE_TRUE(input->sizeAt(3) == weights->sizeAt(2), 0, + "Dilation2D: number of input channels doesn't match number of " + "channels in weights: %i vs %i", + input->sizeAt(3), weights->sizeAt(2)); - std::vector strides(4); - std::vector rates(4); + std::vector strides(4); + std::vector rates(4); - if (block.width() > 2) { - REQUIRE_TRUE(block.width() >= 4, 0, "Dilation2D: number of input arrays should be 4 at least"); + if (block.width() > 2) { + REQUIRE_TRUE(block.width() >= 4, 0, + "Dilation2D: number of input arrays should be 4 at least"); + auto r = INPUT_VARIABLE(2); + auto s = INPUT_VARIABLE(3); - auto r = INPUT_VARIABLE(2); - auto s = INPUT_VARIABLE(3); + strides = s->template asVectorT(); + rates = r->template asVectorT(); + } else { + REQUIRE_TRUE(block.numI() >= 9, 0, + "Dilation2D: number of Int arguments should be 9 at least"); - strides = s->template asVectorT(); - rates = r->template asVectorT(); - } else { - REQUIRE_TRUE(block.numI() >= 9, 0, "Dilation2D: number of Int arguments should be 9 at least"); + int e = 1; + for (int cnt = 0; cnt < 4; cnt++) rates[cnt] = INT_ARG(e++); - int e = 1; - for (int cnt = 0;cnt < 4; cnt++) - rates[cnt] = INT_ARG(e++); + for (int cnt = 0; cnt < 4; cnt++) strides[cnt] = INT_ARG(e++); + } + int sH = 0, sW = 0; + int dH = 0, dW = 0; + int pH = 0, pW = 0; + int oH = 0, oW = 0; - for (int cnt = 0; cnt < 4; cnt++) - strides[cnt] = INT_ARG(e++); - } + helpers::dilation_hw(block.launchContext(), input->shapeInfo(), + weights->shapeInfo(), strides, rates, isSameShape, &sH, + &sW, &pH, &pW, &dH, &dW, &oH, &oW); + REQUIRE_TRUE(oH > 0 && oW > 0, 0, + "Dilation2D: outY and outX should have positive values, but got " + "[%i, %i] instead", + oH, oW); - int sH = 0, sW = 0; - int dH = 0, dW = 0; - int pH = 0, pW = 0; - int oH = 0, oW = 0; + helpers::dilation2d(block.launchContext(), input, weights, output, sH, sW, pH, + pW, dH, dW); - helpers::dilation_hw(block.launchContext(), input->shapeInfo(), weights->shapeInfo(), strides, rates, isSameShape, &sH, &sW, &pH, &pW, &dH, &dW, &oH, &oW); - - - REQUIRE_TRUE(oH > 0 && oW > 0, 0, "Dilation2D: outY and outX should have positive values, but got [%i, %i] instead", oH, oW); - - helpers::dilation2d(block.launchContext(), input, weights, output, sH, sW, pH, pW, dH, dW); - - return Status::OK(); - } - - DECLARE_TYPES(dilation2d) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(dilation2d) { - auto input = inputShape->at(0); - auto weights = inputShape->at(1); +DECLARE_TYPES(dilation2d) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - const int bS = shape::sizeAt(input, 0); - const int iC = shape::sizeAt(input, 3); - const bool isSameShape = INT_ARG(0) == 1; +DECLARE_SHAPE_FN(dilation2d) { + auto input = inputShape->at(0); + auto weights = inputShape->at(1); - std::vector strides(4); - std::vector rates(4); + const int bS = shape::sizeAt(input, 0); + const int iC = shape::sizeAt(input, 3); + const bool isSameShape = INT_ARG(0) == 1; - if (block.width() > 2) { - auto r = INPUT_VARIABLE(2); - auto s = INPUT_VARIABLE(3); + std::vector strides(4); + std::vector rates(4); + if (block.width() > 2) { + auto r = INPUT_VARIABLE(2); + auto s = INPUT_VARIABLE(3); - strides = s->template asVectorT(); - rates = r->template asVectorT(); - } else { - if (block.numI() < 9) { - auto newShape = ConstantShapeHelper::getInstance().scalarShapeInfo(block.dataType()); - return SHAPELIST(newShape); - } + strides = s->template asVectorT(); + rates = r->template asVectorT(); + } else { + if (block.numI() < 9) { + auto newShape = ConstantShapeHelper::getInstance().scalarShapeInfo( + ArrayOptions::dataType(input)); + return SHAPELIST(newShape); + } - int e = 1; - for (int cnt = 0;cnt < 4; cnt++) - rates[cnt] = INT_ARG(e++); + int e = 1; + for (int cnt = 0; cnt < 4; cnt++) rates[cnt] = INT_ARG(e++); - for (int cnt = 0; cnt < 4; cnt++) - strides[cnt] = INT_ARG(e++); - } + for (int cnt = 0; cnt < 4; cnt++) strides[cnt] = INT_ARG(e++); + } - int sH = 0, sW = 0; - int dH = 0, dW = 0; - int pH = 0, pW = 0; - int oH = 0, oW = 0; + int sH = 0, sW = 0; + int dH = 0, dW = 0; + int pH = 0, pW = 0; + int oH = 0, oW = 0; - helpers::dilation_hw(block.launchContext(), input, weights, strides, rates, isSameShape, &sH, &sW, &pH, &pW, &dH, &dW, &oH, &oW); + helpers::dilation_hw(block.launchContext(), input, weights, strides, rates, + isSameShape, &sH, &sW, &pH, &pW, &dH, &dW, &oH, &oW); - std::array shape = {{bS, oH, oW, iC}}; - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(weights), 'c', 4, shape.data()); - return SHAPELIST(newShape); - } -} + std::array shape = {{bS, oH, oW, iC}}; + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(weights), 'c', 4, shape.data()); + return SHAPELIST(newShape); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/im2col.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/im2col.cpp index 2e5818c56754..8d84a172bb89 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/im2col.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/im2col.cpp @@ -22,139 +22,153 @@ #if NOT_EXCLUDED(OP_im2col) #include +#include #include #include -#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(im2col, 1, 1, false, 0, 9) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_NULLIFIED(0); - - REQUIRE_TRUE(x->rankOf() == 4, 0, "im2col input should be 4D, but got %i instead", x->rankOf()); - REQUIRE_TRUE(z->rankOf() == 6, 0, "im2col output should be 6D, but got %i instead", z->rankOf()); - - int kernelHeight = INT_ARG(0); - int kernelWidth = INT_ARG(1); - int strideY = INT_ARG(2); - int strideX = INT_ARG(3); - int padHeight = INT_ARG(4); - int padWidth = INT_ARG(5); - int dY = INT_ARG(6); //Dilation, height/y dimension - int dX = INT_ARG(7); //Dilation, width/x dimension - bool isSameMode = INT_ARG(8) > 0; - double zeroPadVal = 0.0; - if (block.getTArguments()->size() > 0) - zeroPadVal = T_ARG(0); - - // FIXME: zeropad value is void - LaunchContext* ctx = block.launchContext(); - sd::ops::helpers::im2col(*ctx, *x, *z, kernelHeight, kernelWidth, strideY, strideX, padHeight, padWidth, dY, dX, NDArrayFactory::create(zeroPadVal, block.launchContext())); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(im2col) { - auto inShape = inputShape->at(0); - - int bS = shape::shapeOf(inShape)[0]; - int iD = shape::shapeOf(inShape)[1]; - int inY = shape::shapeOf(inShape)[2]; - int inX = shape::shapeOf(inShape)[3]; - - int kY = INT_ARG(0); - int kX = INT_ARG(1); - int sY = INT_ARG(2); - int sX = INT_ARG(3); - int pY = INT_ARG(4); - int pX = INT_ARG(5); - int dY = INT_ARG(6); //Dilation, height/y dimension - int dX = INT_ARG(7); //Dilation, width/x dimension - bool isSameMode = INT_ARG(8) > 0; - - // output is always 6d for im2col - Nd4jLong* zShape; - ALLOCATE(zShape, block.getWorkspace(), shape::shapeInfoLength(6), Nd4jLong); - - int oY = 0; - int oX = 0; - - ConvolutionUtils::calcOutSizePool2D(oY, oX, kY, kX, sY, sX, pY, pX, dY, dX, inY, inX, isSameMode); - - if (isSameMode) - ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, kY, kX, sY, sX, dY, dX); - - zShape[0] = 6; - zShape[1] = bS; - zShape[2] = iD; - zShape[3] = kY; - zShape[4] = kX; - zShape[5] = oY; - zShape[6] = oX; - - zShape[shape::shapeInfoLength(zShape) - 2] = 1; - zShape[shape::shapeInfoLength(zShape) - 1] = 99; - - ShapeUtils::updateStridesAndType(zShape, inShape, 'c'); - - return SHAPELIST(CONSTANT(zShape)); - } - - CUSTOM_OP_IMPL(im2col_bp, 2, 1, false, 0, 9) { - auto input = INPUT_VARIABLE(0); - auto gradAtOutput = INPUT_VARIABLE(1); - auto z = OUTPUT_NULLIFIED(0); - - REQUIRE_TRUE(input->rankOf() == 4, 0, "im2col_bp input should be 4D, but got %i instead", input->rankOf()); - REQUIRE_TRUE(gradAtOutput->rankOf() == 6, 0, "im2col_bp gradient at output (input idx 1) should be 6D, but got %i instead", gradAtOutput->rankOf()); - REQUIRE_TRUE(z->rankOf() == 4, 0, "im2col_bp output (grad at input) should be 4D, but got %i instead", z->rankOf()); - - int kernelHeight = INT_ARG(0); - int kernelWidth = INT_ARG(1); - int strideY = INT_ARG(2); - int strideX = INT_ARG(3); - int pH = INT_ARG(4); - int pW = INT_ARG(5); - int dY = INT_ARG(6); //Dilation, height/y dimension - int dX = INT_ARG(7); //Dilation, width/x dimension - bool isSameMode = INT_ARG(8) > 0; - double zeroPadVal = 0.0; - if (block.getTArguments()->size() > 0) - zeroPadVal = T_ARG(0); - - //Assuming NCHW format here - int imgH = input->sizeAt(2); - int imgW = input->sizeAt(3); - - LaunchContext* ctx = block.launchContext(); - // FIXME:: all helpers should accept NDArray - ops::helpers::col2im(*ctx, *gradAtOutput, *z, strideY, strideX, pH, pW, imgH, imgW, dY, dX); - - return Status::OK(); - } - - DECLARE_TYPES(im2col) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT) - ->setSameMode(true); - } - - DECLARE_TYPES(im2col_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT) - ->setSameMode(true); - } - - DECLARE_SHAPE_FN(im2col_bp) { - Nd4jLong *inShape; - COPY_SHAPE(inputShape->at(0), inShape); - - return SHAPELIST(CONSTANT(inShape)); - } - } +namespace ops { +CUSTOM_OP_IMPL(im2col, 1, 1, false, 0, 9) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_NULLIFIED(0); + + REQUIRE_TRUE(x->rankOf() == 4, 0, + "im2col input should be 4D, but got %i instead", x->rankOf()); + REQUIRE_TRUE(z->rankOf() == 6, 0, + "im2col output should be 6D, but got %i instead", z->rankOf()); + + int kernelHeight = INT_ARG(0); + int kernelWidth = INT_ARG(1); + int strideY = INT_ARG(2); + int strideX = INT_ARG(3); + int padHeight = INT_ARG(4); + int padWidth = INT_ARG(5); + int dY = INT_ARG(6); // Dilation, height/y dimension + int dX = INT_ARG(7); // Dilation, width/x dimension + bool isSameMode = INT_ARG(8) > 0; + double zeroPadVal = 0.0; + if (block.numT() > 0) zeroPadVal = T_ARG(0); + + // FIXME: zeropad value is void + LaunchContext* ctx = block.launchContext(); + sd::ops::helpers::im2col( + *ctx, *x, *z, kernelHeight, kernelWidth, strideY, strideX, padHeight, + padWidth, dY, dX, + NDArrayFactory::create(zeroPadVal, block.launchContext())); + + return Status::OK(); +} + +DECLARE_SHAPE_FN(im2col) { + auto inShape = inputShape->at(0); + + int bS = shape::shapeOf(inShape)[0]; + int iD = shape::shapeOf(inShape)[1]; + int inY = shape::shapeOf(inShape)[2]; + int inX = shape::shapeOf(inShape)[3]; + + int kY = INT_ARG(0); + int kX = INT_ARG(1); + int sY = INT_ARG(2); + int sX = INT_ARG(3); + int pY = INT_ARG(4); + int pX = INT_ARG(5); + int dY = INT_ARG(6); // Dilation, height/y dimension + int dX = INT_ARG(7); // Dilation, width/x dimension + bool isSameMode = INT_ARG(8) > 0; + + // output is always 6d for im2col + Nd4jLong* zShape; + ALLOCATE(zShape, block.workspace(), shape::shapeInfoLength(6), Nd4jLong); + + int oY = 0; + int oX = 0; + + ConvolutionUtils::calcOutSizePool2D(oY, oX, kY, kX, sY, sX, pY, pX, dY, dX, + inY, inX, isSameMode); + + if (isSameMode) + ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, kY, kX, sY, sX, + dY, dX); + + zShape[0] = 6; + zShape[1] = bS; + zShape[2] = iD; + zShape[3] = kY; + zShape[4] = kX; + zShape[5] = oY; + zShape[6] = oX; + + zShape[shape::shapeInfoLength(zShape) - 2] = 1; + zShape[shape::shapeInfoLength(zShape) - 1] = 99; + + ShapeUtils::updateStridesAndType(zShape, inShape, 'c'); + + return SHAPELIST(CONSTANT(zShape)); +} + +CUSTOM_OP_IMPL(im2col_bp, 2, 1, false, 0, 9) { + auto input = INPUT_VARIABLE(0); + auto gradAtOutput = INPUT_VARIABLE(1); + auto z = OUTPUT_NULLIFIED(0); + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "im2col_bp input should be 4D, but got %i instead", + input->rankOf()); + REQUIRE_TRUE(gradAtOutput->rankOf() == 6, 0, + "im2col_bp gradient at output (input idx 1) should be 6D, but " + "got %i instead", + gradAtOutput->rankOf()); + REQUIRE_TRUE( + z->rankOf() == 4, 0, + "im2col_bp output (grad at input) should be 4D, but got %i instead", + z->rankOf()); + + int kernelHeight = INT_ARG(0); + int kernelWidth = INT_ARG(1); + int strideY = INT_ARG(2); + int strideX = INT_ARG(3); + int pH = INT_ARG(4); + int pW = INT_ARG(5); + int dY = INT_ARG(6); // Dilation, height/y dimension + int dX = INT_ARG(7); // Dilation, width/x dimension + bool isSameMode = INT_ARG(8) > 0; + double zeroPadVal = 0.0; + if (block.numT() > 0) zeroPadVal = T_ARG(0); + + // Assuming NCHW format here + int imgH = input->sizeAt(2); + int imgW = input->sizeAt(3); + + LaunchContext* ctx = block.launchContext(); + // FIXME:: all helpers should accept NDArray + ops::helpers::col2im(*ctx, *gradAtOutput, *z, strideY, strideX, pH, pW, imgH, + imgW, dY, dX); + + return Status::OK(); +} + +DECLARE_TYPES(im2col) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT) + ->setSameMode(true); +} + +DECLARE_TYPES(im2col_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT) + ->setSameMode(true); +} + +DECLARE_SHAPE_FN(im2col_bp) { + Nd4jLong* inShape; + COPY_SHAPE(inputShape->at(0), inShape); + + return SHAPELIST(CONSTANT(inShape)); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/ismax.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/ismax.cpp index d786504adb72..247de31c7792 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/ismax.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/ismax.cpp @@ -26,30 +26,28 @@ #include namespace sd { -namespace ops { +namespace ops { CONFIGURABLE_OP_IMPL(ismax, 1, 1, true, 0, -2) { - - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); - auto dimensions = *(block.getIArguments()); // argI - if (x->isScalar()) - z->assign(1); - else - helpers::ismax(block.launchContext(), x, z, dimensions); - - return Status::OK(); + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + auto dimensions = block.getIArguments(); // argI + if (x->isScalar()) + z->assign(1); + else + helpers::ismax(block.launchContext(), x, z, dimensions); + + return Status::OK(); } DECLARE_SYN(IsMax, ismax); - DECLARE_TYPES(ismax) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::ANY); - - } - -} +DECLARE_TYPES(ismax) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::ANY); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/pointwiseConv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/pointwiseConv2d.cpp index 0f7bdde10a84..288b276dbeb3 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/pointwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/pointwiseConv2d.cpp @@ -22,92 +22,139 @@ #include namespace sd { -namespace ops { - - +namespace ops { CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [1, 1, iC, oC], [oC, iC, 1, 1], [oC, 1, 1, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - - auto output = OUTPUT_VARIABLE(0); // [bS, iH, iW, oC] (NHWC) or [bS, oC, iH, iW] (NCHW) - - REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM POINTWISECONV2D OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM POINTWISECONV2D OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2, 0, "CUSTOM POINTWISECONV2D OP: rank of biases array must be equal <= 2, but got %i instead !", bias->rankOf()); - - int kH = 1; // filter(kernel) height - int kW = 1; // filter(kernel) width - int sH = 1; // strides height - int sW = 1; // strides width - int pH = 0; // paddings height - int pW = 0; // paddings width - int dH = 1; // dilations height - int dW = 1; // dilations width - int isNCHW = block.getIArguments()->size() > 0 ? !INT_ARG(0) : 1; // INT_ARG(0): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 1 ? INT_ARG(1) : 0; // 0 - [1, 1, iC, oC], 1 - [oC, iC, 1, 1], 2 - [oC, 1, 1, iC] - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC, oC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, 1/*isSameMode*/, isNCHW, wFormat); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = + INPUT_VARIABLE(1); // [1, 1, iC, oC], [oC, iC, 1, 1], [oC, 1, 1, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + auto output = + OUTPUT_VARIABLE(0); // [bS, iH, iW, oC] (NHWC) or [bS, oC, iH, iW] (NCHW) + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "CUSTOM POINTWISECONV2D OP: rank of input array must be equal " + "to 4, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "CUSTOM POINTWISECONV2D OP: rank of weights array must be equal " + "to 4, but got %i instead !", + weights->rankOf()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2, 0, + "CUSTOM POINTWISECONV2D OP: rank of biases array must be " + "equal <= 2, but got %i instead !", + bias->rankOf()); + + int kH = 1; // filter(kernel) height + int kW = 1; // filter(kernel) width + int sH = 1; // strides height + int sW = 1; // strides width + int pH = 0; // paddings height + int pW = 0; // paddings width + int dH = 1; // dilations height + int dW = 1; // dilations width + int isNCHW = + block.numI() > 0 ? !INT_ARG(0) : 1; // INT_ARG(0): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 1 + ? INT_ARG(1) + : 0; // 0 - [1, 1, iC, oC], 1 - [oC, iC, 1, 1], 2 - [oC, 1, 1, iC] + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC, oC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM POINTWISECONV2D OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM POINTWISECONV2D OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + ConvolutionUtils::conv2d(block, input, weights, bias, output, kH, kW, sH, sW, + pH, pW, dH, dW, 1 /*isSameMode*/, isNCHW, wFormat); + + return Status::OK(); } - DECLARE_TYPES(pointwise_conv2d) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - +DECLARE_TYPES(pointwise_conv2d) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(pointwise_conv2d) { - auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weightsShapeInfo = inputShape->at(1); // [1, 1, iC, oC], [oC, iC, 1, 1], [oC, 1, 1, iC] - auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] - - const int rank = 4; - REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM POINTWISECONV2D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]); - REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM POINTWISECONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); - - int isNCHW = block.getIArguments()->size() > 0 ? !INT_ARG(0) : 1; // INT_ARG(0): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 1 ? INT_ARG(1) : 0; // 0 - [1, 1, iC, oC], 1 - [oC, iC, 1, 1], 2 - [oC, 1, 1, iC] - - int indIOioC, indWoC(0 == wFormat ? 3 : 0); - if(!isNCHW) - indIOioC = 3; - else - indIOioC = 1; - - const int bS = inputShapeInfo[1]; // batch size - const int iC = inputShapeInfo[indIOioC+1]; // input channels - const int oC = weightsShapeInfo[indWoC+1]; // output channels - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC, oC); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if (biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - - auto outputShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, weightsShapeInfo, true, block.getWorkspace()); - - // do not forget to put oC instead of iC in outputShapeInfo - outputShapeInfo[indIOioC + 1] = oC; - - shape::updateStrides(outputShapeInfo, shape::order(inputShapeInfo)); - - return SHAPELIST(CONSTANT(outputShapeInfo)); + auto inputShapeInfo = + inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weightsShapeInfo = + inputShape->at(1); // [1, 1, iC, oC], [oC, iC, 1, 1], [oC, 1, 1, iC] + auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] + + const int rank = 4; + REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, + "CUSTOM POINTWISECONV2D OP: rank of input array must be equal " + "to %i, but got %i instead !", + rank, inputShapeInfo[0]); + REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, + "CUSTOM POINTWISECONV2D OP: rank of weights array must be equal " + "to %i, but got %i instead !", + rank, weightsShapeInfo[0]); + + int isNCHW = + block.numI() > 0 ? !INT_ARG(0) : 1; // INT_ARG(0): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 1 + ? INT_ARG(1) + : 0; // 0 - [1, 1, iC, oC], 1 - [oC, iC, 1, 1], 2 - [oC, 1, 1, iC] + + int indIOioC, indWoC(0 == wFormat ? 3 : 0); + if (!isNCHW) + indIOioC = 3; + else + indIOioC = 1; + + const int bS = inputShapeInfo[1]; // batch size + const int iC = inputShapeInfo[indIOioC + 1]; // input channels + const int oC = weightsShapeInfo[indWoC + 1]; // output channels + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC, oC); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, + "POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, + "POINTWISECONV2D OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + + auto outputShapeInfo = ShapeBuilders::copyShapeInfoAndType( + inputShapeInfo, weightsShapeInfo, true, block.workspace()); + + // do not forget to put oC instead of iC in outputShapeInfo + outputShapeInfo[indIOioC + 1] = oC; + + shape::updateStrides(outputShapeInfo, shape::order(inputShapeInfo)); + + return SHAPELIST(CONSTANT(outputShapeInfo)); } - -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp index d887d7c2ab00..78c05e01e07f 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp @@ -24,379 +24,584 @@ #include #include + #include namespace sd { -namespace ops { - +namespace ops { CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) { + NDArray *input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + NDArray *weightsDepth = INPUT_VARIABLE( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + NDArray *weightsPoint = + nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] + NDArray *bias = nullptr; // [oC], if weightsPoint=nullptr then oC = iC*mC + + NDArray *output = OUTPUT_NULLIFIED( + 0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + + if (block.width() == 3) { + if ((INPUT_VARIABLE(2))->rankOf() == 4) + weightsPoint = INPUT_VARIABLE(2); + else + bias = INPUT_VARIABLE(2); + } else if (block.width() == 4) { + weightsPoint = INPUT_VARIABLE(2); + bias = INPUT_VARIABLE(3); + } + + REQUIRE_TRUE(input->rankOf() == 4, 0, + " SCONV2D OP: rank of input array must be equal to 4, but got " + "%i instead !", + input->rankOf()); + REQUIRE_TRUE(weightsDepth->rankOf() == 4, 0, + " SCONV2D OP: rank of weightsDepth array must be equal to 4, " + "but got %i instead !", + weightsDepth->rankOf()); + if (weightsPoint) + REQUIRE_TRUE(weightsPoint->rankOf() == 4, 0, + " SCONV2D OP: rank of weightsPoint array must be equal to 4, " + "but got %i instead !", + weightsPoint->rankOf()); + if (bias) + REQUIRE_TRUE(bias->rankOf() == 1 || bias->rankOf() == 2, 0, + " SCONV2D OP: rank of biases array must be equal to 1 or 2, " + "but got %i instead !", + bias->rankOf()); + ; + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, + // iC, kH, kW], 2 - [mC, kH, kW, iC] + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier, output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weightsDepth->sizeAt(indWmC); // channels multiplier + + std::vector expectedWeightsDShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); + REQUIRE_TRUE(weightsDepth->isSameShape(expectedWeightsDShape), 0, + " SCONV2D OP: wrong shape of weightsDepth array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), + ShapeUtils::shapeAsString(weightsDepth).c_str()); + if (weightsPoint) { + std::vector expectedWeightsPShape = + ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC * mC, oC); + REQUIRE_TRUE(weightsPoint->isSameShape(expectedWeightsPShape), 0, + " SCONV2D OP: wrong shape of weightsPoint array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), + ShapeUtils::shapeAsString(weightsPoint).c_str()); + } + if (bias) + REQUIRE_TRUE(oC == bias->lengthOf(), 0, + " SCONV2D OP: length of bias array must be equal to " + "outChannels, but got %i instead", + bias->lengthOf()); + + if (iC == 1) { + nd4j_debug( + "SCONV2D OP: for input_channels = 1 this op is equivalent to standard " + "conv2d\n", + ""); + ConvolutionUtils::conv2d(block, input, weightsDepth, bias, output, kH, kW, + sH, sW, pH, pW, dH, dW, isSameMode, isNCHW, + wFormat); + return Status::OK(); + } - NDArray *input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - NDArray *weightsDepth = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] - NDArray *bias = nullptr; // [oC], if weightsPoint=nullptr then oC = iC*mC - - NDArray *output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - - if(block.width() == 3) { - if((INPUT_VARIABLE(2))->rankOf() == 4) - weightsPoint = INPUT_VARIABLE(2); - else - bias = INPUT_VARIABLE(2); - } - else if(block.width() == 4) { - weightsPoint = INPUT_VARIABLE(2); - bias = INPUT_VARIABLE(3); - } - - REQUIRE_TRUE(input->rankOf() == 4, 0, " SCONV2D OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weightsDepth->rankOf() == 4, 0, " SCONV2D OP: rank of weightsDepth array must be equal to 4, but got %i instead !", weightsDepth->rankOf()); - if(weightsPoint) - REQUIRE_TRUE(weightsPoint->rankOf() == 4, 0, " SCONV2D OP: rank of weightsPoint array must be equal to 4, but got %i instead !", weightsPoint->rankOf()); - if(bias) - REQUIRE_TRUE(bias->rankOf() == 1 || bias->rankOf() == 2, 0, " SCONV2D OP: rank of biases array must be equal to 1 or 2, but got %i instead !", bias->rankOf());; - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weightsDepth->sizeAt(indWmC); // channels multiplier - - std::vector expectedWeightsDShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); - REQUIRE_TRUE(weightsDepth->isSameShape(expectedWeightsDShape), 0, " SCONV2D OP: wrong shape of weightsDepth array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDepth).c_str()); - if(weightsPoint) { - std::vector expectedWeightsPShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC*mC, oC); - REQUIRE_TRUE(weightsPoint->isSameShape(expectedWeightsPShape), 0, " SCONV2D OP: wrong shape of weightsPoint array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPoint).c_str()); - } - if (bias) - REQUIRE_TRUE(oC == bias->lengthOf(), 0, " SCONV2D OP: length of bias array must be equal to outChannels, but got %i instead", bias->lengthOf()); - - if (iC == 1) { - nd4j_debug("SCONV2D OP: for input_channels = 1 this op is equivalent to standard conv2d\n",""); - ConvolutionUtils::conv2d(block, input, weightsDepth, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat); - return Status::OK(); - } - - ConvolutionUtils::sconv2d(block, input, weightsDepth, weightsPoint, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat); + ConvolutionUtils::sconv2d(block, input, weightsDepth, weightsPoint, bias, + output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, + isNCHW, wFormat); - return Status::OK(); + return Status::OK(); } - DECLARE_TYPES(sconv2d) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - +DECLARE_TYPES(sconv2d) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(sconv2d) { - - auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weightsDShapeInfo = inputShape->at(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - Nd4jLong const* weightsPShapeInfo = nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] - Nd4jLong const* biasShapeInfo = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr - - if(block.width() == 3) - if(inputShape->at(2)[0] == 4) - weightsPShapeInfo = inputShape->at(2); - else - biasShapeInfo = inputShape->at(2); - else if(block.width() == 4) { - weightsPShapeInfo = inputShape->at(2); - biasShapeInfo = inputShape->at(3); - } - - const int rank = 4; - REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "SCONV2D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]); - REQUIRE_TRUE(weightsDShapeInfo[0] == rank, 0, "SCONV2D OP: rank of weightsDepth array must be equal to %i, but got %i instead !", rank, weightsDShapeInfo[0]); - if(weightsPShapeInfo) - REQUIRE_TRUE(weightsPShapeInfo[0] == rank, 0, "SCONV2D OP: rank of weightsPoint array must be equal to %i, but got %i instead !", rank, weightsPShapeInfo[0]); - if(biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] <= 2, 0, "SCONV2D OP: rank of biases array must be <= 2, but got %i instead !", biasShapeInfo[0]);; - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); - if(!isNCHW) { - indIOioC = 3; indIiH = 1; - } - else { - indIOioC = 1; indIiH = 2; - } - - const int bS = inputShapeInfo[1]; // batch size - const int iH = inputShapeInfo[indIiH+1]; // input height - const int iW = inputShapeInfo[indIiH+2]; // input width - const int iC = inputShapeInfo[indIOioC+1]; // input channels - const int mC = weightsDShapeInfo[indWmC+1]; // channel multiplier - const int oC = weightsPShapeInfo ? weightsPShapeInfo[indWmC+1] : iC*mC; // output channels (oC or iC*mC) - - std::vector expectedWeightsDShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsDShapeInfo, expectedWeightsDShape), 0, "SCONV2D OP: wrong shape of depth weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDShapeInfo).c_str()); - if(weightsPShapeInfo) { - std::vector expectedWeightsPShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC*mC, oC); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsPShapeInfo, expectedWeightsPShape), 0, "SCONV2D OP: wrong shape of point array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPShapeInfo).c_str()); - } - if (biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "SCONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - - int oH, oW; // output height, width - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - Nd4jLong* outputShapeInfo = nullptr; - ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inputShapeInfo), Nd4jLong); - - outputShapeInfo[0] = 4; - outputShapeInfo[1] = bS; - - if (isNCHW) { - outputShapeInfo[2] = oC; - outputShapeInfo[3] = oH; - outputShapeInfo[4] = oW; - } else { - outputShapeInfo[2] = oH; - outputShapeInfo[3] = oW; - outputShapeInfo[4] = oC; - } - - ShapeUtils::updateStridesAndType(outputShapeInfo, weightsDShapeInfo, shape::order(inputShapeInfo)); - - return SHAPELIST(CONSTANT(outputShapeInfo)); + auto inputShapeInfo = + inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weightsDShapeInfo = inputShape->at( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + Nd4jLong const *weightsPShapeInfo = + nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] + Nd4jLong const *biasShapeInfo = + nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr + + if (block.width() == 3) + if (inputShape->at(2)[0] == 4) + weightsPShapeInfo = inputShape->at(2); + else + biasShapeInfo = inputShape->at(2); + else if (block.width() == 4) { + weightsPShapeInfo = inputShape->at(2); + biasShapeInfo = inputShape->at(3); + } + + const int rank = 4; + REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, + "SCONV2D OP: rank of input array must be equal to %i, but got " + "%i instead !", + rank, inputShapeInfo[0]); + REQUIRE_TRUE(weightsDShapeInfo[0] == rank, 0, + "SCONV2D OP: rank of weightsDepth array must be equal to %i, " + "but got %i instead !", + rank, weightsDShapeInfo[0]); + if (weightsPShapeInfo) + REQUIRE_TRUE(weightsPShapeInfo[0] == rank, 0, + "SCONV2D OP: rank of weightsPoint array must be equal to %i, " + "but got %i instead !", + rank, weightsPShapeInfo[0]); + if (biasShapeInfo) + REQUIRE_TRUE( + biasShapeInfo[0] <= 2, 0, + "SCONV2D OP: rank of biases array must be <= 2, but got %i instead !", + biasShapeInfo[0]); + ; + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, + // iC, kH, kW], 2 - [mC, kH, kW, iC] + + int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); + if (!isNCHW) { + indIOioC = 3; + indIiH = 1; + } else { + indIOioC = 1; + indIiH = 2; + } + + const int bS = inputShapeInfo[1]; // batch size + const int iH = inputShapeInfo[indIiH + 1]; // input height + const int iW = inputShapeInfo[indIiH + 2]; // input width + const int iC = inputShapeInfo[indIOioC + 1]; // input channels + const int mC = weightsDShapeInfo[indWmC + 1]; // channel multiplier + const int oC = weightsPShapeInfo ? weightsPShapeInfo[indWmC + 1] + : iC * mC; // output channels (oC or iC*mC) + + std::vector expectedWeightsDShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(weightsDShapeInfo, expectedWeightsDShape), 0, + "SCONV2D OP: wrong shape of depth weights array, expected is %s, but got " + "%s instead !", + ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), + ShapeUtils::shapeAsString(weightsDShapeInfo).c_str()); + if (weightsPShapeInfo) { + std::vector expectedWeightsPShape = + ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC * mC, oC); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(weightsPShapeInfo, expectedWeightsPShape), 0, + "SCONV2D OP: wrong shape of point array, expected is %s, but got %s " + "instead !", + ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), + ShapeUtils::shapeAsString(weightsPShapeInfo).c_str()); + } + if (biasShapeInfo) + REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, + "SCONV2D OP: wrong shape of array with biases, expected rank, " + "length: <=2, %i, but got %i, %i instead !", + oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + + int oH, oW; // output height, width + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, + iH, iW, isSameMode); + + Nd4jLong *outputShapeInfo = nullptr; + ALLOCATE(outputShapeInfo, block.workspace(), + shape::shapeInfoLength(inputShapeInfo), Nd4jLong); + + outputShapeInfo[0] = 4; + outputShapeInfo[1] = bS; + + if (isNCHW) { + outputShapeInfo[2] = oC; + outputShapeInfo[3] = oH; + outputShapeInfo[4] = oW; + } else { + outputShapeInfo[2] = oH; + outputShapeInfo[3] = oW; + outputShapeInfo[4] = oC; + } + + ShapeUtils::updateStridesAndType(outputShapeInfo, weightsDShapeInfo, + shape::order(inputShapeInfo)); + + return SHAPELIST(CONSTANT(outputShapeInfo)); } - DECLARE_TYPES(sconv2d_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(sconv2d_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} //////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) { - - NDArray *input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - NDArray *gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - NDArray *weightsDepth = INPUT_VARIABLE(2); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] - NDArray *bias = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr - - NDArray *gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - NDArray *gradWD = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - NDArray *gradWP = nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] - NDArray *gradB = nullptr; // [oC] - - if(block.width() == 4) { - if((INPUT_VARIABLE(3))->rankOf() == 4) { - weightsPoint = INPUT_VARIABLE(3); - gradWP = OUTPUT_NULLIFIED(2); - } - else { - bias = INPUT_VARIABLE(3); - gradB = OUTPUT_NULLIFIED(2); - } - } - else if(block.width() == 5) { - weightsPoint = INPUT_VARIABLE(3); - bias = INPUT_VARIABLE(4); - gradWP = OUTPUT_NULLIFIED(2); - gradB = OUTPUT_NULLIFIED(3); - } - - - REQUIRE_TRUE(input->rankOf() == 4, 0, " SCONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 4, 0, " SCONV2D_BP OP: rank of output gradients (next epsilon) array must be equal to 4, but got %i instead !", gradO->rankOf()); - REQUIRE_TRUE(weightsDepth->rankOf() == 4, 0, " SCONV2D_BP OP: rank of weightsDepth array must be equal to 4 !, but got %i instead !", weightsDepth->rankOf()); - if(weightsPoint) { - REQUIRE_TRUE(weightsPoint->rankOf() == 4, 0, " SCONV2D_BP OP: rank of weightsPoint array must be equal to 4, but got %i instead !", weightsPoint->rankOf()); - REQUIRE_TRUE(gradWP->rankOf() == 4, 0, " SCONV2D_BP OP: rank of weightsPoint gradients array must be equal to 4, but got %i instead !", gradWP->rankOf()); - } - if(bias) { - REQUIRE_TRUE(bias->rankOf() == 1 || bias->rankOf() == 2, 0, " SCONV2D_BP OP: rank of biases array must be equal to 1 or 2, but got %i instead !", bias->rankOf()); - REQUIRE_TRUE(gradB->rankOf() == 1 || gradB->rankOf() == 2, 0, " SCONV2D_BP OP: rank of biases gradientsarray must be equal to 1 or 2, but got %i instead !", gradB->rankOf()); - } - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weightsDepth->sizeAt(indWmC); // channels multiplier - - std::vector expectedWeightsDShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); - REQUIRE_TRUE(weightsDepth->isSameShape(expectedWeightsDShape), 0, " SCONV2D_BP OP: wrong shape of weightsDepth array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDepth).c_str()); - REQUIRE_TRUE(gradWD->isSameShape(expectedWeightsDShape), 0, " SCONV2D_BP OP: wrong shape of gradWD array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(gradWD).c_str()); - if(weightsPoint) { - std::vector expectedWeightsPShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC*mC, oC); - REQUIRE_TRUE(weightsPoint->isSameShape(expectedWeightsPShape), 0, " SCONV2D_BP OP: wrong shape of weightsPoint array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPoint).c_str()); - REQUIRE_TRUE(gradWP->isSameShape(expectedWeightsPShape), 0, " SCONV2D_BP OP: wrong shape of gradWP array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(gradWP).c_str()); - } - if (bias) { - REQUIRE_TRUE(oC == bias->lengthOf(), 0, " SCONV2D_BP OP: length of bias array must be equal to outChannels, but got %i instead", bias->lengthOf()); - REQUIRE_TRUE(oC == gradB->lengthOf(), 0, " SCONV2D_BP OP: length of biases gradients array must be equal to outChannels, but got %i instead", gradB->lengthOf()); - } - - // if (iC == 1) { - // nd4j_debug(" SCONV2D_BP OP: for input_channels=1 this op is equivalent to standard conv2d_bp \n",""); - // sd::ops::conv2d_bp op; - // return op.execute(&block); - // } - - // ----- if weightsPoint is present, perform pointwise backprop first and calculate gradWP at this step ----- // - if (weightsPoint){ - - auto resultFFShape = isNCHW ? std::vector({bS, mC*iC, oH, oW}) : std::vector({bS, oH, oW, mC*iC}); - auto resultFF = NDArrayFactory::create_(input->ordering(), resultFFShape, input->dataType(), block.launchContext()); - ConvolutionUtils::sconv2d(block, input, weightsDepth, nullptr, nullptr, resultFF, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat); - - auto gradIDepthShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC*mC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); - auto gradIDepth = NDArrayFactory::create_(resultFF->ordering(), gradIDepthShape, resultFF->dataType(), block.launchContext()); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) - - ConvolutionUtils::conv2dBP(block, resultFF, weightsPoint, bias, gradO, gradIDepth, gradWP, gradB, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW, wFormat); // in this case oH=iH and oW=iW - - gradO = gradIDepth; - bias = gradB = nullptr; // if pointwise backprop was done then don't calculate gradB at depthwise_conv2d_bp step - - delete resultFF; + NDArray *input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + NDArray *gradO = INPUT_VARIABLE( + 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + NDArray *weightsDepth = INPUT_VARIABLE( + 2); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + NDArray *weightsPoint = + nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] + NDArray *bias = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr + + NDArray *gradI = OUTPUT_NULLIFIED( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + NDArray *gradWD = OUTPUT_NULLIFIED( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + NDArray *gradWP = + nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] + NDArray *gradB = nullptr; // [oC] + + if (block.width() == 4) { + if ((INPUT_VARIABLE(3))->rankOf() == 4) { + weightsPoint = INPUT_VARIABLE(3); + gradWP = OUTPUT_NULLIFIED(2); + } else { + bias = INPUT_VARIABLE(3); + gradB = OUTPUT_NULLIFIED(2); } - - // ----- apply depthwise_conv2d_bp ----- // - ConvolutionUtils::depthwiseConv2dBP(block, input, weightsDepth, bias, gradO, gradI, gradWD, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat); - - if(weightsPoint) - delete gradO; - - return Status::OK(); + } else if (block.width() == 5) { + weightsPoint = INPUT_VARIABLE(3); + bias = INPUT_VARIABLE(4); + gradWP = OUTPUT_NULLIFIED(2); + gradB = OUTPUT_NULLIFIED(3); + } + + REQUIRE_TRUE(input->rankOf() == 4, 0, + " SCONV2D_BP OP: rank of input array must be equal to 4, but " + "got %i instead !", + input->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 4, 0, + " SCONV2D_BP OP: rank of output gradients (next epsilon) array " + "must be equal to 4, but got %i instead !", + gradO->rankOf()); + REQUIRE_TRUE(weightsDepth->rankOf() == 4, 0, + " SCONV2D_BP OP: rank of weightsDepth array must be equal to 4 " + "!, but got %i instead !", + weightsDepth->rankOf()); + if (weightsPoint) { + REQUIRE_TRUE(weightsPoint->rankOf() == 4, 0, + " SCONV2D_BP OP: rank of weightsPoint array must be equal to " + "4, but got %i instead !", + weightsPoint->rankOf()); + REQUIRE_TRUE(gradWP->rankOf() == 4, 0, + " SCONV2D_BP OP: rank of weightsPoint gradients array must be " + "equal to 4, but got %i instead !", + gradWP->rankOf()); + } + if (bias) { + REQUIRE_TRUE(bias->rankOf() == 1 || bias->rankOf() == 2, 0, + " SCONV2D_BP OP: rank of biases array must be equal to 1 or " + "2, but got %i instead !", + bias->rankOf()); + REQUIRE_TRUE(gradB->rankOf() == 1 || gradB->rankOf() == 2, 0, + " SCONV2D_BP OP: rank of biases gradientsarray must be equal " + "to 1 or 2, but got %i instead !", + gradB->rankOf()); + } + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, + // iC, kH, kW], 2 - [mC, kH, kW, iC] + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier, output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weightsDepth->sizeAt(indWmC); // channels multiplier + + std::vector expectedWeightsDShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); + REQUIRE_TRUE(weightsDepth->isSameShape(expectedWeightsDShape), 0, + " SCONV2D_BP OP: wrong shape of weightsDepth array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), + ShapeUtils::shapeAsString(weightsDepth).c_str()); + REQUIRE_TRUE(gradWD->isSameShape(expectedWeightsDShape), 0, + " SCONV2D_BP OP: wrong shape of gradWD array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), + ShapeUtils::shapeAsString(gradWD).c_str()); + if (weightsPoint) { + std::vector expectedWeightsPShape = + ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC * mC, oC); + REQUIRE_TRUE(weightsPoint->isSameShape(expectedWeightsPShape), 0, + " SCONV2D_BP OP: wrong shape of weightsPoint array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), + ShapeUtils::shapeAsString(weightsPoint).c_str()); + REQUIRE_TRUE(gradWP->isSameShape(expectedWeightsPShape), 0, + " SCONV2D_BP OP: wrong shape of gradWP array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), + ShapeUtils::shapeAsString(gradWP).c_str()); + } + if (bias) { + REQUIRE_TRUE(oC == bias->lengthOf(), 0, + " SCONV2D_BP OP: length of bias array must be equal to " + "outChannels, but got %i instead", + bias->lengthOf()); + REQUIRE_TRUE(oC == gradB->lengthOf(), 0, + " SCONV2D_BP OP: length of biases gradients array must be " + "equal to outChannels, but got %i instead", + gradB->lengthOf()); + } + + // if (iC == 1) { + // nd4j_debug(" SCONV2D_BP OP: for input_channels=1 this op is equivalent + // to standard conv2d_bp \n",""); sd::ops::conv2d_bp op; return + // op.execute(&block); + // } + + // ----- if weightsPoint is present, perform pointwise backprop first and + // calculate gradWP at this step ----- // + if (weightsPoint) { + auto resultFFShape = isNCHW ? std::vector({bS, mC * iC, oH, oW}) + : std::vector({bS, oH, oW, mC * iC}); + auto resultFF = + NDArrayFactory::create_(input->ordering(), resultFFShape, + input->dataType(), block.launchContext()); + ConvolutionUtils::sconv2d(block, input, weightsDepth, nullptr, nullptr, + resultFF, kH, kW, sH, sW, pH, pW, dH, dW, + isSameMode, isNCHW, wFormat); + + auto gradIDepthShape = ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC * mC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}); + auto gradIDepth = NDArrayFactory::create_( + resultFF->ordering(), gradIDepthShape, resultFF->dataType(), + block.launchContext()); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, + // oH, oW] (NCHW) + + ConvolutionUtils::conv2dBP(block, resultFF, weightsPoint, bias, gradO, + gradIDepth, gradWP, gradB, 1, 1, 1, 1, 0, 0, 1, + 1, isSameMode, isNCHW, + wFormat); // in this case oH=iH and oW=iW + + gradO = gradIDepth; + bias = gradB = nullptr; // if pointwise backprop was done then don't + // calculate gradB at depthwise_conv2d_bp step + + delete resultFF; + } + + // ----- apply depthwise_conv2d_bp ----- // + ConvolutionUtils::depthwiseConv2dBP(block, input, weightsDepth, bias, gradO, + gradI, gradWD, gradB, kH, kW, sH, sW, pH, + pW, dH, dW, isSameMode, isNCHW, wFormat); + + if (weightsPoint) delete gradO; + + return Status::OK(); } - DECLARE_SHAPE_FN(sconv2d_bp) { - - auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradOShapeInfo = inputShape->at(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto weightsDShapeInfo = inputShape->at(2); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - Nd4jLong const* weightsPShapeInfo = nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] - Nd4jLong const* biasShapeInfo = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr - - if(block.width() == 4) { - if(inputShape->at(3)[0] == 4) - weightsPShapeInfo = inputShape->at(3); - else - biasShapeInfo = inputShape->at(3); - } - else if(block.width() == 5) { - weightsPShapeInfo = inputShape->at(3); - biasShapeInfo = inputShape->at(4); - } - - const int rank = 4; - REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, " SCONV2D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]); - REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, " SCONV2D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo[0]); - REQUIRE_TRUE(weightsDShapeInfo[0] == rank, 0, " SCONV2D_BP OP: rank of weightsDepth array must be equal to %i, but got %i instead !", rank, weightsDShapeInfo[0]); - if(weightsPShapeInfo) - REQUIRE_TRUE(weightsPShapeInfo[0] == rank, 0, " SCONV2D_BP OP: rank of weightsPoint array must be equal to %i, but got %i instead !", rank, weightsPShapeInfo[0]); - if(biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] ==1 || biasShapeInfo[0] == 2, 0, " SCONV2D_BP OP: rank of biases array must be 1 or 2, but got %i instead !", biasShapeInfo[0]);; - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); - if(!isNCHW) { - indIOioC = 3; indIiH = 1; - } - else { - indIOioC = 1; indIiH = 2; - } - - const int bS = inputShapeInfo[1]; // batch size - const int iH = inputShapeInfo[indIiH+1]; // input height - const int iW = inputShapeInfo[indIiH+2]; // input width - const int iC = inputShapeInfo[indIOioC+1]; // input channels - const int mC = weightsDShapeInfo[indWmC+1]; // channel multiplier - const int oC = weightsPShapeInfo ? weightsPShapeInfo[indWmC+1] : iC*mC; // output channels (oC or iC*mC) - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - std::vector expectedGradOShapeInfo = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indIiH,indIiH+1}); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShapeInfo), 0, "SCONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShapeInfo).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); - std::vector expectedWeightsDShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsDShapeInfo, expectedWeightsDShape), 0, "SCONV2D_BP OP: wrong shape of depth weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDShapeInfo).c_str()); - if(weightsPShapeInfo) { - std::vector expectedWeightsPShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC*mC, oC); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsPShapeInfo, expectedWeightsPShape), 0, "SCONV2D_BP OP: wrong shape of point array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPShapeInfo).c_str()); - } - if (biasShapeInfo) - REQUIRE_TRUE((biasShapeInfo[0] == 1 || biasShapeInfo[0] == 2) && oC == shape::length(biasShapeInfo), 0, "SCONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - - auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - auto gradWDshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsDShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - - Nd4jLong* gradWPshapeInfo(nullptr), *gradBshapeInfo(nullptr); - - if(weightsPShapeInfo && biasShapeInfo) { - gradWPshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsPShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWDshapeInfo), CONSTANT(gradWPshapeInfo), CONSTANT(gradBshapeInfo)); - } - - if(weightsPShapeInfo && !biasShapeInfo) { - gradWPshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsPShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWDshapeInfo), CONSTANT(gradWPshapeInfo)); - } - - if(!weightsPShapeInfo && biasShapeInfo) { - gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWDshapeInfo), CONSTANT(gradBshapeInfo)); - } - - return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWDshapeInfo)); + auto inputShapeInfo = + inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradOShapeInfo = inputShape->at( + 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto weightsDShapeInfo = inputShape->at( + 2); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + Nd4jLong const *weightsPShapeInfo = + nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] + Nd4jLong const *biasShapeInfo = + nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr + + if (block.width() == 4) { + if (inputShape->at(3)[0] == 4) + weightsPShapeInfo = inputShape->at(3); + else + biasShapeInfo = inputShape->at(3); + } else if (block.width() == 5) { + weightsPShapeInfo = inputShape->at(3); + biasShapeInfo = inputShape->at(4); + } + + const int rank = 4; + REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, + " SCONV2D_BP OP: rank of input array must be equal to %i, but " + "got %i instead !", + rank, inputShapeInfo[0]); + REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, + " SCONV2D_BP OP: rank of output gradients (next epsilon) array " + "must be equal to %i, but got %i instead !", + rank, gradOShapeInfo[0]); + REQUIRE_TRUE(weightsDShapeInfo[0] == rank, 0, + " SCONV2D_BP OP: rank of weightsDepth array must be equal to " + "%i, but got %i instead !", + rank, weightsDShapeInfo[0]); + if (weightsPShapeInfo) + REQUIRE_TRUE(weightsPShapeInfo[0] == rank, 0, + " SCONV2D_BP OP: rank of weightsPoint array must be equal to " + "%i, but got %i instead !", + rank, weightsPShapeInfo[0]); + if (biasShapeInfo) + REQUIRE_TRUE(biasShapeInfo[0] == 1 || biasShapeInfo[0] == 2, 0, + " SCONV2D_BP OP: rank of biases array must be 1 or 2, but got " + "%i instead !", + biasShapeInfo[0]); + ; + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, + // iC, kH, kW], 2 - [mC, kH, kW, iC] + + int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); + if (!isNCHW) { + indIOioC = 3; + indIiH = 1; + } else { + indIOioC = 1; + indIiH = 2; + } + + const int bS = inputShapeInfo[1]; // batch size + const int iH = inputShapeInfo[indIiH + 1]; // input height + const int iW = inputShapeInfo[indIiH + 2]; // input width + const int iC = inputShapeInfo[indIOioC + 1]; // input channels + const int mC = weightsDShapeInfo[indWmC + 1]; // channel multiplier + const int oC = weightsPShapeInfo ? weightsPShapeInfo[indWmC + 1] + : iC * mC; // output channels (oC or iC*mC) + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, isSameMode); + + std::vector expectedGradOShapeInfo = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indIiH, indIiH + 1}); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShapeInfo), 0, + "SCONV2D_BP OP: wrong shape of output gradients (next epsilon) array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShapeInfo).c_str(), + ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); + std::vector expectedWeightsDShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(weightsDShapeInfo, expectedWeightsDShape), 0, + "SCONV2D_BP OP: wrong shape of depth weights array, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), + ShapeUtils::shapeAsString(weightsDShapeInfo).c_str()); + if (weightsPShapeInfo) { + std::vector expectedWeightsPShape = + ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC * mC, oC); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(weightsPShapeInfo, expectedWeightsPShape), 0, + "SCONV2D_BP OP: wrong shape of point array, expected is %s, but got %s " + "instead !", + ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), + ShapeUtils::shapeAsString(weightsPShapeInfo).c_str()); + } + if (biasShapeInfo) + REQUIRE_TRUE((biasShapeInfo[0] == 1 || biasShapeInfo[0] == 2) && + oC == shape::length(biasShapeInfo), + 0, + "SCONV2D_BP OP: wrong shape of array with biases, expected " + "rank, length: <=2, %i, but got %i, %i instead !", + oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + + auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType( + inputShapeInfo, gradOShapeInfo, false, block.workspace()); + auto gradWDshapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsDShapeInfo, gradOShapeInfo, false, block.workspace()); + + Nd4jLong *gradWPshapeInfo(nullptr), *gradBshapeInfo(nullptr); + + if (weightsPShapeInfo && biasShapeInfo) { + gradWPshapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsPShapeInfo, gradOShapeInfo, false, block.workspace()); + gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType( + biasShapeInfo, gradOShapeInfo, false, block.workspace()); + return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWDshapeInfo), + CONSTANT(gradWPshapeInfo), CONSTANT(gradBshapeInfo)); + } + + if (weightsPShapeInfo && !biasShapeInfo) { + gradWPshapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsPShapeInfo, gradOShapeInfo, false, block.workspace()); + return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWDshapeInfo), + CONSTANT(gradWPshapeInfo)); + } + + if (!weightsPShapeInfo && biasShapeInfo) { + gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType( + biasShapeInfo, gradOShapeInfo, false, block.workspace()); + return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWDshapeInfo), + CONSTANT(gradBshapeInfo)); + } + + return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWDshapeInfo)); } - - -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/upsampling2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/upsampling2d.cpp index 4800b3db9ddc..d8eadf08a959 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/upsampling2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/upsampling2d.cpp @@ -26,103 +26,128 @@ #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(upsampling2d, 1, 1, false, 0, 2) { - auto input = INPUT_VARIABLE(0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) - auto output = OUTPUT_NULLIFIED(0); // [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) - - const int factorH = INT_ARG(0); - const int factorW = INT_ARG(1); - const int isNCHW = block.getIArguments()->size() > 2 ? INT_ARG(2) : 0; // INT_ARG(2): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(input->rankOf() == 4, 0, "UPSAMPLING2D op: input should be 4D, but got %i instead!", input->rankOf()); - REQUIRE_TRUE(output->rankOf() == 4, 0, "UPSAMPLING2D op: output should be 4D, but got %i instead!", output->rankOf()); - - ConvolutionUtils::upsampling2d(block, *input, *output, factorH, factorW, (bool)isNCHW); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) + auto output = + OUTPUT_NULLIFIED(0); // [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, + // factorH*iH, factorW*iW, iC] (NHWC) + + const int factorH = INT_ARG(0); + const int factorW = INT_ARG(1); + const int isNCHW = + block.numI() > 2 ? INT_ARG(2) : 0; // INT_ARG(2): 0-NCHW, 1-NHWC + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "UPSAMPLING2D op: input should be 4D, but got %i instead!", + input->rankOf()); + REQUIRE_TRUE(output->rankOf() == 4, 0, + "UPSAMPLING2D op: output should be 4D, but got %i instead!", + output->rankOf()); + + ConvolutionUtils::upsampling2d(block, *input, *output, factorH, factorW, + (bool)isNCHW); + + return Status::OK(); } DECLARE_SYN(upsampling, upsampling2d); - DECLARE_TYPES(upsampling2d) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(upsampling2d) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(upsampling2d) { - - auto inputShapeInfo = inputShape->at(0); - - REQUIRE_TRUE(inputShapeInfo[0] == 4, 0, "UPSAMPLING2D op: input should be 4D, but got %i instead!", inputShapeInfo[0]); - - const int factorH = INT_ARG(0); - const int factorW = INT_ARG(1); - const int isNCHW = block.getIArguments()->size() > 2 ? INT_ARG(2) : 0; // INT_ARG(2): 0-NCHW, 1-NHWC - - Nd4jLong *outputShapeInfo = nullptr; - ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inputShapeInfo[0]), Nd4jLong); - - outputShapeInfo[0] = inputShapeInfo[0]; - outputShapeInfo[1] = inputShapeInfo[1]; - - if(isNCHW) { - outputShapeInfo[2] = inputShapeInfo[2]; - outputShapeInfo[3] = inputShapeInfo[3] * factorH; - outputShapeInfo[4] = inputShapeInfo[4] * factorW; - } - else { - outputShapeInfo[2] = inputShapeInfo[2] * factorH; - outputShapeInfo[3] = inputShapeInfo[3] * factorW; - outputShapeInfo[4] = inputShapeInfo[4]; - } - - ShapeUtils::updateStridesAndType(outputShapeInfo, inputShapeInfo, shape::order(inputShapeInfo)); - - return SHAPELIST(CONSTANT(outputShapeInfo)); + auto inputShapeInfo = inputShape->at(0); + + REQUIRE_TRUE(inputShapeInfo[0] == 4, 0, + "UPSAMPLING2D op: input should be 4D, but got %i instead!", + inputShapeInfo[0]); + + const int factorH = INT_ARG(0); + const int factorW = INT_ARG(1); + const int isNCHW = + block.numI() > 2 ? INT_ARG(2) : 0; // INT_ARG(2): 0-NCHW, 1-NHWC + + Nd4jLong *outputShapeInfo = nullptr; + ALLOCATE(outputShapeInfo, block.workspace(), + shape::shapeInfoLength(inputShapeInfo[0]), Nd4jLong); + + outputShapeInfo[0] = inputShapeInfo[0]; + outputShapeInfo[1] = inputShapeInfo[1]; + + if (isNCHW) { + outputShapeInfo[2] = inputShapeInfo[2]; + outputShapeInfo[3] = inputShapeInfo[3] * factorH; + outputShapeInfo[4] = inputShapeInfo[4] * factorW; + } else { + outputShapeInfo[2] = inputShapeInfo[2] * factorH; + outputShapeInfo[3] = inputShapeInfo[3] * factorW; + outputShapeInfo[4] = inputShapeInfo[4]; + } + + ShapeUtils::updateStridesAndType(outputShapeInfo, inputShapeInfo, + shape::order(inputShapeInfo)); + + return SHAPELIST(CONSTANT(outputShapeInfo)); } - DECLARE_TYPES(upsampling2d_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - +DECLARE_TYPES(upsampling2d_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} ////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(upsampling2d_bp, 2, 1, false, 0, 0) { - - // NDArray* input = INPUT_VARIABLE(0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) - auto gradO = INPUT_VARIABLE(1); // [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) - - const int isNCHW = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; // INT_ARG(0): 0-NCHW, 1-NHWC - - // REQUIRE_TRUE(input->rankOf() == 4, 0, "UPSAMPLING2D_BP op: input array must be 4D, but got %i instead!", input->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 4, 0, "UPSAMPLING2D_BP op: output's gradient array must be 4D, but got %i instead!", gradO->rankOf()); - REQUIRE_TRUE(gradI->rankOf() == 4, 0, "UPSAMPLING2D_BP op: input's gradient array must be 4D, but got %i instead!", gradI->rankOf()); - - ConvolutionUtils::upsampling2dBP(block, *gradO, *gradI, (bool)isNCHW); - - return Status::OK(); + // NDArray* input = INPUT_VARIABLE(0); // [bS, iC, iH, iW] + // (NCHW) or [bS, iH, iW, iC] (NHWC) + auto gradO = INPUT_VARIABLE(1); // [bS, iC, factorH*iH, factorW*iW ] (NCHW) + // or [bS, factorH*iH, factorW*iW, iC] (NHWC) + auto gradI = OUTPUT_NULLIFIED( + 0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) + + const int isNCHW = + block.numI() > 0 ? INT_ARG(0) : 0; // INT_ARG(0): 0-NCHW, 1-NHWC + + // REQUIRE_TRUE(input->rankOf() == 4, 0, "UPSAMPLING2D_BP op: input array must + // be 4D, but got %i instead!", input->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 4, 0, + "UPSAMPLING2D_BP op: output's gradient array must be 4D, but " + "got %i instead!", + gradO->rankOf()); + REQUIRE_TRUE(gradI->rankOf() == 4, 0, + "UPSAMPLING2D_BP op: input's gradient array must be 4D, but got " + "%i instead!", + gradI->rankOf()); + + ConvolutionUtils::upsampling2dBP(block, *gradO, *gradI, (bool)isNCHW); + + return Status::OK(); } DECLARE_SYN(upsampling_bp, upsampling2d_bp); - DECLARE_SHAPE_FN(upsampling2d_bp) { - - REQUIRE_TRUE(inputShape->at(0)[0] == 4, 0, "UPSAMPLING2D_BP op: input array must be 4D, but got %i instead!", inputShape->at(0)[0]); - REQUIRE_TRUE(inputShape->at(1)[0] == 4, 0, "UPSAMPLING2D_BP op: output's gradient array must be 4D, but got %i instead!", inputShape->at(1)[0]); - - auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), inputShape->at(1), false, block.getWorkspace()); - - return SHAPELIST(CONSTANT(gradIShapeInfo)); + REQUIRE_TRUE( + inputShape->at(0)[0] == 4, 0, + "UPSAMPLING2D_BP op: input array must be 4D, but got %i instead!", + inputShape->at(0)[0]); + REQUIRE_TRUE(inputShape->at(1)[0] == 4, 0, + "UPSAMPLING2D_BP op: output's gradient array must be 4D, but " + "got %i instead!", + inputShape->at(1)[0]); + + auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType( + inputShape->at(0), inputShape->at(1), false, block.workspace()); + + return SHAPELIST(CONSTANT(gradIShapeInfo)); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/upsampling3d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/upsampling3d.cpp index 557468d147b0..4609ec4d6f1d 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/upsampling3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/upsampling3d.cpp @@ -25,103 +25,131 @@ #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(upsampling3d, 1, 1, false, 0, 3) { - auto input = INPUT_VARIABLE(0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) - auto output = OUTPUT_NULLIFIED(0); // [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) - - const int factorD = INT_ARG(0); - const int factorH = INT_ARG(1); - const int factorW = INT_ARG(2); - const int isNCDHW = block.getIArguments()->size() > 3 ? INT_ARG(3) : 0; // INT_ARG(3): 0-NCDHW, 1-NDHWC - - REQUIRE_TRUE(input->rankOf() == 5, 0, "UPSAMPLING3D op: input should be 5D, but got %i instead!", input->rankOf()); - REQUIRE_TRUE(output->rankOf() == 5, 0, "UPSAMPLING3D op: output should be 5D, but got %i instead!", output->rankOf()); - - ConvolutionUtils::upsampling3d(block, *input, *output, factorD, factorH, factorW, (bool)isNCDHW); + auto input = INPUT_VARIABLE( + 0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) + auto output = OUTPUT_NULLIFIED( + 0); // [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, + // factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) + + const int factorD = INT_ARG(0); + const int factorH = INT_ARG(1); + const int factorW = INT_ARG(2); + const int isNCDHW = + block.numI() > 3 ? INT_ARG(3) : 0; // INT_ARG(3): 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "UPSAMPLING3D op: input should be 5D, but got %i instead!", + input->rankOf()); + REQUIRE_TRUE(output->rankOf() == 5, 0, + "UPSAMPLING3D op: output should be 5D, but got %i instead!", + output->rankOf()); + + ConvolutionUtils::upsampling3d(block, *input, *output, factorD, factorH, + factorW, (bool)isNCDHW); + + return Status::OK(); +} - return Status::OK(); +DECLARE_TYPES(upsampling3d) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - DECLARE_TYPES(upsampling3d) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - DECLARE_SHAPE_FN(upsampling3d) { - - auto inputShapeInfo = inputShape->at(0); - - REQUIRE_TRUE(inputShapeInfo[0] == 5, 0, "UPSAMPLING2D op: input should be 5D, but got %i instead!", inputShapeInfo[0]); - - const int factorD = INT_ARG(0); - const int factorH = INT_ARG(1); - const int factorW = INT_ARG(2); - const int isNCDHW = block.getIArguments()->size() > 3 ? INT_ARG(3) : 0; // INT_ARG(3): 0-NCHW, 1-NHWC - - Nd4jLong *outputShapeInfo = nullptr; - ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inputShapeInfo[0]), Nd4jLong); - - outputShapeInfo[0] = inputShapeInfo[0]; - outputShapeInfo[1] = inputShapeInfo[1]; - - if(isNCDHW) { - outputShapeInfo[2] = inputShapeInfo[2]; - outputShapeInfo[3] = inputShapeInfo[3] * factorD; - outputShapeInfo[4] = inputShapeInfo[4] * factorH; - outputShapeInfo[5] = inputShapeInfo[5] * factorW; - } - else { - outputShapeInfo[2] = inputShapeInfo[2] * factorD; - outputShapeInfo[3] = inputShapeInfo[3] * factorH; - outputShapeInfo[4] = inputShapeInfo[4] * factorW; - outputShapeInfo[5] = inputShapeInfo[5]; - } - - ShapeUtils::updateStridesAndType(outputShapeInfo, inputShapeInfo, shape::order(inputShapeInfo)); - - return SHAPELIST(CONSTANT(outputShapeInfo)); + auto inputShapeInfo = inputShape->at(0); + + REQUIRE_TRUE(inputShapeInfo[0] == 5, 0, + "UPSAMPLING2D op: input should be 5D, but got %i instead!", + inputShapeInfo[0]); + + const int factorD = INT_ARG(0); + const int factorH = INT_ARG(1); + const int factorW = INT_ARG(2); + const int isNCDHW = + block.numI() > 3 ? INT_ARG(3) : 0; // INT_ARG(3): 0-NCHW, 1-NHWC + + Nd4jLong *outputShapeInfo = nullptr; + ALLOCATE(outputShapeInfo, block.workspace(), + shape::shapeInfoLength(inputShapeInfo[0]), Nd4jLong); + + outputShapeInfo[0] = inputShapeInfo[0]; + outputShapeInfo[1] = inputShapeInfo[1]; + + if (isNCDHW) { + outputShapeInfo[2] = inputShapeInfo[2]; + outputShapeInfo[3] = inputShapeInfo[3] * factorD; + outputShapeInfo[4] = inputShapeInfo[4] * factorH; + outputShapeInfo[5] = inputShapeInfo[5] * factorW; + } else { + outputShapeInfo[2] = inputShapeInfo[2] * factorD; + outputShapeInfo[3] = inputShapeInfo[3] * factorH; + outputShapeInfo[4] = inputShapeInfo[4] * factorW; + outputShapeInfo[5] = inputShapeInfo[5]; + } + + ShapeUtils::updateStridesAndType(outputShapeInfo, inputShapeInfo, + shape::order(inputShapeInfo)); + + return SHAPELIST(CONSTANT(outputShapeInfo)); } - DECLARE_TYPES(upsampling3d_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(upsampling3d_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} ////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(upsampling3d_bp, 2, 1, false, 0, 0) { - // NDArray* input = INPUT_VARIABLE(0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) - auto gradO = INPUT_VARIABLE(1); // [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) - - const int isNCDHW = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; // INT_ARG(0): 0-NCHW, 1-NHWC - - // REQUIRE_TRUE(input->rankOf() == 5, 0, "UPSAMPLING3D_BP op: input array must be 4D, but got %i instead!", input->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 5, 0, "UPSAMPLING3D_BP op: output's gradient array must be 4D, but got %i instead!", gradO->rankOf()); - REQUIRE_TRUE(gradI->rankOf() == 5, 0, "UPSAMPLING3D_BP op: input's gradient array must be 4D, but got %i instead!", gradI->rankOf()); - - ConvolutionUtils::upsampling3dBP(block, *gradO, *gradI, (bool)isNCDHW); - - return Status::OK(); + // NDArray* input = INPUT_VARIABLE(0); // [bS, iC, iD, iH, iW] + // (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) + auto gradO = INPUT_VARIABLE( + 1); // [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, + // factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) + auto gradI = OUTPUT_NULLIFIED( + 0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) + + const int isNCDHW = + block.numI() > 0 ? INT_ARG(0) : 0; // INT_ARG(0): 0-NCHW, 1-NHWC + + // REQUIRE_TRUE(input->rankOf() == 5, 0, "UPSAMPLING3D_BP op: input array must + // be 4D, but got %i instead!", input->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 5, 0, + "UPSAMPLING3D_BP op: output's gradient array must be 4D, but " + "got %i instead!", + gradO->rankOf()); + REQUIRE_TRUE(gradI->rankOf() == 5, 0, + "UPSAMPLING3D_BP op: input's gradient array must be 4D, but got " + "%i instead!", + gradI->rankOf()); + + ConvolutionUtils::upsampling3dBP(block, *gradO, *gradI, (bool)isNCDHW); + + return Status::OK(); } - DECLARE_SHAPE_FN(upsampling3d_bp) { - - REQUIRE_TRUE(inputShape->at(0)[0] == 5, 0, "UPSAMPLING3D_BP op: input array must be 4D, but got %i instead!", inputShape->at(0)[0]); - REQUIRE_TRUE(inputShape->at(1)[0] == 5, 0, "UPSAMPLING3D_BP op: output's gradient array must be 4D, but got %i instead!", inputShape->at(1)[0]); - - auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), inputShape->at(1), false, block.getWorkspace()); - - return SHAPELIST(CONSTANT(gradIShapeInfo)); + REQUIRE_TRUE( + inputShape->at(0)[0] == 5, 0, + "UPSAMPLING3D_BP op: input array must be 4D, but got %i instead!", + inputShape->at(0)[0]); + REQUIRE_TRUE(inputShape->at(1)[0] == 5, 0, + "UPSAMPLING3D_BP op: output's gradient array must be 4D, but " + "got %i instead!", + inputShape->at(1)[0]); + + auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType( + inputShape->at(0), inputShape->at(1), false, block.workspace()); + + return SHAPELIST(CONSTANT(gradIShapeInfo)); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/dot_product_attention.cpp b/libnd4j/include/ops/declarable/generic/nn/dot_product_attention.cpp index 49dc52a03a73..b3e3732a631c 100644 --- a/libnd4j/include/ops/declarable/generic/nn/dot_product_attention.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/dot_product_attention.cpp @@ -24,202 +24,241 @@ #include #include - namespace sd { -namespace ops { - - CUSTOM_OP_IMPL(dot_product_attention, 3, -1, false, 0, 2) { - auto queries = INPUT_VARIABLE(0); - auto keys = INPUT_VARIABLE(1); - auto values = INPUT_VARIABLE(2); - auto mask = block.width() > 3 ? INPUT_VARIABLE(3) : nullptr; - - auto output = OUTPUT_VARIABLE(0); - NDArray* weights; - bool outputWeights = INT_ARG(1); - if(outputWeights){ - weights = OUTPUT_VARIABLE(1); - }else{ - auto weightShape = ShapeUtils::evalShapeForMatmul(keys->shapeInfo(), queries->shapeInfo(), true, false); - weights = new NDArray('c', weightShape, values->dataType(), block.launchContext()); - } - - int normalization = INT_ARG(0); - - REQUIRE_TRUE(queries->rankOf() == keys->rankOf() && keys->rankOf() == values->rankOf(), 0, - "dot_product_attention: Queries, Keys and Values must have same rank. " - "But got queries = %s, keys = %s, values = %s", ShapeUtils::shapeAsString(queries).c_str(), - ShapeUtils::shapeAsString(keys).c_str(), ShapeUtils::shapeAsString(values).c_str()); - - REQUIRE_TRUE(queries->rankOf() == 3 || queries->rankOf() == 4, 0, - "dot_product_attention: Queries, Keys and Values must be rank 3 arrays for single headed attention " - "or rank 4 arrays for multi headed attention. But got rank = %i", queries->rankOf()); - - REQUIRE_TRUE(queries->sizeAt(0) == keys->sizeAt(0) && keys->sizeAt(0) == values->sizeAt(0), 0, - "dot_product_attention: Queries, Keys and Values must have the same mini batch size. " - "But got queries = %i, keys = %i, values = %i", queries->sizeAt(0), keys->sizeAt(0), values->sizeAt(0)); - - REQUIRE_TRUE(queries->sizeAt(-2) == keys->sizeAt(-2), 0, - "dot_product_attention: Queries and Keys must have the same feature size. " - "But got queries = %i, keys = %i", queries->sizeAt(-2), keys->sizeAt(-2)); - - REQUIRE_TRUE(keys->sizeAt(-1) == values->sizeAt(-1), 0, - "dot_product_attention: Keys and Values must have the same timestep length. " - "But got keys = %i, values = %i", keys->sizeAt(-1), values->sizeAt(-1)); - - sd::ops::matmul mmul; - mmul.execute({keys, queries}, {weights}, {}, {1}, {}); - if(normalization) { - *weights /= sqrt((double)keys->sizeAt(-2)); - } - - if(mask != nullptr){ - NDArray reshapedMask; - if(weights->rankOf() == 4){ - reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), 1, mask->sizeAt(1), 1}); - }else{ - reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), mask->sizeAt(1), 1}); - } - - // the mask is 0 for positions we want to skip, and 1 for positions we want to keep. By subtracting 1 from - // it we get -1 for those we want to skip and 0 for those we want to keep. Multiplying it by 1e9 then - // turns all of those we want to skip into very large negative values. By adding this to the weights - // before going through the softmax, we effectively push all masked positions to zero after softmax. - // - // we are using 1e9 to mean effectively infinity - *weights += (reshapedMask - 1) * 1e9; - } - - sd::ops::softmax softmax; - softmax.execute({weights}, std::vector{weights}, {}, {-2}, {}, {}, true); - - mmul.execute({values, weights}, {output}, {}, {}, {}); - - if(!outputWeights){ - delete weights; - } - - return Status::OK(); - } - - - DECLARE_TYPES(dot_product_attention) { - getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS}); - getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); - } - - DECLARE_SHAPE_FN(dot_product_attention) { - auto query_shape = inputShape->at(0); - auto keys_shape = inputShape->at(1); - auto values_shape = inputShape->at(2); - - auto weights_shape = ConstantShapeHelper::getInstance().createShapeInfo(sd::ArrayOptions::dataType(values_shape), 'c', ShapeUtils::evalShapeForMatmul(keys_shape, query_shape, true, false)); - auto output_shape = ConstantShapeHelper::getInstance().createShapeInfo(sd::ArrayOptions::dataType(values_shape), 'c', ShapeUtils::evalShapeForMatmul(values_shape, weights_shape, false, false)); - - if(INT_ARG(1)){ - return SHAPELIST(output_shape, weights_shape); - }else{ - return SHAPELIST(output_shape); - } - +namespace ops { + +CUSTOM_OP_IMPL(dot_product_attention, 3, -1, false, 0, 2) { + auto queries = INPUT_VARIABLE(0); + auto keys = INPUT_VARIABLE(1); + auto values = INPUT_VARIABLE(2); + auto mask = block.width() > 3 ? INPUT_VARIABLE(3) : nullptr; + + auto output = OUTPUT_VARIABLE(0); + NDArray *weights; + bool outputWeights = INT_ARG(1); + if (outputWeights) { + weights = OUTPUT_VARIABLE(1); + } else { + auto weightShape = ShapeUtils::evalShapeForMatmul( + keys->shapeInfo(), queries->shapeInfo(), true, false); + weights = new NDArray('c', weightShape, values->dataType(), + block.launchContext()); + } + + int normalization = INT_ARG(0); + + REQUIRE_TRUE( + queries->rankOf() == keys->rankOf() && keys->rankOf() == values->rankOf(), + 0, + "dot_product_attention: Queries, Keys and Values must have same rank. " + "But got queries = %s, keys = %s, values = %s", + ShapeUtils::shapeAsString(queries).c_str(), + ShapeUtils::shapeAsString(keys).c_str(), + ShapeUtils::shapeAsString(values).c_str()); + + REQUIRE_TRUE(queries->rankOf() == 3 || queries->rankOf() == 4, 0, + "dot_product_attention: Queries, Keys and Values must be rank 3 " + "arrays for single headed attention " + "or rank 4 arrays for multi headed attention. But got rank = %i", + queries->rankOf()); + + REQUIRE_TRUE(queries->sizeAt(0) == keys->sizeAt(0) && + keys->sizeAt(0) == values->sizeAt(0), + 0, + "dot_product_attention: Queries, Keys and Values must have the " + "same mini batch size. " + "But got queries = %i, keys = %i, values = %i", + queries->sizeAt(0), keys->sizeAt(0), values->sizeAt(0)); + + REQUIRE_TRUE(queries->sizeAt(-2) == keys->sizeAt(-2), 0, + "dot_product_attention: Queries and Keys must have the same " + "feature size. " + "But got queries = %i, keys = %i", + queries->sizeAt(-2), keys->sizeAt(-2)); + + REQUIRE_TRUE(keys->sizeAt(-1) == values->sizeAt(-1), 0, + "dot_product_attention: Keys and Values must have the same " + "timestep length. " + "But got keys = %i, values = %i", + keys->sizeAt(-1), values->sizeAt(-1)); + + sd::ops::matmul mmul; + mmul.execute({keys, queries}, {weights}, {}, {1}, {}); + if (normalization) { + *weights /= sqrt((double)keys->sizeAt(-2)); + } + + if (mask != nullptr && mask->defined()) { + NDArray reshapedMask; + if (weights->rankOf() == 4) { + reshapedMask = mask->reshape(mask->ordering(), + {mask->sizeAt(0), 1, mask->sizeAt(1), 1}); + } else { + reshapedMask = mask->reshape(mask->ordering(), + {mask->sizeAt(0), mask->sizeAt(1), 1}); } - CUSTOM_OP_IMPL(dot_product_attention_bp, 4, 3, false, 0, 1) { - auto queries = INPUT_VARIABLE(0); - auto keys = INPUT_VARIABLE(1); - auto values = INPUT_VARIABLE(2); - auto eps = INPUT_VARIABLE(3); - auto mask = block.width() > 4 ? INPUT_VARIABLE(4) : nullptr; - - auto dLdq = OUTPUT_VARIABLE(0); - auto dLdk = OUTPUT_VARIABLE(1); - auto dLdv = OUTPUT_VARIABLE(2); - - int normalization = INT_ARG(0); - - - REQUIRE_TRUE(queries->rankOf() == keys->rankOf() && keys->rankOf() == values->rankOf(), 0, - "dot_product_attention: Queries, Keys and Values must have same rank. " - "But got queries = %s, keys = %s, values = %s", ShapeUtils::shapeAsString(queries).c_str(), - ShapeUtils::shapeAsString(keys).c_str(), ShapeUtils::shapeAsString(values).c_str()); - - REQUIRE_TRUE(queries->rankOf() == 3 || queries->rankOf() == 4, 0, - "dot_product_attention: Queries, Keys and Values must be rank 3 arrays for single headed attention " - "or rank 4 arrays for multi headed attention. But got rank = %i", queries->rankOf()); + // the mask is 0 for positions we want to skip, and 1 for positions we want + // to keep. By subtracting 1 from it we get -1 for those we want to skip and + // 0 for those we want to keep. Multiplying it by 1e9 then turns all of + // those we want to skip into very large negative values. By adding this to + // the weights before going through the softmax, we effectively push all + // masked positions to zero after softmax. + // + // we are using 1e9 to mean effectively infinity + *weights += (reshapedMask - 1) * 1e9; + } - REQUIRE_TRUE(queries->sizeAt(0) == keys->sizeAt(0) && keys->sizeAt(0) == values->sizeAt(0), 0, - "dot_product_attention: Queries, Keys and Values must have the same mini batch size. " - "But got queries = %i, keys = %i, values = %i", queries->sizeAt(0), keys->sizeAt(0), values->sizeAt(0)); + sd::ops::softmax softmax; + softmax.execute({weights}, std::vector{weights}, {}, {-2}, {}, {}, + true); - REQUIRE_TRUE(queries->sizeAt(-2) == keys->sizeAt(-2), 0, - "dot_product_attention: Queries and Keys must have the same feature size. " - "But got queries = %i, keys = %i", queries->sizeAt(-2), keys->sizeAt(-2)); + mmul.execute({values, weights}, {output}, {}, {}, {}); - REQUIRE_TRUE(keys->sizeAt(-1) == values->sizeAt(-1), 0, - "dot_product_attention: Keys and Values must have the same timestep length. " - "But got keys = %i, values = %i", keys->sizeAt(-1), values->sizeAt(-1)); + if (!outputWeights) { + delete weights; + } + return Status::OK(); +} - double factor; - if(normalization) - factor = sqrt((double)keys->sizeAt(-2)); - - auto weightShape = ShapeUtils::evalShapeForMatmul(keys->shapeInfo(), queries->shapeInfo(), true, false); - - sd::ops::matmul mmul; - NDArray preSoftmax('c', weightShape, values->dataType(), block.launchContext()); - mmul.execute({keys, queries}, {&preSoftmax},{}, {1}, {}); - - if(normalization) - preSoftmax /= factor; +DECLARE_TYPES(dot_product_attention) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS}); + getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); +} - if(mask != nullptr){ - NDArray reshapedMask; - if(preSoftmax.rankOf() == 4){ - reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), 1, mask->sizeAt(1), 1}); - }else{ - reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), mask->sizeAt(1), 1}); - } - preSoftmax += (reshapedMask - 1) * 1e9; - } +DECLARE_SHAPE_FN(dot_product_attention) { + auto query_shape = inputShape->at(0); + auto keys_shape = inputShape->at(1); + auto values_shape = inputShape->at(2); + + auto weights_shape = ConstantShapeHelper::getInstance().createShapeInfo( + sd::ArrayOptions::dataType(values_shape), 'c', + ShapeUtils::evalShapeForMatmul(keys_shape, query_shape, true, false)); + auto output_shape = ConstantShapeHelper::getInstance().createShapeInfo( + sd::ArrayOptions::dataType(values_shape), 'c', + ShapeUtils::evalShapeForMatmul(values_shape, weights_shape, false, + false)); + + if (INT_ARG(1)) { + return SHAPELIST(output_shape, weights_shape); + } else { + return SHAPELIST(output_shape); + } +} - NDArray weights('c', weightShape, values->dataType(), block.launchContext()); - sd::ops::softmax softmax; - softmax.execute({&preSoftmax}, {&weights},{}, {-2}, {}); +CUSTOM_OP_IMPL(dot_product_attention_bp, 4, 3, false, 0, 1) { + auto queries = INPUT_VARIABLE(0); + auto keys = INPUT_VARIABLE(1); + auto values = INPUT_VARIABLE(2); + auto eps = INPUT_VARIABLE(3); + auto mask = block.width() > 4 ? INPUT_VARIABLE(4) : nullptr; + + auto dLdq = OUTPUT_VARIABLE(0); + auto dLdk = OUTPUT_VARIABLE(1); + auto dLdv = OUTPUT_VARIABLE(2); + + int normalization = INT_ARG(0); + + REQUIRE_TRUE( + queries->rankOf() == keys->rankOf() && keys->rankOf() == values->rankOf(), + 0, + "dot_product_attention: Queries, Keys and Values must have same rank. " + "But got queries = %s, keys = %s, values = %s", + ShapeUtils::shapeAsString(queries).c_str(), + ShapeUtils::shapeAsString(keys).c_str(), + ShapeUtils::shapeAsString(values).c_str()); + + REQUIRE_TRUE(queries->rankOf() == 3 || queries->rankOf() == 4, 0, + "dot_product_attention: Queries, Keys and Values must be rank 3 " + "arrays for single headed attention " + "or rank 4 arrays for multi headed attention. But got rank = %i", + queries->rankOf()); + + REQUIRE_TRUE(queries->sizeAt(0) == keys->sizeAt(0) && + keys->sizeAt(0) == values->sizeAt(0), + 0, + "dot_product_attention: Queries, Keys and Values must have the " + "same mini batch size. " + "But got queries = %i, keys = %i, values = %i", + queries->sizeAt(0), keys->sizeAt(0), values->sizeAt(0)); + + REQUIRE_TRUE(queries->sizeAt(-2) == keys->sizeAt(-2), 0, + "dot_product_attention: Queries and Keys must have the same " + "feature size. " + "But got queries = %i, keys = %i", + queries->sizeAt(-2), keys->sizeAt(-2)); + + REQUIRE_TRUE(keys->sizeAt(-1) == values->sizeAt(-1), 0, + "dot_product_attention: Keys and Values must have the same " + "timestep length. " + "But got keys = %i, values = %i", + keys->sizeAt(-1), values->sizeAt(-1)); + + double factor; + if (normalization) factor = sqrt((double)keys->sizeAt(-2)); + + auto weightShape = ShapeUtils::evalShapeForMatmul( + keys->shapeInfo(), queries->shapeInfo(), true, false); + + sd::ops::matmul mmul; + NDArray preSoftmax('c', weightShape, values->dataType(), + block.launchContext()); + mmul.execute({keys, queries}, {&preSoftmax}, {}, {1}, {}); + + if (normalization) preSoftmax /= factor; + + if (mask != nullptr) { + NDArray reshapedMask; + if (preSoftmax.rankOf() == 4) { + reshapedMask = mask->reshape(mask->ordering(), + {mask->sizeAt(0), 1, mask->sizeAt(1), 1}); + } else { + reshapedMask = mask->reshape(mask->ordering(), + {mask->sizeAt(0), mask->sizeAt(1), 1}); + } + preSoftmax += (reshapedMask - 1) * 1e9; + } - sd::ops::matmul_bp mmul_bp; - NDArray dLdw(weights.shapeInfo(), block.workspace()); - mmul_bp.execute({values, &weights, eps}, std::vector{dLdv, &dLdw}, {}, {}, {}); + NDArray weights('c', weightShape, values->dataType(), block.launchContext()); + sd::ops::softmax softmax; + softmax.execute({&preSoftmax}, {&weights}, {}, {-2}, {}); - NDArray dLds(preSoftmax.shapeInfo(), block.workspace()); - sd::ops::softmax_bp softmax_bp; - softmax_bp.execute({&preSoftmax, &dLdw}, {&dLds}, {}, {-2}, {}); + sd::ops::matmul_bp mmul_bp; + NDArray dLdw(weights.shapeInfo(), block.workspace()); + mmul_bp.execute({values, &weights, eps}, std::vector{dLdv, &dLdw}, + {}, {}, {}); - if(normalization) - dLds /= factor; + NDArray dLds(preSoftmax.shapeInfo(), block.workspace()); + sd::ops::softmax_bp softmax_bp; + softmax_bp.execute({&preSoftmax, &dLdw}, {&dLds}, {}, {-2}, {}); - mmul_bp.execute({keys, queries, &dLds}, std::vector{dLdk, dLdq}, {}, {1}, {}); + if (normalization) dLds /= factor; - return Status::OK(); - } + mmul_bp.execute({keys, queries, &dLds}, std::vector{dLdk, dLdq}, + {}, {1}, {}); - DECLARE_TYPES(dot_product_attention_bp) { - getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS}); - getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(dot_product_attention_bp) { - Nd4jLong *dLdq_shape; - COPY_SHAPE(inputShape->at(0), dLdq_shape); - Nd4jLong *dLdk_shape; - COPY_SHAPE(inputShape->at(1), dLdk_shape); - Nd4jLong *dLdv_shape; - COPY_SHAPE(inputShape->at(2), dLdv_shape); +DECLARE_TYPES(dot_product_attention_bp) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS}); + getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); +} - return SHAPELIST(CONSTANT(dLdq_shape), CONSTANT(dLdk_shape), CONSTANT(dLdv_shape)); - } +DECLARE_SHAPE_FN(dot_product_attention_bp) { + Nd4jLong *dLdq_shape; + COPY_SHAPE(inputShape->at(0), dLdq_shape); + Nd4jLong *dLdk_shape; + COPY_SHAPE(inputShape->at(1), dLdk_shape); + Nd4jLong *dLdv_shape; + COPY_SHAPE(inputShape->at(2), dLdv_shape); + return SHAPELIST(CONSTANT(dLdq_shape), CONSTANT(dLdk_shape), + CONSTANT(dLdv_shape)); } -} + +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/embedding_lookup.cpp b/libnd4j/include/ops/declarable/generic/nn/embedding_lookup.cpp index 0f4a01e031c8..44dff501a08a 100644 --- a/libnd4j/include/ops/declarable/generic/nn/embedding_lookup.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/embedding_lookup.cpp @@ -21,99 +21,102 @@ #include #if NOT_EXCLUDED(OP_embedding_lookup) -#include #include -#include -#include +#include +#include +#include namespace sd { namespace ops { - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(embedding_lookup, 2, 1, false, 0, 1) { - auto input = INPUT_VARIABLE(0); // lookup param - auto indeces = INPUT_VARIABLE(1); // indeces, as is - auto output = OUTPUT_VARIABLE(0); // - - if (block.width() > 2) { // multiple input - indeces = INPUT_VARIABLE(block.width() - 1); - std::vector dims(input->rankOf()); - int i = output->rankOf() - input->rankOf(); - for (auto& v: dims){ - v = i++; - } - - ResultSet outputView = output->allTensorsAlongDimension(dims); - REQUIRE_TRUE(block.width() > output->sizeAt(0), 0, "embedding_lookup: input list should be greater then %i, but %i given.", - output->sizeAt(0), block.width() - ); - for (Nd4jLong e = 0; e < indeces->lengthOf(); ++e) { - Nd4jLong thisIndex = (*indeces).e(e); - input = INPUT_VARIABLE(thisIndex); // lookup param - - outputView.at(e)->assign(input); - } + auto input = INPUT_VARIABLE(0); // lookup param + auto indeces = INPUT_VARIABLE(1); // indeces, as is + auto output = OUTPUT_VARIABLE(0); // + + if (block.width() > 2) { // multiple input + indeces = INPUT_VARIABLE(block.width() - 1); + std::vector dims(input->rankOf()); + int i = output->rankOf() - input->rankOf(); + for (auto& v : dims) { + v = i++; } - else { - int indexRank = indeces->rankOf(); - REQUIRE_TRUE(indexRank > 0, 0, "embeded_lookup: input array of indexes can't be single scalar, the requirement is: rank > 0 !"); - - int inputRank = input->rankOf(); - int lastIndDim = indeces->lengthOf(); - int partition_mode = INT_ARG(0); // partition_mode == 0 - i.e. 'mod' , 1 - 'div' - sd::ops::gather op; + ResultSet outputView = output->allTensorsAlongDimension(dims); + REQUIRE_TRUE( + block.width() > output->sizeAt(0), 0, + "embedding_lookup: input list should be greater then %i, but %i given.", + output->sizeAt(0), block.width()); + for (Nd4jLong e = 0; e < indeces->lengthOf(); ++e) { + Nd4jLong thisIndex = (*indeces).e(e); + input = INPUT_VARIABLE(thisIndex); // lookup param - auto result(op.evaluate({input, indeces}, {0})); - REQUIRE_TRUE(result.status() == Status::OK(), 0, "embedding_lookup: cannot retrieve results from gather op."); - REQUIRE_TRUE(result.at(0)->isSameShape(output), 0, "embedding_lookup: wrong shape of return from gather op."); - output->assign(result.at(0)); + outputView.at(e).assign(input); } - return Status::OK(); + } else { + int indexRank = indeces->rankOf(); + REQUIRE_TRUE(indexRank > 0, 0, + "embeded_lookup: input array of indexes can't be single " + "scalar, the requirement is: rank > 0 !"); + + int inputRank = input->rankOf(); + int lastIndDim = indeces->lengthOf(); + int partition_mode = + INT_ARG(0); // partition_mode == 0 - i.e. 'mod' , 1 - 'div' + + sd::ops::gather op; + + auto result(op.evaluate({input, indeces}, {0})); + REQUIRE_TRUE(result.status() == Status::OK(), 0, + "embedding_lookup: cannot retrieve results from gather op."); + REQUIRE_TRUE(result.at(0).isSameShape(output), 0, + "embedding_lookup: wrong shape of return from gather op."); + output->assign(result.at(0)); + } + return Status::OK(); } DECLARE_TYPES(embedding_lookup) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::ANY); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::ANY); } DECLARE_SHAPE_FN(embedding_lookup) { + auto inShapeInfo = inputShape->at(0); + auto indecesShapeInfo = inputShape->at(1); + int inRank = shape::rank(inShapeInfo); + if (inputShape->size() == 2u) { + int outRank = inRank; - auto inShapeInfo = inputShape->at(0); - auto indecesShapeInfo = inputShape->at(1); - int inRank = shape::rank(inShapeInfo); - if (inputShape->size() == 2u) { - int outRank = inRank; - - std::vector shapeInfo(outRank); - - shapeInfo[0] = indecesShapeInfo[1]; // vector - how many elements - for (int e = 1; e < outRank; e++) - shapeInfo[e] = shape::sizeAt(inShapeInfo, e); - - auto outShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), shapeInfo); - return SHAPELIST(outShapeInfo); - } - - - int outRank = inRank + 1; std::vector shapeInfo(outRank); - auto indeces = INPUT_VARIABLE(block.width() - 1); - shapeInfo[0] = indeces->lengthOf(); // vector - how many elements + + shapeInfo[0] = indecesShapeInfo[1]; // vector - how many elements for (int e = 1; e < outRank; e++) - shapeInfo[e] = shape::sizeAt(inShapeInfo, e); + shapeInfo[e] = shape::sizeAt(inShapeInfo, e); - auto outShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), shapeInfo); + auto outShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), + shapeInfo); return SHAPELIST(outShapeInfo); + } + + int outRank = inRank + 1; + std::vector shapeInfo(outRank); + auto indeces = INPUT_VARIABLE(block.width() - 1); + shapeInfo[0] = indeces->lengthOf(); // vector - how many elements + for (int e = 1; e < outRank; e++) + shapeInfo[e] = shape::sizeAt(inShapeInfo, e); + + auto outShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), + shapeInfo); + return SHAPELIST(outShapeInfo); } - - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp b/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp index 6e911e405043..99d76e992510 100644 --- a/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp @@ -26,130 +26,147 @@ namespace sd { namespace ops { - DECLARE_TYPES(fused_batch_norm) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(fused_batch_norm) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} CUSTOM_OP_IMPL(fused_batch_norm, 3, 3, false, 0, 2) { - auto x = INPUT_VARIABLE(0); // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW) - auto scale = INPUT_VARIABLE(1); // [iD] - auto offset = INPUT_VARIABLE(2); // [iD] - - auto y = OUTPUT_VARIABLE(0); // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW) - auto batchMean = OUTPUT_VARIABLE(1); // [iD] - auto batchVar = OUTPUT_VARIABLE(2); // [iD] - - const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW - const bool isTraining = (bool)INT_ARG(1); - - REQUIRE_TRUE(x->rankOf() == 4, 0, "CUSTOM_OP fused_batch_norm: the rank of input x array must be equal to 4, but got %i instead !", x->rankOf()); - - int bS = x->sizeAt(0); // batch size - int iH, iW, iD; // input height, input width, input depth(number of channels) - if(dataFormat) { - iD = x->sizeAt(1); - iH = x->sizeAt(2); - iW = x->sizeAt(3); - } - else { - iD = x->sizeAt(3); - iH = x->sizeAt(1); - iW = x->sizeAt(2); - } - - REQUIRE_TRUE(scale->rankOf() == 1 && scale->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scale).c_str()); - REQUIRE_TRUE(offset->rankOf() == 1 && offset->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input offset array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(offset).c_str()); - - NDArray *mean(nullptr), *variance(nullptr); - if(!isTraining){ - mean = INPUT_VARIABLE(3); - variance = INPUT_VARIABLE(4); - REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input mean array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(mean).c_str()); - REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input variance array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(variance).c_str()); - } - else { - //REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width()); - std::vector shape = {iD}; - mean = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext()); - variance = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext()); - } - - // FIXME: double? - double epsilon; - if(block.getTArguments()->size() > 0) - epsilon = T_ARG(0) > 1.001e-5 ? T_ARG(0) : 1.001e-5; - else - epsilon = 0.001; - - const int restSize = x->lengthOf() / iD; - auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, mean->dataType(), block.launchContext()); - xAffected.assign(x); - - const int restSizeMinusOne = (restSize > 1) ? (restSize - 1) : 1; - // FIXME: float? - const double restSizeInv = 1.0 / restSize; - const double restSizeAdjust = (double)restSize / restSizeMinusOne; - - if(isTraining) { - auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0}); - sum *= restSizeInv; - mean->assign(sum); - *batchMean = *mean; - //delete sum; - } - else - *batchMean = 0.; - - xAffected -= *mean; - - if(isTraining) { - int power = 2; - xAffected.applyScalar(scalar::Pow, power, xAffected); - auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0}); - sum *= restSizeInv; - variance->assign(sum); - *batchVar = (*variance) * restSizeAdjust; - //delete sum; - } - else - *batchVar = 0.; - xAffected *= (*variance + epsilon).transform(transform::RSqrt) * (*scale) + (*offset); - y->assign( xAffected ); - - if(isTraining) { - delete mean; - delete variance; - } - - return Status::OK(); + auto x = INPUT_VARIABLE(0); // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW) + auto scale = INPUT_VARIABLE(1); // [iD] + auto offset = INPUT_VARIABLE(2); // [iD] + + auto y = OUTPUT_VARIABLE(0); // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW) + auto batchMean = OUTPUT_VARIABLE(1); // [iD] + auto batchVar = OUTPUT_VARIABLE(2); // [iD] + + const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW + const bool isTraining = (bool)INT_ARG(1); + + REQUIRE_TRUE(x->rankOf() == 4, 0, + "CUSTOM_OP fused_batch_norm: the rank of input x array must be " + "equal to 4, but got %i instead !", + x->rankOf()); + + int bS = x->sizeAt(0); // batch size + int iH, iW, iD; // input height, input width, input depth(number of channels) + if (dataFormat) { + iD = x->sizeAt(1); + iH = x->sizeAt(2); + iW = x->sizeAt(3); + } else { + iD = x->sizeAt(3); + iH = x->sizeAt(1); + iW = x->sizeAt(2); + } + + REQUIRE_TRUE(scale->rankOf() == 1 && scale->sizeAt(0) == iD, 0, + "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, " + "expected is [%i], but got %s instead", + iD, ShapeUtils::shapeAsString(scale).c_str()); + REQUIRE_TRUE(offset->rankOf() == 1 && offset->sizeAt(0) == iD, 0, + "CUSTOM_OP fused_batch_norm: wrong shape of input offset array, " + "expected is [%i], but got %s instead", + iD, ShapeUtils::shapeAsString(offset).c_str()); + + NDArray *mean(nullptr), *variance(nullptr); + if (!isTraining) { + mean = INPUT_VARIABLE(3); + variance = INPUT_VARIABLE(4); + REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == iD, 0, + "CUSTOM_OP fused_batch_norm: wrong shape of input mean array, " + "expected is [%i], but got %s instead", + iD, ShapeUtils::shapeAsString(mean).c_str()); + REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == iD, 0, + "CUSTOM_OP fused_batch_norm: wrong shape of input variance " + "array, expected is [%i], but got %s instead", + iD, ShapeUtils::shapeAsString(variance).c_str()); + } else { + // REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when + // isTraining=true then number of input arrays must be equal to 3, but got + // %i instead !", block.width()); + std::vector shape = {iD}; + mean = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), + block.launchContext()); + variance = NDArrayFactory::create_( + scale->ordering(), shape, scale->dataType(), block.launchContext()); + } + + // FIXME: double? + double epsilon; + if (block.numT() > 0) + epsilon = T_ARG(0) > 1.001e-5 ? T_ARG(0) : 1.001e-5; + else + epsilon = 0.001; + + const int restSize = x->lengthOf() / iD; + auto xAffected = NDArrayFactory::create( + x->ordering(), {restSize, iD}, mean->dataType(), block.launchContext()); + xAffected.assign(x); + + const int restSizeMinusOne = (restSize > 1) ? (restSize - 1) : 1; + // FIXME: float? + const double restSizeInv = 1.0 / restSize; + const double restSizeAdjust = (double)restSize / restSizeMinusOne; + + if (isTraining) { + auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0}); + sum *= restSizeInv; + mean->assign(sum); + *batchMean = *mean; + // delete sum; + } else + *batchMean = 0.; + + xAffected -= *mean; + + if (isTraining) { + int power = 2; + xAffected.applyScalar(scalar::Pow, power, xAffected); + auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0}); + sum *= restSizeInv; + variance->assign(sum); + *batchVar = (*variance) * restSizeAdjust; + // delete sum; + } else + *batchVar = 0.; + xAffected *= + (*variance + epsilon).transform(transform::RSqrt) * (*scale) + (*offset); + y->assign(xAffected); + + if (isTraining) { + delete mean; + delete variance; + } + + return Status::OK(); } - - DECLARE_SHAPE_FN(fused_batch_norm) { - auto xShapeInfo = inputShape->at(0); - auto scaleShapeInfo = inputShape->at(1); + auto xShapeInfo = inputShape->at(0); + auto scaleShapeInfo = inputShape->at(1); - const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW - const int iD = dataFormat ? xShapeInfo[2] : xShapeInfo[4]; + const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW + const int iD = dataFormat ? xShapeInfo[2] : xShapeInfo[4]; - REQUIRE_TRUE(scaleShapeInfo[0] == 1 && scaleShapeInfo[1] == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scaleShapeInfo).c_str()); + REQUIRE_TRUE(scaleShapeInfo[0] == 1 && scaleShapeInfo[1] == iD, 0, + "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, " + "expected is [%i], but got %s instead", + iD, ShapeUtils::shapeAsString(scaleShapeInfo).c_str()); - Nd4jLong* outShapeInfo(nullptr), *batchMeanShapeInfo(nullptr), *batchVarShapeInfo(nullptr); + Nd4jLong *outShapeInfo(nullptr), *batchMeanShapeInfo(nullptr), + *batchVarShapeInfo(nullptr); - COPY_SHAPE(xShapeInfo, outShapeInfo); - COPY_SHAPE(scaleShapeInfo, batchMeanShapeInfo); - COPY_SHAPE(scaleShapeInfo, batchVarShapeInfo); + COPY_SHAPE(xShapeInfo, outShapeInfo); + COPY_SHAPE(scaleShapeInfo, batchMeanShapeInfo); + COPY_SHAPE(scaleShapeInfo, batchVarShapeInfo); - return SHAPELIST(CONSTANT(outShapeInfo), CONSTANT(batchMeanShapeInfo), CONSTANT(batchVarShapeInfo)); + return SHAPELIST(CONSTANT(outShapeInfo), CONSTANT(batchMeanShapeInfo), + CONSTANT(batchVarShapeInfo)); } - - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/layer_norm.cpp b/libnd4j/include/ops/declarable/generic/nn/layer_norm.cpp index 5643932cbf7a..f21242937cab 100644 --- a/libnd4j/include/ops/declarable/generic/nn/layer_norm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/layer_norm.cpp @@ -22,124 +22,147 @@ #if NOT_EXCLUDED(OP_layer_norm) #include -#include #include +#include namespace sd { -namespace ops { +namespace ops { + +CONFIGURABLE_OP_IMPL(layer_norm, 2, 1, false, 0, -1) { + auto input = INPUT_VARIABLE(0); + auto gain = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + + auto axis = block.getIArguments(); + + const bool isNCHW = + block.numB() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC + const int dimC = isNCHW ? 1 : input->rankOf() - 1; + + REQUIRE_TRUE(gain->rankOf() == 1 && gain->sizeAt(0) == input->sizeAt(dimC), 0, + "LAYER_NORM OP: wrong shape of gain array, expected is {%i}, " + "but got %s instead !", + input->sizeAt(dimC), ShapeUtils::shapeAsString(gain).c_str()); + + NDArray *bias = nullptr; + if (block.width() > 2) { + bias = INPUT_VARIABLE(2); + REQUIRE_TRUE(bias->rankOf() == 1 && bias->sizeAt(0) == input->sizeAt(dimC), + 0, + "LAYER_NORM OP: wrong shape of bias array, expected is {%i}, " + "but got %s instead !", + input->sizeAt(dimC), ShapeUtils::shapeAsString(bias).c_str()); + } + + std::vector longAxis = ArrayUtils::toLongVector(axis); + + sd::ops::standardize standardizeOp; + std::vector inputs = {input}; + std::vector outputs = {output}; + std::vector targs = {}; + std::vector bargs = {}; + standardizeOp.execute(inputs, outputs, targs, longAxis, bargs); + + // output->applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), gain, + // output); + output->applyBroadcast(sd::broadcast::Multiply, {dimC}, *gain, *output); + if (bias != nullptr) { + // output->applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), bias, output); + // output->applyBroadcast(sd::broadcast::Add, {dimC}, bias); + helpers::addBias(block, *output, *bias, *output, isNCHW); + } + + return Status::OK(); +} - CONFIGURABLE_OP_IMPL(layer_norm, 2, 1, false, 0, -1) { - auto input = INPUT_VARIABLE(0); - auto gain = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); +DECLARE_TYPES(layer_norm) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS}); + getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); +} - std::vector axis = *block.getIArguments(); - - const bool isNCHW = block.getBArguments()->size() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC - const int dimC = isNCHW ? 1 : input->rankOf() - 1; - - REQUIRE_TRUE(gain->rankOf() == 1 && gain->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM OP: wrong shape of gain array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(gain).c_str()); - - NDArray* bias = nullptr; - if (block.width() > 2) { - bias = INPUT_VARIABLE(2); - REQUIRE_TRUE(bias->rankOf() == 1 && bias->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM OP: wrong shape of bias array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(bias).c_str()); - } - - std::vector longAxis = ArrayUtils::toLongVector(axis); - - sd::ops::standardize standardizeOp; - std::vector inputs = {input}; - std::vector outputs = {output}; - std::vector targs = {}; - std::vector bargs = {}; - standardizeOp.execute(inputs, outputs, targs, longAxis, bargs); - - // output->applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), gain, output); - output->applyBroadcast(sd::broadcast::Multiply, {dimC}, *gain, *output); - if(bias != nullptr) { - // output->applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), bias, output); - // output->applyBroadcast(sd::broadcast::Add, {dimC}, bias); - helpers::addBias(block, *output, *bias, *output, isNCHW); - } - - return Status::OK(); - } - - - DECLARE_TYPES(layer_norm) { - getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS}); - getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); - } - - CUSTOM_OP_IMPL(layer_norm_bp, 3, -1, false, 0, -1) { - auto input = INPUT_VARIABLE(0); - auto gain = INPUT_VARIABLE(1); - auto bias = block.width() == 4 ? INPUT_VARIABLE(2) : nullptr; - auto eps = block.width() == 4 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); - - auto dLdx = OUTPUT_VARIABLE(0); - auto dLdg = OUTPUT_VARIABLE(1); - auto dLdb = block.width() == 4 ? OUTPUT_VARIABLE(2) : nullptr; - - const bool isNCHW = block.getBArguments()->size() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC - const int dimC = isNCHW ? 1 : input->rankOf() - 1; - - REQUIRE_TRUE(gain->rankOf() == 1 && gain->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM_BP OP: wrong shape of gain array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(gain).c_str()); - - std::vector axis = *block.getIArguments(); - - std::vector longAxis = ArrayUtils::toLongVector(axis); - - if(bias != nullptr) { - REQUIRE_TRUE(bias->rankOf() == 1 && bias->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM_BP OP: wrong shape of bias array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(bias).c_str()); - // eps->reduceAlongDimension(sd::reduce::Sum, *dLdb, {0}, true); - eps->reduceAlongDimension(sd::reduce::Sum, *dLdb, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC})); - } - - NDArray standardized(input->shapeInfo(), false, block.launchContext()); - - sd::ops::standardize standardizeOp; - std::vector inputs = {input}; - std::vector outputs = {&standardized}; - std::vector targs = {}; - std::vector bargs = {}; - - standardizeOp.execute(inputs, outputs, targs, longAxis, bargs); - standardized.applyPairwiseTransform(sd::pairwise::Multiply, *eps, standardized); - standardized.reduceAlongDimension(sd::reduce::Sum, *dLdg, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC})); - - sd::ops::standardize_bp standardizeBp; - // eps->applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), gain, dLdx); - eps->applyBroadcast(sd::broadcast::Multiply, {dimC}, *gain, *dLdx); - - auto dLdx_tmp = dLdx->dup(); - std::vector standardizeBpArgs = {input, &dLdx_tmp}; - std::vector standardizeBpOut = {dLdx}; - standardizeBp.execute(standardizeBpArgs, standardizeBpOut, targs, longAxis, bargs); - - return Status::OK(); - } - - DECLARE_TYPES(layer_norm_bp) { - getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS}); - getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); - } - - DECLARE_SHAPE_FN(layer_norm_bp) { - Nd4jLong *dLdx_shape; - COPY_SHAPE(inputShape->at(0), dLdx_shape); - Nd4jLong *dLdg_shape; - COPY_SHAPE(inputShape->at(1), dLdg_shape); - if(inputShape->size() > 3){ - Nd4jLong *dLdb_shape; - COPY_SHAPE(inputShape->at(2), dLdb_shape); - return SHAPELIST(CONSTANT(dLdx_shape), CONSTANT(dLdg_shape), CONSTANT(dLdb_shape)); - } - return SHAPELIST(CONSTANT(dLdx_shape), CONSTANT(dLdg_shape)); - } +CUSTOM_OP_IMPL(layer_norm_bp, 3, -1, false, 0, -1) { + auto input = INPUT_VARIABLE(0); + auto gain = INPUT_VARIABLE(1); + auto bias = block.width() == 4 ? INPUT_VARIABLE(2) : nullptr; + auto eps = block.width() == 4 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); + + auto dLdx = OUTPUT_VARIABLE(0); + auto dLdg = OUTPUT_VARIABLE(1); + auto dLdb = block.width() == 4 ? OUTPUT_VARIABLE(2) : nullptr; + + const bool isNCHW = + block.numB() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC + const int dimC = isNCHW ? 1 : input->rankOf() - 1; + + REQUIRE_TRUE(gain->rankOf() == 1 && gain->sizeAt(0) == input->sizeAt(dimC), 0, + "LAYER_NORM_BP OP: wrong shape of gain array, expected is {%i}, " + "but got %s instead !", + input->sizeAt(dimC), ShapeUtils::shapeAsString(gain).c_str()); + + auto axis = block.getIArguments(); + + std::vector longAxis = ArrayUtils::toLongVector(axis); + + if (bias != nullptr) { + REQUIRE_TRUE(bias->rankOf() == 1 && bias->sizeAt(0) == input->sizeAt(dimC), + 0, + "LAYER_NORM_BP OP: wrong shape of bias array, expected is " + "{%i}, but got %s instead !", + input->sizeAt(dimC), ShapeUtils::shapeAsString(bias).c_str()); + // eps->reduceAlongDimension(sd::reduce::Sum, *dLdb, {0}, true); + eps->reduceAlongDimension( + sd::reduce::Sum, *dLdb, + ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC})); + } + + NDArray standardized(input->shapeInfo(), false, block.launchContext()); + + sd::ops::standardize standardizeOp; + std::vector inputs = {input}; + std::vector outputs = {&standardized}; + std::vector targs = {}; + std::vector bargs = {}; + + standardizeOp.execute(inputs, outputs, targs, longAxis, bargs); + standardized.applyPairwiseTransform(sd::pairwise::Multiply, *eps, + standardized); + standardized.reduceAlongDimension( + sd::reduce::Sum, *dLdg, + ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC})); + + sd::ops::standardize_bp standardizeBp; + // eps->applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), gain, dLdx); + eps->applyBroadcast(sd::broadcast::Multiply, {dimC}, *gain, *dLdx); + + auto dLdx_tmp = dLdx->dup(); + std::vector standardizeBpArgs = {input, &dLdx_tmp}; + std::vector standardizeBpOut = {dLdx}; + standardizeBp.execute(standardizeBpArgs, standardizeBpOut, targs, longAxis, + bargs); + + return Status::OK(); +} +DECLARE_TYPES(layer_norm_bp) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS}); + getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); } + +DECLARE_SHAPE_FN(layer_norm_bp) { + Nd4jLong *dLdx_shape; + COPY_SHAPE(inputShape->at(0), dLdx_shape); + Nd4jLong *dLdg_shape; + COPY_SHAPE(inputShape->at(1), dLdg_shape); + if (inputShape->size() > 3) { + Nd4jLong *dLdb_shape; + COPY_SHAPE(inputShape->at(2), dLdb_shape); + return SHAPELIST(CONSTANT(dLdx_shape), CONSTANT(dLdg_shape), + CONSTANT(dLdb_shape)); + } + return SHAPELIST(CONSTANT(dLdx_shape), CONSTANT(dLdg_shape)); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/logSoftmax.cpp b/libnd4j/include/ops/declarable/generic/nn/logSoftmax.cpp index 64aadce370a7..12e71befd578 100644 --- a/libnd4j/include/ops/declarable/generic/nn/logSoftmax.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/logSoftmax.cpp @@ -27,55 +27,59 @@ namespace sd { namespace ops { - - DECLARE_TYPES(log_softmax) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setSameMode(true); - } +DECLARE_TYPES(log_softmax) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); +} CONFIGURABLE_OP_IMPL(log_softmax, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - const int rank = input->rankOf(); - const int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1; + const int rank = input->rankOf(); + const int dim = block.numI() > 0 ? INT_ARG(0) : rank - 1; - REQUIRE_TRUE(dim < rank, 0, "LOG_SOFTMAX OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim); + REQUIRE_TRUE( + dim < rank, 0, + "LOG_SOFTMAX OP: the value of input integer parameter (dimension) must " + "be less than input array rank %i, but got dimension = %i instead !", + rank, dim); - helpers::logSoftmax(block.launchContext(), *input, *output, dim); + helpers::logSoftmax(block.launchContext(), *input, *output, dim); - return Status::OK(); + return Status::OK(); } - DECLARE_TYPES(log_softmax_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - +DECLARE_TYPES(log_softmax_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); +} CONFIGURABLE_OP_IMPL(log_softmax_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); - const int rank = input->rankOf(); - const int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1; + const int rank = input->rankOf(); + const int dim = block.numI() > 0 ? INT_ARG(0) : rank - 1; - REQUIRE_TRUE(dim < rank, 0, "LOG_SOFTMAX_BP OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim); + REQUIRE_TRUE( + dim < rank, 0, + "LOG_SOFTMAX_BP OP: the value of input integer parameter (dimension) " + "must be less than input array rank %i, but got dimension = %i instead !", + rank, dim); - helpers::softmax(block.launchContext(), *input, *gradI, dim); + helpers::softmax(block.launchContext(), *input, *gradI, dim); - gradI->assign( *gradO - (*gradI * *gradO).reduceAlongDimension(reduce::Sum, {dim}, true) ); + gradI->assign( + *gradO - + (*gradI * *gradO).reduceAlongDimension(reduce::Sum, {dim}, true)); - return Status::OK(); + return Status::OK(); } - - -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/nn/lrn.cpp b/libnd4j/include/ops/declarable/generic/nn/lrn.cpp index e9546d1db711..44653807e8ce 100644 --- a/libnd4j/include/ops/declarable/generic/nn/lrn.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/lrn.cpp @@ -23,59 +23,67 @@ #include #if NOT_EXCLUDED(OP_lrn) -#include #include +#include namespace sd { - namespace ops { - - DECLARE_TYPES(lrn) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - CONFIGURABLE_OP_IMPL(lrn, 1, 1, true, 3, 1) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(input->rankOf() == 4, 0, "lrn: Input rank of 4 expected, but got %i instead", input->rankOf()); - - double alpha = T_ARG(1); - double beta = T_ARG(2); - double bias = T_ARG(0); - int depth = INT_ARG(0); - - return helpers::lrnFunctor(block, input, output, depth, bias, alpha, beta); - } - - DECLARE_TYPES(lrn_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - CONFIGURABLE_OP_IMPL(lrn_bp, 2, 1, true, 3, 1) { - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(input->rankOf() == 4, 0, "lrn_bp: Input rank of 4 expected, but got %i instead", input->rankOf()); - REQUIRE_TRUE(input->isSameShape(gradO), 0, "lrn_bp: Both input and grad_output should have the same shape, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - - // FIXME: double/float? - float bias = T_ARG(0); - float alpha = T_ARG(1); - float beta = T_ARG(2); - int depth = INT_ARG(0); - - helpers::lrnBP(block, *input, *gradO, *gradI, depth, bias, alpha, beta); - - return Status::OK(); - } - DECLARE_SYN(local_response_normalization, lrn); - - } +namespace ops { + +DECLARE_TYPES(lrn) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +CONFIGURABLE_OP_IMPL(lrn, 1, 1, true, 3, 1) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "lrn: Input rank of 4 expected, but got %i instead", + input->rankOf()); + + double alpha = T_ARG(1); + double beta = T_ARG(2); + double bias = T_ARG(0); + int depth = INT_ARG(0); + + return helpers::lrnFunctor(block, input, output, depth, bias, alpha, beta); +} + +DECLARE_TYPES(lrn_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +CONFIGURABLE_OP_IMPL(lrn_bp, 2, 1, true, 3, 1) { + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "lrn_bp: Input rank of 4 expected, but got %i instead", + input->rankOf()); + REQUIRE_TRUE(input->isSameShape(gradO), 0, + "lrn_bp: Both input and grad_output should have the same shape, " + "but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(input).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + + // FIXME: double/float? + float bias = T_ARG(0); + float alpha = T_ARG(1); + float beta = T_ARG(2); + int depth = INT_ARG(0); + + helpers::lrnBP(block, *input, *gradO, *gradI, depth, bias, alpha, beta); + + return Status::OK(); +} +DECLARE_SYN(local_response_normalization, lrn); + +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/multi_head_dot_product_attention.cpp b/libnd4j/include/ops/declarable/generic/nn/multi_head_dot_product_attention.cpp index 7ff8eb4c5972..c43599f8d6f4 100644 --- a/libnd4j/include/ops/declarable/generic/nn/multi_head_dot_product_attention.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/multi_head_dot_product_attention.cpp @@ -21,268 +21,357 @@ #include #if NOT_EXCLUDED(OP_multi_head_dot_product_attention) -#include #include +#include namespace sd { -namespace ops { - - CUSTOM_OP_IMPL(multi_head_dot_product_attention, 7, -1, false, 0, 2) { - auto queries = INPUT_VARIABLE(0); //[batch, nIn, timeSteps] - auto keys = INPUT_VARIABLE(1); //[batch, nIn, timeSteps] - auto values = INPUT_VARIABLE(2); //[batch, nIn, timeSteps] - auto Wq = INPUT_VARIABLE(3); //[nHeads, headSize, nIn] - auto Wk = INPUT_VARIABLE(4); //[nHeads, headSize, nIn] - auto Wv = INPUT_VARIABLE(5); //[nHeads, headSize, nIn] - auto Wo = INPUT_VARIABLE(6); //[nHeads * headSize, nOut] - auto mask = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr; - - - auto output = OUTPUT_VARIABLE(0); - int normalization = INT_ARG(0); - int weights = INT_ARG(1); - - auto numHeads = Wk->sizeAt(0); - auto miniBatchSize = queries->sizeAt(0); - auto queryCount = queries->sizeAt(2); - auto projectedValuesSize = Wv->sizeAt(1); - auto outSize = Wo->sizeAt(1); - - REQUIRE_TRUE(queries->rankOf() == keys->rankOf() && keys->rankOf() == values->rankOf(), 0, - "multi_head_dot_product_attention: Queries, Keys and Values must have same rank. " - "But got queries = %s, keys = %s, values = %s", ShapeUtils::shapeAsString(queries).c_str(), - ShapeUtils::shapeAsString(keys).c_str(), ShapeUtils::shapeAsString(values).c_str()); - - REQUIRE_TRUE(queries->rankOf() == 3, 0, - "multi_head_dot_product_attention: Queries, Keys and Values must be rank 3 arrays" - "But got rank = %i", queries->rankOf()); - - REQUIRE_TRUE(Wq->rankOf() == Wk->rankOf() && Wk->rankOf() == Wv->rankOf(), 0, - "multi_head_dot_product_attention: Input projections weights must have the same rank. " - "But got Wq = %s, Wk = %s, Wv = %s", ShapeUtils::shapeAsString(Wq).c_str(), - ShapeUtils::shapeAsString(Wk).c_str(), ShapeUtils::shapeAsString(Wv).c_str()); - - REQUIRE_TRUE(Wq->sizeAt(0) == Wk->sizeAt(0) && Wk->sizeAt(0) == Wv->sizeAt(0), 0, - "multi_head_dot_product_attention: Projections weights must have the same number of attention heads. " - "But got Wq = %s, Wk = %s, Wv = %s", ShapeUtils::shapeAsString(Wq).c_str(), - ShapeUtils::shapeAsString(Wk).c_str(), ShapeUtils::shapeAsString(Wv).c_str()); - - REQUIRE_TRUE(Wo->rankOf() == 2, 0, - "multi_head_dot_product_attention: Output projection weights must have rank 2. " - "But got Wo = %s", ShapeUtils::shapeAsString(Wo).c_str()); - - REQUIRE_TRUE(Wq->sizeAt(2) == queries->sizeAt(1), 0, - "multi_head_dot_product_attention: Query projection matrix Wq has incompatible size to queries matrix." - "Expected Wq[2] = queries[1] = %i, but got Wq = %s, queries = %s ", queries->sizeAt(1), - ShapeUtils::shapeAsString(Wq).c_str(), ShapeUtils::shapeAsString(queries).c_str()); - - REQUIRE_TRUE(Wk->sizeAt(2) == keys->sizeAt(1), 0, - "multi_head_dot_product_attention: Key projection matrix Wk has incompatible size to keys matrix." - "Expected Wk[2] = keys[1] = %i, but got Wk = %s, keys = %s ", keys->sizeAt(1), - ShapeUtils::shapeAsString(Wk).c_str(), ShapeUtils::shapeAsString(keys).c_str()); - - REQUIRE_TRUE(Wv->sizeAt(2) == values->sizeAt(1), 0, - "multi_head_dot_product_attention: Value projection matrix Wv has incompatible size to values matrix." - "Expected Wv[2] = values[1] = %i, but got Wv = %s, values = %s ", values->sizeAt(1), - ShapeUtils::shapeAsString(Wv).c_str(), ShapeUtils::shapeAsString(values).c_str()); - - REQUIRE_TRUE(Wo->sizeAt(0) == (Wv->sizeAt(1) * Wv->sizeAt(0)), 0, - "multi_head_dot_product_attention: Output projection matrix Wo has incompatible size to attention result." - "Expected Wo[0] = Wv[0] * Wv[1] = %i, but got Wo = %s, Wv = %", (Wv->sizeAt(1) * Wv->sizeAt(0)), - ShapeUtils::shapeAsString(Wo).c_str(), ShapeUtils::shapeAsString(Wv).c_str()); - - - // Project queries, keys, values - auto projectedQueries = AttentionHelper::multiHeadProject(queries, Wq, block.launchContext()); //[minibatch, numHeads, projectedSize, seqLength] - auto projectedKeys = AttentionHelper::multiHeadProject(keys, Wk, block.launchContext()); //[minibatch, numHeads, projectedSize, seqLength] - auto projectedValues = AttentionHelper::multiHeadProject(values, Wv, block.launchContext()); //[minibatch, numHeads, projectedSize, seqLength] - - // Apply Attention - // attnResults = [minibatch, numHeads, projectedSize, seqLenth - NDArray attnResults('c', {projectedQueries.sizeAt(0), projectedValues.sizeAt(1), projectedValues.sizeAt(2), projectedQueries.sizeAt(3)}, projectedValues.dataType(), block.launchContext()); - sd::ops::dot_product_attention attention; - attention.execute({&projectedQueries, &projectedKeys, &projectedValues, mask}, {&attnResults, weights ? OUTPUT_VARIABLE(1) : nullptr}, {}, {normalization, weights}, {}); - - // Project attention results - attnResults.permutei({0, 3, 1, 2}); - attnResults.reshapei(attnResults.ordering(), {miniBatchSize * queryCount, numHeads * projectedValuesSize}); - - sd::ops::matmul mmul; - NDArray projRes('c', {attnResults.sizeAt(0), Wo->sizeAt(1)}, values->dataType(), block.launchContext()); - mmul.execute({&attnResults, Wo},{&projRes}, {}, {}, {}); - projRes.reshapei(projRes.ordering(), {miniBatchSize, queryCount, outSize}); - projRes.permutei({0, 2, 1}); - - // FIXME: bad for performance - output->assign(projRes); - - return Status::OK(); - } - - - DECLARE_TYPES(multi_head_dot_product_attention) { - getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS}); - getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); - } - - DECLARE_SHAPE_FN(multi_head_dot_product_attention) { - auto queryShape = inputShape->at(0); - auto keysShape = inputShape->at(1); - auto valuesShape = inputShape->at(2); - auto WkShape = inputShape->at(3); - auto WoShape = inputShape->at(6); - - auto batchSize = shape::sizeAt(queryShape, 0); - auto outSize = shape::sizeAt(WoShape, 1); - auto queryCount = shape::sizeAt(queryShape, 2); - auto numHeads = shape::sizeAt(WkShape, 0); - auto timeSteps = shape::sizeAt(keysShape, 2); - - auto weightsShape = ConstantShapeHelper::getInstance().createShapeInfo(sd::ArrayOptions::dataType(valuesShape), 'c', {batchSize, numHeads, timeSteps, queryCount}); - auto outputShape = ConstantShapeHelper::getInstance().createShapeInfo(sd::ArrayOptions::dataType(valuesShape), 'c', {batchSize, outSize, queryCount}); - - if(INT_ARG(1)){ - return SHAPELIST(outputShape, weightsShape); - }else{ - return SHAPELIST(outputShape); - } - - } - - CUSTOM_OP_IMPL(multi_head_dot_product_attention_bp, 8, 7, false, 0, 1) { - auto queries = INPUT_VARIABLE(0); - auto keys = INPUT_VARIABLE(1); - auto values = INPUT_VARIABLE(2); - auto Wq = INPUT_VARIABLE(3); - auto Wk = INPUT_VARIABLE(4); - auto Wv = INPUT_VARIABLE(5); - auto Wo = INPUT_VARIABLE(6); - auto eps = INPUT_VARIABLE(7); - auto mask = block.width() > 8 ? INPUT_VARIABLE(8) : nullptr; - - auto dLdq = OUTPUT_VARIABLE(0); - auto dLdk = OUTPUT_VARIABLE(1); - auto dLdv = OUTPUT_VARIABLE(2); - auto dLdWq = OUTPUT_VARIABLE(3); - auto dLdWk = OUTPUT_VARIABLE(4); - auto dLdWv = OUTPUT_VARIABLE(5); - auto dLdWo = OUTPUT_VARIABLE(6); - - int normalization = INT_ARG(0); - - auto numHeads = Wk->sizeAt(0); - auto miniBatchSize = queries->sizeAt(0); - auto queryCount = queries->sizeAt(2); - auto outSize = Wo->sizeAt(1); - auto projectedValuesSize = Wv->sizeAt(1); - - - REQUIRE_TRUE(queries->rankOf() == keys->rankOf() && keys->rankOf() == values->rankOf(), 0, - "multi_head_dot_product_attention: Queries, Keys and Values must have same rank. " - "But got queries = %s, keys = %s, values = %s", ShapeUtils::shapeAsString(queries).c_str(), - ShapeUtils::shapeAsString(keys).c_str(), ShapeUtils::shapeAsString(values).c_str()); - - REQUIRE_TRUE(queries->rankOf() == 3, 0, - "multi_head_dot_product_attention: Queries, Keys and Values must be rank 3 arrays" - "But got rank = %i", queries->rankOf()); - - REQUIRE_TRUE(Wq->rankOf() == Wk->rankOf() && Wk->rankOf() == Wv->rankOf(), 0, - "multi_head_dot_product_attention: Input projections weights must have the same rank. " - "But got Wq = %s, Wk = %s, Wv = %s", ShapeUtils::shapeAsString(Wq).c_str(), - ShapeUtils::shapeAsString(Wk).c_str(), ShapeUtils::shapeAsString(Wv).c_str()); - - REQUIRE_TRUE(Wq->sizeAt(0) == Wk->sizeAt(0) && Wk->sizeAt(0) == Wv->sizeAt(0), 0, - "multi_head_dot_product_attention: Projections weights must have the same number of attention heads. " - "But got Wq = %s, Wk = %s, Wv = %s", ShapeUtils::shapeAsString(Wq).c_str(), - ShapeUtils::shapeAsString(Wk).c_str(), ShapeUtils::shapeAsString(Wv).c_str()); - - REQUIRE_TRUE(Wo->rankOf() == 2, 0, - "multi_head_dot_product_attention: Output projection weights must have rank 2. " - "But got Wo = %s", ShapeUtils::shapeAsString(Wo).c_str()); - - REQUIRE_TRUE(Wq->sizeAt(2) == queries->sizeAt(1), 0, - "multi_head_dot_product_attention: Query projection matrix Wq has incompatible size to queries matrix." - "Expected Wq[2] = queries[1] = %i, but got Wq = %s, queries = %s ", queries->sizeAt(1), - ShapeUtils::shapeAsString(Wq).c_str(), ShapeUtils::shapeAsString(queries).c_str()); - - REQUIRE_TRUE(Wk->sizeAt(2) == keys->sizeAt(1), 0, - "multi_head_dot_product_attention: Key projection matrix Wk has incompatible size to keys matrix." - "Expected Wk[2] = keys[1] = %i, but got Wk = %s, keys = %s ", keys->sizeAt(1), - ShapeUtils::shapeAsString(Wk).c_str(), ShapeUtils::shapeAsString(keys).c_str()); - - REQUIRE_TRUE(Wv->sizeAt(2) == values->sizeAt(1), 0, - "multi_head_dot_product_attention: Value projection matrix Wv has incompatible size to values matrix." - "Expected Wv[2] = values[1] = %i, but got Wv = %s, values = %s ", values->sizeAt(1), - ShapeUtils::shapeAsString(Wv).c_str(), ShapeUtils::shapeAsString(values).c_str()); - - REQUIRE_TRUE(Wo->sizeAt(0) == (Wv->sizeAt(1) * Wv->sizeAt(0)), 0, - "multi_head_dot_product_attention: Output projection matrix Wo has incompatible size to attention result." - "Expected Wo[0] = Wv[0] * Wv[1] = %i, but got Wo = %s, Wv = %", (Wv->sizeAt(1) * Wv->sizeAt(0)), - ShapeUtils::shapeAsString(Wo).c_str(), ShapeUtils::shapeAsString(Wv).c_str()); - - // Project queries, keys, values - auto projectedQueries = AttentionHelper::multiHeadProject(queries, Wq, block.launchContext()); - auto projectedKeys = AttentionHelper::multiHeadProject(keys, Wk, block.launchContext()); - auto projectedValues = AttentionHelper::multiHeadProject(values, Wv, block.launchContext()); - - // Apply Attention - NDArray attnResults('c', {projectedQueries.sizeAt(0), projectedValues.sizeAt(1), projectedValues.sizeAt(2), projectedQueries.sizeAt(3)}, projectedValues.dataType(), block.launchContext()); - sd::ops::dot_product_attention attention; - attention.execute({&projectedQueries, &projectedKeys, &projectedValues, mask}, {&attnResults}, {}, {normalization, 0}, {}); - - // Project attention results - attnResults.permutei({0, 3, 1, 2}); - attnResults.reshapei(attnResults.ordering(), {miniBatchSize * queryCount, numHeads * projectedValuesSize}); - - // dLdWo - auto epsPerm = eps->permute({0, 2, 1}); - auto epsPostReshape = epsPerm.reshape(eps->ordering(), {miniBatchSize * queryCount, outSize}); - sd::ops::matmul_bp matmulBp; - NDArray dLdPreWo(attnResults.shapeInfo(), false, block.launchContext()); - matmulBp.execute({&attnResults, Wo, &epsPostReshape}, std::vector{&dLdPreWo, dLdWo}, {}, {}, {}); - - // dLdAttn - dLdPreWo.reshapei({miniBatchSize, queryCount, numHeads, projectedValues.sizeAt(2)}); - dLdPreWo.permutei({0, 2, 3, 1}); - - sd::ops::dot_product_attention_bp attentionBp; - NDArray dLdProjectedQueries(projectedQueries.shapeInfo(), false, block.launchContext()); - NDArray dLdProjectedKeys(projectedKeys.shapeInfo(), false, block.launchContext()); - NDArray dLdProjectedValues(projectedValues.shapeInfo(), false, block.launchContext()); - attentionBp.execute({&projectedQueries, &projectedKeys, &projectedValues, &dLdPreWo, mask},{&dLdProjectedQueries, &dLdProjectedKeys, &dLdProjectedValues}, {}, {normalization}, {}); - - AttentionHelper::multiHeadProjectBp(queries, Wq, &dLdProjectedQueries, dLdq, dLdWq, block.launchContext()); - AttentionHelper::multiHeadProjectBp(keys, Wk, &dLdProjectedKeys, dLdk, dLdWk, block.launchContext()); - AttentionHelper::multiHeadProjectBp(values, Wv, &dLdProjectedValues, dLdv, dLdWv, block.launchContext()); - - return Status::OK(); - } - - DECLARE_TYPES(multi_head_dot_product_attention_bp) { - getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS}); - getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); - } - - DECLARE_SHAPE_FN(multi_head_dot_product_attention_bp) { - Nd4jLong *dLdq_shape; - COPY_SHAPE(inputShape->at(0), dLdq_shape); - Nd4jLong *dLdk_shape; - COPY_SHAPE(inputShape->at(1), dLdk_shape); - Nd4jLong *dLdv_shape; - COPY_SHAPE(inputShape->at(2), dLdv_shape); - Nd4jLong *dLdWq_shape; - COPY_SHAPE(inputShape->at(3), dLdWq_shape); - Nd4jLong *dLdWk_shape; - COPY_SHAPE(inputShape->at(4), dLdWk_shape); - Nd4jLong *dLdWv_shape; - COPY_SHAPE(inputShape->at(5), dLdWv_shape); - Nd4jLong *dLdWo_shape; - COPY_SHAPE(inputShape->at(6), dLdWo_shape); - - return SHAPELIST(CONSTANT(dLdq_shape), CONSTANT(dLdk_shape), CONSTANT(dLdv_shape), CONSTANT(dLdWq_shape), CONSTANT(dLdWk_shape), CONSTANT(dLdWv_shape), CONSTANT(dLdWo_shape)); - } +namespace ops { + +CUSTOM_OP_IMPL(multi_head_dot_product_attention, 7, -1, false, 0, 2) { + auto queries = INPUT_VARIABLE(0); //[batch, nIn, timeSteps] + auto keys = INPUT_VARIABLE(1); //[batch, nIn, timeSteps] + auto values = INPUT_VARIABLE(2); //[batch, nIn, timeSteps] + auto Wq = INPUT_VARIABLE(3); //[nHeads, headSize, nIn] + auto Wk = INPUT_VARIABLE(4); //[nHeads, headSize, nIn] + auto Wv = INPUT_VARIABLE(5); //[nHeads, headSize, nIn] + auto Wo = INPUT_VARIABLE(6); //[nHeads * headSize, nOut] + auto mask = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr; + + auto output = OUTPUT_VARIABLE(0); + int normalization = INT_ARG(0); + int weights = INT_ARG(1); + + auto numHeads = Wk->sizeAt(0); + auto miniBatchSize = queries->sizeAt(0); + auto queryCount = queries->sizeAt(2); + auto projectedValuesSize = Wv->sizeAt(1); + auto outSize = Wo->sizeAt(1); + + REQUIRE_TRUE( + queries->rankOf() == keys->rankOf() && keys->rankOf() == values->rankOf(), + 0, + "multi_head_dot_product_attention: Queries, Keys and Values must have " + "same rank. " + "But got queries = %s, keys = %s, values = %s", + ShapeUtils::shapeAsString(queries).c_str(), + ShapeUtils::shapeAsString(keys).c_str(), + ShapeUtils::shapeAsString(values).c_str()); + + REQUIRE_TRUE(queries->rankOf() == 3, 0, + "multi_head_dot_product_attention: Queries, Keys and Values " + "must be rank 3 arrays" + "But got rank = %i", + queries->rankOf()); + + REQUIRE_TRUE(Wq->rankOf() == Wk->rankOf() && Wk->rankOf() == Wv->rankOf(), 0, + "multi_head_dot_product_attention: Input projections weights " + "must have the same rank. " + "But got Wq = %s, Wk = %s, Wv = %s", + ShapeUtils::shapeAsString(Wq).c_str(), + ShapeUtils::shapeAsString(Wk).c_str(), + ShapeUtils::shapeAsString(Wv).c_str()); + + REQUIRE_TRUE(Wq->sizeAt(0) == Wk->sizeAt(0) && Wk->sizeAt(0) == Wv->sizeAt(0), + 0, + "multi_head_dot_product_attention: Projections weights must " + "have the same number of attention heads. " + "But got Wq = %s, Wk = %s, Wv = %s", + ShapeUtils::shapeAsString(Wq).c_str(), + ShapeUtils::shapeAsString(Wk).c_str(), + ShapeUtils::shapeAsString(Wv).c_str()); + + REQUIRE_TRUE(Wo->rankOf() == 2, 0, + "multi_head_dot_product_attention: Output projection weights " + "must have rank 2. " + "But got Wo = %s", + ShapeUtils::shapeAsString(Wo).c_str()); + + REQUIRE_TRUE( + Wq->sizeAt(2) == queries->sizeAt(1), 0, + "multi_head_dot_product_attention: Query projection matrix Wq has " + "incompatible size to queries matrix." + "Expected Wq[2] = queries[1] = %i, but got Wq = %s, queries = %s ", + queries->sizeAt(1), ShapeUtils::shapeAsString(Wq).c_str(), + ShapeUtils::shapeAsString(queries).c_str()); + + REQUIRE_TRUE(Wk->sizeAt(2) == keys->sizeAt(1), 0, + "multi_head_dot_product_attention: Key projection matrix Wk has " + "incompatible size to keys matrix." + "Expected Wk[2] = keys[1] = %i, but got Wk = %s, keys = %s ", + keys->sizeAt(1), ShapeUtils::shapeAsString(Wk).c_str(), + ShapeUtils::shapeAsString(keys).c_str()); + + REQUIRE_TRUE(Wv->sizeAt(2) == values->sizeAt(1), 0, + "multi_head_dot_product_attention: Value projection matrix Wv " + "has incompatible size to values matrix." + "Expected Wv[2] = values[1] = %i, but got Wv = %s, values = %s ", + values->sizeAt(1), ShapeUtils::shapeAsString(Wv).c_str(), + ShapeUtils::shapeAsString(values).c_str()); + + REQUIRE_TRUE(Wo->sizeAt(0) == (Wv->sizeAt(1) * Wv->sizeAt(0)), 0, + "multi_head_dot_product_attention: Output projection matrix Wo " + "has incompatible size to attention result." + "Expected Wo[0] = Wv[0] * Wv[1] = %i, but got Wo = %s, Wv = %", + (Wv->sizeAt(1) * Wv->sizeAt(0)), + ShapeUtils::shapeAsString(Wo).c_str(), + ShapeUtils::shapeAsString(Wv).c_str()); + + // Project queries, keys, values + auto projectedQueries = AttentionHelper::multiHeadProject( + queries, Wq, + block.launchContext()); //[minibatch, numHeads, projectedSize, seqLength] + auto projectedKeys = AttentionHelper::multiHeadProject( + keys, Wk, + block.launchContext()); //[minibatch, numHeads, projectedSize, seqLength] + auto projectedValues = AttentionHelper::multiHeadProject( + values, Wv, + block.launchContext()); //[minibatch, numHeads, projectedSize, seqLength] + + // Apply Attention + // attnResults = [minibatch, numHeads, projectedSize, seqLenth + NDArray attnResults('c', + {projectedQueries.sizeAt(0), projectedValues.sizeAt(1), + projectedValues.sizeAt(2), projectedQueries.sizeAt(3)}, + projectedValues.dataType(), block.launchContext()); + sd::ops::dot_product_attention attention; + attention.execute({&projectedQueries, &projectedKeys, &projectedValues, mask}, + {&attnResults, weights ? OUTPUT_VARIABLE(1) : nullptr}, {}, + {normalization, weights}, {}); + + // Project attention results + attnResults.permutei({0, 3, 1, 2}); + attnResults.reshapei( + attnResults.ordering(), + {miniBatchSize * queryCount, numHeads * projectedValuesSize}); + + sd::ops::matmul mmul; + NDArray projRes('c', {attnResults.sizeAt(0), Wo->sizeAt(1)}, + values->dataType(), block.launchContext()); + mmul.execute({&attnResults, Wo}, {&projRes}, {}, {}, {}); + projRes.reshapei(projRes.ordering(), {miniBatchSize, queryCount, outSize}); + projRes.permutei({0, 2, 1}); + + // FIXME: bad for performance + output->assign(projRes); + + return Status::OK(); +} +DECLARE_TYPES(multi_head_dot_product_attention) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS}); + getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); } + +DECLARE_SHAPE_FN(multi_head_dot_product_attention) { + auto queryShape = inputShape->at(0); + auto keysShape = inputShape->at(1); + auto valuesShape = inputShape->at(2); + auto WkShape = inputShape->at(3); + auto WoShape = inputShape->at(6); + + auto batchSize = shape::sizeAt(queryShape, 0); + auto outSize = shape::sizeAt(WoShape, 1); + auto queryCount = shape::sizeAt(queryShape, 2); + auto numHeads = shape::sizeAt(WkShape, 0); + auto timeSteps = shape::sizeAt(keysShape, 2); + + auto weightsShape = ConstantShapeHelper::getInstance().createShapeInfo( + sd::ArrayOptions::dataType(valuesShape), 'c', + {batchSize, numHeads, timeSteps, queryCount}); + auto outputShape = ConstantShapeHelper::getInstance().createShapeInfo( + sd::ArrayOptions::dataType(valuesShape), 'c', + {batchSize, outSize, queryCount}); + + if (INT_ARG(1)) { + return SHAPELIST(outputShape, weightsShape); + } else { + return SHAPELIST(outputShape); + } +} + +CUSTOM_OP_IMPL(multi_head_dot_product_attention_bp, 8, 7, false, 0, 1) { + auto queries = INPUT_VARIABLE(0); + auto keys = INPUT_VARIABLE(1); + auto values = INPUT_VARIABLE(2); + auto Wq = INPUT_VARIABLE(3); + auto Wk = INPUT_VARIABLE(4); + auto Wv = INPUT_VARIABLE(5); + auto Wo = INPUT_VARIABLE(6); + auto eps = INPUT_VARIABLE(7); + auto mask = block.width() > 8 ? INPUT_VARIABLE(8) : nullptr; + + auto dLdq = OUTPUT_VARIABLE(0); + auto dLdk = OUTPUT_VARIABLE(1); + auto dLdv = OUTPUT_VARIABLE(2); + auto dLdWq = OUTPUT_VARIABLE(3); + auto dLdWk = OUTPUT_VARIABLE(4); + auto dLdWv = OUTPUT_VARIABLE(5); + auto dLdWo = OUTPUT_VARIABLE(6); + + int normalization = INT_ARG(0); + + auto numHeads = Wk->sizeAt(0); + auto miniBatchSize = queries->sizeAt(0); + auto queryCount = queries->sizeAt(2); + auto outSize = Wo->sizeAt(1); + auto projectedValuesSize = Wv->sizeAt(1); + + REQUIRE_TRUE( + queries->rankOf() == keys->rankOf() && keys->rankOf() == values->rankOf(), + 0, + "multi_head_dot_product_attention: Queries, Keys and Values must have " + "same rank. " + "But got queries = %s, keys = %s, values = %s", + ShapeUtils::shapeAsString(queries).c_str(), + ShapeUtils::shapeAsString(keys).c_str(), + ShapeUtils::shapeAsString(values).c_str()); + + REQUIRE_TRUE(queries->rankOf() == 3, 0, + "multi_head_dot_product_attention: Queries, Keys and Values " + "must be rank 3 arrays" + "But got rank = %i", + queries->rankOf()); + + REQUIRE_TRUE(Wq->rankOf() == Wk->rankOf() && Wk->rankOf() == Wv->rankOf(), 0, + "multi_head_dot_product_attention: Input projections weights " + "must have the same rank. " + "But got Wq = %s, Wk = %s, Wv = %s", + ShapeUtils::shapeAsString(Wq).c_str(), + ShapeUtils::shapeAsString(Wk).c_str(), + ShapeUtils::shapeAsString(Wv).c_str()); + + REQUIRE_TRUE(Wq->sizeAt(0) == Wk->sizeAt(0) && Wk->sizeAt(0) == Wv->sizeAt(0), + 0, + "multi_head_dot_product_attention: Projections weights must " + "have the same number of attention heads. " + "But got Wq = %s, Wk = %s, Wv = %s", + ShapeUtils::shapeAsString(Wq).c_str(), + ShapeUtils::shapeAsString(Wk).c_str(), + ShapeUtils::shapeAsString(Wv).c_str()); + + REQUIRE_TRUE(Wo->rankOf() == 2, 0, + "multi_head_dot_product_attention: Output projection weights " + "must have rank 2. " + "But got Wo = %s", + ShapeUtils::shapeAsString(Wo).c_str()); + + REQUIRE_TRUE( + Wq->sizeAt(2) == queries->sizeAt(1), 0, + "multi_head_dot_product_attention: Query projection matrix Wq has " + "incompatible size to queries matrix." + "Expected Wq[2] = queries[1] = %i, but got Wq = %s, queries = %s ", + queries->sizeAt(1), ShapeUtils::shapeAsString(Wq).c_str(), + ShapeUtils::shapeAsString(queries).c_str()); + + REQUIRE_TRUE(Wk->sizeAt(2) == keys->sizeAt(1), 0, + "multi_head_dot_product_attention: Key projection matrix Wk has " + "incompatible size to keys matrix." + "Expected Wk[2] = keys[1] = %i, but got Wk = %s, keys = %s ", + keys->sizeAt(1), ShapeUtils::shapeAsString(Wk).c_str(), + ShapeUtils::shapeAsString(keys).c_str()); + + REQUIRE_TRUE(Wv->sizeAt(2) == values->sizeAt(1), 0, + "multi_head_dot_product_attention: Value projection matrix Wv " + "has incompatible size to values matrix." + "Expected Wv[2] = values[1] = %i, but got Wv = %s, values = %s ", + values->sizeAt(1), ShapeUtils::shapeAsString(Wv).c_str(), + ShapeUtils::shapeAsString(values).c_str()); + + REQUIRE_TRUE(Wo->sizeAt(0) == (Wv->sizeAt(1) * Wv->sizeAt(0)), 0, + "multi_head_dot_product_attention: Output projection matrix Wo " + "has incompatible size to attention result." + "Expected Wo[0] = Wv[0] * Wv[1] = %i, but got Wo = %s, Wv = %", + (Wv->sizeAt(1) * Wv->sizeAt(0)), + ShapeUtils::shapeAsString(Wo).c_str(), + ShapeUtils::shapeAsString(Wv).c_str()); + + // Project queries, keys, values + auto projectedQueries = + AttentionHelper::multiHeadProject(queries, Wq, block.launchContext()); + auto projectedKeys = + AttentionHelper::multiHeadProject(keys, Wk, block.launchContext()); + auto projectedValues = + AttentionHelper::multiHeadProject(values, Wv, block.launchContext()); + + // Apply Attention + NDArray attnResults('c', + {projectedQueries.sizeAt(0), projectedValues.sizeAt(1), + projectedValues.sizeAt(2), projectedQueries.sizeAt(3)}, + projectedValues.dataType(), block.launchContext()); + sd::ops::dot_product_attention attention; + attention.execute({&projectedQueries, &projectedKeys, &projectedValues, mask}, + {&attnResults}, {}, {normalization, 0}, {}); + + // Project attention results + attnResults.permutei({0, 3, 1, 2}); + attnResults.reshapei( + attnResults.ordering(), + {miniBatchSize * queryCount, numHeads * projectedValuesSize}); + + // dLdWo + auto epsPerm = eps->permute({0, 2, 1}); + auto epsPostReshape = + epsPerm.reshape(eps->ordering(), {miniBatchSize * queryCount, outSize}); + sd::ops::matmul_bp matmulBp; + NDArray dLdPreWo(attnResults.shapeInfo(), false, block.launchContext()); + matmulBp.execute({&attnResults, Wo, &epsPostReshape}, + std::vector{&dLdPreWo, dLdWo}, {}, {}, {}); + + // dLdAttn + dLdPreWo.reshapei( + {miniBatchSize, queryCount, numHeads, projectedValues.sizeAt(2)}); + dLdPreWo.permutei({0, 2, 3, 1}); + + sd::ops::dot_product_attention_bp attentionBp; + NDArray dLdProjectedQueries(projectedQueries.shapeInfo(), false, + block.launchContext()); + NDArray dLdProjectedKeys(projectedKeys.shapeInfo(), false, + block.launchContext()); + NDArray dLdProjectedValues(projectedValues.shapeInfo(), false, + block.launchContext()); + attentionBp.execute( + {&projectedQueries, &projectedKeys, &projectedValues, &dLdPreWo, mask}, + {&dLdProjectedQueries, &dLdProjectedKeys, &dLdProjectedValues}, {}, + {normalization}, {}); + + AttentionHelper::multiHeadProjectBp(queries, Wq, &dLdProjectedQueries, dLdq, + dLdWq, block.launchContext()); + AttentionHelper::multiHeadProjectBp(keys, Wk, &dLdProjectedKeys, dLdk, dLdWk, + block.launchContext()); + AttentionHelper::multiHeadProjectBp(values, Wv, &dLdProjectedValues, dLdv, + dLdWv, block.launchContext()); + + return Status::OK(); } +DECLARE_TYPES(multi_head_dot_product_attention_bp) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS}); + getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); +} + +DECLARE_SHAPE_FN(multi_head_dot_product_attention_bp) { + Nd4jLong *dLdq_shape; + COPY_SHAPE(inputShape->at(0), dLdq_shape); + Nd4jLong *dLdk_shape; + COPY_SHAPE(inputShape->at(1), dLdk_shape); + Nd4jLong *dLdv_shape; + COPY_SHAPE(inputShape->at(2), dLdv_shape); + Nd4jLong *dLdWq_shape; + COPY_SHAPE(inputShape->at(3), dLdWq_shape); + Nd4jLong *dLdWk_shape; + COPY_SHAPE(inputShape->at(4), dLdWk_shape); + Nd4jLong *dLdWv_shape; + COPY_SHAPE(inputShape->at(5), dLdWv_shape); + Nd4jLong *dLdWo_shape; + COPY_SHAPE(inputShape->at(6), dLdWo_shape); + + return SHAPELIST(CONSTANT(dLdq_shape), CONSTANT(dLdk_shape), + CONSTANT(dLdv_shape), CONSTANT(dLdWq_shape), + CONSTANT(dLdWk_shape), CONSTANT(dLdWv_shape), + CONSTANT(dLdWo_shape)); +} + +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp index fde07566795c..5fe628777414 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp @@ -26,195 +26,256 @@ #include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_NULLIFIED(0); - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - const auto kH = INT_ARG(0); - const auto kW = INT_ARG(1); - const auto sH = INT_ARG(2); - const auto sW = INT_ARG(3); - auto pH = INT_ARG(4); - auto pW = INT_ARG(5); - const auto dH = INT_ARG(6); - const auto dW = INT_ARG(7); - const auto isSameMode = static_cast(INT_ARG(8)); - const auto extraParam0 = INT_ARG(9); - const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int oH = 0; - int oW = 0; - - const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); - const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); - - if(!isNCHW) { - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] - } - - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - if (isSameMode) - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; - ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::AVG_POOL, extraParam0); - - if(!isNCHW) { - delete input; - delete output; - } - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_NULLIFIED(0); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + const auto kH = INT_ARG(0); + const auto kW = INT_ARG(1); + const auto sH = INT_ARG(2); + const auto sW = INT_ARG(3); + auto pH = INT_ARG(4); + auto pW = INT_ARG(5); + const auto dH = INT_ARG(6); + const auto dW = INT_ARG(7); + const auto isSameMode = static_cast(INT_ARG(8)); + const auto extraParam0 = INT_ARG(9); + const int isNCHW = + block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "AVGPOOL2D op: input should have rank of 4, but got %i instead", + input->rankOf()); + REQUIRE_TRUE( + dH != 0 && dW != 0, 0, + "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, + dW); + + int oH = 0; + int oW = 0; + + const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); + const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); + + if (!isNCHW) { + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + output = new NDArray( + output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] + } + + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, + iH, iW, isSameMode); + + if (isSameMode) + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; + ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, + dH, dW, PoolingType::AVG_POOL, extraParam0); + + if (!isNCHW) { + delete input; + delete output; + } + + return Status::OK(); } DECLARE_SYN(AvgPool2D, avgpool2d); DECLARE_SYN(AvgPool, avgpool2d); DECLARE_SYN(avgpool, avgpool2d); - DECLARE_TYPES(avgpool2d) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(avgpool2d) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(avgpool2d) { - - auto inShape = inputShape->at(0); - auto shapeOf = shape::shapeOf(inShape); - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - auto argI = *(block.getIArguments()); - const int kH = INT_ARG(0); - const int kW = INT_ARG(1); - const int sH = INT_ARG(2); - const int sW = INT_ARG(3); - const int pH = INT_ARG(4); - const int pW = INT_ARG(5); - const int dH = INT_ARG(6); - const int dW = INT_ARG(7); - const int isSameMode = INT_ARG(8); - - const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - const int bS = shapeOf[0]; - const int iD = isNCHW ? shapeOf[1] : shapeOf[3]; - const int iH = isNCHW ? shapeOf[2] : shapeOf[1]; - const int iW = isNCHW ? shapeOf[3] : shapeOf[2]; - - const char order = shape::order(inShape); // output order must be equal to input order - - // calculate output Height/Width - int oH, oW; - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - // allocate memory for new shape - Nd4jLong newShape[4]; - if (isNCHW) { - newShape[0] = bS; - newShape[1] = iD; - newShape[2] = oH; - newShape[3] = oW; - } else { - newShape[0] = bS; - newShape[1] = oH; - newShape[2] = oW; - newShape[3] = iD; - } - - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), newShape, 4))); + auto inShape = inputShape->at(0); + auto shapeOf = shape::shapeOf(inShape); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + auto argI = block.getIArguments(); + const int kH = INT_ARG(0); + const int kW = INT_ARG(1); + const int sH = INT_ARG(2); + const int sW = INT_ARG(3); + const int pH = INT_ARG(4); + const int pW = INT_ARG(5); + const int dH = INT_ARG(6); + const int dW = INT_ARG(7); + const int isSameMode = INT_ARG(8); + + const int isNCHW = + block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE( + dH != 0 && dW != 0, 0, + "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, + dW); + + const int bS = shapeOf[0]; + const int iD = isNCHW ? shapeOf[1] : shapeOf[3]; + const int iH = isNCHW ? shapeOf[2] : shapeOf[1]; + const int iW = isNCHW ? shapeOf[3] : shapeOf[2]; + + const char order = + shape::order(inShape); // output order must be equal to input order + + // calculate output Height/Width + int oH, oW; + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, + iH, iW, isSameMode); + + // allocate memory for new shape + Nd4jLong newShape[4]; + if (isNCHW) { + newShape[0] = bS; + newShape[1] = iD; + newShape[2] = oH; + newShape[3] = oW; + } else { + newShape[0] = bS; + newShape[1] = oH; + newShape[2] = oW; + newShape[3] = iD; + } + + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), + newShape, 4))); } - DECLARE_TYPES(avgpool2d_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(avgpool2d_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int extraParam0 = INT_ARG(9); - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); - std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); - - if(!isNCHW) { - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] - } - - if(isSameMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - // NDArray columnsWrongShape(input->ordering(), {bS, iC, oH, oW, kH, kW}, input->getWorkspace()); - // NDArray* columns = columnsWrongShape.permute({0, 1, 4, 5, 2, 3}); // [bS, iC, oH, oW, kH, kW] -> [bS, iC, kH, kW, oH, oW] - // NDArray* gradOVector = gradO->reshape('c', {(int) gradO->lengthOf(), 1}); - // NDArray* columns2d = columnsWrongShape.reshape('c', {bS*iC*oH*oW, kH*kW}); - - // columns2d->addiColumnVector(gradOVector); - - // columns->template applyTransform>(gradI, std::vector({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data()); - - // *gradI /= kH*kW; - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; - ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 1, extraParam0); - - if(!isNCHW) { - delete input; - delete gradI; - delete gradO; - } - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE( + 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_NULLIFIED( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int extraParam0 = INT_ARG(9); + int isNCHW = + block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", + input->rankOf()); + REQUIRE_TRUE( + dH != 0 && dW != 0, 0, + "AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", + dH, dW); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWoC, indWkH, indOoH); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}); + std::vector expectedGradIShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "AVGPOOL2D_BP op: wrong shape of output's gradients array (next " + "epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, + "AVGPOOL2D_BP op: wrong shape of input's gradients array " + "(epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradIShape).c_str(), + ShapeUtils::shapeAsString(gradI).c_str()); + + if (!isNCHW) { + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradI = new NDArray( + gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradO = new NDArray( + gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] + } + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + // NDArray columnsWrongShape(input->ordering(), {bS, iC, oH, oW, kH, kW}, + // input->getWorkspace()); NDArray* columns = columnsWrongShape.permute({0, + // 1, 4, 5, 2, 3}); // [bS, iC, oH, oW, kH, kW] + // -> [bS, iC, kH, kW, oH, oW] NDArray* gradOVector = gradO->reshape('c', + // {(int) gradO->lengthOf(), 1}); NDArray* columns2d = + // columnsWrongShape.reshape('c', {bS*iC*oH*oW, kH*kW}); + + // columns2d->addiColumnVector(gradOVector); + + // columns->template applyTransform>(gradI, + // std::vector({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, + // (T)dW}).data()); + + // *gradI /= kH*kW; + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; + ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, + pH, pW, dH, dW, 1, extraParam0); + + if (!isNCHW) { + delete input; + delete gradI; + delete gradO; + } + + return Status::OK(); } DECLARE_SHAPE_FN(avgpool2d_bp) { - - REQUIRE_TRUE(inputShape->at(0)[0] == 4, 0, "AVGPOOL2D_BP op: input array must be 4D, but got %i instead!", inputShape->at(0)[0]); - REQUIRE_TRUE(inputShape->at(1)[0] == 4, 0, "AVGPOOL2D_BP op: output's gradient array (next epsilon) must be 4D, but got %i instead!", inputShape->at(1)[0]); - - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(inputShape->at(0), ArrayOptions::dataType(inputShape->at(1))))); + REQUIRE_TRUE(inputShape->at(0)[0] == 4, 0, + "AVGPOOL2D_BP op: input array must be 4D, but got %i instead!", + inputShape->at(0)[0]); + REQUIRE_TRUE(inputShape->at(1)[0] == 4, 0, + "AVGPOOL2D_BP op: output's gradient array (next epsilon) must " + "be 4D, but got %i instead!", + inputShape->at(1)[0]); + + return SHAPELIST( + ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor( + inputShape->at(0), ArrayOptions::dataType(inputShape->at(1))))); } - -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp index d8df113852cf..053db8063b56 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp @@ -25,189 +25,253 @@ #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto output = OUTPUT_NULLIFIED(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) - - int kD = INT_ARG(0); // filter(kernel) depth - int kH = INT_ARG(1); // filter(kernel) height - int kW = INT_ARG(2); // filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID - int extraParam0 = INT_ARG(13); - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC - - REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::vector expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "AVGPOOL3DNEW OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); - - if(!isNCDHW) { - input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] - } - - if(isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - //T extraParams[] = {}; - ConvolutionUtils::pooling3d(block, *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0); - - if(!isNCDHW) { - delete input; - delete output; - } - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto output = OUTPUT_NULLIFIED( + 0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) + + int kD = INT_ARG(0); // filter(kernel) depth + int kH = INT_ARG(1); // filter(kernel) height + int kW = INT_ARG(2); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + int extraParam0 = INT_ARG(13); + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "AVGPOOL3DNEW OP: rank of input array must be equal to 5, but " + "got %i instead !", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "AVGPOOL3DNEW OP: dilation must not be zero, but got instead " + "{%i, %i, %i}", + dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedOutputShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oD, oH, oW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, + "AVGPOOL3DNEW OP: wrong shape of output array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedOutputShape).c_str(), + ShapeUtils::shapeAsString(output).c_str()); + + if (!isNCDHW) { + input = new NDArray(input->permute( + {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] + output = new NDArray(output->permute( + {0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] + } + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + // T extraParams[] = {}; + ConvolutionUtils::pooling3d(block, *input, *output, kD, kH, kW, sD, sH, sW, + pD, pH, pW, dD, dH, dW, 1, extraParam0); + + if (!isNCDHW) { + delete input; + delete output; + } + + return Status::OK(); } - DECLARE_TYPES(avgpool3dnew) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(avgpool3dnew) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(avgpool3dnew) { - - int kD = INT_ARG(0); // filter(kernel) depth - int kH = INT_ARG(1); // filter(kernel) height - int kW = INT_ARG(2); // filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC - - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - auto inputShapeInfo = inputShape->at(0); - - int idxID, idxIC; - if(isNCDHW) { idxID = 2; idxIC = 1;} - else { idxID = 1; idxIC = 4;} - - int bS = inputShapeInfo[1]; // batch size - int iC = inputShapeInfo[idxIC+1]; // input channels - int iD = inputShapeInfo[idxID+1]; // input depth - int iH = inputShapeInfo[idxID+2]; // input height - int iW = inputShapeInfo[idxID+3]; // input width - - int oD, oH, oW; // output depth, height, width - ConvolutionUtils::calcOutSizePool3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); - - Nd4jLong outputShape[5]; - - outputShape[0] = bS; - - if (isNCDHW) { - outputShape[1] = iC; - outputShape[2] = oD; - outputShape[3] = oH; - outputShape[4] = oW; - } else { - outputShape[1] = oD; - outputShape[2] = oH; - outputShape[3] = oW; - outputShape[4] = iC; - } - // TF DOC: A Tensor. Has the same type as input. - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inputShapeInfo), shape::order(inputShapeInfo), outputShape, 5))); + int kD = INT_ARG(0); // filter(kernel) depth + int kH = INT_ARG(1); // filter(kernel) height + int kW = INT_ARG(2); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "AVGPOOL3DNEW op: dilation must not be zero, but got instead " + "{%i, %i, %i}", + dD, dH, dW); + + auto inputShapeInfo = inputShape->at(0); + + int idxID, idxIC; + if (isNCDHW) { + idxID = 2; + idxIC = 1; + } else { + idxID = 1; + idxIC = 4; + } + + int bS = inputShapeInfo[1]; // batch size + int iC = inputShapeInfo[idxIC + 1]; // input channels + int iD = inputShapeInfo[idxID + 1]; // input depth + int iH = inputShapeInfo[idxID + 2]; // input height + int iW = inputShapeInfo[idxID + 3]; // input width + + int oD, oH, oW; // output depth, height, width + ConvolutionUtils::calcOutSizePool3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, + pH, pW, dD, dH, dW, iD, iH, iW, + isSameMode); + + Nd4jLong outputShape[5]; + + outputShape[0] = bS; + + if (isNCDHW) { + outputShape[1] = iC; + outputShape[2] = oD; + outputShape[3] = oH; + outputShape[4] = oW; + } else { + outputShape[1] = oD; + outputShape[2] = oH; + outputShape[3] = oW; + outputShape[4] = iC; + } + // TF DOC: A Tensor. Has the same type as input. + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(inputShapeInfo), + shape::order(inputShapeInfo), outputShape, 5))); } - DECLARE_TYPES(avgpool3dnew_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(avgpool3dnew_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - - const int kD = INT_ARG(0); // filter(kernel) depth - const int kH = INT_ARG(1); // filter(kernel) height - const int kW = INT_ARG(2); // filter(kernel) width - const int sD = INT_ARG(3); // strides depth - const int sH = INT_ARG(4); // strides height - const int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - const int dD = INT_ARG(9); // dilations depth - const int dH = INT_ARG(10); // dilations height - const int dW = INT_ARG(11); // dilations width - const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID - const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging - const int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC - - REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW_BP op: input should have rank of 5, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW_BP op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL3DNEW_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "AVGPOOL3DNEW_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); - - if(!isNCDHW) { - input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - gradI = new NDArray(gradI->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - gradO = new NDArray(gradO->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] - } - - if(isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; - ConvolutionUtils::pooling3dBP(block, *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0); - - if(!isNCDHW) { - delete input; - delete gradI; - delete gradO; - } - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, + // oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_NULLIFIED(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), epsilon + + const int kD = INT_ARG(0); // filter(kernel) depth + const int kH = INT_ARG(1); // filter(kernel) height + const int kW = INT_ARG(2); // filter(kernel) width + const int sD = INT_ARG(3); // strides depth + const int sH = INT_ARG(4); // strides height + const int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + const int dD = INT_ARG(9); // dilations depth + const int dH = INT_ARG(10); // dilations height + const int dW = INT_ARG(11); // dilations width + const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + const int extraParam0 = + INT_ARG(13); // define what divisor to use while averaging + const int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE( + input->rankOf() == 5, 0, + "AVGPOOL3DNEW_BP op: input should have rank of 5, but got %i instead", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "AVGPOOL3DNEW_BP op: dilation must not be zero, but got instead " + "{%i, %i, %i}", + dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oD, oH, oW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + std::vector expectedGradIShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iD, iH, iW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "AVGPOOL3DNEW_BP op: wrong shape of output's gradients array " + "(next epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, + "AVGPOOL3DNEW_BP op: wrong shape of input's gradients array " + "(epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradIShape).c_str(), + ShapeUtils::shapeAsString(gradI).c_str()); + + if (!isNCDHW) { + input = new NDArray(input->permute( + {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] + gradI = new NDArray(gradI->permute( + {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] + gradO = new NDArray(gradO->permute( + {0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] + } + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; + ConvolutionUtils::pooling3dBP(block, *input, *gradO, *gradI, kD, kH, kW, sD, + sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0); + + if (!isNCDHW) { + delete input; + delete gradI; + delete gradO; + } + + return Status::OK(); } - DECLARE_SHAPE_FN(avgpool3dnew_bp) { - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(inputShape->at(0), ArrayOptions::dataType(inputShape->at(1))))); + return SHAPELIST( + ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor( + inputShape->at(0), ArrayOptions::dataType(inputShape->at(1))))); } - - -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp index 8a37b90b0da6..cb34b667a73d 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp @@ -26,200 +26,261 @@ #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// // maxpool2d corresponds to poolingMode=0 CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) { - - auto input = INPUT_VARIABLE(0); - - REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D OP: input array should have rank of 4, but got %i instead", input->rankOf()); - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - auto output = OUTPUT_NULLIFIED(0); - - const int kH = INT_ARG(0); - const int kW = INT_ARG(1); - const int sH = INT_ARG(2); - const int sW = INT_ARG(3); - int pH = INT_ARG(4); - int pW = INT_ARG(5); - const int dH = INT_ARG(6); - const int dW = INT_ARG(7); - const bool isSameMode = INT_ARG(8); - - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int oH = 0; - int oW = 0; - - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW - - const int iH = isNCHW ? input->sizeAt(2) : input->sizeAt(1); - const int iW = isNCHW ? input->sizeAt(3) : input->sizeAt(2); - - if(!isNCHW) { - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] - } - - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - if (isSameMode) - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; poolingMode; 9 - divisor; - ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::MAX_POOL, 1); - - if(!isNCHW) { - delete input; - delete output; - } - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "MAXPOOL2D OP: input array should have rank of 4, but got %i instead", + input->rankOf()); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + auto output = OUTPUT_NULLIFIED(0); + + const int kH = INT_ARG(0); + const int kW = INT_ARG(1); + const int sH = INT_ARG(2); + const int sW = INT_ARG(3); + int pH = INT_ARG(4); + int pW = INT_ARG(5); + const int dH = INT_ARG(6); + const int dW = INT_ARG(7); + const bool isSameMode = INT_ARG(8); + + REQUIRE_TRUE( + dH != 0 && dW != 0, 0, + "MAXPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, + dW); + + int oH = 0; + int oW = 0; + + int isNCHW = + block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW + + const int iH = isNCHW ? input->sizeAt(2) : input->sizeAt(1); + const int iW = isNCHW ? input->sizeAt(3) : input->sizeAt(2); + + if (!isNCHW) { + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + output = new NDArray( + output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] + } + + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, + iH, iW, isSameMode); + + if (isSameMode) + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; poolingMode; 9 - divisor; + ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, + dH, dW, PoolingType::MAX_POOL, 1); + + if (!isNCHW) { + delete input; + delete output; + } + + return Status::OK(); } DECLARE_SYN(MaxPool2D, maxpool2d); DECLARE_SYN(MaxPool, maxpool2d); DECLARE_SYN(maxpool, maxpool2d); - DECLARE_TYPES(maxpool2d) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - +DECLARE_TYPES(maxpool2d) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} DECLARE_SHAPE_FN(maxpool2d) { - - //NDArray *x = block.getVariables().at(0)->getNDArray(); - auto inShape = inputShape->at(0); - auto shapeOf = shape::shapeOf(inShape); - // 0 - number of dimensions; 1,2 - kernel Height/Width; 3,4 - stride Height/Width; 5,6 - pad Height/Width; 7,8 - dilation Height/Width; 9,10 - input Height/Width; 11 - batch size; 12 - input depth; 13 - same mode; - int kH = INT_ARG(0); - int kW = INT_ARG(1); - int sH = INT_ARG(2); - int sW = INT_ARG(3); - int pH = INT_ARG(4); - int pW = INT_ARG(5); - int dH = INT_ARG(6); - int dW = INT_ARG(7); - int isSameMode = INT_ARG(8); - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW - - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS = shapeOf[0]; - int iC = isNCHW ? shapeOf[1] : shapeOf[3]; - int iH = isNCHW ? shapeOf[2] : shapeOf[1]; - int iW = isNCHW ? shapeOf[3] : shapeOf[2]; - - char order = shape::order(inShape); // output order must be equal to input order - - // calculate output Height/Width - int oH, oW; - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - // allocate memory for new shape - Nd4jLong newShape[4]; - - newShape[0] = bS; - if (isNCHW) { - newShape[1] = iC; - newShape[2] = oH; - newShape[3] = oW; - } else { - newShape[1] = oH; - newShape[2] = oW; - newShape[3] = iC; - } - - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShape), order, newShape, 4))); + // NDArray *x = block.getVariables().at(0)->getNDArray(); + auto inShape = inputShape->at(0); + auto shapeOf = shape::shapeOf(inShape); + // 0 - number of dimensions; 1,2 - kernel Height/Width; 3,4 - stride + // Height/Width; 5,6 - pad Height/Width; 7,8 - dilation Height/Width; 9,10 - + // input Height/Width; 11 - batch size; 12 - input depth; 13 - same mode; + int kH = INT_ARG(0); + int kW = INT_ARG(1); + int sH = INT_ARG(2); + int sW = INT_ARG(3); + int pH = INT_ARG(4); + int pW = INT_ARG(5); + int dH = INT_ARG(6); + int dW = INT_ARG(7); + int isSameMode = INT_ARG(8); + int isNCHW = + block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW + + REQUIRE_TRUE( + dH != 0 && dW != 0, 0, + "MAXPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, + dW); + + int bS = shapeOf[0]; + int iC = isNCHW ? shapeOf[1] : shapeOf[3]; + int iH = isNCHW ? shapeOf[2] : shapeOf[1]; + int iW = isNCHW ? shapeOf[3] : shapeOf[2]; + + char order = + shape::order(inShape); // output order must be equal to input order + + // calculate output Height/Width + int oH, oW; + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, + iH, iW, isSameMode); + + // allocate memory for new shape + Nd4jLong newShape[4]; + + newShape[0] = bS; + if (isNCHW) { + newShape[1] = iC; + newShape[2] = oH; + newShape[3] = oW; + } else { + newShape[1] = oH; + newShape[2] = oW; + newShape[3] = iC; + } + + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(inShape), order, newShape, 4))); } - DECLARE_TYPES(maxpool2d_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(maxpool2d_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(maxpool2d_bp, 2, 1, false, 0, 10) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW - - REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); - std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "MAXPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); - - if(!isNCHW) { - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] - } - - if(isSameMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - // NDArray columnsWrongShape(input->ordering(), {bS, iC, oH, oW, kH, kW}, input->getWorkspace()); - // NDArray* columns = columnsWrongShape.permute({0, 1, 4, 5, 2, 3}); // [bS, iC, oH, oW, kH, kW] -> [bS, iC, kH, kW, oH, oW] - - // input->template applyTransform>(columns, std::vector({(T)kH, (T)kW, (T)sH, (T)sW, (T)pH, (T)pW, (T)dH, (T)dW, (T)0.f, (T)0.f}).data()); - - // NDArray* columns2d = columnsWrongShape.reshape('c', {bS*iC*oH*oW, kH*kW}); - // NDArray* gradOVector = gradO->reshape('c', {(int) gradO->lengthOf(), 1}); - - // columns2d->template applyTransform>(std::vector({(T)1., (T)1.}).data()); - // columns2d->muliColumnVector(gradOVector); - - // columns->template applyTransform>(gradI, std::vector({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data()); - - ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 0., 1.); - - if(!isNCHW) { - delete input; - delete gradI; - delete gradO; - } - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE( + 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_NULLIFIED( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "MAXPOOL2D_BP op: input should have rank of 4, but got %i instead", + input->rankOf()); + REQUIRE_TRUE( + dH != 0 && dW != 0, 0, + "MAXPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", + dH, dW); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWoC, indWkH, indOoH); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}); + std::vector expectedGradIShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "MAXPOOL2D_BP op: wrong shape of output's gradients array (next " + "epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, + "MAXPOOL2D_BP op: wrong shape of input's gradients array " + "(epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradIShape).c_str(), + ShapeUtils::shapeAsString(gradI).c_str()); + + if (!isNCHW) { + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradI = new NDArray( + gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradO = new NDArray( + gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] + } + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + // NDArray columnsWrongShape(input->ordering(), {bS, iC, oH, oW, kH, kW}, + // input->getWorkspace()); NDArray* columns = columnsWrongShape.permute({0, + // 1, 4, 5, 2, 3}); // [bS, iC, oH, oW, kH, kW] + // -> [bS, iC, kH, kW, oH, oW] + + // input->template applyTransform>(columns, + // std::vector({(T)kH, (T)kW, (T)sH, (T)sW, (T)pH, (T)pW, (T)dH, (T)dW, + // (T)0.f, (T)0.f}).data()); + + // NDArray* columns2d = columnsWrongShape.reshape('c', {bS*iC*oH*oW, + // kH*kW}); NDArray* gradOVector = gradO->reshape('c', {(int) + // gradO->lengthOf(), 1}); + + // columns2d->template + // applyTransform>(std::vector({(T)1., (T)1.}).data()); + // columns2d->muliColumnVector(gradOVector); + + // columns->template applyTransform>(gradI, + // std::vector({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, + // (T)dW}).data()); + + ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, + pH, pW, dH, dW, 0., 1.); + + if (!isNCHW) { + delete input; + delete gradI; + delete gradO; + } + + return Status::OK(); } DECLARE_SYN(MaxPool2D_bp, maxpool2d_bp); DECLARE_SYN(MaxPool_bp, maxpool2d_bp); DECLARE_SHAPE_FN(maxpool2d_bp) { - - REQUIRE_TRUE(inputShape->at(0)[0] == 4, 0, "MAXPOOL2D_BP op: input array must be 4D, but got %i instead!", inputShape->at(0)[0]); - REQUIRE_TRUE(inputShape->at(1)[0] == 4, 0, "MAXPOOL2D_BP op: output's gradient array (next epsilon) must be 4D, but got %i instead!", inputShape->at(1)[0]); - - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(inputShape->at(0), ArrayOptions::dataType(inputShape->at(1))))); + REQUIRE_TRUE(inputShape->at(0)[0] == 4, 0, + "MAXPOOL2D_BP op: input array must be 4D, but got %i instead!", + inputShape->at(0)[0]); + REQUIRE_TRUE(inputShape->at(1)[0] == 4, 0, + "MAXPOOL2D_BP op: output's gradient array (next epsilon) must " + "be 4D, but got %i instead!", + inputShape->at(1)[0]); + + return SHAPELIST( + ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor( + inputShape->at(0), ArrayOptions::dataType(inputShape->at(1))))); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp index fd28901cc281..8b75065f71fd 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp @@ -25,204 +25,278 @@ #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto output = OUTPUT_NULLIFIED(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) - - int kD = INT_ARG(0); // filter(kernel) depth - int kH = INT_ARG(1); // filter(kernel) height - int kW = INT_ARG(2); // filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID - // int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW - - REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::vector expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); - // REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW); - // REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW); - - if(!isNCDHW) { - input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] - } - - if(isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - ConvolutionUtils::pooling3d(block, *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1); - - if(!isNCDHW) { - delete input; - delete output; - } - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto output = OUTPUT_NULLIFIED( + 0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) + + int kD = INT_ARG(0); // filter(kernel) depth + int kH = INT_ARG(1); // filter(kernel) height + int kW = INT_ARG(2); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + // int extraParam0 = INT_ARG(13); // + // unnecessary for max case, required only for avg and pnorm cases + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "MAXPOOL3DNEW OP: rank of input array must be equal to 5, but " + "got %i instead !", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "MAXPOOL3DNEW op: dilation must not be zero, but got instead " + "{%i, %i, %i}", + dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedOutputShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oD, oH, oW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, + "MAXPOOL3D op: wrong shape of output array, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(expectedOutputShape).c_str(), + ShapeUtils::shapeAsString(output).c_str()); + // REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the + // input depth/height/width must be greater or equal to kernel(filter) + // depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly + // !", iD,iH,iW, kD,kH,kW); REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= + // pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half + // of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] + // correspondingly !", pD,pH,pW, kD,kH,kW); + + if (!isNCDHW) { + input = new NDArray(input->permute( + {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] + output = new NDArray(output->permute( + {0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] + } + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + ConvolutionUtils::pooling3d(block, *input, *output, kD, kH, kW, sD, sH, sW, + pD, pH, pW, dD, dH, dW, 0, 1); + + if (!isNCDHW) { + delete input; + delete output; + } + + return Status::OK(); } - DECLARE_TYPES(maxpool3dnew) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } +DECLARE_TYPES(maxpool3dnew) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} DECLARE_SHAPE_FN(maxpool3dnew) { - - int kD = INT_ARG(0); // filter(kernel) depth - int kH = INT_ARG(1); // filter(kernel) height - int kW = INT_ARG(2); // filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID - // int extraParam0 = INT_ARG(13); - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW - - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - auto inputShapeInfo = inputShape->at(0); - - int idxID, idxIC; - if(isNCDHW) { idxID = 2; idxIC = 1;} - else { idxID = 1; idxIC = 4;} - - int bS = inputShapeInfo[1]; // batch size - int iC = inputShapeInfo[idxIC+1]; // input channels - int iD = inputShapeInfo[idxID+1]; // input depth - int iH = inputShapeInfo[idxID+2]; // input height - int iW = inputShapeInfo[idxID+3]; // input width - - int oD, oH, oW; // output depth, height, width - ConvolutionUtils::calcOutSizePool3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); - - Nd4jLong outputShape[5]; - - - outputShape[0] = bS; - if (isNCDHW) { - outputShape[1] = iC; - outputShape[2] = oD; - outputShape[3] = oH; - outputShape[4] = oW; - } else { - outputShape[1] = oD; - outputShape[2] = oH; - outputShape[3] = oW; - outputShape[4] = iC; - } - - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inputShapeInfo), shape::order(inputShapeInfo), outputShape, 5))); + int kD = INT_ARG(0); // filter(kernel) depth + int kH = INT_ARG(1); // filter(kernel) height + int kW = INT_ARG(2); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + // int extraParam0 = INT_ARG(13); + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW + + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "MAXPOOL3DNEW op: dilation must not be zero, but got instead " + "{%i, %i, %i}", + dD, dH, dW); + + auto inputShapeInfo = inputShape->at(0); + + int idxID, idxIC; + if (isNCDHW) { + idxID = 2; + idxIC = 1; + } else { + idxID = 1; + idxIC = 4; + } + + int bS = inputShapeInfo[1]; // batch size + int iC = inputShapeInfo[idxIC + 1]; // input channels + int iD = inputShapeInfo[idxID + 1]; // input depth + int iH = inputShapeInfo[idxID + 2]; // input height + int iW = inputShapeInfo[idxID + 3]; // input width + + int oD, oH, oW; // output depth, height, width + ConvolutionUtils::calcOutSizePool3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, + pH, pW, dD, dH, dW, iD, iH, iW, + isSameMode); + + Nd4jLong outputShape[5]; + + outputShape[0] = bS; + if (isNCDHW) { + outputShape[1] = iC; + outputShape[2] = oD; + outputShape[3] = oH; + outputShape[4] = oW; + } else { + outputShape[1] = oD; + outputShape[2] = oH; + outputShape[3] = oW; + outputShape[4] = iC; + } + + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(inputShapeInfo), + shape::order(inputShapeInfo), outputShape, 5))); } - DECLARE_TYPES(maxpool3dnew_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(maxpool3dnew_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - - const int kD = INT_ARG(0); // filter(kernel) depth - const int kH = INT_ARG(1); // filter(kernel) height - const int kW = INT_ARG(2); // filter(kernel) width - const int sD = INT_ARG(3); // strides depth - const int sH = INT_ARG(4); // strides height - const int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - const int dD = INT_ARG(9); // dilations depth - const int dH = INT_ARG(10); // dilations height - const int dW = INT_ARG(11); // dilations width - const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID - // int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW - - REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW_BP op: input should have rank of 5, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW_BP op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL3DNEW_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "MAXPOOL3DNEW_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); - - if(!isNCDHW) { - input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - gradI = new NDArray(gradI->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - gradO = new NDArray(gradO->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] - } - - if(isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - // NDArray columnsWrongShape(input->ordering(), {bS, iC, oD, oH, oW, kD, kH, kW}, input->getWorkspace()); - // NDArray* columns = columnsWrongShape.permute({0, 1, 5, 6, 7, 2, 3, 4}); // [bS, iC, oD, oH, oW, kD, kH, kW] -> [bS, iC, kD, kH, kW, oD, oH, oW] - - // ConvolutionUtils::vol2col(*input, *columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW] - - // NDArray* columns2d = columnsWrongShape.reshape('c', {bS*iC*oD*oH*oW, kD*kH*kW}); - // NDArray* gradOVector = gradO->reshape('c', {(int) gradO->lengthOf(), 1}); - // T extraParams[] = {(T)1., (T)1.}; - // columns2d->template applyTransform>(extraParams); - // columns2d->muliColumnVector(gradOVector); - - // ConvolutionUtils::col2vol(*columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW] - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - unnecessary; - ConvolutionUtils::pooling3dBP(block, *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1); - - if(!isNCDHW) { - delete input; - delete gradI; - delete gradO; - } - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, + // oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_NULLIFIED(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), epsilon + + const int kD = INT_ARG(0); // filter(kernel) depth + const int kH = INT_ARG(1); // filter(kernel) height + const int kW = INT_ARG(2); // filter(kernel) width + const int sD = INT_ARG(3); // strides depth + const int sH = INT_ARG(4); // strides height + const int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + const int dD = INT_ARG(9); // dilations depth + const int dH = INT_ARG(10); // dilations height + const int dW = INT_ARG(11); // dilations width + const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + // int extraParam0 = INT_ARG(13); // + // unnecessary for max case, required only for avg and pnorm cases + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW + + REQUIRE_TRUE( + input->rankOf() == 5, 0, + "MAXPOOL3DNEW_BP op: input should have rank of 5, but got %i instead", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "MAXPOOL3DNEW_BP op: dilation must not be zero, but got instead " + "{%i, %i, %i}", + dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oD, oH, oW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + std::vector expectedGradIShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iD, iH, iW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "MAXPOOL3DNEW_BP op: wrong shape of output's gradients array " + "(next epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, + "MAXPOOL3DNEW_BP op: wrong shape of input's gradients array " + "(epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradIShape).c_str(), + ShapeUtils::shapeAsString(gradI).c_str()); + + if (!isNCDHW) { + input = new NDArray(input->permute( + {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] + gradI = new NDArray(gradI->permute( + {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] + gradO = new NDArray(gradO->permute( + {0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] + } + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + // NDArray columnsWrongShape(input->ordering(), {bS, iC, oD, oH, oW, kD, + // kH, kW}, input->getWorkspace()); NDArray* columns = + // columnsWrongShape.permute({0, 1, 5, 6, 7, 2, 3, 4}); // [bS, iC, oD, oH, + // oW, kD, kH, kW] -> [bS, iC, kD, kH, kW, oD, oH, oW] + + // ConvolutionUtils::vol2col(*input, *columns, sD, sH, sW, pD, pH, pW, dD, + // dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, + // kD, kH, kW, oD, oH, oW] + + // NDArray* columns2d = columnsWrongShape.reshape('c', {bS*iC*oD*oH*oW, + // kD*kH*kW}); NDArray* gradOVector = gradO->reshape('c', {(int) + // gradO->lengthOf(), 1}); T extraParams[] = {(T)1., (T)1.}; + // columns2d->template applyTransform>(extraParams); + // columns2d->muliColumnVector(gradOVector); + + // ConvolutionUtils::col2vol(*columns, *gradI, sD, sH, sW, pD, pH, pW, dD, + // dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is + // de-convoluted to [bS, iC, iD, iH, iW] + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - + // unnecessary; + ConvolutionUtils::pooling3dBP(block, *input, *gradO, *gradI, kD, kH, kW, sD, + sH, sW, pD, pH, pW, dD, dH, dW, 0, 1); + + if (!isNCDHW) { + delete input; + delete gradI; + delete gradO; + } + + return Status::OK(); } - DECLARE_SHAPE_FN(maxpool3dnew_bp) { - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(inputShape->at(0), ArrayOptions::dataType(inputShape->at(1))))); + return SHAPELIST( + ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor( + inputShape->at(0), ArrayOptions::dataType(inputShape->at(1))))); } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp index b03d19451044..af1c3c8147c7 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp @@ -26,40 +26,41 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(max_pool_with_argmax, 1, 2, false, 0, 9) { +namespace ops { +CUSTOM_OP_IMPL(max_pool_with_argmax, 1, 2, false, 0, 9) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_NULLIFIED(0); + auto indices = OUTPUT_NULLIFIED(1); - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_NULLIFIED(0); - auto indices = OUTPUT_NULLIFIED(1); + REQUIRE_TRUE( + x->rankOf() == 4, 0, + "max_pool_with_argmax: Input should have rank of 4, but got %i instead", + x->rankOf()); - REQUIRE_TRUE(x->rankOf() == 4, 0, "max_pool_with_argmax: Input should have rank of 4, but got %i instead", x->rankOf()); + auto argI = block.getIArguments(); - auto argI = *(block.getIArguments()); + helpers::maxPoolingFunctor(block.launchContext(), block, x, z, argI, indices); - helpers::maxPoolingFunctor(block.launchContext(), block, x, z, argI, indices); - - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(max_pool_with_argmax) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedOutputTypes(1, {ALL_INDICES}); +DECLARE_TYPES(max_pool_with_argmax) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedOutputTypes(1, {ALL_INDICES}); +} - } +DECLARE_SHAPE_FN(max_pool_with_argmax) { + auto in = inputShape->at(0);auto dtype = block.numD() ? D_ARG(0) : sd::DataType::INT64; + auto valuesShape = + ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in)); + auto indicesShape = ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(in, dtype)); - DECLARE_SHAPE_FN(max_pool_with_argmax) { - auto in = inputShape->at(0); - auto dtype = block.numD() ? D_ARG(0) : sd::DataType::INT64; - auto valuesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in)); - auto indicesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in, dtype)); - - return SHAPELIST(valuesShape, indicesShape); - } - } + return SHAPELIST(valuesShape, indicesShape); } +} // namespace ops +} // namespace sd #endif - diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/pnormpool2d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/pnormpool2d.cpp index 927627ff86f9..d72edd586d99 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/pnormpool2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/pnormpool2d.cpp @@ -26,209 +26,270 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(pnormpool2d, 1, 1, false, 0, 10) { - - REQUIRE_OK(this->validateInputLengthMatch(block)); - REQUIRE_OK(this->validateInputDimensionsMatch(block)); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_NULLIFIED(0); - - REQUIRE_TRUE(input->rankOf() == 4, 0, "PNORMPOOL2D op: input should have rank of 4, but got %i instead", input->rankOf()); - - auto kY = INT_ARG(0); - auto kX = INT_ARG(1); - auto sY = INT_ARG(2); - auto sX = INT_ARG(3); - auto pY = INT_ARG(4); - auto pX = INT_ARG(5); - auto dY = INT_ARG(6); - auto dX = INT_ARG(7); - bool isSameMode = static_cast(INT_ARG(8)); - auto extraParam0 = INT_ARG(9); - - REQUIRE_TRUE(dY != 0 && dX != 0, 0, "PNORMPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dY, dX); - - int oY = 0; - int oX = 0; - - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW - - if(!isNCHW) { - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] - } - - const auto inY = static_cast(input->sizeAt(2)); - const auto inX = static_cast(input->sizeAt(3)); - - ConvolutionUtils::calcOutSizePool2D(oY, oX, kY, kX, sY, sX, pY, pX, dY, dX, inY, inX, isSameMode); - - if (isSameMode) - ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, kY, kX, sY, sX, dY, dX); - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; - ConvolutionUtils::pooling2d(block, *input, *output, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::PNORM_POOL, extraParam0); - - if(!isNCHW) { - delete input; - delete output; - } - - return Status::OK(); - } - DECLARE_SYN(PnormPool2D, pnormpool2d); - DECLARE_SYN(PnormPool, pnormpool2d); - DECLARE_SYN(pnormpool, pnormpool2d); - - DECLARE_TYPES(pnormpool2d) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - - DECLARE_SHAPE_FN(pnormpool2d) { - auto inShape = inputShape->at(0); - auto shapeOf = shape::shapeOf(inShape); - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - std::vector argI = *(block.getIArguments()); - int kH = INT_ARG(0); - int kW = INT_ARG(1); - int sH = INT_ARG(2); - int sW = INT_ARG(3); - int pH = INT_ARG(4); - int pW = INT_ARG(5); - int dH = INT_ARG(6); - int dW = INT_ARG(7); - int isSameMode = INT_ARG(8); - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW - - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "PNORMPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS = shapeOf[0]; - int iC = isNCHW ? shapeOf[1] : shapeOf[3]; - int iH = isNCHW ? shapeOf[2] : shapeOf[1]; - int iW = isNCHW ? shapeOf[3] : shapeOf[2]; - char order = shape::order(inShape); // output order must be equal to input order - - // calculate output Height/Width - int oH, oW; - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - // allocate memory for new shape - Nd4jLong newShape[4]; - - newShape[0] = bS; - if (isNCHW) { - newShape[1] = iC; - newShape[2] = oH; - newShape[3] = oW; - } else { - newShape[1] = oH; - newShape[2] = oW; - newShape[3] = iC; - } - - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShape), order, newShape, 4))); - } - - - DECLARE_TYPES(pnormpool2d_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +namespace ops { +CUSTOM_OP_IMPL(pnormpool2d, 1, 1, false, 0, 10) { + REQUIRE_OK(this->validateInputLengthMatch(block)); + REQUIRE_OK(this->validateInputDimensionsMatch(block)); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_NULLIFIED(0); + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "PNORMPOOL2D op: input should have rank of 4, but got %i instead", + input->rankOf()); + + auto kY = INT_ARG(0); + auto kX = INT_ARG(1); + auto sY = INT_ARG(2); + auto sX = INT_ARG(3); + auto pY = INT_ARG(4); + auto pX = INT_ARG(5); + auto dY = INT_ARG(6); + auto dX = INT_ARG(7); + bool isSameMode = static_cast(INT_ARG(8)); + auto extraParam0 = INT_ARG(9); + + REQUIRE_TRUE( + dY != 0 && dX != 0, 0, + "PNORMPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dY, + dX); + + int oY = 0; + int oX = 0; + + int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW + + if (!isNCHW) { + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + output = new NDArray( + output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] + } + + const auto inY = static_cast(input->sizeAt(2)); + const auto inX = static_cast(input->sizeAt(3)); + + ConvolutionUtils::calcOutSizePool2D(oY, oX, kY, kX, sY, sX, pY, pX, dY, dX, + inY, inX, isSameMode); + + if (isSameMode) + ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, kY, kX, sY, sX, + dY, dX); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; + ConvolutionUtils::pooling2d(block, *input, *output, kY, kX, sY, sX, pY, pX, + dY, dX, PoolingType::PNORM_POOL, extraParam0); + + if (!isNCHW) { + delete input; + delete output; + } + + return Status::OK(); +} +DECLARE_SYN(PnormPool2D, pnormpool2d); +DECLARE_SYN(PnormPool, pnormpool2d); +DECLARE_SYN(pnormpool, pnormpool2d); + +DECLARE_TYPES(pnormpool2d) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +DECLARE_SHAPE_FN(pnormpool2d) { + auto inShape = inputShape->at(0); + auto shapeOf = shape::shapeOf(inShape); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + auto argI = block.getIArguments(); + int kH = INT_ARG(0); + int kW = INT_ARG(1); + int sH = INT_ARG(2); + int sW = INT_ARG(3); + int pH = INT_ARG(4); + int pW = INT_ARG(5); + int dH = INT_ARG(6); + int dW = INT_ARG(7); + int isSameMode = INT_ARG(8); + int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW + + REQUIRE_TRUE( + dH != 0 && dW != 0, 0, + "PNORMPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, + dW); + + int bS = shapeOf[0]; + int iC = isNCHW ? shapeOf[1] : shapeOf[3]; + int iH = isNCHW ? shapeOf[2] : shapeOf[1]; + int iW = isNCHW ? shapeOf[3] : shapeOf[2]; + char order = + shape::order(inShape); // output order must be equal to input order + + // calculate output Height/Width + int oH, oW; + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, + iH, iW, isSameMode); + // allocate memory for new shape + Nd4jLong newShape[4]; + + newShape[0] = bS; + if (isNCHW) { + newShape[1] = iC; + newShape[2] = oH; + newShape[3] = oW; + } else { + newShape[1] = oH; + newShape[2] = oW; + newShape[3] = iC; + } + + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(inShape), order, newShape, 4))); +} + +DECLARE_TYPES(pnormpool2d_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int pnorm = INT_ARG(9); - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW - - // FIXME: double? - double eps = T_ARG(0); - - REQUIRE_TRUE(input->rankOf() == 4, 0, "PNORMPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "PNORMPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); - std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "PNORMPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "PNORMPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); - - if(!isNCHW) { - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] - } - - // if(isSameMode) // SAME - // ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - // NDArray columnsWrongShape(input->ordering(), {bS, iC, oH, oW, kH, kW}, input->getWorkspace()); - // NDArray* columns = columnsWrongShape.permute({0, 1, 4, 5, 2, 3}); // [bS, iC, oH, oW, kH, kW] -> [bS, iC, kH, kW, oH, oW] - // NDArray* gradOVector = gradO->reshape('c', {(int) gradO->lengthOf(), 1}); - // NDArray* columns2d = columnsWrongShape.reshape('c', {bS*iC*oH*oW, kH*kW}); - // NDArray pNorm(columns2d->shapeInfo(), block.getWorkspace()); - - // input->template applyTransform>(columns, std::vector({(T)kH, (T)kW, (T)sH, (T)sW, (T)pH, (T)pW, (T)dH, (T)dW, (T)0.f, (T)0.f}).data()); - - // columns2d->template applyTransform>(&pNorm); - // pNorm.template applyTransform>(&pNorm, std::vector({(T)pnorm}).data()); - - // NDArray* denomVec = pNorm.sum({1}); - // denomVec->template applyTransform>(std::vector({(T)1. - (T)1. / pnorm}).data()); - // denomVec->template applyScalar>(eps); // in case of 0 - // denomVec->template applyPairwiseTransform>(gradOVector, denomVec, nullptr); - - // if(pnorm != 2) { - // T extraParams[] = {(T)1. - (T)2. / pnorm}; - // pNorm.template applyTransform>(std::vector({(T)1. - (T)2. / pnorm}).data()); - // *columns2d *= pNorm; - // } - - // columns2d->muliColumnVector(denomVec); - - // columns->template applyTransform>(gradI, std::vector({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data()); - - ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 2, pnorm); - - if(!isNCHW) { - delete input; - delete gradI; - delete gradO; - } - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE( + 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_NULLIFIED( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int pnorm = INT_ARG(9); + int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW + + // FIXME: double? + double eps = T_ARG(0); + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "PNORMPOOL2D_BP op: input should have rank of 4, but got %i instead", + input->rankOf()); + REQUIRE_TRUE( + dH != 0 && dW != 0, 0, + "PNORMPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", + dH, dW); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWoC, indWkH, indOoH); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}); + std::vector expectedGradIShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "PNORMPOOL2D_BP op: wrong shape of output's gradients array " + "(next epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, + "PNORMPOOL2D_BP op: wrong shape of input's gradients array " + "(epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradIShape).c_str(), + ShapeUtils::shapeAsString(gradI).c_str()); + + if (!isNCHW) { + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradI = new NDArray( + gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradO = new NDArray( + gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] + } + + // if(isSameMode) // SAME + // ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, + // sW, dH, dW); + + // NDArray columnsWrongShape(input->ordering(), {bS, iC, oH, oW, kH, kW}, + // input->getWorkspace()); NDArray* columns = columnsWrongShape.permute({0, + // 1, 4, 5, 2, 3}); // [bS, iC, oH, oW, kH, kW] + // -> [bS, iC, kH, kW, oH, oW] NDArray* gradOVector = gradO->reshape('c', + // {(int) gradO->lengthOf(), 1}); NDArray* columns2d = + // columnsWrongShape.reshape('c', {bS*iC*oH*oW, kH*kW}); NDArray + // pNorm(columns2d->shapeInfo(), block.workspace()); + + // input->template applyTransform>(columns, + // std::vector({(T)kH, (T)kW, (T)sH, (T)sW, (T)pH, (T)pW, (T)dH, (T)dW, + // (T)0.f, (T)0.f}).data()); + + // columns2d->template applyTransform>(&pNorm); + // pNorm.template applyTransform>(&pNorm, + // std::vector({(T)pnorm}).data()); + + // NDArray* denomVec = pNorm.sum({1}); + // denomVec->template applyTransform>(std::vector({(T)1. - + // (T)1. / pnorm}).data()); denomVec->template + // applyScalar>(eps); // in case of 0 denomVec->template + // applyPairwiseTransform>(gradOVector, denomVec, + // nullptr); + + // if(pnorm != 2) { + // T extraParams[] = {(T)1. - (T)2. / pnorm}; + // pNorm.template applyTransform>(std::vector({(T)1. - + // (T)2. / pnorm}).data()); *columns2d *= pNorm; + // } + + // columns2d->muliColumnVector(denomVec); + + // columns->template applyTransform>(gradI, + // std::vector({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, + // (T)dW}).data()); + + ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, + pH, pW, dH, dW, 2, pnorm); + + if (!isNCHW) { + delete input; + delete gradI; + delete gradO; + } + + return Status::OK(); } DECLARE_SHAPE_FN(pnormpool2d_bp) { - - REQUIRE_TRUE(inputShape->at(0)[0] == 4, 0, "PNORMPOOL2D_BP op: input array must be 4D, but got %i instead!", inputShape->at(0)[0]); - REQUIRE_TRUE(inputShape->at(1)[0] == 4, 0, "PNORMPOOL2D_BP op: output's gradient array (next epsilon) must be 4D, but got %i instead!", inputShape->at(1)[0]); - - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(inputShape->at(0), ArrayOptions::dataType(inputShape->at(1))))); + REQUIRE_TRUE(inputShape->at(0)[0] == 4, 0, + "PNORMPOOL2D_BP op: input array must be 4D, but got %i instead!", + inputShape->at(0)[0]); + REQUIRE_TRUE(inputShape->at(1)[0] == 4, 0, + "PNORMPOOL2D_BP op: output's gradient array (next epsilon) must " + "be 4D, but got %i instead!", + inputShape->at(1)[0]); + + return SHAPELIST( + ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor( + inputShape->at(0), ArrayOptions::dataType(inputShape->at(1))))); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicBidirectionalRNN.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicBidirectionalRNN.cpp index d03f568b5b9e..cabfcc733810 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicBidirectionalRNN.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicBidirectionalRNN.cpp @@ -21,208 +21,340 @@ #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) { - auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] or [bS x time x inSize], shape depends on timeMajor parameter - auto WxFW = INPUT_VARIABLE(1); // input-to-hidden weights for forward RNN, [inSize x numUnitsFW] - auto WhFW = INPUT_VARIABLE(2); // hidden-to-hidden weights for forward RNN, [numUnitsFW x numUnitsFW] - auto bFW = INPUT_VARIABLE(3); // biases for forward RNN, [2*numUnitsFW] - auto WxBW = INPUT_VARIABLE(4); // input-to-hidden weights for backward RNN, [inSize x numUnitsBW] - auto WhBW = INPUT_VARIABLE(5); // hidden-to-hidden weights for backward RNN, [numUnitsBW x numUnitsBW] - auto bBW = INPUT_VARIABLE(6); // biases for backward RNN, [2*v] - - NDArray* h0FW = nullptr; // initial cell output for forward RNN (at time step = 0) [bS x numUnitsFW] - NDArray* h0BW = nullptr; // initial cell output for backward RNN (at time step = 0) [bS x numUnitsBW] - NDArray* maxTimeStep = nullptr; // vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep - - const int timeMajor = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; // if non zero then [time, bS, ...], else [bS, time, ...] - - switch(block.width()) { - case 8: - maxTimeStep = INPUT_VARIABLE(7); - break; - case 9: - h0FW = INPUT_VARIABLE(7); - h0BW = INPUT_VARIABLE(8); - break; - case 10: - h0FW = INPUT_VARIABLE(7); - h0BW = INPUT_VARIABLE(8); - maxTimeStep = INPUT_VARIABLE(9); - break; - } - - auto hFW = OUTPUT_VARIABLE(0); // cell outputs for forward RNN [time x bS x numUnitsFW] or [bS x time x numUnitsFW], shape depends on timeMajor parameter - auto hBW = OUTPUT_VARIABLE(1); // cell outputs for backward RNN [time x bS x numUnitsBW] or [bS x time x numUnitsBW], shape depends on timeMajor parameter - auto hFWFinal = OUTPUT_VARIABLE(2); // final cell out for forward RNN [bS x numUnitsFW] - auto hBWFinal = OUTPUT_VARIABLE(3); // final cell out for backward RNN [bS x numUnitsBF] - - REQUIRE_TRUE(x->rankOf() == 3, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input array must have rank = 3, but got %i instead !", x->rankOf()); - REQUIRE_TRUE(WxFW->rankOf() == 2, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for forward RNN) must have rank = 2, but got %i instead !", WxFW->rankOf()); - REQUIRE_TRUE(WxBW->rankOf() == 2, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for backward RNN) must have rank = 2, but got %i instead !", WxBW->rankOf()); - - const int inRank = x->rankOf(); - const int time = timeMajor ? x->sizeAt(0) : x->sizeAt(1); - const int bS = timeMajor ? x->sizeAt(1) : x->sizeAt(0); - const int numUnitsFW = WxFW->sizeAt(1); - const int numUnitsBW = WxBW->sizeAt(1); - - std::vector expectedWhFWshape = {numUnitsFW, numUnitsFW}; - std::vector expectedWhBWshape = {numUnitsBW, numUnitsBW}; - std::vector expectedbFWshape = {2*numUnitsFW}; - std::vector expectedbBWshape = {2*numUnitsBW}; - REQUIRE_TRUE(WhFW->isSameShape(expectedWhFWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhFWshape).c_str(), ShapeUtils::shapeAsString(WhFW).c_str()); - REQUIRE_TRUE(WhBW->isSameShape(expectedWhBWshape) , 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhBWshape).c_str(), ShapeUtils::shapeAsString(WhBW).c_str()); - REQUIRE_TRUE(bFW->isSameShape(expectedbFWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbFWshape).c_str(), ShapeUtils::shapeAsString(bFW).c_str()); - REQUIRE_TRUE(bBW->isSameShape(expectedbBWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbBWshape).c_str(), ShapeUtils::shapeAsString(bBW).c_str()); - if(h0FW) { - std::vector expectedh0FWshape = {bS, numUnitsFW}; - REQUIRE_TRUE(h0FW->isSameShape(expectedh0FWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0FWshape).c_str(), ShapeUtils::shapeAsString(h0FW).c_str()); - } - if(h0BW) { - std::vector expectedh0BWshape = {bS, numUnitsBW}; - REQUIRE_TRUE(h0BW->isSameShape(expectedh0BWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0BWshape).c_str(), ShapeUtils::shapeAsString(h0BW).c_str()); - } - if(maxTimeStep) { - std::vector expectedmaxTimeStepshape = {bS}; - REQUIRE_TRUE(maxTimeStep->isSameShape(expectedmaxTimeStepshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !", bS, ShapeUtils::shapeAsString(maxTimeStep).c_str()); - } - - - // forward steps - sd::ops::dynamic_rnn dynamicRnn; - auto resultsFW = dynamicRnn.evaluate({x, WxFW, WhFW, bFW, h0FW, maxTimeStep}, {timeMajor}); - hFW->assign(resultsFW.at(0)); // [time x bS x numUnitsFW] or [bS x time x numUnitsFW] - hFWFinal->assign(resultsFW.at(1)); - - auto seqLen = maxTimeStep; - if(seqLen == nullptr) { - // FIXME: which datatype should be used here? - seqLen = new NDArray(x->ordering(), {bS}, sd::DataType::INT64, block.launchContext()); - seqLen->assign(time); // set each element of seqLen to be equal to time - } - - // reverse x - sd::ops::reverse_sequence reverse; - auto resultsIn = timeMajor ? reverse.evaluate({x, seqLen}, {0, 1}) : reverse.evaluate({x, seqLen}, {1, 0}); - REQUIRE_TRUE (resultsIn.status() == ND4J_STATUS_OK, 0, "dynamic_bidirectional_rnn: there is a problem with reverse on the sequence."); - auto revInput = resultsIn.at(0); - - // backward steps - auto resultsBW = dynamicRnn.evaluate({revInput, WxBW, WhBW, bBW, h0BW, maxTimeStep}, {timeMajor}); - auto hBWtemp = resultsBW.at(0); // [time x bS x numUnitsBW] or [ bS x time xnumUnitsBW] - hBWFinal->assign(resultsBW.at(1)); - - // reverse hBWtemp - auto resultsOut = timeMajor ? reverse.evaluate({hBWtemp, seqLen}, {0, 1}) : reverse.evaluate({hBWtemp, seqLen}, {1, 0}); - hBW->assign(resultsOut.at(0)); - - if(seqLen != maxTimeStep) - delete seqLen; - - return Status::OK(); + auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] or [bS x time x + // inSize], shape depends on timeMajor parameter + auto WxFW = INPUT_VARIABLE( + 1); // input-to-hidden weights for forward RNN, [inSize x numUnitsFW] + auto WhFW = INPUT_VARIABLE(2); // hidden-to-hidden weights for forward RNN, + // [numUnitsFW x numUnitsFW] + auto bFW = INPUT_VARIABLE(3); // biases for forward RNN, [2*numUnitsFW] + auto WxBW = INPUT_VARIABLE( + 4); // input-to-hidden weights for backward RNN, [inSize x numUnitsBW] + auto WhBW = INPUT_VARIABLE(5); // hidden-to-hidden weights for backward RNN, + // [numUnitsBW x numUnitsBW] + auto bBW = INPUT_VARIABLE(6); // biases for backward RNN, [2*v] + + NDArray* h0FW = nullptr; // initial cell output for forward RNN (at time + // step = 0) [bS x numUnitsFW] + NDArray* h0BW = nullptr; // initial cell output for backward RNN (at time + // step = 0) [bS x numUnitsBW] + NDArray* maxTimeStep = + nullptr; // vector [bS] containing integer values within [0,time), each + // element of this vector set max time step per each input in + // batch, this means there are no calculations for time >= + // maxTimeStep + + const int timeMajor = + block.numI() > 0 + ? INT_ARG(0) + : 0; // if non zero then [time, bS, ...], else [bS, time, ...] + + switch (block.width()) { + case 8: + maxTimeStep = INPUT_VARIABLE(7); + break; + case 9: + h0FW = INPUT_VARIABLE(7); + h0BW = INPUT_VARIABLE(8); + break; + case 10: + h0FW = INPUT_VARIABLE(7); + h0BW = INPUT_VARIABLE(8); + maxTimeStep = INPUT_VARIABLE(9); + break; + } + + auto hFW = OUTPUT_VARIABLE( + 0); // cell outputs for forward RNN [time x bS x numUnitsFW] or [bS x + // time x numUnitsFW], shape depends on timeMajor parameter + auto hBW = OUTPUT_VARIABLE( + 1); // cell outputs for backward RNN [time x bS x numUnitsBW] or [bS x + // time x numUnitsBW], shape depends on timeMajor parameter + auto hFWFinal = + OUTPUT_VARIABLE(2); // final cell out for forward RNN [bS x numUnitsFW] + auto hBWFinal = + OUTPUT_VARIABLE(3); // final cell out for backward RNN [bS x numUnitsBF] + + REQUIRE_TRUE(x->rankOf() == 3, 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input array must " + "have rank = 3, but got %i instead !", + x->rankOf()); + REQUIRE_TRUE( + WxFW->rankOf() == 2, 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights " + "array (for forward RNN) must have rank = 2, but got %i instead !", + WxFW->rankOf()); + REQUIRE_TRUE( + WxBW->rankOf() == 2, 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights " + "array (for backward RNN) must have rank = 2, but got %i instead !", + WxBW->rankOf()); + + const int inRank = x->rankOf(); + const int time = timeMajor ? x->sizeAt(0) : x->sizeAt(1); + const int bS = timeMajor ? x->sizeAt(1) : x->sizeAt(0); + const int numUnitsFW = WxFW->sizeAt(1); + const int numUnitsBW = WxBW->sizeAt(1); + + std::vector expectedWhFWshape = {numUnitsFW, numUnitsFW}; + std::vector expectedWhBWshape = {numUnitsBW, numUnitsBW}; + std::vector expectedbFWshape = {2 * numUnitsFW}; + std::vector expectedbBWshape = {2 * numUnitsBW}; + REQUIRE_TRUE(WhFW->isSameShape(expectedWhFWshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "hidden-to-hidden weights array (for forward RNN), expected is " + "%s but got %s instead !", + ShapeUtils::shapeAsString(expectedWhFWshape).c_str(), + ShapeUtils::shapeAsString(WhFW).c_str()); + REQUIRE_TRUE(WhBW->isSameShape(expectedWhBWshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "hidden-to-hidden weights array (for backward RNN), expected is " + "%s but got %s instead !", + ShapeUtils::shapeAsString(expectedWhBWshape).c_str(), + ShapeUtils::shapeAsString(WhBW).c_str()); + REQUIRE_TRUE( + bFW->isSameShape(expectedbFWshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array " + "(for forward RNN), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedbFWshape).c_str(), + ShapeUtils::shapeAsString(bFW).c_str()); + REQUIRE_TRUE( + bBW->isSameShape(expectedbBWshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array " + "(for backward RNN), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedbBWshape).c_str(), + ShapeUtils::shapeAsString(bBW).c_str()); + if (h0FW) { + std::vector expectedh0FWshape = {bS, numUnitsFW}; + REQUIRE_TRUE(h0FW->isSameShape(expectedh0FWshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "initial cell output array (for forward RNN), expected is %s " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedh0FWshape).c_str(), + ShapeUtils::shapeAsString(h0FW).c_str()); + } + if (h0BW) { + std::vector expectedh0BWshape = {bS, numUnitsBW}; + REQUIRE_TRUE(h0BW->isSameShape(expectedh0BWshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "initial cell output array (for backward RNN), expected is %s " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedh0BWshape).c_str(), + ShapeUtils::shapeAsString(h0BW).c_str()); + } + if (maxTimeStep) { + std::vector expectedmaxTimeStepshape = {bS}; + REQUIRE_TRUE(maxTimeStep->isSameShape(expectedmaxTimeStepshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "maxTimeStep array, expected is [%i], but got %s instead !", + bS, ShapeUtils::shapeAsString(maxTimeStep).c_str()); + } + + // forward steps + sd::ops::dynamic_rnn dynamicRnn; + auto resultsFW = + dynamicRnn.evaluate({x, WxFW, WhFW, bFW, h0FW, maxTimeStep}, {timeMajor}); + hFW->assign( + resultsFW.at(0)); // [time x bS x numUnitsFW] or [bS x time x numUnitsFW] + hFWFinal->assign(resultsFW.at(1)); + + auto seqLen = maxTimeStep; + if (seqLen == nullptr) { + // FIXME: which datatype should be used here? + seqLen = new NDArray(x->ordering(), {bS}, sd::DataType::INT64, + block.launchContext()); + seqLen->assign(time); // set each element of seqLen to be equal to time + } + + // reverse x + sd::ops::reverse_sequence reverse; + auto resultsIn = timeMajor ? reverse.evaluate({x, seqLen}, {0, 1}) + : reverse.evaluate({x, seqLen}, {1, 0}); + REQUIRE_TRUE(resultsIn.status() == ND4J_STATUS_OK, 0, + "dynamic_bidirectional_rnn: there is a problem with reverse on " + "the sequence."); + auto revInput = resultsIn.at(0); + + // backward steps + auto resultsBW = dynamicRnn.evaluate( + {&revInput, WxBW, WhBW, bBW, h0BW, maxTimeStep}, {timeMajor}); + auto hBWtemp = + resultsBW.at(0); // [time x bS x numUnitsBW] or [ bS x time xnumUnitsBW] + hBWFinal->assign(resultsBW.at(1)); + + // reverse hBWtemp + auto resultsOut = timeMajor ? reverse.evaluate({&hBWtemp, seqLen}, {0, 1}) + : reverse.evaluate({&hBWtemp, seqLen}, {1, 0}); + hBW->assign(resultsOut.at(0)); + + if (seqLen != maxTimeStep) delete seqLen; + + return Status::OK(); } - DECLARE_TYPES(dynamic_bidirectional_rnn) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - -DECLARE_SHAPE_FN(dynamic_bidirectional_rnn) { - - auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] or [bS x time x inSize], shape depends on timeMajor parameter - auto WxFW = INPUT_VARIABLE(1); // input-to-hidden weights for forward RNN, [inSize x numUnitsFW] - auto WhFW = INPUT_VARIABLE(2); // hidden-to-hidden weights for forward RNN, [numUnitsFW x numUnitsFW] - auto bFW = INPUT_VARIABLE(3); // biases for forward RNN, [2*numUnitsFW] - auto WxBW = INPUT_VARIABLE(4); // input-to-hidden weights for backward RNN, [inSize x numUnitsBW] - auto WhBW = INPUT_VARIABLE(5); // hidden-to-hidden weights for backward RNN, [numUnitsBW x numUnitsBW] - auto bBW = INPUT_VARIABLE(6); // biases for backward RNN, [2*numUnitsBW] - - NDArray* h0FW = nullptr; // initial cell output for forward RNN (at time step = 0) [bS x numUnitsFW] - NDArray* h0BW = nullptr; // initial cell output for backward RNN (at time step = 0) [bS x numUnitsBW] - NDArray* maxTimeStep = nullptr; // vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep - - const int timeMajor = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; // if true then [time, bS, ...], else [bS, time, ...] - - switch(block.width()) { - case 8: - maxTimeStep = INPUT_VARIABLE(7); - break; - case 9: - h0FW = INPUT_VARIABLE(7); - h0BW = INPUT_VARIABLE(8); - break; - case 10: - h0FW = INPUT_VARIABLE(7); - h0BW = INPUT_VARIABLE(8); - maxTimeStep = INPUT_VARIABLE(9); - break; - } - - REQUIRE_TRUE(x->rankOf() == 3, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input array must have rank = 3, but got %i instead !", x->rankOf()); - REQUIRE_TRUE(WxFW->rankOf() == 2, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for forward RNN) must have rank = 2, but got %i instead !", WxFW->rankOf()); - REQUIRE_TRUE(WxBW->rankOf() == 2, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for backward RNN) must have rank = 2, but got %i instead !", WxBW->rankOf()); - - const int inRank = x->rankOf(); - const int time = timeMajor ? x->sizeAt(0) : x->sizeAt(1); - const int bS = timeMajor ? x->sizeAt(1) : x->sizeAt(0); - const int numUnitsFW = WxFW->sizeAt(1); - const int numUnitsBW = WxBW->sizeAt(1); - - - std::vector expectedWhFWshape = {numUnitsFW, numUnitsFW}; - std::vector expectedWhBWshape = {numUnitsBW, numUnitsBW}; - std::vector expectedbFWshape = {2*numUnitsFW}; - std::vector expectedbBWshape = {2*numUnitsBW}; - - REQUIRE_TRUE(WhFW->isSameShape(expectedWhFWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhFWshape).c_str(), ShapeUtils::shapeAsString(WhFW).c_str()); - REQUIRE_TRUE(WhBW->isSameShape(expectedWhBWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhBWshape).c_str(), ShapeUtils::shapeAsString(WhBW).c_str()); - REQUIRE_TRUE(bFW->isSameShape(expectedbFWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbFWshape).c_str(), ShapeUtils::shapeAsString(bFW).c_str()); - REQUIRE_TRUE(bBW->isSameShape(expectedbBWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbBWshape).c_str(), ShapeUtils::shapeAsString(bBW).c_str()); - if(h0FW) { - std::vector expectedh0FWshape = {bS, numUnitsFW}; - REQUIRE_TRUE(h0FW->isSameShape(expectedh0FWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0FWshape).c_str(), ShapeUtils::shapeAsString(h0FW).c_str()); - } - if(h0BW) { - std::vector expectedh0BWshape = {bS, numUnitsBW}; - REQUIRE_TRUE(h0BW->isSameShape(expectedh0BWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0BWshape).c_str(), ShapeUtils::shapeAsString(h0BW).c_str()); - } - if(maxTimeStep) { - std::vector expectedmaxTimeStepshape = {bS}; - REQUIRE_TRUE(maxTimeStep->isSameShape(expectedmaxTimeStepshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !", bS, ShapeUtils::shapeAsString(maxTimeStep).c_str()); - } - - // evaluate output shapeInfos - Nd4jLong *hFWShapeInfo(nullptr), *hBWShapeInfo(nullptr), *hFWFinalPrevShapeInfo(nullptr), *hBWFinalPrevShapeInfo(nullptr); - ALLOCATE(hFWShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong); - ALLOCATE(hBWShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong); - ALLOCATE(hFWFinalPrevShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank-1), Nd4jLong); - ALLOCATE(hBWFinalPrevShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank-1), Nd4jLong); - - hFWShapeInfo[0] = hBWShapeInfo[0] = inRank; - hFWShapeInfo[1] = hBWShapeInfo[1] = timeMajor ? time : bS; - hFWShapeInfo[2] = hBWShapeInfo[2] = timeMajor ? bS : time; - hFWShapeInfo[3] = numUnitsFW; - hBWShapeInfo[3] = numUnitsBW; - hFWFinalPrevShapeInfo[0] = hBWFinalPrevShapeInfo[0] = inRank-1; - hFWFinalPrevShapeInfo[1] = hBWFinalPrevShapeInfo[1] = bS; - hFWFinalPrevShapeInfo[2] = numUnitsFW; - hBWFinalPrevShapeInfo[2] = numUnitsBW; - - ShapeUtils::updateStridesAndType(hFWShapeInfo, x->shapeInfo(), x->ordering()); - ShapeUtils::updateStridesAndType(hBWShapeInfo, x->shapeInfo(), x->ordering()); - ShapeUtils::updateStridesAndType(hFWFinalPrevShapeInfo, x->shapeInfo(), x->ordering()); - ShapeUtils::updateStridesAndType(hBWFinalPrevShapeInfo, x->shapeInfo(), x->ordering()); - - return SHAPELIST(CONSTANT(hFWShapeInfo), CONSTANT(hBWShapeInfo), CONSTANT(hFWFinalPrevShapeInfo), CONSTANT(hBWFinalPrevShapeInfo)); +DECLARE_TYPES(dynamic_bidirectional_rnn) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - -} +DECLARE_SHAPE_FN(dynamic_bidirectional_rnn) { + auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] or [bS x time x + // inSize], shape depends on timeMajor parameter + auto WxFW = INPUT_VARIABLE( + 1); // input-to-hidden weights for forward RNN, [inSize x numUnitsFW] + auto WhFW = INPUT_VARIABLE(2); // hidden-to-hidden weights for forward RNN, + // [numUnitsFW x numUnitsFW] + auto bFW = INPUT_VARIABLE(3); // biases for forward RNN, [2*numUnitsFW] + auto WxBW = INPUT_VARIABLE( + 4); // input-to-hidden weights for backward RNN, [inSize x numUnitsBW] + auto WhBW = INPUT_VARIABLE(5); // hidden-to-hidden weights for backward RNN, + // [numUnitsBW x numUnitsBW] + auto bBW = INPUT_VARIABLE(6); // biases for backward RNN, [2*numUnitsBW] + + NDArray* h0FW = nullptr; // initial cell output for forward RNN (at time + // step = 0) [bS x numUnitsFW] + NDArray* h0BW = nullptr; // initial cell output for backward RNN (at time + // step = 0) [bS x numUnitsBW] + NDArray* maxTimeStep = + nullptr; // vector [bS] containing integer values within [0,time), each + // element of this vector set max time step per each input in + // batch, this means there are no calculations for time >= + // maxTimeStep + + const int timeMajor = + block.numI() > 0 + ? INT_ARG(0) + : 0; // if true then [time, bS, ...], else [bS, time, ...] + + switch (block.width()) { + case 8: + maxTimeStep = INPUT_VARIABLE(7); + break; + case 9: + h0FW = INPUT_VARIABLE(7); + h0BW = INPUT_VARIABLE(8); + break; + case 10: + h0FW = INPUT_VARIABLE(7); + h0BW = INPUT_VARIABLE(8); + maxTimeStep = INPUT_VARIABLE(9); + break; + } + + REQUIRE_TRUE(x->rankOf() == 3, 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input array must " + "have rank = 3, but got %i instead !", + x->rankOf()); + REQUIRE_TRUE( + WxFW->rankOf() == 2, 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights " + "array (for forward RNN) must have rank = 2, but got %i instead !", + WxFW->rankOf()); + REQUIRE_TRUE( + WxBW->rankOf() == 2, 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights " + "array (for backward RNN) must have rank = 2, but got %i instead !", + WxBW->rankOf()); + + const int inRank = x->rankOf(); + const int time = timeMajor ? x->sizeAt(0) : x->sizeAt(1); + const int bS = timeMajor ? x->sizeAt(1) : x->sizeAt(0); + const int numUnitsFW = WxFW->sizeAt(1); + const int numUnitsBW = WxBW->sizeAt(1); + + std::vector expectedWhFWshape = {numUnitsFW, numUnitsFW}; + std::vector expectedWhBWshape = {numUnitsBW, numUnitsBW}; + std::vector expectedbFWshape = {2 * numUnitsFW}; + std::vector expectedbBWshape = {2 * numUnitsBW}; + + REQUIRE_TRUE(WhFW->isSameShape(expectedWhFWshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "hidden-to-hidden weights array (for forward RNN), expected is " + "%s but got %s instead !", + ShapeUtils::shapeAsString(expectedWhFWshape).c_str(), + ShapeUtils::shapeAsString(WhFW).c_str()); + REQUIRE_TRUE(WhBW->isSameShape(expectedWhBWshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "hidden-to-hidden weights array (for backward RNN), expected is " + "%s but got %s instead !", + ShapeUtils::shapeAsString(expectedWhBWshape).c_str(), + ShapeUtils::shapeAsString(WhBW).c_str()); + REQUIRE_TRUE( + bFW->isSameShape(expectedbFWshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array " + "(for forward RNN), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedbFWshape).c_str(), + ShapeUtils::shapeAsString(bFW).c_str()); + REQUIRE_TRUE( + bBW->isSameShape(expectedbBWshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array " + "(for backward RNN), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedbBWshape).c_str(), + ShapeUtils::shapeAsString(bBW).c_str()); + if (h0FW) { + std::vector expectedh0FWshape = {bS, numUnitsFW}; + REQUIRE_TRUE(h0FW->isSameShape(expectedh0FWshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "initial cell output array (for forward RNN), expected is %s " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedh0FWshape).c_str(), + ShapeUtils::shapeAsString(h0FW).c_str()); + } + if (h0BW) { + std::vector expectedh0BWshape = {bS, numUnitsBW}; + REQUIRE_TRUE(h0BW->isSameShape(expectedh0BWshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "initial cell output array (for backward RNN), expected is %s " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedh0BWshape).c_str(), + ShapeUtils::shapeAsString(h0BW).c_str()); + } + if (maxTimeStep) { + std::vector expectedmaxTimeStepshape = {bS}; + REQUIRE_TRUE(maxTimeStep->isSameShape(expectedmaxTimeStepshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "maxTimeStep array, expected is [%i], but got %s instead !", + bS, ShapeUtils::shapeAsString(maxTimeStep).c_str()); + } + + // evaluate output shapeInfos + Nd4jLong *hFWShapeInfo(nullptr), *hBWShapeInfo(nullptr), + *hFWFinalPrevShapeInfo(nullptr), *hBWFinalPrevShapeInfo(nullptr); + ALLOCATE(hFWShapeInfo, block.workspace(), shape::shapeInfoLength(inRank), + Nd4jLong); + ALLOCATE(hBWShapeInfo, block.workspace(), shape::shapeInfoLength(inRank), + Nd4jLong); + ALLOCATE(hFWFinalPrevShapeInfo, block.workspace(), + shape::shapeInfoLength(inRank - 1), Nd4jLong); + ALLOCATE(hBWFinalPrevShapeInfo, block.workspace(), + shape::shapeInfoLength(inRank - 1), Nd4jLong); + + hFWShapeInfo[0] = hBWShapeInfo[0] = inRank; + hFWShapeInfo[1] = hBWShapeInfo[1] = timeMajor ? time : bS; + hFWShapeInfo[2] = hBWShapeInfo[2] = timeMajor ? bS : time; + hFWShapeInfo[3] = numUnitsFW; + hBWShapeInfo[3] = numUnitsBW; + hFWFinalPrevShapeInfo[0] = hBWFinalPrevShapeInfo[0] = inRank - 1; + hFWFinalPrevShapeInfo[1] = hBWFinalPrevShapeInfo[1] = bS; + hFWFinalPrevShapeInfo[2] = numUnitsFW; + hBWFinalPrevShapeInfo[2] = numUnitsBW; + + ShapeUtils::updateStridesAndType(hFWShapeInfo, x->shapeInfo(), x->ordering()); + ShapeUtils::updateStridesAndType(hBWShapeInfo, x->shapeInfo(), x->ordering()); + ShapeUtils::updateStridesAndType(hFWFinalPrevShapeInfo, x->shapeInfo(), + x->ordering()); + ShapeUtils::updateStridesAndType(hBWFinalPrevShapeInfo, x->shapeInfo(), + x->ordering()); + + return SHAPELIST(CONSTANT(hFWShapeInfo), CONSTANT(hBWShapeInfo), + CONSTANT(hFWFinalPrevShapeInfo), + CONSTANT(hBWFinalPrevShapeInfo)); } +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicRNN.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicRNN.cpp index 9836d65cedb6..0718267534fa 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicRNN.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicRNN.cpp @@ -19,155 +19,222 @@ // #include -#include +#include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(dynamic_rnn, 4, 2, false, 0, 0) { - - auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] or [bS x time x inSize], depends on timeMajor parameter - auto Wx = INPUT_VARIABLE(1); // input-to-hidden weights, [inSize x numUnits] - auto Wh = INPUT_VARIABLE(2); // hidden-to-hidden weights, [numUnits x numUnits] - auto b = INPUT_VARIABLE(3); // biases for, [2*numUnits] - - NDArray* h0 = nullptr; // initial cell output (at time step = 0) [bS x numUnits] - NDArray* maxTimeStep = nullptr; // vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep - - const int timeMajor = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; // if true then [time, bS, ...], else [bS, time, ...] - - if(block.width() == 5) { - if ((*INPUT_VARIABLE(4)).rankOf() == 2) - h0 = INPUT_VARIABLE(4); - else - maxTimeStep = INPUT_VARIABLE(4); - } - else if(block.width() == 6) { - h0 = INPUT_VARIABLE(4); - maxTimeStep = INPUT_VARIABLE(5); - } - - auto h = OUTPUT_VARIABLE(0); // cell outputs [time x bS x numUnits] or [bS x time x numUnits], depends on timeMajor parameter - auto hFinal = OUTPUT_VARIABLE(1); // at the end it will store cell final non-zero output [bS x numUnits] - - REQUIRE_TRUE(x->rankOf() == 3, 0, "DYNAMIC_RNN custom operation: input array x must have rank = 3, but got %i instead !", x->rankOf()); - REQUIRE_TRUE(Wx->rankOf() == 2, 0, "DYNAMIC_RNN custom operation: input-to-hidden weights array must have rank = 2, but got %i instead !", Wx->rankOf()); - - const int inRank = x->rankOf(); - const int time = timeMajor ? x->sizeAt(0) : x->sizeAt(1); - const int bS = timeMajor ? x->sizeAt(1) : x->sizeAt(0); - const int numUnits = Wx->sizeAt(1); - - std::vector expectedWhShape = {numUnits, numUnits}; - std::vector expectedBShape = {2*numUnits}; - REQUIRE_TRUE(Wh->isSameShape(expectedWhShape), 0, "DYNAMIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWhShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); - REQUIRE_TRUE(b->isSameShape(expectedBShape), 0, "DYNAMIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedBShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - if(h0) { - std::vector expectedh0Shape = {bS, numUnits}; - REQUIRE_TRUE(h0->isSameShape(expectedh0Shape), 0, "DYNAMIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0Shape).c_str(), ShapeUtils::shapeAsString(h0).c_str()); - } - if(maxTimeStep) { - std::vector expectedmaxTimeStepShape = {bS}; - REQUIRE_TRUE(maxTimeStep->isSameShape(expectedmaxTimeStepShape), 0, "DYNAMIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedmaxTimeStepShape).c_str(), ShapeUtils::shapeAsString(maxTimeStep).c_str()); - } - - if(timeMajor == false) { - x = new NDArray(x->permute({1, 0, 2})); // [bS x time x inSize] -> [time x bS x inSize] - h = new NDArray(h->permute({1, 0, 2})); // [bS x time x numUnits] -> [time x bS x numUnits] - } - - helpers::rnnTimeLoop(block.launchContext(), x, Wx, Wh, b, h0, maxTimeStep, h, hFinal); - - if(timeMajor == false) { - delete x; - delete h; - } - - return Status::OK(); + auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] or [bS x time x + // inSize], depends on timeMajor parameter + auto Wx = + INPUT_VARIABLE(1); // input-to-hidden weights, [inSize x numUnits] + auto Wh = + INPUT_VARIABLE(2); // hidden-to-hidden weights, [numUnits x numUnits] + auto b = INPUT_VARIABLE(3); // biases for, [2*numUnits] + + NDArray* h0 = + nullptr; // initial cell output (at time step = 0) [bS x numUnits] + NDArray* maxTimeStep = + nullptr; // vector [bS] containing integer values within [0,time), each + // element of this vector set max time step per each input in + // batch, this means there are no calculations for time >= + // maxTimeStep + + const int timeMajor = + block.numI() > 0 + ? INT_ARG(0) + : 0; // if true then [time, bS, ...], else [bS, time, ...] + + if (block.width() == 5) { + if ((*INPUT_VARIABLE(4)).rankOf() == 2) + h0 = INPUT_VARIABLE(4); + else + maxTimeStep = INPUT_VARIABLE(4); + } else if (block.width() == 6) { + h0 = INPUT_VARIABLE(4); + maxTimeStep = INPUT_VARIABLE(5); + } + + auto h = + OUTPUT_VARIABLE(0); // cell outputs [time x bS x numUnits] or [bS x time + // x numUnits], depends on timeMajor parameter + auto hFinal = OUTPUT_VARIABLE(1); // at the end it will store cell final + // non-zero output [bS x numUnits] + + REQUIRE_TRUE(x->rankOf() == 3, 0, + "DYNAMIC_RNN custom operation: input array x must have rank = " + "3, but got %i instead !", + x->rankOf()); + REQUIRE_TRUE(Wx->rankOf() == 2, 0, + "DYNAMIC_RNN custom operation: input-to-hidden weights array " + "must have rank = 2, but got %i instead !", + Wx->rankOf()); + + const int inRank = x->rankOf(); + const int time = timeMajor ? x->sizeAt(0) : x->sizeAt(1); + const int bS = timeMajor ? x->sizeAt(1) : x->sizeAt(0); + const int numUnits = Wx->sizeAt(1); + + std::vector expectedWhShape = {numUnits, numUnits}; + std::vector expectedBShape = {2 * numUnits}; + REQUIRE_TRUE(Wh->isSameShape(expectedWhShape), 0, + "DYNAMIC_RNN custom operation: wrong shape of hidden-to-hidden " + "weights array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWhShape).c_str(), + ShapeUtils::shapeAsString(Wh).c_str()); + REQUIRE_TRUE(b->isSameShape(expectedBShape), 0, + "DYNAMIC_RNN custom operation: wrong shape of biases array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedBShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + if (h0) { + std::vector expectedh0Shape = {bS, numUnits}; + REQUIRE_TRUE(h0->isSameShape(expectedh0Shape), 0, + "DYNAMIC_RNN custom operation: wrong shape of initial cell " + "output array, expected is %s but got %s instead !", + ShapeUtils::shapeAsString(expectedh0Shape).c_str(), + ShapeUtils::shapeAsString(h0).c_str()); + } + if (maxTimeStep) { + std::vector expectedmaxTimeStepShape = {bS}; + REQUIRE_TRUE(maxTimeStep->isSameShape(expectedmaxTimeStepShape), 0, + "DYNAMIC_RNN custom operation: wrong shape of maxTimeStep " + "array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedmaxTimeStepShape).c_str(), + ShapeUtils::shapeAsString(maxTimeStep).c_str()); + } + + if (timeMajor == false) { + x = new NDArray(x->permute( + {1, 0, 2})); // [bS x time x inSize] -> [time x bS x inSize] + h = new NDArray(h->permute( + {1, 0, 2})); // [bS x time x numUnits] -> [time x bS x numUnits] + } + + helpers::rnnTimeLoop(block.launchContext(), x, Wx, Wh, b, h0, maxTimeStep, h, + hFinal); + + if (timeMajor == false) { + delete x; + delete h; + } + + return Status::OK(); } - - DECLARE_TYPES(dynamic_rnn) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedInputTypes(3, {ALL_FLOATS}) - ->setAllowedInputTypes(4, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(5, {ALL_FLOATS, ALL_INTS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_FLOATS}); - } - +DECLARE_TYPES(dynamic_rnn) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedInputTypes(3, {ALL_FLOATS}) + ->setAllowedInputTypes(4, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(5, {ALL_FLOATS, ALL_INTS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_FLOATS}); +} DECLARE_SHAPE_FN(dynamic_rnn) { - - auto xShapeInfo = inputShape->at(0); // input [time x bS x inSize] or [bS x time x inSize], depends on timeMajor parameter - auto WxShapeInfo = inputShape->at(1); // input-to-hidden weights, [inSize x numUnits] - auto WhShapeInfo = inputShape->at(2); // hidden-to-hidden weights, [numUnits x numUnits] - auto bShapeInfo = inputShape->at(3); // biases for, [2*numUnits] - - Nd4jLong const* h0ShapeInfo = nullptr; // initial cell output (at time step = 0) [bS x numUnits] - Nd4jLong const* maxTimeStepShapeInfo = nullptr; // vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep - - const int timeMajor = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; // if true then [time, bS, ...], else [bS, time, ...] - - if(block.width() == 5) { - if (inputShape->at(4)[0] == 2) - h0ShapeInfo = inputShape->at(4); - else - maxTimeStepShapeInfo = inputShape->at(4); - } - else if(block.width() == 6) { - h0ShapeInfo = inputShape->at(4); - maxTimeStepShapeInfo = inputShape->at(5); - } - - REQUIRE_TRUE(xShapeInfo[0] == 3, 0, "DYNAMIC_RNN custom operation: input array x must have rank = 3, but got %i instead !", xShapeInfo[0]); - REQUIRE_TRUE(WxShapeInfo[0] == 2, 0, "DYNAMIC_RNN custom operation: input-to-hidden weights array must have rank = 2, but got %i instead !", WxShapeInfo[0]); - - const int inRank = xShapeInfo[0]; - const int time = timeMajor ? xShapeInfo[1] : xShapeInfo[2]; - const int bS = timeMajor ? xShapeInfo[2] : xShapeInfo[1]; - const int numUnits = WxShapeInfo[2]; - - - std::vector expectedWhShape = {numUnits, numUnits}; - std::vector expectedBShape = {2*numUnits}; - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, expectedWhShape), 0, "DYNAMIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWhShape).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, expectedBShape), 0, "DYNAMIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedBShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str()); - if(h0ShapeInfo) { - std::vector expectedh0Shape = {bS, numUnits}; - REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0ShapeInfo, expectedh0Shape), 0, "DYNAMIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0Shape).c_str(), ShapeUtils::shapeAsString(h0ShapeInfo).c_str()); - } - if(maxTimeStepShapeInfo) { - std::vector expectedmaxTimeStepShape = {bS}; - REQUIRE_TRUE(ShapeUtils::areShapesEqual(maxTimeStepShapeInfo, expectedmaxTimeStepShape), 0, "DYNAMIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedmaxTimeStepShape).c_str(), ShapeUtils::shapeAsString(maxTimeStepShapeInfo).c_str()); - } - - // evaluate output shapeInfos - Nd4jLong *hShapeInfo(nullptr), *hPrevShapeInfo(nullptr); - ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong); - ALLOCATE(hPrevShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank-1), Nd4jLong); - - hShapeInfo[0] = inRank; - hPrevShapeInfo[0] = inRank-1; - hShapeInfo[1] = timeMajor ? time : bS; - hShapeInfo[2] = timeMajor ? bS : time; - hPrevShapeInfo[1] = bS; - hShapeInfo[3] = hPrevShapeInfo[2] = numUnits; - - ShapeUtils::updateStridesAndType(hShapeInfo, WhShapeInfo, shape::order(xShapeInfo)); - ShapeUtils::updateStridesAndType(hPrevShapeInfo, WhShapeInfo, shape::order(xShapeInfo)); - - return SHAPELIST(CONSTANT(hShapeInfo), CONSTANT(hPrevShapeInfo)); + auto xShapeInfo = + inputShape->at(0); // input [time x bS x inSize] or [bS x time x inSize], + // depends on timeMajor parameter + auto WxShapeInfo = + inputShape->at(1); // input-to-hidden weights, [inSize x numUnits] + auto WhShapeInfo = + inputShape->at(2); // hidden-to-hidden weights, [numUnits x numUnits] + auto bShapeInfo = inputShape->at(3); // biases for, [2*numUnits] + + Nd4jLong const* h0ShapeInfo = + nullptr; // initial cell output (at time step = 0) [bS x numUnits] + Nd4jLong const* maxTimeStepShapeInfo = + nullptr; // vector [bS] containing integer values within [0,time), each + // element of this vector set max time step per each input in + // batch, this means there are no calculations for time >= + // maxTimeStep + + const int timeMajor = + block.numI() > 0 + ? INT_ARG(0) + : 0; // if true then [time, bS, ...], else [bS, time, ...] + + if (block.width() == 5) { + if (inputShape->at(4)[0] == 2) + h0ShapeInfo = inputShape->at(4); + else + maxTimeStepShapeInfo = inputShape->at(4); + } else if (block.width() == 6) { + h0ShapeInfo = inputShape->at(4); + maxTimeStepShapeInfo = inputShape->at(5); + } + + REQUIRE_TRUE(xShapeInfo[0] == 3, 0, + "DYNAMIC_RNN custom operation: input array x must have rank = " + "3, but got %i instead !", + xShapeInfo[0]); + REQUIRE_TRUE(WxShapeInfo[0] == 2, 0, + "DYNAMIC_RNN custom operation: input-to-hidden weights array " + "must have rank = 2, but got %i instead !", + WxShapeInfo[0]); + + const int inRank = xShapeInfo[0]; + const int time = timeMajor ? xShapeInfo[1] : xShapeInfo[2]; + const int bS = timeMajor ? xShapeInfo[2] : xShapeInfo[1]; + const int numUnits = WxShapeInfo[2]; + + std::vector expectedWhShape = {numUnits, numUnits}; + std::vector expectedBShape = {2 * numUnits}; + REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, expectedWhShape), 0, + "DYNAMIC_RNN custom operation: wrong shape of hidden-to-hidden " + "weights array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWhShape).c_str(), + ShapeUtils::shapeAsString(WhShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, expectedBShape), 0, + "DYNAMIC_RNN custom operation: wrong shape of biases array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedBShape).c_str(), + ShapeUtils::shapeAsString(bShapeInfo).c_str()); + if (h0ShapeInfo) { + std::vector expectedh0Shape = {bS, numUnits}; + REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0ShapeInfo, expectedh0Shape), 0, + "DYNAMIC_RNN custom operation: wrong shape of initial cell " + "output array, expected is %s but got %s instead !", + ShapeUtils::shapeAsString(expectedh0Shape).c_str(), + ShapeUtils::shapeAsString(h0ShapeInfo).c_str()); + } + if (maxTimeStepShapeInfo) { + std::vector expectedmaxTimeStepShape = {bS}; + REQUIRE_TRUE(ShapeUtils::areShapesEqual(maxTimeStepShapeInfo, + expectedmaxTimeStepShape), + 0, + "DYNAMIC_RNN custom operation: wrong shape of maxTimeStep " + "array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedmaxTimeStepShape).c_str(), + ShapeUtils::shapeAsString(maxTimeStepShapeInfo).c_str()); + } + + // evaluate output shapeInfos + Nd4jLong *hShapeInfo(nullptr), *hPrevShapeInfo(nullptr); + ALLOCATE(hShapeInfo, block.workspace(), shape::shapeInfoLength(inRank), + Nd4jLong); + ALLOCATE(hPrevShapeInfo, block.workspace(), + shape::shapeInfoLength(inRank - 1), Nd4jLong); + + hShapeInfo[0] = inRank; + hPrevShapeInfo[0] = inRank - 1; + hShapeInfo[1] = timeMajor ? time : bS; + hShapeInfo[2] = timeMajor ? bS : time; + hPrevShapeInfo[1] = bS; + hShapeInfo[3] = hPrevShapeInfo[2] = numUnits; + + ShapeUtils::updateStridesAndType(hShapeInfo, WhShapeInfo, + shape::order(xShapeInfo)); + ShapeUtils::updateStridesAndType(hPrevShapeInfo, WhShapeInfo, + shape::order(xShapeInfo)); + + return SHAPELIST(CONSTANT(hShapeInfo), CONSTANT(hPrevShapeInfo)); } - - - - -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp index 0be3c839353a..7331513f994a 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp @@ -23,167 +23,249 @@ #if NOT_EXCLUDED(OP_gru) #include -#include +#include namespace sd { namespace ops { - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(gru, 5, 1, false, 0, 0) { - auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] - auto hI = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] - - auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] - auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] - auto b = INPUT_VARIABLE(4); // biases, [3*nOut] - - auto h = OUTPUT_VARIABLE(0); // cell outputs [time, bS, nOut], that is per each time step - - const int bS = x->sizeAt(1); - const int nIn = x->sizeAt(2); - const int nOut = hI->sizeAt(1); - - const std::vector h0CorrectShape = {bS, nOut}; - const std::vector wxCorrectShape = {nIn, 3*nOut}; - const std::vector whCorrectShape = {nOut, 3*nOut}; - const std::vector bCorrectShape = {3*nOut}; - - REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); - REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); - REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - - helpers::gruTimeLoop(block.launchContext(), x, hI, Wx, Wh, b, h); - - return Status::OK(); + auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] + auto hI = + INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] + + auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] + auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] + auto b = INPUT_VARIABLE(4); // biases, [3*nOut] + + auto h = OUTPUT_VARIABLE( + 0); // cell outputs [time, bS, nOut], that is per each time step + + const int bS = x->sizeAt(1); + const int nIn = x->sizeAt(2); + const int nOut = hI->sizeAt(1); + + const std::vector h0CorrectShape = {bS, nOut}; + const std::vector wxCorrectShape = {nIn, 3 * nOut}; + const std::vector whCorrectShape = {nOut, 3 * nOut}; + const std::vector bCorrectShape = {3 * nOut}; + + REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, + "GRU operation: wrong shape of previous cell output array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(h0CorrectShape).c_str(), + ShapeUtils::shapeAsString(hI).c_str()); + REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, + "GRU operation: wrong shape of input-to-hidden weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(wxCorrectShape).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, + "GRU operation: wrong shape of hidden-to-hidden weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(whCorrectShape).c_str(), + ShapeUtils::shapeAsString(Wh).c_str()); + REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, + "GRU operation: wrong shape of biases array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(bCorrectShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + + helpers::gruTimeLoop(block.launchContext(), x, hI, Wx, Wh, b, h); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(gru) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(gru) { - - auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] - auto hI = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] - - auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] - auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] - auto b = INPUT_VARIABLE(4); // biases, [3*nOut] - - const int time = x->sizeAt(0); - const int bS = x->sizeAt(1); - const int nIn = x->sizeAt(2); - const int nOut = hI->sizeAt(1); - - const std::vector h0CorrectShape = {bS, nOut}; - const std::vector wxCorrectShape = {nIn, 3*nOut}; - const std::vector whCorrectShape = {nOut, 3*nOut}; - const std::vector bCorrectShape = {3*nOut}; - - REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); - REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); - REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - - auto hShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(hI->dataType(), hI->ordering(), {time, bS, nOut}); - - return SHAPELIST(hShapeInfo); + auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] + auto hI = + INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] + + auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] + auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] + auto b = INPUT_VARIABLE(4); // biases, [3*nOut] + + const int time = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int nIn = x->sizeAt(2); + const int nOut = hI->sizeAt(1); + + const std::vector h0CorrectShape = {bS, nOut}; + const std::vector wxCorrectShape = {nIn, 3 * nOut}; + const std::vector whCorrectShape = {nOut, 3 * nOut}; + const std::vector bCorrectShape = {3 * nOut}; + + REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, + "GRU operation: wrong shape of previous cell output array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(h0CorrectShape).c_str(), + ShapeUtils::shapeAsString(hI).c_str()); + REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, + "GRU operation: wrong shape of input-to-hidden weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(wxCorrectShape).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, + "GRU operation: wrong shape of hidden-to-hidden weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(whCorrectShape).c_str(), + ShapeUtils::shapeAsString(Wh).c_str()); + REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, + "GRU operation: wrong shape of biases array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(bCorrectShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + + auto hShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo( + hI->dataType(), hI->ordering(), {time, bS, nOut}); + + return SHAPELIST(hShapeInfo); } - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(gru_bp, 6, 5, false, 0, 0) { - - auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] - auto hI = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] - - auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] - auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] - auto b = INPUT_VARIABLE(4); // biases, [3*nOut] - - auto dLdh = INPUT_VARIABLE(5); // gradient vs. ff output, [time, bS, nOut] - - auto dLdx = OUTPUT_VARIABLE(0); // gradient vs. ff input, [time, bS, nIn] - auto dLdhI = OUTPUT_NULLIFIED(1); // gradient vs. initial cell output, [bS, nOut] - auto dLdWx = OUTPUT_NULLIFIED(2); // gradient vs. input-to-hidden weights, [nIn, 3*nOut] - auto dLdWh = OUTPUT_NULLIFIED(3); // gradient vs. hidden-to-hidden weights, [nOut, 3*nOut] - auto dLdb = OUTPUT_NULLIFIED(4); // gradient vs. biases [3*nOut] - - const int time = x->sizeAt(0); - const int bS = x->sizeAt(1); - const int nIn = x->sizeAt(2); - const int nOut = hI->sizeAt(1); - - const std::vector h0CorrectShape = {bS, nOut}; - const std::vector wxCorrectShape = {nIn, 3*nOut}; - const std::vector whCorrectShape = {nOut, 3*nOut}; - const std::vector bCorrectShape = {3*nOut}; - const std::vector hCorrectShape = {time, bS, nOut}; - - REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, "GRU_BP operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); - REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU_BP operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU_BP operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); - REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU_BP operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - REQUIRE_TRUE(dLdh->isSameShape(hCorrectShape),0, "GRU_BP operation: wrong shape of gradient vs. ff output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str()); - - helpers::gruTimeLoopBp(block.launchContext(), x, hI, Wx, Wh, b, dLdh, dLdx, dLdhI, dLdWx, dLdWh, dLdb); - - return Status::OK(); + auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] + auto hI = + INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] + + auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] + auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] + auto b = INPUT_VARIABLE(4); // biases, [3*nOut] + + auto dLdh = INPUT_VARIABLE(5); // gradient vs. ff output, [time, bS, nOut] + + auto dLdx = OUTPUT_VARIABLE(0); // gradient vs. ff input, [time, bS, nIn] + auto dLdhI = + OUTPUT_NULLIFIED(1); // gradient vs. initial cell output, [bS, nOut] + auto dLdWx = OUTPUT_NULLIFIED( + 2); // gradient vs. input-to-hidden weights, [nIn, 3*nOut] + auto dLdWh = OUTPUT_NULLIFIED( + 3); // gradient vs. hidden-to-hidden weights, [nOut, 3*nOut] + auto dLdb = OUTPUT_NULLIFIED(4); // gradient vs. biases [3*nOut] + + const int time = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int nIn = x->sizeAt(2); + const int nOut = hI->sizeAt(1); + + const std::vector h0CorrectShape = {bS, nOut}; + const std::vector wxCorrectShape = {nIn, 3 * nOut}; + const std::vector whCorrectShape = {nOut, 3 * nOut}; + const std::vector bCorrectShape = {3 * nOut}; + const std::vector hCorrectShape = {time, bS, nOut}; + + REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, + "GRU_BP operation: wrong shape of previous cell output array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(h0CorrectShape).c_str(), + ShapeUtils::shapeAsString(hI).c_str()); + REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, + "GRU_BP operation: wrong shape of input-to-hidden weights " + "array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(wxCorrectShape).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, + "GRU_BP operation: wrong shape of hidden-to-hidden weights " + "array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(whCorrectShape).c_str(), + ShapeUtils::shapeAsString(Wh).c_str()); + REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, + "GRU_BP operation: wrong shape of biases array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(bCorrectShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + REQUIRE_TRUE(dLdh->isSameShape(hCorrectShape), 0, + "GRU_BP operation: wrong shape of gradient vs. ff output, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(hCorrectShape).c_str(), + ShapeUtils::shapeAsString(dLdh).c_str()); + + helpers::gruTimeLoopBp(block.launchContext(), x, hI, Wx, Wh, b, dLdh, dLdx, + dLdhI, dLdWx, dLdWh, dLdb); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(gru_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(gru_bp) { - - auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] - auto hI = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] - - auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] - auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] - auto b = INPUT_VARIABLE(4); // biases, [3*nOut] - - auto dLdh = INPUT_VARIABLE(5); // gradient vs. ff output, [time, bS, nOut] - - const int time = x->sizeAt(0); - const int bS = x->sizeAt(1); - const int nIn = x->sizeAt(2); - const int nOut = hI->sizeAt(1); - - const std::vector h0CorrectShape = {bS, nOut}; - const std::vector wxCorrectShape = {nIn, 3*nOut}; - const std::vector whCorrectShape = {nOut, 3*nOut}; - const std::vector bCorrectShape = {3*nOut}; - const std::vector hCorrectShape = {time, bS, nOut}; - - REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, "GRU_BP operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); - REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU_BP operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU_BP operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); - REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU_BP operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - REQUIRE_TRUE(dLdh->isSameShape(hCorrectShape),0, "GRU_BP operation: wrong shape of gradient vs. ff output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str()); - - auto dLdxShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(dLdh->dataType(), x->shapeInfo()); - auto dLdhIShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(dLdh->dataType(), hI->shapeInfo()); - auto dLdWxShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(dLdh->dataType(), Wx->shapeInfo()); - auto dLdWhShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(dLdh->dataType(), Wh->shapeInfo()); - auto dLdbShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(dLdh->dataType(), b->shapeInfo()); - - return SHAPELIST(dLdxShapeInfo, dLdhIShapeInfo, dLdWxShapeInfo, dLdWhShapeInfo, dLdbShapeInfo); -} - -} + auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] + auto hI = + INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] + + auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] + auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] + auto b = INPUT_VARIABLE(4); // biases, [3*nOut] + + auto dLdh = INPUT_VARIABLE(5); // gradient vs. ff output, [time, bS, nOut] + + const int time = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int nIn = x->sizeAt(2); + const int nOut = hI->sizeAt(1); + + const std::vector h0CorrectShape = {bS, nOut}; + const std::vector wxCorrectShape = {nIn, 3 * nOut}; + const std::vector whCorrectShape = {nOut, 3 * nOut}; + const std::vector bCorrectShape = {3 * nOut}; + const std::vector hCorrectShape = {time, bS, nOut}; + + REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, + "GRU_BP operation: wrong shape of previous cell output array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(h0CorrectShape).c_str(), + ShapeUtils::shapeAsString(hI).c_str()); + REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, + "GRU_BP operation: wrong shape of input-to-hidden weights " + "array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(wxCorrectShape).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, + "GRU_BP operation: wrong shape of hidden-to-hidden weights " + "array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(whCorrectShape).c_str(), + ShapeUtils::shapeAsString(Wh).c_str()); + REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, + "GRU_BP operation: wrong shape of biases array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(bCorrectShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + REQUIRE_TRUE(dLdh->isSameShape(hCorrectShape), 0, + "GRU_BP operation: wrong shape of gradient vs. ff output, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(hCorrectShape).c_str(), + ShapeUtils::shapeAsString(dLdh).c_str()); + + auto dLdxShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo( + dLdh->dataType(), x->shapeInfo()); + auto dLdhIShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo( + dLdh->dataType(), hI->shapeInfo()); + auto dLdWxShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo( + dLdh->dataType(), Wx->shapeInfo()); + auto dLdWhShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo( + dLdh->dataType(), Wh->shapeInfo()); + auto dLdbShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo( + dLdh->dataType(), b->shapeInfo()); + + return SHAPELIST(dLdxShapeInfo, dLdhIShapeInfo, dLdWxShapeInfo, + dLdWhShapeInfo, dLdbShapeInfo); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp index 25c8d3744c49..546ecace2cbd 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp @@ -23,225 +23,347 @@ #if NOT_EXCLUDED(OP_gruCell) #include -#include +#include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(gruCell, 6, 4, false, 0, 0) { - - auto x = INPUT_VARIABLE(0); // input [bS, nIn], nIn - input size - auto hLast = INPUT_VARIABLE(1); // previous cell output [bS, nU], that is at previous time step t-1, nU - number of units - auto Wru = INPUT_VARIABLE(2); // RU weights - [nIn+nU, 2*nU] - reset and update gates (input/recurrent weights) - auto Wc = INPUT_VARIABLE(3); // C weights - [nIn+nU, nU] - cell gate (input/recurrent weights) - auto bru = INPUT_VARIABLE(4); // reset and update biases, [2*nU] - reset and update gates - auto bc = INPUT_VARIABLE(5); // cell biases, [nU] - - auto r = OUTPUT_VARIABLE(0); // Reset gate output [bS, nU] - auto u = OUTPUT_VARIABLE(1); // Update gate output [bS, nU] - auto c = OUTPUT_VARIABLE(2); // Cell gate output [bS, nU] - auto h = OUTPUT_VARIABLE(3); // current cell output [bS, nU] - - REQUIRE_TRUE(x->rankOf()==2 && hLast->rankOf()==2, 0, "gruCell: Input ranks must be 2 for inputs 0 and 1 (x, hLast) - got %i, %i", x->rankOf(), hLast->rankOf()); - - const int rank = x->rankOf(); - const auto bS = x->sizeAt(0); - const auto nIn = x->sizeAt(1); - const auto nU = hLast->sizeAt(1); - - REQUIRE_TRUE(x->sizeAt(0) == hLast->sizeAt(0), 0, "gruCell: Input minibatch sizes (dimension 0) must be same for x and hLast"); - REQUIRE_TRUE(Wru->rankOf()==2 && Wc->rankOf()==2, 0, "gruCell: weight arrays (Wru, Wc) arrays must be 2, got %i and %i", Wru->rankOf(), Wc->rankOf()); - REQUIRE_TRUE(Wru->sizeAt(0)==(nIn+nU) && Wc->sizeAt(0)==(nIn+nU), 0, "gruCell: Weights size(0) must be equal to nIn + nU, got %i", Wru->sizeAt(0)); - REQUIRE_TRUE(Wru->sizeAt(1)==(2*nU), 0, "gruCell: Weights (reset and update) size(1) must be equal to 2*nU, got %i", Wru->sizeAt(1)); - REQUIRE_TRUE(Wc->sizeAt(1)==nU, 0, "gruCell: Weights (cell) size(1) must be equal to nU, got %i", Wc->sizeAt(1)); - REQUIRE_TRUE(bru->rankOf()==1 && bru->sizeAt(0)==(2*nU), 0, "gruCell: reset/update biases must be rank 1, size 2*nU"); - REQUIRE_TRUE(bc->rankOf()==1 && bc->sizeAt(0)==nU, 0, "gruCell: cell biases must be rank 1, size nU"); - REQUIRE_TRUE(r->rankOf()==2 && u->rankOf()==2 && c->rankOf()==2 && h->rankOf()==2 && - r->sizeAt(0)==bS && u->sizeAt(0)==bS && c->sizeAt(0)==bS && h->sizeAt(0)==bS && - r->sizeAt(1)==nU && u->sizeAt(1)==nU && c->sizeAt(1)==nU && h->sizeAt(1)==nU, - 0, "gruCell: Output arrays must all be rank 2 with size(0) == batchSize and size(1) == nU"); - - helpers::gruCell(block.launchContext(), x, hLast, Wru, Wc, bru, bc, r, u, c, h); - - return Status::OK(); + auto x = INPUT_VARIABLE(0); // input [bS, nIn], nIn - input size + auto hLast = + INPUT_VARIABLE(1); // previous cell output [bS, nU], that is at previous + // time step t-1, nU - number of units + auto Wru = INPUT_VARIABLE(2); // RU weights - [nIn+nU, 2*nU] - reset and + // update gates (input/recurrent weights) + auto Wc = INPUT_VARIABLE( + 3); // C weights - [nIn+nU, nU] - cell gate (input/recurrent weights) + auto bru = INPUT_VARIABLE( + 4); // reset and update biases, [2*nU] - reset and update gates + auto bc = INPUT_VARIABLE(5); // cell biases, [nU] + + auto r = OUTPUT_VARIABLE(0); // Reset gate output [bS, nU] + auto u = OUTPUT_VARIABLE(1); // Update gate output [bS, nU] + auto c = OUTPUT_VARIABLE(2); // Cell gate output [bS, nU] + auto h = OUTPUT_VARIABLE(3); // current cell output [bS, nU] + + REQUIRE_TRUE(x->rankOf() == 2 && hLast->rankOf() == 2, 0, + "gruCell: Input ranks must be 2 for inputs 0 and 1 (x, hLast) - " + "got %i, %i", + x->rankOf(), hLast->rankOf()); + + const int rank = x->rankOf(); + const auto bS = x->sizeAt(0); + const auto nIn = x->sizeAt(1); + const auto nU = hLast->sizeAt(1); + + REQUIRE_TRUE(x->sizeAt(0) == hLast->sizeAt(0), 0, + "gruCell: Input minibatch sizes (dimension 0) must be same for " + "x and hLast"); + REQUIRE_TRUE( + Wru->rankOf() == 2 && Wc->rankOf() == 2, 0, + "gruCell: weight arrays (Wru, Wc) arrays must be 2, got %i and %i", + Wru->rankOf(), Wc->rankOf()); + REQUIRE_TRUE(Wru->sizeAt(0) == (nIn + nU) && Wc->sizeAt(0) == (nIn + nU), 0, + "gruCell: Weights size(0) must be equal to nIn + nU, got %i", + Wru->sizeAt(0)); + REQUIRE_TRUE(Wru->sizeAt(1) == (2 * nU), 0, + "gruCell: Weights (reset and update) size(1) must be equal to " + "2*nU, got %i", + Wru->sizeAt(1)); + REQUIRE_TRUE(Wc->sizeAt(1) == nU, 0, + "gruCell: Weights (cell) size(1) must be equal to nU, got %i", + Wc->sizeAt(1)); + REQUIRE_TRUE(bru->rankOf() == 1 && bru->sizeAt(0) == (2 * nU), 0, + "gruCell: reset/update biases must be rank 1, size 2*nU"); + REQUIRE_TRUE(bc->rankOf() == 1 && bc->sizeAt(0) == nU, 0, + "gruCell: cell biases must be rank 1, size nU"); + REQUIRE_TRUE( + r->rankOf() == 2 && u->rankOf() == 2 && c->rankOf() == 2 && + h->rankOf() == 2 && r->sizeAt(0) == bS && u->sizeAt(0) == bS && + c->sizeAt(0) == bS && h->sizeAt(0) == bS && r->sizeAt(1) == nU && + u->sizeAt(1) == nU && c->sizeAt(1) == nU && h->sizeAt(1) == nU, + 0, + "gruCell: Output arrays must all be rank 2 with size(0) == batchSize and " + "size(1) == nU"); + + helpers::gruCell(block.launchContext(), x, hLast, Wru, Wc, bru, bc, r, u, c, + h); + + return Status::OK(); } DECLARE_TYPES(gruCell) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedInputTypes(3, {ALL_FLOATS}) - ->setAllowedInputTypes(4, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedInputTypes(3, {ALL_FLOATS}) + ->setAllowedInputTypes(4, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } - DECLARE_SHAPE_FN(gruCell) { - - auto x = inputShape->at(0); // input [bS x nIn] - auto hLast = inputShape->at(1); // previous cell output [bS x nU], that is at previous time step t-1 - auto Wru = inputShape->at(2); // RU weights - [(nIn+nU), 2*nU] - reset and update gates (input/recurrent weights) - auto Wc = inputShape->at(3); // C weights - [(nIn+nU), nU] - cell gate (input/recurrent weights) - auto bru = inputShape->at(4); // reset and update biases, [2*nU] - reset and update gates - auto bc = inputShape->at(5); // cell biases, [nU] - - REQUIRE_TRUE(shape::rank(x)==2 && shape::rank(hLast)==2, 0, "gruCell: Input ranks must be 2 for inputs 0 and 1 (x, hLast) - got %i, %i", shape::rank(x), shape::rank(hLast)); - - const int rank = x[0]; - const auto bS = x[1]; - const auto nIn = x[2]; - const auto nU = hLast[2]; - - REQUIRE_TRUE(x[1] == hLast[1], 0, "gruCell: Input minibatch sizes (dimension 0) must be same for x and hLast"); - REQUIRE_TRUE(shape::rank(Wru)==2 && shape::rank(Wc)==2, 0, "gruCell: weight arrays (Wru, Wc) arrays must be 2, got %i and %i", shape::rank(Wru), shape::rank(Wc)); - REQUIRE_TRUE(Wru[1]==(nIn+nU) && Wc[1]==(nIn+nU), 0, "gruCell: Weights size(0) must be equal to nIn + nU, got %i and %i", Wru[1], Wc[1]); - REQUIRE_TRUE(Wru[2]==(2*nU), 0, "gruCell: Weights (reset and update) size(1) must be equal to 2*nU, got %i", Wru[2]); - REQUIRE_TRUE(Wc[2]==nU, 0, "gruCell: Weights (cell) size(1) must be equal to nU, got %i", Wc[2]); - REQUIRE_TRUE(shape::rank(bru)==1 && bru[1]==(2*nU), 0, "gruCell: reset/update biases must be rank 1, size 2*nU"); - REQUIRE_TRUE(shape::rank(bc)==1 && bc[1]==nU, 0, "gruCell: cell biases must be rank 1, size nU"); - - Nd4jLong *s0(nullptr); - ALLOCATE(s0, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);// [bS x nU] - - s0[0] = rank; - s0[1] = bS; - s0[2] = nU; - - ShapeUtils::updateStridesAndType(s0, x, shape::order(hLast)); - auto ts0 = ConstantShapeHelper::getInstance().createFromExisting(s0, block.workspace()); - - //4 output shapes, all [bs, nU] - return SHAPELIST(ts0, ts0, ts0, ts0); + auto x = inputShape->at(0); // input [bS x nIn] + auto hLast = inputShape->at( + 1); // previous cell output [bS x nU], that is at previous time step t-1 + auto Wru = inputShape->at(2); // RU weights - [(nIn+nU), 2*nU] - reset and + // update gates (input/recurrent weights) + auto Wc = inputShape->at( + 3); // C weights - [(nIn+nU), nU] - cell gate (input/recurrent weights) + auto bru = inputShape->at( + 4); // reset and update biases, [2*nU] - reset and update gates + auto bc = inputShape->at(5); // cell biases, [nU] + + REQUIRE_TRUE(shape::rank(x) == 2 && shape::rank(hLast) == 2, 0, + "gruCell: Input ranks must be 2 for inputs 0 and 1 (x, hLast) - " + "got %i, %i", + shape::rank(x), shape::rank(hLast)); + + const int rank = x[0]; + const auto bS = x[1]; + const auto nIn = x[2]; + const auto nU = hLast[2]; + + REQUIRE_TRUE(x[1] == hLast[1], 0, + "gruCell: Input minibatch sizes (dimension 0) must be same for " + "x and hLast"); + REQUIRE_TRUE( + shape::rank(Wru) == 2 && shape::rank(Wc) == 2, 0, + "gruCell: weight arrays (Wru, Wc) arrays must be 2, got %i and %i", + shape::rank(Wru), shape::rank(Wc)); + REQUIRE_TRUE( + Wru[1] == (nIn + nU) && Wc[1] == (nIn + nU), 0, + "gruCell: Weights size(0) must be equal to nIn + nU, got %i and %i", + Wru[1], Wc[1]); + REQUIRE_TRUE(Wru[2] == (2 * nU), 0, + "gruCell: Weights (reset and update) size(1) must be equal to " + "2*nU, got %i", + Wru[2]); + REQUIRE_TRUE(Wc[2] == nU, 0, + "gruCell: Weights (cell) size(1) must be equal to nU, got %i", + Wc[2]); + REQUIRE_TRUE(shape::rank(bru) == 1 && bru[1] == (2 * nU), 0, + "gruCell: reset/update biases must be rank 1, size 2*nU"); + REQUIRE_TRUE(shape::rank(bc) == 1 && bc[1] == nU, 0, + "gruCell: cell biases must be rank 1, size nU"); + + Nd4jLong *s0(nullptr); + ALLOCATE(s0, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); // [bS x nU] + + s0[0] = rank; + s0[1] = bS; + s0[2] = nU; + + ShapeUtils::updateStridesAndType(s0, x, shape::order(hLast)); + auto ts0 = ConstantShapeHelper::getInstance().createFromExisting( + s0, block.workspace()); + + // 4 output shapes, all [bs, nU] + return SHAPELIST(ts0, ts0, ts0, ts0); } - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(gruCell_bp, 10, 6, false, 0, 0) { - - auto x = INPUT_VARIABLE(0); // input [bS x iS] - auto hi = INPUT_VARIABLE(1); // previous cell output [bS x nU] - auto W = INPUT_VARIABLE(2); // weights, [iS+nU x 2*nU] - auto Wc = INPUT_VARIABLE(3); // c weights, [iS+nU x nU] - auto b = INPUT_VARIABLE(4); // biases, [2*nU] - auto bc = INPUT_VARIABLE(5); // biases, [nU] - auto dLdr = INPUT_VARIABLE(6); // gradient wrt reset gate, [bS, nU] - auto dLdu = INPUT_VARIABLE(7); // gradient wrt update gate, [bS, nU] - auto dLdc = INPUT_VARIABLE(8); // gradient wrt cell state, [bS, nU] - auto dLdh = INPUT_VARIABLE(9); // gradient wrt current cell output, [bS, nU] - - auto dLdx = OUTPUT_VARIABLE(0); // gradient wrt x, [bS, iS] - auto dLdhi = OUTPUT_VARIABLE(1); // gradient wrt hi, [bS, nU] - auto dLdW = OUTPUT_VARIABLE(2); // gradient wrt W, [iS+nU x 2*nU] - auto dLdWc = OUTPUT_VARIABLE(3); // gradient wrt Wc, [iS+nU x nU] - auto dLdb = OUTPUT_VARIABLE(4); // gradient wrt biases, [2*nU] - auto dLdbc = OUTPUT_VARIABLE(5); // gradient wrt c biases, [nU] - - const Nd4jLong bS = x->sizeAt(0); - const Nd4jLong iS = x->sizeAt(1); - const Nd4jLong nU = hi->sizeAt(1); - - REQUIRE_TRUE(x->rankOf() == 2, 0, "GRU_CELL_BP: rank of input array x must be 2, but got %i instead", x->rankOf()); - - const std::vector hiCorrectShape = {bS, nU}; - const std::vector wCorrectShape = {iS+nU, 2*nU}; - const std::vector wcCorrectShape = {iS+nU, nU}; - const std::vector bCorrectShape = {2*nU}; - const std::vector bcCorrectShape = {nU}; - - REQUIRE_TRUE(hi->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(hi).c_str()); - REQUIRE_TRUE(W->isSameShape(wCorrectShape), 0, "GRU_CELL_BP op: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(W).c_str()); - REQUIRE_TRUE(Wc->isSameShape(wcCorrectShape), 0, "GRU_CELL_BP op: wrong shape of c weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wcCorrectShape).c_str(), ShapeUtils::shapeAsString(Wc).c_str()); - REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - REQUIRE_TRUE(bc->isSameShape(bcCorrectShape), 0, "GRU_CELL_BP op: wrong shape of c biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bcCorrectShape).c_str(), ShapeUtils::shapeAsString(bc).c_str()); - REQUIRE_TRUE(dLdr->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdr array (gradient wrt reset gate), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdr).c_str()); - REQUIRE_TRUE(dLdu->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdu array (gradient wrt update gate), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdu).c_str()); - REQUIRE_TRUE(dLdc->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell state), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdc).c_str()); - REQUIRE_TRUE(dLdh->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt current cell output), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str()); - - helpers::gruCellBp(block.launchContext(), x, hi, W, Wc, b, bc, dLdr, dLdu, dLdc, dLdh, dLdx, dLdhi, dLdW, dLdWc, dLdb, dLdbc); - - return Status::OK(); + auto x = INPUT_VARIABLE(0); // input [bS x iS] + auto hi = INPUT_VARIABLE(1); // previous cell output [bS x nU] + auto W = INPUT_VARIABLE(2); // weights, [iS+nU x 2*nU] + auto Wc = INPUT_VARIABLE(3); // c weights, [iS+nU x nU] + auto b = INPUT_VARIABLE(4); // biases, [2*nU] + auto bc = INPUT_VARIABLE(5); // biases, [nU] + auto dLdr = INPUT_VARIABLE(6); // gradient wrt reset gate, [bS, nU] + auto dLdu = INPUT_VARIABLE(7); // gradient wrt update gate, [bS, nU] + auto dLdc = INPUT_VARIABLE(8); // gradient wrt cell state, [bS, nU] + auto dLdh = INPUT_VARIABLE(9); // gradient wrt current cell output, [bS, nU] + + auto dLdx = OUTPUT_VARIABLE(0); // gradient wrt x, [bS, iS] + auto dLdhi = OUTPUT_VARIABLE(1); // gradient wrt hi, [bS, nU] + auto dLdW = OUTPUT_VARIABLE(2); // gradient wrt W, [iS+nU x 2*nU] + auto dLdWc = OUTPUT_VARIABLE(3); // gradient wrt Wc, [iS+nU x nU] + auto dLdb = OUTPUT_VARIABLE(4); // gradient wrt biases, [2*nU] + auto dLdbc = OUTPUT_VARIABLE(5); // gradient wrt c biases, [nU] + + const Nd4jLong bS = x->sizeAt(0); + const Nd4jLong iS = x->sizeAt(1); + const Nd4jLong nU = hi->sizeAt(1); + + REQUIRE_TRUE( + x->rankOf() == 2, 0, + "GRU_CELL_BP: rank of input array x must be 2, but got %i instead", + x->rankOf()); + + const std::vector hiCorrectShape = {bS, nU}; + const std::vector wCorrectShape = {iS + nU, 2 * nU}; + const std::vector wcCorrectShape = {iS + nU, nU}; + const std::vector bCorrectShape = {2 * nU}; + const std::vector bcCorrectShape = {nU}; + + REQUIRE_TRUE(hi->isSameShape(hiCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of previous cell output array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(hiCorrectShape).c_str(), + ShapeUtils::shapeAsString(hi).c_str()); + REQUIRE_TRUE(W->isSameShape(wCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of weights array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(wCorrectShape).c_str(), + ShapeUtils::shapeAsString(W).c_str()); + REQUIRE_TRUE(Wc->isSameShape(wcCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of c weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(wcCorrectShape).c_str(), + ShapeUtils::shapeAsString(Wc).c_str()); + REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of biases array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(bCorrectShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + REQUIRE_TRUE(bc->isSameShape(bcCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of c biases array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(bcCorrectShape).c_str(), + ShapeUtils::shapeAsString(bc).c_str()); + REQUIRE_TRUE(dLdr->isSameShape(hiCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of dLdr array (gradient wrt reset " + "gate), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(hiCorrectShape).c_str(), + ShapeUtils::shapeAsString(dLdr).c_str()); + REQUIRE_TRUE(dLdu->isSameShape(hiCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of dLdu array (gradient wrt update " + "gate), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(hiCorrectShape).c_str(), + ShapeUtils::shapeAsString(dLdu).c_str()); + REQUIRE_TRUE(dLdc->isSameShape(hiCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell " + "state), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(hiCorrectShape).c_str(), + ShapeUtils::shapeAsString(dLdc).c_str()); + REQUIRE_TRUE(dLdh->isSameShape(hiCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt " + "current cell output), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(hiCorrectShape).c_str(), + ShapeUtils::shapeAsString(dLdh).c_str()); + + helpers::gruCellBp(block.launchContext(), x, hi, W, Wc, b, bc, dLdr, dLdu, + dLdc, dLdh, dLdx, dLdhi, dLdW, dLdWc, dLdb, dLdbc); + + return Status::OK(); } DECLARE_TYPES(gruCell_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedInputTypes(3, {ALL_FLOATS}) - ->setAllowedInputTypes(4, {ALL_FLOATS}) - ->setAllowedInputTypes(5, {ALL_FLOATS}) - ->setAllowedInputTypes(6, {ALL_FLOATS}) - ->setAllowedInputTypes(7, {ALL_FLOATS}) - ->setAllowedInputTypes(8, {ALL_FLOATS}) - ->setAllowedInputTypes(9, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedInputTypes(3, {ALL_FLOATS}) + ->setAllowedInputTypes(4, {ALL_FLOATS}) + ->setAllowedInputTypes(5, {ALL_FLOATS}) + ->setAllowedInputTypes(6, {ALL_FLOATS}) + ->setAllowedInputTypes(7, {ALL_FLOATS}) + ->setAllowedInputTypes(8, {ALL_FLOATS}) + ->setAllowedInputTypes(9, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } DECLARE_SHAPE_FN(gruCell_bp) { - - auto xShapeInfo = inputShape->at(0); // [bS x iS] - auto hiShapeInfo = inputShape->at(1); // [bS x nU] - auto wShapeInfo = inputShape->at(2); // [iS+nU x 2*nU] - auto wcShapeInfo = inputShape->at(3); // [iS+nU x nU] - auto bShapeInfo = inputShape->at(4); // [2*nU] - auto bcShapeInfo = inputShape->at(5); // [nU] - auto dLdrShapeInfo = inputShape->at(6); // [bS, nU] - auto dLduShapeInfo = inputShape->at(7); // [bS, nU] - auto dLdcShapeInfo = inputShape->at(8); // [bS, nU] - auto dLdhShapeInfo = inputShape->at(9); // [bS, nU] - - const int rank = xShapeInfo[0]; // = 2 - const Nd4jLong bS = xShapeInfo[1]; - const Nd4jLong iS = xShapeInfo[2]; - const Nd4jLong nU = hiShapeInfo[2]; - - REQUIRE_TRUE(xShapeInfo[0] == 2, 0, "GRU_CELL_BP: rank of input array x must be 2, but got %i instead", xShapeInfo[0]); - - const std::vector hiCorrectShape = {bS, nU}; - const std::vector wCorrectShape = {iS+nU, 2*nU}; - const std::vector wcCorrectShape = {iS+nU, nU}; - const std::vector bCorrectShape = {2*nU}; - const std::vector bcCorrectShape = {nU}; - - REQUIRE_TRUE(ShapeUtils::areShapesEqual(hiShapeInfo, hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(hiShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, wCorrectShape), 0, "GRU_CELL_BP op: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(wShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(wcShapeInfo, wcCorrectShape), 0, "GRU_CELL_BP op: wrong shape of c weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wcCorrectShape).c_str(), ShapeUtils::shapeAsString(wcShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bcShapeInfo, bcCorrectShape), 0, "GRU_CELL_BP op: wrong shape of c biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bcCorrectShape).c_str(), ShapeUtils::shapeAsString(bcShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(dLdrShapeInfo, hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdr array (gradient wrt reset gate), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdrShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(dLduShapeInfo, hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdu array (gradient wrt update gate), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLduShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(dLdcShapeInfo, hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell state), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdcShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(dLdhShapeInfo, hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt current cell output), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdhShapeInfo).c_str()); - - Nd4jLong *dLdxShapeInfo = nullptr; - COPY_SHAPE(xShapeInfo, dLdxShapeInfo); - - Nd4jLong *dLdhiShapeInfo = nullptr; - COPY_SHAPE(hiShapeInfo, dLdhiShapeInfo); - - Nd4jLong *dLdWShapeInfo = nullptr; - COPY_SHAPE(wShapeInfo, dLdWShapeInfo); - - Nd4jLong *dLdWcShapeInfo = nullptr; - COPY_SHAPE(wcShapeInfo, dLdWcShapeInfo); - - Nd4jLong *dLdbShapeInfo = nullptr; - COPY_SHAPE(bShapeInfo, dLdbShapeInfo); - - Nd4jLong *dLdbcShapeInfo = nullptr; - COPY_SHAPE(bcShapeInfo, dLdbcShapeInfo); - - return SHAPELIST(CONSTANT(dLdxShapeInfo), CONSTANT(dLdhiShapeInfo), CONSTANT(dLdWShapeInfo), CONSTANT(dLdWcShapeInfo), CONSTANT(dLdbShapeInfo), CONSTANT(dLdbcShapeInfo)); + auto xShapeInfo = inputShape->at(0); // [bS x iS] + auto hiShapeInfo = inputShape->at(1); // [bS x nU] + auto wShapeInfo = inputShape->at(2); // [iS+nU x 2*nU] + auto wcShapeInfo = inputShape->at(3); // [iS+nU x nU] + auto bShapeInfo = inputShape->at(4); // [2*nU] + auto bcShapeInfo = inputShape->at(5); // [nU] + auto dLdrShapeInfo = inputShape->at(6); // [bS, nU] + auto dLduShapeInfo = inputShape->at(7); // [bS, nU] + auto dLdcShapeInfo = inputShape->at(8); // [bS, nU] + auto dLdhShapeInfo = inputShape->at(9); // [bS, nU] + + const int rank = xShapeInfo[0]; // = 2 + const Nd4jLong bS = xShapeInfo[1]; + const Nd4jLong iS = xShapeInfo[2]; + const Nd4jLong nU = hiShapeInfo[2]; + + REQUIRE_TRUE( + xShapeInfo[0] == 2, 0, + "GRU_CELL_BP: rank of input array x must be 2, but got %i instead", + xShapeInfo[0]); + + const std::vector hiCorrectShape = {bS, nU}; + const std::vector wCorrectShape = {iS + nU, 2 * nU}; + const std::vector wcCorrectShape = {iS + nU, nU}; + const std::vector bCorrectShape = {2 * nU}; + const std::vector bcCorrectShape = {nU}; + + REQUIRE_TRUE(ShapeUtils::areShapesEqual(hiShapeInfo, hiCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of previous cell output array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(hiCorrectShape).c_str(), + ShapeUtils::shapeAsString(hiShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, wCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of weights array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(wCorrectShape).c_str(), + ShapeUtils::shapeAsString(wShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(wcShapeInfo, wcCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of c weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(wcCorrectShape).c_str(), + ShapeUtils::shapeAsString(wcShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of biases array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(bCorrectShape).c_str(), + ShapeUtils::shapeAsString(bShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(bcShapeInfo, bcCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of c biases array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(bcCorrectShape).c_str(), + ShapeUtils::shapeAsString(bcShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(dLdrShapeInfo, hiCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of dLdr array (gradient wrt reset " + "gate), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(hiCorrectShape).c_str(), + ShapeUtils::shapeAsString(dLdrShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(dLduShapeInfo, hiCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of dLdu array (gradient wrt update " + "gate), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(hiCorrectShape).c_str(), + ShapeUtils::shapeAsString(dLduShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(dLdcShapeInfo, hiCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell " + "state), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(hiCorrectShape).c_str(), + ShapeUtils::shapeAsString(dLdcShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(dLdhShapeInfo, hiCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt " + "current cell output), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(hiCorrectShape).c_str(), + ShapeUtils::shapeAsString(dLdhShapeInfo).c_str()); + + Nd4jLong *dLdxShapeInfo = nullptr; + COPY_SHAPE(xShapeInfo, dLdxShapeInfo); + + Nd4jLong *dLdhiShapeInfo = nullptr; + COPY_SHAPE(hiShapeInfo, dLdhiShapeInfo); + + Nd4jLong *dLdWShapeInfo = nullptr; + COPY_SHAPE(wShapeInfo, dLdWShapeInfo); + + Nd4jLong *dLdWcShapeInfo = nullptr; + COPY_SHAPE(wcShapeInfo, dLdWcShapeInfo); + + Nd4jLong *dLdbShapeInfo = nullptr; + COPY_SHAPE(bShapeInfo, dLdbShapeInfo); + + Nd4jLong *dLdbcShapeInfo = nullptr; + COPY_SHAPE(bcShapeInfo, dLdbcShapeInfo); + + return SHAPELIST(CONSTANT(dLdxShapeInfo), CONSTANT(dLdhiShapeInfo), + CONSTANT(dLdWShapeInfo), CONSTANT(dLdWcShapeInfo), + CONSTANT(dLdbShapeInfo), CONSTANT(dLdbcShapeInfo)); } - - - -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstm.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstm.cpp index 915be3129d25..9733af262bcd 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstm.cpp @@ -22,131 +22,210 @@ #if NOT_EXCLUDED(OP_lstm) #include -#include +#include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(lstm, 8, 2, false, 3, 2) { - auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] - auto h0 = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS x numProj], in case of projection=false -> numProj == numUnits !!! - auto c0 = INPUT_VARIABLE(2); // initial cell state (at time step = 0) [bS x numUnits], - - auto Wx = INPUT_VARIABLE(3); // input-to-hidden weights, [inSize x 4*numUnits] - auto Wh = INPUT_VARIABLE(4); // hidden-to-hidden weights, [numProj x 4*numUnits] - auto Wc = INPUT_VARIABLE(5); // diagonal weights for peephole connections [3*numUnits] - auto Wp = INPUT_VARIABLE(6); // projection weights [numUnits x numProj] - auto b = INPUT_VARIABLE(7); // biases, [4*numUnits] - - auto h = OUTPUT_VARIABLE(0); // cell outputs [time x bS x numProj], that is per each time step - auto c = OUTPUT_VARIABLE(1); // cell states [time x bS x numUnits] that is per each time step - - const int peephole = INT_ARG(0); // if 1, provide peephole connections - const int projection = INT_ARG(1); // if 1, then projection is performed, if false then numProj==numUnits is mandatory!!!! - - // FIXME: double - const double clippingCellValue = T_ARG(0); // clipping value for ct, if it is not equal to zero, then cell state is clipped - const double clippingProjValue = T_ARG(1); // clipping value for projected ht, if it is not equal to zero, then projected cell output is clipped - const double forgetBias = T_ARG(2); - - const int rank = x->rankOf(); - const int time = x->sizeAt(0); - const int bS = x->sizeAt(1); - const int inSize = x->sizeAt(2); - const int numProj = h0->sizeAt(1); - const int numUnits = c0->sizeAt(1); - - // input shapes validation - const std::vector correctH0Shape = {bS, numProj}; - const std::vector correctC0Shape = {bS, numUnits}; - const std::vector correctWxShape = {inSize, 4*numUnits}; - const std::vector correctWhShape = {numProj, 4*numUnits}; - const std::vector correctWcShape = {3*numUnits}; - const std::vector correctWpShape = {numUnits, numProj}; - const std::vector correctBShape = {4*numUnits}; - - REQUIRE_TRUE(h0->isSameShape(correctH0Shape), 0, "LSTM operation: wrong shape of initial cell output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctH0Shape).c_str(), ShapeUtils::shapeAsString(h0).c_str()); - REQUIRE_TRUE(c0->isSameShape(correctC0Shape), 0, "LSTM operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctC0Shape).c_str(), ShapeUtils::shapeAsString(c0).c_str()); - REQUIRE_TRUE(Wx->isSameShape(correctWxShape), 0, "LSTM operation: wrong shape of input-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWxShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - REQUIRE_TRUE(Wh->isSameShape(correctWhShape), 0, "LSTM operation: wrong shape of hidden-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWhShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); - REQUIRE_TRUE(Wc->isSameShape(correctWcShape), 0, "LSTM operation: wrong shape of diagonal weights for peephole connections, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWcShape).c_str(), ShapeUtils::shapeAsString(Wc).c_str()); - REQUIRE_TRUE(Wp->isSameShape(correctWpShape), 0, "LSTM operation: wrong shape of projection weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWpShape).c_str(), ShapeUtils::shapeAsString(Wp).c_str()); - REQUIRE_TRUE(b->isSameShape(correctBShape), 0, "LSTM operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctBShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - REQUIRE_TRUE(!(!projection && numUnits != numProj), 0, "LSTM operation: projection option is switched of, and in this case output dimensionality for the projection matrices (numProj) must be equal to number of units in lstmCell !"); - - helpers::lstmTimeLoop(block.launchContext(), x, h0, c0, Wx, Wh, Wc, Wp, b, h, c, {(double)peephole, (double)projection, clippingCellValue, clippingProjValue, forgetBias}); - - return Status::OK(); + auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] + auto h0 = INPUT_VARIABLE( + 1); // initial cell output (at time step = 0) [bS x numProj], in case of + // projection=false -> numProj == numUnits !!! + auto c0 = INPUT_VARIABLE( + 2); // initial cell state (at time step = 0) [bS x numUnits], + + auto Wx = + INPUT_VARIABLE(3); // input-to-hidden weights, [inSize x 4*numUnits] + auto Wh = + INPUT_VARIABLE(4); // hidden-to-hidden weights, [numProj x 4*numUnits] + auto Wc = INPUT_VARIABLE( + 5); // diagonal weights for peephole connections [3*numUnits] + auto Wp = INPUT_VARIABLE(6); // projection weights [numUnits x numProj] + auto b = INPUT_VARIABLE(7); // biases, [4*numUnits] + + auto h = OUTPUT_VARIABLE( + 0); // cell outputs [time x bS x numProj], that is per each time step + auto c = OUTPUT_VARIABLE( + 1); // cell states [time x bS x numUnits] that is per each time step + + const int peephole = INT_ARG(0); // if 1, provide peephole connections + const int projection = + INT_ARG(1); // if 1, then projection is performed, if false then + // numProj==numUnits is mandatory!!!! + + // FIXME: double + const double clippingCellValue = + T_ARG(0); // clipping value for ct, if it is not equal to zero, then cell + // state is clipped + const double clippingProjValue = + T_ARG(1); // clipping value for projected ht, if it is not equal to zero, + // then projected cell output is clipped + const double forgetBias = T_ARG(2); + + const int rank = x->rankOf(); + const int time = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int inSize = x->sizeAt(2); + const int numProj = h0->sizeAt(1); + const int numUnits = c0->sizeAt(1); + + // input shapes validation + const std::vector correctH0Shape = {bS, numProj}; + const std::vector correctC0Shape = {bS, numUnits}; + const std::vector correctWxShape = {inSize, 4 * numUnits}; + const std::vector correctWhShape = {numProj, 4 * numUnits}; + const std::vector correctWcShape = {3 * numUnits}; + const std::vector correctWpShape = {numUnits, numProj}; + const std::vector correctBShape = {4 * numUnits}; + + REQUIRE_TRUE(h0->isSameShape(correctH0Shape), 0, + "LSTM operation: wrong shape of initial cell output, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctH0Shape).c_str(), + ShapeUtils::shapeAsString(h0).c_str()); + REQUIRE_TRUE(c0->isSameShape(correctC0Shape), 0, + "LSTM operation: wrong shape of initial cell state, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctC0Shape).c_str(), + ShapeUtils::shapeAsString(c0).c_str()); + REQUIRE_TRUE(Wx->isSameShape(correctWxShape), 0, + "LSTM operation: wrong shape of input-to-hidden weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWxShape).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + REQUIRE_TRUE(Wh->isSameShape(correctWhShape), 0, + "LSTM operation: wrong shape of hidden-to-hidden weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWhShape).c_str(), + ShapeUtils::shapeAsString(Wh).c_str()); + REQUIRE_TRUE(Wc->isSameShape(correctWcShape), 0, + "LSTM operation: wrong shape of diagonal weights for peephole " + "connections, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWcShape).c_str(), + ShapeUtils::shapeAsString(Wc).c_str()); + REQUIRE_TRUE(Wp->isSameShape(correctWpShape), 0, + "LSTM operation: wrong shape of projection weights, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(correctWpShape).c_str(), + ShapeUtils::shapeAsString(Wp).c_str()); + REQUIRE_TRUE(b->isSameShape(correctBShape), 0, + "LSTM operation: wrong shape of biases, expected is %s, but got " + "%s instead !", + ShapeUtils::shapeAsString(correctBShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + REQUIRE_TRUE(!(!projection && numUnits != numProj), 0, + "LSTM operation: projection option is switched of, and in this " + "case output dimensionality for the projection matrices " + "(numProj) must be equal to number of units in lstmCell !"); + + helpers::lstmTimeLoop(block.launchContext(), x, h0, c0, Wx, Wh, Wc, Wp, b, h, + c, + {(double)peephole, (double)projection, + clippingCellValue, clippingProjValue, forgetBias}); + + return Status::OK(); } - DECLARE_TYPES(lstm) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - +DECLARE_TYPES(lstm) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(lstm) { - - auto xShapeInfo = inputShape->at(0); // input [time x bS x inSize] - auto h0ShapeInfo = inputShape->at(1); // initial cell output (at time step = 0) [bS x numProj], in case of projection=false -> numProj == numUnits !!! - auto c0ShapeInfo = inputShape->at(2); // initial cell state (at time step = 0) [bS x numUnits], - - auto WxShapeInfo = inputShape->at(3); // input-to-hidden weights, [inSize x 4*numUnits] - auto WhShapeInfo = inputShape->at(4); // hidden-to-hidden weights, [numProj x 4*numUnits] - auto WcShapeInfo = inputShape->at(5); // diagonal weights for peephole connections [3*numUnits] - auto WpShapeInfo = inputShape->at(6); // projection weights [numUnits x numProj] - auto bShapeInfo = inputShape->at(7); // biases, [4*numUnits] - - const int rank = xShapeInfo[0]; - const int time = xShapeInfo[1]; - const int bS = xShapeInfo[2]; - const int inSize = xShapeInfo[3]; - const int numProj = h0ShapeInfo[2]; - const int numUnits = c0ShapeInfo[2]; - - // input shapes validation - const std::vector correctH0Shape = {bS, numProj}; - const std::vector correctC0Shape = {bS, numUnits}; - const std::vector correctWxShape = {inSize, 4*numUnits}; - const std::vector correctWhShape = {numProj, 4*numUnits}; - const std::vector correctWcShape = {3*numUnits}; - const std::vector correctWpShape = {numUnits, numProj}; - const std::vector correctBShape = {4*numUnits}; - - REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0ShapeInfo, correctH0Shape), 0, "LSTM operation: wrong shape of initial cell output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctH0Shape).c_str(), ShapeUtils::shapeAsString(h0ShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(c0ShapeInfo, correctC0Shape), 0, "LSTM operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctC0Shape).c_str(), ShapeUtils::shapeAsString(c0ShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WxShapeInfo, correctWxShape), 0, "LSTM operation: wrong shape of input-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWxShape).c_str(), ShapeUtils::shapeAsString(WxShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, correctWhShape), 0, "LSTM operation: wrong shape of hidden-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWhShape).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WcShapeInfo, correctWcShape), 0, "LSTM operation: wrong shape of diagonal weights for peephole connections, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWcShape).c_str(), ShapeUtils::shapeAsString(WcShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WpShapeInfo, correctWpShape), 0, "LSTM operation: wrong shape of projection weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWpShape).c_str(), ShapeUtils::shapeAsString(WpShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, correctBShape), 0, "LSTM operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctBShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str()); - - - // evaluate output shapeInfos - Nd4jLong *hShapeInfo(nullptr), *cShapeInfo(nullptr); - ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [time x bS x numProj] - ALLOCATE(cShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [time x bS x numUnits] - - hShapeInfo[0] = cShapeInfo[0] = rank; - hShapeInfo[1] = cShapeInfo[1] = time; - hShapeInfo[2] = cShapeInfo[2] = bS; - hShapeInfo[3] = numProj; - cShapeInfo[3] = numUnits; - - ShapeUtils::updateStridesAndType(hShapeInfo, xShapeInfo, shape::order(h0ShapeInfo)); - ShapeUtils::updateStridesAndType(cShapeInfo, xShapeInfo, shape::order(c0ShapeInfo)); - - return SHAPELIST(CONSTANT(hShapeInfo), CONSTANT(cShapeInfo)); + auto xShapeInfo = inputShape->at(0); // input [time x bS x inSize] + auto h0ShapeInfo = inputShape->at( + 1); // initial cell output (at time step = 0) [bS x numProj], in case of + // projection=false -> numProj == numUnits !!! + auto c0ShapeInfo = inputShape->at( + 2); // initial cell state (at time step = 0) [bS x numUnits], + + auto WxShapeInfo = + inputShape->at(3); // input-to-hidden weights, [inSize x 4*numUnits] + auto WhShapeInfo = + inputShape->at(4); // hidden-to-hidden weights, [numProj x 4*numUnits] + auto WcShapeInfo = inputShape->at( + 5); // diagonal weights for peephole connections [3*numUnits] + auto WpShapeInfo = + inputShape->at(6); // projection weights [numUnits x numProj] + auto bShapeInfo = inputShape->at(7); // biases, [4*numUnits] + + const int rank = xShapeInfo[0]; + const int time = xShapeInfo[1]; + const int bS = xShapeInfo[2]; + const int inSize = xShapeInfo[3]; + const int numProj = h0ShapeInfo[2]; + const int numUnits = c0ShapeInfo[2]; + + // input shapes validation + const std::vector correctH0Shape = {bS, numProj}; + const std::vector correctC0Shape = {bS, numUnits}; + const std::vector correctWxShape = {inSize, 4 * numUnits}; + const std::vector correctWhShape = {numProj, 4 * numUnits}; + const std::vector correctWcShape = {3 * numUnits}; + const std::vector correctWpShape = {numUnits, numProj}; + const std::vector correctBShape = {4 * numUnits}; + + REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0ShapeInfo, correctH0Shape), 0, + "LSTM operation: wrong shape of initial cell output, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctH0Shape).c_str(), + ShapeUtils::shapeAsString(h0ShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(c0ShapeInfo, correctC0Shape), 0, + "LSTM operation: wrong shape of initial cell state, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctC0Shape).c_str(), + ShapeUtils::shapeAsString(c0ShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(WxShapeInfo, correctWxShape), 0, + "LSTM operation: wrong shape of input-to-hidden weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWxShape).c_str(), + ShapeUtils::shapeAsString(WxShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, correctWhShape), 0, + "LSTM operation: wrong shape of hidden-to-hidden weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWhShape).c_str(), + ShapeUtils::shapeAsString(WhShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(WcShapeInfo, correctWcShape), 0, + "LSTM operation: wrong shape of diagonal weights for peephole " + "connections, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWcShape).c_str(), + ShapeUtils::shapeAsString(WcShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(WpShapeInfo, correctWpShape), 0, + "LSTM operation: wrong shape of projection weights, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(correctWpShape).c_str(), + ShapeUtils::shapeAsString(WpShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, correctBShape), 0, + "LSTM operation: wrong shape of biases, expected is %s, but got " + "%s instead !", + ShapeUtils::shapeAsString(correctBShape).c_str(), + ShapeUtils::shapeAsString(bShapeInfo).c_str()); + + // evaluate output shapeInfos + Nd4jLong *hShapeInfo(nullptr), *cShapeInfo(nullptr); + ALLOCATE(hShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); // [time x bS x numProj] + ALLOCATE(cShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); // [time x bS x numUnits] + + hShapeInfo[0] = cShapeInfo[0] = rank; + hShapeInfo[1] = cShapeInfo[1] = time; + hShapeInfo[2] = cShapeInfo[2] = bS; + hShapeInfo[3] = numProj; + cShapeInfo[3] = numUnits; + + ShapeUtils::updateStridesAndType(hShapeInfo, xShapeInfo, + shape::order(h0ShapeInfo)); + ShapeUtils::updateStridesAndType(cShapeInfo, xShapeInfo, + shape::order(c0ShapeInfo)); + + return SHAPELIST(CONSTANT(hShapeInfo), CONSTANT(cShapeInfo)); } - - - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlock.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlock.cpp index 1fd7ec8ccc18..56544d8f3144 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlock.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlock.cpp @@ -22,104 +22,135 @@ #if NOT_EXCLUDED(OP_lstmBlock) #include -#include +#include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(lstmBlock, 9, 7, false, 2, 2) { - auto maxTSLength = INPUT_VARIABLE(0); - auto x = INPUT_VARIABLE(1); // input [seqLen, bS, nIn] at time t - auto cLast = INPUT_VARIABLE(2); // previous cell state [bS, nOut], time t-1 - auto yLast = INPUT_VARIABLE(3); // previous output [bS, nOut], time t-1 - - auto W = INPUT_VARIABLE(4); // Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(nIn+nOut), 4*nOut] - auto Wci = INPUT_VARIABLE(5); // weights - cell peephole (t-1) connections to input modulation gate, [nOut] - auto Wcf = INPUT_VARIABLE(6); // weights - cell peephole (t-1) connections to forget gate, [nOut] - auto Wco = INPUT_VARIABLE(7); // weights - cell peephole (t) connections to output gate, [nOut] - auto b = INPUT_VARIABLE(8); // biases, [4*nOut] - - auto i = OUTPUT_VARIABLE(0); // Output - input modulation gate activations [seqLen, bS, nOut] - auto c = OUTPUT_VARIABLE(1); // Activations, cell state (pre tanh) [seqLen, bs, nOut] - auto f = OUTPUT_VARIABLE(2); // Output - forget gate activations [seqLen, bs, nOut] - auto o = OUTPUT_VARIABLE(3); // Output - output gate activations [seqLen, bs, nOut] - auto z = OUTPUT_VARIABLE(4); // Output - input gate activations [seqLen, bs, nOut] - auto h = OUTPUT_VARIABLE(5); // Cell state, post tanh [seqLen, bs, nOut] - auto y = OUTPUT_VARIABLE(6); // current cell output [seqLen, bS, numProj], time t - - const int peephole = INT_ARG(0); // if 1, provide peephole connections - const int dataFormat = INT_ARG(1); // 0=TNS=[seqLen,bS,nIn]; 1=NST=[bS,nIn,seqLen]; 2=NTS=[bS,seqLen,nIn] - const double forgetBias = T_ARG(0); - const double clippingCellValue = T_ARG(1); // clipping value for ct, if it is not equal to zero, then cell state is clipped - - REQUIRE_TRUE(x->rankOf()==3, 0, "lstmBlock: Input array 1 (x) rank must be got input with rank %i", x->rankOf()); - REQUIRE_TRUE(cLast->rankOf()==2 && yLast->rankOf()==2, 0, "lstmBlock: Input ranks must be 2 for inputs 2/3 (cLast, yLast) - got %i, %i", cLast->rankOf(), yLast->rankOf()); - REQUIRE_TRUE(W->rankOf()==2, 0, "lstmBlock: Weights array rank must be 2"); - REQUIRE_TRUE(b->rankOf()==1, 0, "lstmBlock: Biases must be rank 1"); - REQUIRE_TRUE(i->rankOf()==3 && c->rankOf()==3 && f->rankOf()==3 && o->rankOf()==3 && z->rankOf()==3 && h->rankOf()==3 && y->rankOf()==3, - 0, "lstmBlock: Output arrays must all be rank 3"); - - helpers::lstmBlockTimeLoop(maxTSLength, x, cLast, yLast, W, Wci, Wcf, Wco, b, i, c, f, o, z, h, y, {(double)peephole, forgetBias, clippingCellValue}, dataFormat); - - return Status::OK(); + auto maxTSLength = INPUT_VARIABLE(0); + auto x = INPUT_VARIABLE(1); // input [seqLen, bS, nIn] at time t + auto cLast = INPUT_VARIABLE(2); // previous cell state [bS, nOut], time t-1 + auto yLast = INPUT_VARIABLE(3); // previous output [bS, nOut], time t-1 + + auto W = INPUT_VARIABLE( + 4); // Weights - concatenated (input-to-hidden, hidden-to-hidden weights) + // weights, [(nIn+nOut), 4*nOut] + auto Wci = INPUT_VARIABLE(5); // weights - cell peephole (t-1) connections to + // input modulation gate, [nOut] + auto Wcf = INPUT_VARIABLE( + 6); // weights - cell peephole (t-1) connections to forget gate, [nOut] + auto Wco = INPUT_VARIABLE( + 7); // weights - cell peephole (t) connections to output gate, [nOut] + auto b = INPUT_VARIABLE(8); // biases, [4*nOut] + + auto i = OUTPUT_VARIABLE( + 0); // Output - input modulation gate activations [seqLen, bS, nOut] + auto c = OUTPUT_VARIABLE( + 1); // Activations, cell state (pre tanh) [seqLen, bs, nOut] + auto f = OUTPUT_VARIABLE( + 2); // Output - forget gate activations [seqLen, bs, nOut] + auto o = OUTPUT_VARIABLE( + 3); // Output - output gate activations [seqLen, bs, nOut] + auto z = + OUTPUT_VARIABLE(4); // Output - input gate activations [seqLen, bs, nOut] + auto h = OUTPUT_VARIABLE(5); // Cell state, post tanh [seqLen, bs, nOut] + auto y = + OUTPUT_VARIABLE(6); // current cell output [seqLen, bS, numProj], time t + + const int peephole = INT_ARG(0); // if 1, provide peephole connections + const int dataFormat = + INT_ARG(1); // 0=TNS=[seqLen,bS,nIn]; 1=NST=[bS,nIn,seqLen]; + // 2=NTS=[bS,seqLen,nIn] + const double forgetBias = T_ARG(0); + const double clippingCellValue = + T_ARG(1); // clipping value for ct, if it is not equal to zero, then cell + // state is clipped + + REQUIRE_TRUE( + x->rankOf() == 3, 0, + "lstmBlock: Input array 1 (x) rank must be got input with rank %i", + x->rankOf()); + REQUIRE_TRUE(cLast->rankOf() == 2 && yLast->rankOf() == 2, 0, + "lstmBlock: Input ranks must be 2 for inputs 2/3 (cLast, yLast) " + "- got %i, %i", + cLast->rankOf(), yLast->rankOf()); + REQUIRE_TRUE(W->rankOf() == 2, 0, "lstmBlock: Weights array rank must be 2"); + REQUIRE_TRUE(b->rankOf() == 1, 0, "lstmBlock: Biases must be rank 1"); + REQUIRE_TRUE(i->rankOf() == 3 && c->rankOf() == 3 && f->rankOf() == 3 && + o->rankOf() == 3 && z->rankOf() == 3 && h->rankOf() == 3 && + y->rankOf() == 3, + 0, "lstmBlock: Output arrays must all be rank 3"); + + helpers::lstmBlockTimeLoop( + maxTSLength, x, cLast, yLast, W, Wci, Wcf, Wco, b, i, c, f, o, z, h, y, + {(double)peephole, forgetBias, clippingCellValue}, dataFormat); + + return Status::OK(); } DECLARE_TYPES(lstmBlock) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - DECLARE_SHAPE_FN(lstmBlock) { - auto x = inputShape->at(1); - auto cLast = inputShape->at(2); - auto yLast = inputShape->at(3); - auto W = inputShape->at(4); - auto b = inputShape->at(8); - - REQUIRE_TRUE(shape::rank(x)==3, 0, "lstmBlock: Input array 1 (x) rank must be got input with rank %i", shape::rank(x)); - REQUIRE_TRUE(shape::rank(cLast)==2 && shape::rank(yLast)==2, 0, "lstmBlock: Input ranks must be 2 for inputs 2/3 (cLast, yLast) - got %i, %i", shape::rank(cLast), shape::rank(yLast)); - REQUIRE_TRUE(shape::rank(W)==2, 0, "lstmBlock: Weights array rank must be 2"); - REQUIRE_TRUE(shape::rank(b)==1, 0, "lstmBlock: Biases must be rank 1"); - - - const int dataFormat = INT_ARG(1); // 0=TNS=[seqLen,bS,size]; 1=NST=[bS,size,seqLen]; 2=NTS=[bS,seqLen,size] - int bs; - int t; - int nOut = cLast[2]; //rank, bs, nOut, ...] - - Nd4jLong *s(nullptr); - ALLOCATE(s, block.getWorkspace(), shape::shapeInfoLength(3), Nd4jLong); // [time, bS, nOut] - s[0] = 3; - if(dataFormat == 0){ - //[rank, seqLen, bs, nIn, ...] - s[1] = x[1]; //seqLen - s[2] = x[2]; //bS - s[3] = nOut; - } else if(dataFormat==1){ - //[rank, bs, nIn, seqLen, ...] - s[1] = x[1]; //bS - s[2] = nOut; - s[3] = x[3]; //seqLen - } else { - //[rank, bs, seqLen, nIn, ...] - s[1] = x[1]; //bS - s[2] = x[2]; //seqLen - s[3] = nOut; - - } - ShapeUtils::updateStridesAndType(s, x, 'c'); - - auto s1 = CONSTANT(s); - - //7 outputs, all same shape/type - return SHAPELIST(s1, s1, s1, s1, s1, s1, s1); + auto x = inputShape->at(1); + auto cLast = inputShape->at(2); + auto yLast = inputShape->at(3); + auto W = inputShape->at(4); + auto b = inputShape->at(8); + + REQUIRE_TRUE( + shape::rank(x) == 3, 0, + "lstmBlock: Input array 1 (x) rank must be got input with rank %i", + shape::rank(x)); + REQUIRE_TRUE(shape::rank(cLast) == 2 && shape::rank(yLast) == 2, 0, + "lstmBlock: Input ranks must be 2 for inputs 2/3 (cLast, yLast) " + "- got %i, %i", + shape::rank(cLast), shape::rank(yLast)); + REQUIRE_TRUE(shape::rank(W) == 2, 0, + "lstmBlock: Weights array rank must be 2"); + REQUIRE_TRUE(shape::rank(b) == 1, 0, "lstmBlock: Biases must be rank 1"); + + const int dataFormat = + INT_ARG(1); // 0=TNS=[seqLen,bS,size]; 1=NST=[bS,size,seqLen]; + // 2=NTS=[bS,seqLen,size] + int bs; + int t; + int nOut = cLast[2]; // rank, bs, nOut, ...] + + Nd4jLong *s(nullptr); + ALLOCATE(s, block.workspace(), shape::shapeInfoLength(3), + Nd4jLong); // [time, bS, nOut] + s[0] = 3; + if (dataFormat == 0) { + //[rank, seqLen, bs, nIn, ...] + s[1] = x[1]; // seqLen + s[2] = x[2]; // bS + s[3] = nOut; + } else if (dataFormat == 1) { + //[rank, bs, nIn, seqLen, ...] + s[1] = x[1]; // bS + s[2] = nOut; + s[3] = x[3]; // seqLen + } else { + //[rank, bs, seqLen, nIn, ...] + s[1] = x[1]; // bS + s[2] = x[2]; // seqLen + s[3] = nOut; + } + ShapeUtils::updateStridesAndType(s, x, 'c'); + + auto s1 = CONSTANT(s); + + // 7 outputs, all same shape/type + return SHAPELIST(s1, s1, s1, s1, s1, s1, s1); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlockCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlockCell.cpp index 55d3a6b7a4c8..04f9f133d2c9 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlockCell.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlockCell.cpp @@ -22,106 +22,165 @@ #if NOT_EXCLUDED(OP_lstmBlockCell) #include -#include +#include namespace sd { namespace ops { - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(lstmBlockCell, 8, 7, false, 2, 1) { - //Notation: mostly following https://arxiv.org/pdf/1503.04069.pdf - auto xt = INPUT_VARIABLE(0); // input [bS, inSize] at time t - auto cLast = INPUT_VARIABLE(1); // previous cell state [bS, numUnits], time t-1 - auto yLast = INPUT_VARIABLE(2); // previous output [bS, numUnits], time t-1 - - auto W = INPUT_VARIABLE(3); // Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(inSize+numUnits), 4*numUnits] - auto Wci = INPUT_VARIABLE(4); // weights - cell peephole (t-1) connections to input modulation gate, [numUnits] - auto Wcf = INPUT_VARIABLE(5); // weights - cell peephole (t-1) connections to forget gate, [numUnits] - auto Wco = INPUT_VARIABLE(6); // weights - cell peephole (t) connections to output gate, [numUnits] - auto b = INPUT_VARIABLE(7); // biases, [4*numUnits] - - auto i = OUTPUT_VARIABLE(0); // Output - input modulation gate activations [bS, numUnits] - auto c = OUTPUT_VARIABLE(1); // Activations, cell state (pre tanh) [bs, numUnits] - auto f = OUTPUT_VARIABLE(2); // Output - forget gate activations [bs, numUnits] - auto o = OUTPUT_VARIABLE(3); // Output - output gate activations [bs, numUnits] - auto z = OUTPUT_VARIABLE(4); // Output - input gate activations [bs, numUnits] - auto h = OUTPUT_VARIABLE(5); // Cell state, post tanh [bs, numUnits] - auto y = OUTPUT_VARIABLE(6); // current cell output [bS, numProj], time t - - - const int peephole = INT_ARG(0); // if 1, provide peephole connections - - const double forgetBias = T_ARG(0); - const double clippingCellValue = T_ARG(1); // clipping value for ct, if it is not equal to zero, then cell state is clipped - - REQUIRE_TRUE(xt->rankOf()==2 && cLast->rankOf()==2 && yLast->rankOf()==2, 0, "lstmBlockCell: Input ranks must be 2 for inputs 0/1/2 (x, cLast, outLast) - got %i, %i, %i", xt->rankOf(), cLast->rankOf(), yLast->rankOf()); - const int rank = xt->rankOf(); - const int bS = xt->sizeAt(0); - const int inSize = xt->sizeAt(1); - const int numUnits = cLast->sizeAt(1); - - REQUIRE_TRUE(xt->sizeAt(0) == yLast->sizeAt(0) && xt->sizeAt(0) == cLast->sizeAt(0), 0, "lstmBlockCell: Input minibatch sizes (dimension 0) must be same for xt, cLast, yLast"); - REQUIRE_TRUE(W->rankOf()==2, 0, "lstmBlockCell: Weights array rank must be 2"); - REQUIRE_TRUE(W->sizeAt(0)==(inSize+numUnits), 0, "lstmBlockCell: Weights size(0) must be equal to inSize + numUnits, got %i", W->sizeAt(0)); - REQUIRE_TRUE(W->sizeAt(1)==(4*numUnits), 0, "lstmBlockCell: Weights size(1) must be equal to 4*numUnits, got %i", W->sizeAt(1)); - REQUIRE_TRUE(b->rankOf()==1 && b->sizeAt(0)==(4*numUnits), 0, "lstmBlockCell: Biases must be rank 1, size 4*numUnits"); - REQUIRE_TRUE(i->rankOf()==2 && c->rankOf()==2 && f->rankOf()==2 && o->rankOf()==2 && z->rankOf()==2 && h->rankOf()==2 && y->rankOf()==2 && - i->sizeAt(0)==bS && c->sizeAt(0)==bS && f->sizeAt(0)==bS && o->sizeAt(0)==bS && z->sizeAt(0)==bS && h->sizeAt(0)==bS && y->sizeAt(0)==bS && - i->sizeAt(1)==numUnits && c->sizeAt(1)==numUnits && f->sizeAt(1)==numUnits && o->sizeAt(1)==numUnits && z->sizeAt(1)==numUnits && h->sizeAt(1)==numUnits && y->sizeAt(1)==numUnits, - 0, "lstmBlockCell: Output arrays must all be rank 2 with size(0) == batchSize and size(1) == numUnits"); - - // calculations - helpers::lstmBlockCell(xt, cLast, yLast, W, Wci, Wcf, Wco, b, i, c, f, o, z, h, y, {(double)peephole, forgetBias, clippingCellValue}); - - return Status::OK(); + // Notation: mostly following https://arxiv.org/pdf/1503.04069.pdf + auto xt = INPUT_VARIABLE(0); // input [bS, inSize] at time t + auto cLast = + INPUT_VARIABLE(1); // previous cell state [bS, numUnits], time t-1 + auto yLast = INPUT_VARIABLE(2); // previous output [bS, numUnits], time t-1 + + auto W = INPUT_VARIABLE( + 3); // Weights - concatenated (input-to-hidden, hidden-to-hidden weights) + // weights, [(inSize+numUnits), 4*numUnits] + auto Wci = INPUT_VARIABLE(4); // weights - cell peephole (t-1) connections to + // input modulation gate, [numUnits] + auto Wcf = INPUT_VARIABLE(5); // weights - cell peephole (t-1) connections to + // forget gate, [numUnits] + auto Wco = INPUT_VARIABLE( + 6); // weights - cell peephole (t) connections to output gate, [numUnits] + auto b = INPUT_VARIABLE(7); // biases, [4*numUnits] + + auto i = OUTPUT_VARIABLE( + 0); // Output - input modulation gate activations [bS, numUnits] + auto c = + OUTPUT_VARIABLE(1); // Activations, cell state (pre tanh) [bs, numUnits] + auto f = + OUTPUT_VARIABLE(2); // Output - forget gate activations [bs, numUnits] + auto o = + OUTPUT_VARIABLE(3); // Output - output gate activations [bs, numUnits] + auto z = + OUTPUT_VARIABLE(4); // Output - input gate activations [bs, numUnits] + auto h = OUTPUT_VARIABLE(5); // Cell state, post tanh [bs, numUnits] + auto y = OUTPUT_VARIABLE(6); // current cell output [bS, numProj], time t + + const int peephole = INT_ARG(0); // if 1, provide peephole connections + + const double forgetBias = T_ARG(0); + const double clippingCellValue = + T_ARG(1); // clipping value for ct, if it is not equal to zero, then cell + // state is clipped + + REQUIRE_TRUE( + xt->rankOf() == 2 && cLast->rankOf() == 2 && yLast->rankOf() == 2, 0, + "lstmBlockCell: Input ranks must be 2 for inputs 0/1/2 (x, cLast, " + "outLast) - got %i, %i, %i", + xt->rankOf(), cLast->rankOf(), yLast->rankOf()); + const int rank = xt->rankOf(); + const int bS = xt->sizeAt(0); + const int inSize = xt->sizeAt(1); + const int numUnits = cLast->sizeAt(1); + + REQUIRE_TRUE( + xt->sizeAt(0) == yLast->sizeAt(0) && xt->sizeAt(0) == cLast->sizeAt(0), 0, + "lstmBlockCell: Input minibatch sizes (dimension 0) must be same for xt, " + "cLast, yLast"); + REQUIRE_TRUE(W->rankOf() == 2, 0, + "lstmBlockCell: Weights array rank must be 2"); + REQUIRE_TRUE(W->sizeAt(0) == (inSize + numUnits), 0, + "lstmBlockCell: Weights size(0) must be equal to inSize + " + "numUnits, got %i", + W->sizeAt(0)); + REQUIRE_TRUE( + W->sizeAt(1) == (4 * numUnits), 0, + "lstmBlockCell: Weights size(1) must be equal to 4*numUnits, got %i", + W->sizeAt(1)); + REQUIRE_TRUE(b->rankOf() == 1 && b->sizeAt(0) == (4 * numUnits), 0, + "lstmBlockCell: Biases must be rank 1, size 4*numUnits"); + REQUIRE_TRUE(i->rankOf() == 2 && c->rankOf() == 2 && f->rankOf() == 2 && + o->rankOf() == 2 && z->rankOf() == 2 && h->rankOf() == 2 && + y->rankOf() == 2 && i->sizeAt(0) == bS && + c->sizeAt(0) == bS && f->sizeAt(0) == bS && + o->sizeAt(0) == bS && z->sizeAt(0) == bS && + h->sizeAt(0) == bS && y->sizeAt(0) == bS && + i->sizeAt(1) == numUnits && c->sizeAt(1) == numUnits && + f->sizeAt(1) == numUnits && o->sizeAt(1) == numUnits && + z->sizeAt(1) == numUnits && h->sizeAt(1) == numUnits && + y->sizeAt(1) == numUnits, + 0, + "lstmBlockCell: Output arrays must all be rank 2 with size(0) " + "== batchSize and size(1) == numUnits"); + + // calculations + helpers::lstmBlockCell(xt, cLast, yLast, W, Wci, Wcf, Wco, b, i, c, f, o, z, + h, y, + {(double)peephole, forgetBias, clippingCellValue}); + + return Status::OK(); } DECLARE_TYPES(lstmBlockCell) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - DECLARE_SHAPE_FN(lstmBlockCell) { - auto xt = inputShape->at(0); // input [bS, inSize] at time t - auto cLast = inputShape->at(1); // previous cell state [bS, numUnits], time t-1 - auto yLast = inputShape->at(2); // previous output [bS, numUnits], time t-1 - - auto W = inputShape->at(3); // Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(inSize+numUnits), 4*numUnits] - auto Wci = inputShape->at(4); // weights - cell peephole (t-1) connections to input modulation gate, [numUnits] - auto Wcf = inputShape->at(5); // weights - cell peephole (t-1) connections to forget gate, [numUnits] - auto Wco = inputShape->at(6); // weights - cell peephole (t) connections to output gate, [numUnits] - auto b = inputShape->at(7); // biases, [4*numUnits] - - REQUIRE_TRUE(shape::rank(xt)==2 && shape::rank(cLast)==2 && shape::rank(yLast)==2, 0, "lstmBlockCell: Input ranks must be 2 for inputs 0/1/2 (x, cLast, outLast) - got %i, %i, %i", shape::rank(xt), shape::rank(cLast), shape::rank(yLast)); - const int inSize = xt[2]; - const int numUnits = cLast[2]; //[rank, bS, nOut, ...] - REQUIRE_TRUE(xt[1] == yLast[1] && xt[1] == cLast[1], 0, "lstmBlockCell: Input minibatch sizes (dimension 0) must be same for xt, cLast, yLast"); - REQUIRE_TRUE(shape::rank(W)==2, 0, "lstmBlockCell: Weights array rank must be rank 2, got %i", shape::rank(W)); - REQUIRE_TRUE(W[1]==(inSize+numUnits), 0, "lstmBlockCell: Weights size(0) must be equal to inSize + numUnits, got %i", W[1]); - REQUIRE_TRUE(W[2]==(4*numUnits), 0, "lstmBlockCell: Weights size(1) must be equal to 4*numUnits, got %i", W[2]); - REQUIRE_TRUE(shape::rank(b)==1 && b[1]==(4*numUnits), 0, "lstmBlockCell: Biases must be rank 1, size 4*numUnits"); - - // evaluate output shapeInfos - const int bS = xt[1]; - Nd4jLong *s(nullptr); - ALLOCATE(s, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong); // [bS, numUnits] - - s[0] = 2; - s[1] = bS; - s[2] = numUnits; - - ShapeUtils::updateStridesAndType(s, xt, 'c'); - - auto s1 = CONSTANT(s); - - //7 outputs, all same shape: z, i, f, o, h, c, y - return SHAPELIST(s1, s1, s1, s1, s1, s1, s1); + auto xt = inputShape->at(0); // input [bS, inSize] at time t + auto cLast = + inputShape->at(1); // previous cell state [bS, numUnits], time t-1 + auto yLast = inputShape->at(2); // previous output [bS, numUnits], time t-1 + + auto W = inputShape->at( + 3); // Weights - concatenated (input-to-hidden, hidden-to-hidden weights) + // weights, [(inSize+numUnits), 4*numUnits] + auto Wci = inputShape->at(4); // weights - cell peephole (t-1) connections to + // input modulation gate, [numUnits] + auto Wcf = inputShape->at(5); // weights - cell peephole (t-1) connections to + // forget gate, [numUnits] + auto Wco = inputShape->at( + 6); // weights - cell peephole (t) connections to output gate, [numUnits] + auto b = inputShape->at(7); // biases, [4*numUnits] + + REQUIRE_TRUE(shape::rank(xt) == 2 && shape::rank(cLast) == 2 && + shape::rank(yLast) == 2, + 0, + "lstmBlockCell: Input ranks must be 2 for inputs 0/1/2 (x, " + "cLast, outLast) - got %i, %i, %i", + shape::rank(xt), shape::rank(cLast), shape::rank(yLast)); + const int inSize = xt[2]; + const int numUnits = cLast[2]; //[rank, bS, nOut, ...] + REQUIRE_TRUE(xt[1] == yLast[1] && xt[1] == cLast[1], 0, + "lstmBlockCell: Input minibatch sizes (dimension 0) must be " + "same for xt, cLast, yLast"); + REQUIRE_TRUE(shape::rank(W) == 2, 0, + "lstmBlockCell: Weights array rank must be rank 2, got %i", + shape::rank(W)); + REQUIRE_TRUE(W[1] == (inSize + numUnits), 0, + "lstmBlockCell: Weights size(0) must be equal to inSize + " + "numUnits, got %i", + W[1]); + REQUIRE_TRUE( + W[2] == (4 * numUnits), 0, + "lstmBlockCell: Weights size(1) must be equal to 4*numUnits, got %i", + W[2]); + REQUIRE_TRUE(shape::rank(b) == 1 && b[1] == (4 * numUnits), 0, + "lstmBlockCell: Biases must be rank 1, size 4*numUnits"); + + // evaluate output shapeInfos + const int bS = xt[1]; + Nd4jLong *s(nullptr); + ALLOCATE(s, block.workspace(), shape::shapeInfoLength(2), + Nd4jLong); // [bS, numUnits] + + s[0] = 2; + s[1] = bS; + s[2] = numUnits; + + ShapeUtils::updateStridesAndType(s, xt, 'c'); + + auto s1 = CONSTANT(s); + + // 7 outputs, all same shape: z, i, f, o, h, c, y + return SHAPELIST(s1, s1, s1, s1, s1, s1, s1); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmCell.cpp index 32cb481eeeca..234d3f01a495 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmCell.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmCell.cpp @@ -22,127 +22,214 @@ #if NOT_EXCLUDED(OP_lstmCell) #include -#include +#include namespace sd { namespace ops { - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(lstmCell, 8, 2, false, 3, 2) { - auto xt = INPUT_VARIABLE(0); // input [bS x inSize] - auto ht_1 = INPUT_VARIABLE(1); // previous cell output [bS x numProj], that is at previous time step t-1, in case of projection=false -> numProj=numUnits!!! - auto ct_1 = INPUT_VARIABLE(2); // previous cell state [bS x numUnits], that is at previous time step t-1 - - auto Wx = INPUT_VARIABLE(3); // input-to-hidden weights, [inSize x 4*numUnits] - auto Wh = INPUT_VARIABLE(4); // hidden-to-hidden weights, [numProj x 4*numUnits] - auto Wc = INPUT_VARIABLE(5); // diagonal weights for peephole connections [3*numUnits] - auto Wp = INPUT_VARIABLE(6); // projection weights [numUnits x numProj] - auto b = INPUT_VARIABLE(7); // biases, [4*numUnits] - - auto ht = OUTPUT_VARIABLE(0); // current cell output [bS x numProj], that is at current time step t - auto ct = OUTPUT_VARIABLE(1); // current cell state [bS x numUnits], that is at current time step t - - const int peephole = INT_ARG(0); // if 1, provide peephole connections - const int projection = INT_ARG(1); // if 1, then projection is performed, if false then numProj==numUnits is mandatory!!!! - - // FIXME: double? - const double clippingCellValue = T_ARG(0); // clipping value for ct, if it is not equal to zero, then cell state is clipped - const double clippingProjValue = T_ARG(1); // clipping value for projected ht, if it is not equal to zero, then projected cell output is clipped - const double forgetBias = T_ARG(2); - - const int rank = xt->rankOf(); - const int bS = xt->sizeAt(0); - const int inSize = xt->sizeAt(1); - const int numProj = ht_1->sizeAt(1); - const int numUnits = ct_1->sizeAt(1); - - // input shapes validation - const std::vector correctHt_1Shape = {bS, numProj}; - const std::vector correctCt_1Shape = {bS, numUnits}; - const std::vector correctWxShape = {inSize, 4*numUnits}; - const std::vector correctWhShape = {numProj, 4*numUnits}; - const std::vector correctWcShape = {3*numUnits}; - const std::vector correctWpShape = {numUnits, numProj}; - const std::vector correctBShape = {4*numUnits}; - - REQUIRE_TRUE(ht_1->isSameShape(correctHt_1Shape), 0, "LSTMCELL operation: wrong shape of initial cell output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctHt_1Shape).c_str(), ShapeUtils::shapeAsString(ht_1).c_str()); - REQUIRE_TRUE(ct_1->isSameShape(correctCt_1Shape), 0, "LSTMCELL operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctCt_1Shape).c_str(), ShapeUtils::shapeAsString(ct_1).c_str()); - REQUIRE_TRUE(Wx->isSameShape(correctWxShape), 0, "LSTMCELL operation: wrong shape of input-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWxShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - REQUIRE_TRUE(Wh->isSameShape(correctWhShape), 0, "LSTMCELL operation: wrong shape of hidden-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWhShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); - REQUIRE_TRUE(Wc->isSameShape(correctWcShape), 0, "LSTMCELL operation: wrong shape of diagonal weights for peephole connections, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWcShape).c_str(), ShapeUtils::shapeAsString(Wc).c_str()); - REQUIRE_TRUE(Wp->isSameShape(correctWpShape), 0, "LSTMCELL operation: wrong shape of projection weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWpShape).c_str(), ShapeUtils::shapeAsString(Wp).c_str()); - REQUIRE_TRUE(b->isSameShape(correctBShape), 0, "LSTMCELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctBShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - REQUIRE_TRUE(!(!projection && numUnits != numProj), 0, "LSTMCELL operation: projection option is switched of, and in this case output dimensionality for the projection matrices (numProj) must be equal to number of units in lstmCell !"); - - // calculations - helpers::lstmCell(block.launchContext(), xt,ht_1,ct_1, Wx,Wh,Wc,Wp, b, ht,ct, {(double)peephole, (double)projection, clippingCellValue, clippingProjValue, forgetBias}); - - return Status::OK(); + auto xt = INPUT_VARIABLE(0); // input [bS x inSize] + auto ht_1 = INPUT_VARIABLE( + 1); // previous cell output [bS x numProj], that is at previous time + // step t-1, in case of projection=false -> numProj=numUnits!!! + auto ct_1 = INPUT_VARIABLE(2); // previous cell state [bS x numUnits], that + // is at previous time step t-1 + + auto Wx = + INPUT_VARIABLE(3); // input-to-hidden weights, [inSize x 4*numUnits] + auto Wh = + INPUT_VARIABLE(4); // hidden-to-hidden weights, [numProj x 4*numUnits] + auto Wc = INPUT_VARIABLE( + 5); // diagonal weights for peephole connections [3*numUnits] + auto Wp = INPUT_VARIABLE(6); // projection weights [numUnits x numProj] + auto b = INPUT_VARIABLE(7); // biases, [4*numUnits] + + auto ht = OUTPUT_VARIABLE( + 0); // current cell output [bS x numProj], that is at current time step t + auto ct = OUTPUT_VARIABLE(1); // current cell state [bS x numUnits], that is + // at current time step t + + const int peephole = INT_ARG(0); // if 1, provide peephole connections + const int projection = + INT_ARG(1); // if 1, then projection is performed, if false then + // numProj==numUnits is mandatory!!!! + + // FIXME: double? + const double clippingCellValue = + T_ARG(0); // clipping value for ct, if it is not equal to zero, then cell + // state is clipped + const double clippingProjValue = + T_ARG(1); // clipping value for projected ht, if it is not equal to zero, + // then projected cell output is clipped + const double forgetBias = T_ARG(2); + + const int rank = xt->rankOf(); + const int bS = xt->sizeAt(0); + const int inSize = xt->sizeAt(1); + const int numProj = ht_1->sizeAt(1); + const int numUnits = ct_1->sizeAt(1); + + // input shapes validation + const std::vector correctHt_1Shape = {bS, numProj}; + const std::vector correctCt_1Shape = {bS, numUnits}; + const std::vector correctWxShape = {inSize, 4 * numUnits}; + const std::vector correctWhShape = {numProj, 4 * numUnits}; + const std::vector correctWcShape = {3 * numUnits}; + const std::vector correctWpShape = {numUnits, numProj}; + const std::vector correctBShape = {4 * numUnits}; + + REQUIRE_TRUE(ht_1->isSameShape(correctHt_1Shape), 0, + "LSTMCELL operation: wrong shape of initial cell output, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctHt_1Shape).c_str(), + ShapeUtils::shapeAsString(ht_1).c_str()); + REQUIRE_TRUE(ct_1->isSameShape(correctCt_1Shape), 0, + "LSTMCELL operation: wrong shape of initial cell state, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctCt_1Shape).c_str(), + ShapeUtils::shapeAsString(ct_1).c_str()); + REQUIRE_TRUE(Wx->isSameShape(correctWxShape), 0, + "LSTMCELL operation: wrong shape of input-to-hidden weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWxShape).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + REQUIRE_TRUE(Wh->isSameShape(correctWhShape), 0, + "LSTMCELL operation: wrong shape of hidden-to-hidden weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWhShape).c_str(), + ShapeUtils::shapeAsString(Wh).c_str()); + REQUIRE_TRUE(Wc->isSameShape(correctWcShape), 0, + "LSTMCELL operation: wrong shape of diagonal weights for " + "peephole connections, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWcShape).c_str(), + ShapeUtils::shapeAsString(Wc).c_str()); + REQUIRE_TRUE(Wp->isSameShape(correctWpShape), 0, + "LSTMCELL operation: wrong shape of projection weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWpShape).c_str(), + ShapeUtils::shapeAsString(Wp).c_str()); + REQUIRE_TRUE(b->isSameShape(correctBShape), 0, + "LSTMCELL operation: wrong shape of biases, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(correctBShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + REQUIRE_TRUE(!(!projection && numUnits != numProj), 0, + "LSTMCELL operation: projection option is switched of, and in " + "this case output dimensionality for the projection matrices " + "(numProj) must be equal to number of units in lstmCell !"); + + // calculations + helpers::lstmCell(block.launchContext(), xt, ht_1, ct_1, Wx, Wh, Wc, Wp, b, + ht, ct, + {(double)peephole, (double)projection, clippingCellValue, + clippingProjValue, forgetBias}); + + return Status::OK(); } - DECLARE_TYPES(lstmCell) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - +DECLARE_TYPES(lstmCell) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(lstmCell) { - - auto xtShapeInfo = inputShape->at(0); // input [bS x inSize] - auto ht_1ShapeInfo = inputShape->at(1); // previous cell output [bS x numProj], that is at previous time step t-1, in case of projection=false -> numProj=numUnits!!! - auto ct_1ShapeInfo = inputShape->at(2); // previous cell state [bS x numUnits], that is at previous time step t-1 - - auto WxShapeInfo = inputShape->at(3); // input-to-hidden weights, [inSize x 4*numUnits] - auto WhShapeInfo = inputShape->at(4); // hidden-to-hidden weights, [numProj x 4*numUnits] - auto WcShapeInfo = inputShape->at(5); // diagonal weights for peephole connections [3*numUnits] - auto WpShapeInfo = inputShape->at(6); // projection weights [numUnits x numProj] - auto bShapeInfo = inputShape->at(7); // biases, [4*numUnits] - - const int rank = shape::rank(xtShapeInfo); - const auto bS = xtShapeInfo[1]; - const auto inSize = xtShapeInfo[2]; - const auto numProj = ht_1ShapeInfo[2]; - const auto numUnits = ct_1ShapeInfo[2]; - - // input shapes validation - const std::vector correctHt_1Shape = {bS, numProj}; - const std::vector correctCt_1Shape = {bS, numUnits}; - const std::vector correctWxShape = {inSize, 4*numUnits}; - const std::vector correctWhShape = {numProj, 4*numUnits}; - const std::vector correctWcShape = {3*numUnits}; - const std::vector correctWpShape = {numUnits, numProj}; - const std::vector correctBShape = {4*numUnits}; - - REQUIRE_TRUE(ShapeUtils::areShapesEqual(ht_1ShapeInfo, correctHt_1Shape), 0, "LSTMCELL operation: wrong shape of initial cell output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctHt_1Shape).c_str(), ShapeUtils::shapeAsString(ht_1ShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(ct_1ShapeInfo, correctCt_1Shape), 0, "LSTMCELL operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctCt_1Shape).c_str(), ShapeUtils::shapeAsString(ct_1ShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WxShapeInfo, correctWxShape), 0, "LSTMCELL operation: wrong shape of input-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWxShape).c_str(), ShapeUtils::shapeAsString(WxShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, correctWhShape), 0, "LSTMCELL operation: wrong shape of hidden-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWhShape).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WcShapeInfo, correctWcShape), 0, "LSTMCELL operation: wrong shape of diagonal weights for peephole connections, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWcShape).c_str(), ShapeUtils::shapeAsString(WcShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WpShapeInfo, correctWpShape), 0, "LSTMCELL operation: wrong shape of projection weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWpShape).c_str(), ShapeUtils::shapeAsString(WpShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, correctBShape), 0, "LSTMCELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctBShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str()); - - // evaluate output shapeInfos - Nd4jLong *hShapeInfo(nullptr), *cShapeInfo(nullptr); - ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x numProj] - ALLOCATE(cShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x numUnits] - - hShapeInfo[0] = cShapeInfo[0] = rank; - hShapeInfo[1] = cShapeInfo[1] = bS; - hShapeInfo[2] = numProj; - cShapeInfo[2] = numUnits; - - ShapeUtils::updateStridesAndType(hShapeInfo, xtShapeInfo, shape::order(ht_1ShapeInfo)); - ShapeUtils::updateStridesAndType(cShapeInfo, xtShapeInfo, shape::order(ct_1ShapeInfo)); - - auto result = SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(hShapeInfo), ConstantShapeHelper::getInstance().createShapeInfo(cShapeInfo)); - RELEASE(hShapeInfo, block.workspace()); - RELEASE(cShapeInfo, block.workspace()); - return result; + auto xtShapeInfo = inputShape->at(0); // input [bS x inSize] + auto ht_1ShapeInfo = inputShape->at( + 1); // previous cell output [bS x numProj], that is at previous time + // step t-1, in case of projection=false -> numProj=numUnits!!! + auto ct_1ShapeInfo = + inputShape->at(2); // previous cell state [bS x numUnits], that is at + // previous time step t-1 + + auto WxShapeInfo = + inputShape->at(3); // input-to-hidden weights, [inSize x 4*numUnits] + auto WhShapeInfo = + inputShape->at(4); // hidden-to-hidden weights, [numProj x 4*numUnits] + auto WcShapeInfo = inputShape->at( + 5); // diagonal weights for peephole connections [3*numUnits] + auto WpShapeInfo = + inputShape->at(6); // projection weights [numUnits x numProj] + auto bShapeInfo = inputShape->at(7); // biases, [4*numUnits] + + const int rank = shape::rank(xtShapeInfo); + const auto bS = xtShapeInfo[1]; + const auto inSize = xtShapeInfo[2]; + const auto numProj = ht_1ShapeInfo[2]; + const auto numUnits = ct_1ShapeInfo[2]; + + // input shapes validation + const std::vector correctHt_1Shape = {bS, numProj}; + const std::vector correctCt_1Shape = {bS, numUnits}; + const std::vector correctWxShape = {inSize, 4 * numUnits}; + const std::vector correctWhShape = {numProj, 4 * numUnits}; + const std::vector correctWcShape = {3 * numUnits}; + const std::vector correctWpShape = {numUnits, numProj}; + const std::vector correctBShape = {4 * numUnits}; + + REQUIRE_TRUE(ShapeUtils::areShapesEqual(ht_1ShapeInfo, correctHt_1Shape), 0, + "LSTMCELL operation: wrong shape of initial cell output, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctHt_1Shape).c_str(), + ShapeUtils::shapeAsString(ht_1ShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(ct_1ShapeInfo, correctCt_1Shape), 0, + "LSTMCELL operation: wrong shape of initial cell state, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctCt_1Shape).c_str(), + ShapeUtils::shapeAsString(ct_1ShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(WxShapeInfo, correctWxShape), 0, + "LSTMCELL operation: wrong shape of input-to-hidden weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWxShape).c_str(), + ShapeUtils::shapeAsString(WxShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, correctWhShape), 0, + "LSTMCELL operation: wrong shape of hidden-to-hidden weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWhShape).c_str(), + ShapeUtils::shapeAsString(WhShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(WcShapeInfo, correctWcShape), 0, + "LSTMCELL operation: wrong shape of diagonal weights for " + "peephole connections, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWcShape).c_str(), + ShapeUtils::shapeAsString(WcShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(WpShapeInfo, correctWpShape), 0, + "LSTMCELL operation: wrong shape of projection weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWpShape).c_str(), + ShapeUtils::shapeAsString(WpShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, correctBShape), 0, + "LSTMCELL operation: wrong shape of biases, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(correctBShape).c_str(), + ShapeUtils::shapeAsString(bShapeInfo).c_str()); + + // evaluate output shapeInfos + Nd4jLong *hShapeInfo(nullptr), *cShapeInfo(nullptr); + ALLOCATE(hShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); // [bS x numProj] + ALLOCATE(cShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); // [bS x numUnits] + + hShapeInfo[0] = cShapeInfo[0] = rank; + hShapeInfo[1] = cShapeInfo[1] = bS; + hShapeInfo[2] = numProj; + cShapeInfo[2] = numUnits; + + ShapeUtils::updateStridesAndType(hShapeInfo, xtShapeInfo, + shape::order(ht_1ShapeInfo)); + ShapeUtils::updateStridesAndType(cShapeInfo, xtShapeInfo, + shape::order(ct_1ShapeInfo)); + + auto result = SHAPELIST( + ConstantShapeHelper::getInstance().createShapeInfo(hShapeInfo), + ConstantShapeHelper::getInstance().createShapeInfo(cShapeInfo)); + RELEASE(hShapeInfo, block.workspace()); + RELEASE(cShapeInfo, block.workspace()); + return result; } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp index 0a0754a8e2bf..700825e27f4e 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp @@ -22,786 +22,1068 @@ #if NOT_EXCLUDED(OP_lstmLayer) #include -#include - +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) { - - // equations (no peephole connections) - // it = σ(Wxi * xt + Wri * ht-1 + bi) - // ft = σ(Wxf * xt + Wrf * ht-1 + bf) - // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) - // ct = ft ◦ ct-1 + it ◦ c't - // ot = σ(Wxo * xt + Wro * ht-1 + bo) - // ht = ot ◦ tanh(ct) - - // equations (peephole connections are present) - // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) - // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) - // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) - // ct = clip(ft ◦ ct-1 + it ◦ c't) - // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) - // ht = ot ◦ tanh(ct) - - // notations: - // bS - batch size - // sL - sequence length, number of time steps - // nIn - input size - // nOut - output size (hidden size) - - // INPUTS: - - // ******* - // input x: - // 1) [sL, bS, nIn] when dataFormat == 0 - // 2) [bS, sL, nIn] when dataFormat == 1 - // 3) [bS, nIn, sL] when dataFormat == 2 - - // ******* - // input weights Wx: - // 1) [nIn, 4*nOut] when directionMode < 2 - // 2) [2, nIn, 4*nOut] when directionMode >= 2 - - // ******* - // recurrent weights Wr: - // 1) [nOut, 4*nOut] when directionMode < 2 - // 2) [2, nOut, 4*nOut] when directionMode >= 2 - - // ******* - // peephole weights Wp, optional: - // 1) [3*nOut] when directionMode < 2 - // 2) [2, 3*nOut] when directionMode >= 2 - - // ******* - // biases b, optional: - // 1) [4*nOut] when directionMode < 2 - // 2) [2, 4*nOut] when directionMode >= 2 - - // ******* - // sequence length array seqLen, optional: - // 1) [bS] - - // ******* - // initial output hI, optional: - // 1) [bS, nOut] when directionMode < 2 - // 2) [2, bS, nOut] when directionMode >= 2 - - // ******* - // initial cell state cI (same shape as in hI), optional: - // 1) [bS, nOut] when directionMode < 2 - // 2) [2, bS, nOut] when directionMode >= 2 - - - // OUTPUTS: - - // ******* - // output h, optional: - // 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0 - // 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1 - // 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2 - // 4) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0 - // 5) [bS, sL, 2*nOut] when directionMode == 3 && dataFormat == 1 - // 6) [bS, 2*nOut, sL] when directionMode == 3 && dataFormat == 2 - // 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3 - - // ******* - // output at last step hL, optional: - // 1) [bS, nOut] when directionMode < 2 - // 2) [2, bS, nOut] when directionMode >= 2 - - // ******* - // cell state at last step cL (same shape as in hL), optional: - // 1) [bS, nOut] when directionMode < 2 - // 2) [2, bS, nOut] when directionMode >= 2 - - // !!! dimension 4*nOut implies order it, ft, c't, ot - // !!! dimension 3*nOut implies order it, ft, ot - - const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, bS, nIn] && [sL, 2, bS, nOut] (for ONNX) - const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) - - // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus - const auto gateAct = INT_ARG(2); // activation for input (i), forget (f) and output (o) gates - const auto cellAct = INT_ARG(3); // activation for cell state (c) - const auto outAct = INT_ARG(4); // activation for output (h) - - const auto hasBiases = B_ARG(0); // indicates whether biases array is provided - const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided - const auto hasInitH = B_ARG(2); // indicates whether initial output is provided - const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided - const auto hasPH = B_ARG(4); // indicates whether peephole connections are present - const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} - const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only - const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only - - const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8; - const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8; - const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8; - const auto gateActHasBeta = gateAct == 3 || gateAct == 6; - const auto cellActHasBeta = cellAct == 3 || cellAct == 6; - const auto outActHasBeta = outAct == 3 || outAct == 6; - - uint count = 1; - const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping - const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0; - const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0; - const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0; - const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0; - const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0; - const auto outBeta = outActHasBeta ? T_ARG(count++) : 0; - - const auto x = INPUT_VARIABLE(0); // input - const auto Wx = INPUT_VARIABLE(1); // input weights - const auto Wr = INPUT_VARIABLE(2); // recurrent weights - - count = 3; - const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases - const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector - const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output - const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state - const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights - - REQUIRE_TRUE(dataFormat < 3 || (dataFormat == 3 && directionMode == 4), 0, "LSTM_LAYER operation: if argument dataFormat = 3, then directionMode = 4, but got dataFormat = %i and directionMode = %i instead !", dataFormat, directionMode); - REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER operation: cell clipping value should be nonnegative (>=0) !"); - REQUIRE_TRUE(retFullSeq || retLastH || retLastC, 0, "LSTM_LAYER operation: please specify what output arrays to produce !"); - - count = 0; - auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output - auto hL = retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step - auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step - - // evaluate dimensions - const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); - const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); - const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2); - const Nd4jLong nOut = Wx->sizeAt(-1) / 4; - - // inputs validations - if(directionMode < 2) { // no bidirectional - - // Wx validation - if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) - REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - // Wr validation - if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut) - REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); - // biases validation - if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); - // initial output validation - if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str()); - // initial cell validation - if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str()); - // peephole weights validation - if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str()); + // equations (no peephole connections) + // it = σ(Wxi * xt + Wri * ht-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + bo) + // ht = ot ◦ tanh(ct) + + // equations (peephole connections are present) + // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = clip(ft ◦ ct-1 + it ◦ c't) + // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) + // ht = ot ◦ tanh(ct) + + // notations: + // bS - batch size + // sL - sequence length, number of time steps + // nIn - input size + // nOut - output size (hidden size) + + // INPUTS: + + // ******* + // input x: + // 1) [sL, bS, nIn] when dataFormat == 0 + // 2) [bS, sL, nIn] when dataFormat == 1 + // 3) [bS, nIn, sL] when dataFormat == 2 + + // ******* + // input weights Wx: + // 1) [nIn, 4*nOut] when directionMode < 2 + // 2) [2, nIn, 4*nOut] when directionMode >= 2 + + // ******* + // recurrent weights Wr: + // 1) [nOut, 4*nOut] when directionMode < 2 + // 2) [2, nOut, 4*nOut] when directionMode >= 2 + + // ******* + // peephole weights Wp, optional: + // 1) [3*nOut] when directionMode < 2 + // 2) [2, 3*nOut] when directionMode >= 2 + + // ******* + // biases b, optional: + // 1) [4*nOut] when directionMode < 2 + // 2) [2, 4*nOut] when directionMode >= 2 + + // ******* + // sequence length array seqLen, optional: + // 1) [bS] + + // ******* + // initial output hI, optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // ******* + // initial cell state cI (same shape as in hI), optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // OUTPUTS: + + // ******* + // output h, optional: + // 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0 + // 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1 + // 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2 + // 4) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0 + // 5) [bS, sL, 2*nOut] when directionMode == 3 && dataFormat == 1 + // 6) [bS, 2*nOut, sL] when directionMode == 3 && dataFormat == 2 + // 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3 + + // ******* + // output at last step hL, optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // ******* + // cell state at last step cL (same shape as in hL), optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // !!! dimension 4*nOut implies order it, ft, c't, ot + // !!! dimension 3*nOut implies order it, ft, ot + + const auto dataFormat = + INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], + // 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, bS, nIn] && + // [sL, 2, bS, nOut] (for ONNX) + const auto directionMode = + INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = + // bidirectional concat, 4 = bidirectional extra output dim + // (in conjunction with format dataFormat = 3) + + // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, + // 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, + // 8=ELU, 9=softsign, 10=softplus + const auto gateAct = + INT_ARG(2); // activation for input (i), forget (f) and output (o) gates + const auto cellAct = INT_ARG(3); // activation for cell state (c) + const auto outAct = INT_ARG(4); // activation for output (h) + + const auto hasBiases = + B_ARG(0); // indicates whether biases array is provided + const auto hasSeqLen = + B_ARG(1); // indicates whether seqLen array is provided + const auto hasInitH = + B_ARG(2); // indicates whether initial output is provided + const auto hasInitC = + B_ARG(3); // indicates whether initial cell state is provided + const auto hasPH = + B_ARG(4); // indicates whether peephole connections are present + const auto retFullSeq = B_ARG(5); // indicates whether to return whole time + // sequence h {h_0, h_1, ... , h_sL-1} + const auto retLastH = + B_ARG(6); // indicates whether to return output at last time step only + const auto retLastC = B_ARG( + 7); // indicates whether to return cells state at last time step only + + const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || + gateAct == 6 || gateAct == 8; + const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || + cellAct == 6 || cellAct == 8; + const auto outActHasAlpha = + outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8; + const auto gateActHasBeta = gateAct == 3 || gateAct == 6; + const auto cellActHasBeta = cellAct == 3 || cellAct == 6; + const auto outActHasBeta = outAct == 3 || outAct == 6; + + uint count = 1; + const auto cellClip = + T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping + const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0; + const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0; + const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0; + const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0; + const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0; + const auto outBeta = outActHasBeta ? T_ARG(count++) : 0; + + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + + count = 3; + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto seqLen = + hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector + const auto hI = + hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output + const auto cI = + hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state + const auto Wp = + hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights + + REQUIRE_TRUE( + dataFormat < 3 || (dataFormat == 3 && directionMode == 4), 0, + "LSTM_LAYER operation: if argument dataFormat = 3, then directionMode = " + "4, but got dataFormat = %i and directionMode = %i instead !", + dataFormat, directionMode); + REQUIRE_TRUE(cellClip >= 0, 0, + "LSTM_LAYER operation: cell clipping value should be " + "nonnegative (>=0) !"); + REQUIRE_TRUE( + retFullSeq || retLastH || retLastC, 0, + "LSTM_LAYER operation: please specify what output arrays to produce !"); + + count = 0; + auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output + auto hL = + retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step + auto cL = + retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step + + // evaluate dimensions + const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); + const Nd4jLong bS = + dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); + const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + // inputs validations + if (directionMode < 2) { // no bidirectional + + // Wx validation + if (Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER operation: wrong shape of input weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({nIn, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + // Wr validation + if (Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4 * nOut) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER operation: wrong shape of recurrent weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({nOut, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wr).c_str()); + // biases validation + if (b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER operation: wrong shape of biases, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString({4 * nOut}).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + // initial output validation + if (hI != nullptr && + (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER operation: wrong shape of initial output, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({bS, nOut}).c_str(), + ShapeUtils::shapeAsString(hI).c_str()); + // initial cell validation + if (cI != nullptr && + (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER operation: wrong shape of initial cell state, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({bS, nOut}).c_str(), + ShapeUtils::shapeAsString(cI).c_str()); + // peephole weights validation + if (Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER operation: wrong peephole weights, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString({3 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wp).c_str()); + } else { // bidirectional + // Wx validation + if (Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER operation: wrong shape of input weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, nIn, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + // Wr validation + if (Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || + Wr->sizeAt(2) != 4 * nOut) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER operation: wrong shape of recurrent weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, nOut, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wr).c_str()); + // biases validation + if (b != nullptr && + (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER operation: wrong shape of biases, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString({2, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + // initial output validation + if (hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || + hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER operation: wrong shape of initial output, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), + ShapeUtils::shapeAsString(hI).c_str()); + // initial cell validation + if (cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || + cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER operation: wrong shape of initial cell state, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), + ShapeUtils::shapeAsString(cI).c_str()); + // peephole weights validation + if (Wp != nullptr && + (Wp->rankOf() != 2 || Wp->sizeAt(0) != 2 || Wp->sizeAt(1) != 3 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER operation: wrong peephole weights, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString({2, 3 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wp).c_str()); + } + + std::vector params = { + static_cast(dataFormat), static_cast(directionMode), + static_cast(cellClip), static_cast(gateAct), + static_cast(gateAlpha), static_cast(gateBeta), + static_cast(cellAct), static_cast(cellAlpha), + static_cast(cellBeta), static_cast(outAct), + static_cast(outAlpha), static_cast(outBeta)}; + + if (directionMode == 0) { // forward + + helpers::lstmLayerTimeLoop(x, Wx, Wr, b, seqLen, hI, cI, Wp, params, true, + h, hL, cL); + } else if (directionMode == 1) { // backward + + helpers::lstmLayerTimeLoop(x, Wx, Wr, b, seqLen, hI, cI, Wp, params, false, + h, hL, cL); + } else { // bidirectional + + NDArray WxFwd = (*Wx)({0, 1, 0, 0, 0, 0}); + NDArray WxBwd = (*Wx)({1, 2, 0, 0, 0, 0}); + NDArray WrFwd = (*Wr)({0, 1, 0, 0, 0, 0}); + NDArray WrBwd = (*Wr)({1, 2, 0, 0, 0, 0}); + + NDArray *WpFwd(nullptr), *WpBwd(nullptr), *bFwd(nullptr), *bBwd(nullptr), + *hIFwd(nullptr), *hIBwd(nullptr), *cIFwd(nullptr), *cIBwd(nullptr), + *hLFwd(nullptr), *hLBwd(nullptr), *cLFwd(nullptr), *cLBwd(nullptr), + *hFwd(nullptr), *hBwd(nullptr); + + if (Wp) { + WpFwd = new NDArray((*Wp)({0, 1, 0, 0})); + WpBwd = new NDArray((*Wp)({1, 2, 0, 0})); } - else { // bidirectional - // Wx validation - if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn) - REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - // Wr validation - if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut) - REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); - // biases validation - if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); - // initial output validation - if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str()); - // initial cell validation - if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str()); - // peephole weights validation - if(Wp != nullptr && (Wp->rankOf() != 2 || Wp->sizeAt(0) != 2 || Wp->sizeAt(1) != 3*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str()); + if (b) { + bFwd = new NDArray((*b)({0, 1, 0, 0})); + bBwd = new NDArray((*b)({1, 2, 0, 0})); } - - std::vector params = {static_cast(dataFormat), static_cast(directionMode), static_cast(cellClip), - static_cast(gateAct), static_cast(gateAlpha), static_cast(gateBeta), - static_cast(cellAct), static_cast(cellAlpha), static_cast(cellBeta), - static_cast(outAct), static_cast(outAlpha), static_cast(outBeta)}; - - if(directionMode == 0) { // forward - - helpers::lstmLayerTimeLoop(x, Wx, Wr, b, seqLen, hI, cI, Wp, params, true, h, hL, cL); + if (hI) { + hIFwd = new NDArray((*hI)({0, 1, 0, 0, 0, 0})); + hIBwd = new NDArray((*hI)({1, 2, 0, 0, 0, 0})); } - else if(directionMode == 1) { // backward - - helpers::lstmLayerTimeLoop(x, Wx, Wr, b, seqLen, hI, cI, Wp, params, false, h, hL, cL); + if (cI) { + cIFwd = new NDArray((*cI)({0, 1, 0, 0, 0, 0})); + cIBwd = new NDArray((*cI)({1, 2, 0, 0, 0, 0})); + } + if (hL) { + hLFwd = new NDArray((*hL)({0, 1, 0, 0, 0, 0})); + hLBwd = new NDArray((*hL)({1, 2, 0, 0, 0, 0})); } - else { // bidirectional - - NDArray WxFwd = (*Wx)({0,1, 0,0, 0,0}); - NDArray WxBwd = (*Wx)({1,2, 0,0, 0,0}); - NDArray WrFwd = (*Wr)({0,1, 0,0, 0,0}); - NDArray WrBwd = (*Wr)({1,2, 0,0, 0,0}); - - NDArray *WpFwd(nullptr), *WpBwd(nullptr), *bFwd(nullptr), *bBwd(nullptr), *hIFwd(nullptr), *hIBwd(nullptr), *cIFwd(nullptr), *cIBwd(nullptr), - *hLFwd(nullptr), *hLBwd(nullptr), *cLFwd(nullptr), *cLBwd(nullptr), *hFwd(nullptr), *hBwd(nullptr); - - if(Wp) { - WpFwd = new NDArray((*Wp)({0,1, 0,0})); - WpBwd = new NDArray((*Wp)({1,2, 0,0})); - } - if(b) { - bFwd = new NDArray((*b)({0,1, 0,0})); - bBwd = new NDArray((*b)({1,2, 0,0})); - } - if(hI) { - hIFwd = new NDArray((*hI)({0,1, 0,0, 0,0})); - hIBwd = new NDArray((*hI)({1,2, 0,0, 0,0})); - } - if(cI) { - cIFwd = new NDArray((*cI)({0,1, 0,0, 0,0})); - cIBwd = new NDArray((*cI)({1,2, 0,0, 0,0})); - } - if(hL) { - hLFwd = new NDArray((*hL)({0,1, 0,0, 0,0})); - hLBwd = new NDArray((*hL)({1,2, 0,0, 0,0})); - } - if(cL) { - cLFwd = new NDArray((*cL)({0,1, 0,0, 0,0})); - cLBwd = new NDArray((*cL)({1,2, 0,0, 0,0})); - } - - if(h) { - if(directionMode == 2) { // sum - hFwd = h; - hBwd = new NDArray(h, false, h->getContext()); - } - else if(directionMode == 3) { // concat - hFwd = new NDArray(dataFormat <= 1 ? (*h)({0,0, 0,0, 0,nOut}) : (*h)({0,0, 0,nOut, 0,0})); - hBwd = new NDArray(dataFormat <= 1 ? (*h)({0,0, 0,0, nOut,2*nOut}) : (*h)({0,0, nOut,2*nOut, 0,0})); - } - else { // directionMode == 4 - hFwd = new NDArray((*h)({0,0, 0,1, 0,0, 0,0})); - hBwd = new NDArray((*h)({0,0, 1,2, 0,0, 0,0})); - } - } - - // FIXME - following two calls are independent and may run in different streams - helpers::lstmLayerTimeLoop(x, &WxFwd, &WrFwd, bFwd, seqLen, hIFwd, cIFwd, WpFwd, params, true, hFwd, hLFwd, cLFwd); - helpers::lstmLayerTimeLoop(x, &WxBwd, &WrBwd, bBwd, seqLen, hIBwd, cIBwd, WpBwd, params, false, hBwd, hLBwd, cLBwd); - - if(h && directionMode == 2) - *h += *hBwd; - - delete WpFwd; delete WpBwd; delete bFwd; delete bBwd; delete hIFwd; delete hIBwd; delete cIFwd; - delete cIBwd; delete hLFwd; delete hLBwd; delete cLFwd; delete cLBwd; delete hBwd; - if(hFwd != h) - delete hFwd; + if (cL) { + cLFwd = new NDArray((*cL)({0, 1, 0, 0, 0, 0})); + cLBwd = new NDArray((*cL)({1, 2, 0, 0, 0, 0})); + } + + if (h) { + if (directionMode == 2) { // sum + hFwd = h; + hBwd = new NDArray(h, false, h->getContext()); + } else if (directionMode == 3) { // concat + hFwd = new NDArray(dataFormat <= 1 ? (*h)({0, 0, 0, 0, 0, nOut}) + : (*h)({0, 0, 0, nOut, 0, 0})); + hBwd = + new NDArray(dataFormat <= 1 ? (*h)({0, 0, 0, 0, nOut, 2 * nOut}) + : (*h)({0, 0, nOut, 2 * nOut, 0, 0})); + } else { // directionMode == 4 + hFwd = new NDArray((*h)({0, 0, 0, 1, 0, 0, 0, 0})); + hBwd = new NDArray((*h)({0, 0, 1, 2, 0, 0, 0, 0})); + } } - return Status::OK(); + // FIXME - following two calls are independent and may run in different + // streams + helpers::lstmLayerTimeLoop(x, &WxFwd, &WrFwd, bFwd, seqLen, hIFwd, cIFwd, + WpFwd, params, true, hFwd, hLFwd, cLFwd); + helpers::lstmLayerTimeLoop(x, &WxBwd, &WrBwd, bBwd, seqLen, hIBwd, cIBwd, + WpBwd, params, false, hBwd, hLBwd, cLBwd); + + if (h && directionMode == 2) *h += *hBwd; + + delete WpFwd; + delete WpBwd; + delete bFwd; + delete bBwd; + delete hIFwd; + delete hIBwd; + delete cIFwd; + delete cIBwd; + delete hLFwd; + delete hLBwd; + delete cLFwd; + delete cLBwd; + delete hBwd; + if (hFwd != h) delete hFwd; + } + + return Status::OK(); } DECLARE_TYPES(lstmLayer) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - DECLARE_SHAPE_FN(lstmLayer) { - - const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nIn] (for ONNX) - const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim - - const auto retFullSeq = B_ARG(5); // indicates whether to return whole h {h_0, h_1, ... , h_sL-1}, if true, format would be [sL,bS,nOut] (exact shape depends on dataFormat argument) - const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) - const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) - - const auto x = INPUT_VARIABLE(0); // input - const auto Wx = INPUT_VARIABLE(1); // input weights - const auto Wr = INPUT_VARIABLE(2); // recurrent weights - - // evaluate dimensions - const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); - const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); - const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2); - const Nd4jLong nOut = Wx->sizeAt(-1) / 4; - - DataType type; - if(x->isR()) - type = x->dataType(); - else - type = sd::DataType::FLOAT32; - - auto shapes = SHAPELIST(); - - // evaluate h shape (output) - if(retFullSeq) { - - std::vector hShape; - - if(directionMode <= 2) { // single direction or bidirectional with sum - if(dataFormat == 0) - hShape = {sL, bS, nOut}; - else if(dataFormat == 1) - hShape = {bS, sL, nOut}; - else if(dataFormat == 2) - hShape = {bS, nOut, sL}; - } - else if(directionMode == 3) { // bidirectional with concat - - if(dataFormat == 0) - hShape = {sL, bS, 2*nOut}; - else if(dataFormat == 1) - hShape = {bS, sL, 2*nOut}; - else if(dataFormat == 2) - hShape = {bS, 2*nOut, sL}; - } - else { // bidirectional with extra output dimension equal to 2 - hShape = {sL, 2, bS, nOut}; - } - - shapes->push_back(ConstantShapeHelper::getInstance().createShapeInfo(type, x->ordering(), hShape)); + const auto dataFormat = INT_ARG( + 0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, + // nIn, sL], for bidirectional: 3 = [sL, 2, bS, nIn] (for ONNX) + const auto directionMode = + INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = + // bidirectional concat, 4 = bidirectional extra output dim + + const auto retFullSeq = + B_ARG(5); // indicates whether to return whole h {h_0, h_1, ... , + // h_sL-1}, if true, format would be [sL,bS,nOut] (exact shape + // depends on dataFormat argument) + const auto retLastH = + B_ARG(6); // indicates whether to return output at last time step only, + // in this case shape would be [bS, nOut] (exact shape depends + // on dataFormat argument) + const auto retLastC = + B_ARG(7); // indicates whether to return cells state at last time step + // only, in this case shape would be [bS, nOut] (exact shape + // depends on dataFormat argument) + + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + + // evaluate dimensions + const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); + const Nd4jLong bS = + dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); + const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + DataType type; + if (x->isR()) + type = x->dataType(); + else + type = sd::DataType::FLOAT32; + + auto shapes = SHAPELIST(); + + // evaluate h shape (output) + if (retFullSeq) { + std::vector hShape; + + if (directionMode <= 2) { // single direction or bidirectional with sum + if (dataFormat == 0) + hShape = {sL, bS, nOut}; + else if (dataFormat == 1) + hShape = {bS, sL, nOut}; + else if (dataFormat == 2) + hShape = {bS, nOut, sL}; + } else if (directionMode == 3) { // bidirectional with concat + + if (dataFormat == 0) + hShape = {sL, bS, 2 * nOut}; + else if (dataFormat == 1) + hShape = {bS, sL, 2 * nOut}; + else if (dataFormat == 2) + hShape = {bS, 2 * nOut, sL}; + } else { // bidirectional with extra output dimension equal to 2 + hShape = {sL, 2, bS, nOut}; } - // evaluate hL shape (output at last step) - if(retLastH) { - - std::vector hLShape; + shapes->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + type, x->ordering(), hShape)); + } - if(directionMode < 2) - hLShape = {bS, nOut}; - else - hLShape = {2, bS, nOut}; + // evaluate hL shape (output at last step) + if (retLastH) { + std::vector hLShape; - shapes->push_back(ConstantShapeHelper::getInstance().createShapeInfo(type, x->ordering(), hLShape)); + if (directionMode < 2) + hLShape = {bS, nOut}; + else + hLShape = {2, bS, nOut}; - if(retLastC) // cL and hL have same shapes - shapes->push_back(shapes->at(shapes->size() - 1)); - } + shapes->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + type, x->ordering(), hLShape)); - // evaluate cL shape (cell state at last step) - if(retLastC && !retLastH) { + if (retLastC) // cL and hL have same shapes + shapes->push_back(shapes->at(shapes->size() - 1)); + } - std::vector cLShape; + // evaluate cL shape (cell state at last step) + if (retLastC && !retLastH) { + std::vector cLShape; - if(directionMode < 2) - cLShape = {bS, nOut}; - else - cLShape = {2, bS, nOut}; + if (directionMode < 2) + cLShape = {bS, nOut}; + else + cLShape = {2, bS, nOut}; - shapes->push_back(ConstantShapeHelper::getInstance().createShapeInfo(type, x->ordering(), cLShape)); - } + shapes->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + type, x->ordering(), cLShape)); + } - return shapes; + return shapes; } - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(lstmLayer_bp, 4, 1, false, 1, 5) { - - // equations (no peephole connections) - // it = σ(Wxi * xt + Wri * ht-1 + bi) - // ft = σ(Wxf * xt + Wrf * ht-1 + bf) - // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) - // ct = ft ◦ ct-1 + it ◦ c't - // ot = σ(Wxo * xt + Wro * ht-1 + bo) - // ht = ot ◦ tanh(ct) - - // equations (peephole connections are present) - // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) - // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) - // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) - // ct = clip(ft ◦ ct-1 + it ◦ c't) - // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) - // ht = ot ◦ tanh(ct) - - // notations: - // bS - batch size - // sL - sequence length, number of time steps - // nIn - input size - // nOut - output size (hidden size) - - // INPUTS: - - // ******* - // input x: - // 1) [sL, bS, nIn] when dataFormat == 0 - // 2) [bS, sL, nIn] when dataFormat == 1 - // 3) [bS, nIn, sL] when dataFormat == 2 - - // ******* - // input weights Wx: - // 1) [nIn, 4*nOut] when directionMode < 2 - // 2) [2, nIn, 4*nOut] when directionMode >= 2 - - // ******* - // recurrent weights Wr: - // 1) [nOut, 4*nOut] when directionMode < 2 - // 2) [2, nOut, 4*nOut] when directionMode >= 2 - - // ******* - // peephole weights Wp, optional: - // 1) [3*nOut] when directionMode < 2 - // 2) [2, 3*nOut] when directionMode >= 2 - - // ******* - // biases b, optional: - // 1) [4*nOut] when directionMode < 2 - // 2) [2, 4*nOut] when directionMode >= 2 - - // ******* - // sequence length array seqLen, optional: - // 1) [bS] - - // ******* - // initial output hI, optional: - // 1) [bS, nOut] when directionMode < 2 - // 2) [2, bS, nOut] when directionMode >= 2 - - // ******* - // initial cell state cI (same shape as in hI), optional: - // 1) [bS, nOut] when directionMode < 2 - // 2) [2, bS, nOut] when directionMode >= 2 - - // ******* - // gradient vs. output dLdh, optional: - // 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0 - // 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1 - // 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2 - // 4) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0 - // 5) [bS, sL, 2*nOut] when directionMode == 3 && dataFormat == 1 - // 6) [bS, 2*nOut, sL] when directionMode == 3 && dataFormat == 2 - // 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3 - - // ******* - // gradient vs output at last time step dLdhL, optional: - // 1) [bS, nOut] when directionMode < 2 - // 2) [2, bS, nOut] when directionMode >= 2 - - // ******* - // gradient vs cell state at last time step dLdcL(same shape as in dLdhL), optional: - // 1) [bS, nOut] when directionMode < 2 - // 2) [2, bS, nOut] when directionMode >= 2 - - - // OUTPUTS: - - // ******* - // gradient vs. input dLdx: - // 1) [sL, bS, nIn] when dataFormat == 0 - // 2) [bS, sL, nIn] when dataFormat == 1 - // 3) [bS, nIn, sL] when dataFormat == 2 - - // ******* - // gradient vs. input weights dLdWx: - // 1) [nIn, 4*nOut] when directionMode < 2 - // 2) [2, nIn, 4*nOut] when directionMode >= 2 - - // ******* - // gradient vs. recurrent weights dLdWr: - // 1) [nOut, 4*nOut] when directionMode < 2 - // 2) [2, nOut, 4*nOut] when directionMode >= 2 - - // ******* - // gradient vs. peephole weights dLdWp, optional: - // 1) [3*nOut] when directionMode < 2 - // 2) [2, 3*nOut] when directionMode >= 2 - - // ******* - // gradient vs. biases dLdb, optional: - // 1) [4*nOut] when directionMode < 2 - // 2) [2, 4*nOut] when directionMode >= 2 - - // gradient vs. sequence length array dLdsL, optional (do not calculate it!!!): - // 1) [bS] always - - // ******* - // gradient vs. initial output dLdhI, optional: - // 1) [bS, nOut] when directionMode < 2 - // 2) [2, bS, nOut] when directionMode >= 2 - - // ******* - // gradient vs. initial cell state dLdcI (same shape as in dLdhI), optional: - // 1) [bS, nOut] when directionMode < 2 - // 2) [2, bS, nOut] when directionMode >= 2 - - - // !!! dimension 4*nOut implies order it, ft, c't, ot - // !!! dimension 3*nOut implies order it, ft, ot - - const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, bS, nIn] && [sL, 2, bS, nOut] (for ONNX) - const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) - - // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus - const auto gateAct = INT_ARG(2); // activation for input (i), forget (f) and output (o) gates - const auto cellAct = INT_ARG(3); // activation for cell state (c) - const auto outAct = INT_ARG(4); // activation for output (h) - - const auto hasBiases = B_ARG(0); // indicates whether biases array is provided - const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided - const auto hasInitH = B_ARG(2); // indicates whether initial output is provided - const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided - const auto hasPH = B_ARG(4); // indicates whether peephole connections are present - const auto retFullSeq = B_ARG(5); // indicates whether gradient vs. outputs is given for whole time sequence dLdh {dLdh_0, dLdh_1, ... , dLdh_sL-1} - const auto retLastH = B_ARG(6); // indicates whether gradient vs. output at last time step (dLdhL) is given - const auto retLastC = B_ARG(7); // indicates whether gradient vs. cell state at last time step (dLdcL) is given - - const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8; - const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8; - const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8; - const auto gateActHasBeta = gateAct == 3 || gateAct == 6; - const auto cellActHasBeta = cellAct == 3 || cellAct == 6; - const auto outActHasBeta = outAct == 3 || outAct == 6; - - uint count = 1; - const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping - const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0; - const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0; - const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0; - const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0; - const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0; - const auto outBeta = outActHasBeta ? T_ARG(count++) : 0; - - REQUIRE_TRUE(dataFormat < 3 || (dataFormat == 3 && directionMode == 4), 0, "LSTM_LAYER_BP operation: if argument dataFormat = 3, then directionMode = 4, but got dataFormat = %i and directionMode = %i instead !", dataFormat, directionMode); - REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_BP operation: cell clipping value should be nonnegative (>=0) !"); - REQUIRE_TRUE(retFullSeq || retLastH || retLastC, 0, "LSTM_LAYER_BP operation: please specify at least one of three input gradient arrays: dLdh, dLdhL or dLdcL !"); - - const auto x = INPUT_VARIABLE(0); // input - const auto Wx = INPUT_VARIABLE(1); // input weights - const auto Wr = INPUT_VARIABLE(2); // recurrent weights - - count = 3; - const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases - const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector - const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output - const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state - const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights - const auto dLdh = retFullSeq ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. output - const auto dLdhL = retLastH ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. output at last time step - const auto dLdcL = retLastC ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. cell state at last time step - - count = 3; - auto dLdx = OUTPUT_VARIABLE(0); // gradient vs. input - auto dLdWx = OUTPUT_NULLIFIED(1); // gradient vs. input weights - auto dLdWr = OUTPUT_NULLIFIED(2); // gradient vs. recurrent weights - auto dLdb = hasBiases ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. biases - auto dLdsL = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. seqLen vector, we don't calculate it !!! - auto dLdhI = hasInitH ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. initial output - auto dLdcI = hasInitC ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. initial cell state - auto dLdWp = hasPH ? OUTPUT_NULLIFIED(count) : nullptr; // gradient vs. peephole weights - - // evaluate dimensions - const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); - const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); - const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2); - const Nd4jLong nOut = Wx->sizeAt(-1) / 4; - - // inputs validations - if(directionMode < 2) { // no bidirectional - - // Wx validation - if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - // Wr validation - if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); - // biases validation - if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); - // initial output validation - if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str()); - // initial cell validation - if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str()); - // peephole weights validation - if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str()); - // gradient vs. output at last time step validation - if(dLdhL != nullptr && (dLdhL->rankOf() != 2 || dLdhL->sizeAt(0) != bS || dLdhL->sizeAt(1) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. output at last time step, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdhL).c_str()); - // gradient vs. cell state at last time step validation - if(dLdcL != nullptr && (dLdcL->rankOf() != 2 || dLdcL->sizeAt(0) != bS || dLdcL->sizeAt(1) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. cell state at last time, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdcL).c_str()); + // equations (no peephole connections) + // it = σ(Wxi * xt + Wri * ht-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + bo) + // ht = ot ◦ tanh(ct) + + // equations (peephole connections are present) + // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = clip(ft ◦ ct-1 + it ◦ c't) + // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) + // ht = ot ◦ tanh(ct) + + // notations: + // bS - batch size + // sL - sequence length, number of time steps + // nIn - input size + // nOut - output size (hidden size) + + // INPUTS: + + // ******* + // input x: + // 1) [sL, bS, nIn] when dataFormat == 0 + // 2) [bS, sL, nIn] when dataFormat == 1 + // 3) [bS, nIn, sL] when dataFormat == 2 + + // ******* + // input weights Wx: + // 1) [nIn, 4*nOut] when directionMode < 2 + // 2) [2, nIn, 4*nOut] when directionMode >= 2 + + // ******* + // recurrent weights Wr: + // 1) [nOut, 4*nOut] when directionMode < 2 + // 2) [2, nOut, 4*nOut] when directionMode >= 2 + + // ******* + // peephole weights Wp, optional: + // 1) [3*nOut] when directionMode < 2 + // 2) [2, 3*nOut] when directionMode >= 2 + + // ******* + // biases b, optional: + // 1) [4*nOut] when directionMode < 2 + // 2) [2, 4*nOut] when directionMode >= 2 + + // ******* + // sequence length array seqLen, optional: + // 1) [bS] + + // ******* + // initial output hI, optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // ******* + // initial cell state cI (same shape as in hI), optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // ******* + // gradient vs. output dLdh, optional: + // 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0 + // 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1 + // 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2 + // 4) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0 + // 5) [bS, sL, 2*nOut] when directionMode == 3 && dataFormat == 1 + // 6) [bS, 2*nOut, sL] when directionMode == 3 && dataFormat == 2 + // 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3 + + // ******* + // gradient vs output at last time step dLdhL, optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // ******* + // gradient vs cell state at last time step dLdcL(same shape as in dLdhL), + // optional: 1) [bS, nOut] when directionMode < 2 2) [2, bS, nOut] when + // directionMode >= 2 + + // OUTPUTS: + + // ******* + // gradient vs. input dLdx: + // 1) [sL, bS, nIn] when dataFormat == 0 + // 2) [bS, sL, nIn] when dataFormat == 1 + // 3) [bS, nIn, sL] when dataFormat == 2 + + // ******* + // gradient vs. input weights dLdWx: + // 1) [nIn, 4*nOut] when directionMode < 2 + // 2) [2, nIn, 4*nOut] when directionMode >= 2 + + // ******* + // gradient vs. recurrent weights dLdWr: + // 1) [nOut, 4*nOut] when directionMode < 2 + // 2) [2, nOut, 4*nOut] when directionMode >= 2 + + // ******* + // gradient vs. peephole weights dLdWp, optional: + // 1) [3*nOut] when directionMode < 2 + // 2) [2, 3*nOut] when directionMode >= 2 + + // ******* + // gradient vs. biases dLdb, optional: + // 1) [4*nOut] when directionMode < 2 + // 2) [2, 4*nOut] when directionMode >= 2 + + // gradient vs. sequence length array dLdsL, optional (do not calculate + // it!!!): 1) [bS] always + + // ******* + // gradient vs. initial output dLdhI, optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // ******* + // gradient vs. initial cell state dLdcI (same shape as in dLdhI), optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // !!! dimension 4*nOut implies order it, ft, c't, ot + // !!! dimension 3*nOut implies order it, ft, ot + + const auto dataFormat = + INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], + // 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, bS, nIn] && + // [sL, 2, bS, nOut] (for ONNX) + const auto directionMode = + INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = + // bidirectional concat, 4 = bidirectional extra output dim + // (in conjunction with format dataFormat = 3) + + // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, + // 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, + // 8=ELU, 9=softsign, 10=softplus + const auto gateAct = + INT_ARG(2); // activation for input (i), forget (f) and output (o) gates + const auto cellAct = INT_ARG(3); // activation for cell state (c) + const auto outAct = INT_ARG(4); // activation for output (h) + + const auto hasBiases = + B_ARG(0); // indicates whether biases array is provided + const auto hasSeqLen = + B_ARG(1); // indicates whether seqLen array is provided + const auto hasInitH = + B_ARG(2); // indicates whether initial output is provided + const auto hasInitC = + B_ARG(3); // indicates whether initial cell state is provided + const auto hasPH = + B_ARG(4); // indicates whether peephole connections are present + const auto retFullSeq = + B_ARG(5); // indicates whether gradient vs. outputs is given for whole + // time sequence dLdh {dLdh_0, dLdh_1, ... , dLdh_sL-1} + const auto retLastH = B_ARG(6); // indicates whether gradient vs. output at + // last time step (dLdhL) is given + const auto retLastC = B_ARG(7); // indicates whether gradient vs. cell state + // at last time step (dLdcL) is given + + const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || + gateAct == 6 || gateAct == 8; + const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || + cellAct == 6 || cellAct == 8; + const auto outActHasAlpha = + outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8; + const auto gateActHasBeta = gateAct == 3 || gateAct == 6; + const auto cellActHasBeta = cellAct == 3 || cellAct == 6; + const auto outActHasBeta = outAct == 3 || outAct == 6; + + uint count = 1; + const auto cellClip = + T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping + const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0; + const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0; + const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0; + const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0; + const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0; + const auto outBeta = outActHasBeta ? T_ARG(count++) : 0; + + REQUIRE_TRUE( + dataFormat < 3 || (dataFormat == 3 && directionMode == 4), 0, + "LSTM_LAYER_BP operation: if argument dataFormat = 3, then directionMode " + "= 4, but got dataFormat = %i and directionMode = %i instead !", + dataFormat, directionMode); + REQUIRE_TRUE(cellClip >= 0, 0, + "LSTM_LAYER_BP operation: cell clipping value should be " + "nonnegative (>=0) !"); + REQUIRE_TRUE(retFullSeq || retLastH || retLastC, 0, + "LSTM_LAYER_BP operation: please specify at least one of three " + "input gradient arrays: dLdh, dLdhL or dLdcL !"); + + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + + count = 3; + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto seqLen = + hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector + const auto hI = + hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output + const auto cI = + hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state + const auto Wp = + hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights + const auto dLdh = + retFullSeq ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. output + const auto dLdhL = retLastH + ? INPUT_VARIABLE(count++) + : nullptr; // gradient vs. output at last time step + const auto dLdcL = + retLastC ? INPUT_VARIABLE(count++) + : nullptr; // gradient vs. cell state at last time step + + count = 3; + auto dLdx = OUTPUT_VARIABLE(0); // gradient vs. input + auto dLdWx = OUTPUT_NULLIFIED(1); // gradient vs. input weights + auto dLdWr = OUTPUT_NULLIFIED(2); // gradient vs. recurrent weights + auto dLdb = + hasBiases ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. biases + auto dLdsL = + hasSeqLen + ? INPUT_VARIABLE(count++) + : nullptr; // gradient vs. seqLen vector, we don't calculate it !!! + auto dLdhI = hasInitH ? OUTPUT_NULLIFIED(count++) + : nullptr; // gradient vs. initial output + auto dLdcI = hasInitC ? OUTPUT_NULLIFIED(count++) + : nullptr; // gradient vs. initial cell state + auto dLdWp = hasPH ? OUTPUT_NULLIFIED(count) + : nullptr; // gradient vs. peephole weights + + // evaluate dimensions + const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); + const Nd4jLong bS = + dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); + const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + // inputs validations + if (directionMode < 2) { // no bidirectional + + // Wx validation + if (Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong shape of input weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({nIn, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + // Wr validation + if (Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4 * nOut) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong shape of recurrent weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({nOut, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wr).c_str()); + // biases validation + if (b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong shape of biases, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString({4 * nOut}).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + // initial output validation + if (hI != nullptr && + (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong shape of initial output, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({bS, nOut}).c_str(), + ShapeUtils::shapeAsString(hI).c_str()); + // initial cell validation + if (cI != nullptr && + (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong shape of initial cell " + "state, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({bS, nOut}).c_str(), + ShapeUtils::shapeAsString(cI).c_str()); + // peephole weights validation + if (Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong peephole weights, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString({3 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wp).c_str()); + // gradient vs. output at last time step validation + if (dLdhL != nullptr && (dLdhL->rankOf() != 2 || dLdhL->sizeAt(0) != bS || + dLdhL->sizeAt(1) != nOut)) + REQUIRE_TRUE( + false, 0, + "LSTM_LAYER_BP operation: wrong shape of gradient vs. output at last " + "time step, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({bS, nOut}).c_str(), + ShapeUtils::shapeAsString(dLdhL).c_str()); + // gradient vs. cell state at last time step validation + if (dLdcL != nullptr && (dLdcL->rankOf() != 2 || dLdcL->sizeAt(0) != bS || + dLdcL->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong shape of gradient vs. cell " + "state at last time, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({bS, nOut}).c_str(), + ShapeUtils::shapeAsString(dLdcL).c_str()); + } else { // bidirectional + // Wx validation + if (Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong shape of input weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, nIn, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + // Wr validation + if (Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || + Wr->sizeAt(2) != 4 * nOut) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong shape of recurrent weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, nOut, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wr).c_str()); + // biases validation + if (b != nullptr && + (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong shape of biases, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + // initial output validation + if (hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || + hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong shape of initial output, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), + ShapeUtils::shapeAsString(hI).c_str()); + // initial cell validation + if (cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || + cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong shape of initial cell " + "state, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), + ShapeUtils::shapeAsString(cI).c_str()); + // peephole weights validation + if (Wp != nullptr && + (Wp->rankOf() != 2 || Wp->sizeAt(0) != 2 || Wp->sizeAt(1) != 3 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong peephole weights, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, 3 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wp).c_str()); + // gradient vs. output at last time step validation + if (dLdhL != nullptr && + (dLdhL->rankOf() != 3 || dLdhL->sizeAt(0) != 2 || + dLdhL->sizeAt(1) != bS || dLdhL->sizeAt(2) != nOut)) + REQUIRE_TRUE( + false, 0, + "LSTM_LAYER_BP operation: wrong shape of gradient vs. output at last " + "time step, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), + ShapeUtils::shapeAsString(dLdhL).c_str()); + // gradient vs. cell state at last time step validation + if (dLdcL != nullptr && + (dLdcL->rankOf() != 3 || dLdcL->sizeAt(0) != 2 || + dLdcL->sizeAt(1) != bS || dLdcL->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong shape of gradient vs. cell " + "state at last time, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), + ShapeUtils::shapeAsString(dLdcL).c_str()); + } + + // gradient vs. output validation + if (dLdh) { + int factor = directionMode <= 2 ? 1 : 2; + std::vector expdLdhShape; + if (dataFormat == 0) + expdLdhShape = std::vector{sL, bS, factor * nOut}; + else if (dataFormat == 1) + expdLdhShape = std::vector{bS, sL, factor * nOut}; + else if (dataFormat == 2) + expdLdhShape = std::vector{bS, factor * nOut, sL}; + else + expdLdhShape = std::vector{sL, 2, bS, nOut}; + REQUIRE_TRUE(dLdh->isSameShape(expdLdhShape), 0, + "LSTM_LAYER_CELL_BP operation: wrong shape of gradient vs. " + "output, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expdLdhShape).c_str(), + ShapeUtils::shapeAsString(dLdh).c_str()); + } + + std::vector params = { + static_cast(dataFormat), static_cast(directionMode), + static_cast(cellClip), static_cast(gateAct), + static_cast(gateAlpha), static_cast(gateBeta), + static_cast(cellAct), static_cast(cellAlpha), + static_cast(cellBeta), static_cast(outAct), + static_cast(outAlpha), static_cast(outBeta)}; + + if (directionMode == 0) { // forward + + helpers::lstmLayerTimeLoopBp(x, Wx, Wr, b, seqLen, hI, cI, Wp, dLdh, dLdhL, + dLdcL, params, true, dLdx, dLdWx, dLdWr, dLdb, + dLdhI, dLdcI, dLdWp); + } else if (directionMode == 1) { // backward + + helpers::lstmLayerTimeLoopBp(x, Wx, Wr, b, seqLen, hI, cI, Wp, dLdh, dLdhL, + dLdcL, params, false, dLdx, dLdWx, dLdWr, dLdb, + dLdhI, dLdcI, dLdWp); + } else { // bidirectional + + NDArray WxFwd = (*Wx)({0, 1, 0, 0, 0, 0}); + NDArray WxBwd = (*Wx)({1, 2, 0, 0, 0, 0}); + NDArray dLdWxFwd = (*dLdWx)({0, 1, 0, 0, 0, 0}); + NDArray dLdWxBwd = (*dLdWx)({1, 2, 0, 0, 0, 0}); + + NDArray WrFwd = (*Wr)({0, 1, 0, 0, 0, 0}); + NDArray WrBwd = (*Wr)({1, 2, 0, 0, 0, 0}); + NDArray dLdWrFwd = (*dLdWr)({0, 1, 0, 0, 0, 0}); + NDArray dLdWrBwd = (*dLdWr)({1, 2, 0, 0, 0, 0}); + + NDArray *WpFwd(nullptr), *WpBwd(nullptr), *bFwd(nullptr), *bBwd(nullptr), + *hIFwd(nullptr), *hIBwd(nullptr), *cIFwd(nullptr), *cIBwd(nullptr), + *dLdhFwd(nullptr), *dLdhBwd(nullptr), *dLdhLFwd(nullptr), + *dLdhLBwd(nullptr), *dLdcLFwd(nullptr), *dLdcLBwd(nullptr), + *dLdWpFwd(nullptr), *dLdWpBwd(nullptr), *dLdbFwd(nullptr), + *dLdbBwd(nullptr), *dLdhIFwd(nullptr), *dLdhIBwd(nullptr), + *dLdcIFwd(nullptr), *dLdcIBwd(nullptr); + + if (Wp) { + WpFwd = new NDArray((*Wp)({0, 1, 0, 0})); + WpBwd = new NDArray((*Wp)({1, 2, 0, 0})); + dLdWpFwd = new NDArray((*dLdWp)({0, 1, 0, 0})); + dLdWpBwd = new NDArray((*dLdWp)({1, 2, 0, 0})); } - else { // bidirectional - // Wx validation - if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - // Wr validation - if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); - // biases validation - if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); - // initial output validation - if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str()); - // initial cell validation - if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str()); - // peephole weights validation - if(Wp != nullptr && (Wp->rankOf() != 2 || Wp->sizeAt(0) != 2 || Wp->sizeAt(1) != 3*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str()); - // gradient vs. output at last time step validation - if(dLdhL != nullptr && (dLdhL->rankOf() != 3 || dLdhL->sizeAt(0) != 2 || dLdhL->sizeAt(1) != bS || dLdhL->sizeAt(2) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. output at last time step, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdhL).c_str()); - // gradient vs. cell state at last time step validation - if(dLdcL != nullptr && (dLdcL->rankOf() != 3 || dLdcL->sizeAt(0) != 2 || dLdcL->sizeAt(1) != bS || dLdcL->sizeAt(2) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. cell state at last time, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdcL).c_str()); + if (b) { + bFwd = new NDArray((*b)({0, 1, 0, 0})); + bBwd = new NDArray((*b)({1, 2, 0, 0})); + dLdbFwd = new NDArray((*dLdb)({0, 1, 0, 0})); + dLdbBwd = new NDArray((*dLdb)({1, 2, 0, 0})); } - - // gradient vs. output validation - if(dLdh) { - int factor = directionMode <= 2 ? 1 : 2; - std::vector expdLdhShape; - if(dataFormat == 0) expdLdhShape = std::vector{sL, bS, factor*nOut}; - else if(dataFormat == 1) expdLdhShape = std::vector{bS, sL, factor*nOut}; - else if(dataFormat == 2) expdLdhShape = std::vector{bS, factor*nOut, sL}; - else expdLdhShape = std::vector{sL, 2, bS, nOut}; - REQUIRE_TRUE(dLdh->isSameShape(expdLdhShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of gradient vs. output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expdLdhShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str()); + if (hI) { + hIFwd = new NDArray((*hI)({0, 1, 0, 0, 0, 0})); + hIBwd = new NDArray((*hI)({1, 2, 0, 0, 0, 0})); + dLdhIFwd = new NDArray((*dLdhI)({0, 1, 0, 0, 0, 0})); + dLdhIBwd = new NDArray((*dLdhI)({1, 2, 0, 0, 0, 0})); } - - std::vector params = {static_cast(dataFormat), static_cast(directionMode), static_cast(cellClip), - static_cast(gateAct), static_cast(gateAlpha), static_cast(gateBeta), - static_cast(cellAct), static_cast(cellAlpha), static_cast(cellBeta), - static_cast(outAct), static_cast(outAlpha), static_cast(outBeta)}; - - if(directionMode == 0) { // forward - - helpers::lstmLayerTimeLoopBp(x, Wx, Wr, b, seqLen, hI, cI, Wp, dLdh, dLdhL, dLdcL, params, true, dLdx, dLdWx, dLdWr, dLdb, dLdhI, dLdcI, dLdWp); + if (cI) { + cIFwd = new NDArray((*cI)({0, 1, 0, 0, 0, 0})); + cIBwd = new NDArray((*cI)({1, 2, 0, 0, 0, 0})); + dLdcIFwd = new NDArray((*dLdcI)({0, 1, 0, 0, 0, 0})); + dLdcIBwd = new NDArray((*dLdcI)({1, 2, 0, 0, 0, 0})); + } + if (dLdhL) { + dLdhLFwd = new NDArray((*dLdhL)({0, 1, 0, 0, 0, 0})); + dLdhLBwd = new NDArray((*dLdhL)({1, 2, 0, 0, 0, 0})); + } + if (dLdcL) { + dLdcLFwd = new NDArray((*dLdcL)({0, 1, 0, 0, 0, 0})); + dLdcLBwd = new NDArray((*dLdcL)({1, 2, 0, 0, 0, 0})); } - else if(directionMode == 1) { // backward - helpers::lstmLayerTimeLoopBp(x, Wx, Wr, b, seqLen, hI, cI, Wp, dLdh, dLdhL, dLdcL, params, false, dLdx, dLdWx, dLdWr, dLdb, dLdhI, dLdcI, dLdWp); + if (dLdh) { + if (directionMode == 2) { // sum + dLdhFwd = dLdh; + dLdhBwd = dLdh; + } else if (directionMode == 3) { // concat + dLdhFwd = new NDArray(dataFormat <= 1 ? (*dLdh)({0, 0, 0, 0, 0, nOut}) + : (*dLdh)({0, 0, 0, nOut, 0, 0})); + dLdhBwd = new NDArray(dataFormat <= 1 + ? (*dLdh)({0, 0, 0, 0, nOut, 2 * nOut}) + : (*dLdh)({0, 0, nOut, 2 * nOut, 0, 0})); + } else { // directionMode == 4 + dLdhFwd = new NDArray((*dLdh)({0, 0, 0, 1, 0, 0, 0, 0})); + dLdhBwd = new NDArray((*dLdh)({0, 0, 1, 2, 0, 0, 0, 0})); + } } - else { // bidirectional - - NDArray WxFwd = (*Wx)({0,1, 0,0, 0,0}); - NDArray WxBwd = (*Wx)({1,2, 0,0, 0,0}); - NDArray dLdWxFwd = (*dLdWx)({0,1, 0,0, 0,0}); - NDArray dLdWxBwd = (*dLdWx)({1,2, 0,0, 0,0}); - - NDArray WrFwd = (*Wr)({0,1, 0,0, 0,0}); - NDArray WrBwd = (*Wr)({1,2, 0,0, 0,0}); - NDArray dLdWrFwd = (*dLdWr)({0,1, 0,0, 0,0}); - NDArray dLdWrBwd = (*dLdWr)({1,2, 0,0, 0,0}); - - NDArray *WpFwd(nullptr), *WpBwd(nullptr), *bFwd(nullptr), *bBwd(nullptr), *hIFwd(nullptr), *hIBwd(nullptr), *cIFwd(nullptr), *cIBwd(nullptr), - *dLdhFwd(nullptr), *dLdhBwd(nullptr), *dLdhLFwd(nullptr), *dLdhLBwd(nullptr), *dLdcLFwd(nullptr), *dLdcLBwd(nullptr), - *dLdWpFwd(nullptr), *dLdWpBwd(nullptr), *dLdbFwd(nullptr), *dLdbBwd(nullptr), - *dLdhIFwd(nullptr), *dLdhIBwd(nullptr), *dLdcIFwd(nullptr), *dLdcIBwd(nullptr); - - if(Wp) { - WpFwd = new NDArray((*Wp)({0,1, 0,0})); - WpBwd = new NDArray((*Wp)({1,2, 0,0})); - dLdWpFwd = new NDArray((*dLdWp)({0,1, 0,0})); - dLdWpBwd = new NDArray((*dLdWp)({1,2, 0,0})); - } - if(b) { - bFwd = new NDArray((*b)({0,1, 0,0})); - bBwd = new NDArray((*b)({1,2, 0,0})); - dLdbFwd = new NDArray((*dLdb)({0,1, 0,0})); - dLdbBwd = new NDArray((*dLdb)({1,2, 0,0})); - } - if(hI) { - hIFwd = new NDArray((*hI)({0,1, 0,0, 0,0})); - hIBwd = new NDArray((*hI)({1,2, 0,0, 0,0})); - dLdhIFwd = new NDArray((*dLdhI)({0,1, 0,0, 0,0})); - dLdhIBwd = new NDArray((*dLdhI)({1,2, 0,0, 0,0})); - } - if(cI) { - cIFwd = new NDArray((*cI)({0,1, 0,0, 0,0})); - cIBwd = new NDArray((*cI)({1,2, 0,0, 0,0})); - dLdcIFwd = new NDArray((*dLdcI)({0,1, 0,0, 0,0})); - dLdcIBwd = new NDArray((*dLdcI)({1,2, 0,0, 0,0})); - } - if(dLdhL) { - dLdhLFwd = new NDArray((*dLdhL)({0,1, 0,0, 0,0})); - dLdhLBwd = new NDArray((*dLdhL)({1,2, 0,0, 0,0})); - } - if(dLdcL) { - dLdcLFwd = new NDArray((*dLdcL)({0,1, 0,0, 0,0})); - dLdcLBwd = new NDArray((*dLdcL)({1,2, 0,0, 0,0})); - } - - if(dLdh) { - if(directionMode == 2) { // sum - dLdhFwd = dLdh; - dLdhBwd = dLdh; - } - else if(directionMode == 3) { // concat - dLdhFwd = new NDArray(dataFormat <= 1 ? (*dLdh)({0,0, 0,0, 0,nOut}) : (*dLdh)({0,0, 0,nOut, 0,0})); - dLdhBwd = new NDArray(dataFormat <= 1 ? (*dLdh)({0,0, 0,0, nOut,2*nOut}) : (*dLdh)({0,0, nOut,2*nOut, 0,0})); - } - else { // directionMode == 4 - dLdhFwd = new NDArray((*dLdh)({0,0, 0,1, 0,0, 0,0})); - dLdhBwd = new NDArray((*dLdh)({0,0, 1,2, 0,0, 0,0})); - } - } - - NDArray dLdxBwd = dLdx->ulike(); - - // FIXME - following two calls are independent and may run in different streams - helpers::lstmLayerTimeLoopBp(x, &WxFwd, &WrFwd, bFwd, seqLen, hIFwd, cIFwd, WpFwd, dLdhFwd, dLdhLFwd, dLdcLFwd, params, true, dLdx, &dLdWxFwd, &dLdWrFwd, dLdbFwd, dLdhIFwd, dLdcIFwd, dLdWpFwd); - helpers::lstmLayerTimeLoopBp(x, &WxBwd, &WrBwd, bBwd, seqLen, hIBwd, cIBwd, WpBwd, dLdhBwd, dLdhLBwd, dLdcLBwd, params, false, &dLdxBwd, &dLdWxBwd, &dLdWrBwd, dLdbBwd, dLdhIBwd, dLdcIBwd, dLdWpBwd); - - *dLdx += dLdxBwd; - - delete WpFwd; delete WpBwd; delete bFwd; delete bBwd; delete hIFwd; delete hIBwd; delete cIFwd; delete cIBwd; - delete dLdhLFwd; delete dLdhLBwd; delete dLdcLFwd; delete dLdcLBwd; - delete dLdWpFwd; delete dLdWpBwd; delete dLdbFwd; delete dLdbBwd; - delete dLdhIFwd; delete dLdhIBwd; delete dLdcIFwd; delete dLdcIBwd; - - if(!(dLdh && directionMode == 2)) { delete dLdhFwd; delete dLdhBwd; } + + NDArray dLdxBwd = dLdx->ulike(); + + // FIXME - following two calls are independent and may run in different + // streams + helpers::lstmLayerTimeLoopBp(x, &WxFwd, &WrFwd, bFwd, seqLen, hIFwd, cIFwd, + WpFwd, dLdhFwd, dLdhLFwd, dLdcLFwd, params, + true, dLdx, &dLdWxFwd, &dLdWrFwd, dLdbFwd, + dLdhIFwd, dLdcIFwd, dLdWpFwd); + helpers::lstmLayerTimeLoopBp(x, &WxBwd, &WrBwd, bBwd, seqLen, hIBwd, cIBwd, + WpBwd, dLdhBwd, dLdhLBwd, dLdcLBwd, params, + false, &dLdxBwd, &dLdWxBwd, &dLdWrBwd, dLdbBwd, + dLdhIBwd, dLdcIBwd, dLdWpBwd); + + *dLdx += dLdxBwd; + + delete WpFwd; + delete WpBwd; + delete bFwd; + delete bBwd; + delete hIFwd; + delete hIBwd; + delete cIFwd; + delete cIBwd; + delete dLdhLFwd; + delete dLdhLBwd; + delete dLdcLFwd; + delete dLdcLBwd; + delete dLdWpFwd; + delete dLdWpBwd; + delete dLdbFwd; + delete dLdbBwd; + delete dLdhIFwd; + delete dLdhIBwd; + delete dLdcIFwd; + delete dLdcIBwd; + + if (!(dLdh && directionMode == 2)) { + delete dLdhFwd; + delete dLdhBwd; } + } - return Status::OK(); + return Status::OK(); } DECLARE_TYPES(lstmLayer_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } DECLARE_SHAPE_FN(lstmLayer_bp) { - - const auto hasBiases = B_ARG(0); // indicates whether biases array is provided - const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided - const auto hasInitH = B_ARG(2); // indicates whether initial output is provided - const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided - const auto hasPH = B_ARG(4); // indicates whether peephole connections are present - - int count = 3; - const auto x = INPUT_VARIABLE(0); // input - const auto Wx = INPUT_VARIABLE(1); // input weights - const auto Wr = INPUT_VARIABLE(2); // recurrent weights - const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases - const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector - const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output - const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state - const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights - - auto outShapes = SHAPELIST(x->shapeInfo(), Wx->shapeInfo(), Wr->shapeInfo()); - - if(b != nullptr) - outShapes->push_back(b->shapeInfo()); - if(seqLen != nullptr) - outShapes->push_back(seqLen->shapeInfo()); - if(hI != nullptr) - outShapes->push_back(hI->shapeInfo()); - if(cI != nullptr) - outShapes->push_back(cI->shapeInfo()); - if(Wp != nullptr) - outShapes->push_back(Wp->shapeInfo()); - - return outShapes; + const auto hasBiases = + B_ARG(0); // indicates whether biases array is provided + const auto hasSeqLen = + B_ARG(1); // indicates whether seqLen array is provided + const auto hasInitH = + B_ARG(2); // indicates whether initial output is provided + const auto hasInitC = + B_ARG(3); // indicates whether initial cell state is provided + const auto hasPH = + B_ARG(4); // indicates whether peephole connections are present + + int count = 3; + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto seqLen = + hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector + const auto hI = + hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output + const auto cI = + hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state + const auto Wp = + hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights + + auto outShapes = SHAPELIST(x->shapeInfo(), Wx->shapeInfo(), Wr->shapeInfo()); + + if (b != nullptr) outShapes->push_back(b->shapeInfo()); + if (seqLen != nullptr) outShapes->push_back(seqLen->shapeInfo()); + if (hI != nullptr) outShapes->push_back(hI->shapeInfo()); + if (cI != nullptr) outShapes->push_back(cI->shapeInfo()); + if (Wp != nullptr) outShapes->push_back(Wp->shapeInfo()); + + return outShapes; } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp index 645541d6b25a..9775b68cc34f 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp @@ -22,318 +22,408 @@ #if NOT_EXCLUDED(OP_lstmLayerCell) #include -#include +#include namespace sd { namespace ops { - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(lstmLayerCell, 5, 2, false, 1, 3) { - - // equations (no peephole connections) - // it = σ(Wxi * xt + Wri * ht-1 + bi) - // ft = σ(Wxf * xt + Wrf * ht-1 + bf) - // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) - // ct = ft ◦ ct-1 + it ◦ c't - // ot = σ(Wxo * xt + Wro * ht-1 + bo) - // ht = ot ◦ tanh(ct) - - // equations (peephole connections are present) - // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) - // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) - // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) - // ct = clip(ft ◦ ct-1 + it ◦ c't) - // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) - // ht = ot ◦ tanh(ct) - - // notations: - // bS - batch size - // nIn - input size - // nOut - output size (hidden size) - - // INPUTS: - // input x: [bS, nIn] or [nIn] - // input weights Wx: [nIn, 4*nOut] - // recurrent weights Wr: [nOut, 4*nOut] - // initial (previous) output hI: [bS, nOut] or [nOut] - // initial (previous) cell state cI: [bS, nOut] or [nOut] - // biases b (optional): [4*nOut] - // peephole weights Wp (optional): [3*nOut] - - // OUTPUTS: - // current output h: [bS, nOut] or [nOut] - // current cell state c: [bS, nOut] or [nOut] - - // !!! dimension 4*nOut implies order it, ft, c't, ot - // !!! dimension 3*nOut implies order it, ft, ot - - // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus - const auto gateAct = INT_ARG(0); // activation for input (i), forget (f) and output (o) gates - const auto cellAct = INT_ARG(1); // activation for cell state (c) - const auto outAct = INT_ARG(2); // activation for output (h) - - const auto hasBiases = B_ARG(0); // indicates whether biases array is provided - const auto hasPH = B_ARG(1); // indicates whether peephole connections are present - - const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8; - const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8; - const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8; - const auto gateActHasBeta = gateAct == 3 || gateAct == 6; - const auto cellActHasBeta = cellAct == 3 || cellAct == 6; - const auto outActHasBeta = outAct == 3 || outAct == 6; - - uint count = 1; - const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping - const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0; - const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0; - const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0; - const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0; - const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0; - const auto outBeta = outActHasBeta ? T_ARG(count++) : 0; - - count = 3; - const auto x = INPUT_VARIABLE(0); // input - const auto Wx = INPUT_VARIABLE(1); // input weights - const auto Wr = INPUT_VARIABLE(2); // recurrent weights - const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases - const auto hI = INPUT_VARIABLE(count++); // initial output - const auto cI = INPUT_VARIABLE(count++); // initial cell state - const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr; // peephole weights - - REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_CELL operation: cell clipping value should be nonnegative (>=0) !"); - - auto h = OUTPUT_VARIABLE(0); - auto c = OUTPUT_VARIABLE(1); - - // evaluate dimensions - const Nd4jLong bS = x->rankOf() == 1 ? 0 : x->sizeAt(0); - const Nd4jLong nIn = x->sizeAt(-1); - const Nd4jLong nOut = Wx->sizeAt(-1) / 4; - - // inputs validations - // Wx validation - if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - // Wr validation - if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); - // initial output/cell validation - std::vector exphIcIShape = x->rankOf() == 1 ? std::vector{nOut} : std::vector{bS, nOut}; - REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); - REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(cI).c_str()); - // biases validation - if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); - // peephole weights validation - if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str()); - - std::vector params = {static_cast(0)/*ignore*/, static_cast(0)/*ignore*/, static_cast(cellClip), - static_cast(gateAct), static_cast(gateAlpha), static_cast(gateBeta), - static_cast(cellAct), static_cast(cellAlpha), static_cast(cellBeta), - static_cast(outAct), static_cast(outAlpha), static_cast(outBeta)}; - - helpers::lstmLayerCell(x, Wx, Wr, b, hI, cI, Wp, params, h, c); - - return Status::OK(); + // equations (no peephole connections) + // it = σ(Wxi * xt + Wri * ht-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + bo) + // ht = ot ◦ tanh(ct) + + // equations (peephole connections are present) + // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = clip(ft ◦ ct-1 + it ◦ c't) + // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) + // ht = ot ◦ tanh(ct) + + // notations: + // bS - batch size + // nIn - input size + // nOut - output size (hidden size) + + // INPUTS: + // input x: [bS, nIn] or [nIn] + // input weights Wx: [nIn, 4*nOut] + // recurrent weights Wr: [nOut, 4*nOut] + // initial (previous) output hI: [bS, nOut] or [nOut] + // initial (previous) cell state cI: [bS, nOut] or [nOut] + // biases b (optional): [4*nOut] + // peephole weights Wp (optional): [3*nOut] + + // OUTPUTS: + // current output h: [bS, nOut] or [nOut] + // current cell state c: [bS, nOut] or [nOut] + + // !!! dimension 4*nOut implies order it, ft, c't, ot + // !!! dimension 3*nOut implies order it, ft, ot + + // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, + // 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, + // 8=ELU, 9=softsign, 10=softplus + const auto gateAct = + INT_ARG(0); // activation for input (i), forget (f) and output (o) gates + const auto cellAct = INT_ARG(1); // activation for cell state (c) + const auto outAct = INT_ARG(2); // activation for output (h) + + const auto hasBiases = + B_ARG(0); // indicates whether biases array is provided + const auto hasPH = + B_ARG(1); // indicates whether peephole connections are present + + const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || + gateAct == 6 || gateAct == 8; + const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || + cellAct == 6 || cellAct == 8; + const auto outActHasAlpha = + outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8; + const auto gateActHasBeta = gateAct == 3 || gateAct == 6; + const auto cellActHasBeta = cellAct == 3 || cellAct == 6; + const auto outActHasBeta = outAct == 3 || outAct == 6; + + uint count = 1; + const auto cellClip = + T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping + const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0; + const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0; + const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0; + const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0; + const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0; + const auto outBeta = outActHasBeta ? T_ARG(count++) : 0; + + count = 3; + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto hI = INPUT_VARIABLE(count++); // initial output + const auto cI = INPUT_VARIABLE(count++); // initial cell state + const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr; // peephole weights + + REQUIRE_TRUE(cellClip >= 0, 0, + "LSTM_LAYER_CELL operation: cell clipping value should be " + "nonnegative (>=0) !"); + + auto h = OUTPUT_VARIABLE(0); + auto c = OUTPUT_VARIABLE(1); + + // evaluate dimensions + const Nd4jLong bS = x->rankOf() == 1 ? 0 : x->sizeAt(0); + const Nd4jLong nIn = x->sizeAt(-1); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + // inputs validations + // Wx validation + if (Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_CELL operation: wrong shape of input weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({nIn, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + // Wr validation + if (Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4 * nOut) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_CELL operation: wrong shape of recurrent weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({nOut, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wr).c_str()); + // initial output/cell validation + std::vector exphIcIShape = x->rankOf() == 1 + ? std::vector{nOut} + : std::vector{bS, nOut}; + REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0, + "LSTM_LAYER_CELL operation: wrong shape of initial output, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(exphIcIShape).c_str(), + ShapeUtils::shapeAsString(hI).c_str()); + REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0, + "LSTM_LAYER_CELL operation: wrong shape of initial cell state, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(exphIcIShape).c_str(), + ShapeUtils::shapeAsString(cI).c_str()); + // biases validation + if (b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_CELL operation: wrong shape of biases, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString({4 * nOut}).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + // peephole weights validation + if (Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_CELL operation: wrong shape of peephole weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({3 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wp).c_str()); + + std::vector params = { + static_cast(0) /*ignore*/, static_cast(0) /*ignore*/, + static_cast(cellClip), static_cast(gateAct), + static_cast(gateAlpha), static_cast(gateBeta), + static_cast(cellAct), static_cast(cellAlpha), + static_cast(cellBeta), static_cast(outAct), + static_cast(outAlpha), static_cast(outBeta)}; + + helpers::lstmLayerCell(x, Wx, Wr, b, hI, cI, Wp, params, h, c); + + return Status::OK(); } DECLARE_TYPES(lstmLayerCell) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - DECLARE_SHAPE_FN(lstmLayerCell) { + const auto hasBiases = + B_ARG(0); // indicates whether biases array is provided - const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + uint count = hasBiases ? 4 : 3; + const auto hI = INPUT_VARIABLE(count++); // initial output + const auto cI = INPUT_VARIABLE(count); // initial cell state - uint count = hasBiases ? 4 : 3; - const auto hI = INPUT_VARIABLE(count++); // initial output - const auto cI = INPUT_VARIABLE(count); // initial cell state - - return new ShapeList({hI->shapeInfo(), cI->shapeInfo()}); + return new ShapeList({hI->shapeInfo(), cI->shapeInfo()}); } ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(lstmLayerCellBp, 7, 5, false, 1, 3) { - - // equations (no peephole connections) - // it = σ(Wxi * xt + Wri * ht-1 + bi) - // ft = σ(Wxf * xt + Wrf * ht-1 + bf) - // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) - // ct = ft ◦ ct-1 + it ◦ c't - // ot = σ(Wxo * xt + Wro * ht-1 + bo) - // ht = ot ◦ tanh(ct) - - // equations (peephole connections are present) - // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) - // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) - // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) - // ct = clip(ft ◦ ct-1 + it ◦ c't) - // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) - // ht = ot ◦ tanh(ct) - - // notations: - // bS - batch size - // nIn - input size - // nOut - output size (hidden size) - - // INPUTS: - // input x: [bS, nIn] or [nIn] - // input weights Wx: [nIn, 4*nOut] - // recurrent weights Wr: [nOut, 4*nOut] - // initial (previous) output hI: [bS, nOut] or [nOut] - // initial (previous) cell state cI: [bS, nOut] or [nOut] - // gradient wrt output dLdh: [bS, nOut] or [nOut] - // gradient wrt cell state dLdc: [bS, nOut] or [nOut] - // peephole weights Wp (optional): [3*nOut] - // biases b (optional): [4*nOut] - - // OUTPUTS: - // gradient wrt x dLdx: [bS, nIn] or [nIn] - // gradient wrt Wx dLdWx: [nIn, 4*nOut] - // gradient wrt Wr dLdWr: [nOut, 4*nOut] - // gradient wrt hI dLdhI: [bS, nOut] or [nOut] - // gradient wrt cI dLdcI: [bS, nOut] or [nOut] - // gradient wrt b dLdb (optional): [4*nOut] - // gradient wrt Wp dLdWp (optional): [3*nOut] - - - // !!! dimension 4*nOut implies order it, ft, c't, ot - // !!! dimension 3*nOut implies order it, ft, ot - - // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus - const auto gateAct = INT_ARG(0); // activation for input (i), forget (f) and output (o) gates - const auto cellAct = INT_ARG(1); // activation for cell state (c) - const auto outAct = INT_ARG(2); // activation for output (h) - - const auto hasBiases = B_ARG(0); // indicates whether biases array is provided - const auto hasPH = B_ARG(1); // indicates whether peephole connections are present - - const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8; - const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8; - const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8; - const auto gateActHasBeta = gateAct == 3 || gateAct == 6; - const auto cellActHasBeta = cellAct == 3 || cellAct == 6; - const auto outActHasBeta = outAct == 3 || outAct == 6; - - uint count = 1; - const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping - const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0; - const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0; - const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0; - const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0; - const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0; - const auto outBeta = outActHasBeta ? T_ARG(count++) : 0; - - count = 3; - const auto x = INPUT_VARIABLE(0); // input - const auto Wx = INPUT_VARIABLE(1); // input weights - const auto Wr = INPUT_VARIABLE(2); // recurrent weights - const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases - const auto hI = INPUT_VARIABLE(count++); // initial output - const auto cI = INPUT_VARIABLE(count++); // initial cell state - const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights - const auto dLdh = INPUT_VARIABLE(count); // gradient wrt output - - REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_CELL_BP operation: cell clipping value should be nonnegative (>=0) !"); - - count = 3; - auto dLdx = OUTPUT_VARIABLE(0); - auto dLdWx = OUTPUT_VARIABLE(1); - auto dLdWr = OUTPUT_VARIABLE(2); - auto dLdb = hasBiases ? OUTPUT_VARIABLE(count++) : nullptr; - auto dLdhI = OUTPUT_VARIABLE(count++); - auto dLdcI = OUTPUT_VARIABLE(count++); - auto dLdWp = hasPH ? OUTPUT_VARIABLE(count) : nullptr; - - // evaluate dimensions - const Nd4jLong bS = x->rankOf() == 1 ? 0 : x->sizeAt(0); - const Nd4jLong nIn = x->sizeAt(-1); - const Nd4jLong nOut = Wx->sizeAt(-1) / 4; - - // inputs validations - // Wx validation - if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - // Wr validation - if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); - // initial output/cell validation - std::vector exphIcIShape = x->rankOf() == 1 ? std::vector{nOut} : std::vector{bS, nOut}; - REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); - REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(cI).c_str()); - REQUIRE_TRUE(dLdh->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdh gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str()); - // biases validation - if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); - if(dLdb != nullptr && (dLdb->rankOf() != 1 || dLdb->sizeAt(0) != 4*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdb gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(dLdb).c_str()); - // peephole weights validation - if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str()); - if(dLdWp != nullptr && (dLdWp->rankOf() != 1 || dLdWp->sizeAt(0) != 3*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdWp gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(dLdWp).c_str()); - - - std::vector params = {static_cast(0)/*ignore*/, static_cast(0)/*ignore*/, static_cast(cellClip), - static_cast(gateAct), static_cast(gateAlpha), static_cast(gateBeta), - static_cast(cellAct), static_cast(cellAlpha), static_cast(cellBeta), - static_cast(outAct), static_cast(outAlpha), static_cast(outBeta)}; - - std::vector zShape = x->rankOf() == 1 ? std::vector({4*nOut}) : std::vector({bS, 4*nOut}); - - NDArray z(x->ordering(), zShape, x->dataType(), block.launchContext()); - NDArray a = z.ulike(); - NDArray h = cI->ulike(); - NDArray c = cI->ulike(); - - helpers::lstmLayerCell(x,Wx, Wr, b, hI, cI, Wp, params, &z, &a, &h, &c); - - helpers::lstmLayerCellBp(x, Wx, Wr, b, hI, cI, Wp, dLdh, nullptr, nullptr, &z, &a, &c, params, dLdx, dLdWx, dLdWr, dLdhI, dLdcI, dLdb, dLdWp); - - return Status::OK(); + // equations (no peephole connections) + // it = σ(Wxi * xt + Wri * ht-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + bo) + // ht = ot ◦ tanh(ct) + + // equations (peephole connections are present) + // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = clip(ft ◦ ct-1 + it ◦ c't) + // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) + // ht = ot ◦ tanh(ct) + + // notations: + // bS - batch size + // nIn - input size + // nOut - output size (hidden size) + + // INPUTS: + // input x: [bS, nIn] or [nIn] + // input weights Wx: [nIn, 4*nOut] + // recurrent weights Wr: [nOut, 4*nOut] + // initial (previous) output hI: [bS, nOut] or [nOut] + // initial (previous) cell state cI: [bS, nOut] or [nOut] + // gradient wrt output dLdh: [bS, nOut] or [nOut] + // gradient wrt cell state dLdc: [bS, nOut] or [nOut] + // peephole weights Wp (optional): [3*nOut] + // biases b (optional): [4*nOut] + + // OUTPUTS: + // gradient wrt x dLdx: [bS, nIn] or [nIn] + // gradient wrt Wx dLdWx: [nIn, 4*nOut] + // gradient wrt Wr dLdWr: [nOut, 4*nOut] + // gradient wrt hI dLdhI: [bS, nOut] or [nOut] + // gradient wrt cI dLdcI: [bS, nOut] or [nOut] + // gradient wrt b dLdb (optional): [4*nOut] + // gradient wrt Wp dLdWp (optional): [3*nOut] + + // !!! dimension 4*nOut implies order it, ft, c't, ot + // !!! dimension 3*nOut implies order it, ft, ot + + // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, + // 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, + // 8=ELU, 9=softsign, 10=softplus + const auto gateAct = + INT_ARG(0); // activation for input (i), forget (f) and output (o) gates + const auto cellAct = INT_ARG(1); // activation for cell state (c) + const auto outAct = INT_ARG(2); // activation for output (h) + + const auto hasBiases = + B_ARG(0); // indicates whether biases array is provided + const auto hasPH = + B_ARG(1); // indicates whether peephole connections are present + + const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || + gateAct == 6 || gateAct == 8; + const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || + cellAct == 6 || cellAct == 8; + const auto outActHasAlpha = + outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8; + const auto gateActHasBeta = gateAct == 3 || gateAct == 6; + const auto cellActHasBeta = cellAct == 3 || cellAct == 6; + const auto outActHasBeta = outAct == 3 || outAct == 6; + + uint count = 1; + const auto cellClip = + T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping + const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0; + const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0; + const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0; + const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0; + const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0; + const auto outBeta = outActHasBeta ? T_ARG(count++) : 0; + + count = 3; + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto hI = INPUT_VARIABLE(count++); // initial output + const auto cI = INPUT_VARIABLE(count++); // initial cell state + const auto Wp = + hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights + const auto dLdh = INPUT_VARIABLE(count); // gradient wrt output + + REQUIRE_TRUE(cellClip >= 0, 0, + "LSTM_LAYER_CELL_BP operation: cell clipping value should be " + "nonnegative (>=0) !"); + + count = 3; + auto dLdx = OUTPUT_VARIABLE(0); + auto dLdWx = OUTPUT_VARIABLE(1); + auto dLdWr = OUTPUT_VARIABLE(2); + auto dLdb = hasBiases ? OUTPUT_VARIABLE(count++) : nullptr; + auto dLdhI = OUTPUT_VARIABLE(count++); + auto dLdcI = OUTPUT_VARIABLE(count++); + auto dLdWp = hasPH ? OUTPUT_VARIABLE(count) : nullptr; + + // evaluate dimensions + const Nd4jLong bS = x->rankOf() == 1 ? 0 : x->sizeAt(0); + const Nd4jLong nIn = x->sizeAt(-1); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + // inputs validations + // Wx validation + if (Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_CELL_BP operation: wrong shape of input weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({nIn, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + // Wr validation + if (Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4 * nOut) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_CELL_BP operation: wrong shape of recurrent " + "weights, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({nOut, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wr).c_str()); + // initial output/cell validation + std::vector exphIcIShape = x->rankOf() == 1 + ? std::vector{nOut} + : std::vector{bS, nOut}; + REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0, + "LSTM_LAYER_CELL_BP operation: wrong shape of initial output, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(exphIcIShape).c_str(), + ShapeUtils::shapeAsString(hI).c_str()); + REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0, + "LSTM_LAYER_CELL_BP operation: wrong shape of initial cell " + "state, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(exphIcIShape).c_str(), + ShapeUtils::shapeAsString(cI).c_str()); + REQUIRE_TRUE(dLdh->isSameShape(exphIcIShape), 0, + "LSTM_LAYER_CELL_BP operation: wrong shape of dLdh gradient, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(exphIcIShape).c_str(), + ShapeUtils::shapeAsString(dLdh).c_str()); + // biases validation + if (b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_CELL_BP operation: wrong shape of biases, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({4 * nOut}).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + if (dLdb != nullptr && (dLdb->rankOf() != 1 || dLdb->sizeAt(0) != 4 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_CELL_BP operation: wrong shape of dLdb gradient, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({4 * nOut}).c_str(), + ShapeUtils::shapeAsString(dLdb).c_str()); + // peephole weights validation + if (Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_CELL_BP operation: wrong shape of peephole " + "weights, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({3 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wp).c_str()); + if (dLdWp != nullptr && + (dLdWp->rankOf() != 1 || dLdWp->sizeAt(0) != 3 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_CELL_BP operation: wrong shape of dLdWp gradient, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({3 * nOut}).c_str(), + ShapeUtils::shapeAsString(dLdWp).c_str()); + + std::vector params = { + static_cast(0) /*ignore*/, static_cast(0) /*ignore*/, + static_cast(cellClip), static_cast(gateAct), + static_cast(gateAlpha), static_cast(gateBeta), + static_cast(cellAct), static_cast(cellAlpha), + static_cast(cellBeta), static_cast(outAct), + static_cast(outAlpha), static_cast(outBeta)}; + + std::vector zShape = x->rankOf() == 1 + ? std::vector({4 * nOut}) + : std::vector({bS, 4 * nOut}); + + NDArray z(x->ordering(), zShape, x->dataType(), block.launchContext()); + NDArray a = z.ulike(); + NDArray h = cI->ulike(); + NDArray c = cI->ulike(); + + helpers::lstmLayerCell(x, Wx, Wr, b, hI, cI, Wp, params, &z, &a, &h, &c); + + helpers::lstmLayerCellBp(x, Wx, Wr, b, hI, cI, Wp, dLdh, nullptr, nullptr, &z, + &a, &c, params, dLdx, dLdWx, dLdWr, dLdhI, dLdcI, + dLdb, dLdWp); + + return Status::OK(); } DECLARE_TYPES(lstmLayerCellBp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - DECLARE_SHAPE_FN(lstmLayerCellBp) { + const auto hasBiases = + B_ARG(0); // indicates whether biases array is provided + const auto hasPH = + B_ARG(1); // indicates whether peephole connections are present - const auto hasBiases = B_ARG(0); // indicates whether biases array is provided - const auto hasPH = B_ARG(1); // indicates whether peephole connections are present - - uint count = 3; - const auto x = INPUT_VARIABLE(0); // input - const auto Wx = INPUT_VARIABLE(1); // input weights - const auto Wr = INPUT_VARIABLE(2); // recurrent weights - const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases - const auto hI = INPUT_VARIABLE(count++); // initial output - const auto cI = INPUT_VARIABLE(count++); // initial cell state - const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr; // peephole weights + uint count = 3; + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto hI = INPUT_VARIABLE(count++); // initial output + const auto cI = INPUT_VARIABLE(count++); // initial cell state + const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr; // peephole weights - auto shapes = SHAPELIST(x->shapeInfo(), Wx->shapeInfo(), Wr->shapeInfo()); + auto shapes = SHAPELIST(x->shapeInfo(), Wx->shapeInfo(), Wr->shapeInfo()); - if(b != nullptr) - shapes->push_back(b->shapeInfo()); + if (b != nullptr) shapes->push_back(b->shapeInfo()); - shapes->push_back(hI->shapeInfo()); - shapes->push_back(cI->shapeInfo()); + shapes->push_back(hI->shapeInfo()); + shapes->push_back(cI->shapeInfo()); - if(Wp != nullptr) - shapes->push_back(Wp->shapeInfo()); + if (Wp != nullptr) shapes->push_back(Wp->shapeInfo()); - return shapes; + return shapes; } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/sru.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/sru.cpp index ba4e3d52f964..bb66fb60726e 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/sru.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/sru.cpp @@ -15,7 +15,8 @@ ******************************************************************************/ // -// implementation of operations for Simple Recurrent Unit: arXiv:1709.02755v2 [cs.CL] 12 Sep 2017 +// implementation of operations for Simple Recurrent Unit: arXiv:1709.02755v2 +// [cs.CL] 12 Sep 2017 // //@author Yurii Shyrma // @@ -23,563 +24,863 @@ #include #if NOT_EXCLUDED(OP_sru) -#include -#include #include #include +#include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(sru, 5, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); // X, input 3d tensor [bS x inSize x time], time - number of time steps, bS - batch size, inSize - number of features - auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [3*inSize x inSize] - auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [2*inSize] - auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x inSize] at time t=0 - auto mask = block.width() > 4 ? INPUT_VARIABLE(4) : nullptr; // optional, 2d tensor of dropout mask [bS x inSize] - - auto h = OUTPUT_VARIABLE(0); // cell outputs, [bS x inSize x time] - auto c = OUTPUT_VARIABLE(1); // cell states, [bS x inSize x time] - - const int rank = x->rankOf(); // = 3 - const auto bS = x->sizeAt(0); - const auto inSize = x->sizeAt(1); - const auto time = x->sizeAt(2); - - // input shapes validation - REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, w->rankOf()); - REQUIRE_TRUE(b->rankOf() == 1, 0, "SRU operation: wrong rank of biases array, expected is %i, but got %i instead !", 1, b->rankOf()); - REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0->rankOf()); - if(mask) - REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf()); - - const std::vector wCorrectShape = {3*inSize, inSize}; - const std::vector bCorrectShape = {2*inSize}; - const std::vector c0CorrectShape = {bS, inSize}; - - REQUIRE_TRUE(w->isSameShape(wCorrectShape), 0, "SRU operation: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(w).c_str()); - REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "SRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - REQUIRE_TRUE(c0->isSameShape(c0CorrectShape), 0, "SRU operation: wrong shape of initial state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(c0).c_str()); - if(mask) - REQUIRE_TRUE(mask->isSameShape(c0CorrectShape), 0, "SRU operation: wrong shape of mask array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(mask).c_str()); - - // xm = x * mask - auto xm = x; - if(mask) { - xm = new NDArray(x->shapeInfo(), true, block.launchContext()); - x->applyBroadcast(broadcast::Multiply, {0, 1}, *mask, *xm); - } - - // time loop - helpers::sruTimeLoop(block.launchContext(), xm, c0, w, b, h, c); - - if(mask) - delete xm; - - return Status::OK(); + auto x = INPUT_VARIABLE( + 0); // X, input 3d tensor [bS x inSize x time], time - number of time + // steps, bS - batch size, inSize - number of features + auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [3*inSize x inSize] + auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [2*inSize] + auto c0 = INPUT_VARIABLE( + 3); // C_{0}, 2d tensor of initial state [bS x inSize] at time t=0 + auto mask = + block.width() > 4 + ? INPUT_VARIABLE(4) + : nullptr; // optional, 2d tensor of dropout mask [bS x inSize] + + auto h = OUTPUT_VARIABLE(0); // cell outputs, [bS x inSize x time] + auto c = OUTPUT_VARIABLE(1); // cell states, [bS x inSize x time] + + const int rank = x->rankOf(); // = 3 + const auto bS = x->sizeAt(0); + const auto inSize = x->sizeAt(1); + const auto time = x->sizeAt(2); + + // input shapes validation + REQUIRE_TRUE(w->rankOf() == rank - 1, 0, + "SRU operation: wrong rank of weights array, expected is %i, " + "but got %i instead !", + rank - 1, w->rankOf()); + REQUIRE_TRUE(b->rankOf() == 1, 0, + "SRU operation: wrong rank of biases array, expected is %i, " + "but got %i instead !", + 1, b->rankOf()); + REQUIRE_TRUE(c0->rankOf() == rank - 1, 0, + "SRU operation: wrong rank of initial state array, expected is " + "%i, but got %i instead !", + rank - 1, c0->rankOf()); + if (mask) + REQUIRE_TRUE(mask->rankOf() == rank - 1, 0, + "SRU operation: wrong rank of mask array, expected is %i, but " + "got %i instead !", + rank - 1, mask->rankOf()); + + const std::vector wCorrectShape = {3 * inSize, inSize}; + const std::vector bCorrectShape = {2 * inSize}; + const std::vector c0CorrectShape = {bS, inSize}; + + REQUIRE_TRUE(w->isSameShape(wCorrectShape), 0, + "SRU operation: wrong shape of weights array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(wCorrectShape).c_str(), + ShapeUtils::shapeAsString(w).c_str()); + REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, + "SRU operation: wrong shape of biases array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(bCorrectShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + REQUIRE_TRUE(c0->isSameShape(c0CorrectShape), 0, + "SRU operation: wrong shape of initial state array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(c0CorrectShape).c_str(), + ShapeUtils::shapeAsString(c0).c_str()); + if (mask) + REQUIRE_TRUE(mask->isSameShape(c0CorrectShape), 0, + "SRU operation: wrong shape of mask array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(c0CorrectShape).c_str(), + ShapeUtils::shapeAsString(mask).c_str()); + + // xm = x * mask + auto xm = x; + if (mask) { + xm = new NDArray(x->shapeInfo(), true, block.launchContext()); + x->applyBroadcast(broadcast::Multiply, {0, 1}, *mask, *xm); + } + + // time loop + helpers::sruTimeLoop(block.launchContext(), xm, c0, w, b, h, c); + + if (mask) delete xm; + + return Status::OK(); } - DECLARE_TYPES(sru) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(sru) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(sru) { - - auto xShapeInfo = inputShape->at(0); // X, input 3d tensor [bS x inSize x time], time - number of time steps, bS - batch size, inSize - number of features - auto wShapeInfo = inputShape->at(1); // W, 2d tensor of weights [3*inSize x inSize] - auto bShapeInfo = inputShape->at(2); // B, row of biases with twice length [2*inSize] - auto c0ShapeInfo = inputShape->at(3); // C_{0}, 2d tensor of initial state [bS x inSize] at time t=0 - auto maskShapeInfo = block.width() > 4 ? inputShape->at(4) : nullptr; // optional, 2d tensor of dropout mask [bS x inSize] - - const int rank = xShapeInfo[0]; // = 3 - const int bS = xShapeInfo[1]; - const int inSize = xShapeInfo[2]; - const int time = xShapeInfo[3]; - - // input shapes validation - REQUIRE_TRUE(wShapeInfo[0] == rank-1, 0, "SRU operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, wShapeInfo[0]); - REQUIRE_TRUE(bShapeInfo[0] == 1, 0, "SRU operation: wrong rank of biases array, expected is %i, but got %i instead !", 1, bShapeInfo[0]); - REQUIRE_TRUE(c0ShapeInfo[0] == rank-1, 0, "SRU operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0ShapeInfo[0]); - if(maskShapeInfo) - REQUIRE_TRUE(maskShapeInfo[0] == rank-1, 0, "SRU operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, maskShapeInfo[0]); - - const std::vector wCorrectShape = {3*inSize, inSize}; - const std::vector bCorrectShape = {2*inSize}; - const std::vector c0CorrectShape = {bS, inSize}; - - REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, wCorrectShape), 0, "SRU operation: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(wShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, "SRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(c0ShapeInfo, c0CorrectShape), 0, "SRU operation: wrong shape of initial state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(c0ShapeInfo).c_str()); - if(maskShapeInfo) - REQUIRE_TRUE(ShapeUtils::areShapesEqual(maskShapeInfo, c0CorrectShape), 0, "SRU operation: wrong shape of mask array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(maskShapeInfo).c_str()); - - Nd4jLong* newShapeInfo1 = nullptr; - ALLOCATE(newShapeInfo1, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x inSize x time] - - newShapeInfo1[0] = rank; - newShapeInfo1[1] = bS; - newShapeInfo1[2] = inSize; - newShapeInfo1[3] = time; - - ShapeUtils::updateStridesAndType(newShapeInfo1, xShapeInfo, shape::order(xShapeInfo)); - ShapeDescriptor descriptor(newShapeInfo1); - RELEASE(newShapeInfo1, block.getWorkspace()); - auto result = ConstantShapeHelper::getInstance().createShapeInfo(descriptor); - return SHAPELIST(result, result); + auto xShapeInfo = inputShape->at( + 0); // X, input 3d tensor [bS x inSize x time], time - number of time + // steps, bS - batch size, inSize - number of features + auto wShapeInfo = + inputShape->at(1); // W, 2d tensor of weights [3*inSize x inSize] + auto bShapeInfo = + inputShape->at(2); // B, row of biases with twice length [2*inSize] + auto c0ShapeInfo = inputShape->at( + 3); // C_{0}, 2d tensor of initial state [bS x inSize] at time t=0 + auto maskShapeInfo = + block.width() > 4 + ? inputShape->at(4) + : nullptr; // optional, 2d tensor of dropout mask [bS x inSize] + + const int rank = xShapeInfo[0]; // = 3 + const int bS = xShapeInfo[1]; + const int inSize = xShapeInfo[2]; + const int time = xShapeInfo[3]; + + // input shapes validation + REQUIRE_TRUE(wShapeInfo[0] == rank - 1, 0, + "SRU operation: wrong rank of weights array, expected is %i, " + "but got %i instead !", + rank - 1, wShapeInfo[0]); + REQUIRE_TRUE(bShapeInfo[0] == 1, 0, + "SRU operation: wrong rank of biases array, expected is %i, " + "but got %i instead !", + 1, bShapeInfo[0]); + REQUIRE_TRUE(c0ShapeInfo[0] == rank - 1, 0, + "SRU operation: wrong rank of initial state array, expected is " + "%i, but got %i instead !", + rank - 1, c0ShapeInfo[0]); + if (maskShapeInfo) + REQUIRE_TRUE(maskShapeInfo[0] == rank - 1, 0, + "SRU operation: wrong rank of mask array, expected is %i, but " + "got %i instead !", + rank - 1, maskShapeInfo[0]); + + const std::vector wCorrectShape = {3 * inSize, inSize}; + const std::vector bCorrectShape = {2 * inSize}; + const std::vector c0CorrectShape = {bS, inSize}; + + REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, wCorrectShape), 0, + "SRU operation: wrong shape of weights array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(wCorrectShape).c_str(), + ShapeUtils::shapeAsString(wShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, + "SRU operation: wrong shape of biases array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(bCorrectShape).c_str(), + ShapeUtils::shapeAsString(bShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(c0ShapeInfo, c0CorrectShape), 0, + "SRU operation: wrong shape of initial state array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(c0CorrectShape).c_str(), + ShapeUtils::shapeAsString(c0ShapeInfo).c_str()); + if (maskShapeInfo) + REQUIRE_TRUE(ShapeUtils::areShapesEqual(maskShapeInfo, c0CorrectShape), 0, + "SRU operation: wrong shape of mask array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(c0CorrectShape).c_str(), + ShapeUtils::shapeAsString(maskShapeInfo).c_str()); + + Nd4jLong* newShapeInfo1 = nullptr; + ALLOCATE(newShapeInfo1, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); // [bS x inSize x time] + + newShapeInfo1[0] = rank; + newShapeInfo1[1] = bS; + newShapeInfo1[2] = inSize; + newShapeInfo1[3] = time; + + ShapeUtils::updateStridesAndType(newShapeInfo1, xShapeInfo, + shape::order(xShapeInfo)); + ShapeDescriptor descriptor(newShapeInfo1); + RELEASE(newShapeInfo1, block.workspace()); + auto result = ConstantShapeHelper::getInstance().createShapeInfo(descriptor); + return SHAPELIST(result, result); } ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) { - auto x = INPUT_VARIABLE(0); // X, input 3d tensor [bS x K x N], N - number of time steps, bS - batch size, K - number of features - auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [3K x K] - auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 2*K] - auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x K] at time t=0 - auto c = INPUT_VARIABLE(4); // C, [bS x K x N] - auto inGradCt = INPUT_VARIABLE(5); // [bS x K] - auto inGradH = INPUT_VARIABLE(6); // [bS x K x N] - NDArray* mask = nullptr; // optional, 2d tensor of dropout mask [bS x K] - - bool applyMask = false; - if (block.width() > 7) { - mask = INPUT_VARIABLE(7); - applyMask = true; - } - - auto gradX = OUTPUT_VARIABLE(0); // [bS x K x N] - auto gradW = OUTPUT_VARIABLE(1); // [bS x 3K x K] - auto gradB = OUTPUT_VARIABLE(2); // [1 x 2K] - auto gradInit = OUTPUT_VARIABLE(3); // [bS x K] - - const int bS = x->shapeOf()[0]; - const int K = x->shapeOf()[1]; - const int N = x->shapeOf()[2]; // N - number of time steps - - auto gradBias = NDArrayFactory::create_(x->ordering(), {bS, 2*K, N}, gradX->dataType(), block.launchContext()); - auto gradU = NDArrayFactory::create_(x->ordering(), {bS, 3*K, N}, gradX->dataType(), block.launchContext()); - auto gradHX = NDArrayFactory::create_(x->ordering(), {bS, K, N}, gradX->dataType(), block.launchContext()); - auto gct = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); - auto gradTanh = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); - auto gradCt = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); - auto ftMinus = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); - auto rtMinus = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); - auto temp1 = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); - auto temp2 = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); - - // x = x * mask - if(applyMask) - x->applyBroadcast(broadcast::Multiply, {0, 1}, *mask, *x); // apply mask - // multiplication matrix wi = matmul(w,x), U = WX - auto wi = MmulHelper::mmul(w, x, nullptr, 1., 0.); // U [bS x 3K x N] - - auto wiZ = (*wi)({0,0, 0,K, 0,0}, true); // [bS x K x N] - auto wiF = (*wi)({0,0, K,2*K, 0,0}, true); // forget gate [bS x K x N] - auto wiR = (*wi)({0,0, 2*K,3*K, 0,0}, true); // reset gate [bS x K x N] - auto bF = (*b) ({0,0, 0,K }, true); // biases for forget gate [1 x K] - auto bR = (*b) ({0,0, K,2*K}, true); // biases for reset gate [1 x K] - auto gradBF = (*gradBias)({0,0, 0,K, 0,0}, true); // [bS x K x N] - auto gradBR = (*gradBias)({0,0, K,2*K, 0,0}, true); // [bS x K x N] - auto gradUZ = (*gradU) ({0,0, 0,K, 0,0}, true ); // [bS x K x N] - auto gradUF = (*gradU) ({0,0, K,2*K, 0,0}, true ); // [bS x K x N] - auto gradUR = (*gradU) ({0,0, 2*K,3*K, 0,0}, true ); // [bS x K x N] - - NDArray* ct_1 = nullptr; - - std::vector idx = {0,0, 0,0, 0,0}; - - for (int t = N-1; t >=0 ; --t) { - // initialization - idx[4] = t; - idx[5] = t + 1; - auto xt = (*x)(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] - auto zt = wiZ(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] - auto ft = wiF(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] - auto rt = wiR(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] - auto ct = (*c)(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] - auto inGradHt = (*inGradH)(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] - auto gradBRt = gradBR(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] - auto gradBFt = gradBF(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] - auto gradHXt = (*gradHX)(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] - auto gradUZt = gradUZ(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] - auto gradUFt = gradUF(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] - auto gradURt = gradUR(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] - - if(t != 0) { - idx[4] = t - 1; - idx[5] = t; - ct_1 = new NDArray((*c)(idx)); // previous c_{t-1} - } - else - ct_1 = c0; - - ///////////////// forward - // ft = sigmoid(ft + bf), rt = sigmoid(rt + bR) - ft.addRowVector(bF, ft); - rt.addRowVector(bR, rt); - ft.applyTransform(transform::Sigmoid, ft); - rt.applyTransform(transform::Sigmoid, rt); - - // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur ); - ct.applyTransform(transform::Tanh, *gct); - // ftMinus = 1-ft, rtMinus = 1-rt - ft.applyTransform(transform::OneMinus, *ftMinus); - rt.applyTransform(transform::OneMinus, *rtMinus); - - ///////////////// backward - // bR, *grad_brt_ptr = inGradHt * (g_ct - xt) * (1.0f - rt) * rt; - gct->applyPairwiseTransform(pairwise::Subtract, xt, *temp1); // temp1 = (g_ct - xt) - rtMinus->applyPairwiseTransform(pairwise::Multiply, rt, *temp2); // temp2 = (1.0f - rt) * rt; - temp1->applyPairwiseTransform(pairwise::Multiply, *temp2); // temp1 = (g_ct - xt) * (1.0f - rt) * rt; - inGradHt.applyPairwiseTransform(pairwise::Multiply, *temp1, gradBRt); // = inGradHt * (g_ct - xt) * (1.0f - rt) * rt; - - // bF, TODO - tanh - // gradTanh = (1.0f - g_ct * g_ct); - gct->applyPairwiseTransform(pairwise::Multiply, *gct, *gradTanh); // gradTanh = g_ct * g_ct - gradTanh->applyTransform(transform::OneMinus, *gradTanh); // gradTanh = (1.0f - g_ct * g_ct) - // gradCt = inGradHt * rt * gradTanh - rt.applyPairwiseTransform(pairwise::Multiply, *gradTanh, *gradCt); // gradCt = rt * gradTanh - inGradHt.applyPairwiseTransform(pairwise::Multiply, *gradCt, *gradCt); // gradCt = inGradHt * rt * gradTanh - // gradBFt = (gradCt + inGradCt) * (ct_1 - zt) * (1 - ft) * ft; - gradCt->applyPairwiseTransform(pairwise::Add, *inGradCt, *temp1); // temp1 = (gradCt + inGradCt) - ct_1->applyPairwiseTransform(pairwise::Subtract, zt, *temp2); // temp2 = (ct_1 - zt) - temp1->applyPairwiseTransform(pairwise::Multiply, *ftMinus, *temp1); // temp1 = (gradCt + inGradCt)*(1-ft) - temp1->applyPairwiseTransform(pairwise::Multiply, ft, *temp1); // temp1 = (gradCt + inGradCt)*(1-ft)*ft - temp1->applyPairwiseTransform(pairwise::Multiply, *temp2, gradBFt); // gradBFt = (gradCt + inGradCt) * (ct_1 - zt) * (1 - ft) * ft; - - // x_t (highway connection), gradHXt = inGradHt * (1.0f - rt); - inGradHt.applyPairwiseTransform(pairwise::Multiply, *rtMinus, gradHXt); - - // U_t, gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft); - rt.applyPairwiseTransform(pairwise::Multiply, *gradTanh, *temp1); // temp1 = rt * grad_tanh - inGradHt.applyPairwiseTransform(pairwise::Multiply, *temp1, *temp1); // temp1 = inGradHt * rt * grad_tanh - temp1->applyPairwiseTransform(pairwise::Add, *inGradCt, *temp1); // temp1 = inGradHt * rt * grad_tanh + inGradCt - temp1->applyPairwiseTransform(pairwise::Multiply, *ftMinus, gradUZt); // gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft); - gradUFt.assign(&gradBFt); - gradURt.assign(&gradBRt); - - // c_{t-1}, inGradCt = (gradCt + inGradCt) * ft; - gradCt->applyPairwiseTransform(pairwise::Add, *inGradCt, *temp1); // temp1 = (gradCt + inGradCt) - temp1->applyPairwiseTransform(pairwise::Multiply, ft, *inGradCt); // inGradCt = (gradCt + inGradCt) * ft; - - if(t != 0) - delete ct_1; - } - - // gradInit - gradInit->assign(inGradCt); - - // gradX - auto weightsT = w->transpose(); // [K x 3K] - MmulHelper::mmul(&weightsT, gradU, gradX, 1., 0.); // [bS x K x N] - gradX->applyPairwiseTransform(pairwise::Add, *gradHX, *gradX); // + grad_highway_x - if(applyMask) - gradX->applyBroadcast(broadcast::Multiply, {0,1}, *mask, *gradX); // apply mask - - // gradB - auto temp3 = gradBias->reduceAlongDimension(reduce::Sum, {0,2}, false, true); // [1 x 2K] - gradB->assign(temp3); - - // gradW [bS x 3K x K] - x->permutei({0, 2, 1}); // [bS x N x K] - MmulHelper::mmul(gradU, x, gradW, 1., 0.); // [bS x 3K x K] - - delete gct; delete gradU; delete gradHX; - delete temp1; delete temp2; delete gradCt; delete wi; - delete gradTanh; delete ftMinus; delete rtMinus; delete gradBias; - - return Status::OK(); + auto x = + INPUT_VARIABLE(0); // X, input 3d tensor [bS x K x N], N - number of time + // steps, bS - batch size, K - number of features + auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [3K x K] + auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 2*K] + auto c0 = INPUT_VARIABLE( + 3); // C_{0}, 2d tensor of initial state [bS x K] at time t=0 + auto c = INPUT_VARIABLE(4); // C, [bS x K x N] + auto inGradCt = INPUT_VARIABLE(5); // [bS x K] + auto inGradH = INPUT_VARIABLE(6); // [bS x K x N] + NDArray* mask = nullptr; // optional, 2d tensor of dropout mask [bS x K] + + bool applyMask = false; + if (block.width() > 7) { + mask = INPUT_VARIABLE(7); + applyMask = true; + } + + auto gradX = OUTPUT_VARIABLE(0); // [bS x K x N] + auto gradW = OUTPUT_VARIABLE(1); // [bS x 3K x K] + auto gradB = OUTPUT_VARIABLE(2); // [1 x 2K] + auto gradInit = OUTPUT_VARIABLE(3); // [bS x K] + + const int bS = x->shapeOf()[0]; + const int K = x->shapeOf()[1]; + const int N = x->shapeOf()[2]; // N - number of time steps + + auto gradBias = NDArrayFactory::create_( + x->ordering(), {bS, 2 * K, N}, gradX->dataType(), block.launchContext()); + auto gradU = NDArrayFactory::create_( + x->ordering(), {bS, 3 * K, N}, gradX->dataType(), block.launchContext()); + auto gradHX = NDArrayFactory::create_( + x->ordering(), {bS, K, N}, gradX->dataType(), block.launchContext()); + auto gct = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), + block.launchContext()); + auto gradTanh = NDArrayFactory::create_( + c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); + auto gradCt = NDArrayFactory::create_( + c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); + auto ftMinus = NDArrayFactory::create_( + c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); + auto rtMinus = NDArrayFactory::create_( + c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); + auto temp1 = NDArrayFactory::create_( + c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); + auto temp2 = NDArrayFactory::create_( + c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); + + // x = x * mask + if (applyMask) + x->applyBroadcast(broadcast::Multiply, {0, 1}, *mask, *x); // apply mask + // multiplication matrix wi = matmul(w,x), U = WX + auto wi = MmulHelper::mmul(w, x, nullptr, 1., 0.); // U [bS x 3K x N] + + auto wiZ = (*wi)({0, 0, 0, K, 0, 0}, true); // [bS x K x N] + auto wiF = (*wi)({0, 0, K, 2 * K, 0, 0}, true); // forget gate [bS x K x N] + auto wiR = + (*wi)({0, 0, 2 * K, 3 * K, 0, 0}, true); // reset gate [bS x K x N] + auto bF = (*b)({0, 0, 0, K}, true); // biases for forget gate [1 x K] + auto bR = (*b)({0, 0, K, 2 * K}, true); // biases for reset gate [1 x K] + auto gradBF = (*gradBias)({0, 0, 0, K, 0, 0}, true); // [bS x K x N] + auto gradBR = (*gradBias)({0, 0, K, 2 * K, 0, 0}, true); // [bS x K x N] + auto gradUZ = (*gradU)({0, 0, 0, K, 0, 0}, true); // [bS x K x N] + auto gradUF = (*gradU)({0, 0, K, 2 * K, 0, 0}, true); // [bS x K x N] + auto gradUR = (*gradU)({0, 0, 2 * K, 3 * K, 0, 0}, true); // [bS x K x N] + + NDArray* ct_1 = nullptr; + + std::vector idx = {0, 0, 0, 0, 0, 0}; + + for (int t = N - 1; t >= 0; --t) { + // initialization + idx[4] = t; + idx[5] = t + 1; + auto xt = (*x)(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] + auto zt = wiZ(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] + auto ft = wiF(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] + auto rt = wiR(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] + auto ct = (*c)(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] + auto inGradHt = + (*inGradH)(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] + auto gradBRt = gradBR(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] + auto gradBFt = gradBF(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] + auto gradHXt = (*gradHX)(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] + auto gradUZt = gradUZ(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] + auto gradUFt = gradUF(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] + auto gradURt = gradUR(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] + + if (t != 0) { + idx[4] = t - 1; + idx[5] = t; + ct_1 = new NDArray((*c)(idx)); // previous c_{t-1} + } else + ct_1 = c0; + + ///////////////// forward + // ft = sigmoid(ft + bf), rt = sigmoid(rt + bR) + ft.addRowVector(bF, ft); + rt.addRowVector(bR, rt); + ft.applyTransform(transform::Sigmoid, ft); + rt.applyTransform(transform::Sigmoid, rt); + + // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) + // ? reluf(cur) : cur ); + ct.applyTransform(transform::Tanh, *gct); + // ftMinus = 1-ft, rtMinus = 1-rt + ft.applyTransform(transform::OneMinus, *ftMinus); + rt.applyTransform(transform::OneMinus, *rtMinus); + + ///////////////// backward + // bR, *grad_brt_ptr = inGradHt * (g_ct - xt) * (1.0f - rt) * rt; + gct->applyPairwiseTransform(pairwise::Subtract, xt, + *temp1); // temp1 = (g_ct - xt) + rtMinus->applyPairwiseTransform(pairwise::Multiply, rt, + *temp2); // temp2 = (1.0f - rt) * rt; + temp1->applyPairwiseTransform( + pairwise::Multiply, *temp2); // temp1 = (g_ct - xt) * (1.0f - rt) * rt; + inGradHt.applyPairwiseTransform( + pairwise::Multiply, *temp1, + gradBRt); // = inGradHt * (g_ct - xt) * (1.0f - rt) * rt; + + // bF, TODO - tanh + // gradTanh = (1.0f - g_ct * g_ct); + gct->applyPairwiseTransform(pairwise::Multiply, *gct, + *gradTanh); // gradTanh = g_ct * g_ct + gradTanh->applyTransform(transform::OneMinus, + *gradTanh); // gradTanh = (1.0f - g_ct * g_ct) + // gradCt = inGradHt * rt * gradTanh + rt.applyPairwiseTransform(pairwise::Multiply, *gradTanh, + *gradCt); // gradCt = rt * gradTanh + inGradHt.applyPairwiseTransform( + pairwise::Multiply, *gradCt, + *gradCt); // gradCt = inGradHt * rt * gradTanh + // gradBFt = (gradCt + inGradCt) * (ct_1 - zt) * (1 - ft) * ft; + gradCt->applyPairwiseTransform(pairwise::Add, *inGradCt, + *temp1); // temp1 = (gradCt + inGradCt) + ct_1->applyPairwiseTransform(pairwise::Subtract, zt, + *temp2); // temp2 = (ct_1 - zt) + temp1->applyPairwiseTransform( + pairwise::Multiply, *ftMinus, + *temp1); // temp1 = (gradCt + inGradCt)*(1-ft) + temp1->applyPairwiseTransform( + pairwise::Multiply, ft, + *temp1); // temp1 = (gradCt + inGradCt)*(1-ft)*ft + temp1->applyPairwiseTransform(pairwise::Multiply, *temp2, + gradBFt); // gradBFt = (gradCt + inGradCt) * + // (ct_1 - zt) * (1 - ft) * ft; + + // x_t (highway connection), gradHXt = inGradHt * (1.0f - rt); + inGradHt.applyPairwiseTransform(pairwise::Multiply, *rtMinus, gradHXt); + + // U_t, gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft); + rt.applyPairwiseTransform(pairwise::Multiply, *gradTanh, + *temp1); // temp1 = rt * grad_tanh + inGradHt.applyPairwiseTransform( + pairwise::Multiply, *temp1, + *temp1); // temp1 = inGradHt * rt * grad_tanh + temp1->applyPairwiseTransform( + pairwise::Add, *inGradCt, + *temp1); // temp1 = inGradHt * rt * grad_tanh + inGradCt + temp1->applyPairwiseTransform( + pairwise::Multiply, *ftMinus, + gradUZt); // gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - + // ft); + gradUFt.assign(&gradBFt); + gradURt.assign(&gradBRt); + + // c_{t-1}, inGradCt = (gradCt + inGradCt) * ft; + gradCt->applyPairwiseTransform(pairwise::Add, *inGradCt, + *temp1); // temp1 = (gradCt + inGradCt) + temp1->applyPairwiseTransform( + pairwise::Multiply, ft, + *inGradCt); // inGradCt = (gradCt + inGradCt) * ft; + + if (t != 0) delete ct_1; + } + + // gradInit + gradInit->assign(inGradCt); + + // gradX + auto weightsT = w->transpose(); // [K x 3K] + MmulHelper::mmul(&weightsT, gradU, gradX, 1., 0.); // [bS x K x N] + gradX->applyPairwiseTransform(pairwise::Add, *gradHX, + *gradX); // + grad_highway_x + if (applyMask) + gradX->applyBroadcast(broadcast::Multiply, {0, 1}, *mask, + *gradX); // apply mask + + // gradB + auto temp3 = gradBias->reduceAlongDimension(reduce::Sum, {0, 2}, false, + true); // [1 x 2K] + gradB->assign(temp3); + + // gradW [bS x 3K x K] + x->permutei({0, 2, 1}); // [bS x N x K] + MmulHelper::mmul(gradU, x, gradW, 1., 0.); // [bS x 3K x K] + + delete gct; + delete gradU; + delete gradHX; + delete temp1; + delete temp2; + delete gradCt; + delete wi; + delete gradTanh; + delete ftMinus; + delete rtMinus; + delete gradBias; + + return Status::OK(); } - DECLARE_TYPES(sru_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(sru_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(sru_bp) { - - auto inShape = inputShape->at(0); // [bS x inSize x time] - auto bS = inShape[1]; - auto inSize = inShape[2]; - auto time = inShape[3]; - char order = (char)(inShape[9]); - - ShapeDescriptor descriptor1(ArrayOptions::dataType(inShape), order, {bS, inSize, time}); - ShapeDescriptor descriptor2(ArrayOptions::dataType(inShape), order, {bS, 3 * inSize, inSize}); - ShapeDescriptor descriptor3(ArrayOptions::dataType(inShape), order, {1, 2 * inSize}); - ShapeDescriptor descriptor4(ArrayOptions::dataType(inShape), order, {bS, inSize}); - - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(descriptor1), ConstantShapeHelper::getInstance().createShapeInfo(descriptor2), ConstantShapeHelper::getInstance().createShapeInfo(descriptor3), ConstantShapeHelper::getInstance().createShapeInfo(descriptor4)); + auto inShape = inputShape->at(0); // [bS x inSize x time] + auto bS = inShape[1]; + auto inSize = inShape[2]; + auto time = inShape[3]; + char order = (char)(inShape[9]); + + ShapeDescriptor descriptor1(ArrayOptions::dataType(inShape), order, + {bS, inSize, time}); + ShapeDescriptor descriptor2(ArrayOptions::dataType(inShape), order, + {bS, 3 * inSize, inSize}); + ShapeDescriptor descriptor3(ArrayOptions::dataType(inShape), order, + {1, 2 * inSize}); + ShapeDescriptor descriptor4(ArrayOptions::dataType(inShape), order, + {bS, inSize}); + + return SHAPELIST( + ConstantShapeHelper::getInstance().createShapeInfo(descriptor1), + ConstantShapeHelper::getInstance().createShapeInfo(descriptor2), + ConstantShapeHelper::getInstance().createShapeInfo(descriptor3), + ConstantShapeHelper::getInstance().createShapeInfo(descriptor4)); } - - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(sru_bi, 5, 2, true, 0, 0) { - - auto x = INPUT_VARIABLE(0); // X, input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features - auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [2*inSize x 6*inSize] - auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 4*inSize] - auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x 2*inSize] at time t=0 - NDArray* mask = block.width() > 4 ? INPUT_VARIABLE(4) : nullptr; // optional, 2d tensor of dropout mask [bS x 2*inSize] - - auto ht = OUTPUT_VARIABLE(0); // h_t, [time x bS x 2*inSize] - auto ct = OUTPUT_VARIABLE(1); // c_t, [time x bS x 2*inSize] - - // input shapes validation - const int rank = x->rankOf(); - const Nd4jLong bS = x->sizeAt(1); - const Nd4jLong inSize = x->sizeAt(2) / 2; - - REQUIRE_TRUE(x->rankOf() == rank, 0, "SRU_BI operation: wrong rank of input array, expected is %i, but got %i instead !", rank, x->rankOf()); - REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, w->rankOf()); - REQUIRE_TRUE(b->rankOf() == 1, 0, "SRU_BI operation: wrong rank of biases array, expected is 1, but got %i instead !", b->rankOf()); - REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0->rankOf()); - if(mask) - REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf()); - - const std::vector wCorrectShape = {2*inSize, 6*inSize}; - const std::vector bCorrectShape = {4*inSize}; - const std::vector c0CorrectShape = {bS, 2*inSize}; - - REQUIRE_TRUE(w->isSameShape(wCorrectShape), 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(w).c_str()); - REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - REQUIRE_TRUE(c0->isSameShape(c0CorrectShape), 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(c0).c_str()); - if(mask) - REQUIRE_TRUE(mask->isSameShape(c0CorrectShape), 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(mask).c_str()); - - helpers::sruBI(block.launchContext(), x, w, b, c0, mask, ht, ct); - - return Status::OK(); + auto x = INPUT_VARIABLE( + 0); // X, input 3d tensor [time x bS x 2*inSize], time - number of time + // steps, bS - batch size, inSize - number of features + auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [2*inSize x 6*inSize] + auto b = + INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 4*inSize] + auto c0 = INPUT_VARIABLE( + 3); // C_{0}, 2d tensor of initial state [bS x 2*inSize] at time t=0 + NDArray* mask = + block.width() > 4 + ? INPUT_VARIABLE(4) + : nullptr; // optional, 2d tensor of dropout mask [bS x 2*inSize] + + auto ht = OUTPUT_VARIABLE(0); // h_t, [time x bS x 2*inSize] + auto ct = OUTPUT_VARIABLE(1); // c_t, [time x bS x 2*inSize] + + // input shapes validation + const int rank = x->rankOf(); + const Nd4jLong bS = x->sizeAt(1); + const Nd4jLong inSize = x->sizeAt(2) / 2; + + REQUIRE_TRUE(x->rankOf() == rank, 0, + "SRU_BI operation: wrong rank of input array, expected is %i, " + "but got %i instead !", + rank, x->rankOf()); + REQUIRE_TRUE(w->rankOf() == rank - 1, 0, + "SRU_BI operation: wrong rank of weights array, expected is %i, " + "but got %i instead !", + rank - 1, w->rankOf()); + REQUIRE_TRUE(b->rankOf() == 1, 0, + "SRU_BI operation: wrong rank of biases array, expected is 1, " + "but got %i instead !", + b->rankOf()); + REQUIRE_TRUE(c0->rankOf() == rank - 1, 0, + "SRU_BI operation: wrong rank of initial state array, expected " + "is %i, but got %i instead !", + rank - 1, c0->rankOf()); + if (mask) + REQUIRE_TRUE(mask->rankOf() == rank - 1, 0, + "SRU_BI operation: wrong rank of mask array, expected is %i, " + "but got %i instead !", + rank - 1, mask->rankOf()); + + const std::vector wCorrectShape = {2 * inSize, 6 * inSize}; + const std::vector bCorrectShape = {4 * inSize}; + const std::vector c0CorrectShape = {bS, 2 * inSize}; + + REQUIRE_TRUE(w->isSameShape(wCorrectShape), 0, + "SRU_BI operation: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(wCorrectShape).c_str(), + ShapeUtils::shapeAsString(w).c_str()); + REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, + "SRU_BI operation: wrong shape of biases array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(bCorrectShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + REQUIRE_TRUE(c0->isSameShape(c0CorrectShape), 0, + "SRU_BI operation: wrong shape of initial state array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(c0CorrectShape).c_str(), + ShapeUtils::shapeAsString(c0).c_str()); + if (mask) + REQUIRE_TRUE(mask->isSameShape(c0CorrectShape), 0, + "SRU_BI operation: wrong shape of mask array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(c0CorrectShape).c_str(), + ShapeUtils::shapeAsString(mask).c_str()); + + helpers::sruBI(block.launchContext(), x, w, b, c0, mask, ht, ct); + + return Status::OK(); } - DECLARE_TYPES(sru_bi) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(sru_bi) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(sru_bi) { - - auto xShapeInfo = inputShape->at(0); // [time x bS x 2K ] - auto wShapeInfo = inputShape->at(1); - auto bShapeInfo = inputShape->at(2); - auto c0ShapeInfo = inputShape->at(3); - auto maskShapeInfo = block.width() > 4 ? inputShape->at(4) : nullptr; // optional, 2d tensor of dropout mask [bS x inSize] - - const int rank = xShapeInfo[0]; // = 3 - const Nd4jLong time = xShapeInfo[1]; - const Nd4jLong bS = xShapeInfo[2]; - const Nd4jLong inSize = xShapeInfo[3] / 2; - - - // input shapes validation - REQUIRE_TRUE(wShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, wShapeInfo[0]); - REQUIRE_TRUE(bShapeInfo[0] == 1, 0, "SRU_BI operation: wrong rank of biases array, expected is 1, but got %i instead !", bShapeInfo[0]); - REQUIRE_TRUE(c0ShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0ShapeInfo[0]); - if(maskShapeInfo) - REQUIRE_TRUE(maskShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, maskShapeInfo[0]); - - const std::vector wCorrectShape = {2*inSize, 6*inSize}; - const std::vector bCorrectShape = {4*inSize}; - const std::vector c0CorrectShape = {bS, 2*inSize}; - - REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, wCorrectShape), 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(wShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(c0ShapeInfo, c0CorrectShape), 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(c0ShapeInfo).c_str()); - if(maskShapeInfo) - REQUIRE_TRUE(ShapeUtils::areShapesEqual(maskShapeInfo, c0CorrectShape), 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(maskShapeInfo).c_str()); - - - char order = shape::order(xShapeInfo); - - ShapeDescriptor descriptor(ArrayOptions::dataType(xShapeInfo), order, {time, bS, 2 * inSize}); - auto result = ConstantShapeHelper::getInstance().createShapeInfo(descriptor); - return SHAPELIST(result, result); + auto xShapeInfo = inputShape->at(0); // [time x bS x 2K ] + auto wShapeInfo = inputShape->at(1); + auto bShapeInfo = inputShape->at(2); + auto c0ShapeInfo = inputShape->at(3); + auto maskShapeInfo = + block.width() > 4 + ? inputShape->at(4) + : nullptr; // optional, 2d tensor of dropout mask [bS x inSize] + + const int rank = xShapeInfo[0]; // = 3 + const Nd4jLong time = xShapeInfo[1]; + const Nd4jLong bS = xShapeInfo[2]; + const Nd4jLong inSize = xShapeInfo[3] / 2; + + // input shapes validation + REQUIRE_TRUE(wShapeInfo[0] == rank - 1, 0, + "SRU_BI operation: wrong rank of weights array, expected is %i, " + "but got %i instead !", + rank - 1, wShapeInfo[0]); + REQUIRE_TRUE(bShapeInfo[0] == 1, 0, + "SRU_BI operation: wrong rank of biases array, expected is 1, " + "but got %i instead !", + bShapeInfo[0]); + REQUIRE_TRUE(c0ShapeInfo[0] == rank - 1, 0, + "SRU_BI operation: wrong rank of initial state array, expected " + "is %i, but got %i instead !", + rank - 1, c0ShapeInfo[0]); + if (maskShapeInfo) + REQUIRE_TRUE(maskShapeInfo[0] == rank - 1, 0, + "SRU_BI operation: wrong rank of mask array, expected is %i, " + "but got %i instead !", + rank - 1, maskShapeInfo[0]); + + const std::vector wCorrectShape = {2 * inSize, 6 * inSize}; + const std::vector bCorrectShape = {4 * inSize}; + const std::vector c0CorrectShape = {bS, 2 * inSize}; + + REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, wCorrectShape), 0, + "SRU_BI operation: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(wCorrectShape).c_str(), + ShapeUtils::shapeAsString(wShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, + "SRU_BI operation: wrong shape of biases array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(bCorrectShape).c_str(), + ShapeUtils::shapeAsString(bShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(c0ShapeInfo, c0CorrectShape), 0, + "SRU_BI operation: wrong shape of initial state array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(c0CorrectShape).c_str(), + ShapeUtils::shapeAsString(c0ShapeInfo).c_str()); + if (maskShapeInfo) + REQUIRE_TRUE(ShapeUtils::areShapesEqual(maskShapeInfo, c0CorrectShape), 0, + "SRU_BI operation: wrong shape of mask array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(c0CorrectShape).c_str(), + ShapeUtils::shapeAsString(maskShapeInfo).c_str()); + + char order = shape::order(xShapeInfo); + + ShapeDescriptor descriptor(ArrayOptions::dataType(xShapeInfo), order, + {time, bS, 2 * inSize}); + auto result = ConstantShapeHelper::getInstance().createShapeInfo(descriptor); + return SHAPELIST(result, result); } - - DECLARE_TYPES(sru_bi_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(sru_bi_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(sru_bi_bp, 8, 4, true, 0, 0) { - - auto x = INPUT_VARIABLE(0); // X, input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features - auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [2*inSize x 6*inSize] - auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [4*inSize] - auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x 2*inSize] at time t=0 - auto ct = INPUT_VARIABLE(4); // C, [time x bS x 2*inSize] - auto inGradC0 = INPUT_VARIABLE(5); // [bS x 2*inSize] - auto inGradHt = INPUT_VARIABLE(6); // [time x bS x 2*inSize] - NDArray* mask = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr; // optional, 2d tensor of dropout mask [bS x 2*inSize] - - // input shapes validation - const int rank = x->rankOf(); - const Nd4jLong time = x->sizeAt(0); - const Nd4jLong bS = x->sizeAt(1); - const Nd4jLong inSize = x->sizeAt(2) / 2; - - REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, w->rankOf()); - REQUIRE_TRUE(b->rankOf() == 1, 0, "SRU_BI_BP operation: wrong rank of biases array, expected is 1, but got %i instead !", b->rankOf()); - REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0->rankOf()); - REQUIRE_TRUE(ct->rankOf() == rank, 0, "SRU_BI_BP operation: wrong rank of state array, expected is %i, but got %i instead !", rank, ct->rankOf()); - REQUIRE_TRUE(inGradC0->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of gradient c0, expected is %i, but got %i instead !", rank-1, inGradC0->rankOf()); - REQUIRE_TRUE(inGradHt->rankOf() == rank, 0, "SRU_BI_BP operation: wrong rank of gradient ht, expected is %i, but got %i instead !", rank, inGradHt->rankOf()); - if(mask) - REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf()); - - const std::vector wCorrectShape = {2*inSize, 6*inSize}; - const std::vector bCorrectShape = {4*inSize}; - const std::vector c0CorrectShape = {bS, 2*inSize}; - const std::vector ctCorrectShape = {time, bS, 2*inSize}; - - REQUIRE_TRUE(w->isSameShape(wCorrectShape), 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(w).c_str()); - REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - REQUIRE_TRUE(c0->isSameShape(c0CorrectShape), 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(c0).c_str()); - REQUIRE_TRUE(ct->isSameShape(ctCorrectShape), 0, "SRU_BI operation: wrong shape of state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(ctCorrectShape).c_str(), ShapeUtils::shapeAsString(ct).c_str()); - if(mask) - REQUIRE_TRUE(mask->isSameShape(c0CorrectShape), 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(mask).c_str()); - - auto gradI = OUTPUT_VARIABLE(0); // [time x bS x 2*inSize] - auto gradW = OUTPUT_VARIABLE(1); // [time x 2*inSize x 6*inSize] - auto gradB = OUTPUT_VARIABLE(2); // [1 x 4*inSize] - auto gradC0 = OUTPUT_VARIABLE(3); // [bS x 2*inSize] - - helpers::sruBIBP(block.launchContext(), x, w, b, c0, ct, inGradC0, inGradHt, mask, gradI, gradW, gradB, gradC0); - - return Status::OK(); + auto x = INPUT_VARIABLE( + 0); // X, input 3d tensor [time x bS x 2*inSize], time - number of time + // steps, bS - batch size, inSize - number of features + auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [2*inSize x 6*inSize] + auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [4*inSize] + auto c0 = INPUT_VARIABLE( + 3); // C_{0}, 2d tensor of initial state [bS x 2*inSize] at time t=0 + auto ct = INPUT_VARIABLE(4); // C, [time x bS x 2*inSize] + auto inGradC0 = INPUT_VARIABLE(5); // [bS x 2*inSize] + auto inGradHt = INPUT_VARIABLE(6); // [time x bS x 2*inSize] + NDArray* mask = + block.width() > 7 + ? INPUT_VARIABLE(7) + : nullptr; // optional, 2d tensor of dropout mask [bS x 2*inSize] + + // input shapes validation + const int rank = x->rankOf(); + const Nd4jLong time = x->sizeAt(0); + const Nd4jLong bS = x->sizeAt(1); + const Nd4jLong inSize = x->sizeAt(2) / 2; + + REQUIRE_TRUE(w->rankOf() == rank - 1, 0, + "SRU_BI_BP operation: wrong rank of weights array, expected is " + "%i, but got %i instead !", + rank - 1, w->rankOf()); + REQUIRE_TRUE(b->rankOf() == 1, 0, + "SRU_BI_BP operation: wrong rank of biases array, expected is " + "1, but got %i instead !", + b->rankOf()); + REQUIRE_TRUE(c0->rankOf() == rank - 1, 0, + "SRU_BI_BP operation: wrong rank of initial state array, " + "expected is %i, but got %i instead !", + rank - 1, c0->rankOf()); + REQUIRE_TRUE(ct->rankOf() == rank, 0, + "SRU_BI_BP operation: wrong rank of state array, expected is " + "%i, but got %i instead !", + rank, ct->rankOf()); + REQUIRE_TRUE(inGradC0->rankOf() == rank - 1, 0, + "SRU_BI_BP operation: wrong rank of gradient c0, expected is " + "%i, but got %i instead !", + rank - 1, inGradC0->rankOf()); + REQUIRE_TRUE(inGradHt->rankOf() == rank, 0, + "SRU_BI_BP operation: wrong rank of gradient ht, expected is " + "%i, but got %i instead !", + rank, inGradHt->rankOf()); + if (mask) + REQUIRE_TRUE(mask->rankOf() == rank - 1, 0, + "SRU_BI_BP operation: wrong rank of mask array, expected is " + "%i, but got %i instead !", + rank - 1, mask->rankOf()); + + const std::vector wCorrectShape = {2 * inSize, 6 * inSize}; + const std::vector bCorrectShape = {4 * inSize}; + const std::vector c0CorrectShape = {bS, 2 * inSize}; + const std::vector ctCorrectShape = {time, bS, 2 * inSize}; + + REQUIRE_TRUE(w->isSameShape(wCorrectShape), 0, + "SRU_BI operation: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(wCorrectShape).c_str(), + ShapeUtils::shapeAsString(w).c_str()); + REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, + "SRU_BI operation: wrong shape of biases array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(bCorrectShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + REQUIRE_TRUE(c0->isSameShape(c0CorrectShape), 0, + "SRU_BI operation: wrong shape of initial state array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(c0CorrectShape).c_str(), + ShapeUtils::shapeAsString(c0).c_str()); + REQUIRE_TRUE(ct->isSameShape(ctCorrectShape), 0, + "SRU_BI operation: wrong shape of state array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(ctCorrectShape).c_str(), + ShapeUtils::shapeAsString(ct).c_str()); + if (mask) + REQUIRE_TRUE(mask->isSameShape(c0CorrectShape), 0, + "SRU_BI operation: wrong shape of mask array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(c0CorrectShape).c_str(), + ShapeUtils::shapeAsString(mask).c_str()); + + auto gradI = OUTPUT_VARIABLE(0); // [time x bS x 2*inSize] + auto gradW = OUTPUT_VARIABLE(1); // [time x 2*inSize x 6*inSize] + auto gradB = OUTPUT_VARIABLE(2); // [1 x 4*inSize] + auto gradC0 = OUTPUT_VARIABLE(3); // [bS x 2*inSize] + + helpers::sruBIBP(block.launchContext(), x, w, b, c0, ct, inGradC0, inGradHt, + mask, gradI, gradW, gradB, gradC0); + + return Status::OK(); } DECLARE_SHAPE_FN(sru_bi_bp) { - - auto xShapeInfo = inputShape->at(0); // [time x bS x 2K ] - auto wShapeInfo = inputShape->at(1); - auto bShapeInfo = inputShape->at(2); - auto c0ShapeInfo = inputShape->at(3); - auto ctShapeInfo = inputShape->at(4); - auto inGradC0ShapeInfo = inputShape->at(5); - auto inGradHtShapeInfo = inputShape->at(6); - auto maskShapeInfo = block.width() > 7 ? inputShape->at(7) : nullptr; // optional, 2d tensor of dropout mask [bS x inSize] - - // input shapes validation - const int rank = xShapeInfo[0]; - const Nd4jLong time = xShapeInfo[1]; - const Nd4jLong bS = xShapeInfo[2]; - const Nd4jLong inSize = xShapeInfo[3] / 2; - - REQUIRE_TRUE(wShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, wShapeInfo[0]); - REQUIRE_TRUE(bShapeInfo[0] == 1, 0, "SRU_BI_BP operation: wrong rank of biases array, expected is 1, but got %i instead !", bShapeInfo[0]); - REQUIRE_TRUE(c0ShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0ShapeInfo); - REQUIRE_TRUE(ctShapeInfo[0] == rank, 0, "SRU_BI_BP operation: wrong rank of state array, expected is %i, but got %i instead !", rank, ctShapeInfo); - REQUIRE_TRUE(inGradC0ShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of gradient c0, expected is %i, but got %i instead !", rank-1, inGradC0ShapeInfo[0]); - REQUIRE_TRUE(inGradHtShapeInfo[0] == rank, 0, "SRU_BI_BP operation: wrong rank of gradient ht, expected is %i, but got %i instead !", rank, inGradHtShapeInfo[0]); - if(maskShapeInfo) - REQUIRE_TRUE(maskShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, maskShapeInfo[0]); - - const std::vector wCorrectShape = {2*inSize, 6*inSize}; - const std::vector bCorrectShape = {4*inSize}; - const std::vector c0CorrectShape = {bS, 2*inSize}; - const std::vector ctCorrectShape = {time, bS, 2*inSize}; - const std::vector inGradC0CorrectShape = {bS, 2*inSize}; - const std::vector inGradHtCorrectShape = {time, bS, 2*inSize}; - - REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, wCorrectShape), 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(wShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(c0ShapeInfo, c0CorrectShape), 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(c0ShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(ctShapeInfo, ctCorrectShape), 0, "SRU_BI operation: wrong shape of state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(ctCorrectShape).c_str(), ShapeUtils::shapeAsString(ctShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(inGradC0ShapeInfo, inGradC0CorrectShape), 0, "SRU_BI operation: wrong shape of gradient c0 array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(inGradC0CorrectShape).c_str(), ShapeUtils::shapeAsString(inGradC0ShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(inGradHtShapeInfo, inGradHtCorrectShape), 0, "SRU_BI operation: wrong shape of gradient ht array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(inGradHtCorrectShape).c_str(), ShapeUtils::shapeAsString(inGradHtShapeInfo).c_str()); - if(maskShapeInfo) - REQUIRE_TRUE(ShapeUtils::areShapesEqual(maskShapeInfo, c0CorrectShape), 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(maskShapeInfo).c_str()); - - const char order = shape::order(xShapeInfo); - - ShapeDescriptor descriptor1(ArrayOptions::dataType(xShapeInfo), order, {time, bS, 2 * inSize}); - ShapeDescriptor descriptor2(ArrayOptions::dataType(xShapeInfo), order, {time, 2 * inSize, 6 * inSize}); - ShapeDescriptor descriptor3(ArrayOptions::dataType(xShapeInfo), order, {4 * inSize}); - ShapeDescriptor descriptor4(ArrayOptions::dataType(xShapeInfo), order, {bS, 2 * inSize}); - - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(descriptor1), ConstantShapeHelper::getInstance().createShapeInfo(descriptor2), ConstantShapeHelper::getInstance().createShapeInfo(descriptor3), ConstantShapeHelper::getInstance().createShapeInfo(descriptor4)); + auto xShapeInfo = inputShape->at(0); // [time x bS x 2K ] + auto wShapeInfo = inputShape->at(1); + auto bShapeInfo = inputShape->at(2); + auto c0ShapeInfo = inputShape->at(3); + auto ctShapeInfo = inputShape->at(4); + auto inGradC0ShapeInfo = inputShape->at(5); + auto inGradHtShapeInfo = inputShape->at(6); + auto maskShapeInfo = + block.width() > 7 + ? inputShape->at(7) + : nullptr; // optional, 2d tensor of dropout mask [bS x inSize] + + // input shapes validation + const int rank = xShapeInfo[0]; + const Nd4jLong time = xShapeInfo[1]; + const Nd4jLong bS = xShapeInfo[2]; + const Nd4jLong inSize = xShapeInfo[3] / 2; + + REQUIRE_TRUE(wShapeInfo[0] == rank - 1, 0, + "SRU_BI_BP operation: wrong rank of weights array, expected is " + "%i, but got %i instead !", + rank - 1, wShapeInfo[0]); + REQUIRE_TRUE(bShapeInfo[0] == 1, 0, + "SRU_BI_BP operation: wrong rank of biases array, expected is " + "1, but got %i instead !", + bShapeInfo[0]); + REQUIRE_TRUE(c0ShapeInfo[0] == rank - 1, 0, + "SRU_BI_BP operation: wrong rank of initial state array, " + "expected is %i, but got %i instead !", + rank - 1, c0ShapeInfo); + REQUIRE_TRUE(ctShapeInfo[0] == rank, 0, + "SRU_BI_BP operation: wrong rank of state array, expected is " + "%i, but got %i instead !", + rank, ctShapeInfo); + REQUIRE_TRUE(inGradC0ShapeInfo[0] == rank - 1, 0, + "SRU_BI_BP operation: wrong rank of gradient c0, expected is " + "%i, but got %i instead !", + rank - 1, inGradC0ShapeInfo[0]); + REQUIRE_TRUE(inGradHtShapeInfo[0] == rank, 0, + "SRU_BI_BP operation: wrong rank of gradient ht, expected is " + "%i, but got %i instead !", + rank, inGradHtShapeInfo[0]); + if (maskShapeInfo) + REQUIRE_TRUE(maskShapeInfo[0] == rank - 1, 0, + "SRU_BI_BP operation: wrong rank of mask array, expected is " + "%i, but got %i instead !", + rank - 1, maskShapeInfo[0]); + + const std::vector wCorrectShape = {2 * inSize, 6 * inSize}; + const std::vector bCorrectShape = {4 * inSize}; + const std::vector c0CorrectShape = {bS, 2 * inSize}; + const std::vector ctCorrectShape = {time, bS, 2 * inSize}; + const std::vector inGradC0CorrectShape = {bS, 2 * inSize}; + const std::vector inGradHtCorrectShape = {time, bS, 2 * inSize}; + + REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, wCorrectShape), 0, + "SRU_BI operation: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(wCorrectShape).c_str(), + ShapeUtils::shapeAsString(wShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, + "SRU_BI operation: wrong shape of biases array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(bCorrectShape).c_str(), + ShapeUtils::shapeAsString(bShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(c0ShapeInfo, c0CorrectShape), 0, + "SRU_BI operation: wrong shape of initial state array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(c0CorrectShape).c_str(), + ShapeUtils::shapeAsString(c0ShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(ctShapeInfo, ctCorrectShape), 0, + "SRU_BI operation: wrong shape of state array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(ctCorrectShape).c_str(), + ShapeUtils::shapeAsString(ctShapeInfo).c_str()); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(inGradC0ShapeInfo, inGradC0CorrectShape), 0, + "SRU_BI operation: wrong shape of gradient c0 array, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(inGradC0CorrectShape).c_str(), + ShapeUtils::shapeAsString(inGradC0ShapeInfo).c_str()); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(inGradHtShapeInfo, inGradHtCorrectShape), 0, + "SRU_BI operation: wrong shape of gradient ht array, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(inGradHtCorrectShape).c_str(), + ShapeUtils::shapeAsString(inGradHtShapeInfo).c_str()); + if (maskShapeInfo) + REQUIRE_TRUE(ShapeUtils::areShapesEqual(maskShapeInfo, c0CorrectShape), 0, + "SRU_BI operation: wrong shape of mask array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(c0CorrectShape).c_str(), + ShapeUtils::shapeAsString(maskShapeInfo).c_str()); + + const char order = shape::order(xShapeInfo); + + ShapeDescriptor descriptor1(ArrayOptions::dataType(xShapeInfo), order, + {time, bS, 2 * inSize}); + ShapeDescriptor descriptor2(ArrayOptions::dataType(xShapeInfo), order, + {time, 2 * inSize, 6 * inSize}); + ShapeDescriptor descriptor3(ArrayOptions::dataType(xShapeInfo), order, + {4 * inSize}); + ShapeDescriptor descriptor4(ArrayOptions::dataType(xShapeInfo), order, + {bS, 2 * inSize}); + + return SHAPELIST( + ConstantShapeHelper::getInstance().createShapeInfo(descriptor1), + ConstantShapeHelper::getInstance().createShapeInfo(descriptor2), + ConstantShapeHelper::getInstance().createShapeInfo(descriptor3), + ConstantShapeHelper::getInstance().createShapeInfo(descriptor4)); } -} -} +} // namespace ops +} // namespace sd #endif ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operations for Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features - * 1: 2d tensor of weights [3K x K] - * 2: row of biases with twice length [1 x 2K] - * 3: 2d tensor of previous cell state [bS x K] - * 4: optional, 2d tensor of dropout mask [bS x K] - * - * Output arrays: - * 0: 3d tensor of cell output [bS x K x N] - * 1: 3d tensor of cell state [bS x K x N] - */ - // #if NOT_EXCLUDED(OP_sru) - // DECLARE_CUSTOM_OP(sru_old, 5, 2, false, 0, 0); - - - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features - * 1: 2d tensor of weights [3K x K] - * 2: row of biases with twice length [1 x 2K] - * 3: 2d tensor of previous cell state [bS x K] - * 4: optional, 2d tensor of dropout mask [bS x K] - * - * Output arrays: - * 0: 3d tensor of cell output [bS x K x N] - * 1: 3d tensor of cell state [bS x K x N] - */ - // #if NOT_EXCLUDED(OP_sru_logic) - // DECLARE_CUSTOM_OP(sru_logic, 5, 2, false, 0, 0); - // #endif +/** + * Implementation of operations for Simple Recurrent Unit: "Training RNNs as + * Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi + * + * Input arrays: + * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - + * batch size, K - number of features 1: 2d tensor of weights [3K x K] 2: row of + * biases with twice length [1 x 2K] 3: 2d tensor of previous cell state [bS x + * K] 4: optional, 2d tensor of dropout mask [bS x K] + * + * Output arrays: + * 0: 3d tensor of cell output [bS x K x N] + * 1: 3d tensor of cell state [bS x K x N] + */ +// #if NOT_EXCLUDED(OP_sru) +// DECLARE_CUSTOM_OP(sru_old, 5, 2, false, 0, 0); +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for Simple Recurrent Unit: "Training RNNs as Fast + * as CNNs" Tao Lei, Yu Zhang, Yoav Artzi + * + * Input arrays: + * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - + * batch size, K - number of features 1: 2d tensor of weights [3K x K] 2: row of + * biases with twice length [1 x 2K] 3: 2d tensor of previous cell state [bS x + * K] 4: optional, 2d tensor of dropout mask [bS x K] + * + * Output arrays: + * 0: 3d tensor of cell output [bS x K x N] + * 1: 3d tensor of cell state [bS x K x N] + */ +// #if NOT_EXCLUDED(OP_sru_logic) +// DECLARE_CUSTOM_OP(sru_logic, 5, 2, false, 0, 0); +// #endif ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for back propagation in Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features - * 1: 2d tensor of weights [3K x K] - * 2: row of biases with twice length [1 x 2K] - * 3: 2d tensor of previous cell state [bS x K] - * 4: 3d tensor of cell state [bS x K x N] - * 5: 2d tensor of cell state gradients [bS x K] - * 6: 3d tensor of state output gradients [bS x K x N] - * 7: optional, 2d tensor of dropout mask [bS x K] - * - * Output arrays: - * 0: 3d tensor of input gradients [bS x K x N] - * 1: 3d tensor of weights gradients [bS x 3K x K] - * 2: 2d, row of biases gradients [1 x 2K] - * 3: 2d, tensor of state gradients [bS x K] - */ - // #if NOT_EXCLUDED(OP_sru_logic) - // DECLARE_CUSTOM_OP(sru_bp_logic,8, 4, true, 0, 0); - // #endif +/** + * Implementation of operation for back propagation in Simple Recurrent Unit: + * "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi + * + * Input arrays: + * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - + * batch size, K - number of features 1: 2d tensor of weights [3K x K] 2: row of + * biases with twice length [1 x 2K] 3: 2d tensor of previous cell state [bS x + * K] 4: 3d tensor of cell state [bS x K x N] 5: 2d tensor of cell state + * gradients [bS x K] 6: 3d tensor of state output gradients [bS x K x N] 7: + * optional, 2d tensor of dropout mask [bS x K] + * + * Output arrays: + * 0: 3d tensor of input gradients [bS x K x N] + * 1: 3d tensor of weights gradients [bS x 3K x K] + * 2: 2d, row of biases gradients [1 x 2K] + * 3: 2d, tensor of state gradients [bS x K] + */ +// #if NOT_EXCLUDED(OP_sru_logic) +// DECLARE_CUSTOM_OP(sru_bp_logic,8, 4, true, 0, 0); +// #endif // return 2d array evaluated though last dimension interval t1-t2 // static NDArray* timestep(const NDArray* arr, const int t1, const int t2) { // NDArray* result = new NDArray((*arr)({0,0, 0,0, t1,t2}, true)); -// result->reshapei(result->ordering(), {arr->shapeOf()[0], arr->shapeOf()[1]} ); +// result->reshapei(result->ordering(), {arr->shapeOf()[0], +// arr->shapeOf()[1]} ); // return result; // } @@ -587,11 +888,14 @@ DECLARE_SHAPE_FN(sru_bi_bp) { ///////////////////////////////////////////////////////////////////////// // CUSTOM_OP_IMPL(sru_logic, 5, 2, false, 0, 0) { -// auto input = INPUT_VARIABLE(0); // X, input 3d tensor [bS x K x N], N - number of time steps, bS - batch size, K - number of features -// auto weights = INPUT_VARIABLE(1); // W, 2d tensor of weights [3K x K] -// auto bias = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 2*K] -// auto init = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x K] at time t=0 -// NDArray* mask = nullptr; // optional, 2d tensor of dropout mask [bS x K] +// auto input = INPUT_VARIABLE(0); // X, input 3d tensor +// [bS x K x N], N - number of time steps, bS - batch size, K - number of +// features auto weights = INPUT_VARIABLE(1); // W, 2d tensor +// of weights [3K x K] auto bias = INPUT_VARIABLE(2); // +// B, row of biases with twice length [1 x 2*K] auto init = +// INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state +// [bS x K] at time t=0 NDArray* mask = nullptr; // optional, 2d tensor +// of dropout mask [bS x K] // bool applyMask = false; // if (block.width() > 4) { @@ -602,13 +906,15 @@ DECLARE_SHAPE_FN(sru_bi_bp) { // auto output = OUTPUT_VARIABLE(0); // h_t, [bS x K x N] // auto state = OUTPUT_VARIABLE(1); // c_t, [bS x K x N] -// const int bS = input->shapeOf()[0]; // bS - batch size -// const int K = input->shapeOf()[1]; // K - number of features -// const int N = input->shapeOf()[2]; // N - number of time steps +// const int bS = input->shapeOf()[0]; // bS - batch +// size const int K = input->shapeOf()[1]; // K - +// number of features const int N = input->shapeOf()[2]; // N - number +// of time steps -// const auto wi = mmul(*weights, *input); // U [bS x 3K x N] -// const auto bF = (*bias)({0,0, 0, K}); // biases for forget gate [1 x K] -// const auto bR = (*bias)({0,0, K,2*K}); // biases for reset gate [1 x K] +// const auto wi = mmul(*weights, *input); // U [bS x 3K +// x N] const auto bF = (*bias)({0,0, 0, K}); // +// biases for forget gate [1 x K] const auto bR = (*bias)({0,0, K,2*K}); // +// biases for reset gate [1 x K] // NDArray xt(input->dataType(), block.launchContext()); // NDArray zt(input->dataType(), block.launchContext()); @@ -616,23 +922,28 @@ DECLARE_SHAPE_FN(sru_bi_bp) { // NDArray rt(input->dataType(), block.launchContext()); // NDArray ht(input->dataType(), block.launchContext()); // NDArray ct = *init; -// NDArray gct(state->ordering(), {bS, K}, input->dataType(), block.launchContext()); -// NDArray xmt = *input; +// NDArray gct(state->ordering(), {bS, K}, input->dataType(), +// block.launchContext()); NDArray xmt = *input; // // input = input * mask // if(applyMask) // xmt.applyBroadcast(broadcast::Multiply, {0, 1}, mask, &xmt, nullptr); // for (int t = 0; t < N; ++t) { -// xt = xmt({0,0, 0,0, t,t+1}); xt.reshapei(xt.ordering(), {bS, K}); // [bS x K x N] -> [bS x K x 1] -> [bS x K] -// zt = wi({0,0, 0, K, t,t+1}); zt.reshapei(zt.ordering(), {bS, K}); // [bS x 3K x N] -> [bS x K x 1] -> [bS x K] -// ft = wi({0,0, K, 2*K, t,t+1}); ft.reshapei(ft.ordering(), {bS, K}); // [bS x 3K x N] -> [bS x K x 1] -> [bS x K] -// rt = wi({0,0, 2*K,3*K, t,t+1}); rt.reshapei(rt.ordering(), {bS, K}); // [bS x 3K x N] -> [bS x K x 1] -> [bS x K] +// xt = xmt({0,0, 0,0, t,t+1}); xt.reshapei(xt.ordering(), {bS, K}); +// // [bS x K x N] -> [bS x K x 1] -> [bS x K] zt = wi({0,0, 0, K, +// t,t+1}); zt.reshapei(zt.ordering(), {bS, K}); // [bS x 3K x N] +// -> [bS x K x 1] -> [bS x K] ft = wi({0,0, K, 2*K, t,t+1}); +// ft.reshapei(ft.ordering(), {bS, K}); // [bS x 3K x N] -> [bS x +// K x 1] -> [bS x K] rt = wi({0,0, 2*K,3*K, t,t+1}); +// rt.reshapei(rt.ordering(), {bS, K}); // [bS x 3K x N] -> [bS x +// K x 1] -> [bS x K] // ft = sigmoid_(ft + bF); // rt = sigmoid_(rt + bR); // ct = ft * (ct - zt) + zt; -// // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur ); +// // TODO T val = (activation_type == 1) ? tanh(cur) : +// ((activation_type == 2) ? reluf(cur) : cur ); // ct.applyTransform(transform::Tanh, &gct); // ht = rt * (gct - xt) + xt; @@ -660,7 +971,7 @@ DECLARE_SHAPE_FN(sru_bi_bp) { // char order = (char)(inShape[size-1]); // Nd4jLong* newShapeInfo1 = nullptr; -// ALLOCATE(newShapeInfo1, block.getWorkspace(), size, Nd4jLong); +// ALLOCATE(newShapeInfo1, block.workspace(), size, Nd4jLong); // newShapeInfo1[0] = rank; // newShapeInfo1[1] = bS; @@ -672,14 +983,16 @@ DECLARE_SHAPE_FN(sru_bi_bp) { // return SHAPELIST(result, result); // } - // ////////////////////////////////////////////////////////////////////////// // CUSTOM_OP_IMPL(sru_old, 5, 2, false, 0, 0) { -// auto x = INPUT_VARIABLE(0); // X, input 3d tensor [bS x inSize x time], time - number of time steps, bS - batch size, inSize - number of features -// auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [3K x inSize] -// auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 2*inSize] -// auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x inSize] at time t=0 -// NDArray* mask = nullptr; // optional, 2d tensor of dropout mask [bS x inSize] +// auto x = INPUT_VARIABLE(0); // X, input 3d tensor [bS x +// inSize x time], time - number of time steps, bS - batch size, inSize - +// number of features auto w = INPUT_VARIABLE(1); // W, 2d +// tensor of weights [3K x inSize] auto b = INPUT_VARIABLE(2); // B, row +// of biases with twice length [1 x 2*inSize] auto c0 = +// INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state +// [bS x inSize] at time t=0 NDArray* mask = nullptr; // optional, 2d +// tensor of dropout mask [bS x inSize] // bool applyMask = false; // if (block.width() > 4) { @@ -688,35 +1001,44 @@ DECLARE_SHAPE_FN(sru_bi_bp) { // } // auto h = OUTPUT_VARIABLE(0); // h_t, [bS x inSize x time] -// auto state = OUTPUT_VARIABLE(1); // c_t, [bS x inSize x time] +// auto state = OUTPUT_VARIABLE(1); // c_t, [bS x inSize x +// time] -// const int bS = x->shapeOf()[0]; // bS - batch size -// const int inSize = x->shapeOf()[1]; // inSize - number of features -// const int time = x->shapeOf()[2]; // time - number of time steps +// const int bS = x->shapeOf()[0]; // bS - batch +// size const int inSize = x->shapeOf()[1]; // +// inSize - number of features const int time = x->shapeOf()[2]; // +// time - number of time steps // // multiplication matrix = matmul(w,x) -// auto wi = MmulHelper::mmul(w, x, nullptr, 1., 0.); // U [bS x 3K x time] -// auto wiZ = (*wi)({0,0, 0,inSize, 0,0}, true); // [bS x inSize x time] -// auto wiF = (*wi)({0,0, inSize,2*inSize, 0,0}, true); // forget gate [bS x inSize x time] -// auto wiR = (*wi)({0,0, 2*inSize,3*inSize, 0,0}, true); // reset gate [bS x inSize x time] -// auto bF = (*b) ({0,0, 0,inSize }, true); // biases for forget gate [1 x inSize] -// auto bR = (*b) ({0,0, inSize,2*inSize}, true); // biases for reset gate [1 x inSize] - -// NDArray* xt(nullptr), *zt(nullptr), *ft(nullptr), *rt(nullptr), *ct(nullptr), *ht(nullptr); -// auto ct_1 = c0->dup(c0->ordering()); -// auto gct = NDArrayFactory::create_(state->ordering(), {bS, inSize}, state->dataType(), state->getContext()); -// auto xmt = x->dup(x->ordering()); +// auto wi = MmulHelper::mmul(w, x, nullptr, 1., 0.); // U [bS x +// 3K x time] auto wiZ = (*wi)({0,0, 0,inSize, 0,0}, true); // [bS +// x inSize x time] auto wiF = (*wi)({0,0, inSize,2*inSize, 0,0}, true); +// // forget gate [bS x inSize x time] auto wiR = (*wi)({0,0, +// 2*inSize,3*inSize, 0,0}, true); // reset gate [bS x inSize x time] +// auto bF = (*b) ({0,0, 0,inSize }, true); // biases +// for forget gate [1 x inSize] auto bR = (*b) ({0,0, inSize,2*inSize}, +// true); // biases for reset gate [1 x inSize] + +// NDArray* xt(nullptr), *zt(nullptr), *ft(nullptr), *rt(nullptr), +// *ct(nullptr), *ht(nullptr); auto ct_1 = c0->dup(c0->ordering()); auto gct +// = NDArrayFactory::create_(state->ordering(), {bS, inSize}, +// state->dataType(), state->getContext()); auto xmt = +// x->dup(x->ordering()); // // x = x * mask // if(applyMask) -// xmt->applyBroadcast(broadcast::Multiply, {0, 1}, mask, xmt, nullptr); // apply mask +// xmt->applyBroadcast(broadcast::Multiply, {0, 1}, mask, xmt, nullptr); +// // apply mask // for (int t = 0; t < time; ++t) { -// xt = timestep(xmt, t, t+1); // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize] -// zt = timestep(&wiZ, t, t+1); // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize] -// ft = timestep(&wiF, t, t+1); // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize] -// rt = timestep(&wiR, t, t+1); // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize] -// ct = timestep(state, t, t+1); // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize] -// ht = timestep(h, t, t+1); // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize] +// xt = timestep(xmt, t, t+1); // [bS x inSize x time] -> [bS x +// inSize x 1] -> [bS x inSize] zt = timestep(&wiZ, t, t+1); // +// [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize] ft = +// timestep(&wiF, t, t+1); // [bS x inSize x time] -> [bS x +// inSize x 1] -> [bS x inSize] rt = timestep(&wiR, t, t+1); // +// [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize] ct = +// timestep(state, t, t+1); // [bS x inSize x time] -> [bS x +// inSize x 1] -> [bS x inSize] ht = timestep(h, t, t+1); // +// [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize] // // ft = sigmoid(ft + bf), rt = sigmoid(rt + bR) // ft->addRowVector(&bF, ft); @@ -728,7 +1050,8 @@ DECLARE_SHAPE_FN(sru_bi_bp) { // ft->applyTransform(transform::OneMinus, ft); // ft->applyPairwiseTransform(pairwise::Multiply, *zt, nullptr); // ct->applyPairwiseTransform(pairwise::Add, *ft, nullptr); -// // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur ); +// // TODO T val = (activation_type == 1) ? tanh(cur) : +// ((activation_type == 2) ? reluf(cur) : cur ); // ct->applyTransform(transform::Tanh, gct); // // ht = rt * gct + (1 - rt) * xt @@ -762,7 +1085,7 @@ DECLARE_SHAPE_FN(sru_bi_bp) { // char order = (char)(inShape[size-1]); // Nd4jLong *newShapeInfo1 = nullptr; -// ALLOCATE(newShapeInfo1, block.getWorkspace(), size, Nd4jLong); +// ALLOCATE(newShapeInfo1, block.workspace(), size, Nd4jLong); // newShapeInfo1[0] = rank; // newShapeInfo1[1] = bS; @@ -771,8 +1094,9 @@ DECLARE_SHAPE_FN(sru_bi_bp) { // ShapeUtils::updateStridesAndType(newShapeInfo1, inShape, order); -// auto result = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(newShapeInfo1)); -// RELEASE(newShapeInfo1, block.getWorkspace()); +// auto result = +// ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(newShapeInfo1)); +// RELEASE(newShapeInfo1, block.workspace()); // return SHAPELIST(result, result); // } @@ -785,90 +1109,128 @@ DECLARE_SHAPE_FN(sru_bi_bp) { ////////////////////////////////////////////////////////////////////////// // CUSTOM_OP_IMPL(sru_bp_logic, 8, 4, true, 0, 0) { -// auto x = INPUT_VARIABLE(0); // X, input 3d tensor [bS x inSize x time], time - number of time steps, bS - batch size, inSize - number of features -// auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [3*inSize x inSize] -// auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 2*inSize] -// auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x inSize] at time t=0 -// auto c = INPUT_VARIABLE(4); // C, [bS x inSize x time] -// auto inGradCt = INPUT_VARIABLE(5); // [bS x inSize] -// auto inGradH = INPUT_VARIABLE(6); // [bS x inSize x time] -// auto mask = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr; // optional, 2d tensor of dropout mask [bS x inSize] +// auto x = INPUT_VARIABLE(0); // +// X, input 3d tensor [bS x inSize x time], time - number of time steps, bS +// - batch size, inSize - number of features auto w = +// INPUT_VARIABLE(1); // W, 2d tensor of +// weights [3*inSize x inSize] auto b = INPUT_VARIABLE(2); // B, row +// of biases with twice length [1 x 2*inSize] auto c0 = +// INPUT_VARIABLE(3); // C_{0}, 2d tensor +// of initial state [bS x inSize] at time t=0 auto c = +// INPUT_VARIABLE(4); // C, [bS x inSize x +// time] auto inGradCt = INPUT_VARIABLE(5); // [bS x inSize] auto inGradH = +// INPUT_VARIABLE(6); // [bS x inSize x +// time] auto mask = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr; // +// optional, 2d tensor of dropout mask [bS x inSize] // auto gradX = OUTPUT_VARIABLE(0); // [bS x inSize x time] -// auto gradW = OUTPUT_VARIABLE(1); // [bS x 3*inSize x inSize] -// auto gradB = OUTPUT_VARIABLE(2); // [2*inSize] +// auto gradW = OUTPUT_VARIABLE(1); // [bS x 3*inSize x +// inSize] auto gradB = OUTPUT_VARIABLE(2); // [2*inSize] // auto gradInit = OUTPUT_VARIABLE(3); // [bS x inSize] // // input shapes validation // const int rank = 3; -// REQUIRE_TRUE(x->rankOf() == rank, 0, "SRU_BP operation: wrong rank of input array, expected is %i, but got %i instead !", rank, x->rankOf()); -// REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU_BP operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, w->rankOf()); -// REQUIRE_TRUE(b->rankOf() <= 2, 0, "SRU_BP operation: wrong rank of biases array, expected is <=2, but got %i instead !", b->rankOf()); -// REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU_BP operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0->rankOf()); -// REQUIRE_TRUE(c->rankOf() == rank, 0, "SRU_BP operation: wrong rank of cell states array, expected is %i, but got %i instead !", rank, c->rankOf()); -// REQUIRE_TRUE(inGradCt->rankOf() == rank-1, 0, "SRU_BP operation: wrong rank of array of cell state gradient, expected is %i, but got %i instead !", rank-1, inGradCt->rankOf()); -// REQUIRE_TRUE(inGradH->rankOf() == rank, 0, "SRU_BP operation: wrong rank of array of cell outputs gradients, expected is %i, but got %i instead !", rank, inGradH->rankOf()); -// if(mask) -// REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU_BP operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf()); +// REQUIRE_TRUE(x->rankOf() == rank, 0, "SRU_BP operation: wrong rank of +// input array, expected is %i, but got %i instead !", rank, x->rankOf()); +// REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU_BP operation: wrong rank of +// weights array, expected is %i, but got %i instead !", rank-1, +// w->rankOf()); REQUIRE_TRUE(b->rankOf() <= 2, 0, "SRU_BP operation: +// wrong rank of biases array, expected is <=2, but got %i instead !", +// b->rankOf()); REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU_BP operation: +// wrong rank of initial state array, expected is %i, but got %i instead !", +// rank-1, c0->rankOf()); REQUIRE_TRUE(c->rankOf() == rank, 0, "SRU_BP +// operation: wrong rank of cell states array, expected is %i, but got %i +// instead !", rank, c->rankOf()); REQUIRE_TRUE(inGradCt->rankOf() == +// rank-1, 0, "SRU_BP operation: wrong rank of array of cell state gradient, +// expected is %i, but got %i instead !", rank-1, inGradCt->rankOf()); +// REQUIRE_TRUE(inGradH->rankOf() == rank, 0, "SRU_BP operation: wrong +// rank of array of cell outputs gradients, expected is %i, but got %i +// instead !", rank, inGradH->rankOf()); if(mask) +// REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU_BP operation: wrong +// rank of mask array, expected is %i, but got %i instead !", rank-1, +// mask->rankOf()); // const int bS = x->shapeOf()[0]; // const int inSize = x->shapeOf()[1]; -// const int time = x->shapeOf()[2]; // time - number of time steps +// const int time = x->shapeOf()[2]; // time - number +// of time steps // const std::string wShape = ShapeUtils::shapeAsString(w); -// const std::string wCorrectShape = ShapeUtils::shapeAsString({3*inSize, inSize}); +// const std::string wCorrectShape = +// ShapeUtils::shapeAsString({3*inSize, inSize}); // // const std::string bShape = ShapeUtils::shapeAsString(b); -// // const std::string bCorrectShape = ShapeUtils::shapeAsString({2*inSize}); -// const std::string c0Shape = ShapeUtils::shapeAsString(c0); -// const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, inSize}); -// const std::string cShape = ShapeUtils::shapeAsString(c); -// const std::string cCorrectShape = ShapeUtils::shapeAsString({bS, inSize, time}); -// const std::string inGradCtShape = ShapeUtils::shapeAsString(inGradCt); -// const std::string inGradCtCorrectShape = ShapeUtils::shapeAsString({bS, inSize}); -// const std::string inGradHShape = ShapeUtils::shapeAsString(inGradH); -// const std::string inGradHCorrectShape = ShapeUtils::shapeAsString({bS, inSize, time}); - -// REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BP operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str()); -// // REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BP operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str()); -// REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BP operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str()); -// REQUIRE_TRUE(cShape == cCorrectShape, 0, "SRU_BP operation: wrong shape of cell states array, expected is %s, but got %s instead !", cCorrectShape.c_str(), cShape.c_str()); -// REQUIRE_TRUE(inGradCtShape == inGradCtCorrectShape, 0, "SRU_BP operation: wrong shape of array of cell state gradient, expected is %s, but got %s instead !", inGradCtCorrectShape.c_str(), inGradCtShape.c_str()); -// REQUIRE_TRUE(inGradHShape == inGradHCorrectShape, 0, "SRU_BP operation: wrong shape of array of cell outputs gradients, expected is %s, but got %s instead !", inGradHCorrectShape.c_str(), inGradHShape.c_str()); +// // const std::string bCorrectShape = +// ShapeUtils::shapeAsString({2*inSize}); const std::string c0Shape = +// ShapeUtils::shapeAsString(c0); const std::string c0CorrectShape = +// ShapeUtils::shapeAsString({bS, inSize}); const std::string cShape = +// ShapeUtils::shapeAsString(c); const std::string cCorrectShape = +// ShapeUtils::shapeAsString({bS, inSize, time}); const std::string +// inGradCtShape = ShapeUtils::shapeAsString(inGradCt); const +// std::string inGradCtCorrectShape = ShapeUtils::shapeAsString({bS, +// inSize}); const std::string inGradHShape = +// ShapeUtils::shapeAsString(inGradH); const std::string inGradHCorrectShape +// = ShapeUtils::shapeAsString({bS, inSize, time}); + +// REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BP operation: wrong shape +// of weights array, expected is %s, but got %s instead !", +// wCorrectShape.c_str(), wShape.c_str()); +// // REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BP operation: wrong +// shape of biases array, expected is %s, but got %s instead !", +// bCorrectShape.c_str(), bShape.c_str()); REQUIRE_TRUE(c0Shape == +// c0CorrectShape, 0, "SRU_BP operation: wrong shape of initial state array, +// expected is %s, but got %s instead !", c0CorrectShape.c_str(), +// c0Shape.c_str()); REQUIRE_TRUE(cShape == cCorrectShape, 0, "SRU_BP +// operation: wrong shape of cell states array, expected is %s, but got %s +// instead !", cCorrectShape.c_str(), cShape.c_str()); +// REQUIRE_TRUE(inGradCtShape == inGradCtCorrectShape, 0, "SRU_BP operation: +// wrong shape of array of cell state gradient, expected is %s, but got %s +// instead !", inGradCtCorrectShape.c_str(), inGradCtShape.c_str()); +// REQUIRE_TRUE(inGradHShape == inGradHCorrectShape, 0, "SRU_BP operation: +// wrong shape of array of cell outputs gradients, expected is %s, but got +// %s instead !", inGradHCorrectShape.c_str(), inGradHShape.c_str()); // if(mask) { // const std::string maskShape = ShapeUtils::shapeAsString(mask); -// REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BP operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str()); +// REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BP operation: wrong +// shape of mask array, expected is %s, but got %s instead !", +// c0CorrectShape.c_str(), maskShape.c_str()); // } - -// const auto bF = (*b)({0,0, 0, inSize}); // biases for forget gate [1 x inSize] -// const auto bR = (*b)({0,0, inSize,2*inSize}); // biases for reset gate [1 x inSize] -// NDArray gradBias(x->ordering(), {bS, 2*inSize, time}, x->dataType(), block.launchContext()); -// NDArray gradU (x->ordering(), {bS, 3*inSize, time}, x->dataType(), block.launchContext()); -// NDArray gradHX (x->ordering(), {bS, inSize, time}, x->dataType(), block.launchContext()); -// NDArray gct (c->ordering(), {bS, inSize}, x->dataType(), block.launchContext()); +// const auto bF = (*b)({0,0, 0, inSize}); // biases for forget gate +// [1 x inSize] const auto bR = (*b)({0,0, inSize,2*inSize}); // biases for +// reset gate [1 x inSize] NDArray gradBias(x->ordering(), {bS, 2*inSize, +// time}, x->dataType(), block.launchContext()); NDArray gradU +// (x->ordering(), {bS, 3*inSize, time}, x->dataType(), +// block.launchContext()); NDArray gradHX (x->ordering(), {bS, inSize, +// time}, x->dataType(), block.launchContext()); NDArray gct (c->ordering(), +// {bS, inSize}, x->dataType(), block.launchContext()); // // x = x * mask // if(mask) -// x->applyBroadcast(broadcast::Multiply, {0, 1}, mask, x, nullptr); // apply mask +// x->applyBroadcast(broadcast::Multiply, {0, 1}, mask, x, nullptr); // +// apply mask // // multiplication matrix wi = matmul(w,x), U = WX -// const auto wi = mmul(*w, *x); // U [bS x 3K x time] +// const auto wi = mmul(*w, *x); // U [bS x 3K x time] // for (int t = time-1; t >=0 ; --t) { // // initialization -// auto xt = (*x)({0,0, 0,0, t,t+1}); // [bS x inSize x time] -> [bS x inSize] -// auto zt = wi({0,0, 0, inSize, t,t+1}); // [bS x 3K x time] -> [bS x inSize] -// auto ft = wi({0,0, inSize, 2*inSize, t,t+1}); // [bS x 3K x time] -> [bS x inSize] -// auto rt = wi({0,0, 2*inSize,3*inSize, t,t+1}); // [bS x 3K x time] -> [bS x inSize] -// auto ct = (*c)({0,0, 0,0, t,t+1}); // [bS x inSize x time] -> [bS x inSize] -// auto inGradHt = (*inGradH)({ 0,0, 0,0, t,t+1}); // [bS x inSize x time] -> [bS x inSize] - -// auto ct_1 = t ? (*c)({ 0,0, 0,0, t-1,t}) : *c0; // previous c_{t-1} +// auto xt = (*x)({0,0, 0,0, t,t+1}); // +// [bS x inSize x time] -> [bS x inSize] auto zt = wi({0,0, 0, inSize, +// t,t+1}); // [bS x 3K x time] -> [bS x inSize] auto ft = wi({0,0, +// inSize, 2*inSize, t,t+1}); // [bS x 3K x time] -> [bS x inSize] +// auto rt = wi({0,0, 2*inSize,3*inSize, t,t+1}); // +// [bS x 3K x time] -> [bS x inSize] auto ct = (*c)({0,0, 0,0, +// t,t+1}); // [bS x inSize x time] -> [bS x inSize] auto inGradHt = +// (*inGradH)({ 0,0, 0,0, t,t+1}); // [bS x inSize x +// time] -> [bS x inSize] + +// auto ct_1 = t ? (*c)({ 0,0, 0,0, t-1,t}) : *c0; // previous c_{t-1} // ///////////////// forward // // ft = sigmoid(ft + bf), rt = sigmoid(rt + bR) // ft = sigmoid_(ft + bF); // rt = sigmoid_(rt + bR); -// // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur ); +// // TODO T val = (activation_type == 1) ? tanh(cur) : +// ((activation_type == 2) ? reluf(cur) : cur ); // ct.applyTransform(transform::Tanh, &gct); // ///////////////// backward @@ -884,8 +1246,9 @@ DECLARE_SHAPE_FN(sru_bi_bp) { // // x_t (highway connection), gradHXt = inGradHt * (1.0f - rt); // NDArray gradHXt = inGradHt * rtMinus; -// // U_t, gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft); -// NDArray gradUZt = (inGradHt * rt * gradTanh + *inGradCt) * ftMinus; +// // U_t, gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - +// ft); NDArray gradUZt = (inGradHt * rt * gradTanh + *inGradCt) * +// ftMinus; // // c_{t-1}, inGradCt = (gradCt + inGradCt) * ft; // *inGradCt = (gradCt + *inGradCt) * ft; @@ -902,17 +1265,18 @@ DECLARE_SHAPE_FN(sru_bi_bp) { // // gradInit // gradInit->assign(inGradCt); // // gradX -// w->transposei(); // [inSize x 3K] -// gradX->assign( mmul(*w, gradU) + gradHX); -// if(mask) -// gradX->applyBroadcast(broadcast::Multiply, {0,1}, mask, gradX, nullptr); // apply mask +// w->transposei(); // [inSize x 3K] gradX->assign( mmul(*w, gradU) + +// gradHX); if(mask) +// gradX->applyBroadcast(broadcast::Multiply, {0,1}, mask, gradX, +// nullptr); // apply mask // // gradB -// gradBias.reduceAlongDimension(reduce::Sum, *gradB, {0,2}, false, true); // [1 x 2K] +// gradBias.reduceAlongDimension(reduce::Sum, *gradB, {0,2}, false, true); +// // [1 x 2K] // // gradW [bS x 3K x inSize] -// x->permutei({0, 2, 1}); // [bS x time x inSize] -// gradW->assign( mmul(gradU, *x) ); +// x->permutei({0, 2, 1}); // +// [bS x time x inSize] gradW->assign( mmul(gradU, *x) ); // return Status::OK(); // } @@ -930,10 +1294,16 @@ DECLARE_SHAPE_FN(sru_bi_bp) { // auto time = inShape[3]; // char order = shape::order(inShape); -// ShapeDescriptor descriptor1(ArrayOptions::dataType(inShape), order, {bS, inSize, time}); -// ShapeDescriptor descriptor2(ArrayOptions::dataType(inShape), order, {bS, 3 * inSize, inSize}); -// ShapeDescriptor descriptor3(ArrayOptions::dataType(inShape), order, {1, 2 * inSize}); -// ShapeDescriptor descriptor4(ArrayOptions::dataType(inShape), order, {bS, inSize}); - -// return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(descriptor1), ConstantShapeHelper::getInstance().createShapeInfo(descriptor2), ConstantShapeHelper::getInstance().createShapeInfo(descriptor3), ConstantShapeHelper::getInstance().createShapeInfo(descriptor4)); +// ShapeDescriptor descriptor1(ArrayOptions::dataType(inShape), order, {bS, +// inSize, time}); ShapeDescriptor +// descriptor2(ArrayOptions::dataType(inShape), order, {bS, 3 * inSize, +// inSize}); ShapeDescriptor descriptor3(ArrayOptions::dataType(inShape), +// order, {1, 2 * inSize}); ShapeDescriptor +// descriptor4(ArrayOptions::dataType(inShape), order, {bS, inSize}); + +// return +// SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(descriptor1), +// ConstantShapeHelper::getInstance().createShapeInfo(descriptor2), +// ConstantShapeHelper::getInstance().createShapeInfo(descriptor3), +// ConstantShapeHelper::getInstance().createShapeInfo(descriptor4)); // } diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/sruCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/sruCell.cpp index 3268da4539aa..3d04b192d79b 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/sruCell.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/sruCell.cpp @@ -22,88 +22,119 @@ #if NOT_EXCLUDED(OP_sruCell) #include -#include - +#include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(sruCell, 4, 2, false, 0, 0) { - auto xt = INPUT_VARIABLE(0); // input [bS x inSize], bS - batch size, inSize - number of features - auto ct_1 = INPUT_VARIABLE(1); // previous cell state ct [bS x inSize], that is at previous time step t-1 - auto w = INPUT_VARIABLE(2); // weights [inSize x 3*inSize] - auto b = INPUT_VARIABLE(3); // biases [2*inSize] - - auto ht = OUTPUT_VARIABLE(0); // current cell output [bS x inSize], that is at current time step t - auto ct = OUTPUT_VARIABLE(1); // current cell state [bS x inSize], that is at current time step t - - const int rank = xt->rankOf(); - const int bS = xt->sizeAt(0); - const int inSize = xt->sizeAt(1); // inSize - number of features - - // input shapes validation - const std::vector correctCt_1Shape = {bS, inSize}; - const std::vector correctWShape = {inSize, 3*inSize}; - const std::vector correctBShape = {2*inSize}; - - REQUIRE_TRUE(ct_1->isSameShape(correctCt_1Shape), 0, "SRUCELL operation: wrong shape of previous cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctCt_1Shape).c_str(), ShapeUtils::shapeAsString(ct_1).c_str()); - REQUIRE_TRUE(w->isSameShape(correctWShape), 0, "SRUCELL operation: wrong shape of weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWShape).c_str(), ShapeUtils::shapeAsString(w).c_str()); - REQUIRE_TRUE(b->isSameShape(correctBShape), 0, "SRUCELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctBShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - - - // fixme: shitty initializer lists - helpers::sruCell(block.launchContext(), xt, ct_1, w, b, ht, ct); - - return Status::OK(); + auto xt = INPUT_VARIABLE( + 0); // input [bS x inSize], bS - batch size, inSize - number of features + auto ct_1 = INPUT_VARIABLE(1); // previous cell state ct [bS x inSize], that + // is at previous time step t-1 + auto w = INPUT_VARIABLE(2); // weights [inSize x 3*inSize] + auto b = INPUT_VARIABLE(3); // biases [2*inSize] + + auto ht = OUTPUT_VARIABLE( + 0); // current cell output [bS x inSize], that is at current time step t + auto ct = OUTPUT_VARIABLE( + 1); // current cell state [bS x inSize], that is at current time step t + + const int rank = xt->rankOf(); + const int bS = xt->sizeAt(0); + const int inSize = xt->sizeAt(1); // inSize - number of features + + // input shapes validation + const std::vector correctCt_1Shape = {bS, inSize}; + const std::vector correctWShape = {inSize, 3 * inSize}; + const std::vector correctBShape = {2 * inSize}; + + REQUIRE_TRUE(ct_1->isSameShape(correctCt_1Shape), 0, + "SRUCELL operation: wrong shape of previous cell state, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctCt_1Shape).c_str(), + ShapeUtils::shapeAsString(ct_1).c_str()); + REQUIRE_TRUE(w->isSameShape(correctWShape), 0, + "SRUCELL operation: wrong shape of weights, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(correctWShape).c_str(), + ShapeUtils::shapeAsString(w).c_str()); + REQUIRE_TRUE(b->isSameShape(correctBShape), 0, + "SRUCELL operation: wrong shape of biases, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(correctBShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + + // fixme: shitty initializer lists + helpers::sruCell(block.launchContext(), xt, ct_1, w, b, ht, ct); + + return Status::OK(); } - DECLARE_TYPES(sruCell) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(sruCell) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(sruCell) { - - auto xtShapeInfo = inputShape->at(0); // input [bS x inSize], bS - batch size, inSize - number of features - auto ct_1ShapeInfo = inputShape->at(1); // previous cell state ct [bS x inSize], that is at previous time step t-1 - auto wShapeInfo = inputShape->at(2); // weights [inSize x 3*inSize] - auto bShapeInfo = inputShape->at(3); // biases [2*inSize] - - const int rank = xtShapeInfo[0]; - const int bS = xtShapeInfo[1]; - const int inSize = xtShapeInfo[2]; // inSize - number of features - - // input shapes validation - const std::vector correctCt_1Shape = {bS, inSize}; - const std::vector correctWShape = {inSize, 3*inSize}; - const std::vector correctBShape = {2*inSize}; - - REQUIRE_TRUE(ShapeUtils::areShapesEqual(ct_1ShapeInfo, correctCt_1Shape) , 0, "SRUCELL operation: wrong shape of previous cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctCt_1Shape).c_str(), ShapeUtils::shapeAsString(ct_1ShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo ,correctWShape), 0, "SRUCELL operation: wrong shape of weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWShape).c_str(), ShapeUtils::shapeAsString(wShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo ,correctBShape), 0, "SRUCELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctBShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str()); - - // evaluate output shapeInfos - Nd4jLong *hShapeInfo(nullptr), *cShapeInfo(nullptr); - ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x numProj] - ALLOCATE(cShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x numUnits] - - hShapeInfo[0] = cShapeInfo[0] = rank; - hShapeInfo[1] = cShapeInfo[1] = bS; - hShapeInfo[2] = cShapeInfo[2] = inSize; - - ShapeUtils::updateStridesAndType(hShapeInfo, ct_1ShapeInfo, shape::order(ct_1ShapeInfo)); - ShapeUtils::updateStridesAndType(cShapeInfo, ct_1ShapeInfo, shape::order(ct_1ShapeInfo)); - - return SHAPELIST(ConstantShapeHelper::getInstance().createFromExisting(hShapeInfo, block.workspace()), ConstantShapeHelper::getInstance().createFromExisting(cShapeInfo, block.workspace())); + auto xtShapeInfo = inputShape->at( + 0); // input [bS x inSize], bS - batch size, inSize - number of features + auto ct_1ShapeInfo = + inputShape->at(1); // previous cell state ct [bS x inSize], that is at + // previous time step t-1 + auto wShapeInfo = inputShape->at(2); // weights [inSize x 3*inSize] + auto bShapeInfo = inputShape->at(3); // biases [2*inSize] + + const int rank = xtShapeInfo[0]; + const int bS = xtShapeInfo[1]; + const int inSize = xtShapeInfo[2]; // inSize - number of features + + // input shapes validation + const std::vector correctCt_1Shape = {bS, inSize}; + const std::vector correctWShape = {inSize, 3 * inSize}; + const std::vector correctBShape = {2 * inSize}; + + REQUIRE_TRUE(ShapeUtils::areShapesEqual(ct_1ShapeInfo, correctCt_1Shape), 0, + "SRUCELL operation: wrong shape of previous cell state, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctCt_1Shape).c_str(), + ShapeUtils::shapeAsString(ct_1ShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, correctWShape), 0, + "SRUCELL operation: wrong shape of weights, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(correctWShape).c_str(), + ShapeUtils::shapeAsString(wShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, correctBShape), 0, + "SRUCELL operation: wrong shape of biases, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(correctBShape).c_str(), + ShapeUtils::shapeAsString(bShapeInfo).c_str()); + + // evaluate output shapeInfos + Nd4jLong *hShapeInfo(nullptr), *cShapeInfo(nullptr); + ALLOCATE(hShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); // [bS x numProj] + ALLOCATE(cShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); // [bS x numUnits] + + hShapeInfo[0] = cShapeInfo[0] = rank; + hShapeInfo[1] = cShapeInfo[1] = bS; + hShapeInfo[2] = cShapeInfo[2] = inSize; + + ShapeUtils::updateStridesAndType(hShapeInfo, ct_1ShapeInfo, + shape::order(ct_1ShapeInfo)); + ShapeUtils::updateStridesAndType(cShapeInfo, ct_1ShapeInfo, + shape::order(ct_1ShapeInfo)); + + return SHAPELIST(ConstantShapeHelper::getInstance().createFromExisting( + hShapeInfo, block.workspace()), + ConstantShapeHelper::getInstance().createFromExisting( + cShapeInfo, block.workspace())); } - - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/staticBidirectionalRNN.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/staticBidirectionalRNN.cpp index fbe604a31835..a37cb4c3a992 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/staticBidirectionalRNN.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/staticBidirectionalRNN.cpp @@ -19,211 +19,331 @@ // #include -#include -#include -#include +#include +#include +#include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(static_bidirectional_rnn, 7, 3, false, 0, 0) { - auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] - auto WxFW = INPUT_VARIABLE(1); // input-to-hidden weights for forward RNN, [inSize x numUnitsFW] - auto WhFW = INPUT_VARIABLE(2); // hidden-to-hidden weights for forward RNN, [numUnitsFW x numUnitsFW] - auto bFW = INPUT_VARIABLE(3); // biases for forward RNN, [2*numUnitsFW] - auto WxBW = INPUT_VARIABLE(4); // input-to-hidden weights for backward RNN, [inSize x numUnitsBW] - auto WhBW = INPUT_VARIABLE(5); // hidden-to-hidden weights for backward RNN, [numUnitsBW x numUnitsBW] - auto bBW = INPUT_VARIABLE(6); // biases for backward RNN, [2*v] - - NDArray* h0FW = nullptr; // initial cell output for forward RNN (at time step = 0) [bS x numUnitsFW] - NDArray* h0BW = nullptr; // initial cell output for backward RNN (at time step = 0) [bS x numUnitsBW] - NDArray* maxTimeStep = nullptr; // vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep - - switch(block.width()) { - case 8: - maxTimeStep = INPUT_VARIABLE(7); - break; - case 9: - h0FW = INPUT_VARIABLE(7); - h0BW = INPUT_VARIABLE(8); - break; - case 10: - h0FW = INPUT_VARIABLE(7); - h0BW = INPUT_VARIABLE(8); - maxTimeStep = INPUT_VARIABLE(9); - break; - } - - auto h = OUTPUT_VARIABLE(0); // cell outputs [time x bS x (numUnitsFW + numUnitsBW)], that is per each time step - auto hFWFinal = OUTPUT_VARIABLE(1); // final cell out for forward RNN [bS x numUnitsFW] - auto hBWFinal = OUTPUT_VARIABLE(2); // final cell out for backward RNN [bS x numUnitsBF] - - REQUIRE_TRUE(x->rankOf() == 3, 0, "STATIC_BIDIRECTIONAL_RNN custom operation: input array must have rank = 3, but got %i instead !", x->rankOf()); - REQUIRE_TRUE(WxFW->rankOf() == 2, 0, "STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for forward RNN) must have rank = 2, but got %i instead !", WxFW->rankOf()); - REQUIRE_TRUE(WxBW->rankOf() == 2, 0, "STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for backward RNN) must have rank = 2, but got %i instead !", WxBW->rankOf()); - - const Nd4jLong inRank = x->rankOf(); - const Nd4jLong time = x->sizeAt(0); - const Nd4jLong bS = x->sizeAt(1); - const Nd4jLong numUnitsFW = WxFW->sizeAt(1); - const Nd4jLong numUnitsBW = WxBW->sizeAt(1); - - const std::vector expectedWhFWshape = {numUnitsFW, numUnitsFW}; - const std::vector expectedWhBWshape = {numUnitsBW, numUnitsBW}; - const std::vector expectedbFWshape = {2 * numUnitsFW}; - const std::vector expectedbBWshape = {2 * numUnitsBW}; - - REQUIRE_TRUE(WhFW->isSameShape(expectedWhFWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhFWshape).c_str(), ShapeUtils::shapeAsString(WhFW).c_str()); - REQUIRE_TRUE(WhBW->isSameShape(expectedWhBWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhBWshape).c_str(), ShapeUtils::shapeAsString(WhBW).c_str()); - REQUIRE_TRUE(bFW->isSameShape(expectedbFWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbFWshape).c_str(), ShapeUtils::shapeAsString(bFW).c_str()); - REQUIRE_TRUE(bBW->isSameShape(expectedbBWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbBWshape).c_str(), ShapeUtils::shapeAsString(bBW).c_str()); - if(h0FW) { - const std::vector expectedh0FWshape = {bS, numUnitsFW}; - REQUIRE_TRUE(h0FW->isSameShape(expectedh0FWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0FWshape).c_str(), ShapeUtils::shapeAsString(h0FW).c_str()); - } - if(h0BW) { - const std::vector expectedh0BWshape = {bS, numUnitsBW}; - REQUIRE_TRUE(h0BW->isSameShape(expectedh0BWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0BWshape).c_str(), ShapeUtils::shapeAsString(h0BW).c_str()); - } - if(maxTimeStep) - REQUIRE_TRUE(maxTimeStep->isSameShape({bS}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !", bS, ShapeUtils::shapeAsString(maxTimeStep).c_str()); - - // forward steps - auto hFW = new NDArray(x->ordering(), {time, bS, numUnitsFW}, x->dataType(), block.launchContext()); - helpers::rnnTimeLoop(block.launchContext(), x, WxFW, WhFW, bFW, h0FW, maxTimeStep, hFW, hFWFinal); - - auto seqLen = maxTimeStep; - if(seqLen == nullptr) { -// seqLen = new NDArray(x->ordering(), {x->sizeAt(1)}, x->dataType(), block.launchContext()); // [bS] - seqLen = new NDArray(x->ordering(), {x->sizeAt(1)}, sd::DataType::INT64, block.launchContext()); // [bS] - *seqLen = x->sizeAt(0); // set each element of seqLen to be equal to time - } - - // reverse x - auto revOut = new NDArray(x, false, block.launchContext()); - helpers::reverseSequence(block.launchContext(), x, seqLen, revOut, 0, 1); - - // backward steps - auto hBW = new NDArray(x->ordering(), {time, bS, numUnitsBW}, x->dataType(), block.launchContext()); - - helpers::rnnTimeLoop(block.launchContext(), revOut, WxBW, WhBW, bBW, h0BW, maxTimeStep, hBW, hBWFinal); - - // reverse hBW - auto hBWcopy = new NDArray(*hBW); - helpers::reverseSequence(block.launchContext(), hBWcopy, seqLen, hBW, 0, 1); - - // concatenate hFW and hBW along last third dimension - // NDArrayFactory::concat({hFW, hBW}, 2, h); - helpers::concat(block.launchContext(), {hFW, hBW}, *h, 2); - - delete hBW; - delete hFW; - delete hBWcopy; - delete revOut; - - if(seqLen != maxTimeStep) - delete seqLen; - - - return Status::OK(); + auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] + auto WxFW = INPUT_VARIABLE( + 1); // input-to-hidden weights for forward RNN, [inSize x numUnitsFW] + auto WhFW = INPUT_VARIABLE(2); // hidden-to-hidden weights for forward RNN, + // [numUnitsFW x numUnitsFW] + auto bFW = INPUT_VARIABLE(3); // biases for forward RNN, [2*numUnitsFW] + auto WxBW = INPUT_VARIABLE( + 4); // input-to-hidden weights for backward RNN, [inSize x numUnitsBW] + auto WhBW = INPUT_VARIABLE(5); // hidden-to-hidden weights for backward RNN, + // [numUnitsBW x numUnitsBW] + auto bBW = INPUT_VARIABLE(6); // biases for backward RNN, [2*v] + + NDArray* h0FW = nullptr; // initial cell output for forward RNN (at time + // step = 0) [bS x numUnitsFW] + NDArray* h0BW = nullptr; // initial cell output for backward RNN (at time + // step = 0) [bS x numUnitsBW] + NDArray* maxTimeStep = + nullptr; // vector [bS] containing integer values within [0,time), each + // element of this vector set max time step per each input in + // batch, this means there are no calculations for time >= + // maxTimeStep + + switch (block.width()) { + case 8: + maxTimeStep = INPUT_VARIABLE(7); + break; + case 9: + h0FW = INPUT_VARIABLE(7); + h0BW = INPUT_VARIABLE(8); + break; + case 10: + h0FW = INPUT_VARIABLE(7); + h0BW = INPUT_VARIABLE(8); + maxTimeStep = INPUT_VARIABLE(9); + break; + } + + auto h = OUTPUT_VARIABLE(0); // cell outputs [time x bS x (numUnitsFW + + // numUnitsBW)], that is per each time step + auto hFWFinal = + OUTPUT_VARIABLE(1); // final cell out for forward RNN [bS x numUnitsFW] + auto hBWFinal = + OUTPUT_VARIABLE(2); // final cell out for backward RNN [bS x numUnitsBF] + + REQUIRE_TRUE(x->rankOf() == 3, 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: input array must " + "have rank = 3, but got %i instead !", + x->rankOf()); + REQUIRE_TRUE( + WxFW->rankOf() == 2, 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights " + "array (for forward RNN) must have rank = 2, but got %i instead !", + WxFW->rankOf()); + REQUIRE_TRUE( + WxBW->rankOf() == 2, 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights " + "array (for backward RNN) must have rank = 2, but got %i instead !", + WxBW->rankOf()); + + const Nd4jLong inRank = x->rankOf(); + const Nd4jLong time = x->sizeAt(0); + const Nd4jLong bS = x->sizeAt(1); + const Nd4jLong numUnitsFW = WxFW->sizeAt(1); + const Nd4jLong numUnitsBW = WxBW->sizeAt(1); + + const std::vector expectedWhFWshape = {numUnitsFW, numUnitsFW}; + const std::vector expectedWhBWshape = {numUnitsBW, numUnitsBW}; + const std::vector expectedbFWshape = {2 * numUnitsFW}; + const std::vector expectedbBWshape = {2 * numUnitsBW}; + + REQUIRE_TRUE(WhFW->isSameShape(expectedWhFWshape), 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "hidden-to-hidden weights array (for forward RNN), expected is " + "%s but got %s instead !", + ShapeUtils::shapeAsString(expectedWhFWshape).c_str(), + ShapeUtils::shapeAsString(WhFW).c_str()); + REQUIRE_TRUE(WhBW->isSameShape(expectedWhBWshape), 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "hidden-to-hidden weights array (for backward RNN), expected is " + "%s but got %s instead !", + ShapeUtils::shapeAsString(expectedWhBWshape).c_str(), + ShapeUtils::shapeAsString(WhBW).c_str()); + REQUIRE_TRUE( + bFW->isSameShape(expectedbFWshape), 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array " + "(for forward RNN), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedbFWshape).c_str(), + ShapeUtils::shapeAsString(bFW).c_str()); + REQUIRE_TRUE( + bBW->isSameShape(expectedbBWshape), 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array " + "(for backward RNN), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedbBWshape).c_str(), + ShapeUtils::shapeAsString(bBW).c_str()); + if (h0FW) { + const std::vector expectedh0FWshape = {bS, numUnitsFW}; + REQUIRE_TRUE(h0FW->isSameShape(expectedh0FWshape), 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "initial cell output array (for forward RNN), expected is %s " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedh0FWshape).c_str(), + ShapeUtils::shapeAsString(h0FW).c_str()); + } + if (h0BW) { + const std::vector expectedh0BWshape = {bS, numUnitsBW}; + REQUIRE_TRUE(h0BW->isSameShape(expectedh0BWshape), 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "initial cell output array (for backward RNN), expected is %s " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedh0BWshape).c_str(), + ShapeUtils::shapeAsString(h0BW).c_str()); + } + if (maxTimeStep) + REQUIRE_TRUE(maxTimeStep->isSameShape({bS}), 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "maxTimeStep array, expected is [%i], but got %s instead !", + bS, ShapeUtils::shapeAsString(maxTimeStep).c_str()); + + // forward steps + auto hFW = new NDArray(x->ordering(), {time, bS, numUnitsFW}, x->dataType(), + block.launchContext()); + helpers::rnnTimeLoop(block.launchContext(), x, WxFW, WhFW, bFW, h0FW, + maxTimeStep, hFW, hFWFinal); + + auto seqLen = maxTimeStep; + if (seqLen == nullptr) { + // seqLen = new NDArray(x->ordering(), {x->sizeAt(1)}, x->dataType(), + // block.launchContext()); // [bS] + seqLen = new NDArray(x->ordering(), {x->sizeAt(1)}, sd::DataType::INT64, + block.launchContext()); // [bS] + *seqLen = x->sizeAt(0); // set each element of seqLen to be equal to time + } + + // reverse x + auto revOut = new NDArray(x, false, block.launchContext()); + helpers::reverseSequence(block.launchContext(), x, seqLen, revOut, 0, 1); + + // backward steps + auto hBW = new NDArray(x->ordering(), {time, bS, numUnitsBW}, x->dataType(), + block.launchContext()); + + helpers::rnnTimeLoop(block.launchContext(), revOut, WxBW, WhBW, bBW, h0BW, + maxTimeStep, hBW, hBWFinal); + + // reverse hBW + auto hBWcopy = new NDArray(*hBW); + helpers::reverseSequence(block.launchContext(), hBWcopy, seqLen, hBW, 0, 1); + + // concatenate hFW and hBW along last third dimension + // NDArrayFactory::concat({hFW, hBW}, 2, h); + helpers::concat(block.launchContext(), {hFW, hBW}, *h, 2); + + delete hBW; + delete hFW; + delete hBWcopy; + delete revOut; + + if (seqLen != maxTimeStep) delete seqLen; + + return Status::OK(); } - DECLARE_TYPES(static_bidirectional_rnn) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - -DECLARE_SHAPE_FN(static_bidirectional_rnn) { - - auto xShapeInfo = inputShape->at(0); // input [time x bS x inSize] - auto WxFWShapeInfo = inputShape->at(1); // input-to-hidden weights for forward RNN, [inSize x numUnitsFW] - auto WhFWShapeInfo = inputShape->at(2); // hidden-to-hidden weights for forward RNN, [numUnitsFW x numUnitsFW] - auto bFWShapeInfo = inputShape->at(3); // biases for forward RNN, [2*numUnitsFW] - auto WxBWShapeInfo = inputShape->at(4); // input-to-hidden weights for backward RNN, [inSize x numUnitsBW] - auto WhBWShapeInfo = inputShape->at(5); // hidden-to-hidden weights for backward RNN, [numUnitsBW x numUnitsBW] - auto bBWShapeInfo = inputShape->at(6); // biases for backward RNN, [2*numUnitsBW] - - Nd4jLong const* h0FWShapeInfo = nullptr; // initial cell output for forward RNN (at time step = 0) [bS x numUnitsFW] - Nd4jLong const* h0BWShapeInfo = nullptr; // initial cell output for backward RNN (at time step = 0) [bS x numUnitsBW] - Nd4jLong const* maxTimeStepShapeInfo = nullptr; // vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep - - switch(block.width()) { - case 8: - maxTimeStepShapeInfo = inputShape->at(7); - break; - case 9: - h0FWShapeInfo = inputShape->at(7); - h0BWShapeInfo = inputShape->at(8); - break; - case 10: - h0FWShapeInfo = inputShape->at(7); - h0BWShapeInfo = inputShape->at(8); - maxTimeStepShapeInfo = inputShape->at(9); - break; - } - - REQUIRE_TRUE(xShapeInfo[0] == 3, 0, "STATIC_BIDIRECTIONAL_RNN custom operation: input array must have rank = 3, but got %i instead !", xShapeInfo[0]); - REQUIRE_TRUE(WxFWShapeInfo[0] == 2, 0, "STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for forward RNN) must have rank = 2, but got %i instead !", WxFWShapeInfo[0]); - REQUIRE_TRUE(WxBWShapeInfo[0] == 2, 0, "STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for backward RNN) must have rank = 2, but got %i instead !", WxBWShapeInfo[0]); - - const int inRank = xShapeInfo[0]; - const int time = xShapeInfo[1]; - const int bS = xShapeInfo[2]; - const int numUnitsFW = WxFWShapeInfo[2]; - const int numUnitsBW = WxBWShapeInfo[2]; - - const std::vector expectedWhFWshape = {numUnitsFW, numUnitsFW}; - const std::vector expectedWhBWshape = {numUnitsBW, numUnitsBW}; - const std::vector expectedbFWshape = {2 * numUnitsFW}; - const std::vector expectedbBWshape = {2 * numUnitsBW}; - - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhFWShapeInfo, expectedWhFWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhFWshape).c_str(), ShapeUtils::shapeAsString(WhFWShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhBWShapeInfo, expectedWhBWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhBWshape).c_str(), ShapeUtils::shapeAsString(WhBWShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bFWShapeInfo, expectedbFWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbFWshape).c_str(), ShapeUtils::shapeAsString(bFWShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bBWShapeInfo, expectedbBWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbBWshape).c_str(), ShapeUtils::shapeAsString(bBWShapeInfo).c_str()); - if(h0FWShapeInfo) { - const std::vector expectedh0FWshape = {bS, numUnitsFW}; - REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0FWShapeInfo, expectedh0FWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0FWshape).c_str(), ShapeUtils::shapeAsString(h0FWShapeInfo).c_str()); - } - if(h0BWShapeInfo) { - const std::vector expectedh0BWshape = {bS, numUnitsBW}; - REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0BWShapeInfo, expectedh0BWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0BWshape).c_str(), ShapeUtils::shapeAsString(h0BWShapeInfo).c_str()); - } - if(maxTimeStepShapeInfo) - REQUIRE_TRUE(ShapeUtils::areShapesEqual(maxTimeStepShapeInfo, {bS}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !", bS, ShapeUtils::shapeAsString(maxTimeStepShapeInfo).c_str()); - - // evaluate output shapeInfos - Nd4jLong *hShapeInfo(nullptr), *hFWFinalPrevShapeInfo(nullptr), *hBWFinalPrevShapeInfo(nullptr); - ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong); - ALLOCATE(hFWFinalPrevShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank-1), Nd4jLong); - ALLOCATE(hBWFinalPrevShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank-1), Nd4jLong); - - hShapeInfo[0] = inRank; - hFWFinalPrevShapeInfo[0] = hBWFinalPrevShapeInfo[0] = inRank-1; - hShapeInfo[1] = time; - hShapeInfo[2] = hFWFinalPrevShapeInfo[1] = hBWFinalPrevShapeInfo[1] = bS; - hShapeInfo[3] = numUnitsFW + numUnitsBW; - hFWFinalPrevShapeInfo[2] = numUnitsFW; - hBWFinalPrevShapeInfo[2] = numUnitsBW; - - ShapeUtils::updateStridesAndType(hShapeInfo, xShapeInfo, shape::order(xShapeInfo)); - ShapeUtils::updateStridesAndType(hFWFinalPrevShapeInfo, xShapeInfo, shape::order(xShapeInfo)); - ShapeUtils::updateStridesAndType(hBWFinalPrevShapeInfo, xShapeInfo, shape::order(xShapeInfo)); - - return SHAPELIST(CONSTANT(hShapeInfo), CONSTANT(hFWFinalPrevShapeInfo), CONSTANT(hBWFinalPrevShapeInfo)); +DECLARE_TYPES(static_bidirectional_rnn) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - - - - - - - -} +DECLARE_SHAPE_FN(static_bidirectional_rnn) { + auto xShapeInfo = inputShape->at(0); // input [time x bS x inSize] + auto WxFWShapeInfo = inputShape->at( + 1); // input-to-hidden weights for forward RNN, [inSize x numUnitsFW] + auto WhFWShapeInfo = + inputShape->at(2); // hidden-to-hidden weights for forward RNN, + // [numUnitsFW x numUnitsFW] + auto bFWShapeInfo = + inputShape->at(3); // biases for forward RNN, [2*numUnitsFW] + auto WxBWShapeInfo = inputShape->at( + 4); // input-to-hidden weights for backward RNN, [inSize x numUnitsBW] + auto WhBWShapeInfo = + inputShape->at(5); // hidden-to-hidden weights for backward RNN, + // [numUnitsBW x numUnitsBW] + auto bBWShapeInfo = + inputShape->at(6); // biases for backward RNN, [2*numUnitsBW] + + Nd4jLong const* h0FWShapeInfo = + nullptr; // initial cell output for forward RNN (at time step = 0) [bS x + // numUnitsFW] + Nd4jLong const* h0BWShapeInfo = + nullptr; // initial cell output for backward RNN (at time step = 0) [bS x + // numUnitsBW] + Nd4jLong const* maxTimeStepShapeInfo = + nullptr; // vector [bS] containing integer values within [0,time), each + // element of this vector set max time step per each input in + // batch, this means there are no calculations for time >= + // maxTimeStep + + switch (block.width()) { + case 8: + maxTimeStepShapeInfo = inputShape->at(7); + break; + case 9: + h0FWShapeInfo = inputShape->at(7); + h0BWShapeInfo = inputShape->at(8); + break; + case 10: + h0FWShapeInfo = inputShape->at(7); + h0BWShapeInfo = inputShape->at(8); + maxTimeStepShapeInfo = inputShape->at(9); + break; + } + + REQUIRE_TRUE(xShapeInfo[0] == 3, 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: input array must " + "have rank = 3, but got %i instead !", + xShapeInfo[0]); + REQUIRE_TRUE( + WxFWShapeInfo[0] == 2, 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights " + "array (for forward RNN) must have rank = 2, but got %i instead !", + WxFWShapeInfo[0]); + REQUIRE_TRUE( + WxBWShapeInfo[0] == 2, 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights " + "array (for backward RNN) must have rank = 2, but got %i instead !", + WxBWShapeInfo[0]); + + const int inRank = xShapeInfo[0]; + const int time = xShapeInfo[1]; + const int bS = xShapeInfo[2]; + const int numUnitsFW = WxFWShapeInfo[2]; + const int numUnitsBW = WxBWShapeInfo[2]; + + const std::vector expectedWhFWshape = {numUnitsFW, numUnitsFW}; + const std::vector expectedWhBWshape = {numUnitsBW, numUnitsBW}; + const std::vector expectedbFWshape = {2 * numUnitsFW}; + const std::vector expectedbBWshape = {2 * numUnitsBW}; + + REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhFWShapeInfo, expectedWhFWshape), 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "hidden-to-hidden weights array (for forward RNN), expected is " + "%s but got %s instead !", + ShapeUtils::shapeAsString(expectedWhFWshape).c_str(), + ShapeUtils::shapeAsString(WhFWShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhBWShapeInfo, expectedWhBWshape), 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "hidden-to-hidden weights array (for backward RNN), expected is " + "%s but got %s instead !", + ShapeUtils::shapeAsString(expectedWhBWshape).c_str(), + ShapeUtils::shapeAsString(WhBWShapeInfo).c_str()); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(bFWShapeInfo, expectedbFWshape), 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array " + "(for forward RNN), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedbFWshape).c_str(), + ShapeUtils::shapeAsString(bFWShapeInfo).c_str()); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(bBWShapeInfo, expectedbBWshape), 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array " + "(for backward RNN), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedbBWshape).c_str(), + ShapeUtils::shapeAsString(bBWShapeInfo).c_str()); + if (h0FWShapeInfo) { + const std::vector expectedh0FWshape = {bS, numUnitsFW}; + REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0FWShapeInfo, expectedh0FWshape), + 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "initial cell output array (for forward RNN), expected is %s " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedh0FWshape).c_str(), + ShapeUtils::shapeAsString(h0FWShapeInfo).c_str()); + } + if (h0BWShapeInfo) { + const std::vector expectedh0BWshape = {bS, numUnitsBW}; + REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0BWShapeInfo, expectedh0BWshape), + 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "initial cell output array (for backward RNN), expected is %s " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedh0BWshape).c_str(), + ShapeUtils::shapeAsString(h0BWShapeInfo).c_str()); + } + if (maxTimeStepShapeInfo) + REQUIRE_TRUE(ShapeUtils::areShapesEqual(maxTimeStepShapeInfo, {bS}), 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "maxTimeStep array, expected is [%i], but got %s instead !", + bS, ShapeUtils::shapeAsString(maxTimeStepShapeInfo).c_str()); + + // evaluate output shapeInfos + Nd4jLong *hShapeInfo(nullptr), *hFWFinalPrevShapeInfo(nullptr), + *hBWFinalPrevShapeInfo(nullptr); + ALLOCATE(hShapeInfo, block.workspace(), shape::shapeInfoLength(inRank), + Nd4jLong); + ALLOCATE(hFWFinalPrevShapeInfo, block.workspace(), + shape::shapeInfoLength(inRank - 1), Nd4jLong); + ALLOCATE(hBWFinalPrevShapeInfo, block.workspace(), + shape::shapeInfoLength(inRank - 1), Nd4jLong); + + hShapeInfo[0] = inRank; + hFWFinalPrevShapeInfo[0] = hBWFinalPrevShapeInfo[0] = inRank - 1; + hShapeInfo[1] = time; + hShapeInfo[2] = hFWFinalPrevShapeInfo[1] = hBWFinalPrevShapeInfo[1] = bS; + hShapeInfo[3] = numUnitsFW + numUnitsBW; + hFWFinalPrevShapeInfo[2] = numUnitsFW; + hBWFinalPrevShapeInfo[2] = numUnitsBW; + + ShapeUtils::updateStridesAndType(hShapeInfo, xShapeInfo, + shape::order(xShapeInfo)); + ShapeUtils::updateStridesAndType(hFWFinalPrevShapeInfo, xShapeInfo, + shape::order(xShapeInfo)); + ShapeUtils::updateStridesAndType(hBWFinalPrevShapeInfo, xShapeInfo, + shape::order(xShapeInfo)); + + return SHAPELIST(CONSTANT(hShapeInfo), CONSTANT(hFWFinalPrevShapeInfo), + CONSTANT(hBWFinalPrevShapeInfo)); } +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/staticRNN.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/staticRNN.cpp index 26d2e0818541..384493fcddc5 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/staticRNN.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/staticRNN.cpp @@ -19,130 +19,184 @@ // #include -#include +#include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(static_rnn, 4, 2, false, 0, 0) { - - auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] - auto Wx = INPUT_VARIABLE(1); // input-to-hidden weights, [inSize x numUnits] - auto Wh = INPUT_VARIABLE(2); // hidden-to-hidden weights, [numUnits x numUnits] - auto b = INPUT_VARIABLE(3); // biases for, [2*numUnits] - - NDArray* h0 = nullptr; // initial cell output (at time step = 0) [bS x numUnits] - NDArray* maxTimeStep = nullptr; // vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep - - if(block.width() == 5) { - if ((*INPUT_VARIABLE(4)).rankOf() == 2) - h0 = INPUT_VARIABLE(4); - else - maxTimeStep = INPUT_VARIABLE(4); - } - else if(block.width() == 6) { - h0 = INPUT_VARIABLE(4); - maxTimeStep = INPUT_VARIABLE(5); - } - - auto h = OUTPUT_VARIABLE(0); // cell outputs [time x bS x numUnits] - auto hFinal = OUTPUT_VARIABLE(1); // at the end it will store cell final non-zero output [bS x numUnits] - - REQUIRE_TRUE(x->rankOf() == 3, 0, "STATIC_RNN custom operation: input array x must have rank = 3, but got %i instead !", x->rankOf()); - REQUIRE_TRUE(Wx->rankOf() == 2, 0, "STATIC_RNN custom operation: input-to-hidden weights array must have rank = 2, but got %i instead !", Wx->rankOf()); - - const int time = x->sizeAt(0); - const int bS = x->sizeAt(1); - const int inSize = x->sizeAt(2); - const int numUnits = Wx->sizeAt(1); - - const std::vector expectedWhShape = {numUnits, numUnits}; - const std::vector expectedbShape = {2 * numUnits}; - - REQUIRE_TRUE(Wh->isSameShape(expectedWhShape), 0, "STATIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWhShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); - REQUIRE_TRUE(b->isSameShape(expectedbShape), 0, "STATIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - if(h0) { - const std::vector expectedh0Shape = {bS, numUnits}; - REQUIRE_TRUE(h0->isSameShape(expectedh0Shape), 0, "STATIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0Shape).c_str(), ShapeUtils::shapeAsString(h0).c_str()); - } - if(maxTimeStep) - REQUIRE_TRUE(maxTimeStep->isSameShape({bS}), 0, "STATIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS}).c_str(), ShapeUtils::shapeAsString(maxTimeStep).c_str()); - - - helpers::rnnTimeLoop(block.launchContext(), x, Wx, Wh, b, h0, maxTimeStep, h, hFinal); - - return Status::OK(); + auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] + auto Wx = + INPUT_VARIABLE(1); // input-to-hidden weights, [inSize x numUnits] + auto Wh = + INPUT_VARIABLE(2); // hidden-to-hidden weights, [numUnits x numUnits] + auto b = INPUT_VARIABLE(3); // biases for, [2*numUnits] + + NDArray* h0 = + nullptr; // initial cell output (at time step = 0) [bS x numUnits] + NDArray* maxTimeStep = + nullptr; // vector [bS] containing integer values within [0,time), each + // element of this vector set max time step per each input in + // batch, this means there are no calculations for time >= + // maxTimeStep + + if (block.width() == 5) { + if ((*INPUT_VARIABLE(4)).rankOf() == 2) + h0 = INPUT_VARIABLE(4); + else + maxTimeStep = INPUT_VARIABLE(4); + } else if (block.width() == 6) { + h0 = INPUT_VARIABLE(4); + maxTimeStep = INPUT_VARIABLE(5); + } + + auto h = OUTPUT_VARIABLE(0); // cell outputs [time x bS x numUnits] + auto hFinal = OUTPUT_VARIABLE(1); // at the end it will store cell final + // non-zero output [bS x numUnits] + + REQUIRE_TRUE(x->rankOf() == 3, 0, + "STATIC_RNN custom operation: input array x must have rank = 3, " + "but got %i instead !", + x->rankOf()); + REQUIRE_TRUE(Wx->rankOf() == 2, 0, + "STATIC_RNN custom operation: input-to-hidden weights array " + "must have rank = 2, but got %i instead !", + Wx->rankOf()); + + const int time = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int inSize = x->sizeAt(2); + const int numUnits = Wx->sizeAt(1); + + const std::vector expectedWhShape = {numUnits, numUnits}; + const std::vector expectedbShape = {2 * numUnits}; + + REQUIRE_TRUE(Wh->isSameShape(expectedWhShape), 0, + "STATIC_RNN custom operation: wrong shape of hidden-to-hidden " + "weights array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWhShape).c_str(), + ShapeUtils::shapeAsString(Wh).c_str()); + REQUIRE_TRUE(b->isSameShape(expectedbShape), 0, + "STATIC_RNN custom operation: wrong shape of biases array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedbShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + if (h0) { + const std::vector expectedh0Shape = {bS, numUnits}; + REQUIRE_TRUE(h0->isSameShape(expectedh0Shape), 0, + "STATIC_RNN custom operation: wrong shape of initial cell " + "output array, expected is %s but got %s instead !", + ShapeUtils::shapeAsString(expectedh0Shape).c_str(), + ShapeUtils::shapeAsString(h0).c_str()); + } + if (maxTimeStep) + REQUIRE_TRUE(maxTimeStep->isSameShape({bS}), 0, + "STATIC_RNN custom operation: wrong shape of maxTimeStep " + "array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({bS}).c_str(), + ShapeUtils::shapeAsString(maxTimeStep).c_str()); + + helpers::rnnTimeLoop(block.launchContext(), x, Wx, Wh, b, h0, maxTimeStep, h, + hFinal); + + return Status::OK(); } DECLARE_TYPES(static_rnn) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - DECLARE_SHAPE_FN(static_rnn) { - - auto xShapeInfo = inputShape->at(0); // input [time x bS x inSize] - auto WxShapeInfo = inputShape->at(1); // input-to-hidden weights, [inSize x numUnits] - auto WhShapeInfo = inputShape->at(2); // hidden-to-hidden weights, [numUnits x numUnits] - auto bShapeInfo = inputShape->at(3); // biases for, [2*numUnits] - - const Nd4jLong* h0ShapeInfo = nullptr; // initial cell output (at time step = 0) [bS x numUnits] - const Nd4jLong* maxTimeStepShapeInfo = nullptr; // vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep - - if(block.width() == 5) { - if (inputShape->at(4)[0] == 2) - h0ShapeInfo = inputShape->at(4); - else - maxTimeStepShapeInfo = inputShape->at(4); - } - else if(block.width() == 6) { - h0ShapeInfo = inputShape->at(4); - maxTimeStepShapeInfo = inputShape->at(5); - } - - REQUIRE_TRUE(xShapeInfo[0] == 3, 0, "STATIC_RNN custom operation: input array x must have rank = 3, but got %i instead !", xShapeInfo[0]); - REQUIRE_TRUE(WxShapeInfo[0] == 2, 0, "STATIC_RNN custom operation: input-to-hidden weights array must have rank = 2, but got %i instead !", WxShapeInfo[0]); - - const int inRank = xShapeInfo[0]; - const int time = xShapeInfo[1]; - const int bS = xShapeInfo[2]; - const int numUnits = WxShapeInfo[2]; - - const std::vector expectedWhShape = {numUnits, numUnits}; - const std::vector expectedbShape = {2 * numUnits}; - - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, expectedWhShape), 0, "STATIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWhShape).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, expectedbShape), 0, "STATIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str()); - if(h0ShapeInfo){ - const std::vector expectedh0Shape = {bS, numUnits}; - REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0ShapeInfo, expectedh0Shape), 0, "STATIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0Shape).c_str(), ShapeUtils::shapeAsString(h0ShapeInfo).c_str()); - } - if(maxTimeStepShapeInfo) - REQUIRE_TRUE(ShapeUtils::areShapesEqual(maxTimeStepShapeInfo, {bS}), 0, "STATIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS}).c_str(), ShapeUtils::shapeAsString(maxTimeStepShapeInfo).c_str()); - - // evaluate output shapeInfos - Nd4jLong *hShapeInfo(nullptr), *hPrevShapeInfo(nullptr); - ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong); - ALLOCATE(hPrevShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank-1), Nd4jLong); - - hShapeInfo[0] = inRank; - hPrevShapeInfo[0] = inRank-1; - hShapeInfo[1] = time; - hShapeInfo[2] = hPrevShapeInfo[1] = bS; - hShapeInfo[3] = hPrevShapeInfo[2] = numUnits; - - ShapeUtils::updateStridesAndType(hShapeInfo, xShapeInfo, shape::order(xShapeInfo)); - ShapeUtils::updateStridesAndType(hPrevShapeInfo, xShapeInfo, shape::order(xShapeInfo)); - - return SHAPELIST(CONSTANT(hShapeInfo), CONSTANT(hPrevShapeInfo)); + auto xShapeInfo = inputShape->at(0); // input [time x bS x inSize] + auto WxShapeInfo = + inputShape->at(1); // input-to-hidden weights, [inSize x numUnits] + auto WhShapeInfo = + inputShape->at(2); // hidden-to-hidden weights, [numUnits x numUnits] + auto bShapeInfo = inputShape->at(3); // biases for, [2*numUnits] + + const Nd4jLong* h0ShapeInfo = + nullptr; // initial cell output (at time step = 0) [bS x numUnits] + const Nd4jLong* maxTimeStepShapeInfo = + nullptr; // vector [bS] containing integer values within [0,time), each + // element of this vector set max time step per each input in + // batch, this means there are no calculations for time >= + // maxTimeStep + + if (block.width() == 5) { + if (inputShape->at(4)[0] == 2) + h0ShapeInfo = inputShape->at(4); + else + maxTimeStepShapeInfo = inputShape->at(4); + } else if (block.width() == 6) { + h0ShapeInfo = inputShape->at(4); + maxTimeStepShapeInfo = inputShape->at(5); + } + + REQUIRE_TRUE(xShapeInfo[0] == 3, 0, + "STATIC_RNN custom operation: input array x must have rank = 3, " + "but got %i instead !", + xShapeInfo[0]); + REQUIRE_TRUE(WxShapeInfo[0] == 2, 0, + "STATIC_RNN custom operation: input-to-hidden weights array " + "must have rank = 2, but got %i instead !", + WxShapeInfo[0]); + + const int inRank = xShapeInfo[0]; + const int time = xShapeInfo[1]; + const int bS = xShapeInfo[2]; + const int numUnits = WxShapeInfo[2]; + + const std::vector expectedWhShape = {numUnits, numUnits}; + const std::vector expectedbShape = {2 * numUnits}; + + REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, expectedWhShape), 0, + "STATIC_RNN custom operation: wrong shape of hidden-to-hidden " + "weights array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWhShape).c_str(), + ShapeUtils::shapeAsString(WhShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, expectedbShape), 0, + "STATIC_RNN custom operation: wrong shape of biases array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedbShape).c_str(), + ShapeUtils::shapeAsString(bShapeInfo).c_str()); + if (h0ShapeInfo) { + const std::vector expectedh0Shape = {bS, numUnits}; + REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0ShapeInfo, expectedh0Shape), 0, + "STATIC_RNN custom operation: wrong shape of initial cell " + "output array, expected is %s but got %s instead !", + ShapeUtils::shapeAsString(expectedh0Shape).c_str(), + ShapeUtils::shapeAsString(h0ShapeInfo).c_str()); + } + if (maxTimeStepShapeInfo) + REQUIRE_TRUE(ShapeUtils::areShapesEqual(maxTimeStepShapeInfo, {bS}), 0, + "STATIC_RNN custom operation: wrong shape of maxTimeStep " + "array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({bS}).c_str(), + ShapeUtils::shapeAsString(maxTimeStepShapeInfo).c_str()); + + // evaluate output shapeInfos + Nd4jLong *hShapeInfo(nullptr), *hPrevShapeInfo(nullptr); + ALLOCATE(hShapeInfo, block.workspace(), shape::shapeInfoLength(inRank), + Nd4jLong); + ALLOCATE(hPrevShapeInfo, block.workspace(), + shape::shapeInfoLength(inRank - 1), Nd4jLong); + + hShapeInfo[0] = inRank; + hPrevShapeInfo[0] = inRank - 1; + hShapeInfo[1] = time; + hShapeInfo[2] = hPrevShapeInfo[1] = bS; + hShapeInfo[3] = hPrevShapeInfo[2] = numUnits; + + ShapeUtils::updateStridesAndType(hShapeInfo, xShapeInfo, + shape::order(xShapeInfo)); + ShapeUtils::updateStridesAndType(hPrevShapeInfo, xShapeInfo, + shape::order(xShapeInfo)); + + return SHAPELIST(CONSTANT(hShapeInfo), CONSTANT(hPrevShapeInfo)); } - - - -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp b/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp index c76b79b7b7e1..18d9124322dc 100644 --- a/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp @@ -21,45 +21,62 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(relu_layer, 3, 1, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto w = INPUT_VARIABLE(1); - auto b = INPUT_VARIABLE(2); +namespace ops { +CUSTOM_OP_IMPL(relu_layer, 3, 1, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto w = INPUT_VARIABLE(1); + auto b = INPUT_VARIABLE(2); - REQUIRE_TRUE(x->isMatrix(), 0, "relu_layer: x argument should be a 2D tensor, but got rank %i instead!", x->rankOf()); - REQUIRE_TRUE(w->isMatrix(), 0, "relu_layer: weights argument should be a 2D tensor, but got rank %i instead!", w->rankOf()); - REQUIRE_TRUE(b->isVector(), 0, "relu_layer: biases argument should be a 1D tensor, but got rank %i instead!", b->rankOf()); - REQUIRE_TRUE(b->lengthOf() == w->sizeAt(1), 0, "relu_layer: biases array length should match to columns of weights matrix, however got length = %i and columns = %i!", b->lengthOf(), w->sizeAt(1)); - REQUIRE_TRUE(x->sizeAt(1) == w->sizeAt(0), 0, "relu_layer: number of x columns should match to row number of weights matrix, but got x_columns = %i and weights_rows = %i!", x->sizeAt(1), w->sizeAt(0)); + REQUIRE_TRUE( + x->isMatrix(), 0, + "relu_layer: x argument should be a 2D tensor, but got rank %i instead!", + x->rankOf()); + REQUIRE_TRUE(w->isMatrix(), 0, + "relu_layer: weights argument should be a 2D tensor, but got " + "rank %i instead!", + w->rankOf()); + REQUIRE_TRUE(b->isVector(), 0, + "relu_layer: biases argument should be a 1D tensor, but got " + "rank %i instead!", + b->rankOf()); + REQUIRE_TRUE(b->lengthOf() == w->sizeAt(1), 0, + "relu_layer: biases array length should match to columns of " + "weights matrix, however got length = %i and columns = %i!", + b->lengthOf(), w->sizeAt(1)); + REQUIRE_TRUE(x->sizeAt(1) == w->sizeAt(0), 0, + "relu_layer: number of x columns should match to row number of " + "weights matrix, but got x_columns = %i and weights_rows = %i!", + x->sizeAt(1), w->sizeAt(0)); - auto output = OUTPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - sd::ops::xw_plus_b op; - auto status = op.execute({x, w, b}, {output}); - REQUIRE_TRUE(Status::OK() == status, 0, "relu_layer: xw_plus_b op failed on input data."); + sd::ops::xw_plus_b op; + auto status = op.execute({x, w, b}, {output}); + REQUIRE_TRUE(Status::OK() == status, 0, + "relu_layer: xw_plus_b op failed on input data."); - auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0; + auto scalar = block.numT() > 0 ? block.getTArguments().at(0) : 0.0; - output->applyScalar(sd::scalar::RELU, scalar, *output); + output->applyScalar(sd::scalar::RELU, scalar, *output); - return Status::OK(); - } - - DECLARE_SHAPE_FN(relu_layer) { - auto inShape = inputShape->at(0); - auto weightsShape = inputShape->at(1); - auto outputShape = ShapeUtils::matrixProductShape(inShape, weightsShape, false, false, ArrayOptions::dataType(inShape), block.getWorkspace()); + return Status::OK(); +} - return SHAPELIST(outputShape); - } +DECLARE_SHAPE_FN(relu_layer) { + auto inShape = inputShape->at(0); + auto weightsShape = inputShape->at(1); + auto outputShape = ShapeUtils::matrixProductShape( + inShape, weightsShape, false, false, ArrayOptions::dataType(inShape), + block.workspace()); - DECLARE_TYPES(relu_layer) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) -// ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } + return SHAPELIST(outputShape); } +DECLARE_TYPES(relu_layer) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + // ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); +} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/nn/softmax.cpp b/libnd4j/include/ops/declarable/generic/nn/softmax.cpp index d5c58bb7a24a..fba7d33fe91a 100644 --- a/libnd4j/include/ops/declarable/generic/nn/softmax.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/softmax.cpp @@ -28,53 +28,61 @@ namespace sd { namespace ops { - DECLARE_TYPES(softmax) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(true); - } +DECLARE_TYPES(softmax) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(true); +} CONFIGURABLE_OP_IMPL(softmax, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - const int rank = input->rankOf(); - const int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1; + const int rank = input->rankOf(); + const int dim = block.numI() > 0 ? INT_ARG(0) : rank - 1; - REQUIRE_TRUE(dim < rank, 0, "SOFTMAX OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim); + REQUIRE_TRUE( + dim < rank, 0, + "SOFTMAX OP: the value of input integer parameter (dimension) must be " + "less than input array rank %i, but got dimension = %i instead !", + rank, dim); - helpers::softmax(block.launchContext(), *input, *output, dim); + helpers::softmax(block.launchContext(), *input, *output, dim); - return Status::OK(); + return Status::OK(); } - CONFIGURABLE_OP_IMPL(softmax_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); - const int rank = input->rankOf(); - const int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1; + const int rank = input->rankOf(); + const int dim = block.numI() > 0 ? INT_ARG(0) : rank - 1; - REQUIRE_TRUE(dim < rank, 0, "SOFTMAX_BP OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim); + REQUIRE_TRUE( + dim < rank, 0, + "SOFTMAX_BP OP: the value of input integer parameter (dimension) must be " + "less than input array rank %i, but got dimension = %i instead !", + rank, dim); - helpers::softmax(block.launchContext(), *input, *gradI, dim); + helpers::softmax(block.launchContext(), *input, *gradI, dim); - auto sumAlongDim = (*gradI * *gradO).reduceAlongDimension(reduce::Sum, {dim}, true); - gradI->assign(*gradI * (*gradO - sumAlongDim)); + auto sumAlongDim = + (*gradI * *gradO).reduceAlongDimension(reduce::Sum, {dim}, true); + gradI->assign(*gradI * (*gradO - sumAlongDim)); - return Status::OK(); + return Status::OK(); } - DECLARE_TYPES(softmax_bp) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - -} +DECLARE_TYPES(softmax_bp) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp b/libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp index 5b36ee0e5c89..229dd02b3658 100644 --- a/libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp @@ -14,129 +14,156 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // xw_plus_b op. Created by GS 31.01.2018 - // @author Oleg Semeniv - // - // +// +// xw_plus_b op. Created by GS 31.01.2018 +// @author Oleg Semeniv +// +// #include #if NOT_EXCLUDED(OP_xw_plus_b) +#include #include #include -#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(xw_plus_b, 3, 1, false, 0, 0) { - - auto x = INPUT_VARIABLE(0); - - auto b = INPUT_VARIABLE(2); - auto z = OUTPUT_VARIABLE(0); - - if (x->isEmpty() || INPUT_VARIABLE(1)->isEmpty() || b->isEmpty()) - return Status::OK(); - - const bool bTranspose = (block.getIArguments()->size() > 0 ? INT_ARG(0) == 1 : false); - - auto w = bTranspose ? new NDArray(INPUT_VARIABLE(1)->transpose()) : INPUT_VARIABLE(1); - - REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b: Input x array should have rank equal 2, but got instead %i!", x->rankOf()); - REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b: Input weights array should have rank equal 2, but got instead %i!", w->rankOf()); - REQUIRE_TRUE(z->rankOf() == 2, 0, "xw_plus_b: Output array should have rank equal 2, but got instead %i!", z->rankOf()); - - REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == z->sizeAt(-1), 0, "xw_plus_b: Input bias vector should be 1D and have proper dimension 1x%i." - " But got rank %i, and got length %i instead %i.", z->sizeAt(-1), b->rankOf(), b->lengthOf(), z->sizeAt(-1)); - - // multiply x to y - MmulHelper::mmul(x, w, z, 1.0, 0.0); +namespace ops { +CUSTOM_OP_IMPL(xw_plus_b, 3, 1, false, 0, 0) { + auto x = INPUT_VARIABLE(0); - // adding b vector - z->addiRowVector(*b); + auto b = INPUT_VARIABLE(2); + auto z = OUTPUT_VARIABLE(0); - if (bTranspose) - delete w; + if (x->isEmpty() || INPUT_VARIABLE(1)->isEmpty() || b->isEmpty()) + return Status::OK(); - return Status::OK(); - } + const bool bTranspose = (block.numI() > 0 ? INT_ARG(0) == 1 : false); - DECLARE_SHAPE_FN(xw_plus_b) { + auto w = bTranspose ? new NDArray(INPUT_VARIABLE(1)->transpose()) + : INPUT_VARIABLE(1); - auto weights = INPUT_VARIABLE(1); + REQUIRE_TRUE( + x->rankOf() == 2, 0, + "xw_plus_b: Input x array should have rank equal 2, but got instead %i!", + x->rankOf()); + REQUIRE_TRUE(w->rankOf() == 2, 0, + "xw_plus_b: Input weights array should have rank equal 2, but " + "got instead %i!", + w->rankOf()); + REQUIRE_TRUE( + z->rankOf() == 2, 0, + "xw_plus_b: Output array should have rank equal 2, but got instead %i!", + z->rankOf()); - const int nWeightsFormat = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; + REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == z->sizeAt(-1), 0, + "xw_plus_b: Input bias vector should be 1D and have proper " + "dimension 1x%i." + " But got rank %i, and got length %i instead %i.", + z->sizeAt(-1), b->rankOf(), b->lengthOf(), z->sizeAt(-1)); - auto weightsShape = (1 == nWeightsFormat) ? ShapeUtils::evalTranspShapeInfo(*weights, block.getWorkspace()) : inputShape->at(1); + // multiply x to y + MmulHelper::mmul(x, w, z, 1.0, 0.0); - auto outputShape = ShapeUtils::matrixProductShape(inputShape->at(0), weightsShape, false, false, - ArrayOptions::dataType(inputShape->at(0)), block.getWorkspace()); + // adding b vector + z->addiRowVector(*b); - return SHAPELIST(outputShape); - } + if (bTranspose) delete w; - DECLARE_TYPES(xw_plus_b) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ ALL_FLOATS }); - } - - - CUSTOM_OP_IMPL(xw_plus_b_bp, 4, 3, false, 0, 0) { - - auto x = INPUT_VARIABLE(0); - auto b = INPUT_VARIABLE(2); - auto dLdz = INPUT_VARIABLE(3); - - auto dLdx = OUTPUT_VARIABLE(0); - auto dLdb = OUTPUT_VARIABLE(2); - - if (x->isEmpty() || INPUT_VARIABLE(1)->isEmpty() || b->isEmpty() || dLdz->isEmpty()) - return Status::OK(); + return Status::OK(); +} - const bool bTranspose = (block.getIArguments()->size() > 0 ? INT_ARG(0) == 1 : false); +DECLARE_SHAPE_FN(xw_plus_b) { + auto weights = INPUT_VARIABLE(1); - auto w = bTranspose ? new NDArray(INPUT_VARIABLE(1)->transpose()) : INPUT_VARIABLE(1); + const int nWeightsFormat = block.numI() > 0 ? INT_ARG(0) : 0; - REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b BP: Input x array should have rank equal 2, but got instead %i!", x->rankOf()); - REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b BP: Input weights array should have rank equal 2, but got instead %i!", w->rankOf()); - REQUIRE_TRUE(dLdz->rankOf() == 2, 0, "xw_plus_b BP: Output array should have rank equal 2, but got instead %i!", dLdz->rankOf()); - REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == dLdz->sizeAt(-1), 0, "xw_plus_b BP: Input bias vector should be 1D and have proper dimension 1x%i." - " But got rank %i, and got length %i instead %i.", dLdz->sizeAt(-1), b->rankOf(), b->lengthOf(), dLdz->sizeAt(-1)); + auto weightsShape = (1 == nWeightsFormat) ? ShapeUtils::evalTranspShapeInfo( + *weights, block.workspace()) + : inputShape->at(1); - auto dLdw = (bTranspose) ? new NDArray(OUTPUT_VARIABLE(1)->transpose()) : OUTPUT_VARIABLE(1); + auto outputShape = ShapeUtils::matrixProductShape( + inputShape->at(0), weightsShape, false, false, + ArrayOptions::dataType(inputShape->at(0)), block.workspace()); - // dLdb - dLdb->assign(dLdz->reduceAlongDimension(reduce::Sum, { 0 })); + return SHAPELIST(outputShape); +} - matmul_bp mmul_bp; - mmul_bp.execute({ x, w, dLdz }, std::vector{dLdx, dLdw}, {}, {}, {}); +DECLARE_TYPES(xw_plus_b) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - if (bTranspose) { - delete w; - delete dLdw; - } - return Status::OK(); - } +CUSTOM_OP_IMPL(xw_plus_b_bp, 4, 3, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto b = INPUT_VARIABLE(2); + auto dLdz = INPUT_VARIABLE(3); + + auto dLdx = OUTPUT_VARIABLE(0); + auto dLdb = OUTPUT_VARIABLE(2); + + if (x->isEmpty() || INPUT_VARIABLE(1)->isEmpty() || b->isEmpty() || + dLdz->isEmpty()) + return Status::OK(); + + const bool bTranspose = (block.numI() > 0 ? INT_ARG(0) == 1 : false); + + auto w = bTranspose ? new NDArray(INPUT_VARIABLE(1)->transpose()) + : INPUT_VARIABLE(1); + + REQUIRE_TRUE(x->rankOf() == 2, 0, + "xw_plus_b BP: Input x array should have rank equal 2, but got " + "instead %i!", + x->rankOf()); + REQUIRE_TRUE(w->rankOf() == 2, 0, + "xw_plus_b BP: Input weights array should have rank equal 2, " + "but got instead %i!", + w->rankOf()); + REQUIRE_TRUE(dLdz->rankOf() == 2, 0, + "xw_plus_b BP: Output array should have rank equal 2, but got " + "instead %i!", + dLdz->rankOf()); + REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == dLdz->sizeAt(-1), 0, + "xw_plus_b BP: Input bias vector should be 1D and have proper " + "dimension 1x%i." + " But got rank %i, and got length %i instead %i.", + dLdz->sizeAt(-1), b->rankOf(), b->lengthOf(), dLdz->sizeAt(-1)); + + auto dLdw = (bTranspose) ? new NDArray(OUTPUT_VARIABLE(1)->transpose()) + : OUTPUT_VARIABLE(1); + + // dLdb + dLdb->assign(dLdz->reduceAlongDimension(reduce::Sum, {0})); + + matmul_bp mmul_bp; + mmul_bp.execute({x, w, dLdz}, std::vector{dLdx, dLdw}, {}, {}, {}); + + if (bTranspose) { + delete w; + delete dLdw; + } + return Status::OK(); +} - DECLARE_SHAPE_FN(xw_plus_b_bp) { - Nd4jLong* xShapeInfo; - Nd4jLong* wShapeInfo; - Nd4jLong* bShapeInfo; +DECLARE_SHAPE_FN(xw_plus_b_bp) { + Nd4jLong* xShapeInfo; + Nd4jLong* wShapeInfo; + Nd4jLong* bShapeInfo; - COPY_SHAPE(inputShape->at(0), xShapeInfo); - COPY_SHAPE(inputShape->at(1), wShapeInfo); - COPY_SHAPE(inputShape->at(2), bShapeInfo); - return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(wShapeInfo), CONSTANT(bShapeInfo)); - } + COPY_SHAPE(inputShape->at(0), xShapeInfo); + COPY_SHAPE(inputShape->at(1), wShapeInfo); + COPY_SHAPE(inputShape->at(2), bShapeInfo); + return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(wShapeInfo), + CONSTANT(bShapeInfo)); +} - DECLARE_TYPES(xw_plus_b_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ ALL_FLOATS }); - } - } +DECLARE_TYPES(xw_plus_b_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/parity_ops.cpp b/libnd4j/include/ops/declarable/generic/parity_ops.cpp index 3595512a258f..32846bf43c5b 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops.cpp @@ -23,27 +23,25 @@ #ifndef LIBND4J_PARITY_OPS_H #define LIBND4J_PARITY_OPS_H -#include -#include -#include -#include -#include -#include #include +#include #include +#include +#include +#include +#include +#include #include #include -#include -#include -#include #include -#include - -namespace sd { - namespace ops { +#include +#include - } -} +#include +#include -#endif //LIBND4J_PARITY_OPS_H +namespace sd { +namespace ops {} +} // namespace sd +#endif // LIBND4J_PARITY_OPS_H diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/assert.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/assert.cpp index 362d51c83c35..b796e715bd23 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/assert.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/assert.cpp @@ -25,22 +25,21 @@ #include namespace sd { - namespace ops { - OP_IMPL(Assert, 1, 1, false) { - auto x = INPUT_VARIABLE(0); +namespace ops { +OP_IMPL(Assert, 1, 1, false) { + auto x = INPUT_VARIABLE(0); - if (!x->e(0)) { - REQUIRE_TRUE(false, 0, "Assertion failed for node [%i]\n", block.getNodeId()); - } + if (!x->e(0)) { + REQUIRE_TRUE(false, 0, "Assertion failed for node [%i]\n", + block.getNodeId()); + } - return Status::OK(); - } - DECLARE_TYPES(Assert) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setSameMode(true); - } - } + return Status::OK(); } +DECLARE_TYPES(Assert) { + getOpDescriptor()->setAllowedInputTypes(DataType::ANY)->setSameMode(true); +} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/bincount.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/bincount.cpp index 45b864f26010..d6f96b9ac8ea 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/bincount.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/bincount.cpp @@ -26,98 +26,97 @@ #include namespace sd { - namespace ops { - DECLARE_TYPES(bincount) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::INT32) - ->setAllowedInputTypes(1, sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } - - CUSTOM_OP_IMPL(bincount, 1, 1, false, 0, 0) { - auto values = INPUT_VARIABLE(0); - - NDArray *weights = nullptr; - - int maxLength = -1; - int minLength = 0; - int maxIndex = values->argMax(); - maxLength = values->e(maxIndex) + 1; - - if (block.numI() > 0) { - minLength = sd::math::nd4j_max(INT_ARG(0), 0); - if (block.numI() == 2) - maxLength = sd::math::nd4j_min(maxLength, INT_ARG(1)); - } - - if (block.width() == 2) { // the second argument is weights - weights = INPUT_VARIABLE(1); - REQUIRE_TRUE(values->isSameShape(weights), 0, "bincount: the input and weights shapes should be equals"); - } - else if (block.width() == 3) { // the second argument is min and the third is max - auto min= INPUT_VARIABLE(1); - auto max = INPUT_VARIABLE(2); - minLength = min->e(0); - maxLength = max->e(0); - } - else if (block.width() > 3) { - auto min= INPUT_VARIABLE(2); - auto max = INPUT_VARIABLE(3); - minLength = min->e(0); - maxLength = max->e(0); - weights = INPUT_VARIABLE(1); - REQUIRE_TRUE(values->isSameShape(weights), 0, "bincount: the input and weights shapes should be equals"); - - } - minLength = sd::math::nd4j_max(minLength, 0); - maxLength = sd::math::nd4j_min(maxLength, values->e(maxIndex) + 1); - - auto result = OUTPUT_VARIABLE(0); - result->assign(0.0f); - - helpers::adjustWeights(block.launchContext(), values, weights, result, minLength, maxLength); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(bincount) { - auto shapeList = SHAPELIST(); - auto in = INPUT_VARIABLE(0); - sd::DataType dtype = DataType::INT32; - if (block.width() > 1) - dtype = ArrayOptions::dataType(inputShape->at(1)); - else if (block.numI() > 2) - dtype = (sd::DataType)INT_ARG(2); - - int maxIndex = in->argMax(); - int maxLength = in->e(maxIndex) + 1; - int outLength = maxLength; - if (block.numI() > 0) - outLength = sd::math::nd4j_max(maxLength, INT_ARG(0)); - - if (block.numI() > 1) - outLength = sd::math::nd4j_min(outLength, INT_ARG(1)); - - if (block.width() == 3) { // the second argument is min and the third is max - auto min= INPUT_VARIABLE(1)->e(0); - auto max = INPUT_VARIABLE(2)->e(0); - outLength = sd::math::nd4j_max(maxLength, min); - outLength = sd::math::nd4j_min(outLength, max); - } - else if (block.width() > 3) { - auto min= INPUT_VARIABLE(2); - auto max = INPUT_VARIABLE(3); - outLength = sd::math::nd4j_max(maxLength, min->e(0)); - outLength = sd::math::nd4j_min(outLength, max->e(0)); - } - - auto newshape = ConstantShapeHelper::getInstance().vectorShapeInfo(outLength, dtype); - - shapeList->push_back(newshape); - return shapeList; - } - - } +namespace ops { +DECLARE_TYPES(bincount) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::INT32) + ->setAllowedInputTypes(1, sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } +CUSTOM_OP_IMPL(bincount, 1, 1, false, 0, 0) { + auto values = INPUT_VARIABLE(0); + + NDArray *weights = nullptr; + + int maxLength = -1; + int minLength = 0; + int maxIndex = values->argMax(); + maxLength = values->e(maxIndex) + 1; + + if (block.numI() > 0) { + minLength = sd::math::nd4j_max(INT_ARG(0), 0); + if (block.numI() == 2) + maxLength = sd::math::nd4j_min(maxLength, INT_ARG(1)); + } + + if (block.width() == 2) { // the second argument is weights + weights = INPUT_VARIABLE(1); + REQUIRE_TRUE(values->isSameShape(weights), 0, + "bincount: the input and weights shapes should be equals"); + } else if (block.width() == + 3) { // the second argument is min and the third is max + auto min = INPUT_VARIABLE(1); + auto max = INPUT_VARIABLE(2); + minLength = min->e(0); + maxLength = max->e(0); + } else if (block.width() > 3) { + auto min = INPUT_VARIABLE(2); + auto max = INPUT_VARIABLE(3); + minLength = min->e(0); + maxLength = max->e(0); + weights = INPUT_VARIABLE(1); + REQUIRE_TRUE(values->isSameShape(weights), 0, + "bincount: the input and weights shapes should be equals"); + } + minLength = sd::math::nd4j_max(minLength, 0); + maxLength = sd::math::nd4j_min(maxLength, values->e(maxIndex) + 1); + + auto result = OUTPUT_VARIABLE(0); + result->assign(0.0f); + + helpers::adjustWeights(block.launchContext(), values, weights, result, + minLength, maxLength); + + return Status::OK(); +} + +DECLARE_SHAPE_FN(bincount) { + auto shapeList = SHAPELIST(); + auto in = INPUT_VARIABLE(0); + sd::DataType dtype = DataType::INT32; + if (block.width() > 1) + dtype = ArrayOptions::dataType(inputShape->at(1)); + else if (block.numI() > 2) + dtype = (sd::DataType)INT_ARG(2); + + int maxIndex = in->argMax(); + int maxLength = in->e(maxIndex) + 1; + int outLength = maxLength; + if (block.numI() > 0) outLength = sd::math::nd4j_max(maxLength, INT_ARG(0)); + + if (block.numI() > 1) outLength = sd::math::nd4j_min(outLength, INT_ARG(1)); + + if (block.width() == 3) { // the second argument is min and the third is max + auto min = INPUT_VARIABLE(1)->e(0); + auto max = INPUT_VARIABLE(2)->e(0); + outLength = sd::math::nd4j_max(maxLength, min); + outLength = sd::math::nd4j_min(outLength, max); + } else if (block.width() > 3) { + auto min = INPUT_VARIABLE(2); + auto max = INPUT_VARIABLE(3); + outLength = sd::math::nd4j_max(maxLength, min->e(0)); + outLength = sd::math::nd4j_min(outLength, max->e(0)); + } + + auto newshape = + ConstantShapeHelper::getInstance().vectorShapeInfo(outLength, dtype); + + shapeList->push_back(newshape); + return shapeList; +} + +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/broadcast_dynamic_shape.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/broadcast_dynamic_shape.cpp index d954a0b4411f..0e976a8bb135 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/broadcast_dynamic_shape.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/broadcast_dynamic_shape.cpp @@ -28,66 +28,81 @@ namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(broadcast_dynamic_shape, 2, 1, false, 0, 0) { - - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - - auto z = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(x->rankOf() == 1, 0, "BROADCAST_DYNAMIC_SHAPE OP: the first input array must have rank = 1, but got %i instead!", x->rankOf()); - REQUIRE_TRUE(y->rankOf() == 1, 0, "BROADCAST_DYNAMIC_SHAPE OP: the second input array must have rank = 1, but got %i instead!", y->rankOf()); - REQUIRE_TRUE(x->dataType() == y->dataType(), 0, "BROADCAST_DYNAMIC_SHAPE OP: both input arrays must have the same integer type !"); - - // contract shapeInfos, neglect and don't fill strides, ews, order - // shapes are of interest only - std::vector xShapeInfo(shape::shapeInfoLength(x->lengthOf())); - std::vector yShapeInfo(shape::shapeInfoLength(y->lengthOf())); - - // fill rank and data type - xShapeInfo[0] = x->lengthOf(); - yShapeInfo[0] = y->lengthOf(); - ArrayOptions::setDataType(xShapeInfo.data(), sd::DataType::INT64); // fill with some data type, it doesn't matter what type exactly to choose - ArrayOptions::setDataType(yShapeInfo.data(), sd::DataType::INT64); - - for (Nd4jLong i = 0; i < x->lengthOf(); ++i) - xShapeInfo[i + 1] = x->e(i); - - for (Nd4jLong i = 0; i < y->lengthOf(); ++i) - yShapeInfo[i + 1] = y->e(i); - - const Nd4jLong* poinerOnOutShapeInfo = nullptr; - - const bool isBroadcastPossible = ShapeUtils::evalBroadcastShapeInfo(xShapeInfo.data(), yShapeInfo.data(), true, poinerOnOutShapeInfo, block.launchContext()->getWorkspace()); - - REQUIRE_TRUE(isBroadcastPossible, 0, "BROADCAST_DYNAMIC_SHAPE OP: the shapes of two input arrays %s and %s are not suitable for broadcast operation !", ShapeUtils::shapeAsString(xShapeInfo.data()).c_str(), ShapeUtils::shapeAsString(yShapeInfo.data()).c_str()); - - for (Nd4jLong i = 0; i < z->lengthOf(); ++i) - z->p(i, poinerOnOutShapeInfo[i + 1]); - - return Status::OK(); + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + + auto z = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(x->rankOf() == 1, 0, + "BROADCAST_DYNAMIC_SHAPE OP: the first input array must have " + "rank = 1, but got %i instead!", + x->rankOf()); + REQUIRE_TRUE(y->rankOf() == 1, 0, + "BROADCAST_DYNAMIC_SHAPE OP: the second input array must have " + "rank = 1, but got %i instead!", + y->rankOf()); + REQUIRE_TRUE(x->dataType() == y->dataType(), 0, + "BROADCAST_DYNAMIC_SHAPE OP: both input arrays must have the " + "same integer type !"); + + // contract shapeInfos, neglect and don't fill strides, ews, order + // shapes are of interest only + std::vector xShapeInfo(shape::shapeInfoLength(x->lengthOf())); + std::vector yShapeInfo(shape::shapeInfoLength(y->lengthOf())); + + // fill rank and data type + xShapeInfo[0] = x->lengthOf(); + yShapeInfo[0] = y->lengthOf(); + ArrayOptions::setDataType( + xShapeInfo.data(), + sd::DataType::INT64); // fill with some data type, it doesn't matter what + // type exactly to choose + ArrayOptions::setDataType(yShapeInfo.data(), sd::DataType::INT64); + + for (Nd4jLong i = 0; i < x->lengthOf(); ++i) + xShapeInfo[i + 1] = x->e(i); + + for (Nd4jLong i = 0; i < y->lengthOf(); ++i) + yShapeInfo[i + 1] = y->e(i); + + const Nd4jLong* poinerOnOutShapeInfo = nullptr; + + const bool isBroadcastPossible = ShapeUtils::evalBroadcastShapeInfo( + xShapeInfo.data(), yShapeInfo.data(), true, poinerOnOutShapeInfo, + block.launchContext()->getWorkspace()); + + REQUIRE_TRUE(isBroadcastPossible, 0, + "BROADCAST_DYNAMIC_SHAPE OP: the shapes of two input arrays %s " + "and %s are not suitable for broadcast operation !", + ShapeUtils::shapeAsString(xShapeInfo.data()).c_str(), + ShapeUtils::shapeAsString(yShapeInfo.data()).c_str()); + + for (Nd4jLong i = 0; i < z->lengthOf(); ++i) + z->p(i, poinerOnOutShapeInfo[i + 1]); + + return Status::OK(); } DECLARE_TYPES(broadcast_dynamic_shape) { - getOpDescriptor() - ->setAllowedOutputTypes({ALL_INTS}) - ->setAllowedInputTypes({ALL_INTS}); + getOpDescriptor() + ->setAllowedOutputTypes({ALL_INTS}) + ->setAllowedInputTypes({ALL_INTS}); } - ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(broadcast_dynamic_shape) { + const int xRank = INPUT_VARIABLE(0)->lengthOf(); + const int yRank = INPUT_VARIABLE(1)->lengthOf(); - const int xRank = INPUT_VARIABLE(0)->lengthOf(); - const int yRank = INPUT_VARIABLE(1)->lengthOf(); - - const int maxRank = xRank > yRank ? xRank : yRank; + const int maxRank = xRank > yRank ? xRank : yRank; - auto outputShapeInfo = ConstantShapeHelper::getInstance().vectorShapeInfo(maxRank, ArrayOptions::dataType(inputShape->at(0))); + auto outputShapeInfo = ConstantShapeHelper::getInstance().vectorShapeInfo( + maxRank, ArrayOptions::dataType(inputShape->at(0))); - return SHAPELIST(outputShapeInfo); + return SHAPELIST(outputShapeInfo); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/check_numerics.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/check_numerics.cpp index 3d06d4cedb5d..461513e2a1b8 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/check_numerics.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/check_numerics.cpp @@ -24,33 +24,34 @@ #include namespace sd { - namespace ops { +namespace ops { - CUSTOM_OP_IMPL(check_numerics, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto message = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); +CUSTOM_OP_IMPL(check_numerics, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto message = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); - auto allFinite = input->reduceNumber(reduce::BoolOps::IsFinite); - REQUIRE_TRUE(allFinite.e(0), 0, "CheckNumerics: %s", message->e(0).c_str()); + auto allFinite = input->reduceNumber(reduce::BoolOps::IsFinite); + REQUIRE_TRUE(allFinite.e(0), 0, "CheckNumerics: %s", + message->e(0).c_str()); - if (!block.isInplace()) - output->assign(input); + if (!block.isInplace()) output->assign(input); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(check_numerics) { - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(inputShape->at(0)))); - } +DECLARE_SHAPE_FN(check_numerics) { + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(inputShape->at(0)))); +} - DECLARE_TYPES(check_numerics) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, sd::DataType::UTF8) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +DECLARE_TYPES(check_numerics) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, sd::DataType::UTF8) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp index f694502b3429..1442a90c151a 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp @@ -14,47 +14,50 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author sgazeos@gmail.com - // +// +// @author sgazeos@gmail.com +// +#include #include -#include #include -#include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(compare_and_bitpack, 2, 1, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - auto z0 = NDArrayFactory::create(x->ordering(), x->getShapeAsVector(), block.launchContext()); - BROADCAST_CHECK_EMPTY(x, y, (&z0)); - - auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0); - bitcast res; - auto status = res.execute({ tZ }, { z }, {}, { DataType::UINT8 }, {}, {}, false); - if (tZ != &z0) { - delete tZ; - } - - return status; - } - - DECLARE_TYPES(compare_and_bitpack) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::UINT8); - } - - DECLARE_SHAPE_FN(compare_and_bitpack) { - auto inShape = inputShape->at(0); - DataType newType = DataType::UINT8; - - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(inShape, newType))); - } - - } -} \ No newline at end of file +namespace ops { +CUSTOM_OP_IMPL(compare_and_bitpack, 2, 1, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + auto z0 = NDArrayFactory::create(x->ordering(), x->getShapeAsVector(), + block.launchContext()); + BROADCAST_CHECK_EMPTY(x, y, (&z0)); + + auto tZ = + BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0); + bitcast res; + auto status = res.execute({tZ}, {z}, {}, {DataType::UINT8}, {}, {}, false); + if (tZ != &z0) { + delete tZ; + } + + return status; +} + +DECLARE_TYPES(compare_and_bitpack) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::UINT8); +} + +DECLARE_SHAPE_FN(compare_and_bitpack) { + auto inShape = inputShape->at(0); + DataType newType = DataType::UINT8; + + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(inShape, newType))); +} + +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/confusion_matrix.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/confusion_matrix.cpp index f5c5cbb919ab..f955e416deb8 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/confusion_matrix.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/confusion_matrix.cpp @@ -21,66 +21,79 @@ #include #if NOT_EXCLUDED(OP_confusion_matrix) -#include -#include #include #include -#include +#include +#include #include +#include + namespace sd { - namespace ops { - DECLARE_TYPES(confusion_matrix) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); - } +namespace ops { +DECLARE_TYPES(confusion_matrix) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); +} - CUSTOM_OP_IMPL(confusion_matrix, 2, 1, false, 0, -2) { +CUSTOM_OP_IMPL(confusion_matrix, 2, 1, false, 0, -2) { + auto labels = INPUT_VARIABLE(0); + auto predictions = INPUT_VARIABLE(1); + NDArray *weights = nullptr; + if (block.width() > 2) { + weights = INPUT_VARIABLE(2); + REQUIRE_TRUE( + weights->isSameShape(predictions), 0, + "CONFUSION_MATRIX: Weights and predictions should have equal shape"); + } + auto output = OUTPUT_NULLIFIED(0); - auto labels = INPUT_VARIABLE(0); - auto predictions = INPUT_VARIABLE(1); - NDArray *weights = nullptr; - if(block.width() > 2){ - weights = INPUT_VARIABLE(2); - REQUIRE_TRUE(weights->isSameShape(predictions),0, "CONFUSION_MATRIX: Weights and predictions should have equal shape"); - } - auto output = OUTPUT_NULLIFIED(0); + int minPrediction = predictions->reduceNumber(reduce::Min).e(0); + int minLabel = labels->reduceNumber(reduce::Min).e(0); - int minPrediction = predictions->reduceNumber(reduce::Min).e(0); - int minLabel = labels->reduceNumber(reduce::Min).e(0); + REQUIRE_TRUE(minLabel >= 0, 0, + "CONFUSION_MATRIX: Labels contains negative values !"); + REQUIRE_TRUE(minPrediction >= 0, 0, + "CONFUSION_MATRIX: Predictions contains negative values !"); + REQUIRE_TRUE( + labels->isVector(), 0, + "CONFUSION_MATRIX: Labels input should be a Vector, but got %iD instead", + labels->rankOf()); + REQUIRE_TRUE(predictions->isVector(), 0, + "CONFUSION_MATRIX: Predictions input should be Vector, but got " + "%iD instead", + predictions->rankOf()); + REQUIRE_TRUE( + labels->isSameShape(predictions), 0, + "CONFUSION_MATRIX: Labels and predictions should have equal shape"); - REQUIRE_TRUE(minLabel >=0, 0, "CONFUSION_MATRIX: Labels contains negative values !"); - REQUIRE_TRUE(minPrediction >=0, 0, "CONFUSION_MATRIX: Predictions contains negative values !"); - REQUIRE_TRUE(labels->isVector(), 0, "CONFUSION_MATRIX: Labels input should be a Vector, but got %iD instead", labels->rankOf()); - REQUIRE_TRUE(predictions->isVector(), 0, "CONFUSION_MATRIX: Predictions input should be Vector, but got %iD instead", predictions->rankOf()); - REQUIRE_TRUE(labels->isSameShape(predictions),0, "CONFUSION_MATRIX: Labels and predictions should have equal shape"); + helpers::confusionFunctor(block.launchContext(), labels, predictions, weights, + output); - helpers::confusionFunctor(block.launchContext(), labels, predictions, weights, output); + return Status::OK(); +} - return Status::OK(); - } +DECLARE_SHAPE_FN(confusion_matrix) { + auto labels = INPUT_VARIABLE(0); + auto predictions = INPUT_VARIABLE(1); + auto dtype = block.numD() ? D_ARG(0) : sd::DataType::INT64; + int numClasses = 0; - DECLARE_SHAPE_FN(confusion_matrix) { - auto labels = INPUT_VARIABLE(0); - auto predictions = INPUT_VARIABLE(1); - auto dtype = block.numD() ? D_ARG(0) : sd::DataType::INT64; - int numClasses = 0; + if (block.numI() > 0) { + numClasses = INT_ARG(0); + } else { + int maxPrediction = predictions->reduceNumber(reduce::Max).e(0); + int maxLabel = labels->reduceNumber(reduce::Max).e(0); + numClasses = (maxPrediction >= maxLabel) ? maxPrediction + 1 : maxLabel + 1; + } - if (block.getIArguments()->size() > 0) { - numClasses = INT_ARG(0); - } - else { - int maxPrediction = predictions->reduceNumber(reduce::Max).e(0); - int maxLabel = labels->reduceNumber(reduce::Max).e(0); - numClasses = (maxPrediction >= maxLabel) ? maxPrediction+1 : maxLabel+1; - } - - std::array shape = {{numClasses,numClasses}}; - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(dtype, 'c', 2, shape.data()); - return SHAPELIST(newShape); - } - } + std::array shape = {{numClasses, numClasses}}; + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + dtype, 'c', 2, shape.data()); + return SHAPELIST(newShape); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/expose.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/expose.cpp index d9c931f2119d..b3b2689046fd 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/expose.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/expose.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -15,57 +16,54 @@ ******************************************************************************/ // -// Created by raver119 on 12/11/17. +// @author raver119@gmail.com // #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(expose, -1, -1, true, 0, 0) { +namespace ops { +CUSTOM_OP_IMPL(expose, -1, -1, true, 0, 0) { + for (int e = 0; e < block.width(); e++) { + auto inVar = block.variable(e); + if (inVar->variableType() == VariableType::NDARRAY) { + auto in = INPUT_VARIABLE(e); + auto out = OUTPUT_VARIABLE(e); - for (int e = 0; e < block.width(); e++) { - auto inVar = block.variable(e); - if (inVar->variableType() == VariableType::NDARRAY) { - auto in = INPUT_VARIABLE(e); - auto out = OUTPUT_VARIABLE(e); - - out->assign(in); - } else if (inVar->variableType() == VariableType::ARRAY_LIST) { - auto var = block.ensureVariable(e); - if (!var->hasNDArrayList()) { - auto list = inVar->getNDArrayList(); - - block.pushNDArrayListToVariableSpace(block.nodeId(), e, list, false); - } - } - } - - return ND4J_STATUS_OK; - } - DECLARE_SYN(Enter, expose); - DECLARE_SYN(enter, expose); + out->assign(in); + } else if (inVar->variableType() == VariableType::ARRAY_LIST) { + auto var = block.ensureVariable(block.name(), block.nodeId(), e); + if (!var->hasNDArrayList()) { + auto list = inVar->getNDArrayList(); + block.pushNDArrayListToVariableSpace(block.nodeId(), e, list); + } + } + } - DECLARE_TYPES(expose) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } + return ND4J_STATUS_OK; +} +DECLARE_SYN(Enter, expose); +DECLARE_SYN(enter, expose); - DECLARE_SHAPE_FN(expose) { - auto shapeList = SHAPELIST(); +DECLARE_TYPES(expose) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} - for (int e = 0; e < block.width(); e++) { - auto p = block.input(e); - auto var = block.getVariable(e); - if (var->variableType() == VariableType::NDARRAY) { - auto inShape = inputShape->at(e); - shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(inShape))); - } - } +DECLARE_SHAPE_FN(expose) { + auto shapeList = SHAPELIST(); - return shapeList; - } + for (int e = 0; e < block.width(); e++) { + auto p = block.input(e); + auto var = block.getVariable(e); + if (var->variableType() == VariableType::NDARRAY) { + auto inShape = inputShape->at(e); + shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(inShape))); } -} \ No newline at end of file + } + + return shapeList; +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars.cpp index 291e8b7c1fd7..ebd22aa7b333 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars.cpp @@ -24,51 +24,55 @@ #include #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(fake_quant_with_min_max_vars, 1, 1, true, 0, 0) { +namespace ops { +CONFIGURABLE_OP_IMPL(fake_quant_with_min_max_vars, 1, 1, true, 0, 0) { + auto x = INPUT_VARIABLE(0); - auto x = INPUT_VARIABLE(0); + NDArray* min; + NDArray* max; - NDArray* min; - NDArray* max; + REQUIRE_TRUE(block.width() == 3 || block.numT() == 2, 0, + "fake_quant_with_min_max_vars: No minimum/maximum values " + "provided by either input arrays or TArgs"); - REQUIRE_TRUE(block.width() == 3 || block.getTArguments()->size() == 2, 0, "fake_quant_with_min_max_vars: No minimum/maximum values provided by either input arrays or TArgs"); + NDArray m; + NDArray m2; + if (block.width() == 3) { + min = INPUT_VARIABLE(1); + max = INPUT_VARIABLE(2); + } else if (block.numT() == 2) { + m = NDArrayFactory::create(x->dataType(), T_ARG(0), block.launchContext()); + m2 = NDArrayFactory::create(x->dataType(), T_ARG(1), block.launchContext()); + min = &m; + max = &m2; + } + auto output = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(x->dataType() == output->dataType(), 0, + "fake_quant_with_min_max_vars: input and output data types must " + "be the same"); - NDArray m; - NDArray m2; - if(block.width() == 3){ - min = INPUT_VARIABLE(1); - max = INPUT_VARIABLE(2); - } else if(block.getTArguments()->size() == 2){ - m = NDArrayFactory::create(x->dataType(), T_ARG(0), block.launchContext()); - m2 = NDArrayFactory::create(x->dataType(), T_ARG(1), block.launchContext()); - min = &m; - max = &m2; - } - auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(x->dataType() == output->dataType(), 0, "fake_quant_with_min_max_vars: input and output data types must be the same"); - - int numBits = 8; - if (block.getIArguments() && block.getIArguments()->size()) - numBits = INT_ARG(0); - bool narrowed = false; - if (block.getBArguments() && block.getBArguments()->size()) { - narrowed = B_ARG(0); - } - REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars: Number of \ - bits for quantization should be in between 2 and 16, but %i was given.", numBits); - helpers::fakeQuantWithMinMaxVars(x, min, max, numBits, narrowed, output); - return ND4J_STATUS_OK; - } - - DECLARE_TYPES(fake_quant_with_min_max_vars) { - getOpDescriptor() - -> setAllowedOutputTypes({ALL_FLOATS}) - -> setAllowedInputTypes({ALL_INTS, ALL_FLOATS}); - } + int numBits = 8; + if (block.numI()) numBits = INT_ARG(0); + bool narrowed = false; + if (block.numB()) { + narrowed = B_ARG(0); + } + REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, + "fake_quant_with_min_max_vars: Number of \ + bits for quantization should be in between 2 and 16, but %i was given.", + numBits); + helpers::fakeQuantWithMinMaxVars(x, min, max, numBits, narrowed, output); + return ND4J_STATUS_OK; +} - DECLARE_SYN(fake_quant_with_min_max_args, fake_quant_with_min_max_vars); - } +DECLARE_TYPES(fake_quant_with_min_max_vars) { + getOpDescriptor() + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}); } +DECLARE_SYN(fake_quant_with_min_max_args, fake_quant_with_min_max_vars); +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp index 4af483e22e48..5c4ae1754353 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp @@ -24,49 +24,61 @@ #include #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(fake_quant_with_min_max_vars_per_channel, 3, 1, true, 0, 0) { +namespace ops { +CONFIGURABLE_OP_IMPL(fake_quant_with_min_max_vars_per_channel, 3, 1, true, 0, + 0) { + auto x = INPUT_VARIABLE(0); + auto min = INPUT_VARIABLE(1); + auto max = INPUT_VARIABLE(2); - auto x = INPUT_VARIABLE(0); - auto min = INPUT_VARIABLE(1); - auto max = INPUT_VARIABLE(2); + auto depth = x->sizeAt(-1); + REQUIRE_TRUE(min->rankOf() == 1 && max->rankOf() == 1 && + min->lengthOf() == max->lengthOf(), + 0, + "fake_quant_with_min_max_vars_per_channel: Min and Max should " + "be 1D tensors with the same length"); + REQUIRE_TRUE(depth == min->lengthOf(), 0, + "fake_quant_with_min_max_vars_per_channel: Min length should be" + " %lld, but %lld occurs.", + depth, min->lengthOf()); - auto depth = x->sizeAt(-1); - REQUIRE_TRUE(min->rankOf() == 1 && max->rankOf() == 1 && min->lengthOf() == max->lengthOf(), 0, - "fake_quant_with_min_max_vars_per_channel: Min and Max should be 1D tensors with the same length"); - REQUIRE_TRUE(depth == min->lengthOf(), 0, "fake_quant_with_min_max_vars_per_channel: Min length should be" - " %lld, but %lld occurs.", depth, min->lengthOf()); + REQUIRE_TRUE(depth == max->lengthOf(), 0, + "fake_quant_with_min_max_vars_per_channel: Max length should be" + "%lld, but %lld occurs.", + depth, max->lengthOf()); + auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(depth == max->lengthOf(), 0, "fake_quant_with_min_max_vars_per_channel: Max length should be" - "%lld, but %lld occurs.", depth, max->lengthOf()); - auto output = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(x->dataType() == output->dataType(), 0, + "fake_quant_with_min_max_vars_per_channel: input and output " + "data types must be the same"); - REQUIRE_TRUE(x->dataType() == output->dataType(), 0, "fake_quant_with_min_max_vars_per_channel: input and output data types must be the same"); + int numBits = 8; + if (block.numI()) numBits = INT_ARG(0); + bool narrowed = false; + // INT_ARG(1); + if (block.numB()) { + narrowed = B_ARG(0); + } - int numBits = 8; - if (block.getIArguments() && block.getIArguments()->size()) - numBits = INT_ARG(0); - bool narrowed = false; - //INT_ARG(1); - if (block.getBArguments() && block.getBArguments()->size()) { - narrowed = B_ARG(0); - } - - REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars_per_channel: Number of bits" - " for quatization should be in between 2 and 16, but %i " - "was given.", numBits); - helpers::fakeQuantWithMinMaxVarsPerChannel(block.launchContext(), x, min, max, numBits, narrowed, output); - return ND4J_STATUS_OK; - } - - DECLARE_TYPES(fake_quant_with_min_max_vars_per_channel) { - getOpDescriptor() - -> setAllowedOutputTypes({ALL_FLOATS}) - -> setAllowedInputTypes({ALL_INTS, ALL_FLOATS}); - } + REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, + "fake_quant_with_min_max_vars_per_channel: Number of bits" + " for quatization should be in between 2 and 16, but %i " + "was given.", + numBits); + helpers::fakeQuantWithMinMaxVarsPerChannel(block.launchContext(), x, min, max, + numBits, narrowed, output); + return ND4J_STATUS_OK; +} - DECLARE_SYN(fake_quant_with_min_max_args_per_channel, fake_quant_with_min_max_vars_per_channel); - } +DECLARE_TYPES(fake_quant_with_min_max_vars_per_channel) { + getOpDescriptor() + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}); } +DECLARE_SYN(fake_quant_with_min_max_args_per_channel, + fake_quant_with_min_max_vars_per_channel); +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/in_top_k.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/in_top_k.cpp index 7618de5b1b0a..6f35b27cc703 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/in_top_k.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/in_top_k.cpp @@ -26,39 +26,51 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(in_top_k, 2, 1, true, 0, 1) { - auto predictions = INPUT_VARIABLE(0); - auto target = INPUT_VARIABLE(1); +namespace ops { - auto result = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(block.numI() > 0, 0, "in_top_k: Parameter k is needed to be set"); - REQUIRE_TRUE(predictions->sizeAt(0) == target->lengthOf(), 0, "in_top_k: the number of predictions rows should be equal to target array length, but got %i and %i correspondingly !", predictions->sizeAt(0), target->lengthOf()); - REQUIRE_TRUE(predictions->rankOf() == 2, 0, "in_top_k: The predictions array should have rank 2, but %i given", predictions->rankOf()); - REQUIRE_TRUE(target->rankOf() == 1, 0, "in_top_k: The target should be a vector"); +CUSTOM_OP_IMPL(in_top_k, 2, 1, true, 0, 1) { + auto predictions = INPUT_VARIABLE(0); + auto target = INPUT_VARIABLE(1); - int k = INT_ARG(0); - return helpers::inTopKFunctor(block.launchContext(), predictions, target, result, k); - } + auto result = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(block.numI() > 0, 0, + "in_top_k: Parameter k is needed to be set"); + REQUIRE_TRUE(predictions->sizeAt(0) == target->lengthOf(), 0, + "in_top_k: the number of predictions rows should be equal to " + "target array length, but got %i and %i correspondingly !", + predictions->sizeAt(0), target->lengthOf()); + REQUIRE_TRUE( + predictions->rankOf() == 2, 0, + "in_top_k: The predictions array should have rank 2, but %i given", + predictions->rankOf()); + REQUIRE_TRUE(target->rankOf() == 1, 0, + "in_top_k: The target should be a vector"); - DECLARE_SHAPE_FN(in_top_k) { - auto shapeList = SHAPELIST(); - auto in = inputShape->at(1); - int shapeRank = shape::rank(in); + int k = INT_ARG(0); + return helpers::inTopKFunctor(block.launchContext(), predictions, target, + result, k); +} - auto aShape = ConstantShapeHelper::getInstance().createShapeInfo(sd::DataType::BOOL, shape::order(in), shape::rank(in), shape::shapeOf(in)); - shapeList->push_back(aShape); - return shapeList; - } +DECLARE_SHAPE_FN(in_top_k) { + auto shapeList = SHAPELIST(); + auto in = inputShape->at(1); + int shapeRank = shape::rank(in); - DECLARE_TYPES(in_top_k) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes(DataType::BOOL); - } + auto aShape = ConstantShapeHelper::getInstance().createShapeInfo( + sd::DataType::BOOL, shape::order(in), shape::rank(in), + shape::shapeOf(in)); + shapeList->push_back(aShape); + return shapeList; +} - } +DECLARE_TYPES(in_top_k) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes(DataType::BOOL); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/listdiff.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/listdiff.cpp index 86a37619eb8c..a5051ec56fd5 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/listdiff.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/listdiff.cpp @@ -26,46 +26,58 @@ // this op will probably never become GPU-compatible namespace sd { - namespace ops { - CUSTOM_OP_IMPL(listdiff, 2, 2, false, 0, 0) { - auto values = INPUT_VARIABLE(0); - auto keep = INPUT_VARIABLE(1); - auto output1 = OUTPUT_VARIABLE(0); - auto output2 = OUTPUT_VARIABLE(1); +namespace ops { +CUSTOM_OP_IMPL(listdiff, 2, 2, false, 0, 0) { + auto values = INPUT_VARIABLE(0); + auto keep = INPUT_VARIABLE(1); + auto output1 = OUTPUT_VARIABLE(0); + auto output2 = OUTPUT_VARIABLE(1); - REQUIRE_TRUE(values->rankOf() == 1, 0, "ListDiff: rank of values should be 1D, but got %iD instead", values->rankOf()); - REQUIRE_TRUE(keep->rankOf() == 1, 0, "ListDiff: rank of keep should be 1D, but got %iD instead", keep->rankOf()); - REQUIRE_TRUE(keep->dataType() == values->dataType(), 0, "ListDiff: both inputs must have same data type"); + REQUIRE_TRUE(values->rankOf() == 1, 0, + "ListDiff: rank of values should be 1D, but got %iD instead", + values->rankOf()); + REQUIRE_TRUE(keep->rankOf() == 1, 0, + "ListDiff: rank of keep should be 1D, but got %iD instead", + keep->rankOf()); + REQUIRE_TRUE(keep->dataType() == values->dataType(), 0, + "ListDiff: both inputs must have same data type"); - return helpers::listDiffFunctor(block.launchContext(), values, keep, output1, output2); - }; + return helpers::listDiffFunctor(block.launchContext(), values, keep, output1, + output2); +}; - DECLARE_SHAPE_FN(listdiff) { - auto values = INPUT_VARIABLE(0); - auto keep = INPUT_VARIABLE(1); +DECLARE_SHAPE_FN(listdiff) { + auto values = INPUT_VARIABLE(0); + auto keep = INPUT_VARIABLE(1); - REQUIRE_TRUE(values->rankOf() == 1, 0, "ListDiff: rank of values should be 1D, but got %iD instead", values->rankOf()); - REQUIRE_TRUE(keep->rankOf() == 1, 0, "ListDiff: rank of keep should be 1D, but got %iD instead", keep->rankOf()); - auto v = values->dataType(); - auto k = keep->dataType(); - REQUIRE_TRUE(k == v, 0, "ListDiff: both inputs must have same data type"); + REQUIRE_TRUE(values->rankOf() == 1, 0, + "ListDiff: rank of values should be 1D, but got %iD instead", + values->rankOf()); + REQUIRE_TRUE(keep->rankOf() == 1, 0, + "ListDiff: rank of keep should be 1D, but got %iD instead", + keep->rankOf()); + auto v = values->dataType(); + auto k = keep->dataType(); + REQUIRE_TRUE(k == v, 0, "ListDiff: both inputs must have same data type"); - auto saved = helpers::listDiffCount(block.launchContext(), values, keep); + auto saved = helpers::listDiffCount(block.launchContext(), values, keep); - REQUIRE_TRUE(saved > 0, 0, "ListDiff: no matches found"); + REQUIRE_TRUE(saved > 0, 0, "ListDiff: no matches found"); - auto shapeX = ConstantShapeHelper::getInstance().vectorShapeInfo(saved, values->dataType()); - auto shapeY = ConstantShapeHelper::getInstance().vectorShapeInfo(saved, DataType::INT64); - return SHAPELIST(shapeX, shapeY); - } + auto shapeX = ConstantShapeHelper::getInstance().vectorShapeInfo( + saved, values->dataType()); + auto shapeY = ConstantShapeHelper::getInstance().vectorShapeInfo( + saved, DataType::INT64); + return SHAPELIST(shapeX, shapeY); +} - DECLARE_TYPES(listdiff) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes(0, DataType::INHERIT) - ->setAllowedOutputTypes(1, {ALL_INTS}); - } - } +DECLARE_TYPES(listdiff) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes(0, DataType::INHERIT) + ->setAllowedOutputTypes(1, {ALL_INTS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp index 91512b2f71df..ef0f0241ba32 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp @@ -21,201 +21,226 @@ #include #include - namespace sd { - namespace ops { +namespace ops { #if NOT_EXCLUDED(OP_image_non_max_suppression) - CUSTOM_OP_IMPL(non_max_suppression, 2, 1, false, 0, 0) { - auto boxes = INPUT_VARIABLE(0); - auto scales = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - int maxOutputSize; // = INT_ARG(0); - if (block.width() > 2) - maxOutputSize = INPUT_VARIABLE(2)->e(0); - else if (block.getIArguments()->size() == 1) - maxOutputSize = INT_ARG(0); - else - REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved."); - - double overlayThreshold = 0.5; - double scoreThreshold = - DataTypeUtils::infOrMax(); - - if (block.width() > 3) { - overlayThreshold = INPUT_VARIABLE(3)->e(0); - } - else if (block.getTArguments()->size() > 0) { - overlayThreshold = T_ARG(0); - } - - if (block.width() > 4) { - scoreThreshold = INPUT_VARIABLE(4)->e(0); - } - else if (block.getTArguments()->size() > 1) { - scoreThreshold = T_ARG(1); - } - if (boxes->isEmpty() || scales->isEmpty()) - return Status::OK(); - - if (output->isEmpty()) - return Status::OK(); - - REQUIRE_TRUE(boxes->rankOf() == 2, 0, "image.non_max_suppression: The rank of boxes array should be 2, " - "but %i is given", boxes->rankOf()); - REQUIRE_TRUE(boxes->sizeAt(1) == 4, 0, "image.non_max_suppression: The last dimension of boxes array " - "should be 4, but %i is given", boxes->sizeAt(1)); - REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0, - "image.non_max_suppression: The rank of scales array should be 1, but %i is given", boxes->rankOf()); - REQUIRE_TRUE(overlayThreshold >= 0. && overlayThreshold <= 1., 0, "image.non_max_suppressio: The overlay " - "threashold should be in [0, 1], but " - "%lf is given.", overlayThreshold); - REQUIRE_TRUE(boxes->dataType() == scales->dataType(), 0, - "image.non_max_suppression: Boxes and scores inputs should have the same data type, but %s and %s " - "were given.", DataTypeUtils::asString(boxes->dataType()).c_str(), - DataTypeUtils::asString(scales->dataType()).c_str()); - helpers::nonMaxSuppression(block.launchContext(), boxes, scales, maxOutputSize, overlayThreshold, - scoreThreshold, output); - return Status::OK(); - } +CUSTOM_OP_IMPL(non_max_suppression, 2, 1, false, 0, 0) { + auto boxes = INPUT_VARIABLE(0); + auto scales = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + int maxOutputSize; // = INT_ARG(0); + if (block.width() > 2) + maxOutputSize = INPUT_VARIABLE(2)->e(0); + else if (block.numI() == 1) + maxOutputSize = INT_ARG(0); + else + REQUIRE_TRUE(false, 0, + "image.non_max_suppression: Max output size argument cannot " + "be retrieved."); + + double overlayThreshold = 0.5; + double scoreThreshold = -DataTypeUtils::infOrMax(); + + if (block.width() > 3) { + overlayThreshold = INPUT_VARIABLE(3)->e(0); + } else if (block.numT() > 0) { + overlayThreshold = T_ARG(0); + } + + if (block.width() > 4) { + scoreThreshold = INPUT_VARIABLE(4)->e(0); + } else if (block.numT() > 1) { + scoreThreshold = T_ARG(1); + } + if (boxes->isEmpty() || scales->isEmpty()) return Status::OK(); + + if (output->isEmpty()) return Status::OK(); + + REQUIRE_TRUE( + boxes->rankOf() == 2, 0, + "image.non_max_suppression: The rank of boxes array should be 2, " + "but %i is given", + boxes->rankOf()); + REQUIRE_TRUE(boxes->sizeAt(1) == 4, 0, + "image.non_max_suppression: The last dimension of boxes array " + "should be 4, but %i is given", + boxes->sizeAt(1)); + REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), + 0, + "image.non_max_suppression: The rank of scales array should be " + "1, but %i is given", + boxes->rankOf()); + REQUIRE_TRUE(overlayThreshold >= 0. && overlayThreshold <= 1., 0, + "image.non_max_suppressio: The overlay " + "threashold should be in [0, 1], but " + "%lf is given.", + overlayThreshold); + REQUIRE_TRUE(boxes->dataType() == scales->dataType(), 0, + "image.non_max_suppression: Boxes and scores inputs should have " + "the same data type, but %s and %s " + "were given.", + DataTypeUtils::asString(boxes->dataType()).c_str(), + DataTypeUtils::asString(scales->dataType()).c_str()); + helpers::nonMaxSuppression(block.launchContext(), boxes, scales, + maxOutputSize, overlayThreshold, scoreThreshold, + output); + return Status::OK(); +} - DECLARE_SHAPE_FN(non_max_suppression) { - auto in = inputShape->at(0); - int outRank = shape::rank(in); - const Nd4jLong *outputShape = nullptr; - - int maxOutputSize; - if (block.width() > 2) - maxOutputSize = INPUT_VARIABLE(2)->e(0); - else if (block.getIArguments()->size() == 1) - maxOutputSize = INT_ARG(0); - else - REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved."); - - if (maxOutputSize > 0) { - auto actualIndicesCount = shape::sizeAt(in, 0); - if (block.getTArguments()->size() > 1 || block.width() > 4) { - auto scoreThreshold = - block.getTArguments()->size() > 1 ? T_ARG(1) : INPUT_VARIABLE(4)->e(0); - auto scales = INPUT_VARIABLE(1); - scales->syncToHost(); - for (auto e = 0; e < scales->lengthOf(); e++) { - if (scales->e(e) < (float) scoreThreshold) { - actualIndicesCount--; - } - } - } - if (actualIndicesCount < maxOutputSize) - maxOutputSize = actualIndicesCount; - } - outputShape = ConstantShapeHelper::getInstance().vectorShapeInfo(maxOutputSize, DataType::INT32); - - return SHAPELIST(outputShape); - } - DECLARE_TYPES(non_max_suppression) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_INDICES}); +DECLARE_SHAPE_FN(non_max_suppression) { + auto in = inputShape->at(0); + int outRank = shape::rank(in); + const Nd4jLong *outputShape = nullptr; + + int maxOutputSize; + if (block.width() > 2) + maxOutputSize = INPUT_VARIABLE(2)->e(0); + else if (block.numI() == 1) + maxOutputSize = INT_ARG(0); + else + REQUIRE_TRUE(false, 0, + "image.non_max_suppression: Max output size argument cannot " + "be retrieved."); + + if (maxOutputSize > 0) { + auto actualIndicesCount = shape::sizeAt(in, 0); + if (block.numT() > 1 || block.width() > 4) { + auto scoreThreshold = + block.numT() > 1 ? T_ARG(1) : INPUT_VARIABLE(4)->e(0); + auto scales = INPUT_VARIABLE(1); + scales->syncToHost(); + for (auto e = 0; e < scales->lengthOf(); e++) { + if (scales->e(e) < (float)scoreThreshold) { + actualIndicesCount--; } + } + } + if (actualIndicesCount < maxOutputSize) maxOutputSize = actualIndicesCount; + } + outputShape = ConstantShapeHelper::getInstance().vectorShapeInfo( + maxOutputSize, DataType::INT32); + + return SHAPELIST(outputShape); +} +DECLARE_TYPES(non_max_suppression) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_INDICES}); +} #endif #if NOT_EXCLUDED(OP_image_non_max_suppression_v3) - DECLARE_TYPES(non_max_suppression_v3) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_INDICES}); - } +DECLARE_TYPES(non_max_suppression_v3) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_INDICES}); +} - CUSTOM_OP_IMPL(non_max_suppression_v3, 2, 1, false, 0, 0) { - auto boxes = INPUT_VARIABLE(0); - auto scales = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - int maxOutputSize; // = INT_ARG(0); - if (block.width() > 2) - maxOutputSize = INPUT_VARIABLE(2)->e(0); - else if (block.getIArguments()->size() == 1) - maxOutputSize = INT_ARG(0); - else - REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved."); - - double overlayThreshold = 0.5; - double scoreThreshold = - DataTypeUtils::infOrMax(); - - if (block.width() > 3) { - overlayThreshold = INPUT_VARIABLE(3)->e(0); - } - else if (block.getTArguments()->size() > 0) { - overlayThreshold = T_ARG(0); - } - - if (block.width() > 4) { - scoreThreshold = INPUT_VARIABLE(4)->e(0); - } - else if (block.getTArguments()->size() > 1) { - scoreThreshold = T_ARG(1); - } - if (boxes->isEmpty() || scales->isEmpty()) - return Status::OK(); - if (output->isEmpty()) - return Status::OK(); - - REQUIRE_TRUE(boxes->rankOf() == 2, 0, "image.non_max_suppression: The rank of boxes array should be 2, but " - "%i is given", boxes->rankOf()); - REQUIRE_TRUE(boxes->sizeAt(1) == 4, 0, "image.non_max_suppression: The last dimension of boxes array should " - "be 4, but %i is given", boxes->sizeAt(1)); - REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0, - "image.non_max_suppression: The rank of scales array should be 1, but %i is given", boxes->rankOf()); - REQUIRE_TRUE(overlayThreshold >= 0. && overlayThreshold <= 1., 0, - "image.non_max_suppression_v3: The overlay threashold should be in [0, 1], but %lf given.", - overlayThreshold); - REQUIRE_TRUE(boxes->dataType() == scales->dataType(), 0, - "image.non_max_suppression_v3: Boxes and scores inputs should have the same data type, but %s and %s " - "were given.", DataTypeUtils::asString(boxes->dataType()).c_str(), - DataTypeUtils::asString(scales->dataType()).c_str()); - - helpers::nonMaxSuppressionV3(block.launchContext(), boxes, scales, maxOutputSize, overlayThreshold, - scoreThreshold, output); - return Status::OK(); - } +CUSTOM_OP_IMPL(non_max_suppression_v3, 2, 1, false, 0, 0) { + auto boxes = INPUT_VARIABLE(0); + auto scales = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + int maxOutputSize; // = INT_ARG(0); + if (block.width() > 2) + maxOutputSize = INPUT_VARIABLE(2)->e(0); + else if (block.numI() == 1) + maxOutputSize = INT_ARG(0); + else + REQUIRE_TRUE(false, 0, + "image.non_max_suppression: Max output size argument cannot " + "be retrieved."); + + double overlayThreshold = 0.5; + double scoreThreshold = -DataTypeUtils::infOrMax(); + + if (block.width() > 3) { + overlayThreshold = INPUT_VARIABLE(3)->e(0); + } else if (block.numT() > 0) { + overlayThreshold = T_ARG(0); + } + + if (block.width() > 4) { + scoreThreshold = INPUT_VARIABLE(4)->e(0); + } else if (block.numT() > 1) { + scoreThreshold = T_ARG(1); + } + if (boxes->isEmpty() || scales->isEmpty()) return Status::OK(); + if (output->isEmpty()) return Status::OK(); + + REQUIRE_TRUE( + boxes->rankOf() == 2, 0, + "image.non_max_suppression: The rank of boxes array should be 2, but " + "%i is given", + boxes->rankOf()); + REQUIRE_TRUE( + boxes->sizeAt(1) == 4, 0, + "image.non_max_suppression: The last dimension of boxes array should " + "be 4, but %i is given", + boxes->sizeAt(1)); + REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), + 0, + "image.non_max_suppression: The rank of scales array should be " + "1, but %i is given", + boxes->rankOf()); + REQUIRE_TRUE(overlayThreshold >= 0. && overlayThreshold <= 1., 0, + "image.non_max_suppression_v3: The overlay threashold should be " + "in [0, 1], but %lf given.", + overlayThreshold); + REQUIRE_TRUE(boxes->dataType() == scales->dataType(), 0, + "image.non_max_suppression_v3: Boxes and scores inputs should " + "have the same data type, but %s and %s " + "were given.", + DataTypeUtils::asString(boxes->dataType()).c_str(), + DataTypeUtils::asString(scales->dataType()).c_str()); + + helpers::nonMaxSuppressionV3(block.launchContext(), boxes, scales, + maxOutputSize, overlayThreshold, scoreThreshold, + output); + return Status::OK(); +} - DECLARE_SHAPE_FN(non_max_suppression_v3) { - auto in = inputShape->at(0); - int outRank = shape::rank(in); - - - int maxOutputSize; - if (block.width() > 2) - maxOutputSize = INPUT_VARIABLE(2)->e(0); - else if (block.getIArguments()->size() == 1) - maxOutputSize = INT_ARG(0); - else - REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved."); - auto boxes = INPUT_VARIABLE(0); - auto scales = INPUT_VARIABLE(1); - - double overlayThreshold = 0.5; - double scoreThreshold = - DataTypeUtils::infOrMax(); - - if (block.width() > 3) { - overlayThreshold = INPUT_VARIABLE(3)->e(0); - } - else if (block.getTArguments()->size() > 0) { - overlayThreshold = T_ARG(0); - } - - if (block.width() > 4) { - scoreThreshold = INPUT_VARIABLE(4)->e(0); - } - else if (block.getTArguments()->size() > 1) { - scoreThreshold = T_ARG(1); - } - - auto len = maxOutputSize; - if (len > 0) - len = helpers::nonMaxSuppressionV3(block.launchContext(), boxes, scales, maxOutputSize, overlayThreshold, scoreThreshold, nullptr); - - auto outputShape = ConstantShapeHelper::getInstance().vectorShapeInfo(len, DataType::INT32); - - return SHAPELIST(outputShape); - } +DECLARE_SHAPE_FN(non_max_suppression_v3) { + auto in = inputShape->at(0); + int outRank = shape::rank(in); + + int maxOutputSize; + if (block.width() > 2) + maxOutputSize = INPUT_VARIABLE(2)->e(0); + else if (block.numI() == 1) + maxOutputSize = INT_ARG(0); + else + REQUIRE_TRUE(false, 0, + "image.non_max_suppression: Max output size argument cannot " + "be retrieved."); + auto boxes = INPUT_VARIABLE(0); + auto scales = INPUT_VARIABLE(1); + + double overlayThreshold = 0.5; + double scoreThreshold = -DataTypeUtils::infOrMax(); + + if (block.width() > 3) { + overlayThreshold = INPUT_VARIABLE(3)->e(0); + } else if (block.numT() > 0) { + overlayThreshold = T_ARG(0); + } + + if (block.width() > 4) { + scoreThreshold = INPUT_VARIABLE(4)->e(0); + } else if (block.numT() > 1) { + scoreThreshold = T_ARG(1); + } + + auto len = maxOutputSize; + if (len > 0) + len = helpers::nonMaxSuppressionV3(block.launchContext(), boxes, scales, + maxOutputSize, overlayThreshold, + scoreThreshold, nullptr); + + auto outputShape = + ConstantShapeHelper::getInstance().vectorShapeInfo(len, DataType::INT32); + + return SHAPELIST(outputShape); +} #endif - } -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression_overlaps.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression_overlaps.cpp index 1cc4addbc72a..4bbe71c41946 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression_overlaps.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression_overlaps.cpp @@ -24,69 +24,84 @@ #if NOT_EXCLUDED(OP_image_non_max_suppression_overlaps) namespace sd { - namespace ops { - CUSTOM_OP_IMPL(non_max_suppression_overlaps, 2, 1, false, 0, 0) { - auto boxes = INPUT_VARIABLE(0); - auto scales = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - int maxOutputSize; // = INT_ARG(0); - if (block.width() > 2) - maxOutputSize = INPUT_VARIABLE(2)->e(0); - else if (block.getIArguments()->size() == 1) - maxOutputSize = INT_ARG(0); - else - REQUIRE_TRUE(false, 0, "image.non_max_suppression_overlaps: Max output size argument cannot be retrieved."); - REQUIRE_TRUE(boxes->rankOf() == 2, 0, "image.non_max_suppression_overlaps: The rank of boxes array should be 2, but %i is given", boxes->rankOf()); - REQUIRE_TRUE(boxes->sizeAt(0) == boxes->sizeAt(1), 0, "image.non_max_suppression_overlaps: The boxes array should be square, but {%lld, %lld} is given", boxes->sizeAt(0), boxes->sizeAt(1)); - REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0, "image.non_max_suppression_overlaps: The rank of scales array should be 1, but %i is given", boxes->rankOf()); +namespace ops { +CUSTOM_OP_IMPL(non_max_suppression_overlaps, 2, 1, false, 0, 0) { + auto boxes = INPUT_VARIABLE(0); + auto scales = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + int maxOutputSize; // = INT_ARG(0); + if (block.width() > 2) + maxOutputSize = INPUT_VARIABLE(2)->e(0); + else if (block.numI() == 1) + maxOutputSize = INT_ARG(0); + else + REQUIRE_TRUE(false, 0, + "image.non_max_suppression_overlaps: Max output size argument " + "cannot be retrieved."); + REQUIRE_TRUE(boxes->rankOf() == 2, 0, + "image.non_max_suppression_overlaps: The rank of boxes array " + "should be 2, but %i is given", + boxes->rankOf()); + REQUIRE_TRUE(boxes->sizeAt(0) == boxes->sizeAt(1), 0, + "image.non_max_suppression_overlaps: The boxes array should be " + "square, but {%lld, %lld} is given", + boxes->sizeAt(0), boxes->sizeAt(1)); + REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), + 0, + "image.non_max_suppression_overlaps: The rank of scales array " + "should be 1, but %i is given", + boxes->rankOf()); -// if (scales->lengthOf() < maxOutputSize) -// maxOutputSize = scales->lengthOf(); - double overlapThreshold = 0.5; - double scoreThreshold = -DataTypeUtils::infOrMax(); - if (block.getTArguments()->size() > 0) - overlapThreshold = T_ARG(0); - if (block.getTArguments()->size() > 1) - scoreThreshold = T_ARG(1); + // if (scales->lengthOf() < maxOutputSize) + // maxOutputSize = scales->lengthOf(); + double overlapThreshold = 0.5; + double scoreThreshold = -DataTypeUtils::infOrMax(); + if (block.numT() > 0) overlapThreshold = T_ARG(0); + if (block.numT() > 1) scoreThreshold = T_ARG(1); - // TODO: refactor helpers to multithreaded facility - helpers::nonMaxSuppressionGeneric(block.launchContext(), boxes, scales, maxOutputSize, overlapThreshold, - scoreThreshold, output); - return Status::OK(); - } - - DECLARE_SHAPE_FN(non_max_suppression_overlaps) { - auto in = inputShape->at(0); - int outRank = shape::rank(in); + // TODO: refactor helpers to multithreaded facility + helpers::nonMaxSuppressionGeneric(block.launchContext(), boxes, scales, + maxOutputSize, overlapThreshold, + scoreThreshold, output); + return Status::OK(); +} - int maxOutputSize; - if (block.width() > 2) - maxOutputSize = INPUT_VARIABLE(2)->e(0); - else if (block.getIArguments()->size() == 1) - maxOutputSize = INT_ARG(0); - else - REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved."); +DECLARE_SHAPE_FN(non_max_suppression_overlaps) { + auto in = inputShape->at(0); + int outRank = shape::rank(in); - double overlapThreshold = 0.5; - double scoreThreshold = 0.; + int maxOutputSize; + if (block.width() > 2) + maxOutputSize = INPUT_VARIABLE(2)->e(0); + else if (block.numI() == 1) + maxOutputSize = INT_ARG(0); + else + REQUIRE_TRUE(false, 0, + "image.non_max_suppression: Max output size argument cannot " + "be retrieved."); - Nd4jLong boxSize = helpers::nonMaxSuppressionGeneric(block.launchContext(), INPUT_VARIABLE(0), - INPUT_VARIABLE(1), maxOutputSize, overlapThreshold, scoreThreshold, nullptr); //shape::sizeAt(in, 0); - if (boxSize < maxOutputSize) - maxOutputSize = boxSize; + double overlapThreshold = 0.5; + double scoreThreshold = 0.; - auto outputShape = ConstantShapeHelper::getInstance().vectorShapeInfo(maxOutputSize, DataType::INT32); + Nd4jLong boxSize = helpers::nonMaxSuppressionGeneric( + block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), + maxOutputSize, overlapThreshold, scoreThreshold, + nullptr); // shape::sizeAt(in, 0); + if (boxSize < maxOutputSize) maxOutputSize = boxSize; - return SHAPELIST(outputShape); - } - DECLARE_TYPES(non_max_suppression_overlaps) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_INDICES}); - } + auto outputShape = ConstantShapeHelper::getInstance().vectorShapeInfo( + maxOutputSize, DataType::INT32); - } + return SHAPELIST(outputShape); } +DECLARE_TYPES(non_max_suppression_overlaps) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_INDICES}); +} + +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp index f8a4c5c6ed1a..f8664f645abd 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp @@ -14,9 +14,9 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // Created by george@skymind.io on 26.01.2018. - // +// +// Created by george@skymind.io on 26.01.2018. +// #include #if NOT_EXCLUDED(OP_normalize_moments) @@ -24,62 +24,63 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(normalize_moments, 3, 2, false, 1, 0) { - auto counts = INPUT_VARIABLE(0); - auto means = INPUT_VARIABLE(1); - auto variances = INPUT_VARIABLE(2); +namespace ops { +CUSTOM_OP_IMPL(normalize_moments, 3, 2, false, 1, 0) { + auto counts = INPUT_VARIABLE(0); + auto means = INPUT_VARIABLE(1); + auto variances = INPUT_VARIABLE(2); - auto resMeans = OUTPUT_VARIABLE(0); - auto resVariances = OUTPUT_VARIABLE(1); + auto resMeans = OUTPUT_VARIABLE(0); + auto resVariances = OUTPUT_VARIABLE(1); - // FIXME: double? - NDArray shift = NDArrayFactory::create(0., block.launchContext()); + // FIXME: double? + NDArray shift = NDArrayFactory::create(0., block.launchContext()); - if (block.getTArguments()->size() > 0) { - shift.assign(T_ARG(0)); - } + if (block.numT() > 0) { + shift.assign(T_ARG(0)); + } - means->applyScalarArr(scalar::Divide, *counts, *resMeans); + means->applyScalarArr(scalar::Divide, *counts, *resMeans); - NDArray squareMeans = resMeans->dup('c'); - NDArray tempVariances = resVariances->dup('c'); + NDArray squareMeans = resMeans->dup('c'); + NDArray tempVariances = resVariances->dup('c'); - squareMeans.applyTransform(transform::Square, squareMeans, nullptr); - variances->applyScalarArr(scalar::Divide, *counts, tempVariances); - // tempVariances.printIndexedBuffer("varianced divided by count"); - tempVariances.applyPairwiseTransform(pairwise::Subtract, squareMeans, *resVariances); + squareMeans.applyTransform(transform::Square, squareMeans, nullptr); + variances->applyScalarArr(scalar::Divide, *counts, tempVariances); + // tempVariances.printIndexedBuffer("varianced divided by count"); + tempVariances.applyPairwiseTransform(pairwise::Subtract, squareMeans, + *resVariances); - if (shift.e(0) != 0) { - resMeans->applyScalarArr(scalar::Add, shift, *resMeans); - } + if (shift.e(0) != 0) { + resMeans->applyScalarArr(scalar::Add, shift, *resMeans); + } - return Status::OK(); - } - - DECLARE_SHAPE_FN(normalize_moments) { - auto in = inputShape->at(1); + return Status::OK(); +} - Nd4jLong* meanShape = nullptr; - Nd4jLong* varianceShape = nullptr; +DECLARE_SHAPE_FN(normalize_moments) { + auto in = inputShape->at(1); - COPY_SHAPE_EX(in, meanShape, block.getWorkspace()); - COPY_SHAPE_EX(in, varianceShape, block.getWorkspace()); + Nd4jLong* meanShape = nullptr; + Nd4jLong* varianceShape = nullptr; - auto shapeList = SHAPELIST(); - shapeList->push_back(CONSTANT(meanShape)); - shapeList->push_back(CONSTANT(varianceShape)); + COPY_SHAPE_EX(in, meanShape, block.workspace()); + COPY_SHAPE_EX(in, varianceShape, block.workspace()); - return shapeList; - } + auto shapeList = SHAPELIST(); + shapeList->push_back(CONSTANT(meanShape)); + shapeList->push_back(CONSTANT(varianceShape)); - DECLARE_TYPES(normalize_moments) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ ALL_FLOATS }); - } - } + return shapeList; +} +DECLARE_TYPES(normalize_moments) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops + +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/nth_element.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/nth_element.cpp index b9326a981a72..1c843fbc270a 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/nth_element.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/nth_element.cpp @@ -22,58 +22,63 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(nth_element, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto n = INPUT_VARIABLE(1); - bool reverse = false; - if (block.getIArguments()->size() > 0) - reverse = (bool)INT_ARG(0); +namespace ops { +CUSTOM_OP_IMPL(nth_element, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto n = INPUT_VARIABLE(1); + bool reverse = false; + if (block.numI() > 0) reverse = (bool)INT_ARG(0); - auto output = OUTPUT_VARIABLE(0); - Nd4jLong lastDim = input->sizeAt(-1); - int nVal = n->e(0); - REQUIRE_TRUE(nVal < lastDim && nVal >= 0, 0, "nth_element: n should be non-negative and less than last dimension size (%lld), but %i was given.", lastDim, n); - REQUIRE_TRUE(input->rankOf() > 0, 0, "nth_element: The rank of input array should be at least 1, but %i is given", input->rankOf()); // - if (output->lengthOf() == input->lengthOf()) - output->assign(input); - else { -// if (!input->isVector() && reverse) -// n->assign(lastDim - n->e(0) - 1); - helpers::nthElementFunctor(block.launchContext(), input, nVal, output, reverse); - } - return ND4J_STATUS_OK; - } + auto output = OUTPUT_VARIABLE(0); + Nd4jLong lastDim = input->sizeAt(-1); + int nVal = n->e(0); + REQUIRE_TRUE(nVal < lastDim && nVal >= 0, 0, + "nth_element: n should be non-negative and less than last " + "dimension size (%lld), but %i was given.", + lastDim, n); + REQUIRE_TRUE(input->rankOf() > 0, 0, + "nth_element: The rank of input array should be at least 1, but " + "%i is given", + input->rankOf()); // + if (output->lengthOf() == input->lengthOf()) + output->assign(input); + else { + // if (!input->isVector() && reverse) + // n->assign(lastDim - n->e(0) - 1); + helpers::nthElementFunctor(block.launchContext(), input, nVal, output, + reverse); + } + return ND4J_STATUS_OK; +} - DECLARE_SHAPE_FN(nth_element) { +DECLARE_SHAPE_FN(nth_element) { + auto in = inputShape->at(0); + int outRank = shape::rank(in) - 1; + Nd4jLong const* outShape = nullptr; + if (outRank > 1) { + Nd4jLong* outputShape = nullptr; + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + outputShape[0] = outRank; + for (Nd4jLong e = 0; e < outRank; e++) outputShape[e + 1] = in[e + 1]; - auto in = inputShape->at(0); - int outRank = shape::rank(in) - 1; - Nd4jLong const* outShape = nullptr; - if (outRank > 1) { - Nd4jLong *outputShape = nullptr; - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); - outputShape[0] = outRank; - for (Nd4jLong e = 0; e < outRank; e++) - outputShape[e + 1] = in[e + 1]; + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + outShape = CONSTANT(outputShape); + } else if (outRank == 1) { + outShape = ConstantShapeHelper::getInstance().vectorShapeInfo( + shape::sizeAt(in, 0), ArrayOptions::dataType(in)); + } else { + // outputShape = shape::createScalarShapeInfo(); + outShape = ConstantShapeHelper::getInstance().scalarShapeInfo( + ArrayOptions::dataType(in)); + } + return SHAPELIST(outShape); +} +DECLARE_TYPES(nth_element) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::ANY); +} - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - outShape = CONSTANT(outputShape); - } - else if (outRank == 1) { - outShape = ConstantShapeHelper::getInstance().vectorShapeInfo(shape::sizeAt(in, 0), ArrayOptions::dataType(in)); - } - else { - //outputShape = shape::createScalarShapeInfo(); - outShape = ConstantShapeHelper::getInstance().scalarShapeInfo(ArrayOptions::dataType(in)); - } - return SHAPELIST(outShape); - } - DECLARE_TYPES(nth_element) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::ANY); - } - - } -} \ No newline at end of file +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp index 5b25ea7e6608..12d061a896a4 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp @@ -21,96 +21,89 @@ #include #if NOT_EXCLUDED(OP_onehot) -#include #include +#include #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(onehot, 1, 1, false, -2, -2) { - auto input = INPUT_VARIABLE(0); - - // FIXME: double? - double on(1.0f); // T_ARG(0); - double off(0.0f); //T_ARG(1); +namespace ops { +CUSTOM_OP_IMPL(onehot, 1, 1, false, -2, -2) { + auto input = INPUT_VARIABLE(0); - auto depth = -1; //INT_ARG(0); - auto axis = -1; //INT_ARG(1); + // FIXME: double? + double on(1.0f); // T_ARG(0); + double off(0.0f); // T_ARG(1); - if (block.numI() > 0) - axis = INT_ARG(0); + auto depth = -1; // INT_ARG(0); + auto axis = -1; // INT_ARG(1); - if (block.numI() > 1) { - depth = INT_ARG(1); - } else if (block.width() > 1) { - depth = INPUT_VARIABLE(1)->e(0); - } + if (block.numI() > 0) axis = INT_ARG(0); - REQUIRE_TRUE(depth > 0, 0, "OneHot: depth must be positive value"); + if (block.numI() > 1) { + depth = INT_ARG(1); + } else if (block.width() > 1) { + depth = INPUT_VARIABLE(1)->e(0); + } + REQUIRE_TRUE(depth > 0, 0, "OneHot: depth must be positive value"); - if (block.width() > 2) { - on = INPUT_VARIABLE(2)->e(0); + if (block.width() > 2) { + on = INPUT_VARIABLE(2)->e(0); - if (block.width() > 3) - off = INPUT_VARIABLE(3)->e(0); - } else if (block.numT() > 0) { - on = T_ARG(0); + if (block.width() > 3) off = INPUT_VARIABLE(3)->e(0); + } else if (block.numT() > 0) { + on = T_ARG(0); - if (block.numT() > 1) - off = T_ARG(1); - } + if (block.numT() > 1) off = T_ARG(1); + } - auto output = OUTPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - if (axis < 0) - axis = output->rankOf() + axis; + if (axis < 0) axis = output->rankOf() + axis; - helpers::onehot(block.launchContext(), input, output, axis, depth, on, off); + helpers::onehot(block.launchContext(), input, output, axis, depth, on, off); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(onehot) { - auto inShape = inputShape->at(0); +DECLARE_SHAPE_FN(onehot) { + auto inShape = inputShape->at(0); - sd::DataType dtype = block.numD() > 0 ? D_ARG(0) : sd::DataType::FLOAT32; + sd::DataType dtype = block.numD() > 0 ? D_ARG(0) : sd::DataType::FLOAT32; - int depth = -1; - Nd4jLong axis = -1; + int depth = -1; + Nd4jLong axis = -1; - if (block.numI() > 0) - axis = INT_ARG(0); + if (block.numI() > 0) axis = INT_ARG(0); - if (block.numI() > 1) { - depth = INT_ARG(1); - } else if (block.width() > 1) { - depth = INPUT_VARIABLE(1)->e(0); - } + if (block.numI() > 1) { + depth = INT_ARG(1); + } else if (block.width() > 1) { + depth = INPUT_VARIABLE(1)->e(0); + } - REQUIRE_TRUE(depth > 0, 0, "OneHot: depth must be positive value"); + REQUIRE_TRUE(depth > 0, 0, "OneHot: depth must be positive value"); - int rank = shape::rank(inShape); + int rank = shape::rank(inShape); - if (axis < 0) - axis = rank + 1 + axis; + if (axis < 0) axis = rank + 1 + axis; - std::vector shape; - for (int e = 0; e < rank; e++) - shape.push_back(shape::shapeOf(inShape)[e]); + std::vector shape; + for (int e = 0; e < rank; e++) shape.push_back(shape::shapeOf(inShape)[e]); - shape.insert(shape.begin() + axis, depth); - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(dtype, 'c', rank + 1, shape.data()); + shape.insert(shape.begin() + axis, depth); + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + dtype, 'c', rank + 1, shape.data()); - return SHAPELIST(newShape); - } + return SHAPELIST(newShape); +} - DECLARE_TYPES(onehot) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); - } - } +DECLARE_TYPES(onehot) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/rint.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/rint.cpp index 1e1058397f08..ed3ef117832d 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/rint.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/rint.cpp @@ -24,22 +24,22 @@ #include namespace sd { - namespace ops { - OP_IMPL(rint, 1, 1, true) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); - - x->applyTransform(transform::Rint, *z); - - return Status::OK(); - } - } - - DECLARE_TYPES(rint) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +namespace ops { +OP_IMPL(rint, 1, 1, true) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + + x->applyTransform(transform::Rint, *z); + + return Status::OK(); +} +} // namespace ops + +DECLARE_TYPES(rint) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp index 75f102fa0cc3..84ef5988c763 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp @@ -22,89 +22,98 @@ #if NOT_EXCLUDED(OP_roll) #include -#include #include +#include namespace sd { namespace ops { - CONFIGURABLE_OP_IMPL(roll, 1, 1, true, 0, 0) { - auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - int inputLen = input->lengthOf(); +CONFIGURABLE_OP_IMPL(roll, 1, 1, true, 0, 0) { + auto output = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + int inputLen = input->lengthOf(); - bool shiftIsLinear = block.width() == 1; - std::vector axes; - std::vector shifts; - if (block.width() > 1) { - REQUIRE_TRUE(block.width() == 3, 0, "roll: 3 arguments required for roll - input, shifts and axes. But %i given.", block.width()); - auto axesI = INPUT_VARIABLE(2); - auto shiftsI = INPUT_VARIABLE(1); - REQUIRE_TRUE(axesI->rankOf() == shiftsI->rankOf(), 0, "roll: shifts and axes should be the same rank, but %i and %i given.", (int)shiftsI->rankOf(), (int)axesI->rankOf()); - REQUIRE_TRUE(axesI->lengthOf() == shiftsI->lengthOf(), 0, "roll: shifts and axes should be the same length, but %i and %i given.", (int)shiftsI->lengthOf(), (int)axesI->lengthOf()); - helpers::adjustAxis(axesI->lengthOf(), axesI, axes ); - shifts.resize(shiftsI->lengthOf()); - for (Nd4jLong i = 0; i < shiftsI->lengthOf(); i++) { - auto shift = shiftsI->e(i); - if (shift < 0) { - shift -= input->sizeAt(i) * (shift / inputLen - 1); - } - else { - shift %= input->sizeAt(i); - } - shifts[i] = shift; - } - - } - else { - int shift = INT_ARG(0); - if (shift < 0) { - // convert shift to positive value between 1 and inputLen - 1 - shift -= inputLen * (shift / inputLen - 1); - } - else - // cut shift to value between 1 and inputLen - 1 - shift %= inputLen; - axes.resize(block.getIArguments()->size() - 1); - if (axes.size()) - shifts.resize(axes.size());//emplace_back(shift); - else - shifts.push_back(shift); + bool shiftIsLinear = block.width() == 1; + std::vector axes; + std::vector shifts; + if (block.width() > 1) { + REQUIRE_TRUE(block.width() == 3, 0, + "roll: 3 arguments required for roll - input, shifts and " + "axes. But %i given.", + block.width()); + auto axesI = INPUT_VARIABLE(2); + auto shiftsI = INPUT_VARIABLE(1); + REQUIRE_TRUE( + axesI->rankOf() == shiftsI->rankOf(), 0, + "roll: shifts and axes should be the same rank, but %i and %i given.", + (int)shiftsI->rankOf(), (int)axesI->rankOf()); + REQUIRE_TRUE( + axesI->lengthOf() == shiftsI->lengthOf(), 0, + "roll: shifts and axes should be the same length, but %i and %i given.", + (int)shiftsI->lengthOf(), (int)axesI->lengthOf()); + helpers::adjustAxis(axesI->lengthOf(), axesI, axes); + shifts.resize(shiftsI->lengthOf()); + for (Nd4jLong i = 0; i < shiftsI->lengthOf(); i++) { + auto shift = shiftsI->e(i); + if (shift < 0) { + shift -= input->sizeAt(i) * (shift / inputLen - 1); + } else { + shift %= input->sizeAt(i); + } + shifts[i] = shift; + } - for (auto& s: shifts) - s = shift; + } else { + int shift = INT_ARG(0); + if (shift < 0) { + // convert shift to positive value between 1 and inputLen - 1 + shift -= inputLen * (shift / inputLen - 1); + } else + // cut shift to value between 1 and inputLen - 1 + shift %= inputLen; + axes.resize(block.numI() - 1); + if (axes.size()) + shifts.resize(axes.size()); // emplace_back(shift); + else + shifts.push_back(shift); - for (unsigned e = 0; e < axes.size(); ++e) { - int axis = INT_ARG(e + 1); - REQUIRE_TRUE(axis < input->rankOf() && axis >= -input->rankOf(), 0, "roll: axe value should be between -%i and %i, but %i was given.", - input->rankOf(), input->rankOf() - 1, axis); - axes[e] = (axis < 0? (input->rankOf() + axis) : axis); - } - } + for (auto& s : shifts) s = shift; - if (block.isInplace()) output = input; + for (unsigned e = 0; e < axes.size(); ++e) { + int axis = INT_ARG(e + 1); + REQUIRE_TRUE( + axis < input->rankOf() && axis >= -input->rankOf(), 0, + "roll: axe value should be between -%i and %i, but %i was given.", + input->rankOf(), input->rankOf() - 1, axis); + axes[e] = (axis < 0 ? (input->rankOf() + axis) : axis); + } + } - shiftIsLinear = (axes.size() == 0) || (input->rankOf() == 1); + if (block.isInplace()) output = input; - if (shiftIsLinear) { - helpers::rollFunctorLinear(block.launchContext(), input, output, shifts[0], block.isInplace()); - } - else { - helpers::rollFunctorFull(block.launchContext(), input, output, shifts, axes, block.isInplace()); - } + shiftIsLinear = (axes.size() == 0) || (input->rankOf() == 1); - return Status::OK(); - } + if (shiftIsLinear) { + helpers::rollFunctorLinear(block.launchContext(), input, output, shifts[0], + block.isInplace()); + } else { + helpers::rollFunctorFull(block.launchContext(), input, output, shifts, axes, + block.isInplace()); + } - DECLARE_TYPES(roll) { - getOpDescriptor() - ->setAllowedInputTypes(0,sd::DataType::ANY) - ->setAllowedInputTypes(1,sd::DataType::INT32) // TODO: all ints in future - ->setAllowedInputTypes(2,sd::DataType::INT32) - ->setAllowedOutputTypes(sd::DataType::ANY) - ->setSameMode(true); - } + return Status::OK(); } + +DECLARE_TYPES(roll) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, + sd::DataType::INT32) // TODO: all ints in future + ->setAllowedInputTypes(2, sd::DataType::INT32) + ->setAllowedOutputTypes(sd::DataType::ANY) + ->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/segment_max.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/segment_max.cpp index b348c4549643..4d65859fc49d 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/segment_max.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/segment_max.cpp @@ -22,84 +22,96 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(segment_max, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(idxSegments->isVector(), 0, "segment_max: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "segment_max: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); - - - auto expected = NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); - auto wrong = NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); - - REQUIRE_TRUE(helpers::segmentIndicesValidate(block.launchContext(), idxSegments, expected, wrong), 0, "segment_max: segment indices should be arranged, but %2.1f > %2.1f", expected.e(0), wrong.e(0)); - - segmentedOutput->nullify(); - helpers::segmentMaxFunctor(block.launchContext(), input, idxSegments, segmentedOutput); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(segment_max) { - auto idxVector = INPUT_VARIABLE(1); - - auto in = inputShape->at(0); - int outRank = shape::rank(in); - Nd4jLong* outputShape = nullptr; - idxVector->syncToHost(); - int val = (*idxVector).e(idxVector->lengthOf() - 1); - - int numOfClasses = val + 1; - - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); - - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); - - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - - return SHAPELIST(CONSTANT(outputShape)); - } - - DECLARE_TYPES(segment_max) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) - ->setSameMode(false); - } - CUSTOM_OP_IMPL(segment_max_bp, 3, 2, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto gradOut = INPUT_VARIABLE(2); - auto output = OUTPUT_NULLIFIED(0); - auto outIndices = OUTPUT_NULLIFIED(1); - outIndices->assign(indices); - return helpers::segmentMaxFunctorBP(block.launchContext(), input, indices, gradOut, output); - } - DECLARE_SHAPE_FN(segment_max_bp){ - auto in = inputShape->at(0); - auto inIdx = inputShape->at(1); - - Nd4jLong* outShape; - Nd4jLong* outIndex; - COPY_SHAPE(in, outShape); - COPY_SHAPE(inIdx, outIndex); - return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); - - } - DECLARE_TYPES(segment_max_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) - ->setSameMode(true); - } - - } +namespace ops { +CUSTOM_OP_IMPL(segment_max, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto idxSegments = INPUT_VARIABLE(1); + auto segmentedOutput = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(idxSegments->isVector(), 0, + "segment_max: segment indexes array should be a vector, but it " + "rank is %i.", + idxSegments->rankOf()); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, + "segment_max: segment indexes array length should be equal to " + "the input first dimension, but %i != %i.", + idxSegments->lengthOf(), input->sizeAt(0)); + + auto expected = + NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); + auto wrong = + NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); + + REQUIRE_TRUE( + helpers::segmentIndicesValidate(block.launchContext(), idxSegments, + expected, wrong), + 0, "segment_max: segment indices should be arranged, but %2.1f > %2.1f", + expected.e(0), wrong.e(0)); + + segmentedOutput->nullify(); + helpers::segmentMaxFunctor(block.launchContext(), input, idxSegments, + segmentedOutput); + + return Status::OK(); +} + +DECLARE_SHAPE_FN(segment_max) { + auto idxVector = INPUT_VARIABLE(1); + + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong* outputShape = nullptr; + idxVector->syncToHost(); + int val = (*idxVector).e(idxVector->lengthOf() - 1); + + int numOfClasses = val + 1; + + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for (int i = 1; i < outRank; ++i) outputShape[i + 1] = shape::sizeAt(in, i); + + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + return SHAPELIST(CONSTANT(outputShape)); } + +DECLARE_TYPES(segment_max) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) + ->setSameMode(false); +} +CUSTOM_OP_IMPL(segment_max_bp, 3, 2, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto gradOut = INPUT_VARIABLE(2); + auto output = OUTPUT_NULLIFIED(0); + auto outIndices = OUTPUT_NULLIFIED(1); + outIndices->assign(indices); + return helpers::segmentMaxFunctorBP(block.launchContext(), input, indices, + gradOut, output); +} +DECLARE_SHAPE_FN(segment_max_bp) { + auto in = inputShape->at(0); + auto inIdx = inputShape->at(1); + + Nd4jLong* outShape; + Nd4jLong* outIndex; + COPY_SHAPE(in, outShape); + COPY_SHAPE(inIdx, outIndex); + return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); +} +DECLARE_TYPES(segment_max_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setSameMode(true); +} + +} // namespace ops + +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/segment_mean.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/segment_mean.cpp index 1d8a5bb7f787..2e7ccbad8fd6 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/segment_mean.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/segment_mean.cpp @@ -22,82 +22,95 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(segment_mean, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(idxSegments->isVector(), 0, "segment_mean: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "segment_mean: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); - - auto expected = NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); - auto wrong = NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); - - REQUIRE_TRUE(helpers::segmentIndicesValidate(block.launchContext(), idxSegments, expected, wrong), 0, "segment_mean: segment indices should be arranged, but %2.1f > %2.1f", expected.e(0), wrong.e(0)); - - segmentedOutput->nullify(); - helpers::segmentMeanFunctor(block.launchContext(), input, idxSegments, segmentedOutput); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(segment_mean) { - auto idxVector = INPUT_VARIABLE(1); - - auto in = inputShape->at(0); - int outRank = shape::rank(in); - Nd4jLong* outputShape = nullptr; - int val = (*idxVector).e(idxVector->lengthOf() - 1); - - int numOfClasses = val + 1; - - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); - - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); - - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - - return SHAPELIST(CONSTANT(outputShape)); - } - - DECLARE_TYPES(segment_mean) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(false); - } - - - CUSTOM_OP_IMPL(segment_mean_bp, 3, 2, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto gradOut = INPUT_VARIABLE(2); - auto output = OUTPUT_NULLIFIED(0); - auto outIndices = OUTPUT_NULLIFIED(1); - outIndices->assign(indices); - return helpers::segmentMeanFunctorBP(block.launchContext(), input, indices, gradOut, output); - } - DECLARE_SHAPE_FN(segment_mean_bp){ - auto in = inputShape->at(0); - auto inIdx = inputShape->at(1); - - Nd4jLong* outShape; - Nd4jLong* outIndex; - COPY_SHAPE(in, outShape); - COPY_SHAPE(inIdx, outIndex); - return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); -// return SHAPELIST(in, inIdx); - } - DECLARE_TYPES(segment_mean_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) - ->setSameMode(false); - } - } +namespace ops { +CUSTOM_OP_IMPL(segment_mean, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto idxSegments = INPUT_VARIABLE(1); + auto segmentedOutput = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(idxSegments->isVector(), 0, + "segment_mean: segment indexes array should be a vector, but it " + "rank is %i.", + idxSegments->rankOf()); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, + "segment_mean: segment indexes array length should be equal to " + "the input first dimension, but %i != %i.", + idxSegments->lengthOf(), input->sizeAt(0)); + + auto expected = + NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); + auto wrong = + NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); + + REQUIRE_TRUE( + helpers::segmentIndicesValidate(block.launchContext(), idxSegments, + expected, wrong), + 0, "segment_mean: segment indices should be arranged, but %2.1f > %2.1f", + expected.e(0), wrong.e(0)); + + segmentedOutput->nullify(); + helpers::segmentMeanFunctor(block.launchContext(), input, idxSegments, + segmentedOutput); + + return Status::OK(); } + +DECLARE_SHAPE_FN(segment_mean) { + auto idxVector = INPUT_VARIABLE(1); + + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong* outputShape = nullptr; + int val = (*idxVector).e(idxVector->lengthOf() - 1); + + int numOfClasses = val + 1; + + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for (int i = 1; i < outRank; ++i) outputShape[i + 1] = shape::sizeAt(in, i); + + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + + return SHAPELIST(CONSTANT(outputShape)); +} + +DECLARE_TYPES(segment_mean) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(false); +} + +CUSTOM_OP_IMPL(segment_mean_bp, 3, 2, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto gradOut = INPUT_VARIABLE(2); + auto output = OUTPUT_NULLIFIED(0); + auto outIndices = OUTPUT_NULLIFIED(1); + outIndices->assign(indices); + return helpers::segmentMeanFunctorBP(block.launchContext(), input, indices, + gradOut, output); +} +DECLARE_SHAPE_FN(segment_mean_bp) { + auto in = inputShape->at(0); + auto inIdx = inputShape->at(1); + + Nd4jLong* outShape; + Nd4jLong* outIndex; + COPY_SHAPE(in, outShape); + COPY_SHAPE(inIdx, outIndex); + return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); + // return SHAPELIST(in, inIdx); +} +DECLARE_TYPES(segment_mean_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setSameMode(false); +} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/segment_min.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/segment_min.cpp index 10bc1dd26c55..d78eedc7f209 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/segment_min.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/segment_min.cpp @@ -22,80 +22,94 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(segment_min, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(idxSegments->isVector(), 0, "segment_min: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "segment_min: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); - - auto expected = NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); - auto wrong = NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); - - REQUIRE_TRUE(helpers::segmentIndicesValidate(block.launchContext(), idxSegments, expected, wrong), 0, "segment_min: segment indices should be arranged, but %2.1f > %2.1f", expected.e(0), wrong.e(0)); - - segmentedOutput->nullify(); - helpers::segmentMinFunctor(block.launchContext(), input, idxSegments, segmentedOutput); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(segment_min) { - auto idxVector = INPUT_VARIABLE(1); - - auto in = inputShape->at(0); - int outRank = shape::rank(in); - Nd4jLong* outputShape = nullptr; - int val = (*idxVector).e(idxVector->lengthOf() - 1); - - int numOfClasses = val + 1; - - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); - - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); - - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - - return SHAPELIST(CONSTANT(outputShape)); - } - CUSTOM_OP_IMPL(segment_min_bp, 3, 2, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto gradOut = INPUT_VARIABLE(2); - auto output = OUTPUT_NULLIFIED(0); - auto outIndices = OUTPUT_NULLIFIED(1); - outIndices->assign(indices); - return helpers::segmentMinFunctorBP(block.launchContext(), input, indices, gradOut, output); - } - DECLARE_SHAPE_FN(segment_min_bp){ - auto in = inputShape->at(0); - auto inIdx = inputShape->at(1); - - Nd4jLong* outShape; - Nd4jLong* outIndex; - COPY_SHAPE(in, outShape); - COPY_SHAPE(inIdx, outIndex); - return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); -// return SHAPELIST(in, inIdx); - } - - DECLARE_TYPES(segment_min) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) - ->setSameMode(false); - } - DECLARE_TYPES(segment_min_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) - ->setSameMode(true); - } - } +namespace ops { +CUSTOM_OP_IMPL(segment_min, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto idxSegments = INPUT_VARIABLE(1); + auto segmentedOutput = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(idxSegments->isVector(), 0, + "segment_min: segment indexes array should be a vector, but it " + "rank is %i.", + idxSegments->rankOf()); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, + "segment_min: segment indexes array length should be equal to " + "the input first dimension, but %i != %i.", + idxSegments->lengthOf(), input->sizeAt(0)); + + auto expected = + NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); + auto wrong = + NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); + + REQUIRE_TRUE( + helpers::segmentIndicesValidate(block.launchContext(), idxSegments, + expected, wrong), + 0, "segment_min: segment indices should be arranged, but %2.1f > %2.1f", + expected.e(0), wrong.e(0)); + + segmentedOutput->nullify(); + helpers::segmentMinFunctor(block.launchContext(), input, idxSegments, + segmentedOutput); + + return Status::OK(); } + +DECLARE_SHAPE_FN(segment_min) { + auto idxVector = INPUT_VARIABLE(1); + + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong* outputShape = nullptr; + int val = (*idxVector).e(idxVector->lengthOf() - 1); + + int numOfClasses = val + 1; + + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for (int i = 1; i < outRank; ++i) outputShape[i + 1] = shape::sizeAt(in, i); + + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + + return SHAPELIST(CONSTANT(outputShape)); +} +CUSTOM_OP_IMPL(segment_min_bp, 3, 2, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto gradOut = INPUT_VARIABLE(2); + auto output = OUTPUT_NULLIFIED(0); + auto outIndices = OUTPUT_NULLIFIED(1); + outIndices->assign(indices); + return helpers::segmentMinFunctorBP(block.launchContext(), input, indices, + gradOut, output); +} +DECLARE_SHAPE_FN(segment_min_bp) { + auto in = inputShape->at(0); + auto inIdx = inputShape->at(1); + + Nd4jLong* outShape; + Nd4jLong* outIndex; + COPY_SHAPE(in, outShape); + COPY_SHAPE(inIdx, outIndex); + return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); + // return SHAPELIST(in, inIdx); +} + +DECLARE_TYPES(segment_min) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) + ->setSameMode(false); +} +DECLARE_TYPES(segment_min_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setSameMode(true); +} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/segment_prod.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/segment_prod.cpp index 4f83ac9b0dc3..9cbfea381f14 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/segment_prod.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/segment_prod.cpp @@ -22,85 +22,98 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(segment_prod, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(idxSegments->isVector(), 0, "segment_prod: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "segment_prod: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); - - auto expected = NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); - auto wrong = NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); - - REQUIRE_TRUE(helpers::segmentIndicesValidate(block.launchContext(), idxSegments, expected, wrong), 0, "segment_prod: segment indices should be arranged, but %2.1f > %2.1f", expected.e(0), wrong.e(0)); - - segmentedOutput->nullify(); - helpers::segmentProdFunctor(block.launchContext(), input, idxSegments, segmentedOutput); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(segment_prod) { - auto idxVector = INPUT_VARIABLE(1); - - auto in = inputShape->at(0); - int outRank = shape::rank(in); - Nd4jLong* outputShape = nullptr; - int val = (*idxVector).e(idxVector->lengthOf() - 1); - - int numOfClasses = val + 1; - - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); - - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); - - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - - return SHAPELIST(CONSTANT(outputShape)); - } - - CUSTOM_OP_IMPL(segment_prod_bp, 3, 2, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto gradOut = INPUT_VARIABLE(2); - auto output = OUTPUT_NULLIFIED(0); - auto outIndices = OUTPUT_NULLIFIED(1); - outIndices->assign(indices); - helpers::segmentProdFunctorBP(block.launchContext(), input, indices, gradOut, output); - - return Status::OK(); - } - - DECLARE_TYPES(segment_prod) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) - ->setSameMode(false); - } - - - DECLARE_SHAPE_FN(segment_prod_bp){ - auto in = inputShape->at(0); - auto inIdx = inputShape->at(1); - - Nd4jLong* outShape; - Nd4jLong* outIndex; - COPY_SHAPE(in, outShape); - COPY_SHAPE(inIdx, outIndex); - return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); - } - - DECLARE_TYPES(segment_prod_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) - ->setSameMode(false); - } - } +namespace ops { +CUSTOM_OP_IMPL(segment_prod, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto idxSegments = INPUT_VARIABLE(1); + auto segmentedOutput = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(idxSegments->isVector(), 0, + "segment_prod: segment indexes array should be a vector, but it " + "rank is %i.", + idxSegments->rankOf()); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, + "segment_prod: segment indexes array length should be equal to " + "the input first dimension, but %i != %i.", + idxSegments->lengthOf(), input->sizeAt(0)); + + auto expected = + NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); + auto wrong = + NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); + + REQUIRE_TRUE( + helpers::segmentIndicesValidate(block.launchContext(), idxSegments, + expected, wrong), + 0, "segment_prod: segment indices should be arranged, but %2.1f > %2.1f", + expected.e(0), wrong.e(0)); + + segmentedOutput->nullify(); + helpers::segmentProdFunctor(block.launchContext(), input, idxSegments, + segmentedOutput); + + return Status::OK(); } + +DECLARE_SHAPE_FN(segment_prod) { + auto idxVector = INPUT_VARIABLE(1); + + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong* outputShape = nullptr; + int val = (*idxVector).e(idxVector->lengthOf() - 1); + + int numOfClasses = val + 1; + + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for (int i = 1; i < outRank; ++i) outputShape[i + 1] = shape::sizeAt(in, i); + + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + + return SHAPELIST(CONSTANT(outputShape)); +} + +CUSTOM_OP_IMPL(segment_prod_bp, 3, 2, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto gradOut = INPUT_VARIABLE(2); + auto output = OUTPUT_NULLIFIED(0); + auto outIndices = OUTPUT_NULLIFIED(1); + outIndices->assign(indices); + helpers::segmentProdFunctorBP(block.launchContext(), input, indices, gradOut, + output); + + return Status::OK(); +} + +DECLARE_TYPES(segment_prod) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) + ->setSameMode(false); +} + +DECLARE_SHAPE_FN(segment_prod_bp) { + auto in = inputShape->at(0); + auto inIdx = inputShape->at(1); + + Nd4jLong* outShape; + Nd4jLong* outIndex; + COPY_SHAPE(in, outShape); + COPY_SHAPE(inIdx, outIndex); + return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); +} + +DECLARE_TYPES(segment_prod_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setSameMode(false); +} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/segment_sum.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/segment_sum.cpp index cb4734c5fb8f..21b80d4e6e6c 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/segment_sum.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/segment_sum.cpp @@ -22,75 +22,87 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(segment_sum, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(idxSegments->isVector(), 0, "segment_sum: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "segment_sum: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); - - auto expected = NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); - auto wrong = NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); - - REQUIRE_TRUE(helpers::segmentIndicesValidate(block.launchContext(), idxSegments, expected, wrong), 0, "segment_sum: segment indices should be arranged, but %2.1f > %2.1f", expected.e(0), wrong.e(0)); - - segmentedOutput->nullify(); - helpers::segmentSumFunctor(block.launchContext(), input, idxSegments, segmentedOutput); - - return ND4J_STATUS_OK; - } - - DECLARE_SHAPE_FN(segment_sum) { - auto idxVector = INPUT_VARIABLE(1); - - auto in = inputShape->at(0); - int outRank = shape::rank(in); - Nd4jLong* outputShape = nullptr; - int val = (*idxVector).e(idxVector->lengthOf() - 1); - - int numOfClasses = static_cast(val) + 1; - - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); - - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); - - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - - return SHAPELIST(CONSTANT(outputShape)); - } - - CUSTOM_OP_IMPL(segment_sum_bp, 3, 2, false, 0, 0) { - - return helpers::segmentSumFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), OUTPUT_NULLIFIED(0)); - } - DECLARE_SHAPE_FN(segment_sum_bp){ - auto in = inputShape->at(0); - auto inIdx = inputShape->at(1); - - Nd4jLong* outShape; - Nd4jLong* outIndex; - COPY_SHAPE(in, outShape); - COPY_SHAPE(inIdx, outIndex); - return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); -// return SHAPELIST(in, inIdx); - } - - DECLARE_TYPES(segment_sum) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - DECLARE_TYPES(segment_sum_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) - ->setSameMode(false); - } - } +namespace ops { +CUSTOM_OP_IMPL(segment_sum, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto idxSegments = INPUT_VARIABLE(1); + auto segmentedOutput = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(idxSegments->isVector(), 0, + "segment_sum: segment indexes array should be a vector, but it " + "rank is %i.", + idxSegments->rankOf()); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, + "segment_sum: segment indexes array length should be equal to " + "the input first dimension, but %i != %i.", + idxSegments->lengthOf(), input->sizeAt(0)); + + auto expected = + NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); + auto wrong = + NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); + + REQUIRE_TRUE( + helpers::segmentIndicesValidate(block.launchContext(), idxSegments, + expected, wrong), + 0, "segment_sum: segment indices should be arranged, but %2.1f > %2.1f", + expected.e(0), wrong.e(0)); + + segmentedOutput->nullify(); + helpers::segmentSumFunctor(block.launchContext(), input, idxSegments, + segmentedOutput); + + return ND4J_STATUS_OK; +} + +DECLARE_SHAPE_FN(segment_sum) { + auto idxVector = INPUT_VARIABLE(1); + + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong* outputShape = nullptr; + int val = (*idxVector).e(idxVector->lengthOf() - 1); + + int numOfClasses = static_cast(val) + 1; + + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for (int i = 1; i < outRank; ++i) outputShape[i + 1] = shape::sizeAt(in, i); + + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + return SHAPELIST(CONSTANT(outputShape)); } + +CUSTOM_OP_IMPL(segment_sum_bp, 3, 2, false, 0, 0) { + return helpers::segmentSumFunctorBP(block.launchContext(), INPUT_VARIABLE(0), + INPUT_VARIABLE(1), INPUT_VARIABLE(2), + OUTPUT_NULLIFIED(0)); +} +DECLARE_SHAPE_FN(segment_sum_bp) { + auto in = inputShape->at(0); + auto inIdx = inputShape->at(1); + + Nd4jLong* outShape; + Nd4jLong* outIndex; + COPY_SHAPE(in, outShape); + COPY_SHAPE(inIdx, outIndex); + return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); + // return SHAPELIST(in, inIdx); +} + +DECLARE_TYPES(segment_sum) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} +DECLARE_TYPES(segment_sum_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setSameMode(false); +} +} // namespace ops + +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp index 6b0402ebb6ce..0bce3fd0b06c 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp @@ -22,85 +22,78 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(sequence_mask, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_NULLIFIED(0); - const int inRank = input->rankOf(); +namespace ops { +CUSTOM_OP_IMPL(sequence_mask, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_NULLIFIED(0); + const int inRank = input->rankOf(); - //REQUIRE_TRUE(inRank >= 1, 0, "sequence_mask: input array must have rank >= 1, but %i given!", inRank); - Nd4jLong maxInd = input->argMax(); - float max = input->e(maxInd); - if (block.getIArguments()->size() > 0) { - maxInd = INT_ARG(0); - if (maxInd < max) - maxInd = static_cast(max); - } - else if (block.width() > 1) { - auto maxlen = INPUT_VARIABLE(1); - //REQUIRE_TRUE(maxlen->lengthOf() == 1, "sequence_mask: 2nd input (max length) should be a scalar array."); - float tmaxlen = maxlen->e(0); - if (tmaxlen > max) - maxInd = static_cast(tmaxlen); - } - else - maxInd = static_cast(max); + // REQUIRE_TRUE(inRank >= 1, 0, "sequence_mask: input array must have rank >= + // 1, but %i given!", inRank); + Nd4jLong maxInd = input->argMax(); + float max = input->e(maxInd); + if (block.numI() > 0) { + maxInd = INT_ARG(0); + if (maxInd < max) maxInd = static_cast(max); + } else if (block.width() > 1) { + auto maxlen = INPUT_VARIABLE(1); + // REQUIRE_TRUE(maxlen->lengthOf() == 1, "sequence_mask: 2nd input (max + // length) should be a scalar array."); + float tmaxlen = maxlen->e(0); + if (tmaxlen > max) maxInd = static_cast(tmaxlen); + } else + maxInd = static_cast(max); - helpers::sequenceMask(block.launchContext(), input, output, maxInd); + helpers::sequenceMask(block.launchContext(), input, output, maxInd); - return Status::OK(); - } - - DECLARE_SHAPE_FN(sequence_mask) { - - Nd4jLong* outShapeInfo = nullptr; - auto in = inputShape->at(0); - int outRank = shape::rank(in) + 1; - auto input = INPUT_VARIABLE(0); - auto dtype = DataType::BOOL; - auto argMaxInd = input->argMax(); - Nd4jLong max = input->e(argMaxInd); - Nd4jLong maxInd = max; + return Status::OK(); +} - if (block.numD() > 0) - dtype = D_ARG(0); +DECLARE_SHAPE_FN(sequence_mask) { + Nd4jLong* outShapeInfo = nullptr; + auto in = inputShape->at(0); + int outRank = shape::rank(in) + 1; + auto input = INPUT_VARIABLE(0); + auto dtype = DataType::BOOL; + auto argMaxInd = input->argMax(); + Nd4jLong max = input->e(argMaxInd); + Nd4jLong maxInd = max; - if (block.width() > 1) { - auto maxlen = INPUT_VARIABLE(1); - Nd4jLong tmaxlen = maxlen->e(0); - if (tmaxlen > max) - maxInd = static_cast(tmaxlen); - if (block.numI() > 0) { - dtype = (DataType) INT_ARG(0); - } - } - else { - if (block.numI() > 0) { - maxInd = INT_ARG(0); - } - if (maxInd < max) - maxInd = max; - if (block.numI() > 1) - dtype = (DataType)INT_ARG(1); // to work with legacy code - } + if (block.numD() > 0) dtype = D_ARG(0); - int lastDimension = maxInd; - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); - outShapeInfo[0] = outRank; - for(int i = 0; i < outRank - 1; ++i) - outShapeInfo[i + 1] = shape::sizeAt(in, i); - outShapeInfo[outRank] = lastDimension; + if (block.width() > 1) { + auto maxlen = INPUT_VARIABLE(1); + Nd4jLong tmaxlen = maxlen->e(0); + if (tmaxlen > max) maxInd = static_cast(tmaxlen); + if (block.numI() > 0) { + dtype = (DataType)INT_ARG(0); + } + } else { + if (block.numI() > 0) { + maxInd = INT_ARG(0); + } + if (maxInd < max) maxInd = max; + if (block.numI() > 1) + dtype = (DataType)INT_ARG(1); // to work with legacy code + } - ShapeUtils::updateStridesAndType(outShapeInfo, dtype, shape::order(in)); + int lastDimension = maxInd; + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + outShapeInfo[0] = outRank; + for (int i = 0; i < outRank - 1; ++i) + outShapeInfo[i + 1] = shape::sizeAt(in, i); + outShapeInfo[outRank] = lastDimension; - return SHAPELIST(CONSTANT(outShapeInfo)); - } + ShapeUtils::updateStridesAndType(outShapeInfo, dtype, shape::order(in)); - DECLARE_TYPES(sequence_mask) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS}) - ->setAllowedOutputTypes(sd::DataType::ANY); - } -} + return SHAPELIST(CONSTANT(outShapeInfo)); } +DECLARE_TYPES(sequence_mask) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS}) + ->setAllowedOutputTypes(sd::DataType::ANY); +} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/square.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/square.cpp index adc3a2cb59e7..f4dd1da572d2 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/square.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/square.cpp @@ -24,23 +24,21 @@ #include namespace sd { - namespace ops { - OP_IMPL(square, 1, 1, true) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - int extras = 2; - input->applyScalar(scalar::Pow, extras, *output); - - return Status::OK(); - } - - DECLARE_TYPES(square) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +namespace ops { +OP_IMPL(square, 1, 1, true) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + int extras = 2; + input->applyScalar(scalar::Pow, extras, *output); + + return Status::OK(); +} + +DECLARE_TYPES(square) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/stop_gradient.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/stop_gradient.cpp index 621df4462e3f..810f995709e5 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/stop_gradient.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/stop_gradient.cpp @@ -24,26 +24,24 @@ #include namespace sd { - namespace ops { - OP_IMPL(stop_gradient, 1, 1, true) { - auto out = OUTPUT_VARIABLE(0); - - if (!block.isInplace()) { - auto x = INPUT_VARIABLE(0); - // we hope for memcpy here - out->assign(x); - } - - return Status::OK(); - } - DECLARE_SYN(StopGradient, stop_gradient); - - DECLARE_TYPES(stop_gradient) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +namespace ops { +OP_IMPL(stop_gradient, 1, 1, true) { + auto out = OUTPUT_VARIABLE(0); + + if (!block.isInplace()) { + auto x = INPUT_VARIABLE(0); + // we hope for memcpy here + out->assign(x); + } + + return Status::OK(); +} +DECLARE_SYN(StopGradient, stop_gradient); + +DECLARE_TYPES(stop_gradient) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/top_k.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/top_k.cpp index b042e94fe60f..20df9a1d152c 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/top_k.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/top_k.cpp @@ -22,74 +22,77 @@ #if NOT_EXCLUDED(OP_top_k) //#include -#include #include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(top_k, 1, 2, false, 0, -1) { - auto x = INPUT_VARIABLE(0); - int k = 1;// from params - bool needSort = true; - - auto values = OUTPUT_VARIABLE(0); - auto indices = OUTPUT_VARIABLE(1); - - if (block.numB() == 1) - needSort = B_ARG(0); - - if (block.width() == 1) { - if (block.numI() > 0) { - k = INT_ARG(0); - } - } else { - k = INPUT_VARIABLE(1)->e(0); - } - - REQUIRE_TRUE(k <= x->sizeAt(-1), 0, "top_k: k should not be greater than last dimension"); - REQUIRE_TRUE(k > 0, 0, "top_k: k should be positive, but %i given.", k); - - int res = helpers::topKFunctor(block.launchContext(), x, values, indices, k, needSort); - return res; - } - - DECLARE_SHAPE_FN(top_k) { - auto shapeList = SHAPELIST(); - auto in = inputShape->at(0); - int shapeRank = shape::rank(in); - int k = 1; // default output shape is size 1 - - if (block.width() == 2) { - k = INPUT_VARIABLE(1)->e(0); - } else if (block.numI() > 0) { - k = INT_ARG(0); - } - - REQUIRE_TRUE(k > 0, 0, "top_k: k should be positive, but %i given.", k); - - for (int e = 0; e < 2; e++) { // 2 element tuple at output - Nd4jLong* aShape; - ALLOCATE(aShape, block.getWorkspace(), shape::shapeInfoLength(shapeRank), Nd4jLong); - aShape[0] = shapeRank; - for (int i = 1 ; i < shapeRank; ++i) - aShape[i] = shape::sizeAt(in, i - 1); - aShape[shapeRank] = k; - - shape::updateStrides(aShape, shape::order(in)); - shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(aShape, (e == 0?ArrayOptions::dataType(in):sd::DataType::INT64)))); - - RELEASE(aShape, block.getWorkspace()); - } - return shapeList; - } - - DECLARE_TYPES(top_k) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(0, sd::DataType::ANY) - ->setAllowedOutputTypes(1, {ALL_INDICES}); - } +namespace ops { +CUSTOM_OP_IMPL(top_k, 1, 2, false, 0, -1) { + auto x = INPUT_VARIABLE(0); + int k = 1; // from params + bool needSort = true; + + auto values = OUTPUT_VARIABLE(0); + auto indices = OUTPUT_VARIABLE(1); + + if (block.numB() == 1) needSort = B_ARG(0); + + if (block.width() == 1) { + if (block.numI() > 0) { + k = INT_ARG(0); } + } else { + k = INPUT_VARIABLE(1)->e(0); + } + + REQUIRE_TRUE(k <= x->sizeAt(-1), 0, + "top_k: k should not be greater than last dimension"); + REQUIRE_TRUE(k > 0, 0, "top_k: k should be positive, but %i given.", k); + + int res = helpers::topKFunctor(block.launchContext(), x, values, indices, k, + needSort); + return res; +} + +DECLARE_SHAPE_FN(top_k) { + auto shapeList = SHAPELIST(); + auto in = inputShape->at(0); + int shapeRank = shape::rank(in); + int k = 1; // default output shape is size 1 + + if (block.width() == 2) { + k = INPUT_VARIABLE(1)->e(0); + } else if (block.numI() > 0) { + k = INT_ARG(0); + } + + REQUIRE_TRUE(k > 0, 0, "top_k: k should be positive, but %i given.", k); + + for (int e = 0; e < 2; e++) { // 2 element tuple at output + Nd4jLong* aShape; + ALLOCATE(aShape, block.workspace(), shape::shapeInfoLength(shapeRank), + Nd4jLong); + aShape[0] = shapeRank; + for (int i = 1; i < shapeRank; ++i) aShape[i] = shape::sizeAt(in, i - 1); + aShape[shapeRank] = k; + + shape::updateStrides(aShape, shape::order(in)); + shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(aShape, (e == 0 ? ArrayOptions::dataType(in) + : sd::DataType::INT64)))); + + RELEASE(aShape, block.workspace()); + } + return shapeList; +} + +DECLARE_TYPES(top_k) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(0, sd::DataType::ANY) + ->setAllowedOutputTypes(1, {ALL_INDICES}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unique.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unique.cpp index 9d234abaacec..08f87a1acfad 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unique.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unique.cpp @@ -25,85 +25,92 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(unique, 1, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto values = OUTPUT_VARIABLE(0); - auto indices = OUTPUT_VARIABLE(1); - - REQUIRE_TRUE(x->dataType() == values->dataType(), 0, "Unique: input and output data types must be the same"); - - return helpers::uniqueFunctor(block.launchContext(), x, values, indices, (NDArray*)nullptr); - } - - DECLARE_SHAPE_FN(unique) { - auto in = inputShape->at(0); - auto source = INPUT_VARIABLE(0); -// auto shapeList = SHAPELIST(); - const Nd4jLong* valuesShape; - const Nd4jLong* indicesShape; - - int uniqueCount = helpers::uniqueCount(block.launchContext(), source); - - if (uniqueCount == 0) { // empty value Shape - valuesShape = ConstantShapeHelper::getInstance().emptyShapeInfo(source->dataType()); - } - else { - // all output shapes are 1D arrays (vectors) - valuesShape = ConstantShapeHelper::getInstance().vectorShapeInfo(uniqueCount, ArrayOptions::dataType(in)); - } - // second output is always LONG - indicesShape = ConstantShapeHelper::getInstance().vectorShapeInfo(shape::length(in), sd::DataType::INT64); - - //COPY_SHAPE_EX(in, indicesShape, block.getWorkspace()); - - return SHAPELIST(valuesShape, indicesShape); - - } - - CUSTOM_OP_IMPL(unique_with_counts, 1, 3, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto values = OUTPUT_VARIABLE(0); - auto indices = OUTPUT_VARIABLE(1); - auto counts = OUTPUT_VARIABLE(2); - - return helpers::uniqueFunctor(block.launchContext(), input, values, indices, counts); - } - - DECLARE_SHAPE_FN(unique_with_counts) { - auto in = inputShape->at(0); - auto source = INPUT_VARIABLE(0); - - int uniqueCount = helpers::uniqueCount(block.launchContext(), source); - // all output shapes are 1D arrays (vectors) - // all output shapes are 1D arrays (vectors) - auto valuesShape = ConstantShapeHelper::getInstance().vectorShapeInfo(uniqueCount, source->dataType()); - - // second output is always LONG - auto indicesShape = ConstantShapeHelper::getInstance().vectorShapeInfo(source->lengthOf(), sd::DataType::INT64); - - // third one as well - auto countsShape = ConstantShapeHelper::getInstance().vectorShapeInfo(uniqueCount, sd::DataType::INT64); - - return SHAPELIST(valuesShape, indicesShape, countsShape); - } - - DECLARE_TYPES(unique) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}); - } - - DECLARE_TYPES(unique_with_counts) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes(2, {ALL_INTS}); - } - - } +namespace ops { +CUSTOM_OP_IMPL(unique, 1, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto values = OUTPUT_VARIABLE(0); + auto indices = OUTPUT_VARIABLE(1); + + REQUIRE_TRUE(x->dataType() == values->dataType(), 0, + "Unique: input and output data types must be the same"); + + return helpers::uniqueFunctor(block.launchContext(), x, values, indices, + (NDArray*)nullptr); +} + +DECLARE_SHAPE_FN(unique) { + auto in = inputShape->at(0); + auto source = INPUT_VARIABLE(0); + // auto shapeList = SHAPELIST(); + const Nd4jLong* valuesShape; + const Nd4jLong* indicesShape; + + int uniqueCount = helpers::uniqueCount(block.launchContext(), source); + + if (uniqueCount == 0) { // empty value Shape + valuesShape = + ConstantShapeHelper::getInstance().emptyShapeInfo(source->dataType()); + } else { + // all output shapes are 1D arrays (vectors) + valuesShape = ConstantShapeHelper::getInstance().vectorShapeInfo( + uniqueCount, ArrayOptions::dataType(in)); + } + // second output is always LONG + indicesShape = ConstantShapeHelper::getInstance().vectorShapeInfo( + shape::length(in), sd::DataType::INT64); + + // COPY_SHAPE_EX(in, indicesShape, block.workspace()); + + return SHAPELIST(valuesShape, indicesShape); +} + +CUSTOM_OP_IMPL(unique_with_counts, 1, 3, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto values = OUTPUT_VARIABLE(0); + auto indices = OUTPUT_VARIABLE(1); + auto counts = OUTPUT_VARIABLE(2); + + return helpers::uniqueFunctor(block.launchContext(), input, values, indices, + counts); +} + +DECLARE_SHAPE_FN(unique_with_counts) { + auto in = inputShape->at(0); + auto source = INPUT_VARIABLE(0); + + int uniqueCount = helpers::uniqueCount(block.launchContext(), source); + // all output shapes are 1D arrays (vectors) + // all output shapes are 1D arrays (vectors) + auto valuesShape = ConstantShapeHelper::getInstance().vectorShapeInfo( + uniqueCount, source->dataType()); + + // second output is always LONG + auto indicesShape = ConstantShapeHelper::getInstance().vectorShapeInfo( + source->lengthOf(), sd::DataType::INT64); + + // third one as well + auto countsShape = ConstantShapeHelper::getInstance().vectorShapeInfo( + uniqueCount, sd::DataType::INT64); + + return SHAPELIST(valuesShape, indicesShape, countsShape); } +DECLARE_TYPES(unique) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INTS}); +} + +DECLARE_TYPES(unique_with_counts) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes(2, {ALL_INTS}); +} + +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_max.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_max.cpp index 1909005a7ae1..71a2e07f8e27 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_max.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_max.cpp @@ -22,73 +22,87 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(unsorted_segment_max, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_NULLIFIED(0); - Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_max: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_max: segment indexes array length should be equal to the input first dimension, but %ld != %ild.", idxSegments->lengthOf(), input->sizeAt(0)); - - Nd4jLong wrong; - - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_max: segment indices should be in range [0, %ld), but %ld != %ld", - numOfClasses, wrong, numOfClasses); - - helpers::unsortedSegmentMaxFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); - - return ND4J_STATUS_OK; - } - DECLARE_TYPES(unsorted_segment_max) { - getOpDescriptor() - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setSameMode(true); - } - DECLARE_SHAPE_FN(unsorted_segment_max) { - - auto in = inputShape->at(0); - int outRank = shape::rank(in); - Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - Nd4jLong* outputShape; +namespace ops { +CUSTOM_OP_IMPL(unsorted_segment_max, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto idxSegments = INPUT_VARIABLE(1); + auto segmentedOutput = OUTPUT_NULLIFIED(0); + Nd4jLong numOfClasses = + block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); + REQUIRE_TRUE(idxSegments->isVector(), 0, + "unsorted_segment_max: segment indexes array should be a " + "vector, but it rank is %i.", + idxSegments->rankOf()); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, + "unsorted_segment_max: segment indexes array length should be " + "equal to the input first dimension, but %ld != %ild.", + idxSegments->lengthOf(), input->sizeAt(0)); + + Nd4jLong wrong; + + REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate( + block.launchContext(), idxSegments, numOfClasses, wrong), + 0, + "unsorted_segment_max: segment indices should be in range [0, " + "%ld), but %ld != %ld", + numOfClasses, wrong, numOfClasses); + + helpers::unsortedSegmentMaxFunctor(block.launchContext(), input, idxSegments, + numOfClasses, segmentedOutput); + + return ND4J_STATUS_OK; +} +DECLARE_TYPES(unsorted_segment_max) { + getOpDescriptor() + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setSameMode(true); +} +DECLARE_SHAPE_FN(unsorted_segment_max) { + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong numOfClasses = + block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); + Nd4jLong* outputShape; - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for (int i = 1; i < outRank; ++i) outputShape[i + 1] = shape::sizeAt(in, i); - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - return SHAPELIST(CONSTANT(outputShape)); - } + return SHAPELIST(CONSTANT(outputShape)); +} - CUSTOM_OP_IMPL(unsorted_segment_max_bp, 3, 2, false, 0, 1) { - return helpers::unsortedSegmentMaxFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0)); - } +CUSTOM_OP_IMPL(unsorted_segment_max_bp, 3, 2, false, 0, 1) { + return helpers::unsortedSegmentMaxFunctorBP( + block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), + INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0)); +} - DECLARE_TYPES(unsorted_segment_max_bp) { - getOpDescriptor() - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setSameMode(false); - } +DECLARE_TYPES(unsorted_segment_max_bp) { + getOpDescriptor() + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setSameMode(false); +} - DECLARE_SHAPE_FN(unsorted_segment_max_bp){ - auto in = inputShape->at(0); - auto inIdx = inputShape->at(1); +DECLARE_SHAPE_FN(unsorted_segment_max_bp) { + auto in = inputShape->at(0); + auto inIdx = inputShape->at(1); - Nd4jLong* outShape; - Nd4jLong* outIndex; - COPY_SHAPE(in, outShape); - COPY_SHAPE(inIdx, outIndex); - return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); - } - } + Nd4jLong* outShape; + Nd4jLong* outIndex; + COPY_SHAPE(in, outShape); + COPY_SHAPE(inIdx, outIndex); + return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); } +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_mean.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_mean.cpp index def3adb6ae97..e711cea60e02 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_mean.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_mean.cpp @@ -22,76 +22,90 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(unsorted_segment_mean, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_NULLIFIED(0); - Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - - REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_mean: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_mean: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0)); - - Nd4jLong wrong; - - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_mean: segment indices should be in range [0, %ld), but %ld != %ld", - numOfClasses, wrong, numOfClasses); - - helpers::unsortedSegmentMeanFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); - - return ND4J_STATUS_OK; - } - DECLARE_TYPES(unsorted_segment_mean) { - getOpDescriptor() - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setSameMode(false); - } - - DECLARE_SHAPE_FN(unsorted_segment_mean) { - - auto in = inputShape->at(0); - int outRank = shape::rank(in); - Nd4jLong* outputShape = nullptr; - Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); - - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); - - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - - return SHAPELIST(CONSTANT(outputShape)); - } - - CUSTOM_OP_IMPL(unsorted_segment_mean_bp, 3, 2, false, 0, 1) { - return helpers::unsortedSegmentMeanFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0)); - } - - DECLARE_TYPES(unsorted_segment_mean_bp) { - getOpDescriptor() - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setSameMode(false); - } - - DECLARE_SHAPE_FN(unsorted_segment_mean_bp){ - auto in = inputShape->at(0); - auto inIdx = inputShape->at(1); - - Nd4jLong* outShape; - Nd4jLong* outIndex; - COPY_SHAPE(in, outShape); - COPY_SHAPE(inIdx, outIndex); - return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); -// return SHAPELIST(in, inIdx); - } - } +namespace ops { +CUSTOM_OP_IMPL(unsorted_segment_mean, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto idxSegments = INPUT_VARIABLE(1); + auto segmentedOutput = OUTPUT_NULLIFIED(0); + Nd4jLong numOfClasses = + block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); + + REQUIRE_TRUE(idxSegments->isVector(), 0, + "unsorted_segment_mean: segment indexes array should be a " + "vector, but it rank is %i.", + idxSegments->rankOf()); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, + "unsorted_segment_mean: segment indexes array length should be " + "equal to the input first dimension, but %ld != %ld.", + idxSegments->lengthOf(), input->sizeAt(0)); + + Nd4jLong wrong; + + REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate( + block.launchContext(), idxSegments, numOfClasses, wrong), + 0, + "unsorted_segment_mean: segment indices should be in range [0, " + "%ld), but %ld != %ld", + numOfClasses, wrong, numOfClasses); + + helpers::unsortedSegmentMeanFunctor(block.launchContext(), input, idxSegments, + numOfClasses, segmentedOutput); + + return ND4J_STATUS_OK; } +DECLARE_TYPES(unsorted_segment_mean) { + getOpDescriptor() + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setSameMode(false); +} + +DECLARE_SHAPE_FN(unsorted_segment_mean) { + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong* outputShape = nullptr; + Nd4jLong numOfClasses = + block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); + + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for (int i = 1; i < outRank; ++i) outputShape[i + 1] = shape::sizeAt(in, i); + + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + + return SHAPELIST(CONSTANT(outputShape)); +} + +CUSTOM_OP_IMPL(unsorted_segment_mean_bp, 3, 2, false, 0, 1) { + return helpers::unsortedSegmentMeanFunctorBP( + block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), + INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0)); +} + +DECLARE_TYPES(unsorted_segment_mean_bp) { + getOpDescriptor() + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setSameMode(false); +} + +DECLARE_SHAPE_FN(unsorted_segment_mean_bp) { + auto in = inputShape->at(0); + auto inIdx = inputShape->at(1); + + Nd4jLong* outShape; + Nd4jLong* outIndex; + COPY_SHAPE(in, outShape); + COPY_SHAPE(inIdx, outIndex); + return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); + // return SHAPELIST(in, inIdx); +} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp index da31477ebf83..54544b266a55 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp @@ -22,78 +22,93 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(unsorted_segment_min, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_NULLIFIED(0); - Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_min: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_min: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0)); - - Nd4jLong wrong; - - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_min: segment indices should be in range [0, %ld), but %ld > %ld", - numOfClasses, wrong, numOfClasses); - - helpers::unsortedSegmentMinFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); - - return ND4J_STATUS_OK; - } - - DECLARE_SHAPE_FN(unsorted_segment_min) { - - auto in = inputShape->at(0); - int outRank = shape::rank(in); - Nd4jLong* outputShape = nullptr; - Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); +namespace ops { +CUSTOM_OP_IMPL(unsorted_segment_min, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto idxSegments = INPUT_VARIABLE(1); + auto segmentedOutput = OUTPUT_NULLIFIED(0); + Nd4jLong numOfClasses = block.width() == 3 + ? INPUT_VARIABLE(2)->e(0) + : INT_ARG(0); + + REQUIRE_TRUE(idxSegments->isVector(), 0, + "unsorted_segment_min: segment indexes array should be a " + "vector, but it rank is %i.", + idxSegments->rankOf()); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, + "unsorted_segment_min: segment indexes array length should be " + "equal to the input first dimension, but %ld != %ld.", + idxSegments->lengthOf(), input->sizeAt(0)); + + Nd4jLong wrong; + + REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate( + block.launchContext(), idxSegments, numOfClasses, wrong), + 0, + "unsorted_segment_min: segment indices should be in range [0, " + "%ld), but %ld > %ld", + numOfClasses, wrong, numOfClasses); + + helpers::unsortedSegmentMinFunctor(block.launchContext(), input, idxSegments, + numOfClasses, segmentedOutput); + + return ND4J_STATUS_OK; +} - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); +DECLARE_SHAPE_FN(unsorted_segment_min) { + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong* outputShape = nullptr; + Nd4jLong numOfClasses = + block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for (int i = 1; i < outRank; ++i) outputShape[i + 1] = shape::sizeAt(in, i); - return SHAPELIST(CONSTANT(outputShape)); - } + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - DECLARE_TYPES(unsorted_segment_min) { - getOpDescriptor() - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setSameMode(true); - } + return SHAPELIST(CONSTANT(outputShape)); +} - CUSTOM_OP_IMPL(unsorted_segment_min_bp, 3, 2, false, 0, 1) { - return helpers::unsortedSegmentMinFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0)); - } +DECLARE_TYPES(unsorted_segment_min) { + getOpDescriptor() + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setSameMode(true); +} - DECLARE_TYPES(unsorted_segment_min_bp) { - getOpDescriptor() - ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_FLOATS, ALL_INTS}) - ->setSameMode(false); - } +CUSTOM_OP_IMPL(unsorted_segment_min_bp, 3, 2, false, 0, 1) { + return helpers::unsortedSegmentMinFunctorBP( + block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), + INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0)); +} - DECLARE_SHAPE_FN(unsorted_segment_min_bp){ - auto in = inputShape->at(0); - auto inIdx = inputShape->at(1); +DECLARE_TYPES(unsorted_segment_min_bp) { + getOpDescriptor() + ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_FLOATS, ALL_INTS}) + ->setSameMode(false); +} - Nd4jLong* outShape; - Nd4jLong* outIndex; - COPY_SHAPE(in, outShape); - COPY_SHAPE(inIdx, outIndex); - return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); +DECLARE_SHAPE_FN(unsorted_segment_min_bp) { + auto in = inputShape->at(0); + auto inIdx = inputShape->at(1); - } + Nd4jLong* outShape; + Nd4jLong* outIndex; + COPY_SHAPE(in, outShape); + COPY_SHAPE(inIdx, outIndex); + return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); +} - } +} // namespace ops -} +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_prod.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_prod.cpp index 905a04b3611d..552d1fd23a2e 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_prod.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_prod.cpp @@ -22,91 +22,115 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(unsorted_segment_prod, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_NULLIFIED(0); - Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_prod: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_prod: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0)); - - Nd4jLong wrong = 0; - - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_prod: segment indices should be in range [0, %ld), but %ld != %ld", - numOfClasses, wrong, numOfClasses); - - helpers::unsortedSegmentProdFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); - - return ND4J_STATUS_OK; - } - - DECLARE_SHAPE_FN(unsorted_segment_prod) { - - auto in = inputShape->at(0); - int outRank = shape::rank(in); - Nd4jLong* outputShape = nullptr; - Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); - - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); - - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - - return SHAPELIST(CONSTANT(outputShape)); - } - DECLARE_TYPES(unsorted_segment_prod) { - getOpDescriptor() - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INDICES}) - ->setSameMode(false); - } - - CUSTOM_OP_IMPL(unsorted_segment_prod_bp, 3, 2, false, 0, 1) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto eps = INPUT_VARIABLE(2); -// auto numOfClasses = INT_ARG(0); - auto output = OUTPUT_NULLIFIED(0); - - Nd4jLong numOfClasses = block.width() == 4 ? INPUT_VARIABLE(3)->e(0) : INT_ARG(0); - REQUIRE_TRUE(indices->isVector(), 0, "unsorted_segment_prod_bp: segment indexes array should be a vector, but it rank is %i.", indices->rankOf()); - REQUIRE_TRUE(indices->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_prod_bp: segment indexes array length should be equal to the input first dimension, but %lld != %lld.", indices->lengthOf(), input->sizeAt(0)); - - Nd4jLong wrong = numOfClasses; - - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), indices, numOfClasses, wrong), 0, "unsorted_segment_prod_bp: segment indices should be in range [0, %lld), but %lld > %lld", - numOfClasses, wrong, numOfClasses); - - return helpers::unsortedSegmentProdFunctorBP(block.launchContext(), input, indices, eps, numOfClasses, output); - } - DECLARE_TYPES(unsorted_segment_prod_bp) { - getOpDescriptor() - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INDICES}) - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INDICES}) - ->setAllowedInputTypes(2,{ALL_FLOATS, ALL_INTS}) - ->setSameMode(false); - } - - DECLARE_SHAPE_FN(unsorted_segment_prod_bp){ - auto in = inputShape->at(0); - auto inIdx = inputShape->at(1); - - Nd4jLong* outShape; - Nd4jLong* outIndex; - COPY_SHAPE(in, outShape); - COPY_SHAPE(inIdx, outIndex); - return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); -// return SHAPELIST(in, inIdx); - } - - } +namespace ops { +CUSTOM_OP_IMPL(unsorted_segment_prod, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto idxSegments = INPUT_VARIABLE(1); + auto segmentedOutput = OUTPUT_NULLIFIED(0); + Nd4jLong numOfClasses = + block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); + REQUIRE_TRUE(idxSegments->isVector(), 0, + "unsorted_segment_prod: segment indexes array should be a " + "vector, but it rank is %i.", + idxSegments->rankOf()); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, + "unsorted_segment_prod: segment indexes array length should be " + "equal to the input first dimension, but %ld != %ld.", + idxSegments->lengthOf(), input->sizeAt(0)); + + Nd4jLong wrong = 0; + + REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate( + block.launchContext(), idxSegments, numOfClasses, wrong), + 0, + "unsorted_segment_prod: segment indices should be in range [0, " + "%ld), but %ld != %ld", + numOfClasses, wrong, numOfClasses); + + helpers::unsortedSegmentProdFunctor(block.launchContext(), input, idxSegments, + numOfClasses, segmentedOutput); + + return ND4J_STATUS_OK; +} + +DECLARE_SHAPE_FN(unsorted_segment_prod) { + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong* outputShape = nullptr; + Nd4jLong numOfClasses = + block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); + + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for (int i = 1; i < outRank; ++i) outputShape[i + 1] = shape::sizeAt(in, i); + + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + + return SHAPELIST(CONSTANT(outputShape)); +} +DECLARE_TYPES(unsorted_segment_prod) { + getOpDescriptor() + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INDICES}) + ->setSameMode(false); +} +CUSTOM_OP_IMPL(unsorted_segment_prod_bp, 3, 2, false, 0, 1) { + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto eps = INPUT_VARIABLE(2); + // auto numOfClasses = INT_ARG(0); + auto output = OUTPUT_NULLIFIED(0); + + Nd4jLong numOfClasses = + block.width() == 4 ? INPUT_VARIABLE(3)->e(0) : INT_ARG(0); + REQUIRE_TRUE(indices->isVector(), 0, + "unsorted_segment_prod_bp: segment indexes array should be a " + "vector, but it rank is %i.", + indices->rankOf()); + REQUIRE_TRUE(indices->lengthOf() == input->sizeAt(0), 0, + "unsorted_segment_prod_bp: segment indexes array length should " + "be equal to the input first dimension, but %lld != %lld.", + indices->lengthOf(), input->sizeAt(0)); + + Nd4jLong wrong = numOfClasses; + + REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate( + block.launchContext(), indices, numOfClasses, wrong), + 0, + "unsorted_segment_prod_bp: segment indices should be in range " + "[0, %lld), but %lld > %lld", + numOfClasses, wrong, numOfClasses); + + return helpers::unsortedSegmentProdFunctorBP( + block.launchContext(), input, indices, eps, numOfClasses, output); +} +DECLARE_TYPES(unsorted_segment_prod_bp) { + getOpDescriptor() + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INDICES}) + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INDICES}) + ->setAllowedInputTypes(2, {ALL_FLOATS, ALL_INTS}) + ->setSameMode(false); } + +DECLARE_SHAPE_FN(unsorted_segment_prod_bp) { + auto in = inputShape->at(0); + auto inIdx = inputShape->at(1); + + Nd4jLong* outShape; + Nd4jLong* outIndex; + COPY_SHAPE(in, outShape); + COPY_SHAPE(inIdx, outIndex); + return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); + // return SHAPELIST(in, inIdx); +} + +} // namespace ops + +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sqrt_n.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sqrt_n.cpp index e208f448909e..0fc14029c840 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sqrt_n.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sqrt_n.cpp @@ -22,75 +22,89 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(unsorted_segment_sqrt_n, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_NULLIFIED(0); - Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_sqrt_n: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sqrt_n: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0)); - - Nd4jLong wrong; - - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_sqrt_n: segment indices should be in range [0, %ld), but %ld != %ld", - numOfClasses, wrong, numOfClasses); - - helpers::unsortedSegmentSqrtNFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); - - return ND4J_STATUS_OK; - } - - DECLARE_SHAPE_FN(unsorted_segment_sqrt_n) { - - auto in = inputShape->at(0); - int outRank = shape::rank(in); - Nd4jLong* outputShape = nullptr; - Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); - - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); - - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - - return SHAPELIST(CONSTANT(outputShape)); - } - DECLARE_TYPES(unsorted_segment_sqrt_n) { - getOpDescriptor() - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setSameMode(false); - } - - CUSTOM_OP_IMPL(unsorted_segment_sqrt_n_bp, 3, 2, false, 0, 1) { - return helpers::unsortedSegmentSqrtNFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0)); - } - DECLARE_TYPES(unsorted_segment_sqrt_n_bp) { - getOpDescriptor() - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setSameMode(false); - } - - DECLARE_SHAPE_FN(unsorted_segment_sqrt_n_bp){ - auto in = inputShape->at(0); - auto inIdx = inputShape->at(1); - - Nd4jLong* outShape; - Nd4jLong* outIndex; - COPY_SHAPE(in, outShape); - COPY_SHAPE(inIdx, outIndex); - return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); -// return SHAPELIST(in, inIdx); - } - - } +namespace ops { +CUSTOM_OP_IMPL(unsorted_segment_sqrt_n, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto idxSegments = INPUT_VARIABLE(1); + auto segmentedOutput = OUTPUT_NULLIFIED(0); + Nd4jLong numOfClasses = + block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); + REQUIRE_TRUE(idxSegments->isVector(), 0, + "unsorted_segment_sqrt_n: segment indexes array should be a " + "vector, but it rank is %i.", + idxSegments->rankOf()); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, + "unsorted_segment_sqrt_n: segment indexes array length should " + "be equal to the input first dimension, but %ld != %ld.", + idxSegments->lengthOf(), input->sizeAt(0)); + + Nd4jLong wrong; + + REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate( + block.launchContext(), idxSegments, numOfClasses, wrong), + 0, + "unsorted_segment_sqrt_n: segment indices should be in range " + "[0, %ld), but %ld != %ld", + numOfClasses, wrong, numOfClasses); + + helpers::unsortedSegmentSqrtNFunctor( + block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); + + return ND4J_STATUS_OK; } + +DECLARE_SHAPE_FN(unsorted_segment_sqrt_n) { + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong* outputShape = nullptr; + Nd4jLong numOfClasses = + block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); + + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for (int i = 1; i < outRank; ++i) outputShape[i + 1] = shape::sizeAt(in, i); + + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + + return SHAPELIST(CONSTANT(outputShape)); +} +DECLARE_TYPES(unsorted_segment_sqrt_n) { + getOpDescriptor() + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setSameMode(false); +} + +CUSTOM_OP_IMPL(unsorted_segment_sqrt_n_bp, 3, 2, false, 0, 1) { + return helpers::unsortedSegmentSqrtNFunctorBP( + block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), + INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0)); +} +DECLARE_TYPES(unsorted_segment_sqrt_n_bp) { + getOpDescriptor() + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setSameMode(false); +} + +DECLARE_SHAPE_FN(unsorted_segment_sqrt_n_bp) { + auto in = inputShape->at(0); + auto inIdx = inputShape->at(1); + + Nd4jLong* outShape; + Nd4jLong* outIndex; + COPY_SHAPE(in, outShape); + COPY_SHAPE(inIdx, outIndex); + return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); + // return SHAPELIST(in, inIdx); +} + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sum.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sum.cpp index 325385a86007..32833510e610 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sum.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sum.cpp @@ -22,73 +22,86 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(unsorted_segment_sum, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_NULLIFIED(0); - Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_sum: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sum: segment indexes array length should be equal to the input first dimension, but %ld != %ld", idxSegments->lengthOf(), input->sizeAt(0)); - - Nd4jLong wrong; - - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_sum: segment indices should be in range [0, %ld), but %ld > %ld", - numOfClasses, wrong, numOfClasses); - - helpers::unsortedSegmentSumFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); - - return ND4J_STATUS_OK; - } - DECLARE_TYPES(unsorted_segment_sum) { - getOpDescriptor() - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setSameMode(false); - } - - DECLARE_SHAPE_FN(unsorted_segment_sum) { - - auto in = inputShape->at(0); - int outRank = shape::rank(in); - Nd4jLong* outputShape = nullptr; - Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); - - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); - - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - - return SHAPELIST(CONSTANT(outputShape)); - } - CUSTOM_OP_IMPL(unsorted_segment_sum_bp, 3, 2, false, 0, 1) { - return helpers::unsortedSegmentSumFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0)); - } - - DECLARE_SHAPE_FN(unsorted_segment_sum_bp){ - auto in = inputShape->at(0); - auto inIdx = inputShape->at(1); - - Nd4jLong* outShape; - Nd4jLong* outIndex; - COPY_SHAPE(in, outShape); - COPY_SHAPE(inIdx, outIndex); - return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); - - } - DECLARE_TYPES(unsorted_segment_sum_bp) { - getOpDescriptor() - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(false); - } - - } +namespace ops { +CUSTOM_OP_IMPL(unsorted_segment_sum, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto idxSegments = INPUT_VARIABLE(1); + auto segmentedOutput = OUTPUT_NULLIFIED(0); + Nd4jLong numOfClasses = + block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); + REQUIRE_TRUE(idxSegments->isVector(), 0, + "unsorted_segment_sum: segment indexes array should be a " + "vector, but it rank is %i.", + idxSegments->rankOf()); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, + "unsorted_segment_sum: segment indexes array length should be " + "equal to the input first dimension, but %ld != %ld", + idxSegments->lengthOf(), input->sizeAt(0)); + + Nd4jLong wrong; + + REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate( + block.launchContext(), idxSegments, numOfClasses, wrong), + 0, + "unsorted_segment_sum: segment indices should be in range [0, " + "%ld), but %ld > %ld", + numOfClasses, wrong, numOfClasses); + + helpers::unsortedSegmentSumFunctor(block.launchContext(), input, idxSegments, + numOfClasses, segmentedOutput); + + return ND4J_STATUS_OK; +} +DECLARE_TYPES(unsorted_segment_sum) { + getOpDescriptor() + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setSameMode(false); +} + +DECLARE_SHAPE_FN(unsorted_segment_sum) { + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong* outputShape = nullptr; + Nd4jLong numOfClasses = + block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); + + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for (int i = 1; i < outRank; ++i) outputShape[i + 1] = shape::sizeAt(in, i); + + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + + return SHAPELIST(CONSTANT(outputShape)); +} +CUSTOM_OP_IMPL(unsorted_segment_sum_bp, 3, 2, false, 0, 1) { + return helpers::unsortedSegmentSumFunctorBP( + block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), + INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0)); } + +DECLARE_SHAPE_FN(unsorted_segment_sum_bp) { + auto in = inputShape->at(0); + auto inIdx = inputShape->at(1); + + Nd4jLong* outShape; + Nd4jLong* outIndex; + COPY_SHAPE(in, outShape); + COPY_SHAPE(inIdx, outIndex); + return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); +} +DECLARE_TYPES(unsorted_segment_sum_bp) { + getOpDescriptor() + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(sd::DataType::ANY) + ->setSameMode(false); +} + +} // namespace ops + +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/weighted_cross_entropy_with_logits.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/weighted_cross_entropy_with_logits.cpp index 71d337c7cf40..24383776988e 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/weighted_cross_entropy_with_logits.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/weighted_cross_entropy_with_logits.cpp @@ -27,26 +27,37 @@ namespace sd { namespace ops { - OP_IMPL(weighted_cross_entropy_with_logits, 3, 1, true) { - auto targets = INPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(1); - auto weights = INPUT_VARIABLE(2); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(targets->isSameShape(input), 0, "WEIGHTED_CROSS_ENTROPY_WITH_LOGITS op: The shape of both input params should be equal, but got input_shape=%s and targets_shape=%s !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(targets).c_str()); - REQUIRE_TRUE(weights->isScalar() || targets->sizeAt(-1) == weights->lengthOf(), 0, "WEIGHTED_CROSS_ENTROPY_WITH_LOGITS op: The weights should be scalar or vector with length equal to size of last targets dimension, but got weights_shape=%s instead!", ShapeUtils::shapeAsString(weights).c_str()); - - helpers::weightedCrossEntropyWithLogitsFunctor(block.launchContext(), targets, input, weights, output); - - return Status::OK(); - } - - DECLARE_TYPES(weighted_cross_entropy_with_logits) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +OP_IMPL(weighted_cross_entropy_with_logits, 3, 1, true) { + auto targets = INPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(1); + auto weights = INPUT_VARIABLE(2); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE( + targets->isSameShape(input), 0, + "WEIGHTED_CROSS_ENTROPY_WITH_LOGITS op: The shape of both input params " + "should be equal, but got input_shape=%s and targets_shape=%s !", + ShapeUtils::shapeAsString(input).c_str(), + ShapeUtils::shapeAsString(targets).c_str()); + REQUIRE_TRUE( + weights->isScalar() || targets->sizeAt(-1) == weights->lengthOf(), 0, + "WEIGHTED_CROSS_ENTROPY_WITH_LOGITS op: The weights should be scalar or " + "vector with length equal to size of last targets dimension, but got " + "weights_shape=%s instead!", + ShapeUtils::shapeAsString(weights).c_str()); + + helpers::weightedCrossEntropyWithLogitsFunctor(block.launchContext(), targets, + input, weights, output); + + return Status::OK(); } + +DECLARE_TYPES(weighted_cross_entropy_with_logits) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/zero_fraction.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/zero_fraction.cpp index 91f0a564d187..2d60813bcb0d 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/zero_fraction.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/zero_fraction.cpp @@ -24,39 +24,42 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(zero_fraction, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar"); - - if(input->isEmpty()){ - output->p(0, std::numeric_limits::quiet_NaN()); - return Status::OK(); - } - - int numZeros = 0; -// for (int e = 0; e < input->lengthOf(); e++) -// if ((*input)(e) == T(0)) -// numZeros++; - auto countZero = input->reduceNumber(reduce::CountZero); - //nd4j_printf("Zero count is %f for %i elements.", countZero.e(0), input->lengthOf()); - //countZero /= double(input->lengthOf()); - output->p(0, countZero.e(0) / double(input->lengthOf())); //printIndexedBuffer("Zero count"); - - return Status::OK(); - } - DECLARE_SHAPE_FN(zero_fraction) { - return SHAPELIST(ConstantShapeHelper::getInstance().scalarShapeInfo(sd::DataType::DOUBLE)); - } - - DECLARE_TYPES(zero_fraction) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +namespace ops { +CUSTOM_OP_IMPL(zero_fraction, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar"); + + if (input->isEmpty()) { + output->p(0, std::numeric_limits::quiet_NaN()); + return Status::OK(); + } + + int numZeros = 0; + // for (int e = 0; e < input->lengthOf(); e++) + // if ((*input)(e) == T(0)) + // numZeros++; + auto countZero = input->reduceNumber(reduce::CountZero); + // nd4j_printf("Zero count is %f for %i elements.", countZero.e(0), + // input->lengthOf()); countZero /= double(input->lengthOf()); + output->p( + 0, countZero.e(0) / + double(input->lengthOf())); // printIndexedBuffer("Zero count"); + + return Status::OK(); +} +DECLARE_SHAPE_FN(zero_fraction) { + return SHAPELIST(ConstantShapeHelper::getInstance().scalarShapeInfo( + sd::DataType::DOUBLE)); +} + +DECLARE_TYPES(zero_fraction) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/random/bernoulli.cpp b/libnd4j/include/ops/declarable/generic/random/bernoulli.cpp index ded5bfee5c45..08f32e98a3e8 100644 --- a/libnd4j/include/ops/declarable/generic/random/bernoulli.cpp +++ b/libnd4j/include/ops/declarable/generic/random/bernoulli.cpp @@ -21,49 +21,51 @@ #include #if NOT_EXCLUDED(OP_random_bernoulli) -#include #include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(random_bernoulli, 1, 1, true, 1, 0) { - auto rng = block.getRng(); - // FIXME: to be implemented -/* - if (rng == nullptr) - return Status::THROW("RNG is null, aborting..."); - - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(random_bernoulli, 1, 1, true, 1, 0) { + auto rng = block.getRng(); + // FIXME: to be implemented + /* + if (rng == nullptr) + return Status::THROW("RNG is null, aborting..."); - T f = T_ARG(0); + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - functions::random::RandomFunction::template execTransform>(block.getRNG(), z->buffer(), z->shapeInfo(), &f); -*/ + T f = T_ARG(0); - auto z = OUTPUT_VARIABLE(0); - auto f = T_ARG(0); + functions::random::RandomFunction::template + execTransform>(block.getRNG(), + z->buffer(), z->shapeInfo(), &f); + */ - RandomLauncher::fillBernoulli(block.launchContext(), rng, z, f); + auto z = OUTPUT_VARIABLE(0); + auto f = T_ARG(0); - return Status::OK(); - } + RandomLauncher::fillBernoulli(block.launchContext(), rng, z, f); - DECLARE_SHAPE_FN(random_bernoulli) { - auto in = INPUT_VARIABLE(0); - auto shape = in->template asVectorT(); + return Status::OK(); +} - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(block.dataType(), 'c', shape); - return SHAPELIST(newShape); - } +DECLARE_SHAPE_FN(random_bernoulli) { + auto in = INPUT_VARIABLE(0); + auto shape = in->template asVectorT(); + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + DataType::FLOAT32, 'c', shape); + return SHAPELIST(newShape); +} - DECLARE_TYPES(random_bernoulli) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +DECLARE_TYPES(random_bernoulli) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/random/dropout.cpp b/libnd4j/include/ops/declarable/generic/random/dropout.cpp index b64fd49d5677..2deb1842f801 100644 --- a/libnd4j/include/ops/declarable/generic/random/dropout.cpp +++ b/libnd4j/include/ops/declarable/generic/random/dropout.cpp @@ -27,102 +27,105 @@ namespace sd { namespace ops { - ////////////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(dropout, 1, 1, true, 1, 1) { - auto input = INPUT_VARIABLE(0); // lookup param + auto input = INPUT_VARIABLE(0); // lookup param + + NDArray* reduceShape = nullptr; // this param is optional + auto output = OUTPUT_NULLIFIED(0); // - NDArray *reduceShape = nullptr; // this param is optional - auto output = OUTPUT_NULLIFIED(0); // - - int seed = INT_ARG(0); + int seed = INT_ARG(0); - // FIXME: float? - double probValue = T_ARG(0); - if (block.width() > 1) - reduceShape = INPUT_VARIABLE(1); + // FIXME: float? + double probValue = T_ARG(0); + if (block.width() > 1) reduceShape = INPUT_VARIABLE(1); - REQUIRE_TRUE(probValue > 0.f && probValue <= 1.f, 0, "dropout: Probability should be with range 0 to 1."); + REQUIRE_TRUE(probValue > 0.f && probValue <= 1.f, 0, + "dropout: Probability should be with range 0 to 1."); - if (probValue == 1.0) { - *output = *input; - return Status::OK(); - } + if (probValue == 1.0) { + *output = *input; + return Status::OK(); + } - return helpers::dropOutFunctor(block, input, output, reduceShape, seed, probValue); + return helpers::dropOutFunctor(block, input, output, reduceShape, seed, + probValue); } - DECLARE_TYPES(dropout) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(true); - } +DECLARE_TYPES(dropout) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(true); +} ////////////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(dropout_bp, 2, 1, false, 1, 1) { - NDArray* input = INPUT_VARIABLE(0); // lookup param - NDArray* gradOut = INPUT_VARIABLE(1); // lookup param - - NDArray* reduceShape = nullptr; // this param is optional - NDArray* output = OUTPUT_NULLIFIED(0); // - - int seed = INT_ARG(0); - - double probValue = T_ARG(0); - if (block.width() > 2) - reduceShape = INPUT_VARIABLE(2); - - REQUIRE_TRUE((probValue > 0. && probValue <= 1.), 0, "dropout_bp: Probability should be with range 0 to 1."); - if (probValue == 1.0) { - output->assign(0.f); // fill up output with 0 - return ND4J_STATUS_OK; - } - - REQUIRE_TRUE(helpers::dropOutFunctorBP(block, input, gradOut, output, reduceShape, seed, probValue) == ND4J_STATUS_OK, 0, "dropout_bp: Cannot backprop dropout." ); + NDArray* input = INPUT_VARIABLE(0); // lookup param + NDArray* gradOut = INPUT_VARIABLE(1); // lookup param + + NDArray* reduceShape = nullptr; // this param is optional + NDArray* output = OUTPUT_NULLIFIED(0); // + + int seed = INT_ARG(0); + + double probValue = T_ARG(0); + if (block.width() > 2) reduceShape = INPUT_VARIABLE(2); + REQUIRE_TRUE((probValue > 0. && probValue <= 1.), 0, + "dropout_bp: Probability should be with range 0 to 1."); + if (probValue == 1.0) { + output->assign(0.f); // fill up output with 0 return ND4J_STATUS_OK; + } + + REQUIRE_TRUE( + helpers::dropOutFunctorBP(block, input, gradOut, output, reduceShape, + seed, probValue) == ND4J_STATUS_OK, + 0, "dropout_bp: Cannot backprop dropout."); + + return ND4J_STATUS_OK; } DECLARE_TYPES(dropout_bp) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS, ALL_INTS}) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS, ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(alpha_dropout_bp, 2, 1, false, 4, 1) { - NDArray* input = INPUT_VARIABLE(0); // lookup param - NDArray* gradOut = INPUT_VARIABLE(1); // lookup param + NDArray* input = INPUT_VARIABLE(0); // lookup param + NDArray* gradOut = INPUT_VARIABLE(1); // lookup param - NDArray* reduceShape = nullptr; // this param is optional - NDArray* output = OUTPUT_VARIABLE(0); // + NDArray* reduceShape = nullptr; // this param is optional + NDArray* output = OUTPUT_VARIABLE(0); // - if (block.width() > 2) - reduceShape = INPUT_VARIABLE(2); + if (block.width() > 2) reduceShape = INPUT_VARIABLE(2); - int seed = INT_ARG(0); - - double probValue = T_ARG(0); - double alphaValue = T_ARG(0); - double alpha1Value = T_ARG(2); - double betaValue = T_ARG(3); + int seed = INT_ARG(0); - REQUIRE_TRUE(probValue > 0. && probValue <= 1., 0, "dropout_bp: Probability should be with range 0 to 1."); - if (probValue == 1.0) { - output->assign(0.); // fill up output with 0 - return ND4J_STATUS_OK; - } + double probValue = T_ARG(0); + double alphaValue = T_ARG(0); + double alpha1Value = T_ARG(2); + double betaValue = T_ARG(3); - return helpers::alphaDropOutFunctorBP(block, input, gradOut, output, reduceShape, seed, probValue, alphaValue, alpha1Value, betaValue); -} - DECLARE_TYPES(alpha_dropout_bp) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setSameMode(true); - } + REQUIRE_TRUE(probValue > 0. && probValue <= 1., 0, + "dropout_bp: Probability should be with range 0 to 1."); + if (probValue == 1.0) { + output->assign(0.); // fill up output with 0 + return ND4J_STATUS_OK; + } + + return helpers::alphaDropOutFunctorBP(block, input, gradOut, output, + reduceShape, seed, probValue, + alphaValue, alpha1Value, betaValue); } +DECLARE_TYPES(alpha_dropout_bp) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/random/exponential.cpp b/libnd4j/include/ops/declarable/generic/random/exponential.cpp index 735bab5831cd..c6f5b0aa1d5e 100644 --- a/libnd4j/include/ops/declarable/generic/random/exponential.cpp +++ b/libnd4j/include/ops/declarable/generic/random/exponential.cpp @@ -21,37 +21,37 @@ #include #if NOT_EXCLUDED(OP_random_exponential) -#include #include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(random_exponential, 1, 1, true, 1, 0) { - // random generator for distribution - auto rng = block.randomGenerator(); - auto z = OUTPUT_VARIABLE(0); - auto lambda = T_ARG(0); - - RandomLauncher::fillExponential(block.launchContext(), rng, z, lambda); +namespace ops { +CUSTOM_OP_IMPL(random_exponential, 1, 1, true, 1, 0) { + // random generator for distribution + auto rng = block.randomGenerator(); + auto z = OUTPUT_VARIABLE(0); + auto lambda = T_ARG(0); - return Status::OK(); - } + RandomLauncher::fillExponential(block.launchContext(), rng, z, lambda); + return Status::OK(); +} - DECLARE_SHAPE_FN(random_exponential) { - auto in = INPUT_VARIABLE(0); - auto shape = in->template asVectorT(); +DECLARE_SHAPE_FN(random_exponential) { + auto in = INPUT_VARIABLE(0); + auto shape = in->template asVectorT(); - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(block.dataType(), 'c', shape); - return SHAPELIST(newShape); - } + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + DataType::FLOAT32, 'c', shape); + return SHAPELIST(newShape); +} - DECLARE_TYPES(random_exponential) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +DECLARE_TYPES(random_exponential) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/random/gamma.cpp b/libnd4j/include/ops/declarable/generic/random/gamma.cpp index b7dfc9f0614c..bfafb377351f 100644 --- a/libnd4j/include/ops/declarable/generic/random/gamma.cpp +++ b/libnd4j/include/ops/declarable/generic/random/gamma.cpp @@ -25,61 +25,68 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(random_gamma, 2, 1, false, 0, 0) { - // gamma distribution - auto rng = block.randomGenerator(); - auto shape = INPUT_VARIABLE(0); - auto alpha = INPUT_VARIABLE(1); - NDArray* beta = nullptr; +namespace ops { +CUSTOM_OP_IMPL(random_gamma, 2, 1, false, 0, 0) { + // gamma distribution + auto rng = block.randomGenerator(); + auto shape = INPUT_VARIABLE(0); + auto alpha = INPUT_VARIABLE(1); + NDArray* beta = nullptr; - if (block.width() > 2) { - beta = INPUT_VARIABLE(2); - REQUIRE_TRUE(ShapeUtils::areShapesBroadcastable(*alpha, *beta), 0, "random_gamma: alpha and beta shapes should be broadcastable."); - } + if (block.width() > 2) { + beta = INPUT_VARIABLE(2); + REQUIRE_TRUE( + ShapeUtils::areShapesBroadcastable(*alpha, *beta), 0, + "random_gamma: alpha and beta shapes should be broadcastable."); + } - auto output = OUTPUT_VARIABLE(0); - auto seed = 0; + auto output = OUTPUT_VARIABLE(0); + auto seed = 0; - if (block.getIArguments()->size()) { - seed = INT_ARG(0); - } + if (block.numI()) { + seed = INT_ARG(0); + } - rng.setSeed(seed); + rng.setSeed(seed); - helpers::fillRandomGamma(block.launchContext(), rng, alpha, beta, output); + helpers::fillRandomGamma(block.launchContext(), rng, alpha, beta, output); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(random_gamma) { - auto in = INPUT_VARIABLE(0); - auto shape = in->template asVectorT(); - auto alphaShape = inputShape->at(1); - auto additionalShape = alphaShape; - if (inputShape->size() > 2) { - auto rest = inputShape->at(2); additionalShape = nullptr; - REQUIRE_TRUE(ShapeUtils::areShapesBroadcastable(alphaShape, rest), 0, "random_gamma: alpha and beta shapes should be broadcastable."); - const Nd4jLong* additionalShapeBroadcasted = nullptr; - ShapeUtils::evalBroadcastShapeInfo(alphaShape, rest, true, additionalShapeBroadcasted, block.workspace()); - additionalShape = additionalShapeBroadcasted; - } - auto lastDim = shape::sizeAt(alphaShape, 0); - auto dtype = block.numD() > 0? D_ARG(0): ArrayOptions::dataType(alphaShape); - for (auto i = 0; i < shape::rank(additionalShape); i++) - shape.push_back(shape::sizeAt(additionalShape, i)); - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(dtype, 'c', shape); - return SHAPELIST(newShape); - } +DECLARE_SHAPE_FN(random_gamma) { + auto in = INPUT_VARIABLE(0); + auto shape = in->template asVectorT(); + auto alphaShape = inputShape->at(1); + auto additionalShape = alphaShape; + if (inputShape->size() > 2) { + auto rest = inputShape->at(2); + additionalShape = nullptr; + REQUIRE_TRUE( + ShapeUtils::areShapesBroadcastable(alphaShape, rest), 0, + "random_gamma: alpha and beta shapes should be broadcastable."); + const Nd4jLong* additionalShapeBroadcasted = nullptr; + ShapeUtils::evalBroadcastShapeInfo( + alphaShape, rest, true, additionalShapeBroadcasted, block.workspace()); + additionalShape = additionalShapeBroadcasted; + } + auto lastDim = shape::sizeAt(alphaShape, 0); + auto dtype = block.numD() > 0? D_ARG(0): ArrayOptions::dataType(alphaShape); + for (auto i = 0; i < shape::rank(additionalShape); i++) + shape.push_back(shape::sizeAt(additionalShape, i)); + auto newShape = + ConstantShapeHelper::getInstance().createShapeInfo(dtype, 'c', shape); + return SHAPELIST(newShape); +} - DECLARE_TYPES(random_gamma) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +DECLARE_TYPES(random_gamma) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/random/get_seed.cpp b/libnd4j/include/ops/declarable/generic/random/get_seed.cpp index 9f768e9f3aad..87f9a4550524 100644 --- a/libnd4j/include/ops/declarable/generic/random/get_seed.cpp +++ b/libnd4j/include/ops/declarable/generic/random/get_seed.cpp @@ -24,28 +24,30 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(get_seed, -2, 1, false, 0, 0) { -// REQUIRE_TRUE(block.getRNG() != nullptr, 0, "RNG should be defined in Graph"); - auto rng = block.getRng(); - auto z = OUTPUT_VARIABLE(0); - - z->p(Nd4jLong(0), rng.rootState()); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(get_seed) { - auto newshape = ConstantShapeHelper::getInstance().scalarShapeInfo(DataType::INT64); - return SHAPELIST(newshape); - } - - DECLARE_TYPES(get_seed) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(DataType::INT64); - } - } +namespace ops { +CUSTOM_OP_IMPL(get_seed, -2, 1, false, 0, 0) { + // REQUIRE_TRUE(block.getRNG() != nullptr, 0, "RNG should be + // defined in Graph"); + auto rng = block.getRng(); + auto z = OUTPUT_VARIABLE(0); + + z->p(Nd4jLong(0), rng.rootState()); + + return Status::OK(); +} + +DECLARE_SHAPE_FN(get_seed) { + auto newshape = + ConstantShapeHelper::getInstance().scalarShapeInfo(DataType::INT64); + return SHAPELIST(newshape); +} + +DECLARE_TYPES(get_seed) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(DataType::INT64); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/random/multinomial.cpp b/libnd4j/include/ops/declarable/generic/random/multinomial.cpp index 2e8225d2c171..126ca5a588fe 100644 --- a/libnd4j/include/ops/declarable/generic/random/multinomial.cpp +++ b/libnd4j/include/ops/declarable/generic/random/multinomial.cpp @@ -22,93 +22,113 @@ #include #if NOT_EXCLUDED(OP_random_multinomial) -#include #include +#include #include namespace sd { - namespace ops { - /////////////////////// - /** - * multinomial (categorical) random generator - * takes 2D ndarray with logits with shape [batch_size (N), num_classes (K)] - * and array with one scalar value of samples number, number of independent samples to draw for each experiment 1,N. - * represents the unnormalized log-probabilities for all classes. - * Int arguments: 0 - optional argument, corresponds to dimension with batch_size - * Int arguments: 1 - optional argument, integer type to use for the output. Default int64. - */ - // used https://en.wikipedia.org/wiki/Categorical_distribution - // methods: gumbel trick + softmax + argmax - CUSTOM_OP_IMPL(random_multinomial, 2, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_NULLIFIED(0); - auto inputSamples = INPUT_VARIABLE(1); - - - REQUIRE_TRUE(!input->isEmpty(), 0, "RANDOM_MULTINOMIAL OP: Have to be provided at least one logits. "); - - REQUIRE_TRUE(inputSamples->lengthOf() == 1, 0, "RANDOM_MULTINOMIAL OP: Have to be specified at least one sample," - " but got no argumets instead."); - - Nd4jLong numOfSamples = static_cast(inputSamples->e(0)); - // do nothing if number of samples = 0 - if (0 == numOfSamples) - return Status::OK(); - - REQUIRE_TRUE(numOfSamples > 0, 0, "RANDOM_MULTINOMIAL OP: Number of samples should be greater then 0, got %i. ", numOfSamples); - - const int rank = input->rankOf(); - REQUIRE_TRUE(rank == 2, 0, "RANDOM_MULTINOMIAL OP: Logits should be a matrix with rank = 2, but got instead rank = %i.", rank); - - const int argSize = block.getIArguments()->size(); - const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; - - auto dimA = (0 == dimC) ? 1 : 0; - if (1 == input->sizeAt(dimA)) { - *output = 0; - return Status::OK(); - } - - auto rng = block.randomGenerator(); - helpers::fillRandomMultiNomial(block.launchContext(), rng, *input, *output, numOfSamples, dimC); - return Status::OK(); - } - - - DECLARE_SHAPE_FN(random_multinomial) { - - auto input = INPUT_VARIABLE(0); - auto inputSamples = INPUT_VARIABLE(1); - - REQUIRE_TRUE(inputSamples->lengthOf() == 1, 0, "RANDOM_MULTINOMIAL OP: Have to be specified at least one sample," - " but got no argumets instead."); - - Nd4jLong numOfSamples = static_cast(inputSamples->e(0)); - - REQUIRE_TRUE(numOfSamples > 0, 0, "RANDOM_MULTINOMIAL OP: Number of samples should be greater then 0, got %i. ", numOfSamples); - - const int rank = input->rankOf(); - REQUIRE_TRUE(rank == 2, 0, "RANDOM_MULTINOMIAL OP: Logits should be a matrix with rank = 2, but got instead rank = %i.", rank); - - const int argSize = block.getIArguments()->size(); - const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; - - auto nShape = input->getShapeAsVector(); - auto dimA = (0 == dimC) ? 1 : 0; - nShape[dimA] = numOfSamples; - - DataType nType = (argSize > 1) ? ( INT_ARG(1) >= 0 ? static_cast(INT_ARG(1)) : sd::DataType::INT64) : sd::DataType::INT64; - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(nType, input->ordering(), nShape)); - } - - DECLARE_TYPES(random_multinomial) { - getOpDescriptor() - ->setAllowedInputTypes(0, { ALL_FLOATS, ALL_INTS }) - ->setAllowedInputTypes(1, { sd::DataType::INT32 }) - ->setAllowedOutputTypes(0, { ALL_INDICES }); - } - } +namespace ops { +/////////////////////// +/** + * multinomial (categorical) random generator + * takes 2D ndarray with logits with shape [batch_size (N), num_classes (K)] + * and array with one scalar value of samples number, number of independent + * samples to draw for each experiment 1,N. represents the unnormalized + * log-probabilities for all classes. Int arguments: 0 - optional argument, + * corresponds to dimension with batch_size Int arguments: 1 - optional + * argument, integer type to use for the output. Default int64. + */ +// used https://en.wikipedia.org/wiki/Categorical_distribution +// methods: gumbel trick + softmax + argmax +CUSTOM_OP_IMPL(random_multinomial, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_NULLIFIED(0); + auto inputSamples = INPUT_VARIABLE(1); + + REQUIRE_TRUE( + !input->isEmpty(), 0, + "RANDOM_MULTINOMIAL OP: Have to be provided at least one logits. "); + + REQUIRE_TRUE( + inputSamples->lengthOf() == 1, 0, + "RANDOM_MULTINOMIAL OP: Have to be specified at least one sample," + " but got no argumets instead."); + + Nd4jLong numOfSamples = static_cast(inputSamples->e(0)); + // do nothing if number of samples = 0 + if (0 == numOfSamples) return Status::OK(); + + REQUIRE_TRUE(numOfSamples > 0, 0, + "RANDOM_MULTINOMIAL OP: Number of samples should be greater " + "then 0, got %i. ", + numOfSamples); + + const int rank = input->rankOf(); + REQUIRE_TRUE(rank == 2, 0, + "RANDOM_MULTINOMIAL OP: Logits should be a matrix with rank = " + "2, but got instead rank = %i.", + rank); + + const int argSize = block.numI(); + const int dimC = argSize > 0 + ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) + : rank - 1; + + auto dimA = (0 == dimC) ? 1 : 0; + if (1 == input->sizeAt(dimA)) { + *output = 0; + return Status::OK(); + } + + auto rng = block.randomGenerator(); + helpers::fillRandomMultiNomial(block.launchContext(), rng, *input, *output, + numOfSamples, dimC); + return Status::OK(); +} + +DECLARE_SHAPE_FN(random_multinomial) { + auto input = INPUT_VARIABLE(0); + auto inputSamples = INPUT_VARIABLE(1); + + REQUIRE_TRUE( + inputSamples->lengthOf() == 1, 0, + "RANDOM_MULTINOMIAL OP: Have to be specified at least one sample," + " but got no argumets instead."); + + Nd4jLong numOfSamples = static_cast(inputSamples->e(0)); + + REQUIRE_TRUE(numOfSamples > 0, 0, + "RANDOM_MULTINOMIAL OP: Number of samples should be greater " + "then 0, got %i. ", + numOfSamples); + + const int rank = input->rankOf(); + REQUIRE_TRUE(rank == 2, 0, + "RANDOM_MULTINOMIAL OP: Logits should be a matrix with rank = " + "2, but got instead rank = %i.", + rank); + + const int argSize = block.numI(); + const int dimC = argSize > 0 + ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) + : rank - 1; + + auto nShape = input->getShapeAsVector(); + auto dimA = (0 == dimC) ? 1 : 0; + nShape[dimA] = numOfSamples; + + DataType nType = block.numD() ? D_ARG(0) : sd::DataType::INT64; + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + nType, input->ordering(), nShape)); +} + +DECLARE_TYPES(random_multinomial) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, sd::DataType::INT32) + ->setAllowedOutputTypes(0, {ALL_INDICES}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/random/normal.cpp b/libnd4j/include/ops/declarable/generic/random/normal.cpp index 701570784ddb..e9f0f70c9380 100644 --- a/libnd4j/include/ops/declarable/generic/random/normal.cpp +++ b/libnd4j/include/ops/declarable/generic/random/normal.cpp @@ -21,45 +21,51 @@ #include #if NOT_EXCLUDED(OP_random_normal) -#include #include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(random_normal, 1, 1, true, 2, 0) { - // normal distribution - auto rng = block.randomGenerator(); - // FIXME: to be implemented -/* - REQUIRE_TRUE(rng != nullptr, 0, "RNG isn't defined for this Graph instance"); +namespace ops { +CUSTOM_OP_IMPL(random_normal, 1, 1, true, 2, 0) { + // normal distribution + auto rng = block.randomGenerator(); + // FIXME: to be implemented + /* + REQUIRE_TRUE(rng != nullptr, 0, "RNG isn't defined for this Graph + instance"); - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - functions::random::RandomFunction::template execTransform>(block.getRNG(), z->buffer(), z->shapeInfo(), z->buffer(), z->shapeInfo(), z->buffer(), z->shapeInfo(), block.getTArguments()->data()); -*/ + functions::random::RandomFunction::template + execTransform>(block.getRNG(), + z->buffer(), z->shapeInfo(), z->buffer(), z->shapeInfo(), z->buffer(), + z->shapeInfo(), block.getTArguments()->data()); + */ - RandomLauncher::fillGaussian(block.launchContext(), rng, OUTPUT_VARIABLE(0), T_ARG(0), T_ARG(1)); + RandomLauncher::fillGaussian(block.launchContext(), rng, OUTPUT_VARIABLE(0), + T_ARG(0), T_ARG(1)); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(random_normal) { - auto in = INPUT_VARIABLE(0); - auto shape = in->template asVectorT(); +DECLARE_SHAPE_FN(random_normal) { + auto in = INPUT_VARIABLE(0); + auto shape = in->template asVectorT(); + + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + DataType::FLOAT32, 'c', shape); + return SHAPELIST(newShape); +} - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(block.dataType(), 'c', shape); - return SHAPELIST(newShape); - } - - DECLARE_SYN(randomnormal, random_normal); +DECLARE_SYN(randomnormal, random_normal); - DECLARE_TYPES(random_normal) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +DECLARE_TYPES(random_normal) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/random/poisson.cpp b/libnd4j/include/ops/declarable/generic/random/poisson.cpp index 2eb601bc9be9..0c190b2c8372 100644 --- a/libnd4j/include/ops/declarable/generic/random/poisson.cpp +++ b/libnd4j/include/ops/declarable/generic/random/poisson.cpp @@ -25,43 +25,43 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(random_poisson, 2, 1, false, 0, 0) { - // gamma distribution - auto rng = block.randomGenerator(); - auto shape = INPUT_VARIABLE(0); - auto lambda = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - auto seed = 0; - if (block.getIArguments()->size()) { - seed = INT_ARG(0); - } - rng.setSeed(seed); - helpers::fillRandomPoisson(block.launchContext(), rng, lambda, output); - - return Status::OK(); - } +namespace ops { +CUSTOM_OP_IMPL(random_poisson, 2, 1, false, 0, 0) { + // gamma distribution + auto rng = block.randomGenerator(); + auto shape = INPUT_VARIABLE(0); + auto lambda = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + auto seed = 0; + if (block.numI()) { + seed = INT_ARG(0); + } + rng.setSeed(seed); + helpers::fillRandomPoisson(block.launchContext(), rng, lambda, output); + return Status::OK(); +} - DECLARE_SHAPE_FN(random_poisson) { - auto in = INPUT_VARIABLE(0); - auto shape = in->template asVectorT(); - auto lambdaShape = inputShape->at(1); - auto dtype = block.numD() > 0? D_ARG(0) : ArrayOptions::dataType(lambdaShape); - for (auto d = 0; d < shape::rank(lambdaShape); ++d ) { - shape.emplace_back(shape::sizeAt(lambdaShape, d)); - } - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(dtype, 'c', shape); - return SHAPELIST(newShape); - } +DECLARE_SHAPE_FN(random_poisson) { + auto in = INPUT_VARIABLE(0); + auto shape = in->template asVectorT(); + auto lambdaShape = inputShape->at(1); + auto dtype = block.numD() > 0? D_ARG(0) : ArrayOptions::dataType(lambdaShape); + for (auto d = 0; d < shape::rank(lambdaShape); ++d) { + shape.emplace_back(shape::sizeAt(lambdaShape, d)); + } + auto newShape = + ConstantShapeHelper::getInstance().createShapeInfo(dtype, 'c', shape); + return SHAPELIST(newShape); +} - DECLARE_TYPES(random_poisson) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +DECLARE_TYPES(random_poisson) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/random/random_crop.cpp b/libnd4j/include/ops/declarable/generic/random/random_crop.cpp index 1b30b2f91fe3..fdcd1490c63c 100644 --- a/libnd4j/include/ops/declarable/generic/random/random_crop.cpp +++ b/libnd4j/include/ops/declarable/generic/random/random_crop.cpp @@ -18,55 +18,58 @@ // Created by GS // - #include #include namespace sd { namespace ops { - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(random_crop, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); // values for crop - auto shape = INPUT_VARIABLE(1); // shape for result + auto input = INPUT_VARIABLE(0); // values for crop + auto shape = INPUT_VARIABLE(1); // shape for result + + NDArray* reduceShape = nullptr; // this param is optional + auto output = OUTPUT_VARIABLE(0); // - NDArray* reduceShape = nullptr; // this param is optional - auto output = OUTPUT_VARIABLE(0); // - - int seed = 0; + int seed = 0; - if (block.getIArguments()->size() > 0) - seed = INT_ARG(0); + if (block.numI() > 0) seed = INT_ARG(0); - REQUIRE_TRUE(shape->isVector(), 0, "random_crop: Shape tensor should be a vector."); - - REQUIRE_TRUE(input->rankOf() == shape->lengthOf(), 0, "random_crop: The length of the shape vector is not match input rank. %i and %i were given.", - input->rankOf(), shape->lengthOf()); + REQUIRE_TRUE(shape->isVector(), 0, + "random_crop: Shape tensor should be a vector."); - for (int e = 0; e < shape->lengthOf(); ++e) { - REQUIRE_TRUE((*shape).e(e) <= input->sizeAt(e), 0, "random_crop: Shape tensor should be less than proper input dimension (dim %i, %i > %i).", e, (*shape).e(e), input->sizeAt(e)); - } + REQUIRE_TRUE(input->rankOf() == shape->lengthOf(), 0, + "random_crop: The length of the shape vector is not match input " + "rank. %i and %i were given.", + input->rankOf(), shape->lengthOf()); - return helpers::randomCropFunctor(block, input, shape, output, seed); + for (int e = 0; e < shape->lengthOf(); ++e) { + REQUIRE_TRUE((*shape).e(e) <= input->sizeAt(e), 0, + "random_crop: Shape tensor should be less than proper input " + "dimension (dim %i, %i > %i).", + e, (*shape).e(e), input->sizeAt(e)); + } + + return helpers::randomCropFunctor(block, input, shape, output, seed); } DECLARE_SHAPE_FN(random_crop) { - auto in = INPUT_VARIABLE(1); - auto typeShape = inputShape->at(0); - std::vector shape(in->lengthOf()); - - for (int e = 0; e < shape.size(); e++) - shape[e] = (*in).e(e); - - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(typeShape), 'c', shape); - return SHAPELIST(newShape); + auto in = INPUT_VARIABLE(1); + auto typeShape = inputShape->at(0); + std::vector shape(in->lengthOf()); + + for (int e = 0; e < shape.size(); e++) shape[e] = (*in).e(e); + + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(typeShape), 'c', shape); + return SHAPELIST(newShape); } - DECLARE_TYPES(random_crop) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(random_crop) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } -} \ No newline at end of file +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/random/random_shuffle.cpp b/libnd4j/include/ops/declarable/generic/random/random_shuffle.cpp index 012d2e55ae43..4f8741207154 100644 --- a/libnd4j/include/ops/declarable/generic/random/random_shuffle.cpp +++ b/libnd4j/include/ops/declarable/generic/random/random_shuffle.cpp @@ -22,33 +22,31 @@ #if NOT_EXCLUDED(OP_random_shuffle) #include -#include +#include namespace sd { -namespace ops { +namespace ops { OP_IMPL(random_shuffle, 1, 1, true) { - - auto input = INPUT_VARIABLE(0); - const bool isInplace = block.isInplace(); - auto output = isInplace ? nullptr : OUTPUT_VARIABLE(0); - -// sd::random::RandomBuffer* rng = block.getRNG(); - sd::graph::RandomGenerator rng = block.randomGenerator(); -// REQUIRE_TRUE(rng != nullptr, 0, "RANDOM_SHUFFLE op: RNG should be defined in Graph !"); - - helpers::randomShuffle(block.launchContext(), *input, *output, rng, isInplace); - - return Status::OK(); -} + auto input = INPUT_VARIABLE(0); + const bool isInplace = block.isInplace(); + auto output = isInplace ? nullptr : OUTPUT_VARIABLE(0); + + // sd::random::RandomBuffer* rng = block.getRNG(); + sd::graph::RandomGenerator rng = block.randomGenerator(); + // REQUIRE_TRUE(rng != nullptr, 0, "RANDOM_SHUFFLE op: RNG should be + // defined in Graph !"); + helpers::randomShuffle(block.launchContext(), *input, *output, rng, + isInplace); - DECLARE_TYPES(random_shuffle) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } + return Status::OK(); } + +DECLARE_TYPES(random_shuffle) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/random/set_seed.cpp b/libnd4j/include/ops/declarable/generic/random/set_seed.cpp index f7050f3ab828..f573a6a69faa 100644 --- a/libnd4j/include/ops/declarable/generic/random/set_seed.cpp +++ b/libnd4j/include/ops/declarable/generic/random/set_seed.cpp @@ -21,43 +21,48 @@ #include #if NOT_EXCLUDED(OP_set_seed) -#include #include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(set_seed, -2, 1, false, 0, -2) { -// REQUIRE_TRUE(block.getRNG() != nullptr, 0, "RNG should be defined in Graph"); - auto rng = block.getRng(); //.getRNG(); - - Nd4jLong seed = 0; - if (block.getIArguments()->size() > 0) { - seed = INT_ARG(0); - } else if (block.width() > 0) { - auto input = INPUT_VARIABLE(0); - REQUIRE_TRUE(input->isScalar(),0 ,"SetSeed: Seed operand should be scalar"); - seed = input->e(0); - } else { - REQUIRE_TRUE(false, 0, "SetSeed: either IArg or scalr input should be provided"); - } - - // FIXME: this approach isn't really good for cuda, since it'll assume that CUDA might get nullptr instead of stream - //refreshBuffer(nullptr, seed, (Nd4jPointer) rng); - rng.setSeed((int)seed); - return Status::OK(); - } - - DECLARE_SHAPE_FN(set_seed) { - auto newshape = ConstantShapeHelper::getInstance().scalarShapeInfo(block.dataType()); - return SHAPELIST(newshape); - } - - DECLARE_TYPES(set_seed) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +namespace ops { +CUSTOM_OP_IMPL(set_seed, -2, 1, false, 0, -2) { + // REQUIRE_TRUE(block.getRNG() != nullptr, 0, "RNG should be + // defined in Graph"); + auto rng = block.getRng(); //.getRNG(); + + Nd4jLong seed = 0; + if (block.numI() > 0) { + seed = INT_ARG(0); + } else if (block.width() > 0) { + auto input = INPUT_VARIABLE(0); + REQUIRE_TRUE(input->isScalar(), 0, + "SetSeed: Seed operand should be scalar"); + seed = input->e(0); + } else { + REQUIRE_TRUE(false, 0, + "SetSeed: either IArg or scalr input should be provided"); + } + + // FIXME: this approach isn't really good for cuda, since it'll assume that + // CUDA might get nullptr instead of stream + // refreshBuffer(nullptr, seed, (Nd4jPointer) rng); + rng.setSeed((int)seed); + return Status::OK(); +} + +DECLARE_SHAPE_FN(set_seed) { + auto newshape = + ConstantShapeHelper::getInstance().scalarShapeInfo(DataType::FLOAT32); + return SHAPELIST(newshape); +} + +DECLARE_TYPES(set_seed) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/random/uniform.cpp b/libnd4j/include/ops/declarable/generic/random/uniform.cpp index d4abccf78308..a4936342d302 100644 --- a/libnd4j/include/ops/declarable/generic/random/uniform.cpp +++ b/libnd4j/include/ops/declarable/generic/random/uniform.cpp @@ -22,76 +22,78 @@ #include #if NOT_EXCLUDED(OP_randomuniform) -#include #include +#include #include namespace sd { - namespace ops { - /////////////////////// - /** - * uniform distribution - * takes 1 ndarray - * - * T argumens map: - * TArgs[0] - min for rng - * TArgs[1] - max for rng - */ - CUSTOM_OP_IMPL(randomuniform, 1, 1, true, 0, 0) { - // uniform distribution - auto rng = block.randomGenerator(); - auto dtype = DataType::FLOAT32; - if (block.getIArguments()->size()) - dtype = (DataType)INT_ARG(0); - - auto min = block.width() > 1 ? INPUT_VARIABLE(1) : (NDArray*) nullptr; - auto max = block.width() > 2 ? INPUT_VARIABLE(2) : (NDArray*) nullptr; - bool disposable = false; +namespace ops { +/////////////////////// +/** + * uniform distribution + * takes 1 ndarray + * + * T argumens map: + * TArgs[0] - min for rng + * TArgs[1] - max for rng + */ +CUSTOM_OP_IMPL(randomuniform, 1, 1, true, 0, 0) { + // uniform distribution + auto rng = block.randomGenerator(); + auto dtype = DataType::FLOAT32; + if (block.numI()) dtype = (DataType)INT_ARG(0); - if (min == nullptr && max == nullptr && block.numT() >= 2) { - min = NDArrayFactory::create_(dtype, block.launchContext()); - max = NDArrayFactory::create_(dtype, block.launchContext()); - min->p(0, T_ARG(0)); - max->p(0, T_ARG(1)); - disposable = true; - } + auto min = block.width() > 1 ? INPUT_VARIABLE(1) : (NDArray*)nullptr; + auto max = block.width() > 2 ? INPUT_VARIABLE(2) : (NDArray*)nullptr; + bool disposable = false; - auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(output->dataType() == dtype, 0, "RandomUniform: data type of output should be equals to given."); + if (min == nullptr && max == nullptr && block.numT() >= 2) { + min = NDArrayFactory::create_(dtype, block.launchContext()); + max = NDArrayFactory::create_(dtype, block.launchContext()); + min->p(0, T_ARG(0)); + max->p(0, T_ARG(1)); + disposable = true; + } - helpers::fillRandomUniform(block.launchContext(), rng, min, max, output); + auto output = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(output->dataType() == dtype, 0, + "RandomUniform: data type of output should be equals to given."); - if (disposable) { - delete min; - delete max; - } - return Status::OK(); - } + helpers::fillRandomUniform(block.launchContext(), rng, min, max, output); + if (disposable) { + delete min; + delete max; + } + return Status::OK(); +} - DECLARE_SHAPE_FN(randomuniform) { - auto in = INPUT_VARIABLE(0); - //auto min = INPUT_VARIABLE(1); - auto shape = in->template asVectorT(); - auto dtype = DataType::FLOAT32; //ArrayOptions::dataType(inputShape->at(1)); // output type is by given min +DECLARE_SHAPE_FN(randomuniform) { + auto in = INPUT_VARIABLE(0); + // auto min = INPUT_VARIABLE(1); + auto shape = in->template asVectorT(); + auto dtype = DataType::FLOAT32; // ArrayOptions::dataType(inputShape->at(1)); + // // output type is by given min - if (block.getIArguments()->size()) - dtype = (DataType)INT_ARG(0); - if (block.width() > 1) - REQUIRE_TRUE(dtype == INPUT_VARIABLE(1)->dataType(), 0, "RandomUniform: data type of output and min/max args should be the same"); + if (block.numI()) dtype = (DataType)INT_ARG(0); + if (block.width() > 1) + REQUIRE_TRUE(dtype == INPUT_VARIABLE(1)->dataType(), 0, + "RandomUniform: data type of output and min/max args should " + "be the same"); - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(dtype, 'c', shape); - return SHAPELIST(newShape); - } + auto newShape = + ConstantShapeHelper::getInstance().createShapeInfo(dtype, 'c', shape); + return SHAPELIST(newShape); +} - DECLARE_TYPES(randomuniform) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); - } - } +DECLARE_TYPES(randomuniform) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/reduce/argamax.cpp b/libnd4j/include/ops/declarable/generic/reduce/argamax.cpp index a347c398a8b1..b3008c923ded 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/argamax.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/argamax.cpp @@ -39,7 +39,7 @@ namespace sd { if (output->isEmpty()) return Status::OK(); - auto axis = *block.getIArguments(); + auto axis = block.getIArguments(); // axis might be dynamic (i.e. tf mode) if (block.width() > 1 && axis.size() == 0) { @@ -60,7 +60,7 @@ namespace sd { std::vector dims; if (block.width() == 1) { - dims = *block.getIArguments(); + dims = block.getIArguments(); } else { auto y = INPUT_VARIABLE(1); dims = y->template asVectorT(); @@ -87,7 +87,7 @@ namespace sd { return SHAPELIST(ConstantShapeHelper::getInstance().scalarShapeInfo(dtype)); } - return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), dtype, keepDims, false, block.getWorkspace())); + return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), dtype, keepDims, false, block.workspace())); } } } diff --git a/libnd4j/include/ops/declarable/generic/reduce/argamin.cpp b/libnd4j/include/ops/declarable/generic/reduce/argamin.cpp index 68ad9d2e5b7f..491696858b2a 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/argamin.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/argamin.cpp @@ -39,7 +39,7 @@ namespace sd { if (output->isEmpty()) return Status::OK(); - auto axis = *block.getIArguments(); + auto axis = block.getIArguments(); // axis might be dynamic (i.e. tf mode) if (block.width() > 1 && axis.size() == 0) { @@ -60,7 +60,7 @@ namespace sd { std::vector dims; if (block.width() == 1) { - dims = *block.getIArguments(); + dims = block.getIArguments(); } else { auto y = INPUT_VARIABLE(1); dims = y->template asVectorT(); @@ -87,7 +87,7 @@ namespace sd { return SHAPELIST(ConstantShapeHelper::getInstance().scalarShapeInfo(dtype)); } - return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), dtype, keepDims, false, block.getWorkspace())); + return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), dtype, keepDims, false, block.workspace())); } } } diff --git a/libnd4j/include/ops/declarable/generic/reduce/argmax.cpp b/libnd4j/include/ops/declarable/generic/reduce/argmax.cpp index f8a2486fa1c8..acd8a669138a 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/argmax.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/argmax.cpp @@ -23,75 +23,82 @@ #include #include -#include #include +#include +#include namespace sd { - namespace ops { - DECLARE_TYPES(argmax) { - getOpDescriptor() - ->setAllowedInputTypes({ ALL_FLOATS,ALL_INTS }) - ->setAllowedOutputTypes({ALL_INTS}); - } - - CUSTOM_OP_IMPL(argmax, 1, 1, false, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - if (output->isEmpty()) +namespace ops { +DECLARE_TYPES(argmax) { + getOpDescriptor() + ->setAllowedInputTypes({ ALL_FLOATS,ALL_INTS }) + ->setAllowedOutputTypes({ALL_INTS}); +} + +CUSTOM_OP_IMPL(argmax, 1, 1, false, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (output->isEmpty()) return Status::OK(); - auto axis = *block.getIArguments(); - - // axis might be dynamic (i.e. tf mode) - if (block.width() > 1 && axis.size() == 0) { - auto axisVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axisVector, axis); - helpers::argMax(*input, *output, axis); - } else { - helpers::argMax(*input, *output, axis); + auto axis = block.getIArguments(); - } + // axis might be dynamic (i.e. tf mode) + if (block.width() > 1 && axis.size() == 0) { + auto axisVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axisVector, axis); - STORE_RESULT(output); + helpers::argMax(*input, *output, axis); + } else { + helpers::argMax(*input, *output, axis); - return Status::OK(); - } - DECLARE_SHAPE_FN(argmax) { - std::vector dims; + } - if (block.width() == 1) { - dims = *block.getIArguments(); - } else { - auto y = INPUT_VARIABLE(1); - dims = y->template asVectorT(); - } + STORE_RESULT(output); - auto keepDims = block.numB() ? B_ARG(0) : false; - auto dtype = block.numD() ? D_ARG(0) : DataType::INT64; + return Status::OK(); +} + +DECLARE_SHAPE_FN(argmax) { + std::vector dims; + + if (block.width() == 1) { + dims = block.getIArguments(); + } else { + auto y = INPUT_VARIABLE(1); + dims = y->template asVectorT(); + } - // we're resolving negative axis here - helpers::adjustAxis(shape::rank(inputShape->at(0)), dims); + auto keepDims = block.numB() ? B_ARG(0) : false; + auto dtype = block.numD() ? D_ARG(0) : DataType::INT64;// we're resolving negative axis here + helpers::adjustAxis(shape::rank(inputShape->at(0)), dims); - auto in = inputShape->at(0); - for (auto d : dims) { - // we have special case here + auto in = inputShape->at(0); + + for (auto d : dims) { + // we have special case here if (d == sd::DataTypeUtils::max()) continue; REQUIRE_TRUE(d < shape::rank(in), 0, "ArgMax: axis can't be above rank") - REQUIRE_TRUE(in[d + 1] != 0, 0, "ArgMax: you can't reduce along axis with 0 in shape"); - } - - // special case - output is scalar - if (dims.empty() || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max())) { - return SHAPELIST(ConstantShapeHelper::getInstance().scalarShapeInfo(dtype)); - } - - return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), dtype, keepDims, false, block.getWorkspace())); - } - } + REQUIRE_TRUE(in[d + 1] != 0, 0, + "ArgMax: you can't reduce along axis with 0 in shape"); + } + + // special case - output is scalar + if (dims.empty() || + (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max())) { + return SHAPELIST(ConstantShapeHelper::getInstance().scalarShapeInfo( + dtype)); + } + + return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), + dtype, keepDims, + false, block.workspace())); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/reduce/argmin.cpp b/libnd4j/include/ops/declarable/generic/reduce/argmin.cpp index 40648b7f6ae7..768bb297d7ce 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/argmin.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/argmin.cpp @@ -27,74 +27,74 @@ #include namespace sd { - namespace ops { +namespace ops { - DECLARE_TYPES(argmin) { - getOpDescriptor() - ->setAllowedInputTypes({ ALL_FLOATS,ALL_INTS }) - ->setAllowedOutputTypes({ALL_INTS}); - } - - CUSTOM_OP_IMPL(argmin, 1, 1, false, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto axis = *block.getIArguments(); +DECLARE_TYPES(argmin) { + getOpDescriptor() + ->setAllowedInputTypes({ ALL_FLOATS,ALL_INTS }) + ->setAllowedOutputTypes({ALL_INTS}); +} - auto output = OUTPUT_VARIABLE(0); +CUSTOM_OP_IMPL(argmin, 1, 1, false, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto axis = block.getIArguments(); - if (output->isEmpty()) - return Status::OK(); + auto output = OUTPUT_VARIABLE(0); - // axis might be dynamic (i.e. tf mode) - if (block.width() > 1 && axis.size() == 0) { - auto axisVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axisVector, axis); - helpers::argMin(*input, *output, axis); + if (output->isEmpty()) + return Status::OK();// axis might be dynamic (i.e. tf mode) + if (block.width() > 1 && axis.size() == 0) { + auto axisVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axisVector, axis);helpers::argMin(*input, *output, axis); } else { helpers::argMin(*input, *output, axis); - } - STORE_RESULT(output); + } - return ND4J_STATUS_OK; - } + STORE_RESULT(output); - DECLARE_SHAPE_FN(argmin) { - std::vector dims; + return ND4J_STATUS_OK; +} - if (block.width() == 1) { - dims = *block.getIArguments(); - } else { - auto y = INPUT_VARIABLE(1); - dims = y->template asVectorT(); - } +DECLARE_SHAPE_FN(argmin) { + std::vector dims; - auto keepDims = block.numB() ? B_ARG(0) : false; - auto dtype = block.numD() ? D_ARG(0) : DataType::INT64; + if (block.width() == 1) { + dims = block.getIArguments(); + } else { + auto y = INPUT_VARIABLE(1); + dims = y->template asVectorT(); + } - // we're resolving negative axis here - helpers::adjustAxis(shape::rank(inputShape->at(0)), dims); + auto keepDims = block.numB() ? B_ARG(0) : false; + auto dtype = block.numD() ? D_ARG(0) : DataType::INT64;// we're resolving negative axis here + helpers::adjustAxis(shape::rank(inputShape->at(0)), dims); - auto in = inputShape->at(0); + auto in = inputShape->at(0) ; for (auto d : dims) { // we have special case here - if (d == sd::DataTypeUtils::max()) + if ( d == sd::DataTypeUtils::max()) continue; + REQUIRE_TRUE(d < shape::rank(in), 0, "ArgMin: axis can't be above rank") + REQUIRE_TRUE(in[d + 1] != 0, 0, + "ArgMin: you can't reduce along axis with 0 in shape"); + } - REQUIRE_TRUE(d < shape::rank(in), 0, "ArgMin: axis can't be above rank") - REQUIRE_TRUE(in[d + 1] != 0, 0, "ArgMin: you can't reduce along axis with 0 in shape"); - } + // special case - output is scalar + if (dims.empty() || + (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max())) { + return SHAPELIST( + ConstantShapeHelper::getInstance().scalarShapeInfo(dtype)); + } - // special case - output is scalar - if (dims.empty() || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max())) { - return SHAPELIST(ConstantShapeHelper::getInstance().scalarShapeInfo(dtype)); - } + return SHAPELIST( ShapeUtils::evalReduceShapeInfo( + 'c', dims, inputShape->at(0), dtype, keepDims, false, block.workspace())); - return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), dtype, keepDims, false, block.getWorkspace())); - } - - } } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/reduce/norm.cpp b/libnd4j/include/ops/declarable/generic/reduce/norm.cpp index 64c2e5ccb672..6ae9ea962588 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/norm.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/norm.cpp @@ -25,82 +25,87 @@ #include namespace sd { - namespace ops { - REDUCTION_OP_IMPL(norm, 1, 1, false, 1, -2) { - auto input = INPUT_VARIABLE(0); - NDArray *output = OUTPUT_VARIABLE(0); +namespace ops { +REDUCTION_OP_IMPL(norm, 1, 1, false, 1, -2) { + auto input = INPUT_VARIABLE(0); + NDArray *output = OUTPUT_VARIABLE(0); - auto mode = (int) T_ARG(0); - std::vector dims = *block.getIArguments(); - bool overwrite = false; + auto mode = (int)T_ARG(0); + std::vector dims = block.getIArguments(); + bool overwrite = false; - if (block.width() == 1) { - output = OUTPUT_VARIABLE(0); - } else { - auto axisVector = INPUT_VARIABLE(1); - dims.resize(axisVector->lengthOf()); - helpers::adjustAxis(input->rankOf(), axisVector, dims); - axisVector->printIndexedBuffer("AXIS"); - auto shape = ShapeUtils::evalReduceShapeInfo(input->ordering(), dims, *input, false, false); - if (!shape::equalsStrict(shape, output->shapeInfo())) { - output = new NDArray(shape, false, block.launchContext()); - overwrite = true; - } - } - output->printShapeInfo("Output Shape Info"); - switch(mode) { - case 0: { - REQUIRE_TRUE(dims.size() == 2 || (input->rankOf() == 2 && dims.size() == 0), 0, "Norm: Frobenius is defined for 2D matrices or TADS only"); - // fro - input->reduceAlongDimension(reduce::NormFrobenius, *output, dims, false, output->rankOf() == 2); - } - break; - case 1: { - // euclidean - if ((input->rankOf() == 2 && dims.size() == 0) || dims.size() == 2) { - input->reduceAlongDimension(reduce::NormFrobenius, *output, dims, false, output->rankOf() == 2); - } else { - input->reduceAlongDimension(reduce::Norm2, *output, dims, false, output->rankOf() == 2); - } - } - break; - case 2: { - // 1 - input->reduceAlongDimension(reduce::Norm1, *output, dims, false, output->rankOf() == 2); - } - break; - case 3: { - // 2 - input->reduceAlongDimension(reduce::Norm2, *output, dims, false, output->rankOf() == 2); - } - break; - case 4: { - // inf-norm - input->reduceAlongDimension(reduce::NormMax, *output, dims, false, output->rankOf() == 2); - } - break; - default: { - // p-norm - REQUIRE_TRUE(block.getIArguments()->size() > 1, 0, "P-Norm reductions requires 2 TArguments, but only 1 was provided"); - // FIXME: p is required here - //T p = T_ARG(1); - input->reduceAlongDimension(reduce::NormP, *output, dims, false, output->rankOf() == 2); - } - } + if (block.width() == 1) { + output = OUTPUT_VARIABLE(0); + } else { + auto axisVector = INPUT_VARIABLE(1); + dims.resize(axisVector->lengthOf()); + helpers::adjustAxis(input->rankOf(), axisVector, dims); + auto shape = ShapeUtils::evalReduceShapeInfo(input->ordering(), dims, + *input, false, false); + if (!shape::equalsStrict(shape, output->shapeInfo())) { + output = new NDArray(shape, false, block.launchContext()); + overwrite = true; + } + } + switch (mode) { + case 0: { + REQUIRE_TRUE( + dims.size() == 2 || (input->rankOf() == 2 && dims.size() == 0), 0, + "Norm: Frobenius is defined for 2D matrices or TADS only"); + // fro + input->reduceAlongDimension(reduce::NormFrobenius, *output, dims, false, + output->rankOf() == 2); + } break; + case 1: { + // euclidean + if ((input->rankOf() == 2 && dims.size() == 0) || dims.size() == 2) { + input->reduceAlongDimension(reduce::NormFrobenius, *output, dims, false, + output->rankOf() == 2); + } else { + input->reduceAlongDimension(reduce::Norm2, *output, dims, false, + output->rankOf() == 2); + } + } break; + case 2: { + // 1 + input->reduceAlongDimension(reduce::Norm1, *output, dims, false, + output->rankOf() == 2); + } break; + case 3: { + // 2 + input->reduceAlongDimension(reduce::Norm2, *output, dims, false, + output->rankOf() == 2); + } break; + case 4: { + // inf-norm + input->reduceAlongDimension(reduce::NormMax, *output, dims, false, + output->rankOf() == 2); + } break; + default: { + // p-norm + REQUIRE_TRUE( + block.numI() > 1, 0, + "P-Norm reductions requires 2 TArguments, but only 1 was provided"); + // FIXME: p is required here + // T p = T_ARG(1); + input->reduceAlongDimension(reduce::NormP, *output, dims, false, + output->rankOf() == 2); + } + } - if (overwrite) { - OVERWRITE_RESULT(output); - } + if (overwrite) { + OVERWRITE_RESULT(output); + } - return ND4J_STATUS_OK; - }; + return ND4J_STATUS_OK; +}; - DECLARE_TYPES(norm) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +DECLARE_TYPES(norm) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduceMean.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduceMean.cpp index 90560bbb6feb..5681a7e49eeb 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduceMean.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduceMean.cpp @@ -18,143 +18,165 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 01.06.2018 // - #include #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_mean, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - auto dimensions = *block.getIArguments(); - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - - bool keepDims = false; - if (block.getBArguments()->size()) - keepDims = B_ARG(0); - else if (block.getTArguments()->size()) - keepDims = (bool)T_ARG(0); - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_MEAN OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_MEAN OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - input->reduceAlongDimension(reduce::Mean, *output, dimensions, keepDims); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + auto dimensions = block.getIArguments(); + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } + + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_MEAN OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, + "REDUCE_MEAN OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + input->reduceAlongDimension(reduce::Mean, *output, dimensions, keepDims); + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_mean) { - - auto dimensions = *block.getIArguments(); - auto in = inputShape->at(0); - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(shape::rank(in), axesVector, dimensions); - } - - bool keepDims = false; - if (block.getBArguments()->size()) - keepDims = B_ARG(0); - else if (block.getTArguments()->size()) - keepDims = (bool)T_ARG(0); - - REQUIRE_TRUE(dimensions.size() <= in[0], 0, "REDUCE_MEAN OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_MEAN OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - auto outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(in), dimensions, in, keepDims, false, block.getWorkspace()); - - return SHAPELIST(outShapeInfo); + auto dimensions = block.getIArguments(); + auto in = inputShape->at(0); + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(shape::rank(in), axesVector, dimensions); + } + + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + REQUIRE_TRUE(dimensions.size() <= in[0], 0, + "REDUCE_MEAN OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_MEAN OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + auto outShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(in), dimensions, in, keepDims, false, block.workspace()); + + return SHAPELIST(outShapeInfo); } DECLARE_TYPES(reduce_mean) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_mean_bp, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - - auto gradI = OUTPUT_VARIABLE(0); - - auto dimensions = *block.getIArguments(); - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - - bool keepDims = false; - if (block.getBArguments()->size()) - keepDims = B_ARG(0); - else if (block.getTArguments()->size()) - keepDims = (bool)T_ARG(0); - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_MEAN_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_MEAN_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - if(gradO->lengthOf() == 1) { - gradI->assign(gradO->e(0) / input->lengthOf()); - } - else { - - gradI->assign((gradO->lengthOf() + 0.) / input->lengthOf()); - - if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.getWorkspace()); - *gradI *= gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] - } - else - *gradI *= *gradO; - - } - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + + auto gradI = OUTPUT_VARIABLE(0); + + auto dimensions = block.getIArguments(); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } + + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_MEAN_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, + "REDUCE_MEAN_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + if (gradO->lengthOf() == 1) { + gradI->assign(gradO->e(0) / input->lengthOf()); + } else { + gradI->assign((gradO->lengthOf() + 0.) / input->lengthOf()); + + if (!keepDims) { + auto gradOShapeKeepDims = + ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, + true, false, block.workspace()); + *gradI *= gradO->reshape( + gradO->ordering(), + ShapeUtils::pullShapeFromShapeInfo( + gradOShapeKeepDims)); // for example could be something like + // [a,b] -> [1,a,1,b] + } else + *gradI *= *gradO; + } + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_mean_bp) { - auto in = inputShape->at(0); - auto dimensions = *block.getIArguments(); - auto rank = shape::rank(in); - - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(rank, axesVector, dimensions); - } - REQUIRE_TRUE(dimensions.size() <= rank, 0, "REDUCE_MEAN_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -rank || item < rank, 0, "REDUCE_MEAN_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , rank, rank, item); - - Nd4jLong* gradIshapeInfo(nullptr); - COPY_SHAPE(inputShape->at(0), gradIshapeInfo); - - return SHAPELIST(CONSTANT(gradIshapeInfo)); + auto in = inputShape->at(0); + auto dimensions = block.getIArguments(); + auto rank = shape::rank(in); + + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(rank, axesVector, dimensions); + } + REQUIRE_TRUE(dimensions.size() <= rank, 0, + "REDUCE_MEAN_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -rank || item < rank, 0, + "REDUCE_MEAN_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + rank, rank, item); + + Nd4jLong* gradIshapeInfo(nullptr); + COPY_SHAPE(inputShape->at(0), gradIshapeInfo); + + return SHAPELIST(CONSTANT(gradIshapeInfo)); } - DECLARE_TYPES(reduce_mean_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - - -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduceStDev.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduceStDev.cpp index d101a6a79f12..cdc0698e35e2 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduceStDev.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduceStDev.cpp @@ -18,162 +18,193 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 04.06.2018 // - #include #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_stdev, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - bool keepDims = false;//block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; - bool biasCorrected = false;//block.getTArguments()->size() > 1 ? (bool)T_ARG(1) : false; - - auto dimensions = *block.getIArguments(); - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - - if (block.getBArguments()->size()) { - keepDims = B_ARG(0); - if (block.getBArguments()->size() > 1) - biasCorrected = B_ARG(1); - } - else if (block.getTArguments()->size()) { - keepDims = (bool)T_ARG(0); - if (block.getTArguments()->size() > 1) - biasCorrected = (bool)T_ARG(1); - } - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_STDEV OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_STDEV OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, *output, biasCorrected, dimensions); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + bool keepDims = + false; // block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; + bool biasCorrected = + false; // block.getTArguments()->size() > 1 ? (bool)T_ARG(1) : false; + + auto dimensions = block.getIArguments(); + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } + + if (block.numB()) { + keepDims = B_ARG(0); + if (block.numB() > 1) biasCorrected = B_ARG(1); + } else if (block.numT()) { + keepDims = (bool)T_ARG(0); + if (block.numT() > 1) biasCorrected = (bool)T_ARG(1); + } + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_STDEV OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, + "REDUCE_STDEV OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, + *output, biasCorrected, dimensions); + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_stdev) { - auto in = inputShape->at(0); - auto rank = shape::rank(in); - bool keepDims = false;//block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; - auto dimensions = *block.getIArguments(); - - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(rank, axesVector, dimensions); - } - - if (block.getBArguments()->size()) { - keepDims = B_ARG(0); - } - else if (block.getTArguments()->size()) { - keepDims = (bool)T_ARG(0); - } - - REQUIRE_TRUE(dimensions.size() <= rank, 0, "REDUCE_STDEV OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_STDEV OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - auto outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(in), dimensions, in, keepDims, false, block.getWorkspace()); - - return SHAPELIST(outShapeInfo); + auto in = inputShape->at(0); + auto rank = shape::rank(in); + bool keepDims = + false; // block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; + auto dimensions = block.getIArguments(); + + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(rank, axesVector, dimensions); + } + + if (block.numB()) { + keepDims = B_ARG(0); + } else if (block.numT()) { + keepDims = (bool)T_ARG(0); + } + + REQUIRE_TRUE(dimensions.size() <= rank, 0, + "REDUCE_STDEV OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_STDEV OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + auto outShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(in), dimensions, in, keepDims, false, block.workspace()); + + return SHAPELIST(outShapeInfo); } DECLARE_TYPES(reduce_stdev) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_stdev_bp, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - - auto gradI = OUTPUT_VARIABLE(0); - - bool keepDims = false;//block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; - bool biasCorrected = false;//block.getTArguments()->size() > 1 ? (bool)T_ARG(1) : false; - - auto dimensions = *block.getIArguments(); - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - - if (block.getBArguments()->size()) { - keepDims = B_ARG(0); - if (block.getBArguments()->size() > 1) - biasCorrected = B_ARG(1); - } - else if (block.getTArguments()->size()) { - keepDims = (bool)T_ARG(0); - if (block.getTArguments()->size() > 1) - biasCorrected = (bool)T_ARG(1); - } - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_STDEV_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_STDEV_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - const Nd4jLong N = input->lengthOf() / gradO->lengthOf(); - const Nd4jLong NminusOne = biasCorrected ? N - 1 : N; - - auto mean = input->reduceAlongDimension(reduce::Mean, dimensions, true); - - NDArray variance(mean.shapeInfo(), true, block.launchContext()); // create empty array with shape matching shape of mean array - input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, variance, biasCorrected, dimensions); - - gradI->assign( (*input - mean) / (variance * NminusOne)); // automatic broadcasting happens here - - if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.getWorkspace()); - *gradI *= gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] - } - else - *gradI *= *gradO; // automatic broadcasting happens here - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + + auto gradI = OUTPUT_VARIABLE(0); + + bool keepDims = + false; // block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; + bool biasCorrected = + false; // block.getTArguments()->size() > 1 ? (bool)T_ARG(1) : false; + + auto dimensions = block.getIArguments(); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } + + if (block.numB()) { + keepDims = B_ARG(0); + if (block.numB() > 1) biasCorrected = B_ARG(1); + } else if (block.numT()) { + keepDims = (bool)T_ARG(0); + if (block.numT() > 1) biasCorrected = (bool)T_ARG(1); + } + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_STDEV_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, + "REDUCE_STDEV_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + const Nd4jLong N = input->lengthOf() / gradO->lengthOf(); + const Nd4jLong NminusOne = biasCorrected ? N - 1 : N; + + auto mean = input->reduceAlongDimension(reduce::Mean, dimensions, true); + + NDArray variance(mean.shapeInfo(), true, + block.launchContext()); // create empty array with shape + // matching shape of mean array + input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, + variance, biasCorrected, dimensions); + + gradI->assign((*input - mean) / + (variance * NminusOne)); // automatic broadcasting happens here + + if (!keepDims) { + auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo( + gradO->ordering(), dimensions, *input, true, false, block.workspace()); + *gradI *= gradO->reshape( + gradO->ordering(), + ShapeUtils::pullShapeFromShapeInfo( + gradOShapeKeepDims)); // for example could be something like [a,b] + // -> [1,a,1,b] + } else + *gradI *= *gradO; // automatic broadcasting happens here + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_stdev_bp) { - auto in = inputShape->at(0); - auto rank = shape::rank(in); - auto dimensions = *block.getIArguments(); - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(rank, axesVector, dimensions); - } - - REQUIRE_TRUE(dimensions.size() <= rank, 0, "REDUCE_STDEV_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_STDEV_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - Nd4jLong* gradIshapeInfo(nullptr); - COPY_SHAPE(in, gradIshapeInfo); - - return SHAPELIST(CONSTANT(gradIshapeInfo)); -// return SHAPELIST(in); + auto in = inputShape->at(0); + auto rank = shape::rank(in); + auto dimensions = block.getIArguments(); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(rank, axesVector, dimensions); + } + + REQUIRE_TRUE(dimensions.size() <= rank, 0, + "REDUCE_STDEV_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_STDEV_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + Nd4jLong* gradIshapeInfo(nullptr); + COPY_SHAPE(in, gradIshapeInfo); + + return SHAPELIST(CONSTANT(gradIshapeInfo)); + // return SHAPELIST(in); } DECLARE_TYPES(reduce_stdev_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduceVariance.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduceVariance.cpp index cd7441304f2c..29916c8f9950 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduceVariance.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduceVariance.cpp @@ -18,157 +18,186 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 04.06.2018 // - #include #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_variance, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - bool keepDims = false;//block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; - bool biasCorrected = false;//block.getTArguments()->size() > 1 ? (bool)T_ARG(1) : false; - - auto dimensions = *block.getIArguments(); - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - - if (block.getBArguments()->size()) { - keepDims = B_ARG(0); - if (block.getBArguments()->size() > 1) - biasCorrected = B_ARG(1); - } - else if (block.getTArguments()->size()) { - keepDims = (bool)T_ARG(0); - if (block.getTArguments()->size() > 1) - biasCorrected = (bool)T_ARG(1); - } - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_VARIANCE OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_VARIANCE OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - input->varianceAlongDimension(variance::SummaryStatsVariance, *output, biasCorrected, dimensions); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + bool keepDims = + false; // block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; + bool biasCorrected = + false; // block.getTArguments()->size() > 1 ? (bool)T_ARG(1) : false; + + auto dimensions = block.getIArguments(); + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } + + if (block.numB()) { + keepDims = B_ARG(0); + if (block.numB() > 1) biasCorrected = B_ARG(1); + } else if (block.numT()) { + keepDims = (bool)T_ARG(0); + if (block.numT() > 1) biasCorrected = (bool)T_ARG(1); + } + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_VARIANCE OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, + "REDUCE_VARIANCE OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + input->varianceAlongDimension(variance::SummaryStatsVariance, *output, + biasCorrected, dimensions); + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_variance) { - - bool keepDims = false;//block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; - auto dimensions = *block.getIArguments(); - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - - if (block.getBArguments()->size()) { - keepDims = B_ARG(0); - } - else if (block.getTArguments()->size()) { - keepDims = (bool)T_ARG(0); - } - - REQUIRE_TRUE(dimensions.size() <= INPUT_VARIABLE(0)->rankOf(), 0, "REDUCE_VARIANCE OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_VARIANCE OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - auto outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.getWorkspace()); - - return SHAPELIST(outShapeInfo); + bool keepDims = + false; // block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; + auto dimensions = block.getIArguments(); + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } + + if (block.numB()) { + keepDims = B_ARG(0); + } else if (block.numT()) { + keepDims = (bool)T_ARG(0); + } + + REQUIRE_TRUE(dimensions.size() <= INPUT_VARIABLE(0)->rankOf(), 0, + "REDUCE_VARIANCE OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_VARIANCE OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + auto outShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, + false, block.workspace()); + + return SHAPELIST(outShapeInfo); } DECLARE_TYPES(reduce_variance) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_variance_bp, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - - auto gradI = OUTPUT_VARIABLE(0); - - bool keepDims = false;//block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; - bool biasCorrected = false;//block.getTArguments()->size() > 1 ? (bool)T_ARG(1) : false; - - auto dimensions = *block.getIArguments(); - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } -// else if (block.getIArguments()->size()) - if (block.getBArguments()->size()) { - keepDims = B_ARG(0); - if (block.getBArguments()->size() > 1) - biasCorrected = B_ARG(1); - } - else if (block.getTArguments()->size()) { - keepDims = (bool)T_ARG(0); - if (block.getTArguments()->size() > 1) - biasCorrected = (bool)T_ARG(1); - } - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_VARIANCE_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_VARIANCE_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - const Nd4jLong N = input->lengthOf() / gradO->lengthOf(); - const Nd4jLong NminusOne = biasCorrected ? N - 1 : N; - const double factor1 = 2.0 / NminusOne; - const double factor2 = 2.0 / (N * NminusOne); - - auto mean = input->reduceAlongDimension(reduce::Mean, dimensions, true); - - gradI->assign( (*input - mean) * (2.0f / NminusOne)); // automatic broadcasting happens here - - if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.getWorkspace()); - *gradI *= gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] - } else - *gradI *= *gradO; // automatic broadcasting happens here - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + + auto gradI = OUTPUT_VARIABLE(0); + + bool keepDims = + false; // block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; + bool biasCorrected = + false; // block.getTArguments()->size() > 1 ? (bool)T_ARG(1) : false; + + auto dimensions = block.getIArguments(); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } + // else if (block.getIArguments()->size()) + if (block.numB()) { + keepDims = B_ARG(0); + if (block.numB() > 1) biasCorrected = B_ARG(1); + } else if (block.numT()) { + keepDims = (bool)T_ARG(0); + if (block.numT() > 1) biasCorrected = (bool)T_ARG(1); + } + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_VARIANCE_BP OP: the number of dimensions to reduce " + "along must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, + "REDUCE_VARIANCE_BP OP: the input dimension to reduce along " + "must be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + const Nd4jLong N = input->lengthOf() / gradO->lengthOf(); + const Nd4jLong NminusOne = biasCorrected ? N - 1 : N; + const double factor1 = 2.0 / NminusOne; + const double factor2 = 2.0 / (N * NminusOne); + + auto mean = input->reduceAlongDimension(reduce::Mean, dimensions, true); + + gradI->assign((*input - mean) * + (2.0f / NminusOne)); // automatic broadcasting happens here + + if (!keepDims) { + auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo( + gradO->ordering(), dimensions, *input, true, false, block.workspace()); + *gradI *= gradO->reshape( + gradO->ordering(), + ShapeUtils::pullShapeFromShapeInfo( + gradOShapeKeepDims)); // for example could be something like [a,b] + // -> [1,a,1,b] + } else + *gradI *= *gradO; // automatic broadcasting happens here + + return Status::OK(); } - DECLARE_SHAPE_FN(reduce_variance_bp) { - auto in = inputShape->at(0); - auto rank = shape::rank(in); - auto dimensions = *block.getIArguments(); - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(rank, axesVector, dimensions); - } - - REQUIRE_TRUE(dimensions.size() <= rank, 0, "REDUCE_VARIANCE_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_VARIANCE_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - Nd4jLong* gradIshapeInfo(nullptr); - COPY_SHAPE(in, gradIshapeInfo); - - return SHAPELIST(CONSTANT(gradIshapeInfo)); + auto in = inputShape->at(0); + auto rank = shape::rank(in); + auto dimensions = block.getIArguments(); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(rank, axesVector, dimensions); + } + + REQUIRE_TRUE(dimensions.size() <= rank, 0, + "REDUCE_VARIANCE_BP OP: the number of dimensions to reduce " + "along must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_VARIANCE_BP OP: the input dimension to reduce along " + "must be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + Nd4jLong* gradIshapeInfo(nullptr); + COPY_SHAPE(in, gradIshapeInfo); + + return SHAPELIST(CONSTANT(gradIshapeInfo)); } - DECLARE_TYPES(reduce_variance_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_dot.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_dot.cpp index 75cb40ca27bf..86b61a254771 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_dot.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_dot.cpp @@ -19,8 +19,8 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include namespace sd { namespace ops { @@ -28,97 +28,112 @@ namespace ops { //////////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_dot_bp, 3, 2, false, 0, 0) { - - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto gradO = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - - // L(x,y) = SUM(x_i * y_i) - // dL/dx_i = y_i - - REQUIRE_TRUE(x->isSameShape(y), 0, "REDUCE_DOT_BP OP: both input arrays x and y should have same shapes, but got %s and %s correspondingly", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); - - if (gradO->lengthOf() == 1) { // scalar of reduced to scalar with keep dimensions - gradX->assign((*y) * (*gradO)); - gradY->assign((*x) * (*gradO)); + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto gradO = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + + // L(x,y) = SUM(x_i * y_i) + // dL/dx_i = y_i + + REQUIRE_TRUE(x->isSameShape(y), 0, + "REDUCE_DOT_BP OP: both input arrays x and y should have same " + "shapes, but got %s and %s correspondingly", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str()); + + if (gradO->lengthOf() == + 1) { // scalar of reduced to scalar with keep dimensions + gradX->assign((*y) * (*gradO)); + gradY->assign((*x) * (*gradO)); + } else { + bool keepDims = false; + auto dimensions = block.getIArguments(); + + if (block.width() > 3) { + auto axesVector = INPUT_VARIABLE(3); + helpers::adjustAxis(x->rankOf(), axesVector, dimensions); } - else { - - bool keepDims = false; - auto dimensions = *block.getIArguments(); - - if (block.width() > 3) { - auto axesVector = INPUT_VARIABLE(3); - helpers::adjustAxis(x->rankOf(), axesVector, dimensions); - } - - if (block.getBArguments()->size()) - keepDims = B_ARG(0); - else if (block.getTArguments()->size()) - keepDims = (bool)T_ARG(0); - - REQUIRE_TRUE(dimensions.size() <= x->rankOf(), 0, "REDUCE_DOT_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -x->rankOf() && item < x->rankOf(), 0, "REDUCE_DOT_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , x->rankOf(), x->rankOf(), item); - - if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *x, true, false, block.getWorkspace()); - auto r = gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] - - gradX->assign((*y) * r); - gradY->assign((*x) * r); - } - else { - gradX->assign((*y) * (*gradO)); - gradY->assign((*x) * (*gradO)); - } + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + REQUIRE_TRUE(dimensions.size() <= x->rankOf(), 0, + "REDUCE_DOT_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -x->rankOf() && item < x->rankOf(), 0, + "REDUCE_DOT_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + x->rankOf(), x->rankOf(), item); + + if (!keepDims) { + auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo( + gradO->ordering(), dimensions, *x, true, false, block.workspace()); + auto r = gradO->reshape( + gradO->ordering(), + ShapeUtils::pullShapeFromShapeInfo( + gradOShapeKeepDims)); // for example could be something like + // [a,b] -> [1,a,1,b] + + gradX->assign((*y) * r); + gradY->assign((*x) * r); + } else { + gradX->assign((*y) * (*gradO)); + gradY->assign((*x) * (*gradO)); } - return Status::OK(); + } + return Status::OK(); } - DECLARE_SHAPE_FN(reduce_dot_bp) { + if (shape::length(inputShape->at(2)) > 1) { + bool keepDims = false; + auto dimensions = block.getIArguments(); - if(shape::length(inputShape->at(2)) > 1) { - - bool keepDims = false; - auto dimensions = *block.getIArguments(); - - if (block.width() > 3) { - auto axesVector = INPUT_VARIABLE(3); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - - if (block.getBArguments()->size()) - keepDims = B_ARG(0); - else if (block.getTArguments()->size()) - keepDims = (bool)T_ARG(0); - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_DOT_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_DOT_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); + if (block.width() > 3) { + auto axesVector = INPUT_VARIABLE(3); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); } - Nd4jLong *outShapeInfo1, *outShapeInfo2; - COPY_SHAPE(inputShape->at(0), outShapeInfo1); - COPY_SHAPE(inputShape->at(1), outShapeInfo2); - - return SHAPELIST(CONSTANT(outShapeInfo1), CONSTANT(outShapeInfo2)); + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_DOT_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_DOT_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + } + + Nd4jLong *outShapeInfo1, *outShapeInfo2; + COPY_SHAPE(inputShape->at(0), outShapeInfo1); + COPY_SHAPE(inputShape->at(1), outShapeInfo2); + + return SHAPELIST(CONSTANT(outShapeInfo1), CONSTANT(outShapeInfo2)); } DECLARE_TYPES(reduce_dot_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } #endif -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_logsumexp.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_logsumexp.cpp index 556ad2a7c5ae..a51cca63d9c6 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_logsumexp.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_logsumexp.cpp @@ -24,56 +24,60 @@ namespace sd { namespace ops { #if NOT_EXCLUDED(OP_reduce_logsumexp) - CUSTOM_OP_IMPL(reduce_logsumexp, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - std::vector axes;// = *block.getIArguments(); - if (block.width() > 1) { - auto axisVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axisVector, axes ); - } - else if (block.getIArguments()->size() > 0) { - axes = *block.getIArguments(); - } +CUSTOM_OP_IMPL(reduce_logsumexp, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + std::vector axes; // = *block.getIArguments(); + if (block.width() > 1) { + auto axisVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axisVector, axes); + } else if (block.numI() > 0) { + axes = block.getIArguments(); + } - for(const auto& item : axes) - REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item shapeInfo()[0], 0, "REDUCE_LOGSUMEXP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); + for (const auto& item : axes) + REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], + 0, + "REDUCE_LOGSUMEXP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); - const bool keepDims = block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; - Nd4jLong maxI = input->argMax(); - auto maxVals = input->e(maxI); - //void* whereMax = (void*)(); - auto internal = (*input); - internal -= maxVals; - internal.applyTransform(transform::Exp, internal); - internal.reduceAlongDimension(reduce::Sum, *output, axes, keepDims, false); //, (void*)&maxVals); - output->applyTransform(transform::Log, *output); - (*output) += maxVals; - return ND4J_STATUS_OK; - } - DECLARE_TYPES(reduce_logsumexp) { - getOpDescriptor() - -> setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) - -> setAllowedOutputTypes({ALL_FLOATS}); - } - DECLARE_SHAPE_FN(reduce_logsumexp) { - - const bool keepDims = block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; - auto input = INPUT_VARIABLE(0); + const bool keepDims = block.numT() > 0 ? (bool)T_ARG(0) : false; + Nd4jLong maxI = input->argMax(); + auto maxVals = input->e(maxI); + // void* whereMax = (void*)(); + auto internal = (*input); + internal -= maxVals; + internal.applyTransform(transform::Exp, internal); + internal.reduceAlongDimension(reduce::Sum, *output, axes, keepDims, + false); //, (void*)&maxVals); + output->applyTransform(transform::Log, *output); + (*output) += maxVals; + return ND4J_STATUS_OK; +} +DECLARE_TYPES(reduce_logsumexp) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); +} +DECLARE_SHAPE_FN(reduce_logsumexp) { + const bool keepDims = block.numT() > 0 ? (bool)T_ARG(0) : false; + auto input = INPUT_VARIABLE(0); - std::vector axes; // = *block.getIArguments(); - if (block.width() > 1) { - auto axisVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axisVector, axes ); - } - else if (block.getIArguments()->size() > 0) { - axes = *block.getIArguments(); - } + std::vector axes; // = *block.getIArguments(); + if (block.width() > 1) { + auto axisVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axisVector, axes); + } else if (block.numI() > 0) { + axes = block.getIArguments(); + } - auto outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), axes, inputShape->at(0), keepDims, false, block.getWorkspace()); + auto outShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(inputShape->at(0)), axes, inputShape->at(0), keepDims, false, + block.workspace()); - return SHAPELIST(outShapeInfo); - } -#endif -} + return SHAPELIST(outShapeInfo); } +#endif +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_max.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_max.cpp index bea1e7eccc02..f7416ff91f6d 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_max.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_max.cpp @@ -20,8 +20,8 @@ // #include -#include #include +#include namespace sd { namespace ops { @@ -30,62 +30,74 @@ namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_max, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + std::vector dimensions = block.getIArguments(); - std::vector dimensions = *block.getIArguments(); + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_MAX OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_MAX OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], + 0, + "REDUCE_MAX OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_MAX OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); + bool keepDims = false; //: false; + if (block.numB() > 0) + keepDims = B_ARG(0); + else if (block.numT() > 0) + keepDims = (bool)T_ARG(0); - bool keepDims = false;//: false; - if (block.getBArguments()->size() > 0) - keepDims = B_ARG(0); - else if (block.getTArguments()->size() > 0) - keepDims = (bool)T_ARG(0); + input->reduceAlongDimension(reduce::Max, *output, dimensions, keepDims); - input->reduceAlongDimension(reduce::Max, *output, dimensions, keepDims); - - return Status::OK(); + return Status::OK(); } DECLARE_SHAPE_FN(reduce_max) { - - bool keepDims = false;//: false; - - if (block.getBArguments()->size() > 0) - keepDims = B_ARG(0); - else if (block.getTArguments()->size() > 0) - keepDims = (bool)T_ARG(0); - - auto dimensions = *block.getIArguments(); - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_MAX OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_MAX OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - auto outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.getWorkspace()); - - return SHAPELIST(outShapeInfo); + bool keepDims = false; //: false; + + if (block.numB() > 0) + keepDims = B_ARG(0); + else if (block.numT() > 0) + keepDims = (bool)T_ARG(0); + + auto dimensions = block.getIArguments(); + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_MAX OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_MAX OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + auto outShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, + false, block.workspace()); + + return SHAPELIST(outShapeInfo); } DECLARE_TYPES(reduce_max) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } #endif @@ -93,69 +105,81 @@ DECLARE_TYPES(reduce_max) { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_max_bp, 2, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); - - std::vector dimensions = *block.getIArguments(); - - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_MAX_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_MAX_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - // *** calculations *** // - - *gradI = 0; - - if(gradO->lengthOf() == 1) { - - auto indOfMaxElem = input->indexReduceNumber(sd::indexreduce::IndexMax); - gradI->p(indOfMaxElem.t(0), gradO->e(0)); - } - else { - - auto indicesArr = input->applyIndexReduce(sd::indexreduce::IndexMax, dimensions); - helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation - } - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); + + std::vector dimensions = block.getIArguments(); + + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_MAX_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], + 0, + "REDUCE_MAX_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + // *** calculations *** // + + *gradI = 0; + + if (gradO->lengthOf() == 1) { + auto indOfMaxElem = input->indexReduceNumber(sd::indexreduce::IndexMax); + gradI->p(indOfMaxElem.t(0), gradO->e(0)); + } else { + auto indicesArr = + input->applyIndexReduce(sd::indexreduce::IndexMax, dimensions); + helpers::scatterSimple( + block.launchContext(), 6, *gradI, *gradO, indicesArr, + ShapeUtils::evalDimsToExclude( + gradI->rankOf(), dimensions)); // 6 corresponds to copy operation + } + + return Status::OK(); } - DECLARE_SHAPE_FN(reduce_max_bp) { + std::vector dimensions = block.getIArguments(); - std::vector dimensions = *block.getIArguments(); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_MAX_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_MAX_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_MAX_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_MAX_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !", inputShape->at(0)[0], inputShape->at(0)[0], item); + Nd4jLong* outShapeInfo; + COPY_SHAPE(inputShape->at(0), outShapeInfo); - Nd4jLong* outShapeInfo; - COPY_SHAPE(inputShape->at(0), outShapeInfo); - - return SHAPELIST(CONSTANT(outShapeInfo)); + return SHAPELIST(CONSTANT(outShapeInfo)); } DECLARE_TYPES(reduce_max_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } #endif -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_min.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_min.cpp index d4b470b8ed77..fb9b17506309 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_min.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_min.cpp @@ -20,8 +20,8 @@ // #include -#include #include +#include namespace sd { namespace ops { @@ -30,135 +30,157 @@ namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_min, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + std::vector dimensions = block.getIArguments(); - std::vector dimensions = *block.getIArguments(); + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_MIN OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_MIN OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], + 0, + "REDUCE_MIN OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_MIN OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); + bool keepDims = false; //: false; + if (block.numB() > 0) + keepDims = B_ARG(0); + else if (block.numT() > 0) + keepDims = (bool)T_ARG(0); - bool keepDims = false;//: false; - if (block.getBArguments()->size() > 0) - keepDims = B_ARG(0); - else if (block.getTArguments()->size() > 0) - keepDims = (bool)T_ARG(0); + input->reduceAlongDimension(reduce::Min, *output, dimensions, keepDims); - input->reduceAlongDimension(reduce::Min, *output, dimensions, keepDims); - - return Status::OK(); + return Status::OK(); } DECLARE_SHAPE_FN(reduce_min) { - - bool keepDims = false;//: false; - - if (block.getBArguments()->size() > 0) - keepDims = B_ARG(0); - else if (block.getTArguments()->size() > 0) - keepDims = (bool)T_ARG(0); - - auto dimensions = *block.getIArguments(); - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_MIN OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_MIN OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - auto outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.getWorkspace()); - - return SHAPELIST(outShapeInfo); + bool keepDims = false; //: false; + + if (block.numB() > 0) + keepDims = B_ARG(0); + else if (block.numT() > 0) + keepDims = (bool)T_ARG(0); + + auto dimensions = block.getIArguments(); + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_MIN OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_MIN OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + auto outShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, + false, block.workspace()); + + return SHAPELIST(outShapeInfo); } DECLARE_TYPES(reduce_min) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } - #endif - #if NOT_EXCLUDED(OP_reduce_min_bp) ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_min_bp, 2, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); - - std::vector dimensions = *block.getIArguments(); - - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_MIN_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_MIN_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - // *** calculations *** // - - *gradI = 0; - - if(gradO->lengthOf() == 1) { - - auto indOfMaxElem = input->indexReduceNumber(sd::indexreduce::IndexMin); - gradI->p(indOfMaxElem.e(0), gradO->e(0)); - } - else { - - auto indicesArr = input->applyIndexReduce(sd::indexreduce::IndexMin, dimensions); - helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation - } - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); + + std::vector dimensions = block.getIArguments(); + + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_MIN_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], + 0, + "REDUCE_MIN_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + // *** calculations *** // + + *gradI = 0; + + if (gradO->lengthOf() == 1) { + auto indOfMaxElem = input->indexReduceNumber(sd::indexreduce::IndexMin); + gradI->p(indOfMaxElem.e(0), gradO->e(0)); + } else { + auto indicesArr = + input->applyIndexReduce(sd::indexreduce::IndexMin, dimensions); + helpers::scatterSimple( + block.launchContext(), 6, *gradI, *gradO, indicesArr, + ShapeUtils::evalDimsToExclude( + gradI->rankOf(), dimensions)); // 6 corresponds to copy operation + } + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_min_bp) { + std::vector dimensions = block.getIArguments(); - std::vector dimensions = *block.getIArguments(); - - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_MIN_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_MIN_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_MIN_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !", inputShape->at(0)[0], inputShape->at(0)[0], item); + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_MIN_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); - Nd4jLong* outShapeInfo; - COPY_SHAPE(inputShape->at(0), outShapeInfo); + Nd4jLong* outShapeInfo; + COPY_SHAPE(inputShape->at(0), outShapeInfo); - return SHAPELIST(CONSTANT(outShapeInfo)); + return SHAPELIST(CONSTANT(outShapeInfo)); } DECLARE_TYPES(reduce_min_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - #endif -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_norm1.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm1.cpp index 31261fe5ccd6..fb6269605440 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_norm1.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm1.cpp @@ -19,8 +19,8 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include namespace sd { namespace ops { @@ -28,142 +28,169 @@ namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_norm1, 1, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - std::vector dimensions; - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - else if (block.getIArguments()->size()) - dimensions = *block.getIArguments(); - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_NORM1 OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_NORM1 OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - bool keepDims = false; - if (block.getBArguments()->size()) - keepDims = B_ARG(0); - else if (block.getTArguments()->size()) - keepDims = (bool)T_ARG(0); - - input->reduceAlongDimension(reduce::Norm1, *output, dimensions, keepDims); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + std::vector dimensions; + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } else if (block.numI()) + dimensions = block.getIArguments(); + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_NORM1 OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], + 0, + "REDUCE_NORM1 OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + input->reduceAlongDimension(reduce::Norm1, *output, dimensions, keepDims); + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_norm1) { - - bool keepDims = false; - if (block.getBArguments()->size()) - keepDims = B_ARG(0); - else if (block.getTArguments()->size()) - keepDims = (bool)T_ARG(0); - - std::vector dimensions; - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - else if (block.getIArguments()->size()) - dimensions = *block.getIArguments(); - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_NORM1 OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_NORM1 OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - return SHAPELIST(ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.getWorkspace())); + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + std::vector dimensions; + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } else if (block.numI()) + dimensions = block.getIArguments(); + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_NORM1 OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_NORM1 OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + return SHAPELIST(ShapeUtils::evalReduceShapeInfo( + shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, + false, block.workspace())); } DECLARE_TYPES(reduce_norm1) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } #endif #if NOT_EXCLUDED(OP_reduce_norm1_bp) ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_norm1_bp, 2, 1, false, 0, 0) { - // L = Sum abs(x_i) for all i = 1 to N - // dL/dx_i = 1 if x_i >= 0 and -1 when x_i < 0 - // out_i = epsilon_i if x_i > 0 and -epsilon_i when x_i < 0 - // when gradO is non a scalar, using dimensions to split output onto gradO like parts - // and use LAMBDA with that formula for it. - - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); + // L = Sum abs(x_i) for all i = 1 to N + // dL/dx_i = 1 if x_i >= 0 and -1 when x_i < 0 + // out_i = epsilon_i if x_i > 0 and -epsilon_i when x_i < 0 + // when gradO is non a scalar, using dimensions to split output onto gradO + // like parts and use LAMBDA with that formula for it. - input->applyTransform(sd::transform::Sign, *gradI); + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); - if (gradO->lengthOf() == 1) { - *gradI *= *gradO; - } - else { - - bool keepDims = false; - auto dimensions = *block.getIArguments(); - - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - - if (block.getBArguments()->size()) - keepDims = B_ARG(0); - else if (block.getTArguments()->size()) - keepDims = (bool)T_ARG(0); - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_NORM1_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); + input->applyTransform(sd::transform::Sign, *gradI); - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_NORM1_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - // *** calculations *** // + if (gradO->lengthOf() == 1) { + *gradI *= *gradO; + } else { + bool keepDims = false; + auto dimensions = block.getIArguments(); - if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.getWorkspace()); - *gradI *= gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] - } else - *gradI *= *gradO; + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } - return Status::OK(); + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_NORM1_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, + "REDUCE_NORM1_BP OP: the input dimension to reduce along " + "must be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + // *** calculations *** // + + if (!keepDims) { + auto gradOShapeKeepDims = + ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, + true, false, block.workspace()); + *gradI *= gradO->reshape( + gradO->ordering(), + ShapeUtils::pullShapeFromShapeInfo( + gradOShapeKeepDims)); // for example could be something like + // [a,b] -> [1,a,1,b] + } else + *gradI *= *gradO; + } + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_norm1_bp) { - - auto dimensions = *block.getIArguments(); - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_NORM1_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_NORM1_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !", inputShape->at(0)[0], inputShape->at(0)[0], item); - - Nd4jLong* outShapeInfo; - COPY_SHAPE(inputShape->at(0), outShapeInfo); - - return SHAPELIST(CONSTANT(outShapeInfo)); + auto dimensions = block.getIArguments(); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_NORM1_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_NORM1_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + Nd4jLong* outShapeInfo; + COPY_SHAPE(inputShape->at(0), outShapeInfo); + + return SHAPELIST(CONSTANT(outShapeInfo)); } DECLARE_TYPES(reduce_norm1_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - #endif -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_norm2.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm2.cpp index c9ea8e374d79..72188923a569 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_norm2.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm2.cpp @@ -28,62 +28,74 @@ namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_norm2, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - std::vector dimensions; - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - else if (block.getIArguments()->size()) - dimensions = *block.getIArguments(); - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_NORM2 OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_NORM2 OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - bool keepDims = false; - if (block.getBArguments()->size()) - keepDims = B_ARG(0); - else if (block.getTArguments()->size()) - keepDims = (bool)T_ARG(0); - - input->reduceAlongDimension(reduce::Norm2, *output, dimensions, keepDims); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + std::vector dimensions; + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } else if (block.numI()) + dimensions = block.getIArguments(); + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_NORM2 OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], + 0, + "REDUCE_NORM2 OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + input->reduceAlongDimension(reduce::Norm2, *output, dimensions, keepDims); + + return Status::OK(); } - DECLARE_SHAPE_FN(reduce_norm2) { - - bool keepDims = false; - if (block.getBArguments()->size()) - keepDims = B_ARG(0); - else if (block.getTArguments()->size()) - keepDims = (bool)T_ARG(0); - - std::vector dimensions; - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - else if (block.getIArguments()->size()) - dimensions = *block.getIArguments(); - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_NORM2 OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_NORM2 OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - return SHAPELIST(ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.getWorkspace())); + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + std::vector dimensions; + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } else if (block.numI()) + dimensions = block.getIArguments(); + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_NORM2 OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_NORM2 OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + return SHAPELIST(ShapeUtils::evalReduceShapeInfo( + shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, + false, block.workspace())); } DECLARE_TYPES(reduce_norm2) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } #endif @@ -91,77 +103,91 @@ DECLARE_TYPES(reduce_norm2) { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_norm2_bp, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); + gradI->assign(input); - gradI->assign(input); + if (gradO->lengthOf() == 1) { + *gradI /= input->reduceNumber(reduce::Norm2); + *gradI *= *gradO; + } else { + bool keepDims = false; + auto dimensions = block.getIArguments(); - if (gradO->lengthOf() == 1) { - *gradI /= input->reduceNumber(reduce::Norm2); - *gradI *= *gradO; + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } - else { - - bool keepDims = false; - auto dimensions = *block.getIArguments(); - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - - if (block.getBArguments()->size()) - keepDims = B_ARG(0); - else if (block.getTArguments()->size()) - keepDims = (bool)T_ARG(0); - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_NORM2_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_NORM2_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - // *** calculations *** // - - *gradI /= input->reduceAlongDimension(reduce::Norm2, dimensions, true); - - if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.getWorkspace()); - *gradI *= gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] - } else - *gradI *= *gradO; - } - return Status::OK(); + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_NORM2_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, + "REDUCE_NORM2_BP OP: the input dimension to reduce along " + "must be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + // *** calculations *** // + + *gradI /= input->reduceAlongDimension(reduce::Norm2, dimensions, true); + + if (!keepDims) { + auto gradOShapeKeepDims = + ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, + true, false, block.workspace()); + *gradI *= gradO->reshape( + gradO->ordering(), + ShapeUtils::pullShapeFromShapeInfo( + gradOShapeKeepDims)); // for example could be something like + // [a,b] -> [1,a,1,b] + } else + *gradI *= *gradO; + } + return Status::OK(); } DECLARE_SHAPE_FN(reduce_norm2_bp) { - - auto dimensions = *block.getIArguments(); - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_NORM2_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_NORM2_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !", inputShape->at(0)[0], inputShape->at(0)[0], item); - - Nd4jLong* outShapeInfo; - COPY_SHAPE(inputShape->at(0), outShapeInfo); - - return SHAPELIST(CONSTANT(outShapeInfo)); + auto dimensions = block.getIArguments(); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_NORM2_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_NORM2_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + Nd4jLong* outShapeInfo; + COPY_SHAPE(inputShape->at(0), outShapeInfo); + + return SHAPELIST(CONSTANT(outShapeInfo)); } DECLARE_TYPES(reduce_norm2_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - #endif -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_norm_max.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm_max.cpp index b1a0189009e1..10110d0ee4ed 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_norm_max.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm_max.cpp @@ -20,8 +20,8 @@ // #include -#include #include +#include namespace sd { namespace ops { @@ -29,63 +29,74 @@ namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_norm_max, 1, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - std::vector dimensions; - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - else if (block.getIArguments()->size()) - dimensions = *block.getIArguments(); - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_NORM_MAX OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_NORM_MAX OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - bool keepDims = false; - if (block.getBArguments()->size()) - keepDims = B_ARG(0); - else if (block.getTArguments()->size()) - keepDims = (bool)T_ARG(0); - - input->reduceAlongDimension(reduce::NormMax, *output, dimensions, keepDims); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + std::vector dimensions; + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } else if (block.numI()) + dimensions = block.getIArguments(); + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_NORM_MAX OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], + 0, + "REDUCE_NORM_MAX OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + input->reduceAlongDimension(reduce::NormMax, *output, dimensions, keepDims); + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_norm_max) { - - auto in = inputShape->at(0); - bool keepDims = false; - if (block.getBArguments()->size()) - keepDims = B_ARG(0); - else if (block.getTArguments()->size()) - keepDims = (bool)T_ARG(0); - - std::vector dimensions; - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - else if (block.getIArguments()->size()) - dimensions = *block.getIArguments(); - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_NORM_MAX OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_NORM_MAX OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - return SHAPELIST(ShapeUtils::evalReduceShapeInfo(shape::order(in), dimensions, in, keepDims, false, block.getWorkspace())); + auto in = inputShape->at(0); + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + std::vector dimensions; + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } else if (block.numI()) + dimensions = block.getIArguments(); + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_NORM_MAX OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_NORM_MAX OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + return SHAPELIST(ShapeUtils::evalReduceShapeInfo( + shape::order(in), dimensions, in, keepDims, false, block.workspace())); } DECLARE_TYPES(reduce_norm_max) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } #endif @@ -93,73 +104,84 @@ DECLARE_TYPES(reduce_norm_max) { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_norm_max_bp, 2, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); - - std::vector dimensions = *block.getIArguments(); - - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_NORM_MAX_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_NORM_MAX_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - // *** calculations *** // - - *gradI = 0; - - if(gradO->lengthOf() == 1) { - - auto indOfAbsMaxElem = input->indexReduceNumber(sd::indexreduce::IndexAbsoluteMax); - const Nd4jLong ind = indOfAbsMaxElem.t(0); - const int sign = input->e(ind) >= 0 ? 1 : -1; - gradI->p(ind, sign * gradO->e(0)); - } - else { - - auto indicesArr = input->applyIndexReduce(sd::indexreduce::IndexAbsoluteMax, dimensions); - helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation - *gradI *= input->transform(sd::transform::Sign); - } - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); + + std::vector dimensions = block.getIArguments(); + + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_NORM_MAX_BP OP: the number of dimensions to reduce " + "along must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], + 0, + "REDUCE_NORM_MAX_BP OP: the input dimension to reduce along " + "must be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + // *** calculations *** // + + *gradI = 0; + + if (gradO->lengthOf() == 1) { + auto indOfAbsMaxElem = + input->indexReduceNumber(sd::indexreduce::IndexAbsoluteMax); + const Nd4jLong ind = indOfAbsMaxElem.t(0); + const int sign = input->e(ind) >= 0 ? 1 : -1; + gradI->p(ind, sign * gradO->e(0)); + } else { + auto indicesArr = + input->applyIndexReduce(sd::indexreduce::IndexAbsoluteMax, dimensions); + helpers::scatterSimple( + block.launchContext(), 6, *gradI, *gradO, indicesArr, + ShapeUtils::evalDimsToExclude( + gradI->rankOf(), dimensions)); // 6 corresponds to copy operation + *gradI *= input->transform(sd::transform::Sign); + } + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_norm_max_bp) { - - auto dimensions = *block.getIArguments(); - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_NORM_MAX_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_NORM_MAX_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !", inputShape->at(0)[0], inputShape->at(0)[0], item); - - Nd4jLong* outShapeInfo; - COPY_SHAPE(inputShape->at(0), outShapeInfo); - - return SHAPELIST(CONSTANT(outShapeInfo)); + auto dimensions = block.getIArguments(); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_NORM_MAX_BP OP: the number of dimensions to reduce " + "along must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_NORM_MAX_BP OP: the input dimension to reduce along " + "must be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + Nd4jLong* outShapeInfo; + COPY_SHAPE(inputShape->at(0), outShapeInfo); + + return SHAPELIST(CONSTANT(outShapeInfo)); } DECLARE_TYPES(reduce_norm_max_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - - - #endif -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_prod.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_prod.cpp index e873220efa77..68072e639cbf 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_prod.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_prod.cpp @@ -19,8 +19,8 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include namespace sd { namespace ops { @@ -28,62 +28,74 @@ namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_prod, 1, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - std::vector dimensions; - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - else if (block.getIArguments()->size()) - dimensions = *block.getIArguments(); - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_PROD OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_PROD OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - bool keepDims = false; - if (block.getBArguments()->size()) - keepDims = B_ARG(0); - else if (block.getTArguments()->size()) - keepDims = (bool)T_ARG(0); - - input->reduceAlongDimension(reduce::Prod, *output, dimensions, keepDims); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + std::vector dimensions; + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } else if (block.numI()) + dimensions = block.getIArguments(); + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_PROD OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], + 0, + "REDUCE_PROD OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + input->reduceAlongDimension(reduce::Prod, *output, dimensions, keepDims); + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_prod) { - - bool keepDims = false; - if (block.getBArguments()->size()) - keepDims = B_ARG(0); - else if (block.getTArguments()->size()) - keepDims = (bool)T_ARG(0); - - std::vector dimensions; - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - else if (block.getIArguments()->size()) - dimensions = *block.getIArguments(); - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_PROD OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_PROD OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - return SHAPELIST(ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.getWorkspace())); + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + std::vector dimensions; + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } else if (block.numI()) + dimensions = block.getIArguments(); + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_PROD OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_PROD OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + return SHAPELIST(ShapeUtils::evalReduceShapeInfo( + shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, + false, block.workspace())); } DECLARE_TYPES(reduce_prod) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } #endif @@ -91,78 +103,94 @@ DECLARE_TYPES(reduce_prod) { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_prod_bp, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); + + if (gradO->lengthOf() == 1) { + gradI->assign(input->reduceNumber(sd::reduce::Prod)); + *gradI /= *input; + *gradI *= gradO->e(0); + } else { + bool keepDims = false; + auto dimensions = block.getIArguments(); - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); - - if (gradO->lengthOf() == 1) { - gradI->assign(input->reduceNumber(sd::reduce::Prod)); - *gradI /= *input; - *gradI *= gradO->e(0); - } - else { - - bool keepDims = false; - auto dimensions = *block.getIArguments(); - - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - - if (block.getBArguments()->size()) - keepDims = B_ARG(0); - else if (block.getTArguments()->size()) - keepDims = (bool)T_ARG(0); - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_NORM1_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_NORM1_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - // *** calculations *** // - - auto products = input->reduceAlongDimension(reduce::Prod, dimensions, true); - gradI->applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), products, *gradI); - *gradI /= *input; - - if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.getWorkspace()); - *gradI *= gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] - } else - *gradI *= *gradO; + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } - return Status::OK(); + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_NORM1_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, + "REDUCE_NORM1_BP OP: the input dimension to reduce along " + "must be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + // *** calculations *** // + + auto products = input->reduceAlongDimension(reduce::Prod, dimensions, true); + gradI->applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), products, + *gradI); + *gradI /= *input; + + if (!keepDims) { + auto gradOShapeKeepDims = + ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, + true, false, block.workspace()); + *gradI *= gradO->reshape( + gradO->ordering(), + ShapeUtils::pullShapeFromShapeInfo( + gradOShapeKeepDims)); // for example could be something like + // [a,b] -> [1,a,1,b] + } else + *gradI *= *gradO; + } + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_prod_bp) { - - auto dimensions = *block.getIArguments(); - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_PROD_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_PROD_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !", inputShape->at(0)[0], inputShape->at(0)[0], item); - - Nd4jLong* outShapeInfo; - COPY_SHAPE(inputShape->at(0), outShapeInfo); - - return SHAPELIST(CONSTANT(outShapeInfo)); + auto dimensions = block.getIArguments(); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_PROD_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_PROD_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + Nd4jLong* outShapeInfo; + COPY_SHAPE(inputShape->at(0), outShapeInfo); + + return SHAPELIST(CONSTANT(outShapeInfo)); } DECLARE_TYPES(reduce_prod_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } #endif -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_sqnorm.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_sqnorm.cpp index 22d2c6e1bc4f..e08f2d04a3f6 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_sqnorm.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_sqnorm.cpp @@ -18,8 +18,8 @@ // Created by george@skymind.io on 6/4/2018. // -#include #include +#include namespace sd { namespace ops { @@ -27,63 +27,77 @@ namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_sqnorm, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto gradI = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto gradI = OUTPUT_VARIABLE(0); + bool keepDims = false; - bool keepDims = false; + auto dimensions = block.getIArguments(); - auto dimensions = *block.getIArguments(); + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); - if (block.getBArguments()->size()) - keepDims = B_ARG(0); - else if (block.getTArguments()->size()) - keepDims = (bool)T_ARG(0); + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_SQNORM OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_SQNORM OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, + "REDUCE_SQNORM OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_SQNORM OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); + input->reduceAlongDimension(reduce::SquaredNorm, *gradI, dimensions, + keepDims); - input->reduceAlongDimension(reduce::SquaredNorm, *gradI, dimensions, keepDims); - - return Status::OK(); + return Status::OK(); } DECLARE_SHAPE_FN(reduce_sqnorm) { - - auto dimensions = *block.getIArguments(); - bool keepDims = false; - - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - - if (block.getBArguments()->size()) - keepDims = B_ARG(0); - else if (block.getTArguments()->size()) - keepDims = (bool)T_ARG(0); - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_SQNORM OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_SQNORM OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - auto outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.getWorkspace()); - - return SHAPELIST(outShapeInfo); + auto dimensions = block.getIArguments(); + bool keepDims = false; + + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } + + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_SQNORM OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_SQNORM OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + auto outShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, + false, block.workspace()); + + return SHAPELIST(outShapeInfo); } DECLARE_TYPES(reduce_sqnorm) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } #endif @@ -92,74 +106,90 @@ DECLARE_TYPES(reduce_sqnorm) { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_sqnorm_bp, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); + if (gradO->lengthOf() == 1) { + gradI->assign(2 * (*input) * gradO->e(0)); + } else { + bool keepDims = false; + auto dimensions = block.getIArguments(); - if (gradO->lengthOf() == 1) { - gradI->assign( 2 * (*input) * gradO->e(0)); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } - else { - - bool keepDims = false; - auto dimensions = *block.getIArguments(); - - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - - if (block.getBArguments()->size()) - keepDims = B_ARG(0); - else if (block.getTArguments()->size()) - keepDims = (bool)T_ARG(0); - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_SQNORM_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_SQNORM_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - // *** calculations *** // - if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.getWorkspace()); - gradI->assign(2. * (*input) *gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims))); // for example could be something like [a,b] -> [1,a,1,b] - } else - gradI->assign(2. * (*input) * *gradO); - } - return Status::OK(); + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_SQNORM_BP OP: the number of dimensions to reduce " + "along must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, + "REDUCE_SQNORM_BP OP: the input dimension to reduce along " + "must be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + // *** calculations *** // + + if (!keepDims) { + auto gradOShapeKeepDims = + ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, + true, false, block.workspace()); + gradI->assign( + 2. * (*input) * + gradO->reshape( + gradO->ordering(), + ShapeUtils::pullShapeFromShapeInfo( + gradOShapeKeepDims))); // for example could be something like + // [a,b] -> [1,a,1,b] + } else + gradI->assign(2. * (*input) * *gradO); + } + return Status::OK(); } DECLARE_SHAPE_FN(reduce_sqnorm_bp) { + if (shape::length(inputShape->at(1)) > 1) { + auto dimensions = block.getIArguments(); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } - if(shape::length(inputShape->at(1)) > 1) { - - auto dimensions = *block.getIArguments(); - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_SQNORM_BP OP: the number of dimensions to reduce " + "along must be <= input array rank, but got %i instead", + dimensions.size()); - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_SQNORM_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_SQNORM_BP OP: the input dimension to reduce along " + "must be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + } - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_SQNORM_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - } + Nd4jLong* gradIshapeInfo(nullptr); + COPY_SHAPE(inputShape->at(0), gradIshapeInfo); - Nd4jLong* gradIshapeInfo(nullptr); - COPY_SHAPE(inputShape->at(0), gradIshapeInfo); - - return SHAPELIST(CONSTANT(gradIshapeInfo)); + return SHAPELIST(CONSTANT(gradIshapeInfo)); } DECLARE_TYPES(reduce_sqnorm_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } #endif -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_sum.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_sum.cpp index 0f4a5f467556..5b561d8afbb4 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_sum.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_sum.cpp @@ -28,136 +28,161 @@ namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_sum, 1, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - std::vector dimensions; - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - else if (block.getIArguments()->size()) - dimensions = *block.getIArguments(); - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_SUM OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_SUM OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - bool keepDims = false; - if (block.getBArguments()->size()) - keepDims = B_ARG(0); - else if (block.getTArguments()->size()) - keepDims = (bool)T_ARG(0); - - input->reduceAlongDimension(reduce::Sum, *output, dimensions, keepDims); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + std::vector dimensions; + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } else if (block.numI()) + dimensions = block.getIArguments(); + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_SUM OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], + 0, + "REDUCE_SUM OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + input->reduceAlongDimension(reduce::Sum, *output, dimensions, keepDims); + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_sum) { - - bool keepDims = false; - if (block.getBArguments()->size()) - keepDims = B_ARG(0); - else if (block.getTArguments()->size()) - keepDims = (bool)T_ARG(0); - - std::vector dimensions; - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - else if (block.getIArguments()->size()) - dimensions = *block.getIArguments(); - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_SUM OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_SUM OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - return SHAPELIST(ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.getWorkspace())); + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + std::vector dimensions; + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } else if (block.numI()) + dimensions = block.getIArguments(); + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_SUM OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_SUM OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + return SHAPELIST(ShapeUtils::evalReduceShapeInfo( + shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, + false, block.workspace())); } DECLARE_TYPES(reduce_sum) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } #endif #if NOT_EXCLUDED(OP_reduce_sum_bp) ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_sum_bp, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); - - if (gradO->lengthOf() == 1) { - gradI->assign(gradO->e(0)); - } - else { - - bool keepDims = false; - auto dimensions = *block.getIArguments(); - - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - - if (block.getBArguments()->size()) - keepDims = B_ARG(0); - else if (block.getTArguments()->size()) - keepDims = (bool)T_ARG(0); - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_SUM_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_SUM_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - // *** calculations *** // + if (gradO->lengthOf() == 1) { + gradI->assign(gradO->e(0)); + } else { + bool keepDims = false; + auto dimensions = block.getIArguments(); - if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.getWorkspace()); - auto r = gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] - gradI->applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), r, *gradI); - } else - gradI->applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), *gradO, *gradI); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } - return Status::OK(); + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_SUM_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, + "REDUCE_SUM_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + // *** calculations *** // + + if (!keepDims) { + auto gradOShapeKeepDims = + ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, + true, false, block.workspace()); + auto r = gradO->reshape( + gradO->ordering(), + ShapeUtils::pullShapeFromShapeInfo( + gradOShapeKeepDims)); // for example could be something like + // [a,b] -> [1,a,1,b] + gradI->applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), r, *gradI); + } else + gradI->applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), *gradO, + *gradI); + } + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_sum_bp) { - - auto dimensions = *block.getIArguments(); - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_SUM_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_SUM_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !", inputShape->at(0)[0], inputShape->at(0)[0], item); - - Nd4jLong* outShapeInfo; - COPY_SHAPE(inputShape->at(0), outShapeInfo); - - return SHAPELIST(CONSTANT(outShapeInfo)); + auto dimensions = block.getIArguments(); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_SUM_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_SUM_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + Nd4jLong* outShapeInfo; + COPY_SHAPE(inputShape->at(0), outShapeInfo); + + return SHAPELIST(CONSTANT(outShapeInfo)); } DECLARE_TYPES(reduce_sum_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - #endif -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/shape/broadcast_to.cpp b/libnd4j/include/ops/declarable/generic/shape/broadcast_to.cpp index 18d10be7b8ff..e729ef783fff 100644 --- a/libnd4j/include/ops/declarable/generic/shape/broadcast_to.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/broadcast_to.cpp @@ -24,63 +24,85 @@ #include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(broadcast_to, 2, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto shape = INPUT_VARIABLE(1); - - auto output = OUTPUT_VARIABLE(0); - - const int inputRank = input->rankOf(); - const int shapeRank = shape->rankOf(); - const Nd4jLong shapeLen = shape->lengthOf(); - - REQUIRE_TRUE(shapeRank <= 1, 0, "BROADCAST_TO op: rank of shape array should be <= 1, bot got %i instead !", shapeRank); - REQUIRE_TRUE(inputRank <= shapeLen, 0, "BROADCAST_TO op: rank of input shape array should be <= length of shape array, bot got %i and %i correspondingly !", inputRank, shapeLen); - - std::vector shapeBuff = shape->getBufferAsVector(); - std::vector outShape(shapeBuff.begin(), shapeBuff.end()); - - for(int i = 1; i <= inputRank; ++i) - REQUIRE_TRUE(input->sizeAt(inputRank-i) == outShape[shapeLen-i] || input->sizeAt(inputRank-i) == 1, 0, "BROADCAST_TO op: shape of input array %s can't be broadcasted to the shape %s !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(outShape).c_str()); - - input->tile(*output); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto shape = INPUT_VARIABLE(1); + + auto output = OUTPUT_VARIABLE(0); + + const int inputRank = input->rankOf(); + const int shapeRank = shape->rankOf(); + const Nd4jLong shapeLen = shape->lengthOf(); + + REQUIRE_TRUE(shapeRank <= 1, 0, + "BROADCAST_TO op: rank of shape array should be <= 1, bot got " + "%i instead !", + shapeRank); + REQUIRE_TRUE(inputRank <= shapeLen, 0, + "BROADCAST_TO op: rank of input shape array should be <= length " + "of shape array, bot got %i and %i correspondingly !", + inputRank, shapeLen); + + std::vector shapeBuff = shape->getBufferAsVector(); + std::vector outShape(shapeBuff.begin(), shapeBuff.end()); + + for (int i = 1; i <= inputRank; ++i) + REQUIRE_TRUE(input->sizeAt(inputRank - i) == outShape[shapeLen - i] || + input->sizeAt(inputRank - i) == 1, + 0, + "BROADCAST_TO op: shape of input array %s can't be " + "broadcasted to the shape %s !", + ShapeUtils::shapeAsString(input).c_str(), + ShapeUtils::shapeAsString(outShape).c_str()); + + input->tile(*output); + + return Status::OK(); } DECLARE_TYPES(broadcast_to) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes(DataType::ANY)->setSameMode(true); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(broadcast_to) { - - auto inputShapeInfo = inputShape->at(0); - auto shape = INPUT_VARIABLE(1); - - const int inputRank = inputShapeInfo[0]; - const int shapeRank = shape->rankOf(); - const Nd4jLong shapeLen = shape->lengthOf(); - - REQUIRE_TRUE(shapeRank <= 1, 0, "BROADCAST_TO op: rank of input shape array should be <= 1, bit got %i instead !", shapeRank); - REQUIRE_TRUE(inputRank <= shapeLen, 0, "BROADCAST_TO op: rank of input shape array should be <= length of shape array, bot got %i and %i correspondingly !", inputRank, shapeLen); - - std::vector shapeBuff = shape->getBufferAsVector(); - std::vector outShape(shapeBuff.begin(), shapeBuff.end()); - - for(int i = 1; i <= inputRank; ++i) - REQUIRE_TRUE(inputShapeInfo[inputRank+1-i] == outShape[shapeLen-i] || inputShapeInfo[inputRank+1-i] == 1, 0, "BROADCAST_TO op: shape of input array %s can't be broadcasted to the shape %s !", ShapeUtils::shapeAsString(inputShapeInfo).c_str(), ShapeUtils::shapeAsString(outShape).c_str()); - - auto outShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inputShapeInfo), shape::order(inputShapeInfo), outShape); - return SHAPELIST(outShapeInfo); + auto inputShapeInfo = inputShape->at(0); + auto shape = INPUT_VARIABLE(1); + + const int inputRank = inputShapeInfo[0]; + const int shapeRank = shape->rankOf(); + const Nd4jLong shapeLen = shape->lengthOf(); + + REQUIRE_TRUE(shapeRank <= 1, 0, + "BROADCAST_TO op: rank of input shape array should be <= 1, bit " + "got %i instead !", + shapeRank); + REQUIRE_TRUE(inputRank <= shapeLen, 0, + "BROADCAST_TO op: rank of input shape array should be <= length " + "of shape array, bot got %i and %i correspondingly !", + inputRank, shapeLen); + + std::vector shapeBuff = shape->getBufferAsVector(); + std::vector outShape(shapeBuff.begin(), shapeBuff.end()); + + for (int i = 1; i <= inputRank; ++i) + REQUIRE_TRUE(inputShapeInfo[inputRank + 1 - i] == outShape[shapeLen - i] || + inputShapeInfo[inputRank + 1 - i] == 1, + 0, + "BROADCAST_TO op: shape of input array %s can't be " + "broadcasted to the shape %s !", + ShapeUtils::shapeAsString(inputShapeInfo).c_str(), + ShapeUtils::shapeAsString(outShape).c_str()); + + auto outShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inputShapeInfo), shape::order(inputShapeInfo), + outShape); + return SHAPELIST(outShapeInfo); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/evaluate_reduction_shape.cpp b/libnd4j/include/ops/declarable/generic/shape/evaluate_reduction_shape.cpp index c35a81279d64..33966b0249c8 100644 --- a/libnd4j/include/ops/declarable/generic/shape/evaluate_reduction_shape.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/evaluate_reduction_shape.cpp @@ -24,57 +24,63 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(evaluate_reduction_shape, 2, 1, false, 0, 0) { - auto inputShape = INPUT_VARIABLE(0); - auto axis = INPUT_VARIABLE(1)->asVectorT(); - auto keepDims = block.numB() > 0 ? B_ARG(0) : false; - auto oldFormat = block.numB() > 1 ? B_ARG(1) : false; - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(evaluate_reduction_shape, 2, 1, false, 0, 0) { + auto inputShape = INPUT_VARIABLE(0); + auto axis = INPUT_VARIABLE(1)->asVectorT(); + auto keepDims = block.numB() > 0 ? B_ARG(0) : false; + auto oldFormat = block.numB() > 1 ? B_ARG(1) : false; + auto output = OUTPUT_VARIABLE(0); - auto shape = inputShape->asVectorT(); + auto shape = inputShape->asVectorT(); - auto tempShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(sd::DataType::INT64, 'c', shape); - auto tempReductionShapeInfo = ShapeUtils::evalReduceShapeInfo('c', axis, tempShapeInfo, keepDims, oldFormat, block.workspace()); + auto tempShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo( + sd::DataType::INT64, 'c', shape); + auto tempReductionShapeInfo = ShapeUtils::evalReduceShapeInfo( + 'c', axis, tempShapeInfo, keepDims, oldFormat, block.workspace()); - REQUIRE_TRUE(output->lengthOf() == shape::rank(tempReductionShapeInfo), 0, "evaluate_reduction_shape: output length should be %i, but got %i instead", shape::rank(tempReductionShapeInfo), output->lengthOf()); + REQUIRE_TRUE(output->lengthOf() == shape::rank(tempReductionShapeInfo), 0, + "evaluate_reduction_shape: output length should be %i, but got " + "%i instead", + shape::rank(tempReductionShapeInfo), output->lengthOf()); - for (int e = 0; e < shape::rank(tempReductionShapeInfo); e++) - output->p(e, tempReductionShapeInfo[e+1]); + for (int e = 0; e < shape::rank(tempReductionShapeInfo); e++) + output->p(e, tempReductionShapeInfo[e + 1]); - return Status::OK(); - } - - DECLARE_TYPES(evaluate_reduction_shape) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes(0, sd::DataType::INT64); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(evaluate_reduction_shape) { - auto input = INPUT_VARIABLE(0); - auto axis = INPUT_VARIABLE(1)->asVectorT(); +DECLARE_TYPES(evaluate_reduction_shape) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes(0, sd::DataType::INT64); +} - auto keepDims = block.numB() > 0 ? B_ARG(0) : false; - auto oldFormat = block.numB() > 1 ? B_ARG(1) : false; +DECLARE_SHAPE_FN(evaluate_reduction_shape) { + auto input = INPUT_VARIABLE(0); + auto axis = INPUT_VARIABLE(1)->asVectorT(); - Nd4jLong length = input->lengthOf(); + auto keepDims = block.numB() > 0 ? B_ARG(0) : false; + auto oldFormat = block.numB() > 1 ? B_ARG(1) : false; - if (keepDims) { - if (oldFormat) { - // for oldFormat we can't go below rank 2 - length = sd::math::nd4j_max(2, length); - } - } else { - length -= axis.size(); - if (oldFormat) { - length = sd::math::nd4j_max(2, length); - } - } + Nd4jLong length = input->lengthOf(); - return SHAPELIST(ConstantShapeHelper::getInstance().vectorShapeInfo(length, sd::DataType::INT64)); - } + if (keepDims) { + if (oldFormat) { + // for oldFormat we can't go below rank 2 + length = sd::math::nd4j_max(2, length); } + } else { + length -= axis.size(); + if (oldFormat) { + length = sd::math::nd4j_max(2, length); + } + } + + return SHAPELIST(ConstantShapeHelper::getInstance().vectorShapeInfo( + length, sd::DataType::INT64)); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/expand_dims.cpp b/libnd4j/include/ops/declarable/generic/shape/expand_dims.cpp index df31f5109786..1c586c6b38e7 100644 --- a/libnd4j/include/ops/declarable/generic/shape/expand_dims.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/expand_dims.cpp @@ -24,80 +24,84 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(expand_dims, 1, 1, false, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - if (input->isScalar()) { - output->assign(input); - return Status::OK(); - } - - Nd4jLong axis = block.numI() > 0 ? INT_ARG(0) : INPUT_VARIABLE(1)->e(0); - - if (axis < 0) - axis += input->rankOf() + 1; - - REQUIRE_TRUE(axis >= 0 && axis <= input->rankOf()+1, 0, "ExpandDims: axis should be in range of 0...%i in this case, but got %i instead", input->rankOf() + 1, axis); - - std::vector shape(input->rankOf()); - - for(int e = 0; e < input->rankOf(); e++) - shape[input->sizeAt(e)]; - - shape.insert(shape.begin() + axis, 1); - - if (input->ews() == 1 && output->ews() == 1 && input->ordering() == output->ordering()) { - output->dataBuffer()->copyBufferFrom(*input->dataBuffer().get(), output->lengthOf() * DataTypeUtils::sizeOfElement(output->dataType()), 0, input->bufferOffset()); - } else { - auto tmp = input->reshape(input->ordering(), shape); - output->assign(tmp); - } - return Status::OK(); - } - - DECLARE_TYPES(expand_dims) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } +namespace ops { +CUSTOM_OP_IMPL(expand_dims, 1, 1, false, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (input->isScalar()) { + output->assign(input); + return Status::OK(); + } + + Nd4jLong axis = block.numI() > 0 ? INT_ARG(0) : INPUT_VARIABLE(1)->e(0); + + if (axis < 0) axis += input->rankOf() + 1; + + REQUIRE_TRUE(axis >= 0 && axis <= input->rankOf() + 1, 0, + "ExpandDims: axis should be in range of 0...%i in this case, " + "but got %i instead", + input->rankOf() + 1, axis); + + std::vector shape(input->rankOf()); + + for (int e = 0; e < input->rankOf(); e++) shape[input->sizeAt(e)]; + + shape.insert(shape.begin() + axis, 1); + + if (input->ews() == 1 && output->ews() == 1 && + input->ordering() == output->ordering()) { + output->dataBuffer()->copyBufferFrom( + *input->dataBuffer().get(), + output->lengthOf() * DataTypeUtils::sizeOfElement(output->dataType()), + 0, input->bufferOffset()); + } else { + auto tmp = input->reshape(input->ordering(), shape); + output->assign(tmp); + } + return Status::OK(); +} - DECLARE_SHAPE_FN(expand_dims) { - auto inShape = inputShape->at(0); +DECLARE_TYPES(expand_dims) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} - // 0D scalar edge case - if (shape::rank(inShape) == 0) { +DECLARE_SHAPE_FN(expand_dims) { + auto inShape = inputShape->at(0); - Nd4jLong x = 1; - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inShape), 'c', 1, &x); - return SHAPELIST(newShape); - } + // 0D scalar edge case + if (shape::rank(inShape) == 0) { + Nd4jLong x = 1; + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inShape), 'c', 1, &x); + return SHAPELIST(newShape); + } - // FIXME: temp workaround for TF - if (shape::isScalar(inShape)) { - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inShape), 'c', 2, shape::shapeOf(inShape)); - return SHAPELIST(newShape); - } + // FIXME: temp workaround for TF + if (shape::isScalar(inShape)) { + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inShape), 'c', 2, shape::shapeOf(inShape)); + return SHAPELIST(newShape); + } - auto x_rank = shape::rank(inShape); - char order = shape::order(inShape); + auto x_rank = shape::rank(inShape); + char order = shape::order(inShape); - Nd4jLong axis = block.numI() > 0 ? INT_ARG(0) : INPUT_VARIABLE(1)->e(0); + Nd4jLong axis = block.numI() > 0 ? INT_ARG(0) : INPUT_VARIABLE(1)->e(0); - if (axis < 0) - axis += x_rank + 1; + if (axis < 0) axis += x_rank + 1; - std::vector shape; - for(int e = 0; e < x_rank; e++) - shape.emplace_back(shape::shapeOf(inShape)[e]); + std::vector shape; + for (int e = 0; e < x_rank; e++) + shape.emplace_back(shape::shapeOf(inShape)[e]); - shape.insert(shape.begin() + axis, 1); + shape.insert(shape.begin() + axis, 1); - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inShape), order, shape); - return SHAPELIST(newShape); - } - } + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inShape), order, shape); + return SHAPELIST(newShape); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/flatten.cpp b/libnd4j/include/ops/declarable/generic/shape/flatten.cpp index 8327ca1a1750..3cb37cf7b38d 100644 --- a/libnd4j/include/ops/declarable/generic/shape/flatten.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/flatten.cpp @@ -22,47 +22,54 @@ #if NOT_EXCLUDED(OP_) #include -#include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(flatten, -1, 1, false, 0, 1) { - auto output = OUTPUT_VARIABLE(0); - auto zType = output->dataType(); - auto xType = INPUT_VARIABLE(0)->dataType(); +namespace ops { +CUSTOM_OP_IMPL(flatten, -1, 1, false, 0, 1) { + auto output = OUTPUT_VARIABLE(0); + auto zType = output->dataType(); + auto xType = INPUT_VARIABLE(0)->dataType(); - REQUIRE_TRUE(xType == zType, 0, "Flatten: output array must have same data type as input arrays"); - std::vector arrays(block.width()); - for (int e = 0; e < block.width(); e++) { - auto input = INPUT_VARIABLE(e); + REQUIRE_TRUE( + xType == zType, 0, + "Flatten: output array must have same data type as input arrays"); + std::vector arrays(block.width()); + for (int e = 0; e < block.width(); e++) { + auto input = INPUT_VARIABLE(e); - REQUIRE_TRUE(xType == input->dataType(), 0, "Flatten: all input arrays must have the same data type"); + REQUIRE_TRUE(xType == input->dataType(), 0, + "Flatten: all input arrays must have the same data type"); - arrays[e] = input; - } + arrays[e] = input; + } - char order = (char) INT_ARG(0); - helpers::flatten(block.launchContext(), arrays, output, order); + char order = (char)INT_ARG(0); + helpers::flatten(block.launchContext(), arrays, output, order); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(flatten) { - getOpDescriptor()->setAllowedInputTypes({ALL_INTS, ALL_FLOATS, sd::DataType::BOOL}); - getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS, sd::DataType::BOOL}); - } +DECLARE_TYPES(flatten) { + getOpDescriptor()->setAllowedInputTypes( + {ALL_INTS, ALL_FLOATS, sd::DataType::BOOL}); + getOpDescriptor()->setAllowedOutputTypes( + 0, {ALL_FLOATS, ALL_INTS, sd::DataType::BOOL}); +} - DECLARE_SHAPE_FN(flatten) { - Nd4jLong length = 0; - sd::DataType dtype = ArrayOptions::dataType(inputShape->at(0)); - for (int e = 0; e < inputShape->size(); e++) { - length += shape::length(inputShape->at(e)); - REQUIRE_TRUE(dtype == ArrayOptions::dataType(inputShape->at(e)), 0, "Flatten: all input arrays must have the same datatype"); - } +DECLARE_SHAPE_FN(flatten) { + Nd4jLong length = 0; + sd::DataType dtype = ArrayOptions::dataType(inputShape->at(0)); + for (int e = 0; e < inputShape->size(); e++) { + length += shape::length(inputShape->at(e)); + REQUIRE_TRUE(dtype == ArrayOptions::dataType(inputShape->at(e)), 0, + "Flatten: all input arrays must have the same datatype"); + } - return SHAPELIST(ConstantShapeHelper::getInstance().vectorShapeInfo(length, dtype)); - } - } + return SHAPELIST( + ConstantShapeHelper::getInstance().vectorShapeInfo(length, dtype)); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/order.cpp b/libnd4j/include/ops/declarable/generic/shape/order.cpp index 2d7e0994c80a..6fdade2058f2 100644 --- a/libnd4j/include/ops/declarable/generic/shape/order.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/order.cpp @@ -24,31 +24,33 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(order, 1, 1, false, 0, 1) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(order, 1, 1, false, 0, 1) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - output->assign(input); + output->assign(input); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(order) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_INTS}); - } +DECLARE_TYPES(order) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_INTS}); +} - DECLARE_SHAPE_FN(order) { - auto input = inputShape->at(0); +DECLARE_SHAPE_FN(order) { + auto input = inputShape->at(0); - auto isFOrder = INT_ARG(0) == 1; + auto isFOrder = INT_ARG(0) == 1; - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(input), isFOrder ? 'f' : 'c', shape::rank(input), shape::shapeOf(input)); - return SHAPELIST(newShape); - } - } + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(input), isFOrder ? 'f' : 'c', shape::rank(input), + shape::shapeOf(input)); + return SHAPELIST(newShape); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/permute.cpp b/libnd4j/include/ops/declarable/generic/shape/permute.cpp index f612aec92755..76b6de592289 100644 --- a/libnd4j/include/ops/declarable/generic/shape/permute.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/permute.cpp @@ -22,60 +22,65 @@ #include #if NOT_EXCLUDED(OP_permute) -#include #include +#include namespace sd { - namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// // here iArgs is int vector of ordered set of dimensions to be permuted CUSTOM_OP_IMPL(permute, 1, 1, true, 0, -2) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); + if (x->isEmpty()) { + REQUIRE_TRUE(z->isEmpty(), 0, + "PERMUTE OP: when input is empty, output must also be empty"); + return Status::OK(); // No op + } - if (x->isEmpty()) { - REQUIRE_TRUE(z->isEmpty(), 0, "PERMUTE OP: when input is empty, output must also be empty"); - return Status::OK(); //No op - } - - if (block.width() == 1 && block.getIArguments()->size() == 0) { - z->assign(x->transpose()); - return Status::OK(); - } + if (block.width() == 1 && block.numI() == 0) { + z->assign(x->transpose()); + return Status::OK(); + } - std::vector permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : *block.getIArguments(); + std::vector permutationVector = block.width() > 1 + ? INPUT_VARIABLE(1)->asVectorT() + : block.getIArguments(); - z->assign(x->permute(permutationVector)); + z->assign(x->permute(permutationVector)); - return Status::OK(); + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(permute) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setSameMode(true); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(permute) { + auto x = INPUT_VARIABLE(0); - auto x = INPUT_VARIABLE(0); - - if (block.width() == 1 && block.getIArguments()->size() == 0) - return SHAPELIST(ShapeUtils::evalTranspShapeInfo(*x, block.workspace(), true)); + if (block.width() == 1 && block.numI() == 0) + return SHAPELIST( + ShapeUtils::evalTranspShapeInfo(*x, block.workspace(), true)); - std::vector permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : *block.getIArguments(); + std::vector permutationVector = block.width() > 1 + ? INPUT_VARIABLE(1)->asVectorT() + : block.getIArguments(); - auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(permutationVector.data(), x->rankOf(), *x, block.workspace(), true); + auto outputShapeInfo = ShapeUtils::evalPermShapeInfo( + permutationVector.data(), x->rankOf(), *x, block.workspace(), true); - return SHAPELIST(outputShapeInfo); + return SHAPELIST(outputShapeInfo); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/rank.cpp b/libnd4j/include/ops/declarable/generic/shape/rank.cpp index d12e152390d3..58ded154c9bb 100644 --- a/libnd4j/include/ops/declarable/generic/shape/rank.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/rank.cpp @@ -24,30 +24,30 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(rank, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar"); - - output->p(0, input->rankOf()); - output->syncToDevice(); - - return Status::OK(); - } - DECLARE_SHAPE_FN(rank) { - return SHAPELIST(ConstantShapeHelper::getInstance().scalarShapeInfo(sd::DataType::INT32)); - } - - - DECLARE_TYPES(rank) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}) - ->allowOverride(true); - } - } +namespace ops { +CUSTOM_OP_IMPL(rank, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar"); + + output->p(0, input->rankOf()); + output->syncToDevice(); + + return Status::OK(); +} +DECLARE_SHAPE_FN(rank) { + return SHAPELIST( + ConstantShapeHelper::getInstance().scalarShapeInfo(sd::DataType::INT32)); +} + +DECLARE_TYPES(rank) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}) + ->allowOverride(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp index 38bae587ef8a..174652b71523 100644 --- a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp @@ -24,147 +24,150 @@ #include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// // here iArgs is a vector with (optional) negative of order as first element: // ({-order, dim1, dim2, dim3, ...}) CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); + // Special case: empty.reshape() -> return empty + if (x->isEmpty()) { + REQUIRE_TRUE(z->isEmpty(), 0, + "Reshape: when input is empty, output must also be empty"); + return Status::OK(); // No op + } - //Special case: empty.reshape() -> return empty - if (x->isEmpty()) { - REQUIRE_TRUE(z->isEmpty(), 0, "Reshape: when input is empty, output must also be empty"); - return Status::OK(); //No op - } + REQUIRE_TRUE(x->lengthOf() == z->lengthOf(), 0, + "Reshape: lengths before and after reshape should match, but " + "got %i vs %i", + x->lengthOf(), z->lengthOf()); - REQUIRE_TRUE(x->lengthOf() == z->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), z->lengthOf()); + if (Environment::getInstance().isDebugAndVerbose()) + nd4j_printv("Reshape: new shape", z->getShapeAsVector()); - if (Environment::getInstance().isDebugAndVerbose()) - nd4j_printv("Reshape: new shape", z->getShapeAsVector()); + z->assign(x->reshape(z->ordering(), z->getShapeAsVector())); - z->assign(x->reshape(z->ordering(), z->getShapeAsVector())); - - return Status::OK(); + return Status::OK(); } - DECLARE_TYPES(reshape) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setSameMode(true); } DECLARE_SHAPE_FN(reshape) { - - const auto x = INPUT_VARIABLE(0); - - std::vector reshapeArgs; - std::vector shapeNew; - char orderNew = 'c'; - - if (block.width() == 1) { - reshapeArgs = *block.getIArguments(); - if(!reshapeArgs.empty()) { - orderNew = (char) -reshapeArgs[0]; - if(orderNew == 'c' || orderNew == 'f') - reshapeArgs.erase(reshapeArgs.begin()); // remove first element being order in this case - } + const auto x = INPUT_VARIABLE(0); + + std::vector reshapeArgs; + std::vector shapeNew; + char orderNew = 'c'; + + if (block.width() == 1) { + reshapeArgs = block.getIArguments(); + if (!reshapeArgs.empty()) { + orderNew = (char)-reshapeArgs[0]; + if (orderNew == 'c' || orderNew == 'f') + reshapeArgs.erase( + reshapeArgs + .begin()); // remove first element being order in this case } - else { - reshapeArgs = INPUT_VARIABLE(1)->getBufferAsVector(); - orderNew = block.numI() > 0 ? (char) -INT_ARG(0) : 'c'; + } else { + reshapeArgs = INPUT_VARIABLE(1)->getBufferAsVector(); + orderNew = block.numI() > 0 ? (char)-INT_ARG(0) : 'c'; + } + + REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, + "Reshape buffer should have at least 1 dimension !"); + + // Nd4jLong xLen = x->lengthOf(); + // if(x->isEmpty()) { + // xLen = 1; + // for (uint i = 0; i < x->rankOf(); ++i) // + // take into account possible empty shapes + // if(x->sizeAt(i) != 0) + // xLen *= x->sizeAt(i); + // } + + // for (uint i = 0; i < reshapeArgs.size(); ++i) { + + // if (reshapeArgs[i] == -1) { + + // uint shapeLength = 1, numOfZeros = 0; + + // for(uint j = 0; j < i; ++j) + // if(reshapeArgs[j] != 0) + // shapeLength *= reshapeArgs[j]; + // else + // ++numOfZeros; + + // for(uint j = i + 1; j < reshapeArgs.size(); ++j) { + // REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one + // unknown dimension (-1) is allowed."); if(reshapeArgs[j] != 0) + // shapeLength *= reshapeArgs[j]; + // else + // ++numOfZeros; + // } + + // const auto dim = xLen / shapeLength; + + // if(x->isEmpty() && (1 == dim || 0 == numOfZeros)) + // shapeNew.push_back(0); + // else + // shapeNew.push_back(dim); + // } + // else + // shapeNew.push_back(reshapeArgs[i]); + // } + + Nd4jLong newShapeLen = 1; + int pos = -1; + bool newShapeEmpty = false; + + for (int i = 0; i < reshapeArgs.size(); ++i) { + const int dim = reshapeArgs[i]; + + if (dim == -1) { + REQUIRE_TRUE(pos == -1, 0, + "Reshape : Only one unknown dimension (-1) is allowed."); + pos = i; + shapeNew.push_back(1); + } else if (dim == 0) { + shapeNew.push_back(0); + newShapeEmpty = true; + } else { + shapeNew.push_back(dim); + newShapeLen *= dim; } + } - REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, "Reshape buffer should have at least 1 dimension !"); - - // Nd4jLong xLen = x->lengthOf(); - // if(x->isEmpty()) { - // xLen = 1; - // for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes - // if(x->sizeAt(i) != 0) - // xLen *= x->sizeAt(i); - // } - - // for (uint i = 0; i < reshapeArgs.size(); ++i) { - - // if (reshapeArgs[i] == -1) { - - // uint shapeLength = 1, numOfZeros = 0; - - // for(uint j = 0; j < i; ++j) - // if(reshapeArgs[j] != 0) - // shapeLength *= reshapeArgs[j]; - // else - // ++numOfZeros; - - // for(uint j = i + 1; j < reshapeArgs.size(); ++j) { - // REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); - // if(reshapeArgs[j] != 0) - // shapeLength *= reshapeArgs[j]; - // else - // ++numOfZeros; - // } - - // const auto dim = xLen / shapeLength; - - // if(x->isEmpty() && (1 == dim || 0 == numOfZeros)) - // shapeNew.push_back(0); - // else - // shapeNew.push_back(dim); - // } - // else - // shapeNew.push_back(reshapeArgs[i]); - // } - - Nd4jLong newShapeLen = 1; - int pos = -1; - bool newShapeEmpty = false; - - for (int i = 0; i < reshapeArgs.size(); ++i) { - - const int dim = reshapeArgs[i]; - - if (dim == -1) { - REQUIRE_TRUE(pos == -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); - pos = i; - shapeNew.push_back(1); - } - else if (dim == 0) { - shapeNew.push_back(0); - newShapeEmpty = true; - } - else { - shapeNew.push_back(dim); - newShapeLen *= dim; - } + if (pos != -1) { + Nd4jLong xLen = x->lengthOf(); + if (x->isEmpty()) { + xLen = 1; + for (uint i = 0; i < x->rankOf(); + ++i) // take into account possible empty shapes + if (x->sizeAt(i) > 0 || !newShapeEmpty) xLen *= x->sizeAt(i); } - if (pos != -1) { + shapeNew[pos] = xLen / newShapeLen; + } - Nd4jLong xLen = x->lengthOf(); - if(x->isEmpty()) { - xLen = 1; - for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes - if(x->sizeAt(i) > 0 || !newShapeEmpty) - xLen *= x->sizeAt(i); - } + auto len = shape::prodLong(shapeNew.data(), shapeNew.size()); + REQUIRE_TRUE(x->lengthOf() == len, 0, + "Reshape: lengths before and after reshape should match, but " + "got %i vs %i", + x->lengthOf(), len); - shapeNew[pos] = xLen / newShapeLen; - } - - auto len = shape::prodLong(shapeNew.data(), shapeNew.size()); - REQUIRE_TRUE(x->lengthOf() == len, 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len); - - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(x->dataType(), orderNew, shapeNew)); + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + x->dataType(), orderNew, shapeNew)); } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/reshape_as.cpp b/libnd4j/include/ops/declarable/generic/shape/reshape_as.cpp index 9a8dc00c24a1..4c60649183a4 100644 --- a/libnd4j/include/ops/declarable/generic/shape/reshape_as.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/reshape_as.cpp @@ -24,38 +24,32 @@ #include namespace sd { - namespace ops { +namespace ops { +////////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(reshapeas, 2, 1, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); - ////////////////////////////////////////////////////////////////////////// - CUSTOM_OP_IMPL(reshapeas, 2, 1, false, 0, 0) { + auto z = OUTPUT_VARIABLE(0); - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); + // FIXME: add validation here? + auto tmp = x->reshape(y->ordering(), y->getShapeAsVector()); + z->assign(tmp); - auto z = OUTPUT_VARIABLE(0); - - if (x->reshapei(y->ordering(), y->getShapeAsVector())) { - - z->assign(x); - return Status::OK(); - } - - return ND4J_STATUS_BAD_INPUT; - } - DECLARE_SYN(reshape_as, reshapeas); - - DECLARE_SHAPE_FN(reshapeas) { - - return SHAPELIST(ShapeBuilders::copyShapeInfo(INPUT_VARIABLE(1)->shapeInfo(), false, block.workspace())); - } + return Status::OK(); +} +DECLARE_SYN(reshape_as, reshapeas); - DECLARE_TYPES(reshapeas) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } +DECLARE_SHAPE_FN(reshapeas) { + return SHAPELIST(CONSTANT(ShapeBuilders::copyShapeInfo( + INPUT_VARIABLE(1)->shapeInfo(), false, block.workspace()))); } + +DECLARE_TYPES(reshapeas) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/shape.cpp b/libnd4j/include/ops/declarable/generic/shape/shape.cpp index 098825df386c..4a8dfd606a77 100644 --- a/libnd4j/include/ops/declarable/generic/shape/shape.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/shape.cpp @@ -24,37 +24,36 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(shape_of, 1, 1, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(shape_of, 1, 1, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - for (int e = 0; e < x->rankOf(); e++) - z->p(e, x->sizeAt(e)); + for (int e = 0; e < x->rankOf(); e++) z->p(e, x->sizeAt(e)); - STORE_RESULT(z); + STORE_RESULT(z); - return Status::OK(); - }; - DECLARE_SYN(shape, shape_of); + return Status::OK(); +}; +DECLARE_SYN(shape, shape_of); - DECLARE_SHAPE_FN(shape_of) { - auto inShape = inputShape->at(0); +DECLARE_SHAPE_FN(shape_of) { + auto inShape = inputShape->at(0); - // LONG by default - auto dtype = DataType::INT64; - if (block.numI() > 0) - dtype = DataTypeUtils::fromInt(INT_ARG(0)); + // LONG by default + auto dtype = DataType::INT64; + if (block.numI() > 0) dtype = DataTypeUtils::fromInt(INT_ARG(0)); - return SHAPELIST(ConstantShapeHelper::getInstance().vectorShapeInfo(shape::rank(inShape), dtype)); - }; + return SHAPELIST(ConstantShapeHelper::getInstance().vectorShapeInfo( + shape::rank(inShape), dtype)); +}; - DECLARE_TYPES(shape_of) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_INTS}); - } - } +DECLARE_TYPES(shape_of) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_INTS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/shapes.cpp b/libnd4j/include/ops/declarable/generic/shape/shapes.cpp index 3f5428122729..f412ba55ceae 100644 --- a/libnd4j/include/ops/declarable/generic/shape/shapes.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/shapes.cpp @@ -24,37 +24,37 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(shapes_of, -1, -1, false, 0, 0) { - for (int e = 0; e < block.width(); e++) { - auto x = INPUT_VARIABLE(e); - auto z = OUTPUT_VARIABLE(e); - - for (int i = 0; i < x->rankOf(); i++) - z->p(i, x->sizeAt(i)); - } - - return Status::OK(); - }; - DECLARE_SYN(shape_n, shapes_of); - - DECLARE_SHAPE_FN(shapes_of) { - auto shapeList = SHAPELIST(); - - for (int e = 0; e < inputShape->size(); e++) { - auto inShape = inputShape->at(e); - shapeList->push_back(ConstantShapeHelper::getInstance().vectorShapeInfo(shape::rank(inShape), sd::DataType::INT64)); - } - - return shapeList; - }; - - DECLARE_TYPES(shapes_of) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_INTS}); - } - } +namespace ops { +CUSTOM_OP_IMPL(shapes_of, -1, -1, false, 0, 0) { + for (int e = 0; e < block.width(); e++) { + auto x = INPUT_VARIABLE(e); + auto z = OUTPUT_VARIABLE(e); + + for (int i = 0; i < x->rankOf(); i++) z->p(i, x->sizeAt(i)); + } + + return Status::OK(); +}; +DECLARE_SYN(shape_n, shapes_of); + +DECLARE_SHAPE_FN(shapes_of) { + auto shapeList = SHAPELIST(); + + for (int e = 0; e < inputShape->size(); e++) { + auto inShape = inputShape->at(e); + shapeList->push_back(ConstantShapeHelper::getInstance().vectorShapeInfo( + shape::rank(inShape), sd::DataType::INT64)); + } + + return shapeList; +}; + +DECLARE_TYPES(shapes_of) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_INTS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/size.cpp b/libnd4j/include/ops/declarable/generic/shape/size.cpp index c30ed1b582df..954f5c4e6757 100644 --- a/libnd4j/include/ops/declarable/generic/shape/size.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/size.cpp @@ -24,29 +24,30 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(size, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(output->isScalar(), 0, "Size output should be scalar"); - - output->p(0, input->lengthOf()); - output->syncToDevice(); - - return Status::OK(); - } - DECLARE_SHAPE_FN(size) { - return SHAPELIST(ConstantShapeHelper::getInstance().scalarShapeInfo(sd::DataType::INT64)); - } - - DECLARE_TYPES(size) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}) - ->allowOverride(true); - } - } +namespace ops { +CUSTOM_OP_IMPL(size, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(output->isScalar(), 0, "Size output should be scalar"); + + output->p(0, input->lengthOf()); + output->syncToDevice(); + + return Status::OK(); +} +DECLARE_SHAPE_FN(size) { + return SHAPELIST( + ConstantShapeHelper::getInstance().scalarShapeInfo(sd::DataType::INT64)); +} + +DECLARE_TYPES(size) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}) + ->allowOverride(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/size_at.cpp b/libnd4j/include/ops/declarable/generic/shape/size_at.cpp index 46491e688fa1..b6a5768e84a4 100644 --- a/libnd4j/include/ops/declarable/generic/shape/size_at.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/size_at.cpp @@ -24,34 +24,35 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(size_at, 1, 1, false, 0, 1) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - auto dim = INT_ARG(0); - if (dim < 0) - dim += input->rankOf(); - - REQUIRE_TRUE(dim < input->rankOf(), 0, "Size_At: Dim can't be higher then input rank") - - output->p(0, input->sizeAt(dim)); - output->syncToDevice(); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(size_at) { - return SHAPELIST(ConstantShapeHelper::getInstance().scalarShapeInfo(sd::DataType::INT64)); - } - - DECLARE_TYPES(size_at) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(DataType::INT64) - ->allowOverride(true); - } - } +namespace ops { +CUSTOM_OP_IMPL(size_at, 1, 1, false, 0, 1) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + auto dim = INT_ARG(0); + if (dim < 0) dim += input->rankOf(); + + REQUIRE_TRUE(dim < input->rankOf(), 0, + "Size_At: Dim can't be higher then input rank") + + output->p(0, input->sizeAt(dim)); + output->syncToDevice(); + + return Status::OK(); +} + +DECLARE_SHAPE_FN(size_at) { + return SHAPELIST( + ConstantShapeHelper::getInstance().scalarShapeInfo(sd::DataType::INT64)); +} + +DECLARE_TYPES(size_at) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(DataType::INT64) + ->allowOverride(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/squeeze.cpp b/libnd4j/include/ops/declarable/generic/shape/squeeze.cpp index 5698f957faec..14458571cdb1 100644 --- a/libnd4j/include/ops/declarable/generic/shape/squeeze.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/squeeze.cpp @@ -24,135 +24,136 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(squeeze, 1, 1, false, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - std::vector axis; - - if (block.numI() > 0) - for (int e = 0; e < block.numI(); e++) { - int _a = INT_ARG(e); - if (_a < 0) - _a += input->rankOf(); - - axis.emplace_back(_a); - } - else if (block.width() > 1) { - auto a = INPUT_VARIABLE(1); - for (Nd4jLong e = 0; e < a->lengthOf(); e++) { - int _a = a->e(e); - - if (_a < 0) - _a += input->rankOf(); - - axis.emplace_back(_a); - } - } - - if (input->rankOf() == 0 || (input->rankOf() == 1 && input->lengthOf() == 1)) { - output->assign(input); - return Status::OK(); - } - - std::vector shape; - if (axis.size() == 0) { - for (int d = 0; d < input->rankOf(); d++) - if (input->sizeAt(d) > 1) - shape.emplace_back(input->sizeAt(d)); - } else { - for (int d = 0; d < input->rankOf(); d++) { - if (input->sizeAt(d) == 1) { - if (std::find(axis.begin(), axis.end(), d) == axis.end()) - shape.emplace_back(input->sizeAt(d)); - } else shape.emplace_back(input->sizeAt(d)); - } - } - - if (block.isInplace()) { - output->reshapei(input->ordering(), shape, false); - } else { - if (input->ews() == 1 && output->ews() == 1 && input->ordering() == output->ordering()) { - output->dataBuffer()->copyBufferFrom(*input->dataBuffer().get(), output->lengthOf() * DataTypeUtils::sizeOfElement(output->dataType()), 0, input->bufferOffset()); - } else { - auto tmp = input->reshape(input->ordering(), shape); - output->assign(tmp); - } - } - - return Status::OK(); - } - - DECLARE_TYPES(squeeze) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - - DECLARE_SHAPE_FN(squeeze) { - auto shapeList = SHAPELIST(); - -// Nd4jLong* newShape; - auto in = inputShape->at(0); - auto rank = shape::rank(in); - auto length = shape::length(in); - - if (rank == 0 || (rank == 1 && length == 1)) { - shapeList->push_back(ConstantShapeHelper::getInstance().scalarShapeInfo(ArrayOptions::dataType(in))); - return shapeList; - } - - std::vector axis; - - if (block.numI() > 0) - for (int e = 0; e < block.numI(); e++) { - int _a = INT_ARG(e); - if (_a < 0) - _a += rank; - - axis.emplace_back(_a); - } - else if (block.width() > 1) { - auto a = INPUT_VARIABLE(1); - for (int e = 0; e < a->lengthOf(); e++) { - int _a = a->e(e); - - if (_a < 0) - _a += rank; - - axis.emplace_back(_a); - } - - } - - auto order = shape::order(in); - auto oldShape = shape::shapeOf(in); - - std::vector shape; - if (axis.size() == 0) { - for (int d = 0; d < rank; d++) - if (oldShape[d] > 1) - shape.emplace_back(oldShape[d]); - } else { - for (int d = 0; d < rank; d++) { - if (oldShape[d] == 1) { - if (std::find(axis.begin(), axis.end(), d) == axis.end()) - shape.emplace_back(oldShape[d]); - } else shape.emplace_back(oldShape[d]); - } - } - - if (shape.size() == 0) { - shapeList->push_back(ConstantShapeHelper::getInstance().scalarShapeInfo(ArrayOptions::dataType(in))); - return shapeList; - } - - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(in), order, shape); - shapeList->push_back(newShape); - return shapeList; - } +namespace ops { +CUSTOM_OP_IMPL(squeeze, 1, 1, false, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + std::vector axis; + + if (block.numI() > 0) + for (int e = 0; e < block.numI(); e++) { + int _a = INT_ARG(e); + if (_a < 0) _a += input->rankOf(); + + axis.emplace_back(_a); + } + else if (block.width() > 1) { + auto a = INPUT_VARIABLE(1); + for (Nd4jLong e = 0; e < a->lengthOf(); e++) { + int _a = a->e(e); + + if (_a < 0) _a += input->rankOf(); + + axis.emplace_back(_a); + } + } + + if (input->rankOf() == 0 || + (input->rankOf() == 1 && input->lengthOf() == 1)) { + output->assign(input); + return Status::OK(); + } + + std::vector shape; + if (axis.size() == 0) { + for (int d = 0; d < input->rankOf(); d++) + if (input->sizeAt(d) > 1) shape.emplace_back(input->sizeAt(d)); + } else { + for (int d = 0; d < input->rankOf(); d++) { + if (input->sizeAt(d) == 1) { + if (std::find(axis.begin(), axis.end(), d) == axis.end()) + shape.emplace_back(input->sizeAt(d)); + } else + shape.emplace_back(input->sizeAt(d)); + } + } + + if (block.isInplace()) { + output->reshapei(input->ordering(), shape, false); + } else { + if (input->ews() == 1 && output->ews() == 1 && + input->ordering() == output->ordering()) { + output->dataBuffer()->copyBufferFrom( + *input->dataBuffer().get(), + output->lengthOf() * DataTypeUtils::sizeOfElement(output->dataType()), + 0, input->bufferOffset()); + } else { + auto tmp = input->reshape(input->ordering(), shape); + output->assign(tmp); + } + } + + return Status::OK(); +} + +DECLARE_TYPES(squeeze) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} + +DECLARE_SHAPE_FN(squeeze) { + auto shapeList = SHAPELIST(); + + // Nd4jLong* newShape; + auto in = inputShape->at(0); + auto rank = shape::rank(in); + auto length = shape::length(in); + + if (rank == 0 || (rank == 1 && length == 1)) { + shapeList->push_back(ConstantShapeHelper::getInstance().scalarShapeInfo( + ArrayOptions::dataType(in))); + return shapeList; + } + + std::vector axis; + + if (block.numI() > 0) + for (int e = 0; e < block.numI(); e++) { + int _a = INT_ARG(e); + if (_a < 0) _a += rank; + + axis.emplace_back(_a); + } + else if (block.width() > 1) { + auto a = INPUT_VARIABLE(1); + for (int e = 0; e < a->lengthOf(); e++) { + int _a = a->e(e); + + if (_a < 0) _a += rank; + + axis.emplace_back(_a); + } + } + + auto order = shape::order(in); + auto oldShape = shape::shapeOf(in); + + std::vector shape; + if (axis.size() == 0) { + for (int d = 0; d < rank; d++) + if (oldShape[d] > 1) shape.emplace_back(oldShape[d]); + } else { + for (int d = 0; d < rank; d++) { + if (oldShape[d] == 1) { + if (std::find(axis.begin(), axis.end(), d) == axis.end()) + shape.emplace_back(oldShape[d]); + } else + shape.emplace_back(oldShape[d]); } + } + + if (shape.size() == 0) { + shapeList->push_back(ConstantShapeHelper::getInstance().scalarShapeInfo( + ArrayOptions::dataType(in))); + return shapeList; + } + + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(in), order, shape); + shapeList->push_back(newShape); + return shapeList; } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/tile_to_shape.cpp b/libnd4j/include/ops/declarable/generic/shape/tile_to_shape.cpp index ec0476e049b4..28ee81026d87 100644 --- a/libnd4j/include/ops/declarable/generic/shape/tile_to_shape.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/tile_to_shape.cpp @@ -25,75 +25,75 @@ namespace sd { namespace ops { - CUSTOM_OP_IMPL(tile_to_shape, 1, 1, false, 0, -1) { +CUSTOM_OP_IMPL(tile_to_shape, 1, 1, false, 0, -1) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + std::vector outShape(block.getIArguments().begin(), + block.getIArguments().end()); - std::vector outShape(block.getIArguments()->begin(), block.getIArguments()->end()); + if (block.isInplace()) { + input->tileToShape(outShape, *input); + } else { + input->tileToShape(outShape, *output); + } - if (block.isInplace()) { - input->tileToShape(outShape, *input); - } else { - input->tileToShape(outShape, *output); - } - - return Status::OK(); - } - - DECLARE_SHAPE_FN(tile_to_shape) { - auto in = inputShape->at(0); + return Status::OK(); +} - // output shape always equals to arguments +DECLARE_SHAPE_FN(tile_to_shape) { + auto in = inputShape->at(0); - auto conv = ArrayUtils::toLongVector(*block.getIArguments()); + // output shape always equals to arguments - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(in), shape::order(in), conv); + auto conv = ArrayUtils::toLongVector(block.getIArguments()); - return SHAPELIST(newShape); - } + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(in), shape::order(in), conv); - DECLARE_TYPES(tile_to_shape) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } + return SHAPELIST(newShape); +} - DECLARE_TYPES(tile_to_shape_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(tile_to_shape) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} +DECLARE_TYPES(tile_to_shape_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - CUSTOM_OP_IMPL(tile_to_shape_bp, 2, 1, true, 0, -1) { - auto input = INPUT_VARIABLE(0); - auto epsNext = INPUT_VARIABLE(1); +CUSTOM_OP_IMPL(tile_to_shape_bp, 2, 1, true, 0, -1) { + auto input = INPUT_VARIABLE(0); + auto epsNext = INPUT_VARIABLE(1); - auto gradX = OUTPUT_VARIABLE(0); + auto gradX = OUTPUT_VARIABLE(0); - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(input->shapeInfo(), epsNext->shapeInfo()); - // FIX ME: reduceAlongDimension should have a signature with result pass to avoid assigning twice - if (!axisX.empty()) { - auto tempRes = epsNext->reduceAlongDimension(reduce::Sum, axisX); - gradX->assign(tempRes); - } else - gradX->assign(epsNext); + auto axisX = ShapeUtils::evalBroadcastBackwardAxis(input->shapeInfo(), + epsNext->shapeInfo()); + // FIX ME: reduceAlongDimension should have a signature with result pass to + // avoid assigning twice + if (!axisX.empty()) { + auto tempRes = epsNext->reduceAlongDimension(reduce::Sum, axisX); + gradX->assign(tempRes); + } else + gradX->assign(epsNext); - STORE_RESULT(gradX); + STORE_RESULT(gradX); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(tile_to_shape_bp) { - auto in = inputShape->at(0); +DECLARE_SHAPE_FN(tile_to_shape_bp) { + auto in = inputShape->at(0); - Nd4jLong *newShape; - COPY_SHAPE(in, newShape); + Nd4jLong *newShape; + COPY_SHAPE(in, newShape); - return SHAPELIST(CONSTANT(newShape)); - } -} + return SHAPELIST(CONSTANT(newShape)); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/transpose.cpp b/libnd4j/include/ops/declarable/generic/shape/transpose.cpp index 0b12f415ffc6..fdd406e62909 100644 --- a/libnd4j/include/ops/declarable/generic/shape/transpose.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/transpose.cpp @@ -22,57 +22,61 @@ #include #if NOT_EXCLUDED(OP_transpose) -#include #include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(transpose, 1, 1, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + + // Special case: empty.reshape() -> return empty + if (x->isEmpty()) { + REQUIRE_TRUE( + z->isEmpty(), 0, + "TRANSPOSE OP: when input is empty, output must also be empty"); + return Status::OK(); // No op + } + + if (block.width() == 1 && block.numI() == 0) { + z->assign(x->transpose()); + return Status::OK(); + } - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); - - //Special case: empty.reshape() -> return empty - if (x->isEmpty()) { - REQUIRE_TRUE(z->isEmpty(), 0, "TRANSPOSE OP: when input is empty, output must also be empty"); - return Status::OK(); //No op - } - - if (block.width() == 1 && block.getIArguments()->size() == 0) { - z->assign(x->transpose()); - return Status::OK(); - } - - std::vector permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : *block.getIArguments(); + std::vector permutationVector = block.width() > 1 + ? INPUT_VARIABLE(1)->asVectorT() + : block.getIArguments(); - z->assign(x->permute(permutationVector)); + z->assign(x->permute(permutationVector)); - return Status::OK(); + return Status::OK(); } DECLARE_TYPES(transpose) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } DECLARE_SHAPE_FN(transpose) { + auto x = INPUT_VARIABLE(0); - auto x = INPUT_VARIABLE(0); - - if (block.width() == 1 && block.getIArguments()->size() == 0) - return SHAPELIST(ShapeUtils::evalTranspShapeInfo(*x, block.workspace(), true)); + if (block.width() == 1 && block.numI() == 0) + return SHAPELIST( + ShapeUtils::evalTranspShapeInfo(*x, block.workspace(), true)); - std::vector permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : *block.getIArguments(); + std::vector permutationVector = block.width() > 1 + ? INPUT_VARIABLE(1)->asVectorT() + : block.getIArguments(); - auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(permutationVector.data(), x->rankOf(), *x, block.workspace(), true); + auto outputShapeInfo = ShapeUtils::evalPermShapeInfo( + permutationVector.data(), x->rankOf(), *x, block.workspace(), true); - return SHAPELIST(outputShapeInfo); + return SHAPELIST(outputShapeInfo); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/strings/split_string.cpp b/libnd4j/include/ops/declarable/generic/strings/split_string.cpp index e42591b5c4ad..a9a6d4111775 100644 --- a/libnd4j/include/ops/declarable/generic/strings/split_string.cpp +++ b/libnd4j/include/ops/declarable/generic/strings/split_string.cpp @@ -24,27 +24,27 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(split_string, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto delim = INPUT_VARIABLE(1); - - return Status::OK(); - }; - - DECLARE_SHAPE_FN(split_string) { - auto input = INPUT_VARIABLE(0); - auto delim = INPUT_VARIABLE(1); - - return SHAPELIST(); - } - - DECLARE_TYPES(split_string) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_STRINGS}) - ->setAllowedOutputTypes({ALL_STRINGS}); - } - } +namespace ops { +CUSTOM_OP_IMPL(split_string, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto delim = INPUT_VARIABLE(1); + + return Status::OK(); +}; + +DECLARE_SHAPE_FN(split_string) { + auto input = INPUT_VARIABLE(0); + auto delim = INPUT_VARIABLE(1); + + return SHAPELIST(); +} + +DECLARE_TYPES(split_string) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_STRINGS}) + ->setAllowedOutputTypes({ALL_STRINGS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tensor/create.cpp b/libnd4j/include/ops/declarable/generic/tensor/create.cpp index c692a74d86df..16ee044bf25e 100644 --- a/libnd4j/include/ops/declarable/generic/tensor/create.cpp +++ b/libnd4j/include/ops/declarable/generic/tensor/create.cpp @@ -24,35 +24,36 @@ #include namespace sd { - namespace ops { +namespace ops { - CUSTOM_OP_IMPL(create, 1, 1, false, 0, 1) { - auto init = block.numB() > 0 ? B_ARG(0) : true; +CUSTOM_OP_IMPL(create, 1, 1, false, 0, 1) { + auto init = block.numB() > 0 ? B_ARG(0) : true; - if (init) - OUTPUT_VARIABLE(0)->nullify(); + if (init) OUTPUT_VARIABLE(0)->nullify(); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(create) { - auto shapeInput = INPUT_VARIABLE(0); - auto order = (char) INT_ARG(0); - auto dtype = DataTypeUtils::fromInt(INT_ARG(1)); +DECLARE_SHAPE_FN(create) { + auto shapeInput = INPUT_VARIABLE(0); + auto order = (char)INT_ARG(0); + auto dtype = DataTypeUtils::fromInt(INT_ARG(1)); - REQUIRE_TRUE(order == 'c' || order == 'f', 0, "create: order must be either c or f"); + REQUIRE_TRUE(order == 'c' || order == 'f', 0, + "create: order must be either c or f"); - auto shape = shapeInput->getBufferAsVector(); + auto shape = shapeInput->getBufferAsVector(); - return SHAPELIST(sd::ConstantShapeHelper::getInstance().createShapeInfo(dtype, order, shape)); - } + return SHAPELIST(sd::ConstantShapeHelper::getInstance().createShapeInfo( + dtype, order, shape)); +} - DECLARE_TYPES(create) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS}) - ->setAllowedOutputTypes(sd::DataType::ANY); - } - } +DECLARE_TYPES(create) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS}) + ->setAllowedOutputTypes(sd::DataType::ANY); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tensor/fill.cpp b/libnd4j/include/ops/declarable/generic/tensor/fill.cpp index 81cece90130b..3284158c5ea8 100644 --- a/libnd4j/include/ops/declarable/generic/tensor/fill.cpp +++ b/libnd4j/include/ops/declarable/generic/tensor/fill.cpp @@ -24,75 +24,76 @@ #include namespace sd { - namespace ops { - - CUSTOM_OP_IMPL(fill, 1, 1, false, -2, 0) { - auto shapeArray = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - auto w = block.width(); - auto i = block.numI(); - auto t = block.numT(); - - REQUIRE_TRUE( w > 1 || t > 0 || i > 0, 0, "Fill: either additional variable should exist, or scalar value should be present"); - - if(output->isEmpty()){ - //Empty output array - no-op - return Status::OK(); - } - - if (w > 1) { - output->assign(INPUT_VARIABLE(1)); - } else { - if (t > 0) { - output->assign(T_ARG(0)); - } else if (i > 0) { - output->assign(INT_ARG(0)); - } - } - - STORE_RESULT(output); - - return Status::OK(); - }; - - DECLARE_TYPES(fill) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } - - DECLARE_SHAPE_FN(fill) { - - auto shapeArray = INPUT_VARIABLE(0); - const int len = (int) shapeArray->lengthOf(); - Nd4jLong *newShape = nullptr; - ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(len), Nd4jLong); - - newShape[0] = len; - for (int e = 0; e < shapeArray->lengthOf(); e++){ - newShape[e+1] = shapeArray->e(e); - } - - sd::DataType dataType; - - if (block.width() > 1) { - dataType = INPUT_VARIABLE(1)->dataType(); - } else if (block.numT() > 0) { - dataType = Environment::getInstance().defaultFloatDataType(); - } else if (block.numI() > 0) { - dataType = sd::DataType::INT32; - } else if (block.numB() > 0) { - dataType = sd::DataType::BOOL; - } else - throw std::runtime_error("Fill: missing value to fill output array with"); - - ShapeUtils::updateStridesAndType(newShape, dataType, 'c'); - - return SHAPELIST(CONSTANT(newShape)); - }; +namespace ops { + +CUSTOM_OP_IMPL(fill, 1, 1, false, -2, 0) { + auto shapeArray = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + auto w = block.width(); + auto i = block.numI(); + auto t = block.numT(); + + REQUIRE_TRUE(w > 1 || t > 0 || i > 0, 0, + "Fill: either additional variable should exist, or scalar value " + "should be present"); + + if (output->isEmpty()) { + // Empty output array - no-op + return Status::OK(); + } + + if (w > 1) { + output->assign(INPUT_VARIABLE(1)); + } else { + if (t > 0) { + output->assign(T_ARG(0)); + } else if (i > 0) { + output->assign(INT_ARG(0)); } + } + + STORE_RESULT(output); + + return Status::OK(); +}; + +DECLARE_TYPES(fill) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } +DECLARE_SHAPE_FN(fill) { + auto shapeArray = INPUT_VARIABLE(0); + const int len = (int)shapeArray->lengthOf(); + Nd4jLong *newShape = nullptr; + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(len), Nd4jLong); + + newShape[0] = len; + for (int e = 0; e < shapeArray->lengthOf(); e++) { + newShape[e + 1] = shapeArray->e(e); + } + + sd::DataType dataType; + + if (block.width() > 1) { + dataType = INPUT_VARIABLE(1)->dataType(); + } else if (block.numT() > 0) { + dataType = Environment::getInstance().defaultFloatDataType(); + } else if (block.numI() > 0) { + dataType = sd::DataType::INT32; + } else if (block.numB() > 0) { + dataType = sd::DataType::BOOL; + } else + throw std::runtime_error("Fill: missing value to fill output array with"); + + ShapeUtils::updateStridesAndType(newShape, dataType, 'c'); + + return SHAPELIST(CONSTANT(newShape)); +}; +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tensor/fill_as.cpp b/libnd4j/include/ops/declarable/generic/tensor/fill_as.cpp index f2d74572d2b2..746b4eaa3cb8 100644 --- a/libnd4j/include/ops/declarable/generic/tensor/fill_as.cpp +++ b/libnd4j/include/ops/declarable/generic/tensor/fill_as.cpp @@ -24,33 +24,30 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(fill_as, 1, 1, true, 0, 0) { - auto output = OUTPUT_VARIABLE(0); - - if (block.width() > 1) { - auto s = INPUT_VARIABLE(1); - output->assign(s); - } else if (block.numT() > 0) { - output->assign(T_ARG(0)); - } else if (block.numI() > 0) { - output->assign(INT_ARG(0)); - } - - STORE_RESULT(output); - - return ND4J_STATUS_OK; - } - DECLARE_SYN(filllike, fill_as); - DECLARE_SYN(fill_like, fill_as); - - - DECLARE_TYPES(fill_as) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +namespace ops { +CONFIGURABLE_OP_IMPL(fill_as, 1, 1, true, 0, 0) { + auto output = OUTPUT_VARIABLE(0); + + if (block.width() > 1) { + auto s = INPUT_VARIABLE(1); + output->assign(s); + } else if (block.numT() > 0) { + output->assign(T_ARG(0)); + } else if (block.numI() > 0) { + output->assign(INT_ARG(0)); + } + + STORE_RESULT(output); + + return ND4J_STATUS_OK; } +DECLARE_SYN(filllike, fill_as); +DECLARE_SYN(fill_like, fill_as); + +DECLARE_TYPES(fill_as) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tensor/lin_space.cpp b/libnd4j/include/ops/declarable/generic/tensor/lin_space.cpp index 97f7b390f0b0..6e93c6791930 100644 --- a/libnd4j/include/ops/declarable/generic/tensor/lin_space.cpp +++ b/libnd4j/include/ops/declarable/generic/tensor/lin_space.cpp @@ -26,50 +26,64 @@ namespace sd { namespace ops { - CUSTOM_OP_IMPL(lin_space, 0, 1, false, 0, 0) { - - auto output = OUTPUT_VARIABLE(0); - - const int nInputs = block.width(); - bool bInputs = (3 == nInputs || 3 == block.numI() || (2 == block.numT() && block.numI() > 0)); - - REQUIRE_TRUE(bInputs, 0, "lin_space OP: Have to be supplied correct inputs, input size or T_ARG size have to be equal 3, but got inputs - %i, T_ARGS - %i!", nInputs, block.numT()); - - auto start = (nInputs > 0) ? INPUT_VARIABLE(0)->e(0) : static_cast(T_ARG(0)); - auto finish = (nInputs > 0) ? INPUT_VARIABLE(1)->e(0) : static_cast(T_ARG(1)); - auto numOfElements = (nInputs > 0) ? INPUT_VARIABLE(2)->e(0) : static_cast(I_ARG(0)); - - if (numOfElements == 1) { - output->assign(start); - return Status::OK(); - } - - output->linspace(start, (finish - start) / ( numOfElements - 1.0 )); - return Status::OK(); - } - - DECLARE_SHAPE_FN(lin_space) { - - const int nInputs = block.width(); - bool bInputs = (3 == nInputs || 3 == block.numI() || (2 == block.numT() && block.numI() > 0)); - REQUIRE_TRUE(bInputs, 0, "lin_space OP: Have to be supplied correct inputs, input size or T_ARG size have to be equal 3, but got inputs - %i, T_ARGS - %i!", nInputs, block.numT() ); - - - auto dataType = (nInputs > 0) ? ArrayOptions::dataType(inputShape->at(0)) : ( block.numD() > 0 ? static_cast(D_ARG(0)) : DataType::FLOAT32) ; - Nd4jLong steps = (nInputs > 0) ? INPUT_VARIABLE(2)->e(0) : static_cast(I_ARG(0)); - - return SHAPELIST(ConstantShapeHelper::getInstance().vectorShapeInfo(steps, dataType)); - } - +CUSTOM_OP_IMPL(lin_space, 0, 1, false, 0, 0) { + auto output = OUTPUT_VARIABLE(0); + + const int nInputs = block.width(); + bool bInputs = (3 == nInputs || 3 == block.numI() || + (2 == block.numT() && block.numI() > 0)); + + REQUIRE_TRUE( + bInputs, 0, + "lin_space OP: Have to be supplied correct inputs, input size or T_ARG " + "size have to be equal 3, but got inputs - %i, T_ARGS - %i!", + nInputs, block.numT()); + + auto start = (nInputs > 0) ? INPUT_VARIABLE(0)->e(0) + : static_cast(T_ARG(0)); + auto finish = (nInputs > 0) ? INPUT_VARIABLE(1)->e(0) + : static_cast(T_ARG(1)); + auto numOfElements = (nInputs > 0) ? INPUT_VARIABLE(2)->e(0) + : static_cast(I_ARG(0)); + + if (numOfElements == 1) { + output->assign(start); + return Status::OK(); + } + + output->linspace(start, (finish - start) / (numOfElements - 1.0)); + return Status::OK(); +} - DECLARE_TYPES(lin_space) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); - } +DECLARE_SHAPE_FN(lin_space) { + const int nInputs = block.width(); + bool bInputs = (3 == nInputs || 3 == block.numI() || + (2 == block.numT() && block.numI() > 0)); + REQUIRE_TRUE( + bInputs, 0, + "lin_space OP: Have to be supplied correct inputs, input size or T_ARG " + "size have to be equal 3, but got inputs - %i, T_ARGS - %i!", + nInputs, block.numT()); + + auto dataType = (nInputs > 0) + ? ArrayOptions::dataType(inputShape->at(0)) + : (block.numD() > 0 ? static_cast(D_ARG(0)) + : DataType::FLOAT32); + Nd4jLong steps = (nInputs > 0) ? INPUT_VARIABLE(2)->e(0) + : static_cast(I_ARG(0)); + + return SHAPELIST( + ConstantShapeHelper::getInstance().vectorShapeInfo(steps, dataType)); } + +DECLARE_TYPES(lin_space) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tensor/ones_as.cpp b/libnd4j/include/ops/declarable/generic/tensor/ones_as.cpp index 0fb8fe283ea1..910429167b80 100644 --- a/libnd4j/include/ops/declarable/generic/tensor/ones_as.cpp +++ b/libnd4j/include/ops/declarable/generic/tensor/ones_as.cpp @@ -24,32 +24,35 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(ones_as, 1, 1, false, 0, 0) { - auto output = OUTPUT_VARIABLE(0); +namespace ops { - output->assign(1); +CUSTOM_OP_IMPL(ones_as, 1, 1, false, 0, 0) { + auto output = OUTPUT_VARIABLE(0); - return Status::OK(); - } + output->assign(1); - DECLARE_SHAPE_FN(ones_as) { - auto in = inputShape->at(0); - auto dtype = block.numD() ? D_ARG(0) : ArrayOptions::dataType(in); - auto shape = sd::ConstantShapeHelper::getInstance().createShapeInfo(dtype, in); + return Status::OK(); +} + +DECLARE_SHAPE_FN(ones_as) { + auto in = inputShape->at(0); + auto dtype = block.numD() ? D_ARG(0) : ArrayOptions::dataType(in); + auto shape = + sd::ConstantShapeHelper::getInstance().createShapeInfo(dtype, in); - //nd4j_printf("numD: %i; dtype: %s\n", block.numD(), DataTypeUtils::asString(dtype).c_str()); + // nd4j_printf("numD: %i; dtype: %s\n", block.numD(), + // DataTypeUtils::asString(dtype).c_str()); - return SHAPELIST(shape); - } + return SHAPELIST(shape); +} - DECLARE_TYPES(ones_as) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::ANY) - ->setSameMode(false); - } - } +DECLARE_TYPES(ones_as) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::ANY) + ->setSameMode(false); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tensor/range.cpp b/libnd4j/include/ops/declarable/generic/tensor/range.cpp index 2f88b819b187..1751ae1c1c16 100644 --- a/libnd4j/include/ops/declarable/generic/tensor/range.cpp +++ b/libnd4j/include/ops/declarable/generic/tensor/range.cpp @@ -29,257 +29,262 @@ namespace sd { namespace ops { CUSTOM_OP_IMPL(range, -2, 1, false, -2, -2) { - auto output = OUTPUT_VARIABLE(0); - - const int numInArrs = block.width(); - const int numTArgs = block.getTArguments()->size(); - const int numIArgs = block.getIArguments()->size(); - - NDArray *s = nullptr; - NDArray *d = nullptr; - - bool localS = false; - bool localD = false; - // FIXME: this op should be fully moved to helpers - - if (output->isEmpty()) - return Status::OK(); - - if (numInArrs > 0) { - if(numInArrs == 1) { - //limit = (*INPUT_VARIABLE(0))(0.); - if (output->isR()) { - s = NDArrayFactory::create_(0.0f, block.launchContext()); - d = NDArrayFactory::create_(1.0f, block.launchContext()); - } else { - s = NDArrayFactory::create_(0, block.launchContext()); - d = NDArrayFactory::create_(1, block.launchContext()); - } - localS = true; - localD = true; - } else if(numInArrs == 2) { - s = INPUT_VARIABLE(0); - //limit = (*INPUT_VARIABLE(1))(0.); - if (output->isR()) { - d = NDArrayFactory::create_(1.0f, block.launchContext()); - } else { - d = NDArrayFactory::create_(1, block.launchContext()); - } - localD = true; - } else { - s = INPUT_VARIABLE(0); - //limit = (*INPUT_VARIABLE(1))(0.); - d = INPUT_VARIABLE(2); - } - } else if (numIArgs > 0) { - - if(numIArgs == 1) { - // limit = INT_ARG(0); - } else if(numIArgs == 2) { - s = NDArrayFactory::create_(INT_ARG(0), block.launchContext()); - //limit = INT_ARG(1); - d = NDArrayFactory::create_(1, block.launchContext()); - } - else { - s = NDArrayFactory::create_(INT_ARG(0), block.launchContext()); - //limit = INT_ARG(1); - d = NDArrayFactory::create_(INT_ARG(2), block.launchContext()); - } - - localS = true; - localD = true; + auto output = OUTPUT_VARIABLE(0); + + const int numInArrs = block.width(); + const int numTArgs = block.numT(); + const int numIArgs = block.numI(); + + NDArray *s = nullptr; + NDArray *d = nullptr; + + bool localS = false; + bool localD = false; + // FIXME: this op should be fully moved to helpers + + if (output->isEmpty()) return Status::OK(); + + if (numInArrs > 0) { + if (numInArrs == 1) { + // limit = (*INPUT_VARIABLE(0))(0.); + if (output->isR()) { + s = NDArrayFactory::create_(0.0f, block.launchContext()); + d = NDArrayFactory::create_(1.0f, block.launchContext()); + } else { + s = NDArrayFactory::create_(0, block.launchContext()); + d = NDArrayFactory::create_(1, block.launchContext()); + } + localS = true; + localD = true; + } else if (numInArrs == 2) { + s = INPUT_VARIABLE(0); + // limit = (*INPUT_VARIABLE(1))(0.); + if (output->isR()) { + d = NDArrayFactory::create_(1.0f, block.launchContext()); + } else { + d = NDArrayFactory::create_(1, block.launchContext()); + } + localD = true; + } else { + s = INPUT_VARIABLE(0); + // limit = (*INPUT_VARIABLE(1))(0.); + d = INPUT_VARIABLE(2); + } + } else if (numIArgs > 0) { + if (numIArgs == 1) { + // limit = INT_ARG(0); + } else if (numIArgs == 2) { + s = NDArrayFactory::create_(INT_ARG(0), block.launchContext()); + // limit = INT_ARG(1); + d = NDArrayFactory::create_(1, block.launchContext()); + } else { + s = NDArrayFactory::create_(INT_ARG(0), block.launchContext()); + // limit = INT_ARG(1); + d = NDArrayFactory::create_(INT_ARG(2), block.launchContext()); } - else if (numTArgs > 0) { - - if(numTArgs == 1) { - //limit = T_ARG(0); - s = NDArrayFactory::create_(0.0f, block.launchContext()); - d = NDArrayFactory::create_(1.0f, block.launchContext()); - } else if(numTArgs == 2) { - s = NDArrayFactory::create_(T_ARG(0), block.launchContext()); - //limit = T_ARG(1); - d = NDArrayFactory::create_(1.0f, block.launchContext()); - } - else { - s = NDArrayFactory::create_(T_ARG(0), block.launchContext()); - //limit = T_ARG(1); - d = NDArrayFactory::create_(T_ARG(2), block.launchContext()); - } - - localS = true; - localD = true; + + localS = true; + localD = true; + } else if (numTArgs > 0) { + if (numTArgs == 1) { + // limit = T_ARG(0); + s = NDArrayFactory::create_(0.0f, block.launchContext()); + d = NDArrayFactory::create_(1.0f, block.launchContext()); + } else if (numTArgs == 2) { + s = NDArrayFactory::create_(T_ARG(0), block.launchContext()); + // limit = T_ARG(1); + d = NDArrayFactory::create_(1.0f, block.launchContext()); } else { - REQUIRE_TRUE(false, 0, "CUSTOM RANGE OP: op should have inputs defined in any possible way: T_args, INT_args, or INPUT variables!"); + s = NDArrayFactory::create_(T_ARG(0), block.launchContext()); + // limit = T_ARG(1); + d = NDArrayFactory::create_(T_ARG(2), block.launchContext()); } - helpers::range(block.launchContext(), *s, *d, *output); + localS = true; + localD = true; + } else { + REQUIRE_TRUE(false, 0, + "CUSTOM RANGE OP: op should have inputs defined in any " + "possible way: T_args, INT_args, or INPUT variables!"); + } + + helpers::range(block.launchContext(), *s, *d, *output); - if (localS) - delete s; + if (localS) delete s; - if (localD) - delete d; + if (localD) delete d; - return Status::OK(); + return Status::OK(); } DECLARE_SHAPE_FN(range) { - - const int numInArrs = block.width(); - const int numTArgs = block.getTArguments()->size(); - const int numIArgs = block.getIArguments()->size(); - - Nd4jLong steps = 0; - sd::DataType dataType = block.numD() ? D_ARG(0) : sd::DataType::INHERIT; - - if (numInArrs > 0) { - auto isR = INPUT_VARIABLE(0)->isR(); - auto isZ = INPUT_VARIABLE(0)->isZ(); - auto dtype = INPUT_VARIABLE(0)->dataType(); - - if (isR) { - double start(0), limit, delta(1); - - if (numInArrs == 1) - limit = INPUT_VARIABLE(0)->e(0); - else if (numInArrs == 2) { - start = INPUT_VARIABLE(0)->e(0); - limit = INPUT_VARIABLE(1)->e(0); - } else { - start = INPUT_VARIABLE(0)->e(0); - limit = INPUT_VARIABLE(1)->e(0); - delta = INPUT_VARIABLE(2)->e(0); - } - - if (limit == start){ - //Return [0] to match TF - return SHAPELIST(ConstantShapeHelper::getInstance().vectorShapeInfo(0, dtype)); - } - - REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); - - steps = static_cast((limit - start) / delta); - - if (!block.numD()) - dataType = INPUT_VARIABLE(0)->dataType(); - - if(math::nd4j_abs(start + steps * delta) < math::nd4j_abs(limit)) - ++steps; - } else if (isZ) { - Nd4jLong start(0), limit, delta(1); - - if (numInArrs == 1) - limit = INPUT_VARIABLE(0)->e(0); - else if (numInArrs == 2) { - start = INPUT_VARIABLE(0)->e(0); - limit = INPUT_VARIABLE(1)->e(0); - } else { - start = INPUT_VARIABLE(0)->e(0); - limit = INPUT_VARIABLE(1)->e(0); - delta = INPUT_VARIABLE(2)->e(0); - } - - //nd4j_printf("Start: [%lld]; Limit: [%lld]; Delta: [%lld];\n", start, limit, delta) - - if (limit == start){ - //Return [0] to match TF - return SHAPELIST(ConstantShapeHelper::getInstance().vectorShapeInfo(0, dtype)); - } - - REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); - - steps = static_cast((limit - start) / delta); - - if (!block.numD()) - dataType = INPUT_VARIABLE(0)->dataType(); - - if(math::nd4j_abs(start + steps * delta) < math::nd4j_abs(limit)) - ++steps; - } - } else if (numIArgs > 0) { - Nd4jLong start(0), limit, delta(1); - - if(numIArgs == 1) - limit = INT_ARG(0); - else if(numIArgs == 2) { - start = INT_ARG(0); - limit = INT_ARG(1); - } - else { - start = INT_ARG(0); - limit = INT_ARG(1); - delta = INT_ARG(2); - } - - if (limit == start){ - //Return [0] to match TF - return SHAPELIST(ConstantShapeHelper::getInstance().vectorShapeInfo(0, sd::DataType::INT32)); - } - - REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); - - if (!block.numD()) { - if (limit > DataTypeUtils::max()) - dataType = sd::DataType::INT64; - else - dataType = sd::DataType::INT32; - } - - steps = (limit - start) / delta; - - if(math::nd4j_abs(start + steps * delta) < math::nd4j_abs(limit)) - ++steps; - } - else if (numTArgs > 0) { - double start(0), limit, delta(1); - - if(numTArgs == 1) - limit = T_ARG(0); - else if(numTArgs == 2) { - start = T_ARG(0); - limit = T_ARG(1); - } - else { - start = T_ARG(0); - limit = T_ARG(1); - delta = T_ARG(2); - } - - if (limit == start){ - //Return [0] to match TF - return SHAPELIST(ConstantShapeHelper::getInstance().vectorShapeInfo(0, Environment::getInstance().defaultFloatDataType())); - } - - - REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); - - steps = static_cast((limit - start) / delta); - - if (!block.numD()) { - if (Environment::getInstance().precisionBoostAllowed()) - dataType = sd::DataType::DOUBLE; - else - dataType = Environment::getInstance().defaultFloatDataType(); - } - - if(math::nd4j_abs(start + steps * delta) < math::nd4j_abs(limit)) - ++steps; + const int numInArrs = block.width(); + const int numTArgs = block.numT(); + const int numIArgs = block.numI(); + + Nd4jLong steps = 0; + sd::DataType dataType = block.numD() ? D_ARG(0) : sd::DataType::INHERIT; + + if (numInArrs > 0) { + auto isR = INPUT_VARIABLE(0)->isR(); + auto isZ = INPUT_VARIABLE(0)->isZ(); + auto dtype = INPUT_VARIABLE(0)->dataType(); + + if (isR) { + double start(0), limit, delta(1); + + if (numInArrs == 1) + limit = INPUT_VARIABLE(0)->e(0); + else if (numInArrs == 2) { + start = INPUT_VARIABLE(0)->e(0); + limit = INPUT_VARIABLE(1)->e(0); + } else { + start = INPUT_VARIABLE(0)->e(0); + limit = INPUT_VARIABLE(1)->e(0); + delta = INPUT_VARIABLE(2)->e(0); + } + + if (limit == start) { + // Return [0] to match TF + return SHAPELIST( + ConstantShapeHelper::getInstance().vectorShapeInfo(0, dtype)); + } + + REQUIRE_TRUE(delta != 0, 0, + "CUSTOM RANGE OP: delta should not be equal to zero !"); + + steps = static_cast((limit - start) / delta); + + if (!block.numD()) dataType = INPUT_VARIABLE(0)->dataType(); + + if (math::nd4j_abs(start + steps * delta) < + math::nd4j_abs(limit)) + ++steps; + } else if (isZ) { + Nd4jLong start(0), limit, delta(1); + + if (numInArrs == 1) + limit = INPUT_VARIABLE(0)->e(0); + else if (numInArrs == 2) { + start = INPUT_VARIABLE(0)->e(0); + limit = INPUT_VARIABLE(1)->e(0); + } else { + start = INPUT_VARIABLE(0)->e(0); + limit = INPUT_VARIABLE(1)->e(0); + delta = INPUT_VARIABLE(2)->e(0); + } + + // nd4j_printf("Start: [%lld]; Limit: [%lld]; Delta: [%lld];\n", start, + // limit, delta) + + if (limit == start) { + // Return [0] to match TF + return SHAPELIST( + ConstantShapeHelper::getInstance().vectorShapeInfo(0, dtype)); + } + + REQUIRE_TRUE(delta != 0, 0, + "CUSTOM RANGE OP: delta should not be equal to zero !"); + + steps = static_cast((limit - start) / delta); + + if (!block.numD()) dataType = INPUT_VARIABLE(0)->dataType(); + + if (math::nd4j_abs(start + steps * delta) < + math::nd4j_abs(limit)) + ++steps; + } + } else if (numIArgs > 0) { + Nd4jLong start(0), limit, delta(1); + + if (numIArgs == 1) + limit = INT_ARG(0); + else if (numIArgs == 2) { + start = INT_ARG(0); + limit = INT_ARG(1); } else { - REQUIRE_TRUE(false, 0, "CUSTOM RANGE OP: op should have inputs defined in any possible way: T_args, INT_args, or INPUT variables!"); + start = INT_ARG(0); + limit = INT_ARG(1); + delta = INT_ARG(2); } - REQUIRE_TRUE(steps > 0, 0, "CUSTOM RANGE OP: value of (limit-start)/delta should be positive !"); + if (limit == start) { + // Return [0] to match TF + return SHAPELIST(ConstantShapeHelper::getInstance().vectorShapeInfo( + 0, sd::DataType::INT32)); + } - return SHAPELIST(ConstantShapeHelper::getInstance().vectorShapeInfo(steps, dataType)); -} + REQUIRE_TRUE(delta != 0, 0, + "CUSTOM RANGE OP: delta should not be equal to zero !"); + + if (!block.numD()) { + if (limit > DataTypeUtils::max()) + dataType = sd::DataType::INT64; + else + dataType = sd::DataType::INT32; + } + + steps = (limit - start) / delta; + if (math::nd4j_abs(start + steps * delta) < + math::nd4j_abs(limit)) + ++steps; + } else if (numTArgs > 0) { + double start(0), limit, delta(1); - DECLARE_TYPES(range) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); + if (numTArgs == 1) + limit = T_ARG(0); + else if (numTArgs == 2) { + start = T_ARG(0); + limit = T_ARG(1); + } else { + start = T_ARG(0); + limit = T_ARG(1); + delta = T_ARG(2); + } + + if (limit == start) { + // Return [0] to match TF + return SHAPELIST(ConstantShapeHelper::getInstance().vectorShapeInfo( + 0, Environment::getInstance().defaultFloatDataType())); + } + + REQUIRE_TRUE(delta != 0, 0, + "CUSTOM RANGE OP: delta should not be equal to zero !"); + + steps = static_cast((limit - start) / delta); + + if (!block.numD()) { + if (Environment::getInstance().precisionBoostAllowed()) + dataType = sd::DataType::DOUBLE; + else + dataType = Environment::getInstance().defaultFloatDataType(); } + + if (math::nd4j_abs(start + steps * delta) < + math::nd4j_abs(limit)) + ++steps; + } else { + REQUIRE_TRUE(false, 0, + "CUSTOM RANGE OP: op should have inputs defined in any " + "possible way: T_args, INT_args, or INPUT variables!"); + } + + REQUIRE_TRUE( + steps > 0, 0, + "CUSTOM RANGE OP: value of (limit-start)/delta should be positive !"); + + return SHAPELIST( + ConstantShapeHelper::getInstance().vectorShapeInfo(steps, dataType)); } + +DECLARE_TYPES(range) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tensor/strided_slice.cpp b/libnd4j/include/ops/declarable/generic/tensor/strided_slice.cpp index bbdc84ce5b95..d1d44c96bfcf 100644 --- a/libnd4j/include/ops/declarable/generic/tensor/strided_slice.cpp +++ b/libnd4j/include/ops/declarable/generic/tensor/strided_slice.cpp @@ -16,127 +16,131 @@ limitations under the License. #include #if NOT_EXCLUDED(OP_strided_slice) -#include -#include -#include #include +#include +#include + +#include namespace sd { - namespace ops { - - constexpr int kShrinkAxis = -1, kNewAxis = -2; - - struct StridedSliceSparseSpec { - int dims; - int num_add_axis_after_ellipsis; - std::vector* begin_tensor; - const std::vector* end_tensor; - const std::vector* strides_tensor; - const int begin_mask, end_mask; - int ellipsis_mask; - const int new_axis_mask, shrink_axis_mask; - }; - - struct StridedSliceDenseSpec { - const int dims; - int begin_mask; - int end_mask; - bool begin_valid; - bool end_valid; - std::vector& begin; - std::vector& end; - std::vector& strides; - std::vector final_shape_gather_indices; - int shrink_axis_mask; - - public: - bool buildDenseSpec(StridedSliceSparseSpec& sparse_spec) { - if (this->begin.size() < dims) - this->begin.resize(dims); - - if (this->end.size() < dims) - this->end.resize(dims); - - if (this->strides.size() < dims) - this->strides.resize(dims); - this->begin_mask = 0; - this->end_mask = 0; - this->shrink_axis_mask = 0; - { - int full_index = 0; - - this->begin_valid = sparse_spec.begin_tensor != nullptr; - this->end_valid = sparse_spec.end_tensor != nullptr; - - for (int e = 0; e < sparse_spec.dims; e++) { - if ((1 << e) & sparse_spec.ellipsis_mask) { - int next_index = sd::math::nd4j_min(this->dims - (sparse_spec.dims - e) + 1 + sparse_spec.num_add_axis_after_ellipsis, this->dims); - - for (; full_index < next_index; full_index++) { - // new_axis' aren't real axis so you have to skip - this->begin[full_index] = this->end[full_index] = 0; - this->strides[full_index] = 1; - this->begin_mask |= (1 << full_index); - this->end_mask |= (1 << full_index); - this->final_shape_gather_indices.push_back(full_index); - } - } else if ((1 << e) & sparse_spec.new_axis_mask) { - this->final_shape_gather_indices.emplace_back(kNewAxis); - } else { - if (full_index == this->begin.size()) { - nd4j_printf("Index out of range: %i out of %i\n", full_index, this->dims); - return false; - } - - // Gather slicing spec into appropriate index - if (sparse_spec.begin_tensor != nullptr) - this->begin[full_index] = sparse_spec.begin_tensor->at(e); - - - if (sparse_spec.end_tensor != nullptr) - this->end[full_index] = sparse_spec.end_tensor->at(e); - - this->strides[full_index] = sparse_spec.strides_tensor->at(e); - - if (sparse_spec.begin_mask & (1 << e)) - this->begin_mask |= (1 << full_index); - - - if (sparse_spec.end_mask & (1 << e)) - this->end_mask |= (1 << full_index); - - - // If shrink, record where to get the dimensionality from (i.e. - // new_axis creates a fake 1 size dimension. Also remember shrink - // axis (now in dense form) so we can ignore dense->end below. - if (sparse_spec.shrink_axis_mask & (1 << e)) { - this->final_shape_gather_indices.push_back(kShrinkAxis); - this->shrink_axis_mask |= (1 << full_index); - } else { - this->final_shape_gather_indices.push_back(full_index); - } - full_index++; - } - } - } - return true; - } - }; - - void vectorize(std::vector& input_shape) { - if (input_shape.size() == 2 && input_shape[0] == 1) { - int v = input_shape[1]; - input_shape.clear(); - input_shape.emplace_back(v); - } +namespace ops { + +constexpr int kShrinkAxis = -1, kNewAxis = -2; + +struct StridedSliceSparseSpec { + int dims; + int num_add_axis_after_ellipsis; + std::vector* begin_tensor; + const std::vector* end_tensor; + const std::vector* strides_tensor; + const int begin_mask, end_mask; + int ellipsis_mask; + const int new_axis_mask, shrink_axis_mask; +}; + +struct StridedSliceDenseSpec { + const int dims; + int begin_mask; + int end_mask; + bool begin_valid; + bool end_valid; + std::vector& begin; + std::vector& end; + std::vector& strides; + std::vector final_shape_gather_indices; + int shrink_axis_mask; + + public: + bool buildDenseSpec(StridedSliceSparseSpec& sparse_spec) { + if (this->begin.size() < dims) this->begin.resize(dims); + + if (this->end.size() < dims) this->end.resize(dims); + + if (this->strides.size() < dims) this->strides.resize(dims); + this->begin_mask = 0; + this->end_mask = 0; + this->shrink_axis_mask = 0; + { + int full_index = 0; + + this->begin_valid = sparse_spec.begin_tensor != nullptr; + this->end_valid = sparse_spec.end_tensor != nullptr; + + for (int e = 0; e < sparse_spec.dims; e++) { + if ((1 << e) & sparse_spec.ellipsis_mask) { + int next_index = sd::math::nd4j_min( + this->dims - (sparse_spec.dims - e) + 1 + + sparse_spec.num_add_axis_after_ellipsis, + this->dims); + + for (; full_index < next_index; full_index++) { + // new_axis' aren't real axis so you have to skip + this->begin[full_index] = this->end[full_index] = 0; + this->strides[full_index] = 1; + this->begin_mask |= (1 << full_index); + this->end_mask |= (1 << full_index); + this->final_shape_gather_indices.push_back(full_index); + } + } else if ((1 << e) & sparse_spec.new_axis_mask) { + this->final_shape_gather_indices.emplace_back(kNewAxis); + } else { + if (full_index == this->begin.size()) { + nd4j_printf("Index out of range: %i out of %i\n", full_index, + this->dims); + return false; + } + + // Gather slicing spec into appropriate index + if (sparse_spec.begin_tensor != nullptr) + this->begin[full_index] = sparse_spec.begin_tensor->at(e); + + if (sparse_spec.end_tensor != nullptr) + this->end[full_index] = sparse_spec.end_tensor->at(e); + + this->strides[full_index] = sparse_spec.strides_tensor->at(e); + + if (sparse_spec.begin_mask & (1 << e)) + this->begin_mask |= (1 << full_index); + + if (sparse_spec.end_mask & (1 << e)) + this->end_mask |= (1 << full_index); + + // If shrink, record where to get the dimensionality from (i.e. + // new_axis creates a fake 1 size dimension. Also remember shrink + // axis (now in dense form) so we can ignore dense->end below. + if (sparse_spec.shrink_axis_mask & (1 << e)) { + this->final_shape_gather_indices.push_back(kShrinkAxis); + this->shrink_axis_mask |= (1 << full_index); + } else { + this->final_shape_gather_indices.push_back(full_index); + } + full_index++; } + } + } + return true; + } +}; + +void vectorize(std::vector& input_shape) { + if (input_shape.size() == 2 && input_shape[0] == 1) { + int v = input_shape[1]; + input_shape.clear(); + input_shape.emplace_back(v); + } +} - bool _preprocess_strided_slice(std::vector* indicesList, std::vector* final_shape, std::vector& input_shape, std::vector& begin, std::vector& end, std::vector& strides, int begin_mask, int ellipsis_mask, int end_mask, int new_axis_mask, int shrink_axis_mask, bool* is_identity, bool* is_simple_slice, bool* slice_dim0) { - std::vector preshape; +bool _preprocess_strided_slice( + std::vector* indicesList, std::vector* final_shape, + std::vector& input_shape, std::vector& begin, + std::vector& end, std::vector& strides, int begin_mask, + int ellipsis_mask, int end_mask, int new_axis_mask, int shrink_axis_mask, + bool* is_identity, bool* is_simple_slice, bool* slice_dim0) { + std::vector preshape; - bool ellipsis_seen = false; + bool ellipsis_seen = false; - StridedSliceSparseSpec sparse_spec = {(int) strides.size(), + StridedSliceSparseSpec sparse_spec = {(int)strides.size(), 0, &begin, &end, @@ -147,530 +151,601 @@ namespace sd { new_axis_mask, shrink_axis_mask}; - for (int i = 0; i < sparse_spec.dims; i++) { - if (ellipsis_seen && ((1 << i) & new_axis_mask) != 0) { - sparse_spec.num_add_axis_after_ellipsis++; - } - if ((1 << i) & ellipsis_mask) { - ellipsis_seen = true; - } - } - // If no ellipsis insert one at the end - if (!ellipsis_seen) { - sparse_spec.ellipsis_mask |= (1 << sparse_spec.dims); - sparse_spec.dims++; // this effects loop iteration below - } - - StridedSliceDenseSpec dense_spec = {(int) input_shape.size(), 0, 0, false, false, begin, end, strides}; - if (!dense_spec.buildDenseSpec(sparse_spec)) - return false; - - //nd4j_printv("Input shape: ", input_shape); - - for (int e = 0; e < (int) input_shape.size(); e++) { - int begin_idx = begin[e]; - int end_idx = end[e]; - int stride_idx = strides[e]; - int size_idx = input_shape[e]; - - bool shrink_i = (dense_spec.shrink_axis_mask & (1 << e)); - - if (stride_idx == 0) { - nd4j_printf("Stride is 0 at index %i\n", e); - return false; - } - if (size_idx == -1) { - preshape.emplace_back(shrink_i ? 1 : -1); - continue; - } - - const std::array masks = {{dense_spec.begin_mask & (1 << e), dense_spec.end_mask & (1 << e)}}; - const std::array valid_range = {{stride_idx > 0 ? 0 : -1, stride_idx > 0 ? size_idx : size_idx - 1}}; - - auto canonical = [stride_idx, e, size_idx, masks, valid_range](int x, int c) { - if (masks[c]) { - return stride_idx > 0 ? valid_range[c] : valid_range[(c + 1) & 1]; - } else { - int x_fwd = x < 0 ? size_idx + x : x; // make negative indices positive - return x_fwd < valid_range[0] ? valid_range[0] : x_fwd > valid_range[1] ? valid_range[1] : x_fwd; - } - }; - - if (shrink_i && stride_idx <= 0) { - nd4j_printf("StridedSlice: only stride 1 allowed on non-range indexing\n", e); - return false; - } - - (*is_simple_slice) &= stride_idx == 1; - - const bool begin_and_end_masked = (begin_mask & (1 << e)) && (end_mask & (1 << e)); - - if (dense_spec.begin_valid && dense_spec.end_valid) { - if (shrink_i) { - int x_fwd = begin_idx < 0 ? size_idx + begin_idx : begin_idx; - begin_idx = x_fwd; - end_idx = begin_idx + 1; - if (x_fwd < 0 || x_fwd >= size_idx) { - nd4j_printf("slice index %i of dimension %i out of bounds.\n", begin_idx, e); - return false; - } - } else { - begin_idx = canonical(begin_idx, 0); - end_idx = canonical(end_idx, 1); - } - } else { - (*is_identity) &= stride_idx == 1 && begin_and_end_masked; - (*slice_dim0) &= (e == 0 && stride_idx == 1) || begin_and_end_masked; - } - - int interval_length = 1; - bool known_interval = false; - if (dense_spec.begin_valid && dense_spec.end_valid) { - interval_length = end_idx - begin_idx; - known_interval = true; - } else if (shrink_i) { - interval_length = 1; - known_interval = true; - } else if (begin_and_end_masked) { - if (size_idx > 0) { - if (stride_idx < 0) { - interval_length = -size_idx; - } else { - interval_length = size_idx; - } - - known_interval = true; - } - } - - if (known_interval) { - int size_i; - if (interval_length == 0 || ((interval_length < 0) != (stride_idx < 0))) { - size_i = input_shape.size() == 2 && input_shape[0] == 1? 1: 0; - } else { - size_i = interval_length / stride_idx + (interval_length % stride_idx != 0 ? 1 : 0); - } - - if (indicesList != nullptr) { - if (interval_length > 1) { - indicesList->push_back(begin_idx); - indicesList->push_back(end_idx); - indicesList->push_back(stride_idx); - // (*indicesList)[3*e] = begin_idx; - // (*indicesList)[3*e+1] = end_idx; - // (*indicesList)[3*e+2] = stride_idx; - } - else if (interval_length == 1) { - indicesList->push_back(begin_idx); - indicesList->push_back(begin_idx + 1); - indicesList->push_back(1); - // (*indicesList)[3*e] = begin_idx; - // (*indicesList)[3*e+1] = begin_idx + 1; - // (*indicesList)[3*e+2] = 1; - } - } - - preshape.emplace_back(size_i); - } else { - preshape.emplace_back(-1); - } - } - - - std::vector postshape; - //nd4j_printv("Preshape: ", preshape); - - final_shape->clear(); - for (auto gather_index : dense_spec.final_shape_gather_indices) { - if (gather_index >= 0) { - if (preshape.size() > gather_index) - final_shape->emplace_back(preshape.at(gather_index)); - else - final_shape->emplace_back(1); - } else if (gather_index == kNewAxis) { - final_shape->emplace_back(1); - } - } - - //nd4j_printv("Preshape: ", preshape); - //nd4j_printv("Postshape: ", *final_shape); - - return true; - } + for (int i = 0; i < sparse_spec.dims; i++) { + if (ellipsis_seen && ((1 << i) & new_axis_mask) != 0) { + sparse_spec.num_add_axis_after_ellipsis++; + } + if ((1 << i) & ellipsis_mask) { + ellipsis_seen = true; + } + } + // If no ellipsis insert one at the end + if (!ellipsis_seen) { + sparse_spec.ellipsis_mask |= (1 << sparse_spec.dims); + sparse_spec.dims++; // this effects loop iteration below + } + + StridedSliceDenseSpec dense_spec = { + (int)input_shape.size(), 0, 0, false, false, begin, end, strides}; + if (!dense_spec.buildDenseSpec(sparse_spec)) return false; + + // nd4j_printv("Input shape: ", input_shape); + + for (int e = 0; e < (int)input_shape.size(); e++) { + int begin_idx = begin[e]; + int end_idx = end[e]; + int stride_idx = strides[e]; + int size_idx = input_shape[e]; + + bool shrink_i = (dense_spec.shrink_axis_mask & (1 << e)); + + if (stride_idx == 0) { + nd4j_printf("Stride is 0 at index %i\n", e); + return false; + } + if (size_idx == -1) { + preshape.emplace_back(shrink_i ? 1 : -1); + continue; + } + + const std::array masks = { + {dense_spec.begin_mask & (1 << e), dense_spec.end_mask & (1 << e)}}; + const std::array valid_range = { + {stride_idx > 0 ? 0 : -1, stride_idx > 0 ? size_idx : size_idx - 1}}; + + auto canonical = [stride_idx, e, size_idx, masks, valid_range](int x, + int c) { + if (masks[c]) { + return stride_idx > 0 ? valid_range[c] : valid_range[(c + 1) & 1]; + } else { + int x_fwd = x < 0 ? size_idx + x : x; // make negative indices positive + return x_fwd < valid_range[0] + ? valid_range[0] + : x_fwd > valid_range[1] ? valid_range[1] : x_fwd; + } + }; + + if (shrink_i && stride_idx <= 0) { + nd4j_printf("StridedSlice: only stride 1 allowed on non-range indexing\n", + e); + return false; + } + (*is_simple_slice) &= stride_idx == 1; - CUSTOM_OP_IMPL(strided_slice, 1, 1, false, 0, 5) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); - if (z->isEmpty()) { - return ND4J_STATUS_OK; - } - - int begin_mask = INT_ARG(0); - int ellipsis_mask = INT_ARG(1); - int end_mask = INT_ARG(2); - int new_axis_mask = INT_ARG(3); - int shrink_axis_mask = INT_ARG(4); - - int dim_values = 0; //block.getIArguments()->size() - 5; - int delta = 0; //dim_values % 3; - int elements = 0; //dim_values / 3; - - std::vector begin; - std::vector end; - std::vector strides; - - bool isLive = false; - - std::vector args; - - // statically evaluated - if (block.getIArguments()->size() > 5) { - dim_values = block.getIArguments()->size() - 5; - delta = dim_values % 3; - elements = dim_values / 3; - - for (int e = 5; e < block.getIArguments()->size(); e++) - args.emplace_back(INT_ARG(e)); - - REQUIRE_TRUE(delta == 0, 0, "StridedSlice: Number of Integer arguments should be equal to input rank x 3 = %i, but got %i instead", (x->rankOf() * 3), dim_values); - - ShapeUtils::copyVectorPart(begin, args, elements, 0); - ShapeUtils::copyVectorPart(end, args, elements, elements); - ShapeUtils::copyVectorPart(strides, args, elements, elements * 2); - - } else if (block.width() > 1) { - isLive = true; - - auto v_begin = INPUT_VARIABLE(1); - auto v_end = INPUT_VARIABLE(2); - - elements = v_begin->lengthOf(); - - REQUIRE_TRUE(v_begin->lengthOf() == v_end->lengthOf(), 0, "StridedSlice: Length of begin/end should match, but got %i vs %i instead", (int) v_begin->lengthOf(), (int) v_end->lengthOf()); - REQUIRE_TRUE((v_begin->rankOf() == 1 ) && (v_begin->rankOf() == v_end->rankOf()), 0, "StridedSlice: Rank of begin and ends should be 1, but %i given instead", (int)v_end->rankOf()); - - for (int e = 0; e < v_begin->lengthOf(); e++) - begin.emplace_back(v_begin->e(e)); - - for (int e = 0; e < v_end->lengthOf(); e++) - end.emplace_back(v_end->e(e)); - - if (block.width() > 3) { - auto v_stride = INPUT_VARIABLE(3); - - REQUIRE_TRUE(v_stride->lengthOf() == v_begin->lengthOf(), 0, "StridedSlice: Length of begin/end/stride should match, but got %i vs %i vs %i instead", (int) v_begin->lengthOf(), (int) v_end->lengthOf(), (int) v_stride->lengthOf()); - REQUIRE_TRUE((v_begin->rankOf() == v_stride->rankOf()), 0, "StridedSlice: Rank of begin and ends should be %i, but %i given instead", (int) v_begin->rankOf(), v_stride->rankOf()); - - for (int e = 0; e < v_stride->lengthOf(); e++) - strides.emplace_back(v_stride->e(e)); - } else { - for (int e = 0; e < v_begin->lengthOf(); e++) - strides.emplace_back(1); - } - } else { - REQUIRE_TRUE(false, 0, "StridedSlice: Can't find begin/end/stride information neither in IArguments or in input arrays"); - } - - // validation of begin and start - std::vector ignoreBegin = BitwiseUtils::valueBits(begin_mask); - std::vector ignoreEnd = BitwiseUtils::valueBits(end_mask); - std::vector addAxes = BitwiseUtils::valueBits(new_axis_mask); - std::vector moveAxes = BitwiseUtils::valueBits(shrink_axis_mask); - if (shrink_axis_mask == 0) - for (int dim = 0, b = 0, e = 0; dim < x->rankOf(); ++dim) { - - if(moveAxes[dim]) - continue; - - if(b < begin.size() && !ignoreBegin[b] && !addAxes[dim]) { - int first = strides[b] > 0 ? begin[b] : math::nd4j_abs(begin[b]) - 1; - REQUIRE_TRUE(first <= x->sizeAt(dim), 0, "StridedSlice: begin index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", begin[b], dim); - } - if(e < end.size() && !ignoreEnd[e] && !addAxes[dim]) { - int last = strides[e] > 0 ? end[e] : math::nd4j_abs(end[e]) - 1; - REQUIRE_TRUE(last <= x->sizeAt(dim), 0, "StridedSlice: end index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", end[e], dim); - } - ++b; - ++e; - } - - - std::vector indices; - auto input_shape = x->getShapeAsVector(); - std::vector final_shape; - bool is_identity; - bool is_simple_slice; - bool is_dim0; - - // FIXME: remove this method once we get 1D vectors supported - //vectorize(input_shape); - REQUIRE_TRUE(_preprocess_strided_slice(&indices, &final_shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0), 0, "StridedSlice: shape calculation failed"); -// if(z->lengthOf() == 1 && !z->isEmpty() && (input_shape.size() == 2 && input_shape[0] == 1)) { //(indices.size() == 6) && (indices[2] - indices[0] == 1)) { -// z->assign(x->e(indices[0])); -// } -// else { - if (indices.size()) { - Nd4jLong* subArrShapeInfo = nullptr; - ALLOCATE(subArrShapeInfo, block.getWorkspace(), shape::shapeInfoLength(x->rankOf()), Nd4jLong); - Nd4jLong offset; - - shape::calcSubArrShapeInfoAndOffset(indices.data(), x->shapeInfo(), subArrShapeInfo, offset, true, true); - auto subArrShapeInfoPack = ConstantShapeHelper::getInstance().bufferForShapeInfo(subArrShapeInfo); - - NDArray::prepareSpecialUse({z}, {x}); - - NativeOpExecutioner::execTransformAny(block.launchContext(), sd::transform::Assign, - x->bufferWithOffset(offset), subArrShapeInfoPack.primary(), - x->specialBufferWithOffset(offset), subArrShapeInfoPack.special(), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - nullptr, nullptr, nullptr, true); - - NDArray::registerSpecialUse({z}, {x}); - - RELEASE(subArrShapeInfo, block.getWorkspace()); - } - else if (!z->isEmpty()){ - z->assign(x->e(0)); - } - return Status::OK(); + const bool begin_and_end_masked = + (begin_mask & (1 << e)) && (end_mask & (1 << e)); + + if (dense_spec.begin_valid && dense_spec.end_valid) { + if (shrink_i) { + int x_fwd = begin_idx < 0 ? size_idx + begin_idx : begin_idx; + begin_idx = x_fwd; + end_idx = begin_idx + 1; + if (x_fwd < 0 || x_fwd >= size_idx) { + nd4j_printf("slice index %i of dimension %i out of bounds.\n", + begin_idx, e); + return false; } - DECLARE_SYN(stridedslice, strided_slice); - - DECLARE_SHAPE_FN(strided_slice) { - auto inShape = inputShape->at(0); - - int begin_mask = INT_ARG(0); - int ellipsis_mask = INT_ARG(1); - int end_mask = INT_ARG(2); - int new_axis_mask = INT_ARG(3); - int shrink_axis_mask = INT_ARG(4); - - int x_rank = shape::rank(inShape); - - int dim_values = block.getIArguments()->size() - 5; - int delta = dim_values % 3; - int elements = dim_values / 3; - - - std::vector begin; - std::vector end; - std::vector strides; - - // if that's live - shape will be resolved in runtime - if (block.width() > 1) { - begin = INPUT_VARIABLE(1)->template asVectorT(); - end = INPUT_VARIABLE(2)->template asVectorT(); - strides = INPUT_VARIABLE(3)->template asVectorT(); - } else if (dim_values > 0) { - int delta2 = dim_values / x_rank; - - std::vector args; - for (int e = 5; e < block.getIArguments()->size(); e++) - args.emplace_back(INT_ARG(e)); - - // FIXME: propably template required here - ShapeUtils::copyVectorPart(begin, args, elements, 0); - ShapeUtils::copyVectorPart(end, args, elements, elements); - ShapeUtils::copyVectorPart(strides, args, elements, elements * 2); - } - - REQUIRE_TRUE(begin.size() > 0 && end.size() > 0 && strides.size() > 0, 0, "Strided_Slice: empty arguments"); - - // validation of begin and start - std::vector ignoreBegin = BitwiseUtils::valueBits(begin_mask); - std::vector ignoreEnd = BitwiseUtils::valueBits(end_mask); - std::vector addAxes = BitwiseUtils::valueBits(new_axis_mask); - std::vector moveAxes = BitwiseUtils::valueBits(shrink_axis_mask); - - //if (0 == shrink_axis_mask) - if (false) - for (int dim = 0, b = 0, e = 0; dim < x_rank; ++dim) { - - if(moveAxes[dim]) - continue; - - if(b < begin.size() && !ignoreBegin[b] && !addAxes[dim]) { - int first = strides[b] > 0 ? begin[b] : math::nd4j_abs(begin[b]) - 1; - REQUIRE_TRUE(first <= inShape[dim + 1], 0, "StridedSlice: begin index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", begin[b], dim); - } - if(e < end.size() && !ignoreEnd[e] && !addAxes[dim]) { - int last = strides[e] > 0 ? end[e] : math::nd4j_abs(end[e]) - 1; - REQUIRE_TRUE(last <= inShape[dim + 1], 0, "StridedSlice: end index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", end[e], dim); - } - ++b; - ++e; - } - - std::vector input_shape; //(shape::rank(inShape)); - auto inputLen = shape::length(inShape); - std::vector shape; - - auto rank = shape::rank(inShape); - auto shortShape = shape::shapeOf(inShape); - for (auto e = 0; e < rank; e++) - input_shape.emplace_back(shortShape[e]); - - bool is_identity; - bool is_simple_slice; - bool is_dim0; - - std::vector indices; - bool result = _preprocess_strided_slice(&indices, &shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0); - if (indices.size()) { - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inShape), 'c', - shape); -// if (inputLen > 1) { -// newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inShape), 'c', -// shape); -// } else { -// newShape = ConstantShapeHelper::getInstance().scalarShapeInfo(ArrayOptions::dataType(inShape)); -// } - return SHAPELIST(newShape); - } - - return SHAPELIST(ConstantShapeHelper::getInstance().emptyShapeInfo(ArrayOptions::dataType(inShape))); + } else { + begin_idx = canonical(begin_idx, 0); + end_idx = canonical(end_idx, 1); + } + } else { + (*is_identity) &= stride_idx == 1 && begin_and_end_masked; + (*slice_dim0) &= (e == 0 && stride_idx == 1) || begin_and_end_masked; + } + + int interval_length = 1; + bool known_interval = false; + if (dense_spec.begin_valid && dense_spec.end_valid) { + interval_length = end_idx - begin_idx; + known_interval = true; + } else if (shrink_i) { + interval_length = 1; + known_interval = true; + } else if (begin_and_end_masked) { + if (size_idx > 0) { + if (stride_idx < 0) { + interval_length = -size_idx; + } else { + interval_length = size_idx; } + known_interval = true; + } + } - CUSTOM_OP_IMPL(strided_slice_bp, 2, 1, false, 0, 5) { - auto x = INPUT_VARIABLE(0); - auto epsNext = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - - int begin_mask = INT_ARG(0); - int ellipsis_mask = INT_ARG(1); - int end_mask = INT_ARG(2); - int new_axis_mask = INT_ARG(3); - int shrink_axis_mask = INT_ARG(4); - - int dim_values = 0; //block.getIArguments()->size() - 5; - int delta = 0; //dim_values % 3; - int elements = 0; //dim_values / 3; - - std::vector begin; - std::vector end; - std::vector strides; - - bool isLive = false; - - std::vector args; - - // statically evaluated - if (block.getIArguments()->size() > 5) { - dim_values = block.getIArguments()->size() - 5; - delta = dim_values % 3; - elements = dim_values / 3; - - for (int e = 5; e < block.getIArguments()->size(); e++) - args.emplace_back(INT_ARG(e)); - - REQUIRE_TRUE(delta == 0, 0, "StridedSliceBP: Number of Integer arguments should be equal to input rank x 3 = %i, but got %i instead", (x->rankOf() * 3), dim_values); - - ShapeUtils::copyVectorPart(begin, args, elements, 0); - ShapeUtils::copyVectorPart(end, args, elements, elements); - ShapeUtils::copyVectorPart(strides, args, elements, elements * 2); - - } else if (block.width() >= 3) { - isLive = true; - - auto v_begin = INPUT_VARIABLE(2); - auto v_end = INPUT_VARIABLE(3); - - elements = v_begin->lengthOf(); - - REQUIRE_TRUE(v_begin->lengthOf() == v_end->lengthOf(), 0, "StridedSliceBP: Length of begin/end should match, but got %i vs %i instead", (int) v_begin->lengthOf(), (int) v_end->lengthOf()); - - for (int e = 0; e < v_begin->lengthOf(); e++) - begin.emplace_back(v_begin->e(e)); - - for (int e = 0; e < v_end->lengthOf(); e++) - end.emplace_back(v_end->e(e)); - - if (block.width() >= 4) { - auto v_stride = INPUT_VARIABLE(4); - - REQUIRE_TRUE(v_stride->lengthOf() == v_begin->lengthOf(), 0, "StridedSliceBP: Length of begin/end/stride should match, but got %i vs %i vs %i instead", (int) v_begin->lengthOf(), (int) v_end->lengthOf(), (int) v_stride->lengthOf()); - - for (int e = 0; e < v_stride->lengthOf(); e++) - strides.emplace_back(v_stride->e(e)); - } else { - for (int e = 0; e < v_begin->lengthOf(); e++) - strides.emplace_back(1); - } - } else { - REQUIRE_TRUE(false, 0, "StridedSliceBP: Can't find begin/end/stride information neither in IArguments or in input arrays"); - } - - // validation of begin and start - std::vector ignoreBegin = BitwiseUtils::valueBits(begin_mask); - std::vector ignoreEnd = BitwiseUtils::valueBits(end_mask); - std::vector addAxes = BitwiseUtils::valueBits(new_axis_mask); - std::vector moveAxes = BitwiseUtils::valueBits(shrink_axis_mask); - - for (int dim = 0, b = 0, e = 0; dim < x->rankOf(); ++dim) { - - if(moveAxes[dim]) - continue; - - if(b < begin.size() && !ignoreBegin[b] && !addAxes[dim]) { - int first = strides[b] > 0 ? begin[b] : math::nd4j_abs(begin[b]) - 1; - REQUIRE_TRUE(first <= x->sizeAt(dim), 0, "StridedSlice: begin index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", begin[b], dim); - } - if(e < end.size() && !ignoreEnd[e] && !addAxes[dim]) { - int last = strides[e] > 0 ? end[e] : math::nd4j_abs(end[e]) - 1; - REQUIRE_TRUE(last <= x->sizeAt(dim), 0, "StridedSlice: end index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", end[e], dim); - } - ++b; - ++e; - } - - auto input_shape = x->getShapeAsVector(); - std::vector indices; - std::vector final_shape; - bool is_identity; - bool is_simple_slice; - bool is_dim0; - - // FIXME: remove this method once we get 1D vectors supported - vectorize(input_shape); - REQUIRE_TRUE(_preprocess_strided_slice(&indices, &final_shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0), 0, "StridedSliceBP: shape calculation failed"); - //REQUIRE_TRUE(epsNext->isSameShape(final_shape), 0, "StridedSlice_bp: gradOut shape should be equals to output from strided_slice op."); - //Zero output array, so unused elements have 0 gradient - output->nullify(); - // - // the first case: only for scalar gradient step - if(epsNext->lengthOf() == 1 && (indices.size() == 3 && (indices[1] - indices[0]) == 1 || (indices[2] - indices[0] == 1))) { - output->p(indices[0], *epsNext); - } - else { // else for other cases - auto sub = (*output)(indices, true, true); - sub.assign(epsNext); - } - - return Status::OK(); + if (known_interval) { + int size_i; + if (interval_length == 0 || ((interval_length < 0) != (stride_idx < 0))) { + size_i = input_shape.size() == 2 && input_shape[0] == 1 ? 1 : 0; + } else { + size_i = interval_length / stride_idx + + (interval_length % stride_idx != 0 ? 1 : 0); + } + + if (indicesList != nullptr) { + if (interval_length > 1) { + indicesList->push_back(begin_idx); + indicesList->push_back(end_idx); + indicesList->push_back(stride_idx); + // (*indicesList)[3*e] = begin_idx; + // (*indicesList)[3*e+1] = end_idx; + // (*indicesList)[3*e+2] = stride_idx; + } else if (interval_length == 1) { + indicesList->push_back(begin_idx); + indicesList->push_back(begin_idx + 1); + indicesList->push_back(1); + // (*indicesList)[3*e] = begin_idx; + // (*indicesList)[3*e+1] = begin_idx + 1; + // (*indicesList)[3*e+2] = 1; } + } - DECLARE_SHAPE_FN(strided_slice_bp) { - auto inShape = inputShape->at(0); - Nd4jLong *newShape; - COPY_SHAPE(inShape, newShape); + preshape.emplace_back(size_i); + } else { + preshape.emplace_back(-1); + } + } + + std::vector postshape; + // nd4j_printv("Preshape: ", preshape); + + final_shape->clear(); + for (auto gather_index : dense_spec.final_shape_gather_indices) { + if (gather_index >= 0) { + if (preshape.size() > gather_index) + final_shape->emplace_back(preshape.at(gather_index)); + else + final_shape->emplace_back(1); + } else if (gather_index == kNewAxis) { + final_shape->emplace_back(1); + } + } - return SHAPELIST(CONSTANT(newShape)); - } + // nd4j_printv("Preshape: ", preshape); + // nd4j_printv("Postshape: ", *final_shape); - DECLARE_TYPES(strided_slice) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } + return true; +} - DECLARE_TYPES(strided_slice_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +CUSTOM_OP_IMPL(strided_slice, 1, 1, false, 0, 5) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + if (z->isEmpty()) { + return ND4J_STATUS_OK; + } + + int begin_mask = INT_ARG(0); + int ellipsis_mask = INT_ARG(1); + int end_mask = INT_ARG(2); + int new_axis_mask = INT_ARG(3); + int shrink_axis_mask = INT_ARG(4); + + int dim_values = 0; // block.getIArguments()->size() - 5; + int delta = 0; // dim_values % 3; + int elements = 0; // dim_values / 3; + + std::vector begin; + std::vector end; + std::vector strides; + + bool isLive = false; + + std::vector args; + + // statically evaluated + if (block.numI() > 5) { + dim_values = block.numI() - 5; + delta = dim_values % 3; + elements = dim_values / 3; + + for (int e = 5; e < block.numI(); e++) args.emplace_back(INT_ARG(e)); + + REQUIRE_TRUE(delta == 0, 0, + "StridedSlice: Number of Integer arguments should be equal to " + "input rank x 3 = %i, but got %i instead", + (x->rankOf() * 3), dim_values); + + ShapeUtils::copyVectorPart(begin, args, elements, 0); + ShapeUtils::copyVectorPart(end, args, elements, elements); + ShapeUtils::copyVectorPart(strides, args, elements, elements * 2); + + } else if (block.width() > 1) { + isLive = true; + + auto v_begin = INPUT_VARIABLE(1); + auto v_end = INPUT_VARIABLE(2); + + elements = v_begin->lengthOf(); + + REQUIRE_TRUE(v_begin->lengthOf() == v_end->lengthOf(), 0, + "StridedSlice: Length of begin/end should match, but got %i " + "vs %i instead", + (int)v_begin->lengthOf(), (int)v_end->lengthOf()); + REQUIRE_TRUE( + (v_begin->rankOf() == 1) && (v_begin->rankOf() == v_end->rankOf()), 0, + "StridedSlice: Rank of begin and ends should be 1, but %i given " + "instead", + (int)v_end->rankOf()); + + for (int e = 0; e < v_begin->lengthOf(); e++) + begin.emplace_back(v_begin->e(e)); + + for (int e = 0; e < v_end->lengthOf(); e++) + end.emplace_back(v_end->e(e)); + + if (block.width() > 3) { + auto v_stride = INPUT_VARIABLE(3); + + REQUIRE_TRUE(v_stride->lengthOf() == v_begin->lengthOf(), 0, + "StridedSlice: Length of begin/end/stride should match, but " + "got %i vs %i vs %i instead", + (int)v_begin->lengthOf(), (int)v_end->lengthOf(), + (int)v_stride->lengthOf()); + REQUIRE_TRUE((v_begin->rankOf() == v_stride->rankOf()), 0, + "StridedSlice: Rank of begin and ends should be %i, but %i " + "given instead", + (int)v_begin->rankOf(), v_stride->rankOf()); + + for (int e = 0; e < v_stride->lengthOf(); e++) + strides.emplace_back(v_stride->e(e)); + } else { + for (int e = 0; e < v_begin->lengthOf(); e++) strides.emplace_back(1); } + } else { + REQUIRE_TRUE(false, 0, + "StridedSlice: Can't find begin/end/stride information " + "neither in IArguments or in input arrays"); + } + + // validation of begin and start + std::vector ignoreBegin = BitwiseUtils::valueBits(begin_mask); + std::vector ignoreEnd = BitwiseUtils::valueBits(end_mask); + std::vector addAxes = BitwiseUtils::valueBits(new_axis_mask); + std::vector moveAxes = BitwiseUtils::valueBits(shrink_axis_mask); + if (shrink_axis_mask == 0) + for (int dim = 0, b = 0, e = 0; dim < x->rankOf(); ++dim) { + if (moveAxes[dim]) continue; + + if (b < begin.size() && !ignoreBegin[b] && !addAxes[dim]) { + int first = + strides[b] > 0 ? begin[b] : math::nd4j_abs(begin[b]) - 1; + REQUIRE_TRUE( + first <= x->sizeAt(dim), 0, + "StridedSlice: begin index should be <= corresponding dimension of " + "input array, but got end_index = %i for dimension %i!", + begin[b], dim); + } + if (e < end.size() && !ignoreEnd[e] && !addAxes[dim]) { + int last = strides[e] > 0 ? end[e] : math::nd4j_abs(end[e]) - 1; + REQUIRE_TRUE( + last <= x->sizeAt(dim), 0, + "StridedSlice: end index should be <= corresponding dimension of " + "input array, but got end_index = %i for dimension %i!", + end[e], dim); + } + ++b; + ++e; + } + + std::vector indices; + auto input_shape = x->getShapeAsVector(); + std::vector final_shape; + bool is_identity; + bool is_simple_slice; + bool is_dim0; + + // FIXME: remove this method once we get 1D vectors supported + // vectorize(input_shape); + REQUIRE_TRUE(_preprocess_strided_slice( + &indices, &final_shape, input_shape, begin, end, strides, + begin_mask, ellipsis_mask, end_mask, new_axis_mask, + shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0), + 0, "StridedSlice: shape calculation failed"); + // if(z->lengthOf() == 1 && !z->isEmpty() && (input_shape.size() == + // 2 && input_shape[0] == 1)) { //(indices.size() == 6) && + // (indices[2] - indices[0] == 1)) { + // z->assign(x->e(indices[0])); + // } + // else { + if (indices.size()) { + Nd4jLong* subArrShapeInfo = nullptr; + ALLOCATE(subArrShapeInfo, block.workspace(), + shape::shapeInfoLength(x->rankOf()), Nd4jLong); + Nd4jLong offset; + + shape::calcSubArrShapeInfoAndOffset(indices.data(), x->shapeInfo(), + subArrShapeInfo, offset, true, true); + auto subArrShapeInfoPack = + ConstantShapeHelper::getInstance().bufferForShapeInfo(subArrShapeInfo); + + NDArray::prepareSpecialUse({z}, {x}); + + NativeOpExecutioner::execTransformAny( + block.launchContext(), sd::transform::Assign, + x->bufferWithOffset(offset), + subArrShapeInfoPack.primary(), + x->specialBufferWithOffset(offset), + subArrShapeInfoPack.special(), z->buffer(), + z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), nullptr, + nullptr, nullptr, true); + + NDArray::registerSpecialUse({z}, {x}); + + RELEASE(subArrShapeInfo, block.workspace()); + } else if (!z->isEmpty()) { + z->assign(x->e(0)); + } + return Status::OK(); +} +DECLARE_SYN(stridedslice, strided_slice); + +DECLARE_SHAPE_FN(strided_slice) { + auto inShape = inputShape->at(0); + + int begin_mask = INT_ARG(0); + int ellipsis_mask = INT_ARG(1); + int end_mask = INT_ARG(2); + int new_axis_mask = INT_ARG(3); + int shrink_axis_mask = INT_ARG(4); + + int x_rank = shape::rank(inShape); + + int dim_values = block.numI() - 5; + int delta = dim_values % 3; + int elements = dim_values / 3; + + std::vector begin; + std::vector end; + std::vector strides; + + // if that's live - shape will be resolved in runtime + if (block.width() > 1) { + begin = INPUT_VARIABLE(1)->template asVectorT(); + end = INPUT_VARIABLE(2)->template asVectorT(); + strides = INPUT_VARIABLE(3)->template asVectorT(); + } else if (dim_values > 0) { + int delta2 = dim_values / x_rank; + + std::vector args; + for (int e = 5; e < block.numI(); e++) args.emplace_back(INT_ARG(e)); + + // FIXME: propably template required here + ShapeUtils::copyVectorPart(begin, args, elements, 0); + ShapeUtils::copyVectorPart(end, args, elements, elements); + ShapeUtils::copyVectorPart(strides, args, elements, elements * 2); + } + + REQUIRE_TRUE(begin.size() > 0 && end.size() > 0 && strides.size() > 0, 0, + "Strided_Slice: empty arguments"); + + // validation of begin and start + std::vector ignoreBegin = BitwiseUtils::valueBits(begin_mask); + std::vector ignoreEnd = BitwiseUtils::valueBits(end_mask); + std::vector addAxes = BitwiseUtils::valueBits(new_axis_mask); + std::vector moveAxes = BitwiseUtils::valueBits(shrink_axis_mask); + + // if (0 == shrink_axis_mask) + if (false) + for (int dim = 0, b = 0, e = 0; dim < x_rank; ++dim) { + if (moveAxes[dim]) continue; + + if (b < begin.size() && !ignoreBegin[b] && !addAxes[dim]) { + int first = + strides[b] > 0 ? begin[b] : math::nd4j_abs(begin[b]) - 1; + REQUIRE_TRUE( + first <= inShape[dim + 1], 0, + "StridedSlice: begin index should be <= corresponding dimension of " + "input array, but got end_index = %i for dimension %i!", + begin[b], dim); + } + if (e < end.size() && !ignoreEnd[e] && !addAxes[dim]) { + int last = strides[e] > 0 ? end[e] : math::nd4j_abs(end[e]) - 1; + REQUIRE_TRUE( + last <= inShape[dim + 1], 0, + "StridedSlice: end index should be <= corresponding dimension of " + "input array, but got end_index = %i for dimension %i!", + end[e], dim); + } + ++b; + ++e; + } + + std::vector input_shape; //(shape::rank(inShape)); + auto inputLen = shape::length(inShape); + std::vector shape; + + auto rank = shape::rank(inShape); + auto shortShape = shape::shapeOf(inShape); + for (auto e = 0; e < rank; e++) input_shape.emplace_back(shortShape[e]); + + bool is_identity; + bool is_simple_slice; + bool is_dim0; + + std::vector indices; + bool result = _preprocess_strided_slice( + &indices, &shape, input_shape, begin, end, strides, begin_mask, + ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, + &is_simple_slice, &is_dim0); + if (indices.size()) { + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inShape), 'c', shape); + // if (inputLen > 1) { + // newShape = + // ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inShape), + // 'c', + // shape); + // } else { + // newShape = + // ConstantShapeHelper::getInstance().scalarShapeInfo(ArrayOptions::dataType(inShape)); + // } + return SHAPELIST(newShape); + } + + return SHAPELIST(ConstantShapeHelper::getInstance().emptyShapeInfo( + ArrayOptions::dataType(inShape))); +} + +CUSTOM_OP_IMPL(strided_slice_bp, 2, 1, false, 0, 5) { + auto x = INPUT_VARIABLE(0); + auto epsNext = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + + int begin_mask = INT_ARG(0); + int ellipsis_mask = INT_ARG(1); + int end_mask = INT_ARG(2); + int new_axis_mask = INT_ARG(3); + int shrink_axis_mask = INT_ARG(4); + + int dim_values = 0; // block.getIArguments()->size() - 5; + int delta = 0; // dim_values % 3; + int elements = 0; // dim_values / 3; + + std::vector begin; + std::vector end; + std::vector strides; + + bool isLive = false; + + std::vector args; + + // statically evaluated + if (block.numI() > 5) { + dim_values = block.numI() - 5; + delta = dim_values % 3; + elements = dim_values / 3; + + for (int e = 5; e < block.numI(); e++) args.emplace_back(INT_ARG(e)); + + REQUIRE_TRUE(delta == 0, 0, + "StridedSliceBP: Number of Integer arguments should be equal " + "to input rank x 3 = %i, but got %i instead", + (x->rankOf() * 3), dim_values); + + ShapeUtils::copyVectorPart(begin, args, elements, 0); + ShapeUtils::copyVectorPart(end, args, elements, elements); + ShapeUtils::copyVectorPart(strides, args, elements, elements * 2); + + } else if (block.width() >= 3) { + isLive = true; + + auto v_begin = INPUT_VARIABLE(2); + auto v_end = INPUT_VARIABLE(3); + + elements = v_begin->lengthOf(); + + REQUIRE_TRUE(v_begin->lengthOf() == v_end->lengthOf(), 0, + "StridedSliceBP: Length of begin/end should match, but got %i " + "vs %i instead", + (int)v_begin->lengthOf(), (int)v_end->lengthOf()); + + for (int e = 0; e < v_begin->lengthOf(); e++) + begin.emplace_back(v_begin->e(e)); + + for (int e = 0; e < v_end->lengthOf(); e++) + end.emplace_back(v_end->e(e)); + + if (block.width() >= 4) { + auto v_stride = INPUT_VARIABLE(4); + + REQUIRE_TRUE(v_stride->lengthOf() == v_begin->lengthOf(), 0, + "StridedSliceBP: Length of begin/end/stride should match, " + "but got %i vs %i vs %i instead", + (int)v_begin->lengthOf(), (int)v_end->lengthOf(), + (int)v_stride->lengthOf()); + + for (int e = 0; e < v_stride->lengthOf(); e++) + strides.emplace_back(v_stride->e(e)); + } else { + for (int e = 0; e < v_begin->lengthOf(); e++) strides.emplace_back(1); + } + } else { + REQUIRE_TRUE(false, 0, + "StridedSliceBP: Can't find begin/end/stride information " + "neither in IArguments or in input arrays"); + } + + // validation of begin and start + std::vector ignoreBegin = BitwiseUtils::valueBits(begin_mask); + std::vector ignoreEnd = BitwiseUtils::valueBits(end_mask); + std::vector addAxes = BitwiseUtils::valueBits(new_axis_mask); + std::vector moveAxes = BitwiseUtils::valueBits(shrink_axis_mask); + + for (int dim = 0, b = 0, e = 0; dim < x->rankOf(); ++dim) { + if (moveAxes[dim]) continue; + + if (b < begin.size() && !ignoreBegin[b] && !addAxes[dim]) { + int first = strides[b] > 0 ? begin[b] : math::nd4j_abs(begin[b]) - 1; + REQUIRE_TRUE( + first <= x->sizeAt(dim), 0, + "StridedSlice: begin index should be <= corresponding dimension of " + "input array, but got end_index = %i for dimension %i!", + begin[b], dim); + } + if (e < end.size() && !ignoreEnd[e] && !addAxes[dim]) { + int last = strides[e] > 0 ? end[e] : math::nd4j_abs(end[e]) - 1; + REQUIRE_TRUE( + last <= x->sizeAt(dim), 0, + "StridedSlice: end index should be <= corresponding dimension of " + "input array, but got end_index = %i for dimension %i!", + end[e], dim); + } + ++b; + ++e; + } + + auto input_shape = x->getShapeAsVector(); + std::vector indices; + std::vector final_shape; + bool is_identity; + bool is_simple_slice; + bool is_dim0; + + // FIXME: remove this method once we get 1D vectors supported + vectorize(input_shape); + REQUIRE_TRUE(_preprocess_strided_slice( + &indices, &final_shape, input_shape, begin, end, strides, + begin_mask, ellipsis_mask, end_mask, new_axis_mask, + shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0), + 0, "StridedSliceBP: shape calculation failed"); + // REQUIRE_TRUE(epsNext->isSameShape(final_shape), 0, "StridedSlice_bp: + // gradOut shape should be equals to output from strided_slice op."); Zero + // output array, so unused elements have 0 gradient + output->nullify(); + // + // the first case: only for scalar gradient step + if (epsNext->lengthOf() == 1 && + (indices.size() == 3 && (indices[1] - indices[0]) == 1 || + (indices[2] - indices[0] == 1))) { + output->p(indices[0], *epsNext); + } else { // else for other cases + auto sub = (*output)(indices, true, true); + sub.assign(epsNext); + } + + return Status::OK(); +} + +DECLARE_SHAPE_FN(strided_slice_bp) { + auto inShape = inputShape->at(0); + Nd4jLong* newShape; + COPY_SHAPE(inShape, newShape); + + return SHAPELIST(CONSTANT(newShape)); +} + +DECLARE_TYPES(strided_slice) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} + +DECLARE_TYPES(strided_slice_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/tensor/zeros_as.cpp b/libnd4j/include/ops/declarable/generic/tensor/zeros_as.cpp index 7935c567e541..e7c27f784fde 100644 --- a/libnd4j/include/ops/declarable/generic/tensor/zeros_as.cpp +++ b/libnd4j/include/ops/declarable/generic/tensor/zeros_as.cpp @@ -24,33 +24,33 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(zeros_as, 1, 1, false, 0, 0) { - auto out = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(zeros_as, 1, 1, false, 0, 0) { + auto out = OUTPUT_VARIABLE(0); - out->assign(0); // output is filled by zero by default - - return Status::OK(); - } - DECLARE_SYN(zeroslike, zeros_as); - DECLARE_SYN(zeros_like, zeros_as); + out->assign(0); // output is filled by zero by default + return Status::OK(); +} +DECLARE_SYN(zeroslike, zeros_as); +DECLARE_SYN(zeros_like, zeros_as); - DECLARE_SHAPE_FN(zeros_as) { - auto in = inputShape->at(0); - auto dtype = block.numD() ? D_ARG(0) : ArrayOptions::dataType(in); - auto shape = sd::ConstantShapeHelper::getInstance().createShapeInfo(dtype, in); +DECLARE_SHAPE_FN(zeros_as) { + auto in = inputShape->at(0); + auto dtype = block.numD() ? D_ARG(0) : ArrayOptions::dataType(in); + auto shape = + sd::ConstantShapeHelper::getInstance().createShapeInfo(dtype, in); - return SHAPELIST(shape); - } + return SHAPELIST(shape); +} - DECLARE_TYPES(zeros_as) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::ANY) - ->setSameMode(false); - } - } +DECLARE_TYPES(zeros_as) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::ANY) + ->setSameMode(false); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tests/noop.cpp b/libnd4j/include/ops/declarable/generic/tests/noop.cpp index 37980c7e6a44..7f69b7aa6e2d 100644 --- a/libnd4j/include/ops/declarable/generic/tests/noop.cpp +++ b/libnd4j/include/ops/declarable/generic/tests/noop.cpp @@ -24,18 +24,16 @@ #include namespace sd { - namespace ops { - OP_IMPL(noop, -2, -2, true) { - // Fastest op ever. - return Status::OK(); - } +namespace ops { +OP_IMPL(noop, -2, -2, true) { + // Fastest op ever. + return Status::OK(); +} - DECLARE_TYPES(noop) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +DECLARE_TYPES(noop) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/tests/test_output_reshape.cpp b/libnd4j/include/ops/declarable/generic/tests/test_output_reshape.cpp index 7ded29e2019a..2323ee4eda0e 100644 --- a/libnd4j/include/ops/declarable/generic/tests/test_output_reshape.cpp +++ b/libnd4j/include/ops/declarable/generic/tests/test_output_reshape.cpp @@ -24,25 +24,22 @@ #include namespace sd { - namespace ops { - OP_IMPL(test_output_reshape, 1, 1, true) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +OP_IMPL(test_output_reshape, 1, 1, true) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - if (!block.isInplace()) - output->assign(input); + if (!block.isInplace()) output->assign(input); - output->reshapei({-1}); + output->reshapei({-1}); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(test_output_reshape) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +DECLARE_TYPES(test_output_reshape) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/tests/test_scalar.cpp b/libnd4j/include/ops/declarable/generic/tests/test_scalar.cpp index e67122b05ba5..e066b1d0dd69 100644 --- a/libnd4j/include/ops/declarable/generic/tests/test_scalar.cpp +++ b/libnd4j/include/ops/declarable/generic/tests/test_scalar.cpp @@ -24,43 +24,43 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(test_scalar, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(test_scalar, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - double val = input->e(0) + 2.0; - output->p(0, val); + double val = input->e(0) + 2.0; + output->p(0, val); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(test_scalar) { - Nd4jLong *newShape; - ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); +DECLARE_SHAPE_FN(test_scalar) { + Nd4jLong *newShape; + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); - newShape[0] = 2; - newShape[1] = 1; - newShape[2] = 1; - newShape[3] = 1; - newShape[4] = 1; - newShape[5] = 0; - newShape[6] = 1; - newShape[7] = 99; + newShape[0] = 2; + newShape[1] = 1; + newShape[2] = 1; + newShape[3] = 1; + newShape[4] = 1; + newShape[5] = 0; + newShape[6] = 1; + newShape[7] = 99; - ArrayOptions::setDataType(newShape, ArrayOptions::dataType(inputShape->at(0))); + ArrayOptions::setDataType(newShape, + ArrayOptions::dataType(inputShape->at(0))); - auto shape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(newShape)); - RELEASE(newShape, block.getWorkspace()); - return SHAPELIST(shape); - } + auto shape = ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(newShape)); + RELEASE(newShape, block.workspace()); + return SHAPELIST(shape); +} - DECLARE_TYPES(test_scalar) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +DECLARE_TYPES(test_scalar) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tests/testcustom.cpp b/libnd4j/include/ops/declarable/generic/tests/testcustom.cpp index e8d7fc6c3c29..65085aa393ad 100644 --- a/libnd4j/include/ops/declarable/generic/tests/testcustom.cpp +++ b/libnd4j/include/ops/declarable/generic/tests/testcustom.cpp @@ -24,32 +24,32 @@ #include namespace sd { - namespace ops { - ////////////////////////////////////////////////////////////////////////// - CUSTOM_OP_IMPL(testcustom, 1, 1, false, 0, -1) { - auto z = this->getZ(block); - - STORE_RESULT(*z); - return Status::OK(); - } - DECLARE_SHAPE_FN(testcustom) { - // this test op will just return back original shape doubled - Nd4jLong *shapeOf; - ALLOCATE(shapeOf, block.getWorkspace(), shape::rank(inputShape->at(0)), Nd4jLong); - for (int e = 0; e < shape::rank(inputShape->at(0)); e++) - shapeOf[e] = inputShape->at(0)[e+1] * 2; - - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(block.dataType(), 'c', shape::rank(inputShape->at(0)), shapeOf); - RELEASE(shapeOf, block.getWorkspace()); - return SHAPELIST(newShape); - } - - DECLARE_TYPES(testcustom) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +namespace ops { +////////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(testcustom, 1, 1, false, 0, -1) { + auto z = this->getZ(block); + + STORE_RESULT(*z); + return Status::OK(); +} +DECLARE_SHAPE_FN(testcustom) { + // this test op will just return back original shape doubled + Nd4jLong *shapeOf; + ALLOCATE(shapeOf, block.workspace(), shape::rank(inputShape->at(0)), + Nd4jLong); + for (int e = 0; e < shape::rank(inputShape->at(0)); e++) + shapeOf[e] = inputShape->at(0)[e + 1] * 2; + + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + DataType::FLOAT32, 'c', shape::rank(inputShape->at(0)), shapeOf); + RELEASE(shapeOf, block.workspace()); + return SHAPELIST(newShape); +} + +DECLARE_TYPES(testcustom) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/tests/testop2i2o.cpp b/libnd4j/include/ops/declarable/generic/tests/testop2i2o.cpp index f4d4d3159d3a..9546be7c3874 100644 --- a/libnd4j/include/ops/declarable/generic/tests/testop2i2o.cpp +++ b/libnd4j/include/ops/declarable/generic/tests/testop2i2o.cpp @@ -24,32 +24,30 @@ #include namespace sd { - namespace ops { - ////////////////////////////////////////////////////////////////////////// - // test op, non-divergent - OP_IMPL(testop2i2o, 2, 2, true) { - //nd4j_printf("CPU op used!\n",""); - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - - auto xO = OUTPUT_VARIABLE(0); - auto yO = OUTPUT_VARIABLE(1); - - x->applyScalar(scalar::Add, 1.0, *xO); - y->applyScalar(scalar::Add, 2.0, *yO); - - STORE_2_RESULTS(*xO, *yO); - - return Status::OK(); - } - DECLARE_SYN(TestOp2i2o, testop2i2o); - - DECLARE_TYPES(testop2i2o) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +namespace ops { +////////////////////////////////////////////////////////////////////////// +// test op, non-divergent +OP_IMPL(testop2i2o, 2, 2, true) { + // nd4j_printf("CPU op used!\n",""); + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + + auto xO = OUTPUT_VARIABLE(0); + auto yO = OUTPUT_VARIABLE(1); + + x->applyScalar(scalar::Add, 1.0, *xO); + y->applyScalar(scalar::Add, 2.0, *yO); + + STORE_2_RESULTS(*xO, *yO); + + return Status::OK(); +} +DECLARE_SYN(TestOp2i2o, testop2i2o); + +DECLARE_TYPES(testop2i2o) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/tests/testreduction.cpp b/libnd4j/include/ops/declarable/generic/tests/testreduction.cpp index a0749ed7f6e0..a85dbf906219 100644 --- a/libnd4j/include/ops/declarable/generic/tests/testreduction.cpp +++ b/libnd4j/include/ops/declarable/generic/tests/testreduction.cpp @@ -24,20 +24,18 @@ #include namespace sd { - namespace ops { - REDUCTION_OP_IMPL(testreduction, 1, 1, false, 0, -1) { - auto z = OUTPUT_VARIABLE(0); +namespace ops { +REDUCTION_OP_IMPL(testreduction, 1, 1, false, 0, -1) { + auto z = OUTPUT_VARIABLE(0); -// STORE_RESULT(*z); - return Status::OK(); - } + // STORE_RESULT(*z); + return Status::OK(); +} - DECLARE_TYPES(testreduction) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +DECLARE_TYPES(testreduction) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp b/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp index 3a115b8db286..a650ac3ba020 100644 --- a/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp +++ b/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp @@ -26,85 +26,82 @@ #ifndef LIBND4J_THIRD_PARTY_H #define LIBND4J_THIRD_PARTY_H -#include -#include +#include #include #include -#include +#include #include #include -#include - -namespace sd { - namespace ops { - - - /** - * This op is special one, and suited only for ProjectionLayer by @firasdib - * - * - * - * @tparam T - */ - CUSTOM_OP_IMPL(firas_sparse, 1, 1, false, 0, -1) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_NULLIFIED(0); +#include - int batchSize = x->sizeAt(0); - int numColumns = x->sizeAt(1); +#include - std::vector indices(*block.getIArguments()); - std::map sparse2dense; +namespace sd { +namespace ops { +/** + * This op is special one, and suited only for ProjectionLayer by @firasdib + * + * + * + * @tparam T + */ +CUSTOM_OP_IMPL(firas_sparse, 1, 1, false, 0, -1) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_NULLIFIED(0); - int cnt = 0; - for (auto v: indices) { - std::pair pair(v, cnt++); - sparse2dense.insert(pair); - } + int batchSize = x->sizeAt(0); + int numColumns = x->sizeAt(1); - ResultSet rows = x->allTensorsAlongDimension({1}); + std::vector indices(block.getIArguments()); + std::map sparse2dense; - //PRAGMA_OMP_PARALLEL_FOR - for (int r = 0; r < batchSize; r++) { - auto row = rows.at(r); + int cnt = 0; + for (auto v : indices) { + std::pair pair(v, cnt++); + sparse2dense.insert(pair); + } - for (int e = 0; e < numColumns; e += 2) { - int idx = row->e(e); - if (idx < 0) - break; + ResultSet rows = x->allTensorsAlongDimension({1}); - int denseIdx = sparse2dense.at(idx); + // PRAGMA_OMP_PARALLEL_FOR + for (int r = 0; r < batchSize; r++) { + auto row = rows.at(r); + for (int e = 0; e < numColumns; e += 2) { + int idx = row.e(e); + if (idx < 0) break; - float value = row->e(e); - float current = z->e(r, denseIdx); - z->p(r, denseIdx, value + current); - } - } + int denseIdx = sparse2dense.at(idx); + float value = row.e(e); + float current = z->e(r, denseIdx); + z->p(r, denseIdx, value + current); + } + } - //STORE_RESULT(*z); + // STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(firas_sparse) { - auto inP = inputShape->at(0); +DECLARE_SHAPE_FN(firas_sparse) { + auto inP = inputShape->at(0); - std::vector shape({shape::shapeOf(inP)[0], (Nd4jLong) block.getIArguments()->size()}); - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inP), 'c', shape); - return SHAPELIST(newShape); - } + std::vector shape({shape::shapeOf(inP)[0], (Nd4jLong)block.numI()}); + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inP), 'c', shape); + return SHAPELIST(newShape); +} - DECLARE_TYPES(firas_sparse) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +DECLARE_TYPES(firas_sparse) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif -#endif //LIBND4J_THIRD_PARTY_H +#endif // LIBND4J_THIRD_PARTY_H diff --git a/libnd4j/include/ops/declarable/generic/transforms/batch_to_space.cpp b/libnd4j/include/ops/declarable/generic/transforms/batch_to_space.cpp index 0ffad12a2830..2ef218b7e76c 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/batch_to_space.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/batch_to_space.cpp @@ -40,90 +40,128 @@ limitations under the License. #include namespace sd { -namespace ops { - +namespace ops { CUSTOM_OP_IMPL(batch_to_space, 2, 1, false, 0, 1) { - - // [bS*blockSize*blockSize, H/blockSize, W/blockSize, iC] is rearranged/permuted to [bS, oH, oW, iC] - // oH = H - cropTop - cropBottom - // oW = W - cropLeft - cropRight - - auto input = INPUT_VARIABLE(0); - auto crop = INPUT_VARIABLE(1); - - auto output = OUTPUT_VARIABLE(0); - - const uint blockSize = INT_ARG(0); - REQUIRE_TRUE(blockSize >= 2, 0, "BatchToSpace: integer parameter block_size must be >= 2, but got %i instead", blockSize); - - const int rank = input->rankOf(); - const int dim0 = input->sizeAt(0); - REQUIRE_TRUE(rank == 4, 0, "BatchToSpace: rank of input array must be equal 4, but got %i instead", rank); - REQUIRE_TRUE(dim0 % (blockSize * blockSize) == 0, 0, "BatchToSpace: first dimension of input array must be divisible by blockSize * blockSize (that is by %i), but got first dimension equal to %i", blockSize * blockSize, dim0); - - if(crop->sizeAt(0) != 2 || crop->sizeAt(1) != 2) - REQUIRE_TRUE(false, 0, "BatchToSpace: operation expects crop shape to be {2, 2}, but got %s instead", ShapeUtils::shapeAsString(crop).c_str()); - - const uint cropBottom = crop->e(0,0); - const uint cropTop = crop->e(0,1); - const uint cropLeft = crop->e(1,0); - const uint cropRight = crop->e(1,1); - - const int oH = input->sizeAt(1) * blockSize - cropBottom - cropTop; // top and bottom - const int oW = input->sizeAt(2) * blockSize - cropLeft - cropRight; // left and right - REQUIRE_TRUE(oH >= 0, 0, "BatchToSpace: crop top/bottom values are too big and cause negative output height dimension !"); - REQUIRE_TRUE(oW >= 0, 0, "BatchToSpace: crop left/right values are too big and cause negative output width dimension !"); - - if (shape::strideDescendingCAscendingF(input->shapeInfo())) - helpers::batchToSpace(block.launchContext(), *input, *output, cropBottom, cropTop, cropLeft, cropRight, blockSize); - else - helpers::batchToSpace(block.launchContext(), input->dup(), *output, cropBottom, cropTop, cropLeft, cropRight, blockSize); - - return Status::OK(); + // [bS*blockSize*blockSize, H/blockSize, W/blockSize, iC] is + // rearranged/permuted to [bS, oH, oW, iC] oH = H - cropTop - cropBottom oW = + // W - cropLeft - cropRight + + auto input = INPUT_VARIABLE(0); + auto crop = INPUT_VARIABLE(1); + + auto output = OUTPUT_VARIABLE(0); + + const uint blockSize = INT_ARG(0); + REQUIRE_TRUE(blockSize >= 2, 0, + "BatchToSpace: integer parameter block_size must be >= 2, but " + "got %i instead", + blockSize); + + const int rank = input->rankOf(); + const int dim0 = input->sizeAt(0); + REQUIRE_TRUE( + rank == 4, 0, + "BatchToSpace: rank of input array must be equal 4, but got %i instead", + rank); + REQUIRE_TRUE(dim0 % (blockSize * blockSize) == 0, 0, + "BatchToSpace: first dimension of input array must be divisible " + "by blockSize * blockSize (that is by %i), but got first " + "dimension equal to %i", + blockSize * blockSize, dim0); + + if (crop->sizeAt(0) != 2 || crop->sizeAt(1) != 2) + REQUIRE_TRUE(false, 0, + "BatchToSpace: operation expects crop shape to be {2, 2}, but " + "got %s instead", + ShapeUtils::shapeAsString(crop).c_str()); + + const uint cropBottom = crop->e(0, 0); + const uint cropTop = crop->e(0, 1); + const uint cropLeft = crop->e(1, 0); + const uint cropRight = crop->e(1, 1); + + const int oH = + input->sizeAt(1) * blockSize - cropBottom - cropTop; // top and bottom + const int oW = + input->sizeAt(2) * blockSize - cropLeft - cropRight; // left and right + REQUIRE_TRUE(oH >= 0, 0, + "BatchToSpace: crop top/bottom values are too big and cause " + "negative output height dimension !"); + REQUIRE_TRUE(oW >= 0, 0, + "BatchToSpace: crop left/right values are too big and cause " + "negative output width dimension !"); + + if (shape::strideDescendingCAscendingF(input->shapeInfo())) + helpers::batchToSpace(block.launchContext(), *input, *output, cropBottom, + cropTop, cropLeft, cropRight, blockSize); + else + helpers::batchToSpace(block.launchContext(), input->dup(), *output, + cropBottom, cropTop, cropLeft, cropRight, blockSize); + + return Status::OK(); } //////////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(batch_to_space) { - - getOpDescriptor()->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setSameMode(true); } //////////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(batch_to_space) { - - auto inputShapeInfo = inputShape->at(0); - auto cropShapeInfo = inputShape->at(1); - - const uint blockSize = INT_ARG(0); - REQUIRE_TRUE(blockSize >= 2, 0, "BatchToSpace: integer parameter block_size must be >= 2, but got %i instead", blockSize); - - const int rank = inputShapeInfo[0]; - const int dim0 = inputShapeInfo[1]; - REQUIRE_TRUE(rank == 4, 0, "BatchToSpace: rank of input array must be equal 4, but got %i instead", rank); - REQUIRE_TRUE(dim0 % (blockSize * blockSize) == 0, 0, "BatchToSpace: first dimension of input array must be divisible by blockSize * blockSize (that is by %i), but got first dimension equal to %i", blockSize * blockSize, dim0); - - if(cropShapeInfo[1] != 2 || cropShapeInfo[2] != 2) - REQUIRE_TRUE(false, 0, "BatchToSpace: operation expects crop shape to be {2, 2}, but got %s instead", ShapeUtils::shapeAsString(cropShapeInfo).c_str()); - - const uint cropBottom = INPUT_VARIABLE(1)->e(0,0); - const uint cropTop = INPUT_VARIABLE(1)->e(0,1); - const uint cropLeft = INPUT_VARIABLE(1)->e(1,0); - const uint cropRight = INPUT_VARIABLE(1)->e(1,1); - - const int oH = inputShapeInfo[2] * blockSize - cropTop - cropBottom; // top and bottom - const int oW = inputShapeInfo[3] * blockSize - cropLeft - cropRight; // left and right - REQUIRE_TRUE(oH >= 0, 0, "BatchToSpace: crop top/bottom values are too big and cause negative output height dimension !"); - REQUIRE_TRUE(oW >= 0, 0, "BatchToSpace: crop left/right values are too big and cause negative output width dimension !"); - - // we always give out C order here - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inputShapeInfo), 'c', {dim0 / (blockSize * blockSize), oH, oW, inputShapeInfo[4]})); + auto inputShapeInfo = inputShape->at(0); + auto cropShapeInfo = inputShape->at(1); + + const uint blockSize = INT_ARG(0); + REQUIRE_TRUE(blockSize >= 2, 0, + "BatchToSpace: integer parameter block_size must be >= 2, but " + "got %i instead", + blockSize); + + const int rank = inputShapeInfo[0]; + const int dim0 = inputShapeInfo[1]; + REQUIRE_TRUE( + rank == 4, 0, + "BatchToSpace: rank of input array must be equal 4, but got %i instead", + rank); + REQUIRE_TRUE(dim0 % (blockSize * blockSize) == 0, 0, + "BatchToSpace: first dimension of input array must be divisible " + "by blockSize * blockSize (that is by %i), but got first " + "dimension equal to %i", + blockSize * blockSize, dim0); + + if (cropShapeInfo[1] != 2 || cropShapeInfo[2] != 2) + REQUIRE_TRUE(false, 0, + "BatchToSpace: operation expects crop shape to be {2, 2}, but " + "got %s instead", + ShapeUtils::shapeAsString(cropShapeInfo).c_str()); + + const uint cropBottom = INPUT_VARIABLE(1)->e(0, 0); + const uint cropTop = INPUT_VARIABLE(1)->e(0, 1); + const uint cropLeft = INPUT_VARIABLE(1)->e(1, 0); + const uint cropRight = INPUT_VARIABLE(1)->e(1, 1); + + const int oH = + inputShapeInfo[2] * blockSize - cropTop - cropBottom; // top and bottom + const int oW = + inputShapeInfo[3] * blockSize - cropLeft - cropRight; // left and right + REQUIRE_TRUE(oH >= 0, 0, + "BatchToSpace: crop top/bottom values are too big and cause " + "negative output height dimension !"); + REQUIRE_TRUE(oW >= 0, 0, + "BatchToSpace: crop left/right values are too big and cause " + "negative output width dimension !"); + + // we always give out C order here + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inputShapeInfo), 'c', + {dim0 / (blockSize * blockSize), oH, oW, inputShapeInfo[4]})); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/batch_to_space_nd.cpp b/libnd4j/include/ops/declarable/generic/transforms/batch_to_space_nd.cpp index 1ae1a2e61fb2..8fe68d59d954 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/batch_to_space_nd.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/batch_to_space_nd.cpp @@ -40,89 +40,124 @@ limitations under the License. #include namespace sd { -namespace ops { - +namespace ops { CUSTOM_OP_IMPL(batch_to_space_nd, 3, 1, false, 0, 0) { - - // 4D example, numOfSpatialDims = 2 - two spatial dimensions - // [bS*blockShape[0]*blockShape[1], iH, iW, iC] is rearranged/permuted to [bS, iH*blockShape[0] - cropTop - cropBottom, iW*blockShape[1] - cropLeft - cropRight, iC] - - auto input = INPUT_VARIABLE(0); - auto blockShape = INPUT_VARIABLE(1); - auto crop = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(blockShape->rankOf() == 1, 0, "BatchToSpaceND: rank of blockShape array must be equal to one, but got %i instead !", blockShape->rankOf()); - - const uint numOfSpatialDims = blockShape->sizeAt(0); - - const auto product = blockShape->reduceNumber(sd::reduce::Prod).e(0); - REQUIRE_TRUE(input->sizeAt(0) % product == 0, 0, "BatchToSpaceND: first dimension of input array must be divisible by product of blockShape array elements (= %lld), but got first dimension equal to %i", product, input->sizeAt(0)); - - if(crop->sizeAt(0) != numOfSpatialDims || crop->sizeAt(1) != 2) { - const std::string expectedCropShape = "[" + std::to_string(numOfSpatialDims) + ", 2]"; // [numOfSpatialDims, 2] - REQUIRE_TRUE(false, 0, "BatchToSpaceND: operation expects padding shape to be %s, but got %s instead", expectedCropShape.c_str(), ShapeUtils::shapeAsString(crop).c_str()); - } - - // FIXME - should we use this time-consuming validation ? - for (uint i = 0; i < numOfSpatialDims; ++i) { - const auto cropLeft = crop->e(i,0); - const auto cropRight = crop->e(i,1); - const auto outSpatialDim = input->sizeAt(i + 1) * blockShape->e(i) - cropLeft - cropRight; - REQUIRE_TRUE(outSpatialDim >= 0, 0, "BatchToSpaceND: crop left/right values are too big and cause negative output spatial dimension/dimensions !"); - } - - if (shape::strideDescendingCAscendingF(input->shapeInfo())) - helpers::batchToSpaceND(block.launchContext(), *input, *blockShape, *crop, *output); - else - helpers::batchToSpaceND(block.launchContext(), input->dup(), *blockShape, *crop, *output); - - return Status::OK(); + // 4D example, numOfSpatialDims = 2 - two spatial dimensions + // [bS*blockShape[0]*blockShape[1], iH, iW, iC] is rearranged/permuted to [bS, + // iH*blockShape[0] - cropTop - cropBottom, iW*blockShape[1] - cropLeft - + // cropRight, iC] + + auto input = INPUT_VARIABLE(0); + auto blockShape = INPUT_VARIABLE(1); + auto crop = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(blockShape->rankOf() == 1, 0, + "BatchToSpaceND: rank of blockShape array must be equal to one, " + "but got %i instead !", + blockShape->rankOf()); + + const uint numOfSpatialDims = blockShape->sizeAt(0); + + const auto product = + blockShape->reduceNumber(sd::reduce::Prod).e(0); + REQUIRE_TRUE(input->sizeAt(0) % product == 0, 0, + "BatchToSpaceND: first dimension of input array must be " + "divisible by product of blockShape array elements (= %lld), " + "but got first dimension equal to %i", + product, input->sizeAt(0)); + + if (crop->sizeAt(0) != numOfSpatialDims || crop->sizeAt(1) != 2) { + const std::string expectedCropShape = "[" + + std::to_string(numOfSpatialDims) + + ", 2]"; // [numOfSpatialDims, 2] + REQUIRE_TRUE(false, 0, + "BatchToSpaceND: operation expects padding shape to be %s, " + "but got %s instead", + expectedCropShape.c_str(), + ShapeUtils::shapeAsString(crop).c_str()); + } + + // FIXME - should we use this time-consuming validation ? + for (uint i = 0; i < numOfSpatialDims; ++i) { + const auto cropLeft = crop->e(i, 0); + const auto cropRight = crop->e(i, 1); + const auto outSpatialDim = + input->sizeAt(i + 1) * blockShape->e(i) - cropLeft - + cropRight; + REQUIRE_TRUE(outSpatialDim >= 0, 0, + "BatchToSpaceND: crop left/right values are too big and cause " + "negative output spatial dimension/dimensions !"); + } + + if (shape::strideDescendingCAscendingF(input->shapeInfo())) + helpers::batchToSpaceND(block.launchContext(), *input, *blockShape, *crop, + *output); + else + helpers::batchToSpaceND(block.launchContext(), input->dup(), *blockShape, + *crop, *output); + + return Status::OK(); } //////////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(batch_to_space_nd) { - - getOpDescriptor()->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS}) - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS}) + ->setSameMode(true); } //////////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(batch_to_space_nd) { - - auto inputShapeInfo = inputShape->at(0); - auto blockShapeInfo = inputShape->at(1); - auto cropShapeInfo = inputShape->at(2); - - REQUIRE_TRUE(blockShapeInfo[0] == 1, 0, "BatchToSpaceND: rank of blockShape array must be equal to one, but got %i instead !", blockShapeInfo[0]); - - const auto product = INPUT_VARIABLE(1)->reduceNumber(sd::reduce::Prod).e(0); - REQUIRE_TRUE(inputShapeInfo[1] % product == 0, 0, "BatchToSpaceND: first dimension of input array must be divisible by product of blockShape array elements (= %lld), but got first dimension equal to %i", product, inputShapeInfo[1]); - - const auto numOfSpatialDims = blockShapeInfo[1]; - - if(cropShapeInfo[1] != numOfSpatialDims || cropShapeInfo[2] != 2) { - const std::string expectedCropShape = "[" + std::to_string(numOfSpatialDims) + ", 2]"; // [numOfSpatialDims, 2] - REQUIRE_TRUE(false, 0, "BatchToSpaceND: operation expects padding shape to be %s, but got %s instead", expectedCropShape.c_str(), ShapeUtils::shapeAsString(cropShapeInfo).c_str()); - } - - - std::vector outShape(inputShapeInfo + 1, inputShapeInfo + 1 + inputShapeInfo[0]); - - outShape[0] /= product; - - for (uint i = 0; i < numOfSpatialDims; ++i) - outShape[i + 1] = outShape[i + 1] * INPUT_VARIABLE(1)->e(i) - INPUT_VARIABLE(2)->e(i,0) - INPUT_VARIABLE(2)->e(i,1); - - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inputShapeInfo), 'c', outShape)); + auto inputShapeInfo = inputShape->at(0); + auto blockShapeInfo = inputShape->at(1); + auto cropShapeInfo = inputShape->at(2); + + REQUIRE_TRUE(blockShapeInfo[0] == 1, 0, + "BatchToSpaceND: rank of blockShape array must be equal to one, " + "but got %i instead !", + blockShapeInfo[0]); + + const auto product = + INPUT_VARIABLE(1)->reduceNumber(sd::reduce::Prod).e(0); + REQUIRE_TRUE(inputShapeInfo[1] % product == 0, 0, + "BatchToSpaceND: first dimension of input array must be " + "divisible by product of blockShape array elements (= %lld), " + "but got first dimension equal to %i", + product, inputShapeInfo[1]); + + const auto numOfSpatialDims = blockShapeInfo[1]; + + if (cropShapeInfo[1] != numOfSpatialDims || cropShapeInfo[2] != 2) { + const std::string expectedCropShape = "[" + + std::to_string(numOfSpatialDims) + + ", 2]"; // [numOfSpatialDims, 2] + REQUIRE_TRUE(false, 0, + "BatchToSpaceND: operation expects padding shape to be %s, " + "but got %s instead", + expectedCropShape.c_str(), + ShapeUtils::shapeAsString(cropShapeInfo).c_str()); + } + + std::vector outShape(inputShapeInfo + 1, + inputShapeInfo + 1 + inputShapeInfo[0]); + + outShape[0] /= product; + + for (uint i = 0; i < numOfSpatialDims; ++i) + outShape[i + 1] = outShape[i + 1] * INPUT_VARIABLE(1)->e(i) - + INPUT_VARIABLE(2)->e(i, 0) - + INPUT_VARIABLE(2)->e(i, 1); + + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inputShapeInfo), 'c', outShape)); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/clip_by_averaged_norm.cpp b/libnd4j/include/ops/declarable/generic/transforms/clip_by_averaged_norm.cpp index a7340bf217e2..1255566462f1 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/clip_by_averaged_norm.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/clip_by_averaged_norm.cpp @@ -23,29 +23,29 @@ #if NOT_EXCLUDED(OP_clipbyavgnorm) #include -#include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(clipbyavgnorm, 1, 1, true, 1, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + const bool isInplace = block.isInplace(); + auto clipNorm = NDArrayFactory::create(T_ARG(0), block.launchContext()); - const bool isInplace = block.isInplace(); - auto clipNorm = NDArrayFactory::create(T_ARG(0), block.launchContext()); + helpers::clipByNorm(block.launchContext(), *input, *output, + block.getIArguments(), clipNorm, isInplace, true); - helpers::clipByNorm(block.launchContext(), *input, *output, *block.getIArguments(), clipNorm, isInplace, true); - - return Status::OK(); + return Status::OK(); } DECLARE_TYPES(clipbyavgnorm) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// @@ -58,7 +58,7 @@ CUSTOM_OP_IMPL(clipbyavgnorm_bp, 2, 1, false, 1, 0) { const auto clipNorm = NDArrayFactory::create(gradI->dataType(), T_ARG(0), block.launchContext()); - helpers::clipByNormBp(block.launchContext(), *input, *gradO, *gradI, *block.getIArguments(), clipNorm, true); + helpers::clipByNormBp(block.launchContext(), *input, *gradO, *gradI, block.getIArguments(), clipNorm, true); return Status::OK(); } @@ -81,7 +81,7 @@ DECLARE_TYPES(clipbyavgnorm_bp) { } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/clip_by_global_norm.cpp b/libnd4j/include/ops/declarable/generic/transforms/clip_by_global_norm.cpp index 7758cf2989fa..6b0e2aa9b1bf 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/clip_by_global_norm.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/clip_by_global_norm.cpp @@ -25,48 +25,47 @@ #include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(clip_by_global_norm, 1, 2, true, 1, 0) { + std::vector inputs(block.width()); + std::vector outputs(block.width() + 1); + for (size_t i = 0; i < inputs.size(); ++i) { + inputs[i] = INPUT_VARIABLE(i); + outputs[i] = OUTPUT_VARIABLE(i); + } + outputs[inputs.size()] = OUTPUT_VARIABLE(inputs.size()); + double clipNorm = T_ARG(0); + bool isInplace = block.isInplace(); + helpers::clipByGlobalNorm(block.launchContext(), inputs, clipNorm, + block.workspace(), outputs, isInplace); - std::vector inputs(block.width()); - std::vector outputs(block.width() + 1); - for (size_t i = 0; i < inputs.size(); ++i) { - inputs[i] = INPUT_VARIABLE(i); - outputs[i] = OUTPUT_VARIABLE(i); - } - outputs[inputs.size()] = OUTPUT_VARIABLE(inputs.size()); - double clipNorm = T_ARG(0); - bool isInplace = block.isInplace(); - helpers::clipByGlobalNorm(block.launchContext(), inputs, clipNorm, block.workspace(), outputs, isInplace); - - return Status::OK(); + return Status::OK(); } DECLARE_SHAPE_FN(clip_by_global_norm) { + auto shapeList = SHAPELIST(); - auto shapeList = SHAPELIST(); - - for (int e = 0; e < block.width(); e++) { - auto in = inputShape->at(e); - - Nd4jLong* newShape; - COPY_SHAPE(in, newShape); - shapeList->push_back(CONSTANT(newShape)); - } - - shapeList->push_back(ConstantShapeHelper::getInstance().scalarShapeInfo(ArrayOptions::dataType(inputShape->at(0)))); - return shapeList; -} - - DECLARE_TYPES(clip_by_global_norm) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } + for (int e = 0; e < block.width(); e++) { + auto in = inputShape->at(e); + Nd4jLong* newShape; + COPY_SHAPE(in, newShape); + shapeList->push_back(CONSTANT(newShape)); + } + shapeList->push_back(ConstantShapeHelper::getInstance().scalarShapeInfo( + ArrayOptions::dataType(inputShape->at(0)))); + return shapeList; } + +DECLARE_TYPES(clip_by_global_norm) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/clip_by_norm.cpp b/libnd4j/include/ops/declarable/generic/transforms/clip_by_norm.cpp index 75145f7ccfb1..e3daa26d1979 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/clip_by_norm.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/clip_by_norm.cpp @@ -22,58 +22,60 @@ #if NOT_EXCLUDED(OP_clipbynorm) #include -#include +#include namespace sd { -namespace ops { +namespace ops { - CONFIGURABLE_OP_IMPL(clipbynorm, 1, 1, true, 1, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +CONFIGURABLE_OP_IMPL(clipbynorm, 1, 1, true, 1, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - const auto clipNorm = NDArrayFactory::create(output->dataType(), T_ARG(0), block.launchContext()); - const bool isInplace = block.isInplace(); + const auto clipNorm = NDArrayFactory::create(output->dataType(), T_ARG(0), + block.launchContext()); + const bool isInplace = block.isInplace(); - helpers::clipByNorm(block.launchContext(), *input, *output, *block.getIArguments(), clipNorm, isInplace, false); + helpers::clipByNorm(block.launchContext(), *input, *output, + block.getIArguments(), clipNorm, isInplace, false); - return Status::OK(); - } - - - CUSTOM_OP_IMPL(clipbynorm_bp, 2, 1, false, 1, 0) { - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); + return Status::OK(); +} - auto gradI = OUTPUT_VARIABLE(0); - const auto clipNorm = NDArrayFactory::create(gradI->dataType(), T_ARG(0), block.launchContext()); +CUSTOM_OP_IMPL(clipbynorm_bp, 2, 1, false, 1, 0) { + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); - helpers::clipByNormBp(block.launchContext(), *input, *gradO, *gradI, *block.getIArguments(), clipNorm, false); + auto gradI = OUTPUT_VARIABLE(0); + const auto clipNorm = NDArrayFactory::create(gradI->dataType(), T_ARG(0), block.launchContext()); - return Status::OK(); - } + helpers::clipByNormBp(block.launchContext(), *input, *gradO, *gradI, + block.getIArguments(), clipNorm, false); - DECLARE_SHAPE_FN(clipbynorm_bp) { - auto inShapeInfo = inputShape->at(1); + return Status::OK(); +} - Nd4jLong *newShape = nullptr; - COPY_SHAPE(inShapeInfo, newShape); +DECLARE_SHAPE_FN(clipbynorm_bp) { + auto inShapeInfo = inputShape->at(1); - return SHAPELIST(CONSTANT(newShape)); - } + Nd4jLong *newShape = nullptr; + COPY_SHAPE(inShapeInfo, newShape); - DECLARE_TYPES(clipbynorm) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } + return SHAPELIST(CONSTANT(newShape)); +} - DECLARE_TYPES(clipbynorm_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(clipbynorm) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } + +DECLARE_TYPES(clipbynorm_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/clip_by_value.cpp b/libnd4j/include/ops/declarable/generic/transforms/clip_by_value.cpp index 4275e4837b15..4e7e18ea96d0 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/clip_by_value.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/clip_by_value.cpp @@ -25,30 +25,34 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(clipbyvalue, 1, 1, true, 2, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - // FIXME: extra args!!! - auto left = T_ARG(0); - auto right = T_ARG(1); - - REQUIRE_TRUE(left < right, 0, "clip_by_value: left bound should be lesser than right. But %f >= %f given.", left, right); - //input->applyTransform(transform::ClipByValue, output, block.getTArguments()->data()); - helpers::clipByValue(block.launchContext(), *input, left, right, *output); - //STORE_RESULT(*output); - - return Status::OK(); - } - DECLARE_SYN(ClipByValue, clipbyvalue); - - DECLARE_TYPES(clipbyvalue) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +namespace ops { +CONFIGURABLE_OP_IMPL(clipbyvalue, 1, 1, true, 2, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + // FIXME: extra args!!! + auto left = T_ARG(0); + auto right = T_ARG(1); + + REQUIRE_TRUE(left < right, 0, + "clip_by_value: left bound should be lesser than right. But %f " + ">= %f given.", + left, right); + // input->applyTransform(transform::ClipByValue, output, + // block.getTArguments()->data()); + helpers::clipByValue(block.launchContext(), *input, left, right, *output); + // STORE_RESULT(*output); + + return Status::OK(); } +DECLARE_SYN(ClipByValue, clipbyvalue); + +DECLARE_TYPES(clipbyvalue) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/concat.cpp b/libnd4j/include/ops/declarable/generic/transforms/concat.cpp index 6c0901201156..22c5d15c5747 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/concat.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/concat.cpp @@ -19,421 +19,450 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include +#include +#include + +#include namespace sd { namespace ops { - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 0) { - - REQUIRE_TRUE(block.width() > 0, 0, "CONCAT op: No input arrays were provided"); - - const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0); - - const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); - - // first of all take into account possible presence of empty arrays - // also if scalar is present -> copy its value to vector with length=1 - std::vector nonEmptyArrs; - std::vector arrsToDelete; - int index = 0; - bool allOfSameType = true; - auto rankOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->rankOf() : 0; - auto typeOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->dataType() : block.dataType(); - - for(int i = 0; i < numOfInArrs; ++i) { - auto input = INPUT_VARIABLE(i); - auto currentRank = input->rankOf(); - -// TODO: follow two lines are in accordance to current tf.concat spec. Commented for compatibility with legacy -// REQUIRE_TRUE(currentRank > 0, 0, "Rank of input variable %i must be greater 0, but is %lld instead.", i, currentRank); -// REQUIRE_TRUE(rankOfFirstArr == currentRank, 0, "Number of dimensions in concat should be equals, but for %i input variable %lld != %lld appears.", i, currentRank, rankOfFirstArr); - if(!input->isEmpty()) { - - allOfSameType &= (typeOfFirstArr == input->dataType()); - - if(input->rankOf() == 0) { - auto vec = new NDArray('c', {1}, input->dataType(), block.launchContext()); - vec->assign(input); - nonEmptyArrs.push_back(vec); - arrsToDelete.push_back(index); - } - else{ - nonEmptyArrs.push_back(input); - } - ++index; - } - } - - const int numOfNonEmptyArrs = nonEmptyArrs.size(); - - if(numOfNonEmptyArrs == 0){ - //All inputs are empty arrays -> return empty, mainly for TF import compatibility (no op) - REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "CONCAT op: If all input variables are empty, output must be empty"); - return Status::OK(); - } - - const int rank = nonEmptyArrs[0]->rankOf(); // look up to first non-empty array - int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e(0) : INT_ARG(0); - if(axis < 0){ - axis += rank; - } - - // ******** input validation ******** // - REQUIRE_TRUE(allOfSameType, 0, "CONCAT op: all of input arrays must have same type !"); - REQUIRE_TRUE(nonEmptyArrs[0]->dataType() == OUTPUT_VARIABLE(0)->dataType(), 0, "CONCAT op: output array should have the same type as inputs arrays !"); - REQUIRE_TRUE(0 <= axis && (axis < rank || (axis == 0 && rank == 0)), 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis); - - for(int i = 1; i < numOfNonEmptyArrs; ++i) - REQUIRE_TRUE(nonEmptyArrs[i]->rankOf() == rank, 0, "CONCAT op: all input arrays must have the same rank !"); - - for(int i = 1; i < numOfNonEmptyArrs; ++i) { - for(int dim = 0; dim < rank; ++dim) - if(dim != axis) - REQUIRE_TRUE(nonEmptyArrs[i]->sizeAt(dim) == nonEmptyArrs[0]->sizeAt(dim), 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !"); + REQUIRE_TRUE(block.width() > 0, 0, + "CONCAT op: No input arrays were provided"); + + const bool isAxisInLastArr = block.numB() == 0 ? false : B_ARG(0); + + const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); + + // first of all take into account possible presence of empty arrays + // also if scalar is present -> copy its value to vector with length=1 + std::vector nonEmptyArrs; + std::vector arrsToDelete; + int index = 0; + bool allOfSameType = true; + auto rankOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->rankOf() : 0; + auto typeOfFirstArr = + block.width() > 0 ? INPUT_VARIABLE(0)->dataType() : DataType::FLOAT32; + + for (int i = 0; i < numOfInArrs; ++i) { + auto input = INPUT_VARIABLE(i); + auto currentRank = input->rankOf(); + + // TODO: follow two lines are in accordance to current tf.concat spec. + // Commented for compatibility with legacy + // REQUIRE_TRUE(currentRank > 0, 0, "Rank of input variable %i must + // be greater 0, but is %lld instead.", i, currentRank); + // REQUIRE_TRUE(rankOfFirstArr == currentRank, 0, "Number of + // dimensions in concat should be equals, but for %i input variable + // %lld != %lld appears.", i, currentRank, rankOfFirstArr); + if (!input->isEmpty()) { + allOfSameType &= (typeOfFirstArr == input->dataType()); + + if (input->rankOf() == 0) { + auto vec = + new NDArray('c', {1}, input->dataType(), block.launchContext()); + vec->assign(input); + nonEmptyArrs.push_back(vec); + arrsToDelete.push_back(index); + } else { + nonEmptyArrs.push_back(input); + } + ++index; } - // ******** end of input validation ******** // - - auto output = OUTPUT_VARIABLE(0); - - if(numOfNonEmptyArrs == 1) - output->assign(nonEmptyArrs[0]); - else - helpers::concat(block.launchContext(), nonEmptyArrs, *output, axis); + } - // delete dynamically allocated vectors with length=1 - for(int index : arrsToDelete) - delete nonEmptyArrs[index]; + const int numOfNonEmptyArrs = nonEmptyArrs.size(); + if (numOfNonEmptyArrs == 0) { + // All inputs are empty arrays -> return empty, mainly for TF import + // compatibility (no op) + REQUIRE_TRUE( + OUTPUT_VARIABLE(0)->isEmpty(), 0, + "CONCAT op: If all input variables are empty, output must be empty"); return Status::OK(); + } + + const int rank = + nonEmptyArrs[0]->rankOf(); // look up to first non-empty array + int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e(0) + : INT_ARG(0); + if (axis < 0) { + axis += rank; + } + + // ******** input validation ******** // + REQUIRE_TRUE(allOfSameType, 0, + "CONCAT op: all of input arrays must have same type !"); + REQUIRE_TRUE(nonEmptyArrs[0]->dataType() == OUTPUT_VARIABLE(0)->dataType(), 0, "CONCAT op: output array should have the same type as inputs arrays !");REQUIRE_TRUE( + 0 <= axis && (axis < rank || (axis == 0 && rank == 0)), 0, + "CONCAT op: input axis must be in range [0, %i], but got %i instead!", + rank - 1, axis); + + for (int i = 1; i < numOfNonEmptyArrs; ++i) + REQUIRE_TRUE(nonEmptyArrs[i]->rankOf() == rank, 0, + "CONCAT op: all input arrays must have the same rank !"); + + for (int i = 1; i < numOfNonEmptyArrs; ++i) { + for (int dim = 0; dim < rank; ++dim) + if (dim != axis) + REQUIRE_TRUE( + nonEmptyArrs[i]->sizeAt(dim) == nonEmptyArrs[0]->sizeAt(dim), 0, + "CONCAT op: all input arrays must have the same dimensions (except " + "those on input axis) !"); + } + // ******** end of input validation ******** // + + auto output = OUTPUT_VARIABLE(0); + + if (numOfNonEmptyArrs == 1) + output->assign(nonEmptyArrs[0]); + else + helpers::concat(block.launchContext(), nonEmptyArrs, *output, axis); + + // delete dynamically allocated vectors with length=1 + for (int index : arrsToDelete) delete nonEmptyArrs[index]; + + return Status::OK(); } - DECLARE_SYN(ParallelConcat, concat); - DECLARE_SYN(concat_v2, concat); - DECLARE_SYN(concatv2, concat); +DECLARE_SYN(ParallelConcat, concat); +DECLARE_SYN(concat_v2, concat); +DECLARE_SYN(concatv2, concat); - DECLARE_TYPES(concat) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY); - // ->setSameMode(true); - } +DECLARE_TYPES(concat) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY); + // ->setSameMode(true); +} ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(concat) { - - REQUIRE_TRUE(block.width() > 0, 0, "CONCAT op: No input arrays were provided"); - - const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0); - - const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); - - // first of all take into account possible presence of empty arrays - // also if scalar is present -> use the shape of vector with length=1 instead - ShapeList arrShapes; - std::vector shapesToDelete; - int index = 0; - for(int i = 0; i < numOfInArrs; ++i) { - - if(inputShape->at(i)[0] == 0) { - if (shape::isEmpty(inputShape->at(i))) - arrShapes.push_back(ConstantShapeHelper::getInstance().vectorShapeInfo(0, INPUT_VARIABLE(0)->dataType())); - else - arrShapes.push_back(ConstantShapeHelper::getInstance().vectorShapeInfo(1, INPUT_VARIABLE(0)->dataType())); - } - else{ - arrShapes.push_back(inputShape->at(i)); - } - ++index; + REQUIRE_TRUE(block.width() > 0, 0, + "CONCAT op: No input arrays were provided"); + + const bool isAxisInLastArr = block.numB() == 0 ? false : B_ARG(0); + + const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); + + // first of all take into account possible presence of empty arrays + // also if scalar is present -> use the shape of vector with length=1 instead + ShapeList arrShapes; + std::vector shapesToDelete; + int index = 0; + for (int i = 0; i < numOfInArrs; ++i) { + if (inputShape->at(i)[0] == 0) { + if (shape::isEmpty(inputShape->at(i))) + arrShapes.push_back(ConstantShapeHelper::getInstance().vectorShapeInfo( + 0, INPUT_VARIABLE(0)->dataType())); + else + arrShapes.push_back(ConstantShapeHelper::getInstance().vectorShapeInfo( + 1, INPUT_VARIABLE(0)->dataType())); + } else { + arrShapes.push_back(inputShape->at(i)); } - - const int numOfNonEmptyArrs = arrShapes.size(); - - const int rank = shape::rank(arrShapes.at(0)); - - int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e(0) : INT_ARG(0); - if(axis < 0){ - axis += rank; - } - - // ******** input validation ******** // - REQUIRE_TRUE(0 <= axis && axis < rank, 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis); - - for(int i = 1; i < numOfNonEmptyArrs; ++i) - REQUIRE_TRUE(shape::rank(arrShapes.at(i)) == rank, 0, "CONCAT op: all input arrays must have the same rank !"); - - for(int i = 1; i < numOfNonEmptyArrs; ++i) { - for(int dim = 0; dim < rank; ++dim) - if(dim != axis) - REQUIRE_TRUE(arrShapes.at(i)[dim+1] == arrShapes.at(0)[dim+1], 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !"); - } - // ******** end of input validation ******** // - - - Nd4jLong* outShapeInfo(nullptr); - COPY_SHAPE(arrShapes.at(0), outShapeInfo); - - // case when we have only one input array - if(numOfNonEmptyArrs == 1) { - ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes.at(0), shape::order(arrShapes.at(0))); - return SHAPELIST(CONSTANT(outShapeInfo)); - } - - for(int i = 1; i < numOfNonEmptyArrs; ++i) - outShapeInfo[axis + 1] += arrShapes.at(i)[axis + 1]; - - ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes.at(0), shape::order(arrShapes.at(0))); - - // delete dynamically allocated vectors shapes with length=1 -// for(int index : shapesToDelete) -// RELEASE(arrShapes[index], block.getWorkspace()); - - auto result = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(outShapeInfo)); - RELEASE(outShapeInfo, block.getWorkspace()); - return SHAPELIST(result); + ++index; + } + + const int numOfNonEmptyArrs = arrShapes.size(); + + const int rank = shape::rank(arrShapes.at(0)); + + int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e(0) + : INT_ARG(0); + if (axis < 0) { + axis += rank; + } + + // ******** input validation ******** // + REQUIRE_TRUE( + 0 <= axis && axis < rank, 0, + "CONCAT op: input axis must be in range [0, %i], but got %i instead!", + rank - 1, axis); + + for (int i = 1; i < numOfNonEmptyArrs; ++i) + REQUIRE_TRUE(shape::rank(arrShapes.at(i)) == rank, 0, + "CONCAT op: all input arrays must have the same rank !"); + + for (int i = 1; i < numOfNonEmptyArrs; ++i) { + for (int dim = 0; dim < rank; ++dim) + if (dim != axis) + REQUIRE_TRUE(arrShapes.at(i)[dim + 1] == arrShapes.at(0)[dim + 1], 0, + "CONCAT op: all input arrays must have the same " + "dimensions (except those on input axis) !"); + } + // ******** end of input validation ******** // + + Nd4jLong* outShapeInfo(nullptr); + COPY_SHAPE(arrShapes.at(0), outShapeInfo); + + // case when we have only one input array + if (numOfNonEmptyArrs == 1) { + ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes.at(0), + shape::order(arrShapes.at(0))); + return SHAPELIST(CONSTANT(outShapeInfo)); + } + + for (int i = 1; i < numOfNonEmptyArrs; ++i) + outShapeInfo[axis + 1] += arrShapes.at(i)[axis + 1]; + + ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes.at(0), + shape::order(arrShapes.at(0))); + + // delete dynamically allocated vectors shapes with length=1 + // for(int index : shapesToDelete) + // RELEASE(arrShapes[index], block.workspace()); + + auto result = ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(outShapeInfo)); + RELEASE(outShapeInfo, block.workspace()); + return SHAPELIST(result); } +// ////////////////////////////////////////////////////////////////////////// +// CUSTOM_OP_IMPL(concat, -1, 1, false, 0, -2){ +// // do something here{ +// NDArray *last = INPUT_VARIABLE((int) block.width() - 1); - // ////////////////////////////////////////////////////////////////////////// - // CUSTOM_OP_IMPL(concat, -1, 1, false, 0, -2){ - // // do something here{ - // NDArray *last = INPUT_VARIABLE((int) block.width() - 1); - - // int _dimension = 0; - // if (block.numI() > 0) - // _dimension = INT_ARG(0); - // else { - // _dimension = (int) last->e(0); - // } +// int _dimension = 0; +// if (block.numI() > 0) +// _dimension = INT_ARG(0); +// else { +// _dimension = (int) last->e(0); +// } - // // we want to ensure that all - // NDArray *first = nullptr; - // auto output = OUTPUT_VARIABLE(0); +// // we want to ensure that all +// NDArray *first = nullptr; +// auto output = OUTPUT_VARIABLE(0); - // int elements = 0; +// int elements = 0; - // for (int e = 0; e < block.width(); e++) { - // auto arr = INPUT_VARIABLE(e); - // if (!arr->isEmpty()) - // elements++; +// for (int e = 0; e < block.width(); e++) { +// auto arr = INPUT_VARIABLE(e); +// if (!arr->isEmpty()) +// elements++; - // // we must find first non-empty element here - // if (!arr->isEmpty() && first == nullptr) - // first = arr; - // } +// // we must find first non-empty element here +// if (!arr->isEmpty() && first == nullptr) +// first = arr; +// } - // REQUIRE_TRUE(first != nullptr, 0, "Concat: at least 1 non-empty input required!"); +// REQUIRE_TRUE(first != nullptr, 0, "Concat: at least 1 non-empty input +// required!"); - // // it's possible to get into situation when your input has only 1 input. That's just assign - // if (elements == 1) { - // output->assign(first); - // return Status::OK(); - // } +// // it's possible to get into situation when your input has only 1 input. +// That's just assign if (elements == 1) { +// output->assign(first); +// return Status::OK(); +// } - // bool oldScalars = first->rankOf() == 2 && first->isScalar(); +// bool oldScalars = first->rankOf() == 2 && first->isScalar(); - // auto buffers = new Nd4jPointer[elements]; - // auto shapes = new Nd4jPointer[elements]; +// auto buffers = new Nd4jPointer[elements]; +// auto shapes = new Nd4jPointer[elements]; - // buffers[0] = (Nd4jPointer) first->buffer(); - // shapes[0] = (Nd4jPointer) first->shapeInfo(); +// buffers[0] = (Nd4jPointer) first->buffer(); +// shapes[0] = (Nd4jPointer) first->shapeInfo(); - // if (_dimension < 0) - // _dimension += first->rankOf(); +// if (_dimension < 0) +// _dimension += first->rankOf(); - // if (sd::Environment::getInstance().isDebugAndVerbose()) { - // printf("Shape %i: ", 0); - // shape::printShapeInfoLinear((Nd4jLong *) shapes[0]); - // } +// if (sd::Environment::getInstance().isDebugAndVerbose()) { +// printf("Shape %i: ", 0); +// shape::printShapeInfoLinear((Nd4jLong *) shapes[0]); +// } - // int er = 0; - // for (int e = 0; e < block.width(); e++) { - // Variable *var = block.variable(e); - // auto array = var->getNDArray(); +// int er = 0; +// for (int e = 0; e < block.width(); e++) { +// Variable *var = block.variable(e); +// auto array = var->getNDArray(); - // if (array->isEmpty()) - // continue; +// if (array->isEmpty()) +// continue; - // buffers[er] = reinterpret_cast(array->buffer()); - // shapes[er++] = reinterpret_cast(array->shapeInfo()); +// buffers[er] = reinterpret_cast(array->buffer()); +// shapes[er++] = reinterpret_cast(array->shapeInfo()); - // oldScalars &= array->rankOf() == 2 && array->isScalar(); +// oldScalars &= array->rankOf() == 2 && array->isScalar(); - // if (sd::Environment::getInstance().isDebugAndVerbose()) { - // printf("Shape %i: ", e); - // shape::printShapeInfoLinear(array->shapeInfo()); - // } - // } - // if (sd::Environment::getInstance().isDebugAndVerbose()) - // fflush(stdout); +// if (sd::Environment::getInstance().isDebugAndVerbose()) { +// printf("Shape %i: ", e); +// shape::printShapeInfoLinear(array->shapeInfo()); +// } +// } +// if (sd::Environment::getInstance().isDebugAndVerbose()) +// fflush(stdout); - // if (oldScalars) { - // nd4j_debug("OLD_SCALARS!\n",""); - // _dimension = 1; - // } +// if (oldScalars) { +// nd4j_debug("OLD_SCALARS!\n",""); +// _dimension = 1; +// } - // sd::SpecialMethods::concatCpuGeneric(_dimension, elements, buffers, shapes, output->buffer(), output->shapeInfo()); +// sd::SpecialMethods::concatCpuGeneric(_dimension, elements, buffers, +// shapes, output->buffer(), output->shapeInfo()); - // STORE_RESULT(*output); +// STORE_RESULT(*output); - // if (sd::Environment::getInstance().isDebugAndVerbose()) - // output->printShapeInfo("Concat result shape"); +// if (sd::Environment::getInstance().isDebugAndVerbose()) +// output->printShapeInfo("Concat result shape"); - // delete[] buffers; - // delete[] shapes; +// delete[] buffers; +// delete[] shapes; - // return ND4J_STATUS_OK; - // } +// return ND4J_STATUS_OK; +// } - // DECLARE_SYN(ParallelConcat, concat); - // DECLARE_SYN(concat_v2, concat); - // DECLARE_SYN(concatv2, concat); +// DECLARE_SYN(ParallelConcat, concat); +// DECLARE_SYN(concat_v2, concat); +// DECLARE_SYN(concatv2, concat); - // DECLARE_SHAPE_FN(concat) { - // auto inp = inputShape->at(0); - // int _dimension = INT_ARG(0); +// DECLARE_SHAPE_FN(concat) { +// auto inp = inputShape->at(0); +// int _dimension = INT_ARG(0); - // NDArray* first = nullptr; - // auto last = inputShape->at(inputShape->size() - 1); +// NDArray* first = nullptr; +// auto last = inputShape->at(inputShape->size() - 1); - // Nd4jLong elements = 0; - // Nd4jLong *newShape; +// Nd4jLong elements = 0; +// Nd4jLong *newShape; - // for (int e = 0; e < inputShape->size(); e++) { - // auto s = INPUT_VARIABLE(e); +// for (int e = 0; e < inputShape->size(); e++) { +// auto s = INPUT_VARIABLE(e); - // if (!s->isEmpty()) { - // elements++; +// if (!s->isEmpty()) { +// elements++; - // if (first == nullptr) - // first = s; - // } - // } +// if (first == nullptr) +// first = s; +// } +// } +// { // special cases for 0D concat +// bool allScalars = true; +// bool hasScalars = false; +// for (int e = 0; e < block.width(); e++) { +// auto c = INPUT_VARIABLE(e); - // { // special cases for 0D concat - // bool allScalars = true; - // bool hasScalars = false; - // for (int e = 0; e < block.width(); e++) { - // auto c = INPUT_VARIABLE(e); +// if (c->isEmpty()) +// continue; - // if (c->isEmpty()) - // continue; +// allScalars &= c->rankOf() == 0; +// hasScalars |= c->rankOf() == 0; +// } - // allScalars &= c->rankOf() == 0; - // hasScalars |= c->rankOf() == 0; - // } +// // all scalars +// if (allScalars) { +// ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(1), +// Nd4jLong); - // // all scalars - // if (allScalars) { - // ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(1), Nd4jLong); +// shape::shapeBuffer(1, &elements, newShape); +// return SHAPELIST(newShape); +// } - // shape::shapeBuffer(1, &elements, newShape); - // return SHAPELIST(newShape); - // } +// // any scalar +// if (hasScalars) { +// ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(1), +// Nd4jLong); Nd4jLong length = shape::length(inp); for (int i = 1; +// i < block.width(); i++) { +// auto c = INPUT_VARIABLE(i); +// if (c->isEmpty()) +// continue; - // // any scalar - // if (hasScalars) { - // ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(1), Nd4jLong); - // Nd4jLong length = shape::length(inp); - // for (int i = 1; i < block.width(); i++) { - // auto c = INPUT_VARIABLE(i); - // if (c->isEmpty()) - // continue; +// length += c->lengthOf(); +// } - // length += c->lengthOf(); - // } +// shape::shapeBuffer(1, &length, newShape); +// return SHAPELIST(newShape); +// } +// } - // shape::shapeBuffer(1, &length, newShape); - // return SHAPELIST(newShape); - // } - // } +// ALLOCATE(newShape, block.workspace(), +// shape::shapeInfoLength(first->shapeInfo()), Nd4jLong); +// if (_dimension < 0) +// _dimension += first->rankOf(); - // ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(first->shapeInfo()), Nd4jLong); +// std::memcpy(newShape, first->shapeInfo(), +// shape::shapeInfoByteLength(first->shapeInfo())); for (int i = 0; i < +// inputShape->size(); i++) { +// auto s = INPUT_VARIABLE(i); - // if (_dimension < 0) - // _dimension += first->rankOf(); +// // FIXME: s == first is bad, but fast. alternatively we can subtract +// first size out of result if (s->isEmpty() || s == first) +// continue; - // std::memcpy(newShape, first->shapeInfo(), shape::shapeInfoByteLength(first->shapeInfo())); - // for (int i = 0; i < inputShape->size(); i++) { - // auto s = INPUT_VARIABLE(i); +// newShape[_dimension + 1] += +// shape::shapeOf(inputShape->at(i))[_dimension]; +// } - // // FIXME: s == first is bad, but fast. alternatively we can subtract first size out of result - // if (s->isEmpty() || s == first) - // continue; +// shape::updateStrides(newShape, first->ordering()); - // newShape[_dimension + 1] += shape::shapeOf(inputShape->at(i))[_dimension]; - // } - - // shape::updateStrides(newShape, first->ordering()); - - // return SHAPELIST(newShape); - // } +// return SHAPELIST(newShape); +// } ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(concat_bp, -1, -1, false, 0, 0) { + const bool isAxisInLastArr = block.numB() == 0 ? false : B_ARG(0); - const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0); - - const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); + const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); - auto epsilonNext = INPUT_VARIABLE(numOfInArrs - 1); + auto epsilonNext = INPUT_VARIABLE(numOfInArrs - 1); - auto first = INPUT_VARIABLE(0); + auto first = INPUT_VARIABLE(0); - const int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e(0) : (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + INPUT_VARIABLE(0)->rankOf()); + const int axis = + isAxisInLastArr + ? INPUT_VARIABLE(block.width() - 1)->e(0) + : (INT_ARG(0) >= 0 ? INT_ARG(0) + : INT_ARG(0) + INPUT_VARIABLE(0)->rankOf()); - int startPos = 0; + int startPos = 0; - for (int e = 0; e < numOfInArrs - 1; e++) { - auto originalChunk = INPUT_VARIABLE(e); - auto epsilonChunk = OUTPUT_VARIABLE(e); - std::vector indices(2 * epsilonNext->rankOf()); + for (int e = 0; e < numOfInArrs - 1; e++) { + auto originalChunk = INPUT_VARIABLE(e); + auto epsilonChunk = OUTPUT_VARIABLE(e); + std::vector indices(2 * epsilonNext->rankOf()); - int width = originalChunk->sizeAt(axis); + int width = originalChunk->sizeAt(axis); - for (int e = 0; e < epsilonNext->rankOf(); e++) { - if (e == axis) - indices[2*e + 1] = (indices[2*e] = startPos) + width; - else - indices[2*e + 1] = indices[2*e] = 0; - } + for (int e = 0; e < epsilonNext->rankOf(); e++) { + if (e == axis) + indices[2 * e + 1] = (indices[2 * e] = startPos) + width; + else + indices[2 * e + 1] = indices[2 * e] = 0; + } - auto subarray = (*epsilonNext)(indices, true); - epsilonChunk->assign(subarray); + auto subarray = (*epsilonNext)(indices, true); + epsilonChunk->assign(subarray); - startPos += width; - } + startPos += width; + } - return ND4J_STATUS_OK; + return ND4J_STATUS_OK; } DECLARE_TYPES(concat_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } DECLARE_SHAPE_FN(concat_bp) { + const bool isAxisInLastArr = block.numB() == 0 ? false : B_ARG(0); - const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0); + const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); - const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); + auto shapeList = SHAPELIST(); - auto shapeList = SHAPELIST(); + for (int e = 0; e < numOfInArrs - 1; e++) { + auto inShape = inputShape->at(e); + shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), + shape::shapeOf(inShape), shape::rank(inShape)))); + } - for (int e = 0; e < numOfInArrs - 1; e++) { - auto inShape = inputShape->at(e); - shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), shape::shapeOf(inShape), shape::rank(inShape)))); - } - - return shapeList; + return shapeList; } - -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/transforms/cumprod.cpp b/libnd4j/include/ops/declarable/generic/transforms/cumprod.cpp index c0b011f997f1..66910591290f 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/cumprod.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/cumprod.cpp @@ -21,138 +21,144 @@ #include #if NOT_EXCLUDED(OP_cumprod) -#include #include +#include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(cumprod, 1, 1, true, 0, 2) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(input->dataType() == output->dataType(), 0, "CumSum: input and output data types must be equal"); - - if(input->isEmpty()){ - //No-op - return Status::OK(); - } - - const bool exclusive = INT_ARG(0) == 1; - const bool reverse = INT_ARG(1) == 1; - - if (block.getIArguments()->size() == 2 && block.width() == 1) { - // all at once case - sd::ops::helpers::prefix(block.launchContext(), scalar::Multiply, input, output, exclusive, reverse); - } else { - std::vector dims(block.numI() - 2); - - if (block.width() == 1) { - - for (int e = 0; e < block.numI() - 2; e++) - dims[e] = INT_ARG(e + 2); - } else { - auto ax = INPUT_VARIABLE(1); - dims = ax->template asVectorT(); - } - - for (int e = 0; e < dims.size(); e++) - if (dims[e] < 0) - dims[e] += input->rankOf(); - - sd::ops::helpers::prefix(block.launchContext(), scalar::Multiply, input, output, dims, exclusive, reverse); - } - - return Status::OK(); - } - - DECLARE_TYPES(cumprod) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(true); - } - - DECLARE_TYPES(cumprod_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS}) // there is a case when axes given as IArgs - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(true); - } - - CUSTOM_OP_IMPL(cumprod_bp, 2, 1, false, 0, 2) { - auto input = INPUT_VARIABLE(0); - auto axis = block.width() == 3 ? INPUT_VARIABLE(1) : nullptr; - auto gradOut = block.width() == 3 ? INPUT_VARIABLE(2) : INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - - const bool exclusive = INT_ARG(0) == 1; - const bool reverse = INT_ARG(1) == 1; - - std::vector dims; - - if (block.width() > 2) { - dims = axis->template asVectorT(); - OUTPUT_VARIABLE(1)->assign(1.0f); - } else if (int newSize = (block.numI() - 2)) { - dims.resize(newSize); - - for (int e = 0; e < newSize; e++) - dims[e] = INT_ARG(e + 2); - } - - sd::ops::helpers::prefix(block.launchContext(), scalar::Multiply, input, output, dims, exclusive, reverse); - NDArray val = NDArray(output->dup()); - - gradOut->applyPairwiseTransform(pairwise::Multiply, *output, val); - val.applyPairwiseTransform(pairwise::Divide, *input, val); - if (!exclusive && !reverse) { - if (dims.size()) - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, dims, true, false); - else - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, false, true); - - } - else if (!exclusive && reverse){ - if (dims.size()) - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, dims, false, false); - else - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, false, false); - } - else if (exclusive && !reverse) { - if (dims.size()) - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, dims, true, true); - else - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, true, true); - } - else { - if (dims.size()) - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, dims, true, false); - else - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, true, false); - } - - return Status::OK(); - } - - - DECLARE_SHAPE_FN(cumprod_bp) { - auto inp = inputShape->at(0); - Nd4jLong *newShapeX = nullptr; - COPY_SHAPE(inp, newShapeX); - - if (block.width() == 2) { - return SHAPELIST(CONSTANT(newShapeX)); - } else { - Nd4jLong *newShapeA = nullptr; - COPY_SHAPE(inputShape->at(1), newShapeA); - - return SHAPELIST(CONSTANT(newShapeX), CONSTANT(newShapeA)); - } - } +namespace ops { +CONFIGURABLE_OP_IMPL(cumprod, 1, 1, true, 0, 2) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(input->dataType() == output->dataType(), 0, + "CumSum: input and output data types must be equal"); + + if (input->isEmpty()) { + // No-op + return Status::OK(); + } + + const bool exclusive = INT_ARG(0) == 1; + const bool reverse = INT_ARG(1) == 1; + + if (block.numI() == 2 && block.width() == 1) { + // all at once case + sd::ops::helpers::prefix(block.launchContext(), scalar::Multiply, input, + output, exclusive, reverse); + } else { + std::vector dims(block.numI() - 2); + + if (block.width() == 1) { + for (int e = 0; e < block.numI() - 2; e++) dims[e] = INT_ARG(e + 2); + } else { + auto ax = INPUT_VARIABLE(1); + dims = ax->template asVectorT(); } + + for (int e = 0; e < dims.size(); e++) + if (dims[e] < 0) dims[e] += input->rankOf(); + + sd::ops::helpers::prefix(block.launchContext(), scalar::Multiply, input, + output, dims, exclusive, reverse); + } + + return Status::OK(); +} + +DECLARE_TYPES(cumprod) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(true); +} + +DECLARE_TYPES(cumprod_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes( + 1, + {ALL_INTS, ALL_FLOATS}) // there is a case when axes given as IArgs + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(true); +} + +CUSTOM_OP_IMPL(cumprod_bp, 2, 1, false, 0, 2) { + auto input = INPUT_VARIABLE(0); + auto axis = block.width() == 3 ? INPUT_VARIABLE(1) : nullptr; + auto gradOut = block.width() == 3 ? INPUT_VARIABLE(2) : INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + + const bool exclusive = INT_ARG(0) == 1; + const bool reverse = INT_ARG(1) == 1; + + std::vector dims; + + if (block.width() > 2) { + dims = axis->template asVectorT(); + OUTPUT_VARIABLE(1)->assign(1.0f); + } else if (int newSize = (block.numI() - 2)) { + dims.resize(newSize); + + for (int e = 0; e < newSize; e++) dims[e] = INT_ARG(e + 2); + } + + sd::ops::helpers::prefix(block.launchContext(), scalar::Multiply, input, + output, dims, exclusive, reverse); + NDArray val = NDArray(output->dup()); + + gradOut->applyPairwiseTransform(pairwise::Multiply, *output, val); + val.applyPairwiseTransform(pairwise::Divide, *input, val); + if (!exclusive && !reverse) { + if (dims.size()) + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, + dims, true, false); + else + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, + false, true); + + } else if (!exclusive && reverse) { + if (dims.size()) + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, + dims, false, false); + else + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, + false, false); + } else if (exclusive && !reverse) { + if (dims.size()) + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, + dims, true, true); + else + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, + true, true); + } else { + if (dims.size()) + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, + dims, true, false); + else + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, + true, false); + } + + return Status::OK(); +} + +DECLARE_SHAPE_FN(cumprod_bp) { + auto inp = inputShape->at(0); + Nd4jLong *newShapeX = nullptr; + COPY_SHAPE(inp, newShapeX); + + if (block.width() == 2) { + return SHAPELIST(CONSTANT(newShapeX)); + } else { + Nd4jLong *newShapeA = nullptr; + COPY_SHAPE(inputShape->at(1), newShapeA); + + return SHAPELIST(CONSTANT(newShapeX), CONSTANT(newShapeA)); + } } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/cumsum.cpp b/libnd4j/include/ops/declarable/generic/transforms/cumsum.cpp index 97389fddbfb5..7879fe5a55d5 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/cumsum.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/cumsum.cpp @@ -21,131 +21,133 @@ #include #if NOT_EXCLUDED(OP_cumsum) -#include #include +#include namespace sd { -namespace ops { +namespace ops { CONFIGURABLE_OP_IMPL(cumsum, 1, 1, true, 0, 2) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - const bool exclusive = INT_ARG(0) == 1; - const bool reverse = INT_ARG(1) == 1; + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(input->dataType() == output->dataType(), 0, "CumSum: input and output data types must be equal"); + const bool exclusive = INT_ARG(0) == 1; + const bool reverse = INT_ARG(1) == 1; - if(input->isEmpty()){ - //No-op - return Status::OK(); - } + REQUIRE_TRUE(input->dataType() == output->dataType(), 0, + "CumSum: input and output data types must be equal"); - if (block.getIArguments()->size() == 2 && block.width() == 1) { - // all at once case - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, input, output, exclusive, reverse); + if (input->isEmpty()) { + // No-op + return Status::OK(); + } + + if (block.numI() == 2 && block.width() == 1) { + // all at once case + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, input, output, + exclusive, reverse); + } else { + std::vector dims(block.numI() - 2); + + if (block.width() == 1) { + for (int e = 0; e < block.numI() - 2; e++) dims[e] = INT_ARG(e + 2); + } else { + auto ax = INPUT_VARIABLE(1); + dims = ax->template asVectorT(); } - else { - std::vector dims(block.numI() - 2); - - if (block.width() == 1) { - for (int e = 0; e < block.numI() - 2; e++) - dims[e] = INT_ARG(e + 2); - } - else { - auto ax = INPUT_VARIABLE(1); - dims = ax->template asVectorT(); - } + for (int e = 0; e < dims.size(); e++) + if (dims[e] < 0) dims[e] += input->rankOf(); - for (int e = 0; e < dims.size(); e++) - if (dims[e] < 0) - dims[e] += input->rankOf(); + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, input, output, + dims, exclusive, reverse); + } - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, input, output, dims, exclusive, reverse); - } - - return Status::OK(); + return Status::OK(); +} +DECLARE_TYPES(cumsum) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(false); } - DECLARE_TYPES(cumsum) { - - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(false); - } CUSTOM_OP_IMPL(cumsum_bp, 2, -1, true, 0, 2) { - auto input = INPUT_VARIABLE(0); - auto axis = block.width() == 3 ? INPUT_VARIABLE(1) : nullptr; - auto gradOut = block.width() == 3 ? INPUT_VARIABLE(2) : INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); -// output->assign(gradOut); - const bool exclusive = INT_ARG(0) == 1; - const bool reverse = INT_ARG(1) == 1; - - std::vector dims; - - if (block.width() > 2) { - dims = axis->template asVectorT(); - OUTPUT_VARIABLE(1)->assign(1.0f); - } else if (int newSize = (block.numI() - 2)) { - dims.resize(newSize); - - for (int e = 0; e < newSize; e++) - dims[e] = INT_ARG(e + 2); - } - if (!exclusive && !reverse) { - if (dims.size()) - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, output, dims, false, true); - else - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, output, false, true); - - } - else if (!exclusive && reverse){ - if (dims.size()) - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, output, dims, false, false); - else - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, output, false, false); - } - else if (exclusive && !reverse) { - if (dims.size()) - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, output, dims, true, true); - else - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, output, true, true); - } - else { - if (dims.size()) - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, output, dims, true, false); - else - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, output, true, false); - } - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto axis = block.width() == 3 ? INPUT_VARIABLE(1) : nullptr; + auto gradOut = block.width() == 3 ? INPUT_VARIABLE(2) : INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + // output->assign(gradOut); + const bool exclusive = INT_ARG(0) == 1; + const bool reverse = INT_ARG(1) == 1; + + std::vector dims; + + if (block.width() > 2) { + dims = axis->template asVectorT(); + OUTPUT_VARIABLE(1)->assign(1.0f); + } else if (int newSize = (block.numI() - 2)) { + dims.resize(newSize); + + for (int e = 0; e < newSize; e++) dims[e] = INT_ARG(e + 2); + } + if (!exclusive && !reverse) { + if (dims.size()) + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, + output, dims, false, true); + else + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, + output, false, true); + + } else if (!exclusive && reverse) { + if (dims.size()) + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, + output, dims, false, false); + else + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, + output, false, false); + } else if (exclusive && !reverse) { + if (dims.size()) + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, + output, dims, true, true); + else + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, + output, true, true); + } else { + if (dims.size()) + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, + output, dims, true, false); + else + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, + output, true, false); + } + + return Status::OK(); +} +DECLARE_TYPES(cumsum_bp) { + getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}); + getOpDescriptor()->setAllowedInputTypes( + 1, {ALL_FLOATS, ALL_INTS}); // axes can be set as the second param + getOpDescriptor()->setAllowedInputTypes(2, {ALL_FLOATS}); + getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS}); } - DECLARE_TYPES(cumsum_bp) { - getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}); - getOpDescriptor()->setAllowedInputTypes(1, {ALL_FLOATS, ALL_INTS}); // axes can be set as the second param - getOpDescriptor()->setAllowedInputTypes(2, {ALL_FLOATS}); - getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS}); - } - DECLARE_SHAPE_FN(cumsum_bp) { - auto inp = inputShape->at(0); - Nd4jLong *newShapeX = nullptr; - COPY_SHAPE(inp, newShapeX); +DECLARE_SHAPE_FN(cumsum_bp) { + auto inp = inputShape->at(0); + Nd4jLong *newShapeX = nullptr; + COPY_SHAPE(inp, newShapeX); - if (block.width() == 2) { - return SHAPELIST(CONSTANT(newShapeX)); - } else { - Nd4jLong *newShapeA = nullptr; - COPY_SHAPE(inputShape->at(1), newShapeA); + if (block.width() == 2) { + return SHAPELIST(CONSTANT(newShapeX)); + } else { + Nd4jLong *newShapeA = nullptr; + COPY_SHAPE(inputShape->at(1), newShapeA); - return SHAPELIST(CONSTANT(newShapeX), CONSTANT(newShapeA)); - } - } -} + return SHAPELIST(CONSTANT(newShapeX), CONSTANT(newShapeA)); + } } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/depth_to_space.cpp b/libnd4j/include/ops/declarable/generic/transforms/depth_to_space.cpp index cb966472f19e..52d989fb3ad0 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/depth_to_space.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/depth_to_space.cpp @@ -23,69 +23,73 @@ #include #include + #include namespace sd { namespace ops { - CUSTOM_OP_IMPL(depth_to_space, 1, 1, false, 0, 2) { - int block_size = INT_ARG(0); - bool isNHWC = INT_ARG(1) == 1; - - auto input = INPUT_VARIABLE(0); - - REQUIRE_TRUE(input->rankOf() == 4, 0, "DepthToSpace: input should be 4D array, but got %f instead", input->rankOf()); - - int bS = input->sizeAt(0); - int iD = isNHWC ? input->sizeAt(3) : input->sizeAt(1); - int iH = isNHWC ? input->sizeAt(1) : input->sizeAt(2); - int iW = isNHWC ? input->sizeAt(2) : input->sizeAt(3); - - REQUIRE_TRUE(iD % (block_size * block_size) == 0, 0, "DepthToSpace: input number of channels should be divisible by square(block_size)"); - - auto output = OUTPUT_VARIABLE(0); - - if (shape::strideDescendingCAscendingF(input->shapeInfo())) - helpers::_depthToSpace(block.launchContext(), *input, output, block_size, isNHWC); - else - helpers::_depthToSpace(block.launchContext(), input->dup(), output, block_size, isNHWC); - - STORE_RESULT(output); - - return ND4J_STATUS_OK; - } - - DECLARE_TYPES(depth_to_space) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - - - DECLARE_SHAPE_FN(depth_to_space) { - auto in = inputShape->at(0); - auto block_size = INT_ARG(0); - bool isNHWC = INT_ARG(1) == 1; - - int bS = shape::sizeAt(in, 0); - int iD = isNHWC ? shape::sizeAt(in, 3) : shape::sizeAt(in, 1); - int iH = isNHWC ? shape::sizeAt(in, 1) : shape::sizeAt(in, 2); - int iW = isNHWC ? shape::sizeAt(in, 2) : shape::sizeAt(in, 3); - - int oD = iD / (block_size * block_size); - int oH = iH * block_size; - int oW = iW * block_size; - - - std::array shape; - if (isNHWC) - shape = {{bS, oH, oW, oD }}; - else - shape = {{bS, oD, oH, oW }}; - - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(in), 'c', 4, shape.data()); - return SHAPELIST(newShape); - } +CUSTOM_OP_IMPL(depth_to_space, 1, 1, false, 0, 2) { + int block_size = INT_ARG(0); + bool isNHWC = INT_ARG(1) == 1; + + auto input = INPUT_VARIABLE(0); + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "DepthToSpace: input should be 4D array, but got %f instead", + input->rankOf()); + + int bS = input->sizeAt(0); + int iD = isNHWC ? input->sizeAt(3) : input->sizeAt(1); + int iH = isNHWC ? input->sizeAt(1) : input->sizeAt(2); + int iW = isNHWC ? input->sizeAt(2) : input->sizeAt(3); + + REQUIRE_TRUE(iD % (block_size * block_size) == 0, 0, + "DepthToSpace: input number of channels should be divisible by " + "square(block_size)"); + + auto output = OUTPUT_VARIABLE(0); + + if (shape::strideDescendingCAscendingF(input->shapeInfo())) + helpers::_depthToSpace(block.launchContext(), *input, output, block_size, + isNHWC); + else + helpers::_depthToSpace(block.launchContext(), input->dup(), output, + block_size, isNHWC); + + STORE_RESULT(output); + + return ND4J_STATUS_OK; } + +DECLARE_TYPES(depth_to_space) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} + +DECLARE_SHAPE_FN(depth_to_space) { + auto in = inputShape->at(0); + auto block_size = INT_ARG(0); + bool isNHWC = INT_ARG(1) == 1; + + int bS = shape::sizeAt(in, 0); + int iD = isNHWC ? shape::sizeAt(in, 3) : shape::sizeAt(in, 1); + int iH = isNHWC ? shape::sizeAt(in, 1) : shape::sizeAt(in, 2); + int iW = isNHWC ? shape::sizeAt(in, 2) : shape::sizeAt(in, 3); + + int oD = iD / (block_size * block_size); + int oH = iH * block_size; + int oW = iW * block_size; + + std::array shape; + if (isNHWC) + shape = {{bS, oH, oW, oD}}; + else + shape = {{bS, oD, oH, oW}}; + + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(in), 'c', 4, shape.data()); + return SHAPELIST(newShape); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/dynamic_parititon.cpp b/libnd4j/include/ops/declarable/generic/transforms/dynamic_parititon.cpp index 6a055c02c380..08cea23d3ee6 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/dynamic_parititon.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/dynamic_parititon.cpp @@ -22,131 +22,138 @@ #if NOT_EXCLUDED(OP_dynamic_partition) #include -#include #include +#include + namespace sd { namespace ops { - CUSTOM_OP_IMPL(dynamic_partition, 2, 1, false, 0, 1) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - - // input->printShapeInfo("input"); - // indices->printShapeInfo("indices"); - - REQUIRE_TRUE(input->rankOf() >= indices->rankOf(), 0, - "dynamic_partition: data tensor rank should be non-lesser than indices\' tensor, but %i < %i given,", - input->rankOf(), indices->rankOf()); - for (int dim = 0; dim < indices->rankOf(); dim++) { - REQUIRE_TRUE(input->sizeAt(dim) == indices->sizeAt(dim), 0, - "dynamic_partition: dimensions should be equals for data and indices tensors, but at axis[%i] %i != %i given", - dim, input->sizeAt(dim), indices->sizeAt(dim)); - } - - auto numPartition = INT_ARG(0); - std::vector outputList(numPartition); - for (int o = 0; o < numPartition; ++o) { - outputList[o] = OUTPUT_VARIABLE(o); - } - helpers::dynamicPartitionFunctor(block.launchContext(), input, indices, outputList); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(dynamic_partition) { - auto numPartition = INT_ARG(0); - auto indices = INPUT_VARIABLE(1); - std::vector partitionSizes(numPartition, 0); - auto in = inputShape->at(0); - auto idx = inputShape->at(1); - for (int i = 0; i < numPartition; i++) { - for (int e = 0; e < indices->lengthOf(); ++e) - if (indices->e(e) == i) - partitionSizes[i]++; - } - - auto shapes = SHAPELIST(); - int outRank = shape::rank(in) - shape::rank(idx) + 1; - for (int e = 0; e < numPartition; e++) { - Nd4jLong *newShape; - ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); - //shape::shapeVector(partitionSizes[e], newShape); - newShape[0] = outRank; - newShape[1] = partitionSizes[e]; - for (int i = 1; i < outRank; ++i) - newShape[i + 1] = shape::sizeAt(in, outRank + i - 1); - - shape::updateStrides(newShape, shape::order(in)); - ArrayOptions::setDataType(newShape, ArrayOptions::dataType(in)); - shapes->push_back(CONSTANT(newShape)); - } - - return shapes; - } - - DECLARE_TYPES(dynamic_partition) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); - } - - DECLARE_TYPES(dynamic_partition_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - - CUSTOM_OP_IMPL(dynamic_partition_bp, 3, 2, false, 0, 1) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - //auto gradOut = ; - auto numPartition = INT_ARG(0); - - std::vector outputList(2); // only for output - std::vector gradOutList(numPartition); - for (Nd4jLong e = 0; e < numPartition; e++) { - gradOutList[e] = INPUT_VARIABLE(e + 2); - } - outputList[0] = OUTPUT_VARIABLE(0); - outputList[1] = OUTPUT_VARIABLE(1); - NDArray originalIndices(*indices); //->ordering(), indices->shapeInfo(), indices->dataType()); - originalIndices.linspace(0); - ops::dynamic_partition op; - auto res = op.evaluate({&originalIndices, indices}, {numPartition}); - REQUIRE_TRUE(res.status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning."); - ops::dynamic_stitch stichOp; - std::vector partitions(numPartition * 2); - for (size_t i = 0; i < res.size(); i++) { - partitions[i] = res.at(i); - partitions[i + numPartition] = gradOutList[i]; - } - - auto result = stichOp.evaluate(partitions, {numPartition}); - REQUIRE_TRUE(result.status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning."); - result.at(0)->reshapei(outputList[0]->getShapeAsVector()); - outputList[1]->assign(indices); - outputList[0]->assign(result.at(0)); - -// helpers::dynamicPartitionFunctorBP(block.launchContext(), input, indices, gradOutList, outputList); - return ND4J_STATUS_OK; - } - - DECLARE_SHAPE_FN(dynamic_partition_bp) { - auto numPartition = INT_ARG(0); - auto indices = INPUT_VARIABLE(1); - std::vector partitionSizes(numPartition, 0); - - auto shapes = SHAPELIST(); - // just copy shape info from input and indices to output - for (Nd4jLong i = 0; i < 2; i++) { - Nd4jLong *newShape; - COPY_SHAPE(inputShape->at(i), newShape); - shapes->push_back(CONSTANT(newShape)); - } - - return shapes; - } +CUSTOM_OP_IMPL(dynamic_partition, 2, 1, false, 0, 1) { + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + + // input->printShapeInfo("input"); + // indices->printShapeInfo("indices"); + + REQUIRE_TRUE(input->rankOf() >= indices->rankOf(), 0, + "dynamic_partition: data tensor rank should be non-lesser than " + "indices\' tensor, but %i < %i given,", + input->rankOf(), indices->rankOf()); + for (int dim = 0; dim < indices->rankOf(); dim++) { + REQUIRE_TRUE(input->sizeAt(dim) == indices->sizeAt(dim), 0, + "dynamic_partition: dimensions should be equals for data and " + "indices tensors, but at axis[%i] %i != %i given", + dim, input->sizeAt(dim), indices->sizeAt(dim)); + } + + auto numPartition = INT_ARG(0); + std::vector outputList(numPartition); + for (int o = 0; o < numPartition; ++o) { + outputList[o] = OUTPUT_VARIABLE(o); + } + helpers::dynamicPartitionFunctor(block.launchContext(), input, indices, + outputList); + + return Status::OK(); +} + +DECLARE_SHAPE_FN(dynamic_partition) { + auto numPartition = INT_ARG(0); + auto indices = INPUT_VARIABLE(1); + std::vector partitionSizes(numPartition, 0); + auto in = inputShape->at(0); + auto idx = inputShape->at(1); + for (int i = 0; i < numPartition; i++) { + for (int e = 0; e < indices->lengthOf(); ++e) + if (indices->e(e) == i) partitionSizes[i]++; + } + + auto shapes = SHAPELIST(); + int outRank = shape::rank(in) - shape::rank(idx) + 1; + for (int e = 0; e < numPartition; e++) { + Nd4jLong *newShape; + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + // shape::shapeVector(partitionSizes[e], newShape); + newShape[0] = outRank; + newShape[1] = partitionSizes[e]; + for (int i = 1; i < outRank; ++i) + newShape[i + 1] = shape::sizeAt(in, outRank + i - 1); + + shape::updateStrides(newShape, shape::order(in)); + ArrayOptions::setDataType(newShape, ArrayOptions::dataType(in)); + shapes->push_back(CONSTANT(newShape)); + } + + return shapes; +} + +DECLARE_TYPES(dynamic_partition) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); } + +DECLARE_TYPES(dynamic_partition_bp) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} + +CUSTOM_OP_IMPL(dynamic_partition_bp, 3, 2, false, 0, 1) { + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + // auto gradOut = ; + auto numPartition = INT_ARG(0); + + std::vector outputList(2); // only for output + std::vector gradOutList(numPartition); + for (Nd4jLong e = 0; e < numPartition; e++) { + gradOutList[e] = INPUT_VARIABLE(e + 2); + } + outputList[0] = OUTPUT_VARIABLE(0); + outputList[1] = OUTPUT_VARIABLE(1); + auto originalIndices = + indices + ->dup(); //->ordering(), indices->shapeInfo(), indices->dataType()); + originalIndices.linspace(0); + ops::dynamic_partition op; + auto res = op.evaluate({&originalIndices, indices}, {numPartition}); + REQUIRE_TRUE(res.status() == ND4J_STATUS_OK, 0, + "dynamic_partition_bp: Error with dynamic partitioning."); + ops::dynamic_stitch stichOp; + std::vector partitions(numPartition * 2); + for (size_t i = 0; i < res.size(); i++) { + partitions[i] = &res.at(i); + partitions[i + numPartition] = gradOutList[i]; + } + + auto result = stichOp.evaluate(partitions, {numPartition}); + REQUIRE_TRUE(result.status() == ND4J_STATUS_OK, 0, + "dynamic_partition_bp: Error with dynamic partitioning."); + result.at(0).reshapei(outputList[0]->getShapeAsVector()); + outputList[1]->assign(indices); + outputList[0]->assign(result.at(0)); + + // helpers::dynamicPartitionFunctorBP(block.launchContext(), input, + // indices, gradOutList, outputList); + return ND4J_STATUS_OK; +} + +DECLARE_SHAPE_FN(dynamic_partition_bp) { + auto numPartition = INT_ARG(0); + auto indices = INPUT_VARIABLE(1); + std::vector partitionSizes(numPartition, 0); + + auto shapes = SHAPELIST(); + // just copy shape info from input and indices to output + for (Nd4jLong i = 0; i < 2; i++) { + Nd4jLong *newShape; + COPY_SHAPE(inputShape->at(i), newShape); + shapes->push_back(CONSTANT(newShape)); + } + + return shapes; } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/dynamic_stitch.cpp b/libnd4j/include/ops/declarable/generic/transforms/dynamic_stitch.cpp index d3c419b55f92..6ea53cac028d 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/dynamic_stitch.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/dynamic_stitch.cpp @@ -26,62 +26,70 @@ namespace sd { namespace ops { - CUSTOM_OP_IMPL(dynamic_stitch, 2, 1, false, 0, 0) { - int numOfData = block.width(); -// int k = 0; - // checking input data size - REQUIRE_TRUE(numOfData % 2 == 0, 0, - "dynamic_stitch: The input params should contains" - " both indeces and data lists with same length."); - // split input data list on two equal parts - numOfData /= 2; +CUSTOM_OP_IMPL(dynamic_stitch, 2, 1, false, 0, 0) { + int numOfData = block.width(); + // int k = 0; + // checking input data size + REQUIRE_TRUE(numOfData % 2 == 0, 0, + "dynamic_stitch: The input params should contains" + " both indeces and data lists with same length."); + // split input data list on two equal parts + numOfData /= 2; - // form input lists to use with helpers - both indices and float data inputs - auto output = OUTPUT_VARIABLE(0); - std::vector inputs(numOfData); - std::vector indices(numOfData); + // form input lists to use with helpers - both indices and float data inputs + auto output = OUTPUT_VARIABLE(0); + std::vector inputs(numOfData); + std::vector indices(numOfData); - for (int e = 0; e < numOfData; e++) { - auto data = INPUT_VARIABLE(numOfData + e); - auto index = INPUT_VARIABLE(e); + for (int e = 0; e < numOfData; e++) { + auto data = INPUT_VARIABLE(numOfData + e); + auto index = INPUT_VARIABLE(e); - inputs[e] = data; - indices[e] = index; - } - // run helper - return helpers::dynamicStitchFunctor(block.launchContext(), inputs, indices, output); - } + inputs[e] = data; + indices[e] = index; + } + // run helper + return helpers::dynamicStitchFunctor(block.launchContext(), inputs, indices, + output); +} - DECLARE_TYPES(dynamic_stitch) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } +DECLARE_TYPES(dynamic_stitch) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); +} - DECLARE_SHAPE_FN(dynamic_stitch) { - Nd4jLong maxValue = 0; - auto numOfData = block.width(); - numOfData /= 2; // only index part it's needed to review - auto restShape = inputShape->at(numOfData); - auto firstShape = inputShape->at(0); - // check up inputs to avoid non-int indices and calculate max value from indices to output shape length - for(int i = 0; i < numOfData; i++) { - auto input = INPUT_VARIABLE(i); - REQUIRE_TRUE(input->isZ(), 0, "dynamic_stitch: Indices should be integer, but %d type given.", (int)input->dataType() ); - auto maxV = input->reduceNumber(reduce::Max); - if (maxV.e(0) > maxValue) maxValue = maxV.e(0); - } - // calculate output rank - difference between indices shape and data shape - int outRank = shape::rank(restShape) - shape::rank(firstShape) + 1; // at least 1D tensor - std::vector outShape(outRank); - // fill up output shape template: the first to max index, and rests - to vals from the first data input - outShape[0] = maxValue + 1; - for(int i = 1; i < outRank; ++i) - outShape[i] = shape::sizeAt(restShape, i); +DECLARE_SHAPE_FN(dynamic_stitch) { + Nd4jLong maxValue = 0; + auto numOfData = block.width(); + numOfData /= 2; // only index part it's needed to review + auto restShape = inputShape->at(numOfData); + auto firstShape = inputShape->at(0); + // check up inputs to avoid non-int indices and calculate max value from + // indices to output shape length + for (int i = 0; i < numOfData; i++) { + auto input = INPUT_VARIABLE(i); + REQUIRE_TRUE( + input->isZ(), 0, + "dynamic_stitch: Indices should be integer, but %d type given.", + (int)input->dataType()); + auto maxV = input->reduceNumber(reduce::Max); + if (maxV.e(0) > maxValue) maxValue = maxV.e(0); + } + // calculate output rank - difference between indices shape and data shape + int outRank = shape::rank(restShape) - shape::rank(firstShape) + + 1; // at least 1D tensor + std::vector outShape(outRank); + // fill up output shape template: the first to max index, and rests - to vals + // from the first data input + outShape[0] = maxValue + 1; + for (int i = 1; i < outRank; ++i) outShape[i] = shape::sizeAt(restShape, i); - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(restShape), shape::order(firstShape), outShape))); - } -} + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(restShape), + shape::order(firstShape), outShape))); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/floor.cpp b/libnd4j/include/ops/declarable/generic/transforms/floor.cpp index 984708de5654..3156a9e7bf58 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/floor.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/floor.cpp @@ -24,25 +24,23 @@ #include namespace sd { - namespace ops { - OP_IMPL(Floor, 1, 1, true) { - auto first = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +OP_IMPL(Floor, 1, 1, true) { + auto first = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - first->applyTransform(transform::Floor, *z); + first->applyTransform(transform::Floor, *z); - STORE_RESULT(*z); + STORE_RESULT(*z); - return Status::OK(); - } - DECLARE_SYN(floor, Floor); + return Status::OK(); +} +DECLARE_SYN(floor, Floor); - DECLARE_TYPES(Floor) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +DECLARE_TYPES(Floor) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/gather.cpp b/libnd4j/include/ops/declarable/generic/transforms/gather.cpp index a979c5abd351..03145c3eefeb 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/gather.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/gather.cpp @@ -25,150 +25,156 @@ #include #include - namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(gather, 1, 1, false, 0, -2) { - - auto input = INPUT_VARIABLE(0); - auto indices = block.width() > 1 ? INPUT_VARIABLE(1) : nullptr; - auto output = OUTPUT_VARIABLE(0); - - const bool checkIndices = block.getBArguments()->empty() ? false : B_ARG(0); - - //Edge case: empty indices -> empty output - if(indices != nullptr && indices->isEmpty()){ - REQUIRE_TRUE(output->isEmpty(), 0, "Gather op: If indices are empty, output must also be empty"); - return Status::OK(); //No op - } - - const int numOfIntArgs = block.numI(); - - std::vector intArgs; - if (block.width() > 2) { - intArgs = INPUT_VARIABLE(2)->template asVectorT(); - } - else { - if (numOfIntArgs == 0) - intArgs.emplace_back(0); - else - for (int i = 0; i < numOfIntArgs; ++i) - intArgs.emplace_back(block.getIArguments()->at(i)); - } - - const int inputRank = input->rankOf(); - if(intArgs[0] < 0) - intArgs[0] += inputRank; - - // input validation - REQUIRE_TRUE(intArgs[0] < inputRank, 0, "GATHER op: input axis must be smaller than input array rank, but got %i and %i correspondingly!", intArgs[0], inputRank); - REQUIRE_TRUE(indices != nullptr || numOfIntArgs > 1, 0, "GATHER op: indices should be provided either as additional input array or as IntArguments !"); - - if(checkIndices) { - - NDArray* pIndices = indices; - if(indices == nullptr) - pIndices = new NDArray(input->ordering(), {static_cast(intArgs.size()) - 1}, std::vector(intArgs.begin() + 1, intArgs.end()), DataType::INT64, block.launchContext()); - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *pIndices, *input, intArgs[0]); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "GATHER OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - if(indices == nullptr) - delete pIndices; - } - - helpers::gather(block.launchContext(), input, indices, output, intArgs); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto indices = block.width() > 1 ? INPUT_VARIABLE(1) : nullptr; + auto output = OUTPUT_VARIABLE(0); + + const bool checkIndices = block.getBArguments().empty() ? false : B_ARG(0); + + // Edge case: empty indices -> empty output + if (indices != nullptr && indices->isEmpty()) { + REQUIRE_TRUE(output->isEmpty(), 0, + "Gather op: If indices are empty, output must also be empty"); + return Status::OK(); // No op + } + + const int numOfIntArgs = block.numI(); + + std::vector intArgs; + if (block.width() > 2) { + intArgs = INPUT_VARIABLE(2)->template asVectorT(); + } else { + if (numOfIntArgs == 0) + intArgs.emplace_back(0); + else + for (int i = 0; i < numOfIntArgs; ++i) + intArgs.emplace_back(block.getIArguments().at(i)); + } + + const int inputRank = input->rankOf(); + if (intArgs[0] < 0) intArgs[0] += inputRank; + + // input validation + REQUIRE_TRUE(intArgs[0] < inputRank, 0, + "GATHER op: input axis must be smaller than input array rank, " + "but got %i and %i correspondingly!", + intArgs[0], inputRank); + REQUIRE_TRUE(indices != nullptr || numOfIntArgs > 1, 0, + "GATHER op: indices should be provided either as additional " + "input array or as IntArguments !"); + + if (checkIndices) { + NDArray* pIndices = indices; + if (indices == nullptr) + pIndices = + new NDArray(input->ordering(), {static_cast(intArgs.size()) - 1}, + std::vector(intArgs.begin() + 1, intArgs.end()), + DataType::INT64, block.launchContext()); + const Nd4jLong numOfBadIndx = helpers::checkIndices( + block.launchContext(), *pIndices, *input, intArgs[0]); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "GATHER OP: please check elements of indices-array, total " + "number of wrong elements is %lld!", + numOfBadIndx); + if (indices == nullptr) delete pIndices; + } + + helpers::gather(block.launchContext(), input, indices, output, intArgs); + + return Status::OK(); } DECLARE_TYPES(gather) { - getOpDescriptor()->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}); - getOpDescriptor()->setAllowedInputTypes(1, {ALL_INTS}); - getOpDescriptor()->setAllowedOutputTypes(0, {ALL_INTS, ALL_FLOATS}); + getOpDescriptor()->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}); + getOpDescriptor()->setAllowedInputTypes(1, {ALL_INTS}); + getOpDescriptor()->setAllowedOutputTypes(0, {ALL_INTS, ALL_FLOATS}); } - DECLARE_SHAPE_FN(gather) { + // check shape of paddings + auto inputShapeInfo = inputShape->at(0); + Nd4jLong* outputShapeInfo = nullptr; - // check shape of paddings - auto inputShapeInfo = inputShape->at(0); - Nd4jLong* outputShapeInfo = nullptr; - - int axis = 0; + int axis = 0; - if (block.width() > 2) { - axis = INPUT_VARIABLE(2)->e(0); - } else - axis = block.numI() > 0 ? block.getIArguments()->at(0) : 0; + if (block.width() > 2) { + axis = INPUT_VARIABLE(2)->e(0); + } else + axis = block.numI() > 0 ? block.getIArguments().at(0) : 0; - int inputRank = shape::rank(inputShapeInfo); - if(axis < 0) - axis += inputRank; + int inputRank = shape::rank(inputShapeInfo); + if (axis < 0) axis += inputRank; - REQUIRE_TRUE(axis < inputRank, 0, "GATHER op: input axis must be smaller than input array rank, but got %i and %i correspondingly!", axis, inputRank); + REQUIRE_TRUE(axis < inputRank, 0, + "GATHER op: input axis must be smaller than input array rank, " + "but got %i and %i correspondingly!", + axis, inputRank); - bool isEmpty = false; + bool isEmpty = false; - if (block.width() > 1) { - auto indicesShapeInfo = inputShape->at(1); + if (block.width() > 1) { + auto indicesShapeInfo = inputShape->at(1); - int indicesRank = shape::rank(indicesShapeInfo); + int indicesRank = shape::rank(indicesShapeInfo); - int outputRank = inputRank + indicesRank - 1; + int outputRank = inputRank + indicesRank - 1; - ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outputRank), Nd4jLong); + ALLOCATE(outputShapeInfo, block.workspace(), + shape::shapeInfoLength(outputRank), Nd4jLong); - // fill output shapeInfo - outputShapeInfo[0] = outputRank; - int shapeIdx = 1; + // fill output shapeInfo + outputShapeInfo[0] = outputRank; + int shapeIdx = 1; - for(int i = 0; i < axis; ++i) - outputShapeInfo[shapeIdx++] = inputShapeInfo[i+1]; + for (int i = 0; i < axis; ++i) + outputShapeInfo[shapeIdx++] = inputShapeInfo[i + 1]; - for(int i = 0; i < indicesRank; ++i) - outputShapeInfo[shapeIdx++] = indicesShapeInfo[i+1]; + for (int i = 0; i < indicesRank; ++i) + outputShapeInfo[shapeIdx++] = indicesShapeInfo[i + 1]; - for(int i = axis+1; i < inputRank; ++i) - outputShapeInfo[shapeIdx++] = inputShapeInfo[i+1]; - } - else if (block.numI() > 1) { + for (int i = axis + 1; i < inputRank; ++i) + outputShapeInfo[shapeIdx++] = inputShapeInfo[i + 1]; + } else if (block.numI() > 1) { + int indicesRank = block.numI() == 2 ? 0 : 1; - int indicesRank = block.numI() == 2 ? 0 : 1; + int outputRank = inputRank + indicesRank - 1; + ALLOCATE(outputShapeInfo, block.workspace(), + shape::shapeInfoLength(outputRank), Nd4jLong); - int outputRank = inputRank + indicesRank - 1; - ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outputRank), Nd4jLong); + // building shape manually + outputShapeInfo[0] = outputRank; + int shapeIdx = 1; + for (int i = 0; i < axis; ++i) + outputShapeInfo[shapeIdx++] = inputShapeInfo[i + 1]; - // building shape manually - outputShapeInfo[0] = outputRank; - int shapeIdx = 1; - for(int i = 0; i < axis; ++i) - outputShapeInfo[shapeIdx++] = inputShapeInfo[i+1]; + if (block.numI() > 2) outputShapeInfo[shapeIdx++] = block.numI() - 1; - if (block.numI() > 2) - outputShapeInfo[shapeIdx++] = block.numI() - 1; + for (int i = axis + 1; i < inputRank; ++i) + outputShapeInfo[shapeIdx++] = inputShapeInfo[i + 1]; + } else + REQUIRE_TRUE(false, 0, + "GATHER op: indices should be provided either as additional " + "input array or as IntArguments !"); - for(int i = axis+1; i < inputRank; ++i) - outputShapeInfo[shapeIdx++] = inputShapeInfo[i+1]; - } - else - REQUIRE_TRUE(false, 0, "GATHER op: indices should be provided either as additional input array or as IntArguments !"); - - ShapeUtils::updateStridesAndType(outputShapeInfo, inputShapeInfo, shape::order(inputShapeInfo)); - - if(isEmpty){ - ArrayOptions::setPropertyBit(outputShapeInfo, ARRAY_EMPTY); - } + ShapeUtils::updateStridesAndType(outputShapeInfo, inputShapeInfo, + shape::order(inputShapeInfo)); - auto result = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(outputShapeInfo)); - RELEASE(outputShapeInfo, block.getWorkspace()); - return SHAPELIST(result); + if (isEmpty) { + ArrayOptions::setPropertyBit(outputShapeInfo, ARRAY_EMPTY); + } + auto result = ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(outputShapeInfo)); + RELEASE(outputShapeInfo, block.workspace()); + return SHAPELIST(result); } -} -} - +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/transforms/gatherNd.cpp b/libnd4j/include/ops/declarable/generic/transforms/gatherNd.cpp index 30b5b19ef252..eb6b042a7cc1 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/gatherNd.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/gatherNd.cpp @@ -22,79 +22,90 @@ #if NOT_EXCLUDED(OP_gather_nd) #include -#include #include +#include namespace sd { namespace ops { - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(gather_nd, 2, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - - const bool checkIndices = block.getBArguments()->empty() ? false : B_ARG(0); - - const int rankIn = input->rankOf(); - const int rankInd = indices->rankOf(); - - REQUIRE_TRUE(rankInd > 0, 0, "GATHER_ND op: array of indexes can't be single scalar, the requirement is: rank > 0, but got rank = %i instead!", rankInd); - int lastIndDim = indices->sizeAt(-1); - REQUIRE_TRUE(lastIndDim <= rankIn, 0, "GATHER_ND op: the last dimension of indices array must be <= rank of input array but got %i and %i correspondingly!", lastIndDim, rankIn); - - if(checkIndices) { - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *input); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "GATHER_ND OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - } - - helpers::gatherND(block.launchContext(), *input, *indices, *output); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + + const bool checkIndices = block.getBArguments().empty() ? false : B_ARG(0); + + const int rankIn = input->rankOf(); + const int rankInd = indices->rankOf(); + + REQUIRE_TRUE(rankInd > 0, 0, + "GATHER_ND op: array of indexes can't be single scalar, the " + "requirement is: rank > 0, but got rank = %i instead!", + rankInd); + int lastIndDim = indices->sizeAt(-1); + REQUIRE_TRUE(lastIndDim <= rankIn, 0, + "GATHER_ND op: the last dimension of indices array must be <= " + "rank of input array but got %i and %i correspondingly!", + lastIndDim, rankIn); + + if (checkIndices) { + const Nd4jLong numOfBadIndx = + helpers::checkIndices(block.launchContext(), *indices, *input); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "GATHER_ND OP: please check elements of indices-array, total " + "number of wrong elements is %lld!", + numOfBadIndx); + } + + helpers::gatherND(block.launchContext(), *input, *indices, *output); + + return Status::OK(); } DECLARE_TYPES(gather_nd) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } DECLARE_SHAPE_FN(gather_nd) { - - auto inShapeInfoIn = inputShape->at(0); - auto inShapeInfoInd = inputShape->at(1); - - const int rankIn = inShapeInfoIn[0]; - const int rankInd = inShapeInfoInd[0]; - REQUIRE_TRUE(rankInd > 0, 0, "GATHER_ND op: array of indexes can't be single scalar, the requirement is: rank > 0, but got rank = %i instead!", rankInd); - const int lastIndDim = inShapeInfoInd[rankInd]; - REQUIRE_TRUE(lastIndDim <= rankIn, 0, "GATHER_ND op: the last dimension of indices array must be <= rank of input array but got %i and %i correspondingly!", lastIndDim, rankIn); - - int outRank = (rankInd - 1) + (rankIn - lastIndDim); - - Nd4jLong* outShapeInfo = nullptr; - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); - - outShapeInfo[0] = outRank; - - for(int i = 1; i <= rankInd-1; ++i) - outShapeInfo[i] = inShapeInfoInd[i]; - - for(int i = 0; i < rankIn-lastIndDim; ++i) - outShapeInfo[rankInd + i] = inShapeInfoIn[lastIndDim + i + 1]; - - ShapeUtils::updateStridesAndType(outShapeInfo, inShapeInfoIn, 'c'); - //ArrayOptions::setDataType(outShapeInfo, ArrayOptions::dataType(inShapeInfoIn)); - return SHAPELIST(CONSTANT(outShapeInfo)); + auto inShapeInfoIn = inputShape->at(0); + auto inShapeInfoInd = inputShape->at(1); + + const int rankIn = inShapeInfoIn[0]; + const int rankInd = inShapeInfoInd[0]; + REQUIRE_TRUE(rankInd > 0, 0, + "GATHER_ND op: array of indexes can't be single scalar, the " + "requirement is: rank > 0, but got rank = %i instead!", + rankInd); + const int lastIndDim = inShapeInfoInd[rankInd]; + REQUIRE_TRUE(lastIndDim <= rankIn, 0, + "GATHER_ND op: the last dimension of indices array must be <= " + "rank of input array but got %i and %i correspondingly!", + lastIndDim, rankIn); + + int outRank = (rankInd - 1) + (rankIn - lastIndDim); + + Nd4jLong* outShapeInfo = nullptr; + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + + outShapeInfo[0] = outRank; + + for (int i = 1; i <= rankInd - 1; ++i) outShapeInfo[i] = inShapeInfoInd[i]; + + for (int i = 0; i < rankIn - lastIndDim; ++i) + outShapeInfo[rankInd + i] = inShapeInfoIn[lastIndDim + i + 1]; + + ShapeUtils::updateStridesAndType(outShapeInfo, inShapeInfoIn, 'c'); + // ArrayOptions::setDataType(outShapeInfo, + // ArrayOptions::dataType(inShapeInfoIn)); + return SHAPELIST(CONSTANT(outShapeInfo)); } - - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/hashcode.cpp b/libnd4j/include/ops/declarable/generic/transforms/hashcode.cpp index 0ef9d71cef23..42e0602e4d30 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/hashcode.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/hashcode.cpp @@ -22,37 +22,38 @@ #if NOT_EXCLUDED(OP_hashcode) #include -#include #include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(hashcode, 1, 1, false, 0, 0) { - REQUIRE_TRUE(block.width() == 1, 0, "hashcode: this op can't be applied along dimension"); - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(hashcode, 1, 1, false, 0, 0) { + REQUIRE_TRUE(block.width() == 1, 0, + "hashcode: this op can't be applied along dimension"); - REQUIRE_TRUE(output->isScalar(), 0, "hashcode: this op requires scalar output"); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - helpers::hashCode(block.launchContext(), *input, *output); + REQUIRE_TRUE(output->isScalar(), 0, + "hashcode: this op requires scalar output"); - return Status::OK(); - }; + helpers::hashCode(block.launchContext(), *input, *output); - DECLARE_SHAPE_FN(hashcode) { - return SHAPELIST(ConstantShapeHelper::getInstance().scalarShapeInfo(sd::DataType::INT64)); - } + return Status::OK(); +}; - - DECLARE_TYPES(hashcode) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes({sd::DataType::INT64}); - }; - } +DECLARE_SHAPE_FN(hashcode) { + return SHAPELIST( + ConstantShapeHelper::getInstance().scalarShapeInfo(sd::DataType::INT64)); } -#endif +DECLARE_TYPES(hashcode) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes({sd::DataType::INT64}); +}; +} // namespace ops +} // namespace sd +#endif diff --git a/libnd4j/include/ops/declarable/generic/transforms/histogram.cpp b/libnd4j/include/ops/declarable/generic/transforms/histogram.cpp index e08fcdbf5534..f0d9d12b8e7e 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/histogram.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/histogram.cpp @@ -22,37 +22,39 @@ #if NOT_EXCLUDED(OP_histogram) #include -#include #include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(histogram, 1, 1, false, 0, 1) { - auto input = INPUT_VARIABLE(0); - auto numBins = INT_ARG(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(numBins == output->lengthOf(), 0, "Histogram: numBins must match output length") +namespace ops { - output->nullify(); - helpers::histogramHelper(block.launchContext(), *input, *output); +CUSTOM_OP_IMPL(histogram, 1, 1, false, 0, 1) { + auto input = INPUT_VARIABLE(0); + auto numBins = INT_ARG(0); + auto output = OUTPUT_VARIABLE(0); - return Status::OK(); - } + REQUIRE_TRUE(numBins == output->lengthOf(), 0, + "Histogram: numBins must match output length") - DECLARE_SHAPE_FN(histogram) { - auto numBins = INT_ARG(0); + output->nullify(); + helpers::histogramHelper(block.launchContext(), *input, *output); - return SHAPELIST(ConstantShapeHelper::getInstance().vectorShapeInfo(numBins, sd::DataType::INT64)); - } + return Status::OK(); +} +DECLARE_SHAPE_FN(histogram) { + auto numBins = INT_ARG(0); - DECLARE_TYPES(histogram) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS}); - }; - } + return SHAPELIST(ConstantShapeHelper::getInstance().vectorShapeInfo( + numBins, sd::DataType::INT64)); } +DECLARE_TYPES(histogram) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS}); +}; +} // namespace ops +} // namespace sd + #endif diff --git a/libnd4j/include/ops/declarable/generic/transforms/histogram_fixed_width.cpp b/libnd4j/include/ops/declarable/generic/transforms/histogram_fixed_width.cpp index 208baa5a9e0d..d42466b9d9ef 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/histogram_fixed_width.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/histogram_fixed_width.cpp @@ -25,44 +25,52 @@ #include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(histogram_fixed_width, 2, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto range = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - - const int nbins = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : block.getIArguments()->empty() ? 100 : INT_ARG(0); - - const double leftEdge = range->e(0); - const double rightEdge = range->e(1); - - REQUIRE_TRUE(leftEdge < rightEdge, 0, "HISTOGRAM_FIXED_WIDTH OP: wrong content of range input array, bottom_edge must be smaller than top_edge, but got %f and %f correspondingly !", leftEdge, rightEdge); - REQUIRE_TRUE(nbins >= 1, 0, "HISTOGRAM_FIXED_WIDTH OP: wrong nbins value, expected value should be >= 1, however got %i instead !", nbins); - - helpers::histogramFixedWidth(block.launchContext(), *input, *range, *output); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto range = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + + const int nbins = block.width() == 3 + ? INPUT_VARIABLE(2)->e(0) + : block.getIArguments().empty() ? 100 : INT_ARG(0); + + const double leftEdge = range->e(0); + const double rightEdge = range->e(1); + + REQUIRE_TRUE(leftEdge < rightEdge, 0, + "HISTOGRAM_FIXED_WIDTH OP: wrong content of range input array, " + "bottom_edge must be smaller than top_edge, but got %f and %f " + "correspondingly !", + leftEdge, rightEdge); + REQUIRE_TRUE(nbins >= 1, 0, + "HISTOGRAM_FIXED_WIDTH OP: wrong nbins value, expected value " + "should be >= 1, however got %i instead !", + nbins); + + helpers::histogramFixedWidth(block.launchContext(), *input, *range, *output); + + return Status::OK(); } DECLARE_TYPES(histogram_fixed_width) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_INDICES}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_INDICES}); } - ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(histogram_fixed_width) { - - const int nbins = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : block.getIArguments()->empty() ? 100 : INT_ARG(0); - auto outShapeInfo = ConstantShapeHelper::getInstance().vectorShapeInfo(nbins, DataType::INT64); - return SHAPELIST(outShapeInfo); + const int nbins = block.width() == 3 + ? INPUT_VARIABLE(2)->e(0) + : block.getIArguments().empty() ? 100 : INT_ARG(0); + auto outShapeInfo = ConstantShapeHelper::getInstance().vectorShapeInfo( + nbins, DataType::INT64); + return SHAPELIST(outShapeInfo); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/invertPermutation.cpp b/libnd4j/include/ops/declarable/generic/transforms/invertPermutation.cpp index 3814106cf7ba..2e7ad21d1d73 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/invertPermutation.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/invertPermutation.cpp @@ -22,32 +22,32 @@ #if NOT_EXCLUDED(OP_invert_permutation) #include -#include +#include namespace sd { -namespace ops { +namespace ops { //////////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(invert_permutation, 1, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(input->isVector(), 0 , "INVERT_PERMUTATION op: input array must be vector, but got shape %s instead !", ShapeUtils::shapeAsString(input).c_str()); - - helpers::invertPermutation(block.launchContext(), *input, *output); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(input->isVector(), 0, + "INVERT_PERMUTATION op: input array must be vector, but got " + "shape %s instead !", + ShapeUtils::shapeAsString(input).c_str()); + + helpers::invertPermutation(block.launchContext(), *input, *output); + + return Status::OK(); } - + DECLARE_SYN(InvertPermutation, invert_permutation); - DECLARE_TYPES(invert_permutation) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } -} +DECLARE_TYPES(invert_permutation) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/merge_add.cpp b/libnd4j/include/ops/declarable/generic/transforms/merge_add.cpp index 0fade28bfcae..452678db50ef 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/merge_add.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/merge_add.cpp @@ -22,25 +22,23 @@ #if NOT_EXCLUDED(OP_mergeadd) #include -#include +#include namespace sd { -namespace ops { +namespace ops { OP_IMPL(mergeadd, -1, 1, false) { - - REQUIRE_OK(this->validateInputDimensionsMatch(block)); - - auto output = OUTPUT_VARIABLE(0); + REQUIRE_OK(this->validateInputDimensionsMatch(block)); - std::vector inArrs(block.width()); - - for(int i = 0; i < block.width(); ++i) - inArrs[i] = INPUT_VARIABLE(i); + auto output = OUTPUT_VARIABLE(0); - helpers::mergeAdd(block.launchContext(), inArrs, *output); + std::vector inArrs(block.width()); - return Status::OK(); + for (int i = 0; i < block.width(); ++i) inArrs[i] = INPUT_VARIABLE(i); + + helpers::mergeAdd(block.launchContext(), inArrs, *output); + + return Status::OK(); } DECLARE_SYN(mergesum, mergeadd); DECLARE_SYN(add_n, mergeadd); @@ -48,51 +46,51 @@ DECLARE_SYN(addn, mergeadd); DECLARE_SYN(accumulaten, mergeadd); DECLARE_SYN(accumulate_n, mergeadd); - DECLARE_TYPES(mergeadd) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::ANY); - } - - - CUSTOM_OP_IMPL(mergeadd_bp, 2, 1, false, 0, 0) { +DECLARE_TYPES(mergeadd) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::ANY); +} - auto inSize = block.width() - 1; +CUSTOM_OP_IMPL(mergeadd_bp, 2, 1, false, 0, 0) { + auto inSize = block.width() - 1; - REQUIRE_OK(this->validateInputDimensionsMatch(block)); + REQUIRE_OK(this->validateInputDimensionsMatch(block)); - std::vector outArrs(inSize); - - const auto gradient = INPUT_VARIABLE(inSize); + std::vector outArrs(inSize); - for (int i = 0; i < inSize; ++i) { - outArrs[i] = OUTPUT_VARIABLE(i); - } - helpers::mergeAddBp(block.launchContext(), *gradient, outArrs); + const auto gradient = INPUT_VARIABLE(inSize); - return Status::OK(); - } + for (int i = 0; i < inSize; ++i) { + outArrs[i] = OUTPUT_VARIABLE(i); + } + helpers::mergeAddBp(block.launchContext(), *gradient, outArrs); - DECLARE_TYPES(mergeadd_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::ANY); - } - DECLARE_SHAPE_FN(mergeadd_bp) { + return Status::OK(); +} - const int numOfInArrs = block.width() - 1; +DECLARE_TYPES(mergeadd_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::ANY); +} - auto shapeList = SHAPELIST(); +DECLARE_SHAPE_FN(mergeadd_bp) { + const int numOfInArrs = block.width() - 1; - for (int e = 0; e < numOfInArrs; e++) { - auto inShape = inputShape->at(e); - shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), shape::shapeOf(inShape), shape::rank(inShape)))); - } + auto shapeList = SHAPELIST(); - return shapeList; - } + for (int e = 0; e < numOfInArrs; e++) { + auto inShape = inputShape->at(e); + shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), + shape::shapeOf(inShape), shape::rank(inShape)))); + } + return shapeList; } -} + +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/merge_avg.cpp b/libnd4j/include/ops/declarable/generic/transforms/merge_avg.cpp index 2ea0d501b412..fa81160adaa5 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/merge_avg.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/merge_avg.cpp @@ -22,71 +22,68 @@ #if NOT_EXCLUDED(OP_mergeavg) #include -#include +#include namespace sd { -namespace ops { +namespace ops { OP_IMPL(mergeavg, -1, 1, false) { - - REQUIRE_OK(this->validateInputDimensionsMatch(block)); - - auto output = OUTPUT_VARIABLE(0); + REQUIRE_OK(this->validateInputDimensionsMatch(block)); - std::vector inArrs(block.width()); - - for(int i = 0; i < block.width(); ++i) - inArrs[i] = INPUT_VARIABLE(i); + auto output = OUTPUT_VARIABLE(0); - helpers::mergeAvg(block.launchContext(), inArrs, *output); + std::vector inArrs(block.width()); - return Status::OK(); -} - - DECLARE_TYPES(mergeavg) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } + for (int i = 0; i < block.width(); ++i) inArrs[i] = INPUT_VARIABLE(i); + helpers::mergeAvg(block.launchContext(), inArrs, *output); - CUSTOM_OP_IMPL(mergeavg_bp, 2, 1, false, 0, 0) { + return Status::OK(); +} - auto inSize = block.width() - 1; +DECLARE_TYPES(mergeavg) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - REQUIRE_OK(this->validateInputDimensionsMatch(block)); +CUSTOM_OP_IMPL(mergeavg_bp, 2, 1, false, 0, 0) { + auto inSize = block.width() - 1; - std::vector outArrs(inSize); + REQUIRE_OK(this->validateInputDimensionsMatch(block)); - const auto gradient = INPUT_VARIABLE(inSize); - - for (int i = 0; i < inSize; ++i) { - outArrs[i] = OUTPUT_VARIABLE(i); - } - helpers::mergeAvgBp(block.launchContext(), *gradient, outArrs); - return Status::OK(); - } + std::vector outArrs(inSize); - DECLARE_TYPES(mergeavg_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::ANY); - } - DECLARE_SHAPE_FN(mergeavg_bp) { + const auto gradient = INPUT_VARIABLE(inSize); - const int numOfInArrs = block.width() - 1; + for (int i = 0; i < inSize; ++i) { + outArrs[i] = OUTPUT_VARIABLE(i); + } + helpers::mergeAvgBp(block.launchContext(), *gradient, outArrs); + return Status::OK(); +} - auto shapeList = SHAPELIST(); +DECLARE_TYPES(mergeavg_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::ANY); +} +DECLARE_SHAPE_FN(mergeavg_bp) { + const int numOfInArrs = block.width() - 1; - for (int e = 0; e < numOfInArrs; e++) { - auto inShape = inputShape->at(e); - shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), shape::shapeOf(inShape), shape::rank(inShape)))); - } + auto shapeList = SHAPELIST(); - return shapeList; - } + for (int e = 0; e < numOfInArrs; e++) { + auto inShape = inputShape->at(e); + shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), + shape::shapeOf(inShape), shape::rank(inShape)))); + } + return shapeList; } -} + +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/merge_max.cpp b/libnd4j/include/ops/declarable/generic/transforms/merge_max.cpp index e95092f3879e..b6370e810e42 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/merge_max.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/merge_max.cpp @@ -22,76 +22,72 @@ #if NOT_EXCLUDED(OP_mergemax) #include -#include +#include namespace sd { -namespace ops { - -OP_IMPL(mergemax, -1, 1, false) { - - REQUIRE_OK(this->validateInputDimensionsMatch(block)); - - auto output = OUTPUT_VARIABLE(0); +namespace ops { - std::vector inArrs(block.width()); - - for(int i = 0; i < block.width(); ++i) - inArrs[i] = INPUT_VARIABLE(i); +OP_IMPL(mergemax, -1, 1, false) { + REQUIRE_OK(this->validateInputDimensionsMatch(block)); - helpers::mergeMax(block.launchContext(), inArrs, *output); + auto output = OUTPUT_VARIABLE(0); - return Status::OK(); -} -DECLARE_SYN(MergeMax, mergemax); + std::vector inArrs(block.width()); - DECLARE_TYPES(mergemax) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::ANY); - } + for (int i = 0; i < block.width(); ++i) inArrs[i] = INPUT_VARIABLE(i); + helpers::mergeMax(block.launchContext(), inArrs, *output); - CUSTOM_OP_IMPL(mergemax_bp, 2, 1, false, 0, 0) { + return Status::OK(); +} +DECLARE_SYN(MergeMax, mergemax); - auto inSize = block.width(); +DECLARE_TYPES(mergemax) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::ANY); +} - REQUIRE_OK(this->validateInputDimensionsMatch(block)); +CUSTOM_OP_IMPL(mergemax_bp, 2, 1, false, 0, 0) { + auto inSize = block.width(); - std::vector inArrs(inSize); - std::vector outArrs(inSize - 1); + REQUIRE_OK(this->validateInputDimensionsMatch(block)); - for (int i = 0; i < inSize; ++i) - inArrs[i] = INPUT_VARIABLE(i); + std::vector inArrs(inSize); + std::vector outArrs(inSize - 1); - for (int i = 0; i < (inSize - 1); ++i) { - outArrs[i] = OUTPUT_NULLIFIED(i); - } + for (int i = 0; i < inSize; ++i) inArrs[i] = INPUT_VARIABLE(i); - helpers::mergeMaxBp(block.launchContext(), inArrs, outArrs); + for (int i = 0; i < (inSize - 1); ++i) { + outArrs[i] = OUTPUT_NULLIFIED(i); + } - return Status::OK(); - } + helpers::mergeMaxBp(block.launchContext(), inArrs, outArrs); - DECLARE_TYPES(mergemax_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::ANY); - } - DECLARE_SHAPE_FN(mergemax_bp) { + return Status::OK(); +} - const int numOfInArrs = block.width() - 1; +DECLARE_TYPES(mergemax_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::ANY); +} +DECLARE_SHAPE_FN(mergemax_bp) { + const int numOfInArrs = block.width() - 1; - auto shapeList = SHAPELIST(); - - for (int e = 0; e < numOfInArrs; e++) { - auto inShape = inputShape->at(e); - shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), shape::shapeOf(inShape), shape::rank(inShape)))); - } + auto shapeList = SHAPELIST(); - return shapeList; - } + for (int e = 0; e < numOfInArrs; e++) { + auto inShape = inputShape->at(e); + shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), + shape::shapeOf(inShape), shape::rank(inShape)))); + } + return shapeList; } -} + +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/merge_max_idx.cpp b/libnd4j/include/ops/declarable/generic/transforms/merge_max_idx.cpp index 3c76450aa9fb..aadcf88afa67 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/merge_max_idx.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/merge_max_idx.cpp @@ -22,43 +22,42 @@ #if NOT_EXCLUDED(OP_mergemaxindex) #include -#include +#include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(mergemaxindex, -1, 1, false, 0, 0) { + REQUIRE_OK(this->validateInputDimensionsMatch(block)); + auto output = OUTPUT_VARIABLE(0); - REQUIRE_OK(this->validateInputDimensionsMatch(block)); - auto output = OUTPUT_VARIABLE(0); + std::vector inArrs(block.width()); - std::vector inArrs(block.width()); + for (int i = 0; i < block.width(); ++i) inArrs[i] = INPUT_VARIABLE(i); - for(int i = 0; i < block.width(); ++i) - inArrs[i] = INPUT_VARIABLE(i); + helpers::mergeMaxIndex(block.launchContext(), inArrs, *output); - helpers::mergeMaxIndex(block.launchContext(), inArrs, *output); - - return Status::OK(); + return Status::OK(); } DECLARE_SYN(MergeMaxIndex, mergemaxindex); - DECLARE_TYPES(mergemaxindex) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INDICES}); - } +DECLARE_TYPES(mergemaxindex) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INDICES}); } + DECLARE_SHAPE_FN(mergemaxindex) { - auto in = inputShape->at(0); - auto dtype = DataType::INT32; - if (block.getIArguments()->size()> 0) - dtype = (DataType)INT_ARG(0); + auto in = inputShape->at(0); + auto dtype = DataType::INT32; + if (block.numI() > 0) dtype = (DataType)INT_ARG(0); - auto resShape = ShapeBuilders::copyShapeInfoAndType(in, dtype, block.workspace()); - return SHAPELIST(CONSTANT(resShape)); -} + auto resShape = ShapeBuilders::copyShapeInfoAndType(in, dtype, block.workspace()); + return SHAPELIST(CONSTANT(resShape)); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp b/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp index 403272530e0f..611e24d6f2a9 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp @@ -21,82 +21,132 @@ #if NOT_EXCLUDED(OP_mirror_pad) #include -#include +#include namespace sd { namespace ops { - CUSTOM_OP_IMPL(mirror_pad, 2, 1, false, 0, 1) { - - auto input = INPUT_VARIABLE(0); - auto paddings = INPUT_VARIABLE(1); - - auto output = OUTPUT_VARIABLE(0); - - const int mode = INT_ARG(0); // 0 - REFLECT, else - SYMMETRIC - const int includeBorder = mode ? 0 : 1; - - if(input->rankOf() <= 1) { // when input is scalar or vector; - REQUIRE_TRUE(paddings->lengthOf() == 2, 0, "MIRROR_PAD OP: the length of paddings array must be equal 2, when input array is vector or scalar, bot but got %i instead !", paddings->rankOf()); - REQUIRE_TRUE( (paddings->e(0) <= (input->lengthOf() - includeBorder)) && (paddings->e(1) <= (input->lengthOf() - includeBorder)), 0, "MIRROR_PAD OP: wrong content of paddings array, its elements must be no grater then length of input array (being vector or scalar) for symmetric mode (or length-1 for reflect mode) !"); - } - else { - REQUIRE_TRUE(paddings->rankOf() == 2, 0, "MIRROR_PAD OP: the rank of paddings array must be equal 2, but got %i instead !", paddings->rankOf()); - REQUIRE_TRUE(paddings->sizeAt(0) == input->rankOf(), 0, "MIRROR_PAD OP: zero dimension of paddings array must be equal to input array rank, but got %i and %i correspondingly !", paddings->sizeAt(0), input->rankOf()); - - for(int i = 0; i < input->rankOf(); ++i) - REQUIRE_TRUE( (paddings->e(i,0) <= (input->sizeAt(i) - includeBorder)) && (paddings->e(i,1) <= (input->sizeAt(i) - includeBorder)), 0, "MIRROR_PAD OP: wrong content of paddings array, its elements must be no grater then corresponding dimension of input array for symmetric mode (or dimension-1 for reflect mode) !"); - } - - helpers::mirrorPad(block.launchContext(), *input, *paddings, *output, mode); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto paddings = INPUT_VARIABLE(1); + + auto output = OUTPUT_VARIABLE(0); + + const int mode = INT_ARG(0); // 0 - REFLECT, else - SYMMETRIC + const int includeBorder = mode ? 0 : 1; + + if (input->rankOf() <= 1) { // when input is scalar or vector; + REQUIRE_TRUE( + paddings->lengthOf() == 2, 0, + "MIRROR_PAD OP: the length of paddings array must be equal 2, when " + "input array is vector or scalar, bot but got %i instead !", + paddings->rankOf()); + REQUIRE_TRUE( + (paddings->e(0) <= (input->lengthOf() - includeBorder)) && + (paddings->e(1) <= (input->lengthOf() - includeBorder)), + 0, + "MIRROR_PAD OP: wrong content of paddings array, its elements must be " + "no grater then length of input array (being vector or scalar) for " + "symmetric mode (or length-1 for reflect mode) !"); + } else { + REQUIRE_TRUE(paddings->rankOf() == 2, 0, + "MIRROR_PAD OP: the rank of paddings array must be equal 2, " + "but got %i instead !", + paddings->rankOf()); + REQUIRE_TRUE( + paddings->sizeAt(0) == input->rankOf(), 0, + "MIRROR_PAD OP: zero dimension of paddings array must be equal to " + "input array rank, but got %i and %i correspondingly !", + paddings->sizeAt(0), input->rankOf()); + + for (int i = 0; i < input->rankOf(); ++i) + REQUIRE_TRUE( + (paddings->e(i, 0) <= (input->sizeAt(i) - includeBorder)) && + (paddings->e(i, 1) <= + (input->sizeAt(i) - includeBorder)), + 0, + "MIRROR_PAD OP: wrong content of paddings array, its elements must " + "be no grater then corresponding dimension of input array for " + "symmetric mode (or dimension-1 for reflect mode) !"); + } + + helpers::mirrorPad(block.launchContext(), *input, *paddings, *output, mode); + + return Status::OK(); } - DECLARE_TYPES(mirror_pad) { - getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}); - getOpDescriptor()->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}); // to conform with TF - getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(mirror_pad) { + getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}); + getOpDescriptor()->setAllowedInputTypes( + 1, {DataType::INT32, DataType::INT64}); // to conform with TF + getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS}); +} DECLARE_SHAPE_FN(mirror_pad) { - - auto input = INPUT_VARIABLE(0); - auto paddings = INPUT_VARIABLE(1); - - const int rank = input->rankOf() ? input->rankOf() : 1; // if scalar is input then vector is output - const int includeBorder = static_cast(INT_ARG(0)) ? 0 : 1; - - if(rank == 1) { // when input is scalar or vector; - REQUIRE_TRUE(paddings->lengthOf() == 2, 0, "MIRROR_PAD OP: the length of paddings array must be equal 2, when input array is vector or scalar, bot but got %i instead !", paddings->rankOf()); - REQUIRE_TRUE( (paddings->e(0) <= (input->lengthOf() - includeBorder)) && (paddings->e(1) <= (input->lengthOf() - includeBorder)), 0, "MIRROR_PAD OP: wrong content of paddings array, its elements must be no grater then length of input array (being vector or scalar) for symmetric mode (or length-1 for reflect mode) !"); - } - else { - REQUIRE_TRUE(paddings->rankOf() == 2, 0, "MIRROR_PAD OP: the rank of paddings array must be equal 2, but got %i instead !", paddings->rankOf()); - REQUIRE_TRUE(paddings->sizeAt(0) == input->rankOf(), 0, "MIRROR_PAD OP: zero dimension of paddings array must be equal to input array rank, but got %i and %i correspondingly !", paddings->sizeAt(0), input->rankOf()); - for(int i = 0; i < input->rankOf(); ++i) - REQUIRE_TRUE( (paddings->e(i,0) <= (input->sizeAt(i) - includeBorder)) && (paddings->e(i,1) <= (input->sizeAt(i) - includeBorder)), 0, "MIRROR_PAD OP: wrong content of paddings array, its elements must be no grater then corresponding dimension of input array for symmetric mode (or dimension-1 for reflect mode) !"); - } - - if(rank == 1) { - Nd4jLong len = input->lengthOf() + paddings->e(0) + paddings->e(1); - return SHAPELIST(ConstantShapeHelper::getInstance().vectorShapeInfo(len, input->dataType())); - } - - Nd4jLong* outShapeInfo(nullptr); - - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); - outShapeInfo[0] = rank; - for(int i = 0; i < rank; ++i) - outShapeInfo[i+1] = input->sizeAt(i) + paddings->e(i,0) + paddings->e(i,1); - ShapeUtils::updateStridesAndType(outShapeInfo, input->shapeInfo(), input->ordering()); - - return SHAPELIST(CONSTANT(outShapeInfo)); + auto input = INPUT_VARIABLE(0); + auto paddings = INPUT_VARIABLE(1); + + const int rank = input->rankOf() + ? input->rankOf() + : 1; // if scalar is input then vector is output + const int includeBorder = static_cast(INT_ARG(0)) ? 0 : 1; + + if (rank == 1) { // when input is scalar or vector; + REQUIRE_TRUE( + paddings->lengthOf() == 2, 0, + "MIRROR_PAD OP: the length of paddings array must be equal 2, when " + "input array is vector or scalar, bot but got %i instead !", + paddings->rankOf()); + REQUIRE_TRUE( + (paddings->e(0) <= (input->lengthOf() - includeBorder)) && + (paddings->e(1) <= (input->lengthOf() - includeBorder)), + 0, + "MIRROR_PAD OP: wrong content of paddings array, its elements must be " + "no grater then length of input array (being vector or scalar) for " + "symmetric mode (or length-1 for reflect mode) !"); + } else { + REQUIRE_TRUE(paddings->rankOf() == 2, 0, + "MIRROR_PAD OP: the rank of paddings array must be equal 2, " + "but got %i instead !", + paddings->rankOf()); + REQUIRE_TRUE( + paddings->sizeAt(0) == input->rankOf(), 0, + "MIRROR_PAD OP: zero dimension of paddings array must be equal to " + "input array rank, but got %i and %i correspondingly !", + paddings->sizeAt(0), input->rankOf()); + for (int i = 0; i < input->rankOf(); ++i) + REQUIRE_TRUE( + (paddings->e(i, 0) <= (input->sizeAt(i) - includeBorder)) && + (paddings->e(i, 1) <= + (input->sizeAt(i) - includeBorder)), + 0, + "MIRROR_PAD OP: wrong content of paddings array, its elements must " + "be no grater then corresponding dimension of input array for " + "symmetric mode (or dimension-1 for reflect mode) !"); + } + + if (rank == 1) { + Nd4jLong len = + input->lengthOf() + paddings->e(0) + paddings->e(1); + return SHAPELIST(ConstantShapeHelper::getInstance().vectorShapeInfo( + len, input->dataType())); + } + + Nd4jLong* outShapeInfo(nullptr); + + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); + outShapeInfo[0] = rank; + for (int i = 0; i < rank; ++i) + outShapeInfo[i + 1] = input->sizeAt(i) + paddings->e(i, 0) + + paddings->e(i, 1); + ShapeUtils::updateStridesAndType(outShapeInfo, input->shapeInfo(), + input->ordering()); + + return SHAPELIST(CONSTANT(outShapeInfo)); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/pad.cpp b/libnd4j/include/ops/declarable/generic/transforms/pad.cpp index d09063a95383..9ac4adb884fe 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/pad.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/pad.cpp @@ -22,94 +22,119 @@ #if NOT_EXCLUDED(OP_pad) #include -#include +#include + #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(pad, 2, 1, false, 0, 1) { - - auto input = INPUT_VARIABLE(0); - auto paddings = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - - const int rank = input->rankOf(); - - // input validation - std::vector expectedPaddingsShape = {rank, 2}; - std::vector currentPaddingsShape = paddings->getShapeAsVector(); - REQUIRE_TRUE(expectedPaddingsShape == currentPaddingsShape, 0, "PAD op: wrong shape of paddings array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedPaddingsShape).c_str(), ShapeUtils::shapeAsString(currentPaddingsShape).c_str()); - - NDArray padValue(input->dataType(), block.launchContext()); - - // in case of REFLECT and SYMMETRIC modes paddings must obey additional shape requirements - if (INT_ARG(0) == 0) { // CONSTANT mode - if(block.width() > 2) { - REQUIRE_TRUE(input->dataType() == INPUT_VARIABLE(2)->dataType(), 0, "PAD op: data types of input and padValue arrays should be the same but got %i and %i correspondingly !", input->dataType(), INPUT_VARIABLE(2)->dataType()); - padValue.assign(INPUT_VARIABLE(2)->e(0)); - } - else if (!block.getTArguments()->empty()) - padValue = T_ARG(0); - } - else if(INT_ARG(0) == 1) { // REFLECT mode - for(int dim=0; dim < rank; ++dim) - REQUIRE_TRUE(paddings->e(dim,0) <= (input->shapeOf()[dim]-1) && paddings->e(dim,1) <= (input->shapeOf()[dim]-1), 0, "PAD op: wrong content of paddings array for REFLECT mode !"); - } - if(INT_ARG(0) == 2) { // SYMMETRIC mode - for(int dim=0; dim < rank; ++dim) - REQUIRE_TRUE(paddings->e(dim,0) <= input->shapeOf()[dim] && paddings->e(dim,1) <= input->shapeOf()[dim], 0, "PAD op: wrong content of paddings array for SYMMETRIC mode !"); - } - - // CONSTANT->0, REFLECT->1, SYMMETRIC->2 - REQUIRE_TRUE(INT_ARG(0) >= 0 && INT_ARG(0) <= 2, 0, "PAD op: unknown padding mode, there are only three possible legal values -> 0,1,2, but got %i instead !", INT_ARG(0)); - - // std::vector dimensions(input->rankOf()); - // std::iota(dimensions.begin(), dimensions.end(), 0); // fill with 0, 1, ... rank-1 - - // helpers::recursiveLoopForPad(INT_ARG(0), *input, *paddings, *output, dimensions, 0, 0, 0, padValue); - helpers::pad(block.launchContext(), INT_ARG(0), *input, *paddings, *output, padValue); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto paddings = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + + const int rank = input->rankOf(); + + // input validation + std::vector expectedPaddingsShape = {rank, 2}; + std::vector currentPaddingsShape = paddings->getShapeAsVector(); + REQUIRE_TRUE(expectedPaddingsShape == currentPaddingsShape, 0, + "PAD op: wrong shape of paddings array, expected is %s, but got " + "%s instead !", + ShapeUtils::shapeAsString(expectedPaddingsShape).c_str(), + ShapeUtils::shapeAsString(currentPaddingsShape).c_str()); + + NDArray padValue(input->dataType(), block.launchContext()); + + // in case of REFLECT and SYMMETRIC modes paddings must obey additional shape + // requirements + if (INT_ARG(0) == 0) { // CONSTANT mode + if (block.width() > 2) { + REQUIRE_TRUE(input->dataType() == INPUT_VARIABLE(2)->dataType(), 0, + "PAD op: data types of input and padValue arrays should be " + "the same but got %i and %i correspondingly !", + input->dataType(), INPUT_VARIABLE(2)->dataType()); + padValue.assign(INPUT_VARIABLE(2)->e(0)); + } else if (!block.getTArguments().empty()) + padValue = T_ARG(0); + } else if (INT_ARG(0) == 1) { // REFLECT mode + for (int dim = 0; dim < rank; ++dim) + REQUIRE_TRUE( + paddings->e(dim, 0) <= (input->shapeOf()[dim] - 1) && + paddings->e(dim, 1) <= (input->shapeOf()[dim] - 1), + 0, "PAD op: wrong content of paddings array for REFLECT mode !"); + } + if (INT_ARG(0) == 2) { // SYMMETRIC mode + for (int dim = 0; dim < rank; ++dim) + REQUIRE_TRUE( + paddings->e(dim, 0) <= input->shapeOf()[dim] && + paddings->e(dim, 1) <= input->shapeOf()[dim], + 0, "PAD op: wrong content of paddings array for SYMMETRIC mode !"); + } + + // CONSTANT->0, REFLECT->1, SYMMETRIC->2 + REQUIRE_TRUE(INT_ARG(0) >= 0 && INT_ARG(0) <= 2, 0, + "PAD op: unknown padding mode, there are only three possible " + "legal values -> 0,1,2, but got %i instead !", + INT_ARG(0)); + + // std::vector dimensions(input->rankOf()); + // std::iota(dimensions.begin(), dimensions.end(), 0); + // // fill with 0, 1, ... rank-1 + + // helpers::recursiveLoopForPad(INT_ARG(0), *input, *paddings, *output, + // dimensions, 0, 0, 0, padValue); + helpers::pad(block.launchContext(), INT_ARG(0), *input, *paddings, *output, + padValue); + + return Status::OK(); } DECLARE_TYPES(pad) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}) // INT32 with TF -// ->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}) // INT32 with TF, but used also INT64 due long shapes - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::INT32, DataType::INT64}) // INT32 with TF + // ->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}) // + // INT32 with TF, but used also INT64 due long shapes + ->setSameMode(true); } DECLARE_SHAPE_FN(pad) { - - // check shape of paddings - auto inputShapeInfo = inputShape->at(0); - auto paddings = INPUT_VARIABLE(1); - const int rank = inputShapeInfo[0]; - - // paddings validation - const std::vector expectedPaddingsShape = {rank, 2}; - const std::vector currentPaddingsShape = paddings->getShapeAsVector(); - REQUIRE_TRUE(expectedPaddingsShape == currentPaddingsShape, 0, "PAD op: wrong shape of paddings array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedPaddingsShape).c_str(), ShapeUtils::shapeAsString(currentPaddingsShape).c_str()); - - Nd4jLong* outShapeInfo = nullptr; - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); - outShapeInfo[0] = rank; - for(int i=1; i <= rank; ++i) - outShapeInfo[i] = inputShapeInfo[i] + paddings->e(i-1,0) + paddings->e(i-1,1); - - ShapeUtils::updateStridesAndType(outShapeInfo, inputShapeInfo, shape::order(inputShapeInfo)); - ShapeDescriptor descriptor(outShapeInfo); - RELEASE(outShapeInfo, block.getWorkspace()); - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(descriptor)); + // check shape of paddings + auto inputShapeInfo = inputShape->at(0); + auto paddings = INPUT_VARIABLE(1); + const int rank = inputShapeInfo[0]; + + // paddings validation + const std::vector expectedPaddingsShape = {rank, 2}; + const std::vector currentPaddingsShape = + paddings->getShapeAsVector(); + REQUIRE_TRUE(expectedPaddingsShape == currentPaddingsShape, 0, + "PAD op: wrong shape of paddings array, expected is %s, but got " + "%s instead !", + ShapeUtils::shapeAsString(expectedPaddingsShape).c_str(), + ShapeUtils::shapeAsString(currentPaddingsShape).c_str()); + + Nd4jLong* outShapeInfo = nullptr; + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); + outShapeInfo[0] = rank; + for (int i = 1; i <= rank; ++i) + outShapeInfo[i] = inputShapeInfo[i] + paddings->e(i - 1, 0) + + paddings->e(i - 1, 1); + + ShapeUtils::updateStridesAndType(outShapeInfo, inputShapeInfo, + shape::order(inputShapeInfo)); + ShapeDescriptor descriptor(outShapeInfo); + RELEASE(outShapeInfo, block.workspace()); + return SHAPELIST( + ConstantShapeHelper::getInstance().createShapeInfo(descriptor)); } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/parallelStack.cpp b/libnd4j/include/ops/declarable/generic/transforms/parallelStack.cpp index 46572d88eb48..65817f6ad680 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/parallelStack.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/parallelStack.cpp @@ -22,56 +22,57 @@ #if NOT_EXCLUDED(OP_parallel_stack) #include -#include +#include namespace sd { -namespace ops { - +namespace ops { CUSTOM_OP_IMPL(parallel_stack, -1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - // check whether shapes of all input array are the same - for (int i = 0; i < (int) block.width() - 1; ++i) - REQUIRE_TRUE(shape::equalsSoft((INPUT_VARIABLE(i))->shapeInfo(), (INPUT_VARIABLE(i+1))->shapeInfo()), 0, "PARALLEL_STACK op: the shapes of all input arrays must be the same !"); + // check whether shapes of all input array are the same + for (int i = 0; i < (int)block.width() - 1; ++i) + REQUIRE_TRUE( + shape::equalsSoft((INPUT_VARIABLE(i))->shapeInfo(), + (INPUT_VARIABLE(i + 1))->shapeInfo()), + 0, + "PARALLEL_STACK op: the shapes of all input arrays must be the same !"); - std::vector inArrs(block.width()); - for(int i = 0; i < block.width(); ++i) - inArrs[i] = INPUT_VARIABLE(i); + std::vector inArrs(block.width()); + for (int i = 0; i < block.width(); ++i) inArrs[i] = INPUT_VARIABLE(i); - const int dim = 0; - helpers::stack(block.launchContext(), inArrs, *output, dim); + const int dim = 0; + helpers::stack(block.launchContext(), inArrs, *output, dim); - return Status::OK(); + return Status::OK(); } - DECLARE_TYPES(parallel_stack) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(parallel_stack) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(parallel_stack) { + auto inShapeInfo = inputShape->at(0); + int rank = inShapeInfo[0]; - auto inShapeInfo = inputShape->at(0); - int rank = inShapeInfo[0]; - - Nd4jLong* outShapeInfo = nullptr; - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank+1), Nd4jLong); + Nd4jLong* outShapeInfo = nullptr; + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank + 1), + Nd4jLong); - outShapeInfo[0] = rank + 1; - outShapeInfo[1] = block.width(); - for(int i = 1; i <= rank; ++i) - outShapeInfo[i+1] = inShapeInfo[i]; + outShapeInfo[0] = rank + 1; + outShapeInfo[1] = block.width(); + for (int i = 1; i <= rank; ++i) outShapeInfo[i + 1] = inShapeInfo[i]; - ShapeUtils::updateStridesAndType(outShapeInfo, inShapeInfo, shape::order(inShapeInfo)); + ShapeUtils::updateStridesAndType(outShapeInfo, inShapeInfo, + shape::order(inShapeInfo)); - return SHAPELIST(CONSTANT(outShapeInfo)); + return SHAPELIST(CONSTANT(outShapeInfo)); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/repeat.cpp b/libnd4j/include/ops/declarable/generic/transforms/repeat.cpp index b02f7010c7f1..49421f1abf85 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/repeat.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/repeat.cpp @@ -24,52 +24,58 @@ #include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// -// here iArgs is int vector of repeats at the beginning and last element in iArgs is dimension +// here iArgs is int vector of repeats at the beginning and last element in +// iArgs is dimension CUSTOM_OP_IMPL(repeat, 1, 1, true, 0, -1) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + std::vector repeats = block.getIArguments(); - std::vector repeats = *block.getIArguments(); + const int axis = + repeats.back() < 0 ? repeats.back() + input->rankOf() : repeats.back(); - const int axis = repeats.back() < 0 ? repeats.back() + input->rankOf() : repeats.back(); + repeats.pop_back(); - repeats.pop_back(); + REQUIRE_TRUE(0 <= axis && axis < input->rankOf(), 0, + "CUSTOM REPEAT OP: wrong axis argument it should be less then " + "input array rank %i, but got %i instead !", + input->rankOf(), axis); - REQUIRE_TRUE(0 <= axis && axis < input->rankOf(), 0, "CUSTOM REPEAT OP: wrong axis argument it should be less then input array rank %i, but got %i instead !", input->rankOf(), axis); + REQUIRE_TRUE(repeats.size() == 1 || repeats.size() == input->sizeAt(axis), 0, + "CUSTOM REPEAT OP: wrong axis argument, size of repeats vector " + "must be 1 or equal to dimension at given axis, but got " + "repeats.size = %i and axis = %i !", + repeats.size(), axis); - REQUIRE_TRUE(repeats.size() == 1 || repeats.size() == input->sizeAt(axis), 0, "CUSTOM REPEAT OP: wrong axis argument, size of repeats vector must be 1 or equal to dimension at given axis, but got repeats.size = %i and axis = %i !", repeats.size(), axis); + input->repeat(axis, repeats, *output); - input->repeat(axis, repeats, *output); - - return Status::OK(); + return Status::OK(); } DECLARE_TYPES(repeat) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } DECLARE_SHAPE_FN(repeat) { + auto input = INPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - - std::vector repeats = *block.getIArguments(); - - const int axis = repeats.back() < 0 ? repeats.back() + input->rankOf() : repeats.back(); + auto repeats = block.getIArguments(); - repeats.pop_back(); + const int axis = + repeats.back() < 0 ? repeats.back() + input->rankOf() : repeats.back(); - auto outShape = ShapeUtils::evalRepeatShape(axis, repeats, *input); + repeats.pop_back(); - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(input->dataType(), input->ordering(), outShape))); + auto outShape = ShapeUtils::evalRepeatShape(axis, repeats, *input); + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(input->dataType(), input->ordering(), outShape))); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp b/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp index e8f659c5def0..c1d758df80d9 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp @@ -24,88 +24,84 @@ #include #include - namespace sd { -namespace ops { - - CONFIGURABLE_OP_IMPL(reverse, 1, 1, true, 0, -2) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - if(output->isEmpty()){ - //No-op - return Status::OK(); - } - - std::vector axis; - - if (block.width() > 1) - axis = INPUT_VARIABLE(1)->template asVectorT(); - else if (block.numI() > 0) - axis = *block.getIArguments(); - - if(axis.empty()) { // do not perform reversion - if (!block.isInplace()) - output->assign(input); - } - else { - // check the consistency of input dimensions to reverse along - shape::checkDimensions(input->rankOf(), axis); - helpers::reverse(block.launchContext(), input, output, &axis); - } - - return Status::OK(); - } - - DECLARE_SYN(reverse_v2, reverse); - - DECLARE_TYPES(reverse) { - getOpDescriptor()->setAllowedInputTypes(0, DataType::ANY); - getOpDescriptor()->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}); - getOpDescriptor()->setAllowedOutputTypes(0, DataType::INHERIT); - } - - CUSTOM_OP_IMPL(reverse_bp, 2, 1, false, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto eps = block.width() == 3 ? INPUT_VARIABLE(2) : INPUT_VARIABLE(1); - - auto output = OUTPUT_VARIABLE(0); - std::vector axis; - - if (block.width() == 3) - axis = INPUT_VARIABLE(1)->template asVectorT(); - else if (block.numI() > 0) - axis = *block.getIArguments(); - - if(axis.empty()) { // reversion is not performed in this case - output->assign(eps); - } - else { - // check the consistency of input dimensions to reverse along - shape::checkDimensions(input->rankOf(), axis); - // we just reverse back original array - helpers::reverse(block.launchContext(), eps, output, &axis); - } - - return Status::OK(); - } - - DECLARE_TYPES(reverse_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - DECLARE_SHAPE_FN(reverse_bp) { - auto in = inputShape->at(0); - Nd4jLong *out; - COPY_SHAPE(in, out); - - return SHAPELIST(CONSTANT(out)); - } +namespace ops { + +CONFIGURABLE_OP_IMPL(reverse, 1, 1, true, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (output->isEmpty()) { + // No-op + return Status::OK(); + } + + std::vector axis; + + if (block.width() > 1) + axis = INPUT_VARIABLE(1)->template asVectorT(); + else if (block.numI() > 0) + axis = block.getIArguments(); + + if (axis.empty()) { // do not perform reversion + if (!block.isInplace()) output->assign(input); + } else { + // check the consistency of input dimensions to reverse along + shape::checkDimensions(input->rankOf(), axis); + helpers::reverse(block.launchContext(), input, output, &axis); + } + + return Status::OK(); +} + +DECLARE_SYN(reverse_v2, reverse); + +DECLARE_TYPES(reverse) { + getOpDescriptor()->setAllowedInputTypes(0, DataType::ANY); + getOpDescriptor()->setAllowedInputTypes(1, + {DataType::INT32, DataType::INT64}); + getOpDescriptor()->setAllowedOutputTypes(0, DataType::INHERIT); +} + +CUSTOM_OP_IMPL(reverse_bp, 2, 1, false, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto eps = block.width() == 3 ? INPUT_VARIABLE(2) : INPUT_VARIABLE(1); + + auto output = OUTPUT_VARIABLE(0); + std::vector axis; + + if (block.width() == 3) + axis = INPUT_VARIABLE(1)->template asVectorT(); + else if (block.numI() > 0) + axis = block.getIArguments(); + + if (axis.empty()) { // reversion is not performed in this case + output->assign(eps); + } else { + // check the consistency of input dimensions to reverse along + shape::checkDimensions(input->rankOf(), axis); + // we just reverse back original array + helpers::reverse(block.launchContext(), eps, output, &axis); + } + return Status::OK(); } + +DECLARE_TYPES(reverse_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +DECLARE_SHAPE_FN(reverse_bp) { + auto in = inputShape->at(0); + Nd4jLong *out; + COPY_SHAPE(in, out); + + return SHAPELIST(CONSTANT(out)); +} + +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/reverseSequence.cpp b/libnd4j/include/ops/declarable/generic/transforms/reverseSequence.cpp index c7dcc6e36dbf..058a5d631e4f 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/reverseSequence.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/reverseSequence.cpp @@ -28,57 +28,99 @@ namespace sd { namespace ops { CUSTOM_OP_IMPL(reverse_sequence, 2, 1, false, 0, 2) { - - auto input = INPUT_VARIABLE(0); - auto seqLengths = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - - int seqDim = INT_ARG(0); - int batchDim = block.numI() > 1 ? INT_ARG(1) : 0; - - REQUIRE_TRUE(input->rankOf() > 1, 0, "REVERSE_SEQUENSE operation: input array must have rank > 1, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(seqLengths->rankOf() == 1, 0, "REVERSE_SEQUENSE operation: input array seqLengths must be 1D vector, that is it must have rank == 1, but got %i instead !", seqLengths->rankOf()); - REQUIRE_TRUE(seqLengths->lengthOf() == input->sizeAt(batchDim), 0, "REVERSE_SEQUENSE custom operation: the length of array seqLengths must be equal to the value of batchDim dimension of input array, but got %i and %i correspondingly !", seqLengths->lengthOf(), input->sizeAt(batchDim)); - REQUIRE_TRUE(seqDim != batchDim, 0, "REVERSE_SEQUENSE operation: input integer parameters seqDim and batchDim must be different, but they both are equal to %i !", batchDim); - REQUIRE_TRUE(batchDim < input->rankOf(), 0, "REVERSE_SEQUENSE operation: input integer parameter batchDim must be smaller than input array rank, but got %i and %i correspondingly !", batchDim, input->rankOf()); - REQUIRE_TRUE(seqDim < input->rankOf(), 0, "REVERSE_SEQUENSE operation: input integer parameter seqDim must be smaller than input array rank, but got %i and %i correspondingly !", seqDim, input->rankOf()); - - auto maxElem = seqLengths->reduceNumber(reduce::Max); - REQUIRE_TRUE(maxElem.e(0) <= input->sizeAt(seqDim), 0, "REVERSE_SEQUENSE operation: max element in seqLengths array must be not greater than value of seqDim dimension of input array !"); - - helpers::reverseSequence(block.launchContext(), input, seqLengths, output, seqDim, batchDim); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto seqLengths = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + + int seqDim = INT_ARG(0); + int batchDim = block.numI() > 1 ? INT_ARG(1) : 0; + + REQUIRE_TRUE(input->rankOf() > 1, 0, + "REVERSE_SEQUENSE operation: input array must have rank > 1, " + "but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(seqLengths->rankOf() == 1, 0, + "REVERSE_SEQUENSE operation: input array seqLengths must be 1D " + "vector, that is it must have rank == 1, but got %i instead !", + seqLengths->rankOf()); + REQUIRE_TRUE(seqLengths->lengthOf() == input->sizeAt(batchDim), 0, + "REVERSE_SEQUENSE custom operation: the length of array " + "seqLengths must be equal to the value of batchDim dimension of " + "input array, but got %i and %i correspondingly !", + seqLengths->lengthOf(), input->sizeAt(batchDim)); + REQUIRE_TRUE( + seqDim != batchDim, 0, + "REVERSE_SEQUENSE operation: input integer parameters seqDim and " + "batchDim must be different, but they both are equal to %i !", + batchDim); + REQUIRE_TRUE( + batchDim < input->rankOf(), 0, + "REVERSE_SEQUENSE operation: input integer parameter batchDim must be " + "smaller than input array rank, but got %i and %i correspondingly !", + batchDim, input->rankOf()); + REQUIRE_TRUE( + seqDim < input->rankOf(), 0, + "REVERSE_SEQUENSE operation: input integer parameter seqDim must be " + "smaller than input array rank, but got %i and %i correspondingly !", + seqDim, input->rankOf()); + + auto maxElem = seqLengths->reduceNumber(reduce::Max); + REQUIRE_TRUE( + maxElem.e(0) <= input->sizeAt(seqDim), 0, + "REVERSE_SEQUENSE operation: max element in seqLengths array must be not " + "greater than value of seqDim dimension of input array !"); + + helpers::reverseSequence(block.launchContext(), input, seqLengths, output, + seqDim, batchDim); + + return Status::OK(); } - DECLARE_TYPES(reverse_sequence) { - getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}); - getOpDescriptor()->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}); - getOpDescriptor()->setAllowedOutputTypes(0, DataType::INHERIT); - } +DECLARE_TYPES(reverse_sequence) { + getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}); + getOpDescriptor()->setAllowedInputTypes(1, + {DataType::INT32, DataType::INT64}); + getOpDescriptor()->setAllowedOutputTypes(0, DataType::INHERIT); +} DECLARE_SHAPE_FN(reverse_sequence) { - - auto inShapeInfo = inputShape->at(0); - auto seqLenShapeInfo = inputShape->at(1); - - int seqDim = INT_ARG(0); - int batchDim = block.numI() > 1 ? INT_ARG(1) : 0; - - REQUIRE_TRUE(batchDim < inShapeInfo[0], 0, "REVERSE_SEQUENSE operation: input integer parameter batchDim must be smaller than input array rank, but got %i and %i correspondingly !", batchDim, inShapeInfo[0]); - REQUIRE_TRUE(seqDim < inShapeInfo[0], 0, "REVERSE_SEQUENSE operation: input integer parameter seqDim must be smaller than input array rank, but got %i and %i correspondingly !", seqDim, inShapeInfo[0]); - REQUIRE_TRUE(inShapeInfo[0] > 1, 0, "REVERSE_SEQUENSE operation: input array must have rank > 1, but got %i instead !", inShapeInfo[0]); - REQUIRE_TRUE(seqLenShapeInfo[0] == 1, 0, "REVERSE_SEQUENSE operation: input array seqLengths must be 1D vector, that is it must have rank == 1, but got %i instead !", seqLenShapeInfo[0]); - REQUIRE_TRUE(seqLenShapeInfo[1] == inShapeInfo[batchDim+1], 0, "REVERSE_SEQUENSE custom operation: the length of array seqLengths must be equal to the value of batchDim dimension of input array, but got %i and %i correspondingly !", seqLenShapeInfo[1], inShapeInfo[batchDim+1]); - - Nd4jLong* outShapeInfo = nullptr; - COPY_SHAPE(inShapeInfo, outShapeInfo); - - return SHAPELIST(CONSTANT(outShapeInfo)); + auto inShapeInfo = inputShape->at(0); + auto seqLenShapeInfo = inputShape->at(1); + + int seqDim = INT_ARG(0); + int batchDim = block.numI() > 1 ? INT_ARG(1) : 0; + + REQUIRE_TRUE( + batchDim < inShapeInfo[0], 0, + "REVERSE_SEQUENSE operation: input integer parameter batchDim must be " + "smaller than input array rank, but got %i and %i correspondingly !", + batchDim, inShapeInfo[0]); + REQUIRE_TRUE( + seqDim < inShapeInfo[0], 0, + "REVERSE_SEQUENSE operation: input integer parameter seqDim must be " + "smaller than input array rank, but got %i and %i correspondingly !", + seqDim, inShapeInfo[0]); + REQUIRE_TRUE(inShapeInfo[0] > 1, 0, + "REVERSE_SEQUENSE operation: input array must have rank > 1, " + "but got %i instead !", + inShapeInfo[0]); + REQUIRE_TRUE(seqLenShapeInfo[0] == 1, 0, + "REVERSE_SEQUENSE operation: input array seqLengths must be 1D " + "vector, that is it must have rank == 1, but got %i instead !", + seqLenShapeInfo[0]); + REQUIRE_TRUE(seqLenShapeInfo[1] == inShapeInfo[batchDim + 1], 0, + "REVERSE_SEQUENSE custom operation: the length of array " + "seqLengths must be equal to the value of batchDim dimension of " + "input array, but got %i and %i correspondingly !", + seqLenShapeInfo[1], inShapeInfo[batchDim + 1]); + + Nd4jLong* outShapeInfo = nullptr; + COPY_SHAPE(inShapeInfo, outShapeInfo); + + return SHAPELIST(CONSTANT(outShapeInfo)); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_add.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_add.cpp index e624afeb1e57..e0a8f48241df 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_add.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_add.cpp @@ -29,73 +29,90 @@ namespace sd { namespace ops { OP_IMPL(scatter_add, 3, 1, true) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto updates = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - if (!block.isInplace()) - output->assign(input); - - const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); - const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); - - const int inRank = input->rankOf(); - const int indRank = indices->rankOf(); - const int updRank = updates->rankOf(); - const Nd4jLong indLen = indices->lengthOf(); - - REQUIRE_TRUE(inRank > 0, 0, "SCATTER_ADD OP: input should not be scalar !"); - - if(inRank == 1) { - REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_ADD OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); - } - else if (inRank == updRank && indices->isVector()) { - - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto updates = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + if (!block.isInplace()) output->assign(input); + + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); + + const int inRank = input->rankOf(); + const int indRank = indices->rankOf(); + const int updRank = updates->rankOf(); + const Nd4jLong indLen = indices->lengthOf(); + + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_ADD OP: input should not be scalar !"); + + if (inRank == 1) { + REQUIRE_TRUE(indices->isSameShape(updates), 0, + "SCATTER_ADD OP: when input array has rank = 1 then indices " + "and updates must have the same shapes, but got %s and %s " + "correspondingly !", + ShapeUtils::shapeAsString(indices).c_str(), + ShapeUtils::shapeAsString(updates).c_str()); + } else if (inRank == updRank && indices->isVector()) { + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = {indices->lengthOf()}; + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_ADD OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } else { + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, + "SCATTER_ADD OP: wrong rank of updates array, expected is %i, " + "but got %i instead !", + indRank + inRank - 1, updRank); + + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); + expectedUpdShape.insert(expectedUpdShape.end(), + inShape.begin() + Nd4jLong(1L), inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_ADD OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } + + if (!indices->isEmpty()) { + if (checkIndices) { + const Nd4jLong numOfBadIndx = + helpers::checkIndices(block.launchContext(), *indices, *output, 0); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "SCATTER_ADD OP: please check elements of indices-array, " + "total number of wrong elements is %lld!", + numOfBadIndx); } - else { - REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_ADD OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); + helpers::scatter(block.launchContext(), pairwise::Add, *indices, *updates, + *output, lock); + } - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + Nd4jLong(1L), inShape.end()); - - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - - if (!indices->isEmpty()) { - - if(checkIndices) { - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_ADD OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - } - - helpers::scatter(block.launchContext(), pairwise::Add, *indices, *updates, *output, lock); - } - - return Status::OK(); + return Status::OK(); } DECLARE_SYN(ScatterAdd, scatter_add); DECLARE_TYPES(scatter_add) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_div.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_div.cpp index fd0b2a7305a3..71b8483cb8d2 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_div.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_div.cpp @@ -26,73 +26,89 @@ #include namespace sd { - namespace ops { - OP_IMPL(scatter_div, 3, 1, true) { - - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto updates = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - if (!block.isInplace()) - output->assign(input); - - const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); - const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); - - const int inRank = input->rankOf(); - const int indRank = indices->rankOf(); - const int updRank = updates->rankOf(); - - REQUIRE_TRUE(inRank > 0, 0, "SCATTER_DIV OP: input should not be scalar !"); - - if(inRank == 1) { - REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_DIV OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); - } - else if (inRank == updRank && indices->isVector()) { - - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_DIV OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - else { - - REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_DIV OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); - - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_DIV OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - - if (!indices->isEmpty()) { - - if(checkIndices) { - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_DIV OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - } - - helpers::scatter(block.launchContext(), pairwise::Divide, *indices, *updates, *output, lock); - } +namespace ops { +OP_IMPL(scatter_div, 3, 1, true) { + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto updates = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + if (!block.isInplace()) output->assign(input); + + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); + + const int inRank = input->rankOf(); + const int indRank = indices->rankOf(); + const int updRank = updates->rankOf(); + + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_DIV OP: input should not be scalar !"); + + if (inRank == 1) { + REQUIRE_TRUE(indices->isSameShape(updates), 0, + "SCATTER_DIV OP: when input array has rank = 1 then indices " + "and updates must have the same shapes, but got %s and %s " + "correspondingly !", + ShapeUtils::shapeAsString(indices).c_str(), + ShapeUtils::shapeAsString(updates).c_str()); + } else if (inRank == updRank && indices->isVector()) { + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = {indices->lengthOf()}; + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_DIV OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } else { + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, + "SCATTER_DIV OP: wrong rank of updates array, expected is %i, " + "but got %i instead !", + indRank + inRank - 1, updRank); + + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_DIV OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } + + if (!indices->isEmpty()) { + if (checkIndices) { + const Nd4jLong numOfBadIndx = + helpers::checkIndices(block.launchContext(), *indices, *output, 0); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "SCATTER_DIV OP: please check elements of indices-array, " + "total number of wrong elements is %lld!", + numOfBadIndx); + } - return Status::OK(); - } - DECLARE_SYN(ScatterDiv, scatter_div); + helpers::scatter(block.launchContext(), pairwise::Divide, *indices, + *updates, *output, lock); + } - DECLARE_TYPES(scatter_div) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } - } + return Status::OK(); +} +DECLARE_SYN(ScatterDiv, scatter_div); + +DECLARE_TYPES(scatter_div) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_max.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_max.cpp index b3342c5a58ac..370e95a5e1e0 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_max.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_max.cpp @@ -28,72 +28,88 @@ namespace sd { namespace ops { OP_IMPL(scatter_max, 3, 1, true) { - - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto updates = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - if (!block.isInplace()) - output->assign(input); - - const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); - const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); - - const int inRank = input->rankOf(); - const int indRank = indices->rankOf(); - const int updRank = updates->rankOf(); - - REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MAX OP: input should not be scalar !"); - - if(inRank == 1) { - REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_MAX OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto updates = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + if (!block.isInplace()) output->assign(input); + + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); + + const int inRank = input->rankOf(); + const int indRank = indices->rankOf(); + const int updRank = updates->rankOf(); + + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MAX OP: input should not be scalar !"); + + if (inRank == 1) { + REQUIRE_TRUE(indices->isSameShape(updates), 0, + "SCATTER_MAX OP: when input array has rank = 1 then indices " + "and updates must have the same shapes, but got %s and %s " + "correspondingly !", + ShapeUtils::shapeAsString(indices).c_str(), + ShapeUtils::shapeAsString(updates).c_str()); + } else if (inRank == updRank && indices->isVector()) { + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = {indices->lengthOf()}; + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_MAX OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } else { + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, + "SCATTER_MAX OP: wrong rank of updates array, expected is %i, " + "but got %i instead !", + indRank + inRank - 1, updRank); + + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_MAX OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } + + if (!indices->isEmpty()) { + if (checkIndices) { + const Nd4jLong numOfBadIndx = + helpers::checkIndices(block.launchContext(), *indices, *output, 0); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "SCATTER_MAX OP: please check elements of indices-array, " + "total number of wrong elements is %lld!", + numOfBadIndx); } - else if (inRank == updRank && indices->isVector()) { - - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MAX OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - else { - - REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_MAX OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); + helpers::scatter(block.launchContext(), pairwise::MaxPairwise, *indices, + *updates, *output, lock); + } - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MAX OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - - if (!indices->isEmpty()) { - - if(checkIndices) { - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_MAX OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - } - - helpers::scatter(block.launchContext(), pairwise::MaxPairwise, *indices, *updates, *output, lock); - } - - return Status::OK(); + return Status::OK(); } DECLARE_SYN(ScatterMax, scatter_max); - DECLARE_TYPES(scatter_max) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } - -} +DECLARE_TYPES(scatter_max) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_min.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_min.cpp index d37adb692a4e..069a8f10d400 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_min.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_min.cpp @@ -28,72 +28,88 @@ namespace sd { namespace ops { OP_IMPL(scatter_min, 3, 1, true) { - - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto updates = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - if (!block.isInplace()) - output->assign(input); - - const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); - const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); - - const int inRank = input->rankOf(); - const int indRank = indices->rankOf(); - const int updRank = updates->rankOf(); - - REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MIN OP: input should not be scalar !"); - - if(inRank == 1) { - REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_MIN OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto updates = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + if (!block.isInplace()) output->assign(input); + + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); + + const int inRank = input->rankOf(); + const int indRank = indices->rankOf(); + const int updRank = updates->rankOf(); + + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MIN OP: input should not be scalar !"); + + if (inRank == 1) { + REQUIRE_TRUE(indices->isSameShape(updates), 0, + "SCATTER_MIN OP: when input array has rank = 1 then indices " + "and updates must have the same shapes, but got %s and %s " + "correspondingly !", + ShapeUtils::shapeAsString(indices).c_str(), + ShapeUtils::shapeAsString(updates).c_str()); + } else if (inRank == updRank && indices->isVector()) { + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = {indices->lengthOf()}; + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_MIN OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } else { + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, + "SCATTER_MIN OP: wrong rank of updates array, expected is %i, " + "but got %i instead !", + indRank + inRank - 1, updRank); + + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_MIN OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } + + if (!indices->isEmpty()) { + if (checkIndices) { + const Nd4jLong numOfBadIndx = + helpers::checkIndices(block.launchContext(), *indices, *output, 0); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "SCATTER_MIN OP: please check elements of indices-array, " + "total number of wrong elements is %lld!", + numOfBadIndx); } - else if (inRank == updRank && indices->isVector()) { - - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MIN OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - else { - - REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_MIN OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); + helpers::scatter(block.launchContext(), pairwise::MinPairwise, *indices, + *updates, *output, lock); + } - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MIN OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - - if (!indices->isEmpty()) { - - if(checkIndices) { - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_MIN OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - } - - helpers::scatter(block.launchContext(), pairwise::MinPairwise, *indices, *updates, *output, lock); - } - - return Status::OK(); + return Status::OK(); } DECLARE_SYN(ScatterMin, scatter_min); - DECLARE_TYPES(scatter_min) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } - -} +DECLARE_TYPES(scatter_min) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_mul.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_mul.cpp index 9bf5be7487b8..29c65052706b 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_mul.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_mul.cpp @@ -26,74 +26,89 @@ #include namespace sd { - namespace ops { - OP_IMPL(scatter_mul, 3, 1, true) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto updates = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); - const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); - - const int inRank = input->rankOf(); - const int indRank = indices->rankOf(); - const int updRank = updates->rankOf(); - - if (!block.isInplace()) - output->assign(input); - - - REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MUL OP: input should not be scalar !"); - - if(inRank == 1) { - REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_MUL OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); - } - else if (inRank == updRank && indices->isVector()) { - - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MUL OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - else { - - REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_MUL OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); - - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MUL OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - - if (!indices->isEmpty()) { - - if(checkIndices) { - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_MUL OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - } - - helpers::scatter(block.launchContext(), pairwise::Multiply, *indices, *updates, *output, lock); - } - - return Status::OK(); - } - DECLARE_SYN(ScatterMul, scatter_mul); +namespace ops { +OP_IMPL(scatter_mul, 3, 1, true) { + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto updates = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); + + const int inRank = input->rankOf(); + const int indRank = indices->rankOf(); + const int updRank = updates->rankOf(); + + if (!block.isInplace()) output->assign(input); + + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MUL OP: input should not be scalar !"); + + if (inRank == 1) { + REQUIRE_TRUE(indices->isSameShape(updates), 0, + "SCATTER_MUL OP: when input array has rank = 1 then indices " + "and updates must have the same shapes, but got %s and %s " + "correspondingly !", + ShapeUtils::shapeAsString(indices).c_str(), + ShapeUtils::shapeAsString(updates).c_str()); + } else if (inRank == updRank && indices->isVector()) { + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = {indices->lengthOf()}; + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_MUL OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } else { + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, + "SCATTER_MUL OP: wrong rank of updates array, expected is %i, " + "but got %i instead !", + indRank + inRank - 1, updRank); + + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_MUL OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } + + if (!indices->isEmpty()) { + if (checkIndices) { + const Nd4jLong numOfBadIndx = + helpers::checkIndices(block.launchContext(), *indices, *output, 0); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "SCATTER_MUL OP: please check elements of indices-array, " + "total number of wrong elements is %lld!", + numOfBadIndx); + } - DECLARE_TYPES(scatter_mul) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); + helpers::scatter(block.launchContext(), pairwise::Multiply, *indices, + *updates, *output, lock); + } - } - } + return Status::OK(); +} +DECLARE_SYN(ScatterMul, scatter_mul); + +DECLARE_TYPES(scatter_mul) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd.cpp index 7c2194c6c9da..34eb2f1c0ca6 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd.cpp @@ -27,73 +27,99 @@ namespace sd { namespace ops { - CUSTOM_OP_IMPL(scatter_nd, 3, 1, false, 0, 0) { - auto indices = INPUT_VARIABLE(0); - auto updates = INPUT_VARIABLE(1); - auto shape = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); - const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); - - const int indRank = indices->rankOf(); - const int updRank = updates->rankOf(); - const int shapeRank = shape->rankOf(); - const Nd4jLong shapeLen = shape->lengthOf(); - - REQUIRE_TRUE(shapeRank == 1, 0, "SCATTER_ND OP: the rank of shape array must be 1, but got %i instead !", shapeRank); - REQUIRE_TRUE(indices->sizeAt(-1) <= shapeLen, 0, "SCATTER_ND OP: last dimension of indices array must be <= length of shape array, but got %i and %i correspondingly !", indices->sizeAt(-1), shapeLen); - // REQUIRE_TRUE(updRank == (indRank + shapeLen - 2), 0, "SCATTER_ND OP: the equality updates_rank = (indices_rank + shape_length - 2) must be true for input arrays, but got instead: updates_rank = %i, indices_rank = %i, shape_length = %i !", updRank, indRank, shapeLen); - REQUIRE_TRUE(updRank == (indRank - 1 + shapeLen - indices->sizeAt(-1)), 0, "SCATTER_ND OP: the equality updates_rank = (indices_rank - 1 + shape_length - last_indices_dimension) must be true for input arrays, but got instead: updates_rank = %i, shape_length = %i, last_indices_dimension = %i !", updRank, shapeLen, indices->sizeAt(-1)); - - std::vector outShape = shape->getBufferAsVector(); - std::vector updShape = updates->getShapeAsVector(); - std::vector indShape = indices->getShapeAsVector(); - std::vector expectedUpdShape(std::begin(indShape), std::end(indShape) - 1); - std::move(std::begin(outShape) + indices->sizeAt(-1), std::end(outShape), std::back_inserter(expectedUpdShape)); - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - - if(checkIndices) { - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_ND OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - } - - // initial zeroing of output - *output = 0; - - helpers::scatterND(block.launchContext(), pairwise::Add, *indices, *updates, *output, lock); - - return Status::OK(); - } - - DECLARE_TYPES(scatter_nd) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } - -//////////////////////////////////////////////////////////////////////// - DECLARE_SHAPE_FN(scatter_nd) { +CUSTOM_OP_IMPL(scatter_nd, 3, 1, false, 0, 0) { + auto indices = INPUT_VARIABLE(0); + auto updates = INPUT_VARIABLE(1); + auto shape = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); + + const int indRank = indices->rankOf(); + const int updRank = updates->rankOf(); + const int shapeRank = shape->rankOf(); + const Nd4jLong shapeLen = shape->lengthOf(); + + REQUIRE_TRUE( + shapeRank == 1, 0, + "SCATTER_ND OP: the rank of shape array must be 1, but got %i instead !", + shapeRank); + REQUIRE_TRUE(indices->sizeAt(-1) <= shapeLen, 0, + "SCATTER_ND OP: last dimension of indices array must be <= " + "length of shape array, but got %i and %i correspondingly !", + indices->sizeAt(-1), shapeLen); + // REQUIRE_TRUE(updRank == (indRank + shapeLen - 2), 0, "SCATTER_ND OP: the + // equality updates_rank = (indices_rank + shape_length - 2) must be true for + // input arrays, but got instead: updates_rank = %i, indices_rank = %i, + // shape_length = %i !", updRank, indRank, shapeLen); + REQUIRE_TRUE(updRank == (indRank - 1 + shapeLen - indices->sizeAt(-1)), 0, + "SCATTER_ND OP: the equality updates_rank = (indices_rank - 1 + " + "shape_length - last_indices_dimension) must be true for input " + "arrays, but got instead: updates_rank = %i, shape_length = %i, " + "last_indices_dimension = %i !", + updRank, shapeLen, indices->sizeAt(-1)); + + std::vector outShape = shape->getBufferAsVector(); + std::vector updShape = updates->getShapeAsVector(); + std::vector indShape = indices->getShapeAsVector(); + std::vector expectedUpdShape(std::begin(indShape), + std::end(indShape) - 1); + std::move(std::begin(outShape) + indices->sizeAt(-1), std::end(outShape), + std::back_inserter(expectedUpdShape)); + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_ND OP: wrong shape of updates array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + + if (checkIndices) { + const Nd4jLong numOfBadIndx = + helpers::checkIndices(block.launchContext(), *indices, *output); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "SCATTER_ND OP: please check elements of indices-array, total " + "number of wrong elements is %lld!", + numOfBadIndx); + } + + // initial zeroing of output + *output = 0; + + helpers::scatterND(block.launchContext(), pairwise::Add, *indices, *updates, + *output, lock); + + return Status::OK(); +} - auto shape = INPUT_VARIABLE(2); - auto updShapeInfo = inputShape->at(1); +DECLARE_TYPES(scatter_nd) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); +} - Nd4jLong *outShapeInfo; - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(shape->lengthOf()), Nd4jLong); +//////////////////////////////////////////////////////////////////////// +DECLARE_SHAPE_FN(scatter_nd) { + auto shape = INPUT_VARIABLE(2); + auto updShapeInfo = inputShape->at(1); - outShapeInfo[0] = shape->lengthOf(); - for (int i = 0; i < outShapeInfo[0]; ++i) - outShapeInfo[i + 1] = shape->e(i); + Nd4jLong *outShapeInfo; + ALLOCATE(outShapeInfo, block.workspace(), + shape::shapeInfoLength(shape->lengthOf()), Nd4jLong); - ShapeUtils::updateStridesAndType(outShapeInfo, updShapeInfo, shape::order(updShapeInfo)); + outShapeInfo[0] = shape->lengthOf(); + for (int i = 0; i < outShapeInfo[0]; ++i) + outShapeInfo[i + 1] = shape->e(i); - return SHAPELIST(CONSTANT(outShapeInfo)); - } + ShapeUtils::updateStridesAndType(outShapeInfo, updShapeInfo, + shape::order(updShapeInfo)); + return SHAPELIST(CONSTANT(outShapeInfo)); } -} + +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_add.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_add.cpp index 8fb4288ee86d..946494955dbf 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_add.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_add.cpp @@ -25,57 +25,75 @@ #include namespace sd { -namespace ops { +namespace ops { OP_IMPL(scatter_nd_add, 3, 1, true) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto updates = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); - const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); - - const int inRank = input->rankOf(); - const int indRank = indices->rankOf(); - const int updRank = updates->rankOf(); - - const Nd4jLong indLastDim = indices->sizeAt(-1); - - REQUIRE_TRUE(indLastDim <= inRank, 0, "SCATTER_ND_ADD OP: the last dimension of indices array must be <= input_array_rank, but got %i instead !", indLastDim); - REQUIRE_TRUE(updRank == (indRank - 1 + inRank - indLastDim), 0, "SCATTER_ND_ADD OP: the equality updates_rank = (indices_rank - 1 + input_rank - last_indices_dimension) must be true for input arrays, but got instead: updates_rank = %i, indices_rank = %i, last_indices_dimension = %i !", updRank, indRank, indLastDim); - - std::vector inShape = input->getShapeAsVector(); - std::vector updShape = updates->getShapeAsVector(); - std::vector indShape = indices->getShapeAsVector(); - std::vector expectedUpdShape(std::begin(indShape), std::end(indShape) - 1); - if(inRank > indLastDim) - std::move(std::begin(inShape) + indLastDim, std::end(inShape), std::back_inserter(expectedUpdShape)); - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - - if(checkIndices) { - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_ND_ADD OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - } - - if (!block.isInplace()) - output->assign(input); - - helpers::scatterND(block.launchContext(), pairwise::Add, *indices, *updates, *output, lock); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto updates = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); + + const int inRank = input->rankOf(); + const int indRank = indices->rankOf(); + const int updRank = updates->rankOf(); + + const Nd4jLong indLastDim = indices->sizeAt(-1); + + REQUIRE_TRUE(indLastDim <= inRank, 0, + "SCATTER_ND_ADD OP: the last dimension of indices array must be " + "<= input_array_rank, but got %i instead !", + indLastDim); + REQUIRE_TRUE(updRank == (indRank - 1 + inRank - indLastDim), 0, + "SCATTER_ND_ADD OP: the equality updates_rank = (indices_rank - " + "1 + input_rank - last_indices_dimension) must be true for " + "input arrays, but got instead: updates_rank = %i, indices_rank " + "= %i, last_indices_dimension = %i !", + updRank, indRank, indLastDim); + + std::vector inShape = input->getShapeAsVector(); + std::vector updShape = updates->getShapeAsVector(); + std::vector indShape = indices->getShapeAsVector(); + std::vector expectedUpdShape(std::begin(indShape), + std::end(indShape) - 1); + if (inRank > indLastDim) + std::move(std::begin(inShape) + indLastDim, std::end(inShape), + std::back_inserter(expectedUpdShape)); + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_ND_ADD OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + + if (checkIndices) { + const Nd4jLong numOfBadIndx = + helpers::checkIndices(block.launchContext(), *indices, *output); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "SCATTER_ND_ADD OP: please check elements of indices-array, " + "total number of wrong elements is %lld!", + numOfBadIndx); + } + + if (!block.isInplace()) output->assign(input); + + helpers::scatterND(block.launchContext(), pairwise::Add, *indices, *updates, + *output, lock); + + return Status::OK(); } - DECLARE_TYPES(scatter_nd_add) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } - -} +DECLARE_TYPES(scatter_nd_add) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } +} // namespace ops +} // namespace sd + #endif diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_sub.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_sub.cpp index 6cfa5d0463c9..1ed2154109e3 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_sub.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_sub.cpp @@ -25,57 +25,75 @@ #include namespace sd { -namespace ops { +namespace ops { OP_IMPL(scatter_nd_sub, 3, 1, true) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto updates = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); - const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); - - const int inRank = input->rankOf(); - const int indRank = indices->rankOf(); - const int updRank = updates->rankOf(); - - const Nd4jLong indLastDim = indices->sizeAt(-1); - - REQUIRE_TRUE(indLastDim <= inRank, 0, "SCATTER_ND_SUB OP: the last dimension of indices array must be <= input_array_rank, but got %i instead !", indLastDim); - REQUIRE_TRUE(updRank == (indRank - 1 + inRank - indLastDim), 0, "SCATTER_ND_SUB OP: the equality updates_rank = (indices_rank - 1 + input_rank - last_indices_dimension) must be true for input arrays, but got instead: updates_rank = %i, indices_rank = %i, last_indices_dimension = %i !", updRank, indRank, indLastDim); - - std::vector inShape = input->getShapeAsVector(); - std::vector updShape = updates->getShapeAsVector(); - std::vector indShape = indices->getShapeAsVector(); - std::vector expectedUpdShape(std::begin(indShape), std::end(indShape) - 1); - if(inRank > indLastDim) - std::move(std::begin(inShape) + indLastDim, std::end(inShape), std::back_inserter(expectedUpdShape)); - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - - if(checkIndices) { - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_ND_SUB OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - } - - if (!block.isInplace()) - output->assign(input); - - helpers::scatterND(block.launchContext(), pairwise::Subtract, *indices, *updates, *output, lock); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto updates = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); + + const int inRank = input->rankOf(); + const int indRank = indices->rankOf(); + const int updRank = updates->rankOf(); + + const Nd4jLong indLastDim = indices->sizeAt(-1); + + REQUIRE_TRUE(indLastDim <= inRank, 0, + "SCATTER_ND_SUB OP: the last dimension of indices array must be " + "<= input_array_rank, but got %i instead !", + indLastDim); + REQUIRE_TRUE(updRank == (indRank - 1 + inRank - indLastDim), 0, + "SCATTER_ND_SUB OP: the equality updates_rank = (indices_rank - " + "1 + input_rank - last_indices_dimension) must be true for " + "input arrays, but got instead: updates_rank = %i, indices_rank " + "= %i, last_indices_dimension = %i !", + updRank, indRank, indLastDim); + + std::vector inShape = input->getShapeAsVector(); + std::vector updShape = updates->getShapeAsVector(); + std::vector indShape = indices->getShapeAsVector(); + std::vector expectedUpdShape(std::begin(indShape), + std::end(indShape) - 1); + if (inRank > indLastDim) + std::move(std::begin(inShape) + indLastDim, std::end(inShape), + std::back_inserter(expectedUpdShape)); + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_ND_SUB OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + + if (checkIndices) { + const Nd4jLong numOfBadIndx = + helpers::checkIndices(block.launchContext(), *indices, *output); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "SCATTER_ND_SUB OP: please check elements of indices-array, " + "total number of wrong elements is %lld!", + numOfBadIndx); + } + + if (!block.isInplace()) output->assign(input); + + helpers::scatterND(block.launchContext(), pairwise::Subtract, *indices, + *updates, *output, lock); + + return Status::OK(); } - DECLARE_TYPES(scatter_nd_sub) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } - -} +DECLARE_TYPES(scatter_nd_sub) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } +} // namespace ops +} // namespace sd + #endif diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_update.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_update.cpp index b6122c724937..a7d0c8e0a181 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_update.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_update.cpp @@ -25,58 +25,75 @@ #include namespace sd { -namespace ops { +namespace ops { OP_IMPL(scatter_nd_update, 3, 1, true) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto updates = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - const bool lock = block.getBArguments()->empty() ? true : B_ARG(0); - const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); - - const int inRank = input->rankOf(); - const int indRank = indices->rankOf(); - const int updRank = updates->rankOf(); - - const Nd4jLong indLastDim = indices->sizeAt(-1); - - REQUIRE_TRUE(indLastDim <= inRank, 0, "SCATTER_ND_UPDATE OP: the last dimension of indices array must be <= input_array_rank, but got %i instead !", indLastDim); - REQUIRE_TRUE(updRank == (indRank - 1 + inRank - indLastDim), 0, "SCATTER_ND_UPDATE OP: the equality updates_rank = (indices_rank - 1 + input_rank - last_indices_dimension) must be true for input arrays, but got instead: updates_rank = %i, indices_rank = %i, last_indices_dimension = %i !", updRank, indRank, indLastDim); - - std::vector inShape = input->getShapeAsVector(); - std::vector updShape = updates->getShapeAsVector(); - std::vector indShape = indices->getShapeAsVector(); - std::vector expectedUpdShape(std::begin(indShape), std::end(indShape) - 1); - if(inRank > indLastDim) - std::move(std::begin(inShape) + indLastDim, std::end(inShape), std::back_inserter(expectedUpdShape)); - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND_UPDATE OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - - if(checkIndices) { - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_ND_UPDATE OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - } - - if (!block.isInplace()) - output->assign(input); - - helpers::scatterND(block.launchContext(), pairwise::CopyPws, *indices, *updates, *output, lock); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto updates = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + const bool lock = block.getBArguments().empty() ? true : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); + + const int inRank = input->rankOf(); + const int indRank = indices->rankOf(); + const int updRank = updates->rankOf(); + + const Nd4jLong indLastDim = indices->sizeAt(-1); + + REQUIRE_TRUE(indLastDim <= inRank, 0, + "SCATTER_ND_UPDATE OP: the last dimension of indices array must " + "be <= input_array_rank, but got %i instead !", + indLastDim); + REQUIRE_TRUE(updRank == (indRank - 1 + inRank - indLastDim), 0, + "SCATTER_ND_UPDATE OP: the equality updates_rank = " + "(indices_rank - 1 + input_rank - last_indices_dimension) must " + "be true for input arrays, but got instead: updates_rank = %i, " + "indices_rank = %i, last_indices_dimension = %i !", + updRank, indRank, indLastDim); + + std::vector inShape = input->getShapeAsVector(); + std::vector updShape = updates->getShapeAsVector(); + std::vector indShape = indices->getShapeAsVector(); + std::vector expectedUpdShape(std::begin(indShape), + std::end(indShape) - 1); + if (inRank > indLastDim) + std::move(std::begin(inShape) + indLastDim, std::end(inShape), + std::back_inserter(expectedUpdShape)); + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_ND_UPDATE OP: wrong shape of updates array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + + if (checkIndices) { + const Nd4jLong numOfBadIndx = + helpers::checkIndices(block.launchContext(), *indices, *output); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "SCATTER_ND_UPDATE OP: please check elements of " + "indices-array, total number of wrong elements is %lld!", + numOfBadIndx); + } + + if (!block.isInplace()) output->assign(input); + + helpers::scatterND(block.launchContext(), pairwise::CopyPws, *indices, + *updates, *output, lock); + + return Status::OK(); } - DECLARE_TYPES(scatter_nd_update) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - - } - -} +DECLARE_TYPES(scatter_nd_update) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } +} // namespace ops +} // namespace sd + #endif diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_sub.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_sub.cpp index c955ac04221a..7cf7e3612dca 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_sub.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_sub.cpp @@ -26,74 +26,93 @@ #include namespace sd { - namespace ops { - OP_IMPL(scatter_sub, 3, 1, true) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto updates = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - if (!block.isInplace()) - output->assign(input); - - const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); - const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); - - const int inRank = input->rankOf(); - const int indRank = indices->rankOf(); - const int updRank = updates->rankOf(); - - REQUIRE_TRUE(inRank > 0, 0, "SCATTER_SUB OP: input should not be scalar !"); - - if(inRank == 1) { - REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_SUB OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); - } - else if (inRank == updRank && indices->isVector()) { - - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - - else { - - REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_SUB OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); - - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - - if (!indices->isEmpty()) { - - if(checkIndices) { - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_SUB OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - } - - // ScatterHelper::template scatterApply>(output, indices, updates); - helpers::scatter(block.launchContext(), pairwise::Subtract, *indices, *updates, *output, lock); - } +namespace ops { +OP_IMPL(scatter_sub, 3, 1, true) { + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto updates = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + if (!block.isInplace()) output->assign(input); + + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); + + const int inRank = input->rankOf(); + const int indRank = indices->rankOf(); + const int updRank = updates->rankOf(); + + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_SUB OP: input should not be scalar !"); + + if (inRank == 1) { + REQUIRE_TRUE(indices->isSameShape(updates), 0, + "SCATTER_SUB OP: when input array has rank = 1 then indices " + "and updates must have the same shapes, but got %s and %s " + "correspondingly !", + ShapeUtils::shapeAsString(indices).c_str(), + ShapeUtils::shapeAsString(updates).c_str()); + } else if (inRank == updRank && indices->isVector()) { + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = {indices->lengthOf()}; + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_SUB OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } + + else { + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, + "SCATTER_SUB OP: wrong rank of updates array, expected is %i, " + "but got %i instead !", + indRank + inRank - 1, updRank); + + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_SUB OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } + + if (!indices->isEmpty()) { + if (checkIndices) { + const Nd4jLong numOfBadIndx = + helpers::checkIndices(block.launchContext(), *indices, *output, 0); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "SCATTER_SUB OP: please check elements of indices-array, " + "total number of wrong elements is %lld!", + numOfBadIndx); + } - return Status::OK(); - } - DECLARE_SYN(ScatterSub, scatter_sub); + // ScatterHelper::template scatterApply>(output, + // indices, updates); + helpers::scatter(block.launchContext(), pairwise::Subtract, *indices, + *updates, *output, lock); + } - DECLARE_TYPES(scatter_sub) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } - } + return Status::OK(); +} +DECLARE_SYN(ScatterSub, scatter_sub); + +DECLARE_TYPES(scatter_sub) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_upd.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_upd.cpp index ef54b98138e0..b4e318cdc043 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_upd.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_upd.cpp @@ -25,74 +25,91 @@ #include namespace sd { - namespace ops { - OP_IMPL(scatter_upd, 3, 1, true) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto updates = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - if (!block.isInplace()) - output->assign(input); - - const bool lock = block.getBArguments()->empty() ? true : B_ARG(0); - const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); - - const int inRank = input->rankOf(); - const int indRank = indices->rankOf(); - const int updRank = updates->rankOf(); - - REQUIRE_TRUE(inRank > 0, 0, "SCATTER_UPD OP: input should not be scalar !"); - - if(inRank == 1) { - REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_UPD OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); - } - else if (inRank == updRank && indices->isVector()) { - - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_UPD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - else { - - REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_UPD OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); - - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_UPD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - - if (!indices->isEmpty()) { - - if(checkIndices) { - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_UPD OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - } - - // ScatterHelper::template scatterApply>(output, indices, updates); - helpers::scatter(block.launchContext(), pairwise::CopyPws, *indices, *updates, *output, lock); - } - - return Status::OK(); - } - DECLARE_SYN(ScatterUpdate, scatter_upd); +namespace ops { +OP_IMPL(scatter_upd, 3, 1, true) { + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto updates = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + if (!block.isInplace()) output->assign(input); + + const bool lock = block.getBArguments().empty() ? true : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); + + const int inRank = input->rankOf(); + const int indRank = indices->rankOf(); + const int updRank = updates->rankOf(); + + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_UPD OP: input should not be scalar !"); + + if (inRank == 1) { + REQUIRE_TRUE(indices->isSameShape(updates), 0, + "SCATTER_UPD OP: when input array has rank = 1 then indices " + "and updates must have the same shapes, but got %s and %s " + "correspondingly !", + ShapeUtils::shapeAsString(indices).c_str(), + ShapeUtils::shapeAsString(updates).c_str()); + } else if (inRank == updRank && indices->isVector()) { + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = {indices->lengthOf()}; + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_UPD OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } else { + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, + "SCATTER_UPD OP: wrong rank of updates array, expected is %i, " + "but got %i instead !", + indRank + inRank - 1, updRank); + + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_UPD OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } + + if (!indices->isEmpty()) { + if (checkIndices) { + const Nd4jLong numOfBadIndx = + helpers::checkIndices(block.launchContext(), *indices, *output, 0); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "SCATTER_UPD OP: please check elements of indices-array, " + "total number of wrong elements is %lld!", + numOfBadIndx); + } + // ScatterHelper::template scatterApply>(output, + // indices, updates); + helpers::scatter(block.launchContext(), pairwise::CopyPws, *indices, + *updates, *output, lock); + } - DECLARE_TYPES(scatter_upd) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } - } + return Status::OK(); +} +DECLARE_SYN(ScatterUpdate, scatter_upd); + +DECLARE_TYPES(scatter_upd) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_update.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_update.cpp index d15b4c85949d..1d7b9ec3a110 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_update.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_update.cpp @@ -22,39 +22,36 @@ #if NOT_EXCLUDED(OP_scatter_update) #include -#include +#include namespace sd { - namespace ops { - /** - * scatter update operation - * - * IArgs map: - * IArgs[0] - update operation: 0 - add; 1 - sub; 2 - mul; 3 - div; 4 - rsub; 5 - rdiv; 6 - assign - * IArgs[1] - number of dimensions - * IArgs[...] - dimensions - * IArgs[...] - number of indices - * IArgs[...] - indices - * - * @tparam T - */ - CONFIGURABLE_OP_IMPL(scatter_update, 2, 1, true, 0, -1) { - - auto operand = INPUT_VARIABLE(0); - auto updates = INPUT_VARIABLE(1); - - helpers::scatterUpdate(block.launchContext(), *operand, *updates, block.getIArguments()); - - return Status::OK(); - } - DECLARE_SYN(scatterupdate, scatter_update); - - DECLARE_TYPES(scatter_update) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +namespace ops { +/** + * scatter update operation + * + * IArgs map: + * IArgs[0] - update operation: 0 - add; 1 - sub; 2 - mul; 3 - div; 4 - rsub; 5 + * - rdiv; 6 - assign IArgs[1] - number of dimensions IArgs[...] - dimensions + * IArgs[...] - number of indices + * IArgs[...] - indices + * + * @tparam T + */ +CONFIGURABLE_OP_IMPL(scatter_update, 2, 1, true, 0, -1) { + auto operand = INPUT_VARIABLE(0); + auto updates = INPUT_VARIABLE(1); + + helpers::scatterUpdate(block.launchContext(), *operand, *updates, + &block.getIArguments()); + + return Status::OK(); +} +DECLARE_SYN(scatterupdate, scatter_update); + +DECLARE_TYPES(scatter_update) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/slice.cpp b/libnd4j/include/ops/declarable/generic/transforms/slice.cpp index 822f48681b33..55087b271198 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/slice.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/slice.cpp @@ -21,210 +21,258 @@ #include //#if NOT_EXCLUDED(OP_slice) -#include #include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(slice, 1, 1, false, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - int x_rank = input->rankOf(); - - std::vector begin; - std::vector sz; - - if (block.width() == 3) { - auto b = INPUT_VARIABLE(1); - auto e = INPUT_VARIABLE(2); - - begin = b->template asVectorT(); - sz = e->template asVectorT(); - } else { - REQUIRE_TRUE(block.numI() >= x_rank * 2, 0, "Number of IArgs should be equal to [%i] but got [%i] instead", x_rank * 2, block.numI()); - - ShapeUtils::copyVectorPart(begin, *(block.getIArguments()), x_rank, 0); - ShapeUtils::copyVectorPart(sz, *(block.getIArguments()), x_rank, x_rank); - } - - REQUIRE_TRUE(begin.size() == x_rank, 0, "begin array should have length of [%i] but got [%i] instead", x_rank, begin.size()); - REQUIRE_TRUE(sz.size() == x_rank, 0, "size array should have length of [%i] but got [%i] instead", x_rank, sz.size()); - - std::vector indices(2 * x_rank); - auto empty = false; - for (int e = 0; e < x_rank; e++) { - int size = sz[e]; - int start = begin[e]; - - REQUIRE_TRUE(start >= 0, 0, "Slice: start index should not be negative"); - - REQUIRE_TRUE(start <= input->sizeAt(e), 0, "Index %i is invalid for dimension %i with size %i.", start, e, input->shapeInfo()[e + 1]); - if (size == -1){ - size = input->sizeAt(e) - start; - } - REQUIRE_TRUE(size >= 0, 0, "Slice: interval for dimension %i is less then 1"); - REQUIRE_TRUE(start + size <= input->sizeAt(e), 0, "Slice: interval [%i, %i] is out of bounds for dimension %i with size %i", start, start + size, e, input->sizeAt(e)); - - if(start == input->sizeAt(e) || size == 0 ){ - empty = true; - //Don't break to perform input validation on other dims - } - - indices[2*e] = start; - indices[2*e+1] = start + size; - } - - if(empty){ - REQUIRE_TRUE(output->isEmpty(), 0, "Slice: empty array indices requested, but output array is not empty"); - return Status::OK(); - } +namespace ops { +CUSTOM_OP_IMPL(slice, 1, 1, false, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + int x_rank = input->rankOf(); + + std::vector begin; + std::vector sz; + + if (block.width() == 3) { + auto b = INPUT_VARIABLE(1); + auto e = INPUT_VARIABLE(2); + + begin = b->template asVectorT(); + sz = e->template asVectorT(); + } else { + REQUIRE_TRUE(block.numI() >= x_rank * 2, 0, + "Number of IArgs should be equal to [%i] but got [%i] instead", + x_rank * 2, block.numI()); + + auto vec = block.getIArguments(); + + ShapeUtils::copyVectorPart(begin, vec, x_rank, 0); + ShapeUtils::copyVectorPart(sz, vec, x_rank, x_rank); + } + + REQUIRE_TRUE(begin.size() == x_rank, 0, + "begin array should have length of [%i] but got [%i] instead", + x_rank, begin.size()); + REQUIRE_TRUE(sz.size() == x_rank, 0, + "size array should have length of [%i] but got [%i] instead", + x_rank, sz.size()); + + std::vector indices(2 * x_rank); + auto empty = false; + for (int e = 0; e < x_rank; e++) { + int size = sz[e]; + int start = begin[e]; + + REQUIRE_TRUE(start >= 0, 0, "Slice: start index should not be negative"); + + REQUIRE_TRUE(start <= input->sizeAt(e), 0, + "Index %i is invalid for dimension %i with size %i.", start, e, + input->shapeInfo()[e + 1]); + if (size == -1) { + size = input->sizeAt(e) - start; + } + REQUIRE_TRUE(size >= 0, 0, + "Slice: interval for dimension %i is less then 1"); + REQUIRE_TRUE(start + size <= input->sizeAt(e), 0, + "Slice: interval [%i, %i] is out of bounds for dimension %i " + "with size %i", + start, start + size, e, input->sizeAt(e)); + + if (start == input->sizeAt(e) || size == 0) { + empty = true; + // Don't break to perform input validation on other dims + } - Nd4jLong* subArrShapeInfo = nullptr; - ALLOCATE(subArrShapeInfo, block.getWorkspace(), shape::shapeInfoLength(input->rankOf()), Nd4jLong); + indices[2 * e] = start; + indices[2 * e + 1] = start + size; + } - Nd4jLong offset; + if (empty) { + REQUIRE_TRUE( + output->isEmpty(), 0, + "Slice: empty array indices requested, but output array is not empty"); + return Status::OK(); + } - shape::calcSubArrShapeInfoAndOffset(indices.data(), input->shapeInfo(), subArrShapeInfo, offset, true); + Nd4jLong *subArrShapeInfo = nullptr; + ALLOCATE(subArrShapeInfo, block.workspace(), + shape::shapeInfoLength(input->rankOf()), Nd4jLong); - auto subArrShapeInfoPack = ConstantShapeHelper::getInstance().bufferForShapeInfo(subArrShapeInfo); + Nd4jLong offset; - NDArray::prepareSpecialUse({output}, {input}); + shape::calcSubArrShapeInfoAndOffset(indices.data(), input->shapeInfo(), + subArrShapeInfo, offset, true); - NativeOpExecutioner::execTransformAny(block.launchContext(), sd::transform::Assign, - input->bufferWithOffset(offset), subArrShapeInfoPack.primary(), - input->specialBufferWithOffset(offset), subArrShapeInfoPack.special(), - output->buffer(), output->shapeInfo(), output->specialBuffer(), output->specialShapeInfo(), - nullptr, nullptr, nullptr, true); + auto subArrShapeInfoPack = + ConstantShapeHelper::getInstance().bufferForShapeInfo(subArrShapeInfo); - NDArray::registerSpecialUse({output}, {input}); + NDArray::prepareSpecialUse({output}, {input}); - RELEASE(subArrShapeInfo, block.getWorkspace()); + NativeOpExecutioner::execTransformAny( + block.launchContext(), sd::transform::Assign, + input->bufferWithOffset(offset), + subArrShapeInfoPack.primary(), + input->specialBufferWithOffset(offset), + subArrShapeInfoPack.special(), + output->buffer(), output->shapeInfo(), output->specialBuffer(), + output->specialShapeInfo(), nullptr, nullptr, nullptr, true); - // auto sub = (*input)(indices, true); - // output->assign(sub); + NDArray::registerSpecialUse({output}, {input}); - STORE_RESULT(output); + RELEASE(subArrShapeInfo, block.workspace()); - return Status::OK(); - } + // auto sub = (*input)(indices, true); + // output->assign(sub); - DECLARE_TYPES(slice) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } + STORE_RESULT(output); - DECLARE_SHAPE_FN(slice) { - auto inShape = inputShape->at(0); - auto x_rank = shape::rank(inShape); + return Status::OK(); +} - std::vector begin; - std::vector sz; +DECLARE_TYPES(slice) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} - if (block.width() == 3) { - auto b = INPUT_VARIABLE(1); - auto e = INPUT_VARIABLE(2); +DECLARE_SHAPE_FN(slice) { + auto inShape = inputShape->at(0); + auto x_rank = shape::rank(inShape); + + std::vector begin; + std::vector sz; + + if (block.width() == 3) { + auto b = INPUT_VARIABLE(1); + auto e = INPUT_VARIABLE(2); + + begin = b->template asVectorT(); + sz = e->template asVectorT(); + } else { + REQUIRE_TRUE(block.numI() >= x_rank * 2, 0, + "Number of IArgs should be equal to [%i] but got [%i] instead", + x_rank * 2, block.numI()); + + auto vec = block.getIArguments(); + ShapeUtils::copyVectorPart(begin, vec, x_rank, 0); + ShapeUtils::copyVectorPart(sz, vec, x_rank, x_rank); + } + + REQUIRE_TRUE(begin.size() == x_rank, 0, + "Begin array should have length of [%i] but got [%i] instead", + x_rank, begin.size()); + REQUIRE_TRUE(sz.size() == x_rank, 0, + "Size array should have length of [%i] but got [%i] instead", + x_rank, sz.size()); + + std::vector shape; + auto empty = false; + for (int e = 0; e < x_rank; e++) { + auto size = sz[e]; + auto start = begin[e]; + + if (size == -1) { + size = inShape[e + 1] - start; + } - begin = b->template asVectorT(); - sz = e->template asVectorT(); - } else { - REQUIRE_TRUE(block.numI() >= x_rank * 2, 0, "Number of IArgs should be equal to [%i] but got [%i] instead", x_rank * 2, block.numI()); + // Bounds checking. Note that begin[i] == size[i] means empty array + REQUIRE_TRUE(start >= 0 && start <= inShape[e + 1], 0, + "Invalid begin[%i] value: Begin must satisfy 0 <= begin <= " + "size[i], got begin=%i for dimension size %i", + e, start, inShape[e + 1]); + REQUIRE_TRUE(size == -1 || size >= 0, 0, + "Invalid size[%i] value: must be positive (or -1 for 'all " + "remaining'), got %i", + e, size, inShape[e + 1]); + REQUIRE_TRUE(start >= 0 && start <= inShape[e + 1], 0, + "Invalid begin[%i] value: Begin must satisfy 0 <= begin <= " + "size[i], got begin=%i for dimension size %i", + e, start, inShape[e + 1]); + REQUIRE_TRUE(start + size <= inShape[e + 1], 0, + "Slice: interval [%i, %i] is out of bounds for dimension %i " + "with size %i", + start, start + size, e, inShape[e + 1]); + if (start == inShape[e + 1]) { + size = 0; + } - ShapeUtils::copyVectorPart(begin, *(block.getIArguments()), x_rank, 0); - ShapeUtils::copyVectorPart(sz, *(block.getIArguments()), x_rank, x_rank); - } + shape.emplace_back(size); + } - REQUIRE_TRUE(begin.size() == x_rank, 0, "Begin array should have length of [%i] but got [%i] instead", x_rank, begin.size()); - REQUIRE_TRUE(sz.size() == x_rank, 0, "Size array should have length of [%i] but got [%i] instead", x_rank, sz.size()); + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inShape), 'c', shape); + return SHAPELIST(newShape); +} - std::vector shape; - auto empty = false; - for (int e = 0; e < x_rank; e++) { - auto size = sz[e]; - auto start = begin[e]; +DECLARE_TYPES(slice_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - if(size == -1){ - size = inShape[e+1] - start; - } +CUSTOM_OP_IMPL(slice_bp, 2, 1, false, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto epsNext = block.width() == 4 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(1); - //Bounds checking. Note that begin[i] == size[i] means empty array - REQUIRE_TRUE(start >= 0 && start <= inShape[e+1], 0, "Invalid begin[%i] value: Begin must satisfy 0 <= begin <= size[i], got begin=%i for dimension size %i", e, start, inShape[e+1]); - REQUIRE_TRUE(size == -1 || size >= 0, 0, "Invalid size[%i] value: must be positive (or -1 for 'all remaining'), got %i", e, size, inShape[e+1]); - REQUIRE_TRUE(start >= 0 && start <= inShape[e+1], 0, "Invalid begin[%i] value: Begin must satisfy 0 <= begin <= size[i], got begin=%i for dimension size %i", e, start, inShape[e+1]); - REQUIRE_TRUE(start + size <= inShape[e+1], 0, "Slice: interval [%i, %i] is out of bounds for dimension %i with size %i", start, start + size, e, inShape[e+1]); - if(start == inShape[e+1] ){ - size = 0; - } + auto output = OUTPUT_VARIABLE(0); + output->assign(0.); + int x_rank = input->rankOf(); - shape.emplace_back(size); - } + std::vector begin; + std::vector end; - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inShape), 'c', shape); - return SHAPELIST(newShape); - } + if (block.width() == 4) { + auto b = INPUT_VARIABLE(1); + auto e = INPUT_VARIABLE(2); - DECLARE_TYPES(slice_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } + begin = b->template asVectorT(); + end = e->template asVectorT(); + } else { + REQUIRE_TRUE(block.numI() >= x_rank * 2, 0, + "Number of IArgs should be equal to [%i] but got [%i] instead", + x_rank * 2, block.numI()); + auto vec = block.getIArguments(); - CUSTOM_OP_IMPL(slice_bp, 2, 1, false, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto epsNext = block.width() == 4 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(1); + ShapeUtils::copyVectorPart(begin, vec, x_rank, 0); + ShapeUtils::copyVectorPart(end, vec, x_rank, x_rank); + } - auto output = OUTPUT_VARIABLE(0); - output->assign(0.); - int x_rank = input->rankOf(); + REQUIRE_TRUE(begin.size() == x_rank, 0, + "begin array should have length of [%i] but got [%i] instead", + x_rank, begin.size()); + REQUIRE_TRUE(end.size() == x_rank, 0, + "end array should have length of [%i] but got [%i] instead", + x_rank, end.size()); - std::vector begin; - std::vector end; + std::vector indices(2 * x_rank); + for (int e = 0; e < x_rank; e++) { + int size = end[e]; + int start = begin[e]; - if (block.width() == 4) { - auto b = INPUT_VARIABLE(1); - auto e = INPUT_VARIABLE(2); + if (size == -1) { //-1 means all remaining values + size = input->sizeAt(e) - start; + } + REQUIRE_TRUE(size > 0, 0, "Slice: interval for dimension %i is less then 1", + e); - begin = b->template asVectorT(); - end = e->template asVectorT(); - } else { - REQUIRE_TRUE(block.numI() >= x_rank * 2, 0, "Number of IArgs should be equal to [%i] but got [%i] instead", x_rank * 2, block.numI()); + indices[2 * e] = start; + indices[2 * e + 1] = start + size; + } + auto sub = (*output)(indices, true); + sub.assign(epsNext); - ShapeUtils::copyVectorPart(begin, *(block.getIArguments()), x_rank, 0); - ShapeUtils::copyVectorPart(end, *(block.getIArguments()), x_rank, x_rank); - } + return Status::OK(); +} - REQUIRE_TRUE(begin.size() == x_rank, 0, "begin array should have length of [%i] but got [%i] instead", x_rank, begin.size()); - REQUIRE_TRUE(end.size() == x_rank, 0, "end array should have length of [%i] but got [%i] instead", x_rank, end.size()); +DECLARE_SHAPE_FN(slice_bp) { + auto inShape = inputShape->at(0); + Nd4jLong *newShape; + COPY_SHAPE(inShape, newShape); - std::vector indices(2 * x_rank); - for (int e = 0; e < x_rank; e++) { - int size = end[e]; - int start = begin[e]; - - if (size == -1){ //-1 means all remaining values - size = input->sizeAt(e) - start; - } - REQUIRE_TRUE(size > 0, 0, "Slice: interval for dimension %i is less then 1", e); - - indices[2*e] = start; - indices[2*e + 1] = start + size; - } - auto sub = (*output)(indices, true); - sub.assign(epsNext); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(slice_bp) { - auto inShape = inputShape->at(0); - Nd4jLong *newShape; - COPY_SHAPE(inShape, newShape); - - return SHAPELIST(CONSTANT(newShape)); - } - } + return SHAPELIST(CONSTANT(newShape)); } +} // namespace ops +} // namespace sd //#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/space_to_batch.cpp b/libnd4j/include/ops/declarable/generic/transforms/space_to_batch.cpp index ffffb5396944..79f4f9c4704b 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/space_to_batch.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/space_to_batch.cpp @@ -26,75 +26,108 @@ limitations under the License. namespace sd { namespace ops { - CUSTOM_OP_IMPL(space_to_batch, 2, 1, false, 0, 1) { - - // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC] - - auto input = INPUT_VARIABLE(0); - auto padding = INPUT_VARIABLE(1); - - auto output = OUTPUT_VARIABLE(0); - - - const uint blockSize = INT_ARG(0); - REQUIRE_TRUE(blockSize >= 2, 0, "SpaceToBatch: integer parameter block_size must be >= 2, but got %i instead", blockSize); - - REQUIRE_TRUE(input->rankOf() == 4, 0, "SpaceToBatch: rank of input array must be equal 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(output->rankOf() == 4, 0, "SpaceToBatch: rank of output array must be equal 4, but got %i instead", output->rankOf()); - - if(padding->sizeAt(0) != 2 || padding->sizeAt(1) != 2) - REQUIRE_TRUE(false, 0, "SpaceToBatch: operation expects padding shape to be {2, 2}, but got %s instead", ShapeUtils::shapeAsString(padding).c_str()); - - const uint padBottom = padding->e(0,0); - const uint padTop = padding->e(0,1); - const uint padLeft = padding->e(1,0); - const uint padRight = padding->e(1,1); - - REQUIRE_TRUE((input->sizeAt(1) + padBottom + padTop) % blockSize == 0 && (input->sizeAt(2) + padLeft + padRight) % blockSize == 0, 0, "SpaceToBatch: after padding, second and third dimensions of input array must be divisible by blockSize !"); - - if (shape::strideDescendingCAscendingF(input->shapeInfo())) - helpers::spaceToBatch(block.launchContext(), *input, *output, padBottom, padTop, padLeft, padRight, blockSize); - else - helpers::spaceToBatch(block.launchContext(), input->dup(), *output, padBottom, padTop, padLeft, padRight, blockSize); - - return Status::OK(); + // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + + // padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC] + + auto input = INPUT_VARIABLE(0); + auto padding = INPUT_VARIABLE(1); + + auto output = OUTPUT_VARIABLE(0); + + const uint blockSize = INT_ARG(0); + REQUIRE_TRUE(blockSize >= 2, 0, + "SpaceToBatch: integer parameter block_size must be >= 2, but " + "got %i instead", + blockSize); + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "SpaceToBatch: rank of input array must be equal 4, but got %i instead", + input->rankOf()); + REQUIRE_TRUE( + output->rankOf() == 4, 0, + "SpaceToBatch: rank of output array must be equal 4, but got %i instead", + output->rankOf()); + + if (padding->sizeAt(0) != 2 || padding->sizeAt(1) != 2) + REQUIRE_TRUE(false, 0, + "SpaceToBatch: operation expects padding shape to be {2, 2}, " + "but got %s instead", + ShapeUtils::shapeAsString(padding).c_str()); + + const uint padBottom = padding->e(0, 0); + const uint padTop = padding->e(0, 1); + const uint padLeft = padding->e(1, 0); + const uint padRight = padding->e(1, 1); + + REQUIRE_TRUE((input->sizeAt(1) + padBottom + padTop) % blockSize == 0 && + (input->sizeAt(2) + padLeft + padRight) % blockSize == 0, + 0, + "SpaceToBatch: after padding, second and third dimensions of " + "input array must be divisible by blockSize !"); + + if (shape::strideDescendingCAscendingF(input->shapeInfo())) + helpers::spaceToBatch(block.launchContext(), *input, *output, padBottom, + padTop, padLeft, padRight, blockSize); + else + helpers::spaceToBatch(block.launchContext(), input->dup(), *output, + padBottom, padTop, padLeft, padRight, blockSize); + + return Status::OK(); } //////////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(space_to_batch) { - - getOpDescriptor()->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setSameMode(true); } //////////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(space_to_batch) { - - auto inputShapeInfo = inputShape->at(0); - auto paddingShapeInfo = inputShape->at(1); - - const uint blockSize = INT_ARG(0); - REQUIRE_TRUE(blockSize >= 2, 0, "SpaceToBatch: integer parameter block_size must be >= 2, but got %i instead", blockSize); - - const int rank = inputShapeInfo[0]; - REQUIRE_TRUE(rank == 4, 0, "SpaceToBatch: rank of input array must be equal 4, but got %i instead", rank); - - if(paddingShapeInfo[1] != 2 || paddingShapeInfo[1] != 2) - REQUIRE_TRUE(false, 0, "SpaceToBatch: operation expects padding shape to be {2, 2}, but got %s instead", ShapeUtils::shapeAsString(paddingShapeInfo).c_str()); - - const uint padBottom = INPUT_VARIABLE(1)->e(0,0); - const uint padTop = INPUT_VARIABLE(1)->e(0,1); - const uint padLeft = INPUT_VARIABLE(1)->e(1,0); - const uint padRight = INPUT_VARIABLE(1)->e(1,1); - - REQUIRE_TRUE((inputShapeInfo[2] + padBottom + padTop) % blockSize == 0 && (inputShapeInfo[3] + padLeft + padRight) % blockSize == 0, 0, "SpaceToBatch: after padding, second and third dimensions of input array must be divisible by blockSize !"); - - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inputShapeInfo), 'c', {inputShapeInfo[1] * blockSize * blockSize, (inputShapeInfo[2] + padBottom + padTop) / blockSize, (inputShapeInfo[3] + padLeft + padRight) / blockSize, inputShapeInfo[4]})); + auto inputShapeInfo = inputShape->at(0); + auto paddingShapeInfo = inputShape->at(1); + + const uint blockSize = INT_ARG(0); + REQUIRE_TRUE(blockSize >= 2, 0, + "SpaceToBatch: integer parameter block_size must be >= 2, but " + "got %i instead", + blockSize); + + const int rank = inputShapeInfo[0]; + REQUIRE_TRUE( + rank == 4, 0, + "SpaceToBatch: rank of input array must be equal 4, but got %i instead", + rank); + + if (paddingShapeInfo[1] != 2 || paddingShapeInfo[1] != 2) + REQUIRE_TRUE(false, 0, + "SpaceToBatch: operation expects padding shape to be {2, 2}, " + "but got %s instead", + ShapeUtils::shapeAsString(paddingShapeInfo).c_str()); + + const uint padBottom = INPUT_VARIABLE(1)->e(0, 0); + const uint padTop = INPUT_VARIABLE(1)->e(0, 1); + const uint padLeft = INPUT_VARIABLE(1)->e(1, 0); + const uint padRight = INPUT_VARIABLE(1)->e(1, 1); + + REQUIRE_TRUE((inputShapeInfo[2] + padBottom + padTop) % blockSize == 0 && + (inputShapeInfo[3] + padLeft + padRight) % blockSize == 0, + 0, + "SpaceToBatch: after padding, second and third dimensions of " + "input array must be divisible by blockSize !"); + + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inputShapeInfo), 'c', + {inputShapeInfo[1] * blockSize * blockSize, + (inputShapeInfo[2] + padBottom + padTop) / blockSize, + (inputShapeInfo[3] + padLeft + padRight) / blockSize, + inputShapeInfo[4]})); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/transforms/space_to_batch_nd.cpp b/libnd4j/include/ops/declarable/generic/transforms/space_to_batch_nd.cpp index 5adc35ee622a..028457236e17 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/space_to_batch_nd.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/space_to_batch_nd.cpp @@ -23,83 +23,113 @@ limitations under the License. #include namespace sd { -namespace ops { - +namespace ops { CUSTOM_OP_IMPL(space_to_batch_nd, 3, 1, false, 0, 0) { - - // 4D example, numOfSpatialDims = 2 - two spatial dimensions - // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockShape[0]*blockShape[1], (iH + padBottom + padTop)/blockSize[0], (iW + padLeft + padRight)/blockSize[1], iC] - - auto input = INPUT_VARIABLE(0); - auto blockShape = INPUT_VARIABLE(1); - auto padding = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(blockShape->rankOf() == 1, 0, "SpaceToBatchND: rank of blockShape array must be equal to one, but got %i instead !", blockShape->rankOf()); - - const uint numOfSpatialDims = blockShape->sizeAt(0); - - REQUIRE_TRUE(input->rankOf() == output->rankOf(), 0, "SpaceToBatchND: rank of input and output array must be the same, but got %i and %i correspondingly !", input->rankOf(), output->rankOf()); - - if(padding->sizeAt(0) != numOfSpatialDims || padding->sizeAt(1) != 2) { - const std::string expectedpaddingShape = "[" + std::to_string(numOfSpatialDims) + ", 2]"; // [numOfSpatialDims, 2] - REQUIRE_TRUE(false, 0, "SpaceToBatchND: operation expects padding shape to be %s, but got %s instead", expectedpaddingShape.c_str(), ShapeUtils::shapeAsString(padding).c_str()); - } - - // FIXME - should we use this time-consuming validation ? - for (uint i = 0; i < numOfSpatialDims; ++i) { - const uint padLeft = padding->e(i,0); - const uint padRight = padding->e(i,1); - const Nd4jLong blockSize = blockShape->e(i); - REQUIRE_TRUE((input->sizeAt(i + 1) + padLeft + padRight) % blockSize == 0, 0, "SpaceToBatchND: after padding, spatial dimensions of input array must be divisible by blockSize !"); - } - - if (shape::strideDescendingCAscendingF(input->shapeInfo())) - helpers::spaceToBatchND(block.launchContext(), *input, *blockShape, *padding, *output); - else - helpers::spaceToBatchND(block.launchContext(), input->dup(), *blockShape, *padding, *output); - - return Status::OK(); + // 4D example, numOfSpatialDims = 2 - two spatial dimensions + // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockShape[0]*blockShape[1], + // (iH + padBottom + padTop)/blockSize[0], (iW + padLeft + + // padRight)/blockSize[1], iC] + + auto input = INPUT_VARIABLE(0); + auto blockShape = INPUT_VARIABLE(1); + auto padding = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(blockShape->rankOf() == 1, 0, + "SpaceToBatchND: rank of blockShape array must be equal to one, " + "but got %i instead !", + blockShape->rankOf()); + + const uint numOfSpatialDims = blockShape->sizeAt(0); + + REQUIRE_TRUE(input->rankOf() == output->rankOf(), 0, + "SpaceToBatchND: rank of input and output array must be the " + "same, but got %i and %i correspondingly !", + input->rankOf(), output->rankOf()); + + if (padding->sizeAt(0) != numOfSpatialDims || padding->sizeAt(1) != 2) { + const std::string expectedpaddingShape = "[" + + std::to_string(numOfSpatialDims) + + ", 2]"; // [numOfSpatialDims, 2] + REQUIRE_TRUE(false, 0, + "SpaceToBatchND: operation expects padding shape to be %s, " + "but got %s instead", + expectedpaddingShape.c_str(), + ShapeUtils::shapeAsString(padding).c_str()); + } + + // FIXME - should we use this time-consuming validation ? + for (uint i = 0; i < numOfSpatialDims; ++i) { + const uint padLeft = padding->e(i, 0); + const uint padRight = padding->e(i, 1); + const Nd4jLong blockSize = blockShape->e(i); + REQUIRE_TRUE((input->sizeAt(i + 1) + padLeft + padRight) % blockSize == 0, + 0, + "SpaceToBatchND: after padding, spatial dimensions of input " + "array must be divisible by blockSize !"); + } + + if (shape::strideDescendingCAscendingF(input->shapeInfo())) + helpers::spaceToBatchND(block.launchContext(), *input, *blockShape, + *padding, *output); + else + helpers::spaceToBatchND(block.launchContext(), input->dup(), *blockShape, + *padding, *output); + + return Status::OK(); } //////////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(space_to_batch_nd) { - - getOpDescriptor()->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS}) - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS}) + ->setSameMode(true); } //////////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(space_to_batch_nd) { - - auto inputShapeInfo = inputShape->at(0); - auto blockShapeInfo = inputShape->at(1); - auto paddingShapeInfo = inputShape->at(2); - - REQUIRE_TRUE(blockShapeInfo[0] == 1, 0, "SpaceToBatchND: rank of blockShape array must be equal to one, but got %i instead !", blockShapeInfo[0]); - - const uint numOfSpatialDims = blockShapeInfo[1]; - - if(paddingShapeInfo[1] != numOfSpatialDims || paddingShapeInfo[2] != 2) { - const std::string expectedpaddingShape = "[" + std::to_string(numOfSpatialDims) + ", 2]"; // [numOfSpatialDims, 2] - REQUIRE_TRUE(false, 0, "SpaceToBatchND: operation expects padding shape to be %s, but got %s instead", expectedpaddingShape.c_str(), ShapeUtils::shapeAsString(paddingShapeInfo).c_str()); - } - - std::vector outShape(inputShapeInfo + 1, inputShapeInfo + 1 + inputShapeInfo[0]); - - outShape[0] *= INPUT_VARIABLE(1)->reduceNumber(sd::reduce::Prod).e(0); - - for (uint i = 0; i < numOfSpatialDims; ++i) - outShape[i + 1] = (outShape[i + 1] + INPUT_VARIABLE(2)->e(i,0) + INPUT_VARIABLE(2)->e(i,1)) / INPUT_VARIABLE(1)->e(i); - - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inputShapeInfo), 'c', outShape)); + auto inputShapeInfo = inputShape->at(0); + auto blockShapeInfo = inputShape->at(1); + auto paddingShapeInfo = inputShape->at(2); + + REQUIRE_TRUE(blockShapeInfo[0] == 1, 0, + "SpaceToBatchND: rank of blockShape array must be equal to one, " + "but got %i instead !", + blockShapeInfo[0]); + + const uint numOfSpatialDims = blockShapeInfo[1]; + + if (paddingShapeInfo[1] != numOfSpatialDims || paddingShapeInfo[2] != 2) { + const std::string expectedpaddingShape = "[" + + std::to_string(numOfSpatialDims) + + ", 2]"; // [numOfSpatialDims, 2] + REQUIRE_TRUE(false, 0, + "SpaceToBatchND: operation expects padding shape to be %s, " + "but got %s instead", + expectedpaddingShape.c_str(), + ShapeUtils::shapeAsString(paddingShapeInfo).c_str()); + } + + std::vector outShape(inputShapeInfo + 1, + inputShapeInfo + 1 + inputShapeInfo[0]); + + outShape[0] *= + INPUT_VARIABLE(1)->reduceNumber(sd::reduce::Prod).e(0); + + for (uint i = 0; i < numOfSpatialDims; ++i) + outShape[i + 1] = (outShape[i + 1] + INPUT_VARIABLE(2)->e(i, 0) + + INPUT_VARIABLE(2)->e(i, 1)) / + INPUT_VARIABLE(1)->e(i); + + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inputShapeInfo), 'c', outShape)); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/transforms/space_to_depth.cpp b/libnd4j/include/ops/declarable/generic/transforms/space_to_depth.cpp index 7e108028a915..8565d5e0b385 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/space_to_depth.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/space_to_depth.cpp @@ -22,68 +22,73 @@ #if NOT_EXCLUDED(OP_space_to_depth) #include -#include #include +#include + namespace sd { namespace ops { - DECLARE_TYPES(space_to_depth) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - - CUSTOM_OP_IMPL(space_to_depth, 1, 1, false, 0, 2) { - int block_size = INT_ARG(0); - bool isNHWC = INT_ARG(1) == 1; - - auto input = INPUT_VARIABLE(0); - - REQUIRE_TRUE(input->rankOf() == 4, 0, "SpaceToDepth: input should be 4D array, but got %f instead", input->rankOf()); - - int bS = input->sizeAt(0); - int iD = isNHWC ? input->sizeAt(3) : input->sizeAt(1); - int iH = isNHWC ? input->sizeAt(1) : input->sizeAt(2); - int iW = isNHWC ? input->sizeAt(2) : input->sizeAt(3); - - REQUIRE_TRUE(iH % block_size == 0 && iW % block_size == 0, 0, "SpaceToDepth: input Height & Width should be divisible by block_size"); - - auto output = OUTPUT_VARIABLE(0); - - if (shape::strideDescendingCAscendingF(input->shapeInfo())) - helpers::_spaceTodepth(block.launchContext(), *input, output, block_size, isNHWC); - else - helpers::_spaceTodepth(block.launchContext(), input->dup(), output, block_size, isNHWC); - - return Status::OK(); - } - - - DECLARE_SHAPE_FN(space_to_depth) { - auto in = inputShape->at(0); - int block_size = INT_ARG(0); - bool isNHWC = INT_ARG(1) == 1; - - int bS = shape::sizeAt(in, 0); - int iD = isNHWC ? shape::sizeAt(in, 3) : shape::sizeAt(in, 1); - int iH = isNHWC ? shape::sizeAt(in, 1) : shape::sizeAt(in, 2); - int iW = isNHWC ? shape::sizeAt(in, 2) : shape::sizeAt(in, 3); - - int oD = iD * block_size * block_size; - int oH = iH / block_size; - int oW = iW / block_size; - - std::array shape; - if (isNHWC) - shape = {{bS, oH, oW, oD }}; - else - shape = {{bS, oD, oH, oW }}; - - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(in), 'c', 4, shape.data()); - return SHAPELIST(newShape); - } +DECLARE_TYPES(space_to_depth) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } + +CUSTOM_OP_IMPL(space_to_depth, 1, 1, false, 0, 2) { + int block_size = INT_ARG(0); + bool isNHWC = INT_ARG(1) == 1; + + auto input = INPUT_VARIABLE(0); + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "SpaceToDepth: input should be 4D array, but got %f instead", + input->rankOf()); + + int bS = input->sizeAt(0); + int iD = isNHWC ? input->sizeAt(3) : input->sizeAt(1); + int iH = isNHWC ? input->sizeAt(1) : input->sizeAt(2); + int iW = isNHWC ? input->sizeAt(2) : input->sizeAt(3); + + REQUIRE_TRUE( + iH % block_size == 0 && iW % block_size == 0, 0, + "SpaceToDepth: input Height & Width should be divisible by block_size"); + + auto output = OUTPUT_VARIABLE(0); + + if (shape::strideDescendingCAscendingF(input->shapeInfo())) + helpers::_spaceTodepth(block.launchContext(), *input, output, block_size, + isNHWC); + else + helpers::_spaceTodepth(block.launchContext(), input->dup(), output, + block_size, isNHWC); + + return Status::OK(); +} + +DECLARE_SHAPE_FN(space_to_depth) { + auto in = inputShape->at(0); + int block_size = INT_ARG(0); + bool isNHWC = INT_ARG(1) == 1; + + int bS = shape::sizeAt(in, 0); + int iD = isNHWC ? shape::sizeAt(in, 3) : shape::sizeAt(in, 1); + int iH = isNHWC ? shape::sizeAt(in, 1) : shape::sizeAt(in, 2); + int iW = isNHWC ? shape::sizeAt(in, 2) : shape::sizeAt(in, 3); + + int oD = iD * block_size * block_size; + int oH = iH / block_size; + int oW = iW / block_size; + + std::array shape; + if (isNHWC) + shape = {{bS, oH, oW, oD}}; + else + shape = {{bS, oD, oH, oW}}; + + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(in), 'c', 4, shape.data()); + return SHAPELIST(newShape); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/split.cpp b/libnd4j/include/ops/declarable/generic/transforms/split.cpp index 3fb925dfc515..e41fd4b9efaa 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/split.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/split.cpp @@ -22,127 +22,134 @@ #if NOT_EXCLUDED(OP_split) #include -#include +#include + #include namespace sd { namespace ops { - CUSTOM_OP_IMPL(split, 1, -1, false, 0, 1) { - NDArray *input = nullptr; - int num_splits = INT_ARG(0); - - // axis is 0 by default - int axis = 0; - - if (block.width() == 1) { - input = INPUT_VARIABLE(0); - } else { - auto a = INPUT_VARIABLE(0); - auto b = INPUT_VARIABLE(1); - - if (a->isScalar()) { - // axis goes first - axis = a->e(0); - input = b; - } else if (b->isScalar()) { - axis = b->e(0); - input = a; - } - } - - //Edge case: splitting empty array (mainly for TF import compatibility) -> return N empty arrays - if(input->isEmpty()){ - for( int i=0; i< num_splits; i++ ){ - REQUIRE_TRUE(OUTPUT_VARIABLE(i)->isEmpty(), 0, "Split: When input array is empty, all output arrays must be empty"); - } - //No op - return Status::OK(); - } - - if (block.numI() == 2) - axis = INT_ARG(1); - - if(axis < 0) axis += input->rankOf(); - - REQUIRE_TRUE(input->sizeAt(axis) % num_splits == 0, 0, "Split: num_splits has wrong value, remainder of division should be 0, but it's %i", input->sizeAt(axis) % num_splits); - - std::vector outArrs(num_splits); - for (int e = 0; e < num_splits; e++) { - outArrs[e] = OUTPUT_VARIABLE(e); - } - - helpers::split(block.launchContext(), *input, outArrs, axis); - - return Status::OK(); +CUSTOM_OP_IMPL(split, 1, -1, false, 0, 1) { + NDArray *input = nullptr; + int num_splits = INT_ARG(0); + + // axis is 0 by default + int axis = 0; + + if (block.width() == 1) { + input = INPUT_VARIABLE(0); + } else { + auto a = INPUT_VARIABLE(0); + auto b = INPUT_VARIABLE(1); + + if (a->isScalar()) { + // axis goes first + axis = a->e(0); + input = b; + } else if (b->isScalar()) { + axis = b->e(0); + input = a; } - - DECLARE_TYPES(split) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); + } + + // Edge case: splitting empty array (mainly for TF import compatibility) -> + // return N empty arrays + if (input->isEmpty()) { + for (int i = 0; i < num_splits; i++) { + REQUIRE_TRUE( + OUTPUT_VARIABLE(i)->isEmpty(), 0, + "Split: When input array is empty, all output arrays must be empty"); } + // No op + return Status::OK(); + } - DECLARE_SHAPE_FN(split) { - int num_splits = INT_ARG(0); - auto input = inputShape->at(0); - sd::DataType dataType = ArrayOptions::dataType(input); - - // axis is 0 by default - int axis = 0; - - int inputVar = 0; - if (inputShape->size() != 1) { - auto shape0 = inputShape->at(0); - auto shape1 = inputShape->at(1); - - if (shape::isScalar(shape0)) { - input = shape1; - auto _a = INPUT_VARIABLE(0); - axis = _a->e(0); - dataType = ArrayOptions::dataType(shape1); - inputVar = 1; - } else if (shape::isScalar(shape1)) { - input = shape0; - auto _a = INPUT_VARIABLE(1); - axis = _a->e(0); - dataType = ArrayOptions::dataType(shape0); - inputVar = 0; - } - } - - auto shapes = SHAPELIST(); - - //Edge case: splitting empty array (mainly for TF import compatibility) -> return N empty arrays - // if(INPUT_VARIABLE(inputVar)->isEmpty()){ - // for (int e = 0; e < num_splits; e++) { - // auto empty = ConstantShapeHelper::getInstance().emptyShapeInfo(dataType); - // shapes->push_back(empty); - // } - // return shapes; - // } - - if (block.numI() == 2) - axis = INT_ARG(1); - - if (axis < 0) - axis += shape::rank(input); - - std::vector shape(shape::rank(input)); - - for (int e = 0; e < shape::rank(input); e++) - if (e == axis) - shape[e] = shape::sizeAt(input, e) / num_splits; - else - shape[e] = shape::sizeAt(input, e); - - for (int e = 0; e < num_splits; e++) { - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(dataType, shape::order(input), shape); - shapes->push_back(newShape); - } - - return shapes; - } + if (block.numI() == 2) axis = INT_ARG(1); + + if (axis < 0) axis += input->rankOf(); + + REQUIRE_TRUE(input->sizeAt(axis) % num_splits == 0, 0, + "Split: num_splits has wrong value, remainder of division " + "should be 0, but it's %i", + input->sizeAt(axis) % num_splits); + + std::vector outArrs(num_splits); + for (int e = 0; e < num_splits; e++) { + outArrs[e] = OUTPUT_VARIABLE(e); + } + + helpers::split(block.launchContext(), *input, outArrs, axis); + + return Status::OK(); +} + +DECLARE_TYPES(split) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } + +DECLARE_SHAPE_FN(split) { + int num_splits = INT_ARG(0); + auto input = inputShape->at(0); + sd::DataType dataType = ArrayOptions::dataType(input); + + // axis is 0 by default + int axis = 0; + + int inputVar = 0; + if (inputShape->size() != 1) { + auto shape0 = inputShape->at(0); + auto shape1 = inputShape->at(1); + + if (shape::isScalar(shape0)) { + input = shape1; + auto _a = INPUT_VARIABLE(0); + axis = _a->e(0); + dataType = ArrayOptions::dataType(shape1); + inputVar = 1; + } else if (shape::isScalar(shape1)) { + input = shape0; + auto _a = INPUT_VARIABLE(1); + axis = _a->e(0); + dataType = ArrayOptions::dataType(shape0); + inputVar = 0; + } + } + + auto shapes = SHAPELIST(); + + // Edge case: splitting empty array (mainly for TF import compatibility) -> + // return N empty arrays + // if(INPUT_VARIABLE(inputVar)->isEmpty()){ + // for (int e = 0; e < num_splits; e++) { + // auto empty = + // ConstantShapeHelper::getInstance().emptyShapeInfo(dataType); + // shapes->push_back(empty); + // } + // return shapes; + // } + + if (block.numI() == 2) axis = INT_ARG(1); + + if (axis < 0) axis += shape::rank(input); + + std::vector shape(shape::rank(input)); + + for (int e = 0; e < shape::rank(input); e++) + if (e == axis) + shape[e] = shape::sizeAt(input, e) / num_splits; + else + shape[e] = shape::sizeAt(input, e); + + for (int e = 0; e < num_splits; e++) { + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + dataType, shape::order(input), shape); + shapes->push_back(newShape); + } + + return shapes; } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/split_v.cpp b/libnd4j/include/ops/declarable/generic/transforms/split_v.cpp index decda2e2d335..8248694990f5 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/split_v.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/split_v.cpp @@ -25,104 +25,104 @@ namespace sd { namespace ops { - CUSTOM_OP_IMPL(split_v, 2, -1, false, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto sizes = INPUT_VARIABLE(1); +CUSTOM_OP_IMPL(split_v, 2, -1, false, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto sizes = INPUT_VARIABLE(1); - int axis = 0; + int axis = 0; - if (block.getIArguments()->size() > 0) { - axis = INT_ARG(0); - } else if (block.width() > 2){ - auto _a = INPUT_VARIABLE(2); - axis = _a->e(0); - } + if (block.numI() > 0) { + axis = INT_ARG(0); + } else if (block.width() > 2) { + auto _a = INPUT_VARIABLE(2); + axis = _a->e(0); + } - if (axis < 0) - axis += input->rankOf(); + if (axis < 0) axis += input->rankOf(); - std::vector dims = ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); + std::vector dims = + ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); - int pos = 0; - std::vector indices(2 * input->rankOf()); - - for (Nd4jLong e = 0; e < sizes->lengthOf(); e++) { - int c_size = sizes->e(e); - - for (int d = 0; d < input->rankOf(); d++) { - if (d == axis) - indices[2*d + 1] = (indices[2*d] = pos) + c_size; - else - indices[2*d] = indices[2*d + 1] = 0; - } + int pos = 0; + std::vector indices(2 * input->rankOf()); - auto output = OUTPUT_VARIABLE(e); - REQUIRE_TRUE(output->dataType() == input->dataType(), 0, "SplitV: all outputs must have same data type as input"); + for (Nd4jLong e = 0; e < sizes->lengthOf(); e++) { + int c_size = sizes->e(e); - auto sub = (*input)(indices); + for (int d = 0; d < input->rankOf(); d++) { + if (d == axis) + indices[2 * d + 1] = (indices[2 * d] = pos) + c_size; + else + indices[2 * d] = indices[2 * d + 1] = 0; + } - output->assign(sub); + auto output = OUTPUT_VARIABLE(e); + REQUIRE_TRUE(output->dataType() == input->dataType(), 0, + "SplitV: all outputs must have same data type as input"); - pos += c_size; - } + auto sub = (*input)(indices); - //delete tads; - return Status::OK(); - } + output->assign(sub); - DECLARE_TYPES(split_v) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } + pos += c_size; + } - DECLARE_SHAPE_FN(split_v) { - auto input = inputShape->at(0); - //auto sizes = inputShape->at(1); - - auto shapeList = SHAPELIST(); - int rank = shape::rank(input); - - // 0 is just default axis - int axis = 0; - - if (block.getIArguments()->size() > 0) - axis = INT_ARG(0); - else if (block.width() > 2) { - auto _a = INPUT_VARIABLE(2); - axis = _a->e(0); - } - - if (axis < 0) - axis += shape::rank(input); - - // this op assumes we have sizes defined - auto sizes = INPUT_VARIABLE(1); - - auto length = sizes->lengthOf(); - int pos = 0; - for (Nd4jLong e = 0; e < length; e++) { - int c_size = sizes->e(e); - - - std::vector shape(rank); - - for (int d = 0; d < rank; d++) { - if (d != axis) - shape[d] = shape::sizeAt(input, d); - else - shape[d] = c_size; - } - - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(input), shape::order(input), shape); - shapeList->push_back(newShape); - } - - return shapeList; - } + // delete tads; + return Status::OK(); } + +DECLARE_TYPES(split_v) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); +} + +DECLARE_SHAPE_FN(split_v) { + auto input = inputShape->at(0); + // auto sizes = inputShape->at(1); + + auto shapeList = SHAPELIST(); + int rank = shape::rank(input); + + // 0 is just default axis + int axis = 0; + + if (block.numI() > 0) + axis = INT_ARG(0); + else if (block.width() > 2) { + auto _a = INPUT_VARIABLE(2); + axis = _a->e(0); + } + + if (axis < 0) axis += shape::rank(input); + + // this op assumes we have sizes defined + auto sizes = INPUT_VARIABLE(1); + + auto length = sizes->lengthOf(); + int pos = 0; + for (Nd4jLong e = 0; e < length; e++) { + int c_size = sizes->e(e); + + std::vector shape(rank); + + for (int d = 0; d < rank; d++) { + if (d != axis) + shape[d] = shape::sizeAt(input, d); + else + shape[d] = c_size; + } + + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(input), shape::order(input), shape); + shapeList->push_back(newShape); + } + + return shapeList; } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/stack.cpp b/libnd4j/include/ops/declarable/generic/transforms/stack.cpp index af03d5ef1d02..e37af1082359 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/stack.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/stack.cpp @@ -22,88 +22,96 @@ #if NOT_EXCLUDED(OP_stack) #include -#include +#include namespace sd { namespace ops { CUSTOM_OP_IMPL(stack, -1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; - if(dim < 0) - dim += input->rankOf() + 1; + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + int dim = block.numI() > 0 ? INT_ARG(0) : 0; + if (dim < 0) dim += input->rankOf() + 1; - // no-op in case of empty output array - if (output->isEmpty()) - return Status::OK(); + // no-op in case of empty output array + if (output->isEmpty()) return Status::OK(); - // input validation - // check whether shapes of all input array are the same - for (int i = 0; i < (int) block.width() - 1; ++i) - REQUIRE_TRUE(shape::equalsSoft((INPUT_VARIABLE(i))->shapeInfo(), (INPUT_VARIABLE(i+1))->shapeInfo()), 0, "STACK op: the shapes of all input arrays must be the same !"); + // input validation + // check whether shapes of all input array are the same + for (int i = 0; i < (int)block.width() - 1; ++i) + REQUIRE_TRUE(shape::equalsSoft((INPUT_VARIABLE(i))->shapeInfo(), + (INPUT_VARIABLE(i + 1))->shapeInfo()), + 0, + "STACK op: the shapes of all input arrays must be the same !"); - REQUIRE_TRUE(dim <= input->rankOf(), 0, "STACK op: the input dimension parameter must be <= rank of input arrays shapes (rank=%i), but got %i instead !", input->shapeOf(), dim); + REQUIRE_TRUE(dim <= input->rankOf(), 0, + "STACK op: the input dimension parameter must be <= rank of " + "input arrays shapes (rank=%i), but got %i instead !", + input->shapeOf(), dim); + std::vector inArrs(block.width()); + for (int i = 0; i < block.width(); ++i) inArrs[i] = INPUT_VARIABLE(i); - std::vector inArrs(block.width()); - for(int i = 0; i < block.width(); ++i) - inArrs[i] = INPUT_VARIABLE(i); + helpers::stack(block.launchContext(), inArrs, *output, dim); - helpers::stack(block.launchContext(), inArrs, *output, dim); - - return Status::OK(); + return Status::OK(); } DECLARE_SYN(pack, stack); DECLARE_SYN(Pack, stack); - DECLARE_TYPES(stack) { - //getOpDescriptor()->setSameMode(true); - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes(DataType::ANY); - - } +DECLARE_TYPES(stack) { + // getOpDescriptor()->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes(DataType::ANY); +} DECLARE_SHAPE_FN(stack) { - - // check whether input dimension is within rank range - auto inShapeInfo = inputShape->at(0); - int rank = shape::rank(inShapeInfo); - int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; - if(dim < 0 ) - dim += rank + 1; - - REQUIRE_TRUE(dim <= inShapeInfo[0], 0, "STACK op: the input dimension parameter must be <= rank of input arrays shapes (rank=%i), but got %i instead !", inShapeInfo[0], dim); - - // empty input arrays require some special handling - if (shape::isEmpty(inShapeInfo)) { - switch (rank) { - case 0: { - // we're going to return rank 1 here - if (block.width() == 1) { - return SHAPELIST(ConstantShapeHelper::getInstance().vectorShapeInfo(0, ArrayOptions::dataType(inShapeInfo))); - } else { - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inShapeInfo), 'c', {(Nd4jLong) block.width(), 0})); - } - } - } - } - - if(rank == 0) { - return SHAPELIST(ConstantShapeHelper::getInstance().vectorShapeInfo(block.width(), ArrayOptions::dataType(inShapeInfo))); - } - - //the rank of output ShapeInfo is larger by one compared to input ShapeInfo - std::vector outShape(inShapeInfo + 1, inShapeInfo + 1 + rank); - - // insert (int) block.width() at dim position of input shape to get output shape - outShape.insert(outShape.begin() + Nd4jLong(dim), (Nd4jLong) block.width()); - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), outShape))); + // check whether input dimension is within rank range + auto inShapeInfo = inputShape->at(0); + int rank = shape::rank(inShapeInfo); + int dim = block.numI() > 0 ? INT_ARG(0) : 0; + if (dim < 0) dim += rank + 1; + + REQUIRE_TRUE(dim <= inShapeInfo[0], 0, + "STACK op: the input dimension parameter must be <= rank of " + "input arrays shapes (rank=%i), but got %i instead !", + inShapeInfo[0], dim); + + // empty input arrays require some special handling + if (shape::isEmpty(inShapeInfo)) { + switch (rank) { + case 0: { + // we're going to return rank 1 here + if (block.width() == 1) { + return SHAPELIST(ConstantShapeHelper::getInstance().vectorShapeInfo( + 0, ArrayOptions::dataType(inShapeInfo))); + } else { + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inShapeInfo), 'c', + {(Nd4jLong)block.width(), 0})); + } + } + } + } + + if (rank == 0) { + return SHAPELIST(ConstantShapeHelper::getInstance().vectorShapeInfo( + block.width(), ArrayOptions::dataType(inShapeInfo))); + } + + // the rank of output ShapeInfo is larger by one compared to input ShapeInfo + std::vector outShape(inShapeInfo + 1, inShapeInfo + 1 + rank); + + // insert (int) block.width() at dim position of input shape to get output + // shape + outShape.insert(outShape.begin() + Nd4jLong(dim), (Nd4jLong)block.width()); + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(inShapeInfo), + shape::order(inShapeInfo), outShape))); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/standardize.cpp b/libnd4j/include/ops/declarable/generic/transforms/standardize.cpp index f4e8a6f7acc8..8de950c3c91c 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/standardize.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/standardize.cpp @@ -24,117 +24,121 @@ #include #include - namespace sd { -namespace ops { - - CONFIGURABLE_OP_IMPL(standardize, 1, 1, true, 0, -2) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - std::vector axis; - - if (block.width() > 1) - axis = INPUT_VARIABLE(1)->template asVectorT(); - else if (block.numI() > 0) - axis = *block.getIArguments(); - - REQUIRE_TRUE(!axis.empty(), 0, "STANDARDIZE OP: axis has to be non-empty") - - shape::checkDimensions(input->rankOf(), axis); - - auto means = input->reduceAlongDimension(reduce::Mean, axis, true); - auto stdev = input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, false, axis); - stdev.reshapei(means.getShapeAsVector()); - - input->applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), means, *output, false); - output->applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), stdev, *output, false); - output->applyScalar(sd::scalar::ReplaceNans, 0, *output); - - return Status::OK(); - } - - - DECLARE_TYPES(standardize) { - getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}); - getOpDescriptor()->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}); - getOpDescriptor()->setAllowedOutputTypes(0, DataType::INHERIT); - } +namespace ops { - CUSTOM_OP_IMPL(standardize_bp, 2, 1, false, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto eps = block.width() == 3 ? INPUT_VARIABLE(2) : INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(standardize, 1, 1, true, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - std::vector axis; + std::vector axis; - if (block.width() == 3) - axis = INPUT_VARIABLE(1)->template asVectorT(); - else if (block.numI() > 0) - axis = *block.getIArguments(); + if (block.width() > 1) + axis = INPUT_VARIABLE(1)->template asVectorT(); + else if (block.numI() > 0) + axis = block.getIArguments(); - REQUIRE_TRUE(!axis.empty(), 0, "STANDARDIZE OP: axis has to be non-empty") + REQUIRE_TRUE(!axis.empty(), 0, "STANDARDIZE OP: axis has to be non-empty") + shape::checkDimensions(input->rankOf(), axis); - shape::checkDimensions(input->rankOf(), axis); - auto longAxis = ArrayUtils::toLongVector(axis); + auto means = input->reduceAlongDimension(reduce::Mean, axis, true); + auto stdev = input->varianceAlongDimension( + variance::SummaryStatsStandardDeviation, false, axis); + stdev.reshapei(means.getShapeAsVector()); - auto means = input->reduceAlongDimension(reduce::Mean, axis, true); - auto stdev = input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, false, axis); - stdev.reshapei(means.getShapeAsVector()); + input->applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), means, *output, + false); + output->applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), stdev, *output, + false); + output->applyScalar(sd::scalar::ReplaceNans, 0, *output); - eps->applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), stdev, *output, false); - - NDArray dldu_sum = -output->reduceAlongDimension(reduce::Sum, axis, true); - - NDArray dldx_u(input->shapeInfo(), false, block.launchContext()); - std::vector meanBpArgs = {input, &dldu_sum}; - std::vector meanBpOutput = {&dldx_u}; - std::vector meanBpTArgs = {}; - std::vector meanBpBArgs = {}; - - sd::ops::reduce_mean_bp meanBp; - meanBp.execute(meanBpArgs, meanBpOutput, meanBpTArgs, longAxis, meanBpBArgs); - *output += dldx_u; - - // (eps * (means - input) / (stdev * stdev)) - NDArray tmp(eps->shapeInfo(), false, block.launchContext()); - means.applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), *input, tmp, false); - tmp.applyPairwiseTransform(sd::pairwise::Multiply, *eps, tmp); - stdev.applyPairwiseTransform(sd::pairwise::Multiply, stdev, stdev); - tmp.applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), stdev, tmp, false); - - auto dlds_sum = tmp.reduceAlongDimension(reduce::Sum, axis, true); - NDArray dldx_s(input->shapeInfo(), false, block.launchContext()); - std::vector stdevBpArgs = {input, &dlds_sum}; - std::vector stdevBpOutput = {&dldx_s}; - std::vector stdevBpTArgs = {}; - std::vector stdevBpBArgs = {}; - sd::ops::reduce_stdev_bp stdevBp; - stdevBp.execute(stdevBpArgs, stdevBpOutput, stdevBpTArgs, longAxis, stdevBpBArgs); - *output += dldx_s; - - output->applyScalar(sd::scalar::ReplaceNans, 0, *output); + return Status::OK(); +} - return Status::OK(); - } +DECLARE_TYPES(standardize) { + getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}); + getOpDescriptor()->setAllowedInputTypes(1, + {DataType::INT32, DataType::INT64}); + getOpDescriptor()->setAllowedOutputTypes(0, DataType::INHERIT); +} - DECLARE_TYPES(standardize_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +CUSTOM_OP_IMPL(standardize_bp, 2, 1, false, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto eps = block.width() == 3 ? INPUT_VARIABLE(2) : INPUT_VARIABLE(1); + + auto output = OUTPUT_VARIABLE(0); + std::vector axis; + + if (block.width() == 3) + axis = INPUT_VARIABLE(1)->template asVectorT(); + else if (block.numI() > 0) + axis = block.getIArguments(); + + REQUIRE_TRUE(!axis.empty(), 0, "STANDARDIZE OP: axis has to be non-empty") + + shape::checkDimensions(input->rankOf(), axis); + auto longAxis = ArrayUtils::toLongVector(axis); + + auto means = input->reduceAlongDimension(reduce::Mean, axis, true); + auto stdev = input->varianceAlongDimension( + variance::SummaryStatsStandardDeviation, false, axis); + stdev.reshapei(means.getShapeAsVector()); + + eps->applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), stdev, *output, + false); + + NDArray dldu_sum = -output->reduceAlongDimension(reduce::Sum, axis, true); + + NDArray dldx_u(input->shapeInfo(), false, block.launchContext()); + std::vector meanBpArgs = {input, &dldu_sum}; + std::vector meanBpOutput = {&dldx_u}; + std::vector meanBpTArgs = {}; + std::vector meanBpBArgs = {}; + + sd::ops::reduce_mean_bp meanBp; + meanBp.execute(meanBpArgs, meanBpOutput, meanBpTArgs, longAxis, meanBpBArgs); + *output += dldx_u; + + // (eps * (means - input) / (stdev * stdev)) + NDArray tmp(eps->shapeInfo(), false, block.launchContext()); + means.applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), *input, tmp, + false); + tmp.applyPairwiseTransform(sd::pairwise::Multiply, *eps, tmp); + stdev.applyPairwiseTransform(sd::pairwise::Multiply, stdev, stdev); + tmp.applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), stdev, tmp, false); + + auto dlds_sum = tmp.reduceAlongDimension(reduce::Sum, axis, true); + NDArray dldx_s(input->shapeInfo(), false, block.launchContext()); + std::vector stdevBpArgs = {input, &dlds_sum}; + std::vector stdevBpOutput = {&dldx_s}; + std::vector stdevBpTArgs = {}; + std::vector stdevBpBArgs = {}; + sd::ops::reduce_stdev_bp stdevBp; + stdevBp.execute(stdevBpArgs, stdevBpOutput, stdevBpTArgs, longAxis, + stdevBpBArgs); + *output += dldx_s; + + output->applyScalar(sd::scalar::ReplaceNans, 0, *output); + + return Status::OK(); +} - DECLARE_SHAPE_FN(standardize_bp) { - auto in = inputShape->at(0); - Nd4jLong *out; - COPY_SHAPE(in, out); +DECLARE_TYPES(standardize_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - return SHAPELIST(CONSTANT(out)); - } +DECLARE_SHAPE_FN(standardize_bp) { + auto in = inputShape->at(0); + Nd4jLong *out; + COPY_SHAPE(in, out); + return SHAPELIST(CONSTANT(out)); } -} + +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/tear.cpp b/libnd4j/include/ops/declarable/generic/transforms/tear.cpp index b2292e2b989b..54bc7584b5a6 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/tear.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/tear.cpp @@ -21,60 +21,65 @@ #include #if NOT_EXCLUDED(OP_tear) -#include -#include #include +#include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(tear, 1, -1, false, 0, -1) { - auto input = INPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(tear, 1, -1, false, 0, -1) { + auto input = INPUT_VARIABLE(0); - REQUIRE_TRUE(!block.getIArguments()->empty(), 0, "At least 1 dimension should be specified for Tear"); + REQUIRE_TRUE(!block.getIArguments().empty(), 0, + "At least 1 dimension should be specified for Tear"); - std::vector dims(*block.getIArguments()); + std::vector dims(block.getIArguments()); - for (auto &v: dims) - REQUIRE_TRUE(v >= 0 && v < input->rankOf(), 0, "Tear dimensions should be non-negative values, and lower then input rank. Got %i instead", v); + for (auto &v : dims) + REQUIRE_TRUE(v >= 0 && v < input->rankOf(), 0, + "Tear dimensions should be non-negative values, and lower " + "then input rank. Got %i instead", + v); - auto tads = input->allTensorsAlongDimension(dims); - for (Nd4jLong e = 0; e < tads.size(); e++) { - auto outE = OUTPUT_VARIABLE(e); - outE->assign(tads.at(e)); + auto tads = input->allTensorsAlongDimension(dims); + for (Nd4jLong e = 0; e < tads.size(); e++) { + auto outE = OUTPUT_VARIABLE(e); + outE->assign(tads.at(e)); - // just for debugging purposes - this->storeResult(block, e, *outE); - } + // just for debugging purposes + this->storeResult(block, e, *outE); + } - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(tear) { - auto inShape = inputShape->at(0); +DECLARE_SHAPE_FN(tear) { + auto inShape = inputShape->at(0); - std::vector dims(*block.getIArguments()); + std::vector dims(block.getIArguments()); - if (dims.size() > 1) - std::sort(dims.begin(), dims.end()); + if (dims.size() > 1) std::sort(dims.begin(), dims.end()); - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(inShape, dims); - auto numTads = tadPack.numberOfTads(); + auto tadPack = + sd::ConstantTadHelper::getInstance().tadForDimensions(inShape, dims); + auto numTads = tadPack.numberOfTads(); - auto result = SHAPELIST(); - for (Nd4jLong e = 0; e < numTads; e++) { - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(block.dataType(), shape::order(inShape), shape::rank(tadPack.primaryShapeInfo()), shape::shapeOf(tadPack.primaryShapeInfo())); - result->push_back(newShape); - } + auto result = SHAPELIST(); + for (Nd4jLong e = 0; e < numTads; e++) { + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inShape), shape::order(inShape), + shape::rank(tadPack.primaryShapeInfo()), + shape::shapeOf(tadPack.primaryShapeInfo())); + result->push_back(newShape); + } - return result; - } + return result; +} - DECLARE_TYPES(tear) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +DECLARE_TYPES(tear) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/tile.cpp b/libnd4j/include/ops/declarable/generic/transforms/tile.cpp index e8a502e7465d..f6a506b86c13 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/tile.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/tile.cpp @@ -23,162 +23,178 @@ #if NOT_EXCLUDED(OP_tile) #include -#include +#include namespace sd { namespace ops { CUSTOM_OP_IMPL(tile, 1, 1, false, 0, -2) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - const int inRank = input->rankOf(); - std::vector reps; - - if (block.getIArguments()->size() == inRank) { - - reps = ArrayUtils::toLongVector(*(block.getIArguments())); - } - else if (block.width() > 1) { - - auto reps_vector = INPUT_VARIABLE(1); - REQUIRE_TRUE(reps_vector->lengthOf() == inRank, 0, "TILE op: repeats vector length should be equal to input rank, but got %i and %i correspondingly !", reps_vector->lengthOf(), inRank); - - reps = reps_vector->template asVectorT(); - } - else { - REQUIRE_TRUE(false, 0, "TILE op: this op requires repeats vector, either as IArgs or second array with length equal to rank of input array to be tiled !"); - } - - auto repProd = shape::prodLong(reps.data(), reps.size()); - REQUIRE_TRUE(repProd > 0, 0, "TILE op: reps can't contain 0s"); - - input->tile(reps, *output); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + const int inRank = input->rankOf(); + std::vector reps; + + if (block.numI() == inRank) { + reps = ArrayUtils::toLongVector(block.getIArguments()); + } else if (block.width() > 1) { + auto reps_vector = INPUT_VARIABLE(1); + REQUIRE_TRUE(reps_vector->lengthOf() == inRank, 0, + "TILE op: repeats vector length should be equal to input " + "rank, but got %i and %i correspondingly !", + reps_vector->lengthOf(), inRank); + + reps = reps_vector->template asVectorT(); + } else { + REQUIRE_TRUE( + false, 0, + "TILE op: this op requires repeats vector, either as IArgs or second " + "array with length equal to rank of input array to be tiled !"); + } + + auto repProd = shape::prodLong(reps.data(), reps.size()); + REQUIRE_TRUE(repProd > 0, 0, "TILE op: reps can't contain 0s"); + + input->tile(reps, *output); + + return Status::OK(); } - DECLARE_TYPES(tile) { - getOpDescriptor()->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes(sd::DataType::ANY); - } - +DECLARE_TYPES(tile) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes(sd::DataType::ANY); +} DECLARE_SHAPE_FN(tile) { - - auto inShape = inputShape->at(0); - const int inRank = inShape[0]; - std::vector reps; - - if (block.getIArguments()->size() == inRank) { - - reps = ArrayUtils::toLongVector(*(block.getIArguments())); - } - else if (block.width() > 1) { - - auto reps_vector = INPUT_VARIABLE(1); - REQUIRE_TRUE(reps_vector->lengthOf() == inRank, 0, "TILE op: repeats vector length should be equal to input rank, but got %i and %i correspondingly !", reps_vector->lengthOf(), inRank); - reps = reps_vector->template asVectorT(); - } - else { - REQUIRE_TRUE(false, 0, "TILE op: this op requires repeats vector, either as IArgs or second array with length equal to rank of input array to be tiled !"); - } - - auto repProd = shape::prodLong(reps.data(), reps.size()); - REQUIRE_TRUE(repProd > 0, 0, "TILE op: reps can't contain 0s"); - - std::vector shape(inRank); - for (int e = 0; e < shape::rank(inShape); e++) - shape[e] = shape::sizeAt(inShape, e) * reps[e]; - - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), shape); - return SHAPELIST(newShape); + auto inShape = inputShape->at(0); + const int inRank = inShape[0]; + std::vector reps; + + if (block.numI() == inRank) { + reps = ArrayUtils::toLongVector(block.getIArguments()); + } else if (block.width() > 1) { + auto reps_vector = INPUT_VARIABLE(1); + REQUIRE_TRUE(reps_vector->lengthOf() == inRank, 0, + "TILE op: repeats vector length should be equal to input " + "rank, but got %i and %i correspondingly !", + reps_vector->lengthOf(), inRank); + reps = reps_vector->template asVectorT(); + } else { + REQUIRE_TRUE( + false, 0, + "TILE op: this op requires repeats vector, either as IArgs or second " + "array with length equal to rank of input array to be tiled !"); + } + + auto repProd = shape::prodLong(reps.data(), reps.size()); + REQUIRE_TRUE(repProd > 0, 0, "TILE op: reps can't contain 0s"); + + std::vector shape(inRank); + for (int e = 0; e < shape::rank(inShape); e++) + shape[e] = shape::sizeAt(inShape, e) * reps[e]; + + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inShape), shape::order(inShape), shape); + return SHAPELIST(newShape); } - //////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(tile_bp, 2, 1, false, 0, -2) { - - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); - - const int inRank = input->rankOf(); - - std::vector reps; - - if (block.getIArguments()->size() == inRank) { - - reps = ArrayUtils::toLongVector(*(block.getIArguments())); - } - else if (block.width() > 2) { - - auto reps_vector = INPUT_VARIABLE(1); - REQUIRE_TRUE(reps_vector->lengthOf() == inRank, 0, "TILE_BP op: repeats vector length should be equal to input rank, but got %i and %i correspondingly !", reps_vector->lengthOf(), inRank); - - reps = reps_vector->template asVectorT(); - gradO = INPUT_VARIABLE(2); - } - else { - REQUIRE_TRUE(false, 0, "TILE_BP op: this op requires repeats vector, either as IArgs or second array with length equal to rank of input array to be tiled !"); - } - - REQUIRE_TRUE(inRank == gradO->rankOf(), 0, "TILE_BP op: the ranks of input array and output's gradients array (next epsilon) must be equal, but got %i and %i correspondingly !", inRank, gradO->rankOf()); - - for (int i = 0; i < inRank; ++i) - REQUIRE_TRUE(gradO->sizeAt(i) == gradI->sizeAt(i) * reps[i], 0, "TILE_BP op: shapes of input array and output's gradients array (next epsilon) are inconsistent !"); - - helpers::tileBP(block.launchContext(), *gradO, *gradI, reps); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); + + const int inRank = input->rankOf(); + + std::vector reps; + + if (block.numI() == inRank) { + reps = ArrayUtils::toLongVector(block.getIArguments()); + } else if (block.width() > 2) { + auto reps_vector = INPUT_VARIABLE(1); + REQUIRE_TRUE(reps_vector->lengthOf() == inRank, 0, + "TILE_BP op: repeats vector length should be equal to input " + "rank, but got %i and %i correspondingly !", + reps_vector->lengthOf(), inRank); + + reps = reps_vector->template asVectorT(); + gradO = INPUT_VARIABLE(2); + } else { + REQUIRE_TRUE( + false, 0, + "TILE_BP op: this op requires repeats vector, either as IArgs or " + "second array with length equal to rank of input array to be tiled !"); + } + + REQUIRE_TRUE( + inRank == gradO->rankOf(), 0, + "TILE_BP op: the ranks of input array and output's gradients array (next " + "epsilon) must be equal, but got %i and %i correspondingly !", + inRank, gradO->rankOf()); + + for (int i = 0; i < inRank; ++i) + REQUIRE_TRUE(gradO->sizeAt(i) == gradI->sizeAt(i) * reps[i], 0, + "TILE_BP op: shapes of input array and output's gradients " + "array (next epsilon) are inconsistent !"); + + helpers::tileBP(block.launchContext(), *gradO, *gradI, reps); + + return Status::OK(); } - DECLARE_TYPES(tile_bp) { - getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}); - getOpDescriptor()->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS}); - getOpDescriptor()->setAllowedInputTypes(2, {ALL_FLOATS}); +DECLARE_TYPES(tile_bp) { + getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}); + getOpDescriptor()->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS}); + getOpDescriptor()->setAllowedInputTypes(2, {ALL_FLOATS}); - getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); - } + getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(tile_bp) { - - auto inShape = inputShape->at(0); - auto gradOShape = inputShape->at(1); - const int inRank = inShape[0]; - - std::vector reps; - - if (block.getIArguments()->size() == inRank) { - - reps = ArrayUtils::toLongVector(*(block.getIArguments())); - } - else if (block.width() > 2) { - auto reps_vector = INPUT_VARIABLE(1); - REQUIRE_TRUE(reps_vector->lengthOf() == inRank, 0, "TILE_BP op: repeats vector length should be equal to input rank, but got %i and %i correspondingly !", reps_vector->lengthOf(), inRank); - reps = reps_vector->template asVectorT(); - gradOShape = inputShape->at(2); - } - else { - REQUIRE_TRUE(false, 0, "TILE_BP op: this op requires repeats vector, either as IArgs or second array with length equal to rank of input array to be tiled !"); - } - - REQUIRE_TRUE(inRank == gradOShape[0], 0, "TILE_BP op: the ranks of input array and output's gradients array (next epsilon) must be equal, but got %i and %i correspondingly !", inRank, gradOShape[0]); - - for (int i = 0; i < inRank; ++i) - REQUIRE_TRUE(shape::sizeAt(gradOShape, i) == shape::sizeAt(inShape, i) * reps[i], 0, "TILE_BP op: shapes of input array and output's gradients array (next epsilon) are inconsistent !"); - - Nd4jLong *gradIShape; - COPY_SHAPE(inShape, gradIShape); - - return SHAPELIST(CONSTANT(gradIShape)); - + auto inShape = inputShape->at(0); + auto gradOShape = inputShape->at(1); + const int inRank = inShape[0]; + + std::vector reps; + + if (block.numI() == inRank) { + reps = ArrayUtils::toLongVector(block.getIArguments()); + } else if (block.width() > 2) { + auto reps_vector = INPUT_VARIABLE(1); + REQUIRE_TRUE(reps_vector->lengthOf() == inRank, 0, + "TILE_BP op: repeats vector length should be equal to input " + "rank, but got %i and %i correspondingly !", + reps_vector->lengthOf(), inRank); + reps = reps_vector->template asVectorT(); + gradOShape = inputShape->at(2); + } else { + REQUIRE_TRUE( + false, 0, + "TILE_BP op: this op requires repeats vector, either as IArgs or " + "second array with length equal to rank of input array to be tiled !"); + } + + REQUIRE_TRUE( + inRank == gradOShape[0], 0, + "TILE_BP op: the ranks of input array and output's gradients array (next " + "epsilon) must be equal, but got %i and %i correspondingly !", + inRank, gradOShape[0]); + + for (int i = 0; i < inRank; ++i) + REQUIRE_TRUE( + shape::sizeAt(gradOShape, i) == shape::sizeAt(inShape, i) * reps[i], 0, + "TILE_BP op: shapes of input array and output's gradients array (next " + "epsilon) are inconsistent !"); + + Nd4jLong *gradIShape; + COPY_SHAPE(inShape, gradIShape); + + return SHAPELIST(CONSTANT(gradIShape)); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/unstack.cpp b/libnd4j/include/ops/declarable/generic/transforms/unstack.cpp index 0dfe1e54cd59..ff6a0729eaf8 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/unstack.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/unstack.cpp @@ -23,107 +23,111 @@ #if NOT_EXCLUDED(OP_unstack) #include -#include +#include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(unstack, 1, -1, false, 0, 1) { + auto input = INPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); + auto dim = INT_ARG(0); + if (dim < 0) dim += input->rankOf(); - auto dim = INT_ARG(0); - if (dim < 0) - dim += input->rankOf(); + REQUIRE_TRUE(dim < input->rankOf(), 0, + "Unstack dimension should be lower then rank of input %i, but " + "got dimension=%i !", + input->rankOf(), dim); + REQUIRE_TRUE(dim >= 0, 0, + "Unstack dimension should be non-negative value, but got %i !", + dim); + if (input->isEmpty()) return Status::OK(); - REQUIRE_TRUE(dim < input->rankOf(), 0, "Unstack dimension should be lower then rank of input %i, but got dimension=%i !", input->rankOf(), dim); - REQUIRE_TRUE(dim >= 0, 0, "Unstack dimension should be non-negative value, but got %i !", dim); + std::vector outArrs(input->sizeAt(dim)); + for (uint i = 0; i < outArrs.size(); ++i) outArrs[i] = OUTPUT_VARIABLE(i); - if(input->isEmpty()) - return Status::OK(); + helpers::unstack(block.launchContext(), *input, outArrs, dim); - std::vector outArrs(input->sizeAt(dim)); - for(uint i = 0; i < outArrs.size(); ++i) - outArrs[i] = OUTPUT_VARIABLE(i); - - helpers::unstack(block.launchContext(), *input, outArrs, dim); - - return Status::OK(); + return Status::OK(); } DECLARE_SYN(unpack, unstack); DECLARE_SHAPE_FN(unstack) { - auto inShapeInfo = inputShape->at(0); - - auto dim = INT_ARG(0); - if (dim < 0) - dim += shape::rank(inShapeInfo); - - REQUIRE_TRUE(dim < inShapeInfo[0], 0, "UNSTACK op: dimension should be lower then rank of input %i, but got dimension=%i !", inShapeInfo[0], dim); - REQUIRE_TRUE(dim >= 0, 0, "UNSTACK op: dimension should be non-negative value, but got %i !", dim); - - if(ArrayOptions::arrayType(inShapeInfo) == ArrayType::EMPTY) { - - if(shape::shapeOf(inShapeInfo)[dim] == 0) - return SHAPELIST(); - - const Nd4jLong numTads = shape::shapeOf(inShapeInfo)[dim]; - std::vector outShape; - for(uint i = 0; i < shape::rank(inShapeInfo); ++i) - if(i != dim) - outShape.push_back(shape::shapeOf(inShapeInfo)[i]); - - auto result = SHAPELIST(); - for(uint i = 0; i < numTads; ++i) - result->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), outShape)); + auto inShapeInfo = inputShape->at(0); - return result; - } + auto dim = INT_ARG(0); + if (dim < 0) dim += shape::rank(inShapeInfo); - std::vector dims = ShapeUtils::evalDimsToExclude(inShapeInfo[0], {dim}); + REQUIRE_TRUE(dim < inShapeInfo[0], 0, + "UNSTACK op: dimension should be lower then rank of input %i, " + "but got dimension=%i !", + inShapeInfo[0], dim); + REQUIRE_TRUE( + dim >= 0, 0, + "UNSTACK op: dimension should be non-negative value, but got %i !", dim); - if (dims.size() == 0 && shape::rank(inShapeInfo) == 1) { // split vector into lenthOf scalars + if (ArrayOptions::arrayType(inShapeInfo) == ArrayType::EMPTY) { + if (shape::shapeOf(inShapeInfo)[dim] == 0) return SHAPELIST(); - auto result = SHAPELIST(); - for (Nd4jLong e = 0; e < shape::length(inShapeInfo); e++) - result->push_back(ConstantShapeHelper::getInstance().scalarShapeInfo(ArrayOptions::dataType(inShapeInfo))); + const Nd4jLong numTads = shape::shapeOf(inShapeInfo)[dim]; + std::vector outShape; + for (uint i = 0; i < shape::rank(inShapeInfo); ++i) + if (i != dim) outShape.push_back(shape::shapeOf(inShapeInfo)[i]); - return result; - } - - std::vector subArrShape(shape::rank(inShapeInfo) - 1); + auto result = SHAPELIST(); + for (uint i = 0; i < numTads; ++i) + result->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), + outShape)); - for(uint j = 0, i = 0; i < shape::rank(inShapeInfo); i++) - if(i != dim) - subArrShape[j++] = shape::shapeOf(inShapeInfo)[i]; + return result; + } - // remove leading and trailing 1 - if (inShapeInfo[0] == 2 && subArrShape.size() == 2) { + std::vector dims = ShapeUtils::evalDimsToExclude(inShapeInfo[0], {dim}); - if (subArrShape[0] == 1) - subArrShape.erase(subArrShape.begin()); - else if (subArrShape[1] == 1) - subArrShape.erase(subArrShape.end()); - } + if (dims.size() == 0 && + shape::rank(inShapeInfo) == 1) { // split vector into lenthOf scalars auto result = SHAPELIST(); - for (int e = 0; e < shape::shapeOf(inShapeInfo)[dim]; e++) { - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), subArrShape); - result->push_back(newShape); - } + for (Nd4jLong e = 0; e < shape::length(inShapeInfo); e++) + result->push_back(ConstantShapeHelper::getInstance().scalarShapeInfo( + ArrayOptions::dataType(inShapeInfo))); + return result; + } + + std::vector subArrShape(shape::rank(inShapeInfo) - 1); + + for (uint j = 0, i = 0; i < shape::rank(inShapeInfo); i++) + if (i != dim) subArrShape[j++] = shape::shapeOf(inShapeInfo)[i]; + + // remove leading and trailing 1 + if (inShapeInfo[0] == 2 && subArrShape.size() == 2) { + if (subArrShape[0] == 1) + subArrShape.erase(subArrShape.begin()); + else if (subArrShape[1] == 1) + subArrShape.erase(subArrShape.end()); + } + + auto result = SHAPELIST(); + for (int e = 0; e < shape::shapeOf(inShapeInfo)[dim]; e++) { + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), + subArrShape); + result->push_back(newShape); + } + return result; } DECLARE_TYPES(unstack) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS, ALL_INTS}) - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS, ALL_INTS}) + ->setSameMode(true); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tsne/cell_contains.cpp b/libnd4j/include/ops/declarable/generic/tsne/cell_contains.cpp index e176797b0dc3..26b103af4bea 100644 --- a/libnd4j/include/ops/declarable/generic/tsne/cell_contains.cpp +++ b/libnd4j/include/ops/declarable/generic/tsne/cell_contains.cpp @@ -25,30 +25,31 @@ #include namespace sd { - namespace ops { - - CUSTOM_OP_IMPL(cell_contains, 3, 1, false, 0, 1) { - auto corner = INPUT_VARIABLE(0); - auto width = INPUT_VARIABLE(1); - auto point = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - auto dimension = INT_ARG(0); - output->assign(helpers::cell_contains(corner, width, point, dimension)); - return Status::OK(); - } - - DECLARE_TYPES(cell_contains) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::BOOL) - ->setSameMode(false); - } - - DECLARE_SHAPE_FN(cell_contains) { - return SHAPELIST(CONSTANT(ShapeBuilders::createScalarShapeInfo(sd::DataType::BOOL, block.workspace()))); - } - } +namespace ops { + +CUSTOM_OP_IMPL(cell_contains, 3, 1, false, 0, 1) { + auto corner = INPUT_VARIABLE(0); + auto width = INPUT_VARIABLE(1); + auto point = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + auto dimension = INT_ARG(0); + output->assign(helpers::cell_contains(corner, width, point, dimension)); + return Status::OK(); +} + +DECLARE_TYPES(cell_contains) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::BOOL) + ->setSameMode(false); +} + +DECLARE_SHAPE_FN(cell_contains) { + return SHAPELIST(CONSTANT(ShapeBuilders::createScalarShapeInfo( + sd::DataType::BOOL, block.workspace()))); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tsne/edge_force.cpp b/libnd4j/include/ops/declarable/generic/tsne/edge_force.cpp index 64be499fb719..30da9ad68af8 100644 --- a/libnd4j/include/ops/declarable/generic/tsne/edge_force.cpp +++ b/libnd4j/include/ops/declarable/generic/tsne/edge_force.cpp @@ -25,45 +25,55 @@ #include namespace sd { -namespace ops { - - CUSTOM_OP_IMPL(barnes_edge_forces, 4, 1, false, 0, 1) { - auto rowP = INPUT_VARIABLE(0); - auto colP = INPUT_VARIABLE(1); - auto valP = INPUT_VARIABLE(2); - auto dataP = INPUT_VARIABLE(3); - auto N = INT_ARG(0); +namespace ops { - auto output = OUTPUT_NULLIFIED(0); +CUSTOM_OP_IMPL(barnes_edge_forces, 4, 1, false, 0, 1) { + auto rowP = INPUT_VARIABLE(0); + auto colP = INPUT_VARIABLE(1); + auto valP = INPUT_VARIABLE(2); + auto dataP = INPUT_VARIABLE(3); + auto N = INT_ARG(0); - REQUIRE_TRUE(rowP->isVector(), 0, "barnes_edge_force: row input must be a vector, but its rank is %i instead !", rowP->rankOf()); - REQUIRE_TRUE(colP->isVector(), 0, "barnes_edge_force: col input must be a vector, but its rank is %i instead !", colP->rankOf()); - REQUIRE_TRUE(dataP->dataType() == output->dataType() && dataP->dataType() == valP->dataType(), 0, "barnes_edge_force: data type of dataP, valP and output must be the same"); + auto output = OUTPUT_NULLIFIED(0); - helpers::barnes_edge_forces(rowP, colP, valP, N, output, *dataP); + REQUIRE_TRUE(rowP->isVector(), 0, + "barnes_edge_force: row input must be a vector, but its rank is " + "%i instead !", + rowP->rankOf()); + REQUIRE_TRUE(colP->isVector(), 0, + "barnes_edge_force: col input must be a vector, but its rank is " + "%i instead !", + colP->rankOf()); + REQUIRE_TRUE(dataP->dataType() == output->dataType() && + dataP->dataType() == valP->dataType(), + 0, + "barnes_edge_force: data type of dataP, valP and output must be " + "the same"); - return Status::OK(); - } - - DECLARE_TYPES(barnes_edge_forces) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedInputTypes(3, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setSameMode(false); - } - - DECLARE_SHAPE_FN(barnes_edge_forces) { - Nd4jLong* bufShape; - Nd4jLong* outShapeInfo; - outShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShape->at(3), inputShape->at(3), false, block.getWorkspace()); - return SHAPELIST(CONSTANT(outShapeInfo)); - } + helpers::barnes_edge_forces(rowP, colP, valP, N, output, *dataP); + return Status::OK(); +} +DECLARE_TYPES(barnes_edge_forces) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedInputTypes(3, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setSameMode(false); } + +DECLARE_SHAPE_FN(barnes_edge_forces) { + Nd4jLong* bufShape; + Nd4jLong* outShapeInfo; + outShapeInfo = ShapeBuilders::copyShapeInfoAndType( + inputShape->at(3), inputShape->at(3), false, block.workspace()); + return SHAPELIST(CONSTANT(outShapeInfo)); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tsne/gains.cpp b/libnd4j/include/ops/declarable/generic/tsne/gains.cpp index 4fb943483b79..87cd08bd398e 100644 --- a/libnd4j/include/ops/declarable/generic/tsne/gains.cpp +++ b/libnd4j/include/ops/declarable/generic/tsne/gains.cpp @@ -25,25 +25,23 @@ #include namespace sd { -namespace ops { - - OP_IMPL(barnes_gains, 3, 1, true) { - auto input = INPUT_VARIABLE(0); - auto gradX = INPUT_VARIABLE(1); - auto epsilon = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - helpers::barnes_gains(input, gradX, epsilon, output); - return Status::OK(); - } - - DECLARE_TYPES(barnes_gains) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } +namespace ops { + +OP_IMPL(barnes_gains, 3, 1, true) { + auto input = INPUT_VARIABLE(0); + auto gradX = INPUT_VARIABLE(1); + auto epsilon = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + helpers::barnes_gains(input, gradX, epsilon, output); + return Status::OK(); } + +DECLARE_TYPES(barnes_gains) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp b/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp index cf3675122283..016ea38553b7 100644 --- a/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp +++ b/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp @@ -14,9 +14,9 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author George A. Shulinok , created on 4/18/2019. - // +// +// @author George A. Shulinok , created on 4/18/2019. +// #include #if NOT_EXCLUDED(OP_barnes_symmetrized) @@ -25,68 +25,80 @@ #include namespace sd { - namespace ops { - NDArray* rowCountsPtr = nullptr; +namespace ops { +NDArray* rowCountsPtr = nullptr; - CUSTOM_OP_IMPL(barnes_symmetrized, 3, 3, false, 0, -1) { - auto rowP = INPUT_VARIABLE(0); - auto colP = INPUT_VARIABLE(1); - auto valP = INPUT_VARIABLE(2); - auto N = rowP->lengthOf() - 1; - auto outputRows = OUTPUT_VARIABLE(0); - auto outputCols = OUTPUT_VARIABLE(1); - auto outputVals = OUTPUT_VARIABLE(2); +CUSTOM_OP_IMPL(barnes_symmetrized, 3, 3, false, 0, -1) { + auto rowP = INPUT_VARIABLE(0); + auto colP = INPUT_VARIABLE(1); + auto valP = INPUT_VARIABLE(2); + auto N = rowP->lengthOf() - 1; + auto outputRows = OUTPUT_VARIABLE(0); + auto outputCols = OUTPUT_VARIABLE(1); + auto outputVals = OUTPUT_VARIABLE(2); - if (block.getIArguments()->size() > 0) - N = INT_ARG(0); + if (block.numI() > 0) N = INT_ARG(0); - if (rowCountsPtr) { - helpers::barnes_symmetrize(rowP, colP, valP, N, outputRows, outputCols, outputVals, rowCountsPtr); - delete rowCountsPtr; - return Status::OK(); - } - return Status::THROW("barnes_symmetrized: Cannot loop due wrong input data."); - } + if (rowCountsPtr) { + helpers::barnes_symmetrize(rowP, colP, valP, N, outputRows, outputCols, + outputVals, rowCountsPtr); + delete rowCountsPtr; + return Status::OK(); + } + return Status::THROW("barnes_symmetrized: Cannot loop due wrong input data."); +} - DECLARE_TYPES(barnes_symmetrized) { - getOpDescriptor() - ->setAllowedInputTypes(0, { DataType::INT32 }) - ->setAllowedInputTypes(1, { DataType::INT32 }) - ->setAllowedInputTypes(2, { ALL_INTS, ALL_FLOATS }) - ->setAllowedOutputTypes(1, { DataType::INT32 }) - ->setAllowedOutputTypes(1, { DataType::INT32 }) - ->setAllowedOutputTypes(2, { ALL_INTS, ALL_FLOATS }) - ->setSameMode(false); - } +DECLARE_TYPES(barnes_symmetrized) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::INT32) + ->setAllowedInputTypes(1, DataType::INT32) + ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes(1, DataType::INT32) + ->setAllowedOutputTypes(1, DataType::INT32) + ->setAllowedOutputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setSameMode(false); +} - DECLARE_SHAPE_FN(barnes_symmetrized) { - auto valPShapeInfo = inputShape->at(2); - Nd4jLong* outShapeInfo; - auto rowP = INPUT_VARIABLE(0); - auto colP = INPUT_VARIABLE(1); - auto N = rowP->lengthOf() - 1; - if (block.getIArguments()->size() > 0) - N = INT_ARG(0); - auto dataType = rowP->dataType(); //ArrayOptions::dataType(inputShape->at(0)); - NDArray* rowCounts = NDArrayFactory::create_('c', { N }, block.launchContext()); //rowP->dup(); - //srowCounts->assign(0); - Nd4jLong len = helpers::barnes_row_count(rowP, colP, N, *rowCounts); - rowCounts->syncToHost(); - // rowCounts->printBuffer("Row Counts"); - if (len <= 0) throw std::runtime_error("barnes_symmetrized: Cannot allocate shape due non-positive len."); - rowCountsPtr = rowCounts; - //ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); -// outShapeInfo[1] = 1; -// outShapeInfo[2] = len; - // ShapeUtils::updateStridesAndType(outShapeInfo, ArrayOptions::dataType(valPShapeInfo), 'c'); - //outShapeInfo = ShapeBuilders::createVectorShapeInfo(ArrayOptions::dataType(valPShapeInfo), len, block.workspace()); - outShapeInfo = sd::ShapeBuilders::createShapeInfo(ArrayOptions::dataType(valPShapeInfo), 'c', { 1, len }, block.getWorkspace()); - auto outColsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', { 1, len }, block.getWorkspace()); - auto outRowsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', { 1, N + 1 }, block.getWorkspace()); - return SHAPELIST(CONSTANT(outRowsShapeInfo), CONSTANT(outColsShapeInfo), CONSTANT(outShapeInfo)); - } +DECLARE_SHAPE_FN(barnes_symmetrized) { + auto valPShapeInfo = inputShape->at(2); + Nd4jLong* outShapeInfo; + auto rowP = INPUT_VARIABLE(0); + auto colP = INPUT_VARIABLE(1); + auto N = rowP->lengthOf() - 1; + if (block.numI() > 0) N = INT_ARG(0); - } + auto dataType = + rowP->dataType(); // ArrayOptions::dataType(inputShape->at(0)); + NDArray* rowCounts = NDArrayFactory::create_( + 'c', {N}, block.launchContext()); // rowP->dup(); + // srowCounts->assign(0); + Nd4jLong len = helpers::barnes_row_count(rowP, colP, N, *rowCounts); + rowCounts->syncToHost(); + // rowCounts->printBuffer("Row Counts"); + if (len <= 0) + throw std::runtime_error( + "barnes_symmetrized: Cannot allocate shape due non-positive len."); + rowCountsPtr = rowCounts; + // ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(2), + // Nd4jLong); + // outShapeInfo[1] = 1; + // outShapeInfo[2] = len; + // ShapeUtils::updateStridesAndType(outShapeInfo, + // ArrayOptions::dataType(valPShapeInfo), 'c'); + // outShapeInfo = + // ShapeBuilders::createVectorShapeInfo(ArrayOptions::dataType(valPShapeInfo), + // len, block.workspace()); + outShapeInfo = sd::ShapeBuilders::createShapeInfo( + ArrayOptions::dataType(valPShapeInfo), 'c', {1, len}, block.workspace()); + auto outColsShapeInfo = sd::ShapeBuilders::createShapeInfo( + dataType, 'c', {1, len}, block.workspace()); + auto outRowsShapeInfo = sd::ShapeBuilders::createShapeInfo( + dataType, 'c', {1, N + 1}, block.workspace()); + return SHAPELIST(CONSTANT(outRowsShapeInfo), CONSTANT(outColsShapeInfo), + CONSTANT(outShapeInfo)); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/updaters/adaDeltaUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/adaDeltaUpdater.cpp index 93f01ae1fb0f..f6fe9f0afac3 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/adaDeltaUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/adaDeltaUpdater.cpp @@ -14,68 +14,81 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleh Semeniv (oleg.semeniv@gmail.com) - // +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// -#include -#include -#include -#include #include +#include +#include +#include +#include namespace sd { - namespace ops { - - CONFIGURABLE_OP_IMPL(ada_delta_updater, 3, 3, true, 0, 0) { - - const auto gradient = INPUT_VARIABLE(0); - const auto initStateMsg = INPUT_VARIABLE(1); - const auto initStateMsdx = INPUT_VARIABLE(2); - - auto update = OUTPUT_VARIABLE(0); - auto stateMsg = OUTPUT_VARIABLE(1); - auto stateMsdx = OUTPUT_VARIABLE(2); - - if (gradient->isEmpty() || initStateMsg->isEmpty() || initStateMsdx->isEmpty()) - return Status::OK(); - - REQUIRE_TRUE(gradient->isSameShape(initStateMsg), 0, "ADA_DELTA UPDATER OP: input state Msg must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initStateMsg->shapeInfo()).c_str()); - REQUIRE_TRUE(gradient->isSameShape(initStateMsdx), 0, "ADA_DELTA UPDATER OP: input state Msdx must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initStateMsdx->shapeInfo()).c_str()); - - bool bParamsSupply = 5 == block.width() || 2 == block.getTArguments()->size(); - - REQUIRE_TRUE(bParamsSupply, 0, "ADA_DELTA UPDATER OP: Rho and epsilon were not provided!"); - - double dRho, dEpsilon; - - if (block.width() > 3) { - const auto rho = INPUT_VARIABLE(3); - const auto epsilon = INPUT_VARIABLE(4); - - REQUIRE_TRUE(rho->isScalar(), 0, "ADA_DELTA UPDATER OP: Rho has to be a scalar, but instead got rank %i!", rho->rankOf()); - REQUIRE_TRUE(epsilon->isScalar(), 0, "ADA_DELTA UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); - - dRho = rho->e(0); - dEpsilon = epsilon->e(0); - } - else { - dRho = T_ARG(0); - dEpsilon = T_ARG(1); - } - - helpers::updaterAdaDelta(block.launchContext(), *gradient, *initStateMsg, *initStateMsdx, *update, *stateMsg, *stateMsdx, dRho, dEpsilon); - return Status::OK(); - } - - DECLARE_TYPES(ada_delta_updater) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); - } +namespace ops { + +CONFIGURABLE_OP_IMPL(ada_delta_updater, 3, 3, true, 0, 0) { + const auto gradient = INPUT_VARIABLE(0); + const auto initStateMsg = INPUT_VARIABLE(1); + const auto initStateMsdx = INPUT_VARIABLE(2); + + auto update = OUTPUT_VARIABLE(0); + auto stateMsg = OUTPUT_VARIABLE(1); + auto stateMsdx = OUTPUT_VARIABLE(2); + + if (gradient->isEmpty() || initStateMsg->isEmpty() || + initStateMsdx->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE(gradient->isSameShape(initStateMsg), 0, + "ADA_DELTA UPDATER OP: input state Msg must have the same shape " + "as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateMsg->shapeInfo()).c_str()); + REQUIRE_TRUE(gradient->isSameShape(initStateMsdx), 0, + "ADA_DELTA UPDATER OP: input state Msdx must have the same " + "shape as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateMsdx->shapeInfo()).c_str()); + + bool bParamsSupply = 5 == block.width() || 2 == block.numT(); + + REQUIRE_TRUE(bParamsSupply, 0, + "ADA_DELTA UPDATER OP: Rho and epsilon were not provided!"); + + double dRho, dEpsilon; + + if (block.width() > 3) { + const auto rho = INPUT_VARIABLE(3); + const auto epsilon = INPUT_VARIABLE(4); + + REQUIRE_TRUE(rho->isScalar(), 0, + "ADA_DELTA UPDATER OP: Rho has to be a scalar, but instead " + "got rank %i!", + rho->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, + "ADA_DELTA UPDATER OP: Epsilon has to be a scalar, but " + "instead got rank %i!", + epsilon->rankOf()); + + dRho = rho->e(0); + dEpsilon = epsilon->e(0); + } else { + dRho = T_ARG(0); + dEpsilon = T_ARG(1); + } + + helpers::updaterAdaDelta(block.launchContext(), *gradient, *initStateMsg, + *initStateMsdx, *update, *stateMsg, *stateMsdx, dRho, + dEpsilon); + return Status::OK(); +} - } +DECLARE_TYPES(ada_delta_updater) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/updaters/adaGradUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/adaGradUpdater.cpp index 4cd5b0504142..569381cc5b73 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/adaGradUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/adaGradUpdater.cpp @@ -14,64 +14,71 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleh Semeniv (oleg.semeniv@gmail.com) - // +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// -#include -#include -#include -#include #include +#include +#include +#include +#include namespace sd { - namespace ops { - - CONFIGURABLE_OP_IMPL(ada_grad_updater, 2, 2, true, 0, 0) { - - const auto gradient = INPUT_VARIABLE(0); - const auto initState = INPUT_VARIABLE(1); - - auto update = OUTPUT_VARIABLE(0); - auto stateH = OUTPUT_VARIABLE(1); - - if (gradient->isEmpty() || initState->isEmpty()) - return Status::OK(); - - REQUIRE_TRUE(gradient->isSameShape(initState), 0, "ADA_GRAD UPDATER OP: input state must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initState->shapeInfo()).c_str()); - - - bool bParamsSupply = 4 == block.width() || 2 == block.getTArguments()->size(); - - REQUIRE_TRUE(bParamsSupply, 0, "ADA_GRAD UPDATER OP: learning rate and epsilon were not provided!"); - - double dLr, dEpsilon; - - if (block.width() > 2) { - const auto lr = INPUT_VARIABLE(2); - const auto epsilon = INPUT_VARIABLE(3); - - REQUIRE_TRUE(lr->isScalar(), 0, "ADA_GRAD UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); - REQUIRE_TRUE(epsilon->isScalar(), 0, "ADA_GRAD UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); - - dLr = lr->e(0); - dEpsilon = epsilon->e(0); - } - else { - dLr = T_ARG(0); - dEpsilon = T_ARG(1); - } - - helpers::updaterAdaGrad(block.launchContext(), *gradient, *initState, *update, *stateH, dLr, dEpsilon); - return Status::OK(); - } - - DECLARE_TYPES(ada_grad_updater) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); - } +namespace ops { + +CONFIGURABLE_OP_IMPL(ada_grad_updater, 2, 2, true, 0, 0) { + const auto gradient = INPUT_VARIABLE(0); + const auto initState = INPUT_VARIABLE(1); + + auto update = OUTPUT_VARIABLE(0); + auto stateH = OUTPUT_VARIABLE(1); + + if (gradient->isEmpty() || initState->isEmpty()) return Status::OK(); + + REQUIRE_TRUE( + gradient->isSameShape(initState), 0, + "ADA_GRAD UPDATER OP: input state must have the same shape as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initState->shapeInfo()).c_str()); + + bool bParamsSupply = 4 == block.width() || 2 == block.numT(); + + REQUIRE_TRUE( + bParamsSupply, 0, + "ADA_GRAD UPDATER OP: learning rate and epsilon were not provided!"); + + double dLr, dEpsilon; + + if (block.width() > 2) { + const auto lr = INPUT_VARIABLE(2); + const auto epsilon = INPUT_VARIABLE(3); + + REQUIRE_TRUE(lr->isScalar(), 0, + "ADA_GRAD UPDATER OP: Learning rate has to be a scalar, but " + "instead got rank %i!", + lr->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, + "ADA_GRAD UPDATER OP: Epsilon has to be a scalar, but instead " + "got rank %i!", + epsilon->rankOf()); + + dLr = lr->e(0); + dEpsilon = epsilon->e(0); + } else { + dLr = T_ARG(0); + dEpsilon = T_ARG(1); + } + + helpers::updaterAdaGrad(block.launchContext(), *gradient, *initState, *update, + *stateH, dLr, dEpsilon); + return Status::OK(); +} - } +DECLARE_TYPES(ada_grad_updater) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/updaters/adaMaxUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/adaMaxUpdater.cpp index 9f4bb574b4a0..8a9d98289112 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/adaMaxUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/adaMaxUpdater.cpp @@ -14,80 +14,98 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleh Semeniv (oleg.semeniv@gmail.com) - // +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// -#include -#include -#include -#include #include +#include +#include +#include +#include namespace sd { - namespace ops { - - CONFIGURABLE_OP_IMPL(ada_max_updater, 3, 3, true, 0, 0) { - - const auto gradient = INPUT_VARIABLE(0); - const auto initStateU = INPUT_VARIABLE(1); - const auto initStateM = INPUT_VARIABLE(2); - - auto update = OUTPUT_VARIABLE(0); - auto stateU = OUTPUT_VARIABLE(1); - auto stateM = OUTPUT_VARIABLE(2); - - // todo maybe we need an error like on Java side - if (gradient->isEmpty() || initStateU->isEmpty() || initStateM->isEmpty()) - return Status::OK(); - - REQUIRE_TRUE(gradient->isSameShape(initStateU), 0, "ADA_MAX UPDATER OP: input state V must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initStateU->shapeInfo()).c_str()); - REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "ADA_MAX UPDATER OP: input state M must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initStateM->shapeInfo()).c_str()); - - - bool bParamsSupply = 7 == block.width() || 4 == block.getTArguments()->size(); - - int iteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; - - REQUIRE_TRUE(bParamsSupply, 0, "ADA_MAX UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!"); - - double dLr, dBeta1, dBeta2, dEpsilon; - - if (block.width() > 3) { - const auto lr = INPUT_VARIABLE(3); - const auto beta1 = INPUT_VARIABLE(4); - const auto beta2 = INPUT_VARIABLE(5); - const auto epsilon = INPUT_VARIABLE(6); - - REQUIRE_TRUE(lr->isScalar(), 0, "ADA_MAX UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); - REQUIRE_TRUE(beta1->isScalar(), 0, "ADA_MAX UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf()); - REQUIRE_TRUE(beta2->isScalar(), 0, "ADA_MAX UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf()); - REQUIRE_TRUE(epsilon->isScalar(), 0, "ADA_MAX UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); - - dLr = lr->e(0); - dBeta1 = beta1->e(0); - dBeta2 = beta2->e(0); - dEpsilon = epsilon->e(0); - } - else { - dLr = T_ARG(0); - dBeta1 = T_ARG(1); - dBeta2 = T_ARG(2); - dEpsilon = T_ARG(3); - } - - helpers::updaterAdaMax(block.launchContext(), *gradient, *initStateU, *initStateM, *update, *stateU, *stateM, dLr, dBeta1, dBeta2, dEpsilon, iteration); - return Status::OK(); - } - - DECLARE_TYPES(ada_max_updater) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); - } +namespace ops { + +CONFIGURABLE_OP_IMPL(ada_max_updater, 3, 3, true, 0, 0) { + const auto gradient = INPUT_VARIABLE(0); + const auto initStateU = INPUT_VARIABLE(1); + const auto initStateM = INPUT_VARIABLE(2); + + auto update = OUTPUT_VARIABLE(0); + auto stateU = OUTPUT_VARIABLE(1); + auto stateM = OUTPUT_VARIABLE(2); + + // todo maybe we need an error like on Java side + if (gradient->isEmpty() || initStateU->isEmpty() || initStateM->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE( + gradient->isSameShape(initStateU), 0, + "ADA_MAX UPDATER OP: input state V must have the same shape as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateU->shapeInfo()).c_str()); + REQUIRE_TRUE( + gradient->isSameShape(initStateM), 0, + "ADA_MAX UPDATER OP: input state M must have the same shape as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateM->shapeInfo()).c_str()); + + bool bParamsSupply = 7 == block.width() || 4 == block.numT(); + + int iteration = block.numI() > 0 ? INT_ARG(0) : 0; + + REQUIRE_TRUE(bParamsSupply, 0, + "ADA_MAX UPDATER OP: learning rate, beta 1, beta 2 and epsilon " + "were not provided!"); + + double dLr, dBeta1, dBeta2, dEpsilon; + + if (block.width() > 3) { + const auto lr = INPUT_VARIABLE(3); + const auto beta1 = INPUT_VARIABLE(4); + const auto beta2 = INPUT_VARIABLE(5); + const auto epsilon = INPUT_VARIABLE(6); + + REQUIRE_TRUE(lr->isScalar(), 0, + "ADA_MAX UPDATER OP: Learning rate has to be a scalar, but " + "instead got rank %i!", + lr->rankOf()); + REQUIRE_TRUE(beta1->isScalar(), 0, + "ADA_MAX UPDATER OP: beta 1 has to be a scalar, but instead " + "got rank %i!", + beta1->rankOf()); + REQUIRE_TRUE(beta2->isScalar(), 0, + "ADA_MAX UPDATER OP: beta 2 has to be a scalar, but instead " + "got rank %i!", + beta2->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, + "ADA_MAX UPDATER OP: Epsilon has to be a scalar, but instead " + "got rank %i!", + epsilon->rankOf()); + + dLr = lr->e(0); + dBeta1 = beta1->e(0); + dBeta2 = beta2->e(0); + dEpsilon = epsilon->e(0); + } else { + dLr = T_ARG(0); + dBeta1 = T_ARG(1); + dBeta2 = T_ARG(2); + dEpsilon = T_ARG(3); + } + + helpers::updaterAdaMax(block.launchContext(), *gradient, *initStateU, + *initStateM, *update, *stateU, *stateM, dLr, dBeta1, + dBeta2, dEpsilon, iteration); + return Status::OK(); +} - } +DECLARE_TYPES(ada_max_updater) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/updaters/adamUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/adamUpdater.cpp index 96386c45b955..354d0d48a6f4 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/adamUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/adamUpdater.cpp @@ -14,79 +14,98 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleh Semeniv (oleg.semeniv@gmail.com) - // +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// -#include -#include -#include -#include #include +#include +#include +#include +#include namespace sd { - namespace ops { - - CONFIGURABLE_OP_IMPL(adam_updater, 3, 3, true, 0, 0) { - - const auto gradient = INPUT_VARIABLE(0); - const auto initStateU = INPUT_VARIABLE(1); - const auto initStateM = INPUT_VARIABLE(2); - - auto update = OUTPUT_VARIABLE(0); - auto stateU = OUTPUT_VARIABLE(1); - auto stateM = OUTPUT_VARIABLE(2); - - // todo maybe we need an error like on Java side - if (gradient->isEmpty() || initStateU->isEmpty() || initStateM->isEmpty()) - return Status::OK(); - - REQUIRE_TRUE(gradient->isSameShape(initStateU), 0, "ADAM UPDATER OP: input state V must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initStateU->shapeInfo()).c_str()); - REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "ADAM UPDATER OP: input state M must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initStateM->shapeInfo()).c_str()); - - bool bParamsSupply = 7 == block.width() || 4 == block.getTArguments()->size(); - - auto iteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; - - REQUIRE_TRUE(bParamsSupply, 0, "ADAM UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!"); - - double dLr, dBeta1, dBeta2, dEpsilon; - - if (block.width() > 3) { - const auto lr = INPUT_VARIABLE(3); - const auto beta1 = INPUT_VARIABLE(4); - const auto beta2 = INPUT_VARIABLE(5); - const auto epsilon = INPUT_VARIABLE(6); - - REQUIRE_TRUE(lr->isScalar(), 0, "ADAM UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); - REQUIRE_TRUE(beta1->isScalar(), 0, "ADAM UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf()); - REQUIRE_TRUE(beta2->isScalar(), 0, "ADAM UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf()); - REQUIRE_TRUE(epsilon->isScalar(), 0, "ADAM UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); - - dLr = lr->e(0); - dBeta1 = beta1->e(0); - dBeta2 = beta2->e(0); - dEpsilon = epsilon->e(0); - } - else { - dLr = T_ARG(0); - dBeta1 = T_ARG(1); - dBeta2 = T_ARG(2); - dEpsilon = T_ARG(3); - } - - helpers::updaterAdam(block.launchContext(), *gradient, *initStateU, *initStateM, *update, *stateU, *stateM, dLr, dBeta1, dBeta2, dEpsilon, iteration); - return Status::OK(); - } - - DECLARE_TYPES(adam_updater) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); - } - - } +namespace ops { + +CONFIGURABLE_OP_IMPL(adam_updater, 3, 3, true, 0, 0) { + const auto gradient = INPUT_VARIABLE(0); + const auto initStateU = INPUT_VARIABLE(1); + const auto initStateM = INPUT_VARIABLE(2); + + auto update = OUTPUT_VARIABLE(0); + auto stateU = OUTPUT_VARIABLE(1); + auto stateM = OUTPUT_VARIABLE(2); + + // todo maybe we need an error like on Java side + if (gradient->isEmpty() || initStateU->isEmpty() || initStateM->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE( + gradient->isSameShape(initStateU), 0, + "ADAM UPDATER OP: input state V must have the same shape as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateU->shapeInfo()).c_str()); + REQUIRE_TRUE( + gradient->isSameShape(initStateM), 0, + "ADAM UPDATER OP: input state M must have the same shape as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateM->shapeInfo()).c_str()); + + bool bParamsSupply = 7 == block.width() || 4 == block.numT(); + + auto iteration = block.numI() > 0 ? INT_ARG(0) : 0; + + REQUIRE_TRUE(bParamsSupply, 0, + "ADAM UPDATER OP: learning rate, beta 1, beta 2 and epsilon " + "were not provided!"); + + double dLr, dBeta1, dBeta2, dEpsilon; + + if (block.width() > 3) { + const auto lr = INPUT_VARIABLE(3); + const auto beta1 = INPUT_VARIABLE(4); + const auto beta2 = INPUT_VARIABLE(5); + const auto epsilon = INPUT_VARIABLE(6); + + REQUIRE_TRUE(lr->isScalar(), 0, + "ADAM UPDATER OP: Learning rate has to be a scalar, but " + "instead got rank %i!", + lr->rankOf()); + REQUIRE_TRUE( + beta1->isScalar(), 0, + "ADAM UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", + beta1->rankOf()); + REQUIRE_TRUE( + beta2->isScalar(), 0, + "ADAM UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", + beta2->rankOf()); + REQUIRE_TRUE( + epsilon->isScalar(), 0, + "ADAM UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", + epsilon->rankOf()); + + dLr = lr->e(0); + dBeta1 = beta1->e(0); + dBeta2 = beta2->e(0); + dEpsilon = epsilon->e(0); + } else { + dLr = T_ARG(0); + dBeta1 = T_ARG(1); + dBeta2 = T_ARG(2); + dEpsilon = T_ARG(3); + } + + helpers::updaterAdam(block.launchContext(), *gradient, *initStateU, + *initStateM, *update, *stateU, *stateM, dLr, dBeta1, + dBeta2, dEpsilon, iteration); + return Status::OK(); } + +DECLARE_TYPES(adam_updater) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); +} + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/updaters/amsGradUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/amsGradUpdater.cpp index 32084d970989..5e71fb3710fe 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/amsGradUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/amsGradUpdater.cpp @@ -14,85 +14,107 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleh Semeniv (oleg.semeniv@gmail.com) - // +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// -#include -#include -#include -#include #include +#include +#include +#include +#include namespace sd { - namespace ops { - - CONFIGURABLE_OP_IMPL(ams_grad_updater, 4, 4, true, 0, 0) { - - const auto gradient = INPUT_VARIABLE(0); - const auto initStateV = INPUT_VARIABLE(1); - const auto initStateM = INPUT_VARIABLE(2); - const auto initStateH = INPUT_VARIABLE(3); - - auto update = OUTPUT_VARIABLE(0); - auto stateV = OUTPUT_VARIABLE(1); - auto stateM = OUTPUT_VARIABLE(2); - auto stateH = OUTPUT_VARIABLE(3); - - // todo maybe we need an error like on Java side - if (gradient->isEmpty() || initStateV->isEmpty() || initStateM->isEmpty() || initStateH->isEmpty()) - return Status::OK(); - - REQUIRE_TRUE(gradient->isSameShape(initStateV), 0, "AMSGRAD UPDATER OP: input state Msg must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initStateV->shapeInfo()).c_str()); - REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "AMSGRAD UPDATER OP: input state Msdx must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initStateM->shapeInfo()).c_str()); - REQUIRE_TRUE(gradient->isSameShape(initStateH), 0, "AMSGRAD UPDATER OP: input state Msdx must have the same shape as gradient!," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initStateH->shapeInfo()).c_str()); - - bool bParamsSupply = 8 == block.width() || 4 == block.getTArguments()->size(); - - auto iteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; - - REQUIRE_TRUE(bParamsSupply, 0, "AMSGRAD UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!"); - - double dLr, dBeta1, dBeta2, dEpsilon; - - if (block.width() > 4) { - const auto lr = INPUT_VARIABLE(4); - const auto beta1 = INPUT_VARIABLE(5); - const auto beta2 = INPUT_VARIABLE(6); - const auto epsilon = INPUT_VARIABLE(7); - - REQUIRE_TRUE(lr->isScalar(), 0, "AMSGRAD UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); - REQUIRE_TRUE(beta1->isScalar(), 0, "AMSGRAD UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf()); - REQUIRE_TRUE(beta2->isScalar(), 0, "AMSGRAD UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf()); - REQUIRE_TRUE(epsilon->isScalar(), 0, "AMSGRAD UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); - - dLr = lr->e(0); - dBeta1 = beta1->e(0); - dBeta2 = beta2->e(0); - dEpsilon = epsilon->e(0); - } - else { - dLr = T_ARG(0); - dBeta1 = T_ARG(1); - dBeta2 = T_ARG(2); - dEpsilon = T_ARG(3); - } - - helpers::updaterAmsGrad(block.launchContext(), *gradient, *initStateV, *initStateM, *initStateH, - *update, *stateV, *stateM, *stateH, dLr, dBeta1, dBeta2, dEpsilon, iteration); - return Status::OK(); - } - - DECLARE_TYPES(ams_grad_updater) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); - } - - } +namespace ops { + +CONFIGURABLE_OP_IMPL(ams_grad_updater, 4, 4, true, 0, 0) { + const auto gradient = INPUT_VARIABLE(0); + const auto initStateV = INPUT_VARIABLE(1); + const auto initStateM = INPUT_VARIABLE(2); + const auto initStateH = INPUT_VARIABLE(3); + + auto update = OUTPUT_VARIABLE(0); + auto stateV = OUTPUT_VARIABLE(1); + auto stateM = OUTPUT_VARIABLE(2); + auto stateH = OUTPUT_VARIABLE(3); + + // todo maybe we need an error like on Java side + if (gradient->isEmpty() || initStateV->isEmpty() || initStateM->isEmpty() || + initStateH->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE(gradient->isSameShape(initStateV), 0, + "AMSGRAD UPDATER OP: input state Msg must have the same shape " + "as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateV->shapeInfo()).c_str()); + REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, + "AMSGRAD UPDATER OP: input state Msdx must have the same shape " + "as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateM->shapeInfo()).c_str()); + REQUIRE_TRUE(gradient->isSameShape(initStateH), 0, + "AMSGRAD UPDATER OP: input state Msdx must have the same shape " + "as gradient!," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateH->shapeInfo()).c_str()); + + bool bParamsSupply = 8 == block.width() || 4 == block.numT(); + + auto iteration = block.numI() > 0 ? INT_ARG(0) : 0; + + REQUIRE_TRUE(bParamsSupply, 0, + "AMSGRAD UPDATER OP: learning rate, beta 1, beta 2 and epsilon " + "were not provided!"); + + double dLr, dBeta1, dBeta2, dEpsilon; + + if (block.width() > 4) { + const auto lr = INPUT_VARIABLE(4); + const auto beta1 = INPUT_VARIABLE(5); + const auto beta2 = INPUT_VARIABLE(6); + const auto epsilon = INPUT_VARIABLE(7); + + REQUIRE_TRUE(lr->isScalar(), 0, + "AMSGRAD UPDATER OP: Learning rate has to be a scalar, but " + "instead got rank %i!", + lr->rankOf()); + REQUIRE_TRUE(beta1->isScalar(), 0, + "AMSGRAD UPDATER OP: beta 1 has to be a scalar, but instead " + "got rank %i!", + beta1->rankOf()); + REQUIRE_TRUE(beta2->isScalar(), 0, + "AMSGRAD UPDATER OP: beta 2 has to be a scalar, but instead " + "got rank %i!", + beta2->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, + "AMSGRAD UPDATER OP: Epsilon has to be a scalar, but instead " + "got rank %i!", + epsilon->rankOf()); + + dLr = lr->e(0); + dBeta1 = beta1->e(0); + dBeta2 = beta2->e(0); + dEpsilon = epsilon->e(0); + } else { + dLr = T_ARG(0); + dBeta1 = T_ARG(1); + dBeta2 = T_ARG(2); + dEpsilon = T_ARG(3); + } + + helpers::updaterAmsGrad(block.launchContext(), *gradient, *initStateV, + *initStateM, *initStateH, *update, *stateV, *stateM, + *stateH, dLr, dBeta1, dBeta2, dEpsilon, iteration); + return Status::OK(); } + +DECLARE_TYPES(ams_grad_updater) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); +} + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/updaters/nadamUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/nadamUpdater.cpp index 4d5e4e12e8e9..df3e95d6d777 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/nadamUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/nadamUpdater.cpp @@ -14,79 +14,98 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleh Semeniv (oleg.semeniv@gmail.com) - // +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// -#include -#include -#include -#include #include +#include +#include +#include +#include namespace sd { - namespace ops { - - CONFIGURABLE_OP_IMPL(nadam_updater, 3, 3, true, 0, 0) { - - const auto gradient = INPUT_VARIABLE(0); - const auto initStateV = INPUT_VARIABLE(1); - const auto initStateM = INPUT_VARIABLE(2); - - auto update = OUTPUT_VARIABLE(0); - auto stateV = OUTPUT_VARIABLE(1); - auto stateM = OUTPUT_VARIABLE(2); - - // todo maybe we need an error like on Java side - if (gradient->isEmpty() || initStateV->isEmpty() || initStateM->isEmpty()) - return Status::OK(); - - REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "NADAM UPDATER OP: input state M must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initStateM->shapeInfo()).c_str()); - REQUIRE_TRUE(gradient->isSameShape(initStateV), 0, "NADAM UPDATER OP: input state V must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initStateV->shapeInfo()).c_str()); - - bool bParamsSupply = 7 == block.width() || 4 == block.getTArguments()->size(); - - auto nIteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; - - REQUIRE_TRUE(bParamsSupply, 0, "NADAM UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!"); - - double dLr, dBeta1, dBeta2, dEpsilon; - - if (block.width() > 3) { - const auto lr = INPUT_VARIABLE(3); - const auto beta1 = INPUT_VARIABLE(4); - const auto beta2 = INPUT_VARIABLE(5); - const auto epsilon = INPUT_VARIABLE(6); - - REQUIRE_TRUE(lr->isScalar(), 0, "NADAM UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); - REQUIRE_TRUE(beta1->isScalar(), 0, "NADAM UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf()); - REQUIRE_TRUE(beta2->isScalar(), 0, "NADAM UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf()); - REQUIRE_TRUE(epsilon->isScalar(), 0, "NADAM UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); - - dLr = lr->e(0); - dBeta1 = beta1->e(0); - dBeta2 = beta2->e(0); - dEpsilon = epsilon->e(0); - } - else { - dLr = T_ARG(0); - dBeta1 = T_ARG(1); - dBeta2 = T_ARG(2); - dEpsilon = T_ARG(3); - } - - helpers::updaterNadam(block.launchContext(), *gradient, *initStateV, *initStateM, *update, *stateV, *stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration); - return Status::OK(); - } - - DECLARE_TYPES(nadam_updater) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); - } - - } +namespace ops { + +CONFIGURABLE_OP_IMPL(nadam_updater, 3, 3, true, 0, 0) { + const auto gradient = INPUT_VARIABLE(0); + const auto initStateV = INPUT_VARIABLE(1); + const auto initStateM = INPUT_VARIABLE(2); + + auto update = OUTPUT_VARIABLE(0); + auto stateV = OUTPUT_VARIABLE(1); + auto stateM = OUTPUT_VARIABLE(2); + + // todo maybe we need an error like on Java side + if (gradient->isEmpty() || initStateV->isEmpty() || initStateM->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE( + gradient->isSameShape(initStateM), 0, + "NADAM UPDATER OP: input state M must have the same shape as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateM->shapeInfo()).c_str()); + REQUIRE_TRUE( + gradient->isSameShape(initStateV), 0, + "NADAM UPDATER OP: input state V must have the same shape as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateV->shapeInfo()).c_str()); + + bool bParamsSupply = 7 == block.width() || 4 == block.numT(); + + auto nIteration = block.numI() > 0 ? INT_ARG(0) : 0; + + REQUIRE_TRUE(bParamsSupply, 0, + "NADAM UPDATER OP: learning rate, beta 1, beta 2 and epsilon " + "were not provided!"); + + double dLr, dBeta1, dBeta2, dEpsilon; + + if (block.width() > 3) { + const auto lr = INPUT_VARIABLE(3); + const auto beta1 = INPUT_VARIABLE(4); + const auto beta2 = INPUT_VARIABLE(5); + const auto epsilon = INPUT_VARIABLE(6); + + REQUIRE_TRUE(lr->isScalar(), 0, + "NADAM UPDATER OP: Learning rate has to be a scalar, but " + "instead got rank %i!", + lr->rankOf()); + REQUIRE_TRUE( + beta1->isScalar(), 0, + "NADAM UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", + beta1->rankOf()); + REQUIRE_TRUE( + beta2->isScalar(), 0, + "NADAM UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", + beta2->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, + "NADAM UPDATER OP: Epsilon has to be a scalar, but instead " + "got rank %i!", + epsilon->rankOf()); + + dLr = lr->e(0); + dBeta1 = beta1->e(0); + dBeta2 = beta2->e(0); + dEpsilon = epsilon->e(0); + } else { + dLr = T_ARG(0); + dBeta1 = T_ARG(1); + dBeta2 = T_ARG(2); + dEpsilon = T_ARG(3); + } + + helpers::updaterNadam(block.launchContext(), *gradient, *initStateV, + *initStateM, *update, *stateV, *stateM, dLr, dBeta1, + dBeta2, dEpsilon, nIteration); + return Status::OK(); } + +DECLARE_TYPES(nadam_updater) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); +} + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/updaters/nesterovsUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/nesterovsUpdater.cpp index bcbefe36bc18..d685ecbdc904 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/nesterovsUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/nesterovsUpdater.cpp @@ -14,62 +14,70 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleh Semeniv (oleg.semeniv@gmail.com) - // +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// -#include -#include -#include -#include #include +#include +#include +#include +#include namespace sd { - namespace ops { - - CONFIGURABLE_OP_IMPL(nesterovs_updater, 2, 2, true, 0, 0) { - - const auto gradient = INPUT_VARIABLE(0); - const auto initState = INPUT_VARIABLE(1); - - auto update = OUTPUT_VARIABLE(0); - auto stateV = OUTPUT_VARIABLE(1); - - if (gradient->isEmpty() || initState->isEmpty()) - return Status::OK(); - - REQUIRE_TRUE(gradient->isSameShape(initState), 0, "NESTEROVS UPDATER OP: input state Msg must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initState->shapeInfo()).c_str()); - - bool bParamsSupply = 4 == block.width() || 2 == block.getTArguments()->size(); - - REQUIRE_TRUE(bParamsSupply, 0, "NESTEROVS UPDATER OP: learning rate and momentum were not provided!"); - - double dLr, dMomentum; - - if (block.width() > 2) { - const auto lr = INPUT_VARIABLE(2); - const auto momentum = INPUT_VARIABLE(3); - - REQUIRE_TRUE(lr->isScalar(), 0, "NESTEROVS UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); - REQUIRE_TRUE(momentum->isScalar(), 0, "NESTEROVS UPDATER OP: Momentum has to be a scalar, but instead got rank %i!", momentum->rankOf()); - - dLr = lr->e(0); - dMomentum = momentum->e(0); - } - else { - dLr = T_ARG(0); - dMomentum = T_ARG(1); - } - helpers::updaterNesterovs(block.launchContext(), *gradient, *initState, *update, *stateV, dLr, dMomentum); - return Status::OK(); - } - - DECLARE_TYPES(nesterovs_updater) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); - } +namespace ops { + +CONFIGURABLE_OP_IMPL(nesterovs_updater, 2, 2, true, 0, 0) { + const auto gradient = INPUT_VARIABLE(0); + const auto initState = INPUT_VARIABLE(1); + + auto update = OUTPUT_VARIABLE(0); + auto stateV = OUTPUT_VARIABLE(1); + + if (gradient->isEmpty() || initState->isEmpty()) return Status::OK(); + + REQUIRE_TRUE(gradient->isSameShape(initState), 0, + "NESTEROVS UPDATER OP: input state Msg must have the same shape " + "as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initState->shapeInfo()).c_str()); + + bool bParamsSupply = 4 == block.width() || 2 == block.numT(); + + REQUIRE_TRUE( + bParamsSupply, 0, + "NESTEROVS UPDATER OP: learning rate and momentum were not provided!"); + + double dLr, dMomentum; + + if (block.width() > 2) { + const auto lr = INPUT_VARIABLE(2); + const auto momentum = INPUT_VARIABLE(3); + + REQUIRE_TRUE(lr->isScalar(), 0, + "NESTEROVS UPDATER OP: Learning rate has to be a scalar, but " + "instead got rank %i!", + lr->rankOf()); + REQUIRE_TRUE(momentum->isScalar(), 0, + "NESTEROVS UPDATER OP: Momentum has to be a scalar, but " + "instead got rank %i!", + momentum->rankOf()); + + dLr = lr->e(0); + dMomentum = momentum->e(0); + } else { + dLr = T_ARG(0); + dMomentum = T_ARG(1); + } + helpers::updaterNesterovs(block.launchContext(), *gradient, *initState, + *update, *stateV, dLr, dMomentum); + return Status::OK(); +} - } +DECLARE_TYPES(nesterovs_updater) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/updaters/rmsPropUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/rmsPropUpdater.cpp index a611a4fbe59d..84030f5df0f6 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/rmsPropUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/rmsPropUpdater.cpp @@ -14,67 +14,78 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleh Semeniv (oleg.semeniv@gmail.com) - // +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// -#include -#include -#include -#include #include +#include +#include +#include +#include namespace sd { - namespace ops { - - CONFIGURABLE_OP_IMPL(rms_prop_updater, 2, 2, true, 0, 0) { - - const auto gradient = INPUT_VARIABLE(0); - const auto initState = INPUT_VARIABLE(1); - - auto update = OUTPUT_VARIABLE(0); - auto stateG = OUTPUT_VARIABLE(1); - - if (gradient->isEmpty() || initState->isEmpty()) - return Status::OK(); - - REQUIRE_TRUE(gradient->isSameShape(initState), 0, "RMS_PROB UPDATER OP: input state must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initState->shapeInfo()).c_str()); - - bool bParamsSupply = 5 == block.width() || 3 == block.getTArguments()->size(); - - REQUIRE_TRUE(bParamsSupply, 0, "RSM_PROB UPDATER OP: learning rate, rsm decay and epsilon were not provided!"); - - double dLr, dRmsDecay, dEpsilon; - - if (block.width() > 2) { - const auto lr = INPUT_VARIABLE(2); - const auto rmsDecay = INPUT_VARIABLE(3); - const auto epsilon = INPUT_VARIABLE(4); - - REQUIRE_TRUE(lr->isScalar(), 0, "RSM_PROB UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); - REQUIRE_TRUE(rmsDecay->isScalar(), 0, "RSM_PROB UPDATER OP: Rms decay has to be a scalar, but instead got rank %i!", rmsDecay->rankOf()); - REQUIRE_TRUE(epsilon->isScalar(), 0, "RSM_PROB UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); - - dLr = lr->e(0); - dRmsDecay = rmsDecay->e(0); - dEpsilon = epsilon->e(0); - } - else { - dLr = T_ARG(0); - dRmsDecay = T_ARG(1); - dEpsilon = T_ARG(2); - } - - helpers::updaterRmsProp(block.launchContext(), *gradient, *initState, *update, *stateG, dLr, dRmsDecay, dEpsilon); - return Status::OK(); - } - - DECLARE_TYPES(rms_prop_updater) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); - } +namespace ops { + +CONFIGURABLE_OP_IMPL(rms_prop_updater, 2, 2, true, 0, 0) { + const auto gradient = INPUT_VARIABLE(0); + const auto initState = INPUT_VARIABLE(1); + + auto update = OUTPUT_VARIABLE(0); + auto stateG = OUTPUT_VARIABLE(1); + + if (gradient->isEmpty() || initState->isEmpty()) return Status::OK(); + + REQUIRE_TRUE( + gradient->isSameShape(initState), 0, + "RMS_PROB UPDATER OP: input state must have the same shape as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initState->shapeInfo()).c_str()); + + bool bParamsSupply = 5 == block.width() || 3 == block.numT(); + + REQUIRE_TRUE(bParamsSupply, 0, + "RSM_PROB UPDATER OP: learning rate, rsm decay and epsilon were " + "not provided!"); + + double dLr, dRmsDecay, dEpsilon; + + if (block.width() > 2) { + const auto lr = INPUT_VARIABLE(2); + const auto rmsDecay = INPUT_VARIABLE(3); + const auto epsilon = INPUT_VARIABLE(4); + + REQUIRE_TRUE(lr->isScalar(), 0, + "RSM_PROB UPDATER OP: Learning rate has to be a scalar, but " + "instead got rank %i!", + lr->rankOf()); + REQUIRE_TRUE(rmsDecay->isScalar(), 0, + "RSM_PROB UPDATER OP: Rms decay has to be a scalar, but " + "instead got rank %i!", + rmsDecay->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, + "RSM_PROB UPDATER OP: Epsilon has to be a scalar, but instead " + "got rank %i!", + epsilon->rankOf()); + + dLr = lr->e(0); + dRmsDecay = rmsDecay->e(0); + dEpsilon = epsilon->e(0); + } else { + dLr = T_ARG(0); + dRmsDecay = T_ARG(1); + dEpsilon = T_ARG(2); + } + + helpers::updaterRmsProp(block.launchContext(), *gradient, *initState, *update, + *stateG, dLr, dRmsDecay, dEpsilon); + return Status::OK(); +} - } +DECLARE_TYPES(rms_prop_updater) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/updaters/sgdUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/sgdUpdater.cpp index 491d7b53e203..07abb72d1833 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/sgdUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/sgdUpdater.cpp @@ -14,48 +14,48 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleh Semeniv (oleg.semeniv@gmail.com) - // +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// -#include -#include -#include -#include #include +#include +#include +#include +#include namespace sd { - namespace ops { - - CONFIGURABLE_OP_IMPL(sgd_updater, 1, 1, true, 0, 0) { +namespace ops { - const auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +CONFIGURABLE_OP_IMPL(sgd_updater, 1, 1, true, 0, 0) { + const auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - if (input->isEmpty()) - return Status::OK(); + if (input->isEmpty()) return Status::OK(); - bool bLearningRate = 2 == block.width() || 1 == block.getTArguments()->size(); + bool bLearningRate = 2 == block.width() || 1 == block.numT(); - REQUIRE_TRUE(bLearningRate, 0, "SGD UPDATER OP: Learning rate was not provided!"); + REQUIRE_TRUE(bLearningRate, 0, + "SGD UPDATER OP: Learning rate was not provided!"); - if (block.width() > 1) { - const auto lr = INPUT_VARIABLE(1); - REQUIRE_TRUE(lr->isScalar(), 0, "SGD UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); + if (block.width() > 1) { + const auto lr = INPUT_VARIABLE(1); + REQUIRE_TRUE(lr->isScalar(), 0, + "SGD UPDATER OP: Learning rate has to be a scalar, but " + "instead got rank %i!", + lr->rankOf()); - input->applyScalarArr(scalar::Multiply, *lr, *output); - } - else { - input->applyScalar(scalar::Multiply, T_ARG(0), *output); - } + input->applyScalarArr(scalar::Multiply, *lr, *output); + } else { + input->applyScalar(scalar::Multiply, T_ARG(0), *output); + } - return Status::OK(); - } - - DECLARE_TYPES(sgd_updater) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); - } + return Status::OK(); +} - } +DECLARE_TYPES(sgd_updater) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/util/print_affinity.cpp b/libnd4j/include/ops/declarable/generic/util/print_affinity.cpp index f7a758af6f0b..02a7020d1ccc 100644 --- a/libnd4j/include/ops/declarable/generic/util/print_affinity.cpp +++ b/libnd4j/include/ops/declarable/generic/util/print_affinity.cpp @@ -25,28 +25,35 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(print_affinity, 1, 1, true, 0, 0) { - // TODO: make this op compatible with ArrayList etc - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - nd4j_printf(": Actuality: [HOST: %s; DEVICE: %s]; affinity: [%i]; Pointers: [HOST: %p; DEVICE: %p]; DataBuffer length: %lld\n", block.nodeId(), input->isActualOnHostSide() ? "true" : "false", input->isActualOnDeviceSide() ? "true" : "false", input->dataBuffer()->deviceId(), input->buffer(), input->specialBuffer(), input->dataBuffer()->getLenInBytes()); - - return Status::OK(); - } - - DECLARE_TYPES(print_affinity) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_STRINGS}) - ->setAllowedOutputTypes(0, sd::DataType::INT32); - } - - DECLARE_SHAPE_FN(print_affinity) { - return SHAPELIST(ConstantShapeHelper::getInstance().scalarShapeInfo(DataType::INT32)); - } - } +namespace ops { +CUSTOM_OP_IMPL(print_affinity, 1, 1, true, 0, 0) { + // TODO: make this op compatible with ArrayList etc + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + nd4j_printf( + ": Actuality: [HOST: %s; DEVICE: %s]; affinity: [%i]; Pointers: " + "[HOST: %p; DEVICE: %p]; DataBuffer length: %lld\n", + block.nodeId(), input->isActualOnHostSide() ? "true" : "false", + input->isActualOnDeviceSide() ? "true" : "false", + input->dataBuffer()->deviceId(), input->buffer(), input->specialBuffer(), + input->dataBuffer()->getLenInBytes()); + + return Status::OK(); } +DECLARE_TYPES(print_affinity) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_STRINGS}) + ->setAllowedOutputTypes(0, sd::DataType::INT32); +} + +DECLARE_SHAPE_FN(print_affinity) { + return SHAPELIST( + ConstantShapeHelper::getInstance().scalarShapeInfo(DataType::INT32)); +} +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/util/print_variable.cpp b/libnd4j/include/ops/declarable/generic/util/print_variable.cpp index 74ff99fd2f68..8ac24959155a 100644 --- a/libnd4j/include/ops/declarable/generic/util/print_variable.cpp +++ b/libnd4j/include/ops/declarable/generic/util/print_variable.cpp @@ -24,54 +24,57 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(print_variable, 1, 1, true, 0, 0) { - // TODO: make this op compatible with ArrayList etc - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - std::string str; +namespace ops { - if (block.width() == 2) { - auto message = INPUT_VARIABLE(1); - REQUIRE_TRUE(message->isS(), 0, "print_variable: message variable must be a String"); +CUSTOM_OP_IMPL(print_variable, 1, 1, true, 0, 0) { + // TODO: make this op compatible with ArrayList etc + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + std::string str; - str = message->e(0); - } + if (block.width() == 2) { + auto message = INPUT_VARIABLE(1); + REQUIRE_TRUE(message->isS(), 0, + "print_variable: message variable must be a String"); - bool printSpecial = false; - if (block.numB() > 0) - printSpecial = B_ARG(0); + str = message->e(0); + } - if (printSpecial && !sd::Environment::getInstance().isCPU()) { - // only specific backends support special printout. for cpu-based backends it's the same as regular print + bool printSpecial = false; + if (block.numB() > 0) printSpecial = B_ARG(0); - if (block.width() == 2) - helpers::print_special(*block.launchContext(), *input, str); - else - helpers::print_special(*block.launchContext(), *input); - } else { - // optionally add message to the print out - if (block.width() == 2) { - input->printIndexedBuffer(str.c_str()); - } else { - input->printIndexedBuffer(); - } - } + if (printSpecial && !sd::Environment::getInstance().isCPU()) { + // only specific backends support special printout. for cpu-based backends + // it's the same as regular print - return Status::OK(); - } + if (block.width() == 2) + helpers::print_special(*block.launchContext(), *input, str); + else + helpers::print_special(*block.launchContext(), *input); + } else { + // optionally add message to the print out + if (block.width() == 2) { + input->printIndexedBuffer(str.c_str()); + } else { + input->printIndexedBuffer(); + } + } - DECLARE_TYPES(print_variable) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_STRINGS}) - ->setAllowedOutputTypes(0, sd::DataType::INT32); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(print_variable) { - return SHAPELIST(ConstantShapeHelper::getInstance().scalarShapeInfo(DataType::INT32)); - } - } +DECLARE_TYPES(print_variable) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_STRINGS}) + ->setAllowedOutputTypes(0, sd::DataType::INT32); +} + +DECLARE_SHAPE_FN(print_variable) { + return SHAPELIST( + ConstantShapeHelper::getInstance().scalarShapeInfo(DataType::INT32)); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/BarnesHutTsne.h b/libnd4j/include/ops/declarable/headers/BarnesHutTsne.h index 3f0d86e19f11..858ac32c36b3 100644 --- a/libnd4j/include/ops/declarable/headers/BarnesHutTsne.h +++ b/libnd4j/include/ops/declarable/headers/BarnesHutTsne.h @@ -24,78 +24,79 @@ #include namespace sd { - namespace ops { - /** - * This operation used as helper with BarnesHutTsne class - * to compute edge forces using barnes hut - * - * Expected input: - * 0: 1D row-vector (or with shape (1, m)) - * 1: 1D integer vector with slice nums - * 2: 1D float-point values vector with same shape as above - * 3: 2D float-point matrix with data to search - * - * Int args: - * 0: N - number of slices - * - * Output: - * 0: 2D matrix with the same shape and type as the 3th argument - */ - #if NOT_EXCLUDED(OP_barnes_edge_forces) - DECLARE_CUSTOM_OP(barnes_edge_forces, 4, 1, false, 0, 1); - #endif +namespace ops { +/** + * This operation used as helper with BarnesHutTsne class + * to compute edge forces using barnes hut + * + * Expected input: + * 0: 1D row-vector (or with shape (1, m)) + * 1: 1D integer vector with slice nums + * 2: 1D float-point values vector with same shape as above + * 3: 2D float-point matrix with data to search + * + * Int args: + * 0: N - number of slices + * + * Output: + * 0: 2D matrix with the same shape and type as the 3th argument + */ +#if NOT_EXCLUDED(OP_barnes_edge_forces) +DECLARE_CUSTOM_OP(barnes_edge_forces, 4, 1, false, 0, 1); +#endif - /** - * This operation used as helper with BarnesHutTsne class - * to Symmetrize the value matrix - * - * Expected input: - * 0: 1D int row-vector - * 1: 1D int col-vector - * 2: 1D float vector with values - * - * Output: - * 0: 1D int result row-vector - * 1: 1D int result col-vector - * 2: a float-point tensor with shape 1xN, with values from the last input vector - */ - #if NOT_EXCLUDED(OP_barnes_symmetrized) - DECLARE_CUSTOM_OP(barnes_symmetrized, 3, 3, false, 0, -1); - #endif +/** + * This operation used as helper with BarnesHutTsne class + * to Symmetrize the value matrix + * + * Expected input: + * 0: 1D int row-vector + * 1: 1D int col-vector + * 2: 1D float vector with values + * + * Output: + * 0: 1D int result row-vector + * 1: 1D int result col-vector + * 2: a float-point tensor with shape 1xN, with values from the last input + * vector + */ +#if NOT_EXCLUDED(OP_barnes_symmetrized) +DECLARE_CUSTOM_OP(barnes_symmetrized, 3, 3, false, 0, -1); +#endif - /** - * This operation used as helper with BranesHutTsne class - * to compute x = x + 2 * yGrads / abs(yGrads) != yIncs / abs(yIncs) - * - * Expected input: - * 0: input tensor - * 1: input gradient - * 2: gradient step tensor - * - * Output: - * 0: result of expression above - */ - #if NOT_EXCLUDED(OP_barnes_gains) - DECLARE_OP(barnes_gains, 3, 1, true); - #endif +/** + * This operation used as helper with BranesHutTsne class + * to compute x = x + 2 * yGrads / abs(yGrads) != yIncs / abs(yIncs) + * + * Expected input: + * 0: input tensor + * 1: input gradient + * 2: gradient step tensor + * + * Output: + * 0: result of expression above + */ +#if NOT_EXCLUDED(OP_barnes_gains) +DECLARE_OP(barnes_gains, 3, 1, true); +#endif - /** - * This operation used as helper with Cell class - * to check vals in given set - * - * Expected input: - * 0: 1D float row-vector (corners) - * 1: 1D float col-vector (widths) - * 2: 1D float vector (point) - * - * Output: - * 0: bool val - */ - #if NOT_EXCLUDED(OP_cell_contains) - DECLARE_CUSTOM_OP(cell_contains, 3, 1, false, 0, 1); - #endif +/** + * This operation used as helper with Cell class + * to check vals in given set + * + * Expected input: + * 0: 1D float row-vector (corners) + * 1: 1D float col-vector (widths) + * 2: 1D float vector (point) + * + * Output: + * 0: bool val + */ +#if NOT_EXCLUDED(OP_cell_contains) +DECLARE_CUSTOM_OP(cell_contains, 3, 1, false, 0, 1); +#endif - } -} +} // namespace ops +} // namespace sd -#endif // LIBND4J_HEADERS_BARNES_HUT_TSNE_H +#endif // LIBND4J_HEADERS_BARNES_HUT_TSNE_H diff --git a/libnd4j/include/ops/declarable/headers/activations.h b/libnd4j/include/ops/declarable/headers/activations.h index db9e8186a062..4eccc17c5e1e 100644 --- a/libnd4j/include/ops/declarable/headers/activations.h +++ b/libnd4j/include/ops/declarable/headers/activations.h @@ -21,180 +21,177 @@ #ifndef LIBND4J_HEADERS_ACTIVATIONS_H #define LIBND4J_HEADERS_ACTIVATIONS_H - #include namespace sd { - namespace ops { - /** - * This is Sigmoid activation function implementation - * Math is: 1 / 1 + exp(-x) - */ - #if NOT_EXCLUDED(OP_sigmoid) - DECLARE_CONFIGURABLE_OP(sigmoid, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(sigmoid_bp, 2, 1, true, 0, 0); - #endif - - /** - * This is Softsign activation function implementation - * Math is: x / 1 + abs(x) - */ - #if NOT_EXCLUDED(OP_softsign) - DECLARE_CONFIGURABLE_OP(softsign, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(softsign_bp, 2, 1, true, 0, 0); - #endif - - /** - * This is Tanh activation function implementation - */ - #if NOT_EXCLUDED(OP_tanh) - DECLARE_CONFIGURABLE_OP(tanh, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(tanh_bp, 2, 1, true, 0, 0); - #endif - - /** - * This is Softplus activation function implementation - * Math is: log(1 + exp(x)) - */ - #if NOT_EXCLUDED(OP_softplus) - DECLARE_CONFIGURABLE_OP(softplus, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(softplus_bp, 2, 1, true, 0, 0); - #endif - - /** - * This is RELU activation function implementation - */ - #if NOT_EXCLUDED(OP_relu) - DECLARE_CONFIGURABLE_OP(relu, 1, 1, true, 1, 0); - DECLARE_CONFIGURABLE_OP(relu_bp, 2, 1, true, 0, 0); - #endif - - /** - * This is SELU activation function implementation - */ - #if NOT_EXCLUDED(OP_selu) - DECLARE_CONFIGURABLE_OP(selu, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(selu_bp, 2, 1, true, 0, 0); - #endif - - /** - * This is Leaky RELU activation function. - * Math is: x < 0 ? alpha * x : x; - */ - #if NOT_EXCLUDED(OP_lrelu) - DECLARE_CONFIGURABLE_OP(lrelu, 1, 1, true, -2, 0); - DECLARE_CONFIGURABLE_OP(lrelu_bp, 2, 1, true, -2, 0); - #endif - - /** - * This op is ELU activation function. - * Math is: x >= 0 ? x : exp(x) - 1; - */ - #if NOT_EXCLUDED(OP_elu) - DECLARE_CONFIGURABLE_OP(elu, 1, 1, true, -2, 0); - DECLARE_CONFIGURABLE_OP(elu_bp, 2, 1, true, -2, 0); - #endif - - /** - * This is Cube activation function. - * Math is: x^3 - */ - #if NOT_EXCLUDED(OP_cube) - DECLARE_CONFIGURABLE_OP(cube, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(cube_bp, 2, 1, true, 0, 0); - #endif - - /** - * This is RectifiedTanh activation function. - * Math is: max(0, tanh(x)) - */ - #if NOT_EXCLUDED(OP_rectifiedtanh) - DECLARE_CONFIGURABLE_OP(rectifiedtanh, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(rectifiedtanh_bp, 2, 1, true, 0, 0); - #endif - - /** - * This is RationalTanh activation function. - */ - #if NOT_EXCLUDED(OP_rationaltanh) - DECLARE_CONFIGURABLE_OP(rationaltanh, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(rationaltanh_bp, 2, 1, true, 0, 0); - #endif - - /** - * This is HardTanh activation function. - * Math is: x < -1.0 ? -1.0 : x > 1.0 ? 1.0 : x; - */ - #if NOT_EXCLUDED(OP_hardtanh) - DECLARE_CONFIGURABLE_OP(hardtanh, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(hardtanh_bp, 2, 1, true, 0, 0); - #endif - - /** - * This is HardSigmoid activation function. - * Math is: min(1, max(0, 0.2 * x + 0.5)) - */ - #if NOT_EXCLUDED(OP_hardsigmoid) - DECLARE_CONFIGURABLE_OP(hardsigmoid, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(hardsigmoid_bp, 2, 1, true, 0, 0); - #endif - - /** - * This is Indentity operation. It passes signal umodified in both directions. - */ - #if NOT_EXCLUDED(OP_identity) - DECLARE_OP(identity, 1, 1, true); - DECLARE_OP(identity_bp, 2, 1, true); - #endif - - /** - * This is Indentity operation. It passes signal umodified in both directions. - */ - #if NOT_EXCLUDED(OP_identity_n) - DECLARE_CUSTOM_OP(identity_n, 1, 1, true, 0, 0); - #endif - - /** - * This is Concatenated RELU implementation. - * What happens inside: RELU(Concat((x, -x, {-1}))) - * - * PLEASE NOTE: Concatenation will double amount of features available in input - */ - #if NOT_EXCLUDED(OP_crelu) - DECLARE_CUSTOM_OP(crelu, 1, 1, false, 0, 0); - DECLARE_CUSTOM_OP(crelu_bp, 2, 1, false, 0, 0); - #endif - - /** - * This is RELU6 activation function implementation - */ - #if NOT_EXCLUDED(OP_relu6) - DECLARE_CONFIGURABLE_OP(relu6, 1, 1, true, 1, 0); - DECLARE_CONFIGURABLE_OP(relu6_bp, 2, 1, true, 0, 0); - #endif - - - /** - * Parametric Rectified Linear Unit - * f(x) = alpha * x for x < 0, f(x) = x for x >= 0 - */ - #if NOT_EXCLUDED(OP_prelu) - DECLARE_CONFIGURABLE_OP(prelu, 2, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(prelu_bp, 3, 2, true, 0, 0); - #endif - - /** - * Thresholded Rectified Linear Unit - * f(x) = x for x > theta, f(x) = 0 otherwise - * theta must be >= 0 - */ - #if NOT_EXCLUDED(OP_thresholdedrelu) - DECLARE_CONFIGURABLE_OP(thresholdedrelu, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(thresholdedrelu_bp, 2, 1, true, 0, 0); - #endif - - - } -} +namespace ops { +/** + * This is Sigmoid activation function implementation + * Math is: 1 / 1 + exp(-x) + */ +#if NOT_EXCLUDED(OP_sigmoid) +DECLARE_CONFIGURABLE_OP(sigmoid, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(sigmoid_bp, 2, 1, true, 0, 0); +#endif + +/** + * This is Softsign activation function implementation + * Math is: x / 1 + abs(x) + */ +#if NOT_EXCLUDED(OP_softsign) +DECLARE_CONFIGURABLE_OP(softsign, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(softsign_bp, 2, 1, true, 0, 0); +#endif + +/** + * This is Tanh activation function implementation + */ +#if NOT_EXCLUDED(OP_tanh) +DECLARE_CONFIGURABLE_OP(tanh, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(tanh_bp, 2, 1, true, 0, 0); +#endif + +/** + * This is Softplus activation function implementation + * Math is: log(1 + exp(x)) + */ +#if NOT_EXCLUDED(OP_softplus) +DECLARE_CONFIGURABLE_OP(softplus, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(softplus_bp, 2, 1, true, 0, 0); +#endif + +/** + * This is RELU activation function implementation + */ +#if NOT_EXCLUDED(OP_relu) +DECLARE_CONFIGURABLE_OP(relu, 1, 1, true, 1, 0); +DECLARE_CONFIGURABLE_OP(relu_bp, 2, 1, true, 0, 0); +#endif + +/** + * This is SELU activation function implementation + */ +#if NOT_EXCLUDED(OP_selu) +DECLARE_CONFIGURABLE_OP(selu, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(selu_bp, 2, 1, true, 0, 0); +#endif + +/** + * This is Leaky RELU activation function. + * Math is: x < 0 ? alpha * x : x; + */ +#if NOT_EXCLUDED(OP_lrelu) +DECLARE_CONFIGURABLE_OP(lrelu, 1, 1, true, -2, 0); +DECLARE_CONFIGURABLE_OP(lrelu_bp, 2, 1, true, -2, 0); +#endif + +/** + * This op is ELU activation function. + * Math is: x >= 0 ? x : exp(x) - 1; + */ +#if NOT_EXCLUDED(OP_elu) +DECLARE_CONFIGURABLE_OP(elu, 1, 1, true, -2, 0); +DECLARE_CONFIGURABLE_OP(elu_bp, 2, 1, true, -2, 0); +#endif + +/** + * This is Cube activation function. + * Math is: x^3 + */ +#if NOT_EXCLUDED(OP_cube) +DECLARE_CONFIGURABLE_OP(cube, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(cube_bp, 2, 1, true, 0, 0); +#endif + +/** + * This is RectifiedTanh activation function. + * Math is: max(0, tanh(x)) + */ +#if NOT_EXCLUDED(OP_rectifiedtanh) +DECLARE_CONFIGURABLE_OP(rectifiedtanh, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(rectifiedtanh_bp, 2, 1, true, 0, 0); +#endif + +/** + * This is RationalTanh activation function. + */ +#if NOT_EXCLUDED(OP_rationaltanh) +DECLARE_CONFIGURABLE_OP(rationaltanh, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(rationaltanh_bp, 2, 1, true, 0, 0); +#endif + +/** + * This is HardTanh activation function. + * Math is: x < -1.0 ? -1.0 : x > 1.0 ? 1.0 : x; + */ +#if NOT_EXCLUDED(OP_hardtanh) +DECLARE_CONFIGURABLE_OP(hardtanh, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(hardtanh_bp, 2, 1, true, 0, 0); +#endif + +/** + * This is HardSigmoid activation function. + * Math is: min(1, max(0, 0.2 * x + 0.5)) + */ +#if NOT_EXCLUDED(OP_hardsigmoid) +DECLARE_CONFIGURABLE_OP(hardsigmoid, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(hardsigmoid_bp, 2, 1, true, 0, 0); +#endif + +/** + * This is Indentity operation. It passes signal umodified in both directions. + */ +#if NOT_EXCLUDED(OP_identity) +DECLARE_OP(identity, 1, 1, true); +DECLARE_OP(identity_bp, 2, 1, true); +#endif + +/** + * This is Indentity operation. It passes signal umodified in both directions. + */ +#if NOT_EXCLUDED(OP_identity_n) +DECLARE_CUSTOM_OP(identity_n, 1, 1, true, 0, 0); +#endif + +/** + * This is Concatenated RELU implementation. + * What happens inside: RELU(Concat((x, -x, {-1}))) + * + * PLEASE NOTE: Concatenation will double amount of features available in input + */ +#if NOT_EXCLUDED(OP_crelu) +DECLARE_CUSTOM_OP(crelu, 1, 1, false, 0, 0); +DECLARE_CUSTOM_OP(crelu_bp, 2, 1, false, 0, 0); +#endif + +/** + * This is RELU6 activation function implementation + */ +#if NOT_EXCLUDED(OP_relu6) +DECLARE_CONFIGURABLE_OP(relu6, 1, 1, true, 1, 0); +DECLARE_CONFIGURABLE_OP(relu6_bp, 2, 1, true, 0, 0); +#endif + +/** + * Parametric Rectified Linear Unit + * f(x) = alpha * x for x < 0, f(x) = x for x >= 0 + */ +#if NOT_EXCLUDED(OP_prelu) +DECLARE_CONFIGURABLE_OP(prelu, 2, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(prelu_bp, 3, 2, true, 0, 0); +#endif + +/** + * Thresholded Rectified Linear Unit + * f(x) = x for x > theta, f(x) = 0 otherwise + * theta must be >= 0 + */ +#if NOT_EXCLUDED(OP_thresholdedrelu) +DECLARE_CONFIGURABLE_OP(thresholdedrelu, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(thresholdedrelu_bp, 2, 1, true, 0, 0); +#endif + +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/bitwise.h b/libnd4j/include/ops/declarable/headers/bitwise.h index b5f29896f779..db8188267dc2 100644 --- a/libnd4j/include/ops/declarable/headers/bitwise.h +++ b/libnd4j/include/ops/declarable/headers/bitwise.h @@ -24,107 +24,109 @@ #include namespace sd { - namespace ops { - /** - * This operation toggles individual bits of each element in array - * - * PLEASE NOTE: This operation is possible only on integer data types - * - * @tparam T - */ - #if NOT_EXCLUDED(OP_toggle_bits) - DECLARE_OP(toggle_bits, -1, -1, true); - #endif - +namespace ops { +/** + * This operation toggles individual bits of each element in array + * + * PLEASE NOTE: This operation is possible only on integer data types + * + * @tparam T + */ +#if NOT_EXCLUDED(OP_toggle_bits) +DECLARE_OP(toggle_bits, -1, -1, true); +#endif - /** - * This operation shift individual bits of each element in array to the left: << - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * @tparam T - */ - #if NOT_EXCLUDED(OP_shift_bits) - DECLARE_BROADCASTABLE_OP(shift_bits, 0, 0); - #endif +/** + * This operation shift individual bits of each element in array to the left: << + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ +#if NOT_EXCLUDED(OP_shift_bits) +DECLARE_BROADCASTABLE_OP(shift_bits, 0, 0); +#endif - /** - * This operation shift individual bits of each element in array to the right: >> - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * @tparam T - */ - #if NOT_EXCLUDED(OP_rshift_bits) - DECLARE_BROADCASTABLE_OP(rshift_bits, 0, 0); - #endif +/** + * This operation shift individual bits of each element in array to the right: + * >> + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ +#if NOT_EXCLUDED(OP_rshift_bits) +DECLARE_BROADCASTABLE_OP(rshift_bits, 0, 0); +#endif - /** - * This operation shift individual bits of each element in array, shifting to the left - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * @tparam T - */ - #if NOT_EXCLUDED(OP_cyclic_shift_bits) - DECLARE_BROADCASTABLE_OP(cyclic_shift_bits, 0, 0); - #endif +/** + * This operation shift individual bits of each element in array, shifting to + * the left + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ +#if NOT_EXCLUDED(OP_cyclic_shift_bits) +DECLARE_BROADCASTABLE_OP(cyclic_shift_bits, 0, 0); +#endif - /** - * This operation shift individual bits of each element in array, shifting to the right - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * @tparam T - */ - #if NOT_EXCLUDED(OP_cyclic_rshift_bits) - DECLARE_BROADCASTABLE_OP(cyclic_rshift_bits, 0, 0); - #endif +/** + * This operation shift individual bits of each element in array, shifting to + * the right + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ +#if NOT_EXCLUDED(OP_cyclic_rshift_bits) +DECLARE_BROADCASTABLE_OP(cyclic_rshift_bits, 0, 0); +#endif - /** - * This operation applies bitwise AND - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * @tparam T - */ - #if NOT_EXCLUDED(OP_bitwise_and) - DECLARE_BROADCASTABLE_OP(bitwise_and, 0, 0); - #endif +/** + * This operation applies bitwise AND + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ +#if NOT_EXCLUDED(OP_bitwise_and) +DECLARE_BROADCASTABLE_OP(bitwise_and, 0, 0); +#endif - /** - * This operation applies bitwise OR - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * @tparam T - */ - #if NOT_EXCLUDED(OP_bitwise_or) - DECLARE_BROADCASTABLE_OP(bitwise_or, 0, 0); - #endif +/** + * This operation applies bitwise OR + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ +#if NOT_EXCLUDED(OP_bitwise_or) +DECLARE_BROADCASTABLE_OP(bitwise_or, 0, 0); +#endif - /** - * This operation applies bitwise XOR - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * @tparam T - */ - #if NOT_EXCLUDED(OP_bitwise_xor) - DECLARE_BROADCASTABLE_OP(bitwise_xor, 0, 0); - #endif +/** + * This operation applies bitwise XOR + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ +#if NOT_EXCLUDED(OP_bitwise_xor) +DECLARE_BROADCASTABLE_OP(bitwise_xor, 0, 0); +#endif - /** - * This operation returns hamming distance based on bits - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * @tparam T - */ - #if NOT_EXCLUDED(OP_bits_hamming_distance) - DECLARE_CUSTOM_OP(bits_hamming_distance, 2, 1, true, 0, 0); - #endif - } -} +/** + * This operation returns hamming distance based on bits + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ +#if NOT_EXCLUDED(OP_bits_hamming_distance) +DECLARE_CUSTOM_OP(bits_hamming_distance, 2, 1, true, 0, 0); +#endif +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/blas.h b/libnd4j/include/ops/declarable/headers/blas.h index 6fd5a389445a..733f4299efbb 100644 --- a/libnd4j/include/ops/declarable/headers/blas.h +++ b/libnd4j/include/ops/declarable/headers/blas.h @@ -23,68 +23,67 @@ #include namespace sd { - namespace ops { +namespace ops { - /** - * This op is general matmum implementation. Depending on inputs dimensionality output result might be different. - * matrix x matrix = BLAS gemm - * vector x matrix = BLAS gemm - * vector x vector = BLAS dot - * vector x scalar = element-wise mul - * scalar x vector = element-wise mul - * - * Optional T arguments: - * 0: alpha (where applicable) - * 1: beta (where applicable) - * - * Optional Integer arguments: - * 0: transA (where applicable) - * 1: transB (where applicable) - */ - #if NOT_EXCLUDED(OP_matmul) - DECLARE_CUSTOM_OP(matmul, 2, 1, false, 0, -2); - DECLARE_CUSTOM_OP(matmul_bp, 3, 2, false, 0, -2); - #endif +/** + * This op is general matmum implementation. Depending on inputs dimensionality + * output result might be different. matrix x matrix = BLAS gemm vector x matrix + * = BLAS gemm vector x vector = BLAS dot vector x scalar = element-wise mul + * scalar x vector = element-wise mul + * + * Optional T arguments: + * 0: alpha (where applicable) + * 1: beta (where applicable) + * + * Optional Integer arguments: + * 0: transA (where applicable) + * 1: transB (where applicable) + */ +#if NOT_EXCLUDED(OP_matmul) +DECLARE_CUSTOM_OP(matmul, 2, 1, false, 0, -2); +DECLARE_CUSTOM_OP(matmul_bp, 3, 2, false, 0, -2); +#endif - /** - * tensorMmul/tensorDot operation - * takes 2 ndarrays, and 2 sets of axes - * - * Integer argumens map: - * IArgs[0] - number of axes along for first array - * IArgs[1]... axes values for first array - * IArgs[] - number of axes along for second array - * IArgs[1]... axes values for second array - */ - #if NOT_EXCLUDED(OP_tensormmul) - DECLARE_CUSTOM_OP(tensormmul, 2, 1, false, 0, -1); - DECLARE_CUSTOM_OP(tensormmul_bp, 3, 2, false, 0, -1); - #endif +/** + * tensorMmul/tensorDot operation + * takes 2 ndarrays, and 2 sets of axes + * + * Integer argumens map: + * IArgs[0] - number of axes along for first array + * IArgs[1]... axes values for first array + * IArgs[] - number of axes along for second array + * IArgs[1]... axes values for second array + */ +#if NOT_EXCLUDED(OP_tensormmul) +DECLARE_CUSTOM_OP(tensormmul, 2, 1, false, 0, -1); +DECLARE_CUSTOM_OP(tensormmul_bp, 3, 2, false, 0, -1); +#endif - /** - * This op is simple implementation of BLAS AXPY method. - * Math is: y += a * x; - */ - #if NOT_EXCLUDED(OP_axpy) - DECLARE_CONFIGURABLE_OP(axpy, 2, 1, false, -2, 0); - #endif +/** + * This op is simple implementation of BLAS AXPY method. + * Math is: y += a * x; + */ +#if NOT_EXCLUDED(OP_axpy) +DECLARE_CONFIGURABLE_OP(axpy, 2, 1, false, -2, 0); +#endif - /** - * This operation implements batched matrix multiplication - * Expected arguments: - * alpha: vector of T - * beta: vector of T - * ...: A, B matrices sequentially. i.e: AAAAABBBBB - * - * Integer arguments: - * transA, transB, M, N, K, ldA, ldB, ldC - usual BLAS gemm arguments - * batchCount - number of operations in this batch - * - * PLEASE NOTE: M, N, K, ldA, ldB, ldC should be equal for all matrices within batch. - */ - #if NOT_EXCLUDED(OP_batched_gemm) - DECLARE_CUSTOM_OP(batched_gemm, -1, -1, false, 0, 9); - #endif +/** + * This operation implements batched matrix multiplication + * Expected arguments: + * alpha: vector of T + * beta: vector of T + * ...: A, B matrices sequentially. i.e: AAAAABBBBB + * + * Integer arguments: + * transA, transB, M, N, K, ldA, ldB, ldC - usual BLAS gemm arguments + * batchCount - number of operations in this batch + * + * PLEASE NOTE: M, N, K, ldA, ldB, ldC should be equal for all matrices within + * batch. + */ +#if NOT_EXCLUDED(OP_batched_gemm) +DECLARE_CUSTOM_OP(batched_gemm, -1, -1, false, 0, 9); +#endif /** * performs singular value decomposition (SVD) of one or more matrices, evaluates the SVD of each inner-most 2D matrix in input array: diff --git a/libnd4j/include/ops/declarable/headers/boolean.h b/libnd4j/include/ops/declarable/headers/boolean.h index 75e95f630349..c61179c05e84 100644 --- a/libnd4j/include/ops/declarable/headers/boolean.h +++ b/libnd4j/include/ops/declarable/headers/boolean.h @@ -24,133 +24,137 @@ #include namespace sd { - namespace ops { - - /** - * This is scalar boolean op. - * Both operands should be scalars. - * - * Returns true if x < y - */ - #if NOT_EXCLUDED(OP_lt_scalar) - DECLARE_BOOLEAN_OP(lt_scalar, 2, true); - #endif - - /** - * This is scalar boolean op. - * Both operands should be scalars. - * - * Returns true if x > y - */ - #if NOT_EXCLUDED(OP_gt_scalar) - DECLARE_BOOLEAN_OP(gt_scalar, 2, true); - #endif - - /** - * This is scalar boolean op. - * Both operands should be scalars. - * - * Returns true if x <= y - */ - #if NOT_EXCLUDED(OP_lte_scalar) - DECLARE_BOOLEAN_OP(lte_scalar, 2, true); - #endif - - /** - * This is scalar boolean op. - * Both operands should be scalars. - * - * Returns true if x >= y - */ - #if NOT_EXCLUDED(OP_gte_scalar) - DECLARE_BOOLEAN_OP(gte_scalar, 2, true); - #endif - - /** - * This is scalar boolean op. - * Both operands should be scalars. - * - * Returns true if both operands are equal. - */ - #if NOT_EXCLUDED(OP_eq_scalar) - DECLARE_BOOLEAN_OP(eq_scalar, 2, true); - #endif - - /** - * This is scalar boolean op. - * Both operands should be scalars. - * - * Returns true if x != y - */ - #if NOT_EXCLUDED(OP_neq_scalar) - DECLARE_BOOLEAN_OP(neq_scalar, 2, true); - #endif - - /** - * This op takes 2 n-dimensional arrays as input, and return - * array of the same shape, with elements, either from x or y, depending on the condition. - */ - #if NOT_EXCLUDED(OP_where) - DECLARE_CUSTOM_OP(Where, 1, 1, false, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_where_np) - DECLARE_CUSTOM_OP(where_np, 1, 1, false, 0, 0); - #endif - - /** - * This op takes 2 n-dimensional arrays as input, and return - * array of the same shape, with elements, either from x or y, depending on the condition. - */ - #if NOT_EXCLUDED(OP_select) - DECLARE_CUSTOM_OP(select, 3, 1, false, 0, 0); - #endif - - /** - * This op takes either 1 argument and 1 scalar - * or 1 argument and another comparison array - * and runs a pre defined conditional op. - * - * The output of the op is dynamic in size and returns a flat vector of elements - * that return true on the given condition. - * In numpy parlance, most people might understand: - * a[a > 2] - * where a is a numpy array and the condition is true when an element is - * > 2. Libnd4j already implements a number of pre defined conditions. - * @tparam T - */ - #if NOT_EXCLUDED(OP_choose) - DECLARE_CUSTOM_OP(choose, -1, 1, false, -2, -1); - #endif - - /** - * This op takes 1 n-dimensional array as input, and returns true if for every adjacent pair we have x[i] <= x[i+1]. - */ - #if NOT_EXCLUDED(OP_is_non_decreasing) - DECLARE_BOOLEAN_OP(is_non_decreasing, 1, true); - #endif - - /** - * This op takes 1 n-dimensional array as input, and returns true if for every adjacent pair we have x[i] < x[i+1]. - */ - #if NOT_EXCLUDED(OP_is_strictly_increasing) - DECLARE_BOOLEAN_OP(is_strictly_increasing, 1, true); - #endif - - /** - * This op takes 1 n-dimensional array as input, and returns true if input is a numeric array. - */ - #if NOT_EXCLUDED(OP_is_numeric_tensor) - DECLARE_BOOLEAN_OP(is_numeric_tensor, 1, true); - #endif - - /** - * - */ - #if NOT_EXCLUDED(OP_boolean_not) - DECLARE_OP(boolean_not, 1, 1, true); - #endif - } -} +namespace ops { + +/** + * This is scalar boolean op. + * Both operands should be scalars. + * + * Returns true if x < y + */ +#if NOT_EXCLUDED(OP_lt_scalar) +DECLARE_BOOLEAN_OP(lt_scalar, 2, true); +#endif + +/** + * This is scalar boolean op. + * Both operands should be scalars. + * + * Returns true if x > y + */ +#if NOT_EXCLUDED(OP_gt_scalar) +DECLARE_BOOLEAN_OP(gt_scalar, 2, true); +#endif + +/** + * This is scalar boolean op. + * Both operands should be scalars. + * + * Returns true if x <= y + */ +#if NOT_EXCLUDED(OP_lte_scalar) +DECLARE_BOOLEAN_OP(lte_scalar, 2, true); +#endif + +/** + * This is scalar boolean op. + * Both operands should be scalars. + * + * Returns true if x >= y + */ +#if NOT_EXCLUDED(OP_gte_scalar) +DECLARE_BOOLEAN_OP(gte_scalar, 2, true); +#endif + +/** + * This is scalar boolean op. + * Both operands should be scalars. + * + * Returns true if both operands are equal. + */ +#if NOT_EXCLUDED(OP_eq_scalar) +DECLARE_BOOLEAN_OP(eq_scalar, 2, true); +#endif + +/** + * This is scalar boolean op. + * Both operands should be scalars. + * + * Returns true if x != y + */ +#if NOT_EXCLUDED(OP_neq_scalar) +DECLARE_BOOLEAN_OP(neq_scalar, 2, true); +#endif + +/** + * This op takes 2 n-dimensional arrays as input, and return + * array of the same shape, with elements, either from x or y, depending on the + * condition. + */ +#if NOT_EXCLUDED(OP_where) +DECLARE_CUSTOM_OP(Where, 1, 1, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_where_np) +DECLARE_CUSTOM_OP(where_np, 1, 1, false, 0, 0); +#endif + +/** + * This op takes 2 n-dimensional arrays as input, and return + * array of the same shape, with elements, either from x or y, depending on the + * condition. + */ +#if NOT_EXCLUDED(OP_select) +DECLARE_CUSTOM_OP(select, 3, 1, false, 0, 0); +#endif + +/** + * This op takes either 1 argument and 1 scalar + * or 1 argument and another comparison array + * and runs a pre defined conditional op. + * + * The output of the op is dynamic in size and returns a flat vector of + * elements that return true on the given condition. In numpy parlance, most + * people might understand: a[a > 2] where a is a numpy array and the condition + * is true when an element is > 2. Libnd4j already implements a number of pre + * defined conditions. + * @tparam T + */ +#if NOT_EXCLUDED(OP_choose) +DECLARE_CUSTOM_OP(choose, -1, 1, false, -2, -1); +#endif + +/** + * This op takes 1 n-dimensional array as input, and returns true if for every + * adjacent pair we have x[i] <= x[i+1]. + */ +#if NOT_EXCLUDED(OP_is_non_decreasing) +DECLARE_BOOLEAN_OP(is_non_decreasing, 1, true); +#endif + +/** + * This op takes 1 n-dimensional array as input, and returns true if for every + * adjacent pair we have x[i] < x[i+1]. + */ +#if NOT_EXCLUDED(OP_is_strictly_increasing) +DECLARE_BOOLEAN_OP(is_strictly_increasing, 1, true); +#endif + +/** + * This op takes 1 n-dimensional array as input, and returns true if input is a + * numeric array. + */ +#if NOT_EXCLUDED(OP_is_numeric_tensor) +DECLARE_BOOLEAN_OP(is_numeric_tensor, 1, true); +#endif + +/** + * + */ +#if NOT_EXCLUDED(OP_boolean_not) +DECLARE_OP(boolean_not, 1, 1, true); +#endif +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/broadcastable.h b/libnd4j/include/ops/declarable/headers/broadcastable.h index 7380412a4156..fa89ab47c9ef 100644 --- a/libnd4j/include/ops/declarable/headers/broadcastable.h +++ b/libnd4j/include/ops/declarable/headers/broadcastable.h @@ -21,367 +21,397 @@ #ifndef LIBND4J_HEADERS_BROADCASTABLE_H #define LIBND4J_HEADERS_BROADCASTABLE_H -#include #include -#include +#include #include +#include namespace sd { - namespace ops { - // TODO: make broadcastables separate class - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Max(X, Y) - */ - #if NOT_EXCLUDED(OP_maximum) - DECLARE_BROADCASTABLE_OP(maximum, 0, 0); - DECLARE_CUSTOM_OP(maximum_bp, 3, 2, false, 0, 0); - #endif - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Min(X, Y) - */ - #if NOT_EXCLUDED(OP_minimum) - DECLARE_BROADCASTABLE_OP(minimum, 0, 0); - DECLARE_CUSTOM_OP(minimum_bp, 3, 2, false, 0, 0); - #endif - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Add(X, Y) - */ - #if NOT_EXCLUDED(OP_add) - DECLARE_BROADCASTABLE_OP(add, 0, 0); - DECLARE_CUSTOM_OP(add_bp, 3, 2, false, 0, 0); - #endif - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Subtract(X, Y) - */ - #if NOT_EXCLUDED(OP_subtract) - DECLARE_BROADCASTABLE_OP(subtract, 0, 0); - DECLARE_CUSTOM_OP(subtract_bp, 3, 2, false, 0, 0); - #endif - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Subtract(Y, X) - */ - #if NOT_EXCLUDED(OP_reversesubtract) - DECLARE_BROADCASTABLE_OP(reversesubtract, 0, 0); - DECLARE_CUSTOM_OP(reversesubtract_bp, 3, 2, false, 0, 0); - #endif - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = ReverseMod(X, Y) == Mod(Y, X) - */ - #if NOT_EXCLUDED(OP_reversemod) - DECLARE_BROADCASTABLE_OP(reversemod, 0, 0); - DECLARE_CUSTOM_OP(reversemod_bp, 3, 2, true, 0, 0); - #endif - - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Subtract(X, Y) * Subtract(X, Y) - */ - #if NOT_EXCLUDED(OP_squaredsubtract) - DECLARE_BROADCASTABLE_OP(squaredsubtract, 0, 0) - DECLARE_CUSTOM_OP(squaredsubtract_bp, 3, 2, false, 0, 0); - #endif - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Multiply(X, Y) - */ - #if NOT_EXCLUDED(OP_multiply) - DECLARE_BROADCASTABLE_OP(multiply, 0, 0); - DECLARE_CUSTOM_OP(multiply_bp, 3, 2, false, 0, 0); - #endif - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Divide(X, Y) - */ - #if NOT_EXCLUDED(OP_divide) - DECLARE_BROADCASTABLE_OP(divide, 0, 0); - DECLARE_CUSTOM_OP(divide_bp, 3, 2, false, 0, 0); - #endif - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Divide(X, Y) with exception, 0 if Y = 0 - */ - #if NOT_EXCLUDED(OP_divide_no_nan) - DECLARE_BROADCASTABLE_OP(divide_no_nan, 0, 0); - #endif - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Divide(Y, x) - */ - #if NOT_EXCLUDED(OP_reversedivide) - DECLARE_BROADCASTABLE_OP(reversedivide, 0, 0); - DECLARE_CUSTOM_OP(reversedivide_bp, 3, 2, false, 0, 0); - #endif - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = FloorMod(X, Y) - */ - #if NOT_EXCLUDED(OP_floormod) - DECLARE_BROADCASTABLE_OP(floormod, 0, 0); - DECLARE_CUSTOM_OP(floormod_bp, 3, 2, true, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_mod) - DECLARE_BROADCASTABLE_OP(mod, 0, 0); - DECLARE_CUSTOM_OP(mod_bp, 3, 2, true, 0, 0); - #endif - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = FloorDiv(X, Y) - */ - #if NOT_EXCLUDED(OP_floordiv) - DECLARE_BROADCASTABLE_OP(floordiv, 0, 0) - DECLARE_CUSTOM_OP(floordiv_bp, 2, 1, true, 0, 0) - #endif - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Divide(X, Y) - */ - #if NOT_EXCLUDED(OP_realdiv) - DECLARE_BROADCASTABLE_OP(realdiv, 0, 0); - DECLARE_CUSTOM_OP(realdiv_bp, 3, 2, false, 0, 0); - #endif - - - /** - * - * - * @tparam T - */ - DECLARE_BROADCASTABLE_OP(truncatediv, 0, 0); - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Assign(X, Y) - */ - #if NOT_EXCLUDED(OP_assign) - DECLARE_BROADCASTABLE_OP(assign, 0, 0); - DECLARE_CUSTOM_OP(assign_bp, 3, 2, false, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_meshgrid) - DECLARE_CUSTOM_OP(meshgrid, -1, -1, false, 0, 0); - #endif - - /** - * This op takes 2 equally shaped arrays as input, and provides binary matrix as output. - * Math is: _x == _y ? (T) 1.0f : (T) 0.0f; - * - */ - #if NOT_EXCLUDED(OP_equals) - DECLARE_BROADCASTABLE_BOOL_OP(equals, 0, 0); - #endif - - /** - * This op takes 2 equally shaped arrays as input, and provides binary matrix as output. - * Math is: _x != _y ? (T) 1.0f : (T) 0.0f; - */ - #if NOT_EXCLUDED(OP_not_equals) - DECLARE_BROADCASTABLE_BOOL_OP(not_equals, 0, 0); - #endif - - /** - * This op takes 2 equally shaped arrays as input, and provides binary matrix as output. - * Math is: _x <= _y ? (T) 1.0f : (T) 0.0f; - */ - #if NOT_EXCLUDED(OP_less_equal) - DECLARE_BROADCASTABLE_BOOL_OP(less_equal, 0, 0); - #endif - - /** - * This op takes 2 equally shaped arrays as input, and provides binary matrix as output. - * Math is: _x >= _y ? (T) 1.0f : (T) 0.0f; - */ - #if NOT_EXCLUDED(OP_greater_equal) - DECLARE_BROADCASTABLE_BOOL_OP(greater_equal, 0, 0); - #endif - - /** - * This op takes 2 equally shaped arrays as input, and provides binary matrix as output. - * Math is: _x < _y ? (T) 1.0f : (T) 0.0f; - */ - #if NOT_EXCLUDED(OP_less) - DECLARE_BROADCASTABLE_BOOL_OP(less, 0, 0); - #endif - - /** - * This op takes 2 equally shaped arrays as input, and provides binary matrix as output. - * Math is: _x > _y ? (T) 1.0f : (T) 0.0f; - */ - #if NOT_EXCLUDED(OP_greater) - DECLARE_BROADCASTABLE_BOOL_OP(greater, 0, 0); - #endif - - /** - * - */ - #if NOT_EXCLUDED(OP_boolean_and) - DECLARE_BROADCASTABLE_OP(boolean_and, 0, 0); - #endif - - /** - * - */ - #if NOT_EXCLUDED(OP_boolean_or) - DECLARE_BROADCASTABLE_OP(boolean_or, 0, 0); - #endif - - /** - * - */ - #if NOT_EXCLUDED(OP_boolean_xor) - DECLARE_BROADCASTABLE_OP(boolean_xor, 0, 0); - #endif - - /** - * This operation performs calculation of percentile of input array along given axises - * - * Input - tensor with rank N > 0 - * Output - tensor with rank (N - length(axis)) or scalar if number of Integer arguments is zero - * Float arguments: - * 0: percentile (scalar) in range [0,100] (inclusively) - * 1: interpolation (optional), possible values are 0-"lower", 1-"higher", 2-"nearest"(default) - * 2: keepDims (optional), if it is non zero, then unities are kept in reduced resulting shape of output array, default is 0 - * Integer arguments - axis - the sequence of axises to calculate percentile along, if sequence is empty then calculate percentile for whole input tensor and return result as scalar - * - */ - #if NOT_EXCLUDED(OP_percentile) - DECLARE_CUSTOM_OP(percentile, 1, 1, false, 1, -2); - #endif - - - /** - * Special atan2 op impl for TF's args order - * @tparam T - */ - #if NOT_EXCLUDED(OP_tf_atan2) - DECLARE_BROADCASTABLE_OP(tf_atan2, 0, 0); - #endif - - /** - * Broadcastable pow implementation - * @tparam T - */ - #if NOT_EXCLUDED(OP_Pow) - DECLARE_BROADCASTABLE_OP(Pow, 0, 0); - DECLARE_CUSTOM_OP(Pow_bp, 3, 2, false, 0, 0); - #endif - - /** - * Broadcastable igamma implementation - * - * igamma(a, x) = gamma(а, x) / Gamma(a) - Gamma distribution function P(a,x) - * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } - * gamma(a, x) = int from 0 to x { t ^ {a - 1} e^{-t}dt } - * @tparam T - */ - #if NOT_EXCLUDED(OP_igamma) - DECLARE_BROADCASTABLE_OP(igamma, 0, 0); - #endif - /** - * Broadcastable igammac implementation - * igammac(a, x) = Gamma(a,x)/Gamma(а) - Gamma distribution function Q(a,x) - * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } - * Gamma(a, x) = int from x to infinity { t ^ {a - 1} e^{-t}dt } - * @tparam T - */ - #if NOT_EXCLUDED(OP_igammac) - DECLARE_BROADCASTABLE_OP(igammac, 0, 0); - #endif - } -} +namespace ops { +// TODO: make broadcastables separate class + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Max(X, Y) + */ +#if NOT_EXCLUDED(OP_maximum) +DECLARE_BROADCASTABLE_OP(maximum, 0, 0); +DECLARE_CUSTOM_OP(maximum_bp, 3, 2, false, 0, 0); +#endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Min(X, Y) + */ +#if NOT_EXCLUDED(OP_minimum) +DECLARE_BROADCASTABLE_OP(minimum, 0, 0); +DECLARE_CUSTOM_OP(minimum_bp, 3, 2, false, 0, 0); +#endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Add(X, Y) + */ +#if NOT_EXCLUDED(OP_add) +DECLARE_BROADCASTABLE_OP(add, 0, 0); +DECLARE_CUSTOM_OP(add_bp, 3, 2, false, 0, 0); +#endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Subtract(X, Y) + */ +#if NOT_EXCLUDED(OP_subtract) +DECLARE_BROADCASTABLE_OP(subtract, 0, 0); +DECLARE_CUSTOM_OP(subtract_bp, 3, 2, false, 0, 0); +#endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Subtract(Y, X) + */ +#if NOT_EXCLUDED(OP_reversesubtract) +DECLARE_BROADCASTABLE_OP(reversesubtract, 0, 0); +DECLARE_CUSTOM_OP(reversesubtract_bp, 3, 2, false, 0, 0); +#endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = ReverseMod(X, Y) == Mod(Y, X) + */ +#if NOT_EXCLUDED(OP_reversemod) +DECLARE_BROADCASTABLE_OP(reversemod, 0, 0); +DECLARE_CUSTOM_OP(reversemod_bp, 3, 2, true, 0, 0); +#endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Subtract(X, Y) * Subtract(X, Y) + */ +#if NOT_EXCLUDED(OP_squaredsubtract) +DECLARE_BROADCASTABLE_OP(squaredsubtract, 0, 0) +DECLARE_CUSTOM_OP(squaredsubtract_bp, 3, 2, false, 0, 0); +#endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Multiply(X, Y) + */ +#if NOT_EXCLUDED(OP_multiply) +DECLARE_BROADCASTABLE_OP(multiply, 0, 0); +DECLARE_CUSTOM_OP(multiply_bp, 3, 2, false, 0, 0); +#endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Divide(X, Y) + */ +#if NOT_EXCLUDED(OP_divide) +DECLARE_BROADCASTABLE_OP(divide, 0, 0); +DECLARE_CUSTOM_OP(divide_bp, 3, 2, false, 0, 0); +#endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Divide(X, Y) with exception, 0 if Y = 0 + */ +#if NOT_EXCLUDED(OP_divide_no_nan) +DECLARE_BROADCASTABLE_OP(divide_no_nan, 0, 0); +#endif +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Divide(Y, x) + */ +#if NOT_EXCLUDED(OP_reversedivide) +DECLARE_BROADCASTABLE_OP(reversedivide, 0, 0); +DECLARE_CUSTOM_OP(reversedivide_bp, 3, 2, false, 0, 0); +#endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = FloorMod(X, Y) + */ +#if NOT_EXCLUDED(OP_floormod) +DECLARE_BROADCASTABLE_OP(floormod, 0, 0); +DECLARE_CUSTOM_OP(floormod_bp, 3, 2, true, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_mod) +DECLARE_BROADCASTABLE_OP(mod, 0, 0); +DECLARE_CUSTOM_OP(mod_bp, 3, 2, true, 0, 0); +#endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = FloorDiv(X, Y) + */ +#if NOT_EXCLUDED(OP_floordiv) +DECLARE_BROADCASTABLE_OP(floordiv, 0, 0) +DECLARE_CUSTOM_OP(floordiv_bp, 2, 1, true, 0, 0) +#endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Divide(X, Y) + */ +#if NOT_EXCLUDED(OP_realdiv) +DECLARE_BROADCASTABLE_OP(realdiv, 0, 0); +DECLARE_CUSTOM_OP(realdiv_bp, 3, 2, false, 0, 0); +#endif + +/** + * + * + * @tparam T + */ +DECLARE_BROADCASTABLE_OP(truncatediv, 0, 0); + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Assign(X, Y) + */ +#if NOT_EXCLUDED(OP_assign) +DECLARE_BROADCASTABLE_OP(assign, 0, 0); +DECLARE_CUSTOM_OP(assign_bp, 3, 2, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_meshgrid) +DECLARE_CUSTOM_OP(meshgrid, -1, -1, false, 0, 0); +#endif + +/** + * This op takes 2 equally shaped arrays as input, and provides binary matrix as + * output. Math is: _x == _y ? (T) 1.0f : (T) 0.0f; + * + */ +#if NOT_EXCLUDED(OP_equals) +DECLARE_BROADCASTABLE_BOOL_OP(equals, 0, 0); +#endif + +/** + * This op takes 2 equally shaped arrays as input, and provides binary matrix as + * output. Math is: _x != _y ? (T) 1.0f : (T) 0.0f; + */ +#if NOT_EXCLUDED(OP_not_equals) +DECLARE_BROADCASTABLE_BOOL_OP(not_equals, 0, 0); +#endif + +/** + * This op takes 2 equally shaped arrays as input, and provides binary matrix as + * output. Math is: _x <= _y ? (T) 1.0f : (T) 0.0f; + */ +#if NOT_EXCLUDED(OP_less_equal) +DECLARE_BROADCASTABLE_BOOL_OP(less_equal, 0, 0); +#endif + +/** + * This op takes 2 equally shaped arrays as input, and provides binary matrix as + * output. Math is: _x >= _y ? (T) 1.0f : (T) 0.0f; + */ +#if NOT_EXCLUDED(OP_greater_equal) +DECLARE_BROADCASTABLE_BOOL_OP(greater_equal, 0, 0); +#endif + +/** + * This op takes 2 equally shaped arrays as input, and provides binary matrix as + * output. Math is: _x < _y ? (T) 1.0f : (T) 0.0f; + */ +#if NOT_EXCLUDED(OP_less) +DECLARE_BROADCASTABLE_BOOL_OP(less, 0, 0); +#endif + +/** + * This op takes 2 equally shaped arrays as input, and provides binary matrix as + * output. Math is: _x > _y ? (T) 1.0f : (T) 0.0f; + */ +#if NOT_EXCLUDED(OP_greater) +DECLARE_BROADCASTABLE_BOOL_OP(greater, 0, 0); +#endif + +/** + * + */ +#if NOT_EXCLUDED(OP_boolean_and) +DECLARE_BROADCASTABLE_OP(boolean_and, 0, 0); +#endif + +/** + * + */ +#if NOT_EXCLUDED(OP_boolean_or) +DECLARE_BROADCASTABLE_OP(boolean_or, 0, 0); +#endif + +/** + * + */ +#if NOT_EXCLUDED(OP_boolean_xor) +DECLARE_BROADCASTABLE_OP(boolean_xor, 0, 0); +#endif + +/** + * This operation performs calculation of percentile of input array along given + * axises + * + * Input - tensor with rank N > 0 + * Output - tensor with rank (N - length(axis)) or scalar if number of Integer + * arguments is zero Float arguments: 0: percentile (scalar) in range [0,100] + * (inclusively) 1: interpolation (optional), possible values are 0-"lower", + * 1-"higher", 2-"nearest"(default) 2: keepDims (optional), if it is non zero, + * then unities are kept in reduced resulting shape of output array, default is + * 0 Integer arguments - axis - the sequence of axises to calculate percentile + * along, if sequence is empty then calculate percentile for whole input tensor + * and return result as scalar + * + */ +#if NOT_EXCLUDED(OP_percentile) +DECLARE_CUSTOM_OP(percentile, 1, 1, false, 1, -2); +#endif + +/** + * Special atan2 op impl for TF's args order + * @tparam T + */ +#if NOT_EXCLUDED(OP_tf_atan2) +DECLARE_BROADCASTABLE_OP(tf_atan2, 0, 0); +#endif + +/** + * Broadcastable pow implementation + * @tparam T + */ +#if NOT_EXCLUDED(OP_Pow) +DECLARE_BROADCASTABLE_OP(Pow, 0, 0); +DECLARE_CUSTOM_OP(Pow_bp, 3, 2, false, 0, 0); +#endif + +/** + * Broadcastable igamma implementation + * + * igamma(a, x) = gamma(а, x) / Gamma(a) - Gamma distribution function P(a,x) + * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } + * gamma(a, x) = int from 0 to x { t ^ {a - 1} e^{-t}dt } + * @tparam T + */ +#if NOT_EXCLUDED(OP_igamma) +DECLARE_BROADCASTABLE_OP(igamma, 0, 0); +#endif +/** + * Broadcastable igammac implementation + * igammac(a, x) = Gamma(a,x)/Gamma(а) - Gamma distribution function Q(a,x) + * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } + * Gamma(a, x) = int from x to infinity { t ^ {a - 1} e^{-t}dt } + * @tparam T + */ +#if NOT_EXCLUDED(OP_igammac) +DECLARE_BROADCASTABLE_OP(igammac, 0, 0); +#endif +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/common.h b/libnd4j/include/ops/declarable/headers/common.h index d70e6beb8eef..bde901a61049 100644 --- a/libnd4j/include/ops/declarable/headers/common.h +++ b/libnd4j/include/ops/declarable/headers/common.h @@ -21,22 +21,25 @@ #ifndef LIBND4J_OPS_DECLARABLE_COMMON_H #define LIBND4J_OPS_DECLARABLE_COMMON_H -#include -#include -#include #include +#include #include -#include +#include +#include +#include #include -#include -#include +#include +#include #include #include +#include +#include +#include #include -#include -#include -#include -#include +#include +#include + +#include #include #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/compat.h b/libnd4j/include/ops/declarable/headers/compat.h index 37894517a0b9..1ec0e332b258 100644 --- a/libnd4j/include/ops/declarable/headers/compat.h +++ b/libnd4j/include/ops/declarable/headers/compat.h @@ -24,31 +24,30 @@ #include namespace sd { - namespace ops { - /** - * This operation splits input string into pieces separated by delimiter - * PLEASE NOTE: This implementation is compatible with TF 1.x - * - * Input[0] - string to split - * Input[1] - delimiter - * - * Returns: - * Output[0] - indices tensor - * Output[1] - values tensor - */ - #if NOT_EXCLUDED(OP_compat_string_split) - DECLARE_CUSTOM_OP(compat_string_split, 2, 2, false, 0, 0); - #endif - - /** - * This operation converts TF sparse array representation to dense NDArray - */ - #if NOT_EXCLUDED(OP_compat_sparse_to_dense) - DECLARE_CUSTOM_OP(compat_sparse_to_dense, 4, 1, false, 0, 0); - #endif - - } -} - - -#endif //SAMEDIFF_COMPAT_H +namespace ops { +/** + * This operation splits input string into pieces separated by delimiter + * PLEASE NOTE: This implementation is compatible with TF 1.x + * + * Input[0] - string to split + * Input[1] - delimiter + * + * Returns: + * Output[0] - indices tensor + * Output[1] - values tensor + */ +#if NOT_EXCLUDED(OP_compat_string_split) +DECLARE_CUSTOM_OP(compat_string_split, 2, 2, false, 0, 0); +#endif + +/** + * This operation converts TF sparse array representation to dense NDArray + */ +#if NOT_EXCLUDED(OP_compat_sparse_to_dense) +DECLARE_CUSTOM_OP(compat_sparse_to_dense, 4, 1, false, 0, 0); +#endif + +} // namespace ops +} // namespace sd + +#endif // SAMEDIFF_COMPAT_H diff --git a/libnd4j/include/ops/declarable/headers/compression.h b/libnd4j/include/ops/declarable/headers/compression.h index 9c177f8a4af3..3ffca17c3a30 100644 --- a/libnd4j/include/ops/declarable/headers/compression.h +++ b/libnd4j/include/ops/declarable/headers/compression.h @@ -24,39 +24,40 @@ #include namespace sd { - namespace ops { - - /** - * encode_bitmap - reinterpret 3D float tensor into uint8_t vector with length N. - * - * Input: - * 0 - 3D float tensor with shape {height, width, channels} - * - * Output: - * 0 - 1D uint8_t tensor with shape {N} - */ - #if NOT_EXCLUDED(OP_encode_bitmap) - DECLARE_CUSTOM_OP(encode_bitmap, 1, 3, true, 1, 0); - #endif +namespace ops { - /** - * decode_bitmap - reinterpret uint8_t linear tensor as data to float tensor with shape - * - * Input: - * 0 - uint8_t vector with length N ( shape {N}) - * - * Output: - * 0 - 3D tensor with shape {height, width, channels} - * - */ - #if NOT_EXCLUDED(OP_decode_bitmap) - DECLARE_CUSTOM_OP(decode_bitmap, 2, 1, true, 0, 0); - #endif +/** + * encode_bitmap - reinterpret 3D float tensor into uint8_t vector with length + * N. + * + * Input: + * 0 - 3D float tensor with shape {height, width, channels} + * + * Output: + * 0 - 1D uint8_t tensor with shape {N} + */ +#if NOT_EXCLUDED(OP_encode_bitmap) +DECLARE_CUSTOM_OP(encode_bitmap, 1, 3, true, 1, 0); +#endif +/** + * decode_bitmap - reinterpret uint8_t linear tensor as data to float tensor + * with shape + * + * Input: + * 0 - uint8_t vector with length N ( shape {N}) + * + * Output: + * 0 - 3D tensor with shape {height, width, channels} + * + */ +#if NOT_EXCLUDED(OP_decode_bitmap) +DECLARE_CUSTOM_OP(decode_bitmap, 2, 1, true, 0, 0); +#endif - DECLARE_CUSTOM_OP(encode_threshold, 2, 1, true, 1, 0); - DECLARE_CUSTOM_OP(decode_threshold, 2, 1, true, 0, 0); - } -} +DECLARE_CUSTOM_OP(encode_threshold, 2, 1, true, 1, 0); +DECLARE_CUSTOM_OP(decode_threshold, 2, 1, true, 0, 0); +} // namespace ops +} // namespace sd -#endif // SD_HEADERS_COMPRESSION_H \ No newline at end of file +#endif // SD_HEADERS_COMPRESSION_H \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/convo.h b/libnd4j/include/ops/declarable/headers/convo.h index d00da07f23d4..b472209b7e28 100644 --- a/libnd4j/include/ops/declarable/headers/convo.h +++ b/libnd4j/include/ops/declarable/headers/convo.h @@ -24,307 +24,306 @@ #include namespace sd { - namespace ops { +namespace ops { - /** - * 1D temporal convolution implementation - * Expected input: - * x: 3D array - * weight: 3D Array - * bias: optional vector - * - * Int args: - * 0: kernel - * 1: stride - * 2: padding - */ - #if NOT_EXCLUDED(OP_conv1d) - DECLARE_CUSTOM_OP(conv1d, 2, 1, false, 0, 5); - DECLARE_CUSTOM_OP(conv1d_bp, 3, 2, false, 0, 5); - #endif - - /** - * 2D convolution implementation - * Expected input: - * x: 4D array - * weight: 4D Array - * bias: optional vector, length of outputChannels - * - * IntArgs: - * 0: kernel height - * 1: kernel width - * 2: stride height - * 3: stride width - * 4: padding height - * 5: padding width - * 6: dilation height - * 7: dilation width - * 8: same mode: 1 true, 0 false - * 9: data format: 1 NHWC, 0 NCHW - */ - #if NOT_EXCLUDED(OP_conv2d) - DECLARE_CUSTOM_OP(conv2d, 2, 1, false, 0, 9); - DECLARE_CUSTOM_OP(conv2d_bp, 3, 2, false, 0, 9); - DECLARE_CUSTOM_OP(conv2d_input_bp, 3, 1, false, 0, 9); - #endif - - /** - * Depthwise convolution2d op: - * Expected inputs: - * x: 4D array, NCHW format - * weightsDepth: 4D array, - * weightsPointwise: optional, 4D array - * bias: optional, vector - */ - #if NOT_EXCLUDED(OP_sconv2d) - DECLARE_CUSTOM_OP(sconv2d, 2, 1, false, 0, 9); - DECLARE_CUSTOM_OP(sconv2d_bp, 3, 2, false, 0, 9); - #endif - - /** - * 2D deconvolution implementation - * - * IntArgs: - * 0: kernel height - * 1: kernel width - * 2: stride height - * 3: stride width - * 4: padding height - * 5: padding width - * 6: dilation height - * 7: dilation width - * 8: same mode: 0 false, 1 true - */ - #if NOT_EXCLUDED(OP_deconv2d) - DECLARE_CUSTOM_OP(deconv2d, 2, 1, false, 0, 9); - DECLARE_CUSTOM_OP(deconv2d_bp, 3, 2, false, 0, 9); - #endif +/** + * 1D temporal convolution implementation + * Expected input: + * x: 3D array + * weight: 3D Array + * bias: optional vector + * + * Int args: + * 0: kernel + * 1: stride + * 2: padding + */ +#if NOT_EXCLUDED(OP_conv1d) +DECLARE_CUSTOM_OP(conv1d, 2, 1, false, 0, 5); +DECLARE_CUSTOM_OP(conv1d_bp, 3, 2, false, 0, 5); +#endif - /** - * 3D deconvolution implementation - * - * IntArgs: - * 0: filter(kernel) depth - * 1: filter(kernel) height - * 2: filter(kernel) width - * 3: strides depth - * 4: strides height - * 5: strides width - * 6: paddings depth - * 7: paddings height - * 8: paddings width - * 9: dilations depth - * 10: dilations height - * 11: dilations width - * 12: same mode: 0 false, 1 true - * 13: data format (optional): 0-NDHWC, 1-NCDHW, default is 1 - */ +/** + * 2D convolution implementation + * Expected input: + * x: 4D array + * weight: 4D Array + * bias: optional vector, length of outputChannels + * + * IntArgs: + * 0: kernel height + * 1: kernel width + * 2: stride height + * 3: stride width + * 4: padding height + * 5: padding width + * 6: dilation height + * 7: dilation width + * 8: same mode: 1 true, 0 false + * 9: data format: 1 NHWC, 0 NCHW + */ +#if NOT_EXCLUDED(OP_conv2d) +DECLARE_CUSTOM_OP(conv2d, 2, 1, false, 0, 9); +DECLARE_CUSTOM_OP(conv2d_bp, 3, 2, false, 0, 9); +DECLARE_CUSTOM_OP(conv2d_input_bp, 3, 1, false, 0, 9); +#endif - #if NOT_EXCLUDED(OP_deconv3d) - DECLARE_CUSTOM_OP(deconv3d, 2, 1, false, 0, 13); - DECLARE_CUSTOM_OP(deconv3d_bp, 3, 2, false, 0, 13); - #endif +/** + * Depthwise convolution2d op: + * Expected inputs: + * x: 4D array, NCHW format + * weightsDepth: 4D array, + * weightsPointwise: optional, 4D array + * bias: optional, vector + */ +#if NOT_EXCLUDED(OP_sconv2d) +DECLARE_CUSTOM_OP(sconv2d, 2, 1, false, 0, 9); +DECLARE_CUSTOM_OP(sconv2d_bp, 3, 2, false, 0, 9); +#endif +/** + * 2D deconvolution implementation + * + * IntArgs: + * 0: kernel height + * 1: kernel width + * 2: stride height + * 3: stride width + * 4: padding height + * 5: padding width + * 6: dilation height + * 7: dilation width + * 8: same mode: 0 false, 1 true + */ +#if NOT_EXCLUDED(OP_deconv2d) +DECLARE_CUSTOM_OP(deconv2d, 2, 1, false, 0, 9); +DECLARE_CUSTOM_OP(deconv2d_bp, 3, 2, false, 0, 9); +#endif - /** - * This op implements max pooling for convolution networks. - * Expected Input: 4D array, NCHW format. - * - * IntArgs: - * 0: kernel height - * 1: kernel width - * 2: stride height - * 3: stride width - * 4: padding height - * 5: padding width - * 6: dilation height - * 7: dilation width - * 8: same mode: 0 false, 1 true - */ - #if NOT_EXCLUDED(OP_maxpool2d) - DECLARE_CUSTOM_OP(maxpool2d, 1, 1, false, 0, 10); - DECLARE_CUSTOM_OP(maxpool2d_bp, 2, 1, false, 0, 10); - #endif +/** + * 3D deconvolution implementation + * + * IntArgs: + * 0: filter(kernel) depth + * 1: filter(kernel) height + * 2: filter(kernel) width + * 3: strides depth + * 4: strides height + * 5: strides width + * 6: paddings depth + * 7: paddings height + * 8: paddings width + * 9: dilations depth + * 10: dilations height + * 11: dilations width + * 12: same mode: 0 false, 1 true + * 13: data format (optional): 0-NDHWC, 1-NCDHW, default is 1 + */ - /** - * This op implements average pooling for convolution networks. - * Expected Input: 4D array, NCHW format. - * - * IntArgs: - * 0: kernel height - * 1: kernel width - * 2: stride height - * 3: stride width - * 4: padding height - * 5: padding width - * 6: dilation height - * 7: dilation width - * 8: same mode: 0 false, 1 true - */ - #if NOT_EXCLUDED(OP_avgpool2d) - DECLARE_CUSTOM_OP(avgpool2d, 1, 1, false, 0, 10); - DECLARE_CUSTOM_OP(avgpool2d_bp, 2, 1, false, 0, 10); - #endif +#if NOT_EXCLUDED(OP_deconv3d) +DECLARE_CUSTOM_OP(deconv3d, 2, 1, false, 0, 13); +DECLARE_CUSTOM_OP(deconv3d_bp, 3, 2, false, 0, 13); +#endif - /** - * This op implements pnorm pooling for convolution networks. - * Expected Input: 4D array, NCHW format. - * - * IntArgs: - * 0: kernel height - * 1: kernel width - * 2: stride height - * 3: stride width - * 4: padding height - * 5: padding width - * 6: dilation height - * 7: dilation width - * 8: same mode: 0 false, 1 true - * 9: p for p-norm - */ - #if NOT_EXCLUDED(OP_pnormpool2d) - DECLARE_CUSTOM_OP(pnormpool2d, 1, 1, false, 0, 10); - DECLARE_CUSTOM_OP(pnormpool2d_bp, 2, 1, false, 1, 10); - #endif +/** + * This op implements max pooling for convolution networks. + * Expected Input: 4D array, NCHW format. + * + * IntArgs: + * 0: kernel height + * 1: kernel width + * 2: stride height + * 3: stride width + * 4: padding height + * 5: padding width + * 6: dilation height + * 7: dilation width + * 8: same mode: 0 false, 1 true + */ +#if NOT_EXCLUDED(OP_maxpool2d) +DECLARE_CUSTOM_OP(maxpool2d, 1, 1, false, 0, 10); +DECLARE_CUSTOM_OP(maxpool2d_bp, 2, 1, false, 0, 10); +#endif - /** - * This op implements im2col algorithm, widely used in convolution neural networks - * Input: 4D input expected - * - * Int args: - * 0: kernel height - * 1: kernel width - * 2: stride height - * 3: stride width - * 4: padding height - * 5: padding width - * 6: dilation height - * 7: dilation width - * 8: isSameMode - */ - #if NOT_EXCLUDED(OP_im2col) - DECLARE_CUSTOM_OP(im2col, 1, 1, false, 0, 9); - DECLARE_CUSTOM_OP(im2col_bp, 2, 1, false, 0, 9); - #endif +/** + * This op implements average pooling for convolution networks. + * Expected Input: 4D array, NCHW format. + * + * IntArgs: + * 0: kernel height + * 1: kernel width + * 2: stride height + * 3: stride width + * 4: padding height + * 5: padding width + * 6: dilation height + * 7: dilation width + * 8: same mode: 0 false, 1 true + */ +#if NOT_EXCLUDED(OP_avgpool2d) +DECLARE_CUSTOM_OP(avgpool2d, 1, 1, false, 0, 10); +DECLARE_CUSTOM_OP(avgpool2d_bp, 2, 1, false, 0, 10); +#endif - /** - * This op implements col2im algorithm, widely used in convolution neural networks - * Input: 6D input expected (like output of im2col op) - * - * Int args: - * 0: stride height - * 1: stride width - * 2: padding height - * 3: padding width - * 4: image height - * 5: image width - * 6: dilation height - * 7: dilation width - */ - #if NOT_EXCLUDED(OP_col2im) - DECLARE_CUSTOM_OP(col2im, 1, 1, false, 0, 9); - #endif +/** + * This op implements pnorm pooling for convolution networks. + * Expected Input: 4D array, NCHW format. + * + * IntArgs: + * 0: kernel height + * 1: kernel width + * 2: stride height + * 3: stride width + * 4: padding height + * 5: padding width + * 6: dilation height + * 7: dilation width + * 8: same mode: 0 false, 1 true + * 9: p for p-norm + */ +#if NOT_EXCLUDED(OP_pnormpool2d) +DECLARE_CUSTOM_OP(pnormpool2d, 1, 1, false, 0, 10); +DECLARE_CUSTOM_OP(pnormpool2d_bp, 2, 1, false, 1, 10); +#endif - /** - * Expected input: 4D array - * - * IntArgs: - * 0: scale factor for rows (height) - * 1: scale factor for columns (width) - * 2: data format: 0 NHWC (default), 1 NCHW - */ - #if NOT_EXCLUDED(OP_upsampling2d) - DECLARE_CUSTOM_OP(upsampling2d, 1, 1, false, 0, 2); - DECLARE_CUSTOM_OP(upsampling2d_bp, 2, 1, false, 0, 0); - #endif +/** + * This op implements im2col algorithm, widely used in convolution neural + * networks Input: 4D input expected + * + * Int args: + * 0: kernel height + * 1: kernel width + * 2: stride height + * 3: stride width + * 4: padding height + * 5: padding width + * 6: dilation height + * 7: dilation width + * 8: isSameMode + */ +#if NOT_EXCLUDED(OP_im2col) +DECLARE_CUSTOM_OP(im2col, 1, 1, false, 0, 9); +DECLARE_CUSTOM_OP(im2col_bp, 2, 1, false, 0, 9); +#endif - /** - * Expected input: 4D array - * - * IntArgs: - * 0: scale factor for depth - * 1: scale factor for rows (height) - * 2: scale factor for columns (width) - * 3: data format: 0 NDHWC (default), 1 NCDHW - */ - #if NOT_EXCLUDED(OP_upsampling3d) - DECLARE_CUSTOM_OP(upsampling3d, 1, 1, false, 0, 3); - DECLARE_CUSTOM_OP(upsampling3d_bp, 2, 1, false, 0, 0); - #endif +/** + * This op implements col2im algorithm, widely used in convolution neural + * networks Input: 6D input expected (like output of im2col op) + * + * Int args: + * 0: stride height + * 1: stride width + * 2: padding height + * 3: padding width + * 4: image height + * 5: image width + * 6: dilation height + * 7: dilation width + */ +#if NOT_EXCLUDED(OP_col2im) +DECLARE_CUSTOM_OP(col2im, 1, 1, false, 0, 9); +#endif - /** - * This op produces binary matrix wrt to target dimension. - * Maximum value within each TAD is replaced with 1, other values are set to true. - * - * Int args: - * 0: axis - */ - #if NOT_EXCLUDED(OP_ismax) - DECLARE_CONFIGURABLE_OP(ismax, 1, 1, true, 0, -2); - #endif +/** + * Expected input: 4D array + * + * IntArgs: + * 0: scale factor for rows (height) + * 1: scale factor for columns (width) + * 2: data format: 0 NHWC (default), 1 NCHW + */ +#if NOT_EXCLUDED(OP_upsampling2d) +DECLARE_CUSTOM_OP(upsampling2d, 1, 1, false, 0, 2); +DECLARE_CUSTOM_OP(upsampling2d_bp, 2, 1, false, 0, 0); +#endif - /** - * Dilation2D op - * - * Int args: - * 0: isSameMode - */ - #if NOT_EXCLUDED(OP_dilation2d) - DECLARE_CUSTOM_OP(dilation2d, 2, 1, false, 0, 1); - #endif +/** + * Expected input: 4D array + * + * IntArgs: + * 0: scale factor for depth + * 1: scale factor for rows (height) + * 2: scale factor for columns (width) + * 3: data format: 0 NDHWC (default), 1 NCDHW + */ +#if NOT_EXCLUDED(OP_upsampling3d) +DECLARE_CUSTOM_OP(upsampling3d, 1, 1, false, 0, 3); +DECLARE_CUSTOM_OP(upsampling3d_bp, 2, 1, false, 0, 0); +#endif - #if NOT_EXCLUDED(OP_conv3dnew) - DECLARE_CUSTOM_OP(conv3dnew, 2, 1, false, 0, 13); - DECLARE_CUSTOM_OP(conv3dnew_bp, 3, 2, false, 0, 13); - #endif +/** + * This op produces binary matrix wrt to target dimension. + * Maximum value within each TAD is replaced with 1, other values are set to + * true. + * + * Int args: + * 0: axis + */ +#if NOT_EXCLUDED(OP_ismax) +DECLARE_CONFIGURABLE_OP(ismax, 1, 1, true, 0, -2); +#endif - #if NOT_EXCLUDED(OP_avgpool3dnew) - DECLARE_CUSTOM_OP(avgpool3dnew, 1, 1, false, 0, 10); - DECLARE_CUSTOM_OP(avgpool3dnew_bp, 2, 1, false, 0, 10); - #endif +/** + * Dilation2D op + * + * Int args: + * 0: isSameMode + */ +#if NOT_EXCLUDED(OP_dilation2d) +DECLARE_CUSTOM_OP(dilation2d, 2, 1, false, 0, 1); +#endif - #if NOT_EXCLUDED(OP_maxpool3dnew) - DECLARE_CUSTOM_OP(maxpool3dnew, 1, 1, false, 0, 10); - DECLARE_CUSTOM_OP(maxpool3dnew_bp, 2, 1, false, 0, 10); - #endif +#if NOT_EXCLUDED(OP_conv3dnew) +DECLARE_CUSTOM_OP(conv3dnew, 2, 1, false, 0, 13); +DECLARE_CUSTOM_OP(conv3dnew_bp, 3, 2, false, 0, 13); +#endif - /** - * This op same as maxpool2d with a variant to return a matrix of indexes for max values - * - * Input - 4D tensor - * Output: - * 0 - 4D tensor as input - * 1 - 4D tensor with max value indexes - * - * Int params: - * 9 int with 2x4 vectors and 1 bool value - */ - #if NOT_EXCLUDED(OP_max_pool_woth_argmax) - DECLARE_CUSTOM_OP(max_pool_with_argmax, 1, 2, false, 0, 9); - #endif +#if NOT_EXCLUDED(OP_avgpool3dnew) +DECLARE_CUSTOM_OP(avgpool3dnew, 1, 1, false, 0, 10); +DECLARE_CUSTOM_OP(avgpool3dnew_bp, 2, 1, false, 0, 10); +#endif +#if NOT_EXCLUDED(OP_maxpool3dnew) +DECLARE_CUSTOM_OP(maxpool3dnew, 1, 1, false, 0, 10); +DECLARE_CUSTOM_OP(maxpool3dnew_bp, 2, 1, false, 0, 10); +#endif - #if NOT_EXCLUDED(OP_depthwise_conv2d) - DECLARE_CUSTOM_OP(depthwise_conv2d, 2, 1, false, 0, 9); - DECLARE_CUSTOM_OP(depthwise_conv2d_bp, 3, 2, false, 0, 9); - #endif +/** + * This op same as maxpool2d with a variant to return a matrix of indexes for + * max values + * + * Input - 4D tensor + * Output: + * 0 - 4D tensor as input + * 1 - 4D tensor with max value indexes + * + * Int params: + * 9 int with 2x4 vectors and 1 bool value + */ +#if NOT_EXCLUDED(OP_max_pool_woth_argmax) +DECLARE_CUSTOM_OP(max_pool_with_argmax, 1, 2, false, 0, 9); +#endif - /** - * point-wise 2D convolution - * Expected input: - * x: 4D array - * weight: 4D Array [1, 1, iC, oC] (NHWC) or [oC, iC, 1, 1] (NCHW) - * bias: optional vector, length of oC - * - * IntArgs: - * 0: data format: 1 NHWC, 0 NCHW (optional, by default = NHWC) - */ - DECLARE_CUSTOM_OP(pointwise_conv2d, 2, 1, false, 0, 0); +#if NOT_EXCLUDED(OP_depthwise_conv2d) +DECLARE_CUSTOM_OP(depthwise_conv2d, 2, 1, false, 0, 9); +DECLARE_CUSTOM_OP(depthwise_conv2d_bp, 3, 2, false, 0, 9); +#endif - DECLARE_CUSTOM_OP(deconv2d_tf, 2, 1, false, 0, 0); +/** + * point-wise 2D convolution + * Expected input: + * x: 4D array + * weight: 4D Array [1, 1, iC, oC] (NHWC) or [oC, iC, 1, 1] (NCHW) + * bias: optional vector, length of oC + * + * IntArgs: + * 0: data format: 1 NHWC, 0 NCHW (optional, by default = NHWC) + */ +DECLARE_CUSTOM_OP(pointwise_conv2d, 2, 1, false, 0, 0); - } -} +DECLARE_CUSTOM_OP(deconv2d_tf, 2, 1, false, 0, 0); +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/datatypes.h b/libnd4j/include/ops/declarable/headers/datatypes.h index e467539199c7..643784ce5e13 100644 --- a/libnd4j/include/ops/declarable/headers/datatypes.h +++ b/libnd4j/include/ops/declarable/headers/datatypes.h @@ -23,91 +23,94 @@ #include namespace sd { - namespace ops { - /** - * This operation casts elements of input array to double data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ - #if NOT_EXCLUDED(OP_to_double) - DECLARE_CUSTOM_OP(to_double, 1, 1, true, 0, 0); - #endif +namespace ops { +/** + * This operation casts elements of input array to double data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +#if NOT_EXCLUDED(OP_to_double) +DECLARE_CUSTOM_OP(to_double, 1, 1, true, 0, 0); +#endif - /** - * This operation casts elements of input array to float16 data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ - #if NOT_EXCLUDED(OP_to_float16) - DECLARE_CUSTOM_OP(to_float16, 1, 1, true, 0, 0); - #endif +/** + * This operation casts elements of input array to float16 data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +#if NOT_EXCLUDED(OP_to_float16) +DECLARE_CUSTOM_OP(to_float16, 1, 1, true, 0, 0); +#endif - /** - * This operation casts elements of input array to float data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ - #if NOT_EXCLUDED(OP_to_float32) - DECLARE_CUSTOM_OP(to_float32, 1, 1, true, 0, 0); - #endif +/** + * This operation casts elements of input array to float data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +#if NOT_EXCLUDED(OP_to_float32) +DECLARE_CUSTOM_OP(to_float32, 1, 1, true, 0, 0); +#endif - /** - * This operation casts elements of input array to int32 data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ - #if NOT_EXCLUDED(OP_to_int32) - DECLARE_CUSTOM_OP(to_int32, 1, 1, true, 0, 0); - #endif +/** + * This operation casts elements of input array to int32 data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +#if NOT_EXCLUDED(OP_to_int32) +DECLARE_CUSTOM_OP(to_int32, 1, 1, true, 0, 0); +#endif - /** - * This operation casts elements of input array to int64 (aka long long) data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ - #if NOT_EXCLUDED(OP_to_int64) - DECLARE_CUSTOM_OP(to_int64, 1, 1, true, 0, 0); - #endif +/** + * This operation casts elements of input array to int64 (aka long long) data + * type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +#if NOT_EXCLUDED(OP_to_int64) +DECLARE_CUSTOM_OP(to_int64, 1, 1, true, 0, 0); +#endif - /** - * This operation casts elements of input array to unsinged int32 data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ - #if NOT_EXCLUDED(OP_to_uint32) - DECLARE_CUSTOM_OP(to_uint32, 1, 1, true, 0, 0); - #endif +/** + * This operation casts elements of input array to unsinged int32 data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +#if NOT_EXCLUDED(OP_to_uint32) +DECLARE_CUSTOM_OP(to_uint32, 1, 1, true, 0, 0); +#endif - /** - * This operation casts elements of input array to unsigned int64 (aka unsigned long long) data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ - #if NOT_EXCLUDED(OP_to_uint64) - DECLARE_CUSTOM_OP(to_uint64, 1, 1, true, 0, 0); - #endif +/** + * This operation casts elements of input array to unsigned int64 (aka unsigned + * long long) data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +#if NOT_EXCLUDED(OP_to_uint64) +DECLARE_CUSTOM_OP(to_uint64, 1, 1, true, 0, 0); +#endif - /** - * This operation casts elements of input array to specified data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - * - * - * Int args: - * 0: target DataType - */ - #if NOT_EXCLUDED(OP_cast) - DECLARE_CUSTOM_OP(cast, 1, 1, false, 0, 1); - #endif - /** - * This operation change type of input and modified shape of output to conform with given data type - * - * all as above op - * */ - #if NOT_EXCLUDED(OP_bitcast) - DECLARE_CUSTOM_OP(bitcast, 1, 1, false, 0, 1); - #endif - } -} +/** + * This operation casts elements of input array to specified data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + * + * + * Int args: + * 0: target DataType + */ +#if NOT_EXCLUDED(OP_cast) +DECLARE_CUSTOM_OP(cast, 1, 1, false, 0, 1); +#endif +/** + * This operation change type of input and modified shape of output to conform + * with given data type + * + * all as above op + * */ +#if NOT_EXCLUDED(OP_bitcast) +DECLARE_CUSTOM_OP(bitcast, 1, 1, false, 0, 1); +#endif +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/images.h b/libnd4j/include/ops/declarable/headers/images.h index aa21145407fb..813791738467 100644 --- a/libnd4j/include/ops/declarable/headers/images.h +++ b/libnd4j/include/ops/declarable/headers/images.h @@ -16,7 +16,7 @@ // // @author Oleh Semeniv (oleg.semeniv@gmail.com) -// +// // // @author AbdelRauf (rauf@konduit.ai) // @@ -24,89 +24,89 @@ #ifndef LIBND4J_HEADERS_IMAGES_H #define LIBND4J_HEADERS_IMAGES_H -#include -#include -#include #include +#include +#include +#include #include namespace sd { namespace ops { - /** * Rgb To Hsv * Input arrays: - * 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels. - * Int arguments: - * 0 - optional argument, corresponds to dimension with 3 channels + * 0 - input array with rank >= 1, must have at least one dimension equal 3, + * that is dimension containing channels. Int arguments: 0 - optional argument, + * corresponds to dimension with 3 channels */ #if NOT_EXCLUDED(OP_rgb_to_hsv) - DECLARE_CONFIGURABLE_OP(rgb_to_hsv, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(rgb_to_hsv, 1, 1, true, 0, 0); #endif /** * Hsv To Rgb * Input arrays: - * 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels. - * Int arguments: - * 0 - optional argument, corresponds to dimension with 3 channels + * 0 - input array with rank >= 1, must have at least one dimension equal 3, + * that is dimension containing channels. Int arguments: 0 - optional argument, + * corresponds to dimension with 3 channels */ #if NOT_EXCLUDED(OP_hsv_to_rgb) - DECLARE_CONFIGURABLE_OP(hsv_to_rgb, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(hsv_to_rgb, 1, 1, true, 0, 0); #endif /** -* Rgb To GrayScale -* Input arrays: -* 0 - input array with rank >= 1, the RGB tensor to convert. Last dimension must have size 3 and should contain RGB values. -*/ + * Rgb To GrayScale + * Input arrays: + * 0 - input array with rank >= 1, the RGB tensor to convert. Last dimension + * must have size 3 and should contain RGB values. + */ #if NOT_EXCLUDED(OP_rgb_to_grs) - DECLARE_CUSTOM_OP(rgb_to_grs, 1, 1, false, 0, 0); +DECLARE_CUSTOM_OP(rgb_to_grs, 1, 1, false, 0, 0); #endif - /** - * Rgb To Yuv - * Input arrays: - * 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels. - * Int arguments: - * 0 - optional argument, corresponds to dimension with 3 channels - */ +/** + * Rgb To Yuv + * Input arrays: + * 0 - input array with rank >= 1, must have at least one dimension equal 3, + * that is dimension containing channels. Int arguments: 0 - optional argument, + * corresponds to dimension with 3 channels + */ #if NOT_EXCLUDED(OP_rgb_to_yuv) - DECLARE_CONFIGURABLE_OP(rgb_to_yuv, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(rgb_to_yuv, 1, 1, true, 0, 0); #endif - /** - * Yuv To Rgb - * Input arrays: - * 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels. - * Int arguments: - * 0 - optional argument, corresponds to dimension with 3 channels - */ +/** + * Yuv To Rgb + * Input arrays: + * 0 - input array with rank >= 1, must have at least one dimension equal 3, + * that is dimension containing channels. Int arguments: 0 - optional argument, + * corresponds to dimension with 3 channels + */ #if NOT_EXCLUDED(OP_rgb_to_yuv) - DECLARE_CONFIGURABLE_OP(yuv_to_rgb, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(yuv_to_rgb, 1, 1, true, 0, 0); #endif /** -* Rgb To Yiq -* Input arrays: -* 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels. -* Int arguments: -* 0 - optional argument, corresponds to dimension with 3 channels -*/ + * Rgb To Yiq + * Input arrays: + * 0 - input array with rank >= 1, must have at least one dimension equal 3, + * that is dimension containing channels. Int arguments: 0 - optional argument, + * corresponds to dimension with 3 channels + */ #if NOT_EXCLUDED(OP_rgb_to_yiq) - DECLARE_CONFIGURABLE_OP(rgb_to_yiq, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(rgb_to_yiq, 1, 1, true, 0, 0); #endif /** -* Yiq To Rgb -* Input arrays: -* 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels. -* Int arguments: -* 0 - optional argument, corresponds to dimension with 3 channels -*/ + * Yiq To Rgb + * Input arrays: + * 0 - input array with rank >= 1, must have at least one dimension equal 3, + * that is dimension containing channels. Int arguments: 0 - optional argument, + * corresponds to dimension with 3 channels + */ #if NOT_EXCLUDED(OP_yiq_to_rgb) - DECLARE_CONFIGURABLE_OP(yiq_to_rgb, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(yiq_to_rgb, 1, 1, true, 0, 0); #endif /** diff --git a/libnd4j/include/ops/declarable/headers/kernels.h b/libnd4j/include/ops/declarable/headers/kernels.h index c4cc02cb59ae..dbcdf20f7113 100644 --- a/libnd4j/include/ops/declarable/headers/kernels.h +++ b/libnd4j/include/ops/declarable/headers/kernels.h @@ -24,11 +24,11 @@ #include namespace sd { - namespace ops { - #if NOT_EXCLUDED(OP_knn_mindistance) - DECLARE_CUSTOM_OP(knn_mindistance, 3, 1, false, 0, 0); - #endif - } -} +namespace ops { +#if NOT_EXCLUDED(OP_knn_mindistance) +DECLARE_CUSTOM_OP(knn_mindistance, 3, 1, false, 0, 0); +#endif +} // namespace ops +} // namespace sd -#endif //LIBND4J_KERNELS_H +#endif // LIBND4J_KERNELS_H diff --git a/libnd4j/include/ops/declarable/headers/list.h b/libnd4j/include/ops/declarable/headers/list.h index af4fb5706e95..08703f61023b 100644 --- a/libnd4j/include/ops/declarable/headers/list.h +++ b/libnd4j/include/ops/declarable/headers/list.h @@ -24,108 +24,105 @@ #include namespace sd { - namespace ops { - // list operations, basically all around NDArrayList - - /** - * This operations puts given NDArray into (optionally) given NDArrayList. - * If no NDArrayList was provided - new one will be created - */ - #if NOT_EXCLUDED(OP_write_list) - DECLARE_LIST_OP(write_list, 2, 1, 0, -2); - #endif - - /** - * This operation concatenates given NDArrayList, and returns NDArray as result - */ - #if NOT_EXCLUDED(OP_stack_list) - DECLARE_LIST_OP(stack_list, 1, 1, 0, 0); - #endif - - /** - * This operations selects specified index fron NDArrayList and returns it as NDArray - * Expected arguments: - * x: non-empty list - * indices: optional, scalar with index - * - * Int args: - * optional, index - */ - #if NOT_EXCLUDED(OP_read_list) - DECLARE_LIST_OP(read_list, 1, 1, 0, 0); - #endif - - /** - * This operations selects specified indices fron NDArrayList and returns them as NDArray - * Expected arguments: - * x: non-empty list - * indices: optional, vector with indices - * - * Int args: - * optional, indices - */ - #if NOT_EXCLUDED(OP_pick_list) - DECLARE_LIST_OP(pick_list, 1, 1, -2, -2); - #endif - - /** - * This operations returns scalar, with number of existing arrays within given NDArrayList - * Expected arguments: - * x: list - */ - #if NOT_EXCLUDED(OP_size_list) - DECLARE_LIST_OP(size_list, 1, 1, 0, 0); - #endif - - /** - * This operation creates new empty NDArrayList - */ - #if NOT_EXCLUDED(OP_create_list) - DECLARE_LIST_OP(create_list, 1, 2, 0, -2); - #endif - - /** - * This operation unpacks given NDArray into specified NDArrayList wrt specified indices - */ - #if NOT_EXCLUDED(OP_scatter_list) - DECLARE_LIST_OP(scatter_list, 1, 1, 0, -2); - #endif - - /** - * This operation splits given NDArray into chunks, and stores them into given NDArrayList wert sizes - * Expected arguments: - * list: optional, NDArrayList. if not available - new NDArrayList will be created - * array: array to be split - * sizes: vector with sizes for each chunk - */ - #if NOT_EXCLUDED(OP_split_list) - DECLARE_LIST_OP(split_list, 2, 1, 0, -2); - #endif - - /** - * This operation builds NDArray from NDArrayList using indices - * Expected arguments: - * x: non-empty list - * indices: vector with indices for gather operation - */ - #if NOT_EXCLUDED(OP_gather_list) - DECLARE_LIST_OP(gather_list, 2, 1, 0, -2); - #endif - - /** - * This operation clones given NDArrayList - */ - #if NOT_EXCLUDED(OP_clone_list) - DECLARE_LIST_OP(clone_list, 1, 1, 0, 0); - #endif - - /** - * This operation unstacks given NDArray into NDArrayList by the first dimension - */ - #if NOT_EXCLUDED(OP_unstack_list) - DECLARE_LIST_OP(unstack_list, 1, 1, 0, 0); - #endif - } -} +namespace ops { +// list operations, basically all around NDArrayList + +/** + * This operations puts given NDArray into (optionally) given NDArrayList. + * If no NDArrayList was provided - new one will be created + */ +#if NOT_EXCLUDED(OP_write_list) +DECLARE_LIST_OP(write_list, 2, 1, 0, -2); +#endif + +/** + * This operation concatenates given NDArrayList, and returns NDArray as result + */ +#if NOT_EXCLUDED(OP_stack_list) +DECLARE_LIST_OP(stack_list, 1, 1, 0, 0); +#endif + +/** + * This operations selects specified index fron NDArrayList and returns it as + * NDArray Expected arguments: x: non-empty list indices: optional, scalar with + * index + * + * Int args: + * optional, index + */ +#if NOT_EXCLUDED(OP_read_list) +DECLARE_LIST_OP(read_list, 1, 1, 0, 0); +#endif + +/** + * This operations selects specified indices fron NDArrayList and returns them + * as NDArray Expected arguments: x: non-empty list indices: optional, vector + * with indices + * + * Int args: + * optional, indices + */ +#if NOT_EXCLUDED(OP_pick_list) +DECLARE_LIST_OP(pick_list, 1, 1, -2, -2); +#endif + +/** + * This operations returns scalar, with number of existing arrays within given + * NDArrayList Expected arguments: x: list + */ +#if NOT_EXCLUDED(OP_size_list) +DECLARE_LIST_OP(size_list, 1, 1, 0, 0); +#endif + +/** + * This operation creates new empty NDArrayList + */ +#if NOT_EXCLUDED(OP_create_list) +DECLARE_LIST_OP(create_list, 1, 2, 0, -2); +#endif + +/** + * This operation unpacks given NDArray into specified NDArrayList wrt specified + * indices + */ +#if NOT_EXCLUDED(OP_scatter_list) +DECLARE_LIST_OP(scatter_list, 1, 1, 0, -2); +#endif + +/** + * This operation splits given NDArray into chunks, and stores them into given + * NDArrayList wert sizes Expected arguments: list: optional, NDArrayList. if + * not available - new NDArrayList will be created array: array to be split + * sizes: vector with sizes for each chunk + */ +#if NOT_EXCLUDED(OP_split_list) +DECLARE_LIST_OP(split_list, 2, 1, 0, -2); +#endif + +/** + * This operation builds NDArray from NDArrayList using indices + * Expected arguments: + * x: non-empty list + * indices: vector with indices for gather operation + */ +#if NOT_EXCLUDED(OP_gather_list) +DECLARE_LIST_OP(gather_list, 2, 1, 0, -2); +#endif + +/** + * This operation clones given NDArrayList + */ +#if NOT_EXCLUDED(OP_clone_list) +DECLARE_LIST_OP(clone_list, 1, 1, 0, 0); +#endif + +/** + * This operation unstacks given NDArray into NDArrayList by the first dimension + */ +#if NOT_EXCLUDED(OP_unstack_list) +DECLARE_LIST_OP(unstack_list, 1, 1, 0, 0); +#endif +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/loss.h b/libnd4j/include/ops/declarable/headers/loss.h index 3f36250404f9..bc8fa22183f9 100644 --- a/libnd4j/include/ops/declarable/headers/loss.h +++ b/libnd4j/include/ops/declarable/headers/loss.h @@ -25,341 +25,383 @@ namespace sd { namespace ops { - - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of hinge loss function max(0, 1 - labels*logits) - * - * Input arrays: - * 0: logits - logits, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels and must be broadcastable to labels. - * 2: labels - ground truth vales, expected to be 0. or 1., type float. - * Must have the same shape as logits. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as logits. - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as logits or just single scalar, depending on reduction mode (see input integer argument) - */ - #if NOT_EXCLUDED(OP_hinge_loss) - DECLARE_CUSTOM_OP(hinge_loss, 3, 1, false, 0, 1); - DECLARE_CUSTOM_OP(hinge_loss_grad, 3, 3, false, 0, 1); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of hinge loss function max(0, 1 - labels*logits) + * + * Input arrays: + * 0: logits - logits, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels and must be + * broadcastable to labels. 2: labels - ground truth vales, expected to be 0. + * or 1., type float. Must have the same shape as logits. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as logits. + * 1 - "weighted_sum", output is scalar and equal to sum of all elements + * of weightedLosses array 2 - "weighted_mean", output is scalar and equal to + * sum of all elements of weightedLosses array divided by sum of all elements of + * weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output is scalar + * and equal to scalar sum of all elements of weightedLosses array divided by + * number of non-zero weights + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as logits or just single scalar, + * depending on reduction mode (see input integer argument) + */ +#if NOT_EXCLUDED(OP_hinge_loss) +DECLARE_CUSTOM_OP(hinge_loss, 3, 1, false, 0, 1); +DECLARE_CUSTOM_OP(hinge_loss_grad, 3, 3, false, 0, 1); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of Huber loss function: - * 0.5 * (labels-predictions)^2 if |labels-predictions| <= delta - * 0.5 * delta^2 + delta * (|labels-predictions| - delta) if |labels-predictions| > delta - * - * Input arrays: - * 0: predictions - the predicted values, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels, and must be broadcastable to labels. - * 2: labels - ground truth vales, type float. - * Must have the same shape as predictions. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as predictions - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Input float arguments: - * 0: point where the huber loss function changes from a quadratic to linear. - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as predictions or just single scalar, depending on reduction mode (see input integer argument) - */ - #if NOT_EXCLUDED(OP_huber_loss) - DECLARE_CUSTOM_OP(huber_loss, 3, 1, false, 1, 1); - DECLARE_CUSTOM_OP(huber_loss_grad, 3, 1, false, 1, 1); - #endif - - - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of logarithmic loss function ( y_i * log(p_i) + (1 - y_i) * log(1 - p_i) ) - * - * Input arrays: - * 0: predictions - the predicted values, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels, and must be broadcastable to labels. - * 2: labels - ground truth vales, type float. - * Must have the same shape as predictions. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as predictions - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Input float arguments: - * 0: a small increment to add to avoid taking a log of zero. - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as predictions or just single scalar, depending on reduction mode (see input integer argument) - */ - #if NOT_EXCLUDED(OP_log_loss) - DECLARE_CUSTOM_OP(log_loss, 3, 1, false, 1, 1); - DECLARE_CUSTOM_OP(log_loss_grad, 3, 3, false, 1, 1); - #endif - - /** - * l2_loss op. - * compute a l2 norm for given array. - * - * input param - an array (tensor) - * output value - a real number with given type (e.g. float or double) - */ - #if NOT_EXCLUDED(OP_l2_loss) - DECLARE_CUSTOM_OP(l2_loss, 1, 1, false, 0, 0); - #endif - - - /** - * This op calculates logarithmic loss of poisson distributed input. - * Input arrays: - * 0: log_predictions - must be already pre-transformed to log(x) - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels and must be broadcastable to labels. - * 2: labels - ground truth vales, expected to be 0. or 1., type float. - * Must have the same shape as logits. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as logits. - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * 1: optional - boolean value compute_full_loss: 0 (default) or 1 (compute) - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as log_predictions or just single scalar, depending on reduction mode (see input integer argument) - */ - #if NOT_EXCLUDED(OP_log_poisson_loss) - DECLARE_CUSTOM_OP(log_poisson_loss, 3, 1, true, 0, 1); - DECLARE_CUSTOM_OP(log_poisson_loss_grad, 3, 3, true, 0, 1); - #endif - - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of pairwise-errors-squared loss function - * - * Input arrays: - * 0: predictions - the predicted values, type float. - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels and must be broadcastable to labels. - * 2: labels - ground truth vales, type float. - * Must have the same shape as predictions. - * - * Output array: - * 0: loss value, it is just single scalar, type float. - */ - #if NOT_EXCLUDED(OP_mean_pairwssqerr_loss) - DECLARE_CUSTOM_OP(mean_pairwssqerr_loss, 3, 1, false, 0, 0); - DECLARE_CUSTOM_OP(mean_pairwssqerr_loss_grad, 3, 3, false, 0, 0); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of Huber loss function: + * 0.5 * (labels-predictions)^2 if + * |labels-predictions| <= delta 0.5 * delta^2 + delta * (|labels-predictions| - + * delta) if |labels-predictions| > delta + * + * Input arrays: + * 0: predictions - the predicted values, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels, and must be + * broadcastable to labels. 2: labels - ground truth vales, type float. Must + * have the same shape as predictions. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as + * predictions 1 - "weighted_sum", output is scalar and equal to sum of all + * elements of weightedLosses array 2 - "weighted_mean", output is scalar and + * equal to sum of all elements of weightedLosses array divided by sum of all + * elements of weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output + * is scalar and equal to scalar sum of all elements of weightedLosses array + * divided by number of non-zero weights + * + * Input float arguments: + * 0: point where the huber loss function changes from a quadratic to linear. + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as predictions or just single + * scalar, depending on reduction mode (see input integer argument) + */ +#if NOT_EXCLUDED(OP_huber_loss) +DECLARE_CUSTOM_OP(huber_loss, 3, 1, false, 1, 1); +DECLARE_CUSTOM_OP(huber_loss_grad, 3, 1, false, 1, 1); +#endif - - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of Sum-of-Squares loss function 1/N * sum_{i}^{N}(predictions_i - labels_i)^2 - * - * Input arrays: - * 0: predictions - the predicted values, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels and must be broadcastable to labels. - * 2: labels - ground truth vales, type float. - * Must have the same shape as predictions. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as predictions - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as predictions or just single scalar, depending on reduction mode (see input integer argument) - */ - #if NOT_EXCLUDED(OP_mean_sqerr_loss) - DECLARE_CUSTOM_OP(mean_sqerr_loss, 3, 1, false, 0, 1); - DECLARE_CUSTOM_OP(mean_sqerr_loss_grad, 3, 3, false, 0, 1); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of logarithmic loss function ( y_i * log(p_i) + (1 - y_i) * + * log(1 - p_i) ) + * + * Input arrays: + * 0: predictions - the predicted values, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels, and must be + * broadcastable to labels. 2: labels - ground truth vales, type float. Must + * have the same shape as predictions. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as + * predictions 1 - "weighted_sum", output is scalar and equal to sum of all + * elements of weightedLosses array 2 - "weighted_mean", output is scalar and + * equal to sum of all elements of weightedLosses array divided by sum of all + * elements of weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output + * is scalar and equal to scalar sum of all elements of weightedLosses array + * divided by number of non-zero weights + * + * Input float arguments: + * 0: a small increment to add to avoid taking a log of zero. + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as predictions or just single + * scalar, depending on reduction mode (see input integer argument) + */ +#if NOT_EXCLUDED(OP_log_loss) +DECLARE_CUSTOM_OP(log_loss, 3, 1, false, 1, 1); +DECLARE_CUSTOM_OP(log_loss_grad, 3, 3, false, 1, 1); +#endif +/** + * l2_loss op. + * compute a l2 norm for given array. + * + * input param - an array (tensor) + * output value - a real number with given type (e.g. float or double) + */ +#if NOT_EXCLUDED(OP_l2_loss) +DECLARE_CUSTOM_OP(l2_loss, 1, 1, false, 0, 0); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of sigmoid cross-entropy loss function max(logits, 0.) - logits * labels + log(1. + exp(-abs(logits))); - * - * Input arrays: - * 0: logits - logits, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels, and must be broadcastable to labels. - * 2: labels - ground truth vales, expected to be 0. or 1., type float. - * Must have the same shape as logits. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as logits. - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Input float arguments: - * 0: smoothing value, if it is greater than 0 then apply smoothing to the labels (smooth the labels towards 1/2): new_labels = labels * (1 - labelsSmoothing)+ 0.5 * labelsSmoothing - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as logits or just single scalar, depending on reduction mode (see input integer argument) - */ - #if NOT_EXCLUDED(OP_sigm_cross_entropy_loss) - DECLARE_CUSTOM_OP(sigm_cross_entropy_loss, 3, 1, false, 1, 1); - DECLARE_CUSTOM_OP(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1); - #endif - +/** + * This op calculates logarithmic loss of poisson distributed input. + * Input arrays: + * 0: log_predictions - must be already pre-transformed to log(x) + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels and must be + * broadcastable to labels. 2: labels - ground truth vales, expected to be 0. + * or 1., type float. Must have the same shape as logits. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as logits. + * 1 - "weighted_sum", output is scalar and equal to sum of all elements + * of weightedLosses array 2 - "weighted_mean", output is scalar and equal to + * sum of all elements of weightedLosses array divided by sum of all elements of + * weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output is scalar + * and equal to scalar sum of all elements of weightedLosses array divided by + * number of non-zero weights 1: optional - boolean value compute_full_loss: 0 + * (default) or 1 (compute) + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as log_predictions or just single + * scalar, depending on reduction mode (see input integer argument) + */ +#if NOT_EXCLUDED(OP_log_poisson_loss) +DECLARE_CUSTOM_OP(log_poisson_loss, 3, 1, true, 0, 1); +DECLARE_CUSTOM_OP(log_poisson_loss_grad, 3, 3, true, 0, 1); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of softmax cross-entropy loss function max(logits, 0.) - logits * labels + log(1. + exp(-abs(logits))); - * - * Input arrays: - * 0: logits - logits, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels, and must be broadcastable to labels. - * 2: labels - ground truth vales, expected to be 0. or 1., type float. - * Must have the same shape as logits. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as logits. - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Input float arguments: - * 0: smoothing value, if it is greater than 0 then apply smoothing to the labels (smooth the labels towards 1/numClasses): new_labels = labels * (1 - labelsSmoothing) + labelsSmoothing / numClasses - * - * Output array: - * 0: loss values, type float. - * Can be an array with shape as in logits except last dimension is equal to unity or just single scalar, depending on reduction mode (see input integer argument) - */ - #if NOT_EXCLUDED(OP_softmax_cross_entropy_loss) - DECLARE_CUSTOM_OP(softmax_cross_entropy_loss, 3, 1, false, 1, 1); - DECLARE_CUSTOM_OP(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of pairwise-errors-squared loss function + * + * Input arrays: + * 0: predictions - the predicted values, type float. + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels and must be + * broadcastable to labels. 2: labels - ground truth vales, type float. Must + * have the same shape as predictions. + * + * Output array: + * 0: loss value, it is just single scalar, type float. + */ +#if NOT_EXCLUDED(OP_mean_pairwssqerr_loss) +DECLARE_CUSTOM_OP(mean_pairwssqerr_loss, 3, 1, false, 0, 0); +DECLARE_CUSTOM_OP(mean_pairwssqerr_loss_grad, 3, 3, false, 0, 0); +#endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of Sum-of-Squares loss function 1/N * + * sum_{i}^{N}(predictions_i - labels_i)^2 + * + * Input arrays: + * 0: predictions - the predicted values, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels and must be + * broadcastable to labels. 2: labels - ground truth vales, type float. Must + * have the same shape as predictions. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as + * predictions 1 - "weighted_sum", output is scalar and equal to sum of all + * elements of weightedLosses array 2 - "weighted_mean", output is scalar and + * equal to sum of all elements of weightedLosses array divided by sum of all + * elements of weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output + * is scalar and equal to scalar sum of all elements of weightedLosses array + * divided by number of non-zero weights + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as predictions or just single + * scalar, depending on reduction mode (see input integer argument) + */ +#if NOT_EXCLUDED(OP_mean_sqerr_loss) +DECLARE_CUSTOM_OP(mean_sqerr_loss, 3, 1, false, 0, 1); +DECLARE_CUSTOM_OP(mean_sqerr_loss_grad, 3, 3, false, 0, 1); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of Absolute Difference loss function |predictions - labels| - * - * Input arrays: - * 0: predictions - the predicted values, type float. - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels and must be broadcastable to labels. - * 2: labels - ground truth vales, type float. - * Must have the same shape as predictions. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as predictions - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as predictions or just single scalar, depending on reduction mode (see input integer argument) - */ - #if NOT_EXCLUDED(OP_absolute_difference_loss) - DECLARE_CUSTOM_OP(absolute_difference_loss, 3, 1, false, 0, 1); - DECLARE_CUSTOM_OP(absolute_difference_loss_grad, 3, 3, false, 0, 1); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of sigmoid cross-entropy loss function max(logits, 0.) - + * logits * labels + log(1. + exp(-abs(logits))); + * + * Input arrays: + * 0: logits - logits, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels, and must be + * broadcastable to labels. 2: labels - ground truth vales, expected to be 0. + * or 1., type float. Must have the same shape as logits. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as logits. + * 1 - "weighted_sum", output is scalar and equal to sum of all elements + * of weightedLosses array 2 - "weighted_mean", output is scalar and equal to + * sum of all elements of weightedLosses array divided by sum of all elements of + * weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output is scalar + * and equal to scalar sum of all elements of weightedLosses array divided by + * number of non-zero weights + * + * Input float arguments: + * 0: smoothing value, if it is greater than 0 then apply smoothing to the + * labels (smooth the labels towards 1/2): new_labels = labels * (1 - + * labelsSmoothing)+ 0.5 * labelsSmoothing + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as logits or just single scalar, + * depending on reduction mode (see input integer argument) + */ +#if NOT_EXCLUDED(OP_sigm_cross_entropy_loss) +DECLARE_CUSTOM_OP(sigm_cross_entropy_loss, 3, 1, false, 1, 1); +DECLARE_CUSTOM_OP(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1); +#endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of softmax cross-entropy loss function max(logits, 0.) - + * logits * labels + log(1. + exp(-abs(logits))); + * + * Input arrays: + * 0: logits - logits, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels, and must be + * broadcastable to labels. 2: labels - ground truth vales, expected to be 0. + * or 1., type float. Must have the same shape as logits. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as logits. + * 1 - "weighted_sum", output is scalar and equal to sum of all elements + * of weightedLosses array 2 - "weighted_mean", output is scalar and equal to + * sum of all elements of weightedLosses array divided by sum of all elements of + * weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output is scalar + * and equal to scalar sum of all elements of weightedLosses array divided by + * number of non-zero weights + * + * Input float arguments: + * 0: smoothing value, if it is greater than 0 then apply smoothing to the + * labels (smooth the labels towards 1/numClasses): new_labels = labels * (1 - + * labelsSmoothing) + labelsSmoothing / numClasses + * + * Output array: + * 0: loss values, type float. + * Can be an array with shape as in logits except last dimension is equal + * to unity or just single scalar, depending on reduction mode (see input + * integer argument) + */ +#if NOT_EXCLUDED(OP_softmax_cross_entropy_loss) +DECLARE_CUSTOM_OP(softmax_cross_entropy_loss, 3, 1, false, 1, 1); +DECLARE_CUSTOM_OP(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of cosine-distance loss function 1. - (predictions * labels).reduce_sum_along(dimension) - * - * Input arrays: - * 0: predictions - the predicted values, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels and must be broadcastable to labels. - * 2: labels - ground truth vales, type float. - * Must have the same shape as predictions. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as predictions - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * 1: dimension along which the cosine distance is computed. - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as predictions or just single scalar, depending on reduction mode (see input integer argument) - */ - #if NOT_EXCLUDED(OP_cosine_distance_loss) - DECLARE_CUSTOM_OP(cosine_distance_loss, 3, 1, false, 0, 2); - DECLARE_CUSTOM_OP(cosine_distance_loss_grad, 3, 3, false, 0, 2); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of Absolute Difference loss function |predictions - labels| + * + * Input arrays: + * 0: predictions - the predicted values, type float. + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels and must be + * broadcastable to labels. 2: labels - ground truth vales, type float. Must + * have the same shape as predictions. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as + * predictions 1 - "weighted_sum", output is scalar and equal to sum of all + * elements of weightedLosses array 2 - "weighted_mean", output is scalar and + * equal to sum of all elements of weightedLosses array divided by sum of all + * elements of weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output + * is scalar and equal to scalar sum of all elements of weightedLosses array + * divided by number of non-zero weights + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as predictions or just single + * scalar, depending on reduction mode (see input integer argument) + */ +#if NOT_EXCLUDED(OP_absolute_difference_loss) +DECLARE_CUSTOM_OP(absolute_difference_loss, 3, 1, false, 0, 1); +DECLARE_CUSTOM_OP(absolute_difference_loss_grad, 3, 3, false, 0, 1); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of softmax cross-entropy loss function - * - * Input arrays: - * 0: logits - logits, type float - * 1: labels - ground truth vales, expected to be 0. or 1., type float. - * Must have the same shape as logits. - * - * Input integer arguments: - * 0: optional (default is last dimension) dimension with classes - * - * Output array: - * 0: loss values, type float. An array with shape resulting from reducing of logits shape along dimension with classes - */ - #if NOT_EXCLUDED(OP_softmax_cross_entropy_loss_with_logits) - DECLARE_CUSTOM_OP(softmax_cross_entropy_loss_with_logits, 2, 1, false, 0, 0); - DECLARE_CUSTOM_OP(softmax_cross_entropy_loss_with_logits_grad, 2, 2, false, 0, 0); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of cosine-distance loss function 1. - (predictions * + * labels).reduce_sum_along(dimension) + * + * Input arrays: + * 0: predictions - the predicted values, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels and must be + * broadcastable to labels. 2: labels - ground truth vales, type float. Must + * have the same shape as predictions. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as + * predictions 1 - "weighted_sum", output is scalar and equal to sum of all + * elements of weightedLosses array 2 - "weighted_mean", output is scalar and + * equal to sum of all elements of weightedLosses array divided by sum of all + * elements of weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output + * is scalar and equal to scalar sum of all elements of weightedLosses array + * divided by number of non-zero weights 1: dimension along which the cosine + * distance is computed. + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as predictions or just single + * scalar, depending on reduction mode (see input integer argument) + */ +#if NOT_EXCLUDED(OP_cosine_distance_loss) +DECLARE_CUSTOM_OP(cosine_distance_loss, 3, 1, false, 0, 2); +DECLARE_CUSTOM_OP(cosine_distance_loss_grad, 3, 3, false, 0, 2); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of sparse softmax cross-entropy loss function - * - * Input arrays: - * 0: labels - ground truth vales, expected to be within range [0, num_classes), type float. - * Must have rank equal logits rank minus 1. - * 1: logits - logits, type float - * - * Output array: - * 0: loss values, type float. Has the same shape as labels - */ - #if NOT_EXCLUDED(OP_sparse_softmax_cross_entropy_loss_with_logits) - DECLARE_CUSTOM_OP(sparse_softmax_cross_entropy_loss_with_logits, 2, 1, false, 0, 0); - DECLARE_CUSTOM_OP(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, false, 0, 0); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of softmax cross-entropy loss function + * + * Input arrays: + * 0: logits - logits, type float + * 1: labels - ground truth vales, expected to be 0. or 1., type float. + * Must have the same shape as logits. + * + * Input integer arguments: + * 0: optional (default is last dimension) dimension with classes + * + * Output array: + * 0: loss values, type float. An array with shape resulting from reducing of + * logits shape along dimension with classes + */ +#if NOT_EXCLUDED(OP_softmax_cross_entropy_loss_with_logits) +DECLARE_CUSTOM_OP(softmax_cross_entropy_loss_with_logits, 2, 1, false, 0, 0); +DECLARE_CUSTOM_OP(softmax_cross_entropy_loss_with_logits_grad, 2, 2, false, 0, + 0); +#endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of sparse softmax cross-entropy loss function + * + * Input arrays: + * 0: labels - ground truth vales, expected to be within range [0, + * num_classes), type float. Must have rank equal logits rank minus 1. 1: logits + * - logits, type float + * + * Output array: + * 0: loss values, type float. Has the same shape as labels + */ +#if NOT_EXCLUDED(OP_sparse_softmax_cross_entropy_loss_with_logits) +DECLARE_CUSTOM_OP(sparse_softmax_cross_entropy_loss_with_logits, 2, 1, false, 0, + 0); +DECLARE_CUSTOM_OP(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, + false, 0, 0); +#endif -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/nlp.h b/libnd4j/include/ops/declarable/headers/nlp.h index e12db1402576..ab399fe8abce 100644 --- a/libnd4j/include/ops/declarable/headers/nlp.h +++ b/libnd4j/include/ops/declarable/headers/nlp.h @@ -18,21 +18,21 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_NLP_H -#define DEV_TESTS_NLP_H +#ifndef SD_NLP_H +#define SD_NLP_H #include namespace sd { - namespace ops { +namespace ops { - #if NOT_EXCLUDED(OP_skipgram) - DECLARE_CONFIGURABLE_OP(skipgram, 12, 12, true, 0, 0); - #endif +#if NOT_EXCLUDED(OP_skipgram) +DECLARE_CONFIGURABLE_OP(skipgram, 12, 12, true, 0, 0); +#endif - #if NOT_EXCLUDED(OP_cbow) - DECLARE_CONFIGURABLE_OP(cbow, 15, 15, true, 0, 0); - #endif - } -} +#if NOT_EXCLUDED(OP_cbow) +DECLARE_CONFIGURABLE_OP(cbow, 15, 15, true, 0, 0); +#endif +} // namespace ops +} // namespace sd -#endif //DEV_TESTS_NLP_H +#endif // SD_NLP_H diff --git a/libnd4j/include/ops/declarable/headers/nn.h b/libnd4j/include/ops/declarable/headers/nn.h index 26699aa324d4..6b4ad2e3212b 100644 --- a/libnd4j/include/ops/declarable/headers/nn.h +++ b/libnd4j/include/ops/declarable/headers/nn.h @@ -24,236 +24,244 @@ #include namespace sd { - namespace ops { +namespace ops { - #if NOT_EXCLUDED(OP_softmax) - DECLARE_CONFIGURABLE_OP(softmax, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(softmax_bp, 2, 1, true, 0, 0); - #endif +#if NOT_EXCLUDED(OP_softmax) +DECLARE_CONFIGURABLE_OP(softmax, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(softmax_bp, 2, 1, true, 0, 0); +#endif - /** - * Local response normalization implementation as TF. - * input: 4D array - * - * T args: - * - * 0: bias - * 1: alpha - * 2: beta - * - * Int arg: depth - optional local radius - * - * output - 4D array - */ - #if NOT_EXCLUDED(OP_lrn) - DECLARE_CONFIGURABLE_OP(lrn, 1, 1, true, 3, 0); - #endif - - /** - * Local response normalization - backprop variant. - * input: - * 0 - 4D array of data - * 1 - epsilon - 4D array of approximation - * - * T args: - * - * 0: bias - * 1: alpha - * 2: beta - * - * Int arg: depth - optional local radius - * - * output - next approximation as 4D array - */ - #if NOT_EXCLUDED(OP_lrn) - DECLARE_CONFIGURABLE_OP(lrn_bp, 2, 1, true, 3, 0); - #endif - - /** - * Batch normalization implementation. - * Reference: https://arxiv.org/abs/1502.03167v3 - * - * Expected arguments: - * input: input array (any number of dimensions) - * mean: - * variance: - * gamma: - * beta: - * - * Int args: - * 0: apply scale - * 1: apply offset - * - * - * T args: - * 0: epsilon - */ - #if NOT_EXCLUDED(OP_batchnorm) - DECLARE_CUSTOM_OP(batchnorm, 3, 1, false, 1, 2); - #endif - - /** - * back prop in batch normalization - * - * Expected arguments: - * input: input array (any number of dimensions) - * mean: - * variance: - * gamma: optional - * beta: optional - * dLdOut: next epsilon - * - * Int args: - * 0: apply scale - * 1: apply offset - * - * T args: - * 0: epsilon - * - * output arrays: - * dL/dInput - * dL/dMean - * dL/dVariance - * dL/dGamma, optional - * dL/dBeta, optional - */ - #if NOT_EXCLUDED(OP_batchnorm) - DECLARE_CUSTOM_OP(batchnorm_bp, 4, 3, false, 1, 2); - #endif +/** + * Local response normalization implementation as TF. + * input: 4D array + * + * T args: + * + * 0: bias + * 1: alpha + * 2: beta + * + * Int arg: depth - optional local radius + * + * output - 4D array + */ +#if NOT_EXCLUDED(OP_lrn) +DECLARE_CONFIGURABLE_OP(lrn, 1, 1, true, 3, 0); +#endif +/** + * Local response normalization - backprop variant. + * input: + * 0 - 4D array of data + * 1 - epsilon - 4D array of approximation + * + * T args: + * + * 0: bias + * 1: alpha + * 2: beta + * + * Int arg: depth - optional local radius + * + * output - next approximation as 4D array + */ +#if NOT_EXCLUDED(OP_lrn) +DECLARE_CONFIGURABLE_OP(lrn_bp, 2, 1, true, 3, 0); +#endif - /** - * This operation updates parameters with provided gradients, wrt learning rate - * Expected arguments: - * x: parameters, any shape - * y: gradients. same shape as x - * lr: optional, learning rate - * - * T args: - * 0: optional, learning rate - */ - #if NOT_EXCLUDED(OP_apply_sgd) - DECLARE_CONFIGURABLE_OP(apply_sgd, 2, 1, true, -2, 0); - #endif +/** + * Batch normalization implementation. + * Reference: https://arxiv.org/abs/1502.03167v3 + * + * Expected arguments: + * input: input array (any number of dimensions) + * mean: + * variance: + * gamma: + * beta: + * + * Int args: + * 0: apply scale + * 1: apply offset + * + * + * T args: + * 0: epsilon + */ +#if NOT_EXCLUDED(OP_batchnorm) +DECLARE_CUSTOM_OP(batchnorm, 3, 1, false, 1, 2); +#endif - /** - * This operation performs batch normalization of layer, it is based on following article https://arxiv.org/abs/1502.03167. - * Expected arguments: - * x: input 4D array of shape [bS,iH,iW,iD] (data format = NHWC) or [bS,iD,iH,iW] (data format = NCHW), where - * bS - batch size - * iH - input height - * iW - input width - * iD - input depth (or number of channels) - * scale: 1D input array of scale factors, shape [iD] - * offset: 1D input array of offsets (shifts), shape [iD] - * mean: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false - * variance: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false - * - * T input arguments: - * 0: epsilon, it is optional argument, default value is 0.001, this is small number to be added to the variance of x - * - * integer input arguments: - * 0: dataFormat, may have two values: zero -> NHWC, unity -> NCHW - * 1: isTraining, may have two values: zero -> inference, unity -> training - */ - #if NOT_EXCLUDED(OP_fused_batch_norm) - DECLARE_CUSTOM_OP(fused_batch_norm, 3, 1, false, 0, 2); - #endif +/** + * back prop in batch normalization + * + * Expected arguments: + * input: input array (any number of dimensions) + * mean: + * variance: + * gamma: optional + * beta: optional + * dLdOut: next epsilon + * + * Int args: + * 0: apply scale + * 1: apply offset + * + * T args: + * 0: epsilon + * + * output arrays: + * dL/dInput + * dL/dMean + * dL/dVariance + * dL/dGamma, optional + * dL/dBeta, optional + */ +#if NOT_EXCLUDED(OP_batchnorm) +DECLARE_CUSTOM_OP(batchnorm_bp, 4, 3, false, 1, 2); +#endif - #if NOT_EXCLUDED(OP_log_softmax) - DECLARE_CONFIGURABLE_OP(log_softmax, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(log_softmax_bp, 2, 1, true, 0, 0); - #endif +/** + * This operation updates parameters with provided gradients, wrt learning rate + * Expected arguments: + * x: parameters, any shape + * y: gradients. same shape as x + * lr: optional, learning rate + * + * T args: + * 0: optional, learning rate + */ +#if NOT_EXCLUDED(OP_apply_sgd) +DECLARE_CONFIGURABLE_OP(apply_sgd, 2, 1, true, -2, 0); +#endif +/** + * This operation performs batch normalization of layer, it is based on + * following article https://arxiv.org/abs/1502.03167. Expected arguments: x: + * input 4D array of shape [bS,iH,iW,iD] (data format = NHWC) or [bS,iD,iH,iW] + * (data format = NCHW), where bS - batch size iH - input height iW - input + * width iD - input depth (or number of channels) scale: 1D input array of + * scale factors, shape [iD] offset: 1D input array of offsets (shifts), shape + * [iD] mean: 1D input array of population mean used for inference, shape [iD], + * this array is required only if isTraining = false variance: 1D input array of + * population mean used for inference, shape [iD], this array is required only + * if isTraining = false + * + * T input arguments: + * 0: epsilon, it is optional argument, default value is 0.001, this is small + * number to be added to the variance of x + * + * integer input arguments: + * 0: dataFormat, may have two values: zero -> NHWC, unity -> NCHW + * 1: isTraining, may have two values: zero -> inference, unity -> training + */ +#if NOT_EXCLUDED(OP_fused_batch_norm) +DECLARE_CUSTOM_OP(fused_batch_norm, 3, 1, false, 0, 2); +#endif - /** - * relu_layer = relu(x*w + b) - */ - DECLARE_CUSTOM_OP(relu_layer, 3, 1, false, 0, 0); +#if NOT_EXCLUDED(OP_log_softmax) +DECLARE_CONFIGURABLE_OP(log_softmax, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(log_softmax_bp, 2, 1, true, 0, 0); +#endif - /** - * applies layer normalization to input - * y = g * standardize(x) + b - * - * see sd::ops::standardize - * - */ - #if NOT_EXCLUDED(OP_layer_norm) - DECLARE_CONFIGURABLE_OP(layer_norm, 3, 1, true, 0, -2); - DECLARE_CUSTOM_OP(layer_norm_bp, 4, 1, false, 0, -2); - #endif +/** + * relu_layer = relu(x*w + b) + */ +DECLARE_CUSTOM_OP(relu_layer, 3, 1, false, 0, 0); - /** - * This operation performs dot product attention on the given timeseries input with the given queries - * out = sum(similarity(k_i, q) * v_i) - * - * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q - * - * Optionally with normalization step: - * similarity(k, q) = softmax(k * q / sqrt(size(q)) - * - * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1) - * - * Note: This supports multiple queries at once, if only one query is available the queries vector still has to - * be 3D but can have queryCount = 1 - * - * Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for - * both. - * - * Expected arguments: - * q: input 3D array "queries" of shape [batchSize, featureKeys, queryCount] or 4D array of shape [batchSize, numHeads, featureKeys, queryCount] - * k: input 3D array "keys" of shape [batchSize, featureKeys, timesteps] or 4D array of shape [batchSize, numHeads, featureKeys, timesteps] - * v: input 3D array "values" of shape [batchSize, featureValues, timesteps] or 4D array of shape [batchSize, numHeads, featureValues, timesteps] - * mask: OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] - * - * integer input arguments: - * 0: normalization, may have two values: zero -> do not apply normalization, one -> apply normalization - * 1: withWeights, may have two values: zero -> do not return weights, one -> return weights - * - * Output Arrays: - * 0: Attention result arrays of shape [batchSize, featureValues, queryCount] or [batchSize, numHeads, featureValues, queryCount] - * 1: OPTIONAL; Attention weights of shape [batchSize, timesteps, queryCount] or [batchSize, numHeads, timesteps, queryCount] - */ - #if NOT_EXCLUDED(OP_dot_product_attention) - DECLARE_CUSTOM_OP(dot_product_attention, 3, -1, false, 0, 2); - DECLARE_CUSTOM_OP(dot_product_attention_bp, 4, 3, false, 0, 1); - #endif +/** + * applies layer normalization to input + * y = g * standardize(x) + b + * + * see sd::ops::standardize + * + */ +#if NOT_EXCLUDED(OP_layer_norm) +DECLARE_CONFIGURABLE_OP(layer_norm, 3, 1, true, 0, -2); +DECLARE_CUSTOM_OP(layer_norm_bp, 4, 1, false, 0, -2); +#endif +/** + * This operation performs dot product attention on the given timeseries input + * with the given queries out = sum(similarity(k_i, q) * v_i) + * + * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q + * + * Optionally with normalization step: + * similarity(k, q) = softmax(k * q / sqrt(size(q)) + * + * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, + * eq. 1) + * + * Note: This supports multiple queries at once, if only one query is available + * the queries vector still has to be 3D but can have queryCount = 1 + * + * Note: keys and values usually is the same array. If you want to use it as the + * same array, simply pass it for both. + * + * Expected arguments: + * q: input 3D array "queries" of shape [batchSize, featureKeys, queryCount] or + * 4D array of shape [batchSize, numHeads, featureKeys, queryCount] k: input 3D + * array "keys" of shape [batchSize, featureKeys, timesteps] or 4D array of + * shape [batchSize, numHeads, featureKeys, timesteps] v: input 3D array + * "values" of shape [batchSize, featureValues, timesteps] or 4D array of shape + * [batchSize, numHeads, featureValues, timesteps] mask: OPTIONAL; array that + * defines which values should be skipped of shape [batchSize, timesteps] + * + * integer input arguments: + * 0: normalization, may have two values: zero -> do not apply normalization, + * one -> apply normalization 1: withWeights, may have two values: zero -> do + * not return weights, one -> return weights + * + * Output Arrays: + * 0: Attention result arrays of shape [batchSize, featureValues, queryCount] or + * [batchSize, numHeads, featureValues, queryCount] 1: OPTIONAL; Attention + * weights of shape [batchSize, timesteps, queryCount] or [batchSize, numHeads, + * timesteps, queryCount] + */ +#if NOT_EXCLUDED(OP_dot_product_attention) +DECLARE_CUSTOM_OP(dot_product_attention, 3, -1, false, 0, 2); +DECLARE_CUSTOM_OP(dot_product_attention_bp, 4, 3, false, 0, 1); +#endif - /** - * This performs multi-headed dot product attention on the given timeseries input - * out = concat(head_1, head_2, ..., head_n) * Wo - * head_i = dot_product_attention(Wq_i*q, Wk_i*k, Wv_i*v) - * - * Optionally with normalization when calculating the attention for each head. - * - * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. 4,5, "3.2.2 Multi-Head Attention") - * - * This makes use of dot_product_attention OP support for rank 4 inputs. - * - * Expected arguments: - * q: input 3D array "queries" of shape [batchSize, featureKeys, queryCount] - * k: input 3D array "keys" of shape [batchSize, featureKeys, timesteps] - * v: input 3D array "values" of shape [batchSize, featureValues, timesteps] - * Wq: input query projection weights of shape [numHeads, projectedKeys, featureKeys] - * Wk: input key projection weights of shape [numHeads, projectedKeys, featureKeys] - * Wv: input value projection weights of shape [numHeads, projectedValues, featureValues] - * Wo: output projection weights of shape [numHeads * projectedValues, outSize] - * mask: OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] - * - * integer input arguments: - * 0: normalization, may have two values: zero -> do not apply normalization, one -> apply normalization - * 1: withWeights, may have two values: zero -> do not return weights, one -> return weights - * - * Output Arrays: - * 0: Attention result arrays of shape [batchSize, outSize, queryCount] - * 1: OPTIONAL; Attention weights of shape [batchSize, numHeads, timesteps, queryCount] - */ - #if NOT_EXCLUDED(OP_multi_head_dot_product_attention) - DECLARE_CUSTOM_OP(multi_head_dot_product_attention, 7, -1, false, 0, 2); - DECLARE_CUSTOM_OP(multi_head_dot_product_attention_bp, 8, 7, false, 0, 1); - #endif - } -} +/** + * This performs multi-headed dot product attention on the given timeseries + * input out = concat(head_1, head_2, ..., head_n) * Wo head_i = + * dot_product_attention(Wq_i*q, Wk_i*k, Wv_i*v) + * + * Optionally with normalization when calculating the attention for each head. + * + * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. + * 4,5, "3.2.2 Multi-Head Attention") + * + * This makes use of dot_product_attention OP support for rank 4 inputs. + * + * Expected arguments: + * q: input 3D array "queries" of shape [batchSize, featureKeys, queryCount] + * k: input 3D array "keys" of shape [batchSize, featureKeys, timesteps] + * v: input 3D array "values" of shape [batchSize, featureValues, timesteps] + * Wq: input query projection weights of shape [numHeads, projectedKeys, + * featureKeys] Wk: input key projection weights of shape [numHeads, + * projectedKeys, featureKeys] Wv: input value projection weights of shape + * [numHeads, projectedValues, featureValues] Wo: output projection weights of + * shape [numHeads * projectedValues, outSize] mask: OPTIONAL; array that + * defines which values should be skipped of shape [batchSize, timesteps] + * + * integer input arguments: + * 0: normalization, may have two values: zero -> do not apply normalization, + * one -> apply normalization 1: withWeights, may have two values: zero -> do + * not return weights, one -> return weights + * + * Output Arrays: + * 0: Attention result arrays of shape [batchSize, outSize, queryCount] + * 1: OPTIONAL; Attention weights of shape [batchSize, numHeads, timesteps, + * queryCount] + */ +#if NOT_EXCLUDED(OP_multi_head_dot_product_attention) +DECLARE_CUSTOM_OP(multi_head_dot_product_attention, 7, -1, false, 0, 2); +DECLARE_CUSTOM_OP(multi_head_dot_product_attention_bp, 8, 7, false, 0, 1); +#endif +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index 27c012214664..839f54172fa7 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -25,34 +25,32 @@ #include namespace sd { - namespace ops { - /** - * This operation returns index of max element in a given NDArray (optionally: along given dimension(s)) - * Expected input: - * 0: N-dimensional array - * 1: optional axis vector - * - * Int args: - * 0: optional axis - */ - #if NOT_EXCLUDED(OP_argmax) - DECLARE_CUSTOM_OP(argmax, 1, 1, false, 0, -2); - #endif +namespace ops { +/** + * This operation returns index of max element in a given NDArray (optionally: + * along given dimension(s)) Expected input: 0: N-dimensional array 1: optional + * axis vector + * + * Int args: + * 0: optional axis + */ +#if NOT_EXCLUDED(OP_argmax) +DECLARE_CUSTOM_OP(argmax, 1, 1, false, 0, -2); +#endif - /** - * This operation returns index of min element in a given NDArray (optionally: along given dimension(s)) - * Expected input: - * 0: N-dimensional array - * 1: optional axis vector - * - * Int args: - * 0: optional axis - */ - #if NOT_EXCLUDED(OP_argmin) - DECLARE_CUSTOM_OP(argmin, 1, 1, false, 0, -2); - #endif +/** + * This operation returns index of min element in a given NDArray (optionally: + * along given dimension(s)) Expected input: 0: N-dimensional array 1: optional + * axis vector + * + * Int args: + * 0: optional axis + */ +#if NOT_EXCLUDED(OP_argmin) +DECLARE_CUSTOM_OP(argmin, 1, 1, false, 0, -2); +#endif - /** +/** * This operation returns index of absolute max element in a given NDArray (optionally: along given dimension(s)) * Expected input: * 0: N-dimensional array @@ -79,1697 +77,1764 @@ namespace sd { #endif /** - * This operation provides various normalization modes: - * 0: frobenius - * 1: euclidean (norm2) - * 2: norm1 - * 3: norm2 - * 4: inf-norm - * 5: p-norm - * - * Expected arguments: - * input: N-dimensional array - * - * - * Int args: - * 0...: axis - * - * T args: - * 0: norm mode - * 1: p for p-norm - */ - #if NOT_EXCLUDED(OP_norm) - DECLARE_REDUCTION_OP(norm, 1, 1, false, 1, -2); - #endif - - /** - * Inserts elements provided by diagonal array into the main diagonal of innermost matrices of input array - * - * Input arrays: - * 0: input array, considered as batch of matrices - * 1: diagonal array containing elements to be inserted into input array, - * following rank condition should be satisfied: diagonal_rank = input_rank - 1, - * the shapes of diagonal and input arrays must be equal except last dimension of input array, - * for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C], - * also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions - * that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2]) - * - * Output array: - * 0: has the same shape as input, corresponding diagonal elements are substituted - */ - #if NOT_EXCLUDED(OP_matrix_set_diag) - DECLARE_CONFIGURABLE_OP(matrix_set_diag, 2, 1, false, 0, 0); - #endif + * This operation provides various normalization modes: + * 0: frobenius + * 1: euclidean (norm2) + * 2: norm1 + * 3: norm2 + * 4: inf-norm + * 5: p-norm + * + * Expected arguments: + * input: N-dimensional array + * + * + * Int args: + * 0...: axis + * + * T args: + * 0: norm mode + * 1: p for p-norm + */ +#if NOT_EXCLUDED(OP_norm) +DECLARE_REDUCTION_OP(norm, 1, 1, false, 1, -2); +#endif - /** - * Inserts elements provided by diagonal array into the main diagonal of innermost matrices of output array, - * rest output elements are set to zeros - * - * Input array: - * diagonal: array containing elements to be inserted into output array, - * following rank condition is present: diagonal_rank = ouput_rank - 1 - * - * Output array: - * 0: is considered as batch of matrices, if for example diagonal array has shape [A,B,C] then output array has shape [A,B,C,C] - */ - DECLARE_CUSTOM_OP(matrix_diag, 1, 1, false, 0, 0); +/** + * Inserts elements provided by diagonal array into the main diagonal of + * innermost matrices of input array + * + * Input arrays: + * 0: input array, considered as batch of matrices + * 1: diagonal array containing elements to be inserted into input array, + * following rank condition should be satisfied: diagonal_rank = input_rank + * - 1, the shapes of diagonal and input arrays must be equal except last + * dimension of input array, for example if input_shape = [A,B,C,D] then + * diagonal_shape = [A,B,C], also last dimension of diagonal array should be + * equal to smaller of last and last but one input dimensions that is: + * diagonal_shape[-1] = min(input_shape[-1], input_shape[-2]) + * + * Output array: + * 0: has the same shape as input, corresponding diagonal elements are + * substituted + */ +#if NOT_EXCLUDED(OP_matrix_set_diag) +DECLARE_CONFIGURABLE_OP(matrix_set_diag, 2, 1, false, 0, 0); +#endif - /** - * This op calculates regularized incomplete beta integral Ix(a, b). - * Implementation is based on two algorithms depending on input values of a and b: - * - when a and b are both > maxValue (3000.), then Gauss-Legendre quadrature method is applied - * - when a and b are both <= maxValue (3000.), then modified Lentz’s algorithm for continued fractions is applied - * - * Input arrays: - * a: defines power t^{a-1}, must be > 0, type float. - * b: defines power (1-t)^{b-1}, must be > 0, type float. - * x: defines upper limit of integration, must be within (0 <= x <= 1) range, type float. - * - * Output array: - * 0: values of regularized incomplete beta integral that corresponds to variable upper limit x, type float - * - * Three input and one output arrays must have the same shape - */ - #if NOT_EXCLUDED(OP_betainc) - DECLARE_CONFIGURABLE_OP(betainc, 3, 1, false, 0, 0); - #endif +/** + * Inserts elements provided by diagonal array into the main diagonal of + * innermost matrices of output array, rest output elements are set to zeros + * + * Input array: + * diagonal: array containing elements to be inserted into output array, + * following rank condition is present: diagonal_rank = ouput_rank + * - 1 + * + * Output array: + * 0: is considered as batch of matrices, if for example diagonal array has + * shape [A,B,C] then output array has shape [A,B,C,C] + */ +DECLARE_CUSTOM_OP(matrix_diag, 1, 1, false, 0, 0); - /** - * This operation is added for compatibility purposes mostly. - * PLEASE NOTE: Please consider using Add instead - * Expected arguments: - * 0: N-dimensional input - * 1: bias vector - */ - #if NOT_EXCLUDED(OP_biasadd) - DECLARE_CUSTOM_OP(biasadd, 2, 1, true, 0, 0); - DECLARE_CUSTOM_OP(biasadd_bp, 3, 2, false, 0, 0); - #endif +/** + * This op calculates regularized incomplete beta integral Ix(a, b). + * Implementation is based on two algorithms depending on input values of a and + * b: + * - when a and b are both > maxValue (3000.), then Gauss-Legendre quadrature + * method is applied + * - when a and b are both <= maxValue (3000.), then modified Lentz’s algorithm + * for continued fractions is applied + * + * Input arrays: + * a: defines power t^{a-1}, must be > 0, type float. + * b: defines power (1-t)^{b-1}, must be > 0, type float. + * x: defines upper limit of integration, must be within (0 <= x <= 1) range, + * type float. + * + * Output array: + * 0: values of regularized incomplete beta integral that corresponds to + * variable upper limit x, type float + * + * Three input and one output arrays must have the same shape + */ +#if NOT_EXCLUDED(OP_betainc) +DECLARE_CONFIGURABLE_OP(betainc, 3, 1, false, 0, 0); +#endif - /** - * Returns a diagonal tensor with a given diagonal values. Given a diagonal, this operation returns a tensor with the diagonal and everything else padded with zeros. - */ - #if NOT_EXCLUDED(OP_diag) - DECLARE_CUSTOM_OP(diag, 1, 1, false, 0, 0); - #endif +/** + * This operation is added for compatibility purposes mostly. + * PLEASE NOTE: Please consider using Add instead + * Expected arguments: + * 0: N-dimensional input + * 1: bias vector + */ +#if NOT_EXCLUDED(OP_biasadd) +DECLARE_CUSTOM_OP(biasadd, 2, 1, true, 0, 0); +DECLARE_CUSTOM_OP(biasadd_bp, 3, 2, false, 0, 0); +#endif - /** - * Returns a diagonal tensor with a given diagonal values. Given a diagonal, this operation returns a tensor with the diagonal and everything else padded with zeros. - */ - #if NOT_EXCLUDED(OP_diag_part) - DECLARE_CUSTOM_OP(diag_part, 1, 1, false, 0, 0); - #endif +/** + * Returns a diagonal tensor with a given diagonal values. Given a diagonal, + * this operation returns a tensor with the diagonal and everything else padded + * with zeros. + */ +#if NOT_EXCLUDED(OP_diag) +DECLARE_CUSTOM_OP(diag, 1, 1, false, 0, 0); +#endif - /** - * Returns a diagonal vector for any submatricies with in a given tensor. - * It is an op inverse to matrix_set_giag. - * Using input tensor as batched 2D diagonals flat them to vector (1D) with diagonal values. - * - * Input : batched tensor with rank >=2 - * Output: tensor with rank lesser by 1 from input - */ - #if NOT_EXCLUDED(OP_matrix_diag_part) - DECLARE_CUSTOM_OP(matrix_diag_part, 1, 1, false, 0, 0); - #endif +/** + * Returns a diagonal tensor with a given diagonal values. Given a diagonal, + * this operation returns a tensor with the diagonal and everything else padded + * with zeros. + */ +#if NOT_EXCLUDED(OP_diag_part) +DECLARE_CUSTOM_OP(diag_part, 1, 1, false, 0, 0); +#endif - /** - * QR decomposition: A = QR, where Q is ortogonal (Q * QT = I) and R is upper triangular. - * For A (MxN) Q is M x M and R is (NxN). - * - * Input : - * 0 - float (or complex float) tensor with shape {.,..,...,M,N} - batch of float matricies - * - * Output: - * 0 - float tensor with shape {.,..,...,MxN} - batch of ortogonal matricies {Qs} - * 1 - float tensor with shape {.,..,...,NxN} - batch of upper triangular matricies {Rs} - */ - #if NOT_EXCLUDED(OP_qr) - DECLARE_CUSTOM_OP(qr, 1, 2, false, 0, 0); - #endif +/** + * Returns a diagonal vector for any submatricies with in a given tensor. + * It is an op inverse to matrix_set_giag. + * Using input tensor as batched 2D diagonals flat them to vector (1D) with + * diagonal values. + * + * Input : batched tensor with rank >=2 + * Output: tensor with rank lesser by 1 from input + */ +#if NOT_EXCLUDED(OP_matrix_diag_part) +DECLARE_CUSTOM_OP(matrix_diag_part, 1, 1, false, 0, 0); +#endif - /** - * This operation takes 2 arrays: original values, and values to be excluded. And returns 2 arrays: values left after exclusion, and indices in original array for surivals. - * Expected arguments: - * 0: vector with original values - * 1: vector with values to exclude - */ - #if NOT_EXCLUDED(OP_listdiff) - DECLARE_CUSTOM_OP(listdiff, 2, 2, false, 0, 0); - #endif +/** + * QR decomposition: A = QR, where Q is ortogonal (Q * QT = I) and R is upper + * triangular. For A (MxN) Q is M x M and R is (NxN). + * + * Input : + * 0 - float (or complex float) tensor with shape {.,..,...,M,N} - batch of + * float matricies + * + * Output: + * 0 - float tensor with shape {.,..,...,MxN} - batch of ortogonal matricies + * {Qs} 1 - float tensor with shape {.,..,...,NxN} - batch of upper triangular + * matricies {Rs} + */ +#if NOT_EXCLUDED(OP_qr) +DECLARE_CUSTOM_OP(qr, 1, 2, false, 0, 0); +#endif - /** - * This operation applies Add operation to specific inputs wrt indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ - #if NOT_EXCLUDED(OP_scatter_add) - DECLARE_OP(scatter_add, 3, 1, true); - #endif +/** + * This operation takes 2 arrays: original values, and values to be excluded. + * And returns 2 arrays: values left after exclusion, and indices in original + * array for surivals. Expected arguments: 0: vector with original values 1: + * vector with values to exclude + */ +#if NOT_EXCLUDED(OP_listdiff) +DECLARE_CUSTOM_OP(listdiff, 2, 2, false, 0, 0); +#endif - /** - * This operation applies Subtract operation to specific inputs wrt indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ - #if NOT_EXCLUDED(OP_scatter_sub) - DECLARE_OP(scatter_sub, 3, 1, true); - #endif +/** + * This operation applies Add operation to specific inputs wrt indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +#if NOT_EXCLUDED(OP_scatter_add) +DECLARE_OP(scatter_add, 3, 1, true); +#endif - /** - * This operation applies Multiply operation to specific inputs wrt indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ - #if NOT_EXCLUDED(OP_scatter_mul) - DECLARE_OP(scatter_mul, 3, 1, true); - #endif +/** + * This operation applies Subtract operation to specific inputs wrt indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +#if NOT_EXCLUDED(OP_scatter_sub) +DECLARE_OP(scatter_sub, 3, 1, true); +#endif - /** - * This operation applies Divide operation to specific inputs wrt indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ - #if NOT_EXCLUDED(OP_scatter_div) - DECLARE_OP(scatter_div, 3, 1, true); - #endif +/** + * This operation applies Multiply operation to specific inputs wrt indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +#if NOT_EXCLUDED(OP_scatter_mul) +DECLARE_OP(scatter_mul, 3, 1, true); +#endif - /** - * This operation applies Assign operation to specific inputs wrt indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ - #if NOT_EXCLUDED(OP_scatter_upd) - DECLARE_OP(scatter_upd, 3, 1, true); - #endif +/** + * This operation applies Divide operation to specific inputs wrt indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +#if NOT_EXCLUDED(OP_scatter_div) +DECLARE_OP(scatter_div, 3, 1, true); +#endif - /** - * This operation applies Max operation to specific inputs through given indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ - #if NOT_EXCLUDED(OP_scatter_max) - DECLARE_OP(scatter_max, 3, 1, true); - #endif +/** + * This operation applies Assign operation to specific inputs wrt indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +#if NOT_EXCLUDED(OP_scatter_upd) +DECLARE_OP(scatter_upd, 3, 1, true); +#endif - /** - * This operation applies Min operation to specific inputs through given indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ - #if NOT_EXCLUDED(OP_scatter_min) - DECLARE_OP(scatter_min, 3, 1, true); - #endif +/** + * This operation applies Max operation to specific inputs through given indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +#if NOT_EXCLUDED(OP_scatter_max) +DECLARE_OP(scatter_max, 3, 1, true); +#endif - /** - * This operation scatter "updates" elements into new output array according to given "indices" - * Expected arguments: - * indices: array containing elements/slices indexes of output array to put "updates" elements into, the rest output elements will be zeros - * updates: array containing elements to be inserted into output array - * shape: contains shape of output array - */ - #if NOT_EXCLUDED(OP_scatter_nd) - DECLARE_CUSTOM_OP(scatter_nd, 3, 1, false, 0, 0); - #endif +/** + * This operation applies Min operation to specific inputs through given indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +#if NOT_EXCLUDED(OP_scatter_min) +DECLARE_OP(scatter_min, 3, 1, true); +#endif - /** - * This operation scatter "updates" elements into input array along given "indices" - * Expected arguments: - * input: array to be updated - * indices: array containing elements/slices indexes of input array to put "updates" elements into - * updates: array containing elements to be inserted into input array - */ - #if NOT_EXCLUDED(OP_scatter_nd_update) - DECLARE_OP(scatter_nd_update, 3, 1, true); - #endif +/** + * This operation scatter "updates" elements into new output array according to + * given "indices" Expected arguments: indices: array containing elements/slices + * indexes of output array to put "updates" elements into, the rest output + * elements will be zeros updates: array containing elements to be inserted into + * output array shape: contains shape of output array + */ +#if NOT_EXCLUDED(OP_scatter_nd) +DECLARE_CUSTOM_OP(scatter_nd, 3, 1, false, 0, 0); +#endif - /** - * This operation adds "updates" elements to input array along given "indices" - * Expected arguments: - * input: array to be updated - * indices: array containing elements/slices indexes of input array to add "updates" elements to - * updates: array containing elements to be interfered with input - */ - #if NOT_EXCLUDED(OP_scatter_add) - DECLARE_OP(scatter_nd_add, 3, 1, true); - #endif +/** + * This operation scatter "updates" elements into input array along given + * "indices" Expected arguments: input: array to be updated indices: array + * containing elements/slices indexes of input array to put "updates" elements + * into updates: array containing elements to be inserted into input array + */ +#if NOT_EXCLUDED(OP_scatter_nd_update) +DECLARE_OP(scatter_nd_update, 3, 1, true); +#endif - /** - * This operation subtract "updates" elements from input array along given "indices" - * Expected arguments: - * input: array to be updated - * indices: array containing elements/slices indexes of input array to subtract "updates" elements from - * updates: array containing elements to be interfered with input - */ - #if NOT_EXCLUDED(OP_scatter_sub) - DECLARE_OP(scatter_nd_sub, 3, 1, true); - #endif +/** + * This operation adds "updates" elements to input array along given "indices" + * Expected arguments: + * input: array to be updated + * indices: array containing elements/slices indexes of input array to add + * "updates" elements to updates: array containing elements to be interfered + * with input + */ +#if NOT_EXCLUDED(OP_scatter_add) +DECLARE_OP(scatter_nd_add, 3, 1, true); +#endif - /** - * This operation takes input's shape, and returns new NDArray filled with specified value - * Expected arguments: - * input: N-dimensional array - * - * T args: - * 0: scalar value, used to fill NDArray - */ - #if NOT_EXCLUDED(OP_fill_as) - DECLARE_CONFIGURABLE_OP(fill_as, 1, 1, true, 1, 0); - #endif +/** + * This operation subtract "updates" elements from input array along given + * "indices" Expected arguments: input: array to be updated indices: array + * containing elements/slices indexes of input array to subtract "updates" + * elements from updates: array containing elements to be interfered with input + */ +#if NOT_EXCLUDED(OP_scatter_sub) +DECLARE_OP(scatter_nd_sub, 3, 1, true); +#endif - /** - * This operation applies element-wise rint (round to integral value) operation - */ - #if NOT_EXCLUDED(OP_rint) - DECLARE_OP(rint, 1, 1, true); - #endif +/** + * This operation takes input's shape, and returns new NDArray filled with + * specified value Expected arguments: input: N-dimensional array + * + * T args: + * 0: scalar value, used to fill NDArray + */ +#if NOT_EXCLUDED(OP_fill_as) +DECLARE_CONFIGURABLE_OP(fill_as, 1, 1, true, 1, 0); +#endif - /** - * This operation returns unique elements from input array as vector, and their original indices in input array - * Expected input: - * input: N-dimensional array - */ - #if NOT_EXCLUDED(OP_unique) - DECLARE_CUSTOM_OP(unique, 1, 2, false, 0, 0); - #endif +/** + * This operation applies element-wise rint (round to integral value) operation + */ +#if NOT_EXCLUDED(OP_rint) +DECLARE_OP(rint, 1, 1, true); +#endif - /** - * This operation returns 3 1D arrays for given 1D array with unique element count and indexes - * input: - * 0 - 1D array - * - * output: - * 0 - 1D array with unique values - * 1 - 1D array with ids for values in array above - * 2 - 1D array with counts for values in array above - */ - #if NOT_EXCLUDED(OP_unique_with_counts) - DECLARE_CUSTOM_OP(unique_with_counts, 1, 3, false, 0, 0); - #endif +/** + * This operation returns unique elements from input array as vector, and their + * original indices in input array Expected input: input: N-dimensional array + */ +#if NOT_EXCLUDED(OP_unique) +DECLARE_CUSTOM_OP(unique, 1, 2, false, 0, 0); +#endif - /** - * This operation splits input NDArray into multiple TADs along given dimensions - * Expected arguments: - * input: N-dimensional array - * - * Int args: - * 0..: TAD axis - */ - #if NOT_EXCLUDED(OP_tear) - DECLARE_CUSTOM_OP(tear, 1, -1, false, 0, -1); - #endif +/** + * This operation returns 3 1D arrays for given 1D array with unique element + * count and indexes input: 0 - 1D array + * + * output: + * 0 - 1D array with unique values + * 1 - 1D array with ids for values in array above + * 2 - 1D array with counts for values in array above + */ +#if NOT_EXCLUDED(OP_unique_with_counts) +DECLARE_CUSTOM_OP(unique_with_counts, 1, 3, false, 0, 0); +#endif - /** - * This op does the same as tear, just uses different input format: - * @tparam T - */ - #if NOT_EXCLUDED(OP_unstack) - DECLARE_CUSTOM_OP(unstack, 1, -1, false, 0, 1); - #endif +/** + * This operation splits input NDArray into multiple TADs along given dimensions + * Expected arguments: + * input: N-dimensional array + * + * Int args: + * 0..: TAD axis + */ +#if NOT_EXCLUDED(OP_tear) +DECLARE_CUSTOM_OP(tear, 1, -1, false, 0, -1); +#endif - /** - * This operation extracts a strided (optionally) slice from a tensor, - */ - #if NOT_EXCLUDED(OP_strided_slice) - DECLARE_CUSTOM_OP(strided_slice, 1, 1, false, 0, 5); // TODO: new op type needed. that returns VIEW - DECLARE_CUSTOM_OP(strided_slice_bp, 2, 1, false, 0, 5); - #endif +/** + * This op does the same as tear, just uses different input format: + * @tparam T + */ +#if NOT_EXCLUDED(OP_unstack) +DECLARE_CUSTOM_OP(unstack, 1, -1, false, 0, 1); +#endif - /** - * This operation extracts a slice from a tensor. - * - */ - #if NOT_EXCLUDED(OP_slice) - DECLARE_CUSTOM_OP(slice, 1, 1, false, 0, -2); - DECLARE_CUSTOM_OP(slice_bp, 2, 1, false, 0, -2); - #endif +/** + * This operation extracts a strided (optionally) slice from a tensor, + */ +#if NOT_EXCLUDED(OP_strided_slice) +DECLARE_CUSTOM_OP(strided_slice, 1, 1, false, 0, + 5); // TODO: new op type needed. that returns VIEW +DECLARE_CUSTOM_OP(strided_slice_bp, 2, 1, false, 0, 5); +#endif - /** - * This operation generate sequences. Basically from......to, with step used as increment. - * Expected arguments: - * start: optional scalar with starting value - * stop: optional scalar with end value - * step: optional scalar witn step value - * - * Int args: (optional) - * 0: optional scalar with starting value - * 1: optional scalar with end value - * 1: optional scalar witn step value - * - * T args: (optional) - * 0: optional scalar with starting value - * 1: optional scalar with end value - * 1: optional scalar witn step value - */ - #if NOT_EXCLUDED(OP_range) - DECLARE_CUSTOM_OP(range, -2, 1, false, -2, -2); - #endif +/** + * This operation extracts a slice from a tensor. + * + */ +#if NOT_EXCLUDED(OP_slice) +DECLARE_CUSTOM_OP(slice, 1, 1, false, 0, -2); +DECLARE_CUSTOM_OP(slice_bp, 2, 1, false, 0, -2); +#endif - /** - * This operation return one-hot encoded n-dimensional array - * Expected arguments: - * input: N-dimensional array - * - * T args: - * 0: 'on' value - * 1: 'off' value - * - * Int args: - * 0: depth - * 1: axis - */ - #if NOT_EXCLUDED(OP_onehot) - DECLARE_CUSTOM_OP(onehot, 1, 1, false, -2, -2); - #endif +/** + * This operation generate sequences. Basically from......to, with step used as + * increment. Expected arguments: start: optional scalar with starting value + * stop: optional scalar with end value + * step: optional scalar witn step value + * + * Int args: (optional) + * 0: optional scalar with starting value + * 1: optional scalar with end value + * 1: optional scalar witn step value + * + * T args: (optional) + * 0: optional scalar with starting value + * 1: optional scalar with end value + * 1: optional scalar witn step value + */ +#if NOT_EXCLUDED(OP_range) +DECLARE_CUSTOM_OP(range, -2, 1, false, -2, -2); +#endif +/** + * This operation return one-hot encoded n-dimensional array + * Expected arguments: + * input: N-dimensional array + * + * T args: + * 0: 'on' value + * 1: 'off' value + * + * Int args: + * 0: depth + * 1: axis + */ +#if NOT_EXCLUDED(OP_onehot) +DECLARE_CUSTOM_OP(onehot, 1, 1, false, -2, -2); +#endif - /** - * This operation calculate the confusion matrix for a - * pair of prediction and label 1-D arrays. - * Expected arguments: - * Input arrays: - * 0 - predictions: 1-D array - * 1 - labels: 1-D array - * 2 - weights : optional - * Int args: - * 0 - num_classes: optional - * - */ - #if NOT_EXCLUDED(OP_confusion_matrix) - DECLARE_CUSTOM_OP(confusion_matrix, 2, 1, false, 0, -2); - #endif +/** + * This operation calculate the confusion matrix for a + * pair of prediction and label 1-D arrays. + * Expected arguments: + * Input arrays: + * 0 - predictions: 1-D array + * 1 - labels: 1-D array + * 2 - weights : optional + * Int args: + * 0 - num_classes: optional + * + */ +#if NOT_EXCLUDED(OP_confusion_matrix) +DECLARE_CUSTOM_OP(confusion_matrix, 2, 1, false, 0, -2); +#endif - /** - * This operation stacks a list of rank tensors into one rank-(R+1) tensor. - * Expected arguments: - * 0...: N-Dimensional arrays to stack - * - */ - #if NOT_EXCLUDED(OP_stack) - DECLARE_CUSTOM_OP(stack, -1, 1, false, 0, 0); - #endif +/** + * This operation stacks a list of rank tensors into one rank-(R+1) tensor. + * Expected arguments: + * 0...: N-Dimensional arrays to stack + * + */ +#if NOT_EXCLUDED(OP_stack) +DECLARE_CUSTOM_OP(stack, -1, 1, false, 0, 0); +#endif - /** - * This operation returns length of input array - * Expected arguments: - * input: N-dimensional array - * - * TODO: make this operation reduction, to allow TAD -> size - */ - #if NOT_EXCLUDED(OP_size) - DECLARE_CUSTOM_OP(size, 1, 1, false, 0, 0); // add DeclarableScalarOp? - #endif +/** + * This operation returns length of input array + * Expected arguments: + * input: N-dimensional array + * + * TODO: make this operation reduction, to allow TAD -> size + */ +#if NOT_EXCLUDED(OP_size) +DECLARE_CUSTOM_OP(size, 1, 1, false, 0, 0); // add DeclarableScalarOp? +#endif +/** + * This operation returns rank of input array as scalar value. + */ +#if NOT_EXCLUDED(OP_rank) +DECLARE_CUSTOM_OP(rank, 1, 1, false, 0, 0); // ^ +#endif - /** - * This operation returns rank of input array as scalar value. - */ - #if NOT_EXCLUDED(OP_rank) - DECLARE_CUSTOM_OP(rank, 1, 1, false, 0, 0); // ^ - #endif +#if NOT_EXCLUDED(OP_broadcastgradientargs) +DECLARE_OP(broadcastgradientargs, 2, 2, true); +#endif +/** + * This operation takes input's shape, and returns new NDArray filled with zeros + * Expected arguments: + * input: N-dimensional array + * + */ +#if NOT_EXCLUDED(OP_zeros_as) +DECLARE_CUSTOM_OP(zeros_as, 1, 1, false, 0, 0); +#endif - #if NOT_EXCLUDED(OP_broadcastgradientargs) - DECLARE_OP(broadcastgradientargs, 2, 2, true); - #endif +/** + * This operation takes input's shape, and returns new NDArray filled with ones + * Expected arguments: + * input: N-dimensional array + * + */ +#if NOT_EXCLUDED(OP_ones_as) +DECLARE_CUSTOM_OP(ones_as, 1, 1, false, 0, 0); +#endif - /** - * This operation takes input's shape, and returns new NDArray filled with zeros - * Expected arguments: - * input: N-dimensional array - * - */ - #if NOT_EXCLUDED(OP_zeros_as) - DECLARE_CUSTOM_OP(zeros_as, 1, 1, false, 0, 0); - #endif +/** + * This operation applies element-wise pow(x, 2) to the given input + * Expected arguments: + * input: N-Dimensional array + */ +#if NOT_EXCLUDED(OP_square) +DECLARE_OP(square, 1, 1, true); +#endif - /** - * This operation takes input's shape, and returns new NDArray filled with ones - * Expected arguments: - * input: N-dimensional array - * - */ - #if NOT_EXCLUDED(OP_ones_as) - DECLARE_CUSTOM_OP(ones_as, 1, 1, false, 0, 0); - #endif +/** + * This op calculates Hurwitz zeta function zeta(x, q) = sum_{n=0}^{inf} (q + + * n)^{-x} Implementation is based on Euler-Maclaurin summation formula + * + * Input arrays: + * x: define power {-x}, must be > 1, type float. + * q: define summand in denominator, must be > 0, type float. + * + * Output array: + * 0: corresponding values of Hurwitz zeta function + * + * Two input and one output arrays must have the same shape + */ +#if NOT_EXCLUDED(OP_zeta) +DECLARE_CONFIGURABLE_OP(zeta, 2, 1, false, 0, 0); +#endif - /** - * This operation applies element-wise pow(x, 2) to the given input - * Expected arguments: - * input: N-Dimensional array - */ - #if NOT_EXCLUDED(OP_square) - DECLARE_OP(square, 1, 1, true); - #endif +/** + * This op calculates polygamma function psi^(n)(x). Implementation is based on + * serial representation written in terms of the Hurwitz zeta function: + * polygamma = (-1)^{n+1} * n! * zeta(n+1, x). + * + * Input arrays: + * 0: n - define derivative order (n+1), type integer (however currently is + * implemented as float casted to integer) 1: x - abscissa points where to + * evaluate the polygamma function, type float + * + * Output array: + * 0: values of polygamma function at corresponding x, type float + * + * Two input and one output arrays have the same shape + */ +#if NOT_EXCLUDED(OP_polygamma) +DECLARE_CONFIGURABLE_OP(polygamma, 2, 1, false, 0, 0); +#endif - /** - * This op calculates Hurwitz zeta function zeta(x, q) = sum_{n=0}^{inf} (q + n)^{-x} - * Implementation is based on Euler-Maclaurin summation formula - * - * Input arrays: - * x: define power {-x}, must be > 1, type float. - * q: define summand in denominator, must be > 0, type float. - * - * Output array: - * 0: corresponding values of Hurwitz zeta function - * - * Two input and one output arrays must have the same shape - */ - #if NOT_EXCLUDED(OP_zeta) - DECLARE_CONFIGURABLE_OP(zeta, 2, 1, false, 0, 0); - #endif +/** + * This op calculates lgamma function lgamma(x) = log(Gamma(x)) + * + * Input arrays: + * 0: x - input matrix + * + * Output array: + * 0: log of Gamma(x) + * + */ +#if NOT_EXCLUDED(OP_lgamma) +DECLARE_OP(lgamma, 1, 1, true); +#endif - /** - * This op calculates polygamma function psi^(n)(x). Implementation is based on serial representation written in - * terms of the Hurwitz zeta function: polygamma = (-1)^{n+1} * n! * zeta(n+1, x). - * - * Input arrays: - * 0: n - define derivative order (n+1), type integer (however currently is implemented as float casted to integer) - * 1: x - abscissa points where to evaluate the polygamma function, type float - * - * Output array: - * 0: values of polygamma function at corresponding x, type float - * - * Two input and one output arrays have the same shape - */ - #if NOT_EXCLUDED(OP_polygamma) - DECLARE_CONFIGURABLE_OP(polygamma, 2, 1, false, 0, 0); - #endif +/** + * This op calculates digamma function psi(x) = derivative of log(Gamma(x)) + * + * Input arrays: + * 0: x - abscissa points where to evaluate the digamma function, type float + * + * Output array: + * 0: values of digamma function at corresponding x, type float + * + */ +#if NOT_EXCLUDED(OP_digamma) +DECLARE_CONFIGURABLE_OP(digamma, 1, 1, false, 0, 0); +#endif - /** - * This op calculates lgamma function lgamma(x) = log(Gamma(x)) - * - * Input arrays: - * 0: x - input matrix - * - * Output array: - * 0: log of Gamma(x) - * - */ - #if NOT_EXCLUDED(OP_lgamma) - DECLARE_OP(lgamma, 1, 1, true); - #endif +/** + * This operation takes shape as first argument, and returns new NDArray filled + * with specific scalar value. Input arrays: 0 - shape vector 1 - optional + * scalar NDArray + * + * T arguments: + * 0 - optional scalar value + * + */ +#if NOT_EXCLUDED(OP_fill) +DECLARE_CUSTOM_OP(fill, 1, 1, false, -2, 0); +#endif - /** - * This op calculates digamma function psi(x) = derivative of log(Gamma(x)) - * - * Input arrays: - * 0: x - abscissa points where to evaluate the digamma function, type float - * - * Output array: - * 0: values of digamma function at corresponding x, type float - * - */ - #if NOT_EXCLUDED(OP_digamma) - DECLARE_CONFIGURABLE_OP(digamma, 1, 1, false, 0, 0); - #endif +/** + * This operation splits given NDArray into chunks of specific size, along given + * dimension Input arrays: 0 - input array 1 - array of sizes 2 - optional axis + * + * Integer arguments: + * 0 - optional axis + * + */ +#if NOT_EXCLUDED(OP_split_v) +DECLARE_CUSTOM_OP(split_v, 2, -1, false, 0, -2); +#endif - /** - * This operation takes shape as first argument, and returns new NDArray filled with specific scalar value. - * Input arrays: - * 0 - shape vector - * 1 - optional scalar NDArray - * - * T arguments: - * 0 - optional scalar value - * - */ - #if NOT_EXCLUDED(OP_fill) - DECLARE_CUSTOM_OP(fill, 1, 1, false, -2, 0); - #endif +/** + * This operation splits given NDArray into chunks of specific size, along given + * dimension 0 - input array 1 - optional axis + * + * Integer arguments: + * 0 - number of splits + * 1 - optional axis + */ +#if NOT_EXCLUDED(OP_split) +DECLARE_CUSTOM_OP(split, 1, -1, false, 0, 1); +#endif - /** - * This operation splits given NDArray into chunks of specific size, along given dimension - * Input arrays: - * 0 - input array - * 1 - array of sizes - * 2 - optional axis - * - * Integer arguments: - * 0 - optional axis - * - */ - #if NOT_EXCLUDED(OP_split_v) - DECLARE_CUSTOM_OP(split_v, 2, -1, false, 0, -2); - #endif +/** + * This operation adjusts image hue by delta + * Input arrays: + * 0 - input array with rank >= 3, must have at least one dimension equal 3, + * that is dimension containing channels. 1 - optional argument, input + * scalar-array containing delta + * + * T arguments: + * 0 - optional argument, delta value + * + * Int arguments: + * 0 - optional argument, corresponds to dimension with 3 channels + */ +#if NOT_EXCLUDED(OP_adjust_hue) +DECLARE_CONFIGURABLE_OP(adjust_hue, 1, 1, true, 0, 0); +#endif - /** - * This operation splits given NDArray into chunks of specific size, along given dimension - * 0 - input array - * 1 - optional axis - * - * Integer arguments: - * 0 - number of splits - * 1 - optional axis - */ - #if NOT_EXCLUDED(OP_split) - DECLARE_CUSTOM_OP(split, 1, -1, false, 0, 1); - #endif +/** + * This operation adjusts image saturation by delta + * Input arrays: + * 0 - input array with rank >= 3, must have at least one dimension equal 3, + * that is dimension containing channels. 1 - optional argument, input + * scalar-array containing saturation factor + * + * T arguments: + * 0 - optional argument, saturation factor + * + * Int arguments: + * 0 - optional argument, corresponds to dimension with 3 channels + */ +#if NOT_EXCLUDED(OP_adjust_saturation) +DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, 0, 0); +#endif +/** + * This operation adjusts image contrast by given factor ( z = (x - mean) * + * factor + mean ) Input arrays: 0 - input array with rank >= 3, must have last + * one dimension equal 3, that is dimension containing channels. 1 - optional + * argument, input scalar-array containing saturation contrast factor + * + * T arguments: + * 0 - optional argument, contrast factor + * + */ +#if NOT_EXCLUDED(OP_adjust_contrast) +DECLARE_CONFIGURABLE_OP(adjust_contrast, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(adjust_contrast_v2, 1, 1, true, 0, 0); +#endif - /** - * This operation adjusts image hue by delta - * Input arrays: - * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels. - * 1 - optional argument, input scalar-array containing delta - * - * T arguments: - * 0 - optional argument, delta value - * - * Int arguments: - * 0 - optional argument, corresponds to dimension with 3 channels - */ - #if NOT_EXCLUDED(OP_adjust_hue) - DECLARE_CONFIGURABLE_OP(adjust_hue, 1, 1, true, 0, 0); - #endif +/** + * This operation rearranges data from depth into blocks of spatial data. This + * is the reverse transformation of space_to_depth op. This op output is a copy + * of the input tensor where values from the depth dimension are moved in + * spatial blocks to the height and width dimensions. Int attr 0 indicates the + * input block size and how the data is moved. Input: 0 - 4D tensor on given + * type Output: 0 - 4D tensor of given type and proper shape + * + * Int arguments: + * 0 - block size + * 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels + * } 1 ("NCHW"): shape{ batch, channels, height, width } 2 ("NCHW_VECT_C"): int8 + * shape{ batch, channels / 4, height, width, 4 } optional (default 0) + */ +#if NOT_EXCLUDED(OP_depth_to_space) +DECLARE_CUSTOM_OP(depth_to_space, 1, 1, false, 0, -1); +#endif - /** - * This operation adjusts image saturation by delta - * Input arrays: - * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels. - * 1 - optional argument, input scalar-array containing saturation factor - * - * T arguments: - * 0 - optional argument, saturation factor - * - * Int arguments: - * 0 - optional argument, corresponds to dimension with 3 channels - */ - #if NOT_EXCLUDED(OP_adjust_saturation) - DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, 0, 0); - #endif +/** + * This operation rearranges blocks of spatial data, into depth.This op output + * is a copy of the input tensor where values from the height and width + * dimensions are moved to the depth dimension. Int attr 0 indicates the input + * block size. + * + * Input: + * - 4D tensor of given type + * Output: + * - 4D tensor + * + * Int arguments: + * 0 - block size + * 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels + * } 1 ("NCHW"): shape{ batch, channels, height, width } 2 ("NCHW_VECT_C"): int8 + * shape{ batch, channels / 4, height, width, 4 } optional (default 0) + * + */ +#if NOT_EXCLUDED(OP_space_to_depth) +DECLARE_CUSTOM_OP(space_to_depth, 1, 1, false, 0, -1); +#endif - /** - * This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean ) - * Input arrays: - * 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels. - * 1 - optional argument, input scalar-array containing saturation contrast factor - * - * T arguments: - * 0 - optional argument, contrast factor - * - */ - #if NOT_EXCLUDED(OP_adjust_contrast) - DECLARE_CONFIGURABLE_OP(adjust_contrast, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(adjust_contrast_v2, 1, 1, true, 0, 0); - #endif +/** + * This op calculates cross-product between input arguments + * Input arguments + * 0 - vector or tensor A + * 1 - vector or tensor B + */ +#if NOT_EXCLUDED(OP_cross) +DECLARE_OP(cross, 2, 1, false); +#endif +/** + * Zero-pads and then rearranges (permutes) blocks of spatial data into batch. + * More specifically, this op outputs a copy of the input tensor where values + * from the height and width dimensions are moved to the batch dimension. After + * the zero-padding, both height and width of the input must be divisible by the + * block size. + * + * Inputs: + * 0 - input tensor + * 1 - 2D paddings tensor (shape {M, 2}) + * + * Output: + * - result tensor + * + * Int args: + * 0 - block size (M) + * + */ +#if NOT_EXCLUDED(OP_space_to_batch) +DECLARE_CUSTOM_OP(space_to_batch, 2, 1, false, 0, 1); +#endif +/* + * This operation divides "spatial" dimensions [1, ..., M] of the input into a + * grid of blocks of shape block_shape, and interleaves these blocks with the + * "batch" dimension (0) such that in the output, the spatial dimensions [1, + * ..., M] correspond to the position within the grid, and the batch dimension + * combines both the position within a spatial block and the original batch + * position. Prior to division into blocks, the spatial dimensions of the input + * are optionally zero padded according to paddings. + * + * Inputs: + * 0 - input (N-D tensor) + * 1 - block_shape - int 1D tensor with M length + * 2 - paddings - int 2D tensor with shape {M, 2} + * + * Output: + * - N-D tensor with the same type as input 0. + * + * */ +#if NOT_EXCLUDED(OP_space_to_batch_nd) +DECLARE_CUSTOM_OP(space_to_batch_nd, 3, 1, false, 0, 0); +#endif +/** + * + * + */ +#if NOT_EXCLUDED(OP_batch_to_space) +DECLARE_CUSTOM_OP(batch_to_space, 2, 1, false, 0, 1); +#endif +#if NOT_EXCLUDED(OP_batch_to_space_nd) +DECLARE_CUSTOM_OP(batch_to_space_nd, 3, 1, false, 0, 0); +#endif - /** - * This operation rearranges data from depth into blocks of spatial data. This is the reverse transformation - * of space_to_depth op. This op output is a copy of the input tensor where values from the depth dimension - * are moved in spatial blocks to the height and width dimensions. Int attr 0 indicates the input - * block size and how the data is moved. - * Input: - * 0 - 4D tensor on given type - * Output: - * 0 - 4D tensor of given type and proper shape - * - * Int arguments: - * 0 - block size - * 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels } - * 1 ("NCHW"): shape{ batch, channels, height, width } - * 2 ("NCHW_VECT_C"): int8 shape{ batch, channels / 4, height, width, 4 } - * optional (default 0) - */ - #if NOT_EXCLUDED(OP_depth_to_space) - DECLARE_CUSTOM_OP(depth_to_space, 1, 1, false, 0, -1); - #endif +/** + * top_k operation returns a vector of k top values for + * given NDArray as tensor with default boolean (true) + * as sort for result index array + * will be sorted by the values in descending order. + * The first parameter is a NDArray for working. + * The second is k (default 1) - optional + * The third is boolean value(default is true) (0 - as is, 1 - sorted by value) + * optional + */ +#if NOT_EXCLUDED(OP_top_k) +DECLARE_CUSTOM_OP(top_k, 1, 2, false, 0, -1); +#endif - /** - * This operation rearranges blocks of spatial data, into depth.This op output is a copy of the input tensor - * where values from the height and width dimensions are moved to the depth dimension. Int attr 0 indicates - * the input block size. - * - * Input: - * - 4D tensor of given type - * Output: - * - 4D tensor - * - * Int arguments: - * 0 - block size - * 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels } - * 1 ("NCHW"): shape{ batch, channels, height, width } - * 2 ("NCHW_VECT_C"): int8 shape{ batch, channels / 4, height, width, 4 } - * optional (default 0) - * - */ - #if NOT_EXCLUDED(OP_space_to_depth) - DECLARE_CUSTOM_OP(space_to_depth, 1, 1, false, 0, -1); - #endif +/** + * in_top_k operation returns a vector of k boolean values for + * given NDArray as 2D matrix of predicted in the NDArray k top values + * The first parameter is a NDArray of predicted values (2d array). + * The second is NDArray as vector of indeces k top values will be search. + * The third is k + */ +#if NOT_EXCLUDED(OP_in_top_k) +DECLARE_CUSTOM_OP(in_top_k, 2, 1, true, 1, 1); +#endif - /** - * This op calculates cross-product between input arguments - * Input arguments - * 0 - vector or tensor A - * 1 - vector or tensor B - */ - #if NOT_EXCLUDED(OP_cross) - DECLARE_OP(cross, 2, 1, false); - #endif +/** + * moments operation calculate a mean and variation for given NDArray + * with reduce a result according to axis array given. + * For full axis the result is both mean and variance of all members in array. + * Otherwise there are two NDArrays with means and variances for + * Axes can be put as the second NDArray or as int vector. + * + * the optional flag "keep_dims" can be set as T param + */ +#if NOT_EXCLUDED(OP_moments) +DECLARE_CUSTOM_OP(moments, 1, 2, false, 0, -2); +#endif - /** - * Zero-pads and then rearranges (permutes) blocks of spatial data into batch. More specifically, this op - * outputs a copy of the input tensor where values from the height and width dimensions are moved to the - * batch dimension. After the zero-padding, both height and width of the input must be divisible by the block - * size. - * - * Inputs: - * 0 - input tensor - * 1 - 2D paddings tensor (shape {M, 2}) - * - * Output: - * - result tensor - * - * Int args: - * 0 - block size (M) - * - */ - #if NOT_EXCLUDED(OP_space_to_batch) - DECLARE_CUSTOM_OP(space_to_batch, 2, 1, false, 0, 1); - #endif +/** + * embedding_lookup - search for submatrices in given matrix and retunts them + * accordingly to index array given. + */ +#if NOT_EXCLUDED(OP_embedding_lookup) +DECLARE_CUSTOM_OP(embedding_lookup, 2, 1, false, 0, 1); +#endif - /* - * This operation divides "spatial" dimensions [1, ..., M] of the input into a grid of blocks of shape - * block_shape, and interleaves these blocks with the "batch" dimension (0) such that in the output, - * the spatial dimensions [1, ..., M] correspond to the position within the grid, and the batch dimension - * combines both the position within a spatial block and the original batch position. Prior to division into - * blocks, the spatial dimensions of the input are optionally zero padded according to paddings. - * - * Inputs: - * 0 - input (N-D tensor) - * 1 - block_shape - int 1D tensor with M length - * 2 - paddings - int 2D tensor with shape {M, 2} - * - * Output: - * - N-D tensor with the same type as input 0. - * - * */ - #if NOT_EXCLUDED(OP_space_to_batch_nd) - DECLARE_CUSTOM_OP(space_to_batch_nd, 3, 1, false, 0, 0); - #endif +/** + * dynamic_partition - partition a input tensor onto num_partitions + * accordingly to index array given. + * + * the first param - NDArray to be partitioned. + * the second param - index array + * the third param (integer param) - num or partitions. + * + * returns a num of NDArrays as output + */ +#if NOT_EXCLUDED(OP_dynamic_partition) +DECLARE_CUSTOM_OP(dynamic_partition, 2, 1, false, 0, 1); +#endif - /** - * - * - */ - #if NOT_EXCLUDED(OP_batch_to_space) - DECLARE_CUSTOM_OP(batch_to_space, 2, 1, false, 0, 1); - #endif - #if NOT_EXCLUDED(OP_batch_to_space_nd) - DECLARE_CUSTOM_OP(batch_to_space_nd, 3, 1, false, 0, 0); - #endif +#if NOT_EXCLUDED(OP_dynamic_partition_bp) +DECLARE_CUSTOM_OP(dynamic_partition_bp, 3, 2, false, 0, 1); +#endif - /** - * top_k operation returns a vector of k top values for - * given NDArray as tensor with default boolean (true) - * as sort for result index array - * will be sorted by the values in descending order. - * The first parameter is a NDArray for working. - * The second is k (default 1) - optional - * The third is boolean value(default is true) (0 - as is, 1 - sorted by value) optional - */ - #if NOT_EXCLUDED(OP_top_k) - DECLARE_CUSTOM_OP(top_k, 1, 2, false, 0, -1); - #endif +/** + * dynamic_stitch - merge partitions from the second param a input tensor + * into a single tensor accordingly to index array given. + * + * the first param - index array + * the second params - tensors to be merged + * + * returns a num of NDArrays as output + * + * the operation is inversion od dynamic_partition + */ +#if NOT_EXCLUDED(OP_dynamic_stitch) +DECLARE_CUSTOM_OP(dynamic_stitch, 2, 1, false, 0, 0); +#endif - /** - * in_top_k operation returns a vector of k boolean values for - * given NDArray as 2D matrix of predicted in the NDArray k top values - * The first parameter is a NDArray of predicted values (2d array). - * The second is NDArray as vector of indeces k top values will be search. - * The third is k - */ - #if NOT_EXCLUDED(OP_in_top_k) - DECLARE_CUSTOM_OP(in_top_k, 2, 1, true, 1, 1); - #endif +/** + * zero_fraction op. + * compute a fraction of zeros in given array + * + * input param - an array (tensor) + * output value - a real number with given type (e.g. float or double) + */ +#if NOT_EXCLUDED(OP_zero_fraction) +DECLARE_CUSTOM_OP(zero_fraction, 1, 1, false, 0, 0); +#endif - /** - * moments operation calculate a mean and variation for given NDArray - * with reduce a result according to axis array given. - * For full axis the result is both mean and variance of all members in array. - * Otherwise there are two NDArrays with means and variances for - * Axes can be put as the second NDArray or as int vector. - * - * the optional flag "keep_dims" can be set as T param - */ - #if NOT_EXCLUDED(OP_moments) - DECLARE_CUSTOM_OP(moments, 1, 2, false, 0, -2); - #endif +/** + * xw_plus_b op. + * multiply two first matrices and add third vector to each row of result + * + * input params: + * - 2D matrix NxM + * - 2D matrix MxN + * - 1D vector with N elements + * output value - 2D matrix NxN as multiply of matrixes and add vector + * Int args: + * 0 - optional switcher of weights format, if int arg == 1 - mkldnn, else + * mmul + */ +#if NOT_EXCLUDED(OP_xw_plus_b) +DECLARE_CUSTOM_OP(xw_plus_b, 3, 1, false, 0, 0); +DECLARE_CUSTOM_OP(xw_plus_b_bp, 4, 3, false, 0, 0); +#endif - /** - * embedding_lookup - search for submatrices in given matrix and retunts them - * accordingly to index array given. - */ - #if NOT_EXCLUDED(OP_embedding_lookup) - DECLARE_CUSTOM_OP(embedding_lookup, 2, 1, false, 0, 1); - #endif +/** + * This operation is missed due it simplicy. + * Input and output params are the same after operation. + * Input - NDArray, output - NDArray with the same shape. + */ +#if NOT_EXCLUDED(OP_stop_gradient) +DECLARE_OP(stop_gradient, 1, 1, true); +#endif - /** - * dynamic_partition - partition a input tensor onto num_partitions - * accordingly to index array given. - * - * the first param - NDArray to be partitioned. - * the second param - index array - * the third param (integer param) - num or partitions. - * - * returns a num of NDArrays as output - */ - #if NOT_EXCLUDED(OP_dynamic_partition) - DECLARE_CUSTOM_OP(dynamic_partition, 2, 1, false, 0, 1); - #endif +#if NOT_EXCLUDED(OP_parallel_stack) +DECLARE_CUSTOM_OP(parallel_stack, -1, 1, false, 0, 0); +#endif - #if NOT_EXCLUDED(OP_dynamic_partition_bp) - DECLARE_CUSTOM_OP(dynamic_partition_bp, 3, 2, false, 0, 1); - #endif +/** + * normalize_moments operation normalize already calculated mean and variation + * accordingly to shift and count. + * input params: + * - count of data + * - tensor with mean + * - tensor with variance (the same shape as before) + * + * - optional floating point param shift. + * + * returns a normalized pair mean and variance with the same shapes as input + */ +#if NOT_EXCLUDED(OP_normalize_moments) +DECLARE_CUSTOM_OP(normalize_moments, 3, 2, false, 1, 0); +#endif - /** - * dynamic_stitch - merge partitions from the second param a input tensor - * into a single tensor accordingly to index array given. - * - * the first param - index array - * the second params - tensors to be merged - * - * returns a num of NDArrays as output - * - * the operation is inversion od dynamic_partition - */ - #if NOT_EXCLUDED(OP_dynamic_stitch) - DECLARE_CUSTOM_OP(dynamic_stitch, 2, 1, false, 0, 0); - #endif +/** + * sufficient_statistics operation return calculated mean and variation with + * data count. this operation is invert for moments accordingly to shift and + * count. input params: + * - input tensor + * - axes vector + * + * + * - optional floating point param shift. + * - optional int (as bool) keep_dimension + * + * returns four tensors: + * - scalar tensor (data count) + * - sum elements of input (accross axises) + * - sum of squares of input (accross axises) + * - shift (if was given by input floating param) + */ +#if NOT_EXCLUDED(OP_sufficient_statistics) +DECLARE_CUSTOM_OP(sufficient_statistics, 2, 3, false, 0, 0); +#endif - /** - * zero_fraction op. - * compute a fraction of zeros in given array - * - * input param - an array (tensor) - * output value - a real number with given type (e.g. float or double) - */ - #if NOT_EXCLUDED(OP_zero_fraction) - DECLARE_CUSTOM_OP(zero_fraction, 1, 1, false, 0, 0); - #endif +/** + * This op calculates weighted logarithmic loss of input + * Input arguments + * 0 - target + * 1 - input + * 2 - weights (scalar or vector with same as last dimension) + * + * return value - a tensor with the same shape as target or input + */ +#if NOT_EXCLUDED(OP_weighted_cross_entropy_with_logits) +DECLARE_OP(weighted_cross_entropy_with_logits, 3, 1, true); +#endif - /** - * xw_plus_b op. - * multiply two first matrices and add third vector to each row of result - * - * input params: - * - 2D matrix NxM - * - 2D matrix MxN - * - 1D vector with N elements - * output value - 2D matrix NxN as multiply of matrixes and add vector - * Int args: - * 0 - optional switcher of weights format, if int arg == 1 - mkldnn, else mmul - */ - #if NOT_EXCLUDED(OP_xw_plus_b) - DECLARE_CUSTOM_OP(xw_plus_b, 3, 1, false, 0, 0); - DECLARE_CUSTOM_OP(xw_plus_b_bp, 4, 3, false, 0, 0); - #endif +/** + * This op calculates dropout of input + * Input arguments + * 0 - input tensor + * 1 - noise_shape - (vector with shape to reduce) - optional + * + * int parameter - seed for random numbers + * T parameter - probability (should be between 0 and 1) + * return value - a tensor with the same shape as target or input + */ +#if NOT_EXCLUDED(OP_dropout) +DECLARE_CONFIGURABLE_OP(dropout, 1, 1, true, 1, 1); +#endif +#if NOT_EXCLUDED(OP_dropout_bp) +DECLARE_CONFIGURABLE_OP(dropout_bp, 2, 1, false, 1, 1); +#endif - /** - * This operation is missed due it simplicy. - * Input and output params are the same after operation. - * Input - NDArray, output - NDArray with the same shape. - */ - #if NOT_EXCLUDED(OP_stop_gradient) - DECLARE_OP(stop_gradient, 1, 1, true); - #endif +/* Calculates alpha weighted dropout + T params: + 0 - drop probability + 1 - alpha value + 2 - alpha' value + 3 - beta value + */ +#if NOT_EXCLUDED(OP_alpha_dropout_bp) +DECLARE_CONFIGURABLE_OP(alpha_dropout_bp, 2, 1, false, 4, 1); +#endif - #if NOT_EXCLUDED(OP_parallel_stack) - DECLARE_CUSTOM_OP(parallel_stack, -1, 1, false, 0, 0); - #endif +/** + * bincount operation return a vector with element counted. + * + * input params: + * - input tensor - only int part are accepted + * - weights - the same shape tensor with integer weights for element + * (optional) default weight - 1,1,1..,1 for all values in the tensor + * + * optional ints: + * - min_length - zero or greater + * - max_length - between min_length and max(input) + 1 + * + * returns four tensors: + * - vector tensor with length to min(max_len, max(input) + 1) with count + * of values in indexed place + * + */ +#if NOT_EXCLUDED(OP_bincount) +DECLARE_CUSTOM_OP(bincount, 1, 1, false, 0, 0); +#endif - /** - * normalize_moments operation normalize already calculated mean and variation - * accordingly to shift and count. - * input params: - * - count of data - * - tensor with mean - * - tensor with variance (the same shape as before) - * - * - optional floating point param shift. - * - * returns a normalized pair mean and variance with the same shapes as input - */ - #if NOT_EXCLUDED(OP_normalize_moments) - DECLARE_CUSTOM_OP(normalize_moments, 3, 2, false, 1, 0); - #endif - - /** - * sufficient_statistics operation return calculated mean and variation with data count. - * this operation is invert for moments - * accordingly to shift and count. - * input params: - * - input tensor - * - axes vector - * - * - * - optional floating point param shift. - * - optional int (as bool) keep_dimension - * - * returns four tensors: - * - scalar tensor (data count) - * - sum elements of input (accross axises) - * - sum of squares of input (accross axises) - * - shift (if was given by input floating param) - */ - #if NOT_EXCLUDED(OP_sufficient_statistics) - DECLARE_CUSTOM_OP(sufficient_statistics, 2, 3, false, 0, 0); - #endif - - /** - * This op calculates weighted logarithmic loss of input - * Input arguments - * 0 - target - * 1 - input - * 2 - weights (scalar or vector with same as last dimension) - * - * return value - a tensor with the same shape as target or input - */ - #if NOT_EXCLUDED(OP_weighted_cross_entropy_with_logits) - DECLARE_OP(weighted_cross_entropy_with_logits, 3, 1, true); - #endif - - /** - * This op calculates dropout of input - * Input arguments - * 0 - input tensor - * 1 - noise_shape - (vector with shape to reduce) - optional - * - * int parameter - seed for random numbers - * T parameter - probability (should be between 0 and 1) - * return value - a tensor with the same shape as target or input - */ - #if NOT_EXCLUDED(OP_dropout) - DECLARE_CONFIGURABLE_OP(dropout, 1, 1, true, 1, 1); - #endif - #if NOT_EXCLUDED(OP_dropout_bp) - DECLARE_CONFIGURABLE_OP(dropout_bp, 2, 1, false, 1, 1); - #endif - - /* Calculates alpha weighted dropout - T params: - 0 - drop probability - 1 - alpha value - 2 - alpha' value - 3 - beta value - */ - #if NOT_EXCLUDED(OP_alpha_dropout_bp) - DECLARE_CONFIGURABLE_OP(alpha_dropout_bp, 2, 1, false, 4, 1); - #endif - - - /** - * bincount operation return a vector with element counted. - * - * input params: - * - input tensor - only int part are accepted - * - weights - the same shape tensor with integer weights for element (optional) - * default weight - 1,1,1..,1 for all values in the tensor - * - * optional ints: - * - min_length - zero or greater - * - max_length - between min_length and max(input) + 1 - * - * returns four tensors: - * - vector tensor with length to min(max_len, max(input) + 1) with count - * of values in indexed place - * - */ - #if NOT_EXCLUDED(OP_bincount) - DECLARE_CUSTOM_OP(bincount, 1, 1, false, 0, 0); - #endif - - /** - * broadcast_dynamic_shape op. - * - * input params: - * 0 - the first shape (vector with shape) - * 1 - the second shape (vector with shape) - * - * return value: - * vector with broadcasted shape - */ - #if NOT_EXCLUDED(OP_broadcast_dynamic_shape) - DECLARE_CUSTOM_OP(broadcast_dynamic_shape, 2, 1, false, 0, 0); - #endif - - /** - * matrix_determinant op. - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * M) - * - * return value: - * tensor with dimension (x * y * z * ::: *) with determinant for all - * M x M matricies - */ - #if NOT_EXCLUDED(OP_matrix_determinant) - DECLARE_CUSTOM_OP(matrix_determinant, 1, 1, false, 0, 0); - #endif +/** + * broadcast_dynamic_shape op. + * + * input params: + * 0 - the first shape (vector with shape) + * 1 - the second shape (vector with shape) + * + * return value: + * vector with broadcasted shape + */ +#if NOT_EXCLUDED(OP_broadcast_dynamic_shape) +DECLARE_CUSTOM_OP(broadcast_dynamic_shape, 2, 1, false, 0, 0); +#endif - /** - * log_matrix_determinant op. - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * M) - * - * return value: - * tensor with dimension (x * y * z * ::: *) with log determinant for all - * M x M matricies - */ +/** + * matrix_determinant op. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) + * + * return value: + * tensor with dimension (x * y * z * ::: *) with determinant for all + * M x M matricies + */ +#if NOT_EXCLUDED(OP_matrix_determinant) +DECLARE_CUSTOM_OP(matrix_determinant, 1, 1, false, 0, 0); +#endif - #if NOT_EXCLUDED(OP_log_matrix_determinant) - DECLARE_CUSTOM_OP(log_matrix_determinant, 1, 1, false, 0, 0); - #endif +/** + * log_matrix_determinant op. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) + * + * return value: + * tensor with dimension (x * y * z * ::: *) with log determinant for all + * M x M matricies + */ - /** - * logdet op. Logarithm of the determinant of hermitian positive matricies. - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * M) - * - * return value: - * tensor with dimension (x * y * z * ::: *) with log determinant for all - * M x M matricies - */ +#if NOT_EXCLUDED(OP_log_matrix_determinant) +DECLARE_CUSTOM_OP(log_matrix_determinant, 1, 1, false, 0, 0); +#endif - #if NOT_EXCLUDED(OP_logdet) - DECLARE_CUSTOM_OP(logdet, 1, 1, false, 0, 0); - #endif +/** + * logdet op. Logarithm of the determinant of hermitian positive matricies. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) + * + * return value: + * tensor with dimension (x * y * z * ::: *) with log determinant for all + * M x M matricies + */ - /** - * matrix_solve_ls op (lstsq) - solves one or more linear least-squares problems. - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * N) - left parts of equations - * 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations - * - * float args: - * 0 - l2_regularizer (default 0. and only for 0 implemented) - * - * boolean args: - * 0 - fast - default is true (optional) - use Cholesky decomposition instead QR decomposition of matricies. - * - * return value: - * tensor with dimension (x * y * z * ::: * N * K) with solutions - * - */ - #if NOT_EXCLUDED(OP_lstsq) - DECLARE_CUSTOM_OP(lstsq, 2, 1, false, 0, 0); - #endif +#if NOT_EXCLUDED(OP_logdet) +DECLARE_CUSTOM_OP(logdet, 1, 1, false, 0, 0); +#endif - /* solve_ls - analog of lstsq op with another solution approach - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * N) - left parts of equations - * 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations - * - * float args: - * 0 - l2_regularizer (default 0. and only for 0 implemented) - * - * boolean args: - * 0 - fast - default is true (optional) - use Cholesky decomposition instead QR decomposition of matricies. - * - * return value: - * tensor with dimension (x * y * z * ::: * N * K) with solutions - * - * Note: if fast is false - then l2_regularizer arg is ignored and used lstsq method due QR decomposition - * */ - #if NOT_EXCLUDED(OP_solve_ls) - DECLARE_CUSTOM_OP(solve_ls, 2, 1, false, 0, 0); - #endif +/** + * matrix_solve_ls op (lstsq) - solves one or more linear least-squares + * problems. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * N) - left parts of + * equations 1 - the tensor with dimension (x * y * z * ::: * M * K) - right + * parts of equations + * + * float args: + * 0 - l2_regularizer (default 0. and only for 0 implemented) + * + * boolean args: + * 0 - fast - default is true (optional) - use Cholesky decomposition instead + * QR decomposition of matricies. + * + * return value: + * tensor with dimension (x * y * z * ::: * N * K) with solutions + * + */ +#if NOT_EXCLUDED(OP_lstsq) +DECLARE_CUSTOM_OP(lstsq, 2, 1, false, 0, 0); +#endif - /** - * matrix_inverse op. - make inverse for all 2D square matricies found in the input tensor - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * M) - * - * return value: - * tensor with dimension (x * y * z * ::: * M * M) with inverse M x M matricies in it - */ - #if NOT_EXCLUDED(OP_matrix_inverse) - DECLARE_OP(matrix_inverse, 1, 1, true); - #endif +/* solve_ls - analog of lstsq op with another solution approach + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * N) - left parts of + * equations 1 - the tensor with dimension (x * y * z * ::: * M * K) - right + * parts of equations + * + * float args: + * 0 - l2_regularizer (default 0. and only for 0 implemented) + * + * boolean args: + * 0 - fast - default is true (optional) - use Cholesky decomposition instead + * QR decomposition of matricies. + * + * return value: + * tensor with dimension (x * y * z * ::: * N * K) with solutions + * + * Note: if fast is false - then l2_regularizer arg is ignored and used lstsq + * method due QR decomposition + * */ +#if NOT_EXCLUDED(OP_solve_ls) +DECLARE_CUSTOM_OP(solve_ls, 2, 1, false, 0, 0); +#endif - /** - * triangular_solve op. - reverse Gaussian method for solve systems of linear equations. - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of equations - * 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations - * - * boolean args: - * 0 - lower - default is true (optional) - left part is lower triangular matrix - * 1 - adjoint - default is false (optional) - indicate input matrix or its adjoint (hermitian addition) should be used - * - * return value: - * tensor with dimension (x * y * z * ::: * M * K) with solutions - * - */ - #if NOT_EXCLUDED(OP_triangular_solve) - DECLARE_CUSTOM_OP(triangular_solve, 2, 1, false, 0, 0); - #endif +/** + * matrix_inverse op. - make inverse for all 2D square matricies found in the + * input tensor + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) + * + * return value: + * tensor with dimension (x * y * z * ::: * M * M) with inverse M x M + * matricies in it + */ +#if NOT_EXCLUDED(OP_matrix_inverse) +DECLARE_OP(matrix_inverse, 1, 1, true); +#endif - /** - * solve op. - solve systems of linear equations - general method. - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of equations - * 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations - * - * boolean args: - * 0 - adjoint - default is false (optional) - indicate input matrix or its adjoint (hermitian addition) should be used - * - * return value: - * tensor with dimension (x * y * z * ::: * M * K) with solutions - * - */ - #if NOT_EXCLUDED(OP_solve) - DECLARE_CUSTOM_OP(solve, 2, 1, true, 0, 0); - #endif +/** + * triangular_solve op. - reverse Gaussian method for solve systems of linear + * equations. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of + * equations 1 - the tensor with dimension (x * y * z * ::: * M * K) - right + * parts of equations + * + * boolean args: + * 0 - lower - default is true (optional) - left part is lower triangular + * matrix 1 - adjoint - default is false (optional) - indicate input matrix or + * its adjoint (hermitian addition) should be used + * + * return value: + * tensor with dimension (x * y * z * ::: * M * K) with solutions + * + */ +#if NOT_EXCLUDED(OP_triangular_solve) +DECLARE_CUSTOM_OP(triangular_solve, 2, 1, false, 0, 0); +#endif - /** - * lu op. - make LUP decomposition of given batch of 2D square matricies - * - * input params: - * 0 - float tensor with dimension (x * y * z * ::: * M * M) - * - * return value: - * 0 - float tensor with dimension (x * y * z * ::: * M * M) with LU M x M matricies in it - * 1 - int (32 or 64) batched vector of permutations with length M - shape (x * y * z * ::: * M) - * - * int argument: - * 0 - data type of output permutaion vector (int32 or int64), optional, default INT32 - */ +/** + * solve op. - solve systems of linear equations - general method. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of + * equations 1 - the tensor with dimension (x * y * z * ::: * M * K) - right + * parts of equations + * + * boolean args: + * 0 - adjoint - default is false (optional) - indicate input matrix or its + * adjoint (hermitian addition) should be used + * + * return value: + * tensor with dimension (x * y * z * ::: * M * K) with solutions + * + */ +#if NOT_EXCLUDED(OP_solve) +DECLARE_CUSTOM_OP(solve, 2, 1, true, 0, 0); +#endif - #if NOT_EXCLUDED(OP_matrix_inverse) - DECLARE_CUSTOM_OP(lu, 1, 2, false, 0, 0); - #endif +/** + * lu op. - make LUP decomposition of given batch of 2D square matricies + * + * input params: + * 0 - float tensor with dimension (x * y * z * ::: * M * M) + * + * return value: + * 0 - float tensor with dimension (x * y * z * ::: * M * M) with LU M x M + * matricies in it 1 - int (32 or 64) batched vector of permutations with length + * M - shape (x * y * z * ::: * M) + * + * int argument: + * 0 - data type of output permutaion vector (int32 or int64), optional, + * default INT32 + */ - /** - * sequence_mask op. - make mask for given tensor filled by (j > x[i_1, i_2,...,i_n]) -> z[i_1, i_2,...,i_n,j] - * - * input params: - * 0 - the ND-tensor filled by integer-like values - * - * optional int param - maxlength (maxlength >= max(x)). By default maxlength = max(x). - * return value: - * (N+1)D tensor filled by 0 and 1 accordingly the mask - */ - #if NOT_EXCLUDED(OP_sequence_mask) - DECLARE_CUSTOM_OP(sequence_mask, 1, 1, false, 0, 0); - #endif - /** - * segment_max op. - make a tensor filled by max values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * return value: - * tensor with max values according to indices sets. - */ +#if NOT_EXCLUDED(OP_matrix_inverse) +DECLARE_CUSTOM_OP(lu, 1, 2, false, 0, 0); +#endif - #if NOT_EXCLUDED(OP_segment_max) - DECLARE_CUSTOM_OP(segment_max, 2, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_segment_max_bp) - DECLARE_CUSTOM_OP(segment_max_bp, 3, 2, false, 0, 0); - #endif +/** + * sequence_mask op. - make mask for given tensor filled by (j > x[i_1, + * i_2,...,i_n]) -> z[i_1, i_2,...,i_n,j] + * + * input params: + * 0 - the ND-tensor filled by integer-like values + * + * optional int param - maxlength (maxlength >= max(x)). By default maxlength = + * max(x). return value: (N+1)D tensor filled by 0 and 1 accordingly the mask + */ +#if NOT_EXCLUDED(OP_sequence_mask) +DECLARE_CUSTOM_OP(sequence_mask, 1, 1, false, 0, 0); +#endif +/** + * segment_max op. - make a tensor filled by max values according to index + * tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * return value: + * tensor with max values according to indices sets. + */ - /** - * segment_min op. - make a tensor filled by min values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * return value: - * tensor with min values according to indices sets. - */ - #if NOT_EXCLUDED(OP_segment_min) - DECLARE_CUSTOM_OP(segment_min, 2, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_segment_min_bp) - DECLARE_CUSTOM_OP(segment_min_bp, 3, 2, false, 0, 0); - #endif +#if NOT_EXCLUDED(OP_segment_max) +DECLARE_CUSTOM_OP(segment_max, 2, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_segment_max_bp) +DECLARE_CUSTOM_OP(segment_max_bp, 3, 2, false, 0, 0); +#endif - /** - * segment_sum op. - make a tensor filled by sum of values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * return value: - * tensor with sum of values according to indices sets. - */ - #if NOT_EXCLUDED(OP_segment_sum) - DECLARE_CUSTOM_OP(segment_sum, 2, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_segment_sum_bp) - DECLARE_CUSTOM_OP(segment_sum_bp, 3, 2, false, 0, 0); - #endif +/** + * segment_min op. - make a tensor filled by min values according to index + * tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * return value: + * tensor with min values according to indices sets. + */ +#if NOT_EXCLUDED(OP_segment_min) +DECLARE_CUSTOM_OP(segment_min, 2, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_segment_min_bp) +DECLARE_CUSTOM_OP(segment_min_bp, 3, 2, false, 0, 0); +#endif - /** - * segment_prod op. - make a tensor filled by product of values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * return value: - * tensor with product of values according to indices sets. - */ - #if NOT_EXCLUDED(OP_segment_prod) - DECLARE_CUSTOM_OP(segment_prod, 2, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_segment_prod_bp) - DECLARE_CUSTOM_OP(segment_prod_bp, 3, 2, false, 0, 0); - #endif - /** - * segment_mean op. - make a tensor filled by average of values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * return value: - * tensor with average of values according to indices sets. - */ - #if NOT_EXCLUDED(OP_segment_mean) - DECLARE_CUSTOM_OP(segment_mean, 2, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_segment_mean_bp) - DECLARE_CUSTOM_OP(segment_mean_bp, 3, 2, false, 0, 0); - #endif +/** + * segment_sum op. - make a tensor filled by sum of values according to index + * tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * return value: + * tensor with sum of values according to indices sets. + */ +#if NOT_EXCLUDED(OP_segment_sum) +DECLARE_CUSTOM_OP(segment_sum, 2, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_segment_sum_bp) +DECLARE_CUSTOM_OP(segment_sum_bp, 3, 2, false, 0, 0); +#endif - /** - * unsorted_segment_max op. - make a tensor filled by max values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * return value: - * tensor with max values according to indices sets. - */ - #if NOT_EXCLUDED(OP_unsorted_segment_max) - DECLARE_CUSTOM_OP(unsorted_segment_max, 2, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_unsorted_segment_max_bp) - DECLARE_CUSTOM_OP(unsorted_segment_max_bp, 3, 2, false, 0, 1); - #endif +/** + * segment_prod op. - make a tensor filled by product of values according to + * index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * return value: + * tensor with product of values according to indices sets. + */ +#if NOT_EXCLUDED(OP_segment_prod) +DECLARE_CUSTOM_OP(segment_prod, 2, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_segment_prod_bp) +DECLARE_CUSTOM_OP(segment_prod_bp, 3, 2, false, 0, 0); +#endif +/** + * segment_mean op. - make a tensor filled by average of values according to + * index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * return value: + * tensor with average of values according to indices sets. + */ +#if NOT_EXCLUDED(OP_segment_mean) +DECLARE_CUSTOM_OP(segment_mean, 2, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_segment_mean_bp) +DECLARE_CUSTOM_OP(segment_mean_bp, 3, 2, false, 0, 0); +#endif - /** - * unsorted_segment_min op. - make a tensor filled by min values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * integer param: - * 0 - num of segments - * - * return value: - * tensor with min values according to indices sets. - */ - #if NOT_EXCLUDED(OP_unsorted_segment_min_bp) - DECLARE_CUSTOM_OP(unsorted_segment_min, 2, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_unsorted_segment_min_bp) - DECLARE_CUSTOM_OP(unsorted_segment_min_bp, 3, 2, false, 0, 1); - #endif +/** + * unsorted_segment_max op. - make a tensor filled by max values according to + * index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * return value: + * tensor with max values according to indices sets. + */ +#if NOT_EXCLUDED(OP_unsorted_segment_max) +DECLARE_CUSTOM_OP(unsorted_segment_max, 2, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_unsorted_segment_max_bp) +DECLARE_CUSTOM_OP(unsorted_segment_max_bp, 3, 2, false, 0, 1); +#endif - /** - * unsorted_segment_sum op. - make a tensor filled by sum of values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * integer param: - * 0 - num of segments - * - * return value: - * tensor with sum of values according to indices sets. - */ - #if NOT_EXCLUDED(OP_unsorted_segment_sum) - DECLARE_CUSTOM_OP(unsorted_segment_sum, 2, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_unsorted_segment_sum_bp) - DECLARE_CUSTOM_OP(unsorted_segment_sum_bp, 3, 2, false, 0, 1); - #endif +/** + * unsorted_segment_min op. - make a tensor filled by min values according to + * index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * integer param: + * 0 - num of segments + * + * return value: + * tensor with min values according to indices sets. + */ +#if NOT_EXCLUDED(OP_unsorted_segment_min_bp) +DECLARE_CUSTOM_OP(unsorted_segment_min, 2, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_unsorted_segment_min_bp) +DECLARE_CUSTOM_OP(unsorted_segment_min_bp, 3, 2, false, 0, 1); +#endif - /** - * unsorted_segment_prod op. - make a tensor filled by product of values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * integer param: - * 0 - num of segments - * - * return value: - * tensor with product of values according to indices sets. - */ - #if NOT_EXCLUDED(OP_unsorted_segment_prod) - DECLARE_CUSTOM_OP(unsorted_segment_prod, 2, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_unsorted_segment_prod_bp) - DECLARE_CUSTOM_OP(unsorted_segment_prod_bp, 3, 2, false, 0, 1); - #endif +/** + * unsorted_segment_sum op. - make a tensor filled by sum of values according to + * index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * integer param: + * 0 - num of segments + * + * return value: + * tensor with sum of values according to indices sets. + */ +#if NOT_EXCLUDED(OP_unsorted_segment_sum) +DECLARE_CUSTOM_OP(unsorted_segment_sum, 2, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_unsorted_segment_sum_bp) +DECLARE_CUSTOM_OP(unsorted_segment_sum_bp, 3, 2, false, 0, 1); +#endif - /** - * unsorted_segment_mean op. - make a tensor filled by average of values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * integer param: - * 0 - num of segments - * - * return value: - * tensor with average of values according to indices sets. - */ - #if NOT_EXCLUDED(OP_unsorted_segment_mean) - DECLARE_CUSTOM_OP(unsorted_segment_mean, 2, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_unsorted_segment_mean_bp) - DECLARE_CUSTOM_OP(unsorted_segment_mean_bp, 3, 2, false, 0, 1); - #endif +/** + * unsorted_segment_prod op. - make a tensor filled by product of values + * according to index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * integer param: + * 0 - num of segments + * + * return value: + * tensor with product of values according to indices sets. + */ +#if NOT_EXCLUDED(OP_unsorted_segment_prod) +DECLARE_CUSTOM_OP(unsorted_segment_prod, 2, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_unsorted_segment_prod_bp) +DECLARE_CUSTOM_OP(unsorted_segment_prod_bp, 3, 2, false, 0, 1); +#endif - /** - * unsorted_segment_sqrt_n op. - computes the sum along segments of a tensor divided by the sqrt(N). - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * integer param: - * 0 - num of segments - * - * return value: - * tensor with average of values according to indices sets. - */ - #if NOT_EXCLUDED(OP_unsorted_segment_sqrt) - DECLARE_CUSTOM_OP(unsorted_segment_sqrt_n, 2, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_unsorted_segment_sqrt_n_bp) - DECLARE_CUSTOM_OP(unsorted_segment_sqrt_n_bp, 3, 2, false, 0, 1); - #endif +/** + * unsorted_segment_mean op. - make a tensor filled by average of values + * according to index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * integer param: + * 0 - num of segments + * + * return value: + * tensor with average of values according to indices sets. + */ +#if NOT_EXCLUDED(OP_unsorted_segment_mean) +DECLARE_CUSTOM_OP(unsorted_segment_mean, 2, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_unsorted_segment_mean_bp) +DECLARE_CUSTOM_OP(unsorted_segment_mean_bp, 3, 2, false, 0, 1); +#endif - /** - * extract_image_patches op - Extract patches from images and put them in the "depth" output dimension. - * - * input params: - * 0 - images tensor (4D) - * - * int params: - * 0 - ksize_rows - * 1 - ksize_cols - * 2 - strides_rows - * 3 - strides_cols - * 4 - rates_rows - * 5 - rates_cols - * 6 - padding_type - 0 - equiv 'VALID', 1 - 'SAME' - */ - #if NOT_EXCLUDED(OP_extract_image_patches) - DECLARE_CUSTOM_OP(extract_image_patches, 1, 1, false, 0, 7); - #endif +/** + * unsorted_segment_sqrt_n op. - computes the sum along segments of a tensor + * divided by the sqrt(N). + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * integer param: + * 0 - num of segments + * + * return value: + * tensor with average of values according to indices sets. + */ +#if NOT_EXCLUDED(OP_unsorted_segment_sqrt) +DECLARE_CUSTOM_OP(unsorted_segment_sqrt_n, 2, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_unsorted_segment_sqrt_n_bp) +DECLARE_CUSTOM_OP(unsorted_segment_sqrt_n_bp, 3, 2, false, 0, 1); +#endif - /** - * draw_bounding_boxes op - modified input image with given colors exept given boxes. - * - * input params: - * 0 - images tensor (4D) with shape {batch, width, height, channels}, where channes is 1 (BW image), - * 3 (RGB) or 4 (RGBA) - * 1 - boxes tensor (3D) with shape {batch, number_of_boxes, 4} where last dimension encoded as - * (y_min, x_min, y_max, x_max), all values in between 0. and 1. - * 2 - colours tensor (2D) with shape {number_of_boxes, channels} -- bordering color set (palette) - * - * output: - * 0 - 4D tensor with same shape as images (input 0) - */ - #if NOT_EXCLUDED(OP_draw_bounding_boxes) - DECLARE_OP(draw_bounding_boxes, 3, 1, true); - #endif +/** + * extract_image_patches op - Extract patches from images and put them in the + * "depth" output dimension. + * + * input params: + * 0 - images tensor (4D) + * + * int params: + * 0 - ksize_rows + * 1 - ksize_cols + * 2 - strides_rows + * 3 - strides_cols + * 4 - rates_rows + * 5 - rates_cols + * 6 - padding_type - 0 - equiv 'VALID', 1 - 'SAME' + */ +#if NOT_EXCLUDED(OP_extract_image_patches) +DECLARE_CUSTOM_OP(extract_image_patches, 1, 1, false, 0, 7); +#endif - /** - * roll - op porting from numpy (https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.roll.html) - * - * input params: - * 0 - NDArray - * - * int params: - * 0 - shift - * 1 - axe 1 - * 2 - axe 2 - * ... - * N - axe N - * - * All axes are optional and should be between 0 and input->rankOf(). Of course, all axes can be repeated. - * - * output: - * 0 - NDArray with the same shape as input. - */ - #if NOT_EXCLUDED(OP_roll) - DECLARE_CONFIGURABLE_OP(roll, 1, 1, true, 0, 1); - #endif +/** + * draw_bounding_boxes op - modified input image with given colors exept given + * boxes. + * + * input params: + * 0 - images tensor (4D) with shape {batch, width, height, channels}, where + * channes is 1 (BW image), 3 (RGB) or 4 (RGBA) 1 - boxes tensor (3D) with shape + * {batch, number_of_boxes, 4} where last dimension encoded as (y_min, x_min, + * y_max, x_max), all values in between 0. and 1. 2 - colours tensor (2D) with + * shape {number_of_boxes, channels} -- bordering color set (palette) + * + * output: + * 0 - 4D tensor with same shape as images (input 0) + */ +#if NOT_EXCLUDED(OP_draw_bounding_boxes) +DECLARE_OP(draw_bounding_boxes, 3, 1, true); +#endif - /** - * lin_space - op porting from TF (https://www.tensorflow.org/api_docs/python/tf/lin_space) - * - * optional input params: - * 0 - startVal - NDArray scalar (float point) - * 1 - finishVal - NDArray scalar (float point) - * 2 - numOfElements - NDArray scalar (integer) - * Optional: - * T args - * 0 - startVal - * 1 - finishVal] - * 2 - numOfElements - * output: - * 0 - 1D NDArray with the same type as input and length as given with numOfElements param. - */ - #if NOT_EXCLUDED(OP_lin_space) - DECLARE_CUSTOM_OP(lin_space, 0, 1, false, 0, 0); - #endif +/** + * roll - op porting from numpy + * (https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.roll.html) + * + * input params: + * 0 - NDArray + * + * int params: + * 0 - shift + * 1 - axe 1 + * 2 - axe 2 + * ... + * N - axe N + * + * All axes are optional and should be between 0 and input->rankOf(). Of + * course, all axes can be repeated. + * + * output: + * 0 - NDArray with the same shape as input. + */ +#if NOT_EXCLUDED(OP_roll) +DECLARE_CONFIGURABLE_OP(roll, 1, 1, true, 0, 1); +#endif - /** - * reduction_sum - tf.reduction_sum operation - * - * input params: - * 0 - NDArray - * - * T_ARG param (optional): - * 0 - keep_dims != 0. - * - * int params (optional): - * 0 - axe 1 - * 1 - axe 2 - * ... - * N-1 axe N - * - * All axes are optional and should be between 0 and input->rankOf() - 1 - * - * output: - * 0 - NDArray with reduces shape accordingly to axes (the scalar in default case). - */ - #if NOT_EXCLUDED(OP_reduce_sum) - DECLARE_CUSTOM_OP(reduce_sum, 1, 1, false, 0, 0); - #endif +/** + * lin_space - op porting from TF + * (https://www.tensorflow.org/api_docs/python/tf/lin_space) + * + * optional input params: + * 0 - startVal - NDArray scalar (float point) + * 1 - finishVal - NDArray scalar (float point) + * 2 - numOfElements - NDArray scalar (integer) + * Optional: + * T args + * 0 - startVal + * 1 - finishVal] + * 2 - numOfElements + * output: + * 0 - 1D NDArray with the same type as input and length as given with + * numOfElements param. + */ +#if NOT_EXCLUDED(OP_lin_space) +DECLARE_CUSTOM_OP(lin_space, 0, 1, false, 0, 0); +#endif - #if NOT_EXCLUDED(OP_reduce_sum_bp) - DECLARE_CUSTOM_OP(reduce_sum_bp, 2, 1, false, 0, 0); - #endif +/** + * reduction_sum - tf.reduction_sum operation + * + * input params: + * 0 - NDArray + * + * T_ARG param (optional): + * 0 - keep_dims != 0. + * + * int params (optional): + * 0 - axe 1 + * 1 - axe 2 + * ... + * N-1 axe N + * + * All axes are optional and should be between 0 and input->rankOf() - 1 + * + * output: + * 0 - NDArray with reduces shape accordingly to axes (the scalar in default + * case). + */ +#if NOT_EXCLUDED(OP_reduce_sum) +DECLARE_CUSTOM_OP(reduce_sum, 1, 1, false, 0, 0); +#endif - /** - * reduction_prod - tf.reduction_prod operation - * - * input params: - * 0 - NDArray - * - * T_ARG param (optional): - * 0 - keep_dims != 0. - * - * int params (optional): - * 0 - axe 1 - * 1 - axe 2 - * ... - * N-1 axe N - * - * All axes are optional and should be between 0 and input->rankOf() - 1 - * - * output: - * 0 - NDArray with reduces shape accordingly to axes (the scalar in default case). - */ - #if NOT_EXCLUDED(OP_reduce_prod) - DECLARE_CUSTOM_OP(reduce_prod, 1, 1, false, 0, 0); - #endif +#if NOT_EXCLUDED(OP_reduce_sum_bp) +DECLARE_CUSTOM_OP(reduce_sum_bp, 2, 1, false, 0, 0); +#endif - #if NOT_EXCLUDED(OP_reduce_prod_bp) - DECLARE_CUSTOM_OP(reduce_prod_bp, 2, 1, false, 0, 0); - #endif +/** + * reduction_prod - tf.reduction_prod operation + * + * input params: + * 0 - NDArray + * + * T_ARG param (optional): + * 0 - keep_dims != 0. + * + * int params (optional): + * 0 - axe 1 + * 1 - axe 2 + * ... + * N-1 axe N + * + * All axes are optional and should be between 0 and input->rankOf() - 1 + * + * output: + * 0 - NDArray with reduces shape accordingly to axes (the scalar in default + * case). + */ +#if NOT_EXCLUDED(OP_reduce_prod) +DECLARE_CUSTOM_OP(reduce_prod, 1, 1, false, 0, 0); +#endif - /** - * This op calculates min of elements along given dimensions - * - * input array: - * x: tensor to calculate mins for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate min along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated mins - */ - #if NOT_EXCLUDED(OP_reduce_min) - DECLARE_CUSTOM_OP(reduce_min, 1, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_reduce_min_bp) - DECLARE_CUSTOM_OP(reduce_min_bp, 2, 1, false, 0, 0); - #endif +#if NOT_EXCLUDED(OP_reduce_prod_bp) +DECLARE_CUSTOM_OP(reduce_prod_bp, 2, 1, false, 0, 0); +#endif - /** - * This op calculates max of elements along given dimensions - * - * input array: - * x: tensor to calculate maxes for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate max along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated maxes - */ - #if NOT_EXCLUDED(OP_reduce_max) - DECLARE_CUSTOM_OP(reduce_max, 1, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_reduce_max_bp) - DECLARE_CUSTOM_OP(reduce_max_bp, 2, 1, false, 0, 0); - #endif +/** + * This op calculates min of elements along given dimensions + * + * input array: + * x: tensor to calculate mins for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate min along, default corresponds + * to empty list in which case calculation is performed for all dimensions and + * scalar is returned + * + * output array: + * reduced tensor with calculated mins + */ +#if NOT_EXCLUDED(OP_reduce_min) +DECLARE_CUSTOM_OP(reduce_min, 1, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_reduce_min_bp) +DECLARE_CUSTOM_OP(reduce_min_bp, 2, 1, false, 0, 0); +#endif - /** - * This op calculates norm1 of elements along given dimensions - * - * input array: - * x: tensor to calculate norm1 for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate norm1 along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated norm1 - */ - #if NOT_EXCLUDED(OP_reduce_norm1) - DECLARE_CUSTOM_OP(reduce_norm1, 1, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_reduce_norm1_bp) - DECLARE_CUSTOM_OP(reduce_norm1_bp, 2, 1, false, 0, 0); - #endif +/** + * This op calculates max of elements along given dimensions + * + * input array: + * x: tensor to calculate maxes for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate max along, default corresponds + * to empty list in which case calculation is performed for all dimensions and + * scalar is returned + * + * output array: + * reduced tensor with calculated maxes + */ +#if NOT_EXCLUDED(OP_reduce_max) +DECLARE_CUSTOM_OP(reduce_max, 1, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_reduce_max_bp) +DECLARE_CUSTOM_OP(reduce_max_bp, 2, 1, false, 0, 0); +#endif - /** - * This op calculates norm2 of elements along given dimensions - * - * input array: - * x: tensor to calculate norm2 for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate norm2 along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated norm2 - */ - #if NOT_EXCLUDED(OP_reduce_norm2) - DECLARE_CUSTOM_OP(reduce_norm2, 1, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_reduce_norm2_bp) - DECLARE_CUSTOM_OP(reduce_norm2_bp, 2, 1, false, 0, 0); - #endif +/** + * This op calculates norm1 of elements along given dimensions + * + * input array: + * x: tensor to calculate norm1 for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate norm1 along, default + * corresponds to empty list in which case calculation is performed for all + * dimensions and scalar is returned + * + * output array: + * reduced tensor with calculated norm1 + */ +#if NOT_EXCLUDED(OP_reduce_norm1) +DECLARE_CUSTOM_OP(reduce_norm1, 1, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_reduce_norm1_bp) +DECLARE_CUSTOM_OP(reduce_norm1_bp, 2, 1, false, 0, 0); +#endif +/** + * This op calculates norm2 of elements along given dimensions + * + * input array: + * x: tensor to calculate norm2 for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate norm2 along, default + * corresponds to empty list in which case calculation is performed for all + * dimensions and scalar is returned + * + * output array: + * reduced tensor with calculated norm2 + */ +#if NOT_EXCLUDED(OP_reduce_norm2) +DECLARE_CUSTOM_OP(reduce_norm2, 1, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_reduce_norm2_bp) +DECLARE_CUSTOM_OP(reduce_norm2_bp, 2, 1, false, 0, 0); +#endif - /** - * This op calculates squared norm of elements along given dimensions - * - * input array: - * x: tensor to calculate squared norm for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate squared norm along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated norm - */ - #if NOT_EXCLUDED(OP_reduce_sqnorm) - DECLARE_CUSTOM_OP(reduce_sqnorm, 1, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_reduce_sqnorm_bp) - DECLARE_CUSTOM_OP(reduce_sqnorm_bp, 2, 1, false, 0, 0); - #endif +/** + * This op calculates squared norm of elements along given dimensions + * + * input array: + * x: tensor to calculate squared norm for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate squared norm along, default + * corresponds to empty list in which case calculation is performed for all + * dimensions and scalar is returned + * + * output array: + * reduced tensor with calculated norm + */ +#if NOT_EXCLUDED(OP_reduce_sqnorm) +DECLARE_CUSTOM_OP(reduce_sqnorm, 1, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_reduce_sqnorm_bp) +DECLARE_CUSTOM_OP(reduce_sqnorm_bp, 2, 1, false, 0, 0); +#endif - /** - * This op calculates norm max of elements along given dimensions - * - * input array: - * x: tensor to calculate norm max for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate norm max along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated norm - */ - #if NOT_EXCLUDED(OP_reduce_norm_max) - DECLARE_CUSTOM_OP(reduce_norm_max, 1, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_reduce_norm_max_bp) - DECLARE_CUSTOM_OP(reduce_norm_max_bp, 2, 1, false, 0, 0); - #endif +/** + * This op calculates norm max of elements along given dimensions + * + * input array: + * x: tensor to calculate norm max for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate norm max along, default + * corresponds to empty list in which case calculation is performed for all + * dimensions and scalar is returned + * + * output array: + * reduced tensor with calculated norm + */ +#if NOT_EXCLUDED(OP_reduce_norm_max) +DECLARE_CUSTOM_OP(reduce_norm_max, 1, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_reduce_norm_max_bp) +DECLARE_CUSTOM_OP(reduce_norm_max_bp, 2, 1, false, 0, 0); +#endif - /** - * This op calculates mean of elements along given dimensions - * - * input array: - * x: tensor to calculate mean for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate mean along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated means - */ - #if NOT_EXCLUDED(OP_reduce_mean) - DECLARE_CUSTOM_OP(reduce_mean, 1, 1, false, 0, 0); - #endif +/** + * This op calculates mean of elements along given dimensions + * + * input array: + * x: tensor to calculate mean for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate mean along, default corresponds + * to empty list in which case calculation is performed for all dimensions and + * scalar is returned + * + * output array: + * reduced tensor with calculated means + */ +#if NOT_EXCLUDED(OP_reduce_mean) +DECLARE_CUSTOM_OP(reduce_mean, 1, 1, false, 0, 0); +#endif - #if NOT_EXCLUDED(OP_reduce_mean_bp) - DECLARE_CUSTOM_OP(reduce_mean_bp, 2, 1, false, 0, 0) - #endif - /** - * This op calculates sample variance of elements along given dimensions - * - * input array: - * x: tensor to calculate mean for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * biasCorrected - if non zero, then bias correction will be applied, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate mean along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated means - */ - DECLARE_CUSTOM_OP(reduce_variance, 1, 1, false, 0, 0); - DECLARE_CUSTOM_OP(reduce_variance_bp, 2, 1, false, 0, 0) +#if NOT_EXCLUDED(OP_reduce_mean_bp) +DECLARE_CUSTOM_OP(reduce_mean_bp, 2, 1, false, 0, 0) +#endif +/** + * This op calculates sample variance of elements along given dimensions + * + * input array: + * x: tensor to calculate mean for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero biasCorrected - if non zero, then bias correction will + * be applied, default value is zero + * + * int arguments: + * list of integers - dimensions to calculate mean along, default corresponds + * to empty list in which case calculation is performed for all dimensions and + * scalar is returned + * + * output array: + * reduced tensor with calculated means + */ +DECLARE_CUSTOM_OP(reduce_variance, 1, 1, false, 0, 0); +DECLARE_CUSTOM_OP(reduce_variance_bp, 2, 1, false, 0, 0) - /** - * This op calculates sample standard deviation of elements along given dimensions - * - * input array: - * x: tensor to calculate mean for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * biasCorrected - if non zero, then bias correction will be applied, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate mean along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated means - */ - DECLARE_CUSTOM_OP(reduce_stdev, 1, 1, false, 0, 0); - DECLARE_CUSTOM_OP(reduce_stdev_bp, 2, 1, false, 0, 0) +/** + * This op calculates sample standard deviation of elements along given + * dimensions + * + * input array: + * x: tensor to calculate mean for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero biasCorrected - if non zero, then bias correction will + * be applied, default value is zero + * + * int arguments: + * list of integers - dimensions to calculate mean along, default corresponds + * to empty list in which case calculation is performed for all dimensions and + * scalar is returned + * + * output array: + * reduced tensor with calculated means + */ +DECLARE_CUSTOM_OP(reduce_stdev, 1, 1, false, 0, 0); +DECLARE_CUSTOM_OP(reduce_stdev_bp, 2, 1, false, 0, 0) - /** - * This op calculates backprop dot for two tensors along given dimensions - * - * input array: - * x: tensor to calculate dot for - * y: tensor to calculate dot for - * z: tensor with gradient output of the FF dot for x and y - * - * int arguments: - * list of integers - dimensions to calculate dot along, - * default corresponds to empty list in which case calculation - * is performed for all dimensions and scalar is returned. - * - * output array: - * the tensor with calculated backproped dots - * - */ +/** + * This op calculates backprop dot for two tensors along given dimensions + * + * input array: + * x: tensor to calculate dot for + * y: tensor to calculate dot for + * z: tensor with gradient output of the FF dot for x and y + * + * int arguments: + * list of integers - dimensions to calculate dot along, + * default corresponds to empty list in which case calculation + * is performed for all dimensions and scalar is returned. + * + * output array: + * the tensor with calculated backproped dots + * + */ - #if NOT_EXCLUDED(OP_reduce_dot_bp) - DECLARE_CUSTOM_OP(reduce_dot_bp, 3, 2, false, 0, 0); - #endif - /** - * reduce_logsumexp - tf.reduce_logsumexe operation - * - * input params: - * 0 - NDArray (input) - * 1 - 1D NDArray (axis) (optional) - integer array - * - * T_ARG param (optional): - * 0 - keep_dims != 0. - * - * int params (optional): - * 0 - axe 1 - * 1 - axe 2 - * ... - * N-1 axe N - * - * CAUTION: All axes are optional and should be between 0 and input->rankOf() - 1 - * and put either with second param or as integers but not both - * - * output: - * 0 - NDArray with reduces shape accordingly to axes (the scalar in default case). - */ - #if NOT_EXCLUDED(OP_reduce_logsumexp) - DECLARE_CUSTOM_OP(reduce_logsumexp, 1, 1, false, 0, 0); - #endif +#if NOT_EXCLUDED(OP_reduce_dot_bp) +DECLARE_CUSTOM_OP(reduce_dot_bp, 3, 2, false, 0, 0); +#endif +/** + * reduce_logsumexp - tf.reduce_logsumexe operation + * + * input params: + * 0 - NDArray (input) + * 1 - 1D NDArray (axis) (optional) - integer array + * + * T_ARG param (optional): + * 0 - keep_dims != 0. + * + * int params (optional): + * 0 - axe 1 + * 1 - axe 2 + * ... + * N-1 axe N + * + * CAUTION: All axes are optional and should be between 0 and input->rankOf() - + * 1 and put either with second param or as integers but not both + * + * output: + * 0 - NDArray with reduces shape accordingly to axes (the scalar in default + * case). + */ +#if NOT_EXCLUDED(OP_reduce_logsumexp) +DECLARE_CUSTOM_OP(reduce_logsumexp, 1, 1, false, 0, 0); +#endif /** * Copy a tensor setting everything outside a central band in each innermost matrix @@ -1786,139 +1851,143 @@ namespace sd { * */ - #if NOT_EXCLUDED(OP_matrix_band_part) - DECLARE_CONFIGURABLE_OP(matrix_band_part, 1, 1, true, 0, 2); - #endif - +#if NOT_EXCLUDED(OP_matrix_band_part) +DECLARE_CONFIGURABLE_OP(matrix_band_part, 1, 1, true, 0, 2); +#endif - #if NOT_EXCLUDED(OP_Assert) - DECLARE_OP(Assert, 1, 1, false); - #endif +#if NOT_EXCLUDED(OP_Assert) +DECLARE_OP(Assert, 1, 1, false); +#endif - /** - * image.non_max_suppression ops. - * input: - * 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type - * 1 - scales - 1D-tensor with shape (num_boxes) by float type - * 2 - output_size - 0D-tensor by int type (optional) - * float args: - * 0 - overlap_threshold - threshold value for overlap checks (optional, by default 0.5) - * 1 - score_threshold - the threshold for deciding when to remove boxes based on score (optional, by default -inf) - * int args: - * 0 - output_size - as arg 2 used for same target. Eigher this or arg 2 should be provided. - * - * output: - * - vector with size M, where M <= output_size by int type - * - * */ - #if NOT_EXCLUDED(OP_image_non_max_suppression) - DECLARE_CUSTOM_OP(non_max_suppression, 2, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_image_non_max_suppression_v3) - DECLARE_CUSTOM_OP(non_max_suppression_v3, 2, 1, false, 0, 0); - #endif +/** + * image.non_max_suppression ops. + * input: + * 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type + * 1 - scales - 1D-tensor with shape (num_boxes) by float type + * 2 - output_size - 0D-tensor by int type (optional) + * float args: + * 0 - overlap_threshold - threshold value for overlap checks (optional, by + * default 0.5) 1 - score_threshold - the threshold for deciding when to remove + * boxes based on score (optional, by default -inf) int args: 0 - output_size - + * as arg 2 used for same target. Eigher this or arg 2 should be provided. + * + * output: + * - vector with size M, where M <= output_size by int type + * + * */ +#if NOT_EXCLUDED(OP_image_non_max_suppression) +DECLARE_CUSTOM_OP(non_max_suppression, 2, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_image_non_max_suppression_v3) +DECLARE_CUSTOM_OP(non_max_suppression_v3, 2, 1, false, 0, 0); +#endif - /* - * image.non_max_suppression_overlaps op. - * input: - * 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type - * 1 - scales - 1D-tensor with shape (num_boxes) by float type - * 2 - output_size - 0D-tensor by int type (optional) - * float args: - * 0 - overlap_threshold - threshold value for overlap checks (optional, by default 0.5) - * 1 - score_threshold - the threshold for deciding when to remove boxes based on score (optional, by default -inf) - * int args: - * 0 - output_size - as arg 2 used for same target. Eigher this or arg 2 should be provided. - * - * output: - * 0 - 1D integer tensor with shape [M], epresenting the selected indices from the overlaps tensor, where M <= max_output_size - * */ - #if NOT_EXCLUDED(OP_image_non_max_suppression_overlaps) - DECLARE_CUSTOM_OP(non_max_suppression_overlaps, 2, 1, false, 0, 0); - #endif +/* + * image.non_max_suppression_overlaps op. + * input: + * 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type + * 1 - scales - 1D-tensor with shape (num_boxes) by float type + * 2 - output_size - 0D-tensor by int type (optional) + * float args: + * 0 - overlap_threshold - threshold value for overlap checks (optional, by + * default 0.5) 1 - score_threshold - the threshold for deciding when to remove + * boxes based on score (optional, by default -inf) int args: 0 - output_size - + * as arg 2 used for same target. Eigher this or arg 2 should be provided. + * + * output: + * 0 - 1D integer tensor with shape [M], epresenting the selected indices + * from the overlaps tensor, where M <= max_output_size + * */ +#if NOT_EXCLUDED(OP_image_non_max_suppression_overlaps) +DECLARE_CUSTOM_OP(non_max_suppression_overlaps, 2, 1, false, 0, 0); +#endif - /* - * cholesky op - decomposite positive square symetric matrix (or matricies when rank > 2). - * input: - * 0 - matricies - tensor with shape (..., N, N) by float type - * - * output - lower triangular matrix (matricies when rank > 2) with the same shape as input. - * */ - #if NOT_EXCLUDED(OP_cholesky) - DECLARE_OP(cholesky, 1, 1, true); - #endif - /* - * nth_element - apply nth_element for last dimension of input tensor - * input array: - * 0 - input array - * 1 - scalar tensor with n for operation. n should be less than last dimension - * - * output: - * 0 - NDArray with the same shape as input - */ - #if NOT_EXCLUDED(OP_nth_element) - DECLARE_CUSTOM_OP(nth_element, 2, 1, false, 0, 0); - #endif +/* + * cholesky op - decomposite positive square symetric matrix (or matricies when + * rank > 2). input: 0 - matricies - tensor with shape (..., N, N) by float type + * + * output - lower triangular matrix (matricies when rank > 2) with the same + * shape as input. + * */ +#if NOT_EXCLUDED(OP_cholesky) +DECLARE_OP(cholesky, 1, 1, true); +#endif +/* + * nth_element - apply nth_element for last dimension of input tensor + * input array: + * 0 - input array + * 1 - scalar tensor with n for operation. n should be less than last + * dimension + * + * output: + * 0 - NDArray with the same shape as input + */ +#if NOT_EXCLUDED(OP_nth_element) +DECLARE_CUSTOM_OP(nth_element, 2, 1, false, 0, 0); +#endif - /** - * This op checks for Inf/NaN values within input array, and throws exception if there's at least one - */ - #if NOT_EXCLUDED(OP_check_numerics) - DECLARE_CUSTOM_OP(check_numerics, 2, 1, true, 0, 0); - #endif /** - * fake_quant_with_min_max_vals - tf.quantization.fake_quant_with_min_max_vars - * - * input params: - * 0 - NDArray (input) - * 1 - 0D Tensor - min value - * 2 - 0D Tensor - max value - * - * int params (optional): - * 0 - num_bits (allowed interval [2, 16], default 8) - * 1 - narrow_range (default False) - * - * output: - * 0 - NDArray with the same shape as input - */ - #if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars) - DECLARE_CONFIGURABLE_OP(fake_quant_with_min_max_vars, 3, 1, true, 0, -2); - #endif + * This op checks for Inf/NaN values within input array, and throws exception if + * there's at least one + */ +#if NOT_EXCLUDED(OP_check_numerics) +DECLARE_CUSTOM_OP(check_numerics, 2, 1, true, 0, 0); +#endif +/** + * fake_quant_with_min_max_vals - tf.quantization.fake_quant_with_min_max_vars + * + * input params: + * 0 - NDArray (input) + * 1 - 0D Tensor - min value + * 2 - 0D Tensor - max value + * + * int params (optional): + * 0 - num_bits (allowed interval [2, 16], default 8) + * 1 - narrow_range (default False) + * + * output: + * 0 - NDArray with the same shape as input + */ +#if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars) +DECLARE_CONFIGURABLE_OP(fake_quant_with_min_max_vars, 3, 1, true, 0, -2); +#endif /** - * fake_quant_with_min_max_vals_per_channel - tf.quantization.fake_quant_with_min_max_vars_per_channel - * - * input params: - * 0 - NDArray (input) - at least 2D. - * 1 - 1D Tensor - min values (min length equals to last dim of input) - * 2 - 1D Tensor - max value (length equals to min) - * - * int params (optional): - * 0 - num_bits (allowed interval [2, 16], default 8) - * 1 - narrow_range (default False) - * - * output: - * 0 - NDArray with the same shape as input - */ - #if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars_per_channel) - DECLARE_CONFIGURABLE_OP(fake_quant_with_min_max_vars_per_channel, 3, 1, true, 0, -2); - #endif + * fake_quant_with_min_max_vals_per_channel - + * tf.quantization.fake_quant_with_min_max_vars_per_channel + * + * input params: + * 0 - NDArray (input) - at least 2D. + * 1 - 1D Tensor - min values (min length equals to last dim of input) + * 2 - 1D Tensor - max value (length equals to min) + * + * int params (optional): + * 0 - num_bits (allowed interval [2, 16], default 8) + * 1 - narrow_range (default False) + * + * output: + * 0 - NDArray with the same shape as input + */ +#if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars_per_channel) +DECLARE_CONFIGURABLE_OP(fake_quant_with_min_max_vars_per_channel, 3, 1, true, 0, + -2); +#endif - /** - * compare_and_bitpack - compare with greater and pack result with uint8 - * - * input params: - * 0 - NDArray (input) - * 1 - 0D Tensor - threshold - * - * - * output: - * 0 - NDArray with the same shape as input and type uint8 - */ - #if NOT_EXCLUDED(OP_compare_and_bitpack) - DECLARE_CUSTOM_OP(compare_and_bitpack, 2, 1, false, 0, 0); - #endif - } -} +/** + * compare_and_bitpack - compare with greater and pack result with uint8 + * + * input params: + * 0 - NDArray (input) + * 1 - 0D Tensor - threshold + * + * + * output: + * 0 - NDArray with the same shape as input and type uint8 + */ +#if NOT_EXCLUDED(OP_compare_and_bitpack) +DECLARE_CUSTOM_OP(compare_and_bitpack, 2, 1, false, 0, 0); +#endif +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/headers/random.h b/libnd4j/include/ops/declarable/headers/random.h index 367a41995e05..d478e2b0daba 100644 --- a/libnd4j/include/ops/declarable/headers/random.h +++ b/libnd4j/include/ops/declarable/headers/random.h @@ -25,78 +25,79 @@ #include namespace sd { - namespace ops { - #if NOT_EXCLUDED(OP_set_seed) - DECLARE_CUSTOM_OP(set_seed, -2, 1, false, 0, -2); - #endif +namespace ops { +#if NOT_EXCLUDED(OP_set_seed) +DECLARE_CUSTOM_OP(set_seed, -2, 1, false, 0, -2); +#endif - #if NOT_EXCLUDED(OP_get_seed) - DECLARE_CUSTOM_OP(get_seed, -2, 1, false, 0, 0); - #endif +#if NOT_EXCLUDED(OP_get_seed) +DECLARE_CUSTOM_OP(get_seed, -2, 1, false, 0, 0); +#endif - /* - * random_uniform distribution for types int32,int64, float16, float and double - * by default dtype is float32 - * - * input: - * 0 - shape of output (1D int tensor) - * 1 - min val (0D of output type) - optional (0 as default) - * 2 - max val (0D of output type) - optional (inf as default) - * - * output: - * 0 - uniformly distributed values of given type (between min and max) - */ - #if NOT_EXCLUDED(OP_randomuniform) - DECLARE_CUSTOM_OP(randomuniform, 1, 1, false, 0, 0); - #endif - /* - * multinomial (categorical) random generator draws samples from a multinomial distribution - * - * Input array: - * 0 - 2D ndarray with unnormalized log-probabilities with shape [batch_size (N), num_classes (K)] - * 1 - array with one int value of samples number, number of independent samples to draw for each experiment 1,N. - * Int arguments: - * 0 - optional argument, corresponds to dimension with batch_size - * 1 - optional argument, integer type to use for the output. Default int64. - * - * Output array: - * 0 - 2D ndarray with the drawn samples of shape [batch_size, num_samples] - */ - #if NOT_EXCLUDED(OP_random_multinomial) - DECLARE_CUSTOM_OP(random_multinomial, 2, 1, false, 0, 0); - #endif +/* + * random_uniform distribution for types int32,int64, float16, float and double + * by default dtype is float32 + * + * input: + * 0 - shape of output (1D int tensor) + * 1 - min val (0D of output type) - optional (0 as default) + * 2 - max val (0D of output type) - optional (inf as default) + * + * output: + * 0 - uniformly distributed values of given type (between min and max) + */ +#if NOT_EXCLUDED(OP_randomuniform) +DECLARE_CUSTOM_OP(randomuniform, 1, 1, false, 0, 0); +#endif +/* + * multinomial (categorical) random generator draws samples from a multinomial + * distribution + * + * Input array: + * 0 - 2D ndarray with unnormalized log-probabilities with shape [batch_size + * (N), num_classes (K)] 1 - array with one int value of samples number, number + * of independent samples to draw for each experiment 1,N. Int arguments: 0 - + * optional argument, corresponds to dimension with batch_size 1 - optional + * argument, integer type to use for the output. Default int64. + * + * Output array: + * 0 - 2D ndarray with the drawn samples of shape [batch_size, num_samples] + */ +#if NOT_EXCLUDED(OP_random_multinomial) +DECLARE_CUSTOM_OP(random_multinomial, 2, 1, false, 0, 0); +#endif - #if NOT_EXCLUDED(OP_random_normal) - DECLARE_CUSTOM_OP(random_normal, 1, 1, true, 2, 0); - #endif +#if NOT_EXCLUDED(OP_random_normal) +DECLARE_CUSTOM_OP(random_normal, 1, 1, true, 2, 0); +#endif - #if NOT_EXCLUDED(OP_random_bernoulli) - DECLARE_CUSTOM_OP(random_bernoulli, 1, 1, true, 0, 1); - #endif +#if NOT_EXCLUDED(OP_random_bernoulli) +DECLARE_CUSTOM_OP(random_bernoulli, 1, 1, true, 0, 1); +#endif - #if NOT_EXCLUDED(OP_random_exponential) - DECLARE_CUSTOM_OP(random_exponential, 1, 1, true, 1, 0); - #endif +#if NOT_EXCLUDED(OP_random_exponential) +DECLARE_CUSTOM_OP(random_exponential, 1, 1, true, 1, 0); +#endif - #if NOT_EXCLUDED(OP_random_crop) - DECLARE_CUSTOM_OP(random_crop, 2, 1, false, 0, 0); - #endif +#if NOT_EXCLUDED(OP_random_crop) +DECLARE_CUSTOM_OP(random_crop, 2, 1, false, 0, 0); +#endif - /** - * random_gamma op. - */ - #if NOT_EXCLUDED(OP_random_gamma) - DECLARE_CUSTOM_OP(random_gamma, 2, 1, false, 0, 0); - #endif +/** + * random_gamma op. + */ +#if NOT_EXCLUDED(OP_random_gamma) +DECLARE_CUSTOM_OP(random_gamma, 2, 1, false, 0, 0); +#endif - /** - * random_poisson op. - */ - #if NOT_EXCLUDED(OP_random_poisson) - DECLARE_CUSTOM_OP(random_poisson, 2, 1, false, 0, 0); - #endif +/** + * random_poisson op. + */ +#if NOT_EXCLUDED(OP_random_poisson) +DECLARE_CUSTOM_OP(random_poisson, 2, 1, false, 0, 0); +#endif - } -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/recurrent.h b/libnd4j/include/ops/declarable/headers/recurrent.h index aeeae24c42fe..870f0ceea2f2 100644 --- a/libnd4j/include/ops/declarable/headers/recurrent.h +++ b/libnd4j/include/ops/declarable/headers/recurrent.h @@ -24,420 +24,448 @@ #include namespace sd { -namespace ops { - - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features - * 1: 2d tensor of weights [3K x K] - * 2: row of biases with twice length [1 x 2K] - * 3: 2d tensor of previous cell state [bS x K] - * 4: optional, 2d tensor of dropout mask [bS x K] - * - * Output arrays: - * 0: 3d tensor of cell output [bS x K x N] - * 1: 3d tensor of cell state [bS x K x N] - */ - #if NOT_EXCLUDED(OP_sru) - DECLARE_CUSTOM_OP(sru, 5, 2, false, 0, 0); - #endif - - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for Simple Recurrent Unit (bidirectional case): "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input 3d tensor with shape [N x bS x 2K], N - number of time steps, bS - batch size, K - number of features - * 1: 2d tensor of weights [2K x 6K] - * 2: row of biases with twice length [1 x 4K] - * 3: 2d tensor of previous cell state [bS x 2K] - * 4: optional, 2d tensor of dropout mask [bS x 2K] - * - * Output arrays: - * 0: 3d tensor of cell output [N x bS x 2K] - * 1: 3d tensor of cell state [N x bS x 2K] - */ - #if NOT_EXCLUDED(OP_sru_bi) - DECLARE_CUSTOM_OP(sru_bi, 5, 2, true, 0, 0); - #endif - - - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for back propagation in Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features - * 1: 2d tensor of weights [3K x K] - * 2: row of biases with twice length [1 x 2K] - * 3: 2d tensor of previous cell state [bS x K] - * 4: 3d tensor of cell state [bS x K x N] - * 5: 2d tensor of cell state gradients [bS x K] - * 6: 3d tensor of state output gradients [bS x K x N] - * 7: optional, 2d tensor of dropout mask [bS x K] - * - * Output arrays: - * 0: 3d tensor of input gradients [bS x K x N] - * 1: 3d tensor of weights gradients [bS x 3K x K] - * 2: 2d, row of biases gradients [1 x 2K] - * 3: 2d, tensor of state gradients [bS x K] - */ - #if NOT_EXCLUDED(OP_sru) - DECLARE_CUSTOM_OP(sru_bp, 8, 4, true, 0, 0); - #endif - - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for back propagation in Simple Recurrent Unit (bidirectional case): "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input 3d tensor with shape [N x bS x 2K], N - number of time steps, bS - batch size, K - number of features - * 1: 2d tensor of weights [2K x 6K] - * 2: row of biases with twice length [1 x 4K] - * 3: 2d tensor of previous cell state [bS x 2K] - * 4: 3d tensor of cell state [N x bS x 2K] - * 5: 2d tensor of cell state gradients [bS x 2K] - * 6: 3d tensor of state output gradients [N x bS x 2K] - * 7: optional, 2d tensor of dropout mask [bS x 2K] - * - * Output arrays: - * 0: 3d tensor of input gradients [N x bS x 2K] - * 1: 3d tensor of weights gradients [N x 2K x 6K] - * 2: 2d, row of biases gradients [1 x 4K] - * 3: 2d, tensor of state gradients [bS x 2K] - */ - #if NOT_EXCLUDED(OP_sru_bi) - DECLARE_CUSTOM_OP(sru_bi_bp, 8, 4, true, 0, 0); - #endif +namespace ops { +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for Simple Recurrent Unit: "Training RNNs as Fast + * as CNNs" Tao Lei, Yu Zhang, Yoav Artzi + * + * Input arrays: + * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - + * batch size, K - number of features 1: 2d tensor of weights [3K x K] 2: row of + * biases with twice length [1 x 2K] 3: 2d tensor of previous cell state [bS x + * K] 4: optional, 2d tensor of dropout mask [bS x K] + * + * Output arrays: + * 0: 3d tensor of cell output [bS x K x N] + * 1: 3d tensor of cell state [bS x K x N] + */ +#if NOT_EXCLUDED(OP_sru) +DECLARE_CUSTOM_OP(sru, 5, 2, false, 0, 0); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for LSTM cell with peep hole connections: - * S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation - * and - * https://research.google.com/pubs/archive/43905.pdf - * Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014. - * - * Input arrays: - * 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - number of features - * 1: previous cell output [batchSize x numProj], that is at previous time step t-1, in case of projection=false -> numProj=numUnits!!! - * 2: previous cell state [batchSize x numUnits], that is at previous time step t-1 - * 3: input-to-hidden weights, [inSize x 4*numUnits] - * 4: hidden-to-hidden weights, [numProj x 4*numUnits] - * 5: diagonal weights for peephole connections [3*numUnits] - * 6: projection weights [numUnits x numProj] - * 7: biases, [4*numUnits] - * - * Input integer arguments: - * 0: if not zero, provide peephole connections - * 1: if not zero, then projection is performed, if zero then numProj==numUnits is mandatory! - * - * Input float arguments: - * 0: clipping value for cell state, if it is not equal to zero, then cell state is clipped - * 1: clipping value for projected cell output, if it is not equal to zero, then projected cell output is clipped - * 2: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training - * - * Output arrays: - * 0: current cell output [batchSize x numProj], that is at current time step t - * 1: current cell state [batchSize x numUnits], that is at current time step t - */ - #if NOT_EXCLUDED(OP_lstmCell) - DECLARE_CUSTOM_OP(lstmCell, 8, 2, false, 3, 2); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for Simple Recurrent Unit (bidirectional case): + * "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi + * + * Input arrays: + * 0: input 3d tensor with shape [N x bS x 2K], N - number of time steps, bS + * - batch size, K - number of features 1: 2d tensor of weights [2K x 6K] 2: row + * of biases with twice length [1 x 4K] 3: 2d tensor of previous cell state [bS + * x 2K] 4: optional, 2d tensor of dropout mask [bS x 2K] + * + * Output arrays: + * 0: 3d tensor of cell output [N x bS x 2K] + * 1: 3d tensor of cell state [N x bS x 2K] + */ +#if NOT_EXCLUDED(OP_sru_bi) +DECLARE_CUSTOM_OP(sru_bi, 5, 2, true, 0, 0); +#endif - #if NOT_EXCLUDED(OP_lstmLayerCell) - DECLARE_CUSTOM_OP(lstmLayerCell, 5, 2, false, 1, 3); - #endif - #if NOT_EXCLUDED(OP_lstmLayerCell) - DECLARE_CUSTOM_OP(lstmLayerCellBp, 7, 5, false, 1, 3); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for back propagation in Simple Recurrent Unit: + * "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi + * + * Input arrays: + * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - + * batch size, K - number of features 1: 2d tensor of weights [3K x K] 2: row of + * biases with twice length [1 x 2K] 3: 2d tensor of previous cell state [bS x + * K] 4: 3d tensor of cell state [bS x K x N] 5: 2d tensor of cell state + * gradients [bS x K] 6: 3d tensor of state output gradients [bS x K x N] 7: + * optional, 2d tensor of dropout mask [bS x K] + * + * Output arrays: + * 0: 3d tensor of input gradients [bS x K x N] + * 1: 3d tensor of weights gradients [bS x 3K x K] + * 2: 2d, row of biases gradients [1 x 2K] + * 3: 2d, tensor of state gradients [bS x K] + */ +#if NOT_EXCLUDED(OP_sru) +DECLARE_CUSTOM_OP(sru_bp, 8, 4, true, 0, 0); +#endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for back propagation in Simple Recurrent Unit + * (bidirectional case): "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav + * Artzi + * + * Input arrays: + * 0: input 3d tensor with shape [N x bS x 2K], N - number of time steps, bS + * - batch size, K - number of features 1: 2d tensor of weights [2K x 6K] 2: row + * of biases with twice length [1 x 4K] 3: 2d tensor of previous cell state [bS + * x 2K] 4: 3d tensor of cell state [N x bS x 2K] 5: 2d tensor of cell state + * gradients [bS x 2K] 6: 3d tensor of state output gradients [N x bS x 2K] 7: + * optional, 2d tensor of dropout mask [bS x 2K] + * + * Output arrays: + * 0: 3d tensor of input gradients [N x bS x 2K] + * 1: 3d tensor of weights gradients [N x 2K x 6K] + * 2: 2d, row of biases gradients [1 x 4K] + * 3: 2d, tensor of state gradients [bS x 2K] + */ +#if NOT_EXCLUDED(OP_sru_bi) +DECLARE_CUSTOM_OP(sru_bi_bp, 8, 4, true, 0, 0); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for LSTM cell with optional peep hole connections: - * S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation - * and - * https://research.google.com/pubs/archive/43905.pdf - * Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014. - * See also: https://arxiv.org/pdf/1503.04069.pdf - * - * Input arrays: - * 0: input [bS, inSize] at time t - * 1: previous cell state [bS, numUnits], time t-1 - * 2: previous output [bS, numUnits], time t-1 - * 3: Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(inSize+numUnits), 4*numUnits] - * 4: weights - cell peephole (t-1) connections to input modulation gate, [numUnits] - * 5: weights - cell peephole (t-1) connections to forget gate, [numUnits] - * 6: weights - cell peephole (t) connections to output gate, [numUnits] - * 7: biases, shape [4*numUnits] - * - * Input integer arguments: - * 0: if not zero, provide peephole connections - * - * Input float arguments: - * 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training - * 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped - * - * Output arrays: - * 0: i - Input modulation gate activations [bS, numUnits] - * 1: c (cs) - Cell state (pre tanh) [bs, numUnits] (cs) - * 2: f - Output - forget gate activations [bs, numUnits] - * 3: o - Output - output gate activations [bs, numUnits] - * 4: z (ci) - Output - block input [bs, numUnits] - * 5: h (co) - Cell state, post tanh [bs, numUnits] - * 6: y (h) - Current cell output [bS, numUnits], time t - */ - #if NOT_EXCLUDED(OP_lstmBlockCell) - DECLARE_CUSTOM_OP(lstmBlockCell, 8, 7, false, 2, 1); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for LSTM cell with peep hole connections: + * S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural + * Computation and https://research.google.com/pubs/archive/43905.pdf Hasim Sak, + * Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent + * neural network architectures for large scale acoustic modeling." INTERSPEECH, + * 2014. + * + * Input arrays: + * 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - + * number of features 1: previous cell output [batchSize x numProj], that is at + * previous time step t-1, in case of projection=false -> numProj=numUnits!!! 2: + * previous cell state [batchSize x numUnits], that is at previous time step + * t-1 3: input-to-hidden weights, [inSize x 4*numUnits] 4: hidden-to-hidden + * weights, [numProj x 4*numUnits] 5: diagonal weights for peephole connections + * [3*numUnits] 6: projection weights [numUnits x numProj] 7: biases, + * [4*numUnits] + * + * Input integer arguments: + * 0: if not zero, provide peephole connections + * 1: if not zero, then projection is performed, if zero then + * numProj==numUnits is mandatory! + * + * Input float arguments: + * 0: clipping value for cell state, if it is not equal to zero, then cell + * state is clipped 1: clipping value for projected cell output, if it is not + * equal to zero, then projected cell output is clipped 2: the bias added to + * forget gates in order to reduce the scale of forgetting in the beginning of + * the training + * + * Output arrays: + * 0: current cell output [batchSize x numProj], that is at current time step + * t 1: current cell state [batchSize x numUnits], that is at current time step + * t + */ +#if NOT_EXCLUDED(OP_lstmCell) +DECLARE_CUSTOM_OP(lstmCell, 8, 2, false, 3, 2); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for LSTM layer with optional peep hole connections. - * See lstmBlockCell for details. lstmBlockCell is used internally for computation. - * This method expects as input (and returns as output) sequences in one of 3 formats, depending on the data format arg: - * dataFormat = 0 -> TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major" - * dataFormat = 1 -> NST: shape [numExamples, inOutSize, timeLength] - * dataFormat = 2 -> NTS: shape [numExamples, timeLength, inOutSize] - TF "time_major=false" layout - * - * - * Input arrays: - * 0: max sequence length; long/int64 scalar - * 1: input [seqLength, bS, inSize] at time t - * 2: previous/initial cell state [bS, numUnits] - * 3: previous/initial output [bS, numUnits] - * 4: Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(inSize+numUnits), 4*numUnits] - * 5: weights - cell peephole (t-1) connections to input modulation gate, [numUnits] - * 6: weights - cell peephole (t-1) connections to forget gate, [numUnits] - * 7: weights - cell peephole (t) connections to output gate, [numUnits] - * 8: biases, Shape [4*numUnits] - * - * Input integer arguments: - * 0: if not zero, provide peephole connections - * 1: Data format - 0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen]; 2=NTS=[mb,seqLen,size] - * - * Input float arguments: - * 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training - * 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped - * - * Output arrays: - * 0: i - Input modulation gate activations, rank 3, shape as per dataFormat - * 1: c (cs) - Cell state (pre tanh), rank 3, shape as per dataFormat - * 2: f - Output - forget gate activations, rank 3, shape as per dataFormat - * 3: o - Output - output gate activations, rank 3, shape as per dataFormat - * 4: z (ci) - Output - block input, rank 3, shape as per dataFormat - * 5: h (co) - Cell state, post tanh, rank 3, shape as per dataFormat - * 6: y (h) - Current cell output, rank 3, shape as per dataFormat - */ - #if NOT_EXCLUDED(OP_lstmBlock) - DECLARE_CUSTOM_OP(lstmBlock, 9, 7, false, 2, 2); - #endif +#if NOT_EXCLUDED(OP_lstmLayerCell) +DECLARE_CUSTOM_OP(lstmLayerCell, 5, 2, false, 1, 3); +#endif +#if NOT_EXCLUDED(OP_lstmLayerCell) +DECLARE_CUSTOM_OP(lstmLayerCellBp, 7, 5, false, 1, 3); +#endif - ////////////////////////////////////////////////////////////////////////// - #if NOT_EXCLUDED(OP_lstmLayer) - DECLARE_CUSTOM_OP(lstmLayer, 3, 1, false, 1, 5); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for LSTM cell with optional peep hole + * connections: S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". + * Neural Computation and https://research.google.com/pubs/archive/43905.pdf + * Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory + * recurrent neural network architectures for large scale acoustic modeling." + * INTERSPEECH, 2014. See also: https://arxiv.org/pdf/1503.04069.pdf + * + * Input arrays: + * 0: input [bS, inSize] at time t + * 1: previous cell state [bS, numUnits], time t-1 + * 2: previous output [bS, numUnits], time t-1 + * 3: Weights - concatenated (input-to-hidden, hidden-to-hidden weights) + * weights, [(inSize+numUnits), 4*numUnits] 4: weights - cell peephole (t-1) + * connections to input modulation gate, [numUnits] 5: weights - cell peephole + * (t-1) connections to forget gate, [numUnits] 6: weights - cell peephole (t) + * connections to output gate, [numUnits] 7: biases, shape [4*numUnits] + * + * Input integer arguments: + * 0: if not zero, provide peephole connections + * + * Input float arguments: + * 0: the bias added to forget gates in order to reduce the scale of + * forgetting in the beginning of the training 1: clipping value for cell state, + * if it is not equal to zero, then cell state is clipped + * + * Output arrays: + * 0: i - Input modulation gate activations [bS, numUnits] + * 1: c (cs) - Cell state (pre tanh) [bs, numUnits] (cs) + * 2: f - Output - forget gate activations [bs, numUnits] + * 3: o - Output - output gate activations [bs, numUnits] + * 4: z (ci) - Output - block input [bs, numUnits] + * 5: h (co) - Cell state, post tanh [bs, numUnits] + * 6: y (h) - Current cell output [bS, numUnits], time t + */ +#if NOT_EXCLUDED(OP_lstmBlockCell) +DECLARE_CUSTOM_OP(lstmBlockCell, 8, 7, false, 2, 1); +#endif - ////////////////////////////////////////////////////////////////////////// - #if NOT_EXCLUDED(OP_lstmLayer) - DECLARE_CUSTOM_OP(lstmLayer_bp, 4, 1, false, 1, 5); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for LSTM layer with optional peep hole + * connections. See lstmBlockCell for details. lstmBlockCell is used internally + * for computation. This method expects as input (and returns as output) + * sequences in one of 3 formats, depending on the data format arg: dataFormat = + * 0 -> TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to + * as "time major" dataFormat = 1 -> NST: shape [numExamples, inOutSize, + * timeLength] dataFormat = 2 -> NTS: shape [numExamples, timeLength, inOutSize] + * - TF "time_major=false" layout + * + * + * Input arrays: + * 0: max sequence length; long/int64 scalar + * 1: input [seqLength, bS, inSize] at time t + * 2: previous/initial cell state [bS, numUnits] + * 3: previous/initial output [bS, numUnits] + * 4: Weights - concatenated (input-to-hidden, hidden-to-hidden weights) + * weights, [(inSize+numUnits), 4*numUnits] 5: weights - cell peephole (t-1) + * connections to input modulation gate, [numUnits] 6: weights - cell peephole + * (t-1) connections to forget gate, [numUnits] 7: weights - cell peephole (t) + * connections to output gate, [numUnits] 8: biases, Shape [4*numUnits] + * + * Input integer arguments: + * 0: if not zero, provide peephole connections + * 1: Data format - 0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen]; + * 2=NTS=[mb,seqLen,size] + * + * Input float arguments: + * 0: the bias added to forget gates in order to reduce the scale of + * forgetting in the beginning of the training 1: clipping value for cell state, + * if it is not equal to zero, then cell state is clipped + * + * Output arrays: + * 0: i - Input modulation gate activations, rank 3, shape as per + * dataFormat 1: c (cs) - Cell state (pre tanh), rank 3, shape as per dataFormat + * 2: f - Output - forget gate activations, rank 3, shape as per + * dataFormat 3: o - Output - output gate activations, rank 3, shape as per + * dataFormat 4: z (ci) - Output - block input, rank 3, shape as per dataFormat + * 5: h (co) - Cell state, post tanh, rank 3, shape as per dataFormat + * 6: y (h) - Current cell output, rank 3, shape as per dataFormat + */ +#if NOT_EXCLUDED(OP_lstmBlock) +DECLARE_CUSTOM_OP(lstmBlock, 9, 7, false, 2, 2); +#endif +////////////////////////////////////////////////////////////////////////// +#if NOT_EXCLUDED(OP_lstmLayer) +DECLARE_CUSTOM_OP(lstmLayer, 3, 1, false, 1, 5); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operations for Simple Recurrent Unit cell: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - number of features - * 1: previous cell state [batchSize x inSize], that is at previous time step t-1 - * 2: weights [inSize x 3*inSize] - * 3: biases [1 x 2*inSize] - * - * Output arrays: - * 0: current cell output [batchSize x inSize], that is at current time step t - * 1: current cell state [batchSize x inSize], that is at current time step t - */ - #if NOT_EXCLUDED(OP_sruCell) - DECLARE_CUSTOM_OP(sruCell, 4, 2, false, 0, 0); - #endif +////////////////////////////////////////////////////////////////////////// +#if NOT_EXCLUDED(OP_lstmLayer) +DECLARE_CUSTOM_OP(lstmLayer_bp, 4, 1, false, 1, 5); +#endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operations for Simple Recurrent Unit cell: "Training RNNs + * as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi + * + * Input arrays: + * 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - + * number of features 1: previous cell state [batchSize x inSize], that is at + * previous time step t-1 2: weights [inSize x 3*inSize] 3: biases [1 x + * 2*inSize] + * + * Output arrays: + * 0: current cell output [batchSize x inSize], that is at current time step + * t 1: current cell state [batchSize x inSize], that is at current time step t + */ +#if NOT_EXCLUDED(OP_sruCell) +DECLARE_CUSTOM_OP(sruCell, 4, 2, false, 0, 0); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of gated Recurrent Unit cell: - * Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, Yoshua Bengio - * "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation" - * - * Input arrays: - * 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - number of features - * 1: previous cell output [batchSize x numUnits], that is at previous time step t-1 - * 2: RU weights - [(inSize+numUnits), 2*numUnits] - reset and update gates (input/recurrent weights) - * 3: C weights - [(inSize+numUnits), numUnits] - cell gate (input/recurrent weights) - * 4: reset and update biases, [2*numUnits] - reset and update gates - * 5: cell biases, [numUnits] - * - * Output arrays: - * 0: Reset gate output [bS, numUnits] - * 1: Update gate output [bS, numUnits] - * 2: Cell gate output [bS, numUnits] - * 3: Current cell output [bS, numUnits] - */ - #if NOT_EXCLUDED(OP_gruCell) - DECLARE_CUSTOM_OP(gruCell, 6, 4, false, 0, 0); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of gated Recurrent Unit cell: + * Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, + * Fethi Bougares, Holger Schwenk, Yoshua Bengio "Learning Phrase + * Representations using RNN Encoder-Decoder for Statistical Machine + * Translation" + * + * Input arrays: + * 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - + * number of features 1: previous cell output [batchSize x numUnits], that is + * at previous time step t-1 2: RU weights - [(inSize+numUnits), 2*numUnits] - + * reset and update gates (input/recurrent weights) 3: C weights - + * [(inSize+numUnits), numUnits] - cell gate (input/recurrent weights) 4: reset + * and update biases, [2*numUnits] - reset and update gates 5: cell biases, + * [numUnits] + * + * Output arrays: + * 0: Reset gate output [bS, numUnits] + * 1: Update gate output [bS, numUnits] + * 2: Cell gate output [bS, numUnits] + * 3: Current cell output [bS, numUnits] + */ +#if NOT_EXCLUDED(OP_gruCell) +DECLARE_CUSTOM_OP(gruCell, 6, 4, false, 0, 0); +#endif - #if NOT_EXCLUDED(OP_gruCell) - DECLARE_CUSTOM_OP(gruCell_bp, 10, 6, false, 0, 0); - #endif +#if NOT_EXCLUDED(OP_gruCell) +DECLARE_CUSTOM_OP(gruCell_bp, 10, 6, false, 0, 0); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation "LSTM time sequences" with peep hole connections: - * - * Input arrays: - * 0: input with shape [time x batchSize x inSize], time - number of time steps, batchSize - batch size, inSize - number of features - * 1: initial cell output [batchSize x numProj], that is at time step = 0, in case of projection=false -> numProj=numUnits!!! - * 2: initial cell state [batchSize x numUnits], that is at time step = 0 - * 3: input-to-hidden weights, [inSize x 4*numUnits] - * 4: hidden-to-hidden weights, [numProj x 4*numUnits] - * 5: diagonal weights for peephole connections [3*numUnits] - * 6: projection weights [numUnits x numProj] - * 7: biases, [4*numUnits] - * - * Input integer arguments: - * 0: if not zero, provide peephole connections - * 1: if not zero, then projection is performed, if zero then numProj==numUnits is mandatory! - * - * Input float arguments: - * 0: clipping value for cell state, if it is not equal to zero, then cell state is clipped - * 1: clipping value for projected cell output, if it is not equal to zero, then projected cell output is clipped - * 2: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training - * - * Output arrays: - * 0: cell outputs [time x batchSize x numProj], that is per each time step - * 1: cell states [time x batchSize x numUnits], that is per each time step - */ - #if NOT_EXCLUDED(OP_lstm) - DECLARE_CUSTOM_OP(lstm, 8, 2, false, 3, 2); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation "LSTM time sequences" with peep hole connections: + * + * Input arrays: + * 0: input with shape [time x batchSize x inSize], time - number of time + * steps, batchSize - batch size, inSize - number of features 1: initial cell + * output [batchSize x numProj], that is at time step = 0, in case of + * projection=false -> numProj=numUnits!!! 2: initial cell state [batchSize x + * numUnits], that is at time step = 0 3: input-to-hidden weights, [inSize x + * 4*numUnits] 4: hidden-to-hidden weights, [numProj x 4*numUnits] 5: diagonal + * weights for peephole connections [3*numUnits] 6: projection weights [numUnits + * x numProj] 7: biases, [4*numUnits] + * + * Input integer arguments: + * 0: if not zero, provide peephole connections + * 1: if not zero, then projection is performed, if zero then + * numProj==numUnits is mandatory! + * + * Input float arguments: + * 0: clipping value for cell state, if it is not equal to zero, then cell + * state is clipped 1: clipping value for projected cell output, if it is not + * equal to zero, then projected cell output is clipped 2: the bias added to + * forget gates in order to reduce the scale of forgetting in the beginning of + * the training + * + * Output arrays: + * 0: cell outputs [time x batchSize x numProj], that is per each time step + * 1: cell states [time x batchSize x numUnits], that is per each time step + */ +#if NOT_EXCLUDED(OP_lstm) +DECLARE_CUSTOM_OP(lstm, 8, 2, false, 3, 2); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of gated Recurrent Unit: - * - * Input arrays: - * 0: input with shape [time x batchSize x inSize], time - number of time steps, batchSize - batch size, inSize - number of features - * 1: initial cell output [batchSize x numUnits], that is at time step = 0 - * 2: input-to-hidden weights, [inSize x 3*numUnits] - * 3: hidden-to-hidden weights, [numUnits x 3*numUnits] - * 4: biases, [3*numUnits] - * - * Output arrays: - * 0: cell outputs [time x batchSize x numUnits], that is per each time step - */ - #if NOT_EXCLUDED(OP_gru) - DECLARE_CUSTOM_OP(gru, 5, 1, false, 0, 0); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of gated Recurrent Unit: + * + * Input arrays: + * 0: input with shape [time x batchSize x inSize], time - number of time + * steps, batchSize - batch size, inSize - number of features 1: initial cell + * output [batchSize x numUnits], that is at time step = 0 2: input-to-hidden + * weights, [inSize x 3*numUnits] 3: hidden-to-hidden weights, [numUnits x + * 3*numUnits] 4: biases, [3*numUnits] + * + * Output arrays: + * 0: cell outputs [time x batchSize x numUnits], that is per each time step + */ +#if NOT_EXCLUDED(OP_gru) +DECLARE_CUSTOM_OP(gru, 5, 1, false, 0, 0); +#endif - #if NOT_EXCLUDED(OP_gru) - DECLARE_CUSTOM_OP(gru_bp, 6, 5, false, 0, 0); - #endif +#if NOT_EXCLUDED(OP_gru) +DECLARE_CUSTOM_OP(gru_bp, 6, 5, false, 0, 0); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation "static RNN time sequences" with peep hole connections: - * - * Input arrays: - * 0: input with shape [time x batchSize x inSize], time - number of time steps, batchSize - batch size, inSize - number of features - * 1: input-to-hidden weights, [inSize x numUnits] - * 2: hidden-to-hidden weights, [numUnits x numUnits] - * 3: biases, [2*numUnits] - * 4: (optional) initial cell output [batchSize x numUnits], that is at time step = 0 - * 5: (optional) vector with shape [batchSize] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this provides no calculations for time >= maxTimeStep - * - * Output arrays: - * 0: cell outputs [time x batchSize x numUnits] - * 1: cell final non-zero output [batchSize x numUnits] - */ - DECLARE_CUSTOM_OP(static_rnn, 4, 2, false, 0, 0); +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation "static RNN time sequences" with peep hole + * connections: + * + * Input arrays: + * 0: input with shape [time x batchSize x inSize], time - number of time + * steps, batchSize - batch size, inSize - number of features 1: input-to-hidden + * weights, [inSize x numUnits] 2: hidden-to-hidden weights, [numUnits x + * numUnits] 3: biases, [2*numUnits] 4: (optional) initial cell output + * [batchSize x numUnits], that is at time step = 0 5: (optional) vector with + * shape [batchSize] containing integer values within [0,time), each element of + * this vector set max time step per each input in batch, this provides no + * calculations for time >= maxTimeStep + * + * Output arrays: + * 0: cell outputs [time x batchSize x numUnits] + * 1: cell final non-zero output [batchSize x numUnits] + */ +DECLARE_CUSTOM_OP(static_rnn, 4, 2, false, 0, 0); - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation "static RNN time sequences" with peep hole connections: - * - * Input arrays: - * 0: input with shape [time x batchSize x inSize] or [batchSize x time x numUnits], time - number of time steps, batchSize - batch size, inSize - number of features - * 1: input-to-hidden weights, [inSize x numUnits] - * 2: hidden-to-hidden weights, [numUnits x numUnits] - * 3: biases, [2*numUnits] - * 4: (optional) initial cell output [batchSize x numUnits], that is at time step = 0 - * 5: (optional) vector with shape [batchSize] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this provides no calculations for time >= maxTimeStep - * - * Input integer arguments: - * 0: (optional) timeMajor - if non zero then input shape is [time, batchSize, ...], else [batchSize, time, ...] - * - * Output arrays: - * 0: cell outputs [time x batchSize x numUnits] or [batchSize x time x numUnits] - * 1: cell final non-zero output [batchSize x numUnits] - */ - DECLARE_CUSTOM_OP(dynamic_rnn, 4, 2, false, 0, 0); +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation "static RNN time sequences" with peep hole + * connections: + * + * Input arrays: + * 0: input with shape [time x batchSize x inSize] or [batchSize x time x + * numUnits], time - number of time steps, batchSize - batch size, inSize - + * number of features 1: input-to-hidden weights, [inSize x numUnits] 2: + * hidden-to-hidden weights, [numUnits x numUnits] 3: biases, [2*numUnits] 4: + * (optional) initial cell output [batchSize x numUnits], that is at time step = + * 0 5: (optional) vector with shape [batchSize] containing integer values + * within [0,time), each element of this vector set max time step per each input + * in batch, this provides no calculations for time >= maxTimeStep + * + * Input integer arguments: + * 0: (optional) timeMajor - if non zero then input shape is [time, + * batchSize, ...], else [batchSize, time, ...] + * + * Output arrays: + * 0: cell outputs [time x batchSize x numUnits] or [batchSize x time x + * numUnits] 1: cell final non-zero output [batchSize x numUnits] + */ +DECLARE_CUSTOM_OP(dynamic_rnn, 4, 2, false, 0, 0); - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation "static RNN time sequences" with peep hole connections: - * - * Input arrays: - * 0: input with shape [time x batchSize x inSize], time - number of time steps, batchSize - batch size, inSize - number of features - * 1: input-to-hidden weights for forward RNN, [inSize x numUnitsFW] - * 2: hidden-to-hidden weights for forward RNN, [numUnitsFW x numUnitsFW] - * 3: biases for forward RNN, [2*numUnitsFW] - * 4: input-to-hidden weights for backward RNN, [inSize x numUnitsBW] - * 5: hidden-to-hidden weights for backward RNN, [numUnitsBW x numUnitsBW] - * 6: biases for backward RNN, [2*numUnitsBW] - * 7: (optional) initial cell output for forward RNN [batchSize x numUnitsFW], that is at time step = 0 - * 8: (optional) initial cell output for backward RNN [batchSize x numUnitsBW], that is at time step = 0 - * 9: (optional) vector with shape [batchSize] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this provides no calculations for time >= maxTimeStep - * - * Output arrays: - * 0: cell outputs [time x batchSize x (numUnitsFW + numUnitsBW)] - * 1: cell final non-zero output for forward RNN [batchSize x numUnitsFW] - * 2: cell final non-zero output for backward RNN [batchSize x numUnitsBW] - */ - DECLARE_CUSTOM_OP(static_bidirectional_rnn, 7, 3, false, 0, 0); +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation "static RNN time sequences" with peep hole + * connections: + * + * Input arrays: + * 0: input with shape [time x batchSize x inSize], time - number of time + * steps, batchSize - batch size, inSize - number of features 1: input-to-hidden + * weights for forward RNN, [inSize x numUnitsFW] 2: hidden-to-hidden weights + * for forward RNN, [numUnitsFW x numUnitsFW] 3: biases for forward RNN, + * [2*numUnitsFW] 4: input-to-hidden weights for backward RNN, [inSize x + * numUnitsBW] 5: hidden-to-hidden weights for backward RNN, [numUnitsBW x + * numUnitsBW] 6: biases for backward RNN, [2*numUnitsBW] 7: (optional) initial + * cell output for forward RNN [batchSize x numUnitsFW], that is at time step = + * 0 8: (optional) initial cell output for backward RNN [batchSize x + * numUnitsBW], that is at time step = 0 9: (optional) vector with shape + * [batchSize] containing integer values within [0,time), each element of this + * vector set max time step per each input in batch, this provides no + * calculations for time >= maxTimeStep + * + * Output arrays: + * 0: cell outputs [time x batchSize x (numUnitsFW + numUnitsBW)] + * 1: cell final non-zero output for forward RNN [batchSize x numUnitsFW] + * 2: cell final non-zero output for backward RNN [batchSize x numUnitsBW] + */ +DECLARE_CUSTOM_OP(static_bidirectional_rnn, 7, 3, false, 0, 0); - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation "static RNN time sequences" with peep hole connections: - * - * Input arrays: - * 0: input with shape [time x batchSize x inSize] or [batchSize x time x inSize], time - number of time steps, batchSize - batch size, inSize - number of features - * 1: input-to-hidden weights for forward RNN, [inSize x numUnitsFW] - * 2: hidden-to-hidden weights for forward RNN, [numUnitsFW x numUnitsFW] - * 3: biases for forward RNN, [2*numUnitsFW] - * 4: input-to-hidden weights for backward RNN, [inSize x numUnitsBW] - * 5: hidden-to-hidden weights for backward RNN, [numUnitsBW x numUnitsBW] - * 6: biases for backward RNN, [2*numUnitsBW] - * 7: (optional) initial cell output for forward RNN [batchSize x numUnitsFW], that is at time step = 0 - * 8: (optional) initial cell output for backward RNN [batchSize x numUnitsBW], that is at time step = 0 - * 9: (optional) vector with shape [batchSize] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this provides no calculations for time >= maxTimeStep - * - * Input integer arguments: - * 0: (optional) timeMajor - if non zero then input shape is [time, batchSize, ...], else [batchSize, time, ...] - * - * Output arrays: - * 0: cell outputs for forward RNN [time x batchSize x numUnitsFW] or [batchSize x time x numUnitsFW] - * 1: cell outputs for backward RNN [time x batchSize x numUnitsBW] or [batchSize x time x numUnitsBW] - * 2: cell final non-zero output for forward RNN [batchSize x numUnitsFW] - * 3: cell final non-zero output for backward RNN [batchSize x numUnitsBW] - */ - DECLARE_CUSTOM_OP(dynamic_bidirectional_rnn, 7, 4, false, 0, 0); +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation "static RNN time sequences" with peep hole + * connections: + * + * Input arrays: + * 0: input with shape [time x batchSize x inSize] or [batchSize x time x + * inSize], time - number of time steps, batchSize - batch size, inSize - number + * of features 1: input-to-hidden weights for forward RNN, [inSize x + * numUnitsFW] 2: hidden-to-hidden weights for forward RNN, [numUnitsFW x + * numUnitsFW] 3: biases for forward RNN, [2*numUnitsFW] 4: input-to-hidden + * weights for backward RNN, [inSize x numUnitsBW] 5: hidden-to-hidden weights + * for backward RNN, [numUnitsBW x numUnitsBW] 6: biases for backward RNN, + * [2*numUnitsBW] 7: (optional) initial cell output for forward RNN [batchSize x + * numUnitsFW], that is at time step = 0 8: (optional) initial cell output for + * backward RNN [batchSize x numUnitsBW], that is at time step = 0 9: (optional) + * vector with shape [batchSize] containing integer values within [0,time), each + * element of this vector set max time step per each input in batch, this + * provides no calculations for time >= maxTimeStep + * + * Input integer arguments: + * 0: (optional) timeMajor - if non zero then input shape is [time, + * batchSize, ...], else [batchSize, time, ...] + * + * Output arrays: + * 0: cell outputs for forward RNN [time x batchSize x numUnitsFW] or + * [batchSize x time x numUnitsFW] 1: cell outputs for backward RNN [time x + * batchSize x numUnitsBW] or [batchSize x time x numUnitsBW] 2: cell final + * non-zero output for forward RNN [batchSize x numUnitsFW] 3: cell final + * non-zero output for backward RNN [batchSize x numUnitsBW] + */ +DECLARE_CUSTOM_OP(dynamic_bidirectional_rnn, 7, 4, false, 0, 0); -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/shape.h b/libnd4j/include/ops/declarable/headers/shape.h index 7f93303429f2..66622bd1a12e 100644 --- a/libnd4j/include/ops/declarable/headers/shape.h +++ b/libnd4j/include/ops/declarable/headers/shape.h @@ -24,98 +24,97 @@ #include namespace sd { - namespace ops { - #if NOT_EXCLUDED(OP_permute) - DECLARE_CUSTOM_OP(permute, 1, 1, false, 0, -2); - #endif - - #if NOT_EXCLUDED(OP_reshapeas) - DECLARE_CUSTOM_OP(reshapeas, 2, 1, false, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_transpose) - DECLARE_CUSTOM_OP(transpose, 1, 1, false, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_shape_of) - DECLARE_CUSTOM_OP(shape_of, 1, 1, false, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_shapes_of) - DECLARE_CUSTOM_OP(shapes_of, -1, -1, false, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_squeeze) - DECLARE_CUSTOM_OP(squeeze, 1, 1, false, 0, -2); - #endif - - #if NOT_EXCLUDED(OP_expand_dims) - DECLARE_CUSTOM_OP(expand_dims, 1, 1, false, 0, -2); - #endif - - #if NOT_EXCLUDED(OP_reshape) - DECLARE_CUSTOM_OP(reshape, 1, 1, false, 0, -2); - #endif - - #if NOT_EXCLUDED(OP_size_at) - DECLARE_CUSTOM_OP(size_at, 1, 1, false, 0, 1); - #endif - - /** - * This op changes order of given array to specified order. - * In other words: C/F order switch - * - * Int args: - * 0 - isForder. set to 1 for F order output, or 0 for C order output - * - * @tparam T - */ - #if NOT_EXCLUDED(OP_order) - DECLARE_CUSTOM_OP(order, 1, 1, false, 0, 1); - #endif - - /** - * This op boosts specified input up to specified shape - * - * @tparam T - */ - #if NOT_EXCLUDED(OP_tile_to_shape) - DECLARE_CUSTOM_OP(tile_to_shape, 1, 1, false, 0, -1); - DECLARE_CUSTOM_OP(tile_to_shape_bp, 2, 1, false, 0, -1); - #endif - - /** - * This op broadcast given input up to given shape - * - * inputs: - * input array - array to be broadcasted to given shape - * shape array - array containing shape be broadcasted to - */ - #if NOT_EXCLUDED(OP_broadcast_to) - DECLARE_CUSTOM_OP(broadcast_to, 2, 1, false, 0, 0); - #endif - - - #if NOT_EXCLUDED(OP_evaluate_reduction_shape) - DECLARE_CUSTOM_OP(evaluate_reduction_shape, 2, 1, false, 0, 0); - #endif - - /** - * This operation creates new array - * Input: - * array with shape values - * - * IArgs: - * order value - * data type value - * - * BArgs: - * initialization option - */ - #if NOT_EXCLUDED(OP_create) - DECLARE_CUSTOM_OP(create, 1, 1, false, 0, 1); - #endif - } -} +namespace ops { +#if NOT_EXCLUDED(OP_permute) +DECLARE_CUSTOM_OP(permute, 1, 1, false, 0, -2); +#endif + +#if NOT_EXCLUDED(OP_reshapeas) +DECLARE_CUSTOM_OP(reshapeas, 2, 1, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_transpose) +DECLARE_CUSTOM_OP(transpose, 1, 1, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_shape_of) +DECLARE_CUSTOM_OP(shape_of, 1, 1, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_shapes_of) +DECLARE_CUSTOM_OP(shapes_of, -1, -1, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_squeeze) +DECLARE_CUSTOM_OP(squeeze, 1, 1, false, 0, -2); +#endif + +#if NOT_EXCLUDED(OP_expand_dims) +DECLARE_CUSTOM_OP(expand_dims, 1, 1, false, 0, -2); +#endif + +#if NOT_EXCLUDED(OP_reshape) +DECLARE_CUSTOM_OP(reshape, 1, 1, false, 0, -2); +#endif + +#if NOT_EXCLUDED(OP_size_at) +DECLARE_CUSTOM_OP(size_at, 1, 1, false, 0, 1); +#endif + +/** + * This op changes order of given array to specified order. + * In other words: C/F order switch + * + * Int args: + * 0 - isForder. set to 1 for F order output, or 0 for C order output + * + * @tparam T + */ +#if NOT_EXCLUDED(OP_order) +DECLARE_CUSTOM_OP(order, 1, 1, false, 0, 1); +#endif + +/** + * This op boosts specified input up to specified shape + * + * @tparam T + */ +#if NOT_EXCLUDED(OP_tile_to_shape) +DECLARE_CUSTOM_OP(tile_to_shape, 1, 1, false, 0, -1); +DECLARE_CUSTOM_OP(tile_to_shape_bp, 2, 1, false, 0, -1); +#endif + +/** + * This op broadcast given input up to given shape + * + * inputs: + * input array - array to be broadcasted to given shape + * shape array - array containing shape be broadcasted to + */ +#if NOT_EXCLUDED(OP_broadcast_to) +DECLARE_CUSTOM_OP(broadcast_to, 2, 1, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_evaluate_reduction_shape) +DECLARE_CUSTOM_OP(evaluate_reduction_shape, 2, 1, false, 0, 0); +#endif + +/** + * This operation creates new array + * Input: + * array with shape values + * + * IArgs: + * order value + * data type value + * + * BArgs: + * initialization option + */ +#if NOT_EXCLUDED(OP_create) +DECLARE_CUSTOM_OP(create, 1, 1, false, 0, 1); +#endif +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/strings.h b/libnd4j/include/ops/declarable/headers/strings.h index bd4b8b94996e..303a406a2b6d 100644 --- a/libnd4j/include/ops/declarable/headers/strings.h +++ b/libnd4j/include/ops/declarable/headers/strings.h @@ -24,19 +24,18 @@ #include namespace sd { - namespace ops { - /** - * This operation splits input string into pieces separated by delimiter - * - * Input[0] - string to split - * Input[1] - delimiter - */ - #if NOT_EXCLUDED(OP_split_string) - DECLARE_CUSTOM_OP(split_string, 2, 1, true, 0, 0); - #endif - - } -} +namespace ops { +/** + * This operation splits input string into pieces separated by delimiter + * + * Input[0] - string to split + * Input[1] - delimiter + */ +#if NOT_EXCLUDED(OP_split_string) +DECLARE_CUSTOM_OP(split_string, 2, 1, true, 0, 0); +#endif +} // namespace ops +} // namespace sd -#endif //SAMEDIFF_STRINGS_H +#endif // SAMEDIFF_STRINGS_H diff --git a/libnd4j/include/ops/declarable/headers/tests.h b/libnd4j/include/ops/declarable/headers/tests.h index cad12b3c85f9..c54d5ed90320 100644 --- a/libnd4j/include/ops/declarable/headers/tests.h +++ b/libnd4j/include/ops/declarable/headers/tests.h @@ -20,29 +20,29 @@ #include namespace sd { - namespace ops { - #if NOT_EXCLUDED(OP_test_output_reshape) - DECLARE_OP(test_output_reshape, 1, 1, true); - #endif +namespace ops { +#if NOT_EXCLUDED(OP_test_output_reshape) +DECLARE_OP(test_output_reshape, 1, 1, true); +#endif - #if NOT_EXCLUDED(OP_test_scalar) - DECLARE_CUSTOM_OP(test_scalar, 1, 1, false, 0, 0); - #endif +#if NOT_EXCLUDED(OP_test_scalar) +DECLARE_CUSTOM_OP(test_scalar, 1, 1, false, 0, 0); +#endif - #if NOT_EXCLUDED(OP_testreduction) - DECLARE_REDUCTION_OP(testreduction, 1, 1, false, 0, -1); - #endif +#if NOT_EXCLUDED(OP_testreduction) +DECLARE_REDUCTION_OP(testreduction, 1, 1, false, 0, -1); +#endif - #if NOT_EXCLUDED(OP_noop) - DECLARE_OP(noop, -2, -2, true); - #endif +#if NOT_EXCLUDED(OP_noop) +DECLARE_OP(noop, -2, -2, true); +#endif - #if NOT_EXCLUDED(OP_testop2i2o) - DECLARE_OP(testop2i2o, 2, 2, true); - #endif +#if NOT_EXCLUDED(OP_testop2i2o) +DECLARE_OP(testop2i2o, 2, 2, true); +#endif - #if NOT_EXCLUDED(OP_testcustom) - DECLARE_CUSTOM_OP(testcustom, 1, 1, false, 0, -1); - #endif - } -} \ No newline at end of file +#if NOT_EXCLUDED(OP_testcustom) +DECLARE_CUSTOM_OP(testcustom, 1, 1, false, 0, -1); +#endif +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/third_party.h b/libnd4j/include/ops/declarable/headers/third_party.h index 705a02903112..c4cad9eabb0d 100644 --- a/libnd4j/include/ops/declarable/headers/third_party.h +++ b/libnd4j/include/ops/declarable/headers/third_party.h @@ -24,11 +24,11 @@ #include namespace sd { - namespace ops { - #if NOT_EXCLUDED(OP_firas_sparse) - DECLARE_CUSTOM_OP(firas_sparse, 1, 1, false, 0, -1); - #endif - } -} +namespace ops { +#if NOT_EXCLUDED(OP_firas_sparse) +DECLARE_CUSTOM_OP(firas_sparse, 1, 1, false, 0, -1); +#endif +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/transforms.h b/libnd4j/include/ops/declarable/headers/transforms.h index 3fe2f1223f3e..eda225d79f31 100644 --- a/libnd4j/include/ops/declarable/headers/transforms.h +++ b/libnd4j/include/ops/declarable/headers/transforms.h @@ -24,209 +24,210 @@ #include namespace sd { - namespace ops { - #if NOT_EXCLUDED(OP_clipbyvalue) - DECLARE_CONFIGURABLE_OP(clipbyvalue, 1, 1, true, 2, 0); - #endif - - #if NOT_EXCLUDED(OP_clipbynorm) - DECLARE_CONFIGURABLE_OP(clipbynorm, 1, 1, true, 1, 0); - DECLARE_CUSTOM_OP(clipbynorm_bp, 2, 1, false, 1, 0); - #endif - - #if NOT_EXCLUDED(OP_clipbyavgnorm) - DECLARE_CONFIGURABLE_OP(clipbyavgnorm, 1, 1, true, 1, 0); - DECLARE_CUSTOM_OP(clipbyavgnorm_bp, 2, 1, false, 1, 0); - #endif - - #if NOT_EXCLUDED(OP_cumsum) - DECLARE_CONFIGURABLE_OP(cumsum, 1, 1, true, 0, 2); - #endif - - #if NOT_EXCLUDED(OP_cumprod) - DECLARE_CONFIGURABLE_OP(cumprod, 1, 1, true, 0, 2); - #endif - - #if NOT_EXCLUDED(OP_tile) - DECLARE_CUSTOM_OP(tile, 1, 1, false, 0, -2); - DECLARE_CUSTOM_OP(tile_bp, 2, 1, false, 0, -2); - #endif - - #if NOT_EXCLUDED(OP_repeat) - DECLARE_CUSTOM_OP(repeat, 1, 1, true, 0, -1); - #endif - - #if NOT_EXCLUDED(OP_invert_permutation) - DECLARE_CONFIGURABLE_OP(invert_permutation, 1, 1, false, 0, 0); - #endif - - DECLARE_CUSTOM_OP(concat, -1, 1, false, 0, 0); - DECLARE_CUSTOM_OP(concat_bp, -1, -1, false, 0, 0); - - #if NOT_EXCLUDED(OP_mergemax) - DECLARE_OP(mergemax, -1, 1, false); - DECLARE_CUSTOM_OP(mergemax_bp, 2, 1, false, 0, 0); - #endif - /* - * Complete tensor with max indices merged from all input tensors list - * - * INPUT: tensors with the same shape - * OUTPUT: integer tensor with the same shape - * INT_ARG: result type (one of int), INT32 by default - */ - #if NOT_EXCLUDED(OP_mergemaxindex) - DECLARE_CUSTOM_OP(mergemaxindex, -1, 1, false, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_mergeadd) - DECLARE_OP(mergeadd, -1, 1, false); - DECLARE_CUSTOM_OP(mergeadd_bp, 2, 1, false, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_mergeavg) - DECLARE_OP(mergeavg, -1, 1, false); - DECLARE_CUSTOM_OP(mergeavg_bp, 2, 1, false, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_scatter_update) - DECLARE_CONFIGURABLE_OP(scatter_update, 2, 1, true, 0, -1); - #endif - - #if NOT_EXCLUDED(OP_Floor) - DECLARE_OP(Floor, 1, 1, true); - #endif - - #if NOT_EXCLUDED(OP_Log1p) - DECLARE_OP(Log1p, 2, 1, true); - #endif - - #if NOT_EXCLUDED(OP_reverse) - DECLARE_CONFIGURABLE_OP(reverse, 1, 1, true, 0, -2); - DECLARE_CUSTOM_OP(reverse_bp, 2, 1, false, 0, -2); - #endif - - #if NOT_EXCLUDED(OP_gather) - DECLARE_CUSTOM_OP(gather, 1, 1, false, 0, -2); - #endif - - #if NOT_EXCLUDED(OP_pad) - DECLARE_CUSTOM_OP(pad, 2, 1, false, 0, 1); - #endif - - /** - * creates identity 2D matrix or batch of identical 2D identity matrices - * - * Input array: - * provide some array - in any case operation simply neglects it - * - * Input float argument (if passed): - * TArgs[0] - type of elements of output array, default value is 5 (float) - * - * Input integer arguments: - * IArgs[0] - order of output identity matrix, 99 -> 'c'-order, 102 -> 'f'-order - * IArgs[1] - the number of rows in output inner-most 2D identity matrix - * IArgs[2] - optional, the number of columns in output inner-most 2D identity matrix, if this argument is not provided then it is taken to be equal to number of rows - * IArgs[3,4,...] - optional, shape of batch, output matrix will have leading batch dimensions of this shape - */ - #if NOT_EXCLUDED(OP_eye) - DECLARE_CUSTOM_OP(eye, -2, 1, false, -2, 2); - #endif - - #if NOT_EXCLUDED(OP_gather_nd) - DECLARE_CUSTOM_OP(gather_nd, 2, 1, false, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_reverse_sequence) - DECLARE_CUSTOM_OP(reverse_sequence, 2, 1, false, 0, 2); - #endif - - #if NOT_EXCLUDED(OP_trace) - DECLARE_CUSTOM_OP(trace, 1, 1, false, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_random_shuffle) - DECLARE_OP(random_shuffle, 1, 1, true); - #endif - - /** - * clip a list of given tensors with given average norm when needed - * - * Input: - * a list of tensors (at least one) - * - * Input floating point argument: - * clip_norm - a value that used as threshold value and norm to be used - * - * return a list of clipped tensors - * and global_norm as scalar tensor at the end - */ - #if NOT_EXCLUDED(OP_clip_by_global_norm) - DECLARE_CUSTOM_OP(clip_by_global_norm, 1, 2, true, 1, 0); - #endif - - DECLARE_CUSTOM_OP(tri, -2, 1, false, 0, 1); - - DECLARE_CUSTOM_OP(triu, 1, 1, false, 0, 0); - - DECLARE_CUSTOM_OP(triu_bp, 2, 1, false, 0, 0); - - #if NOT_EXCLUDED(OP_mirror_pad) - DECLARE_CUSTOM_OP(mirror_pad, 2, 1, false, 0, 1); - #endif - - #if NOT_EXCLUDED(OP_cumsum) - DECLARE_CUSTOM_OP(cumsum_bp, 2, -1, false, 0, 2); - #endif - - #if NOT_EXCLUDED(OP_cumprod) - DECLARE_CUSTOM_OP(cumprod_bp, 2, -21, false, 0, 2); - #endif - - - #if NOT_EXCLUDED(OP_flatten) - DECLARE_CUSTOM_OP(flatten, -1, 1, false, 0, 1); - #endif - - /** - * returns histogram (as 1D array) with fixed bins width - * - * Input arrays: - * - input array with elements to be binned into output histogram - * - range array with first element being bottom limit and second element being top limit of histogram, - please note that input_value <= range[0] will be mapped to histogram[0], input_value >= range[1] will be mapped to histogram[-1] - * - * Input integer arguments: - * nbins (optional) - number of histogram bins, default value is 100 - */ - #if NOT_EXCLUDED(OP_histogram_fixed_width) - DECLARE_CUSTOM_OP(histogram_fixed_width, 2, 1, false, 0, 0); - #endif - - - /** - * standardizes input array to be zero mean unit variance along the given axis - * - * - */ - #if NOT_EXCLUDED(OP_standardize) - DECLARE_CONFIGURABLE_OP(standardize, 1, 1, true, 0, -2); - DECLARE_CUSTOM_OP(standardize_bp, 2, 1, false, 0, -2); - #endif - - /** - * This operation calculates hash code, optionally along dimension - */ - #if NOT_EXCLUDED(OP_hashcode) - DECLARE_CUSTOM_OP(hashcode, 1, 1, false, 0, 0); - #endif - - /** - * This operation calculates number of entries per bin - */ - #if NOT_EXCLUDED(OP_histogram) - DECLARE_CUSTOM_OP(histogram, 1, 1, false, 0, 1); - #endif - } -} +namespace ops { +#if NOT_EXCLUDED(OP_clipbyvalue) +DECLARE_CONFIGURABLE_OP(clipbyvalue, 1, 1, true, 2, 0); +#endif + +#if NOT_EXCLUDED(OP_clipbynorm) +DECLARE_CONFIGURABLE_OP(clipbynorm, 1, 1, true, 1, 0); +DECLARE_CUSTOM_OP(clipbynorm_bp, 2, 1, false, 1, 0); +#endif + +#if NOT_EXCLUDED(OP_clipbyavgnorm) +DECLARE_CONFIGURABLE_OP(clipbyavgnorm, 1, 1, true, 1, 0); +DECLARE_CUSTOM_OP(clipbyavgnorm_bp, 2, 1, false, 1, 0); +#endif + +#if NOT_EXCLUDED(OP_cumsum) +DECLARE_CONFIGURABLE_OP(cumsum, 1, 1, true, 0, 2); +#endif + +#if NOT_EXCLUDED(OP_cumprod) +DECLARE_CONFIGURABLE_OP(cumprod, 1, 1, true, 0, 2); +#endif + +#if NOT_EXCLUDED(OP_tile) +DECLARE_CUSTOM_OP(tile, 1, 1, false, 0, -2); +DECLARE_CUSTOM_OP(tile_bp, 2, 1, false, 0, -2); +#endif + +#if NOT_EXCLUDED(OP_repeat) +DECLARE_CUSTOM_OP(repeat, 1, 1, true, 0, -1); +#endif + +#if NOT_EXCLUDED(OP_invert_permutation) +DECLARE_CONFIGURABLE_OP(invert_permutation, 1, 1, false, 0, 0); +#endif + +DECLARE_CUSTOM_OP(concat, -1, 1, false, 0, 0); +DECLARE_CUSTOM_OP(concat_bp, -1, -1, false, 0, 0); + +#if NOT_EXCLUDED(OP_mergemax) +DECLARE_OP(mergemax, -1, 1, false); +DECLARE_CUSTOM_OP(mergemax_bp, 2, 1, false, 0, 0); +#endif +/* + * Complete tensor with max indices merged from all input tensors list + * + * INPUT: tensors with the same shape + * OUTPUT: integer tensor with the same shape + * INT_ARG: result type (one of int), INT32 by default + */ +#if NOT_EXCLUDED(OP_mergemaxindex) +DECLARE_CUSTOM_OP(mergemaxindex, -1, 1, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_mergeadd) +DECLARE_OP(mergeadd, -1, 1, false); +DECLARE_CUSTOM_OP(mergeadd_bp, 2, 1, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_mergeavg) +DECLARE_OP(mergeavg, -1, 1, false); +DECLARE_CUSTOM_OP(mergeavg_bp, 2, 1, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_scatter_update) +DECLARE_CONFIGURABLE_OP(scatter_update, 2, 1, true, 0, -1); +#endif + +#if NOT_EXCLUDED(OP_Floor) +DECLARE_OP(Floor, 1, 1, true); +#endif + +#if NOT_EXCLUDED(OP_Log1p) +DECLARE_OP(Log1p, 2, 1, true); +#endif + +#if NOT_EXCLUDED(OP_reverse) +DECLARE_CONFIGURABLE_OP(reverse, 1, 1, true, 0, -2); +DECLARE_CUSTOM_OP(reverse_bp, 2, 1, false, 0, -2); +#endif + +#if NOT_EXCLUDED(OP_gather) +DECLARE_CUSTOM_OP(gather, 1, 1, false, 0, -2); +#endif + +#if NOT_EXCLUDED(OP_pad) +DECLARE_CUSTOM_OP(pad, 2, 1, false, 0, 1); +#endif + +/** + * creates identity 2D matrix or batch of identical 2D identity matrices + * + * Input array: + * provide some array - in any case operation simply neglects it + * + * Input float argument (if passed): + * TArgs[0] - type of elements of output array, default value is 5 (float) + * + * Input integer arguments: + * IArgs[0] - order of output identity matrix, 99 -> 'c'-order, 102 -> + * 'f'-order IArgs[1] - the number of rows in output inner-most 2D + * identity matrix IArgs[2] - optional, the number of columns in output + * inner-most 2D identity matrix, if this argument is not provided then it is + * taken to be equal to number of rows IArgs[3,4,...] - optional, shape of + * batch, output matrix will have leading batch dimensions of this shape + */ +#if NOT_EXCLUDED(OP_eye) +DECLARE_CUSTOM_OP(eye, -2, 1, false, -2, 2); +#endif + +#if NOT_EXCLUDED(OP_gather_nd) +DECLARE_CUSTOM_OP(gather_nd, 2, 1, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_reverse_sequence) +DECLARE_CUSTOM_OP(reverse_sequence, 2, 1, false, 0, 2); +#endif + +#if NOT_EXCLUDED(OP_trace) +DECLARE_CUSTOM_OP(trace, 1, 1, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_random_shuffle) +DECLARE_OP(random_shuffle, 1, 1, true); +#endif + +/** + * clip a list of given tensors with given average norm when needed + * + * Input: + * a list of tensors (at least one) + * + * Input floating point argument: + * clip_norm - a value that used as threshold value and norm to be used + * + * return a list of clipped tensors + * and global_norm as scalar tensor at the end + */ +#if NOT_EXCLUDED(OP_clip_by_global_norm) +DECLARE_CUSTOM_OP(clip_by_global_norm, 1, 2, true, 1, 0); +#endif + +DECLARE_CUSTOM_OP(tri, -2, 1, false, 0, 1); + +DECLARE_CUSTOM_OP(triu, 1, 1, false, 0, 0); + +DECLARE_CUSTOM_OP(triu_bp, 2, 1, false, 0, 0); + +#if NOT_EXCLUDED(OP_mirror_pad) +DECLARE_CUSTOM_OP(mirror_pad, 2, 1, false, 0, 1); +#endif + +#if NOT_EXCLUDED(OP_cumsum) +DECLARE_CUSTOM_OP(cumsum_bp, 2, -1, false, 0, 2); +#endif + +#if NOT_EXCLUDED(OP_cumprod) +DECLARE_CUSTOM_OP(cumprod_bp, 2, -21, false, 0, 2); +#endif + +#if NOT_EXCLUDED(OP_flatten) +DECLARE_CUSTOM_OP(flatten, -1, 1, false, 0, 1); +#endif + +/** + * returns histogram (as 1D array) with fixed bins width + * + * Input arrays: + * - input array with elements to be binned into output histogram + * - range array with first element being bottom limit and second element being + top limit of histogram, please note that input_value <= range[0] will be mapped + to histogram[0], input_value >= range[1] will be mapped to histogram[-1] + * + * Input integer arguments: + * nbins (optional) - number of histogram bins, default value is 100 + */ +#if NOT_EXCLUDED(OP_histogram_fixed_width) +DECLARE_CUSTOM_OP(histogram_fixed_width, 2, 1, false, 0, 0); +#endif + +/** + * standardizes input array to be zero mean unit variance along the given axis + * + * + */ +#if NOT_EXCLUDED(OP_standardize) +DECLARE_CONFIGURABLE_OP(standardize, 1, 1, true, 0, -2); +DECLARE_CUSTOM_OP(standardize_bp, 2, 1, false, 0, -2); +#endif + +/** + * This operation calculates hash code, optionally along dimension + */ +#if NOT_EXCLUDED(OP_hashcode) +DECLARE_CUSTOM_OP(hashcode, 1, 1, false, 0, 0); +#endif + +/** + * This operation calculates number of entries per bin + */ +#if NOT_EXCLUDED(OP_histogram) +DECLARE_CUSTOM_OP(histogram, 1, 1, false, 0, 1); +#endif +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/updaters.h b/libnd4j/include/ops/declarable/headers/updaters.h index dc08ff1f2116..c7e6f177505c 100644 --- a/libnd4j/include/ops/declarable/headers/updaters.h +++ b/libnd4j/include/ops/declarable/headers/updaters.h @@ -14,197 +14,194 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleh Semeniv (oleg.semeniv@gmail.com) - // - +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// #ifndef LIBND4J_HEADERS_UPDATERS_H #define LIBND4J_HEADERS_UPDATERS_H -#include -#include -#include #include +#include +#include +#include #include - namespace sd { - namespace ops { - +namespace ops { - /** - * SGD updater - * Input arrays: - * 0 - input array with gradients. - * Optional: - * 1 - scalar learning rate value - * Optional: - * T args - * 0 - scalar learning rate value - */ +/** + * SGD updater + * Input arrays: + * 0 - input array with gradients. + * Optional: + * 1 - scalar learning rate value + * Optional: + * T args + * 0 - scalar learning rate value + */ #if NOT_EXCLUDED(OP_sgd_updater) - DECLARE_CONFIGURABLE_OP(sgd_updater, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(sgd_updater, 1, 1, true, 0, 0); #endif - /** - * RmsPropUpdater updater - * Input arrays: - * 0 - input array with gradients. - * 1 - Initial state - * Optional: - * 2 - scalar learning rate value - * 3 - scalar rms decay - * 4 - epsilon - * Optional: - * T args - * 0 - scalar learning rate value - * 1 - scalar rms decay - * 2 - epsilon - */ +/** + * RmsPropUpdater updater + * Input arrays: + * 0 - input array with gradients. + * 1 - Initial state + * Optional: + * 2 - scalar learning rate value + * 3 - scalar rms decay + * 4 - epsilon + * Optional: + * T args + * 0 - scalar learning rate value + * 1 - scalar rms decay + * 2 - epsilon + */ #if NOT_EXCLUDED(OP_rms_prop_updater) - DECLARE_CONFIGURABLE_OP(rms_prop_updater, 2, 2, true, 0, 0); +DECLARE_CONFIGURABLE_OP(rms_prop_updater, 2, 2, true, 0, 0); #endif - // AdaGrad - /* Input arrays : - * 0 - input array with gradients. - * 1 - historical grad state - * Optional : - * 2 - scalar learning rate value - * 3 - epsilon - * Optional: - * T args - * 0 - scalar learning rate value - * 1 - epsilon - */ +// AdaGrad +/* Input arrays : + * 0 - input array with gradients. + * 1 - historical grad state + * Optional : + * 2 - scalar learning rate value + * 3 - epsilon + * Optional: + * T args + * 0 - scalar learning rate value + * 1 - epsilon + */ #if NOT_EXCLUDED(OP_ada_grad_updater) - DECLARE_CONFIGURABLE_OP(ada_grad_updater, 2, 2, true, 0, 0); +DECLARE_CONFIGURABLE_OP(ada_grad_updater, 2, 2, true, 0, 0); #endif - // AdaMax - /* Input arrays : - * 0 - input array with gradients. - * 1 - gradient state V - * 2 - gradient state M - * Optional : - * 3 - scalar learning rate value - * 4 - beta 1 value - * 5 - beta 2 value - * 6 - epsilon - * Optional: - * T args - * 0 - scalar learning rate value - * 1 - beta 1 value - * 2 - beta 2 value - * 3 - epsilon - * Optional: - * I args - * 0 - iteration - */ +// AdaMax +/* Input arrays : + * 0 - input array with gradients. + * 1 - gradient state V + * 2 - gradient state M + * Optional : + * 3 - scalar learning rate value + * 4 - beta 1 value + * 5 - beta 2 value + * 6 - epsilon + * Optional: + * T args + * 0 - scalar learning rate value + * 1 - beta 1 value + * 2 - beta 2 value + * 3 - epsilon + * Optional: + * I args + * 0 - iteration + */ #if NOT_EXCLUDED(OP_ada_max_updater) - DECLARE_CONFIGURABLE_OP(ada_max_updater, 3, 3, true, 0, 0); +DECLARE_CONFIGURABLE_OP(ada_max_updater, 3, 3, true, 0, 0); #endif - // Nesterov's momentum - /* Input arrays : - * 0 - input array with gradients. - * 1 - V grad state - * Optional : - * 2 - scalar learning rate value - * 3 - scalar momentum value - * Optional: - * T args - * 0 - learning rate value - * 1 - momentum value - */ +// Nesterov's momentum +/* Input arrays : + * 0 - input array with gradients. + * 1 - V grad state + * Optional : + * 2 - scalar learning rate value + * 3 - scalar momentum value + * Optional: + * T args + * 0 - learning rate value + * 1 - momentum value + */ #if NOT_EXCLUDED(OP_nesterovs_updater) - DECLARE_CONFIGURABLE_OP(nesterovs_updater, 2, 2, true, 0, 0); +DECLARE_CONFIGURABLE_OP(nesterovs_updater, 2, 2, true, 0, 0); #endif - // Adam - /* Input arrays : - * 0 - input array with gradients. - * 1 - gradient state V - * 2 - gradient state M - * Optional : - * 3 - scalar learning rate value - * 4 - beta 1 value - * 5 - beta 2 value - * 6 - epsilon - * Optional: - * T args - * 0 - scalar learning rate value - * 1 - beta 1 value - * 2 - beta 2 value - * 3 - epsilon - * Optional: - * I args - * 0 - iteration - */ +// Adam +/* Input arrays : + * 0 - input array with gradients. + * 1 - gradient state V + * 2 - gradient state M + * Optional : + * 3 - scalar learning rate value + * 4 - beta 1 value + * 5 - beta 2 value + * 6 - epsilon + * Optional: + * T args + * 0 - scalar learning rate value + * 1 - beta 1 value + * 2 - beta 2 value + * 3 - epsilon + * Optional: + * I args + * 0 - iteration + */ #if NOT_EXCLUDED(OP_adam_updater) - DECLARE_CONFIGURABLE_OP(adam_updater, 3, 3, true, 0, 0); +DECLARE_CONFIGURABLE_OP(adam_updater, 3, 3, true, 0, 0); #endif - // AdaDelta - /* Input arrays : - * 0 - input array with gradients. - * 1 - gradient state V - * 2 - gradient state M - * Optional : - * 3 - rho value - * 6 - epsilon - * Optional: - * T args - * 0 - rho - * 1 - epsilon - */ +// AdaDelta +/* Input arrays : + * 0 - input array with gradients. + * 1 - gradient state V + * 2 - gradient state M + * Optional : + * 3 - rho value + * 6 - epsilon + * Optional: + * T args + * 0 - rho + * 1 - epsilon + */ #if NOT_EXCLUDED(OP_ada_delta_updater) - DECLARE_CONFIGURABLE_OP(ada_delta_updater, 3, 3, true, 0, 0); +DECLARE_CONFIGURABLE_OP(ada_delta_updater, 3, 3, true, 0, 0); #endif - // Nadam - /* Input arrays : - * 0 - input array with gradients. - * 1 - gradient state V - * 2 - gradient state M - * Optional : - * 3 - scalar learning rate value - * 4 - beta 1 value - * 5 - beta 2 value - * 6 - epsilon - * Optional: - * T args - * 0 - scalar learning rate value - * 1 - beta 1 value - * 2 - beta 2 value - * 3 - epsilon - * Optional: - * I args - * 0 - iteration - */ +// Nadam +/* Input arrays : + * 0 - input array with gradients. + * 1 - gradient state V + * 2 - gradient state M + * Optional : + * 3 - scalar learning rate value + * 4 - beta 1 value + * 5 - beta 2 value + * 6 - epsilon + * Optional: + * T args + * 0 - scalar learning rate value + * 1 - beta 1 value + * 2 - beta 2 value + * 3 - epsilon + * Optional: + * I args + * 0 - iteration + */ #if NOT_EXCLUDED(OP_nadam_updater) - DECLARE_CONFIGURABLE_OP(nadam_updater, 3, 3, true, 0, 0); -#endif - // AmsGrad - /* Input arrays : - * 0 - input array with gradients. - * 1 - gradient state V - sqrd gradients - * 2 - gradient state M - moving avg - * 3 - gradient state H - max - * Optional : - * 4 - scalar learning rate value - * 5 - beta 1 value - * 6 - beta 2 value - * 7 - epsilon - * Optional: - * T args - * 0 - scalar learning rate value - * 1 - beta 1 value - * 2 - beta 2 value - * 3 - epsilon - * Optional: - * I args - * 0 - iteration - */ +DECLARE_CONFIGURABLE_OP(nadam_updater, 3, 3, true, 0, 0); +#endif +// AmsGrad +/* Input arrays : + * 0 - input array with gradients. + * 1 - gradient state V - sqrd gradients + * 2 - gradient state M - moving avg + * 3 - gradient state H - max + * Optional : + * 4 - scalar learning rate value + * 5 - beta 1 value + * 6 - beta 2 value + * 7 - epsilon + * Optional: + * T args + * 0 - scalar learning rate value + * 1 - beta 1 value + * 2 - beta 2 value + * 3 - epsilon + * Optional: + * I args + * 0 - iteration + */ #if NOT_EXCLUDED(OP_ams_grad_updater) - DECLARE_CONFIGURABLE_OP(ams_grad_updater, 4, 4, true, 0, 0); -#endif -} -} +DECLARE_CONFIGURABLE_OP(ams_grad_updater, 4, 4, true, 0, 0); +#endif +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/headers/util.h b/libnd4j/include/ops/declarable/headers/util.h index 57b013f298b5..a32a700bcd09 100644 --- a/libnd4j/include/ops/declarable/headers/util.h +++ b/libnd4j/include/ops/declarable/headers/util.h @@ -24,21 +24,21 @@ #include namespace sd { - namespace ops { - /** - * This operation prints out NDArray content, either on host or device. - */ - #if NOT_EXCLUDED(OP_print_variable) - DECLARE_CUSTOM_OP(print_variable, 1, 1, true, 0, 0); - #endif +namespace ops { +/** + * This operation prints out NDArray content, either on host or device. + */ +#if NOT_EXCLUDED(OP_print_variable) +DECLARE_CUSTOM_OP(print_variable, 1, 1, true, 0, 0); +#endif - /** - * This operation prints out affinity & locality status of given NDArray - */ - #if NOT_EXCLUDED(OP_print_affinity) - DECLARE_CUSTOM_OP(print_affinity, 1, 1, true, 0, 0); - #endif - } -} +/** + * This operation prints out affinity & locality status of given NDArray + */ +#if NOT_EXCLUDED(OP_print_affinity) +DECLARE_CUSTOM_OP(print_affinity, 1, 1, true, 0, 0); +#endif +} // namespace ops +} // namespace sd -#endif //LIBND4J_UTILS_H +#endif // LIBND4J_UTILS_H diff --git a/libnd4j/include/ops/declarable/helpers/BarnesHutTsne.h b/libnd4j/include/ops/declarable/helpers/BarnesHutTsne.h index f52dd9ba44ac..bb03c96fa66b 100644 --- a/libnd4j/include/ops/declarable/helpers/BarnesHutTsne.h +++ b/libnd4j/include/ops/declarable/helpers/BarnesHutTsne.h @@ -27,15 +27,22 @@ namespace sd { namespace ops { namespace helpers { - Nd4jLong barnes_row_count(const NDArray* rowP, const NDArray* colP, Nd4jLong N, NDArray& rowCounts); - void barnes_symmetrize(const NDArray* rowP, const NDArray* colP, const NDArray* valP, Nd4jLong N, NDArray* outputRows, NDArray* outputCols, NDArray* outputVals, NDArray* rowCounts = nullptr); - void barnes_edge_forces(const NDArray* rowP, NDArray const* colP, NDArray const* valP, int N, NDArray* output, NDArray const& data); - void barnes_gains(NDArray* input, NDArray* gradX, NDArray* epsilon, NDArray* output); - bool cell_contains(NDArray* corner, NDArray* width, NDArray* point, Nd4jLong dimension); - -} -} -} - - -#endif //LIBND4J_ACTIVATIONS_H +Nd4jLong barnes_row_count(const NDArray* rowP, const NDArray* colP, Nd4jLong N, + NDArray& rowCounts); +void barnes_symmetrize(const NDArray* rowP, const NDArray* colP, + const NDArray* valP, Nd4jLong N, NDArray* outputRows, + NDArray* outputCols, NDArray* outputVals, + NDArray* rowCounts = nullptr); +void barnes_edge_forces(const NDArray* rowP, NDArray const* colP, + NDArray const* valP, int N, NDArray* output, + NDArray const& data); +void barnes_gains(NDArray* input, NDArray* gradX, NDArray* epsilon, + NDArray* output); +bool cell_contains(NDArray* corner, NDArray* width, NDArray* point, + Nd4jLong dimension); + +} // namespace helpers +} // namespace ops +} // namespace sd + +#endif // LIBND4J_ACTIVATIONS_H diff --git a/libnd4j/include/ops/declarable/helpers/activations.h b/libnd4j/include/ops/declarable/helpers/activations.h index ab652ab24fe1..9065464ebd62 100644 --- a/libnd4j/include/ops/declarable/helpers/activations.h +++ b/libnd4j/include/ops/declarable/helpers/activations.h @@ -27,26 +27,37 @@ namespace sd { namespace ops { namespace helpers { - ND4J_EXPORT void softMaxForVector(sd::LaunchContext * context, const NDArray &input, NDArray &output); +SD_EXPORT void softMaxForVector(sd::LaunchContext *context, + const NDArray &input, NDArray &output); - ND4J_EXPORT void logSoftMaxForVector(sd::LaunchContext * context, const NDArray &input, NDArray &output); +SD_EXPORT void logSoftMaxForVector(sd::LaunchContext *context, + const NDArray &input, NDArray &output); - ND4J_EXPORT void softmax(sd::LaunchContext * context, const NDArray &input, NDArray &output, const int dimension); +SD_EXPORT void softmax(sd::LaunchContext *context, const NDArray &input, + NDArray &output, const int dimension); - ND4J_EXPORT void logSoftmax(sd::LaunchContext * context, const NDArray &input, NDArray &output, const int dimension); +SD_EXPORT void logSoftmax(sd::LaunchContext *context, const NDArray &input, + NDArray &output, const int dimension); - ND4J_EXPORT void softmaxDerivative(sd::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension); +SD_EXPORT void softmaxDerivative(sd::LaunchContext *context, + const NDArray &input, NDArray &output, + const int dimension); - ND4J_EXPORT void prelu(sd::LaunchContext * context, const NDArray &input, const NDArray &alpha, NDArray &output); +SD_EXPORT void prelu(sd::LaunchContext *context, const NDArray &input, + const NDArray &alpha, NDArray &output); - ND4J_EXPORT void preluBP(sd::LaunchContext * context, const NDArray &input, const NDArray &alpha, const NDArray &dLdO, NDArray &dLdI, NDArray &dLdA); +SD_EXPORT void preluBP(sd::LaunchContext *context, const NDArray &input, + const NDArray &alpha, const NDArray &dLdO, NDArray &dLdI, + NDArray &dLdA); - ND4J_EXPORT void thresholdRelu(sd::LaunchContext * context, const NDArray &input, double threshold, NDArray &output); +SD_EXPORT void thresholdRelu(sd::LaunchContext *context, const NDArray &input, + double threshold, NDArray &output); - ND4J_EXPORT void thresholdReluDerivative(sd::LaunchContext * context, NDArray *input, double threshold, NDArray* dLdO, NDArray *output); -} -} -} +SD_EXPORT void thresholdReluDerivative(sd::LaunchContext *context, + NDArray *input, double threshold, + NDArray *dLdO, NDArray *output); +} // namespace helpers +} // namespace ops +} // namespace sd - -#endif //LIBND4J_ACTIVATIONS_H +#endif // LIBND4J_ACTIVATIONS_H diff --git a/libnd4j/include/ops/declarable/helpers/addBias.h b/libnd4j/include/ops/declarable/helpers/addBias.h index 8eff731f7401..3760d90999d5 100644 --- a/libnd4j/include/ops/declarable/helpers/addBias.h +++ b/libnd4j/include/ops/declarable/helpers/addBias.h @@ -21,20 +21,18 @@ #ifndef LIBND4J_ADDBIAS_H #define LIBND4J_ADDBIAS_H -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - - void addBias(graph::Context& block, const NDArray& input, const NDArray& bias, NDArray& output, const bool isNCHW); - +void addBias(graph::Context& block, const NDArray& input, const NDArray& bias, + NDArray& output, const bool isNCHW); } -} -} - +} // namespace ops +} // namespace sd -#endif // LIBND4J_ADDBIAS_H +#endif // LIBND4J_ADDBIAS_H diff --git a/libnd4j/include/ops/declarable/helpers/adjust_hue.h b/libnd4j/include/ops/declarable/helpers/adjust_hue.h index 2d0e2f08762d..3dda2bcd90d6 100644 --- a/libnd4j/include/ops/declarable/helpers/adjust_hue.h +++ b/libnd4j/include/ops/declarable/helpers/adjust_hue.h @@ -20,114 +20,108 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - - void adjustHue(sd::LaunchContext* context, const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC); - - +void adjustHue(sd::LaunchContext* context, const NDArray* input, + const NDArray* deltaScalarArr, NDArray* output, const int dimC); //////////////////////////////////////////////////////////////////////////////// template -FORCEINLINE _CUDA_HD void rgbToHsv(const T& r, const T& g, const T& b, T& h, T& s, T& v) { - - // h values are in range [0, 360) - // s and v values are in range [0, 1] - - const T max = sd::math::nd4j_max(r, sd::math::nd4j_max(g, b)); - const T min = sd::math::nd4j_min(r, sd::math::nd4j_min(g, b)); - const T c = max - min; - const T _p6 = (T)1 / (T)6; - // calculate h - if(c == 0) { - h = 0; - } - else if(max == r) { - h = _p6 * ((g - b) / c) + (g >= b ? (T)0 : (T)1); - } - else if(max == g) { - h = _p6 * ((b - r) / c + (T)2); - } - else { // max == b - h = _p6 * ((r - g) / c + (T)4); - } - - // calculate s - s = max == (T)0 ? (T)0 : c / max; - - // calculate v - v = max;// / 255.f; +FORCEINLINE _CUDA_HD void rgbToHsv(const T& r, const T& g, const T& b, T& h, + T& s, T& v) { + // h values are in range [0, 360) + // s and v values are in range [0, 1] + + const T max = sd::math::nd4j_max(r, sd::math::nd4j_max(g, b)); + const T min = sd::math::nd4j_min(r, sd::math::nd4j_min(g, b)); + const T c = max - min; + const T _p6 = (T)1 / (T)6; + // calculate h + if (c == 0) { + h = 0; + } else if (max == r) { + h = _p6 * ((g - b) / c) + (g >= b ? (T)0 : (T)1); + } else if (max == g) { + h = _p6 * ((b - r) / c + (T)2); + } else { // max == b + h = _p6 * ((r - g) / c + (T)4); + } + + // calculate s + s = max == (T)0 ? (T)0 : c / max; + + // calculate v + v = max; // / 255.f; } //////////////////////////////////////////////////////////////////////////////// template -FORCEINLINE _CUDA_HD void hsvToRgb(const T& h, const T& s, const T& v, T& r, T& g, T& b) { - - const float sector = h * 6.f; - const T c = v * s; - - if(0.f <= sector && sector < 1.f) { - r = v; - g = v - c * (1 - sector); - b = v - c; - } - else if(1.f <= sector && sector < 2.f) { - r = v - c * (sector - 1); - g = v; - b = v - c; - } - else if(2.f <= sector && sector < 3.f) { - r = v - c; - g = v; - b = v - c * (3 - sector); - } - else if(3.f <= sector && sector < 4.f) { - r = v - c; - g = v - c * (sector - 3); - b = v; - } - else if(4.f <= sector && sector < 5.f) { - r = v - c * (5 - sector); - g = v - c; - b = v; - } - else { // 5.f <= sector < 6.f - r = v; - g = v - c; - b = v - c * (sector - 5); - } - -// r *= 255; -// g *= 255; -// b *= 255; +FORCEINLINE _CUDA_HD void hsvToRgb(const T& h, const T& s, const T& v, T& r, + T& g, T& b) { + const float sector = h * 6.f; + const T c = v * s; + + if (0.f <= sector && sector < 1.f) { + r = v; + g = v - c * (1 - sector); + b = v - c; + } else if (1.f <= sector && sector < 2.f) { + r = v - c * (sector - 1); + g = v; + b = v - c; + } else if (2.f <= sector && sector < 3.f) { + r = v - c; + g = v; + b = v - c * (3 - sector); + } else if (3.f <= sector && sector < 4.f) { + r = v - c; + g = v - c * (sector - 3); + b = v; + } else if (4.f <= sector && sector < 5.f) { + r = v - c * (5 - sector); + g = v - c; + b = v; + } else { // 5.f <= sector < 6.f + r = v; + g = v - c; + b = v - c * (sector - 5); + } + + // r *= 255; + // g *= 255; + // b *= 255; } //////////////////////////////////////////////////////////////////////////////// template -FORCEINLINE _CUDA_HD void rgbYuv(const T& r, const T& g, const T& b, T& y, T& u, T& v) { - y = static_cast(0.299) * r + static_cast(0.587) *g + static_cast(0.114) * b; - u = -static_cast(0.14714119) * r - static_cast(0.2888691) * g + static_cast(0.43601035) * b; - v = static_cast(0.61497538) * r - static_cast(0.51496512) * g - static_cast(0.10001026) * b; +FORCEINLINE _CUDA_HD void rgbYuv(const T& r, const T& g, const T& b, T& y, T& u, + T& v) { + y = static_cast(0.299) * r + static_cast(0.587) * g + + static_cast(0.114) * b; + u = -static_cast(0.14714119) * r - static_cast(0.2888691) * g + + static_cast(0.43601035) * b; + v = static_cast(0.61497538) * r - static_cast(0.51496512) * g - + static_cast(0.10001026) * b; } //////////////////////////////////////////////////////////////////////////////// template -FORCEINLINE _CUDA_HD void yuvRgb(const T& y, const T& u, const T& v, T& r, T& g, T& b) { - r = y + static_cast(1.13988303) * v; - g = y - static_cast(0.394642334) * u - static_cast(0.58062185) * v; - b = y + static_cast(2.03206185) * u; +FORCEINLINE _CUDA_HD void yuvRgb(const T& y, const T& u, const T& v, T& r, T& g, + T& b) { + r = y + static_cast(1.13988303) * v; + g = y - static_cast(0.394642334) * u - static_cast(0.58062185) * v; + b = y + static_cast(2.03206185) * u; } /*//////////////////////////////////////////////////////////////////////////////// template -static FORCEINLINE _CUDA_HD void rgb_to_hv(T r, T g, T b, T* h, T* v_min, T* v_max) { - T v_mid; - int h_category; +static FORCEINLINE _CUDA_HD void rgb_to_hv(T r, T g, T b, T* h, T* v_min, T* +v_max) { T v_mid; int h_category; // According to the figures in: // https://en.wikipedia.org/wiki/HSL_and_HSV#Hue_and_chroma // For the conditions, we don't care about the case where two components are @@ -185,12 +179,9 @@ static FORCEINLINE _CUDA_HD void rgb_to_hv(T r, T g, T b, T* h, T* v_min, T* v_m //////////////////////////////////////////////////////////////////////////////// template -static FORCEINLINE _CUDA_HD void hv_to_rgb(T h, T v_min, T v_max, T* r, T* g, T* b) { - int h_category = static_cast(h); - T ratio = h - (T)h_category; - bool increase = ((h_category & 0x1) == 0); - if (!increase) - ratio = 1 - ratio; +static FORCEINLINE _CUDA_HD void hv_to_rgb(T h, T v_min, T v_max, T* r, T* g, T* +b) { int h_category = static_cast(h); T ratio = h - (T)h_category; bool +increase = ((h_category & 0x1) == 0); if (!increase) ratio = 1 - ratio; T v_mid = v_min + ratio * (v_max - v_min); // According to the figures in: @@ -230,6 +221,6 @@ static FORCEINLINE _CUDA_HD void hv_to_rgb(T h, T v_min, T v_max, T* r, T* g, T* } */ -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/adjust_saturation.h b/libnd4j/include/ops/declarable/helpers/adjust_saturation.h index 25dc30f102ad..2f95d8643a24 100644 --- a/libnd4j/include/ops/declarable/helpers/adjust_saturation.h +++ b/libnd4j/include/ops/declarable/helpers/adjust_saturation.h @@ -19,25 +19,24 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - void adjustSaturation(sd::LaunchContext* context, const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC); +void adjustSaturation(sd::LaunchContext* context, const NDArray* input, + const NDArray* factorScalarArr, NDArray* output, + const int dimC); /* template - static FORCEINLINE _CUDA_HD void rgb_to_hsv(T r, T g, T b, T* h, T* s, T* v) { - T vv = sd::math::nd4j_max(r, sd::math::nd4j_max(g, b)); - T range = vv - sd::math::nd4j_min(r, sd::math::nd4j_min(g, b)); - if (vv > 0) { - *s = range / vv; - } else { - *s = 0; + static FORCEINLINE _CUDA_HD void rgb_to_hsv(T r, T g, T b, T* h, T* s, T* v) + { T vv = sd::math::nd4j_max(r, sd::math::nd4j_max(g, b)); T range = vv + - sd::math::nd4j_min(r, sd::math::nd4j_min(g, b)); if (vv > 0) { *s = + range / vv; } else { *s = 0; } T norm = 1.0f / (6.0f * range); T hh; @@ -59,15 +58,9 @@ namespace helpers { } template - static FORCEINLINE _CUDA_HD void hsv_to_rgb(T h, T s, T v, T* r, T* g, T* b) { - T c = s * v; - T m = v - c; - T dh = h * 6; - T rr, gg, bb; - int h_category = static_cast(dh); - T fmodu = dh; - while (fmodu <= (T) 0) - fmodu += (T) 2.0f; + static FORCEINLINE _CUDA_HD void hsv_to_rgb(T h, T s, T v, T* r, T* g, T* b) + { T c = s * v; T m = v - c; T dh = h * 6; T rr, gg, bb; int h_category = + static_cast(dh); T fmodu = dh; while (fmodu <= (T) 0) fmodu += (T) 2.0f; while (fmodu >= (T) 2.0f) fmodu -= (T) 2.0f; @@ -116,6 +109,6 @@ namespace helpers { } */ -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/axis.h b/libnd4j/include/ops/declarable/helpers/axis.h index 76c5a070ddc0..8379b4591b23 100644 --- a/libnd4j/include/ops/declarable/helpers/axis.h +++ b/libnd4j/include/ops/declarable/helpers/axis.h @@ -19,20 +19,20 @@ // #ifndef __AXIS_H_HELPERS__ #define __AXIS_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - /* - * adjustAxis routines: adjust data with output to non-negative values. - * */ - void adjustAxis(Nd4jLong rank, NDArray* axisVector, std::vector& output); - void adjustAxis(Nd4jLong rank, std::vector& output); +/* + * adjustAxis routines: adjust data with output to non-negative values. + * */ +void adjustAxis(Nd4jLong rank, NDArray* axisVector, std::vector& output); +void adjustAxis(Nd4jLong rank, std::vector& output); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/batched_gemm.h b/libnd4j/include/ops/declarable/helpers/batched_gemm.h index 26651cf3c63e..63e18a1df15a 100644 --- a/libnd4j/include/ops/declarable/helpers/batched_gemm.h +++ b/libnd4j/include/ops/declarable/helpers/batched_gemm.h @@ -18,15 +18,19 @@ // @author raver119@gmail.com // -#include #include +#include + namespace sd { namespace ops { namespace helpers { -void bgemm(const std::vector& vA, const std::vector& vB, std::vector& vC, const NDArray* alphas, const NDArray* betas, int transA, int transB, int M, int N, int K, const int lda, const int ldb, const int ldc); +void bgemm(const std::vector& vA, const std::vector& vB, + std::vector& vC, const NDArray* alphas, + const NDArray* betas, int transA, int transB, int M, int N, int K, + const int lda, const int ldb, const int ldc); } -} -} \ No newline at end of file +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/batchnorm.h b/libnd4j/include/ops/declarable/helpers/batchnorm.h index 72bc69718e37..62fc8baefc64 100644 --- a/libnd4j/include/ops/declarable/helpers/batchnorm.h +++ b/libnd4j/include/ops/declarable/helpers/batchnorm.h @@ -23,17 +23,17 @@ #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { +void batchnorm(const NDArray* input, const NDArray* mean, + const NDArray* variance, const NDArray* gamma, + const NDArray* beta, NDArray* output, + const std::vector& axes, const double epsilon); - void batchnorm(const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* beta, NDArray* output, const std::vector& axes, const double epsilon); - - -} } -} - +} // namespace ops +} // namespace sd -#endif //LIBND4J_BATCHNORM_H +#endif // LIBND4J_BATCHNORM_H diff --git a/libnd4j/include/ops/declarable/helpers/betaInc.h b/libnd4j/include/ops/declarable/helpers/betaInc.h index 4c37d7c5dc3a..ab16f8c0083f 100644 --- a/libnd4j/include/ops/declarable/helpers/betaInc.h +++ b/libnd4j/include/ops/declarable/helpers/betaInc.h @@ -22,20 +22,23 @@ #define LIBND4J_BETAINC_H #include + #include "array/NDArray.h" namespace sd { namespace ops { namespace helpers { - const uint maxIter = MAX_NUM_THREADS /*articles propose 10000*/; // max number of loop iterations in function for continued fractions - - void betaInc(sd::LaunchContext* context, const NDArray& a, const NDArray& b, const NDArray& x, NDArray& output); - +const uint maxIter = + MAX_NUM_THREADS /*articles propose 10000*/; // max number of loop + // iterations in function for + // continued fractions -} -} -} +void betaInc(sd::LaunchContext* context, const NDArray& a, const NDArray& b, + const NDArray& x, NDArray& output); +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //LIBND4J_BETAINC_H \ No newline at end of file +#endif // LIBND4J_BETAINC_H \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/choose.h b/libnd4j/include/ops/declarable/helpers/choose.h index 233c6b1ff95c..4a1c5a389fa0 100644 --- a/libnd4j/include/ops/declarable/helpers/choose.h +++ b/libnd4j/include/ops/declarable/helpers/choose.h @@ -19,17 +19,20 @@ // #ifndef __CHOOSE_H_HELPERS__ #define __CHOOSE_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void chooseFunctorArray(sd::LaunchContext * context, NDArray* arg, NDArray* comp, int mode, NDArray* result, NDArray* numResults); - void chooseFunctorScalar(sd::LaunchContext * context, NDArray* arg, double scalar, int mode, NDArray* result, NDArray* numResults); +void chooseFunctorArray(sd::LaunchContext* context, NDArray* arg, NDArray* comp, + int mode, NDArray* result, NDArray* numResults); +void chooseFunctorScalar(sd::LaunchContext* context, NDArray* arg, + double scalar, int mode, NDArray* result, + NDArray* numResults); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/col2im.h b/libnd4j/include/ops/declarable/helpers/col2im.h index 39a29da85581..d139c6606ccc 100644 --- a/libnd4j/include/ops/declarable/helpers/col2im.h +++ b/libnd4j/include/ops/declarable/helpers/col2im.h @@ -27,12 +27,13 @@ namespace sd { namespace ops { namespace helpers { - ND4J_EXPORT void col2im(sd::LaunchContext & context, const NDArray& input, NDArray& output, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW); +SD_EXPORT void col2im(sd::LaunchContext& context, const NDArray& input, + NDArray& output, const int sH, const int sW, const int pH, + const int pW, const int iH, const int iW, const int dH, + const int dW); - -} } -} - +} // namespace ops +} // namespace sd -#endif //LIBND4J_COL2IM_H +#endif // LIBND4J_COL2IM_H diff --git a/libnd4j/include/ops/declarable/helpers/compare_elem.h b/libnd4j/include/ops/declarable/helpers/compare_elem.h index 18faee3288b9..0d0f9f8eb8cc 100644 --- a/libnd4j/include/ops/declarable/helpers/compare_elem.h +++ b/libnd4j/include/ops/declarable/helpers/compare_elem.h @@ -14,21 +14,21 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - #ifndef LIBND4J_COMPARE_ELEM_H #define LIBND4J_COMPARE_ELEM_H #include + #include "array/NDArray.h" namespace sd { - namespace ops { - namespace helpers { +namespace ops { +namespace helpers { - void compare_elem(sd::LaunchContext * context, NDArray* input, bool isStrictlyIncreasing, bool& output); - } - } +void compare_elem(sd::LaunchContext* context, NDArray* input, + bool isStrictlyIncreasing, bool& output); } +} // namespace ops +} // namespace sd - -#endif //LIBND4J_COMPARE_ELEM_H \ No newline at end of file +#endif // LIBND4J_COMPARE_ELEM_H \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/compression.h b/libnd4j/include/ops/declarable/helpers/compression.h index b9c70a91b886..966d2209dcd3 100644 --- a/libnd4j/include/ops/declarable/helpers/compression.h +++ b/libnd4j/include/ops/declarable/helpers/compression.h @@ -19,16 +19,18 @@ // #ifndef __COMPRESSION_H_HELPERS__ #define __COMPRESSION_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void decodeBitmap(sd::LaunchContext* context, const NDArray* input, NDArray* output); - Nd4jLong encodeBitmap(sd::LaunchContext* context, NDArray* input, NDArray* output, float threshold); -} -} -} +void decodeBitmap(sd::LaunchContext* context, const NDArray* input, + NDArray* output); +Nd4jLong encodeBitmap(sd::LaunchContext* context, NDArray* input, + NDArray* output, float threshold); +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/confusion.h b/libnd4j/include/ops/declarable/helpers/confusion.h index a4d27be4fbe6..e660c2b7bdd1 100644 --- a/libnd4j/include/ops/declarable/helpers/confusion.h +++ b/libnd4j/include/ops/declarable/helpers/confusion.h @@ -19,16 +19,17 @@ // #ifndef __CONFUSION_H_HELPERS__ #define __CONFUSION_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void confusionFunctor(sd::LaunchContext * context, NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output); +void confusionFunctor(sd::LaunchContext* context, NDArray* labels, + NDArray* predictions, NDArray* weights, NDArray* output); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/convolutions.h b/libnd4j/include/ops/declarable/helpers/convolutions.h index eb41ae637e10..90e0b2467cbd 100644 --- a/libnd4j/include/ops/declarable/helpers/convolutions.h +++ b/libnd4j/include/ops/declarable/helpers/convolutions.h @@ -22,312 +22,427 @@ #define LIBND4J_CONVOLUTIONS_H #include +#include #include #include -#include - namespace sd { - namespace ops { - - enum PoolingType { - MAX_POOL = 0, - AVG_POOL = 1, - PNORM_POOL = 2, - }; - - class ND4J_EXPORT ConvolutionUtils { - public: - static inline void calcOutSizePool2D(int& oH, int& oW, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int iH, const int iW, const int paddingMode) { - - if(paddingMode == 0) { // valid - // oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; - // oW = (iW - (kW + (kW-1)*(dW-1)) + 2*pW)/sW + 1; - oH = (iH - ((kH - 1) * dH + 1) + 2 * pH) / sH + 1; - oW = (iW - ((kW - 1) * dW + 1) + 2 * pW) / sW + 1; - } - else if (paddingMode == 1) { // same - oH = (int) math::nd4j_ceil(iH * 1. / sH); - oW = (int) math::nd4j_ceil(iW * 1. / sW); - } - else { // causal - oH = (iH - 1) / sH + 1; // 2*pH = (kH-1)*dH - oW = (iW - 1) / sW + 1; - } - } - - static inline void calcOutSizePool3D(int& oD, int& oH, int& oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int iD, const int iH, const int iW, const int paddingMode) { - - if(paddingMode == 0) { // valid - oD = (iD - ((kD - 1) * dD + 1) + 2 * pD) / sD + 1; - oH = (iH - ((kH - 1) * dH + 1) + 2 * pH) / sH + 1; - oW = (iW - ((kW - 1) * dW + 1) + 2 * pW) / sW + 1; - } - else if(paddingMode == 1) { // same - oD = (int) sd::math::nd4j_ceil(iD * 1. / sD); - oH = (int) sd::math::nd4j_ceil(iH * 1. / sH); - oW = (int) sd::math::nd4j_ceil(iW * 1. / sW); - - } - else { // causal - oD = (iD - 1) / sD + 1; - oH = (iH - 1) / sH + 1; // 2*pH = (kH-1)*dH - oW = (iW - 1) / sW + 1; - } - } - - static inline void calcPadding2D(int& pH, int& pW, int oH, int oW, int iH, int iW, int kH, int kW, int sH, int sW, int dH, int dW, const int paddingMode = 1 /* default is same mode*/) { - - if(paddingMode == 0) // valid - return; - - if(paddingMode == 1) { // same - - const int eKH = (kH - 1) * dH + 1; - const int eKW = (kW - 1) * dW + 1; - - pH = ((oH - 1) * sH + eKH - iH) / 2; //Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2 - pW = ((oW - 1) * sW + eKW - iW) / 2; - } - else { // causal - pH = (kH - 1) * dH; - pW = (kW - 1) * dW; - } - } - - static inline void calcPadding3D(int& pD, int& pH, int& pW, const int oD, const int oH, const int oW, const int iD, const int iH, const int iW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int dD, const int dH, const int dW, const int paddingMode = 1 /* default is same mode*/) { - - if(paddingMode == 0) // valid - return; - - if(paddingMode == 1) { // same - - const int eKD = (kD - 1) * dD + 1; - const int eKH = (kH - 1) * dH + 1; - const int eKW = (kW - 1) * dW + 1; - - pD = ((oD - 1) * sD + eKD - iD) / 2; - pH = ((oH - 1) * sH + eKH - iH) / 2; //Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2 - pW = ((oW - 1) * sW + eKW - iW) / 2; - } - else { // causal - pD = (kD - 1) * dD; - pH = (kH - 1) * dH; - pW = (kW - 1) * dW; - } - } - - // calculation of output height and width in 2D deconvolution procedure - static inline void calcOutSizeDeconv2D(int& oH, int& oW, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int iH, const int iW, const int paddingMode) { - - if (paddingMode) { - oH = sH * iH; - oW = sW * iW; - } - else { - const int ekH = (kH - 1) * dH + 1; - const int ekW = (kW - 1) * dW + 1; - - oH = sH * (iH - 1) + ekH - 2 * pH; - oW = sW * (iW - 1) + ekW - 2 * pW; - } - } - - // calculation of output height and width in 3D deconvolution procedure - static inline void calcOutSizeDeconv3D(int& oD, int& oH, int& oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int iD, const int iH, const int iW, const int paddingMode) { - - if (paddingMode) { - oD = sD * iD; - oH = sH * iH; - oW = sW * iW; - } - else { - - const int ekD = (kD - 1) * dD + 1; - const int ekH = (kH - 1) * dH + 1; - const int ekW = (kW - 1) * dW + 1; - - oD = sD * (iD - 1) + ekD - 2 * pD; - oH = sH * (iH - 1) + ekH - 2 * pH; - oW = sW * (iW - 1) + ekW - 2 * pW; - } - } - - // evaluates sizes values and indexes using input and output arrays depending on data format - static inline void getSizesAndIndexesConv2d(const bool isNCHW, const int wFormat, const NDArray& input, const NDArray& output, int& bS, int& iC, int& iH, int& iW, int& oC, int& oH, int& oW, int& indIOioC, int& indIiH, int& indWiC, int& indWoC, int& indWkH, int& indOoH) { - getSizesAndIndexesConv2d(isNCHW, wFormat, input.shapeInfo(), output.shapeInfo(), bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - } - - static inline void getSizesAndIndexesConv2d(const bool isNCHW, const int wFormat, const Nd4jLong* inShapeInfo, const Nd4jLong* outShapeInfo, int& bS, int& iC, int& iH, int& iW, int& oC, int& oH, int& oW, int& indIOioC, int& indIiH, int& indWiC, int& indWoC, int& indWkH, int& indOoH) { - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, oC] (wFormat = 0), [oC, iC, kH, kW] (wFormat = 1), [oC, kH, kW, iC] (wFormat = 2) - // output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - - if(0 == wFormat) { - indWkH = 0; indWiC = 2; indWoC = 3; - } - else if(1 == wFormat) { - indWkH = 2; indWiC = 1; indWoC = 0; - } - else { - indWkH = 1; indWiC = 3; indWoC = 0; - } - - if(!isNCHW) { - indIOioC = 3; indIiH = 1; indOoH = 1; - } - else { - indIOioC = 1; indIiH = 2; indOoH = 2; - } - - bS = inShapeInfo[1]; // batch size - iC = inShapeInfo[indIOioC+1]; // input channels - iH = inShapeInfo[indIiH+1]; // input height - iW = inShapeInfo[indIiH+2]; // input width - oC = outShapeInfo[indIOioC+1]; // output channels - oH = outShapeInfo[indOoH+1]; // output height - oW = outShapeInfo[indOoH+2]; // output width - } - - // evaluates sizes values and indexes using input and output arrays depending on data format - static inline void getSizesAndIndexesConv3d(const bool isNCDHW, const int wFormat, const NDArray& input, const NDArray& output, int& bS, int& iC, int& iD, int& iH, int& iW, int& oC, int& oD, int& oH, int& oW, int& indIOioC, int& indIOioD, int& indWiC, int& indWoC, int& indWkD) { - // input [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - // weights [kD, kH, kW, iC, oC] (wFormat = 0), [oC, iC, kD, kH, kW] (wFormat = 1), [oC, kD, kH, kW, iC] (wFormat = 2) - // output [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) - - if(0 == wFormat) { - indWkD = 0; indWiC = 3; indWoC = 4; - } - else if(1 == wFormat) { - indWkD = 2; indWiC = 1; indWoC = 0; - } - else { - indWkD = 1; indWiC = 4; indWoC = 0; - } - - if(!isNCDHW) { - indIOioC = 4; indIOioD = 1; - } - else { - indIOioC = 1; indIOioD = 2; - } - - bS = input.sizeAt(0); // batch size - iC = input.sizeAt(indIOioC); // input channels - iD = input.sizeAt(indIOioD); // input depth - iH = input.sizeAt(indIOioD+1); // input height - iW = input.sizeAt(indIOioD+2); // input width - oC = output.sizeAt(indIOioC); // output channels - oD = output.sizeAt(indIOioD); // output depth - oH = output.sizeAt(indIOioD+1); // output height - oW = output.sizeAt(indIOioD+2); // output width - } - - // static inline void calcPaddingAndDilationForConv2DMKL(const int iH, const int iW, const int oH, const int oW, const int kH, const int kW, const int sH, const int sW, const int paddingMode, int& pH, int& pW, int& dH, int& dW) { - - // if(kH != 1) { - // if(paddingMode) { - // pH = (oH - 1) * sH - iH + kH - pH; - // dH = dH - 1; - // } - // else - // dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1); - // } - // if(kW != 1) { - // if(paddingMode) { - // pW = (oW - 1) * sW - iW + kW - pW; - // dW = dW - 1; - // } - // else - // dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1); - // } - // } - - // static inline void calcPaddingAndDilationForConv3DMKL(const int iD, const int iH, const int iW, const int oD, const int oH, const int oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int paddingMode, int& pD, int& pH, int& pW, int& dD, int& dH, int& dW) { - - // if(kD != 1) { - // if(paddingMode) { - // pD = (oD - 1) * sD - iD + kD - pD; - // dD = dD - 1; - // } - // else - // dD = (iD + 2*pD - (oD - 1) * sD - kD) / (kD - 1); - // } - // if(kH != 1) { - // if(paddingMode) { - // pH = (oH - 1) * sH - iH + kH - pH; - // dH = dH - 1; - // } - // else - // dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1); - // } - // if(kW != 1) { - // if(paddingMode) { - // pW = (oW - 1) * sW - iW + kW - pW; - // dW = dW - 1; - // } - // else - // dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1); - // } - // } - - static std::vector expectWeightsShape(const int wFormat, const int kH, const int kW, const int iC, const int oC) { - - if(0 == wFormat) - return std::vector({kH, kW, iC, oC}); - - if(1 == wFormat) - return std::vector({oC, iC, kH, kW}); - - return std::vector({oC, kH, kW, iC}); - } - - static std::vector expectWeightsShape(const int wFormat, const int kD, const int kH, const int kW, const int iC, const int oC) { - - if(0 == wFormat) - return std::vector({kD, kH, kW, iC, oC}); - - if(1 == wFormat) - return std::vector({oC, iC, kD, kH, kW}); - - return std::vector({oC, kD, kH, kW, iC}); - } - - static void conv2d(sd::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat); - - // static void conv2d(sd::graph::Context & block, const std::vector& inArrs, NDArray* output, const std::vector& intArgs); - - // static void conv2dBP(sd::graph::Context & block, const std::vector& inArrs, const std::vector& outArrs, const std::vector& intArgs); - - static void conv2dBP(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat); - - static void depthwiseConv2d(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat); - - static void depthwiseConv2dBP(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat); - - static void sconv2d(sd::graph::Context & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat); - - static void vol2col(sd::graph::Context & block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW); - - static void col2vol(sd::graph::Context & block, const NDArray& col, NDArray& vol, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW); - - static void upsampling2d(sd::graph::Context & block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW); - - static void upsampling3d(sd::graph::Context & block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW); - - static void upsampling2dBP(sd::graph::Context & block, const NDArray& gradO, NDArray& gradI, const bool isNCHW); - - static void upsampling3dBP(sd::graph::Context & block, const NDArray& gradO, NDArray& gradI, const bool isNCDHW); - - static void pooling2d(sd::graph::Context & block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0); - - static void pooling3d(sd::graph::Context & block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0); - - static void pooling2dBP(sd::graph::Context & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0); - - static void pooling3dBP(sd::graph::Context & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0); - }; - -} -} -#endif //LIBND4J_CONVOLUTIONS_H +namespace ops { + +enum PoolingType { + MAX_POOL = 0, + AVG_POOL = 1, + PNORM_POOL = 2, +}; + +class SD_EXPORT ConvolutionUtils { + public: + static inline void calcOutSizePool2D(int& oH, int& oW, const int kH, + const int kW, const int sH, const int sW, + const int pH, const int pW, const int dH, + const int dW, const int iH, const int iW, + const int paddingMode) { + if (paddingMode == 0) { // valid + // oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; + // oW = (iW - (kW + (kW-1)*(dW-1)) + 2*pW)/sW + 1; + oH = (iH - ((kH - 1) * dH + 1) + 2 * pH) / sH + 1; + oW = (iW - ((kW - 1) * dW + 1) + 2 * pW) / sW + 1; + } else if (paddingMode == 1) { // same + oH = (int)math::nd4j_ceil(iH * 1. / sH); + oW = (int)math::nd4j_ceil(iW * 1. / sW); + } else { // causal + oH = (iH - 1) / sH + 1; // 2*pH = (kH-1)*dH + oW = (iW - 1) / sW + 1; + } + } + + static inline void calcOutSizePool3D(int& oD, int& oH, int& oW, const int kD, + const int kH, const int kW, const int sD, + const int sH, const int sW, const int pD, + const int pH, const int pW, const int dD, + const int dH, const int dW, const int iD, + const int iH, const int iW, + const int paddingMode) { + if (paddingMode == 0) { // valid + oD = (iD - ((kD - 1) * dD + 1) + 2 * pD) / sD + 1; + oH = (iH - ((kH - 1) * dH + 1) + 2 * pH) / sH + 1; + oW = (iW - ((kW - 1) * dW + 1) + 2 * pW) / sW + 1; + } else if (paddingMode == 1) { // same + oD = (int)sd::math::nd4j_ceil(iD * 1. / sD); + oH = (int)sd::math::nd4j_ceil(iH * 1. / sH); + oW = (int)sd::math::nd4j_ceil(iW * 1. / sW); + + } else { // causal + oD = (iD - 1) / sD + 1; + oH = (iH - 1) / sH + 1; // 2*pH = (kH-1)*dH + oW = (iW - 1) / sW + 1; + } + } + + static inline void calcPadding2D( + int& pH, int& pW, int oH, int oW, int iH, int iW, int kH, int kW, int sH, + int sW, int dH, int dW, + const int paddingMode = 1 /* default is same mode*/) { + if (paddingMode == 0) // valid + return; + + if (paddingMode == 1) { // same + + const int eKH = (kH - 1) * dH + 1; + const int eKW = (kW - 1) * dW + 1; + + pH = ((oH - 1) * sH + eKH - iH) / + 2; // Note that padBottom is 1 bigger than this if bracketed term is + // not divisible by 2 + pW = ((oW - 1) * sW + eKW - iW) / 2; + } else { // causal + pH = (kH - 1) * dH; + pW = (kW - 1) * dW; + } + } + + static inline void calcPadding3D( + int& pD, int& pH, int& pW, const int oD, const int oH, const int oW, + const int iD, const int iH, const int iW, const int kD, const int kH, + const int kW, const int sD, const int sH, const int sW, const int dD, + const int dH, const int dW, + const int paddingMode = 1 /* default is same mode*/) { + if (paddingMode == 0) // valid + return; + + if (paddingMode == 1) { // same + + const int eKD = (kD - 1) * dD + 1; + const int eKH = (kH - 1) * dH + 1; + const int eKW = (kW - 1) * dW + 1; + + pD = ((oD - 1) * sD + eKD - iD) / 2; + pH = ((oH - 1) * sH + eKH - iH) / + 2; // Note that padBottom is 1 bigger than this if bracketed term is + // not divisible by 2 + pW = ((oW - 1) * sW + eKW - iW) / 2; + } else { // causal + pD = (kD - 1) * dD; + pH = (kH - 1) * dH; + pW = (kW - 1) * dW; + } + } + + // calculation of output height and width in 2D deconvolution procedure + static inline void calcOutSizeDeconv2D(int& oH, int& oW, const int kH, + const int kW, const int sH, + const int sW, const int pH, + const int pW, const int dH, + const int dW, const int iH, + const int iW, const int paddingMode) { + if (paddingMode) { + oH = sH * iH; + oW = sW * iW; + } else { + const int ekH = (kH - 1) * dH + 1; + const int ekW = (kW - 1) * dW + 1; + + oH = sH * (iH - 1) + ekH - 2 * pH; + oW = sW * (iW - 1) + ekW - 2 * pW; + } + } + + // calculation of output height and width in 3D deconvolution procedure + static inline void calcOutSizeDeconv3D( + int& oD, int& oH, int& oW, const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, const int pD, const int pH, + const int pW, const int dD, const int dH, const int dW, const int iD, + const int iH, const int iW, const int paddingMode) { + if (paddingMode) { + oD = sD * iD; + oH = sH * iH; + oW = sW * iW; + } else { + const int ekD = (kD - 1) * dD + 1; + const int ekH = (kH - 1) * dH + 1; + const int ekW = (kW - 1) * dW + 1; + + oD = sD * (iD - 1) + ekD - 2 * pD; + oH = sH * (iH - 1) + ekH - 2 * pH; + oW = sW * (iW - 1) + ekW - 2 * pW; + } + } + + // evaluates sizes values and indexes using input and output arrays depending + // on data format + static inline void getSizesAndIndexesConv2d( + const bool isNCHW, const int wFormat, const NDArray& input, + const NDArray& output, int& bS, int& iC, int& iH, int& iW, int& oC, + int& oH, int& oW, int& indIOioC, int& indIiH, int& indWiC, int& indWoC, + int& indWkH, int& indOoH) { + getSizesAndIndexesConv2d(isNCHW, wFormat, input.shapeInfo(), + output.shapeInfo(), bS, iC, iH, iW, oC, oH, oW, + indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + } + + static inline void getSizesAndIndexesConv2d( + const bool isNCHW, const int wFormat, const Nd4jLong* inShapeInfo, + const Nd4jLong* outShapeInfo, int& bS, int& iC, int& iH, int& iW, int& oC, + int& oH, int& oW, int& indIOioC, int& indIiH, int& indWiC, int& indWoC, + int& indWkH, int& indOoH) { + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weights [kH, kW, iC, oC] (wFormat = 0), [oC, iC, kH, kW] (wFormat = 1), + // [oC, kH, kW, iC] (wFormat = 2) output [bS, oH, oW, oC] (NHWC) or [bS, + // oC, oH, oW] (NCHW) + + if (0 == wFormat) { + indWkH = 0; + indWiC = 2; + indWoC = 3; + } else if (1 == wFormat) { + indWkH = 2; + indWiC = 1; + indWoC = 0; + } else { + indWkH = 1; + indWiC = 3; + indWoC = 0; + } + + if (!isNCHW) { + indIOioC = 3; + indIiH = 1; + indOoH = 1; + } else { + indIOioC = 1; + indIiH = 2; + indOoH = 2; + } + + bS = inShapeInfo[1]; // batch size + iC = inShapeInfo[indIOioC + 1]; // input channels + iH = inShapeInfo[indIiH + 1]; // input height + iW = inShapeInfo[indIiH + 2]; // input width + oC = outShapeInfo[indIOioC + 1]; // output channels + oH = outShapeInfo[indOoH + 1]; // output height + oW = outShapeInfo[indOoH + 2]; // output width + } + + // evaluates sizes values and indexes using input and output arrays depending + // on data format + static inline void getSizesAndIndexesConv3d( + const bool isNCDHW, const int wFormat, const NDArray& input, + const NDArray& output, int& bS, int& iC, int& iD, int& iH, int& iW, + int& oC, int& oD, int& oH, int& oW, int& indIOioC, int& indIOioD, + int& indWiC, int& indWoC, int& indWkD) { + // input [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + // weights [kD, kH, kW, iC, oC] (wFormat = 0), [oC, iC, kD, kH, kW] (wFormat + // = 1), [oC, kD, kH, kW, iC] (wFormat = 2) output [bS, oD, oH, oW, oC] + // (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) + + if (0 == wFormat) { + indWkD = 0; + indWiC = 3; + indWoC = 4; + } else if (1 == wFormat) { + indWkD = 2; + indWiC = 1; + indWoC = 0; + } else { + indWkD = 1; + indWiC = 4; + indWoC = 0; + } + + if (!isNCDHW) { + indIOioC = 4; + indIOioD = 1; + } else { + indIOioC = 1; + indIOioD = 2; + } + + bS = input.sizeAt(0); // batch size + iC = input.sizeAt(indIOioC); // input channels + iD = input.sizeAt(indIOioD); // input depth + iH = input.sizeAt(indIOioD + 1); // input height + iW = input.sizeAt(indIOioD + 2); // input width + oC = output.sizeAt(indIOioC); // output channels + oD = output.sizeAt(indIOioD); // output depth + oH = output.sizeAt(indIOioD + 1); // output height + oW = output.sizeAt(indIOioD + 2); // output width + } + + // static inline void calcPaddingAndDilationForConv2DMKL(const int iH, const + // int iW, const int oH, const int oW, const int kH, const int kW, const int + // sH, const int sW, const int paddingMode, int& pH, int& pW, int& dH, int& + // dW) { + + // if(kH != 1) { + // if(paddingMode) { + // pH = (oH - 1) * sH - iH + kH - pH; + // dH = dH - 1; + // } + // else + // dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1); + // } + // if(kW != 1) { + // if(paddingMode) { + // pW = (oW - 1) * sW - iW + kW - pW; + // dW = dW - 1; + // } + // else + // dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1); + // } + // } + + // static inline void calcPaddingAndDilationForConv3DMKL(const int iD, const + // int iH, const int iW, const int oD, const int oH, const int oW, const int + // kD, const int kH, const int kW, const int sD, const int sH, const int sW, + // const int paddingMode, int& pD, int& pH, int& pW, int& dD, int& dH, int& + // dW) { + + // if(kD != 1) { + // if(paddingMode) { + // pD = (oD - 1) * sD - iD + kD - pD; + // dD = dD - 1; + // } + // else + // dD = (iD + 2*pD - (oD - 1) * sD - kD) / (kD - 1); + // } + // if(kH != 1) { + // if(paddingMode) { + // pH = (oH - 1) * sH - iH + kH - pH; + // dH = dH - 1; + // } + // else + // dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1); + // } + // if(kW != 1) { + // if(paddingMode) { + // pW = (oW - 1) * sW - iW + kW - pW; + // dW = dW - 1; + // } + // else + // dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1); + // } + // } + + static std::vector expectWeightsShape(const int wFormat, + const int kH, const int kW, + const int iC, const int oC) { + if (0 == wFormat) return std::vector({kH, kW, iC, oC}); + + if (1 == wFormat) return std::vector({oC, iC, kH, kW}); + + return std::vector({oC, kH, kW, iC}); + } + + static std::vector expectWeightsShape(const int wFormat, + const int kD, const int kH, + const int kW, const int iC, + const int oC) { + if (0 == wFormat) return std::vector({kD, kH, kW, iC, oC}); + + if (1 == wFormat) return std::vector({oC, iC, kD, kH, kW}); + + return std::vector({oC, kD, kH, kW, iC}); + } + + static void conv2d(sd::graph::Context& context, const NDArray* input, + const NDArray* weights, const NDArray* bias, + NDArray* output, const int kH, const int kW, const int sH, + const int sW, int pH, int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW, + const int wFormat); + + // static void conv2d(sd::graph::Context & block, const std::vector& + // inArrs, NDArray* output, const std::vector& intArgs); + + // static void conv2dBP(sd::graph::Context & block, const + // std::vector& inArrs, const std::vector& outArrs, const + // std::vector& intArgs); + + static void conv2dBP(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + const NDArray* gradO, NDArray* gradI, NDArray* gradW, + NDArray* gradB, const int kH, const int kW, const int sH, + const int sW, int pH, int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW, + const int wFormat); + + static void depthwiseConv2d(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + NDArray* output, const int kH, const int kW, + const int sH, const int sW, int pH, int pW, + const int dH, const int dW, const int paddingMode, + const int isNCHW, const int wFormat); + + static void depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + const NDArray* gradO, NDArray* gradI, + NDArray* gradW, NDArray* gradB, const int kH, + const int kW, const int sH, const int sW, + int pH, int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW, + const int wFormat); + + static void sconv2d(sd::graph::Context& block, const NDArray* input, + const NDArray* weightsDepth, const NDArray* weightsPoint, + const NDArray* bias, NDArray* output, const int kH, + const int kW, const int sH, const int sW, int pH, int pW, + const int dH, const int dW, const int paddingMode, + const int isNCHW, const int wFormat); + + static void vol2col(sd::graph::Context& block, const NDArray& vol, + NDArray& col, const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, const int dD, + const int dH, const int dW); + + static void col2vol(sd::graph::Context& block, const NDArray& col, + NDArray& vol, const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, const int dD, + const int dH, const int dW); + + static void upsampling2d(sd::graph::Context& block, const NDArray& input, + NDArray& output, const int factorH, + const int factorW, const bool isNCHW); + + static void upsampling3d(sd::graph::Context& block, const NDArray& input, + NDArray& output, const int factorD, + const int factorH, const int factorW, + const bool isNCDHW); + + static void upsampling2dBP(sd::graph::Context& block, const NDArray& gradO, + NDArray& gradI, const bool isNCHW); + + static void upsampling3dBP(sd::graph::Context& block, const NDArray& gradO, + NDArray& gradI, const bool isNCDHW); + + static void pooling2d(sd::graph::Context& block, const NDArray& input, + NDArray& output, const int kH, const int kW, + const int sH, const int sW, const int pH, const int pW, + const int dH, const int dW, + const PoolingType poolingMode, const int extraParam0); + + static void pooling3d(sd::graph::Context& block, const NDArray& input, + NDArray& output, const int kD, const int kH, + const int kW, const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, const int dD, + const int dH, const int dW, const int poolingMode, + const int extraParam0); + + static void pooling2dBP(sd::graph::Context& block, const NDArray& input, + const NDArray& gradO, NDArray& gradI, const int kH, + const int kW, const int sH, const int sW, + const int pH, const int pW, const int dH, + const int dW, const int poolingMode, + const int extraParam0); + + static void pooling3dBP(sd::graph::Context& block, const NDArray& input, + const NDArray& gradO, NDArray& gradI, const int kD, + const int kH, const int kW, const int sD, + const int sH, const int sW, const int pD, + const int pH, const int pW, const int dD, + const int dH, const int dW, const int poolingMode, + const int extraParam0); +}; + +} // namespace ops +} // namespace sd +#endif // LIBND4J_CONVOLUTIONS_H diff --git a/libnd4j/include/ops/declarable/helpers/cpu/BarnesHutTsne.cpp b/libnd4j/include/ops/declarable/helpers/cpu/BarnesHutTsne.cpp index 97bdd5c8963d..2ffd204a90dc 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/BarnesHutTsne.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/BarnesHutTsne.cpp @@ -18,40 +18,40 @@ // @author George A. Shulinok , created on 4/18/2019 // -#include #include +#include namespace sd { namespace ops { namespace helpers { - Nd4jLong barnes_row_count(const NDArray* rowP, const NDArray* colP, Nd4jLong N, NDArray& rowCounts) { - - int* pRowCounts = reinterpret_cast(rowCounts.buffer()); - int const* pRows = reinterpret_cast(rowP->buffer()); - int const* pCols = reinterpret_cast(colP->buffer()); - for (Nd4jLong n = 0; n < N; n++) { - int begin = pRows[n];//->e(n); - int end = pRows[n + 1];//rowP->e(n + 1); - for (int i = begin; i < end; i++) { - bool present = false; - for (int m = pRows[pCols[i]]; m < pRows[pCols[i] + 1]; m++) - if (pCols[m] == n) { - present = true; - break; - } - - ++pRowCounts[n]; - - if (!present) - ++pRowCounts[pCols[i]]; - } +Nd4jLong barnes_row_count(const NDArray* rowP, const NDArray* colP, Nd4jLong N, + NDArray& rowCounts) { + int* pRowCounts = reinterpret_cast(rowCounts.buffer()); + int const* pRows = reinterpret_cast(rowP->buffer()); + int const* pCols = reinterpret_cast(colP->buffer()); + for (Nd4jLong n = 0; n < N; n++) { + int begin = pRows[n]; //->e(n); + int end = pRows[n + 1]; // rowP->e(n + 1); + for (int i = begin; i < end; i++) { + bool present = false; + for (int m = pRows[pCols[i]]; m < pRows[pCols[i] + 1]; m++) + if (pCols[m] == n) { + present = true; + break; } - NDArray numElementsArr = rowCounts.sumNumber(); //reduceAlongDimension(reduce::Sum, {}); - //rowCounts.printBuffer("Row counts"); - auto numElements = numElementsArr.e(0); - return numElements; + + ++pRowCounts[n]; + + if (!present) ++pRowCounts[pCols[i]]; } + } + NDArray numElementsArr = + rowCounts.sumNumber(); // reduceAlongDimension(reduce::Sum, {}); + // rowCounts.printBuffer("Row counts"); + auto numElements = numElementsArr.e(0); + return numElements; +} // static // void printVector(std::vector const& v) { // for (auto x: v) { @@ -61,179 +61,214 @@ namespace helpers { // fflush(stdout); // } - template - static void barnes_symmetrize_(const NDArray* rowP, const NDArray* colP, const NDArray* valP, Nd4jLong N, NDArray* outputRows, NDArray* outputCols, NDArray* outputVals, NDArray* rowCounts) { - //auto N = rowP->lengthOf() - 1; /// 2 + rowP->lengthOf() % 2; - //auto numElements = output->lengthOf(); - //std::vector symRowP = rowCounts->asVectorT();//NDArrayFactory::create('c', {numElements}); - //NDArray symValP = NDArrayFactory::create('c', {numElements}); - //symRowP.insert(symRowP.begin(),0); - //symRowP(1, {0}) = *rowCounts; - int const* pRows = reinterpret_cast(rowP->buffer()); - int* symRowP = reinterpret_cast(outputRows->buffer()); - symRowP[0] = 0; - for (Nd4jLong n = 0; n < N; n++) - symRowP[n + 1] = symRowP[n] + rowCounts->e(n); -// outputRows->printBuffer("output rows"); - - int* symColP = reinterpret_cast(outputCols->buffer()); -// symRowP.p(n + 1, symRowP.e(n) + rowCounts.e(n)) -// outputRows->printBuffer("SymRows are"); - int const* pCols = reinterpret_cast(colP->buffer()); - T const* pVals = reinterpret_cast(valP->buffer()); - T* pOutput = reinterpret_cast(outputVals->buffer()); - //std::vector rowCountsV = rowCounts->getBufferAsVector(); - std::vector offset(N);// = NDArrayFactory::create('c', {N}); - -//PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(schedule(guided) shared(offset)) - for (Nd4jLong n = 0; n < N; n++) { - int begin = pRows[n]; - int bound = pRows[n + 1]; - - for (int i = begin; i < bound; i++) { - bool present = false; - int colPI = pCols[i]; - int start = pRows[colPI]; - int end = pRows[colPI + 1]; - - //PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided) firstprivate(offset)) - for (int m = start; m < end; m++) { - if (pCols[m] == n) { - present = true; - if (n <= colPI) { - symColP[symRowP[n] + offset[n]] = colPI; - symColP[symRowP[colPI] + offset[colPI]] = n; - pOutput[symRowP[n] + offset[n]] = pVals[i] + pVals[m]; - pOutput[symRowP[colPI] + offset[colPI]] = pVals[i] + pVals[m]; - } - } - } - - // If (colP[i], n) is not present, there is no addition involved - if (!present) { - //int colPI = pCols[i]; - //if (n <= colPI) { - symColP[symRowP[n] + offset[n]] = colPI; - symColP[symRowP[pCols[i]] + offset[colPI]] = n; - pOutput[symRowP[n] + offset[n]] = pVals[i]; - pOutput[symRowP[colPI] + offset[colPI]] = pVals[i]; - //} - - } - // Update offsets - if (!present || (present && n <= colPI)) { - ++offset[n]; - - if (colPI != n) - ++offset[colPI]; - } -// printVector(offset); - } +template +static void barnes_symmetrize_(const NDArray* rowP, const NDArray* colP, + const NDArray* valP, Nd4jLong N, + NDArray* outputRows, NDArray* outputCols, + NDArray* outputVals, NDArray* rowCounts) { + // auto N = rowP->lengthOf() - 1; /// 2 + rowP->lengthOf() % 2; + // auto numElements = output->lengthOf(); + // std::vector symRowP = + // rowCounts->asVectorT();//NDArrayFactory::create('c', + // {numElements}); NDArray symValP = NDArrayFactory::create('c', + // {numElements}); symRowP.insert(symRowP.begin(),0); symRowP(1, {0}) = + // *rowCounts; + int const* pRows = reinterpret_cast(rowP->buffer()); + int* symRowP = reinterpret_cast(outputRows->buffer()); + symRowP[0] = 0; + for (Nd4jLong n = 0; n < N; n++) + symRowP[n + 1] = symRowP[n] + rowCounts->e(n); + // outputRows->printBuffer("output rows"); + + int* symColP = reinterpret_cast(outputCols->buffer()); + // symRowP.p(n + 1, symRowP.e(n) + rowCounts.e(n)) + // outputRows->printBuffer("SymRows are"); + int const* pCols = reinterpret_cast(colP->buffer()); + T const* pVals = reinterpret_cast(valP->buffer()); + T* pOutput = reinterpret_cast(outputVals->buffer()); + // std::vector rowCountsV = rowCounts->getBufferAsVector(); + std::vector offset(N); // = NDArrayFactory::create('c', {N}); + + // PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(schedule(guided) shared(offset)) + for (Nd4jLong n = 0; n < N; n++) { + int begin = pRows[n]; + int bound = pRows[n + 1]; + + for (int i = begin; i < bound; i++) { + bool present = false; + int colPI = pCols[i]; + int start = pRows[colPI]; + int end = pRows[colPI + 1]; + + // PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided) firstprivate(offset)) + for (int m = start; m < end; m++) { + if (pCols[m] == n) { + present = true; + if (n <= colPI) { + symColP[symRowP[n] + offset[n]] = colPI; + symColP[symRowP[colPI] + offset[colPI]] = n; + pOutput[symRowP[n] + offset[n]] = pVals[i] + pVals[m]; + pOutput[symRowP[colPI] + offset[colPI]] = pVals[i] + pVals[m]; + } } + } + + // If (colP[i], n) is not present, there is no addition involved + if (!present) { + // int colPI = pCols[i]; + // if (n <= colPI) { + symColP[symRowP[n] + offset[n]] = colPI; + symColP[symRowP[pCols[i]] + offset[colPI]] = n; + pOutput[symRowP[n] + offset[n]] = pVals[i]; + pOutput[symRowP[colPI] + offset[colPI]] = pVals[i]; + //} + } + // Update offsets + if (!present || (present && n <= colPI)) { + ++offset[n]; + + if (colPI != n) ++offset[colPI]; + } + // printVector(offset); } - void barnes_symmetrize(const NDArray* rowP, const NDArray* colP, const NDArray* valP, Nd4jLong N, NDArray* outputRows, NDArray* outputCols, NDArray* outputVals, NDArray* rowCounts) { - - // Divide the result by two - BUILD_SINGLE_SELECTOR(valP->dataType(), barnes_symmetrize_, (rowP, colP, valP, N, outputRows, outputCols, outputVals, rowCounts), NUMERIC_TYPES); + } +} +void barnes_symmetrize(const NDArray* rowP, const NDArray* colP, + const NDArray* valP, Nd4jLong N, NDArray* outputRows, + NDArray* outputCols, NDArray* outputVals, + NDArray* rowCounts) { + // Divide the result by two + BUILD_SINGLE_SELECTOR( + valP->dataType(), barnes_symmetrize_, + (rowP, colP, valP, N, outputRows, outputCols, outputVals, rowCounts), + NUMERIC_TYPES); + + *outputVals /= 2.0; + // output->assign(symValP); +} +BUILD_SINGLE_TEMPLATE(template void barnes_symmetrize_, + (const NDArray* rowP, const NDArray* colP, + const NDArray* valP, Nd4jLong N, NDArray* outputRows, + NDArray* outputCols, NDArray* outputVals, + NDArray* rowCounts), + NUMERIC_TYPES); + +template +static void barnes_edge_forces_(const NDArray* rowP, NDArray const* colP, + NDArray const* valP, int N, NDArray const* data, + NDArray* output) { + T const* dataP = reinterpret_cast(data->buffer()); + T const* vals = reinterpret_cast(valP->buffer()); + T* outputP = reinterpret_cast(output->buffer()); + int colCount = data->columns(); + + // auto shift = 0; + auto rowSize = sizeof(T) * colCount; + + auto func = PRAGMA_THREADS_FOR { + for (auto n = start; n < stop; n++) { + int s = rowP->e(n); + int end = rowP->e(n + 1); + int shift = n * colCount; + for (int i = s; i < end; i++) { + T const* thisSlice = dataP + colP->e(i) * colCount; + T res = 1; + + for (int k = 0; k < colCount; k++) { + auto tempVal = dataP[shift + k] - thisSlice[k]; // thisSlice[k]; + res += tempVal * tempVal; + } - *outputVals /= 2.0; - //output->assign(symValP); - } - BUILD_SINGLE_TEMPLATE(template void barnes_symmetrize_, (const NDArray* rowP, const NDArray* colP, const NDArray* valP, Nd4jLong N, NDArray* outputRows, NDArray* outputCols, NDArray* outputVals, NDArray* rowCounts), NUMERIC_TYPES); - - template - static void barnes_edge_forces_(const NDArray* rowP, NDArray const* colP, NDArray const* valP, int N, NDArray const* data, NDArray* output) { - T const* dataP = reinterpret_cast(data->buffer()); - T const* vals = reinterpret_cast(valP->buffer()); - T* outputP = reinterpret_cast(output->buffer()); - int colCount = data->columns(); - - -// auto shift = 0; - auto rowSize = sizeof(T) * colCount; - - auto func = PRAGMA_THREADS_FOR { - for (auto n = start; n < stop; n++) { - int s = rowP->e(n); - int end = rowP->e(n + 1); - int shift = n * colCount; - for (int i = s; i < end; i++) { - T const *thisSlice = dataP + colP->e(i) * colCount; - T res = 1; - - for (int k = 0; k < colCount; k++) { - auto tempVal = dataP[shift + k] - thisSlice[k];//thisSlice[k]; - res += tempVal * tempVal; - } - - res = vals[i] / res; - for (int k = 0; k < colCount; k++) - outputP[shift + k] += ((dataP[shift + k] - thisSlice[k]) * res); - } - //shift += colCount; - } - }; - - samediff::Threads::parallel_tad(func, 0, N); + res = vals[i] / res; + for (int k = 0; k < colCount; k++) + outputP[shift + k] += ((dataP[shift + k] - thisSlice[k]) * res); + } + // shift += colCount; } + }; - void barnes_edge_forces(const NDArray* rowP, NDArray const* colP, NDArray const* valP, int N, NDArray* output, NDArray const& data) { - // Loop over all edges in the graph - BUILD_SINGLE_SELECTOR(output->dataType(), barnes_edge_forces_, (rowP, colP, valP, N, &data, output), FLOAT_TYPES); - } - BUILD_SINGLE_TEMPLATE(template void barnes_edge_forces_, (const NDArray* rowP, NDArray const* colP, NDArray const* valP, int N, NDArray const* data, NDArray* output), FLOAT_TYPES); - - template - static void barnes_gains_(NDArray* input, NDArray* gradX, NDArray* epsilon, NDArray* output) { - // gains = gains.add(.2).muli(sign(yGrads)).neq(sign(yIncs)).castTo(Nd4j.defaultFloatingPointType()) - // .addi(gains.mul(0.8).muli(sign(yGrads)).neq(sign(yIncs))); - auto gainsInternal = LAMBDA_TTT(x, grad, eps) { -// return T((x + 2.) * sd::math::nd4j_sign(grad) != sd::math::nd4j_sign(eps)) + T(x * 0.8 * sd::math::nd4j_sign(grad) != sd::math::nd4j_sign(eps)); - //return T((x + 2.) * sd::math::nd4j_sign(grad) == sd::math::nd4j_sign(eps)) + T(x * 0.8 * sd::math::nd4j_sign(grad) == sd::math::nd4j_sign(eps)); - T res = sd::math::nd4j_sign(grad) != sd::math::nd4j_sign(eps) ? x + T(.2) : x * T(.8); - if(res < .01) res = .01; - return res; - }; - - input->applyTriplewiseLambda(*gradX, *epsilon, gainsInternal, *output); - } + samediff::Threads::parallel_tad(func, 0, N); +} - void barnes_gains(NDArray* input, NDArray* gradX, NDArray* epsilon, NDArray* output) { - // gains = gains.add(.2).muli(sign(yGrads)).neq(sign(yIncs)).castTo(Nd4j.defaultFloatingPointType()) - // .addi(gains.mul(0.8).muli(sign(yGrads)).neq(sign(yIncs))); - BUILD_SINGLE_SELECTOR(input->dataType(), barnes_gains_, (input, gradX, epsilon, output), NUMERIC_TYPES); -// auto signGradX = *gradX; -// auto signEpsilon = *epsilon; -// gradX->applyTransform(transform::Sign, &signGradX, nullptr); -// epsilon->applyTransform(transform::Sign, &signEpsilon, nullptr); -// auto leftPart = (*input + 2.) * signGradX; -// auto leftPartBool = NDArrayFactory::create(leftPart.ordering(), leftPart.getShapeAsVector()); -// -// leftPart.applyPairwiseTransform(pairwise::NotEqualTo, &signEpsilon, &leftPartBool, nullptr); -// auto rightPart = *input * 0.8 * signGradX; -// auto rightPartBool = NDArrayFactory::create(rightPart.ordering(), rightPart.getShapeAsVector()); -// rightPart.applyPairwiseTransform(pairwise::NotEqualTo, &signEpsilon, &rightPartBool, nullptr); -// leftPart.assign(leftPartBool); -// rightPart.assign(rightPartBool); -// leftPart.applyPairwiseTransform(pairwise::Add, &rightPart, output, nullptr); +void barnes_edge_forces(const NDArray* rowP, NDArray const* colP, + NDArray const* valP, int N, NDArray* output, + NDArray const& data) { + // Loop over all edges in the graph + BUILD_SINGLE_SELECTOR(output->dataType(), barnes_edge_forces_, + (rowP, colP, valP, N, &data, output), FLOAT_TYPES); +} +BUILD_SINGLE_TEMPLATE(template void barnes_edge_forces_, + (const NDArray* rowP, NDArray const* colP, + NDArray const* valP, int N, NDArray const* data, + NDArray* output), + FLOAT_TYPES); + +template +static void barnes_gains_(NDArray* input, NDArray* gradX, NDArray* epsilon, + NDArray* output) { + // gains = + // gains.add(.2).muli(sign(yGrads)).neq(sign(yIncs)).castTo(Nd4j.defaultFloatingPointType()) + // .addi(gains.mul(0.8).muli(sign(yGrads)).neq(sign(yIncs))); + auto gainsInternal = LAMBDA_TTT(x, grad, eps) { + // return T((x + 2.) * sd::math::nd4j_sign(grad) != + // sd::math::nd4j_sign(eps)) + T(x * 0.8 * + // sd::math::nd4j_sign(grad) != + // sd::math::nd4j_sign(eps)); + // return T((x + 2.) * sd::math::nd4j_sign(grad) == + // sd::math::nd4j_sign(eps)) + T(x * 0.8 * + // sd::math::nd4j_sign(grad) == sd::math::nd4j_sign(eps)); + T res = sd::math::nd4j_sign(grad) != sd::math::nd4j_sign(eps) + ? x + T(.2) + : x * T(.8); + + if (res < .01) res = .01; + + return res; + }; + + input->applyTriplewiseLambda(*gradX, *epsilon, gainsInternal, *output); +} - } - BUILD_SINGLE_TEMPLATE(template void barnes_gains_, (NDArray* input, NDArray* gradX, NDArray* epsilon, NDArray* output), NUMERIC_TYPES); +void barnes_gains(NDArray* input, NDArray* gradX, NDArray* epsilon, + NDArray* output) { + // gains = + // gains.add(.2).muli(sign(yGrads)).neq(sign(yIncs)).castTo(Nd4j.defaultFloatingPointType()) + // .addi(gains.mul(0.8).muli(sign(yGrads)).neq(sign(yIncs))); + BUILD_SINGLE_SELECTOR(input->dataType(), barnes_gains_, + (input, gradX, epsilon, output), FLOAT_TYPES); + // auto signGradX = *gradX; + // auto signEpsilon = *epsilon; + // gradX->applyTransform(transform::Sign, &signGradX, nullptr); + // epsilon->applyTransform(transform::Sign, &signEpsilon, nullptr); + // auto leftPart = (*input + 2.) * signGradX; + // auto leftPartBool = + // NDArrayFactory::create(leftPart.ordering(), + // leftPart.getShapeAsVector()); + // + // leftPart.applyPairwiseTransform(pairwise::NotEqualTo, &signEpsilon, + // &leftPartBool, nullptr); auto rightPart = *input * 0.8 * signGradX; + // auto rightPartBool = + // NDArrayFactory::create(rightPart.ordering(), + // rightPart.getShapeAsVector()); + // rightPart.applyPairwiseTransform(pairwise::NotEqualTo, &signEpsilon, + // &rightPartBool, nullptr); leftPart.assign(leftPartBool); + // rightPart.assign(rightPartBool); + // leftPart.applyPairwiseTransform(pairwise::Add, &rightPart, output, + // nullptr); +} - bool cell_contains(NDArray* corner, NDArray* width, NDArray* point, Nd4jLong dimension) { - auto cornerMinusWidth = *corner - *width; - auto cornerPlusWidth = *corner + *width; +bool cell_contains(NDArray* corner, NDArray* width, NDArray* point, + Nd4jLong dimension) { + auto cornerMinusWidth = *corner - *width; + auto cornerPlusWidth = *corner + *width; - for (Nd4jLong i = 0; i < dimension; i++) { - if (cornerMinusWidth.e(i) > point->e(i)) - return false; - if (cornerPlusWidth.e(i) < point->e(i)) - return false; - } + for (Nd4jLong i = 0; i < dimension; i++) { + if (cornerMinusWidth.e(i) > point->e(i)) return false; + if (cornerPlusWidth.e(i) < point->e(i)) return false; + } - return true; - } -} + return true; } -} - +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/activations.cpp b/libnd4j/include/ops/declarable/helpers/cpu/activations.cpp index ccc4d676aeb4..3e6cf25f1f70 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/activations.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/activations.cpp @@ -19,226 +19,263 @@ // @author raver119@gmail.com // -#include +#include +#include #include +#include + #include -#include -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// - template - void static _softMaxDerivForVector(sd::LaunchContext * context, const void *input, const Nd4jLong *inShapeInfo, void *output) { - - const T* inBuff = reinterpret_cast(input); - T* outBuff = reinterpret_cast(output); - - T max = -DataTypeUtils::max(); - T sum = 0.; - int length = shape::length(inShapeInfo); - - for (int i = 0; i < length; i++) { - const Nd4jLong offset = shape::getIndexOffset(i, inShapeInfo); - max = sd::math::nd4j_max(max, inBuff[offset]); - } - - for (int i = 0; i < length; i++) { - const Nd4jLong offset = shape::getIndexOffset(i, inShapeInfo); - outBuff[offset] = sd::math::nd4j_exp(inBuff[offset] - max); - sum += outBuff[offset]; - } - - for (int i = 0; i < length; i++) { - const Nd4jLong offset = shape::getIndexOffset(i, inShapeInfo); - outBuff[offset] /= sum; - outBuff[offset] *= (1.f - outBuff[offset]); // derivative - } - } +template +void static _softMaxDerivForVector(sd::LaunchContext* context, + const void* input, + const Nd4jLong* inShapeInfo, void* output) { + const T* inBuff = reinterpret_cast(input); + T* outBuff = reinterpret_cast(output); + + T max = -DataTypeUtils::max(); + T sum = 0.; + int length = shape::length(inShapeInfo); + + for (int i = 0; i < length; i++) { + const Nd4jLong offset = shape::getIndexOffset(i, inShapeInfo); + max = sd::math::nd4j_max(max, inBuff[offset]); + } + + for (int i = 0; i < length; i++) { + const Nd4jLong offset = shape::getIndexOffset(i, inShapeInfo); + outBuff[offset] = sd::math::nd4j_exp(inBuff[offset] - max); + sum += outBuff[offset]; + } + + for (int i = 0; i < length; i++) { + const Nd4jLong offset = shape::getIndexOffset(i, inShapeInfo); + outBuff[offset] /= sum; + outBuff[offset] *= (1.f - outBuff[offset]); // derivative + } +} /////////////////////////////////////////////////////////////////// - void softmaxDerivative(sd::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension) { - - const int rank = input.rankOf(); - int temp; - - if(shape::isCommonVector(input.shapeInfo(), temp)) { - - BUILD_SINGLE_SELECTOR(input.dataType(), _softMaxDerivForVector, (context, input.buffer(), input.shapeInfo(), output.buffer()), FLOAT_TYPES); - } - else { - auto maxAlongDim = const_cast(input).reduceAlongDimension(reduce::Max, {dimension}, true); - (input - maxAlongDim).applyTransform(transform::Exp, output); // output contains exponents temporarily - auto sumAlongDim = output.reduceAlongDimension(reduce::Sum, {dimension}, true); - output /= sumAlongDim; - output *= (1.f - output); // derivative - } - } +void softmaxDerivative(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const int dimension) { + const int rank = input.rankOf(); + int temp; + + if (shape::isCommonVector(input.shapeInfo(), temp)) { + BUILD_SINGLE_SELECTOR( + input.dataType(), _softMaxDerivForVector, + (context, input.buffer(), input.shapeInfo(), output.buffer()), + FLOAT_TYPES); + } else { + auto maxAlongDim = const_cast(input).reduceAlongDimension( + reduce::Max, {dimension}, true); + (input - maxAlongDim) + .applyTransform(transform::Exp, + output); // output contains exponents temporarily + auto sumAlongDim = + output.reduceAlongDimension(reduce::Sum, {dimension}, true); + output /= sumAlongDim; + output *= (1.f - output); // derivative + } +} /////////////////////////////////////////////////////////////////// - template - void logSoftMaxForVector_(void const* input, Nd4jLong const* inShapeInfo, void *output, Nd4jLong const* outShapeInfo) { - auto inBuff = reinterpret_cast(input); - auto outBuff = reinterpret_cast(output); - - T max = -DataTypeUtils::max(); - T sum = 0; - - auto inEWS = shape::elementWiseStride(inShapeInfo); - auto length = shape::length(inShapeInfo); - - if (inEWS == 1) { - for (Nd4jLong i = 0; i < length; i++) - max = sd::math::nd4j_max(max, inBuff[i]); - - PRAGMA_OMP_SIMD_SUM(sum) - for (Nd4jLong i = 0; i < length; i++) { - outBuff[i] = sd::math::nd4j_exp(inBuff[i] - max); - sum += outBuff[i]; - } - - PRAGMA_OMP_SIMD - for (Nd4jLong i = 0; i < length; i++) { - outBuff[i] /= sum; - outBuff[i] = sd::math::nd4j_log(outBuff[i]); - } - } - else if (inEWS > 1) { - - PRAGMA_OMP_SIMD_MAX(max) - for (Nd4jLong i = 0; i < length; i++) - max = sd::math::nd4j_max(max, inBuff[i * inEWS]); - - PRAGMA_OMP_SIMD_SUM(sum) - for (Nd4jLong i = 0; i < length; i++) { - outBuff[i * inEWS] = sd::math::nd4j_exp(inBuff[i * inEWS] - max); - sum += outBuff[i * inEWS]; - } - - PRAGMA_OMP_SIMD - for (Nd4jLong i = 0; i < length; i++) { - outBuff[i * inEWS] /= sum; - outBuff[i * inEWS] = sd::math::nd4j_log(outBuff[i * inEWS]); - } - } +template +void logSoftMaxForVector_(void const* input, Nd4jLong const* inShapeInfo, + void* output, Nd4jLong const* outShapeInfo) { + auto inBuff = reinterpret_cast(input); + auto outBuff = reinterpret_cast(output); + + T max = -DataTypeUtils::max(); + T sum = 0; + + auto inEWS = shape::elementWiseStride(inShapeInfo); + auto length = shape::length(inShapeInfo); + + if (inEWS == 1) { + for (Nd4jLong i = 0; i < length; i++) + max = sd::math::nd4j_max(max, inBuff[i]); + + PRAGMA_OMP_SIMD_SUM(sum) + for (Nd4jLong i = 0; i < length; i++) { + outBuff[i] = sd::math::nd4j_exp(inBuff[i] - max); + sum += outBuff[i]; } - /////////////////////////////////////////////////////////////////// - void logSoftMaxForVector(sd::LaunchContext* context, const NDArray& input, NDArray& output) { - - if(!input.isVector() || !output.isVector()) - throw std::runtime_error("ops::helpers::logSoftMaxForVector function input and output arrays must be vectors !"); - - auto xType = input.dataType(); - BUILD_SINGLE_SELECTOR(xType, logSoftMaxForVector_, (input.buffer(), input.shapeInfo(), output.buffer(), output.shapeInfo()), FLOAT_TYPES); + PRAGMA_OMP_SIMD + for (Nd4jLong i = 0; i < length; i++) { + outBuff[i] /= sum; + outBuff[i] = sd::math::nd4j_log(outBuff[i]); + } + } else if (inEWS > 1) { + PRAGMA_OMP_SIMD_MAX(max) + for (Nd4jLong i = 0; i < length; i++) + max = sd::math::nd4j_max(max, inBuff[i * inEWS]); + + PRAGMA_OMP_SIMD_SUM(sum) + for (Nd4jLong i = 0; i < length; i++) { + outBuff[i * inEWS] = sd::math::nd4j_exp(inBuff[i * inEWS] - max); + sum += outBuff[i * inEWS]; } + PRAGMA_OMP_SIMD + for (Nd4jLong i = 0; i < length; i++) { + outBuff[i * inEWS] /= sum; + outBuff[i * inEWS] = sd::math::nd4j_log(outBuff[i * inEWS]); + } + } +} -////////////////////////////////////////////////////////////////////////// -void prelu(sd::LaunchContext * context, const NDArray& input, const NDArray& alpha, NDArray& output) { - const Nd4jLong inputLen = input.lengthOf(); - const Nd4jLong* inputShapeInfo = input.shapeInfo(); - const Nd4jLong* alphaShapeInfo = alpha.shapeInfo(); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - // FIXME: double! - double x = input.e(i); - if (x < 0.0) { - // FIXME: double - output.p(i, (x * alpha.e(shape::subArrayIndex(i, inputShapeInfo, alphaShapeInfo)))); - } else - output.p(i, x); - } - }; - - samediff::Threads::parallel_for(func, 0, inputLen); +/////////////////////////////////////////////////////////////////// +void logSoftMaxForVector(sd::LaunchContext* context, const NDArray& input, + NDArray& output) { + if (!input.isVector() || !output.isVector()) + throw std::runtime_error( + "ops::helpers::logSoftMaxForVector function input and output arrays " + "must be vectors !"); + + auto xType = input.dataType(); + BUILD_SINGLE_SELECTOR( + xType, logSoftMaxForVector_, + (input.buffer(), input.shapeInfo(), output.buffer(), output.shapeInfo()), + FLOAT_TYPES); } ////////////////////////////////////////////////////////////////////////// -void preluBP(sd::LaunchContext * context, const NDArray& input, const NDArray& alpha, const NDArray& dLdO, NDArray& dLdI, NDArray& dLdA) { - - const Nd4jLong inputLen = input.lengthOf(); - const Nd4jLong* inputShapeInfo = input.shapeInfo(); - const Nd4jLong* alphaShapeInfo = alpha.shapeInfo(); - - dLdA.assign(0.0f); - - for(Nd4jLong i = 0; i < inputLen; ++i) { +void prelu(sd::LaunchContext* context, const NDArray& input, + const NDArray& alpha, NDArray& output) { + const Nd4jLong inputLen = input.lengthOf(); + const Nd4jLong* inputShapeInfo = input.shapeInfo(); + const Nd4jLong* alphaShapeInfo = alpha.shapeInfo(); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + // FIXME: double! + double x = input.e(i); + if (x < 0.0) { // FIXME: double - double x = input.e(i); - double grO = dLdO.e(i); - if(x < 0.0) { - Nd4jLong alphaInd = shape::subArrayIndex(i, inputShapeInfo, alphaShapeInfo); - dLdI.p(i, grO * alpha.e(alphaInd)); - double prevVal = dLdA.e(alphaInd); - prevVal += (grO * x); - dLdA.p(alphaInd, prevVal); - } - else - dLdI.p(i, grO); - } -} - - - bool checkAlphaShapeLen(std::vector const& expectedShape, Nd4jLong shapeLen) { - Nd4jLong expectedAlphaLen = std::accumulate(expectedShape.cbegin(), expectedShape.cend(), 1, std::multiplies()); - return expectedAlphaLen == shapeLen; - } - template - static void thresholdRelu_(NDArray const& input, double threshold, NDArray& output) { - auto routine = LAMBDA_T(_x, threshold) { - return _x > (T)threshold? _x: (T)0.f; - }; - const_cast(input).applyLambda(routine, output); - } - - void thresholdRelu(sd::LaunchContext * context, NDArray const& input, double threshold, NDArray& output) { - BUILD_SINGLE_SELECTOR(input.dataType(), thresholdRelu_, (input, threshold, output), FLOAT_TYPES); - } - - template - static void thresholdReluDerivative_(sd::LaunchContext * context, NDArray* input, double theta, NDArray* dLdO, NDArray* output) { - auto derivative = LAMBDA_TT(_x, grO, theta) {if (_x > theta) return grO; else return static_cast(0); }; - - input->applyPairwiseLambda(*dLdO, derivative, *output); - + output.p(i, (x * alpha.e(shape::subArrayIndex( + i, inputShapeInfo, alphaShapeInfo)))); + } else + output.p(i, x); } + }; - void thresholdReluDerivative(sd::LaunchContext * context, NDArray* input, double threshold, NDArray* dLdO, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), thresholdReluDerivative_, (context, input, threshold, dLdO, output), FLOAT_TYPES); - } - - /////////////////////////////////////////////////////////////////// - void logSoftmax(sd::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension) { - - const int rank = input.rankOf(); - - if(input.isVector()) { + samediff::Threads::parallel_for(func, 0, inputLen); +} - if(rank == 1 || input.sizeAt(dimension) != 1) { - BUILD_SINGLE_SELECTOR(input.dataType(), logSoftMaxForVector_, (input.buffer(), input.shapeInfo(), output.buffer(), output.shapeInfo()), FLOAT_TYPES); - } - else - output = 0.; - } - else { +////////////////////////////////////////////////////////////////////////// +void preluBP(sd::LaunchContext* context, const NDArray& input, + const NDArray& alpha, const NDArray& dLdO, NDArray& dLdI, + NDArray& dLdA) { + const Nd4jLong inputLen = input.lengthOf(); + const Nd4jLong* inputShapeInfo = input.shapeInfo(); + const Nd4jLong* alphaShapeInfo = alpha.shapeInfo(); + + dLdA.assign(0.0f); + + for (Nd4jLong i = 0; i < inputLen; ++i) { + // FIXME: double + double x = input.e(i); + double grO = dLdO.e(i); + if (x < 0.0) { + Nd4jLong alphaInd = + shape::subArrayIndex(i, inputShapeInfo, alphaShapeInfo); + dLdI.p(i, grO * alpha.e(alphaInd)); + double prevVal = dLdA.e(alphaInd); + prevVal += (grO * x); + dLdA.p(alphaInd, prevVal); + } else + dLdI.p(i, grO); + } +} - auto maxAlongDim = const_cast(input).reduceAlongDimension(reduce::Max, {dimension}, true); - (input - maxAlongDim).applyTransform(transform::Exp, output); // output contains exponents temporarily - auto sumAlongDim = output.reduceAlongDimension(reduce::Sum, {dimension}, true); - output /= sumAlongDim; - output.applyTransform(transform::Log, output); - } - } +bool checkAlphaShapeLen(std::vector const& expectedShape, + Nd4jLong shapeLen) { + Nd4jLong expectedAlphaLen = + std::accumulate(expectedShape.cbegin(), expectedShape.cend(), 1, + std::multiplies()); + return expectedAlphaLen == shapeLen; +} +template +static void thresholdRelu_(NDArray const& input, double threshold, + NDArray& output) { + auto routine = LAMBDA_T(_x, threshold) { + return _x > (T)threshold ? _x : (T)0.f; + }; + const_cast(input).applyLambda(routine, output); +} - BUILD_SINGLE_TEMPLATE(template void thresholdReluDerivative_, (sd::LaunchContext * context, NDArray* input, double threshold, NDArray* dLdO, NDArray* output), FLOAT_TYPES); - BUILD_SINGLE_TEMPLATE(template void logSoftMaxForVector_, (void const* input, Nd4jLong const* inShapeInfo, void *output, Nd4jLong const* outShapeInfo), FLOAT_TYPES); - BUILD_SINGLE_TEMPLATE(template void _softMaxDerivForVector, (sd::LaunchContext * context, const void *input, const Nd4jLong *inShapeInfo, void *output), FLOAT_TYPES); +void thresholdRelu(sd::LaunchContext* context, NDArray const& input, + double threshold, NDArray& output) { + BUILD_SINGLE_SELECTOR(input.dataType(), thresholdRelu_, + (input, threshold, output), FLOAT_TYPES); +} +template +static void thresholdReluDerivative_(sd::LaunchContext* context, NDArray* input, + double theta, NDArray* dLdO, + NDArray* output) { + auto derivative = LAMBDA_TT(_x, grO, theta) { + if (_x > theta) + return grO; + else + return static_cast(0); + }; + + input->applyPairwiseLambda(*dLdO, derivative, *output); } + +void thresholdReluDerivative(sd::LaunchContext* context, NDArray* input, + double threshold, NDArray* dLdO, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), thresholdReluDerivative_, + (context, input, threshold, dLdO, output), FLOAT_TYPES); } + +/////////////////////////////////////////////////////////////////// +void logSoftmax(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const int dimension) { + const int rank = input.rankOf(); + + if (input.isVector()) { + if (rank == 1 || input.sizeAt(dimension) != 1) { + BUILD_SINGLE_SELECTOR(input.dataType(), logSoftMaxForVector_, + (input.buffer(), input.shapeInfo(), output.buffer(), + output.shapeInfo()), + FLOAT_TYPES); + } else + output = 0.; + } else { + auto maxAlongDim = const_cast(input).reduceAlongDimension( + reduce::Max, {dimension}, true); + (input - maxAlongDim) + .applyTransform(transform::Exp, + output); // output contains exponents temporarily + auto sumAlongDim = + output.reduceAlongDimension(reduce::Sum, {dimension}, true); + output /= sumAlongDim; + output.applyTransform(transform::Log, output); + } } +BUILD_SINGLE_TEMPLATE(template void thresholdReluDerivative_, + (sd::LaunchContext * context, NDArray* input, + double threshold, NDArray* dLdO, NDArray* output), + FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template void logSoftMaxForVector_, + (void const* input, Nd4jLong const* inShapeInfo, + void* output, Nd4jLong const* outShapeInfo), + FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template void _softMaxDerivForVector, + (sd::LaunchContext * context, const void* input, + const Nd4jLong* inShapeInfo, void* output), + FLOAT_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/addBias.cpp b/libnd4j/include/ops/declarable/helpers/cpu/addBias.cpp index aa86ea041dcb..b1777e186d26 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/addBias.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/addBias.cpp @@ -15,634 +15,699 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Yurii Shyrma, created on 26.02.2018 - // - // - // @author AbdelRauf - // +// +// @author Yurii Shyrma, created on 26.02.2018 +// +// +// @author AbdelRauf +// -#include -#include -#include -#include -#include #include +#include #include #include -#if defined(__GNUC__) +#include +#include +#include +#include + +#if defined(__GNUC__) #define align32 __attribute__((aligned(32))) #elif defined(_MSC_VER) #define align32 __declspec(align(32)) #else -#define align32 -#endif +#define align32 +#endif namespace sd { - namespace ops { - namespace helpers { - - template - static FORCEINLINE void _add(const T* __restrict xx, const T* __restrict yy, T* __restrict zz, const size_t& N) { - PRAGMA_OMP_SIMD - for (size_t c = 0; c < N; c++) - zz[c] = xx[c] + yy[c]; - } - - template - static FORCEINLINE void _add_inplace(T* __restrict xx, const T* __restrict yy, const size_t& N) { - PRAGMA_OMP_SIMD - for (size_t c = 0; c < N; c++) - xx[c] = xx[c] + yy[c]; - } - - template - static FORCEINLINE void _add_broadcast_inplace(T* __restrict xx, const T yy, const size_t& N) { - PRAGMA_OMP_SIMD - for (size_t c = 0; c < N; c++) - xx[c] = xx[c] + yy; - } - - template - static FORCEINLINE void _add_broadcast(const T* __restrict xx, const T yy, T* __restrict zz, const size_t& N) { - PRAGMA_OMP_SIMD - for (size_t c = 0; c < N; c++) - zz[c] = xx[c] + yy; - } - - static constexpr size_t MIN_NN = 32; - static constexpr size_t MIN_NN_K = 2; - - template - static typename std::enable_if::value, const X*>::type - flattened_bias(const Y* b_real, X* b_stack, const size_t b_stack_size, std::unique_ptr& b_heap, const Nd4jLong num, Nd4jLong yStrideC) - { - //best results when buffer used much , may result bad perf if buffer is used once - X* b_new = nullptr; - if (yStrideC != 1) { - if (num > b_stack_size) { - b_heap.reset(new X[num]); - b_new = b_heap.get(); - } - else { - b_new = b_stack; - } - for (size_t i = 0; i < num; i++) { - b_new[i] = b_real[i * yStrideC]; - } - } - else { - //no need , just pass normal bias - return static_cast(b_real); - } - return const_cast(b_new); - } - - template - static typename std::enable_if::value, const X*>::type - flattened_bias(const Y* b_real, X* b_stack, const size_t b_stack_size, std::unique_ptr& b_heap, const Nd4jLong num, Nd4jLong yStrideC) - { - //best results when buffer used much , may result bad perf if buffer is used once - X* b_new = nullptr; - if (num > b_stack_size) { - b_heap.reset(new X[num]); - b_new = b_heap.get(); - } - else { - b_new = b_stack; - } - if (yStrideC != 1) { - for (size_t i = 0; i < num; i++) { - b_new[i] = static_cast(b_real[i * yStrideC]); - } - } - else { - for (size_t i = 0; i < num; i++) { - b_new[i] = static_cast(b_real[i]); - } - } - return const_cast(b_new); - } - - template - static void channel_atTheEnd_stride1_C(const Nd4jLong*& x_strides, const Nd4jLong*& bases, T* x, const T* b, T* z, const bool& inplace, const Nd4jLong& start, const Nd4jLong& stop, const Nd4jLong& inc) - { - size_t loop_count = (stop - start) / inc; - sd::CoordsState cst; - size_t offset = sd::init_coords(cst, start, bases, x_strides); - - if (!inplace) { - for (size_t i = 0; i < loop_count; i++) { - _add(&(x[offset]), b, &(z[offset]), inc); - offset = sd::inc_coords(cst, offset); - } - } - else { - for (size_t i = 0; i < loop_count; i++) { - _add_inplace(&(x[offset]), b, inc); - offset = sd::inc_coords(cst, offset); - } - } - } - - - template - static void channel_atTheEnd_generic_C(const Nd4jLong* bases, const Nd4jLong* x_strides, const Nd4jLong* z_strides, const bool& inplaceOp, const bool same_stride, const bool same_order, T* x, const T* b, T* z, Nd4jLong start, Nd4jLong stop, Nd4jLong inc) { - - //just ensure that passed sameStride is correct, because when bases are equal orders matters - bool sameOrderStride = same_order && same_stride; - if (sameOrderStride && x_strides[constRank - 1] == 1) { - channel_atTheEnd_stride1_C(x_strides, bases, x, b, z, inplaceOp, start, stop, inc); - } - else { - size_t loop_count = (stop - start) / inc; - sd::ZipCoordsState cst; - sd::zip_size_t offset = sd::init_coords(cst, start, bases, x_strides, z_strides); - Nd4jLong x_stride = ZIP_STRIDE1(cst, constRank - 1); - Nd4jLong z_stride = ZIP_STRIDE2(cst, constRank - 1); - - if (same_order && x_stride == 1 && z_stride == 1) { - /* bases are equal with different strides , but the last one is 1. So we can still vectorize it */ - for (size_t i = 0; i < loop_count; i++) { - _add(&(x[offset.first]), b, &(z[offset.second]), inc); - offset = sd::inc_coords(cst, offset); - } - } - else { - for (size_t i = 0; i < loop_count; i++) { - T* xx = &(x[offset.first]); - T* zz = &(z[offset.second]); - for (size_t j = 0; j < inc; j++) - zz[j * z_stride] = xx[j * x_stride] + b[j]; - offset = sd::inc_coords(cst, offset); - } - } - } - - } - - /** - * this is our main optimization which benefits from everything for the continuous last_channel C order case - * as it is intended for full continous we do not need any rank info - */ - template - void channel_atTheEnd_continous_C(T* x, const T* b, T* z, bool inplaceOp, Nd4jLong start, Nd4jLong stop, Nd4jLong inc) { - size_t nums = (stop - start); - size_t num_inc = nums - nums % inc; - if (inplaceOp) { - - size_t offset_p = start; - for (size_t i = 0; i < num_inc; i += inc) { - _add_inplace(&(x[offset_p]), b, inc); - offset_p += inc; - } - if (nums > num_inc) - _add_inplace(&(x[offset_p]), b, nums - num_inc); - } - else { - size_t offset_p = start; - for (size_t i = 0; i < num_inc; i += inc) { - _add(&(x[offset_p]), b, &(z[offset_p]), inc); - offset_p += inc; - } - if (nums > num_inc) - _add(&(x[offset_p]), b, &(z[offset_p]), nums - num_inc); - } - } - - template - static void channel_NC_stride1_C(const Nd4jLong*& x_strides, const Nd4jLong*& bases, T* x, const T2* b, T* z, const bool& inplace, const Nd4jLong yStrideC, const Nd4jLong& start, const Nd4jLong& stop, const Nd4jLong& inc) - { - size_t loop_count = (stop - start) / inc; - sd::CoordsState cst; - size_t offset = sd::init_coords(cst, start, bases, x_strides); - - if (!inplace) { - for (size_t i = 0; i < loop_count; i++) { - T yy = static_cast(b[COORDS(cst, 1) * yStrideC]); - _add_broadcast(&(x[offset]), yy, &(z[offset]), inc); - offset = sd::inc_coords(cst, offset); - } - } - else { - for (size_t i = 0; i < loop_count; i++) { - T yy = static_cast(b[COORDS(cst, 1) * yStrideC]); - _add_broadcast_inplace(&(x[offset]), yy, inc); - offset = sd::inc_coords(cst, offset); - } - } - } - - template - void channel_NC_generic_C(const Nd4jLong* bases, const Nd4jLong* x_strides, const Nd4jLong* z_strides, const bool& inplaceOp, const bool same_stride, const bool same_order, const Nd4jLong yStrideC, T* x, const T2* b, T* z, Nd4jLong start, Nd4jLong stop, Nd4jLong inc) { - - //just ensure that passed sameStride is correct, because when bases are equal orders matters - - bool sameOrderStride = same_order && same_stride; - - if (sameOrderStride && x_strides[constRank - 1] == 1) { - channel_NC_stride1_C(x_strides, bases, x, b, z, inplaceOp, yStrideC, start, stop, inc); - } - else { - - // (stop-start) % inc == 0 because we handled inside partitioning using the channel size - size_t loop_count = (stop - start) / inc; - sd::ZipCoordsState cst; - sd::zip_size_t offset = sd::init_coords(cst, start, bases, x_strides, z_strides); - Nd4jLong x_stride = ZIP_STRIDE1(cst, constRank - 1); - Nd4jLong z_stride = ZIP_STRIDE2(cst, constRank - 1); - if (same_order && z_stride == 1 && x_stride == 1) { - /* bases are equal with different strides , but the last one is 1. So we can still vectorize it */ - for (size_t i = 0; i < loop_count; i++) { - T yy = static_cast(b[ZIP_COORDS(cst, 1) * yStrideC]); - _add_broadcast(&(x[offset.first]), yy, &(z[offset.second]), inc); - offset = sd::inc_coords(cst, offset); - } - } - else { - for (size_t i = 0; i < loop_count; i++) { - T* xx = &(x[offset.first]); - T* zz = &(z[offset.second]); - T yy = static_cast(b[ZIP_COORDS(cst, 1) * yStrideC]); - for (size_t j = 0; j < inc; j++) - zz[j * z_stride] = xx[j * x_stride] + yy; - offset = sd::inc_coords(cst, offset); - } - } - } - } - - /// - template - void channel_NC_continous_numHW_C(Nd4jLong rank, const Nd4jLong* bases, const Nd4jLong* x_strides, T* x, const T2* b, T* z, bool inplaceOp, const Nd4jLong yStrideC, Nd4jLong start, Nd4jLong stop, Nd4jLong inc) { - - // (stop-start) % inc == 0 because we handled inside partitioning using the channel size - size_t loop_count = (stop - start) / inc; - - sd::CoordsState<1> cst; - //note: we had to manually pass index - size_t offset_p = sd::init_coords<2>(cst, start / inc, bases, x_strides); - - //partitioning was done using numHW, so we can increment from rank 2 - if (inplaceOp) { - for (size_t i = 0; i < loop_count; i++) { - T yy = static_cast(b[COORDS(cst, 1) * yStrideC]); - _add_broadcast_inplace(&(x[offset_p]), yy, inc); - offset_p = sd::inc_coords<2>(cst, offset_p); - } - } - else { - if (yStrideC == 1) { - for (size_t i = 0; i < loop_count; i++) { - T yy = static_cast(b[COORDS(cst, 1)]); - _add_broadcast(&(x[offset_p]), yy, &(z[offset_p]), inc); - offset_p = sd::inc_coords<2>(cst, offset_p); - } - } - else { - for (size_t i = 0; i < loop_count; i++) { - T yy = static_cast(b[COORDS(cst, 1) * yStrideC]); - _add_broadcast(&(x[offset_p]), yy, &(z[offset_p]), inc); - offset_p = sd::inc_coords<2>(cst, offset_p); - } - } - } - } - - // - template - static void channel_generic_stride_skip_F(const Nd4jLong*& x_strides, const Nd4jLong*& bases, T* x, const T2* b, T* z, const bool& inplace, const Nd4jLong yStrideC, const Nd4jLong& start, const Nd4jLong& stop, const Nd4jLong& inc) - { - // (stop-start) % inc == 0 because we handled inside partitioning using the channel size - size_t loop_count = (stop - start) / inc; - sd::CoordsState cst; - size_t offset_p = sd::init_coords(cst, start, bases, x_strides); - if (!inplace) { - for (size_t i = 0; i < loop_count; i++) { - T yy = static_cast(b[COORDS(cst, b_index) * yStrideC]); - _add_broadcast(&(x[offset_p]), yy, &(z[offset_p]), inc); - offset_p = sd::inc_coords(cst, offset_p); - } - } - else { - for (size_t i = 0; i < loop_count; i++) { - T yy = static_cast(b[COORDS(cst, b_index) * yStrideC]); - _add_broadcast_inplace(&(x[offset_p]), yy, inc); - offset_p = sd::inc_coords(cst, offset_p); - } - } - } - - /// - template - void channel_generic_F(const Nd4jLong* bases, const Nd4jLong* x_strides, const Nd4jLong* z_strides, const bool& inplaceOp, const bool same_stride, const bool same_order, const Nd4jLong yStrideC, T* x, const T2* b, T* z, Nd4jLong start, Nd4jLong stop, Nd4jLong inc) { - //just ensure that passed sameStride is correct, because when bases are equal orders matters - bool sameOrderStride = same_order && same_stride; - if (sameOrderStride && x_strides[0] == 1) { - channel_generic_stride_skip_F(x_strides, bases, x, b, z, inplaceOp, yStrideC, start, stop, inc); - } - else { - // (stop-start) % inc == 0 because we handled inside partitioning using the channel size - - size_t loop_count = (stop - start) / inc; - sd::ZipCoordsState cst; - sd::zip_size_t offset = sd::init_coords(cst, start, bases, x_strides, z_strides); - Nd4jLong x_stride = ZIP_STRIDE1(cst, 0); - Nd4jLong z_stride = ZIP_STRIDE2(cst, 0); - if (same_order && z_stride == 1 && x_stride == 1) { - - for (size_t i = 0; i < loop_count; i++) { - T yy = static_cast(b[ZIP_COORDS(cst, b_index) * yStrideC]); - _add_broadcast(&(x[offset.first]), yy, &(z[offset.second]), inc); - offset = sd::inc_coords(cst, offset); - } - } - else { - for (size_t i = 0; i < loop_count; i++) { - T* xx = &(x[offset.first]); - T* zz = &(z[offset.second]); - T yy = static_cast(b[ZIP_COORDS(cst, b_index) * yStrideC]); - for (size_t j = 0; j < inc; j++) - zz[j * z_stride] = xx[j * x_stride] + yy; - offset = sd::inc_coords(cst, offset); - } - } - } - } - - - template - static void addBias_(const NDArray& input, const NDArray& bias, NDArray& output, const bool isNCHW) { - /* - if (input.rankOf() == 2 && bias.rankOf() == 1 && input.sizeAt(1) == bias.sizeAt(0) && input.ordering() == 'c') { - int rows = input.sizeAt(0); - int biasLen = bias.lengthOf(); - - auto inB = input.bufferAsT(); - auto bB = bias.bufferAsT(); - auto outB = output.bufferAsT(); - - for (int e = 0; e < rows; e++) { - auto row = inB + (e * biasLen); - auto out = outB + (e * biasLen); - - for (int t = 0; t < biasLen; t++) { - out[t] = row[t] + bB[t]; - } - } - - return; - } - */ - - auto x_shapeInfo = input.shapeInfo(); - auto z_shapeInfo = output.shapeInfo(); - auto x = input.bufferAsT(); - auto z = output.bufferAsT(); - auto b = bias.bufferAsT(); - const Nd4jLong rank = x_shapeInfo[0]; - auto bases = &(x_shapeInfo[1]); - auto x_strides = &(x_shapeInfo[rank + 1]); - auto z_strides = &(z_shapeInfo[rank + 1]); - const bool inplaceOp = (x == z); - const bool same_order = inplaceOp || (input.ordering() == output.ordering()); - const bool channel_atTheEnd = !isNCHW; - const bool same_stride = inplaceOp || shape::strideEquals(x_shapeInfo, z_shapeInfo); - bool isContinuous = false; - int posOfNonUnityDim; - bias.isCommonVector(posOfNonUnityDim); - const Nd4jLong yStrideC = bias.strideAt(posOfNonUnityDim); - char order = input.ordering(); - - //for rank>5 - if (rank > 5) { - const int channelDim = isNCHW ? 1 : input.rankOf() - 1; // second or last - const_cast(input).applyBroadcast(sd::broadcast::Add, { channelDim }, bias, output); - return; - } - - if (same_order && same_stride) { - isContinuous = shape::elementWiseStride(x_shapeInfo) == 1 && shape::elementWiseStride(z_shapeInfo) == 1; - // check_continuity(order, bases, x_strides, rank); - }//if ( sameOrder && same_stride) - - bool treat_as_lastC = false; - // - if (rank == 2 && isNCHW) { - //we believe we better treat it as channel at the end case; - treat_as_lastC = true; - } - if (channel_atTheEnd || treat_as_lastC) { - //N..HWC case here - //flattened bias variables - constexpr size_t BSIZE1 = 3 * MIN_NN * MIN_NN; - constexpr size_t BSIZE2 = BSIZE1 + MIN_NN * MIN_NN; - X flatBias_stack[BSIZE2] align32; - std::unique_ptr flatBias_heap; - const X* bias_new; - X* bias_extra = nullptr; - size_t total_num = 1; - for (Nd4jLong i = 0; i < rank; i++) { - total_num *= bases[i]; - } - Nd4jLong inc; - size_t rank_skip = 1; - if (order == 'c') { - size_t b_stack_size = BSIZE2; - inc = bases[rank - 1]; - if (isContinuous) { - //for continous we need extra stack memory - // to create vectorizable bias from small size - b_stack_size = BSIZE1; - bias_extra = &(flatBias_stack[BSIZE1]); - } - bias_new = flattened_bias(b, (X*)flatBias_stack, b_stack_size, flatBias_heap, inc, yStrideC); - if (isContinuous && inc < MIN_NN_K * MIN_NN && total_num > inc * MIN_NN_K) { - //for small size where total_num is sufficient we need to recreate vectorizable buffer - size_t old_inc = inc; - //sizeof bias_extra is MIN_NN * MIN_NN - size_t new_inc = inc < MIN_NN ? inc * MIN_NN : inc * MIN_NN / MIN_NN_K; - //if there is a room then lets multiply - new_inc = (new_inc * MIN_NN_K <= total_num && new_inc < MIN_NN * MIN_NN / MIN_NN_K) ? MIN_NN_K * new_inc : new_inc; - for (size_t i = 0; i < new_inc; i += inc) { - //copy to our buffer - X* cp = &(bias_extra[i]); - for (size_t j = 0; j < inc; j++) { - cp[j] = bias_new[j]; - } - } - //vectorizable buffer - inc = new_inc; - bias_new = bias_extra; - } - } - else { - inc = bases[0]; - if (isContinuous) { - //we can choose other inc and index for that case - //but for now lets choose all till the last one - uint32_t req_numThreads = sd::Environment::getInstance().maxMasterThreads(); - isContinuous = false; - if (rank > 2) { - if (req_numThreads < 2 || bases[rank - 1] >= req_numThreads) { - inc = total_num / bases[rank - 1]; - isContinuous = true; - rank_skip = rank - 1; - } - else if (rank > 3 && bases[rank - 1] * bases[rank - 2] >= req_numThreads) { - inc = total_num / bases[rank - 1] / bases[rank - 2]; //for continuous case it is its stride - rank_skip = rank - 2; - isContinuous = true; - } - } - } - } - - FUNC_1D func = [order, isContinuous, rank, x, b, bias_new, z, x_shapeInfo, z_shapeInfo, same_stride, same_order, yStrideC, rank_skip] - (uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void { - const Nd4jLong rank = x_shapeInfo[0]; - auto bases = &(x_shapeInfo[1]); - auto x_strides = &(x_shapeInfo[rank + 1]); - auto z_strides = &(z_shapeInfo[rank + 1]); - const bool inplaceOp = (x == z); - if (order == 'c') { - if (isContinuous) { - channel_atTheEnd_continous_C(const_cast(x), bias_new, z, inplaceOp, start, stop, increment); - } - // rank is in [2,5] - else if (rank == 4) { - channel_atTheEnd_generic_C(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, const_cast(x), bias_new, z, start, stop, increment); - - } - else if (rank == 5) { - channel_atTheEnd_generic_C(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, const_cast(x), bias_new, z, start, stop, increment); - } - else if (rank == 2) { - channel_atTheEnd_generic_C(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, const_cast(x), bias_new, z, start, stop, increment); - } - else if (rank == 3) { - channel_atTheEnd_generic_C(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, const_cast(x), bias_new, z, start, stop, increment); - } - } - else { - //generic F case - if (isContinuous) { - if (rank == 4) { - if (rank_skip == rank - 2) { - channel_generic_stride_skip_F(x_strides, bases, const_cast(x), b, z, inplaceOp, yStrideC, start, stop, increment); - } - else { - channel_generic_stride_skip_F(x_strides, bases, const_cast(x), b, z, inplaceOp, yStrideC, start, stop, increment); - } - } - else if (rank == 5) { - if (rank_skip == rank - 2) { - //skip==3 - channel_generic_stride_skip_F(x_strides, bases, const_cast(x), b, z, inplaceOp, yStrideC, start, stop, increment); - } - else { - channel_generic_stride_skip_F(x_strides, bases, const_cast(x), b, z, inplaceOp, yStrideC, start, stop, increment); - } - } - else if (rank == 3) { - channel_generic_stride_skip_F(x_strides, bases, const_cast(x), b, z, inplaceOp, yStrideC, start, stop, increment); - } - } - else if (rank == 4) { - channel_generic_F(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, yStrideC, const_cast(x), b, z, start, stop, increment); - } - else if (rank == 5) { - channel_generic_F(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, yStrideC, const_cast(x), b, z, start, stop, increment); - } - else if (rank == 2) { - channel_generic_F(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, yStrideC, const_cast(x), b, z, start, stop, increment); - } - else if (rank == 3) { - channel_generic_F(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, yStrideC, const_cast(x), b, z, start, stop, increment); - } - - } - }; - // - samediff::Threads::parallel_aligned_increment(func, 0, total_num, inc); - } - else { - //NC...HW case here - size_t numNC = 1; - size_t numHW = 1; - for (size_t i = 0; i < 2; i++) { - numNC *= bases[i]; - } - for (Nd4jLong i = 2; i < rank; i++) { - numHW *= bases[i]; - } - Nd4jLong total_num = numNC * numHW; - Nd4jLong inc = (order == 'c') ? bases[rank - 1] : bases[0]; - if (order == 'c' && isContinuous) { - //sometimes last dimension is too big and multithreading could suffer using unfair partitioning - //so we will do it only when inc is smaller our value or multithreading turned off - uint32_t req_numThreads = sd::Environment::getInstance().maxMasterThreads(); - if (req_numThreads < 2 || numNC >= req_numThreads || inc <= 2 * 8196 || rank == 3) { - inc = numHW; - } - else { - //treat it as stride1c case - isContinuous = false; - } - } - FUNC_1D func = [order, isContinuous, rank, x, b, z, x_shapeInfo, z_shapeInfo, same_stride, same_order, yStrideC] - (uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void { - const Nd4jLong rank = x_shapeInfo[0]; - const Nd4jLong* bases = &(x_shapeInfo[1]); - const Nd4jLong* x_strides = &(x_shapeInfo[rank + 1]); - const Nd4jLong* z_strides = &(z_shapeInfo[rank + 1]); - const bool inplaceOp = (x == z); - if (order == 'c') { - if (isContinuous) { - channel_NC_continous_numHW_C(rank, bases, x_strides, const_cast(x), b, z, inplaceOp, yStrideC, start, stop, increment); - } - // rank is in [3,5] - else if (rank == 4) { - channel_NC_generic_C(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, yStrideC, const_cast(x), b, z, start, stop, increment); - - } - else if (rank == 5) { - channel_NC_generic_C(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, yStrideC, const_cast(x), b, z, start, stop, increment); - } - else if (rank == 3) { - channel_NC_generic_C(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, yStrideC, const_cast(x), b, z, start, stop, increment); - } - } - else { - //the same can be applied for NCHW case - //generic F case - //continous case is missing - - if (rank == 4) { - channel_generic_F(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, yStrideC, const_cast(x), b, z, start, stop, increment); - } - else if (rank == 5) { - channel_generic_F(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, yStrideC, const_cast(x), b, z, start, stop, increment); - } - else if (rank == 3) { - channel_generic_F(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, yStrideC, const_cast(x), b, z, start, stop, increment); - } - } - }; - // - samediff::Threads::parallel_aligned_increment(func, 0, total_num, inc); - } - } - ////////////////////////////////////////////////////////////////////////// - void addBias(sd::graph::Context& block, const NDArray& input, const NDArray& bias, NDArray& output, const bool isNCHW) { - - // bias.rankOf() == 1 ? bias : bias.reshape(bias.ordering(), {bias.lengthOf()}) - BUILD_DOUBLE_SELECTOR(input.dataType(), bias.dataType(), addBias_, (input, bias, output, isNCHW), FLOAT_TYPES, FLOAT_TYPES); - } - - - BUILD_DOUBLE_TEMPLATE(template void addBias_, (const NDArray& input, const NDArray& bias, NDArray& output, const bool isNCHW), FLOAT_TYPES, FLOAT_TYPES); - } - } +namespace ops { +namespace helpers { + +template +static FORCEINLINE void _add(const T* __restrict xx, const T* __restrict yy, + T* __restrict zz, const size_t& N) { + PRAGMA_OMP_SIMD + for (size_t c = 0; c < N; c++) zz[c] = xx[c] + yy[c]; +} + +template +static FORCEINLINE void _add_inplace(T* __restrict xx, const T* __restrict yy, + const size_t& N) { + PRAGMA_OMP_SIMD + for (size_t c = 0; c < N; c++) xx[c] = xx[c] + yy[c]; +} + +template +static FORCEINLINE void _add_broadcast_inplace(T* __restrict xx, const T yy, + const size_t& N) { + PRAGMA_OMP_SIMD + for (size_t c = 0; c < N; c++) xx[c] = xx[c] + yy; +} + +template +static FORCEINLINE void _add_broadcast(const T* __restrict xx, const T yy, + T* __restrict zz, const size_t& N) { + PRAGMA_OMP_SIMD + for (size_t c = 0; c < N; c++) zz[c] = xx[c] + yy; +} + +static constexpr size_t MIN_NN = 32; +static constexpr size_t MIN_NN_K = 2; + +template +static typename std::enable_if::value, const X*>::type +flattened_bias(const Y* b_real, X* b_stack, const size_t b_stack_size, + std::unique_ptr& b_heap, const Nd4jLong num, + Nd4jLong yStrideC) { + // best results when buffer used much , may result bad perf if buffer is used + // once + X* b_new = nullptr; + if (yStrideC != 1) { + if (num > b_stack_size) { + b_heap.reset(new X[num]); + b_new = b_heap.get(); + } else { + b_new = b_stack; + } + for (size_t i = 0; i < num; i++) { + b_new[i] = b_real[i * yStrideC]; + } + } else { + // no need , just pass normal bias + return static_cast(b_real); + } + return const_cast(b_new); +} + +template +static typename std::enable_if::value, const X*>::type +flattened_bias(const Y* b_real, X* b_stack, const size_t b_stack_size, + std::unique_ptr& b_heap, const Nd4jLong num, + Nd4jLong yStrideC) { + // best results when buffer used much , may result bad perf if buffer is used + // once + X* b_new = nullptr; + if (num > b_stack_size) { + b_heap.reset(new X[num]); + b_new = b_heap.get(); + } else { + b_new = b_stack; + } + if (yStrideC != 1) { + for (size_t i = 0; i < num; i++) { + b_new[i] = static_cast(b_real[i * yStrideC]); + } + } else { + for (size_t i = 0; i < num; i++) { + b_new[i] = static_cast(b_real[i]); + } + } + return const_cast(b_new); } + +template +static void channel_atTheEnd_stride1_C(const Nd4jLong*& x_strides, + const Nd4jLong*& bases, T* x, const T* b, + T* z, const bool& inplace, + const Nd4jLong& start, + const Nd4jLong& stop, + const Nd4jLong& inc) { + size_t loop_count = (stop - start) / inc; + sd::CoordsState cst; + size_t offset = sd::init_coords(cst, start, bases, x_strides); + + if (!inplace) { + for (size_t i = 0; i < loop_count; i++) { + _add(&(x[offset]), b, &(z[offset]), inc); + offset = sd::inc_coords(cst, offset); + } + } else { + for (size_t i = 0; i < loop_count; i++) { + _add_inplace(&(x[offset]), b, inc); + offset = sd::inc_coords(cst, offset); + } + } +} + +template +static void channel_atTheEnd_generic_C( + const Nd4jLong* bases, const Nd4jLong* x_strides, const Nd4jLong* z_strides, + const bool& inplaceOp, const bool same_stride, const bool same_order, T* x, + const T* b, T* z, Nd4jLong start, Nd4jLong stop, Nd4jLong inc) { + // just ensure that passed sameStride is correct, because when bases are + // equal orders matters + bool sameOrderStride = same_order && same_stride; + if (sameOrderStride && x_strides[constRank - 1] == 1) { + channel_atTheEnd_stride1_C(x_strides, bases, x, b, z, + inplaceOp, start, stop, inc); + } else { + size_t loop_count = (stop - start) / inc; + sd::ZipCoordsState cst; + sd::zip_size_t offset = + sd::init_coords(cst, start, bases, x_strides, z_strides); + Nd4jLong x_stride = ZIP_STRIDE1(cst, constRank - 1); + Nd4jLong z_stride = ZIP_STRIDE2(cst, constRank - 1); + + if (same_order && x_stride == 1 && z_stride == 1) { + /* bases are equal with different strides , but the last one is 1. So we + * can still vectorize it */ + for (size_t i = 0; i < loop_count; i++) { + _add(&(x[offset.first]), b, &(z[offset.second]), inc); + offset = sd::inc_coords(cst, offset); + } + } else { + for (size_t i = 0; i < loop_count; i++) { + T* xx = &(x[offset.first]); + T* zz = &(z[offset.second]); + for (size_t j = 0; j < inc; j++) + zz[j * z_stride] = xx[j * x_stride] + b[j]; + offset = sd::inc_coords(cst, offset); + } + } + } +} + +/** + * this is our main optimization which benefits from everything for the + * continuous last_channel C order case as it is intended for full continous we + * do not need any rank info + */ +template +void channel_atTheEnd_continous_C(T* x, const T* b, T* z, bool inplaceOp, + Nd4jLong start, Nd4jLong stop, Nd4jLong inc) { + size_t nums = (stop - start); + size_t num_inc = nums - nums % inc; + if (inplaceOp) { + size_t offset_p = start; + for (size_t i = 0; i < num_inc; i += inc) { + _add_inplace(&(x[offset_p]), b, inc); + offset_p += inc; + } + if (nums > num_inc) _add_inplace(&(x[offset_p]), b, nums - num_inc); + } else { + size_t offset_p = start; + for (size_t i = 0; i < num_inc; i += inc) { + _add(&(x[offset_p]), b, &(z[offset_p]), inc); + offset_p += inc; + } + if (nums > num_inc) + _add(&(x[offset_p]), b, &(z[offset_p]), nums - num_inc); + } +} + +template +static void channel_NC_stride1_C(const Nd4jLong*& x_strides, + const Nd4jLong*& bases, T* x, const T2* b, + T* z, const bool& inplace, + const Nd4jLong yStrideC, const Nd4jLong& start, + const Nd4jLong& stop, const Nd4jLong& inc) { + size_t loop_count = (stop - start) / inc; + sd::CoordsState cst; + size_t offset = sd::init_coords(cst, start, bases, x_strides); + + if (!inplace) { + for (size_t i = 0; i < loop_count; i++) { + T yy = static_cast(b[COORDS(cst, 1) * yStrideC]); + _add_broadcast(&(x[offset]), yy, &(z[offset]), inc); + offset = sd::inc_coords(cst, offset); + } + } else { + for (size_t i = 0; i < loop_count; i++) { + T yy = static_cast(b[COORDS(cst, 1) * yStrideC]); + _add_broadcast_inplace(&(x[offset]), yy, inc); + offset = sd::inc_coords(cst, offset); + } + } +} + +template +void channel_NC_generic_C(const Nd4jLong* bases, const Nd4jLong* x_strides, + const Nd4jLong* z_strides, const bool& inplaceOp, + const bool same_stride, const bool same_order, + const Nd4jLong yStrideC, T* x, const T2* b, T* z, + Nd4jLong start, Nd4jLong stop, Nd4jLong inc) { + // just ensure that passed sameStride is correct, because when bases are + // equal orders matters + + bool sameOrderStride = same_order && same_stride; + + if (sameOrderStride && x_strides[constRank - 1] == 1) { + channel_NC_stride1_C(x_strides, bases, x, b, z, inplaceOp, + yStrideC, start, stop, inc); + } else { + // (stop-start) % inc == 0 because we handled inside partitioning using + // the channel size + size_t loop_count = (stop - start) / inc; + sd::ZipCoordsState cst; + sd::zip_size_t offset = + sd::init_coords(cst, start, bases, x_strides, z_strides); + Nd4jLong x_stride = ZIP_STRIDE1(cst, constRank - 1); + Nd4jLong z_stride = ZIP_STRIDE2(cst, constRank - 1); + if (same_order && z_stride == 1 && x_stride == 1) { + /* bases are equal with different strides , but the last one is 1. So we + * can still vectorize it */ + for (size_t i = 0; i < loop_count; i++) { + T yy = static_cast(b[ZIP_COORDS(cst, 1) * yStrideC]); + _add_broadcast(&(x[offset.first]), yy, &(z[offset.second]), inc); + offset = sd::inc_coords(cst, offset); + } + } else { + for (size_t i = 0; i < loop_count; i++) { + T* xx = &(x[offset.first]); + T* zz = &(z[offset.second]); + T yy = static_cast(b[ZIP_COORDS(cst, 1) * yStrideC]); + for (size_t j = 0; j < inc; j++) + zz[j * z_stride] = xx[j * x_stride] + yy; + offset = sd::inc_coords(cst, offset); + } + } + } +} + +/// +template +void channel_NC_continous_numHW_C(Nd4jLong rank, const Nd4jLong* bases, + const Nd4jLong* x_strides, T* x, const T2* b, + T* z, bool inplaceOp, const Nd4jLong yStrideC, + Nd4jLong start, Nd4jLong stop, Nd4jLong inc) { + // (stop-start) % inc == 0 because we handled inside partitioning using the + // channel size + size_t loop_count = (stop - start) / inc; + + sd::CoordsState<1> cst; + // note: we had to manually pass index + size_t offset_p = sd::init_coords<2>(cst, start / inc, bases, x_strides); + + // partitioning was done using numHW, so we can increment from rank 2 + if (inplaceOp) { + for (size_t i = 0; i < loop_count; i++) { + T yy = static_cast(b[COORDS(cst, 1) * yStrideC]); + _add_broadcast_inplace(&(x[offset_p]), yy, inc); + offset_p = sd::inc_coords<2>(cst, offset_p); + } + } else { + if (yStrideC == 1) { + for (size_t i = 0; i < loop_count; i++) { + T yy = static_cast(b[COORDS(cst, 1)]); + _add_broadcast(&(x[offset_p]), yy, &(z[offset_p]), inc); + offset_p = sd::inc_coords<2>(cst, offset_p); + } + } else { + for (size_t i = 0; i < loop_count; i++) { + T yy = static_cast(b[COORDS(cst, 1) * yStrideC]); + _add_broadcast(&(x[offset_p]), yy, &(z[offset_p]), inc); + offset_p = sd::inc_coords<2>(cst, offset_p); + } + } + } +} + +// +template +static void channel_generic_stride_skip_F( + const Nd4jLong*& x_strides, const Nd4jLong*& bases, T* x, const T2* b, T* z, + const bool& inplace, const Nd4jLong yStrideC, const Nd4jLong& start, + const Nd4jLong& stop, const Nd4jLong& inc) { + // (stop-start) % inc == 0 because we handled inside partitioning using the + // channel size + size_t loop_count = (stop - start) / inc; + sd::CoordsState cst; + size_t offset_p = + sd::init_coords(cst, start, bases, x_strides); + if (!inplace) { + for (size_t i = 0; i < loop_count; i++) { + T yy = static_cast(b[COORDS(cst, b_index) * yStrideC]); + _add_broadcast(&(x[offset_p]), yy, &(z[offset_p]), inc); + offset_p = sd::inc_coords(cst, offset_p); + } + } else { + for (size_t i = 0; i < loop_count; i++) { + T yy = static_cast(b[COORDS(cst, b_index) * yStrideC]); + _add_broadcast_inplace(&(x[offset_p]), yy, inc); + offset_p = sd::inc_coords(cst, offset_p); + } + } +} + +/// +template +void channel_generic_F(const Nd4jLong* bases, const Nd4jLong* x_strides, + const Nd4jLong* z_strides, const bool& inplaceOp, + const bool same_stride, const bool same_order, + const Nd4jLong yStrideC, T* x, const T2* b, T* z, + Nd4jLong start, Nd4jLong stop, Nd4jLong inc) { + // just ensure that passed sameStride is correct, because when bases are + // equal orders matters + bool sameOrderStride = same_order && same_stride; + if (sameOrderStride && x_strides[0] == 1) { + channel_generic_stride_skip_F( + x_strides, bases, x, b, z, inplaceOp, yStrideC, start, stop, inc); + } else { + // (stop-start) % inc == 0 because we handled inside partitioning using + // the channel size + + size_t loop_count = (stop - start) / inc; + sd::ZipCoordsState cst; + sd::zip_size_t offset = sd::init_coords( + cst, start, bases, x_strides, z_strides); + Nd4jLong x_stride = ZIP_STRIDE1(cst, 0); + Nd4jLong z_stride = ZIP_STRIDE2(cst, 0); + if (same_order && z_stride == 1 && x_stride == 1) { + for (size_t i = 0; i < loop_count; i++) { + T yy = static_cast(b[ZIP_COORDS(cst, b_index) * yStrideC]); + _add_broadcast(&(x[offset.first]), yy, &(z[offset.second]), inc); + offset = sd::inc_coords(cst, offset); + } + } else { + for (size_t i = 0; i < loop_count; i++) { + T* xx = &(x[offset.first]); + T* zz = &(z[offset.second]); + T yy = static_cast(b[ZIP_COORDS(cst, b_index) * yStrideC]); + for (size_t j = 0; j < inc; j++) + zz[j * z_stride] = xx[j * x_stride] + yy; + offset = sd::inc_coords(cst, offset); + } + } + } +} + +template +static void addBias_(const NDArray& input, const NDArray& bias, NDArray& output, + const bool isNCHW) { + /* + if (input.rankOf() == 2 && bias.rankOf() == 1 && input.sizeAt(1) == +bias.sizeAt(0) && input.ordering() == 'c') { int rows = input.sizeAt(0); int +biasLen = bias.lengthOf(); + +auto inB = input.bufferAsT(); +auto bB = bias.bufferAsT(); +auto outB = output.bufferAsT(); + + for (int e = 0; e < rows; e++) { + auto row = inB + (e * biasLen); +auto out = outB + (e * biasLen); + + for (int t = 0; t < biasLen; t++) { + out[t] = row[t] + bB[t]; + } + } + +return; + } + */ + + auto x_shapeInfo = input.shapeInfo(); + auto z_shapeInfo = output.shapeInfo(); + auto x = input.bufferAsT(); + auto z = output.bufferAsT(); + auto b = bias.bufferAsT(); + const Nd4jLong rank = x_shapeInfo[0]; + auto bases = &(x_shapeInfo[1]); + auto x_strides = &(x_shapeInfo[rank + 1]); + auto z_strides = &(z_shapeInfo[rank + 1]); + const bool inplaceOp = (x == z); + const bool same_order = inplaceOp || (input.ordering() == output.ordering()); + const bool channel_atTheEnd = !isNCHW; + const bool same_stride = + inplaceOp || shape::strideEquals(x_shapeInfo, z_shapeInfo); + bool isContinuous = false; + int posOfNonUnityDim; + bias.isCommonVector(posOfNonUnityDim); + const Nd4jLong yStrideC = bias.strideAt(posOfNonUnityDim); + char order = input.ordering(); + + // for rank>5 + if (rank > 5) { + const int channelDim = isNCHW ? 1 : input.rankOf() - 1; // second or last + const_cast(input).applyBroadcast(sd::broadcast::Add, {channelDim}, + bias, output); + return; + } + + if (same_order && same_stride) { + isContinuous = shape::elementWiseStride(x_shapeInfo) == 1 && + shape::elementWiseStride(z_shapeInfo) == 1; + // check_continuity(order, bases, x_strides, rank); + } // if ( sameOrder && same_stride) + + bool treat_as_lastC = false; + // + if (rank == 2 && isNCHW) { + // we believe we better treat it as channel at the end case; + treat_as_lastC = true; + } + if (channel_atTheEnd || treat_as_lastC) { + // N..HWC case here + // flattened bias variables + constexpr size_t BSIZE1 = 3 * MIN_NN * MIN_NN; + constexpr size_t BSIZE2 = BSIZE1 + MIN_NN * MIN_NN; + X flatBias_stack[BSIZE2] align32; + std::unique_ptr flatBias_heap; + const X* bias_new; + X* bias_extra = nullptr; + size_t total_num = 1; + for (Nd4jLong i = 0; i < rank; i++) { + total_num *= bases[i]; + } + Nd4jLong inc; + size_t rank_skip = 1; + if (order == 'c') { + size_t b_stack_size = BSIZE2; + inc = bases[rank - 1]; + if (isContinuous) { + // for continous we need extra stack memory + // to create vectorizable bias from small size + b_stack_size = BSIZE1; + bias_extra = &(flatBias_stack[BSIZE1]); + } + bias_new = flattened_bias(b, (X*)flatBias_stack, b_stack_size, + flatBias_heap, inc, yStrideC); + if (isContinuous && inc < MIN_NN_K * MIN_NN && + total_num > inc * MIN_NN_K) { + // for small size where total_num is sufficient we need to recreate + // vectorizable buffer + size_t old_inc = inc; + // sizeof bias_extra is MIN_NN * MIN_NN + size_t new_inc = inc < MIN_NN ? inc * MIN_NN : inc * MIN_NN / MIN_NN_K; + // if there is a room then lets multiply + new_inc = (new_inc * MIN_NN_K <= total_num && + new_inc < MIN_NN * MIN_NN / MIN_NN_K) + ? MIN_NN_K * new_inc + : new_inc; + for (size_t i = 0; i < new_inc; i += inc) { + // copy to our buffer + X* cp = &(bias_extra[i]); + for (size_t j = 0; j < inc; j++) { + cp[j] = bias_new[j]; + } + } + // vectorizable buffer + inc = new_inc; + bias_new = bias_extra; + } + } else { + inc = bases[0]; + if (isContinuous) { + // we can choose other inc and index for that case + // but for now lets choose all till the last one + uint32_t req_numThreads = + sd::Environment::getInstance().maxMasterThreads(); + isContinuous = false; + if (rank > 2) { + if (req_numThreads < 2 || bases[rank - 1] >= req_numThreads) { + inc = total_num / bases[rank - 1]; + isContinuous = true; + rank_skip = rank - 1; + } else if (rank > 3 && + bases[rank - 1] * bases[rank - 2] >= req_numThreads) { + inc = total_num / bases[rank - 1] / + bases[rank - 2]; // for continuous case it is its stride + rank_skip = rank - 2; + isContinuous = true; + } + } + } + } + + FUNC_1D func = [order, isContinuous, rank, x, b, bias_new, z, x_shapeInfo, + z_shapeInfo, same_stride, same_order, yStrideC, + rank_skip](uint64_t thread_id, int64_t start, int64_t stop, + int64_t increment) -> void { + const Nd4jLong rank = x_shapeInfo[0]; + auto bases = &(x_shapeInfo[1]); + auto x_strides = &(x_shapeInfo[rank + 1]); + auto z_strides = &(z_shapeInfo[rank + 1]); + const bool inplaceOp = (x == z); + if (order == 'c') { + if (isContinuous) { + channel_atTheEnd_continous_C(const_cast(x), bias_new, z, + inplaceOp, start, stop, increment); + } + // rank is in [2,5] + else if (rank == 4) { + channel_atTheEnd_generic_C( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + const_cast(x), bias_new, z, start, stop, increment); + + } else if (rank == 5) { + channel_atTheEnd_generic_C( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + const_cast(x), bias_new, z, start, stop, increment); + } else if (rank == 2) { + channel_atTheEnd_generic_C( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + const_cast(x), bias_new, z, start, stop, increment); + } else if (rank == 3) { + channel_atTheEnd_generic_C( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + const_cast(x), bias_new, z, start, stop, increment); + } + } else { + // generic F case + if (isContinuous) { + if (rank == 4) { + if (rank_skip == rank - 2) { + channel_generic_stride_skip_F( + x_strides, bases, const_cast(x), b, z, inplaceOp, + yStrideC, start, stop, increment); + } else { + channel_generic_stride_skip_F( + x_strides, bases, const_cast(x), b, z, inplaceOp, + yStrideC, start, stop, increment); + } + } else if (rank == 5) { + if (rank_skip == rank - 2) { + // skip==3 + channel_generic_stride_skip_F( + x_strides, bases, const_cast(x), b, z, inplaceOp, + yStrideC, start, stop, increment); + } else { + channel_generic_stride_skip_F( + x_strides, bases, const_cast(x), b, z, inplaceOp, + yStrideC, start, stop, increment); + } + } else if (rank == 3) { + channel_generic_stride_skip_F( + x_strides, bases, const_cast(x), b, z, inplaceOp, yStrideC, + start, stop, increment); + } + } else if (rank == 4) { + channel_generic_F( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + yStrideC, const_cast(x), b, z, start, stop, increment); + } else if (rank == 5) { + channel_generic_F( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + yStrideC, const_cast(x), b, z, start, stop, increment); + } else if (rank == 2) { + channel_generic_F( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + yStrideC, const_cast(x), b, z, start, stop, increment); + } else if (rank == 3) { + channel_generic_F( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + yStrideC, const_cast(x), b, z, start, stop, increment); + } + } + }; + // + samediff::Threads::parallel_aligned_increment(func, 0, total_num, inc); + } else { + // NC...HW case here + size_t numNC = 1; + size_t numHW = 1; + for (size_t i = 0; i < 2; i++) { + numNC *= bases[i]; + } + for (Nd4jLong i = 2; i < rank; i++) { + numHW *= bases[i]; + } + Nd4jLong total_num = numNC * numHW; + Nd4jLong inc = (order == 'c') ? bases[rank - 1] : bases[0]; + if (order == 'c' && isContinuous) { + // sometimes last dimension is too big and multithreading could suffer + // using unfair partitioning so we will do it only when inc is smaller our + // value or multithreading turned off + uint32_t req_numThreads = + sd::Environment::getInstance().maxMasterThreads(); + if (req_numThreads < 2 || numNC >= req_numThreads || inc <= 2 * 8196 || + rank == 3) { + inc = numHW; + } else { + // treat it as stride1c case + isContinuous = false; + } + } + FUNC_1D func = [order, isContinuous, rank, x, b, z, x_shapeInfo, + z_shapeInfo, same_stride, same_order, + yStrideC](uint64_t thread_id, int64_t start, int64_t stop, + int64_t increment) -> void { + const Nd4jLong rank = x_shapeInfo[0]; + const Nd4jLong* bases = &(x_shapeInfo[1]); + const Nd4jLong* x_strides = &(x_shapeInfo[rank + 1]); + const Nd4jLong* z_strides = &(z_shapeInfo[rank + 1]); + const bool inplaceOp = (x == z); + if (order == 'c') { + if (isContinuous) { + channel_NC_continous_numHW_C(rank, bases, x_strides, + const_cast(x), b, z, inplaceOp, + yStrideC, start, stop, increment); + } + // rank is in [3,5] + else if (rank == 4) { + channel_NC_generic_C( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + yStrideC, const_cast(x), b, z, start, stop, increment); + + } else if (rank == 5) { + channel_NC_generic_C( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + yStrideC, const_cast(x), b, z, start, stop, increment); + } else if (rank == 3) { + channel_NC_generic_C( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + yStrideC, const_cast(x), b, z, start, stop, increment); + } + } else { + // the same can be applied for NCHW case + // generic F case + // continous case is missing + + if (rank == 4) { + channel_generic_F( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + yStrideC, const_cast(x), b, z, start, stop, increment); + } else if (rank == 5) { + channel_generic_F( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + yStrideC, const_cast(x), b, z, start, stop, increment); + } else if (rank == 3) { + channel_generic_F( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + yStrideC, const_cast(x), b, z, start, stop, increment); + } + } + }; + // + samediff::Threads::parallel_aligned_increment(func, 0, total_num, inc); + } +} +////////////////////////////////////////////////////////////////////////// +void addBias(sd::graph::Context& block, const NDArray& input, + const NDArray& bias, NDArray& output, const bool isNCHW) { + // bias.rankOf() == 1 ? bias : bias.reshape(bias.ordering(), + // {bias.lengthOf()}) + BUILD_DOUBLE_SELECTOR(input.dataType(), bias.dataType(), addBias_, + (input, bias, output, isNCHW), FLOAT_TYPES, + FLOAT_TYPES); +} + +BUILD_DOUBLE_TEMPLATE(template void addBias_, + (const NDArray& input, const NDArray& bias, + NDArray& output, const bool isNCHW), + FLOAT_TYPES, FLOAT_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/adjust_hue.cpp b/libnd4j/include/ops/declarable/helpers/cpu/adjust_hue.cpp index 3f37666e7638..f6c1e0c2c72e 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/adjust_hue.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/adjust_hue.cpp @@ -19,86 +19,85 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include #include +#include +#include namespace sd { namespace ops { namespace helpers { - template -static void adjustHue_(const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC) { - - const T delta = deltaScalarArr->e(0); - const int rank = input->rankOf(); - - const T* x = input->bufferAsT(); - T* z = output->bufferAsT(); - - if(dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && input->ordering() == 'c' && output->ordering() == 'c') { - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i += increment) { - T h, s, v; - - rgbToHsv(x[i], x[i + 1], x[i + 2], h, s, v); - - h += delta ; - if (h > (T)1) - h -= (T)1; - else if (h < 0) - h += (T)1; - - hsvToRgb(h, s, v, z[i], z[i + 1], z[i + 2]); - } - }; - - samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3); - } - else { - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimC); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimC); - - const Nd4jLong numOfTads = packX.numberOfTads(); - const Nd4jLong xDimCstride = input->stridesOf()[dimC]; - const Nd4jLong zDimCstride = output->stridesOf()[dimC]; - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - - const T *xTad = x + packX.platformOffsets()[i]; - T *zTad = z + packZ.platformOffsets()[i]; - - T h, s, v; - - rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v); - - h += delta ; - if (h > (T)1) - h -= (T)1; - else if (h < 0) - h += (T)1; - - hsvToRgb(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); - - } - }; - - samediff::Threads::parallel_tad(func, 0, numOfTads); - } +static void adjustHue_(const NDArray *input, const NDArray *deltaScalarArr, + NDArray *output, const int dimC) { + const T delta = deltaScalarArr->e(0); + const int rank = input->rankOf(); + + const T *x = input->bufferAsT(); + T *z = output->bufferAsT(); + + if (dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && + input->ordering() == 'c' && output->ordering() == 'c') { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i += increment) { + T h, s, v; + + rgbToHsv(x[i], x[i + 1], x[i + 2], h, s, v); + + h += delta; + if (h > (T)1) + h -= (T)1; + else if (h < 0) + h += (T)1; + + hsvToRgb(h, s, v, z[i], z[i + 1], z[i + 2]); + } + }; + + samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3); + } else { + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimC); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimC); + + const Nd4jLong numOfTads = packX.numberOfTads(); + const Nd4jLong xDimCstride = input->stridesOf()[dimC]; + const Nd4jLong zDimCstride = output->stridesOf()[dimC]; + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + const T *xTad = x + packX.platformOffsets()[i]; + T *zTad = z + packZ.platformOffsets()[i]; + + T h, s, v; + + rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v); + + h += delta; + if (h > (T)1) + h -= (T)1; + else if (h < 0) + h += (T)1; + + hsvToRgb(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); + } + }; + + samediff::Threads::parallel_tad(func, 0, numOfTads); + } } - -void adjustHue(sd::LaunchContext* context, const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC) { - BUILD_SINGLE_SELECTOR(input->dataType(), adjustHue_, (input, deltaScalarArr, output, dimC), FLOAT_TYPES); +void adjustHue(sd::LaunchContext *context, const NDArray *input, + const NDArray *deltaScalarArr, NDArray *output, const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), adjustHue_, + (input, deltaScalarArr, output, dimC), FLOAT_TYPES); } /* template -static void adjust_hue_single_(sd::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) { +static void adjust_hue_single_(sd::LaunchContext * context, NDArray *array, +NDArray *output, float delta, bool isNHWC) { // we're 100% sure it's 3 const int numChannels = 3; int tuples = array->lengthOf() / numChannels; @@ -166,8 +165,8 @@ static void adjust_hue_single_(sd::LaunchContext * context, NDArray *array, NDAr } } -void adjust_hue_(sd::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) { - auto xType = array->dataType(); +void adjust_hue_(sd::LaunchContext * context, NDArray *array, NDArray *output, +NDArray* delta, bool isNHWC) { auto xType = array->dataType(); float d = delta->e(0); if (array->rankOf() == 4) { @@ -177,21 +176,24 @@ void adjust_hue_(sd::LaunchContext * context, NDArray *array, NDArray *output, N // FIXME: template selector should be moved out of loop PRAGMA_OMP_PARALLEL_FOR for (int e = 0; e < tSize; e++) { - BUILD_SINGLE_SELECTOR(xType, adjust_hue_single_, (context, tadsIn->at(e), tadsOut->at(e), d, isNHWC);, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(xType, adjust_hue_single_, (context, +tadsIn->at(e), tadsOut->at(e), d, isNHWC);, FLOAT_TYPES); } delete tadsIn; delete tadsOut; } else { - BUILD_SINGLE_SELECTOR(xType, adjust_hue_single_, (context, array, output, d, isNHWC);, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(xType, adjust_hue_single_, (context, array, +output, d, isNHWC);, FLOAT_TYPES); } } -BUILD_SINGLE_TEMPLATE(template void adjust_hue_single_, (sd::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC);, FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template void adjust_hue_single_, (sd::LaunchContext * +context, NDArray *array, NDArray *output, float delta, bool isNHWC);, +FLOAT_TYPES); */ - -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/adjust_saturation.cpp b/libnd4j/include/ops/declarable/helpers/cpu/adjust_saturation.cpp index 63f26c90fa3e..d5afcf9b7766 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/adjust_saturation.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/adjust_saturation.cpp @@ -19,84 +19,88 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include #include +#include +#include +#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { template -static void adjustSaturation_(const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC) { - - const T factor = factorScalarArr->e(0); - const int rank = input->rankOf(); - - const T* x = input->bufferAsT(); - T* z = output->bufferAsT(); - - if(dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && input->ordering() == 'c' && output->ordering() == 'c') { - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i += increment) { - T h, s, v; - - rgbToHsv(x[i], x[i + 1], x[i + 2], h, s, v); - - s *= factor; - if (s > 1.f) - s = 1.f; - else if (s < 0.f) - s = 0.f; - - hsvToRgb(h, s, v, z[i], z[i + 1], z[i + 2]); - } - }; - - samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3); - } else { - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimC); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimC); - - const Nd4jLong numOfTads = packX.numberOfTads(); - const Nd4jLong xDimCstride = input->stridesOf()[dimC]; - const Nd4jLong zDimCstride = output->stridesOf()[dimC]; - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - const T *xTad = x + packX.platformOffsets()[i]; - T *zTad = z + packZ.platformOffsets()[i]; - - T h, s, v; - - rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v); - - s *= factor; - if (s > 1.f) - s = 1.f; - else if (s < 0.f) - s = 0.f; - - hsvToRgb(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); - } - }; - - samediff::Threads::parallel_tad(func, 0, numOfTads); - } +static void adjustSaturation_(const NDArray *input, + const NDArray *factorScalarArr, NDArray *output, + const int dimC) { + const T factor = factorScalarArr->e(0); + const int rank = input->rankOf(); + + const T *x = input->bufferAsT(); + T *z = output->bufferAsT(); + + if (dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && + input->ordering() == 'c' && output->ordering() == 'c') { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i += increment) { + T h, s, v; + + rgbToHsv(x[i], x[i + 1], x[i + 2], h, s, v); + + s *= factor; + if (s > 1.f) + s = 1.f; + else if (s < 0.f) + s = 0.f; + + hsvToRgb(h, s, v, z[i], z[i + 1], z[i + 2]); + } + }; + + samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3); + } else { + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimC); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimC); + + const Nd4jLong numOfTads = packX.numberOfTads(); + const Nd4jLong xDimCstride = input->stridesOf()[dimC]; + const Nd4jLong zDimCstride = output->stridesOf()[dimC]; + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + const T *xTad = x + packX.platformOffsets()[i]; + T *zTad = z + packZ.platformOffsets()[i]; + + T h, s, v; + + rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v); + + s *= factor; + if (s > 1.f) + s = 1.f; + else if (s < 0.f) + s = 0.f; + + hsvToRgb(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); + } + }; + + samediff::Threads::parallel_tad(func, 0, numOfTads); + } } - -void adjustSaturation(sd::LaunchContext* context, const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC) { - - BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturation_, (input, factorScalarArr, output, dimC), FLOAT_TYPES); +void adjustSaturation(sd::LaunchContext *context, const NDArray *input, + const NDArray *factorScalarArr, NDArray *output, + const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturation_, + (input, factorScalarArr, output, dimC), FLOAT_TYPES); } /* template -static void adjust_saturation_single_(sd::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) { +static void adjust_saturation_single_(sd::LaunchContext * context, NDArray +*array, NDArray *output, float delta, bool isNHWC) { // we're 100% sure it's 3 const int numChannels = 3; int tuples = array->lengthOf() / numChannels; @@ -114,7 +118,8 @@ static void adjust_saturation_single_(sd::LaunchContext * context, NDArray *arra T h, s, v; // Convert the RGB color to Hue/V-range. helpers::rgb_to_hsv(i[0], i[1], i[2], &h, &s, &v); - s = sd::math::nd4j_min((T) 1.0f, sd::math::nd4j_max((T) 0.0f, s * delta)); + s = sd::math::nd4j_min((T) 1.0f, sd::math::nd4j_max((T) 0.0f, +s * delta)); // Convert the hue and v-range back into RGB. helpers::hsv_to_rgb(h, s, v, o, o + 1, o + 2); } @@ -143,7 +148,8 @@ static void adjust_saturation_single_(sd::LaunchContext * context, NDArray *arra T h, s, v; // Convert the RGB color to Hue/V-range. helpers::rgb_to_hsv(_ri[0], _gi[0], _bi[0], &h, &s, &v); - s = sd::math::nd4j_min((T) 1.0f, sd::math::nd4j_max((T) 0.0f, s * delta)); + s = sd::math::nd4j_min((T) 1.0f, sd::math::nd4j_max((T) 0.0f, +s * delta)); // Convert the hue and v-range back into RGB. helpers::hsv_to_rgb(h, s, v, _ro, _go, _bo); } @@ -153,8 +159,8 @@ static void adjust_saturation_single_(sd::LaunchContext * context, NDArray *arra } } -void adjust_saturation(sd::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) { - auto xType = array->dataType(); +void adjust_saturation(sd::LaunchContext * context, NDArray *array, NDArray +*output, NDArray* delta, bool isNHWC) { auto xType = array->dataType(); float d = delta->e(0); if (array->rankOf() == 4) { @@ -165,7 +171,8 @@ void adjust_saturation(sd::LaunchContext * context, NDArray *array, NDArray *out // FIXME: template selector should be moved out of loop PRAGMA_OMP_PARALLEL_FOR for (int e = 0; e < tSize; e++) { - BUILD_SINGLE_SELECTOR(xType, adjust_saturation_single_, (context, tadsIn->at(e), tadsOut->at(e), d, isNHWC);, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(xType, adjust_saturation_single_, (context, +tadsIn->at(e), tadsOut->at(e), d, isNHWC);, FLOAT_TYPES); } @@ -173,13 +180,16 @@ void adjust_saturation(sd::LaunchContext * context, NDArray *array, NDArray *out delete tadsOut; } else { - BUILD_SINGLE_SELECTOR(xType, adjust_saturation_single_, (context, array, output, d, isNHWC);, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(xType, adjust_saturation_single_, (context, array, +output, d, isNHWC);, FLOAT_TYPES); } } -BUILD_SINGLE_TEMPLATE(template void adjust_saturation_single_, (sd::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC), FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template void adjust_saturation_single_, +(sd::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool +isNHWC), FLOAT_TYPES); */ -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/axis.cpp b/libnd4j/include/ops/declarable/helpers/cpu/axis.cpp index 36f7166309f5..7ea8d2458f38 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/axis.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/axis.cpp @@ -20,29 +20,26 @@ #include - namespace sd { namespace ops { namespace helpers { - void adjustAxis(Nd4jLong rank, NDArray* axisVector, std::vector& output) { - output.resize(axisVector->lengthOf()); - for (Nd4jLong e = 0; e < axisVector->lengthOf(); e++) { - auto ca = axisVector->e(e); - if (ca < 0) - ca += rank; - - output[e] = ca; - } - } +void adjustAxis(Nd4jLong rank, NDArray* axisVector, std::vector& output) { + output.resize(axisVector->lengthOf()); + for (Nd4jLong e = 0; e < axisVector->lengthOf(); e++) { + auto ca = axisVector->e(e); + if (ca < 0) ca += rank; - void adjustAxis(Nd4jLong rank, std::vector &axisVector) { - for (size_t e = 0; e < axisVector.size(); e++) { - auto a = axisVector[e]; - if (a < 0) - axisVector[e] = a + rank; - } - } -} + output[e] = ca; + } } + +void adjustAxis(Nd4jLong rank, std::vector& axisVector) { + for (size_t e = 0; e < axisVector.size(); e++) { + auto a = axisVector[e]; + if (a < 0) axisVector[e] = a + rank; + } } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/batched_gemm.cpp b/libnd4j/include/ops/declarable/helpers/cpu/batched_gemm.cpp index 0c9338a8eff7..6f9d8e220815 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/batched_gemm.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/batched_gemm.cpp @@ -18,114 +18,135 @@ // @author raver119@gmail.com // +#include +#include +#include #include #include -#include -#include -#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - template -void bgemm_(const std::vector& vA, const std::vector& vB, std::vector& vC, const NDArray* alphas, const NDArray* betas, int transA, int transB, int M, int N, int K, const int lda, const int ldb, const int ldc) { - - int batchSize = vA.size(); - - if (BlasHelper::getInstance().hasBatchedGEMM()) { - auto arr = vA.at(0); - CBLAS_TRANSPOSE *tA, *tB; - int *tM, *tN, *tK, *tldA, *tldB, *tldC, *tsize; - // mkl requires mnk etc as arrays, cuda doesn't - ALLOCATE(tA, arr->getContext()->getWorkspace(), batchSize, CBLAS_TRANSPOSE); - ALLOCATE(tB, arr->getContext()->getWorkspace(), batchSize, CBLAS_TRANSPOSE); - ALLOCATE(tM, arr->getContext()->getWorkspace(), batchSize, int); - ALLOCATE(tN, arr->getContext()->getWorkspace(), batchSize, int); - ALLOCATE(tK, arr->getContext()->getWorkspace(), batchSize, int); - ALLOCATE(tldA, arr->getContext()->getWorkspace(), batchSize, int); - ALLOCATE(tldB, arr->getContext()->getWorkspace(), batchSize, int); - ALLOCATE(tldC, arr->getContext()->getWorkspace(), batchSize, int); - ALLOCATE(tsize, arr->getContext()->getWorkspace(), batchSize, int); - - shape::fill(tA, (CBLAS_TRANSPOSE) transA, batchSize); - shape::fill(tB, (CBLAS_TRANSPOSE) transB, batchSize); - - shape::fill(tM, M, batchSize); - shape::fill(tN, N, batchSize); - shape::fill(tK, K, batchSize); - shape::fill(tldA, lda, batchSize); - shape::fill(tldB, ldb, batchSize); - shape::fill(tldC, ldc, batchSize); - shape::fill(tsize, 1, batchSize); - - std::vector buffersA(batchSize); - std::vector buffersB(batchSize); - std::vector buffersC(batchSize); - - for (int e = 0; e < batchSize; e++) { - buffersA[e] = reinterpret_cast(vA[e]->buffer()); - buffersB[e] = reinterpret_cast(vB[e]->buffer()); - buffersC[e] = reinterpret_cast(vC[e]->buffer()); - } - - if (std::is_same::value) { - BlasHelper::getInstance().dgemmBatched()(CblasColMajor, tA, tB, tM, tN, tK, (double *) alphas->buffer(), (double **) buffersA.data(), tldA, (double **) buffersB.data(), tldB, (double *) betas->buffer(),(double **) buffersC.data(), tldC, vA.size(), tsize); - } else if (std::is_same::value) { - BlasHelper::getInstance().sgemmBatched()(CblasColMajor, tA, tB, tM, tN, tK, (float *) alphas->buffer(), (float **) buffersA.data(), tldA, (float **) buffersB.data(), tldB, (float *) betas->buffer(), (float **) buffersC.data(), tldC, vA.size(), tsize); - } +void bgemm_(const std::vector &vA, const std::vector &vB, + std::vector &vC, const NDArray *alphas, + const NDArray *betas, int transA, int transB, int M, int N, int K, + const int lda, const int ldb, const int ldc) { + int batchSize = vA.size(); + + if (BlasHelper::getInstance().hasBatchedGEMM()) { + auto arr = vA.at(0); + CBLAS_TRANSPOSE *tA, *tB; + int *tM, *tN, *tK, *tldA, *tldB, *tldC, *tsize; + // mkl requires mnk etc as arrays, cuda doesn't + ALLOCATE(tA, arr->getContext()->getWorkspace(), batchSize, CBLAS_TRANSPOSE); + ALLOCATE(tB, arr->getContext()->getWorkspace(), batchSize, CBLAS_TRANSPOSE); + ALLOCATE(tM, arr->getContext()->getWorkspace(), batchSize, int); + ALLOCATE(tN, arr->getContext()->getWorkspace(), batchSize, int); + ALLOCATE(tK, arr->getContext()->getWorkspace(), batchSize, int); + ALLOCATE(tldA, arr->getContext()->getWorkspace(), batchSize, int); + ALLOCATE(tldB, arr->getContext()->getWorkspace(), batchSize, int); + ALLOCATE(tldC, arr->getContext()->getWorkspace(), batchSize, int); + ALLOCATE(tsize, arr->getContext()->getWorkspace(), batchSize, int); + + shape::fill(tA, (CBLAS_TRANSPOSE)transA, batchSize); + shape::fill(tB, (CBLAS_TRANSPOSE)transB, batchSize); + + shape::fill(tM, M, batchSize); + shape::fill(tN, N, batchSize); + shape::fill(tK, K, batchSize); + shape::fill(tldA, lda, batchSize); + shape::fill(tldB, ldb, batchSize); + shape::fill(tldC, ldc, batchSize); + shape::fill(tsize, 1, batchSize); + + std::vector buffersA(batchSize); + std::vector buffersB(batchSize); + std::vector buffersC(batchSize); + + for (int e = 0; e < batchSize; e++) { + buffersA[e] = reinterpret_cast(vA[e]->buffer()); + buffersB[e] = reinterpret_cast(vB[e]->buffer()); + buffersC[e] = reinterpret_cast(vC[e]->buffer()); + } - // release temporary arrays - RELEASE(tA, arr->getContext()->getWorkspace()); - RELEASE(tB, arr->getContext()->getWorkspace()); - RELEASE(tM, arr->getContext()->getWorkspace()); - RELEASE(tN, arr->getContext()->getWorkspace()); - RELEASE(tK, arr->getContext()->getWorkspace()); - RELEASE(tldA, arr->getContext()->getWorkspace()); - RELEASE(tldB, arr->getContext()->getWorkspace()); - RELEASE(tldC, arr->getContext()->getWorkspace()); - RELEASE(tsize, arr->getContext()->getWorkspace()); - } else { - CBLAS_TRANSPOSE tA = (CBLAS_TRANSPOSE) transA; - CBLAS_TRANSPOSE tB = (CBLAS_TRANSPOSE) transB; - - int vaSize = vA.size(); - - auto func = PRAGMA_THREADS_FOR { - for (auto p = start; p < stop; p++) { - auto A = reinterpret_cast(vA.at(p)->buffer()); - auto B = reinterpret_cast(vB.at(p)->buffer()); - auto C = reinterpret_cast(vC.at(p)->buffer()); - auto alpha = alphas->e(p); - auto beta = betas->e(p); - for (int m = 0; m < M; ++m) { - for (int n = 0; n < N; ++n) { - T c_mnp = 0; - - PRAGMA_OMP_SIMD - for (int k = 0; k < K; ++k) - c_mnp += A[tA == CblasNoTrans ? (m + k * lda) : (m * lda + k)] * B[tB == CblasNoTrans ? (k + n * ldb) : (k * ldb + n)]; - - C[m + n * ldc] = alpha * c_mnp + beta * C[m + n * ldc]; - } - } - } - }; - - samediff::Threads::parallel_tad(func, 0, vaSize); + if (std::is_same::value) { + BlasHelper::getInstance().dgemmBatched()( + CblasColMajor, tA, tB, tM, tN, tK, (double *)alphas->buffer(), + (double **)buffersA.data(), tldA, (double **)buffersB.data(), tldB, + (double *)betas->buffer(), (double **)buffersC.data(), tldC, + vA.size(), tsize); + } else if (std::is_same::value) { + BlasHelper::getInstance().sgemmBatched()( + CblasColMajor, tA, tB, tM, tN, tK, (float *)alphas->buffer(), + (float **)buffersA.data(), tldA, (float **)buffersB.data(), tldB, + (float *)betas->buffer(), (float **)buffersC.data(), tldC, vA.size(), + tsize); } -} + // release temporary arrays + RELEASE(tA, arr->getContext()->getWorkspace()); + RELEASE(tB, arr->getContext()->getWorkspace()); + RELEASE(tM, arr->getContext()->getWorkspace()); + RELEASE(tN, arr->getContext()->getWorkspace()); + RELEASE(tK, arr->getContext()->getWorkspace()); + RELEASE(tldA, arr->getContext()->getWorkspace()); + RELEASE(tldB, arr->getContext()->getWorkspace()); + RELEASE(tldC, arr->getContext()->getWorkspace()); + RELEASE(tsize, arr->getContext()->getWorkspace()); + } else { + CBLAS_TRANSPOSE tA = (CBLAS_TRANSPOSE)transA; + CBLAS_TRANSPOSE tB = (CBLAS_TRANSPOSE)transB; + + int vaSize = vA.size(); + + auto func = PRAGMA_THREADS_FOR { + for (auto p = start; p < stop; p++) { + auto A = reinterpret_cast(vA.at(p)->buffer()); + auto B = reinterpret_cast(vB.at(p)->buffer()); + auto C = reinterpret_cast(vC.at(p)->buffer()); + auto alpha = alphas->e(p); + auto beta = betas->e(p); + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + T c_mnp = 0; + + PRAGMA_OMP_SIMD + for (int k = 0; k < K; ++k) + c_mnp += A[tA == CblasNoTrans ? (m + k * lda) : (m * lda + k)] * + B[tB == CblasNoTrans ? (k + n * ldb) : (k * ldb + n)]; + + C[m + n * ldc] = alpha * c_mnp + beta * C[m + n * ldc]; + } + } + } + }; -void bgemm(const std::vector& vA, const std::vector& vB, std::vector& vC, const NDArray* alphas, const NDArray* betas, int transA, int transB, int M, int N, int K, const int lda, const int ldb, const int ldc) { - auto xType = vA.at(0)->dataType(); - BUILD_SINGLE_SELECTOR(xType, bgemm_, (vA, vB, vC, alphas, betas, transA, transB, M, N, K, lda, ldb, ldc), FLOAT_TYPES); + samediff::Threads::parallel_tad(func, 0, vaSize); + } } -BUILD_SINGLE_TEMPLATE(template void bgemm_, (const std::vector& vA, const std::vector& vB, std::vector& vC, const NDArray* alphas, const NDArray* betas, int transA, int transB, int M, int N, int K, const int lda, const int ldb, const int ldc), FLOAT_TYPES); - +void bgemm(const std::vector &vA, const std::vector &vB, + std::vector &vC, const NDArray *alphas, + const NDArray *betas, int transA, int transB, int M, int N, int K, + const int lda, const int ldb, const int ldc) { + auto xType = vA.at(0)->dataType(); + BUILD_SINGLE_SELECTOR( + xType, bgemm_, + (vA, vB, vC, alphas, betas, transA, transB, M, N, K, lda, ldb, ldc), + FLOAT_TYPES); } -} -} \ No newline at end of file + +BUILD_SINGLE_TEMPLATE(template void bgemm_, + (const std::vector &vA, + const std::vector &vB, + std::vector &vC, const NDArray *alphas, + const NDArray *betas, int transA, int transB, int M, + int N, int K, const int lda, const int ldb, + const int ldc), + FLOAT_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/batchnorm.cpp b/libnd4j/include/ops/declarable/helpers/cpu/batchnorm.cpp index 65c342d9ca09..988c703efc6c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/batchnorm.cpp @@ -18,183 +18,213 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // - -#include -#include -#include #include +#include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// template -static void batchnorm_(const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* beta, - NDArray* output, +static void batchnorm_(const NDArray* input, const NDArray* mean, + const NDArray* variance, const NDArray* gamma, + const NDArray* beta, NDArray* output, const std::vector& axes, const double epsilon) { - - // formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta - - const T* x = input->bufferAsT(); - T* z = output->bufferAsT(); - const T* m = mean->bufferAsT(); - const T* v = variance->bufferAsT(); - const T* g = gamma == nullptr ? nullptr : gamma->bufferAsT(); - const T* b = beta == nullptr ? nullptr : beta->bufferAsT(); - - const bool xzSameOffset = shape::haveSameShapeAndStrides(input->shapeInfo(), output->shapeInfo()); - - bool paramSameOffset = shape::haveSameShapeAndStrides(mean->shapeInfo(), variance->shapeInfo()); - if(paramSameOffset && gamma != nullptr) - paramSameOffset &= shape::haveSameShapeAndStrides(mean->shapeInfo(), gamma->shapeInfo()); - if(paramSameOffset && beta != nullptr) - paramSameOffset &= shape::haveSameShapeAndStrides(mean->shapeInfo(), beta->shapeInfo()); - - const Nd4jLong lenBig = input->lengthOf(); - const Nd4jLong lenSmall = mean->lengthOf(); - - const Nd4jLong steps = lenBig / lenSmall; - std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(input->rankOf(), axes); - - OmpLaunchHelper info(lenBig, lenSmall); - - auto func = PRAGMA_THREADS_DO { - - Nd4jLong* xOffsets = new Nd4jLong[steps]; - Nd4jLong* zOffsets = xzSameOffset ? xOffsets : new Nd4jLong[steps]; - int* auxBuff = new int[2 * input->rankOf()]; - - for (Nd4jLong j = 0; j < lenSmall; ++j) { - - const bool isOwner = (j < info._numThreads) ? thread_id == j : thread_id == (j % info._numThreads); - - if(!isOwner) - continue; - - const auto meanOffset = shape::getIndexOffset(j, mean->shapeInfo()); - const auto varOffset = paramSameOffset ? meanOffset : shape::getIndexOffset(j, variance->shapeInfo()); - - const auto meanVal = m[meanOffset]; - auto sigmaInvGam = static_cast(1) / sd::math::nd4j_sqrt(v[varOffset] + epsilon); - - if(g != nullptr) { - const auto gammaOffset = paramSameOffset ? meanOffset : shape::getIndexOffset(j, gamma->shapeInfo()); - sigmaInvGam *= g[gammaOffset]; - } - - T betaVal = static_cast(0); - if(b != nullptr) { - const auto betaOffset = paramSameOffset ? meanOffset : shape::getIndexOffset(j, beta->shapeInfo()); - betaVal = b[betaOffset]; - } - - // calculate offsets for input and output - shape::outerArrayOffsets(xOffsets, j, input->shapeInfo(), mean->shapeInfo(), auxBuff, dimsToExclude.data()); - if(!xzSameOffset) - shape::outerArrayOffsets(zOffsets, j, output->shapeInfo(), mean->shapeInfo(), auxBuff, dimsToExclude.data()); - - PRAGMA_OMP_SIMD - for (Nd4jLong i = 0; i < steps; ++i) - z[zOffsets[i]] = (x[xOffsets[i]] - meanVal) * sigmaInvGam + betaVal; - } - - delete []auxBuff; - delete []xOffsets; - if(!xzSameOffset) - delete []zOffsets; - }; - - samediff::Threads::parallel_do(func, info._numThreads); + // formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + + // beta + + const T* x = input->bufferAsT(); + T* z = output->bufferAsT(); + const T* m = mean->bufferAsT(); + const T* v = variance->bufferAsT(); + const T* g = gamma == nullptr ? nullptr : gamma->bufferAsT(); + const T* b = beta == nullptr ? nullptr : beta->bufferAsT(); + + const bool xzSameOffset = + shape::haveSameShapeAndStrides(input->shapeInfo(), output->shapeInfo()); + + bool paramSameOffset = + shape::haveSameShapeAndStrides(mean->shapeInfo(), variance->shapeInfo()); + if (paramSameOffset && gamma != nullptr) + paramSameOffset &= + shape::haveSameShapeAndStrides(mean->shapeInfo(), gamma->shapeInfo()); + if (paramSameOffset && beta != nullptr) + paramSameOffset &= + shape::haveSameShapeAndStrides(mean->shapeInfo(), beta->shapeInfo()); + + const Nd4jLong lenBig = input->lengthOf(); + const Nd4jLong lenSmall = mean->lengthOf(); + + const Nd4jLong steps = lenBig / lenSmall; + std::vector dimsToExclude = + ShapeUtils::evalDimsToExclude(input->rankOf(), axes); + + OmpLaunchHelper info(lenBig, lenSmall); + + auto func = PRAGMA_THREADS_DO { + Nd4jLong* xOffsets = new Nd4jLong[steps]; + Nd4jLong* zOffsets = xzSameOffset ? xOffsets : new Nd4jLong[steps]; + int* auxBuff = new int[2 * input->rankOf()]; + + for (Nd4jLong j = 0; j < lenSmall; ++j) { + const bool isOwner = (j < info._numThreads) + ? thread_id == j + : thread_id == (j % info._numThreads); + + if (!isOwner) continue; + + const auto meanOffset = shape::getIndexOffset(j, mean->shapeInfo()); + const auto varOffset = + paramSameOffset ? meanOffset + : shape::getIndexOffset(j, variance->shapeInfo()); + + const auto meanVal = m[meanOffset]; + auto sigmaInvGam = + static_cast(1) / sd::math::nd4j_sqrt(v[varOffset] + epsilon); + + if (g != nullptr) { + const auto gammaOffset = + paramSameOffset ? meanOffset + : shape::getIndexOffset(j, gamma->shapeInfo()); + sigmaInvGam *= g[gammaOffset]; + } + + T betaVal = static_cast(0); + if (b != nullptr) { + const auto betaOffset = + paramSameOffset ? meanOffset + : shape::getIndexOffset(j, beta->shapeInfo()); + betaVal = b[betaOffset]; + } + + // calculate offsets for input and output + shape::outerArrayOffsets(xOffsets, j, input->shapeInfo(), + mean->shapeInfo(), auxBuff, + dimsToExclude.data()); + if (!xzSameOffset) + shape::outerArrayOffsets(zOffsets, j, output->shapeInfo(), + mean->shapeInfo(), auxBuff, + dimsToExclude.data()); + + PRAGMA_OMP_SIMD + for (Nd4jLong i = 0; i < steps; ++i) + z[zOffsets[i]] = (x[xOffsets[i]] - meanVal) * sigmaInvGam + betaVal; + } + + delete[] auxBuff; + delete[] xOffsets; + if (!xzSameOffset) delete[] zOffsets; + }; + + samediff::Threads::parallel_do(func, info._numThreads); } ////////////////////////////////////////////////////////////////////////// template -static void batchnorm2_(const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* beta, - NDArray* output, +static void batchnorm2_(const NDArray* input, const NDArray* mean, + const NDArray* variance, const NDArray* gamma, + const NDArray* beta, NDArray* output, const std::vector& axes, const double epsilon) { - - // formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta - - const auto x = input->bufferAsT(); - auto z = output->bufferAsT(); - const auto m = mean->bufferAsT(); - const auto v = variance->bufferAsT(); - const auto g = gamma == nullptr ? nullptr : gamma->bufferAsT(); - const auto b = beta == nullptr ? nullptr : beta->bufferAsT(); - - // xRank == zRank, minRank = meanRank = varianceRank = gammaRank = betaRank - const uint xRank = input->rankOf(); - const uint minRank = mean->rankOf(); - const uint numAxes = axes.size(); - - const bool xzSameOffset = shape::haveSameShapeAndStrides(input->shapeInfo(), output->shapeInfo()); - - bool paramSameOffset = shape::haveSameShapeAndStrides(mean->shapeInfo(), variance->shapeInfo()); - if(paramSameOffset && gamma != nullptr) - paramSameOffset &= shape::haveSameShapeAndStrides(mean->shapeInfo(), gamma->shapeInfo()); - if(paramSameOffset && beta != nullptr) - paramSameOffset &= shape::haveSameShapeAndStrides(mean->shapeInfo(), beta->shapeInfo()); - - auto func = PRAGMA_THREADS_FOR { - - int xzCoords[MAX_RANK], minCoords[MAX_RANK]; - - for (uint i = 0, j = 0; i < xRank; ++i) - if(j < numAxes && i != axes[j]) - minCoords[i] = 0; - else - ++j; - - for (auto i = start; i < stop; i++) { - - shape::index2coordsCPU(start, i, input->shapeInfo(), xzCoords); - - const auto xOffset = shape::getOffset(input->shapeInfo(), xzCoords); - const auto zOffset = xzSameOffset ? xOffset : shape::getOffset(output->shapeInfo(), xzCoords); - - if(minRank == xRank) { - for (uint j = 0; j < numAxes; ++j) - minCoords[axes[j]] = xzCoords[axes[j]]; - } - else // minRank = numAxes = 1 in this case - minCoords[0] = xzCoords[axes[0]]; - - const auto meanOffset = shape::getOffset(mean->shapeInfo(), minCoords); - const auto varianceOffset = paramSameOffset ? meanOffset : shape::getOffset(variance->shapeInfo(), minCoords); - - T sigmaInvGam = 1. / sd::math::nd4j_sqrt(v[varianceOffset] + epsilon); - - if(g != nullptr) { - const auto gammaOffset = paramSameOffset ? meanOffset : shape::getOffset(gamma->shapeInfo(), minCoords); - sigmaInvGam *= g[gammaOffset]; - } - - z[zOffset] = (x[xOffset] - m[meanOffset]) * sigmaInvGam; - - if(b != nullptr) { - const auto betaOffset = paramSameOffset ? meanOffset : shape::getOffset(beta->shapeInfo(), minCoords); - z[zOffset] += b[betaOffset]; - } - } - }; - - samediff::Threads::parallel_for(func, 0, input->lengthOf()); + // formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + + // beta + + const auto x = input->bufferAsT(); + auto z = output->bufferAsT(); + const auto m = mean->bufferAsT(); + const auto v = variance->bufferAsT(); + const auto g = gamma == nullptr ? nullptr : gamma->bufferAsT(); + const auto b = beta == nullptr ? nullptr : beta->bufferAsT(); + + // xRank == zRank, minRank = meanRank = varianceRank = gammaRank = betaRank + const uint xRank = input->rankOf(); + const uint minRank = mean->rankOf(); + const uint numAxes = axes.size(); + + const bool xzSameOffset = + shape::haveSameShapeAndStrides(input->shapeInfo(), output->shapeInfo()); + + bool paramSameOffset = + shape::haveSameShapeAndStrides(mean->shapeInfo(), variance->shapeInfo()); + if (paramSameOffset && gamma != nullptr) + paramSameOffset &= + shape::haveSameShapeAndStrides(mean->shapeInfo(), gamma->shapeInfo()); + if (paramSameOffset && beta != nullptr) + paramSameOffset &= + shape::haveSameShapeAndStrides(mean->shapeInfo(), beta->shapeInfo()); + + auto func = PRAGMA_THREADS_FOR { + int xzCoords[MAX_RANK], minCoords[MAX_RANK]; + + for (uint i = 0, j = 0; i < xRank; ++i) + if (j < numAxes && i != axes[j]) + minCoords[i] = 0; + else + ++j; + + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, input->shapeInfo(), xzCoords); + + const auto xOffset = shape::getOffset(input->shapeInfo(), xzCoords); + const auto zOffset = + xzSameOffset ? xOffset + : shape::getOffset(output->shapeInfo(), xzCoords); + + if (minRank == xRank) { + for (uint j = 0; j < numAxes; ++j) + minCoords[axes[j]] = xzCoords[axes[j]]; + } else // minRank = numAxes = 1 in this case + minCoords[0] = xzCoords[axes[0]]; + + const auto meanOffset = shape::getOffset(mean->shapeInfo(), minCoords); + const auto varianceOffset = + paramSameOffset ? meanOffset + : shape::getOffset(variance->shapeInfo(), minCoords); + + T sigmaInvGam = + 1. / sd::math::nd4j_sqrt(v[varianceOffset] + epsilon); + + if (g != nullptr) { + const auto gammaOffset = + paramSameOffset ? meanOffset + : shape::getOffset(gamma->shapeInfo(), minCoords); + sigmaInvGam *= g[gammaOffset]; + } + + z[zOffset] = (x[xOffset] - m[meanOffset]) * sigmaInvGam; + + if (b != nullptr) { + const auto betaOffset = + paramSameOffset ? meanOffset + : shape::getOffset(beta->shapeInfo(), minCoords); + z[zOffset] += b[betaOffset]; + } + } + }; + + samediff::Threads::parallel_for(func, 0, input->lengthOf()); } ////////////////////////////////////////////////////////////////////////// -void batchnorm(const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* beta, NDArray* output, const std::vector& axes, const double epsilon) { - - // batchnorm2_ is still slower ? - BUILD_SINGLE_SELECTOR(input->dataType(), batchnorm_, (input, mean, variance, gamma, beta, output, axes, epsilon), FLOAT_TYPES); +void batchnorm(const NDArray* input, const NDArray* mean, + const NDArray* variance, const NDArray* gamma, + const NDArray* beta, NDArray* output, + const std::vector& axes, const double epsilon) { + // batchnorm2_ is still slower ? + BUILD_SINGLE_SELECTOR( + input->dataType(), batchnorm_, + (input, mean, variance, gamma, beta, output, axes, epsilon), FLOAT_TYPES); } +BUILD_SINGLE_TEMPLATE(template void batchnorm_, + (const NDArray* input, const NDArray* mean, + const NDArray* variance, const NDArray* gamma, + const NDArray* beta, NDArray* output, + const std::vector& axes, const double epsilon), + FLOAT_TYPES); - -BUILD_SINGLE_TEMPLATE(template void batchnorm_, (const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* beta, NDArray* output, const std::vector& axes, const double epsilon), FLOAT_TYPES); - -} -} -} - +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/betaInc.cpp b/libnd4j/include/ops/declarable/helpers/cpu/betaInc.cpp index 0056fec6d2df..d2076244dd84 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/betaInc.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/betaInc.cpp @@ -18,11 +18,12 @@ // Created by Yurii Shyrma on 11.12.2017 // -#include #include -#include #include #include +#include + +#include namespace sd { namespace ops { @@ -30,111 +31,112 @@ namespace helpers { /////////////////////////////////////////////////////////////////// // modified Lentz’s algorithm for continued fractions, -// reference: Lentz, W.J. 1976, “Generating Bessel Functions in Mie Scattering Calculations Using Continued Fractions” +// reference: Lentz, W.J. 1976, “Generating Bessel Functions in Mie Scattering +// Calculations Using Continued Fractions” template static T continuedFraction(const T a, const T b, const T x) { - - const T min = DataTypeUtils::min() / DataTypeUtils::eps(); - const T aPlusb = a + b; - T val, aPlus2i; - - T t2 = 1; - T t1 = static_cast(1) - aPlusb * x / (a + static_cast(1)); - if(math::nd4j_abs(t1) < min) - t1 = min; - t1 = static_cast(1) / t1; - T result = t1; - - for(uint i = 1; i <= maxIter; ++i) { - - aPlus2i = a + static_cast(2*i); - val = i * (b - i) * x / ((aPlus2i - static_cast(1)) * aPlus2i); - // t1 - t1 = static_cast(1) + val * t1; - if(math::nd4j_abs(t1) < min) - t1 = min; - t1 = static_cast(1) / t1; - // t2 - t2 = static_cast(1) + val / t2; - if(math::nd4j_abs(t2) < min) - t2 = min; - // result - result *= t2 * t1; - val = -(a + i) * (aPlusb + i) * x / ((aPlus2i + static_cast(1)) * aPlus2i); - // t1 - t1 = static_cast(1) + val * t1; - if(math::nd4j_abs(t1) < min) - t1 = min; - t1 = static_cast(1) / t1; - // t2 - t2 = static_cast(1) + val / t2; - if(math::nd4j_abs(t2) < min) - t2 = min; - // result - val = t2 * t1; - result *= val; - - // condition to stop loop - if(math::nd4j_abs(val - static_cast(1)) <= DataTypeUtils::eps()) - return result; - } - - return DataTypeUtils::infOrMax(); // no convergence, more iterations is required, return infinity + const T min = DataTypeUtils::min() / DataTypeUtils::eps(); + const T aPlusb = a + b; + T val, aPlus2i; + + T t2 = 1; + T t1 = static_cast(1) - aPlusb * x / (a + static_cast(1)); + if (math::nd4j_abs(t1) < min) t1 = min; + t1 = static_cast(1) / t1; + T result = t1; + + for (uint i = 1; i <= maxIter; ++i) { + aPlus2i = a + static_cast(2 * i); + val = i * (b - i) * x / ((aPlus2i - static_cast(1)) * aPlus2i); + // t1 + t1 = static_cast(1) + val * t1; + if (math::nd4j_abs(t1) < min) t1 = min; + t1 = static_cast(1) / t1; + // t2 + t2 = static_cast(1) + val / t2; + if (math::nd4j_abs(t2) < min) t2 = min; + // result + result *= t2 * t1; + val = + -(a + i) * (aPlusb + i) * x / ((aPlus2i + static_cast(1)) * aPlus2i); + // t1 + t1 = static_cast(1) + val * t1; + if (math::nd4j_abs(t1) < min) t1 = min; + t1 = static_cast(1) / t1; + // t2 + t2 = static_cast(1) + val / t2; + if (math::nd4j_abs(t2) < min) t2 = min; + // result + val = t2 * t1; + result *= val; + + // condition to stop loop + if (math::nd4j_abs(val - static_cast(1)) <= DataTypeUtils::eps()) + return result; + } + + return DataTypeUtils::infOrMax(); // no convergence, more iterations is + // required, return infinity } /////////////////////////////////////////////////////////////////// -// evaluates incomplete beta function for positive a and b, and x between 0 and 1. +// evaluates incomplete beta function for positive a and b, and x between 0 +// and 1. template static T betaIncCore(T a, T b, T x) { - // if (a <= (T)0. || b <= (T)0.) - // throw("betaInc function: a and b must be > 0 !"); - - // if (x < (T)0. || x > (T)1.) - // throw("betaInc function: x must be within (0, 1) interval !"); + // if (a <= (T)0. || b <= (T)0.) + // throw("betaInc function: a and b must be > 0 !"); + // if (x < (T)0. || x > (T)1.) + // throw("betaInc function: x must be within (0, 1) interval !"); - // t^{n-1} * (1 - t)^{n-1} is symmetric function with respect to x = 0.5 - if(a == b && x == static_cast(0.5)) - return static_cast(0.5); + // t^{n-1} * (1 - t)^{n-1} is symmetric function with respect to x = 0.5 + if (a == b && x == static_cast(0.5)) return static_cast(0.5); - if (x == static_cast(0) || x == static_cast(1)) - return x; + if (x == static_cast(0) || x == static_cast(1)) return x; - const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b); - const T front = math::nd4j_exp(math::nd4j_log(x) * a + math::nd4j_log(1.f - x) * b - gammaPart); + const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b); + const T front = + math::nd4j_exp(math::nd4j_log(x) * a + + math::nd4j_log(1.f - x) * b - gammaPart); - if (x <= (a + static_cast(1)) / (a + b + static_cast(2))) - return front * continuedFraction(a, b, x) / a; - else // symmetry relation - return static_cast(1) - front * continuedFraction(b, a, static_cast(1) - x) / b; + if (x <= (a + static_cast(1)) / (a + b + static_cast(2))) + return front * continuedFraction(a, b, x) / a; + else // symmetry relation + return static_cast(1) - + front * continuedFraction(b, a, static_cast(1) - x) / b; } /////////////////////////////////////////////////////////////////// -template -static void betaIncForArray(sd::LaunchContext * context, const NDArray& a, const NDArray& b, const NDArray& x, NDArray& output) { - - int xLen = x.lengthOf(); +template +static void betaIncForArray(sd::LaunchContext* context, const NDArray& a, + const NDArray& b, const NDArray& x, + NDArray& output) { + int xLen = x.lengthOf(); - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - output.r(i) = betaIncCore(a.t(i), b.t(i), x.t(i)); - }; + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) + output.r(i) = betaIncCore(a.t(i), b.t(i), x.t(i)); + }; - samediff::Threads::parallel_for(func, 0, xLen); + samediff::Threads::parallel_for(func, 0, xLen); } /////////////////////////////////////////////////////////////////// // overload betaInc for arrays, shapes of a, b and x must be the same !!! -void betaInc(sd::LaunchContext * context, const NDArray& a, const NDArray& b, const NDArray& x, NDArray& output) { - auto xType = a.dataType(); - BUILD_SINGLE_SELECTOR(xType, betaIncForArray, (context, a, b, x, output), FLOAT_TYPES); +void betaInc(sd::LaunchContext* context, const NDArray& a, const NDArray& b, + const NDArray& x, NDArray& output) { + auto xType = a.dataType(); + BUILD_SINGLE_SELECTOR(xType, betaIncForArray, (context, a, b, x, output), + FLOAT_TYPES); } -BUILD_SINGLE_TEMPLATE(template void betaIncForArray, (sd::LaunchContext * context, const NDArray& a, const NDArray& b, const NDArray& x, NDArray& output), FLOAT_TYPES); - - -} -} -} +BUILD_SINGLE_TEMPLATE(template void betaIncForArray, + (sd::LaunchContext * context, const NDArray& a, + const NDArray& b, const NDArray& x, NDArray& output), + FLOAT_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/clip.cpp b/libnd4j/include/ops/declarable/helpers/cpu/clip.cpp index 2c2d9a111ddd..7520bf127012 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/clip.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/clip.cpp @@ -24,39 +24,37 @@ #include #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// -void clipByNorm(sd::LaunchContext* context, NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace, const bool useAverage) { +void clipByNorm(sd::LaunchContext* context,NDArray& input, NDArray& output, + const std::vector& dimensions, + const NDArray& clipNorm, const bool isInplace, + const bool useAverage) { - NDArray* z = nullptr; + NDArray* z = nullptr; - if(isInplace) { - z = &input; + if (isInplace) { +z = &input; } else { output.assign(input); z = &output; - } - - if(dimensions.empty()) { + } if (dimensions.empty()) { const NDArray actualNorm = useAverage ? z->reduceAlongDimension(reduce::Norm2, {}) / z->lengthOf() : z->reduceAlongDimension(reduce::Norm2, {}); - - if(actualNorm.e(0) > clipNorm.e(0)) + if (actualNorm.e(0) > clipNorm.e(0) ) *z *= clipNorm / actualNorm; - } - else { - - auto listOfSubArrs = z->allTensorsAlongDimension(dimensions); + } else { + auto listOfSubArrs = z->allTensorsAlongDimension(dimensions); auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i++) { - const NDArray actualNorm = useAverage ? listOfSubArrs.at(i)->reduceAlongDimension(reduce::Norm2, {}) / listOfSubArrs.at(i)->lengthOf() : listOfSubArrs.at(i)->reduceAlongDimension(reduce::Norm2, {}); + const NDArray actualNorm = useAverage ? listOfSubArrs.at(i).reduceAlongDimension(reduce::Norm2, {}) / listOfSubArrs.at(i).lengthOf() : listOfSubArrs.at(i).reduceAlongDimension(reduce::Norm2, {}); if(actualNorm.e(0) > clipNorm.e(0)) - *listOfSubArrs.at(i) *= clipNorm / actualNorm; + listOfSubArrs.at(i) *= clipNorm / actualNorm; } }; samediff::Threads::parallel_tad(func, 0, listOfSubArrs.size()); @@ -106,10 +104,10 @@ static void clipByNormBp_(const NDArray& input, const NDArray& gradO, NDArray& g for (auto i = start; i < stop; i++) { - auto gradOSubArr = gradOSubArrs.at(i); + auto gradOSubArr = gradOSubArrs.at(i); auto gradISubArr = gradISubArrs.at(i); - const T norm = useAverage ? norm2.e(i) / gradISubArr->lengthOf() : norm2.e(i); + const T norm = useAverage ? norm2.e(i) / gradISubArr.lengthOf() : norm2.e(i); if (norm > clipVal) { @@ -123,10 +121,10 @@ static void clipByNormBp_(const NDArray& input, const NDArray& gradO, NDArray& g return factor1 * y * (static_cast(1.f) - factor2 * x * sum); }; - inputSubArr->applyPairwiseLambda(*gradOSubArr, lambda, *gradISubArr); + inputSubArr.applyPairwiseLambda(gradOSubArr, lambda, gradISubArr); } else - gradISubArr->assign(gradOSubArr); + gradISubArr.assign(gradOSubArr); } }; samediff::Threads::parallel_tad(func, 0, gradISubArrs.size()); @@ -135,72 +133,91 @@ static void clipByNormBp_(const NDArray& input, const NDArray& gradO, NDArray& g BUILD_SINGLE_TEMPLATE(template void clipByNormBp_, (const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector& dimensions, const NDArray& clipNorm, const bool useAverage), FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// -void clipByNormBp(sd::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector& dimensions, const NDArray& clipNorm, const bool useAverage) { +void clipByNormBp(sd::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI, + const std::vector& dimensions, const NDArray& clipNorm, + const bool useAverage) { const NDArray& castedInput = gradI.dataType() == input.dataType() ? input : input.cast(gradI.dataType()); - - BUILD_SINGLE_SELECTOR(gradI.dataType(), clipByNormBp_, (castedInput, gradO, gradI, dimensions, clipNorm, useAverage), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(gradI.dataType(), clipByNormBp_, (castedInput, gradO, gradI, dimensions, clipNorm, useAverage), + FLOAT_TYPES); } - template - static void clipByGlobalNorm_(std::vector const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector& outputs, bool isInplace) { - T globalNorm = 0; //NDArrayFactory::create(0, inputs[0]->getContext()); //sqrt(sum([l2norm(t)**2 for t in t_list])) -// PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(sumT : globalNorm) - for (size_t i = 0; i < inputs.size(); i++) { - auto input = inputs[i]; - auto l2norm = input->reduceNumber(reduce::Norm2); - globalNorm += l2norm.t(0) * l2norm.t(0); - } - - //globalNorm.applyTransform(transform::Sqrt, nullptr, nullptr);// = sd::math::nd4j_sqrt(globalNorm); - auto normS = sd::math::nd4j_sqrt(globalNorm); - outputs[inputs.size()]->p(0, normS); +template +static void clipByGlobalNorm_(std::vector const& inputs, + double clipNorm, sd::memory::Workspace* workspace, + std::vector& outputs, bool isInplace) { + T globalNorm = 0; // NDArrayFactory::create(0, inputs[0]->getContext()); + // //sqrt(sum([l2norm(t)**2 for t in t_list])) + // PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(sumT : globalNorm) + for (size_t i = 0; i < inputs.size(); i++) { + auto input = inputs[i]; + auto l2norm = input->reduceNumber(reduce::Norm2); + globalNorm += l2norm.t(0) * l2norm.t(0); + } - const T factor = clipNorm / normS; + // globalNorm.applyTransform(transform::Sqrt, nullptr, nullptr);// = + // sd::math::nd4j_sqrt(globalNorm); + auto normS = sd::math::nd4j_sqrt(globalNorm); + outputs[inputs.size()]->p(0, normS); -// PRAGMA_OMP_PARALLEL_FOR - for (size_t e = 0; e < inputs.size(); e++) { - // all-reduce - auto input = inputs[e]; - auto output = outputs[e]; + const T factor = clipNorm / normS; - if (normS <= clipNorm) { - output->assign(input); - } - else { + // PRAGMA_OMP_PARALLEL_FOR + for (size_t e = 0; e < inputs.size(); e++) { + // all-reduce + auto input = inputs[e]; + auto output = outputs[e]; - auto lambda = LAMBDA_T(_x, factor) { return _x * factor; }; - input->applyLambda(lambda, *output); - } - } - } - void clipByGlobalNorm(sd::LaunchContext * context, std::vector const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector& outputs, bool isInplace) { - BUILD_SINGLE_SELECTOR(outputs[0]->dataType(), clipByGlobalNorm_, (inputs, clipNorm, workspace, outputs, isInplace), FLOAT_TYPES); + if (normS <= clipNorm) { + output->assign(input); + } else { + auto lambda = LAMBDA_T(_x, factor) { return _x * factor; }; + input->applyLambda(lambda, *output); } + } +} +void clipByGlobalNorm(sd::LaunchContext* context, + std::vector const& inputs, double clipNorm, + sd::memory::Workspace* workspace, + std::vector& outputs, bool isInplace) { + BUILD_SINGLE_SELECTOR(outputs[0]->dataType(), clipByGlobalNorm_, + (inputs, clipNorm, workspace, outputs, isInplace), + FLOAT_TYPES); +} - BUILD_SINGLE_TEMPLATE(template void clipByGlobalNorm_, (std::vector const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector& outputs, bool isInplace), FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template void clipByGlobalNorm_, + (std::vector const& inputs, double clipNorm, + sd::memory::Workspace* workspace, + std::vector& outputs, bool isInplace), + FLOAT_TYPES); - template - static void clipByValue_(NDArray& input, double leftBound, double rightBound, NDArray& output) { - auto routine = LAMBDA_T(_x, leftBound, rightBound) { - if (_x > rightBound) return rightBound; - if (_x < leftBound) return leftBound; - return _x; - }; +template +static void clipByValue_(NDArray& input, double leftBound, double rightBound, + NDArray& output) { + auto routine = LAMBDA_T(_x, leftBound, rightBound) { + if (_x > rightBound) return rightBound; + if (_x < leftBound) return leftBound; + return _x; + }; - input.applyLambda(routine, output); - } + input.applyLambda(routine, output); +} - void clipByValue(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output) { - BUILD_SINGLE_SELECTOR(input.dataType(), clipByValue_, (input, leftBound, rightBound, output), FLOAT_TYPES); - } +void clipByValue(sd::LaunchContext* context, NDArray& input, double leftBound, + double rightBound, NDArray& output) { + BUILD_SINGLE_SELECTOR(input.dataType(), clipByValue_, + (input, leftBound, rightBound, output), FLOAT_TYPES); +} - BUILD_SINGLE_TEMPLATE(template void clipByValue_, (NDArray& input, double leftBound, double rightBound, NDArray& output);, FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template void clipByValue_, + (NDArray & input, double leftBound, double rightBound, + NDArray& output); + , FLOAT_TYPES); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/col2im.cpp b/libnd4j/include/ops/declarable/helpers/cpu/col2im.cpp index 42d4af529217..9f13b3af177e 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/col2im.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/col2im.cpp @@ -18,8 +18,8 @@ // Created by raver119 on 30.11.17. // -#include #include +#include namespace sd { namespace ops { @@ -27,115 +27,120 @@ namespace helpers { // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] template -void col2im_(sd::LaunchContext & context, const NDArray& input, NDArray& output, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW) { - - auto imBuff = output.bufferAsT(); - auto colBuff = input.bufferAsT(); - auto imShapeBuffer = output.shapeInfo(); - auto colShapeBuffer = input.shapeInfo(); - auto colShape = shape::shapeOf(colShapeBuffer); - auto colStride = shape::stride(colShapeBuffer); - auto imShape = shape::shapeOf(imShapeBuffer); - auto imStride = shape::stride(imShapeBuffer); - - const int bS = imShape[0]; - const int iC = imShape[1]; - const int kH = colShape[2]; - const int kW = colShape[3]; - const int oH = colShape[4]; - const int oW = colShape[5]; - const Nd4jLong colStride0 = colStride[0]; - const Nd4jLong colStride1 = colStride[1]; - const Nd4jLong colStride2 = colStride[2]; - const Nd4jLong colStride3 = colStride[3]; - const Nd4jLong colStride4 = colStride[4]; - const Nd4jLong colStride5 = colStride[5]; - const Nd4jLong imStride0 = imStride[0]; - const Nd4jLong imStride1 = imStride[1]; - const Nd4jLong imStride2 = imStride[2]; - const Nd4jLong imStride3 = imStride[3]; - - - // if (shape::order(colShapeBuffer) == 'c' && shape::order(imShapeBuffer) == 'c' && shape::strideDescendingCAscendingF(colShapeBuffer) && shape::strideDescendingCAscendingF(imShapeBuffer)) { - if (false) { - - auto func = PRAGMA_THREADS_FOR_2D { - T const* col; - T* im; - - int imRow, imCol; - - for (auto b = start_x; b < stop_x; b += inc_x) { - for (auto c = start_y; c < stop_y; c += inc_y) { - for (int kRow = 0; kRow < kH; ++kRow) { - for (int kCol = 0; kCol < kW; ++kCol) { - for (int colH = 0; colH < oH; ++colH) { - for (int colW = 0; colW < oW; ++colW) { - - imRow = (-pH + kRow * dH) + colH * sH; - imCol = (-pW + kCol * dW) + colW * sW; - - col = colBuff + b * colStride0 + c * colStride1 + kRow * colStride2 + kCol * colStride3 + colH * colStride4 + colW * colStride5; - im = imBuff + b * imStride0 + c * imStride1 + imRow * imStride2 + imCol * imStride3; - - if (static_cast(imRow) < static_cast(iH) && - static_cast(imCol) < static_cast(iW)) - *im += *col; - } - } - } - } +void col2im_(sd::LaunchContext& context, const NDArray& input, NDArray& output, + const int sH, const int sW, const int pH, const int pW, + const int iH, const int iW, const int dH, const int dW) { + auto imBuff = output.bufferAsT(); + auto colBuff = input.bufferAsT(); + auto imShapeBuffer = output.shapeInfo(); + auto colShapeBuffer = input.shapeInfo(); + auto colShape = shape::shapeOf(colShapeBuffer); + auto colStride = shape::stride(colShapeBuffer); + auto imShape = shape::shapeOf(imShapeBuffer); + auto imStride = shape::stride(imShapeBuffer); + + const int bS = imShape[0]; + const int iC = imShape[1]; + const int kH = colShape[2]; + const int kW = colShape[3]; + const int oH = colShape[4]; + const int oW = colShape[5]; + const Nd4jLong colStride0 = colStride[0]; + const Nd4jLong colStride1 = colStride[1]; + const Nd4jLong colStride2 = colStride[2]; + const Nd4jLong colStride3 = colStride[3]; + const Nd4jLong colStride4 = colStride[4]; + const Nd4jLong colStride5 = colStride[5]; + const Nd4jLong imStride0 = imStride[0]; + const Nd4jLong imStride1 = imStride[1]; + const Nd4jLong imStride2 = imStride[2]; + const Nd4jLong imStride3 = imStride[3]; + + // if (shape::order(colShapeBuffer) == 'c' && shape::order(imShapeBuffer) == + // 'c' && shape::strideDescendingCAscendingF(colShapeBuffer) && + // shape::strideDescendingCAscendingF(imShapeBuffer)) { + if (false) { + auto func = PRAGMA_THREADS_FOR_2D { + T const* col; + T* im; + + int imRow, imCol; + + for (auto b = start_x; b < stop_x; b += inc_x) { + for (auto c = start_y; c < stop_y; c += inc_y) { + for (int kRow = 0; kRow < kH; ++kRow) { + for (int kCol = 0; kCol < kW; ++kCol) { + for (int colH = 0; colH < oH; ++colH) { + for (int colW = 0; colW < oW; ++colW) { + imRow = (-pH + kRow * dH) + colH * sH; + imCol = (-pW + kCol * dW) + colW * sW; + + col = colBuff + b * colStride0 + c * colStride1 + + kRow * colStride2 + kCol * colStride3 + + colH * colStride4 + colW * colStride5; + im = imBuff + b * imStride0 + c * imStride1 + + imRow * imStride2 + imCol * imStride3; + + if (static_cast(imRow) < + static_cast(iH) && + static_cast(imCol) < static_cast(iW)) + *im += *col; } + } } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); - } - else { - - auto func = PRAGMA_THREADS_FOR { - T *col, *im; - - for (auto b = start; b < stop; b++) { - T *im0 = imBuff + b * imStride0; - T const* col4 = colBuff + b * colStride0; - for (int colH = 0; colH < oH; ++colH, col4 += colStride4) { - T const* col5 = col4; - for (int colW = 0; colW < oW; ++colW, col5 += colStride5) { - T const* col1 = col5; - T *im1 = im0; - for (int c = 0; c < iC; ++c, col1 += colStride1, im1 += imStride1) { - int imRow = (-pH + colH * sH); - T const* col2 = col1; - T *im2 = im1 + imRow * imStride2; - for (int kRow = 0; - kRow < kH; ++kRow, col2 += colStride2, imRow += dH, im2 += dH * imStride2) { - int imCol = -pW + colW * sW; - T const* col3 = col2; - T *im3 = im2 + imCol * imStride3; - for (int kCol = 0; - kCol < kW; ++kCol, col3 += colStride3, imCol += dW, im3 += dW * imStride3) { - - if (static_cast(imRow) < static_cast(iH) && - static_cast(imCol) < static_cast(iW)) - *im3 += *col3; - } - } - } - } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } else { + auto func = PRAGMA_THREADS_FOR { + T *col, *im; + + for (auto b = start; b < stop; b++) { + T* im0 = imBuff + b * imStride0; + T const* col4 = colBuff + b * colStride0; + for (int colH = 0; colH < oH; ++colH, col4 += colStride4) { + T const* col5 = col4; + for (int colW = 0; colW < oW; ++colW, col5 += colStride5) { + T const* col1 = col5; + T* im1 = im0; + for (int c = 0; c < iC; ++c, col1 += colStride1, im1 += imStride1) { + int imRow = (-pH + colH * sH); + T const* col2 = col1; + T* im2 = im1 + imRow * imStride2; + for (int kRow = 0; kRow < kH; ++kRow, col2 += colStride2, + imRow += dH, im2 += dH * imStride2) { + int imCol = -pW + colW * sW; + T const* col3 = col2; + T* im3 = im2 + imCol * imStride3; + for (int kCol = 0; kCol < kW; ++kCol, col3 += colStride3, + imCol += dW, im3 += dW * imStride3) { + if (static_cast(imRow) < + static_cast(iH) && + static_cast(imCol) < static_cast(iW)) + *im3 += *col3; } + } } - }; + } + } + } + }; - samediff::Threads::parallel_tad(func, 0, bS); - } + samediff::Threads::parallel_tad(func, 0, bS); + } } - -void col2im(sd::LaunchContext & context, const NDArray& input, NDArray& output, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW) { - BUILD_SINGLE_SELECTOR(input.dataType(), col2im_, (context, input, output, sH, sW, pH, pW, iH, iW, dH, dW), FLOAT_TYPES); +void col2im(sd::LaunchContext& context, const NDArray& input, NDArray& output, + const int sH, const int sW, const int pH, const int pW, + const int iH, const int iW, const int dH, const int dW) { + BUILD_SINGLE_SELECTOR( + input.dataType(), col2im_, + (context, input, output, sH, sW, pH, pW, iH, iW, dH, dW), FLOAT_TYPES); } -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compare_elem.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compare_elem.cpp index 32dc3d7c79b9..e33d89b7d302 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compare_elem.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compare_elem.cpp @@ -14,62 +14,65 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -#include #include +#include namespace sd { - namespace ops { - namespace helpers { - template - static void _compare_elem(NDArray *input, bool isStrictlyIncreasing, bool& output) { - auto length = shape::length(input->shapeInfo()); - - int elementsPerThread = length / ELEMENT_THRESHOLD; - int num_threads = sd::math::nd4j_max(1, elementsPerThread); - num_threads = sd::math::nd4j_min(num_threads, omp_get_max_threads()); - Nd4jLong sumt = 0; - - if(isStrictlyIncreasing) { - //PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:sum) - auto func = PRAGMA_REDUCE_LONG { - Nd4jLong sum = 0; - for (auto i = start; i < stop; i++) { - auto val0 = input->t(i); - auto val1 = input->t(i + 1); - sum += val0 >= val1 ? -1 : 0; - } - return sum; - }; - sumt = samediff::Threads::parallel_long(func, LAMBDA_SUML, 0, length - 1); - } else { - //PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:sum) - auto func = PRAGMA_REDUCE_LONG { - Nd4jLong sum = 0; - for (auto i = start; i < stop; i++) { - auto val0 = input->t(i); - auto val1 = input->t(i + 1); - sum += val0 > val1 ? -1 : 0; - } +namespace ops { +namespace helpers { +template +static void _compare_elem(NDArray* input, bool isStrictlyIncreasing, + bool& output) { + auto length = shape::length(input->shapeInfo()); - return sum; - }; - sumt = samediff::Threads::parallel_long(func, LAMBDA_SUML, 0, length - 1); - } + int elementsPerThread = length / ELEMENT_THRESHOLD; + int num_threads = sd::math::nd4j_max(1, elementsPerThread); + num_threads = sd::math::nd4j_min(num_threads, omp_get_max_threads()); + Nd4jLong sumt = 0; - //nd4j_printf("Sum: %lld\n", sumt) + if (isStrictlyIncreasing) { + // PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:sum) + auto func = PRAGMA_REDUCE_LONG { + Nd4jLong sum = 0; + for (auto i = start; i < stop; i++) { + auto val0 = input->t(i); + auto val1 = input->t(i + 1); + sum += val0 >= val1 ? -1 : 0; + } + return sum; + }; + sumt = samediff::Threads::parallel_long(func, LAMBDA_SUML, 0, length - 1); + } else { + // PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:sum) + auto func = PRAGMA_REDUCE_LONG { + Nd4jLong sum = 0; + for (auto i = start; i < stop; i++) { + auto val0 = input->t(i); + auto val1 = input->t(i + 1); + sum += val0 > val1 ? -1 : 0; + } - output = (sumt > -1); + return sum; + }; + sumt = samediff::Threads::parallel_long(func, LAMBDA_SUML, 0, length - 1); + } - } + // nd4j_printf("Sum: %lld\n", sumt) - void compare_elem(sd::LaunchContext * context, NDArray *input, bool isStrictlyIncreasing, bool& output) { - auto xType = input->dataType(); - - BUILD_SINGLE_SELECTOR(xType, _compare_elem, (input, isStrictlyIncreasing, output), LIBND4J_TYPES); - } + output = (sumt > -1); +} +void compare_elem(sd::LaunchContext* context, NDArray* input, + bool isStrictlyIncreasing, bool& output) { + auto xType = input->dataType(); - BUILD_SINGLE_TEMPLATE(template void _compare_elem, (NDArray *A, bool isStrictlyIncreasing, bool& output);, LIBND4J_TYPES); - } - } + BUILD_SINGLE_SELECTOR(xType, _compare_elem, + (input, isStrictlyIncreasing, output), LIBND4J_TYPES); } + +BUILD_SINGLE_TEMPLATE(template void _compare_elem, + (NDArray * A, bool isStrictlyIncreasing, bool& output); + , LIBND4J_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compression/compression.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compression/compression.cpp index 0911b0619478..53eade55f857 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compression/compression.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compression/compression.cpp @@ -17,21 +17,25 @@ // // @author sgazeos@gmail.com // -#include #include +#include namespace sd { namespace ops { namespace helpers { - void decodeBitmap(sd::LaunchContext* context, const NDArray* input, NDArray* output) { - NativeOpExecutioner::decodeBitmap(input->buffer(), output->lengthOf(), output->buffer(), output->shapeInfo()); - } - - - Nd4jLong encodeBitmap(sd::LaunchContext* context, NDArray* input, NDArray* output, float threshold) { - return NativeOpExecutioner::encodeBitmap(input->buffer(), input->shapeInfo(), input->lengthOf(), output->bufferAsT(), threshold); - } -} +void decodeBitmap(sd::LaunchContext* context, const NDArray* input, + NDArray* output) { + NativeOpExecutioner::decodeBitmap(input->buffer(), output->lengthOf(), + output->buffer(), output->shapeInfo()); } + +Nd4jLong encodeBitmap(sd::LaunchContext* context, NDArray* input, + NDArray* output, float threshold) { + return NativeOpExecutioner::encodeBitmap(input->buffer(), input->shapeInfo(), + input->lengthOf(), + output->bufferAsT(), threshold); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compression/threshold.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compression/threshold.cpp index bac3812d1739..33f2cffb711f 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compression/threshold.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compression/threshold.cpp @@ -18,45 +18,53 @@ // @author raver119@gmail.com // -#include #include #include +#include namespace sd { - namespace ops { - namespace helpers { - template - static int32_t thresholdEstimate_(const NDArray &updates, const float threshold) { - auto N = updates.lengthOf(); - const auto buffer = updates.bufferAsT(); - - auto func = PRAGMA_REDUCE_LONG { - int64_t cnt = 0; - for (auto e = start; e < stop; e++) { - auto v = sd::math::nd4j_abs(buffer[e]); - if (v >= threshold) - cnt++; - } - - return cnt; - }; - - return samediff::Threads::parallel_long(func, LAMBDA_AL { return _old + _new; }, 0, N); - } - - int32_t thresholdEstimate(const NDArray &updates, const float threshold) { - BUILD_SINGLE_SELECTOR(updates.dataType(), return thresholdEstimate_, (updates, threshold), FLOAT_TYPES); - - return 0; - } - - void thresholdEncode(NDArray &updates, NDArray &encoded, float threshold) { - BUILD_SINGLE_SELECTOR(updates.dataType(), sd::TypeCast::convertToThreshold, (nullptr, updates.buffer(), updates.lengthOf(), encoded.buffer()), FLOAT_TYPES); - } - - void thresholdDecode(const NDArray &encoded, NDArray &updates) { - BUILD_SINGLE_SELECTOR(updates.dataType(), sd::TypeCast::convertFromThreshold, (nullptr, encoded.buffer(), updates.lengthOf(), updates.buffer()), FLOAT_TYPES); - } - } +namespace ops { +namespace helpers { +template +static int32_t thresholdEstimate_(const NDArray &updates, + const float threshold) { + auto N = updates.lengthOf(); + const auto buffer = updates.bufferAsT(); + + auto func = PRAGMA_REDUCE_LONG { + int64_t cnt = 0; + for (auto e = start; e < stop; e++) { + auto v = sd::math::nd4j_abs(buffer[e]); + if (v >= threshold) cnt++; } + + return cnt; + }; + + return samediff::Threads::parallel_long( + func, LAMBDA_AL { return _old + _new; }, 0, N); +} + +int32_t thresholdEstimate(const NDArray &updates, const float threshold) { + BUILD_SINGLE_SELECTOR(updates.dataType(), return thresholdEstimate_, + (updates, threshold), FLOAT_TYPES); + + return 0; +} + +void thresholdEncode(NDArray &updates, NDArray &encoded, float threshold) { + BUILD_SINGLE_SELECTOR( + updates.dataType(), sd::TypeCast::convertToThreshold, + (nullptr, updates.buffer(), updates.lengthOf(), encoded.buffer()), + FLOAT_TYPES); +} + +void thresholdDecode(const NDArray &encoded, NDArray &updates) { + BUILD_SINGLE_SELECTOR( + updates.dataType(), sd::TypeCast::convertFromThreshold, + (nullptr, encoded.buffer(), updates.lengthOf(), updates.buffer()), + FLOAT_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/concat.cpp b/libnd4j/include/ops/declarable/helpers/cpu/concat.cpp index 9fd2b5b02381..cdb74c618ca0 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/concat.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/concat.cpp @@ -18,24 +18,30 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - #include #include namespace sd { - namespace ops { - namespace helpers { - ////////////////////////////////////////////////////////////////////////// - template - static void concat_(const std::vector& inArrs, NDArray& output, const int axis) { - sd::SpecialMethods::concatCpuGeneric(inArrs, output, axis); - } +namespace ops { +namespace helpers { +////////////////////////////////////////////////////////////////////////// +template +static void concat_(const std::vector& inArrs, NDArray& output, + const int axis) { + sd::SpecialMethods::concatCpuGeneric(inArrs, output, axis); +} - void concat(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output, const int axis) { - BUILD_SINGLE_SELECTOR(output.dataType(), concat_,(inArrs, output, axis), LIBND4J_TYPES); - } +void concat(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output, + const int axis) { + BUILD_SINGLE_SELECTOR(output.dataType(), concat_, (inArrs, output, axis), + LIBND4J_TYPES); +} - BUILD_SINGLE_TEMPLATE(template void concat_, (const std::vector& inArrs, NDArray& output, const int axis), LIBND4J_TYPES); - } - } -} \ No newline at end of file +BUILD_SINGLE_TEMPLATE(template void concat_, + (const std::vector& inArrs, + NDArray& output, const int axis), + LIBND4J_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/confusion.cpp b/libnd4j/include/ops/declarable/helpers/cpu/confusion.cpp index 685d80d2dcf0..f4ec0c667dc1 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/confusion.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/confusion.cpp @@ -18,39 +18,44 @@ // @author GS // -#include #include - +#include namespace sd { namespace ops { namespace helpers { - template - void _confusionFunctor(NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output) { - ResultSet arrs = output->allTensorsAlongDimension({1}); - int lLen = labels->lengthOf(); - - auto func = PRAGMA_THREADS_FOR { - for (int j = start; j < stop; j++) { - auto label = labels->e(j); - auto pred = predictions->e(j); - T value = (weights == nullptr ? (T) 1.0f : weights->e(j)); - arrs.at(label)->p(pred, value); - } - }; - - samediff::Threads::parallel_for(func, 0, lLen); +template +void _confusionFunctor(NDArray* labels, NDArray* predictions, NDArray* weights, + NDArray* output) { + ResultSet arrs = output->allTensorsAlongDimension({1}); + int lLen = labels->lengthOf(); + + auto func = PRAGMA_THREADS_FOR { + for (int j = start; j < stop; j++) { + auto label = labels->e(j); + auto pred = predictions->e(j); + T value = (weights == nullptr ? (T)1.0f : weights->e(j)); + arrs.at(label).p(pred, value); } + }; - void confusionFunctor(sd::LaunchContext * context, NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output) { - auto xType = output->dataType(); // weights can be null - - BUILD_SINGLE_SELECTOR(xType, _confusionFunctor, (labels, predictions, weights, output), NUMERIC_TYPES); - } + samediff::Threads::parallel_for(func, 0, lLen); +} - BUILD_SINGLE_TEMPLATE(template void _confusionFunctor, (NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output);, NUMERIC_TYPES); +void confusionFunctor(sd::LaunchContext* context, NDArray* labels, + NDArray* predictions, NDArray* weights, NDArray* output) { + auto xType = output->dataType(); // weights can be null + BUILD_SINGLE_SELECTOR(xType, _confusionFunctor, + (labels, predictions, weights, output), NUMERIC_TYPES); } -} -} \ No newline at end of file + +BUILD_SINGLE_TEMPLATE(template void _confusionFunctor, + (NDArray * labels, NDArray* predictions, NDArray* weights, + NDArray* output); + , NUMERIC_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_col2vol.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_col2vol.cpp index b12064cacda8..2715fb0ae539 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_col2vol.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_col2vol.cpp @@ -18,126 +18,150 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include #include +#include namespace sd { - namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// // [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW] template -static void col2vol_(const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - - // initial zeroing of volume content - volume.nullify(); - - const int bS = volume.sizeAt(0); - const int iC = volume.sizeAt(1); - const int iD = volume.sizeAt(2); - const int iH = volume.sizeAt(3); - const int iW = volume.sizeAt(4); - const int kD = columns.sizeAt(2); - const int kH = columns.sizeAt(3); - const int kW = columns.sizeAt(4); - const int oD = columns.sizeAt(5); - const int oH = columns.sizeAt(6); - const int oW = columns.sizeAt(7); - const Nd4jLong colStride0 = columns.stridesOf()[0]; - const Nd4jLong colStride1 = columns.stridesOf()[1]; - const Nd4jLong colStride2 = columns.stridesOf()[2]; - const Nd4jLong colStride3 = columns.stridesOf()[3]; - const Nd4jLong colStride4 = columns.stridesOf()[4]; - const Nd4jLong colStride5 = columns.stridesOf()[5]; - const Nd4jLong colStride6 = columns.stridesOf()[6]; - const Nd4jLong colStride7 = columns.stridesOf()[7]; - const Nd4jLong volStride0 = volume.stridesOf()[0]; - const Nd4jLong volStride1 = volume.stridesOf()[1]; - const Nd4jLong volStride2 = volume.stridesOf()[2]; - const Nd4jLong volStride3 = volume.stridesOf()[3]; - const Nd4jLong volStride4 = volume.stridesOf()[4]; - - T* volBuff = volume.bufferAsT(); - T* colBuff = const_cast(columns).bufferAsT(); - - - if (volume.ordering() == 'c' && columns.ordering() == 'c' && shape::strideDescendingCAscendingF(volume.shapeInfo()) && shape::strideDescendingCAscendingF(columns.shapeInfo())) { - - auto func = PRAGMA_THREADS_FOR { - T* col, *vol; - int volDep, volRow, volCol; - - for (int b = start; b < stop; b++) { - for (int c = 0; c < iC; c++) { - for (int kDep = 0; kDep < kD; ++kDep) { - for (int kRow = 0; kRow < kH; ++kRow) { - for (int kCol = 0; kCol < kW; ++kCol) { - for (int colD = 0; colD < oD; ++colD) { - for (int colH = 0; colH < oH; ++colH) { - for (int colW = 0; colW < oW; ++colW) { - - volDep = -pD + kDep * dD + colD * sD; - volRow = -pH + kRow * dH + colH * sH; - volCol = -pW + kCol * dW + colW * sW; - - if (static_cast(volDep) < static_cast(iD) && static_cast(volRow) < static_cast(iH) && static_cast(volCol) < static_cast(iW)) { - col = colBuff + b * colStride0 + c * colStride1 + kDep * colStride2 + kRow * colStride3 + kCol * colStride4 + colD * colStride5 + colH * colStride6 + colW * colStride7; - vol = volBuff + b * volStride0 + c * volStride1 + volDep * volStride2 + volRow * volStride3 + volCol * volStride4; - *vol += *col; - } - } - } - } - } - } - } - } +static void col2vol_(const NDArray& columns, NDArray& volume, const int sD, + const int sH, const int sW, const int pD, const int pH, + const int pW, const int dD, const int dH, const int dW) { + // initial zeroing of volume content + volume.nullify(); + + const int bS = volume.sizeAt(0); + const int iC = volume.sizeAt(1); + const int iD = volume.sizeAt(2); + const int iH = volume.sizeAt(3); + const int iW = volume.sizeAt(4); + const int kD = columns.sizeAt(2); + const int kH = columns.sizeAt(3); + const int kW = columns.sizeAt(4); + const int oD = columns.sizeAt(5); + const int oH = columns.sizeAt(6); + const int oW = columns.sizeAt(7); + const Nd4jLong colStride0 = columns.stridesOf()[0]; + const Nd4jLong colStride1 = columns.stridesOf()[1]; + const Nd4jLong colStride2 = columns.stridesOf()[2]; + const Nd4jLong colStride3 = columns.stridesOf()[3]; + const Nd4jLong colStride4 = columns.stridesOf()[4]; + const Nd4jLong colStride5 = columns.stridesOf()[5]; + const Nd4jLong colStride6 = columns.stridesOf()[6]; + const Nd4jLong colStride7 = columns.stridesOf()[7]; + const Nd4jLong volStride0 = volume.stridesOf()[0]; + const Nd4jLong volStride1 = volume.stridesOf()[1]; + const Nd4jLong volStride2 = volume.stridesOf()[2]; + const Nd4jLong volStride3 = volume.stridesOf()[3]; + const Nd4jLong volStride4 = volume.stridesOf()[4]; + + T* volBuff = volume.bufferAsT(); + T* colBuff = const_cast(columns).bufferAsT(); + + if (volume.ordering() == 'c' && columns.ordering() == 'c' && + shape::strideDescendingCAscendingF(volume.shapeInfo()) && + shape::strideDescendingCAscendingF(columns.shapeInfo())) { + auto func = PRAGMA_THREADS_FOR { + T *col, *vol; + int volDep, volRow, volCol; + + for (int b = start; b < stop; b++) { + for (int c = 0; c < iC; c++) { + for (int kDep = 0; kDep < kD; ++kDep) { + for (int kRow = 0; kRow < kH; ++kRow) { + for (int kCol = 0; kCol < kW; ++kCol) { + for (int colD = 0; colD < oD; ++colD) { + for (int colH = 0; colH < oH; ++colH) { + for (int colW = 0; colW < oW; ++colW) { + volDep = -pD + kDep * dD + colD * sD; + volRow = -pH + kRow * dH + colH * sH; + volCol = -pW + kCol * dW + colW * sW; + + if (static_cast(volDep) < + static_cast(iD) && + static_cast(volRow) < + static_cast(iH) && + static_cast(volCol) < + static_cast(iW)) { + col = colBuff + b * colStride0 + c * colStride1 + + kDep * colStride2 + kRow * colStride3 + + kCol * colStride4 + colD * colStride5 + + colH * colStride6 + colW * colStride7; + vol = volBuff + b * volStride0 + c * volStride1 + + volDep * volStride2 + volRow * volStride3 + + volCol * volStride4; + *vol += *col; + } } - }; - - samediff::Threads::parallel_tad(func, 0, bS); - - } else { - - auto func = PRAGMA_THREADS_FOR { - T* col, *vol; - int volDep, volRow, volCol; - - for (int b = start; b < stop; b++) { - for (int colD = 0; colD < oD; colD++) { - for (int colH = 0; colH < oH; ++colH) { - for (int colW = 0; colW < oW; ++colW) { - for (int c = 0; c < iC; ++c) { - for (int kDep = 0; kDep < kD; ++kDep) { - for (int kRow = 0; kRow < kH; ++kRow) { - for (int kCol = 0; kCol < kW; ++kCol) { - - volDep = (-pD + kDep * dD) + colD * sD; - volRow = (-pH + kRow * dH) + colH * sH; - volCol = (-pW + kCol * dW) + colW * sW; - - if (static_cast(volDep) < static_cast(iD) && static_cast(volRow) < static_cast(iH) && static_cast(volCol) < static_cast(iW)) { - col = colBuff + b * colStride0 + c * colStride1 + kDep * colStride2 + kRow * colStride3 + kCol * colStride4 + colD * colStride5 + colH * colStride6 + colW * colStride7; - vol = volBuff + b * volStride0 + c * volStride1 + volDep * volStride2 + volRow * volStride3 + volCol * volStride4; - *vol += *col; - } - } - } - } - } - } - } - } + } + } + } + } + } + } + } + }; + + samediff::Threads::parallel_tad(func, 0, bS); + + } else { + auto func = PRAGMA_THREADS_FOR { + T *col, *vol; + int volDep, volRow, volCol; + + for (int b = start; b < stop; b++) { + for (int colD = 0; colD < oD; colD++) { + for (int colH = 0; colH < oH; ++colH) { + for (int colW = 0; colW < oW; ++colW) { + for (int c = 0; c < iC; ++c) { + for (int kDep = 0; kDep < kD; ++kDep) { + for (int kRow = 0; kRow < kH; ++kRow) { + for (int kCol = 0; kCol < kW; ++kCol) { + volDep = (-pD + kDep * dD) + colD * sD; + volRow = (-pH + kRow * dH) + colH * sH; + volCol = (-pW + kCol * dW) + colW * sW; + + if (static_cast(volDep) < + static_cast(iD) && + static_cast(volRow) < + static_cast(iH) && + static_cast(volCol) < + static_cast(iW)) { + col = colBuff + b * colStride0 + c * colStride1 + + kDep * colStride2 + kRow * colStride3 + + kCol * colStride4 + colD * colStride5 + + colH * colStride6 + colW * colStride7; + vol = volBuff + b * volStride0 + c * volStride1 + + volDep * volStride2 + volRow * volStride3 + + volCol * volStride4; + *vol += *col; + } } - }; - - samediff::Threads::parallel_tad(func, 0, bS); + } + } + } } + } } + } + }; -void ConvolutionUtils::col2vol(sd::graph::Context& block, const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - BUILD_SINGLE_SELECTOR(volume.dataType(), col2vol_, (columns, volume, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); + samediff::Threads::parallel_tad(func, 0, bS); + } } +void ConvolutionUtils::col2vol(sd::graph::Context& block, + const NDArray& columns, NDArray& volume, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW) { + BUILD_SINGLE_SELECTOR(volume.dataType(), col2vol_, + (columns, volume, sD, sH, sW, pD, pH, pW, dD, dH, dW), + FLOAT_TYPES); } -} + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_conv2d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_conv2d.cpp index 45e66651c19c..cd05df093c7e 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_conv2d.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_conv2d.cpp @@ -18,90 +18,115 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include -#include -#include -#include #include -#include #include +#include +#include +#include +#include +#include namespace sd { - namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - // bias [oC] - // output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 1-NCHW, 0-NHWC - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - nd4j_debug("MKL-DNN is not used for conv2d!\n", 0); - - std::vector permutForOutput; - - if(isNCHW) - permutForOutput = {0, 3, 1, 2}; // [bS, oH, oW, oC] -> [bS, oC, oH, oW] - else - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC - - std::vector wAxes; - if(0 == wFormat) - wAxes = {0, 1, 2}; - else if(1 == wFormat) - wAxes = {2, 3, 1}; - else - wAxes = {1, 2, 3}; - - NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), input->getContext()); - NDArray colP = col.permute({0, 5, 3, 4, 1, 2}); // {bS, iC, kH, kW, oH, oW} - NDArray mmulResult('f', {bS*oH*oW, oC}, output->dataType(), output->getContext()); - - //----- calculation of output -----// - auto ctx = block.launchContext(); - helpers::im2col(*ctx, *input, colP, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, wAxes, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC] - - //----- assign outTemp to output -----// - if(isNCHW) { - mmulResult.reshapei({bS, oH, oW, oC}); - mmulResult.permutei(permutForOutput); - } - output->assign(mmulResult); - - //----- add biases if required -----// - if(bias) - // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); - helpers::addBias(block, *output, *bias, *output, isNCHW); - - if(!isNCHW) - delete input; - - } - -void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); +static void conv2d_(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + NDArray* output, const int kH, const int kW, const int sH, + const int sW, int pH, int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW, + const int wFormat) { + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + // bias [oC] + // output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 1-NCHW, 0-NHWC + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW, paddingMode); + + nd4j_debug("MKL-DNN is not used for conv2d!\n", 0); + + std::vector permutForOutput; + + if (isNCHW) + permutForOutput = {0, 3, 1, 2}; // [bS, oH, oW, oC] -> [bS, oC, oH, oW] + else + input = new NDArray(input->permute( + {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC + + std::vector wAxes; + if (0 == wFormat) + wAxes = {0, 1, 2}; + else if (1 == wFormat) + wAxes = {2, 3, 1}; + else + wAxes = {1, 2, 3}; + + NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), + input->getContext()); + NDArray colP = col.permute({0, 5, 3, 4, 1, 2}); // {bS, iC, kH, kW, oH, oW} + NDArray mmulResult('f', {bS * oH * oW, oC}, output->dataType(), + output->getContext()); + + //----- calculation of output -----// + auto ctx = block.launchContext(); + helpers::im2col( + *ctx, *input, colP, kH, kW, sH, sW, pH, pW, dH, dW, + NDArrayFactory::create( + 0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, + // iC, kH, kW, oH, oW] + MmulHelper::tensorDot( + &col, weights, &mmulResult, {3, 4, 5}, wAxes, + {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC] + + //----- assign outTemp to output -----// + if (isNCHW) { + mmulResult.reshapei({bS, oH, oW, oC}); + mmulResult.permutei(permutForOutput); + } + output->assign(mmulResult); + + //----- add biases if required -----// + if (bias) + // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); + helpers::addBias(block, *output, *bias, *output, isNCHW); + + if (!isNCHW) delete input; } +void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + NDArray* output, const int kH, const int kW, + const int sH, const int sW, int pH, int pW, + const int dH, const int dW, const int paddingMode, + const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE( + input->dataType(), conv2d_, + (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, + paddingMode, isNCHW, wFormat), + FLOAT_TYPES); } -} + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_conv2dBP.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_conv2dBP.cpp index 6a01a4a4dda5..9c367713d961 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_conv2dBP.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_conv2dBP.cpp @@ -18,110 +18,144 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include -#include -#include #include -#include #include +#include +#include +#include +#include namespace sd { - namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - // bias [oC] - // gradO [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - - // gradI [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - // gradW [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - // gradB [oC] - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 0-NHWC, 1-NCHW - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - nd4j_debug("MKL-DNN is not used for conv2d_bp!\n", 0); - - std::vector gradOaxesForDot; - - if(!isNCHW) { - gradOaxesForDot = {0, 1, 2}; // bS, oH, oW - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - } else { - gradOaxesForDot = {0, 2, 3}; // bS, oH, oW - } - - std::vector wPermut, colPermut; - - if(0 == wFormat) { - wPermut = {2, 0, 1, 3}; - colPermut = {2, 3, 1, 0, 4, 5}; - } - else if(1 == wFormat) { - wPermut = {1, 2, 3, 0}; - colPermut = {1, 2, 3, 0, 4, 5}; - } - else { - wPermut = {3, 1, 2, 0}; - colPermut = {2, 3, 1, 0, 4, 5}; - } - - NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); - - // ----- calculation of gradW ----- // - if(gradW) { - auto ctx = block.launchContext(); - helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - sd::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, wPermut); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC] - } - - // ----- calculation of gradB ----- // - if(gradB) { - NDArray* gradBR = gradB; - if(gradB->rankOf() == 2) - gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); - gradO->reduceAlongDimension(reduce::Sum, *gradBR, gradOaxesForDot); // sum over bS, oH, oW - if(gradBR != gradB) - delete gradBR; - } - - //----- calculation of gradI -----// - // [kH, kW, iC, oC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] - // [oC, iC, kH, kW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, bS, oH, oW] - // [oC, kH, kW, iC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] - sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, colPermut); - - helpers::col2im(*block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] - - if(!isNCHW) { - delete input; - delete gradI; - } - } - -void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); +static void conv2dBP_(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + const NDArray* gradO, NDArray* gradI, NDArray* gradW, + NDArray* gradB, const int kH, const int kW, const int sH, + const int sW, int pH, int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW, + const int wFormat) { + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + // bias [oC] + // gradO [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + + // gradI [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + // gradW [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + // gradB [oC] + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 0-NHWC, 1-NCHW + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW, paddingMode); + + nd4j_debug("MKL-DNN is not used for conv2d_bp!\n", 0); + + std::vector gradOaxesForDot; + + if (!isNCHW) { + gradOaxesForDot = {0, 1, 2}; // bS, oH, oW + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradI = new NDArray( + gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + } else { + gradOaxesForDot = {0, 2, 3}; // bS, oH, oW + } + + std::vector wPermut, colPermut; + + if (0 == wFormat) { + wPermut = {2, 0, 1, 3}; + colPermut = {2, 3, 1, 0, 4, 5}; + } else if (1 == wFormat) { + wPermut = {1, 2, 3, 0}; + colPermut = {1, 2, 3, 0, 4, 5}; + } else { + wPermut = {3, 1, 2, 0}; + colPermut = {2, 3, 1, 0, 4, 5}; + } + + NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, + input->dataType(), input->getContext()); + + // ----- calculation of gradW ----- // + if (gradW) { + auto ctx = block.launchContext(); + helpers::im2col( + *ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, + NDArrayFactory::create( + 0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to + // [bS, iC, kH, kW, oH, oW] + sd::MmulHelper::tensorDot( + &columns, gradO, gradW, {0, 4, 5}, gradOaxesForDot, + wPermut); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, + // oW] = [iC, kH, kW, oC] + } + + // ----- calculation of gradB ----- // + if (gradB) { + NDArray* gradBR = gradB; + if (gradB->rankOf() == 2) + gradBR = new NDArray( + gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); + gradO->reduceAlongDimension(reduce::Sum, *gradBR, + gradOaxesForDot); // sum over bS, oH, oW + if (gradBR != gradB) delete gradBR; + } + + //----- calculation of gradI -----// + // [kH, kW, iC, oC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, + // oW] [oC, iC, kH, kW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, bS, + // oH, oW] [oC, kH, kW, iC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, + // bS, oH, oW] + sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, + colPermut); + + helpers::col2im( + *block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, + dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] + + if (!isNCHW) { + delete input; + delete gradI; + } } +void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + const NDArray* gradO, NDArray* gradI, + NDArray* gradW, NDArray* gradB, const int kH, + const int kW, const int sH, const int sW, + int pH, int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW, + const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE( + input->dataType(), conv2dBP_, + (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, + pH, pW, dH, dW, paddingMode, isNCHW, wFormat), + FLOAT_TYPES); } -} \ No newline at end of file + +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_depthwiseConv2d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_depthwiseConv2d.cpp index fa86dbd6c199..6be9e4014710 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_depthwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_depthwiseConv2d.cpp @@ -18,84 +18,119 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include -#include -#include -#include #include -#include #include +#include +#include +#include +#include +#include namespace sd { - namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - // bias [oC] = iC*mC - // output [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 0-NCHW, 1-NHWC - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - std::vector> modifColumns = {{1,0,4,5,2,3}, {iC,bS*oH*oW,kH*kW}}; // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> [iC,bS*oH*oW,kH*kW] - std::vector> modifOutput, modifWeights; - std::vector outReShape; - - if(!isNCHW) { - outReShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] - modifOutput = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] - } - else { - outReShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] - modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - } - - if(0 == wFormat) - modifWeights = {{2,0,1,3},{iC,kH*kW,mC}}; - else if(1 == wFormat) - modifWeights = {{1,2,3,0},{iC,kH*kW,mC}}; - else - modifWeights = {{3,1,2,0},{iC,kH*kW,mC}}; - - if(paddingMode == 1) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); - NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false); - - helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, modifWeights, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] - - if(bias) - // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); - helpers::addBias(block, *output, *bias, *output, isNCHW); - - if(!isNCHW) - delete input; - } - -void ConvolutionUtils::depthwiseConv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); - } - +static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + NDArray* output, const int kH, const int kW, + const int sH, const int sW, int pH, int pW, + const int dH, const int dW, const int paddingMode, + const int isNCHW, const int wFormat) { + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // bias [oC] = iC*mC + // output [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 0-NCHW, 1-NHWC + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + std::vector> modifColumns = { + {1, 0, 4, 5, 2, 3}, + {iC, bS * oH * oW, + kH * kW}}; // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> + // [iC,bS*oH*oW,kH*kW] + std::vector> modifOutput, modifWeights; + std::vector outReShape; + + if (!isNCHW) { + outReShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] + modifOutput = { + {3, 0, 1, 2, 4}, + {iC, bS * oH * oW, + mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] + } else { + outReShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] + modifOutput = { + {1, 0, 3, 4, 2}, + {iC, bS * oH * oW, + mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + } + + if (0 == wFormat) + modifWeights = {{2, 0, 1, 3}, {iC, kH * kW, mC}}; + else if (1 == wFormat) + modifWeights = {{1, 2, 3, 0}, {iC, kH * kW, mC}}; + else + modifWeights = {{3, 1, 2, 0}, {iC, kH * kW, mC}}; + + if (paddingMode == 1) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, + input->dataType(), input->getContext()); + NDArray outputReshaped = + output->reshape(output->ordering(), outReShape, false); + + helpers::im2col( + *output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, + NDArrayFactory::create( + 0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, + // iC, kH, kW, oH, oW] + MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, + modifWeights, + modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, + // mC] = [iC, bS*oH*oW, mC] + + if (bias) + // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); + helpers::addBias(block, *output, *bias, *output, isNCHW); + + if (!isNCHW) delete input; } + +void ConvolutionUtils::depthwiseConv2d( + sd::graph::Context& block, const NDArray* input, const NDArray* weights, + const NDArray* bias, NDArray* output, const int kH, const int kW, + const int sH, const int sW, int pH, int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE( + input->dataType(), depthwiseConv2d_, + (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, + paddingMode, isNCHW, wFormat), + FLOAT_TYPES); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_depthwiseConv2dBP.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_depthwiseConv2dBP.cpp index 7c0d933e235a..5b704ce7f464 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_depthwiseConv2dBP.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_depthwiseConv2dBP.cpp @@ -18,103 +18,155 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // +#include +#include +#include #include #include -#include -#include -#include namespace sd { - namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) - // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - // bias [oC] = [iC*mC] - // gradO [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - // gradI [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon - // gradW [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - // gradB [oC] - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 0-NHWC, 1-NCHW - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - std::vector> modifColumns = {{1,2,3,0,4,5}, {iC, kH*kW, bS*oH*oW}}; // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW] - std::vector> modifGradO1, modifGradO2, modifWeights; - std::vector gradOreShape; - - if(!isNCHW) { - gradOreShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] - modifGradO1 = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - modifGradO2 = {{3,0,1,2},{iC, mC, bS*oH*oW}}; // [bS,oH,oW,iC*mC] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] - gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] - } - else { - gradOreShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] - modifGradO1 = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] - } - - if(0 == wFormat) - modifWeights = {{2,0,1,3},{iC,kH*kW,mC}}; - else if(1 == wFormat) - modifWeights = {{1,2,3,0},{iC,kH*kW,mC}}; - else - modifWeights = {{3,1,2,0},{iC,kH*kW,mC}}; - - if(paddingMode == 1) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); - NDArray gradOreshaped = gradO->reshape(gradO->ordering(), gradOreShape); - - // ----- calculation of gradW and gradB ----- // - - helpers::im2col(*input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, modifGradO1, modifWeights); // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC] - - // ----- calculation of gradB ----- // - if(gradB) { - NDArray* gradBR = gradB; - if(gradB->rankOf() == 2) - gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false)); - gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW - - if(gradBR != gradB) - delete gradBR; - } - - //----- calculation of gradI -----// - sd::MmulHelper::tensorDot(weights, gradO, &columns, modifWeights, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW] - helpers::col2im(*input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] - - if(!isNCHW) { - delete input; - delete gradI; - } - } - -void ConvolutionUtils::depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); - } +static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, + const NDArray* bias, const NDArray* gradO, + NDArray* gradI, NDArray* gradW, NDArray* gradB, + const int kH, const int kW, const int sH, + const int sW, int pH, int pW, const int dH, + const int dW, const int paddingMode, + const int isNCHW, const int wFormat) { + // input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) + // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // bias [oC] = [iC*mC] + // gradO [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next + // gradI [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon + // gradW [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // gradB [oC] + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 0-NHWC, 1-NCHW + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + std::vector> modifColumns = { + {1, 2, 3, 0, 4, 5}, + {iC, kH * kW, + bS * oH * oW}}; // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW] + std::vector> modifGradO1, modifGradO2, modifWeights; + std::vector gradOreShape; + + if (!isNCHW) { + gradOreShape = {bS, oH, oW, iC, + mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] + modifGradO1 = { + {3, 0, 1, 2, 4}, + {iC, bS * oH * oW, + mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + modifGradO2 = { + {3, 0, 1, 2}, + {iC, mC, + bS * oH * + oW}}; // [bS,oH,oW,iC*mC] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] + gradI = new NDArray( + gradI->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] + } else { + gradOreShape = {bS, iC, mC, oH, + oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] + modifGradO1 = { + {1, 0, 3, 4, 2}, + {iC, bS * oH * oW, + mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + modifGradO2 = { + {1, 0, 2, 3}, + {iC, mC, + bS * oH * + oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] + } + + if (0 == wFormat) + modifWeights = {{2, 0, 1, 3}, {iC, kH * kW, mC}}; + else if (1 == wFormat) + modifWeights = {{1, 2, 3, 0}, {iC, kH * kW, mC}}; + else + modifWeights = {{3, 1, 2, 0}, {iC, kH * kW, mC}}; + + if (paddingMode == 1) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, + input->dataType(), input->getContext()); + NDArray gradOreshaped = gradO->reshape(gradO->ordering(), gradOreShape); + + // ----- calculation of gradW and gradB ----- // + + helpers::im2col( + *input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, + NDArrayFactory::create( + 0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, + // iC, kH, kW, oH, oW] + sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, + modifGradO1, + modifWeights); // [iC, kW*kH, bS*oH*oW] x [iC, + // bS*oH*oW, mC] = [iC, kH*kW, mC] + + // ----- calculation of gradB ----- // + if (gradB) { + NDArray* gradBR = gradB; + if (gradB->rankOf() == 2) + gradBR = new NDArray( + gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false)); + gradO->reduceAlongDimension( + reduce::Sum, *gradBR, {0, indOoH, indOoH + 1}); // sum over bS, oH, oW + + if (gradBR != gradB) delete gradBR; + } + + //----- calculation of gradI -----// + sd::MmulHelper::tensorDot(weights, gradO, &columns, modifWeights, modifGradO2, + modifColumns); // [iC, kH*kW, mC] x [iC, mC, + // bS*oH*oW] = [iC, kW*kH, bS*oH*oW] + helpers::col2im( + *input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, + dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] + + if (!isNCHW) { + delete input; + delete gradI; + } +} +void ConvolutionUtils::depthwiseConv2dBP( + sd::graph::Context& block, const NDArray* input, const NDArray* weights, + const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, + NDArray* gradB, const int kH, const int kW, const int sH, const int sW, + int pH, int pW, const int dH, const int dW, const int paddingMode, + const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE( + input->dataType(), depthwiseConv2dBP_, + (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, + dH, dW, paddingMode, isNCHW, wFormat), + FLOAT_TYPES); } -} \ No newline at end of file + +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2d.cpp index 26dc4f99eae3..a6942a7d30a8 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2d.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2d.cpp @@ -18,206 +18,272 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include #include +#include namespace sd { - namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// - template - static void pooling2d_(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { - // input is [bS, iC, iH, iW] - // output is [bS, iC, oH, oW] - T* out = output.bufferAsT(); - T* in = const_cast(input).bufferAsT(); - - const int kHEff = kH + (kH-1)*(dH-1); - const int kWEff = kW + (kW-1)*(dW-1); - - const int bS = input.sizeAt(0); - const int iC = input.sizeAt(1); - const int iH = input.sizeAt(2); - const int iW = input.sizeAt(3); - const int oC = output.sizeAt(1); - const int oH = output.sizeAt(2); - const int oW = output.sizeAt(3); - - nd4j_debug("MKL-DNN is not used for pooling2d!\n", 0); - - const Nd4jLong iStride0 = input.stridesOf()[0]; - const Nd4jLong iStride1 = input.stridesOf()[1]; - const Nd4jLong iStride2 = input.stridesOf()[2]; - const Nd4jLong iStride3 = input.stridesOf()[3]; - const Nd4jLong oStride0 = output.stridesOf()[0]; - const Nd4jLong oStride1 = output.stridesOf()[1]; - const Nd4jLong oStride2 = output.stridesOf()[2]; - const Nd4jLong oStride3 = output.stridesOf()[3]; - - const Nd4jLong iStep2 = dH*iStride2; - const Nd4jLong iStep3 = dW*iStride3; - const int kProd = kH*kW; - - if(poolingMode == 0) { // max - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong hstart, wstart, hend, wend; - T *pIn; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - hstart = oh * sH - pH; - wstart = ow * sW - pW; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); - - hstart *= iStride2; - hend *= iStride2; - wstart *= iStride3; - wend *= iStride3; - - T max = -DataTypeUtils::max(); - - for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) { - T val = pIn[kh + kw]; - if (val > max) - max = val; - } - out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = max; - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); - } -/*************************************************************************/ - else if(poolingMode == 1) { // avg - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong hstart, wstart, hend, wend; - T *pIn; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - hstart = oh * sH - pH; - wstart = ow * sW - pW; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); - - hstart *= iStride2; - hend *= iStride2; - wstart *= iStride3; - wend *= iStride3; - - T sum = static_cast(0.f); - - for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) - sum += pIn[kh + kw]; - - if (extraParam0 == 0) { //Exclude padding - int a = (hend - hstart) / iStep2 + ((hend - hstart) % iStep2 == 0 ? 0 : 1); - int r = (wend - wstart) / iStep3 + ((wend - wstart) % iStep3 == 0 ? 0 : 1); - sum /= static_cast(a * r); // Accounts for dilation - } else if (extraParam0 == 1) //Include padding - sum /= kProd; - - out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = sum; - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); +template +static void pooling2d_(sd::graph::Context& block, const NDArray& input, + NDArray& output, const int kH, const int kW, + const int sH, const int sW, const int pH, const int pW, + const int dH, const int dW, const int poolingMode, + const int extraParam0) { + // input is [bS, iC, iH, iW] + // output is [bS, iC, oH, oW] + T* out = output.bufferAsT(); + T* in = const_cast(input).bufferAsT(); + + const int kHEff = kH + (kH - 1) * (dH - 1); + const int kWEff = kW + (kW - 1) * (dW - 1); + + const int bS = input.sizeAt(0); + const int iC = input.sizeAt(1); + const int iH = input.sizeAt(2); + const int iW = input.sizeAt(3); + const int oC = output.sizeAt(1); + const int oH = output.sizeAt(2); + const int oW = output.sizeAt(3); + + nd4j_debug("MKL-DNN is not used for pooling2d!\n", 0); + + const Nd4jLong iStride0 = input.stridesOf()[0]; + const Nd4jLong iStride1 = input.stridesOf()[1]; + const Nd4jLong iStride2 = input.stridesOf()[2]; + const Nd4jLong iStride3 = input.stridesOf()[3]; + const Nd4jLong oStride0 = output.stridesOf()[0]; + const Nd4jLong oStride1 = output.stridesOf()[1]; + const Nd4jLong oStride2 = output.stridesOf()[2]; + const Nd4jLong oStride3 = output.stridesOf()[3]; + + const Nd4jLong iStep2 = dH * iStride2; + const Nd4jLong iStep3 = dW * iStride3; + const int kProd = kH * kW; + + if (poolingMode == 0) { // max + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong hstart, wstart, hend, wend; + T* pIn; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pIn = in + b * iStride0 + c * iStride1; + + hstart = oh * sH - pH; + wstart = ow * sW - pW; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (hstart < 0) + hstart += + dH * + ((-hstart + dH - 1) / + dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) + // / static_cast(dH)); + if (wstart < 0) + wstart += + dW * + ((-wstart + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) + /// static_cast(dW)); + if (hend > iH) + hend -= + dH * + ((hend - iH + dH - 1) / + dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) + /// static_cast(dH)); + if (wend > iW) + wend -= + dW * + ((wend - iW + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) + /// static_cast(dW)); + + hstart *= iStride2; + hend *= iStride2; + wstart *= iStride3; + wend *= iStride3; + + T max = -DataTypeUtils::max(); + + for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) { + T val = pIn[kh + kw]; + if (val > max) max = val; + } + out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = + max; } -/*************************************************************************/ - else if(poolingMode == 2) { // pnorm - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong hstart, wstart, hend, wend; - T *pIn; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - hstart = oh * sH - pH; - wstart = ow * sW - pW; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); - - hstart *= iStride2; - hend *= iStride2; - wstart *= iStride3; - wend *= iStride3; - - T sum = static_cast(0.f); - - for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) - sum += sd::math::nd4j_pow(sd::math::nd4j_abs(pIn[kh + kw]), extraParam0); - - sum = sd::math::nd4j_pow(sum, static_cast((T) 1.f) / extraParam0); - - out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = sum; - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); - } - else { - nd4j_printf("ConvolutionUtils::pooling2d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); - throw ""; + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } + /*************************************************************************/ + else if (poolingMode == 1) { // avg + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong hstart, wstart, hend, wend; + T* pIn; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pIn = in + b * iStride0 + c * iStride1; + + hstart = oh * sH - pH; + wstart = ow * sW - pW; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (hstart < 0) + hstart += + dH * + ((-hstart + dH - 1) / + dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) + // / static_cast(dH)); + if (wstart < 0) + wstart += + dW * + ((-wstart + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) + /// static_cast(dW)); + if (hend > iH) + hend -= + dH * + ((hend - iH + dH - 1) / + dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) + /// static_cast(dH)); + if (wend > iW) + wend -= + dW * + ((wend - iW + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) + /// static_cast(dW)); + + hstart *= iStride2; + hend *= iStride2; + wstart *= iStride3; + wend *= iStride3; + + T sum = static_cast(0.f); + + for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) + sum += pIn[kh + kw]; + + if (extraParam0 == 0) { // Exclude padding + int a = (hend - hstart) / iStep2 + + ((hend - hstart) % iStep2 == 0 ? 0 : 1); + int r = (wend - wstart) / iStep3 + + ((wend - wstart) % iStep3 == 0 ? 0 : 1); + sum /= static_cast(a * r); // Accounts for dilation + } else if (extraParam0 == 1) // Include padding + sum /= kProd; + + out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = + sum; } + } } - - void ConvolutionUtils::pooling2d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0) { - BUILD_SINGLE_SELECTOR(input.dataType(), pooling2d_, (block, input, output, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } + /*************************************************************************/ + else if (poolingMode == 2) { // pnorm + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong hstart, wstart, hend, wend; + T* pIn; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pIn = in + b * iStride0 + c * iStride1; + + hstart = oh * sH - pH; + wstart = ow * sW - pW; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (hstart < 0) + hstart += + dH * + ((-hstart + dH - 1) / + dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) + // / static_cast(dH)); + if (wstart < 0) + wstart += + dW * + ((-wstart + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) + /// static_cast(dW)); + if (hend > iH) + hend -= + dH * + ((hend - iH + dH - 1) / + dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) + /// static_cast(dH)); + if (wend > iW) + wend -= + dW * + ((wend - iW + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) + /// static_cast(dW)); + + hstart *= iStride2; + hend *= iStride2; + wstart *= iStride3; + wend *= iStride3; + + T sum = static_cast(0.f); + + for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) + sum += sd::math::nd4j_pow( + sd::math::nd4j_abs(pIn[kh + kw]), extraParam0); + + sum = sd::math::nd4j_pow( + sum, static_cast((T)1.f) / extraParam0); + + out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = + sum; + } + } } - + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } else { + nd4j_printf( + "ConvolutionUtils::pooling2d: pooling mode argument can take three " + "values only: 0, 1, 2, but got %i instead !\n", + poolingMode); + throw ""; + } } + +void ConvolutionUtils::pooling2d(sd::graph::Context& block, + const NDArray& input, NDArray& output, + const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, + const int dH, const int dW, + const PoolingType poolingMode, + const int extraParam0) { + BUILD_SINGLE_SELECTOR(input.dataType(), pooling2d_, + (block, input, output, kH, kW, sH, sW, pH, pW, dH, dW, + poolingMode, extraParam0), + FLOAT_TYPES); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2dBP.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2dBP.cpp index 03f34bfae342..a6874e8f82a9 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2dBP.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2dBP.cpp @@ -18,289 +18,346 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include #include +#include namespace sd { - namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// - template - static void pooling2dBP_(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { - // input [bS, iC, iH, iW] - // gradI [bS, iC, iH, iW] -> gradI is output in this function - // gradO [bS, iC, oH, oW] - - // initial zeroing of gradI - gradI.nullify(); - - T* in = const_cast(input).bufferAsT(); - T* gO = const_cast(gradO).bufferAsT(); - T* gI = gradI.bufferAsT(); - - const int kHEff = kH + (kH-1)*(dH-1); - const int kWEff = kW + (kW-1)*(dW-1); - - const int bS = gradI.sizeAt(0); - const int iC = gradI.sizeAt(1); - const int iH = gradI.sizeAt(2); - const int iW = gradI.sizeAt(3); - const int oC = gradO.sizeAt(1); - const int oH = gradO.sizeAt(2); - const int oW = gradO.sizeAt(3); - - nd4j_debug("MKL-DNN is not used for pooling2d_bp!\n", 0); - - const Nd4jLong iStride0 = input.stridesOf()[0]; - const Nd4jLong iStride1 = input.stridesOf()[1]; - const Nd4jLong iStride2 = input.stridesOf()[2]; - const Nd4jLong iStride3 = input.stridesOf()[3]; - const Nd4jLong gIStride0 = gradI.stridesOf()[0]; - const Nd4jLong gIStride1 = gradI.stridesOf()[1]; - const Nd4jLong gIStride2 = gradI.stridesOf()[2]; - const Nd4jLong gIStride3 = gradI.stridesOf()[3]; - const Nd4jLong oStride0 = gradO.stridesOf()[0]; - const Nd4jLong oStride1 = gradO.stridesOf()[1]; - const Nd4jLong oStride2 = gradO.stridesOf()[2]; - const Nd4jLong oStride3 = gradO.stridesOf()[3]; - const Nd4jLong iStep2 = dH*iStride2; - const Nd4jLong iStep3 = dW*iStride3; - const Nd4jLong gIStep2 = dH*gIStride2; - const Nd4jLong gIStep3 = dW*gIStride3; - const int kProd = kH*kW; - - const bool sameStrides = iStride0 == gIStride0 && iStride1 == gIStride1 && iStride2 == gIStride2 && iStride3 == gIStride3; - - if(poolingMode == 0) { // max - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong hstart, wstart,hend, wend, maxKH, maxKW; - T sum, valO, *pIn, *pgI; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - hstart = oh * sH - pH; - wstart = ow * sW - pW; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); - - sum = -DataTypeUtils::max(); - valO = gO[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3]; - - if (sameStrides) { - - hstart *= iStride2; - hend *= iStride2; - wstart *= iStride3; - wend *= iStride3; - - // we set these to default values - maxKH = hstart; - maxKW = wstart; - - for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) { - T valIn = pIn[kh + kw]; - if (valIn > sum) { - sum = valIn; - maxKH = kh; - maxKW = kw; - } - } - gI[pIn - in + maxKH + maxKW] += valO; - } else { - - // we set these to default values - maxKH = hstart; - maxKW = wstart; - - for (Nd4jLong kh = hstart; kh < hend; kh += dH) - for (Nd4jLong kw = wstart; kw < wend; kw += dW) { - T valIn = pIn[kh * iStride2 + kw * iStride3]; - if (valIn > sum) { - sum = valIn; - maxKH = kh; - maxKW = kw; - } - } - - gI[b * gIStride0 + c * gIStride1 + maxKH * gIStride2 + maxKW * gIStride3] += valO; - } - } - } - } +template +static void pooling2dBP_(sd::graph::Context& block, const NDArray& input, + const NDArray& gradO, NDArray& gradI, const int kH, + const int kW, const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW, + const int poolingMode, const int extraParam0) { + // input [bS, iC, iH, iW] + // gradI [bS, iC, iH, iW] -> gradI is output in this function + // gradO [bS, iC, oH, oW] + + // initial zeroing of gradI + gradI.nullify(); + + T* in = const_cast(input).bufferAsT(); + T* gO = const_cast(gradO).bufferAsT(); + T* gI = gradI.bufferAsT(); + + const int kHEff = kH + (kH - 1) * (dH - 1); + const int kWEff = kW + (kW - 1) * (dW - 1); + + const int bS = gradI.sizeAt(0); + const int iC = gradI.sizeAt(1); + const int iH = gradI.sizeAt(2); + const int iW = gradI.sizeAt(3); + const int oC = gradO.sizeAt(1); + const int oH = gradO.sizeAt(2); + const int oW = gradO.sizeAt(3); + + nd4j_debug("MKL-DNN is not used for pooling2d_bp!\n", 0); + + const Nd4jLong iStride0 = input.stridesOf()[0]; + const Nd4jLong iStride1 = input.stridesOf()[1]; + const Nd4jLong iStride2 = input.stridesOf()[2]; + const Nd4jLong iStride3 = input.stridesOf()[3]; + const Nd4jLong gIStride0 = gradI.stridesOf()[0]; + const Nd4jLong gIStride1 = gradI.stridesOf()[1]; + const Nd4jLong gIStride2 = gradI.stridesOf()[2]; + const Nd4jLong gIStride3 = gradI.stridesOf()[3]; + const Nd4jLong oStride0 = gradO.stridesOf()[0]; + const Nd4jLong oStride1 = gradO.stridesOf()[1]; + const Nd4jLong oStride2 = gradO.stridesOf()[2]; + const Nd4jLong oStride3 = gradO.stridesOf()[3]; + const Nd4jLong iStep2 = dH * iStride2; + const Nd4jLong iStep3 = dW * iStride3; + const Nd4jLong gIStep2 = dH * gIStride2; + const Nd4jLong gIStep3 = dW * gIStride3; + const int kProd = kH * kW; + + const bool sameStrides = iStride0 == gIStride0 && iStride1 == gIStride1 && + iStride2 == gIStride2 && iStride3 == gIStride3; + + if (poolingMode == 0) { // max + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong hstart, wstart, hend, wend, maxKH, maxKW; + T sum, valO, *pIn, *pgI; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pIn = in + b * iStride0 + c * iStride1; + + hstart = oh * sH - pH; + wstart = ow * sW - pW; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (hstart < 0) + hstart += + dH * + ((-hstart + dH - 1) / + dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) + // / static_cast(dH)); + if (wstart < 0) + wstart += + dW * + ((-wstart + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) + /// static_cast(dW)); + if (hend > iH) + hend -= + dH * + ((hend - iH + dH - 1) / + dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) + /// static_cast(dH)); + if (wend > iW) + wend -= + dW * + ((wend - iW + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) + /// static_cast(dW)); + + sum = -DataTypeUtils::max(); + valO = gO[b * oStride0 + c * oStride1 + oh * oStride2 + + ow * oStride3]; + + if (sameStrides) { + hstart *= iStride2; + hend *= iStride2; + wstart *= iStride3; + wend *= iStride3; + + // we set these to default values + maxKH = hstart; + maxKW = wstart; + + for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) { + T valIn = pIn[kh + kw]; + if (valIn > sum) { + sum = valIn; + maxKH = kh; + maxKW = kw; } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); - } -/*************************************************************************/ - else if(poolingMode == 1) { // avg - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong hstart, wstart, hend, wend, maxKH, maxKW; - T sum, valO, *pIn, *pgI; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pgI = gI + b * gIStride0 + c * gIStride1; - - hstart = oh * sH - pH; - wstart = ow * sW - pW; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / - dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / - dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / - dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / - dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); - - hstart *= gIStride2; - hend *= gIStride2; - wstart *= gIStride3; - wend *= gIStride3; - - valO = gO[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3]; - - if ((int) extraParam0 == 0) //Exclude padding - valO /= static_cast(sd::math::nd4j_ceil( - static_cast(hend - hstart) / static_cast(gIStep2))) * - static_cast(sd::math::nd4j_ceil( - static_cast(wend - wstart) / - static_cast(gIStep3))); //Accounts for dilation - else if ((int) extraParam0 == 1) //Include padding - valO /= kProd; - - for (Nd4jLong kh = hstart; kh < hend; kh += gIStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += gIStep3) - pgI[kh + kw] += valO; - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); - } -/*************************************************************************/ - else if(poolingMode == 2) { // pnorm - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong hstart, wstart, hend, wend, maxKH, maxKW; - T sum, valO, *pIn, *pgI; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - pgI = sameStrides ? gI + (pIn - in) : gI + b * gIStride0 + c * gIStride1; - - hstart = oh * sH - pH; - wstart = ow * sW - pW; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / - dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / - dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / - dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / - dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); - - sum = static_cast(0.f); - valO = gO[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3]; - - if (sameStrides) { - - hstart *= iStride2; - hend *= iStride2; - wstart *= iStride3; - wend *= iStride3; - - for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) - sum += sd::math::nd4j_pow( - sd::math::nd4j_abs(pIn[kh + kw]), extraParam0); - - valO *= sd::math::nd4j_pow(sum, - ((T) 1. - extraParam0) / extraParam0); - - for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) - pgI[kh + kw] += valO * sd::math::nd4j_pow( - sd::math::nd4j_abs(pIn[kh + kw]), extraParam0 - 1.f) * - sd::math::nd4j_sgn(pIn[kh + kw]); - } else { - - for (Nd4jLong kh = hstart; kh < hend; kh += dH) - for (Nd4jLong kw = wstart; kw < wend; kw += dW) - sum += sd::math::nd4j_pow( - sd::math::nd4j_abs(pIn[kh * iStride2 + kw * iStride3]), - extraParam0); - - valO *= sd::math::nd4j_pow(sum, - ((T) 1. - extraParam0) / extraParam0); - - for (Nd4jLong kh = hstart; kh < hend; kh += dH) { - for (Nd4jLong kw = wstart; kw < wend; kw += dW) { - const auto inVal = pIn[kh * iStride2 + kw * iStride3]; - pgI[kh * gIStride2 + kw * gIStride3] += valO * - sd::math::nd4j_pow( - sd::math::nd4j_abs( - inVal), - extraParam0 - 1.f) * - sd::math::nd4j_sgn( - inVal); - } - } - } - } - } - } + } + gI[pIn - in + maxKH + maxKW] += valO; + } else { + // we set these to default values + maxKH = hstart; + maxKW = wstart; + + for (Nd4jLong kh = hstart; kh < hend; kh += dH) + for (Nd4jLong kw = wstart; kw < wend; kw += dW) { + T valIn = pIn[kh * iStride2 + kw * iStride3]; + if (valIn > sum) { + sum = valIn; + maxKH = kh; + maxKW = kw; } - }; + } - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + gI[b * gIStride0 + c * gIStride1 + maxKH * gIStride2 + + maxKW * gIStride3] += valO; + } } - else { - nd4j_printf("ConvolutionUtils::pooling2dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); - throw std::runtime_error("Incorrect pooling2dBP mode"); + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } + /*************************************************************************/ + else if (poolingMode == 1) { // avg + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong hstart, wstart, hend, wend, maxKH, maxKW; + T sum, valO, *pIn, *pgI; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pgI = gI + b * gIStride0 + c * gIStride1; + + hstart = oh * sH - pH; + wstart = ow * sW - pW; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (hstart < 0) + hstart += + dH * + ((-hstart + dH - 1) / + dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) + // / static_cast(dH)); + if (wstart < 0) + wstart += + dW * + ((-wstart + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) + /// static_cast(dW)); + if (hend > iH) + hend -= + dH * + ((hend - iH + dH - 1) / + dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) + /// static_cast(dH)); + if (wend > iW) + wend -= + dW * + ((wend - iW + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) + /// static_cast(dW)); + + hstart *= gIStride2; + hend *= gIStride2; + wstart *= gIStride3; + wend *= gIStride3; + + valO = gO[b * oStride0 + c * oStride1 + oh * oStride2 + + ow * oStride3]; + + if ((int)extraParam0 == 0) // Exclude padding + valO /= static_cast(sd::math::nd4j_ceil( + static_cast(hend - hstart) / + static_cast(gIStep2))) * + static_cast(sd::math::nd4j_ceil( + static_cast(wend - wstart) / + static_cast( + gIStep3))); // Accounts for dilation + else if ((int)extraParam0 == 1) // Include padding + valO /= kProd; + + for (Nd4jLong kh = hstart; kh < hend; kh += gIStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += gIStep3) + pgI[kh + kw] += valO; } + } } - -void ConvolutionUtils::pooling2dBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { - BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBP_, (block, input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } + /*************************************************************************/ + else if (poolingMode == 2) { // pnorm + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong hstart, wstart, hend, wend, maxKH, maxKW; + T sum, valO, *pIn, *pgI; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pIn = in + b * iStride0 + c * iStride1; + pgI = sameStrides ? gI + (pIn - in) + : gI + b * gIStride0 + c * gIStride1; + + hstart = oh * sH - pH; + wstart = ow * sW - pW; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (hstart < 0) + hstart += + dH * + ((-hstart + dH - 1) / + dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) + // / static_cast(dH)); + if (wstart < 0) + wstart += + dW * + ((-wstart + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) + /// static_cast(dW)); + if (hend > iH) + hend -= + dH * + ((hend - iH + dH - 1) / + dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) + /// static_cast(dH)); + if (wend > iW) + wend -= + dW * + ((wend - iW + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) + /// static_cast(dW)); + + sum = static_cast(0.f); + valO = gO[b * oStride0 + c * oStride1 + oh * oStride2 + + ow * oStride3]; + + if (sameStrides) { + hstart *= iStride2; + hend *= iStride2; + wstart *= iStride3; + wend *= iStride3; + + for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) + sum += sd::math::nd4j_pow( + sd::math::nd4j_abs(pIn[kh + kw]), extraParam0); + + valO *= sd::math::nd4j_pow( + sum, ((T)1. - extraParam0) / extraParam0); + + for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) + pgI[kh + kw] += valO * + sd::math::nd4j_pow( + sd::math::nd4j_abs(pIn[kh + kw]), + extraParam0 - 1.f) * + sd::math::nd4j_sgn(pIn[kh + kw]); + } else { + for (Nd4jLong kh = hstart; kh < hend; kh += dH) + for (Nd4jLong kw = wstart; kw < wend; kw += dW) + sum += sd::math::nd4j_pow( + sd::math::nd4j_abs( + pIn[kh * iStride2 + kw * iStride3]), + extraParam0); + + valO *= sd::math::nd4j_pow( + sum, ((T)1. - extraParam0) / extraParam0); + + for (Nd4jLong kh = hstart; kh < hend; kh += dH) { + for (Nd4jLong kw = wstart; kw < wend; kw += dW) { + const auto inVal = pIn[kh * iStride2 + kw * iStride3]; + pgI[kh * gIStride2 + kw * gIStride3] += + valO * + sd::math::nd4j_pow( + sd::math::nd4j_abs(inVal), extraParam0 - 1.f) * + sd::math::nd4j_sgn(inVal); + } + } + } + } + } } - + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } else { + nd4j_printf( + "ConvolutionUtils::pooling2dBP: pooling mode argument can take three " + "values only: 0, 1, 2, but got %i instead !\n", + poolingMode); + throw std::runtime_error("Incorrect pooling2dBP mode"); + } } + +void ConvolutionUtils::pooling2dBP(sd::graph::Context& block, + const NDArray& input, const NDArray& gradO, + NDArray& gradI, const int kH, const int kW, + const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW, + const int poolingMode, + const int extraParam0) { + BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBP_, + (block, input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, + dW, poolingMode, extraParam0), + FLOAT_TYPES); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling3d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling3d.cpp index 04d5f993ad4b..860b5f774ba2 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling3d.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling3d.cpp @@ -18,244 +18,251 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include #include +#include namespace sd { - namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// - template - static void pooling3d_(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - // input is [bS, iC, iD, iH, iW] - // output is [bS, iC, oD, oH, oW] - T* out = output.bufferAsT(); - T* in = const_cast(input).bufferAsT(); - - const int kDEff = kD + (kD-1)*(dD-1); - const int kHEff = kH + (kH-1)*(dH-1); - const int kWEff = kW + (kW-1)*(dW-1); - - const int bS = input.sizeAt(0); - const int iC = input.sizeAt(1); - const int iD = input.sizeAt(2); - const int iH = input.sizeAt(3); - const int iW = input.sizeAt(4); - const int oC = output.sizeAt(1); - const int oD = output.sizeAt(2); - const int oH = output.sizeAt(3); - const int oW = output.sizeAt(4); - - nd4j_debug("MKL-DNN is not used for pooling3d!\n", 0); - - const Nd4jLong iStride0 = input.stridesOf()[0]; - const Nd4jLong iStride1 = input.stridesOf()[1]; - const Nd4jLong iStride2 = input.stridesOf()[2]; - const Nd4jLong iStride3 = input.stridesOf()[3]; - const Nd4jLong iStride4 = input.stridesOf()[4]; - const Nd4jLong oStride0 = output.stridesOf()[0]; - const Nd4jLong oStride1 = output.stridesOf()[1]; - const Nd4jLong oStride2 = output.stridesOf()[2]; - const Nd4jLong oStride3 = output.stridesOf()[3]; - const Nd4jLong oStride4 = output.stridesOf()[4]; - const Nd4jLong iStep2 = dD*iStride2; - const Nd4jLong iStep3 = dH*iStride3; - const Nd4jLong iStep4 = dW*iStride4; - const int kProd = kD*kH*kW; - - if(poolingMode == 0) { // max - auto func = PRAGMA_THREADS_FOR_3D { - Nd4jLong dstart, hstart, wstart, dend, hend, wend; - T sum, *pIn; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int od = start_z; od < stop_z; od += inc_z) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - dstart = od * sD - pD; - hstart = oh * sH - pH; - wstart = ow * sW - pW; - dend = dstart + kDEff; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if (dend > iD) - dend -= dD * ((dend - iD + dD - 1) / dD); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - dstart *= iStride2; - dend *= iStride2; - hstart *= iStride3; - hend *= iStride3; - wstart *= iStride4; - wend *= iStride4; - - sum = -DataTypeUtils::max(); - - for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) { - T val = pIn[kd + kh + kw]; - if (val > sum) - sum = val; - } - - out[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4] = sum; - } - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); - } -/*************************************************************************/ - else if(poolingMode == 1) { // avg - auto func = PRAGMA_THREADS_FOR_3D { - Nd4jLong dstart, hstart, wstart, dend, hend, wend; - T sum, *pIn; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int od = start_z; od < stop_z; od += inc_z) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - dstart = od * sD - pD; - hstart = oh * sH - pH; - wstart = ow * sW - pW; - dend = dstart + kDEff; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if (dend > iD) - dend -= dD * ((dend - iD + dD - 1) / dD); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - dstart *= iStride2; - dend *= iStride2; - hstart *= iStride3; - hend *= iStride3; - wstart *= iStride4; - wend *= iStride4; - - sum = static_cast(0.); - - for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) - sum += pIn[kd + kh + kw]; - - if (extraParam0 == 0) //Exclude padding - sum /= sd::math::nd4j_ceil(static_cast(dend - dstart) / static_cast(iStep2)) * sd::math::nd4j_ceil(static_cast(hend - hstart) / static_cast(iStep3)) * sd::math::nd4j_ceil(static_cast(wend - wstart) / static_cast(iStep4)); //Accounts for dilation - else if (extraParam0 == 1) //Include padding - sum /= kProd; - - out[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4] = sum; - } - } - } - } +template +static void pooling3d_(sd::graph::Context& block, const NDArray& input, + NDArray& output, const int kD, const int kH, + const int kW, const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, const int dD, + const int dH, const int dW, const int poolingMode, + const int extraParam0) { + // input is [bS, iC, iD, iH, iW] + // output is [bS, iC, oD, oH, oW] + T* out = output.bufferAsT(); + T* in = const_cast(input).bufferAsT(); + + const int kDEff = kD + (kD - 1) * (dD - 1); + const int kHEff = kH + (kH - 1) * (dH - 1); + const int kWEff = kW + (kW - 1) * (dW - 1); + + const int bS = input.sizeAt(0); + const int iC = input.sizeAt(1); + const int iD = input.sizeAt(2); + const int iH = input.sizeAt(3); + const int iW = input.sizeAt(4); + const int oC = output.sizeAt(1); + const int oD = output.sizeAt(2); + const int oH = output.sizeAt(3); + const int oW = output.sizeAt(4); + + nd4j_debug("MKL-DNN is not used for pooling3d!\n", 0); + + const Nd4jLong iStride0 = input.stridesOf()[0]; + const Nd4jLong iStride1 = input.stridesOf()[1]; + const Nd4jLong iStride2 = input.stridesOf()[2]; + const Nd4jLong iStride3 = input.stridesOf()[3]; + const Nd4jLong iStride4 = input.stridesOf()[4]; + const Nd4jLong oStride0 = output.stridesOf()[0]; + const Nd4jLong oStride1 = output.stridesOf()[1]; + const Nd4jLong oStride2 = output.stridesOf()[2]; + const Nd4jLong oStride3 = output.stridesOf()[3]; + const Nd4jLong oStride4 = output.stridesOf()[4]; + const Nd4jLong iStep2 = dD * iStride2; + const Nd4jLong iStep3 = dH * iStride3; + const Nd4jLong iStep4 = dW * iStride4; + const int kProd = kD * kH * kW; + + if (poolingMode == 0) { // max + auto func = PRAGMA_THREADS_FOR_3D { + Nd4jLong dstart, hstart, wstart, dend, hend, wend; + T sum, *pIn; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int od = start_z; od < stop_z; od += inc_z) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pIn = in + b * iStride0 + c * iStride1; + + dstart = od * sD - pD; + hstart = oh * sH - pH; + wstart = ow * sW - pW; + dend = dstart + kDEff; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (dstart < 0) dstart += dD * ((-dstart + dD - 1) / dD); + if (hstart < 0) hstart += dH * ((-hstart + dH - 1) / dH); + if (wstart < 0) wstart += dW * ((-wstart + dW - 1) / dW); + if (dend > iD) dend -= dD * ((dend - iD + dD - 1) / dD); + if (hend > iH) hend -= dH * ((hend - iH + dH - 1) / dH); + if (wend > iW) wend -= dW * ((wend - iW + dW - 1) / dW); + + dstart *= iStride2; + dend *= iStride2; + hstart *= iStride3; + hend *= iStride3; + wstart *= iStride4; + wend *= iStride4; + + sum = -DataTypeUtils::max(); + + for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) { + T val = pIn[kd + kh + kw]; + if (val > sum) sum = val; } - }; - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); + out[b * oStride0 + c * oStride1 + od * oStride2 + + oh * oStride3 + ow * oStride4] = sum; + } } -/*************************************************************************/ - else if(poolingMode == 2) { // pnorm - auto func = PRAGMA_THREADS_FOR_3D { - Nd4jLong dstart, hstart, wstart, dend, hend, wend; - T sum, *pIn; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int od = start_z; od < stop_z; od += inc_z) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - dstart = od * sD - pD; - hstart = oh * sH - pH; - wstart = ow * sW - pW; - dend = dstart + kDEff; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if (dend > iD) - dend -= dD * ((dend - iD + dD - 1) / dD); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - dstart *= iStride2; - dend *= iStride2; - hstart *= iStride3; - hend *= iStride3; - wstart *= iStride4; - wend *= iStride4; - - sum = static_cast(0.); - - for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) - sum += sd::math::nd4j_pow(sd::math::nd4j_abs(pIn[kd + kh + kw]), extraParam0); - - sum = sd::math::nd4j_pow(sum, (T) 1.f / extraParam0); - - out[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4] = sum; - } - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); - } - else { - nd4j_printf("ConvolutionUtils::pooling3d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); - throw std::runtime_error("Incorrect poooling3d mode"); + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); + } + /*************************************************************************/ + else if (poolingMode == 1) { // avg + auto func = PRAGMA_THREADS_FOR_3D { + Nd4jLong dstart, hstart, wstart, dend, hend, wend; + T sum, *pIn; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int od = start_z; od < stop_z; od += inc_z) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pIn = in + b * iStride0 + c * iStride1; + + dstart = od * sD - pD; + hstart = oh * sH - pH; + wstart = ow * sW - pW; + dend = dstart + kDEff; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (dstart < 0) dstart += dD * ((-dstart + dD - 1) / dD); + if (hstart < 0) hstart += dH * ((-hstart + dH - 1) / dH); + if (wstart < 0) wstart += dW * ((-wstart + dW - 1) / dW); + if (dend > iD) dend -= dD * ((dend - iD + dD - 1) / dD); + if (hend > iH) hend -= dH * ((hend - iH + dH - 1) / dH); + if (wend > iW) wend -= dW * ((wend - iW + dW - 1) / dW); + + dstart *= iStride2; + dend *= iStride2; + hstart *= iStride3; + hend *= iStride3; + wstart *= iStride4; + wend *= iStride4; + + sum = static_cast(0.); + + for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) + sum += pIn[kd + kh + kw]; + + if (extraParam0 == 0) // Exclude padding + sum /= sd::math::nd4j_ceil( + static_cast(dend - dstart) / + static_cast(iStep2)) * + sd::math::nd4j_ceil( + static_cast(hend - hstart) / + static_cast(iStep3)) * + sd::math::nd4j_ceil( + static_cast(wend - wstart) / + static_cast( + iStep4)); // Accounts for dilation + else if (extraParam0 == 1) // Include padding + sum /= kProd; + + out[b * oStride0 + c * oStride1 + od * oStride2 + + oh * oStride3 + ow * oStride4] = sum; + } } + } } - -void ConvolutionUtils::pooling3d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - BUILD_SINGLE_SELECTOR(input.dataType(), pooling3d_, (block, input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); + } + /*************************************************************************/ + else if (poolingMode == 2) { // pnorm + auto func = PRAGMA_THREADS_FOR_3D { + Nd4jLong dstart, hstart, wstart, dend, hend, wend; + T sum, *pIn; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int od = start_z; od < stop_z; od += inc_z) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pIn = in + b * iStride0 + c * iStride1; + + dstart = od * sD - pD; + hstart = oh * sH - pH; + wstart = ow * sW - pW; + dend = dstart + kDEff; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (dstart < 0) dstart += dD * ((-dstart + dD - 1) / dD); + if (hstart < 0) hstart += dH * ((-hstart + dH - 1) / dH); + if (wstart < 0) wstart += dW * ((-wstart + dW - 1) / dW); + if (dend > iD) dend -= dD * ((dend - iD + dD - 1) / dD); + if (hend > iH) hend -= dH * ((hend - iH + dH - 1) / dH); + if (wend > iW) wend -= dW * ((wend - iW + dW - 1) / dW); + + dstart *= iStride2; + dend *= iStride2; + hstart *= iStride3; + hend *= iStride3; + wstart *= iStride4; + wend *= iStride4; + + sum = static_cast(0.); + + for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) + sum += sd::math::nd4j_pow( + sd::math::nd4j_abs(pIn[kd + kh + kw]), + extraParam0); + + sum = sd::math::nd4j_pow(sum, (T)1.f / extraParam0); + + out[b * oStride0 + c * oStride1 + od * oStride2 + + oh * oStride3 + ow * oStride4] = sum; + } + } + } } - + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); + } else { + nd4j_printf( + "ConvolutionUtils::pooling3d: pooling mode argument can take three " + "values only: 0, 1, 2, but got %i instead !\n", + poolingMode); + throw std::runtime_error("Incorrect poooling3d mode"); + } } + +void ConvolutionUtils::pooling3d(sd::graph::Context& block, + const NDArray& input, NDArray& output, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const int poolingMode, const int extraParam0) { + BUILD_SINGLE_SELECTOR(input.dataType(), pooling3d_, + (block, input, output, kD, kH, kW, sD, sH, sW, pD, pH, + pW, dD, dH, dW, poolingMode, extraParam0), + FLOAT_TYPES); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling3dBP.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling3dBP.cpp index 02f6f57aca77..7bea63eedb56 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling3dBP.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling3dBP.cpp @@ -18,309 +18,337 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include #include +#include namespace sd { - namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// - template - static void pooling3dBP_(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - // input [bS, iC, iD, iH, iW] - // gradI [bS, iC, iD, iH, iW] -> gradI is output in this function - // gradO [bS, iC, oD, oH, oW] - - // initial zeroing of gradI - gradI.nullify(); - - T* in = const_cast(input).bufferAsT(); - T* gO = const_cast(gradO).bufferAsT(); - T* gI = gradI.bufferAsT(); - - const int kDEff = kD + (kD-1)*(dD-1); - const int kHEff = kH + (kH-1)*(dH-1); - const int kWEff = kW + (kW-1)*(dW-1); - - const int bS = gradI.sizeAt(0); - const int iC = gradI.sizeAt(1); - const int iD = gradI.sizeAt(2); - const int iH = gradI.sizeAt(3); - const int iW = gradI.sizeAt(4); - const int oC = gradO.sizeAt(1); - const int oD = gradO.sizeAt(2); - const int oH = gradO.sizeAt(3); - const int oW = gradO.sizeAt(4); - - nd4j_debug("MKL-DNN is not used for pooling3d_bp!\n", 0); - - const Nd4jLong iStride0 = input.stridesOf()[0]; - const Nd4jLong iStride1 = input.stridesOf()[1]; - const Nd4jLong iStride2 = input.stridesOf()[2]; - const Nd4jLong iStride3 = input.stridesOf()[3]; - const Nd4jLong iStride4 = input.stridesOf()[4]; - const Nd4jLong gIStride0 = gradI.stridesOf()[0]; - const Nd4jLong gIStride1 = gradI.stridesOf()[1]; - const Nd4jLong gIStride2 = gradI.stridesOf()[2]; - const Nd4jLong gIStride3 = gradI.stridesOf()[3]; - const Nd4jLong gIStride4 = gradI.stridesOf()[4]; - const Nd4jLong oStride0 = gradO.stridesOf()[0]; - const Nd4jLong oStride1 = gradO.stridesOf()[1]; - const Nd4jLong oStride2 = gradO.stridesOf()[2]; - const Nd4jLong oStride3 = gradO.stridesOf()[3]; - const Nd4jLong oStride4 = gradO.stridesOf()[4]; - const Nd4jLong iStep2 = dD*iStride2; - const Nd4jLong iStep3 = dH*iStride3; - const Nd4jLong iStep4 = dW*iStride4; - const Nd4jLong gIStep2 = dD*gIStride2; - const Nd4jLong gIStep3 = dH*gIStride3; - const Nd4jLong gIStep4 = dW*gIStride4; - const int kProd = kD*kH*kW; - - const bool sameStrides = iStride0 == gIStride0 && iStride1 == gIStride1 && iStride2 == gIStride2 && iStride3 == gIStride3 && iStride4 == gIStride4; - - if(poolingMode == 0) { // max - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW; - T sum, valO, *pIn, *pgI; - - for (int b = start_x; b < stop_x; b++) { - for (int c = start_y; c < stop_y; c++) { - for (int od = 0; od < oD; od++) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - dstart = od * sD - pD; - hstart = oh * sH - pH; - wstart = ow * sW - pW; - dend = dstart + kDEff; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if (dend > iD) - dend -= dD * ((dend - iD + dD - 1) / dD); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - sum = -DataTypeUtils::max(); - valO = gO[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4]; - - if (sameStrides) { - - dstart *= iStride2; - dend *= iStride2; - hstart *= iStride3; - hend *= iStride3; - wstart *= iStride4; - wend *= iStride4; - - maxKD = dstart; - maxKH = hstart; - maxKW = wstart; - - for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) { - T valIn = pIn[kd + kh + kw]; - if (valIn > sum) { - sum = valIn; - maxKD = kd; - maxKH = kh; - maxKW = kw; - } - } - gI[pIn - in + maxKD + maxKH + maxKW] += valO; - } else { - - // we set these to default values - maxKH = hstart; - maxKW = wstart; - maxKD = dstart; - - for (Nd4jLong kd = dstart; kd < dend; kd += dD) - for (Nd4jLong kh = hstart; kh < hend; kh += dH) - for (Nd4jLong kw = wstart; kw < wend; kw += dW) { - T valIn = pIn[kd * iStride2 + kh * iStride3 + kw * iStride4]; - if (valIn > sum) { - sum = valIn; - maxKD = kd; - maxKH = kh; - maxKW = kw; - } - } - - gI[b * gIStride0 + c * gIStride1 + maxKD * gIStride2 + maxKH * gIStride3 + maxKW * gIStride4] += valO; - } - } - } - } +template +static void pooling3dBP_(sd::graph::Context& block, const NDArray& input, + const NDArray& gradO, NDArray& gradI, const int kD, + const int kH, const int kW, const int sD, const int sH, + const int sW, const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const int poolingMode, const int extraParam0) { + // input [bS, iC, iD, iH, iW] + // gradI [bS, iC, iD, iH, iW] -> gradI is output in this function + // gradO [bS, iC, oD, oH, oW] + + // initial zeroing of gradI + gradI.nullify(); + + T* in = const_cast(input).bufferAsT(); + T* gO = const_cast(gradO).bufferAsT(); + T* gI = gradI.bufferAsT(); + + const int kDEff = kD + (kD - 1) * (dD - 1); + const int kHEff = kH + (kH - 1) * (dH - 1); + const int kWEff = kW + (kW - 1) * (dW - 1); + + const int bS = gradI.sizeAt(0); + const int iC = gradI.sizeAt(1); + const int iD = gradI.sizeAt(2); + const int iH = gradI.sizeAt(3); + const int iW = gradI.sizeAt(4); + const int oC = gradO.sizeAt(1); + const int oD = gradO.sizeAt(2); + const int oH = gradO.sizeAt(3); + const int oW = gradO.sizeAt(4); + + nd4j_debug("MKL-DNN is not used for pooling3d_bp!\n", 0); + + const Nd4jLong iStride0 = input.stridesOf()[0]; + const Nd4jLong iStride1 = input.stridesOf()[1]; + const Nd4jLong iStride2 = input.stridesOf()[2]; + const Nd4jLong iStride3 = input.stridesOf()[3]; + const Nd4jLong iStride4 = input.stridesOf()[4]; + const Nd4jLong gIStride0 = gradI.stridesOf()[0]; + const Nd4jLong gIStride1 = gradI.stridesOf()[1]; + const Nd4jLong gIStride2 = gradI.stridesOf()[2]; + const Nd4jLong gIStride3 = gradI.stridesOf()[3]; + const Nd4jLong gIStride4 = gradI.stridesOf()[4]; + const Nd4jLong oStride0 = gradO.stridesOf()[0]; + const Nd4jLong oStride1 = gradO.stridesOf()[1]; + const Nd4jLong oStride2 = gradO.stridesOf()[2]; + const Nd4jLong oStride3 = gradO.stridesOf()[3]; + const Nd4jLong oStride4 = gradO.stridesOf()[4]; + const Nd4jLong iStep2 = dD * iStride2; + const Nd4jLong iStep3 = dH * iStride3; + const Nd4jLong iStep4 = dW * iStride4; + const Nd4jLong gIStep2 = dD * gIStride2; + const Nd4jLong gIStep3 = dH * gIStride3; + const Nd4jLong gIStep4 = dW * gIStride4; + const int kProd = kD * kH * kW; + + const bool sameStrides = iStride0 == gIStride0 && iStride1 == gIStride1 && + iStride2 == gIStride2 && iStride3 == gIStride3 && + iStride4 == gIStride4; + + if (poolingMode == 0) { // max + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW; + T sum, valO, *pIn, *pgI; + + for (int b = start_x; b < stop_x; b++) { + for (int c = start_y; c < stop_y; c++) { + for (int od = 0; od < oD; od++) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pIn = in + b * iStride0 + c * iStride1; + + dstart = od * sD - pD; + hstart = oh * sH - pH; + wstart = ow * sW - pW; + dend = dstart + kDEff; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (dstart < 0) dstart += dD * ((-dstart + dD - 1) / dD); + if (hstart < 0) hstart += dH * ((-hstart + dH - 1) / dH); + if (wstart < 0) wstart += dW * ((-wstart + dW - 1) / dW); + if (dend > iD) dend -= dD * ((dend - iD + dD - 1) / dD); + if (hend > iH) hend -= dH * ((hend - iH + dH - 1) / dH); + if (wend > iW) wend -= dW * ((wend - iW + dW - 1) / dW); + + sum = -DataTypeUtils::max(); + valO = gO[b * oStride0 + c * oStride1 + od * oStride2 + + oh * oStride3 + ow * oStride4]; + + if (sameStrides) { + dstart *= iStride2; + dend *= iStride2; + hstart *= iStride3; + hend *= iStride3; + wstart *= iStride4; + wend *= iStride4; + + maxKD = dstart; + maxKH = hstart; + maxKW = wstart; + + for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) { + T valIn = pIn[kd + kh + kw]; + if (valIn > sum) { + sum = valIn; + maxKD = kd; + maxKH = kh; + maxKW = kw; } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); - } -/*************************************************************************/ - else if(poolingMode == 1) { // avg - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW; - T sum, valO, *pIn, *pgI; - - for (int b = start_x; b < stop_x; b++) { - for (int c = start_y; c < stop_y; c++) { - for (int od = 0; od < oD; od++) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pgI = gI + b * gIStride0 + c * gIStride1; - - dstart = od * sD - pD; - hstart = oh * sH - pH; - wstart = ow * sW - pW; - dend = dstart + kDEff; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if (dend > iD) - dend -= dD * ((dend - iD + dD - 1) / dD); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - dstart *= gIStride2; - dend *= gIStride2; - hstart *= gIStride3; - hend *= gIStride3; - wstart *= gIStride4; - wend *= gIStride4; - - valO = gO[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4]; - - if (extraParam0 == 0) //Exclude padding - valO /= sd::math::nd4j_ceil(static_cast(dend - dstart) / static_cast(gIStep2)) * sd::math::nd4j_ceil(static_cast(hend - hstart) / static_cast(gIStep3)) * sd::math::nd4j_ceil(static_cast(wend - wstart) / static_cast(gIStep4)); //Accounts for dilation - else if (extraParam0 == 1) //Include padding - valO /= kProd; - - for (Nd4jLong kd = dstart; kd < dend; kd += gIStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += gIStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += gIStep4) - pgI[kd + kh + kw] += valO; - } - } - } + } + gI[pIn - in + maxKD + maxKH + maxKW] += valO; + } else { + // we set these to default values + maxKH = hstart; + maxKW = wstart; + maxKD = dstart; + + for (Nd4jLong kd = dstart; kd < dend; kd += dD) + for (Nd4jLong kh = hstart; kh < hend; kh += dH) + for (Nd4jLong kw = wstart; kw < wend; kw += dW) { + T valIn = + pIn[kd * iStride2 + kh * iStride3 + kw * iStride4]; + if (valIn > sum) { + sum = valIn; + maxKD = kd; + maxKH = kh; + maxKW = kw; } - } - }; + } - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + gI[b * gIStride0 + c * gIStride1 + maxKD * gIStride2 + + maxKH * gIStride3 + maxKW * gIStride4] += valO; + } + } } -/*************************************************************************/ - else if(poolingMode == 2) { // pnorm - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW; - T sum, valO, *pIn, *pgI; - - for (int b = start_x; b < stop_x; b++) { - for (int c = start_y; c < stop_y; c++) { - for (int od = 0; od < oD; od++) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - pgI = gI + (pIn - in); - - dstart = od * sD - pD; - hstart = oh * sH - pH; - wstart = ow * sW - pW; - dend = dstart + kDEff; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if (dend > iD) - dend -= dD * ((dend - iD + dD - 1) / dD); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - sum = static_cast(0.); - valO = gO[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4]; - - if (sameStrides) { - - dstart *= iStride2; - dend *= iStride2; - hstart *= iStride3; - hend *= iStride3; - wstart *= iStride4; - wend *= iStride4; - - for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) - sum += sd::math::nd4j_pow(sd::math::nd4j_abs(pIn[kd + kh + kw]), extraParam0); - - valO *= sd::math::nd4j_pow(sum, ((T) 1.f - extraParam0) / extraParam0); - - for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) - pgI[kd + kh + kw] += valO * sd::math::nd4j_pow(sd::math::nd4j_abs(pIn[kd + kh + kw]),extraParam0 - (T) 1.f) * sd::math::nd4j_sgn(pIn[kd + kh + kw]); - } else { - for (Nd4jLong kd = dstart; kd < dend; kd += dD) - for (Nd4jLong kh = hstart; kh < hend; kh += dH) - for (Nd4jLong kw = wstart; kw < wend; kw += dW) - sum += sd::math::nd4j_pow(sd::math::nd4j_abs(pIn[kd * iStride2 + kh * iStride3 + kw * iStride4]), extraParam0); - - valO *= sd::math::nd4j_pow(sum, ((T) 1.f - extraParam0) / extraParam0); - - for (Nd4jLong kd = dstart; kd < dend; kd += dD) - for (Nd4jLong kh = hstart; kh < hend; kh += dH) - for (Nd4jLong kw = wstart; kw < wend; kw += dW) { - const auto inVal = pIn[kD * iStride2 + kh * iStride3 + kw * iStride4]; - pgI[kd * gIStride2 + kh * gIStride3 + kw * gIStride4] += valO * sd::math::nd4j_pow(sd::math::nd4j_abs(inVal), extraParam0 - 1.f) * sd::math::nd4j_sgn(inVal); - } - } - } - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } + /*************************************************************************/ + else if (poolingMode == 1) { // avg + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW; + T sum, valO, *pIn, *pgI; + + for (int b = start_x; b < stop_x; b++) { + for (int c = start_y; c < stop_y; c++) { + for (int od = 0; od < oD; od++) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pgI = gI + b * gIStride0 + c * gIStride1; + + dstart = od * sD - pD; + hstart = oh * sH - pH; + wstart = ow * sW - pW; + dend = dstart + kDEff; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (dstart < 0) dstart += dD * ((-dstart + dD - 1) / dD); + if (hstart < 0) hstart += dH * ((-hstart + dH - 1) / dH); + if (wstart < 0) wstart += dW * ((-wstart + dW - 1) / dW); + if (dend > iD) dend -= dD * ((dend - iD + dD - 1) / dD); + if (hend > iH) hend -= dH * ((hend - iH + dH - 1) / dH); + if (wend > iW) wend -= dW * ((wend - iW + dW - 1) / dW); + + dstart *= gIStride2; + dend *= gIStride2; + hstart *= gIStride3; + hend *= gIStride3; + wstart *= gIStride4; + wend *= gIStride4; + + valO = gO[b * oStride0 + c * oStride1 + od * oStride2 + + oh * oStride3 + ow * oStride4]; + + if (extraParam0 == 0) // Exclude padding + valO /= sd::math::nd4j_ceil( + static_cast(dend - dstart) / + static_cast(gIStep2)) * + sd::math::nd4j_ceil( + static_cast(hend - hstart) / + static_cast(gIStep3)) * + sd::math::nd4j_ceil( + static_cast(wend - wstart) / + static_cast( + gIStep4)); // Accounts for dilation + else if (extraParam0 == 1) // Include padding + valO /= kProd; + + for (Nd4jLong kd = dstart; kd < dend; kd += gIStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += gIStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += gIStep4) + pgI[kd + kh + kw] += valO; + } } - else { - nd4j_printf("ConvolutionUtils::pooling3dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); - throw ""; + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } + /*************************************************************************/ + else if (poolingMode == 2) { // pnorm + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW; + T sum, valO, *pIn, *pgI; + + for (int b = start_x; b < stop_x; b++) { + for (int c = start_y; c < stop_y; c++) { + for (int od = 0; od < oD; od++) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pIn = in + b * iStride0 + c * iStride1; + pgI = gI + (pIn - in); + + dstart = od * sD - pD; + hstart = oh * sH - pH; + wstart = ow * sW - pW; + dend = dstart + kDEff; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (dstart < 0) dstart += dD * ((-dstart + dD - 1) / dD); + if (hstart < 0) hstart += dH * ((-hstart + dH - 1) / dH); + if (wstart < 0) wstart += dW * ((-wstart + dW - 1) / dW); + if (dend > iD) dend -= dD * ((dend - iD + dD - 1) / dD); + if (hend > iH) hend -= dH * ((hend - iH + dH - 1) / dH); + if (wend > iW) wend -= dW * ((wend - iW + dW - 1) / dW); + + sum = static_cast(0.); + valO = gO[b * oStride0 + c * oStride1 + od * oStride2 + + oh * oStride3 + ow * oStride4]; + + if (sameStrides) { + dstart *= iStride2; + dend *= iStride2; + hstart *= iStride3; + hend *= iStride3; + wstart *= iStride4; + wend *= iStride4; + + for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) + sum += sd::math::nd4j_pow( + sd::math::nd4j_abs(pIn[kd + kh + kw]), + extraParam0); + + valO *= sd::math::nd4j_pow( + sum, ((T)1.f - extraParam0) / extraParam0); + + for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) + pgI[kd + kh + kw] += + valO * + sd::math::nd4j_pow( + sd::math::nd4j_abs(pIn[kd + kh + kw]), + extraParam0 - (T)1.f) * + sd::math::nd4j_sgn(pIn[kd + kh + kw]); + } else { + for (Nd4jLong kd = dstart; kd < dend; kd += dD) + for (Nd4jLong kh = hstart; kh < hend; kh += dH) + for (Nd4jLong kw = wstart; kw < wend; kw += dW) + sum += sd::math::nd4j_pow( + sd::math::nd4j_abs( + pIn[kd * iStride2 + kh * iStride3 + + kw * iStride4]), + extraParam0); + + valO *= sd::math::nd4j_pow( + sum, ((T)1.f - extraParam0) / extraParam0); + + for (Nd4jLong kd = dstart; kd < dend; kd += dD) + for (Nd4jLong kh = hstart; kh < hend; kh += dH) + for (Nd4jLong kw = wstart; kw < wend; kw += dW) { + const auto inVal = + pIn[kD * iStride2 + kh * iStride3 + kw * iStride4]; + pgI[kd * gIStride2 + kh * gIStride3 + kw * gIStride4] += + valO * + sd::math::nd4j_pow( + sd::math::nd4j_abs(inVal), + extraParam0 - 1.f) * + sd::math::nd4j_sgn(inVal); + } + } + } } + } } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } else { + nd4j_printf( + "ConvolutionUtils::pooling3dBP: pooling mode argument can take three " + "values only: 0, 1, 2, but got %i instead !\n", + poolingMode); + throw ""; + } +} - void ConvolutionUtils::pooling3dBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBP_, (block, input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); - } - } +void ConvolutionUtils::pooling3dBP(sd::graph::Context& block, + const NDArray& input, const NDArray& gradO, + NDArray& gradI, const int kD, const int kH, + const int kW, const int sD, const int sH, + const int sW, const int pD, const int pH, + const int pW, const int dD, const int dH, + const int dW, const int poolingMode, + const int extraParam0) { + BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBP_, + (block, input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, + pH, pW, dD, dH, dW, poolingMode, extraParam0), + FLOAT_TYPES); } +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_sconv2d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_sconv2d.cpp index 742f88c3bad2..0f76569916a3 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_sconv2d.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_sconv2d.cpp @@ -18,56 +18,84 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include #include +#include namespace sd { - namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weightsDepth [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - // weightsPoint [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] - // bias [oC], oC = iC*mC if weightsPoint=nullptr - // output is [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) +static void sconv2d_(sd::graph::Context& block, const NDArray* input, + const NDArray* weightsDepth, const NDArray* weightsPoint, + const NDArray* bias, NDArray* output, const int kH, + const int kW, const int sH, const int sW, int pH, int pW, + const int dH, const int dW, const int paddingMode, + const int isNCHW, const int wFormat) { + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weightsDepth [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // weightsPoint [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] + // bias [oC], oC = iC*mC if weightsPoint=nullptr + // output is [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 1-NCHW, 0-NHWC + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 1-NCHW, 0-NHWC - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weightsDepth->sizeAt(indWmC); // channels multiplier + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier, output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weightsDepth->sizeAt(indWmC); // channels multiplier - NDArray* outputDepth = output; - if(weightsPoint) // if pointwise convolution is expected - outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector({bS, oH, oW, iC*mC}) : std::vector({bS, iC*mC, oH, oW}), input->dataType(), input->getContext()); + NDArray* outputDepth = output; + if (weightsPoint) // if pointwise convolution is expected + outputDepth = + new NDArray(output->ordering(), + !isNCHW ? std::vector({bS, oH, oW, iC * mC}) + : std::vector({bS, iC * mC, oH, oW}), + input->dataType(), input->getContext()); - // ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- // - ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW, wFormat); - - // ----- perform pointwise convolution (oH = iH, oW = iW) ----- // - if (weightsPoint) { - ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW, wFormat); // in this case oH=iH, oW=iW - delete outputDepth; - } - } - -void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); - } + // ----- perform depthwise convolution (if weightsPoint is absent then oC = + // iC*mC) ----- // + ConvolutionUtils::depthwiseConv2d( + block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, + kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); + // ----- perform pointwise convolution (oH = iH, oW = iW) ----- // + if (weightsPoint) { + ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1, + 1, 1, 1, 0, 0, 1, 1, paddingMode, isNCHW, + wFormat); // in this case oH=iH, oW=iW + delete outputDepth; + } } + +void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, + const NDArray* weightsDepth, + const NDArray* weightsPoint, const NDArray* bias, + NDArray* output, const int kH, const int kW, + const int sH, const int sW, int pH, int pW, + const int dH, const int dW, + const int paddingMode, const int isNCHW, + const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE( + input->dataType(), sconv2d_, + (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, + pH, pW, dH, dW, paddingMode, isNCHW, wFormat), + FLOAT_TYPES); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling2d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling2d.cpp index ffdd5c34b361..6e07a5e0110b 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling2d.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling2d.cpp @@ -18,63 +18,71 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include #include +#include namespace sd { - namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void upsampling2d_(const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) { - // input has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) - // output has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) - - const T* x = input.bufferAsT(); - T* z = output.bufferAsT(); +static void upsampling2d_(const NDArray& input, NDArray& output, + const int factorH, const int factorW, + const bool isNCHW) { + // input has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) + // output has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, + // factorH*iH, factorW*iW, iC] (NHWC) - const uint dimIH = isNCHW ? 2 : 1; - const uint dimIC = isNCHW ? 1 : 3; + const T* x = input.bufferAsT(); + T* z = output.bufferAsT(); - const uint bS = input.sizeAt(0); - const uint iC = input.sizeAt(dimIC); - const uint oH = output.sizeAt(dimIH); - const uint oW = output.sizeAt(dimIH + 1); + const uint dimIH = isNCHW ? 2 : 1; + const uint dimIC = isNCHW ? 1 : 3; - const Nd4jLong xStride0 = input.stridesOf()[0]; - const Nd4jLong xStride1 = input.stridesOf()[dimIC]; - const Nd4jLong xStride2 = input.stridesOf()[dimIH]; - const Nd4jLong xStride3 = input.stridesOf()[dimIH + 1]; + const uint bS = input.sizeAt(0); + const uint iC = input.sizeAt(dimIC); + const uint oH = output.sizeAt(dimIH); + const uint oW = output.sizeAt(dimIH + 1); - const Nd4jLong zStride0 = output.stridesOf()[0]; - const Nd4jLong zStride1 = output.stridesOf()[dimIC]; - const Nd4jLong zStride2 = output.stridesOf()[dimIH]; - const Nd4jLong zStride3 = output.stridesOf()[dimIH + 1]; + const Nd4jLong xStride0 = input.stridesOf()[0]; + const Nd4jLong xStride1 = input.stridesOf()[dimIC]; + const Nd4jLong xStride2 = input.stridesOf()[dimIH]; + const Nd4jLong xStride3 = input.stridesOf()[dimIH + 1]; - // loop through output array - auto func = PRAGMA_THREADS_FOR_3D { - uint xCoord2, xCoord3; - for (uint b = start_x; b < stop_x; b += inc_x) { - for (uint c = start_y; c < stop_y; c += inc_y) { - for (uint h = start_z; h < stop_z; h += inc_z) { - for (uint w = 0; w < oW; ++w) { - xCoord2 = h / factorH; - xCoord3 = w / factorW; + const Nd4jLong zStride0 = output.stridesOf()[0]; + const Nd4jLong zStride1 = output.stridesOf()[dimIC]; + const Nd4jLong zStride2 = output.stridesOf()[dimIH]; + const Nd4jLong zStride3 = output.stridesOf()[dimIH + 1]; - z[b * zStride0 + c * zStride1 + h * zStride2 + w * zStride3] = x[b * xStride0 + c * xStride1 + xCoord2 * xStride2 + xCoord3 * xStride3]; - } - } - } - } - }; + // loop through output array + auto func = PRAGMA_THREADS_FOR_3D { + uint xCoord2, xCoord3; + for (uint b = start_x; b < stop_x; b += inc_x) { + for (uint c = start_y; c < stop_y; c += inc_y) { + for (uint h = start_z; h < stop_z; h += inc_z) { + for (uint w = 0; w < oW; ++w) { + xCoord2 = h / factorH; + xCoord3 = w / factorW; - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oH, 1); + z[b * zStride0 + c * zStride1 + h * zStride2 + w * zStride3] = + x[b * xStride0 + c * xStride1 + xCoord2 * xStride2 + + xCoord3 * xStride3]; + } } + } + } + }; - -void ConvolutionUtils::upsampling2d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) { - BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2d_, (input, output, factorH, factorW, isNCHW), FLOAT_TYPES); + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oH, 1); } +void ConvolutionUtils::upsampling2d(sd::graph::Context& block, + const NDArray& input, NDArray& output, + const int factorH, const int factorW, + const bool isNCHW) { + BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2d_, + (input, output, factorH, factorW, isNCHW), FLOAT_TYPES); } -} + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling2dBP.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling2dBP.cpp index aba46aabc119..623ae6645411 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling2dBP.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling2dBP.cpp @@ -18,69 +18,74 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include #include +#include namespace sd { - namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void upsampling2dBP_(const NDArray& gradO, NDArray& gradI, const bool isNCHW) { - // gradO has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) - // gradI has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) - - const T* x = gradO.bufferAsT(); - T* z = gradI.bufferAsT(); - - const uint dimIH = isNCHW ? 2 : 1; - const uint dimIC = isNCHW ? 1 : 3; - - const uint bS = gradI.sizeAt(0); - const uint iC = gradI.sizeAt(dimIC); - const uint iH = gradI.sizeAt(dimIH); - const uint iW = gradI.sizeAt(dimIH + 1); - - const uint factorH = gradO.sizeAt(dimIH) / iH; - const uint factorW = gradO.sizeAt(dimIH + 1) / iW; - - const Nd4jLong xStride0 = gradO.stridesOf()[0]; - const Nd4jLong xStride1 = gradO.stridesOf()[dimIC]; - const Nd4jLong xStride2 = gradO.stridesOf()[dimIH]; - const Nd4jLong xStride3 = gradO.stridesOf()[dimIH + 1]; - - const Nd4jLong zStride0 = gradI.stridesOf()[0]; - const Nd4jLong zStride1 = gradI.stridesOf()[dimIC]; - const Nd4jLong zStride2 = gradI.stridesOf()[dimIH]; - const Nd4jLong zStride3 = gradI.stridesOf()[dimIH + 1]; - - // loop through output array - auto func = PRAGMA_THREADS_FOR_3D { - for (uint b = start_x; b < stop_x; b += inc_x) { - for (uint c = start_y; c < stop_y; c += inc_y) { - for (uint h = start_z; h < stop_z; h += inc_z) { - for (uint w = 0; w < iW; ++w) { - - const auto zOffset = b * zStride0 + c * zStride1 + h * zStride2 + w * zStride3; - - z[zOffset] = 0; - - for (uint xh = h * factorH; xh < h * factorH + factorH; ++xh) - for (uint xw = w * factorW; xw < w * factorW + factorW; ++xw) - z[zOffset] += x[b * xStride0 + c * xStride1 + xh * xStride2 + xw * xStride3]; - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, iH, 1); +static void upsampling2dBP_(const NDArray& gradO, NDArray& gradI, + const bool isNCHW) { + // gradO has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, + // factorH*iH, factorW*iW, iC] (NHWC) gradI has shape [bS, iC, iH, iW] (NCHW) + // or [bS, iH, iW, iC] (NHWC) + + const T* x = gradO.bufferAsT(); + T* z = gradI.bufferAsT(); + + const uint dimIH = isNCHW ? 2 : 1; + const uint dimIC = isNCHW ? 1 : 3; + + const uint bS = gradI.sizeAt(0); + const uint iC = gradI.sizeAt(dimIC); + const uint iH = gradI.sizeAt(dimIH); + const uint iW = gradI.sizeAt(dimIH + 1); + + const uint factorH = gradO.sizeAt(dimIH) / iH; + const uint factorW = gradO.sizeAt(dimIH + 1) / iW; + + const Nd4jLong xStride0 = gradO.stridesOf()[0]; + const Nd4jLong xStride1 = gradO.stridesOf()[dimIC]; + const Nd4jLong xStride2 = gradO.stridesOf()[dimIH]; + const Nd4jLong xStride3 = gradO.stridesOf()[dimIH + 1]; + + const Nd4jLong zStride0 = gradI.stridesOf()[0]; + const Nd4jLong zStride1 = gradI.stridesOf()[dimIC]; + const Nd4jLong zStride2 = gradI.stridesOf()[dimIH]; + const Nd4jLong zStride3 = gradI.stridesOf()[dimIH + 1]; + + // loop through output array + auto func = PRAGMA_THREADS_FOR_3D { + for (uint b = start_x; b < stop_x; b += inc_x) { + for (uint c = start_y; c < stop_y; c += inc_y) { + for (uint h = start_z; h < stop_z; h += inc_z) { + for (uint w = 0; w < iW; ++w) { + const auto zOffset = + b * zStride0 + c * zStride1 + h * zStride2 + w * zStride3; + + z[zOffset] = 0; + + for (uint xh = h * factorH; xh < h * factorH + factorH; ++xh) + for (uint xw = w * factorW; xw < w * factorW + factorW; ++xw) + z[zOffset] += x[b * xStride0 + c * xStride1 + xh * xStride2 + + xw * xStride3]; + } } + } + } + }; -void ConvolutionUtils::upsampling2dBP(sd::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) { - BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling2dBP_, (gradO, gradI, isNCHW), FLOAT_TYPES); + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, iH, 1); } +void ConvolutionUtils::upsampling2dBP(sd::graph::Context& block, + const NDArray& gradO, NDArray& gradI, + const bool isNCHW) { + BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling2dBP_, + (gradO, gradI, isNCHW), FLOAT_TYPES); } -} + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling3d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling3d.cpp index 7b86ec5a150f..b73981e763aa 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling3d.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling3d.cpp @@ -18,72 +18,80 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include #include +#include namespace sd { - namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void upsampling3d_(const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) { - // input has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) - // output has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) - - const T* x = input.bufferAsT(); - T* z = output.bufferAsT(); - - const uint dimID = isNCDHW ? 2 : 1; - const uint dimIC = isNCDHW ? 1 : 4; - - const uint bS = input.sizeAt(0); - const uint iC = input.sizeAt(dimIC); - const uint oD = output.sizeAt(dimID); - const uint oH = output.sizeAt(dimID + 1); - const uint oW = output.sizeAt(dimID + 2); - - const Nd4jLong xStride0 = input.stridesOf()[0]; - const Nd4jLong xStride1 = input.stridesOf()[dimIC]; - const Nd4jLong xStride2 = input.stridesOf()[dimID]; - const Nd4jLong xStride3 = input.stridesOf()[dimID + 1]; - const Nd4jLong xStride4 = input.stridesOf()[dimID + 2]; - - const Nd4jLong zStride0 = output.stridesOf()[0]; - const Nd4jLong zStride1 = output.stridesOf()[dimIC]; - const Nd4jLong zStride2 = output.stridesOf()[dimID]; - const Nd4jLong zStride3 = output.stridesOf()[dimID + 1]; - const Nd4jLong zStride4 = output.stridesOf()[dimID + 2]; - - // loop through output array - auto func = PRAGMA_THREADS_FOR_3D { - uint xCoord2, xCoord3, xCoord4; - - for (uint b = start_x; b < stop_x; b += inc_x) { - for (uint c = start_y; c < stop_y; c += inc_y) { - for (uint d = start_z; d < stop_z; d += inc_z) { - for (uint h = 0; h < oH; ++h) { - for (uint w = 0; w < oW; ++w) { - - xCoord2 = d / factorD; - xCoord3 = h / factorH; - xCoord4 = w / factorW; - - z[b * zStride0 + c * zStride1 + d * zStride2 + h * zStride3 + w * zStride4] = x[ - b * xStride0 + c * xStride1 + xCoord2 * xStride2 + xCoord3 * xStride3 + - xCoord4 * xStride4]; - } - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); - } - - void ConvolutionUtils::upsampling3d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) { - BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3d_, (input, output, factorD, factorH, factorW, isNCDHW), FLOAT_TYPES); +static void upsampling3d_(const NDArray& input, NDArray& output, + const int factorD, const int factorH, + const int factorW, const bool isNCDHW) { + // input has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] + // (NDHWC) output has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] + // (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) + + const T* x = input.bufferAsT(); + T* z = output.bufferAsT(); + + const uint dimID = isNCDHW ? 2 : 1; + const uint dimIC = isNCDHW ? 1 : 4; + + const uint bS = input.sizeAt(0); + const uint iC = input.sizeAt(dimIC); + const uint oD = output.sizeAt(dimID); + const uint oH = output.sizeAt(dimID + 1); + const uint oW = output.sizeAt(dimID + 2); + + const Nd4jLong xStride0 = input.stridesOf()[0]; + const Nd4jLong xStride1 = input.stridesOf()[dimIC]; + const Nd4jLong xStride2 = input.stridesOf()[dimID]; + const Nd4jLong xStride3 = input.stridesOf()[dimID + 1]; + const Nd4jLong xStride4 = input.stridesOf()[dimID + 2]; + + const Nd4jLong zStride0 = output.stridesOf()[0]; + const Nd4jLong zStride1 = output.stridesOf()[dimIC]; + const Nd4jLong zStride2 = output.stridesOf()[dimID]; + const Nd4jLong zStride3 = output.stridesOf()[dimID + 1]; + const Nd4jLong zStride4 = output.stridesOf()[dimID + 2]; + + // loop through output array + auto func = PRAGMA_THREADS_FOR_3D { + uint xCoord2, xCoord3, xCoord4; + + for (uint b = start_x; b < stop_x; b += inc_x) { + for (uint c = start_y; c < stop_y; c += inc_y) { + for (uint d = start_z; d < stop_z; d += inc_z) { + for (uint h = 0; h < oH; ++h) { + for (uint w = 0; w < oW; ++w) { + xCoord2 = d / factorD; + xCoord3 = h / factorH; + xCoord4 = w / factorW; + + z[b * zStride0 + c * zStride1 + d * zStride2 + h * zStride3 + + w * zStride4] = + x[b * xStride0 + c * xStride1 + xCoord2 * xStride2 + + xCoord3 * xStride3 + xCoord4 * xStride4]; + } + } } + } + } + }; + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); } + +void ConvolutionUtils::upsampling3d(sd::graph::Context& block, + const NDArray& input, NDArray& output, + const int factorD, const int factorH, + const int factorW, const bool isNCDHW) { + BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3d_, + (input, output, factorD, factorH, factorW, isNCDHW), + FLOAT_TYPES); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling3dBP.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling3dBP.cpp index 93c2746fbfc6..40480b9e6360 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling3dBP.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling3dBP.cpp @@ -18,78 +18,82 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include #include +#include namespace sd { - namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void upsampling3dBP_(const NDArray& gradO, NDArray& gradI, const bool isNCDHW) { - - // input has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) - // output has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) - - const T* x = gradO.bufferAsT(); - T* z = gradI.bufferAsT(); - - const uint dimID = isNCDHW ? 2 : 1; - const uint dimIC = isNCDHW ? 1 : 4; - - const uint bS = gradI.sizeAt(0); - const uint iC = gradI.sizeAt(dimIC); - const uint iD = gradI.sizeAt(dimID); - const uint iH = gradI.sizeAt(dimID + 1); - const uint iW = gradI.sizeAt(dimID + 2); - - const uint factorD = gradO.sizeAt(dimID) / iD; - const uint factorH = gradO.sizeAt(dimID + 1) / iH; - const uint factorW = gradO.sizeAt(dimID + 2) / iW; - - const Nd4jLong xStride0 = gradO.stridesOf()[0]; - const Nd4jLong xStride1 = gradO.stridesOf()[dimIC]; - const Nd4jLong xStride2 = gradO.stridesOf()[dimID]; - const Nd4jLong xStride3 = gradO.stridesOf()[dimID + 1]; - const Nd4jLong xStride4 = gradO.stridesOf()[dimID + 2]; - - const Nd4jLong zStride0 = gradI.stridesOf()[0]; - const Nd4jLong zStride1 = gradI.stridesOf()[dimIC]; - const Nd4jLong zStride2 = gradI.stridesOf()[dimID]; - const Nd4jLong zStride3 = gradI.stridesOf()[dimID + 1]; - const Nd4jLong zStride4 = gradI.stridesOf()[dimID + 2]; - - // loop through output array - auto func = PRAGMA_THREADS_FOR_3D { - for (uint b = start_x; b < stop_x; b += inc_x) { - for (uint c = start_y; c < stop_y; c += inc_y) { - for (uint d = start_z; d < stop_z; d += inc_z) { - for (uint h = 0; h < iH; ++h) { - for (uint w = 0; w < iW; ++w) { - - const auto zOffset = b * zStride0 + c * zStride1 + d * zStride2 + h * zStride3 + w * zStride4; - - z[zOffset] = 0; - - for (uint xd = d * factorD; xd < d * factorD + factorD; ++xd) - for (uint xh = h * factorH; xh < h * factorH + factorH; ++xh) - for (uint xw = w * factorW; xw < w * factorW + factorW; ++xw) - z[zOffset] += x[b * xStride0 + c * xStride1 + xd * xStride2 + xh * xStride3 + xw * xStride4]; - } - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, iD, 1); - } - - - void ConvolutionUtils::upsampling3dBP(sd::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) { - BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling3dBP_, (gradO, gradI, isNCHW), FLOAT_TYPES); +static void upsampling3dBP_(const NDArray& gradO, NDArray& gradI, + const bool isNCDHW) { + // input has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] + // (NDHWC) output has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] + // (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) + + const T* x = gradO.bufferAsT(); + T* z = gradI.bufferAsT(); + + const uint dimID = isNCDHW ? 2 : 1; + const uint dimIC = isNCDHW ? 1 : 4; + + const uint bS = gradI.sizeAt(0); + const uint iC = gradI.sizeAt(dimIC); + const uint iD = gradI.sizeAt(dimID); + const uint iH = gradI.sizeAt(dimID + 1); + const uint iW = gradI.sizeAt(dimID + 2); + + const uint factorD = gradO.sizeAt(dimID) / iD; + const uint factorH = gradO.sizeAt(dimID + 1) / iH; + const uint factorW = gradO.sizeAt(dimID + 2) / iW; + + const Nd4jLong xStride0 = gradO.stridesOf()[0]; + const Nd4jLong xStride1 = gradO.stridesOf()[dimIC]; + const Nd4jLong xStride2 = gradO.stridesOf()[dimID]; + const Nd4jLong xStride3 = gradO.stridesOf()[dimID + 1]; + const Nd4jLong xStride4 = gradO.stridesOf()[dimID + 2]; + + const Nd4jLong zStride0 = gradI.stridesOf()[0]; + const Nd4jLong zStride1 = gradI.stridesOf()[dimIC]; + const Nd4jLong zStride2 = gradI.stridesOf()[dimID]; + const Nd4jLong zStride3 = gradI.stridesOf()[dimID + 1]; + const Nd4jLong zStride4 = gradI.stridesOf()[dimID + 2]; + + // loop through output array + auto func = PRAGMA_THREADS_FOR_3D { + for (uint b = start_x; b < stop_x; b += inc_x) { + for (uint c = start_y; c < stop_y; c += inc_y) { + for (uint d = start_z; d < stop_z; d += inc_z) { + for (uint h = 0; h < iH; ++h) { + for (uint w = 0; w < iW; ++w) { + const auto zOffset = b * zStride0 + c * zStride1 + d * zStride2 + + h * zStride3 + w * zStride4; + + z[zOffset] = 0; + + for (uint xd = d * factorD; xd < d * factorD + factorD; ++xd) + for (uint xh = h * factorH; xh < h * factorH + factorH; ++xh) + for (uint xw = w * factorW; xw < w * factorW + factorW; ++xw) + z[zOffset] += + x[b * xStride0 + c * xStride1 + xd * xStride2 + + xh * xStride3 + xw * xStride4]; + } + } } + } + } + }; + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, iD, 1); } + +void ConvolutionUtils::upsampling3dBP(sd::graph::Context& block, + const NDArray& gradO, NDArray& gradI, + const bool isNCHW) { + BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling3dBP_, + (gradO, gradI, isNCHW), FLOAT_TYPES); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_vol2col.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_vol2col.cpp index 4c8b5bad1474..5f560950d925 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_vol2col.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_vol2col.cpp @@ -18,130 +18,153 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include #include +#include namespace sd { - namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW] template -static void vol2col_(const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - - const int bS = volume.sizeAt(0); - const int iC = volume.sizeAt(1); - const int iD = volume.sizeAt(2); - const int iH = volume.sizeAt(3); - const int iW = volume.sizeAt(4); - const int kD = columns.sizeAt(2); - const int kH = columns.sizeAt(3); - const int kW = columns.sizeAt(4); - const int oD = columns.sizeAt(5); - const int oH = columns.sizeAt(6); - const int oW = columns.sizeAt(7); - const Nd4jLong colStride0 = columns.stridesOf()[0]; - const Nd4jLong colStride1 = columns.stridesOf()[1]; - const Nd4jLong colStride2 = columns.stridesOf()[2]; - const Nd4jLong colStride3 = columns.stridesOf()[3]; - const Nd4jLong colStride4 = columns.stridesOf()[4]; - const Nd4jLong colStride5 = columns.stridesOf()[5]; - const Nd4jLong colStride6 = columns.stridesOf()[6]; - const Nd4jLong colStride7 = columns.stridesOf()[7]; - const Nd4jLong volStride0 = volume.stridesOf()[0]; - const Nd4jLong volStride1 = volume.stridesOf()[1]; - const Nd4jLong volStride2 = volume.stridesOf()[2]; - const Nd4jLong volStride3 = volume.stridesOf()[3]; - const Nd4jLong volStride4 = volume.stridesOf()[4]; - - T* colBuff = columns.bufferAsT(); - T* volBuff = const_cast(volume).bufferAsT(); - - - if (volume.ordering() == 'c' && columns.ordering() == 'c' && shape::strideDescendingCAscendingF(volume.shapeInfo()) && shape::strideDescendingCAscendingF(columns.shapeInfo())) { - - auto func = PRAGMA_THREADS_FOR_3D { - T *col, *vol; - int volDep, volRow, volCol; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int kDep = start_z; kDep < stop_z; kDep += inc_z) { - for (int kRow = 0; kRow < kH; ++kRow) { - for (int kCol = 0; kCol < kW; ++kCol) { - for (int colD = 0; colD < oD; ++colD) { - for (int colH = 0; colH < oH; ++colH) { - for (int colW = 0; colW < oW; ++colW) { - - volDep = (-pD + kDep * dD) + colD * sD; - volRow = (-pH + kRow * dH) + colH * sH; - volCol = (-pW + kCol * dW) + colW * sW; - - col = colBuff + b * colStride0 + c * colStride1 + kDep * colStride2 + kRow * colStride3 + kCol * colStride4 + colD * colStride5 + colH * colStride6 + colW * colStride7; - - if (static_cast(volDep) >= static_cast(iD) || static_cast(volRow) >= static_cast(iH) || static_cast(volCol) >= static_cast(iW)) - *col = static_cast(0.); - else { - vol = volBuff + b * volStride0 + c * volStride1 + volDep * volStride2 + volRow * volStride3 + volCol * volStride4; - *col = *vol; - } - } - } - } - } - } - } - } +static void vol2col_(const NDArray& volume, NDArray& columns, const int sD, + const int sH, const int sW, const int pD, const int pH, + const int pW, const int dD, const int dH, const int dW) { + const int bS = volume.sizeAt(0); + const int iC = volume.sizeAt(1); + const int iD = volume.sizeAt(2); + const int iH = volume.sizeAt(3); + const int iW = volume.sizeAt(4); + const int kD = columns.sizeAt(2); + const int kH = columns.sizeAt(3); + const int kW = columns.sizeAt(4); + const int oD = columns.sizeAt(5); + const int oH = columns.sizeAt(6); + const int oW = columns.sizeAt(7); + const Nd4jLong colStride0 = columns.stridesOf()[0]; + const Nd4jLong colStride1 = columns.stridesOf()[1]; + const Nd4jLong colStride2 = columns.stridesOf()[2]; + const Nd4jLong colStride3 = columns.stridesOf()[3]; + const Nd4jLong colStride4 = columns.stridesOf()[4]; + const Nd4jLong colStride5 = columns.stridesOf()[5]; + const Nd4jLong colStride6 = columns.stridesOf()[6]; + const Nd4jLong colStride7 = columns.stridesOf()[7]; + const Nd4jLong volStride0 = volume.stridesOf()[0]; + const Nd4jLong volStride1 = volume.stridesOf()[1]; + const Nd4jLong volStride2 = volume.stridesOf()[2]; + const Nd4jLong volStride3 = volume.stridesOf()[3]; + const Nd4jLong volStride4 = volume.stridesOf()[4]; + + T* colBuff = columns.bufferAsT(); + T* volBuff = const_cast(volume).bufferAsT(); + + if (volume.ordering() == 'c' && columns.ordering() == 'c' && + shape::strideDescendingCAscendingF(volume.shapeInfo()) && + shape::strideDescendingCAscendingF(columns.shapeInfo())) { + auto func = PRAGMA_THREADS_FOR_3D { + T *col, *vol; + int volDep, volRow, volCol; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int kDep = start_z; kDep < stop_z; kDep += inc_z) { + for (int kRow = 0; kRow < kH; ++kRow) { + for (int kCol = 0; kCol < kW; ++kCol) { + for (int colD = 0; colD < oD; ++colD) { + for (int colH = 0; colH < oH; ++colH) { + for (int colW = 0; colW < oW; ++colW) { + volDep = (-pD + kDep * dD) + colD * sD; + volRow = (-pH + kRow * dH) + colH * sH; + volCol = (-pW + kCol * dW) + colW * sW; + + col = colBuff + b * colStride0 + c * colStride1 + + kDep * colStride2 + kRow * colStride3 + + kCol * colStride4 + colD * colStride5 + + colH * colStride6 + colW * colStride7; + + if (static_cast(volDep) >= + static_cast(iD) || + static_cast(volRow) >= + static_cast(iH) || + static_cast(volCol) >= + static_cast(iW)) + *col = static_cast(0.); + else { + vol = volBuff + b * volStride0 + c * volStride1 + + volDep * volStride2 + volRow * volStride3 + + volCol * volStride4; + *col = *vol; + } } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, kD, 1); - - } else { - - auto func = PRAGMA_THREADS_FOR_2D { - T *col, *vol; - int volDep, volRow, volCol; - for (int b = start_x; b < stop_x; b++) { - for (int colD = start_y; colD < stop_y; colD++) { - for (int colH = 0; colH < oH; ++colH) { - for (int colW = 0; colW < oW; ++colW) { - for (int c = 0; c < iC; ++c) { - for (int kDep = 0; kDep < kD; ++kDep) { - for (int kRow = 0; kRow < kH; ++kRow) { - for (int kCol = 0; kCol < kW; ++kCol) { - - volDep = (-pD + kDep * dD) + colD * sD; - volRow = (-pH + kRow * dH) + colH * sH; - volCol = (-pW + kCol * dW) + colW * sW; - - col = colBuff + b * colStride0 + c * colStride1 + kDep * colStride2 + kRow * colStride3 + kCol * colStride4 + colD * colStride5 + colH * colStride6 + colW * colStride7; - - if (static_cast(volDep) >= static_cast(iD) || static_cast(volRow) >= static_cast(iH) || static_cast(volCol) >= static_cast(iW)) - *col = static_cast(0.f); - else { - vol = volBuff + b * volStride0 + c * volStride1 + volDep * volStride2 + volRow * volStride3 + volCol * volStride4; - *col = *vol; - } - } - } - } - } - } - } - } + } + } + } + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, kD, 1); + + } else { + auto func = PRAGMA_THREADS_FOR_2D { + T *col, *vol; + int volDep, volRow, volCol; + for (int b = start_x; b < stop_x; b++) { + for (int colD = start_y; colD < stop_y; colD++) { + for (int colH = 0; colH < oH; ++colH) { + for (int colW = 0; colW < oW; ++colW) { + for (int c = 0; c < iC; ++c) { + for (int kDep = 0; kDep < kD; ++kDep) { + for (int kRow = 0; kRow < kH; ++kRow) { + for (int kCol = 0; kCol < kW; ++kCol) { + volDep = (-pD + kDep * dD) + colD * sD; + volRow = (-pH + kRow * dH) + colH * sH; + volCol = (-pW + kCol * dW) + colW * sW; + + col = colBuff + b * colStride0 + c * colStride1 + + kDep * colStride2 + kRow * colStride3 + + kCol * colStride4 + colD * colStride5 + + colH * colStride6 + colW * colStride7; + + if (static_cast(volDep) >= + static_cast(iD) || + static_cast(volRow) >= + static_cast(iH) || + static_cast(volCol) >= + static_cast(iW)) + *col = static_cast(0.f); + else { + vol = volBuff + b * volStride0 + c * volStride1 + + volDep * volStride2 + volRow * volStride3 + + volCol * volStride4; + *col = *vol; + } } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, oD, 1); - //func(0, 0, bS, 1, 0, oD, 1); + } + } + } } + } } + } + }; -void ConvolutionUtils::vol2col(sd::graph::Context& block, const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); + samediff::Threads::parallel_for(func, 0, bS, 1, 0, oD, 1); + // func(0, 0, bS, 1, 0, oD, 1); + } } +void ConvolutionUtils::vol2col(sd::graph::Context& block, const NDArray& volume, + NDArray& columns, const int sD, const int sH, + const int sW, const int pD, const int pH, + const int pW, const int dD, const int dH, + const int dW) { + BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, + (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), + FLOAT_TYPES); } -} \ No newline at end of file + +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/crop_and_resize.cpp b/libnd4j/include/ops/declarable/helpers/cpu/crop_and_resize.cpp index ab6503946c14..5165f6396d0b 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/crop_and_resize.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/crop_and_resize.cpp @@ -33,31 +33,38 @@ limitations under the License. // @author sgazeos@gmail.com // -#include #include +#include namespace sd { - namespace ops { - namespace helpers { +namespace ops { +namespace helpers { -// ------------------------------------------------------------------------------------------------------------------ // -// ------------------------------------------------------------------------------------------------------------------ // -// crop and resize helper functor: +// ------------------------------------------------------------------------------------------------------------------ +// // +// ------------------------------------------------------------------------------------------------------------------ +// // crop and resize helper functor: // \@param context - launch context for operation -// \@param images - batch of images (4D tensor) with shape {batch, width, height, channels} with given type +// \@param images - batch of images (4D tensor) with shape {batch, width, +// height, channels} with given type // \@param boxes - float boxes for crop // \@param indices - integer boxes indices for crop // \@param cropSize - integer size (newWidth, newHeight) -// \@param method - one of bilinear (0) or nearest neighbour (1) interpolation algorithm +// \@param method - one of bilinear (0) or nearest neighbour (1) interpolation +// algorithm // \@param extrapolationVal - radix to increase/decrease image // \@param crops - output image batch (4D with given type) // - void - cropAndResizeFunctor(sd::LaunchContext * context, NDArray const *images, NDArray const *boxes, - NDArray const *indices, NDArray const *cropSize, - int method, double extrapolationVal, NDArray *crops) { - BUILD_TRIPLE_SELECTOR(images->dataType(), boxes->dataType(), indices->dataType(), cropAndResizeFunctor_, (images, boxes, indices, cropSize, method, extrapolationVal, crops), NUMERIC_TYPES, FLOAT_TYPES, INTEGER_TYPES); - } - } - } -} \ No newline at end of file +void cropAndResizeFunctor(sd::LaunchContext *context, NDArray const *images, + NDArray const *boxes, NDArray const *indices, + NDArray const *cropSize, int method, + double extrapolationVal, NDArray *crops) { + BUILD_TRIPLE_SELECTOR( + images->dataType(), boxes->dataType(), indices->dataType(), + cropAndResizeFunctor_, + (images, boxes, indices, cropSize, method, extrapolationVal, crops), + NUMERIC_TYPES, FLOAT_TYPES, INTEGER_TYPES); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/crop_and_resize.hpp b/libnd4j/include/ops/declarable/helpers/cpu/crop_and_resize.hpp index c7d29c47131e..58d09800f5b3 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/crop_and_resize.hpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/crop_and_resize.hpp @@ -18,106 +18,118 @@ // @author sgazeos@gmail.com // -#include #include +#include namespace sd { - namespace ops { - namespace helpers { - template - void cropAndResizeFunctor_(NDArray const *images, NDArray const *boxes, NDArray const *indices, NDArray const *cropSize, int method, double extrapolationVal, NDArray *crops) { - const int batchSize = images->sizeAt(0); - const int imageHeight = images->sizeAt(1); - const int imageWidth = images->sizeAt(2); - - const int numBoxes = crops->sizeAt(0); - const int cropHeight = crops->sizeAt(1); - const int cropWidth = crops->sizeAt(2); - const int depth = crops->sizeAt(3); +namespace ops { +namespace helpers { +template +void cropAndResizeFunctor_(NDArray const *images, NDArray const *boxes, + NDArray const *indices, NDArray const *cropSize, + int method, double extrapolationVal, + NDArray *crops) { + const int batchSize = images->sizeAt(0); + const int imageHeight = images->sizeAt(1); + const int imageWidth = images->sizeAt(2); - for (auto b = 0; b < numBoxes; ++b) { - T y1 = boxes->t(b, 0); - T x1 = boxes->t(b, 1); - T y2 = boxes->t(b, 2); - T x2 = boxes->t(b, 3); + const int numBoxes = crops->sizeAt(0); + const int cropHeight = crops->sizeAt(1); + const int cropWidth = crops->sizeAt(2); + const int depth = crops->sizeAt(3); - int bIn = indices->e(b); - if (bIn >= batchSize) { - continue; - } + for (auto b = 0; b < numBoxes; ++b) { + T y1 = boxes->t(b, 0); + T x1 = boxes->t(b, 1); + T y2 = boxes->t(b, 2); + T x2 = boxes->t(b, 3); - T heightScale = (cropHeight > 1) ? (y2 - y1) * (imageHeight - 1) / (cropHeight - 1) : T(0); - T widthScale = (cropWidth > 1) ? (x2 - x1) * (imageWidth - 1) / (cropWidth - 1) : T(0); + int bIn = indices->e(b); + if (bIn >= batchSize) { + continue; + } - auto func = PRAGMA_THREADS_FOR { - for (auto y = start; y < stop; y++) { - const float inY = (cropHeight > 1) - ? y1 * (imageHeight - 1) + y * heightScale - : 0.5 * (y1 + y2) * (imageHeight - 1); + T heightScale = (cropHeight > 1) + ? (y2 - y1) * (imageHeight - 1) / (cropHeight - 1) + : T(0); + T widthScale = + (cropWidth > 1) ? (x2 - x1) * (imageWidth - 1) / (cropWidth - 1) : T(0); - if (inY < 0 || inY > imageHeight - 1) { - for (auto x = 0; x < cropWidth; ++x) { - for (auto d = 0; d < depth; ++d) { - crops->p(b, y, x, d, extrapolationVal); - } - } - continue; - } - if (method == 0 /* bilinear */) { - const int topYIndex = sd::math::p_floor(inY); - const int bottomYIndex = sd::math::p_ceil(inY); - const float y_lerp = inY - topYIndex; + auto func = PRAGMA_THREADS_FOR { + for (auto y = start; y < stop; y++) { + const float inY = (cropHeight > 1) + ? y1 * (imageHeight - 1) + y * heightScale + : 0.5 * (y1 + y2) * (imageHeight - 1); - for (auto x = 0; x < cropWidth; ++x) { - const float in_x = (cropWidth > 1) - ? x1 * (imageWidth - 1) + x * widthScale - : 0.5 * (x1 + x2) * (imageWidth - 1); + if (inY < 0 || inY > imageHeight - 1) { + for (auto x = 0; x < cropWidth; ++x) { + for (auto d = 0; d < depth; ++d) { + crops->p(b, y, x, d, extrapolationVal); + } + } + continue; + } + if (method == 0 /* bilinear */) { + const int topYIndex = sd::math::p_floor(inY); + const int bottomYIndex = sd::math::p_ceil(inY); + const float y_lerp = inY - topYIndex; - if (in_x < 0 || in_x > imageWidth - 1) { - for (auto d = 0; d < depth; ++d) { - crops->p(b, y, x, d, extrapolationVal); - } - continue; - } - int left_x_index = math::p_floor(in_x); - int right_x_index = math::p_ceil(in_x); - T x_lerp = in_x - left_x_index; + for (auto x = 0; x < cropWidth; ++x) { + const float in_x = (cropWidth > 1) + ? x1 * (imageWidth - 1) + x * widthScale + : 0.5 * (x1 + x2) * (imageWidth - 1); - for (auto d = 0; d < depth; ++d) { - const float topLeft(images->e(bIn, topYIndex, left_x_index, d)); - const float topRight(images->e(bIn, topYIndex, right_x_index, d)); - const float bottomLeft(images->e(bIn, bottomYIndex, left_x_index, d)); - const float bottomRight(images->e(bIn, bottomYIndex, right_x_index, d)); - const float top = topLeft + (topRight - topLeft) * x_lerp; - const float bottom = bottomLeft + (bottomRight - bottomLeft) * x_lerp; - crops->p(b, y, x, d, top + (bottom - top) * y_lerp); - } - } - } else { // method is "nearest neighbor" - for (auto x = 0; x < cropWidth; ++x) { - const float inX = (cropWidth > 1) - ? x1 * (imageWidth - 1) + x * widthScale - : 0.5 * (x1 + x2) * (imageWidth - 1); + if (in_x < 0 || in_x > imageWidth - 1) { + for (auto d = 0; d < depth; ++d) { + crops->p(b, y, x, d, extrapolationVal); + } + continue; + } + int left_x_index = math::p_floor(in_x); + int right_x_index = math::p_ceil(in_x); + T x_lerp = in_x - left_x_index; - if (inX < 0 || inX > imageWidth - 1) { - for (auto d = 0; d < depth; ++d) { - crops->p(b, y, x, d, extrapolationVal); - } - continue; - } - const int closestXIndex = roundf(inX); - const int closestYIndex = roundf(inY); - for (auto d = 0; d < depth; ++d) { - crops->p(b, y, x, d, images->e(bIn, closestYIndex, closestXIndex, d)); - } - } - } - } - }; + for (auto d = 0; d < depth; ++d) { + const float topLeft( + images->e(bIn, topYIndex, left_x_index, d)); + const float topRight( + images->e(bIn, topYIndex, right_x_index, d)); + const float bottomLeft( + images->e(bIn, bottomYIndex, left_x_index, d)); + const float bottomRight( + images->e(bIn, bottomYIndex, right_x_index, d)); + const float top = topLeft + (topRight - topLeft) * x_lerp; + const float bottom = + bottomLeft + (bottomRight - bottomLeft) * x_lerp; + crops->p(b, y, x, d, top + (bottom - top) * y_lerp); + } + } + } else { // method is "nearest neighbor" + for (auto x = 0; x < cropWidth; ++x) { + const float inX = (cropWidth > 1) + ? x1 * (imageWidth - 1) + x * widthScale + : 0.5 * (x1 + x2) * (imageWidth - 1); - samediff::Threads::parallel_for(func, 0, cropHeight); - } + if (inX < 0 || inX > imageWidth - 1) { + for (auto d = 0; d < depth; ++d) { + crops->p(b, y, x, d, extrapolationVal); + } + continue; + } + const int closestXIndex = roundf(inX); + const int closestYIndex = roundf(inY); + for (auto d = 0; d < depth; ++d) { + crops->p(b, y, x, d, + images->e(bIn, closestYIndex, closestXIndex, d)); } + } } - } -} \ No newline at end of file + } + }; + + samediff::Threads::parallel_for(func, 0, cropHeight); + } +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/cross.cpp b/libnd4j/include/ops/declarable/helpers/cpu/cross.cpp index 51af1840be42..33a5b57498fa 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/cross.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/cross.cpp @@ -18,39 +18,39 @@ // @author GS (sgazeos@gmail.com), created on 10/1/2018 // - -#include #include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { -void crossBatched(sd::LaunchContext * context, NDArray *a, NDArray *b, NDArray *o) { - auto _a = a->reshape(a->ordering(), {-1, 3}); - auto _b = b->reshape(b->ordering(), {-1, 3}); - auto _o = o->reshape(o->ordering(), {-1, 3}, false); +void crossBatched(sd::LaunchContext *context, NDArray *a, NDArray *b, + NDArray *o) { + auto _a = a->reshape(a->ordering(), {-1, 3}); + auto _b = b->reshape(b->ordering(), {-1, 3}); + auto _o = o->reshape(o->ordering(), {-1, 3}, false); - auto tadsA = _a.allTensorsAlongDimension({1}); - auto tadsB = _b.allTensorsAlongDimension({1}); - auto tadsO = _o.allTensorsAlongDimension({1}); + auto tadsA = _a.allTensorsAlongDimension({1}); + auto tadsB = _b.allTensorsAlongDimension({1}); + auto tadsO = _o.allTensorsAlongDimension({1}); - int tads = tadsA.size(); + int tads = tadsA.size(); - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto a_ = tadsA.at(e); - auto b_ = tadsB.at(e); - auto o_ = tadsO.at(e); + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto a_ = tadsA.at(e); + auto b_ = tadsB.at(e); + auto o_ = tadsO.at(e); - helpers::cross(context, a_, b_, o_); - } - }; + helpers::cross(context, &a_, &b_, &o_); + } + }; - samediff::Threads::parallel_tad(func, 0, tads); + samediff::Threads::parallel_tad(func, 0, tads); } -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/d_t_s.cpp b/libnd4j/include/ops/declarable/helpers/cpu/d_t_s.cpp index 27b73d001eb5..b0e782113727 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/d_t_s.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/d_t_s.cpp @@ -18,89 +18,104 @@ // // -#include #include +#include namespace sd { namespace ops { namespace helpers { - template - static void __depthToSpace(const NDArray &input, NDArray *output, int block_size, bool isNHWC) { - T const*input_ptr = reinterpret_cast(input.buffer()); - T *output_ptr = reinterpret_cast(output->buffer()); - - const int batch_size = input.sizeAt(0); - const int input_depth = isNHWC ? input.sizeAt(3) : input.sizeAt(1); - const int input_height = isNHWC ? input.sizeAt(1) : input.sizeAt(2); - const int input_width = isNHWC ? input.sizeAt(2) : input.sizeAt(3); - - const int output_depth = isNHWC ? output->sizeAt(3) : output->sizeAt(1); - const int output_height = isNHWC ? output->sizeAt(1) : output->sizeAt(2); - const int output_width = isNHWC ? output->sizeAt(2) : output->sizeAt(3); - - const int input_area = input_width * input_height; - const int input_depth_by_input_area = input_depth * input_area; - const int output_depth_by_input_height = output_depth * input_height; - - if (isNHWC) { - const int total_count = batch_size * output_height * output_width * output_depth; - auto func = PRAGMA_THREADS_FOR { - for (auto out_idx = start; out_idx < stop; out_idx++) { - const int d = out_idx % output_depth; - const int out_idx2 = out_idx / output_depth; - const int w = out_idx2 % output_width; - const int out_idx3 = out_idx2 / output_width; - const int h = out_idx3 % output_height; - const int b = out_idx3 / output_height; - - const int in_h = h / block_size; - const int offset_h = h % block_size; - const int in_w = w / block_size; - const int offset_w = w % block_size; - const int offset_d = (offset_h * block_size + offset_w) * output_depth; - const int in_d = d + offset_d; - const int inp_idx = in_d + input_depth * (in_w + input_width * (in_h + input_height * b)); - (output_ptr + out_idx)[0] = (input_ptr + inp_idx)[0]; - } - }; - - samediff::Threads::parallel_for(func, 0, total_count); - } else { - const int total_count = batch_size * input_depth_by_input_area; - - auto func = PRAGMA_THREADS_FOR { - for (int input_idx = start; input_idx < stop; input_idx++) { - const int n_bY_bX_oC_iY = input_idx / input_width; - const int iX = input_idx - n_bY_bX_oC_iY * input_width; - - const int n_bY_bX = n_bY_bX_oC_iY / output_depth_by_input_height; - const int oC_iY = n_bY_bX_oC_iY - n_bY_bX * output_depth_by_input_height; - - const int n_bY = n_bY_bX / block_size; - const int bX = n_bY_bX - n_bY * block_size; - - const int n = n_bY / block_size; - const int bY = n_bY - n * block_size; - - const int output_idx = bX + block_size * (iX + input_width * (bY + block_size * (oC_iY + n * output_depth_by_input_height))); - - (output_ptr + output_idx)[0] = (input_ptr + input_idx)[0]; - } - }; - - samediff::Threads::parallel_for(func, 0, total_count); - } - } - - void _depthToSpace(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) { - auto xType = input.dataType(); - - BUILD_SINGLE_SELECTOR(xType, __depthToSpace, (input, output, block_size, isNHWC), LIBND4J_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template void __depthToSpace, (const NDArray &input, NDArray *output, int block_size, bool isNHWC);, LIBND4J_TYPES); - +template +static void __depthToSpace(const NDArray &input, NDArray *output, + int block_size, bool isNHWC) { + T const *input_ptr = reinterpret_cast(input.buffer()); + T *output_ptr = reinterpret_cast(output->buffer()); + + const int batch_size = input.sizeAt(0); + const int input_depth = isNHWC ? input.sizeAt(3) : input.sizeAt(1); + const int input_height = isNHWC ? input.sizeAt(1) : input.sizeAt(2); + const int input_width = isNHWC ? input.sizeAt(2) : input.sizeAt(3); + + const int output_depth = isNHWC ? output->sizeAt(3) : output->sizeAt(1); + const int output_height = isNHWC ? output->sizeAt(1) : output->sizeAt(2); + const int output_width = isNHWC ? output->sizeAt(2) : output->sizeAt(3); + + const int input_area = input_width * input_height; + const int input_depth_by_input_area = input_depth * input_area; + const int output_depth_by_input_height = output_depth * input_height; + + if (isNHWC) { + const int total_count = + batch_size * output_height * output_width * output_depth; + auto func = PRAGMA_THREADS_FOR { + for (auto out_idx = start; out_idx < stop; out_idx++) { + const int d = out_idx % output_depth; + const int out_idx2 = out_idx / output_depth; + const int w = out_idx2 % output_width; + const int out_idx3 = out_idx2 / output_width; + const int h = out_idx3 % output_height; + const int b = out_idx3 / output_height; + + const int in_h = h / block_size; + const int offset_h = h % block_size; + const int in_w = w / block_size; + const int offset_w = w % block_size; + const int offset_d = (offset_h * block_size + offset_w) * output_depth; + const int in_d = d + offset_d; + const int inp_idx = + in_d + + input_depth * (in_w + input_width * (in_h + input_height * b)); + (output_ptr + out_idx)[0] = (input_ptr + inp_idx)[0]; + } + }; + + samediff::Threads::parallel_for(func, 0, total_count); + } else { + const int total_count = batch_size * input_depth_by_input_area; + + auto func = PRAGMA_THREADS_FOR { + for (int input_idx = start; input_idx < stop; input_idx++) { + const int n_bY_bX_oC_iY = input_idx / input_width; + const int iX = input_idx - n_bY_bX_oC_iY * input_width; + + const int n_bY_bX = n_bY_bX_oC_iY / output_depth_by_input_height; + const int oC_iY = + n_bY_bX_oC_iY - n_bY_bX * output_depth_by_input_height; + + const int n_bY = n_bY_bX / block_size; + const int bX = n_bY_bX - n_bY * block_size; + + const int n = n_bY / block_size; + const int bY = n_bY - n * block_size; + + const int output_idx = + bX + block_size * + (iX + input_width * + (bY + block_size * + (oC_iY + + n * output_depth_by_input_height))); + + (output_ptr + output_idx)[0] = (input_ptr + input_idx)[0]; + } + }; + + samediff::Threads::parallel_for(func, 0, total_count); + } } + +void _depthToSpace(sd::LaunchContext *context, const NDArray &input, + NDArray *output, int block_size, bool isNHWC) { + auto xType = input.dataType(); + + BUILD_SINGLE_SELECTOR(xType, __depthToSpace, + (input, output, block_size, isNHWC), LIBND4J_TYPES); } -} \ No newline at end of file + +BUILD_SINGLE_TEMPLATE(template void __depthToSpace, + (const NDArray &input, NDArray *output, int block_size, + bool isNHWC); + , LIBND4J_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/diGamma.cpp b/libnd4j/include/ops/declarable/helpers/cpu/diGamma.cpp index 37abaf559883..bb53d65084ce 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/diGamma.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/diGamma.cpp @@ -19,8 +19,8 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include namespace sd { namespace ops { @@ -30,24 +30,19 @@ namespace helpers { // calculate digamma function for array elements template static void diGamma_(const NDArray& x, NDArray& z) { - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - z.p(i, diGammaScalar(x.e(i))); - }; - samediff::Threads::parallel_for(func, 0, x.lengthOf()); + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) z.p(i, diGammaScalar(x.e(i))); + }; + samediff::Threads::parallel_for(func, 0, x.lengthOf()); } void diGamma(sd::LaunchContext* context, const NDArray& x, NDArray& z) { - - BUILD_SINGLE_SELECTOR(x.dataType(), diGamma_, (x, z), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(x.dataType(), diGamma_, (x, z), FLOAT_TYPES); } -BUILD_SINGLE_TEMPLATE(template void diGamma_, (const NDArray& x, NDArray& z), FLOAT_TYPES); - - - -} -} -} +BUILD_SINGLE_TEMPLATE(template void diGamma_, (const NDArray& x, NDArray& z), + FLOAT_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/diag.cpp b/libnd4j/include/ops/declarable/helpers/cpu/diag.cpp index 670ad532239d..351df65532b3 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/diag.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/diag.cpp @@ -25,39 +25,41 @@ namespace sd { namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// // Returns a batched matrix tensor with new batched diagonal values. -// for detailed explanations please take a look on web page: https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag +// for detailed explanations please take a look on web page: +// https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag template static void _diagFunctor(const NDArray* input, NDArray* output) { + const int inLength = input->lengthOf(); - const int inLength = input->lengthOf(); - - for(int i = 0; i < inLength; ++i) - output->p(i * (inLength + 1), (*input).e(i)); + for (int i = 0; i < inLength; ++i) + output->p(i * (inLength + 1), (*input).e(i)); } - void diagFunctor(sd::LaunchContext * context, const NDArray* input, NDArray* output) { - auto xType = input->dataType(); - - BUILD_SINGLE_SELECTOR(xType, _diagFunctor, (input, output), LIBND4J_TYPES); - } +void diagFunctor(sd::LaunchContext* context, const NDArray* input, + NDArray* output) { + auto xType = input->dataType(); -BUILD_SINGLE_TEMPLATE(template void _diagFunctor, (const NDArray* input, NDArray* output);, LIBND4J_TYPES); - -void diagPartFunctor(sd::LaunchContext * context, NDArray const* input, NDArray* output) { - const int outLen = output->lengthOf(); - const int inLen = input->lengthOf(); - int i(0), j(0); - while (j < outLen) { - output->p(j, input->e(i)); - i += outLen + 1; - ++j; - } + BUILD_SINGLE_SELECTOR(xType, _diagFunctor, (input, output), LIBND4J_TYPES); } +BUILD_SINGLE_TEMPLATE(template void _diagFunctor, + (const NDArray* input, NDArray* output); + , LIBND4J_TYPES); +void diagPartFunctor(sd::LaunchContext* context, NDArray const* input, + NDArray* output) { + const int outLen = output->lengthOf(); + const int inLen = input->lengthOf(); + int i(0), j(0); + while (j < outLen) { + output->p(j, input->e(i)); + i += outLen + 1; + ++j; + } } -} -} \ No newline at end of file + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/dilation2d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/dilation2d.cpp index 1688dcbc421d..ee426c1fd7b7 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/dilation2d.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/dilation2d.cpp @@ -18,83 +18,87 @@ // @autkhor raver119@gmail.com // -#include #include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////// template -static void dilation2d_(NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) { - - // input [bS, iH, iW, iC] - // weights [kH, kW, iC] - // output [bS, oH, oW, iC] - - const X* x = input->bufferAsT(); - const X* y = weights->bufferAsT(); - Z* z = output->bufferAsT(); - - const Nd4jLong* xShapeInfo = input->shapeInfo(); - const Nd4jLong* yShapeInfo = weights->shapeInfo(); - const Nd4jLong* zShapeInfo = output->shapeInfo(); - - const uint bS = input->sizeAt(0); - const uint iH = input->sizeAt(1); - const uint iW = input->sizeAt(2); - const uint iC = input->sizeAt(3); - - const uint kH = weights->sizeAt(0); - const uint kW = weights->sizeAt(1); - - const uint oH = output->sizeAt(1); - const uint oW = output->sizeAt(2); - - auto func = PRAGMA_THREADS_FOR_2D { - - for (auto b = start_x; b < stop_x; b += inc_x) { - for (auto oh = start_y; oh < stop_y; oh += inc_y) { - for (uint ow = 0; ow < oW; ++ow) { - for (uint c = 0; c < iC; ++c) { - - X max = -DataTypeUtils::max(); - - for (uint kh = 0; kh < kH; ++kh) { - const int ih = oh * sH - pH + kh * dH; - if (ih < 0 || ih >= iH) continue; - - for (uint kw = 0; kw < kW; ++kw) { - const int iw = ow * sW - pW + kw * dW; - if (iw < 0 || iw >= iW) continue; - - uint xCoords[4] = { static_cast(b), static_cast(ih), static_cast(iw), c}; - uint yCoords[3] = {kh, kw, c}; - - const X val = x[shape::getOffset(xShapeInfo, xCoords)] + y[shape::getOffset(yShapeInfo, yCoords)]; - if (val > max) - max = val; - } - } - - uint zCoords[4] = { static_cast(b), static_cast(oh), ow, c}; - z[shape::getOffset(zShapeInfo, zCoords)] = static_cast(max); - } - } +static void dilation2d_(NDArray* input, NDArray* weights, NDArray* output, + const int sH, const int sW, const int pH, const int pW, + const int dH, const int dW) { + // input [bS, iH, iW, iC] + // weights [kH, kW, iC] + // output [bS, oH, oW, iC] + + const X* x = input->bufferAsT(); + const X* y = weights->bufferAsT(); + Z* z = output->bufferAsT(); + + const Nd4jLong* xShapeInfo = input->shapeInfo(); + const Nd4jLong* yShapeInfo = weights->shapeInfo(); + const Nd4jLong* zShapeInfo = output->shapeInfo(); + + const uint bS = input->sizeAt(0); + const uint iH = input->sizeAt(1); + const uint iW = input->sizeAt(2); + const uint iC = input->sizeAt(3); + + const uint kH = weights->sizeAt(0); + const uint kW = weights->sizeAt(1); + + const uint oH = output->sizeAt(1); + const uint oW = output->sizeAt(2); + + auto func = PRAGMA_THREADS_FOR_2D { + for (auto b = start_x; b < stop_x; b += inc_x) { + for (auto oh = start_y; oh < stop_y; oh += inc_y) { + for (uint ow = 0; ow < oW; ++ow) { + for (uint c = 0; c < iC; ++c) { + X max = -DataTypeUtils::max(); + + for (uint kh = 0; kh < kH; ++kh) { + const int ih = oh * sH - pH + kh * dH; + if (ih < 0 || ih >= iH) continue; + + for (uint kw = 0; kw < kW; ++kw) { + const int iw = ow * sW - pW + kw * dW; + if (iw < 0 || iw >= iW) continue; + + uint xCoords[4] = {static_cast(b), static_cast(ih), + static_cast(iw), c}; + uint yCoords[3] = {kh, kw, c}; + + const X val = x[shape::getOffset(xShapeInfo, xCoords)] + + y[shape::getOffset(yShapeInfo, yCoords)]; + if (val > max) max = val; + } } + + uint zCoords[4] = {static_cast(b), static_cast(oh), ow, + c}; + z[shape::getOffset(zShapeInfo, zCoords)] = static_cast(max); + } } - }; + } + } + }; - samediff::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1); + samediff::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1); } -void dilation2d(sd::LaunchContext* context, NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), dilation2d_, (input, weights, output, sH, sW, pH, pW, dH, dW), FLOAT_TYPES); +void dilation2d(sd::LaunchContext* context, NDArray* input, NDArray* weights, + NDArray* output, const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), dilation2d_, + (input, weights, output, sH, sW, pH, pW, dH, dW), + FLOAT_TYPES); } - -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/dropout.cpp b/libnd4j/include/ops/declarable/helpers/cpu/dropout.cpp index 54981dea5c5e..d8b04295e3d1 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/dropout.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/dropout.cpp @@ -18,157 +18,217 @@ // @author raver119@gmail.com // -#include +#include #include -#include +#include + #include -#include +#include namespace sd { namespace ops { namespace helpers { - template - static void dropoutSimple(NDArray const* input, NDArray* output, double probValue, int seed) { - - sd::graph::RandomGenerator nodeRng(3019L, seed); - int inLen = input->lengthOf(); - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - float val = nodeRng.relativeT(e, T(0.f), T(1.f)); +template +static void dropoutSimple(NDArray const* input, NDArray* output, + double probValue, int seed) { + sd::graph::RandomGenerator nodeRng(3019L, seed); + int inLen = input->lengthOf(); - if (val < probValue) - output->p(e, input->e(e) / probValue); - } - }; + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + float val = nodeRng.relativeT(e, T(0.f), T(1.f)); - samediff::Threads::parallel_for(func, 0, inLen); + if (val < probValue) output->p(e, input->e(e) / probValue); } - BUILD_SINGLE_TEMPLATE(template void dropoutSimple, (NDArray const* input, NDArray* output, double probValue, int seed), FLOAT_TYPES); - - template - int dropOutFunctor_(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) { - //NativeOps native; - //sd::graph::RandomGenerator nodeRng(seed); //static int dropOutFunctor_(sd::random::RandomBuffer* rng, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) { - //NativeOps native; - //native.reSeedBuffer(nullptr, (long)seed, rng); - //if (newRng ) - if (reduceShape == nullptr){ - dropoutSimple(input, output, probValue, seed); - } - else { - REQUIRE_TRUE(reduceShape->lengthOf() <= input->rankOf(), 0, "dropout: Noise shape should be fittable to input"); - - std::vector dims(reduceShape->lengthOf()); - - bool fit = true; - for(auto i = 0; i < dims.size(); i++ ) { - if (fit) { - dims[i] = reduceShape->e(i); - for (int e = 0; e < input->rankOf(); ++e) - if (fit) - if (input->sizeAt(e) % dims[i]) { - fit = false; - } - } - } - - // check dims to fit input - REQUIRE_TRUE(fit, 0, "dropout: Noise shape should fit to input rank."); - std::unique_ptr chunk(new NDArray('c', dims, output->dataType(), output->getContext())); - chunk->assign(1.f); - //chunk->applyRandom>(rng, nullptr, chunk.get(), &probValue); - //NativeOpExecutioner::execRandom(random::DropOutInverted, rng, chunk->buffer(), chunk->shapeInfo(), chunk->buffer(), chunk->shapeInfo(), &prob); - dropoutSimple(chunk.get(), chunk.get(), probValue, seed); - // broadcast chunk to full matrix - std::unique_ptr dropOutMultiplier(new NDArray(*input)); - dropOutMultiplier->assign(1.f); - - *dropOutMultiplier += *chunk; - - output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr); - } + }; - return Status::OK(); - } - - int dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) { - auto xType = input->dataType(); - - BUILD_SINGLE_SELECTOR(xType, return dropOutFunctor_, (context, input, output, reduceShape, seed, probValue), FLOAT_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template int dropOutFunctor_, (graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue);, FLOAT_TYPES); - -/////////////////////////////////// backrpopagations /////////////////////////////////////////////// - template - static int dropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) { - - int res = dropOutFunctor(context, input, output, reduceShape, seed, probValue); - - if (ND4J_STATUS_OK == res) - for (Nd4jLong e = 0; e < output->lengthOf(); e++) { - if (output->e(e) != 0.f) output->p(e, gradOut->e(e) / probValue); -// else (*output)(e) = T(0.f); + samediff::Threads::parallel_for(func, 0, inLen); +} +BUILD_SINGLE_TEMPLATE(template void dropoutSimple, + (NDArray const* input, NDArray* output, double probValue, + int seed), + FLOAT_TYPES); + +template +int dropOutFunctor_(graph::Context& context, NDArray* input, NDArray* output, + NDArray* reduceShape, int seed, double probValue) { + // NativeOps native; + // sd::graph::RandomGenerator nodeRng(seed); //static int + // dropOutFunctor_(sd::random::RandomBuffer* rng, NDArray* input, NDArray* + // output, NDArray* reduceShape, int seed, double probValue) { NativeOps + // native; native.reSeedBuffer(nullptr, (long)seed, rng); if (newRng ) + if (reduceShape == nullptr) { + dropoutSimple(input, output, probValue, seed); + } else { + REQUIRE_TRUE(reduceShape->lengthOf() <= input->rankOf(), 0, + "dropout: Noise shape should be fittable to input"); + + std::vector dims(reduceShape->lengthOf()); + + bool fit = true; + for (auto i = 0; i < dims.size(); i++) { + if (fit) { + dims[i] = reduceShape->e(i); + for (int e = 0; e < input->rankOf(); ++e) + if (fit) + if (input->sizeAt(e) % dims[i]) { + fit = false; } - - return res; + } } - template - static int alphaDropOutFunctor_(graph::Context& context, NDArray* input, NDArray* output, - NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { - - //NativeOps native; - //auto rng = context.getRNG(); - //native.reSeedBuffer(nullptr, (long)seed, rng); - //if (rng == nullptr) - // return ND4J_STATUS_BAD_RNG; - //T probValueArr[] = {probValue, alpha, alpha1, beta}; - //input->template applyRandom>(rng, nullptr, output, probValueArr); - sd::graph::RandomGenerator nodeRng(3019L, seed); - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - float randVal = nodeRng.relativeT(e, T(0.f), T(1.f)); - float xVal = input->e(e); - output->p(e, randVal >= probValue ? alpha * beta + alpha1 : alpha * xVal + alpha1); - } - }; + // check dims to fit input + REQUIRE_TRUE(fit, 0, "dropout: Noise shape should fit to input rank."); + std::unique_ptr chunk( + new NDArray('c', dims, output->dataType(), output->getContext())); + chunk->assign(1.f); + // chunk->applyRandom>(rng, nullptr, + // chunk.get(), &probValue); + // NativeOpExecutioner::execRandom(random::DropOutInverted, rng, + // chunk->buffer(), chunk->shapeInfo(), chunk->buffer(), chunk->shapeInfo(), + // &prob); + dropoutSimple(chunk.get(), chunk.get(), probValue, seed); + // broadcast chunk to full matrix + std::unique_ptr dropOutMultiplier(new NDArray(*input)); + dropOutMultiplier->assign(1.f); + + *dropOutMultiplier += *chunk; + + output->assign( + *input * + *dropOutMultiplier); // input->applyPairwiseTransform(pairwise::Multiply, + // dropOutMultiplier.get(), output, nullptr); + } + + return Status::OK(); +} - samediff::Threads::parallel_for(func, 0, input->lengthOf()); +int dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, + NDArray* reduceShape, int seed, double probValue) { + auto xType = input->dataType(); - return Status::OK(); + BUILD_SINGLE_SELECTOR(xType, return dropOutFunctor_, + (context, input, output, reduceShape, seed, probValue), + FLOAT_TYPES); +} + +BUILD_SINGLE_TEMPLATE(template int dropOutFunctor_, + (graph::Context & context, NDArray* input, + NDArray* output, NDArray* reduceShape, int seed, + double probValue); + , FLOAT_TYPES); + +/////////////////////////////////// backrpopagations +////////////////////////////////////////////////// +template +static int dropOutFunctorBP_(graph::Context& context, NDArray* input, + NDArray* gradOut, NDArray* output, + NDArray* reduceShape, int seed, double probValue) { + int res = + dropOutFunctor(context, input, output, reduceShape, seed, probValue); + + if (ND4J_STATUS_OK == res) + for (Nd4jLong e = 0; e < output->lengthOf(); e++) { + if (output->e(e) != 0.f) + output->p(e, gradOut->e(e) / probValue); + // else (*output)(e) = T(0.f); } - template - int alphaDropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, - NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { + return res; +} - int res = alphaDropOutFunctor(context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta); - if (res == ND4J_STATUS_OK) { - (*output) *= alpha; - (*output) *= (*gradOut); //->applyPairwiseTransform(gradOut, output, nullptr); - } - return res; +template +static int alphaDropOutFunctor_(graph::Context& context, NDArray* input, + NDArray* output, NDArray* reduceShape, int seed, + double probValue, double alpha, double alpha1, + double beta) { + // NativeOps native; + // auto rng = context.getRNG(); + // native.reSeedBuffer(nullptr, (long)seed, rng); + // if (rng == nullptr) + // return ND4J_STATUS_BAD_RNG; + // T probValueArr[] = {probValue, alpha, alpha1, beta}; + // input->template applyRandom>(rng, nullptr, + // output, probValueArr); + sd::graph::RandomGenerator nodeRng(3019L, seed); + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + float randVal = nodeRng.relativeT(e, T(0.f), T(1.f)); + float xVal = input->e(e); + output->p(e, randVal >= probValue ? alpha * beta + alpha1 + : alpha * xVal + alpha1); } + }; - int dropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) { - BUILD_SINGLE_SELECTOR(context.dataType(), return dropOutFunctorBP_, (context, input, gradOut, output, reduceShape, seed, probValue), FLOAT_TYPES); - } - BUILD_SINGLE_TEMPLATE(template int dropOutFunctorBP_, (graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue), FLOAT_TYPES); + samediff::Threads::parallel_for(func, 0, input->lengthOf()); - int alphaDropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { - BUILD_SINGLE_SELECTOR(context.dataType(), return alphaDropOutFunctor_, (context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta), FLOAT_TYPES); - } - BUILD_SINGLE_TEMPLATE(template int alphaDropOutFunctor_, (graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta), FLOAT_TYPES); + return Status::OK(); +} - int alphaDropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { - BUILD_SINGLE_SELECTOR(context.dataType(), return alphaDropOutFunctorBP_, (context, input, gradOut, output, reduceShape, seed, probValue, alpha, alpha1, beta), FLOAT_TYPES); - } - BUILD_SINGLE_TEMPLATE(template int alphaDropOutFunctorBP_, (graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta), FLOAT_TYPES); +template +int alphaDropOutFunctorBP_(graph::Context& context, NDArray* input, + NDArray* gradOut, NDArray* output, + NDArray* reduceShape, int seed, double probValue, + double alpha, double alpha1, double beta) { + int res = alphaDropOutFunctor(context, input, output, reduceShape, seed, + probValue, alpha, alpha1, beta); + if (res == ND4J_STATUS_OK) { + (*output) *= alpha; + (*output) *= + (*gradOut); //->applyPairwiseTransform(gradOut, + // output, nullptr); + } + return res; +} +int dropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, + NDArray* output, NDArray* reduceShape, int seed, + double probValue) { + BUILD_SINGLE_SELECTOR( + gradOut->dataType(), return dropOutFunctorBP_, + (context, input, gradOut, output, reduceShape, seed, probValue), + FLOAT_TYPES); +} +BUILD_SINGLE_TEMPLATE(template int dropOutFunctorBP_, + (graph::Context & context, NDArray* input, + NDArray* gradOut, NDArray* output, NDArray* reduceShape, + int seed, double probValue), + FLOAT_TYPES); + +int alphaDropOutFunctor(graph::Context& context, NDArray* input, + NDArray* output, NDArray* reduceShape, int seed, + double probValue, double alpha, double alpha1, + double beta) { + BUILD_SINGLE_SELECTOR(output->dataType(), return alphaDropOutFunctor_, + (context, input, output, reduceShape, seed, probValue, + alpha, alpha1, beta), + FLOAT_TYPES); } +BUILD_SINGLE_TEMPLATE(template int alphaDropOutFunctor_, + (graph::Context & context, NDArray* input, + NDArray* output, NDArray* reduceShape, int seed, + double probValue, double alpha, double alpha1, + double beta), + FLOAT_TYPES); + +int alphaDropOutFunctorBP(graph::Context& context, NDArray* input, + NDArray* gradOut, NDArray* output, + NDArray* reduceShape, int seed, double probValue, + double alpha, double alpha1, double beta) { + BUILD_SINGLE_SELECTOR(gradOut->dataType(), return alphaDropOutFunctorBP_, + (context, input, gradOut, output, reduceShape, seed, + probValue, alpha, alpha1, beta), + FLOAT_TYPES); } -} \ No newline at end of file +BUILD_SINGLE_TEMPLATE(template int alphaDropOutFunctorBP_, + (graph::Context & context, NDArray* input, + NDArray* gradOut, NDArray* output, NDArray* reduceShape, + int seed, double probValue, double alpha, double alpha1, + double beta), + FLOAT_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/dynamic.cpp b/libnd4j/include/ops/declarable/helpers/cpu/dynamic.cpp index 89cf680d470b..62056794eea3 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/dynamic.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/dynamic.cpp @@ -17,209 +17,258 @@ // // Created by george on 05.04.18. // -#include #include +#include namespace sd { - namespace ops { - namespace helpers { - - template - static void _dynamicPartitionFunctor(NDArray const* input, NDArray const* indices, std::vector& outputList) { - std::vector> outputs(outputList.size()); - int sourceDimsLen = input->rankOf() - indices->rankOf(); - if (sourceDimsLen) { - std::vector sourceDims(sourceDimsLen); - - for (int i = sourceDimsLen; i > 0; i--) - sourceDims[sourceDimsLen - i] = input->rankOf() - i; - - ResultSet listOfTensors = input->allTensorsAlongDimension(sourceDims); - - unsigned int outSize = outputList.size(); - - //PRAGMA_OMP_PARALLEL_FOR_IF(outSize > Environment::getInstance().tadThreshold()) - for (unsigned int i = 0; i < outSize; i++) { - outputs[i].first = outputList[i]; - std::vector outDims(outputs[i].first->rankOf() - 1); - - int r = outputs[i].first->rankOf(); - - for (int k = 1; k < r; k++) - outDims[k - 1] = k; - - ResultSet listOutForCurrent = outputs[i].first->allTensorsAlongDimension(outDims); - - outputs[i].second = 0; - - //PRAGMA_OMP_PARALLEL_FOR_IF(indices->lengthOf() > Environment::getInstance().elementwiseThreshold()) - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - if ((*indices).e(e) == i) - listOutForCurrent.at(outputs[i].second++)->assign(listOfTensors.at(e)); - } - - } else { - unsigned int outSize = outputList.size(); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - outputs[i].first = outputList[i]; - outputs[i].second = 0; - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - if (indices->e(e) == i) - outputs[i].first->p(outputs[i].second++, input->e(e)); - } - }; - - samediff::Threads::parallel_tad(func, 0, outSize); - } - } - template - static int _dynamicStitchFunctor(std::vector const& inputs, std::vector const& indices, NDArray* output){ - - int numOfData = inputs.size(); - - if (output->isVector()) { - for (int e = 0; e < numOfData; e++) { - auto data = inputs[e]; - auto index = indices[e]; - for (Nd4jLong i = 0; i < index->lengthOf(); i++) { - Nd4jLong pos = index->e(i); - if (pos < 0) { - nd4j_printf("dynamic_stitch: Index value should be non-negative. But %i was given", pos); - return ND4J_STATUS_VALIDATION; - } - if (pos >= output->lengthOf()) { - nd4j_printf("dynamic_stitch: Index should be less than %i. But %i was given", - output->lengthOf(), pos); - return ND4J_STATUS_VALIDATION; - } - output->p(pos, data->e(i)); - } - } - } - else { - std::vector restDims(output->rankOf() - 1); - for (auto i = restDims.size(); i > 0; i--) - restDims[restDims.size() - i] = output->rankOf() - i; - - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - for (int e = 0; e < numOfData; e++) { - auto data = inputs[e]; - auto index = indices[e]; - std::vector sourceDims(data->rankOf() - index->rankOf()); - for (auto i = sourceDims.size(); i > 0; i--) - sourceDims[sourceDims.size() - i] = data->rankOf() - i; - - ResultSet listOfTensors = data->allTensorsAlongDimension(sourceDims) ; - - for (Nd4jLong i = 0; i < index->lengthOf(); i++) { - auto pos = index->e(i); - if (pos < 0) { - nd4j_printf("dynamic_stitch: Index value should be non-negative. But %i was given", pos); - return ND4J_STATUS_VALIDATION; - } - if (pos >= output->lengthOf()) { - nd4j_printf("dynamic_stitch: Index should be less than %i. But %i was given", - output->lengthOf(), pos); - return ND4J_STATUS_VALIDATION; - } - - listOfOutTensors.at(pos)->assign(listOfTensors.at(i)); - } - } - } - return ND4J_STATUS_OK; - } - - template - static void _dynamicPartitionFunctorBP(NDArray const* input, NDArray const* indices, std::vector const& inputGradientList, std::vector& outputList) { - std::vector> outputs(inputGradientList.size()); - - int sourceDimsLen = input->rankOf() - indices->rankOf(); - if (sourceDimsLen) { // multidimensional case - std::vector sourceDims(sourceDimsLen); - - for (int i = sourceDimsLen; i > 0; i--) - sourceDims[sourceDimsLen - i] = input->rankOf() - i; - - ResultSet listOfTensors = outputList[0]->allTensorsAlongDimension(sourceDims); - - for (auto i = 0; i < inputGradientList.size(); i++) { - outputs[i].first = inputGradientList[i]; - if (outputs[i].first->rankOf() < 1) continue; // skip empty gradient outs - std::vector outDims(outputs[i].first->rankOf() - 1); - - for (int k = 1; k < outputs[i].first->rankOf(); k++) - outDims[k - 1] = k; - - ResultSet listOutForCurrent = outputs[i].first->allTensorsAlongDimension(outDims); - - outputs[i].second = 0; - - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - if (indices->e(e) == i) - listOfTensors.at(e)->assign(listOutForCurrent.at(outputs[i].second++)); - } - } - else { // one-dimensional case - auto output = outputList[0]; - unsigned int gradsSize = inputGradientList.size(); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - outputs[i].first = inputGradientList[i]; - outputs[i].second = 0; - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - if (indices->e(e) == i) - output->p(e, outputs[i].first->e(outputs[i].second++)); - } - }; - - samediff::Threads::parallel_tad(func, 0, gradsSize); - } - - outputList[1]->assign(indices); - } - - void dynamicPartitionFunctor(sd::LaunchContext * context, NDArray const* input, NDArray const* indices, std::vector& outputList) { - auto xType = input->dataType(); - - BUILD_SINGLE_SELECTOR(xType, _dynamicPartitionFunctor, (input, indices, outputList), LIBND4J_TYPES); - } - - template - static int _dynamicStitchFunctorBP(std::vector const& inputs, std::vector const& indices, NDArray const* gradInput, std::vector& outputList){ - throw std::runtime_error("Not umplemented yet"); - } - - int dynamicStitchFunctor(sd::LaunchContext * context, std::vector const& inputs, std::vector const& indices, NDArray* output){ - auto xType = inputs.at(0)->dataType(); +namespace ops { +namespace helpers { + +template +static void _dynamicPartitionFunctor(NDArray const* input, + NDArray const* indices, + std::vector& outputList) { + std::vector> outputs(outputList.size()); + int sourceDimsLen = input->rankOf() - indices->rankOf(); + if (sourceDimsLen) { + std::vector sourceDims(sourceDimsLen); - BUILD_SINGLE_SELECTOR(xType, return _dynamicStitchFunctor, (inputs, indices, output), LIBND4J_TYPES); - } + for (int i = sourceDimsLen; i > 0; i--) + sourceDims[sourceDimsLen - i] = input->rankOf() - i; - int dynamicStitchFunctorBP(sd::LaunchContext * context, std::vector const& inputs, std::vector const& indices, NDArray const* gradInput, std::vector& outputList) { - auto xType = inputs.at(0)->dataType(); + ResultSet listOfTensors = input->allTensorsAlongDimension(sourceDims); - BUILD_SINGLE_SELECTOR(xType, return _dynamicStitchFunctorBP, (inputs, indices, gradInput, outputList), LIBND4J_TYPES); - } + unsigned int outSize = outputList.size(); - void dynamicPartitionFunctorBP(sd::LaunchContext * context, NDArray const* input, NDArray const* indices, std::vector const& inputGradientList, std::vector& outputList) { - auto xType = input->dataType(); + // PRAGMA_OMP_PARALLEL_FOR_IF(outSize > + // Environment::getInstance().tadThreshold()) + for (unsigned int i = 0; i < outSize; i++) { + outputs[i].first = outputList[i]; + std::vector outDims(outputs[i].first->rankOf() - 1); - BUILD_SINGLE_SELECTOR(xType, _dynamicPartitionFunctorBP, (input, indices, inputGradientList, outputList), LIBND4J_TYPES); - } + int r = outputs[i].first->rankOf(); - BUILD_SINGLE_TEMPLATE(template void _dynamicPartitionFunctorBP, (NDArray const* input, NDArray const* indices, std::vector const& inputGradientList, std::vector& outputList);, LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template int _dynamicStitchFunctorBP, (std::vector const& inputs, std::vector const& indices, NDArray const* gradInput, std::vector& outputList);, LIBND4J_TYPES); + for (int k = 1; k < r; k++) outDims[k - 1] = k; - BUILD_SINGLE_TEMPLATE(template void _dynamicPartitionFunctor, (NDArray const* input, NDArray const* indices, std::vector& outputList);, LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template int _dynamicStitchFunctor, (std::vector const& inputs, std::vector const& indices, NDArray* output);, LIBND4J_TYPES); + ResultSet listOutForCurrent = + outputs[i].first->allTensorsAlongDimension(outDims); + outputs[i].second = 0; + + // PRAGMA_OMP_PARALLEL_FOR_IF(indices->lengthOf() > + // Environment::getInstance().elementwiseThreshold()) + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + if ((*indices).e(e) == i) + listOutForCurrent.at(outputs[i].second++).assign(listOfTensors.at(e)); + } + } else { + unsigned int outSize = outputList.size(); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + outputs[i].first = outputList[i]; + outputs[i].second = 0; + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + if (indices->e(e) == i) + outputs[i].first->p(outputs[i].second++, input->e(e)); + } + }; + + samediff::Threads::parallel_tad(func, 0, outSize); + } +} +template +static int _dynamicStitchFunctor(std::vector const& inputs, + std::vector const& indices, + NDArray* output) { + int numOfData = inputs.size(); + + if (output->isVector()) { + for (int e = 0; e < numOfData; e++) { + auto data = inputs[e]; + auto index = indices[e]; + for (Nd4jLong i = 0; i < index->lengthOf(); i++) { + Nd4jLong pos = index->e(i); + if (pos < 0) { + nd4j_printf( + "dynamic_stitch: Index value should be non-negative. But %i was " + "given", + pos); + return ND4J_STATUS_VALIDATION; + } + if (pos >= output->lengthOf()) { + nd4j_printf( + "dynamic_stitch: Index should be less than %i. But %i was given", + output->lengthOf(), pos); + return ND4J_STATUS_VALIDATION; + } + output->p(pos, data->e(i)); + } + } + } else { + std::vector restDims(output->rankOf() - 1); + for (auto i = restDims.size(); i > 0; i--) + restDims[restDims.size() - i] = output->rankOf() - i; + + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + for (int e = 0; e < numOfData; e++) { + auto data = inputs[e]; + auto index = indices[e]; + std::vector sourceDims(data->rankOf() - index->rankOf()); + for (auto i = sourceDims.size(); i > 0; i--) + sourceDims[sourceDims.size() - i] = data->rankOf() - i; + + ResultSet listOfTensors = data->allTensorsAlongDimension(sourceDims); + + for (Nd4jLong i = 0; i < index->lengthOf(); i++) { + auto pos = index->e(i); + if (pos < 0) { + nd4j_printf( + "dynamic_stitch: Index value should be non-negative. But %i was " + "given", + pos); + return ND4J_STATUS_VALIDATION; } + if (pos >= output->lengthOf()) { + nd4j_printf( + "dynamic_stitch: Index should be less than %i. But %i was given", + output->lengthOf(), pos); + return ND4J_STATUS_VALIDATION; + } + + listOfOutTensors.at(pos).assign(listOfTensors.at(i)); + } + } + } + return ND4J_STATUS_OK; +} + +template +static void _dynamicPartitionFunctorBP( + NDArray const* input, NDArray const* indices, + std::vector const& inputGradientList, + std::vector& outputList) { + std::vector> outputs(inputGradientList.size()); + + int sourceDimsLen = input->rankOf() - indices->rankOf(); + if (sourceDimsLen) { // multidimensional case + std::vector sourceDims(sourceDimsLen); + + for (int i = sourceDimsLen; i > 0; i--) + sourceDims[sourceDimsLen - i] = input->rankOf() - i; + + ResultSet listOfTensors = + outputList[0]->allTensorsAlongDimension(sourceDims); + + for (auto i = 0; i < inputGradientList.size(); i++) { + outputs[i].first = inputGradientList[i]; + if (outputs[i].first->rankOf() < 1) continue; // skip empty gradient outs + std::vector outDims(outputs[i].first->rankOf() - 1); + + for (int k = 1; k < outputs[i].first->rankOf(); k++) outDims[k - 1] = k; + + ResultSet listOutForCurrent = + outputs[i].first->allTensorsAlongDimension(outDims); + + outputs[i].second = 0; + + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + if (indices->e(e) == i) + listOfTensors.at(e).assign(listOutForCurrent.at(outputs[i].second++)); } + } else { // one-dimensional case + auto output = outputList[0]; + unsigned int gradsSize = inputGradientList.size(); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + outputs[i].first = inputGradientList[i]; + outputs[i].second = 0; + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + if (indices->e(e) == i) + output->p(e, outputs[i].first->e(outputs[i].second++)); + } + }; + + samediff::Threads::parallel_tad(func, 0, gradsSize); + } + + outputList[1]->assign(indices); +} + +void dynamicPartitionFunctor(sd::LaunchContext* context, NDArray const* input, + NDArray const* indices, + std::vector& outputList) { + auto xType = input->dataType(); + + BUILD_SINGLE_SELECTOR(xType, _dynamicPartitionFunctor, + (input, indices, outputList), LIBND4J_TYPES); +} + +template +static int _dynamicStitchFunctorBP(std::vector const& inputs, + std::vector const& indices, + NDArray const* gradInput, + std::vector& outputList) { + throw std::runtime_error("Not umplemented yet"); +} + +int dynamicStitchFunctor(sd::LaunchContext* context, + std::vector const& inputs, + std::vector const& indices, + NDArray* output) { + auto xType = inputs.at(0)->dataType(); + + BUILD_SINGLE_SELECTOR(xType, return _dynamicStitchFunctor, + (inputs, indices, output), LIBND4J_TYPES); +} + +int dynamicStitchFunctorBP(sd::LaunchContext* context, + std::vector const& inputs, + std::vector const& indices, + NDArray const* gradInput, + std::vector& outputList) { + auto xType = inputs.at(0)->dataType(); + + BUILD_SINGLE_SELECTOR(xType, return _dynamicStitchFunctorBP, + (inputs, indices, gradInput, outputList), + LIBND4J_TYPES); +} + +void dynamicPartitionFunctorBP(sd::LaunchContext* context, NDArray const* input, + NDArray const* indices, + std::vector const& inputGradientList, + std::vector& outputList) { + auto xType = input->dataType(); + + BUILD_SINGLE_SELECTOR(xType, _dynamicPartitionFunctorBP, + (input, indices, inputGradientList, outputList), + LIBND4J_TYPES); } +BUILD_SINGLE_TEMPLATE(template void _dynamicPartitionFunctorBP, + (NDArray const* input, NDArray const* indices, + std::vector const& inputGradientList, + std::vector& outputList); + , LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template int _dynamicStitchFunctorBP, + (std::vector const& inputs, + std::vector const& indices, + NDArray const* gradInput, + std::vector& outputList); + , LIBND4J_TYPES); + +BUILD_SINGLE_TEMPLATE(template void _dynamicPartitionFunctor, + (NDArray const* input, NDArray const* indices, + std::vector& outputList); + , LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template int _dynamicStitchFunctor, + (std::vector const& inputs, + std::vector const& indices, NDArray* output); + , LIBND4J_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/extract_patches.cpp b/libnd4j/include/ops/declarable/helpers/cpu/extract_patches.cpp index ba04fd9aac19..9de2b03103d2 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/extract_patches.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/extract_patches.cpp @@ -18,41 +18,43 @@ // @author sgazeos@gmail.com // -#include #include +#include namespace sd { namespace ops { namespace helpers { - template - static void _extractPatches(NDArray* images, NDArray* output, int sizeRow, int sizeCol, int strideRow, int strideCol, int rateRow, int rateCol, bool theSame){ - std::vector restDims({1, 2, 3}); // the first and the last dims - ResultSet listOfMatricies = images->allTensorsAlongDimension(restDims); - ResultSet listOfOutputs = output->allTensorsAlongDimension(restDims); - // 3D matricies - 2D matricies of vectors (if last dim is greater than 1) - //int e = 0; - const int ksizeRowsEffective = sizeRow + (sizeRow - 1) * (rateRow - 1); - const int ksizeColsEffective = sizeCol + (sizeCol - 1) * (rateCol - 1); - const int ksize = ksizeRowsEffective * ksizeColsEffective; - int batchCount = listOfMatricies.size(); //lengthOf() / ksize; - Nd4jLong lastDim = images->sizeAt(3); - Nd4jLong outLastDim = output->sizeAt(3); - Nd4jLong rowDim = images->sizeAt(1); - Nd4jLong colDim = images->sizeAt(2); - Nd4jLong outRowDim = output->sizeAt(1); - Nd4jLong outColDim = output->sizeAt(2); - auto rowCast = 1; //(sizeRow - 1)*rateRow < outRowDim/sizeRow ?0:1;///(ksize * lastDim > rowDim * ksizeColsEffective + lastDim?1:0); - auto colCast = 1; //colDim / ksizeColsEffective +2 <= sizeCol?0:1;//(ksize * lastDim > ksizeRowsEffective * colDim + lastDim?1:0); - if (sizeRow * rateRow < 3) - rowCast = 0; - if (sizeCol * rateCol < 3) - colCast = 0; +template +static void _extractPatches(NDArray* images, NDArray* output, int sizeRow, + int sizeCol, int strideRow, int strideCol, + int rateRow, int rateCol, bool theSame) { + std::vector restDims({1, 2, 3}); // the first and the last dims + ResultSet listOfMatricies = images->allTensorsAlongDimension(restDims); + ResultSet listOfOutputs = output->allTensorsAlongDimension(restDims); + // 3D matricies - 2D matricies of vectors (if last dim is greater than 1) + // int e = 0; + const int ksizeRowsEffective = sizeRow + (sizeRow - 1) * (rateRow - 1); + const int ksizeColsEffective = sizeCol + (sizeCol - 1) * (rateCol - 1); + const int ksize = ksizeRowsEffective * ksizeColsEffective; + int batchCount = listOfMatricies.size(); // lengthOf() / ksize; + Nd4jLong lastDim = images->sizeAt(3); + Nd4jLong outLastDim = output->sizeAt(3); + Nd4jLong rowDim = images->sizeAt(1); + Nd4jLong colDim = images->sizeAt(2); + Nd4jLong outRowDim = output->sizeAt(1); + Nd4jLong outColDim = output->sizeAt(2); + auto rowCast = 1; //(sizeRow - 1)*rateRow < outRowDim/sizeRow ?0:1;///(ksize + //* lastDim > rowDim * ksizeColsEffective + lastDim?1:0); + auto colCast = 1; // colDim / ksizeColsEffective +2 <= sizeCol?0:1;//(ksize * + // lastDim > ksizeRowsEffective * colDim + lastDim?1:0); + if (sizeRow * rateRow < 3) rowCast = 0; + if (sizeCol * rateCol < 3) colCast = 0; - auto func = PRAGMA_THREADS_FOR { - for (auto batch = 0; batch < stop; batch++) { - auto patch = listOfMatricies.at(batch); - auto outMatrix = listOfOutputs.at(batch); + auto func = PRAGMA_THREADS_FOR { + for (auto batch = 0; batch < stop; batch++) { + auto patch = listOfMatricies.at(batch); + auto outMatrix = listOfOutputs.at(batch); for (Nd4jLong i = 0; i < outRowDim; i++) { for (Nd4jLong j = 0; j < outColDim; j++) { @@ -73,7 +75,7 @@ namespace helpers { bool setUp = (theSame && row >= 0 && col >= 0 && row < rowDim && col < colDim) || (!theSame); if (setUp) { - outMatrix->r(i, j, pos) = patch->e(row, col, pixel); + outMatrix.r(i, j, pos) = patch.e(row, col, pixel); } pos++; } @@ -82,18 +84,26 @@ namespace helpers { } }; - samediff::Threads::parallel_tad(func, 0, batchCount); - } - + samediff::Threads::parallel_tad(func, 0, batchCount); +} - void extractPatches(sd::LaunchContext * context, NDArray* images, NDArray* output, int sizeRow, int sizeCol, int stradeRow, int stradeCol, int rateRow, int rateCol, bool theSame){ - auto xType = images->dataType(); +void extractPatches(sd::LaunchContext* context, NDArray* images, + NDArray* output, int sizeRow, int sizeCol, int stradeRow, + int stradeCol, int rateRow, int rateCol, bool theSame) { + auto xType = images->dataType(); - BUILD_SINGLE_SELECTOR(xType, _extractPatches, (images, output, sizeRow, sizeCol, stradeRow, stradeCol, rateRow, rateCol, theSame), LIBND4J_TYPES); - } + BUILD_SINGLE_SELECTOR(xType, _extractPatches, + (images, output, sizeRow, sizeCol, stradeRow, stradeCol, + rateRow, rateCol, theSame), + LIBND4J_TYPES); +} - BUILD_SINGLE_TEMPLATE(template void _extractPatches, (NDArray* input, NDArray* output, int sizeRow, int sizeCol, int stradeRow, int stradeCol, int rateRow, int rateCol, bool theSame), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void _extractPatches, + (NDArray * input, NDArray* output, int sizeRow, + int sizeCol, int stradeRow, int stradeCol, int rateRow, + int rateCol, bool theSame), + LIBND4J_TYPES); -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/eye.cpp b/libnd4j/include/ops/declarable/helpers/cpu/eye.cpp index 30a83b8713c9..1a7fc19e758d 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/eye.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/eye.cpp @@ -18,28 +18,25 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// -void eye(sd::LaunchContext * context, NDArray& output) { - - const int rank = output.rankOf(); - auto arrs = output.allTensorsAlongDimension({rank-2, rank-1}); +void eye(sd::LaunchContext* context, NDArray& output) { + const int rank = output.rankOf(); + auto arrs = output.allTensorsAlongDimension({rank - 2, rank - 1}); - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - arrs.at(i)->setIdentity(); - }; + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) arrs.at(i).setIdentity(); + }; - samediff::Threads::parallel_tad(func, 0, arrs.size()); + samediff::Threads::parallel_tad(func, 0, arrs.size()); } -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp b/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp index 7317f8a73baa..f36ea457e117 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp @@ -18,105 +18,124 @@ // @author sgazeos@gmail.com // -#include #include +#include namespace sd { namespace ops { namespace helpers { - // - // nudge - nudged min max over scale - // scale = (Max - Min) / (quantMax - quantMin) - // quantMin = 0 or 1, quantMax = 2^b - 1 == (1 << b) - 1 - // - template - static void nudge(T min, T max, int quantMin, int quantMax, T* scale, T* nudgedMin, T* nudgedMax) { - // floating point instead integers - T quantMaxF = static_cast(quantMax); - T quantMinF = static_cast(quantMin); - // compute scale - *scale = (max - min) / (quantMaxF - quantMinF); - // compute left bound point - auto zeroPointFromMin = quantMinF - min / *scale; - // bound zero point to conform with range [0 or 1, 2^b - 1] - uint16_t const nudged_zero_point = [zeroPointFromMin, quantMin, quantMax, quantMaxF, quantMinF] { - if (zeroPointFromMin < quantMinF) { - return static_cast(quantMin); - } - if (zeroPointFromMin > quantMaxF) { - return static_cast(quantMax); - } - return (uint16_t)sd::math::nd4j_round(zeroPointFromMin); - }(); - // compute nudged min and max with computed nudged zero point - *nudgedMin = (quantMinF - nudged_zero_point) * (*scale); - *nudgedMax = (quantMaxF - nudged_zero_point) * (*scale); +// +// nudge - nudged min max over scale +// scale = (Max - Min) / (quantMax - quantMin) +// quantMin = 0 or 1, quantMax = 2^b - 1 == (1 << b) - 1 +// +template +static void nudge(T min, T max, int quantMin, int quantMax, T* scale, + T* nudgedMin, T* nudgedMax) { + // floating point instead integers + T quantMaxF = static_cast(quantMax); + T quantMinF = static_cast(quantMin); + // compute scale + *scale = (max - min) / (quantMaxF - quantMinF); + // compute left bound point + auto zeroPointFromMin = quantMinF - min / *scale; + // bound zero point to conform with range [0 or 1, 2^b - 1] + uint16_t const nudged_zero_point = [zeroPointFromMin, quantMin, quantMax, + quantMaxF, quantMinF] { + if (zeroPointFromMin < quantMinF) { + return static_cast(quantMin); } + if (zeroPointFromMin > quantMaxF) { + return static_cast(quantMax); + } + return (uint16_t)sd::math::nd4j_round(zeroPointFromMin); + }(); + // compute nudged min and max with computed nudged zero point + *nudgedMin = (quantMinF - nudged_zero_point) * (*scale); + *nudgedMax = (quantMaxF - nudged_zero_point) * (*scale); +} - template - void fakeQuantWithMinMaxVarsPerChannel_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { - int lowIntBound = narrowed ? 1 : 0; // 0 or 1 - int upperIntBound = (1 << numBits) - 1; // 2^b - 1 - auto channels = input->sizeAt(-1); // last dimension +template +void fakeQuantWithMinMaxVarsPerChannel_(NDArray* input, NDArray* min, + NDArray* max, int numBits, + bool narrowed, NDArray* output) { + int lowIntBound = narrowed ? 1 : 0; // 0 or 1 + int upperIntBound = (1 << numBits) - 1; // 2^b - 1 + auto channels = input->sizeAt(-1); // last dimension - PRAGMA_OMP_PARALLEL_FOR - for (auto i = 0; i < channels; i++) { - T scale, nudged_min, nudged_max; - // nudge min and max first, with scale computing - nudge(min->t(i), max->t(i), lowIntBound, upperIntBound, &scale, &nudged_min, &nudged_max); - // slide using last dimension and process all for given channel - for (auto e = 0; e < input->lengthOf(); e += channels) { - T val = input->t(e + i); - if ( val <= nudged_min) - val = nudged_min; - else if (val >= nudged_max) - val = nudged_max; - // quantization itself - output->r(e + i) = math::nd4j_floor((val - nudged_min)/scale + T(0.5)) * scale + nudged_min; - } - } + PRAGMA_OMP_PARALLEL_FOR + for (auto i = 0; i < channels; i++) { + T scale, nudged_min, nudged_max; + // nudge min and max first, with scale computing + nudge(min->t(i), max->t(i), lowIntBound, upperIntBound, &scale, + &nudged_min, &nudged_max); + // slide using last dimension and process all for given channel + for (auto e = 0; e < input->lengthOf(); e += channels) { + T val = input->t(e + i); + if (val <= nudged_min) + val = nudged_min; + else if (val >= nudged_max) + val = nudged_max; + // quantization itself + output->r(e + i) = + math::nd4j_floor((val - nudged_min) / scale + T(0.5)) * scale + + nudged_min; } + } +} // -//const auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min); +// const auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min); // const auto clamped_shifted = clamped - nudged_min; // outputs.device(d) = (clamped_shifted / nudged_scale_repl + 0.5f).floor() * // nudged_scale_repl + // nudged_min; // - template - void fakeQuantWithMinMaxVars_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { - int lowIntBound = narrowed ? 1 : 0; - int upperIntBound = (1 << numBits) - 1; - - T nudgedMin, nudgedMax, scale; - // nudge with given min and max and compute scale and nudged min and max - nudge(min->t(0), max->t(0), lowIntBound, upperIntBound, &scale, &nudgedMin, &nudgedMax); - // quantization as one - auto fakeQuantizationWithMinMax = LAMBDA_T(x, nudgedMin, nudgedMax, scale) { - T val = x; // boundign value between nudged min and max - if (val < nudgedMin) { - val = nudgedMin; - } - else if (val > nudgedMax) - val = nudgedMax; - // converse value with scale and shifted with nudged min - val -= nudgedMin; - return (sd::math::nd4j_floor(val / scale + T(0.5f)) * scale + nudgedMin); - }; +template +void fakeQuantWithMinMaxVars_(NDArray* input, NDArray* min, NDArray* max, + int numBits, bool narrowed, NDArray* output) { + int lowIntBound = narrowed ? 1 : 0; + int upperIntBound = (1 << numBits) - 1; - input->applyLambda(fakeQuantizationWithMinMax, *output); - } - - void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVars_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES); - } - void fakeQuantWithMinMaxVarsPerChannel(LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVarsPerChannel_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template void fakeQuantWithMinMaxVars_, (NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output), FLOAT_TYPES); + T nudgedMin, nudgedMax, scale; + // nudge with given min and max and compute scale and nudged min and max + nudge(min->t(0), max->t(0), lowIntBound, upperIntBound, &scale, + &nudgedMin, &nudgedMax); + // quantization as one + auto fakeQuantizationWithMinMax = LAMBDA_T(x, nudgedMin, nudgedMax, scale) { + T val = x; // boundign value between nudged min and max + if (val < nudgedMin) { + val = nudgedMin; + } else if (val > nudgedMax) + val = nudgedMax; + // converse value with scale and shifted with nudged min + val -= nudgedMin; + return (sd::math::nd4j_floor(val / scale + T(0.5f)) * scale + + nudgedMin); + }; + input->applyLambda(fakeQuantizationWithMinMax, *output); } + +void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, + int numBits, bool narrowed, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVars_, + (input, min, max, numBits, narrowed, output), + FLOAT_TYPES); } +void fakeQuantWithMinMaxVarsPerChannel(LaunchContext* context, NDArray* input, + NDArray* min, NDArray* max, int numBits, + bool narrowed, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVarsPerChannel_, + (input, min, max, numBits, narrowed, output), + FLOAT_TYPES); } + +BUILD_SINGLE_TEMPLATE(template void fakeQuantWithMinMaxVars_, + (NDArray * input, NDArray* min, NDArray* max, int numBits, + bool narrowed, NDArray* output), + FLOAT_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/flatten.cpp b/libnd4j/include/ops/declarable/helpers/cpu/flatten.cpp index aadd74298e7a..9b1385ca2dac 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/flatten.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/flatten.cpp @@ -21,38 +21,40 @@ #include namespace sd { - namespace ops { - namespace helpers { - - template - static void flatten_(std::vector &inputs, NDArray *output, const char order) { - - int numArrays = inputs.size(); - std::vector offsets(numArrays); - Nd4jLong cOffset = 0; - - // calculating offsets in output - for (int e = 0; e < numArrays; e++) { - offsets[e] = cOffset; - cOffset += inputs[e]->lengthOf(); - } - - // actually transferring data - for (int e = 0; e < numArrays; e++) { - auto z = reinterpret_cast(output->bufferWithOffset(offsets[e])); - - auto xBuffer = inputs[e]->bufferAsT(); - auto xShapeInfo = inputs[e]->shapeInfo(); - auto xLength = inputs[e]->lengthOf(); - - for (Nd4jLong i = 0; i < xLength; i++) - z[i] = xBuffer[getIndexOffsetOrdered(i, xShapeInfo, order)]; - } - } - - void flatten(sd::LaunchContext *context, std::vector &inputs, NDArray *output, char order) { - BUILD_SINGLE_SELECTOR(output->dataType(), flatten_, (inputs, output, order), LIBND4J_TYPES); - } - } - } -} \ No newline at end of file +namespace ops { +namespace helpers { + +template +static void flatten_(std::vector &inputs, NDArray *output, + const char order) { + int numArrays = inputs.size(); + std::vector offsets(numArrays); + Nd4jLong cOffset = 0; + + // calculating offsets in output + for (int e = 0; e < numArrays; e++) { + offsets[e] = cOffset; + cOffset += inputs[e]->lengthOf(); + } + + // actually transferring data + for (int e = 0; e < numArrays; e++) { + auto z = reinterpret_cast(output->bufferWithOffset(offsets[e])); + + auto xBuffer = inputs[e]->bufferAsT(); + auto xShapeInfo = inputs[e]->shapeInfo(); + auto xLength = inputs[e]->lengthOf(); + + for (Nd4jLong i = 0; i < xLength; i++) + z[i] = xBuffer[getIndexOffsetOrdered(i, xShapeInfo, order)]; + } +} + +void flatten(sd::LaunchContext *context, std::vector &inputs, + NDArray *output, char order) { + BUILD_SINGLE_SELECTOR(output->dataType(), flatten_, (inputs, output, order), + LIBND4J_TYPES); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/gather.cpp b/libnd4j/include/ops/declarable/helpers/cpu/gather.cpp index c28101558d9e..b6120eaacac4 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/gather.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/gather.cpp @@ -18,160 +18,169 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 07.03.2019 // -#include -#include #include -#include #include +#include +#include + +#include namespace sd { namespace ops { namespace helpers { //////////////////////////////////////////////////////////////////////// -void gather(sd::LaunchContext * context, const NDArray* input, const NDArray* indices, NDArray* output, const std::vector& intArgs) { - - int axis = intArgs.size() > 0 ? intArgs[0] : 0; - const int inputRank = input->rankOf(); - if(axis < 0) - axis += inputRank; - - const int numOfIntArgs = intArgs.size(); - - if (indices != nullptr) { - - // first case: indices consist of only one scalar - if(indices->isScalar()) { - - if(input->rankOf() <= 1){ - //For scalar indices, rank 0 or 1 input: can't do tensor along dimension 0 as this is whole array... instead, we want to get a scalar - auto idx = indices->e(0); - auto scalarNDArray = input->e(idx); - output->assign(scalarNDArray); +void gather(sd::LaunchContext* context, const NDArray* input, + const NDArray* indices, NDArray* output, + const std::vector& intArgs) { + int axis = intArgs.size() > 0 ? intArgs[0] : 0; + const int inputRank = input->rankOf(); + if (axis < 0) axis += inputRank; + + const int numOfIntArgs = intArgs.size(); + + if (indices != nullptr) { + // first case: indices consist of only one scalar + if (indices->isScalar()) { + if (input->rankOf() <= 1) { + // For scalar indices, rank 0 or 1 input: can't do tensor along + // dimension 0 as this is whole array... instead, we want to get a + // scalar + auto idx = indices->e(0); + auto scalarNDArray = input->e(idx); + output->assign(scalarNDArray); + } else { + NDArray inSubArr = (*input)(indices->e(0), {axis}); + output->assign(inSubArr); + } + } else { + if (input->rankOf() == 1 && output->rankOf() == 1) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) + output->p(i, input->e(indices->e(i))); + }; + + samediff::Threads::parallel_for(func, 0, output->lengthOf()); + + } else { + std::vector dimsOut; + for (int i = 0; i < axis; ++i) dimsOut.push_back(i); + for (int i = axis + indices->rankOf(); i < output->rankOf(); ++i) + dimsOut.push_back(i); + + std::vector dimsIn = + ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); + + const Nd4jLong numOfSubArrs = indices->lengthOf(); + + auto inTadPack = ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimsIn); + auto outTadPack = ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimsOut); + + auto inTadShapeInfo = inTadPack.primaryShapeInfo(); + auto outTadShapeInfo = outTadPack.primaryShapeInfo(); + + if (shape::order(inTadShapeInfo) == shape::order(outTadShapeInfo) && + shape::order(inTadShapeInfo) == 'c' && + input->dataType() == output->dataType() && + shape::elementWiseStride(inTadShapeInfo) == 1 && + shape::elementWiseStride(outTadShapeInfo) == 1) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto inBuff = input->bufferWithOffset( + inTadPack.primaryOffsets()[indices->e(i)]); + auto outBuff = + output->bufferWithOffset(outTadPack.primaryOffsets()[i]); + + memcpy(outBuff, inBuff, + shape::length(inTadShapeInfo) * input->sizeOfT()); } - else { - NDArray inSubArr = (*input)(indices->e(0), {axis}); - output->assign(inSubArr); + }; + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); + } else { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto inBuff = input->bufferWithOffset( + inTadPack.primaryOffsets()[indices->e(i)]); + auto outBuff = + output->bufferWithOffset(outTadPack.primaryOffsets()[i]); + + NativeOpExecutioner::execTransformAny( + input->getContext(), transform::Assign, inBuff, + inTadShapeInfo, nullptr /*input specialBuffer*/, + nullptr /*input special*/, outBuff, outTadShapeInfo, + nullptr /*output specialBuffer*/, + nullptr /*output special*/, nullptr, nullptr, + nullptr, false /*allowParallelism*/); } - } - else { - - if(input->rankOf() == 1 && output->rankOf() == 1) { - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - output->p(i, input->e(indices->e(i))); - }; - - samediff::Threads::parallel_for(func, 0, output->lengthOf()); - - } - else { - - std::vector dimsOut; - for (int i = 0; i < axis; ++i) - dimsOut.push_back(i); - for (int i = axis+indices->rankOf(); i < output->rankOf(); ++i) - dimsOut.push_back(i); - - std::vector dimsIn = ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); - - const Nd4jLong numOfSubArrs = indices->lengthOf(); - - auto inTadPack = ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimsIn); - auto outTadPack = ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimsOut); - - auto inTadShapeInfo = inTadPack.primaryShapeInfo(); - auto outTadShapeInfo = outTadPack.primaryShapeInfo(); - - if (shape::order(inTadShapeInfo) == shape::order(outTadShapeInfo) && shape::order(inTadShapeInfo) == 'c' && input->dataType() == output->dataType() && shape::elementWiseStride(inTadShapeInfo) == 1 && shape::elementWiseStride(outTadShapeInfo) == 1) { - - auto func = PRAGMA_THREADS_FOR { - - for (auto i = start; i < stop; i++) { - auto inBuff = input->bufferWithOffset(inTadPack.primaryOffsets()[indices->e(i)]); - auto outBuff = output->bufferWithOffset(outTadPack.primaryOffsets()[i]); - - memcpy(outBuff, inBuff, shape::length(inTadShapeInfo) * input->sizeOfT()); - } - }; - samediff::Threads::parallel_tad(func, 0, numOfSubArrs); - } - else { - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { + }; - auto inBuff = input->bufferWithOffset(inTadPack.primaryOffsets()[indices->e(i)]); - auto outBuff = output->bufferWithOffset(outTadPack.primaryOffsets()[i]); - - NativeOpExecutioner::execTransformAny(input->getContext(), transform::Assign, - inBuff, inTadShapeInfo, nullptr/*input specialBuffer*/, nullptr/*input special*/, - outBuff, outTadShapeInfo, nullptr/*output specialBuffer*/, nullptr/*output special*/, - nullptr, nullptr, nullptr, false/*allowParallelism*/); - } - }; - - samediff::Threads::parallel_tad(func, 0, numOfSubArrs); - } - } + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); } + } } - else { - - // we only allow scalar/vector case here - if (numOfIntArgs == 2) { // scalar case - - output->assign((*input)(intArgs[1], {axis})); - } - else { // vector case - - const Nd4jLong numOfSubArrs = intArgs.size() - 1; - - std::vector dims = ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); - - auto inTadPack = ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dims); - auto outTadPack = ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dims); - - auto inTadShapeInfo = inTadPack.primaryShapeInfo(); - auto outTadShapeInfo = outTadPack.primaryShapeInfo(); - - if (shape::order(inTadShapeInfo) == shape::order(outTadShapeInfo) && shape::order(inTadShapeInfo) == 'c' && input->dataType() == output->dataType() && shape::elementWiseStride(inTadShapeInfo) == 1 && shape::elementWiseStride(outTadShapeInfo) == 1) { - - auto func = PRAGMA_THREADS_FOR { - - for (auto i = start; i < stop; i++) { - auto inBuff = input->bufferWithOffset(inTadPack.primaryOffsets()[intArgs[i + 1]]); - void* outBuff = output->bufferWithOffset(outTadPack.primaryOffsets()[i]); - - std::memcpy(outBuff, inBuff, shape::length(inTadShapeInfo) * input->sizeOfT()); - } - }; - samediff::Threads::parallel_tad(func, 0, numOfSubArrs); - - } - else { - - auto func = PRAGMA_THREADS_FOR { - - for (auto i = start; i < stop; i++) { - auto inBuff = input->bufferWithOffset(inTadPack.primaryOffsets()[intArgs[i + 1]]); - auto outBuff = output->bufferWithOffset(outTadPack.primaryOffsets()[i]); - - NativeOpExecutioner::execTransformAny(input->getContext(), transform::Assign, - inBuff, inTadShapeInfo, nullptr/*input specialBuffer*/, nullptr/*input special*/, - outBuff, outTadShapeInfo, nullptr/*output specialBuffer*/, nullptr/*output special*/, - nullptr, nullptr, nullptr, false/*allowParallelism*/); - - } - }; - samediff::Threads::parallel_tad(func, 0, numOfSubArrs); - } - - } + } else { + // we only allow scalar/vector case here + if (numOfIntArgs == 2) { // scalar case + + output->assign((*input)(intArgs[1], {axis})); + } else { // vector case + + const Nd4jLong numOfSubArrs = intArgs.size() - 1; + + std::vector dims = + ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); + + auto inTadPack = ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dims); + auto outTadPack = ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dims); + + auto inTadShapeInfo = inTadPack.primaryShapeInfo(); + auto outTadShapeInfo = outTadPack.primaryShapeInfo(); + + if (shape::order(inTadShapeInfo) == shape::order(outTadShapeInfo) && + shape::order(inTadShapeInfo) == 'c' && + input->dataType() == output->dataType() && + shape::elementWiseStride(inTadShapeInfo) == 1 && + shape::elementWiseStride(outTadShapeInfo) == 1) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto inBuff = input->bufferWithOffset( + inTadPack.primaryOffsets()[intArgs[i + 1]]); + void* outBuff = + output->bufferWithOffset(outTadPack.primaryOffsets()[i]); + + std::memcpy(outBuff, inBuff, + shape::length(inTadShapeInfo) * input->sizeOfT()); + } + }; + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); + + } else { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto inBuff = input->bufferWithOffset( + inTadPack.primaryOffsets()[intArgs[i + 1]]); + auto outBuff = + output->bufferWithOffset(outTadPack.primaryOffsets()[i]); + + NativeOpExecutioner::execTransformAny( + input->getContext(), transform::Assign, inBuff, inTadShapeInfo, + nullptr /*input specialBuffer*/, + nullptr /*input special*/, outBuff, outTadShapeInfo, + nullptr /*output specialBuffer*/, + nullptr /*output special*/, nullptr, nullptr, nullptr, + false /*allowParallelism*/); + } + }; + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); + } } + } } - -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/gatherTransforms.cpp b/libnd4j/include/ops/declarable/helpers/cpu/gatherTransforms.cpp index e6f1a389668b..daafa041997b 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/gatherTransforms.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/gatherTransforms.cpp @@ -18,166 +18,175 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - - -#include +#include #include +#include + #include -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - //////////////////////////////////////////////////////////////////////// -template +template static void gatherND_(NDArray& input, NDArray& indices, NDArray& output) { + const X* x = reinterpret_cast(input.buffer()); + const Y* y = reinterpret_cast(indices.buffer()); + X* z = reinterpret_cast(output.buffer()); - const X* x = reinterpret_cast(input.buffer()); - const Y* y = reinterpret_cast(indices.buffer()); - X* z = reinterpret_cast(output.buffer()); + const int xRank = input.rankOf(); + const int yRank = indices.rankOf(); + const int zRank = output.rankOf(); + const int maxRank = + sd::math::nd4j_max(yRank, sd::math::nd4j_max(xRank, zRank)); - const int xRank = input.rankOf(); - const int yRank = indices.rankOf(); - const int zRank = output.rankOf(); - const int maxRank = sd::math::nd4j_max(yRank, sd::math::nd4j_max(xRank, zRank)); + const Nd4jLong zLen = output.lengthOf(); - const Nd4jLong zLen = output.lengthOf(); + const uint yLastDim = indices.sizeAt(-1); - const uint yLastDim = indices.sizeAt(-1); + const int diff = zRank - xRank; + const bool bEqual = yLastDim == xRank; - const int diff = zRank - xRank; - const bool bEqual = yLastDim == xRank; + auto func = PRAGMA_THREADS_FOR { + int xCoords[MAX_RANK], zCoords[MAX_RANK], temp; - auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, output.shapeInfo(), zCoords); - int xCoords[MAX_RANK], zCoords[MAX_RANK], temp; + const auto zOffset = shape::getOffset(output.shapeInfo(), zCoords); - for (auto i = start; i < stop; i++) { - - shape::index2coordsCPU(start, i, output.shapeInfo(), zCoords); - - const auto zOffset = shape::getOffset(output.shapeInfo(), zCoords); - - temp = zCoords[yRank - 1]; - zCoords[yRank - 1] = 0; - const auto yOffset = shape::getOffset(indices.shapeInfo(), zCoords); - zCoords[yRank - 1] = temp; + temp = zCoords[yRank - 1]; + zCoords[yRank - 1] = 0; + const auto yOffset = shape::getOffset(indices.shapeInfo(), zCoords); + zCoords[yRank - 1] = temp; - if(bEqual) - memcpy(xCoords, zCoords, zRank * sizeof(int)); - else if(diff >= 0) - memcpy(xCoords, zCoords + diff, xRank * sizeof(int)); - else - memcpy(xCoords - diff, zCoords, zRank * sizeof(int)); + if (bEqual) + memcpy(xCoords, zCoords, zRank * sizeof(int)); + else if (diff >= 0) + memcpy(xCoords, zCoords + diff, xRank * sizeof(int)); + else + memcpy(xCoords - diff, zCoords, zRank * sizeof(int)); - for (uint j = 0; j < yLastDim; ++j) - xCoords[j] = y[yOffset + j * indices.stridesOf()[yRank - 1]]; // last stride + for (uint j = 0; j < yLastDim; ++j) + xCoords[j] = + y[yOffset + j * indices.stridesOf()[yRank - 1]]; // last stride - const auto xOffset = shape::getOffset(input.shapeInfo(), xCoords); + const auto xOffset = shape::getOffset(input.shapeInfo(), xCoords); - z[zOffset] = x[xOffset]; - } - }; + z[zOffset] = x[xOffset]; + } + }; - samediff::Threads::parallel_tad(func, 0, zLen); + samediff::Threads::parallel_tad(func, 0, zLen); } //////////////////////////////////////////////////////////////////////// -void gatherND(sd::LaunchContext * context, NDArray& input, NDArray& indices, NDArray& output) { - BUILD_DOUBLE_SELECTOR(input.dataType(), indices.dataType(), gatherND_, (input, indices, output), LIBND4J_TYPES, INDEXING_TYPES); +void gatherND(sd::LaunchContext* context, NDArray& input, NDArray& indices, + NDArray& output) { + BUILD_DOUBLE_SELECTOR(input.dataType(), indices.dataType(), gatherND_, + (input, indices, output), LIBND4J_TYPES, + INDEXING_TYPES); } - //////////////////////////////////////////////////////////////////////// -template -static void gather_(NDArray* input, const NDArray* indices, NDArray* output, const std::vector& intArgs) { - - int axis = intArgs.size() > 0 ? intArgs[0] : 0; - const int inputRank = input->rankOf(); - if(axis < 0) - axis += inputRank; - - const int numOfIntArgs = intArgs.size(); - - if (indices != nullptr) { - - for(Nd4jLong i = 0; i < indices->lengthOf(); ++i) - if(indices->e(i) >= input->sizeAt(axis)) - throw std::runtime_error("helpers::gather function: indices array contains wrong elements, each element must be smaller than corresponding dimension of input array !"); - - // first case: indices consist of only one scalar - if(indices->isScalar()) { - if(input->rankOf() <= 1){ - //For scalar indices, rank 0 or 1 input: can't do tensor along dimension 0 as this is whole array... instead, we want to get a scalar - auto idx = indices->e(0); - auto scalarNDArray = input->e(idx); - output->assign(scalarNDArray); - } else { - auto dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimensions); - - auto tadArr = NDArray(reinterpret_cast(reinterpret_cast(input->buffer()) + tadPack.primaryOffsets()[indices->e(0)]), tadPack.primaryShapeInfo(), output->getContext()); - output->assign(&tadArr); - } - } - else if (input->rankOf() == 1 && indices->isVector()) { - // special case - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - output->p(e, input->e(indices->e(e))); - }; - - samediff::Threads::parallel_for(func, 0, indices->lengthOf()); +template +static void gather_(NDArray* input, const NDArray* indices, NDArray* output, + const std::vector& intArgs) { + int axis = intArgs.size() > 0 ? intArgs[0] : 0; + const int inputRank = input->rankOf(); + if (axis < 0) axis += inputRank; + + const int numOfIntArgs = intArgs.size(); + + if (indices != nullptr) { + for (Nd4jLong i = 0; i < indices->lengthOf(); ++i) + if (indices->e(i) >= input->sizeAt(axis)) + throw std::runtime_error( + "helpers::gather function: indices array contains wrong elements, " + "each element must be smaller than corresponding dimension of " + "input array !"); + + // first case: indices consist of only one scalar + if (indices->isScalar()) { + if (input->rankOf() <= 1) { + // For scalar indices, rank 0 or 1 input: can't do tensor along + // dimension 0 as this is whole array... instead, we want to get a + // scalar + auto idx = indices->e(0); + auto scalarNDArray = input->e(idx); + output->assign(scalarNDArray); + } else { + auto dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimensions); + + auto tadArr = + NDArray(reinterpret_cast( + reinterpret_cast(input->buffer()) + + tadPack.primaryOffsets()[indices->e(0)]), + tadPack.primaryShapeInfo(), output->getContext()); + output->assign(&tadArr); + } + } else if (input->rankOf() == 1 && indices->isVector()) { + // special case + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) + output->p(e, input->e(indices->e(e))); + }; + + samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } else { + std::vector dimsOut(indices->rankOf()); + std::iota(dimsOut.begin(), dimsOut.end(), + axis); // fill with axis, axis+1, ... indices->rankOf()-1 + const Nd4jLong numOfSubArrs = + ShapeUtils::getNumOfSubArrs(output->shapeInfo(), dimsOut); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + NDArray subArrOut = (*output)(i, dimsOut); + NDArray subArrIn = (*input)(indices->e(i), {axis}); + subArrOut.assign(subArrIn); } - else { - - std::vector dimsOut(indices->rankOf()); - std::iota(dimsOut.begin(), dimsOut.end(), axis); // fill with axis, axis+1, ... indices->rankOf()-1 - const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(output->shapeInfo(), dimsOut); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - NDArray subArrOut = (*output)(i, dimsOut); - NDArray subArrIn = (*input)(indices->e(i), {axis}); - subArrOut.assign(subArrIn); - } - }; + }; - samediff::Threads::parallel_tad(func, 0, numOfSubArrs); - } + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); } - else { - - for(int i = 1; i < numOfIntArgs; ++i) - if(intArgs[i] >= input->sizeAt(axis)) - throw std::runtime_error("helpers::gather function: some of input indexes is larger than corresponding shape of input array !"); - - // we only allow scalar/vector case here - if (numOfIntArgs == 2) { // scalar case - output->assign((*input)(intArgs[1], {axis})); - } - else { // vector case - const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(output->shapeInfo(), {axis}); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - NDArray subArrOut = (*output)(i, {axis}); - NDArray subArrIn = (*input)(intArgs[i + 1], {axis}); - subArrOut.assign(subArrIn); - } - }; - - samediff::Threads::parallel_tad(func, 0, numOfSubArrs); + } else { + for (int i = 1; i < numOfIntArgs; ++i) + if (intArgs[i] >= input->sizeAt(axis)) + throw std::runtime_error( + "helpers::gather function: some of input indexes is larger than " + "corresponding shape of input array !"); + + // we only allow scalar/vector case here + if (numOfIntArgs == 2) { // scalar case + output->assign((*input)(intArgs[1], {axis})); + } else { // vector case + const Nd4jLong numOfSubArrs = + ShapeUtils::getNumOfSubArrs(output->shapeInfo(), {axis}); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + NDArray subArrOut = (*output)(i, {axis}); + NDArray subArrIn = (*input)(intArgs[i + 1], {axis}); + subArrOut.assign(subArrIn); } - } -} + }; - void gather(NDArray* input, const NDArray* indices, NDArray* output, const std::vector& intArgs) { - BUILD_SINGLE_SELECTOR(input->dataType(), gather_, (input, indices, output, intArgs), LIBND4J_TYPES); + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); } - -} + } } + +void gather(NDArray* input, const NDArray* indices, NDArray* output, + const std::vector& intArgs) { + BUILD_SINGLE_SELECTOR(input->dataType(), gather_, + (input, indices, output, intArgs), LIBND4J_TYPES); } + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/gradient.cpp b/libnd4j/include/ops/declarable/helpers/cpu/gradient.cpp index df5ee1afcf7b..e8786a3a1ac0 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/gradient.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/gradient.cpp @@ -25,18 +25,22 @@ namespace sd { namespace ops { namespace helpers { template -static void applyGradientDescent_(NDArray* input, NDArray* step, double weight, NDArray* output) { - auto lambda = LAMBDA_TT(_x, _y, weight) { - return _x - (_y * weight); - }; +static void applyGradientDescent_(NDArray* input, NDArray* step, double weight, + NDArray* output) { + auto lambda = LAMBDA_TT(_x, _y, weight) { return _x - (_y * weight); }; - input->applyPairwiseLambda(*step, lambda, *output); + input->applyPairwiseLambda(*step, lambda, *output); } -void applyGradientDescent(sd::LaunchContext* context, NDArray* input, NDArray* step, double weight, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), applyGradientDescent_, (input, step, weight, output), FLOAT_TYPES); -} -BUILD_SINGLE_TEMPLATE(template void applyGradientDescent_, (NDArray* input, NDArray* step, double weight, NDArray* output), FLOAT_TYPES); -} -} +void applyGradientDescent(sd::LaunchContext* context, NDArray* input, + NDArray* step, double weight, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), applyGradientDescent_, + (input, step, weight, output), FLOAT_TYPES); } +BUILD_SINGLE_TEMPLATE(template void applyGradientDescent_, + (NDArray * input, NDArray* step, double weight, + NDArray* output), + FLOAT_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/hamming.cpp b/libnd4j/include/ops/declarable/helpers/cpu/hamming.cpp index 10b6a27e026d..22df1684e685 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/hamming.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/hamming.cpp @@ -18,87 +18,84 @@ // @author raver119@gmail.com // -#include -#include #include +#include +#include namespace sd { - namespace ops { - namespace helpers { - - static Nd4jLong hamming_distance(unsigned long long x, unsigned long long y) { - Nd4jLong dist = 0; - - for (unsigned long long val = x ^ y; val > 0; val /= 2) { - if (val & 1) - dist++; - } - return dist; - } - - - template - static void _hamming(NDArray &x, NDArray &y, NDArray &z) { - auto xEws = x.ews(); - auto yEws = y.ews(); - - auto xBuffer = x.bufferAsT(); - auto yBuffer = y.bufferAsT(); - - Nd4jLong distance = 0; - auto lengthOf = x.lengthOf(); - int maxThreads = sd::math::nd4j_min(256, omp_get_max_threads()); - Nd4jLong intermediate[256]; - - // nullify temp values - for (int e = 0; e < maxThreads; e++) - intermediate[e] = 0; - - if (xEws == 1 && yEws == 1 && x.ordering() == y.ordering()) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto _x = static_cast(xBuffer[e]); - auto _y = static_cast(yBuffer[e]); - - intermediate[thread_id] += hamming_distance(_x, _y); - } - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, lengthOf); - } else if (xEws > 1 && yEws > 1 && x.ordering() == y.ordering()) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto _x = static_cast(xBuffer[e * xEws]); - auto _y = static_cast(yBuffer[e * yEws]); - - intermediate[thread_id] += hamming_distance(_x, _y); - } - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, lengthOf); - } else { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto _x = static_cast(x.e(e)); - auto _y = static_cast(y.e(e)); - - intermediate[thread_id] += hamming_distance(_x, _y); - } - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, lengthOf); - } - - // accumulate intermediate variables into output array - for (int e = 0; e < maxThreads; e++) - distance += intermediate[e]; - - z.p(0, distance); - } - - void hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &output) { - BUILD_DOUBLE_SELECTOR(x.dataType(), output.dataType(), _hamming, (x, y, output), INTEGER_TYPES, INDEXING_TYPES); - } - } - } -} \ No newline at end of file +namespace ops { +namespace helpers { + +static Nd4jLong hamming_distance(unsigned long long x, unsigned long long y) { + Nd4jLong dist = 0; + + for (unsigned long long val = x ^ y; val > 0; val /= 2) { + if (val & 1) dist++; + } + return dist; +} + +template +static void _hamming(NDArray &x, NDArray &y, NDArray &z) { + auto xEws = x.ews(); + auto yEws = y.ews(); + + auto xBuffer = x.bufferAsT(); + auto yBuffer = y.bufferAsT(); + + Nd4jLong distance = 0; + auto lengthOf = x.lengthOf(); + int maxThreads = sd::math::nd4j_min(256, omp_get_max_threads()); + Nd4jLong intermediate[256]; + + // nullify temp values + for (int e = 0; e < maxThreads; e++) intermediate[e] = 0; + + if (xEws == 1 && yEws == 1 && x.ordering() == y.ordering()) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto _x = static_cast(xBuffer[e]); + auto _y = static_cast(yBuffer[e]); + + intermediate[thread_id] += hamming_distance(_x, _y); + } + }; + + maxThreads = samediff::Threads::parallel_for(func, 0, lengthOf); + } else if (xEws > 1 && yEws > 1 && x.ordering() == y.ordering()) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto _x = static_cast(xBuffer[e * xEws]); + auto _y = static_cast(yBuffer[e * yEws]); + + intermediate[thread_id] += hamming_distance(_x, _y); + } + }; + + maxThreads = samediff::Threads::parallel_for(func, 0, lengthOf); + } else { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto _x = static_cast(x.e(e)); + auto _y = static_cast(y.e(e)); + + intermediate[thread_id] += hamming_distance(_x, _y); + } + }; + + maxThreads = samediff::Threads::parallel_for(func, 0, lengthOf); + } + + // accumulate intermediate variables into output array + for (int e = 0; e < maxThreads; e++) distance += intermediate[e]; + + z.p(0, distance); +} + +void hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &output) { + BUILD_DOUBLE_SELECTOR(x.dataType(), output.dataType(), _hamming, + (x, y, output), INTEGER_TYPES, INDEXING_TYPES); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/hashcode.cpp b/libnd4j/include/ops/declarable/helpers/cpu/hashcode.cpp index 5893b2c88363..822d155d0523 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/hashcode.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/hashcode.cpp @@ -18,88 +18,90 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { - namespace ops { - namespace helpers { - template - static void hashCode_(LaunchContext *context, NDArray &array, NDArray &result) { - Nd4jLong blockSize = 32; - auto length = array.lengthOf(); - int numBlocks = length / blockSize + ((length % blockSize == 0) ? 0 : 1); - auto tempA = NDArrayFactory::create('c', {numBlocks}, context); - auto tempB = NDArrayFactory::create('c', { numBlocks / blockSize + 1}, context); - - auto buffer = array.bufferAsT(); - auto tempBufferA = tempA.bufferAsT(); - auto tempBufferB = tempB.bufferAsT(); - - // default buffer is the first one, because it might be the last one in case of small arrays (< blockSize) - auto tempBuffer = tempBufferA; - auto tempResult = tempBufferB; - - // we divide array into 32 element chunks, and store intermediate results once - auto func = PRAGMA_THREADS_FOR { - for (auto b = start; b < stop; b++) { - auto blockBuffer = buffer + b * numBlocks; - - Nd4jLong r = 1; - for (Nd4jLong e = 0; e < blockSize && e + (b * numBlocks) < length; e++) { - auto v = longBytes(blockBuffer[e]); - r = 31 * r + v; - } - - tempBuffer[b] = r; - } - }; - samediff::Threads::parallel_tad(func, 0, numBlocks); - - // we replace pointer with intermediate one, and repeat only one chunk left - int iterationCount = 0; - while (numBlocks > 1) { - int lastLength = numBlocks; - numBlocks = lastLength / blockSize + ((lastLength % blockSize == 0) ? 0 : 1); - - - auto func2 = PRAGMA_THREADS_FOR { - for (auto b = start; b < stop; b++) { - auto blockBuffer = tempBuffer + b * numBlocks; - - Nd4jLong r = 1; - for (Nd4jLong e = 0; e < blockSize && e + (b * numBlocks) < lastLength; e++) { - auto v = longBytes(blockBuffer[e]); - r = 31 * r + v; - } - - tempResult[b] = r; - } - }; - samediff::Threads::parallel_tad(func2, 0, numBlocks); - - - iterationCount++; - // swapping buffers - if (iterationCount % 2 == 0) { - tempBuffer = tempBufferA; - tempResult = tempBufferB; - } else { - tempBuffer = tempBufferB; - tempResult = tempBufferA; - } - } +namespace ops { +namespace helpers { +template +static void hashCode_(LaunchContext *context, NDArray &array, NDArray &result) { + Nd4jLong blockSize = 32; + auto length = array.lengthOf(); + int numBlocks = length / blockSize + ((length % blockSize == 0) ? 0 : 1); + auto tempA = NDArrayFactory::create('c', {numBlocks}, context); + auto tempB = NDArrayFactory::create( + 'c', {numBlocks / blockSize + 1}, context); + + auto buffer = array.bufferAsT(); + auto tempBufferA = tempA.bufferAsT(); + auto tempBufferB = tempB.bufferAsT(); + + // default buffer is the first one, because it might be the last one in case + // of small arrays (< blockSize) + auto tempBuffer = tempBufferA; + auto tempResult = tempBufferB; + + // we divide array into 32 element chunks, and store intermediate results once + auto func = PRAGMA_THREADS_FOR { + for (auto b = start; b < stop; b++) { + auto blockBuffer = buffer + b * numBlocks; + + Nd4jLong r = 1; + for (Nd4jLong e = 0; e < blockSize && e + (b * numBlocks) < length; e++) { + auto v = longBytes(blockBuffer[e]); + r = 31 * r + v; + } + + tempBuffer[b] = r; + } + }; + samediff::Threads::parallel_tad(func, 0, numBlocks); + + // we replace pointer with intermediate one, and repeat only one chunk left + int iterationCount = 0; + while (numBlocks > 1) { + int lastLength = numBlocks; + numBlocks = + lastLength / blockSize + ((lastLength % blockSize == 0) ? 0 : 1); + + auto func2 = PRAGMA_THREADS_FOR { + for (auto b = start; b < stop; b++) { + auto blockBuffer = tempBuffer + b * numBlocks; + + Nd4jLong r = 1; + for (Nd4jLong e = 0; e < blockSize && e + (b * numBlocks) < lastLength; + e++) { + auto v = longBytes(blockBuffer[e]); + r = 31 * r + v; + } - if (length <= blockSize) - result.p(0, tempBufferA[0]); - else - result.p(0, tempResult[0]); - } + tempResult[b] = r; + } + }; + samediff::Threads::parallel_tad(func2, 0, numBlocks); + + iterationCount++; + // swapping buffers + if (iterationCount % 2 == 0) { + tempBuffer = tempBufferA; + tempResult = tempBufferB; + } else { + tempBuffer = tempBufferB; + tempResult = tempBufferA; + } + } + if (length <= blockSize) + result.p(0, tempBufferA[0]); + else + result.p(0, tempResult[0]); +} - void hashCode(LaunchContext *context, NDArray &array, NDArray &result) { - BUILD_SINGLE_SELECTOR(array.dataType(), hashCode_, (context, array, result), LIBND4J_TYPES); - } - } - } +void hashCode(LaunchContext *context, NDArray &array, NDArray &result) { + BUILD_SINGLE_SELECTOR(array.dataType(), hashCode_, (context, array, result), + LIBND4J_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp b/libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp index cb815110d415..9dafacb11c55 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp @@ -21,50 +21,55 @@ #include namespace sd { - namespace ops { - namespace helpers { - template - static void histogram_(void const* xBuffer, Nd4jLong const* xShapeInfo, void *zBuffer, Nd4jLong const* zShapeInfo, Nd4jLong numBins, double min_val, double max_val) { - auto dx = reinterpret_cast(xBuffer); - auto result = reinterpret_cast(zBuffer); +namespace ops { +namespace helpers { +template +static void histogram_(void const *xBuffer, Nd4jLong const *xShapeInfo, + void *zBuffer, Nd4jLong const *zShapeInfo, + Nd4jLong numBins, double min_val, double max_val) { + auto dx = reinterpret_cast(xBuffer); + auto result = reinterpret_cast(zBuffer); - int length = shape::length(xShapeInfo); + int length = shape::length(xShapeInfo); - X binSize = (max_val - min_val) / (numBins); + X binSize = (max_val - min_val) / (numBins); - // FIXME: this op should be parallelized - { - int *bins = new int[numBins]; - std::memset(bins, 0, sizeof(int) * numBins); + // FIXME: this op should be parallelized + { + int *bins = new int[numBins]; + std::memset(bins, 0, sizeof(int) * numBins); - PRAGMA_OMP_SIMD - for (int x = 0; x < length; x++) { - int idx = (int) ((dx[x] - min_val) / binSize); - if (idx < 0) - idx = 0; - else if (idx >= numBins) - idx = numBins - 1; + PRAGMA_OMP_SIMD + for (int x = 0; x < length; x++) { + int idx = (int)((dx[x] - min_val) / binSize); + if (idx < 0) + idx = 0; + else if (idx >= numBins) + idx = numBins - 1; - bins[idx]++; - } - - PRAGMA_OMP_SIMD - for (Nd4jLong x = 0; x < numBins; x++) { - result[x] += bins[x]; - } + bins[idx]++; + } + PRAGMA_OMP_SIMD + for (Nd4jLong x = 0; x < numBins; x++) { + result[x] += bins[x]; + } - delete[] bins; - } - } + delete[] bins; + } +} - void histogramHelper(sd::LaunchContext *context, NDArray &input, NDArray &output) { - Nd4jLong numBins = output.lengthOf(); - double min_val = input.reduceNumber(reduce::SameOps::Min).e(0); - double max_val = input.reduceNumber(reduce::SameOps::Max).e(0); +void histogramHelper(sd::LaunchContext *context, NDArray &input, + NDArray &output) { + Nd4jLong numBins = output.lengthOf(); + double min_val = input.reduceNumber(reduce::SameOps::Min).e(0); + double max_val = input.reduceNumber(reduce::SameOps::Max).e(0); - BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (input.buffer(), input.shapeInfo(), output.buffer(), output.shapeInfo(), numBins, min_val, max_val), LIBND4J_TYPES, INDEXING_TYPES); - } - } - } -} \ No newline at end of file + BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, + (input.buffer(), input.shapeInfo(), output.buffer(), + output.shapeInfo(), numBins, min_val, max_val), + LIBND4J_TYPES, INDEXING_TYPES); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/histogramFixedWidth.cpp b/libnd4j/include/ops/declarable/helpers/cpu/histogramFixedWidth.cpp index 9376e80bf8a4..a150c5652efb 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/histogramFixedWidth.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/histogramFixedWidth.cpp @@ -24,45 +24,48 @@ namespace sd { namespace ops { namespace helpers { - template -void histogramFixedWidth_(const NDArray& input, const NDArray& range, NDArray& output) { - - const int nbins = output.lengthOf(); +void histogramFixedWidth_(const NDArray& input, const NDArray& range, + NDArray& output) { + const int nbins = output.lengthOf(); - // firstly initialize output with zeros - output.nullify(); + // firstly initialize output with zeros + output.nullify(); - const T leftEdge = range.e(0); - const T rightEdge = range.e(1); + const T leftEdge = range.e(0); + const T rightEdge = range.e(1); - const T binWidth = (rightEdge - leftEdge ) / nbins; - const T secondEdge = leftEdge + binWidth; - const T lastButOneEdge = rightEdge - binWidth; + const T binWidth = (rightEdge - leftEdge) / nbins; + const T secondEdge = leftEdge + binWidth; + const T lastButOneEdge = rightEdge - binWidth; - Nd4jLong inputLength = input.lengthOf(); + Nd4jLong inputLength = input.lengthOf(); - // FIXME: make this one parallel without CRITICAL section - for(Nd4jLong i = 0; i < inputLength; ++i) { - const T value = input.e(i); + // FIXME: make this one parallel without CRITICAL section + for (Nd4jLong i = 0; i < inputLength; ++i) { + const T value = input.e(i); - if(value < secondEdge) { - output.p(0, output.e(0) + 1); - } else if(value >= lastButOneEdge) { - output.p(nbins - 1, output.e(nbins - 1) + 1); - } else { - Nd4jLong currInd = static_cast((value - leftEdge) / binWidth); - output.p(currInd, output.e(currInd) + 1); - } + if (value < secondEdge) { + output.p(0, output.e(0) + 1); + } else if (value >= lastButOneEdge) { + output.p(nbins - 1, output.e(nbins - 1) + 1); + } else { + Nd4jLong currInd = static_cast((value - leftEdge) / binWidth); + output.p(currInd, output.e(currInd) + 1); } + } } -void histogramFixedWidth(sd::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output) { - BUILD_SINGLE_SELECTOR(input.dataType(), histogramFixedWidth_, (input, range, output), LIBND4J_TYPES); +void histogramFixedWidth(sd::LaunchContext* context, const NDArray& input, + const NDArray& range, NDArray& output) { + BUILD_SINGLE_SELECTOR(input.dataType(), histogramFixedWidth_, + (input, range, output), LIBND4J_TYPES); } -BUILD_SINGLE_TEMPLATE(template void histogramFixedWidth_, (const NDArray& input, const NDArray& range, NDArray& output), LIBND4J_TYPES); - +BUILD_SINGLE_TEMPLATE(template void histogramFixedWidth_, + (const NDArray& input, const NDArray& range, + NDArray& output), + LIBND4J_TYPES); -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/im2col.cpp b/libnd4j/include/ops/declarable/helpers/cpu/im2col.cpp index 2434fddcc658..60bb27dd6b56 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/im2col.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/im2col.cpp @@ -18,123 +18,135 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 19.09.2018 // -#include #include +#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// template -static void im2col_(sd::LaunchContext & context, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const NDArray& arrZeroPadVal) { - - // input [bS, iC, iH, iW] is convoluted to output [bS, iC, kH, kW, oH, oW] - - auto imBuff = static_cast(input.buffer()); - auto colBuff = static_cast(output.buffer()); - auto imShapeBuffer = input.shapeInfo(); - auto colShapeBuffer = output.shapeInfo(); - auto colShape = shape::shapeOf(colShapeBuffer); - auto colStride = shape::stride(colShapeBuffer); - auto imShape = shape::shapeOf(imShapeBuffer); - auto imStride = shape::stride(imShapeBuffer); - - const T zeroPadVal = arrZeroPadVal.e(0); - - const int bS = imShape[0]; - const int iC = imShape[1]; - const int iH = imShape[2]; - const int iW = imShape[3]; - const int oH = colShape[4]; - const int oW = colShape[5]; - const Nd4jLong colStride0 = colStride[0]; - const Nd4jLong colStride1 = colStride[1]; - const Nd4jLong colStride2 = colStride[2]; - const Nd4jLong colStride3 = colStride[3]; - const Nd4jLong colStride4 = colStride[4]; - const Nd4jLong colStride5 = colStride[5]; - const Nd4jLong imStride0 = imStride[0]; - const Nd4jLong imStride1 = imStride[1]; - const Nd4jLong imStride2 = imStride[2]; - const Nd4jLong imStride3 = imStride[3]; - - - if (shape::order(imShapeBuffer) == 'c' && shape::order(colShapeBuffer) == 'c' && shape::strideDescendingCAscendingF(imShapeBuffer) && shape::strideDescendingCAscendingF(colShapeBuffer)) { - - auto func = PRAGMA_THREADS_FOR_2D { - for (auto b = start_x; b < stop_x; b++) { - for (auto c = start_y; c < stop_y; c++) { - for (int kRow = 0; kRow < kH; ++kRow) { - for (int kCol = 0; kCol < kW; ++kCol) { - for (int colH = 0; colH < oH; ++colH) { - for (int colW = 0; colW < oW; ++colW) { - - int imRow = (-pH + kRow * dH) + colH * sH; - int imCol = (-pW + kCol * dW) + colW * sW; - - auto col = colBuff + b * colStride0 + c * colStride1 + kRow * colStride2 + kCol * colStride3 + colH * colStride4 + colW * colStride5; - - if (static_cast(imRow) >= static_cast(iH) || static_cast(imCol) >= static_cast(iW)) - *col = zeroPadVal; - else { - auto im = imBuff + b * imStride0 + c * imStride1 + imRow * imStride2 + imCol * imStride3; - *col = *im; - } - } - } - } - } +static void im2col_(sd::LaunchContext& context, const NDArray& input, + NDArray& output, const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, const int dH, + const int dW, const NDArray& arrZeroPadVal) { + // input [bS, iC, iH, iW] is convoluted to output [bS, iC, kH, kW, oH, oW] + + auto imBuff = static_cast(input.buffer()); + auto colBuff = static_cast(output.buffer()); + auto imShapeBuffer = input.shapeInfo(); + auto colShapeBuffer = output.shapeInfo(); + auto colShape = shape::shapeOf(colShapeBuffer); + auto colStride = shape::stride(colShapeBuffer); + auto imShape = shape::shapeOf(imShapeBuffer); + auto imStride = shape::stride(imShapeBuffer); + + const T zeroPadVal = arrZeroPadVal.e(0); + + const int bS = imShape[0]; + const int iC = imShape[1]; + const int iH = imShape[2]; + const int iW = imShape[3]; + const int oH = colShape[4]; + const int oW = colShape[5]; + const Nd4jLong colStride0 = colStride[0]; + const Nd4jLong colStride1 = colStride[1]; + const Nd4jLong colStride2 = colStride[2]; + const Nd4jLong colStride3 = colStride[3]; + const Nd4jLong colStride4 = colStride[4]; + const Nd4jLong colStride5 = colStride[5]; + const Nd4jLong imStride0 = imStride[0]; + const Nd4jLong imStride1 = imStride[1]; + const Nd4jLong imStride2 = imStride[2]; + const Nd4jLong imStride3 = imStride[3]; + + if (shape::order(imShapeBuffer) == 'c' && + shape::order(colShapeBuffer) == 'c' && + shape::strideDescendingCAscendingF(imShapeBuffer) && + shape::strideDescendingCAscendingF(colShapeBuffer)) { + auto func = PRAGMA_THREADS_FOR_2D { + for (auto b = start_x; b < stop_x; b++) { + for (auto c = start_y; c < stop_y; c++) { + for (int kRow = 0; kRow < kH; ++kRow) { + for (int kCol = 0; kCol < kW; ++kCol) { + for (int colH = 0; colH < oH; ++colH) { + for (int colW = 0; colW < oW; ++colW) { + int imRow = (-pH + kRow * dH) + colH * sH; + int imCol = (-pW + kCol * dW) + colW * sW; + + auto col = colBuff + b * colStride0 + c * colStride1 + + kRow * colStride2 + kCol * colStride3 + + colH * colStride4 + colW * colStride5; + + if (static_cast(imRow) >= + static_cast(iH) || + static_cast(imCol) >= static_cast(iW)) + *col = zeroPadVal; + else { + auto im = imBuff + b * imStride0 + c * imStride1 + + imRow * imStride2 + imCol * imStride3; + *col = *im; + } } + } } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); - } - else { - - auto func = PRAGMA_THREADS_FOR_2D { - T *col; - T const* im; - int imRow, imCol; - - for (auto b = start_x; b < stop_x; b += inc_x) { - for (auto colH = start_y; colH < stop_y; colH += inc_y) { - for (int colW = 0; colW < oW; ++colW) { - for (int c = 0; c < iC; ++c) { - for (int kRow = 0; kRow < kH; ++kRow) { - for (int kCol = 0; kCol < kW; ++kCol) { - - imRow = (-pH + kRow * dH) + colH * sH; - imCol = (-pW + kCol * dW) + colW * sW; - - col = colBuff + b * colStride0 + c * colStride1 + kRow * colStride2 + kCol * colStride3 + colH * colStride4 + colW * colStride5; - - if (static_cast(imRow) >= static_cast(iH) || static_cast(imCol) >= static_cast(iW)) - *col = zeroPadVal; - else { - im = imBuff + b * imStride0 + c * imStride1 + imRow * imStride2 + imCol * imStride3; - *col = *im; - } - } - } - } - } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } else { + auto func = PRAGMA_THREADS_FOR_2D { + T* col; + T const* im; + int imRow, imCol; + + for (auto b = start_x; b < stop_x; b += inc_x) { + for (auto colH = start_y; colH < stop_y; colH += inc_y) { + for (int colW = 0; colW < oW; ++colW) { + for (int c = 0; c < iC; ++c) { + for (int kRow = 0; kRow < kH; ++kRow) { + for (int kCol = 0; kCol < kW; ++kCol) { + imRow = (-pH + kRow * dH) + colH * sH; + imCol = (-pW + kCol * dW) + colW * sW; + + col = colBuff + b * colStride0 + c * colStride1 + + kRow * colStride2 + kCol * colStride3 + + colH * colStride4 + colW * colStride5; + + if (static_cast(imRow) >= + static_cast(iH) || + static_cast(imCol) >= static_cast(iW)) + *col = zeroPadVal; + else { + im = imBuff + b * imStride0 + c * imStride1 + + imRow * imStride2 + imCol * imStride3; + *col = *im; + } } + } } - }; + } + } + } + }; - samediff::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1); - } + samediff::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1); + } } - -void im2col(sd::LaunchContext & context, const NDArray& im, NDArray& col, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const NDArray& arrZeroPadVal) { - BUILD_SINGLE_SELECTOR(im.dataType(), im2col_, (context, im, col, kH, kW, sH, sW, pH, pW, dH, dW, arrZeroPadVal), FLOAT_TYPES); +void im2col(sd::LaunchContext& context, const NDArray& im, NDArray& col, + const int kH, const int kW, const int sH, const int sW, + const int pH, const int pW, const int dH, const int dW, + const NDArray& arrZeroPadVal) { + BUILD_SINGLE_SELECTOR( + im.dataType(), im2col_, + (context, im, col, kH, kW, sH, sW, pH, pW, dH, dW, arrZeroPadVal), + FLOAT_TYPES); } - -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/image_draw_bounding_boxes.cpp b/libnd4j/include/ops/declarable/helpers/cpu/image_draw_bounding_boxes.cpp index ee4faafb024f..d2de40cc7dc1 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/image_draw_bounding_boxes.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/image_draw_bounding_boxes.cpp @@ -31,128 +31,139 @@ limitations under the License. // // @author sgazeos@gmail.com // -#include #include #include +#include namespace sd { namespace ops { namespace helpers { - typedef std::vector> ColorTable_t; - static ColorTable_t DefaultColorTable(int depth) { - std::vector> colorTable; - colorTable.emplace_back(std::vector({1, 1, 0, 1})); // 0: yellow - colorTable.emplace_back(std::vector({0, 0, 1, 1})); // 1: blue - colorTable.emplace_back(std::vector({1, 0, 0, 1})); // 2: red - colorTable.emplace_back(std::vector({0, 1, 0, 1})); // 3: lime - colorTable.emplace_back(std::vector({0.5, 0, 0.5, 1})); // 4: purple - colorTable.emplace_back(std::vector({0.5, 0.5, 0, 1})); // 5: olive - colorTable.emplace_back(std::vector({0.5, 0, 0, 1})); // 6: maroon - colorTable.emplace_back(std::vector({0, 0, 0.5, 1})); // 7: navy blue - colorTable.emplace_back(std::vector({0, 1, 1, 1})); // 8: aqua - colorTable.emplace_back(std::vector({1, 0, 1, 1})); // 9: fuchsia +typedef std::vector> ColorTable_t; +static ColorTable_t DefaultColorTable(int depth) { + std::vector> colorTable; + colorTable.emplace_back(std::vector({1, 1, 0, 1})); // 0: yellow + colorTable.emplace_back(std::vector({0, 0, 1, 1})); // 1: blue + colorTable.emplace_back(std::vector({1, 0, 0, 1})); // 2: red + colorTable.emplace_back(std::vector({0, 1, 0, 1})); // 3: lime + colorTable.emplace_back(std::vector({0.5, 0, 0.5, 1})); // 4: purple + colorTable.emplace_back(std::vector({0.5, 0.5, 0, 1})); // 5: olive + colorTable.emplace_back(std::vector({0.5, 0, 0, 1})); // 6: maroon + colorTable.emplace_back(std::vector({0, 0, 0.5, 1})); // 7: navy blue + colorTable.emplace_back(std::vector({0, 1, 1, 1})); // 8: aqua + colorTable.emplace_back(std::vector({1, 0, 1, 1})); // 9: fuchsia - if (depth == 1) { - for (Nd4jLong i = 0; i < colorTable.size(); i++) { - colorTable[i][0] = 1; - } - } - return colorTable; + if (depth == 1) { + for (Nd4jLong i = 0; i < colorTable.size(); i++) { + colorTable[i][0] = 1; } + } + return colorTable; +} - void drawBoundingBoxesFunctor(sd::LaunchContext * context, NDArray* images, NDArray* boxes, NDArray* colors, NDArray* output) { - // images - batch of 3D images with BW (last dim = 1), RGB (last dim = 3) or RGBA (last dim = 4) channel set - // boxes - batch of 2D bounds with last dim (y_start, x_start, y_end, x_end) to compute i and j as - // floor((height - 1 ) * y_start) => rowStart, floor((height - 1) * y_end) => rowEnd - // floor((width - 1 ) * x_start) => colStart, floor((width - 1) * x_end) => colEnd - // height = images->sizeAt(1), width = images->sizeAt(2) - // colors - colors for each box given - // set up color for each box as frame - auto batchSize = images->sizeAt(0); - auto boxSize = boxes->sizeAt(0); - auto height = images->sizeAt(1); - auto width = images->sizeAt(2); - auto channels = images->sizeAt(3); - //auto imageList = images->allTensorsAlongDimension({1, 2, 3}); // split images by batch -// auto boxList = boxes->allTensorsAlongDimension({1, 2}); // split boxes by batch - //auto colorSet = colors->allTensorsAlongDimension({0}); - output->assign(images); // fill up all output with input images, then fill up boxes - ColorTable_t colorTable; - if (colors) { - for (auto i = 0; i < colors->sizeAt(0); i++) { - std::vector colorValue(4); - for (auto j = 0; j < 4; j++) { - colorValue[j] = j < colors->sizeAt(1) ? colors->e(i, j) : 1.f; - } - colorTable.emplace_back(colorValue); - } - } - if (colorTable.empty()) - colorTable = DefaultColorTable(channels); - auto func = PRAGMA_THREADS_FOR { - for (auto batch = start; batch < stop; ++batch) { // loop by batch - const Nd4jLong numBoxes = boxes->sizeAt(1); - for (auto boxIndex = 0; boxIndex < numBoxes; ++boxIndex) { - auto colorIndex = boxIndex % colorTable.size(); - auto rowStart = Nd4jLong((height - 1) * boxes->t(batch, boxIndex, 0)); - auto rowStartBound = sd::math::nd4j_max(Nd4jLong(0), rowStart); - auto rowEnd = Nd4jLong((height - 1) * boxes->t(batch, boxIndex, 2)); - auto rowEndBound = sd::math::nd4j_min(Nd4jLong(height - 1), rowEnd); - auto colStart = Nd4jLong((width - 1) * boxes->t(batch, boxIndex, 1)); - auto colStartBound = sd::math::nd4j_max(Nd4jLong(0), colStart); - auto colEnd = Nd4jLong((width - 1) * boxes->t(batch, boxIndex, 3)); - auto colEndBound = sd::math::nd4j_min(Nd4jLong(width - 1), colEnd); - - if (rowStart > rowEnd || colStart > colEnd) { - nd4j_debug( - "helpers::drawBoundingBoxesFunctor: Bounding box (%lld, %lld, %lld, %lld) is inverted " - "and will not be drawn\n", rowStart, colStart, rowEnd, colEnd); - continue; - } - if (rowStart >= height || rowEnd < 0 || colStart >= width || - colEnd < 0) { - nd4j_debug( - "helpers::drawBoundingBoxesFunctor: Bounding box (%lld, %lld, %lld, %lld) is completely " - "outside the image and not be drawn\n ", rowStart, colStart, rowEnd, colEnd); - continue; - } +void drawBoundingBoxesFunctor(sd::LaunchContext* context, NDArray* images, + NDArray* boxes, NDArray* colors, + NDArray* output) { + // images - batch of 3D images with BW (last dim = 1), RGB (last dim = 3) or + // RGBA (last dim = 4) channel set boxes - batch of 2D bounds with last dim + // (y_start, x_start, y_end, x_end) to compute i and j as floor((height - 1 ) + // * y_start) => rowStart, floor((height - 1) * y_end) => rowEnd floor((width + // - 1 ) * x_start) => colStart, floor((width - 1) * x_end) => colEnd height = + // images->sizeAt(1), width = images->sizeAt(2) colors - colors for each box + // given set up color for each box as frame + auto batchSize = images->sizeAt(0); + auto boxSize = boxes->sizeAt(0); + auto height = images->sizeAt(1); + auto width = images->sizeAt(2); + auto channels = images->sizeAt(3); + // auto imageList = images->allTensorsAlongDimension({1, 2, 3}); // split + // images by batch + // auto boxList = boxes->allTensorsAlongDimension({1, 2}); // split + // boxes by batch + // auto colorSet = colors->allTensorsAlongDimension({0}); + output->assign( + images); // fill up all output with input images, then fill up boxes + ColorTable_t colorTable; + if (colors) { + for (auto i = 0; i < colors->sizeAt(0); i++) { + std::vector colorValue(4); + for (auto j = 0; j < 4; j++) { + colorValue[j] = j < colors->sizeAt(1) ? colors->e(i, j) : 1.f; + } + colorTable.emplace_back(colorValue); + } + } + if (colorTable.empty()) colorTable = DefaultColorTable(channels); + auto func = PRAGMA_THREADS_FOR { + for (auto batch = start; batch < stop; ++batch) { // loop by batch + const Nd4jLong numBoxes = boxes->sizeAt(1); + for (auto boxIndex = 0; boxIndex < numBoxes; ++boxIndex) { + auto colorIndex = boxIndex % colorTable.size(); + auto rowStart = + Nd4jLong((height - 1) * boxes->t(batch, boxIndex, 0)); + auto rowStartBound = sd::math::nd4j_max(Nd4jLong(0), rowStart); + auto rowEnd = + Nd4jLong((height - 1) * boxes->t(batch, boxIndex, 2)); + auto rowEndBound = sd::math::nd4j_min(Nd4jLong(height - 1), rowEnd); + auto colStart = + Nd4jLong((width - 1) * boxes->t(batch, boxIndex, 1)); + auto colStartBound = sd::math::nd4j_max(Nd4jLong(0), colStart); + auto colEnd = + Nd4jLong((width - 1) * boxes->t(batch, boxIndex, 3)); + auto colEndBound = sd::math::nd4j_min(Nd4jLong(width - 1), colEnd); - // Draw upper line - if (rowStart >= 0) { - for (auto j = colStartBound; j <= colEndBound; ++j) - for (auto c = 0; c < channels; c++) { - output->p(batch, rowStart, j, c, colorTable[colorIndex][c]); - } - } - // Draw bottom line. - if (rowEnd < height) { - for (auto j = colStartBound; j <= colEndBound; ++j) - for (auto c = 0; c < channels; c++) { - output->p(batch, rowEnd, j, c, colorTable[colorIndex][c]); - } - } + if (rowStart > rowEnd || colStart > colEnd) { + nd4j_debug( + "helpers::drawBoundingBoxesFunctor: Bounding box (%lld, %lld, " + "%lld, %lld) is inverted " + "and will not be drawn\n", + rowStart, colStart, rowEnd, colEnd); + continue; + } + if (rowStart >= height || rowEnd < 0 || colStart >= width || + colEnd < 0) { + nd4j_debug( + "helpers::drawBoundingBoxesFunctor: Bounding box (%lld, %lld, " + "%lld, %lld) is completely " + "outside the image and not be drawn\n ", + rowStart, colStart, rowEnd, colEnd); + continue; + } - // Draw left line. - if (colStart >= 0) { - for (auto i = rowStartBound; i <= rowEndBound; ++i) - for (auto c = 0; c < channels; c++) { - output->p(batch, i, colStart, c, colorTable[colorIndex][c]); - } - } - // Draw right line. - if (colEnd < width) { - for (auto i = rowStartBound; i <= rowEndBound; ++i) - for (auto c = 0; c < channels; c++) { - output->p(batch, i, colEnd, c, colorTable[colorIndex][c]); - } - } - } + // Draw upper line + if (rowStart >= 0) { + for (auto j = colStartBound; j <= colEndBound; ++j) + for (auto c = 0; c < channels; c++) { + output->p(batch, rowStart, j, c, colorTable[colorIndex][c]); } - }; - samediff::Threads::parallel_tad(func, 0, batchSize); + } + // Draw bottom line. + if (rowEnd < height) { + for (auto j = colStartBound; j <= colEndBound; ++j) + for (auto c = 0; c < channels; c++) { + output->p(batch, rowEnd, j, c, colorTable[colorIndex][c]); + } + } + // Draw left line. + if (colStart >= 0) { + for (auto i = rowStartBound; i <= rowEndBound; ++i) + for (auto c = 0; c < channels; c++) { + output->p(batch, i, colStart, c, colorTable[colorIndex][c]); + } + } + // Draw right line. + if (colEnd < width) { + for (auto i = rowStartBound; i <= rowEndBound; ++i) + for (auto c = 0; c < channels; c++) { + output->p(batch, i, colEnd, c, colorTable[colorIndex][c]); + } + } + } } - -} -} + }; + samediff::Threads::parallel_tad(func, 0, batchSize); } + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp b/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp index 7206b03e5e1d..f086339c368c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp @@ -33,964 +33,1050 @@ limitations under the License. // @author sgazeos@gmail.com // -#include #include #include +#include + #include "../cross.h" namespace sd { namespace ops { namespace helpers { - struct BilinearInterpolationData { - Nd4jLong _bottomIndex; // Lower source index used in the interpolation - Nd4jLong _topIndex; // Upper source index used in the interpolation - // 1D linear iterpolation scale (see: - // https://en.wikipedia.org/wiki/Bilinear_interpolation) - double _interpolarValue; - }; - // calculateResizeScale determines the float scaling factor. - inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize, - bool alignCorners) { - return (alignCorners && outSize > 1) - ? (inSize - 1) / static_cast(outSize - 1) - : inSize / static_cast(outSize); - } - - template - struct ImageResizerStateCommon { - explicit ImageResizerStateCommon(bool alignCorners, bool halfPixelCenters) - : _alignCorners(alignCorners), - _halfPixelCenters(halfPixelCenters) {} - - // ValidateAndCalculateOutputSize checks the bounds on the input tensors - // and requested size, sets up some of the resizing state such as the - // heightScale and widthScale, and calculates the output size. - // If any of these operations fails, it sets an error status in - // the context, which the caller must check. - int validateAndCalculateOutputSize(NDArray const* input, int const width, int const height) { - // - batchSize = input->sizeAt(0);//.dim_size(0); - outHeight = height; - outWidth = width; //internal::SubtleMustCopy(Svec(1)); - inHeight = static_cast(input->sizeAt(1)); - inWidth = static_cast(input->sizeAt(2)); - channels = input->sizeAt(3); //.dim_size(3); - heightScale = calculateResizeScale(inHeight, outHeight, _alignCorners); - widthScale = calculateResizeScale(inWidth, outWidth, _alignCorners); - - // Guard against overflows - if (ceilf((outHeight - 1) * heightScale) > static_cast(DataTypeUtils::max())) { - nd4j_printf("resize_bicubic: Upper overflow occurs for resize height (%f)\n", ceilf((outHeight - 1) * heightScale)); - return Status::CODE(ND4J_STATUS_BAD_INPUT, "resize_bicubic: Upper overflow occurs for resize height"); - } - if (ceilf((outWidth - 1) * heightScale) > static_cast(DataTypeUtils::max())) { - nd4j_printf("resize_bicubic: Upper overflow occurs for resize height (%f)\n", ceilf((outHeight - 1) * heightScale)); - return Status::CODE(ND4J_STATUS_BAD_INPUT, "resize_bicubic: Upper overflow occurs for resize width"); - } - - return Status::OK(); - } - - // Calculates all the required variables, and allocates the output. - int validateAndCreateOutput(NDArray const* input, int const width, int const height) { - return validateAndCalculateOutputSize(input, width, height); - } - - I batchSize; - I outHeight; - I outWidth; - I inHeight; - I inWidth; - I channels; - F heightScale; - F widthScale; - NDArray* output = nullptr; - - private: - bool _alignCorners; - bool _halfPixelCenters; - }; +struct BilinearInterpolationData { + Nd4jLong _bottomIndex; // Lower source index used in the interpolation + Nd4jLong _topIndex; // Upper source index used in the interpolation + // 1D linear iterpolation scale (see: + // https://en.wikipedia.org/wiki/Bilinear_interpolation) + double _interpolarValue; +}; +// calculateResizeScale determines the float scaling factor. +inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize, + bool alignCorners) { + return (alignCorners && outSize > 1) + ? (inSize - 1) / static_cast(outSize - 1) + : inSize / static_cast(outSize); +} - typedef ImageResizerStateCommon ImageResizerState; +template +struct ImageResizerStateCommon { + explicit ImageResizerStateCommon(bool alignCorners, bool halfPixelCenters) + : _alignCorners(alignCorners), _halfPixelCenters(halfPixelCenters) {} + + // ValidateAndCalculateOutputSize checks the bounds on the input tensors + // and requested size, sets up some of the resizing state such as the + // heightScale and widthScale, and calculates the output size. + // If any of these operations fails, it sets an error status in + // the context, which the caller must check. + int validateAndCalculateOutputSize(NDArray const* input, int const width, + int const height) { + // + batchSize = input->sizeAt(0); //.dim_size(0); + outHeight = height; + outWidth = width; // internal::SubtleMustCopy(Svec(1)); + inHeight = static_cast(input->sizeAt(1)); + inWidth = static_cast(input->sizeAt(2)); + channels = input->sizeAt(3); //.dim_size(3); + heightScale = calculateResizeScale(inHeight, outHeight, _alignCorners); + widthScale = calculateResizeScale(inWidth, outWidth, _alignCorners); + + // Guard against overflows + if (ceilf((outHeight - 1) * heightScale) > + static_cast(DataTypeUtils::max())) { + nd4j_printf( + "resize_bicubic: Upper overflow occurs for resize height (%f)\n", + ceilf((outHeight - 1) * heightScale)); + return Status::CODE( + ND4J_STATUS_BAD_INPUT, + "resize_bicubic: Upper overflow occurs for resize height"); + } + if (ceilf((outWidth - 1) * heightScale) > + static_cast(DataTypeUtils::max())) { + nd4j_printf( + "resize_bicubic: Upper overflow occurs for resize height (%f)\n", + ceilf((outHeight - 1) * heightScale)); + return Status::CODE( + ND4J_STATUS_BAD_INPUT, + "resize_bicubic: Upper overflow occurs for resize width"); + } - // Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the + return Status::OK(); + } + + // Calculates all the required variables, and allocates the output. + int validateAndCreateOutput(NDArray const* input, int const width, + int const height) { + return validateAndCalculateOutputSize(input, width, height); + } + + I batchSize; + I outHeight; + I outWidth; + I inHeight; + I inWidth; + I channels; + F heightScale; + F widthScale; + NDArray* output = nullptr; + + private: + bool _alignCorners; + bool _halfPixelCenters; +}; + +typedef ImageResizerStateCommon ImageResizerState; + +// Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the // floating point coordinates of the top,left pixel is 0.5,0.5. - struct HalfPixelScaler { - HalfPixelScaler(){}; - inline float operator()(const int x, const float scale) const { - // Note that we subtract 0.5 from the return value, as the existing bilinear - // sampling code etc assumes pixels are in the old coordinate system. - return (static_cast(x) + 0.5f) * scale - 0.5f; - } - }; - - // Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the +struct HalfPixelScaler { + HalfPixelScaler(){}; + inline float operator()(const int x, const float scale) const { + // Note that we subtract 0.5 from the return value, as the existing bilinear + // sampling code etc assumes pixels are in the old coordinate system. + return (static_cast(x) + 0.5f) * scale - 0.5f; + } +}; + +// Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the // floating point coordinates of the top,left pixel is 0.5,0.5. - struct HalfPixelScalerNN { - HalfPixelScalerNN(){}; - inline float operator()(const int x, const float scale) const { - // Note that we subtract 0.5 from the return value, as the existing bilinear - // sampling code etc assumes pixels are in the old coordinate system. - return (static_cast(x) + 0.5f) * scale; - } - }; +struct HalfPixelScalerNN { + HalfPixelScalerNN(){}; + inline float operator()(const int x, const float scale) const { + // Note that we subtract 0.5 from the return value, as the existing bilinear + // sampling code etc assumes pixels are in the old coordinate system. + return (static_cast(x) + 0.5f) * scale; + } +}; // Older incorrect scaling method that causes all resizes to have a slight // translation leading to inconsistent results. For example, a flip then a // resize gives different results then a resize then a flip. - struct LegacyScaler { - LegacyScaler(){}; - inline float operator()(const int x, const float scale) const { - return static_cast(x) * scale; - } - }; - - struct WeightsAndIndices { - float _weight0; - float _weight1; - float _weight2; - float _weight3; - Nd4jLong _index0; - Nd4jLong _index1; - Nd4jLong _index2; - Nd4jLong _index3; - - int _advance; // advance value. - }; - - template - inline void computeInterpolationWeights(const Scaler scaler, Nd4jLong outSize, - Nd4jLong inSize, - double scale, - BilinearInterpolationData *interpolationData) { - interpolationData[outSize]._bottomIndex = 0; - interpolationData[outSize]._topIndex = 0; - - auto func = PRAGMA_THREADS_FOR { - for (auto k = start; k < stop; k++) { - auto i = (outSize - k - 1); - double const in = scaler(i, scale); - double const in_f = sd::math::nd4j_floor(in); - double const in_c = sd::math::nd4j_ceil(in); - interpolationData[i]._bottomIndex = sd::math::nd4j_max(static_cast(in_f), (Nd4jLong)0LL);//static_cast(in); - interpolationData[i]._topIndex = sd::math::nd4j_min(static_cast(in_c), inSize - 1); - interpolationData[i]._interpolarValue = in - in_f; - } - }; - samediff::Threads::parallel_for(func, 0, outSize); +struct LegacyScaler { + LegacyScaler(){}; + inline float operator()(const int x, const float scale) const { + return static_cast(x) * scale; + } +}; + +struct WeightsAndIndices { + float _weight0; + float _weight1; + float _weight2; + float _weight3; + Nd4jLong _index0; + Nd4jLong _index1; + Nd4jLong _index2; + Nd4jLong _index3; + + int _advance; // advance value. +}; + +template +inline void computeInterpolationWeights( + const Scaler scaler, Nd4jLong outSize, Nd4jLong inSize, double scale, + BilinearInterpolationData* interpolationData) { + interpolationData[outSize]._bottomIndex = 0; + interpolationData[outSize]._topIndex = 0; + + auto func = PRAGMA_THREADS_FOR { + for (auto k = start; k < stop; k++) { + auto i = (outSize - k - 1); + double const in = scaler(i, scale); + double const in_f = sd::math::nd4j_floor(in); + double const in_c = sd::math::nd4j_ceil(in); + interpolationData[i]._bottomIndex = + sd::math::nd4j_max(static_cast(in_f), + (Nd4jLong)0LL); // static_cast(in); + interpolationData[i]._topIndex = + sd::math::nd4j_min(static_cast(in_c), inSize - 1); + interpolationData[i]._interpolarValue = in - in_f; } + }; + samediff::Threads::parallel_for(func, 0, outSize); +} /** * Computes the bilinear interpolation from the appropriate 4 float points * and the linear interpolation weights. */ // static void -// resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, +// resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, +// Nd4jLong inWidth, Nd4jLong outHeight, // Nd4jLong outWidth, Nd4jLong channels, // std::vector const& xs, // std::vector const& ys, // NDArray *output); - template - static void - resizeImage_(T const* pInputBuf, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, - Nd4jLong outWidth, Nd4jLong channels, - std::vector const &xs, - std::vector const &ys, - Z* pOutputBuf) { - - Nd4jLong inRowSize = inWidth * channels; - Nd4jLong inBatchNumValues = inHeight * inRowSize; - Nd4jLong outRowSize = outWidth * channels; - -// T const *pInputBuf = images->getDataBuffer()->primaryAsT(); // this works only with 'c' direction - BilinearInterpolationData const* xsPtr = xs.data(); - -// T* pOutputBuf = output->dataBuffer()->primaryAsT(); - auto computeBilinear = [](double topLeft, double topRight, - double bottomLeft, double bottomRight, - double xVal, double yVal) { - double top = topLeft + (topRight - topLeft) * xVal; - double bottom = bottomLeft + (bottomRight - bottomLeft) * xVal; - return top + (bottom - top) * yVal; - }; - - auto func = PRAGMA_THREADS_FOR { - for (auto batch = start; batch < stop; ++batch) { - auto pInput = pInputBuf + batch * inBatchNumValues; - for (Nd4jLong y = 0; y < outHeight; ++y) { - auto pOutput = pOutputBuf + (batch * outHeight + y) * outRowSize; - const T* ysInputLowerPtr = pInput + ys[y]._bottomIndex * inRowSize; - const T* ysInputUpperPtr = pInput + ys[y]._topIndex * inRowSize; - double yVal = ys[y]._interpolarValue; - for (Nd4jLong x = 0; x < outWidth; ++x) { - auto xsBottom = xsPtr[x]._bottomIndex; - auto xsTop = xsPtr[x]._topIndex; - auto xVal = xsPtr[x]._interpolarValue; - for (Nd4jLong c = 0; c < channels; ++c) { - double topLeft(ysInputLowerPtr[xsBottom + c]); - double topRight(ysInputLowerPtr[xsTop + c]); - double bottomLeft(ysInputUpperPtr[xsBottom + c]); - double bottomRight(ysInputUpperPtr[xsTop + c]); - pOutput[x * channels + c] = computeBilinear(topLeft, topRight, bottomLeft, bottomRight, - xVal, yVal); - } - } - } - } - }; - samediff::Threads::parallel_tad(func, 0, batchSize); +template +static void resizeImage_(T const* pInputBuf, Nd4jLong batchSize, + Nd4jLong inHeight, Nd4jLong inWidth, + Nd4jLong outHeight, Nd4jLong outWidth, + Nd4jLong channels, + std::vector const& xs, + std::vector const& ys, + Z* pOutputBuf) { + Nd4jLong inRowSize = inWidth * channels; + Nd4jLong inBatchNumValues = inHeight * inRowSize; + Nd4jLong outRowSize = outWidth * channels; + + // T const *pInputBuf = images->getDataBuffer()->primaryAsT(); // + // this works only with 'c' direction + BilinearInterpolationData const* xsPtr = xs.data(); + + // T* pOutputBuf = output->dataBuffer()->primaryAsT(); + auto computeBilinear = [](double topLeft, double topRight, double bottomLeft, + double bottomRight, double xVal, double yVal) { + double top = topLeft + (topRight - topLeft) * xVal; + double bottom = bottomLeft + (bottomRight - bottomLeft) * xVal; + return top + (bottom - top) * yVal; + }; + + auto func = PRAGMA_THREADS_FOR { + for (auto batch = start; batch < stop; ++batch) { + auto pInput = pInputBuf + batch * inBatchNumValues; + for (Nd4jLong y = 0; y < outHeight; ++y) { + auto pOutput = pOutputBuf + (batch * outHeight + y) * outRowSize; + const T* ysInputLowerPtr = pInput + ys[y]._bottomIndex * inRowSize; + const T* ysInputUpperPtr = pInput + ys[y]._topIndex * inRowSize; + double yVal = ys[y]._interpolarValue; + for (Nd4jLong x = 0; x < outWidth; ++x) { + auto xsBottom = xsPtr[x]._bottomIndex; + auto xsTop = xsPtr[x]._topIndex; + auto xVal = xsPtr[x]._interpolarValue; + for (Nd4jLong c = 0; c < channels; ++c) { + double topLeft(ysInputLowerPtr[xsBottom + c]); + double topRight(ysInputLowerPtr[xsTop + c]); + double bottomLeft(ysInputUpperPtr[xsBottom + c]); + double bottomRight(ysInputUpperPtr[xsTop + c]); + pOutput[x * channels + c] = computeBilinear( + topLeft, topRight, bottomLeft, bottomRight, xVal, yVal); + } + } + } } + }; + samediff::Threads::parallel_tad(func, 0, batchSize); +} - template - static int resizeBilinearFunctor_(NDArray const *images, int const width, int const height, bool const alignCorners, - bool const halfPixelCenter, NDArray *output) { - ImageResizerState st(alignCorners, halfPixelCenter); - st.validateAndCalculateOutputSize(images, width, height); - - const Nd4jLong batchSize = images->sizeAt(0); - const Nd4jLong inHeight = images->sizeAt(1); - const Nd4jLong inWidth = images->sizeAt(2); - const Nd4jLong channels = images->sizeAt(3); - - const Nd4jLong outHeight = output->sizeAt(1); - const Nd4jLong outWidth = output->sizeAt(2); +template +static int resizeBilinearFunctor_(NDArray const* images, int const width, + int const height, bool const alignCorners, + bool const halfPixelCenter, NDArray* output) { + ImageResizerState st(alignCorners, halfPixelCenter); + st.validateAndCalculateOutputSize(images, width, height); + + const Nd4jLong batchSize = images->sizeAt(0); + const Nd4jLong inHeight = images->sizeAt(1); + const Nd4jLong inWidth = images->sizeAt(2); + const Nd4jLong channels = images->sizeAt(3); + + const Nd4jLong outHeight = output->sizeAt(1); + const Nd4jLong outWidth = output->sizeAt(2); + + // Handle no-op resizes efficiently. + if (outHeight == inHeight && outWidth == inWidth) { + output->assign(images); + return Status::OK(); + } + + std::vector ys(outHeight + 1); + std::vector xs(outWidth + 1); + if (halfPixelCenter) { + computeInterpolationWeights(HalfPixelScaler(), outHeight, inHeight, + st.heightScale, ys.data()); + computeInterpolationWeights(HalfPixelScaler(), outWidth, inWidth, + st.widthScale, xs.data()); + + } else { + // Compute the cached interpolation weights on the x and y dimensions. + computeInterpolationWeights(LegacyScaler(), outHeight, inHeight, + st.heightScale, ys.data()); + computeInterpolationWeights(LegacyScaler(), outWidth, inWidth, + st.widthScale, xs.data()); + } + int xsSize = xs.size(); + // Scale x interpolation weights to avoid a multiplication during iteration. + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + xs[i]._bottomIndex *= channels; + xs[i]._topIndex *= channels; + } + }; + samediff::Threads::parallel_for(func, 0, xsSize); - // Handle no-op resizes efficiently. - if (outHeight == inHeight && outWidth == inWidth) { - output->assign(images); - return Status::OK(); - } + resizeImage_(images->getDataBuffer()->primaryAsT(), batchSize, + inHeight, inWidth, outHeight, outWidth, channels, xs, ys, + output->dataBuffer()->primaryAsT()); + return Status::OK(); +} - std::vector ys(outHeight + 1); - std::vector xs(outWidth + 1); +template +void resizeNeighbor(ImageResizerState const& st, NDArray const* images, + bool const alignCorners, bool const halfPixelCenter, + NDArray* output) { + const Nd4jLong batchSize = st.batchSize; + const Nd4jLong inHeight = st.inHeight; + const Nd4jLong inWidth = st.inWidth; + const Nd4jLong channels = st.channels; + + const Nd4jLong outHeight = st.outHeight; + const Nd4jLong outWidth = st.outWidth; + Scaler scaler; + + auto func = PRAGMA_THREADS_FOR_2D { + for (auto b = start_x; b < stop_x; b += inc_x) { + for (auto y = start_y; y < stop_y; y += inc_y) { + auto posY = + alignCorners + ? static_cast( + sd::math::p_round(scaler(y, st.heightScale))) + : static_cast( + sd::math::p_floor(scaler(y, st.heightScale))); + Nd4jLong inY = sd::math::nd4j_min(posY, inHeight - 1); if (halfPixelCenter) { - computeInterpolationWeights(HalfPixelScaler(), outHeight, inHeight, st.heightScale, - ys.data()); - computeInterpolationWeights(HalfPixelScaler(), outWidth, inWidth, st.widthScale, xs.data()); - + inY = sd::math::nd4j_max(0LL, inY); } - else { - // Compute the cached interpolation weights on the x and y dimensions. - computeInterpolationWeights(LegacyScaler(), outHeight, inHeight, st.heightScale, - ys.data()); - computeInterpolationWeights(LegacyScaler(), outWidth, inWidth, st.widthScale, xs.data()); + for (Nd4jLong x = 0; x < outWidth; ++x) { + auto posX = + alignCorners + ? static_cast( + sd::math::p_round(scaler(x, st.widthScale))) + : static_cast( + sd::math::p_floor(scaler(x, st.widthScale))); + Nd4jLong inX = sd::math::nd4j_min(posX, inWidth - 1); + if (halfPixelCenter) { + inX = sd::math::nd4j_max(0LL, inX); + } + // copy pixel over all channels + for (Nd4jLong e = 0; e < channels; e++) + output->r(b, y, x, e) = images->t(b, inY, inX, e); } - int xsSize = xs.size(); - // Scale x interpolation weights to avoid a multiplication during iteration. - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - xs[i]._bottomIndex *= channels; - xs[i]._topIndex *= channels; - } - }; - samediff::Threads::parallel_for(func, 0, xsSize); - - resizeImage_(images->getDataBuffer()->primaryAsT(), batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output->dataBuffer()->primaryAsT()); - return Status::OK(); + } } + }; + samediff::Threads::parallel_for(func, 0, batchSize, 1, 0, outHeight, 1); +} - template - void resizeNeighbor(ImageResizerState const& st, NDArray const *images, bool const alignCorners, bool const halfPixelCenter, NDArray *output) { - const Nd4jLong batchSize = st.batchSize; - const Nd4jLong inHeight = st.inHeight; - const Nd4jLong inWidth = st.inWidth; - const Nd4jLong channels = st.channels; - - const Nd4jLong outHeight = st.outHeight; - const Nd4jLong outWidth = st.outWidth; - Scaler scaler; - - auto func = PRAGMA_THREADS_FOR_2D { - for (auto b = start_x; b < stop_x; b += inc_x) { - for (auto y = start_y; y < stop_y; y += inc_y) { - auto posY = alignCorners ? static_cast(sd::math::p_round(scaler(y, st.heightScale))) : static_cast(sd::math::p_floor(scaler(y, st.heightScale))); - Nd4jLong inY = sd::math::nd4j_min(posY, inHeight - 1); - if (halfPixelCenter) { - inY = sd::math::nd4j_max(0LL, inY); - } - for (Nd4jLong x = 0; x < outWidth; ++x) { - auto posX = alignCorners ? static_cast(sd::math::p_round(scaler(x, st.widthScale))) : static_cast(sd::math::p_floor(scaler(x, st.widthScale))); - Nd4jLong inX = sd::math::nd4j_min(posX,inWidth - 1); - if (halfPixelCenter) { - inX = sd::math::nd4j_max(0LL, inX); - } - // copy pixel over all channels - for (Nd4jLong e = 0; e < channels; e++) - output->r(b, y, x, e) = images->t(b, inY, inX, e); - } - } - } - }; - samediff::Threads::parallel_for(func, 0, batchSize, 1, 0, outHeight, 1); - } - - template - int resizeNeighborFunctor_(NDArray const *images, int const width, int const height, bool const alignCorners, bool const halfPixelCenter, NDArray *output) { - ImageResizerState st(alignCorners, halfPixelCenter); - st.validateAndCalculateOutputSize(images, width, height); - - // Handle no-op resizes efficiently. - if (output->sizeAt(1) == images->sizeAt(1) && output->sizeAt(2) == images->sizeAt(2)) { - output->assign(images); - return Status::OK(); - } - - if (halfPixelCenter) - resizeNeighbor(st, images, alignCorners, true, output); - else - resizeNeighbor(st, images, alignCorners, false, output); - - return Status::OK(); - } +template +int resizeNeighborFunctor_(NDArray const* images, int const width, + int const height, bool const alignCorners, + bool const halfPixelCenter, NDArray* output) { + ImageResizerState st(alignCorners, halfPixelCenter); + st.validateAndCalculateOutputSize(images, width, height); + + // Handle no-op resizes efficiently. + if (output->sizeAt(1) == images->sizeAt(1) && + output->sizeAt(2) == images->sizeAt(2)) { + output->assign(images); + return Status::OK(); + } + + if (halfPixelCenter) + resizeNeighbor(st, images, alignCorners, true, + output); + else + resizeNeighbor(st, images, alignCorners, false, output); + + return Status::OK(); +} -// void resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, +// void resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong +// inHeight, Nd4jLong inWidth, Nd4jLong outHeight, // Nd4jLong outWidth, Nd4jLong channels, // std::vector const &xs, // std::vector const &ys, // NDArray *output) { -// BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), resizeImage_, -// (images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output), +// BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), +// resizeImage_, +// (images, batchSize, inHeight, inWidth, +// outHeight, outWidth, channels, xs, ys, output), // NUMERIC_TYPES, FLOAT_TYPES); // } - int resizeBilinearFunctor(sd::LaunchContext * context, NDArray const *images, int const width, int const height, - bool const alignCorners, bool const halfPixelCenter, NDArray *output) { - BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), return resizeBilinearFunctor_, (images, width, height, alignCorners, halfPixelCenter, output), NUMERIC_TYPES, FLOAT_TYPES); - return Status::OK(); - } +int resizeBilinearFunctor(sd::LaunchContext* context, NDArray const* images, + int const width, int const height, + bool const alignCorners, bool const halfPixelCenter, + NDArray* output) { + BUILD_DOUBLE_SELECTOR( + images->dataType(), output->dataType(), return resizeBilinearFunctor_, + (images, width, height, alignCorners, halfPixelCenter, output), + NUMERIC_TYPES, FLOAT_TYPES); + return Status::OK(); +} - int resizeNeighborFunctor(sd::LaunchContext * context, NDArray const *images, int const width, int const height, - bool const alignCorners, bool const halfPixelCenter, NDArray *output) { - BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_, (images, width, height, alignCorners, halfPixelCenter, output), LIBND4J_TYPES); - } +int resizeNeighborFunctor(sd::LaunchContext* context, NDArray const* images, + int const width, int const height, + bool const alignCorners, bool const halfPixelCenter, + NDArray* output) { + BUILD_SINGLE_SELECTOR( + images->dataType(), return resizeNeighborFunctor_, + (images, width, height, alignCorners, halfPixelCenter, output), + LIBND4J_TYPES); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -// ------------------------------------------------------------------------------------------------------------------ // -// Bicubic interpolation -// ------------------------------------------------------------------------------------------------------------------ // - class CachedInterpolationCalculator { - public: - CachedInterpolationCalculator() : _indexes{-1, -1, -1, -1} {} - - // Advances iteration. Returns the number of values that should be copied from - // the current point to the next point. The copying should always be done by - // copying the last values from the old point to the first - // values of the new point. - inline int Advance(const Nd4jLong x0, const Nd4jLong x1, const Nd4jLong x2, - const Nd4jLong x3) { - // We use 2 hands and walk through, copying from one to another where - // we already have values. - // Invariant, new_indicies_hand <= cached_values_hand - const Nd4jLong new_x_indices[] = {x0, x1, x2, x3}; - int cachedValuesHand = 0; - int newIndiciesHand = 0; - while (cachedValuesHand < 4) { - if (_indexes[cachedValuesHand] == new_x_indices[newIndiciesHand]) { - if (newIndiciesHand < cachedValuesHand) { - _indexes[newIndiciesHand] = _indexes[cachedValuesHand]; - } - newIndiciesHand++; - } - cachedValuesHand++; - } - switch (newIndiciesHand) { - case 0: - _indexes[0] = x0; - case 1: - _indexes[1] = x1; - case 2: - _indexes[2] = x2; - case 3: - _indexes[3] = x3; - break; - } - return newIndiciesHand; +// ------------------------------------------------------------------------------------------------------------------ +// // Bicubic interpolation +// ------------------------------------------------------------------------------------------------------------------ +// // +class CachedInterpolationCalculator { + public: + CachedInterpolationCalculator() : _indexes{-1, -1, -1, -1} {} + + // Advances iteration. Returns the number of values that should be copied from + // the current point to the next point. The copying should always be done by + // copying the last values from the old point to the first + // values of the new point. + inline int Advance(const Nd4jLong x0, const Nd4jLong x1, const Nd4jLong x2, + const Nd4jLong x3) { + // We use 2 hands and walk through, copying from one to another where + // we already have values. + // Invariant, new_indicies_hand <= cached_values_hand + const Nd4jLong new_x_indices[] = {x0, x1, x2, x3}; + int cachedValuesHand = 0; + int newIndiciesHand = 0; + while (cachedValuesHand < 4) { + if (_indexes[cachedValuesHand] == new_x_indices[newIndiciesHand]) { + if (newIndiciesHand < cachedValuesHand) { + _indexes[newIndiciesHand] = _indexes[cachedValuesHand]; } - - private: - Nd4jLong _indexes[4]; - }; - static const Nd4jLong kTableSize = 1024LL; //(1 << 10); - - const float* initCoeffsTable(const double a) { - // Allocate and initialize coefficients table using Bicubic - // convolution algorithm. - // https://en.wikipedia.org/wiki/Bicubic_interpolation - float* coeffsTable = new float[(kTableSize + 1) * 2]; - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i <= stop; ++i) { - float x = i * 1.0 / kTableSize; - coeffsTable[i * 2] = ((a + 2) * x - (a + 3)) * x * x + 1; - x += 1.0; - coeffsTable[i * 2 + 1] = ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; - } - }; - samediff::Threads::parallel_for(func, 0, kTableSize); - return coeffsTable; + newIndiciesHand++; + } + cachedValuesHand++; } - - const float* getCoeffsTable(const bool use_keys_cubic) { - // Static so that we initialize it on first use - if (use_keys_cubic) { - // http://ieeexplore.ieee.org/document/1163711/ - // R. G. Keys. Cubic convolution interpolation for digital image - // processing. IEEE Transactions on Acoustics, Speech, and Signal - // Processing, 29(6):1153–1160, 1981. - static const float* coeffs_table = initCoeffsTable(-0.5f); - return coeffs_table; - } else { - static const float* coeffs_table = initCoeffsTable(-0.75f); - return coeffs_table; - } + switch (newIndiciesHand) { + case 0: + _indexes[0] = x0; + case 1: + _indexes[1] = x1; + case 2: + _indexes[2] = x2; + case 3: + _indexes[3] = x3; + break; } - - inline Nd4jLong bound(Nd4jLong val, Nd4jLong limit) { - return math::nd4j_min(limit - 1ll, math::nd4j_max(Nd4jLong{0}, val)); + return newIndiciesHand; + } + + private: + Nd4jLong _indexes[4]; +}; +static const Nd4jLong kTableSize = 1024LL; //(1 << 10); + +const float* initCoeffsTable(const double a) { + // Allocate and initialize coefficients table using Bicubic + // convolution algorithm. + // https://en.wikipedia.org/wiki/Bicubic_interpolation + float* coeffsTable = new float[(kTableSize + 1) * 2]; + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i <= stop; ++i) { + float x = i * 1.0 / kTableSize; + coeffsTable[i * 2] = ((a + 2) * x - (a + 3)) * x * x + 1; + x += 1.0; + coeffsTable[i * 2 + 1] = ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; } + }; + samediff::Threads::parallel_for(func, 0, kTableSize); + return coeffsTable; +} + +const float* getCoeffsTable(const bool use_keys_cubic) { + // Static so that we initialize it on first use + if (use_keys_cubic) { + // http://ieeexplore.ieee.org/document/1163711/ + // R. G. Keys. Cubic convolution interpolation for digital image + // processing. IEEE Transactions on Acoustics, Speech, and Signal + // Processing, 29(6):1153–1160, 1981. + static const float* coeffs_table = initCoeffsTable(-0.5f); + return coeffs_table; + } else { + static const float* coeffs_table = initCoeffsTable(-0.75f); + return coeffs_table; + } +} - template - int resizeBicubicFunctor_(sd::LaunchContext * context, NDArray const* image, int width, int height, - bool preserveAspectRatio, bool antialias, NDArray* output) { - return ND4J_STATUS_OK; +inline Nd4jLong bound(Nd4jLong val, Nd4jLong limit) { + return math::nd4j_min(limit - 1ll, math::nd4j_max(Nd4jLong{0}, val)); +} + +template +int resizeBicubicFunctor_(sd::LaunchContext* context, NDArray const* image, + int width, int height, bool preserveAspectRatio, + bool antialias, NDArray* output) { + return ND4J_STATUS_OK; +} + +int resizeBicubicFunctor(sd::LaunchContext* context, NDArray const* image, + int width, int height, bool preserveAspectRatio, + bool antialias, NDArray* output) { + BUILD_SINGLE_SELECTOR( + image->dataType(), return resizeBicubicFunctor_, + (context, image, width, height, preserveAspectRatio, antialias, output), + NUMERIC_TYPES); +} +// ------------------------------------------------------------------------------------------------------------------ +// // + +template +inline float interpolate1D(const float weight0, const float weight1, + const float weight2, const float weight3, + const T value0, const T value1, const T value2, + const T value3) { + return static_cast(value0) * weight0 + + static_cast(value1) * weight1 + + static_cast(value2) * weight2 + + static_cast(value3) * weight3; +} + +// Compute the 1D interpolation for a given X index using the y_weights +static float compute(float values[4], const float xW0, const float xW1, + const float xW2, const float xW3) { + return interpolate1D(xW0, xW1, xW2, xW3, values[0], values[1], values[2], + values[3]); +} + +template +inline void getWeightsAndIndices(const float scale, const Nd4jLong out_loc, + const Nd4jLong limit, WeightsAndIndices* out) { + const Scaler scaler; + const float in_loc_f = scaler(out_loc, scale); + const Nd4jLong in_loc = std::floor(in_loc_f); + const float delta = in_loc_f - in_loc; + const Nd4jLong offset = lrintf(delta * kTableSize); + const float* coeffs_table = getCoeffsTable(use_keys_cubic); + if (use_keys_cubic) { + // The legacy code placed more weight on the edge pixels, since bounding + // the set of inputs to sample could cause an edge pixel to be repeated. + // Here we change the behavior at borders to match that used by the + // scale_and_translate_op, where sampling locations outside the image have + // their weight set to 0, and the weights are renormalized so that their sum + // is 1.0. + out->_index0 = bound(in_loc - 1, limit); + out->_weight0 = + (out->_index0 == in_loc - 1 ? coeffs_table[offset * 2 + 1] : 0.0f); + out->_index1 = bound(in_loc, limit); + out->_weight1 = (out->_index1 == in_loc ? coeffs_table[offset * 2] : 0.0f); + out->_index2 = bound(in_loc + 1, limit); + out->_weight2 = + (out->_index2 == in_loc + 1 ? coeffs_table[(kTableSize - offset) * 2] + : 0.0f); + out->_index3 = bound(in_loc + 2, limit); + out->_weight3 = (out->_index3 == in_loc + 2 + ? coeffs_table[(kTableSize - offset) * 2 + 1] + : 0.0f); + + const float weight_sum = + out->_weight0 + out->_weight1 + out->_weight2 + out->_weight3; + if (std::abs(weight_sum) >= 1000.0f * std::numeric_limits::min()) { + const float one_over_weight_sum = 1.0f / weight_sum; + out->_weight0 *= one_over_weight_sum; + out->_weight1 *= one_over_weight_sum; + out->_weight2 *= one_over_weight_sum; + out->_weight3 *= one_over_weight_sum; } + } else { + out->_weight0 = coeffs_table[offset * 2 + 1]; + out->_weight1 = coeffs_table[offset * 2]; + out->_weight2 = coeffs_table[(kTableSize - offset) * 2]; + out->_weight3 = coeffs_table[(kTableSize - offset) * 2 + 1]; + out->_index0 = bound(in_loc - 1, limit); + out->_index1 = bound(in_loc, limit); + out->_index2 = bound(in_loc + 1, limit); + out->_index3 = bound(in_loc + 2, limit); + } +} - int resizeBicubicFunctor(sd::LaunchContext * context, NDArray const* image, int width, int height, - bool preserveAspectRatio, bool antialias, NDArray* output) { - BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctor_, (context, image, - width, height, preserveAspectRatio, antialias, output), NUMERIC_TYPES); +static void computeXWeightsAndIndices(const ImageResizerState& resizer_state, + const bool half_pixel_centers, + std::vector* x_wais) { + CachedInterpolationCalculator calc; + if (half_pixel_centers) { + auto func = PRAGMA_THREADS_FOR { + for (auto x = start; x < stop; ++x) { + getWeightsAndIndices( + resizer_state.widthScale, x, resizer_state.inWidth, &(*x_wais)[x]); + auto& x_wai = (*x_wais)[x]; + x_wai._advance = calc.Advance(x_wai._index0, x_wai._index1, + x_wai._index2, x_wai._index3); + } + }; + samediff::Threads::parallel_for(func, 0, resizer_state.outWidth); + } else { + auto func = PRAGMA_THREADS_FOR { + for (auto x = start; x < stop; ++x) { + getWeightsAndIndices( + resizer_state.widthScale, x, resizer_state.inWidth, &(*x_wais)[x]); + auto& x_wai = (*x_wais)[x]; + x_wai._advance = calc.Advance(x_wai._index0, x_wai._index1, + x_wai._index2, x_wai._index3); + } + }; + samediff::Threads::parallel_for(func, 0, resizer_state.outWidth); + } + // Scale the values so they can be used as offsets into buffers. + auto func = PRAGMA_THREADS_FOR { + for (auto x = start; x < stop; ++x) { + (*x_wais)[x]._index0 *= resizer_state.channels; + (*x_wais)[x]._index1 *= resizer_state.channels; + (*x_wais)[x]._index2 *= resizer_state.channels; + (*x_wais)[x]._index3 *= resizer_state.channels; } -// ------------------------------------------------------------------------------------------------------------------ // + }; + samediff::Threads::parallel_for(func, 0, resizer_state.outWidth); +} - template - inline float interpolate1D(const float weight0, const float weight1, const float weight2, const float weight3, - const T value0, const T value1, const T value2, const T value3) { - return static_cast(value0) * weight0 + - static_cast(value1) * weight1 + - static_cast(value2) * weight2 + - static_cast(value3) * weight3; - } +template +static FORCEINLINE float computeYInterpolation(int which, int channelNum, + const WeightsAndIndices& yWai, + const T* pY0, const T* pY1, + const T* pY2, const T* pY3, + const WeightsAndIndices& xWai) { + int xIndex; + switch (which) { + case 0: + xIndex = xWai._index0; + break; + case 1: + xIndex = xWai._index1; + break; + case 2: + xIndex = xWai._index2; + break; + default: + xIndex = xWai._index3; + break; + } + const Nd4jLong pt_index = xIndex + channelNum; + return interpolate1D(yWai._weight0, yWai._weight1, yWai._weight2, + yWai._weight3, pY0[pt_index], pY1[pt_index], + pY2[pt_index], pY3[pt_index]); +} -// Compute the 1D interpolation for a given X index using the y_weights - static float compute(float values[4], const float xW0, const float xW1, const float xW2, const float xW3) { - return interpolate1D(xW0, xW1, xW2, xW3, values[0], values[1],values[2], values[3]); +template +static void bicubicInterpolateWithCaching(NDArray const* image, + ImageResizerState const& resizerState, + bool const halfPixelCenters, + NDArray* output) { + std::vector xWais(resizerState.outWidth); + computeXWeightsAndIndices(resizerState, halfPixelCenters, &xWais); + + const auto numChannels = resizerState.channels; + const Nd4jLong inRowWidth = resizerState.inWidth * numChannels; + const Nd4jLong inBatchWidth = resizerState.inHeight * inRowWidth; + const auto batchNum = resizerState.batchSize; + const auto outHeight = resizerState.outHeight; + const auto outWidth = resizerState.outWidth; + + auto func = PRAGMA_THREADS_FOR { + const T* inputPtr = image->getDataBuffer()->primaryAsT(); + F* pOutputY = + output->dataBuffer()->primaryAsT(); // output is float anyway + std::vector cachedValue(numChannels == 3 ? 0 : 4 * numChannels, 0); + + for (auto b = start; b < stop; ++b) { + auto pInput = inputPtr + b * inBatchWidth; + + for (Nd4jLong y = 0; y < outHeight; ++y) { + auto pOutput = &pOutputY[(b * outHeight + y) * outWidth * numChannels]; + + WeightsAndIndices yWai; + if (halfPixelCenters) { + getWeightsAndIndices( + resizerState.heightScale, y, resizerState.inHeight, &yWai); + } else { + getWeightsAndIndices( + resizerState.heightScale, y, resizerState.inHeight, &yWai); } + // Make pointers represent offsets of data in inputBPtr. + const T* y_ptr_0 = pInput + yWai._index0 * inRowWidth; + const T* y_ptr_1 = pInput + yWai._index1 * inRowWidth; + const T* y_ptr_2 = pInput + yWai._index2 * inRowWidth; + const T* y_ptr_3 = pInput + yWai._index3 * inRowWidth; + + if (numChannels == 3) { + // Manually unroll case of 3 channels. + F cached_value_0[4] = {0}; + F cached_value_1[4] = {0}; + F cached_value_2[4] = {0}; + for (Nd4jLong x = 0; x < resizerState.outWidth; ++x) { + const WeightsAndIndices& xWai = xWais[x]; + // Shift values in cached_value_* to fill first '_advance' values. + switch (xWai._advance) { + case 3: + cached_value_0[0] = cached_value_0[1]; + cached_value_0[1] = cached_value_0[2]; + cached_value_0[2] = cached_value_0[3]; + cached_value_1[0] = cached_value_1[1]; + cached_value_1[1] = cached_value_1[2]; + cached_value_1[2] = cached_value_1[3]; + cached_value_2[0] = cached_value_2[1]; + cached_value_2[1] = cached_value_2[2]; + cached_value_2[2] = cached_value_2[3]; + break; + case 2: + cached_value_0[0] = cached_value_0[2]; + cached_value_0[1] = cached_value_0[3]; + cached_value_1[0] = cached_value_1[2]; + cached_value_1[1] = cached_value_1[3]; + cached_value_2[0] = cached_value_2[2]; + cached_value_2[1] = cached_value_2[3]; + break; + case 1: { + cached_value_0[0] = cached_value_0[3]; + cached_value_1[0] = cached_value_1[3]; + cached_value_2[0] = cached_value_2[3]; + break; + } + } - template - inline void getWeightsAndIndices(const float scale, const Nd4jLong out_loc, const Nd4jLong limit, WeightsAndIndices* out) { - const Scaler scaler; - const float in_loc_f = scaler(out_loc, scale); - const Nd4jLong in_loc = std::floor(in_loc_f); - const float delta = in_loc_f - in_loc; - const Nd4jLong offset = lrintf(delta * kTableSize); - const float* coeffs_table = getCoeffsTable(use_keys_cubic); - if (use_keys_cubic) { - // The legacy code placed more weight on the edge pixels, since bounding - // the set of inputs to sample could cause an edge pixel to be repeated. - // Here we change the behavior at borders to match that used by the - // scale_and_translate_op, where sampling locations outside the image have - // their weight set to 0, and the weights are renormalized so that their sum - // is 1.0. - out->_index0 = bound(in_loc - 1, limit); - out->_weight0 = - (out->_index0 == in_loc - 1 ? coeffs_table[offset * 2 + 1] : 0.0f); - out->_index1 = bound(in_loc, limit); - out->_weight1 = (out->_index1 == in_loc ? coeffs_table[offset * 2] : 0.0f); - out->_index2 = bound(in_loc + 1, limit); - out->_weight2 = - (out->_index2 == in_loc + 1 ? coeffs_table[(kTableSize - offset) * 2] - : 0.0f); - out->_index3 = bound(in_loc + 2, limit); - out->_weight3 = (out->_index3 == in_loc + 2 - ? coeffs_table[(kTableSize - offset) * 2 + 1] - : 0.0f); - - const float weight_sum = - out->_weight0 + out->_weight1 + out->_weight2 + out->_weight3; - if (std::abs(weight_sum) >= 1000.0f * std::numeric_limits::min()) { - const float one_over_weight_sum = 1.0f / weight_sum; - out->_weight0 *= one_over_weight_sum; - out->_weight1 *= one_over_weight_sum; - out->_weight2 *= one_over_weight_sum; - out->_weight3 *= one_over_weight_sum; + // Set the remaining '4-_advance' values by computing. + switch (xWai._advance) { + case 0: + cached_value_0[0] = computeYInterpolation( + 0, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_1[0] = computeYInterpolation( + 0, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_2[0] = computeYInterpolation( + 0, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + + case 1: + cached_value_0[1] = computeYInterpolation( + 1, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_1[1] = computeYInterpolation( + 1, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_2[1] = computeYInterpolation( + 1, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + + case 2: + cached_value_0[2] = computeYInterpolation( + 2, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_1[2] = computeYInterpolation( + 2, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_2[2] = computeYInterpolation( + 2, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + + case 3: + cached_value_0[3] = computeYInterpolation( + 3, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_1[3] = computeYInterpolation( + 3, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_2[3] = computeYInterpolation( + 3, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + break; + } + pOutput[x * numChannels + 0] = + compute(cached_value_0, xWai._weight0, xWai._weight1, + xWai._weight2, xWai._weight3); + pOutput[x * numChannels + 1] = + compute(cached_value_1, xWai._weight0, xWai._weight1, + xWai._weight2, xWai._weight3); + pOutput[x * numChannels + 2] = + compute(cached_value_2, xWai._weight0, xWai._weight1, + xWai._weight2, xWai._weight3); + } + } else { + for (Nd4jLong x = 0; x < resizerState.outWidth; ++x) { + const WeightsAndIndices& xWai = xWais[x]; + // Shift values in cachedValue to fill first '_advance' values. + switch (xWai._advance) { + case 3: + for (auto c = 0; c < numChannels; ++c) { + cachedValue[4 * c + 0] = cachedValue[4 * c + 1]; + cachedValue[4 * c + 1] = cachedValue[4 * c + 2]; + cachedValue[4 * c + 2] = cachedValue[4 * c + 3]; } - } else { - out->_weight0 = coeffs_table[offset * 2 + 1]; - out->_weight1 = coeffs_table[offset * 2]; - out->_weight2 = coeffs_table[(kTableSize - offset) * 2]; - out->_weight3 = coeffs_table[(kTableSize - offset) * 2 + 1]; - out->_index0 = bound(in_loc - 1, limit); - out->_index1 = bound(in_loc, limit); - out->_index2 = bound(in_loc + 1, limit); - out->_index3 = bound(in_loc + 2, limit); + break; + case 2: + for (auto c = 0; c < numChannels; ++c) { + cachedValue[4 * c + 0] = cachedValue[4 * c + 2]; + cachedValue[4 * c + 1] = cachedValue[4 * c + 3]; + } + break; + case 1: { + for (auto c = 0; c < numChannels; ++c) { + cachedValue[4 * c + 0] = cachedValue[4 * c + 3]; + } + break; + } } - } - static void computeXWeightsAndIndices(const ImageResizerState& resizer_state, - const bool half_pixel_centers, - std::vector* x_wais) { - CachedInterpolationCalculator calc; - if (half_pixel_centers) { - auto func = PRAGMA_THREADS_FOR { - for (auto x = start; x < stop; ++x) { - getWeightsAndIndices( - resizer_state.widthScale, x, resizer_state.inWidth, &(*x_wais)[x]); - auto &x_wai = (*x_wais)[x]; - x_wai._advance = calc.Advance(x_wai._index0, x_wai._index1, x_wai._index2, - x_wai._index3); - } - }; - samediff::Threads::parallel_for(func, 0, resizer_state.outWidth); - } else { - auto func = PRAGMA_THREADS_FOR { - for (auto x = start; x < stop; ++x) { - getWeightsAndIndices( - resizer_state.widthScale, x, resizer_state.inWidth, &(*x_wais)[x]); - auto& x_wai = (*x_wais)[x]; - x_wai._advance = calc.Advance(x_wai._index0, x_wai._index1, x_wai._index2, - x_wai._index3); - } - }; - samediff::Threads::parallel_for(func, 0, resizer_state.outWidth); - } - // Scale the values so they can be used as offsets into buffers. - auto func = PRAGMA_THREADS_FOR { - for (auto x = start; x < stop; ++x) { - (*x_wais)[x]._index0 *= resizer_state.channels; - (*x_wais)[x]._index1 *= resizer_state.channels; - (*x_wais)[x]._index2 *= resizer_state.channels; - (*x_wais)[x]._index3 *= resizer_state.channels; + // Set the remaining '4-_advance' values by computing. + switch (xWai._advance) { + case 0: + for (auto c = 0; c < numChannels; ++c) { + cachedValue[4 * c + 0] = computeYInterpolation( + 0, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); } - }; - samediff::Threads::parallel_for(func, 0, resizer_state.outWidth); - } - template - static FORCEINLINE float computeYInterpolation( - int which, int channelNum, const WeightsAndIndices& yWai, - const T* pY0, const T* pY1, const T* pY2, const T* pY3, - const WeightsAndIndices& xWai) { - int xIndex; - switch (which) { - case 0: - xIndex = xWai._index0; - break; - case 1: - xIndex = xWai._index1; - break; - case 2: - xIndex = xWai._index2; - break; - default: - xIndex = xWai._index3; - break; - } - const Nd4jLong pt_index = xIndex + channelNum; - return interpolate1D(yWai._weight0, yWai._weight1, yWai._weight2, - yWai._weight3, pY0[pt_index], pY1[pt_index], - pY2[pt_index], pY3[pt_index]); - } + case 1: + for (auto c = 0; c < numChannels; ++c) { + cachedValue[4 * c + 1] = computeYInterpolation( + 1, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + } - template - static void - bicubicInterpolateWithCaching(NDArray const* image, ImageResizerState const& resizerState, bool const halfPixelCenters, NDArray* output) { - std::vector xWais(resizerState.outWidth); - computeXWeightsAndIndices(resizerState, halfPixelCenters, &xWais); - - const auto numChannels = resizerState.channels; - const Nd4jLong inRowWidth = resizerState.inWidth * numChannels; - const Nd4jLong inBatchWidth = resizerState.inHeight * inRowWidth; - const auto batchNum = resizerState.batchSize; - const auto outHeight = resizerState.outHeight; - const auto outWidth = resizerState.outWidth; - - auto func = PRAGMA_THREADS_FOR { - const T* inputPtr = image->getDataBuffer()->primaryAsT(); - F* pOutputY = output->dataBuffer()->primaryAsT(); // output is float anyway - std::vector cachedValue(numChannels == 3 ? 0 : 4 * numChannels, 0); - - for (auto b = start; b < stop; ++b) { - auto pInput = inputPtr + b * inBatchWidth; - - for (Nd4jLong y = 0; y < outHeight; ++y) { - auto pOutput = &pOutputY[(b * outHeight + y) * outWidth * numChannels]; - - WeightsAndIndices yWai; - if (halfPixelCenters) { - getWeightsAndIndices( - resizerState.heightScale, y, resizerState.inHeight, &yWai); - } else { - getWeightsAndIndices( - resizerState.heightScale, y, resizerState.inHeight, &yWai); - } - // Make pointers represent offsets of data in inputBPtr. - const T* y_ptr_0 = pInput + yWai._index0 * inRowWidth; - const T* y_ptr_1 = pInput + yWai._index1 * inRowWidth; - const T* y_ptr_2 = pInput + yWai._index2 * inRowWidth; - const T* y_ptr_3 = pInput + yWai._index3 * inRowWidth; - - if (numChannels == 3) { - // Manually unroll case of 3 channels. - F cached_value_0[4] = {0}; - F cached_value_1[4] = {0}; - F cached_value_2[4] = {0}; - for (Nd4jLong x = 0; x < resizerState.outWidth; ++x) { - const WeightsAndIndices &xWai = xWais[x]; - // Shift values in cached_value_* to fill first '_advance' values. - switch (xWai._advance) { - case 3: - cached_value_0[0] = cached_value_0[1]; - cached_value_0[1] = cached_value_0[2]; - cached_value_0[2] = cached_value_0[3]; - cached_value_1[0] = cached_value_1[1]; - cached_value_1[1] = cached_value_1[2]; - cached_value_1[2] = cached_value_1[3]; - cached_value_2[0] = cached_value_2[1]; - cached_value_2[1] = cached_value_2[2]; - cached_value_2[2] = cached_value_2[3]; - break; - case 2: - cached_value_0[0] = cached_value_0[2]; - cached_value_0[1] = cached_value_0[3]; - cached_value_1[0] = cached_value_1[2]; - cached_value_1[1] = cached_value_1[3]; - cached_value_2[0] = cached_value_2[2]; - cached_value_2[1] = cached_value_2[3]; - break; - case 1: { - cached_value_0[0] = cached_value_0[3]; - cached_value_1[0] = cached_value_1[3]; - cached_value_2[0] = cached_value_2[3]; - break; - } - } - - // Set the remaining '4-_advance' values by computing. - switch (xWai._advance) { - case 0: - cached_value_0[0] = computeYInterpolation( - 0, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_1[0] = computeYInterpolation( - 0, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_2[0] = computeYInterpolation( - 0, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - - case 1: - cached_value_0[1] = computeYInterpolation( - 1, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_1[1] = computeYInterpolation( - 1, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_2[1] = computeYInterpolation( - 1, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - - case 2: - cached_value_0[2] = computeYInterpolation( - 2, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_1[2] = computeYInterpolation( - 2, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_2[2] = computeYInterpolation( - 2, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - - case 3: - cached_value_0[3] = computeYInterpolation( - 3, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_1[3] = computeYInterpolation( - 3, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_2[3] = computeYInterpolation( - 3, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - break; - } - pOutput[x * numChannels + 0] = - compute(cached_value_0, xWai._weight0, xWai._weight1, - xWai._weight2, xWai._weight3); - pOutput[x * numChannels + 1] = - compute(cached_value_1, xWai._weight0, xWai._weight1, - xWai._weight2, xWai._weight3); - pOutput[x * numChannels + 2] = - compute(cached_value_2, xWai._weight0, xWai._weight1, - xWai._weight2, xWai._weight3); - } - } else { - for (Nd4jLong x = 0; x < resizerState.outWidth; ++x) { - const WeightsAndIndices &xWai = xWais[x]; - // Shift values in cachedValue to fill first '_advance' values. - switch (xWai._advance) { - case 3: - for (auto c = 0; c < numChannels; ++c) { - cachedValue[4 * c + 0] = cachedValue[4 * c + 1]; - cachedValue[4 * c + 1] = cachedValue[4 * c + 2]; - cachedValue[4 * c + 2] = cachedValue[4 * c + 3]; - } - break; - case 2: - for (auto c = 0; c < numChannels; ++c) { - cachedValue[4 * c + 0] = cachedValue[4 * c + 2]; - cachedValue[4 * c + 1] = cachedValue[4 * c + 3]; - } - break; - case 1: { - for (auto c = 0; c < numChannels; ++c) { - cachedValue[4 * c + 0] = cachedValue[4 * c + 3]; - } - break; - } - } - - // Set the remaining '4-_advance' values by computing. - switch (xWai._advance) { - case 0: - for (auto c = 0; c < numChannels; ++c) { - cachedValue[4 * c + 0] = computeYInterpolation( - 0, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - } - - case 1: - for (auto c = 0; c < numChannels; ++c) { - cachedValue[4 * c + 1] = computeYInterpolation( - 1, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - } - - case 2: - for (auto c = 0; c < numChannels; ++c) { - cachedValue[4 * c + 2] = computeYInterpolation( - 2, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - } - - case 3: - for (auto c = 0; c < numChannels; ++c) { - cachedValue[4 * c + 3] = computeYInterpolation( - 3, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - } - break; - } - for (auto c = 0; c < numChannels; ++c) { - pOutput[x * numChannels + c] = - (F)compute(&cachedValue[4 * c], xWai._weight0, xWai._weight1, - xWai._weight2, xWai._weight3); - } - } - } + case 2: + for (auto c = 0; c < numChannels; ++c) { + cachedValue[4 * c + 2] = computeYInterpolation( + 2, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + } + + case 3: + for (auto c = 0; c < numChannels; ++c) { + cachedValue[4 * c + 3] = computeYInterpolation( + 3, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); } + break; } - }; - samediff::Threads::parallel_tad(func, 0, batchNum); + for (auto c = 0; c < numChannels; ++c) { + pOutput[x * numChannels + c] = + (F)compute(&cachedValue[4 * c], xWai._weight0, xWai._weight1, + xWai._weight2, xWai._weight3); + } + } + } + } } + }; + samediff::Threads::parallel_tad(func, 0, batchNum); +} // simplified bicubic resize without antialiasing // - template - int resizeBicubicFunctorA_(sd::LaunchContext * context, NDArray const* image, int const width, int const height, - bool const alignCorners, bool const halfPixelAlign, NDArray* output) { - ImageResizerState st(alignCorners, halfPixelAlign); // align_corners, half_pixel_align - int res = st.validateAndCreateOutput(image, width, height); - if (res == Status::OK()) - bicubicInterpolateWithCaching(image, st, halfPixelAlign, output); - - return res; +template +int resizeBicubicFunctorA_(sd::LaunchContext* context, NDArray const* image, + int const width, int const height, + bool const alignCorners, bool const halfPixelAlign, + NDArray* output) { + ImageResizerState st(alignCorners, + halfPixelAlign); // align_corners, half_pixel_align + int res = st.validateAndCreateOutput(image, width, height); + if (res == Status::OK()) + bicubicInterpolateWithCaching(image, st, halfPixelAlign, output); + + return res; +} +int resizeBicubicFunctorA(sd::LaunchContext* context, NDArray const* image, + int const width, int const height, + bool const alignCorners, bool const halfPixelAlign, + NDArray* output) { + BUILD_SINGLE_SELECTOR( + image->dataType(), return resizeBicubicFunctorA_, + (context, image, width, height, alignCorners, halfPixelAlign, output), + NUMERIC_TYPES); +} +// ------------------------------------------------------------------------------------------------------------------ +// // +struct CachedInterpolation { + Nd4jLong start; + Nd4jLong end; + float startScale; + float endMinusOneScale; + bool needsBounding; +}; + +template +struct ScaleCache { + float yScale; + T const* yPtr; +}; +// Computes the sum of all x values defined by taken across +// the y offsets and scales defined by y_ptrs and y_scales, for channel c. +// +// Note that is a template parameter to avoid a performance +// penalty from dynamically checking it. +template +static void computePatchSumOf3Channels(float scale, ImageResizerState const& st, + std::vector> const& yPtrs, + CachedInterpolation const& xCache, + float* outputPtr) { + bool const needsXBounding = xCache.needsBounding; + + auto boundIfNeeded = [needsXBounding](Nd4jLong x, Nd4jLong y) -> Nd4jLong { + return (needsXBounding ? bound(x, y) : (x)); + }; + + float sum_0 = 0; + float sum_1 = 0; + float sum_2 = 0; + for (size_t i = 0; i < yPtrs.size(); ++i) { + const T* ptr = yPtrs[i].yPtr; + float scaleX = xCache.startScale; + Nd4jLong offset = 3 * boundIfNeeded(xCache.start, st.inWidth); + float sum_y_0 = static_cast(ptr[offset + 0]) * scaleX; + float sum_y_1 = static_cast(ptr[offset + 1]) * scaleX; + float sum_y_2 = static_cast(ptr[offset + 2]) * scaleX; + + if (xCache.start + 1 != xCache.end) { + for (Nd4jLong x = xCache.start + 1; x < xCache.end - 1; ++x) { + Nd4jLong offset = 3 * boundIfNeeded(x, st.inWidth); + sum_y_0 += static_cast(ptr[offset + 0]); + sum_y_1 += static_cast(ptr[offset + 1]); + sum_y_2 += static_cast(ptr[offset + 2]); + } + scaleX = xCache.endMinusOneScale; + offset = st.channels * boundIfNeeded(xCache.end - 1, st.inWidth); + sum_y_0 += static_cast(ptr[offset + 0]) * scaleX; + sum_y_1 += static_cast(ptr[offset + 1]) * scaleX; + sum_y_2 += static_cast(ptr[offset + 2]) * scaleX; } - int resizeBicubicFunctorA(sd::LaunchContext * context, NDArray const* image, int const width, int const height, - bool const alignCorners, bool const halfPixelAlign, NDArray* output) { - BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctorA_, (context, image, width, height, alignCorners, halfPixelAlign, output), NUMERIC_TYPES); + sum_0 += sum_y_0 * yPtrs[i].yScale; + sum_1 += sum_y_1 * yPtrs[i].yScale; + sum_2 += sum_y_2 * yPtrs[i].yScale; + } + + outputPtr[0] = sum_0 * scale; + outputPtr[1] = sum_1 * scale; + outputPtr[2] = sum_2 * scale; +} + +// Computes the sum of all x values defined by taken across +// the y offsets and scales defined by y_ptrs and y_scales, for channel c. +// +// Note that is a template parameter to avoid a performance +// penalty from dynamically checking it. +template +static void computePatchSum(float scale, const ImageResizerState& st, + const std::vector>& yPtrs, + const CachedInterpolation& xCache, + float* outputPtr) { + bool const needsXBounding = xCache.needsBounding; + + auto boundIfNeeded = [needsXBounding](Nd4jLong x, Nd4jLong y) -> Nd4jLong { + return (needsXBounding ? bound(x, y) : (x)); + }; + + const auto numChannels = st.channels; + for (Nd4jLong c = 0; c < numChannels; ++c) { + float sum = 0; + for (size_t i = 0; i < yPtrs.size(); ++i) { + T const* ptr = yPtrs[i].yPtr; + float scaleX = xCache.startScale; + float sumY = + static_cast( + ptr[numChannels * boundIfNeeded(xCache.start, st.inWidth) + c]) * + scaleX; + if (xCache.start + 1 != xCache.end) { + for (Nd4jLong x = xCache.start + 1; x < xCache.end - 1; ++x) { + sumY += static_cast( + ptr[numChannels * boundIfNeeded(x, st.inWidth) + c]); + } + scaleX = xCache.endMinusOneScale; + sumY += + static_cast( + ptr[numChannels * boundIfNeeded(xCache.end - 1, st.inWidth) + + c]) * + scaleX; + } + sum += sumY * yPtrs[i].yScale; } -// ------------------------------------------------------------------------------------------------------------------ // - struct CachedInterpolation { - Nd4jLong start; - Nd4jLong end; - float startScale; - float endMinusOneScale; - bool needsBounding; - }; + outputPtr[c] = sum * scale; + } +} - template - struct ScaleCache { - float yScale; - T const* yPtr; - }; - // Computes the sum of all x values defined by taken across - // the y offsets and scales defined by y_ptrs and y_scales, for channel c. - // - // Note that is a template parameter to avoid a performance - // penalty from dynamically checking it. - template - static void computePatchSumOf3Channels(float scale, - ImageResizerState const& st, - std::vector> const& yPtrs, - CachedInterpolation const& xCache, - float* outputPtr) { - - bool const needsXBounding = xCache.needsBounding; - - auto boundIfNeeded = [needsXBounding](Nd4jLong x, Nd4jLong y) -> Nd4jLong { - return (needsXBounding ? bound(x, y) : (x)); +template +static void resizeArea(ImageResizerState const& st, + std::vector const& caches, + NDArray const* input, NDArray* output) { + T const* inputPtr = input->bufferAsT(); + float scale = 1.f / (st.heightScale * st.widthScale); + auto outputPtr = + output->bufferAsT(); // output is always float. TO DO: provide + // another float types also with template + // declaration + + auto batchProcess = PRAGMA_THREADS_FOR { + for (auto batch = start; batch < stop; batch++) { + for (auto y = 0; y < st.outHeight; ++y) { + const float inY = y * st.heightScale; + const float inY1 = (y + 1) * st.heightScale; + // The start and end height indices of all the cells that could + // contribute to the target cell. + const Nd4jLong yStart = math::nd4j_floor(inY); + const Nd4jLong yEnd = math::nd4j_ceil(inY1); + + std::vector> yCaches; + auto cacheLen = yEnd - yStart; + if (cacheLen) { + yCaches.resize(cacheLen); }; - float sum_0 = 0; - float sum_1 = 0; - float sum_2 = 0; - for (size_t i = 0; i < yPtrs.size(); ++i) { - const T* ptr = yPtrs[i].yPtr; - float scaleX = xCache.startScale; - Nd4jLong offset = 3 * boundIfNeeded(xCache.start, st.inWidth); - float sum_y_0 = static_cast(ptr[offset + 0]) * scaleX; - float sum_y_1 = static_cast(ptr[offset + 1]) * scaleX; - float sum_y_2 = static_cast(ptr[offset + 2]) * scaleX; - - if (xCache.start + 1 != xCache.end) { - for (Nd4jLong x = xCache.start + 1; x < xCache.end - 1; ++x) { - Nd4jLong offset = 3 * boundIfNeeded(x, st.inWidth); - sum_y_0 += static_cast(ptr[offset + 0]); - sum_y_1 += static_cast(ptr[offset + 1]); - sum_y_2 += static_cast(ptr[offset + 2]); - } - scaleX = xCache.endMinusOneScale; - offset = st.channels * boundIfNeeded(xCache.end - 1, st.inWidth); - sum_y_0 += static_cast(ptr[offset + 0]) * scaleX; - sum_y_1 += static_cast(ptr[offset + 1]) * scaleX; - sum_y_2 += static_cast(ptr[offset + 2]) * scaleX; - } - sum_0 += sum_y_0 * yPtrs[i].yScale; - sum_1 += sum_y_1 * yPtrs[i].yScale; - sum_2 += sum_y_2 * yPtrs[i].yScale; + for (auto i = yStart, k = 0LL; i < yEnd; ++i, ++k) { + ScaleCache scaleCache; + if (i < inY) { + scaleCache.yScale = (i + 1 > inY1 ? st.heightScale : i + 1 - inY); + } else { + scaleCache.yScale = (i + 1 > inY1 ? inY1 - i : 1.0); + } + scaleCache.yPtr = + inputPtr + (batch * st.inHeight * st.inWidth * st.channels + + bound(i, st.inHeight) * st.inWidth * st.channels); + yCaches[k] = scaleCache; } - - outputPtr[0] = sum_0 * scale; - outputPtr[1] = sum_1 * scale; - outputPtr[2] = sum_2 * scale; - } - - // Computes the sum of all x values defined by taken across - // the y offsets and scales defined by y_ptrs and y_scales, for channel c. - // - // Note that is a template parameter to avoid a performance - // penalty from dynamically checking it. - template - static void computePatchSum(float scale, const ImageResizerState& st, - const std::vector>& yPtrs, - const CachedInterpolation& xCache, - float* outputPtr) { - - bool const needsXBounding = xCache.needsBounding; - - auto boundIfNeeded = [needsXBounding](Nd4jLong x, Nd4jLong y) -> Nd4jLong { - return (needsXBounding ? bound(x, y) : (x)); - }; - - const auto numChannels = st.channels; - for (Nd4jLong c = 0; c < numChannels; ++c) { - float sum = 0; - for (size_t i = 0; i < yPtrs.size(); ++i) { - T const* ptr = yPtrs[i].yPtr; - float scaleX = xCache.startScale; - float sumY = static_cast(ptr[numChannels * boundIfNeeded(xCache.start, st.inWidth) + c]) * scaleX; - if (xCache.start + 1 != xCache.end) { - for (Nd4jLong x = xCache.start + 1; x < xCache.end - 1; ++x) { - sumY += static_cast( - ptr[numChannels * boundIfNeeded(x, st.inWidth) + c]); - } - scaleX = xCache.endMinusOneScale; - sumY += static_cast(ptr[numChannels * boundIfNeeded(xCache.end - 1, st.inWidth) + c]) * scaleX; - } - sum += sumY * yPtrs[i].yScale; - } - outputPtr[c] = sum * scale; - } + float* output = + outputPtr + (batch * st.outHeight + y) * st.channels * st.outWidth; + + if (st.channels == 3) { + for (Nd4jLong x = 0; x < st.outWidth; ++x) { + const CachedInterpolation& xCache = caches[x]; + computePatchSumOf3Channels(scale, st, yCaches, xCache, output); + output += st.channels; + } + } else { + for (Nd4jLong x = 0; x < st.outWidth; ++x) { + const CachedInterpolation& xCache = caches[x]; + computePatchSum(scale, st, yCaches, xCache, output); + output += st.channels; + } } - - - - template - static void resizeArea(ImageResizerState const& st, std::vector const& caches, NDArray const* input, NDArray* output) { - T const* inputPtr = input->bufferAsT(); - float scale = 1.f / (st.heightScale * st.widthScale); - auto outputPtr = output->bufferAsT(); // output is always float. TO DO: provide another float types also with template declaration - - auto batchProcess = PRAGMA_THREADS_FOR { - for (auto batch = start; batch < stop; batch++) { - for (auto y = 0; y < st.outHeight; ++y) { - const float inY = y * st.heightScale; - const float inY1 = (y + 1) * st.heightScale; - // The start and end height indices of all the cells that could - // contribute to the target cell. - const Nd4jLong yStart = math::nd4j_floor(inY); - const Nd4jLong yEnd = math::nd4j_ceil(inY1); - - std::vector> yCaches; - auto cacheLen = yEnd - yStart; - if (cacheLen) { - yCaches.resize(cacheLen); - }; - - for (auto i = yStart, k = 0LL; i < yEnd; ++i, ++k) { - ScaleCache scaleCache; - if (i < inY) { - scaleCache.yScale = (i + 1 > inY1 ? st.heightScale : i + 1 - inY); - } else { - scaleCache.yScale = (i + 1 > inY1 ? inY1 - i : 1.0); - } - scaleCache.yPtr = inputPtr + (batch * st.inHeight * st.inWidth * st.channels + - bound(i, st.inHeight) * st.inWidth * st.channels); - yCaches[k] = scaleCache; - } - float* output = outputPtr + (batch * st.outHeight + y) * st.channels * st.outWidth; - - if (st.channels == 3) { - for (Nd4jLong x = 0; x < st.outWidth; ++x) { - const CachedInterpolation &xCache = caches[x]; - computePatchSumOf3Channels(scale, st, yCaches, xCache, output); - output += st.channels; - } - } else { - for (Nd4jLong x = 0; x < st.outWidth; ++x) { - const CachedInterpolation &xCache = caches[x]; - computePatchSum(scale, st, yCaches, xCache, output); - output += st.channels; - } - } - } - } - }; - samediff::Threads::parallel_tad(batchProcess, 0, st.batchSize, 1); + } } + }; + samediff::Threads::parallel_tad(batchProcess, 0, st.batchSize, 1); +} - template - int resizeAreaFunctor_(sd::LaunchContext* context, NDArray const* image, int const width, int const height, - bool const alignCorners, NDArray* output) { - ImageResizerState st(alignCorners, false); // Create resize info - auto res = st.validateAndCalculateOutputSize(image, width, height); - if (Status::OK() == res) { - std::vector xCached(st.outWidth); - auto cachingProcedure = PRAGMA_THREADS_FOR { - for (auto x = start; x < stop; x++) { - auto &xCache = xCached[x]; - const float inX = x * st.widthScale; - const float inX1 = (x + 1) * st.widthScale; - - Nd4jLong v = math::nd4j_floor(inX); - xCache.start = v; - xCache.startScale = - v < inX ? (v + 1 > inX1 ? st.widthScale : v + 1 - inX) : (v + 1 > inX1 ? inX1 - v - : 1.f); - v = math::nd4j_ceil(inX1); - xCache.end = v--; - xCache.endMinusOneScale = - v < inX ? (v + 1 > inX1 ? st.widthScale : v + 1 - inX) : (v + 1 > inX1 ? inX1 - v - : 1.f); - xCache.needsBounding = bound(xCache.start, st.inWidth) != xCache.start || - bound(xCache.end - 1, st.inWidth) != (xCache.end - 1); - - } - }; - samediff::Threads::parallel_for(cachingProcedure, 0, xCached.size(), 1); +template +int resizeAreaFunctor_(sd::LaunchContext* context, NDArray const* image, + int const width, int const height, + bool const alignCorners, NDArray* output) { + ImageResizerState st(alignCorners, false); // Create resize info + auto res = st.validateAndCalculateOutputSize(image, width, height); + if (Status::OK() == res) { + std::vector xCached(st.outWidth); + auto cachingProcedure = PRAGMA_THREADS_FOR { + for (auto x = start; x < stop; x++) { + auto& xCache = xCached[x]; + const float inX = x * st.widthScale; + const float inX1 = (x + 1) * st.widthScale; + + Nd4jLong v = math::nd4j_floor(inX); + xCache.start = v; + xCache.startScale = v < inX + ? (v + 1 > inX1 ? st.widthScale : v + 1 - inX) + : (v + 1 > inX1 ? inX1 - v : 1.f); + v = math::nd4j_ceil(inX1); + xCache.end = v--; + xCache.endMinusOneScale = + v < inX ? (v + 1 > inX1 ? st.widthScale : v + 1 - inX) + : (v + 1 > inX1 ? inX1 - v : 1.f); + xCache.needsBounding = + bound(xCache.start, st.inWidth) != xCache.start || + bound(xCache.end - 1, st.inWidth) != (xCache.end - 1); + } + }; + samediff::Threads::parallel_for(cachingProcedure, 0, xCached.size(), 1); - resizeArea(st, xCached, image, output); - } - return res; - } + resizeArea(st, xCached, image, output); + } + return res; +} - int resizeAreaFunctor(sd::LaunchContext * context, NDArray const* image, int const width, int const height, bool const alignCorners, NDArray* output) { - BUILD_SINGLE_SELECTOR(image->dataType(), return resizeAreaFunctor_, (context, image, width, height, alignCorners, output), NUMERIC_TYPES); - } +int resizeAreaFunctor(sd::LaunchContext* context, NDArray const* image, + int const width, int const height, + bool const alignCorners, NDArray* output) { + BUILD_SINGLE_SELECTOR(image->dataType(), return resizeAreaFunctor_, + (context, image, width, height, alignCorners, output), + NUMERIC_TYPES); +} /** * resize as TF v.2.x implemented (with preserve aspect ratio and antialias flags routines @@ -1355,9 +1441,10 @@ namespace helpers { return resizeBicubicFunctor(context, image, width, height, alignCorners, false, output); case kResizeArea: return resizeAreaFunctor(context, image, width, height, alignCorners, output); + default: + nd4j_printf("helper::resizeImagesFunctor: Wrong resize method %i\n", (int)method); + return Status::CODE(ND4J_STATUS_BAD_INPUT, "helper::resizeImagesFunctor: Wrong resize method"); } - nd4j_printf("helper::resizeImagesFunctor: Wrong resize method %i\n", (int)method); - return Status::CODE(ND4J_STATUS_BAD_INPUT, "helper::resizeImagesFunctor: Wrong resize method"); } // ------------------------------------------------------------------------------------------------------------------ // int resizeFunctor(sd::LaunchContext * context, NDArray const* image, int const width, int const height, @@ -1376,7 +1463,6 @@ namespace helpers { return Status::CODE(ND4J_STATUS_BAD_INPUT, "helper::resizeFunctor: Wrong resize method"); } - -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/image_suppression.cpp b/libnd4j/include/ops/declarable/helpers/cpu/image_suppression.cpp index c3ad42db3181..741292a3cd6a 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/image_suppression.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/image_suppression.cpp @@ -18,8 +18,9 @@ // @author sgazeos@gmail.com // -#include #include +#include + #include #include #include @@ -28,235 +29,288 @@ namespace sd { namespace ops { namespace helpers { - template - static void nonMaxSuppressionV2_(NDArray* boxes, NDArray* scales, int maxSize, double overlapThreshold, - double scoreThreshold, NDArray* output) { - std::vector indices(scales->lengthOf()); - std::iota(indices.begin(), indices.end(), 0); - auto actualIndicesCount = indices.size(); - for (auto e = 0; e < scales->lengthOf(); e++) { - if (scales->e(e) < (float)scoreThreshold) { - indices[e] = -1; - actualIndicesCount--; - } - } - std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return i >= 0 && j >=0?scales->e(i) > scales->e(j):(i > j);}); - -// std::vector selected(output->lengthOf()); - std::vector selectedIndices(output->lengthOf(), 0); - auto needToSuppressWithThreshold = [] (NDArray& boxes, int previousIndex, int nextIndex, T threshold) -> bool { - if (previousIndex < 0 || nextIndex < 0) return true; - T minYPrev = sd::math::nd4j_min(boxes.t(previousIndex, 0), boxes.t(previousIndex, 2)); - T minXPrev = sd::math::nd4j_min(boxes.t(previousIndex, 1), boxes.t(previousIndex, 3)); - T maxYPrev = sd::math::nd4j_max(boxes.t(previousIndex, 0), boxes.t(previousIndex, 2)); - T maxXPrev = sd::math::nd4j_max(boxes.t(previousIndex, 1), boxes.t(previousIndex, 3)); - T minYNext = sd::math::nd4j_min(boxes.t(nextIndex, 0), boxes.t(nextIndex, 2)); - T minXNext = sd::math::nd4j_min(boxes.t(nextIndex, 1), boxes.t(nextIndex, 3)); - T maxYNext = sd::math::nd4j_max(boxes.t(nextIndex, 0), boxes.t(nextIndex, 2)); - T maxXNext = sd::math::nd4j_max(boxes.t(nextIndex, 1), boxes.t(nextIndex, 3)); - T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev); - T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext); - - if (areaNext <= T(0.f) || areaPrev <= T(0.f)) return false; - - T minIntersectionY = sd::math::nd4j_max(minYPrev, minYNext); - T minIntersectionX = sd::math::nd4j_max(minXPrev, minXNext); - T maxIntersectionY = sd::math::nd4j_min(maxYPrev, maxYNext); - T maxIntersectionX = sd::math::nd4j_min(maxXPrev, maxXNext); - T intersectionArea = - sd::math::nd4j_max(T(maxIntersectionY - minIntersectionY), T(0.0f)) * - sd::math::nd4j_max(T(maxIntersectionX - minIntersectionX), T(0.0f)); - T intersectionValue = intersectionArea / (areaPrev + areaNext - intersectionArea); - return intersectionValue > threshold; - - }; -// int numSelected = 0; - int numBoxes = actualIndicesCount; //boxes->sizeAt(0); - int numSelected = 0; - - for (int i = 0; i < numBoxes; ++i) { - bool shouldSelect = numSelected < output->lengthOf(); - - // FIXME: add parallelism here - for (int j = numSelected - 1; j >= 0; --j) { - if (shouldSelect) - if (needToSuppressWithThreshold(*boxes, indices[i], indices[selectedIndices[j]], T(overlapThreshold))) { - shouldSelect = false; - } - } - if (shouldSelect) { - output->p(numSelected, indices[i]); - selectedIndices[numSelected++] = i; - } - } +template +static void nonMaxSuppressionV2_(NDArray* boxes, NDArray* scales, int maxSize, + double overlapThreshold, double scoreThreshold, + NDArray* output) { + std::vector indices(scales->lengthOf()); + std::iota(indices.begin(), indices.end(), 0); + auto actualIndicesCount = indices.size(); + for (auto e = 0; e < scales->lengthOf(); e++) { + if (scales->e(e) < (float)scoreThreshold) { + indices[e] = -1; + actualIndicesCount--; } -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // Return intersection-over-union overlap between boxes i and j - template - static inline T similirityV3_(NDArray const& boxes, Nd4jLong i, Nd4jLong j) { - const T zero = static_cast(0.f); - const T yminI = math::nd4j_min(boxes.t(i, 0), boxes.t(i, 2)); - const T xminI = math::nd4j_min(boxes.t(i, 1), boxes.t(i, 3)); - const T ymaxI = math::nd4j_max(boxes.t(i, 0), boxes.t(i, 2)); - const T xmaxI = math::nd4j_max(boxes.t(i, 1), boxes.t(i, 3)); - const T yminJ = math::nd4j_min(boxes.t(j, 0), boxes.t(j, 2)); - const T xminJ = math::nd4j_min(boxes.t(j, 1), boxes.t(j, 3)); - const T ymaxJ = math::nd4j_max(boxes.t(j, 0), boxes.t(j, 2)); - const T xmaxJ = math::nd4j_max(boxes.t(j, 1), boxes.t(j, 3)); - const T areaI = (ymaxI - yminI) * (xmaxI - xminI); - const T areaJ = (ymaxJ - yminJ) * (xmaxJ - xminJ); - if (areaI <= zero || areaJ <= zero) { - return zero; + } + std::sort(indices.begin(), indices.end(), [scales](int i, int j) { + return i >= 0 && j >= 0 ? scales->e(i) > scales->e(j) : (i > j); + }); + + // std::vector selected(output->lengthOf()); + std::vector selectedIndices(output->lengthOf(), 0); + auto needToSuppressWithThreshold = [](NDArray& boxes, int previousIndex, + int nextIndex, T threshold) -> bool { + if (previousIndex < 0 || nextIndex < 0) return true; + T minYPrev = sd::math::nd4j_min(boxes.t(previousIndex, 0), + boxes.t(previousIndex, 2)); + T minXPrev = sd::math::nd4j_min(boxes.t(previousIndex, 1), + boxes.t(previousIndex, 3)); + T maxYPrev = sd::math::nd4j_max(boxes.t(previousIndex, 0), + boxes.t(previousIndex, 2)); + T maxXPrev = sd::math::nd4j_max(boxes.t(previousIndex, 1), + boxes.t(previousIndex, 3)); + T minYNext = + sd::math::nd4j_min(boxes.t(nextIndex, 0), boxes.t(nextIndex, 2)); + T minXNext = + sd::math::nd4j_min(boxes.t(nextIndex, 1), boxes.t(nextIndex, 3)); + T maxYNext = + sd::math::nd4j_max(boxes.t(nextIndex, 0), boxes.t(nextIndex, 2)); + T maxXNext = + sd::math::nd4j_max(boxes.t(nextIndex, 1), boxes.t(nextIndex, 3)); + T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev); + T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext); + + if (areaNext <= T(0.f) || areaPrev <= T(0.f)) return false; + + T minIntersectionY = sd::math::nd4j_max(minYPrev, minYNext); + T minIntersectionX = sd::math::nd4j_max(minXPrev, minXNext); + T maxIntersectionY = sd::math::nd4j_min(maxYPrev, maxYNext); + T maxIntersectionX = sd::math::nd4j_min(maxXPrev, maxXNext); + T intersectionArea = + sd::math::nd4j_max(T(maxIntersectionY - minIntersectionY), T(0.0f)) * + sd::math::nd4j_max(T(maxIntersectionX - minIntersectionX), T(0.0f)); + T intersectionValue = + intersectionArea / (areaPrev + areaNext - intersectionArea); + return intersectionValue > threshold; + }; + // int numSelected = 0; + int numBoxes = actualIndicesCount; // boxes->sizeAt(0); + int numSelected = 0; + + for (int i = 0; i < numBoxes; ++i) { + bool shouldSelect = numSelected < output->lengthOf(); + + // FIXME: add parallelism here + for (int j = numSelected - 1; j >= 0; --j) { + if (shouldSelect) + if (needToSuppressWithThreshold(*boxes, indices[i], + indices[selectedIndices[j]], + T(overlapThreshold))) { + shouldSelect = false; } - const T intersectionYmin = math::nd4j_max(yminI, yminJ); - const T intersectionXmin = math::nd4j_max(xminI, xminJ); - const T intersectionYmax = math::nd4j_min(ymaxI, ymaxJ); - const T intersectionXmax = math::nd4j_min(xmaxI, xmaxJ); - const T intersectionY = intersectionYmax - intersectionYmin; - const T intersectionX = intersectionXmax - intersectionXmin; - const T intersectionArea = math::nd4j_max(intersectionY, zero) * math::nd4j_max(intersectionX, zero); - return intersectionArea / (areaI + areaJ - intersectionArea); - } - - template - static inline T similiratyOverlaps_(NDArray const& boxes, Nd4jLong i, Nd4jLong j) { - return boxes.t(i, j); - } - - typedef NDArray (*SimiliratyFunc)(NDArray const& boxes, Nd4jLong i, Nd4jLong j); - - static NDArray similiratyOverlaps(NDArray const& boxes, Nd4jLong i, Nd4jLong j) { - NDArray res(boxes.dataType(), boxes.getContext()); // = NDArrayFactory::create(0.); - BUILD_SINGLE_SELECTOR(boxes.dataType(), res = similiratyOverlaps_, (boxes, i, j) , FLOAT_TYPES); - return res; } - - static NDArray similiratyV3(NDArray const& boxes, Nd4jLong i, Nd4jLong j) { - NDArray res(boxes.dataType(), boxes.getContext()); // = NDArrayFactory::create(0.); - BUILD_SINGLE_SELECTOR(boxes.dataType(), res = similirityV3_, (boxes, i, j) , FLOAT_TYPES); - return res; + if (shouldSelect) { + output->p(numSelected, indices[i]); + selectedIndices[numSelected++] = i; } - + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - static Nd4jLong - nonMaxSuppressionGeneric_(sd::LaunchContext* context, NDArray* boxes, NDArray* scores, int outputSize, - float overlapThreshold, float scoreThreshold, NDArray* output, SimiliratyFunc f) { - - auto numBoxes = boxes->sizeAt(0); - T* scoresData = scores->dataBuffer()->primaryAsT(); - - // Data structure for a selection candidate in NMS. - struct Candidate { - int _boxIndex; - T _score; - int _suppressBeginIndex; - }; - - auto cmp = [](const Candidate& bsI, const Candidate& bsJ) -> bool{ - return ((bsI._score == bsJ._score) && (bsI._boxIndex > bsJ._boxIndex)) || - (bsI._score < bsJ._score); - }; - - std::priority_queue, decltype(cmp)> candidatePriorityQueue(cmp); - for (auto i = 0; i < scores->lengthOf(); ++i) { - if ((float)scoresData[i] > (float)scoreThreshold) { - candidatePriorityQueue.emplace(Candidate({i, scoresData[i], 0})); - } - } - - std::vector selected; - T similarity, originalScore; - Candidate nextCandidate; - - while (selected.size() < outputSize && !candidatePriorityQueue.empty()) { - nextCandidate = candidatePriorityQueue.top(); - originalScore = nextCandidate._score; - candidatePriorityQueue.pop(); - - // Overlapping boxes are likely to have similar scores, therefore we - // iterate through the previously selected boxes backwards in order to - // see if `nextCandidate` should be suppressed. We also enforce a property - // that a candidate can be suppressed by another candidate no more than - // once via `suppress_begin_index` which tracks which previously selected - // boxes have already been compared against next_candidate prior to a given - // iteration. These previous selected boxes are then skipped over in the - // following loop. - bool shouldHardSuppress = false; - for (int j = static_cast(selected.size()) - 1; j >= nextCandidate._suppressBeginIndex; --j) { - auto similarityA = f(*boxes, nextCandidate._boxIndex, selected[j]); //boxes->t(nextCandidate._boxIndex, selected[j]); - similarity = similarityA.template t(0); - nextCandidate._score *= T(similarity <= overlapThreshold?1.0:0.); //suppressWeightFunc(similarity); +// Return intersection-over-union overlap between boxes i and j +template +static inline T similirityV3_(NDArray const& boxes, Nd4jLong i, Nd4jLong j) { + const T zero = static_cast(0.f); + const T yminI = math::nd4j_min(boxes.t(i, 0), boxes.t(i, 2)); + const T xminI = math::nd4j_min(boxes.t(i, 1), boxes.t(i, 3)); + const T ymaxI = math::nd4j_max(boxes.t(i, 0), boxes.t(i, 2)); + const T xmaxI = math::nd4j_max(boxes.t(i, 1), boxes.t(i, 3)); + const T yminJ = math::nd4j_min(boxes.t(j, 0), boxes.t(j, 2)); + const T xminJ = math::nd4j_min(boxes.t(j, 1), boxes.t(j, 3)); + const T ymaxJ = math::nd4j_max(boxes.t(j, 0), boxes.t(j, 2)); + const T xmaxJ = math::nd4j_max(boxes.t(j, 1), boxes.t(j, 3)); + const T areaI = (ymaxI - yminI) * (xmaxI - xminI); + const T areaJ = (ymaxJ - yminJ) * (xmaxJ - xminJ); + if (areaI <= zero || areaJ <= zero) { + return zero; + } + const T intersectionYmin = math::nd4j_max(yminI, yminJ); + const T intersectionXmin = math::nd4j_max(xminI, xminJ); + const T intersectionYmax = math::nd4j_min(ymaxI, ymaxJ); + const T intersectionXmax = math::nd4j_min(xmaxI, xmaxJ); + const T intersectionY = intersectionYmax - intersectionYmin; + const T intersectionX = intersectionXmax - intersectionXmin; + const T intersectionArea = + math::nd4j_max(intersectionY, zero) * math::nd4j_max(intersectionX, zero); + return intersectionArea / (areaI + areaJ - intersectionArea); +} - // First decide whether to perform hard suppression - if ((float)similarity >= static_cast(overlapThreshold)) { - shouldHardSuppress = true; - break; - } +template +static inline T similiratyOverlaps_(NDArray const& boxes, Nd4jLong i, + Nd4jLong j) { + return boxes.t(i, j); +} - // If next_candidate survives hard suppression, apply soft suppression - if ((float)nextCandidate._score <= (float)scoreThreshold) break; - } - // If `nextCandidate._score` has not dropped below `scoreThreshold` - // by this point, then we know that we went through all of the previous - // selections and can safely update `suppress_begin_index` to - // `selected.size()`. If on the other hand `next_candidate.score` - // *has* dropped below the score threshold, then since `suppressWeight` - // always returns values in [0, 1], further suppression by items that were - // not covered in the above for loop would not have caused the algorithm - // to select this item. We thus do the same update to - // `suppressBeginIndex`, but really, this element will not be added back - // into the priority queue in the following. - nextCandidate._suppressBeginIndex = selected.size(); +typedef NDArray (*SimiliratyFunc)(NDArray const& boxes, Nd4jLong i, Nd4jLong j); - if (!shouldHardSuppress) { - if (nextCandidate._score == originalScore) { - // Suppression has not occurred, so select next_candidate - selected.push_back(nextCandidate._boxIndex); -// selected_scores.push_back(nextCandidate._score); - } - if ((float)nextCandidate._score > (float)scoreThreshold) { - // Soft suppression has occurred and current score is still greater than - // score_threshold; add next_candidate back onto priority queue. - candidatePriorityQueue.push(nextCandidate); - } - } - } +static NDArray similiratyOverlaps(NDArray const& boxes, Nd4jLong i, + Nd4jLong j) { + NDArray res(boxes.dataType(), + boxes.getContext()); // = NDArrayFactory::create(0.); + BUILD_SINGLE_SELECTOR(boxes.dataType(), res = similiratyOverlaps_, + (boxes, i, j), FLOAT_TYPES); + return res; +} - if (output) { - DataBuffer buf(selected.data(), selected.size() * sizeof(I), DataTypeUtils::fromT()); - output->dataBuffer()->copyBufferFrom(buf, buf.getLenInBytes()); - } +static NDArray similiratyV3(NDArray const& boxes, Nd4jLong i, Nd4jLong j) { + NDArray res(boxes.dataType(), + boxes.getContext()); // = NDArrayFactory::create(0.); + BUILD_SINGLE_SELECTOR(boxes.dataType(), res = similirityV3_, (boxes, i, j), + FLOAT_TYPES); + return res; +} - return (Nd4jLong)selected.size(); +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +static Nd4jLong nonMaxSuppressionGeneric_(sd::LaunchContext* context, + NDArray* boxes, NDArray* scores, + int outputSize, + float overlapThreshold, + float scoreThreshold, NDArray* output, + SimiliratyFunc f) { + auto numBoxes = boxes->sizeAt(0); + T* scoresData = scores->dataBuffer()->primaryAsT(); + + // Data structure for a selection candidate in NMS. + struct Candidate { + int _boxIndex; + T _score; + int _suppressBeginIndex; + }; + + auto cmp = [](const Candidate& bsI, const Candidate& bsJ) -> bool { + return ((bsI._score == bsJ._score) && (bsI._boxIndex > bsJ._boxIndex)) || + (bsI._score < bsJ._score); + }; + + std::priority_queue, decltype(cmp)> + candidatePriorityQueue(cmp); + for (auto i = 0; i < scores->lengthOf(); ++i) { + if ((float)scoresData[i] > (float)scoreThreshold) { + candidatePriorityQueue.emplace(Candidate({i, scoresData[i], 0})); } - - Nd4jLong - nonMaxSuppressionGeneric(sd::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize, - double overlapThreshold, double scoreThreshold, NDArray* output) { - BUILD_DOUBLE_SELECTOR(boxes->dataType(), output == nullptr?DataType::INT32:output->dataType(), return nonMaxSuppressionGeneric_, (context, boxes, scores, maxSize, overlapThreshold, scoreThreshold, output, similiratyOverlaps), FLOAT_TYPES, INTEGER_TYPES); - return 0; + } + + std::vector selected; + T similarity, originalScore; + Candidate nextCandidate; + + while (selected.size() < outputSize && !candidatePriorityQueue.empty()) { + nextCandidate = candidatePriorityQueue.top(); + originalScore = nextCandidate._score; + candidatePriorityQueue.pop(); + + // Overlapping boxes are likely to have similar scores, therefore we + // iterate through the previously selected boxes backwards in order to + // see if `nextCandidate` should be suppressed. We also enforce a property + // that a candidate can be suppressed by another candidate no more than + // once via `suppress_begin_index` which tracks which previously selected + // boxes have already been compared against next_candidate prior to a given + // iteration. These previous selected boxes are then skipped over in the + // following loop. + bool shouldHardSuppress = false; + for (int j = static_cast(selected.size()) - 1; + j >= nextCandidate._suppressBeginIndex; --j) { + auto similarityA = + f(*boxes, nextCandidate._boxIndex, + selected[j]); // boxes->t(nextCandidate._boxIndex, selected[j]); + similarity = similarityA.template t(0); + nextCandidate._score *= T(similarity <= overlapThreshold + ? 1.0 + : 0.); // suppressWeightFunc(similarity); + + // First decide whether to perform hard suppression + if ((float)similarity >= static_cast(overlapThreshold)) { + shouldHardSuppress = true; + break; + } + + // If next_candidate survives hard suppression, apply soft suppression + if ((float)nextCandidate._score <= (float)scoreThreshold) break; } - - Nd4jLong - nonMaxSuppressionV3(sd::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize, - double overlapThreshold, double scoreThreshold, NDArray* output) { - BUILD_DOUBLE_SELECTOR(boxes->dataType(), output == nullptr?DataType::INT32:output->dataType(), return nonMaxSuppressionGeneric_, (context, boxes, scores, maxSize, overlapThreshold, scoreThreshold, output, similiratyV3), FLOAT_TYPES, INTEGER_TYPES); - return 0; + // If `nextCandidate._score` has not dropped below `scoreThreshold` + // by this point, then we know that we went through all of the previous + // selections and can safely update `suppress_begin_index` to + // `selected.size()`. If on the other hand `next_candidate.score` + // *has* dropped below the score threshold, then since `suppressWeight` + // always returns values in [0, 1], further suppression by items that were + // not covered in the above for loop would not have caused the algorithm + // to select this item. We thus do the same update to + // `suppressBeginIndex`, but really, this element will not be added back + // into the priority queue in the following. + nextCandidate._suppressBeginIndex = selected.size(); + + if (!shouldHardSuppress) { + if (nextCandidate._score == originalScore) { + // Suppression has not occurred, so select next_candidate + selected.push_back(nextCandidate._boxIndex); + // selected_scores.push_back(nextCandidate._score); + } + if ((float)nextCandidate._score > (float)scoreThreshold) { + // Soft suppression has occurred and current score is still greater than + // score_threshold; add next_candidate back onto priority queue. + candidatePriorityQueue.push(nextCandidate); + } } + } - BUILD_DOUBLE_TEMPLATE(template Nd4jLong nonMaxSuppressionGeneric_, (sd::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize, - float overlapThreshold, float scoreThreshold, NDArray* output, SimiliratyFunc similiratyFunc), FLOAT_TYPES, INTEGER_TYPES); + if (output) { + DataBuffer buf(selected.data(), selected.size() * sizeof(I), + DataTypeUtils::fromT()); + output->dataBuffer()->copyBufferFrom(buf, buf.getLenInBytes()); + } - void - nonMaxSuppression(sd::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, - double overlapThreshold, double scoreThreshold, NDArray* output) { - BUILD_SINGLE_SELECTOR(boxes->dataType(), nonMaxSuppressionV2_, (boxes, scales, maxSize, - overlapThreshold, scoreThreshold, output), NUMERIC_TYPES); - } - BUILD_SINGLE_TEMPLATE(template void nonMaxSuppressionV2_, (NDArray* boxes, NDArray* scales, int maxSize, - double overlapThreshold, double scoreThreshold, NDArray* output), NUMERIC_TYPES); + return (Nd4jLong)selected.size(); +} +Nd4jLong nonMaxSuppressionGeneric(sd::LaunchContext* context, NDArray* boxes, + NDArray* scores, int maxSize, + double overlapThreshold, + double scoreThreshold, NDArray* output) { + BUILD_DOUBLE_SELECTOR( + boxes->dataType(), + output == nullptr ? DataType::INT32 : output->dataType(), + return nonMaxSuppressionGeneric_, + (context, boxes, scores, maxSize, overlapThreshold, scoreThreshold, + output, similiratyOverlaps), + FLOAT_TYPES, INTEGER_TYPES); + return 0; } + +Nd4jLong nonMaxSuppressionV3(sd::LaunchContext* context, NDArray* boxes, + NDArray* scores, int maxSize, + double overlapThreshold, double scoreThreshold, + NDArray* output) { + BUILD_DOUBLE_SELECTOR( + boxes->dataType(), + output == nullptr ? DataType::INT32 : output->dataType(), + return nonMaxSuppressionGeneric_, + (context, boxes, scores, maxSize, overlapThreshold, scoreThreshold, + output, similiratyV3), + FLOAT_TYPES, INTEGER_TYPES); + return 0; +} + +BUILD_DOUBLE_TEMPLATE(template Nd4jLong nonMaxSuppressionGeneric_, + (sd::LaunchContext * context, NDArray* boxes, + NDArray* scores, int maxSize, float overlapThreshold, + float scoreThreshold, NDArray* output, + SimiliratyFunc similiratyFunc), + FLOAT_TYPES, INTEGER_TYPES); + +void nonMaxSuppression(sd::LaunchContext* context, NDArray* boxes, + NDArray* scales, int maxSize, double overlapThreshold, + double scoreThreshold, NDArray* output) { + BUILD_SINGLE_SELECTOR( + boxes->dataType(), nonMaxSuppressionV2_, + (boxes, scales, maxSize, overlapThreshold, scoreThreshold, output), + NUMERIC_TYPES); } -} \ No newline at end of file +BUILD_SINGLE_TEMPLATE(template void nonMaxSuppressionV2_, + (NDArray * boxes, NDArray* scales, int maxSize, + double overlapThreshold, double scoreThreshold, + NDArray* output), + NUMERIC_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/imagesHelpers.cpp b/libnd4j/include/ops/declarable/helpers/cpu/imagesHelpers.cpp index 108804f38624..4f28757e69e6 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/imagesHelpers.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/imagesHelpers.cpp @@ -19,10 +19,10 @@ // @author AbdelRauf (rauf@konduit.ai) // +#include +#include #include #include -#include -#include namespace sd { namespace ops { @@ -30,259 +30,271 @@ namespace helpers { template static void rgbToGrs_(const NDArray& input, NDArray& output, const int dimC) { - - const T* x = input.bufferAsT(); - T* z = output.bufferAsT(); - const int rank = input.rankOf(); - - if(dimC == rank - 1 && 'c' == input.ordering() && 1 == input.ews() && - 'c' == output.ordering() && 1 == output.ews()){ - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - const auto xStep = i*3; - z[i] = 0.2989f*x[xStep] + 0.5870f*x[xStep + 1] + 0.1140f*x[xStep + 2]; - } - }; - - samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1); - return; - } - - auto func = PRAGMA_THREADS_FOR{ - - int coords[MAX_RANK]; - for (auto i = start; i < stop; i++) { - shape::index2coordsCPU(start, i, output.shapeInfo(), coords); - const auto zOffset = shape::getOffset(output.shapeInfo(), coords); - const auto xOffset0 = shape::getOffset(input.shapeInfo(), coords); - const auto xOffset1 = xOffset0 + input.strideAt(dimC); - const auto xOffset2 = xOffset1 + input.strideAt(dimC); - z[zOffset] = 0.2989f*x[xOffset0] + 0.5870f*x[xOffset1] + 0.1140f*x[xOffset2]; - } + const T* x = input.bufferAsT(); + T* z = output.bufferAsT(); + const int rank = input.rankOf(); + + if (dimC == rank - 1 && 'c' == input.ordering() && 1 == input.ews() && + 'c' == output.ordering() && 1 == output.ews()) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + const auto xStep = i * 3; + z[i] = 0.2989f * x[xStep] + 0.5870f * x[xStep + 1] + + 0.1140f * x[xStep + 2]; + } }; samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1); return; + } + + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, output.shapeInfo(), coords); + const auto zOffset = shape::getOffset(output.shapeInfo(), coords); + const auto xOffset0 = shape::getOffset(input.shapeInfo(), coords); + const auto xOffset1 = xOffset0 + input.strideAt(dimC); + const auto xOffset2 = xOffset1 + input.strideAt(dimC); + z[zOffset] = + 0.2989f * x[xOffset0] + 0.5870f * x[xOffset1] + 0.1140f * x[xOffset2]; + } + }; + + samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1); + return; } -void transformRgbGrs(sd::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) { - BUILD_SINGLE_SELECTOR(input.dataType(), rgbToGrs_, (input, output, dimC), NUMERIC_TYPES); +void transformRgbGrs(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const int dimC) { + BUILD_SINGLE_SELECTOR(input.dataType(), rgbToGrs_, (input, output, dimC), + NUMERIC_TYPES); } template -FORCEINLINE static void rgbToFromYuv_(const NDArray& input, NDArray& output, const int dimC, Op op) { - - const T* x = input.bufferAsT(); - T* z = output.bufferAsT(); - const int rank = input.rankOf(); - bool bSimple = (dimC == rank - 1 && 'c' == input.ordering() && 1 == input.ews() && - 'c' == output.ordering() && 1 == output.ews()); - - if (bSimple) { - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i += increment) { - op(x[i], x[i + 1], x[i + 2], z[i], z[i + 1], z[i + 2]); - } - }; - - samediff::Threads::parallel_for(func, 0, input.lengthOf(), 3); - return; - } - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input.shapeInfo(), dimC); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output.shapeInfo(), dimC); - - const Nd4jLong numOfTads = packX.numberOfTads(); - const Nd4jLong xDimCstride = input.stridesOf()[dimC]; - const Nd4jLong zDimCstride = output.stridesOf()[dimC]; - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - const T* xTad = x + packX.platformOffsets()[i]; - T* zTad = z + packZ.platformOffsets()[i]; - op(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); - } +FORCEINLINE static void rgbToFromYuv_(const NDArray& input, NDArray& output, + const int dimC, Op op) { + const T* x = input.bufferAsT(); + T* z = output.bufferAsT(); + const int rank = input.rankOf(); + bool bSimple = + (dimC == rank - 1 && 'c' == input.ordering() && 1 == input.ews() && + 'c' == output.ordering() && 1 == output.ews()); + + if (bSimple) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i += increment) { + op(x[i], x[i + 1], x[i + 2], z[i], z[i + 1], z[i + 2]); + } }; - samediff::Threads::parallel_tad(func, 0, numOfTads); + samediff::Threads::parallel_for(func, 0, input.lengthOf(), 3); return; + } + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input.shapeInfo(), dimC); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output.shapeInfo(), dimC); + + const Nd4jLong numOfTads = packX.numberOfTads(); + const Nd4jLong xDimCstride = input.stridesOf()[dimC]; + const Nd4jLong zDimCstride = output.stridesOf()[dimC]; + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + const T* xTad = x + packX.platformOffsets()[i]; + T* zTad = z + packZ.platformOffsets()[i]; + op(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], + zTad[zDimCstride], zTad[2 * zDimCstride]); + } + }; + + samediff::Threads::parallel_tad(func, 0, numOfTads); + return; } template -FORCEINLINE static void rgbYuv_(const NDArray& input, NDArray& output, const int dimC) { - auto op = sd::ops::helpers::rgbYuv; - return rgbToFromYuv_(input, output, dimC, op); +FORCEINLINE static void rgbYuv_(const NDArray& input, NDArray& output, + const int dimC) { + auto op = sd::ops::helpers::rgbYuv; + return rgbToFromYuv_(input, output, dimC, op); } -void transformRgbYuv(sd::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) { - BUILD_SINGLE_SELECTOR(input.dataType(), rgbYuv_, (input, output, dimC), FLOAT_TYPES); +void transformRgbYuv(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const int dimC) { + BUILD_SINGLE_SELECTOR(input.dataType(), rgbYuv_, (input, output, dimC), + FLOAT_TYPES); } template -FORCEINLINE static void yuvRgb_(const NDArray& input, NDArray& output, const int dimC) { - auto op = sd::ops::helpers::yuvRgb; - return rgbToFromYuv_(input, output, dimC, op); +FORCEINLINE static void yuvRgb_(const NDArray& input, NDArray& output, + const int dimC) { + auto op = sd::ops::helpers::yuvRgb; + return rgbToFromYuv_(input, output, dimC, op); } -void transformYuvRgb(sd::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) { - BUILD_SINGLE_SELECTOR(input.dataType(), yuvRgb_, (input, output, dimC), FLOAT_TYPES); +void transformYuvRgb(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const int dimC) { + BUILD_SINGLE_SELECTOR(input.dataType(), yuvRgb_, (input, output, dimC), + FLOAT_TYPES); } template -FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output, const int dimC, Op op) { - - const int rank = input->rankOf(); - - const T* x = input->bufferAsT(); - T* z = output->bufferAsT(); - - if (dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && input->ordering() == 'c' && output->ordering() == 'c') { - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i += increment) { - op(x[i], x[i + 1], x[i + 2], z[i], z[i + 1], z[i + 2]); - } - }; - - samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3); - } - else { - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimC); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimC); - - const Nd4jLong numOfTads = packX.numberOfTads(); - const Nd4jLong xDimCstride = input->stridesOf()[dimC]; - const Nd4jLong zDimCstride = output->stridesOf()[dimC]; +FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output, + const int dimC, Op op) { + const int rank = input->rankOf(); + + const T* x = input->bufferAsT(); + T* z = output->bufferAsT(); + + if (dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && + input->ordering() == 'c' && output->ordering() == 'c') { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i += increment) { + op(x[i], x[i + 1], x[i + 2], z[i], z[i + 1], z[i + 2]); + } + }; - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - const T* xTad = x + packX.platformOffsets()[i]; - T* zTad = z + packZ.platformOffsets()[i]; - op(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); + samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3); + } else { + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimC); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimC); - } - }; + const Nd4jLong numOfTads = packX.numberOfTads(); + const Nd4jLong xDimCstride = input->stridesOf()[dimC]; + const Nd4jLong zDimCstride = output->stridesOf()[dimC]; + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + const T* xTad = x + packX.platformOffsets()[i]; + T* zTad = z + packZ.platformOffsets()[i]; + op(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], + zTad[zDimCstride], zTad[2 * zDimCstride]); + } + }; - samediff::Threads::parallel_tad(func, 0, numOfTads); - } + samediff::Threads::parallel_tad(func, 0, numOfTads); + } } - template -FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output, const int dimC , T (&tr)[3][3] ) { - - const int rank = input->rankOf(); - - const T* x = input->bufferAsT(); - T* z = output->bufferAsT(); - // TODO: Use tensordot or other optimizied helpers to see if we can get better performance. - - if (dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && input->ordering() == 'c' && output->ordering() == 'c') { +FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output, + const int dimC, T (&tr)[3][3]) { + const int rank = input->rankOf(); + + const T* x = input->bufferAsT(); + T* z = output->bufferAsT(); + // TODO: Use tensordot or other optimizied helpers to see if we can get better + // performance. + + if (dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && + input->ordering() == 'c' && output->ordering() == 'c') { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i += increment) { + // simple M*v //tr.T*v.T // v * tr //rule: (AB)' =B'A' + // v.shape (1,3) row vector + T x0, x1, x2; + x0 = x[i]; // just additional hint + x1 = x[i + 1]; + x2 = x[i + 2]; + z[i] = x0 * tr[0][0] + x1 * tr[1][0] + x2 * tr[2][0]; + z[i + 1] = x0 * tr[0][1] + x1 * tr[1][1] + x2 * tr[2][1]; + z[i + 2] = x0 * tr[0][2] + x1 * tr[1][2] + x2 * tr[2][2]; + } + }; - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i += increment) { - //simple M*v //tr.T*v.T // v * tr //rule: (AB)' =B'A' - // v.shape (1,3) row vector - T x0, x1, x2; - x0 = x[i]; //just additional hint - x1 = x[i + 1]; - x2 = x[i + 2]; - z[i] = x0 * tr[0][0] + x1 * tr[1][0] + x2 * tr[2][0]; - z[i+1] = x0 * tr[0][1] + x1 * tr[1][1] + x2 * tr[2][1]; - z[i+2] = x0 * tr[0][2] + x1 * tr[1][2] + x2 * tr[2][2]; + samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3); + } else { + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimC); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimC); - } - }; + const Nd4jLong numOfTads = packX.numberOfTads(); + const Nd4jLong xDimCstride = input->stridesOf()[dimC]; + const Nd4jLong zDimCstride = output->stridesOf()[dimC]; + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + const T* xTad = x + packX.platformOffsets()[i]; + T* zTad = z + packZ.platformOffsets()[i]; + // simple M*v //tr.T*v + T x0, x1, x2; + x0 = xTad[0]; + x1 = xTad[xDimCstride]; + x2 = xTad[2 * xDimCstride]; + zTad[0] = x0 * tr[0][0] + x1 * tr[1][0] + x2 * tr[2][0]; + zTad[zDimCstride] = x0 * tr[0][1] + x1 * tr[1][1] + x2 * tr[2][1]; + zTad[2 * zDimCstride] = x0 * tr[0][2] + x1 * tr[1][2] + x2 * tr[2][2]; + } + }; - samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3); - } - else { - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimC); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimC); - - const Nd4jLong numOfTads = packX.numberOfTads(); - const Nd4jLong xDimCstride = input->stridesOf()[dimC]; - const Nd4jLong zDimCstride = output->stridesOf()[dimC]; - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - const T* xTad = x + packX.platformOffsets()[i]; - T* zTad = z + packZ.platformOffsets()[i]; - //simple M*v //tr.T*v - T x0, x1, x2; - x0 = xTad[0]; - x1 = xTad[xDimCstride]; - x2 = xTad[2 * xDimCstride]; - zTad[0] = x0 * tr[0][0] + x1 * tr[1][0] + x2 * tr[2][0]; - zTad[zDimCstride] = x0 * tr[0][1] + x1 * tr[1][1] + x2 * tr[2][1]; - zTad[2 * zDimCstride] = x0 * tr[0][2] + x1 * tr[1][2] + x2 * tr[2][2]; - - } - }; - - samediff::Threads::parallel_tad(func, 0, numOfTads); - } + samediff::Threads::parallel_tad(func, 0, numOfTads); + } } - - - template -FORCEINLINE static void hsvRgb(const NDArray* input, NDArray* output, const int dimC) { - auto op = sd::ops::helpers::hsvToRgb; - return tripleTransformer(input, output, dimC, op); +FORCEINLINE static void hsvRgb(const NDArray* input, NDArray* output, + const int dimC) { + auto op = sd::ops::helpers::hsvToRgb; + return tripleTransformer(input, output, dimC, op); } template -FORCEINLINE static void rgbHsv(const NDArray* input, NDArray* output, const int dimC) { - auto op = sd::ops::helpers::rgbToHsv; - return tripleTransformer(input, output, dimC, op); +FORCEINLINE static void rgbHsv(const NDArray* input, NDArray* output, + const int dimC) { + auto op = sd::ops::helpers::rgbToHsv; + return tripleTransformer(input, output, dimC, op); } - template -FORCEINLINE static void rgbYiq(const NDArray* input, NDArray* output, const int dimC) { - T arr[3][3] = { - { (T)0.299, (T)0.59590059, (T)0.2115 }, - { (T)0.587, (T)-0.27455667, (T)-0.52273617 }, - { (T)0.114, (T)-0.32134392, (T)0.31119955 } - }; - return tripleTransformer(input, output, dimC, arr); +FORCEINLINE static void rgbYiq(const NDArray* input, NDArray* output, + const int dimC) { + T arr[3][3] = {{(T)0.299, (T)0.59590059, (T)0.2115}, + {(T)0.587, (T)-0.27455667, (T)-0.52273617}, + {(T)0.114, (T)-0.32134392, (T)0.31119955}}; + return tripleTransformer(input, output, dimC, arr); } template -FORCEINLINE static void yiqRgb(const NDArray* input, NDArray* output, const int dimC) { - //TODO: this operation does not use the clamp operation, so there is a possibility being out of range. - //Justify that it will not be out of range for images data - T arr[3][3] = { - { (T)1, (T)1, (T)1 }, - { (T)0.95598634, (T)-0.27201283, (T)-1.10674021 }, - { (T)0.6208248, (T)-0.64720424, (T)1.70423049 } - }; - return tripleTransformer(input, output, dimC, arr); +FORCEINLINE static void yiqRgb(const NDArray* input, NDArray* output, + const int dimC) { + // TODO: this operation does not use the clamp operation, so there is a + // possibility being out of range. Justify that it will not be out of range + // for images data + T arr[3][3] = {{(T)1, (T)1, (T)1}, + {(T)0.95598634, (T)-0.27201283, (T)-1.10674021}, + {(T)0.6208248, (T)-0.64720424, (T)1.70423049}}; + return tripleTransformer(input, output, dimC, arr); } - - -void transformHsvRgb(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { - BUILD_SINGLE_SELECTOR(input->dataType(), hsvRgb, (input, output, dimC), FLOAT_TYPES); +void transformHsvRgb(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), hsvRgb, (input, output, dimC), + FLOAT_TYPES); } -void transformRgbHsv(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { - BUILD_SINGLE_SELECTOR(input->dataType(), rgbHsv, (input, output, dimC), FLOAT_TYPES); +void transformRgbHsv(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), rgbHsv, (input, output, dimC), + FLOAT_TYPES); } -void transformYiqRgb(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { - BUILD_SINGLE_SELECTOR(input->dataType(), yiqRgb, (input, output, dimC), FLOAT_TYPES); +void transformYiqRgb(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), yiqRgb, (input, output, dimC), + FLOAT_TYPES); } -void transformRgbYiq(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { - BUILD_SINGLE_SELECTOR(input->dataType(), rgbYiq, (input, output, dimC), FLOAT_TYPES); +void transformRgbYiq(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), rgbYiq, (input, output, dimC), + FLOAT_TYPES); } - -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/invertPermutation.cpp b/libnd4j/include/ops/declarable/helpers/cpu/invertPermutation.cpp index 5325ac282b95..1bff0a9ecea6 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/invertPermutation.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/invertPermutation.cpp @@ -18,34 +18,37 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { //////////////////////////////////////////////////////////////////////// -void invertPermutation(sd::LaunchContext * context, const NDArray& input, NDArray& output) { - - std::set uniqueElems; - const int length = input.lengthOf(); - - for(int i = 0; i < length; ++i) { - - int elem = input.e(i); - - if(!uniqueElems.insert(elem).second) // this operation forbids us to use #pragma omp - throw std::runtime_error("helpers::invertPermutation function: input array contains duplicates !"); - - if(elem < 0 || elem > length - 1) - throw std::runtime_error("helpers::invertPermutation function: element of input array is out of range (0, length-1) !"); - - output.p(elem, i); - } +void invertPermutation(sd::LaunchContext* context, const NDArray& input, + NDArray& output) { + std::set uniqueElems; + const int length = input.lengthOf(); + + for (int i = 0; i < length; ++i) { + int elem = input.e(i); + + if (!uniqueElems.insert(elem) + .second) // this operation forbids us to use #pragma omp + throw std::runtime_error( + "helpers::invertPermutation function: input array contains " + "duplicates !"); + + if (elem < 0 || elem > length - 1) + throw std::runtime_error( + "helpers::invertPermutation function: element of input array is out " + "of range (0, length-1) !"); + + output.p(elem, i); + } } -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/ismax.cpp b/libnd4j/include/ops/declarable/helpers/cpu/ismax.cpp index c2bcb8399bd1..e3e7e0b61608 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/ismax.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/ismax.cpp @@ -19,193 +19,183 @@ // @author raver119@gmail.com // - -#include -#include -#include #include +#include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { template -static void ismax_(const NDArray* input, NDArray* output, const std::vector& dimensions) { - - if (input->isVector()) { - int dimensionsLength = dimensions.size(); - int length = input->lengthOf(); - if (!dimensions.empty() && (input->shapeOf())[dimensions[0]] == 1) { - for (int i = 0; i < length; i++) - output->p(i, 1); - } - else { - int eleStride = shape::elementWiseStride(input->shapeInfo()); - if (eleStride == 1) { - int maxIdx = 0; - auto currMax = input->e(0); - if (length < ELEMENT_THRESHOLD) { - - for (int i = 0; i < length; i++) { - if (currMax < input->e(i)) { - currMax = input->e(i); - maxIdx = i; - } - output->p(i, 0); - } - } - else { - - { - int maxIdxLocal = maxIdx; - auto currMaxLocal = currMax; - - for (int i = 0; i < length; i++) { - if (currMaxLocal < input->e(i)) { - currMaxLocal = input->e(i); - maxIdxLocal = i; - } - output->p(i, 0); - } - - PRAGMA_OMP_CRITICAL - { - if (currMax < currMaxLocal) { - currMax = currMaxLocal; - maxIdx = maxIdxLocal; - } - } - } - } - output->p(maxIdx, 1); +static void ismax_(const NDArray* input, NDArray* output, + const std::vector& dimensions) { + if (input->isVector()) { + int dimensionsLength = dimensions.size(); + int length = input->lengthOf(); + if (!dimensions.empty() && (input->shapeOf())[dimensions[0]] == 1) { + for (int i = 0; i < length; i++) output->p(i, 1); + } else { + int eleStride = shape::elementWiseStride(input->shapeInfo()); + if (eleStride == 1) { + int maxIdx = 0; + auto currMax = input->e(0); + if (length < ELEMENT_THRESHOLD) { + for (int i = 0; i < length; i++) { + if (currMax < input->e(i)) { + currMax = input->e(i); + maxIdx = i; } - else { - int maxIdx = 0; - auto currMax = input->e(0); - if (length < ELEMENT_THRESHOLD) { - - for (int i = 0; i < length; i++) { - if (currMax < input->e(i)) { - currMax = input->e(i); - maxIdx = i; - } - output->p(i, 0.f); - } - } - else { - - { - int maxIdxLocal = maxIdx; - auto currMaxLocal = currMax; - for (int i = 0; i < length; i++) { - if (currMaxLocal < input->e(i)) { - currMaxLocal = input->e(i); - maxIdxLocal = i; - } - output->p(i, 0.f); - } - - PRAGMA_OMP_CRITICAL - { - if (currMax < currMaxLocal) { - currMax = currMaxLocal; - maxIdx = maxIdxLocal; - } - } - } - } - output->p(maxIdx, 1); + output->p(i, 0); + } + } else { + { + int maxIdxLocal = maxIdx; + auto currMaxLocal = currMax; + + for (int i = 0; i < length; i++) { + if (currMaxLocal < input->e(i)) { + currMaxLocal = input->e(i); + maxIdxLocal = i; + } + output->p(i, 0); } + + PRAGMA_OMP_CRITICAL { + if (currMax < currMaxLocal) { + currMax = currMaxLocal; + maxIdx = maxIdxLocal; + } + } + } } - } - else { - int dimensionsLength = dimensions.size(); - //int tads = tad.numTads; - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the input shape info for setting up the tad problem - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), const_cast(dimensions.data()), dimensionsLength); - auto tadPackZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), const_cast(dimensions.data()), dimensionsLength); - - - auto tadShapeShapeInfo = tadPack.primaryShapeInfo(); - auto tadOffsets = tadPack.primaryOffsets(); - auto zOfsets = tadPackZ.platformOffsets(); - - int tadLength = shape::length(tadShapeShapeInfo); - int tads = tadPack.numberOfTads(); - - int tadsPerThread = tads / TAD_THRESHOLD; - int num_threads = sd::math::nd4j_max(1, tadsPerThread); - num_threads = sd::math::nd4j_min(num_threads, omp_get_max_threads()); - - auto tadEWS = shape::elementWiseStride(tadShapeShapeInfo); - auto zEWS = shape::elementWiseStride(tadPackZ.primaryShapeInfo()); - - int span = (tads / num_threads) + 8; - - auto func = PRAGMA_THREADS_FOR { - for (auto r = start; r < stop; r++) { - auto rX = const_cast(input)->bufferAsT() + tadOffsets[r]; - auto rZ = output->bufferAsT() + zOfsets[r]; - - auto maxValue = rX[0]; - int maxIdx = 0; - if (tadEWS == 1 && zEWS == 1) { - for (int i = 0; i < tadLength; i++) { - if (rX[i] > maxValue) { - maxIdx = i; - maxValue = rX[i]; - } - } - - PRAGMA_OMP_SIMD - for (int i = 0; i < tadLength; i++) { - rZ[i] = maxIdx == i ? (Z) 1 : (Z) 0; - } - } - else if (tadEWS > 1 && zEWS > 1) { - for (int i = 0; i < tadLength; i++) { - if (rX[i * tadEWS] > maxValue) { - maxIdx = i; - maxValue = rX[i * tadEWS]; - } - } - - PRAGMA_OMP_SIMD - for (int i = 0; i < tadLength; i++) { - rZ[i * zEWS] = maxIdx == i ? (Z) 1 : (Z) 0; - } - } else { - for (int i = 0; i < tadLength; i++) { - auto xOffset = shape::getIndexOffset(i, tadShapeShapeInfo); - if (rX[xOffset] > maxValue) { - maxIdx = i; - maxValue = rX[xOffset]; - } - } - - PRAGMA_OMP_SIMD - for (int i = 0; i < tadLength; i++) { - auto zOffset = shape::getIndexOffset(i, tadPackZ.primaryShapeInfo()); - rZ[zOffset] = maxIdx == i ? (Z) 1 : (Z) 0; - } - } + output->p(maxIdx, 1); + } else { + int maxIdx = 0; + auto currMax = input->e(0); + if (length < ELEMENT_THRESHOLD) { + for (int i = 0; i < length; i++) { + if (currMax < input->e(i)) { + currMax = input->e(i); + maxIdx = i; + } + output->p(i, 0.f); + } + } else { + { + int maxIdxLocal = maxIdx; + auto currMaxLocal = currMax; + for (int i = 0; i < length; i++) { + if (currMaxLocal < input->e(i)) { + currMaxLocal = input->e(i); + maxIdxLocal = i; + } + output->p(i, 0.f); } - }; - samediff::Threads::parallel_tad(func, 0, tads); + PRAGMA_OMP_CRITICAL { + if (currMax < currMaxLocal) { + currMax = currMaxLocal; + maxIdx = maxIdxLocal; + } + } + } + } + output->p(maxIdx, 1); + } } -} - + } else { + int dimensionsLength = dimensions.size(); + // int tads = tad.numTads; + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the input shape info for setting up the tad problem + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), const_cast(dimensions.data()), + dimensionsLength); + auto tadPackZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), const_cast(dimensions.data()), + dimensionsLength); + + auto tadShapeShapeInfo = tadPack.primaryShapeInfo(); + auto tadOffsets = tadPack.primaryOffsets(); + auto zOfsets = tadPackZ.platformOffsets(); + + int tadLength = shape::length(tadShapeShapeInfo); + int tads = tadPack.numberOfTads(); + + int tadsPerThread = tads / TAD_THRESHOLD; + int num_threads = sd::math::nd4j_max(1, tadsPerThread); + num_threads = sd::math::nd4j_min(num_threads, omp_get_max_threads()); + + auto tadEWS = shape::elementWiseStride(tadShapeShapeInfo); + auto zEWS = shape::elementWiseStride(tadPackZ.primaryShapeInfo()); + + int span = (tads / num_threads) + 8; + + auto func = PRAGMA_THREADS_FOR { + for (auto r = start; r < stop; r++) { + auto rX = const_cast(input)->bufferAsT() + tadOffsets[r]; + auto rZ = output->bufferAsT() + zOfsets[r]; + + auto maxValue = rX[0]; + int maxIdx = 0; + if (tadEWS == 1 && zEWS == 1) { + for (int i = 0; i < tadLength; i++) { + if (rX[i] > maxValue) { + maxIdx = i; + maxValue = rX[i]; + } + } + + PRAGMA_OMP_SIMD + for (int i = 0; i < tadLength; i++) { + rZ[i] = maxIdx == i ? (Z)1 : (Z)0; + } + } else if (tadEWS > 1 && zEWS > 1) { + for (int i = 0; i < tadLength; i++) { + if (rX[i * tadEWS] > maxValue) { + maxIdx = i; + maxValue = rX[i * tadEWS]; + } + } + + PRAGMA_OMP_SIMD + for (int i = 0; i < tadLength; i++) { + rZ[i * zEWS] = maxIdx == i ? (Z)1 : (Z)0; + } + } else { + for (int i = 0; i < tadLength; i++) { + auto xOffset = shape::getIndexOffset(i, tadShapeShapeInfo); + if (rX[xOffset] > maxValue) { + maxIdx = i; + maxValue = rX[xOffset]; + } + } + + PRAGMA_OMP_SIMD + for (int i = 0; i < tadLength; i++) { + auto zOffset = + shape::getIndexOffset(i, tadPackZ.primaryShapeInfo()); + rZ[zOffset] = maxIdx == i ? (Z)1 : (Z)0; + } + } + } + }; -void ismax(sd::LaunchContext * context, const NDArray *input, NDArray *output, const std::vector& dimensions) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), ismax_, (input, output, dimensions), LIBND4J_TYPES, LIBND4J_TYPES); + samediff::Threads::parallel_tad(func, 0, tads); + } } - -} -} +void ismax(sd::LaunchContext* context, const NDArray* input, NDArray* output, + const std::vector& dimensions) { + BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), ismax_, + (input, output, dimensions), LIBND4J_TYPES, + LIBND4J_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp b/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp index b2a0e537f70d..45dc14c3eaf5 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp @@ -18,369 +18,425 @@ // @author GS // -#include #include +#include #include namespace sd { namespace ops { namespace helpers { - template - static void reluDerivative__(NDArray* theFirst, NDArray* theSecond) { - auto functor = LAMBDA_TT(x, y){ - return x > (T) 0.f ? y : T(0.f); - }; +template +static void reluDerivative__(NDArray* theFirst, NDArray* theSecond) { + auto functor = LAMBDA_TT(x, y) { return x > (T)0.f ? y : T(0.f); }; + + theFirst->applyPairwiseLambda(*theSecond, functor, *theFirst); +} + +void reluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative__, + (theFirst, theSecond), FLOAT_TYPES); +} + +template +static void reluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { + T zero = (T)0.f; + auto functor = LAMBDA_TT(x, y, zero) { return x > zero ? y : zero; }; + + input->applyPairwiseLambda(*epsilon, functor, *output); + + /* + auto x = input->bufferAsT(); + auto y = epsilon->bufferAsT(); + auto z = output->bufferAsT(); + + int length = input->lengthOf(); + + T zero = (T) 0.f; + + PRAGMA_OMP_PARALLEL_FOR + for (int e = 0; e < length; e++) { + z[e] = x[e] > zero ? y[e] : zero; + } + */ +} + +void reluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} + +template +static void relu6Derivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + return x > (T)0.f && x < (T)6.f ? y : T(0.f); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void relu6Derivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), relu6Derivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} + +template +static void leakyReluDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output, const float alpha) { + const T alphaT = static_cast(alpha); + + auto functor = LAMBDA_TT(x, y, alphaT) { return x < 0 ? alphaT * y : y; }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void leakyReluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput, + const float alpha) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, + (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES); +} + +template +static void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, + const float alpha) { + const T alphaT = static_cast(alpha); + + auto functor = LAMBDA_TT(x, y, alphaT) { + return y * sd::math::nd4j_eluderivative(x, alphaT); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void eluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput, const float alpha) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, + (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES); +} + +template +static void seluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + return y * simdOps::SELUDerivative::op(x, nullptr); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void seluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), seluDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} + +template +static void cubeDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { + auto functor = LAMBDA_TT(x, y) { return y * (3 * x * x); }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void cubeDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), cubeDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} + +// return (x >= X(0.f) ? y: -y); +template +static void reduceNorm1_(NDArray* input, NDArray* epsilon, NDArray* output) { + auto functor = LAMBDA_TT(x, y) { return x > T(0.f) ? y : -y; }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void reduceNorm1(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), reduceNorm1_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} + +//////////////////////////////////////////////////////////////////////// +template +static void sigmCrossEntropy_(NDArray* logits, NDArray* labels, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + return sd::math::nd4j_max(x, (T)0.f) - x * y + + sd::math::nd4j_log( + (T)1.f + sd::math::nd4j_exp(-sd::math::nd4j_abs(x))); + }; + + logits->applyPairwiseLambda(*labels, functor, *output); +} + +void sigmCrossEntropy(sd::LaunchContext* context, NDArray* logits, + NDArray* labels, NDArray* output) { + BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropy_, + (logits, labels, output), FLOAT_TYPES); +} + +//////////////////////////////////////////////////////////////////////// +template +static void sigmCrossEntropyGrad_(NDArray* logits, NDArray* labels, + NDArray* output) { + // 1 - labels - 1 / (1 + exp(logits)) + auto functor = LAMBDA_TT(x, y) { + if (x <= 0) + return static_cast(1.) - y - + static_cast(1.) / + (static_cast(1.) + sd::math::nd4j_exp(x)); + auto e = sd::math::nd4j_exp(-x); + return static_cast(1.) - y - e / (static_cast(1.) + e); + }; + + logits->applyPairwiseLambda(*labels, functor, *output); +} + +void sigmCrossEntropyGrad(sd::LaunchContext* context, NDArray* logits, + NDArray* labels, NDArray* output) { + BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropyGrad_, + (logits, labels, output), FLOAT_TYPES); +} - theFirst->applyPairwiseLambda(*theSecond, functor, *theFirst); - } +//////////////////////////////////////////////////////////////////////// +template +static void tanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + T th = sd::math::nd4j_tanh(x); + return y * ((T)1.0f - (th * th)); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} - void reluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative__, (theFirst, theSecond), FLOAT_TYPES); - } +void tanhDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), tanhDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} + +// return static_cast(d2) * simdOps::HardTanhDerivative::op(d1, nullptr); +template +static void hardTanhDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + T th = sd::math::nd4j_tanh(x); + return y * simdOps::HardTanhDerivative::op(x, nullptr); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void hardTanhDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardTanhDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} + +template +static void rationalTanhDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + return y * simdOps::RationalTanhDerivative::op(x, nullptr); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} - template - static void reluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { +void rationalTanhDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), rationalTanhDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} - T zero = (T) 0.f; - auto functor = LAMBDA_TT(x, y, zero){ - return x > zero ? y : zero; - }; +template +static void rectifiedTanhDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + return x > (T)0.0f ? y * (sd::math::nd4j_tanhderivative(x)) : (T)0.0f; + }; - input->applyPairwiseLambda(*epsilon, functor, *output); + input->applyPairwiseLambda(*epsilon, functor, *output); +} - /* - auto x = input->bufferAsT(); - auto y = epsilon->bufferAsT(); - auto z = output->bufferAsT(); +void rectifiedTanhDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), rectifiedTanhDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} - int length = input->lengthOf(); +// X f = (X) 1.0f + sd::math::nd4j_abs(d1); +// return (X) d2 * ((X) 1.0f / (f * f)); - T zero = (T) 0.f; +template +static void softSignDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + T ss = (T)1.f + sd::math::nd4j_abs(x); + return y * ((T)1.0f / (ss * ss)); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} - PRAGMA_OMP_PARALLEL_FOR - for (int e = 0; e < length; e++) { - z[e] = x[e] > zero ? y[e] : zero; - } - */ - } +void softSignDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), softSignDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} - void reluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } +template +static void softPlusDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + T p = sd::math::nd4j_pow(static_cast(M_E), x); + return y * (p / (p + 1.)); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} - template - static void relu6Derivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return x > (T)0.f && x < (T)6.f? y : T(0.f); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void relu6Derivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), relu6Derivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - template - static void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, const float alpha) { - - const T alphaT = static_cast(alpha); - - auto functor = LAMBDA_TT(x, y, alphaT) { - return x < 0 ? alphaT * y : y; - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void leakyReluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES); - } - - template - static void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, const float alpha) { - - const T alphaT = static_cast(alpha); - - auto functor = LAMBDA_TT(x, y, alphaT){ - return y * sd::math::nd4j_eluderivative(x, alphaT); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void eluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES); - } - - template - static void seluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return y * simdOps::SELUDerivative::op(x, nullptr); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void seluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), seluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - template - static void cubeDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return y * (3 * x * x); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void cubeDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), cubeDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - //return (x >= X(0.f) ? y: -y); - template - static void reduceNorm1_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return x > T(0.f)? y : -y; - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void reduceNorm1(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), reduceNorm1_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - //////////////////////////////////////////////////////////////////////// - template - static void sigmCrossEntropy_(NDArray* logits, NDArray* labels, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return sd::math::nd4j_max(x, (T)0.f) - x * y + sd::math::nd4j_log((T)1.f + sd::math::nd4j_exp(-sd::math::nd4j_abs(x))); - }; - - logits->applyPairwiseLambda(*labels, functor, *output); - } - - void sigmCrossEntropy(sd::LaunchContext * context, NDArray* logits, NDArray* labels, NDArray* output) { - BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropy_, (logits, labels, output), FLOAT_TYPES); - } - - //////////////////////////////////////////////////////////////////////// - template - static void sigmCrossEntropyGrad_(NDArray* logits, NDArray* labels, NDArray* output) { - // 1 - labels - 1 / (1 + exp(logits)) - auto functor = LAMBDA_TT(x, y) { - if(x <= 0) - return static_cast(1.) - y - static_cast(1.) / (static_cast(1.) + sd::math::nd4j_exp(x)); - auto e = sd::math::nd4j_exp(-x); - return static_cast(1.) - y - e / (static_cast(1.) + e); - }; - - logits->applyPairwiseLambda(*labels, functor, *output); - } - - void sigmCrossEntropyGrad(sd::LaunchContext * context, NDArray* logits, NDArray* labels, NDArray* output) { - BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropyGrad_, (logits, labels, output), FLOAT_TYPES); - } - - //////////////////////////////////////////////////////////////////////// - template - static void tanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - T th = sd::math::nd4j_tanh(x); - return y * ((T)1.0f - (th * th)); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void tanhDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), tanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - // return static_cast(d2) * simdOps::HardTanhDerivative::op(d1, nullptr); - template - static void hardTanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - T th = sd::math::nd4j_tanh(x); - return y * simdOps::HardTanhDerivative::op(x, nullptr); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void hardTanhDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - template - static void rationalTanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return y * simdOps::RationalTanhDerivative::op(x, nullptr); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void rationalTanhDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), rationalTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - template - static void rectifiedTanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return x > (T) 0.0f ? y * (sd::math::nd4j_tanhderivative(x)) : (T) 0.0f; - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void rectifiedTanhDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), rectifiedTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - // X f = (X) 1.0f + sd::math::nd4j_abs(d1); - // return (X) d2 * ((X) 1.0f / (f * f)); - - template - static void softSignDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - T ss = (T)1.f + sd::math::nd4j_abs(x); - return y * ((T) 1.0f / (ss * ss)); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void softSignDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), softSignDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - template - static void softPlusDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - T p = sd::math::nd4j_pow(static_cast(M_E), x); - return y * (p / (p + 1.)); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void softPlusDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), softPlusDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } +void softPlusDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), softPlusDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} /// /// \param theFirst /// \param theSecond /// \param theOutput - template - static void sigmoidDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - T s = sd::math::nd4j_sigmoid(x); - return y * (s * ((T) 1.0f - s)); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void sigmoidDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), sigmoidDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - template - static void hardSigmoidDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return y * simdOps::HardSigmoidDerivative::op(x, nullptr); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void hardSigmoidDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardSigmoidDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - template - static void logSumExp_(NDArray* input, NDArray* axis, NDArray* output) { - // reduce along axis with - NDArray tempInput = input->dup(); - input->applyTransform(transform::Exp, tempInput); - std::vector axisVector; - if (axis != nullptr) { - axisVector.resize(axis->lengthOf()); - for (size_t i = 0; i < axisVector.size(); ++i) - axisVector[i] = axis->e(i); - } - tempInput.reduceAlongDimension(reduce::Sum, *output, axisVector); - output->applyTransform(transform::Log, *output); - } - - template - static void logSumExp_(NDArray* input, NDArray* subtrah, NDArray* axis, NDArray* output) { - // reduce along axis with - NDArray tempInput = input->dup(); - input->applyPairwiseTransform(pairwise::Subtract, *subtrah, tempInput); - tempInput.applyTransform(transform::Exp, tempInput); - - std::vector axisVector; - if (axis != nullptr) { - axisVector.resize(axis->lengthOf()); - for (size_t i = 0; i < axisVector.size(); ++i) - axisVector[i] = axis->e(i); - } - tempInput.reduceAlongDimension(reduce::Sum, *output, axisVector); - output->applyTransform(transform::Log, *output); - } - - void logSumExp(sd::LaunchContext * context, NDArray* input, NDArray* axis, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), logSumExp_, (input, axis, output), FLOAT_TYPES); - } - - void logSumExp(sd::LaunchContext * context, NDArray* input, NDArray* subtrah, NDArray* axis, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), logSumExp_, (input, subtrah, axis, output), FLOAT_TYPES); - } - -////////////////////////////////////////////////////////////////////////// template -static void weightedCrossEntropyWithLogitsFunctor_(NDArray const* targets, NDArray const* input, NDArray const* weights, NDArray* output) { +static void sigmoidDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + T s = sd::math::nd4j_sigmoid(x); + return y * (s * ((T)1.0f - s)); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} - T posWeight = weights->e(0); +void sigmoidDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), sigmoidDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} - auto mainRoutineT1 = LAMBDA_TT(_x, _z, posWeight) { - T targetWeight = (1. + (posWeight - (T)1.f) * _z); - return (1. - _z) * _x + - targetWeight * (sd::math::nd4j_log((T)1.f + sd::math::nd4j_exp(-sd::math::nd4j_abs(_x))) + - sd::math::nd4j_max(-_x, T(0.f)) - ); - }; +template +static void hardSigmoidDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + return y * simdOps::HardSigmoidDerivative::op(x, nullptr); + }; - auto mainRoutineT2 = LAMBDA_TTT(_x, _z, _w) { - return (((T)1.0 - _z) * _x) + - _w * (sd::math::nd4j_log(T(1.) + sd::math::nd4j_exp(-sd::math::nd4j_abs(_x))) + - sd::math::nd4j_max(-_x, T(0.f))); - }; + input->applyPairwiseLambda(*epsilon, functor, *output); +} +void hardSigmoidDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardSigmoidDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} - if (weights->isScalar()) { - const_cast(input)->applyPairwiseLambda(const_cast(*targets), mainRoutineT1, *output); - } - else - { - std::unique_ptr targetVector(new NDArray(*weights)); - targetVector->applyScalar(scalar::Add, -1.f, *targetVector); +template +static void logSumExp_(NDArray* input, NDArray* axis, NDArray* output) { + // reduce along axis with + NDArray tempInput = input->dup(); + input->applyTransform(transform::Exp, tempInput); + std::vector axisVector; + if (axis != nullptr) { + axisVector.resize(axis->lengthOf()); + for (size_t i = 0; i < axisVector.size(); ++i) + axisVector[i] = axis->e(i); + } + tempInput.reduceAlongDimension(reduce::Sum, *output, axisVector); + output->applyTransform(transform::Log, *output); +} - std::unique_ptr targetTensor(new NDArray(*targets)); - *targetTensor = (*targetVector * *targetTensor) + T(1.f); - const_cast(input)->applyTriplewiseLambda(const_cast(*targets), *targetTensor.get(), mainRoutineT2, *output); - } +template +static void logSumExp_(NDArray* input, NDArray* subtrah, NDArray* axis, + NDArray* output) { + // reduce along axis with + NDArray tempInput = input->dup(); + input->applyPairwiseTransform(pairwise::Subtract, *subtrah, tempInput); + tempInput.applyTransform(transform::Exp, tempInput); + + std::vector axisVector; + if (axis != nullptr) { + axisVector.resize(axis->lengthOf()); + for (size_t i = 0; i < axisVector.size(); ++i) + axisVector[i] = axis->e(i); + } + tempInput.reduceAlongDimension(reduce::Sum, *output, axisVector); + output->applyTransform(transform::Log, *output); } -void weightedCrossEntropyWithLogitsFunctor(sd::LaunchContext * context, NDArray const* targets, NDArray const* input, NDArray const* weights, NDArray* output) { - BUILD_SINGLE_SELECTOR(targets->dataType(), weightedCrossEntropyWithLogitsFunctor_, (targets, input, weights, output), FLOAT_TYPES); +void logSumExp(sd::LaunchContext* context, NDArray* input, NDArray* axis, + NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), logSumExp_, (input, axis, output), + FLOAT_TYPES); } +void logSumExp(sd::LaunchContext* context, NDArray* input, NDArray* subtrah, + NDArray* axis, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), logSumExp_, + (input, subtrah, axis, output), FLOAT_TYPES); } + +////////////////////////////////////////////////////////////////////////// +template +static void weightedCrossEntropyWithLogitsFunctor_(NDArray const* targets, + NDArray const* input, + NDArray const* weights, + NDArray* output) { + T posWeight = weights->e(0); + + auto mainRoutineT1 = LAMBDA_TT(_x, _z, posWeight) { + T targetWeight = (1. + (posWeight - (T)1.f) * _z); + return (1. - _z) * _x + + targetWeight * (sd::math::nd4j_log( + (T)1.f + sd::math::nd4j_exp( + -sd::math::nd4j_abs(_x))) + + sd::math::nd4j_max(-_x, T(0.f))); + }; + + auto mainRoutineT2 = LAMBDA_TTT(_x, _z, _w) { + return (((T)1.0 - _z) * _x) + + _w * + (sd::math::nd4j_log( + T(1.) + sd::math::nd4j_exp(-sd::math::nd4j_abs(_x))) + + sd::math::nd4j_max(-_x, T(0.f))); + }; + + if (weights->isScalar()) { + const_cast(input)->applyPairwiseLambda( + const_cast(*targets), mainRoutineT1, *output); + } else { + std::unique_ptr targetVector(new NDArray(*weights)); + targetVector->applyScalar(scalar::Add, -1.f, *targetVector); + + std::unique_ptr targetTensor(new NDArray(*targets)); + *targetTensor = (*targetVector * *targetTensor) + T(1.f); + const_cast(input)->applyTriplewiseLambda( + const_cast(*targets), *targetTensor.get(), mainRoutineT2, + *output); + } } -} \ No newline at end of file + +void weightedCrossEntropyWithLogitsFunctor(sd::LaunchContext* context, + NDArray const* targets, + NDArray const* input, + NDArray const* weights, + NDArray* output) { + BUILD_SINGLE_SELECTOR(targets->dataType(), + weightedCrossEntropyWithLogitsFunctor_, + (targets, input, weights, output), FLOAT_TYPES); +} + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lgamma.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lgamma.cpp index 3b71f7ce9742..b6d6b66b2a0f 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lgamma.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lgamma.cpp @@ -19,8 +19,8 @@ // @author George A. Shulinok // -#include #include +#include namespace sd { namespace ops { @@ -30,24 +30,23 @@ namespace helpers { // calculate digamma function for array elements template static void lgamma_(NDArray& x, NDArray& z) { - - auto lgammaProc = LAMBDA_T(x_) { - return T(DataTypeUtils::fromT() == DataType::DOUBLE?::lgamma(x_): ::lgammaf(x_)); //math::nd4j_log(math::nd4j_gamma(x)); - }; - - x.applyLambda(lgammaProc, z); + auto lgammaProc = LAMBDA_T(x_) { + return T( + DataTypeUtils::fromT() == DataType::DOUBLE + ? ::lgamma(x_) + : ::lgammaf(x_)); // math::nd4j_log(math::nd4j_gamma(x)); + }; + + x.applyLambda(lgammaProc, z); } void lgamma(sd::LaunchContext* context, NDArray& x, NDArray& z) { - - BUILD_SINGLE_SELECTOR(x.dataType(), lgamma_, (x, z), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(x.dataType(), lgamma_, (x, z), FLOAT_TYPES); } -BUILD_SINGLE_TEMPLATE(template void lgamma_, (NDArray& x, NDArray& z), FLOAT_TYPES); - - - -} -} -} +BUILD_SINGLE_TEMPLATE(template void lgamma_, (NDArray & x, NDArray& z), + FLOAT_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lrn.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lrn.cpp index b49f8e61cd63..5a11a9647ec5 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lrn.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lrn.cpp @@ -19,314 +19,348 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include +#include #include #include -#include +#include namespace sd { namespace ops { namespace helpers { template -static int lrnFunctor_(sd::graph::Context& block, NDArray* input, NDArray* output, int depth, float bias, float alpha, float beta) { - - nd4j_debug("MKL-DNN is not used for lrn!\n", 0); - - const int rank = input->rankOf(); - - TadPack inTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), {rank - 1}); - TadPack outTadPack; - - if(shape::haveSameShapeAndStrides(input->shapeInfo(), output->shapeInfo())) - outTadPack = inTadPack; - else - outTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), {rank - 1}); - - const Nd4jLong numOfTads = inTadPack.numberOfTads(); - const Nd4jLong tadLen = input->sizeAt(-1); - - const Nd4jLong* inTadOffsets = inTadPack.primaryOffsets(); - const Nd4jLong* outTadOffsets = outTadPack.primaryOffsets(); - - const Nd4jLong inTadEws = shape::elementWiseStride(inTadPack.primaryShapeInfo()); - const Nd4jLong outTadEws = shape::elementWiseStride(outTadPack.primaryShapeInfo()); - - const T* inBuff = reinterpret_cast(input->buffer()); - T* outBuff = reinterpret_cast(output->buffer()); - - const T tbias = static_cast(bias); - const T tbeta = static_cast(beta); - const T talpha = static_cast(alpha); - - if(inTadEws == 1 && outTadEws == 1) { - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - const T *x = inBuff + inTadOffsets[i]; - T *y = outBuff + outTadOffsets[i]; - - T prev = 0; - - // calculate squared sum of elements per each j-th element range [j - depth, j + depth + 1] - // we store each squared sum in corresponding element of y array - for (Nd4jLong j = 0; j < tadLen; ++j) { - const uint begin = sd::math::nd4j_max(0, j - depth); - const uint last = depth + j + 1; - const uint end = sd::math::nd4j_min(last, tadLen); - - if (j == 0) { - for (uint s = begin; s < end; ++s) - prev = prev + x[s] * x[s]; - y[j] = prev; - } else if (begin == 0 && last <= tadLen) - y[j] = prev + x[end - 1] * x[end - 1]; - else if (begin > 0 && last <= tadLen) - y[j] = prev + x[end - 1] * x[end - 1] - x[begin - 1] * x[begin - 1]; - else if (begin > 0 && last > tadLen) - y[j] = prev - x[begin - 1] * x[begin - 1]; - else - y[j] = prev; - - if (j != 0) - prev = y[j]; - - y[j] = x[j] / sd::math::nd4j_pow(tbias + alpha * prev, tbeta); - } - } - }; - - samediff::Threads::parallel_tad(func, 0, numOfTads); - } - else { - auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong i = 0; i < numOfTads; ++i) { - const T *x = inBuff + inTadOffsets[i]; - T *y = outBuff + outTadOffsets[i]; - - T prev = 0; - - // calculate squared sum of elements per each j-th element range [j - depth, j + depth + 1] - // we store each squared sum in corresponding element of y array - for (Nd4jLong j = 0; j < tadLen; ++j) { - const uint begin = sd::math::nd4j_max(0, j - depth); - const uint last = depth + j + 1; - const uint end = sd::math::nd4j_min(last, tadLen); - - if (j == 0) { - for (uint s = begin; s < end; ++s) - prev = prev + x[s * inTadEws] * x[s * inTadEws]; - y[j * outTadEws] = prev; - } else if (begin == 0 && last <= tadLen) - y[j * outTadEws] = prev + x[(end - 1) * inTadEws] * x[(end - 1) * inTadEws]; - else if (begin > 0 && last <= tadLen) - y[j * outTadEws] = prev + x[(end - 1) * inTadEws] * x[(end - 1) * inTadEws] - x[(begin - 1) * inTadEws] * x[(begin - 1) * inTadEws]; - else if (begin > 0 && last > tadLen) - y[j * outTadEws] = prev - x[(begin - 1) * inTadEws] * x[(begin - 1) * inTadEws]; - else - y[j * outTadEws] = prev; - - if (j != 0) - prev = y[j * outTadEws]; - - y[j * outTadEws] = x[j * inTadEws] / sd::math::nd4j_pow(tbias + alpha * prev, tbeta); - } - } - }; - - samediff::Threads::parallel_tad(func, 0, numOfTads); - } - return Status::OK(); +static int lrnFunctor_(sd::graph::Context& block, NDArray* input, + NDArray* output, int depth, float bias, float alpha, + float beta) { + nd4j_debug("MKL-DNN is not used for lrn!\n", 0); + + const int rank = input->rankOf(); + + TadPack inTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), rank - 1); + TadPack outTadPack; + + if (shape::haveSameShapeAndStrides(input->shapeInfo(), output->shapeInfo())) + outTadPack = inTadPack; + else + outTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), rank - 1); + + const Nd4jLong numOfTads = inTadPack.numberOfTads(); + const Nd4jLong tadLen = input->sizeAt(-1); + + const Nd4jLong* inTadOffsets = inTadPack.primaryOffsets(); + const Nd4jLong* outTadOffsets = outTadPack.primaryOffsets(); + + const Nd4jLong inTadEws = + shape::elementWiseStride(inTadPack.primaryShapeInfo()); + const Nd4jLong outTadEws = + shape::elementWiseStride(outTadPack.primaryShapeInfo()); + + const T* inBuff = reinterpret_cast(input->buffer()); + T* outBuff = reinterpret_cast(output->buffer()); + + const T tbias = static_cast(bias); + const T tbeta = static_cast(beta); + const T talpha = static_cast(alpha); + + if (inTadEws == 1 && outTadEws == 1) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + const T* x = inBuff + inTadOffsets[i]; + T* y = outBuff + outTadOffsets[i]; + + T prev = 0; + + // calculate squared sum of elements per each j-th element range [j - + // depth, j + depth + 1] we store each squared sum in corresponding + // element of y array + for (Nd4jLong j = 0; j < tadLen; ++j) { + const uint begin = sd::math::nd4j_max(0, j - depth); + const uint last = depth + j + 1; + const uint end = sd::math::nd4j_min(last, tadLen); + + if (j == 0) { + for (uint s = begin; s < end; ++s) prev = prev + x[s] * x[s]; + y[j] = prev; + } else if (begin == 0 && last <= tadLen) + y[j] = prev + x[end - 1] * x[end - 1]; + else if (begin > 0 && last <= tadLen) + y[j] = prev + x[end - 1] * x[end - 1] - x[begin - 1] * x[begin - 1]; + else if (begin > 0 && last > tadLen) + y[j] = prev - x[begin - 1] * x[begin - 1]; + else + y[j] = prev; + + if (j != 0) prev = y[j]; + + y[j] = + x[j] / sd::math::nd4j_pow(tbias + alpha * prev, tbeta); + } + } + }; + + samediff::Threads::parallel_tad(func, 0, numOfTads); + } else { + auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = 0; i < numOfTads; ++i) { + const T* x = inBuff + inTadOffsets[i]; + T* y = outBuff + outTadOffsets[i]; + + T prev = 0; + + // calculate squared sum of elements per each j-th element range [j - + // depth, j + depth + 1] we store each squared sum in corresponding + // element of y array + for (Nd4jLong j = 0; j < tadLen; ++j) { + const uint begin = sd::math::nd4j_max(0, j - depth); + const uint last = depth + j + 1; + const uint end = sd::math::nd4j_min(last, tadLen); + + if (j == 0) { + for (uint s = begin; s < end; ++s) + prev = prev + x[s * inTadEws] * x[s * inTadEws]; + y[j * outTadEws] = prev; + } else if (begin == 0 && last <= tadLen) + y[j * outTadEws] = + prev + x[(end - 1) * inTadEws] * x[(end - 1) * inTadEws]; + else if (begin > 0 && last <= tadLen) + y[j * outTadEws] = + prev + x[(end - 1) * inTadEws] * x[(end - 1) * inTadEws] - + x[(begin - 1) * inTadEws] * x[(begin - 1) * inTadEws]; + else if (begin > 0 && last > tadLen) + y[j * outTadEws] = + prev - x[(begin - 1) * inTadEws] * x[(begin - 1) * inTadEws]; + else + y[j * outTadEws] = prev; + + if (j != 0) prev = y[j * outTadEws]; + + y[j * outTadEws] = x[j * inTadEws] / sd::math::nd4j_pow( + tbias + alpha * prev, tbeta); + } + } + }; + + samediff::Threads::parallel_tad(func, 0, numOfTads); + } + return Status::OK(); } - -BUILD_SINGLE_TEMPLATE(template int lrnFunctor_, (sd::graph::Context& block, NDArray* input, NDArray* output, int depth, float bias, float alpha, float beta), FLOAT_TYPES); -int lrnFunctor(sd::graph::Context& block, NDArray* input, NDArray* output, int depth, double bias, double alpha, double beta) { - BUILD_SINGLE_SELECTOR(input->dataType(), return lrnFunctor_, (block, input, output, depth, bias, alpha, beta), FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template int lrnFunctor_, + (sd::graph::Context & block, NDArray* input, + NDArray* output, int depth, float bias, float alpha, + float beta), + FLOAT_TYPES); + +int lrnFunctor(sd::graph::Context& block, NDArray* input, NDArray* output, + int depth, double bias, double alpha, double beta) { + BUILD_SINGLE_SELECTOR(input->dataType(), return lrnFunctor_, + (block, input, output, depth, bias, alpha, beta), + FLOAT_TYPES); } ////////////////////////////////////////////////////////////////////////// template -static void lrnBP_(const NDArray& input, const NDArray& gradO, NDArray& gradI, const int depth, const float bias, const float alpha, const float beta) { - - const int rank = input.rankOf(); - - TadPack inTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(input.shapeInfo(), {rank - 1}); - TadPack gradITadPack; - - if(shape::haveSameShapeAndStrides(input.shapeInfo(), gradI.shapeInfo())) - gradITadPack = inTadPack; - else - gradITadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(gradI.shapeInfo(), {rank - 1}); - - const Nd4jLong numOfTads = inTadPack.numberOfTads(); - const Nd4jLong tadLen = input.sizeAt(-1); - - const Nd4jLong* inTadOffsets = inTadPack.primaryOffsets(); - const Nd4jLong* gradITadOffsets = gradITadPack.primaryOffsets(); - - const Nd4jLong inTadEws = shape::elementWiseStride(inTadPack.primaryShapeInfo()); - const Nd4jLong gradITadEws = shape::elementWiseStride(gradITadPack.primaryShapeInfo()); - - const X* inBuff = reinterpret_cast(input.buffer()); - Y* gradIBuff = reinterpret_cast(gradI.buffer()); - - const Y tbias = static_cast(bias); - const Y tbeta = static_cast(beta); - const Y talpha = static_cast(alpha); - const Y coeff = talpha * tbeta; - - if(inTadEws == 1 && gradITadEws == 1) { - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - const X *x = inBuff + inTadOffsets[i]; - Y *y = gradIBuff + gradITadOffsets[i]; - - // this loop calculates squared sum of elements per each j-th element range [j - depth, j + depth + 1] - // we store each squared sum in corresponding element of y array - for (Nd4jLong j = 0; j < tadLen; ++j) { - const uint begin = sd::math::nd4j_max(0, j - depth); - const uint last = depth + j + 1; - const uint end = sd::math::nd4j_min(last, tadLen); - - if (j == 0) { - y[0] = 0; - for (uint s = begin; s < end; ++s) - y[0] = y[0] + x[s] * x[s]; - } else if (begin == 0 && last <= tadLen) - y[j] = y[j - 1] + x[end - 1] * x[end - 1]; - else if (begin > 0 && last <= tadLen) - y[j] = y[j - 1] + x[end - 1] * x[end - 1] - x[begin - 1] * x[begin - 1]; - else if (begin > 0 && last > tadLen) - y[j] = y[j - 1] - x[begin - 1] * x[begin - 1]; - else - y[j] = y[j - 1]; - } - - Y *factor = new Y[tadLen]; - - Y prev = 0; - // second loop calculates derivatives using information gained in first loop above - for (Nd4jLong j = 0; j < tadLen; ++j) { - const uint begin = sd::math::nd4j_max(0, j - depth); - const uint last = depth + j + 1; - const uint end = sd::math::nd4j_min(last, tadLen); - - Y init = tbias + talpha * y[j]; - - if (j == 0) { - for (uint s = begin; s < end; ++s) { - factor[s] = sd::math::nd4j_pow(tbias + talpha * y[s], -tbeta - 1); - prev = prev + x[s] * factor[s]; - } - y[0] = prev; - } else if (begin == 0 && last <= tadLen) { - factor[end - 1] = sd::math::nd4j_pow(tbias + talpha * y[end - 1], -tbeta - 1); - y[j] = prev + x[end - 1] * factor[end - 1]; - } else if (begin > 0 && last <= tadLen) { - factor[end - 1] = sd::math::nd4j_pow(tbias + talpha * y[end - 1], -tbeta - 1); - y[j] = prev + x[end - 1] * factor[end - 1] - x[begin - 1] * factor[begin - 1]; - } else if (begin > 0 && last > tadLen) - y[j] = prev - x[begin - 1] * factor[begin - 1]; - else - y[j] = prev; - - if (j != 0) - prev = y[j]; - - y[j] = factor[j] * init - 2 * x[j] * coeff * prev; - } - - delete[]factor; +static void lrnBP_(const NDArray& input, const NDArray& gradO, NDArray& gradI, + const int depth, const float bias, const float alpha, + const float beta) { + const int rank = input.rankOf(); + + TadPack inTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + input.shapeInfo(), rank - 1); + TadPack gradITadPack; + + if (shape::haveSameShapeAndStrides(input.shapeInfo(), gradI.shapeInfo())) + gradITadPack = inTadPack; + else + gradITadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + gradI.shapeInfo(), rank - 1); + + const Nd4jLong numOfTads = inTadPack.numberOfTads(); + const Nd4jLong tadLen = input.sizeAt(-1); + + const Nd4jLong* inTadOffsets = inTadPack.primaryOffsets(); + const Nd4jLong* gradITadOffsets = gradITadPack.primaryOffsets(); + + const Nd4jLong inTadEws = + shape::elementWiseStride(inTadPack.primaryShapeInfo()); + const Nd4jLong gradITadEws = + shape::elementWiseStride(gradITadPack.primaryShapeInfo()); + + const X* inBuff = reinterpret_cast(input.buffer()); + Y* gradIBuff = reinterpret_cast(gradI.buffer()); + + const Y tbias = static_cast(bias); + const Y tbeta = static_cast(beta); + const Y talpha = static_cast(alpha); + const Y coeff = talpha * tbeta; + + if (inTadEws == 1 && gradITadEws == 1) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + const X* x = inBuff + inTadOffsets[i]; + Y* y = gradIBuff + gradITadOffsets[i]; + + // this loop calculates squared sum of elements per each j-th element + // range [j - depth, j + depth + 1] we store each squared sum in + // corresponding element of y array + for (Nd4jLong j = 0; j < tadLen; ++j) { + const uint begin = sd::math::nd4j_max(0, j - depth); + const uint last = depth + j + 1; + const uint end = sd::math::nd4j_min(last, tadLen); + + if (j == 0) { + y[0] = 0; + for (uint s = begin; s < end; ++s) y[0] = y[0] + x[s] * x[s]; + } else if (begin == 0 && last <= tadLen) + y[j] = y[j - 1] + x[end - 1] * x[end - 1]; + else if (begin > 0 && last <= tadLen) + y[j] = y[j - 1] + x[end - 1] * x[end - 1] - + x[begin - 1] * x[begin - 1]; + else if (begin > 0 && last > tadLen) + y[j] = y[j - 1] - x[begin - 1] * x[begin - 1]; + else + y[j] = y[j - 1]; + } + + Y* factor = new Y[tadLen]; + + Y prev = 0; + // second loop calculates derivatives using information gained in first + // loop above + for (Nd4jLong j = 0; j < tadLen; ++j) { + const uint begin = sd::math::nd4j_max(0, j - depth); + const uint last = depth + j + 1; + const uint end = sd::math::nd4j_min(last, tadLen); + + Y init = tbias + talpha * y[j]; + + if (j == 0) { + for (uint s = begin; s < end; ++s) { + factor[s] = sd::math::nd4j_pow(tbias + talpha * y[s], + -tbeta - 1); + prev = prev + x[s] * factor[s]; } - }; - - samediff::Threads::parallel_tad(func, 0, numOfTads); - } - else { - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - const X *x = inBuff + inTadOffsets[i]; - Y *y = gradIBuff + gradITadOffsets[i]; - - // this loop calculates squared sum of elements per each j-th element range [j - depth, j + depth + 1] - // we store each squared sum in corresponding element of y array - for (Nd4jLong j = 0; j < tadLen; ++j) { - const uint begin = sd::math::nd4j_max(0, j - depth); - const uint last = depth + j + 1; - const uint end = sd::math::nd4j_min(last, tadLen); - - if (j == 0) { - y[0] = 0; - for (uint s = begin; s < end; ++s) - y[0] = y[0] + x[s * inTadEws] * x[s * inTadEws]; - } else if (begin == 0 && last <= tadLen) - y[j * gradITadEws] = - y[(j - 1) * gradITadEws] + x[(end - 1) * inTadEws] * x[(end - 1) * inTadEws]; - else if (begin > 0 && last <= tadLen) - y[j * gradITadEws] = - y[(j - 1) * gradITadEws] + x[(end - 1) * inTadEws] * x[(end - 1) * inTadEws] - - x[(begin - 1) * inTadEws] * x[(begin - 1) * inTadEws]; - else if (begin > 0 && last > tadLen) - y[j * gradITadEws] = - y[(j - 1) * gradITadEws] - x[(begin - 1) * inTadEws] * x[(begin - 1) * inTadEws]; - else - y[j * gradITadEws] = y[(j - 1) * gradITadEws]; - } - - Y *factor = new Y[tadLen]; - - Y prev = 0; - // second loop calculates derivatives using information gained in first loop above - for (Nd4jLong j = 0; j < tadLen; ++j) { - const uint begin = sd::math::nd4j_max(0, j - depth); - const uint last = depth + j + 1; - const uint end = sd::math::nd4j_min(last, tadLen); - - Y init = tbias + talpha * y[j * gradITadEws]; - - if (j == 0) { - for (uint s = begin; s < end; ++s) { - factor[s] = sd::math::nd4j_pow(tbias + talpha * y[s * gradITadEws], -tbeta - 1); - prev = prev + x[s * inTadEws] * factor[s]; - } - y[0] = prev; - } else if (begin == 0 && last <= tadLen) { - factor[end - 1] = sd::math::nd4j_pow(tbias + talpha * y[(end - 1) * gradITadEws], - -tbeta - 1); - y[j * gradITadEws] = prev + x[(end - 1) * inTadEws] * factor[end - 1]; - } else if (begin > 0 && last <= tadLen) { - factor[end - 1] = sd::math::nd4j_pow(tbias + talpha * y[(end - 1) * gradITadEws], - -tbeta - 1); - y[j * gradITadEws] = prev + x[(end - 1) * inTadEws] * factor[end - 1] - - x[(begin - 1) * inTadEws] * factor[begin - 1]; - } else if (begin > 0 && last > tadLen) - y[j * gradITadEws] = prev - x[(begin - 1) * inTadEws] * factor[begin - 1]; - else - y[j * gradITadEws] = prev; - - if (j != 0) - prev = y[j * gradITadEws]; - - y[j * gradITadEws] = factor[j] * init - 2 * x[j * inTadEws] * coeff * prev; - } - - delete[]factor; + y[0] = prev; + } else if (begin == 0 && last <= tadLen) { + factor[end - 1] = sd::math::nd4j_pow( + tbias + talpha * y[end - 1], -tbeta - 1); + y[j] = prev + x[end - 1] * factor[end - 1]; + } else if (begin > 0 && last <= tadLen) { + factor[end - 1] = sd::math::nd4j_pow( + tbias + talpha * y[end - 1], -tbeta - 1); + y[j] = prev + x[end - 1] * factor[end - 1] - + x[begin - 1] * factor[begin - 1]; + } else if (begin > 0 && last > tadLen) + y[j] = prev - x[begin - 1] * factor[begin - 1]; + else + y[j] = prev; + + if (j != 0) prev = y[j]; + + y[j] = factor[j] * init - 2 * x[j] * coeff * prev; + } + + delete[] factor; + } + }; + + samediff::Threads::parallel_tad(func, 0, numOfTads); + } else { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + const X* x = inBuff + inTadOffsets[i]; + Y* y = gradIBuff + gradITadOffsets[i]; + + // this loop calculates squared sum of elements per each j-th element + // range [j - depth, j + depth + 1] we store each squared sum in + // corresponding element of y array + for (Nd4jLong j = 0; j < tadLen; ++j) { + const uint begin = sd::math::nd4j_max(0, j - depth); + const uint last = depth + j + 1; + const uint end = sd::math::nd4j_min(last, tadLen); + + if (j == 0) { + y[0] = 0; + for (uint s = begin; s < end; ++s) + y[0] = y[0] + x[s * inTadEws] * x[s * inTadEws]; + } else if (begin == 0 && last <= tadLen) + y[j * gradITadEws] = + y[(j - 1) * gradITadEws] + + x[(end - 1) * inTadEws] * x[(end - 1) * inTadEws]; + else if (begin > 0 && last <= tadLen) + y[j * gradITadEws] = + y[(j - 1) * gradITadEws] + + x[(end - 1) * inTadEws] * x[(end - 1) * inTadEws] - + x[(begin - 1) * inTadEws] * x[(begin - 1) * inTadEws]; + else if (begin > 0 && last > tadLen) + y[j * gradITadEws] = + y[(j - 1) * gradITadEws] - + x[(begin - 1) * inTadEws] * x[(begin - 1) * inTadEws]; + else + y[j * gradITadEws] = y[(j - 1) * gradITadEws]; + } + + Y* factor = new Y[tadLen]; + + Y prev = 0; + // second loop calculates derivatives using information gained in first + // loop above + for (Nd4jLong j = 0; j < tadLen; ++j) { + const uint begin = sd::math::nd4j_max(0, j - depth); + const uint last = depth + j + 1; + const uint end = sd::math::nd4j_min(last, tadLen); + + Y init = tbias + talpha * y[j * gradITadEws]; + + if (j == 0) { + for (uint s = begin; s < end; ++s) { + factor[s] = sd::math::nd4j_pow( + tbias + talpha * y[s * gradITadEws], -tbeta - 1); + prev = prev + x[s * inTadEws] * factor[s]; } - }; - - samediff::Threads::parallel_tad(func, 0, numOfTads); - } - gradI *= gradO; + y[0] = prev; + } else if (begin == 0 && last <= tadLen) { + factor[end - 1] = sd::math::nd4j_pow( + tbias + talpha * y[(end - 1) * gradITadEws], -tbeta - 1); + y[j * gradITadEws] = + prev + x[(end - 1) * inTadEws] * factor[end - 1]; + } else if (begin > 0 && last <= tadLen) { + factor[end - 1] = sd::math::nd4j_pow( + tbias + talpha * y[(end - 1) * gradITadEws], -tbeta - 1); + y[j * gradITadEws] = prev + + x[(end - 1) * inTadEws] * factor[end - 1] - + x[(begin - 1) * inTadEws] * factor[begin - 1]; + } else if (begin > 0 && last > tadLen) + y[j * gradITadEws] = + prev - x[(begin - 1) * inTadEws] * factor[begin - 1]; + else + y[j * gradITadEws] = prev; + + if (j != 0) prev = y[j * gradITadEws]; + + y[j * gradITadEws] = + factor[j] * init - 2 * x[j * inTadEws] * coeff * prev; + } + + delete[] factor; + } + }; + + samediff::Threads::parallel_tad(func, 0, numOfTads); + } + gradI *= gradO; } - -void lrnBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int depth, const float bias, const float alpha, const float beta) { - BUILD_DOUBLE_SELECTOR(input.dataType(), gradO.dataType(), lrnBP_, (input, gradO, gradI, depth, bias, alpha, beta), FLOAT_TYPES, FLOAT_TYPES); +void lrnBP(sd::graph::Context& block, const NDArray& input, + const NDArray& gradO, NDArray& gradI, const int depth, + const float bias, const float alpha, const float beta) { + BUILD_DOUBLE_SELECTOR(input.dataType(), gradO.dataType(), lrnBP_, + (input, gradO, gradI, depth, bias, alpha, beta), + FLOAT_TYPES, FLOAT_TYPES); } -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp index 02d4c985538e..05ad99099f51 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp @@ -20,227 +20,263 @@ // implementation of operation for LSTM cell with peep hole connections: // http://www.bioinf.jku.at/publications/older/2604.pdf -// S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. -// and +// S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural +// Computation, 9(8):1735-1780, 1997. and // https://research.google.com/pubs/archive/43905.pdf -// Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014. +// Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory +// recurrent neural network architectures for large scale acoustic modeling." +// INTERSPEECH, 2014. - -#include +#include +#include #include +#include #include -#include #include -#include +#include +#include + #include -#include -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// -void lstmCell(sd::LaunchContext * context, const NDArray* xt, const NDArray* ht_1, const NDArray* ct_1, const NDArray* Wx, const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, const NDArray* b, - NDArray* ht, NDArray* ct, const std::vector& params) { - - // xt input [bS x nIn] - // ht_1 previous cell output [bS x numProj], that is at previous time step t-1, in case of projection=false -> numProj=nOut!!! - // ct_1 previous cell state [bS x nOut], that is at previous time step t-1 - - // Wx input-to-hidden weights, [nIn x 4*nOut] - // Wh hidden-to-hidden weights, [numProj x 4*nOut] - // Wc diagonal weights for peephole connections [3*nOut] - // Wp projection weights [nOut x numProj] - // b biases, [4*nOut] - - // ht current cell output [bS x numProj], that is at current time step t - // ct current cell state [bS x nOut], that is at current time step t - - const bool peephole = (bool)params[0]; // if true, provide peephole connections - const bool projection = (bool)params[1]; // if true, then projection is performed, if false then numProj==nOut is mandatory!!!! - double clippingCellValue = params[2]; // clipping value for ct, if it is not equal to zero, then cell state is clipped - double clippingProjValue = params[3]; // clipping value for projected ht, if it is not equal to zero, then projected cell output is clipped - const double forgetBias = params[4]; - - const int bS = xt->sizeAt(0); - const int nIn = xt->sizeAt(1); - const int numProj = ht_1->sizeAt(1); - const int nOut = ct_1->sizeAt(1); - - auto z = mmul(*xt, *Wx) + mmul(*ht_1, *Wh) + *b; // [bS x 4*nOut] + [bS x 4*nOut] + [1 x 4*nOut] = [bS x 4*nOut] - - auto zit = z({0,0, 0,nOut}); // z for input gate, = mmul(Wxi,xt) + mmul(Whi,ht_1) + bi = [bS x nOut] - auto zft = z({0,0, nOut,2*nOut}); // z for forget gate, = mmul(Wxf,xt) + mmul(Whf,ht_1) + bf = [bS x nOut] - auto zct = z({0,0, 2*nOut,3*nOut}); // z for cell state, = mmul(Wxc,xt) + mmul(Whc,ht_1) + bc = [bS x nOut] - auto zot = z({0,0, 3*nOut,4*nOut}); // z for output gate, = mmul(Wxo,xt) + mmul(Who,ht_1) + bo = [bS x nOut] - - if(peephole) { // add peephole connections: z + ct_1*Wc - zit += (*ct_1) * (*Wc)({0, nOut}); // add peephole connections to input gate - zft += (*ct_1) * (*Wc)({nOut, 2*nOut}); // add peephole connections to forget gate - } - - // current sell state = ft*ct_1 + it*tanh(mmul(Wxc,xt) + mmul(Whc,ht_1) + bc - ct->assign( sigmoid(zft + forgetBias) * (*ct_1) + sigmoid(zit) * tanh(zct) ); - - // if clipping value is provided then cell state is clipped by this value prior to the cell output activation - if(clippingCellValue > 0.0) - ct->applyScalar(scalar::LstmClip, clippingCellValue, *ct); - - if(peephole) - zot += (*ct) * (*Wc)({{2*nOut, 3*nOut}}); // add peephole connections to output gate zot + ct*Wc - - // current cell output = ot*tanh(ct) - auto htNoPeepHole = sigmoid(zot) * tanh(*ct); // = [bS x nOut] - - // apply projection - if(projection) { - ht->assign( mmul(htNoPeepHole, *Wp) ); // [bS x nOut] * [ nOut x numProj] = [bS x numProj] - // if clipping projection is provided then projected cell output state is clipped by this value - if(clippingProjValue != 0.) - ht->applyScalar(scalar::LstmClip, clippingProjValue, *ht); - } - else - ht->assign(&htNoPeepHole); +void lstmCell(sd::LaunchContext* context, const NDArray* xt, + const NDArray* ht_1, const NDArray* ct_1, const NDArray* Wx, + const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, + const NDArray* b, NDArray* ht, NDArray* ct, + const std::vector& params) { + // xt input [bS x nIn] + // ht_1 previous cell output [bS x numProj], that is at previous time step + // t-1, in case of projection=false -> numProj=nOut!!! ct_1 previous cell + // state [bS x nOut], that is at previous time step t-1 + + // Wx input-to-hidden weights, [nIn x 4*nOut] + // Wh hidden-to-hidden weights, [numProj x 4*nOut] + // Wc diagonal weights for peephole connections [3*nOut] + // Wp projection weights [nOut x numProj] + // b biases, [4*nOut] + + // ht current cell output [bS x numProj], that is at current time step t + // ct current cell state [bS x nOut], that is at current time step t + + const bool peephole = + (bool)params[0]; // if true, provide peephole connections + const bool projection = + (bool)params[1]; // if true, then projection is performed, if false then + // numProj==nOut is mandatory!!!! + double clippingCellValue = + params[2]; // clipping value for ct, if it is not equal to zero, then + // cell state is clipped + double clippingProjValue = + params[3]; // clipping value for projected ht, if it is not equal to + // zero, then projected cell output is clipped + const double forgetBias = params[4]; + + const int bS = xt->sizeAt(0); + const int nIn = xt->sizeAt(1); + const int numProj = ht_1->sizeAt(1); + const int nOut = ct_1->sizeAt(1); + + auto z = mmul(*xt, *Wx) + mmul(*ht_1, *Wh) + + *b; // [bS x 4*nOut] + [bS x 4*nOut] + [1 x 4*nOut] = [bS x 4*nOut] + + auto zit = z({0, 0, 0, nOut}); // z for input gate, = mmul(Wxi,xt) + + // mmul(Whi,ht_1) + bi = [bS x nOut] + auto zft = z({0, 0, nOut, 2 * nOut}); // z for forget gate, = mmul(Wxf,xt) + + // mmul(Whf,ht_1) + bf = [bS x nOut] + auto zct = + z({0, 0, 2 * nOut, 3 * nOut}); // z for cell state, = mmul(Wxc,xt) + + // mmul(Whc,ht_1) + bc = [bS x nOut] + auto zot = + z({0, 0, 3 * nOut, 4 * nOut}); // z for output gate, = mmul(Wxo,xt) + + // mmul(Who,ht_1) + bo = [bS x nOut] + + if (peephole) { // add peephole connections: z + ct_1*Wc + zit += + (*ct_1) * (*Wc)({0, nOut}); // add peephole connections to input gate + zft += (*ct_1) * + (*Wc)({nOut, 2 * nOut}); // add peephole connections to forget gate + } + + // current sell state = ft*ct_1 + it*tanh(mmul(Wxc,xt) + mmul(Whc,ht_1) + bc + ct->assign(sigmoid(zft + forgetBias) * (*ct_1) + sigmoid(zit) * tanh(zct)); + + // if clipping value is provided then cell state is clipped by this value + // prior to the cell output activation + if (clippingCellValue > 0.0) + ct->applyScalar(scalar::LstmClip, clippingCellValue, *ct); + + if (peephole) + zot += (*ct) * (*Wc)({{2 * nOut, 3 * nOut}}); // add peephole connections + // to output gate zot + ct*Wc + + // current cell output = ot*tanh(ct) + auto htNoPeepHole = sigmoid(zot) * tanh(*ct); // = [bS x nOut] + + // apply projection + if (projection) { + ht->assign(mmul(htNoPeepHole, + *Wp)); // [bS x nOut] * [ nOut x numProj] = [bS x numProj] + // if clipping projection is provided then projected cell output state is + // clipped by this value + if (clippingProjValue != 0.) + ht->applyScalar(scalar::LstmClip, clippingProjValue, *ht); + } else + ht->assign(&htNoPeepHole); } template -static void fusedTanh(NDArray *z, NDArray *i, NDArray *c, const NDArray *cLast, NDArray *f, NDArray *h) { - //cell state = blockInput .* inputGate + prevCellState .* forgetGate - /* - z->applyPairwiseTransform(pairwise::Multiply, i, c, nullptr); //c = z * i - auto temp = (*f) * (*cLast); - *c += temp; //c = (i * z) + (zf * (*cLast)) - c->applyTransform(transform::Tanh, h); //h = tanh(c) - */ - - auto uLen = static_cast(z->lengthOf()); - auto c_ = c->bufferAsT(); - auto z_ = z->bufferAsT(); - auto i_ = i->bufferAsT(); - auto f_ = f->bufferAsT(); - auto cLast_ = cLast->bufferAsT(); - auto h_ = h->bufferAsT(); - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - c_[e] = z_[e] * i_[e] + (f_[e] * cLast_[e]); - h_[e] = sd::math::nd4j_tanh(c_[e]); - } - }; - - samediff::Threads::parallel_for(func, 0, uLen); -} - -////////////////////////////////////////////////////////////////////////// - -void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast, - const NDArray* W, const NDArray* Wci, const NDArray* Wcf, const NDArray* Wco, const NDArray* b, - NDArray* i, NDArray* c, NDArray* f, NDArray* o, NDArray* z, NDArray* h, NDArray* y, const std::vector& params) { - - /* Input arrays: - * 0: xt - input [bS, nIn] at time t - * 1: cLast (cs_prev) - previous cell state [bS, nOut], time t-1 - * 2: yLast (h_prev) - previous output [bS, nOut], time t-1 - * 3: W - Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(nIn+nOut), 4*nOut] - * 4: Wci - weights - cell peephole (t-1) connections to input modulation gate, [nOut] - * 5: Wcf - weights - cell peephole (t-1) connections to forget gate, [nOut] - * 6: Wco - weights - cell peephole (t) connections to output gate, [nOut] - * 7: b - biases, [4*nOut] - * - * Input integer arguments: - * 0: if not zero, provide peephole connections - * - * Input float arguments: - * 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training - * 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped - * - * Output arrays: - * 0: i - Input modulation gate activations [bS, nOut] - * 1: c (cs) - Cell state (pre tanh) [bs, nOut] (cs) - * 2: f - Output - forget gate activations [bs, nOut] - * 3: o - Output - output gate activations [bs, nOut] - * 4: z (ci) - Output - block input [bs, nOut] - * 5: h (co) - Cell state, post tanh [bs, nOut] - * 6: y (h) - Current cell output [bS, nOut], time t - */ - const bool peephole = (bool)params[0]; // if true, provide peephole connections - const double forgetBias = params[1]; - const double clippingCellValue = params[2]; // clipping value for ct, if it is not equal to zero, then cell state is clipped - - const int bS = xt->sizeAt(0); - const int nIn = xt->sizeAt(1); - const int nOut = cLast->sizeAt(1); - - //Concat inputs: [xt, yt-1]: concat([bs,nIn],[bs,nOut]) -> [bs, (nIn+nOut)] - NDArray concatOut(xt->ordering(), {xt->sizeAt(0), xt->sizeAt(1) + yLast->sizeAt(1)}, xt->dataType(), xt->getContext()); - helpers::concat(xt->getContext(), {const_cast(xt), const_cast(yLast)}, concatOut, {1}); - - auto m = mmul(concatOut, *W); // mmul: [bs, (nIn+nOut)] * [(nIn+nOut), 4*nOut] = [bs, 4*nOut] - m += (*b); // addiRowVector - - //Note: weights are ordered [inputGate, blockInput, forgetGate, outputGate] to match TF (TF code comments state [i,f,z/ci,o] but behaviour is [i,z,f,o]) - auto zi = m({0,0, 0, nOut}); // z for input modulation gate, [bS, nOut] - auto zz = m({0,0, nOut, 2*nOut}); // z for block input, [bS, nOut] - auto zf = m({0,0, 2*nOut, 3*nOut}); // z for forget gate, [bS, nOut] - auto zo = m({0,0, 3*nOut, 4*nOut}); // z for output gate, [bS, nOut] - - if(peephole) { // add peephole connections: z + ct_1*Wc - zi += (*cLast) * (*Wci); // add peephole connections to input gate - zf += (*cLast) * (*Wcf); // add peephole connections to forget gate - } - - // current sell state = ft*cLast + it*tanh(mmul(Wxc,xt) + mmul(Whc,ht_1) + bc - if(forgetBias != 0.0) - zf += forgetBias; - - PRAGMA_OMP_PARALLEL - PRAGMA_OMP_SINGLE - { - PRAGMA_OMP_TASK - zz.applyTransform(transform::Tanh, *z); //z = tanh(zz) - - PRAGMA_OMP_TASK - zi.applyTransform(transform::Sigmoid, *i); //i = sigmoid(zi) - - PRAGMA_OMP_TASK - zf.applyTransform(transform::Sigmoid, *f); //f = sigmoid(zf); +static void fusedTanh(NDArray* z, NDArray* i, NDArray* c, const NDArray* cLast, + NDArray* f, NDArray* h) { + // cell state = blockInput .* inputGate + prevCellState .* forgetGate + /* + z->applyPairwiseTransform(pairwise::Multiply, i, c, nullptr); //c = z * + i auto temp = (*f) * (*cLast); *c += temp; //c = + (i * z) + (zf * (*cLast)) c->applyTransform(transform::Tanh, h); //h = + tanh(c) + */ + + auto uLen = static_cast(z->lengthOf()); + auto c_ = c->bufferAsT(); + auto z_ = z->bufferAsT(); + auto i_ = i->bufferAsT(); + auto f_ = f->bufferAsT(); + auto cLast_ = cLast->bufferAsT(); + auto h_ = h->bufferAsT(); + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + c_[e] = z_[e] * i_[e] + (f_[e] * cLast_[e]); + h_[e] = sd::math::nd4j_tanh(c_[e]); } + }; - if (z->ews() == 1 && i->ews() == 1 && c->ews() == 1 && cLast->ews() == 1 && f->ews() == 1 && h->ews() == 1 && - z->ordering() == i->ordering() && z->ordering() == c->ordering() && z->ordering() == cLast->ordering() && z->ordering() == f->ordering() && z->ordering() == h->ordering()) { - //cell state = blockInput .* inputGate + prevCellState .* forgetGate - BUILD_SINGLE_SELECTOR(z->dataType(), fusedTanh, (z, i, c, cLast, f, h), FLOAT_TYPES); - } else { - //cell state = blockInput .* inputGate + prevCellState .* forgetGate - z->applyPairwiseTransform(pairwise::Multiply, *i, *c); //c = z * i - auto temp = (*f) * (*cLast); - *c += temp; //c = (i * z) + (zf * (*cLast)) - c->applyTransform(transform::Tanh, *h); //h = tanh(c) - } - - // if clipping value is provided then cell state is clipped by this value prior to the cell output activation - if(clippingCellValue > 0.0) - c->applyScalar(scalar::LstmClip, clippingCellValue, *c); - - // add peephole connections to output gate zot + ct*Wc - if(peephole) { - auto prod = *c * (*Wco); - zo += prod; - } - - zo.applyTransform(transform::Sigmoid, *o); // o = sigmoid(zo) - - // current cell output = ot*tanh(ct) - c->applyTransform(transform::Tanh, *h); //h = tanh(c) - o->applyPairwiseTransform(pairwise::Multiply, *h, *y); //y = o * h + samediff::Threads::parallel_for(func, 0, uLen); } +////////////////////////////////////////////////////////////////////////// - - -} -} +void lstmBlockCell(const NDArray* xt, const NDArray* cLast, + const NDArray* yLast, const NDArray* W, const NDArray* Wci, + const NDArray* Wcf, const NDArray* Wco, const NDArray* b, + NDArray* i, NDArray* c, NDArray* f, NDArray* o, NDArray* z, + NDArray* h, NDArray* y, const std::vector& params) { + /* Input arrays: + * 0: xt - input [bS, nIn] at time t + * 1: cLast (cs_prev) - previous cell state [bS, nOut], time t-1 + * 2: yLast (h_prev) - previous output [bS, nOut], time t-1 + * 3: W - Weights - concatenated (input-to-hidden, + * hidden-to-hidden weights) weights, [(nIn+nOut), 4*nOut] 4: Wci - weights - + * cell peephole (t-1) connections to input modulation gate, [nOut] 5: Wcf - + * weights - cell peephole (t-1) connections to forget gate, [nOut] 6: Wco - + * weights - cell peephole (t) connections to output gate, [nOut] 7: b - + * biases, [4*nOut] + * + * Input integer arguments: + * 0: if not zero, provide peephole connections + * + * Input float arguments: + * 0: the bias added to forget gates in order to reduce the scale of + * forgetting in the beginning of the training 1: clipping value for cell + * state, if it is not equal to zero, then cell state is clipped + * + * Output arrays: + * 0: i - Input modulation gate activations [bS, nOut] + * 1: c (cs) - Cell state (pre tanh) [bs, nOut] (cs) + * 2: f - Output - forget gate activations [bs, nOut] + * 3: o - Output - output gate activations [bs, nOut] + * 4: z (ci) - Output - block input [bs, nOut] + * 5: h (co) - Cell state, post tanh [bs, nOut] + * 6: y (h) - Current cell output [bS, nOut], time t + */ + const bool peephole = + (bool)params[0]; // if true, provide peephole connections + const double forgetBias = params[1]; + const double clippingCellValue = + params[2]; // clipping value for ct, if it is not equal to zero, then + // cell state is clipped + + const int bS = xt->sizeAt(0); + const int nIn = xt->sizeAt(1); + const int nOut = cLast->sizeAt(1); + + // Concat inputs: [xt, yt-1]: concat([bs,nIn],[bs,nOut]) -> [bs, (nIn+nOut)] + NDArray concatOut(xt->ordering(), + {xt->sizeAt(0), xt->sizeAt(1) + yLast->sizeAt(1)}, + xt->dataType(), xt->getContext()); + helpers::concat(xt->getContext(), + {const_cast(xt), const_cast(yLast)}, + concatOut, 1); + + auto m = + mmul(concatOut, + *W); // mmul: [bs, (nIn+nOut)] * [(nIn+nOut), 4*nOut] = [bs, 4*nOut] + m += (*b); // addiRowVector + + // Note: weights are ordered [inputGate, blockInput, forgetGate, outputGate] + // to match TF (TF code comments state [i,f,z/ci,o] but behaviour is + // [i,z,f,o]) + auto zi = m({0, 0, 0, nOut}); // z for input modulation gate, [bS, nOut] + auto zz = m({0, 0, nOut, 2 * nOut}); // z for block input, [bS, nOut] + auto zf = m({0, 0, 2 * nOut, 3 * nOut}); // z for forget gate, [bS, nOut] + auto zo = m({0, 0, 3 * nOut, 4 * nOut}); // z for output gate, [bS, nOut] + + if (peephole) { // add peephole connections: z + ct_1*Wc + zi += (*cLast) * (*Wci); // add peephole connections to input gate + zf += (*cLast) * (*Wcf); // add peephole connections to forget gate + } + + // current sell state = ft*cLast + it*tanh(mmul(Wxc,xt) + mmul(Whc,ht_1) + bc + if (forgetBias != 0.0) zf += forgetBias; + + PRAGMA_OMP_PARALLEL + PRAGMA_OMP_SINGLE { + PRAGMA_OMP_TASK + zz.applyTransform(transform::Tanh, *z); // z = tanh(zz) + + PRAGMA_OMP_TASK + zi.applyTransform(transform::Sigmoid, *i); // i = sigmoid(zi) + + PRAGMA_OMP_TASK + zf.applyTransform(transform::Sigmoid, *f); // f = sigmoid(zf); + } + + if (z->ews() == 1 && i->ews() == 1 && c->ews() == 1 && cLast->ews() == 1 && + f->ews() == 1 && h->ews() == 1 && z->ordering() == i->ordering() && + z->ordering() == c->ordering() && z->ordering() == cLast->ordering() && + z->ordering() == f->ordering() && z->ordering() == h->ordering()) { + // cell state = blockInput .* inputGate + prevCellState .* forgetGate + BUILD_SINGLE_SELECTOR(z->dataType(), fusedTanh, (z, i, c, cLast, f, h), + FLOAT_TYPES); + } else { + // cell state = blockInput .* inputGate + prevCellState .* forgetGate + z->applyPairwiseTransform(pairwise::Multiply, *i, *c); // c = z * i + auto temp = (*f) * (*cLast); + *c += temp; // c = (i * z) + (zf * (*cLast)) + c->applyTransform(transform::Tanh, *h); // h = tanh(c) + } + + // if clipping value is provided then cell state is clipped by this value + // prior to the cell output activation + if (clippingCellValue > 0.0) + c->applyScalar(scalar::LstmClip, clippingCellValue, *c); + + // add peephole connections to output gate zot + ct*Wc + if (peephole) { + auto prod = *c * (*Wco); + zo += prod; + } + + zo.applyTransform(transform::Sigmoid, *o); // o = sigmoid(zo) + + // current cell output = ot*tanh(ct) + c->applyTransform(transform::Tanh, *h); // h = tanh(c) + o->applyPairwiseTransform(pairwise::Multiply, *h, *y); // y = o * h } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lstsq.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lstsq.cpp index 204b05530086..376f72a3512e 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lstsq.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lstsq.cpp @@ -17,92 +17,115 @@ // // @author GS // -#include #include #include #include #include - -#include -#include #include +#include #include +#include +#include namespace sd { namespace ops { namespace helpers { - template - static void fillRegularizer(NDArray& ioMatrix, double const value) { - auto lastDims = ioMatrix.allTensorsAlongDimension({-2, -1}); - auto rows = ioMatrix.sizeAt(-2); - //auto cols = ioMatrix.sizeAt(-1); +template +static void fillRegularizer(NDArray& ioMatrix, double const value) { + auto lastDims = ioMatrix.allTensorsAlongDimension({-2, -1}); + auto rows = ioMatrix.sizeAt(-2); + // auto cols = ioMatrix.sizeAt(-1); for (auto x = 0; x < lastDims.size(); x++) { for (auto r = 0; r < rows; r++) { - lastDims[x]->r(r,r) = (T)value; + lastDims[x].r(r,r) = (T)value; } } } - template - int leastSquaresSolveFunctor_(sd::LaunchContext* context, NDArray const* leftInput, NDArray const* rightInput, double const l2Regularizer, bool const fast, NDArray* output) { - NDArray::preparePrimaryUse({output}, {leftInput, rightInput}); - if (fast) { // Cholesky decomposition approach - // Equation for solve A^T * Ax = A^T * b, so - // 1. Computing A2: - auto tAtShape = ShapeUtils::evalShapeForMatmul(leftInput->shapeInfo(), leftInput->shapeInfo(), true, false); - //tAtShape[tAtShape.size() - 2] = output->sizeAt(-2); - NDArray leftOutput('c', tAtShape, output->dataType(), context); - MmulHelper::matmul(leftInput, leftInput, &leftOutput, true, false); // Computing A2 = A^T * A - // 2. Computing B' = A^T * b - auto rightOutput = output->ulike(); - - MmulHelper::matmul(leftInput, rightInput, &rightOutput, true, false); // Computing B' = A^T * b - // 3. due l2Regularizer = 0, skip regularization ( indeed A' = A2 - l2Regularizer * I) - auto regularizer = leftOutput.ulike(); - fillRegularizer(regularizer, l2Regularizer);https://mangapark.net/ -// regularizer *= l2Regularizer; - leftOutput += regularizer; - // 4. Cholesky decomposition -- output matrix is square and lower triangular -// auto leftOutputT = leftOutput.ulike(); - auto err = helpers::cholesky(context, &leftOutput, &leftOutput, true); // inplace decomposition - if (err) return err; - // alternate moment: inverse lower triangular matrix to solve equation A'x = b' => L^Tx = L^-1 * b' - // solve one upper triangular system (to avoid float problems) +template +int leastSquaresSolveFunctor_(sd::LaunchContext* context, + NDArray const* leftInput, + NDArray const* rightInput, + double const l2Regularizer, bool const fast, + NDArray* output) { + NDArray::preparePrimaryUse({output}, {leftInput, rightInput}); + if (fast) { // Cholesky decomposition approach + // Equation for solve A^T * Ax = A^T * b, so + // 1. Computing A2: + auto tAtShape = ShapeUtils::evalShapeForMatmul( + leftInput->shapeInfo(), leftInput->shapeInfo(), true, false); + // tAtShape[tAtShape.size() - 2] = output->sizeAt(-2); + NDArray leftOutput('c', tAtShape, output->dataType(), context); + MmulHelper::matmul(leftInput, leftInput, &leftOutput, true, + false); // Computing A2 = A^T * A + // 2. Computing B' = A^T * b + auto rightOutput = output->ulike(); - // 5. Solve two triangular systems: - auto rightB = rightOutput.ulike(); - helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, true, false, &rightB); - helpers::adjointMatrix(context, &leftOutput, true, &leftOutput); //.transposei(); - helpers::triangularSolveFunctor(context, &leftOutput, &rightB, false, false, output); - // All done - } - else { // QR decomposition approach - // Equation for solve Rx = Q^T * b, where A = Q * R, where Q - orthogonal matrix, and R - upper triangular - // 1. QR decomposition - auto qShape = leftInput->getShapeAsVector(); - auto rShape = leftInput->getShapeAsVector(); - qShape[leftInput->rankOf() - 1] = leftInput->sizeAt(-2); - - NDArray Q(leftInput->ordering(), qShape, leftInput->dataType(), context);// = leftInput->ulike(); - NDArray R(leftInput->ordering(), rShape, leftInput->dataType(), context); // = rightInput->ulike(); - helpers::qr(context, leftInput, &Q, &R, true); - // 2. b` = Q^t * b: - auto rightOutput = rightInput->ulike(); - MmulHelper::matmul(&Q, rightInput, &rightOutput, true, false); - // 3. Solve triangular system - helpers::triangularSolveFunctor(context, &R, &rightOutput, false, false, output); - } - NDArray::registerPrimaryUse({output}, {leftInput, rightInput}); - return Status::OK(); - } + MmulHelper::matmul(leftInput, rightInput, &rightOutput, true, + false); // Computing B' = A^T * b + // 3. due l2Regularizer = 0, skip regularization ( indeed A' = A2 - + // l2Regularizer * I) + auto regularizer = leftOutput.ulike(); + fillRegularizer(regularizer, l2Regularizer); + https: // mangapark.net/ + // regularizer *= l2Regularizer; + leftOutput += regularizer; + // 4. Cholesky decomposition -- output matrix is square and lower triangular + // auto leftOutputT = leftOutput.ulike(); + auto err = helpers::cholesky(context, &leftOutput, &leftOutput, + true); // inplace decomposition + if (err) return err; + // alternate moment: inverse lower triangular matrix to solve equation A'x = + // b' => L^Tx = L^-1 * b' solve one upper triangular system (to avoid float + // problems) - int leastSquaresSolveFunctor(sd::LaunchContext* context, NDArray const* leftInput, NDArray const* rightInput, double const l2Regularizer, bool const fast, NDArray* output) { - BUILD_SINGLE_SELECTOR(leftInput->dataType(), return leastSquaresSolveFunctor_, (context, leftInput, rightInput, l2Regularizer, fast, output), FLOAT_TYPES); - } + // 5. Solve two triangular systems: + auto rightB = rightOutput.ulike(); + helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, true, + false, &rightB); + helpers::adjointMatrix(context, &leftOutput, true, + &leftOutput); //.transposei(); + helpers::triangularSolveFunctor(context, &leftOutput, &rightB, false, false, + output); + // All done + } else { // QR decomposition approach + // Equation for solve Rx = Q^T * b, where A = Q * R, where Q - orthogonal + // matrix, and R - upper triangular + // 1. QR decomposition + auto qShape = leftInput->getShapeAsVector(); + auto rShape = leftInput->getShapeAsVector(); + qShape[leftInput->rankOf() - 1] = leftInput->sizeAt(-2); + NDArray Q(leftInput->ordering(), qShape, leftInput->dataType(), + context); // = leftInput->ulike(); + NDArray R(leftInput->ordering(), rShape, leftInput->dataType(), + context); // = rightInput->ulike(); + helpers::qr(context, leftInput, &Q, &R, true); + // 2. b` = Q^t * b: + auto rightOutput = rightInput->ulike(); + MmulHelper::matmul(&Q, rightInput, &rightOutput, true, false); + // 3. Solve triangular system + helpers::triangularSolveFunctor(context, &R, &rightOutput, false, false, + output); + } + NDArray::registerPrimaryUse({output}, {leftInput, rightInput}); + return Status::OK(); } + +int leastSquaresSolveFunctor(sd::LaunchContext* context, + NDArray const* leftInput, + NDArray const* rightInput, + double const l2Regularizer, bool const fast, + NDArray* output) { + BUILD_SINGLE_SELECTOR( + leftInput->dataType(), return leastSquaresSolveFunctor_, + (context, leftInput, rightInput, l2Regularizer, fast, output), + FLOAT_TYPES); } -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp index 8f45c696b960..d209a10739ac 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp @@ -18,599 +18,680 @@ // @author raver119@gmail.com // -#include -#include #include -#include -#include #include +#include +#include +#include namespace sd { namespace ops { namespace helpers { - template - static void swapRows_(NDArray* matrix, int theFirst, int theSecond) { - - if (theFirst != theSecond) - for (int i = 0; i < matrix->columns(); i++) { - math::nd4j_swap(matrix->r(theFirst, i), matrix->r(theSecond, i)); - } - } - BUILD_SINGLE_TEMPLATE(template void swapRows_, (NDArray* matrix, int theFirst, int theSecond), FLOAT_TYPES); - - template - static void swapRows(T* matrixBuf, Nd4jLong const* matrixShape, Nd4jLong theFirst, Nd4jLong theSecond) { - if (theFirst != theSecond) { - auto n = shape::sizeAt(matrixShape, -1); - - auto loop = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - Nd4jLong theFirstPos[] = {theFirst, i}; - Nd4jLong theSecondPos[] = {theSecond, i}; - auto theFirstIndex = shape::getOffset(matrixShape, theFirstPos, 0); - auto theSecondIndex = shape::getOffset(matrixShape, theSecondPos, 0); - math::nd4j_swap(matrixBuf[theFirstIndex], matrixBuf[theSecondIndex]); - } - }; - - samediff::Threads::parallel_tad(loop, 0, n, 1); - } - } - - void swapRows(NDArray* matrix, int theFirst, int theSecond) { - BUILD_SINGLE_SELECTOR(matrix->dataType(), swapRows_, (matrix, theFirst, theSecond), FLOAT_TYPES); +template +static void swapRows_(NDArray* matrix, int theFirst, int theSecond) { + if (theFirst != theSecond) + for (int i = 0; i < matrix->columns(); i++) { + math::nd4j_swap(matrix->r(theFirst, i), matrix->r(theSecond, i)); } +} +BUILD_SINGLE_TEMPLATE(template void swapRows_, + (NDArray * matrix, int theFirst, int theSecond), + FLOAT_TYPES); - template - static void invertLowerMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) { - int n = inputMatrix->rows(); - invertedMatrix->setIdentity(); +template +static void swapRows(T* matrixBuf, Nd4jLong const* matrixShape, + Nd4jLong theFirst, Nd4jLong theSecond) { + if (theFirst != theSecond) { + auto n = shape::sizeAt(matrixShape, -1); + + auto loop = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + Nd4jLong theFirstPos[] = {theFirst, i}; + Nd4jLong theSecondPos[] = {theSecond, i}; + auto theFirstIndex = shape::getOffset(matrixShape, theFirstPos, 0); + auto theSecondIndex = shape::getOffset(matrixShape, theSecondPos, 0); + math::nd4j_swap(matrixBuf[theFirstIndex], matrixBuf[theSecondIndex]); + } + }; + + samediff::Threads::parallel_tad(loop, 0, n, 1); + } +} - if (inputMatrix->isIdentityMatrix()) return; +void swapRows(NDArray* matrix, int theFirst, int theSecond) { + BUILD_SINGLE_SELECTOR(matrix->dataType(), swapRows_, + (matrix, theFirst, theSecond), FLOAT_TYPES); +} - auto invertDiagonals = PRAGMA_THREADS_FOR { - for (int i = start; i < stop; i += increment) - invertedMatrix->r(i, i) /= inputMatrix->t(i, i); - }; +template +static void invertLowerMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) { + int n = inputMatrix->rows(); + invertedMatrix->setIdentity(); + + if (inputMatrix->isIdentityMatrix()) return; + + auto invertDiagonals = PRAGMA_THREADS_FOR { + for (int i = start; i < stop; i += increment) + invertedMatrix->r(i, i) /= inputMatrix->t(i, i); + }; + + auto invertSubDiagonals = PRAGMA_THREADS_FOR { + for (int i = start; i < stop; i += increment) + invertedMatrix->r(i, i - 1) -= + (inputMatrix->t(i, i - 1) * invertedMatrix->t(i - 1, i - 1) / + inputMatrix->t(i, i)); + }; + + samediff::Threads::parallel_for(invertDiagonals, 0, n, 1); + samediff::Threads::parallel_for(invertSubDiagonals, 1, n, 1); + + // PRAGMA_OMP_PARALLEL_FOR_SIMD + for (int i = 1; i < n; i++) { + for (int j = 0; j < i - 1; j++) + for (int k = 0; k < i; k++) + invertedMatrix->r(i, j) -= + ((invertedMatrix->t(k, j) * inputMatrix->t(i, k) / + inputMatrix->t(i, i))); + } +} - auto invertSubDiagonals = PRAGMA_THREADS_FOR { - for (int i = start; i < stop; i += increment) - invertedMatrix->r(i, i - 1) -= (inputMatrix->t(i, i - 1) * invertedMatrix->t(i - 1, i - 1) / inputMatrix->t(i, i)); - }; +BUILD_SINGLE_TEMPLATE(template void invertLowerMatrix_, + (NDArray * inputMatrix, NDArray* invertedMatrix); + , FLOAT_TYPES); - samediff::Threads::parallel_for(invertDiagonals, 0, n, 1); - samediff::Threads::parallel_for(invertSubDiagonals, 1, n, 1); +void invertLowerMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { + BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, + (inputMatrix, invertedMatrix), FLOAT_TYPES); +} -// PRAGMA_OMP_PARALLEL_FOR_SIMD - for (int i = 1; i < n; i++) { - for (int j = 0; j < i - 1 ; j++) - for (int k = 0; k < i; k++) - invertedMatrix->r(i, j) -= ((invertedMatrix->t(k, j) * inputMatrix->t(i, k) / inputMatrix->t(i, i))); - } +template +static void _invertUpperMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { + int n = inputMatrix->rows(); + invertedMatrix->setIdentity(); + + if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I + return; + } + + auto invertDiagonals = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i += increment) + invertedMatrix->r(i, i) /= inputMatrix->t(i, i); + }; + + // PRAGMA_OMP_PARALLEL_FOR_IF(n > + // Environment::getInstance().elementwiseThreshold()) + auto invertUpDiagonals = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i += increment) + invertedMatrix->r(i, i + 1) -= + (inputMatrix->t(i, i + 1) * invertedMatrix->t(i + 1, i + 1) / + inputMatrix->t(i, i)); + }; + + samediff::Threads::parallel_for(invertDiagonals, 0, n, 1); + samediff::Threads::parallel_for(invertUpDiagonals, 0, n - 1, 1); + + // PRAGMA_OMP_PARALLEL_FOR_SIMD + for (auto i = n - 2; i >= 0; i--) { + for (auto j = i + 2; j < n; j++) + for (auto k = i; k < n; k++) + invertedMatrix->r(i, j) -= + ((invertedMatrix->t(k, j) * inputMatrix->t(i, k) / + inputMatrix->t(i, i))); + } +} - } +BUILD_SINGLE_TEMPLATE(template void _invertUpperMatrix, + (NDArray * inputMatrix, NDArray* invertedMatrix); + , FLOAT_TYPES); - BUILD_SINGLE_TEMPLATE(template void invertLowerMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_TYPES); +void invertUpperMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { + BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), _invertUpperMatrix, + (inputMatrix, invertedMatrix), FLOAT_TYPES); +} - void invertLowerMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { - BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_TYPES); +template +static NDArray lup_(LaunchContext* context, NDArray* input, NDArray* compound, + NDArray* permutation) { + const int rowNum = input->rows(); + const int columnNum = input->columns(); + + NDArray determinant = NDArrayFactory::create(1.f, context); + NDArray compoundMatrix = input->dup(); // copy + NDArray permutationMatrix( + input, false, context); // has same shape as input and contiguous strides + permutationMatrix.setIdentity(); + + T pivotValue; // = T(0.0); + int pivot; // = -1; + int swapCount = 0; + + for (int i = 0; i < rowNum; i++) { + pivotValue = T(0.0); + pivot = -1; + // PRAGMA_OMP_PARALLEL_FOR //_ARGS(firstprivate(pivot,pivotValue)) + for (int rowCounter = i; rowCounter < rowNum; rowCounter++) { + if (sd::math::nd4j_abs(compoundMatrix.t(rowCounter, i)) > pivotValue) { + pivotValue = sd::math::nd4j_abs(compoundMatrix.t(rowCounter, i)); + pivot = rowCounter; + } } - template - static void _invertUpperMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { - int n = inputMatrix->rows(); - invertedMatrix->setIdentity(); - - if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I - return; + if (pivotValue > DataTypeUtils::min()) { + swapRows(&compoundMatrix, pivot, i); + swapRows(&permutationMatrix, pivot, i); + if (pivot != i) swapCount++; + + for (int j = i + 1; j < rowNum; j++) { + compoundMatrix.r(j, i) /= compoundMatrix.t(i, i); + // PRAGMA_OMP_PARALLEL_FOR + for (int k = i + 1; k < rowNum; k++) { + compoundMatrix.r(j, k) -= + compoundMatrix.t(j, i) * compoundMatrix.t(i, k); } - - auto invertDiagonals = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i += increment) - invertedMatrix->r(i, i) /= inputMatrix->t(i, i); - }; - - //PRAGMA_OMP_PARALLEL_FOR_IF(n > Environment::getInstance().elementwiseThreshold()) - auto invertUpDiagonals = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i += increment) - invertedMatrix->r(i, i + 1) -= (inputMatrix->t(i, i + 1) * invertedMatrix->t(i + 1, i + 1) / - inputMatrix->t(i, i)); - }; - - samediff::Threads::parallel_for(invertDiagonals, 0, n, 1); - samediff::Threads::parallel_for(invertUpDiagonals, 0, n - 1, 1); - -// PRAGMA_OMP_PARALLEL_FOR_SIMD - for (auto i = n - 2; i >= 0; i--) { - for (auto j = i + 2; j < n; j++) - for (auto k = i; k < n; k++) - invertedMatrix->r(i, j) -= ((invertedMatrix->t(k, j) * inputMatrix->t(i, k) / inputMatrix->t(i, i))); + } + } + } + + for (int e = 0; e < rowNum; e++) { + // nd4j_printf("Compound matrix diag %i %f.\n", e, (*compoundMatrix)(e, e)); + determinant *= compoundMatrix.e(e, e); + } + if (swapCount % 2) determinant = -determinant; + if (compound != nullptr) compound->assign(compoundMatrix); + if (permutation != nullptr) { + auto permutaionVector = NDArrayFactory::create( + 'c', {rowNum}, DataTypeUtils::fromT(), input->getContext()); + for (auto i = 0; i < rowNum; i++) { + for (auto j = 0; j < columnNum; j++) { + if (permutationMatrix.t(i, j) != 0) { + permutaionVector.template r(i) = j; } + } } - - BUILD_SINGLE_TEMPLATE(template void _invertUpperMatrix, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_TYPES); - - void invertUpperMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { - BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), _invertUpperMatrix, (inputMatrix, invertedMatrix), FLOAT_TYPES); + if (permutationMatrix.isSameShape(permutation)) + permutation->assign(permutationMatrix); + else if (permutation->isSameShape(permutaionVector)) { + permutation->assign(permutaionVector); } + } + return determinant; +} - - template - static NDArray lup_(LaunchContext *context, NDArray* input, NDArray* compound, NDArray* permutation) { - - const int rowNum = input->rows(); - const int columnNum = input->columns(); - - NDArray determinant = NDArrayFactory::create(1.f, context); - NDArray compoundMatrix = *input; // copy - NDArray permutationMatrix(input, false, context); // has same shape as input and contiguous strides - permutationMatrix.setIdentity(); - - T pivotValue; // = T(0.0); - int pivot; // = -1; - int swapCount = 0; - - for(int i = 0; i < rowNum; i++ ) { - pivotValue = T(0.0); - pivot = -1; - //PRAGMA_OMP_PARALLEL_FOR //_ARGS(firstprivate(pivot,pivotValue)) - for(int rowCounter = i; rowCounter < rowNum; rowCounter++ ) { - if (sd::math::nd4j_abs(compoundMatrix.t(rowCounter, i)) > pivotValue) { - pivotValue = sd::math::nd4j_abs(compoundMatrix.t(rowCounter, i)); - pivot = rowCounter; - } - } - - if( pivotValue > DataTypeUtils::min()) { - swapRows(&compoundMatrix, pivot, i); - swapRows(&permutationMatrix, pivot, i); - if (pivot != i) - swapCount++; - - for( int j = i + 1; j < rowNum; j++ ) { - compoundMatrix.r(j, i) /= compoundMatrix.t(i, i); - //PRAGMA_OMP_PARALLEL_FOR - for( int k = i + 1; k < rowNum; k++ ) { - compoundMatrix.r(j, k) -= compoundMatrix.t(j, i) * compoundMatrix.t(i, k); - } - } - } - } - - for (int e = 0; e < rowNum; e++) { - // nd4j_printf("Compound matrix diag %i %f.\n", e, (*compoundMatrix)(e, e)); - determinant *= compoundMatrix.e(e, e); - } - if (swapCount % 2) determinant = -determinant; - if (compound != nullptr) - compound->assign(compoundMatrix); - if (permutation != nullptr) { - auto permutaionVector = NDArrayFactory::create('c', {rowNum}, DataTypeUtils::fromT(), input->getContext()); - for (auto i = 0; i < rowNum; i++) { - for (auto j = 0; j < columnNum; j++) { - if (permutationMatrix.t(i, j) != 0) { - permutaionVector.template r(i) = j; - } - } - } - if (permutationMatrix.isSameShape(permutation)) - permutation->assign(permutationMatrix); - else if (permutation->isSameShape(permutaionVector)) { - permutation->assign(permutaionVector); - } - } - return determinant; +BUILD_DOUBLE_TEMPLATE(template NDArray lup_, + (LaunchContext * context, NDArray* input, NDArray* output, + NDArray* permutation), + FLOAT_TYPES, INDEXING_TYPES); +/* + * lu decomposition with naive algorithm with partial pivoting + * */ +template +static I argmaxCol(I column, T* compoundBuffer, Nd4jLong const* compoundShape) { + auto rowNum = shape::sizeAt(compoundShape, 0); + Nd4jLong xInitial[] = {column, column}; + auto xInitialIndex = shape::getOffset(compoundShape, xInitial, 0); + auto maxValue = T(0); // sd::math::nd4j_abs(compoundBuffer[xInitialIndex]); + auto result = -1; + // auto loop = PRAGMA_THREADS_FOR { + auto start = column, stop = rowNum, increment = 1; + for (auto rowCounter = start; rowCounter < stop; rowCounter++) { + Nd4jLong xPos[] = {rowCounter, column}; + auto xIndex = shape::getOffset(compoundShape, xPos, 0); + if (sd::math::nd4j_abs(compoundBuffer[xIndex]) > maxValue) { + maxValue = sd::math::nd4j_max(maxValue, + sd::math::nd4j_abs(compoundBuffer[xIndex])); + result = rowCounter; } + } + //}; + // samediff::Threads::parallel_for(loop, column, rowNum, 1); + return result; +} - BUILD_DOUBLE_TEMPLATE(template NDArray lup_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES, INDEXING_TYPES); - /* - * lu decomposition with naive algorithm with partial pivoting - * */ - template - static I argmaxCol(I column, T* compoundBuffer, Nd4jLong const* compoundShape) { - auto rowNum = shape::sizeAt(compoundShape, 0); - Nd4jLong xInitial[] = {column, column}; - auto xInitialIndex = shape::getOffset(compoundShape, xInitial, 0); - auto maxValue = T(0); //sd::math::nd4j_abs(compoundBuffer[xInitialIndex]); - auto result = -1; - //auto loop = PRAGMA_THREADS_FOR { - auto start = column, stop = rowNum, increment = 1; - for (auto rowCounter = start; rowCounter < stop; rowCounter++) { - Nd4jLong xPos[] = {rowCounter, column}; - auto xIndex = shape::getOffset(compoundShape, xPos, 0); - if (sd::math::nd4j_abs(compoundBuffer[xIndex]) > maxValue) { - maxValue = sd::math::nd4j_max(maxValue, sd::math::nd4j_abs(compoundBuffer[xIndex])); - result = rowCounter; - } - } - //}; - //samediff::Threads::parallel_for(loop, column, rowNum, 1); - return result; +template +void processColumns(int currentRow, int rowNum, T* compoundBuf, + Nd4jLong const* compoundShape) { + Nd4jLong xDiag[] = {currentRow, currentRow}; + auto diagIndex = shape::getOffset(compoundShape, xDiag, 0); + auto loop = PRAGMA_THREADS_FOR { + for (auto j = start; j < stop; j++) { + Nd4jLong xRow[] = {j, currentRow}; + auto rowIndex = shape::getOffset(compoundShape, xRow, 0); + compoundBuf[rowIndex] /= compoundBuf[diagIndex]; // output->t(i, i); + for (int k = currentRow + 1; k < rowNum; k++) { + Nd4jLong yRow[] = {j, k}; + Nd4jLong yCol[] = {currentRow, k}; + auto rowIndexY = shape::getOffset(compoundShape, yRow, 0); + auto colIndex = shape::getOffset(compoundShape, yCol, 0); + compoundBuf[rowIndexY] -= compoundBuf[rowIndex] * compoundBuf[colIndex]; + } } + }; + samediff::Threads::parallel_tad(loop, currentRow + 1, rowNum, 1); +} - template - void processColumns(int currentRow, int rowNum, T* compoundBuf, Nd4jLong const* compoundShape) { - Nd4jLong xDiag[] = {currentRow, currentRow}; - auto diagIndex = shape::getOffset(compoundShape, xDiag, 0); - auto loop = PRAGMA_THREADS_FOR { - for (auto j = start; j < stop; j++) { - Nd4jLong xRow[] = {j, currentRow}; - auto rowIndex = shape::getOffset(compoundShape, xRow, 0); - compoundBuf[rowIndex] /= compoundBuf[diagIndex]; //output->t(i, i); - for (int k = currentRow + 1; k < rowNum; k++) { - Nd4jLong yRow[] = {j, k}; - Nd4jLong yCol[] = {currentRow, k}; - auto rowIndexY = shape::getOffset(compoundShape, yRow, 0); - auto colIndex = shape::getOffset(compoundShape, yCol, 0); - compoundBuf[rowIndexY] -= compoundBuf[rowIndex] * compoundBuf[colIndex]; - } - } - }; - samediff::Threads::parallel_tad(loop, currentRow + 1, rowNum, 1); +template +static void doolitleLU(LaunchContext* context, NDArray* compound, + Nd4jLong rowNum) { + auto input = compound->dup(); + compound->nullify(); + + // Decomposing matrix into Upper and Lower + // triangular matrix + for (auto i = 0; i < rowNum; i++) { + // Upper Triangular + for (auto k = i; k < rowNum; k++) { + // Summation of L(i, j) * U(j, k) + int sum = 0; + for (int j = 0; j < i; j++) + sum += compound->t(i, j) * compound->t(j, k); + + // Evaluating U(i, k) + compound->r(i, k) = input.t(i, k) - sum; } - template - static void doolitleLU(LaunchContext* context, NDArray* compound, Nd4jLong rowNum) { - auto input = compound->dup(); - compound->nullify(); - - // Decomposing matrix into Upper and Lower - // triangular matrix - for (auto i = 0; i < rowNum; i++) { - - // Upper Triangular - for (auto k = i; k < rowNum; k++) { - - // Summation of L(i, j) * U(j, k) - int sum = 0; - for (int j = 0; j < i; j++) - sum += compound->t(i,j) * compound->t(j,k); - - // Evaluating U(i, k) - compound->r(i, k) = input.t(i, k) - sum; - } - - // Lower Triangular - for (int k = i + 1; k < rowNum; k++) { - // Summation of L(k, j) * U(j, i) - int sum = 0; - for (int j = 0; j < i; j++) - sum += compound->t(k,j) * compound->t(j, i); - - // Evaluating L(k, i) - compound->r(k, i) = (input.t(k, i) - sum) / compound->t(i,i); - } - } - } + // Lower Triangular + for (int k = i + 1; k < rowNum; k++) { + // Summation of L(k, j) * U(j, i) + int sum = 0; + for (int j = 0; j < i; j++) + sum += compound->t(k, j) * compound->t(j, i); - template - static void luNN_(LaunchContext *context, NDArray* compound, NDArray* permutation, Nd4jLong rowNum) { - - //const int rowNum = compound->rows(); -// const int columnNum = output->columns(); - if (permutation) { // LUP algorithm - permutation->linspace(0); - auto permutationBuf = permutation->bufferAsT(); //dataBuffer()->primaryAsT(); - auto compoundBuf = compound->bufferAsT(); - auto compoundShape = compound->shapeInfo(); - auto permutationShape = permutation->shapeInfo(); - for (auto i = 0; i < rowNum - 1; i++) { - auto pivotIndex = argmaxCol(i, compoundBuf, compoundShape); - if (pivotIndex < 0) { - throw std::runtime_error("helpers::luNN_: input matrix is singular."); - } - math::nd4j_swap(permutationBuf[shape::getIndexOffset(i, permutationShape)], - permutationBuf[shape::getIndexOffset(pivotIndex, permutationShape)]); - swapRows(compoundBuf, compoundShape, i, pivotIndex); - - processColumns(i, rowNum, compoundBuf, compoundShape); - } - } - else { // Doolitle algorithm with LU decomposition - doolitleLU(context, compound, rowNum); - } + // Evaluating L(k, i) + compound->r(k, i) = (input.t(k, i) - sum) / compound->t(i, i); } + } +} - template - static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) { - auto n = input->sizeAt(-1); - - output->assign(input); // fill up output tensor with zeros - ResultSet outputs = output->allTensorsAlongDimension({-2, -1}); - ResultSet permutations; - if (permutationVectors) - permutations = permutationVectors->allTensorsAlongDimension({-1}); - - auto loop = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - luNN_(context, outputs.at(i), permutationVectors?permutations.at(i):nullptr, n); - } - }; - samediff::Threads::parallel_for(loop, 0, outputs.size(), 1); +template +static void luNN_(LaunchContext* context, NDArray* compound, + NDArray* permutation, Nd4jLong rowNum) { + // const int rowNum = compound->rows(); + // const int columnNum = output->columns(); + if (permutation) { // LUP algorithm + permutation->linspace(0); + auto permutationBuf = + permutation->bufferAsT(); // dataBuffer()->primaryAsT(); + auto compoundBuf = compound->bufferAsT(); + auto compoundShape = compound->shapeInfo(); + auto permutationShape = permutation->shapeInfo(); + for (auto i = 0; i < rowNum - 1; i++) { + auto pivotIndex = argmaxCol(i, compoundBuf, compoundShape); + if (pivotIndex < 0) { + throw std::runtime_error("helpers::luNN_: input matrix is singular."); + } + math::nd4j_swap( + permutationBuf[shape::getIndexOffset(i, permutationShape)], + permutationBuf[shape::getIndexOffset(pivotIndex, permutationShape)]); + swapRows(compoundBuf, compoundShape, i, pivotIndex); + + processColumns(i, rowNum, compoundBuf, compoundShape); } + } else { // Doolitle algorithm with LU decomposition + doolitleLU(context, compound, rowNum); + } +} - void lu(LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation) { - BUILD_DOUBLE_SELECTOR(input->dataType(), permutation?permutation->dataType():DataType::INT32, lu_, (context, input, output, permutation), FLOAT_TYPES, INDEXING_TYPES); +template +static void lu_(LaunchContext* context, NDArray* input, NDArray* output, + NDArray* permutationVectors) { + auto n = input->sizeAt(-1); + + output->assign(input); // fill up output tensor with zeros + ResultSet outputs = output->allTensorsAlongDimension({-2, -1}); + ResultSet permutations; + if (permutationVectors) + permutations = permutationVectors->allTensorsAlongDimension({-1}); + + auto loop = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + luNN_(context, &outputs.at(i), + permutationVectors ? &permutations.at(i) : nullptr, n); } + }; + samediff::Threads::parallel_for(loop, 0, outputs.size(), 1); +} -// BUILD_DOUBLE_TEMPLATE(template NDArray lu_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES, INDEXING_TYPES); - - template - static int determinant_(LaunchContext *context, NDArray* input, NDArray* output) { +void lu(LaunchContext* context, NDArray* input, NDArray* output, + NDArray* permutation) { + BUILD_DOUBLE_SELECTOR(input->dataType(), + permutation ? permutation->dataType() : DataType::INT32, + lu_, (context, input, output, permutation), FLOAT_TYPES, + INDEXING_TYPES); +} - Nd4jLong n = input->sizeAt(-1); - Nd4jLong n2 = n * n; +// BUILD_DOUBLE_TEMPLATE(template NDArray lu_, (LaunchContext *context, +// NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES, +// INDEXING_TYPES); - auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace()); +template +static int determinant_(LaunchContext* context, NDArray* input, + NDArray* output) { + Nd4jLong n = input->sizeAt(-1); + Nd4jLong n2 = n * n; + + auto matrix = + NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), + context); //, block.workspace()); + + for (int e = 0; e < output->lengthOf(); e++) { + for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) + matrix.p(row, input->e(k)); + output->p(e, lup_(context, &matrix, (NDArray*)nullptr, + (NDArray*)nullptr)); + } + + return Status::OK(); +} - for (int e = 0; e < output->lengthOf(); e++) { - for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) - matrix.p(row, input->e(k)); - output->p(e, lup_(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr)); - } +int determinant(sd::LaunchContext* context, NDArray* input, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, + (context, input, output), FLOAT_TYPES); +} - return Status::OK(); +template +int logAbsDeterminant_(LaunchContext* context, NDArray* input, + NDArray* output) { + Nd4jLong n = input->sizeAt(-1); + Nd4jLong n2 = n * n; + + NDArray matrix = + NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), + context); //, block.workspace()); + for (int e = 0; e < output->lengthOf(); e++) { + for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) { + matrix.p(row, input->e(k)); } + NDArray det = + lup_(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr); + if (det.e(0) != 0.f) + output->p(e, sd::math::nd4j_log(sd::math::nd4j_abs(det.t(0)))); + } - int determinant(sd::LaunchContext * context, NDArray* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_TYPES); - } + return ND4J_STATUS_OK; +} -template - int logAbsDeterminant_(LaunchContext *context, NDArray* input, NDArray* output) { - - Nd4jLong n = input->sizeAt(-1); - Nd4jLong n2 = n * n; - - NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace()); - for (int e = 0; e < output->lengthOf(); e++) { - for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) { - matrix.p(row, input->e(k)); - } - NDArray det = lup_(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr); - if (det.e(0) != 0.f) - output->p(e, sd::math::nd4j_log(sd::math::nd4j_abs(det.t(0)))); - } +int logAbsDeterminant(sd::LaunchContext* context, NDArray* input, + NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, + (context, input, output), FLOAT_TYPES); +} - return ND4J_STATUS_OK; +template +static int inverse_(LaunchContext* context, NDArray* input, NDArray* output) { + auto n = input->sizeAt(-1); + auto n2 = n * n; + auto totalCount = output->lengthOf() / n2; + + output->assign(0.f); // fill up output tensor with zeros + auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), + context); //, block.workspace()); + auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), + context); //, block.workspace()); + auto permutation = + NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); + auto lowerMatrix = + NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); + auto upperMatrix = + NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); + + for (int e = 0; e < totalCount; e++) { + if (e) matrix.assign(0.f); + + for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { + matrix.p(row++, input->e(k)); } - - int logAbsDeterminant(sd::LaunchContext * context, NDArray* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_TYPES); + T det = lup_(context, &matrix, &compound, &permutation) + .template e(0); + + // FIXME: and how this is going to work on float16? + if (sd::math::nd4j_abs(det) < T(0.000001)) { + nd4j_printf( + "matrix_inverse: The matrix %i has no inverse due determinant is " + "%lf. Quiting...\n", + e, det); + matrix.printIndexedBuffer("Wrong matrix"); + return ND4J_STATUS_VALIDATION; + } + lowerMatrix.setIdentity(); // set up U to identity matrix + for (int k = 1; k < n; + k++) { // and then put all values under main diagonal on to it + for (int j = 0; j < k; j++) + lowerMatrix.template r(k, j) = compound.template t(k, j); + } + upperMatrix.setIdentity(); // set up U to identity matrix + for (int k = 0; k < n; + k++) { // and then put all values under main diagonal on to it + for (int j = k; j < n; j++) + upperMatrix.template r(k, j) = compound.template t(k, j); } + invertUpperMatrix(&upperMatrix, &matrix); - template - static int inverse_(LaunchContext *context, NDArray* input, NDArray* output) { - - auto n = input->sizeAt(-1); - auto n2 = n * n; - auto totalCount = output->lengthOf() / n2; - - output->assign(0.f); // fill up output tensor with zeros - auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); //, block.getWorkspace()); - auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); //, block.getWorkspace()); - auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); - auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); - auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); - - for (int e = 0; e < totalCount; e++) { - if (e) - matrix.assign(0.f); - - for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { - matrix.p(row++, input->e(k)); - } - T det = lup_(context, &matrix, &compound, &permutation).template e(0); - - // FIXME: and how this is going to work on float16? - if (sd::math::nd4j_abs(det) < T(0.000001)) { - nd4j_printf("matrix_inverse: The matrix %i has no inverse due determinant is %lf. Quiting...\n", e, det); - matrix.printIndexedBuffer("Wrong matrix"); - return ND4J_STATUS_VALIDATION; - } - lowerMatrix.setIdentity(); // set up U to identity matrix - for (int k = 1; k < n; k++) { // and then put all values under main diagonal on to it - for (int j = 0; j < k; j++) - lowerMatrix.template r(k, j) = compound.template t(k, j); - } - upperMatrix.setIdentity(); // set up U to identity matrix - for (int k = 0; k < n; k++) { // and then put all values under main diagonal on to it - for (int j = k; j < n; j++) - upperMatrix.template r(k, j) = compound.template t(k, j); - } - invertUpperMatrix(&upperMatrix, &matrix); - - invertLowerMatrix(&lowerMatrix, &upperMatrix); - - sd::MmulHelper::mmul(&matrix, &upperMatrix, &compound, 1.0, 0.0); - sd::MmulHelper::mmul(&compound, &permutation, &matrix, 1.0, 0.0); - for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { - output->r(k) = matrix.template t(row++); - } - } + invertLowerMatrix(&lowerMatrix, &upperMatrix); - return Status::OK(); + sd::MmulHelper::mmul(&matrix, &upperMatrix, &compound, 1.0, 0.0); + sd::MmulHelper::mmul(&compound, &permutation, &matrix, 1.0, 0.0); + for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { + output->r(k) = matrix.template t(row++); } + } - template - static int lowerInverse_(LaunchContext *context, NDArray* input, NDArray* output) { - - auto n = input->sizeAt(-1); - auto n2 = n * n; - auto totalCount = output->lengthOf() / n2; - - output->assign(0.f); // fill up output tensor with zeros - auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); //, block.getWorkspace()); - auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); //, block.getWorkspace()); - auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); - auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); - auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); - -// auto batchLoop = PRAGMA_THREADS_FOR { - for (int e = 0; e < totalCount; e++) { - if (e) - matrix.assign(0.f); - - for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { - matrix.p(row++, input->e(k)); - } - T det = T(1.f); - for (auto i = 0; i < n; i++) { - det *= matrix. template t(i, i); - } - - // FIXME: and how this is going to work on float16? - if (sd::math::nd4j_abs(det) < T(0.000001)) { - nd4j_printf("matrix_inverse: The matrix %i has no inverse due determinant is %lf. Quiting...\n", e, det); - matrix.printIndexedBuffer("Wrong matrix"); - return ND4J_STATUS_VALIDATION; - } - lowerMatrix.nullify(); - invertLowerMatrix(&matrix, &lowerMatrix); - - for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { - output->r(k) = lowerMatrix.template t(row++); - } - } + return Status::OK(); +} - return Status::OK(); +template +static int lowerInverse_(LaunchContext* context, NDArray* input, + NDArray* output) { + auto n = input->sizeAt(-1); + auto n2 = n * n; + auto totalCount = output->lengthOf() / n2; + + output->assign(0.f); // fill up output tensor with zeros + auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), + context); //, block.workspace()); + auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), + context); //, block.workspace()); + auto permutation = + NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); + auto lowerMatrix = + NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); + auto upperMatrix = + NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); + + // auto batchLoop = PRAGMA_THREADS_FOR { + for (int e = 0; e < totalCount; e++) { + if (e) matrix.assign(0.f); + + for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { + matrix.p(row++, input->e(k)); } - - template - static int upperInverse_(LaunchContext *context, NDArray* input, NDArray* output) { - - auto n = input->sizeAt(-1); - auto n2 = n * n; - - output->nullify(); // fill up output tensor with zeros -// auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); //, block.getWorkspace()); -// auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); -// auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); - auto inputPart = input->allTensorsAlongDimension({-2, -1}); - auto outputPart = output->allTensorsAlongDimension({-2, -1}); - auto totalCount = outputPart.size(); //lengthOf() / n2; - for (int e = 0; e < totalCount; e++) { - invertUpperMatrix(inputPart.at(e), outputPart.at(e)); - } - return Status::OK(); + T det = T(1.f); + for (auto i = 0; i < n; i++) { + det *= matrix.template t(i, i); } - int inverse(sd::LaunchContext * context, NDArray* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_TYPES); + // FIXME: and how this is going to work on float16? + if (sd::math::nd4j_abs(det) < T(0.000001)) { + nd4j_printf( + "matrix_inverse: The matrix %i has no inverse due determinant is " + "%lf. Quiting...\n", + e, det); + matrix.printIndexedBuffer("Wrong matrix"); + return ND4J_STATUS_VALIDATION; } + lowerMatrix.nullify(); + invertLowerMatrix(&matrix, &lowerMatrix); - int lowerInverseFunctor(sd::LaunchContext * context, NDArray* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), return lowerInverse_, (context, input, output), FLOAT_TYPES); + for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { + output->r(k) = lowerMatrix.template t(row++); } + } - int upperInverseFunctor(sd::LaunchContext * context, NDArray* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), return upperInverse_, (context, input, output), FLOAT_TYPES); - } + return Status::OK(); +} - template - static bool checkCholeskyInput_(sd::LaunchContext * context, NDArray const* input) { - //std::unique_ptr matrix(NDArrayFactory::create_('c', {n, n}, input->dataType())); //, block.getWorkspace()); - ResultSet lastMatrixList = input->allTensorsAlongDimension({input->rankOf() - 2, input->rankOf()-1}); - for (size_t i = 0; i < lastMatrixList.size(); i++) { - auto thisMatrix = lastMatrixList.at(i); - // check for symmetric - for (Nd4jLong r = 0; r < thisMatrix->rows(); r++) - for (Nd4jLong c = 0; c < thisMatrix->columns(); c++) - if (sd::math::nd4j_abs(thisMatrix->e(r, c) - lastMatrixList.at(i)->e(c,r)) > DataTypeUtils::min()) return false; - - NDArray output = NDArrayFactory::create(0., context); - if (ND4J_STATUS_OK != determinant(context, thisMatrix, &output)) return false; - if (output.e(0) <= T(0)) return 0; - NDArray reversedMatrix(*thisMatrix); - if (ND4J_STATUS_OK != inverse(context, thisMatrix, &reversedMatrix)) return false; - if (ND4J_STATUS_OK != determinant(context, &reversedMatrix, &output)) return false; - if (output.e(0) <= T(0)) return 0; +template +static int upperInverse_(LaunchContext* context, NDArray* input, + NDArray* output) { + auto n = input->sizeAt(-1); + auto n2 = n * n; + + output->nullify(); // fill up output tensor with zeros + // auto matrix = NDArrayFactory::create('c', {n, n}, + // DataTypeUtils::fromT(), context); //, block.workspace()); auto + // lowerMatrix = NDArrayFactory::create('c', {n, n}, + // DataTypeUtils::fromT(), context); auto upperMatrix = + // NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), + // context); + auto inputPart = input->allTensorsAlongDimension({-2, -1}); + auto outputPart = output->allTensorsAlongDimension({-2, -1}); + auto totalCount = outputPart.size(); // lengthOf() / n2; + for (int e = 0; e < totalCount; e++) { + invertUpperMatrix(&inputPart.at(e), &outputPart.at(e)); + } + return Status::OK(); +} - } +int inverse(sd::LaunchContext* context, NDArray* input, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, + (context, input, output), FLOAT_TYPES); +} +int lowerInverseFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), return lowerInverse_, + (context, input, output), FLOAT_TYPES); +} - return true; - } +int upperInverseFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), return upperInverse_, + (context, input, output), FLOAT_TYPES); +} - bool checkCholeskyInput(sd::LaunchContext * context, NDArray const* input) { - BUILD_SINGLE_SELECTOR(input->dataType(), return checkCholeskyInput_, (context, input), FLOAT_TYPES); - } +template +static bool checkCholeskyInput_(sd::LaunchContext* context, + NDArray const* input) { + // std::unique_ptr matrix(NDArrayFactory::create_('c', {n, n}, + // input->dataType())); //, block.workspace()); + ResultSet lastMatrixList = input->allTensorsAlongDimension( + {input->rankOf() - 2, input->rankOf() - 1}); + for (size_t i = 0; i < lastMatrixList.size(); i++) { + auto thisMatrix = lastMatrixList.at(i); + // check for symmetric + for (Nd4jLong r = 0; r < thisMatrix.rows(); r++) + for (Nd4jLong c = 0; c < thisMatrix.columns(); c++) + if (sd::math::nd4j_abs(thisMatrix.e(r, c) - + lastMatrixList.at(i).e(c, r)) > + DataTypeUtils::min()) + return false; + + NDArray output = NDArrayFactory::create(0., context); + if (ND4J_STATUS_OK != determinant(context, &thisMatrix, &output)) + return false; + if (output.e(0) <= T(0)) return 0; + NDArray reversedMatrix = thisMatrix.dup(); + if (ND4J_STATUS_OK != inverse(context, &thisMatrix, &reversedMatrix)) + return false; + if (ND4J_STATUS_OK != determinant(context, &reversedMatrix, &output)) + return false; + if (output.e(0) <= T(0)) return 0; + } + + return true; +} - template - int cholesky_(LaunchContext *context, NDArray* input, NDArray* output, bool inplace) { - - auto n = input->sizeAt(-1); - auto n2 = n * n; - auto totalCount = output->lengthOf() / n2; - if (!inplace) - output->assign(0.f); // fill up output tensor with zeros only inplace=false - - std::unique_ptr matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), context)); //, block.getWorkspace()); - std::unique_ptr lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), context)); - - for (int e = 0; e < totalCount; e++) { - - // fill up matrix - for (int k = e * n2, l = 0; k < (e + 1) * n2; k++) { - matrix->p(l++, input->e(k)); - } - //if (e) // from the second loop need to zero matrix - lowerMatrix->assign(0.f); - - for (Nd4jLong col = 0; col < n; col++) { - for (Nd4jLong row = 0; row < col; row++) { - T rowSum = 0; - for (Nd4jLong k = 0; k < row; ++k) - rowSum += (lowerMatrix->e(col, k) * lowerMatrix->e(row, k)); - lowerMatrix->p(col, row, (matrix->e(row, col) - rowSum) / lowerMatrix->e(row, row)); - } - T diagonalSum = 0; - for (Nd4jLong k = 0; k < col; ++k) - diagonalSum += lowerMatrix->e(col, k) * lowerMatrix->e(col, k); - lowerMatrix->p(col, col, sd::math::nd4j_sqrt(matrix->e(col, col) - diagonalSum)); - //nd4j_printf("%i: ", col); - //lowerMatrix->printIndexedBuffer("Lower matrix"); - } - for (int k = e * n2, l = 0; k < (e + 1) * n2; k++) { - output->p(k, lowerMatrix->e(l++)); - } - } +bool checkCholeskyInput(sd::LaunchContext* context, NDArray const* input) { + BUILD_SINGLE_SELECTOR(input->dataType(), return checkCholeskyInput_, + (context, input), FLOAT_TYPES); +} - return ND4J_STATUS_OK; +template +int cholesky_(LaunchContext* context, NDArray* input, NDArray* output, + bool inplace) { + auto n = input->sizeAt(-1); + auto n2 = n * n; + auto totalCount = output->lengthOf() / n2; + if (!inplace) + output->assign(0.f); // fill up output tensor with zeros only inplace=false + + std::unique_ptr matrix(NDArrayFactory::create_( + 'c', {n, n}, input->dataType(), context)); //, block.workspace()); + std::unique_ptr lowerMatrix( + NDArrayFactory::create_('c', {n, n}, input->dataType(), context)); + + for (int e = 0; e < totalCount; e++) { + // fill up matrix + for (int k = e * n2, l = 0; k < (e + 1) * n2; k++) { + matrix->p(l++, input->e(k)); } - - int cholesky(sd::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) { - BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES); + // if (e) // from the second loop need to zero matrix + lowerMatrix->assign(0.f); + + for (Nd4jLong col = 0; col < n; col++) { + for (Nd4jLong row = 0; row < col; row++) { + T rowSum = 0; + for (Nd4jLong k = 0; k < row; ++k) + rowSum += (lowerMatrix->e(col, k) * lowerMatrix->e(row, k)); + lowerMatrix->p(col, row, + (matrix->e(row, col) - rowSum) / + lowerMatrix->e(row, row)); + } + T diagonalSum = 0; + for (Nd4jLong k = 0; k < col; ++k) + diagonalSum += lowerMatrix->e(col, k) * lowerMatrix->e(col, k); + lowerMatrix->p( + col, col, + sd::math::nd4j_sqrt(matrix->e(col, col) - diagonalSum)); + // nd4j_printf("%i: ", col); + // lowerMatrix->printIndexedBuffer("Lower matrix"); } - - template - int logdetFunctor_(LaunchContext *context, NDArray* input, NDArray* output) { - auto tempOutput = input->dup(); - int res = cholesky_(context, input, &tempOutput, false); - if (res != ND4J_STATUS_OK) - return res; - auto n = input->sizeAt(-1); - auto totalCount = output->lengthOf(); - std::vector d(n); - ResultSet matricies = tempOutput.allTensorsAlongDimension({input->rankOf()-2, input->rankOf() - 1}); - - for (Nd4jLong e = 0; e < totalCount; e++) { - for (size_t i = 0; i < n; ++i) - output->r(e) += sd::math::nd4j_log(sd::math::nd4j_pow(matricies.at(e)->t(i, i), T(2))); - } - return ND4J_STATUS_OK; + for (int k = e * n2, l = 0; k < (e + 1) * n2; k++) { + output->p(k, lowerMatrix->e(l++)); } + } - int logdetFunctor(sd::LaunchContext * context, NDArray* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (context, input, output), FLOAT_TYPES); - } + return ND4J_STATUS_OK; +} - int lup(sd::LaunchContext * context, NDArray* input, NDArray* compound, NDArray* permutation) { - BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lup_, (context, input, compound, permutation), FLOAT_NATIVE, INDEXING_TYPES); - return Status::OK(); - } +int cholesky(sd::LaunchContext* context, NDArray* input, NDArray* output, + bool inplace) { + BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, + (context, input, output, inplace), FLOAT_TYPES); +} +template +int logdetFunctor_(LaunchContext* context, NDArray* input, NDArray* output) { + auto tempOutput = input->dup(); + int res = cholesky_(context, input, &tempOutput, false); + if (res != ND4J_STATUS_OK) return res; + auto n = input->sizeAt(-1); + auto totalCount = output->lengthOf(); + std::vector d(n); + ResultSet matricies = tempOutput.allTensorsAlongDimension( + {input->rankOf() - 2, input->rankOf() - 1}); + + for (Nd4jLong e = 0; e < totalCount; e++) { + for (size_t i = 0; i < n; ++i) + output->r(e) += sd::math::nd4j_log( + sd::math::nd4j_pow(matricies.at(e).t(i, i), T(2))); + } + return ND4J_STATUS_OK; } + +int logdetFunctor(sd::LaunchContext* context, NDArray* input, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, + (context, input, output), FLOAT_TYPES); } + +int lup(sd::LaunchContext* context, NDArray* input, NDArray* compound, + NDArray* permutation) { + BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lup_, + (context, input, compound, permutation), FLOAT_NATIVE, + INDEXING_TYPES); + return Status::OK(); } + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/matrixSetDiag.cpp b/libnd4j/include/ops/declarable/helpers/cpu/matrixSetDiag.cpp index 443048c5676a..07ce55dc9af3 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/matrixSetDiag.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/matrixSetDiag.cpp @@ -19,61 +19,64 @@ // #include -#include #include +#include namespace sd { namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// -template -void matrixSetDiag_(const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad) { - - // input and output are the same array (x == z) when zeroPad = true - // xRank = zRank, xRank = yRank + 1 - // xLen = zLen - - const T* x = input.bufferAsT(); - const T* y = diagonal.bufferAsT(); - T* z = output.bufferAsT(); - - const Nd4jLong* xShapeInfo = input.shapeInfo(); - const Nd4jLong* yShapeInfo = diagonal.shapeInfo(); - const Nd4jLong* zShapeInfo = output.shapeInfo(); - - const bool areSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); // shapes are definitely the same, but strides might not - - const int xRank = input.rankOf(); - const auto xLen = input.lengthOf(); - - auto func = PRAGMA_THREADS_FOR { - - int coords[MAX_RANK]; - - for (Nd4jLong i = 0; i < xLen; ++i) { - - shape::index2coordsCPU(start, i, xShapeInfo, coords); - - const auto xOffset = shape::getOffset(xShapeInfo, coords); - const auto zOffset = areSameOffsets ? xOffset : shape::getOffset(zShapeInfo, coords); - - // condition to be on diagonal of innermost matrix - if (coords[xRank - 2] == coords[xRank - 1]) - z[zOffset] = y[shape::getOffset(yShapeInfo, coords)]; - else - z[zOffset] = zeroPad ? static_cast(0) : x[xOffset]; - } - }; - samediff::Threads::parallel_for(func, 0, xLen); +template +void matrixSetDiag_(const NDArray& input, const NDArray& diagonal, + NDArray& output, const bool zeroPad) { + // input and output are the same array (x == z) when zeroPad = true + // xRank = zRank, xRank = yRank + 1 + // xLen = zLen + + const T* x = input.bufferAsT(); + const T* y = diagonal.bufferAsT(); + T* z = output.bufferAsT(); + + const Nd4jLong* xShapeInfo = input.shapeInfo(); + const Nd4jLong* yShapeInfo = diagonal.shapeInfo(); + const Nd4jLong* zShapeInfo = output.shapeInfo(); + + const bool areSameOffsets = shape::haveSameShapeAndStrides( + xShapeInfo, + zShapeInfo); // shapes are definitely the same, but strides might not + + const int xRank = input.rankOf(); + const auto xLen = input.lengthOf(); + + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK]; + + for (Nd4jLong i = 0; i < xLen; ++i) { + shape::index2coordsCPU(start, i, xShapeInfo, coords); + + const auto xOffset = shape::getOffset(xShapeInfo, coords); + const auto zOffset = + areSameOffsets ? xOffset : shape::getOffset(zShapeInfo, coords); + + // condition to be on diagonal of innermost matrix + if (coords[xRank - 2] == coords[xRank - 1]) + z[zOffset] = y[shape::getOffset(yShapeInfo, coords)]; + else + z[zOffset] = zeroPad ? static_cast(0) : x[xOffset]; + } + }; + samediff::Threads::parallel_for(func, 0, xLen); } ////////////////////////////////////////////////////////////////////////// -void matrixSetDiag(sd::LaunchContext* context, const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad) { - BUILD_SINGLE_SELECTOR(input.dataType(), matrixSetDiag_, (input, diagonal, output, zeroPad), LIBND4J_TYPES); +void matrixSetDiag(sd::LaunchContext* context, const NDArray& input, + const NDArray& diagonal, NDArray& output, + const bool zeroPad) { + BUILD_SINGLE_SELECTOR(input.dataType(), matrixSetDiag_, + (input, diagonal, output, zeroPad), LIBND4J_TYPES); } -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/matrix_band.cpp b/libnd4j/include/ops/declarable/helpers/cpu/matrix_band.cpp index d83f0dab94de..0e2d5b2ffff0 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/matrix_band.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/matrix_band.cpp @@ -23,51 +23,63 @@ namespace sd { namespace ops { namespace helpers { - template - void matrixBandPart_(NDArray* input, NDArray* output, Nd4jLong lowerBand, Nd4jLong upperBand) { - // TO DO: retrieve all 2D submatricies with last dimensions and process them with given bands - Nd4jLong M = input->sizeAt(-2); - Nd4jLong N = input->sizeAt(-1); - Nd4jLong lastDim = input->rankOf() - 1; - Nd4jLong preLastDim = input->rankOf() - 2; - ResultSet listOut = output->allTensorsAlongDimension({(int)preLastDim, (int)lastDim}); - ResultSet listDiag = input->allTensorsAlongDimension({(int)preLastDim, (int)lastDim}); - for (Nd4jLong e = 0; e < static_cast(listOut.size()); ++e) { - NDArray* inputMatrix = listDiag.at(e); - NDArray* outputMatrix = listOut.at(e); - if (outputMatrix != inputMatrix) // if not inplace - outputMatrix->assign(inputMatrix); - if (lowerBand >= 0) { - for (Nd4jLong row = 0; row < inputMatrix->rows(); ++row) { - for (Nd4jLong col = 0; col < row; ++col) { - if ((row - col) > lowerBand) - outputMatrix->p(row, col, 0.); -// else - // (*outputMatrix)(row, col) = (*inputMatrix)(row, col); - } -// in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) && (num_upper < 0 || (n-m) <= num_upper). - } - } - if (upperBand >= 0) { - for (Nd4jLong col = 0; col < inputMatrix->columns(); ++col) { - for (Nd4jLong row = 0; row < col; ++row) { - if ((col - row) > upperBand) - outputMatrix->p(row, col, 0.); -// else - // (*outputMatrix)(row, col) = (*inputMatrix)(row, col); - } -// in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) && (num_upper < 0 || (n-m) <= num_upper). - } - - } +template +void matrixBandPart_(NDArray* input, NDArray* output, Nd4jLong lowerBand, + Nd4jLong upperBand) { + // TO DO: retrieve all 2D submatricies with last dimensions and process them + // with given bands + Nd4jLong M = input->sizeAt(-2); + Nd4jLong N = input->sizeAt(-1); + Nd4jLong lastDim = input->rankOf() - 1; + Nd4jLong preLastDim = input->rankOf() - 2; + ResultSet listOut = + output->allTensorsAlongDimension({(int)preLastDim, (int)lastDim}); + ResultSet listDiag = + input->allTensorsAlongDimension({(int)preLastDim, (int)lastDim}); + for (Nd4jLong e = 0; e < static_cast(listOut.size()); ++e) { + auto inputMatrix = listDiag.at(e); + auto outputMatrix = listOut.at(e); + if (outputMatrix.platformBuffer() != + inputMatrix.platformBuffer()) // if not inplace + outputMatrix.assign(inputMatrix); + if (lowerBand >= 0) { + for (Nd4jLong row = 0; row < inputMatrix.rows(); ++row) { + for (Nd4jLong col = 0; col < row; ++col) { + if ((row - col) > lowerBand) outputMatrix.p(row, col, 0.); + // else + // (*outputMatrix)(row, col) = + // (*inputMatrix)(row, col); } + // in_band(m, n) = (num_lower < 0 || (m-n) <= + // num_lower)) && (num_upper < 0 || (n-m) <= + // num_upper). + } } - - void matrixBandPart(sd::LaunchContext * context, NDArray* input, NDArray* output, Nd4jLong lowerBand, Nd4jLong upperBand) { - BUILD_SINGLE_SELECTOR(input->dataType(), matrixBandPart_, (input, output, lowerBand, upperBand), FLOAT_TYPES); + if (upperBand >= 0) { + for (Nd4jLong col = 0; col < inputMatrix.columns(); ++col) { + for (Nd4jLong row = 0; row < col; ++row) { + if ((col - row) > upperBand) outputMatrix.p(row, col, 0.); + // else + // (*outputMatrix)(row, col) = + // (*inputMatrix)(row, col); + } + // in_band(m, n) = (num_lower < 0 || (m-n) <= + // num_lower)) && (num_upper < 0 || (n-m) <= + // num_upper). + } } - BUILD_SINGLE_TEMPLATE(template void matrixBandPart_, (NDArray* input, NDArray* output, Nd4jLong lowerBand, Nd4jLong upperBand), FLOAT_TYPES); -} -} + } } +void matrixBandPart(sd::LaunchContext* context, NDArray* input, NDArray* output, + Nd4jLong lowerBand, Nd4jLong upperBand) { + BUILD_SINGLE_SELECTOR(input->dataType(), matrixBandPart_, + (input, output, lowerBand, upperBand), FLOAT_TYPES); +} +BUILD_SINGLE_TEMPLATE(template void matrixBandPart_, + (NDArray * input, NDArray* output, Nd4jLong lowerBand, + Nd4jLong upperBand), + FLOAT_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag_part.cpp b/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag_part.cpp index 3271dc110cab..137d558bc541 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag_part.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag_part.cpp @@ -19,49 +19,52 @@ // #include -#include -#include #include +#include +#include namespace sd { namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// // Returns a batched matrix tensor with new batched diagonal values. -// for detailed explanations please take a look on web page: https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag +// for detailed explanations please take a look on web page: +// https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag template int _matrixDiagPart(const NDArray* input, NDArray* output) { + auto listOut = output->allTensorsAlongDimension({output->rankOf() - 1}); + auto listDiag = input->allTensorsAlongDimension( + {input->rankOf() - 2, input->rankOf() - 1}); - auto listOut = output->allTensorsAlongDimension({output->rankOf() - 1}); - auto listDiag = input->allTensorsAlongDimension({input->rankOf() - 2, input->rankOf() - 1}); + if (listOut.size() != listDiag.size()) { + nd4j_printf("matrix_diag_part: Input matrix has wrong shape.", ""); + return ND4J_STATUS_VALIDATION; + } + int lastDimension = sd::math::nd4j_min(input->sizeAt(-2), input->sizeAt(-1)); + // TODO: tune this properlys + int lO = listOut.size(); - if (listOut.size() != listDiag. size()) { - nd4j_printf("matrix_diag_part: Input matrix has wrong shape.", ""); - return ND4J_STATUS_VALIDATION; - } - int lastDimension = sd::math::nd4j_min(input->sizeAt(-2), input->sizeAt(-1)); - // TODO: tune this properlys - int lO = listOut.size(); + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) + for (int j = 0; j < lastDimension; ++j) + listOut.at(i).p(j, listDiag.at(i).e(j, j)); + }; - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - for (int j = 0; j < lastDimension; ++j) - listOut.at(i)->p(j, listDiag.at(i)->e(j, j)); - }; + samediff::Threads::parallel_tad(func, 0, lO); - samediff::Threads::parallel_tad(func, 0, lO); - - return Status::OK(); + return Status::OK(); } - int matrixDiagPart(sd::LaunchContext * context, const NDArray* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), return _matrixDiagPart, (input, output), LIBND4J_TYPES); - } +int matrixDiagPart(sd::LaunchContext* context, const NDArray* input, + NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), return _matrixDiagPart, + (input, output), LIBND4J_TYPES); +} - BUILD_SINGLE_TEMPLATE(template int _matrixDiagPart, (const NDArray* input, NDArray* output), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template int _matrixDiagPart, + (const NDArray* input, NDArray* output), LIBND4J_TYPES); -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp b/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp index ebb9d53fa7f0..0321cde1cff4 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp @@ -18,64 +18,71 @@ // @author raver119@gmail.com // -#include #include - +#include namespace sd { namespace ops { namespace helpers { - template - static void maxPoolingFunctor_(sd::graph::Context& block, NDArray* input, NDArray* values, std::vector const& params, NDArray* indices) { - - int kY = params[0]; - int kX = params[1]; +template +static void maxPoolingFunctor_(sd::graph::Context& block, NDArray* input, + NDArray* values, std::vector const& params, + NDArray* indices) { + int kY = params[0]; + int kX = params[1]; - int sY = params[2]; - int sX = params[3]; + int sY = params[2]; + int sX = params[3]; - int pY = params[4]; - int pX = params[5]; + int pY = params[4]; + int pX = params[5]; - int dY = params[6]; - int dX = params[7]; + int dY = params[6]; + int dX = params[7]; - int oY = 0; - int oX = 0; + int oY = 0; + int oX = 0; - const int bSize = input->sizeAt(0); - const int inD = input->sizeAt(1); - const int inY = input->sizeAt(2); - const int inX = input->sizeAt(3); + const int bSize = input->sizeAt(0); + const int inD = input->sizeAt(1); + const int inY = input->sizeAt(2); + const int inX = input->sizeAt(3); - const bool isSameMode = params[8] != 0; + const bool isSameMode = params[8] != 0; - ConvolutionUtils::calcOutSizePool2D(oY, oX, kY, kX, sY, sX, pY, pX, dY, dX, inY, inX, isSameMode); + ConvolutionUtils::calcOutSizePool2D(oY, oX, kY, kX, sY, sX, pY, pX, dY, dX, + inY, inX, isSameMode); - if (isSameMode) - ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, params[0], params[1], params[2], params[3], params[6], params[7]); + if (isSameMode) + ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, params[0], + params[1], params[2], params[3], params[6], + params[7]); - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; - ConvolutionUtils::pooling2d(block, *input, *values, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::MAX_POOL, 1); - - if (nullptr != indices) { - // for max_pool_with_argmax - int total = input->lengthOf(); - int part = total / bSize; - - for (int k = 0; k < total; ) - for (int i = 0; i < part; i++) { - indices->p(k++, i); - } - } + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; + ConvolutionUtils::pooling2d(block, *input, *values, kY, kX, sY, sX, pY, pX, + dY, dX, PoolingType::MAX_POOL, 1); - } - - void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector const& params, NDArray* indices) { - BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), LIBND4J_TYPES); - } + if (nullptr != indices) { + // for max_pool_with_argmax + int total = input->lengthOf(); + int part = total / bSize; + for (int k = 0; k < total;) + for (int i = 0; i < part; i++) { + indices->p(k++, i); + } + } } + +void maxPoolingFunctor(sd::LaunchContext* context, sd::graph::Context& block, + NDArray* input, NDArray* values, + std::vector const& params, NDArray* indices) { + BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, + (block, input, values, params, indices), LIBND4J_TYPES); } -} \ No newline at end of file + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/merge.cpp b/libnd4j/include/ops/declarable/helpers/cpu/merge.cpp index 2a0c5af95f92..c06b40e5c033 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/merge.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/merge.cpp @@ -20,259 +20,269 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// -template -static void mergeMaxIndex_(const std::vector& inArrs, NDArray& output) { - - const Nd4jLong numArgs = inArrs.size(); - auto x = inArrs[0]; - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - X max = -DataTypeUtils::max(); - Z idx = static_cast(0); - - for (Nd4jLong i = 0; i < numArgs; i++) { - X v = inArrs[i]->t(e); - if (v > max) { - max = v; - idx = static_cast(i); - } - } - - output.r(e) = static_cast(idx); +template +static void mergeMaxIndex_(const std::vector& inArrs, + NDArray& output) { + const Nd4jLong numArgs = inArrs.size(); + auto x = inArrs[0]; + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + X max = -DataTypeUtils::max(); + Z idx = static_cast(0); + + for (Nd4jLong i = 0; i < numArgs; i++) { + X v = inArrs[i]->t(e); + if (v > max) { + max = v; + idx = static_cast(i); } - }; + } - samediff::Threads::parallel_for(func, 0, x->lengthOf()); -} + output.r(e) = static_cast( idx); + } + }; -void mergeMaxIndex(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output) { - BUILD_DOUBLE_SELECTOR(inArrs[0]->dataType(), output.dataType(), mergeMaxIndex_, (inArrs, output), LIBND4J_TYPES, INDEXING_TYPES); + samediff::Threads::parallel_for(func, 0, x->lengthOf()); } +void mergeMaxIndex(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output) { + BUILD_DOUBLE_SELECTOR(inArrs[0]->dataType(), output.dataType(),mergeMaxIndex_, (inArrs, output), + LIBND4J_TYPES, INDEXING_TYPES); +} ////////////////////////////////////////////////////////////////////////// -template -static void mergeMax_(const std::vector& inArrs, NDArray& output) { - - const Nd4jLong numArgs = inArrs.size(); - auto x = inArrs[0]; - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - T max = -DataTypeUtils::max(); - for (Nd4jLong i = 0; i < numArgs; i++) { - T v = inArrs[i]->e(e); - if (v > max) - max = v; - } - output.p(e, max); - } - }; +template +static void mergeMax_(const std::vector& inArrs, + NDArray& output) { + const Nd4jLong numArgs = inArrs.size(); + auto x = inArrs[0]; + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + T max = -DataTypeUtils::max(); + for (Nd4jLong i = 0; i < numArgs; i++) { + T v = inArrs[i]->e(e); + if (v > max) max = v; + } + output.p(e, max); + } + }; - samediff::Threads::parallel_for(func, 0, x->lengthOf()); + samediff::Threads::parallel_for(func, 0, x->lengthOf()); } -void mergeMax(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output) { - BUILD_SINGLE_SELECTOR(output.dataType(), mergeMax_, (inArrs, output), LIBND4J_TYPES); +void mergeMax(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output) { + BUILD_SINGLE_SELECTOR(output.dataType(), mergeMax_, (inArrs, output), + LIBND4J_TYPES); } - ////////////////////////////////////////////////////////////////////////// -template -static void mergeMaxBp_(const std::vector& inArrs, std::vector& outArrs) { - - // outArrs.size() == inArrs.size() - 1 - const Nd4jLong numArgs = outArrs.size(); - // last array is gradient - const auto gradient = inArrs[numArgs]->bufferAsT(); - auto length = inArrs[numArgs]->lengthOf(); +template +static void mergeMaxBp_(const std::vector& inArrs, + std::vector& outArrs) { + // outArrs.size() == inArrs.size() - 1 + const Nd4jLong numArgs = outArrs.size(); + // last array is gradient + const auto gradient = inArrs[numArgs]->bufferAsT(); + auto length = inArrs[numArgs]->lengthOf(); - bool bSameOrderAndEws1 = (1 == inArrs[numArgs]->ews()); + bool bSameOrderAndEws1 = (1 == inArrs[numArgs]->ews()); - if (bSameOrderAndEws1) { - auto gradOrdering = inArrs[numArgs]->ordering(); + if (bSameOrderAndEws1) { + auto gradOrdering = inArrs[numArgs]->ordering(); - for (int i = 0; i < numArgs; ++i) { - bSameOrderAndEws1 &= (gradOrdering == inArrs[i]->ordering()); - bSameOrderAndEws1 &= (1 == inArrs[i]->ews()); - bSameOrderAndEws1 &= (gradOrdering == outArrs[i]->ordering()); - bSameOrderAndEws1 &= (1 == outArrs[i]->ews()); - } - } - - - if(bSameOrderAndEws1){ - auto func = PRAGMA_THREADS_FOR{ - for (auto e = start; e < stop; e++) { - T max = -DataTypeUtils::max(); - Nd4jLong nMaxIndex = 0; - for (Nd4jLong i = 0; i < numArgs; i++) { - const T* v = inArrs[i]->bufferAsT(); - if (v[e] > max) { - max = v[e]; - nMaxIndex = i; - } - } - T* z = outArrs[nMaxIndex]->bufferAsT(); - z[e] = gradient[e]; - } - }; - - samediff::Threads::parallel_for(func, 0, length); - return; - } - - auto gradShape = inArrs[numArgs]->shapeInfo(); - std::vector vbSameShaepeAndStrides(numArgs); for (int i = 0; i < numArgs; ++i) { - vbSameShaepeAndStrides[i] = shape::haveSameShapeAndStrides(gradShape, inArrs[i]->shapeInfo()); + bSameOrderAndEws1 &= (gradOrdering == inArrs[i]->ordering()); + bSameOrderAndEws1 &= (1 == inArrs[i]->ews()); + bSameOrderAndEws1 &= (gradOrdering == outArrs[i]->ordering()); + bSameOrderAndEws1 &= (1 == outArrs[i]->ews()); } + } - auto func = PRAGMA_THREADS_FOR{ - - int coords[MAX_RANK]; - for (auto e = start; e < stop; e++) { - - shape::index2coordsCPU(start, e, gradShape, coords); - - const auto gradOffset = shape::getOffset(gradShape, coords); - - T max = -DataTypeUtils::max(); - Nd4jLong nMaxIndex = 0; - - for (Nd4jLong i = 0; i < numArgs; i++) { - - const auto xOffset = vbSameShaepeAndStrides[i] ? gradOffset : shape::getOffset(inArrs[i]->shapeInfo(), coords); - const T* v = inArrs[i]->bufferAsT(); - if (v[xOffset] > max) { - max = v[xOffset]; - nMaxIndex = i; - } - } - - const auto zOffset = vbSameShaepeAndStrides[nMaxIndex] ? gradOffset : shape::getOffset(outArrs[nMaxIndex]->shapeInfo(), coords); - - T* z = outArrs[nMaxIndex]->bufferAsT(); - z[zOffset] = gradient[gradOffset]; - } + if (bSameOrderAndEws1) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + T max = -DataTypeUtils::max(); + Nd4jLong nMaxIndex = 0; + for (Nd4jLong i = 0; i < numArgs; i++) { + const T* v = inArrs[i]->bufferAsT(); + if (v[e] > max) { + max = v[e]; + nMaxIndex = i; + } + } + T* z = outArrs[nMaxIndex]->bufferAsT(); + z[e] = gradient[e]; + } }; samediff::Threads::parallel_for(func, 0, length); return; -} - -void mergeMaxBp(sd::LaunchContext* context, const std::vector& inArrs, std::vector& outArrs) { - BUILD_SINGLE_SELECTOR(outArrs[0]->dataType(), mergeMaxBp_, (inArrs, outArrs), LIBND4J_TYPES); -} + } + + auto gradShape = inArrs[numArgs]->shapeInfo(); + std::vector vbSameShaepeAndStrides(numArgs); + for (int i = 0; i < numArgs; ++i) { + vbSameShaepeAndStrides[i] = + shape::haveSameShapeAndStrides(gradShape, inArrs[i]->shapeInfo()); + } + + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK]; + for (auto e = start; e < stop; e++) { + shape::index2coordsCPU(start, e, gradShape, coords); + + const auto gradOffset = shape::getOffset(gradShape, coords); + + T max = -DataTypeUtils::max(); + Nd4jLong nMaxIndex = 0; + + for (Nd4jLong i = 0; i < numArgs; i++) { + const auto xOffset = + vbSameShaepeAndStrides[i] + ? gradOffset + : shape::getOffset(inArrs[i]->shapeInfo(), coords); + const T* v = inArrs[i]->bufferAsT(); + if (v[xOffset] > max) { + max = v[xOffset]; + nMaxIndex = i; + } + } -////////////////////////////////////////////////////////////////////////// -template -static void mergeAvg_(const std::vector& inArrs, NDArray& output) { - const Nd4jLong numArgs = inArrs.size(); - const T factor = 1.f / numArgs; - auto x = inArrs[0]; + const auto zOffset = + vbSameShaepeAndStrides[nMaxIndex] + ? gradOffset + : shape::getOffset(outArrs[nMaxIndex]->shapeInfo(), coords); - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - T sum = 0.; - for (Nd4jLong i = 0; i < numArgs; i++) { - T v = inArrs[i]->e(e); - sum += v; - } - output.p(e, sum * factor); - } - }; + T* z = outArrs[nMaxIndex]->bufferAsT(); + z[zOffset] = gradient[gradOffset]; + } + }; - samediff::Threads::parallel_for(func, 0, x->lengthOf()); + samediff::Threads::parallel_for(func, 0, length); + return; } -void mergeAvg(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output) { - BUILD_SINGLE_SELECTOR(output.dataType(), mergeAvg_, (inArrs, output), LIBND4J_TYPES); +void mergeMaxBp(sd::LaunchContext* context, + const std::vector& inArrs, + std::vector& outArrs) { + BUILD_SINGLE_SELECTOR(outArrs[0]->dataType(), mergeMaxBp_, (inArrs, outArrs), + LIBND4J_TYPES); } ////////////////////////////////////////////////////////////////////////// -template -static void mergeAvgBp_(const NDArray& gradient, std::vector& outArrs) { - - const Nd4jLong numArgs = outArrs.size(); - - auto func = PRAGMA_THREADS_FOR{ - for (auto e = start; e < stop; e++) { - - T v = gradient.e(e) / numArgs; - - for (Nd4jLong i = 0; i < numArgs; i++) { - outArrs[i]->p(e, v); - } - } - }; +template +static void mergeAvg_(const std::vector& inArrs, + NDArray& output) { + const Nd4jLong numArgs = inArrs.size(); + const T factor = 1.f / numArgs; + auto x = inArrs[0]; + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + T sum = 0.; + for (Nd4jLong i = 0; i < numArgs; i++) { + T v = inArrs[i]->e(e); + sum += v; + } + output.p(e, sum * factor); + } + }; - samediff::Threads::parallel_for(func, 0, gradient.lengthOf()); + samediff::Threads::parallel_for(func, 0, x->lengthOf()); } -void mergeAvgBp(sd::LaunchContext* context, const NDArray& gradient, std::vector& outArrs) { - BUILD_SINGLE_SELECTOR(gradient.dataType(), mergeAvgBp_, (gradient, outArrs), LIBND4J_TYPES); +void mergeAvg(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output) { + BUILD_SINGLE_SELECTOR(output.dataType(), mergeAvg_, (inArrs, output), + LIBND4J_TYPES); } - ////////////////////////////////////////////////////////////////////////// -template -static void mergeAdd_(const std::vector& inArrs, NDArray& output) { - - const Nd4jLong numArgs = inArrs.size(); - auto x = inArrs[0]; - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - T sum = (T) 0.f; - for (Nd4jLong i = 0; i < numArgs; i++) - sum += inArrs[i]->e(e); +template +static void mergeAvgBp_(const NDArray& gradient, + std::vector& outArrs) { + const Nd4jLong numArgs = outArrs.size(); + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + T v = gradient.e(e) / numArgs; + + for (Nd4jLong i = 0; i < numArgs; i++) { + outArrs[i]->p(e, v); + } + } + }; - output.p(e, sum); - } - }; + samediff::Threads::parallel_for(func, 0, gradient.lengthOf()); +} - samediff::Threads::parallel_for(func, 0, x->lengthOf()); +void mergeAvgBp(sd::LaunchContext* context, const NDArray& gradient, + std::vector& outArrs) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), mergeAvgBp_, (gradient, outArrs), + LIBND4J_TYPES); } - void mergeAdd(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output) { - BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (inArrs, output), LIBND4J_TYPES); - } ////////////////////////////////////////////////////////////////////////// -template -static void mergeAddBp_(const NDArray& gradient, std::vector& outArrs) { - - const Nd4jLong numArgs = outArrs.size(); - - auto func = PRAGMA_THREADS_FOR{ - for (auto e = start; e < stop; e++) { - - T v = gradient.e(e); - - for (Nd4jLong i = 0; i < numArgs; i++) { - outArrs[i]->p(e, v); - } - } - }; +template +static void mergeAdd_(const std::vector& inArrs, + NDArray& output) { + const Nd4jLong numArgs = inArrs.size(); + auto x = inArrs[0]; + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + T sum = (T)0.f; + for (Nd4jLong i = 0; i < numArgs; i++) sum += inArrs[i]->e(e); + + output.p(e, sum); + } + }; - samediff::Threads::parallel_for(func, 0, gradient.lengthOf()); + samediff::Threads::parallel_for(func, 0, x->lengthOf()); } - -void mergeAddBp(sd::LaunchContext* context, const NDArray& gradient, std::vector& outArrs) { - BUILD_SINGLE_SELECTOR(gradient.dataType(), mergeAddBp_, (gradient, outArrs), LIBND4J_TYPES); +void mergeAdd(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output) { + BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (inArrs, output), + LIBND4J_TYPES); } +////////////////////////////////////////////////////////////////////////// +template +static void mergeAddBp_(const NDArray& gradient, + std::vector& outArrs) { + const Nd4jLong numArgs = outArrs.size(); + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + T v = gradient.e(e); + + for (Nd4jLong i = 0; i < numArgs; i++) { + outArrs[i]->p(e, v); + } + } + }; + samediff::Threads::parallel_for(func, 0, gradient.lengthOf()); } + +void mergeAddBp(sd::LaunchContext* context, const NDArray& gradient, + std::vector& outArrs) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), mergeAddBp_, (gradient, outArrs), + LIBND4J_TYPES); } -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/meshgrid.cpp b/libnd4j/include/ops/declarable/helpers/cpu/meshgrid.cpp index 336eacf20dfd..7d98880ae34f 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/meshgrid.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/meshgrid.cpp @@ -18,36 +18,33 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.04.2018 // - -#include #include +#include + #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// -void meshgrid(sd::LaunchContext * context, const std::vector& inArrs, const std::vector& outArrs, const bool swapFirst2Dims) { - - const int rank = inArrs.size(); - int inIndices[MAX_RANK]; - std::iota(inIndices, inIndices + rank, 0); - if(swapFirst2Dims && rank > 1) { - inIndices[0] = 1; - inIndices[1] = 0; - } - - for(int i = 0; i < rank; ++i) { - auto list = outArrs[i]->allTensorsAlongDimension({inIndices[i]}); - for(int j = 0; j < list.size(); ++j) - list.at(j)->assign(inArrs[i]); - } -} - -} -} +void meshgrid(sd::LaunchContext* context, const std::vector& inArrs, + const std::vector& outArrs, const bool swapFirst2Dims) { + const int rank = inArrs.size(); + int inIndices[MAX_RANK]; + std::iota(inIndices, inIndices + rank, 0); + if (swapFirst2Dims && rank > 1) { + inIndices[0] = 1; + inIndices[1] = 0; + } + + for (int i = 0; i < rank; ++i) { + auto list = outArrs[i]->allTensorsAlongDimension({inIndices[i]}); + for (int j = 0; j < list.size(); ++j) list.at(j).assign(inArrs[i]); + } } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/minimax.cpp b/libnd4j/include/ops/declarable/helpers/cpu/minimax.cpp index 6174151d6874..ce7fffb97381 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/minimax.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/minimax.cpp @@ -19,159 +19,160 @@ // #ifndef __MIN_I_MAX_H_HELPERS__ #define __MIN_I_MAX_H_HELPERS__ -#include #include #include +#include namespace sd { namespace ops { namespace helpers { - template - static void minimumBPFunctor_(NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) { - - auto lambdaX = LAMBDA_TTT(_e, _x, _y) { - return _x <= _y ? _e : (T) 0.; - }; - - auto lambdaY = LAMBDA_TTT(_e, _x, _y) { - return _x >= _y ? _e : (T) 0.; - }; - - - if (x->isSameShape(y)) { - // PWT case case - - // X gradient - epsNext->applyTriplewiseLambda(*x, *y, lambdaX, *gradX); - - // Y gradient - epsNext->applyTriplewiseLambda(*x, *y, lambdaY, *gradY); - - } else if (y->isScalar()) { - T s = y->e(0); - auto lambdaS = LAMBDA_TT(_e, _x, s) { - return _x <= s ? _e : (T) 0.; - }; - - // scalar case - auto tmp = epsNext->reduceNumber(reduce::Sum); - if (x <= y) - gradY->assign(tmp); - else - gradY->assign(0.0f); - - epsNext->applyPairwiseLambda(*x, lambdaS, *gradX); - } else { - // broadcast case - - // in this case we want to boost our X and Y shapes to the size of FF pass output (or epsNext, which has the same shape) - auto preX = x->dup(); - auto preY = y->dup(); - - auto targetShape = epsNext->getShapeAsVector(); +template +static void minimumBPFunctor_(NDArray* x, NDArray* y, NDArray* epsNext, + NDArray* gradX, NDArray* gradY) { + auto lambdaX = LAMBDA_TTT(_e, _x, _y) { return _x <= _y ? _e : (T)0.; }; - preX.tileToShape(targetShape, preX); - preY.tileToShape(targetShape, preY); + auto lambdaY = LAMBDA_TTT(_e, _x, _y) { return _x >= _y ? _e : (T)0.; }; - epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX); - epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY); + if (x->isSameShape(y)) { + // PWT case case - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); + // X gradient + epsNext->applyTriplewiseLambda(*x, *y, lambdaX, *gradX); - if (axisX.size() > 0) { - auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); - gradX->assign(sum); - } else - gradX->assign(preX); + // Y gradient + epsNext->applyTriplewiseLambda(*x, *y, lambdaY, *gradY); - if (axisY.size() > 0) { - auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); - gradY->assign(sum); - } else - gradY->assign(preY); - } + } else if (y->isScalar()) { + T s = y->e(0); + auto lambdaS = LAMBDA_TT(_e, _x, s) { return _x <= s ? _e : (T)0.; }; - } - template - void maximumBPFunctor_(NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) { + // scalar case + auto tmp = epsNext->reduceNumber(reduce::Sum); + if (x <= y) + gradY->assign(tmp); + else + gradY->assign(0.0f); - auto lambdaX = LAMBDA_TTT(_e, _x, _y) { - return _x >= _y ? _e : (T) 0.; - }; + epsNext->applyPairwiseLambda(*x, lambdaS, *gradX); + } else { + // broadcast case - auto lambdaY = LAMBDA_TTT(_e, _x, _y) { - return _x <= _y ? _e : (T) 0.; - }; + // in this case we want to boost our X and Y shapes to the size of FF pass + // output (or epsNext, which has the same shape) + auto preX = x->dup(); + auto preY = y->dup(); + auto targetShape = epsNext->getShapeAsVector(); - if (x->isSameShape(y)) { - // PWT case case + preX.tileToShape(targetShape, preX); + preY.tileToShape(targetShape, preY); - // X gradient - epsNext->applyTriplewiseLambda(*x, *y, lambdaX, *gradX); + epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX); + epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY); - // Y gradient - epsNext->applyTriplewiseLambda(*x, *y, lambdaY, *gradY); + auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), + epsNext->shapeInfo()); + auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), + epsNext->shapeInfo()); - } else if (y->isScalar()) { - T s = y->e(0); - auto lambdaS = LAMBDA_TT(_e, _x, s) { - return _x >= s ? _e : (T) 0.; - }; - - // scalar case - auto tmp = epsNext->reduceNumber(reduce::Sum); - if (x <= y) - gradY->assign(tmp); - else - gradY->assign(0.0f); - - epsNext->applyPairwiseLambda(*x, lambdaS, *gradX); - } else { - // broadcast case - - // in this case we want to boost our X and Y shapes to the size of FF pass output (or epsNext, which has the same shape) - auto preX = x->dup(); - auto preY = y->dup(); - - auto targetShape = epsNext->getShapeAsVector(); - - preX.tileToShape(targetShape, preX); - preY.tileToShape(targetShape, preY); - - epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX); - epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY); - - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); - - if (axisX.size() > 0) { - auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); - gradX->assign(sum); - } else - gradX->assign(preX); - - if (axisY.size() > 0) { - auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); - gradY->assign(sum); - } else - gradY->assign(preY); - } - } - - void minimumBPFunctor(sd::LaunchContext * context, NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) { - BUILD_SINGLE_SELECTOR(x->dataType(), minimumBPFunctor_, (x, y, epsNext, gradX, gradY), NUMERIC_TYPES); - } - - void maximumBPFunctor(sd::LaunchContext * context, NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) { - BUILD_SINGLE_SELECTOR(x->dataType(), maximumBPFunctor_, (x, y, epsNext, gradX, gradY), NUMERIC_TYPES); - } - BUILD_SINGLE_TEMPLATE(template void minimumBPFunctor_, (NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY), NUMERIC_TYPES); - BUILD_SINGLE_TEMPLATE(template void maximumBPFunctor_, (NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY), NUMERIC_TYPES); + if (axisX.size() > 0) { + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); + gradX->assign(sum); + } else + gradX->assign(preX); + if (axisY.size() > 0) { + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); + gradY->assign(sum); + } else + gradY->assign(preY); + } } +template +void maximumBPFunctor_(NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, + NDArray* gradY) { + auto lambdaX = LAMBDA_TTT(_e, _x, _y) { return _x >= _y ? _e : (T)0.; }; + + auto lambdaY = LAMBDA_TTT(_e, _x, _y) { return _x <= _y ? _e : (T)0.; }; + + if (x->isSameShape(y)) { + // PWT case case + + // X gradient + epsNext->applyTriplewiseLambda(*x, *y, lambdaX, *gradX); + + // Y gradient + epsNext->applyTriplewiseLambda(*x, *y, lambdaY, *gradY); + + } else if (y->isScalar()) { + T s = y->e(0); + auto lambdaS = LAMBDA_TT(_e, _x, s) { return _x >= s ? _e : (T)0.; }; + + // scalar case + auto tmp = epsNext->reduceNumber(reduce::Sum); + if (x <= y) + gradY->assign(tmp); + else + gradY->assign(0.0f); + + epsNext->applyPairwiseLambda(*x, lambdaS, *gradX); + } else { + // broadcast case + + // in this case we want to boost our X and Y shapes to the size of FF pass + // output (or epsNext, which has the same shape) + auto preX = x->dup(); + auto preY = y->dup(); + + auto targetShape = epsNext->getShapeAsVector(); + + preX.tileToShape(targetShape, preX); + preY.tileToShape(targetShape, preY); + + epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX); + epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY); + + auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), + epsNext->shapeInfo()); + auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), + epsNext->shapeInfo()); + + if (axisX.size() > 0) { + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); + gradX->assign(sum); + } else + gradX->assign(preX); + + if (axisY.size() > 0) { + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); + gradY->assign(sum); + } else + gradY->assign(preY); + } } + +void minimumBPFunctor(sd::LaunchContext* context, NDArray* x, NDArray* y, + NDArray* epsNext, NDArray* gradX, NDArray* gradY) { + BUILD_SINGLE_SELECTOR(x->dataType(), minimumBPFunctor_, + (x, y, epsNext, gradX, gradY), NUMERIC_TYPES); +} + +void maximumBPFunctor(sd::LaunchContext* context, NDArray* x, NDArray* y, + NDArray* epsNext, NDArray* gradX, NDArray* gradY) { + BUILD_SINGLE_SELECTOR(x->dataType(), maximumBPFunctor_, + (x, y, epsNext, gradX, gradY), NUMERIC_TYPES); } +BUILD_SINGLE_TEMPLATE(template void minimumBPFunctor_, + (NDArray * x, NDArray* y, NDArray* epsNext, + NDArray* gradX, NDArray* gradY), + NUMERIC_TYPES); +BUILD_SINGLE_TEMPLATE(template void maximumBPFunctor_, + (NDArray * x, NDArray* y, NDArray* epsNext, + NDArray* gradX, NDArray* gradY), + NUMERIC_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/cpu/nth_element.cpp b/libnd4j/include/ops/declarable/helpers/cpu/nth_element.cpp index b9225e40d151..36c3684c6326 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/nth_element.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/nth_element.cpp @@ -18,59 +18,70 @@ // @author sgazeos@gmail.com // -#include -#include -#include -#include #include +#include +#include +#include +#include namespace sd { namespace ops { namespace helpers { - template - void nthElementFunctor_(NDArray* input, Nd4jLong n, NDArray* output, bool reverse) { +template +void nthElementFunctor_(NDArray* input, Nd4jLong n, NDArray* output, + bool reverse) { + NDArray sortedVals(*input); + if (input->isVector()) { + // std::vector data(input->lengthOf()); + // memcpy(&data[0], input->buffer(), sizeof(T) * data.size()); + // size_t l = 0; + // for (size_t l = 0; l < data.size(); ++l) + // data[l] = input->e(l); + // auto nthPos = data.begin(); + // nthPos += n; + // std::nth_element(data.begin(), nthPos, data.end()); + SpecialMethods::sortGeneric(sortedVals.buffer(), sortedVals.shapeInfo(), + reverse); + output->p(0, sortedVals.e(n)); + } else { // rank greater than 1 + std::vector lastDims( + {input->rankOf() - + 1}); // = ShapeUtils::evalDimsToExclude(input->rankOf(), + // {input->rankOf() - 1}); - NDArray sortedVals(*input); - if (input->isVector()) { - //std::vector data(input->lengthOf()); - //memcpy(&data[0], input->buffer(), sizeof(T) * data.size()); - //size_t l = 0; - //for (size_t l = 0; l < data.size(); ++l) - // data[l] = input->e(l); - //auto nthPos = data.begin(); - //nthPos += n; - //std::nth_element(data.begin(), nthPos, data.end()); - SpecialMethods::sortGeneric(sortedVals.buffer(), sortedVals.shapeInfo(), reverse); - output->p(0, sortedVals.e(n)); - } - else { // rank greater than 1 - std::vector lastDims({input->rankOf() - 1});// = ShapeUtils::evalDimsToExclude(input->rankOf(), {input->rankOf() - 1}); + auto pack = sd::ConstantTadHelper::getInstance().tadForDimensions( + sortedVals.shapeInfo(), lastDims); - auto pack = sd::ConstantTadHelper::getInstance().tadForDimensions(sortedVals.shapeInfo(), lastDims); + SpecialMethods::sortTadGeneric(sortedVals.buffer(), + sortedVals.shapeInfo(), lastDims.data(), + lastDims.size(), pack.primaryShapeInfo(), + pack.primaryOffsets(), reverse); - SpecialMethods::sortTadGeneric(sortedVals.buffer(), sortedVals.shapeInfo(), lastDims.data(), lastDims.size(), pack.primaryShapeInfo(), pack.primaryOffsets(), reverse); + ResultSet rows = sortedVals.allTensorsAlongDimension(lastDims); + Nd4jLong oL = output->lengthOf(); - ResultSet rows = sortedVals.allTensorsAlongDimension(lastDims); - Nd4jLong oL = output->lengthOf(); - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto row = rows.at(e); - output->p(e, row->e(n)); - } - }; - - samediff::Threads::parallel_for(func, 0, oL); - } - } - - void nthElementFunctor(sd::LaunchContext *launchContext, NDArray* input, Nd4jLong n, NDArray* output, bool reverse) { - BUILD_SINGLE_SELECTOR(input->dataType(), nthElementFunctor_, (input, n, output, reverse), LIBND4J_TYPES); - - } - BUILD_SINGLE_TEMPLATE(template void nthElementFunctor_, (NDArray* input, Nd4jLong n, NDArray* output, bool reverse), LIBND4J_TYPES); + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto row = rows.at(e); + output->p(e, row.e(n)); + } + }; + samediff::Threads::parallel_for(func, 0, oL); + } } + +void nthElementFunctor(sd::LaunchContext* launchContext, NDArray* input, + Nd4jLong n, NDArray* output, bool reverse) { + BUILD_SINGLE_SELECTOR(input->dataType(), nthElementFunctor_, + (input, n, output, reverse), LIBND4J_TYPES); } -} +BUILD_SINGLE_TEMPLATE(template void nthElementFunctor_, + (NDArray * input, Nd4jLong n, NDArray* output, + bool reverse), + LIBND4J_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/one_hot.cpp b/libnd4j/include/ops/declarable/helpers/cpu/one_hot.cpp index 41a265ca9a97..05f77bb8c6e9 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/one_hot.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/one_hot.cpp @@ -18,86 +18,103 @@ // @author raver119@gmail.com // -#include -#include +#include + #include -#include "../one_hot.h" +#include +#include namespace sd { - namespace ops { - namespace helpers { - template - static void onehot_(void *voutput, Nd4jLong const* zShapeInfo, void const* vindices, Nd4jLong const* iShapeInfo, int axis, double on, double off) { - auto output = reinterpret_cast(voutput); - auto indices = reinterpret_cast(vindices); - - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(zShapeInfo, {axis}); - - auto iLen = static_cast(shape::length(iShapeInfo)); - auto tLen = static_cast(shape::length(tadPack.primaryShapeInfo())); - auto numTads = static_cast(tadPack.numberOfTads()); - auto tadEws = shape::elementWiseStride(tadPack.primaryShapeInfo()); - - if (iLen != numTads) - throw std::runtime_error("OneHot: number of TADs should be equal to number of indices"); - - if (shape::elementWiseStride(zShapeInfo) != 1 || shape::elementWiseStride(iShapeInfo) != 1) - throw std::runtime_error("OneHot: op expects output and indices to have elementWiseStride to be equal to 1"); - - Z zero = static_cast(off); - Z one = static_cast(on); - - if (tadEws >= 1) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = 0; e < stop; e++) { - auto cO = output + tadPack.primaryOffsets()[e]; - - auto idx = static_cast(indices[e]); - if (idx < 0 || idx >= tLen) { - PRAGMA_OMP_SIMD - for (unsigned int t = 0; t < tLen; t++) { - cO[t * tadEws] = zero; - } - } else { - PRAGMA_OMP_SIMD - for (unsigned int t = 0; t < tLen; t++) { - cO[t * tadEws] = idx == t ? one : zero; - } - } - } - }; - - samediff::Threads::parallel_tad(func, 0, numTads); - } else { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto cO = output + tadPack.primaryOffsets()[e]; - - auto idx = static_cast(indices[e]); - if (idx < 0 || idx >= tLen) { - PRAGMA_OMP_SIMD - for (unsigned int t = 0; t < tLen; t++) { - cO[shape::getIndexOffset(t, tadPack.primaryShapeInfo())] = zero; - } - } else { - PRAGMA_OMP_SIMD - for (unsigned int t = 0; t < tLen; t++) { - cO[shape::getIndexOffset(t, tadPack.primaryShapeInfo())] = idx == t ? one : zero; - } - } - } - }; - - samediff::Threads::parallel_tad(func, 0, numTads); - } - } - - void onehot(const sd::LaunchContext* context, const NDArray *indices, NDArray *output, const uint axis, const uint depth, const double on, const double off) { - auto zType = output->dataType(); - auto iType = indices->dataType(); - - BUILD_DOUBLE_SELECTOR(zType, iType, onehot_, (output->buffer(), output->shapeInfo(), indices->buffer(), indices->shapeInfo(), axis, on, off), LIBND4J_TYPES, LIBND4J_TYPES); - } +namespace ops { +namespace helpers { + +template +static void onehot_(void* voutput, Nd4jLong const* zShapeInfo, + void const* vindices, Nd4jLong const* iShapeInfo, int axis, + double on, double off) { + auto output = reinterpret_cast(voutput); + auto indices = reinterpret_cast(vindices); + + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + zShapeInfo, axis); + + auto iLen = static_cast(shape::length(iShapeInfo)); + auto tLen = + static_cast(shape::length(tadPack.primaryShapeInfo())); + auto numTads = static_cast(tadPack.numberOfTads()); + auto tadEws = shape::elementWiseStride(tadPack.primaryShapeInfo()); + + if (iLen != numTads) + throw std::runtime_error( + "OneHot: number of TADs should be equal to number of indices"); + + if (shape::elementWiseStride(zShapeInfo) != 1 || + shape::elementWiseStride(iShapeInfo) != 1) + throw std::runtime_error( + "OneHot: op expects output and indices to have elementWiseStride to be " + "equal to 1"); + + Z zero = static_cast(off); + Z one = static_cast(on); + + if (tadEws >= 1) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = 0; e < stop; e++) { + auto cO = output + tadPack.primaryOffsets()[e]; + + auto idx = static_cast(indices[e]); + if (idx < 0 || idx >= tLen) { + PRAGMA_OMP_SIMD + for (unsigned int t = 0; t < tLen; t++) { + cO[t * tadEws] = zero; + } + } else { + PRAGMA_OMP_SIMD + for (unsigned int t = 0; t < tLen; t++) { + cO[t * tadEws] = idx == t ? one : zero; + } + } + } + }; + + samediff::Threads::parallel_tad(func, 0, numTads); + } else { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto cO = output + tadPack.primaryOffsets()[e]; + + auto idx = static_cast(indices[e]); + if (idx < 0 || idx >= tLen) { + PRAGMA_OMP_SIMD + for (unsigned int t = 0; t < tLen; t++) { + cO[shape::getIndexOffset(t, tadPack.primaryShapeInfo())] = zero; + } + } else { + PRAGMA_OMP_SIMD + for (unsigned int t = 0; t < tLen; t++) { + cO[shape::getIndexOffset(t, tadPack.primaryShapeInfo())] = + idx == t ? one : zero; + } } - } + } + }; + + samediff::Threads::parallel_tad(func, 0, numTads); + } +} + +void onehot(const sd::LaunchContext* context, const NDArray* indices, + NDArray* output, const uint axis, const uint depth, const double on, + const double off) { + auto zType = output->dataType(); + auto iType = indices->dataType(); + + BUILD_DOUBLE_SELECTOR( + zType, iType, onehot_, + (output->buffer(), output->shapeInfo(), indices->buffer(), + indices->shapeInfo(), axis, on, off), + LIBND4J_TYPES, LIBND4J_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/pad.cpp b/libnd4j/include/ops/declarable/helpers/cpu/pad.cpp index a0efd44c1d87..003ef27bf30c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/pad.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/pad.cpp @@ -18,123 +18,119 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// -template -void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, const NDArray& padValue) { - - const T* x = input.bufferAsT(); - T* z = output.bufferAsT(); - - const Nd4jLong* xShape = input.shapeOf(); - const Nd4jLong* zShape = output.shapeOf(); - - const int rank = input.rankOf(); // both input and output have the same rank - const int rankMinusOne = rank - 1; - - const auto zLen = output.lengthOf(); - - if(mode == 0) { // CONSTANT case - - const T padVal = padValue.e(0); +template +void pad_(const int mode, const NDArray& input, const NDArray& paddings, + NDArray& output, const NDArray& padValue) { + const T* x = input.bufferAsT(); + T* z = output.bufferAsT(); - auto func = PRAGMA_THREADS_FOR { + const Nd4jLong* xShape = input.shapeOf(); + const Nd4jLong* zShape = output.shapeOf(); - int zCoords[MAX_RANK], xCoords[MAX_RANK]; + const int rank = input.rankOf(); // both input and output have the same rank + const int rankMinusOne = rank - 1; - for (auto i = start; i < stop; i++) { + const auto zLen = output.lengthOf(); - shape::index2coordsCPU(start, i, output.shapeInfo(), zCoords); - const auto zOffset = shape::getOffset(output.shapeInfo(), zCoords); + if (mode == 0) { // CONSTANT case - memcpy(xCoords, zCoords, rank * sizeof(int)); + const T padVal = padValue.e(0); - bool within = true; + auto func = PRAGMA_THREADS_FOR { + int zCoords[MAX_RANK], xCoords[MAX_RANK]; - for (int j = rankMinusOne; j >= 0; --j) { + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, output.shapeInfo(), zCoords); + const auto zOffset = shape::getOffset(output.shapeInfo(), zCoords); - if (xShape[j] == zShape[j]) - continue; + memcpy(xCoords, zCoords, rank * sizeof(int)); - const auto left = paddings.e(j, 0); + bool within = true; - if (zCoords[j] < left || zCoords[j] >= left + xShape[j]) { - within = false; - break; - } - else - xCoords[j] = zCoords[j] - left; - } - - if (within) - z[zOffset] = x[shape::getOffset(input.shapeInfo(), xCoords)]; - else - z[zOffset] = padVal; - } - }; + for (int j = rankMinusOne; j >= 0; --j) { + if (xShape[j] == zShape[j]) continue; - samediff::Threads::parallel_tad(func, 0, zLen); - } - else { // REFLECT and SYMMETRIC cases + const auto left = paddings.e(j, 0); - const Nd4jLong shift1 = mode == 1 ? 0 : 1; // REFLECT : SYMMETRIC - const Nd4jLong shift2 = mode == 1 ? 2 : 1; // REFLECT : SYMMETRIC + if (zCoords[j] < left || zCoords[j] >= left + xShape[j]) { + within = false; + break; + } else + xCoords[j] = zCoords[j] - left; + } - auto func = PRAGMA_THREADS_FOR { + if (within) + z[zOffset] = x[shape::getOffset(input.shapeInfo(), xCoords)]; + else + z[zOffset] = padVal; + } + }; - int zCoords[MAX_RANK], xCoords[MAX_RANK]; + samediff::Threads::parallel_tad(func, 0, zLen); + } else { // REFLECT and SYMMETRIC cases - for (auto i = start; i < stop; i++) { + const Nd4jLong shift1 = mode == 1 ? 0 : 1; // REFLECT : SYMMETRIC + const Nd4jLong shift2 = mode == 1 ? 2 : 1; // REFLECT : SYMMETRIC - shape::index2coordsCPU(start, i, output.shapeInfo(), zCoords); - const auto zOffset = shape::getOffset(output.shapeInfo(), zCoords); + auto func = PRAGMA_THREADS_FOR { + int zCoords[MAX_RANK], xCoords[MAX_RANK]; - memcpy(xCoords, zCoords, rank * sizeof(int)); + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, output.shapeInfo(), zCoords); + const auto zOffset = shape::getOffset(output.shapeInfo(), zCoords); - for (int j = rankMinusOne; j >= 0; --j) { + memcpy(xCoords, zCoords, rank * sizeof(int)); - if (xShape[j] == zShape[j]) - continue; + for (int j = rankMinusOne; j >= 0; --j) { + if (xShape[j] == zShape[j]) continue; - xCoords[j] = zCoords[j] - paddings.e(j, 0); // are ready to fill middle (within input dimension range) + xCoords[j] = + zCoords[j] - + paddings.e(j, 0); // are ready to fill middle (within + // input dimension range) - if (xCoords[j] < 0) - xCoords[j] = -xCoords[j] - shift1; // means fill from left - else if (xCoords[j] >= xShape[j]) - xCoords[j] = 2 * xShape[j] - xCoords[j] - shift2; // means fill from right - } + if (xCoords[j] < 0) + xCoords[j] = -xCoords[j] - shift1; // means fill from left + else if (xCoords[j] >= xShape[j]) + xCoords[j] = + 2 * xShape[j] - xCoords[j] - shift2; // means fill from right + } - const auto xOffset = shape::getOffset(input.shapeInfo(), xCoords); - z[zOffset] = x[xOffset]; - } - }; + const auto xOffset = shape::getOffset(input.shapeInfo(), xCoords); + z[zOffset] = x[xOffset]; + } + }; - samediff::Threads::parallel_tad(func, 0, zLen); - } + samediff::Threads::parallel_tad(func, 0, zLen); + } } // ////////////////////////////////////////////////////////////////////////// // template -// void pad2_(const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, NDArray const& padValue) { +// void pad2_(const int mode, const NDArray& input, const NDArray& paddings, +// NDArray& output, NDArray const& padValue) { // const int rank = output.rankOf(); // std::vector dimsToExclude(rank); -// std::iota(dimsToExclude.begin(), dimsToExclude.end(), 0); // fill with 0, 1, ... rank-1 +// std::iota(dimsToExclude.begin(), dimsToExclude.end(), 0); // +// fill with 0, 1, ... rank-1 // Nd4jLong numLeft = paddings.e(rank-1,0); // Nd4jLong numRight = paddings.e(rank-1,1); // Nd4jLong inDimSize = input.sizeAt(rank-1); // Nd4jLong outDimSize = output.sizeAt(rank-1); -// std::vector> outIdx = { std::vector(2*rank), {numLeft, numLeft + inDimSize}, {0, numLeft}, {numLeft + inDimSize, outDimSize} }; +// std::vector> outIdx = { +// std::vector(2*rank), {numLeft, numLeft + inDimSize}, {0, +// numLeft}, {numLeft + inDimSize, outDimSize} }; // for(int i = 0; i < rank-1; ++i) { // outIdx[0][2*i] = paddings.e(i, 0); @@ -145,10 +141,12 @@ void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray // // ***** populate innermost sub-arrays firstly ***** // // dimsToExclude.pop_back(); -// Nd4jLong startL = mode == 1 ? 1 : 0; // REFLECT or SYMMETRIC -// Nd4jLong startR = mode == 1 ? inDimSize-2 : inDimSize-1; // REFLECT or SYMMETRIC +// Nd4jLong startL = mode == 1 ? 1 : 0; // +// REFLECT or SYMMETRIC Nd4jLong startR = mode == 1 ? inDimSize-2 : +// inDimSize-1; // REFLECT or SYMMETRIC -// Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(input.shapeInfo(), dimsToExclude); +// Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(input.shapeInfo(), +// dimsToExclude); // NDArray outSubArr0 = output(outIdx[0], true); @@ -171,12 +169,14 @@ void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray // temp.assign(padValue); // assign right // } // } -// else { // REFLECT or SYMMETRIC +// else { // REFLECT or SYMMETRIC -// for(Nd4jLong k = numLeft-1, e = startL; k >= 0; --k, ++e) // fill left side +// for(Nd4jLong k = numLeft-1, e = startL; k >= 0; --k, ++e) // +// fill left side // outSubArr1.t(k) = inSubArr.t(e); -// for(Nd4jLong k = numLeft + inDimSize, e = startR; k < outDimSize; ++k, --e) // fill right side +// for(Nd4jLong k = numLeft + inDimSize, e = startR; k < outDimSize; +// ++k, --e) // fill right side // outSubArr1.t(k) = inSubArr.t(e); // } // } @@ -203,13 +203,16 @@ void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray // if(mode == 0) { // outIdxOuter[0] = 0; outIdxOuter[1] = numLeft; -// outIdxInner[0] = numLeft + inDimSize; outIdxInner[1] = outDimSize; +// outIdxInner[0] = numLeft + inDimSize; outIdxInner[1] = +// outDimSize; // } -// startL = mode == 1 ? numLeft + 1 : numLeft; // REFLECT or SYMMETRIC -// startR = mode == 1 ? numLeft + inDimSize - 2 : numLeft + inDimSize-1; // REFLECT or SYMMETRIC +// startL = mode == 1 ? numLeft + 1 : numLeft; // REFLECT or SYMMETRIC +// startR = mode == 1 ? numLeft + inDimSize - 2 : numLeft + inDimSize-1; +// // REFLECT or SYMMETRIC -// numOfSubArrs = ShapeUtils::getNumOfSubArrs(output.shapeInfo(), dimsToExclude); +// numOfSubArrs = ShapeUtils::getNumOfSubArrs(output.shapeInfo(), +// dimsToExclude); // PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(outIdxOuter, outIdxInner)) // for(Nd4jLong j = 0; j < numOfSubArrs; ++j) { @@ -220,17 +223,20 @@ void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray // if(numLeft != 0) { // NDArray tempO = outSubArr(outIdxOuter); -// tempO.assign(padValue); // assign left +// tempO.assign(padValue); // +// assign left // } // if(numRight != 0) { // NDArray tempI = outSubArr(outIdxInner); -// tempI.assign(padValue); // assign right +// tempI.assign(padValue); // +// assign right // } // } -// else { // REFLECT or SYMMETRIC +// else { // REFLECT or SYMMETRIC -// for(Nd4jLong k = numLeft-1, e = startL; k >= 0; --k, ++e) { // fill left side +// for(Nd4jLong k = numLeft-1, e = startL; k >= 0; --k, ++e) { +// // fill left side // outIdxOuter[0] = k; // outIdxOuter[1] = k+1; // outIdxInner[0] = e; @@ -240,7 +246,8 @@ void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray // outSubArrOuter.assign(outSubArrInner); // } -// for(Nd4jLong k = numLeft + inDimSize, e = startR; k < outDimSize; ++k, --e) { // fill right side +// for(Nd4jLong k = numLeft + inDimSize, e = startR; k < +// outDimSize; ++k, --e) { // fill right side // outIdxOuter[0] = k; // outIdxOuter[1] = k+1; // outIdxInner[0] = e; @@ -254,106 +261,115 @@ void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray // } // } -void pad(sd::LaunchContext * context, const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, NDArray const& padValue) { - BUILD_SINGLE_SELECTOR(input.dataType(), pad_, (mode, input, paddings, output, padValue), LIBND4J_TYPES); +void pad(sd::LaunchContext* context, const int mode, const NDArray& input, + const NDArray& paddings, NDArray& output, NDArray const& padValue) { + BUILD_SINGLE_SELECTOR(input.dataType(), pad_, + (mode, input, paddings, output, padValue), + LIBND4J_TYPES); } ////////////////////////////////////////////////////////////////////////// -template -static void mirrorPad_(const NDArray& input, const NDArray& paddings, NDArray& output, const int mode) { - - // mode: 0 - REFLECT, else - SYMMETRIC - const int reflBorder = (bool)mode ? 1 : 0; - const int rank = input.rankOf(); - const Nd4jLong outLen = output.lengthOf(); - - if(rank <= 1) { - - const Nd4jLong inLen = input.lengthOf(); - const auto leftSide = paddings.e(0); - const auto leftSideCorrected = leftSide - reflBorder; - const Nd4jLong len = 2*(inLen-1) + leftSide + reflBorder; - - for(int i = 0; i < outLen; ++i) { - - if (i < leftSide) // left side - output.p(i, input.e(leftSideCorrected - i)); - - else if(i >= leftSide && i < leftSide + inLen) // middle - output.p(i, input.e(i - leftSide)); - - else // right side - output.p(i, input.e(len - i)); - } +template +static void mirrorPad_(const NDArray& input, const NDArray& paddings, + NDArray& output, const int mode) { + // mode: 0 - REFLECT, else - SYMMETRIC + const int reflBorder = (bool)mode ? 1 : 0; + const int rank = input.rankOf(); + const Nd4jLong outLen = output.lengthOf(); + + if (rank <= 1) { + const Nd4jLong inLen = input.lengthOf(); + const auto leftSide = paddings.e(0); + const auto leftSideCorrected = leftSide - reflBorder; + const Nd4jLong len = 2 * (inLen - 1) + leftSide + reflBorder; + + for (int i = 0; i < outLen; ++i) { + if (i < leftSide) // left side + output.p(i, input.e(leftSideCorrected - i)); + + else if (i >= leftSide && i < leftSide + inLen) // middle + output.p(i, input.e(i - leftSide)); + + else // right side + output.p(i, input.e(len - i)); } - else { - - auto func = PRAGMA_THREADS_FOR { - - int inIdx[MAX_RANK], outIdx[MAX_RANK]; - - for (auto i = start; i < stop; i++) { + } else { + auto func = PRAGMA_THREADS_FOR { + int inIdx[MAX_RANK], outIdx[MAX_RANK]; - shape::index2coordsCPU(start, i, output.shapeInfo(), outIdx); + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, output.shapeInfo(), outIdx); - for (int j = 0; j < rank; ++j) { - const Nd4jLong inLen = input.sizeAt(j); - const auto leftSide = paddings.e(j, 0); - const auto leftSideCorrected = leftSide - reflBorder; - const Nd4jLong len = 2 * (inLen - 1) + leftSide + reflBorder; + for (int j = 0; j < rank; ++j) { + const Nd4jLong inLen = input.sizeAt(j); + const auto leftSide = paddings.e(j, 0); + const auto leftSideCorrected = leftSide - reflBorder; + const Nd4jLong len = 2 * (inLen - 1) + leftSide + reflBorder; - if (outIdx[j] < leftSide) // left side - inIdx[j] = leftSideCorrected - outIdx[j]; + if (outIdx[j] < leftSide) // left side + inIdx[j] = leftSideCorrected - outIdx[j]; - else if (outIdx[j] >= leftSide && outIdx[j] < leftSide + inLen) // middle - inIdx[j] = outIdx[j] - leftSide; + else if (outIdx[j] >= leftSide && + outIdx[j] < leftSide + inLen) // middle + inIdx[j] = outIdx[j] - leftSide; - else // right side - inIdx[j] = len - outIdx[j]; - } + else // right side + inIdx[j] = len - outIdx[j]; + } - auto outOffset = shape::getOffset(output.shapeInfo(), outIdx); - auto inOffset = shape::getOffset(input.shapeInfo(), inIdx); - reinterpret_cast(output.buffer())[outOffset] = reinterpret_cast(input.buffer())[inOffset]; - } - }; + auto outOffset = shape::getOffset(output.shapeInfo(), outIdx); + auto inOffset = shape::getOffset(input.shapeInfo(), inIdx); + reinterpret_cast(output.buffer())[outOffset] = + reinterpret_cast(input.buffer())[inOffset]; + } + }; - samediff::Threads::parallel_for(func, 0, outLen); - } + samediff::Threads::parallel_for(func, 0, outLen); + } } - void mirrorPad(sd::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode) { - BUILD_SINGLE_SELECTOR(input.dataType(), mirrorPad_, (input, paddings, output, mode), LIBND4J_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template void mirrorPad_, (const NDArray& input, const NDArray& paddings, NDArray& output, const int mode), LIBND4J_TYPES); +void mirrorPad(sd::LaunchContext* context, const NDArray& input, + const NDArray& paddings, NDArray& output, const int mode) { + BUILD_SINGLE_SELECTOR(input.dataType(), mirrorPad_, + (input, paddings, output, mode), LIBND4J_TYPES); +} +BUILD_SINGLE_TEMPLATE(template void mirrorPad_, + (const NDArray& input, const NDArray& paddings, + NDArray& output, const int mode), + LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// /*// initial values of inIdx, outIdx, dim must be equal to zero template -static void recursiveLoopForPad_(const int mode, NDArray& input, const NDArray& paddings, NDArray& output, std::vector dimensions, int dim, int inIdx, int outIdx, NDArray& padValue ) { +static void recursiveLoopForPad_(const int mode, NDArray& input, const NDArray& +paddings, NDArray& output, std::vector dimensions, int dim, int inIdx, int +outIdx, NDArray& padValue ) { int leftOffset; - // dimensions are array of input dimensions, it is sorted in increasing order - // every time at the beginning we erase first element from it (not good idea to use vector for this purpose, but luckily it is small enough) - // then we use this array for tads building, every time while recursion the number of built tads becomes bigger - dimensions.erase(dimensions.begin()); - // build tad basing on output array, also create auxiliary arrays pointing on required output array ranges - shape::TAD tadOut(output.shapeInfo(), dimensions.data(), dimensions.size()); - tadOut.createTadOnlyShapeInfo(); + // dimensions are array of input dimensions, it is sorted in increasing +order + // every time at the beginning we erase first element from it (not good idea +to use vector for this purpose, but luckily it is small enough) + // then we use this array for tads building, every time while recursion the +number of built tads becomes bigger dimensions.erase(dimensions.begin()); + // build tad basing on output array, also create auxiliary arrays pointing +on required output array ranges shape::TAD tadOut(output.shapeInfo(), +dimensions.data(), dimensions.size()); tadOut.createTadOnlyShapeInfo(); tadOut.createOffsets(); - auto subArrOut = NDArray(output.getBuffer(), tadOut.tadOnlyShapeInfo, output.getContext()); - auto subArr = NDArray(output.getBuffer(), tadOut.tadOnlyShapeInfo, output.getContext()); - // build tad basing on input array, also create auxiliary array pointing on required input array range - shape::TAD tadIn(input.shapeInfo(), dimensions.data(), dimensions.size()); - tadIn.createTadOnlyShapeInfo(); + auto subArrOut = NDArray(output.getBuffer(), tadOut.tadOnlyShapeInfo, +output.getContext()); auto subArr = NDArray(output.getBuffer(), +tadOut.tadOnlyShapeInfo, output.getContext()); + // build tad basing on input array, also create auxiliary array pointing on +required input array range shape::TAD tadIn(input.shapeInfo(), +dimensions.data(), dimensions.size()); tadIn.createTadOnlyShapeInfo(); tadIn.createOffsets(); - auto subArrIn = NDArray(input.getBuffer(), tadIn.tadOnlyShapeInfo, output.getContext()); - // these indices take into account recursion and always point to actual tads numbers - if (input.rankOf() > 1 && output.rankOf() > 1) {// only for non-vector cases - outIdx = outIdx * output.sizeAt(dim + 1); - inIdx = inIdx * input.sizeAt(dim + 1); + auto subArrIn = NDArray(input.getBuffer(), tadIn.tadOnlyShapeInfo, +output.getContext()); + // these indices take into account recursion and always point to actual tads +numbers if (input.rankOf() > 1 && output.rankOf() > 1) {// only for non-vector +cases outIdx = outIdx * output.sizeAt(dim + 1); inIdx = inIdx * input.sizeAt(dim ++ 1); } // current input tad number, we add to it unity in a loop int k = -1; @@ -366,15 +382,16 @@ static void recursiveLoopForPad_(const int mode, NDArray& input, const NDArray& // increase input tads number ++k; - // recursion condition allows for the fact that tad can't reduce to scalar - if(dim < input.rankOf() - 2) - recursiveLoopForPad(mode, input, paddings, output, dimensions, dim + 1, inIdx + k, outIdx + i, padValue); - else if (paddings.sizeAt(0) > dim + 1){ - leftOffset = paddings.e(dim + 1, 0); + // recursion condition allows for the fact that tad can't reduce to +scalar if(dim < input.rankOf() - 2) recursiveLoopForPad(mode, input, paddings, +output, dimensions, dim + 1, inIdx + k, outIdx + i, padValue); else if +(paddings.sizeAt(0) > dim + 1){ leftOffset = paddings.e(dim + 1, 0); // shift buffers pointers to actual element position if (output.rankOf() > 1) { - subArrOut.setBuffer(reinterpret_cast(output.getBuffer()) + tadOut.tadOffsets[outIdx + i]); - subArrIn.setBuffer(reinterpret_cast(input.getBuffer()) + tadIn.tadOffsets[inIdx + i - paddings.e(dim, 0)]); + subArrOut.setBuffer(reinterpret_cast(output.getBuffer()) + +tadOut.tadOffsets[outIdx + i]); + subArrIn.setBuffer(reinterpret_cast(input.getBuffer()) + +tadIn.tadOffsets[inIdx + i - paddings.e(dim, 0)]); } else { subArrOut.p(i, subArrIn.e(i - leftOffset)); @@ -383,35 +400,37 @@ static void recursiveLoopForPad_(const int mode, NDArray& input, const NDArray& switch (mode) { case 0: // CONSTANT mode for(int j = 0; j < subArrOut.lengthOf(); ++j) - if(j < leftOffset || j >= (subArrIn.lengthOf() + leftOffset) ) // firstly fill with zeros outer ranges + if(j < leftOffset || j >= (subArrIn.lengthOf() + +leftOffset) ) // firstly fill with zeros outer ranges subArrOut.p(j, (T)0.f); else - subArrOut.p(j, subArrIn.e(j - leftOffset)); // fill middle with elements of input array - break; + subArrOut.p(j, subArrIn.e(j - leftOffset)); +// fill middle with elements of input array break; case 1: // REFLECT mode - for(int j = 1; j <= leftOffset; ++j) // fill firstly left side - subArrOut.p(leftOffset - j, subArrIn.e(j)); - for(int j = 0; j < subArrIn.lengthOf(); ++j) // fill middle + for(int j = 1; j <= leftOffset; ++j) // fill firstly left +side subArrOut.p(leftOffset - j, subArrIn.e(j)); for(int j = 0; j < +subArrIn.lengthOf(); ++j) // fill middle subArrOut.p(leftOffset + j, subArrIn.e(j)); - for(int j = (subArrOut.lengthOf() - leftOffset); j < subArrOut.lengthOf(); ++j) // fill right side - subArrOut.p(j, subArrIn.e(subArrOut.lengthOf() - j - 1)); - break; + for(int j = (subArrOut.lengthOf() - leftOffset); j < +subArrOut.lengthOf(); ++j) // fill right side subArrOut.p(j, +subArrIn.e(subArrOut.lengthOf() - j - 1)); break; case 2: // SYMMETRIC mode - for(int j = 1; j <= leftOffset; ++j) // fill firstly left side - subArrOut.p(leftOffset - j, subArrIn.e(j-1)); - for(int j = 0; j < subArrIn.lengthOf(); ++j) // fill middle + for(int j = 1; j <= leftOffset; ++j) // fill firstly left +side subArrOut.p(leftOffset - j, subArrIn.e(j-1)); for(int j = 0; j < +subArrIn.lengthOf(); ++j) // fill middle subArrOut.p(leftOffset + j, subArrIn.e(j)); - for(int j = (subArrOut.lengthOf() - leftOffset); j < subArrOut.lengthOf(); ++j) // fill right side - subArrOut.p(j, subArrIn.e(subArrOut.lengthOf() - j)); - break; + for(int j = (subArrOut.lengthOf() - leftOffset); j < +subArrOut.lengthOf(); ++j) // fill right side subArrOut.p(j, +subArrIn.e(subArrOut.lengthOf() - j)); break; } } else { if (mode == 0 && input.rankOf() < 2) - subArrOut.p(i, subArrIn.e(i - leftOffset)); // fill middle with elements of input array + subArrOut.p(i, subArrIn.e(i - leftOffset)); // fill middle +with elements of input array } } // populate sub-array formed previously @@ -422,18 +441,19 @@ static void recursiveLoopForPad_(const int mode, NDArray& input, const NDArray& // fill left side with padValue if (output.rankOf() > 1) { subArrOut.setBuffer( - reinterpret_cast(output.getBuffer()) + tadOut.tadOffsets[outIdx + leftOffset - j]); - subArrOut.assign(padValue); + reinterpret_cast(output.getBuffer()) + +tadOut.tadOffsets[outIdx + leftOffset - j]); subArrOut.assign(padValue); } else { subArrOut.p(j - 1, padValue); } } // output.printIndexedBuffer("Output at"); - for(int j = (output.sizeAt(dim) - leftOffset); j < output.sizeAt(dim); ++j) { // fill left side with zeros - if (output.rankOf() > 1) { - subArrOut.setBuffer(reinterpret_cast(output.getBuffer()) + tadOut.tadOffsets[outIdx + j]); - subArrOut.assign(padValue); + for(int j = (output.sizeAt(dim) - leftOffset); j < +output.sizeAt(dim); ++j) { // fill left side with zeros if +(output.rankOf() > 1) { + subArrOut.setBuffer(reinterpret_cast(output.getBuffer()) ++ tadOut.tadOffsets[outIdx + j]); subArrOut.assign(padValue); } else { subArrOut.p(j, padValue); @@ -442,42 +462,54 @@ static void recursiveLoopForPad_(const int mode, NDArray& input, const NDArray& break; case 1: // REFLECT mode - for(int j = 1; j <= leftOffset; ++j) { // fill left side - subArr.setBuffer(reinterpret_cast(output.getBuffer()) + tadOut.tadOffsets[outIdx + leftOffset + j]); - subArrOut.setBuffer(reinterpret_cast(output.getBuffer()) + tadOut.tadOffsets[outIdx + leftOffset - j]); - subArrOut.assign(&subArr); + for(int j = 1; j <= leftOffset; ++j) { // fill left side + subArr.setBuffer(reinterpret_cast(output.getBuffer()) + +tadOut.tadOffsets[outIdx + leftOffset + j]); + subArrOut.setBuffer(reinterpret_cast(output.getBuffer()) + +tadOut.tadOffsets[outIdx + leftOffset - j]); subArrOut.assign(&subArr); } - for(int j = (output.sizeAt(dim) - leftOffset); j < output.sizeAt(dim); ++j) { // fill right side - subArr.setBuffer(reinterpret_cast(output.getBuffer()) + tadOut.tadOffsets[outIdx + output.sizeAt(dim) + leftOffset - 1 - j]); - subArrOut.setBuffer(reinterpret_cast(output.getBuffer()) + tadOut.tadOffsets[outIdx + j]); - subArrOut.assign(&subArr); + for(int j = (output.sizeAt(dim) - leftOffset); j < +output.sizeAt(dim); ++j) { // fill right side + subArr.setBuffer(reinterpret_cast(output.getBuffer()) + +tadOut.tadOffsets[outIdx + output.sizeAt(dim) + leftOffset - 1 - j]); + subArrOut.setBuffer(reinterpret_cast(output.getBuffer()) + +tadOut.tadOffsets[outIdx + j]); subArrOut.assign(&subArr); } break; case 2: // SYMMETRIC mode - for(int j = 1; j <= leftOffset; ++j) { // fill left side - subArr.setBuffer(reinterpret_cast(output.getBuffer()) + tadOut.tadOffsets[outIdx + leftOffset + j - 1]); - subArrOut.setBuffer(reinterpret_cast(output.getBuffer()) + tadOut.tadOffsets[outIdx + leftOffset - j]); - subArrOut.assign(&subArr); + for(int j = 1; j <= leftOffset; ++j) { // fill left side + subArr.setBuffer(reinterpret_cast(output.getBuffer()) + +tadOut.tadOffsets[outIdx + leftOffset + j - 1]); + subArrOut.setBuffer(reinterpret_cast(output.getBuffer()) + +tadOut.tadOffsets[outIdx + leftOffset - j]); subArrOut.assign(&subArr); } - for(int j = (output.sizeAt(dim) - leftOffset); j < output.sizeAt(dim); ++j) { // fill right side - subArr.setBuffer(reinterpret_cast(output.getBuffer()) + tadOut.tadOffsets[outIdx + output.sizeAt(dim) + leftOffset - j]); - subArrOut.setBuffer(reinterpret_cast(output.getBuffer()) + tadOut.tadOffsets[outIdx + j]); - subArrOut.assign(&subArr); + for(int j = (output.sizeAt(dim) - leftOffset); j < +output.sizeAt(dim); ++j) { // fill right side + subArr.setBuffer(reinterpret_cast(output.getBuffer()) + +tadOut.tadOffsets[outIdx + output.sizeAt(dim) + leftOffset - j]); + subArrOut.setBuffer(reinterpret_cast(output.getBuffer()) + +tadOut.tadOffsets[outIdx + j]); subArrOut.assign(&subArr); } break; } } */ /* - void recursiveLoopForPad(const int mode, NDArray& input, const NDArray& paddings, NDArray& output, std::vector dimensions, int dim, int inIdx, int outIdx, NDArray& padValue ) { - BUILD_SINGLE_SELECTOR(input.dataType(), recursiveLoopForPad_, (mode, input, paddings, output, dimensions, dim, inIdx, outIdx, padValue), LIBND4J_TYPES); + void recursiveLoopForPad(const int mode, NDArray& input, const NDArray& + paddings, NDArray& output, std::vector dimensions, int dim, int inIdx, + int outIdx, NDArray& padValue ) { BUILD_SINGLE_SELECTOR(input.dataType(), + recursiveLoopForPad_, (mode, input, paddings, output, dimensions, dim, inIdx, + outIdx, padValue), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template void recursiveLoopForPad_, (const int mode, NDArray& input, const NDArray& paddings, NDArray& output, std::vector dimensions, int dim, int inIdx, int outIdx, NDArray& padValue), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void recursiveLoopForPad_, (const int mode, + NDArray& input, const NDArray& paddings, NDArray& output, std::vector + dimensions, int dim, int inIdx, int outIdx, NDArray& padValue), + LIBND4J_TYPES); */ -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/percentile.cpp b/libnd4j/include/ops/declarable/helpers/cpu/percentile.cpp index dea46cd69f16..9686c83317c6 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/percentile.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/percentile.cpp @@ -18,70 +18,82 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 17.05.2018 // -#include #include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// template -static void _percentile(const NDArray& input, NDArray& output, std::vector& axises, const float q, const int interpolation) { - - const int inputRank = input.rankOf(); - - if(axises.empty()) - for(int i=0; i shapeOfSubArr(listOfSubArrs.at(0)->rankOf()); - for(int i=0; ishapeOf()[i]; - - auto flattenedArr = NDArrayFactory::create('c', shapeOfSubArr, input.dataType(), input.getContext()); - const int len = flattenedArr.lengthOf(); - - const float fraction = 1.f - q / 100.; - Nd4jLong position = 0; - - switch(interpolation) { - case 0: // lower - position = static_cast(math::nd4j_ceil((len - 1) * fraction)); - break; - case 1: // higher - position = static_cast(math::nd4j_floor((len - 1) * fraction)); - break; - case 2: // nearest - position = static_cast(math::nd4j_round((len - 1) * fraction)); - break; - } - position = len - position - 1; - - // FIXME: our sort impl should be used instead, so this operation might be implemented as generic - // FIXME: parallelism ! - for(int i=0; i(flattenedArr.buffer()); - flattenedArr.assign(listOfSubArrs.at(i)); - std::sort(buff, buff + len); - output.p(i, flattenedArr.e(position)); - } +static void _percentile(const NDArray& input, NDArray& output, + std::vector& axises, const float q, + const int interpolation) { + const int inputRank = input.rankOf(); + + if (axises.empty()) + for (int i = 0; i < inputRank; ++i) axises.push_back(i); + else + shape::checkDimensions(inputRank, + axises); // check, sort dimensions and remove + // duplicates if they are present + + auto listOfSubArrs = input.allTensorsAlongDimension(axises); + + std::vector shapeOfSubArr(listOfSubArrs.at(0).rankOf()); + for (int i = 0; i < shapeOfSubArr.size(); ++i) + shapeOfSubArr[i] = listOfSubArrs.at(0).shapeOf()[i]; + + auto flattenedArr = NDArrayFactory::create( + 'c', shapeOfSubArr, input.dataType(), input.getContext()); + const int len = flattenedArr.lengthOf(); + + const float fraction = 1.f - q / 100.; + Nd4jLong position = 0; + + switch (interpolation) { + case 0: // lower + position = static_cast( + math::nd4j_ceil((len - 1) * fraction)); + break; + case 1: // higher + position = static_cast( + math::nd4j_floor((len - 1) * fraction)); + break; + case 2: // nearest + position = static_cast( + math::nd4j_round((len - 1) * fraction)); + break; + } + position = len - position - 1; + + // FIXME: our sort impl should be used instead, so this operation might be + // implemented as generic + // FIXME: parallelism ! + for (int i = 0; i < listOfSubArrs.size(); ++i) { + auto buff = reinterpret_cast(flattenedArr.buffer()); + flattenedArr.assign(listOfSubArrs.at(i)); + std::sort(buff, buff + len); + output.p(i, flattenedArr.e(position)); + } } - void percentile(sd::LaunchContext * context, const NDArray& input, NDArray& output, std::vector& axises, const float q, const int interpolation) { - BUILD_SINGLE_SELECTOR(input.dataType(), _percentile, (input, output, axises, q, interpolation), LIBND4J_TYPES); - } +void percentile(sd::LaunchContext* context, const NDArray& input, + NDArray& output, std::vector& axises, const float q, + const int interpolation) { + BUILD_SINGLE_SELECTOR(input.dataType(), _percentile, + (input, output, axises, q, interpolation), + LIBND4J_TYPES); +} - BUILD_SINGLE_TEMPLATE(template void _percentile, (const NDArray& input, NDArray& output, std::vector& axises, const float q, const int interpolation), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void _percentile, + (const NDArray& input, NDArray& output, + std::vector& axises, const float q, + const int interpolation), + LIBND4J_TYPES); -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/polyGamma.cpp b/libnd4j/include/ops/declarable/helpers/cpu/polyGamma.cpp index 2c93cee08895..57bf5fa51aa1 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/polyGamma.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/polyGamma.cpp @@ -18,80 +18,82 @@ // Created by Yurii Shyrma on 12.12.2017 // -#include -#include #include #include +#include +#include namespace sd { namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// // calculate factorial template static FORCEINLINE T getFactorial(const int n) { - if (n < 0) - throw std::runtime_error("factorial is not defined for negative number !"); + if (n < 0) + throw std::runtime_error("factorial is not defined for negative number !"); - if(n==0 || n==1) - return (T)1.f; + if (n == 0 || n == 1) return (T)1.f; - T result = (T)1.f; + T result = (T)1.f; - for(int i = 2; i <= n; ++i) - result *= i; + for (int i = 2; i <= n; ++i) result *= i; - return result; + return result; } ////////////////////////////////////////////////////////////////////////// -// implementation is based on serial representation written in terms of the Hurwitz zeta function as polygamma = (-1)^{n+1} * n! * zeta(n+1, x) +// implementation is based on serial representation written in terms of the +// Hurwitz zeta function as polygamma = (-1)^{n+1} * n! * zeta(n+1, x) template -static FORCEINLINE T polyGammaScalar(sd::LaunchContext * context, const int n, const T x) { - - // if (n < 0) - // throw("polyGamma function: n must be >= 0 !"); +static FORCEINLINE T polyGammaScalar(sd::LaunchContext* context, const int n, + const T x) { + // if (n < 0) + // throw("polyGamma function: n must be >= 0 !"); - // if (x <= (T)0.) - // throw("polyGamma function: x must be > 0 !"); + // if (x <= (T)0.) + // throw("polyGamma function: x must be > 0 !"); - int sign = (n + 1) % 2 ? -1 : 1; - // T factorial = (T)std::tgamma(n + 1); + int sign = (n + 1) % 2 ? -1 : 1; + // T factorial = (T)std::tgamma(n + 1); - return sign * getFactorial(n) * zetaScalar((T)(n + 1), x); + return sign * getFactorial(n) * zetaScalar((T)(n + 1), x); } - ////////////////////////////////////////////////////////////////////////// // calculate polygamma function for arrays template -static void polyGamma_(sd::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) { - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - const T order = n.e(i); - if(order != static_cast(order)) // if order has fractional part then do not perform calculations and return NAN - output.p(i, std::numeric_limits::quiet_NaN()); - else if (order == 0) // polygamma function of zero order is digamma function - output.p(i, diGammaScalar(x.e(i))); - else - output.p(i, polyGammaScalar(context, order, x.e(i))); - } - }; - samediff::Threads::parallel_for(func, 0, x.lengthOf()); +static void polyGamma_(sd::LaunchContext* context, const NDArray& n, + const NDArray& x, NDArray& output) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + const T order = n.e(i); + if (order != + static_cast(order)) // if order has fractional part then do not + // perform calculations and return NAN + output.p(i, std::numeric_limits::quiet_NaN()); + else if (order == + 0) // polygamma function of zero order is digamma function + output.p(i, diGammaScalar(x.e(i))); + else + output.p(i, polyGammaScalar(context, order, x.e(i))); + } + }; + samediff::Threads::parallel_for(func, 0, x.lengthOf()); } - void polyGamma(sd::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) { - BUILD_SINGLE_SELECTOR(x.dataType(), polyGamma_, (context, n, x, output), FLOAT_TYPES); - } - -BUILD_SINGLE_TEMPLATE(template void polyGamma_, (sd::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output), FLOAT_TYPES); - - - -} -} +void polyGamma(sd::LaunchContext* context, const NDArray& n, const NDArray& x, + NDArray& output) { + BUILD_SINGLE_SELECTOR(x.dataType(), polyGamma_, (context, n, x, output), + FLOAT_TYPES); } +BUILD_SINGLE_TEMPLATE(template void polyGamma_, + (sd::LaunchContext * context, const NDArray& n, + const NDArray& x, NDArray& output), + FLOAT_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/prefix.cpp b/libnd4j/include/ops/declarable/helpers/cpu/prefix.cpp index 1afe0355662c..ea1d66040635 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/prefix.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/prefix.cpp @@ -18,118 +18,136 @@ // @author raver119@gmail.com // -#include -#include #include +#include #include +#include namespace sd { - namespace ops { - namespace helpers { - template - static void prefix_(scalar::Ops op, const void* vx, Nd4jLong const* xShapeInfo, void* vz, Nd4jLong const* zShapeInfo, bool exclusive, bool reverse) { - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto length = shape::length(xShapeInfo); - - T prevSum = op == scalar::Add ? (T) 0 : (T) 1; - T sum = prevSum; - - if (reverse) { - if (shape::elementWiseStride(xShapeInfo) == 1 && shape::elementWiseStride(zShapeInfo) == 1 && - shape::order(xShapeInfo) == 'c' && shape::order(zShapeInfo) == 'c') { - - for (Nd4jLong e = length - 1; e >= 0; --e) { - sum = op == scalar::Add ? simdOps::Add::op(sum, x[e]) : simdOps::Multiply::op(sum, x[e]); - if (!exclusive) - prevSum = sum; - - z[e] = prevSum; - - prevSum = sum; - } - } - else { - - for (Nd4jLong e = length - 1; e >= 0; --e) { - - auto xOffset = shape::getIndexOffset(e, xShapeInfo); - auto zOffset = shape::getIndexOffset(e, zShapeInfo); - sum = op == scalar::Add ? simdOps::Add::op(sum, x[xOffset]) : simdOps::Multiply::op(sum, x[xOffset]); - - if (!exclusive) - prevSum = sum; - - z[zOffset] = prevSum; - prevSum = sum; - } - } - } else { - if (shape::elementWiseStride(xShapeInfo) == 1 && shape::elementWiseStride(zShapeInfo) == 1 && - shape::order(xShapeInfo) == 'c' && shape::order(zShapeInfo) == 'c') { - - for (Nd4jLong e = 0; e < length; e++) { - sum = op == scalar::Add ? simdOps::Add::op(sum, x[e]) : simdOps::Multiply::op(sum, x[e]); - - if (!exclusive) - prevSum = sum; - - z[e] = prevSum; - - prevSum = sum; - } - } - else { - - for (Nd4jLong e = 0; e < length; e++) { - - auto xOffset = shape::getIndexOffset(e, xShapeInfo); - auto zOffset = shape::getIndexOffset(e, zShapeInfo); - sum = op == scalar::Add ? simdOps::Add::op(sum, x[xOffset]) : simdOps::Multiply::op(sum, x[xOffset]); - - if (!exclusive) - prevSum = sum; - - z[zOffset] = prevSum; - prevSum = sum; - } - } - } - }; - - template - static void prefix_(scalar::Ops op, const NDArray* x, NDArray* z, const std::vector& dims, bool exclusive, bool reverse) { - auto xTads = x->allTensorsAlongDimension(dims); - auto zTads = z->allTensorsAlongDimension(dims); - auto t = xTads.size(); - - for (int e = 0; e < t; e++) { - auto tx = xTads.at(e); - auto tz = zTads.at(e); - - prefix_(op, tx->buffer(), tx->shapeInfo(), tz->buffer(), tz->shapeInfo(), exclusive, reverse); - } - }; - - template - static void prefix_(scalar::Ops op, const NDArray* x, NDArray* z, bool exclusive, bool reverse) { - prefix_(op, x->buffer(), x->shapeInfo(), z->buffer(), z->shapeInfo(), exclusive, reverse); - }; - - void prefix(sd::LaunchContext * context, scalar::Ops op, const NDArray* x, NDArray* z, bool exclusive, bool reverse) { - BUILD_SINGLE_SELECTOR(x->dataType(), prefix_, (op, x, z, exclusive, reverse), LIBND4J_TYPES); - } - - void prefix(sd::LaunchContext * context, scalar::Ops op, const NDArray* x, NDArray* z, const std::vector& dims, bool exclusive, bool reverse) { - BUILD_SINGLE_SELECTOR(x->dataType(), prefix_, (op, x, z, dims, exclusive, reverse), LIBND4J_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template void prefix_, (scalar::Ops op, const void* vx, Nd4jLong const* xShapeInfo, void* vz, Nd4jLong const* zShapeInfo, bool exclusive, bool reverse), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void prefix_, (scalar::Ops op, const NDArray* x, NDArray* z, const std::vector& dims, bool exclusive, bool reverse), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void prefix_, (scalar::Ops op, const NDArray* x, NDArray* z, bool exclusive, bool reverse), LIBND4J_TYPES); - - - - } +namespace ops { +namespace helpers { +template +static void prefix_(scalar::Ops op, const void* vx, Nd4jLong const* xShapeInfo, + void* vz, Nd4jLong const* zShapeInfo, bool exclusive, + bool reverse) { + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto length = shape::length(xShapeInfo); + + T prevSum = op == scalar::Add ? (T)0 : (T)1; + T sum = prevSum; + + if (reverse) { + if (shape::elementWiseStride(xShapeInfo) == 1 && + shape::elementWiseStride(zShapeInfo) == 1 && + shape::order(xShapeInfo) == 'c' && shape::order(zShapeInfo) == 'c') { + for (Nd4jLong e = length - 1; e >= 0; --e) { + sum = op == scalar::Add ? simdOps::Add::op(sum, x[e]) + : simdOps::Multiply::op(sum, x[e]); + if (!exclusive) prevSum = sum; + + z[e] = prevSum; + + prevSum = sum; + } + } else { + for (Nd4jLong e = length - 1; e >= 0; --e) { + auto xOffset = shape::getIndexOffset(e, xShapeInfo); + auto zOffset = shape::getIndexOffset(e, zShapeInfo); + sum = op == scalar::Add + ? simdOps::Add::op(sum, x[xOffset]) + : simdOps::Multiply::op(sum, x[xOffset]); + + if (!exclusive) prevSum = sum; + + z[zOffset] = prevSum; + prevSum = sum; + } + } + } else { + if (shape::elementWiseStride(xShapeInfo) == 1 && + shape::elementWiseStride(zShapeInfo) == 1 && + shape::order(xShapeInfo) == 'c' && shape::order(zShapeInfo) == 'c') { + for (Nd4jLong e = 0; e < length; e++) { + sum = op == scalar::Add ? simdOps::Add::op(sum, x[e]) + : simdOps::Multiply::op(sum, x[e]); + + if (!exclusive) prevSum = sum; + + z[e] = prevSum; + + prevSum = sum; + } + } else { + for (Nd4jLong e = 0; e < length; e++) { + auto xOffset = shape::getIndexOffset(e, xShapeInfo); + auto zOffset = shape::getIndexOffset(e, zShapeInfo); + sum = op == scalar::Add + ? simdOps::Add::op(sum, x[xOffset]) + : simdOps::Multiply::op(sum, x[xOffset]); + + if (!exclusive) prevSum = sum; + + z[zOffset] = prevSum; + prevSum = sum; + } } -} \ No newline at end of file + } +}; + +template +static void prefix_(scalar::Ops op, const NDArray* x, NDArray* z, + const std::vector& dims, bool exclusive, + bool reverse) { + auto xTads = x->allTensorsAlongDimension(dims); + auto zTads = z->allTensorsAlongDimension(dims); + auto t = xTads.size(); + + for (int e = 0; e < t; e++) { + auto tx = xTads.at(e); + auto tz = zTads.at(e); + + prefix_(op, tx.buffer(), tx.shapeInfo(), tz.buffer(), tz.shapeInfo(), + exclusive, reverse); + } +}; + +template +static void prefix_(scalar::Ops op, const NDArray* x, NDArray* z, + bool exclusive, bool reverse) { + prefix_(op, x->buffer(), x->shapeInfo(), z->buffer(), z->shapeInfo(), + exclusive, reverse); +}; + +void prefix(sd::LaunchContext* context, scalar::Ops op, const NDArray* x, + NDArray* z, bool exclusive, bool reverse) { + BUILD_SINGLE_SELECTOR(x->dataType(), prefix_, (op, x, z, exclusive, reverse), + LIBND4J_TYPES); +} + +void prefix(sd::LaunchContext* context, scalar::Ops op, const NDArray* x, + NDArray* z, const std::vector& dims, bool exclusive, + bool reverse) { + BUILD_SINGLE_SELECTOR(x->dataType(), prefix_, + (op, x, z, dims, exclusive, reverse), LIBND4J_TYPES); +} + +BUILD_SINGLE_TEMPLATE(template void prefix_, + (scalar::Ops op, const void* vx, + Nd4jLong const* xShapeInfo, void* vz, + Nd4jLong const* zShapeInfo, bool exclusive, + bool reverse), + LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void prefix_, + (scalar::Ops op, const NDArray* x, NDArray* z, + const std::vector& dims, bool exclusive, + bool reverse), + LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void prefix_, + (scalar::Ops op, const NDArray* x, NDArray* z, + bool exclusive, bool reverse), + LIBND4J_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/print_variable.cpp b/libnd4j/include/ops/declarable/helpers/cpu/print_variable.cpp index 26a24a5afcc2..7790f2a6f156 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/print_variable.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/print_variable.cpp @@ -21,11 +21,12 @@ #include namespace sd { - namespace ops { - namespace helpers { - void print_special(LaunchContext &ctx, const NDArray &array, const std::string &message) { - array.printIndexedBuffer(message.c_str()); - } - } - } +namespace ops { +namespace helpers { +void print_special(LaunchContext &ctx, const NDArray &array, + const std::string &message) { + array.printIndexedBuffer(message.c_str()); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp b/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp index 1f980e553697..45e29354ec65 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2019-2020 Konduit K.K. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -17,117 +17,132 @@ // // @author George A. Shulinok // -#include -#include -#include #include +#include +#include +#include namespace sd { namespace ops { namespace helpers { - template - NDArray matrixMinor(NDArray& in, Nd4jLong col) { - NDArray m = in.ulike(); - m.setIdentity(); - m({col, m.rows(), col, m.columns()}).assign(in({col, m.rows(), col, m.columns()})); +template +NDArray matrixMinor(NDArray& in, Nd4jLong col) { + NDArray m = in.ulike(); + m.setIdentity(); + m({col, m.rows(), col, m.columns()}) + .assign(in({col, m.rows(), col, m.columns()})); - return m; - } + return m; +} /* m = I - v v^T */ - template - NDArray vmul(NDArray const& v, int n) - { - NDArray res('c', {n,n}, v.dataType(), v.getContext()); // x = matrix_new(n, n); - T const* vBuf = v.getDataBuffer()->primaryAsT(); - T* resBuf = res.dataBuffer()->primaryAsT(); - auto interloop = PRAGMA_THREADS_FOR_2D { - for (auto i = start_x; i < n; i += inc_x) - for (auto j = start_y; j < n; j += inc_y) - resBuf[i * n + j] = -2 * vBuf[i] * vBuf[j] + (i == j ? T(1) : T(0)); - }; - - samediff::Threads::parallel_for(interloop, 0, n, 1, 0, n, 1); - return res; - } - - template - void qrSingle(NDArray* matrix, NDArray* Q, NDArray* R, bool const fullMatricies) { - Nd4jLong M = matrix->sizeAt(-2); - Nd4jLong N = matrix->sizeAt(-1); - auto resQ = fullMatricies?Q->ulike():NDArrayFactory::create(matrix->ordering(), {M,M}, Q->getContext()); - auto resR = fullMatricies?R->ulike():matrix->ulike(); - std::vector q(M); - - NDArray z = *matrix; - NDArray e('c', {M}, DataTypeUtils::fromT(), Q->getContext()); // two internal buffers and scalar for squared norm - - for (Nd4jLong k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number - e.nullify(); - z = matrixMinor(z, k); // minor computing for current column with given matrix z (initally is a input matrix) -// z.printIndexedBuffer("Minor!!!"); - - auto currentColumn = z({0, 0, k, k + 1}); // retrieve k column from z to x buffer - auto norm = currentColumn.reduceAlongDimension(reduce::Norm2, {0}); - if (matrix->t(k,k) > T(0.f)) // negate on positive matrix diagonal element - norm *= T(-1.f);//.applyTransform(transform::Neg, nullptr, nullptr); //t(0) = -norm.t(0); - //e.t(k) = T(1.f); // e - is filled by 0 vector except diagonal element (filled by 1) - //auto tE = e; - //tE *= norm; -// norm.printIndexedBuffer("Norm!!!"); - e.p(k, norm); - e += currentColumn;// e += tE; // e[i] = x[i] + a * e[i] for each i from 0 to n - 1 - auto normE = e.reduceAlongDimension(reduce::Norm2, {0}); - e /= normE; - q[k] = vmul(e, M); - auto qQ = z.ulike(); - MmulHelper::matmul(&q[k], &z, &qQ, false, false); - z = std::move(qQ); - } - resQ.assign(q[0]); // -// MmulHelper::matmul(&q[0], matrix, &resR, false, false); - for (Nd4jLong i = 1; i < N && i < M - 1; i++) { - auto tempResQ = resQ; - MmulHelper::matmul(&q[i], &resQ, &tempResQ, false, false); // use mmulMxM? - resQ = std::move(tempResQ); - } - MmulHelper::matmul(&resQ, matrix, &resR, false, false); - // resR *= -1.f; - resQ.transposei(); - if (fullMatricies) { - Q->assign(resQ); - R->assign(resR); - } - else { - Q->assign(resQ({0,0, 0, N})); - R->assign(resR({0,N, 0, 0})); - } - } - - template - void qr_(NDArray const* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) { - Nd4jLong lastDim = input->rankOf() - 1; - Nd4jLong preLastDim = input->rankOf() - 2; - ResultSet listOutQ(outputQ->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); - ResultSet listOutR(outputR->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); - ResultSet listInput(input->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); - auto batching = PRAGMA_THREADS_FOR { - for (auto batch = start; batch < stop; batch++) { - //qr here - qrSingle(listInput.at(batch), listOutQ.at(batch), listOutR.at(batch), fullMatricies); - } - }; - - samediff::Threads::parallel_tad(batching, 0, listOutQ.size(), 1); +template +NDArray vmul(NDArray const& v, int n) { + NDArray res('c', {n, n}, v.dataType(), + v.getContext()); // x = matrix_new(n, n); + T const* vBuf = v.getDataBuffer()->primaryAsT(); + T* resBuf = res.dataBuffer()->primaryAsT(); + auto interloop = PRAGMA_THREADS_FOR_2D { + for (auto i = start_x; i < n; i += inc_x) + for (auto j = start_y; j < n; j += inc_y) + resBuf[i * n + j] = -2 * vBuf[i] * vBuf[j] + (i == j ? T(1) : T(0)); + }; + + samediff::Threads::parallel_for(interloop, 0, n, 1, 0, n, 1); + return res; +} - } +template +void qrSingle(NDArray* matrix, NDArray* Q, NDArray* R, + bool const fullMatricies) { + Nd4jLong M = matrix->sizeAt(-2); + Nd4jLong N = matrix->sizeAt(-1); + auto resQ = fullMatricies ? Q->ulike() + : NDArrayFactory::create( + matrix->ordering(), {M, M}, Q->getContext()); + auto resR = fullMatricies ? R->ulike() : matrix->ulike(); + std::vector q(M); + + NDArray z = matrix->dup(); + NDArray e( + 'c', {M}, DataTypeUtils::fromT(), + Q->getContext()); // two internal buffers and scalar for squared norm + + for (Nd4jLong k = 0; k < N && k < M - 1; + k++) { // loop for columns, but not further then row number + e.nullify(); + z = matrixMinor(z, k); // minor computing for current column with given + // matrix z (initally is a input matrix) + // z.printIndexedBuffer("Minor!!!"); + + auto currentColumn = + z({0, 0, k, k + 1}); // retrieve k column from z to x buffer + auto norm = currentColumn.reduceAlongDimension(reduce::Norm2, {0}); + if (matrix->t(k, k) > + T(0.f)) // negate on positive matrix diagonal element + norm *= T(-1.f); //.applyTransform(transform::Neg, nullptr, nullptr); + ////t(0) = -norm.t(0); + // e.t(k) = T(1.f); // e - is filled by 0 vector except diagonal element + // (filled by 1) auto tE = e; tE *= norm; + // norm.printIndexedBuffer("Norm!!!"); + e.p(k, norm); + e += currentColumn; // e += tE; // e[i] = x[i] + a * e[i] for each i from + // 0 to n - 1 + auto normE = e.reduceAlongDimension(reduce::Norm2, {0}); + e /= normE; + q[k] = vmul(e, M); + auto qQ = z.ulike(); + MmulHelper::matmul(&q[k], &z, &qQ, false, false); + z = std::move(qQ); + } + resQ.assign(q[0]); // + // MmulHelper::matmul(&q[0], matrix, &resR, false, false); + for (Nd4jLong i = 1; i < N && i < M - 1; i++) { + auto tempResQ = resQ.ulike(); + MmulHelper::matmul(&q[i], &resQ, &tempResQ, false, false); // use mmulMxM? + resQ = std::move(tempResQ); + } + MmulHelper::matmul(&resQ, matrix, &resR, false, false); + // resR *= -1.f; + resQ.transposei(); + if (fullMatricies) { + Q->assign(resQ); + R->assign(resR); + } else { + Q->assign(resQ({0, 0, 0, N})); + R->assign(resR({0, N, 0, 0})); + } +} - void qr(sd::LaunchContext* context, NDArray const* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) { - BUILD_SINGLE_SELECTOR(input->dataType(), qr_, (input, outputQ, outputR, fullMatricies), FLOAT_TYPES); +template +void qr_(NDArray const* input, NDArray* outputQ, NDArray* outputR, + bool const fullMatricies) { + Nd4jLong lastDim = input->rankOf() - 1; + Nd4jLong preLastDim = input->rankOf() - 2; + ResultSet listOutQ( + outputQ->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); + ResultSet listOutR( + outputR->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); + ResultSet listInput( + input->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); + auto batching = PRAGMA_THREADS_FOR { + for (auto batch = start; batch < stop; batch++) { + // qr here + qrSingle(&listInput.at(batch), &listOutQ.at(batch), + &listOutR.at(batch), fullMatricies); } + }; + samediff::Threads::parallel_tad(batching, 0, listOutQ.size(), 1); } -} + +void qr(sd::LaunchContext* context, NDArray const* input, NDArray* outputQ, + NDArray* outputR, bool const fullMatricies) { + BUILD_SINGLE_SELECTOR(input->dataType(), qr_, + (input, outputQ, outputR, fullMatricies), FLOAT_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/random.cpp b/libnd4j/include/ops/declarable/helpers/cpu/random.cpp index d96b3017568b..77fc5045152c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/random.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/random.cpp @@ -22,16 +22,16 @@ //#include #include //#include -#include -#include #include #include +#include +#include namespace sd { namespace ops { namespace helpers { - /** +/** * gammaLess - compute gamma distributed value for shapes (alpha) from 0 to 1 * @tparam T - any float types are acceptable * @param rng - random generator for uniformly vals @@ -110,181 +110,197 @@ namespace helpers { break; } return (decreasedAlpha * normalizedVar / beta); - } - - template - void fillRandomGamma_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) { - - auto broadcasted = alpha->shapeInfo(); - if (beta != nullptr) { - const Nd4jLong* broadcastedShape = nullptr; - ShapeUtils::evalBroadcastShapeInfo(*alpha, *beta, true, broadcastedShape, context->getWorkspace()); - broadcasted = broadcastedShape; - } - - auto step = shape::length(broadcasted); - auto shift = output->lengthOf() / step; - - auto copyAlpha = alpha; - auto copyBeta = beta; - if (beta != nullptr) { - NDArray alphaBroadcasted(broadcasted, alpha->dataType(), false, context); - NDArray betaBroadcasted(broadcasted, beta->dataType(), false, context); - - copyAlpha = new NDArray(alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *alpha)); - copyBeta = new NDArray(betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *beta)); - } - bool directOutput = output->ews() == 1 && output->ordering() == 'c'; - T* outputBuf = output->dataBuffer()->primaryAsT(); - - PRAGMA_OMP_PARALLEL_FOR - for (Nd4jLong k = 0; k < shift; k++) { - auto pos = k * step; - for (Nd4jLong e = 0; e < step; e++) - if (directOutput) { - outputBuf[pos + e] = copyAlpha->t(e) <= 1? gammaLess(rng, copyAlpha->t(e), beta?copyBeta->t(e):T(1.f)):gammaGreat(rng, copyAlpha->t(e), beta?copyBeta->t(e):T(1.f)); - } - else { - output->r(pos + e) = copyAlpha->t(e) <= 1? gammaLess(rng, copyAlpha->t(e), beta?copyBeta->t(e):T(1.f)):gammaGreat(rng, copyAlpha->t(e), beta?copyBeta->t(e):T(1.f)); - } - } + }template +void fillRandomGamma_(LaunchContext* context, graph::RandomGenerator& rng, + NDArray* alpha, NDArray* beta, NDArray* output) { + auto broadcasted = alpha->shapeInfo(); + if (beta != nullptr) { + const Nd4jLong* broadcastedShape = nullptr; + ShapeUtils::evalBroadcastShapeInfo(*alpha, *beta, true, broadcastedShape, + context->getWorkspace()); + broadcasted = broadcastedShape; + } + + auto step = shape::length(broadcasted); + auto shift = output->lengthOf() / step; + + auto copyAlpha = alpha; + auto copyBeta = beta; + if (beta != nullptr) { + NDArray alphaBroadcasted(broadcasted, alpha->dataType(), false, context); + NDArray betaBroadcasted(broadcasted, beta->dataType(), false, context); + + copyAlpha = new NDArray(alphaBroadcasted.applyTrueBroadcast( + BroadcastOpsTuple::Assign(), *alpha)); + copyBeta = new NDArray( + betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *beta));} + + bool directOutput = output->ews() == 1 && output->ordering() == 'c'; + T* outputBuf = output->dataBuffer()->primaryAsT(); + + PRAGMA_OMP_PARALLEL_FOR + for (Nd4jLong k = 0; k < shift; k++) { + auto pos = k * step; + + for (Nd4jLong e = 0; e < step; e++) + if (directOutput) { + outputBuf[pos + e] = copyAlpha->t(e) <= 1? gammaLess(rng, copyAlpha->t(e), beta?copyBeta->t(e):T( + 1.f)):gammaGreat(rng,copyAlpha->t(e), beta ?copyBeta->t(e):T(1.f)); + } else { + output->r(pos + e) = copyAlpha->t(e) <= 1? gammaLess(rng, copyAlpha->t(e), beta?copyBeta->t(e):T( + 1.f)):gammaGreat(rng,copyAlpha->t(e), beta ?copyBeta->t(e):T(1.f)); + } + } + + if (beta != nullptr) { + delete copyAlpha; + delete copyBeta; + // delete broadcasted; + } +} - if (beta != nullptr) { - delete copyAlpha; - delete copyBeta; - //delete broadcasted; - } +void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, + NDArray* alpha, NDArray* beta, NDArray* output) { + BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomGamma_, + (context, rng, alpha, beta, output), FLOAT_NATIVE); +} +BUILD_SINGLE_TEMPLATE(template void fillRandomGamma_, + (LaunchContext * context, graph::RandomGenerator& rng, + NDArray* alpha, NDArray* beta, NDArray* output), + FLOAT_NATIVE); + +/* + * algorithm Poisson generator based upon the inversion by sequential +search:[48]:505 init: Let x ← 0, p ← e−λ, s ← p. Generate uniform random number +u in [0,1]. while u > s do: x ← x + 1. p ← p * λ / x. s ← s + p. return x. + * */ +template +void fillRandomPoisson_(LaunchContext* context, graph::RandomGenerator& rng, + NDArray* lambda, NDArray* output) { + auto shift = output->lengthOf() / lambda->lengthOf(); + auto step = lambda->lengthOf(); + T* lambdaBuf = lambda->dataBuffer()->primaryAsT(); + T* outputBuf = output->dataBuffer()->primaryAsT(); + bool directLa = lambda->ews() == 1 && lambda->ordering() == 'c'; + bool directOut = output->ews() == 1 && output->ordering() == 'c'; + PRAGMA_OMP_PARALLEL_FOR + for (Nd4jLong k = 0; k < shift; k++) { + auto pos = k * step; + auto u = rng.relativeT(k, 0., 1.); + for (Nd4jLong e = 0; e < step; e++) { + auto p = math::nd4j_exp(-lambda->t(e)); + auto s = p; + auto x = T(0.f); + while (u > s) { + x += 1.f; + p *= directLa ? lambdaBuf[e] / x : lambda->t(e) / x; + s += p; + } + if (directOut) + outputBuf[pos + e] = x; + else + output->r(pos + e) = x; } + } +} - void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) { - BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomGamma_, (context, rng, alpha, beta, output), FLOAT_NATIVE); - } - BUILD_SINGLE_TEMPLATE(template void fillRandomGamma_, (LaunchContext* context, - graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output), FLOAT_NATIVE); - - /* - * algorithm Poisson generator based upon the inversion by sequential search:[48]:505 - init: - Let x ← 0, p ← e−λ, s ← p. - Generate uniform random number u in [0,1]. - while u > s do: - x ← x + 1. - p ← p * λ / x. - s ← s + p. - return x. - * */ - template - void fillRandomPoisson_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output) { - auto shift = output->lengthOf() / lambda->lengthOf(); - auto step = lambda->lengthOf(); - T* lambdaBuf = lambda->dataBuffer()->primaryAsT(); - T* outputBuf = output->dataBuffer()->primaryAsT(); - bool directLa = lambda->ews() == 1 && lambda->ordering() == 'c'; - bool directOut = output->ews() == 1 && output->ordering() == 'c'; - PRAGMA_OMP_PARALLEL_FOR - for (Nd4jLong k = 0; k < shift; k++) { - auto pos = k * step; - auto u = rng.relativeT(k, 0., 1.); - for (Nd4jLong e = 0; e < step; e++) { - auto p = math::nd4j_exp(-lambda->t(e)); - auto s = p; - auto x = T(0.f); - while (u > s) { - x += 1.f; - p *= directLa?lambdaBuf[e]/x:lambda->t(e) / x; - s += p; - } - if (directOut) - outputBuf[pos + e] = x; - else - output->r(pos + e) = x; - } - } +void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, + NDArray* lambda, NDArray* output) { + BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomPoisson_, + (context, rng, lambda, output), FLOAT_NATIVE); +} +BUILD_SINGLE_TEMPLATE(template void fillRandomPoisson_, + (LaunchContext * context, graph::RandomGenerator& rng, + NDArray* lambda, NDArray* output), + FLOAT_TYPES); + +template +void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, + NDArray* min, NDArray* max, NDArray* output) { + T minVal = T(0); + T maxVal = DataTypeUtils::max(); + if (min) minVal = min->t(0); + if (max) maxVal = max->t(0); + + if (output->isR()) + RandomLauncher::fillUniform(context, rng, output, minVal, maxVal); + else { + PRAGMA_OMP_PARALLEL_FOR + for (Nd4jLong i = 0; i < output->lengthOf(); i++) { + output->r(i) = rng.relativeT(i, minVal, maxVal); } + } +} - void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output) { - BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomPoisson_, (context, rng, lambda, output), FLOAT_NATIVE); - } - BUILD_SINGLE_TEMPLATE(template void fillRandomPoisson_, (LaunchContext* context, - graph::RandomGenerator& rng, NDArray* lambda, NDArray* output), FLOAT_TYPES); +void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, + NDArray* min, NDArray* max, NDArray* output) { + BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, + (context, rng, min, max, output), NUMERIC_TYPES); +} - template - void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) { - T minVal = T(0); - T maxVal = DataTypeUtils::max(); - if (min) - minVal = min->t(0); - if (max) - maxVal = max->t(0); - - if (output->isR()) - RandomLauncher::fillUniform(context, rng, output, minVal, maxVal); - else { - PRAGMA_OMP_PARALLEL_FOR - for (Nd4jLong i = 0; i < output->lengthOf(); i++) { - output->r(i) = rng.relativeT(i, minVal, maxVal); - } +// used https://en.wikipedia.org/wiki/Categorical_distribution +// methods: gumbel trick + softmax + argmax +template +void fillRandomMultiNomial_(LaunchContext* context, graph::RandomGenerator& rng, + NDArray& input, NDArray& output, + const Nd4jLong numOfSamples, const int dimC) { + const Tx* x = input.bufferAsT(); + Tz* z = output.bufferAsT(); + + Tx minVal = DataTypeUtils::min(); + Tx maxVal = 1.0; + + auto dimA = (0 == dimC) ? 1 : 0; + const Nd4jLong batchValue = output.sizeAt(dimC); + const Nd4jLong numOfClassX = input.sizeAt(dimA); + + const Nd4jLong zDimAstride = output.stridesOf()[dimA]; + const Nd4jLong xDimAstride = input.stridesOf()[dimA]; + const Nd4jLong zDimCstride = output.stridesOf()[dimC]; + const Nd4jLong xDimCstride = input.stridesOf()[dimC]; + + auto func = PRAGMA_THREADS_FOR_2D { + for (auto nBatchIndex = start_x; nBatchIndex < stop_x; + nBatchIndex += inc_x) { + for (auto nSampleIndexInBatch = start_y; nSampleIndexInBatch < stop_y; + nSampleIndexInBatch += inc_y) { + const Tx* xTad = x + (nBatchIndex * xDimCstride); + Tz* zTad = z + (nBatchIndex * zDimCstride); + Tz& arg = zTad[nSampleIndexInBatch * zDimAstride]; + Tx Max = -minVal; + + auto nSamplesPerBatch = nBatchIndex * numOfClassX * numOfSamples; + auto nClassesPerSample = nSampleIndexInBatch * numOfClassX; + for (Nd4jLong nClass = 0; nClass < numOfClassX; nClass += 1) { + auto nIndex = nSamplesPerBatch + nClassesPerSample + nClass; + auto unifornLog = + sd::math::nd4j_log(-sd::math::nd4j_log( + rng.relativeT(nIndex, minVal, maxVal))); + Tx tValue = (xTad[nClass * xDimAstride] - unifornLog); + if (tValue > Max) { + Max = tValue; + arg = nClass; + } } + } } + }; - void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) { - BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES); - } - - // used https://en.wikipedia.org/wiki/Categorical_distribution - // methods: gumbel trick + softmax + argmax - template - void fillRandomMultiNomial_(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC) { - - const Tx* x = input.bufferAsT(); - Tz* z = output.bufferAsT(); - - Tx minVal = DataTypeUtils::min(); - Tx maxVal = 1.0; - - auto dimA = (0 == dimC) ? 1 : 0; - const Nd4jLong batchValue = output.sizeAt(dimC); - const Nd4jLong numOfClassX = input.sizeAt(dimA); - - const Nd4jLong zDimAstride = output.stridesOf()[dimA]; - const Nd4jLong xDimAstride = input.stridesOf()[dimA]; - const Nd4jLong zDimCstride = output.stridesOf()[dimC]; - const Nd4jLong xDimCstride = input.stridesOf()[dimC]; - - auto func = PRAGMA_THREADS_FOR_2D{ - for (auto nBatchIndex = start_x; nBatchIndex < stop_x; nBatchIndex += inc_x) { - for (auto nSampleIndexInBatch = start_y; nSampleIndexInBatch < stop_y; nSampleIndexInBatch += inc_y) { - - const Tx* xTad = x + (nBatchIndex * xDimCstride); - Tz* zTad = z + (nBatchIndex * zDimCstride); - Tz& arg = zTad[nSampleIndexInBatch * zDimAstride]; - Tx Max = -minVal; - - auto nSamplesPerBatch = nBatchIndex * numOfClassX * numOfSamples; - auto nClassesPerSample = nSampleIndexInBatch * numOfClassX; - for (Nd4jLong nClass = 0; nClass < numOfClassX; nClass += 1) { - auto nIndex = nSamplesPerBatch + nClassesPerSample + nClass; - auto unifornLog = sd::math::nd4j_log(-sd::math::nd4j_log(rng.relativeT(nIndex, minVal, maxVal))); - Tx tValue = (xTad[nClass * xDimAstride] - unifornLog); - if (tValue > Max) { - Max = tValue; - arg = nClass; - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, batchValue, 1, 0, numOfSamples, 1); - rng.rewindH(output.lengthOf()*numOfClassX); - - return; - } - - void fillRandomMultiNomial(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC) { - BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), fillRandomMultiNomial_, (context, rng, input, output, numOfSamples, dimC), FLOAT_TYPES, INDEXING_TYPES); - } + samediff::Threads::parallel_for(func, 0, batchValue, 1, 0, numOfSamples, 1); + rng.rewindH(output.lengthOf() * numOfClassX); + return; } + +void fillRandomMultiNomial(LaunchContext* context, graph::RandomGenerator& rng, + NDArray& input, NDArray& output, + const Nd4jLong numOfSamples, const int dimC) { + BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), + fillRandomMultiNomial_, + (context, rng, input, output, numOfSamples, dimC), + FLOAT_TYPES, INDEXING_TYPES); } -} + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp b/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp index 2ffbfc95f1c5..15dbc2fa0fb8 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp @@ -19,16 +19,15 @@ // implementation is based on following article: // "MergeShuffle: A Very Fast, Parallel Random Permutation Algorithm", https://arxiv.org/abs/1508.03167 - - -#include -#include #include -#include +#include #include +#include -namespace sd { -namespace ops { +#include + +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// @@ -114,7 +113,7 @@ static void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGen } }; - auto funcMerge = PRAGMA_THREADS_FOR { + auto funcMerge = PRAGMA_THREADS_FOR { for (int64_t i = start, k = 1; i < stop; i += increment, ++k) { Nd4jLong offset = len * i >> power; @@ -161,7 +160,7 @@ static void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGen for(int i = firstDim - 1; i > 0; --i) { const int j = rng.relativeInt(i) % (i + 1); if(i != j) - subArrsList.at(i)->swapUnsafe(*subArrsList.at(j)); + subArrsList.at(i).swapUnsafe(subArrsList.at(j)); } } else { @@ -178,7 +177,7 @@ static void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGen auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; ++i) - subArrsListOut.at(i)->assign(subArrsListIn.at(indices[i])); + subArrsListOut.at(i).assign(subArrsListIn.at(indices[i])); }; samediff::Threads::parallel_for(func, 0, firstDim); @@ -188,11 +187,12 @@ static void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGen } } -void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) { - BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES); -} - -} -} +void randomShuffle(sd::LaunchContext* context, NDArray& input, NDArray& output, + sd::graph::RandomGenerator& rng, const bool isInplace) { + BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, + (input, output, rng, isInplace), LIBND4J_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/random_crop.cpp b/libnd4j/include/ops/declarable/helpers/cpu/random_crop.cpp index 34be299b7619..7083b9f21f96 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/random_crop.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/random_crop.cpp @@ -20,50 +20,65 @@ #include //#include -#include -#include #include + +#include +#include namespace sd { namespace ops { namespace helpers { - template - static int _randomCropFunctor(graph::Context& context, NDArray* input, NDArray* shape, NDArray* output, int seed) { - graph::RandomGenerator rngX(context.getRng()); - //functions::random::RandomFunction::template execTransform>(rng, output->buffer(), output->shapeInfo(), std::vector({T(0.), shape->e(last)}).data()); - //NativeOpExecutioner::execRandom(random::UniformDistribution, rng, output->buffer(), output->shapeInfo(), std::vector({T(0.), shape->e(last)}).data()); - Nd4jLong last = shape->lengthOf() - 1; +template +static int _randomCropFunctor(graph::Context& context, NDArray* input, + NDArray* shape, NDArray* output, int seed) { + graph::RandomGenerator rngX(context.getRng()); + // functions::random::RandomFunction::template + // execTransform>(rng, output->buffer(), + // output->shapeInfo(), std::vector({T(0.), shape->e(last)}).data()); + // NativeOpExecutioner::execRandom(random::UniformDistribution, rng, + // output->buffer(), output->shapeInfo(), std::vector({T(0.), + // shape->e(last)}).data()); + Nd4jLong last = shape->lengthOf() - 1; - rngX.setSeed(seed); - //functions::random::RandomFunction::template execTransform>(rng, output->buffer(), output->shapeInfo(), std::vector({T(0.), shape->getScalar(last)}).data()); - for (Nd4jLong e = 0; e < output->lengthOf(); ++e) { - output->p(e, rngX.relativeT(e, 0, shape->e(last))); - } - Nd4jLong maxIndex = output->argMax(); - Nd4jLong startPos = output->e(maxIndex); - Nd4jLong lastDim = input->sizeAt(-1); - // nd4j_printf("Before processing: %i %i. Output length %i\n", maxIndex, startPos, output->lengthOf()); - Nd4jLong pos = 0; - Nd4jLong width = startPos + shape->e(last); - if (width >= lastDim) { - startPos -= (width - lastDim); - width = lastDim; - } + rngX.setSeed(seed); + // functions::random::RandomFunction::template + // execTransform>(rng, output->buffer(), + // output->shapeInfo(), std::vector({T(0.), + // shape->getScalar(last)}).data()); + for (Nd4jLong e = 0; e < output->lengthOf(); ++e) { + output->p(e, rngX.relativeT(e, 0, shape->e(last))); + } + Nd4jLong maxIndex = output->argMax(); + Nd4jLong startPos = output->e(maxIndex); + Nd4jLong lastDim = input->sizeAt(-1); + // nd4j_printf("Before processing: %i %i. Output length %i\n", maxIndex, + // startPos, output->lengthOf()); + Nd4jLong pos = 0; + Nd4jLong width = startPos + shape->e(last); + if (width >= lastDim) { + startPos -= (width - lastDim); + width = lastDim; + } - for (Nd4jLong i = 0; i < input->lengthOf(); i += lastDim) { - for (Nd4jLong k = startPos; k < width && pos < output->lengthOf(); k++) { - output->p(pos++, input->e(i + k)); - } - } - return ND4J_STATUS_OK; + for (Nd4jLong i = 0; i < input->lengthOf(); i += lastDim) { + for (Nd4jLong k = startPos; k < width && pos < output->lengthOf(); k++) { + output->p(pos++, input->e(i + k)); } + } + return ND4J_STATUS_OK; +} - int randomCropFunctor(graph::Context& context, NDArray* input, NDArray* shape, NDArray* output, int seed) { - BUILD_SINGLE_SELECTOR(input->dataType(), return _randomCropFunctor, (context, input, shape, output, seed), FLOAT_TYPES); - } +int randomCropFunctor(graph::Context& context, NDArray* input, NDArray* shape, + NDArray* output, int seed) { + BUILD_SINGLE_SELECTOR(input->dataType(), return _randomCropFunctor, + (context, input, shape, output, seed), FLOAT_TYPES); +} - BUILD_SINGLE_TEMPLATE(template int _randomCropFunctor, (graph::Context& context, NDArray* input, NDArray* shape, NDArray* output, int seed), FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template int _randomCropFunctor, + (graph::Context & context, NDArray* input, NDArray* shape, + NDArray* output, int seed), + FLOAT_TYPES); -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/range.cpp b/libnd4j/include/ops/declarable/helpers/cpu/range.cpp index eb2cbd76047f..9a82e461a6f9 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/range.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/range.cpp @@ -18,40 +18,41 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 27.08.2018 // - -#include #include +#include namespace sd { namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// // be careful: outVector must have c-order and ews = 1 !!! template -static void _range(const NDArray& start, const NDArray& delta, NDArray& outVector) { - - const Nd4jLong len = outVector.lengthOf(); - - auto buff = reinterpret_cast(outVector.buffer()); - auto s = start.e(0); - auto d = delta.e(0); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - buff[i] = s + i * d; - }; - samediff::Threads::parallel_for(func, 0, len); +static void _range(const NDArray& start, const NDArray& delta, + NDArray& outVector) { + const Nd4jLong len = outVector.lengthOf(); + + auto buff = reinterpret_cast(outVector.buffer()); + auto s = start.e(0); + auto d = delta.e(0); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) buff[i] = s + i * d; + }; + samediff::Threads::parallel_for(func, 0, len); } - void range(sd::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector) { - BUILD_SINGLE_SELECTOR(outVector.dataType(), _range, (start, delta, outVector), LIBND4J_TYPES); - } - -BUILD_SINGLE_TEMPLATE(template void _range, (const NDArray& start, const NDArray& delta, NDArray& outVector), LIBND4J_TYPES); +void range(sd::LaunchContext* context, const NDArray& start, + const NDArray& delta, NDArray& outVector) { + BUILD_SINGLE_SELECTOR(outVector.dataType(), _range, (start, delta, outVector), + LIBND4J_TYPES); +} +BUILD_SINGLE_TEMPLATE(template void _range, + (const NDArray& start, const NDArray& delta, + NDArray& outVector), + LIBND4J_TYPES); -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/reverse.cpp b/libnd4j/include/ops/declarable/helpers/cpu/reverse.cpp index bc072682ab35..3dfd04d03a21 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/reverse.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/reverse.cpp @@ -18,200 +18,205 @@ // @author Yurii Shyrma, created on 16.04.2018 // -#include -#include #include #include +#include +#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { template inline void swap(T* arr, Nd4jLong from, Nd4jLong to) { - T tmp = arr[from]; - arr[from] = arr[to]; - arr[to] = tmp; + T tmp = arr[from]; + arr[from] = arr[to]; + arr[to] = tmp; } ///////////////////////////////////////////////////////////////////////////////////// // this legacy op is written by raver119@gmail.com -template -static void reverseArray(sd::LaunchContext * context, void const* vinArr, Nd4jLong const*inShapeBuffer, void *voutArr, Nd4jLong const*outShapeBuffer, int numOfElemsToReverse = 0) { - auto inArr = reinterpret_cast(vinArr); - auto outArr = reinterpret_cast(voutArr); - - Nd4jLong inLength = shape::length(inShapeBuffer); - Nd4jLong outLength = shape::length(outShapeBuffer); - if(numOfElemsToReverse == 0) - numOfElemsToReverse = inLength; - int inEWS = shape::elementWiseStride(inShapeBuffer); - char inOrder = shape::order(inShapeBuffer); - auto sLength = numOfElemsToReverse - 1; - - // two step phase here - if (inArr == outArr) { - if (inEWS == 1) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto idx = sLength - e; - swap(const_cast(inArr), e, idx); - } - }; - samediff::Threads::parallel_for(func, 0, numOfElemsToReverse / 2); - } - else if (inEWS > 1) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto idx1 = (sLength - e) * inEWS; - Nd4jLong idx2 = e * inEWS; - swap(const_cast(inArr), idx1, idx2); - } - }; - - samediff::Threads::parallel_for(func, 0, numOfElemsToReverse / 2); - } - else { - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto inOffset = shape::getIndexOffset(e, inShapeBuffer); - auto outOffset = shape::getIndexOffset(sLength - e, inShapeBuffer); - swap(outArr, inOffset, outOffset); - } - }; - - samediff::Threads::parallel_for(func, 0, numOfElemsToReverse / 2); - } - } - else { - // single step phase here - auto outEWS = shape::elementWiseStride(outShapeBuffer); - char outOrder = shape::order(outShapeBuffer); - - if (inEWS == 1 && outEWS == 1 && inOrder == outOrder) { - - auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong e = start; e < stop; e++) - outArr[sLength - e] = inArr[e]; - }; - samediff::Threads::parallel_for(func, 0, numOfElemsToReverse); - - if(inLength != numOfElemsToReverse) { - auto f2 = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - outArr[e] = inArr[e]; - }; - samediff::Threads::parallel_for(f2, numOfElemsToReverse, inLength); - } - } - else if (inEWS >= 1 && outEWS >= 1 && inOrder == outOrder) { - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - outArr[(sLength - e) * outEWS] = inArr[e * inEWS]; - }; - samediff::Threads::parallel_for(func, 0, numOfElemsToReverse); - - if(inLength != numOfElemsToReverse) { - auto f2 = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - outArr[e * outEWS] = inArr[e * inEWS]; - }; - samediff::Threads::parallel_for(f2, numOfElemsToReverse, inLength); - } - } - else { - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto inOffset = shape::getIndexOffset(e, inShapeBuffer); - auto outOffset = shape::getIndexOffset(sLength - e, outShapeBuffer); - outArr[outOffset] = inArr[inOffset]; - } - }; - samediff::Threads::parallel_for(func, 0, numOfElemsToReverse); - - if(inLength != numOfElemsToReverse) { - - auto f2 = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto inOffset = shape::getIndexOffset(e, inShapeBuffer); - auto outOffset = shape::getIndexOffset(e, outShapeBuffer); - outArr[outOffset] = inArr[inOffset]; - } - }; - samediff::Threads::parallel_for(f2, numOfElemsToReverse, inLength); - } - } - } -} - - -/////////////////////////////////////////////////////////////////// template -static void reverseSequence_(sd::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim){ - - int posOfNonUnityDim = -1; - if(input->isVector() || shape::isLikeVector(input->shapeInfo(), posOfNonUnityDim)) { +static void reverseArray(sd::LaunchContext* context, void const* vinArr, + Nd4jLong const* inShapeBuffer, void* voutArr, + Nd4jLong const* outShapeBuffer, + int numOfElemsToReverse = 0) { + auto inArr = reinterpret_cast(vinArr); + auto outArr = reinterpret_cast(voutArr); + + Nd4jLong inLength = shape::length(inShapeBuffer); + Nd4jLong outLength = shape::length(outShapeBuffer); + if (numOfElemsToReverse == 0) numOfElemsToReverse = inLength; + int inEWS = shape::elementWiseStride(inShapeBuffer); + char inOrder = shape::order(inShapeBuffer); + auto sLength = numOfElemsToReverse - 1; + + // two step phase here + if (inArr == outArr) { + if (inEWS == 1) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto idx = sLength - e; + swap(const_cast(inArr), e, idx); + } + }; + samediff::Threads::parallel_for(func, 0, numOfElemsToReverse / 2); + } else if (inEWS > 1) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto idx1 = (sLength - e) * inEWS; + Nd4jLong idx2 = e * inEWS; + swap(const_cast(inArr), idx1, idx2); + } + }; + + samediff::Threads::parallel_for(func, 0, numOfElemsToReverse / 2); + } else { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto inOffset = shape::getIndexOffset(e, inShapeBuffer); + auto outOffset = shape::getIndexOffset(sLength - e, inShapeBuffer); + swap(outArr, inOffset, outOffset); + } + }; - if((seqDim == 0 && input->sizeAt(0) == 1) || (batchDim == posOfNonUnityDim)) - output->assign(input); - else - helpers::reverseArray(context, const_cast(input)->buffer(), const_cast(input)->shapeInfo(), output->buffer(), output->shapeInfo(), seqLengths->e(0)); + samediff::Threads::parallel_for(func, 0, numOfElemsToReverse / 2); } - else { - - if(seqDim > batchDim) - --seqDim; - - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {batchDim}); - - auto inSubArrsSet = input->allTensorsAlongDimension(dimensions); - auto outSubArrsSet = output->allTensorsAlongDimension(dimensions); - - for(int i = 0; i < inSubArrsSet.size(); ++i) { - - Nd4jLong numOfElemsToReverse = seqLengths->e(i); - - if(numOfElemsToReverse == 0 || numOfElemsToReverse == 1) { - outSubArrsSet.at(i)->assign(inSubArrsSet.at(i)); - } - else { - auto inInnerSet = inSubArrsSet.at(i)->allTensorsAlongDimension({seqDim}); - auto outInnerSet = outSubArrsSet.at(i)->allTensorsAlongDimension({seqDim}); - for(int j = 0; j < inInnerSet.size(); ++j) - helpers::reverseArray(context, inInnerSet.at(j)->buffer(), inInnerSet.at(j)->shapeInfo(), outInnerSet.at(j)->buffer(), outInnerSet.at(j)->shapeInfo(), numOfElemsToReverse); - } + } else { + // single step phase here + auto outEWS = shape::elementWiseStride(outShapeBuffer); + char outOrder = shape::order(outShapeBuffer); + + if (inEWS == 1 && outEWS == 1 && inOrder == outOrder) { + auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong e = start; e < stop; e++) outArr[sLength - e] = inArr[e]; + }; + samediff::Threads::parallel_for(func, 0, numOfElemsToReverse); + + if (inLength != numOfElemsToReverse) { + auto f2 = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) outArr[e] = inArr[e]; + }; + samediff::Threads::parallel_for(f2, numOfElemsToReverse, inLength); + } + } else if (inEWS >= 1 && outEWS >= 1 && inOrder == outOrder) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) + outArr[(sLength - e) * outEWS] = inArr[e * inEWS]; + }; + samediff::Threads::parallel_for(func, 0, numOfElemsToReverse); + + if (inLength != numOfElemsToReverse) { + auto f2 = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) + outArr[e * outEWS] = inArr[e * inEWS]; + }; + samediff::Threads::parallel_for(f2, numOfElemsToReverse, inLength); + } + } else { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto inOffset = shape::getIndexOffset(e, inShapeBuffer); + auto outOffset = shape::getIndexOffset(sLength - e, outShapeBuffer); + outArr[outOffset] = inArr[inOffset]; } + }; + samediff::Threads::parallel_for(func, 0, numOfElemsToReverse); + + if (inLength != numOfElemsToReverse) { + auto f2 = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto inOffset = shape::getIndexOffset(e, inShapeBuffer); + auto outOffset = shape::getIndexOffset(e, outShapeBuffer); + outArr[outOffset] = inArr[inOffset]; + } + }; + samediff::Threads::parallel_for(f2, numOfElemsToReverse, inLength); + } } + } } - void reverseSequence(sd::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim) { - BUILD_SINGLE_SELECTOR(input->dataType(), reverseSequence_, (context, input, seqLengths, output, seqDim, batchDim), LIBND4J_TYPES); - } - -////////////////////////////////////////////////////////////////////////// -void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector* intArgs) { - - auto listOut = output->allTensorsAlongDimension(*intArgs); - auto listIn = input->allTensorsAlongDimension(*intArgs); - - NDArray *subArrIn, *subArrOut; - - for(int i = 0; i < listIn.size(); ++i) { // listIn.size() = listOut.size() - subArrIn = listIn.at(i); - subArrOut = listOut.at(i); - BUILD_SINGLE_SELECTOR(input->dataType(), helpers::reverseArray, (context, subArrIn->buffer(), subArrIn->shapeInfo(), subArrOut->buffer(), subArrOut->shapeInfo()), LIBND4J_TYPES); +/////////////////////////////////////////////////////////////////// +template +static void reverseSequence_(sd::LaunchContext* context, const NDArray* input, + const NDArray* seqLengths, NDArray* output, + int seqDim, const int batchDim) { + int posOfNonUnityDim = -1; + if (input->isVector() || + shape::isLikeVector(input->shapeInfo(), posOfNonUnityDim)) { + if ((seqDim == 0 && input->sizeAt(0) == 1) || + (batchDim == posOfNonUnityDim)) + output->assign(input); + else + helpers::reverseArray(context, const_cast(input)->buffer(), + const_cast(input)->shapeInfo(), + output->buffer(), output->shapeInfo(), + seqLengths->e(0)); + } else { + if (seqDim > batchDim) --seqDim; + + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {batchDim}); + + auto inSubArrsSet = input->allTensorsAlongDimension(dimensions); + auto outSubArrsSet = output->allTensorsAlongDimension(dimensions); + + for (int i = 0; i < inSubArrsSet.size(); ++i) { + Nd4jLong numOfElemsToReverse = seqLengths->e(i); + + if (numOfElemsToReverse == 0 || numOfElemsToReverse == 1) { + outSubArrsSet.at(i).assign(inSubArrsSet.at(i)); + } else { + auto inInnerSet = inSubArrsSet.at(i).allTensorsAlongDimension({seqDim}); + auto outInnerSet = + outSubArrsSet.at(i).allTensorsAlongDimension({seqDim}); + for (int j = 0; j < inInnerSet.size(); ++j) + helpers::reverseArray( + context, inInnerSet.at(j).buffer(), inInnerSet.at(j).shapeInfo(), + outInnerSet.at(j).buffer(), outInnerSet.at(j).shapeInfo(), + numOfElemsToReverse); + } } + } } -BUILD_SINGLE_TEMPLATE(template void reverseSequence_, (sd::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim), LIBND4J_TYPES); -BUILD_SINGLE_TEMPLATE(template void reverseArray, (sd::LaunchContext * context, void const*inArr, Nd4jLong const*inShapeBuffer, void* outArr, Nd4jLong const* outShapeBuffer, int numOfElemsToReverse), LIBND4J_TYPES); - - -} +void reverseSequence(sd::LaunchContext* context, const NDArray* input, + const NDArray* seqLengths, NDArray* output, int seqDim, + const int batchDim) { + BUILD_SINGLE_SELECTOR(input->dataType(), reverseSequence_, + (context, input, seqLengths, output, seqDim, batchDim), + LIBND4J_TYPES); } + +////////////////////////////////////////////////////////////////////////// +void reverse(sd::LaunchContext* context, const NDArray* input, NDArray* output, + const std::vector* intArgs) { + + auto listOut = output->allTensorsAlongDimension(*intArgs); + auto listIn = input->allTensorsAlongDimension(*intArgs); + + for (int i = 0; i < listIn.size(); ++i) { // listIn.size() = listOut.size() + auto subArrIn = listIn.at(i); + auto subArrOut = listOut.at(i); + BUILD_SINGLE_SELECTOR(input->dataType(), helpers::reverseArray, + (context, subArrIn.buffer(), subArrIn.shapeInfo(), + subArrOut.buffer(), subArrOut.shapeInfo()), + LIBND4J_TYPES); + } } +BUILD_SINGLE_TEMPLATE(template void reverseSequence_, + (sd::LaunchContext * context, const NDArray* input, + const NDArray* seqLengths, NDArray* output, int seqDim, + const int batchDim), + LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void reverseArray, + (sd::LaunchContext * context, void const* inArr, + Nd4jLong const* inShapeBuffer, void* outArr, + Nd4jLong const* outShapeBuffer, int numOfElemsToReverse), + LIBND4J_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp b/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp index 2e3d983cd306..ae5616010edd 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp @@ -24,138 +24,143 @@ namespace sd { namespace ops { namespace helpers { - template - static void rollFunctorLinear_(NDArray* input, NDArray* output, int shift, bool inplace){ - auto source = input; - if (!inplace) - output->assign(input); - - int fullLen = source->lengthOf(); - int actualShift = shift; // % fullLen; // shift already non-negative then - if (actualShift < 0) { - actualShift -= fullLen * (actualShift / fullLen - 1); - } - else - actualShift %= fullLen; - - if (actualShift) { - int shiftCount = fullLen / actualShift - 1; - int remainShift = fullLen % actualShift; - - // stage 1) swap last actualShift elements with first ones. - //PRAGMA_OMP_PARALLEL_FOR //_IF(actualShift > Environment::getInstance().elementwiseThreshold()) - for (int e = 0; e < actualShift; ++e) { - int sourceIndex = fullLen - actualShift + e; - - auto _e0 = output->e(e); - auto _e1 = output->e(sourceIndex); - - //sd::math::nd4j_swap((*output)(e), (*output)(sourceIndex)); - output->p(e, _e1); - output->p(sourceIndex, _e0); - } - - // stage 2) swap swapped actualShift elements with rest remainShiftCount times. - //PRAGMA_OMP_PARALLEL_FOR //_IF(shiftCount > Environment::getInstance().tadThreshold()) - for (int count = 1; count < shiftCount; ++count) { - for (int e = 0; e < actualShift; ++e) { - int destinationIndex = fullLen - (count + 1) * actualShift + e; - int sourceIndex = fullLen - count * actualShift + e; - - auto _e0 = output->e(destinationIndex); - auto _e1 = output->e(sourceIndex); - - //sd::math::nd4j_swap((*output)(destinationIndex), (*output)(sourceIndex)); - output->p(destinationIndex, _e1); - output->p(sourceIndex, _e0); - } - } - - // stage 3) swap remainer of items. - if (remainShift && shiftCount) - for (int i = actualShift; i < 2 * actualShift; ++i) { - auto _e0 = output->e(i); - auto _e1 = output->e(i + remainShift); - - //sd::math::nd4j_swap((*output)(i), (*output)(i + remainShift)); - - output->p(i, _e1); - output->p(i + remainShift, _e0); - } - } +template +static void rollFunctorLinear_(NDArray* input, NDArray* output, int shift, + bool inplace) { + auto source = input; + if (!inplace) output->assign(input); + + int fullLen = source->lengthOf(); + int actualShift = shift; // % fullLen; // shift already non-negative then + if (actualShift < 0) { + actualShift -= fullLen * (actualShift / fullLen - 1); + } else + actualShift %= fullLen; + + if (actualShift) { + int shiftCount = fullLen / actualShift - 1; + int remainShift = fullLen % actualShift; + + // stage 1) swap last actualShift elements with first ones. + // PRAGMA_OMP_PARALLEL_FOR //_IF(actualShift > + // Environment::getInstance().elementwiseThreshold()) + for (int e = 0; e < actualShift; ++e) { + int sourceIndex = fullLen - actualShift + e; + + auto _e0 = output->e(e); + auto _e1 = output->e(sourceIndex); + + // sd::math::nd4j_swap((*output)(e), (*output)(sourceIndex)); + output->p(e, _e1); + output->p(sourceIndex, _e0); } - void rollFunctorFull(sd::LaunchContext * context, NDArray* input, NDArray* output, std::vector const& shifts, std::vector const& axes, bool inplace){ - - if (!inplace) - output->assign(input); - - auto source = output; //input; - for (size_t i = 0; i < axes.size(); i++) { - int axe = axes[i]; - if (axe == source->rankOf() - 1) {// last dimension - ResultSet listOfTensors = source->allTensorsAlongDimension({axe}); - ResultSet listOfOutTensors = output->allTensorsAlongDimension({axe}); - int fullLen = listOfTensors.size(); - int theShift = shifts[i]; - if (theShift > 0) { - theShift %= fullLen; - } - else { - theShift -= fullLen * (theShift / fullLen - 1); - } - for (int k = 0; k < fullLen; k++) { - rollFunctorLinear(context, listOfTensors.at(k), listOfOutTensors.at(k), theShift, true); - } - } - else { - std::vector dims(source->rankOf() - axe - 1); - for (size_t i = 0; i < dims.size(); ++i) - dims[i] = axe + 1 + i; - - ResultSet listOfTensors = source->allTensorsAlongDimension({dims}); - ResultSet listOfOutTensors = output->allTensorsAlongDimension({dims}); - // - int fullLen = listOfTensors.size(); - int sizeAt = input->sizeAt(axe); - - int theShift = shifts[i]; - - if (theShift > 0) { - theShift %= sizeAt; - } - else { - theShift -= sizeAt * (theShift / sizeAt - 1); - } - - if (theShift) { - for (int dim = 0; dim < fullLen / sizeAt; ++dim) { - for (int e = theShift; e < sizeAt - theShift; ++e) { - auto sourceM = listOfTensors.at(dim * sizeAt + e - theShift); - auto targetM = listOfOutTensors.at(dim * sizeAt + e); - sourceM->swapUnsafe(*targetM); - } - - for (int e = 0; e < theShift; ++e) { - int sourceIndex = dim * sizeAt + sizeAt - theShift + e; - auto sourceM = listOfTensors.at(sourceIndex); - auto targetM = listOfOutTensors.at(dim * sizeAt + e); - - sourceM->swapUnsafe(*targetM); - } - } - } - } -// if (!inplace) -// source = output; - } + // stage 2) swap swapped actualShift elements with rest remainShiftCount + // times. + // PRAGMA_OMP_PARALLEL_FOR //_IF(shiftCount > + // Environment::getInstance().tadThreshold()) + for (int count = 1; count < shiftCount; ++count) { + for (int e = 0; e < actualShift; ++e) { + int destinationIndex = fullLen - (count + 1) * actualShift + e; + int sourceIndex = fullLen - count * actualShift + e; + + auto _e0 = output->e(destinationIndex); + auto _e1 = output->e(sourceIndex); + + // sd::math::nd4j_swap((*output)(destinationIndex), + // (*output)(sourceIndex)); + output->p(destinationIndex, _e1); + output->p(sourceIndex, _e0); + } } - void rollFunctorLinear(sd::LaunchContext * context, NDArray* input, NDArray* output, int shift, bool inplace){ - BUILD_SINGLE_SELECTOR(input->dataType(), rollFunctorLinear_, (input, output, shift, inplace), LIBND4J_TYPES); - } + // stage 3) swap remainer of items. + if (remainShift && shiftCount) + for (int i = actualShift; i < 2 * actualShift; ++i) { + auto _e0 = output->e(i); + auto _e1 = output->e(i + remainShift); + + // sd::math::nd4j_swap((*output)(i), (*output)(i + remainShift)); - BUILD_SINGLE_TEMPLATE(template void rollFunctorLinear_, (NDArray* input, NDArray* output, int shift, bool inplace), LIBND4J_TYPES); + output->p(i, _e1); + output->p(i + remainShift, _e0); + } + } } + +void rollFunctorFull(sd::LaunchContext* context, NDArray* input, + NDArray* output, std::vector const& shifts, + std::vector const& axes, bool inplace) { + if (!inplace) output->assign(input); + + auto source = output; // input; + for (size_t i = 0; i < axes.size(); i++) { + int axe = axes[i]; + if (axe == source->rankOf() - 1) { // last dimension + ResultSet listOfTensors = source->allTensorsAlongDimension({axe}); + ResultSet listOfOutTensors = output->allTensorsAlongDimension({axe}); + int fullLen = listOfTensors.size(); + int theShift = shifts[i]; + if (theShift > 0) { + theShift %= fullLen; + } else { + theShift -= fullLen * (theShift / fullLen - 1); + } + for (int k = 0; k < fullLen; k++) { + rollFunctorLinear(context, &listOfTensors.at(k), + &listOfOutTensors.at(k), theShift, true); + } + } else { + std::vector dims(source->rankOf() - axe - 1); + for (size_t i = 0; i < dims.size(); ++i) dims[i] = axe + 1 + i; + + ResultSet listOfTensors = source->allTensorsAlongDimension({dims}); + ResultSet listOfOutTensors = output->allTensorsAlongDimension({dims}); + // + int fullLen = listOfTensors.size(); + int sizeAt = input->sizeAt(axe); + + int theShift = shifts[i]; + + if (theShift > 0) { + theShift %= sizeAt; + } else { + theShift -= sizeAt * (theShift / sizeAt - 1); + } + + if (theShift) { + for (int dim = 0; dim < fullLen / sizeAt; ++dim) { + for (int e = theShift; e < sizeAt - theShift; ++e) { + auto sourceM = listOfTensors.at(dim * sizeAt + e - theShift); + auto targetM = listOfOutTensors.at(dim * sizeAt + e); + sourceM.swapUnsafe(targetM); + } + + for (int e = 0; e < theShift; ++e) { + int sourceIndex = dim * sizeAt + sizeAt - theShift + e; + auto sourceM = listOfTensors.at(sourceIndex); + auto targetM = listOfOutTensors.at(dim * sizeAt + e); + + sourceM.swapUnsafe(targetM); + } + } + } + } + // if (!inplace) + // source = output; + } } -} \ No newline at end of file + +void rollFunctorLinear(sd::LaunchContext* context, NDArray* input, + NDArray* output, int shift, bool inplace) { + BUILD_SINGLE_SELECTOR(input->dataType(), rollFunctorLinear_, + (input, output, shift, inplace), LIBND4J_TYPES); +} + +BUILD_SINGLE_TEMPLATE(template void rollFunctorLinear_, + (NDArray * input, NDArray* output, int shift, + bool inplace), + LIBND4J_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/s_t_b.cpp b/libnd4j/include/ops/declarable/helpers/cpu/s_t_b.cpp index 99a172c0282d..8b70dad9d863 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/s_t_b.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/s_t_b.cpp @@ -19,399 +19,433 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// template -static void batchToSpace_(const NDArray& input, NDArray& output, const uint cropBottom, const uint cropTop, const uint cropLeft, const uint cropRight) { - - // input [bS, H * blockSize, W * blockSize, iC] - // output [bS, H * blockSize - cropBottom - cropTop, W * blockSize - cropLeft - cropRight, iC] - - // if (cropTop = cropBottom = cropRight = cropLeft = 0) shapes are the same - // else: - // oH -> [cropBottom, iH - cropTop] - // oW -> [cropLeft, iH - cropRight] - // xLen > zLen - - const T* x = input.bufferAsT(); - T* z = output.bufferAsT(); +static void batchToSpace_(const NDArray& input, NDArray& output, + const uint cropBottom, const uint cropTop, + const uint cropLeft, const uint cropRight) { + // input [bS, H * blockSize, W * blockSize, iC] + // output [bS, H * blockSize - cropBottom - cropTop, W * blockSize - cropLeft + // - cropRight, iC] + + // if (cropTop = cropBottom = cropRight = cropLeft = 0) shapes are the same + // else: + // oH -> [cropBottom, iH - cropTop] + // oW -> [cropLeft, iH - cropRight] + // xLen > zLen + + const T* x = input.bufferAsT(); + T* z = output.bufferAsT(); + + const int rank = 4; + + const Nd4jLong* xShapeInfo = input.shapeInfo(); + const Nd4jLong* zShapeInfo = output.shapeInfo(); + + const uint bS = xShapeInfo[1]; + const uint iH = xShapeInfo[2]; + const uint iW = xShapeInfo[3]; + const uint iC = xShapeInfo[4]; + + // loop through output array + auto func = PRAGMA_THREADS_FOR_3D { + for (auto b = start_x; b < stop_x; b += inc_x) { + for (auto h = start_y; h < stop_y; h += inc_y) { + for (auto w = start_z; w < stop_z; w += inc_z) { + for (uint c = 0; c < iC; ++c) { + const Nd4jLong xOffset = b * xShapeInfo[5] + h * xShapeInfo[6] + + w * xShapeInfo[7] + c * xShapeInfo[8]; + const Nd4jLong zOffset = + b * zShapeInfo[5] + (h - cropBottom) * zShapeInfo[6] + + (w - cropLeft) * zShapeInfo[7] + c * zShapeInfo[8]; - const int rank = 4; - - const Nd4jLong* xShapeInfo = input.shapeInfo(); - const Nd4jLong* zShapeInfo = output.shapeInfo(); - - const uint bS = xShapeInfo[1]; - const uint iH = xShapeInfo[2]; - const uint iW = xShapeInfo[3]; - const uint iC = xShapeInfo[4]; - - // loop through output array - auto func = PRAGMA_THREADS_FOR_3D { - for (auto b = start_x; b < stop_x; b += inc_x) { - for (auto h = start_y; h < stop_y; h += inc_y) { - for (auto w = start_z; w < stop_z; w += inc_z) { - for (uint c = 0; c < iC; ++c) { - const Nd4jLong xOffset = b * xShapeInfo[5] + h * xShapeInfo[6] + w * xShapeInfo[7] + c * xShapeInfo[8]; - const Nd4jLong zOffset = b * zShapeInfo[5] + (h - cropBottom) * zShapeInfo[6] + (w - cropLeft) * zShapeInfo[7] + c * zShapeInfo[8]; - - z[zOffset] = x[xOffset]; - } - } - } + z[zOffset] = x[xOffset]; + } } - }; + } + } + }; - samediff::Threads::parallel_for(func, 0, bS, 1, cropBottom, iH - cropTop, 1, cropLeft, iW - cropRight, 1); + samediff::Threads::parallel_for(func, 0, bS, 1, cropBottom, iH - cropTop, 1, + cropLeft, iW - cropRight, 1); } -BUILD_SINGLE_TEMPLATE(template void batchToSpace_, (const NDArray& input, NDArray& output, const uint cropBottom, const uint cropTop, const uint cropLeft, const uint cropRight), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void batchToSpace_, + (const NDArray& input, NDArray& output, + const uint cropBottom, const uint cropTop, + const uint cropLeft, const uint cropRight), + LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// -void batchToSpace(sd::LaunchContext* context, const NDArray& input, NDArray& output, const uint cropBottom, const uint cropTop, const uint cropLeft, const uint cropRight, const uint blockSize) { - - // [bS*blockSize*blockSize, H/blockSize, W/blockSize, iC] is rearranged/permuted to [bS, oH, oW, iC] - // oH = H - cropTop - cropBottom - // oW = W - cropLeft - cropRight - - NDArray inputRearranged0 = input.reshape(input.ordering(), {blockSize, blockSize, output.sizeAt(0), input.sizeAt(1), input.sizeAt(2), input.sizeAt(3)}); - inputRearranged0.permutei({2, 3,0, 4,1, 5}); - - if(input.lengthOf() == output.lengthOf()) - output.assign(inputRearranged0); - else { - NDArray inputRearranged1 = inputRearranged0.reshape(input.ordering(), {output.sizeAt(0), input.sizeAt(1) * blockSize, input.sizeAt(2) * blockSize, input.sizeAt(3)}); - BUILD_SINGLE_SELECTOR(input.dataType(), batchToSpace_, (inputRearranged1, output, cropBottom, cropTop, cropLeft, cropRight), LIBND4J_TYPES); - } +void batchToSpace(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const uint cropBottom, const uint cropTop, + const uint cropLeft, const uint cropRight, + const uint blockSize) { + // [bS*blockSize*blockSize, H/blockSize, W/blockSize, iC] is + // rearranged/permuted to [bS, oH, oW, iC] oH = H - cropTop - cropBottom oW = + // W - cropLeft - cropRight + + NDArray inputRearranged0 = input.reshape( + input.ordering(), {blockSize, blockSize, output.sizeAt(0), + input.sizeAt(1), input.sizeAt(2), input.sizeAt(3)}); + inputRearranged0.permutei({2, 3, 0, 4, 1, 5}); + + if (input.lengthOf() == output.lengthOf()) + output.assign(inputRearranged0); + else { + NDArray inputRearranged1 = inputRearranged0.reshape( + input.ordering(), {output.sizeAt(0), input.sizeAt(1) * blockSize, + input.sizeAt(2) * blockSize, input.sizeAt(3)}); + BUILD_SINGLE_SELECTOR( + input.dataType(), batchToSpace_, + (inputRearranged1, output, cropBottom, cropTop, cropLeft, cropRight), + LIBND4J_TYPES); + } } ////////////////////////////////////////////////////////////////////////// template -static void batchToSpaceND_(const NDArray& input, const NDArray& crop, NDArray& output, const uint numOfSpatialDims) { - - // input [bS, H * blockShape[0], W * blockShape[1], iC] - // output [bS, H * blockShape[0] - cropBottom - cropTop, W * blockShape[1] - cropLeft - cropRight, iC] - - // if (cropTop = cropBottom = cropRight = cropLeft = 0) shapes are the same - // else: - // oH -> [cropBottom, iH - cropTop] - // oW -> [cropLeft, iH - cropRight] - // xLen >= zLen +static void batchToSpaceND_(const NDArray& input, const NDArray& crop, + NDArray& output, const uint numOfSpatialDims) { + // input [bS, H * blockShape[0], W * blockShape[1], iC] + // output [bS, H * blockShape[0] - cropBottom - cropTop, W * blockShape[1] - + // cropLeft - cropRight, iC] - const T* x = input.bufferAsT(); - T* z = output.bufferAsT(); + // if (cropTop = cropBottom = cropRight = cropLeft = 0) shapes are the same + // else: + // oH -> [cropBottom, iH - cropTop] + // oW -> [cropLeft, iH - cropRight] + // xLen >= zLen - const int rank = input.rankOf(); - const Nd4jLong zLen = output.lengthOf(); + const T* x = input.bufferAsT(); + T* z = output.bufferAsT(); - // loop through input array - auto func = PRAGMA_THREADS_FOR { + const int rank = input.rankOf(); + const Nd4jLong zLen = output.lengthOf(); - int zCoords[MAX_RANK], xCoords[MAX_RANK]; + // loop through input array + auto func = PRAGMA_THREADS_FOR { + int zCoords[MAX_RANK], xCoords[MAX_RANK]; - for (auto i = start; i < stop; i++) { + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, output.shapeInfo(), zCoords); - shape::index2coordsCPU(start, i, output.shapeInfo(), zCoords); + memcpy(xCoords, zCoords, rank * sizeof(int)); - memcpy(xCoords, zCoords, rank * sizeof(int)); + // evaluate spatial coordinates for x + for (uint j = 1; j <= numOfSpatialDims; ++j) + xCoords[j] += crop.e(j - 1, 0); // add crop left - // evaluate spatial coordinates for x - for (uint j = 1; j <= numOfSpatialDims; ++j) - xCoords[j] += crop.e(j - 1, 0); // add crop left + const auto zOffset = shape::getOffset(output.shapeInfo(), zCoords); + const auto xOffset = shape::getOffset(input.shapeInfo(), xCoords); - const auto zOffset = shape::getOffset(output.shapeInfo(), zCoords); - const auto xOffset = shape::getOffset(input.shapeInfo(), xCoords); - - z[zOffset] = x[xOffset]; - } - }; + z[zOffset] = x[xOffset]; + } + }; - samediff::Threads::parallel_tad(func, 0, zLen); + samediff::Threads::parallel_tad(func, 0, zLen); } -BUILD_SINGLE_TEMPLATE(template void batchToSpaceND_, (const NDArray& input, const NDArray& crop, NDArray& output, const uint numOfSpatialDims), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void batchToSpaceND_, + (const NDArray& input, const NDArray& crop, + NDArray& output, const uint numOfSpatialDims), + LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// -void batchToSpaceND(sd::LaunchContext* context, const NDArray& input, const NDArray& blockShape, const NDArray& crop, NDArray& output) { +void batchToSpaceND(sd::LaunchContext* context, const NDArray& input, + const NDArray& blockShape, const NDArray& crop, + NDArray& output) { + // 4D example, numOfSpatialDims = 2 - two spatial dimensions + // [bS*blockShape[0]*blockShape[1], iH, iW, iC] is rearranged/permuted to [bS, + // iH*blockShape[0] - cropTop - cropBottom, iW*blockShape[1] - cropLeft - + // cropRight, iC] - // 4D example, numOfSpatialDims = 2 - two spatial dimensions - // [bS*blockShape[0]*blockShape[1], iH, iW, iC] is rearranged/permuted to [bS, iH*blockShape[0] - cropTop - cropBottom, iW*blockShape[1] - cropLeft - cropRight, iC] + const uint rank = input.rankOf(); + const uint numOfSpatialDims = blockShape.sizeAt(0); - const uint rank = input.rankOf(); - const uint numOfSpatialDims = blockShape.sizeAt(0); + //*** construct reshaping std::vector for first reshape of input array ***// - //*** construct reshaping std::vector for first reshape of input array ***// + std::vector temp(numOfSpatialDims + rank); - std::vector temp(numOfSpatialDims + rank); + uint i; + for (i = 0; i < numOfSpatialDims; ++i) temp[i] = blockShape.e(i); + temp[i++] = output.sizeAt(0); + for (uint j = 1; j < rank; ++i, ++j) temp[i] = input.sizeAt(j); - uint i; - for(i = 0; i < numOfSpatialDims; ++i) - temp[i] = blockShape.e(i); - temp[i++] = output.sizeAt(0); - for(uint j = 1; j < rank; ++i, ++j) - temp[i] = input.sizeAt(j); + NDArray inputRearranged0 = input.reshape(input.ordering(), temp); - NDArray inputRearranged0 = input.reshape(input.ordering(), temp); + //*** construct permuting std::vector for permutation of input array ***// - //*** construct permuting std::vector for permutation of input array ***// + temp[0] = numOfSpatialDims; - temp[0] = numOfSpatialDims; + for (i = 1; i <= numOfSpatialDims; ++i) { + temp[2 * i - 1] = numOfSpatialDims + i; + temp[2 * i] = i - 1; + } + for (i = 2 * numOfSpatialDims + 1; i < static_cast(temp.size()); ++i) + temp[i] = i; - for(i = 1; i <= numOfSpatialDims; ++i) { - temp[2*i - 1] = numOfSpatialDims + i; - temp[2*i] = i - 1; - } - for(i = 2 * numOfSpatialDims + 1; i < static_cast(temp.size()); ++i) - temp[i] = i; + inputRearranged0.permutei(temp); - inputRearranged0.permutei(temp); + if (input.lengthOf() == output.lengthOf()) { + output.assign(inputRearranged0); + } else { + //*** construct reshaping std::vector for second reshape of input array + //***// + temp.resize(rank); - if(input.lengthOf() == output.lengthOf()) { - output.assign(inputRearranged0); - } - else { - //*** construct reshaping std::vector for second reshape of input array ***// - - temp.resize(rank); - - temp[0] = output.sizeAt(0); + temp[0] = output.sizeAt(0); - for(i = 1; i < rank; ++i) - temp[i] = (i <= numOfSpatialDims) ? input.sizeAt(i) * blockShape.e(i - 1) : input.sizeAt(i); + for (i = 1; i < rank; ++i) + temp[i] = (i <= numOfSpatialDims) + ? input.sizeAt(i) * blockShape.e(i - 1) + : input.sizeAt(i); - NDArray inputRearranged1 = inputRearranged0.reshape(input.ordering(), temp); + NDArray inputRearranged1 = inputRearranged0.reshape(input.ordering(), temp); - BUILD_SINGLE_SELECTOR(input.dataType(), batchToSpaceND_, (inputRearranged1, crop, output, numOfSpatialDims), LIBND4J_TYPES); - } + BUILD_SINGLE_SELECTOR(input.dataType(), batchToSpaceND_, + (inputRearranged1, crop, output, numOfSpatialDims), + LIBND4J_TYPES); + } } ////////////////////////////////////////////////////////////////////////// template -static void spaceToBatch_(const NDArray& input, NDArray& output, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight) { - - // input [bS, H * blockSize - padBottom - padTop, W * blockSize - padLeft - padRight, iC] - // output [bS, H * blockSize, W * blockSize, iC] - - // if (padTop = padBottom = padRight = padLeft = 0) shapes are the same - // else: - // iH -> [padBottom, oH - padTop] - // iW -> [padLeft, oW - padRight] - // zLen > xLen - - const T* x = input.bufferAsT(); - T* z = output.bufferAsT(); - - const int rank = 4; - - const Nd4jLong* xShapeInfo = input.shapeInfo(); - const Nd4jLong* zShapeInfo = output.shapeInfo(); - - const uint bS = zShapeInfo[1]; - const uint oH = zShapeInfo[2]; - const uint oW = zShapeInfo[3]; - const uint iC = zShapeInfo[4]; - - // loop through output array - auto func = PRAGMA_THREADS_FOR_2D { - for (auto b = start_x; b < stop_x; b += inc_x) { - for (auto h = start_y; h < stop_y; h += inc_y) { - for (uint w = 0; w < oW; ++w) { - for (uint c = 0; c < iC; ++c) { - - const Nd4jLong zOffset = b * zShapeInfo[5] + h * zShapeInfo[6] + w * zShapeInfo[7] + c * zShapeInfo[8]; - - if (h >= padBottom && h < oH - padTop && w >= padLeft && w < oW - padRight) { - const Nd4jLong xOffset = b * xShapeInfo[5] + (h - padBottom) * xShapeInfo[6] + (w - padLeft) * xShapeInfo[7] + c * xShapeInfo[8]; - z[zOffset] = x[xOffset]; - } else - z[zOffset] = 0.f; - } - } - } +static void spaceToBatch_(const NDArray& input, NDArray& output, + const uint padBottom, const uint padTop, + const uint padLeft, const uint padRight) { + // input [bS, H * blockSize - padBottom - padTop, W * blockSize - padLeft - + // padRight, iC] output [bS, H * blockSize, W * blockSize, iC] + + // if (padTop = padBottom = padRight = padLeft = 0) shapes are the same + // else: + // iH -> [padBottom, oH - padTop] + // iW -> [padLeft, oW - padRight] + // zLen > xLen + + const T* x = input.bufferAsT(); + T* z = output.bufferAsT(); + + const int rank = 4; + + const Nd4jLong* xShapeInfo = input.shapeInfo(); + const Nd4jLong* zShapeInfo = output.shapeInfo(); + + const uint bS = zShapeInfo[1]; + const uint oH = zShapeInfo[2]; + const uint oW = zShapeInfo[3]; + const uint iC = zShapeInfo[4]; + + // loop through output array + auto func = PRAGMA_THREADS_FOR_2D { + for (auto b = start_x; b < stop_x; b += inc_x) { + for (auto h = start_y; h < stop_y; h += inc_y) { + for (uint w = 0; w < oW; ++w) { + for (uint c = 0; c < iC; ++c) { + const Nd4jLong zOffset = b * zShapeInfo[5] + h * zShapeInfo[6] + + w * zShapeInfo[7] + c * zShapeInfo[8]; + + if (h >= padBottom && h < oH - padTop && w >= padLeft && + w < oW - padRight) { + const Nd4jLong xOffset = + b * xShapeInfo[5] + (h - padBottom) * xShapeInfo[6] + + (w - padLeft) * xShapeInfo[7] + c * xShapeInfo[8]; + z[zOffset] = x[xOffset]; + } else + z[zOffset] = 0.f; + } } - }; + } + } + }; - samediff::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1); + samediff::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1); } -BUILD_SINGLE_TEMPLATE(template void spaceToBatch_, (const NDArray& input, NDArray& output, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void spaceToBatch_, + (const NDArray& input, NDArray& output, + const uint padBottom, const uint padTop, + const uint padLeft, const uint padRight), + LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// -void spaceToBatch(sd::LaunchContext* context, const NDArray& input, NDArray& output, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight, const uint blockSize) { - - // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC] - - NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), output.sizeAt(3)}, false); - outputRearranged0.permutei({2, 3,0, 4,1, 5}); - - if(input.lengthOf() == output.lengthOf()) { - outputRearranged0.assign(input); - } - else { - NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, output.sizeAt(3)}, false); - BUILD_SINGLE_SELECTOR(input.dataType(), spaceToBatch_, (input, outputRearranged1, padBottom, padTop, padLeft, padRight), LIBND4J_TYPES); - - if(output.buffer() != outputRearranged1.buffer()) - outputRearranged0.assign(outputRearranged1); - } +void spaceToBatch(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const uint padBottom, const uint padTop, + const uint padLeft, const uint padRight, + const uint blockSize) { + // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + + // padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC] + + NDArray outputRearranged0 = + output.reshape(output.ordering(), + {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), + output.sizeAt(2), output.sizeAt(3)}, + false); + outputRearranged0.permutei({2, 3, 0, 4, 1, 5}); + + if (input.lengthOf() == output.lengthOf()) { + outputRearranged0.assign(input); + } else { + NDArray outputRearranged1 = outputRearranged0.reshape( + output.ordering(), + {input.sizeAt(0), output.sizeAt(1) * blockSize, + output.sizeAt(2) * blockSize, output.sizeAt(3)}, + false); + BUILD_SINGLE_SELECTOR( + input.dataType(), spaceToBatch_, + (input, outputRearranged1, padBottom, padTop, padLeft, padRight), + LIBND4J_TYPES); + + if (output.buffer() != outputRearranged1.buffer()) + outputRearranged0.assign(outputRearranged1); + } } - - - - - - - - - - - - - - - - - - ////////////////////////////////////////////////////////////////////////// template -static void spaceToBatchND_(const NDArray& input, const NDArray& padding, NDArray& output, const uint numOfSpatialDims) { - - // 4D example - // input [bS, H * blockShape[0] - padBottom - padTop, W * blockShape[1] - padLeft - padRight, iC] - // output [bS, H * blockShape[0], W * blockShape[1], iC] +static void spaceToBatchND_(const NDArray& input, const NDArray& padding, + NDArray& output, const uint numOfSpatialDims) { + // 4D example + // input [bS, H * blockShape[0] - padBottom - padTop, W * blockShape[1] - + // padLeft - padRight, iC] output [bS, H * blockShape[0], W * blockShape[1], + // iC] - // if (padTop = padBottom = padRight = padLeft = 0) shapes are the same - // else: - // iH -> [padBottom, oH - padTop] - // iW -> [padLeft, oW - padRight] - // zLen > xLen + // if (padTop = padBottom = padRight = padLeft = 0) shapes are the same + // else: + // iH -> [padBottom, oH - padTop] + // iW -> [padLeft, oW - padRight] + // zLen > xLen - const T* x = input.bufferAsT(); - T* z = output.bufferAsT(); + const T* x = input.bufferAsT(); + T* z = output.bufferAsT(); - const int rank = input.rankOf(); - const Nd4jLong zLen = output.lengthOf(); + const int rank = input.rankOf(); + const Nd4jLong zLen = output.lengthOf(); - // loop through output array - auto func = PRAGMA_THREADS_FOR { + // loop through output array + auto func = PRAGMA_THREADS_FOR { + int zCoords[MAX_RANK], xCoords[MAX_RANK]; - int zCoords[MAX_RANK], xCoords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, output.shapeInfo(), zCoords); - for (auto i = start; i < stop; i++) { + const auto zOffset = shape::getOffset(output.shapeInfo(), zCoords); - shape::index2coordsCPU(start, i, output.shapeInfo(), zCoords); + memcpy(xCoords, zCoords, rank * sizeof(int)); - const auto zOffset = shape::getOffset(output.shapeInfo(), zCoords); + bool within = true; - memcpy(xCoords, zCoords, rank * sizeof(int)); + for (uint j = 1; j <= numOfSpatialDims; ++j) { + const auto padLeft = padding.e(j - 1, 0); + const auto padRight = padding.e(j - 1, 1); - bool within = true; + within &= + zCoords[j] >= padLeft && zCoords[j] < output.sizeAt(j) - padRight; - for (uint j = 1; j <= numOfSpatialDims; ++j) { + if (!within) break; - const auto padLeft = padding.e(j - 1, 0); - const auto padRight = padding.e(j - 1, 1); + xCoords[j] = zCoords[j] - padLeft; // get coordinates for x + } - within &= zCoords[j] >= padLeft && zCoords[j] < output.sizeAt(j) - padRight; - - if (!within) - break; - - xCoords[j] = zCoords[j] - padLeft; // get coordinates for x - } - - if (within) - z[zOffset] = x[shape::getOffset(input.shapeInfo(), xCoords)]; - else - z[zOffset] = 0.f; - } - }; + if (within) + z[zOffset] = x[shape::getOffset(input.shapeInfo(), xCoords)]; + else + z[zOffset] = 0.f; + } + }; - samediff::Threads::parallel_tad(func, 0, zLen); + samediff::Threads::parallel_tad(func, 0, zLen); } -BUILD_SINGLE_TEMPLATE(template void spaceToBatchND_, (const NDArray& input, const NDArray& padding, NDArray& output, const uint numOfSpatialDims), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void spaceToBatchND_, + (const NDArray& input, const NDArray& padding, + NDArray& output, const uint numOfSpatialDims), + LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// -void spaceToBatchND(sd::LaunchContext* context, const NDArray& input, const NDArray& blockShape, const NDArray& padding, NDArray& output ) { - - // 4D example with two spatial dimensions - // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockShape[0]*blockShape[1], (iH + padBottom + padTop)/blockShape[0], (iW + padLeft + padRight)/blockShape[1], iC] +void spaceToBatchND(sd::LaunchContext* context, const NDArray& input, + const NDArray& blockShape, const NDArray& padding, + NDArray& output) { + // 4D example with two spatial dimensions + // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockShape[0]*blockShape[1], + // (iH + padBottom + padTop)/blockShape[0], (iW + padLeft + + // padRight)/blockShape[1], iC] - const uint rank = input.rankOf(); + const uint rank = input.rankOf(); - const uint numOfSpatialDims = blockShape.sizeAt(0); + const uint numOfSpatialDims = blockShape.sizeAt(0); - //*** construct reshaping std::vector for first reshape of output array ***// - std::vector temp(numOfSpatialDims + rank); + //*** construct reshaping std::vector for first reshape of output array ***// + std::vector temp(numOfSpatialDims + rank); - int i; - for(i = 0; i < numOfSpatialDims; ++i) - temp[i] = blockShape.e(i); - temp[i++] = input.sizeAt(0); - for(int j = 1; j < rank; ++i, ++j) - temp[i] = output.sizeAt(j); + int i; + for (i = 0; i < numOfSpatialDims; ++i) temp[i] = blockShape.e(i); + temp[i++] = input.sizeAt(0); + for (int j = 1; j < rank; ++i, ++j) temp[i] = output.sizeAt(j); - NDArray outputRearranged0 = output.reshape(output.ordering(), temp, false); + NDArray outputRearranged0 = output.reshape(output.ordering(), temp, false); - //*** construct permuting std::vector for permutation of output array ***// + //*** construct permuting std::vector for permutation of output array ***// - temp[0] = numOfSpatialDims; - - for(i = 1; i <= numOfSpatialDims; ++i) { - temp[2*i - 1] = numOfSpatialDims + i; - temp[2*i] = i - 1; - } - for(i = 2 * numOfSpatialDims + 1; i < temp.size(); ++i) - temp[i] = i; + temp[0] = numOfSpatialDims; - outputRearranged0.permutei(temp); + for (i = 1; i <= numOfSpatialDims; ++i) { + temp[2 * i - 1] = numOfSpatialDims + i; + temp[2 * i] = i - 1; + } + for (i = 2 * numOfSpatialDims + 1; i < temp.size(); ++i) temp[i] = i; - // ****** // + outputRearranged0.permutei(temp); - if(input.lengthOf() == output.lengthOf()) { - outputRearranged0.assign(input); - } - else { + // ****** // - //*** construct reshaping std::vector for second reshape of output array ***// - temp.resize(rank); + if (input.lengthOf() == output.lengthOf()) { + outputRearranged0.assign(input); + } else { + //*** construct reshaping std::vector for second reshape of output array + //***// + temp.resize(rank); - temp[0] = input.sizeAt(0); + temp[0] = input.sizeAt(0); - for(i = 1; i < rank; ++i) - temp[i] = (i <= numOfSpatialDims) ? output.sizeAt(i) * blockShape.e(i - 1) : output.sizeAt(i); + for (i = 1; i < rank; ++i) + temp[i] = (i <= numOfSpatialDims) + ? output.sizeAt(i) * blockShape.e(i - 1) + : output.sizeAt(i); - NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), temp, false); + NDArray outputRearranged1 = + outputRearranged0.reshape(output.ordering(), temp, false); - BUILD_SINGLE_SELECTOR(input.dataType(), spaceToBatchND_, (input, padding, outputRearranged1, numOfSpatialDims), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), spaceToBatchND_, + (input, padding, outputRearranged1, numOfSpatialDims), + LIBND4J_TYPES); - if(output.buffer() != outputRearranged1.buffer()) - outputRearranged0.assign(outputRearranged1); - } + if (output.buffer() != outputRearranged1.buffer()) + outputRearranged0.assign(outputRearranged1); + } } - /* template struct SpaceToBatchHelper { template - static void run(T *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong *space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const Nd4jLong *block_offsets, T *ptrBatch, const Nd4jLong *batch_shape, const Nd4jLong *batch_strides) { - for (int batch_pos = 0; batch_pos < batch_shape[0]; ++batch_pos) { - const int space_pos = batch_pos * block_shape[0] + block_offsets[0] - pad_start[0]; - if (space_pos >= 0 && space_pos < space_shape[0]) { - SpaceToBatchHelper::run(ptrSpace + space_pos * space_strides[0], space_shape + 1, space_strides + 1, block_shape + 1, pad_start + 1, block_offsets + 1, ptrBatch, batch_shape + 1, batch_strides + 1); - } else { + static void run(T *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong +*space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const +Nd4jLong *block_offsets, T *ptrBatch, const Nd4jLong *batch_shape, const +Nd4jLong *batch_strides) { for (int batch_pos = 0; batch_pos < batch_shape[0]; +++batch_pos) { const int space_pos = batch_pos * block_shape[0] + +block_offsets[0] - pad_start[0]; if (space_pos >= 0 && space_pos < +space_shape[0]) { SpaceToBatchHelper::run(ptrSpace + space_pos * +space_strides[0], space_shape + 1, space_strides + 1, block_shape + 1, pad_start ++ 1, block_offsets + 1, ptrBatch, batch_shape + 1, batch_strides + 1); } else { if (!B2S) for (int i = 0; i < batch_strides[0]; i++) ptrBatch[i] = (T) 0.f; @@ -425,26 +459,30 @@ void spaceToBatchND(sd::LaunchContext* context, const NDArray& input, const NDAr template struct SpaceToBatchHelper<0, B2S> { template - static void run(T *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong *space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const Nd4jLong *block_offsets, T *ptrBatch, const Nd4jLong *batch_shape, const Nd4jLong *batch_strides) { - int str = batch_strides[-1]; - for (int i = 0; i < str; i++) - if (B2S) - ptrSpace[i] = ptrBatch[i]; - else - ptrBatch[i] = ptrSpace[i]; + static void run(T *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong +*space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const +Nd4jLong *block_offsets, T *ptrBatch, const Nd4jLong *batch_shape, const +Nd4jLong *batch_strides) { int str = batch_strides[-1]; for (int i = 0; i < str; +i++) if (B2S) ptrSpace[i] = ptrBatch[i]; else ptrBatch[i] = ptrSpace[i]; } }; template - void _execute(sd::LaunchContext * context, void *vptrSpace, const Nd4jLong *space_shape, const Nd4jLong *space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const Nd4jLong *block_offsets, void *vptrBatch, const Nd4jLong *batch_shape, const Nd4jLong *batch_strides) { - auto ptrSpace = reinterpret_cast(vptrSpace); - auto ptrBatch = reinterpret_cast(vptrBatch); - SpaceToBatchHelper::run(ptrSpace, space_shape, space_strides, block_shape, pad_start, block_offsets, ptrBatch, batch_shape, batch_strides); + void _execute(sd::LaunchContext * context, void *vptrSpace, const Nd4jLong +*space_shape, const Nd4jLong *space_strides, const Nd4jLong *block_shape, const +Nd4jLong *pad_start, const Nd4jLong *block_offsets, void *vptrBatch, const +Nd4jLong *batch_shape, const Nd4jLong *batch_strides) { auto ptrSpace = +reinterpret_cast(vptrSpace); auto ptrBatch = reinterpret_cast(vptrBatch); SpaceToBatchHelper::run(ptrSpace, +space_shape, space_strides, block_shape, pad_start, block_offsets, ptrBatch, +batch_shape, batch_strides); }; - Nd4jStatus _spaceToBatch(sd::LaunchContext * context, int internal_block_dims, NDArray *input, NDArray *output, std::vector &internal_input_shape, std::vector &internal_output_shape, Nd4jLong *block_shape, Nd4jLong *paddings) { - auto in = input->reshape('c', internal_input_shape); - auto out = output->reshape('c', internal_output_shape); + Nd4jStatus _spaceToBatch(sd::LaunchContext * context, int +internal_block_dims, NDArray *input, NDArray *output, std::vector +&internal_input_shape, std::vector &internal_output_shape, Nd4jLong +*block_shape, Nd4jLong *paddings) { auto in = input->reshape('c', +internal_input_shape); auto out = output->reshape('c', internal_output_shape); switch (internal_block_dims) { case 1: _prepare<1, false>(context, &in, &out, block_shape, paddings); @@ -459,16 +497,19 @@ void spaceToBatchND(sd::LaunchContext* context, const NDArray& input, const NDAr _prepare<4, false>(context, &in, &out, block_shape, paddings); break; default: { - return Status::THROW("SpaceToBatch: Wrong number of internal_block_dims"); + return Status::THROW("SpaceToBatch: Wrong number of +internal_block_dims"); } } return Status::OK(); } - Nd4jStatus _batchToSpace(sd::LaunchContext * context, int internal_block_dims, NDArray *input, NDArray *output, std::vector &internal_input_shape, std::vector &internal_output_shape, Nd4jLong *block_shape, Nd4jLong *crops) { - auto in = input->reshape('c', internal_input_shape); - auto out = output->reshape('c', internal_output_shape); + Nd4jStatus _batchToSpace(sd::LaunchContext * context, int +internal_block_dims, NDArray *input, NDArray *output, std::vector +&internal_input_shape, std::vector &internal_output_shape, Nd4jLong +*block_shape, Nd4jLong *crops) { auto in = input->reshape('c', +internal_input_shape); auto out = output->reshape('c', internal_output_shape); switch (internal_block_dims) { case 1: _prepare<1, true>(context, &in, &out, block_shape, crops); @@ -483,7 +524,8 @@ void spaceToBatchND(sd::LaunchContext* context, const NDArray& input, const NDAr _prepare<4, true>(context, &in, &out, block_shape, crops); break; default: { - return Status::THROW("BatchToSpace: Wrong number of internal_block_dims"); + return Status::THROW("BatchToSpace: Wrong number of +internal_block_dims"); } } @@ -498,12 +540,16 @@ void spaceToBatchND(sd::LaunchContext* context, const NDArray& input, const NDAr #define STB_BOOL (0, false),\ (1, true) - BUILD_TRIPLE_TEMPLATE(template void _execute, (sd::LaunchContext * context, void *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong *space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const Nd4jLong *block_offsets, void *ptrBatch, const Nd4jLong *batch_shape, const Nd4jLong *batch_strides), LIBND4J_TYPES, STB_DIM, STB_BOOL); + BUILD_TRIPLE_TEMPLATE(template void _execute, (sd::LaunchContext * context, +void *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong *space_strides, +const Nd4jLong *block_shape, const Nd4jLong *pad_start, const Nd4jLong +*block_offsets, void *ptrBatch, const Nd4jLong *batch_shape, const Nd4jLong +*batch_strides), LIBND4J_TYPES, STB_DIM, STB_BOOL); #undef STB_BOOL #undef STB_DIM */ -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/s_t_d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/s_t_d.cpp index b51a4adc9606..dd1e84e4d135 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/s_t_d.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/s_t_d.cpp @@ -18,91 +18,103 @@ // // -#include #include +#include namespace sd { namespace ops { namespace helpers { - template - static void _spaceTodepth_(const NDArray &input, NDArray *output, int block_size, bool isNHWC) { - auto input_ptr = reinterpret_cast(input.buffer()); - auto output_ptr = reinterpret_cast(output->buffer()); - - const int batch_size = input.sizeAt(0); - const int input_depth = isNHWC ? input.sizeAt(3) : input.sizeAt(1); - const int input_height = isNHWC ? input.sizeAt(1) : input.sizeAt(2); - const int input_width = isNHWC ? input.sizeAt(2) : input.sizeAt(3); - - const int output_depth = isNHWC ? output->sizeAt(3) : output->sizeAt(1); - const int output_height = isNHWC ? output->sizeAt(1) : output->sizeAt(2); - const int output_width = isNHWC ? output->sizeAt(2) : output->sizeAt(3); - - const int input_depth_by_output_height = input_depth * output_height; - - const int output_area = output_width * output_height; - const int output_depth_by_output_area = output_depth * output_area; - - - if (isNHWC) { - const int total_count = batch_size * input_height * input_width * input_depth; - - auto func = PRAGMA_THREADS_FOR { - for (auto inp_idx = start; inp_idx < stop; inp_idx++) { - // inp_idx = d + input_depth * (w + input_width * (h + input_height * b)) - const int d = inp_idx % input_depth; - const int inp_idx2 = inp_idx / input_depth; - const int w = inp_idx2 % input_width; - const int inp_idx3 = inp_idx2 / input_width; - const int h = inp_idx3 % input_height; - const int b = inp_idx3 / input_height; - - const int out_h = h / block_size; - const int offset_h = h % block_size; - const int out_w = w / block_size; - const int offset_w = w % block_size; - const int offset_d = (offset_h * block_size + offset_w) * input_depth; - const int out_d = d + offset_d; - - const int out_idx = out_d + output_depth * (out_w + output_width * (out_h + output_height * b)); - *(output_ptr + out_idx) = *(input_ptr + inp_idx); - } - }; - - samediff::Threads::parallel_for(func, 0, total_count); - } else { - const int total_count = batch_size * output_depth_by_output_area; - - auto func = PRAGMA_THREADS_FOR { - for (auto inp_idx = start; inp_idx < stop; inp_idx++) { - const int n_iC_oY_bY_oX = inp_idx / block_size; - const int bX = inp_idx - n_iC_oY_bY_oX * block_size; - - const int n_iC_oY_bY = n_iC_oY_bY_oX / output_width; - const int oX = n_iC_oY_bY_oX - n_iC_oY_bY * output_width; - - const int n_iC_oY = n_iC_oY_bY / block_size; - const int bY = n_iC_oY_bY - n_iC_oY * block_size; - - const int n = n_iC_oY / input_depth_by_output_height; - const int iC_oY = n_iC_oY - n * input_depth_by_output_height; - - const int output_idx = oX + (((n * block_size + bY) * block_size + bX) * input_depth_by_output_height + iC_oY) * output_width; - - *(output_ptr + output_idx) = *(input_ptr + inp_idx); - } - }; - - samediff::Threads::parallel_for(func, 0, total_count); - } - } +template +static void _spaceTodepth_(const NDArray &input, NDArray *output, + int block_size, bool isNHWC) { + auto input_ptr = reinterpret_cast(input.buffer()); + auto output_ptr = reinterpret_cast(output->buffer()); + + const int batch_size = input.sizeAt(0); + const int input_depth = isNHWC ? input.sizeAt(3) : input.sizeAt(1); + const int input_height = isNHWC ? input.sizeAt(1) : input.sizeAt(2); + const int input_width = isNHWC ? input.sizeAt(2) : input.sizeAt(3); + + const int output_depth = isNHWC ? output->sizeAt(3) : output->sizeAt(1); + const int output_height = isNHWC ? output->sizeAt(1) : output->sizeAt(2); + const int output_width = isNHWC ? output->sizeAt(2) : output->sizeAt(3); + + const int input_depth_by_output_height = input_depth * output_height; + + const int output_area = output_width * output_height; + const int output_depth_by_output_area = output_depth * output_area; + + if (isNHWC) { + const int total_count = + batch_size * input_height * input_width * input_depth; + + auto func = PRAGMA_THREADS_FOR { + for (auto inp_idx = start; inp_idx < stop; inp_idx++) { + // inp_idx = d + input_depth * (w + input_width * (h + input_height * + // b)) + const int d = inp_idx % input_depth; + const int inp_idx2 = inp_idx / input_depth; + const int w = inp_idx2 % input_width; + const int inp_idx3 = inp_idx2 / input_width; + const int h = inp_idx3 % input_height; + const int b = inp_idx3 / input_height; + + const int out_h = h / block_size; + const int offset_h = h % block_size; + const int out_w = w / block_size; + const int offset_w = w % block_size; + const int offset_d = (offset_h * block_size + offset_w) * input_depth; + const int out_d = d + offset_d; + + const int out_idx = + out_d + + output_depth * (out_w + output_width * (out_h + output_height * b)); + *(output_ptr + out_idx) = *(input_ptr + inp_idx); + } + }; + + samediff::Threads::parallel_for(func, 0, total_count); + } else { + const int total_count = batch_size * output_depth_by_output_area; + + auto func = PRAGMA_THREADS_FOR { + for (auto inp_idx = start; inp_idx < stop; inp_idx++) { + const int n_iC_oY_bY_oX = inp_idx / block_size; + const int bX = inp_idx - n_iC_oY_bY_oX * block_size; + + const int n_iC_oY_bY = n_iC_oY_bY_oX / output_width; + const int oX = n_iC_oY_bY_oX - n_iC_oY_bY * output_width; + + const int n_iC_oY = n_iC_oY_bY / block_size; + const int bY = n_iC_oY_bY - n_iC_oY * block_size; + + const int n = n_iC_oY / input_depth_by_output_height; + const int iC_oY = n_iC_oY - n * input_depth_by_output_height; + + const int output_idx = oX + (((n * block_size + bY) * block_size + bX) * + input_depth_by_output_height + + iC_oY) * + output_width; + + *(output_ptr + output_idx) = *(input_ptr + inp_idx); + } + }; + + samediff::Threads::parallel_for(func, 0, total_count); + } +} - void _spaceTodepth(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) { - BUILD_SINGLE_SELECTOR(input.dataType(), _spaceTodepth_, (input, output, block_size, isNHWC), LIBND4J_TYPES); - } +void _spaceTodepth(sd::LaunchContext *context, const NDArray &input, + NDArray *output, int block_size, bool isNHWC) { + BUILD_SINGLE_SELECTOR(input.dataType(), _spaceTodepth_, + (input, output, block_size, isNHWC), LIBND4J_TYPES); +} - BUILD_SINGLE_TEMPLATE(template void _spaceTodepth_, (const NDArray &input, NDArray *output, int block_size, bool isNHWC), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void _spaceTodepth_, + (const NDArray &input, NDArray *output, int block_size, + bool isNHWC), + LIBND4J_TYPES); -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp b/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp index 0693406bfe32..d9e82da161ef 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp @@ -18,178 +18,190 @@ // @author raver119@gmail.com // +#include +#include #include + #include -#include -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// // x - indices, z - input/output -template -Nd4jLong checkIndices_(const NDArray& indices, const NDArray& output, const int axis) { +template +Nd4jLong checkIndices_(const NDArray& indices, const NDArray& output, + const int axis) { + std::atomic numOfBadIndx{0}; - std::atomic numOfBadIndx{0}; + const auto x = indices.bufferAsT(); - const auto x = indices.bufferAsT(); + const auto xShapeInfo = indices.shapeInfo(); + const auto zShapeInfo = output.shapeInfo(); - const auto xShapeInfo = indices.shapeInfo(); - const auto zShapeInfo = output.shapeInfo(); + const auto xRank = indices.rankOf(); - const auto xRank = indices.rankOf(); + auto func = PRAGMA_THREADS_FOR { + int xCoords[MAX_RANK]; - auto func = PRAGMA_THREADS_FOR { - - int xCoords[MAX_RANK]; - - for (auto i = start; i < stop; i++) { + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, xShapeInfo, xCoords); - shape::index2coordsCPU(start, i, xShapeInfo, xCoords); + const Nd4jLong currentInd = x[shape::getOffset(xShapeInfo, xCoords)]; - const Nd4jLong currentInd = x[shape::getOffset(xShapeInfo, xCoords)]; - - if(currentInd >= shape::sizeAt(zShapeInfo, axis == -1 ? xCoords[xRank-1] : axis)) { - printf("checkIndices: out of range element %lld at index %ld \n", currentInd, i); - ++numOfBadIndx; - } - } - }; + if (currentInd >= + shape::sizeAt(zShapeInfo, axis == -1 ? xCoords[xRank - 1] : axis)) { + printf("checkIndices: out of range element %lld at index %ld \n", + currentInd, i); + ++numOfBadIndx; + } + } + }; - samediff::Threads::parallel_for(func, 0, indices.lengthOf()); + samediff::Threads::parallel_for(func, 0, indices.lengthOf()); - return numOfBadIndx; + return numOfBadIndx; } /////////////////////////////////////////////////////////////////// -Nd4jLong checkIndices(sd::LaunchContext *context, const NDArray& indices, const NDArray& output, const int axis) { - - BUILD_SINGLE_SELECTOR(indices.dataType(), return checkIndices_, (indices, output, axis), INDEXING_TYPES); +Nd4jLong checkIndices(sd::LaunchContext* context, const NDArray& indices, + const NDArray& output, const int axis) { + BUILD_SINGLE_SELECTOR(indices.dataType(), return checkIndices_, + (indices, output, axis), INDEXING_TYPES); } /////////////////////////////////////////////////////////////////// -void scatter(sd::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) { - - const int outRank = output.rankOf(); - const int indRank = indices.rankOf(); - const int updRank = updates.rankOf(); - const Nd4jLong indLen = indices.lengthOf(); - - if(outRank == 1) { - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - Nd4jLong idx = indices.e(i); - NDArray out = output({idx, idx + 1}); +void scatter(sd::LaunchContext* context, pairwise::Ops op, + const NDArray& indices, const NDArray& updates, NDArray& output, + const bool lock) { + const int outRank = output.rankOf(); + const int indRank = indices.rankOf(); + const int updRank = updates.rankOf(); + const Nd4jLong indLen = indices.lengthOf(); + + if (outRank == 1) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + Nd4jLong idx = indices.e(i); + NDArray out = output({idx, idx + 1}); - out.applyPairwiseTransform(op, updates.e(i)); - } - }; + out.applyPairwiseTransform(op, updates.e(i)); + } + }; - samediff::Threads::parallel_tad(func, 0, indLen, 1, lock ? 1 : sd::Environment::getInstance().maxThreads()); - } - else { // outRank > 1 + samediff::Threads::parallel_tad( + func, 0, indLen, 1, + lock ? 1 : sd::Environment::getInstance().maxThreads()); + } else { // outRank > 1 - int sizeOfDims = indRank; - if(outRank == updRank && indices.isVector()) - sizeOfDims = 1; + int sizeOfDims = indRank; + if (outRank == updRank && indices.isVector()) sizeOfDims = 1; - std::vector dimsToExcludeUpd(sizeOfDims); - std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0); + std::vector dimsToExcludeUpd(sizeOfDims); + std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0); - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - NDArray outSubArr = output(indices.e(i), std::vector({0})); - NDArray updSubArr = updates(i, dimsToExcludeUpd); + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + NDArray outSubArr = + output(indices.e(i), std::vector({0})); + NDArray updSubArr = updates(i, dimsToExcludeUpd); - outSubArr.applyPairwiseTransform(op, updSubArr); - } - }; + outSubArr.applyPairwiseTransform(op, updSubArr); + } + }; - samediff::Threads::parallel_tad(func, 0, indLen, 1, lock ? 1 : sd::Environment::getInstance().maxThreads()); - } + samediff::Threads::parallel_tad( + func, 0, indLen, 1, + lock ? 1 : sd::Environment::getInstance().maxThreads()); + } } /////////////////////////////////////////////////////////////////// -void scatterND(sd::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) { - - const Nd4jLong indLen = indices.lengthOf(); - const int outRank = output.rankOf(); - const int indRank = indices.rankOf(); - const Nd4jLong indLastDim = indices.sizeAt(-1); - - if(outRank == 1) { - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - Nd4jLong idx = indices.e(i); - NDArray out = output({idx, idx + 1}); +void scatterND(sd::LaunchContext* context, pairwise::Ops op, + const NDArray& indices, const NDArray& updates, NDArray& output, + const bool lock) { + const Nd4jLong indLen = indices.lengthOf(); + const int outRank = output.rankOf(); + const int indRank = indices.rankOf(); + const Nd4jLong indLastDim = indices.sizeAt(-1); + + if (outRank == 1) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + Nd4jLong idx = indices.e(i); + NDArray out = output({idx, idx + 1}); - out.applyPairwiseTransform(op, updates.e(i), nullptr); - } - }; + out.applyPairwiseTransform(op, updates.e(i), nullptr); + } + }; - samediff::Threads::parallel_tad(func, 0, indLen, 1, lock ? 1 : sd::Environment::getInstance().maxThreads()); - } - else { - std::vector dimsToExcludeInd = ShapeUtils::evalDimsToExclude(indRank, {indRank-1}); - std::vector dimsToExcludeUpd(indRank - 1); - std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0); + samediff::Threads::parallel_tad( + func, 0, indLen, 1, + lock ? 1 : sd::Environment::getInstance().maxThreads()); + } else { + std::vector dimsToExcludeInd = + ShapeUtils::evalDimsToExclude(indRank, {indRank - 1}); + std::vector dimsToExcludeUpd(indRank - 1); + std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0); - auto func = PRAGMA_THREADS_FOR { - std::vector idxRangeOut(2*outRank, 0); + auto func = PRAGMA_THREADS_FOR { + std::vector idxRangeOut(2 * outRank, 0); - for (auto i = start; i < stop; i++) { - NDArray indSubArr = indices(i, dimsToExcludeInd); + for (auto i = start; i < stop; i++) { + NDArray indSubArr = indices(i, dimsToExcludeInd); - for (Nd4jLong j = 0; j < indLastDim; ++j) { - idxRangeOut[2 * j] = indSubArr.e(j); - idxRangeOut[2 * j + 1] = idxRangeOut[2 * j] + 1; - } + for (Nd4jLong j = 0; j < indLastDim; ++j) { + idxRangeOut[2 * j] = indSubArr.e(j); + idxRangeOut[2 * j + 1] = idxRangeOut[2 * j] + 1; + } - NDArray outSubArr = output(idxRangeOut); - NDArray updSubArr = updates(i, dimsToExcludeUpd); + NDArray outSubArr = output(idxRangeOut); + NDArray updSubArr = updates(i, dimsToExcludeUpd); - outSubArr.applyPairwiseTransform(op, updSubArr); - } - }; + outSubArr.applyPairwiseTransform(op, updSubArr); + } + }; - samediff::Threads::parallel_tad(func, 0, indLen / indLastDim, 1, lock ? 1 : sd::Environment::getInstance().maxThreads()); - } + samediff::Threads::parallel_tad( + func, 0, indLen / indLastDim, 1, + lock ? 1 : sd::Environment::getInstance().maxThreads()); + } } -void scatterForLoss(sd::LaunchContext *context, const NDArray& indices, NDArray& updates, NDArray& output, const bool calcGrad) { +void scatterForLoss(sd::LaunchContext* context, const NDArray& indices, + NDArray& updates, NDArray& output, const bool calcGrad) { + // shapes of indices and output must be the same + // shape of indices should be the same as updates shape with last dimension + // excluded for example if updates is {a,b,c} then indices should be {a,b} - // shapes of indices and output must be the same - // shape of indices should be the same as updates shape with last dimension excluded - // for example if updates is {a,b,c} then indices should be {a,b} + const Nd4jLong indicesLen = indices.lengthOf(); - const Nd4jLong indicesLen = indices.lengthOf(); + std::vector dimsToExclude = + ShapeUtils::evalDimsToExclude(updates.rankOf(), {-1}); - std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(updates.rankOf(), {-1}); - - if(!calcGrad) { - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto subArr = updates(i, dimsToExclude); - output.p(i, subArr.e(indices.e(i))); - } - }; + if (!calcGrad) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto subArr = updates(i, dimsToExclude); + output.p(i, subArr.e(indices.e(i))); + } + }; - samediff::Threads::parallel_for(func, 0, indicesLen); - } else { - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto subArr = updates(i, dimsToExclude); - auto ind = indices.e(i); - subArr.p(ind, subArr.e(ind) - 1.); - } - }; + samediff::Threads::parallel_for(func, 0, indicesLen); + } else { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto subArr = updates(i, dimsToExclude); + auto ind = indices.e(i); + subArr.p(ind, subArr.e(ind) - 1.); + } + }; - samediff::Threads::parallel_for(func, 0, indicesLen); - } + samediff::Threads::parallel_for(func, 0, indicesLen); + } } -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/scatterUpdateAndSimple.cpp b/libnd4j/include/ops/declarable/helpers/cpu/scatterUpdateAndSimple.cpp index fe41c5d43335..042acde058ad 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/scatterUpdateAndSimple.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/scatterUpdateAndSimple.cpp @@ -18,98 +18,103 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // -#include -#include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// -void scatterUpdate(sd::LaunchContext * context, NDArray& input, NDArray& updates, const std::vector* intArgs) { - - int opCode = (*intArgs)[0]; - int dimSize = (*intArgs)[1]; - Nd4jLong e; - Nd4jLong limg = 2 + dimSize; - std::vector tadDimensions(dimSize); - for (e = 2; e < limg; e++) - tadDimensions[e-2] = (*intArgs)[e]; - - std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(input.rankOf(), tadDimensions); - - // increasing counter to skip numIndices - e++; - std::vector indices; - for (; e < static_cast(intArgs->size()); e++) - indices.push_back((*intArgs)[e]); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto inSubArr = input(indices[i], dimsToExclude, true); - auto updSubArr = updates(i, dimsToExclude, true); - - if (inSubArr.lengthOf() != updSubArr.lengthOf()) - continue; - - switch (opCode) { - case 0: - inSubArr.applyPairwiseTransform(pairwise::Add, updSubArr, inSubArr); - break; - case 1: - inSubArr.applyPairwiseTransform(pairwise::Subtract, updSubArr, inSubArr); - break; - case 2: - inSubArr.applyPairwiseTransform(pairwise::Multiply, updSubArr, inSubArr); - break; - case 3: - inSubArr.applyPairwiseTransform(pairwise::Divide, updSubArr, inSubArr); - break; - case 4: - inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, updSubArr, inSubArr); - break; - case 5: - inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, updSubArr, inSubArr); - break; - case 6: - inSubArr.applyPairwiseTransform(pairwise::CopyPws, updSubArr, inSubArr); - break; - default: - continue; - } - } - }; +void scatterUpdate(sd::LaunchContext* context, NDArray& input, NDArray& updates, + const std::vector* intArgs) { + int opCode = (*intArgs)[0]; + int dimSize = (*intArgs)[1]; + Nd4jLong e; + Nd4jLong limg = 2 + dimSize; + std::vector tadDimensions(dimSize); + for (e = 2; e < limg; e++) tadDimensions[e - 2] = (*intArgs)[e]; + + std::vector dimsToExclude = + ShapeUtils::evalDimsToExclude(input.rankOf(), tadDimensions); + + // increasing counter to skip numIndices + e++; + std::vector indices; + for (; e < static_cast(intArgs->size()); e++) + indices.push_back((*intArgs)[e]); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto inSubArr = input(indices[i], dimsToExclude, true); + auto updSubArr = updates(i, dimsToExclude, true); + + if (inSubArr.lengthOf() != updSubArr.lengthOf()) continue; + + switch (opCode) { + case 0: + inSubArr.applyPairwiseTransform(pairwise::Add, updSubArr, inSubArr); + break; + case 1: + inSubArr.applyPairwiseTransform(pairwise::Subtract, updSubArr, + inSubArr); + break; + case 2: + inSubArr.applyPairwiseTransform(pairwise::Multiply, updSubArr, + inSubArr); + break; + case 3: + inSubArr.applyPairwiseTransform(pairwise::Divide, updSubArr, + inSubArr); + break; + case 4: + inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, updSubArr, + inSubArr); + break; + case 5: + inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, updSubArr, + inSubArr); + break; + case 6: + inSubArr.applyPairwiseTransform(pairwise::CopyPws, updSubArr, + inSubArr); + break; + default: + continue; + } + } + }; - samediff::Threads::parallel_tad(func, 0, indices.size()); + samediff::Threads::parallel_tad(func, 0, indices.size()); } - ////////////////////////////////////////////////////////////////////////// -void scatterSimple(sd::LaunchContext * context, const int opId, NDArray& input, const NDArray& updates, const NDArray& indices, const std::vector& dimensions) { - - // updates and indices have same length - const Nd4jLong len = indices.lengthOf(); - - switch (opId) { - - case 6: { // copy - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto inSubArr = input(i, dimensions); - inSubArr.p(indices.t(i), updates.e(i)); - } - }; - - samediff::Threads::parallel_for(func, 0, len); +void scatterSimple(sd::LaunchContext* context, const int opId, NDArray& input, + const NDArray& updates, const NDArray& indices, + const std::vector& dimensions) { + // updates and indices have same length + const Nd4jLong len = indices.lengthOf(); + + switch (opId) { + case 6: { // copy + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto inSubArr = input(i, dimensions); + inSubArr.p(indices.t(i), updates.e(i)); } - break; + }; - default: - throw std::invalid_argument("helpers::scatterSimple: operation is not implemented for given id !"); - } -} + samediff::Threads::parallel_for(func, 0, len); + } break; + default: + throw std::invalid_argument( + "helpers::scatterSimple: operation is not implemented for given id " + "!"); + } } -} -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp b/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp index 50ff79679b3c..a52a9d5498a8 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp @@ -19,1073 +19,1146 @@ // @author GS // -#include -#include #include +#include +#include + #include namespace sd { namespace ops { namespace helpers { - // segment max - template - static void segmentMaxFunctor_(NDArray* input, NDArray* indices, NDArray* output) { - //int numClasses = output->sizeAt(0); - // if input is a vector: (as if in doc sample) - Nd4jLong idx = indices->e(0); - if (input->isVector()) { - T val = input->e(0); - - for (Nd4jLong e = 1; e < indices->lengthOf(); e++) { - if (idx == indices->e(e)) { - // max - val = sd::math::nd4j_max(val, input->t(e)); - } - else { - idx = indices->e(e); - val = input->t(e); - } - output->r(idx) = val; - } - } - else { - std::vector restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto listOfTensors = input->allTensorsAlongDimension(restDims); - auto listOfOutTensors = output->allTensorsAlongDimension(restDims); +// segment max +template +static void segmentMaxFunctor_(NDArray* input, NDArray* indices, + NDArray* output) { + // int numClasses = output->sizeAt(0); + // if input is a vector: (as if in doc sample) + Nd4jLong idx = indices->e(0); + if (input->isVector()) { + T val = input->e(0); + + for (Nd4jLong e = 1; e < indices->lengthOf(); e++) { + if (idx == indices->e(e)) { + // max + val = sd::math::nd4j_max(val, input->t(e)); + } else { + idx = indices->e(e); + val = input->t(e); + } + output->r(idx) = val; + } + } else { + std::vector restDims = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto listOfTensors = input->allTensorsAlongDimension(restDims); + auto listOfOutTensors = output->allTensorsAlongDimension(restDims); - auto numOfClasses = output->sizeAt(0); // number of classes - std::vector> outputs(numOfClasses); - auto maxT = listOfOutTensors.at(idx); + auto numOfClasses = output->sizeAt(0); // number of classes + std::vector> outputs(numOfClasses); + auto maxT = listOfOutTensors.at(idx); - //int pos = 0; - maxT->assign(listOfTensors.at(0)); + // int pos = 0; + maxT.assign(listOfTensors.at(0)); for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { if (indices->e(i) == idx) { - for (Nd4jLong e = 0; e < maxT->lengthOf(); e++) { - maxT->r(e) = sd::math::nd4j_max(maxT->t(e), listOfTensors.at(i)->t(e)); + for (Nd4jLong e = 0; e < maxT.lengthOf(); e++) { + maxT.r(e) = sd::math::nd4j_max(maxT.t(e), listOfTensors.at(i).t(e)); } } else { idx = indices->e(i); maxT = listOfOutTensors.at(idx); - maxT->assign(listOfTensors.at(i)); + maxT.assign(listOfTensors.at(i)); } } } } - // segmen min - template - static void segmentMinFunctor_(NDArray* input, NDArray* indices, NDArray* output) { - //int numClasses = output->sizeAt(0); - // if input is a vector: (as if in doc sample) - Nd4jLong idx = indices->e(0); - if (input->isVector()) { - T val = input->e(0); - - for (Nd4jLong e = 1; e < indices->lengthOf(); e++) { - if (idx == indices->e(e)) { - // min - val = sd::math::nd4j_min(val, input->t(e)); - } - else { - idx = indices->e(e); - val = input->t(e); - } - output->r(idx) = val; - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - int numOfClasses = output->sizeAt(0); // number of classes - std::vector> outputs(numOfClasses); - auto minT = listOfOutTensors.at(idx); - - int pos = 0; - minT->assign(listOfTensors.at(0)); - - for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { - if (indices->e(i) == idx) { - - for (Nd4jLong e = 0; e < minT->lengthOf(); e++) { - minT->p(e, sd::math::nd4j_min(minT->e(e), listOfTensors.at(i)->e(e))); - } - } - else { - idx = indices->e(i); - minT = listOfOutTensors.at(idx); - minT->assign(listOfTensors.at(i)); - } - } - } - } - - // segmen mean - template - static void segmentMeanFunctor_(NDArray* input, NDArray* indices, NDArray* output) { - int numClasses = output->sizeAt(0); - // if input is a vector: (as if in doc sample) - int idx = indices->e(0); - if (input->isVector()) { - T val = T(0.f); - int count = 0; - - for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { - if (idx == indices->e(e)) { - // mean - val += input->e(e); - count++; - } - else { - output->p(idx, val / count); - idx = indices->e(e); - val = input->e(e); - count = 1; - } - output->p(idx, val / count); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - auto listOfTensors = input->allTensorsAlongDimension(restDims); - auto listOfOutTensors = output->allTensorsAlongDimension(restDims); - - int numOfClasses = output->sizeAt(0); // number of classes - std::vector> outputs(numOfClasses); - auto meanT = listOfOutTensors.at(idx); - int count = 1; - auto meanV = meanT->dup(); - meanV.assign(listOfTensors.at(0)); - - for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { - if (indices->e(i) == idx) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - meanV.p(e, meanV.e(e) + listOfTensors.at(i)->e(e)); - } - }; - samediff::Threads::parallel_for(func, 0, meanT->lengthOf()); - - count++; - } - else { - //meanT->assign(meanV); - meanV.applyScalar(scalar::Divide, count, *meanT); - idx = indices->e(i); - meanT = listOfOutTensors.at(idx); - meanV.assign(listOfTensors.at(i)); - count = 1; - } - meanV.applyScalar(scalar::Divide, count, *meanT); - } - } +// segmen min +template +static void segmentMinFunctor_(NDArray* input, NDArray* indices, + NDArray* output) { + // int numClasses = output->sizeAt(0); + // if input is a vector: (as if in doc sample) + Nd4jLong idx = indices->e(0); + if (input->isVector()) { + T val = input->e(0); + + for (Nd4jLong e = 1; e < indices->lengthOf(); e++) { + if (idx == indices->e(e)) { + // min + val = sd::math::nd4j_min(val, input->t(e)); + } else { + idx = indices->e(e); + val = input->t(e); + } + output->r(idx) = val; } + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - template - static void segmentSumFunctor_(NDArray* input, NDArray* indices, NDArray* output) { - int numClasses = output->sizeAt(0); - // if input is a vector: (as if in doc sample) - int idx = indices->e(0); - if (input->isVector()) { - T val = T(0.f); - int count = 0; - for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { - if (idx == indices->e(e)) { - // sum - val += input->t(e); - } - else { - idx = indices->e(e); - val = input->t(e); - } - output->p(idx, val); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - auto listOfTensors = input->allTensorsAlongDimension(restDims); - auto listOfOutTensors = output->allTensorsAlongDimension(restDims); - - int numOfClasses = output->sizeAt(0); // number of classes - std::vector> outputs(numOfClasses); - auto sumT = listOfOutTensors.at(idx); - - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - if (indices->e(i) == idx) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - sumT->p(e, sumT->e(e) + listOfTensors.at(i)->e(e)); - } - }; - samediff::Threads::parallel_for(func, 0, sumT->lengthOf()); - } - else { - idx = indices->e(i); - sumT = listOfOutTensors.at(idx); - sumT->assign(listOfTensors.at(i)); - } - } - } - } + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - template - static void segmentProdFunctor_(NDArray* input, NDArray* indices, NDArray* output) { - //int numClasses = output->sizeAt(0); - // if input is a vector: (as if in doc sample) - int idx = indices->e(0); - output->assign(1.f); - if (input->isVector()) { - T val = input->e(0); - int count = 0; - - for (Nd4jLong e = 1; e < indices->lengthOf(); e++) { - if (idx == indices->e(e)) { - // sum - val *= input->e(e); - } - else { - idx = indices->e(e); - val = input->e(e); - } - output->p(idx, val); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + int numOfClasses = output->sizeAt(0); // number of classes + std::vector> outputs(numOfClasses); + auto minT = listOfOutTensors.at(idx); - auto listOfTensors = input->allTensorsAlongDimension(restDims); - auto listOfOutTensors = output->allTensorsAlongDimension(restDims); + int pos = 0; + minT.assign(listOfTensors.at(0)); - int numOfClasses = output->sizeAt(0); // number of classes - auto sumT = listOfOutTensors.at(idx); - sumT->assign(listOfTensors.at(0)); - for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { - if (indices->e(i) == idx) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - sumT->p(e, sumT->e(e) * listOfTensors.at(i)->e(e)); - } - }; - samediff::Threads::parallel_for(func, 0, sumT->lengthOf()); - } - else { - idx = indices->e(i); - sumT = listOfOutTensors.at(idx); - sumT->assign(listOfTensors.at(i)); - } - } + for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { + if (indices->e(i) == idx) { + for (Nd4jLong e = 0; e < minT.lengthOf(); e++) { + minT.p(e, + sd::math::nd4j_min(minT.e(e), listOfTensors.at(i).e(e))); } + } else { + idx = indices->e(i); + minT = listOfOutTensors.at(idx); + minT.assign(listOfTensors.at(i)); + } } + } +} -// template -// static bool segmentIndicesValidate_(NDArray* indices, NDArray& aexpected, NDArray& anOutput) { -// } - - void segmentMaxFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), segmentMaxFunctor_, (input, indices, output), LIBND4J_TYPES); +// segmen mean +template +static void segmentMeanFunctor_(NDArray* input, NDArray* indices, + NDArray* output) { + int numClasses = output->sizeAt(0); + // if input is a vector: (as if in doc sample) + int idx = indices->e(0); + if (input->isVector()) { + T val = T(0.f); + int count = 0; + + for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { + if (idx == indices->e(e)) { + // mean + val += input->e(e); + count++; + } else { + output->p(idx, val / count); + idx = indices->e(e); + val = input->e(e); + count = 1; + } + output->p(idx, val / count); } - - void segmentMinFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), segmentMinFunctor_, (input, indices, output), LIBND4J_TYPES); + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + auto listOfTensors = input->allTensorsAlongDimension(restDims); + auto listOfOutTensors = output->allTensorsAlongDimension(restDims); + + int numOfClasses = output->sizeAt(0); // number of classes + std::vector> outputs(numOfClasses); + auto meanT = listOfOutTensors.at(idx); + int count = 1; + auto meanV = meanT.dup(); + meanV.assign(listOfTensors.at(0)); + + for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { + if (indices->e(i) == idx) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + meanV.p(e, meanV.e(e) + listOfTensors.at(i).e(e)); + } + }; + samediff::Threads::parallel_for(func, 0, meanT.lengthOf()); + + count++; + } else { + // meanT->assign(meanV); + meanV.applyScalar(scalar::Divide, count, meanT); + idx = indices->e(i); + meanT = listOfOutTensors.at(idx); + meanV.assign(listOfTensors.at(i)); + count = 1; + } + meanV.applyScalar(scalar::Divide, count, meanT); } + } +} - void segmentMeanFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), segmentMeanFunctor_, (input, indices, output), LIBND4J_TYPES); +template +static void segmentSumFunctor_(NDArray* input, NDArray* indices, + NDArray* output) { + int numClasses = output->sizeAt(0); + // if input is a vector: (as if in doc sample) + int idx = indices->e(0); + if (input->isVector()) { + T val = T(0.f); + int count = 0; + for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { + if (idx == indices->e(e)) { + // sum + val += input->t(e); + } else { + idx = indices->e(e); + val = input->t(e); + } + output->p(idx, val); } - - void segmentSumFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), segmentSumFunctor_, (input, indices, output), LIBND4J_TYPES); + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + auto listOfTensors = input->allTensorsAlongDimension(restDims); + auto listOfOutTensors = output->allTensorsAlongDimension(restDims); + + int numOfClasses = output->sizeAt(0); // number of classes + std::vector> outputs(numOfClasses); + auto sumT = listOfOutTensors.at(idx); + + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + if (indices->e(i) == idx) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + sumT.p(e, sumT.e(e) + listOfTensors.at(i).e(e)); + } + }; + samediff::Threads::parallel_for(func, 0, sumT.lengthOf()); + } else { + idx = indices->e(i); + sumT = listOfOutTensors.at(idx); + sumT.assign(listOfTensors.at(i)); + } } + } +} - void segmentProdFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), segmentProdFunctor_, (input, indices, output), LIBND4J_TYPES); +template +static void segmentProdFunctor_(NDArray* input, NDArray* indices, + NDArray* output) { + // int numClasses = output->sizeAt(0); + // if input is a vector: (as if in doc sample) + int idx = indices->e(0); + output->assign(1.f); + if (input->isVector()) { + T val = input->e(0); + int count = 0; + + for (Nd4jLong e = 1; e < indices->lengthOf(); e++) { + if (idx == indices->e(e)) { + // sum + val *= input->e(e); + } else { + idx = indices->e(e); + val = input->e(e); + } + output->p(idx, val); } - - bool segmentIndicesValidate(sd::LaunchContext * context, NDArray* indices, NDArray& expected, NDArray& output) { - auto val = indices->e(0); - for (Nd4jLong e = 1; e < indices->lengthOf(); e++) { - output = indices->e(e); - if (val.e(0) > output.e(0)) - return false; - val = indices->e(e); - } - - return true; + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + auto listOfTensors = input->allTensorsAlongDimension(restDims); + auto listOfOutTensors = output->allTensorsAlongDimension(restDims); + + int numOfClasses = output->sizeAt(0); // number of classes + auto sumT = listOfOutTensors.at(idx); + sumT.assign(listOfTensors.at(0)); + for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { + if (indices->e(i) == idx) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + sumT.p(e, sumT.e(e) * listOfTensors.at(i).e(e)); + } + }; + samediff::Threads::parallel_for(func, 0, sumT.lengthOf()); + } else { + idx = indices->e(i); + sumT = listOfOutTensors.at(idx); + sumT.assign(listOfTensors.at(i)); + } } + } +} - //BUILD_SINGLE_TEMPLATE(template bool segmentIndicesValidate_, (NDArray*, NDArray&, NDArray&), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void segmentProdFunctor_, (NDArray* input, NDArray* indices, NDArray* output), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void segmentSumFunctor_, (NDArray* input, NDArray* indices, NDArray* output), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void segmentMeanFunctor_, (NDArray* input, NDArray* indices, NDArray* output), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void segmentMinFunctor_, (NDArray* input, NDArray* indices, NDArray* output), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void segmentMaxFunctor_, (NDArray* input, NDArray* indices, NDArray* output), LIBND4J_TYPES); - // -------------------------------------------------------------------------------------------------------------- // - // Unsorted segment ops - // -------------------------------------------------------------------------------------------------------------- // - - bool unsortedSegmentIndicesValidate(sd::LaunchContext * context, NDArray* indices, Nd4jLong expected, Nd4jLong& output) { - Nd4jLong val = indices->e(0); - - Nd4jLong maxInd = indices->argMax(); - if (indices->e(maxInd) >= expected) { - output = val; - return false; - } - output = expected; - return true; - } +// template +// static bool segmentIndicesValidate_(NDArray* indices, NDArray& aexpected, +// NDArray& anOutput) { +// } - template - static void unsortedSegmentMaxFunctor_(NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { +void segmentMaxFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), segmentMaxFunctor_, + (input, indices, output), LIBND4J_TYPES); +} - // if input is a vector: (as if in doc sample) - //int idx = static_cast((*indices)(0.)); - MAP_IMPL> idxs;//(indices->lengthOf()); - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - idxs[indices->e(e)].push_back(e); +void segmentMinFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), segmentMinFunctor_, + (input, indices, output), LIBND4J_TYPES); +} - //std::sort(idxs.begin(), idxs.end()); +void segmentMeanFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), segmentMeanFunctor_, + (input, indices, output), LIBND4J_TYPES); +} - if (input->isVector()) { // 1D case - T maxVal = DataTypeUtils::max(); - output->assign(-maxVal); +void segmentSumFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), segmentSumFunctor_, + (input, indices, output), LIBND4J_TYPES); +} - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - T val = input->e(fi->second.at(0)); - for (Nd4jLong idx = 1; idx < static_cast(fi->second.size()); ++idx) { - val = sd::math::nd4j_max(val, input->e(fi->second.at(idx))); - } - output->p(fi->first, val); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); +void segmentProdFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), segmentProdFunctor_, + (input, indices, output), LIBND4J_TYPES); +} - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); +bool segmentIndicesValidate(sd::LaunchContext* context, NDArray* indices, + NDArray& expected, NDArray& output) { + auto val = indices->e(0); + for (Nd4jLong e = 1; e < indices->lengthOf(); e++) { + output = indices->e(e); + if (val.e(0) > output.e(0)) return false; + val = indices->e(e); + } - T maxVal = DataTypeUtils::max(); - output->assign(-maxVal); + return true; +} - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors.at(fi->first); - outputT->assign(listOfTensors.at(fi->second.at(0))); - for (Nd4jLong idx = 1; idx < static_cast(fi->second.size()); ++idx) { - auto maxT = listOfTensors.at(fi->second.at(idx)); - for (Nd4jLong e = 0; e < outputT->lengthOf(); ++e) { - T val = sd::math::nd4j_max(maxT->e(e), outputT->e(e)); +// BUILD_SINGLE_TEMPLATE(template bool segmentIndicesValidate_, (NDArray*, +// NDArray&, NDArray&), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void segmentProdFunctor_, + (NDArray * input, NDArray* indices, NDArray* output), + LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void segmentSumFunctor_, + (NDArray * input, NDArray* indices, NDArray* output), + LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void segmentMeanFunctor_, + (NDArray * input, NDArray* indices, NDArray* output), + LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void segmentMinFunctor_, + (NDArray * input, NDArray* indices, NDArray* output), + LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void segmentMaxFunctor_, + (NDArray * input, NDArray* indices, NDArray* output), + LIBND4J_TYPES); +// -------------------------------------------------------------------------------------------------------------- +// // Unsorted segment ops +// -------------------------------------------------------------------------------------------------------------- +// // + +bool unsortedSegmentIndicesValidate(sd::LaunchContext* context, + NDArray* indices, Nd4jLong expected, + Nd4jLong& output) { + Nd4jLong val = indices->e(0); + + Nd4jLong maxInd = indices->argMax(); + if (indices->e(maxInd) >= expected) { + output = val; + return false; + } + output = expected; + return true; +} - outputT->p(e, val); - } - } - } - } - } - void unsortedSegmentMaxFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), unsortedSegmentMaxFunctor_, (input, indices, numOfClasses, output), NUMERIC_TYPES); +template +static void unsortedSegmentMaxFunctor_(NDArray* input, NDArray* indices, + Nd4jLong numOfClasses, NDArray* output) { + // if input is a vector: (as if in doc sample) + // int idx = static_cast((*indices)(0.)); + MAP_IMPL> idxs; //(indices->lengthOf()); + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + idxs[indices->e(e)].push_back(e); + + // std::sort(idxs.begin(), idxs.end()); + + if (input->isVector()) { // 1D case + T maxVal = DataTypeUtils::max(); + output->assign(-maxVal); + + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + T val = input->e(fi->second.at(0)); + for (Nd4jLong idx = 1; idx < static_cast(fi->second.size()); + ++idx) { + val = sd::math::nd4j_max(val, input->e(fi->second.at(idx))); + } + output->p(fi->first, val); } - BUILD_SINGLE_TEMPLATE(template void unsortedSegmentMaxFunctor_, (NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES); - - template - static void unsortedSegmentMinFunctor_(NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - // if input is a vector: (as if in doc sample) - //int idx = static_cast((*indices)(0.)); - MAP_IMPL> idxs;//(indices->lengthOf()); - - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - idxs[indices->e(e)].push_back(e); + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - //std::sort(idxs.begin(), idxs.end()); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - if (input->isVector()) { // 1D case - T maxVal = DataTypeUtils::max(); - output->assign(maxVal); + T maxVal = DataTypeUtils::max(); + output->assign(-maxVal); - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - T val = input->t(fi->second.at(0)); + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + auto outputT = listOfOutTensors.at(fi->first); + outputT.assign(listOfTensors.at(fi->second.at(0))); + for (Nd4jLong idx = 1; idx < static_cast(fi->second.size()); + ++idx) { + auto maxT = listOfTensors.at(fi->second.at(idx)); + for (Nd4jLong e = 0; e < outputT.lengthOf(); ++e) { + T val = sd::math::nd4j_max(maxT.e(e), outputT.e(e)); - for (size_t idx = 1; idx < fi->second.size(); ++idx) { - val = sd::math::nd4j_min(val, input->t(fi->second.at(idx))); - } - output->r(fi->first) = val; - } + outputT.p(e, val); } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + } + } + } +} +void unsortedSegmentMaxFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), unsortedSegmentMaxFunctor_, + (input, indices, numOfClasses, output), NUMERIC_TYPES); +} +BUILD_SINGLE_TEMPLATE(template void unsortedSegmentMaxFunctor_, + (NDArray * input, NDArray* indices, Nd4jLong numOfClasses, + NDArray* output), + NUMERIC_TYPES); + +template +static void unsortedSegmentMinFunctor_(NDArray* input, NDArray* indices, + Nd4jLong numOfClasses, NDArray* output) { + // if input is a vector: (as if in doc sample) + // int idx = static_cast((*indices)(0.)); + MAP_IMPL> idxs; //(indices->lengthOf()); + + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + idxs[indices->e(e)].push_back(e); + + // std::sort(idxs.begin(), idxs.end()); + + if (input->isVector()) { // 1D case + T maxVal = DataTypeUtils::max(); + output->assign(maxVal); + + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + T val = input->t(fi->second.at(0)); + + for (size_t idx = 1; idx < fi->second.size(); ++idx) { + val = sd::math::nd4j_min(val, input->t(fi->second.at(idx))); + } + output->r(fi->first) = val; + } + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - T maxVal = DataTypeUtils::max(); - output->assign(maxVal); + T maxVal = DataTypeUtils::max(); + output->assign(maxVal); - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors.at(fi->first); - outputT->assign(listOfTensors.at(fi->second.at(0))); - for (size_t idx = 1; idx < fi->second.size(); ++idx) { - auto minT = listOfTensors.at(fi->second.at(idx)); + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + auto outputT = listOfOutTensors.at(fi->first); + outputT.assign(listOfTensors.at(fi->second.at(0))); + for (size_t idx = 1; idx < fi->second.size(); ++idx) { + auto minT = listOfTensors.at(fi->second.at(idx)); - for (Nd4jLong e = 0; e < outputT->lengthOf(); ++e) { - outputT->r(e) = sd::math::nd4j_min(minT->t(e), outputT->t(e)); + for (Nd4jLong e = 0; e < outputT.lengthOf(); ++e) { + outputT.r(e) = sd::math::nd4j_min(minT.t(e), outputT.t(e)); } } //outputT->assign(maxT); } } + } - } - void unsortedSegmentMinFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), unsortedSegmentMinFunctor_, (input, indices, numOfClasses, output), - NUMERIC_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template void unsortedSegmentMinFunctor_, (NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES); - - void unsortedSegmentMeanFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - MAP_IMPL> idxs;//(indices->lengthOf()); - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - idxs[indices->e(e)].push_back(e); - - //std::sort(idxs.begin(), idxs.end()); - - if (input->isVector()) { // 1D case +void unsortedSegmentMinFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), unsortedSegmentMinFunctor_, (input, indices, numOfClasses, output), + NUMERIC_TYPES); +} - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - double sumValue = input->e(fi->second.at(0)); - int loop_size = fi->second.size(); +void unsortedSegmentMeanFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output) { + MAP_IMPL> idxs; //(indices->lengthOf()); + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + idxs[indices->e(e)].push_back(e); - // FIXME: parallelism here? - for (size_t idx = 1; idx < loop_size; ++idx) { - sumValue += input->e(fi->second.at(idx)); - } + // std::sort(idxs.begin(), idxs.end()); - output->p(fi->first, sumValue / fi->second.size()); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + if (input->isVector()) { // 1D case - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + double sumValue = input->e(fi->second.at(0)); + int loop_size = fi->second.size(); - // FIXME: parallelism here? - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors.at(fi->first); - outputT->assign(listOfTensors.at(fi->second.at(0))); - Nd4jLong loopSize = fi->second.size(); + // FIXME: parallelism here? + for (size_t idx = 1; idx < loop_size; ++idx) { + sumValue += input->e(fi->second.at(idx)); + } - for (Nd4jLong idx = 1; idx < loopSize; ++idx) { - auto current = listOfTensors.at(fi->second.at(idx)); - *outputT += *current; - } - (*outputT) /= double(fi->second.size()); - } - } + output->p(fi->first, sumValue / fi->second.size()); } - - void unsortedSegmentSumFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - MAP_IMPL> idxs;//(indices->lengthOf()); - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - idxs[indices->e(e)].push_back(e); - - if (input->isVector()) { // 1D case - - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - double sumValue = input->e(fi->second.at(0)); - Nd4jLong loop_size = fi->second.size(); - - // FIXME: parallelism here? - for (Nd4jLong idx = 1; idx < loop_size; ++idx) { - sumValue += input->e(fi->second.at(idx)); - } - output->p(fi->first, sumValue); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors.at(fi->first); - outputT->assign(listOfTensors.at(fi->second.at(0))); - Nd4jLong loop_size = fi->second.size(); - - // FIXME: parallelism here? - for (Nd4jLong idx = 1; idx < loop_size; ++idx) { - auto current = listOfTensors.at(fi->second.at(idx)); - *(outputT) += *current; - } - //outputT->assign(maxT); - } - } + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + // FIXME: parallelism here? + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + auto outputT = listOfOutTensors.at(fi->first); + outputT.assign(listOfTensors.at(fi->second.at(0))); + Nd4jLong loopSize = fi->second.size(); + + for (Nd4jLong idx = 1; idx < loopSize; ++idx) { + auto current = listOfTensors.at(fi->second.at(idx)); + outputT += current; + } + (outputT) /= double(fi->second.size()); } + } +} - template - void unsortedSegmentProdFunctor_(NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - MAP_IMPL> idxs;//(indices->lengthOf()); - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - idxs[indices->e(e)].push_back(e); +void unsortedSegmentSumFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output) { + MAP_IMPL> idxs; //(indices->lengthOf()); + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + idxs[indices->e(e)].push_back(e); - //std::sort(idxs.begin(), idxs.end()); + if (input->isVector()) { // 1D case - output->assign(1.f); + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + double sumValue = input->e(fi->second.at(0)); + Nd4jLong loop_size = fi->second.size(); - if (input->isVector()) { // 1D case - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - T prodValue = input->e(fi->second.at(0)); - for (size_t idx = 1; idx < fi->second.size(); ++idx) { - prodValue *= input->e(fi->second.at(idx)); - } - output->p(fi->first, prodValue); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors.at(fi->first); - outputT->assign(listOfTensors.at(fi->second.at(0))); - for (size_t idx = 1; idx < fi->second.size(); ++idx) { - auto current = listOfTensors.at(fi->second.at(idx)); - - *outputT *= *current; - } - } - } + // FIXME: parallelism here? + for (Nd4jLong idx = 1; idx < loop_size; ++idx) { + sumValue += input->e(fi->second.at(idx)); + } + output->p(fi->first, sumValue); } - - void unsortedSegmentProdFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), unsortedSegmentProdFunctor_, (input, indices, numOfClasses, output), NUMERIC_TYPES); + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + auto outputT = listOfOutTensors.at(fi->first); + outputT.assign(listOfTensors.at(fi->second.at(0))); + Nd4jLong loop_size = fi->second.size(); + + // FIXME: parallelism here? + for (Nd4jLong idx = 1; idx < loop_size; ++idx) { + auto current = listOfTensors.at(fi->second.at(idx)); + outputT += current; + } + // outputT->assign(maxT); } - BUILD_SINGLE_TEMPLATE(template void unsortedSegmentProdFunctor_, (NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES); + } +} - void unsortedSegmentSqrtNFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - MAP_IMPL> idxs;//(indices->lengthOf()); - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - idxs[indices->e(e)].push_back(e); +template +void unsortedSegmentProdFunctor_(NDArray* input, NDArray* indices, + Nd4jLong numOfClasses, NDArray* output) { + MAP_IMPL> idxs; //(indices->lengthOf()); + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + idxs[indices->e(e)].push_back(e); - //std::sort(idxs.begin(), idxs.end()); + // std::sort(idxs.begin(), idxs.end()); - if (input->isVector()) { // 1D case - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - double sumValue = input->e(fi->second.at(0)); - for (size_t idx = 1; idx < fi->second.size(); ++idx) { - sumValue += input->e(fi->second.at(idx)); - } - output->p(fi->first, sumValue / sd::math::nd4j_sqrt(fi->second.size())); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors.at(fi->first); - outputT->assign(listOfTensors.at(fi->second.at(0))); - for (size_t idx = 1; idx < fi->second.size(); ++idx) { - auto current = listOfTensors.at(fi->second.at(idx)); - *outputT += *current; - } - //outputT->assign(maxT); - (*outputT) /= sd::math::nd4j_sqrt(fi->second.size()); - } - } + output->assign(1.f); + + if (input->isVector()) { // 1D case + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + T prodValue = input->e(fi->second.at(0)); + for (size_t idx = 1; idx < fi->second.size(); ++idx) { + prodValue *= input->e(fi->second.at(idx)); + } + output->p(fi->first, prodValue); } + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - // -------------------------------------------------------------------------------------------------------------- // - // Backpropagate ops helpers - // -------------------------------------------------------------------------------------------------------------- // - // Sorted backpropagate ops - // - // segment max - template - int segmentMaxFunctorBP_(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - //int numOfClasses = gradOut->sizeAt(0); - // if input is a vector: (as if in doc sample) - auto tempRes = gradOut->dup(); - segmentMaxFunctor_(input, indices, &tempRes); - if (input->isVector()) { - Nd4jLong loop_size = input->lengthOf(); - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto classNum = indices->e(e); - if (sd::math::nd4j_abs(tempRes.e(classNum) - input->e(e)) <= T(1.e-6)) - output->p(e, gradOut->e(classNum)); - } - }; - samediff::Threads::parallel_for(func, 0, loop_size); - } - else { - std::vector restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - //int numOfClasses = tempRes.sizeAt(0); // number of classes - //std::vector> outputs(numOfClasses); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); - - for (Nd4jLong e = 0; e < current->lengthOf(); e++) { - if (sd::math::nd4j_abs(listOfBPTensors.at(classNum)->e(e) - current->e(e)) <= T(1.e-6)) - currentOut->p(e, currentGradOut->e(e)); - } - } - }; + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - samediff::Threads::parallel_tad(func, 0, indices->lengthOf()); - } + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + auto outputT = listOfOutTensors.at(fi->first); + outputT.assign(listOfTensors.at(fi->second.at(0))); + for (size_t idx = 1; idx < fi->second.size(); ++idx) { + auto current = listOfTensors.at(fi->second.at(idx)); - return ND4J_STATUS_OK; + outputT *= current; + } } + } +} - int segmentMaxFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - BUILD_SINGLE_SELECTOR(output->dataType(), return segmentMaxFunctorBP_, (context, input, indices, gradOut, output), NUMERIC_TYPES); +void unsortedSegmentProdFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), unsortedSegmentProdFunctor_, + (input, indices, numOfClasses, output), NUMERIC_TYPES); +} +BUILD_SINGLE_TEMPLATE(template void unsortedSegmentProdFunctor_, + (NDArray * input, NDArray* indices, Nd4jLong numOfClasses, + NDArray* output), + NUMERIC_TYPES); + +void unsortedSegmentSqrtNFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output) { + MAP_IMPL> idxs; //(indices->lengthOf()); + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + idxs[indices->e(e)].push_back(e); + + // std::sort(idxs.begin(), idxs.end()); + + if (input->isVector()) { // 1D case + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + double sumValue = input->e(fi->second.at(0)); + for (size_t idx = 1; idx < fi->second.size(); ++idx) { + sumValue += input->e(fi->second.at(idx)); + } + output->p(fi->first, sumValue / sd::math::nd4j_sqrt( + fi->second.size())); } - BUILD_SINGLE_TEMPLATE(template int segmentMaxFunctorBP_, (sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), NUMERIC_TYPES); - - // segmen min - int segmentMinFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - NDArray tempRes = gradOut->dup(); - segmentMinFunctor(context, input, indices, &tempRes); - if (input->isVector()) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto classNum = indices->e(e); - if (sd::math::nd4j_abs(tempRes.e(classNum) - input->e(e)) < 1.e-5) - output->p(e, gradOut->e(classNum)); - } - }; - samediff::Threads::parallel_for(func, 0, input->lengthOf()); - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - //int numOfClasses = tempRes.sizeAt(0); // number of classes - //std::vector> outputs(numOfClasses); - output->assign(0.); - int pos = 0; - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); - - for (Nd4jLong e = 0; e < current->lengthOf(); e++) { - if (sd::math::nd4j_abs(listOfBPTensors.at(classNum)->e(e) - current->e(e)) < - 1.e-5) - currentOut->p(e, currentGradOut->e(e)); - } - } - }; - - samediff::Threads::parallel_tad(func, 0, indices->lengthOf()); - } - return ND4J_STATUS_OK; + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + auto outputT = listOfOutTensors.at(fi->first); + outputT.assign(listOfTensors.at(fi->second.at(0))); + for (size_t idx = 1; idx < fi->second.size(); ++idx) { + auto current = listOfTensors.at(fi->second.at(idx)); + outputT += current; + } + // outputT->assign(maxT); + outputT /= sd::math::nd4j_sqrt(fi->second.size()); } + } +} - // segmen mean - int segmentMeanFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - int numClasses = output->sizeAt(0); - MAP_IMPL classCount;//(numClasses); - - for (Nd4jLong count = 0; count < numClasses; ++count) { - classCount[count] = 0; - } - - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { - classCount[indices->e(e)] ++; +// -------------------------------------------------------------------------------------------------------------- +// // Backpropagate ops helpers +// -------------------------------------------------------------------------------------------------------------- +// // Sorted backpropagate ops +// +// segment max +template +int segmentMaxFunctorBP_(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + // int numOfClasses = gradOut->sizeAt(0); + // if input is a vector: (as if in doc sample) + auto tempRes = gradOut->dup(); + segmentMaxFunctor_(input, indices, &tempRes); + if (input->isVector()) { + Nd4jLong loop_size = input->lengthOf(); + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto classNum = indices->e(e); + if (sd::math::nd4j_abs(tempRes.e(classNum) - input->e(e)) <= + T(1.e-6)) + output->p(e, gradOut->e(classNum)); + } + }; + samediff::Threads::parallel_for(func, 0, loop_size); + } else { + std::vector restDims = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + // int numOfClasses = tempRes.sizeAt(0); // number of classes + // std::vector> outputs(numOfClasses); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + + for (Nd4jLong e = 0; e < current.lengthOf(); e++) { + if (sd::math::nd4j_abs(listOfBPTensors.at(classNum).e(e) - + current.e(e)) <= T(1.e-6)) + currentOut.p(e, currentGradOut.e(e)); } + } + }; - // if input is a vector: (as if in doc sample) - if (input->isVector()) { - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { - Nd4jLong classNum = indices->e(e); - output->p(e, gradOut->e(classNum) / classCount[classNum]); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); -; - - int pos = 0; - //auto func = [&](uint64_t thread_id, uint64_t start, uint64_t stop, uint64_t increment) -> void { - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - auto classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); - - for (Nd4jLong e = 0; e < current->lengthOf(); e++) { - currentOut->p(e, currentGradOut->e(e) / classCount.at(classNum)); - } - } - //}; + samediff::Threads::parallel_tad(func, 0, indices->lengthOf()); + } - //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); - } - return ND4J_STATUS_OK; - } + return ND4J_STATUS_OK; +} - int segmentSumFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { -// int numClasses = output->sizeAt(0); - // if input is a vector: (as if in doc sample) - Nd4jLong idx = indices->e(0); - if (input->isVector()) { - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { - Nd4jLong classNum = indices->e(e); - output->p(e, gradOut->e(classNum)); - } +int segmentMaxFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + BUILD_SINGLE_SELECTOR(output->dataType(), return segmentMaxFunctorBP_, + (context, input, indices, gradOut, output), + NUMERIC_TYPES); +} +BUILD_SINGLE_TEMPLATE(template int segmentMaxFunctorBP_, + (sd::LaunchContext * context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output), + NUMERIC_TYPES); + +// segmen min +int segmentMinFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + NDArray tempRes = gradOut->dup(); + segmentMinFunctor(context, input, indices, &tempRes); + if (input->isVector()) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto classNum = indices->e(e); + if (sd::math::nd4j_abs(tempRes.e(classNum) - + input->e(e)) < 1.e-5) + output->p(e, gradOut->e(classNum)); + } + }; + samediff::Threads::parallel_for(func, 0, input->lengthOf()); + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + // int numOfClasses = tempRes.sizeAt(0); // number of classes + // std::vector> outputs(numOfClasses); + output->assign(0.); + int pos = 0; + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + + for (Nd4jLong e = 0; e < current.lengthOf(); e++) { + if (sd::math::nd4j_abs(listOfBPTensors.at(classNum).e(e) - + current.e(e)) < 1.e-5) + currentOut.p(e, currentGradOut.e(e)); } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + } + }; - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + samediff::Threads::parallel_tad(func, 0, indices->lengthOf()); + } + return ND4J_STATUS_OK; +} - //auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - auto classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); +// segmen mean +int segmentMeanFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + int numClasses = output->sizeAt(0); + MAP_IMPL classCount; //(numClasses); + + for (Nd4jLong count = 0; count < numClasses; ++count) { + classCount[count] = 0; + } + + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { + classCount[indices->e(e)]++; + } + + // if input is a vector: (as if in doc sample) + if (input->isVector()) { + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { + Nd4jLong classNum = indices->e(e); + output->p(e, gradOut->e(classNum) / classCount[classNum]); + } + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + ; + + int pos = 0; + // auto func = [&](uint64_t thread_id, uint64_t start, uint64_t stop, + // uint64_t increment) -> void { + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + auto classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + + for (Nd4jLong e = 0; e < current.lengthOf(); e++) { + currentOut.p(e, currentGradOut.e(e) / classCount.at(classNum)); + } + } + //}; - currentOut->assign(currentGradOut); - } - //}; + // samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } + return ND4J_STATUS_OK; +} - //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); - } - return Status::OK(); +int segmentSumFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + // int numClasses = output->sizeAt(0); + // if input is a vector: (as if in doc sample) + Nd4jLong idx = indices->e(0); + if (input->isVector()) { + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { + Nd4jLong classNum = indices->e(e); + output->p(e, gradOut->e(classNum)); } + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - int segmentProdFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - auto tempRes = gradOut->dup(); - segmentProdFunctor(context, input, indices, &tempRes); - if (input->isVector()) { - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { - Nd4jLong classNum = indices->e(e); - output->p(e, gradOut->e(classNum) * tempRes.e(classNum)/ input->e(e)); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - //int numOfClasses = tempRes.sizeAt(0); // number of classes - //std::vector> outputs(numOfClasses); - - //auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - auto classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); - auto currentFFOut = listOfBPTensors.at(classNum); - - currentOut->assign((*currentFFOut) * (*currentGradOut) / (*current)); - } - //}; + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); - } + // auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + auto classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); - return ND4J_STATUS_OK; + currentOut.assign(currentGradOut); } + //}; - // -------------------------------------------------------------------------------------------------------------- // - // Unsorted backpropagate segment ops - // -------------------------------------------------------------------------------------------------------------- // - - template - static int unsortedSegmentMaxFunctorBP_(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { -// int numOfClasses = gradOut->sizeAt(0); - // if input is a vector: (as if in doc sample) - auto tempRes = gradOut->dup(); - unsortedSegmentMaxFunctor(context, input, indices, numOfClasses, &tempRes); - if (input->isVector()) { - - for (Nd4jLong e = 0; e < input->lengthOf(); ++e) { - Nd4jLong classNum = indices->e(e); - if (sd::math::nd4j_abs(tempRes.e(classNum) - input->e(e)) < 1.e-5) - output->p(e, gradOut->e(classNum)); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - Nd4jLong classNum = indices->e(i); - NDArray* current = listOfTensors.at(i); - NDArray* currentOut = listOfOutTensors.at(i); - NDArray* currentGradOut = listOfGradOuts.at(classNum); - for (int e = 0; e < current->lengthOf(); e++) { - if (sd::math::nd4j_abs(listOfBPTensors.at(classNum)->e(e) - current->e(e)) < 1.e-5) - currentOut->p(e, currentGradOut->e(e)); - } - } - } + // samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } + return Status::OK(); +} - return ND4J_STATUS_OK; +int segmentProdFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + auto tempRes = gradOut->dup(); + segmentProdFunctor(context, input, indices, &tempRes); + if (input->isVector()) { + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { + Nd4jLong classNum = indices->e(e); + output->p(e, gradOut->e(classNum) * tempRes.e(classNum) / + input->e(e)); } - - int unsortedSegmentMaxFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - BUILD_SINGLE_SELECTOR(output->dataType(), return unsortedSegmentMaxFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES); + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + // int numOfClasses = tempRes.sizeAt(0); // number of classes + // std::vector> outputs(numOfClasses); + + // auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + auto classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + auto currentFFOut = listOfBPTensors.at(classNum); + + currentOut.assign(currentFFOut * currentGradOut / current); } - BUILD_SINGLE_TEMPLATE(template int unsortedSegmentMaxFunctorBP_, (sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES); - - template - static int unsortedSegmentMinFunctorBP_(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - auto tempRes = gradOut->dup(); - unsortedSegmentMinFunctor(context, input, indices, numOfClasses, &tempRes); - if (input->isVector()) { - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto classNum = indices->e(e); - if (sd::math::nd4j_abs(tempRes.t(classNum) - input->t(e)) < 1.e-6) - output->r(e) = gradOut->t(classNum); - } - }; + //}; - samediff::Threads::parallel_for(func, 0, input->lengthOf()); - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - //auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - auto classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); - - for (Nd4jLong e = 0; e < current->lengthOf(); e++) { - if (sd::math::nd4j_abs(listOfBPTensors.at(classNum)->t(e) - current->t(e)) < 1.e-6) - currentOut->r(e) = currentGradOut->t(e); - } - } - //}; + // samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } - //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); - } + return ND4J_STATUS_OK; +} - return ND4J_STATUS_OK; +// -------------------------------------------------------------------------------------------------------------- +// // Unsorted backpropagate segment ops +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static int unsortedSegmentMaxFunctorBP_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + NDArray* gradOut, Nd4jLong numOfClasses, + NDArray* output) { + // int numOfClasses = gradOut->sizeAt(0); + // if input is a vector: (as if in doc sample) + auto tempRes = gradOut->dup(); + unsortedSegmentMaxFunctor(context, input, indices, numOfClasses, &tempRes); + if (input->isVector()) { + for (Nd4jLong e = 0; e < input->lengthOf(); ++e) { + Nd4jLong classNum = indices->e(e); + if (sd::math::nd4j_abs(tempRes.e(classNum) - + input->e(e)) < 1.e-5) + output->p(e, gradOut->e(classNum)); } - - int unsortedSegmentMinFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - BUILD_SINGLE_SELECTOR(output->dataType(), return unsortedSegmentMinFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES); + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + Nd4jLong classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + for (int e = 0; e < current.lengthOf(); e++) { + if (sd::math::nd4j_abs(listOfBPTensors.at(classNum).e(e) - + current.e(e)) < 1.e-5) + currentOut.p(e, currentGradOut.e(e)); + } } - BUILD_SINGLE_TEMPLATE(template int unsortedSegmentMinFunctorBP_, (sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES); + } - int unsortedSegmentMeanFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { + return ND4J_STATUS_OK; +} - MAP_IMPL classCount;//(numClasses); +int unsortedSegmentMaxFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output) { + BUILD_SINGLE_SELECTOR( + output->dataType(), return unsortedSegmentMaxFunctorBP_, + (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES); +} +BUILD_SINGLE_TEMPLATE(template int unsortedSegmentMaxFunctorBP_, + (sd::LaunchContext * context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output), + NUMERIC_TYPES); + +template +static int unsortedSegmentMinFunctorBP_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + NDArray* gradOut, Nd4jLong numOfClasses, + NDArray* output) { + auto tempRes = gradOut->dup(); + unsortedSegmentMinFunctor(context, input, indices, numOfClasses, &tempRes); + if (input->isVector()) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto classNum = indices->e(e); + if (sd::math::nd4j_abs(tempRes.t(classNum) - input->t(e)) < 1.e-6) + output->r(e) = gradOut->t(classNum); + } + }; + + samediff::Threads::parallel_for(func, 0, input->lengthOf()); + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + // auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + auto classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + + for (Nd4jLong e = 0; e < current.lengthOf(); e++) { + if (sd::math::nd4j_abs(listOfBPTensors.at(classNum).t(e) - current.t(e)) < 1.e-6) + currentOut.r(e) = currentGradOut.t(e); + } + } + //}; - for (Nd4jLong count = 0; count < numOfClasses; ++count) { - classCount[count] = 0; - } + // samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { - classCount[indices->e(e)]++; - } + return ND4J_STATUS_OK; +} - // if input is a vector: (as if in doc sample) - if (input->isVector()) { - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { - Nd4jLong classNum = indices->e(e); - output->p(e, gradOut->e(classNum) / classCount[classNum]); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - Nd4jLong classNum = indices->e(i); - NDArray* current = listOfTensors.at(i); - NDArray* currentOut = listOfOutTensors.at(i); - NDArray* currentGradOut = listOfGradOuts.at(classNum); - currentOut->assign(*currentGradOut / double(classCount[classNum])); - } - } - return ND4J_STATUS_OK; +int unsortedSegmentMinFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output) { + BUILD_SINGLE_SELECTOR( + output->dataType(), return unsortedSegmentMinFunctorBP_, + (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES); +} +BUILD_SINGLE_TEMPLATE(template int unsortedSegmentMinFunctorBP_, + (sd::LaunchContext * context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output), + NUMERIC_TYPES); + +int unsortedSegmentMeanFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output) { + MAP_IMPL classCount; //(numClasses); + + for (Nd4jLong count = 0; count < numOfClasses; ++count) { + classCount[count] = 0; + } + + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { + classCount[indices->e(e)]++; + } + + // if input is a vector: (as if in doc sample) + if (input->isVector()) { + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { + Nd4jLong classNum = indices->e(e); + output->p(e, gradOut->e(classNum) / classCount[classNum]); } - - int unsortedSegmentSumFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - - // if input is a vector: (as if in doc sample) - Nd4jLong idx = indices->e(0); - if (input->isVector()) { - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { - Nd4jLong classNum = indices->e(e); - output->p(e, gradOut->e(classNum)); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - //auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - auto classNum = indices->e(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); - - currentOut->assign(currentGradOut); - } - //}; - - //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); - } - return Status::OK(); + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + Nd4jLong classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + currentOut.assign(currentGradOut / double(classCount[classNum])); } + } + return ND4J_STATUS_OK; +} - int unsortedSegmentProdFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - - auto tempRes = gradOut->dup(); - unsortedSegmentProdFunctor(context, input, indices, numOfClasses, &tempRes); - if (input->isVector()) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto classNum = indices->e(e); - output->p(e, gradOut->e(classNum) * tempRes.e(classNum) / input->e(e)); - } - }; +int unsortedSegmentSumFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output) { + // if input is a vector: (as if in doc sample) + Nd4jLong idx = indices->e(0); + if (input->isVector()) { + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { + Nd4jLong classNum = indices->e(e); + output->p(e, gradOut->e(classNum)); + } + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - samediff::Threads::parallel_for(func, 0, indices->lengthOf()); - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - //auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - auto classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); - auto currentFFOut = listOfBPTensors.at(classNum); - - currentOut->assign((*currentFFOut) * (*currentGradOut) / (*current)); - } - //}; + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); - } + // auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + auto classNum = indices->e(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); - return Status::OK(); + currentOut.assign(currentGradOut); } + //}; -// template - int unsortedSegmentSqrtNFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - MAP_IMPL classCount;//(numClasses); - - for (Nd4jLong count = 0; count < numOfClasses; ++count) { - classCount[count] = 0; - } + // samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } + return Status::OK(); +} - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { - classCount[indices->e(e)]++; - } +int unsortedSegmentProdFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output) { + auto tempRes = gradOut->dup(); + unsortedSegmentProdFunctor(context, input, indices, numOfClasses, &tempRes); + if (input->isVector()) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto classNum = indices->e(e); + output->p(e, gradOut->e(classNum) * + tempRes.e(classNum) / + input->e(e)); + } + }; + + samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + // auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + auto classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + auto currentFFOut = listOfBPTensors.at(classNum); + + currentOut.assign(currentFFOut * currentGradOut / current); + } + //}; - // if input is a vector: (as if in doc sample) - if (input->isVector()) { - //auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { - auto classNum = indices->e(e); - output->p(e, gradOut->e(classNum) / sd::math::nd4j_sqrt(classCount[classNum])); - } - //}; + // samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } - //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfGradOuts =gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors =input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors =output->allTensorsAlongDimension(restDims); - - //auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - auto classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); - - for (int e = 0; e < current->lengthOf(); e++) { - currentOut->p(e, currentGradOut->e(e) / sd::math::nd4j_sqrt(classCount[classNum])); - } - } - //}; + return Status::OK(); +} - //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); - } - return Status::OK(); +// template +int unsortedSegmentSqrtNFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output) { + MAP_IMPL classCount; //(numClasses); + + for (Nd4jLong count = 0; count < numOfClasses; ++count) { + classCount[count] = 0; + } + + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { + classCount[indices->e(e)]++; + } + + // if input is a vector: (as if in doc sample) + if (input->isVector()) { + // auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { + auto classNum = indices->e(e); + output->p(e, + gradOut->e(classNum) / + sd::math::nd4j_sqrt(classCount[classNum])); } + //}; + + // samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + // auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + auto classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + + for (int e = 0; e < current.lengthOf(); e++) { + currentOut.p( + e, currentGradOut.e(e) / + sd::math::nd4j_sqrt(classCount[classNum])); + } + } + //}; + // samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } + return Status::OK(); } -} -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/sequence_mask.cpp b/libnd4j/include/ops/declarable/helpers/cpu/sequence_mask.cpp index 3c8ce573e3b0..34c344d8bb64 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/sequence_mask.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/sequence_mask.cpp @@ -18,30 +18,36 @@ // @author GS // -#include #include +#include namespace sd { namespace ops { namespace helpers { - template - static void sequenceMask_(NDArray* input, NDArray* output, int maxIndex) { - auto func = PRAGMA_THREADS_FOR_2D { - for (auto i = start_x; i < stop_x; i += inc_x) - for (auto k = start_y; k < stop_y; k += inc_y) - if (i < input->t(k)) - output->r(k * maxIndex + i) = B(true); //, T(1.0f)); - }; - - samediff::Threads::parallel_for(func, 0, maxIndex, 1, 0, input->lengthOf(), 1); - } +template +static void sequenceMask_(NDArray* input, NDArray* output, int maxIndex) { + auto func = PRAGMA_THREADS_FOR_2D { + for (auto i = start_x; i < stop_x; i += inc_x) + for (auto k = start_y; k < stop_y; k += inc_y) + if (i < input->t(k)) + output->r(k * maxIndex + i) = B(true); //, T(1.0f)); + }; - void sequenceMask(sd::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (input, output, maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED); - } - - BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED); + samediff::Threads::parallel_for(func, 0, maxIndex, 1, 0, input->lengthOf(), + 1); } + +void sequenceMask(sd::LaunchContext* context, NDArray* input, NDArray* output, + int maxIndex) { + BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, + (input, output, maxIndex), INTEGER_TYPES, + LIBND4J_TYPES_EXTENDED); } -} \ No newline at end of file + +BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, + (NDArray * input, NDArray* output, int maxIndex), + INTEGER_TYPES, LIBND4J_TYPES_EXTENDED); +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/sg_cb.cpp b/libnd4j/include/ops/declarable/helpers/cpu/sg_cb.cpp index 07cbca04ea5c..0eceb0127c1d 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/sg_cb.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/sg_cb.cpp @@ -18,609 +18,724 @@ // @author raver119@gmail.com // +#include #include #include -#include #define HS_MAX_EXP 6.0f namespace sd { - namespace ops { - namespace helpers { - template - void hSoftmax_(void *vsyn0, void *vsyn1, void *vexpTable, void *vneu1e, double alpha, int vectorLength, int code, int expLength, bool isInference) { - auto syn0 = reinterpret_cast(vsyn0); - auto syn1 = reinterpret_cast(vsyn1); - auto expTable = reinterpret_cast(vexpTable); - auto neu1e = reinterpret_cast(vneu1e); - - T dot(0.0f); - T g(0.0f); - T f(0.0f); - - // dot - for (int e = 0; e < vectorLength; e++) { - dot += syn0[e] * syn1[e]; - } - - // gradient - if (dot < (T) - HS_MAX_EXP || dot >= (T) HS_MAX_EXP) - return; - - - int idx = static_cast((dot + HS_MAX_EXP) * ((float) expLength / HS_MAX_EXP / 2.0f)); - - if (idx >= expLength || idx < 0) - return; - - f = expTable[idx]; - g = (static_cast(1.0f) - static_cast(code) - f) * (T) alpha; - - // axpy1 - - for (int e = 0; e < vectorLength; e++) { - neu1e[e] = g * syn1[e] + neu1e[e]; - } - - // axpy2 - if (!isInference) { - for (int e = 0; e < vectorLength; e++) { - syn1[e] = g * syn0[e] + syn1[e]; - } - } - } - - template - void nSampling_(void *vsyn0, void *vsyn1Neg, void *vexpTable, void *vneu1e, double alpha, int vectorLength, int code, int expLength, bool isInference) { - auto syn0 = reinterpret_cast(vsyn0); - auto syn1Neg = reinterpret_cast(vsyn1Neg); - auto expTable = reinterpret_cast(vexpTable); - auto neu1e = reinterpret_cast(vneu1e); - - T dot = (T) 0.0f; - T g = (T) 0.0f; - - for (int e = 0; e < vectorLength; e++) { - dot += syn0[e] * syn1Neg[e]; - } - - if (dot > HS_MAX_EXP) - g = (code - 1) * alpha; - else if (dot < (T) - HS_MAX_EXP) - g = (code - 0) * alpha; - else { - int idx = (int) ((dot + (T) HS_MAX_EXP) * ((T) expLength / HS_MAX_EXP / 2.0)); - if (idx >= expLength) - return; - - if (idx < 0) - return; - - g = ((T) code - expTable[idx]) * alpha; - } - - // axpy1 - for (int e = 0; e < vectorLength; e++) { - neu1e[e] = g * syn1Neg[e] + neu1e[e]; - } - - // axpy2 - if (!isInference) { - for (int e = 0; e < vectorLength; e++) { - syn1Neg[e] = g * syn0[e] + syn1Neg[e]; - } - } - } - - template - void cbow_(void *vsyn0, void *vsyn1, void *vsyn1Neg, void *vexpTable, void *vnegTable, void *vinfVector, int target, int ngStarter, int *context, int *lockedWords, int *indices, int8_t *codes, double alpha, Nd4jLong randomValue, const int contextWidth, const int hsRounds, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const int numLabels, const bool trainWords) { - auto syn0 = reinterpret_cast(vsyn0); - auto syn1 = reinterpret_cast(vsyn1); - auto syn1Neg = reinterpret_cast(vsyn1Neg); - auto expTable = reinterpret_cast(vexpTable); - auto negTable = reinterpret_cast(vnegTable); - auto infVector = reinterpret_cast(vinfVector); - - auto neu1 = new T[vectorLength]; - auto neu1e = new T[vectorLength]; - memset(neu1, 0, vectorLength * sizeof(T)); - memset(neu1e, 0, vectorLength * sizeof(T)); - - // building neu1 for current window - for (int c = 0; c < contextWidth; c++) { - if (context[c] >= vocabSize) - throw std::runtime_error("Bad context 4"); - - T *syn0word = syn0 + (context[c] * vectorLength); - - for (int i = 0; i < vectorLength; i++) { - neu1[i] += syn0word[i]; - } - } - - // for inference we add additional inference vector - if (infVector != nullptr) { - - for (int i = 0; i < vectorLength; i++) { - neu1[i] += infVector[i]; - } - } - - - // average neu1 - if (contextWidth > 0) { - for (int i = 0; i < vectorLength; i++) { - neu1[i] /= contextWidth + (infVector != nullptr ? 1 : 0); - } - } - - // softmax round - if (hsRounds > 0) { - for (int i = 0; i < hsRounds; i++) { - if (indices[i] < 0 || indices[i] >= vocabSize) - throw std::runtime_error("Bad context 5"); - - hSoftmax_(neu1, syn1 + (indices[i] * vectorLength), expTable, neu1e, alpha, vectorLength, codes[i], expLength, infVector != nullptr); - } - } - - auto nsStarter = ngStarter; - auto irow = nsStarter; - if (nsRounds > 0) { - for (int r = 0; r < nsRounds + 1; r++) { - if (r == 0) { - // target is known in advance - } else { - randomValue = randomValue * (unsigned long long) 25214903917 + 11; - auto idx = sd::math::nd4j_abs((randomValue >> 16) % negLength); - irow = idx >= negLength ? -1 : static_cast(negTable[idx]); - - if (irow < 0 || irow >= vocabSize) irow = randomValue % (vocabSize - 1) + 1; - if (irow == nsStarter) - continue; - } - - nSampling_(neu1, syn1Neg + (irow * vectorLength), expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, expLength, infVector != nullptr); - } - } - - // if we don't train words - we skip start of idxSyn0 - int starter = trainWords == 1 ? 0 : contextWidth - numLabels; - - // propagate neu1e -> syn0 - if (infVector == nullptr) { - for (int c = starter; c < contextWidth; c++) { - if (lockedWords[c] == 1) - continue; - - T *syn0word = syn0 + (context[c] * vectorLength); - - for (int i = 0; i < vectorLength; i++) { - syn0word[i] += neu1e[i]; - } - } - } else { - - for (int i = 0; i < vectorLength; i++) { - infVector[i] += neu1e[i]; - } - } - - - delete[] neu1; - delete[] neu1e; - } - BUILD_SINGLE_TEMPLATE(template void cbow_, (void *syn0, void *syn1, void *syn1Neg, void *expTable, void *vnegTable, void *vinfVector, int target, int ngStarter, int *context, int *lockedWords, int *indices, int8_t *codes, double alpha, Nd4jLong randomValue, const int contextWidth, const int hsRounds, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const int numLabels, const bool trainWords), FLOAT_TYPES); - - - template - void skipgram_(void *vsyn0, void *vsyn1, void *vsyn1Neg, void *vexpTable, void *vnegTable, void *vinfVector, int target, int ngStarter, int *indices, int8_t *codes, double alpha, Nd4jLong randomValue, const int hsRounds, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength) { - auto syn0 = reinterpret_cast(vsyn0); - auto syn1 = reinterpret_cast(vsyn1); - auto syn1Neg = reinterpret_cast(vsyn1Neg); - auto expTable = reinterpret_cast(vexpTable); - auto negTable = reinterpret_cast(vnegTable); - auto infVector = reinterpret_cast(vinfVector); - - auto neu1e = new T[vectorLength]; - memset(neu1e, 0, vectorLength * sizeof(T)); - - // hierarchic softmax goes first (if enabled) - auto syn0row = infVector != nullptr ? infVector : syn0 + (target * vectorLength); - auto irow = 0; - if (hsRounds > 0) { - for (int r = 0; r < hsRounds; r++) { - irow = indices[r]; - if (irow < 0 || irow >= vocabSize) - break; - - hSoftmax_(syn0row, syn1 + (irow * vectorLength), expTable, neu1e, alpha, vectorLength, codes[r], expLength, infVector != nullptr); - } - } - - // negative sampling goes second (if enabled) - auto nsStarter = ngStarter; - irow = nsStarter; - if (nsRounds > 0) { - for (int r = 0; r < nsRounds + 1; r++) { - if (r == 0) { - // target is known in advance - } else { - randomValue = randomValue * (unsigned long long) 25214903917 + 11; - auto idx = sd::math::nd4j_abs((randomValue >> 16) % negLength); - irow = idx >= negLength ? -1 : static_cast(negTable[idx]); - - if (irow < 0 || irow >= vocabSize) irow = randomValue % (vocabSize - 1) + 1; - if (irow == nsStarter) - continue; - } - - nSampling_(syn0row, syn1Neg + (irow * vectorLength), expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, expLength, infVector != nullptr); - } - } - - if (infVector == nullptr) { - for (int e = 0; e < vectorLength; e++) { - syn0row[e] += neu1e[e]; - } - } else { - for (int e = 0; e < vectorLength; e++) { - infVector[e] += neu1e[e]; - } - } - - delete[] neu1e; - } - BUILD_SINGLE_TEMPLATE(template void skipgram_, (void *syn0, void *syn1, void *syn1Neg, void *expTable, void *vnegTable, void *vinfVector, int target, int ngStarter, int *indices, int8_t *codes, double alpha, Nd4jLong randomValue, const int hsRounds, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength), FLOAT_TYPES); - - int binarySearch(const int *haystack, const int needle, const int totalElements) { - int firstIndex = 0; - int lastIndex = totalElements - 1; - int halfIndex = sd::math::nd4j_floor((lastIndex + firstIndex) / (float) 2); - - while(haystack[halfIndex] != needle && firstIndex < lastIndex) { - if (needle < haystack[halfIndex]) { - lastIndex = halfIndex - 1; - } else if (needle > haystack[halfIndex]) { - firstIndex = halfIndex + 1; - } - halfIndex = sd::math::nd4j_floor((lastIndex + firstIndex) / (float) 2); - } - - return (haystack[halfIndex] == needle) ? halfIndex : -1; - } - - template - static void do_update(const int target, const int rowIndex, const int count, T *syn0, T *neu1t, const int vectorLength) { - - auto syn0row = syn0 + (target * vectorLength); - auto neu1e = neu1t + (rowIndex * vectorLength); - for (int e = 0; e< vectorLength; e++) - syn0row[e] += neu1e[e] / count; - } - - template - static void do_positive(const int target, const int postive, T* syn0, T* syn1Neg, T* expTable, T* neu1e, const double alpha, const int vectorLength, const int expLength) { - //nd4j_printf("Target: [%i]; Positive: [%i]; TID: [%i];\n", target, postive, omp_get_thread_num()); - nSampling_(syn0, syn1Neg, expTable, neu1e, alpha, vectorLength, 1, expLength, false); - } - - template - static void do_negative(int target, int positive, T* syn0, T* syn1Neg, T* expTable, T* negTable, T* neu1e, int *sStarters, const double alpha, const unsigned long long rv, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const int nsRounds, const int numThreads, const int numTargets) { - int irow = 0; - unsigned long long randomValue = rv; - for (int r = 0; r < nsRounds; r++) { - randomValue = sd::math::nd4j_abs(randomValue * (unsigned long long) 25214903917 + 11); - auto idx = sd::math::nd4j_abs((randomValue >> 16) % negLength); - irow = idx >= negLength ? -1 : static_cast(negTable[idx]); - - if (irow < 0 || irow >= vocabSize) - irow = randomValue % (vocabSize - 1) + 1; - - if (irow == positive) - continue; - - // we shift irow here to guarantee independence - - int dim = irow % numThreads; - if (dim != omp_get_thread_num()) { - irow += (numThreads - dim + omp_get_thread_num()); - - // roll back to nearest affilated word - while (irow >= vocabSize) - irow -= numThreads; - - // if this row was processed as first step somewhere - skip it - if (binarySearch(sStarters, irow, numTargets) > 0) { - r--; - continue; - } - } - - - nSampling_(syn0, syn1Neg + (irow * vectorLength), expTable, neu1e, alpha, vectorLength, 0, expLength, false); - } - } - - template - void skipgramBatchExec_(NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, void *vinfVector, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const bool preciseMode, const int numThreads) { - //auto syn0 = reinterpret_cast(vsyn0); - //auto syn1 = reinterpret_cast(vsyn1); - //auto syn1Neg = reinterpret_cast(vsyn1Neg); - const auto expTable = reinterpret_cast(vexpTable); - const auto negTable = reinterpret_cast(vnegTable); - const auto infVector = reinterpret_cast(vinfVector); - - //const auto numThreads = omp_get_max_threads(); - const auto idxShift = indices.isEmpty() ? 0 : indices.sizeAt(1); - const auto hsRounds = codes.isEmpty() ? 0 : codes.sizeAt(1); - - // regular mode provides 0 guarantees for reproducibility - auto numTargets = targets.lengthOf(); - auto bTarget = targets.bufferAsT(); - auto bIndices = indices.bufferAsT(); - auto bCodes = codes.bufferAsT(); - - auto func = PRAGMA_THREADS_FOR { - T sneu1e[600]; - - for (auto t = start; t < stop; t++) { - T *neu1e = vectorLength <= 600 ? sneu1e : new T[vectorLength]; - memset(neu1e, 0, vectorLength * sizeof(T)); - - auto target = bTarget[t]; - auto alpha = lr.e(t); - unsigned long long randomValue = nextRandom.e(t); - - auto syn0row = reinterpret_cast(s0.bufferWithOffset(target * vectorLength)); - - if (hsRounds > 0) { - int irow = 0; - auto cShift = t * idxShift; - - for (Nd4jLong e = 0; e < hsRounds; e++) { - irow = bIndices[e + cShift]; - if (irow < 0 || irow >= vocabSize) - continue; - - auto syn1row = s1.bufferWithOffset(irow * vectorLength); - auto code = bCodes[e + cShift]; - - //nd4j_printf("syn0: [%i]; syn1: [%i]; code: [%i]\n", target, irow, code); - hSoftmax_(syn0row, syn1row, expTable, neu1e, alpha, vectorLength, code, - expLength, false); - } - } - - - if (nsRounds > 0) { - int irow = negStarters.e(t); - int nsStarter = irow; - for (int r = 0; r < nsRounds + 1; r++) { - if (r == 0) { - // target is known in advance - } else { - randomValue = randomValue * (unsigned long long) 25214903917 + 11; - auto idx = sd::math::nd4j_abs((randomValue >> 16) % negLength); - irow = idx >= negLength ? -1 : static_cast(negTable[idx]); - - if (irow < 0 || irow >= vocabSize) - irow = randomValue % (vocabSize - 1) + 1; - - if (irow == nsStarter) - continue; - } - - nSampling_(syn0row, s1n.bufferWithOffset(irow * vectorLength), expTable, neu1e, - alpha, vectorLength, r == 0 ? 1 : 0, expLength, infVector != nullptr); - } - } - - for (int e = 0; e < vectorLength; e++) - syn0row[e] += neu1e[e]; - - // optionally release temp arrays - if (vectorLength > 600) - delete[] neu1e; - } - }; - - samediff::Threads::parallel_tad(func, 0, numTargets, 1, numThreads); - } - BUILD_SINGLE_TEMPLATE(template void skipgramBatchExec_, (NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, void *vinfVector, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const bool preciseMode, const int numThreads), FLOAT_TYPES); - - - template - void cbowBatchExec_(NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, void *vinfVector, NDArray &context, NDArray &lockedWords, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, NDArray &nLabels, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const bool trainWords, const int numThreads) { - const auto syn0 = s0.bufferAsT(); - const auto syn1 = s1.bufferAsT(); - const auto syn1Neg = s1n.bufferAsT(); - - const auto expTable = reinterpret_cast(vexpTable); - const auto negTable = reinterpret_cast(vnegTable); - const auto infVector = reinterpret_cast(vinfVector); - - //const auto numThreads = omp_get_max_threads(); - const auto idxShift = indices.isEmpty() ? 0 : indices.sizeAt(1); - const auto hsRounds = codes.isEmpty() ? 0 : codes.sizeAt(1); - const auto numTargets = context.sizeAt(0); - const int contextWidth = context.sizeAt(1); - - const auto bContext = context.bufferAsT(); - const auto bLocker = lockedWords.bufferAsT(); - const auto bIndices = indices.bufferAsT(); - const auto bCodes = codes.bufferAsT(); - const auto bStarters = negStarters.bufferAsT(); - const auto numIndices = indices.isEmpty() ? 0 : indices.sizeAt(1); - - auto func = PRAGMA_THREADS_FOR { - T sneu1[600]; - T sneu1e[600]; - - for (auto e = start; e < stop; e++) { - T *neu1 = vectorLength <= 600 ? sneu1 : new T[vectorLength]; - T *neu1e = vectorLength <= 600 ? sneu1e : new T[vectorLength]; - - // optionally we nullify temp arrays after successful (and on first) cycle - memset(neu1, 0, sizeof(T) * vectorLength); - memset(neu1e, 0, sizeof(T) * vectorLength); - - auto alpha = lr.e(e); - auto numLabels = nLabels.isEmpty() ? 0 : nLabels.e(e); - - int actualContext = 0; - - // building neu1 for current window - for (int c = 0; c < contextWidth; c++) { - // getting next context word - auto cContext = bContext[c + (e * contextWidth)]; - - // skipping padded values - if (cContext < 0) - continue; - - if (cContext >= vocabSize) - throw std::runtime_error("ContextID can't be >= vocab size"); - - T *syn0word = syn0 + (cContext * vectorLength); - - for (int i = 0; i < vectorLength; i++) - neu1[i] += syn0word[i]; - - actualContext++; - } - - if (infVector != nullptr) - actualContext++; - - if (actualContext > 1) { - for (int i = 0; i < vectorLength; i++) - neu1[i] /= actualContext; - } - - // hierarchic softmax step - if (!indices.isEmpty()) { - for (Nd4jLong i = 0; i < numIndices; i++) { - const int cIndex = bIndices[(e * numIndices) + i]; - const int cCode = bCodes[(e * numIndices) + i]; - - // we're skipping padded values - if (cIndex < 0) - continue; - - if (cIndex >= vocabSize) - throw std::runtime_error("Index can't be > vocab size"); - - hSoftmax_(neu1, syn1 + (cIndex * vectorLength), expTable, neu1e, alpha, vectorLength, - cCode, expLength, false); - } - } - - // negative sampling step - if (!negStarters.isEmpty() && nsRounds > 0) { - int irow = bStarters[e]; - const int nsStarter = irow; - unsigned long long randomValue = nextRandom.e(e); - - for (int r = 0; r < nsRounds + 1; r++) { - // we're skipping rng on 0 step - if (r != 0) { - randomValue = randomValue * (unsigned long long) 25214903917 + 11; - auto idx = sd::math::nd4j_abs((randomValue >> 16) % negLength); - irow = idx >= negLength ? -1 : static_cast(negTable[idx]); - - if (irow < 0 || irow >= vocabSize) irow = randomValue % (vocabSize - 1) + 1; - if (irow == nsStarter) - continue; - - nSampling_(neu1, s1n.bufferWithOffset(irow * vectorLength), expTable, neu1e, - alpha, vectorLength, r == 0 ? 1 : 0, expLength, infVector != nullptr); - } else { - nSampling_(neu1, s1n.bufferWithOffset(irow * vectorLength), expTable, neu1e, - alpha, vectorLength, r == 0 ? 1 : 0, expLength, infVector != nullptr); - } - - //nd4j_printf("Thread <%i>: syn0: [%i]; s1n: [%i];\n", omp_get_thread_num(), 0, irow); - } - } - - - // if we're skipping labels - int starter = trainWords == 1 ? 0 : contextWidth - numLabels; - - // applying previously averaged results - for (int c = starter; c < contextWidth; c++) { - // getting context - auto cContext = bContext[c + (e * contextWidth)]; - auto cLock = bLocker[c + (e * contextWidth)]; - - // skipping padded values - if (cContext < 0 || cLock == 1) - continue; - - if (cContext >= vocabSize) - throw std::runtime_error("ContextID can't be > vocab size"); - - // one word from context - T *syn0word = syn0 + (cContext * vectorLength); - - for (int i = 0; i < vectorLength; i++) - syn0word[i] += neu1e[i]; - - } - - // optionally release temp arrays - if (vectorLength > 600) { - delete[] neu1; - delete[] neu1e; - } - } - }; - - samediff::Threads::parallel_tad(func, 0, numTargets, 1, numThreads); - } - BUILD_SINGLE_TEMPLATE(template void cbowBatchExec_, (NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, void *vinfVector, NDArray &context, NDArray &lockedWords, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, NDArray &nLabels, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const bool trainWords, const int numThreads), FLOAT_TYPES); - - void skipgram(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, NDArray &negTable, NDArray &target, NDArray &ngStarter, int nsRounds, NDArray &indices, NDArray &codes, NDArray &alpha, NDArray &randomValue, NDArray &inferenceVector, const bool preciseMode, const int numWorkers) { - auto xType = syn0.dataType(); - - // single round case - if ((ngStarter.isScalar() && !ngStarter.isEmpty())|| (target.isScalar() && !target.isEmpty())) { - auto hsRounds = codes.lengthOf(); - - BUILD_SINGLE_SELECTOR(xType, skipgram_, (syn0.buffer(), syn1.buffer(), syn1Neg.buffer(), expTable.buffer(), negTable.buffer(), inferenceVector.buffer(), target.isEmpty() ? -1 : target.e(0), ngStarter.isEmpty() ? -1 : ngStarter.e(0), reinterpret_cast(indices.buffer()), reinterpret_cast(codes.buffer()), alpha.e(0), randomValue.e(0), hsRounds, nsRounds, (int) syn0.sizeAt(0), (int) syn0.sizeAt(1), (int) expTable.lengthOf(), (int) negTable.lengthOf()), FLOAT_TYPES); - } else if (ngStarter.isVector() || target.isVector()){ - // batch mode - - BUILD_SINGLE_SELECTOR(xType, skipgramBatchExec_, (syn0, syn1, syn1Neg, expTable.buffer(), negTable.buffer(), nullptr, target, ngStarter, indices, codes, alpha, randomValue, nsRounds, syn0.sizeAt(0), syn0.sizeAt(1), expTable.lengthOf(), negTable.lengthOf(), preciseMode, numWorkers), FLOAT_TYPES); - } else - throw std::runtime_error("SkipGram: target must have rank 0 or 1"); - } - - void cbow(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, NDArray &negTable, NDArray &target, NDArray &ngStarter, int nsRounds, NDArray &context, NDArray &lockedWords, NDArray &indices, NDArray &codes, NDArray &alpha, NDArray &randomValue, NDArray &numLabels, NDArray &inferenceVector, const bool trainWords, int numWorkers) { - auto xType = syn0.dataType(); - - if ((context.rankOf() == 0 || context.rankOf() == 1) && (indices.rankOf() == 1 || indices.rankOf() == 0)) { - // single round case - /*nd4j_printf("Row exec; ContextWidth: %i; LockedWords: %i; numLabels: %i; Train words: %i\n", (int) context.lengthOf(), (int) lockedWords.lengthOf(), numLabels.isEmpty() ? 0 : numLabels.e(0), (int) trainWords); - if (context.lengthOf() == 2) { - context.printBuffer("context"); - lockedWords.printBuffer("locked"); - codes.printBuffer("codes"); - indices.printBuffer("indices"); - }*/ - - auto hsRounds = codes.lengthOf(); - - BUILD_SINGLE_SELECTOR(xType, cbow_, (syn0.buffer(), syn1.buffer(), syn1Neg.buffer(), expTable.buffer(), negTable.buffer(), inferenceVector.buffer(), target.isEmpty() ? -1 : target.e(0), ngStarter.isEmpty() ? -1 : ngStarter.e(0), reinterpret_cast(context.buffer()), reinterpret_cast(lockedWords.buffer()),reinterpret_cast(indices.buffer()), reinterpret_cast(codes.buffer()), alpha.e( 0), randomValue.e(0), (int) context.lengthOf(), hsRounds, nsRounds, (int) syn0.sizeAt(0), (int) syn0.sizeAt(1), (int) expTable.lengthOf(), (int) negTable.lengthOf(), numLabels.isEmpty() ? 0 : numLabels.e(0), trainWords), FLOAT_TYPES); - } else if (context.rankOf() == 2 && indices.rankOf() == 2) { - // batch mode - //nd4j_printf("Batch exec\n",""); - - BUILD_SINGLE_SELECTOR(xType, cbowBatchExec_, (syn0, syn1, syn1Neg, expTable.buffer(), negTable.buffer(), nullptr, context, lockedWords, target, ngStarter, indices, codes, alpha, randomValue, numLabels, nsRounds, syn0.sizeAt(0), syn0.sizeAt(1), expTable.lengthOf(), negTable.isEmpty() ? 0 : negTable.lengthOf(), trainWords, numWorkers), FLOAT_TYPES); - } else - throw std::runtime_error("CBOW: context must have rank 0/1 or 2"); - } +namespace ops { +namespace helpers { +template +void hSoftmax_(void *vsyn0, void *vsyn1, void *vexpTable, void *vneu1e, + double alpha, int vectorLength, int code, int expLength, + bool isInference) { + auto syn0 = reinterpret_cast(vsyn0); + auto syn1 = reinterpret_cast(vsyn1); + auto expTable = reinterpret_cast(vexpTable); + auto neu1e = reinterpret_cast(vneu1e); + + T dot(0.0f); + T g(0.0f); + T f(0.0f); + + // dot + for (int e = 0; e < vectorLength; e++) { + dot += syn0[e] * syn1[e]; + } + + // gradient + if (dot < (T)-HS_MAX_EXP || dot >= (T)HS_MAX_EXP) return; + + int idx = static_cast((dot + HS_MAX_EXP) * + ((float)expLength / HS_MAX_EXP / 2.0f)); + + if (idx >= expLength || idx < 0) return; + + f = expTable[idx]; + g = (static_cast(1.0f) - static_cast(code) - f) * (T)alpha; + + // axpy1 + + for (int e = 0; e < vectorLength; e++) { + neu1e[e] = g * syn1[e] + neu1e[e]; + } + + // axpy2 + if (!isInference) { + for (int e = 0; e < vectorLength; e++) { + syn1[e] = g * syn0[e] + syn1[e]; + } + } +} + +template +void nSampling_(void *vsyn0, void *vsyn1Neg, void *vexpTable, void *vneu1e, + double alpha, int vectorLength, int code, int expLength, + bool isInference) { + auto syn0 = reinterpret_cast(vsyn0); + auto syn1Neg = reinterpret_cast(vsyn1Neg); + auto expTable = reinterpret_cast(vexpTable); + auto neu1e = reinterpret_cast(vneu1e); + + T dot = (T)0.0f; + T g = (T)0.0f; + + for (int e = 0; e < vectorLength; e++) { + dot += syn0[e] * syn1Neg[e]; + } + + if (dot > HS_MAX_EXP) + g = (code - 1) * alpha; + else if (dot < (T)-HS_MAX_EXP) + g = (code - 0) * alpha; + else { + int idx = (int)((dot + (T)HS_MAX_EXP) * ((T)expLength / HS_MAX_EXP / 2.0)); + if (idx >= expLength) return; + + if (idx < 0) return; + + g = ((T)code - expTable[idx]) * alpha; + } + + // axpy1 + for (int e = 0; e < vectorLength; e++) { + neu1e[e] = g * syn1Neg[e] + neu1e[e]; + } + + // axpy2 + if (!isInference) { + for (int e = 0; e < vectorLength; e++) { + syn1Neg[e] = g * syn0[e] + syn1Neg[e]; + } + } +} + +template +void cbow_(void *vsyn0, void *vsyn1, void *vsyn1Neg, void *vexpTable, + void *vnegTable, void *vinfVector, int target, int ngStarter, + int *context, int *lockedWords, int *indices, int8_t *codes, + double alpha, Nd4jLong randomValue, const int contextWidth, + const int hsRounds, const int nsRounds, const int vocabSize, + const int vectorLength, const int expLength, const int negLength, + const int numLabels, const bool trainWords) { + auto syn0 = reinterpret_cast(vsyn0); + auto syn1 = reinterpret_cast(vsyn1); + auto syn1Neg = reinterpret_cast(vsyn1Neg); + auto expTable = reinterpret_cast(vexpTable); + auto negTable = reinterpret_cast(vnegTable); + auto infVector = reinterpret_cast(vinfVector); + + auto neu1 = new T[vectorLength]; + auto neu1e = new T[vectorLength]; + memset(neu1, 0, vectorLength * sizeof(T)); + memset(neu1e, 0, vectorLength * sizeof(T)); + + // building neu1 for current window + for (int c = 0; c < contextWidth; c++) { + if (context[c] >= vocabSize) throw std::runtime_error("Bad context 4"); + + T *syn0word = syn0 + (context[c] * vectorLength); + + for (int i = 0; i < vectorLength; i++) { + neu1[i] += syn0word[i]; + } + } + + // for inference we add additional inference vector + if (infVector != nullptr) { + for (int i = 0; i < vectorLength; i++) { + neu1[i] += infVector[i]; + } + } + + // average neu1 + if (contextWidth > 0) { + for (int i = 0; i < vectorLength; i++) { + neu1[i] /= contextWidth + (infVector != nullptr ? 1 : 0); + } + } + + // softmax round + if (hsRounds > 0) { + for (int i = 0; i < hsRounds; i++) { + if (indices[i] < 0 || indices[i] >= vocabSize) + throw std::runtime_error("Bad context 5"); + + hSoftmax_(neu1, syn1 + (indices[i] * vectorLength), expTable, neu1e, + alpha, vectorLength, codes[i], expLength, + infVector != nullptr); + } + } + + auto nsStarter = ngStarter; + auto irow = nsStarter; + if (nsRounds > 0) { + for (int r = 0; r < nsRounds + 1; r++) { + if (r == 0) { + // target is known in advance + } else { + randomValue = randomValue * (unsigned long long)25214903917 + 11; + auto idx = + sd::math::nd4j_abs((randomValue >> 16) % negLength); + irow = idx >= negLength ? -1 : static_cast(negTable[idx]); + + if (irow < 0 || irow >= vocabSize) + irow = randomValue % (vocabSize - 1) + 1; + if (irow == nsStarter) continue; + } + + nSampling_(neu1, syn1Neg + (irow * vectorLength), expTable, neu1e, + alpha, vectorLength, r == 0 ? 1 : 0, expLength, + infVector != nullptr); + } + } + + // if we don't train words - we skip start of idxSyn0 + int starter = trainWords == 1 ? 0 : contextWidth - numLabels; + + // propagate neu1e -> syn0 + if (infVector == nullptr) { + for (int c = starter; c < contextWidth; c++) { + if (lockedWords[c] == 1) continue; + + T *syn0word = syn0 + (context[c] * vectorLength); + + for (int i = 0; i < vectorLength; i++) { + syn0word[i] += neu1e[i]; + } + } + } else { + for (int i = 0; i < vectorLength; i++) { + infVector[i] += neu1e[i]; + } + } + + delete[] neu1; + delete[] neu1e; +} +BUILD_SINGLE_TEMPLATE(template void cbow_, + (void *syn0, void *syn1, void *syn1Neg, void *expTable, + void *vnegTable, void *vinfVector, int target, + int ngStarter, int *context, int *lockedWords, + int *indices, int8_t *codes, double alpha, + Nd4jLong randomValue, const int contextWidth, + const int hsRounds, const int nsRounds, + const int vocabSize, const int vectorLength, + const int expLength, const int negLength, + const int numLabels, const bool trainWords), + FLOAT_TYPES); + +template +void skipgram_(void *vsyn0, void *vsyn1, void *vsyn1Neg, void *vexpTable, + void *vnegTable, void *vinfVector, int target, int ngStarter, + int *indices, int8_t *codes, double alpha, Nd4jLong randomValue, + const int hsRounds, const int nsRounds, const int vocabSize, + const int vectorLength, const int expLength, + const int negLength) { + auto syn0 = reinterpret_cast(vsyn0); + auto syn1 = reinterpret_cast(vsyn1); + auto syn1Neg = reinterpret_cast(vsyn1Neg); + auto expTable = reinterpret_cast(vexpTable); + auto negTable = reinterpret_cast(vnegTable); + auto infVector = reinterpret_cast(vinfVector); + + auto neu1e = new T[vectorLength]; + memset(neu1e, 0, vectorLength * sizeof(T)); + + // hierarchic softmax goes first (if enabled) + auto syn0row = + infVector != nullptr ? infVector : syn0 + (target * vectorLength); + auto irow = 0; + if (hsRounds > 0) { + for (int r = 0; r < hsRounds; r++) { + irow = indices[r]; + if (irow < 0 || irow >= vocabSize) break; + + hSoftmax_(syn0row, syn1 + (irow * vectorLength), expTable, neu1e, + alpha, vectorLength, codes[r], expLength, + infVector != nullptr); + } + } + + // negative sampling goes second (if enabled) + auto nsStarter = ngStarter; + irow = nsStarter; + if (nsRounds > 0) { + for (int r = 0; r < nsRounds + 1; r++) { + if (r == 0) { + // target is known in advance + } else { + randomValue = randomValue * (unsigned long long)25214903917 + 11; + auto idx = + sd::math::nd4j_abs((randomValue >> 16) % negLength); + irow = idx >= negLength ? -1 : static_cast(negTable[idx]); + + if (irow < 0 || irow >= vocabSize) + irow = randomValue % (vocabSize - 1) + 1; + if (irow == nsStarter) continue; + } + + nSampling_(syn0row, syn1Neg + (irow * vectorLength), expTable, neu1e, + alpha, vectorLength, r == 0 ? 1 : 0, expLength, + infVector != nullptr); + } + } + + if (infVector == nullptr) { + for (int e = 0; e < vectorLength; e++) { + syn0row[e] += neu1e[e]; + } + } else { + for (int e = 0; e < vectorLength; e++) { + infVector[e] += neu1e[e]; + } + } + + delete[] neu1e; +} +BUILD_SINGLE_TEMPLATE(template void skipgram_, + (void *syn0, void *syn1, void *syn1Neg, void *expTable, + void *vnegTable, void *vinfVector, int target, + int ngStarter, int *indices, int8_t *codes, double alpha, + Nd4jLong randomValue, const int hsRounds, + const int nsRounds, const int vocabSize, + const int vectorLength, const int expLength, + const int negLength), + FLOAT_TYPES); + +int binarySearch(const int *haystack, const int needle, + const int totalElements) { + int firstIndex = 0; + int lastIndex = totalElements - 1; + int halfIndex = + sd::math::nd4j_floor((lastIndex + firstIndex) / (float)2); + + while (haystack[halfIndex] != needle && firstIndex < lastIndex) { + if (needle < haystack[halfIndex]) { + lastIndex = halfIndex - 1; + } else if (needle > haystack[halfIndex]) { + firstIndex = halfIndex + 1; + } + halfIndex = + sd::math::nd4j_floor((lastIndex + firstIndex) / (float)2); + } + + return (haystack[halfIndex] == needle) ? halfIndex : -1; +} + +template +static void do_update(const int target, const int rowIndex, const int count, + T *syn0, T *neu1t, const int vectorLength) { + auto syn0row = syn0 + (target * vectorLength); + auto neu1e = neu1t + (rowIndex * vectorLength); + for (int e = 0; e < vectorLength; e++) syn0row[e] += neu1e[e] / count; +} + +template +static void do_positive(const int target, const int postive, T *syn0, + T *syn1Neg, T *expTable, T *neu1e, const double alpha, + const int vectorLength, const int expLength) { + // nd4j_printf("Target: [%i]; Positive: [%i]; TID: [%i];\n", target, postive, + // omp_get_thread_num()); + nSampling_(syn0, syn1Neg, expTable, neu1e, alpha, vectorLength, 1, + expLength, false); +} + +template +static void do_negative(int target, int positive, T *syn0, T *syn1Neg, + T *expTable, T *negTable, T *neu1e, int *sStarters, + const double alpha, const unsigned long long rv, + const int vocabSize, const int vectorLength, + const int expLength, const int negLength, + const int nsRounds, const int numThreads, + const int numTargets) { + int irow = 0; + unsigned long long randomValue = rv; + for (int r = 0; r < nsRounds; r++) { + randomValue = sd::math::nd4j_abs( + randomValue * (unsigned long long)25214903917 + 11); + auto idx = sd::math::nd4j_abs((randomValue >> 16) % negLength); + irow = idx >= negLength ? -1 : static_cast(negTable[idx]); + + if (irow < 0 || irow >= vocabSize) irow = randomValue % (vocabSize - 1) + 1; + + if (irow == positive) continue; + + // we shift irow here to guarantee independence + + int dim = irow % numThreads; + if (dim != omp_get_thread_num()) { + irow += (numThreads - dim + omp_get_thread_num()); + + // roll back to nearest affilated word + while (irow >= vocabSize) irow -= numThreads; + + // if this row was processed as first step somewhere - skip it + if (binarySearch(sStarters, irow, numTargets) > 0) { + r--; + continue; + } + } + + nSampling_(syn0, syn1Neg + (irow * vectorLength), expTable, neu1e, alpha, + vectorLength, 0, expLength, false); + } +} + +template +void skipgramBatchExec_(NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, + void *vnegTable, void *vinfVector, NDArray &targets, + NDArray &negStarters, NDArray &indices, NDArray &codes, + NDArray &lr, NDArray &nextRandom, const int nsRounds, + const int vocabSize, const int vectorLength, + const int expLength, const int negLength, + const bool preciseMode, const int numThreads) { + // auto syn0 = reinterpret_cast(vsyn0); + // auto syn1 = reinterpret_cast(vsyn1); + // auto syn1Neg = reinterpret_cast(vsyn1Neg); + const auto expTable = reinterpret_cast(vexpTable); + const auto negTable = reinterpret_cast(vnegTable); + const auto infVector = reinterpret_cast(vinfVector); + + // const auto numThreads = omp_get_max_threads(); + const auto idxShift = indices.isEmpty() ? 0 : indices.sizeAt(1); + const auto hsRounds = codes.isEmpty() ? 0 : codes.sizeAt(1); + + // regular mode provides 0 guarantees for reproducibility + auto numTargets = targets.lengthOf(); + auto bTarget = targets.bufferAsT(); + auto bIndices = indices.bufferAsT(); + auto bCodes = codes.bufferAsT(); + + auto func = PRAGMA_THREADS_FOR { + T sneu1e[600]; + + for (auto t = start; t < stop; t++) { + T *neu1e = vectorLength <= 600 ? sneu1e : new T[vectorLength]; + memset(neu1e, 0, vectorLength * sizeof(T)); + + auto target = bTarget[t]; + auto alpha = lr.e(t); + unsigned long long randomValue = nextRandom.e(t); + + auto syn0row = + reinterpret_cast(s0.bufferWithOffset(target * vectorLength)); + + if (hsRounds > 0) { + int irow = 0; + auto cShift = t * idxShift; + + for (Nd4jLong e = 0; e < hsRounds; e++) { + irow = bIndices[e + cShift]; + if (irow < 0 || irow >= vocabSize) continue; + + auto syn1row = s1.bufferWithOffset(irow * vectorLength); + auto code = bCodes[e + cShift]; + + // nd4j_printf("syn0: [%i]; syn1: [%i]; code: [%i]\n", target, irow, + // code); + hSoftmax_(syn0row, syn1row, expTable, neu1e, alpha, vectorLength, + code, expLength, false); + } + } + + if (nsRounds > 0) { + int irow = negStarters.e(t); + int nsStarter = irow; + for (int r = 0; r < nsRounds + 1; r++) { + if (r == 0) { + // target is known in advance + } else { + randomValue = randomValue * (unsigned long long)25214903917 + 11; + auto idx = + sd::math::nd4j_abs((randomValue >> 16) % negLength); + irow = idx >= negLength ? -1 : static_cast(negTable[idx]); + + if (irow < 0 || irow >= vocabSize) + irow = randomValue % (vocabSize - 1) + 1; + + if (irow == nsStarter) continue; + } + + nSampling_(syn0row, s1n.bufferWithOffset(irow * vectorLength), + expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, + expLength, infVector != nullptr); + } + } + + for (int e = 0; e < vectorLength; e++) syn0row[e] += neu1e[e]; + + // optionally release temp arrays + if (vectorLength > 600) delete[] neu1e; + } + }; + + samediff::Threads::parallel_tad(func, 0, numTargets, 1, numThreads); +} +BUILD_SINGLE_TEMPLATE(template void skipgramBatchExec_, + (NDArray & s0, NDArray &s1, NDArray &s1n, void *vexpTable, + void *vnegTable, void *vinfVector, NDArray &targets, + NDArray &negStarters, NDArray &indices, NDArray &codes, + NDArray &lr, NDArray &nextRandom, const int nsRounds, + const int vocabSize, const int vectorLength, + const int expLength, const int negLength, + const bool preciseMode, const int numThreads), + FLOAT_TYPES); + +template +void cbowBatchExec_(NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, + void *vnegTable, void *vinfVector, NDArray &context, + NDArray &lockedWords, NDArray &targets, + NDArray &negStarters, NDArray &indices, NDArray &codes, + NDArray &lr, NDArray &nextRandom, NDArray &nLabels, + const int nsRounds, const int vocabSize, + const int vectorLength, const int expLength, + const int negLength, const bool trainWords, + const int numThreads) { + const auto syn0 = s0.bufferAsT(); + const auto syn1 = s1.bufferAsT(); + const auto syn1Neg = s1n.bufferAsT(); + + const auto expTable = reinterpret_cast(vexpTable); + const auto negTable = reinterpret_cast(vnegTable); + const auto infVector = reinterpret_cast(vinfVector); + + // const auto numThreads = omp_get_max_threads(); + const auto idxShift = indices.isEmpty() ? 0 : indices.sizeAt(1); + const auto hsRounds = codes.isEmpty() ? 0 : codes.sizeAt(1); + const auto numTargets = context.sizeAt(0); + const int contextWidth = context.sizeAt(1); + + const auto bContext = context.bufferAsT(); + const auto bLocker = lockedWords.bufferAsT(); + const auto bIndices = indices.bufferAsT(); + const auto bCodes = codes.bufferAsT(); + const auto bStarters = negStarters.bufferAsT(); + const auto numIndices = indices.isEmpty() ? 0 : indices.sizeAt(1); + + auto func = PRAGMA_THREADS_FOR { + T sneu1[600]; + T sneu1e[600]; + + for (auto e = start; e < stop; e++) { + T *neu1 = vectorLength <= 600 ? sneu1 : new T[vectorLength]; + T *neu1e = vectorLength <= 600 ? sneu1e : new T[vectorLength]; + + // optionally we nullify temp arrays after successful (and on first) cycle + memset(neu1, 0, sizeof(T) * vectorLength); + memset(neu1e, 0, sizeof(T) * vectorLength); + + auto alpha = lr.e(e); + auto numLabels = nLabels.isEmpty() ? 0 : nLabels.e(e); + + int actualContext = 0; + + // building neu1 for current window + for (int c = 0; c < contextWidth; c++) { + // getting next context word + auto cContext = bContext[c + (e * contextWidth)]; + + // skipping padded values + if (cContext < 0) continue; + + if (cContext >= vocabSize) + throw std::runtime_error("ContextID can't be >= vocab size"); + + T *syn0word = syn0 + (cContext * vectorLength); + + for (int i = 0; i < vectorLength; i++) neu1[i] += syn0word[i]; + + actualContext++; + } + + if (infVector != nullptr) actualContext++; + + if (actualContext > 1) { + for (int i = 0; i < vectorLength; i++) neu1[i] /= actualContext; + } + + // hierarchic softmax step + if (!indices.isEmpty()) { + for (Nd4jLong i = 0; i < numIndices; i++) { + const int cIndex = bIndices[(e * numIndices) + i]; + const int cCode = bCodes[(e * numIndices) + i]; + + // we're skipping padded values + if (cIndex < 0) continue; + + if (cIndex >= vocabSize) + throw std::runtime_error("Index can't be > vocab size"); + + hSoftmax_(neu1, syn1 + (cIndex * vectorLength), expTable, neu1e, + alpha, vectorLength, cCode, expLength, false); + } + } + + // negative sampling step + if (!negStarters.isEmpty() && nsRounds > 0) { + int irow = bStarters[e]; + const int nsStarter = irow; + unsigned long long randomValue = nextRandom.e(e); + + for (int r = 0; r < nsRounds + 1; r++) { + // we're skipping rng on 0 step + if (r != 0) { + randomValue = randomValue * (unsigned long long)25214903917 + 11; + auto idx = + sd::math::nd4j_abs((randomValue >> 16) % negLength); + irow = idx >= negLength ? -1 : static_cast(negTable[idx]); + + if (irow < 0 || irow >= vocabSize) + irow = randomValue % (vocabSize - 1) + 1; + if (irow == nsStarter) continue; + + nSampling_(neu1, s1n.bufferWithOffset(irow * vectorLength), + expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, + expLength, infVector != nullptr); + } else { + nSampling_(neu1, s1n.bufferWithOffset(irow * vectorLength), + expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, + expLength, infVector != nullptr); + } + + // nd4j_printf("Thread <%i>: syn0: [%i]; s1n: [%i];\n", + // omp_get_thread_num(), 0, irow); } + } + + // if we're skipping labels + int starter = trainWords == 1 ? 0 : contextWidth - numLabels; + + // applying previously averaged results + for (int c = starter; c < contextWidth; c++) { + // getting context + auto cContext = bContext[c + (e * contextWidth)]; + auto cLock = bLocker[c + (e * contextWidth)]; + + // skipping padded values + if (cContext < 0 || cLock == 1) continue; + + if (cContext >= vocabSize) + throw std::runtime_error("ContextID can't be > vocab size"); + + // one word from context + T *syn0word = syn0 + (cContext * vectorLength); + + for (int i = 0; i < vectorLength; i++) syn0word[i] += neu1e[i]; + } + + // optionally release temp arrays + if (vectorLength > 600) { + delete[] neu1; + delete[] neu1e; + } } -} \ No newline at end of file + }; + + samediff::Threads::parallel_tad(func, 0, numTargets, 1, numThreads); +} +BUILD_SINGLE_TEMPLATE( + template void cbowBatchExec_, + (NDArray & s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, + void *vinfVector, NDArray &context, NDArray &lockedWords, NDArray &targets, + NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, + NDArray &nextRandom, NDArray &nLabels, const int nsRounds, + const int vocabSize, const int vectorLength, const int expLength, + const int negLength, const bool trainWords, const int numThreads), + FLOAT_TYPES); + +void skipgram(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, + NDArray &negTable, NDArray &target, NDArray &ngStarter, + int nsRounds, NDArray &indices, NDArray &codes, NDArray &alpha, + NDArray &randomValue, NDArray &inferenceVector, + const bool preciseMode, const int numWorkers) { + auto xType = syn0.dataType(); + + // single round case + if ((ngStarter.isScalar() && !ngStarter.isEmpty()) || + (target.isScalar() && !target.isEmpty())) { + auto hsRounds = codes.lengthOf(); + + BUILD_SINGLE_SELECTOR( + xType, skipgram_, + (syn0.buffer(), syn1.buffer(), syn1Neg.buffer(), expTable.buffer(), + negTable.buffer(), inferenceVector.buffer(), + target.isEmpty() ? -1 : target.e(0), + ngStarter.isEmpty() ? -1 : ngStarter.e(0), + reinterpret_cast(indices.buffer()), + reinterpret_cast(codes.buffer()), alpha.e(0), + randomValue.e(0), hsRounds, nsRounds, (int)syn0.sizeAt(0), + (int)syn0.sizeAt(1), (int)expTable.lengthOf(), + (int)negTable.lengthOf()), + FLOAT_TYPES); + } else if (ngStarter.isVector() || target.isVector()) { + // batch mode + + BUILD_SINGLE_SELECTOR( + xType, skipgramBatchExec_, + (syn0, syn1, syn1Neg, expTable.buffer(), negTable.buffer(), nullptr, + target, ngStarter, indices, codes, alpha, randomValue, nsRounds, + syn0.sizeAt(0), syn0.sizeAt(1), expTable.lengthOf(), + negTable.lengthOf(), preciseMode, numWorkers), + FLOAT_TYPES); + } else + throw std::runtime_error("SkipGram: target must have rank 0 or 1"); +} + +void cbow(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, + NDArray &negTable, NDArray &target, NDArray &ngStarter, int nsRounds, + NDArray &context, NDArray &lockedWords, NDArray &indices, + NDArray &codes, NDArray &alpha, NDArray &randomValue, + NDArray &numLabels, NDArray &inferenceVector, const bool trainWords, + int numWorkers) { + auto xType = syn0.dataType(); + + if ((context.rankOf() == 0 || context.rankOf() == 1) && + (indices.rankOf() == 1 || indices.rankOf() == 0)) { + // single round case + /*nd4j_printf("Row exec; ContextWidth: %i; LockedWords: %i; numLabels: %i; + Train words: %i\n", (int) context.lengthOf(), (int) lockedWords.lengthOf(), + numLabels.isEmpty() ? 0 : numLabels.e(0), (int) trainWords); if + (context.lengthOf() == 2) { context.printBuffer("context"); + lockedWords.printBuffer("locked"); + codes.printBuffer("codes"); + indices.printBuffer("indices"); + }*/ + + auto hsRounds = codes.lengthOf(); + + BUILD_SINGLE_SELECTOR( + xType, cbow_, + (syn0.buffer(), syn1.buffer(), syn1Neg.buffer(), expTable.buffer(), + negTable.buffer(), inferenceVector.buffer(), + target.isEmpty() ? -1 : target.e(0), + ngStarter.isEmpty() ? -1 : ngStarter.e(0), + reinterpret_cast(context.buffer()), + reinterpret_cast(lockedWords.buffer()), + reinterpret_cast(indices.buffer()), + reinterpret_cast(codes.buffer()), alpha.e(0), + randomValue.e(0), (int)context.lengthOf(), hsRounds, + nsRounds, (int)syn0.sizeAt(0), (int)syn0.sizeAt(1), + (int)expTable.lengthOf(), (int)negTable.lengthOf(), + numLabels.isEmpty() ? 0 : numLabels.e(0), trainWords), + FLOAT_TYPES); + } else if (context.rankOf() == 2 && indices.rankOf() == 2) { + // batch mode + // nd4j_printf("Batch exec\n",""); + + BUILD_SINGLE_SELECTOR( + xType, cbowBatchExec_, + (syn0, syn1, syn1Neg, expTable.buffer(), negTable.buffer(), nullptr, + context, lockedWords, target, ngStarter, indices, codes, alpha, + randomValue, numLabels, nsRounds, syn0.sizeAt(0), syn0.sizeAt(1), + expTable.lengthOf(), negTable.isEmpty() ? 0 : negTable.lengthOf(), + trainWords, numWorkers), + FLOAT_TYPES); + } else + throw std::runtime_error("CBOW: context must have rank 0/1 or 2"); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp b/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp index 9dfeac2ec16e..387bcf54a70b 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp @@ -21,61 +21,65 @@ #include namespace sd { - namespace ops { - namespace helpers { - template - void rshift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { - auto lambda = LAMBDA_T(x, shift) { - return x >> shift; - }; +namespace ops { +namespace helpers { +template +void rshift_bits_(LaunchContext *launchContext, NDArray &input, NDArray &output, + uint32_t shift) { + auto lambda = LAMBDA_T(x, shift) { return x >> shift; }; - input.applyLambda(lambda, output); - } + input.applyLambda(lambda, output); +} - void rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { - BUILD_SINGLE_SELECTOR(x.dataType(), rshift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); - } +void rshift_bits(LaunchContext *launchContext, NDArray &x, NDArray &z, + uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), rshift_bits_, + (launchContext, x, z, shift), INTEGER_TYPES); +} - template - void shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { - auto lambda = LAMBDA_T(x, shift) { - return x << shift; - }; +template +void shift_bits_(LaunchContext *launchContext, NDArray &input, NDArray &output, + uint32_t shift) { + auto lambda = LAMBDA_T(x, shift) { return x << shift; }; - input.applyLambda(lambda, output); - } + input.applyLambda(lambda, output); +} - void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { - BUILD_SINGLE_SELECTOR(x.dataType(), shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); - } +void shift_bits(LaunchContext *launchContext, NDArray &x, NDArray &z, + uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), shift_bits_, (launchContext, x, z, shift), + INTEGER_TYPES); +} - template - void cyclic_rshift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { - auto step = (sizeof(T) * 8) - shift; - auto lambda = LAMBDA_T(x, shift, step) { - return x >> shift | x << step; - }; +template +void cyclic_rshift_bits_(LaunchContext *launchContext, NDArray &input, + NDArray &output, uint32_t shift) { + auto step = (sizeof(T) * 8) - shift; + auto lambda = LAMBDA_T(x, shift, step) { return x >> shift | x << step; }; - input.applyLambda(lambda, output); - } + input.applyLambda(lambda, output); +} - void cyclic_rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { - BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_rshift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); - } +void cyclic_rshift_bits(LaunchContext *launchContext, NDArray &x, NDArray &z, + uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_rshift_bits_, + (launchContext, x, z, shift), INTEGER_TYPES); +} - template - void cyclic_shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { - auto step = (sizeof(T) * 8) - shift; - auto lambda = LAMBDA_T(x, shift, step) { - return x << shift | x >> step; - }; +template +void cyclic_shift_bits_(LaunchContext *launchContext, NDArray &input, + NDArray &output, uint32_t shift) { + auto step = (sizeof(T) * 8) - shift; + auto lambda = LAMBDA_T(x, shift, step) { return x << shift | x >> step; }; - input.applyLambda(lambda, output); - } + input.applyLambda(lambda, output); +} - void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { - BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); - } - } - } -} \ No newline at end of file +void cyclic_shift_bits(LaunchContext *launchContext, NDArray &x, NDArray &z, + uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_shift_bits_, + (launchContext, x, z, shift), INTEGER_TYPES); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/softmax.cpp b/libnd4j/include/ops/declarable/helpers/cpu/softmax.cpp index 7fd03f8e4d8c..e8c556153e9a 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/softmax.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/softmax.cpp @@ -19,237 +19,239 @@ // @author raver119@gmail.com // -#include +#include +#include #include +#include + #include -#include -#include namespace sd { - namespace ops { - namespace helpers { - - template - static void softMaxForVector_(void const* input, Nd4jLong const* inShapeInfo, void *output, Nd4jLong const* outShapeInfo) { - - auto inBuff = reinterpret_cast(input); - auto outBuff = reinterpret_cast(output); - - T max = -DataTypeUtils::max(); - T sum = 0.; - int inEWS = shape::elementWiseStride(inShapeInfo); - int outEWS = shape::elementWiseStride(outShapeInfo); - int length = shape::length(inShapeInfo); - - if (inEWS >= 1 && outEWS >= 1) { - - if (inEWS == 1 && outEWS == 1) { - - for (int i = 0; i < length; i++) - max = sd::math::nd4j_max(max, inBuff[i]); - - for (int i = 0; i < length; i++) { - outBuff[i] = sd::math::nd4j_exp(inBuff[i] - max); - sum += outBuff[i]; - } - - for (int i = 0; i < length; i++) - outBuff[i] /= sum; - } - else { - - for (int i = 0; i < length; i++) - max = sd::math::nd4j_max(max, inBuff[i * inEWS]); - - for (int i = 0; i < length; i++) { - T r = sd::math::nd4j_exp(inBuff[i * inEWS] - max); - outBuff[i * outEWS] = r; - sum += r; - } - - for (int i = 0; i < length; i++) - outBuff[i * outEWS] /= sum; - } - } - } - - /////////////////////////////////////////////////////////////////// - void softMaxForVector(sd::LaunchContext * context, const NDArray& input, NDArray& output) { - - if(!input.isVector() || !output.isVector()) - throw std::runtime_error("ops::helpers::softMaxForVector function: input and output arrays must be vectors !"); - - auto xType = input.dataType(); - BUILD_SINGLE_SELECTOR(xType, softMaxForVector_, (input.buffer(), input.shapeInfo(), output.buffer(), output.shapeInfo()), FLOAT_TYPES); - } - - template - void softmax_loop(const T* input, T *output, const Nd4jLong * offsets, Nd4jLong numOfSubArrs, uint32_t tadLen); +namespace ops { +namespace helpers { + +template +static void softMaxForVector_(void const* input, Nd4jLong const* inShapeInfo, + void* output, Nd4jLong const* outShapeInfo) { + auto inBuff = reinterpret_cast(input); + auto outBuff = reinterpret_cast(output); + + T max = -DataTypeUtils::max(); + T sum = 0.; + int inEWS = shape::elementWiseStride(inShapeInfo); + int outEWS = shape::elementWiseStride(outShapeInfo); + int length = shape::length(inShapeInfo); + + if (inEWS >= 1 && outEWS >= 1) { + if (inEWS == 1 && outEWS == 1) { + for (int i = 0; i < length; i++) + max = sd::math::nd4j_max(max, inBuff[i]); + + for (int i = 0; i < length; i++) { + outBuff[i] = sd::math::nd4j_exp(inBuff[i] - max); + sum += outBuff[i]; + } + + for (int i = 0; i < length; i++) outBuff[i] /= sum; + } else { + for (int i = 0; i < length; i++) + max = sd::math::nd4j_max(max, inBuff[i * inEWS]); + + for (int i = 0; i < length; i++) { + T r = sd::math::nd4j_exp(inBuff[i * inEWS] - max); + outBuff[i * outEWS] = r; + sum += r; + } + + for (int i = 0; i < length; i++) outBuff[i * outEWS] /= sum; + } + } +} + +/////////////////////////////////////////////////////////////////// +void softMaxForVector(sd::LaunchContext* context, const NDArray& input, + NDArray& output) { + if (!input.isVector() || !output.isVector()) + throw std::runtime_error( + "ops::helpers::softMaxForVector function: input and output arrays must " + "be vectors !"); + + auto xType = input.dataType(); + BUILD_SINGLE_SELECTOR( + xType, softMaxForVector_, + (input.buffer(), input.shapeInfo(), output.buffer(), output.shapeInfo()), + FLOAT_TYPES); +} + +template +void softmax_loop(const T* input, T* output, const Nd4jLong* offsets, + Nd4jLong numOfSubArrs, uint32_t tadLen); #ifdef _OPENMP - template <> - FORCEINLINE void softmax_loop(const float* input, float *output, const Nd4jLong * offsets, Nd4jLong numOfSubArrs, uint32_t tadLen) { +template <> +FORCEINLINE void softmax_loop(const float* input, float* output, + const Nd4jLong* offsets, Nd4jLong numOfSubArrs, + uint32_t tadLen) { #pragma omp parallel for default(shared) - for (Nd4jLong i = 0; i < numOfSubArrs; i++) { - auto inBuff = input + offsets[i]; - auto outBuff = output + offsets[i]; - - float max = -DataTypeUtils::max(); - float sum = 0.f; - -#pragma omp simd reduction(max:max) - for (uint j = 0; j < tadLen; ++j) - max = sd::math::nd4j_max(max, inBuff[j]); - -#pragma omp simd reduction(+:sum) - for (uint j = 0; j < tadLen; ++j) { - float temp = sd::math::nd4j_exp(inBuff[j] - max); - outBuff[j] = temp; - sum += temp; - } - - for (uint j = 0; j < tadLen; ++j) - outBuff[j] /= sum; - } - } -#else - template <> - FORCEINLINE void softmax_loop(const float *input, float *output, const Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen) { - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto inBuff = input + offsets[i]; - auto outBuff = output + offsets[i]; - - float max = -DataTypeUtils::max(); - float sum = 0.f; - - for (uint j = 0; j < tadLen; ++j) - max = sd::math::nd4j_max(max, inBuff[j]); - - for (uint j = 0; j < tadLen; ++j) { - float temp = sd::math::nd4j_exp(inBuff[j] - max); - outBuff[j] = temp; - sum += temp; - } + for (Nd4jLong i = 0; i < numOfSubArrs; i++) { + auto inBuff = input + offsets[i]; + auto outBuff = output + offsets[i]; + + float max = -DataTypeUtils::max(); + float sum = 0.f; + +#pragma omp simd reduction(max : max) + for (uint j = 0; j < tadLen; ++j) + max = sd::math::nd4j_max(max, inBuff[j]); + +#pragma omp simd reduction(+ : sum) + for (uint j = 0; j < tadLen; ++j) { + float temp = sd::math::nd4j_exp(inBuff[j] - max); + outBuff[j] = temp; + sum += temp; + } - for (uint j = 0; j < tadLen; ++j) - outBuff[j] /= sum; - } - }; + for (uint j = 0; j < tadLen; ++j) outBuff[j] /= sum; + } +} +#else +template <> +FORCEINLINE void softmax_loop(const float* input, float* output, + const Nd4jLong* offsets, Nd4jLong numOfSubArrs, + uint32_t tadLen) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto inBuff = input + offsets[i]; + auto outBuff = output + offsets[i]; + + float max = -DataTypeUtils::max(); + float sum = 0.f; + + for (uint j = 0; j < tadLen; ++j) + max = sd::math::nd4j_max(max, inBuff[j]); + + for (uint j = 0; j < tadLen; ++j) { + float temp = sd::math::nd4j_exp(inBuff[j] - max); + outBuff[j] = temp; + sum += temp; + } + + for (uint j = 0; j < tadLen; ++j) outBuff[j] /= sum; + } + }; - samediff::Threads::parallel_tad(func,0, numOfSubArrs); - } + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); +} #endif +template +FORCEINLINE void softmax_loop(const T* input, T* output, + const Nd4jLong* offsets, Nd4jLong numOfSubArrs, + uint32_t tadLen) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto inBuff = input + offsets[i]; + auto outBuff = output + offsets[i]; + + T max = -DataTypeUtils::max(); + T sum(0.f); + +#pragma omp simd reduction(maxT : max) + for (uint j = 0; j < tadLen; ++j) + max = sd::math::nd4j_max(max, inBuff[j]); + +#pragma omp simd reduction(sumT : sum) + for (uint j = 0; j < tadLen; ++j) { + T temp = sd::math::nd4j_exp(inBuff[j] - max); + outBuff[j] = temp; + sum += temp; + } + + for (uint j = 0; j < tadLen; ++j) outBuff[j] /= sum; + } + }; - template - FORCEINLINE void softmax_loop(const T *input, T *output, const Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen) { - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto inBuff = input + offsets[i]; - auto outBuff = output + offsets[i]; - - T max = -DataTypeUtils::max(); - T sum(0.f); - -#pragma omp simd reduction(maxT:max) - for (uint j = 0; j < tadLen; ++j) - max = sd::math::nd4j_max(max, inBuff[j]); - -#pragma omp simd reduction(sumT:sum) - for (uint j = 0; j < tadLen; ++j) { - T temp = sd::math::nd4j_exp(inBuff[j] - max); - outBuff[j] = temp; - sum += temp; - } - - for (uint j = 0; j < tadLen; ++j) - outBuff[j] /= sum; - } - }; - - samediff::Threads::parallel_tad(func,0, numOfSubArrs); - } + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); +} ////////////////////////////////////////////////////////////////////////// - template - static void softmax_(sd::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension) { - - const int rank = input.rankOf(); - - if(input.isVector()) { - - if(rank == 1 || input.sizeAt(dimension) != 1) - softMaxForVector_(input.buffer(), input.shapeInfo(), output.buffer(), output.shapeInfo()); - else - output = 1.; - } - else if(input.isSameShapeStrict(output)) { - - TadPack tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(input.shapeInfo(), dimension); - auto tadShapeInfo = tadPack.primaryShapeInfo(); - auto tadOffsets = tadPack.primaryOffsets(); - const uint numOfSubArrs = tadPack.numberOfTads(); - const uint tadLen = shape::length(tadShapeInfo); - - if(shape::elementWiseStride(tadShapeInfo) == 1){ - auto inBuff = input.bufferAsT(); - T *outBuff = output.bufferAsT(); - - softmax_loop(inBuff, outBuff, tadOffsets, numOfSubArrs, tadLen); - } - else { - - uint inShapeInfoCast[MAX_RANK]; - bool canCast = sd::DataTypeUtils::castShapeInfo(tadShapeInfo, inShapeInfoCast); - - auto offsets = new Nd4jLong[tadLen]; - shape::calcOffsets(tadShapeInfo, offsets); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto inBuff = input.bufferAsT() + tadOffsets[i]; - auto outBuff = output.bufferAsT() + tadOffsets[i]; - - T max = -DataTypeUtils::max(); - T sum = 0.f; - - for (uint j = 0; j < tadLen; ++j) - max = sd::math::nd4j_max(max, inBuff[offsets[j]]); - - for (uint j = 0; j < tadLen; ++j) { - T temp = sd::math::nd4j_exp(inBuff[offsets[j]] - max); - outBuff[offsets[j]] = temp; - sum += temp; - } - - for (uint j = 0; j < tadLen; ++j) - outBuff[offsets[j]] /= sum; - } - }; - - samediff::Threads::parallel_tad(func, 0, numOfSubArrs); - - delete []offsets; - } - } - else { - NDArray max = input.reduceAlongDimension(sd::reduce::Max, {dimension}, true); - input.applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), max, output, false); - output.applyTransform(sd::transform::Exp, output); - NDArray sum = output.reduceAlongDimension(sd::reduce::Sum, {dimension}, true); - output /= sum; - } - } - - - /////////////////////////////////////////////////////////////////// - void softmax(sd::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension) { +template +static void softmax_(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const int dimension) { + const int rank = input.rankOf(); + + if (input.isVector()) { + if (rank == 1 || input.sizeAt(dimension) != 1) + softMaxForVector_(input.buffer(), input.shapeInfo(), output.buffer(), + output.shapeInfo()); + else + output = 1.; + } else if (input.isSameShapeStrict(output)) { + TadPack tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + input.shapeInfo(), dimension); + auto tadShapeInfo = tadPack.primaryShapeInfo(); + auto tadOffsets = tadPack.primaryOffsets(); + const uint numOfSubArrs = tadPack.numberOfTads(); + const uint tadLen = shape::length(tadShapeInfo); + + if (shape::elementWiseStride(tadShapeInfo) == 1) { + auto inBuff = input.bufferAsT(); + T* outBuff = output.bufferAsT(); + + softmax_loop(inBuff, outBuff, tadOffsets, numOfSubArrs, tadLen); + } else { + uint inShapeInfoCast[MAX_RANK]; + bool canCast = + sd::DataTypeUtils::castShapeInfo(tadShapeInfo, inShapeInfoCast); + + auto offsets = new Nd4jLong[tadLen]; + shape::calcOffsets(tadShapeInfo, offsets); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto inBuff = input.bufferAsT() + tadOffsets[i]; + auto outBuff = output.bufferAsT() + tadOffsets[i]; + + T max = -DataTypeUtils::max(); + T sum = 0.f; + + for (uint j = 0; j < tadLen; ++j) + max = sd::math::nd4j_max(max, inBuff[offsets[j]]); + + for (uint j = 0; j < tadLen; ++j) { + T temp = sd::math::nd4j_exp(inBuff[offsets[j]] - max); + outBuff[offsets[j]] = temp; + sum += temp; + } + + for (uint j = 0; j < tadLen; ++j) outBuff[offsets[j]] /= sum; + } + }; - BUILD_SINGLE_SELECTOR(input.dataType(), softmax_, (context, input, output, dimension), FLOAT_TYPES); - } + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); - } + delete[] offsets; } -} \ No newline at end of file + } else { + NDArray max = + input.reduceAlongDimension(sd::reduce::Max, {dimension}, true); + input.applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), max, output, + false); + output.applyTransform(sd::transform::Exp, output); + NDArray sum = + output.reduceAlongDimension(sd::reduce::Sum, {dimension}, true); + output /= sum; + } +} + +/////////////////////////////////////////////////////////////////// +void softmax(sd::LaunchContext* context, const NDArray& input, NDArray& output, + const int dimension) { + BUILD_SINGLE_SELECTOR(input.dataType(), softmax_, + (context, input, output, dimension), FLOAT_TYPES); +} + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp b/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp index a0034bb5dcdd..730129e41d99 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp @@ -17,33 +17,36 @@ // // @author GS // -#include +#include "../solve.h" + #include #include #include #include +#include -#include "../triangular_solve.h" #include "../lup.h" -#include "../solve.h" +#include "../triangular_solve.h" namespace sd { namespace ops { namespace helpers { -// --------------------------------------------------------------------------------------------------------------------------------------- // - template - static void adjointMatrix_(sd::LaunchContext* context, NDArray const* input, NDArray* output) { - auto inputPart = input->allTensorsAlongDimension({-2, -1}); - auto outputPart = output->allTensorsAlongDimension({-2, -1}); - auto rows = input->sizeAt(-2); - output->assign(input); +// --------------------------------------------------------------------------------------------------------------------------------------- +// // +template +static void adjointMatrix_(sd::LaunchContext* context, NDArray const* input, + NDArray* output) { + auto inputPart = input->allTensorsAlongDimension({-2, -1}); + auto outputPart = output->allTensorsAlongDimension({-2, -1}); + auto rows = input->sizeAt(-2); + output->assign(input); auto batchLoop = PRAGMA_THREADS_FOR { for (auto batch = start; batch < stop; batch++) { for (Nd4jLong r = 0; r < rows; r++) { for (Nd4jLong c = 0; c < r; c++) { - math::nd4j_swap(outputPart[batch]->r(r, c) , outputPart[batch]->r(c, r)); + math::nd4j_swap(outputPart[batch].r(r, c) , outputPart[batch].r(c, r)); } } } @@ -51,23 +54,26 @@ namespace helpers { samediff::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1); } -// --------------------------------------------------------------------------------------------------------------------------------------- // - template - static int solveFunctor_(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool const adjoint, NDArray* output) { - - // stage 1: LU decomposition batched - auto leftOutput = leftInput->ulike(); - auto permuShape = rightInput->getShapeAsVector(); permuShape.pop_back(); - auto permutations = NDArrayFactory::create('c', permuShape, context); - helpers::lu(context, leftInput, &leftOutput, &permutations); - auto P = leftInput->ulike(); //permutations batched matrix - P.nullify(); // to fill up matricies with zeros - auto PPart = P.allTensorsAlongDimension({-2,-1}); - auto permutationsPart = permutations.allTensorsAlongDimension({-1}); +// --------------------------------------------------------------------------------------------------------------------------------------- +// // +template +static int solveFunctor_(sd::LaunchContext* context, NDArray* leftInput, + NDArray* rightInput, bool const adjoint, + NDArray* output) { + // stage 1: LU decomposition batched + auto leftOutput = leftInput->ulike(); + auto permuShape = rightInput->getShapeAsVector(); + permuShape.pop_back(); + auto permutations = NDArrayFactory::create('c', permuShape, context); + helpers::lu(context, leftInput, &leftOutput, &permutations); + auto P = leftInput->ulike(); // permutations batched matrix + P.nullify(); // to fill up matricies with zeros + auto PPart = P.allTensorsAlongDimension({-2, -1}); + auto permutationsPart = permutations.allTensorsAlongDimension({-1}); for (auto batch = 0; batch < permutationsPart.size(); ++batch) { - for (Nd4jLong row = 0; row < PPart[batch]->rows(); ++row) { - PPart[batch]->r(row, permutationsPart[batch]->t(row)) = T(1.f); + for (Nd4jLong row = 0; row < PPart[batch].rows(); ++row) { + PPart[batch].r(row, permutationsPart[batch].t(row)) = T(1.f); } } @@ -77,26 +83,34 @@ namespace helpers { MmulHelper::matmul(&P, rightInput, &rightPermuted, 0, 0); ResultSet leftLowerPart = leftLower.allTensorsAlongDimension({-2, -1}); for (auto i = 0; i < leftLowerPart.size(); i++) { - for (Nd4jLong r = 0; r < leftLowerPart[i]->rows(); r++) - leftLowerPart[i]->r(r,r) = (T)1.f; + for (Nd4jLong r = 0; r < leftLowerPart[i].rows(); r++) + leftLowerPart[i].r(r,r) = (T)1.f; } // stage 2: triangularSolveFunctor for Lower with given b helpers::triangularSolveFunctor(context, &leftLower, &rightPermuted, true, false, &rightOutput); // stage 3: triangularSolveFunctor for Upper with output of previous stage helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, false, false, output); - return Status::OK(); - } - -// --------------------------------------------------------------------------------------------------------------------------------------- // - int solveFunctor(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool const adjoint, NDArray* output) { - BUILD_SINGLE_SELECTOR(leftInput->dataType(), return solveFunctor_, (context, leftInput, rightInput, adjoint, output), FLOAT_TYPES); - } -// --------------------------------------------------------------------------------------------------------------------------------------- // - void adjointMatrix(sd::LaunchContext* context, NDArray const* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), adjointMatrix_, (context, input, output), FLOAT_TYPES); - } -// --------------------------------------------------------------------------------------------------------------------------------------- // + return Status::OK(); } + +// --------------------------------------------------------------------------------------------------------------------------------------- +// // +int solveFunctor(sd::LaunchContext* context, NDArray* leftInput, + NDArray* rightInput, bool const adjoint, NDArray* output) { + BUILD_SINGLE_SELECTOR(leftInput->dataType(), return solveFunctor_, + (context, leftInput, rightInput, adjoint, output), + FLOAT_TYPES); } +// --------------------------------------------------------------------------------------------------------------------------------------- +// // +void adjointMatrix(sd::LaunchContext* context, NDArray const* input, + NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), adjointMatrix_, + (context, input, output), FLOAT_TYPES); } +// --------------------------------------------------------------------------------------------------------------------------------------- +// // +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/split.cpp b/libnd4j/include/ops/declarable/helpers/cpu/split.cpp index 48c6c490318d..fd70c0462444 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/split.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/split.cpp @@ -14,118 +14,117 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleh Semeniv (oleg.semeniv@gmail.com) - // +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// -#include #include +#include namespace sd { namespace ops { namespace helpers { +////////////////////////////////////////////////////////////////////////// +template +static void split_(const NDArray& input, const std::vector& outArrs, + const int axis) { + uint numSplits = outArrs.size(); - ////////////////////////////////////////////////////////////////////////// - template - static void split_(const NDArray& input, const std::vector& outArrs, const int axis) { - uint numSplits = outArrs.size(); - - const auto sizeofT = input.sizeOfT(); - - auto xBuff = input.bufferAsT(); - - bool luckCase1 = ((axis == 0 && input.ordering() == 'c') || (axis == input.rankOf() - 1 && input.ordering() == 'f')) && input.ews() == 1; - - if (luckCase1) { - for (uint i = 0; i < numSplits; ++i) { - luckCase1 &= outArrs[i]->ordering() == input.ordering() && outArrs[i]->ews() == 1; - if (!luckCase1) - break; - } - } - - if (luckCase1) { - - T* x = const_cast(xBuff); - for (uint i = 0; i < numSplits; ++i) { - const auto memAmountToCopy = outArrs[i]->lengthOf(); - memcpy(outArrs[i]->bufferAsT(), x, memAmountToCopy * sizeofT); - x += memAmountToCopy; - } - return; - } - - const bool isXcontin = input.strideAt(axis) == 1 && input.ordering() == 'c'; - bool areOutsContin = true; - bool allSameOrder = true; - - if (isXcontin) { - for (uint i = 0; i < numSplits; ++i) { - areOutsContin &= outArrs[i]->strideAt(axis) == 1; - allSameOrder &= outArrs[i]->ordering() == input.ordering(); - if (!areOutsContin || !allSameOrder) - break; - } - } - - const bool luckCase2 = isXcontin && areOutsContin && allSameOrder; + const auto sizeofT = input.sizeOfT(); - if (luckCase2) { + auto xBuff = input.bufferAsT(); - const auto xDim = input.sizeAt(axis); + bool luckCase1 = ((axis == 0 && input.ordering() == 'c') || + (axis == input.rankOf() - 1 && input.ordering() == 'f')) && + input.ews() == 1; - for (Nd4jLong i = 0; i < input.lengthOf() / xDim; ++i) { + if (luckCase1) { + for (uint i = 0; i < numSplits; ++i) { + luckCase1 &= + outArrs[i]->ordering() == input.ordering() && outArrs[i]->ews() == 1; + if (!luckCase1) break; + } + } + + if (luckCase1) { + T* x = const_cast(xBuff); + for (uint i = 0; i < numSplits; ++i) { + const auto memAmountToCopy = outArrs[i]->lengthOf(); + memcpy(outArrs[i]->bufferAsT(), x, memAmountToCopy * sizeofT); + x += memAmountToCopy; + } + return; + } + + const bool isXcontin = input.strideAt(axis) == 1 && input.ordering() == 'c'; + bool areOutsContin = true; + bool allSameOrder = true; + + if (isXcontin) { + for (uint i = 0; i < numSplits; ++i) { + areOutsContin &= outArrs[i]->strideAt(axis) == 1; + allSameOrder &= outArrs[i]->ordering() == input.ordering(); + if (!areOutsContin || !allSameOrder) break; + } + } - auto x = xBuff + xDim * i; + const bool luckCase2 = isXcontin && areOutsContin && allSameOrder; - for (uint j = 0; j < numSplits; ++j) { - const auto zDim = outArrs[j]->sizeAt(axis); - T* z = outArrs[j]->bufferAsT() + zDim * i; - memcpy(z, x, zDim * sizeofT); - z += zDim; - x += zDim; - } - } + if (luckCase2) { + const auto xDim = input.sizeAt(axis); - return; - } + for (Nd4jLong i = 0; i < input.lengthOf() / xDim; ++i) { + auto x = xBuff + xDim * i; - uint zDim = outArrs[0]->sizeAt(axis); - // general case + for (uint j = 0; j < numSplits; ++j) { + const auto zDim = outArrs[j]->sizeAt(axis); + T* z = outArrs[j]->bufferAsT() + zDim * i; + memcpy(z, x, zDim * sizeofT); + z += zDim; + x += zDim; + } + } - auto func = PRAGMA_THREADS_FOR{ + return; + } - int coords[MAX_RANK], temp; + uint zDim = outArrs[0]->sizeAt(axis); + // general case - for (auto i = start; i < stop; i += increment) { + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK], temp; - shape::index2coordsCPU(start, i, input.shapeInfo(), coords); - const auto xOffset = shape::getOffset(input.shapeInfo(), coords); + for (auto i = start; i < stop; i += increment) { + shape::index2coordsCPU(start, i, input.shapeInfo(), coords); + const auto xOffset = shape::getOffset(input.shapeInfo(), coords); - uint outArrIdx = 0; + uint outArrIdx = 0; - temp = coords[axis]; + temp = coords[axis]; - while (coords[axis] >= zDim) { - coords[axis] -= zDim; - ++outArrIdx; - } + while (coords[axis] >= zDim) { + coords[axis] -= zDim; + ++outArrIdx; + } - T* z = outArrs[outArrIdx]->bufferAsT(); - const auto zOffset = shape::getOffset(outArrs[outArrIdx]->shapeInfo(), coords); - z[zOffset] = xBuff[xOffset]; + T* z = outArrs[outArrIdx]->bufferAsT(); + const auto zOffset = + shape::getOffset(outArrs[outArrIdx]->shapeInfo(), coords); + z[zOffset] = xBuff[xOffset]; - coords[axis] = temp; - } - }; + coords[axis] = temp; + } + }; - samediff::Threads::parallel_for(func, 0, input.lengthOf()); - } + samediff::Threads::parallel_for(func, 0, input.lengthOf()); +} - void split(sd::LaunchContext* context, const NDArray& input, std::vector& outArrs, const int axis) { - BUILD_SINGLE_SELECTOR(input.dataType(), split_, (input, outArrs, axis), LIBND4J_TYPES); - } - } - } +void split(sd::LaunchContext* context, const NDArray& input, + std::vector& outArrs, const int axis) { + BUILD_SINGLE_SELECTOR(input.dataType(), split_, (input, outArrs, axis), + LIBND4J_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/sru.cpp b/libnd4j/include/ops/declarable/helpers/cpu/sru.cpp index ecd5ead2be2d..e106750b3f77 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/sru.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/sru.cpp @@ -15,337 +15,366 @@ ******************************************************************************/ // -// implementation of operations for Simple Recurrent Unit: arXiv:1709.02755v2 [cs.CL] 12 Sep 2017 +// implementation of operations for Simple Recurrent Unit: arXiv:1709.02755v2 +// [cs.CL] 12 Sep 2017 // // @author Yurii Shyrma, created on 05.12.2017 // -#include #include -#include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// static FORCEINLINE NDArray activation(const NDArray& arr) { - - // return (const_cast&>(arr)).template transform>(); - auto result = NDArray(&arr, false, arr.getContext()); - (const_cast(arr)).applyTransform(transform::Tanh, result); - return result; + // return (const_cast&>(arr)).template + // transform>(); + auto result = NDArray(&arr, false, arr.getContext()); + (const_cast(arr)).applyTransform(transform::Tanh, result); + return result; } - ////////////////////////////////////////////////////////////////////////// static FORCEINLINE NDArray sigmoid(const NDArray& arr) { - return (const_cast(arr)).transform(transform::Sigmoid); + return (const_cast(arr)).transform(transform::Sigmoid); } - ////////////////////////////////////////////////////////////////////////// -void sruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* c0, const NDArray* w, const NDArray* b, NDArray* h, NDArray* c) { - - // x input [bS x inSize], bS - batch size, inSize - number of features - // c0 previous cell state c [bS x inSize], that is at previous time step t-1 - // w weights [inSize x 3*inSize] - // b biases [2*inSize] +void sruCell(sd::LaunchContext* context, const NDArray* x, const NDArray* c0, + const NDArray* w, const NDArray* b, NDArray* h, NDArray* c) { + // x input [bS x inSize], bS - batch size, inSize - number of features + // c0 previous cell state c [bS x inSize], that is at previous time step t-1 + // w weights [inSize x 3*inSize] + // b biases [2*inSize] - // h current cell output [bS x inSize], that is at current time step t - // c current cell state [bS x inSize], that is at current time step t + // h current cell output [bS x inSize], that is at current time step t + // c current cell state [bS x inSize], that is at current time step t - const int inSize = x->sizeAt(1); // inSize - number of features + const int inSize = x->sizeAt(1); // inSize - number of features - auto z = mmul(*x, *w); // [bS x 3*inSize] + auto z = mmul(*x, *w); // [bS x 3*inSize] - // forget gate = sigmoid(x*Wf + bf) - auto f = sigmoid(z({0,0, inSize, 2*inSize}) + (*b)({0, inSize})); + // forget gate = sigmoid(x*Wf + bf) + auto f = sigmoid(z({0, 0, inSize, 2 * inSize}) + (*b)({0, inSize})); - // reset gate = sigmoid(x*Wr + br) - auto r = sigmoid(z({0,0, 2*inSize, 3*inSize}) + (*b)({inSize, 2*inSize})); + // reset gate = sigmoid(x*Wr + br) + auto r = + sigmoid(z({0, 0, 2 * inSize, 3 * inSize}) + (*b)({inSize, 2 * inSize})); - // ◦ means element-wise product or so called Hadamard product - // current sell state = f◦c0 + (1 - f)◦(x*Wc) - c->assign(f * (*c0) + (1.f - f) * z({0, 0 ,0, inSize}) ); - // *c = f*(*c0 - z({},{0, inSize})) + z({{},{0, inSize}}); + // ◦ means element-wise product or so called Hadamard product + // current sell state = f◦c0 + (1 - f)◦(x*Wc) + c->assign(f * (*c0) + (1.f - f) * z({0, 0, 0, inSize})); + // *c = f*(*c0 - z({},{0, inSize})) + z({{},{0, inSize}}); - // current cell output = r◦activation(c) + (1 - r)◦x - h->assign( r * activation(*c) + (1.f - r) * (*x) ); - // *h = r * (activation(c) - *x) + *x; + // current cell output = r◦activation(c) + (1 - r)◦x + h->assign(r * activation(*c) + (1.f - r) * (*x)); + // *h = r * (activation(c) - *x) + *x; } - ////////////////////////////////////////////////////////////////////////// -void sruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* c0, const NDArray* w, const NDArray* b, NDArray* h, NDArray* c) { +void sruTimeLoop(sd::LaunchContext* context, const NDArray* x, + const NDArray* c0, const NDArray* w, const NDArray* b, + NDArray* h, NDArray* c) { + // x input [bS x inSize x time] + // c0 initial cell state (at time step = 0) [bS x inSize], + // w weights, [3*inSize x inSize] + // b biases, [2*inSize] - // x input [bS x inSize x time] - // c0 initial cell state (at time step = 0) [bS x inSize], - // w weights, [3*inSize x inSize] - // b biases, [2*inSize] + // h cell outputs [bS x inSize x time] + // c cell states [bS x inSize x time] - // h cell outputs [bS x inSize x time] - // c cell states [bS x inSize x time] + auto wT = w->transpose(); // [3*inSize x inSize] -> [inSize x 3*inSize] - auto wT = w->transpose(); // [3*inSize x inSize] -> [inSize x 3*inSize] + const int time = x->sizeAt(2); - const int time = x->sizeAt(2); + NDArray ct_1(*c0); - NDArray ct_1(*c0); + // loop through time steps + for (int t = 0; t < time; ++t) { + auto xt = (*x)({0, 0, 0, 0, t, t + 1}); + auto ht = (*h)({0, 0, 0, 0, t, t + 1}); + auto ct = (*c)({0, 0, 0, 0, t, t + 1}); - // loop through time steps - for (int t = 0; t < time; ++t) { - - auto xt = (*x)({0,0, 0,0, t,t+1}); - auto ht = (*h)({0,0, 0,0, t,t+1}); - auto ct = (*c)({0,0, 0,0, t,t+1}); - - helpers::sruCell(context, &xt, &ct_1, &wT, b, &ht, &ct); - ct_1.assign(ct); - } + helpers::sruCell(context, &xt, &ct_1, &wT, b, &ht, &ct); + ct_1.assign(ct); + } } ////////////////////////////////////////////////////////////////////////// template -static void sruBI_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct) { - - // x input 3d tensor [time x bS x 2*K], time - number of time steps, bS - batch size, K - number of features - // w 2d tensor of weights [2*K x 6*K] - // b row of biases with twice length [4*K] - // c0 2d tensor of initial state [bS x 2*K] at time t=0 - // mask optional, 2d tensor of dropout mask [bS x 2*K] - - // ht [time x bS x 2*K] - // ct [time x bS x 2*K] - - const Nd4jLong time = x->sizeAt(0); // time - number of time steps - const Nd4jLong bS = x->sizeAt(1); // bS - batch size - const Nd4jLong K = x->sizeAt(2) / 2; // K - number of features - - // x = x * mask - if(mask) - x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x); // apply mask - - // U = x * w - NDArray wi = mmul(*x, *w); // U [time x bS x 6*K] - - const Nd4jLong d2 = 2*K; - const Nd4jLong ncols = bS*d2; - const Nd4jLong ncolsWi = 3*ncols; - - T* pI = x->bufferAsT(); - T* pWi = wi.bufferAsT(); - T* pBias = const_cast(b)->bufferAsT(); - T* pInit = const_cast(c0)->bufferAsT(); - T* pMask = mask ? const_cast(mask)->bufferAsT() : nullptr; - T* pHt = ht->bufferAsT(); - T* pCt = ct->bufferAsT(); - - auto func = PRAGMA_THREADS_FOR { - for (auto col = start; col < stop; col++) { - const auto colNum = col % d2; - bool flip = colNum >= K; - T maskVal = mask ? *(pMask + col) : T(1); - T cur = *(pInit + col); - T bF = *(pBias + colNum); - T bR = *(pBias + colNum + d2); - T *pWiVal = pWi + 3 * col; - T *pIVal = pI + col; - T *pHtVal = pHt + col; - T *pCtVal = pCt + col; - - if (flip) { - const auto step = (time - 1) * ncols; - pIVal += step; - pHtVal += step; - pCtVal += step; - pWiVal += (time - 1) * ncolsWi; - } - - auto ncolsRev = flip ? -ncols : ncols; - auto ncolsWiRev = flip ? -ncolsWi : ncolsWi; - - for (Nd4jLong t = 0; t < time; ++t) { - // evaluate sigmoids - T ft = (1.) / (1. + sd::math::nd4j_exp(-(pWiVal[1] + bF))); - T rt = (1.) / (1. + sd::math::nd4j_exp(-(pWiVal[2] + bR))); - - cur = (cur - *pWiVal) * ft + *pWiVal; - *pCtVal = cur; - T val = sd::math::nd4j_tanh(cur); - *pHtVal = (val * maskVal - *pIVal) * rt + *pIVal; - - pIVal += ncolsRev; - pWiVal += ncolsWiRev; - pCtVal += ncolsRev; - pHtVal += ncolsRev; - } - } - }; - - samediff::Threads::parallel_tad(func, 0, ncols); +static void sruBI_(NDArray* x, const NDArray* w, const NDArray* b, + const NDArray* c0, const NDArray* mask, NDArray* ht, + NDArray* ct) { + // x input 3d tensor [time x bS x 2*K], time - number of time steps, bS - + // batch size, K - number of features w 2d tensor of weights [2*K x 6*K] + // b row of biases with twice length [4*K] + // c0 2d tensor of initial state [bS x 2*K] at time t=0 + // mask optional, 2d tensor of dropout mask [bS x 2*K] + + // ht [time x bS x 2*K] + // ct [time x bS x 2*K] + + const Nd4jLong time = x->sizeAt(0); // time - number of time steps + const Nd4jLong bS = x->sizeAt(1); // bS - batch size + const Nd4jLong K = x->sizeAt(2) / 2; // K - number of features + + // x = x * mask + if (mask) + x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x); // apply mask + + // U = x * w + NDArray wi = mmul(*x, *w); // U [time x bS x 6*K] + + const Nd4jLong d2 = 2 * K; + const Nd4jLong ncols = bS * d2; + const Nd4jLong ncolsWi = 3 * ncols; + + T* pI = x->bufferAsT(); + T* pWi = wi.bufferAsT(); + T* pBias = const_cast(b)->bufferAsT(); + T* pInit = const_cast(c0)->bufferAsT(); + T* pMask = mask ? const_cast(mask)->bufferAsT() : nullptr; + T* pHt = ht->bufferAsT(); + T* pCt = ct->bufferAsT(); + + auto func = PRAGMA_THREADS_FOR { + for (auto col = start; col < stop; col++) { + const auto colNum = col % d2; + bool flip = colNum >= K; + T maskVal = mask ? *(pMask + col) : T(1); + T cur = *(pInit + col); + T bF = *(pBias + colNum); + T bR = *(pBias + colNum + d2); + T* pWiVal = pWi + 3 * col; + T* pIVal = pI + col; + T* pHtVal = pHt + col; + T* pCtVal = pCt + col; + + if (flip) { + const auto step = (time - 1) * ncols; + pIVal += step; + pHtVal += step; + pCtVal += step; + pWiVal += (time - 1) * ncolsWi; + } + + auto ncolsRev = flip ? -ncols : ncols; + auto ncolsWiRev = flip ? -ncolsWi : ncolsWi; + + for (Nd4jLong t = 0; t < time; ++t) { + // evaluate sigmoids + T ft = (1.) / (1. + sd::math::nd4j_exp(-(pWiVal[1] + bF))); + T rt = (1.) / (1. + sd::math::nd4j_exp(-(pWiVal[2] + bR))); + + cur = (cur - *pWiVal) * ft + *pWiVal; + *pCtVal = cur; + T val = sd::math::nd4j_tanh(cur); + *pHtVal = (val * maskVal - *pIVal) * rt + *pIVal; + + pIVal += ncolsRev; + pWiVal += ncolsWiRev; + pCtVal += ncolsRev; + pHtVal += ncolsRev; + } + } + }; + + samediff::Threads::parallel_tad(func, 0, ncols); } ////////////////////////////////////////////////////////////////////////// template -static void sruBIBP_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradHt, const NDArray* mask, - NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0) { - - // x input 3d tensor [time x bS x 2*K], time - number of time steps, bS - batch size, K - number of features - // w 2d tensor of weights [2*K x 6*K] - // b row of biases with twice length 4*K] - // c0 2d tensor of initial state [bS x 2*K] at time t=0 - // ct [time x bS x 2*K] - // inGradC0 [bS x 2*K] - // inGradHt [time x bS x 2*K] - // mask optional, 2d tensor of dropout mask [bS x 2*K] - - // gradI [time x bS x 2*K] - // gradW [time x 2*K x 6*K] - // gradB [4*K] - // gradC0 [bS x 2*K] - - const Nd4jLong time = x->sizeAt(0); // time - number of time steps - const Nd4jLong bS = x->sizeAt(1); - const Nd4jLong K = x->sizeAt(2) / 2; - - // x = x * mask - if(mask) - x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x); // apply mask - - // U = x * w - NDArray wi = mmul(*x, *w); // [time x bS x 2*K] * [2*K x 6*K] = [time x bS x 6*K] - NDArray gradBias(x->ordering(), {bS, 4*K}, x->dataType(), x->getContext()); - NDArray gradWi (x->ordering(), {time, bS, 6*K}, x->dataType(), x->getContext()); - - const Nd4jLong d2 = 2*K; - const Nd4jLong ncols = bS*d2; - const Nd4jLong ncolsWi = 3*ncols; - T* pInput = x->bufferAsT(); - T* pWi = wi.bufferAsT(); - T* pBias = const_cast(b)->bufferAsT(); - T* pInit = const_cast(c0)->bufferAsT(); - T* pMask = mask ? const_cast(mask)->bufferAsT() : nullptr; - T* pState = const_cast(ct)->bufferAsT(); - T* pInGradCt = const_cast(inGradC0)->bufferAsT(); - T* pInGradHt = const_cast(inGradHt)->bufferAsT(); - T* pGradWi = gradWi.bufferAsT(); - T* pGradInput = gradI->bufferAsT(); - T* pGradBias = gradBias.bufferAsT(); - T* pGradInit = gradC0->bufferAsT(); - - auto func = PRAGMA_THREADS_FOR { - for (auto col = start; col < stop; col++) { - T gbF = 0.f; - T gbR = 0.f; - const auto colNum = col % d2; - const bool flip = colNum >= K; - T maskVal = mask ? *(pMask + col) : T(1.); - T cur = *(pInGradCt + col); - T bF = *(pBias + colNum); - T bR = *(pBias + colNum + d2); - T *pWiVal = pWi + 3 * col; - T *pInputVal = pInput + col; - T *pStateVal = pState + col; - T *pInGradHtVal = pInGradHt + col; - T *pGradWiVal = pGradWi + 3 * col; - T *pGradInputVal = pGradInput + col; - - if (!flip) { - const auto stepI = (time - 1) * ncols; - const auto stepW = (time - 1) * ncolsWi; - pInputVal += stepI; - pStateVal += stepI; - pInGradHtVal += stepI; - pGradInputVal += stepI; - pWiVal += stepW; - pGradWiVal += stepW; - } - - Nd4jLong ncolsRev = flip ? -ncols : ncols; - Nd4jLong ncolsWiRev = flip ? -ncolsWi : ncolsWi; - - for (Nd4jLong t = 0; t < time; ++t) { - // evaluate sigmoids - T ft = ((T) 1.) / ((T) 1. + sd::math::nd4j_exp(-(*(pWiVal + 1) + bF))); - T rt = ((T) 1.) / ((T) 1. + sd::math::nd4j_exp(-(*(pWiVal + 2) + bR))); - - T val = sd::math::nd4j_tanh(*pStateVal); - T prevVal = (t < time - 1) ? (*(pStateVal - ncolsRev)) : (*(pInit + col)); - // grad wrt input - *pGradInputVal = *pInGradHtVal - (*pInGradHtVal) * rt; - // grad wrt rt, wiR and bR - T grt = (*pInGradHtVal) * (val * maskVal - *pInputVal) * (rt - rt * rt); - *(pGradWiVal + 2) = grt; - gbR += grt; - // grad wrt state - T gradSateVal = (*pInGradHtVal) * maskVal * (rt - rt * val * val) + cur; - // grad wrt wi0 - *pGradWiVal = gradSateVal - gradSateVal * ft; - // grad wrt ft, wi1, and bF - T gft = gradSateVal * (prevVal - *pWiVal) * (ft - ft * ft); - *(pGradWiVal + 1) = gft; - gbF += gft; - // grad wrt c_previous - cur = gradSateVal * ft; - - pInputVal -= ncolsRev; - pWiVal -= ncolsWiRev; - pStateVal -= ncolsRev; - pGradWiVal -= ncolsWiRev; - pGradInputVal -= ncolsRev; - pInGradHtVal -= ncolsRev; - } - *(pGradBias + col) = gbF; - *(pGradBias + col + ncols) = gbR; - *(pGradInit + col) = cur; - } - }; - - samediff::Threads::parallel_tad(func, 0, ncols); - - // gradB - gradBias.reduceAlongDimension(reduce::Sum, *gradB, {0}); // [4*K] - - // gradW - x->permutei({0, 2, 1}); // [time x bS x 2*K] -> [time x 2*K x bS] - MmulHelper::mmul(x, &gradWi, gradW, 1., 0.); // [time x 2*K x bS ] * [time x bS x 6*K] = [time x 2*K x 6*K] -} - - -void sruBI(sd::LaunchContext * context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct) { - BUILD_SINGLE_SELECTOR(x->dataType(), sruBI_, (x, w, b, c0, mask, ht, ct), FLOAT_TYPES); -} -void sruBIBP(sd::LaunchContext * context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradH, const NDArray* mask, NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0) { - BUILD_SINGLE_SELECTOR(x->dataType(), sruBIBP_, (x, w, b, c0, ct, inGradC0, inGradH, mask, gradI, gradW, gradB, gradC0), FLOAT_TYPES); -} - +static void sruBIBP_(NDArray* x, const NDArray* w, const NDArray* b, + const NDArray* c0, const NDArray* ct, + const NDArray* inGradC0, const NDArray* inGradHt, + const NDArray* mask, NDArray* gradI, NDArray* gradW, + NDArray* gradB, NDArray* gradC0) { + // x input 3d tensor [time x bS x 2*K], time - number of time steps, bS - + // batch size, K - number of features w 2d tensor of weights [2*K x 6*K] b + // row of biases with twice length 4*K] c0 2d tensor of initial state [bS x + // 2*K] at time t=0 ct [time x bS x 2*K] inGradC0 [bS x 2*K] inGradHt [time x + // bS x 2*K] mask optional, 2d tensor of dropout mask [bS x 2*K] + + // gradI [time x bS x 2*K] + // gradW [time x 2*K x 6*K] + // gradB [4*K] + // gradC0 [bS x 2*K] + + const Nd4jLong time = x->sizeAt(0); // time - number of time steps + const Nd4jLong bS = x->sizeAt(1); + const Nd4jLong K = x->sizeAt(2) / 2; + + // x = x * mask + if (mask) + x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x); // apply mask + + // U = x * w + NDArray wi = + mmul(*x, *w); // [time x bS x 2*K] * [2*K x 6*K] = [time x bS x 6*K] + NDArray gradBias(x->ordering(), {bS, 4 * K}, x->dataType(), x->getContext()); + NDArray gradWi(x->ordering(), {time, bS, 6 * K}, x->dataType(), + x->getContext()); + + const Nd4jLong d2 = 2 * K; + const Nd4jLong ncols = bS * d2; + const Nd4jLong ncolsWi = 3 * ncols; + T* pInput = x->bufferAsT(); + T* pWi = wi.bufferAsT(); + T* pBias = const_cast(b)->bufferAsT(); + T* pInit = const_cast(c0)->bufferAsT(); + T* pMask = mask ? const_cast(mask)->bufferAsT() : nullptr; + T* pState = const_cast(ct)->bufferAsT(); + T* pInGradCt = const_cast(inGradC0)->bufferAsT(); + T* pInGradHt = const_cast(inGradHt)->bufferAsT(); + T* pGradWi = gradWi.bufferAsT(); + T* pGradInput = gradI->bufferAsT(); + T* pGradBias = gradBias.bufferAsT(); + T* pGradInit = gradC0->bufferAsT(); + + auto func = PRAGMA_THREADS_FOR { + for (auto col = start; col < stop; col++) { + T gbF = 0.f; + T gbR = 0.f; + const auto colNum = col % d2; + const bool flip = colNum >= K; + T maskVal = mask ? *(pMask + col) : T(1.); + T cur = *(pInGradCt + col); + T bF = *(pBias + colNum); + T bR = *(pBias + colNum + d2); + T* pWiVal = pWi + 3 * col; + T* pInputVal = pInput + col; + T* pStateVal = pState + col; + T* pInGradHtVal = pInGradHt + col; + T* pGradWiVal = pGradWi + 3 * col; + T* pGradInputVal = pGradInput + col; + + if (!flip) { + const auto stepI = (time - 1) * ncols; + const auto stepW = (time - 1) * ncolsWi; + pInputVal += stepI; + pStateVal += stepI; + pInGradHtVal += stepI; + pGradInputVal += stepI; + pWiVal += stepW; + pGradWiVal += stepW; + } + + Nd4jLong ncolsRev = flip ? -ncols : ncols; + Nd4jLong ncolsWiRev = flip ? -ncolsWi : ncolsWi; + + for (Nd4jLong t = 0; t < time; ++t) { + // evaluate sigmoids + T ft = + ((T)1.) / ((T)1. + sd::math::nd4j_exp(-(*(pWiVal + 1) + bF))); + T rt = + ((T)1.) / ((T)1. + sd::math::nd4j_exp(-(*(pWiVal + 2) + bR))); + + T val = sd::math::nd4j_tanh(*pStateVal); + T prevVal = + (t < time - 1) ? (*(pStateVal - ncolsRev)) : (*(pInit + col)); + // grad wrt input + *pGradInputVal = *pInGradHtVal - (*pInGradHtVal) * rt; + // grad wrt rt, wiR and bR + T grt = (*pInGradHtVal) * (val * maskVal - *pInputVal) * (rt - rt * rt); + *(pGradWiVal + 2) = grt; + gbR += grt; + // grad wrt state + T gradSateVal = (*pInGradHtVal) * maskVal * (rt - rt * val * val) + cur; + // grad wrt wi0 + *pGradWiVal = gradSateVal - gradSateVal * ft; + // grad wrt ft, wi1, and bF + T gft = gradSateVal * (prevVal - *pWiVal) * (ft - ft * ft); + *(pGradWiVal + 1) = gft; + gbF += gft; + // grad wrt c_previous + cur = gradSateVal * ft; + + pInputVal -= ncolsRev; + pWiVal -= ncolsWiRev; + pStateVal -= ncolsRev; + pGradWiVal -= ncolsWiRev; + pGradInputVal -= ncolsRev; + pInGradHtVal -= ncolsRev; + } + *(pGradBias + col) = gbF; + *(pGradBias + col + ncols) = gbR; + *(pGradInit + col) = cur; + } + }; -BUILD_SINGLE_TEMPLATE(template void sruBI_, (NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct), FLOAT_TYPES); -BUILD_SINGLE_TEMPLATE(template void sruBIBP_, (NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradH, const NDArray* mask, NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0), FLOAT_TYPES); + samediff::Threads::parallel_tad(func, 0, ncols); + // gradB + gradBias.reduceAlongDimension(reduce::Sum, *gradB, {0}); // [4*K] + // gradW + x->permutei({0, 2, 1}); // [time x bS x 2*K] -> [time x 2*K x bS] + MmulHelper::mmul( + x, &gradWi, gradW, 1., + 0.); // [time x 2*K x bS ] * [time x bS x 6*K] = [time x 2*K x 6*K] } + +void sruBI(sd::LaunchContext* context, NDArray* x, const NDArray* w, + const NDArray* b, const NDArray* c0, const NDArray* mask, + NDArray* ht, NDArray* ct) { + BUILD_SINGLE_SELECTOR(x->dataType(), sruBI_, (x, w, b, c0, mask, ht, ct), + FLOAT_TYPES); } +void sruBIBP(sd::LaunchContext* context, NDArray* x, const NDArray* w, + const NDArray* b, const NDArray* c0, const NDArray* ct, + const NDArray* inGradC0, const NDArray* inGradH, + const NDArray* mask, NDArray* gradI, NDArray* gradW, + NDArray* gradB, NDArray* gradC0) { + BUILD_SINGLE_SELECTOR( + x->dataType(), sruBIBP_, + (x, w, b, c0, ct, inGradC0, inGradH, mask, gradI, gradW, gradB, gradC0), + FLOAT_TYPES); } +BUILD_SINGLE_TEMPLATE(template void sruBI_, + (NDArray * x, const NDArray* w, const NDArray* b, + const NDArray* c0, const NDArray* mask, NDArray* ht, + NDArray* ct), + FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template void sruBIBP_, + (NDArray * x, const NDArray* w, const NDArray* b, + const NDArray* c0, const NDArray* ct, + const NDArray* inGradC0, const NDArray* inGradH, + const NDArray* mask, NDArray* gradI, NDArray* gradW, + NDArray* gradB, NDArray* gradC0), + FLOAT_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd ////////////////////////////////////////////////////////////////////////// // template -// void sruCellBP(const std::vector*>& inArrs, const std::vector*>& outArrs) { - -// NDArray* x = inArrs[0]; // input [bS x inSize], bS - batch size, inSize - number of features -// NDArray* c0 = inArrs[1]; // previous cell state c [bS x inSize], that is at previous time step t-1 -// NDArray* w = inArrs[2]; // weights [inSize x 3*inSize] -// NDArray* b = inArrs[3]; // biases [2*inSize] -// NDArray* dLdC = inArrs[4]; // gradient of the loss func with respect to cell output [bS x inSize] -// NDArray* dLdH = inArrs[5]; // gradient of the loss func with respect to cell state [bS x inSize] - -// NDArray* dLdX = outArrs[0]; // gradient of the loss func with respect to input [bS x inSize], so called epsilon -// NDArray* dLdW = outArrs[1]; // gradient of the loss func with respect to weights [inSize x 3*inSize] -// NDArray* dLdB = outArrs[2]; // gradient of the loss func with respect to biases [2*inSize] -// NDArray* dLdC0 = outArrs[3]; // gradient of the loss func with respect to previous cell state [bS, inSize] +// void sruCellBP(const std::vector*>& inArrs, const +// std::vector*>& outArrs) { + +// NDArray* x = inArrs[0]; // input [bS x inSize], bS - +// batch size, inSize - number of features NDArray* c0 = inArrs[1]; // +// previous cell state c [bS x inSize], that is at previous time step t-1 +// NDArray* w = inArrs[2]; // weights [inSize x +// 3*inSize] NDArray* b = inArrs[3]; // biases +// [2*inSize] NDArray* dLdC = inArrs[4]; // gradient of the +// loss func with respect to cell output [bS x inSize] NDArray* dLdH = +// inArrs[5]; // gradient of the loss func with respect to +// cell state [bS x inSize] + +// NDArray* dLdX = outArrs[0]; // gradient of the loss func +// with respect to input [bS x inSize], so called epsilon NDArray* dLdW +// = outArrs[1]; // gradient of the loss func with respect to +// weights [inSize x 3*inSize] NDArray* dLdB = outArrs[2]; // gradient +// of the loss func with respect to biases [2*inSize] NDArray* dLdC0 = +// outArrs[3]; // gradient of the loss func with respect to +// previous cell state [bS, inSize] // const int inSize = x->sizeAt(1); // inSize - number of features @@ -353,16 +382,17 @@ BUILD_SINGLE_TEMPLATE(template void sruBIBP_, (NDArray* x, const NDArray* w, con // NDArray z = mmul(*x, *w); // [bS x 3*inSize] // // forget gate = sigmoid(x*Wf + bf) -// NDArray f = sigmoid(z({{},{inSize, 2*inSize}}) + (*b)({{0, inSize}})); // [bS, inSize] -// NDArray oneMinusF = 1. - f; +// NDArray f = sigmoid(z({{},{inSize, 2*inSize}}) + (*b)({{0, +// inSize}})); // [bS, inSize] NDArray oneMinusF = 1. - f; // // reset gate = sigmoid(x*Wr + br) -// NDArray r = sigmoid(z({{},{2*inSize, 3*inSize}}) + (*b)({{inSize, 2*inSize}})); // [bS, inSize] -// NDArray oneMinusR = 1. - r; - -// // current sell state = f◦c0 + (1 - f)◦(x*Wc) ---> c->assign( f*(*c0) + ((T)1. - f) * z({{},{0, inSize}}) ); -// // current cell output = r◦activation(c) + (1 - r)◦x ---> h->assign( r*activation(*c) + ((T)1. - r) * (*x) ); +// NDArray r = sigmoid(z({{},{2*inSize, 3*inSize}}) + (*b)({{inSize, +// 2*inSize}})); // [bS, inSize] NDArray oneMinusR = 1. - r; +// // current sell state = f◦c0 + (1 - f)◦(x*Wc) ---> c->assign( +// f*(*c0) + ((T)1. - f) * z({{},{0, inSize}}) ); +// // current cell output = r◦activation(c) + (1 - r)◦x ---> h->assign( +// r*activation(*c) + ((T)1. - r) * (*x) ); // //*********** back propagation ***********// // // dCdC0 = f; @@ -378,27 +408,30 @@ BUILD_SINGLE_TEMPLATE(template void sruBIBP_, (NDArray* x, const NDArray* w, con // // dHdC = r * (1 - tanh*tanh) // NDArray dHdC = r * (1. - tanh * tanh); // // dCdX = dCdX + dCdF*dFdX = (1-f)*Wc + dCdF*Wf -// NDArray dCdX = oneMinusF * (*w)({{},{0, inSize}}) + dCdF * (*w)({{},{inSize, 2*inSize}}); - +// NDArray dCdX = oneMinusF * (*w)({{},{0, inSize}}) + dCdF * +// (*w)({{},{inSize, 2*inSize}}); // // dLdC0 = dLdC * dCdC0 = dLdC * f // dLdC0->assign((*dLdC) * f); - -// // dLdBf = dLdH*dHdBf + dLdC*dCdBf = dLdH*dHdC*dCdBf + dLdC*dCdF*dFdBf = dLdH*dHdC*dCdF*dFdBf + dLdC*dCdF*dFdBf = (dLdH*dHdC + dLdC)*dCdF*dFdBf +// // dLdBf = dLdH*dHdBf + dLdC*dCdBf = dLdH*dHdC*dCdBf + dLdC*dCdF*dFdBf = +// dLdH*dHdC*dCdF*dFdBf + dLdC*dCdF*dFdBf = (dLdH*dHdC + dLdC)*dCdF*dFdBf // (*dLdB)({{0, inSize}}).assign(((*dLdH) * dHdC + *dLdC) * dCdF * dFdBf); // // dLdBr = dLdH * dHdR * dRdBr // (*dLdB)({{inSize, 2*inSize}}).assign((*dLdH) * dHdR * dRdBr) - -// // dLdWc = dLdH*dHdWc + dLdC*dCdWc = dLdH*dHdC*dCdWc + dLdC*dCdWc = (dLdH*dHdC + dLdC) * dCdWc = (dLdH*dHdC + dLdC) * (1-f)*x -// (*dLdW)({{}, {0, inSize}}).assign(((*dLdH) * dHdC + *dLdC) * oneMinusF * (*x)); +// // dLdWc = dLdH*dHdWc + dLdC*dCdWc = dLdH*dHdC*dCdWc + dLdC*dCdWc = +// (dLdH*dHdC + dLdC) * dCdWc = (dLdH*dHdC + dLdC) * (1-f)*x +// (*dLdW)({{}, {0, inSize}}).assign(((*dLdH) * dHdC + *dLdC) * oneMinusF * +// (*x)); // // dLdWf = dLdBf * x // (*dLdW)({{}, {inSize, 2*inSize}}).assign((*dLdB)({{0, inSize}}) * (*x)); // // dLdWr = dLdBr * x -// (*dLdW)({{}, {2*inSize, 3*inSize}}).assign((*dLdB)({{inSize, 2*inSize}}) * (*x)); - +// (*dLdW)({{}, {2*inSize, 3*inSize}}).assign((*dLdB)({{inSize, 2*inSize}}) +// * (*x)); -// // dLdX = dLdH*dHdX + dLdC*dCdX = dLdH*(dHdX + dHdR*dRdX + dHdC*dCdX) + dLdC*dCdF*dFdX = dLdH*(1 - r + dHdR*dRdX + dHdC*dCdX) + dLdC*dCdX -// dLdX->assign((*dLdH) * (oneMinusR + dHdR * (*w)({{},{2*inSize, 3*inSize}}) + dHdC * dCdX) + (*dLdC) * dCdX); +// // dLdX = dLdH*dHdX + dLdC*dCdX = dLdH*(dHdX + dHdR*dRdX + dHdC*dCdX) + +// dLdC*dCdF*dFdX = dLdH*(1 - r + dHdR*dRdX + dHdC*dCdX) + dLdC*dCdX +// dLdX->assign((*dLdH) * (oneMinusR + dHdR * (*w)({{},{2*inSize, +// 3*inSize}}) + dHdC * dCdX) + (*dLdC) * dCdX); // } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/stack.cpp b/libnd4j/include/ops/declarable/helpers/cpu/stack.cpp index 3db322fc81a0..c38f4cca4c4b 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/stack.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/stack.cpp @@ -18,105 +18,111 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include #include #include #include - +#include +#include namespace sd { namespace ops { namespace helpers { - /////////////////////////////////////////////////////////////////// template -static void stack_(const std::vector& inArrs, NDArray& output, const int dim) { - - const int numOfSubArrs = inArrs.size(); - - if(inArrs[0]->rankOf() == 0) { - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - output.p(i, inArrs[i]->t(0)); - }; - - samediff::Threads::parallel_for(func, 0, numOfSubArrs); - } - else { - - auto zTadPack = ConstantTadHelper::getInstance().tadForDimensions(output.shapeInfo(), ShapeUtils::evalDimsToExclude(output.rankOf(), {dim})); - auto zTadShapeInfo = zTadPack.primaryShapeInfo(); - - auto func = PRAGMA_THREADS_FOR { - - for (auto i = start; i < stop; i++) { - - void* zBuff = output.bufferWithOffset(zTadPack.primaryOffsets()[i]); - - NativeOpExecutioner::execTransformAny(inArrs[0]->getContext(), transform::Assign, - inArrs[i]->buffer(), inArrs[i]->shapeInfo(), nullptr/*input specialBuffer*/, nullptr/*input special*/, - zBuff, zTadShapeInfo, nullptr/*output specialBuffer*/, nullptr/*output special*/, - nullptr, nullptr, nullptr, false/*allowParallelism*/); - } - }; - - samediff::Threads::parallel_tad(func, 0, numOfSubArrs); - } - +static void stack_(const std::vector& inArrs, NDArray& output, + const int dim) { + const int numOfSubArrs = inArrs.size(); + + if (inArrs[0]->rankOf() == 0) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) output.p(i, inArrs[i]->t(0)); + }; + + samediff::Threads::parallel_for(func, 0, numOfSubArrs); + } else { + auto zTadPack = ConstantTadHelper::getInstance().tadForDimensions( + output.shapeInfo(), + ShapeUtils::evalDimsToExclude(output.rankOf(), {dim})); + auto zTadShapeInfo = zTadPack.primaryShapeInfo(); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + void* zBuff = output.bufferWithOffset(zTadPack.primaryOffsets()[i]); + + NativeOpExecutioner::execTransformAny( + inArrs[0]->getContext(), transform::Assign, inArrs[i]->buffer(), + inArrs[i]->shapeInfo(), nullptr /*input specialBuffer*/, + nullptr /*input special*/, zBuff, zTadShapeInfo, + nullptr /*output specialBuffer*/, + nullptr /*output special*/, nullptr, nullptr, nullptr, + false /*allowParallelism*/); + } + }; + + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); + } } //////////////////////////////////////////////////////////////////////// -void stack(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output, const int dim) { - BUILD_SINGLE_SELECTOR(output.dataType(), stack_, (inArrs, output, dim), LIBND4J_TYPES); +void stack(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output, + const int dim) { + BUILD_SINGLE_SELECTOR(output.dataType(), stack_, (inArrs, output, dim), + LIBND4J_TYPES); } -BUILD_SINGLE_TEMPLATE(template void stack_ , (const std::vector& inArrs, NDArray& output, const int dim), LIBND4J_TYPES); - +BUILD_SINGLE_TEMPLATE(template void stack_, + (const std::vector& inArrs, + NDArray& output, const int dim), + LIBND4J_TYPES); /////////////////////////////////////////////////////////////////// template -static void unstack_(const NDArray& input, const std::vector& outArrs, const int dim) { - - const int numOfSubArrs = outArrs.size(); - - if(outArrs[0]->rankOf() == 0) { - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - outArrs[i]->p(0, input.t(i)); - }; - - samediff::Threads::parallel_for(func, 0, numOfSubArrs); - } - else { - - auto xTadPack = ConstantTadHelper::getInstance().tadForDimensions(input.shapeInfo(), ShapeUtils::evalDimsToExclude(input.rankOf(), {dim})); - auto xTadShapeInfo = xTadPack.primaryShapeInfo(); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto xBuff = input.bufferWithOffset(xTadPack.primaryOffsets()[i]); - - NativeOpExecutioner::execTransformAny(input.getContext(), transform::Assign, - xBuff, xTadShapeInfo, nullptr/*input specialBuffer*/, nullptr/*input special*/, - outArrs[i]->buffer(), outArrs[i]->shapeInfo(), nullptr/*output specialBuffer*/, nullptr/*output special*/, - nullptr, nullptr, nullptr, false/*allowParallelism*/); - } - }; - - samediff::Threads::parallel_tad(func, 0, numOfSubArrs); - } +static void unstack_(const NDArray& input, const std::vector& outArrs, + const int dim) { + const int numOfSubArrs = outArrs.size(); + + if (outArrs[0]->rankOf() == 0) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) outArrs[i]->p(0, input.t(i)); + }; + + samediff::Threads::parallel_for(func, 0, numOfSubArrs); + } else { + auto xTadPack = ConstantTadHelper::getInstance().tadForDimensions( + input.shapeInfo(), + ShapeUtils::evalDimsToExclude(input.rankOf(), {dim})); + auto xTadShapeInfo = xTadPack.primaryShapeInfo(); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto xBuff = input.bufferWithOffset(xTadPack.primaryOffsets()[i]); + + NativeOpExecutioner::execTransformAny( + input.getContext(), transform::Assign, xBuff, xTadShapeInfo, + nullptr /*input specialBuffer*/, nullptr /*input special*/, + outArrs[i]->buffer(), outArrs[i]->shapeInfo(), + nullptr /*output specialBuffer*/, + nullptr /*output special*/, nullptr, nullptr, nullptr, + false /*allowParallelism*/); + } + }; + + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); + } } //////////////////////////////////////////////////////////////////////// -void unstack(sd::LaunchContext* context, const NDArray& input, const std::vector& outArrs, const int dim) { - BUILD_SINGLE_SELECTOR(input.dataType(), unstack_, (input, outArrs, dim), LIBND4J_TYPES); -} -BUILD_SINGLE_TEMPLATE(template void unstack_, (const NDArray& input, const std::vector& outArrs, const int dim), LIBND4J_TYPES); - +void unstack(sd::LaunchContext* context, const NDArray& input, + const std::vector& outArrs, const int dim) { + BUILD_SINGLE_SELECTOR(input.dataType(), unstack_, (input, outArrs, dim), + LIBND4J_TYPES); } -} -} - +BUILD_SINGLE_TEMPLATE(template void unstack_, + (const NDArray& input, + const std::vector& outArrs, const int dim), + LIBND4J_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp b/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp index 6910960ef9ac..c0dfbb9e9f80 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp @@ -18,62 +18,63 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 03.01.2018 // -#include #include -#include #include +#include +#include namespace sd { namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// -// svd operation, this function is not method of SVD class, it is standalone function +// svd operation, this function is not method of SVD class, it is standalone +// function template -static void svd_(const NDArray* x, const std::vector& outArrs, const bool fullUV, const bool calcUV, const int switchNum) { - - auto s = outArrs[0]; - auto u = outArrs[1]; - auto v = outArrs[2]; - - const int rank = x->rankOf(); - const int sRank = rank - 1; - - auto listX = x->allTensorsAlongDimension({rank-2, rank-1}); - auto listS = s->allTensorsAlongDimension({sRank-1}); - ResultSet* listU(nullptr), *listV(nullptr); - - if(calcUV) { - listU = new ResultSet(u->allTensorsAlongDimension({rank-2, rank-1})); - listV = new ResultSet(v->allTensorsAlongDimension({rank-2, rank-1})); - } - - for(int i = 0; i < listX.size(); ++i) { - - // NDArray matrix(x->ordering(), {listX.at(i)->sizeAt(0), listX.at(i)->sizeAt(1)}, block.getContext()); - // matrix.assign(listX.at(i)); - helpers::SVD svdObj(*(listX.at(i)), switchNum, calcUV, calcUV, fullUV); - listS.at(i)->assign(svdObj._s); - - if(calcUV) { - listU->at(i)->assign(svdObj._u); - listV->at(i)->assign(svdObj._v); - } +static void svd_(const NDArray* x, const std::vector& outArrs, + const bool fullUV, const bool calcUV, const int switchNum) { + auto s = outArrs[0]; + auto u = outArrs[1]; + auto v = outArrs[2]; + + const int rank = x->rankOf(); + const int sRank = rank - 1; + + auto listX = x->allTensorsAlongDimension({rank - 2, rank - 1}); + auto listS = s->allTensorsAlongDimension({sRank - 1}); + ResultSet *listU(nullptr), *listV(nullptr); + + if (calcUV) { + listU = new ResultSet(u->allTensorsAlongDimension({rank - 2, rank - 1})); + listV = new ResultSet(v->allTensorsAlongDimension({rank - 2, rank - 1})); + } + + for (int i = 0; i < listX.size(); ++i) { + // NDArray matrix(x->ordering(), {listX.at(i)->sizeAt(0), + // listX.at(i)->sizeAt(1)}, block.getContext()); matrix.assign(listX.at(i)); + helpers::SVD svdObj(listX.at(i), switchNum, calcUV, calcUV, fullUV); + listS.at(i).assign(svdObj._s); + + if (calcUV) { + listU->at(i).assign(svdObj._u); + listV->at(i).assign(svdObj._v); } + } - if(calcUV) { - delete listU; - delete listV; - } + if (calcUV) { + delete listU; + delete listV; + } } ////////////////////////////////////////////////////////////////////////// -void svd(sd::LaunchContext * context, const NDArray* x, const std::vector& outArrs, const bool fullUV, const bool calcUV, const int switchNum) { - BUILD_SINGLE_SELECTOR(x->dataType(), svd_, (x, outArrs, fullUV, calcUV, switchNum), FLOAT_TYPES); -} - - -} -} +void svd(sd::LaunchContext* context, const NDArray* x, + const std::vector& outArrs, const bool fullUV, + const bool calcUV, const int switchNum) { + BUILD_SINGLE_SELECTOR(x->dataType(), svd_, + (x, outArrs, fullUV, calcUV, switchNum), FLOAT_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/tile.cpp b/libnd4j/include/ops/declarable/helpers/cpu/tile.cpp index 4edb9e2a0142..aabc0c56882b 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/tile.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/tile.cpp @@ -18,74 +18,68 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - -#include -#include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// template -static void tileBP_(const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector reps) { - - T* gradIBuff = reinterpret_cast(gradI.buffer()); - auto gradOBuff = reinterpret_cast(gradO.buffer()); - const Nd4jLong gradILen = gradI.lengthOf(); - const Nd4jLong gradOLen = gradO.lengthOf(); // gradOLen >= gradILen - const Nd4jLong gradIEWS = sd::math::nd4j_abs(gradI.ews()); - const Nd4jLong gradOEWS = gradO.ews(); - - // initial zeroing of gradI content - if(gradIEWS == 1) - memset(gradIBuff, 0, gradILen * sizeof(T)); - else { - //PRAGMA_OMP_PARALLEL_FOR_SIMD - for (Nd4jLong i = 0; i < gradILen * gradIEWS; i += gradIEWS) - gradIBuff[i] = static_cast(0.f); - } - - - if(gradO.ordering() == 'c' && gradOEWS == 1) { - - //PRAGMA_OMP_PARALLEL_FOR_SIMD - for(Nd4jLong i=0; i(idx) + gradOBuff[i]); - } +static void tileBP_(const NDArray& gradO /*input*/, NDArray& gradI /*output*/, + const std::vector reps) { + T* gradIBuff = reinterpret_cast(gradI.buffer()); + auto gradOBuff = reinterpret_cast(gradO.buffer()); + const Nd4jLong gradILen = gradI.lengthOf(); + const Nd4jLong gradOLen = gradO.lengthOf(); // gradOLen >= gradILen + const Nd4jLong gradIEWS = sd::math::nd4j_abs(gradI.ews()); + const Nd4jLong gradOEWS = gradO.ews(); + + // initial zeroing of gradI content + if (gradIEWS == 1) + memset(gradIBuff, 0, gradILen * sizeof(T)); + else { + // PRAGMA_OMP_PARALLEL_FOR_SIMD + for (Nd4jLong i = 0; i < gradILen * gradIEWS; i += gradIEWS) + gradIBuff[i] = static_cast(0.f); + } + + if (gradO.ordering() == 'c' && gradOEWS == 1) { + // PRAGMA_OMP_PARALLEL_FOR_SIMD + for (Nd4jLong i = 0; i < gradOLen; ++i) { + auto idx = shape::subArrayIndex(i, gradO.shapeInfo(), gradI.shapeInfo()); + gradI.p(idx, gradI.e(idx) + gradOBuff[i]); } - else if(gradO.ordering() == 'c' && gradOEWS > 1) { - - //PRAGMA_OMP_PARALLEL_FOR_SIMD - for(Nd4jLong i=0; i(idx) + gradOBuff[i * gradOEWS]); - } + } else if (gradO.ordering() == 'c' && gradOEWS > 1) { + // PRAGMA_OMP_PARALLEL_FOR_SIMD + for (Nd4jLong i = 0; i < gradOLen; ++i) { + auto idx = shape::subArrayIndex(i, gradO.shapeInfo(), gradI.shapeInfo()); + gradI.p(idx, gradI.e(idx) + gradOBuff[i * gradOEWS]); } - else { - - //PRAGMA_OMP_PARALLEL_FOR_SIMD - for(Nd4jLong i=0; i(fidx) + gradOBuff[shape::getIndexOffset(i, gradO.shapeInfo())]); - } + } else { + // PRAGMA_OMP_PARALLEL_FOR_SIMD + for (Nd4jLong i = 0; i < gradOLen; ++i) { + auto fidx = shape::subArrayIndex(i, gradO.shapeInfo(), gradI.shapeInfo()); + gradI.p(fidx, gradI.e(fidx) + + gradOBuff[shape::getIndexOffset(i, gradO.shapeInfo())]); } + } } -void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector reps) { - BUILD_SINGLE_SELECTOR(gradI.dataType(), tileBP_, (gradO, gradI, reps), FLOAT_TYPES); +void tileBP(sd::LaunchContext* context, const NDArray& gradO /*input*/, + NDArray& gradI /*output*/, const std::vector reps) { + BUILD_SINGLE_SELECTOR(gradI.dataType(), tileBP_, (gradO, gradI, reps), + FLOAT_TYPES); } +BUILD_SINGLE_TEMPLATE(template void tileBP_, + (const NDArray& gradO /*input*/, + NDArray& gradI /*output*/, + const std::vector reps), + FLOAT_TYPES); -BUILD_SINGLE_TEMPLATE(template void tileBP_, (const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector reps), FLOAT_TYPES); - - - - - -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/toggle_bits.cpp b/libnd4j/include/ops/declarable/helpers/cpu/toggle_bits.cpp index 67b4e0f776d8..e5ab75571229 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/toggle_bits.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/toggle_bits.cpp @@ -18,24 +18,22 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { - namespace ops { - namespace helpers { - template - void toggle_bits__(NDArray &in, NDArray &out) { - auto lambda = LAMBDA_T(_x) { - return BitwiseUtils::flip_bits(_x); - }; +namespace ops { +namespace helpers { +template +void toggle_bits__(NDArray& in, NDArray& out) { + auto lambda = LAMBDA_T(_x) { return BitwiseUtils::flip_bits(_x); }; - in.applyLambda(lambda, out); - } + in.applyLambda(lambda, out); +} - void __toggle_bits(sd::LaunchContext * context, NDArray& in, NDArray& out) { - BUILD_SINGLE_SELECTOR(in.dataType(), toggle_bits__, (in, out), INTEGER_TYPES); - } - } - } -} \ No newline at end of file +void __toggle_bits(sd::LaunchContext* context, NDArray& in, NDArray& out) { + BUILD_SINGLE_SELECTOR(in.dataType(), toggle_bits__, (in, out), INTEGER_TYPES); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/top_k.cpp b/libnd4j/include/ops/declarable/helpers/cpu/top_k.cpp index 65edeb71b713..9a511f9a3bd7 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/top_k.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/top_k.cpp @@ -18,167 +18,186 @@ // @author raver119@gmail.com // -#include -#include #include #include +#include +#include namespace sd { namespace ops { namespace helpers { - template - static int topKFunctor_(const NDArray* input, NDArray* values, NDArray* indices, const uint k, bool needSort) { - Nd4jLong width = input->sizeAt(-1); - int lastDim = input->rankOf() - 1; -// ----------------------------------------------------------------------------------------------- // -// this assumption is right: -// if (values->lengthOf() != k * lastDimList->size()) { -// nd4j_printf("top_k: something is wrong. %i expected, but %i given.\n", -// values->lengthOf(), k * lastDimList->size()); -// } -// ----------------------------------------------------------------------------------------------- // - std::vector dimsToExclude(input->rankOf() - 1); - for (size_t d = 0; d < dimsToExclude.size(); ++d) - dimsToExclude[d] = d; - - const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(input->shapeInfo(), dimsToExclude); - - if (k == 1) { - for (Nd4jLong e = 0; e < numOfSubArrs; ++e) { - auto trial = (*input)(e, dimsToExclude); - //int maxPos = //lastDimList->at(e)->argMax(); - Nd4jLong maxPos = 0; - //trial.printIndexedBuffer("TRIAL:"); - T maxVal = trial.e(0); - for (Nd4jLong pos = 1; pos < trial.lengthOf(); pos++) - if (maxVal < trial.e(pos)) { - maxPos = pos; - maxVal = trial.e(pos); - } - if (indices) - indices->p(e, maxPos); //topIndex; - if (values) - values->p(e, maxVal); - } - } - else { - int nextPos = 0; - - for (Nd4jLong e = 0; e < numOfSubArrs; ++e) { - auto trial = (*input)(e, dimsToExclude); - - // fill up the first k elements - NDArray topValues = NDArrayFactory::create('c', {k}, input->getContext()); - NDArray sortedVals = NDArrayFactory::create('c', {k}, input->getContext()); - NDArray topIndices = NDArrayFactory::create('c', {k}, input->getContext()); - for (uint pos = 0; pos < k; ++pos) { - topIndices.r(pos) = pos; - topValues.r(pos) = trial.t(pos); - } - //std::vector sortedVals(topValues); - sortedVals.assign(topValues);// = NDArrayFactory::create('c', {k}); - //std::sort(sortedVals.begin(), sortedVals.end()); // sorted in ascending order - SpecialMethods::sortGeneric(sortedVals.buffer(), sortedVals.shapeInfo(), false); - for (Nd4jLong i = static_cast(k); i < width; ++i) { - T val = trial.e(i); - T minTopVal = sortedVals.t(0); - if (minTopVal < val) { // value should be inserted to top k - // only if it is not contained in - T* begin = reinterpret_cast(sortedVals.buffer()); - T* end = begin + k; - bool exists = std::binary_search(begin, end, val); - if (!exists) { - //exchangePos - a distance between begin and minimal existed to be suppressed by val - T* topBegin = reinterpret_cast(topValues.buffer()); - T* topEnd = topBegin + k; - auto exchangePos = std::distance(topBegin, std::find(topBegin, topEnd, sortedVals.t(0))); - topValues.r(exchangePos) = val; //*exchangeIt = val; - topIndices.r(exchangePos) = i; - sortedVals.r(0) = val; // suppress in sorted - //std::sort(sortedVals.begin(), sortedVals.end()); // sorted in ascending order - SpecialMethods::sortGeneric(sortedVals.buffer(), sortedVals.shapeInfo(), false); - } - } - } - if (needSort) { - SpecialMethods::sortGeneric(topValues.buffer(), topValues.shapeInfo(), true); - - for (Nd4jLong j = 0; j < width; j++) - for (uint pos = 0; pos < k; ++pos) - if (topValues.t(pos) == trial.t(j)) - topIndices.r(pos) = j; - } - else { // else sort by indices - std::map sortValsMap; - //std::vector> data(topValues.lengthOf()); - for (Nd4jLong e = 0; e < topValues.lengthOf(); ++e) { - sortValsMap[topIndices.t(e)] = topValues.t(e); - } - - //std::sort(data.begin(), data.end(), [](std::pair const& a, std::pair const& b) { - // return a.first < b.first; - //}); - Nd4jLong e = 0; - for (auto it = sortValsMap.begin(); it != sortValsMap.end(); ++it, e++) { - topIndices.r(e) = it->first; - topValues.r(e) = it->second; - } - - } - if (values) - (*values)(e, dimsToExclude).assign(topValues); - if (indices) - (*indices)(e, dimsToExclude).assign(topIndices); - } - //indices->printIndexedBuffer("Indices as is"); +template +static int topKFunctor_(const NDArray* input, NDArray* values, NDArray* indices, + const uint k, bool needSort) { + Nd4jLong width = input->sizeAt(-1); + int lastDim = input->rankOf() - 1; + // ----------------------------------------------------------------------------------------------- + // // this assumption is right: + // if (values->lengthOf() != k * lastDimList->size()) { + // nd4j_printf("top_k: something is wrong. %i expected, but %i + // given.\n", + // values->lengthOf(), k * lastDimList->size()); + // } + // ----------------------------------------------------------------------------------------------- + // // + std::vector dimsToExclude(input->rankOf() - 1); + for (size_t d = 0; d < dimsToExclude.size(); ++d) dimsToExclude[d] = d; + + const Nd4jLong numOfSubArrs = + ShapeUtils::getNumOfSubArrs(input->shapeInfo(), dimsToExclude); + + if (k == 1) { + for (Nd4jLong e = 0; e < numOfSubArrs; ++e) { + auto trial = (*input)(e, dimsToExclude); + // int maxPos = //lastDimList->at(e)->argMax(); + Nd4jLong maxPos = 0; + // trial.printIndexedBuffer("TRIAL:"); + T maxVal = trial.e(0); + for (Nd4jLong pos = 1; pos < trial.lengthOf(); pos++) + if (maxVal < trial.e(pos)) { + maxPos = pos; + maxVal = trial.e(pos); } - return Status::OK(); - } -// ----------------------------------------------------------------------------------------------- // - - template - static int inTopKFunctor_(sd::LaunchContext* context, const NDArray* input, const NDArray* target, NDArray* result, const uint k) { - - std::vector shapeI(input->rankOf()); - for (int i = 0; i < input->rankOf() - 1; i++) - shapeI[i] = input->sizeAt(i); - shapeI[input->rankOf() - 1] = k; - std::unique_ptr indices(NDArrayFactory::create_(input->ordering(), shapeI, context)); - NDArray* values = nullptr; - int status = topKFunctor(context, input, values, indices.get(), k, true); - result->assign(0); - if (status == ND4J_STATUS_OK) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - bool found = false; - for (uint j = 0; j < k; j++) { - if (target->e(e) == indices->e(e * k + j)) { - found = true; - break; - } - } - if (found) - result->p(e, true); - } - }; - - samediff::Threads::parallel_tad(func, 0, target->lengthOf()); - } - return status; - + if (indices) indices->p(e, maxPos); // topIndex; + if (values) values->p(e, maxVal); } - - int topKFunctor(sd::LaunchContext * context, const NDArray* input, NDArray* values, NDArray* indices, const uint k, bool needSort) { - BUILD_SINGLE_SELECTOR(input->dataType(), return topKFunctor_, (input, values, indices, k, needSort), NUMERIC_TYPES); + } else { + int nextPos = 0; + + for (Nd4jLong e = 0; e < numOfSubArrs; ++e) { + auto trial = (*input)(e, dimsToExclude); + + // fill up the first k elements + NDArray topValues = + NDArrayFactory::create('c', {k}, input->getContext()); + NDArray sortedVals = + NDArrayFactory::create('c', {k}, input->getContext()); + NDArray topIndices = + NDArrayFactory::create('c', {k}, input->getContext()); + for (uint pos = 0; pos < k; ++pos) { + topIndices.r(pos) = pos; + topValues.r(pos) = trial.t(pos); + } + // std::vector sortedVals(topValues); + sortedVals.assign(topValues); // = NDArrayFactory::create('c', {k}); + // std::sort(sortedVals.begin(), sortedVals.end()); // sorted in ascending + // order + SpecialMethods::sortGeneric(sortedVals.buffer(), + sortedVals.shapeInfo(), false); + for (Nd4jLong i = static_cast(k); i < width; ++i) { + T val = trial.e(i); + T minTopVal = sortedVals.t(0); + if (minTopVal < val) { // value should be inserted to top k + // only if it is not contained in + T* begin = reinterpret_cast(sortedVals.buffer()); + T* end = begin + k; + bool exists = std::binary_search(begin, end, val); + if (!exists) { + // exchangePos - a distance between begin and minimal existed to be + // suppressed by val + T* topBegin = reinterpret_cast(topValues.buffer()); + T* topEnd = topBegin + k; + auto exchangePos = std::distance( + topBegin, std::find(topBegin, topEnd, sortedVals.t(0))); + topValues.r(exchangePos) = val; //*exchangeIt = val; + topIndices.r(exchangePos) = i; + sortedVals.r(0) = val; // suppress in sorted + // std::sort(sortedVals.begin(), sortedVals.end()); // sorted in + // ascending order + SpecialMethods::sortGeneric(sortedVals.buffer(), + sortedVals.shapeInfo(), false); + } + } + } + if (needSort) { + SpecialMethods::sortGeneric(topValues.buffer(), + topValues.shapeInfo(), true); + + for (Nd4jLong j = 0; j < width; j++) + for (uint pos = 0; pos < k; ++pos) + if (topValues.t(pos) == trial.t(j)) + topIndices.r(pos) = j; + } else { // else sort by indices + std::map sortValsMap; + // std::vector> data(topValues.lengthOf()); + for (Nd4jLong e = 0; e < topValues.lengthOf(); ++e) { + sortValsMap[topIndices.t(e)] = topValues.t(e); } - int inTopKFunctor(sd::LaunchContext * context, const NDArray* input, const NDArray* target, NDArray* result, const uint k) { - BUILD_SINGLE_SELECTOR(input->dataType(), return inTopKFunctor_, (context, input, target, result, k), NUMERIC_TYPES); + // std::sort(data.begin(), data.end(), [](std::pair const& a, + // std::pair const& b) { + // return a.first < b.first; + //}); + Nd4jLong e = 0; + for (auto it = sortValsMap.begin(); it != sortValsMap.end(); + ++it, e++) { + topIndices.r(e) = it->first; + topValues.r(e) = it->second; } + } + if (values) (*values)(e, dimsToExclude).assign(topValues); + if (indices) (*indices)(e, dimsToExclude).assign(topIndices); + } + // indices->printIndexedBuffer("Indices as is"); + } + return Status::OK(); +} +// ----------------------------------------------------------------------------------------------- +// // + +template +static int inTopKFunctor_(sd::LaunchContext* context, const NDArray* input, + const NDArray* target, NDArray* result, + const uint k) { + std::vector shapeI(input->rankOf()); + for (int i = 0; i < input->rankOf() - 1; i++) shapeI[i] = input->sizeAt(i); + shapeI[input->rankOf() - 1] = k; + std::unique_ptr indices( + NDArrayFactory::create_(input->ordering(), shapeI, context)); + NDArray* values = nullptr; + int status = topKFunctor(context, input, values, indices.get(), k, true); + result->assign(0); + if (status == ND4J_STATUS_OK) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + bool found = false; + for (uint j = 0; j < k; j++) { + if (target->e(e) == indices->e(e * k + j)) { + found = true; + break; + } + } + if (found) result->p(e, true); + } + }; - BUILD_SINGLE_TEMPLATE(template int topKFunctor_, (const NDArray* input, NDArray* values, NDArray* indices, const uint k, bool needSort), NUMERIC_TYPES); - BUILD_SINGLE_TEMPLATE(template int inTopKFunctor_, (sd::LaunchContext * context, const NDArray* input, const NDArray* target, NDArray* result, const uint k), NUMERIC_TYPES); + samediff::Threads::parallel_tad(func, 0, target->lengthOf()); + } + return status; } + +int topKFunctor(sd::LaunchContext* context, const NDArray* input, + NDArray* values, NDArray* indices, const uint k, + bool needSort) { + BUILD_SINGLE_SELECTOR(input->dataType(), return topKFunctor_, + (input, values, indices, k, needSort), NUMERIC_TYPES); } + +int inTopKFunctor(sd::LaunchContext* context, const NDArray* input, + const NDArray* target, NDArray* result, const uint k) { + BUILD_SINGLE_SELECTOR(input->dataType(), return inTopKFunctor_, + (context, input, target, result, k), NUMERIC_TYPES); } + +BUILD_SINGLE_TEMPLATE(template int topKFunctor_, + (const NDArray* input, NDArray* values, NDArray* indices, + const uint k, bool needSort), + NUMERIC_TYPES); +BUILD_SINGLE_TEMPLATE(template int inTopKFunctor_, + (sd::LaunchContext * context, const NDArray* input, + const NDArray* target, NDArray* result, const uint k), + NUMERIC_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/trace.cpp b/libnd4j/include/ops/declarable/helpers/cpu/trace.cpp index d544fa24eea6..d6353bd72ff5 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/trace.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/trace.cpp @@ -18,30 +18,30 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// template static void trace_(const NDArray& input, NDArray& output) { - const int inRank = input.rankOf(); - auto setOfSubArrs = input.allTensorsAlongDimension({inRank-2, inRank-1}); + const int inRank = input.rankOf(); + auto setOfSubArrs = input.allTensorsAlongDimension({inRank - 2, inRank - 1}); - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - output.p(i, setOfSubArrs.at(i)->getTrace()); - }; - samediff::Threads::parallel_for(func, 0, setOfSubArrs.size()); + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) + output.p(i, setOfSubArrs.at(i).getTrace()); + }; + samediff::Threads::parallel_for(func, 0, setOfSubArrs.size()); } - void trace(sd::LaunchContext * context, const NDArray& input, NDArray& output) { - BUILD_SINGLE_SELECTOR(input.dataType(), trace_, (input, output), LIBND4J_TYPES); - } -} -} +void trace(sd::LaunchContext* context, const NDArray& input, NDArray& output) { + BUILD_SINGLE_SELECTOR(input.dataType(), trace_, (input, output), + LIBND4J_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp b/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp index 86847da16ba9..e9b598acfce6 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp @@ -17,71 +17,76 @@ // // @author GS // -#include +#include "../triangular_solve.h" + #include #include -#include "../triangular_solve.h" +#include namespace sd { namespace ops { namespace helpers { - /* - * lower triangular process for system of linear equations - * x_1 = b_1/a_1,1 - * x_2 = (b_2 - a_2,1 * x_1) / a_2,2 - * x_3 = (b_3 - a_3,1 * x_1 - a_3,2 * x_2) / a_3,3 - * ... - * x_M = (b_M - a_M,1 * x_1 - ... a_M,M-1 * x_M-1)/ a_M,M - * - * output == x - * a == leftInput - * b == rightInput - * - * */ - template - static void lowerTriangularSolve(sd::LaunchContext * context, NDArray const * leftInput, NDArray const* rightInput, bool const unitsOnDiag, NDArray* output) { - auto rows = leftInput->rows(); - auto cols = rightInput->columns(); - //output->r(0,0) = rightInput->t(0,0) / leftInput->t(0,0); - for (Nd4jLong r = 0; r < rows; r++) { - for (Nd4jLong j = 0; j < cols; j++) { - auto sum = rightInput->t(r, j); - for (Nd4jLong c = 0; c < r; c++) { - sum -= leftInput->t(r, c) * output->t(c, j); - } - output->r(r, j) = unitsOnDiag?sum: sum / leftInput->t(r, r); - } - } +/* + * lower triangular process for system of linear equations + * x_1 = b_1/a_1,1 + * x_2 = (b_2 - a_2,1 * x_1) / a_2,2 + * x_3 = (b_3 - a_3,1 * x_1 - a_3,2 * x_2) / a_3,3 + * ... + * x_M = (b_M - a_M,1 * x_1 - ... a_M,M-1 * x_M-1)/ a_M,M + * + * output == x + * a == leftInput + * b == rightInput + * + * */ +template +static void lowerTriangularSolve(sd::LaunchContext* context, NDArray const * leftInput, + NDArray const* rightInput, bool const unitsOnDiag, + NDArray* output) { + auto rows = leftInput->rows(); + auto cols = rightInput->columns(); + // output->r(0,0) = rightInput->t(0,0) / leftInput->t(0,0); + for (Nd4jLong r = 0; r < rows; r++) { + for (Nd4jLong j = 0; j < cols; j++) { + auto sum = rightInput->t(r, j); + for (Nd4jLong c = 0; c < r; c++) { + sum -= leftInput->t(r, c) * output->t(c, j); + } + output->r(r, j) = unitsOnDiag?sum: sum / leftInput->t(r, r); } + } +} - /* - * upper triangular process for system of linear equations - * x_M = b_M/a_M,M - * x_M-1 = (b_M-1 - a_M-1,M-2 * x_M) / a_M-1,M-1 - * x_M-2 = (b_M-2 - a_M-2,M-3 * x_M-2 - a_M-2,M-1 * x_M) / a_3,3 - * ... - * x_1 = (b_1 - a_1,2 * x_2 - ... a_1,M * x_M)/ a_1,1 - * - * output == x - * a == leftInput - * b == rightInput - * - * */ +/* + * upper triangular process for system of linear equations + * x_M = b_M/a_M,M + * x_M-1 = (b_M-1 - a_M-1,M-2 * x_M) / a_M-1,M-1 + * x_M-2 = (b_M-2 - a_M-2,M-3 * x_M-2 - a_M-2,M-1 * x_M) / a_3,3 + * ... + * x_1 = (b_1 - a_1,2 * x_2 - ... a_1,M * x_M)/ a_1,1 + * + * output == x + * a == leftInput + * b == rightInput + * + * */ - template - static void upperTriangularSolve(sd::LaunchContext* context, NDArray const* leftInput, NDArray const* rightInput, bool const unitsOnDiag, NDArray* output) { - auto rows = leftInput->rows(); - auto cols = rightInput->columns(); - for (Nd4jLong r = rows; r > 0; r--) { - for (Nd4jLong j = 0; j < cols; j++) { - auto sum = rightInput->t(r - 1, j); - for (Nd4jLong c = r; c < rows; c++) { - sum -= leftInput->t(r - 1, c) * output->t(c, j); - } - output->r(r - 1, j) = unitsOnDiag? sum : sum / leftInput->t(r - 1, r - 1); - } - } +template +static void upperTriangularSolve(sd::LaunchContext* context, NDArray const* leftInput, + NDArray const* rightInput, bool const unitsOnDiag, + NDArray* output) { + auto rows = leftInput->rows(); + auto cols = rightInput->columns(); + for (Nd4jLong r = rows; r > 0; r--) { + for (Nd4jLong j = 0; j < cols; j++) { + auto sum = rightInput->t(r - 1, j); + for (Nd4jLong c = r; c < rows; c++) { + sum -= leftInput->t(r - 1, c) * output->t(c, j); + } + output->r(r - 1, j) = unitsOnDiag? sum : sum / leftInput->t(r - 1, r - 1); } + } +} /// triangularSolve2D - 2D implementation of triangularSolveFunctor /// \tparam T - type of NDArray output @@ -103,46 +108,51 @@ namespace helpers { } BUILD_SINGLE_TEMPLATE(template void triangularSolve2D, (sd::LaunchContext* context, NDArray const& leftInput, NDArray const& rightInput, bool const lower, bool const unitsOnDiag, NDArray& output), FLOAT_TYPES); - template - static int triangularSolveFunctor_(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool lower, bool adjoint, NDArray* output) { - auto leftPart = leftInput->allTensorsAlongDimension({-2, -1}); - auto rightPart = rightInput->allTensorsAlongDimension({-2, -1}); - auto outputPart = output->allTensorsAlongDimension({-2, -1}); +template +static int triangularSolveFunctor_(sd::LaunchContext* context, + NDArray* leftInput, NDArray* rightInput, + bool lower, bool adjoint, NDArray* output) { + auto leftPart = leftInput->allTensorsAlongDimension({-2, -1}); + auto rightPart = rightInput->allTensorsAlongDimension({-2, -1}); + auto outputPart = output->allTensorsAlongDimension({-2, -1}); - auto batchLoop = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - if (lower) { - lowerTriangularSolve(context, leftPart[i], rightPart[i], false, outputPart[i]); - } else { - upperTriangularSolve(context, leftPart[i], rightPart[i], false, outputPart[i]); - } - } - }; - - samediff::Threads::parallel_tad(batchLoop, 0, leftPart.size(), 1); + auto batchLoop = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + if (lower) { + lowerTriangularSolve(context, &leftPart[i], &rightPart[i], false, + &outputPart[i]); + } else { + upperTriangularSolve(context, &leftPart[i], &rightPart[i], false, + &outputPart[i]); + } + } + }; - return Status::OK(); + samediff::Threads::parallel_tad(batchLoop, 0, leftPart.size(), 1); - } - template - static void adjointTriangularMatrix_(sd::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output) { - auto inputPart = input->allTensorsAlongDimension({-2, -1}); - auto outputPart = output->allTensorsAlongDimension({-2, -1}); - auto cols = input->sizeAt(-1); - auto rows = input->sizeAt(-2); + return Status::OK(); +} +template +static void adjointTriangularMatrix_(sd::LaunchContext* context, + NDArray const* input, bool const lower, + NDArray* output) { + auto inputPart = input->allTensorsAlongDimension({-2, -1}); + auto outputPart = output->allTensorsAlongDimension({-2, -1}); + auto cols = input->sizeAt(-1); + auto rows = input->sizeAt(-2); auto batchLoop = PRAGMA_THREADS_FOR { for (auto batch = start; batch < stop; batch++) { if (!lower) { for (Nd4jLong r = 0; r < rows; r++) { for (Nd4jLong c = 0; c <= r; c++) { - outputPart[batch]->r(r, c) = inputPart[batch]->t(c, r); + outputPart[batch].r(r, c) = inputPart[batch].t(c, r); } } } else { for (Nd4jLong r = 0; r < rows; r++) { for (Nd4jLong c = r; c < cols; c++) { - outputPart[batch]->r(r, c) = inputPart[batch]->t(c, r); + outputPart[batch].r(r, c) = inputPart[batch].t(c, r); } } } @@ -151,13 +161,19 @@ namespace helpers { samediff::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1); } - int triangularSolveFunctor(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool lower, bool adjoint, NDArray* output) { - BUILD_SINGLE_SELECTOR(leftInput->dataType(), return triangularSolveFunctor_, (context, leftInput, rightInput, lower, adjoint, output), FLOAT_NATIVE); - } - - void adjointMatrix(sd::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), adjointTriangularMatrix_, (context, input, lower, output), FLOAT_NATIVE); - } -} +int triangularSolveFunctor(sd::LaunchContext* context, NDArray* leftInput, + NDArray* rightInput, bool lower, bool adjoint, + NDArray* output) { + BUILD_SINGLE_SELECTOR( + leftInput->dataType(), return triangularSolveFunctor_, + (context, leftInput, rightInput, lower, adjoint, output), FLOAT_NATIVE); } + +void adjointMatrix(sd::LaunchContext* context, NDArray const* input, + bool const lower, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), adjointTriangularMatrix_, + (context, input, lower, output), FLOAT_NATIVE); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/triu.cpp b/libnd4j/include/ops/declarable/helpers/cpu/triu.cpp index eb2074865e65..e46f350efb38 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/triu.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/triu.cpp @@ -18,39 +18,41 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// template -static void triuBP_(sd::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int diagonal) { - - auto dOdI = NDArray(&gradO); // dO/dI - const_cast(input).fillAsTriangular(0, diagonal, dOdI.sizeAt(-1), dOdI, 'b'); - int dLen = dOdI.lengthOf(); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - if (dOdI.t(i) != static_cast(0.f)) - dOdI.r(i) = static_cast(1.f); - } - }; - samediff::Threads::parallel_for(func, 0, dLen); - - // FIXME: !!! - gradI.assign(dOdI * gradO); // chain rule: dLoss/dI = dO/dI * dLoss/dO -} - - void triuBP(sd::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int diagonal) { - BUILD_SINGLE_SELECTOR(gradO.dataType(), triuBP_, (context, input, gradO, gradI, diagonal), LIBND4J_TYPES); +static void triuBP_(sd::LaunchContext* context, const NDArray& input, + const NDArray& gradO, NDArray& gradI, const int diagonal) { + auto dOdI = NDArray(&gradO); // dO/dI + const_cast(input).fillAsTriangular(0, diagonal, dOdI.sizeAt(-1), + dOdI, 'b'); + int dLen = dOdI.lengthOf(); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + if (dOdI.t(i) != static_cast(0.f)) + dOdI.r(i) = static_cast(1.f); } + }; + samediff::Threads::parallel_for(func, 0, dLen); + // FIXME: !!! + gradI.assign(dOdI * gradO); // chain rule: dLoss/dI = dO/dI * dLoss/dO } + +void triuBP(sd::LaunchContext* context, const NDArray& input, + const NDArray& gradO, NDArray& gradI, const int diagonal) { + BUILD_SINGLE_SELECTOR(gradO.dataType(), triuBP_, + (context, input, gradO, gradI, diagonal), + LIBND4J_TYPES); } -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaDelta.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaDelta.cpp index 78268b2dc12e..118c8fba8f89 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaDelta.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaDelta.cpp @@ -18,10 +18,10 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include #include #include #include +#include namespace sd { namespace ops { @@ -29,80 +29,106 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// template -static void adaDeltaUpdater_(const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx, - NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon) { - - const T* grad = gradient.bufferAsT(); - const T* initMsg = initStateMsg.bufferAsT(); - const T* initMsdx = initStateMsdx.bufferAsT(); - - T* up = update.bufferAsT(); - T* stMsg = stateMsg.bufferAsT(); - T* stMsdx = stateMsdx.bufferAsT(); - - const T rho = static_cast(dRho); - const T epsilon = static_cast(dEpsilon); - const T rhoT = (1 - rho); - - bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateMsg.ews() && 1 == initStateMsg.ews() && 1 == stateMsdx.ews() && 1 == initStateMsdx.ews(); - bool bSameOrdering = gradient.ordering() == update.ordering() && - update.ordering() == stateMsdx.ordering() && - stateMsdx.ordering() == initStateMsdx.ordering() && - stateMsdx.ordering() == initStateMsg.ordering() && stateMsg.ordering() == initStateMsg.ordering(); - - if (bEws1 && bSameOrdering) { - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - stMsg[i] = rho * initMsg[i] + grad[i] * grad[i] * rhoT; - - up[i] = grad[i] * (sd::math::nd4j_sqrt(initMsdx[i] + epsilon) / sd::math::nd4j_sqrt(stMsg[i] + epsilon)); - - stMsdx[i] = rho * initMsdx[i] + up[i] * up[i] * rhoT; - } - }; - - samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); - return; - } - - - bool bXZsame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); - bool bXInMsgSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateMsg.shapeInfo()); - bool bXStMsgSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateMsg.shapeInfo()); - bool bXInMsdxSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateMsdx.shapeInfo()); - bool bXStMsdxSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateMsdx.shapeInfo()); - - auto func = PRAGMA_THREADS_FOR{ - - int coords[MAX_RANK]; - for (auto i = start; i < gradient.lengthOf(); i++) { - shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); - const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); - const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); - const auto initMsgOffset = bXInMsgSame ? xOffset : shape::getOffset(initStateMsg.shapeInfo(), coords); - const auto stMsgOffset = bXStMsgSame ? xOffset : shape::getOffset(stateMsg.shapeInfo(), coords); - const auto initMsdxOffset = bXInMsdxSame ? xOffset : shape::getOffset(initStateMsdx.shapeInfo(), coords); - const auto stMsdxOffset = bXStMsdxSame ? xOffset : shape::getOffset(stateMsdx.shapeInfo(), coords); - - - stMsg[stMsgOffset] = rho * initMsg[initMsgOffset] + grad[xOffset] * grad[xOffset] * rhoT; - - up[zOffset] = grad[xOffset] * (sd::math::nd4j_sqrt(initMsdx[initMsdxOffset] + epsilon) / sd::math::nd4j_sqrt(stMsg[stMsgOffset] + epsilon)); - - stMsdx[stMsdxOffset] = rho * initMsdx[initMsdxOffset] + up[zOffset] * up[zOffset] * rhoT; - } +static void adaDeltaUpdater_(const NDArray& gradient, + const NDArray& initStateMsg, + const NDArray& initStateMsdx, NDArray& update, + NDArray& stateMsg, NDArray& stateMsdx, + const double dRho, const double dEpsilon) { + const T* grad = gradient.bufferAsT(); + const T* initMsg = initStateMsg.bufferAsT(); + const T* initMsdx = initStateMsdx.bufferAsT(); + + T* up = update.bufferAsT(); + T* stMsg = stateMsg.bufferAsT(); + T* stMsdx = stateMsdx.bufferAsT(); + + const T rho = static_cast(dRho); + const T epsilon = static_cast(dEpsilon); + const T rhoT = (1 - rho); + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && + 1 == stateMsg.ews() && 1 == initStateMsg.ews() && + 1 == stateMsdx.ews() && 1 == initStateMsdx.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateMsdx.ordering() && + stateMsdx.ordering() == initStateMsdx.ordering() && + stateMsdx.ordering() == initStateMsg.ordering() && + stateMsg.ordering() == initStateMsg.ordering(); + + if (bEws1 && bSameOrdering) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + stMsg[i] = rho * initMsg[i] + grad[i] * grad[i] * rhoT; + + up[i] = grad[i] * (sd::math::nd4j_sqrt(initMsdx[i] + epsilon) / + sd::math::nd4j_sqrt(stMsg[i] + epsilon)); + + stMsdx[i] = rho * initMsdx[i] + up[i] * up[i] * rhoT; + } }; samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); return; -} + } + + bool bXZsame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); + bool bXInMsgSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initStateMsg.shapeInfo()); + bool bXStMsgSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + stateMsg.shapeInfo()); + bool bXInMsdxSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initStateMsdx.shapeInfo()); + bool bXStMsdxSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + stateMsdx.shapeInfo()); + + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK]; + for (auto i = start; i < gradient.lengthOf(); i++) { + shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); + const auto zOffset = + bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); + const auto initMsgOffset = + bXInMsgSame ? xOffset + : shape::getOffset(initStateMsg.shapeInfo(), coords); + const auto stMsgOffset = + bXStMsgSame ? xOffset + : shape::getOffset(stateMsg.shapeInfo(), coords); + const auto initMsdxOffset = + bXInMsdxSame ? xOffset + : shape::getOffset(initStateMsdx.shapeInfo(), coords); + const auto stMsdxOffset = + bXStMsdxSame ? xOffset + : shape::getOffset(stateMsdx.shapeInfo(), coords); + + stMsg[stMsgOffset] = + rho * initMsg[initMsgOffset] + grad[xOffset] * grad[xOffset] * rhoT; + + up[zOffset] = + grad[xOffset] * + (sd::math::nd4j_sqrt(initMsdx[initMsdxOffset] + epsilon) / + sd::math::nd4j_sqrt(stMsg[stMsgOffset] + epsilon)); + + stMsdx[stMsdxOffset] = + rho * initMsdx[initMsdxOffset] + up[zOffset] * up[zOffset] * rhoT; + } + }; -void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx, - NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon) { - BUILD_SINGLE_SELECTOR(gradient.dataType(), adaDeltaUpdater_, (gradient, initStateMsg, initStateMsdx, update, stateMsg, stateMsdx, dRho, dEpsilon), FLOAT_TYPES); + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; } +void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateMsg, const NDArray& initStateMsdx, + NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, + const double dRho, const double dEpsilon) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), adaDeltaUpdater_, + (gradient, initStateMsg, initStateMsdx, update, + stateMsg, stateMsdx, dRho, dEpsilon), + FLOAT_TYPES); } -} -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaGrad.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaGrad.cpp index e65f34e72bf6..9f80f59dd666 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaGrad.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaGrad.cpp @@ -18,10 +18,10 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include #include #include #include +#include namespace sd { namespace ops { @@ -29,63 +29,75 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// template -static void adaGradUpdater_(const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateH, const double dLr, const double dEpsilon) { - - const T* grad = gradient.bufferAsT(); - const T* init = initState.bufferAsT(); - - T* up = update.bufferAsT(); - T* st = stateH.bufferAsT(); - - const T lr = static_cast(dLr); - const T epsilon = static_cast(dEpsilon); - - bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateH.ews() && 1 == initState.ews(); - bool bSameOrdering = gradient.ordering() == update.ordering() && update.ordering() == stateH.ordering() && stateH.ordering() == initState.ordering(); - - if (bEws1 && bSameOrdering) { - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - st[i] = init[i] + grad[i] * grad[i]; - up[i] = (lr * grad[i]) / (math::nd4j_sqrt(st[i]) + epsilon); - } - }; - - samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); - return; - } - - bool bXZsame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); - bool bXInSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initState.shapeInfo()); - bool bXStSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateH.shapeInfo()); - - auto func = PRAGMA_THREADS_FOR{ - - int coords[MAX_RANK]; - for (auto i = start; i < stop; i++) { - shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); - - const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); - - const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); - const auto initOffset = bXInSame ? xOffset : shape::getOffset(initState.shapeInfo(), coords); - const auto stOffset = bXStSame ? xOffset : shape::getOffset(stateH.shapeInfo(), coords); - - st[stOffset] = init[initOffset] + grad[xOffset] * grad[xOffset]; - up[zOffset] = (lr * grad[xOffset]) / (math::nd4j_sqrt(st[stOffset]) + epsilon); - } +static void adaGradUpdater_(const NDArray& gradient, const NDArray& initState, + NDArray& update, NDArray& stateH, const double dLr, + const double dEpsilon) { + const T* grad = gradient.bufferAsT(); + const T* init = initState.bufferAsT(); + + T* up = update.bufferAsT(); + T* st = stateH.bufferAsT(); + + const T lr = static_cast(dLr); + const T epsilon = static_cast(dEpsilon); + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateH.ews() && + 1 == initState.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateH.ordering() && + stateH.ordering() == initState.ordering(); + + if (bEws1 && bSameOrdering) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + st[i] = init[i] + grad[i] * grad[i]; + up[i] = (lr * grad[i]) / (math::nd4j_sqrt(st[i]) + epsilon); + } }; samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); return; -} + } + + bool bXZsame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); + bool bXInSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initState.shapeInfo()); + bool bXStSame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateH.shapeInfo()); + + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); + + const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); + + const auto zOffset = + bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); + const auto initOffset = + bXInSame ? xOffset : shape::getOffset(initState.shapeInfo(), coords); + const auto stOffset = + bXStSame ? xOffset : shape::getOffset(stateH.shapeInfo(), coords); + + st[stOffset] = init[initOffset] + grad[xOffset] * grad[xOffset]; + up[zOffset] = (lr * grad[xOffset]) / + (math::nd4j_sqrt(st[stOffset]) + epsilon); + } + }; -void updaterAdaGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateH, - const double dLr, const double dEpsilon) { - BUILD_SINGLE_SELECTOR(gradient.dataType(), adaGradUpdater_, (gradient, initState, update, stateH, dLr, dEpsilon), FLOAT_TYPES); + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; } +void updaterAdaGrad(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initState, NDArray& update, NDArray& stateH, + const double dLr, const double dEpsilon) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), adaGradUpdater_, + (gradient, initState, update, stateH, dLr, dEpsilon), + FLOAT_TYPES); } -} -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaMax.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaMax.cpp index 6c7d0d3225f8..315f5997ae3f 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaMax.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaMax.cpp @@ -18,10 +18,10 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include #include #include #include +#include namespace sd { namespace ops { @@ -29,85 +29,112 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// template -static void adaMaxUpdater_(const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - - const T* grad = gradient.bufferAsT(); - const T* initU = initStateU.bufferAsT(); - const T* initM = initStateM.bufferAsT(); - - T* up = update.bufferAsT(); - T* stU = stateU.bufferAsT(); - T* stM = stateM.bufferAsT(); - - const T lr = static_cast(dLr); - const T beta1 = static_cast(dBeta1); - const T beta2 = static_cast(dBeta2); - const T epsilon = static_cast(dEpsilon); - const T iteration = static_cast(nIteration); - const T beta1T = sd::math::nd4j_pow(beta1, (iteration + 1)); - T epsilonT = lr / (1.0 - beta1T); - if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) - epsilonT = epsilon; - - - bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && 1 == stateU.ews() && 1 == initStateU.ews(); - bool bSameOrdering = gradient.ordering() == update.ordering() && - update.ordering() == stateU.ordering() && - stateU.ordering() == initStateU.ordering() && - stateU.ordering() == initStateM.ordering() && stateM.ordering() == initStateM.ordering(); - - if (bEws1 && bSameOrdering) { - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - //m = B_1 * m + (1-B_1)*grad - stM[i] = beta1 * initM[i] + grad[i] * (1 - beta1); - //u = max(B_2 * u, |grad|) - stU[i] = sd::math::nd4j_max((beta2 * initU[i]), sd::math::nd4j_abs(grad[i])) + 1e-32; - - up[i] = stM[i] * epsilonT / stU[i]; - } - }; - - samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); - return; - } - - bool bXZsame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); - bool bXInVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateU.shapeInfo()); - bool bXStVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateU.shapeInfo()); - bool bXInMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateM.shapeInfo()); - bool bXStMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateM.shapeInfo()); - - auto func = PRAGMA_THREADS_FOR{ - - int coords[MAX_RANK]; - for (auto i = start; i < stop; i++) { - shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); - const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); - const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); - const auto initUOffset = bXInVSame ? xOffset : shape::getOffset(initStateU.shapeInfo(), coords); - const auto stUOffset = bXStVSame ? xOffset : shape::getOffset(stateU.shapeInfo(), coords); - const auto initMOffset = bXInMSame ? xOffset : shape::getOffset(initStateM.shapeInfo(), coords); - const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.shapeInfo(), coords); - - //m = B_1 * m + (1-B_1)*grad - stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); - //u = max(B_2 * u, |grad|) - stU[stUOffset] = sd::math::nd4j_max((beta2 * initU[initUOffset]), sd::math::nd4j_abs(grad[xOffset])) + 1e-32; - - up[zOffset] = stM[stMOffset] * epsilonT / stU[stUOffset]; - } +static void adaMaxUpdater_(const NDArray& gradient, const NDArray& initStateU, + const NDArray& initStateM, NDArray& update, + NDArray& stateU, NDArray& stateM, const double dLr, + const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + const T* grad = gradient.bufferAsT(); + const T* initU = initStateU.bufferAsT(); + const T* initM = initStateM.bufferAsT(); + + T* up = update.bufferAsT(); + T* stU = stateU.bufferAsT(); + T* stM = stateM.bufferAsT(); + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + const T beta1T = sd::math::nd4j_pow(beta1, (iteration + 1)); + T epsilonT = lr / (1.0 - beta1T); + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || + sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && + 1 == initStateM.ews() && 1 == stateU.ews() && + 1 == initStateU.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateU.ordering() && + stateU.ordering() == initStateU.ordering() && + stateU.ordering() == initStateM.ordering() && + stateM.ordering() == initStateM.ordering(); + + if (bEws1 && bSameOrdering) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + // m = B_1 * m + (1-B_1)*grad + stM[i] = beta1 * initM[i] + grad[i] * (1 - beta1); + // u = max(B_2 * u, |grad|) + stU[i] = sd::math::nd4j_max((beta2 * initU[i]), + sd::math::nd4j_abs(grad[i])) + + 1e-32; + + up[i] = stM[i] * epsilonT / stU[i]; + } }; samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); return; -} + } + + bool bXZsame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); + bool bXInVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initStateU.shapeInfo()); + bool bXStVSame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateU.shapeInfo()); + bool bXInMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initStateM.shapeInfo()); + bool bXStMSame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateM.shapeInfo()); + + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); + const auto zOffset = + bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); + const auto initUOffset = + bXInVSame ? xOffset + : shape::getOffset(initStateU.shapeInfo(), coords); + const auto stUOffset = + bXStVSame ? xOffset : shape::getOffset(stateU.shapeInfo(), coords); + const auto initMOffset = + bXInMSame ? xOffset + : shape::getOffset(initStateM.shapeInfo(), coords); + const auto stMOffset = + bXStMSame ? xOffset : shape::getOffset(stateM.shapeInfo(), coords); + + // m = B_1 * m + (1-B_1)*grad + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); + // u = max(B_2 * u, |grad|) + stU[stUOffset] = sd::math::nd4j_max((beta2 * initU[initUOffset]), + sd::math::nd4j_abs(grad[xOffset])) + + 1e-32; + + up[zOffset] = stM[stMOffset] * epsilonT / stU[stUOffset]; + } + }; -void updaterAdaMax(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - BUILD_SINGLE_SELECTOR(gradient.dataType(), adaMaxUpdater_, (gradient, initStateU, initStateM, update, stateU, stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; } +void updaterAdaMax(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateU, const NDArray& initStateM, + NDArray& update, NDArray& stateU, NDArray& stateM, + const double dLr, const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), adaMaxUpdater_, + (gradient, initStateU, initStateM, update, stateU, + stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), + FLOAT_TYPES); } -} -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterAdam.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdam.cpp index 2d670949f748..2d2e5d1fe2a2 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/updaterAdam.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdam.cpp @@ -18,10 +18,10 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include #include #include #include +#include namespace sd { namespace ops { @@ -29,85 +29,110 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// template -static void adamUpdater_(const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, - NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, - const double dEpsilon, const int nIteration) { - - const T* grad = gradient.bufferAsT(); - const T* initU = initStateU.bufferAsT(); - const T* initM = initStateM.bufferAsT(); - - T* up = update.bufferAsT(); - T* stU = stateU.bufferAsT(); - T* stM = stateM.bufferAsT(); - - const T lr = static_cast(dLr); - const T beta1 = static_cast(dBeta1); - const T beta2 = static_cast(dBeta2); - const T epsilon = static_cast(dEpsilon); - const T iteration = static_cast(nIteration); - - const T beta1T = sd::math::nd4j_pow(beta1, (iteration + 1)); - const T beta2T = sd::math::nd4j_pow(beta2, (iteration + 1)); - - T epsilonT = lr * sd::math::nd4j_sqrt(1. - beta2T) / (1.0 - beta1T); - if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) - epsilonT = epsilon; - - bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && 1 == stateU.ews() && 1 == initStateU.ews(); - bool bSameOrdering = gradient.ordering() == update.ordering() && - update.ordering() == stateU.ordering() && - stateU.ordering() == initStateU.ordering() && - stateU.ordering() == initStateM.ordering() && stateM.ordering() == initStateM.ordering(); - - if (bEws1 && bSameOrdering) { - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - stM[i] = beta1 * initM[i] + grad[i] * (1 - beta1); - stU[i] = beta2 * initU[i] + grad[i] * grad[i] * (1 - beta2); - - up[i] = (stM[i] * epsilonT) / (sd::math::nd4j_sqrt(stU[i]) + epsilon); - } - }; - - samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); - return; - } - - bool bXZsame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); - bool bXInVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateU.shapeInfo()); - bool bXStVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateU.shapeInfo()); - bool bXInMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateM.shapeInfo()); - bool bXStMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateM.shapeInfo()); - - auto func = PRAGMA_THREADS_FOR{ - - int coords[MAX_RANK]; - for (auto i = start; i < stop; i++) { - shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); - const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); - const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); - const auto initUOffset = bXInVSame ? xOffset : shape::getOffset(initStateU.shapeInfo(), coords); - const auto stUOffset = bXStVSame ? xOffset : shape::getOffset(stateU.shapeInfo(), coords); - const auto initMOffset = bXInVSame ? xOffset : shape::getOffset(initStateM.shapeInfo(), coords); - const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.shapeInfo(), coords); - - stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); - stU[stUOffset] = beta2 * initU[initUOffset] + grad[xOffset] * grad[xOffset] * (1 - beta2); - - up[zOffset] = (stM[stMOffset] * epsilonT) / (sd::math::nd4j_sqrt(stU[stUOffset]) + epsilon); - } +static void adamUpdater_(const NDArray& gradient, const NDArray& initStateU, + const NDArray& initStateM, NDArray& update, + NDArray& stateU, NDArray& stateM, const double dLr, + const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + const T* grad = gradient.bufferAsT(); + const T* initU = initStateU.bufferAsT(); + const T* initM = initStateM.bufferAsT(); + + T* up = update.bufferAsT(); + T* stU = stateU.bufferAsT(); + T* stM = stateM.bufferAsT(); + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + + const T beta1T = sd::math::nd4j_pow(beta1, (iteration + 1)); + const T beta2T = sd::math::nd4j_pow(beta2, (iteration + 1)); + + T epsilonT = lr * sd::math::nd4j_sqrt(1. - beta2T) / (1.0 - beta1T); + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || + sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && + 1 == initStateM.ews() && 1 == stateU.ews() && + 1 == initStateU.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateU.ordering() && + stateU.ordering() == initStateU.ordering() && + stateU.ordering() == initStateM.ordering() && + stateM.ordering() == initStateM.ordering(); + + if (bEws1 && bSameOrdering) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + stM[i] = beta1 * initM[i] + grad[i] * (1 - beta1); + stU[i] = beta2 * initU[i] + grad[i] * grad[i] * (1 - beta2); + + up[i] = + (stM[i] * epsilonT) / (sd::math::nd4j_sqrt(stU[i]) + epsilon); + } }; samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); return; -} + } + + bool bXZsame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); + bool bXInVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initStateU.shapeInfo()); + bool bXStVSame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateU.shapeInfo()); + bool bXInMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initStateM.shapeInfo()); + bool bXStMSame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateM.shapeInfo()); + + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); + const auto zOffset = + bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); + const auto initUOffset = + bXInVSame ? xOffset + : shape::getOffset(initStateU.shapeInfo(), coords); + const auto stUOffset = + bXStVSame ? xOffset : shape::getOffset(stateU.shapeInfo(), coords); + const auto initMOffset = + bXInVSame ? xOffset + : shape::getOffset(initStateM.shapeInfo(), coords); + const auto stMOffset = + bXStMSame ? xOffset : shape::getOffset(stateM.shapeInfo(), coords); + + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); + stU[stUOffset] = beta2 * initU[initUOffset] + + grad[xOffset] * grad[xOffset] * (1 - beta2); + + up[zOffset] = (stM[stMOffset] * epsilonT) / + (sd::math::nd4j_sqrt(stU[stUOffset]) + epsilon); + } + }; -void updaterAdam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - BUILD_SINGLE_SELECTOR(gradient.dataType(), adamUpdater_, (gradient, initStateU, initStateM, update, stateU, stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; } +void updaterAdam(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateU, const NDArray& initStateM, + NDArray& update, NDArray& stateU, NDArray& stateM, + const double dLr, const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), adamUpdater_, + (gradient, initStateU, initStateM, update, stateU, + stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), + FLOAT_TYPES); } -} -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterAmsGrad.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterAmsGrad.cpp index 7cb05075c24b..cf84bf49bf66 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/updaterAmsGrad.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterAmsGrad.cpp @@ -18,10 +18,10 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include #include #include #include +#include namespace sd { namespace ops { @@ -29,98 +29,134 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// template -static void amsGradUpdater_(const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH, - NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - - const T* grad = gradient.bufferAsT(); - const T* initV = initStateV.bufferAsT(); - const T* initM = initStateM.bufferAsT(); - const T* initH = initStateH.bufferAsT(); - - T* up = update.bufferAsT(); - T* stV = stateV.bufferAsT(); - T* stM = stateM.bufferAsT(); - T* stH = stateH.bufferAsT(); - - const T lr = static_cast(dLr); - const T beta1 = static_cast(dBeta1); - const T beta2 = static_cast(dBeta2); - const T epsilon = static_cast(dEpsilon); - const T iteration = static_cast(nIteration); - - T epsilonT = lr * sd::math::nd4j_sqrt(1.0 - sd::math::nd4j_pow(beta2, (iteration + 1))) / (1.0 - sd::math::nd4j_pow(beta1, (iteration + 1))); - - if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) - epsilonT = epsilon; - - const T mbeta1 = (1 - beta1); - const T mbeta2 = (1 - beta2); - - bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && - 1 == stateV.ews() && 1 == initStateV.ews() && 1 == stateH.ews() && 1 == initStateH.ews(); - bool bSameOrdering = gradient.ordering() == update.ordering() && - update.ordering() == stateV.ordering() && - stateV.ordering() == initStateV.ordering() && - stateV.ordering() == initStateM.ordering() && - stateM.ordering() == initStateM.ordering() && - stateM.ordering() == initStateH.ordering() && stateH.ordering() == initStateH.ordering(); - - if (bEws1 && bSameOrdering) { - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - stM[i] = beta1 * initM[i] + grad[i] * mbeta1; - stV[i] = beta2 * initV[i] + grad[i] * grad[i] * mbeta2; - stH[i] = sd::math::nd4j_max(initH[i], stV[i]); - - up[i] = epsilonT * stM[i] / (sd::math::nd4j_sqrt(stH[i]) + epsilon); - } - }; - - samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); - return; - } - - bool bXZsame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); - bool bXInVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateV.shapeInfo()); - bool bXStVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateV.shapeInfo()); - bool bXInMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateM.shapeInfo()); - bool bXStMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateM.shapeInfo()); - bool bXInHSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateH.shapeInfo()); - bool bXStHSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateH.shapeInfo()); - - auto func = PRAGMA_THREADS_FOR{ - - int coords[MAX_RANK]; - for (auto i = start; i < stop; i++) { - shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); - const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); - const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); - const auto initVOffset = bXInVSame ? xOffset : shape::getOffset(initStateV.shapeInfo(), coords); - const auto stVOffset = bXStVSame ? xOffset : shape::getOffset(stateV.shapeInfo(), coords); - const auto initMOffset = bXInMSame ? xOffset : shape::getOffset(initStateM.shapeInfo(), coords); - const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.shapeInfo(), coords); - const auto initHOffset = bXInHSame ? xOffset : shape::getOffset(initStateH.shapeInfo(), coords); - const auto stHOffset = bXStHSame ? xOffset : shape::getOffset(stateH.shapeInfo(), coords); - - stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * mbeta1; - stV[stVOffset] = beta2 * initV[initVOffset] + grad[xOffset] * grad[xOffset] * mbeta2; - stH[stHOffset] = sd::math::nd4j_max(initH[initHOffset], stV[stVOffset]); - - up[zOffset] = epsilonT * stM[stMOffset] / (sd::math::nd4j_sqrt(stH[stHOffset]) + epsilon); - } +static void amsGradUpdater_(const NDArray& gradient, const NDArray& initStateV, + const NDArray& initStateM, + const NDArray& initStateH, NDArray& update, + NDArray& stateV, NDArray& stateM, NDArray& stateH, + const double dLr, const double dBeta1, + const double dBeta2, const double dEpsilon, + const int nIteration) { + const T* grad = gradient.bufferAsT(); + const T* initV = initStateV.bufferAsT(); + const T* initM = initStateM.bufferAsT(); + const T* initH = initStateH.bufferAsT(); + + T* up = update.bufferAsT(); + T* stV = stateV.bufferAsT(); + T* stM = stateM.bufferAsT(); + T* stH = stateH.bufferAsT(); + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + + T epsilonT = lr * + sd::math::nd4j_sqrt( + 1.0 - sd::math::nd4j_pow(beta2, (iteration + 1))) / + (1.0 - sd::math::nd4j_pow(beta1, (iteration + 1))); + + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || + sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + const T mbeta1 = (1 - beta1); + const T mbeta2 = (1 - beta2); + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && + 1 == initStateM.ews() && 1 == stateV.ews() && + 1 == initStateV.ews() && 1 == stateH.ews() && + 1 == initStateH.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateV.ordering() && + stateV.ordering() == initStateV.ordering() && + stateV.ordering() == initStateM.ordering() && + stateM.ordering() == initStateM.ordering() && + stateM.ordering() == initStateH.ordering() && + stateH.ordering() == initStateH.ordering(); + + if (bEws1 && bSameOrdering) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + stM[i] = beta1 * initM[i] + grad[i] * mbeta1; + stV[i] = beta2 * initV[i] + grad[i] * grad[i] * mbeta2; + stH[i] = sd::math::nd4j_max(initH[i], stV[i]); + + up[i] = + epsilonT * stM[i] / (sd::math::nd4j_sqrt(stH[i]) + epsilon); + } }; samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); return; -} + } + + bool bXZsame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); + bool bXInVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initStateV.shapeInfo()); + bool bXStVSame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateV.shapeInfo()); + bool bXInMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initStateM.shapeInfo()); + bool bXStMSame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateM.shapeInfo()); + bool bXInHSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initStateH.shapeInfo()); + bool bXStHSame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateH.shapeInfo()); + + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); + const auto zOffset = + bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); + const auto initVOffset = + bXInVSame ? xOffset + : shape::getOffset(initStateV.shapeInfo(), coords); + const auto stVOffset = + bXStVSame ? xOffset : shape::getOffset(stateV.shapeInfo(), coords); + const auto initMOffset = + bXInMSame ? xOffset + : shape::getOffset(initStateM.shapeInfo(), coords); + const auto stMOffset = + bXStMSame ? xOffset : shape::getOffset(stateM.shapeInfo(), coords); + const auto initHOffset = + bXInHSame ? xOffset + : shape::getOffset(initStateH.shapeInfo(), coords); + const auto stHOffset = + bXStHSame ? xOffset : shape::getOffset(stateH.shapeInfo(), coords); + + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * mbeta1; + stV[stVOffset] = + beta2 * initV[initVOffset] + grad[xOffset] * grad[xOffset] * mbeta2; + stH[stHOffset] = sd::math::nd4j_max(initH[initHOffset], stV[stVOffset]); + + up[zOffset] = epsilonT * stM[stMOffset] / + (sd::math::nd4j_sqrt(stH[stHOffset]) + epsilon); + } + }; -void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH, - NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - BUILD_SINGLE_SELECTOR(gradient.dataType(), amsGradUpdater_, (gradient, initStateV, initStateM, initStateH, update, stateV, stateM, stateH, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; } - -} -} +void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateV, const NDArray& initStateM, + const NDArray& initStateH, NDArray& update, NDArray& stateV, + NDArray& stateM, NDArray& stateH, const double dLr, + const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + BUILD_SINGLE_SELECTOR( + gradient.dataType(), amsGradUpdater_, + (gradient, initStateV, initStateM, initStateH, update, stateV, stateM, + stateH, dLr, dBeta1, dBeta2, dEpsilon, nIteration), + FLOAT_TYPES); } + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterNadam.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterNadam.cpp index 40f9c9407789..5d5278be2aff 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/updaterNadam.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterNadam.cpp @@ -18,10 +18,10 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include #include #include #include +#include namespace sd { namespace ops { @@ -29,88 +29,111 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// template -static void nadamUpdater_(const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, - NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, - const double dBeta2, const double dEpsilon, const int nIteration) { - - const T* grad = gradient.bufferAsT(); - const T* initV = initStateV.bufferAsT(); - const T* initM = initStateM.bufferAsT(); - - T* up = update.bufferAsT(); - T* stV = stateV.bufferAsT(); - T* stM = stateM.bufferAsT(); - - const T lr = static_cast(dLr); - const T beta1 = static_cast(dBeta1); - const T beta2 = static_cast(dBeta2); - const T epsilon = static_cast(dEpsilon); - const T iteration = static_cast(nIteration); - - const T mbeta1T = 1.0 - sd::math::nd4j_pow(beta1, (iteration + 1)); - const T mbeta1 = (1 - beta1); - const T mbeta2 = (1 - beta2); - - bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && 1 == stateV.ews() && 1 == initStateV.ews(); - bool bSameOrdering = gradient.ordering() == update.ordering() && - update.ordering() == stateV.ordering() && - stateV.ordering() == initStateV.ordering() && - stateV.ordering() == initStateM.ordering() && stateM.ordering() == initStateM.ordering(); - - if (bEws1 && bSameOrdering) { - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - auto oneMinusBeta1Grad = grad[i] * mbeta1; - - stM[i] = beta1 * initM[i] + oneMinusBeta1Grad; - stV[i] = beta2 * initV[i] + grad[i] * grad[i] * mbeta2; - - up[i] = (lr * ((stM[i] * beta1 + oneMinusBeta1Grad) / mbeta1T)) / (sd::math::nd4j_sqrt(stV[i]) + epsilon); - } - }; - - samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); - return; - } - - bool bXZsame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); - bool bXInVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateV.shapeInfo()); - bool bXStVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateV.shapeInfo()); - bool bXInMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateM.shapeInfo()); - bool bXStMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateM.shapeInfo()); - - auto func = PRAGMA_THREADS_FOR{ - - int coords[MAX_RANK]; - for (auto i = start; i < stop; i++) { - shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); - const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); - const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); - const auto initVOffset = bXInVSame ? xOffset : shape::getOffset(initStateV.shapeInfo(), coords); - const auto stVOffset = bXStVSame ? xOffset : shape::getOffset(stateV.shapeInfo(), coords); - const auto initMOffset = bXInMSame ? xOffset : shape::getOffset(initStateM.shapeInfo(), coords); - const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.shapeInfo(), coords); - - auto oneMinusBeta1Grad = grad[xOffset] * mbeta1; - - stM[stMOffset] = beta1 * initM[initMOffset] + oneMinusBeta1Grad; - stV[stVOffset] = beta2 * initV[initVOffset] + grad[xOffset] * grad[xOffset] * mbeta2; - - up[zOffset] = (lr * ((stM[stMOffset] * beta1 + oneMinusBeta1Grad) / mbeta1T)) / (sd::math::nd4j_sqrt(stV[stVOffset]) + epsilon); - } +static void nadamUpdater_(const NDArray& gradient, const NDArray& initStateV, + const NDArray& initStateM, NDArray& update, + NDArray& stateV, NDArray& stateM, const double dLr, + const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + const T* grad = gradient.bufferAsT(); + const T* initV = initStateV.bufferAsT(); + const T* initM = initStateM.bufferAsT(); + + T* up = update.bufferAsT(); + T* stV = stateV.bufferAsT(); + T* stM = stateM.bufferAsT(); + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + + const T mbeta1T = 1.0 - sd::math::nd4j_pow(beta1, (iteration + 1)); + const T mbeta1 = (1 - beta1); + const T mbeta2 = (1 - beta2); + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && + 1 == initStateM.ews() && 1 == stateV.ews() && + 1 == initStateV.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateV.ordering() && + stateV.ordering() == initStateV.ordering() && + stateV.ordering() == initStateM.ordering() && + stateM.ordering() == initStateM.ordering(); + + if (bEws1 && bSameOrdering) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto oneMinusBeta1Grad = grad[i] * mbeta1; + + stM[i] = beta1 * initM[i] + oneMinusBeta1Grad; + stV[i] = beta2 * initV[i] + grad[i] * grad[i] * mbeta2; + + up[i] = (lr * ((stM[i] * beta1 + oneMinusBeta1Grad) / mbeta1T)) / + (sd::math::nd4j_sqrt(stV[i]) + epsilon); + } }; samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); return; -} + } + + bool bXZsame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); + bool bXInVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initStateV.shapeInfo()); + bool bXStVSame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateV.shapeInfo()); + bool bXInMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initStateM.shapeInfo()); + bool bXStMSame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateM.shapeInfo()); + + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); + const auto zOffset = + bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); + const auto initVOffset = + bXInVSame ? xOffset + : shape::getOffset(initStateV.shapeInfo(), coords); + const auto stVOffset = + bXStVSame ? xOffset : shape::getOffset(stateV.shapeInfo(), coords); + const auto initMOffset = + bXInMSame ? xOffset + : shape::getOffset(initStateM.shapeInfo(), coords); + const auto stMOffset = + bXStMSame ? xOffset : shape::getOffset(stateM.shapeInfo(), coords); + + auto oneMinusBeta1Grad = grad[xOffset] * mbeta1; + + stM[stMOffset] = beta1 * initM[initMOffset] + oneMinusBeta1Grad; + stV[stVOffset] = + beta2 * initV[initVOffset] + grad[xOffset] * grad[xOffset] * mbeta2; + + up[zOffset] = + (lr * ((stM[stMOffset] * beta1 + oneMinusBeta1Grad) / mbeta1T)) / + (sd::math::nd4j_sqrt(stV[stVOffset]) + epsilon); + } + }; -void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, - NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - BUILD_SINGLE_SELECTOR(gradient.dataType(), nadamUpdater_, (gradient, initStateV, initStateM, update, stateV, stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; } - -} -} +void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateV, const NDArray& initStateM, + NDArray& update, NDArray& stateV, NDArray& stateM, + const double dLr, const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), nadamUpdater_, + (gradient, initStateV, initStateM, update, stateV, + stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), + FLOAT_TYPES); } + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterNesterovs.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterNesterovs.cpp index 1d8bb8d45fb8..68b81826d881 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/updaterNesterovs.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterNesterovs.cpp @@ -18,10 +18,10 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include #include #include #include +#include namespace sd { namespace ops { @@ -29,63 +29,76 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// template -static void nesterovsUpdater_(const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateV, const double dLr, const double dMomentum) { +static void nesterovsUpdater_(const NDArray& gradient, const NDArray& initState, + NDArray& update, NDArray& stateV, + const double dLr, const double dMomentum) { + const T* grad = gradient.bufferAsT(); + const T* init = initState.bufferAsT(); - const T* grad = gradient.bufferAsT(); - const T* init = initState.bufferAsT(); + T* up = update.bufferAsT(); + T* st = stateV.bufferAsT(); - T* up = update.bufferAsT(); - T* st = stateV.bufferAsT(); + const T lr = static_cast(dLr); + const T momentum = static_cast(dMomentum); + const T momentumT = (-momentum - 1); - const T lr = static_cast(dLr); - const T momentum = static_cast(dMomentum); - const T momentumT = (-momentum - 1); + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateV.ews() && + 1 == initState.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateV.ordering() && + stateV.ordering() == initState.ordering(); - bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateV.ews() && 1 == initState.ews(); - bool bSameOrdering = gradient.ordering() == update.ordering() && update.ordering() == stateV.ordering() && stateV.ordering() == initState.ordering(); + if (bEws1 && bSameOrdering) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + T prevState = momentum * init[i]; + st[i] = prevState - lr * grad[i]; + up[i] = prevState + momentumT * st[i]; + } + }; - if (bEws1 && bSameOrdering) { - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - T prevState = momentum * init[i]; - st[i] = prevState - lr * grad[i]; - up[i] = prevState + momentumT * st[i]; - } - }; + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; + } - samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); - return; - } - - bool bXZsame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); - bool bXInSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initState.shapeInfo()); - bool bXStSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateV.shapeInfo()); + bool bXZsame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); + bool bXInSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initState.shapeInfo()); + bool bXStSame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateV.shapeInfo()); - auto func = PRAGMA_THREADS_FOR{ + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); + const auto zOffset = + bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); + const auto initOffset = + bXInSame ? xOffset : shape::getOffset(initState.shapeInfo(), coords); + const auto stOffset = + bXStSame ? xOffset : shape::getOffset(stateV.shapeInfo(), coords); - int coords[MAX_RANK]; - for (auto i = start; i < stop; i++) { - shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); - const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); - const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); - const auto initOffset = bXInSame ? xOffset : shape::getOffset(initState.shapeInfo(), coords); - const auto stOffset = bXStSame ? xOffset : shape::getOffset(stateV.shapeInfo(), coords); - - T prevState = momentum * init[initOffset]; - st[stOffset] = prevState - lr * grad[xOffset]; - up[zOffset] = prevState + momentumT * st[stOffset]; - } - }; + T prevState = momentum * init[initOffset]; + st[stOffset] = prevState - lr * grad[xOffset]; + up[zOffset] = prevState + momentumT * st[stOffset]; + } + }; - samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); - return; + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; } -void updaterNesterovs(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateV, const double dLr, const double dMomentum) { - BUILD_SINGLE_SELECTOR(gradient.dataType(), nesterovsUpdater_, (gradient, initState, update, stateV, dLr, dMomentum), FLOAT_TYPES); +void updaterNesterovs(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initState, NDArray& update, + NDArray& stateV, const double dLr, + const double dMomentum) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), nesterovsUpdater_, + (gradient, initState, update, stateV, dLr, dMomentum), + FLOAT_TYPES); } -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterRmsProp.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterRmsProp.cpp index 473b43cf8c03..8e2ed9d88837 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/updaterRmsProp.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterRmsProp.cpp @@ -18,74 +18,87 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include #include #include #include +#include namespace sd { namespace ops { namespace helpers { template -static void rmsPropUpdater_(const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateG, - const double dLr, const double dRmsDecay, const double dEpsilon) { - - const T* grad = gradient.bufferAsT(); - const T* init = initState.bufferAsT(); +static void rmsPropUpdater_(const NDArray& gradient, const NDArray& initState, + NDArray& update, NDArray& stateG, const double dLr, + const double dRmsDecay, const double dEpsilon) { + const T* grad = gradient.bufferAsT(); + const T* init = initState.bufferAsT(); - T* up = update.bufferAsT(); - T* st = stateG.bufferAsT(); - - const T lr = static_cast(dLr); - const T rmsDecay = static_cast(dRmsDecay); - const T epsilon = static_cast(dEpsilon); - - bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateG.ews() && 1 == initState.ews(); - bool bSameOrdering = gradient.ordering() == update.ordering() && update.ordering() == stateG.ordering() && stateG.ordering() == initState.ordering(); + T* up = update.bufferAsT(); + T* st = stateG.bufferAsT(); - if (bEws1 && bSameOrdering) { - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - st[i] = init[i] * rmsDecay + grad[i] * grad[i] * (1 - rmsDecay) ; - up[i] = (lr * grad[i]) / ( math::nd4j_sqrt(st[i]) + epsilon); - } - }; + const T lr = static_cast(dLr); + const T rmsDecay = static_cast(dRmsDecay); + const T epsilon = static_cast(dEpsilon); - samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); - return; - } - - bool bXZsame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); - bool bXInSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initState.shapeInfo()); - bool bXStSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateG.shapeInfo()); - - auto func = PRAGMA_THREADS_FOR{ + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateG.ews() && + 1 == initState.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateG.ordering() && + stateG.ordering() == initState.ordering(); - int coords[MAX_RANK]; - for (auto i = start; i < stop; i++) { - shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); - const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); - const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); - const auto initOffset = bXInSame ? xOffset : shape::getOffset(initState.shapeInfo(), coords); - const auto stOffset = bXStSame ? xOffset : shape::getOffset(stateG.shapeInfo(), coords); - - st[stOffset] = init[initOffset] * rmsDecay + grad[xOffset] * grad[xOffset] * (1 - rmsDecay) ; - up[zOffset] = (lr * grad[xOffset]) / ( math::nd4j_sqrt(st[stOffset]) + epsilon); - } + if (bEws1 && bSameOrdering) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + st[i] = init[i] * rmsDecay + grad[i] * grad[i] * (1 - rmsDecay); + up[i] = (lr * grad[i]) / (math::nd4j_sqrt(st[i]) + epsilon); + } }; samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); return; -} + } -void updaterRmsProp(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateG, - const double dLr, const double dRmsDecay, const double dEpsilon) { - BUILD_SINGLE_SELECTOR(gradient.dataType(), rmsPropUpdater_, (gradient, initState, update, stateG, dLr, dRmsDecay, dEpsilon), FLOAT_TYPES); -} + bool bXZsame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); + bool bXInSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initState.shapeInfo()); + bool bXStSame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateG.shapeInfo()); + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); + const auto zOffset = + bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); + const auto initOffset = + bXInSame ? xOffset : shape::getOffset(initState.shapeInfo(), coords); + const auto stOffset = + bXStSame ? xOffset : shape::getOffset(stateG.shapeInfo(), coords); + st[stOffset] = init[initOffset] * rmsDecay + + grad[xOffset] * grad[xOffset] * (1 - rmsDecay); + up[zOffset] = (lr * grad[xOffset]) / + (math::nd4j_sqrt(st[stOffset]) + epsilon); + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; } + +void updaterRmsProp(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initState, NDArray& update, NDArray& stateG, + const double dLr, const double dRmsDecay, + const double dEpsilon) { + BUILD_SINGLE_SELECTOR( + gradient.dataType(), rmsPropUpdater_, + (gradient, initState, update, stateG, dLr, dRmsDecay, dEpsilon), + FLOAT_TYPES); } -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/weights.cpp b/libnd4j/include/ops/declarable/helpers/cpu/weights.cpp index ebdfc674b196..da95201836c2 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/weights.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/weights.cpp @@ -24,24 +24,31 @@ namespace sd { namespace ops { namespace helpers { - template - static void adjustWeights_(NDArray* input, NDArray* weights, NDArray* output, int minLength, int maxLength) { - for (Nd4jLong e = 0; e < input->lengthOf(); e++) { - int val = input->e(e); - if (val < maxLength) { - if (weights != nullptr) - output->p(val, output->e(val) + weights->e(e)); - else - output->p(val, output->e(val) + 1); - } - } +template +static void adjustWeights_(NDArray* input, NDArray* weights, NDArray* output, + int minLength, int maxLength) { + for (Nd4jLong e = 0; e < input->lengthOf(); e++) { + int val = input->e(e); + if (val < maxLength) { + if (weights != nullptr) + output->p(val, output->e(val) + weights->e(e)); + else + output->p(val, output->e(val) + 1); } - - void adjustWeights(sd::LaunchContext * context, NDArray* input, NDArray* weights, NDArray* output, int minLength, int maxLength) { - BUILD_SINGLE_SELECTOR(output->dataType(), adjustWeights_, (input, weights, output, minLength, maxLength), LIBND4J_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template void adjustWeights_, (NDArray* input, NDArray* weights, NDArray* output, int minLength, int maxLength), LIBND4J_TYPES); + } } + +void adjustWeights(sd::LaunchContext* context, NDArray* input, NDArray* weights, + NDArray* output, int minLength, int maxLength) { + BUILD_SINGLE_SELECTOR(output->dataType(), adjustWeights_, + (input, weights, output, minLength, maxLength), + LIBND4J_TYPES); } -} \ No newline at end of file + +BUILD_SINGLE_TEMPLATE(template void adjustWeights_, + (NDArray * input, NDArray* weights, NDArray* output, + int minLength, int maxLength), + LIBND4J_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/zeta.cpp b/libnd4j/include/ops/declarable/helpers/cpu/zeta.cpp index d127fc1665c9..e4ca9ff0fe6a 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/zeta.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/zeta.cpp @@ -18,66 +18,66 @@ // Created by Yurii Shyrma on 12.12.2017 // -#include #include +#include namespace sd { namespace ops { namespace helpers { -const int maxIter = 1000000; // max number of loop iterations +const int maxIter = 1000000; // max number of loop iterations ////////////////////////////////////////////////////////////////////////// // slow implementation template static FORCEINLINE T zetaScalarSlow(const T x, const T q) { - - const T precision = (T)1e-7; // function stops the calculation of series when next item is <= precision - - // if (x <= (T)1.) - // throw("zeta function: x must be > 1 !"); - - // if (q <= (T)0.) - // throw("zeta function: q must be > 0 !"); - - T item; - T result = (T)0.; - for(int i = 0; i < maxIter; ++i) { - - item = math::nd4j_pow((q + i),-x); - result += item; - - if(item <= precision) - break; - } - - return result; -} + const T precision = (T)1e-7; // function stops the calculation of series when + // next item is <= precision + // if (x <= (T)1.) + // throw("zeta function: x must be > 1 !"); -////////////////////////////////////////////////////////////////////////// -// calculate the Hurwitz zeta function for arrays -template -static void zeta_(sd::LaunchContext * context, const NDArray& x, const NDArray& q, NDArray &z) { + // if (q <= (T)0.) + // throw("zeta function: q must be > 0 !"); - //auto result = NDArray(&x, false, context); - int xLen = x.lengthOf(); + T item; + T result = (T)0.; + for (int i = 0; i < maxIter; ++i) { + item = math::nd4j_pow((q + i), -x); + result += item; - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - z.p(i, zetaScalar(x.e(i), q.e(i))); - }; + if (item <= precision) break; + } - samediff::Threads::parallel_for(func, 0, xLen); + return result; } -void zeta(sd::LaunchContext * context, const NDArray& x, const NDArray& q, NDArray& z) { - BUILD_SINGLE_SELECTOR(x.dataType(), zeta_, (context, x, q, z), FLOAT_TYPES); -} +////////////////////////////////////////////////////////////////////////// +// calculate the Hurwitz zeta function for arrays +template +static void zeta_(sd::LaunchContext* context, const NDArray& x, + const NDArray& q, NDArray& z) { + // auto result = NDArray(&x, false, context); + int xLen = x.lengthOf(); -BUILD_SINGLE_TEMPLATE(template void zeta_, (sd::LaunchContext * context, const NDArray& x, const NDArray& q, NDArray& z), FLOAT_TYPES); + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) + z.p(i, zetaScalar(x.e(i), q.e(i))); + }; + samediff::Threads::parallel_for(func, 0, xLen); } + +void zeta(sd::LaunchContext* context, const NDArray& x, const NDArray& q, + NDArray& z) { + BUILD_SINGLE_SELECTOR(x.dataType(), zeta_, (context, x, q, z), FLOAT_TYPES); } -} +BUILD_SINGLE_TEMPLATE(template void zeta_, + (sd::LaunchContext * context, const NDArray& x, + const NDArray& q, NDArray& z), + FLOAT_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/crop_and_resize.h b/libnd4j/include/ops/declarable/helpers/crop_and_resize.h index cff96f93e5b2..18d9ea16d313 100644 --- a/libnd4j/include/ops/declarable/helpers/crop_and_resize.h +++ b/libnd4j/include/ops/declarable/helpers/crop_and_resize.h @@ -14,7 +14,6 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // @author sgazeos@gmail.com // @@ -22,19 +21,23 @@ #ifndef SD_CROP_AND_RESIZE_H #define SD_CROP_AND_RESIZE_H -#include #include - +#include namespace sd { - namespace ops { - namespace helpers { - template - void cropAndResizeFunctor_(NDArray const *images, NDArray const *boxes, NDArray const *indices, NDArray const *cropSize, int method, double extrapolationVal, NDArray *crops); - - void cropAndResizeFunctor(sd::LaunchContext * context, NDArray const* images, NDArray const* boxes, NDArray const* indices, NDArray const* cropSize, int method, double extrapolationVal, NDArray* crops); - } - } -} - -#endif //SD_CROP_AND_RESIZE_H +namespace ops { +namespace helpers { +template +void cropAndResizeFunctor_(NDArray const* images, NDArray const* boxes, + NDArray const* indices, NDArray const* cropSize, + int method, double extrapolationVal, NDArray* crops); + +void cropAndResizeFunctor(sd::LaunchContext* context, NDArray const* images, + NDArray const* boxes, NDArray const* indices, + NDArray const* cropSize, int method, + double extrapolationVal, NDArray* crops); +} // namespace helpers +} // namespace ops +} // namespace sd + +#endif // SD_CROP_AND_RESIZE_H diff --git a/libnd4j/include/ops/declarable/helpers/cross.h b/libnd4j/include/ops/declarable/helpers/cross.h index bd1e2a61dbe9..818613749184 100644 --- a/libnd4j/include/ops/declarable/helpers/cross.h +++ b/libnd4j/include/ops/declarable/helpers/cross.h @@ -18,69 +18,75 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { namespace ops { namespace helpers { -void crossBatched(sd::LaunchContext * context, NDArray *a, NDArray *b, NDArray *o); - -void FORCEINLINE cross(sd::LaunchContext * context, NDArray *a, NDArray *b, NDArray *o) { - - if (a->isR()) { - auto a0 = a->e(0); - auto a1 = a->e(1); - auto a2 = a->e(2); - - auto b0 = b->e(0); - auto b1 = b->e(1); - auto b2 = b->e(2); - - o->p(Nd4jLong(0L), a1 * b2 - a2 * b1); - o->p(1L, a2 * b0 - a0 * b2); - o->p(2L, a0 * b1 - a1 * b0); - } else { - auto a0 = a->e(0); - auto a1 = a->e(1); - auto a2 = a->e(2); - - auto b0 = b->e(0); - auto b1 = b->e(1); - auto b2 = b->e(2); - - o->p(Nd4jLong(0L), a1 * b2 - a2 * b1); - o->p(1L, a2 * b0 - a0 * b2); - o->p(2L, a0 * b1 - a1 * b0); - } +void crossBatched(sd::LaunchContext *context, NDArray *a, NDArray *b, + NDArray *o); + +void FORCEINLINE cross(sd::LaunchContext *context, NDArray *a, NDArray *b, + NDArray *o) { + if (a->isR()) { + auto a0 = a->e(0); + auto a1 = a->e(1); + auto a2 = a->e(2); + + auto b0 = b->e(0); + auto b1 = b->e(1); + auto b2 = b->e(2); + + o->p(Nd4jLong(0L), a1 * b2 - a2 * b1); + o->p(1L, a2 * b0 - a0 * b2); + o->p(2L, a0 * b1 - a1 * b0); + } else { + auto a0 = a->e(0); + auto a1 = a->e(1); + auto a2 = a->e(2); + + auto b0 = b->e(0); + auto b1 = b->e(1); + auto b2 = b->e(2); + + o->p(Nd4jLong(0L), a1 * b2 - a2 * b1); + o->p(1L, a2 * b0 - a0 * b2); + o->p(2L, a0 * b1 - a1 * b0); + } } - void FORCEINLINE _crossBatched(sd::LaunchContext * context, NDArray *a, NDArray *b, NDArray *o) { - auto a_ = a->reshape(a->ordering(), {-1, 3}); - auto b_ = b->reshape(b->ordering(), {-1, 3}); - auto o_ = o->reshape(o->ordering(), {-1, 3}, false); +void FORCEINLINE _crossBatched(sd::LaunchContext *context, NDArray *a, + NDArray *b, NDArray *o) { + auto a_ = a->reshape(a->ordering(), {-1, 3}); + auto b_ = b->reshape(b->ordering(), {-1, 3}); + auto o_ = o->reshape(o->ordering(), {-1, 3}, false); - auto tadsA = a_.allTensorsAlongDimension({1}); - auto tadsB = b_.allTensorsAlongDimension({1}); - auto tadsO = o_.allTensorsAlongDimension({1}); + auto tadsA = a_.allTensorsAlongDimension({1}); + auto tadsB = b_.allTensorsAlongDimension({1}); + auto tadsO = o_.allTensorsAlongDimension({1}); - int tads = tadsA.size(); + int tads = tadsA.size(); - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto a_ = tadsA.at(e); - auto b_ = tadsB.at(e); - auto o_ = tadsO.at(e); + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto a_ = tadsA.at(e); + auto b_ = tadsB.at(e); + auto o_ = tadsO.at(e); - helpers::cross(context, a_, b_, o_); - } - }; - - samediff::Threads::parallel_tad(func, 0, tads); + helpers::cross(context, &a_, &b_, &o_); } + }; - void weightedCrossEntropyWithLogitsFunctor(sd::LaunchContext * context, NDArray const* targets, NDArray const* input, NDArray const* weights, NDArray* output); -} + samediff::Threads::parallel_tad(func, 0, tads); } -} \ No newline at end of file + +void weightedCrossEntropyWithLogitsFunctor(sd::LaunchContext *context, + NDArray const *targets, + NDArray const *input, + NDArray const *weights, + NDArray *output); +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/BarnesHutTsne.cu b/libnd4j/include/ops/declarable/helpers/cuda/BarnesHutTsne.cu index 70ff75b96336..afffabe80a7f 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/BarnesHutTsne.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/BarnesHutTsne.cu @@ -24,47 +24,49 @@ namespace sd { namespace ops { namespace helpers { //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -// count rows kernel - count input pRows and pCols and put result onto pRowCounts -// pRowCounts - array of ints, with length N -// pRows - array of ints with length N, vals from 0 to N-1 -// pCols - array of ints with length < N and vals between 0 and max(pRows) +// count rows kernel - count input pRows and pCols and put result onto +// pRowCounts pRowCounts - array of ints, with length N pRows - array of ints +// with length N, vals from 0 to N-1 pCols - array of ints with length < N and +// vals between 0 and max(pRows) // - static __global__ void countRowsKernel(int* pRowCounts, int const* pRows, int const* pCols, Nd4jLong N) { - auto start = blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; - for (int n = threadIdx.x + start; n < N; n += step) { - int begin = pRows[n];//->e(n); - int end = pRows[n + 1];//rowP->e(n + 1); - for (int i = begin; i < end; i++) { - bool present = false; - // loop between near pRows - for (int m = pRows[pCols[i]]; m < pRows[pCols[i] + 1]; m++) - if (pCols[m] == n) { // mark index as existed with columns array - present = true; - break; - } +static __global__ void countRowsKernel(int* pRowCounts, int const* pRows, + int const* pCols, Nd4jLong N) { + auto start = blockIdx.x * blockDim.x; + auto step = blockDim.x * gridDim.x; + for (int n = threadIdx.x + start; n < N; n += step) { + int begin = pRows[n]; //->e(n); + int end = pRows[n + 1]; // rowP->e(n + 1); + for (int i = begin; i < end; i++) { + bool present = false; + // loop between near pRows + for (int m = pRows[pCols[i]]; m < pRows[pCols[i] + 1]; m++) + if (pCols[m] == n) { // mark index as existed with columns array + present = true; + break; + } - atomicAdd(&pRowCounts[n], 1); + atomicAdd(&pRowCounts[n], 1); - if (!present) // increment row counter for given index - atomicAdd(&pRowCounts[pCols[i]], 1); - } - } + if (!present) // increment row counter for given index + atomicAdd(&pRowCounts[pCols[i]], 1); } + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // row counter caller - Nd4jLong barnes_row_count(const NDArray* rowP, const NDArray* colP, Nd4jLong N, NDArray& rowCounts) { - - int* pRowCounts = reinterpret_cast(rowCounts.specialBuffer()); - int const* pRows = reinterpret_cast(rowP->specialBuffer()); - int const* pCols = reinterpret_cast(colP->specialBuffer()); - auto stream = rowCounts.getContext()->getCudaStream(); - countRowsKernel<<<1, 1, 128, *stream>>>(pRowCounts, pRows, pCols, N); - NDArray numElementsArr = rowCounts.sumNumber(); //reduceAlongDimension(reduce::Sum, {}); - //rowCounts.printBuffer("Row counts"); - auto numElements = numElementsArr.e(0); - return numElements; - } +Nd4jLong barnes_row_count(const NDArray* rowP, const NDArray* colP, Nd4jLong N, + NDArray& rowCounts) { + int* pRowCounts = reinterpret_cast(rowCounts.specialBuffer()); + int const* pRows = reinterpret_cast(rowP->specialBuffer()); + int const* pCols = reinterpret_cast(colP->specialBuffer()); + auto stream = rowCounts.getContext()->getCudaStream(); + countRowsKernel<<<1, 1, 128, *stream>>>(pRowCounts, pRows, pCols, N); + NDArray numElementsArr = + rowCounts.sumNumber(); // reduceAlongDimension(reduce::Sum, {}); + // rowCounts.printBuffer("Row counts"); + auto numElements = numElementsArr.e(0); + return numElements; +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // extend symRowP with pRowCounts array vals @@ -72,18 +74,17 @@ namespace helpers { // symRowP - int array with length N+1 // N - given array length // - static __global__ void fillUpsymRow(int const* pRowCounts, int* symRowP, int N) { - - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (int n = start; n < N + 1; n += step) { // to avoid race condition use shift only for given index - symRowP[n] = 0; - for (int i = 0; i < n; i++) - atomicAdd(&symRowP[n], pRowCounts[i]); - } +static __global__ void fillUpsymRow(int const* pRowCounts, int* symRowP, + int N) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; - } + for (int n = start; n < N + 1; + n += step) { // to avoid race condition use shift only for given index + symRowP[n] = 0; + for (int i = 0; i < n; i++) atomicAdd(&symRowP[n], pRowCounts[i]); + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // symmetrize routine kernel @@ -96,183 +97,225 @@ namespace helpers { // pOutput - result matrix (floats) // N - pRows length // - template - static __global__ void symmetrizeKernel(int const* pRows, int const* pCols, T const* pVals, int* symRowP, int* symColP, int* offset, T* pOutput, int N) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; +template +static __global__ void symmetrizeKernel(int const* pRows, int const* pCols, + T const* pVals, int* symRowP, + int* symColP, int* offset, T* pOutput, + int N) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; - for (int n = start; n < N; n += step) { - int begin = pRows[n]; - int bound = pRows[n + 1]; + for (int n = start; n < N; n += step) { + int begin = pRows[n]; + int bound = pRows[n + 1]; - for (int i = begin; i < bound; i++) { - bool present = false; - int colPI = pCols[i]; - int start = pRows[colPI]; - int end = pRows[colPI + 1]; + for (int i = begin; i < bound; i++) { + bool present = false; + int colPI = pCols[i]; + int start = pRows[colPI]; + int end = pRows[colPI + 1]; - for (int m = start; m < end; m++) { - if (pCols[m] == n) { - present = true; - if (n <= colPI) { - symColP[symRowP[n] + offset[n]] = colPI; - symColP[symRowP[colPI] + offset[colPI]] = n; - pOutput[symRowP[n] + offset[n]] = pVals[i] + pVals[m]; - pOutput[symRowP[colPI] + offset[colPI]] = pVals[i] + pVals[m]; - } - } - } + for (int m = start; m < end; m++) { + if (pCols[m] == n) { + present = true; + if (n <= colPI) { + symColP[symRowP[n] + offset[n]] = colPI; + symColP[symRowP[colPI] + offset[colPI]] = n; + pOutput[symRowP[n] + offset[n]] = pVals[i] + pVals[m]; + pOutput[symRowP[colPI] + offset[colPI]] = pVals[i] + pVals[m]; + } + } + } - // If (colP[i], n) is not present, there is no addition involved - if (!present) { - symColP[symRowP[n] + offset[n]] = colPI; - symColP[symRowP[pCols[i]] + offset[colPI]] = n; - pOutput[symRowP[n] + offset[n]] = pVals[i]; - pOutput[symRowP[colPI] + offset[colPI]] = pVals[i]; - } - // Update offsets - if (!present || (present && n <= colPI)) { - atomicAdd(&offset[n], 1); + // If (colP[i], n) is not present, there is no addition involved + if (!present) { + symColP[symRowP[n] + offset[n]] = colPI; + symColP[symRowP[pCols[i]] + offset[colPI]] = n; + pOutput[symRowP[n] + offset[n]] = pVals[i]; + pOutput[symRowP[colPI] + offset[colPI]] = pVals[i]; + } + // Update offsets + if (!present || (present && n <= colPI)) { + atomicAdd(&offset[n], 1); - if (colPI != n) - atomicAdd(&offset[colPI], 1); - } - } - } + if (colPI != n) atomicAdd(&offset[colPI], 1); + } } + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // symmetrize algorithm itself // - template - static void barnes_symmetrize_(const NDArray* rowP, const NDArray* colP, const NDArray* valP, Nd4jLong N, NDArray* outputRows, NDArray* outputCols, NDArray* outputVals, NDArray* rowCounts) { - int const* pRows = reinterpret_cast(rowP->specialBuffer()); - int* symRowP = reinterpret_cast(outputRows->specialBuffer()); - int* pRowCounts = reinterpret_cast(rowCounts->specialBuffer()); - auto stream = outputCols->getContext()->getCudaStream(); - // fill up syRowP array - fillUpsymRow<<<1, N, 128, *stream>>>(pRowCounts, symRowP, N); - outputRows->syncToHost(); -// outputRows->printBuffer("output rows"); - int* symColP = reinterpret_cast(outputCols->specialBuffer()); -// outputRows->printBuffer("SymRows are"); - int const* pCols = reinterpret_cast(colP->specialBuffer()); - T const* pVals = reinterpret_cast(valP->specialBuffer()); - T* pOutput = reinterpret_cast(outputVals->specialBuffer()); - //std::vector rowCountsV = rowCounts->getBufferAsVector(); - auto offsetArr = NDArrayFactory::create('c', {N}); - int* offset = reinterpret_cast(offsetArr.specialBuffer()); - // symmetrize itself - symmetrizeKernel<<<1, 1, 1024, *stream>>>(pRows, pCols, pVals, symRowP, symColP, offset, pOutput, N); - } +template +static void barnes_symmetrize_(const NDArray* rowP, const NDArray* colP, + const NDArray* valP, Nd4jLong N, + NDArray* outputRows, NDArray* outputCols, + NDArray* outputVals, NDArray* rowCounts) { + int const* pRows = reinterpret_cast(rowP->specialBuffer()); + int* symRowP = reinterpret_cast(outputRows->specialBuffer()); + int* pRowCounts = reinterpret_cast(rowCounts->specialBuffer()); + auto stream = outputCols->getContext()->getCudaStream(); + // fill up syRowP array + fillUpsymRow<<<1, N, 128, *stream>>>(pRowCounts, symRowP, N); + outputRows->syncToHost(); + // outputRows->printBuffer("output rows"); + int* symColP = reinterpret_cast(outputCols->specialBuffer()); + // outputRows->printBuffer("SymRows are"); + int const* pCols = reinterpret_cast(colP->specialBuffer()); + T const* pVals = reinterpret_cast(valP->specialBuffer()); + T* pOutput = reinterpret_cast(outputVals->specialBuffer()); + // std::vector rowCountsV = rowCounts->getBufferAsVector(); + auto offsetArr = NDArrayFactory::create('c', {N}); + int* offset = reinterpret_cast(offsetArr.specialBuffer()); + // symmetrize itself + symmetrizeKernel<<<1, 1, 1024, *stream>>>(pRows, pCols, pVals, symRowP, + symColP, offset, pOutput, N); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // symmetrize caller and adoption // - void barnes_symmetrize(const NDArray* rowP, const NDArray* colP, const NDArray* valP, Nd4jLong N, NDArray* outputRows, NDArray* outputCols, NDArray* outputVals, NDArray* rowCounts) { - BUILD_SINGLE_SELECTOR(valP->dataType(), barnes_symmetrize_, (rowP, colP, valP, N, outputRows, outputCols, outputVals, rowCounts), NUMERIC_TYPES); +void barnes_symmetrize(const NDArray* rowP, const NDArray* colP, + const NDArray* valP, Nd4jLong N, NDArray* outputRows, + NDArray* outputCols, NDArray* outputVals, + NDArray* rowCounts) { + BUILD_SINGLE_SELECTOR( + valP->dataType(), barnes_symmetrize_, + (rowP, colP, valP, N, outputRows, outputCols, outputVals, rowCounts), + NUMERIC_TYPES); - *outputVals /= 2.0; - } - BUILD_SINGLE_TEMPLATE(template void barnes_symmetrize_, (const NDArray* rowP, const NDArray* colP, const NDArray* valP, Nd4jLong N, NDArray* outputRows, NDArray* outputCols, NDArray* outputVals, NDArray* rowCounts), NUMERIC_TYPES); + *outputVals /= 2.0; +} +BUILD_SINGLE_TEMPLATE(template void barnes_symmetrize_, + (const NDArray* rowP, const NDArray* colP, + const NDArray* valP, Nd4jLong N, NDArray* outputRows, + NDArray* outputCols, NDArray* outputVals, + NDArray* rowCounts), + NUMERIC_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // edge forces implementation // - template - static __global__ void edgeForcesKernel(int const* pRows, int const* pCols, T const* dataP, T const* vals, T* outputP, int N, int colCount, int rowSize) { -// std::vector buffer(colCount); +template +static __global__ void edgeForcesKernel(int const* pRows, int const* pCols, + T const* dataP, T const* vals, + T* outputP, int N, int colCount, + int rowSize) { + // std::vector buffer(colCount); - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; - for (int n = start; n < N; n += step) { - int start = pRows[n]; - int end = pRows[n + 1]; - int shift = n * colCount; - for (int i = start; i < end; i++) { - T const* thisSlice = dataP + pCols[i] * colCount; - T res = 1; + for (int n = start; n < N; n += step) { + int start = pRows[n]; + int end = pRows[n + 1]; + int shift = n * colCount; + for (int i = start; i < end; i++) { + T const* thisSlice = dataP + pCols[i] * colCount; + T res = 1; - for (int k = 0; k < colCount; k++) { - auto valTemp = dataP[shift + k] - thisSlice[k];//thisSlice[k]; - res += valTemp * valTemp; // (dataP[shift + k] * dataP[shift + k] - 2 * dataP[shift + k] * thisSlice[k] + thisSlice[k] * thisSlice[k]) - } - res = vals[i] / res; - for (int k = 0; k < colCount; k++) - math::atomics::nd4j_atomicAdd(&outputP[shift + k], T((dataP[shift + k] - thisSlice[k]) * res)); - } - } + for (int k = 0; k < colCount; k++) { + auto valTemp = dataP[shift + k] - thisSlice[k]; // thisSlice[k]; + res += + valTemp * + valTemp; // (dataP[shift + k] * dataP[shift + k] - 2 * dataP[shift + // + k] * thisSlice[k] + thisSlice[k] * thisSlice[k]) + } + res = vals[i] / res; + for (int k = 0; k < colCount; k++) + math::atomics::nd4j_atomicAdd( + &outputP[shift + k], T((dataP[shift + k] - thisSlice[k]) * res)); } + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // edge forces algorithm // - template - static void barnes_edge_forces_(const NDArray* rowP, NDArray const* colP, NDArray const* valP, int N, NDArray const* data, NDArray* output) { - NDArray::prepareSpecialUse({output}, {data, rowP, colP, valP, valP}); - T const* dataP = reinterpret_cast(data->specialBuffer()); - T const* vals = reinterpret_cast(valP->specialBuffer()); - T* outputP = reinterpret_cast(output->specialBuffer()); - int const* pRows = reinterpret_cast(rowP->specialBuffer()); - int const* pCols = reinterpret_cast(colP->specialBuffer()); - int colCount = data->columns(); - //auto shift = 0; - auto rowSize = sizeof(T) * colCount; - auto stream = output->getContext()->getCudaStream(); - edgeForcesKernel<<<1, 128, 1024, *stream>>>(pRows, pCols, dataP, vals, outputP, N, colCount, rowSize); - NDArray::registerSpecialUse({output}, {rowP, colP, valP, data}); - } +template +static void barnes_edge_forces_(const NDArray* rowP, NDArray const* colP, + NDArray const* valP, int N, NDArray const* data, + NDArray* output) { + NDArray::prepareSpecialUse({output}, {data, rowP, colP, valP, valP}); + T const* dataP = reinterpret_cast(data->specialBuffer()); + T const* vals = reinterpret_cast(valP->specialBuffer()); + T* outputP = reinterpret_cast(output->specialBuffer()); + int const* pRows = reinterpret_cast(rowP->specialBuffer()); + int const* pCols = reinterpret_cast(colP->specialBuffer()); + int colCount = data->columns(); + // auto shift = 0; + auto rowSize = sizeof(T) * colCount; + auto stream = output->getContext()->getCudaStream(); + edgeForcesKernel<<<1, 128, 1024, *stream>>>(pRows, pCols, dataP, vals, + outputP, N, colCount, rowSize); + NDArray::registerSpecialUse({output}, {rowP, colP, valP, data}); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // edge forces caller // - void barnes_edge_forces(const NDArray* rowP, NDArray const* colP, NDArray const* valP, int N, NDArray* output, NDArray const& data) { - // Loop over all edges in the graph - BUILD_SINGLE_SELECTOR(output->dataType(), barnes_edge_forces_, (rowP, colP, valP, N, &data, output), FLOAT_TYPES); - } - BUILD_SINGLE_TEMPLATE(template void barnes_edge_forces_, (const NDArray* rowP, NDArray const* colP, NDArray const* valP, int N, NDArray const* data, NDArray* output), FLOAT_TYPES); +void barnes_edge_forces(const NDArray* rowP, NDArray const* colP, + NDArray const* valP, int N, NDArray* output, + NDArray const& data) { + // Loop over all edges in the graph + BUILD_SINGLE_SELECTOR(output->dataType(), barnes_edge_forces_, + (rowP, colP, valP, N, &data, output), FLOAT_TYPES); +} +BUILD_SINGLE_TEMPLATE(template void barnes_edge_forces_, + (const NDArray* rowP, NDArray const* colP, + NDArray const* valP, int N, NDArray const* data, + NDArray* output), + FLOAT_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -// gains - run a function T((x + 2.) * sd::math::nd4j_sign(grad) != sd::math::nd4j_sign(eps)) + T(x * 0.8 * sd::math::nd4j_sign(grad) != sd::math::nd4j_sign(eps)); -// for all members in input and put all in output +// gains - run a function T((x + 2.) * sd::math::nd4j_sign(grad) != +// sd::math::nd4j_sign(eps)) + T(x * 0.8 * sd::math::nd4j_sign(grad) +// != sd::math::nd4j_sign(eps)); for all members in input and put all in +// output // - template - void barnes_gains_(NDArray* input, NDArray* gradX, NDArray* epsilon, NDArray* output) { - auto gainsInternal = LAMBDA_TTT(x, grad, eps) { - T res = sd::math::nd4j_sign(grad) != sd::math::nd4j_sign(eps) ? x + T(.2) : x * T(.8); - if(res < .01) res = .01; - return res; - }; +template +void barnes_gains_(NDArray* input, NDArray* gradX, NDArray* epsilon, + NDArray* output) { + auto gainsInternal = LAMBDA_TTT(x, grad, eps) { + T res = sd::math::nd4j_sign(grad) != sd::math::nd4j_sign(eps) + ? x + T(.2) + : x * T(.8); + if (res < .01) res = .01; + return res; + }; - input->applyTriplewiseLambda(*gradX, *epsilon, gainsInternal, *output); - } + input->applyTriplewiseLambda(*gradX, *epsilon, gainsInternal, *output); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // gains caller - void barnes_gains(NDArray* input, NDArray* gradX, NDArray* epsilon, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), barnes_gains_, (input, gradX, epsilon, output), NUMERIC_TYPES); - } - BUILD_SINGLE_TEMPLATE(template void barnes_gains_, (NDArray* input, NDArray* gradX, NDArray* epsilon, NDArray* output), NUMERIC_TYPES); +void barnes_gains(NDArray* input, NDArray* gradX, NDArray* epsilon, + NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), barnes_gains_, + (input, gradX, epsilon, output), NUMERIC_TYPES); +} +BUILD_SINGLE_TEMPLATE(template void barnes_gains_, + (NDArray * input, NDArray* gradX, NDArray* epsilon, + NDArray* output), + NUMERIC_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // cell contains - check cells for given point // - bool cell_contains(NDArray* corner, NDArray* width, NDArray* point, Nd4jLong dimension) { - auto cornerMinusWidth = *corner - *width; - auto cornerPlusWidth = *corner + *width; - // executes on host side, so sync all to host memory - cornerMinusWidth.syncToHost(); - cornerPlusWidth.syncToHost(); - for (Nd4jLong i = 0; i < dimension; i++) { - if (cornerMinusWidth.e(i) > point->e(i)) - return false; - if (cornerPlusWidth.e(i) < point->e(i)) - return false; - } +bool cell_contains(NDArray* corner, NDArray* width, NDArray* point, + Nd4jLong dimension) { + auto cornerMinusWidth = *corner - *width; + auto cornerPlusWidth = *corner + *width; + // executes on host side, so sync all to host memory + cornerMinusWidth.syncToHost(); + cornerPlusWidth.syncToHost(); + for (Nd4jLong i = 0; i < dimension; i++) { + if (cornerMinusWidth.e(i) > point->e(i)) return false; + if (cornerPlusWidth.e(i) < point->e(i)) return false; + } - return true; - } -} + return true; } -} - +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/activations.cu b/libnd4j/include/ops/declarable/helpers/cuda/activations.cu index e675342d9242..e2913346706e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/activations.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/activations.cu @@ -19,593 +19,690 @@ // @author raver119@gmail.com // -#include -#include +#include +#include #include +#include +#include + #include -#include -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template +template __global__ void preluCuda(const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz) { - - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); + const void *vy, const Nd4jLong *yShapeInfo, + void *vz) { + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); - __shared__ Nd4jLong xzLen; - __shared__ int xzRank, yRank; + __shared__ Nd4jLong xzLen; + __shared__ int xzRank, yRank; - if (threadIdx.x == 0) { - xzLen = shape::length(xShapeInfo); + if (threadIdx.x == 0) { + xzLen = shape::length(xShapeInfo); - xzRank = shape::rank(xShapeInfo); - yRank = shape::rank(yShapeInfo); - } - __syncthreads(); + xzRank = shape::rank(xShapeInfo); + yRank = shape::rank(yShapeInfo); + } + __syncthreads(); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - int coords[MAX_RANK]; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + int coords[MAX_RANK]; - for (int i = tid; i < xzLen; i += blockDim.x * gridDim.x) { - shape::index2coords(i, xShapeInfo, coords); + for (int i = tid; i < xzLen; i += blockDim.x * gridDim.x) { + shape::index2coords(i, xShapeInfo, coords); - const auto xzOffset = shape::getOffset(xShapeInfo, coords); - const auto xVal = x[xzOffset]; + const auto xzOffset = shape::getOffset(xShapeInfo, coords); + const auto xVal = x[xzOffset]; - if(xVal < 0) { - for (uint j = 0; j < yRank; ++j) - if(yShapeInfo[j + 1] == 1) - coords[j + 1] = 0; + if (xVal < 0) { + for (uint j = 0; j < yRank; ++j) + if (yShapeInfo[j + 1] == 1) coords[j + 1] = 0; - z[xzOffset] = xVal * y[shape::getOffset(yShapeInfo, coords + 1)]; - } - else - z[xzOffset] = xVal; - } + z[xzOffset] = xVal * y[shape::getOffset(yShapeInfo, coords + 1)]; + } else + z[xzOffset] = xVal; + } } /////////////////////////////////////////////////////////////////// -template -linkage void preluCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz) { - preluCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz); +template +linkage void preluCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz) { + preluCuda<<>>( + vx, xShapeInfo, vy, yShapeInfo, vz); } /////////////////////////////////////////////////////////////////// -void prelu(sd::LaunchContext * context, const NDArray& input, const NDArray& alpha, NDArray& output) { - - PointersManager manager(context, "prelu"); - - const int threadsPerBlock = 256; - const int blocksPerGrid = 512; - const int sharedMem = 512; - - const auto xType = input.dataType(); - const auto yType = alpha.dataType(); - - NDArray::prepareSpecialUse({&output}, {&input, &alpha}); - BUILD_SINGLE_SELECTOR_TWICE(xType, preluCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), alpha.specialBuffer(), alpha.specialShapeInfo(), output.specialBuffer()), FLOAT_TYPES); - NDArray::registerSpecialUse({&output}, {&input, &alpha}); - - manager.synchronize(); +void prelu(sd::LaunchContext *context, const NDArray &input, + const NDArray &alpha, NDArray &output) { + PointersManager manager(context, "prelu"); + + const int threadsPerBlock = 256; + const int blocksPerGrid = 512; + const int sharedMem = 512; + + const auto xType = input.dataType(); + const auto yType = alpha.dataType(); + + NDArray::prepareSpecialUse({&output}, {&input, &alpha}); + BUILD_SINGLE_SELECTOR_TWICE( + xType, preluCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), alpha.specialBuffer(), + alpha.specialShapeInfo(), output.specialBuffer()), + FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&input, &alpha}); + + manager.synchronize(); } /////////////////////////////////////////////////////////////////// -template -__global__ linkage void preluBPCuda(const void *vIn, const Nd4jLong *inShapeInfo, - const void *vAlpha, const Nd4jLong *alphaShapeInfo, - const void *vdLdO, const Nd4jLong *dLdOShapeInfo, - void *vdLdI, const Nd4jLong *dLdIShapeInfo, - void *vdLdA, const Nd4jLong *dLdAShapeInfo) { - - const auto in = reinterpret_cast(vIn); - const auto alpha = reinterpret_cast(vAlpha); - const auto dLdO = reinterpret_cast(vdLdO); - auto dLdI = reinterpret_cast(vdLdI); - auto dLdA = reinterpret_cast(vdLdA); - - __shared__ Nd4jLong inLen, totalThreads; - __shared__ int inRank, alphaRank; - - if (threadIdx.x == 0) { - inLen = shape::length(inShapeInfo); - totalThreads = gridDim.x * blockDim.x; - - inRank = shape::rank(inShapeInfo); - alphaRank = shape::rank(alphaShapeInfo); - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - int coords[MAX_RANK]; - - for (int i = tid; i < inLen; i += totalThreads) { - shape::index2coords(i, inShapeInfo, coords); - - const auto inOffset = shape::getOffset(inShapeInfo, coords); - const auto dLdOOffset = shape::getOffset(dLdOShapeInfo, coords); - const auto dLdIOffset = shape::getOffset(dLdIShapeInfo, coords); - - const auto xVal = in[inOffset]; - const auto grO = dLdO[dLdOOffset]; - - if(xVal < 0) { - - for (uint j = 0; j < alphaRank; ++j) - if(alphaShapeInfo[j + 1] == 1) - coords[j + 1] = 0; - - const auto alphaOffset = shape::getOffset(alphaShapeInfo, coords + 1); - const auto dLdAOffset = shape::getOffset(dLdAShapeInfo, coords + 1); - - dLdI[dLdIOffset] = grO * alpha[alphaOffset]; - - sd::math::atomics::nd4j_atomicAdd(&dLdA[dLdAOffset], static_cast(grO * xVal)); - } - else - dLdI[dLdIOffset] = grO; - } +template +__global__ linkage void preluBPCuda( + const void *vIn, const Nd4jLong *inShapeInfo, const void *vAlpha, + const Nd4jLong *alphaShapeInfo, const void *vdLdO, + const Nd4jLong *dLdOShapeInfo, void *vdLdI, const Nd4jLong *dLdIShapeInfo, + void *vdLdA, const Nd4jLong *dLdAShapeInfo) { + const auto in = reinterpret_cast(vIn); + const auto alpha = reinterpret_cast(vAlpha); + const auto dLdO = reinterpret_cast(vdLdO); + auto dLdI = reinterpret_cast(vdLdI); + auto dLdA = reinterpret_cast(vdLdA); + + __shared__ Nd4jLong inLen, totalThreads; + __shared__ int inRank, alphaRank; + + if (threadIdx.x == 0) { + inLen = shape::length(inShapeInfo); + totalThreads = gridDim.x * blockDim.x; + + inRank = shape::rank(inShapeInfo); + alphaRank = shape::rank(alphaShapeInfo); + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + int coords[MAX_RANK]; + + for (int i = tid; i < inLen; i += totalThreads) { + shape::index2coords(i, inShapeInfo, coords); + + const auto inOffset = shape::getOffset(inShapeInfo, coords); + const auto dLdOOffset = shape::getOffset(dLdOShapeInfo, coords); + const auto dLdIOffset = shape::getOffset(dLdIShapeInfo, coords); + + const auto xVal = in[inOffset]; + const auto grO = dLdO[dLdOOffset]; + + if (xVal < 0) { + for (uint j = 0; j < alphaRank; ++j) + if (alphaShapeInfo[j + 1] == 1) coords[j + 1] = 0; + + const auto alphaOffset = shape::getOffset(alphaShapeInfo, coords + 1); + const auto dLdAOffset = shape::getOffset(dLdAShapeInfo, coords + 1); + + dLdI[dLdIOffset] = grO * alpha[alphaOffset]; + + sd::math::atomics::nd4j_atomicAdd(&dLdA[dLdAOffset], + static_cast(grO * xVal)); + } else + dLdI[dLdIOffset] = grO; + } } ////////////////////////////////////////////////////////////////////////// -template -__host__ linkage void preluBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vIn, const Nd4jLong *inShapeInfo, const void *vAlpha, const Nd4jLong *alphaShapeInfo, const void *vdLdO, const Nd4jLong *dLdOShapeInfo, void *vdLdI, const Nd4jLong *dLdIShapeInfo, void *vdLdA, const Nd4jLong *dLdAShapeInfo) { - - preluBPCuda<<>>(vIn, inShapeInfo, vAlpha, alphaShapeInfo, vdLdO, dLdOShapeInfo, vdLdI, dLdIShapeInfo, vdLdA, dLdAShapeInfo); +template +__host__ linkage void preluBPCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t *stream, const void *vIn, const Nd4jLong *inShapeInfo, + const void *vAlpha, const Nd4jLong *alphaShapeInfo, const void *vdLdO, + const Nd4jLong *dLdOShapeInfo, void *vdLdI, const Nd4jLong *dLdIShapeInfo, + void *vdLdA, const Nd4jLong *dLdAShapeInfo) { + preluBPCuda<<>>( + vIn, inShapeInfo, vAlpha, alphaShapeInfo, vdLdO, dLdOShapeInfo, vdLdI, + dLdIShapeInfo, vdLdA, dLdAShapeInfo); } ////////////////////////////////////////////////////////////////////////// -void preluBP(sd::LaunchContext* context, const NDArray& input, const NDArray& alpha, const NDArray& dLdO, NDArray& dLdI, NDArray& dLdA) { - dLdA.nullify(); - - PointersManager manager(context, "preluBP"); - - const int threadsPerBlock = 256; - const int blocksPerGrid = 512; - const int sharedMem = 512; - - const auto xType = input.dataType(); - const auto zType = alpha.dataType(); - - NDArray::prepareSpecialUse({&dLdI, &dLdA}, {&input, &alpha, &dLdO}); - BUILD_SINGLE_SELECTOR_TWICE(xType, preluBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), alpha.specialBuffer(), alpha.specialShapeInfo(), dLdO.specialBuffer(), dLdO.specialShapeInfo(), dLdI.specialBuffer(), dLdI.specialShapeInfo(), dLdA.specialBuffer(), dLdA.specialShapeInfo()), FLOAT_TYPES); - NDArray::registerSpecialUse({&dLdI, &dLdA}, {&input, &alpha, &dLdO}); - - manager.synchronize(); +void preluBP(sd::LaunchContext *context, const NDArray &input, + const NDArray &alpha, const NDArray &dLdO, NDArray &dLdI, + NDArray &dLdA) { + dLdA.nullify(); + + PointersManager manager(context, "preluBP"); + + const int threadsPerBlock = 256; + const int blocksPerGrid = 512; + const int sharedMem = 512; + + const auto xType = input.dataType(); + const auto zType = alpha.dataType(); + + NDArray::prepareSpecialUse({&dLdI, &dLdA}, {&input, &alpha, &dLdO}); + BUILD_SINGLE_SELECTOR_TWICE( + xType, preluBPCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), alpha.specialBuffer(), + alpha.specialShapeInfo(), dLdO.specialBuffer(), dLdO.specialShapeInfo(), + dLdI.specialBuffer(), dLdI.specialShapeInfo(), dLdA.specialBuffer(), + dLdA.specialShapeInfo()), + FLOAT_TYPES); + NDArray::registerSpecialUse({&dLdI, &dLdA}, {&input, &alpha, &dLdO}); + + manager.synchronize(); } - /////////////////////////////////////////////////////////////////// -template -__device__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { - - // logic of this kernel is based on assumption gridDim = 1 - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ Nd4jLong len; - __shared__ int numOfIters; - __shared__ T* shmem; - - if (threadIdx.x == 0) { - extern __shared__ char shared[]; - shmem = reinterpret_cast(shared); - len = shape::length(xShapeInfo); - numOfIters = (len + blockDim.x - 1) / blockDim.x; // ceil (len / blockDim.x) - } - __syncthreads(); - - T temp = -DataTypeUtils::max(); // set start value to compare with at first iteration, FIXME: what if T is unsigned ?? - - // ************ evaluate max element in input array x ************ // - for (int i = 0; i < numOfIters; ++i) { - - const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; - if(elemIdx < len) { - const Nd4jLong xOffset = shape::getIndexOffset(elemIdx, xShapeInfo); - shmem[threadIdx.x] = (threadIdx.x != 0) ? x[xOffset] : sd::math::nd4j_max(x[xOffset], temp); // take into account max element evaluated on previous iteration and stored in temp - } - else - shmem[threadIdx.x] = -DataTypeUtils::max(); // FIXME: what if T is unsigned ?? - - __syncthreads(); - - for (int s = blockDim.x / 2; s > 0; s /= 2) { - if(threadIdx.x < s) - shmem[threadIdx.x] = sd::math::nd4j_max(shmem[threadIdx.x], shmem[threadIdx.x + s]); - __syncthreads(); - } - - temp = shmem[0]; // save max value calculated at current iteration - } - - const T max = temp; - temp = 0; - - // ************ evaluate value of exp(x[offset] - max) per each element, store it to shared memory shmem ************ // - // at the same evaluate sum of exponents, sum will be stored in shmem[0] - for (int i = 0; i < numOfIters; ++i) { - - const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; - if(elemIdx < len) { - const Nd4jLong xOffset = shape::getIndexOffset(elemIdx, xShapeInfo); - const Nd4jLong zOffset = shape::getIndexOffset(elemIdx, zShapeInfo); - z[zOffset] = sd::math::nd4j_exp(x[xOffset] - max); - shmem[threadIdx.x] = (threadIdx.x != 0) ? z[zOffset] : (z[zOffset] + temp); // take into account sum element evaluated on previous iteration and stored in temp - } - else - shmem[threadIdx.x] = 0; - - __syncthreads(); - - for (int s = blockDim.x / 2; s > 0; s /= 2) { - if(threadIdx.x < s) - shmem[threadIdx.x] += shmem[threadIdx.x + s]; - __syncthreads(); - } - - temp = shmem[0]; // save sum calculated at current iteration - } - - // ************ evaluate z[offset] / sum ************ // - for (int i = 0; i < numOfIters; ++i) { - const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; - if(elemIdx >= len) continue; - const Nd4jLong zOffset = shape::getIndexOffset(elemIdx, zShapeInfo); - z[zOffset] /= shmem[0]; - } +template +__device__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xShapeInfo, + void *vz, const Nd4jLong *zShapeInfo) { + // logic of this kernel is based on assumption gridDim = 1 + + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ Nd4jLong len; + __shared__ int numOfIters; + __shared__ T *shmem; + + if (threadIdx.x == 0) { + extern __shared__ char shared[]; + shmem = reinterpret_cast(shared); + len = shape::length(xShapeInfo); + numOfIters = + (len + blockDim.x - 1) / blockDim.x; // ceil (len / blockDim.x) + } + __syncthreads(); + + T temp = + -DataTypeUtils::max(); // set start value to compare with at first + // iteration, FIXME: what if T is unsigned ?? + + // ************ evaluate max element in input array x ************ // + for (int i = 0; i < numOfIters; ++i) { + const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; + if (elemIdx < len) { + const Nd4jLong xOffset = shape::getIndexOffset(elemIdx, xShapeInfo); + shmem[threadIdx.x] = + (threadIdx.x != 0) + ? x[xOffset] + : sd::math::nd4j_max( + x[xOffset], + temp); // take into account max element evaluated on + // previous iteration and stored in temp + } else + shmem[threadIdx.x] = + -DataTypeUtils::max(); // FIXME: what if T is unsigned ?? + + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s /= 2) { + if (threadIdx.x < s) + shmem[threadIdx.x] = + sd::math::nd4j_max(shmem[threadIdx.x], shmem[threadIdx.x + s]); + __syncthreads(); + } + + temp = shmem[0]; // save max value calculated at current iteration + } + + const T max = temp; + temp = 0; + + // ************ evaluate value of exp(x[offset] - max) per each element, store + // it to shared memory shmem ************ // at the same evaluate sum of + // exponents, sum will be stored in shmem[0] + for (int i = 0; i < numOfIters; ++i) { + const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; + if (elemIdx < len) { + const Nd4jLong xOffset = shape::getIndexOffset(elemIdx, xShapeInfo); + const Nd4jLong zOffset = shape::getIndexOffset(elemIdx, zShapeInfo); + z[zOffset] = sd::math::nd4j_exp(x[xOffset] - max); + shmem[threadIdx.x] = + (threadIdx.x != 0) + ? z[zOffset] + : (z[zOffset] + + temp); // take into account sum element evaluated on previous + // iteration and stored in temp + } else + shmem[threadIdx.x] = 0; + + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s /= 2) { + if (threadIdx.x < s) shmem[threadIdx.x] += shmem[threadIdx.x + s]; + __syncthreads(); + } + + temp = shmem[0]; // save sum calculated at current iteration + } + + // ************ evaluate z[offset] / sum ************ // + for (int i = 0; i < numOfIters; ++i) { + const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; + if (elemIdx >= len) continue; + const Nd4jLong zOffset = shape::getIndexOffset(elemIdx, zShapeInfo); + z[zOffset] /= shmem[0]; + } } -template -__global__ void softMaxForVectorCudaGlobal(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { - - softMaxForVectorCuda(vx, xShapeInfo, vz, zShapeInfo); +template +__global__ void softMaxForVectorCudaGlobal(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + softMaxForVectorCuda(vx, xShapeInfo, vz, zShapeInfo); } /////////////////////////////////////////////////////////////////// template -linkage void softMaxForVectorCudaLauncher(const cudaStream_t* stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { - - softMaxForVectorCudaGlobal<<<1, MAX_NUM_THREADS / 4 , (MAX_NUM_THREADS / 4) * sizeof(T) + 512, *stream>>>(vx, xShapeInfo, vz, zShapeInfo); +linkage void softMaxForVectorCudaLauncher(const cudaStream_t *stream, + const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + softMaxForVectorCudaGlobal + <<<1, MAX_NUM_THREADS / 4, (MAX_NUM_THREADS / 4) * sizeof(T) + 512, + *stream>>>(vx, xShapeInfo, vz, zShapeInfo); } /////////////////////////////////////////////////////////////////// -template -__global__ static void softMaxCuda(const void* vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, - void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets) { - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - const auto* xTad = x + xOffsets[blockIdx.x]; - auto* zTad = z + zOffsets[blockIdx.x]; - - softMaxForVectorCuda(xTad, xTadShapeInfo, zTad, zTadShapeInfo); +template +__global__ static void softMaxCuda(const void *vx, + const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xOffsets, void *vz, + const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zOffsets) { + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + const auto *xTad = x + xOffsets[blockIdx.x]; + auto *zTad = z + zOffsets[blockIdx.x]; + + softMaxForVectorCuda(xTad, xTadShapeInfo, zTad, zTadShapeInfo); } /////////////////////////////////////////////////////////////////// -template -static void softMaxCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, - void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets) { - - softMaxCuda<<>>(vx, xTadShapeInfo, xOffsets, vz, zTadShapeInfo, zOffsets); +template +static void softMaxCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t *stream, const void *vx, + const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xOffsets, void *vz, + const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zOffsets) { + softMaxCuda<<>>( + vx, xTadShapeInfo, xOffsets, vz, zTadShapeInfo, zOffsets); } - ////////////////////////////////////////////////////////////////////////// -void softmax(sd::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension) { - - if(!input.isActualOnDeviceSide()) input.syncToDevice(); - const int rank = input.rankOf(); - - PointersManager manager(context, "helpers::softmax"); - - if(input.isVector()) { - - if(rank == 1 || input.sizeAt(dimension) != 1) { - NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), softMaxForVectorCudaLauncher, (context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo()), FLOAT_TYPES); - NDArray::registerSpecialUse({&output}, {&input}); - } - else - output = 1.; - } - else { - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input.shapeInfo(), {dimension}); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output.shapeInfo(), {dimension}); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = packZ.numberOfTads(); - const int sharedMem = input.sizeOfT() * threadsPerBlock + 512; - - NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), softMaxCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), packX.specialShapeInfo(), packX.specialOffsets(), output.specialBuffer(), packZ.specialShapeInfo(), packZ.specialOffsets()), FLOAT_TYPES); - NDArray::registerSpecialUse({&output}, {&input}); - - // auto maxAlongDim = const_cast(input).reduceAlongDimension(reduce::Max, {dimension}, true); - // (input - maxAlongDim).applyTransform(transform::Exp, &output); // output contains exponents temporarily - // auto sumAlongDim = output.reduceAlongDimension(reduce::Sum, {dimension}, true); - // output /= sumAlongDim; - // input.tickReadDevice(); - } - - - manager.synchronize(); - - output.tickWriteDevice(); +void softmax(sd::LaunchContext *context, const NDArray &input, NDArray &output, + const int dimension) { + if (!input.isActualOnDeviceSide()) input.syncToDevice(); + const int rank = input.rankOf(); + + PointersManager manager(context, "helpers::softmax"); + + if (input.isVector()) { + if (rank == 1 || input.sizeAt(dimension) != 1) { + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR(input.dataType(), softMaxForVectorCudaLauncher, + (context->getCudaStream(), input.specialBuffer(), + input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo()), + FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + } else + output = 1.; + } else { + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input.shapeInfo(), {dimension}); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output.shapeInfo(), {dimension}); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = packZ.numberOfTads(); + const int sharedMem = input.sizeOfT() * threadsPerBlock + 512; + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR( + input.dataType(), softMaxCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + input.specialBuffer(), packX.specialShapeInfo(), + packX.specialOffsets(), output.specialBuffer(), + packZ.specialShapeInfo(), packZ.specialOffsets()), + FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + // auto maxAlongDim = + // const_cast(input).reduceAlongDimension(reduce::Max, + // {dimension}, true); (input - maxAlongDim).applyTransform(transform::Exp, + // &output); // output contains exponents temporarily auto sumAlongDim = + // output.reduceAlongDimension(reduce::Sum, {dimension}, true); output /= + // sumAlongDim; input.tickReadDevice(); + } + + manager.synchronize(); + + output.tickWriteDevice(); } /////////////////////////////////////////////////////////////////// -template -__global__ void logSoftMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo, void *vz) { - - // logic of this kernel is based on assumption gridDim = 1 - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ Nd4jLong len; - __shared__ int numOfIters; - __shared__ T* shmem; - - if (threadIdx.x == 0) { - extern __shared__ char shared[]; - shmem = reinterpret_cast(shared); - len = shape::length(xzShapeInfo); - numOfIters = (len + blockDim.x - 1) / blockDim.x; // ceil (len / blockDim.x) - } - __syncthreads(); - - T temp = -DataTypeUtils::max(); // set start value to compare with at first iteration, FIXME: what if T is unsigned ?? - - // ************ evaluate max element in input array x ************ // - for (int i = 0; i < numOfIters; ++i) { - - const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; - if(elemIdx < len) { - const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo); - shmem[threadIdx.x] = (threadIdx.x != 0) ? x[offset] : sd::math::nd4j_max(x[offset], temp); // take into account max element evaluated on previous iteration and stored in temp - } - else - shmem[threadIdx.x] = -DataTypeUtils::max(); // FIXME: what if T is unsigned ?? - - __syncthreads(); - - for (int s = blockDim.x / 2; s > 0; s /= 2) { - if(threadIdx.x < s) - shmem[threadIdx.x] = sd::math::nd4j_max(shmem[threadIdx.x], shmem[threadIdx.x + s]); - __syncthreads(); - } - - temp = shmem[0]; // save max value calculated at current iteration - } - - const T max = temp; - temp = 0; - - // ************ evaluate value of exp(x[offset] - max) per each element, store it to shared memory shmem ************ // - // at the same time evaluate sum of exponents, sum will be stored in shmem[0] - for (int i = 0; i < numOfIters; ++i) { - - const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; - if(elemIdx < len) { - const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo); - z[offset] = sd::math::nd4j_exp(x[offset] - max); - shmem[threadIdx.x] = (threadIdx.x != 0) ? z[offset] : (z[offset] + temp); // take into account sum element evaluated on previous iteration and stored in temp - } - else - shmem[threadIdx.x] = 0; - - __syncthreads(); - - for (int s = blockDim.x / 2; s > 0; s /= 2) { - if(threadIdx.x < s) - shmem[threadIdx.x] += shmem[threadIdx.x + s]; - __syncthreads(); - } - - temp = shmem[0]; // save sum calculated at current iteration - } - - // ************ evaluate log(z[offset] / sum) ************ // - for (int i = 0; i < numOfIters; ++i) { - const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; - if(elemIdx >= len) continue; - const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo); - z[offset] = sd::math::nd4j_log(z[offset] / shmem[0]); - } +template +__global__ void logSoftMaxForVectorCuda(const void *vx, + const Nd4jLong *xzShapeInfo, void *vz) { + // logic of this kernel is based on assumption gridDim = 1 + + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ Nd4jLong len; + __shared__ int numOfIters; + __shared__ T *shmem; + + if (threadIdx.x == 0) { + extern __shared__ char shared[]; + shmem = reinterpret_cast(shared); + len = shape::length(xzShapeInfo); + numOfIters = + (len + blockDim.x - 1) / blockDim.x; // ceil (len / blockDim.x) + } + __syncthreads(); + + T temp = + -DataTypeUtils::max(); // set start value to compare with at first + // iteration, FIXME: what if T is unsigned ?? + + // ************ evaluate max element in input array x ************ // + for (int i = 0; i < numOfIters; ++i) { + const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; + if (elemIdx < len) { + const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo); + shmem[threadIdx.x] = + (threadIdx.x != 0) + ? x[offset] + : sd::math::nd4j_max( + x[offset], + temp); // take into account max element evaluated on + // previous iteration and stored in temp + } else + shmem[threadIdx.x] = + -DataTypeUtils::max(); // FIXME: what if T is unsigned ?? + + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s /= 2) { + if (threadIdx.x < s) + shmem[threadIdx.x] = + sd::math::nd4j_max(shmem[threadIdx.x], shmem[threadIdx.x + s]); + __syncthreads(); + } + + temp = shmem[0]; // save max value calculated at current iteration + } + + const T max = temp; + temp = 0; + + // ************ evaluate value of exp(x[offset] - max) per each element, store + // it to shared memory shmem ************ // at the same time evaluate sum of + // exponents, sum will be stored in shmem[0] + for (int i = 0; i < numOfIters; ++i) { + const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; + if (elemIdx < len) { + const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo); + z[offset] = sd::math::nd4j_exp(x[offset] - max); + shmem[threadIdx.x] = + (threadIdx.x != 0) + ? z[offset] + : (z[offset] + temp); // take into account sum element evaluated + // on previous iteration and stored in temp + } else + shmem[threadIdx.x] = 0; + + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s /= 2) { + if (threadIdx.x < s) shmem[threadIdx.x] += shmem[threadIdx.x + s]; + __syncthreads(); + } + + temp = shmem[0]; // save sum calculated at current iteration + } + + // ************ evaluate log(z[offset] / sum) ************ // + for (int i = 0; i < numOfIters; ++i) { + const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; + if (elemIdx >= len) continue; + const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo); + z[offset] = sd::math::nd4j_log(z[offset] / shmem[0]); + } } /////////////////////////////////////////////////////////////////// template -linkage void logSoftMaxForVectorCudaLauncher(const cudaStream_t* stream, const void *vx, const Nd4jLong *xzShapeInfo, void *vz) { - - logSoftMaxForVectorCuda<<<1, MAX_NUM_THREADS, MAX_NUM_THREADS * sizeof(T) + 512, *stream>>>(vx, xzShapeInfo, vz); +linkage void logSoftMaxForVectorCudaLauncher(const cudaStream_t *stream, + const void *vx, + const Nd4jLong *xzShapeInfo, + void *vz) { + logSoftMaxForVectorCuda + <<<1, MAX_NUM_THREADS, MAX_NUM_THREADS * sizeof(T) + 512, *stream>>>( + vx, xzShapeInfo, vz); } ////////////////////////////////////////////////////////////////////////// -void logSoftmax(sd::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension) { - - if(!input.isActualOnDeviceSide()) input.syncToDevice(); - const int rank = input.rankOf(); - - if(input.isVector()) { - - if(rank == 1 || input.sizeAt(dimension) != 1) { - BUILD_SINGLE_SELECTOR(input.dataType(), logSoftMaxForVectorCudaLauncher, (context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer()), FLOAT_TYPES); - input.tickReadDevice(); - } - else - output = 0.; - } - else { - - auto maxAlongDim = const_cast(input).reduceAlongDimension(reduce::Max, {dimension}, true); - (input - maxAlongDim).applyTransform(transform::Exp, output); // output contains exponents temporarily - auto sumAlongDim = output.reduceAlongDimension(reduce::Sum, {dimension}, true); - output /= sumAlongDim; - output.applyTransform(transform::Log, output); - input.tickReadDevice(); - } - - PointersManager manager(context, "helpers::logSoftmax"); - manager.synchronize(); - - output.tickWriteDevice(); +void logSoftmax(sd::LaunchContext *context, const NDArray &input, + NDArray &output, const int dimension) { + if (!input.isActualOnDeviceSide()) input.syncToDevice(); + const int rank = input.rankOf(); + + if (input.isVector()) { + if (rank == 1 || input.sizeAt(dimension) != 1) { + BUILD_SINGLE_SELECTOR(input.dataType(), logSoftMaxForVectorCudaLauncher, + (context->getCudaStream(), input.specialBuffer(), + input.specialShapeInfo(), output.specialBuffer()), + FLOAT_TYPES); + input.tickReadDevice(); + } else + output = 0.; + } else { + auto maxAlongDim = const_cast(input).reduceAlongDimension( + reduce::Max, {dimension}, true); + (input - maxAlongDim) + .applyTransform(transform::Exp, + output); // output contains exponents temporarily + auto sumAlongDim = + output.reduceAlongDimension(reduce::Sum, {dimension}, true); + output /= sumAlongDim; + output.applyTransform(transform::Log, output); + input.tickReadDevice(); + } + + PointersManager manager(context, "helpers::logSoftmax"); + manager.synchronize(); + + output.tickWriteDevice(); } /////////////////////////////////////////////////////////////////// -template -__global__ linkage void softMaxDerivForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo, void *vz) { - - // logic of this kernel is based on assumption gridDim = 1 - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ Nd4jLong len; - __shared__ int numOfIters; - __shared__ T* shmem; - - if (threadIdx.x == 0) { - extern __shared__ char shared[]; - shmem = reinterpret_cast(shared); - len = shape::length(xzShapeInfo); - numOfIters = (len + blockDim.x - 1) / blockDim.x; // ceil (len / blockDim.x) - } - __syncthreads(); - - T temp = -DataTypeUtils::max(); // set start value to compare with at first iteration, FIXME: what if T is unsigned ?? - - // ************ evaluate max element in input array x ************ // - for (int i = 0; i < numOfIters; ++i) { - - const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; - if(elemIdx < len) { - const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo); - shmem[threadIdx.x] = (threadIdx.x != 0) ? x[offset] : sd::math::nd4j_max(x[offset], temp); // take into account max element evaluated on previous iteration and stored in temp - } - else - shmem[threadIdx.x] = -DataTypeUtils::max(); // FIXME: what if T is unsigned ?? - - __syncthreads(); - - for (int s = blockDim.x / 2; s > 0; s /= 2) { - if(threadIdx.x < s) - shmem[threadIdx.x] = sd::math::nd4j_max(shmem[threadIdx.x], shmem[threadIdx.x + s]); - __syncthreads(); - } - - temp = shmem[0]; // save max value calculated at current iteration - } - - const T max = temp; - temp = 0; - - // ************ evaluate value of exp(x[offset] - max) per each element, store it to shared memory shmem ************ // - // at the same evaluate sum of exponents, sum will be stored in shmem[0] - for (int i = 0; i < numOfIters; ++i) { - - const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; - if(elemIdx < len) { - const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo); - z[offset] = sd::math::nd4j_exp(x[offset] - max); - shmem[threadIdx.x] = (threadIdx.x != 0) ? z[offset] : (z[offset] + temp); // take into account sum element evaluated on previous iteration and stored in temp - } - else - shmem[threadIdx.x] = 0; - - __syncthreads(); - - for (int s = blockDim.x / 2; s > 0; s /= 2) { - if(threadIdx.x < s) - shmem[threadIdx.x] += shmem[threadIdx.x + s]; - __syncthreads(); - } - - temp = shmem[0]; // save sum calculated at current iteration - } - - // ************ evaluate (z[offset] / sum) and derivative z[offset] = z[offset] * (1 - z[offset]) ************ // - for (int i = 0; i < numOfIters; ++i) { - const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; - if(elemIdx >= len) continue; - const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo); - z[offset] /= shmem[0]; - z[offset] *= (1.f - z[offset]); // derivative - } +template +__global__ linkage void softMaxDerivForVectorCuda(const void *vx, + const Nd4jLong *xzShapeInfo, + void *vz) { + // logic of this kernel is based on assumption gridDim = 1 + + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ Nd4jLong len; + __shared__ int numOfIters; + __shared__ T *shmem; + + if (threadIdx.x == 0) { + extern __shared__ char shared[]; + shmem = reinterpret_cast(shared); + len = shape::length(xzShapeInfo); + numOfIters = + (len + blockDim.x - 1) / blockDim.x; // ceil (len / blockDim.x) + } + __syncthreads(); + + T temp = + -DataTypeUtils::max(); // set start value to compare with at first + // iteration, FIXME: what if T is unsigned ?? + + // ************ evaluate max element in input array x ************ // + for (int i = 0; i < numOfIters; ++i) { + const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; + if (elemIdx < len) { + const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo); + shmem[threadIdx.x] = + (threadIdx.x != 0) + ? x[offset] + : sd::math::nd4j_max( + x[offset], + temp); // take into account max element evaluated on + // previous iteration and stored in temp + } else + shmem[threadIdx.x] = + -DataTypeUtils::max(); // FIXME: what if T is unsigned ?? + + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s /= 2) { + if (threadIdx.x < s) + shmem[threadIdx.x] = + sd::math::nd4j_max(shmem[threadIdx.x], shmem[threadIdx.x + s]); + __syncthreads(); + } + + temp = shmem[0]; // save max value calculated at current iteration + } + + const T max = temp; + temp = 0; + + // ************ evaluate value of exp(x[offset] - max) per each element, store + // it to shared memory shmem ************ // at the same evaluate sum of + // exponents, sum will be stored in shmem[0] + for (int i = 0; i < numOfIters; ++i) { + const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; + if (elemIdx < len) { + const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo); + z[offset] = sd::math::nd4j_exp(x[offset] - max); + shmem[threadIdx.x] = + (threadIdx.x != 0) + ? z[offset] + : (z[offset] + temp); // take into account sum element evaluated + // on previous iteration and stored in temp + } else + shmem[threadIdx.x] = 0; + + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s /= 2) { + if (threadIdx.x < s) shmem[threadIdx.x] += shmem[threadIdx.x + s]; + __syncthreads(); + } + + temp = shmem[0]; // save sum calculated at current iteration + } + + // ************ evaluate (z[offset] / sum) and derivative z[offset] = + // z[offset] * (1 - z[offset]) ************ // + for (int i = 0; i < numOfIters; ++i) { + const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; + if (elemIdx >= len) continue; + const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo); + z[offset] /= shmem[0]; + z[offset] *= (1.f - z[offset]); // derivative + } } /////////////////////////////////////////////////////////////////// template -linkage void softMaxDerivForVectorCudaLauncher(const cudaStream_t* stream, const void *vx, const Nd4jLong *xzShapeInfo, void *vz) { - - softMaxDerivForVectorCuda<<<1, MAX_NUM_THREADS, MAX_NUM_THREADS * sizeof(T) + 512, *stream>>>(vx, xzShapeInfo, vz); +linkage void softMaxDerivForVectorCudaLauncher(const cudaStream_t *stream, + const void *vx, + const Nd4jLong *xzShapeInfo, + void *vz) { + softMaxDerivForVectorCuda + <<<1, MAX_NUM_THREADS, MAX_NUM_THREADS * sizeof(T) + 512, *stream>>>( + vx, xzShapeInfo, vz); } /////////////////////////////////////////////////////////////////// -void softmaxDerivative(sd::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension) { - - if(!input.isActualOnDeviceSide()) input.syncToDevice(); - const int rank = input.rankOf(); - int temp; - - if(shape::isCommonVector(input.shapeInfo(), temp)) { - - BUILD_SINGLE_SELECTOR(input.dataType(), softMaxDerivForVectorCudaLauncher, (context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer()), FLOAT_TYPES); - input.tickReadDevice(); - } - else { - - auto maxAlongDim = const_cast(input).reduceAlongDimension(reduce::Max, {dimension}, true); - (input - maxAlongDim).applyTransform(transform::Exp, output); // output contains exponents temporarily - auto sumAlongDim = output.reduceAlongDimension(reduce::Sum, {dimension}, true); - output /= sumAlongDim; - output *= (1.f - output); // derivative - input.tickReadDevice(); - } - - PointersManager manager(context, "helpers::softmaxDerivative"); - manager.synchronize(); - - output.tickWriteDevice(); +void softmaxDerivative(sd::LaunchContext *context, const NDArray &input, + NDArray &output, const int dimension) { + if (!input.isActualOnDeviceSide()) input.syncToDevice(); + const int rank = input.rankOf(); + int temp; + + if (shape::isCommonVector(input.shapeInfo(), temp)) { + BUILD_SINGLE_SELECTOR(input.dataType(), softMaxDerivForVectorCudaLauncher, + (context->getCudaStream(), input.specialBuffer(), + input.specialShapeInfo(), output.specialBuffer()), + FLOAT_TYPES); + input.tickReadDevice(); + } else { + auto maxAlongDim = const_cast(input).reduceAlongDimension( + reduce::Max, {dimension}, true); + (input - maxAlongDim) + .applyTransform(transform::Exp, + output); // output contains exponents temporarily + auto sumAlongDim = + output.reduceAlongDimension(reduce::Sum, {dimension}, true); + output /= sumAlongDim; + output *= (1.f - output); // derivative + input.tickReadDevice(); + } + + PointersManager manager(context, "helpers::softmaxDerivative"); + manager.synchronize(); + + output.tickWriteDevice(); } +template +linkage void thresholdRelu_(NDArray const &input, double threshold, + NDArray &output) { + auto routine = LAMBDA_T(_x, threshold) { + return _x > (T)threshold ? _x : (T)0.f; + }; + const_cast(input).applyLambda(routine, output); +} - template - linkage void thresholdRelu_(NDArray const& input, double threshold, NDArray& output) { - auto routine = LAMBDA_T(_x, threshold) { - return _x > (T)threshold ? _x: (T)0.f; - }; - const_cast(input).applyLambda(routine, output); - } - - void thresholdRelu(sd::LaunchContext * context, NDArray const& input, double threshold, NDArray& output) { - BUILD_SINGLE_SELECTOR(input.dataType(), thresholdRelu_, (input, threshold, output), FLOAT_TYPES); - } - - template - linkage void thresholdReluDerivative_(NDArray* input, double theta, NDArray* dLdO, NDArray* output) { - auto derivative = LAMBDA_TT(_x, grO, theta) {if (_x > theta) return grO; else return static_cast(0); }; - - input->applyPairwiseLambda(*dLdO, derivative, *output); - } - - void thresholdReluDerivative(sd::LaunchContext * context, NDArray* input, double threshold, NDArray* dLdO, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), thresholdReluDerivative_, (input, threshold, dLdO, output), FLOAT_TYPES); - } - +void thresholdRelu(sd::LaunchContext *context, NDArray const &input, + double threshold, NDArray &output) { + BUILD_SINGLE_SELECTOR(input.dataType(), thresholdRelu_, + (input, threshold, output), FLOAT_TYPES); } + +template +linkage void thresholdReluDerivative_(NDArray *input, double theta, + NDArray *dLdO, NDArray *output) { + auto derivative = LAMBDA_TT(_x, grO, theta) { + if (_x > theta) + return grO; + else + return static_cast(0); + }; + + input->applyPairwiseLambda(*dLdO, derivative, *output); } + +void thresholdReluDerivative(sd::LaunchContext *context, NDArray *input, + double threshold, NDArray *dLdO, NDArray *output) { + BUILD_SINGLE_SELECTOR(input->dataType(), thresholdReluDerivative_, + (input, threshold, dLdO, output), FLOAT_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/addBias.cu b/libnd4j/include/ops/declarable/helpers/cuda/addBias.cu index 18474f2c79ba..e646ef0f6440 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/addBias.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/addBias.cu @@ -18,131 +18,139 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // - -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////// -template -__global__ static void addBiasCuda( const void* vx, const Nd4jLong* xShapeInfo, - const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const bool isNCHW) { - - // bias [oC] - - // if(input_rank == 4) - // input and output have same shapes: [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - // if(input_rank == 5) - // input and output have same shapes: [bS, oD, oH, oW, oC] (NHWC) or [bS, oD, oC, oH, oW] (NCHW) - - const X* x = reinterpret_cast(vx); - const Y* y = reinterpret_cast(vy); - X* z = reinterpret_cast(vz); - - __shared__ int rank, channelPosition, posOfNonUnityDim; - __shared__ Nd4jLong len, *sharedMem; - __shared__ bool xzSameOffsets, xzAreSame; - - if (threadIdx.x == 0) { - - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - rank = shape::rank(xShapeInfo); // xRank == zRank - xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - len = shape::length(xShapeInfo); - channelPosition = isNCHW ? 1 : rank - 1; // second or last - xzAreSame = x == z; - - shape::isCommonVector(yShapeInfo, posOfNonUnityDim); - } - __syncthreads(); - - auto coords = sharedMem + threadIdx.x * rank; - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < len; i += blockDim.x * gridDim.x) { - - shape::index2coords(i, xShapeInfo, coords); - - const auto xOffsets = shape::getOffset(xShapeInfo, coords); - const auto zOffsets = xzSameOffsets ? xOffsets : shape::getOffset(zShapeInfo, coords); - const auto yOffsets = coords[channelPosition] * shape::stride(yShapeInfo)[posOfNonUnityDim]; - - if(xzAreSame) - z[zOffsets] += static_cast(y[yOffsets]); - else - z[zOffsets] = x[xOffsets] + static_cast(y[yOffsets]); - } +template +__global__ static void addBiasCuda(const void* vx, const Nd4jLong* xShapeInfo, + const void* vy, const Nd4jLong* yShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + const bool isNCHW) { + // bias [oC] + + // if(input_rank == 4) + // input and output have same shapes: [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, + // oW] (NCHW) + // if(input_rank == 5) + // input and output have same shapes: [bS, oD, oH, oW, oC] (NHWC) or [bS, oD, + // oC, oH, oW] (NCHW) + + const X* x = reinterpret_cast(vx); + const Y* y = reinterpret_cast(vy); + X* z = reinterpret_cast(vz); + + __shared__ int rank, channelPosition, posOfNonUnityDim; + __shared__ Nd4jLong len, *sharedMem; + __shared__ bool xzSameOffsets, xzAreSame; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + rank = shape::rank(xShapeInfo); // xRank == zRank + xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + len = shape::length(xShapeInfo); + channelPosition = isNCHW ? 1 : rank - 1; // second or last + xzAreSame = x == z; + + shape::isCommonVector(yShapeInfo, posOfNonUnityDim); + } + __syncthreads(); + + auto coords = sharedMem + threadIdx.x * rank; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < len; + i += blockDim.x * gridDim.x) { + shape::index2coords(i, xShapeInfo, coords); + + const auto xOffsets = shape::getOffset(xShapeInfo, coords); + const auto zOffsets = + xzSameOffsets ? xOffsets : shape::getOffset(zShapeInfo, coords); + const auto yOffsets = + coords[channelPosition] * shape::stride(yShapeInfo)[posOfNonUnityDim]; + + if (xzAreSame) + z[zOffsets] += static_cast(y[yOffsets]); + else + z[zOffsets] = x[xOffsets] + static_cast(y[yOffsets]); + } } ////////////////////////////////////////////////////////////////////////// -template -static void addBiasCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const bool isNCHW) { - - addBiasCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, isNCHW); +template +static void addBiasCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, + const Nd4jLong* xShapeInfo, const void* vy, + const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const bool isNCHW) { + addBiasCuda<<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, isNCHW); } -template -__global__ static void addBias2DCuda( const void* vx, - const void* vy, - void* vz, - uint32_t blocks, uint32_t length) { +template +__global__ static void addBias2DCuda(const void* vx, const void* vy, void* vz, + uint32_t blocks, uint32_t length) { + auto y = reinterpret_cast(vy); - auto y = reinterpret_cast(vy); + for (uint32_t b = blockIdx.x; b < blocks; b += gridDim.x) { + auto x = reinterpret_cast(vx) + length * b; + auto z = reinterpret_cast(vz) + length * b; - for (uint32_t b = blockIdx.x; b < blocks; b += gridDim.x) { - auto x = reinterpret_cast(vx) + length * b; - auto z = reinterpret_cast(vz) + length * b; - - for (uint32_t e = threadIdx.x; e < length; e += blockDim.x) { - z[e] = x[e] + y[e]; - } + for (uint32_t e = threadIdx.x; e < length; e += blockDim.x) { + z[e] = x[e] + y[e]; } + } } -template -static void addBias2DCudaLauncher(const cudaStream_t *stream, const void* vx, - const void* vy, - void* vz, - uint32_t blocks, uint32_t length) { - - addBias2DCuda<<<256, 1024, 128, *stream>>>(vx, vy, vz, blocks, length); +template +static void addBias2DCudaLauncher(const cudaStream_t* stream, const void* vx, + const void* vy, void* vz, uint32_t blocks, + uint32_t length) { + addBias2DCuda<<<256, 1024, 128, *stream>>>(vx, vy, vz, blocks, length); } ////////////////////////////////////////////////////////////////////////// -void addBias(sd::graph::Context& block, const NDArray& input, const NDArray& bias, NDArray& output, const bool isNCHW) { - - PointersManager manager(block.launchContext(), "addBias"); - NDArray::prepareSpecialUse({&output}, {&input, &bias}); - - if (input.rankOf() == 2 && bias.rankOf() == 1 && input.ordering() == 'c' && output.ordering() == 'c' && input.ews() == 1 && bias.ews() == 1 && input.sizeAt(1) == bias.sizeAt(0)) { - BUILD_DOUBLE_SELECTOR(input.dataType(), bias.dataType(), addBias2DCudaLauncher, - (block.launchContext()->getCudaStream(), input.specialBuffer(), bias.specialBuffer(), output.specialBuffer(), input.sizeAt(0), bias.sizeAt(0)), - FLOAT_TYPES, FLOAT_TYPES); - } else { - // default case - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - - BUILD_DOUBLE_SELECTOR(input.dataType(), bias.dataType(), addBiasCudaLauncher, - (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), bias.specialBuffer(), bias.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), isNCHW), - FLOAT_TYPES, FLOAT_TYPES); - } - NDArray::registerSpecialUse({&output}, {&input, &bias}); - manager.synchronize(); +void addBias(sd::graph::Context& block, const NDArray& input, + const NDArray& bias, NDArray& output, const bool isNCHW) { + PointersManager manager(block.launchContext(), "addBias"); + NDArray::prepareSpecialUse({&output}, {&input, &bias}); + + if (input.rankOf() == 2 && bias.rankOf() == 1 && input.ordering() == 'c' && + output.ordering() == 'c' && input.ews() == 1 && bias.ews() == 1 && + input.sizeAt(1) == bias.sizeAt(0)) { + BUILD_DOUBLE_SELECTOR( + input.dataType(), bias.dataType(), addBias2DCudaLauncher, + (block.launchContext()->getCudaStream(), input.specialBuffer(), + bias.specialBuffer(), output.specialBuffer(), input.sizeAt(0), + bias.sizeAt(0)), + FLOAT_TYPES, FLOAT_TYPES); + } else { + // default case + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + BUILD_DOUBLE_SELECTOR( + input.dataType(), bias.dataType(), addBiasCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, + block.launchContext()->getCudaStream(), input.specialBuffer(), + input.specialShapeInfo(), bias.specialBuffer(), + bias.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), isNCHW), + FLOAT_TYPES, FLOAT_TYPES); + } + NDArray::registerSpecialUse({&output}, {&input, &bias}); + manager.synchronize(); } -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/adjust_hue.cu b/libnd4j/include/ops/declarable/helpers/cuda/adjust_hue.cu index fff4bfb11d05..2839c2bf5d8e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/adjust_hue.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/adjust_hue.cu @@ -19,91 +19,102 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - /////////////////////////////////////////////////////////////////// template -static void _CUDA_G adjustHueCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, - void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, - const Nd4jLong numOfTads, const T delta, const int dimC) { - - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); - - __shared__ int rank; - __shared__ Nd4jLong xDimCstride, zDimCstride; - - if (threadIdx.x == 0) { - rank = shape::rank(xShapeInfo); - xDimCstride = shape::stride(xShapeInfo)[dimC]; - zDimCstride = shape::stride(zShapeInfo)[dimC]; - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { - - const T* xTad = x + xTadOffsets[i]; - T* zTad = z + zTadOffsets[i]; - - T h, s, v; - - rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v); - - h += delta ; - if(h > 1) - h -= 1; - else if(h < 0) - h += 1; - - hsvToRgb(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); - } +static void _CUDA_G adjustHueCuda(const void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xTadOffsets, void* vz, + const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadOffsets, + const Nd4jLong numOfTads, const T delta, + const int dimC) { + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); + + __shared__ int rank; + __shared__ Nd4jLong xDimCstride, zDimCstride; + + if (threadIdx.x == 0) { + rank = shape::rank(xShapeInfo); + xDimCstride = shape::stride(xShapeInfo)[dimC]; + zDimCstride = shape::stride(zShapeInfo)[dimC]; + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { + const T* xTad = x + xTadOffsets[i]; + T* zTad = z + zTadOffsets[i]; + + T h, s, v; + + rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v); + + h += delta; + if (h > 1) + h -= 1; + else if (h < 0) + h += 1; + + hsvToRgb(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); + } } /////////////////////////////////////////////////////////////////// -template -static _CUDA_H void adjustHueCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, - void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, - const Nd4jLong numOfTads, const NDArray* deltaScalarArr, const int dimC) { - - adjustHueCuda<<>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, deltaScalarArr->e(0), dimC); +template +static _CUDA_H void adjustHueCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, + const NDArray* deltaScalarArr, const int dimC) { + adjustHueCuda<<>>( + vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, + deltaScalarArr->e(0), dimC); } //////////////////////////////////////////////////////////////////////// -void adjustHue(sd::LaunchContext* context, const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC) { - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), {dimC}); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), {dimC}); - - const Nd4jLong numOfTads = packX.numberOfTads(); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; - - PointersManager manager(context, "adjustHue"); - - NDArray::prepareSpecialUse({output}, {input, deltaScalarArr}); - BUILD_SINGLE_SELECTOR(input->dataType(), adjustHueCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->specialBuffer(), input->specialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, deltaScalarArr, dimC), FLOAT_TYPES); - NDArray::registerSpecialUse({output}, {input, deltaScalarArr}); - - manager.synchronize(); +void adjustHue(sd::LaunchContext* context, const NDArray* input, + const NDArray* deltaScalarArr, NDArray* output, const int dimC) { + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), {dimC}); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), {dimC}); + + const Nd4jLong numOfTads = packX.numberOfTads(); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "adjustHue"); + + NDArray::prepareSpecialUse({output}, {input, deltaScalarArr}); + BUILD_SINGLE_SELECTOR( + input->dataType(), adjustHueCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + input->specialBuffer(), input->specialShapeInfo(), + packX.platformOffsets(), output->specialBuffer(), + output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, + deltaScalarArr, dimC), + FLOAT_TYPES); + NDArray::registerSpecialUse({output}, {input, deltaScalarArr}); + + manager.synchronize(); } - /* template -static void _CUDA_G adjustHueSingleNHWCKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong tuples, float delta) { - int numChannels = 3; - auto tid = threadIdx.x + blockIdx.x * blockDim.x; +static void _CUDA_G adjustHueSingleNHWCKernel(void *xBuffer, Nd4jLong +*xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong tuples, float delta) +{ int numChannels = 3; auto tid = threadIdx.x + blockIdx.x * blockDim.x; auto bIn = reinterpret_cast(xBuffer); auto bOut = reinterpret_cast(zBuffer); @@ -128,10 +139,11 @@ static void _CUDA_G adjustHueSingleNHWCKernel(void *xBuffer, Nd4jLong *xShapeInf } template -static void _CUDA_G adjustHueSingleNCHWKernel(void *xBuffer, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, void *zBuffer, Nd4jLong *zTadShapeInfo, Nd4jLong *zOffsets, Nd4jLong tadLength, Nd4jLong tuples, float delta) { - int numChannels = 3; - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - static const int kChannelRange = 6; +static void _CUDA_G adjustHueSingleNCHWKernel(void *xBuffer, Nd4jLong +*xTadShapeInfo, Nd4jLong *xOffsets, void *zBuffer, Nd4jLong *zTadShapeInfo, +Nd4jLong *zOffsets, Nd4jLong tadLength, Nd4jLong tuples, float delta) { int +numChannels = 3; auto tid = threadIdx.x + blockIdx.x * blockDim.x; static const +int kChannelRange = 6; auto bufferR = reinterpret_cast(xBuffer) + xOffsets[0]; auto bufferG = reinterpret_cast(xBuffer) + xOffsets[1]; @@ -166,56 +178,70 @@ static void _CUDA_G adjustHueSingleNCHWKernel(void *xBuffer, Nd4jLong *xTadShape } template -static void _adjust_hue_single(sd::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) { +static void _adjust_hue_single(sd::LaunchContext * context, NDArray *array, +NDArray *output, float delta, bool isNHWC) { // numChannels is always 3 auto tuples = array->lengthOf() / 3; if (isNHWC) { - adjustHueSingleNHWCKernel<<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), array->specialShapeInfo(), output->specialBuffer(), output->special(), tuples, delta); - } else { + adjustHueSingleNHWCKernel<<<256, 256, 1024, +*context->getCudaStream()>>>(array->specialBuffer(), array->specialShapeInfo(), +output->specialBuffer(), output->special(), tuples, delta); } else { // TODO: check this one - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(array->shapeInfo(), {1, 2}); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), {1, 2}); + auto packX = +sd::ConstantTadHelper::getInstance().tadForDimensions(array->shapeInfo(), {1, +2}); auto packZ = +sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), {1, +2}); auto tadLength = shape::length(packX.primaryShapeInfo()); - adjustHueSingleNCHWKernel<<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, tuples, delta); + adjustHueSingleNCHWKernel<<<256, 256, 1024, +*context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), +packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), +packZ.platformOffsets(), tadLength, tuples, delta); } } template -static void _adjust_hue_batch(sd::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) { - auto xType = array->dataType(); +static void _adjust_hue_batch(sd::LaunchContext * context, NDArray *array, +NDArray *output, float delta, bool isNHWC) { auto xType = array->dataType(); // numChannels is always 3 auto tuples = array->lengthOf() / 3; if (isNHWC) { - // in case of nhwc batch, we don't really care about examples: it's still bunch of RGB values - BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, (context, array, output, delta, isNHWC);, FLOAT_TYPES); - } else { + // in case of nhwc batch, we don't really care about examples: it's +still bunch of RGB values BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, +(context, array, output, delta, isNHWC);, FLOAT_TYPES); } else { // TODO: check this one - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(array->shapeInfo(), {0, 2, 3}); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), {0, 2, 3}); + auto packX = +sd::ConstantTadHelper::getInstance().tadForDimensions(array->shapeInfo(), {0, +2, 3}); auto packZ = +sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), {0, +2, 3}); auto tadLength = shape::length(packX.primary()); - adjustHueSingleNCHWKernel<<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), packZ.platform(), packZ.platform(), tadLength, tuples, delta); + adjustHueSingleNCHWKernel<<<256, 256, 1024, +*context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), +packX.platformOffsets(), output->specialBuffer(), packZ.platform(), +packZ.platform(), tadLength, tuples, delta); } } -void _adjust_hue(sd::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) { - auto xType = array->dataType(); +void _adjust_hue(sd::LaunchContext * context, NDArray *array, NDArray *output, +NDArray* delta, bool isNHWC) { auto xType = array->dataType(); float d = delta->e(0); if (array->rankOf() == 4) { - BUILD_SINGLE_SELECTOR(xType, _adjust_hue_batch, (context, array, output, d, isNHWC);, FLOAT_TYPES); - } else { - BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, (context, array, output, d, isNHWC);, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(xType, _adjust_hue_batch, (context, array, output, +d, isNHWC);, FLOAT_TYPES); } else { BUILD_SINGLE_SELECTOR(xType, +_adjust_hue_single, (context, array, output, d, isNHWC);, FLOAT_TYPES); } } */ -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/adjust_saturation.cu b/libnd4j/include/ops/declarable/helpers/cuda/adjust_saturation.cu index 36837db2977e..2277e08ab677 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/adjust_saturation.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/adjust_saturation.cu @@ -19,92 +19,102 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include #include #include +#include +#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - /////////////////////////////////////////////////////////////////// template -static void _CUDA_G adjustSaturationCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, - void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, - const Nd4jLong numOfTads, const T factor, const int dimC) { +static void _CUDA_G adjustSaturationCuda( + const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, + void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, + const Nd4jLong numOfTads, const T factor, const int dimC) { + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); + __shared__ int rank; + __shared__ Nd4jLong xDimCstride, zDimCstride; - __shared__ int rank; - __shared__ Nd4jLong xDimCstride, zDimCstride; + if (threadIdx.x == 0) { + rank = shape::rank(xShapeInfo); + xDimCstride = shape::stride(xShapeInfo)[dimC]; + zDimCstride = shape::stride(zShapeInfo)[dimC]; + } + __syncthreads(); - if (threadIdx.x == 0) { - rank = shape::rank(xShapeInfo); - xDimCstride = shape::stride(xShapeInfo)[dimC]; - zDimCstride = shape::stride(zShapeInfo)[dimC]; - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const T* xTad = x + xTadOffsets[i]; - T* zTad = z + zTadOffsets[i]; + for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { + const T* xTad = x + xTadOffsets[i]; + T* zTad = z + zTadOffsets[i]; - T h, s, v; + T h, s, v; - rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v); + rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v); - s *= factor; - if(s > 1.f) - s = 1.f; - else if(s < 0.f) - s = 0.f; + s *= factor; + if (s > 1.f) + s = 1.f; + else if (s < 0.f) + s = 0.f; - hsvToRgb(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); - } + hsvToRgb(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); + } } /////////////////////////////////////////////////////////////////// -template -static _CUDA_H void adjustSaturationCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, - void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, - const Nd4jLong numOfTads, const NDArray* factorScalarArr, const int dimC) { - - adjustSaturationCuda<<>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, factorScalarArr->e(0), dimC); +template +static _CUDA_H void adjustSaturationCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, + const NDArray* factorScalarArr, const int dimC) { + adjustSaturationCuda<<>>( + vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, + factorScalarArr->e(0), dimC); } //////////////////////////////////////////////////////////////////////// -void adjustSaturation(sd::LaunchContext* context, const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC) { - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), {dimC}); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), {dimC}); - - const Nd4jLong numOfTads = packX.numberOfTads(); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; - - PointersManager manager(context, "adjustSaturation"); - - NDArray::prepareSpecialUse({output}, {input, factorScalarArr}); - BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturationCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->specialBuffer(), input->specialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, factorScalarArr, dimC), FLOAT_TYPES); - NDArray::registerSpecialUse({output}, {input, factorScalarArr}); - - manager.synchronize(); +void adjustSaturation(sd::LaunchContext* context, const NDArray* input, + const NDArray* factorScalarArr, NDArray* output, + const int dimC) { + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), {dimC}); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), {dimC}); + + const Nd4jLong numOfTads = packX.numberOfTads(); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "adjustSaturation"); + + NDArray::prepareSpecialUse({output}, {input, factorScalarArr}); + BUILD_SINGLE_SELECTOR( + input->dataType(), adjustSaturationCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + input->specialBuffer(), input->specialShapeInfo(), + packX.platformOffsets(), output->specialBuffer(), + output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, + factorScalarArr, dimC), + FLOAT_TYPES); + NDArray::registerSpecialUse({output}, {input, factorScalarArr}); + + manager.synchronize(); } /* template -static void _CUDA_G adjustSaturationSingleNHWCKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong tuples, float delta) { - int numChannels = 3; - auto tid = threadIdx.x + blockIdx.x * blockDim.x; +static void _CUDA_G adjustSaturationSingleNHWCKernel(void *xBuffer, Nd4jLong +*xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong tuples, float delta) +{ int numChannels = 3; auto tid = threadIdx.x + blockIdx.x * blockDim.x; auto bIn = reinterpret_cast(xBuffer); auto bOut = reinterpret_cast(zBuffer); @@ -117,7 +127,8 @@ static void _CUDA_G adjustSaturationSingleNHWCKernel(void *xBuffer, Nd4jLong *xS T h, s, v; // Convert the RGB color to Hue/V-range. helpers::rgb_to_hsv(i[0], i[1], i[2], &h, &s, &v); - s = sd::math::nd4j_min((T) 1.0f, sd::math::nd4j_max((T) 0.0f, s * delta)); + s = sd::math::nd4j_min((T) 1.0f, sd::math::nd4j_max((T) 0.0f, s * +delta)); // Convert the hue and v-range back into RGB. helpers::hsv_to_rgb(h, s, v, o, o + 1, o + 2); @@ -125,10 +136,11 @@ static void _CUDA_G adjustSaturationSingleNHWCKernel(void *xBuffer, Nd4jLong *xS } template -static void _CUDA_G adjustSaturationSingleNCHWKernel(void *xBuffer, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, void *zBuffer, Nd4jLong *zTadShapeInfo, Nd4jLong *zOffsets, Nd4jLong tadLength, Nd4jLong tuples, float delta) { - int numChannels = 3; - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - static const int kChannelRange = 6; +static void _CUDA_G adjustSaturationSingleNCHWKernel(void *xBuffer, Nd4jLong +*xTadShapeInfo, Nd4jLong *xOffsets, void *zBuffer, Nd4jLong *zTadShapeInfo, +Nd4jLong *zOffsets, Nd4jLong tadLength, Nd4jLong tuples, float delta) { int +numChannels = 3; auto tid = threadIdx.x + blockIdx.x * blockDim.x; static const +int kChannelRange = 6; auto bufferR = reinterpret_cast(xBuffer) + xOffsets[0]; auto bufferG = reinterpret_cast(xBuffer) + xOffsets[1]; @@ -150,62 +162,79 @@ static void _CUDA_G adjustSaturationSingleNCHWKernel(void *xBuffer, Nd4jLong *xT T h, s, v; // Convert the RGB color to Hue/V-range. helpers::rgb_to_hsv(_ri[0], _gi[0], _bi[0], &h, &s, &v); - s = sd::math::nd4j_min((T) 1.0f, sd::math::nd4j_max((T) 0.0f, s * delta)); + s = sd::math::nd4j_min((T) 1.0f, sd::math::nd4j_max((T) 0.0f, s * +delta)); // Convert the hue and v-range back into RGB. helpers::hsv_to_rgb(h, s, v, _ro, _go, _bo); } } template -static void _adjust_saturation_single(sd::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) { +static void _adjust_saturation_single(sd::LaunchContext * context, NDArray +*array, NDArray *output, float delta, bool isNHWC) { // numChannels is always 3 auto tuples = array->lengthOf() / 3; if (isNHWC) { - adjustSaturationSingleNHWCKernel<<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), array->specialShapeInfo(), output->specialBuffer(), output->special(), tuples, delta); - } else { - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(array->shapeInfo(), {1, 2}); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), {1, 2}); + adjustSaturationSingleNHWCKernel<<<256, 256, 1024, +*context->getCudaStream()>>>(array->specialBuffer(), array->specialShapeInfo(), +output->specialBuffer(), output->special(), tuples, delta); } else { + auto packX = +sd::ConstantTadHelper::getInstance().tadForDimensions(array->shapeInfo(), {1, +2}); auto packZ = +sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), {1, +2}); auto tadLength = shape::length(packX.primaryShapeInfo()); - adjustSaturationSingleNCHWKernel<<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, tuples, delta); + adjustSaturationSingleNCHWKernel<<<256, 256, 1024, +*context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), +packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), +packZ.platformOffsets(), tadLength, tuples, delta); } } template -static void _adjust_saturation_batch(sd::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) { - auto xType = array->dataType(); +static void _adjust_saturation_batch(sd::LaunchContext * context, NDArray +*array, NDArray *output, float delta, bool isNHWC) { auto xType = +array->dataType(); // numChannels is always 3 auto tuples = array->lengthOf() / 3; if (isNHWC) { - // in case of nhwc batch, we don't really care about examples: it's still bunch of RGB values - BUILD_SINGLE_SELECTOR(xType, _adjust_saturation_single, (context, array, output, delta, isNHWC);, FLOAT_TYPES); - } else { + // in case of nhwc batch, we don't really care about examples: it's +still bunch of RGB values BUILD_SINGLE_SELECTOR(xType, +_adjust_saturation_single, (context, array, output, delta, isNHWC);, +FLOAT_TYPES); } else { // TODO: check this one - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(array->shapeInfo(), {0, 2, 3}); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), {0, 2, 3}); + auto packX = +sd::ConstantTadHelper::getInstance().tadForDimensions(array->shapeInfo(), {0, +2, 3}); auto packZ = +sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), {0, +2, 3}); auto tadLength = shape::length(packX.primary()); - adjustSaturationSingleNCHWKernel<<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), packZ.platform(), packZ.platform(), tadLength, tuples, delta); + adjustSaturationSingleNCHWKernel<<<256, 256, 1024, +*context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), +packX.platformOffsets(), output->specialBuffer(), packZ.platform(), +packZ.platform(), tadLength, tuples, delta); } } -void adjust_saturation(sd::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) { - auto xType = array->dataType(); +void adjust_saturation(sd::LaunchContext * context, NDArray *array, NDArray +*output, NDArray* delta, bool isNHWC) { auto xType = array->dataType(); float d = delta->e(0); if (array->rankOf() == 4) { - BUILD_SINGLE_SELECTOR(xType, _adjust_saturation_batch, (context, array, output, d, isNHWC);, FLOAT_TYPES); - } else { - BUILD_SINGLE_SELECTOR(xType, _adjust_saturation_single, (context, array, output, d, isNHWC);, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(xType, _adjust_saturation_batch, (context, array, +output, d, isNHWC);, FLOAT_TYPES); } else { BUILD_SINGLE_SELECTOR(xType, +_adjust_saturation_single, (context, array, output, d, isNHWC);, FLOAT_TYPES); } } */ -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/axis.cu b/libnd4j/include/ops/declarable/helpers/cuda/axis.cu index 1dd00f688fd2..b5fd8249b838 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/axis.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/axis.cu @@ -20,32 +20,31 @@ #include - namespace sd { namespace ops { namespace helpers { - void adjustAxis(Nd4jLong rank, NDArray* axisVector, std::vector& output) { - output.resize(axisVector->lengthOf()); - axisVector->tickReadDevice(); // mark input as read on device - axisVector->syncToHost(); // sync to host - for (int e = 0; e < axisVector->lengthOf(); e++) { - auto ca = axisVector->e(e); - if (ca < 0) // shift values on rank for negative vals - ca += rank; - - output[e] = ca; - } - } - - void adjustAxis(Nd4jLong rank, std::vector &axisVector) { - for (int e = 0; e < axisVector.size(); e++) { - auto a = axisVector[e]; - if (a < 0) // shift vals on rank for negative vals - axisVector[e] = a + rank; - } - } - -} +void adjustAxis(Nd4jLong rank, NDArray* axisVector, std::vector& output) { + output.resize(axisVector->lengthOf()); + axisVector->tickReadDevice(); // mark input as read on device + axisVector->syncToHost(); // sync to host + for (int e = 0; e < axisVector->lengthOf(); e++) { + auto ca = axisVector->e(e); + if (ca < 0) // shift values on rank for negative vals + ca += rank; + + output[e] = ca; + } } + +void adjustAxis(Nd4jLong rank, std::vector& axisVector) { + for (int e = 0; e < axisVector.size(); e++) { + auto a = axisVector[e]; + if (a < 0) // shift vals on rank for negative vals + axisVector[e] = a + rank; + } } + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/batched_gemm.cu b/libnd4j/include/ops/declarable/helpers/cuda/batched_gemm.cu index 40540f65d3a1..b9678fdaa809 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/batched_gemm.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/batched_gemm.cu @@ -19,14 +19,13 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include +#include +#include #include #include #include -#include -#include - namespace sd { namespace ops { @@ -34,145 +33,157 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////////// // bsxMXK x bSxKxN = bSxMxN -void bgemm(const std::vector& vA, const std::vector& vB, std::vector& vC, const NDArray* alphas, const NDArray* betas, int transA, int transB, int M, int N, int K, const int lda, const int ldb, const int ldc) { - - const auto bS = vA.size(); // batch size - - std::vector pA(bS), pB(bS), pC(bS); - - std::vector toDelete; - - for(int i = 0; i < bS; ++i) { - - if(vA[i]->ews() != 1) { - pA[i] = new NDArray(vA[i]->dup('f')); - toDelete.emplace_back(pA[i]); - } - else - pA[i] = vA[i]; - - if(vB[i]->ews() != 1) { - pB[i] = new NDArray(vB[i]->dup('f')); - toDelete.emplace_back(pB[i]); - } - else - pB[i] = vB[i]; - - if(vC[i]->ews() != 1) { - pC[i] = new NDArray(vC[i]->dup('f')); - toDelete.emplace_back(pC[i]); - } - else - pC[i] = vC[i]; - - if(pC[i]->ordering() != 'f') { - auto temp = pA[i]; - pA[i] = new NDArray(pB[i]->permute({1,0})); - pB[i] = new NDArray(temp ->permute({1,0})); - pC[i] = new NDArray(pC[i]->permute({1,0})); - toDelete.push_back(pA[i]); - toDelete.push_back(pB[i]); - toDelete.push_back(pC[i]); - M = pA[i]->sizeAt(0); - K = pA[i]->sizeAt(1); - N = pB[i]->sizeAt(1); - } - - NDArray::prepareSpecialUse ({pC[i]}, {pA[i], pB[i]}); - NDArray::registerSpecialUse({pC[i]}, {pA[i], pB[i]}); - } - - NDArray::prepareSpecialUse ({}, {alphas, betas}); - NDArray::registerSpecialUse({}, {alphas, betas}); - - std::vector pAbuffs(bS), pBbuffs(bS), pCbuffs(bS); - for(int i = 0; i < bS; ++i) { - pAbuffs[i] = pA[i]->specialBuffer(); - pBbuffs[i] = pB[i]->specialBuffer(); - pCbuffs[i] = pC[i]->specialBuffer(); +void bgemm(const std::vector& vA, const std::vector& vB, + std::vector& vC, const NDArray* alphas, + const NDArray* betas, int transA, int transB, int M, int N, int K, + const int lda, const int ldb, const int ldc) { + const auto bS = vA.size(); // batch size + + std::vector pA(bS), pB(bS), pC(bS); + + std::vector toDelete; + + for (int i = 0; i < bS; ++i) { + if (vA[i]->ews() != 1) { + pA[i] = new NDArray(vA[i]->dup('f')); + toDelete.emplace_back(pA[i]); + } else + pA[i] = vA[i]; + + if (vB[i]->ews() != 1) { + pB[i] = new NDArray(vB[i]->dup('f')); + toDelete.emplace_back(pB[i]); + } else + pB[i] = vB[i]; + + if (vC[i]->ews() != 1) { + pC[i] = new NDArray(vC[i]->dup('f')); + toDelete.emplace_back(pC[i]); + } else + pC[i] = vC[i]; + + if (pC[i]->ordering() != 'f') { + auto temp = pA[i]; + pA[i] = new NDArray(pB[i]->permute({1, 0})); + pB[i] = new NDArray(temp->permute({1, 0})); + pC[i] = new NDArray(pC[i]->permute({1, 0})); + toDelete.push_back(pA[i]); + toDelete.push_back(pB[i]); + toDelete.push_back(pC[i]); + M = pA[i]->sizeAt(0); + K = pA[i]->sizeAt(1); + N = pB[i]->sizeAt(1); } - sd::LaunchContext* context = vA[0]->getContext(); - PointersManager manager(context, "helpers::bgemm cuda"); - - const void** aBuffers = reinterpret_cast(manager.replicatePointer(pAbuffs.data(), bS * sizeof(void*))); - const void** bBuffers = reinterpret_cast(manager.replicatePointer(pBbuffs.data(), bS * sizeof(void*))); - void** cBuffers = reinterpret_cast(manager.replicatePointer(pCbuffs.data(), bS * sizeof(void*))); - - // const auto aOrder = pA->ordering(); - // const auto bOrder = pB->ordering(); - - // const bool transA = aOrder != 'f'; - // const bool transB = bOrder != 'f'; - - const cublasOperation_t transAblas = transA == 112 ? CUBLAS_OP_T : CUBLAS_OP_N; - const cublasOperation_t transBblas = transB == 112 ? CUBLAS_OP_T : CUBLAS_OP_N; - - // const int lda = aOrder == 'f' ? M : K; - // const int ldb = bOrder == 'f' ? K : N; - // const int ldc = M; // cOrder == 'f' ? M : N; - - const auto aType = pA[0]->dataType(); - const auto bType = pB[0]->dataType(); - const auto cType = pC[0]->dataType(); - - std::lock_guard lock(*LaunchContext::deviceMutex()); - - auto handle = reinterpret_cast(context->getCublasHandle()); - auto stream = context->getCudaStream(); - - auto status = cublasSetStream_v2(*handle, *stream); - - if (status != CUBLAS_STATUS_SUCCESS) - throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); - - const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC); - - // choose appropriate cuda gemm api depending on data types - if(ABC && aType == DataType::DOUBLE) { - double alpha = alphas->e(0); - double beta = betas->e(0); - status = cublasDgemmBatched(*handle, transAblas, transBblas, M, N, K, &alpha, (const double**)aBuffers, lda, (const double**)bBuffers, ldb, &beta, (double**)cBuffers, ldc, bS); - } - else if(ABC && aType == DataType::FLOAT32) { - float alpha = alphas->e(0); - float beta = betas->e(0); - status = cublasSgemmBatched(*handle, transAblas, transBblas, M, N, K, &alpha, (const float**)aBuffers, lda, (const float**)bBuffers, ldb, &beta, (float**)cBuffers, ldc, bS); - } - else if(ABC && aType == DataType::HALF) { - __half alpha = alphas->e(0); - __half beta = betas->e(0); - status = cublasHgemmBatched(*handle, transAblas, transBblas, M, N, K, &alpha, (const __half**)aBuffers, lda, (const __half**)bBuffers, ldb, &beta, (__half**)cBuffers, ldc, bS); - } - else if(AB && aType == DataType::INT8 && cType == DataType::FLOAT32) { - float alpha = alphas->e(0); - float beta = betas->e(0); - status = cublasGemmBatchedEx(*handle, transAblas, transBblas, M, N, K, &alpha, aBuffers, CUDA_R_8I, lda, bBuffers, CUDA_R_8I, ldb, &beta, cBuffers, CUDA_R_32F, ldc, bS, CUDA_R_32F, CUBLAS_GEMM_DEFAULT); - } - else if(AB && aType == DataType::HALF && cType == DataType::FLOAT32) { - float alpha = alphas->e(0); - float beta = betas->e(0); - status = cublasGemmBatchedEx(*handle, transAblas, transBblas, M, N, K, &alpha, aBuffers, CUDA_R_16F, lda, bBuffers, CUDA_R_16F, ldb, &beta, cBuffers, CUDA_R_32F, ldc, bS, CUDA_R_32F, CUBLAS_GEMM_DEFAULT); - } - else - throw std::runtime_error("batched gemm cuda: this mode is not implemented yet !"); - - if (status != CUBLAS_STATUS_SUCCESS) - throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); - - auto cudaResult = cudaStreamSynchronize(*stream); - if (cudaResult != 0) - throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", cudaResult); - - for(int i = 0; i < bS; ++i) - if(vC[i]->ews() != 1) - vC[i]->assign(pC[i]); - - for(int i = toDelete.size() - 1; i >= 0; --i) - delete toDelete[i]; -} - -} -} + NDArray::prepareSpecialUse({pC[i]}, {pA[i], pB[i]}); + NDArray::registerSpecialUse({pC[i]}, {pA[i], pB[i]}); + } + + NDArray::prepareSpecialUse({}, {alphas, betas}); + NDArray::registerSpecialUse({}, {alphas, betas}); + + std::vector pAbuffs(bS), pBbuffs(bS), pCbuffs(bS); + for (int i = 0; i < bS; ++i) { + pAbuffs[i] = pA[i]->specialBuffer(); + pBbuffs[i] = pB[i]->specialBuffer(); + pCbuffs[i] = pC[i]->specialBuffer(); + } + + sd::LaunchContext* context = vA[0]->getContext(); + PointersManager manager(context, "helpers::bgemm cuda"); + + const void** aBuffers = reinterpret_cast( + manager.replicatePointer(pAbuffs.data(), bS * sizeof(void*))); + const void** bBuffers = reinterpret_cast( + manager.replicatePointer(pBbuffs.data(), bS * sizeof(void*))); + void** cBuffers = reinterpret_cast( + manager.replicatePointer(pCbuffs.data(), bS * sizeof(void*))); + + // const auto aOrder = pA->ordering(); + // const auto bOrder = pB->ordering(); + + // const bool transA = aOrder != 'f'; + // const bool transB = bOrder != 'f'; + + const cublasOperation_t transAblas = + transA == 112 ? CUBLAS_OP_T : CUBLAS_OP_N; + const cublasOperation_t transBblas = + transB == 112 ? CUBLAS_OP_T : CUBLAS_OP_N; + + // const int lda = aOrder == 'f' ? M : K; + // const int ldb = bOrder == 'f' ? K : N; + // const int ldc = M; // cOrder == 'f' ? M : N; + + const auto aType = pA[0]->dataType(); + const auto bType = pB[0]->dataType(); + const auto cType = pC[0]->dataType(); + + std::lock_guard lock(*LaunchContext::deviceMutex()); + + auto handle = reinterpret_cast(context->getCublasHandle()); + auto stream = context->getCudaStream(); + + auto status = cublasSetStream_v2(*handle, *stream); + + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); + + const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC); + + // choose appropriate cuda gemm api depending on data types + if (ABC && aType == DataType::DOUBLE) { + double alpha = alphas->e(0); + double beta = betas->e(0); + status = cublasDgemmBatched(*handle, transAblas, transBblas, M, N, K, + &alpha, (const double**)aBuffers, lda, + (const double**)bBuffers, ldb, &beta, + (double**)cBuffers, ldc, bS); + } else if (ABC && aType == DataType::FLOAT32) { + float alpha = alphas->e(0); + float beta = betas->e(0); + status = cublasSgemmBatched(*handle, transAblas, transBblas, M, N, K, + &alpha, (const float**)aBuffers, lda, + (const float**)bBuffers, ldb, &beta, + (float**)cBuffers, ldc, bS); + } else if (ABC && aType == DataType::HALF) { + __half alpha = alphas->e(0); + __half beta = betas->e(0); + status = cublasHgemmBatched(*handle, transAblas, transBblas, M, N, K, + &alpha, (const __half**)aBuffers, lda, + (const __half**)bBuffers, ldb, &beta, + (__half**)cBuffers, ldc, bS); + } else if (AB && aType == DataType::INT8 && cType == DataType::FLOAT32) { + float alpha = alphas->e(0); + float beta = betas->e(0); + status = cublasGemmBatchedEx(*handle, transAblas, transBblas, M, N, K, + &alpha, aBuffers, CUDA_R_8I, lda, bBuffers, + CUDA_R_8I, ldb, &beta, cBuffers, CUDA_R_32F, + ldc, bS, CUDA_R_32F, CUBLAS_GEMM_DEFAULT); + } else if (AB && aType == DataType::HALF && cType == DataType::FLOAT32) { + float alpha = alphas->e(0); + float beta = betas->e(0); + status = cublasGemmBatchedEx(*handle, transAblas, transBblas, M, N, K, + &alpha, aBuffers, CUDA_R_16F, lda, bBuffers, + CUDA_R_16F, ldb, &beta, cBuffers, CUDA_R_32F, + ldc, bS, CUDA_R_32F, CUBLAS_GEMM_DEFAULT); + } else + throw std::runtime_error( + "batched gemm cuda: this mode is not implemented yet !"); + + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); + + auto cudaResult = cudaStreamSynchronize(*stream); + if (cudaResult != 0) + throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", + cudaResult); + + for (int i = 0; i < bS; ++i) + if (vC[i]->ews() != 1) vC[i]->assign(pC[i]); + + for (int i = toDelete.size() - 1; i >= 0; --i) delete toDelete[i]; } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/batchnorm.cu b/libnd4j/include/ops/declarable/helpers/cuda/batchnorm.cu index f7f8bf966f1c..5674315ccde4 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/batchnorm.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/batchnorm.cu @@ -18,29 +18,29 @@ // @author Yurii Shyrma, created on 25.02.2018 // - -#include -#include -#include #include +#include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// // template -// __global__ static void batchnormCuda(const void* vx, const Nd4jLong* xShapeInfo, -// const void* vMean, const Nd4jLong* meanShapeInfo, -// const void* vVariance, const Nd4jLong* varianceShapeInfo, -// const void* vGamma, const Nd4jLong* gammaShapeInfo, -// const void* vBeta, const Nd4jLong* betaShapeInfo, -// void* vz, const Nd4jLong* zShapeInfo, -// const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, -// const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, -// const T epsilon) { +// __global__ static void batchnormCuda(const void* vx, const Nd4jLong* +// xShapeInfo, const void* vMean, +// const Nd4jLong* meanShapeInfo, +// const void* +// vVariance, const Nd4jLong* varianceShapeInfo, +// const void* vGamma, const Nd4jLong* gammaShapeInfo, +// const void* vBeta, const Nd4jLong* betaShapeInfo, void* vz, const Nd4jLong* +// zShapeInfo, const Nd4jLong* +// xTadShapeInfo, const Nd4jLong* xTadOffsets, +// const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, +// const T epsilon) { // const auto x = reinterpret_cast(vx); // auto z = reinterpret_cast(vz); @@ -49,7 +49,8 @@ namespace helpers { // const auto gamma = reinterpret_cast(vGamma); // const auto beta = reinterpret_cast(vBeta); -// // maxRank = xRank = zRank, minRank = meanRank = varianceRank = gammaRank = betaRank +// // maxRank = xRank = zRank, minRank = meanRank = varianceRank = gammaRank +// = betaRank // __shared__ Nd4jLong minLen, tadLen, totalThreads; // if (threadIdx.x == 0) { @@ -64,10 +65,12 @@ namespace helpers { // for (uint i = tid; i < minLen; i += totalThreads) { -// const auto meanOffset = shape::getIndexOffset(i, meanShapeInfo); +// const auto meanOffset = shape::getIndexOffset(i, +// meanShapeInfo); // const auto varianceOffset = shape::getIndexOffset(i, varianceShapeInfo); -// T sigmaInvGam = 1. / sd::math::nd4j_sqrt(variance[varianceOffset] + epsilon); +// T sigmaInvGam = 1. / sd::math::nd4j_sqrt(variance[varianceOffset] +// + epsilon); // if(gamma != nullptr) // sigmaInvGam *= gamma[shape::getIndexOffset(i, gammaShapeInfo)]; @@ -84,7 +87,8 @@ namespace helpers { // const auto xTadOffset = shape::getIndexOffset(j, xTadShapeInfo); // const auto zTadOffset = shape::getIndexOffset(j, zTadShapeInfo); -// zTad[zTadOffset] = (xTad[xTadOffset] - mean[meanOffset]) * sigmaInvGam; +// zTad[zTadOffset] = (xTad[xTadOffset] - mean[meanOffset]) * +// sigmaInvGam; // if(beta != nullptr) // zTad[zTadOffset] += beta[betaOffset]; @@ -93,145 +97,180 @@ namespace helpers { // } ////////////////////////////////////////////////////////////////////////// -template -__global__ static void batchnormCuda2(const void* vx, const Nd4jLong* xShapeInfo, - const void* vMean, const Nd4jLong* meanShapeInfo, - const void* vVariance, const Nd4jLong* varianceShapeInfo, - const void* vGamma, const Nd4jLong* gammaShapeInfo, - const void* vBeta, const Nd4jLong* betaShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const int numDims, const int* dims, - const T epsilon) { - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - const auto mean = reinterpret_cast(vMean); - const auto variance = reinterpret_cast(vVariance); - const auto gamma = reinterpret_cast(vGamma); - const auto beta = reinterpret_cast(vBeta); - - __shared__ int xRank, minRank; // xRank == zRank, minRank = meanRank = varianceRank = gammaRank = betaRank - __shared__ Nd4jLong xLen, totalThreads; // xLen = zLen - - - if (threadIdx.x == 0) { - - totalThreads = gridDim.x * blockDim.x; - - xLen = shape::length(xShapeInfo); - xRank = shape::rank(xShapeInfo); - minRank = shape::rank(meanShapeInfo); +template +__global__ static void batchnormCuda2( + const void* vx, const Nd4jLong* xShapeInfo, const void* vMean, + const Nd4jLong* meanShapeInfo, const void* vVariance, + const Nd4jLong* varianceShapeInfo, const void* vGamma, + const Nd4jLong* gammaShapeInfo, const void* vBeta, + const Nd4jLong* betaShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + const int numDims, const int* dims, const T epsilon) { + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + const auto mean = reinterpret_cast(vMean); + const auto variance = reinterpret_cast(vVariance); + const auto gamma = reinterpret_cast(vGamma); + const auto beta = reinterpret_cast(vBeta); + + __shared__ int xRank, minRank; // xRank == zRank, minRank = meanRank = + // varianceRank = gammaRank = betaRank + __shared__ Nd4jLong xLen, totalThreads; // xLen = zLen + + if (threadIdx.x == 0) { + totalThreads = gridDim.x * blockDim.x; + + xLen = shape::length(xShapeInfo); + xRank = shape::rank(xShapeInfo); + minRank = shape::rank(meanShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (uint i = tid; i < xLen; i += totalThreads) { + shape::index2coords(i, xShapeInfo, coords); + + const auto xOffset = shape::getOffset(xShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); + + if (minRank == xRank) { + for (uint i = 0, j = 0; i < xRank; ++i) { + if (j < numDims && i != dims[j]) + coords[i] = 0; + else + ++j; + } + } else // minRank = numDims = 1 in this case + coords[0] = coords[dims[0]]; + + const auto meanOffset = shape::getOffset(meanShapeInfo, coords); + const auto varianceOffset = shape::getOffset(varianceShapeInfo, coords); + + T sigmaInvGam = + 1. / sd::math::nd4j_sqrt(variance[varianceOffset] + epsilon); + + if (gamma != nullptr) { + const auto gammaOffset = shape::getOffset(gammaShapeInfo, coords); + sigmaInvGam *= gamma[gammaOffset]; } - __syncthreads(); - - int coords[MAX_RANK]; - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (uint i = tid; i < xLen; i += totalThreads) { + z[zOffset] = (x[xOffset] - mean[meanOffset]) * sigmaInvGam; - shape::index2coords(i, xShapeInfo, coords); - - const auto xOffset = shape::getOffset(xShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); - - if(minRank == xRank) { - for (uint i = 0, j = 0; i < xRank; ++i) { - if(j < numDims && i != dims[j]) - coords[i] = 0; - else - ++j; - } - } - else // minRank = numDims = 1 in this case - coords[0] = coords[dims[0]]; - - const auto meanOffset = shape::getOffset(meanShapeInfo, coords); - const auto varianceOffset = shape::getOffset(varianceShapeInfo, coords); - - T sigmaInvGam = 1. / sd::math::nd4j_sqrt(variance[varianceOffset] + epsilon); - - if(gamma != nullptr) { - const auto gammaOffset = shape::getOffset(gammaShapeInfo, coords); - sigmaInvGam *= gamma[gammaOffset]; - } - - z[zOffset] = (x[xOffset] - mean[meanOffset]) * sigmaInvGam; - - if(beta != nullptr) { - const auto betaOffset = shape::getOffset(betaShapeInfo, coords); - z[zOffset] += beta[betaOffset]; - } + if (beta != nullptr) { + const auto betaOffset = shape::getOffset(betaShapeInfo, coords); + z[zOffset] += beta[betaOffset]; } + } } /////////////////////////////////////////////////////////////////// // template -// __host__ static void batchnormCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, -// const void* vx, const Nd4jLong* xShapeInfo, -// const void* vMean, const Nd4jLong* meanShapeInfo, -// const void* vVariance, const Nd4jLong* varianceShapeInfo, -// const void* vGamma, const Nd4jLong* gammaShapeInfo, -// const void* vBeta, const Nd4jLong* betaShapeInfo, -// void* vz, const Nd4jLong* zShapeInfo, -// const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, -// const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, -// const double epsilon) { - -// batchnormCuda<<>>(vx, xShapeInfo, vMean, meanShapeInfo, vVariance, varianceShapeInfo, vGamma, gammaShapeInfo, vBeta, betaShapeInfo, vz, zShapeInfo, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, static_cast(epsilon)); +// __host__ static void batchnormCudaLauncher(const int blocksPerGrid, const int +// threadsPerBlock, const cudaStream_t *stream, +// const void* vx, const Nd4jLong* xShapeInfo, +// const void* vMean, const +// Nd4jLong* meanShapeInfo, +// const void* +// vVariance, +// const Nd4jLong* varianceShapeInfo, +// const void* vGamma, const Nd4jLong* +// gammaShapeInfo, const +// void* vBeta, const Nd4jLong* betaShapeInfo, +// void* vz, +// const Nd4jLong* zShapeInfo, +// const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, +// const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, const double +// epsilon) +// { + +// batchnormCuda<<>>(vx, +// xShapeInfo, vMean, meanShapeInfo, vVariance, varianceShapeInfo, vGamma, +// gammaShapeInfo, vBeta, betaShapeInfo, vz, zShapeInfo, xTadShapeInfo, +// xTadOffsets, zTadShapeInfo, zTadOffsets, static_cast(epsilon)); // } /////////////////////////////////////////////////////////////////// -template -__host__ static void batchnormCudaLauncher2(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - const void* vMean, const Nd4jLong* meanShapeInfo, - const void* vVariance, const Nd4jLong* varianceShapeInfo, - const void* vGamma, const Nd4jLong* gammaShapeInfo, - const void* vBeta, const Nd4jLong* betaShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const int numDims, const int* dims, - const double epsilon) { - - batchnormCuda2<<>>(vx, xShapeInfo, vMean, meanShapeInfo, vVariance, varianceShapeInfo, vGamma, gammaShapeInfo, vBeta, betaShapeInfo, vz, zShapeInfo, numDims, dims, static_cast(epsilon)); +template +__host__ static void batchnormCudaLauncher2( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vMean, const Nd4jLong* meanShapeInfo, const void* vVariance, + const Nd4jLong* varianceShapeInfo, const void* vGamma, + const Nd4jLong* gammaShapeInfo, const void* vBeta, + const Nd4jLong* betaShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + const int numDims, const int* dims, const double epsilon) { + batchnormCuda2<<>>( + vx, xShapeInfo, vMean, meanShapeInfo, vVariance, varianceShapeInfo, + vGamma, gammaShapeInfo, vBeta, betaShapeInfo, vz, zShapeInfo, numDims, + dims, static_cast(epsilon)); } ////////////////////////////////////////////////////////////////////////// -void batchnorm(const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* beta, NDArray* output, const std::vector& axes, const double epsilon) { - - // std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(input->rankOf(), axes); - - // auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimsToExclude); - // auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimsToExclude); - - // const int threadsPerBlock = MAX_NUM_THREADS / 2; - // const int blocksPerGrid = (mean->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - // PointersManager manager(input->getContext(), "batchnorm"); - - // NDArray::prepareSpecialUse({output}, {input, mean, variance, gamma, beta}); - // BUILD_SINGLE_SELECTOR(input->dataType(), batchnormCudaLauncher, (blocksPerGrid, threadsPerBlock, input->getContext()->getCudaStream(), input->specialBuffer(), input->specialShapeInfo(), mean->specialBuffer(), mean->specialShapeInfo(), variance->specialBuffer(), variance->specialShapeInfo(), gamma ? gamma->specialBuffer() : nullptr, gamma ? gamma->specialShapeInfo() : nullptr, beta ? beta->specialBuffer() : nullptr, beta ? beta->specialShapeInfo() : nullptr, output->specialBuffer(), output->special(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platform(), packZ.platform(), epsilon), FLOAT_TYPES); - // NDArray::registerSpecialUse({output}, {input, mean, variance, gamma, beta}); - - // manager.synchronize(); - - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (input->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - PointersManager manager(input->getContext(), "batchnorm"); - - const int* dims = reinterpret_cast(manager.replicatePointer(axes.data(), axes.size() * sizeof(int))); - - NDArray::prepareSpecialUse({output}, {input, mean, variance, gamma, beta}); - BUILD_SINGLE_SELECTOR(input->dataType(), batchnormCudaLauncher2, (blocksPerGrid, threadsPerBlock, input->getContext()->getCudaStream(), input->specialBuffer(), input->specialShapeInfo(), mean->specialBuffer(), mean->specialShapeInfo(), variance->specialBuffer(), variance->specialShapeInfo(), gamma ? gamma->specialBuffer() : nullptr, gamma ? gamma->specialShapeInfo() : nullptr, beta ? beta->specialBuffer() : nullptr, beta ? beta->specialShapeInfo() : nullptr, output->specialBuffer(), output->specialShapeInfo(), axes.size(), dims, epsilon), FLOAT_TYPES); - NDArray::registerSpecialUse({output}, {input, mean, variance, gamma, beta}); - - manager.synchronize(); -} - - -} -} +void batchnorm(const NDArray* input, const NDArray* mean, + const NDArray* variance, const NDArray* gamma, + const NDArray* beta, NDArray* output, + const std::vector& axes, const double epsilon) { + // std::vector dimsToExclude = + // ShapeUtils::evalDimsToExclude(input->rankOf(), axes); + + // auto packX = + // sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), + // dimsToExclude); + // auto packZ = + // sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), + // dimsToExclude); + + // const int threadsPerBlock = MAX_NUM_THREADS / 2; + // const int blocksPerGrid = (mean->lengthOf() + threadsPerBlock - 1) / + // threadsPerBlock; + + // PointersManager manager(input->getContext(), "batchnorm"); + + // NDArray::prepareSpecialUse({output}, {input, mean, variance, gamma, + // beta}); BUILD_SINGLE_SELECTOR(input->dataType(), batchnormCudaLauncher, + // (blocksPerGrid, threadsPerBlock, input->getContext()->getCudaStream(), + // input->specialBuffer(), input->specialShapeInfo(), + // mean->specialBuffer(), mean->specialShapeInfo(), + // variance->specialBuffer(), variance->specialShapeInfo(), gamma ? + // gamma->specialBuffer() : nullptr, gamma ? gamma->specialShapeInfo() : + // nullptr, beta ? beta->specialBuffer() : nullptr, beta ? + // beta->specialShapeInfo() : nullptr, output->specialBuffer(), + // output->special(), packX.platformShapeInfo(), + // packX.platformOffsets(), packZ.platform(), + // packZ.platform(), epsilon), FLOAT_TYPES); + // NDArray::registerSpecialUse({output}, {input, mean, variance, gamma, + // beta}); + + // manager.synchronize(); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (input->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(input->getContext(), "batchnorm"); + + const int* dims = reinterpret_cast( + manager.replicatePointer(axes.data(), axes.size() * sizeof(int))); + + NDArray::prepareSpecialUse({output}, {input, mean, variance, gamma, beta}); + BUILD_SINGLE_SELECTOR( + input->dataType(), batchnormCudaLauncher2, + (blocksPerGrid, threadsPerBlock, input->getContext()->getCudaStream(), + input->specialBuffer(), input->specialShapeInfo(), mean->specialBuffer(), + mean->specialShapeInfo(), variance->specialBuffer(), + variance->specialShapeInfo(), gamma ? gamma->specialBuffer() : nullptr, + gamma ? gamma->specialShapeInfo() : nullptr, + beta ? beta->specialBuffer() : nullptr, + beta ? beta->specialShapeInfo() : nullptr, output->specialBuffer(), + output->specialShapeInfo(), axes.size(), dims, epsilon), + FLOAT_TYPES); + NDArray::registerSpecialUse({output}, {input, mean, variance, gamma, beta}); + + manager.synchronize(); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/betaInc.cu b/libnd4j/include/ops/declarable/helpers/cuda/betaInc.cu index a18ec1fda8f7..fb7805405688 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/betaInc.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/betaInc.cu @@ -18,177 +18,179 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include -#include #include +#include + +#include namespace sd { namespace ops { namespace helpers { - /////////////////////////////////////////////////////////////////// // modified Lentz’s algorithm for continued fractions, -// reference: Lentz, W.J. 1976, “Generating Bessel Functions in Mie Scattering Calculations Using Continued Fractions,” +// reference: Lentz, W.J. 1976, “Generating Bessel Functions in Mie Scattering +// Calculations Using Continued Fractions,” template __device__ T continuedFractionCuda(const T a, const T b, const T x) { - - extern __shared__ unsigned char shmem[]; - T* coeffs = reinterpret_cast(shmem); - - const T min = DataTypeUtils::min() / DataTypeUtils::eps(); - const T aPlusb = a + b; - T val, aPlus2i; - - T t2 = coeffs[1]; - T t1 = coeffs[0]; - if(math::nd4j_abs(t1) < min) - t1 = min; - t1 = static_cast(1) / t1; - T result = t1; - - for(uint i = 1; i <= maxIter; ++i) { - - const uint i2 = 2*i; - aPlus2i = a + static_cast(i2); - - // t1 - t1 = static_cast(1) + coeffs[i2] * t1; - if(math::nd4j_abs(t1) < min) - t1 = min; - t1 = static_cast(1) / t1; - // t2 - t2 = static_cast(1) + coeffs[i2] / t2; - if(math::nd4j_abs(t2) < min) - t2 = min; - // result - result *= t2 * t1; - // t1 - t1 = static_cast(1) + coeffs[i2 + 1] * t1; - if(math::nd4j_abs(t1) < min) - t1 = min; - t1 = static_cast(1) / t1; - // t2 - t2 = static_cast(1) + coeffs[i2 + 1] / t2; - if(math::nd4j_abs(t2) < min) - t2 = min; - // result - val = t2 * t1; - result *= val; - - // condition to stop loop - if(math::nd4j_abs(val - static_cast(1)) <= DataTypeUtils::eps()) - return result; - } - - return DataTypeUtils::infOrMax(); // no convergence, more iterations is required, return infinity + extern __shared__ unsigned char shmem[]; + T* coeffs = reinterpret_cast(shmem); + + const T min = DataTypeUtils::min() / DataTypeUtils::eps(); + const T aPlusb = a + b; + T val, aPlus2i; + + T t2 = coeffs[1]; + T t1 = coeffs[0]; + if (math::nd4j_abs(t1) < min) t1 = min; + t1 = static_cast(1) / t1; + T result = t1; + + for (uint i = 1; i <= maxIter; ++i) { + const uint i2 = 2 * i; + aPlus2i = a + static_cast(i2); + + // t1 + t1 = static_cast(1) + coeffs[i2] * t1; + if (math::nd4j_abs(t1) < min) t1 = min; + t1 = static_cast(1) / t1; + // t2 + t2 = static_cast(1) + coeffs[i2] / t2; + if (math::nd4j_abs(t2) < min) t2 = min; + // result + result *= t2 * t1; + // t1 + t1 = static_cast(1) + coeffs[i2 + 1] * t1; + if (math::nd4j_abs(t1) < min) t1 = min; + t1 = static_cast(1) / t1; + // t2 + t2 = static_cast(1) + coeffs[i2 + 1] / t2; + if (math::nd4j_abs(t2) < min) t2 = min; + // result + val = t2 * t1; + result *= val; + + // condition to stop loop + if (math::nd4j_abs(val - static_cast(1)) <= DataTypeUtils::eps()) + return result; + } + + return DataTypeUtils::infOrMax(); // no convergence, more iterations is + // required, return infinity } /////////////////////////////////////////////////////////////////// -template +template __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo, - const void* vb, const Nd4jLong* bShapeInfo, - const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo) { - - extern __shared__ unsigned char shmem[]; - T* sharedMem = reinterpret_cast(shmem); - - const Nd4jLong j = blockIdx.x; // one block per each element - - T& z = *(reinterpret_cast(vz) + shape::getIndexOffset(j, zShapeInfo)); - - __shared__ T a, b, x; - __shared__ bool symmCond; - - if (threadIdx.x == 0) { - - a = *(reinterpret_cast(va) + shape::getIndexOffset(j, aShapeInfo)); - b = *(reinterpret_cast(vb) + shape::getIndexOffset(j, bShapeInfo)); - x = *(reinterpret_cast(vx) + shape::getIndexOffset(j, xShapeInfo)); - - symmCond = x > (a + static_cast(1)) / (a + b + static_cast(2)); - - if(symmCond) { // swap a and b, x = 1 - x - T temp = a; - a = b; - b = temp; - x = static_cast(1) - x; - } - + const void* vb, const Nd4jLong* bShapeInfo, + const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo) { + extern __shared__ unsigned char shmem[]; + T* sharedMem = reinterpret_cast(shmem); + + const Nd4jLong j = blockIdx.x; // one block per each element + + T& z = *(reinterpret_cast(vz) + shape::getIndexOffset(j, zShapeInfo)); + + __shared__ T a, b, x; + __shared__ bool symmCond; + + if (threadIdx.x == 0) { + a = *(reinterpret_cast(va) + + shape::getIndexOffset(j, aShapeInfo)); + b = *(reinterpret_cast(vb) + + shape::getIndexOffset(j, bShapeInfo)); + x = *(reinterpret_cast(vx) + + shape::getIndexOffset(j, xShapeInfo)); + + symmCond = x > (a + static_cast(1)) / (a + b + static_cast(2)); + + if (symmCond) { // swap a and b, x = 1 - x + T temp = a; + a = b; + b = temp; + x = static_cast(1) - x; } - __syncthreads(); - - // t^{n-1} * (1 - t)^{n-1} is symmetric function with respect to x = 0.5 - if(a == b && x == static_cast(0.5)) { - z = static_cast(0.5); - return; - } - - if (x == static_cast(0) || x == static_cast(1)) { - z = symmCond ? static_cast(1) - x : x; - return; - } - - // calculate two coefficients per thread - if(threadIdx.x != 0) { - - const int i = threadIdx.x; - const T aPlus2i = a + 2*i; - sharedMem[2*i] = i * (b - i) * x / ((aPlus2i - static_cast(1)) * aPlus2i); - sharedMem[2*i + 1] = -(a + i) * (a + b + i) * x / ((aPlus2i + static_cast(1)) * aPlus2i); - } - - __syncthreads(); - - if(threadIdx.x == 0) { - - const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b); - const T front = math::nd4j_exp(math::nd4j_log(x) * a + math::nd4j_log(1.f - x) * b - gammaPart); - - sharedMem[0] = static_cast(1) - (a + b) * x / (a + static_cast(1)); - sharedMem[1] = static_cast(1); - - z = front * continuedFractionCuda(a, b, x) / a; - - if(symmCond) // symmetry relation - z = static_cast(1) - z; - } + } + __syncthreads(); + + // t^{n-1} * (1 - t)^{n-1} is symmetric function with respect to x = 0.5 + if (a == b && x == static_cast(0.5)) { + z = static_cast(0.5); + return; + } + + if (x == static_cast(0) || x == static_cast(1)) { + z = symmCond ? static_cast(1) - x : x; + return; + } + + // calculate two coefficients per thread + if (threadIdx.x != 0) { + const int i = threadIdx.x; + const T aPlus2i = a + 2 * i; + sharedMem[2 * i] = + i * (b - i) * x / ((aPlus2i - static_cast(1)) * aPlus2i); + sharedMem[2 * i + 1] = + -(a + i) * (a + b + i) * x / ((aPlus2i + static_cast(1)) * aPlus2i); + } + + __syncthreads(); + + if (threadIdx.x == 0) { + const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b); + const T front = + math::nd4j_exp(math::nd4j_log(x) * a + + math::nd4j_log(1.f - x) * b - gammaPart); + + sharedMem[0] = static_cast(1) - (a + b) * x / (a + static_cast(1)); + sharedMem[1] = static_cast(1); + + z = front * continuedFractionCuda(a, b, x) / a; + + if (symmCond) // symmetry relation + z = static_cast(1) - z; + } } /////////////////////////////////////////////////////////////////// -template -static void betaIncForArrayCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* va, const Nd4jLong* aShapeInfo, - const void* vb, const Nd4jLong* bShapeInfo, - const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo) { - - betaIncForArrayCuda<<>>(va, aShapeInfo, vb, bShapeInfo, vx, xShapeInfo, vz, zShapeInfo); +template +static void betaIncForArrayCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* va, const Nd4jLong* aShapeInfo, + const void* vb, const Nd4jLong* bShapeInfo, const void* vx, + const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo) { + betaIncForArrayCuda + <<>>( + va, aShapeInfo, vb, bShapeInfo, vx, xShapeInfo, vz, zShapeInfo); } /////////////////////////////////////////////////////////////////// // overload betaInc for arrays, shapes of a, b and x must be the same !!! -void betaInc(sd::LaunchContext* context, const NDArray& a, const NDArray& b, const NDArray& x, NDArray& output) { - - const int threadsPerBlock = maxIter; - const int blocksPerGrid = output.lengthOf(); - const int sharedMem = 2 * output.sizeOfT() * threadsPerBlock + 128; - - const auto xType = x.dataType(); - - PointersManager manager(context, "betaInc"); - - NDArray::prepareSpecialUse({&output}, {&a, &b, &x}); - BUILD_SINGLE_SELECTOR(xType, betaIncForArrayCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), a.specialBuffer(), a.specialShapeInfo(), b.specialBuffer(), b.specialShapeInfo(), x.specialBuffer(), x.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo()), FLOAT_TYPES); - NDArray::registerSpecialUse({&output}, {&a, &b, &x}); - - manager.synchronize(); -} - - -} -} +void betaInc(sd::LaunchContext* context, const NDArray& a, const NDArray& b, + const NDArray& x, NDArray& output) { + const int threadsPerBlock = maxIter; + const int blocksPerGrid = output.lengthOf(); + const int sharedMem = 2 * output.sizeOfT() * threadsPerBlock + 128; + + const auto xType = x.dataType(); + + PointersManager manager(context, "betaInc"); + + NDArray::prepareSpecialUse({&output}, {&a, &b, &x}); + BUILD_SINGLE_SELECTOR( + xType, betaIncForArrayCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + a.specialBuffer(), a.specialShapeInfo(), b.specialBuffer(), + b.specialShapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + output.specialBuffer(), output.specialShapeInfo()), + FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&a, &b, &x}); + + manager.synchronize(); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/col2im.cu b/libnd4j/include/ops/declarable/helpers/cuda/col2im.cu index 62f60cc733b2..5525584da992 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/col2im.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/col2im.cu @@ -19,188 +19,210 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// -// columns [bS, iC, kH, kW, oH, oW] to be de-convoluted to image [bS, iC, iH, iW] +// columns [bS, iC, kH, kW, oH, oW] to be de-convoluted to image [bS, iC, iH, +// iW] template -static __global__ void col2imCuda(const void* columns, const Nd4jLong* colShapeInfo, void* image, const Nd4jLong* imShapeInfo, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) { - - const T* col = reinterpret_cast(columns); - T* im = reinterpret_cast(image); +static __global__ void col2imCuda(const void* columns, + const Nd4jLong* colShapeInfo, void* image, + const Nd4jLong* imShapeInfo, const int sH, + const int sW, const int pH, const int pW, + const int dH, const int dW) { + const T* col = reinterpret_cast(columns); + T* im = reinterpret_cast(image); - __shared__ uint kH, kW, oH, oW, *sharedMem; - __shared__ Nd4jLong imLen; + __shared__ uint kH, kW, oH, oW, *sharedMem; + __shared__ Nd4jLong imLen; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - kH = dH * (colShapeInfo[3] - 1) + 1; - kW = dW * (colShapeInfo[4] - 1) + 1; + kH = dH * (colShapeInfo[3] - 1) + 1; + kW = dW * (colShapeInfo[4] - 1) + 1; - oH = colShapeInfo[5]; - oW = colShapeInfo[6]; + oH = colShapeInfo[5]; + oW = colShapeInfo[6]; - imLen = shape::length(imShapeInfo); - } - __syncthreads(); - - auto coords = sharedMem + threadIdx.x * 6; + imLen = shape::length(imShapeInfo); + } + __syncthreads(); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto coords = sharedMem + threadIdx.x * 6; - for (Nd4jLong i = tid; i < imLen; i += gridDim.x * blockDim.x) { + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - shape::index2coords(i, imShapeInfo, coords); + for (Nd4jLong i = tid; i < imLen; i += gridDim.x * blockDim.x) { + shape::index2coords(i, imShapeInfo, coords); - const auto imOffset = shape::getOffset(imShapeInfo, coords); + const auto imOffset = shape::getOffset(imShapeInfo, coords); - const auto bSiCoffset = coords[0] * colShapeInfo[7] + coords[1] * colShapeInfo[8]; + const auto bSiCoffset = + coords[0] * colShapeInfo[7] + coords[1] * colShapeInfo[8]; - const uint imH = coords[2] + pH; - const uint imW = coords[3] + pW; + const uint imH = coords[2] + pH; + const uint imW = coords[3] + pW; - const uint colHstart = (imH < kH) ? 0 : (imH - kH) / sH + 1; - const uint colWstart = (imW < kW) ? 0 : (imW - kW) / sW + 1; + const uint colHstart = (imH < kH) ? 0 : (imH - kH) / sH + 1; + const uint colWstart = (imW < kW) ? 0 : (imW - kW) / sW + 1; - const uint colHend = sd::math::nd4j_min(imH / sH + 1, oH); - const uint colWend = sd::math::nd4j_min(imW / sW + 1, oW); + const uint colHend = sd::math::nd4j_min(imH / sH + 1, oH); + const uint colWend = sd::math::nd4j_min(imW / sW + 1, oW); - T val = 0; + T val = 0; - for(coords[4] = colHstart; coords[4] < colHend; ++coords[4]) { - coords[2] = imH - coords[4] * sH; - if(coords[2] % dH != 0) continue; + for (coords[4] = colHstart; coords[4] < colHend; ++coords[4]) { + coords[2] = imH - coords[4] * sH; + if (coords[2] % dH != 0) continue; - for(coords[5] = colWstart; coords[5] < colWend; ++coords[5]) { - coords[3] = imW - coords[5] * sW; - if(coords[3] % dW != 0) continue; + for (coords[5] = colWstart; coords[5] < colWend; ++coords[5]) { + coords[3] = imW - coords[5] * sW; + if (coords[3] % dW != 0) continue; - val += col[bSiCoffset + (coords[2]/dH)*colShapeInfo[9] + (coords[3]/dW)*colShapeInfo[10] + coords[4]*colShapeInfo[11] + coords[5]*colShapeInfo[12]]; - } - } - im[imOffset] = val; + val += col[bSiCoffset + (coords[2] / dH) * colShapeInfo[9] + + (coords[3] / dW) * colShapeInfo[10] + + coords[4] * colShapeInfo[11] + coords[5] * colShapeInfo[12]]; + } } + im[imOffset] = val; + } } //////////////////////////////////////////////////////////////////////// -// columns [bS, iC, kH, kW, oH, oW] to be de-convoluted to image [bS, iC, iH, iW] -template -__global__ static void col2imCuda2(const void *columns, void *image, const Nd4jLong *colShapeInfo, const Nd4jLong *imShapeInfo, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) { - - const auto col = reinterpret_cast(columns); - auto im = reinterpret_cast(image); - - auto colShape = shape::shapeOf(const_cast(colShapeInfo)); - auto colStride = shape::stride(const_cast(colShapeInfo)); - - int colStride0 = colStride[0]; - int colStride1 = colStride[1]; - int colStride2 = colStride[2]; - int colStride3 = colStride[3]; - int colStride4 = colStride[4]; - int colStride5 = colStride[5]; - - int kH = colShape[2]; - int kW = colShape[3]; - - auto imShape = shape::shapeOf(const_cast(imShapeInfo)); - auto imOrder = shape::order(const_cast(imShapeInfo)); - auto imStride = shape::stride(const_cast(imShapeInfo)); - - int bS = imShape[0]; - int iC = imShape[1]; - int iH = imShape[2]; - int iW = imShape[3]; - - int oH = colShape[4];//(iH + 2 * pH - kH) / sW + 1; - int oW = colShape[5];//(iW + 2 * pW - kW) / sH + 1; - - int n = bS * iC * iH * iW; - - //Effective kernel size, accounting for dilation - int kHeff = kH + (kH - 1) * (dH - 1); - int kWeff = kW + (kW - 1) * (dW - 1); - - for (int i = (blockDim.x * blockIdx.x) + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { - T val = 0; - - int w_im = i % iW + pW; - int h_im = (i / iW) % iH + pH; - int c_im = i / (iW * iH); - int b = c_im / iC; - int c = c_im % iC; - - // compute the start and end of the output - // These are the indexes for dimensions ??? in the 6d col matrix - int w_col_start = (w_im < kWeff) ? 0 : (w_im - kWeff) / sW + 1; - int w_col_end = sd::math::nd4j_min(w_im / sW + 1, oW); - - int h_col_start = (h_im < kHeff) ? 0 : (h_im - kHeff) / sH + 1; - int h_col_end = sd::math::nd4j_min(h_im / sH + 1, oH); - - //Iterate over col entries in the 6d array... these are added up - for (int colH = h_col_start; colH < h_col_end; colH += 1) { - for (int colW = w_col_start; colW < w_col_end; colW += 1) { - int kRow = (h_im - colH * sH); - int kCol = (w_im - colW * sW); - - if(kRow % dH == 0 && kCol % dW == 0){ - kRow /= dH; - kCol /= dW; - - int data_col_index = b * colStride0 + c * colStride1 + kRow * colStride2 + kCol * colStride3 + colH * colStride4 + colW * colStride5; - val += col[data_col_index]; - } - } - } - - int i_f = 0; - int i_c = i; - for (int dim = 3; dim >= 0; dim--) { - i_f += (i_c % imShape[dim]) * imStride[dim]; - i_c = i_c / imShape[dim]; - } - - im[i_f] = val; - } +// columns [bS, iC, kH, kW, oH, oW] to be de-convoluted to image [bS, iC, iH, +// iW] +template +__global__ static void col2imCuda2(const void* columns, void* image, + const Nd4jLong* colShapeInfo, + const Nd4jLong* imShapeInfo, const int sH, + const int sW, const int pH, const int pW, + const int dH, const int dW) { + const auto col = reinterpret_cast(columns); + auto im = reinterpret_cast(image); + + auto colShape = shape::shapeOf(const_cast(colShapeInfo)); + auto colStride = shape::stride(const_cast(colShapeInfo)); + + int colStride0 = colStride[0]; + int colStride1 = colStride[1]; + int colStride2 = colStride[2]; + int colStride3 = colStride[3]; + int colStride4 = colStride[4]; + int colStride5 = colStride[5]; + + int kH = colShape[2]; + int kW = colShape[3]; + + auto imShape = shape::shapeOf(const_cast(imShapeInfo)); + auto imOrder = shape::order(const_cast(imShapeInfo)); + auto imStride = shape::stride(const_cast(imShapeInfo)); + + int bS = imShape[0]; + int iC = imShape[1]; + int iH = imShape[2]; + int iW = imShape[3]; + + int oH = colShape[4]; //(iH + 2 * pH - kH) / sW + 1; + int oW = colShape[5]; //(iW + 2 * pW - kW) / sH + 1; + + int n = bS * iC * iH * iW; + + // Effective kernel size, accounting for dilation + int kHeff = kH + (kH - 1) * (dH - 1); + int kWeff = kW + (kW - 1) * (dW - 1); + + for (int i = (blockDim.x * blockIdx.x) + threadIdx.x; i < n; + i += blockDim.x * gridDim.x) { + T val = 0; + + int w_im = i % iW + pW; + int h_im = (i / iW) % iH + pH; + int c_im = i / (iW * iH); + int b = c_im / iC; + int c = c_im % iC; + + // compute the start and end of the output + // These are the indexes for dimensions ??? in the 6d col matrix + int w_col_start = (w_im < kWeff) ? 0 : (w_im - kWeff) / sW + 1; + int w_col_end = sd::math::nd4j_min(w_im / sW + 1, oW); + + int h_col_start = (h_im < kHeff) ? 0 : (h_im - kHeff) / sH + 1; + int h_col_end = sd::math::nd4j_min(h_im / sH + 1, oH); + + // Iterate over col entries in the 6d array... these are added up + for (int colH = h_col_start; colH < h_col_end; colH += 1) { + for (int colW = w_col_start; colW < w_col_end; colW += 1) { + int kRow = (h_im - colH * sH); + int kCol = (w_im - colW * sW); + + if (kRow % dH == 0 && kCol % dW == 0) { + kRow /= dH; + kCol /= dW; + + int data_col_index = b * colStride0 + c * colStride1 + + kRow * colStride2 + kCol * colStride3 + + colH * colStride4 + colW * colStride5; + val += col[data_col_index]; + } + } + } + + int i_f = 0; + int i_c = i; + for (int dim = 3; dim >= 0; dim--) { + i_f += (i_c % imShape[dim]) * imStride[dim]; + i_c = i_c / imShape[dim]; + } + + im[i_f] = val; + } } ////////////////////////////////////////////////////////////////////////// template -static void col2imCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* columns, const Nd4jLong* colShapeInfo, - void* image, const Nd4jLong* imShapeInfo, - const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) { - - // col2imCuda2<<<512, 512, 1024, *stream>>>(columns, image, colShapeInfo, imShapeInfo, sH, sW, pH, pW, dH, dW); - col2imCuda<<>>(columns, colShapeInfo, image, imShapeInfo, sH, sW, pH, pW, dH, dW); +static void col2imCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* columns, + const Nd4jLong* colShapeInfo, void* image, + const Nd4jLong* imShapeInfo, const int sH, + const int sW, const int pH, const int pW, + const int dH, const int dW) { + // col2imCuda2<<<512, 512, 1024, *stream>>>(columns, image, colShapeInfo, + // imShapeInfo, sH, sW, pH, pW, dH, dW); + col2imCuda<<>>( + columns, colShapeInfo, image, imShapeInfo, sH, sW, pH, pW, dH, dW); } ////////////////////////////////////////////////////////////////////////// -void col2im(sd::LaunchContext& context, const NDArray& col, NDArray& im, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW) { - - PointersManager manager(&context, "col2im"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (im.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = col.rankOf() * sizeof(uint) * threadsPerBlock + 256; - - NDArray::prepareSpecialUse({&im}, {&col}); - BUILD_SINGLE_SELECTOR(im.dataType(), col2imCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context.getCudaStream(), col.specialBuffer(), col.specialShapeInfo(), im.specialBuffer(), im.specialShapeInfo(), sH, sW, pH, pW, dH, dW), FLOAT_TYPES); - NDArray::registerSpecialUse({&im}, {&col}); - - manager.synchronize(); +void col2im(sd::LaunchContext& context, const NDArray& col, NDArray& im, + const int sH, const int sW, const int pH, const int pW, + const int iH, const int iW, const int dH, const int dW) { + PointersManager manager(&context, "col2im"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (im.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = col.rankOf() * sizeof(uint) * threadsPerBlock + 256; + + NDArray::prepareSpecialUse({&im}, {&col}); + BUILD_SINGLE_SELECTOR( + im.dataType(), col2imCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context.getCudaStream(), + col.specialBuffer(), col.specialShapeInfo(), im.specialBuffer(), + im.specialShapeInfo(), sH, sW, pH, pW, dH, dW), + FLOAT_TYPES); + NDArray::registerSpecialUse({&im}, {&col}); + + manager.synchronize(); } - - -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/compare_elem.cu b/libnd4j/include/ops/declarable/helpers/cuda/compare_elem.cu index 8d0bede625e5..dc757716e6d2 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/compare_elem.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/compare_elem.cu @@ -17,119 +17,131 @@ #include namespace sd { - namespace ops { - namespace helpers { - - template - static _CUDA_G void comparator(void *vx, const Nd4jLong *xShapeInfo, Nd4jLong length, const bool isStrict, void *reductionBuffer, bool *z) { - auto x = reinterpret_cast(vx); - auto reduction = reinterpret_cast(reductionBuffer); - - extern __shared__ uint32_t shared[]; - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - - shared[threadIdx.x] = 0; - - // each thread will compare 2 elements: E and E+1 - for (int e = tid; e < length - 1; e += blockDim.x * gridDim.x) { - auto val0 = x[shape::getIndexOffset(e, xShapeInfo)]; - auto val1 = x[shape::getIndexOffset(e+1, xShapeInfo)]; - - bool v = false; - if (isStrict) - v = val1 > val0; - else - v = val1 >= val0; - - // store comparison result in shared memory - shared[threadIdx.x] += v ? 0 : 1; - } - __syncthreads(); - - // aggregate sums in shared memory - for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) { - if (threadIdx.x < activeThreads) - shared[threadIdx.x] += shared[threadIdx.x + activeThreads]; - __syncthreads(); - } - - - // store over the grid if we have more than 1 block - if (gridDim.x > 1) { - - auto tc = reinterpret_cast(reductionBuffer); - __shared__ bool amLast; - - tid = threadIdx.x; - if (threadIdx.x == 0) - reduction[blockIdx.x] = shared[0]; - - __threadfence(); - __syncthreads(); - - if (threadIdx.x == 0) { - unsigned int ticket = atomicInc(&tc[16384], gridDim.x); - amLast = (ticket == gridDim.x - 1); - } - - __syncthreads(); - - if (amLast) { - tc[16384] = 0; - shared[threadIdx.x] = 0; +namespace ops { +namespace helpers { + +template +static _CUDA_G void comparator(void *vx, const Nd4jLong *xShapeInfo, + Nd4jLong length, const bool isStrict, + void *reductionBuffer, bool *z) { + auto x = reinterpret_cast(vx); + auto reduction = reinterpret_cast(reductionBuffer); + + extern __shared__ uint32_t shared[]; + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + + shared[threadIdx.x] = 0; + + // each thread will compare 2 elements: E and E+1 + for (int e = tid; e < length - 1; e += blockDim.x * gridDim.x) { + auto val0 = x[shape::getIndexOffset(e, xShapeInfo)]; + auto val1 = x[shape::getIndexOffset(e + 1, xShapeInfo)]; + + bool v = false; + if (isStrict) + v = val1 > val0; + else + v = val1 >= val0; + + // store comparison result in shared memory + shared[threadIdx.x] += v ? 0 : 1; + } + __syncthreads(); + + // aggregate sums in shared memory + for (uint activeThreads = blockDim.x / 2; activeThreads > 0; + activeThreads /= 2) { + if (threadIdx.x < activeThreads) + shared[threadIdx.x] += shared[threadIdx.x + activeThreads]; + __syncthreads(); + } + + // store over the grid if we have more than 1 block + if (gridDim.x > 1) { + auto tc = reinterpret_cast(reductionBuffer); + __shared__ bool amLast; + + tid = threadIdx.x; + if (threadIdx.x == 0) reduction[blockIdx.x] = shared[0]; + + __threadfence(); + __syncthreads(); + + if (threadIdx.x == 0) { + unsigned int ticket = atomicInc(&tc[16384], gridDim.x); + amLast = (ticket == gridDim.x - 1); + } - for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) - shared[threadIdx.x] += reduction[i]; + __syncthreads(); - __syncthreads(); + if (amLast) { + tc[16384] = 0; + shared[threadIdx.x] = 0; - for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) { - if (threadIdx.x < activeThreads) - shared[threadIdx.x] += shared[threadIdx.x + activeThreads]; - __syncthreads(); - } + for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) + shared[threadIdx.x] += reduction[i]; - __syncthreads(); + __syncthreads(); - if (threadIdx.x == 0) { - z[0] = shared[0] == 0; - } - } - } - else { - // if we have only 1 block, we just store results right away - if (threadIdx.x == 0) { - auto tc = reinterpret_cast(reductionBuffer); - tc[16384] = 0; - z[0] = shared[0] == 0; - } - } - } + for (uint activeThreads = blockDim.x / 2; activeThreads > 0; + activeThreads /= 2) { + if (threadIdx.x < activeThreads) + shared[threadIdx.x] += shared[threadIdx.x + activeThreads]; + __syncthreads(); + } - template - static void _compare_elem(sd::LaunchContext * context, NDArray *input, bool isStrictlyIncreasing, bool& output) { - auto z = NDArrayFactory::create(false, context); + __syncthreads(); - const int numThreads = 256; - const int numBlocks = sd::math::nd4j_min(128, sd::math::nd4j_max(1, input->lengthOf() / numThreads)); + if (threadIdx.x == 0) { + z[0] = shared[0] == 0; + } + } + } else { + // if we have only 1 block, we just store results right away + if (threadIdx.x == 0) { + auto tc = reinterpret_cast(reductionBuffer); + tc[16384] = 0; + z[0] = shared[0] == 0; + } + } +} - comparator<<getCudaStream()>>>(input->specialBuffer(), input->specialShapeInfo(), input->lengthOf(), isStrictlyIncreasing, context->getReductionPointer(), reinterpret_cast(z.specialBuffer())); +template +static void _compare_elem(sd::LaunchContext *context, NDArray *input, + bool isStrictlyIncreasing, bool &output) { + auto z = NDArrayFactory::create(false, context); - z.tickWriteDevice(); - sd::DebugHelper::checkErrorCode(context->getCudaStream(), "is_strictly_increasing"); + const int numThreads = 256; + const int numBlocks = sd::math::nd4j_min( + 128, sd::math::nd4j_max(1, input->lengthOf() / numThreads)); - output = z.e(0); - } + comparator<<getCudaStream()>>>( + input->specialBuffer(), input->specialShapeInfo(), input->lengthOf(), + isStrictlyIncreasing, context->getReductionPointer(), + reinterpret_cast(z.specialBuffer())); - void compare_elem(sd::LaunchContext * context, NDArray *input, bool isStrictlyIncreasing, bool& output) { - auto xType = input->dataType(); - input->syncToDevice(); + z.tickWriteDevice(); + sd::DebugHelper::checkErrorCode(context->getCudaStream(), + "is_strictly_increasing"); - BUILD_SINGLE_SELECTOR(xType, _compare_elem, (context, input, isStrictlyIncreasing, output), LIBND4J_TYPES); - } + output = z.e(0); +} +void compare_elem(sd::LaunchContext *context, NDArray *input, + bool isStrictlyIncreasing, bool &output) { + auto xType = input->dataType(); + input->syncToDevice(); - BUILD_SINGLE_TEMPLATE(template void _compare_elem, (sd::LaunchContext * context, NDArray *A, bool isStrictlyIncreasing, bool& output);, LIBND4J_TYPES); - } - } + BUILD_SINGLE_SELECTOR(xType, _compare_elem, + (context, input, isStrictlyIncreasing, output), + LIBND4J_TYPES); } + +BUILD_SINGLE_TEMPLATE(template void _compare_elem, + (sd::LaunchContext * context, NDArray *A, + bool isStrictlyIncreasing, bool &output); + , LIBND4J_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/compression/compression.cu b/libnd4j/include/ops/declarable/helpers/cuda/compression/compression.cu index 5de20c57f852..8503f0a9d4f2 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/compression/compression.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/compression/compression.cu @@ -17,50 +17,60 @@ // // @author sgazeos@gmail.com // -#include -#include #include +#include +#include namespace sd { namespace ops { namespace helpers { - void decodeBitmap(sd::LaunchContext* context, const NDArray* input, NDArray* output) { - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({output}, {input}); +void decodeBitmap(sd::LaunchContext* context, const NDArray* input, + NDArray* output) { + auto stream = context->getCudaStream(); + NDArray::prepareSpecialUse({output}, {input}); - dim3 launchDims(512, 512, 16384); - auto xType = output->dataType(); - BUILD_SINGLE_SELECTOR(xType, cudaDecodeBitmapGeneric, (launchDims, stream, input->specialBuffer(), output->lengthOf(), output->specialBuffer()), FLOAT_TYPES); + dim3 launchDims(512, 512, 16384); + auto xType = output->dataType(); + BUILD_SINGLE_SELECTOR(xType, cudaDecodeBitmapGeneric, + (launchDims, stream, input->specialBuffer(), + output->lengthOf(), output->specialBuffer()), + FLOAT_TYPES); - sd::DebugHelper::checkErrorCode(stream, "decodeBitmapFloat(...) failed"); + sd::DebugHelper::checkErrorCode(stream, "decodeBitmapFloat(...) failed"); - NDArray::registerSpecialUse({output}, {input}); - } + NDArray::registerSpecialUse({output}, {input}); +} - Nd4jLong encodeBitmap(sd::LaunchContext* context, NDArray* input, NDArray* output, float threshold) { - auto stream = LaunchContext::defaultContext()->getCudaStream(); - int *resultPointer = reinterpret_cast(LaunchContext::defaultContext()->getScalarPointer()); - int *reductionPointer = reinterpret_cast(LaunchContext::defaultContext()->getReductionPointer()); +Nd4jLong encodeBitmap(sd::LaunchContext* context, NDArray* input, + NDArray* output, float threshold) { + auto stream = LaunchContext::defaultContext()->getCudaStream(); + int* resultPointer = reinterpret_cast( + LaunchContext::defaultContext()->getScalarPointer()); + int* reductionPointer = reinterpret_cast( + LaunchContext::defaultContext()->getReductionPointer()); - // nullify result pointer before use - resultPointer[0] = 0; + // nullify result pointer before use + resultPointer[0] = 0; - NDArray::prepareSpecialUse({},{output, input}); + NDArray::prepareSpecialUse({}, {output, input}); - dim3 launchDims(512, 512, 32768); - auto xType = input->dataType(); - BUILD_SINGLE_SELECTOR(xType, cudaEncodeBitmapGeneric, - (launchDims, stream, input->specialBuffer(), input->lengthOf(), reinterpret_cast(output->specialBuffer()), resultPointer, reductionPointer, threshold), - FLOAT_TYPES); + dim3 launchDims(512, 512, 32768); + auto xType = input->dataType(); + BUILD_SINGLE_SELECTOR( + xType, cudaEncodeBitmapGeneric, + (launchDims, stream, input->specialBuffer(), input->lengthOf(), + reinterpret_cast(output->specialBuffer()), resultPointer, + reductionPointer, threshold), + FLOAT_TYPES); - sd::DebugHelper::checkErrorCode(stream, "encodeBitmapFloat(...) failed"); + sd::DebugHelper::checkErrorCode(stream, "encodeBitmapFloat(...) failed"); - Nd4jLong dZ = (Nd4jLong) resultPointer[0]; - resultPointer[0] = 0; + Nd4jLong dZ = (Nd4jLong)resultPointer[0]; + resultPointer[0] = 0; - NDArray::registerSpecialUse({output, input}, {}); - return dZ; - } -} -} + NDArray::registerSpecialUse({output, input}, {}); + return dZ; } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/compression/threshold.cu b/libnd4j/include/ops/declarable/helpers/cuda/compression/threshold.cu index bbe6d688102d..d20a3465d28a 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/compression/threshold.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/compression/threshold.cu @@ -18,215 +18,242 @@ // @author raver119@gmail.com // -#include -#include #include +#include +#include + #include namespace sd { - namespace ops { - namespace helpers { - void prescanArrayRecursive(int** g_scanBlockSums, int *dZ, int *dX, int numElements, int level) { - auto stream = LaunchContext::defaultContext()->getCudaStream(); - - - int blockSize = 512; // max size of the thread blocks - int numBlocks = sd::math::nd4j_max(1, static_cast(ceil(static_cast(numElements) / (2.f * blockSize)))); - int numThreads; - - if (numBlocks > 1) - numThreads = blockSize; - else if (sd::isPowerOfTwo(numElements)) - numThreads = numElements / 2; - else - numThreads = sd::floorPow2(numElements); - - numThreads = sd::math::nd4j_max(1, numThreads); - - int numEltsPerBlock = numThreads * 2; - - - - // if this is a non-power-of-2 array, the last block will be non-full - // compute the smallest power of 2 able to compute its scan. - int numEltsLastBlock = - numElements - (numBlocks-1) * numEltsPerBlock; - int numThreadsLastBlock = sd::math::nd4j_max(1, numEltsLastBlock / 2); - int np2LastBlock = 0; - int sharedMemLastBlock = 0; - - if (numEltsLastBlock != numEltsPerBlock) { - np2LastBlock = 1; - - if(!isPowerOfTwo(numEltsLastBlock)) - numThreadsLastBlock = floorPow2(numEltsLastBlock); - - unsigned int extraSpace = (2 * numThreadsLastBlock) / NUM_BANKS; - sharedMemLastBlock = sizeof(int) * (2 * numThreadsLastBlock + extraSpace); - } - - // padding space is used to avoid shared memory bank conflicts - int extraSpace = numEltsPerBlock / NUM_BANKS; - int sharedMemSize = sizeof(int) * (numEltsPerBlock + extraSpace); - - // setup execution parameters - // if NP2, we process the last block separately - dim3 grid(sd::math::nd4j_max(1, numBlocks - np2LastBlock), 1, 1); - dim3 threads(numThreads, 1, 1); - dim3 gridOnes(1, 1, 1); - dim3 threadsOnes(numThreadsLastBlock, 1, 1); - - if (sharedMemSize < 2048) - sharedMemSize = 2048; - - if (sharedMemLastBlock < 2048) - sharedMemLastBlock = 2048; - - // execute the scan - if (numBlocks > 1) { - sd::prescanLauncher(grid, threads, sharedMemSize, stream, dZ, dX, g_scanBlockSums[level], numThreads * 2, 0, 0); - if (np2LastBlock) { - sd::prescanLauncher(gridOnes, threadsOnes, sharedMemLastBlock, stream, dZ, dX, g_scanBlockSums[level], numEltsLastBlock, numBlocks - 1, numElements - numEltsLastBlock); - } - - // After scanning all the sub-blocks, we are mostly done. But now we - // need to take all of the last values of the sub-blocks and scan those. - // This will give us a new value that must be sdded to each block to - // get the final results. - // recursive (CPU) call - prescanArrayRecursive(g_scanBlockSums, g_scanBlockSums[level], g_scanBlockSums[level], numBlocks, level+1); - - sd::uniformAdd<<>>(dZ, g_scanBlockSums[level], numElements - numEltsLastBlock, 0, 0); - - if (np2LastBlock) { - sd::uniformAdd<<<1, numThreadsLastBlock, 1024, *stream>>>(dZ, g_scanBlockSums[level], numEltsLastBlock, numBlocks - 1, numElements - numEltsLastBlock); - } - } else if (isPowerOfTwo(numElements)) { - sd::prescanLauncher(grid, threads, sharedMemSize, stream, dZ, dX, 0, numThreads * 2, 0, 0); - } else { - sd::prescanLauncher(grid, threads, sharedMemSize, stream, dZ, dX, 0, numElements, 0, 0); - } - } - - static void encodeThresholdP2Int_(void **prs, int *dx, Nd4jLong N, int *dz) { - auto stream = LaunchContext::defaultContext()->getCudaStream(); - - prescanArrayRecursive(reinterpret_cast(prs), dz, dx + 1, (int) N, 0); - sd::DebugHelper::checkErrorCode(stream, "encodeThresholdP2Int(...) failed"); - } - - static void encodeThresholdP3_(void *dx, const Nd4jLong *hXShapeInfo, int *offsets, Nd4jLong N, int *dz){ - auto stream = LaunchContext::defaultContext()->getCudaStream(); - - int blockSize = 512; - int numBlocks = N / blockSize + (N % blockSize ? 1 : 0); - - dim3 launchDims(numBlocks, blockSize, 8192); - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - BUILD_SINGLE_SELECTOR(xType, encoderKernelP3Generic, (launchDims, stream, dx, offsets, N, dz), FLOAT_TYPES); - - sd::DebugHelper::checkErrorCode(stream, "encodeThresholdP3Float(...) failed"); - } - - - static NDArray thresholdEstimate_(const NDArray &updates, const float threshold) { - const int numThreads = 512; - const int numBlocks = updates.lengthOf() / numThreads + (updates.lengthOf() % numThreads ? 1 : 0); - - auto tmp = NDArrayFactory::create('c', {numBlocks + 1}); +namespace ops { +namespace helpers { +void prescanArrayRecursive(int **g_scanBlockSums, int *dZ, int *dX, + int numElements, int level) { + auto stream = LaunchContext::defaultContext()->getCudaStream(); + + int blockSize = 512; // max size of the thread blocks + int numBlocks = sd::math::nd4j_max( + 1, static_cast( + ceil(static_cast(numElements) / (2.f * blockSize)))); + int numThreads; + + if (numBlocks > 1) + numThreads = blockSize; + else if (sd::isPowerOfTwo(numElements)) + numThreads = numElements / 2; + else + numThreads = sd::floorPow2(numElements); + + numThreads = sd::math::nd4j_max(1, numThreads); + + int numEltsPerBlock = numThreads * 2; + + // if this is a non-power-of-2 array, the last block will be non-full + // compute the smallest power of 2 able to compute its scan. + int numEltsLastBlock = numElements - (numBlocks - 1) * numEltsPerBlock; + int numThreadsLastBlock = sd::math::nd4j_max(1, numEltsLastBlock / 2); + int np2LastBlock = 0; + int sharedMemLastBlock = 0; + + if (numEltsLastBlock != numEltsPerBlock) { + np2LastBlock = 1; + + if (!isPowerOfTwo(numEltsLastBlock)) + numThreadsLastBlock = floorPow2(numEltsLastBlock); + + unsigned int extraSpace = (2 * numThreadsLastBlock) / NUM_BANKS; + sharedMemLastBlock = sizeof(int) * (2 * numThreadsLastBlock + extraSpace); + } + + // padding space is used to avoid shared memory bank conflicts + int extraSpace = numEltsPerBlock / NUM_BANKS; + int sharedMemSize = sizeof(int) * (numEltsPerBlock + extraSpace); + + // setup execution parameters + // if NP2, we process the last block separately + dim3 grid(sd::math::nd4j_max(1, numBlocks - np2LastBlock), 1, 1); + dim3 threads(numThreads, 1, 1); + dim3 gridOnes(1, 1, 1); + dim3 threadsOnes(numThreadsLastBlock, 1, 1); + + if (sharedMemSize < 2048) sharedMemSize = 2048; + + if (sharedMemLastBlock < 2048) sharedMemLastBlock = 2048; + + // execute the scan + if (numBlocks > 1) { + sd::prescanLauncher(grid, threads, sharedMemSize, stream, dZ, + dX, g_scanBlockSums[level], numThreads * 2, + 0, 0); + if (np2LastBlock) { + sd::prescanLauncher(gridOnes, threadsOnes, sharedMemLastBlock, + stream, dZ, dX, g_scanBlockSums[level], + numEltsLastBlock, numBlocks - 1, + numElements - numEltsLastBlock); + } - dim3 launchDims(numBlocks, numThreads, 1024); - auto xType = updates.dataType(); + // After scanning all the sub-blocks, we are mostly done. But now we + // need to take all of the last values of the sub-blocks and scan those. + // This will give us a new value that must be sdded to each block to + // get the final results. + // recursive (CPU) call + prescanArrayRecursive(g_scanBlockSums, g_scanBlockSums[level], + g_scanBlockSums[level], numBlocks, level + 1); + + sd::uniformAdd<<>>( + dZ, g_scanBlockSums[level], numElements - numEltsLastBlock, 0, 0); + + if (np2LastBlock) { + sd::uniformAdd<<<1, numThreadsLastBlock, 1024, *stream>>>( + dZ, g_scanBlockSums[level], numEltsLastBlock, numBlocks - 1, + numElements - numEltsLastBlock); + } + } else if (isPowerOfTwo(numElements)) { + sd::prescanLauncher(grid, threads, sharedMemSize, stream, dZ, + dX, 0, numThreads * 2, 0, 0); + } else { + sd::prescanLauncher(grid, threads, sharedMemSize, stream, dZ, + dX, 0, numElements, 0, 0); + } - NDArray::prepareSpecialUse({&tmp}, {&updates}); - BUILD_SINGLE_SELECTOR(xType, encoderKernelP1Generic, (launchDims, LaunchContext::defaultContext()->getCudaStream(), updates.specialBuffer(), updates.lengthOf(), tmp.specialBuffer(), threshold), FLOAT_TYPES); - NDArray::registerSpecialUse({&tmp}, {&updates}); - return std::move(tmp); - } +} - int32_t thresholdEstimate(const NDArray &updates, const float threshold) { - return thresholdEstimate_(updates, threshold).e(0); - } - - void thresholdEncode(NDArray &updates, NDArray &encoded, float threshold) { - // we need these blocks in order to know, how many "updates" will be processed by each GPU block - auto blocks = thresholdEstimate_(updates, threshold); - - const int numThreads = 512; - const int numBlocks = updates.lengthOf() / numThreads + (updates.lengthOf() % numThreads ? 1 : 0); - - const int prefixThreads = 512; - int numElts = numBlocks; - int level = 0; - - // here we just calculate number of sumBlock arrays - do { - int numPrefixBlocks = sd::math::nd4j_max(1, sd::math::nd4j_ceil((float) numElts / (2.0f * prefixThreads))); - if (numPrefixBlocks > 1) { - level++; - } - numElts = numPrefixBlocks; - } while (numElts > 1); - - - std::vector tempArrays(level); - std::vector pointers(level); - - level = 0; - numElts = numBlocks; +static void encodeThresholdP2Int_(void **prs, int *dx, Nd4jLong N, int *dz) { + auto stream = LaunchContext::defaultContext()->getCudaStream(); - do { - int numPrefixBlocks = sd::math::nd4j_max(1, sd::math::nd4j_ceil((float) numElts / (2.0f * prefixThreads))); - if (numPrefixBlocks > 1) { - tempArrays[level] = std::move(NDArrayFactory::create('c', {numPrefixBlocks})); - pointers[level] = tempArrays[level].specialBuffer();; - level++; - } - numElts = numPrefixBlocks; - } while (numElts > 1); + prescanArrayRecursive(reinterpret_cast(prs), dz, dx + 1, (int)N, 0); + sd::DebugHelper::checkErrorCode(stream, "encodeThresholdP2Int(...) failed"); +} - PointersManager pm(LaunchContext::defaultContext(), "thresholdEncode"); - auto offsets = NDArrayFactory::create('c', {numBlocks}); +static void encodeThresholdP3_(void *dx, const Nd4jLong *hXShapeInfo, + int *offsets, Nd4jLong N, int *dz) { + auto stream = LaunchContext::defaultContext()->getCudaStream(); - // we want to check, if we're hiting external limit on number of encoded elements - auto numMatches = blocks.e(0); - if (numMatches > encoded.lengthOf() - 4) { - blocks.p(0, encoded.lengthOf() - 4); - blocks.syncToDevice(); - } + int blockSize = 512; + int numBlocks = N / blockSize + (N % blockSize ? 1 : 0); - NDArray::prepareSpecialUse({}, {&encoded, &updates}); + dim3 launchDims(numBlocks, blockSize, 8192); + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + BUILD_SINGLE_SELECTOR(xType, encoderKernelP3Generic, + (launchDims, stream, dx, offsets, N, dz), FLOAT_TYPES); - // filling offsets - encodeThresholdP2Int_(reinterpret_cast(pointers.data()), - reinterpret_cast(blocks.specialBuffer()), - numBlocks, - reinterpret_cast(offsets.specialBuffer())); + sd::DebugHelper::checkErrorCode(stream, "encodeThresholdP3Float(...) failed"); +} - NDArray::registerSpecialUse({&blocks, &offsets}, {}); - pm.synchronize(); +static NDArray thresholdEstimate_(const NDArray &updates, + const float threshold) { + const int numThreads = 512; + const int numBlocks = updates.lengthOf() / numThreads + + (updates.lengthOf() % numThreads ? 1 : 0); + auto tmp = NDArrayFactory::create('c', {numBlocks + 1}); - encodeThresholdP3_(updates.specialBuffer(), - updates.shapeInfo(), - reinterpret_cast(offsets.specialBuffer()), - updates.lengthOf(), - reinterpret_cast(encoded.specialBuffer())); + dim3 launchDims(numBlocks, numThreads, 1024); + auto xType = updates.dataType(); - pm.synchronize(); + NDArray::prepareSpecialUse({&tmp}, {&updates}); + BUILD_SINGLE_SELECTOR( + xType, encoderKernelP1Generic, + (launchDims, LaunchContext::defaultContext()->getCudaStream(), + updates.specialBuffer(), updates.lengthOf(), tmp.specialBuffer(), + threshold), + FLOAT_TYPES); + NDArray::registerSpecialUse({&tmp}, {&updates}); - NDArray::registerSpecialUse({&encoded, &updates}, {}); - } + return std::move(tmp); +} - void thresholdDecode(const NDArray &encoded, NDArray &updates) { - dim3 launchDims(128, 512, 512); - auto xType = updates.dataType(); +int32_t thresholdEstimate(const NDArray &updates, const float threshold) { + return thresholdEstimate_(updates, threshold).e(0); +} - NDArray::prepareSpecialUse({&updates}, {&encoded}); - BUILD_SINGLE_SELECTOR(xType, decoderKernelGeneric, (launchDims, LaunchContext::defaultContext()->getCudaStream(), encoded.specialBuffer(), updates.lengthOf(), updates.specialBuffer()), FLOAT_TYPES); - NDArray::registerSpecialUse({&updates}, {&encoded}); - } - } +void thresholdEncode(NDArray &updates, NDArray &encoded, float threshold) { + // we need these blocks in order to know, how many "updates" will be processed + // by each GPU block + auto blocks = thresholdEstimate_(updates, threshold); + + const int numThreads = 512; + const int numBlocks = updates.lengthOf() / numThreads + + (updates.lengthOf() % numThreads ? 1 : 0); + + const int prefixThreads = 512; + int numElts = numBlocks; + int level = 0; + + // here we just calculate number of sumBlock arrays + do { + int numPrefixBlocks = sd::math::nd4j_max( + 1, sd::math::nd4j_ceil((float)numElts / + (2.0f * prefixThreads))); + if (numPrefixBlocks > 1) { + level++; } + numElts = numPrefixBlocks; + } while (numElts > 1); + + std::vector tempArrays(level); + std::vector pointers(level); + + level = 0; + numElts = numBlocks; + + do { + int numPrefixBlocks = sd::math::nd4j_max( + 1, sd::math::nd4j_ceil((float)numElts / + (2.0f * prefixThreads))); + if (numPrefixBlocks > 1) { + tempArrays[level] = + std::move(NDArrayFactory::create('c', {numPrefixBlocks})); + pointers[level] = tempArrays[level].specialBuffer();; + level++;} + numElts = numPrefixBlocks; + } while (numElts > 1); + + PointersManager pm(LaunchContext::defaultContext(), "thresholdEncode"); + auto offsets = NDArrayFactory::create('c', {numBlocks}); + + // we want to check, if we're hiting external limit on number of encoded + // elements + auto numMatches = blocks.e(0); + if (numMatches > encoded.lengthOf() - 4) { + blocks.p(0, encoded.lengthOf() - 4); + blocks.syncToDevice(); + } + + NDArray::prepareSpecialUse({}, {&encoded, &updates}); + + // filling offsets + encodeThresholdP2Int_(reinterpret_cast(pointers.data()), + reinterpret_cast(blocks.specialBuffer()), + numBlocks, + reinterpret_cast(offsets.specialBuffer())); + + NDArray::registerSpecialUse({&blocks, &offsets}, {}); + pm.synchronize(); + + encodeThresholdP3_(updates.specialBuffer(), updates.shapeInfo(), + reinterpret_cast(offsets.specialBuffer()), + updates.lengthOf(), + reinterpret_cast(encoded.specialBuffer())); + + pm.synchronize(); + + NDArray::registerSpecialUse({&encoded, &updates}, {}); +} + +void thresholdDecode(const NDArray &encoded, NDArray &updates) { + dim3 launchDims(128, 512, 512); + auto xType = updates.dataType(); + + NDArray::prepareSpecialUse({&updates}, {&encoded}); + BUILD_SINGLE_SELECTOR( + xType, decoderKernelGeneric, + (launchDims, LaunchContext::defaultContext()->getCudaStream(), + encoded.specialBuffer(), updates.lengthOf(), updates.specialBuffer()), + FLOAT_TYPES); + NDArray::registerSpecialUse({&updates}, {&encoded}); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/concat.cu b/libnd4j/include/ops/declarable/helpers/cuda/concat.cu index 400c25f880cc..9872e441effb 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/concat.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/concat.cu @@ -18,176 +18,196 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - -#include -#include -#include -#include #include -#include +#include #include -#include #include +#include +#include +#include +#include -namespace sd { -namespace ops { -namespace helpers { +#include +namespace sd { +namespace ops { +namespace helpers { /////////////////////////////////////////////////////////////////// -template -__global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int axis) { +template +__global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const int axis) { + T* z = reinterpret_cast(vz); + __shared__ Nd4jLong zLen, totalThreads; + __shared__ int rank; - T* z = reinterpret_cast(vz); - __shared__ Nd4jLong zLen, totalThreads; - __shared__ int rank; + if (threadIdx.x == 0) { + zLen = shape::length(zShapeInfo); + rank = shape::rank(zShapeInfo); + totalThreads = gridDim.x * blockDim.x; + } + __syncthreads(); - if (threadIdx.x == 0) { - zLen = shape::length(zShapeInfo); - rank = shape::rank(zShapeInfo); - totalThreads = gridDim.x * blockDim.x; - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - int coords[MAX_RANK]; + int coords[MAX_RANK]; - for (Nd4jLong i = tid; i < zLen; i += totalThreads) { - shape::index2coords(i, zShapeInfo, coords); + for (Nd4jLong i = tid; i < zLen; i += totalThreads) { + shape::index2coords(i, zShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); - int inArrIdx = 0; - Nd4jLong *xShapeInfo = reinterpret_cast(pxShapeInfo)[inArrIdx]; + int inArrIdx = 0; + Nd4jLong* xShapeInfo = reinterpret_cast(pxShapeInfo)[inArrIdx]; - while (coords[axis] >= xShapeInfo[axis + 1]) { - coords[axis] -= xShapeInfo[axis + 1]; - xShapeInfo = reinterpret_cast(pxShapeInfo)[++inArrIdx]; - } + while (coords[axis] >= xShapeInfo[axis + 1]) { + coords[axis] -= xShapeInfo[axis + 1]; + xShapeInfo = reinterpret_cast(pxShapeInfo)[++inArrIdx]; + } - const auto *x = reinterpret_cast(reinterpret_cast(pVx)[inArrIdx]); - const auto xOffset = shape::getOffset(xShapeInfo, coords); + const auto* x = + reinterpret_cast(reinterpret_cast(pVx)[inArrIdx]); + const auto xOffset = shape::getOffset(xShapeInfo, coords); - z[zOffset] = x[xOffset]; - } + z[zOffset] = x[xOffset]; + } } /////////////////////////////////////////////////////////////////// -template -__host__ static void concatCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - void* pVx, void* pxShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int axis) { - - concatCuda<<>>(pVx, pxShapeInfo, vz, zShapeInfo, axis); +template +__host__ static void concatCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, void* pVx, void* pxShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const int axis) { + concatCuda<<>>( + pVx, pxShapeInfo, vz, zShapeInfo, axis); } ////////////////////////////////////////////////////////////////////////// -void concat(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output, const int axis) { - - const int numOfInArrs = inArrs.size(); - const auto sizeofT = output.sizeOfT(); - - NDArray::prepareSpecialUse({&output}, inArrs); - - bool luckCase1 = ((axis == 0 && output.ordering() == 'c') || (axis == output.rankOf() - 1 && output.ordering() == 'f')) && output.ews() == 1; - - if(luckCase1) { - for (uint i = 0; i < numOfInArrs; ++i) { - luckCase1 &= inArrs[i]->ordering() == output.ordering() && inArrs[i]->ews() == 1; - if(!luckCase1) - break; - } +void concat(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output, + const int axis) { + const int numOfInArrs = inArrs.size(); + const auto sizeofT = output.sizeOfT(); + + NDArray::prepareSpecialUse({&output}, inArrs); + + bool luckCase1 = + ((axis == 0 && output.ordering() == 'c') || + (axis == output.rankOf() - 1 && output.ordering() == 'f')) && + output.ews() == 1; + + if (luckCase1) { + for (uint i = 0; i < numOfInArrs; ++i) { + luckCase1 &= + inArrs[i]->ordering() == output.ordering() && inArrs[i]->ews() == 1; + if (!luckCase1) break; } + } - if(luckCase1) { // for example {1,10} + {2,10} + {3,10} = {6, 10} order c; or {10,1} + {10,2} + {10,3} = {10, 6} order f + if (luckCase1) { // for example {1,10} + {2,10} + {3,10} = {6, 10} order c; + // or {10,1} + {10,2} + {10,3} = {10, 6} order f - void* z = static_cast(output.specialBuffer()); + void* z = static_cast(output.specialBuffer()); - for (uint i = 0; i < numOfInArrs; ++i) { - const auto memAmountToCopy = inArrs[i]->lengthOf() * sizeofT; - cudaMemcpyAsync(z, reinterpret_cast(inArrs[i]->specialBuffer()), memAmountToCopy, cudaMemcpyDeviceToDevice, *context->getCudaStream()); - z = static_cast(z) + memAmountToCopy; - } + for (uint i = 0; i < numOfInArrs; ++i) { + const auto memAmountToCopy = inArrs[i]->lengthOf() * sizeofT; + cudaMemcpyAsync( + z, reinterpret_cast(inArrs[i]->specialBuffer()), + memAmountToCopy, cudaMemcpyDeviceToDevice, *context->getCudaStream()); + z = static_cast(z) + memAmountToCopy; + } - if(cudaStreamSynchronize(*context->getCudaStream()) != 0) - throw std::runtime_error("concat cuda: luckCase1 failed!"); + if (cudaStreamSynchronize(*context->getCudaStream()) != 0) + throw std::runtime_error("concat cuda: luckCase1 failed!"); - for(int i = 0; i < numOfInArrs; ++i) - inArrs[i]->tickReadDevice(); - output.tickWriteDevice(); + for (int i = 0; i < numOfInArrs; ++i) inArrs[i]->tickReadDevice(); + output.tickWriteDevice(); - return; - } + return; + } - // const bool isZcontin = output.strideAt(axis) == 1; - // bool areInputsContin = true; - // bool allSameOrder = true; - // std::vector strideOfContigStride(numOfInArrs); + // const bool isZcontin = output.strideAt(axis) == 1; + // bool areInputsContin = true; + // bool allSameOrder = true; + // std::vector strideOfContigStride(numOfInArrs); - // if(isZcontin) { + // if(isZcontin) { - // for (uint i = 0; i < inArrs.size(); ++i) { + // for (uint i = 0; i < inArrs.size(); ++i) { - // areInputsContin &= inArrs[i]->strideAt(axis) == 1; - // allSameOrder &= output.ordering() == inArrs[i]->ordering(); - // if(!areInputsContin || !allSameOrder) - // break; + // areInputsContin &= inArrs[i]->strideAt(axis) == 1; + // allSameOrder &= output.ordering() == inArrs[i]->ordering(); + // if(!areInputsContin || !allSameOrder) + // break; - // strideOfContigStride[i] = shape::strideOverContigAxis(axis, inArrs[i]->shapeInfo()); - // } - // } + // strideOfContigStride[i] = shape::strideOverContigAxis(axis, + // inArrs[i]->shapeInfo()); + // } + // } - // const bool luckCase2 = isZcontin && areInputsContin && allSameOrder; + // const bool luckCase2 = isZcontin && areInputsContin && allSameOrder; - // if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, here axis 1 shoud have stride = 1 for all inputs arrays and output array + // if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, + // here axis 1 shoud have stride = 1 for all inputs arrays and output array - // const auto zStep = shape::strideOverContigAxis(axis, output.shapeInfo()); + // const auto zStep = shape::strideOverContigAxis(axis, + // output.shapeInfo()); - // for (uint i = 0; i < output.lengthOf() / output.sizeAt(axis); ++i) { + // for (uint i = 0; i < output.lengthOf() / output.sizeAt(axis); ++i) { - // const auto iShift = i * sizeofT; - // void* z = static_cast(output.specialBuffer()) + zStep * iShift; + // const auto iShift = i * sizeofT; + // void* z = static_cast(output.specialBuffer()) + zStep * + // iShift; - // for (uint j = 0; j < numOfInArrs; ++j) { - // const auto xDim = inArrs[j]->sizeAt(axis); - // void* x = static_cast(inArrs[j]->specialBuffer()) + strideOfContigStride[j] * iShift; - // const auto memSizeToCopy = xDim * sizeofT; - // cudaMemcpyAsync(z, x, memSizeToCopy, cudaMemcpyDeviceToDevice, *context->getCudaStream()); - // z = static_cast(z) + memSizeToCopy; - // } - // } + // for (uint j = 0; j < numOfInArrs; ++j) { + // const auto xDim = inArrs[j]->sizeAt(axis); + // void* x = static_cast(inArrs[j]->specialBuffer()) + + // strideOfContigStride[j] * iShift; const auto memSizeToCopy = + // xDim * sizeofT; cudaMemcpyAsync(z, x, memSizeToCopy, + // cudaMemcpyDeviceToDevice, *context->getCudaStream()); z = + // static_cast(z) + memSizeToCopy; + // } + // } - // if(cudaStreamSynchronize(*context->getCudaStream()) != 0) - // throw std::runtime_error("concat cuda: luckCase2 failed!"); - // } - // else { // general (slower) case + // if(cudaStreamSynchronize(*context->getCudaStream()) != 0) + // throw std::runtime_error("concat cuda: luckCase2 failed!"); + // } + // else { // general (slower) case - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = 256; + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = 256; - // prepare arrays of pointers on buffers and shapes - std::vector hInBuffers(numOfInArrs); - std::vector hInShapeInfo(numOfInArrs); + // prepare arrays of pointers on buffers and shapes + std::vector hInBuffers(numOfInArrs); + std::vector hInShapeInfo(numOfInArrs); - for(int i = 0; i < numOfInArrs; ++i) { - hInBuffers[i] = inArrs[i]->specialBuffer(); - hInShapeInfo[i] = inArrs[i]->specialShapeInfo(); - } + for (int i = 0; i < numOfInArrs; ++i) { + hInBuffers[i] = inArrs[i]->specialBuffer(); + hInShapeInfo[i] = inArrs[i]->specialShapeInfo(); + } - PointersManager manager(context, "helpers::concat"); + PointersManager manager(context, "helpers::concat"); - void* dInBuffers = manager.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void*)); - void* dInShapeInfo = manager.replicatePointer(hInShapeInfo.data(), hInShapeInfo.size() * sizeof(Nd4jLong*)); + void* dInBuffers = manager.replicatePointer( + hInBuffers.data(), hInBuffers.size() * sizeof(void*)); + void* dInShapeInfo = manager.replicatePointer( + hInShapeInfo.data(), hInShapeInfo.size() * sizeof(Nd4jLong*)); - BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), concatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), dInBuffers, dInShapeInfo, output.specialBuffer(), output.specialShapeInfo(), axis), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR( + inArrs[0]->dataType(), concatCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + dInBuffers, dInShapeInfo, output.specialBuffer(), + output.specialShapeInfo(), axis), + LIBND4J_TYPES); - manager.synchronize(); - // } + manager.synchronize(); + // } - NDArray::registerSpecialUse({&output}, inArrs); + NDArray::registerSpecialUse({&output}, inArrs); } -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/confusion.cu b/libnd4j/include/ops/declarable/helpers/cuda/confusion.cu index fd676ba83603..a34ccb8b856b 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/confusion.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/confusion.cu @@ -18,105 +18,135 @@ // @author GS // -#include #include -#include -#include #include +#include +#include +#include namespace sd { namespace ops { namespace helpers { - template - __global__ static void copyBuffers(Nd4jLong* destination, void const* source, Nd4jLong bufferLength) { - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - for (int t = tid; t < bufferLength; t += step) { - destination[t] = static_cast(reinterpret_cast(source)[t]); - } - } - - template - __global__ static void confusionFunctorKernel(Nd4jLong* labelsBuffer, Nd4jLong* predictionBuffer, Nd4jLong bufferLength, void const* weightsBuffer, void* outputBuffer, const Nd4jLong* tadShape, const Nd4jLong* tadOffsets) { - __shared__ int arrIdx, blocksPerArr; - __shared__ T *z; - __shared__ T const* w; - __shared__ Nd4jLong *zShapeInfo, *xShapeInfo, arrLen; - - if (threadIdx.x == 0) { - z = reinterpret_cast(outputBuffer); - w = reinterpret_cast(weightsBuffer); - arrLen = shape::length(tadShape); - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - for (int t = tid; t < bufferLength; t += step) { - auto label = labelsBuffer[t]; //->e(j); - auto pred = predictionBuffer[t]; //->e(j); - auto tZ = z + tadOffsets[label]; - T val = (weightsBuffer == nullptr ? (T)1.0f : w[t]); - - auto idx = shape::getIndexOffset(pred, tadShape); - tZ[idx] = val; - } - } - - template - void _confusionFunctor(sd::LaunchContext * context, NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output) { - auto stream = context->getCudaStream(); - - auto pack = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), 1); - - PointersManager manager(context, "helpers::confusion"); - - Nd4jLong* labelsLongBuffer = labels->dataType() == sd::DataType::INT64?(Nd4jLong*)labels->specialBuffer():nullptr; - Nd4jLong* predictionLongBuffer = predictions->dataType() == sd::DataType::INT64?(Nd4jLong*)predictions->specialBuffer():nullptr; - - if (labelsLongBuffer == nullptr) { - auto err = cudaMalloc(&labelsLongBuffer, labels->lengthOf() * sizeof(Nd4jLong)); - if (err != 0) - throw sd::cuda_exception::build("Cannot allocate memory for labels long buffer", err); - // copy with type conversion - copyBuffers<<<256, 512, 1024, *stream>>>(labelsLongBuffer, labels->specialBuffer(), labels->lengthOf()); - } - - if (predictionLongBuffer == nullptr) { - auto err = cudaMalloc(&predictionLongBuffer, predictions->lengthOf() * sizeof(Nd4jLong)); - if (err != 0) - throw sd::cuda_exception::build("Cannot allocate memory for predictions long buffer", err); - // copy with type conversion - copyBuffers<<<256, 512, 1024, *stream>>>(predictionLongBuffer, predictions->specialBuffer(), predictions->lengthOf()); - } - - auto bufferLength = labels->lengthOf(); - dim3 launchDims(32, 32, 1024); - confusionFunctorKernel<<>>(labelsLongBuffer, predictionLongBuffer, bufferLength, weights != nullptr? weights->specialBuffer():nullptr, output->specialBuffer(), pack.specialShapeInfo(), pack.specialOffsets()); - - manager.synchronize(); - - if (predictionLongBuffer != predictions->specialBuffer()) { - cudaError_t err = cudaFree(predictionLongBuffer); - if (err != 0) - throw sd::cuda_exception::build("Cannot deallocate memory for predictions long buffer", err); - } - - if (labelsLongBuffer != labels->specialBuffer()) { - cudaError_t err = cudaFree(labelsLongBuffer); - if (err != 0) - throw sd::cuda_exception::build("Cannot deallocate memory for labels long buffer", err); - } - } - - void confusionFunctor(sd::LaunchContext * context, NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output) { - auto xType = predictions->dataType(); - auto zType = output->dataType(); // weights can be null - NDArray::prepareSpecialUse({output}, {labels, predictions, weights}); - BUILD_DOUBLE_SELECTOR(xType, zType, _confusionFunctor, (context, labels, predictions, weights, output), INDEXING_TYPES, NUMERIC_TYPES); - NDArray::registerSpecialUse({output}, {labels, predictions, weights}); - } +template +__global__ static void copyBuffers(Nd4jLong* destination, void const* source, + Nd4jLong bufferLength) { + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + for (int t = tid; t < bufferLength; t += step) { + destination[t] = + static_cast(reinterpret_cast(source)[t]); + } } + +template +__global__ static void confusionFunctorKernel( + Nd4jLong* labelsBuffer, Nd4jLong* predictionBuffer, Nd4jLong bufferLength, + void const* weightsBuffer, void* outputBuffer, const Nd4jLong* tadShape, + const Nd4jLong* tadOffsets) { + __shared__ int arrIdx, blocksPerArr; + __shared__ T* z; + __shared__ T const* w; + __shared__ Nd4jLong *zShapeInfo, *xShapeInfo, arrLen; + + if (threadIdx.x == 0) { + z = reinterpret_cast(outputBuffer); + w = reinterpret_cast(weightsBuffer); + arrLen = shape::length(tadShape); + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + for (int t = tid; t < bufferLength; t += step) { + auto label = labelsBuffer[t]; //->e(j); + auto pred = predictionBuffer[t]; //->e(j); + auto tZ = z + tadOffsets[label]; + T val = (weightsBuffer == nullptr ? (T)1.0f : w[t]); + + auto idx = shape::getIndexOffset(pred, tadShape); + tZ[idx] = val; + } +} + +template +void _confusionFunctor(sd::LaunchContext* context, NDArray* labels, + NDArray* predictions, NDArray* weights, + NDArray* output) { + auto stream = context->getCudaStream(); + + auto pack = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), 1); + + PointersManager manager(context, "helpers::confusion"); + + Nd4jLong* labelsLongBuffer = labels->dataType() == sd::DataType::INT64 + ? (Nd4jLong*)labels->specialBuffer() + : nullptr; + Nd4jLong* predictionLongBuffer = + predictions->dataType() == sd::DataType::INT64 + ? (Nd4jLong*)predictions->specialBuffer() + : nullptr; + + if (labelsLongBuffer == nullptr) { + auto err = + cudaMalloc(&labelsLongBuffer, labels->lengthOf() * sizeof(Nd4jLong)); + if (err != 0) + throw sd::cuda_exception::build( + "Cannot allocate memory for labels long buffer", err); + // copy with type conversion + copyBuffers<<<256, 512, 1024, *stream>>>( + labelsLongBuffer, labels->specialBuffer(), labels->lengthOf()); + } + + if (predictionLongBuffer == nullptr) { + auto err = cudaMalloc(&predictionLongBuffer, + predictions->lengthOf() * sizeof(Nd4jLong)); + if (err != 0) + throw sd::cuda_exception::build( + "Cannot allocate memory for predictions long buffer", err); + // copy with type conversion + copyBuffers<<<256, 512, 1024, *stream>>>(predictionLongBuffer, + predictions->specialBuffer(), + predictions->lengthOf()); + } + + auto bufferLength = labels->lengthOf(); + dim3 launchDims(32, 32, 1024); + confusionFunctorKernel + <<>>( + labelsLongBuffer, predictionLongBuffer, bufferLength, + weights != nullptr ? weights->specialBuffer() : nullptr, + output->specialBuffer(), pack.specialShapeInfo(), + pack.specialOffsets()); + + manager.synchronize(); + + if (predictionLongBuffer != predictions->specialBuffer()) { + cudaError_t err = cudaFree(predictionLongBuffer); + if (err != 0) + throw sd::cuda_exception::build( + "Cannot deallocate memory for predictions long buffer", err); + } + + if (labelsLongBuffer != labels->specialBuffer()) { + cudaError_t err = cudaFree(labelsLongBuffer); + if (err != 0) + throw sd::cuda_exception::build( + "Cannot deallocate memory for labels long buffer", err); + } +} + +void confusionFunctor(sd::LaunchContext* context, NDArray* labels, + NDArray* predictions, NDArray* weights, NDArray* output) { + auto xType = predictions->dataType(); + auto zType = output->dataType(); // weights can be null + NDArray::prepareSpecialUse({output}, {labels, predictions, weights}); + BUILD_DOUBLE_SELECTOR(xType, zType, _confusionFunctor, + (context, labels, predictions, weights, output), + INDEXING_TYPES, NUMERIC_TYPES); + NDArray::registerSpecialUse({output}, {labels, predictions, weights}); } -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_col2vol.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_col2vol.cu index 80df76c91ae0..2f6b31688859 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_col2vol.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_col2vol.cu @@ -19,113 +19,136 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include #include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// -// columns [bS, iC, kD, kH, kW, oD, oH, oW] to be de-convoluted to volume [bS, iC, iD, iH, iW] +// columns [bS, iC, kD, kH, kW, oD, oH, oW] to be de-convoluted to volume [bS, +// iC, iD, iH, iW] template -static __global__ void col2volCuda(const void* columns, const Nd4jLong* colShapeInfo, void* volume, const Nd4jLong* volShapeInfo, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - - const T* col = reinterpret_cast(columns); - T* vol = reinterpret_cast(volume); - - __shared__ uint kD, kH, kW, oD, oH, oW, *sharedMem; - __shared__ Nd4jLong volLen; +static __global__ void col2volCuda(const void* columns, + const Nd4jLong* colShapeInfo, void* volume, + const Nd4jLong* volShapeInfo, const int sD, + const int sH, const int sW, const int pD, + const int pH, const int pW, const int dD, + const int dH, const int dW) { + const T* col = reinterpret_cast(columns); + T* vol = reinterpret_cast(volume); - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + __shared__ uint kD, kH, kW, oD, oH, oW, *sharedMem; + __shared__ Nd4jLong volLen; - oD = colShapeInfo[6]; - oH = colShapeInfo[7]; - oW = colShapeInfo[8]; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - kD = dD * (colShapeInfo[3] - 1) + 1; - kH = dH * (colShapeInfo[4] - 1) + 1; - kW = dW * (colShapeInfo[5] - 1) + 1; + oD = colShapeInfo[6]; + oH = colShapeInfo[7]; + oW = colShapeInfo[8]; - volLen = shape::length(volShapeInfo); - } - __syncthreads(); - - auto coords = sharedMem + threadIdx.x * 8; + kD = dD * (colShapeInfo[3] - 1) + 1; + kH = dH * (colShapeInfo[4] - 1) + 1; + kW = dW * (colShapeInfo[5] - 1) + 1; - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + volLen = shape::length(volShapeInfo); + } + __syncthreads(); - for (Nd4jLong i = tid; i < volLen; i += gridDim.x * blockDim.x) { + auto coords = sharedMem + threadIdx.x * 8; - shape::index2coords(i, volShapeInfo, coords); + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto volOffset = shape::getOffset(volShapeInfo, coords); + for (Nd4jLong i = tid; i < volLen; i += gridDim.x * blockDim.x) { + shape::index2coords(i, volShapeInfo, coords); - const auto bSiCoffset = coords[0] * colShapeInfo[9] + coords[1] * colShapeInfo[10]; + const auto volOffset = shape::getOffset(volShapeInfo, coords); - const uint imD = coords[2] + pD; - const uint imH = coords[3] + pH; - const uint imW = coords[4] + pW; + const auto bSiCoffset = + coords[0] * colShapeInfo[9] + coords[1] * colShapeInfo[10]; - const uint colDstart = (imD < kD) ? 0 : (imD - kD) / sD + 1; - const uint colHstart = (imH < kH) ? 0 : (imH - kH) / sH + 1; - const uint colWstart = (imW < kW) ? 0 : (imW - kW) / sW + 1; + const uint imD = coords[2] + pD; + const uint imH = coords[3] + pH; + const uint imW = coords[4] + pW; - const uint colDend = sd::math::nd4j_min(imD / sD + 1, oD); - const uint colHend = sd::math::nd4j_min(imH / sH + 1, oH); - const uint colWend = sd::math::nd4j_min(imW / sW + 1, oW); + const uint colDstart = (imD < kD) ? 0 : (imD - kD) / sD + 1; + const uint colHstart = (imH < kH) ? 0 : (imH - kH) / sH + 1; + const uint colWstart = (imW < kW) ? 0 : (imW - kW) / sW + 1; - T val = 0; + const uint colDend = sd::math::nd4j_min(imD / sD + 1, oD); + const uint colHend = sd::math::nd4j_min(imH / sH + 1, oH); + const uint colWend = sd::math::nd4j_min(imW / sW + 1, oW); - for(uint colD = colDstart; colD < colDend; ++colD) { - coords[2] = imD - colD * sD; - if(coords[2] % dD != 0) continue; + T val = 0; - for(uint colH = colHstart; colH < colHend; ++colH) { - coords[3] = imH - colH * sH; - if(coords[3] % dH != 0) continue; + for (uint colD = colDstart; colD < colDend; ++colD) { + coords[2] = imD - colD * sD; + if (coords[2] % dD != 0) continue; - for(uint colW = colWstart; colW < colWend; ++colW) { - coords[4] = imW - colW * sW; - if(coords[4] % dW != 0) continue; + for (uint colH = colHstart; colH < colHend; ++colH) { + coords[3] = imH - colH * sH; + if (coords[3] % dH != 0) continue; - val += col[bSiCoffset + (coords[2]/dD)*colShapeInfo[11] + (coords[3]/dH)*colShapeInfo[12] + (coords[4]/dW)*colShapeInfo[13] + colD*colShapeInfo[14] + colH*colShapeInfo[15] + colW*colShapeInfo[16]]; + for (uint colW = colWstart; colW < colWend; ++colW) { + coords[4] = imW - colW * sW; + if (coords[4] % dW != 0) continue; - } - } + val += col[bSiCoffset + (coords[2] / dD) * colShapeInfo[11] + + (coords[3] / dH) * colShapeInfo[12] + + (coords[4] / dW) * colShapeInfo[13] + + colD * colShapeInfo[14] + colH * colShapeInfo[15] + + colW * colShapeInfo[16]]; } - - vol[volOffset] = val; + } } + + vol[volOffset] = val; + } } ////////////////////////////////////////////////////////////////////////// template -static void col2volCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* columns, const Nd4jLong* colShapeInfo, - void* volume, const Nd4jLong* volShapeInfo, - const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - - col2volCuda<<>>(columns, colShapeInfo, volume, volShapeInfo, sD, sH, sW, pD, pH, pW, dD, dH, dW); +static void col2volCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* columns, + const Nd4jLong* colShapeInfo, void* volume, + const Nd4jLong* volShapeInfo, const int sD, + const int sH, const int sW, const int pD, + const int pH, const int pW, const int dD, + const int dH, const int dW) { + col2volCuda<<>>( + columns, colShapeInfo, volume, volShapeInfo, sD, sH, sW, pD, pH, pW, dD, + dH, dW); } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::col2vol(sd::graph::Context& block, const NDArray& col, NDArray& vol, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - - PointersManager manager(block.launchContext(), "col2vol"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (vol.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = col.rankOf() * sizeof(uint) * threadsPerBlock + 256; - - NDArray::prepareSpecialUse({&vol}, {&col}); - BUILD_SINGLE_SELECTOR(vol.dataType(), col2volCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), col.specialBuffer(), col.specialShapeInfo(), vol.specialBuffer(), vol.specialShapeInfo(), sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); - NDArray::registerSpecialUse({&vol}, {&col}); - - manager.synchronize(); +void ConvolutionUtils::col2vol(sd::graph::Context& block, const NDArray& col, + NDArray& vol, const int sD, const int sH, + const int sW, const int pD, const int pH, + const int pW, const int dD, const int dH, + const int dW) { + PointersManager manager(block.launchContext(), "col2vol"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (vol.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = col.rankOf() * sizeof(uint) * threadsPerBlock + 256; + + NDArray::prepareSpecialUse({&vol}, {&col}); + BUILD_SINGLE_SELECTOR( + vol.dataType(), col2volCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, + block.launchContext()->getCudaStream(), col.specialBuffer(), + col.specialShapeInfo(), vol.specialBuffer(), vol.specialShapeInfo(), sD, + sH, sW, pD, pH, pW, dD, dH, dW), + FLOAT_TYPES); + NDArray::registerSpecialUse({&vol}, {&col}); + + manager.synchronize(); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_conv2d.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_conv2d.cu index 494ce4a815a6..ce4be2c3928e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_conv2d.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_conv2d.cu @@ -19,87 +19,113 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include -#include #include #include +#include +#include +#include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - // bias [oC] - // output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 1-NCHW, 0-NHWC - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - std::vector permutForOutput; - - if(isNCHW) - permutForOutput = {0, 3, 1, 2}; // [bS, oH, oW, oC] -> [bS, oC, oH, oW] - else - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC - - std::vector wAxes; - if(0 == wFormat) - wAxes = {0, 1, 2}; - else if(1 == wFormat) - wAxes = {2, 3, 1}; - else - wAxes = {1, 2, 3}; - - NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), input->getContext()); - NDArray colP = col.permute({0, 5, 3, 4, 1, 2}); // {bS, iC, kH, kW, oH, oW} - NDArray mmulResult('f', {bS*oH*oW, oC}, output->dataType(), output->getContext()); - - //----- calculation of output -----// - auto ctx = block.launchContext(); - helpers::im2col(*ctx, *input, colP, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, wAxes, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC] - - //----- assign outTemp to output -----// - if(isNCHW) { - mmulResult.reshapei({bS, oH, oW, oC}); - mmulResult.permutei(permutForOutput); - } - output->assign(mmulResult); - - //----- add biases if required -----// - if(bias) - // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); - helpers::addBias(block, *output, *bias, *output, isNCHW); - - if(!isNCHW) - delete input; - +static void conv2d_(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + NDArray* output, const int kH, const int kW, const int sH, + const int sW, int pH, int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW, + const int wFormat) { + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + // bias [oC] + // output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 1-NCHW, 0-NHWC + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW, paddingMode); + + std::vector permutForOutput; + + if (isNCHW) + permutForOutput = {0, 3, 1, 2}; // [bS, oH, oW, oC] -> [bS, oC, oH, oW] + else + input = new NDArray(input->permute( + {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC + + std::vector wAxes; + if (0 == wFormat) + wAxes = {0, 1, 2}; + else if (1 == wFormat) + wAxes = {2, 3, 1}; + else + wAxes = {1, 2, 3}; + + NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), + input->getContext()); + NDArray colP = col.permute({0, 5, 3, 4, 1, 2}); // {bS, iC, kH, kW, oH, oW} + NDArray mmulResult('f', {bS * oH * oW, oC}, output->dataType(), + output->getContext()); + + //----- calculation of output -----// + auto ctx = block.launchContext(); + helpers::im2col( + *ctx, *input, colP, kH, kW, sH, sW, pH, pW, dH, dW, + NDArrayFactory::create( + 0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, + // iC, kH, kW, oH, oW] + MmulHelper::tensorDot( + &col, weights, &mmulResult, {3, 4, 5}, wAxes, + {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC] + + //----- assign outTemp to output -----// + if (isNCHW) { + mmulResult.reshapei({bS, oH, oW, oC}); + mmulResult.permutei(permutForOutput); + } + output->assign(mmulResult); + + //----- add biases if required -----// + if (bias) + // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); + helpers::addBias(block, *output, *bias, *output, isNCHW); + + if (!isNCHW) delete input; } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); +void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + NDArray* output, const int kH, const int kW, + const int sH, const int sW, int pH, int pW, + const int dH, const int dW, const int paddingMode, + const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE( + input->dataType(), conv2d_, + (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, + paddingMode, isNCHW, wFormat), + FLOAT_TYPES); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_conv2dBP.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_conv2dBP.cu index dbf4ee39012f..7317e0009fb0 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_conv2dBP.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_conv2dBP.cu @@ -19,107 +19,144 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include -#include #include #include +#include +#include +#include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - // bias [oC] - // gradO [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - - // gradI [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - // gradW [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - // gradB [oC] - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 0-NHWC, 1-NCHW - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - std::vector gradOaxesForDot; - - if(!isNCHW) { - gradOaxesForDot = {0, 1, 2}; // bS, oH, oW - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - } else { - gradOaxesForDot = {0, 2, 3}; // bS, oH, oW - } - - std::vector wPermut, colPermut; - if(0 == wFormat) { - wPermut = {2, 0, 1, 3}; - colPermut = {2, 3, 1, 0, 4, 5}; - } - else if(1 == wFormat) { - wPermut = {1, 2, 3, 0}; - colPermut = {1, 2, 3, 0, 4, 5}; - } - else { - wPermut = {3, 1, 2, 0}; - colPermut = {2, 3, 1, 0, 4, 5}; - } - - NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); - - // ----- calculation of gradW ----- // - if(gradW) { - auto ctx = block.launchContext(); - helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - sd::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, wPermut); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC] - } - - // ----- calculation of gradB ----- // - if(gradB) { - NDArray* gradBR = gradB; - if(gradB->rankOf() == 2) - gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); - gradO->reduceAlongDimension(reduce::Sum, *gradBR, gradOaxesForDot, false); // sum over bS, oH, oW - if(gradBR != gradB) - delete gradBR; - } - - //----- calculation of gradI -----// - // [kH, kW, iC, oC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] - // [oC, iC, kH, kW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, bS, oH, oW] - // [oC, kH, kW, iC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] - sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, colPermut); // [kH, kW, iC, oC]/[oC, iC, kH, kW]] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] - - helpers::col2im(*block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] - - if(!isNCHW) { - delete input; - delete gradI; - } +static void conv2dBP_(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + const NDArray* gradO, NDArray* gradI, NDArray* gradW, + NDArray* gradB, const int kH, const int kW, const int sH, + const int sW, int pH, int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW, + const int wFormat) { + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + // bias [oC] + // gradO [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + + // gradI [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + // gradW [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + // gradB [oC] + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 0-NHWC, 1-NCHW + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW, paddingMode); + + std::vector gradOaxesForDot; + + if (!isNCHW) { + gradOaxesForDot = {0, 1, 2}; // bS, oH, oW + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradI = new NDArray( + gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + } else { + gradOaxesForDot = {0, 2, 3}; // bS, oH, oW + } + + std::vector wPermut, colPermut; + if (0 == wFormat) { + wPermut = {2, 0, 1, 3}; + colPermut = {2, 3, 1, 0, 4, 5}; + } else if (1 == wFormat) { + wPermut = {1, 2, 3, 0}; + colPermut = {1, 2, 3, 0, 4, 5}; + } else { + wPermut = {3, 1, 2, 0}; + colPermut = {2, 3, 1, 0, 4, 5}; + } + + NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, + input->dataType(), input->getContext()); + + // ----- calculation of gradW ----- // + if (gradW) { + auto ctx = block.launchContext(); + helpers::im2col( + *ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, + NDArrayFactory::create( + 0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to + // [bS, iC, kH, kW, oH, oW] + sd::MmulHelper::tensorDot( + &columns, gradO, gradW, {0, 4, 5}, gradOaxesForDot, + wPermut); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, + // oW] = [iC, kH, kW, oC] + } + + // ----- calculation of gradB ----- // + if (gradB) { + NDArray* gradBR = gradB; + if (gradB->rankOf() == 2) + gradBR = new NDArray( + gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); + gradO->reduceAlongDimension(reduce::Sum, *gradBR, gradOaxesForDot, + false); // sum over bS, oH, oW + if (gradBR != gradB) delete gradBR; + } + + //----- calculation of gradI -----// + // [kH, kW, iC, oC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, + // oW] [oC, iC, kH, kW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, bS, + // oH, oW] [oC, kH, kW, iC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, + // bS, oH, oW] + sd::MmulHelper::tensorDot( + weights, gradO, &columns, {indWoC}, {indIOioC}, + colPermut); // [kH, kW, iC, oC]/[oC, iC, kH, kW]] x [bS, oH, oW, oC]/[bS, + // oC, oH, oW] = [kH, kW, iC, bS, oH, oW] + + helpers::col2im( + *block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, + dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] + + if (!isNCHW) { + delete input; + delete gradI; + } } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); +void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + const NDArray* gradO, NDArray* gradI, + NDArray* gradW, NDArray* gradB, const int kH, + const int kW, const int sH, const int sW, + int pH, int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW, + const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE( + input->dataType(), conv2dBP_, + (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, + pH, pW, dH, dW, paddingMode, isNCHW, wFormat), + FLOAT_TYPES); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_depthwiseConv2d.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_depthwiseConv2d.cu index bbf5d5892617..42f29c050cdc 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_depthwiseConv2d.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_depthwiseConv2d.cu @@ -19,83 +19,119 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include -#include #include #include +#include +#include +#include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - // bias [oC] = iC*mC - // output [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 0-NCHW, 1-NHWC - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - std::vector> modifColumns = {{1,0,4,5,2,3}, {iC,bS*oH*oW,kH*kW}}; // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> [iC,bS*oH*oW,kH*kW] - std::vector> modifOutput, modifWeights; - std::vector outReShape; - - if(!isNCHW) { - outReShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] - modifOutput = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] - } - else { - outReShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] - modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - } - - if(0 == wFormat) - modifWeights = {{2,0,1,3},{iC,kH*kW,mC}}; - else if(1 == wFormat) - modifWeights = {{1,2,3,0},{iC,kH*kW,mC}}; - else - modifWeights = {{3,1,2,0},{iC,kH*kW,mC}}; - - if(paddingMode == 1) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); - NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false); - - helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, modifWeights, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] - - if(bias) - // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); - helpers::addBias(block, *output, *bias, *output, isNCHW); - - if(!isNCHW) - delete input; +static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + NDArray* output, const int kH, const int kW, + const int sH, const int sW, int pH, int pW, + const int dH, const int dW, const int paddingMode, + const int isNCHW, const int wFormat) { + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // bias [oC] = iC*mC + // output [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 0-NCHW, 1-NHWC + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + std::vector> modifColumns = { + {1, 0, 4, 5, 2, 3}, + {iC, bS * oH * oW, + kH * kW}}; // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> + // [iC,bS*oH*oW,kH*kW] + std::vector> modifOutput, modifWeights; + std::vector outReShape; + + if (!isNCHW) { + outReShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] + modifOutput = { + {3, 0, 1, 2, 4}, + {iC, bS * oH * oW, + mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] + } else { + outReShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] + modifOutput = { + {1, 0, 3, 4, 2}, + {iC, bS * oH * oW, + mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + } + + if (0 == wFormat) + modifWeights = {{2, 0, 1, 3}, {iC, kH * kW, mC}}; + else if (1 == wFormat) + modifWeights = {{1, 2, 3, 0}, {iC, kH * kW, mC}}; + else + modifWeights = {{3, 1, 2, 0}, {iC, kH * kW, mC}}; + + if (paddingMode == 1) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, + input->dataType(), input->getContext()); + NDArray outputReshaped = + output->reshape(output->ordering(), outReShape, false); + + helpers::im2col( + *output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, + NDArrayFactory::create( + 0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, + // iC, kH, kW, oH, oW] + MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, + modifWeights, + modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, + // mC] = [iC, bS*oH*oW, mC] + + if (bias) + // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); + helpers::addBias(block, *output, *bias, *output, isNCHW); + + if (!isNCHW) delete input; } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::depthwiseConv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); +void ConvolutionUtils::depthwiseConv2d( + sd::graph::Context& block, const NDArray* input, const NDArray* weights, + const NDArray* bias, NDArray* output, const int kH, const int kW, + const int sH, const int sW, int pH, int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE( + input->dataType(), depthwiseConv2d_, + (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, + paddingMode, isNCHW, wFormat), + FLOAT_TYPES); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_depthwiseConv2dBP.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_depthwiseConv2dBP.cu index b06af61665dd..545ae0daa5e9 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_depthwiseConv2dBP.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_depthwiseConv2dBP.cu @@ -19,102 +19,155 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include #include #include +#include +#include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) - // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - // bias [oC] = [iC*mC] - // gradO [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - // gradI [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon - // gradW [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - // gradB [oC] - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 0-NHWC, 1-NCHW - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - std::vector> modifColumns = {{1,2,3,0,4,5}, {iC, kH*kW, bS*oH*oW}}; // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW] - std::vector> modifGradO1, modifGradO2, modifWeights; - std::vector gradOreShape; - - if(!isNCHW) { - gradOreShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] - modifGradO1 = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - modifGradO2 = {{3,0,1,2},{iC, mC, bS*oH*oW}}; // [bS,oH,oW,iC*mC] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] - gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] - } - else { - gradOreShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] - modifGradO1 = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] - } - - if(0 == wFormat) - modifWeights = {{2,0,1,3},{iC,kH*kW,mC}}; - else if(1 == wFormat) - modifWeights = {{1,2,3,0},{iC,kH*kW,mC}}; - else - modifWeights = {{3,1,2,0},{iC,kH*kW,mC}}; - - if(paddingMode == 1) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); - NDArray gradOreshaped = gradO->reshape(gradO->ordering(), gradOreShape); - - // ----- calculation of gradW and gradB ----- // - - helpers::im2col(*input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, modifGradO1, modifWeights); // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC] - - // ----- calculation of gradB ----- // - if(gradB) { - NDArray* gradBR = gradB; - if(gradB->rankOf() == 2) - gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); - gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}, false); // sum over bS, oH, oW - if(gradBR != gradB) - delete gradBR; - } - - //----- calculation of gradI -----// - sd::MmulHelper::tensorDot(weights, gradO, &columns, modifWeights, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW] - helpers::col2im(*input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] - - if(!isNCHW) { - delete input; - delete gradI; - } +static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, + const NDArray* bias, const NDArray* gradO, + NDArray* gradI, NDArray* gradW, NDArray* gradB, + const int kH, const int kW, const int sH, + const int sW, int pH, int pW, const int dH, + const int dW, const int paddingMode, + const int isNCHW, const int wFormat) { + // input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) + // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // bias [oC] = [iC*mC] + // gradO [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next + // gradI [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon + // gradW [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // gradB [oC] + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 0-NHWC, 1-NCHW + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + std::vector> modifColumns = { + {1, 2, 3, 0, 4, 5}, + {iC, kH * kW, + bS * oH * oW}}; // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW] + std::vector> modifGradO1, modifGradO2, modifWeights; + std::vector gradOreShape; + + if (!isNCHW) { + gradOreShape = {bS, oH, oW, iC, + mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] + modifGradO1 = { + {3, 0, 1, 2, 4}, + {iC, bS * oH * oW, + mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + modifGradO2 = { + {3, 0, 1, 2}, + {iC, mC, + bS * oH * + oW}}; // [bS,oH,oW,iC*mC] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] + gradI = new NDArray( + gradI->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] + } else { + gradOreShape = {bS, iC, mC, oH, + oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] + modifGradO1 = { + {1, 0, 3, 4, 2}, + {iC, bS * oH * oW, + mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + modifGradO2 = { + {1, 0, 2, 3}, + {iC, mC, + bS * oH * + oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] + } + + if (0 == wFormat) + modifWeights = {{2, 0, 1, 3}, {iC, kH * kW, mC}}; + else if (1 == wFormat) + modifWeights = {{1, 2, 3, 0}, {iC, kH * kW, mC}}; + else + modifWeights = {{3, 1, 2, 0}, {iC, kH * kW, mC}}; + + if (paddingMode == 1) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, + input->dataType(), input->getContext()); + NDArray gradOreshaped = gradO->reshape(gradO->ordering(), gradOreShape); + + // ----- calculation of gradW and gradB ----- // + + helpers::im2col( + *input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, + NDArrayFactory::create( + 0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, + // iC, kH, kW, oH, oW] + sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, + modifGradO1, + modifWeights); // [iC, kW*kH, bS*oH*oW] x [iC, + // bS*oH*oW, mC] = [iC, kH*kW, mC] + + // ----- calculation of gradB ----- // + if (gradB) { + NDArray* gradBR = gradB; + if (gradB->rankOf() == 2) + gradBR = new NDArray( + gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); + gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0, indOoH, indOoH + 1}, + false); // sum over bS, oH, oW + if (gradBR != gradB) delete gradBR; + } + + //----- calculation of gradI -----// + sd::MmulHelper::tensorDot(weights, gradO, &columns, modifWeights, modifGradO2, + modifColumns); // [iC, kH*kW, mC] x [iC, mC, + // bS*oH*oW] = [iC, kW*kH, bS*oH*oW] + helpers::col2im( + *input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, + dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] + + if (!isNCHW) { + delete input; + delete gradI; + } } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); +void ConvolutionUtils::depthwiseConv2dBP( + sd::graph::Context& block, const NDArray* input, const NDArray* weights, + const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, + NDArray* gradB, const int kH, const int kW, const int sH, const int sW, + int pH, int pW, const int dH, const int dW, const int paddingMode, + const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE( + input->dataType(), depthwiseConv2dBP_, + (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, + dH, dW, paddingMode, isNCHW, wFormat), + FLOAT_TYPES); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling2d.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling2d.cu index c146be7bff2b..1ad9a69a29cf 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling2d.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling2d.cu @@ -19,324 +19,381 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include #include #include +#include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static __global__ void avgPooling2dCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { - - // input is [bS, iC, iH, iW] - // output is [bS, iC, oH, oW] - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); +static __global__ void avgPooling2dCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, const int dH, const int dW, + const int extraParam0) { + // input is [bS, iC, iH, iW] + // output is [bS, iC, oH, oW] + + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ int bS, iC, oH, oW, iH, iW, strideB, strideC, strideY, strideX, + strideOB, strideOC, strideOY, strideOX, length, kHEff, kWEff; + + if (threadIdx.x == 0) { + bS = shape::sizeAt(xShapeInfo, 0); + iC = shape::sizeAt(xShapeInfo, 1); + oH = shape::sizeAt(zShapeInfo, 2); + oW = shape::sizeAt(zShapeInfo, 3); + iH = shape::sizeAt(xShapeInfo, 2); + iW = shape::sizeAt(xShapeInfo, 3); + + strideB = shape::stride(xShapeInfo)[0]; + strideC = shape::stride(xShapeInfo)[1]; + strideY = shape::stride(xShapeInfo)[2]; + strideX = shape::stride(xShapeInfo)[3]; + + strideOB = shape::stride(zShapeInfo)[0]; + strideOC = shape::stride(zShapeInfo)[1]; + strideOY = shape::stride(zShapeInfo)[2]; + strideOX = shape::stride(zShapeInfo)[3]; + + length = shape::length(zShapeInfo); + + // Replace kernel H/W with *effective* kernel H/W accounting for dilatyon + kHEff = kH + (kH - 1) * (dH - 1); + kWEff = kW + (kW - 1) * (dW - 1); + } + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (int index = tid; index < length; index += blockDim.x * gridDim.x) { + const int pw = index % oW; + const int ph = (index / oW) % oH; + const int c = (index / oW / oH) % iC; + const int n = index / oW / oH / iC; + + int hstart = sH * ph - pH; + int wstart = sW * pw - pW; + int hend = hstart + kHEff; + int wend = wstart + kWEff; + + if (hstart < 0) { + int f = sd::math::nd4j_ceil((Z)-hstart / (Z)dH); + hstart += f * dH; + } + if (wstart < 0) { + int f = sd::math::nd4j_ceil((Z)-wstart / (Z)dW); + wstart += f * dW; + } + if (hend > iH) { + int f = sd::math::nd4j_ceil((Z)(hend - iH) / (Z)dH); + hend -= f * dH; + } + if (wend > iW) { + int f = sd::math::nd4j_ceil((Z)(wend - iW) / (Z)dW); + wend -= f * dW; + } - __shared__ int bS, iC, oH, oW, iH, iW, strideB, strideC, strideY, strideX, strideOB, strideOC, strideOY, strideOX, length, kHEff, kWEff; + // Accounts for dilation + int pool_size = + sd::math::nd4j_ceil((double)(hend - hstart) / (double)dH) * + sd::math::nd4j_ceil((double)(wend - wstart) / (double)dW); - if (threadIdx.x == 0) { - bS = shape::sizeAt(xShapeInfo, 0); - iC = shape::sizeAt(xShapeInfo, 1); - oH = shape::sizeAt(zShapeInfo, 2); - oW = shape::sizeAt(zShapeInfo, 3); - iH = shape::sizeAt(xShapeInfo, 2); - iW = shape::sizeAt(xShapeInfo, 3); + Z sum = 0.0f; - strideB = shape::stride(xShapeInfo)[0]; - strideC = shape::stride(xShapeInfo)[1]; - strideY = shape::stride(xShapeInfo)[2]; - strideX = shape::stride(xShapeInfo)[3]; + const X *inSlice = x + (n * strideB + c * strideC); - strideOB = shape::stride(zShapeInfo)[0]; - strideOC = shape::stride(zShapeInfo)[1]; - strideOY = shape::stride(zShapeInfo)[2]; - strideOX = shape::stride(zShapeInfo)[3]; + for (int h = hstart; h < hend; h += dH) + for (int w = wstart; w < wend; w += dW) + sum += static_cast(inSlice[h * strideY + w * strideX]); - length = shape::length(zShapeInfo); + int divide_factor = pool_size; // Case 0: exclude padding + if (extraParam0 == 1) // Case 1: include padding + divide_factor = kH * kW; - //Replace kernel H/W with *effective* kernel H/W accounting for dilatyon - kHEff = kH + (kH-1)*(dH-1); - kWEff = kW + (kW-1)*(dW-1); - } - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (int index = tid; index < length; index += blockDim.x * gridDim.x) { - - const int pw = index % oW; - const int ph = (index / oW) % oH; - const int c = (index / oW / oH) % iC; - const int n = index / oW / oH / iC; - - int hstart = sH * ph - pH; - int wstart = sW * pw - pW; - int hend = hstart + kHEff; - int wend = wstart + kWEff; - - if(hstart < 0){ - int f = sd::math::nd4j_ceil((Z) -hstart / (Z)dH); - hstart += f * dH; - } - if(wstart < 0){ - int f = sd::math::nd4j_ceil((Z) -wstart / (Z) dW); - wstart += f * dW; - } - if(hend > iH){ - int f = sd::math::nd4j_ceil((Z) (hend-iH) / (Z) dH); - hend -= f * dH; - } - if(wend > iW){ - int f = sd::math::nd4j_ceil((Z) (wend-iW) / (Z) dW); - wend -= f * dW; - } - - //Accounts for dilation - int pool_size = sd::math::nd4j_ceil((double) (hend-hstart) / (double) dH) * sd::math::nd4j_ceil((double) (wend-wstart) / (double) dW); - - Z sum = 0.0f; - - const X *inSlice = x + (n * strideB + c * strideC); - - for (int h = hstart; h < hend; h += dH) - for (int w = wstart; w < wend; w += dW) - sum += static_cast(inSlice[h * strideY + w * strideX]); - - int divide_factor = pool_size; //Case 0: exclude padding - if (extraParam0 == 1) //Case 1: include padding - divide_factor = kH * kW; - - z[n * strideOB + c * strideOC + pw * strideOX + ph * strideOY] = sum / static_cast(divide_factor); - } + z[n * strideOB + c * strideOC + pw * strideOX + ph * strideOY] = + sum / static_cast(divide_factor); + } } ////////////////////////////////////////////////////////////////////////// template -static void avgPooling2dCudaLauncher(sd::LaunchContext & block, const void *vx, const Nd4jLong *vxShapeInfo, void *vz, const Nd4jLong *vzShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { - avgPooling2dCuda<<<512, 512, 4192, *block.getCudaStream()>>>(vx, vxShapeInfo, vz, vzShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, extraParam0); +static void avgPooling2dCudaLauncher(sd::LaunchContext &block, const void *vx, + const Nd4jLong *vxShapeInfo, void *vz, + const Nd4jLong *vzShapeInfo, const int kH, + const int kW, const int sH, const int sW, + const int pH, const int pW, const int dH, + const int dW, const int extraParam0) { + avgPooling2dCuda<<<512, 512, 4192, *block.getCudaStream()>>>( + vx, vxShapeInfo, vz, vzShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, + extraParam0); } ////////////////////////////////////////////////////////////////////////// template -static __global__ void pnormPooling2dCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { - - // input is [bS, iC, iH, iW] - // output is [bS, iC, oH, oW] - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ int bS, iC, oH, oW, iH, iW, strideB, strideC, strideY, strideX, strideOB, strideOC, strideOY, strideOX, length, kHEff, kWEff; - __shared__ bool fOrder; - - if (threadIdx.x == 0) { - bS = shape::sizeAt(xShapeInfo, 0); - iC = shape::sizeAt(xShapeInfo, 1); - oH = shape::sizeAt(zShapeInfo, 2); - oW = shape::sizeAt(zShapeInfo, 3); - iH = shape::sizeAt(xShapeInfo, 2); - iW = shape::sizeAt(xShapeInfo, 3); +static __global__ void pnormPooling2dCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, const int dH, const int dW, + const int extraParam0) { + // input is [bS, iC, iH, iW] + // output is [bS, iC, oH, oW] + + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ int bS, iC, oH, oW, iH, iW, strideB, strideC, strideY, strideX, + strideOB, strideOC, strideOY, strideOX, length, kHEff, kWEff; + __shared__ bool fOrder; + + if (threadIdx.x == 0) { + bS = shape::sizeAt(xShapeInfo, 0); + iC = shape::sizeAt(xShapeInfo, 1); + oH = shape::sizeAt(zShapeInfo, 2); + oW = shape::sizeAt(zShapeInfo, 3); + iH = shape::sizeAt(xShapeInfo, 2); + iW = shape::sizeAt(xShapeInfo, 3); + + strideB = shape::stride(xShapeInfo)[0]; + strideC = shape::stride(xShapeInfo)[1]; + strideY = shape::stride(xShapeInfo)[2]; + strideX = shape::stride(xShapeInfo)[3]; + + strideOB = shape::stride(zShapeInfo)[0]; + strideOC = shape::stride(zShapeInfo)[1]; + strideOY = shape::stride(zShapeInfo)[2]; + strideOX = shape::stride(zShapeInfo)[3]; + + length = shape::length(zShapeInfo); + + // Replace kernel H/W with *effective* kernel H/W accounting for dilatyon + kHEff = kH + (kH - 1) * (dH - 1); + kWEff = kW + (kW - 1) * (dW - 1); + } + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (int index = tid; index < length; index += blockDim.x * gridDim.x) { + const int pw = index % oW; + const int ph = (index / oW) % oH; + const int c = (index / oW / oH) % iC; + const int n = index / oW / oH / iC; + + int hstart = sH * ph - pH; + int wstart = sW * pw - pW; + int hend = hstart + kHEff; + int wend = wstart + kWEff; + + if (hstart < 0) { + int f = sd::math::nd4j_ceil((Z)-hstart / (Z)dH); + hstart += f * dH; + } + if (wstart < 0) { + int f = sd::math::nd4j_ceil((Z)-wstart / (Z)dW); + wstart += f * dW; + } + if (hend > iH) { + int f = sd::math::nd4j_ceil((Z)(hend - iH) / (Z)dH); + hend -= f * dH; + } + if (wend > iW) { + int f = sd::math::nd4j_ceil((Z)(wend - iW) / (Z)dW); + wend -= f * dW; + } + // Accounts for dilation + int pool_size = + sd::math::nd4j_ceil((double)(hend - hstart) / (double)dH) * + sd::math::nd4j_ceil((double)(wend - wstart) / (double)dW); - strideB = shape::stride(xShapeInfo)[0]; - strideC = shape::stride(xShapeInfo)[1]; - strideY = shape::stride(xShapeInfo)[2]; - strideX = shape::stride(xShapeInfo)[3]; + Z sum = 0.f; - strideOB = shape::stride(zShapeInfo)[0]; - strideOC = shape::stride(zShapeInfo)[1]; - strideOY = shape::stride(zShapeInfo)[2]; - strideOX = shape::stride(zShapeInfo)[3]; + const X *inSlice = x + (n * strideB + c * strideC); - length = shape::length(zShapeInfo); + for (int h = hstart; h < hend; h += dH) + for (int w = wstart; w < wend; w += dW) + sum += sd::math::nd4j_pow( + static_cast( + sd::math::nd4j_abs(inSlice[h * strideY + w * strideX])), + extraParam0); - //Replace kernel H/W with *effective* kernel H/W accounting for dilatyon - kHEff = kH + (kH-1)*(dH-1); - kWEff = kW + (kW-1)*(dW-1); - } - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (int index = tid; index < length; index += blockDim.x * gridDim.x) { - - const int pw = index % oW; - const int ph = (index / oW) % oH; - const int c = (index / oW / oH) % iC; - const int n = index / oW / oH / iC; - - int hstart = sH * ph - pH; - int wstart = sW * pw - pW; - int hend = hstart + kHEff; - int wend = wstart + kWEff; - - if (hstart < 0) { - int f = sd::math::nd4j_ceil((Z) -hstart / (Z) dH); - hstart += f * dH; - } - if (wstart < 0) { - int f = sd::math::nd4j_ceil((Z) -wstart / (Z) dW); - wstart += f * dW; - } - if (hend > iH) { - int f = sd::math::nd4j_ceil((Z) (hend - iH) / (Z) dH); - hend -= f * dH; - } - if (wend > iW) { - int f = sd::math::nd4j_ceil((Z) (wend - iW) / (Z) dW); - wend -= f * dW; - } - //Accounts for dilation - int pool_size = sd::math::nd4j_ceil((double) (hend - hstart) / (double) dH) * - sd::math::nd4j_ceil((double) (wend - wstart) / (double) dW); - - Z sum = 0.f; - - const X *inSlice = x + (n * strideB + c * strideC); - - for (int h = hstart; h < hend; h += dH) - for (int w = wstart; w < wend; w += dW) - sum += sd::math::nd4j_pow(static_cast(sd::math::nd4j_abs(inSlice[h * strideY + w * strideX])), extraParam0); - - z[n * strideOB + c * strideOC + pw * strideOX + ph * strideOY] = sd::math::nd4j_pow(sum, (Z) 1.0f / extraParam0); - } + z[n * strideOB + c * strideOC + pw * strideOX + ph * strideOY] = + sd::math::nd4j_pow(sum, (Z)1.0f / extraParam0); + } } ////////////////////////////////////////////////////////////////////////// template -static void pnormPooling2dCudaLauncher(sd::LaunchContext & block, const void *vx, const Nd4jLong *vxShapeInfo, void *vz, const Nd4jLong *vzShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { - pnormPooling2dCuda<<<512, 512, 4192, *block.getCudaStream()>>>(vx, vxShapeInfo, vz, vzShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, extraParam0); +static void pnormPooling2dCudaLauncher(sd::LaunchContext &block, const void *vx, + const Nd4jLong *vxShapeInfo, void *vz, + const Nd4jLong *vzShapeInfo, + const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, + const int dH, const int dW, + const int extraParam0) { + pnormPooling2dCuda<<<512, 512, 4192, *block.getCudaStream()>>>( + vx, vxShapeInfo, vz, vzShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, + extraParam0); } ////////////////////////////////////////////////////////////////////////// template -static __global__ void maxPooling2dCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { - - // input is [bS, iC, iH, iW] - // output is [bS, iC, oH, oW] - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ int bS, iC, oH, oW, iH, iW, strideB, strideC, strideY, strideX, strideOB, strideOC, strideOY, strideOX, length, kHEff, kWEff; - __shared__ bool fOrder; - - if (threadIdx.x == 0) { - bS = shape::sizeAt(xShapeInfo, 0); - iC = shape::sizeAt(xShapeInfo, 1); - oH = shape::sizeAt(zShapeInfo, 2); - oW = shape::sizeAt(zShapeInfo, 3); - iH = shape::sizeAt(xShapeInfo, 2); - iW = shape::sizeAt(xShapeInfo, 3); - - strideB = shape::stride(xShapeInfo)[0]; - strideC = shape::stride(xShapeInfo)[1]; - strideY = shape::stride(xShapeInfo)[2]; - strideX = shape::stride(xShapeInfo)[3]; +static __global__ void maxPooling2dCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, const int dH, const int dW, + const int extraParam0) { + // input is [bS, iC, iH, iW] + // output is [bS, iC, oH, oW] + + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ int bS, iC, oH, oW, iH, iW, strideB, strideC, strideY, strideX, + strideOB, strideOC, strideOY, strideOX, length, kHEff, kWEff; + __shared__ bool fOrder; + + if (threadIdx.x == 0) { + bS = shape::sizeAt(xShapeInfo, 0); + iC = shape::sizeAt(xShapeInfo, 1); + oH = shape::sizeAt(zShapeInfo, 2); + oW = shape::sizeAt(zShapeInfo, 3); + iH = shape::sizeAt(xShapeInfo, 2); + iW = shape::sizeAt(xShapeInfo, 3); + + strideB = shape::stride(xShapeInfo)[0]; + strideC = shape::stride(xShapeInfo)[1]; + strideY = shape::stride(xShapeInfo)[2]; + strideX = shape::stride(xShapeInfo)[3]; + + strideOB = shape::stride(zShapeInfo)[0]; + strideOC = shape::stride(zShapeInfo)[1]; + strideOY = shape::stride(zShapeInfo)[2]; + strideOX = shape::stride(zShapeInfo)[3]; + + length = shape::length(zShapeInfo); + + // Replace kernel H/W with *effective* kernel H/W accounting for dilatyon + kHEff = kH + (kH - 1) * (dH - 1); + kWEff = kW + (kW - 1) * (dW - 1); + } + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (int index = tid; index < length; index += blockDim.x * gridDim.x) { + const int pw = index % oW; + const int ph = (index / oW) % oH; + const int c = (index / oW / oH) % iC; + const int n = index / oW / oH / iC; + + int hstart = sH * ph - pH; + int wstart = sW * pw - pW; + int hend = hstart + kHEff; + int wend = wstart + kWEff; + + if (hstart < 0) { + int f = sd::math::nd4j_ceil((Z)-hstart / (Z)dH); + hstart += f * dH; + } + if (wstart < 0) { + int f = sd::math::nd4j_ceil((Z)-wstart / (Z)dW); + wstart += f * dW; + } + if (hend > iH) { + int f = sd::math::nd4j_ceil((Z)(hend - iH) / (Z)dH); + hend -= f * dH; + } + if (wend > iW) { + int f = sd::math::nd4j_ceil((Z)(wend - iW) / (Z)dW); + wend -= f * dW; + } + // Accounts for dilation + int pool_size = + sd::math::nd4j_ceil((double)(hend - hstart) / (double)dH) * + sd::math::nd4j_ceil((double)(wend - wstart) / (double)dW); - strideOB = shape::stride(zShapeInfo)[0]; - strideOC = shape::stride(zShapeInfo)[1]; - strideOY = shape::stride(zShapeInfo)[2]; - strideOX = shape::stride(zShapeInfo)[3]; + Z max = -sd::DataTypeUtils::max(); - length = shape::length(zShapeInfo); + const X *inSlice = x + (n * strideB + c * strideC); - //Replace kernel H/W with *effective* kernel H/W accounting for dilatyon - kHEff = kH + (kH-1)*(dH-1); - kWEff = kW + (kW-1)*(dW-1); - } - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (int index = tid; index < length; index += blockDim.x * gridDim.x) { - - const int pw = index % oW; - const int ph = (index / oW) % oH; - const int c = (index / oW / oH) % iC; - const int n = index / oW / oH / iC; - - int hstart = sH * ph - pH; - int wstart = sW * pw - pW; - int hend = hstart + kHEff; - int wend = wstart + kWEff; - - if(hstart < 0){ - int f = sd::math::nd4j_ceil((Z) -hstart / (Z)dH); - hstart += f * dH; - } - if(wstart < 0){ - int f = sd::math::nd4j_ceil((Z) -wstart / (Z) dW); - wstart += f * dW; - } - if(hend > iH){ - int f = sd::math::nd4j_ceil((Z) (hend-iH) / (Z) dH); - hend -= f * dH; - } - if(wend > iW){ - int f = sd::math::nd4j_ceil((Z) (wend-iW) / (Z) dW); - wend -= f * dW; - } - //Accounts for dilation - int pool_size = sd::math::nd4j_ceil((double) (hend-hstart) / (double) dH) * sd::math::nd4j_ceil((double) (wend-wstart) / (double) dW); - - Z max = -sd::DataTypeUtils::max(); - - const X *inSlice = x + (n * strideB + c * strideC); - - for (int h = hstart; h < hend; h += dH) { - for (int w = wstart; w < wend; w += dW) { - Z v = static_cast(inSlice[h * strideY + w * strideX]); - if (v > max) - max = v; - } - } - - z[n * strideOB + c * strideOC + pw * strideOX + ph * strideOY] = max; + for (int h = hstart; h < hend; h += dH) { + for (int w = wstart; w < wend; w += dW) { + Z v = static_cast(inSlice[h * strideY + w * strideX]); + if (v > max) max = v; + } } + + z[n * strideOB + c * strideOC + pw * strideOX + ph * strideOY] = max; + } } ////////////////////////////////////////////////////////////////////////// template -static void maxPooling2dCudaLauncher(sd::LaunchContext & block, const void *vx, const Nd4jLong *vxShapeInfo, void *vz, const Nd4jLong *vzShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { - maxPooling2dCuda<<<512, 512, 4192, *block.getCudaStream()>>>(vx, vxShapeInfo, vz, vzShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, extraParam0); +static void maxPooling2dCudaLauncher(sd::LaunchContext &block, const void *vx, + const Nd4jLong *vxShapeInfo, void *vz, + const Nd4jLong *vzShapeInfo, const int kH, + const int kW, const int sH, const int sW, + const int pH, const int pW, const int dH, + const int dW, const int extraParam0) { + maxPooling2dCuda<<<512, 512, 4192, *block.getCudaStream()>>>( + vx, vxShapeInfo, vz, vzShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, + extraParam0); } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::pooling2d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0) { - - if(!input.isActualOnDeviceSide()) input.syncToDevice(); - - switch (poolingMode) { - - case MAX_POOL: { - BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), maxPooling2dCudaLauncher, (*block.launchContext(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), FLOAT_TYPES); - } - break; - case AVG_POOL: { - BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), avgPooling2dCudaLauncher, (*block.launchContext(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), FLOAT_TYPES); - } - break; - case PNORM_POOL: { - BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), pnormPooling2dCudaLauncher, (*block.launchContext(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), FLOAT_TYPES); - } - break; - default: - throw std::runtime_error("Pooling2D: Unknown PoolingType used"); - } - - output.tickWriteDevice(); - input.tickReadDevice(); - - auto result = cudaStreamSynchronize(*block.launchContext()->getCudaStream()); - if (result != 0) - throw cuda_exception::build("Pooling2D failed", result); +void ConvolutionUtils::pooling2d(sd::graph::Context &block, + const NDArray &input, NDArray &output, + const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, + const int dH, const int dW, + const PoolingType poolingMode, + const int extraParam0) { + if (!input.isActualOnDeviceSide()) input.syncToDevice(); + + switch (poolingMode) { + case MAX_POOL: { + BUILD_SINGLE_SELECTOR_TWICE( + input.dataType(), maxPooling2dCudaLauncher, + (*block.launchContext(), input.specialBuffer(), + input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, + extraParam0), + FLOAT_TYPES); + } break; + case AVG_POOL: { + BUILD_SINGLE_SELECTOR_TWICE( + input.dataType(), avgPooling2dCudaLauncher, + (*block.launchContext(), input.specialBuffer(), + input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, + extraParam0), + FLOAT_TYPES); + } break; + case PNORM_POOL: { + BUILD_SINGLE_SELECTOR_TWICE( + input.dataType(), pnormPooling2dCudaLauncher, + (*block.launchContext(), input.specialBuffer(), + input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, + extraParam0), + FLOAT_TYPES); + } break; + default: + throw std::runtime_error("Pooling2D: Unknown PoolingType used"); + } + + output.tickWriteDevice(); + input.tickReadDevice(); + + auto result = cudaStreamSynchronize(*block.launchContext()->getCudaStream()); + if (result != 0) throw cuda_exception::build("Pooling2D failed", result); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling2dBP.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling2dBP.cu index 62f4787ddc02..398fd872f6c5 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling2dBP.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling2dBP.cu @@ -19,170 +19,189 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include #include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -__global__ static void pooling2dBPCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { - - // x: input [bS, iC, iH, iW] - // y: gradO [bS, iC, oH, oW] - // z: gradI [bS, iC, iH, iW] -> gradI is output in this function - - const T* x = reinterpret_cast(vx); - const T* y = reinterpret_cast(vy); - T* z = reinterpret_cast(vz); - - Nd4jLong coord2, coord3; - __shared__ int rank, kHeff, kWeff, iH, iW, kProd; - __shared__ Nd4jLong yLen, *sharedMem; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - yLen = shape::length(yShapeInfo); - rank = 4; - - kHeff = kH + (kH - 1) * (dH - 1); - kWeff = kW + (kW - 1) * (dW - 1); - - iH = xShapeInfo[3]; - iW = xShapeInfo[4]; - - kProd = kH * kW; - } - __syncthreads(); - - const auto yInd = threadIdx.x + blockIdx.x * blockDim.x; - - if(yInd >= yLen) - return; - - auto coords = sharedMem + threadIdx.x * rank; - - shape::index2coords(yInd, yShapeInfo, coords); - - const auto yOffset = shape::getOffset(yShapeInfo, coords); - - int hstart = coords[2] * sH - pH; - int wstart = coords[3] * sW - pW; - int hend = hstart + kHeff; - int wend = wstart + kWeff; - if(hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if(wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if(hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if(wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - - switch (poolingMode) { - - /*** max ***/ - case 0: { - coord2 = hstart; - coord3 = wstart; - - T max = -DataTypeUtils::max(); - for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) { - for (coords[3] = wstart; coords[3] < wend; coords[3] += dW){ - T val = x[shape::getOffset(xShapeInfo, coords)]; - if (val > max) { - max = val; - coord2 = coords[2]; - coord3 = coords[3]; - } - } - } - coords[2] = coord2; - coords[3] = coord3; - auto zOffset = shape::getOffset(zShapeInfo, coords); - sd::math::atomics::nd4j_atomicAdd(&z[zOffset], y[yOffset]); - //z[zOffset] += y[yOffset]; +__global__ static void pooling2dBPCuda( + const void* vx, const Nd4jLong* xShapeInfo, const void* vy, + const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + const int kH, const int kW, const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW, const int poolingMode, + const int extraParam0) { + // x: input [bS, iC, iH, iW] + // y: gradO [bS, iC, oH, oW] + // z: gradI [bS, iC, iH, iW] -> gradI is output in this function + + const T* x = reinterpret_cast(vx); + const T* y = reinterpret_cast(vy); + T* z = reinterpret_cast(vz); + + Nd4jLong coord2, coord3; + __shared__ int rank, kHeff, kWeff, iH, iW, kProd; + __shared__ Nd4jLong yLen, *sharedMem; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + yLen = shape::length(yShapeInfo); + rank = 4; + + kHeff = kH + (kH - 1) * (dH - 1); + kWeff = kW + (kW - 1) * (dW - 1); + + iH = xShapeInfo[3]; + iW = xShapeInfo[4]; + + kProd = kH * kW; + } + __syncthreads(); + + const auto yInd = threadIdx.x + blockIdx.x * blockDim.x; + + if (yInd >= yLen) return; + + auto coords = sharedMem + threadIdx.x * rank; + + shape::index2coords(yInd, yShapeInfo, coords); + + const auto yOffset = shape::getOffset(yShapeInfo, coords); + + int hstart = coords[2] * sH - pH; + int wstart = coords[3] * sW - pW; + int hend = hstart + kHeff; + int wend = wstart + kWeff; + if (hstart < 0) hstart += dH * ((-hstart + dH - 1) / dH); + if (wstart < 0) wstart += dW * ((-wstart + dW - 1) / dW); + if (hend > iH) hend -= dH * ((hend - iH + dH - 1) / dH); + if (wend > iW) wend -= dW * ((wend - iW + dW - 1) / dW); + + switch (poolingMode) { + /*** max ***/ + case 0: { + coord2 = hstart; + coord3 = wstart; + + T max = -DataTypeUtils::max(); + for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) { + for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) { + T val = x[shape::getOffset(xShapeInfo, coords)]; + if (val > max) { + max = val; + coord2 = coords[2]; + coord3 = coords[3]; + } } - break; - - /*** avg ***/ - case 1: { - - T val = y[yOffset]; - - if (extraParam0 == 0) //Exclude padding - val /= sd::math::nd4j_ceil(static_cast(hend - hstart) / static_cast(dH)) * sd::math::nd4j_ceil(static_cast(wend - wstart) / static_cast(dW)); //Accounts for dilation - else if (extraParam0 == 1) //Include padding - val /= kProd; - - for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) - for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) - sd::math::atomics::nd4j_atomicAdd(&z[shape::getOffset(zShapeInfo, coords)], val); - } - break; - - /*** pnorm ***/ - case 2: { - - T sum = static_cast(0.); - T val = y[yOffset]; - - for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) - for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) - sum += sd::math::nd4j_pow(sd::math::nd4j_abs(x[shape::getOffset(xShapeInfo, coords)]), extraParam0); - - val *= sd::math::nd4j_pow(sum, ((T)1.f - extraParam0) / extraParam0); - - for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) { - for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) { - const auto xOffset = shape::getOffset(xShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); - sd::math::atomics::nd4j_atomicAdd(&z[zOffset], val * sd::math::nd4j_pow(sd::math::nd4j_abs(x[xOffset]), extraParam0 - 1.f) * sd::math::nd4j_sgn(x[xOffset])); - } - } + } + coords[2] = coord2; + coords[3] = coord3; + auto zOffset = shape::getOffset(zShapeInfo, coords); + sd::math::atomics::nd4j_atomicAdd(&z[zOffset], y[yOffset]); + // z[zOffset] += y[yOffset]; + } break; + + /*** avg ***/ + case 1: { + T val = y[yOffset]; + + if (extraParam0 == 0) // Exclude padding + val /= + sd::math::nd4j_ceil(static_cast(hend - hstart) / + static_cast(dH)) * + sd::math::nd4j_ceil( + static_cast(wend - wstart) / + static_cast(dW)); // Accounts for dilation + else if (extraParam0 == 1) // Include padding + val /= kProd; + + for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) + for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) + sd::math::atomics::nd4j_atomicAdd( + &z[shape::getOffset(zShapeInfo, coords)], val); + } break; + + /*** pnorm ***/ + case 2: { + T sum = static_cast(0.); + T val = y[yOffset]; + + for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) + for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) + sum += sd::math::nd4j_pow( + sd::math::nd4j_abs(x[shape::getOffset(xShapeInfo, coords)]), + extraParam0); + + val *= sd::math::nd4j_pow(sum, + ((T)1.f - extraParam0) / extraParam0); + + for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) { + for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) { + const auto xOffset = shape::getOffset(xShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); + sd::math::atomics::nd4j_atomicAdd( + &z[zOffset], + val * + sd::math::nd4j_pow(sd::math::nd4j_abs(x[xOffset]), + extraParam0 - 1.f) * + sd::math::nd4j_sgn(x[xOffset])); } - break; - } + } + } break; + } } ////////////////////////////////////////////////////////////////////////// template -static void pooling2dBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const int kH, const int kW, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW, - const int poolingMode, const int extraParam0) { - - pooling2dBPCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0); +static void pooling2dBPCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vy, const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, const int dH, const int dW, + const int poolingMode, const int extraParam0) { + pooling2dBPCuda<<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, kH, kW, sH, sW, pH, pW, + dH, dW, poolingMode, extraParam0); } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::pooling2dBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { - - // initial zeroing of gradI - gradI.nullify(); - - PointersManager manager(block.launchContext(), "pooling2dBP"); - - const int threadsPerBlock = 256; - const int blocksPerGrid = (gradO.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&gradI}, {&input, &gradO}); - BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), gradO.specialBuffer(), gradO.specialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); - NDArray::registerSpecialUse({&gradI}, {&input, &gradO}); - - manager.synchronize(); +void ConvolutionUtils::pooling2dBP(sd::graph::Context& block, + const NDArray& input, const NDArray& gradO, + NDArray& gradI, const int kH, const int kW, + const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW, + const int poolingMode, + const int extraParam0) { + // initial zeroing of gradI + gradI.nullify(); + + PointersManager manager(block.launchContext(), "pooling2dBP"); + + const int threadsPerBlock = 256; + const int blocksPerGrid = + (gradO.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&gradI}, {&input, &gradO}); + BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBPCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, + block.launchContext()->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), + gradO.specialBuffer(), gradO.specialShapeInfo(), + gradI.specialBuffer(), gradI.specialShapeInfo(), kH, + kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), + FLOAT_TYPES); + NDArray::registerSpecialUse({&gradI}, {&input, &gradO}); + + manager.synchronize(); } -} -} \ No newline at end of file +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling3d.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling3d.cu index 0a3bfc9b6119..5f69a30ba8c3 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling3d.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling3d.cu @@ -19,163 +19,177 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include #include +#include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// template -__global__ static void pooling3dCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - - // x input is [bS, iC, iD, iH, iW] - // z output is [bS, iC, oD, oH, oW] - - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); - - __shared__ int rank, kDeff, kHeff, kWeff, iD, iH, iW, kProd; - __shared__ Nd4jLong zLen, *sharedMem; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - zLen = shape::length(zShapeInfo); - rank = 5; - - kDeff = kD + (kD - 1) * (dD - 1); - kHeff = kH + (kH - 1) * (dH - 1); - kWeff = kW + (kW - 1) * (dW - 1); - - iD = xShapeInfo[3]; - iH = xShapeInfo[4]; - iW = xShapeInfo[5]; - - kProd = kD * kH * kW; - } - __syncthreads(); - - const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; - - if(zInd >= zLen) - return; - - auto coords = sharedMem + threadIdx.x * rank; - - shape::index2coords(zInd, zShapeInfo, coords); - - const auto zOffset = shape::getOffset(zShapeInfo, coords); - - int dstart = coords[2] * sD - pD; - int hstart = coords[3] * sH - pH; - int wstart = coords[4] * sW - pW; - int dend = dstart + kDeff; - int hend = hstart + kHeff; - int wend = wstart + kWeff; - - if(dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if(hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if(wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if(dend > iD) - dend -= dD * ((dend - iD + dD - 1) / dD); - if(hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if(wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - - switch (poolingMode) { - - /*** max ***/ - case 0: { - T max = -DataTypeUtils::max(); - for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) { - for (coords[3] = hstart; coords[3] < hend; coords[3] += dH){ - for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) { - T val = x[shape::getOffset(xShapeInfo, coords)]; - if (val > max) - max = val; - } - } - } - z[zOffset] = max; - } - break; - - /*** avg ***/ - case 1: { - T sum = static_cast(0.); - for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) - for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) - for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) - sum += x[shape::getOffset(xShapeInfo, coords)]; - - if (extraParam0 == 0) { //Exclude padding - uint a = (dend - dstart) / dD + ((dend - dstart) % dD == 0 ? 0 : 1); - uint b = (hend - hstart) / dH + ((hend - hstart) % dH == 0 ? 0 : 1); - uint c = (wend - wstart) / dW + ((wend - wstart) % dW == 0 ? 0 : 1); - sum /= static_cast(a * b * c); // /= sd::math::nd4j_ceil(static_cast(dend - dstart) / static_cast(dD)) * sd::math::nd4j_ceil(static_cast(hend - hstart) / static_cast(dH)) * sd::math::nd4j_ceil(static_cast(wend - wstart) / static_cast(dW)); //Accounts for dilation - } - else if (extraParam0 == 1) //Include padding - sum /= kProd; - - z[zOffset] = sum; - } - break; - - /*** pnorm ***/ - case 2: { - T sum = static_cast(0.); - for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) - for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) - for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) - sum += sd::math::nd4j_pow(sd::math::nd4j_abs(x[shape::getOffset(xShapeInfo, coords)]), extraParam0); - - sum = sd::math::nd4j_pow(sum, (T) 1.f / extraParam0); - - z[zOffset] = sum; +__global__ static void pooling3dCuda(const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const int poolingMode, + const int extraParam0) { + // x input is [bS, iC, iD, iH, iW] + // z output is [bS, iC, oD, oH, oW] + + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); + + __shared__ int rank, kDeff, kHeff, kWeff, iD, iH, iW, kProd; + __shared__ Nd4jLong zLen, *sharedMem; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + zLen = shape::length(zShapeInfo); + rank = 5; + + kDeff = kD + (kD - 1) * (dD - 1); + kHeff = kH + (kH - 1) * (dH - 1); + kWeff = kW + (kW - 1) * (dW - 1); + + iD = xShapeInfo[3]; + iH = xShapeInfo[4]; + iW = xShapeInfo[5]; + + kProd = kD * kH * kW; + } + __syncthreads(); + + const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; + + if (zInd >= zLen) return; + + auto coords = sharedMem + threadIdx.x * rank; + + shape::index2coords(zInd, zShapeInfo, coords); + + const auto zOffset = shape::getOffset(zShapeInfo, coords); + + int dstart = coords[2] * sD - pD; + int hstart = coords[3] * sH - pH; + int wstart = coords[4] * sW - pW; + int dend = dstart + kDeff; + int hend = hstart + kHeff; + int wend = wstart + kWeff; + + if (dstart < 0) dstart += dD * ((-dstart + dD - 1) / dD); + if (hstart < 0) hstart += dH * ((-hstart + dH - 1) / dH); + if (wstart < 0) wstart += dW * ((-wstart + dW - 1) / dW); + if (dend > iD) dend -= dD * ((dend - iD + dD - 1) / dD); + if (hend > iH) hend -= dH * ((hend - iH + dH - 1) / dH); + if (wend > iW) wend -= dW * ((wend - iW + dW - 1) / dW); + + switch (poolingMode) { + /*** max ***/ + case 0: { + T max = -DataTypeUtils::max(); + for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) { + for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) { + for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) { + T val = x[shape::getOffset(xShapeInfo, coords)]; + if (val > max) max = val; + } } - break; - } + } + z[zOffset] = max; + } break; + + /*** avg ***/ + case 1: { + T sum = static_cast(0.); + for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) + for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) + for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) + sum += x[shape::getOffset(xShapeInfo, coords)]; + + if (extraParam0 == 0) { // Exclude padding + uint a = (dend - dstart) / dD + ((dend - dstart) % dD == 0 ? 0 : 1); + uint b = (hend - hstart) / dH + ((hend - hstart) % dH == 0 ? 0 : 1); + uint c = (wend - wstart) / dW + ((wend - wstart) % dW == 0 ? 0 : 1); + sum /= static_cast( + a * b * + c); // /= sd::math::nd4j_ceil(static_cast(dend - + // dstart) / static_cast(dD)) * + // sd::math::nd4j_ceil(static_cast(hend - + // hstart) / static_cast(dH)) * + // sd::math::nd4j_ceil(static_cast(wend - + // wstart) / static_cast(dW)); //Accounts for + // dilation + } else if (extraParam0 == 1) // Include padding + sum /= kProd; + + z[zOffset] = sum; + } break; + + /*** pnorm ***/ + case 2: { + T sum = static_cast(0.); + for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) + for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) + for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) + sum += sd::math::nd4j_pow( + sd::math::nd4j_abs(x[shape::getOffset(xShapeInfo, coords)]), + extraParam0); + + sum = sd::math::nd4j_pow(sum, (T)1.f / extraParam0); + + z[zOffset] = sum; + } break; + } } ////////////////////////////////////////////////////////////////////////// template -static void pooling3dCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const int poolingMode, const int extraParam0) { - - pooling3dCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0); +static void pooling3dCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, const int kD, const int kH, + const int kW, const int sD, const int sH, const int sW, const int pD, + const int pH, const int pW, const int dD, const int dH, const int dW, + const int poolingMode, const int extraParam0) { + pooling3dCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, + dH, dW, poolingMode, extraParam0); } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::pooling3d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - - PointersManager manager(block.launchContext(), "pooling3d"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); - NDArray::registerSpecialUse({&output}, {&input}); - - manager.synchronize(); +void ConvolutionUtils::pooling3d(sd::graph::Context& block, + const NDArray& input, NDArray& output, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const int poolingMode, const int extraParam0) { + PointersManager manager(block.launchContext(), "pooling3d"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR( + input.dataType(), pooling3dCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, + block.launchContext()->getCudaStream(), input.specialBuffer(), + input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, + dW, poolingMode, extraParam0), + FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + manager.synchronize(); } - -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling3dBP.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling3dBP.cu index fd78bb80beb0..70fa1a43e0a0 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling3dBP.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling3dBP.cu @@ -19,184 +19,205 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include #include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -__global__ static void pooling3dBPCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - - // x: input [bS, iC, iD, iH, iW] - // y: gradO [bS, iC, oD, oH, oW] - // z: gradI [bS, iC, iD, iH, iW] -> gradI is output in this function - - - const T* x = reinterpret_cast(vx); - const T* y = reinterpret_cast(vy); - T* z = reinterpret_cast(vz); - - Nd4jLong coord2, coord3, coord4; - __shared__ int rank, kDeff, kHeff, kWeff, iD, iH, iW, kProd; - __shared__ Nd4jLong yLen, *sharedMem; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - yLen = shape::length(yShapeInfo); - rank = 5; - - kDeff = kD + (kD - 1) * (dD - 1); - kHeff = kH + (kH - 1) * (dH - 1); - kWeff = kW + (kW - 1) * (dW - 1); - - iD = xShapeInfo[3]; - iH = xShapeInfo[4]; - iW = xShapeInfo[5]; - - kProd = kD * kH * kW; - } - __syncthreads(); - - const auto yInd = threadIdx.x + blockIdx.x * blockDim.x; - - if(yInd >= yLen) - return; - - auto coords = sharedMem + threadIdx.x * rank; - - shape::index2coords(yInd, yShapeInfo, coords); - - const auto yOffset = shape::getOffset(yShapeInfo, coords); - - int dstart = coords[2] * sD - pD; - int hstart = coords[3] * sH - pH; - int wstart = coords[4] * sW - pW; - int dend = dstart + kDeff; - int hend = hstart + kHeff; - int wend = wstart + kWeff; - - if(dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if(hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if(wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if(dend > iD) - dend -= dD * ((dend - iD + dD - 1) / dD); - if(hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if(wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - - switch (poolingMode) { - - /*** max ***/ - case 0: { - - T max = -DataTypeUtils::max(); - for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) { - for (coords[3] = hstart; coords[3] < hend; coords[3] += dH){ - for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) { - T val = x[shape::getOffset(xShapeInfo, coords)]; - if (val > max) { - max = val; - coord2 = coords[2]; - coord3 = coords[3]; - coord4 = coords[4]; - } - } - } +__global__ static void pooling3dBPCuda( + const void* vx, const Nd4jLong* xShapeInfo, const void* vy, + const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + const int kD, const int kH, const int kW, const int sD, const int sH, + const int sW, const int pD, const int pH, const int pW, const int dD, + const int dH, const int dW, const int poolingMode, const int extraParam0) { + // x: input [bS, iC, iD, iH, iW] + // y: gradO [bS, iC, oD, oH, oW] + // z: gradI [bS, iC, iD, iH, iW] -> gradI is output in this function + + const T* x = reinterpret_cast(vx); + const T* y = reinterpret_cast(vy); + T* z = reinterpret_cast(vz); + + Nd4jLong coord2, coord3, coord4; + __shared__ int rank, kDeff, kHeff, kWeff, iD, iH, iW, kProd; + __shared__ Nd4jLong yLen, *sharedMem; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + yLen = shape::length(yShapeInfo); + rank = 5; + + kDeff = kD + (kD - 1) * (dD - 1); + kHeff = kH + (kH - 1) * (dH - 1); + kWeff = kW + (kW - 1) * (dW - 1); + + iD = xShapeInfo[3]; + iH = xShapeInfo[4]; + iW = xShapeInfo[5]; + + kProd = kD * kH * kW; + } + __syncthreads(); + + const auto yInd = threadIdx.x + blockIdx.x * blockDim.x; + + if (yInd >= yLen) return; + + auto coords = sharedMem + threadIdx.x * rank; + + shape::index2coords(yInd, yShapeInfo, coords); + + const auto yOffset = shape::getOffset(yShapeInfo, coords); + + int dstart = coords[2] * sD - pD; + int hstart = coords[3] * sH - pH; + int wstart = coords[4] * sW - pW; + int dend = dstart + kDeff; + int hend = hstart + kHeff; + int wend = wstart + kWeff; + + if (dstart < 0) dstart += dD * ((-dstart + dD - 1) / dD); + if (hstart < 0) hstart += dH * ((-hstart + dH - 1) / dH); + if (wstart < 0) wstart += dW * ((-wstart + dW - 1) / dW); + if (dend > iD) dend -= dD * ((dend - iD + dD - 1) / dD); + if (hend > iH) hend -= dH * ((hend - iH + dH - 1) / dH); + if (wend > iW) wend -= dW * ((wend - iW + dW - 1) / dW); + + switch (poolingMode) { + /*** max ***/ + case 0: { + T max = -DataTypeUtils::max(); + for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) { + for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) { + for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) { + T val = x[shape::getOffset(xShapeInfo, coords)]; + if (val > max) { + max = val; + coord2 = coords[2]; + coord3 = coords[3]; + coord4 = coords[4]; } - coords[2] = coord2; - coords[3] = coord3; - coords[4] = coord4; - sd::math::atomics::nd4j_atomicAdd(&z[shape::getOffset(zShapeInfo, coords)], y[yOffset]); - } - break; - - /*** avg ***/ - case 1: { - - T val = y[yOffset]; - - if (extraParam0 == 0) //Exclude padding - val /= sd::math::nd4j_ceil(static_cast(dend - dstart) / static_cast(dD)) * sd::math::nd4j_ceil(static_cast(hend - hstart) / static_cast(dH)) * sd::math::nd4j_ceil(static_cast(wend - wstart) / static_cast(dW)); //Accounts for dilation - else if (extraParam0 == 1) //Include padding - val /= kProd; - - for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) - for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) - for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) - sd::math::atomics::nd4j_atomicAdd(&z[shape::getOffset(zShapeInfo, coords)], val); + } } - break; - - /*** pnorm ***/ - case 2: { - - T sum = static_cast(0.); - T val = y[yOffset]; - - for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) - for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) - for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) - sum += sd::math::nd4j_pow(sd::math::nd4j_abs(x[shape::getOffset(xShapeInfo, coords)]), extraParam0); - - val *= sd::math::nd4j_pow(sum, ((T)1.f - extraParam0) / extraParam0); - - for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) { - for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) { - for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) { - const auto xOffset = shape::getOffset(xShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); - sd::math::atomics::nd4j_atomicAdd(&z[zOffset], val * sd::math::nd4j_pow(sd::math::nd4j_abs(x[xOffset]), extraParam0 - 1.f) * sd::math::nd4j_sgn(x[xOffset])); - } - } - } + } + coords[2] = coord2; + coords[3] = coord3; + coords[4] = coord4; + sd::math::atomics::nd4j_atomicAdd( + &z[shape::getOffset(zShapeInfo, coords)], y[yOffset]); + } break; + + /*** avg ***/ + case 1: { + T val = y[yOffset]; + + if (extraParam0 == 0) // Exclude padding + val /= + sd::math::nd4j_ceil(static_cast(dend - dstart) / + static_cast(dD)) * + sd::math::nd4j_ceil(static_cast(hend - hstart) / + static_cast(dH)) * + sd::math::nd4j_ceil( + static_cast(wend - wstart) / + static_cast(dW)); // Accounts for dilation + else if (extraParam0 == 1) // Include padding + val /= kProd; + + for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) + for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) + for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) + sd::math::atomics::nd4j_atomicAdd( + &z[shape::getOffset(zShapeInfo, coords)], val); + } break; + + /*** pnorm ***/ + case 2: { + T sum = static_cast(0.); + T val = y[yOffset]; + + for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) + for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) + for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) + sum += sd::math::nd4j_pow( + sd::math::nd4j_abs(x[shape::getOffset(xShapeInfo, coords)]), + extraParam0); + + val *= sd::math::nd4j_pow(sum, + ((T)1.f - extraParam0) / extraParam0); + + for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) { + for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) { + for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) { + const auto xOffset = shape::getOffset(xShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); + sd::math::atomics::nd4j_atomicAdd( + &z[zOffset], + val * + sd::math::nd4j_pow( + sd::math::nd4j_abs(x[xOffset]), extraParam0 - 1.f) * + sd::math::nd4j_sgn(x[xOffset])); + } } - break; - } + } + } break; + } } ////////////////////////////////////////////////////////////////////////// template -static void pooling3dBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const int poolingMode, const int extraParam0) { - - pooling3dBPCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0); +static void pooling3dBPCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vy, const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, const int pD, const int pH, + const int pW, const int dD, const int dH, const int dW, + const int poolingMode, const int extraParam0) { + pooling3dBPCuda<<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, kD, kH, kW, sD, sH, sW, + pD, pH, pW, dD, dH, dW, poolingMode, extraParam0); } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::pooling3dBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - - // initial zeroing of gradI - gradI.nullify(); - - PointersManager manager(block.launchContext(), "pooling3dBP"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (gradO.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&gradI}, {&input, &gradO}); - BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), gradO.specialBuffer(), gradO.specialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); - NDArray::registerSpecialUse({&gradI}, {&input, &gradO}); - - manager.synchronize(); +void ConvolutionUtils::pooling3dBP(sd::graph::Context& block, + const NDArray& input, const NDArray& gradO, + NDArray& gradI, const int kD, const int kH, + const int kW, const int sD, const int sH, + const int sW, const int pD, const int pH, + const int pW, const int dD, const int dH, + const int dW, const int poolingMode, + const int extraParam0) { + // initial zeroing of gradI + gradI.nullify(); + + PointersManager manager(block.launchContext(), "pooling3dBP"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (gradO.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&gradI}, {&input, &gradO}); + BUILD_SINGLE_SELECTOR( + input.dataType(), pooling3dBPCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, + block.launchContext()->getCudaStream(), input.specialBuffer(), + input.specialShapeInfo(), gradO.specialBuffer(), + gradO.specialShapeInfo(), gradI.specialBuffer(), + gradI.specialShapeInfo(), kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + poolingMode, extraParam0), + FLOAT_TYPES); + NDArray::registerSpecialUse({&gradI}, {&input, &gradO}); + + manager.synchronize(); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_sconv2d.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_sconv2d.cu index 3a9ed5364071..cd22dc728343 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_sconv2d.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_sconv2d.cu @@ -22,52 +22,81 @@ #include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { +static void sconv2d_(sd::graph::Context& block, const NDArray* input, + const NDArray* weightsDepth, const NDArray* weightsPoint, + const NDArray* bias, NDArray* output, const int kH, + const int kW, const int sH, const int sW, int pH, int pW, + const int dH, const int dW, const int paddingMode, + const int isNCHW, const int wFormat) { + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weightsDepth [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // weightsPoint [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] + // bias [oC], oC = iC*mC if weightsPoint=nullptr + // output is [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weightsDepth [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - // weightsPoint [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] - // bias [oC], oC = iC*mC if weightsPoint=nullptr - // output is [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 1-NCHW, 0-NHWC - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 1-NCHW, 0-NHWC + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier, output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weightsDepth->sizeAt(indWmC); // channels multiplier - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weightsDepth->sizeAt(indWmC); // channels multiplier + NDArray* outputDepth = output; + if (weightsPoint) // if pointwise convolution is expected + outputDepth = + new NDArray(output->ordering(), + !isNCHW ? std::vector({bS, oH, oW, iC * mC}) + : std::vector({bS, iC * mC, oH, oW}), + input->dataType(), input->getContext()); - NDArray* outputDepth = output; - if(weightsPoint) // if pointwise convolution is expected - outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector({bS, oH, oW, iC*mC}) : std::vector({bS, iC*mC, oH, oW}), input->dataType(), input->getContext()); + // ----- perform depthwise convolution (if weightsPoint is absent then oC = + // iC*mC) ----- // + ConvolutionUtils::depthwiseConv2d( + block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, + kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); - // ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- // - ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW, wFormat); - - // ----- perform pointwise convolution (oH = iH, oW = iW) ----- // - if (weightsPoint) { - ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW, wFormat); // in this case oH=iH, oW=iW - delete outputDepth; - } + // ----- perform pointwise convolution (oH = iH, oW = iW) ----- // + if (weightsPoint) { + ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1, + 1, 1, 1, 0, 0, 1, 1, paddingMode, isNCHW, + wFormat); // in this case oH=iH, oW=iW + delete outputDepth; + } } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); +void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, + const NDArray* weightsDepth, + const NDArray* weightsPoint, const NDArray* bias, + NDArray* output, const int kH, const int kW, + const int sH, const int sW, int pH, int pW, + const int dH, const int dW, + const int paddingMode, const int isNCHW, + const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE( + input->dataType(), sconv2d_, + (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, + pH, pW, dH, dW, paddingMode, isNCHW, wFormat), + FLOAT_TYPES); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling2d.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling2d.cu index ee1fa892416f..988e1fc310e6 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling2d.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling2d.cu @@ -19,79 +19,93 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -__global__ static void upsampling2dCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int factorH, const int factorW, const bool isNCHW) { - - // x has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) - // z has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) +__global__ static void upsampling2dCuda(const void* vx, + const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, + const int factorH, const int factorW, + const bool isNCHW) { + // x has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) + // z has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, + // factorW*iW, iC] (NHWC) - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); - __shared__ int rank, dimIH; - __shared__ Nd4jLong zLen, *sharedMem; + __shared__ int rank, dimIH; + __shared__ Nd4jLong zLen, *sharedMem; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - dimIH = isNCHW ? 2 : 1; - zLen = shape::length(zShapeInfo); - rank = 4; - } - __syncthreads(); + dimIH = isNCHW ? 2 : 1; + zLen = shape::length(zShapeInfo); + rank = 4; + } + __syncthreads(); - const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; + const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; - if(zInd >= zLen) - return; + if (zInd >= zLen) return; - auto coords = sharedMem + threadIdx.x * rank; + auto coords = sharedMem + threadIdx.x * rank; - shape::index2coords(zInd, zShapeInfo, coords); + shape::index2coords(zInd, zShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); - coords[dimIH] /= factorH; - coords[dimIH + 1] /= factorW; + coords[dimIH] /= factorH; + coords[dimIH + 1] /= factorW; - const auto xOffset = shape::getOffset(xShapeInfo, coords); + const auto xOffset = shape::getOffset(xShapeInfo, coords); - z[zOffset] = x[xOffset]; + z[zOffset] = x[xOffset]; } ////////////////////////////////////////////////////////////////////////// template -static void upsampling2dCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const int factorH, const int factorW, const bool isNCHW) { - - upsampling2dCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, factorH, factorW, isNCHW); +static void upsampling2dCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, const int factorH, const int factorW, + const bool isNCHW) { + upsampling2dCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, factorH, factorW, isNCHW); } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::upsampling2d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) { - - PointersManager manager(block.launchContext(), "upsampling2d"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), factorH, factorW, isNCHW), FLOAT_TYPES); - NDArray::registerSpecialUse({&output}, {&input}); - - manager.synchronize(); +void ConvolutionUtils::upsampling2d(sd::graph::Context& block, + const NDArray& input, NDArray& output, + const int factorH, const int factorW, + const bool isNCHW) { + PointersManager manager(block.launchContext(), "upsampling2d"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR( + input.dataType(), upsampling2dCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, + block.launchContext()->getCudaStream(), input.specialBuffer(), + input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), factorH, factorW, isNCHW), + FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + manager.synchronize(); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling2dBP.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling2dBP.cu index c6864c48a109..519380c2e35e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling2dBP.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling2dBP.cu @@ -19,85 +19,98 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -__global__ static void upsampling2dBPCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const bool isNCHW) { - - // x (gradO) has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) - // z (gradI) has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) +__global__ static void upsampling2dBPCuda(const void* vx, + const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, + const bool isNCHW) { + // x (gradO) has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, + // factorH*iH, factorW*iW, iC] (NHWC) z (gradI) has shape [bS, iC, iH, iW] + // (NCHW) or [bS, iH, iW, iC] (NHWC) - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); - __shared__ int rank, dimIH; - __shared__ uint factorH, factorW; - __shared__ Nd4jLong zLen, *sharedMem; + __shared__ int rank, dimIH; + __shared__ uint factorH, factorW; + __shared__ Nd4jLong zLen, *sharedMem; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - dimIH = isNCHW ? 2 : 1; - zLen = shape::length(zShapeInfo); - rank = 4; + dimIH = isNCHW ? 2 : 1; + zLen = shape::length(zShapeInfo); + rank = 4; - factorH = xShapeInfo[dimIH + 1] / zShapeInfo[dimIH + 1]; - factorW = xShapeInfo[dimIH + 2] / zShapeInfo[dimIH + 2]; - } - __syncthreads(); + factorH = xShapeInfo[dimIH + 1] / zShapeInfo[dimIH + 1]; + factorW = xShapeInfo[dimIH + 2] / zShapeInfo[dimIH + 2]; + } + __syncthreads(); - const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; + const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; - if(zInd >= zLen) - return; + if (zInd >= zLen) return; - auto coords = sharedMem + threadIdx.x * rank; + auto coords = sharedMem + threadIdx.x * rank; - shape::index2coords(zInd, zShapeInfo, coords); + shape::index2coords(zInd, zShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); - z[zOffset] = 0; + z[zOffset] = 0; - const Nd4jLong zCoord2 = coords[dimIH] * factorH; - const Nd4jLong zCoord3 = coords[dimIH + 1] * factorW; + const Nd4jLong zCoord2 = coords[dimIH] * factorH; + const Nd4jLong zCoord3 = coords[dimIH + 1] * factorW; - for(coords[dimIH] = zCoord2; coords[dimIH] < zCoord2 + factorH; ++coords[dimIH]) - for(coords[dimIH + 1] = zCoord3; coords[dimIH + 1] < zCoord3 + factorW; ++coords[dimIH + 1]) - z[zOffset] += x[shape::getOffset(xShapeInfo, coords)]; + for (coords[dimIH] = zCoord2; coords[dimIH] < zCoord2 + factorH; + ++coords[dimIH]) + for (coords[dimIH + 1] = zCoord3; coords[dimIH + 1] < zCoord3 + factorW; + ++coords[dimIH + 1]) + z[zOffset] += x[shape::getOffset(xShapeInfo, coords)]; } ////////////////////////////////////////////////////////////////////////// template -static void upsampling2dBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const bool isNCHW) { - - upsampling2dBPCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, isNCHW); +static void upsampling2dBPCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, const bool isNCHW) { + upsampling2dBPCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, isNCHW); } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::upsampling2dBP(sd::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) { - - PointersManager manager(block.launchContext(), "upsampling2d_bp"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (gradI.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = gradI.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&gradI}, {&gradO}); - BUILD_SINGLE_SELECTOR(gradI.dataType(), upsampling2dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), gradO.specialBuffer(), gradO.specialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), isNCHW), FLOAT_TYPES); - NDArray::registerSpecialUse({&gradI}, {&gradO}); - - manager.synchronize(); +void ConvolutionUtils::upsampling2dBP(sd::graph::Context& block, + const NDArray& gradO, NDArray& gradI, + const bool isNCHW) { + PointersManager manager(block.launchContext(), "upsampling2d_bp"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (gradI.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + gradI.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&gradI}, {&gradO}); + BUILD_SINGLE_SELECTOR( + gradI.dataType(), upsampling2dBPCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, + block.launchContext()->getCudaStream(), gradO.specialBuffer(), + gradO.specialShapeInfo(), gradI.specialBuffer(), + gradI.specialShapeInfo(), isNCHW), + FLOAT_TYPES); + NDArray::registerSpecialUse({&gradI}, {&gradO}); + + manager.synchronize(); } -} -} \ No newline at end of file +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling3d.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling3d.cu index 1acb4307f3d4..f774c7475edf 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling3d.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling3d.cu @@ -19,80 +19,94 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -__global__ static void upsampling3dCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int factorD, const int factorH, const int factorW, const bool isNCDHW) { - - // x has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) - // z has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) +__global__ static void upsampling3dCuda(const void* vx, + const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, + const int factorD, const int factorH, + const int factorW, const bool isNCDHW) { + // x has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) + // z has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, + // factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); - __shared__ int rank, dimID; - __shared__ Nd4jLong zLen, *sharedMem; + __shared__ int rank, dimID; + __shared__ Nd4jLong zLen, *sharedMem; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - dimID = isNCDHW ? 2 : 1; - zLen = shape::length(zShapeInfo); - rank = 5; - } - __syncthreads(); + dimID = isNCDHW ? 2 : 1; + zLen = shape::length(zShapeInfo); + rank = 5; + } + __syncthreads(); - const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; + const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; - if(zInd >= zLen) - return; + if (zInd >= zLen) return; - auto coords = sharedMem + threadIdx.x * rank; + auto coords = sharedMem + threadIdx.x * rank; - shape::index2coords(zInd, zShapeInfo, coords); + shape::index2coords(zInd, zShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); - coords[dimID] /= factorD; - coords[dimID + 1] /= factorH; - coords[dimID + 2] /= factorW; + coords[dimID] /= factorD; + coords[dimID + 1] /= factorH; + coords[dimID + 2] /= factorW; - const auto xOffset = shape::getOffset(xShapeInfo, coords); + const auto xOffset = shape::getOffset(xShapeInfo, coords); - z[zOffset] = x[xOffset]; + z[zOffset] = x[xOffset]; } ////////////////////////////////////////////////////////////////////////// template -static void upsampling3dCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const int factorD, const int factorH, const int factorW, const bool isNCDHW) { - - upsampling3dCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, factorD, factorH, factorW, isNCDHW); +static void upsampling3dCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, const int factorD, const int factorH, + const int factorW, const bool isNCDHW) { + upsampling3dCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, factorD, factorH, factorW, isNCDHW); } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::upsampling3d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) { - - PointersManager manager(block.launchContext(), "upsampling3d"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), factorD, factorH, factorW, isNCDHW), FLOAT_TYPES); - NDArray::registerSpecialUse({&output}, {&input}); - - manager.synchronize(); +void ConvolutionUtils::upsampling3d(sd::graph::Context& block, + const NDArray& input, NDArray& output, + const int factorD, const int factorH, + const int factorW, const bool isNCDHW) { + PointersManager manager(block.launchContext(), "upsampling3d"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR( + input.dataType(), upsampling3dCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, + block.launchContext()->getCudaStream(), input.specialBuffer(), + input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), factorD, factorH, factorW, isNCDHW), + FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + manager.synchronize(); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling3dBP.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling3dBP.cu index 5a1e08c07d1f..9d98eac3967f 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling3dBP.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling3dBP.cu @@ -19,89 +19,102 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -__global__ static void upsampling3dBPCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const bool isNCDHW) { - - // x (gradO) has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) - // z (gradI) has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) +__global__ static void upsampling3dBPCuda(const void* vx, + const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, + const bool isNCDHW) { + // x (gradO) has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] + // (NDHWC) z (gradI) has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] + // (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); - __shared__ int rank, dimID; - __shared__ uint factorD, factorH, factorW; - __shared__ Nd4jLong zLen, *sharedMem; + __shared__ int rank, dimID; + __shared__ uint factorD, factorH, factorW; + __shared__ Nd4jLong zLen, *sharedMem; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - dimID = isNCDHW ? 2 : 1; - zLen = shape::length(zShapeInfo); - rank = 5; + dimID = isNCDHW ? 2 : 1; + zLen = shape::length(zShapeInfo); + rank = 5; - factorD = xShapeInfo[dimID + 1] / zShapeInfo[dimID + 1]; - factorH = xShapeInfo[dimID + 2] / zShapeInfo[dimID + 2]; - factorW = xShapeInfo[dimID + 3] / zShapeInfo[dimID + 3]; - } - __syncthreads(); + factorD = xShapeInfo[dimID + 1] / zShapeInfo[dimID + 1]; + factorH = xShapeInfo[dimID + 2] / zShapeInfo[dimID + 2]; + factorW = xShapeInfo[dimID + 3] / zShapeInfo[dimID + 3]; + } + __syncthreads(); - const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; + const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; - if(zInd >= zLen) - return; + if (zInd >= zLen) return; - auto coords = sharedMem + threadIdx.x * rank; + auto coords = sharedMem + threadIdx.x * rank; - shape::index2coords(zInd, zShapeInfo, coords); + shape::index2coords(zInd, zShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); - z[zOffset] = 0; + z[zOffset] = 0; - const Nd4jLong zCoord2 = coords[dimID] * factorD; - const Nd4jLong zCoord3 = coords[dimID + 1] * factorH; - const Nd4jLong zCoord4 = coords[dimID + 2] * factorW; + const Nd4jLong zCoord2 = coords[dimID] * factorD; + const Nd4jLong zCoord3 = coords[dimID + 1] * factorH; + const Nd4jLong zCoord4 = coords[dimID + 2] * factorW; - for(coords[dimID] = zCoord2; coords[dimID] < zCoord2 + factorD; ++coords[dimID]) - for(coords[dimID + 1] = zCoord3; coords[dimID + 1] < zCoord3 + factorH; ++coords[dimID + 1]) - for(coords[dimID + 2] = zCoord4; coords[dimID + 2] < zCoord4 + factorW; ++coords[dimID + 2]) - z[zOffset] += x[shape::getOffset(xShapeInfo, coords)]; + for (coords[dimID] = zCoord2; coords[dimID] < zCoord2 + factorD; + ++coords[dimID]) + for (coords[dimID + 1] = zCoord3; coords[dimID + 1] < zCoord3 + factorH; + ++coords[dimID + 1]) + for (coords[dimID + 2] = zCoord4; coords[dimID + 2] < zCoord4 + factorW; + ++coords[dimID + 2]) + z[zOffset] += x[shape::getOffset(xShapeInfo, coords)]; } ////////////////////////////////////////////////////////////////////////// template -static void upsampling3dBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const bool isNCDHW) { - - upsampling3dBPCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, isNCDHW); +static void upsampling3dBPCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, const bool isNCDHW) { + upsampling3dBPCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, isNCDHW); } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::upsampling3dBP(sd::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCDHW) { - - PointersManager manager(block.launchContext(), "upsampling3d_bp"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (gradI.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = gradI.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&gradI}, {&gradO}); - BUILD_SINGLE_SELECTOR(gradI.dataType(), upsampling3dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), gradO.specialBuffer(), gradO.specialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), isNCDHW), FLOAT_TYPES); - NDArray::registerSpecialUse({&gradI}, {&gradO}); - - manager.synchronize(); +void ConvolutionUtils::upsampling3dBP(sd::graph::Context& block, + const NDArray& gradO, NDArray& gradI, + const bool isNCDHW) { + PointersManager manager(block.launchContext(), "upsampling3d_bp"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (gradI.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + gradI.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&gradI}, {&gradO}); + BUILD_SINGLE_SELECTOR( + gradI.dataType(), upsampling3dBPCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, + block.launchContext()->getCudaStream(), gradO.specialBuffer(), + gradO.specialShapeInfo(), gradI.specialBuffer(), + gradI.specialShapeInfo(), isNCDHW), + FLOAT_TYPES); + NDArray::registerSpecialUse({&gradI}, {&gradO}); + + manager.synchronize(); } - -} -} \ No newline at end of file +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_vol2col.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_vol2col.cu index c2c5fb3efb86..12d987a47377 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_vol2col.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_vol2col.cu @@ -19,93 +19,120 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// -// vol [bS, iC, iD, iH, iW] is convoluted to col [bS, iC, kD, kH, kW, oD, oH, oW] +// vol [bS, iC, iD, iH, iW] is convoluted to col [bS, iC, kD, kH, kW, oD, oH, +// oW] template -static __global__ void vol2colCuda(const void* volume, const Nd4jLong* volShapeInfo, void* columns, const Nd4jLong* colShapeInfo, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - - const T* vol = reinterpret_cast(volume); - T* col = reinterpret_cast(columns); - - __shared__ int colRank, volRank; - __shared__ Nd4jLong colLen, iD, iH, iW, *sharedMem; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - volRank = 5; - colRank = 8; - - colLen = shape::length(colShapeInfo); - - iD = volShapeInfo[3]; - iH = volShapeInfo[4]; - iW = volShapeInfo[5]; - } - __syncthreads(); - - const auto colInd = threadIdx.x + blockIdx.x * blockDim.x; - - if(colInd >= colLen) - return; - - auto coords = sharedMem + threadIdx.x * colRank; - - shape::index2coords(colInd, colShapeInfo, coords); - - // const auto colW = coords[7]; - // const auto colH = coords[6]; - // const auto colD = coords[5]; - // const auto kCol = coords[4]; - // const auto kRow = coords[3]; - // const auto kDep = coords[2]; - // const auto c = coords[1]; - // const auto b = coords[0]; - - const auto colOffset = shape::getOffset(colShapeInfo, coords); - - coords[2] = -pD + coords[2] * dD + coords[5] * sD; // const auto volDep = (-pD + kDep * dD) + colD * sD; - coords[3] = -pH + coords[3] * dH + coords[6] * sH; // const auto volRow = (-pH + kRow * dH) + colH * sH; - coords[4] = -pW + coords[4] * dW + coords[7] * sW; // const auto volCol = (-pW + kCol * dW) + colW * sW; - - if (static_cast(coords[2]) >= static_cast(iD) || static_cast(coords[3]) >= static_cast(iH) || static_cast(coords[4]) >= static_cast(iW)) - col[colOffset] = static_cast(0.); - else - col[colOffset] = vol[shape::getOffset(volShapeInfo, coords)]; +static __global__ void vol2colCuda(const void* volume, + const Nd4jLong* volShapeInfo, void* columns, + const Nd4jLong* colShapeInfo, const int sD, + const int sH, const int sW, const int pD, + const int pH, const int pW, const int dD, + const int dH, const int dW) { + const T* vol = reinterpret_cast(volume); + T* col = reinterpret_cast(columns); + + __shared__ int colRank, volRank; + __shared__ Nd4jLong colLen, iD, iH, iW, *sharedMem; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + volRank = 5; + colRank = 8; + + colLen = shape::length(colShapeInfo); + + iD = volShapeInfo[3]; + iH = volShapeInfo[4]; + iW = volShapeInfo[5]; + } + __syncthreads(); + + const auto colInd = threadIdx.x + blockIdx.x * blockDim.x; + + if (colInd >= colLen) return; + + auto coords = sharedMem + threadIdx.x * colRank; + + shape::index2coords(colInd, colShapeInfo, coords); + + // const auto colW = coords[7]; + // const auto colH = coords[6]; + // const auto colD = coords[5]; + // const auto kCol = coords[4]; + // const auto kRow = coords[3]; + // const auto kDep = coords[2]; + // const auto c = coords[1]; + // const auto b = coords[0]; + + const auto colOffset = shape::getOffset(colShapeInfo, coords); + + coords[2] = + -pD + coords[2] * dD + + coords[5] * sD; // const auto volDep = (-pD + kDep * dD) + colD * sD; + coords[3] = + -pH + coords[3] * dH + + coords[6] * sH; // const auto volRow = (-pH + kRow * dH) + colH * sH; + coords[4] = + -pW + coords[4] * dW + + coords[7] * sW; // const auto volCol = (-pW + kCol * dW) + colW * sW; + + if (static_cast(coords[2]) >= static_cast(iD) || + static_cast(coords[3]) >= static_cast(iH) || + static_cast(coords[4]) >= static_cast(iW)) + col[colOffset] = static_cast(0.); + else + col[colOffset] = vol[shape::getOffset(volShapeInfo, coords)]; } ////////////////////////////////////////////////////////////////////////// template -static void vol2colCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* volume, const Nd4jLong* volShapeInfo, - void* columns, const Nd4jLong* colShapeInfo, - const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - - vol2colCuda<<>>(volume, volShapeInfo, columns, colShapeInfo, sD, sH, sW, pD, pH, pW, dD, dH, dW); +static void vol2colCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* volume, + const Nd4jLong* volShapeInfo, void* columns, + const Nd4jLong* colShapeInfo, const int sD, + const int sH, const int sW, const int pD, + const int pH, const int pW, const int dD, + const int dH, const int dW) { + vol2colCuda<<>>( + volume, volShapeInfo, columns, colShapeInfo, sD, sH, sW, pD, pH, pW, dD, + dH, dW); } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::vol2col(sd::graph::Context& block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - - PointersManager manager(block.launchContext(), "vol2col"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (col.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = col.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&col}, {&vol}); - BUILD_SINGLE_SELECTOR(vol.dataType(), vol2colCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), vol.specialBuffer(), vol.specialShapeInfo(), col.specialBuffer(), col.specialShapeInfo(), sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); - NDArray::registerSpecialUse({&col}, {&vol}); - - manager.synchronize(); +void ConvolutionUtils::vol2col(sd::graph::Context& block, const NDArray& vol, + NDArray& col, const int sD, const int sH, + const int sW, const int pD, const int pH, + const int pW, const int dD, const int dH, + const int dW) { + PointersManager manager(block.launchContext(), "vol2col"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (col.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = col.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&col}, {&vol}); + BUILD_SINGLE_SELECTOR( + vol.dataType(), vol2colCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, + block.launchContext()->getCudaStream(), vol.specialBuffer(), + vol.specialShapeInfo(), col.specialBuffer(), col.specialShapeInfo(), sD, + sH, sW, pD, pH, pW, dD, dH, dW), + FLOAT_TYPES); + NDArray::registerSpecialUse({&col}, {&vol}); + + manager.synchronize(); } -} -} \ No newline at end of file +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/cross.cu b/libnd4j/include/ops/declarable/helpers/cuda/cross.cu index 8de4f65fd690..ad7893ae8936 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/cross.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/cross.cu @@ -18,105 +18,114 @@ // @author Yurii Shyrma, created on 10.06.2019 // - -#include #include +#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// -template +template __global__ static void crossCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo) { + void* vz, const Nd4jLong* zShapeInfo) { + __shared__ const T* x; + __shared__ const T* y; + __shared__ T* z; + __shared__ int rank, *sharedMem; + __shared__ Nd4jLong lenWithoutLastDim, totalThreads; - __shared__ const T* x; - __shared__ const T* y; - __shared__ T* z; - __shared__ int rank, *sharedMem; - __shared__ Nd4jLong lenWithoutLastDim, totalThreads; + if (threadIdx.x == 0) { + x = reinterpret_cast(vx); + y = reinterpret_cast(vy); + z = reinterpret_cast(vz); - if (threadIdx.x == 0) { - x = reinterpret_cast(vx); - y = reinterpret_cast(vy); - z = reinterpret_cast(vz); + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + totalThreads = gridDim.x * blockDim.x; - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - totalThreads = gridDim.x * blockDim.x; + rank = shape::rank(xShapeInfo); + lenWithoutLastDim = shape::length(xShapeInfo) / + xShapeInfo[rank]; // shape::length(xShapeInfo) / 3; + } + __syncthreads(); - rank = shape::rank(xShapeInfo); - lenWithoutLastDim = shape::length(xShapeInfo) / xShapeInfo[rank]; // shape::length(xShapeInfo) / 3; - } - __syncthreads(); + auto coords = sharedMem + threadIdx.x * rank; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - auto coords = sharedMem + threadIdx.x * rank; - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + for (uint i = tid; i < lenWithoutLastDim; i += totalThreads) { + shape::index2coords(i, rank - 1, xShapeInfo + 1, coords); - for (uint i = tid; i < lenWithoutLastDim; i += totalThreads) { + coords[rank - 1] = 0; - shape::index2coords(i, rank - 1, xShapeInfo + 1, coords); + auto xOffset = shape::getOffset(xShapeInfo, coords); + auto yOffset = shape::getOffset(yShapeInfo, coords); - coords[rank - 1] = 0; + const auto x0 = x[xOffset]; + const auto y0 = y[yOffset]; - auto xOffset = shape::getOffset(xShapeInfo, coords); - auto yOffset = shape::getOffset(yShapeInfo, coords); + xOffset += shape::stride(const_cast(xShapeInfo))[rank - 1]; + yOffset += shape::stride(const_cast(yShapeInfo))[rank - 1]; - const auto x0 = x[xOffset]; - const auto y0 = y[yOffset]; + const auto x1 = x[xOffset]; + const auto y1 = y[yOffset]; - xOffset += shape::stride(const_cast(xShapeInfo))[rank - 1]; - yOffset += shape::stride(const_cast(yShapeInfo))[rank - 1]; + xOffset += shape::stride(const_cast(xShapeInfo))[rank - 1]; + yOffset += shape::stride(const_cast(yShapeInfo))[rank - 1]; - const auto x1 = x[xOffset]; - const auto y1 = y[yOffset]; + const auto x2 = x[xOffset]; + const auto y2 = y[yOffset]; - xOffset += shape::stride(const_cast(xShapeInfo))[rank - 1]; - yOffset += shape::stride(const_cast(yShapeInfo))[rank - 1]; + auto zOffset = shape::getOffset(zShapeInfo, coords); + z[zOffset] = x1 * y2 - x2 * y1; - const auto x2 = x[xOffset]; - const auto y2 = y[yOffset]; + zOffset += shape::stride(const_cast(zShapeInfo))[rank - 1]; + z[zOffset] = x2 * y0 - x0 * y2; - auto zOffset = shape::getOffset(zShapeInfo, coords); - z[zOffset] = x1 * y2 - x2 * y1; - - zOffset += shape::stride(const_cast(zShapeInfo))[rank - 1]; - z[zOffset] = x2 * y0 - x0 * y2; - - zOffset += shape::stride(const_cast(zShapeInfo))[rank - 1]; - z[zOffset] = x0 * y1 - x1 * y0; - } + zOffset += shape::stride(const_cast(zShapeInfo))[rank - 1]; + z[zOffset] = x0 * y1 - x1 * y0; + } } -template -__host__ static void crossCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo) { - - crossCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); +template +__host__ static void crossCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vy, const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo) { + crossCuda<<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); } -BUILD_SINGLE_TEMPLATE(template void crossCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo), NUMERIC_TYPES); - - -void crossBatched(sd::LaunchContext* context, NDArray *x, NDArray *y, NDArray *z) { - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (x->lengthOf() / x->sizeAt(-1) + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = sizeof(int) * threadsPerBlock * x->rankOf() + 128; - - PointersManager manager(context, "cross"); - - NDArray::prepareSpecialUse({z}, {x, y}); - BUILD_SINGLE_SELECTOR(x->dataType(), crossCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), x->specialBuffer(), x->specialShapeInfo(), y->specialBuffer(), y->specialShapeInfo(), z->specialBuffer(), z->specialShapeInfo()), NUMERIC_TYPES); - NDArray::registerSpecialUse({z}, {x, y}); - - manager.synchronize(); +BUILD_SINGLE_TEMPLATE(template void crossCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const int sharedMem, const cudaStream_t* stream, + const void* vx, const Nd4jLong* xShapeInfo, + const void* vy, const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo), + NUMERIC_TYPES); + +void crossBatched(sd::LaunchContext* context, NDArray* x, NDArray* y, + NDArray* z) { + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (x->lengthOf() / x->sizeAt(-1) + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = sizeof(int) * threadsPerBlock * x->rankOf() + 128; + + PointersManager manager(context, "cross"); + + NDArray::prepareSpecialUse({z}, {x, y}); + BUILD_SINGLE_SELECTOR( + x->dataType(), crossCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + x->specialBuffer(), x->specialShapeInfo(), y->specialBuffer(), + y->specialShapeInfo(), z->specialBuffer(), z->specialShapeInfo()), + NUMERIC_TYPES); + NDArray::registerSpecialUse({z}, {x, y}); + + manager.synchronize(); } -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/d_t_s.cu b/libnd4j/include/ops/declarable/helpers/cuda/d_t_s.cu index 35d8bf0335ae..c9c87db35147 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/d_t_s.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/d_t_s.cu @@ -24,82 +24,106 @@ namespace sd { namespace ops { namespace helpers { - template - static _CUDA_G void depthToSpaceKernel(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const int block_size, const bool isNHWC) { - auto input_ptr = reinterpret_cast(vx); - auto output_ptr = reinterpret_cast(vz); - - const int batch_size = shape::sizeAt(xShapeInfo, 0); - const int input_depth = isNHWC ? shape::sizeAt(xShapeInfo, 3) : shape::sizeAt(xShapeInfo, 1); - const int input_height = isNHWC ? shape::sizeAt(xShapeInfo, 1) : shape::sizeAt(xShapeInfo, 2); - const int input_width = isNHWC ? shape::sizeAt(xShapeInfo, 2) : shape::sizeAt(xShapeInfo, 3); - - const int output_depth = isNHWC ? shape::sizeAt(zShapeInfo, 3) : shape::sizeAt(zShapeInfo, 1); - const int output_height = isNHWC ? shape::sizeAt(zShapeInfo, 1) : shape::sizeAt(zShapeInfo, 2); - const int output_width = isNHWC ? shape::sizeAt(zShapeInfo, 2) : shape::sizeAt(zShapeInfo, 3); - - const int input_area = input_width * input_height; - const int input_depth_by_input_area = input_depth * input_area; - const int output_depth_by_input_height = output_depth * input_height; - - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - - if (isNHWC) { - const int total_count = batch_size * output_height * output_width * output_depth; - for (int out_idx = tid; out_idx < total_count; out_idx += blockDim.x * gridDim.x) { - const int d = out_idx % output_depth; - const int out_idx2 = out_idx / output_depth; - const int w = out_idx2 % output_width; - const int out_idx3 = out_idx2 / output_width; - const int h = out_idx3 % output_height; - const int b = out_idx3 / output_height; - - const int in_h = h / block_size; - const int offset_h = h % block_size; - const int in_w = w / block_size; - const int offset_w = w % block_size; - const int offset_d = (offset_h * block_size + offset_w) * output_depth; - const int in_d = d + offset_d; - const int inp_idx = in_d + input_depth * (in_w + input_width * (in_h + input_height * b)); - (output_ptr + out_idx)[0] = (input_ptr + inp_idx)[0]; - } - } else { - const int total_count = batch_size * input_depth_by_input_area; - - for (int input_idx = tid; input_idx < total_count; input_idx += blockDim.x * gridDim.x) { - const int n_bY_bX_oC_iY = input_idx / input_width; - const int iX = input_idx - n_bY_bX_oC_iY * input_width; - - const int n_bY_bX = n_bY_bX_oC_iY / output_depth_by_input_height; - const int oC_iY = n_bY_bX_oC_iY - n_bY_bX * output_depth_by_input_height; - - const int n_bY = n_bY_bX / block_size; - const int bX = n_bY_bX - n_bY * block_size; - - const int n = n_bY / block_size; - const int bY = n_bY - n * block_size; - - const int output_idx = bX + block_size * (iX + input_width * (bY + block_size * (oC_iY + n * output_depth_by_input_height))); - - (output_ptr + output_idx)[0] = (input_ptr + input_idx)[0]; - } - } +template +static _CUDA_G void depthToSpaceKernel(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + const int block_size, + const bool isNHWC) { + auto input_ptr = reinterpret_cast(vx); + auto output_ptr = reinterpret_cast(vz); + + const int batch_size = shape::sizeAt(xShapeInfo, 0); + const int input_depth = + isNHWC ? shape::sizeAt(xShapeInfo, 3) : shape::sizeAt(xShapeInfo, 1); + const int input_height = + isNHWC ? shape::sizeAt(xShapeInfo, 1) : shape::sizeAt(xShapeInfo, 2); + const int input_width = + isNHWC ? shape::sizeAt(xShapeInfo, 2) : shape::sizeAt(xShapeInfo, 3); + + const int output_depth = + isNHWC ? shape::sizeAt(zShapeInfo, 3) : shape::sizeAt(zShapeInfo, 1); + const int output_height = + isNHWC ? shape::sizeAt(zShapeInfo, 1) : shape::sizeAt(zShapeInfo, 2); + const int output_width = + isNHWC ? shape::sizeAt(zShapeInfo, 2) : shape::sizeAt(zShapeInfo, 3); + + const int input_area = input_width * input_height; + const int input_depth_by_input_area = input_depth * input_area; + const int output_depth_by_input_height = output_depth * input_height; + + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + + if (isNHWC) { + const int total_count = + batch_size * output_height * output_width * output_depth; + for (int out_idx = tid; out_idx < total_count; + out_idx += blockDim.x * gridDim.x) { + const int d = out_idx % output_depth; + const int out_idx2 = out_idx / output_depth; + const int w = out_idx2 % output_width; + const int out_idx3 = out_idx2 / output_width; + const int h = out_idx3 % output_height; + const int b = out_idx3 / output_height; + + const int in_h = h / block_size; + const int offset_h = h % block_size; + const int in_w = w / block_size; + const int offset_w = w % block_size; + const int offset_d = (offset_h * block_size + offset_w) * output_depth; + const int in_d = d + offset_d; + const int inp_idx = + in_d + input_depth * (in_w + input_width * (in_h + input_height * b)); + (output_ptr + out_idx)[0] = (input_ptr + inp_idx)[0]; } + } else { + const int total_count = batch_size * input_depth_by_input_area; + for (int input_idx = tid; input_idx < total_count; + input_idx += blockDim.x * gridDim.x) { + const int n_bY_bX_oC_iY = input_idx / input_width; + const int iX = input_idx - n_bY_bX_oC_iY * input_width; - template - static void __depthToSpace(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) { - depthToSpaceKernel<<<512, 512, 1024, *context->getCudaStream()>>>(input.specialBuffer(), input.specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), block_size, isNHWC); - } + const int n_bY_bX = n_bY_bX_oC_iY / output_depth_by_input_height; + const int oC_iY = n_bY_bX_oC_iY - n_bY_bX * output_depth_by_input_height; + + const int n_bY = n_bY_bX / block_size; + const int bX = n_bY_bX - n_bY * block_size; - void _depthToSpace(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) { - auto xType = input.dataType(); + const int n = n_bY / block_size; + const int bY = n_bY - n * block_size; - NDArray::prepareSpecialUse({output}, {&input}); + const int output_idx = + bX + + block_size * + (iX + input_width * + (bY + block_size * + (oC_iY + n * output_depth_by_input_height))); - BUILD_SINGLE_SELECTOR(xType, __depthToSpace, (context, input, output, block_size, isNHWC), LIBND4J_TYPES); - NDArray::registerSpecialUse({output}, {&input}); + (output_ptr + output_idx)[0] = (input_ptr + input_idx)[0]; } + } +} + +template +static void __depthToSpace(sd::LaunchContext *context, const NDArray &input, + NDArray *output, int block_size, bool isNHWC) { + depthToSpaceKernel<<<512, 512, 1024, *context->getCudaStream()>>>( + input.specialBuffer(), input.specialShapeInfo(), output->specialBuffer(), + output->specialShapeInfo(), block_size, isNHWC); } + +void _depthToSpace(sd::LaunchContext *context, const NDArray &input, + NDArray *output, int block_size, bool isNHWC) { + auto xType = input.dataType(); + + NDArray::prepareSpecialUse({output}, {&input}); + + BUILD_SINGLE_SELECTOR(xType, __depthToSpace, + (context, input, output, block_size, isNHWC), + LIBND4J_TYPES); + NDArray::registerSpecialUse({output}, {&input}); } -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/diGamma.cu b/libnd4j/include/ops/declarable/helpers/cuda/diGamma.cu index ff217bdb6c80..c475d3ea58e8 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/diGamma.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/diGamma.cu @@ -19,60 +19,72 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include namespace sd { namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template +template __global__ static void diGammaCuda(const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ Nd4jLong len; - __shared__ bool sameOffset; - - if (threadIdx.x == 0) { - len = shape::length(xShapeInfo); - sameOffset = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - } - __syncthreads(); - - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < len; i += gridDim.x * blockDim.x) { - - const auto xOffset = shape::getIndexOffset(i, xShapeInfo); - const auto zOffset = sameOffset ? xOffset : shape::getIndexOffset(i, zShapeInfo); - - z[zOffset] = diGammaScalar(x[xOffset]); - } + void *vz, const Nd4jLong *zShapeInfo) { + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ Nd4jLong len; + __shared__ bool sameOffset; + + if (threadIdx.x == 0) { + len = shape::length(xShapeInfo); + sameOffset = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + } + __syncthreads(); + + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < len; + i += gridDim.x * blockDim.x) { + const auto xOffset = shape::getIndexOffset(i, xShapeInfo); + const auto zOffset = + sameOffset ? xOffset : shape::getIndexOffset(i, zShapeInfo); + + z[zOffset] = diGammaScalar(x[xOffset]); + } } /////////////////////////////////////////////////////////////////// -template -static void diGammaCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { - - diGammaCuda<<>>(vx, xShapeInfo, vz, zShapeInfo); +template +static void diGammaCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, + const cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + diGammaCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo); } /////////////////////////////////////////////////////////////////// -void diGamma(sd::LaunchContext* context, const NDArray& x, NDArray& z) { - - int threadsPerBlock = MAX_NUM_THREADS / 2; - int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - NDArray::prepareSpecialUse({&z}, {&x}); - BUILD_SINGLE_SELECTOR(x.dataType(), diGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), x.specialBuffer(), x.specialShapeInfo(), z.specialBuffer(), z.specialShapeInfo()), FLOAT_TYPES); - NDArray::registerSpecialUse({&z}, {&x}); +void diGamma(sd::LaunchContext *context, const NDArray &x, NDArray &z) { + int threadsPerBlock = MAX_NUM_THREADS / 2; + int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({&z}, {&x}); + BUILD_SINGLE_SELECTOR( + x.dataType(), diGammaCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + x.specialBuffer(), x.specialShapeInfo(), z.specialBuffer(), + z.specialShapeInfo()), + FLOAT_TYPES); + NDArray::registerSpecialUse({&z}, {&x}); } -BUILD_SINGLE_TEMPLATE(template void diGammaCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo), FLOAT_TYPES); - -} -} -} +BUILD_SINGLE_TEMPLATE(template void diGammaCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo), + FLOAT_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/diag.cu b/libnd4j/include/ops/declarable/helpers/cuda/diag.cu index f011f409526a..95bbd5b9893a 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/diag.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/diag.cu @@ -28,110 +28,130 @@ namespace helpers { // diag functor cuda kernel // outputBuffer - output tensor buffer // outputShape - output tensor shape -// inputBuffer - input tensor buffer - this tensor should be placed on diagonal position of output -// inputShape - input tensor shape -// inputLength - length for input tensor +// inputBuffer - input tensor buffer - this tensor should be placed on diagonal +// position of output inputShape - input tensor shape inputLength - length for +// input tensor // template -static __global__ void diagFunctorKernel(void* outputBuffer, const Nd4jLong* outputShape, void const* inputBuffer, const Nd4jLong* inputShape, Nd4jLong inputLength) { - __shared__ T *z; - __shared__ T const* x; - __shared__ Nd4jLong outputLength; - - if (threadIdx.x == 0) { - z = reinterpret_cast(outputBuffer); - x = reinterpret_cast(inputBuffer); - - outputLength = shape::length(outputShape); - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - - for (int t = tid; t < inputLength; t += step) { // for all vals in input, put all on diagonal position to output - z[shape::getIndexOffset(t * (inputLength + 1), outputShape)] = x[shape::getIndexOffset(t, inputShape)]; //tX]; - } - +static __global__ void diagFunctorKernel(void* outputBuffer, + const Nd4jLong* outputShape, + void const* inputBuffer, + const Nd4jLong* inputShape, + Nd4jLong inputLength) { + __shared__ T* z; + __shared__ T const* x; + __shared__ Nd4jLong outputLength; + + if (threadIdx.x == 0) { + z = reinterpret_cast(outputBuffer); + x = reinterpret_cast(inputBuffer); + + outputLength = shape::length(outputShape); + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + + for (int t = tid; t < inputLength; + t += + step) { // for all vals in input, put all on diagonal position to output + z[shape::getIndexOffset(t * (inputLength + 1), outputShape)] = + x[shape::getIndexOffset(t, inputShape)]; // tX]; + } } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // diag part functor cuda kernel // outputBuffer - output tensor buffer - linear sequence of diagonal values // outputShape - output tensor shape -// inputBuffer - input tensor buffer - this tensor should be placed on diagonal position of output -// inputShape - input tensor shape -// outputLength - given length of output -// inputLength - given length for input tensor +// inputBuffer - input tensor buffer - this tensor should be placed on diagonal +// position of output inputShape - input tensor shape outputLength - given +// length of output inputLength - given length for input tensor // - template - static __global__ void diagPartFunctorKernel(void* outputBuffer, const Nd4jLong* outputShape, void const* inputBuffer, const Nd4jLong* inputShape, Nd4jLong outputLength, Nd4jLong inputLength) { - __shared__ T *z; - __shared__ T const* x; - - if (threadIdx.x == 0) { - z = reinterpret_cast(outputBuffer); - x = reinterpret_cast(inputBuffer); - - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - Nd4jLong i = threadIdx.x * (outputLength + 1); // pos to diagonal value - for (int t = tid; t < outputLength && i < inputLength; t += step) { // loop by output, but input matrix may not be square - // put diagonal val from input onto output - z[shape::getIndexOffset(t, outputShape)] = x[shape::getIndexOffset(i, inputShape)]; - i += outputLength + 1; // shift to next diagonal value - } - } +template +static __global__ void diagPartFunctorKernel( + void* outputBuffer, const Nd4jLong* outputShape, void const* inputBuffer, + const Nd4jLong* inputShape, Nd4jLong outputLength, Nd4jLong inputLength) { + __shared__ T* z; + __shared__ T const* x; + + if (threadIdx.x == 0) { + z = reinterpret_cast(outputBuffer); + x = reinterpret_cast(inputBuffer); + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + Nd4jLong i = threadIdx.x * (outputLength + 1); // pos to diagonal value + for (int t = tid; t < outputLength && i < inputLength; + t += step) { // loop by output, but input matrix may not be square + // put diagonal val from input onto output + z[shape::getIndexOffset(t, outputShape)] = + x[shape::getIndexOffset(i, inputShape)]; + i += outputLength + 1; // shift to next diagonal value + } +} ////////////////////////////////////////////////////////////////////////// // Returns a batched matrix tensor with new batched diagonal values. -// for detailed explanations please take a look on web page: https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag - template - static void _diagFunctor(sd::LaunchContext * context, const NDArray* input, NDArray* output) { - auto stream = context->getCudaStream(); - auto inputLength = input->lengthOf(); - dim3 launchDims(256, 512, 8192); - if (!input->isActualOnDeviceSide()) - input->syncToDevice(); - diagFunctorKernel<<>>(output->specialBuffer(), output->specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), inputLength); - } +// for detailed explanations please take a look on web page: +// https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag +template +static void _diagFunctor(sd::LaunchContext* context, const NDArray* input, + NDArray* output) { + auto stream = context->getCudaStream(); + auto inputLength = input->lengthOf(); + dim3 launchDims(256, 512, 8192); + if (!input->isActualOnDeviceSide()) input->syncToDevice(); + diagFunctorKernel<<>>( + output->specialBuffer(), output->specialShapeInfo(), + input->specialBuffer(), input->specialShapeInfo(), inputLength); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // diagFunctor - caller for diag functor processor - void diagFunctor(sd::LaunchContext * context, const NDArray* input, NDArray* output) { - auto xType = input->dataType(); +void diagFunctor(sd::LaunchContext* context, const NDArray* input, + NDArray* output) { + auto xType = input->dataType(); - BUILD_SINGLE_SELECTOR(xType, _diagFunctor, (context, input, output), LIBND4J_TYPES); - } + BUILD_SINGLE_SELECTOR(xType, _diagFunctor, (context, input, output), + LIBND4J_TYPES); +} - BUILD_SINGLE_TEMPLATE(template void _diagFunctor, (sd::LaunchContext * context, const NDArray* input, NDArray* output);, LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void _diagFunctor, + (sd::LaunchContext * context, const NDArray* input, + NDArray* output); + , LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // diagPartFunctor - caller for diag part functor kernel - template - void _diagPartFunctor(sd::LaunchContext * context, NDArray const* input, NDArray* output) { - const int outLen = output->lengthOf(); - const int inLen = input->lengthOf(); - auto stream = context->getCudaStream(); - - dim3 launchDims(256, 512, 8192); - if (!input->isActualOnDeviceSide()) - input->syncToDevice(); - - diagPartFunctorKernel<<>>(output->specialBuffer(), output->specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), outLen, inLen); - } +template +void _diagPartFunctor(sd::LaunchContext* context, NDArray const* input, + NDArray* output) { + const int outLen = output->lengthOf(); + const int inLen = input->lengthOf(); + auto stream = context->getCudaStream(); + + dim3 launchDims(256, 512, 8192); + if (!input->isActualOnDeviceSide()) input->syncToDevice(); + + diagPartFunctorKernel + <<>>( + output->specialBuffer(), output->specialShapeInfo(), + input->specialBuffer(), input->specialShapeInfo(), outLen, inLen); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // diagPartFunctor - caller for diag part functor processor - void diagPartFunctor(sd::LaunchContext * context, NDArray const* input, NDArray* output) { - auto zType = output->dataType(); - BUILD_SINGLE_SELECTOR(zType, _diagPartFunctor, (context, input, output), NUMERIC_TYPES); - - } - +void diagPartFunctor(sd::LaunchContext* context, NDArray const* input, + NDArray* output) { + auto zType = output->dataType(); + BUILD_SINGLE_SELECTOR(zType, _diagPartFunctor, (context, input, output), + NUMERIC_TYPES); } -} -} \ No newline at end of file + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/dilation2d.cu b/libnd4j/include/ops/declarable/helpers/cuda/dilation2d.cu index 0d25552c93a5..2b7cbcc98afa 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/dilation2d.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/dilation2d.cu @@ -18,117 +18,125 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// template -__global__ static void dilation2dCuda(const void* vx, const Nd4jLong* xShapeInfo, - const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW) { +__global__ static void dilation2dCuda(const void* vx, + const Nd4jLong* xShapeInfo, + const void* vy, + const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const int sH, + const int sW, const int pH, const int pW, + const int dH, const int dW) { + // x [bS, iH, iW, iC] + // y [kH, kW, iC] + // z [bS, oH, oW, iC] - // x [bS, iH, iW, iC] - // y [kH, kW, iC] - // z [bS, oH, oW, iC] + const X* x = reinterpret_cast(vx); + const X* y = reinterpret_cast(vy); + Z* z = reinterpret_cast(vz); - const X* x = reinterpret_cast(vx); - const X* y = reinterpret_cast(vy); - Z* z = reinterpret_cast(vz); + __shared__ int xzRank, yRank, *sharedMem; + __shared__ uint iH, iW, kH, kW; + __shared__ Nd4jLong zLen; - __shared__ int xzRank, yRank, *sharedMem; - __shared__ uint iH, iW, kH, kW; - __shared__ Nd4jLong zLen; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + zLen = shape::length(zShapeInfo); - zLen = shape::length(zShapeInfo); + xzRank = shape::rank(xShapeInfo); + yRank = shape::rank(yShapeInfo); - xzRank = shape::rank(xShapeInfo); - yRank = shape::rank(yShapeInfo); + iH = xShapeInfo[2]; + iW = xShapeInfo[3]; - iH = xShapeInfo[2]; - iW = xShapeInfo[3]; - - kH = yShapeInfo[1]; - kW = yShapeInfo[2]; - } - __syncthreads(); + kH = yShapeInfo[1]; + kW = yShapeInfo[2]; + } + __syncthreads(); - const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; + const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; - if(zInd >= zLen) - return; + if (zInd >= zLen) return; - auto xzCoords = sharedMem + threadIdx.x * (xzRank + yRank); - auto yCoords = xzCoords + xzRank; + auto xzCoords = sharedMem + threadIdx.x * (xzRank + yRank); + auto yCoords = xzCoords + xzRank; - shape::index2coords(zInd, zShapeInfo, xzCoords); + shape::index2coords(zInd, zShapeInfo, xzCoords); - const auto zOffset = shape::getOffset(zShapeInfo, xzCoords); + const auto zOffset = shape::getOffset(zShapeInfo, xzCoords); - yCoords[2] = xzCoords[3]; // iC coordinate is same for x, y and z + yCoords[2] = xzCoords[3]; // iC coordinate is same for x, y and z - const auto oh = xzCoords[1]; - const auto ow = xzCoords[2]; + const auto oh = xzCoords[1]; + const auto ow = xzCoords[2]; - X max = -DataTypeUtils::max(); + X max = -DataTypeUtils::max(); - for (yCoords[0] = 0; yCoords[0] < kH; ++yCoords[0]) { - xzCoords[1] = oh * sH - pH + yCoords[0] * dH; - if (xzCoords[1] < 0 || xzCoords[1] >= iH) continue; + for (yCoords[0] = 0; yCoords[0] < kH; ++yCoords[0]) { + xzCoords[1] = oh * sH - pH + yCoords[0] * dH; + if (xzCoords[1] < 0 || xzCoords[1] >= iH) continue; - for (yCoords[1] = 0; yCoords[1] < kW; ++yCoords[1]) { - xzCoords[2] = ow * sW - pW + yCoords[1] * dW; - if(xzCoords[2] < 0 || xzCoords[2] >= iW) continue; + for (yCoords[1] = 0; yCoords[1] < kW; ++yCoords[1]) { + xzCoords[2] = ow * sW - pW + yCoords[1] * dW; + if (xzCoords[2] < 0 || xzCoords[2] >= iW) continue; - const X val = x[shape::getOffset(xShapeInfo, xzCoords)] + y[shape::getOffset(yShapeInfo, yCoords)]; - if (val > max) - max = val; - } - } + const X val = x[shape::getOffset(xShapeInfo, xzCoords)] + + y[shape::getOffset(yShapeInfo, yCoords)]; + if (val > max) max = val; + } + } - z[zOffset] = static_cast(max); + z[zOffset] = static_cast(max); } ////////////////////////////////////////////////////////////////////////// template -static void dilation2dCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW) { - - dilation2dCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, sH, sW, pH, pW, dH, dW); +static void dilation2dCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vy, const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW) { + dilation2dCuda<<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, sH, sW, pH, pW, dH, dW); } -void dilation2d(sd::LaunchContext* context, NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) { - - PointersManager manager(context, "dilation2d"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (output->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = (weights->rankOf() + output->rankOf()) * sizeof(int) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({output}, {input, weights}); - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), dilation2dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input->specialBuffer(), input->specialShapeInfo(), weights->specialBuffer(), weights->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), sH, sW, pH, pW, dH, dW), FLOAT_TYPES); - NDArray::registerSpecialUse({output}, {input, weights}); - - manager.synchronize(); +void dilation2d(sd::LaunchContext* context, NDArray* input, NDArray* weights, + NDArray* output, const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW) { + PointersManager manager(context, "dilation2d"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (output->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + (weights->rankOf() + output->rankOf()) * sizeof(int) * threadsPerBlock + + 128; + + NDArray::prepareSpecialUse({output}, {input, weights}); + BUILD_SINGLE_SELECTOR_TWICE( + input->dataType(), dilation2dCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + input->specialBuffer(), input->specialShapeInfo(), + weights->specialBuffer(), weights->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), sH, sW, pH, pW, dH, + dW), + FLOAT_TYPES); + NDArray::registerSpecialUse({output}, {input, weights}); + + manager.synchronize(); } - -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu b/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu index 4e0fdb377e10..ef14d1e46780 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu @@ -18,258 +18,345 @@ // @author raver119@gmail.com // -#include +#include #include -#include +#include + #include -#include +#include namespace sd { namespace ops { namespace helpers { - template - static __global__ void dropoutSimpleKernel(void const* inputBuf, Nd4jLong const* inputShape, void* outputBuf, Nd4jLong const* outputShape, double probVal, int inLen, sd::graph::RandomGenerator* nodeRng) { - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - T const* input = reinterpret_cast(inputBuf); - T* output = reinterpret_cast(outputBuf); - - // trivial idea: loop through all elements, get independent probability for each element to be nullified - for (Nd4jLong e = 0; e < inLen; ++e) { - T val = nodeRng->relativeT(e, T(0.f), T(1.f)); - - // if probability is ok - we're saving scaled value - if (double(val) < probVal) - output[shape::getIndexOffset(e, outputShape)] = T(input[shape::getIndexOffset(e, inputShape)] / probVal); - } - } +template +static __global__ void dropoutSimpleKernel( + void const* inputBuf, Nd4jLong const* inputShape, void* outputBuf, + Nd4jLong const* outputShape, double probVal, int inLen, + sd::graph::RandomGenerator* nodeRng) { + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + T const* input = reinterpret_cast(inputBuf); + T* output = reinterpret_cast(outputBuf); + + // trivial idea: loop through all elements, get independent probability for + // each element to be nullified + for (Nd4jLong e = 0; e < inLen; ++e) { + T val = nodeRng->relativeT(e, T(0.f), T(1.f)); + + // if probability is ok - we're saving scaled value + if (double(val) < probVal) + output[shape::getIndexOffset(e, outputShape)] = + T(input[shape::getIndexOffset(e, inputShape)] / probVal); + } +} - template - static void dropoutSimple(sd::LaunchContext* context, NDArray const* input, NDArray* output, double probValue, int seed) { - sd::graph::RandomGenerator nodeRng(3019L, seed); - int inLen = input->lengthOf(); - sd::graph::RandomGenerator* dRandom; - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({output}, {input}); - - auto err = cudaMalloc(&dRandom, sizeof(sd::graph::RandomGenerator)); - if (err) { - throw cuda_exception::build("helpers::dropoutSimple: Cannot allocate device memory for random generator.", err); - } - err = cudaMemcpy(dRandom, &nodeRng, sizeof(sd::graph::RandomGenerator), cudaMemcpyHostToDevice); - if (err) { - throw cuda_exception::build("helpers::dropoutSimple: Cannot set up device memory for random generator.", err); - } - - dropoutSimpleKernel<<<128, 256, 1024, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), probValue, inLen, dRandom); - err = cudaFree(dRandom); - if (err) { - throw cuda_exception::build("helpers::dropoutSimple: Cannot deallocate device memory for random generator.", err); - } - NDArray::registerSpecialUse({output}, {input}); - } +template +static void dropoutSimple(sd::LaunchContext* context, NDArray const* input, + NDArray* output, double probValue, int seed) { + sd::graph::RandomGenerator nodeRng(3019L, seed); + int inLen = input->lengthOf(); + sd::graph::RandomGenerator* dRandom; + auto stream = context->getCudaStream(); + NDArray::prepareSpecialUse({output}, {input}); + + auto err = cudaMalloc(&dRandom, sizeof(sd::graph::RandomGenerator)); + if (err) { + throw cuda_exception::build( + "helpers::dropoutSimple: Cannot allocate device memory for random " + "generator.", + err); + } + err = cudaMemcpy(dRandom, &nodeRng, sizeof(sd::graph::RandomGenerator), + cudaMemcpyHostToDevice); + if (err) { + throw cuda_exception::build( + "helpers::dropoutSimple: Cannot set up device memory for random " + "generator.", + err); + } + + dropoutSimpleKernel<<<128, 256, 1024, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), probValue, inLen, + dRandom); + err = cudaFree(dRandom); + if (err) { + throw cuda_exception::build( + "helpers::dropoutSimple: Cannot deallocate device memory for random " + "generator.", + err); + } + NDArray::registerSpecialUse({output}, {input}); +} - template - int _dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) { - - if (reduceShape == nullptr){ - dropoutSimple(context.launchContext(), input, output, probValue, seed); - } - else { - REQUIRE_TRUE(reduceShape->lengthOf() <= input->rankOf(), 0, "dropout: Noise shape should be fittable to input"); - - std::vector dims(reduceShape->lengthOf()); - reduceShape->syncToHost(); // to ensure that follows are actual - bool fit = true; - - for( int i = 0; i < dims.size(); i++ ) { - if (fit) { - dims[i] = reduceShape->e(i); - for (int e = 0; e < input->rankOf(); ++e) - if (fit) - if (input->sizeAt(e) % dims[i]) { - fit = false; - } - } +template +int _dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, + NDArray* reduceShape, int seed, double probValue) { + if (reduceShape == nullptr) { + dropoutSimple(context.launchContext(), input, output, probValue, seed); + } else { + REQUIRE_TRUE(reduceShape->lengthOf() <= input->rankOf(), 0, + "dropout: Noise shape should be fittable to input"); + + std::vector dims(reduceShape->lengthOf()); + reduceShape->syncToHost(); // to ensure that follows are actual + bool fit = true; + + for (int i = 0; i < dims.size(); i++) { + if (fit) { + dims[i] = reduceShape->e(i); + for (int e = 0; e < input->rankOf(); ++e) + if (fit) + if (input->sizeAt(e) % dims[i]) { + fit = false; } - - // check dims to fit input - REQUIRE_TRUE(fit, 0, "dropout: Noise shape should fit to input rank."); - std::unique_ptr chunk(new NDArray('c', dims, output->dataType(), context.launchContext())); - chunk->assign(1.f); - - dropoutSimple(context.launchContext(), chunk.get(), chunk.get(), probValue, seed); - // broadcast chunk to full matrix - std::unique_ptr dropOutMultiplier(new NDArray(*input)); - dropOutMultiplier->assign(1.f); - - *dropOutMultiplier += *chunk; - - // FIXME: we could do this in one step, aren't we? - output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr); - } - - return Status::OK(); + } } - int dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) { - auto xType = input->dataType(); - NDArray::prepareSpecialUse({output}, {input}); - - BUILD_SINGLE_SELECTOR(xType, return _dropOutFunctor, (context, input, output, reduceShape, seed, probValue), FLOAT_TYPES); + // check dims to fit input + REQUIRE_TRUE(fit, 0, "dropout: Noise shape should fit to input rank."); + std::unique_ptr chunk( + new NDArray('c', dims, output->dataType(), context.launchContext())); + chunk->assign(1.f); - NDArray::registerSpecialUse({output}, {input}); - } + dropoutSimple(context.launchContext(), chunk.get(), chunk.get(), + probValue, seed); + // broadcast chunk to full matrix + std::unique_ptr dropOutMultiplier(new NDArray(*input)); + dropOutMultiplier->assign(1.f); -/////////////////////////////////// backrpopagations /////////////////////////////////////////////// - template - static __global__ void dropoutBPKernel(void* outputBuf, Nd4jLong const* outputShape, void* gradOutBuf, Nd4jLong const* gradOutShape, double probValue) { - __shared__ T* output; - __shared__ T* input; - __shared__ int len; + *dropOutMultiplier += *chunk; - if (threadIdx.x == 0) { - len = shape::length(outputShape); - output = reinterpret_cast(outputBuf); - input = reinterpret_cast(gradOutBuf); - } - __syncthreads(); + // FIXME: we could do this in one step, aren't we? + output->assign( + *input * + *dropOutMultiplier); // input->applyPairwiseTransform(pairwise::Multiply, + // dropOutMultiplier.get(), output, nullptr); + } - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (int e = tid; e < len; e += step) { - const auto zOffset = shape::getIndexOffset(e, outputShape); - - // if probability was non-zero on FF step, we'll scale grads back - if (output[zOffset] != T(0.)) - output[zOffset] = T(input[shape::getIndexOffset(e, gradOutShape)] / probValue); - - } - } - template - static int dropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) { - // we're making additional FF run to see how probabilities played out with given seeds - int res = dropOutFunctor(context, input, output, reduceShape, seed, probValue); - auto stream = context.launchContext()->getCudaStream(); + return Status::OK(); +} - NDArray::prepareSpecialUse({output}, {input, gradOut}); +int dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, + NDArray* reduceShape, int seed, double probValue) { + auto xType = input->dataType(); + NDArray::prepareSpecialUse({output}, {input}); - if (ND4J_STATUS_OK == res) - dropoutBPKernel<<<128, 256, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), probValue); + BUILD_SINGLE_SELECTOR(xType, return _dropOutFunctor, + (context, input, output, reduceShape, seed, probValue), + FLOAT_TYPES); - NDArray::registerSpecialUse({output}, {input, gradOut}); + NDArray::registerSpecialUse({output}, {input}); +} - return res; - } +/////////////////////////////////// backrpopagations +////////////////////////////////////////////////// +template +static __global__ void dropoutBPKernel(void* outputBuf, + Nd4jLong const* outputShape, + void* gradOutBuf, + Nd4jLong const* gradOutShape, + double probValue) { + __shared__ T* output; + __shared__ T* input; + __shared__ int len; + + if (threadIdx.x == 0) { + len = shape::length(outputShape); + output = reinterpret_cast(outputBuf); + input = reinterpret_cast(gradOutBuf); + } + __syncthreads(); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (int e = tid; e < len; e += step) { + const auto zOffset = shape::getIndexOffset(e, outputShape); + + // if probability was non-zero on FF step, we'll scale grads back + if (output[zOffset] != T(0.)) + output[zOffset] = + T(input[shape::getIndexOffset(e, gradOutShape)] / probValue); + } +} +template +static int dropOutFunctorBP_(graph::Context& context, NDArray* input, + NDArray* gradOut, NDArray* output, + NDArray* reduceShape, int seed, double probValue) { + // we're making additional FF run to see how probabilities played out with + // given seeds + int res = + dropOutFunctor(context, input, output, reduceShape, seed, probValue); + auto stream = context.launchContext()->getCudaStream(); + + NDArray::prepareSpecialUse({output}, {input, gradOut}); + + if (ND4J_STATUS_OK == res) + dropoutBPKernel<<<128, 256, 1024, *stream>>>( + output->specialBuffer(), output->specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), probValue); + + NDArray::registerSpecialUse({output}, {input, gradOut}); + + return res; +} - template - static __global__ void alphaDropoutSimpleKernel(void const* inputBuf, Nd4jLong const* inputShape, void* outputBuf, Nd4jLong const* outputShape, double probValue, double alpha, double alpha1, double beta, int inLen, sd::graph::RandomGenerator* nodeRng) { - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - T const* input = reinterpret_cast(inputBuf); - T* output = reinterpret_cast(outputBuf); - - for (auto e = tid; e < inLen; e += step) { - T val = nodeRng->relativeT(e, T(0.f), T(1.f)); - T xVal = input[shape::getIndexOffset(e, inputShape)]; - output[shape::getIndexOffset(e, outputShape)] = (val >= T(probValue) ? T(alpha * beta + alpha1) : T(alpha * (double)xVal + alpha1)); - } - } - template - static void alphaDropoutSimple(sd::LaunchContext* context, NDArray const* input, NDArray* output, int seed, double probValue, double alpha, double alpha1, double beta) { - sd::graph::RandomGenerator nodeRng(3019L, seed), *dRandom; - auto stream = context->getCudaStream(); - auto err = cudaMalloc(&dRandom, sizeof(sd::graph::RandomGenerator)); - NDArray::prepareSpecialUse({output}, {input}); - if (err) { - throw cuda_exception::build("helpers::alphaDropoutSimple: Cannot allocate device memory for random generator.", err); - } - err = cudaMemcpy(dRandom, &nodeRng, sizeof(sd::graph::RandomGenerator), cudaMemcpyHostToDevice); - if (err) { - throw cuda_exception::build("helpers::alphaDropoutSimple: Cannot set up device memory for random generator.", err); - } - - alphaDropoutSimpleKernel<<<128, 256, 1024, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), probValue, alpha, alpha1, beta, output->lengthOf(), dRandom); - - err = cudaFree(dRandom); - if (err) { - throw cuda_exception::build("helpers::alphaDropoutSimple: Cannot deallocate device memory for random generator.", err); - } - NDArray::registerSpecialUse({output}, {input}); - } +template +static __global__ void alphaDropoutSimpleKernel( + void const* inputBuf, Nd4jLong const* inputShape, void* outputBuf, + Nd4jLong const* outputShape, double probValue, double alpha, double alpha1, + double beta, int inLen, sd::graph::RandomGenerator* nodeRng) { + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + T const* input = reinterpret_cast(inputBuf); + T* output = reinterpret_cast(outputBuf); + + for (auto e = tid; e < inLen; e += step) { + T val = nodeRng->relativeT(e, T(0.f), T(1.f)); + T xVal = input[shape::getIndexOffset(e, inputShape)]; + output[shape::getIndexOffset(e, outputShape)] = + (val >= T(probValue) ? T(alpha * beta + alpha1) + : T(alpha * (double)xVal + alpha1)); + } +} +template +static void alphaDropoutSimple(sd::LaunchContext* context, NDArray const* input, + NDArray* output, int seed, double probValue, + double alpha, double alpha1, double beta) { + sd::graph::RandomGenerator nodeRng(3019L, seed), *dRandom; + auto stream = context->getCudaStream(); + auto err = cudaMalloc(&dRandom, sizeof(sd::graph::RandomGenerator)); + NDArray::prepareSpecialUse({output}, {input}); + if (err) { + throw cuda_exception::build( + "helpers::alphaDropoutSimple: Cannot allocate device memory for random " + "generator.", + err); + } + err = cudaMemcpy(dRandom, &nodeRng, sizeof(sd::graph::RandomGenerator), + cudaMemcpyHostToDevice); + if (err) { + throw cuda_exception::build( + "helpers::alphaDropoutSimple: Cannot set up device memory for random " + "generator.", + err); + } + + alphaDropoutSimpleKernel<<<128, 256, 1024, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), probValue, alpha, + alpha1, beta, output->lengthOf(), dRandom); + + err = cudaFree(dRandom); + if (err) { + throw cuda_exception::build( + "helpers::alphaDropoutSimple: Cannot deallocate device memory for " + "random generator.", + err); + } + NDArray::registerSpecialUse({output}, {input}); +} - template - static int alphaDropOutFunctor_(graph::Context& context, NDArray* input, NDArray* output, - NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { - - if (reduceShape == nullptr){ - alphaDropoutSimple(context.launchContext(), input, output, seed, probValue, alpha, alpha1, beta); - } - else { - REQUIRE_TRUE(reduceShape->lengthOf() <= input->rankOf(), 0, "dropout: Noise shape should be fittable to input"); - - std::vector dims(reduceShape->lengthOf()); - reduceShape->syncToHost(); // to ensure that follows are actual - bool fit = true; - - for( int i = 0; i < dims.size(); i++ ) { - if (fit) { - dims[i] = reduceShape->e(i); - for (int e = 0; e < input->rankOf(); ++e) - if (fit) - if (input->sizeAt(e) % dims[i]) { - fit = false; - } - } +template +static int alphaDropOutFunctor_(graph::Context& context, NDArray* input, + NDArray* output, NDArray* reduceShape, int seed, + double probValue, double alpha, double alpha1, + double beta) { + if (reduceShape == nullptr) { + alphaDropoutSimple(context.launchContext(), input, output, seed, + probValue, alpha, alpha1, beta); + } else { + REQUIRE_TRUE(reduceShape->lengthOf() <= input->rankOf(), 0, + "dropout: Noise shape should be fittable to input"); + + std::vector dims(reduceShape->lengthOf()); + reduceShape->syncToHost(); // to ensure that follows are actual + bool fit = true; + + for (int i = 0; i < dims.size(); i++) { + if (fit) { + dims[i] = reduceShape->e(i); + for (int e = 0; e < input->rankOf(); ++e) + if (fit) + if (input->sizeAt(e) % dims[i]) { + fit = false; } + } + } - // check dims to fit input - REQUIRE_TRUE(fit, 0, "alpha_dropout: Noise shape should fit to input rank."); - std::unique_ptr chunk(new NDArray('c', dims, output->dataType(), context.launchContext())); - chunk->assign(1.f); - - alphaDropoutSimple(context.launchContext(), chunk.get(), chunk.get(), seed, probValue, alpha, alpha1, beta); - - // broadcast chunk to full matrix - std::unique_ptr dropOutMultiplier(new NDArray(*input)); - dropOutMultiplier->assign(1.f); - - *dropOutMultiplier += *chunk; + // check dims to fit input + REQUIRE_TRUE(fit, 0, + "alpha_dropout: Noise shape should fit to input rank."); + std::unique_ptr chunk( + new NDArray('c', dims, output->dataType(), context.launchContext())); + chunk->assign(1.f); - output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr); - } + alphaDropoutSimple(context.launchContext(), chunk.get(), chunk.get(), + seed, probValue, alpha, alpha1, beta); + // broadcast chunk to full matrix + std::unique_ptr dropOutMultiplier(new NDArray(*input)); + dropOutMultiplier->assign(1.f); - return Status::OK(); - } + *dropOutMultiplier += *chunk; - template - int alphaDropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, - NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { - - int res = alphaDropOutFunctor(context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta); - if (res == ND4J_STATUS_OK) { - // FIXME: can we make it single-loop? - (*output) *= alpha; - (*output) *= (*gradOut); //->applyPairwiseTransform(gradOut, output, nullptr); - } - return res; - } + output->assign( + *input * + *dropOutMultiplier); // input->applyPairwiseTransform(pairwise::Multiply, + // dropOutMultiplier.get(), output, nullptr); + } - int dropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) { - BUILD_SINGLE_SELECTOR(context.dataType(), return dropOutFunctorBP_, (context, input, gradOut, output, reduceShape, seed, probValue), FLOAT_TYPES); - } + return Status::OK(); +} - int alphaDropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { - BUILD_SINGLE_SELECTOR(context.dataType(), return alphaDropOutFunctor_, (context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta), FLOAT_TYPES); - } +template +int alphaDropOutFunctorBP_(graph::Context& context, NDArray* input, + NDArray* gradOut, NDArray* output, + NDArray* reduceShape, int seed, double probValue, + double alpha, double alpha1, double beta) { + int res = alphaDropOutFunctor(context, input, output, reduceShape, seed, + probValue, alpha, alpha1, beta); + if (res == ND4J_STATUS_OK) { + // FIXME: can we make it single-loop? + (*output) *= alpha; + (*output) *= + (*gradOut); //->applyPairwiseTransform(gradOut, + // output, nullptr); + } + return res; +} - int alphaDropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { - BUILD_SINGLE_SELECTOR(context.dataType(), return alphaDropOutFunctorBP_, (context, input, gradOut, output, reduceShape, seed, probValue, alpha, alpha1, beta), FLOAT_TYPES); - } +int dropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, + NDArray* output, NDArray* reduceShape, int seed, + double probValue) { + BUILD_SINGLE_SELECTOR( + output->dataType(), return dropOutFunctorBP_, + (context, input, gradOut, output, reduceShape, seed, probValue), + FLOAT_TYPES); +} +int alphaDropOutFunctor(graph::Context& context, NDArray* input, + NDArray* output, NDArray* reduceShape, int seed, + double probValue, double alpha, double alpha1, + double beta) { + BUILD_SINGLE_SELECTOR(output->dataType(), return alphaDropOutFunctor_, + (context, input, output, reduceShape, seed, probValue, + alpha, alpha1, beta), + FLOAT_TYPES); } + +int alphaDropOutFunctorBP(graph::Context& context, NDArray* input, + NDArray* gradOut, NDArray* output, + NDArray* reduceShape, int seed, double probValue, + double alpha, double alpha1, double beta) { + BUILD_SINGLE_SELECTOR(output->dataType(), return alphaDropOutFunctorBP_, + (context, input, gradOut, output, reduceShape, seed, + probValue, alpha, alpha1, beta), + FLOAT_TYPES); } -} \ No newline at end of file + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu b/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu index bce7316efc25..e009a9280105 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu @@ -17,354 +17,429 @@ // // @author raver119@gmail.com // -#include -#include #include +#include +#include namespace sd { - namespace ops { - namespace helpers { - - - template - static _CUDA_G void dynamicPartitionScalarKernel(const void *vx, const Nd4jLong *xShapeInfo, const void *vi, const Nd4jLong *iShapeInfo, void **vz, Nd4jLong **zShapeInfos, const Nd4jLong numOutputs) { - auto x = reinterpret_cast(vx); - auto i = reinterpret_cast(vi); - auto xLength = shape::length(xShapeInfo); - auto iLength = shape::length(iShapeInfo); - - extern __shared__ char shmem[]; - __shared__ Y *rawIndices; - __shared__ Y *trueIndices; - - if (threadIdx.x == 0) { - rawIndices = reinterpret_cast(shmem); - trueIndices = rawIndices + blockDim.x; - } - __syncthreads(); - - // we run things in blocks, 1 partition per block of threads - for (Nd4jLong o = blockIdx.x; o < numOutputs; o += gridDim.x) { - auto z = reinterpret_cast(vz[o]); - - auto zShapeInfo = zShapeInfos[o]; - auto zLength = shape::length(zShapeInfo); - - // iLimit should be multiple of blockDim.x - auto iLimit = iLength <= blockDim.x ? blockDim.x : (iLength + (blockDim.x - (iLength % blockDim.x))); - int cnt = 0; - - for (Nd4jLong e = threadIdx.x; e < iLimit; e += blockDim.x) { - // load set of indices into shared memory - if (e < iLength) - rawIndices[threadIdx.x] = i[shape::getIndexOffset(e, iShapeInfo)]; - __syncthreads(); - - // now we need to find out where our actual updates will be mapped - // TODO: this can be improved obviously, by using prefix-sum like approach - if (threadIdx.x == 0) { - for (int f = 0; f < blockDim.x; f++) { - if (rawIndices[f] == static_cast(o)) - trueIndices[f] = cnt++; - else - trueIndices[f] = -1; - } - } - __syncthreads(); - - - // doing actual update - if (e < iLength) - if (trueIndices[threadIdx.x] >= 0) { - z[trueIndices[threadIdx.x]] = x[shape::getIndexOffset(e, xShapeInfo)]; - } - - __syncthreads(); - } - } - } - - template - static _CUDA_G void dynamicPartitionTadKernel(const void *vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, Nd4jLong xLength, const void *vindices, const Nd4jLong *iShapeInfo, Nd4jLong iLength, void **vz, Nd4jLong **zTadShapeInfos, Nd4jLong **zTadOffsets, Nd4jLong numOutputs) { - auto x = reinterpret_cast(vx); - auto indices = reinterpret_cast(vindices); - - // we run things in blocks, 1 partition per block of threads - for (int i = blockIdx.x; i < numOutputs; i += gridDim.x) { - auto z = reinterpret_cast(vz[i]); - - // each thread has own counter for partitions - int outCnt = 0; - - for (Nd4jLong e = 0; e < iLength; e++) { - if (indices[shape::getIndexOffset(e, iShapeInfo)] == i) { - auto dx = x + xTadOffsets[e]; - auto dz = z + zTadOffsets[i][outCnt++]; - - for (int f = threadIdx.x; f < xLength; f += blockDim.x) { - dz[shape::getIndexOffset(f, zTadShapeInfos[i])] = dx[shape::getIndexOffset(f, xTadShapeInfo)]; - } - } - } - } - } - - template - static void _dynamicPartitionFunctor(sd::LaunchContext * context, NDArray const* input, NDArray const* indices, std::vector& outputList) { - std::vector> outputs(outputList.size()); - int sourceDimsLen = input->rankOf() - indices->rankOf(); - - unsigned int outSize = outputList.size(); - - PointersManager pm(context, "dynamicPartition"); - - if (sourceDimsLen) { // non-linear case - std::vector sourceDims(sourceDimsLen); - - for (int i = sourceDimsLen; i > 0; i--) - sourceDims[sourceDimsLen - i] = input->rankOf() - i; - //compute tad array for given dimensions - auto packX = ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), sourceDims); - - std::vector outBuffers(outSize); - std::vector tadShapes(outSize); - std::vector tadOffsets(outSize); - std::vector numTads(outSize); - // fill up dimensions array for before kernel - for (unsigned int i = 0; i < outSize; i++) { - outputs[i].first = outputList[i]; - std::vector outDims(outputs[i].first->rankOf() - 1); - - int r = outputs[i].first->rankOf(); - - for (int k = 1; k < r; k++) - outDims[k - 1] = k; - - auto packZ = ConstantTadHelper::getInstance().tadForDimensions(outputList.at(i)->shapeInfo(), outDims); - - outBuffers[i] = outputList.at(i)->specialBuffer(); - tadShapes[i] = packZ.platformShapeInfo(); - tadOffsets[i] = packZ.platformOffsets(); - } +namespace ops { +namespace helpers { + +template +static _CUDA_G void dynamicPartitionScalarKernel( + const void *vx, const Nd4jLong *xShapeInfo, const void *vi, + const Nd4jLong *iShapeInfo, void **vz, Nd4jLong **zShapeInfos, + const Nd4jLong numOutputs) { + auto x = reinterpret_cast(vx); + auto i = reinterpret_cast(vi); + auto xLength = shape::length(xShapeInfo); + auto iLength = shape::length(iShapeInfo); + + extern __shared__ char shmem[]; + __shared__ Y *rawIndices; + __shared__ Y *trueIndices; + + if (threadIdx.x == 0) { + rawIndices = reinterpret_cast(shmem); + trueIndices = rawIndices + blockDim.x; + } + __syncthreads(); + + // we run things in blocks, 1 partition per block of threads + for (Nd4jLong o = blockIdx.x; o < numOutputs; o += gridDim.x) { + auto z = reinterpret_cast(vz[o]); + + auto zShapeInfo = zShapeInfos[o]; + auto zLength = shape::length(zShapeInfo); + + // iLimit should be multiple of blockDim.x + auto iLimit = iLength <= blockDim.x + ? blockDim.x + : (iLength + (blockDim.x - (iLength % blockDim.x))); + int cnt = 0; + + for (Nd4jLong e = threadIdx.x; e < iLimit; e += blockDim.x) { + // load set of indices into shared memory + if (e < iLength) + rawIndices[threadIdx.x] = i[shape::getIndexOffset(e, iShapeInfo)]; + __syncthreads(); + + // now we need to find out where our actual updates will be mapped + // TODO: this can be improved obviously, by using prefix-sum like approach + if (threadIdx.x == 0) { + for (int f = 0; f < blockDim.x; f++) { + if (rawIndices[f] == static_cast(o)) + trueIndices[f] = cnt++; + else + trueIndices[f] = -1; + } + } + __syncthreads(); - // we copy pointers to device - auto dOutBuffers = reinterpret_cast(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *))); - auto dOutTadShapes = reinterpret_cast(pm.replicatePointer(tadShapes.data(), tadShapes.size() * sizeof(Nd4jLong *))); - auto dOutTadOffsets = reinterpret_cast(pm.replicatePointer(tadOffsets.data(), tadOffsets.size() * sizeof(Nd4jLong *))); - // run kernel on device - dynamicPartitionTadKernel<<<256, 256, 1024, *context->getCudaStream()>>>(input->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), shape::length(packX.primaryShapeInfo()), indices->specialBuffer(), indices->specialShapeInfo(), indices->lengthOf(), dOutBuffers, dOutTadShapes, dOutTadOffsets, outSize); - - } else { // linear case - auto numThreads = 256; - auto shmemSize = numThreads * sizeof(Y) * 2 + 1024; - - std::vector outBuffers; - std::vector outShapes; - - for (auto v:outputList) { - outBuffers.emplace_back(v->specialBuffer()); - outShapes.emplace_back(v->specialShapeInfo()); - } + // doing actual update + if (e < iLength) + if (trueIndices[threadIdx.x] >= 0) { + z[trueIndices[threadIdx.x]] = x[shape::getIndexOffset(e, xShapeInfo)]; + } - auto dOutBuffers = reinterpret_cast(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *))); - auto dOutShapes = reinterpret_cast(pm.replicatePointer(outShapes.data(), outShapes.size() * sizeof(Nd4jLong *))); + __syncthreads(); + } + } +} - dynamicPartitionScalarKernel<<<256, numThreads, shmemSize, *context->getCudaStream()>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), dOutBuffers, dOutShapes, outSize); - } +template +static _CUDA_G void dynamicPartitionTadKernel( + const void *vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, + Nd4jLong xLength, const void *vindices, const Nd4jLong *iShapeInfo, + Nd4jLong iLength, void **vz, Nd4jLong **zTadShapeInfos, + Nd4jLong **zTadOffsets, Nd4jLong numOutputs) { + auto x = reinterpret_cast(vx); + auto indices = reinterpret_cast(vindices); + + // we run things in blocks, 1 partition per block of threads + for (int i = blockIdx.x; i < numOutputs; i += gridDim.x) { + auto z = reinterpret_cast(vz[i]); + + // each thread has own counter for partitions + int outCnt = 0; + + for (Nd4jLong e = 0; e < iLength; e++) { + if (indices[shape::getIndexOffset(e, iShapeInfo)] == i) { + auto dx = x + xTadOffsets[e]; + auto dz = z + zTadOffsets[i][outCnt++]; + + for (int f = threadIdx.x; f < xLength; f += blockDim.x) { + dz[shape::getIndexOffset(f, zTadShapeInfos[i])] = + dx[shape::getIndexOffset(f, xTadShapeInfo)]; + } + } + } + } +} - pm.synchronize(); - } +template +static void _dynamicPartitionFunctor(sd::LaunchContext *context, + NDArray const *input, + NDArray const *indices, + std::vector &outputList) { + std::vector> outputs(outputList.size()); + int sourceDimsLen = input->rankOf() - indices->rankOf(); + unsigned int outSize = outputList.size(); - template - static _CUDA_G void dynamicStitchScalarKernel(void **vx, Nd4jLong **xShapeInfos, void **vindices, Nd4jLong **iShapeInfos, int inputSize, void *vz, const Nd4jLong *zShapeInfo, Nd4jLong zLength) { - auto z = reinterpret_cast(vz); + PointersManager pm(context, "dynamicPartition"); - for (int e = blockIdx.x; e < inputSize; e += gridDim.x) { - auto x = reinterpret_cast(vx[e]); - auto indices = reinterpret_cast(vindices[e]); + if (sourceDimsLen) { // non-linear case + std::vector sourceDims(sourceDimsLen); - auto xShapeInfo = xShapeInfos[e]; - auto iShapeInfo = iShapeInfos[e]; + for (int i = sourceDimsLen; i > 0; i--) + sourceDims[sourceDimsLen - i] = input->rankOf() - i; + // compute tad array for given dimensions + auto packX = ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), sourceDims); - auto iLength = shape::length(iShapeInfo); + std::vector outBuffers(outSize); + std::vector tadShapes(outSize); + std::vector tadOffsets(outSize); + std::vector numTads(outSize); + // fill up dimensions array for before kernel + for (unsigned int i = 0; i < outSize; i++) { + outputs[i].first = outputList[i]; + std::vector outDims(outputs[i].first->rankOf() - 1); - for (int i = threadIdx.x; i < iLength; i += blockDim.x) { - auto idx = indices[shape::getIndexOffset(i, iShapeInfo)]; - if (idx >= 0 && idx < zLength) - z[shape::getIndexOffset(idx, zShapeInfo)] = x[shape::getIndexOffset(i, xShapeInfo)]; - } - } - } + int r = outputs[i].first->rankOf(); - template - static _CUDA_G void dynamicStitchTadKernel(void **vx, Nd4jLong **xTadShapeInfos, Nd4jLong **xTadOffsets, void **vindices, Nd4jLong **iShapeInfos, int inputSize, void *vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets) { - auto bz = reinterpret_cast(vz); + for (int k = 1; k < r; k++) outDims[k - 1] = k; - for (int e = blockIdx.x; e < inputSize; e += gridDim.x) { - auto indices = reinterpret_cast(vindices[e]); - auto iShapeInfo = iShapeInfos[e]; + auto packZ = ConstantTadHelper::getInstance().tadForDimensions( + outputList.at(i)->shapeInfo(), outDims); - if (shape::isEmpty(iShapeInfo)) - continue; + outBuffers[i] = outputList.at(i)->specialBuffer(); + tadShapes[i] = packZ.platformShapeInfo(); + tadOffsets[i] = packZ.platformOffsets(); + } - auto iLength = shape::length(iShapeInfo); - auto zLength = shape::length(zTadShapeInfo); + // we copy pointers to device + auto dOutBuffers = reinterpret_cast(pm.replicatePointer( + outBuffers.data(), outBuffers.size() * sizeof(void *))); + auto dOutTadShapes = reinterpret_cast(pm.replicatePointer( + tadShapes.data(), tadShapes.size() * sizeof(Nd4jLong *))); + auto dOutTadOffsets = reinterpret_cast(pm.replicatePointer( + tadOffsets.data(), tadOffsets.size() * sizeof(Nd4jLong *))); + // run kernel on device + dynamicPartitionTadKernel + <<<256, 256, 1024, *context->getCudaStream()>>>( + input->specialBuffer(), packX.platformShapeInfo(), + packX.platformOffsets(), shape::length(packX.primaryShapeInfo()), + indices->specialBuffer(), indices->specialShapeInfo(), + indices->lengthOf(), dOutBuffers, dOutTadShapes, dOutTadOffsets, + outSize); + + } else { // linear case + auto numThreads = 256; + auto shmemSize = numThreads * sizeof(Y) * 2 + 1024; + + std::vector outBuffers; + std::vector outShapes; + + for (auto v : outputList) { + outBuffers.emplace_back(v->specialBuffer()); + outShapes.emplace_back(v->specialShapeInfo()); + } - auto xShapeInfo = xTadShapeInfos[e]; - auto xLength = shape::length(xShapeInfo); + auto dOutBuffers = reinterpret_cast(pm.replicatePointer( + outBuffers.data(), outBuffers.size() * sizeof(void *))); + auto dOutShapes = reinterpret_cast(pm.replicatePointer( + outShapes.data(), outShapes.size() * sizeof(Nd4jLong *))); - for (int i = 0; i < iLength; i++) { - auto idx = indices[shape::getIndexOffset(i, iShapeInfo)]; + dynamicPartitionScalarKernel + <<<256, numThreads, shmemSize, *context->getCudaStream()>>>( + input->specialBuffer(), input->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), dOutBuffers, + dOutShapes, outSize); + } - auto z = bz + zTadOffsets[idx]; - auto x = reinterpret_cast(vx[e]) + xTadOffsets[e][i]; + pm.synchronize(); +} - for (int f = threadIdx.x; f < zLength; f += blockDim.x) { - z[shape::getIndexOffset(f, zTadShapeInfo)] = x[shape::getIndexOffset(f, xShapeInfo)]; - } +template +static _CUDA_G void dynamicStitchScalarKernel( + void **vx, Nd4jLong **xShapeInfos, void **vindices, Nd4jLong **iShapeInfos, + int inputSize, void *vz, const Nd4jLong *zShapeInfo, Nd4jLong zLength) { + auto z = reinterpret_cast(vz); - __syncthreads(); - } - } - } + for (int e = blockIdx.x; e < inputSize; e += gridDim.x) { + auto x = reinterpret_cast(vx[e]); + auto indices = reinterpret_cast(vindices[e]); - template - static int _dynamicStitchFunctor(sd::LaunchContext * context, std::vector const& inputs, std::vector const& indices, NDArray* output){ + auto xShapeInfo = xShapeInfos[e]; + auto iShapeInfo = iShapeInfos[e]; - int inputSize = inputs.size(); + auto iLength = shape::length(iShapeInfo); - PointersManager pm(context, "dynamicStitch"); + for (int i = threadIdx.x; i < iLength; i += blockDim.x) { + auto idx = indices[shape::getIndexOffset(i, iShapeInfo)]; + if (idx >= 0 && idx < zLength) + z[shape::getIndexOffset(idx, zShapeInfo)] = + x[shape::getIndexOffset(i, xShapeInfo)]; + } + } +} - if (output->isVector()) { - std::vector inputBuffers(inputSize); - std::vector inputShapes(inputSize); - std::vector indicesBuffers(inputSize); - std::vector indicesShapes(inputSize); +template +static _CUDA_G void dynamicStitchTadKernel( + void **vx, Nd4jLong **xTadShapeInfos, Nd4jLong **xTadOffsets, + void **vindices, Nd4jLong **iShapeInfos, int inputSize, void *vz, + const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets) { + auto bz = reinterpret_cast(vz); - for (int e = 0; e < inputSize; e++) { - inputBuffers[e] = inputs.at(e)->specialBuffer(); - indicesBuffers[e] = indices.at(e)->specialBuffer(); + for (int e = blockIdx.x; e < inputSize; e += gridDim.x) { + auto indices = reinterpret_cast(vindices[e]); + auto iShapeInfo = iShapeInfos[e]; - inputShapes[e] = inputs.at(e)->specialShapeInfo(); - indicesShapes[e] = indices.at(e)->specialShapeInfo(); - } + if (shape::isEmpty(iShapeInfo)) continue; - // copying pointers to buffers to device - auto dInputBuffers = reinterpret_cast(pm.replicatePointer(inputBuffers.data(), inputSize * sizeof(void *))); - auto dIndicesBuffers = reinterpret_cast(pm.replicatePointer(indicesBuffers.data(), inputSize * sizeof(void *))); - auto dInputShapes = reinterpret_cast(pm.replicatePointer(inputShapes.data(), inputSize * sizeof(Nd4jLong *))); - auto dIndicesShapes = reinterpret_cast(pm.replicatePointer(indicesShapes.data(), inputSize * sizeof(Nd4jLong *))); + auto iLength = shape::length(iShapeInfo); + auto zLength = shape::length(zTadShapeInfo); - dynamicStitchScalarKernel<<<256, 256, 1024, *context->getCudaStream()>>>(dInputBuffers, dInputShapes, dIndicesBuffers, dIndicesShapes, inputSize, output->specialBuffer(), output->specialShapeInfo(), output->lengthOf()); - } else { - std::vector restDims(output->rankOf() - 1); - for (int i = restDims.size(); i > 0; i--) - restDims[restDims.size() - i] = output->rankOf() - i; + auto xShapeInfo = xTadShapeInfos[e]; + auto xLength = shape::length(xShapeInfo); - auto packZ = ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), restDims); + for (int i = 0; i < iLength; i++) { + auto idx = indices[shape::getIndexOffset(i, iShapeInfo)]; - std::vector inputBuffers(inputSize); - std::vector inputTadShapes(inputSize); - std::vector inputTadOffsets(inputSize); + auto z = bz + zTadOffsets[idx]; + auto x = reinterpret_cast(vx[e]) + xTadOffsets[e][i]; - std::vector indicesBuffers(inputSize); - std::vector indicesShapes(inputSize); + for (int f = threadIdx.x; f < zLength; f += blockDim.x) { + z[shape::getIndexOffset(f, zTadShapeInfo)] = + x[shape::getIndexOffset(f, xShapeInfo)]; + } - for (int e = 0; e < inputSize; e++) { - std::vector sourceDims(inputs[e]->rankOf() - indices[e]->rankOf()); - for (int i = sourceDims.size(); i > 0; i--) - sourceDims[sourceDims.size() - i] = inputs[e]->rankOf() - i; + __syncthreads(); + } + } +} - auto packX = ConstantTadHelper::getInstance().tadForDimensions(inputs[e]->shapeInfo(), sourceDims); +template +static int _dynamicStitchFunctor(sd::LaunchContext *context, + std::vector const &inputs, + std::vector const &indices, + NDArray *output) { + int inputSize = inputs.size(); - indicesBuffers[e] = indices[e]->specialBuffer(); - indicesShapes[e] = indices[e]->specialShapeInfo(); + PointersManager pm(context, "dynamicStitch"); - inputBuffers[e] = inputs[e]->specialBuffer(); - inputTadShapes[e] = packX.platformShapeInfo(); - inputTadOffsets[e] = packX.platformOffsets(); - } + if (output->isVector()) { + std::vector inputBuffers(inputSize); + std::vector inputShapes(inputSize); + std::vector indicesBuffers(inputSize); + std::vector indicesShapes(inputSize); - // copying pointers to buffers to device - auto dInputBuffers = reinterpret_cast(pm.replicatePointer(inputBuffers.data(), inputSize * sizeof(void *))); - auto dInputTadShapes = reinterpret_cast(pm.replicatePointer(inputTadShapes.data(), inputSize * sizeof(Nd4jLong *))); - auto dInputTadOffsets = reinterpret_cast(pm.replicatePointer(inputTadOffsets.data(), inputSize * sizeof(Nd4jLong *))); + for (int e = 0; e < inputSize; e++) { + inputBuffers[e] = inputs.at(e)->specialBuffer(); + indicesBuffers[e] = indices.at(e)->specialBuffer(); - auto dIndicesBuffers = reinterpret_cast(pm.replicatePointer(indicesBuffers.data(), inputSize * sizeof(void *))); - auto dIndicesShapes = reinterpret_cast(pm.replicatePointer(indicesShapes.data(), inputSize * sizeof(Nd4jLong *))); + inputShapes[e] = inputs.at(e)->specialShapeInfo(); + indicesShapes[e] = indices.at(e)->specialShapeInfo(); + } - dynamicStitchTadKernel<<<256, 256, 1024, *context->getCudaStream()>>>(dInputBuffers, dInputTadShapes, dInputTadOffsets, dIndicesBuffers, dIndicesShapes, inputSize, output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets()); - } + // copying pointers to buffers to device + auto dInputBuffers = reinterpret_cast( + pm.replicatePointer(inputBuffers.data(), inputSize * sizeof(void *))); + auto dIndicesBuffers = reinterpret_cast( + pm.replicatePointer(indicesBuffers.data(), inputSize * sizeof(void *))); + auto dInputShapes = reinterpret_cast(pm.replicatePointer( + inputShapes.data(), inputSize * sizeof(Nd4jLong *))); + auto dIndicesShapes = reinterpret_cast(pm.replicatePointer( + indicesShapes.data(), inputSize * sizeof(Nd4jLong *))); + + dynamicStitchScalarKernel + <<<256, 256, 1024, *context->getCudaStream()>>>( + dInputBuffers, dInputShapes, dIndicesBuffers, dIndicesShapes, + inputSize, output->specialBuffer(), output->specialShapeInfo(), + output->lengthOf()); + } else { + std::vector restDims(output->rankOf() - 1); + for (int i = restDims.size(); i > 0; i--) + restDims[restDims.size() - i] = output->rankOf() - i; + + auto packZ = ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), restDims); + + std::vector inputBuffers(inputSize); + std::vector inputTadShapes(inputSize); + std::vector inputTadOffsets(inputSize); + + std::vector indicesBuffers(inputSize); + std::vector indicesShapes(inputSize); + + for (int e = 0; e < inputSize; e++) { + std::vector sourceDims(inputs[e]->rankOf() - indices[e]->rankOf()); + for (int i = sourceDims.size(); i > 0; i--) + sourceDims[sourceDims.size() - i] = inputs[e]->rankOf() - i; + + auto packX = ConstantTadHelper::getInstance().tadForDimensions( + inputs[e]->shapeInfo(), sourceDims); + + indicesBuffers[e] = indices[e]->specialBuffer(); + indicesShapes[e] = indices[e]->specialShapeInfo(); + + inputBuffers[e] = inputs[e]->specialBuffer(); + inputTadShapes[e] = packX.platformShapeInfo(); + inputTadOffsets[e] = packX.platformOffsets(); + } - pm.synchronize(); + // copying pointers to buffers to device + auto dInputBuffers = reinterpret_cast( + pm.replicatePointer(inputBuffers.data(), inputSize * sizeof(void *))); + auto dInputTadShapes = reinterpret_cast(pm.replicatePointer( + inputTadShapes.data(), inputSize * sizeof(Nd4jLong *))); + auto dInputTadOffsets = reinterpret_cast(pm.replicatePointer( + inputTadOffsets.data(), inputSize * sizeof(Nd4jLong *))); - return Status::OK(); - } + auto dIndicesBuffers = reinterpret_cast( + pm.replicatePointer(indicesBuffers.data(), inputSize * sizeof(void *))); + auto dIndicesShapes = reinterpret_cast(pm.replicatePointer( + indicesShapes.data(), inputSize * sizeof(Nd4jLong *))); - template - static void _dynamicPartitionFunctorBP(NDArray const* input, NDArray const* indices, std::vector const& inputGradientList, std::vector& outputList) { + dynamicStitchTadKernel<<<256, 256, 1024, *context->getCudaStream()>>>( + dInputBuffers, dInputTadShapes, dInputTadOffsets, dIndicesBuffers, + dIndicesShapes, inputSize, output->specialBuffer(), + packZ.platformShapeInfo(), packZ.platformOffsets()); + } - } + pm.synchronize(); - void dynamicPartitionFunctor(sd::LaunchContext * context, NDArray const* input, NDArray const* indices, std::vector& outputList) { - auto xType = input->dataType(); - auto yType = indices->dataType(); + return Status::OK(); +} - NDArray::prepareSpecialUse({}, {indices, input}); +template +static void _dynamicPartitionFunctorBP( + NDArray const *input, NDArray const *indices, + std::vector const &inputGradientList, + std::vector &outputList) {} - BUILD_DOUBLE_SELECTOR(xType, yType, _dynamicPartitionFunctor, (context, input, indices, outputList), NUMERIC_TYPES, INDEXING_TYPES); +void dynamicPartitionFunctor(sd::LaunchContext *context, NDArray const *input, + NDArray const *indices, + std::vector &outputList) { + auto xType = input->dataType(); + auto yType = indices->dataType(); - NDArray::registerSpecialUse({}, {indices, input}); + NDArray::prepareSpecialUse({}, {indices, input}); - // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list - for (auto v:outputList) { - v->tickWriteDevice(); - } - } + BUILD_DOUBLE_SELECTOR(xType, yType, _dynamicPartitionFunctor, + (context, input, indices, outputList), NUMERIC_TYPES, + INDEXING_TYPES); - template - static int _dynamicStitchFunctorBP(std::vector const& inputs, std::vector const& indices, NDArray const* gradInput, std::vector& outputList){ - throw std::runtime_error("Not umplemented yet"); - } + NDArray::registerSpecialUse({}, {indices, input}); - int dynamicStitchFunctor(sd::LaunchContext * context, std::vector const& inputs, std::vector const& indices, NDArray* output){ - auto xType = inputs.at(0)->dataType(); - auto yType = indices.at(0)->dataType(); + // TODO: it would be nice to have NDArray::registerSpecialUse signature that + // accepts something else beyond initializer_list + for (auto v : outputList) { + v->tickWriteDevice(); + } +} - for (auto v:indices) { - v->syncToDevice(); - v->tickReadDevice(); - } +template +static int _dynamicStitchFunctorBP(std::vector const &inputs, + std::vector const &indices, + NDArray const *gradInput, + std::vector &outputList) { + throw std::runtime_error("Not umplemented yet"); +} - for (auto v:inputs) { - v->syncToDevice(); - v->tickReadDevice(); - } +int dynamicStitchFunctor(sd::LaunchContext *context, + std::vector const &inputs, + std::vector const &indices, + NDArray *output) { + auto xType = inputs.at(0)->dataType(); + auto yType = indices.at(0)->dataType(); - NDArray::prepareSpecialUse({output}, {}); + for (auto v : indices) { + v->syncToDevice(); + v->tickReadDevice(); + } + for (auto v : inputs) { + v->syncToDevice(); + v->tickReadDevice(); + } - BUILD_DOUBLE_SELECTOR(xType, yType, _dynamicStitchFunctor, (context, inputs, indices, output), NUMERIC_TYPES, INDEXING_TYPES); + NDArray::prepareSpecialUse({output}, {}); - NDArray::registerSpecialUse({output}, {}); + BUILD_DOUBLE_SELECTOR(xType, yType, _dynamicStitchFunctor, + (context, inputs, indices, output), NUMERIC_TYPES, + INDEXING_TYPES); - return Status::OK(); - } + NDArray::registerSpecialUse({output}, {}); - int dynamicStitchFunctorBP(sd::LaunchContext * context, std::vector const& inputs, std::vector const& indices, NDArray const* gradInput, std::vector& outputList) { - auto xType = inputs.at(0)->dataType(); + return Status::OK(); +} - BUILD_SINGLE_SELECTOR(xType, return _dynamicStitchFunctorBP, (inputs, indices, gradInput, outputList), NUMERIC_TYPES); - } +int dynamicStitchFunctorBP(sd::LaunchContext *context, + std::vector const &inputs, + std::vector const &indices, + NDArray const *gradInput, + std::vector &outputList) { + auto xType = inputs.at(0)->dataType(); - void dynamicPartitionFunctorBP(sd::LaunchContext * context, NDArray const* input, NDArray const* indices, std::vector const& inputGradientList, std::vector& outputList) { - auto xType = input->dataType(); + BUILD_SINGLE_SELECTOR(xType, return _dynamicStitchFunctorBP, + (inputs, indices, gradInput, outputList), + NUMERIC_TYPES); +} - BUILD_SINGLE_SELECTOR(xType, _dynamicPartitionFunctorBP, (input, indices, inputGradientList, outputList), NUMERIC_TYPES); - } +void dynamicPartitionFunctorBP(sd::LaunchContext *context, NDArray const *input, + NDArray const *indices, + std::vector const &inputGradientList, + std::vector &outputList) { + auto xType = input->dataType(); - } - } + BUILD_SINGLE_SELECTOR(xType, _dynamicPartitionFunctorBP, + (input, indices, inputGradientList, outputList), + NUMERIC_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/extract_patches.cu b/libnd4j/include/ops/declarable/helpers/cuda/extract_patches.cu index e1c50687982a..a3fb65780a96 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/extract_patches.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/extract_patches.cu @@ -19,11 +19,12 @@ // @author sgazeos@gmail.com // -#include +#include #include #include +#include + #include -#include namespace sd { namespace ops { @@ -46,101 +47,121 @@ namespace helpers { // - outTadShape - output TAD shape // - outputOffsets - output TAD offsets // - template - static __global__ void globalExtractPatchesKernel(bool theSame, int batchCount, int sizeRow, int sizeCol, int rowDim, int colDim, int outRowDim, int outColDim, int strideRow, int strideCol, int rateRow, int rateCol, int rowCast, int colCast, int lastDim, const T* input, const Nd4jLong* patchShape, const Nd4jLong* inputOffsets, T* output, const Nd4jLong* outTadShape, const Nd4jLong* outputOffsets) { - - auto start = threadIdx.x + blockIdx.x * blockDim.x; - - auto step = blockDim.x * gridDim.x; - // batch input by 3 last dims and extrapole input onto output with outColDim/outRowDim - for (Nd4jLong batch = start; batch < batchCount; batch += step) { - auto patch = input + inputOffsets[batch];// listOfMatricies->at(batch); - auto outMatrix = output + outputOffsets[batch]; //listOfOutputs->at(batch); - - for (Nd4jLong i = 0; i < outRowDim; i++) { - for (Nd4jLong j = 0; j < outColDim; j++) { - Nd4jLong pos = 0; - auto rowStart = i * strideRow - (theSame?rowCast:0); - auto colStart = j * strideCol - (theSame?colCast:0); - auto rowEnd = rowStart + sizeRow * rateRow; - auto colEnd = colStart + sizeCol * rateCol; - if (!theSame) { - rowEnd = math::nd4j_min(rowStart + sizeRow * rateRow, Nd4jLong (rowDim)); - colEnd = math::nd4j_min(colStart + sizeCol * rateCol, Nd4jLong (colDim)); - } - - for (auto row = rowStart; row < rowEnd; row += rateRow) { - for (auto col = colStart; col < colEnd; col += rateCol) { - for (auto pixel = 0; pixel < lastDim; pixel++) { - Nd4jLong zPos[] = {i, j, pos}; - Nd4jLong xPos[] = {row, col, pixel}; - bool setUp = - (theSame && row >= 0 && col >= 0 && row < rowDim && col < colDim) || (!theSame); - - if (setUp) { // VALID or SAME cases - outMatrix[shape::getOffset(outTadShape, zPos)] = patch[shape::getOffset(patchShape, xPos)]; - } - pos++; - } - } - } - } - } +template +static __global__ void globalExtractPatchesKernel( + bool theSame, int batchCount, int sizeRow, int sizeCol, int rowDim, + int colDim, int outRowDim, int outColDim, int strideRow, int strideCol, + int rateRow, int rateCol, int rowCast, int colCast, int lastDim, + const T* input, const Nd4jLong* patchShape, const Nd4jLong* inputOffsets, + T* output, const Nd4jLong* outTadShape, const Nd4jLong* outputOffsets) { + auto start = threadIdx.x + blockIdx.x * blockDim.x; + + auto step = blockDim.x * gridDim.x; + // batch input by 3 last dims and extrapole input onto output with + // outColDim/outRowDim + for (Nd4jLong batch = start; batch < batchCount; batch += step) { + auto patch = input + inputOffsets[batch]; // listOfMatricies->at(batch); + auto outMatrix = + output + outputOffsets[batch]; // listOfOutputs->at(batch); + + for (Nd4jLong i = 0; i < outRowDim; i++) { + for (Nd4jLong j = 0; j < outColDim; j++) { + Nd4jLong pos = 0; + auto rowStart = i * strideRow - (theSame ? rowCast : 0); + auto colStart = j * strideCol - (theSame ? colCast : 0); + auto rowEnd = rowStart + sizeRow * rateRow; + auto colEnd = colStart + sizeCol * rateCol; + if (!theSame) { + rowEnd = + math::nd4j_min(rowStart + sizeRow * rateRow, Nd4jLong(rowDim)); + colEnd = + math::nd4j_min(colStart + sizeCol * rateCol, Nd4jLong(colDim)); } + for (auto row = rowStart; row < rowEnd; row += rateRow) { + for (auto col = colStart; col < colEnd; col += rateCol) { + for (auto pixel = 0; pixel < lastDim; pixel++) { + Nd4jLong zPos[] = {i, j, pos}; + Nd4jLong xPos[] = {row, col, pixel}; + bool setUp = (theSame && row >= 0 && col >= 0 && row < rowDim && + col < colDim) || + (!theSame); + + if (setUp) { // VALID or SAME cases + outMatrix[shape::getOffset(outTadShape, zPos)] = + patch[shape::getOffset(patchShape, xPos)]; + } + pos++; + } + } + } + } } + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - static void _extractPatches(sd::LaunchContext * context, NDArray* images, NDArray* output, int sizeRow, int sizeCol, int strideRow, int strideCol, int rateRow, int rateCol, bool theSame){ - NDArray::prepareSpecialUse({output}, {images}); - std::vector restDims({1, 2, 3}); // the first and the last dims - // 3D matricies - 2D matricies of vectors (if last dim is greater than 1) - //int e = 0; - const int ksizeRowsEffective = sizeRow + (sizeRow - 1) * (rateRow - 1); - const int ksizeColsEffective = sizeCol + (sizeCol - 1) * (rateCol - 1); - const int ksize = ksizeRowsEffective * ksizeColsEffective; - Nd4jLong lastDim = images->sizeAt(3); - Nd4jLong outLastDim = output->sizeAt(3); - Nd4jLong rowDim = images->sizeAt(1); - Nd4jLong colDim = images->sizeAt(2); - Nd4jLong outRowDim = output->sizeAt(1); - Nd4jLong outColDim = output->sizeAt(2); - auto rowCast = 1; - auto colCast = 1; - // validate shifts - if (sizeRow * rateRow < 3) - rowCast = 0; - if (sizeCol * rateCol < 3) - colCast = 0; - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(images->shapeInfo(), restDims.data(), restDims.size()); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), restDims.data(), restDims.size()); - int batchCount = packX.numberOfTads(); - - PointersManager manager(context, "helpers::extractPatches"); - - auto stream = context->getCudaStream(); - auto imagesBuffer = reinterpret_cast(images->specialBuffer()); - auto outputBuffer = reinterpret_cast(output->specialBuffer()); - - globalExtractPatchesKernel<<<128, 128, 1024, *stream>>>(theSame, batchCount, sizeRow, sizeCol, - rowDim, colDim, outRowDim, outColDim, strideRow, strideCol, rateRow, rateCol, rowCast, colCast, lastDim, - imagesBuffer, packX.specialShapeInfo(), packX.specialOffsets(), outputBuffer, packZ.specialShapeInfo(), - packZ.specialOffsets()); - - manager.synchronize(); - NDArray::registerSpecialUse({output}, {images}); - } - BUILD_SINGLE_TEMPLATE(template void _extractPatches, (sd::LaunchContext * context, NDArray* input, NDArray* output, int sizeRow, int sizeCol, int stradeRow, int stradeCol, int rateRow, int rateCol, bool theSame), LIBND4J_TYPES); - - - - void extractPatches(sd::LaunchContext * context, NDArray* images, NDArray* output, int sizeRow, int sizeCol, int stradeRow, int stradeCol, int rateRow, int rateCol, bool theSame){ - auto xType = images->dataType(); - - BUILD_SINGLE_SELECTOR(xType, _extractPatches, (context, images, output, sizeRow, sizeCol, stradeRow, stradeCol, rateRow, rateCol, theSame), LIBND4J_TYPES); - } +template +static void _extractPatches(sd::LaunchContext* context, NDArray* images, + NDArray* output, int sizeRow, int sizeCol, + int strideRow, int strideCol, int rateRow, + int rateCol, bool theSame) { + NDArray::prepareSpecialUse({output}, {images}); + std::vector restDims({1, 2, 3}); // the first and the last dims + // 3D matricies - 2D matricies of vectors (if last dim is greater than 1) + // int e = 0; + const int ksizeRowsEffective = sizeRow + (sizeRow - 1) * (rateRow - 1); + const int ksizeColsEffective = sizeCol + (sizeCol - 1) * (rateCol - 1); + const int ksize = ksizeRowsEffective * ksizeColsEffective; + Nd4jLong lastDim = images->sizeAt(3); + Nd4jLong outLastDim = output->sizeAt(3); + Nd4jLong rowDim = images->sizeAt(1); + Nd4jLong colDim = images->sizeAt(2); + Nd4jLong outRowDim = output->sizeAt(1); + Nd4jLong outColDim = output->sizeAt(2); + auto rowCast = 1; + auto colCast = 1; + // validate shifts + if (sizeRow * rateRow < 3) rowCast = 0; + if (sizeCol * rateCol < 3) colCast = 0; + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + images->shapeInfo(), restDims.data(), restDims.size()); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), restDims.data(), restDims.size()); + int batchCount = packX.numberOfTads(); + + PointersManager manager(context, "helpers::extractPatches"); + + auto stream = context->getCudaStream(); + auto imagesBuffer = reinterpret_cast(images->specialBuffer()); + auto outputBuffer = reinterpret_cast(output->specialBuffer()); + + globalExtractPatchesKernel<<<128, 128, 1024, *stream>>>( + theSame, batchCount, sizeRow, sizeCol, rowDim, colDim, outRowDim, + outColDim, strideRow, strideCol, rateRow, rateCol, rowCast, colCast, + lastDim, imagesBuffer, packX.specialShapeInfo(), packX.specialOffsets(), + outputBuffer, packZ.specialShapeInfo(), packZ.specialOffsets()); + + manager.synchronize(); + NDArray::registerSpecialUse({output}, {images}); } +BUILD_SINGLE_TEMPLATE(template void _extractPatches, + (sd::LaunchContext * context, NDArray* input, + NDArray* output, int sizeRow, int sizeCol, int stradeRow, + int stradeCol, int rateRow, int rateCol, bool theSame), + LIBND4J_TYPES); + +void extractPatches(sd::LaunchContext* context, NDArray* images, + NDArray* output, int sizeRow, int sizeCol, int stradeRow, + int stradeCol, int rateRow, int rateCol, bool theSame) { + auto xType = images->dataType(); + + BUILD_SINGLE_SELECTOR(xType, _extractPatches, + (context, images, output, sizeRow, sizeCol, stradeRow, + stradeCol, rateRow, rateCol, theSame), + LIBND4J_TYPES); } -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu b/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu index 7fcd71dba4c0..2db249d5a014 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu @@ -18,8 +18,8 @@ // @author sgazeos@gmail.com // -#include #include +#include namespace sd { namespace ops { @@ -33,104 +33,116 @@ namespace helpers { // narrowed - shrink is true // output - output tensor // - template - static __host__ __device__ void - nudge(T min, T max, int quantMin, int quantMax, T* scale, T* nudgedMin, T* nudgedMax) { - T quantMaxF = static_cast(quantMax); - T quantMinF = static_cast(quantMin); - *scale = (max - min) / (quantMaxF - quantMinF); - auto zeroPointFromMin = quantMinF - min / *scale; - uint16_t const nudgedZeroPoint = [zeroPointFromMin, quantMin, quantMax, quantMaxF, quantMinF] { - if (zeroPointFromMin < quantMinF) { - return static_cast(quantMin); - } - if (zeroPointFromMin > quantMaxF) { - return static_cast(quantMax); - } - return sd::math::nd4j_round(zeroPointFromMin); - }(); - *nudgedMax = (quantMaxF - static_cast(nudgedZeroPoint)) * (*scale); - *nudgedMin = (quantMinF - static_cast(nudgedZeroPoint)) * (*scale); +template +static __host__ __device__ void nudge(T min, T max, int quantMin, int quantMax, + T* scale, T* nudgedMin, T* nudgedMax) { + T quantMaxF = static_cast(quantMax); + T quantMinF = static_cast(quantMin); + *scale = (max - min) / (quantMaxF - quantMinF); + auto zeroPointFromMin = quantMinF - min / *scale; + uint16_t const nudgedZeroPoint = [zeroPointFromMin, quantMin, quantMax, + quantMaxF, quantMinF] { + if (zeroPointFromMin < quantMinF) { + return static_cast(quantMin); } - - template - void fakeQuantWithMinMaxVars_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { - int lowIntBound = narrowed?1:0; - int upperIntBound = (1 << numBits) - 1; - min->syncToHost(); // these are scalars, so nothing much happened - max->syncToHost(); - T scale, nudgedMin, nudgedMax; - nudge(min->t(0), max->t(0), lowIntBound, upperIntBound, &scale, &nudgedMin, &nudgedMax); - - auto wiseMinMaxAndSoOn = LAMBDA_T(x, nudgedMin, nudgedMax, scale) { - T val = x; - if (x < nudgedMin) { - val = nudgedMin; - } - else if (x > nudgedMax) { - val = nudgedMax; - } - else - val = x; - return (math::nd4j_floor((val - nudgedMin) / scale + T(0.5)) * scale + nudgedMin); - }; - - input->applyLambda(wiseMinMaxAndSoOn, *output); + if (zeroPointFromMin > quantMaxF) { + return static_cast(quantMax); } + return sd::math::nd4j_round(zeroPointFromMin); + }(); + *nudgedMax = (quantMaxF - static_cast(nudgedZeroPoint)) * (*scale); + *nudgedMin = (quantMinF - static_cast(nudgedZeroPoint)) * (*scale); +} - template - static __global__ void fakeQuantWithMinMaxKernel(const T* input, const Nd4jLong* inputShape, - T* min, T* max, - int lowIntBound, int upperIntBound, Nd4jLong channels, - T* output, const Nd4jLong* outputShape, - Nd4jLong length) { - __shared__ int block; - if (threadIdx.x == 0) { - block = length / channels; // to loop with last dimension as block - } - __syncthreads(); +template +void fakeQuantWithMinMaxVars_(NDArray* input, NDArray* min, NDArray* max, + int numBits, bool narrowed, NDArray* output) { + int lowIntBound = narrowed ? 1 : 0; + int upperIntBound = (1 << numBits) - 1; + min->syncToHost(); // these are scalars, so nothing much happened + max->syncToHost(); + T scale, nudgedMin, nudgedMax; + nudge(min->t(0), max->t(0), lowIntBound, upperIntBound, &scale, + &nudgedMin, &nudgedMax); - for (auto i = blockIdx.x; i < (int)channels; i += gridDim.x) { - T scale, nudgedMin, nudgedMax; - nudge(min[i], max[i], lowIntBound, upperIntBound, &scale, &nudgedMin, &nudgedMax); - // loop over blocks to quantization between nudged min and max - for (auto b = threadIdx.x; b < block; b += blockDim.x) { - T val = input[shape::getIndexOffset(b * channels + i, inputShape)]; - if (val < nudgedMin) { - val = nudgedMin; - } else if (val > nudgedMax) { - val = nudgedMax; - } - output[shape::getIndexOffset(b * channels + i, outputShape)] = - (math::nd4j_floor((val - nudgedMin) / scale + T(0.5f)) * scale + nudgedMin); - }; - } - } + auto wiseMinMaxAndSoOn = LAMBDA_T(x, nudgedMin, nudgedMax, scale) { + T val = x; + if (x < nudgedMin) { + val = nudgedMin; + } else if (x > nudgedMax) { + val = nudgedMax; + } else + val = x; + return (math::nd4j_floor((val - nudgedMin) / scale + T(0.5)) * scale + + nudgedMin); + }; - template - void fakeQuantWithMinMaxVarsPerChannel_(LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { - int lowIntBound = narrowed?1:0; - int upperIntBound = (1 << numBits) - 1; - auto channels = min->lengthOf(); - auto length = input->lengthOf(); - NDArray::prepareSpecialUse({output}, {min, max, input}); - auto stream = context->getCudaStream(); - T* inputBuf = input->dataBuffer()->specialAsT(); - T* outputBuf = output->dataBuffer()->specialAsT(); - T* minBuf = min->dataBuffer()->specialAsT(); - T* maxBuf = max->dataBuffer()->specialAsT(); - fakeQuantWithMinMaxKernel<<<128, 256, 256, *stream>>>(inputBuf, input->specialShapeInfo(), - minBuf, maxBuf, lowIntBound, upperIntBound, channels, outputBuf, output->specialShapeInfo(), length); - NDArray::registerSpecialUse({output}, {min, max, input}); + input->applyLambda(wiseMinMaxAndSoOn, *output); +} - } +template +static __global__ void fakeQuantWithMinMaxKernel( + const T* input, const Nd4jLong* inputShape, T* min, T* max, int lowIntBound, + int upperIntBound, Nd4jLong channels, T* output, + const Nd4jLong* outputShape, Nd4jLong length) { + __shared__ int block; + if (threadIdx.x == 0) { + block = length / channels; // to loop with last dimension as block + } + __syncthreads(); - void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVars_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES); - } - void fakeQuantWithMinMaxVarsPerChannel(LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVarsPerChannel_, (context, input, min, max, numBits, narrowed, output), FLOAT_TYPES); - } + for (auto i = blockIdx.x; i < (int)channels; i += gridDim.x) { + T scale, nudgedMin, nudgedMax; + nudge(min[i], max[i], lowIntBound, upperIntBound, &scale, &nudgedMin, + &nudgedMax); + // loop over blocks to quantization between nudged min and max + for (auto b = threadIdx.x; b < block; b += blockDim.x) { + T val = input[shape::getIndexOffset(b * channels + i, inputShape)]; + if (val < nudgedMin) { + val = nudgedMin; + } else if (val > nudgedMax) { + val = nudgedMax; + } + output[shape::getIndexOffset(b * channels + i, outputShape)] = + (math::nd4j_floor((val - nudgedMin) / scale + T(0.5f)) * scale + + nudgedMin); + }; + } } + +template +void fakeQuantWithMinMaxVarsPerChannel_(LaunchContext* context, NDArray* input, + NDArray* min, NDArray* max, int numBits, + bool narrowed, NDArray* output) { + int lowIntBound = narrowed ? 1 : 0; + int upperIntBound = (1 << numBits) - 1; + auto channels = min->lengthOf(); + auto length = input->lengthOf(); + NDArray::prepareSpecialUse({output}, {min, max, input}); + auto stream = context->getCudaStream(); + T* inputBuf = input->dataBuffer()->specialAsT(); + T* outputBuf = output->dataBuffer()->specialAsT(); + T* minBuf = min->dataBuffer()->specialAsT(); + T* maxBuf = max->dataBuffer()->specialAsT(); + fakeQuantWithMinMaxKernel<<<128, 256, 256, *stream>>>( + inputBuf, input->specialShapeInfo(), minBuf, maxBuf, lowIntBound, + upperIntBound, channels, outputBuf, output->specialShapeInfo(), length); + NDArray::registerSpecialUse({output}, {min, max, input}); +} + +void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, + int numBits, bool narrowed, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVars_, + (input, min, max, numBits, narrowed, output), + FLOAT_TYPES); } +void fakeQuantWithMinMaxVarsPerChannel(LaunchContext* context, NDArray* input, + NDArray* min, NDArray* max, int numBits, + bool narrowed, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVarsPerChannel_, + (context, input, min, max, numBits, narrowed, output), + FLOAT_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/flatten.cu b/libnd4j/include/ops/declarable/helpers/cuda/flatten.cu index aa2ff8297ca8..5dbaba63a7c2 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/flatten.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/flatten.cu @@ -18,68 +18,75 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { - namespace ops { - namespace helpers { - template - void _CUDA_G flattenKernel(void **xBuffers, Nd4jLong **xShapeInfos, Nd4jLong *offsets, Nd4jLong numInputs, void *zBuffer, const Nd4jLong *zShapeInfo, char order) { - - int xCoord[MAX_RANK]; - - // each block of threads works on 1 input array - for (Nd4jLong e = blockIdx.x; e < numInputs; e += gridDim.x) { - auto z = reinterpret_cast(zBuffer) + offsets[e]; - - auto xBuffer = reinterpret_cast(xBuffers[e]); - auto xShapeInfo = xShapeInfos[e]; - auto xLength = shape::length(xShapeInfo); - - // each element of this input array has own place within common output array - for (uint i = threadIdx.x; i < xLength; i += blockDim.x) - z[i] = xBuffer[getIndexOffsetOrdered(i, xShapeInfo, order)]; - } - } - - template - void flatten_(sd::LaunchContext *context, std::vector &inputs, NDArray *output, char order) { - PointersManager pm(context, "flatten"); - - std::vector hdBuffers(inputs.size()); - std::vector hOffsets(inputs.size()); - std::vector hdShapes(inputs.size()); - Nd4jLong cOffset = 0; - - // calculating offsets in output - for (int e = 0; e < inputs.size(); e++) { - hOffsets[e] = cOffset; - cOffset += inputs[e]->lengthOf(); - - hdBuffers[e] = inputs[e]->specialBuffer(); - hdShapes[e] = inputs[e]->specialShapeInfo(); - } - - // copying pointers to device - auto dBuffers = (void **) pm.replicatePointer(hdBuffers.data(), inputs.size() * sizeof(void*)); - auto dShapes = (Nd4jLong **)pm.replicatePointer(hdShapes.data(), inputs.size() * sizeof(Nd4jLong*)); - auto dOffsets = (Nd4jLong *) pm.replicatePointer(hOffsets.data(), inputs.size() * sizeof(Nd4jLong)); - - - flattenKernel<<<256, 512, 8192, *context->getCudaStream()>>>(dBuffers, dShapes, dOffsets, inputs.size(), output->specialBuffer(), output->specialShapeInfo(), order); - - pm.synchronize(); - } - - void flatten(sd::LaunchContext *context, std::vector &inputs, NDArray *output, char order) { - // FIXME: we want NDArrayFactory::prepareSpecialUse here eventually - for (auto v:inputs) - v->syncToDevice(); - - BUILD_SINGLE_SELECTOR(output->dataType(), flatten_, (context, inputs, output, order), LIBND4J_TYPES); - NDArray::registerSpecialUse({output}, {}); - } - } - } -} \ No newline at end of file +namespace ops { +namespace helpers { +template +void _CUDA_G flattenKernel(void **xBuffers, Nd4jLong **xShapeInfos, + Nd4jLong *offsets, Nd4jLong numInputs, void *zBuffer, + const Nd4jLong *zShapeInfo, char order) { + int xCoord[MAX_RANK]; + + // each block of threads works on 1 input array + for (Nd4jLong e = blockIdx.x; e < numInputs; e += gridDim.x) { + auto z = reinterpret_cast(zBuffer) + offsets[e]; + + auto xBuffer = reinterpret_cast(xBuffers[e]); + auto xShapeInfo = xShapeInfos[e]; + auto xLength = shape::length(xShapeInfo); + + // each element of this input array has own place within common output array + for (uint i = threadIdx.x; i < xLength; i += blockDim.x) + z[i] = xBuffer[getIndexOffsetOrdered(i, xShapeInfo, order)]; + } +} + +template +void flatten_(sd::LaunchContext *context, std::vector &inputs, + NDArray *output, char order) { + PointersManager pm(context, "flatten"); + + std::vector hdBuffers(inputs.size()); + std::vector hOffsets(inputs.size()); + std::vector hdShapes(inputs.size()); + Nd4jLong cOffset = 0; + + // calculating offsets in output + for (int e = 0; e < inputs.size(); e++) { + hOffsets[e] = cOffset; + cOffset += inputs[e]->lengthOf(); + + hdBuffers[e] = inputs[e]->specialBuffer(); + hdShapes[e] = inputs[e]->specialShapeInfo(); + } + + // copying pointers to device + auto dBuffers = (void **)pm.replicatePointer(hdBuffers.data(), + inputs.size() * sizeof(void *)); + auto dShapes = (Nd4jLong **)pm.replicatePointer( + hdShapes.data(), inputs.size() * sizeof(Nd4jLong *)); + auto dOffsets = (Nd4jLong *)pm.replicatePointer( + hOffsets.data(), inputs.size() * sizeof(Nd4jLong)); + + flattenKernel<<<256, 512, 8192, *context->getCudaStream()>>>( + dBuffers, dShapes, dOffsets, inputs.size(), output->specialBuffer(), + output->specialShapeInfo(), order); + + pm.synchronize(); +} + +void flatten(sd::LaunchContext *context, std::vector &inputs, + NDArray *output, char order) { + // FIXME: we want NDArrayFactory::prepareSpecialUse here eventually + for (auto v : inputs) v->syncToDevice(); + + BUILD_SINGLE_SELECTOR(output->dataType(), flatten_, + (context, inputs, output, order), LIBND4J_TYPES); + NDArray::registerSpecialUse({output}, {}); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/gather.cu b/libnd4j/include/ops/declarable/helpers/cuda/gather.cu index 26778aa63165..efd8fd119074 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/gather.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/gather.cu @@ -18,166 +18,183 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 07.03.2019 // - -#include -#include #include #include +#include -namespace sd { -namespace ops { -namespace helpers { - - template - __global__ static void gatherCudaLinearKernel(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo) { +#include +namespace sd { +namespace ops { +namespace helpers { - __shared__ const X* x; - __shared__ const Y* y; - __shared__ X* z; - __shared__ Nd4jLong xLen, yLen, zLen; +template +__global__ static void gatherCudaLinearKernel( + const void* vx, const Nd4jLong* xShapeInfo, const void* vy, + const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo) { + __shared__ const X* x; + __shared__ const Y* y; + __shared__ X* z; + __shared__ Nd4jLong xLen, yLen, zLen; + + if (threadIdx.x == 0) { + x = reinterpret_cast(vx); + z = reinterpret_cast(vz); + y = reinterpret_cast(vy); + xLen = shape::length(xShapeInfo); + yLen = shape::length(yShapeInfo); + zLen = shape::length(zShapeInfo); + } + __syncthreads(); + // const Nd4jLong zLen = shape::length(zShapeInfo); + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (int j = start; j < zLen; j += step) { + auto zIndex = shape::getIndexOffset(j, zShapeInfo); + auto yIndex = shape::getIndexOffset(j, yShapeInfo); + auto xIndex = shape::getIndexOffset(y[yIndex], xShapeInfo); + z[zIndex] = x[xIndex]; + } +} +////////////////////////////////////////////////////////////////////// +template +__global__ static void gatherCuda(const int numOfSubArrs, const void* vx, + const Nd4jLong* xShapeInfo, + const Nd4jLong* xOffsets, const void* vy, + const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, + const Nd4jLong* zOffsets) { + const Y* y = reinterpret_cast(vy); + __shared__ const X* x; + __shared__ X* z; + + const Nd4jLong len = shape::length(xShapeInfo); + // const Nd4jLong zLen = shape::length(zShapeInfo); + for (int i = blockIdx.x; i < numOfSubArrs; i += gridDim.x) { if (threadIdx.x == 0) { - x = reinterpret_cast(vx); - z = reinterpret_cast(vz); - y = reinterpret_cast(vy); - xLen = shape::length(xShapeInfo); - yLen = shape::length(yShapeInfo); - zLen = shape::length(zShapeInfo); + x = reinterpret_cast(vx) + + xOffsets[y[shape::getIndexOffset(i, yShapeInfo)]]; + z = reinterpret_cast(vz) + zOffsets[i]; } __syncthreads(); - //const Nd4jLong zLen = shape::length(zShapeInfo); - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (int j = start; j < zLen; j += step) { - auto zIndex = shape::getIndexOffset(j, zShapeInfo); - auto yIndex = shape::getIndexOffset(j, yShapeInfo); - auto xIndex = shape::getIndexOffset(y[yIndex], xShapeInfo); - z[zIndex] = x[xIndex]; - } -} -////////////////////////////////////////////////////////////////////// -template -__global__ static void gatherCuda(const int numOfSubArrs, - const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xOffsets, - const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zOffsets) { - - const Y* y = reinterpret_cast(vy); - __shared__ const X* x; - __shared__ X* z; - - const Nd4jLong len = shape::length(xShapeInfo); - //const Nd4jLong zLen = shape::length(zShapeInfo); - for (int i = blockIdx.x; i < numOfSubArrs; i += gridDim.x) { - - if (threadIdx.x == 0) { - x = reinterpret_cast(vx) + xOffsets[y[shape::getIndexOffset(i, yShapeInfo)]]; - z = reinterpret_cast(vz) + zOffsets[i]; - } - __syncthreads(); - - for (int j = threadIdx.x; j < len; j += blockDim.x) { - auto zIndex = shape::getIndexOffset(j, zShapeInfo); - auto xIndex = shape::getIndexOffset(j, xShapeInfo); - z[zIndex] = x[xIndex]; - } - __syncthreads(); + for (int j = threadIdx.x; j < len; j += blockDim.x) { + auto zIndex = shape::getIndexOffset(j, zShapeInfo); + auto xIndex = shape::getIndexOffset(j, xShapeInfo); + z[zIndex] = x[xIndex]; } + __syncthreads(); + } } -template -__host__ static void gatherCudaLinear(const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo) { - gatherCudaLinearKernel<<<128, 256, 1024, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); +template +__host__ static void gatherCudaLinear(const cudaStream_t* stream, + const void* vx, + const Nd4jLong* xShapeInfo, + const void* vy, + const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo) { + gatherCudaLinearKernel<<<128, 256, 1024, *stream>>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); } ////////////////////////////////////////////////////////////////////// -template -__host__ static void gatherCudaLauncher(const cudaStream_t *stream, const int numOfSubArrs, - const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xOffsets, - const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zOffsets) { - gatherCuda<<>>(numOfSubArrs, vx, xShapeInfo, xOffsets, vy, yShapeInfo, vz, zShapeInfo, zOffsets); +template +__host__ static void gatherCudaLauncher( + const cudaStream_t* stream, const int numOfSubArrs, const void* vx, + const Nd4jLong* xShapeInfo, const Nd4jLong* xOffsets, const void* vy, + const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + const Nd4jLong* zOffsets) { + gatherCuda<<>>( + numOfSubArrs, vx, xShapeInfo, xOffsets, vy, yShapeInfo, vz, zShapeInfo, + zOffsets); } ////////////////////////////////////////////////////////////////////// -void gather(sd::LaunchContext * context, const NDArray* input, const NDArray* indices, NDArray* output, const std::vector& intArgs) { - - const int inputRank = input->rankOf(); - const int numOfIntArgs = intArgs.size(); - - int axis = numOfIntArgs > 0 ? intArgs[0] : 0; - if(axis < 0) - axis += inputRank; - - if (indices == nullptr && numOfIntArgs == 2) { // scalar case - output->assign((*input)(intArgs[1], {axis})); +void gather(sd::LaunchContext* context, const NDArray* input, + const NDArray* indices, NDArray* output, + const std::vector& intArgs) { + const int inputRank = input->rankOf(); + const int numOfIntArgs = intArgs.size(); + + int axis = numOfIntArgs > 0 ? intArgs[0] : 0; + if (axis < 0) axis += inputRank; + + if (indices == nullptr && numOfIntArgs == 2) { // scalar case + output->assign((*input)(intArgs[1], {axis})); + } else if (indices != nullptr && indices->isScalar()) { + if (input->rankOf() <= 1) { // For scalar indices, rank 0 or 1 input: can't + // do tensor along dimension 0 as this is whole + // array... instead, we want to get a scalar + auto idx = indices->e(0); + auto scalarNDArray = input->e(idx); + output->assign(scalarNDArray); + } else { + NDArray inSubArr = (*input)(indices->e(0), {axis}); + output->assign(inSubArr); } - else if (indices != nullptr && indices->isScalar()) { - - if(input->rankOf() <= 1) { //For scalar indices, rank 0 or 1 input: can't do tensor along dimension 0 as this is whole array... instead, we want to get a scalar - auto idx = indices->e(0); - auto scalarNDArray = input->e(idx); - output->assign(scalarNDArray); - } - else { - NDArray inSubArr = (*input)(indices->e(0), {axis}); - output->assign(inSubArr); - } + } else { + NDArray* pIndices = const_cast(indices); + if (indices == nullptr) + pIndices = + new NDArray(input->ordering(), {numOfIntArgs - 1}, + std::vector(intArgs.begin() + 1, intArgs.end()), + DataType::INT64, input->getContext()); + + std::vector dimsOut(pIndices->rankOf()); + std::iota(dimsOut.begin(), dimsOut.end(), + axis); // fill with axis, axis+1, ... axis+pIndices->rankOf()-1 + + const Nd4jLong numOfSubArrs = pIndices->lengthOf(); + + Nd4jLong *outSubArrShapeInfo(nullptr), *inSubArrShapeInfo(nullptr), + *outSubArrOffsets(nullptr), *inSubArrOffsets(nullptr); + input->getSubArrShapeAndOffsets({axis}, inSubArrShapeInfo, inSubArrOffsets); + output->getSubArrShapeAndOffsets(dimsOut, outSubArrShapeInfo, + outSubArrOffsets); + if (output->rankOf() > 1) { + PointersManager manager(context, "gather"); + auto xShapeInfo = reinterpret_cast(manager.replicatePointer( + inSubArrShapeInfo, shape::shapeInfoByteLength(inSubArrShapeInfo))); + auto zShapeInfo = reinterpret_cast(manager.replicatePointer( + outSubArrShapeInfo, shape::shapeInfoByteLength(outSubArrShapeInfo))); + auto xOffsets = reinterpret_cast(manager.replicatePointer( + inSubArrOffsets, + (input->lengthOf() / shape::length(inSubArrShapeInfo)) * + sizeof(Nd4jLong))); + auto zOffsets = reinterpret_cast(manager.replicatePointer( + outSubArrOffsets, + (output->lengthOf() / shape::length(outSubArrShapeInfo)) * + sizeof(Nd4jLong))); + + NDArray::prepareSpecialUse({output}, {input, pIndices}); + BUILD_DOUBLE_SELECTOR( + input->dataType(), pIndices->dataType(), gatherCudaLauncher, + (context->getCudaStream(), numOfSubArrs, input->specialBuffer(), + xShapeInfo, xOffsets, pIndices->specialBuffer(), + pIndices->specialShapeInfo(), output->specialBuffer(), zShapeInfo, + zOffsets), + LIBND4J_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, pIndices}); + manager.synchronize(); + } else { + NDArray::prepareSpecialUse({output}, {input, pIndices}); + BUILD_DOUBLE_SELECTOR( + input->dataType(), pIndices->dataType(), gatherCudaLinear, + (context->getCudaStream(), input->specialBuffer(), + input->specialShapeInfo(), pIndices->specialBuffer(), + pIndices->specialShapeInfo(), output->specialBuffer(), + output->specialShapeInfo()), + LIBND4J_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, pIndices}); } - else { - - NDArray* pIndices = const_cast(indices); - if(indices == nullptr) - pIndices = new NDArray(input->ordering(), {numOfIntArgs-1}, std::vector(intArgs.begin() + 1, intArgs.end()), DataType::INT64, input->getContext()); - - std::vector dimsOut(pIndices->rankOf()); - std::iota(dimsOut.begin(), dimsOut.end(), axis); // fill with axis, axis+1, ... axis+pIndices->rankOf()-1 - - const Nd4jLong numOfSubArrs = pIndices->lengthOf(); - - Nd4jLong *outSubArrShapeInfo(nullptr), *inSubArrShapeInfo(nullptr), *outSubArrOffsets(nullptr), *inSubArrOffsets(nullptr); - input-> getSubArrShapeAndOffsets({axis}, inSubArrShapeInfo, inSubArrOffsets); - output->getSubArrShapeAndOffsets(dimsOut, outSubArrShapeInfo, outSubArrOffsets); - if (output->rankOf() > 1) { - PointersManager manager(context, "gather"); - auto xShapeInfo = reinterpret_cast(manager.replicatePointer(inSubArrShapeInfo, - shape::shapeInfoByteLength( - inSubArrShapeInfo))); - auto zShapeInfo = reinterpret_cast(manager.replicatePointer(outSubArrShapeInfo, - shape::shapeInfoByteLength( - outSubArrShapeInfo))); - auto xOffsets = reinterpret_cast(manager.replicatePointer(inSubArrOffsets, (input->lengthOf() / - shape::length( - inSubArrShapeInfo)) * - sizeof(Nd4jLong))); - auto zOffsets = reinterpret_cast(manager.replicatePointer(outSubArrOffsets, - (output->lengthOf() / - shape::length(outSubArrShapeInfo)) * - sizeof(Nd4jLong))); - - NDArray::prepareSpecialUse({output}, {input, pIndices}); - BUILD_DOUBLE_SELECTOR(input->dataType(), pIndices->dataType(), gatherCudaLauncher, (context->getCudaStream(), numOfSubArrs, input->specialBuffer(), xShapeInfo, xOffsets, pIndices->specialBuffer(), pIndices->specialShapeInfo(), output->specialBuffer(), zShapeInfo, zOffsets), LIBND4J_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, pIndices}); - manager.synchronize(); - } - else { - NDArray::prepareSpecialUse({output}, {input, pIndices}); - BUILD_DOUBLE_SELECTOR(input->dataType(), pIndices->dataType(), gatherCudaLinear, (context->getCudaStream(), input->specialBuffer(), input->specialShapeInfo(), pIndices->specialBuffer(), pIndices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()), LIBND4J_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, pIndices}); - - } - - if(indices == nullptr) - delete pIndices; - } + if (indices == nullptr) delete pIndices; + } } -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/gather_nd.cu b/libnd4j/include/ops/declarable/helpers/cuda/gather_nd.cu index d72f3e1bc126..1ab8aeba2790 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/gather_nd.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/gather_nd.cu @@ -18,128 +18,136 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - -#include -#include -#include -#include #include -#include +#include #include -#include #include +#include +#include +#include +#include + +#include namespace sd { - namespace ops { - namespace helpers { - /////////////////////////////////////////////////////////////////// +namespace ops { +namespace helpers { +/////////////////////////////////////////////////////////////////// // x - input, y - indices, z - output - template - __global__ static void gatherNDCuda(const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); +template +__global__ static void gatherNDCuda(const void *vx, const Nd4jLong *xShapeInfo, + const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo) { + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); - __shared__ int xRank, yRank, zRank, maxRank, yLastDim; - __shared__ Nd4jLong zLen, totalThreads, *sharedMem; + __shared__ int xRank, yRank, zRank, maxRank, yLastDim; + __shared__ Nd4jLong zLen, totalThreads, *sharedMem; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - xRank = shape::rank(xShapeInfo); - yRank = shape::rank(yShapeInfo); - zRank = shape::rank(zShapeInfo); - maxRank = sd::math::nd4j_max(yRank, sd::math::nd4j_max(xRank, zRank)); + xRank = shape::rank(xShapeInfo); + yRank = shape::rank(yShapeInfo); + zRank = shape::rank(zShapeInfo); + maxRank = + sd::math::nd4j_max(yRank, sd::math::nd4j_max(xRank, zRank)); - zLen = shape::length(zShapeInfo); - yLastDim = yShapeInfo[yRank]; + zLen = shape::length(zShapeInfo); + yLastDim = yShapeInfo[yRank]; - totalThreads = gridDim.x * blockDim.x; - } - __syncthreads(); + totalThreads = gridDim.x * blockDim.x; + } + __syncthreads(); - auto coord = sharedMem + threadIdx.x * maxRank; + auto coord = sharedMem + threadIdx.x * maxRank; - Nd4jLong *zCoordStart, *xCoordStart; + Nd4jLong *zCoordStart, *xCoordStart; - if(yLastDim == xRank) { - zCoordStart = coord; - xCoordStart = coord; - } - if(zRank >= xRank) { - zCoordStart = coord; - xCoordStart = coord + zRank - xRank; - } - else { - zCoordStart = coord + xRank - zRank; - xCoordStart = coord; - } + if (yLastDim == xRank) { + zCoordStart = coord; + xCoordStart = coord; + } + if (zRank >= xRank) { + zCoordStart = coord; + xCoordStart = coord + zRank - xRank; + } else { + zCoordStart = coord + xRank - zRank; + xCoordStart = coord; + } - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (Nd4jLong i = tid; i < zLen; i += totalThreads) { + for (Nd4jLong i = tid; i < zLen; i += totalThreads) { + shape::index2coords(i, zShapeInfo, zCoordStart); - shape::index2coords(i, zShapeInfo, zCoordStart); + const auto zOffset = shape::getOffset(zShapeInfo, zCoordStart); - const auto zOffset = shape::getOffset(zShapeInfo, zCoordStart); + // last y coordinate + int coordToRestore; + if (yLastDim != xRank) + coordToRestore = static_cast(zCoordStart[yRank - 1]); - // last y coordinate - int coordToRestore; - if(yLastDim != xRank) - coordToRestore = static_cast(zCoordStart[yRank - 1]); + zCoordStart[yRank - 1] = 0; // last y coordinate + const auto yOffset = shape::getOffset(yShapeInfo, zCoordStart); - zCoordStart[yRank - 1] = 0; // last y coordinate - const auto yOffset = shape::getOffset(yShapeInfo, zCoordStart); + // restore z coordinate + if (yLastDim != xRank) zCoordStart[yRank - 1] = coordToRestore; - //restore z coordinate - if(yLastDim != xRank) - zCoordStart[yRank - 1] = coordToRestore; + // construct coordinates for x + for (uint j = 0; j < yLastDim; ++j) + xCoordStart[j] = y[yOffset + j * yShapeInfo[2 * yRank]]; // last stride - // construct coordinates for x - for(uint j = 0; j < yLastDim; ++j) - xCoordStart[j] = y[yOffset + j * yShapeInfo[2 * yRank]]; // last stride + const auto xOffset = shape::getOffset(xShapeInfo, xCoordStart); - const auto xOffset = shape::getOffset(xShapeInfo, xCoordStart); - - z[zOffset] = x[xOffset]; - // printf("z[%lld] = x[%lld] = %f\n", zOffset, xOffset, (float) z[zOffset]); - } - } + z[zOffset] = x[xOffset]; + // printf("z[%lld] = x[%lld] = %f\n", zOffset, xOffset, (float) z[zOffset]); + } +} /////////////////////////////////////////////////////////////////// - template - static void gatherNDCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - gatherNDCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); - } +template +static void gatherNDCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + gatherNDCuda<<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); +} /////////////////////////////////////////////////////////////////// - void gatherND(sd::LaunchContext * context, NDArray& input, NDArray& indices, NDArray& output) { - - const int maxRank = sd::math::nd4j_max(indices.rankOf(), sd::math::nd4j_max(input.rankOf(), output.rankOf())); - - const int threadsPerBlock = 256; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = 8 * threadsPerBlock * maxRank + 128; - - const auto xType = input.dataType(); - const auto yType = indices.dataType(); - - PointersManager manager(context, "gatherND"); - - NDArray::prepareSpecialUse({&output}, {&input, &indices}); - BUILD_DOUBLE_SELECTOR(xType, yType, gatherNDCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), indices.specialBuffer(), indices.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo()), LIBND4J_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({&output}, {&input, &indices}); - - manager.synchronize(); - } - } - } -} \ No newline at end of file +void gatherND(sd::LaunchContext *context, NDArray &input, NDArray &indices, + NDArray &output) { + const int maxRank = sd::math::nd4j_max( + indices.rankOf(), + sd::math::nd4j_max(input.rankOf(), output.rankOf())); + + const int threadsPerBlock = 256; + const int blocksPerGrid = + (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = 8 * threadsPerBlock * maxRank + 128; + + const auto xType = input.dataType(); + const auto yType = indices.dataType(); + + PointersManager manager(context, "gatherND"); + + NDArray::prepareSpecialUse({&output}, {&input, &indices}); + BUILD_DOUBLE_SELECTOR( + xType, yType, gatherNDCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), indices.specialBuffer(), + indices.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo()), + LIBND4J_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({&output}, {&input, &indices}); + + manager.synchronize(); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/gradient.cu b/libnd4j/include/ops/declarable/helpers/cuda/gradient.cu index f165d88b7e78..2fa6b4eb33a1 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/gradient.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/gradient.cu @@ -25,19 +25,23 @@ namespace sd { namespace ops { namespace helpers { template -void applyGradientDescent_(LaunchContext* context, NDArray* input, NDArray* step, double weight, NDArray* output) { - // classic one - auto lambda = LAMBDA_TT(_x, _y, weight) { - return _x - (_y * weight); - }; +void applyGradientDescent_(LaunchContext* context, NDArray* input, + NDArray* step, double weight, NDArray* output) { + // classic one + auto lambda = LAMBDA_TT(_x, _y, weight) { return _x - (_y * weight); }; - input->applyPairwiseLambda(*step, lambda, *output); + input->applyPairwiseLambda(*step, lambda, *output); } -void applyGradientDescent(sd::LaunchContext* context, NDArray* input, NDArray* step, double weight, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), applyGradientDescent_, (context, input, step, weight, output), FLOAT_TYPES); -} -BUILD_SINGLE_TEMPLATE(template void applyGradientDescent_, (LaunchContext* context, NDArray* input, NDArray* step, double weight, NDArray* output), FLOAT_TYPES); -} -} +void applyGradientDescent(sd::LaunchContext* context, NDArray* input, + NDArray* step, double weight, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), applyGradientDescent_, + (context, input, step, weight, output), FLOAT_TYPES); } +BUILD_SINGLE_TEMPLATE(template void applyGradientDescent_, + (LaunchContext * context, NDArray* input, NDArray* step, + double weight, NDArray* output), + FLOAT_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/hamming.cu b/libnd4j/include/ops/declarable/helpers/cuda/hamming.cu index e3fdd94116ea..67905940dbbf 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/hamming.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/hamming.cu @@ -18,78 +18,91 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { - namespace ops { - namespace helpers { - template - static _CUDA_G void _hammingKernel(const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, void *reductionBuffer, Nd4jLong length) { - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ Nd4jLong *shared; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - shared = reinterpret_cast(shmem); - } - __syncthreads(); - - // we want to nullify temporary memory before accumulating intermediate results - shared[threadIdx.x] = 0; - - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - for (Nd4jLong e = tid; e < length; e += blockDim.x * gridDim.x) { - auto _x = static_cast(x[shape::getIndexOffset(e, xShapeInfo)]); - auto _y = static_cast(y[shape::getIndexOffset(e, yShapeInfo)]); - - // we save intermediate result into shared memory - shared[threadIdx.x] += __popcll(_x ^ _y); - } - __syncthreads(); - - // now we accumulate values - auto numItems = sd::math::nd4j_min(blockDim.x, length); - auto floorPow2 = numItems; - if (floorPow2 & (floorPow2 - 1)) { - - while (floorPow2 & (floorPow2 - 1)) - floorPow2 &= floorPow2 - 1; - - if (threadIdx.x >= floorPow2) - shared[threadIdx.x - floorPow2] = shared[threadIdx.x - floorPow2] + shared[threadIdx.x]; - - __syncthreads(); - } - __syncthreads(); - - for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; activeThreads >>= 1) { - if (threadIdx.x < activeThreads && threadIdx.x + activeThreads < numItems) - shared[threadIdx.x] = shared[threadIdx.x] + shared[threadIdx.x + activeThreads]; - - __syncthreads(); - } - __syncthreads(); - - // FIXME: do we really want atomicAdd on global memory here - // and store them to output - if (threadIdx.x == 0 && shared[0] > 0) - sd::math::atomics::nd4j_atomicAdd(&z[0], static_cast(shared[threadIdx.x])); - } - - template - static void _hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &z) { - _hammingKernel<<<256, 256, 256 * sizeof(Nd4jLong) + 256, *context->getCudaStream()>>>(x.specialBuffer(), x.specialShapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.specialBuffer(), nullptr, x.lengthOf()); - } - - void hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &output) { - NDArray::prepareSpecialUse({&output}, {&x, &y}); - BUILD_DOUBLE_SELECTOR(x.dataType(), output.dataType(), _hamming, (context, x, y, output), INTEGER_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({&output}, {&x, &y}); - } - } - } -} \ No newline at end of file +namespace ops { +namespace helpers { +template +static _CUDA_G void _hammingKernel(const void *vx, const Nd4jLong *xShapeInfo, + const void *vy, const Nd4jLong *yShapeInfo, + void *vz, void *reductionBuffer, + Nd4jLong length) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + __shared__ Nd4jLong *shared; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + shared = reinterpret_cast(shmem); + } + __syncthreads(); + + // we want to nullify temporary memory before accumulating intermediate + // results + shared[threadIdx.x] = 0; + + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + for (Nd4jLong e = tid; e < length; e += blockDim.x * gridDim.x) { + auto _x = static_cast( + x[shape::getIndexOffset(e, xShapeInfo)]); + auto _y = static_cast( + y[shape::getIndexOffset(e, yShapeInfo)]); + + // we save intermediate result into shared memory + shared[threadIdx.x] += __popcll(_x ^ _y); + } + __syncthreads(); + + // now we accumulate values + auto numItems = sd::math::nd4j_min(blockDim.x, length); + auto floorPow2 = numItems; + if (floorPow2 & (floorPow2 - 1)) { + while (floorPow2 & (floorPow2 - 1)) floorPow2 &= floorPow2 - 1; + + if (threadIdx.x >= floorPow2) + shared[threadIdx.x - floorPow2] = + shared[threadIdx.x - floorPow2] + shared[threadIdx.x]; + + __syncthreads(); + } + __syncthreads(); + + for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; + activeThreads >>= 1) { + if (threadIdx.x < activeThreads && threadIdx.x + activeThreads < numItems) + shared[threadIdx.x] = + shared[threadIdx.x] + shared[threadIdx.x + activeThreads]; + + __syncthreads(); + } + __syncthreads(); + + // FIXME: do we really want atomicAdd on global memory here + // and store them to output + if (threadIdx.x == 0 && shared[0] > 0) + sd::math::atomics::nd4j_atomicAdd(&z[0], + static_cast(shared[threadIdx.x])); +} + +template +static void _hamming(LaunchContext *context, NDArray &x, NDArray &y, + NDArray &z) { + _hammingKernel + <<<256, 256, 256 * sizeof(Nd4jLong) + 256, *context->getCudaStream()>>>( + x.specialBuffer(), x.specialShapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), z.specialBuffer(), nullptr, x.lengthOf()); +} + +void hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &output) { + NDArray::prepareSpecialUse({&output}, {&x, &y}); + BUILD_DOUBLE_SELECTOR(x.dataType(), output.dataType(), _hamming, + (context, x, y, output), INTEGER_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({&output}, {&x, &y}); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/hashcode.cu b/libnd4j/include/ops/declarable/helpers/cuda/hashcode.cu index 1c4ca9152345..a66893b6cc8f 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/hashcode.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/hashcode.cu @@ -20,112 +20,127 @@ #include - namespace sd { - namespace ops { - namespace helpers { - template - static __global__ void splitBufferToChuncks(T* buffer, Nd4jLong* tempBuffer, Nd4jLong numBlocks, Nd4jLong blockSize, Nd4jLong length) { - - for (int b = blockIdx.x * blockDim.x + threadIdx.x; b < numBlocks; b += gridDim.x*blockDim.x) { - auto blockBuffer = buffer + b * numBlocks; - - Nd4jLong r = 1LL; - for (int e = 0; e < blockSize && e + (b * numBlocks) < length; e++) { - auto v = longBytes(blockBuffer[e]); - r = 31LL * r + v; - } - - tempBuffer[b] = r; - } - } - - template - static __global__ void internalHash(Nd4jLong* tempBuffer, Nd4jLong* tempResult, Nd4jLong numBlocks, Nd4jLong blockSize, Nd4jLong lastLength) { - - for (int b = blockIdx.x * blockDim.x + threadIdx.x; b < numBlocks; b += gridDim.x * blockDim.x) { - auto blockBuffer = tempBuffer + b * numBlocks; - Nd4jLong r = 1LL; - - for (Nd4jLong e = 0; e < blockSize && e + (b * numBlocks) < lastLength; e++) { - auto v = longBytes(blockBuffer[e]); - r = 31LL * r + v; - } - - tempResult[b] = r; - - } - - } - - - static __global__ void lastStep(Nd4jLong* resultBuf, Nd4jLong* tempBufferA, Nd4jLong* tempResult, Nd4jLong length, Nd4jLong blockSize) { - if (threadIdx.x == 0) { - if (length <= blockSize) - *resultBuf = *tempBufferA; - else - *resultBuf = *tempResult; - } - } - - template - void hashCode_(LaunchContext *context, NDArray &array, NDArray &result) { - auto blockSize = 32; - auto stream = context->getCudaStream(); - array.syncToDevice(); - - NDArray::prepareSpecialUse({&result}, {&array}); - auto length = array.lengthOf(); - int numBlocks = length / blockSize + ((length % blockSize == 0) ? 0 : 1); - auto tempA = NDArrayFactory::create('c', {numBlocks}, context); - auto tempB = NDArrayFactory::create('c', { numBlocks / blockSize + 1}, context); - - auto buffer = reinterpret_cast(array.specialBuffer()); //bufferAsT(); - auto tempBufferA = reinterpret_cast(tempA.specialBuffer()); //bufferAsT(); - auto tempBufferB = reinterpret_cast(tempB.specialBuffer()); //bufferAsT(); - - // default buffer is the first one, because it might be the last one in case of small arrays (< blockSize) - auto tempBuffer = tempBufferA; - auto tempResult = tempBufferB; - - // we divide array into 32 element chunks, and store intermediate results once - splitBufferToChuncks<<>>(buffer, tempBuffer, numBlocks, blockSize, length); - - // we replace pointer with intermediate one, and repeat only one chunk left - int iterationCount = 0; - while (numBlocks > 1) { - int lastLength = numBlocks; - numBlocks = lastLength / blockSize + ((lastLength % blockSize == 0) ? 0 : 1); - - - internalHash<<>>(tempBuffer, tempResult, numBlocks, blockSize, lastLength); - - - iterationCount++; - // swapping buffers - if (iterationCount % 2 == 0) { - tempBuffer = tempBufferA; - tempResult = tempBufferB; - } else { - tempBuffer = tempBufferB; - tempResult = tempBufferA; - } - } - - lastStep<<<1,1,128, *stream>>>(reinterpret_cast(result.specialBuffer()), tempBufferA, tempResult, length, blockSize); -// tempA.syncToHost(); -// tempB.syncToHost(); -// result.assign((length <= blockSize?tempA.e(0) : tempB.e(0))); - - NDArray::registerSpecialUse({&result}, {&array}); - } - - void hashCode(LaunchContext *context, NDArray &array, NDArray &result) { - BUILD_SINGLE_SELECTOR(array.dataType(), hashCode_, (context, array, result), LIBND4J_TYPES); - } +namespace ops { +namespace helpers { +template +static __global__ void splitBufferToChuncks(T* buffer, Nd4jLong* tempBuffer, + Nd4jLong numBlocks, + Nd4jLong blockSize, + Nd4jLong length) { + for (int b = blockIdx.x * blockDim.x + threadIdx.x; b < numBlocks; + b += gridDim.x * blockDim.x) { + auto blockBuffer = buffer + b * numBlocks; + + Nd4jLong r = 1LL; + for (int e = 0; e < blockSize && e + (b * numBlocks) < length; e++) { + auto v = longBytes(blockBuffer[e]); + r = 31LL * r + v; + } + + tempBuffer[b] = r; + } +} + +template +static __global__ void internalHash(Nd4jLong* tempBuffer, Nd4jLong* tempResult, + Nd4jLong numBlocks, Nd4jLong blockSize, + Nd4jLong lastLength) { + for (int b = blockIdx.x * blockDim.x + threadIdx.x; b < numBlocks; + b += gridDim.x * blockDim.x) { + auto blockBuffer = tempBuffer + b * numBlocks; + Nd4jLong r = 1LL; + + for (Nd4jLong e = 0; e < blockSize && e + (b * numBlocks) < lastLength; + e++) { + auto v = longBytes(blockBuffer[e]); + r = 31LL * r + v; + } + + tempResult[b] = r; + } +} + +static __global__ void lastStep(Nd4jLong* resultBuf, Nd4jLong* tempBufferA, + Nd4jLong* tempResult, Nd4jLong length, + Nd4jLong blockSize) { + if (threadIdx.x == 0) { + if (length <= blockSize) + *resultBuf = *tempBufferA; + else + *resultBuf = *tempResult; + } +} - BUILD_SINGLE_TEMPLATE(template void hashCode_, (LaunchContext* context, NDArray& array, NDArray& result), LIBND4J_TYPES); - } +template +void hashCode_(LaunchContext* context, NDArray& array, NDArray& result) { + auto blockSize = 32; + auto stream = context->getCudaStream(); + array.syncToDevice(); + + NDArray::prepareSpecialUse({&result}, {&array}); + auto length = array.lengthOf(); + int numBlocks = length / blockSize + ((length % blockSize == 0) ? 0 : 1); + auto tempA = NDArrayFactory::create('c', {numBlocks}, context); + auto tempB = NDArrayFactory::create( + 'c', {numBlocks / blockSize + 1}, context); + + auto buffer = reinterpret_cast(array.specialBuffer()); // bufferAsT(); + auto tempBufferA = reinterpret_cast( + tempA.specialBuffer()); // bufferAsT(); + auto tempBufferB = reinterpret_cast( + tempB.specialBuffer()); // bufferAsT(); + + // default buffer is the first one, because it might be the last one in case + // of small arrays (< blockSize) + auto tempBuffer = tempBufferA; + auto tempResult = tempBufferB; + + // we divide array into 32 element chunks, and store intermediate results once + splitBufferToChuncks<<>>( + buffer, tempBuffer, numBlocks, blockSize, length); + + // we replace pointer with intermediate one, and repeat only one chunk left + int iterationCount = 0; + while (numBlocks > 1) { + int lastLength = numBlocks; + numBlocks = + lastLength / blockSize + ((lastLength % blockSize == 0) ? 0 : 1); + + internalHash<<>>( + tempBuffer, tempResult, numBlocks, blockSize, lastLength); + + iterationCount++; + // swapping buffers + if (iterationCount % 2 == 0) { + tempBuffer = tempBufferA; + tempResult = tempBufferB; + } else { + tempBuffer = tempBufferB; + tempResult = tempBufferA; } + } + + lastStep<<<1, 1, 128, *stream>>>( + reinterpret_cast(result.specialBuffer()), tempBufferA, + tempResult, length, blockSize); + // tempA.syncToHost(); + // tempB.syncToHost(); + // result.assign((length <= blockSize?tempA.e(0) : + // tempB.e(0))); + + NDArray::registerSpecialUse({&result}, {&array}); +} + +void hashCode(LaunchContext* context, NDArray& array, NDArray& result) { + BUILD_SINGLE_SELECTOR(array.dataType(), hashCode_, (context, array, result), + LIBND4J_TYPES); } +BUILD_SINGLE_TEMPLATE(template void hashCode_, + (LaunchContext * context, NDArray& array, + NDArray& result), + LIBND4J_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu b/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu index 6d7310fc0c20..dbdd24d65048 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu @@ -18,118 +18,137 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { - namespace ops { - namespace helpers { - template - void _CUDA_G histogramKernel(void *xBuffer, const Nd4jLong *xShapeInfo, void *zBuffer, const Nd4jLong *zShapeInfo, void *allocationPointer, void *reductionPointer, Nd4jLong numBins, X* min_val, X* max_val) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - auto dx = reinterpret_cast(xBuffer); - auto result = reinterpret_cast(zBuffer); - - __shared__ Z *bins; - __shared__ int length; - __shared__ Z *reductor; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - bins = (Z *) shmem; - reductor = ((Z *) allocationPointer) + (numBins * blockIdx.x); - - length = shape::length(xShapeInfo); - } - __syncthreads(); - - X binSize = X((*max_val - *min_val) / numBins); - - // nullify bins - for (int e = threadIdx.x; e < numBins; e += blockDim.x) { - bins[e] = (Z) 0; - } - __syncthreads(); - - for (int e = tid; e < length; e += blockDim.x * gridDim.x) { - int idx = int((dx[e] - *min_val) / binSize); - idx = math::nd4j_max(idx, 0); //atomicMax(&idx, 0);//atomicMax(&idx, 0); - idx = math::nd4j_min(idx, int(numBins - 1)); //atomicMin(&idx, int(numBins - 1)); - sd::math::atomics::nd4j_atomicAdd(&bins[idx], (Z)1); - } - __syncthreads(); - // at this point all bins in shared memory are calculated, so we aggregate them now via threadfence trick - - // transfer shared memory to reduction memory - if (gridDim.x > 1) { - unsigned int *tc = (unsigned int *)reductionPointer; - __shared__ bool amLast; - - for (int e = threadIdx.x; e < numBins; e += blockDim.x) { - reductor[e] = bins[e]; - } - __threadfence(); - __syncthreads(); - - if (threadIdx.x == 0) { - unsigned int ticket = atomicInc(&tc[16384], gridDim.x); - amLast = (ticket == gridDim.x - 1); - } - __syncthreads(); - - if (amLast) { - tc[16384] = 0; - - // nullify shared memory for future accumulation - for (int e = threadIdx.x; e < numBins; e += blockDim.x) { - bins[e] = (Z) 0; - } - - // accumulate reduced bins - for (int r = 0; r < gridDim.x; r++) { - Z *ptrBuf = ((Z *)allocationPointer) + (r * numBins); - - for (int e = threadIdx.x; e < numBins; e += blockDim.x) { - math::atomics::nd4j_atomicAdd(&bins[e], ptrBuf[e]); - } - } - __syncthreads(); - - // write them out to Z - for (int e = threadIdx.x; e < numBins; e += blockDim.x) { - result[e] = bins[e]; - } - } - } else { - // if there's only 1 block - just write away data - for (int e = threadIdx.x; e < numBins; e += blockDim.x) { - result[e] = bins[e]; - } - } - } - - template - static void histogram_(sd::LaunchContext *context, void *xBuffer, const Nd4jLong *xShapeInfo, const Nd4jLong *dxShapeInfo, void *zBuffer, const Nd4jLong *zShapeInfo, Nd4jLong numBins, void* min_val, void* max_val) { - int numThreads = 256; - int numBlocks = sd::math::nd4j_max(256, sd::math::nd4j_min(1, shape::length(xShapeInfo) / numThreads)); - int workspaceSize = numBlocks * numBins; - auto tmp = NDArrayFactory::create('c', {workspaceSize}, context); - - histogramKernel<<getCudaStream()>>>(xBuffer, dxShapeInfo, zBuffer, zShapeInfo, tmp.specialBuffer(), context->getReductionPointer(), numBins, reinterpret_cast(min_val), reinterpret_cast(max_val)); - - cudaStreamSynchronize(*context->getCudaStream()); - } - - void histogramHelper(sd::LaunchContext *context, NDArray &input, NDArray &output) { - Nd4jLong numBins = output.lengthOf(); - NDArray::registerSpecialUse({&output}, {&input}); - - auto min_val = input.reduceNumber(reduce::SameOps::Min); - auto max_val = input.reduceNumber(reduce::SameOps::Max); -// min_val.printIndexedBuffer("MIN"); -// max_val.printIndexedBuffer("MAX"); - BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (context, input.specialBuffer(), input.shapeInfo(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), numBins, min_val.specialBuffer(), max_val.specialBuffer()), LIBND4J_TYPES, INTEGER_TYPES); - NDArray::registerSpecialUse({&output}, {&input}); - } +namespace ops { +namespace helpers { +template +void _CUDA_G histogramKernel(void *xBuffer, const Nd4jLong *xShapeInfo, + void *zBuffer, const Nd4jLong *zShapeInfo, + void *allocationPointer, void *reductionPointer, + Nd4jLong numBins, X *min_val, X *max_val) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + auto dx = reinterpret_cast(xBuffer); + auto result = reinterpret_cast(zBuffer); + + __shared__ Z *bins; + __shared__ int length; + __shared__ Z *reductor; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + bins = (Z *)shmem; + reductor = ((Z *)allocationPointer) + (numBins * blockIdx.x); + + length = shape::length(xShapeInfo); + } + __syncthreads(); + + X binSize = X((*max_val - *min_val) / numBins); + + // nullify bins + for (int e = threadIdx.x; e < numBins; e += blockDim.x) { + bins[e] = (Z)0; + } + __syncthreads(); + + for (int e = tid; e < length; e += blockDim.x * gridDim.x) { + int idx = int((dx[e] - *min_val) / binSize); + idx = math::nd4j_max(idx, 0); // atomicMax(&idx, 0);//atomicMax(&idx, 0); + idx = math::nd4j_min( + idx, int(numBins - 1)); // atomicMin(&idx, int(numBins - 1)); + sd::math::atomics::nd4j_atomicAdd(&bins[idx], (Z)1); + } + __syncthreads(); + // at this point all bins in shared memory are calculated, so we aggregate + // them now via threadfence trick + + // transfer shared memory to reduction memory + if (gridDim.x > 1) { + unsigned int *tc = (unsigned int *)reductionPointer; + __shared__ bool amLast; + + for (int e = threadIdx.x; e < numBins; e += blockDim.x) { + reductor[e] = bins[e]; + } + __threadfence(); + __syncthreads(); + + if (threadIdx.x == 0) { + unsigned int ticket = atomicInc(&tc[16384], gridDim.x); + amLast = (ticket == gridDim.x - 1); + } + __syncthreads(); + + if (amLast) { + tc[16384] = 0; + + // nullify shared memory for future accumulation + for (int e = threadIdx.x; e < numBins; e += blockDim.x) { + bins[e] = (Z)0; + } + + // accumulate reduced bins + for (int r = 0; r < gridDim.x; r++) { + Z *ptrBuf = ((Z *)allocationPointer) + (r * numBins); + + for (int e = threadIdx.x; e < numBins; e += blockDim.x) { + math::atomics::nd4j_atomicAdd(&bins[e], ptrBuf[e]); } + } + __syncthreads(); + + // write them out to Z + for (int e = threadIdx.x; e < numBins; e += blockDim.x) { + result[e] = bins[e]; + } + } + } else { + // if there's only 1 block - just write away data + for (int e = threadIdx.x; e < numBins; e += blockDim.x) { + result[e] = bins[e]; } -} \ No newline at end of file + } +} + +template +static void histogram_(sd::LaunchContext *context, void *xBuffer, + const Nd4jLong *xShapeInfo, const Nd4jLong *dxShapeInfo, + void *zBuffer, const Nd4jLong *zShapeInfo, + Nd4jLong numBins, void *min_val, void *max_val) { + int numThreads = 256; + int numBlocks = sd::math::nd4j_max( + 256, sd::math::nd4j_min(1, shape::length(xShapeInfo) / numThreads)); + int workspaceSize = numBlocks * numBins; + auto tmp = NDArrayFactory::create('c', {workspaceSize}, context); + + histogramKernel + <<getCudaStream()>>>( + xBuffer, dxShapeInfo, zBuffer, zShapeInfo, tmp.specialBuffer(), + context->getReductionPointer(), numBins, + reinterpret_cast(min_val), reinterpret_cast(max_val)); + + cudaStreamSynchronize(*context->getCudaStream()); +} + +void histogramHelper(sd::LaunchContext *context, NDArray &input, + NDArray &output) { + Nd4jLong numBins = output.lengthOf(); + NDArray::registerSpecialUse({&output}, {&input}); + + auto min_val = input.reduceNumber(reduce::SameOps::Min); + auto max_val = input.reduceNumber(reduce::SameOps::Max); + // min_val.printIndexedBuffer("MIN"); + // max_val.printIndexedBuffer("MAX"); + BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, + (context, input.specialBuffer(), input.shapeInfo(), + input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), numBins, + min_val.specialBuffer(), max_val.specialBuffer()), + LIBND4J_TYPES, INTEGER_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/histogramFixedWidth.cu b/libnd4j/include/ops/declarable/helpers/cuda/histogramFixedWidth.cu index c6041b33b0e5..fa8357aedc26 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/histogramFixedWidth.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/histogramFixedWidth.cu @@ -18,104 +18,115 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 31.08.2018 // -#include #include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template -__global__ static void histogramFixedWidthCuda( const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const X leftEdge, const X rightEdge) { - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ Nd4jLong xLen, zLen, totalThreads, nbins; - __shared__ X binWidth, secondEdge, lastButOneEdge; +template +__global__ static void histogramFixedWidthCuda( + const void* vx, const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const X leftEdge, const X rightEdge) { + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); - if (threadIdx.x == 0) { + __shared__ Nd4jLong xLen, zLen, totalThreads, nbins; + __shared__ X binWidth, secondEdge, lastButOneEdge; - xLen = shape::length(xShapeInfo); - nbins = shape::length(zShapeInfo); // nbins = zLen - totalThreads = gridDim.x * blockDim.x; + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + nbins = shape::length(zShapeInfo); // nbins = zLen + totalThreads = gridDim.x * blockDim.x; - binWidth = (rightEdge - leftEdge ) / nbins; - secondEdge = leftEdge + binWidth; - lastButOneEdge = rightEdge - binWidth; - } + binWidth = (rightEdge - leftEdge) / nbins; + secondEdge = leftEdge + binWidth; + lastButOneEdge = rightEdge - binWidth; + } - __syncthreads(); + __syncthreads(); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (Nd4jLong i = tid; i < xLen; i += totalThreads) { + for (Nd4jLong i = tid; i < xLen; i += totalThreads) { + const X value = x[shape::getIndexOffset(i, xShapeInfo)]; - const X value = x[shape::getIndexOffset(i, xShapeInfo)]; + Nd4jLong zIndex; - Nd4jLong zIndex; + if (value < secondEdge) + zIndex = 0; + else if (value >= lastButOneEdge) + zIndex = nbins - 1; + else + zIndex = static_cast((value - leftEdge) / binWidth); - if(value < secondEdge) - zIndex = 0; - else if(value >= lastButOneEdge) - zIndex = nbins - 1; - else - zIndex = static_cast((value - leftEdge) / binWidth); - - sd::math::atomics::nd4j_atomicAdd(&z[shape::getIndexOffset(zIndex, zShapeInfo)], 1); - } + sd::math::atomics::nd4j_atomicAdd( + &z[shape::getIndexOffset(zIndex, zShapeInfo)], 1); + } } /////////////////////////////////////////////////////////////////// -template -__host__ static void histogramFixedWidthCudaLauncher(const cudaStream_t *stream, const NDArray& input, const NDArray& range, NDArray& output) { - - const X leftEdge = range.e(0); - const X rightEdge = range.e(1); - - histogramFixedWidthCuda<<<256, 256, 1024, *stream>>>(input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), leftEdge, rightEdge); +template +__host__ static void histogramFixedWidthCudaLauncher(const cudaStream_t* stream, + const NDArray& input, + const NDArray& range, + NDArray& output) { + const X leftEdge = range.e(0); + const X rightEdge = range.e(1); + + histogramFixedWidthCuda<<<256, 256, 1024, *stream>>>( + input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), leftEdge, rightEdge); } //////////////////////////////////////////////////////////////////////// -void histogramFixedWidth(sd::LaunchContext* context, const NDArray& input, const NDArray& range, NDArray& output) { +void histogramFixedWidth(sd::LaunchContext* context, const NDArray& input, + const NDArray& range, NDArray& output) { + // firstly initialize output with zeros + output.nullify(); - // firstly initialize output with zeros - output.nullify(); + PointersManager manager(context, "histogramFixedWidth"); - PointersManager manager(context, "histogramFixedWidth"); + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), + histogramFixedWidthCudaLauncher, + (context->getCudaStream(), input, range, output), + LIBND4J_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); - NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogramFixedWidthCudaLauncher, (context->getCudaStream(), input, range, output), LIBND4J_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({&output}, {&input}); - - manager.synchronize(); + manager.synchronize(); } - // template -// __global__ static void copyBuffers(Nd4jLong* destination, void const* source, Nd4jLong* sourceShape, Nd4jLong bufferLength) { +// __global__ static void copyBuffers(Nd4jLong* destination, void const* +// source, Nd4jLong* sourceShape, Nd4jLong bufferLength) { // const auto tid = blockIdx.x * gridDim.x + threadIdx.x; // const auto step = gridDim.x * blockDim.x; // for (int t = tid; t < bufferLength; t += step) { -// destination[t] = reinterpret_cast(source)[shape::getIndexOffset(t, sourceShape)]; +// destination[t] = reinterpret_cast(source)[shape::getIndexOffset(t, sourceShape)]; // } // } // template -// __global__ static void returnBuffers(void* destination, Nd4jLong const* source, Nd4jLong* destinationShape, Nd4jLong bufferLength) { +// __global__ static void returnBuffers(void* destination, Nd4jLong const* +// source, Nd4jLong* destinationShape, Nd4jLong bufferLength) { // const auto tid = blockIdx.x * gridDim.x + threadIdx.x; // const auto step = gridDim.x * blockDim.x; // for (int t = tid; t < bufferLength; t += step) { -// reinterpret_cast(destination)[shape::getIndexOffset(t, destinationShape)] = source[t]; +// reinterpret_cast(destination)[shape::getIndexOffset(t, +// destinationShape)] = source[t]; // } // } // template -// static __global__ void histogramFixedWidthKernel(void* outputBuffer, Nd4jLong outputLength, void const* inputBuffer, Nd4jLong* inputShape, Nd4jLong inputLength, double const leftEdge, double binWidth, double secondEdge, double lastButOneEdge) { +// static __global__ void histogramFixedWidthKernel(void* outputBuffer, +// Nd4jLong outputLength, void const* inputBuffer, Nd4jLong* inputShape, +// Nd4jLong inputLength, double const leftEdge, double binWidth, double +// secondEdge, double lastButOneEdge) { // __shared__ T const* x; // __shared__ Nd4jLong* z; // output buffer @@ -131,7 +142,8 @@ void histogramFixedWidth(sd::LaunchContext* context, const NDArray& input, const // for(auto i = tid; i < inputLength; i += step) { // const T value = x[shape::getIndexOffset(i, inputShape)]; -// Nd4jLong currInd = static_cast((value - leftEdge) / binWidth); +// Nd4jLong currInd = static_cast((value - leftEdge) / +// binWidth); // if(value < secondEdge) // currInd = 0; @@ -141,9 +153,9 @@ void histogramFixedWidth(sd::LaunchContext* context, const NDArray& input, const // } // } - // template -// void histogramFixedWidth_(sd::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output) { +// void histogramFixedWidth_(sd::LaunchContext * context, const NDArray& +// input, const NDArray& range, NDArray& output) { // const int nbins = output.lengthOf(); // auto stream = context->getCudaStream(); // // firstly initialize output with zeros @@ -161,16 +173,23 @@ void histogramFixedWidth(sd::LaunchContext* context, const NDArray& input, const // const double secondEdge = leftEdge + binWidth; // double lastButOneEdge = rightEdge - binWidth; // Nd4jLong* outputBuffer; -// cudaError_t err = cudaMalloc(&outputBuffer, output.lengthOf() * sizeof(Nd4jLong)); -// if (err != 0) -// throw cuda_exception::build("helpers::histogramFixedWidth: Cannot allocate memory for output", err); -// copyBuffers<<<256, 512, 8192, *stream>>>(outputBuffer, output.specialBuffer(), output.special(), output.lengthOf()); -// histogramFixedWidthKernel<<<256, 512, 8192, *stream>>>(outputBuffer, output.lengthOf(), input.specialBuffer(), input.special(), input.lengthOf(), leftEdge, binWidth, secondEdge, lastButOneEdge); -// returnBuffers<<<256, 512, 8192, *stream>>>(output.specialBuffer(), outputBuffer, output.special(), output.lengthOf()); +// cudaError_t err = cudaMalloc(&outputBuffer, output.lengthOf() * +// sizeof(Nd4jLong)); if (err != 0) +// throw cuda_exception::build("helpers::histogramFixedWidth: Cannot +// allocate memory for output", err); +// copyBuffers<<<256, 512, 8192, *stream>>>(outputBuffer, +// output.specialBuffer(), output.special(), +// output.lengthOf()); histogramFixedWidthKernel<<<256, 512, 8192, +// *stream>>>(outputBuffer, output.lengthOf(), input.specialBuffer(), +// input.special(), input.lengthOf(), leftEdge, binWidth, +// secondEdge, lastButOneEdge); returnBuffers<<<256, 512, +// 8192, *stream>>>(output.specialBuffer(), outputBuffer, +// output.special(), output.lengthOf()); // //cudaSyncStream(*stream); // err = cudaFree(outputBuffer); // if (err != 0) -// throw cuda_exception::build("helpers::histogramFixedWidth: Cannot deallocate memory for output buffer", err); +// throw cuda_exception::build("helpers::histogramFixedWidth: Cannot +// deallocate memory for output buffer", err); // output.tickWriteDevice(); // //#pragma omp parallel for schedule(guided) // // for(Nd4jLong i = 0; i < input.lengthOf(); ++i) { @@ -182,20 +201,27 @@ void histogramFixedWidth(sd::LaunchContext* context, const NDArray& input, const // // output.p(0, output.e(0) + 1); // // else if(value >= lastButOneEdge) // //#pragma omp critical -// // output.p(nbins-1, output.e(nbins-1) + 1); +// // output.p(nbins-1, output.e(nbins-1) + +// 1); // // else { -// // Nd4jLong currInd = static_cast((value - leftEdge) / binWidth); +// // Nd4jLong currInd = static_cast((value - leftEdge) +// / binWidth); // //#pragma omp critical -// // output.p(currInd, output.e(currInd) + 1); +// // output.p(currInd, output.e(currInd) + +// 1); // // } // // } // } -// void histogramFixedWidth(sd::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output) { -// BUILD_SINGLE_SELECTOR(input.dataType(), histogramFixedWidth_, (context, input, range, output), LIBND4J_TYPES); +// void histogramFixedWidth(sd::LaunchContext * context, const NDArray& +// input, const NDArray& range, NDArray& output) { +// BUILD_SINGLE_SELECTOR(input.dataType(), histogramFixedWidth_, +// (context, input, range, output), LIBND4J_TYPES); // } -// BUILD_SINGLE_TEMPLATE(template void histogramFixedWidth_, (sd::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output), LIBND4J_TYPES); +// BUILD_SINGLE_TEMPLATE(template void histogramFixedWidth_, +// (sd::LaunchContext * context, const NDArray& input, const NDArray& range, +// NDArray& output), LIBND4J_TYPES); -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/im2col.cu b/libnd4j/include/ops/declarable/helpers/cuda/im2col.cu index 08f5959e82c5..ad2992fb7df3 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/im2col.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/im2col.cu @@ -18,91 +18,104 @@ // Created by raver119 on 30.11.17. // -#include #include +#include namespace sd { namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// // input [bS, iC, iH, iW] is convoluted to output [bS, iC, kH, kW, oH, oW] template __global__ static void im2colCuda(const void *image, void *columns, - const Nd4jLong *imShapeInfo, const Nd4jLong *colShapeInfo, - const int sH, const int sW, - const int pH, const int pW, + const Nd4jLong *imShapeInfo, + const Nd4jLong *colShapeInfo, const int sH, + const int sW, const int pH, const int pW, const int dH, const int dW, const double zeroPadValD) { + T zeroPadVal = + static_cast(zeroPadValD); // Value to use when value is padding. + // Usually 0 but not always + const auto im = reinterpret_cast(image); + auto col = reinterpret_cast(columns); - T zeroPadVal = static_cast(zeroPadValD); //Value to use when value is padding. Usually 0 but not always - const auto im = reinterpret_cast(image); - auto col = reinterpret_cast(columns); - - __shared__ Nd4jLong colLen, iH, iW; - __shared__ int imRank, colRank, *sharedMem; + __shared__ Nd4jLong colLen, iH, iW; + __shared__ int imRank, colRank, *sharedMem; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - colRank = 6; - imRank = 4; + colRank = 6; + imRank = 4; - colLen = shape::length(colShapeInfo); + colLen = shape::length(colShapeInfo); - iH = imShapeInfo[3]; - iW = imShapeInfo[4]; - } - __syncthreads(); + iH = imShapeInfo[3]; + iW = imShapeInfo[4]; + } + __syncthreads(); - const auto colInd = threadIdx.x + blockIdx.x * blockDim.x; + const auto colInd = threadIdx.x + blockIdx.x * blockDim.x; - if(colInd >= colLen) - return; + if (colInd >= colLen) return; - auto coords = sharedMem + threadIdx.x * colRank; + auto coords = sharedMem + threadIdx.x * colRank; - shape::index2coords(colInd, colShapeInfo, coords); + shape::index2coords(colInd, colShapeInfo, coords); - const auto colOffset = shape::getOffset(colShapeInfo, coords); + const auto colOffset = shape::getOffset(colShapeInfo, coords); - coords[2] = (-pH + coords[2] * dH) + coords[4] * sH; // imH - coords[3] = (-pW + coords[3] * dW) + coords[5] * sW; // imW + coords[2] = (-pH + coords[2] * dH) + coords[4] * sH; // imH + coords[3] = (-pW + coords[3] * dW) + coords[5] * sW; // imW - if (static_cast(coords[2]) >= static_cast(iH) || static_cast(coords[3]) >= static_cast(iW)) - col[colOffset] = zeroPadVal; - else - col[colOffset] = im[shape::getOffset(imShapeInfo, coords)]; + if (static_cast(coords[2]) >= static_cast(iH) || + static_cast(coords[3]) >= static_cast(iW)) + col[colOffset] = zeroPadVal; + else + col[colOffset] = im[shape::getOffset(imShapeInfo, coords)]; } - ////////////////////////////////////////////////////////////////////////// template -static void im2colCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, sd::LaunchContext & context, const void *image, void *columns, const Nd4jLong *imShapeInfo, const Nd4jLong *colShapeInfo, int sH, int sW, int pH, int pW, int dH, int dW, double zeroPadVal) { - im2colCuda<<>>(image, columns, imShapeInfo, colShapeInfo, sH, sW, pH, pW, dH, dW, zeroPadVal); +static void im2colCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, + sd::LaunchContext &context, const void *image, + void *columns, const Nd4jLong *imShapeInfo, + const Nd4jLong *colShapeInfo, int sH, int sW, + int pH, int pW, int dH, int dW, + double zeroPadVal) { + im2colCuda<<>>(image, columns, imShapeInfo, + colShapeInfo, sH, sW, pH, pW, dH, + dW, zeroPadVal); } ////////////////////////////////////////////////////////////////////////// -void im2col(sd::LaunchContext& context, const NDArray& image, NDArray& columns, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const NDArray& arrZeroPadVal) { - - PointersManager manager(&context, "im2col"); - - const int threadsPerBlock = 512; - const int blocksPerGrid = (columns.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - NDArray::prepareSpecialUse({&columns}, {&image}); - BUILD_SINGLE_SELECTOR(columns.dataType(), im2colCudaLauncher, (blocksPerGrid, threadsPerBlock, context, image.specialBuffer(), columns.specialBuffer(), image.specialShapeInfo(), columns.specialShapeInfo(), sH, sW, pH, pW, dH, dW, arrZeroPadVal.e(0)), FLOAT_TYPES); - NDArray::registerSpecialUse({&columns}, {&image}); - - manager.synchronize(); +void im2col(sd::LaunchContext &context, const NDArray &image, NDArray &columns, + const int kH, const int kW, const int sH, const int sW, + const int pH, const int pW, const int dH, const int dW, + const NDArray &arrZeroPadVal) { + PointersManager manager(&context, "im2col"); + + const int threadsPerBlock = 512; + const int blocksPerGrid = + (columns.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({&columns}, {&image}); + BUILD_SINGLE_SELECTOR(columns.dataType(), im2colCudaLauncher, + (blocksPerGrid, threadsPerBlock, context, + image.specialBuffer(), columns.specialBuffer(), + image.specialShapeInfo(), columns.specialShapeInfo(), + sH, sW, pH, pW, dH, dW, arrZeroPadVal.e(0)), + FLOAT_TYPES); + NDArray::registerSpecialUse({&columns}, {&image}); + + manager.synchronize(); } - - - - -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_draw_bounding_boxes.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_draw_bounding_boxes.cu index 47319f10041e..0a5c4be6becf 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_draw_bounding_boxes.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_draw_bounding_boxes.cu @@ -17,164 +17,193 @@ // // @author sgazeos@gmail.com // -#include #include +#include namespace sd { namespace ops { namespace helpers { - typedef NDArray ColorTable_t; - static NDArray DefaultColorTable(int depth, sd::LaunchContext* context) { - //std::vector> colorTable; - const Nd4jLong kDefaultTableLength = 10; - const Nd4jLong kDefaultChannelLength = 4; - NDArray colorTable('c', {kDefaultTableLength, kDefaultChannelLength}, { - 1,1,0,1, // yellow - 0, 0, 1, 1, // 1: blue - 1, 0, 0, 1, // 2: red - 0, 1, 0, 1, // 3: lime - 0.5, 0, 0.5, 1, // 4: purple - 0.5, 0.5, 0, 1, // 5: olive - 0.5, 0, 0, 1, // 6: maroon - 0, 0, 0.5, 1, // 7: navy blue - 0, 1, 1, 1, // 8: aqua - 1, 0, 1, 1 // 9: fuchsia - }, DataType::FLOAT32, context); - - if (depth == 1) { - colorTable.assign(1.f); // all to white when black and white colors - } - return colorTable; - } - - template - static __global__ void drawBoundingBoxesKernel(T const* images, const Nd4jLong* imagesShape, - float const* boxes, const Nd4jLong* boxesShape, - float const* colorTable, const Nd4jLong* colorTableShape, - T* output, const Nd4jLong* outputShape, - Nd4jLong batchSize, Nd4jLong width, Nd4jLong height, - Nd4jLong channels, Nd4jLong boxSize, Nd4jLong colorTableLen) { - - for (auto batch = blockIdx.x; batch < (int)batchSize; batch += gridDim.x) { // loop by batch - for (auto boxIndex = 0; boxIndex < boxSize; ++boxIndex) { - // box with shape - //auto internalBox = &boxes[b * colorSetSize * 4 + c * 4];//(*boxes)(b, {0})(c, {0});//internalBoxes->at(c); - auto colorIndex = boxIndex % colorTableLen;//colorSet->at(c); -// auto rowStart = sd::math::nd4j_max(Nd4jLong (0), Nd4jLong ((height - 1) * internalBox[0])); -// auto rowEnd = sd::math::nd4j_min(Nd4jLong (height - 1), Nd4jLong ((height - 1) * internalBox[2])); -// auto colStart = sd::math::nd4j_max(Nd4jLong (0), Nd4jLong ((width - 1) * internalBox[1])); -// auto colEnd = sd::math::nd4j_min(Nd4jLong(width - 1), Nd4jLong ((width - 1) * internalBox[3])); - Nd4jLong indices0[] = {batch, boxIndex, 0}; - Nd4jLong indices1[] = {batch, boxIndex, 1}; - Nd4jLong indices2[] = {batch, boxIndex, 2}; - Nd4jLong indices3[] = {batch, boxIndex, 3}; - auto rowStart = Nd4jLong ((height - 1) * boxes[shape::getOffset(boxesShape, indices0, 0)]); - auto rowStartBound = sd::math::nd4j_max(Nd4jLong (0), rowStart); - auto rowEnd = Nd4jLong ((height - 1) * boxes[shape::getOffset(boxesShape, indices2, 0)]); - auto rowEndBound = sd::math::nd4j_min(Nd4jLong (height - 1), rowEnd); - auto colStart = Nd4jLong ((width - 1) * boxes[shape::getOffset(boxesShape, indices1, 0)]); - auto colStartBound = sd::math::nd4j_max(Nd4jLong (0), colStart); - auto colEnd = Nd4jLong ((width - 1) * boxes[shape::getOffset(boxesShape, indices3, 0)]); - auto colEndBound = sd::math::nd4j_min(Nd4jLong(width - 1), colEnd); - if (rowStart > rowEnd || colStart > colEnd) { -// printf("helpers::drawBoundingBoxesFunctor: Bounding box (%lld, %lld, %lld, %lld) is inverted " -// "and will not be drawn\n", rowStart, colStart, rowEnd, colEnd); - continue; - } - if (rowStart >= height || rowEnd < 0 || colStart >= width || - colEnd < 0) { -// printf("helpers::drawBoundingBoxesFunctor: Bounding box (%lld, %lld, %lld, %lld) is completely " -// "outside the image and not be drawn\n", rowStart, colStart, rowEnd, colEnd); - continue; - } - - // Draw upper line - if (rowStart >= 0) { - for (auto j = colStartBound + threadIdx.x; j <= colEndBound; j += blockDim.x) - for (auto c = 0; c < channels; c++) { - Nd4jLong zPos[] = {batch, rowStart, j, c}; - Nd4jLong cPos[] = {colorIndex, c}; - auto cIndex = shape::getOffset(colorTableShape, cPos, 0); - auto zIndex = shape::getOffset(outputShape, zPos, 0); - output[zIndex] = (T)colorTable[cIndex]; - } - } - // Draw bottom line. - if (rowEnd < height) { - for (auto j = colStartBound + threadIdx.x; j <= colEndBound; j += blockDim.x) - for (auto c = 0; c < channels; c++) { - Nd4jLong zPos[] = {batch, rowEnd, j, c}; - Nd4jLong cPos[] = {colorIndex, c}; - auto cIndex = shape::getOffset(colorTableShape, cPos, 0); - auto zIndex = shape::getOffset(outputShape, zPos, 0); - output[zIndex] = (T)colorTable[cIndex]; - } - } +typedef NDArray ColorTable_t; +static NDArray DefaultColorTable(int depth, sd::LaunchContext* context) { + // std::vector> colorTable; + const Nd4jLong kDefaultTableLength = 10; + const Nd4jLong kDefaultChannelLength = 4; + NDArray colorTable('c', {kDefaultTableLength, kDefaultChannelLength}, + { + 1, 1, 0, 1, // yellow + 0, 0, 1, 1, // 1: blue + 1, 0, 0, 1, // 2: red + 0, 1, 0, 1, // 3: lime + 0.5, 0, 0.5, 1, // 4: purple + 0.5, 0.5, 0, 1, // 5: olive + 0.5, 0, 0, 1, // 6: maroon + 0, 0, 0.5, 1, // 7: navy blue + 0, 1, 1, 1, // 8: aqua + 1, 0, 1, 1 // 9: fuchsia + }, + DataType::FLOAT32, context); - // Draw left line. - if (colStart >= 0) { - for (auto i = rowStartBound + threadIdx.x; i <= rowEndBound; i += blockDim.x) - for (auto c = 0; c < channels; c++) { - Nd4jLong zPos[] = {batch, i, colStart, c}; - Nd4jLong cPos[] = {colorIndex, c}; - auto cIndex = shape::getOffset(colorTableShape, cPos, 0); - auto zIndex = shape::getOffset(outputShape, zPos, 0); - output[zIndex] = (T)colorTable[cIndex]; - } - } - // Draw right line. - if (colEnd < width) { - for (auto i = rowStartBound + threadIdx.x; i <= rowEndBound; i += blockDim.x) - for (auto c = 0; c < channels; c++) { - Nd4jLong zPos[] = {batch, i, colEnd, c}; - Nd4jLong cPos[] = {colorIndex, c}; - auto cIndex = shape::getOffset(colorTableShape, cPos, 0); - auto zIndex = shape::getOffset(outputShape, zPos, 0); - output[zIndex] = (T)colorTable[cIndex]; - } - } - } - } + if (depth == 1) { + colorTable.assign(1.f); // all to white when black and white colors + } + return colorTable; +} - } +template +static __global__ void drawBoundingBoxesKernel( + T const* images, const Nd4jLong* imagesShape, float const* boxes, + const Nd4jLong* boxesShape, float const* colorTable, + const Nd4jLong* colorTableShape, T* output, const Nd4jLong* outputShape, + Nd4jLong batchSize, Nd4jLong width, Nd4jLong height, Nd4jLong channels, + Nd4jLong boxSize, Nd4jLong colorTableLen) { + for (auto batch = blockIdx.x; batch < (int)batchSize; + batch += gridDim.x) { // loop by batch + for (auto boxIndex = 0; boxIndex < boxSize; ++boxIndex) { + // box with shape + // auto internalBox = &boxes[b * colorSetSize * 4 + c * 4];//(*boxes)(b, + // {0})(c, {0});//internalBoxes->at(c); + auto colorIndex = boxIndex % colorTableLen; // colorSet->at(c); + // auto rowStart = sd::math::nd4j_max(Nd4jLong (0), + // Nd4jLong ((height - 1) * internalBox[0])); auto rowEnd = + // sd::math::nd4j_min(Nd4jLong (height - 1), Nd4jLong + // ((height - 1) * internalBox[2])); auto colStart = + // sd::math::nd4j_max(Nd4jLong (0), Nd4jLong ((width - 1) * + // internalBox[1])); auto colEnd = + // sd::math::nd4j_min(Nd4jLong(width - 1), Nd4jLong ((width + // - 1) * internalBox[3])); + Nd4jLong indices0[] = {batch, boxIndex, 0}; + Nd4jLong indices1[] = {batch, boxIndex, 1}; + Nd4jLong indices2[] = {batch, boxIndex, 2}; + Nd4jLong indices3[] = {batch, boxIndex, 3}; + auto rowStart = Nd4jLong( + (height - 1) * boxes[shape::getOffset(boxesShape, indices0, 0)]); + auto rowStartBound = sd::math::nd4j_max(Nd4jLong(0), rowStart); + auto rowEnd = Nd4jLong((height - 1) * + boxes[shape::getOffset(boxesShape, indices2, 0)]); + auto rowEndBound = sd::math::nd4j_min(Nd4jLong(height - 1), rowEnd); + auto colStart = Nd4jLong( + (width - 1) * boxes[shape::getOffset(boxesShape, indices1, 0)]); + auto colStartBound = sd::math::nd4j_max(Nd4jLong(0), colStart); + auto colEnd = Nd4jLong((width - 1) * + boxes[shape::getOffset(boxesShape, indices3, 0)]); + auto colEndBound = sd::math::nd4j_min(Nd4jLong(width - 1), colEnd); + if (rowStart > rowEnd || colStart > colEnd) { + // printf("helpers::drawBoundingBoxesFunctor: + // Bounding box (%lld, %lld, %lld, %lld) is inverted + // " + // "and will not be drawn\n", rowStart, + // colStart, rowEnd, colEnd); + continue; + } + if (rowStart >= height || rowEnd < 0 || colStart >= width || colEnd < 0) { + // printf("helpers::drawBoundingBoxesFunctor: + // Bounding box (%lld, %lld, %lld, %lld) is + // completely " + // "outside the image and not be + // drawn\n", rowStart, colStart, rowEnd, + // colEnd); + continue; + } - template - void drawBoundingBoxesH(sd::LaunchContext* context, NDArray const* images, NDArray const* boxes, NDArray const* colors, NDArray* output) { - auto batchSize = images->sizeAt(0); - auto height = images->sizeAt(1); - auto width = images->sizeAt(2); - auto channels = images->sizeAt(3); - auto stream = context->getCudaStream(); - auto boxSize = boxes->sizeAt(1); - NDArray colorsTable = DefaultColorTable(channels, context); - if ((colors != nullptr && colors->lengthOf() > 0)) { - colorsTable = *colors; - } + // Draw upper line + if (rowStart >= 0) { + for (auto j = colStartBound + threadIdx.x; j <= colEndBound; + j += blockDim.x) + for (auto c = 0; c < channels; c++) { + Nd4jLong zPos[] = {batch, rowStart, j, c}; + Nd4jLong cPos[] = {colorIndex, c}; + auto cIndex = shape::getOffset(colorTableShape, cPos, 0); + auto zIndex = shape::getOffset(outputShape, zPos, 0); + output[zIndex] = (T)colorTable[cIndex]; + } + } + // Draw bottom line. + if (rowEnd < height) { + for (auto j = colStartBound + threadIdx.x; j <= colEndBound; + j += blockDim.x) + for (auto c = 0; c < channels; c++) { + Nd4jLong zPos[] = {batch, rowEnd, j, c}; + Nd4jLong cPos[] = {colorIndex, c}; + auto cIndex = shape::getOffset(colorTableShape, cPos, 0); + auto zIndex = shape::getOffset(outputShape, zPos, 0); + output[zIndex] = (T)colorTable[cIndex]; + } + } - auto imagesBuf = images->getDataBuffer()->specialAsT(); - auto boxesBuf = boxes->getDataBuffer()->specialAsT(); // boxes should be float32 - auto colorsTableBuf = colorsTable.getDataBuffer()->specialAsT(); // color table is float32 - auto outputBuf = output->dataBuffer()->specialAsT(); - drawBoundingBoxesKernel<<<128, 128, 1024, *stream>>>(imagesBuf, images->specialShapeInfo(), - boxesBuf, boxes->specialShapeInfo(), colorsTableBuf, colorsTable.specialShapeInfo(), - outputBuf, output->specialShapeInfo(), batchSize, width, height, channels, boxSize, colorsTable.lengthOf()); + // Draw left line. + if (colStart >= 0) { + for (auto i = rowStartBound + threadIdx.x; i <= rowEndBound; + i += blockDim.x) + for (auto c = 0; c < channels; c++) { + Nd4jLong zPos[] = {batch, i, colStart, c}; + Nd4jLong cPos[] = {colorIndex, c}; + auto cIndex = shape::getOffset(colorTableShape, cPos, 0); + auto zIndex = shape::getOffset(outputShape, zPos, 0); + output[zIndex] = (T)colorTable[cIndex]; + } + } + // Draw right line. + if (colEnd < width) { + for (auto i = rowStartBound + threadIdx.x; i <= rowEndBound; + i += blockDim.x) + for (auto c = 0; c < channels; c++) { + Nd4jLong zPos[] = {batch, i, colEnd, c}; + Nd4jLong cPos[] = {colorIndex, c}; + auto cIndex = shape::getOffset(colorTableShape, cPos, 0); + auto zIndex = shape::getOffset(outputShape, zPos, 0); + output[zIndex] = (T)colorTable[cIndex]; + } + } } + } +} - void drawBoundingBoxesFunctor(sd::LaunchContext * context, NDArray* images, NDArray* boxes, NDArray* colors, NDArray* output) { - // images - batch of 3D images with BW (last dim = 1), RGB (last dim = 3) or RGBA (last dim = 4) channel set - // boxes - batch of 2D bounds with last dim (y_start, x_start, y_end, x_end) to compute i and j as - // floor((height - 1 ) * y_start) => rowStart, floor((height - 1) * y_end) => rowEnd - // floor((width - 1 ) * x_start) => colStart, floor((width - 1) * x_end) => colEnd - // height = images->sizeAt(1), width = images->sizeAt(2) - // colors - colors for each box given - // set up color for each box as frame - NDArray::prepareSpecialUse({output}, {images, boxes, colors}); - output->assign(images); - BUILD_SINGLE_SELECTOR(output->dataType(), drawBoundingBoxesH, (context, images, boxes, colors, output), FLOAT_TYPES); - NDArray::registerSpecialUse({output}, {images, boxes, colors}); - } +template +void drawBoundingBoxesH(sd::LaunchContext* context, NDArray const* images, + NDArray const* boxes, NDArray const* colors, + NDArray* output) { + auto batchSize = images->sizeAt(0); + auto height = images->sizeAt(1); + auto width = images->sizeAt(2); + auto channels = images->sizeAt(3); + auto stream = context->getCudaStream(); + auto boxSize = boxes->sizeAt(1); + NDArray colorsTable = DefaultColorTable(channels, context); + if ((colors != nullptr && colors->lengthOf() > 0)) { + colorsTable = *colors; + } + auto imagesBuf = images->getDataBuffer()->specialAsT(); + auto boxesBuf = + boxes->getDataBuffer()->specialAsT(); // boxes should be float32 + auto colorsTableBuf = colorsTable.getDataBuffer() + ->specialAsT(); // color table is float32 + auto outputBuf = output->dataBuffer()->specialAsT(); + drawBoundingBoxesKernel<<<128, 128, 1024, *stream>>>( + imagesBuf, images->specialShapeInfo(), boxesBuf, + boxes->specialShapeInfo(), colorsTableBuf, colorsTable.specialShapeInfo(), + outputBuf, output->specialShapeInfo(), batchSize, width, height, channels, + boxSize, colorsTable.lengthOf()); } + +void drawBoundingBoxesFunctor(sd::LaunchContext* context, NDArray* images, + NDArray* boxes, NDArray* colors, + NDArray* output) { + // images - batch of 3D images with BW (last dim = 1), RGB (last dim = 3) or + // RGBA (last dim = 4) channel set boxes - batch of 2D bounds with last dim + // (y_start, x_start, y_end, x_end) to compute i and j as floor((height - 1 ) + // * y_start) => rowStart, floor((height - 1) * y_end) => rowEnd floor((width + // - 1 ) * x_start) => colStart, floor((width - 1) * x_end) => colEnd height = + // images->sizeAt(1), width = images->sizeAt(2) colors - colors for each box + // given set up color for each box as frame + NDArray::prepareSpecialUse({output}, {images, boxes, colors}); + output->assign(images); + BUILD_SINGLE_SELECTOR(output->dataType(), drawBoundingBoxesH, + (context, images, boxes, colors, output), FLOAT_TYPES); + NDArray::registerSpecialUse({output}, {images, boxes, colors}); } -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu index 3365d5d62e94..cc349d98bc6d 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu @@ -33,52 +33,51 @@ limitations under the License. // @author George A. Shulinok // -#include #include #include +#include namespace sd { namespace ops { namespace helpers { - struct BilinearInterpolationData { - Nd4jLong bottomIndex; // Lower source index used in the interpolation - Nd4jLong topIndex; // Upper source index used in the interpolation - // 1-D linear iterpolation scale (see: - // https://en.wikipedia.org/wiki/Bilinear_interpolation) - double interpolarValue; - }; +struct BilinearInterpolationData { + Nd4jLong bottomIndex; // Lower source index used in the interpolation + Nd4jLong topIndex; // Upper source index used in the interpolation + // 1-D linear iterpolation scale (see: + // https://en.wikipedia.org/wiki/Bilinear_interpolation) + double interpolarValue; +}; // Older incorrect scaling method that causes all resizes to have a slight // translation leading to inconsistent results. For example, a flip then a // resize gives different results then a resize then a flip. - struct LegacyScaler { - _CUDA_HD LegacyScaler(){}; - inline _CUDA_HD float operator()(const int x, const float scale) const { - return static_cast(x) * scale; - } - }; +struct LegacyScaler { + _CUDA_HD LegacyScaler(){}; + inline _CUDA_HD float operator()(const int x, const float scale) const { + return static_cast(x) * scale; + } +}; // Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the // floating point coordinates of the top,left pixel is 0.5,0.5. - struct HalfPixelScaler { - _CUDA_HD HalfPixelScaler(){}; - inline _CUDA_HD float operator()(const int x, const float scale) const { - // Note that we subtract 0.5 from the return value, as the existing bilinear - // sampling code etc assumes pixels are in the old coordinate system. - return (static_cast(x) + 0.5f) * scale - 0.5f; - } - }; - - - // Utility functions - // calculateResizeScale determines the float scaling factor. - inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize, - bool alignCorners) { - return (alignCorners && outSize > 1) - ? (inSize - 1) / static_cast(outSize - 1) - : inSize / static_cast(outSize); - } +struct HalfPixelScaler { + _CUDA_HD HalfPixelScaler(){}; + inline _CUDA_HD float operator()(const int x, const float scale) const { + // Note that we subtract 0.5 from the return value, as the existing bilinear + // sampling code etc assumes pixels are in the old coordinate system. + return (static_cast(x) + 0.5f) * scale - 0.5f; + } +}; + +// Utility functions +// calculateResizeScale determines the float scaling factor. +inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize, + bool alignCorners) { + return (alignCorners && outSize > 1) + ? (inSize - 1) / static_cast(outSize - 1) + : inSize / static_cast(outSize); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // computeInterpolationWeights kernel @@ -87,1264 +86,1508 @@ namespace helpers { // scale - input scale // interporationData - result // - template - static __global__ void computeInterpolationWeights(Nd4jLong outSize, - Nd4jLong inSize, - double scale, - Nd4jLong channels, - BilinearInterpolationData* interpolationData) { - interpolationData[outSize].bottomIndex = 0; - interpolationData[outSize].topIndex = 0; - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - Scaler scaler; - for (Nd4jLong i = outSize - tid; i >= 0; i -= step) { - double in = scaler(i, scale); -// interpolationData[i].bottomIndex = static_cast(in); -// interpolationData[i].topIndex = sd::math::nd4j_min(interpolationData[i].bottomIndex + 1, inSize - 1); -// interpolationData[i].interpolarValue = in - interpolationData[i].bottomIndex; - double const in_f = sd::math::p_floor(in); - double const in_c = sd::math::p_ceil(in); - interpolationData[i].bottomIndex = sd::math::nd4j_max(static_cast(in_f), (Nd4jLong)0LL);//static_cast(in); - interpolationData[i].topIndex = sd::math::nd4j_min(static_cast(in_c), inSize - 1); - interpolationData[i].interpolarValue = in - in_f; - - if (channels) { - math::atomics::nd4j_atomicMul(&interpolationData[i].bottomIndex, channels); - math::atomics::nd4j_atomicMul(&interpolationData[i].topIndex, channels); - } - } +template +static __global__ void computeInterpolationWeights( + Nd4jLong outSize, Nd4jLong inSize, double scale, Nd4jLong channels, + BilinearInterpolationData* interpolationData) { + interpolationData[outSize].bottomIndex = 0; + interpolationData[outSize].topIndex = 0; + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + Scaler scaler; + for (Nd4jLong i = outSize - tid; i >= 0; i -= step) { + double in = scaler(i, scale); + // interpolationData[i].bottomIndex = static_cast(in); + // interpolationData[i].topIndex = + // sd::math::nd4j_min(interpolationData[i].bottomIndex + 1, + // inSize - 1); interpolationData[i].interpolarValue = in - + // interpolationData[i].bottomIndex; + double const in_f = sd::math::p_floor(in); + double const in_c = sd::math::p_ceil(in); + interpolationData[i].bottomIndex = + sd::math::nd4j_max(static_cast(in_f), + (Nd4jLong)0LL); // static_cast(in); + interpolationData[i].topIndex = + sd::math::nd4j_min(static_cast(in_c), inSize - 1); + interpolationData[i].interpolarValue = in - in_f; + + if (channels) { + math::atomics::nd4j_atomicMul(&interpolationData[i].bottomIndex, + channels); + math::atomics::nd4j_atomicMul(&interpolationData[i].topIndex, channels); } + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // resize image with bilinear interpolation algorithm // - static void resizeImage(sd::LaunchContext* context, NDArray const* images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, - Nd4jLong outWidth, Nd4jLong channels, - BilinearInterpolationData* xs_, - BilinearInterpolationData* ys_, - NDArray* output); +static void resizeImage(sd::LaunchContext* context, NDArray const* images, + Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, + Nd4jLong outHeight, Nd4jLong outWidth, + Nd4jLong channels, BilinearInterpolationData* xs_, + BilinearInterpolationData* ys_, NDArray* output); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // resize image with bilinear interpolation algorithm kernel // - template - static __global__ void resizeImageKernel(T const* input, Nd4jLong const* inputShape, Z* outputYptr, - Nd4jLong const* outputShape, Nd4jLong batchSize, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, - Nd4jLong inRowSize, Nd4jLong outRowSize, Nd4jLong inBatchNumValues, - BilinearInterpolationData* xs_, BilinearInterpolationData* ys_) { - - for (auto batch = blockIdx.x; batch < batchSize; batch += gridDim.x ) { // blockIdx.x as batch index - auto pX = input + batch * inBatchNumValues; - for (Nd4jLong y = threadIdx.x; y < outHeight; y += blockDim.x) { - const T* ys_input_lower_ptr = pX + ys_[y].bottomIndex * inRowSize; - const T* ys_input_upper_ptr = pX + ys_[y].topIndex * inRowSize; - double yVal = ys_[y].interpolarValue; - auto pZ = outputYptr + (batch * outHeight + y) * outRowSize; - for (Nd4jLong x = 0; x < outWidth; x++) { - auto xsBottom = xs_[x].bottomIndex; - auto xsTop = xs_[x].topIndex; - auto xVal = xs_[x].interpolarValue; - // process interpolation for all channels - for (int c = 0; c < channels; c++) { - Z topLeft(ys_input_lower_ptr[xsBottom + c]); - Z topRight(ys_input_lower_ptr[xsTop + c]); - Z bottomLeft(ys_input_upper_ptr[xsBottom + c]); - Z bottomRight(ys_input_upper_ptr[xsTop + c]); - Z top = topLeft + (topRight - topLeft) * xVal; - Z bottom = bottomLeft + (bottomRight - bottomLeft) * xVal; - Z resVal = Z(top + (bottom - top) * yVal); - pZ[x * channels + c] = resVal; - } - } - } +template +static __global__ void resizeImageKernel( + T const* input, Nd4jLong const* inputShape, Z* outputYptr, + Nd4jLong const* outputShape, Nd4jLong batchSize, Nd4jLong outWidth, + Nd4jLong outHeight, Nd4jLong channels, Nd4jLong inRowSize, + Nd4jLong outRowSize, Nd4jLong inBatchNumValues, + BilinearInterpolationData* xs_, BilinearInterpolationData* ys_) { + for (auto batch = blockIdx.x; batch < batchSize; + batch += gridDim.x) { // blockIdx.x as batch index + auto pX = input + batch * inBatchNumValues; + for (Nd4jLong y = threadIdx.x; y < outHeight; y += blockDim.x) { + const T* ys_input_lower_ptr = pX + ys_[y].bottomIndex * inRowSize; + const T* ys_input_upper_ptr = pX + ys_[y].topIndex * inRowSize; + double yVal = ys_[y].interpolarValue; + auto pZ = outputYptr + (batch * outHeight + y) * outRowSize; + for (Nd4jLong x = 0; x < outWidth; x++) { + auto xsBottom = xs_[x].bottomIndex; + auto xsTop = xs_[x].topIndex; + auto xVal = xs_[x].interpolarValue; + // process interpolation for all channels + for (int c = 0; c < channels; c++) { + Z topLeft(ys_input_lower_ptr[xsBottom + c]); + Z topRight(ys_input_lower_ptr[xsTop + c]); + Z bottomLeft(ys_input_upper_ptr[xsBottom + c]); + Z bottomRight(ys_input_upper_ptr[xsTop + c]); + Z top = topLeft + (topRight - topLeft) * xVal; + Z bottom = bottomLeft + (bottomRight - bottomLeft) * xVal; + Z resVal = Z(top + (bottom - top) * yVal); + pZ[x * channels + c] = resVal; } + } } + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // resize image with - template - static void resizeImage_(sd::LaunchContext* context, NDArray const* images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, - Nd4jLong outWidth, Nd4jLong channels, - BilinearInterpolationData* xs_, - BilinearInterpolationData* ys_, - NDArray* output) { - Nd4jLong inRowSize = inWidth * channels; - Nd4jLong inBatchNumValues = inHeight * inRowSize; - Nd4jLong outRowSize = outWidth * channels; - auto stream = context->getCudaStream(); - T const* pInput = images->getDataBuffer()->specialAsT(); //reinterpret_cast(images->specialBuffer()); // this works only with 'c' direction - F* pOutput = output->dataBuffer()->specialAsT();//reinterpret_cast(output->specialBuffer()); - dim3 batchSizeBlock(batchSize, 1, 1); - dim3 pictureBlock(outHeight, outWidth, channels); - resizeImageKernel<<<256, 256, 256, *stream>>>(pInput, images->specialShapeInfo(), pOutput, - output->specialShapeInfo(), batchSize, outWidth, outHeight, channels, inRowSize, outRowSize, - inBatchNumValues, xs_, ys_); - - auto err = cudaStreamSynchronize(*stream); - if (err != 0) { - throw cuda_exception::build("helpers::resizeImage_: Cannot synchronize kernel execution", err); - } - } +template +static void resizeImage_(sd::LaunchContext* context, NDArray const* images, + Nd4jLong batchSize, Nd4jLong inHeight, + Nd4jLong inWidth, Nd4jLong outHeight, + Nd4jLong outWidth, Nd4jLong channels, + BilinearInterpolationData* xs_, + BilinearInterpolationData* ys_, NDArray* output) { + Nd4jLong inRowSize = inWidth * channels; + Nd4jLong inBatchNumValues = inHeight * inRowSize; + Nd4jLong outRowSize = outWidth * channels; + auto stream = context->getCudaStream(); + T const* pInput = + images->getDataBuffer() + ->specialAsT(); // reinterpret_cast(images->specialBuffer()); // + // this works only with 'c' direction + F* pOutput = + output->dataBuffer() + ->specialAsT(); // reinterpret_cast(output->specialBuffer()); + dim3 batchSizeBlock(batchSize, 1, 1); + dim3 pictureBlock(outHeight, outWidth, channels); + resizeImageKernel<<<256, 256, 256, *stream>>>( + pInput, images->specialShapeInfo(), pOutput, output->specialShapeInfo(), + batchSize, outWidth, outHeight, channels, inRowSize, outRowSize, + inBatchNumValues, xs_, ys_); + + auto err = cudaStreamSynchronize(*stream); + if (err != 0) { + throw cuda_exception::build( + "helpers::resizeImage_: Cannot synchronize kernel execution", err); + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - static int resizeBilinearFunctor_(sd::LaunchContext* context, NDArray const* images, int const width, - int const height, bool const alignCorners, bool const halfPixelCenter, NDArray* output) { - const Nd4jLong batchSize = images->sizeAt(0); - const Nd4jLong inHeight = images->sizeAt(1); - const Nd4jLong inWidth = images->sizeAt(2); - const Nd4jLong channels = images->sizeAt(3); - - const Nd4jLong outHeight = output->sizeAt(1); - const Nd4jLong outWidth = output->sizeAt(2); - - // Handle no-op resizes efficiently. - if (outHeight == inHeight && outWidth == inWidth) { - output->assign(images); - return ND4J_STATUS_OK; - } - - float heightScale = calculateResizeScale(inHeight, outHeight, alignCorners); - float widthScale = calculateResizeScale(inWidth, outWidth, alignCorners); - - BilinearInterpolationData* xs_;// = xs.data(); - BilinearInterpolationData* ys_;// = xs.data(); - - cudaError_t err = cudaMalloc(&xs_, sizeof(BilinearInterpolationData) * (outWidth + 1)); - if (err != 0) { - throw cuda_exception::build("helpers::resize_image: Cannot allocate memory for vertical parts rectangulars", err); - } - - err = cudaMalloc(&ys_, sizeof(BilinearInterpolationData) * (outHeight + 1)); - if (err != 0) { - throw cuda_exception::build("helpers::resize_image: Cannot allocate memory for horizontal parts rectangulars", err); - } - auto stream = context->getCudaStream(); - // Compute the cached interpolation weights on the x and y dimensions. - if (halfPixelCenter) { - computeInterpolationWeights < - HalfPixelScaler ><<<256, 512, 512, *stream>>>(outHeight, inHeight, heightScale, 0, ys_); - computeInterpolationWeights < - HalfPixelScaler ><<<256, 512, 512, *stream>>>(outWidth, inWidth, widthScale, channels, xs_); - } - else { - computeInterpolationWeights < - LegacyScaler ><<<256, 512, 512, *stream>>>(outHeight, inHeight, heightScale, 0, ys_); - computeInterpolationWeights < - LegacyScaler ><<<256, 512, 512, *stream>>>(outWidth, inWidth, widthScale, channels, xs_); - } - printf("Input is %dx%d, Output is %dx%d\n", inHeight, inWidth, outHeight, outWidth); - NDArray::prepareSpecialUse({output}, {images}); - resizeImage_(context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs_, ys_, output); - err = cudaStreamSynchronize(*stream); - NDArray::registerSpecialUse({output}, {images}); - - err = cudaFree(xs_); - if (err != 0) { - throw cuda_exception::build("helpers::resize_image: Cannot deallocate memory for vertical parts rectangulars", err); - } - - err = cudaFree(ys_); - if (err != 0) { - throw cuda_exception::build("helpers::resize_image: Cannot deallocate memory for horizontical parts rectangulars", err); - } - - return Status::OK(); - } +template +static int resizeBilinearFunctor_(sd::LaunchContext* context, + NDArray const* images, int const width, + int const height, bool const alignCorners, + bool const halfPixelCenter, NDArray* output) { + const Nd4jLong batchSize = images->sizeAt(0); + const Nd4jLong inHeight = images->sizeAt(1); + const Nd4jLong inWidth = images->sizeAt(2); + const Nd4jLong channels = images->sizeAt(3); + + const Nd4jLong outHeight = output->sizeAt(1); + const Nd4jLong outWidth = output->sizeAt(2); + + // Handle no-op resizes efficiently. + if (outHeight == inHeight && outWidth == inWidth) { + output->assign(images); + return ND4J_STATUS_OK; + } + + float heightScale = calculateResizeScale(inHeight, outHeight, alignCorners); + float widthScale = calculateResizeScale(inWidth, outWidth, alignCorners); + + BilinearInterpolationData* xs_; // = xs.data(); + BilinearInterpolationData* ys_; // = xs.data(); + + cudaError_t err = + cudaMalloc(&xs_, sizeof(BilinearInterpolationData) * (outWidth + 1)); + if (err != 0) { + throw cuda_exception::build( + "helpers::resize_image: Cannot allocate memory for vertical parts " + "rectangulars", + err); + } + + err = cudaMalloc(&ys_, sizeof(BilinearInterpolationData) * (outHeight + 1)); + if (err != 0) { + throw cuda_exception::build( + "helpers::resize_image: Cannot allocate memory for horizontal parts " + "rectangulars", + err); + } + auto stream = context->getCudaStream(); + // Compute the cached interpolation weights on the x and y dimensions. + if (halfPixelCenter) { + computeInterpolationWeights + <<<256, 512, 512, *stream>>>(outHeight, inHeight, heightScale, 0, ys_); + computeInterpolationWeights<<<256, 512, 512, *stream>>>( + outWidth, inWidth, widthScale, channels, xs_); + } else { + computeInterpolationWeights + <<<256, 512, 512, *stream>>>(outHeight, inHeight, heightScale, 0, ys_); + computeInterpolationWeights<<<256, 512, 512, *stream>>>( + outWidth, inWidth, widthScale, channels, xs_); + } + printf("Input is %dx%d, Output is %dx%d\n", inHeight, inWidth, outHeight, + outWidth); + NDArray::prepareSpecialUse({output}, {images}); + resizeImage_(context, images, batchSize, inHeight, inWidth, outHeight, + outWidth, channels, xs_, ys_, output); + err = cudaStreamSynchronize(*stream); + NDArray::registerSpecialUse({output}, {images}); + + err = cudaFree(xs_); + if (err != 0) { + throw cuda_exception::build( + "helpers::resize_image: Cannot deallocate memory for vertical parts " + "rectangulars", + err); + } + + err = cudaFree(ys_); + if (err != 0) { + throw cuda_exception::build( + "helpers::resize_image: Cannot deallocate memory for horizontical " + "parts rectangulars", + err); + } + + return Status::OK(); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // resize by interpolation nearest neighbor algorithm kernel // - template - static __global__ void resizeNeighborKernel(T const* input, Nd4jLong const* inputShape, T* output, Nd4jLong const* outputShape, - Nd4jLong batchSize, Nd4jLong inWidth, Nd4jLong inHeight, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, double widthScale, double heightScale, bool alignCorners, bool halfPixelCenters) { - - //for (int b = blockIdx.x; b < batchSize; b += gridDim.x) - if (blockIdx.x < batchSize) - { - auto b = blockIdx.x; - for (int y = threadIdx.x; y < outHeight; y += blockDim.x) { - auto posY = alignCorners ? static_cast(sd::math::p_round(halfPixelCenters?((float)y + 0.5f) * heightScale:(float)y * heightScale)) : static_cast(sd::math::p_floor( - halfPixelCenters?((float)y + 0.5f) * heightScale:(float)y * heightScale)); - Nd4jLong inY = sd::math::nd4j_min(posY, inHeight - 1); - if (halfPixelCenters) { - inY = sd::math::nd4j_max(0LL, inY); - } - - for (int x = threadIdx.y; x < outWidth; x += blockDim.y) { - auto posX = alignCorners ? static_cast(sd::math::p_round(halfPixelCenters?((float)x + 0.5f) * widthScale:(float)x * widthScale)) : static_cast(sd::math::p_floor( - halfPixelCenters?((float)x + 0.5f) * widthScale:(float)x * widthScale)); - Nd4jLong inX = sd::math::nd4j_min(posX, inWidth - 1); - if (halfPixelCenters) { - inX = sd::math::nd4j_max(0LL, inX); - } - - auto start = blockIdx.z * blockDim.z + threadIdx.z; - auto step = blockDim.z * gridDim.z; - - for (Nd4jLong e = start; e < channels; e += step) { - Nd4jLong posX[] = {b, inY, inX, e}; - Nd4jLong posZ[] = {b, y, x, e}; - auto xIndex = shape::getOffset(inputShape, posX); - auto zIndex = shape::getOffset(outputShape, posZ); - output[zIndex] = input[xIndex]; - } - } - } +template +static __global__ void resizeNeighborKernel( + T const* input, Nd4jLong const* inputShape, T* output, + Nd4jLong const* outputShape, Nd4jLong batchSize, Nd4jLong inWidth, + Nd4jLong inHeight, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, + double widthScale, double heightScale, bool alignCorners, + bool halfPixelCenters) { + // for (int b = blockIdx.x; b < batchSize; b += gridDim.x) + if (blockIdx.x < batchSize) { + auto b = blockIdx.x; + for (int y = threadIdx.x; y < outHeight; y += blockDim.x) { + auto posY = alignCorners + ? static_cast(sd::math::p_round( + halfPixelCenters ? ((float)y + 0.5f) * heightScale + : (float)y * heightScale)) + : static_cast(sd::math::p_floor( + halfPixelCenters ? ((float)y + 0.5f) * heightScale + : (float)y * heightScale)); + Nd4jLong inY = sd::math::nd4j_min(posY, inHeight - 1); + if (halfPixelCenters) { + inY = sd::math::nd4j_max(0LL, inY); + } + + for (int x = threadIdx.y; x < outWidth; x += blockDim.y) { + auto posX = alignCorners + ? static_cast(sd::math::p_round( + halfPixelCenters ? ((float)x + 0.5f) * widthScale + : (float)x * widthScale)) + : static_cast(sd::math::p_floor( + halfPixelCenters ? ((float)x + 0.5f) * widthScale + : (float)x * widthScale)); + Nd4jLong inX = sd::math::nd4j_min(posX, inWidth - 1); + if (halfPixelCenters) { + inX = sd::math::nd4j_max(0LL, inX); } + auto start = blockIdx.z * blockDim.z + threadIdx.z; + auto step = blockDim.z * gridDim.z; + + for (Nd4jLong e = start; e < channels; e += step) { + Nd4jLong posX[] = {b, inY, inX, e}; + Nd4jLong posZ[] = {b, y, x, e}; + auto xIndex = shape::getOffset(inputShape, posX); + auto zIndex = shape::getOffset(outputShape, posZ); + output[zIndex] = input[xIndex]; + } + } } + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // resizeNeighborFunctor - main algorithm by nearest neighbor // - template - int resizeNeighborFunctor_(sd::LaunchContext* context, NDArray const* images, int const width, int const height, - bool const alignCorners, bool const halfPixelCenters, NDArray* output) { - const Nd4jLong batchSize = images->sizeAt(0); - const Nd4jLong inHeight = images->sizeAt(1); - const Nd4jLong inWidth = images->sizeAt(2); - const Nd4jLong channels = images->sizeAt(3); - - const Nd4jLong outHeight = output->sizeAt(1); - const Nd4jLong outWidth = output->sizeAt(2); - - // Handle no-op resizes efficiently. - if (outHeight == inHeight && outWidth == inWidth) { - output->assign(images); - return ND4J_STATUS_OK; - } - -// if ((alignCorners && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (alignCorners && outHeight < 2) || -// (alignCorners && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) { -// // wrong input data -// nd4j_printf("image.resize_nearest_neighbor: Wrong input or output size to resize\n", ""); -// return ND4J_STATUS_BAD_ARGUMENTS; -// } -// float heightScale = alignCorners ? (inHeight - 1.f) / float(outHeight - 1.f) : (inHeight / float(outHeight)); -// float widthScale = alignCorners ? (inWidth - 1.f) / float(outWidth - 1.f) : (inWidth / float(outWidth)); - float heightScale = calculateResizeScale(inHeight, outHeight, alignCorners); - float widthScale = calculateResizeScale(inWidth, outWidth, alignCorners); - - auto imagesBuffer = images->getDataBuffer()->specialAsT();//reinterpret_cast(images->specialBuffer()); - auto outputBuffer = output->dataBuffer()->specialAsT();//reinterpret_cast(output->specialBuffer()); - auto stream = context->getCudaStream(); - - NDArray::prepareSpecialUse({output}, {images}); - resizeNeighborKernel<<>>(imagesBuffer, images->specialShapeInfo(), outputBuffer, output->specialShapeInfo(), - batchSize, inWidth, inHeight, outWidth, outHeight, channels, widthScale, heightScale, alignCorners, halfPixelCenters); - NDArray::registerSpecialUse({output}, {images}); - - return Status::OK(); - } +template +int resizeNeighborFunctor_(sd::LaunchContext* context, NDArray const* images, + int const width, int const height, + bool const alignCorners, bool const halfPixelCenters, + NDArray* output) { + const Nd4jLong batchSize = images->sizeAt(0); + const Nd4jLong inHeight = images->sizeAt(1); + const Nd4jLong inWidth = images->sizeAt(2); + const Nd4jLong channels = images->sizeAt(3); + + const Nd4jLong outHeight = output->sizeAt(1); + const Nd4jLong outWidth = output->sizeAt(2); + + // Handle no-op resizes efficiently. + if (outHeight == inHeight && outWidth == inWidth) { + output->assign(images); + return ND4J_STATUS_OK; + } + + // if ((alignCorners && inHeight < 2) || (inHeight < 1) || (outHeight < + // 1) || (alignCorners && outHeight < 2) || + // (alignCorners && inWidth < 2) || (inWidth < 1) || (outWidth < 1) + // || (center && outWidth < 2)) { + // // wrong input data + // nd4j_printf("image.resize_nearest_neighbor: Wrong input or + // output size to resize\n", ""); return ND4J_STATUS_BAD_ARGUMENTS; + // } + // float heightScale = alignCorners ? (inHeight - 1.f) / + // float(outHeight - 1.f) : (inHeight / float(outHeight)); float + // widthScale = alignCorners ? (inWidth - 1.f) / float(outWidth - 1.f) + // : (inWidth / float(outWidth)); + float heightScale = calculateResizeScale(inHeight, outHeight, alignCorners); + float widthScale = calculateResizeScale(inWidth, outWidth, alignCorners); + + auto imagesBuffer = + images->getDataBuffer() + ->specialAsT(); // reinterpret_cast(images->specialBuffer()); + auto outputBuffer = + output->dataBuffer() + ->specialAsT(); // reinterpret_cast(output->specialBuffer()); + auto stream = context->getCudaStream(); + + NDArray::prepareSpecialUse({output}, {images}); + resizeNeighborKernel<<>>( + imagesBuffer, images->specialShapeInfo(), outputBuffer, + output->specialShapeInfo(), batchSize, inWidth, inHeight, outWidth, + outHeight, channels, widthScale, heightScale, alignCorners, + halfPixelCenters); + NDArray::registerSpecialUse({output}, {images}); + + return Status::OK(); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // resizeImage - resize bilinear algorithm caller // - void resizeImage(sd::LaunchContext* context, NDArray const* images, Nd4jLong batchSize, Nd4jLong inHeight, - Nd4jLong inWidth, Nd4jLong outHeight, Nd4jLong outWidth, Nd4jLong channels, BilinearInterpolationData* xs_, - BilinearInterpolationData* ys_, NDArray* output) { - BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), - resizeImage_, (context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, - xs_, ys_, output), NUMERIC_TYPES, FLOAT_TYPES); - } +void resizeImage(sd::LaunchContext* context, NDArray const* images, + Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, + Nd4jLong outHeight, Nd4jLong outWidth, Nd4jLong channels, + BilinearInterpolationData* xs_, BilinearInterpolationData* ys_, + NDArray* output) { + BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), resizeImage_, + (context, images, batchSize, inHeight, inWidth, + outHeight, outWidth, channels, xs_, ys_, output), + NUMERIC_TYPES, FLOAT_TYPES); +} - BUILD_DOUBLE_TEMPLATE(template void resizeImage_,(sd::LaunchContext* context, NDArray const* images, - Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, Nd4jLong outWidth, - Nd4jLong channels, BilinearInterpolationData* xs_, BilinearInterpolationData* ys_, NDArray* output), - NUMERIC_TYPES, FLOAT_TYPES); +BUILD_DOUBLE_TEMPLATE(template void resizeImage_, + (sd::LaunchContext * context, NDArray const* images, + Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, + Nd4jLong outHeight, Nd4jLong outWidth, Nd4jLong channels, + BilinearInterpolationData* xs_, + BilinearInterpolationData* ys_, NDArray* output), + NUMERIC_TYPES, FLOAT_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - int resizeBilinearFunctor(sd::LaunchContext* context, NDArray const* images, int width, int height, - bool const alignCorners, bool const halfPixelCenter, NDArray* output) { - BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), return resizeBilinearFunctor_, (context, images, - width, height, alignCorners, halfPixelCenter, output), NUMERIC_TYPES, FLOAT_TYPES); - } -// BUILD_SINGLE_TEMPLATE(template int resizeBilinearFunctor_, (sd::LaunchContext* context, -// NDArray const* images, int const width, int const height, bool const alignCorners, -// bool const halfPixelCenter, NDArray* output), LIBND4J_TYPES); +int resizeBilinearFunctor(sd::LaunchContext* context, NDArray const* images, + int width, int height, bool const alignCorners, + bool const halfPixelCenter, NDArray* output) { + BUILD_DOUBLE_SELECTOR( + images->dataType(), output->dataType(), return resizeBilinearFunctor_, + (context, images, width, height, alignCorners, halfPixelCenter, output), + NUMERIC_TYPES, FLOAT_TYPES); +} +// BUILD_SINGLE_TEMPLATE(template int resizeBilinearFunctor_, +// (sd::LaunchContext* context, +// NDArray const* images, int const width, int const height, bool +// const alignCorners, bool const halfPixelCenter, NDArray* output), +// LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - int resizeNeighborFunctor(sd::LaunchContext* context, NDArray const* images, int const width, int const height, - bool const alignCorners, bool const halfPixelCenter, NDArray* output) { - BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_, - (context, images, width, height, alignCorners, halfPixelCenter, output), LIBND4J_TYPES); - } -// BUILD_SINGLE_TEMPLATE(template int resizeNeighborFunctor_, (sd::LaunchContext* context, NDArray const* images, -// int width, int height, bool const alignCorners, bool const halfPixelCenter, NDArray* output), LIBND4J_TYPES); +int resizeNeighborFunctor(sd::LaunchContext* context, NDArray const* images, + int const width, int const height, + bool const alignCorners, bool const halfPixelCenter, + NDArray* output) { + BUILD_SINGLE_SELECTOR( + images->dataType(), return resizeNeighborFunctor_, + (context, images, width, height, alignCorners, halfPixelCenter, output), + LIBND4J_TYPES); +} +// BUILD_SINGLE_TEMPLATE(template int resizeNeighborFunctor_, +// (sd::LaunchContext* context, NDArray const* images, +// int width, int height, bool const alignCorners, bool const +// halfPixelCenter, NDArray* output), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // Bicubic interpolation //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - struct ImageResizerState { - explicit ImageResizerState(bool alignCorners, bool halfPixelCenters) - : _alignCorners(alignCorners), - _halfPixelCenters(halfPixelCenters) {} - - // ValidateAndCalculateOutputSize checks the bounds on the input tensors - // and requested size, sets up some of the resizing state such as the - // heightScale and widthScale, and calculates the output size. - // If any of these operations fails, it sets an error status in - // the context, which the caller must check. - int validateAndCalculateOutputSize(NDArray const* input, int const width, int const height) { - // - batchSize = input->sizeAt(0);//.dim_size(0); - outHeight = height; - outWidth = width; //internal::SubtleMustCopy(Svec(1)); - inHeight = static_cast(input->sizeAt(1)); - inWidth = static_cast(input->sizeAt(2)); - channels = input->sizeAt(3); //.dim_size(3); - heightScale = calculateResizeScale(inHeight, outHeight, _alignCorners); - widthScale = calculateResizeScale(inWidth, outWidth, _alignCorners); - - // Guard against overflows - if (ceilf((outHeight - 1) * heightScale) > static_cast(DataTypeUtils::max())) { - nd4j_printf("resize_bicubic: Upper overflow occurs for resize height (%f)\n", ceilf((outHeight - 1) * heightScale)); - return Status::CODE(ND4J_STATUS_BAD_INPUT, "resize_bicubic: Upper overflow occurs for resize height"); - } - if (ceilf((outWidth - 1) * heightScale) > static_cast(DataTypeUtils::max())) { - nd4j_printf("resize_bicubic: Upper overflow occurs for resize height (%f)\n", ceilf((outHeight - 1) * heightScale)); - return Status::CODE(ND4J_STATUS_BAD_INPUT, "resize_bicubic: Upper overflow occurs for resize width"); - } - - return Status::OK(); - } - - // Calculates all the required variables, and allocates the output. - int validateAndCreateOutput(NDArray const* input, int const width, int const height) { - return validateAndCalculateOutputSize(input, width, height); - } - - Nd4jLong batchSize; - Nd4jLong outHeight; - Nd4jLong outWidth; - Nd4jLong inHeight; - Nd4jLong inWidth; - Nd4jLong channels; - float heightScale; - float widthScale; - NDArray* output = nullptr; - cudaStream_t* stream; - private: - bool _alignCorners; - bool _halfPixelCenters; - }; - - struct WeightsAndIndices { - float _weight0; - float _weight1; - float _weight2; - float _weight3; - Nd4jLong _index0; - Nd4jLong _index1; - Nd4jLong _index2; - Nd4jLong _index3; - - int _advance; // advance value. - }; - - class CachedInterpolationCalculator { - public: - _CUDA_HD CachedInterpolationCalculator() : _indexes{-1, -1, -1, -1} {} - - // Advances iteration. Returns the number of values that should be copied from - // the current point to the next point. The copying should always be done by - // copying the last values from the old point to the first - // values of the new point. - inline _CUDA_HD int Advance(const Nd4jLong x0, const Nd4jLong x1, const Nd4jLong x2, - const Nd4jLong x3) { - // We use 2 hands and walk through, copying from one to another where - // we already have values. - // Invariant, new_indicies_hand <= cached_values_hand - const Nd4jLong new_x_indices[4] = {x0, x1, x2, x3}; - int cachedValuesHand = 0; - int newIndiciesHand = 0; - while (cachedValuesHand < 4) { - if (_indexes[cachedValuesHand] == new_x_indices[newIndiciesHand]) { - if (newIndiciesHand < cachedValuesHand) { - _indexes[newIndiciesHand] = _indexes[cachedValuesHand]; - } - newIndiciesHand++; - } - cachedValuesHand++; - } - switch (newIndiciesHand) { - case 0: - _indexes[0] = x0; - case 1: - _indexes[1] = x1; - case 2: - _indexes[2] = x2; - case 3: - _indexes[3] = x3; - break; - } - return newIndiciesHand; - } - - private: - Nd4jLong _indexes[4]; - }; - - - static __global__ void initCoefTableKernel(const double a, float* table, Nd4jLong tableSize) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - for (int i = start; i <= tableSize; i += step) { - float x = i * 1.0 / tableSize; - table[i * 2] = ((a + 2) * x - (a + 3)) * x * x + 1; - x += 1.0; - table[i * 2 + 1] = ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; - } +struct ImageResizerState { + explicit ImageResizerState(bool alignCorners, bool halfPixelCenters) + : _alignCorners(alignCorners), _halfPixelCenters(halfPixelCenters) {} + + // ValidateAndCalculateOutputSize checks the bounds on the input tensors + // and requested size, sets up some of the resizing state such as the + // heightScale and widthScale, and calculates the output size. + // If any of these operations fails, it sets an error status in + // the context, which the caller must check. + int validateAndCalculateOutputSize(NDArray const* input, int const width, + int const height) { + // + batchSize = input->sizeAt(0); //.dim_size(0); + outHeight = height; + outWidth = width; // internal::SubtleMustCopy(Svec(1)); + inHeight = static_cast(input->sizeAt(1)); + inWidth = static_cast(input->sizeAt(2)); + channels = input->sizeAt(3); //.dim_size(3); + heightScale = calculateResizeScale(inHeight, outHeight, _alignCorners); + widthScale = calculateResizeScale(inWidth, outWidth, _alignCorners); + + // Guard against overflows + if (ceilf((outHeight - 1) * heightScale) > + static_cast(DataTypeUtils::max())) { + nd4j_printf( + "resize_bicubic: Upper overflow occurs for resize height (%f)\n", + ceilf((outHeight - 1) * heightScale)); + return Status::CODE( + ND4J_STATUS_BAD_INPUT, + "resize_bicubic: Upper overflow occurs for resize height"); + } + if (ceilf((outWidth - 1) * heightScale) > + static_cast(DataTypeUtils::max())) { + nd4j_printf( + "resize_bicubic: Upper overflow occurs for resize height (%f)\n", + ceilf((outHeight - 1) * heightScale)); + return Status::CODE( + ND4J_STATUS_BAD_INPUT, + "resize_bicubic: Upper overflow occurs for resize width"); } - static const Nd4jLong kTableSize = (1 << 10); - float* initCoeffsTable(const double a, cudaStream_t* stream) { - // Allocate and initialize coefficients table using Bicubic - // convolution algorithm. - // https://en.wikipedia.org/wiki/Bicubic_interpolation - float* coeffs_table; // = new float[(kTableSize + 1) * 2]; - auto err = cudaMalloc(&coeffs_table, sizeof(float) * ((kTableSize + 1) * 2)); - if (err != 0) { - throw cuda_exception::build("helpers::initCoeffsTable: Cannot allocate memory for vertical parts rectangulars", err); + return Status::OK(); + } + + // Calculates all the required variables, and allocates the output. + int validateAndCreateOutput(NDArray const* input, int const width, + int const height) { + return validateAndCalculateOutputSize(input, width, height); + } + + Nd4jLong batchSize; + Nd4jLong outHeight; + Nd4jLong outWidth; + Nd4jLong inHeight; + Nd4jLong inWidth; + Nd4jLong channels; + float heightScale; + float widthScale; + NDArray* output = nullptr; + cudaStream_t* stream; + + private: + bool _alignCorners; + bool _halfPixelCenters; +}; + +struct WeightsAndIndices { + float _weight0; + float _weight1; + float _weight2; + float _weight3; + Nd4jLong _index0; + Nd4jLong _index1; + Nd4jLong _index2; + Nd4jLong _index3; + + int _advance; // advance value. +}; + +class CachedInterpolationCalculator { + public: + _CUDA_HD CachedInterpolationCalculator() : _indexes{-1, -1, -1, -1} {} + + // Advances iteration. Returns the number of values that should be copied from + // the current point to the next point. The copying should always be done by + // copying the last values from the old point to the first + // values of the new point. + inline _CUDA_HD int Advance(const Nd4jLong x0, const Nd4jLong x1, + const Nd4jLong x2, const Nd4jLong x3) { + // We use 2 hands and walk through, copying from one to another where + // we already have values. + // Invariant, new_indicies_hand <= cached_values_hand + const Nd4jLong new_x_indices[4] = {x0, x1, x2, x3}; + int cachedValuesHand = 0; + int newIndiciesHand = 0; + while (cachedValuesHand < 4) { + if (_indexes[cachedValuesHand] == new_x_indices[newIndiciesHand]) { + if (newIndiciesHand < cachedValuesHand) { + _indexes[newIndiciesHand] = _indexes[cachedValuesHand]; } - - - initCoefTableKernel<<<128,128,128, *stream>>>(a, coeffs_table, kTableSize); - err = cudaStreamSynchronize(*stream); - if (err != 0) { - throw cuda_exception::build("helpers::initCoeffsTable: Cannot syncronize kernel", err); - } - - return coeffs_table; + newIndiciesHand++; + } + cachedValuesHand++; + } + switch (newIndiciesHand) { + case 0: + _indexes[0] = x0; + case 1: + _indexes[1] = x1; + case 2: + _indexes[2] = x2; + case 3: + _indexes[3] = x3; + break; } + return newIndiciesHand; + } + + private: + Nd4jLong _indexes[4]; +}; + +static __global__ void initCoefTableKernel(const double a, float* table, + Nd4jLong tableSize) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + for (int i = start; i <= tableSize; i += step) { + float x = i * 1.0 / tableSize; + table[i * 2] = ((a + 2) * x - (a + 3)) * x * x + 1; + x += 1.0; + table[i * 2 + 1] = ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; + } +} + +static const Nd4jLong kTableSize = (1 << 10); +float* initCoeffsTable(const double a, cudaStream_t* stream) { + // Allocate and initialize coefficients table using Bicubic + // convolution algorithm. + // https://en.wikipedia.org/wiki/Bicubic_interpolation + float* coeffs_table; // = new float[(kTableSize + 1) * 2]; + auto err = cudaMalloc(&coeffs_table, sizeof(float) * ((kTableSize + 1) * 2)); + if (err != 0) { + throw cuda_exception::build( + "helpers::initCoeffsTable: Cannot allocate memory for vertical parts " + "rectangulars", + err); + } + + initCoefTableKernel<<<128, 128, 128, *stream>>>(a, coeffs_table, kTableSize); + err = cudaStreamSynchronize(*stream); + if (err != 0) { + throw cuda_exception::build( + "helpers::initCoeffsTable: Cannot syncronize kernel", err); + } + + return coeffs_table; +} // _CUDA_HD const float* getCoeffsTable(const bool use_keys_cubic) { // // Static so that we initialize it on first use // if (use_keys_cubic) { // // http://ieeexplore.ieee.org/document/1163711/ -// // R. G. Keys. Cubic convolution interpolation for digital image -// // processing. IEEE Transactions on Acoustics, Speech, and Signal +// // R. G. Keys. Cubic convolution interpolation for digital +// image +// // processing. IEEE Transactions on Acoustics, Speech, and +// Signal // // Processing, 29(6):1153–1160, 1981. -// //static const float* coeffs_table = initCoeffsTable(-0.5f, stream); -// return sCoeffsTableHalf; +// //static const float* coeffs_table = initCoeffsTable(-0.5f, +// stream); return sCoeffsTableHalf; // } else { -// //static const float* coeffs_table = initCoeffsTable(-0.75f, stream); -// return sCoeffsTableThreeFourth; +// //static const float* coeffs_table = initCoeffsTable(-0.75f, +// stream); return sCoeffsTableThreeFourth; // } // } - inline _CUDA_HD Nd4jLong bound(Nd4jLong val, Nd4jLong limit) { - return math::nd4j_min(limit - 1ll, math::nd4j_max(Nd4jLong{0}, val)); - } - +inline _CUDA_HD Nd4jLong bound(Nd4jLong val, Nd4jLong limit) { + return math::nd4j_min(limit - 1ll, math::nd4j_max(Nd4jLong{0}, val)); +} - template - inline _CUDA_HD float interpolate1D(const float weight0, const float weight1, const float weight2, const float weight3, - const T value0, const T value1, const T value2, const T value3) { - return static_cast(value0) * weight0 + - static_cast(value1) * weight1 + - static_cast(value2) * weight2 + - static_cast(value3) * weight3; - } +template +inline _CUDA_HD float interpolate1D(const float weight0, const float weight1, + const float weight2, const float weight3, + const T value0, const T value1, + const T value2, const T value3) { + return static_cast(value0) * weight0 + + static_cast(value1) * weight1 + + static_cast(value2) * weight2 + + static_cast(value3) * weight3; +} // Compute the 1D interpolation for a given X index using the y_weights - static _CUDA_HD float compute(float values[4], const float xW0, const float xW1, const float xW2, const float xW3) { - return interpolate1D(xW0, xW1, xW2, xW3, values[0], values[1],values[2], values[3]); - } - - +static _CUDA_HD float compute(float values[4], const float xW0, const float xW1, + const float xW2, const float xW3) { + return interpolate1D(xW0, xW1, xW2, xW3, values[0], values[1], values[2], + values[3]); +} - template - inline _CUDA_HD void getWeightsAndIndices(float const* coeffs_table, const float scale, const Nd4jLong out_loc, const Nd4jLong limit, WeightsAndIndices* out) { - const Scaler scaler; - const float in_loc_f = scaler(out_loc, scale); - const Nd4jLong in_loc = math::nd4j_floor(in_loc_f); - const float delta = in_loc_f - in_loc; - const Nd4jLong offset = math::nd4j_round(delta * kTableSize); - //const float* coeffs_table = getCoeffsTable(use_keys_cubic); - if (use_keys_cubic) { - // The legacy code placed more weight on the edge pixels, since bounding - // the set of inputs to sample could cause an edge pixel to be repeated. - // Here we change the behavior at borders to match that used by the - // scale_and_translate_op, where sampling locations outside the image have - // their weight set to 0, and the weights are renormalized so that their sum - // is 1.0. - out->_index0 = bound(in_loc - 1, limit); - out->_weight0 = - (out->_index0 == in_loc - 1 ? coeffs_table[offset * 2 + 1] : 0.0f); - out->_index1 = bound(in_loc, limit); - out->_weight1 = (out->_index1 == in_loc ? coeffs_table[offset * 2] : 0.0f); - out->_index2 = bound(in_loc + 1, limit); - out->_weight2 = - (out->_index2 == in_loc + 1 ? coeffs_table[(kTableSize - offset) * 2] - : 0.0f); - out->_index3 = bound(in_loc + 2, limit); - out->_weight3 = (out->_index3 == in_loc + 2 - ? coeffs_table[(kTableSize - offset) * 2 + 1] - : 0.0f); - - const float weight_sum = - out->_weight0 + out->_weight1 + out->_weight2 + out->_weight3; - if (math::nd4j_abs(weight_sum) >= 1000.0f * DataTypeUtils::min()) { - const float one_over_weight_sum = 1.0f / weight_sum; - out->_weight0 *= one_over_weight_sum; - out->_weight1 *= one_over_weight_sum; - out->_weight2 *= one_over_weight_sum; - out->_weight3 *= one_over_weight_sum; - } - } else { - out->_weight0 = coeffs_table[offset * 2 + 1]; - out->_weight1 = coeffs_table[offset * 2]; - out->_weight2 = coeffs_table[(kTableSize - offset) * 2]; - out->_weight3 = coeffs_table[(kTableSize - offset) * 2 + 1]; - out->_index0 = bound(in_loc - 1, limit); - out->_index1 = bound(in_loc, limit); - out->_index2 = bound(in_loc + 1, limit); - out->_index3 = bound(in_loc + 2, limit); - } +template +inline _CUDA_HD void getWeightsAndIndices(float const* coeffs_table, + const float scale, + const Nd4jLong out_loc, + const Nd4jLong limit, + WeightsAndIndices* out) { + const Scaler scaler; + const float in_loc_f = scaler(out_loc, scale); + const Nd4jLong in_loc = math::nd4j_floor(in_loc_f); + const float delta = in_loc_f - in_loc; + const Nd4jLong offset = math::nd4j_round(delta * kTableSize); + // const float* coeffs_table = getCoeffsTable(use_keys_cubic); + if (use_keys_cubic) { + // The legacy code placed more weight on the edge pixels, since bounding + // the set of inputs to sample could cause an edge pixel to be repeated. + // Here we change the behavior at borders to match that used by the + // scale_and_translate_op, where sampling locations outside the image have + // their weight set to 0, and the weights are renormalized so that their sum + // is 1.0. + out->_index0 = bound(in_loc - 1, limit); + out->_weight0 = + (out->_index0 == in_loc - 1 ? coeffs_table[offset * 2 + 1] : 0.0f); + out->_index1 = bound(in_loc, limit); + out->_weight1 = (out->_index1 == in_loc ? coeffs_table[offset * 2] : 0.0f); + out->_index2 = bound(in_loc + 1, limit); + out->_weight2 = + (out->_index2 == in_loc + 1 ? coeffs_table[(kTableSize - offset) * 2] + : 0.0f); + out->_index3 = bound(in_loc + 2, limit); + out->_weight3 = (out->_index3 == in_loc + 2 + ? coeffs_table[(kTableSize - offset) * 2 + 1] + : 0.0f); + + const float weight_sum = + out->_weight0 + out->_weight1 + out->_weight2 + out->_weight3; + if (math::nd4j_abs(weight_sum) >= 1000.0f * DataTypeUtils::min()) { + const float one_over_weight_sum = 1.0f / weight_sum; + out->_weight0 *= one_over_weight_sum; + out->_weight1 *= one_over_weight_sum; + out->_weight2 *= one_over_weight_sum; + out->_weight3 *= one_over_weight_sum; } + } else { + out->_weight0 = coeffs_table[offset * 2 + 1]; + out->_weight1 = coeffs_table[offset * 2]; + out->_weight2 = coeffs_table[(kTableSize - offset) * 2]; + out->_weight3 = coeffs_table[(kTableSize - offset) * 2 + 1]; + out->_index0 = bound(in_loc - 1, limit); + out->_index1 = bound(in_loc, limit); + out->_index2 = bound(in_loc + 1, limit); + out->_index3 = bound(in_loc + 2, limit); + } +} - static __global__ void accumulateChannelsKernel(WeightsAndIndices* pXWais, Nd4jLong outWidth, Nd4jLong channels) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (auto x = start; x < outWidth; x += step) { - pXWais[x]._index0 *= channels; - pXWais[x]._index1 *= channels; - pXWais[x]._index2 *= channels; - pXWais[x]._index3 *= channels; - } - } +static __global__ void accumulateChannelsKernel(WeightsAndIndices* pXWais, + Nd4jLong outWidth, + Nd4jLong channels) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (auto x = start; x < outWidth; x += step) { + pXWais[x]._index0 *= channels; + pXWais[x]._index1 *= channels; + pXWais[x]._index2 *= channels; + pXWais[x]._index3 *= channels; + } +} - static __global__ void advaceWeightsAndIndicesKernel(float const* cacheTable, CachedInterpolationCalculator* calc, WeightsAndIndices* pXWais, Nd4jLong inWidth, float widthScale, - Nd4jLong outWidth, Nd4jLong channels, bool halfPixelCenters) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (auto x = start; x < outWidth; x += step) { - if (halfPixelCenters) - getWeightsAndIndices(cacheTable, widthScale, x, inWidth, &pXWais[x]); - else - getWeightsAndIndices(cacheTable, widthScale, x, inWidth, &pXWais[x]); - pXWais[x]._advance = calc->Advance(pXWais[x]._index0, pXWais[x]._index1, pXWais[x]._index2, pXWais[x]._index3); - } - } - // resizerState and xWais are device allocated - static void computeXWeightsAndIndices(float const* coeffsTable, const ImageResizerState& resizerState, - const bool halfPixelCenters, - WeightsAndIndices* pXWais) { - - auto stream = resizerState.stream; - auto outWidth = resizerState.outWidth; - CachedInterpolationCalculator calc; // = new CachedInterpolationCalculator; - CachedInterpolationCalculator* pCalcD; - auto err = cudaMalloc(&pCalcD, sizeof(CachedInterpolationCalculator)); - if (err != 0) { - cuda_exception::build("helpers::computeXWeightsAndIndices: Cannot allocated device memory for interpolate calculator", err); - } - err = cudaMemcpyAsync(pCalcD, &calc, sizeof(CachedInterpolationCalculator), cudaMemcpyHostToDevice, *stream); - if (err != 0) { - cuda_exception::build("helpers::computeXWeightsAndIndices: Cannot set up device memory for interpolate calculator", err); - } +static __global__ void advaceWeightsAndIndicesKernel( + float const* cacheTable, CachedInterpolationCalculator* calc, + WeightsAndIndices* pXWais, Nd4jLong inWidth, float widthScale, + Nd4jLong outWidth, Nd4jLong channels, bool halfPixelCenters) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (auto x = start; x < outWidth; x += step) { + if (halfPixelCenters) + getWeightsAndIndices(cacheTable, widthScale, x, + inWidth, &pXWais[x]); + else + getWeightsAndIndices(cacheTable, widthScale, x, + inWidth, &pXWais[x]); + pXWais[x]._advance = calc->Advance(pXWais[x]._index0, pXWais[x]._index1, + pXWais[x]._index2, pXWais[x]._index3); + } +} +// resizerState and xWais are device allocated +static void computeXWeightsAndIndices(float const* coeffsTable, + const ImageResizerState& resizerState, + const bool halfPixelCenters, + WeightsAndIndices* pXWais) { + auto stream = resizerState.stream; + auto outWidth = resizerState.outWidth; + CachedInterpolationCalculator calc; // = new CachedInterpolationCalculator; + CachedInterpolationCalculator* pCalcD; + auto err = cudaMalloc(&pCalcD, sizeof(CachedInterpolationCalculator)); + if (err != 0) { + cuda_exception::build( + "helpers::computeXWeightsAndIndices: Cannot allocated device memory " + "for interpolate calculator", + err); + } + err = cudaMemcpyAsync(pCalcD, &calc, sizeof(CachedInterpolationCalculator), + cudaMemcpyHostToDevice, *stream); + if (err != 0) { + cuda_exception::build( + "helpers::computeXWeightsAndIndices: Cannot set up device memory for " + "interpolate calculator", + err); + } + + advaceWeightsAndIndicesKernel<<<128, 128, 128, *stream>>>( + coeffsTable, pCalcD, pXWais, resizerState.inWidth, + resizerState.widthScale, outWidth, resizerState.channels, + halfPixelCenters); + err = cudaFree(pCalcD); + if (err != 0) { + cuda_exception::build( + "helpers::computeXWeightsAndIndices: Cannot deallocated device memory " + "for interpolate calculator", + err); + } + err = cudaStreamSynchronize(*stream); + if (err != 0) { + cuda_exception::build( + "helpers::computeXWeightsAndIndices: Cannot synchronize stream after " + "advance weights and indicers", + err); + } + // Scale the values so they can be used as offsets into buffers. + accumulateChannelsKernel<<<128, 128, 512, *stream>>>(pXWais, outWidth, + resizerState.channels); + err = cudaStreamSynchronize(*stream); + if (err != 0) { + cuda_exception::build( + "helpers::computeXWeightsAndIndices: Cannot synchronize stream after " + "accumulate channels", + err); + } +} - advaceWeightsAndIndicesKernel<<<128, 128, 128, *stream>>>(coeffsTable, pCalcD, pXWais, resizerState.inWidth, resizerState.widthScale, outWidth, resizerState.channels, halfPixelCenters); - err = cudaFree(pCalcD); - if (err != 0) { - cuda_exception::build("helpers::computeXWeightsAndIndices: Cannot deallocated device memory for interpolate calculator", err); - } - err = cudaStreamSynchronize(*stream); - if (err != 0) { - cuda_exception::build("helpers::computeXWeightsAndIndices: Cannot synchronize stream after advance weights and indicers", err); - } - // Scale the values so they can be used as offsets into buffers. - accumulateChannelsKernel<<<128, 128, 512, *stream>>>(pXWais, outWidth, resizerState.channels); - err = cudaStreamSynchronize(*stream); - if (err != 0) { - cuda_exception::build("helpers::computeXWeightsAndIndices: Cannot synchronize stream after accumulate channels", err); - } +template +static _CUDA_HD FORCEINLINE float computeYInterpolation( + int which, int channelNum, const WeightsAndIndices& yWai, const T* pY0, + const T* pY1, const T* pY2, const T* pY3, const WeightsAndIndices& xWai) { + int xIndex; + switch (which) { + case 0: + xIndex = xWai._index0; + break; + case 1: + xIndex = xWai._index1; + break; + case 2: + xIndex = xWai._index2; + break; + default: + xIndex = xWai._index3; + break; + } + const Nd4jLong pt_index = xIndex + channelNum; + return interpolate1D(yWai._weight0, yWai._weight1, yWai._weight2, + yWai._weight3, pY0[pt_index], pY1[pt_index], + pY2[pt_index], pY3[pt_index]); +} - } +template +static __global__ void bicubicInterpolateWithCachingKernel( + float const* cachedTable, T const* inputPtr, + ImageResizerState* pResizerState, WeightsAndIndices* xWais, + bool halfPixelCenters, Nd4jLong inBatchWidth, Nd4jLong inRowWidth, + float* outputPtr) { + // auto numChannels = pResizerState->channels; + + for (Nd4jLong b = blockIdx.x; b < pResizerState->batchSize; b += gridDim.x) { + auto pInput = inputPtr + b * inBatchWidth; + float* cachedValue; + for (Nd4jLong y = threadIdx.x; y < pResizerState->outHeight; + y += blockDim.x) { + if (threadIdx.x == 0) { + extern __shared__ char sharedChar[]; + cachedValue = reinterpret_cast(sharedChar); + } + auto pos = (b * pResizerState->outHeight + y) * pResizerState->outWidth * + pResizerState->channels; + auto pOutput = &outputPtr[pos]; + struct WeightsAndIndices yWai; + if (halfPixelCenters) { + getWeightsAndIndices( + cachedTable, pResizerState->heightScale, y, pResizerState->inHeight, + &yWai); + } else { + getWeightsAndIndices( + cachedTable, pResizerState->heightScale, y, pResizerState->inHeight, + &yWai); + } + // Make pointers represent offsets of data in inputBPtr. + const T* y_ptr_0 = pInput + yWai._index0 * inRowWidth; + const T* y_ptr_1 = pInput + yWai._index1 * inRowWidth; + const T* y_ptr_2 = pInput + yWai._index2 * inRowWidth; + const T* y_ptr_3 = pInput + yWai._index3 * inRowWidth; + + if (pResizerState->channels == 3) { + // Manually unroll case of 3 channels. + float cached_value_0[4] = {0}; + float cached_value_1[4] = {0}; + float cached_value_2[4] = {0}; + for (Nd4jLong x = 0; x < pResizerState->outWidth; ++x) { + const WeightsAndIndices& xWai = xWais[x]; + // Shift values in cached_value_* to fill first '_advance' values. + switch (xWai._advance) { + case 3: + cached_value_0[0] = cached_value_0[1]; + cached_value_0[1] = cached_value_0[2]; + cached_value_0[2] = cached_value_0[3]; + cached_value_1[0] = cached_value_1[1]; + cached_value_1[1] = cached_value_1[2]; + cached_value_1[2] = cached_value_1[3]; + cached_value_2[0] = cached_value_2[1]; + cached_value_2[1] = cached_value_2[2]; + cached_value_2[2] = cached_value_2[3]; + break; + case 2: + cached_value_0[0] = cached_value_0[2]; + cached_value_0[1] = cached_value_0[3]; + cached_value_1[0] = cached_value_1[2]; + cached_value_1[1] = cached_value_1[3]; + cached_value_2[0] = cached_value_2[2]; + cached_value_2[1] = cached_value_2[3]; + break; + case 1: { + cached_value_0[0] = cached_value_0[3]; + cached_value_1[0] = cached_value_1[3]; + cached_value_2[0] = cached_value_2[3]; + break; + } + } - template - static _CUDA_HD FORCEINLINE float computeYInterpolation( - int which, int channelNum, const WeightsAndIndices& yWai, - const T* pY0, const T* pY1, const T* pY2, const T* pY3, - const WeightsAndIndices& xWai) { - int xIndex; - switch (which) { + // Set the remaining '4-_advance' values by computing. + switch (xWai._advance) { case 0: - xIndex = xWai._index0; - break; + cached_value_0[0] = computeYInterpolation( + 0, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_1[0] = computeYInterpolation( + 0, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_2[0] = computeYInterpolation( + 0, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); case 1: - xIndex = xWai._index1; - break; + cached_value_0[1] = computeYInterpolation( + 1, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_1[1] = computeYInterpolation( + 1, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_2[1] = computeYInterpolation( + 1, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); case 2: - xIndex = xWai._index2; - break; - default: - xIndex = xWai._index3; - break; + cached_value_0[2] = computeYInterpolation( + 2, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_1[2] = computeYInterpolation( + 2, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_2[2] = computeYInterpolation( + 2, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + case 3: + cached_value_0[3] = computeYInterpolation( + 3, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_1[3] = computeYInterpolation( + 3, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_2[3] = computeYInterpolation( + 3, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + // break; + } + pOutput[x * pResizerState->channels + 0] = + compute(cached_value_0, xWai._weight0, xWai._weight1, + xWai._weight2, xWai._weight3); + pOutput[x * pResizerState->channels + 1] = + compute(cached_value_1, xWai._weight0, xWai._weight1, + xWai._weight2, xWai._weight3); + pOutput[x * pResizerState->channels + 2] = + compute(cached_value_2, xWai._weight0, xWai._weight1, + xWai._weight2, xWai._weight3); } - const Nd4jLong pt_index = xIndex + channelNum; - return interpolate1D(yWai._weight0, yWai._weight1, yWai._weight2, - yWai._weight3, pY0[pt_index], pY1[pt_index], - pY2[pt_index], pY3[pt_index]); - } - - template - static __global__ void bicubicInterpolateWithCachingKernel(float const* cachedTable, T const* inputPtr, ImageResizerState* pResizerState, WeightsAndIndices* xWais, bool halfPixelCenters, Nd4jLong inBatchWidth, Nd4jLong inRowWidth, float* outputPtr) { -// auto numChannels = pResizerState->channels; - - for (Nd4jLong b = blockIdx.x; b < pResizerState->batchSize; b += gridDim.x) { - auto pInput = inputPtr + b * inBatchWidth; - float* cachedValue; - for (Nd4jLong y = threadIdx.x; y < pResizerState->outHeight; y += blockDim.x) { - if (threadIdx.x == 0) { - extern __shared__ char sharedChar[]; - cachedValue = reinterpret_cast(sharedChar); - } - auto pos = (b * pResizerState->outHeight + y) * pResizerState->outWidth * pResizerState->channels; - auto pOutput = &outputPtr[pos]; - struct WeightsAndIndices yWai; - if (halfPixelCenters) { - getWeightsAndIndices(cachedTable, pResizerState->heightScale, y, pResizerState->inHeight, &yWai); - } else { - getWeightsAndIndices(cachedTable, pResizerState->heightScale, y, pResizerState->inHeight, &yWai); - } - // Make pointers represent offsets of data in inputBPtr. - const T* y_ptr_0 = pInput + yWai._index0 * inRowWidth; - const T* y_ptr_1 = pInput + yWai._index1 * inRowWidth; - const T* y_ptr_2 = pInput + yWai._index2 * inRowWidth; - const T* y_ptr_3 = pInput + yWai._index3 * inRowWidth; - - if (pResizerState->channels == 3) { - // Manually unroll case of 3 channels. - float cached_value_0[4] = {0}; - float cached_value_1[4] = {0}; - float cached_value_2[4] = {0}; - for (Nd4jLong x = 0; x < pResizerState->outWidth; ++x) { - const WeightsAndIndices& xWai = xWais[x]; - // Shift values in cached_value_* to fill first '_advance' values. - switch (xWai._advance) { - case 3: - cached_value_0[0] = cached_value_0[1]; - cached_value_0[1] = cached_value_0[2]; - cached_value_0[2] = cached_value_0[3]; - cached_value_1[0] = cached_value_1[1]; - cached_value_1[1] = cached_value_1[2]; - cached_value_1[2] = cached_value_1[3]; - cached_value_2[0] = cached_value_2[1]; - cached_value_2[1] = cached_value_2[2]; - cached_value_2[2] = cached_value_2[3]; - break; - case 2: - cached_value_0[0] = cached_value_0[2]; - cached_value_0[1] = cached_value_0[3]; - cached_value_1[0] = cached_value_1[2]; - cached_value_1[1] = cached_value_1[3]; - cached_value_2[0] = cached_value_2[2]; - cached_value_2[1] = cached_value_2[3]; - break; - case 1: { - cached_value_0[0] = cached_value_0[3]; - cached_value_1[0] = cached_value_1[3]; - cached_value_2[0] = cached_value_2[3]; - break; - } - } - - // Set the remaining '4-_advance' values by computing. - switch (xWai._advance) { - case 0: - cached_value_0[0] = computeYInterpolation(0, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_1[0] = computeYInterpolation(0, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_2[0] = computeYInterpolation(0, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - case 1: - cached_value_0[1] = computeYInterpolation(1, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_1[1] = computeYInterpolation(1, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_2[1] = computeYInterpolation(1, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - case 2: - cached_value_0[2] = computeYInterpolation(2, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_1[2] = computeYInterpolation(2, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_2[2] = computeYInterpolation(2, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - case 3: - cached_value_0[3] = computeYInterpolation(3, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_1[3] = computeYInterpolation(3, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_2[3] = computeYInterpolation(3, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - // break; - } - pOutput[x * pResizerState->channels + 0] = compute(cached_value_0, xWai._weight0, xWai._weight1, - xWai._weight2, xWai._weight3); - pOutput[x * pResizerState->channels + 1] = compute(cached_value_1, xWai._weight0, xWai._weight1, - xWai._weight2, xWai._weight3); - pOutput[x * pResizerState->channels + 2] = compute(cached_value_2, xWai._weight0, xWai._weight1, - xWai._weight2, xWai._weight3); - } - } else { - for (Nd4jLong x = 0; x < pResizerState->outWidth; ++x) { - const WeightsAndIndices& xWai = xWais[x]; - // Shift values in cachedValue to fill first '_advance' values. - switch (xWai._advance) { - case 3: - for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { - cachedValue[4 * c + 0] = cachedValue[4 * c + 1]; - cachedValue[4 * c + 1] = cachedValue[4 * c + 2]; - cachedValue[4 * c + 2] = cachedValue[4 * c + 3]; - } - break; - case 2: - for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { - cachedValue[4 * c + 0] = cachedValue[4 * c + 2]; - cachedValue[4 * c + 1] = cachedValue[4 * c + 3]; - } - break; - case 1: { - for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { - cachedValue[4 * c + 0] = cachedValue[4 * c + 3]; - } - break; - } - } - - // Set the remaining '4-_advance' values by computing. - switch (xWai._advance) { - case 0: - for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { - cachedValue[4 * c + 0] = computeYInterpolation(0, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - } - case 1: - for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { - cachedValue[4 * c + 1] = computeYInterpolation(1, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - } - case 2: - for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { - cachedValue[4 * c + 2] = computeYInterpolation(2, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - } - case 3: - for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { - cachedValue[4 * c + 3] = computeYInterpolation(3, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - } - // break; - } - for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { - pOutput[x * pResizerState->channels + c] = compute(&cachedValue[4 * c], xWai._weight0, xWai._weight1, xWai._weight2, xWai._weight3); - } - } - } + } else { + for (Nd4jLong x = 0; x < pResizerState->outWidth; ++x) { + const WeightsAndIndices& xWai = xWais[x]; + // Shift values in cachedValue to fill first '_advance' values. + switch (xWai._advance) { + case 3: + for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { + cachedValue[4 * c + 0] = cachedValue[4 * c + 1]; + cachedValue[4 * c + 1] = cachedValue[4 * c + 2]; + cachedValue[4 * c + 2] = cachedValue[4 * c + 3]; + } + break; + case 2: + for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { + cachedValue[4 * c + 0] = cachedValue[4 * c + 2]; + cachedValue[4 * c + 1] = cachedValue[4 * c + 3]; + } + break; + case 1: { + for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { + cachedValue[4 * c + 0] = cachedValue[4 * c + 3]; + } + break; } - } - - } - - - template - static void - bicubicInterpolateWithCaching(NDArray const* image, ImageResizerState const& resizerState, bool const halfPixelCenters, NDArray* output) { - const auto numChannels = resizerState.channels; - const Nd4jLong inRowWidth = resizerState.inWidth * numChannels; - const Nd4jLong inBatchWidth = resizerState.inHeight * inRowWidth; - - auto stream = resizerState.stream; //output->getContext()->getCudaStream(); - ImageResizerState* resizerStateD; - auto err = cudaMalloc(&resizerStateD, sizeof(ImageResizerState)); - if (err != 0) { - throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot allocate memory for resizerState", err); - } - err = cudaMemcpyAsync(resizerStateD, &resizerState, sizeof(ImageResizerState), cudaMemcpyHostToDevice, *stream); - if (err != 0) { - throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot set up memory for resizerState", err); - } - -// float* cachedValue = nullptr; -// size_t cachedSize = sizeof(float) * (numChannels == 3 ? 0 : 4 * numChannels); -// if (cachedSize) { -// err = cudaMalloc(reinterpret_cast(&cachedValue), cachedSize); -// if (err != 0) { -// throw cuda_exception::build( -// "helpers::bicubicInterpolateWithCaching: Cannot allocate memory for cached values", err); -// } -// err = cudaMemset(cachedValue, 0, cachedSize); -// if (err != 0) { -// throw cuda_exception::build( -// "helpers::bicubicInterpolateWithCaching: Cannot set up memory for cached values", err); -// } -// } - - WeightsAndIndices* xWais; //(resizerState.outWidth); - err = cudaMalloc(&xWais, sizeof(WeightsAndIndices) * resizerState.outWidth); - if (err != 0) { - throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot allocate memory for weights and indices", err); - } + } - auto coeffsTable = halfPixelCenters?initCoeffsTable(-0.5, stream): initCoeffsTable(-0.75, stream); - if (err != 0) { - throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: computeXWeigtsAndInidces finished with error", err); - } - computeXWeightsAndIndices(coeffsTable, resizerState, halfPixelCenters, xWais); - err = cudaStreamQuery(*stream); - if (err != 0) { - throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: computeXWeigtsAndInidces finished with error", err); - } - const T* pInput = image->getDataBuffer()->specialAsT(); - float* pOutput = output->dataBuffer()->specialAsT(); //_data.data(); - bicubicInterpolateWithCachingKernel<<<128, 1, 512, *stream>>>(coeffsTable, pInput, - resizerStateD, xWais, halfPixelCenters, inBatchWidth, inRowWidth, pOutput); - err = cudaStreamSynchronize(*stream); - if (err != 0) { - throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Kernels finished with error", err); + // Set the remaining '4-_advance' values by computing. + switch (xWai._advance) { + case 0: + for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { + cachedValue[4 * c + 0] = computeYInterpolation( + 0, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + } + case 1: + for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { + cachedValue[4 * c + 1] = computeYInterpolation( + 1, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + } + case 2: + for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { + cachedValue[4 * c + 2] = computeYInterpolation( + 2, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + } + case 3: + for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { + cachedValue[4 * c + 3] = computeYInterpolation( + 3, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + } + // break; + } + for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { + pOutput[x * pResizerState->channels + c] = + compute(&cachedValue[4 * c], xWai._weight0, xWai._weight1, + xWai._weight2, xWai._weight3); + } } + } + } + } +} - err = cudaFree(resizerStateD); - if (err != 0) { - throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for resizerState", err); - } -// if (cachedSize) -// err = cudaFree(cachedValue); -// if (err != 0) { -// throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for cached values", err); -// } +template +static void bicubicInterpolateWithCaching(NDArray const* image, + ImageResizerState const& resizerState, + bool const halfPixelCenters, + NDArray* output) { + const auto numChannels = resizerState.channels; + const Nd4jLong inRowWidth = resizerState.inWidth * numChannels; + const Nd4jLong inBatchWidth = resizerState.inHeight * inRowWidth; + + auto stream = resizerState.stream; // output->getContext()->getCudaStream(); + ImageResizerState* resizerStateD; + auto err = cudaMalloc(&resizerStateD, sizeof(ImageResizerState)); + if (err != 0) { + throw cuda_exception::build( + "helpers::bicubicInterpolateWithCaching: Cannot allocate memory for " + "resizerState", + err); + } + err = cudaMemcpyAsync(resizerStateD, &resizerState, sizeof(ImageResizerState), + cudaMemcpyHostToDevice, *stream); + if (err != 0) { + throw cuda_exception::build( + "helpers::bicubicInterpolateWithCaching: Cannot set up memory for " + "resizerState", + err); + } + + // float* cachedValue = nullptr; + // size_t cachedSize = sizeof(float) * (numChannels == 3 ? 0 : 4 * + // numChannels); if (cachedSize) { + // err = cudaMalloc(reinterpret_cast(&cachedValue), + // cachedSize); if (err != 0) { + // throw cuda_exception::build( + // "helpers::bicubicInterpolateWithCaching: Cannot + // allocate memory for cached values", err); + // } + // err = cudaMemset(cachedValue, 0, cachedSize); + // if (err != 0) { + // throw cuda_exception::build( + // "helpers::bicubicInterpolateWithCaching: Cannot set + // up memory for cached values", err); + // } + // } + + WeightsAndIndices* xWais; //(resizerState.outWidth); + err = cudaMalloc(&xWais, sizeof(WeightsAndIndices) * resizerState.outWidth); + if (err != 0) { + throw cuda_exception::build( + "helpers::bicubicInterpolateWithCaching: Cannot allocate memory for " + "weights and indices", + err); + } + + auto coeffsTable = halfPixelCenters ? initCoeffsTable(-0.5, stream) + : initCoeffsTable(-0.75, stream); + if (err != 0) { + throw cuda_exception::build( + "helpers::bicubicInterpolateWithCaching: computeXWeigtsAndInidces " + "finished with error", + err); + } + computeXWeightsAndIndices(coeffsTable, resizerState, halfPixelCenters, xWais); + err = cudaStreamQuery(*stream); + if (err != 0) { + throw cuda_exception::build( + "helpers::bicubicInterpolateWithCaching: computeXWeigtsAndInidces " + "finished with error", + err); + } + const T* pInput = image->getDataBuffer()->specialAsT(); + float* pOutput = output->dataBuffer()->specialAsT(); //_data.data(); + bicubicInterpolateWithCachingKernel<<<128, 1, 512, *stream>>>( + coeffsTable, pInput, resizerStateD, xWais, halfPixelCenters, inBatchWidth, + inRowWidth, pOutput); + err = cudaStreamSynchronize(*stream); + if (err != 0) { + throw cuda_exception::build( + "helpers::bicubicInterpolateWithCaching: Kernels finished with error", + err); + } + + err = cudaFree(resizerStateD); + if (err != 0) { + throw cuda_exception::build( + "helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for " + "resizerState", + err); + } + // if (cachedSize) + // err = cudaFree(cachedValue); + // if (err != 0) { + // throw + // cuda_exception::build("helpers::bicubicInterpolateWithCaching: + // Cannot deallocate memory for cached values", err); + // } + + err = cudaFree(xWais); + if (err != 0) { + throw cuda_exception::build( + "helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for " + "weights and indices", + err); + } + + err = cudaFree(coeffsTable); + if (err != 0) { + throw cuda_exception::build( + "helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for " + "coefficients table", + err); + } +} +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +int resizeBicubicFunctor_(sd::LaunchContext* context, NDArray const* image, + int width, int height, bool preserveAspectRatio, + bool antialias, NDArray* output) { + return Status::OK(); +} - err = cudaFree(xWais); - if (err != 0) { - throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for weights and indices", err); - } +int resizeBicubicFunctor(sd::LaunchContext* context, NDArray const* image, + int width, int height, bool preserveAspectRatio, + bool antialias, NDArray* output) { + BUILD_SINGLE_SELECTOR( + image->dataType(), return resizeBicubicFunctor_, + (context, image, width, height, preserveAspectRatio, antialias, output), + NUMERIC_TYPES); +} +BUILD_SINGLE_TEMPLATE(template int resizeBicubicFunctor_, + (sd::LaunchContext * context, NDArray const* image, + int width, int height, bool preserveAspectRatio, + bool antialias, NDArray* output), + NUMERIC_TYPES); +// ------------------------------------------------------------------------------------------------------------------ +// // +struct CachedInterpolation { + Nd4jLong start; + Nd4jLong end; + float startScale; + float endMinusOneScale; + bool needsBounding; +}; + +static __global__ void fillInterpolationCache(CachedInterpolation* xCached, + Nd4jLong cacheLen, + Nd4jLong inWidth, + float widthScale) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto increment = blockDim.x * gridDim.x; + + for (auto x = start; x < cacheLen; x += increment) { + auto& xCache = xCached[x]; + const float inX = x * widthScale; + const float inX1 = (x + 1) * widthScale; + + Nd4jLong v = math::nd4j_floor(inX); + xCache.start = v; + xCache.startScale = v < inX ? (v + 1 > inX1 ? widthScale : v + 1 - inX) + : (v + 1 > inX1 ? inX1 - v : 1.f); + v = math::nd4j_ceil(inX1); + xCache.end = v--; + xCache.endMinusOneScale = v < inX + ? (v + 1 > inX1 ? widthScale : v + 1 - inX) + : (v + 1 > inX1 ? inX1 - v : 1.f); + xCache.needsBounding = bound(xCache.start, inWidth) != xCache.start || + bound(xCache.end - 1, inWidth) != (xCache.end - 1); + } +} - err = cudaFree(coeffsTable); - if (err != 0) { - throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for coefficients table", err); - } +// ------------------------------------------------------------------------------------------------------------------ +// // +template +struct ScaleCache { + float yScale; + T const* yPtr; +}; +// Computes the sum of all x values defined by taken across +// the y offsets and scales defined by y_ptrs and y_scales, for channel c. +// +// Note that is a template parameter to avoid a performance +// penalty from dynamically checking it. +template +static __device__ void computePatchSumOf3Channels( + float scale, const ImageResizerState& st, ScaleCache const* yScaleCache, + Nd4jLong ptrsLen, const CachedInterpolation& xCache, float* outputPtr) { + bool const needsXBounding = xCache.needsBounding; + + auto boundIfNeeded = [needsXBounding](Nd4jLong x, Nd4jLong y) -> Nd4jLong { + return (needsXBounding ? bound(x, y) : (x)); + }; + + float sum_0 = 0; + float sum_1 = 0; + float sum_2 = 0; + for (int i = 0; i < ptrsLen; ++i) { + const T* ptr = yScaleCache[i].yPtr; + float scaleX = xCache.startScale; + Nd4jLong offset = 3 * boundIfNeeded(xCache.start, st.inWidth); + float sum_y_0 = static_cast(ptr[offset + 0]) * scaleX; + float sum_y_1 = static_cast(ptr[offset + 1]) * scaleX; + float sum_y_2 = static_cast(ptr[offset + 2]) * scaleX; + + if (xCache.start + 1 != xCache.end) { + for (Nd4jLong x = xCache.start + 1; x < xCache.end - 1; ++x) { + Nd4jLong offset = 3 * boundIfNeeded(x, st.inWidth); + sum_y_0 += static_cast(ptr[offset + 0]); + sum_y_1 += static_cast(ptr[offset + 1]); + sum_y_2 += static_cast(ptr[offset + 2]); + } + scaleX = xCache.endMinusOneScale; + offset = st.channels * boundIfNeeded(xCache.end - 1, st.inWidth); + sum_y_0 += static_cast(ptr[offset + 0]) * scaleX; + sum_y_1 += static_cast(ptr[offset + 1]) * scaleX; + sum_y_2 += static_cast(ptr[offset + 2]) * scaleX; } -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - int resizeBicubicFunctor_(sd::LaunchContext * context, NDArray const* image, int width, int height, - bool preserveAspectRatio, bool antialias, NDArray* output) { - return Status::OK(); - } + sum_0 += sum_y_0 * yScaleCache[i].yScale; + sum_1 += sum_y_1 * yScaleCache[i].yScale; + sum_2 += sum_y_2 * yScaleCache[i].yScale; + } + + outputPtr[0] = sum_0 * scale; + outputPtr[1] = sum_1 * scale; + outputPtr[2] = sum_2 * scale; +} - int resizeBicubicFunctor(sd::LaunchContext * context, NDArray const* image, int width, int height, - bool preserveAspectRatio, bool antialias, NDArray* output) { - BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctor_, (context, image, - width, height, preserveAspectRatio, antialias, output), NUMERIC_TYPES); - } - BUILD_SINGLE_TEMPLATE(template int resizeBicubicFunctor_, (sd::LaunchContext * context, NDArray const* image, int width, int height, - bool preserveAspectRatio, bool antialias, NDArray* output), NUMERIC_TYPES); -// ------------------------------------------------------------------------------------------------------------------ // - struct CachedInterpolation { - Nd4jLong start; - Nd4jLong end; - float startScale; - float endMinusOneScale; - bool needsBounding; - }; - - static __global__ void fillInterpolationCache(CachedInterpolation* xCached, Nd4jLong cacheLen, Nd4jLong inWidth, float widthScale) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto increment = blockDim.x * gridDim.x; - - for (auto x = start; x < cacheLen; x += increment) { - auto& xCache = xCached[x]; - const float inX = x * widthScale; - const float inX1 = (x + 1) * widthScale; - - Nd4jLong v = math::nd4j_floor(inX); - xCache.start = v; - xCache.startScale = v < inX ? (v + 1 > inX1 ? widthScale : v + 1 - inX) : (v + 1 > inX1 ? inX1 - v : 1.f); - v = math::nd4j_ceil(inX1); - xCache.end = v--; - xCache.endMinusOneScale = v < inX ? (v + 1 > inX1 ? widthScale : v + 1 - inX) : (v + 1 > inX1 ? inX1 - v : 1.f); - xCache.needsBounding = bound(xCache.start, inWidth) != xCache.start || bound(xCache.end - 1, inWidth) != (xCache.end - 1); +// Computes the sum of all x values defined by taken across +// the y offsets and scales defined by y_ptrs and y_scales, for channel c. +// +// Note that is a template parameter to avoid a performance +// penalty from dynamically checking it. +template +static __device__ void computePatchSum(float scale, const ImageResizerState& st, + ScaleCache const* yScaleCache, + Nd4jLong ptrsLen, + const CachedInterpolation& xCache, + float* outputPtr) { + bool const needsXBounding = xCache.needsBounding; + + auto boundIfNeeded = [needsXBounding](Nd4jLong x, Nd4jLong y) -> Nd4jLong { + return (needsXBounding ? bound(x, y) : (x)); + }; + + const auto numChannels = st.channels; + for (Nd4jLong c = 0; c < numChannels; ++c) { + float sum = 0; + for (int i = 0; i < ptrsLen; ++i) { + T const* ptr = yScaleCache[i].yPtr; + float scaleX = xCache.startScale; + float sumY = + static_cast( + ptr[numChannels * boundIfNeeded(xCache.start, st.inWidth) + c]) * + scaleX; + if (xCache.start + 1 != xCache.end) { + for (Nd4jLong x = xCache.start + 1; x < xCache.end - 1; ++x) { + sumY += static_cast( + ptr[numChannels * boundIfNeeded(x, st.inWidth) + c]); } + scaleX = xCache.endMinusOneScale; + sumY += + static_cast( + ptr[numChannels * boundIfNeeded(xCache.end - 1, st.inWidth) + + c]) * + scaleX; + } + sum += sumY * yScaleCache[i].yScale; } + outputPtr[c] = sum * scale; + } +} -// ------------------------------------------------------------------------------------------------------------------ // - template - struct ScaleCache { - float yScale; - T const* yPtr; - }; - - // Computes the sum of all x values defined by taken across - // the y offsets and scales defined by y_ptrs and y_scales, for channel c. - // - // Note that is a template parameter to avoid a performance - // penalty from dynamically checking it. - template - static __device__ void computePatchSumOf3Channels(float scale, - const ImageResizerState& st, - ScaleCache const* yScaleCache, - Nd4jLong ptrsLen, - const CachedInterpolation& xCache, - float* outputPtr) { - - bool const needsXBounding = xCache.needsBounding; - - auto boundIfNeeded = [needsXBounding](Nd4jLong x, Nd4jLong y) -> Nd4jLong { - return (needsXBounding ? bound(x, y) : (x)); - }; - - float sum_0 = 0; - float sum_1 = 0; - float sum_2 = 0; - for (int i = 0; i < ptrsLen; ++i) { - const T* ptr = yScaleCache[i].yPtr; - float scaleX = xCache.startScale; - Nd4jLong offset = 3 * boundIfNeeded(xCache.start, st.inWidth); - float sum_y_0 = static_cast(ptr[offset + 0]) * scaleX; - float sum_y_1 = static_cast(ptr[offset + 1]) * scaleX; - float sum_y_2 = static_cast(ptr[offset + 2]) * scaleX; - - if (xCache.start + 1 != xCache.end) { - for (Nd4jLong x = xCache.start + 1; x < xCache.end - 1; ++x) { - Nd4jLong offset = 3 * boundIfNeeded(x, st.inWidth); - sum_y_0 += static_cast(ptr[offset + 0]); - sum_y_1 += static_cast(ptr[offset + 1]); - sum_y_2 += static_cast(ptr[offset + 2]); - } - scaleX = xCache.endMinusOneScale; - offset = st.channels * boundIfNeeded(xCache.end - 1, st.inWidth); - sum_y_0 += static_cast(ptr[offset + 0]) * scaleX; - sum_y_1 += static_cast(ptr[offset + 1]) * scaleX; - sum_y_2 += static_cast(ptr[offset + 2]) * scaleX; - } - sum_0 += sum_y_0 * yScaleCache[i].yScale; - sum_1 += sum_y_1 * yScaleCache[i].yScale; - sum_2 += sum_y_2 * yScaleCache[i].yScale; +template +static __global__ void resizeAreaKernel( + ImageResizerState const* pSt, CachedInterpolation const* caches, + float scale, T const* inputPtr, Nd4jLong const* inputShape, + float* outputPtr, Nd4jLong const* outputShape, + ScaleCache* cachePool) { // batch * outWidth * outHeight + + for (auto batch = blockIdx.x; batch < pSt->batchSize; batch += gridDim.x) { + for (auto y = threadIdx.x; y < pSt->outHeight; y += blockDim.x) { + const float inY = y * pSt->heightScale; + const float inY1 = (y + 1) * pSt->heightScale; + // The start and end height indices of all the cells that could + // contribute to the target cell. + const Nd4jLong yStart = math::nd4j_floor(inY); + const Nd4jLong yEnd = math::nd4j_ceil(inY1); + auto scalesDim = yEnd - yStart; + auto yScaleCache = cachePool + (batch * pSt->outHeight + y) * pSt->outWidth; + + // auto startPtr = sharedPtr + y * scalesDim * sizeof(float); + // float* yScales = yScalesShare + y * sizeof(float) * + // scalesDim;//reinterpret_cast(startPtr); //shared + y * + // scalesDim + // * y + scalesDim * sizeof(T const *) [scalesDim]; T const** yPtrs = + // yPtrsShare + y * sizeof(T const*) * scalesDim; //[scalesDim]; yPtrs = + // reinterpret_cast(sharedBuf); + float* output = outputPtr + (batch * pSt->outHeight + y) * pSt->channels * + pSt->outWidth; + // int k = 0; + for (Nd4jLong i = yStart, k = 0; i < yEnd; ++i, ++k) { + float scaleY; + if (i < inY) { + scaleY = (i + 1 > inY1 ? pSt->heightScale : i + 1 - inY); + } else { + scaleY = (i + 1 > inY1 ? inY1 - i : 1.0); } - - outputPtr[0] = sum_0 * scale; - outputPtr[1] = sum_1 * scale; - outputPtr[2] = sum_2 * scale; - } - - // Computes the sum of all x values defined by taken across - // the y offsets and scales defined by y_ptrs and y_scales, for channel c. - // - // Note that is a template parameter to avoid a performance - // penalty from dynamically checking it. - template - static __device__ void computePatchSum(float scale, const ImageResizerState& st, - ScaleCache const* yScaleCache, Nd4jLong ptrsLen, - const CachedInterpolation& xCache, - float* outputPtr) { - - bool const needsXBounding = xCache.needsBounding; - - auto boundIfNeeded = [needsXBounding](Nd4jLong x, Nd4jLong y) -> Nd4jLong { - return (needsXBounding ? bound(x, y) : (x)); - }; - - const auto numChannels = st.channels; - for (Nd4jLong c = 0; c < numChannels; ++c) { - float sum = 0; - for (int i = 0; i < ptrsLen; ++i) { - T const* ptr = yScaleCache[i].yPtr; - float scaleX = xCache.startScale; - float sumY = static_cast(ptr[numChannels * boundIfNeeded(xCache.start, st.inWidth) + c]) * scaleX; - if (xCache.start + 1 != xCache.end) { - for (Nd4jLong x = xCache.start + 1; x < xCache.end - 1; ++x) { - sumY += static_cast( - ptr[numChannels * boundIfNeeded(x, st.inWidth) + c]); - } - scaleX = xCache.endMinusOneScale; - sumY += static_cast(ptr[numChannels * boundIfNeeded(xCache.end - 1, st.inWidth) + c]) * scaleX; - } - sum += sumY * yScaleCache[i].yScale; - } - outputPtr[c] = sum * scale; + yScaleCache[k].yScale = scaleY; + yScaleCache[k].yPtr = + inputPtr + (batch * pSt->inHeight * pSt->inWidth * pSt->channels + + bound(i, pSt->inHeight) * pSt->inWidth * pSt->channels); + } + + if (pSt->channels == 3) { + for (Nd4jLong x = 0; x < pSt->outWidth; ++x) { + const CachedInterpolation& xCache = caches[x]; + computePatchSumOf3Channels(scale, *pSt, yScaleCache, scalesDim, + xCache, output); + output += pSt->channels; } - } - - template - static __global__ void resizeAreaKernel(ImageResizerState const* pSt, CachedInterpolation const* caches, float scale, - T const* inputPtr, Nd4jLong const* inputShape, float* outputPtr, Nd4jLong const* outputShape, ScaleCache* cachePool) { //batch * outWidth * outHeight - - for (auto batch = blockIdx.x; batch < pSt->batchSize; batch += gridDim.x) { - for (auto y = threadIdx.x; y < pSt->outHeight; y += blockDim.x) { - const float inY = y * pSt->heightScale; - const float inY1 = (y + 1) * pSt->heightScale; - // The start and end height indices of all the cells that could - // contribute to the target cell. - const Nd4jLong yStart = math::nd4j_floor(inY); - const Nd4jLong yEnd = math::nd4j_ceil(inY1); - auto scalesDim = yEnd - yStart; - auto yScaleCache = cachePool + (batch * pSt->outHeight + y) * pSt->outWidth; - - //auto startPtr = sharedPtr + y * scalesDim * sizeof(float); - //float* yScales = yScalesShare + y * sizeof(float) * scalesDim;//reinterpret_cast(startPtr); //shared + y * scalesDim * y + scalesDim * sizeof(T const *) [scalesDim]; - //T const** yPtrs = yPtrsShare + y * sizeof(T const*) * scalesDim; //[scalesDim]; - //yPtrs = reinterpret_cast(sharedBuf); - float* output = outputPtr + (batch * pSt->outHeight + y) * pSt->channels * pSt->outWidth; - //int k = 0; - for (Nd4jLong i = yStart, k = 0; i < yEnd; ++i, ++k) { - float scaleY; - if (i < inY) { - scaleY = (i + 1 > inY1 ? pSt->heightScale : i + 1 - inY); - } else { - scaleY = (i + 1 > inY1 ? inY1 - i : 1.0); - } - yScaleCache[k].yScale = scaleY; - yScaleCache[k].yPtr = inputPtr + (batch * pSt->inHeight * pSt->inWidth * pSt->channels + bound(i, pSt->inHeight) * pSt->inWidth * pSt->channels); - } - - if (pSt->channels == 3) { - for (Nd4jLong x = 0; x < pSt->outWidth; ++x) { - const CachedInterpolation& xCache = caches[x]; - computePatchSumOf3Channels(scale, *pSt, yScaleCache, scalesDim, xCache, output); - output += pSt->channels; - } - } else { - for (Nd4jLong x = 0; x < pSt->outWidth; ++x) { - const CachedInterpolation &xCache = caches[x]; - computePatchSum(scale, *pSt, yScaleCache, scalesDim, xCache, output); - output += pSt->channels; - } - } - } + } else { + for (Nd4jLong x = 0; x < pSt->outWidth; ++x) { + const CachedInterpolation& xCache = caches[x]; + computePatchSum(scale, *pSt, yScaleCache, scalesDim, xCache, + output); + output += pSt->channels; } + } } + } +} - template - static void resizeArea(cudaStream_t* stream, ImageResizerState const& st, CachedInterpolation* cache, - NDArray const* input, NDArray* output) { - - T const* inputPtr = reinterpret_cast(input->specialBuffer()); -// float* yScales; -// T const** yPtrs; - float scale = 1.f / (st.heightScale * st.widthScale); - auto outputPtr = reinterpret_cast(output->specialBuffer()); // output is always float. TO DO: provide another float types also with template declaration - ImageResizerState* pSt; - auto err = cudaMalloc(&pSt, sizeof(ImageResizerState)); - if (err != 0) { +template +static void resizeArea(cudaStream_t* stream, ImageResizerState const& st, + CachedInterpolation* cache, NDArray const* input, + NDArray* output) { + T const* inputPtr = reinterpret_cast(input->specialBuffer()); + // float* yScales; + // T const** yPtrs; + float scale = 1.f / (st.heightScale * st.widthScale); + auto outputPtr = reinterpret_cast( + output->specialBuffer()); // output is always float. TO DO: provide + // another float types also with template + // declaration + ImageResizerState* pSt; + auto err = cudaMalloc(&pSt, sizeof(ImageResizerState)); + if (err != 0) { throw cuda_exception::build("helpers::resizeArea: Cannot allocate memory for ImageResizerState", err); - } - - err = cudaMemcpyAsync(pSt, &st, sizeof(ImageResizerState), cudaMemcpyHostToDevice, *stream); - if (err != 0) { + }err = cudaMemcpyAsync(pSt, &st, sizeof(ImageResizerState), + cudaMemcpyHostToDevice, *stream); + if (err != 0) { throw cuda_exception::build("helpers::resizeArea: Cannot copy to device memory", err); - } - ScaleCache* cachePool; - auto cachePoolSize = sizeof(ScaleCache) * st.batchSize * st.outWidth * st.outHeight; - err = cudaMalloc(&cachePool, cachePoolSize); + }ScaleCache* cachePool; + auto cachePoolSize = sizeof(ScaleCache) * st.batchSize * + st.outWidth * st.outHeight; + err = cudaMalloc(&cachePool, cachePoolSize); if (err != 0) { throw cuda_exception::build("helpers::resizeArea: Cannot allocate memory for cache", err); } - resizeAreaKernel<<<128, 128, 2048, *stream>>>(pSt, cache, scale, inputPtr, input->specialShapeInfo(), outputPtr, - output->specialShapeInfo(), cachePool); - err = cudaStreamSynchronize(*stream); - if (err != 0) { + resizeAreaKernel<<<128, 128, 2048, *stream>>>( + pSt, cache, scale, inputPtr, input->specialShapeInfo(), outputPtr, + output->specialShapeInfo(), cachePool); + err = cudaStreamSynchronize(*stream); + if (err != 0) { throw cuda_exception::build("helpers::resizeArea: An error occured with kernel running", err); - } - err = cudaFree(cachePool); - if (err != 0) { + }err = cudaFree(cachePool); + if (err != 0) { throw cuda_exception::build("helpers::resizeArea: Cannot deallocate memory for cache", err); - } - err = cudaFree(pSt); - if (err != 0) { + }err = cudaFree(pSt); +if (err != 0) { throw cuda_exception::build("helpers::resizeArea: Cannot deallocate memory for ImageResizeState", err); - } - } -// ------------------------------------------------------------------------------------------------------------------ // - template - int resizeAreaFunctor_(sd::LaunchContext* context, NDArray const* image, int const width, int const height, - bool const alignCorners, NDArray* output) { - - ImageResizerState st(alignCorners, false); // Create resize info - auto res = st.validateAndCalculateOutputSize(image, width, height); - auto stream = context->getCudaStream(); - if (Status::OK() == res) { - CachedInterpolation* xCached; - //(st.outWidth); - auto err = cudaMalloc(&xCached, sizeof(CachedInterpolation) * st.outWidth); - if (err != 0) { + }} +// ------------------------------------------------------------------------------------------------------------------ +// // +template +int resizeAreaFunctor_(sd::LaunchContext* context, NDArray const* image, + int const width, int const height, + bool const alignCorners, NDArray* output) { + ImageResizerState st(alignCorners, false); // Create resize info + auto res = st.validateAndCalculateOutputSize(image, width, height); + auto stream = context->getCudaStream(); + if (Status::OK() == res) { + CachedInterpolation* xCached; + //(st.outWidth); + auto err = cudaMalloc(&xCached, sizeof(CachedInterpolation) * st.outWidth);if (err != 0) { throw cuda_exception::build("helpers::resizeAreaFunctor_: Cannot allocate memory for cached interpolations", err); } - NDArray::prepareSpecialUse({output}, {image}); - fillInterpolationCache<<<128, 128, 256, *stream>>>(xCached, st.outWidth, st.inWidth, st.widthScale); - resizeArea(stream, st, xCached, image, output); - err = cudaStreamSynchronize(*stream); - if (err != 0) { + NDArray::prepareSpecialUse({output}, {image}); + fillInterpolationCache<<<128, 128, 256, *stream>>>( + xCached, st.outWidth, st.inWidth, st.widthScale); + resizeArea(stream, st, xCached, image, output); + err = cudaStreamSynchronize(*stream); + if (err != 0) { throw cuda_exception::build("helpers::resizeAreaFunctor_: Error occured when kernel was running", err); } err = cudaFree(xCached); if (err != 0) { throw cuda_exception::build("helpers::resizeAreaFunctor_: Cannot deallocate memory for cached interpolations", err); } - NDArray::registerSpecialUse({output}, {image}); - } + NDArray::registerSpecialUse({output}, {image}); + } - return res; - } - int resizeAreaFunctor(sd::LaunchContext * context, NDArray const* image, int const width, int const height, - bool const alignCorners, NDArray* output) { - BUILD_SINGLE_SELECTOR(image->dataType(), return resizeAreaFunctor_, (context, image, width, height, alignCorners, output), NUMERIC_TYPES); - } + return res; +} +int resizeAreaFunctor(sd::LaunchContext* context, NDArray const* image, + int const width, int const height, + bool const alignCorners, NDArray* output) { + BUILD_SINGLE_SELECTOR(image->dataType(), return resizeAreaFunctor_, + (context, image, width, height, alignCorners, output), + NUMERIC_TYPES); +} -// ------------------------------------------------------------------------------------------------------------------ // -// simplified bicubic resize without antialiasing +// ------------------------------------------------------------------------------------------------------------------ +// // simplified bicubic resize without antialiasing // - template - int resizeBicubicFunctorA_(sd::LaunchContext * context, NDArray const* image, int width, int height, - bool const alignCorners, bool const halfPixelCenters, NDArray* output) { - - ImageResizerState st(alignCorners, halfPixelCenters); // align_corners, half_pixel_align - st.stream = context->getCudaStream(); - NDArray::prepareSpecialUse({output}, {image}); - int res = st.validateAndCreateOutput(image, width, height); - if (res == Status::OK()) - bicubicInterpolateWithCaching(image, st, halfPixelCenters, output); - NDArray::registerSpecialUse({output}, {image}); - return res; - } +template +int resizeBicubicFunctorA_(sd::LaunchContext* context, NDArray const* image, + int width, int height, bool const alignCorners, + bool const halfPixelCenters, NDArray* output) { + ImageResizerState st(alignCorners, + halfPixelCenters); // align_corners, half_pixel_align + st.stream = context->getCudaStream(); + NDArray::prepareSpecialUse({output}, {image}); + int res = st.validateAndCreateOutput(image, width, height); + if (res == Status::OK()) + bicubicInterpolateWithCaching(image, st, halfPixelCenters, output); + NDArray::registerSpecialUse({output}, {image}); + return res; +} - int resizeBicubicFunctorA(sd::LaunchContext * context, NDArray const* image, int width, int height, - bool const alignCorners, bool const halfPixelCenters, NDArray* output) { - BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctorA_, (context, - image, width, height, alignCorners, halfPixelCenters, output), NUMERIC_TYPES); - } - BUILD_SINGLE_TEMPLATE(template int resizeBicubicFunctorA_, (sd::LaunchContext * context, - NDArray const* image, int width, int height, bool const alignCorners, bool const halfPixelCenters, NDArray* output), NUMERIC_TYPES); - - -// ------------------------------------------------------------------------------------------------------------------ // - int resizeImagesFunctor(sd::LaunchContext * context, NDArray const* image, int const width, int const height, - ImageResizeMethods method, bool alignCorners, NDArray* output) { - switch (method) { - case kResizeBilinear: - return resizeBilinearFunctor(context, image, width, height, alignCorners, false, output); - case kResizeNearest: - return resizeNeighborFunctor(context, image, width, height, alignCorners, false, output); - case kResizeBicubic: - return resizeBicubicFunctor(context, image, width, height, alignCorners, false, output); - case kResizeArea: - return resizeAreaFunctor(context, image, width, height, alignCorners, output); - default: - throw std::runtime_error("helper::resizeImagesFunctor: Wrong resize method."); - } +int resizeBicubicFunctorA(sd::LaunchContext* context, NDArray const* image, + int width, int height, bool const alignCorners, + bool const halfPixelCenters, NDArray* output) { + BUILD_SINGLE_SELECTOR( + image->dataType(), return resizeBicubicFunctorA_, + (context, image, width, height, alignCorners, halfPixelCenters, output), + NUMERIC_TYPES); +} +BUILD_SINGLE_TEMPLATE(template int resizeBicubicFunctorA_, + (sd::LaunchContext * context, NDArray const* image, + int width, int height, bool const alignCorners, + bool const halfPixelCenters, NDArray* output), + NUMERIC_TYPES); + + +// ------------------------------------------------------------------------------------------------------------------// +int resizeImagesFunctor(sd::LaunchContext* context, NDArray const* image, int const width, + int const height, ImageResizeMethods method, + bool alignCorners, NDArray* output) { + switch (method) { + case kResizeBilinear: + return resizeBilinearFunctor(context, image, width, height, alignCorners, false, + output); + + case kResizeNearest: + return resizeNeighborFunctor(context, image, width, height, alignCorners, false, + output); + + case kResizeBicubic: + return resizeBicubicFunctor(context, image, width, height, + alignCorners, false, output); + + case kResizeArea: + return resizeAreaFunctor(context, image, width, height, alignCorners, output); + default: + throw std::runtime_error("helper::resizeImagesFunctor: Wrong resize method."); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// --------------------------------------------------------------------------------------------------------------- +// // Crop and Resize helper implementation +// -------------------------------------------------------------------------------------------------------------- +// // cropAndResize kernel type of input(images) and output should be the same +// +template +static __global__ void cropAndResizeKernel( + T const* images, Nd4jLong const* imagesShape, Z const* boxes, + Nd4jLong const* boxesShape, I const* indices, Nd4jLong const* indexShape, + I const* cropSize, Nd4jLong const* cropShape, int method, + double extrapolationVal, T* output, Nd4jLong const* outputShape, + int numBoxes, int cropHeight, int cropWidth, int batchSize, int imageHeight, + int imageWidth, int depth) { + for (int b = blockIdx.x; b < numBoxes; b += gridDim.x) { + Nd4jLong x1Pos[] = {b, 1}; + Nd4jLong y1Pos[] = {b, 0}; + Nd4jLong y2Pos[] = {b, 2}; + Nd4jLong x2Pos[] = {b, 3}; + Z y1 = boxes[shape::getOffset(boxesShape, y1Pos)]; //->t(b, 0)]; + Z x1 = boxes[shape::getOffset(boxesShape, x1Pos)]; + Z y2 = boxes[shape::getOffset(boxesShape, y2Pos)]; + Z x2 = boxes[shape::getOffset(boxesShape, x2Pos)]; + + int bIn = indices[b]; + if (bIn >= batchSize) { + continue; } - //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // --------------------------------------------------------------------------------------------------------------- // - // Crop and Resize helper implementation - // -------------------------------------------------------------------------------------------------------------- // - // cropAndResize kernel type of input(images) and output should be the same - // - template - static __global__ void cropAndResizeKernel(T const *images, Nd4jLong const* imagesShape, Z const* boxes, Nd4jLong const* boxesShape, - I const* indices, Nd4jLong const* indexShape, I const* cropSize, Nd4jLong const* cropShape, int method, - double extrapolationVal, T* output, Nd4jLong const* outputShape, int numBoxes, int cropHeight, int cropWidth, - int batchSize, int imageHeight, int imageWidth, int depth) { - - for (int b = blockIdx.x; b < numBoxes; b += gridDim.x) - { - Nd4jLong x1Pos[] = {b, 1}; - Nd4jLong y1Pos[] = {b, 0}; - Nd4jLong y2Pos[] = {b, 2}; - Nd4jLong x2Pos[] = {b, 3}; - Z y1 = boxes[shape::getOffset(boxesShape, y1Pos)];//->t(b, 0)]; - Z x1 = boxes[shape::getOffset(boxesShape, x1Pos)]; - Z y2 = boxes[shape::getOffset(boxesShape, y2Pos)]; - Z x2 = boxes[shape::getOffset(boxesShape, x2Pos)]; - - int bIn = indices[b]; - if (bIn >= batchSize) { - continue; + Z heightScale = (cropHeight > 1) + ? (y2 - y1) * (imageHeight - 1) / Z(cropHeight - 1) + : Z(0); + Z widthScale = (cropWidth > 1) + ? (x2 - x1) * (imageWidth - 1) / Z(cropWidth - 1) + : Z(0); + + for (int y = threadIdx.x; y < cropHeight; y += blockDim.x) { + const float inY = (cropHeight > 1) + ? y1 * (imageHeight - 1) + y * heightScale + : 0.5 * (y1 + y2) * (imageHeight - 1); + if (inY < 0 || inY > imageHeight - 1) { + for (int x = threadIdx.y; x < cropWidth; x += blockDim.y) { + auto start = blockIdx.z * blockDim.x + threadIdx.z; + auto step = blockDim.z * gridDim.z; + for (int d = start; d < depth; d += step) { + Nd4jLong zPos[] = {b, y, x, d}; + auto zIndex = shape::getOffset(outputShape, zPos); + output[zIndex] = (Z)extrapolationVal; + // crops->p(b, y, x, d, extrapolationVal); + } + } + continue; + } + + if (method == 0 /* bilinear */) { + const int topYIndex = sd::math::p_floor(inY); + const int bottomYIndex = sd::math::p_ceil(inY); + const float y_lerp = inY - topYIndex; + + for (int x = 0; x < cropWidth; ++x) { + const float in_x = (cropWidth > 1) + ? x1 * (imageWidth - 1) + x * widthScale + : 0.5 * (x1 + x2) * (imageWidth - 1); + if (in_x < 0 || in_x > imageWidth - 1) { + auto start = blockIdx.z * blockDim.x + threadIdx.z; + auto step = blockDim.z * gridDim.z; + for (int d = start; d < depth; d += step) { + Nd4jLong zPos[] = {b, y, x, d}; + auto zIndex = shape::getOffset(outputShape, zPos); + output[zIndex] = (Z)extrapolationVal; + // crops->p(b, y, x, d, + // extrapolationVal); } - - Z heightScale = (cropHeight > 1) ? (y2 - y1) * (imageHeight - 1) / Z(cropHeight - 1) : Z(0); - Z widthScale = (cropWidth > 1) ? (x2 - x1) * (imageWidth - 1) / Z(cropWidth - 1) : Z(0); - - for (int y = threadIdx.x; y < cropHeight; y += blockDim.x) { - const float inY = (cropHeight > 1) - ? y1 * (imageHeight - 1) + y * heightScale - : 0.5 * (y1 + y2) * (imageHeight - 1); - if (inY < 0 || inY > imageHeight - 1) { - for (int x = threadIdx.y; x < cropWidth; x += blockDim.y) { - auto start = blockIdx.z * blockDim.x + threadIdx.z; - auto step = blockDim.z * gridDim.z; - for (int d = start; d < depth; d += step) { - Nd4jLong zPos[] = {b, y, x, d}; - auto zIndex = shape::getOffset(outputShape, zPos); - output[zIndex] = (Z)extrapolationVal; - //crops->p(b, y, x, d, extrapolationVal); - } - } - continue; - } - - if (method == 0 /* bilinear */) { - const int topYIndex = sd::math::p_floor(inY); - const int bottomYIndex = sd::math::p_ceil(inY); - const float y_lerp = inY - topYIndex; - - for (int x = 0; x < cropWidth; ++x) { - const float in_x = (cropWidth > 1) - ? x1 * (imageWidth - 1) + x * widthScale - : 0.5 * (x1 + x2) * (imageWidth - 1); - if (in_x < 0 || in_x > imageWidth - 1) { - auto start = blockIdx.z * blockDim.x + threadIdx.z; - auto step = blockDim.z * gridDim.z; - for (int d = start; d < depth; d += step) { - Nd4jLong zPos[] = {b, y, x, d}; - auto zIndex = shape::getOffset(outputShape, zPos); - output[zIndex] = (Z)extrapolationVal; -// crops->p(b, y, x, d, extrapolationVal); - } - continue; - } - int left_x_index = math::p_floor(in_x); - int right_x_index = math::p_ceil(in_x); - T x_lerp = in_x - left_x_index; - - auto start = blockIdx.z * blockDim.x + threadIdx.z; - auto step = blockDim.z * gridDim.z; - for (int d = start; d < depth; d += step) { - Nd4jLong topLeftPos[] = {bIn, topYIndex, left_x_index, d}; - Nd4jLong topRightPos[] = {bIn, topYIndex, right_x_index, d}; - Nd4jLong bottomLeftPos[] = {bIn, bottomYIndex, left_x_index, d}; - Nd4jLong bottomRightPos[] = {bIn, bottomYIndex, right_x_index, d}; - const T topLeft(images[shape::getOffset(imagesShape, topLeftPos)]); //->e(bIn, topYIndex, left_x_index, d)); - const T topRight(images[shape::getOffset(imagesShape, topRightPos)]); //->e(bIn, topYIndex, right_x_index, d)); - const T bottomLeft(images[shape::getOffset(imagesShape, bottomLeftPos)]);//->e(bIn, bottomYIndex, left_x_index, d)); - const T bottomRight(images[shape::getOffset(imagesShape, bottomRightPos)]); //->e(bIn, bottomYIndex, right_x_index, d)); - const T top = topLeft + (topRight - topLeft) * x_lerp; - const T bottom = bottomLeft + (bottomRight - bottomLeft) * x_lerp; - Nd4jLong zPos[] = {b, y, x, d}; - auto zIndex = shape::getOffset(outputShape, zPos); - output[zIndex] = Z(top + (bottom - top) * y_lerp); - } - } - } else { // method is "nearest neighbor" - for (int x = 0; x < cropWidth; ++x) { - const float inX = (cropWidth > 1) - ? x1 * (imageWidth - 1) + x * widthScale - : 0.5 * (x1 + x2) * (imageWidth - 1); - if (inX < 0 || inX > imageWidth - 1) { - auto start = blockIdx.z * blockDim.x + threadIdx.z; - auto step = blockDim.z * gridDim.z; - for (int d = start; d < depth; d += step) { - Nd4jLong zPos[] = {b, y, x, d}; - auto zIndex = shape::getOffset(outputShape, zPos); - output[zIndex] = (Z)extrapolationVal; - } - continue; - } - const int closestXIndex = roundf(inX); - const int closestYIndex = roundf(inY); - auto start = blockIdx.z * blockDim.x + threadIdx.z; - auto step = blockDim.z * gridDim.z; - for (int d = start; d < depth; d += step) { - Nd4jLong zPos[] = {b, y, x, d}; - Nd4jLong xPos[] = {bIn, closestYIndex, closestXIndex, d}; - auto zIndex = shape::getOffset(outputShape, zPos); - auto xIndex = shape::getOffset(imagesShape, xPos); - output[zIndex] = images[xIndex]; - } - } - } + continue; + } + int left_x_index = math::p_floor(in_x); + int right_x_index = math::p_ceil(in_x); + T x_lerp = in_x - left_x_index; + + auto start = blockIdx.z * blockDim.x + threadIdx.z; + auto step = blockDim.z * gridDim.z; + for (int d = start; d < depth; d += step) { + Nd4jLong topLeftPos[] = {bIn, topYIndex, left_x_index, d}; + Nd4jLong topRightPos[] = {bIn, topYIndex, right_x_index, d}; + Nd4jLong bottomLeftPos[] = {bIn, bottomYIndex, left_x_index, d}; + Nd4jLong bottomRightPos[] = {bIn, bottomYIndex, right_x_index, d}; + const T topLeft(images[shape::getOffset( + imagesShape, + topLeftPos)]); //->e(bIn, topYIndex, left_x_index, d)); + const T topRight(images[shape::getOffset( + imagesShape, topRightPos)]); //->e(bIn, topYIndex, + // right_x_index, d)); + const T bottomLeft(images[shape::getOffset( + imagesShape, bottomLeftPos)]); //->e(bIn, bottomYIndex, + // left_x_index, d)); + const T bottomRight(images[shape::getOffset( + imagesShape, bottomRightPos)]); //->e(bIn, bottomYIndex, + // right_x_index, d)); + const T top = topLeft + (topRight - topLeft) * x_lerp; + const T bottom = bottomLeft + (bottomRight - bottomLeft) * x_lerp; + Nd4jLong zPos[] = {b, y, x, d}; + auto zIndex = shape::getOffset(outputShape, zPos); + output[zIndex] = Z(top + (bottom - top) * y_lerp); + } + } + } else { // method is "nearest neighbor" + for (int x = 0; x < cropWidth; ++x) { + const float inX = (cropWidth > 1) + ? x1 * (imageWidth - 1) + x * widthScale + : 0.5 * (x1 + x2) * (imageWidth - 1); + if (inX < 0 || inX > imageWidth - 1) { + auto start = blockIdx.z * blockDim.x + threadIdx.z; + auto step = blockDim.z * gridDim.z; + for (int d = start; d < depth; d += step) { + Nd4jLong zPos[] = {b, y, x, d}; + auto zIndex = shape::getOffset(outputShape, zPos); + output[zIndex] = (Z)extrapolationVal; } + continue; + } + const int closestXIndex = roundf(inX); + const int closestYIndex = roundf(inY); + auto start = blockIdx.z * blockDim.x + threadIdx.z; + auto step = blockDim.z * gridDim.z; + for (int d = start; d < depth; d += step) { + Nd4jLong zPos[] = {b, y, x, d}; + Nd4jLong xPos[] = {bIn, closestYIndex, closestXIndex, d}; + auto zIndex = shape::getOffset(outputShape, zPos); + auto xIndex = shape::getOffset(imagesShape, xPos); + output[zIndex] = images[xIndex]; + } } - + } } + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // cropAndResizeFunctor main algorithm @@ -1357,43 +1600,59 @@ namespace helpers { // extrapolationVal - double value of extrapolation // crops - output (4D tensor - [batch, outWidth, outHeight, pixels]) // - template - void cropAndResizeFunctor_(sd::LaunchContext* context, NDArray const *images, NDArray const *boxes, NDArray const *indices, - NDArray const *cropSize, int method, double extrapolationVal, NDArray *crops) { - const int batchSize = images->sizeAt(0); - const int imageHeight = images->sizeAt(1); - const int imageWidth = images->sizeAt(2); - - const int numBoxes = crops->sizeAt(0); - const int cropHeight = crops->sizeAt(1); - const int cropWidth = crops->sizeAt(2); - const int depth = crops->sizeAt(3); - auto stream = context->getCudaStream(); - T const* imagesBuf = reinterpret_cast(images->specialBuffer()); - Z const* boxesBuf = reinterpret_cast(boxes->specialBuffer()); - I const* indexBuf = reinterpret_cast(indices->specialBuffer()); - I const* cropSizes = reinterpret_cast(cropSize->specialBuffer()); - T* outBuf = reinterpret_cast(crops->specialBuffer()); - - int threadsPerBlock = math::nd4j_max(imageHeight * imageWidth, cropHeight * cropWidth); - if(threadsPerBlock > MAX_NUM_THREADS/4) - threadsPerBlock = MAX_NUM_THREADS/4; - - NDArray::prepareSpecialUse({crops}, {images, boxes, indices, cropSize}); - cropAndResizeKernel<<>>(imagesBuf, images->specialShapeInfo(), boxesBuf, boxes->specialShapeInfo(), indexBuf, indices->specialShapeInfo(), - cropSizes, cropSize->specialShapeInfo(), method, extrapolationVal, outBuf, crops->specialShapeInfo(), numBoxes, cropHeight, cropWidth, batchSize, imageHeight, imageWidth, depth); - NDArray::registerSpecialUse({crops}, {images, boxes, indices, cropSize}); - } +template +void cropAndResizeFunctor_(sd::LaunchContext* context, NDArray const* images, + NDArray const* boxes, NDArray const* indices, + NDArray const* cropSize, int method, + double extrapolationVal, NDArray* crops) { + const int batchSize = images->sizeAt(0); + const int imageHeight = images->sizeAt(1); + const int imageWidth = images->sizeAt(2); + + const int numBoxes = crops->sizeAt(0); + const int cropHeight = crops->sizeAt(1); + const int cropWidth = crops->sizeAt(2); + const int depth = crops->sizeAt(3); + auto stream = context->getCudaStream(); + T const* imagesBuf = reinterpret_cast(images->specialBuffer()); + Z const* boxesBuf = reinterpret_cast(boxes->specialBuffer()); + I const* indexBuf = reinterpret_cast(indices->specialBuffer()); + I const* cropSizes = reinterpret_cast(cropSize->specialBuffer()); + T* outBuf = reinterpret_cast(crops->specialBuffer()); + + int threadsPerBlock = + math::nd4j_max(imageHeight * imageWidth, cropHeight * cropWidth); + if (threadsPerBlock > MAX_NUM_THREADS / 4) + threadsPerBlock = MAX_NUM_THREADS / 4; + + NDArray::prepareSpecialUse({crops}, {images, boxes, indices, cropSize}); + cropAndResizeKernel<<>>( + imagesBuf, images->specialShapeInfo(), boxesBuf, + boxes->specialShapeInfo(), indexBuf, indices->specialShapeInfo(), + cropSizes, cropSize->specialShapeInfo(), method, extrapolationVal, outBuf, + crops->specialShapeInfo(), numBoxes, cropHeight, cropWidth, batchSize, + imageHeight, imageWidth, depth); + NDArray::registerSpecialUse({crops}, {images, boxes, indices, cropSize}); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - void cropAndResizeFunctor(sd::LaunchContext * context, NDArray const *images, NDArray const *boxes, NDArray const *indices, NDArray const *cropSize, int method, double extrapolationVal, NDArray *crops) { - BUILD_TRIPLE_SELECTOR(images->dataType(), boxes->dataType(), indices->dataType(), cropAndResizeFunctor_, - (context, images, boxes, indices, cropSize, method, extrapolationVal, crops), NUMERIC_TYPES, FLOAT_TYPES, INTEGER_TYPES); - // - } - BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, - (sd::LaunchContext * context, NDArray const* images, NDArray const* boxes, NDArray const* indices, NDArray const* cropSize, int method, double extrapolationVal, NDArray* crops), - NUMERIC_TYPES, FLOAT_TYPES, INTEGER_TYPES); -} +void cropAndResizeFunctor(sd::LaunchContext* context, NDArray const* images, + NDArray const* boxes, NDArray const* indices, + NDArray const* cropSize, int method, + double extrapolationVal, NDArray* crops) { + BUILD_TRIPLE_SELECTOR(images->dataType(), boxes->dataType(), + indices->dataType(), cropAndResizeFunctor_, + (context, images, boxes, indices, cropSize, method, + extrapolationVal, crops), + NUMERIC_TYPES, FLOAT_TYPES, INTEGER_TYPES); + // } -} \ No newline at end of file +BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, + (sd::LaunchContext * context, NDArray const* images, + NDArray const* boxes, NDArray const* indices, + NDArray const* cropSize, int method, + double extrapolationVal, NDArray* crops), + NUMERIC_TYPES, FLOAT_TYPES, INTEGER_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu index 8b7e8ee57074..96825de5968a 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu @@ -18,10 +18,11 @@ // @author sgazeos@gmail.com // -#include #include -#include #include +#include +#include + #include namespace sd { @@ -37,378 +38,477 @@ namespace helpers { // // return value: true, if threshold is overcome, false otherwise // - template - static __device__ bool needToSuppressWithThreshold(T* boxes, Nd4jLong const* boxesShape, int previousIndex, int nextIndex, T threshold) { - Nd4jLong previous0[] = {previousIndex, 0}; - Nd4jLong previous1[] = {previousIndex, 1}; - Nd4jLong previous2[] = {previousIndex, 2}; - Nd4jLong previous3[] = {previousIndex, 3}; - Nd4jLong next0[] = {nextIndex, 0}; - Nd4jLong next1[] = {nextIndex, 1}; - Nd4jLong next2[] = {nextIndex, 2}; - Nd4jLong next3[] = {nextIndex, 3}; - - // we have rectangle with given max values. Compute vexes of rectangle first - - T minYPrev = sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, previous0)], boxes[shape::getOffset(boxesShape, previous2)]); - T minXPrev = sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, previous1)], boxes[shape::getOffset(boxesShape, previous3)]); - T maxYPrev = sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, previous0)], boxes[shape::getOffset(boxesShape, previous2)]); - T maxXPrev = sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, previous1)], boxes[shape::getOffset(boxesShape, previous3)]); - T minYNext = sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, next0)], boxes[shape::getOffset(boxesShape, next2)]); - T minXNext = sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, next1)], boxes[shape::getOffset(boxesShape, next3)]); - T maxYNext = sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, next0)], boxes[shape::getOffset(boxesShape, next2)]); - T maxXNext = sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, next1)], boxes[shape::getOffset(boxesShape, next3)]); - - // compute areas for comparation - T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev); - T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext); - - // of course, areas should be positive - if (areaNext <= T(0.f) || areaPrev <= T(0.f)) return false; - - // compute intersection of rectangles - T minIntersectionY = sd::math::nd4j_max(minYPrev, minYNext); - T minIntersectionX = sd::math::nd4j_max(minXPrev, minXNext); - T maxIntersectionY = sd::math::nd4j_min(maxYPrev, maxYNext); - T maxIntersectionX = sd::math::nd4j_min(maxXPrev, maxXNext); - T intersectionArea = - sd::math::nd4j_max(T(maxIntersectionY - minIntersectionY), T(0.0f)) * - sd::math::nd4j_max(T(maxIntersectionX - minIntersectionX), T(0.0f)); - T intersectionValue = intersectionArea / (areaPrev + areaNext - intersectionArea); - // final check - return intersectionValue > threshold; - } +template +static __device__ bool needToSuppressWithThreshold(T* boxes, + Nd4jLong const* boxesShape, + int previousIndex, + int nextIndex, T threshold) { + Nd4jLong previous0[] = {previousIndex, 0}; + Nd4jLong previous1[] = {previousIndex, 1}; + Nd4jLong previous2[] = {previousIndex, 2}; + Nd4jLong previous3[] = {previousIndex, 3}; + Nd4jLong next0[] = {nextIndex, 0}; + Nd4jLong next1[] = {nextIndex, 1}; + Nd4jLong next2[] = {nextIndex, 2}; + Nd4jLong next3[] = {nextIndex, 3}; + + // we have rectangle with given max values. Compute vexes of rectangle first + + T minYPrev = + sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, previous0)], + boxes[shape::getOffset(boxesShape, previous2)]); + T minXPrev = + sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, previous1)], + boxes[shape::getOffset(boxesShape, previous3)]); + T maxYPrev = + sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, previous0)], + boxes[shape::getOffset(boxesShape, previous2)]); + T maxXPrev = + sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, previous1)], + boxes[shape::getOffset(boxesShape, previous3)]); + T minYNext = sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, next0)], + boxes[shape::getOffset(boxesShape, next2)]); + T minXNext = sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, next1)], + boxes[shape::getOffset(boxesShape, next3)]); + T maxYNext = sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, next0)], + boxes[shape::getOffset(boxesShape, next2)]); + T maxXNext = sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, next1)], + boxes[shape::getOffset(boxesShape, next3)]); + + // compute areas for comparation + T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev); + T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext); + + // of course, areas should be positive + if (areaNext <= T(0.f) || areaPrev <= T(0.f)) return false; + + // compute intersection of rectangles + T minIntersectionY = sd::math::nd4j_max(minYPrev, minYNext); + T minIntersectionX = sd::math::nd4j_max(minXPrev, minXNext); + T maxIntersectionY = sd::math::nd4j_min(maxYPrev, maxYNext); + T maxIntersectionX = sd::math::nd4j_min(maxXPrev, maxXNext); + T intersectionArea = + sd::math::nd4j_max(T(maxIntersectionY - minIntersectionY), T(0.0f)) * + sd::math::nd4j_max(T(maxIntersectionX - minIntersectionX), T(0.0f)); + T intersectionValue = + intersectionArea / (areaPrev + areaNext - intersectionArea); + // final check + return intersectionValue > threshold; +} - template - static __device__ T similirityV3(T* boxes, Nd4jLong const* boxesShape, int previousIndex, int nextIndex) { - Nd4jLong previous0[] = {previousIndex, 0}; - Nd4jLong previous1[] = {previousIndex, 1}; - Nd4jLong previous2[] = {previousIndex, 2}; - Nd4jLong previous3[] = {previousIndex, 3}; - Nd4jLong next0[] = {nextIndex, 0}; - Nd4jLong next1[] = {nextIndex, 1}; - Nd4jLong next2[] = {nextIndex, 2}; - Nd4jLong next3[] = {nextIndex, 3}; - - // we have rectangle with given max values. Compute vexes of rectangle first - - T minYPrev = sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, previous0)], boxes[shape::getOffset(boxesShape, previous2)]); - T minXPrev = sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, previous1)], boxes[shape::getOffset(boxesShape, previous3)]); - T maxYPrev = sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, previous0)], boxes[shape::getOffset(boxesShape, previous2)]); - T maxXPrev = sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, previous1)], boxes[shape::getOffset(boxesShape, previous3)]); - T minYNext = sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, next0)], boxes[shape::getOffset(boxesShape, next2)]); - T minXNext = sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, next1)], boxes[shape::getOffset(boxesShape, next3)]); - T maxYNext = sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, next0)], boxes[shape::getOffset(boxesShape, next2)]); - T maxXNext = sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, next1)], boxes[shape::getOffset(boxesShape, next3)]); - - // compute areas for comparation - T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev); - T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext); - - // of course, areas should be positive - if (areaNext <= T(0.f) || areaPrev <= T(0.f)) return false; - - // compute intersection of rectangles - T minIntersectionY = sd::math::nd4j_max(minYPrev, minYNext); - T minIntersectionX = sd::math::nd4j_max(minXPrev, minXNext); - T maxIntersectionY = sd::math::nd4j_min(maxYPrev, maxYNext); - T maxIntersectionX = sd::math::nd4j_min(maxXPrev, maxXNext); - T intersectionArea = - sd::math::nd4j_max(T(maxIntersectionY - minIntersectionY), T(0.0f)) * - sd::math::nd4j_max(T(maxIntersectionX - minIntersectionX), T(0.0f)); - T intersectionValue = intersectionArea / (areaPrev + areaNext - intersectionArea); - // final check - return intersectionValue; - } +template +static __device__ T similirityV3(T* boxes, Nd4jLong const* boxesShape, + int previousIndex, int nextIndex) { + Nd4jLong previous0[] = {previousIndex, 0}; + Nd4jLong previous1[] = {previousIndex, 1}; + Nd4jLong previous2[] = {previousIndex, 2}; + Nd4jLong previous3[] = {previousIndex, 3}; + Nd4jLong next0[] = {nextIndex, 0}; + Nd4jLong next1[] = {nextIndex, 1}; + Nd4jLong next2[] = {nextIndex, 2}; + Nd4jLong next3[] = {nextIndex, 3}; + + // we have rectangle with given max values. Compute vexes of rectangle first + + T minYPrev = + sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, previous0)], + boxes[shape::getOffset(boxesShape, previous2)]); + T minXPrev = + sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, previous1)], + boxes[shape::getOffset(boxesShape, previous3)]); + T maxYPrev = + sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, previous0)], + boxes[shape::getOffset(boxesShape, previous2)]); + T maxXPrev = + sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, previous1)], + boxes[shape::getOffset(boxesShape, previous3)]); + T minYNext = sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, next0)], + boxes[shape::getOffset(boxesShape, next2)]); + T minXNext = sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, next1)], + boxes[shape::getOffset(boxesShape, next3)]); + T maxYNext = sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, next0)], + boxes[shape::getOffset(boxesShape, next2)]); + T maxXNext = sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, next1)], + boxes[shape::getOffset(boxesShape, next3)]); + + // compute areas for comparation + T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev); + T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext); + + // of course, areas should be positive + if (areaNext <= T(0.f) || areaPrev <= T(0.f)) return false; + + // compute intersection of rectangles + T minIntersectionY = sd::math::nd4j_max(minYPrev, minYNext); + T minIntersectionX = sd::math::nd4j_max(minXPrev, minXNext); + T maxIntersectionY = sd::math::nd4j_min(maxYPrev, maxYNext); + T maxIntersectionX = sd::math::nd4j_min(maxXPrev, maxXNext); + T intersectionArea = + sd::math::nd4j_max(T(maxIntersectionY - minIntersectionY), T(0.0f)) * + sd::math::nd4j_max(T(maxIntersectionX - minIntersectionX), T(0.0f)); + T intersectionValue = + intersectionArea / (areaPrev + areaNext - intersectionArea); + // final check + return intersectionValue; +} - //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // shouldSelectKernel - compute status for all selected rectangles (boxes) // -// we compute boolean flag as shared uint32 and return it on final only for the first thread +// we compute boolean flag as shared uint32 and return it on final only for the +// first thread // - template - static __global__ void shouldSelectKernel(T* boxesBuf, Nd4jLong const* boxesShape, I* indexBuf, I* selectedIndicesData, double threshold, int numSelected, int i, bool* shouldSelect) { - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - auto step = gridDim.x * blockDim.x; - __shared__ unsigned int shouldSelectShared; - if (threadIdx.x == 0) { - shouldSelectShared = (unsigned int)shouldSelect[0]; - } - __syncthreads(); - for (int j = numSelected - 1 - tid; j >= 0; j -= step) { - if (shouldSelectShared) { - if (needToSuppressWithThreshold(boxesBuf, boxesShape, indexBuf[i], - indexBuf[selectedIndicesData[j]], T(threshold))) - atomicCAS(&shouldSelectShared, 1, 0); // exchange only when need to suppress - } - } - __syncthreads(); - - // final move: collect result - if (threadIdx.x == 0) { - *shouldSelect = shouldSelectShared > 0; - } +template +static __global__ void shouldSelectKernel(T* boxesBuf, + Nd4jLong const* boxesShape, + I* indexBuf, I* selectedIndicesData, + double threshold, int numSelected, + int i, bool* shouldSelect) { + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto step = gridDim.x * blockDim.x; + __shared__ unsigned int shouldSelectShared; + if (threadIdx.x == 0) { + shouldSelectShared = (unsigned int)shouldSelect[0]; + } + __syncthreads(); + for (int j = numSelected - 1 - tid; j >= 0; j -= step) { + if (shouldSelectShared) { + if (needToSuppressWithThreshold(boxesBuf, boxesShape, indexBuf[i], + indexBuf[selectedIndicesData[j]], + T(threshold))) + atomicCAS(&shouldSelectShared, 1, + 0); // exchange only when need to suppress } + } + __syncthreads(); + + // final move: collect result + if (threadIdx.x == 0) { + *shouldSelect = shouldSelectShared > 0; + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // indices - type depended, indicesLong - type defined (only 64bit integers) // - template - static __global__ void copyIndices(void* indices, void* indicesLong, Nd4jLong len) { - I* indexBuf = reinterpret_cast(indices); - Nd4jLong* srcBuf = reinterpret_cast(indicesLong);; +template +static __global__ void copyIndices(void* indices, void* indicesLong, + Nd4jLong len) { + I* indexBuf = reinterpret_cast(indices); + Nd4jLong* srcBuf = reinterpret_cast(indicesLong); + ; - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + auto step = blockDim.x * gridDim.x; - for (auto i = tid; i < len; i += step) - indexBuf[i] = (I)srcBuf[i]; - } + for (auto i = tid; i < len; i += step) indexBuf[i] = (I)srcBuf[i]; +} - template - static __global__ void suppressScores(T* scores, I* indices, Nd4jLong length, T scoreThreshold) { - auto start = blockIdx.x * blockDim.x; - auto step = gridDim.x * blockDim.x; - - for (auto e = start + threadIdx.x; e < (int)length; e += step) { - if (scores[e] < scoreThreshold) { - scores[e] = scoreThreshold; - indices[e] = -1; - } - else { - indices[e] = I(e); - } - } +template +static __global__ void suppressScores(T* scores, I* indices, Nd4jLong length, + T scoreThreshold) { + auto start = blockIdx.x * blockDim.x; + auto step = gridDim.x * blockDim.x; + + for (auto e = start + threadIdx.x; e < (int)length; e += step) { + if (scores[e] < scoreThreshold) { + scores[e] = scoreThreshold; + indices[e] = -1; + } else { + indices[e] = I(e); } + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -// nonMaxSuppressionV2 algorithm - given from TF NonMaxSuppressionV2 implementation +// nonMaxSuppressionV2 algorithm - given from TF NonMaxSuppressionV2 +// implementation // - template - static void nonMaxSuppressionV2_(sd::LaunchContext* context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, double scoreThreshold, NDArray* output) { - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({output}, {boxes, scales}); - std::unique_ptr indices(NDArrayFactory::create_('c', {scales->lengthOf()}, context)); // - 1, scales->lengthOf()); //, scales->getContext()); - - NDArray scores(*scales); - Nd4jPointer extras[2] = {nullptr, stream}; - auto indexBuf = indices->dataBuffer()->specialAsT();///reinterpret_cast(indices->specialBuffer()); - auto scoreBuf = scores.dataBuffer()->specialAsT(); - suppressScores<<<128, 128, 128, *stream>>>(scoreBuf, indexBuf, scores.lengthOf(), T(scoreThreshold)); - indices->tickWriteDevice(); - sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true); - indices->tickWriteDevice(); - NDArray selectedIndices = NDArrayFactory::create('c', {output->lengthOf()}, context); - int numSelected = 0; - int numBoxes = boxes->sizeAt(0); - auto boxesBuf = reinterpret_cast(boxes->specialBuffer()); - - auto selectedIndicesData = reinterpret_cast(selectedIndices.specialBuffer()); - auto outputBuf = reinterpret_cast(output->specialBuffer()); - - bool* shouldSelectD; - auto err = cudaMalloc(&shouldSelectD, sizeof(bool)); - if (err) { - throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot allocate memory for bool flag", err); - } - for (I i = 0; i < boxes->sizeAt(0); ++i) { - bool shouldSelect = numSelected < output->lengthOf(); - if (shouldSelect) { - err = cudaMemcpy(shouldSelectD, &shouldSelect, sizeof(bool), cudaMemcpyHostToDevice); - if (err) { - throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot set up bool flag to device", err); - } - - shouldSelectKernel<<<128, 256, 1024, *stream>>>(boxesBuf, boxes->specialShapeInfo(), indexBuf, selectedIndicesData, threshold, numSelected, i, shouldSelectD); - err = cudaMemcpy(&shouldSelect, shouldSelectD, sizeof(bool), cudaMemcpyDeviceToHost); - if (err) { - throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot set up bool flag to host", err); - } - } - - if (shouldSelect) { - cudaMemcpy(reinterpret_cast(output->specialBuffer()) + numSelected, indexBuf + i, sizeof(I), cudaMemcpyDeviceToDevice); - cudaMemcpy(selectedIndicesData + numSelected, &i, sizeof(I), cudaMemcpyHostToDevice); - numSelected++; - } - } - - err = cudaFree(shouldSelectD); - if (err) { - throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot deallocate memory for bool flag", err); - } - +template +static void nonMaxSuppressionV2_(sd::LaunchContext* context, NDArray* boxes, + NDArray* scales, int maxSize, double threshold, + double scoreThreshold, NDArray* output) { + auto stream = context->getCudaStream(); + NDArray::prepareSpecialUse({output}, {boxes, scales}); + std::unique_ptr indices(NDArrayFactory::create_( + 'c', {scales->lengthOf()}, + context)); // - 1, scales->lengthOf()); //, scales->getContext()); + + NDArray scores(*scales); + Nd4jPointer extras[2] = {nullptr, stream}; + auto indexBuf = + indices->dataBuffer() + ->specialAsT< + I>(); /// reinterpret_cast(indices->specialBuffer()); + auto scoreBuf = scores.dataBuffer()->specialAsT(); + suppressScores<<<128, 128, 128, *stream>>>( + scoreBuf, indexBuf, scores.lengthOf(), T(scoreThreshold)); + indices->tickWriteDevice(); + sortByValue(extras, indices->buffer(), indices->shapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), + scores.specialShapeInfo(), true); + indices->tickWriteDevice(); + NDArray selectedIndices = + NDArrayFactory::create('c', {output->lengthOf()}, context); + int numSelected = 0; + int numBoxes = boxes->sizeAt(0); + auto boxesBuf = reinterpret_cast(boxes->specialBuffer()); + + auto selectedIndicesData = + reinterpret_cast(selectedIndices.specialBuffer()); + auto outputBuf = reinterpret_cast(output->specialBuffer()); + + bool* shouldSelectD; + auto err = cudaMalloc(&shouldSelectD, sizeof(bool)); + if (err) { + throw cuda_exception::build( + "helpers::nonMaxSuppressionV2: Cannot allocate memory for bool flag", + err); + } + for (I i = 0; i < boxes->sizeAt(0); ++i) { + bool shouldSelect = numSelected < output->lengthOf(); + if (shouldSelect) { + err = cudaMemcpy(shouldSelectD, &shouldSelect, sizeof(bool), + cudaMemcpyHostToDevice); + if (err) { + throw cuda_exception::build( + "helpers::nonMaxSuppressionV2: Cannot set up bool flag to device", + err); + } + + shouldSelectKernel<<<128, 256, 1024, *stream>>>( + boxesBuf, boxes->specialShapeInfo(), indexBuf, selectedIndicesData, + threshold, numSelected, i, shouldSelectD); + err = cudaMemcpy(&shouldSelect, shouldSelectD, sizeof(bool), + cudaMemcpyDeviceToHost); + if (err) { + throw cuda_exception::build( + "helpers::nonMaxSuppressionV2: Cannot set up bool flag to host", + err); + } } -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - static __device__ bool checkOverlapBoxes(T* boxes, Nd4jLong const* shape, T* scores, I* indices, I* selectedIndices, I* startIndices, I selectedSize, I nextCandidateIndex, T overlapThreshold, T scoreThreshold, bool simple) { - bool shouldHardSuppress = false; - T& nextCandidateScore = scores[nextCandidateIndex]; - I selectedIndex = indices[nextCandidateIndex]; - I finish = startIndices[nextCandidateIndex]; - - for (int j = selectedSize; j > finish; --j) { - T boxVal; - if (simple) { - Nd4jLong xPos[] = {selectedIndex, selectedIndices[j - 1]}; - auto xShift = shape::getOffset(shape, xPos, 0); - boxVal = boxes[xShift]; - } - else { - boxVal = similirityV3(boxes, shape, selectedIndex, selectedIndices[j - 1]); - } - if (boxVal > static_cast(overlapThreshold)) - nextCandidateScore = static_cast(0.f); - - // First decide whether to perform hard suppression - if (boxVal >= overlapThreshold) { - shouldHardSuppress = true; - break; - } - - // If nextCandidate survives hard suppression, apply soft suppression - if (nextCandidateScore <= static_cast(scoreThreshold)) break; - } - - return shouldHardSuppress; + if (shouldSelect) { + cudaMemcpy(reinterpret_cast(output->specialBuffer()) + numSelected, + indexBuf + i, sizeof(I), cudaMemcpyDeviceToDevice); + cudaMemcpy(selectedIndicesData + numSelected, &i, sizeof(I), + cudaMemcpyHostToDevice); + numSelected++; } + } + + err = cudaFree(shouldSelectD); + if (err) { + throw cuda_exception::build( + "helpers::nonMaxSuppressionV2: Cannot deallocate memory for bool flag", + err); + } +} + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - static __global__ void - suppressNonMaxOverlapKernel(T* boxes, Nd4jLong const* boxesShape, T* scoresData, I* indices, I* startIndices, Nd4jLong length, I maxOutputLen, - T overlapThreshold, T scoreThreshold, I* output, Nd4jLong const* outputShape, I* outputLength, bool simple) { - - __shared__ I selectedSize; - __shared__ I* tempOutput; - - if (threadIdx.x == 0) { - selectedSize = outputLength?*outputLength:maxOutputLen; - extern __shared__ unsigned char shmem[]; - tempOutput = (I*)shmem; - } - __syncthreads(); - - auto start = blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; - - for (I nextCandidateIndex = start + threadIdx.x; selectedSize < maxOutputLen && nextCandidateIndex < (I)length; ) { - auto originalScore = scoresData[nextCandidateIndex];//nextCandidate._score; - I nextCandidateBoxIndex = indices[nextCandidateIndex]; - auto selectedSizeMark = selectedSize; - - // skip for cases when index is less than 0 (under score threshold) - if (nextCandidateBoxIndex < 0) { - nextCandidateIndex += step; - continue; - } - // check for overlaps - bool shouldHardSuppress = checkOverlapBoxes(boxes, boxesShape, scoresData, indices, tempOutput, startIndices, selectedSize, - nextCandidateIndex, overlapThreshold, scoreThreshold, simple);//false; - T nextCandidateScore = scoresData[nextCandidateIndex]; - - startIndices[nextCandidateIndex] = selectedSize; - if (!shouldHardSuppress) { - if (nextCandidateScore == originalScore) { - // Suppression has not occurred, so select nextCandidate - if (output) - output[selectedSize] = nextCandidateBoxIndex; - tempOutput[selectedSize] = nextCandidateBoxIndex; - math::atomics::nd4j_atomicAdd(&selectedSize, (I)1); - } - - if (nextCandidateScore > scoreThreshold) { - // Soft suppression has occurred and current score is still greater than - // scoreThreshold; add nextCandidate back onto priority queue. - continue; // in some cases, this index not 0 - } - } - nextCandidateIndex += step; - } - - if (threadIdx.x == 0) { - if (outputLength) - *outputLength = selectedSize; - } +template +static __device__ bool checkOverlapBoxes(T* boxes, Nd4jLong const* shape, + T* scores, I* indices, + I* selectedIndices, I* startIndices, + I selectedSize, I nextCandidateIndex, + T overlapThreshold, T scoreThreshold, + bool simple) { + bool shouldHardSuppress = false; + T& nextCandidateScore = scores[nextCandidateIndex]; + I selectedIndex = indices[nextCandidateIndex]; + I finish = startIndices[nextCandidateIndex]; + + for (int j = selectedSize; j > finish; --j) { + T boxVal; + if (simple) { + Nd4jLong xPos[] = {selectedIndex, selectedIndices[j - 1]}; + auto xShift = shape::getOffset(shape, xPos, 0); + boxVal = boxes[xShift]; + } else { + boxVal = + similirityV3(boxes, shape, selectedIndex, selectedIndices[j - 1]); } + if (boxVal > static_cast(overlapThreshold)) + nextCandidateScore = static_cast(0.f); -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - static Nd4jLong - nonMaxSuppressionGeneric_(sd::LaunchContext* context, NDArray* boxes, NDArray* scores, int outputSize, - double overlapThreshold, double scoreThreshold, NDArray* output, bool simple) { - auto stream = context->getCudaStream(); - if (output) - NDArray::prepareSpecialUse({output}, {boxes, scores}); - else { - if (!boxes->isActualOnDeviceSide()) - boxes->syncToDevice(); - if (!scores->isActualOnDeviceSide()) - scores->syncToDevice(); - } - - NDArray indices = NDArrayFactory::create('c', {scores->lengthOf()}, context); // - 1, scales->lengthOf()); //, scales->getContext()); - NDArray startPositions = NDArrayFactory::create('c', {scores->lengthOf()}, context); - NDArray selectedScores(*scores); - Nd4jPointer extras[2] = {nullptr, stream}; - auto indexBuf = indices.dataBuffer()->specialAsT();///reinterpret_cast(indices->specialBuffer()); - - suppressScores<<<128, 128, 128, *stream>>>(selectedScores.dataBuffer()->specialAsT(), indexBuf, selectedScores.lengthOf(), T(scoreThreshold)); - - sortByValue(extras, indices.buffer(), indices.shapeInfo(), indices.specialBuffer(), indices.specialShapeInfo(), selectedScores.buffer(), selectedScores.shapeInfo(), selectedScores.specialBuffer(), selectedScores.specialShapeInfo(), true); - indices.tickWriteDevice(); - selectedScores.tickWriteDevice(); - - auto scoresData = selectedScores.dataBuffer()->specialAsT();//, numBoxes, scoresData.begin()); - - auto startIndices = startPositions.dataBuffer()->specialAsT(); - I selectedSize = 0; - Nd4jLong res = 0; - if (output) { // this part used when output shape already calculated to fill up values on output - DataBuffer selectedSizeBuf(&selectedSize, sizeof(I), DataTypeUtils::fromT()); - suppressNonMaxOverlapKernel<<<1, 1, 1024, *stream >>> (boxes->dataBuffer()->specialAsT(), - boxes->specialShapeInfo(), scoresData, indexBuf, startIndices, scores->lengthOf(), (I) outputSize, - T(overlapThreshold), T(scoreThreshold), output->dataBuffer()->specialAsT(), output->specialShapeInfo(), - selectedSizeBuf.specialAsT(), simple); - } - else { // this case used on calculation of output shape. Output and output shape shoulde be nullptr. - DataBuffer selectedSizeBuf(&selectedSize, sizeof(I), DataTypeUtils::fromT()); - suppressNonMaxOverlapKernel<<<1, 1, 1024, *stream >>> (boxes->dataBuffer()->specialAsT(), - boxes->specialShapeInfo(), scoresData, indexBuf, startIndices, scores->lengthOf(), (I)outputSize, - T(overlapThreshold), T(scoreThreshold), (I*)nullptr, (Nd4jLong*) nullptr, selectedSizeBuf.specialAsT(), simple); - selectedSizeBuf.syncToPrimary(context, true); - res = *selectedSizeBuf.primaryAsT(); - } - - if (output) - NDArray::registerSpecialUse({output}, {boxes, scores}); - - return res; + // First decide whether to perform hard suppression + if (boxVal >= overlapThreshold) { + shouldHardSuppress = true; + break; } + + // If nextCandidate survives hard suppression, apply soft suppression + if (nextCandidateScore <= static_cast(scoreThreshold)) break; + } + + return shouldHardSuppress; +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - void nonMaxSuppression(sd::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, double scoreThreshold, NDArray* output) { - BUILD_DOUBLE_SELECTOR(boxes->dataType(), output->dataType(), nonMaxSuppressionV2_, - (context, boxes, scales, maxSize, threshold, scoreThreshold, output), - FLOAT_TYPES, INDEXING_TYPES); +template +static __global__ void suppressNonMaxOverlapKernel( + T* boxes, Nd4jLong const* boxesShape, T* scoresData, I* indices, + I* startIndices, Nd4jLong length, I maxOutputLen, T overlapThreshold, + T scoreThreshold, I* output, Nd4jLong const* outputShape, I* outputLength, + bool simple) { + __shared__ I selectedSize; + __shared__ I* tempOutput; + + if (threadIdx.x == 0) { + selectedSize = outputLength ? *outputLength : maxOutputLen; + extern __shared__ unsigned char shmem[]; + tempOutput = (I*)shmem; + } + __syncthreads(); + + auto start = blockIdx.x * blockDim.x; + auto step = blockDim.x * gridDim.x; + + for (I nextCandidateIndex = start + threadIdx.x; + selectedSize < maxOutputLen && nextCandidateIndex < (I)length;) { + auto originalScore = + scoresData[nextCandidateIndex]; // nextCandidate._score; + I nextCandidateBoxIndex = indices[nextCandidateIndex]; + auto selectedSizeMark = selectedSize; + + // skip for cases when index is less than 0 (under score threshold) + if (nextCandidateBoxIndex < 0) { + nextCandidateIndex += step; + continue; } -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - - Nd4jLong nonMaxSuppressionGeneric(sd::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, double scoreThreshold, NDArray* output) { - BUILD_DOUBLE_SELECTOR(boxes->dataType(), output ? output->dataType():DataType::INT32, return nonMaxSuppressionGeneric_, - (context, boxes, scales, maxSize, threshold, scoreThreshold, output, true), - FLOAT_TYPES, INDEXING_TYPES); - return boxes->sizeAt(0); + // check for overlaps + bool shouldHardSuppress = + checkOverlapBoxes(boxes, boxesShape, scoresData, indices, tempOutput, + startIndices, selectedSize, nextCandidateIndex, + overlapThreshold, scoreThreshold, simple); // false; + T nextCandidateScore = scoresData[nextCandidateIndex]; + + startIndices[nextCandidateIndex] = selectedSize; + if (!shouldHardSuppress) { + if (nextCandidateScore == originalScore) { + // Suppression has not occurred, so select nextCandidate + if (output) output[selectedSize] = nextCandidateBoxIndex; + tempOutput[selectedSize] = nextCandidateBoxIndex; + math::atomics::nd4j_atomicAdd(&selectedSize, (I)1); + } + + if (nextCandidateScore > scoreThreshold) { + // Soft suppression has occurred and current score is still greater than + // scoreThreshold; add nextCandidate back onto priority queue. + continue; // in some cases, this index not 0 + } } + nextCandidateIndex += step; + } - Nd4jLong - nonMaxSuppressionV3(sd::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize, - double overlapThreshold, double scoreThreshold, NDArray* output) { - BUILD_DOUBLE_SELECTOR(boxes->dataType(), output ? output->dataType():DataType::INT32, return nonMaxSuppressionGeneric_, - (context, boxes, scores, maxSize, overlapThreshold, scoreThreshold, output, false), - FLOAT_TYPES, INDEXING_TYPES); - return boxes->sizeAt(0); - } + if (threadIdx.x == 0) { + if (outputLength) *outputLength = selectedSize; + } +} +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +static Nd4jLong nonMaxSuppressionGeneric_(sd::LaunchContext* context, + NDArray* boxes, NDArray* scores, + int outputSize, + double overlapThreshold, + double scoreThreshold, + NDArray* output, bool simple) { + auto stream = context->getCudaStream(); + if (output) + NDArray::prepareSpecialUse({output}, {boxes, scores}); + else { + if (!boxes->isActualOnDeviceSide()) boxes->syncToDevice(); + if (!scores->isActualOnDeviceSide()) scores->syncToDevice(); + } + + auto indices = NDArrayFactory::create('c', {scores->lengthOf()},context); + auto startPositions = NDArrayFactory::create('c', {scores->lengthOf()}, context); + auto selectedScores = scores->dup(); + Nd4jPointer extras[2] = {nullptr, stream}; + auto indexBuf = indices.dataBuffer()->template specialAsT(); + + suppressScores<<<128, 128, 128, *stream>>>( + selectedScores.dataBuffer()->specialAsT(), indexBuf, + selectedScores.lengthOf(), T(scoreThreshold)); + + sortByValue(extras, indices.buffer(), indices.shapeInfo(), + indices.specialBuffer(), indices.specialShapeInfo(), + selectedScores.buffer(), selectedScores.shapeInfo(), + selectedScores.specialBuffer(), selectedScores.specialShapeInfo(), + true); + indices.tickWriteDevice(); + selectedScores.tickWriteDevice(); + + auto scoresData = selectedScores.dataBuffer()->template specialAsT(); + + auto startIndices = startPositions.dataBuffer()->template specialAsT(); + I selectedSize = 0; + Nd4jLong res = 0; + if (output) { // this part used when output shape already calculated to fill + // up values on output + DataBuffer selectedSizeBuf(&selectedSize, sizeof(I), + DataTypeUtils::fromT()); + suppressNonMaxOverlapKernel<<<1, 1, 1024, *stream>>>( + boxes->dataBuffer()->specialAsT(), boxes->specialShapeInfo(), + scoresData, indexBuf, startIndices, scores->lengthOf(), (I)outputSize, + T(overlapThreshold), T(scoreThreshold), + output->dataBuffer()->specialAsT(), output->specialShapeInfo(), + selectedSizeBuf.specialAsT(), simple); + } else { // this case used on calculation of output shape. Output and output + // shape shoulde be nullptr. + DataBuffer selectedSizeBuf(&selectedSize, sizeof(I), + DataTypeUtils::fromT()); + suppressNonMaxOverlapKernel<<<1, 1, 1024, *stream>>>( + boxes->dataBuffer()->specialAsT(), boxes->specialShapeInfo(), + scoresData, indexBuf, startIndices, scores->lengthOf(), (I)outputSize, + T(overlapThreshold), T(scoreThreshold), (I*)nullptr, (Nd4jLong*)nullptr, + selectedSizeBuf.specialAsT(), simple); + selectedSizeBuf.syncToPrimary(context, true); + res = *selectedSizeBuf.primaryAsT(); + } + + if (output) NDArray::registerSpecialUse({output}, {boxes, scores}); + + return res; +} +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void nonMaxSuppression(sd::LaunchContext* context, NDArray* boxes, + NDArray* scales, int maxSize, double threshold, + double scoreThreshold, NDArray* output) { + BUILD_DOUBLE_SELECTOR( + boxes->dataType(), output->dataType(), nonMaxSuppressionV2_, + (context, boxes, scales, maxSize, threshold, scoreThreshold, output), + FLOAT_TYPES, INDEXING_TYPES); } +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +Nd4jLong nonMaxSuppressionGeneric(sd::LaunchContext* context, NDArray* boxes, + NDArray* scales, int maxSize, + double threshold, double scoreThreshold, + NDArray* output) { + BUILD_DOUBLE_SELECTOR(boxes->dataType(), + output ? output->dataType() : DataType::INT32, + return nonMaxSuppressionGeneric_, + (context, boxes, scales, maxSize, threshold, + scoreThreshold, output, true), + FLOAT_TYPES, INDEXING_TYPES); + return boxes->sizeAt(0); } + +Nd4jLong nonMaxSuppressionV3(sd::LaunchContext* context, NDArray* boxes, + NDArray* scores, int maxSize, + double overlapThreshold, double scoreThreshold, + NDArray* output) { + BUILD_DOUBLE_SELECTOR(boxes->dataType(), + output ? output->dataType() : DataType::INT32, + return nonMaxSuppressionGeneric_, + (context, boxes, scores, maxSize, overlapThreshold, + scoreThreshold, output, false), + FLOAT_TYPES, INDEXING_TYPES); + return boxes->sizeAt(0); } + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/imagesHelpers.cu b/libnd4j/include/ops/declarable/helpers/cuda/imagesHelpers.cu index 749f60c11e99..26a3456a0b79 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/imagesHelpers.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/imagesHelpers.cu @@ -19,407 +19,498 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include -#include #include -#include #include +#include +#include +#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - /////////////////////////////////////////////////////////////////// -template -__global__ void rgbToYuvCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) { - - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); - - __shared__ int rank; - __shared__ Nd4jLong xDimCstride, zDimCstride; - - if (threadIdx.x == 0) { - rank = shape::rank(xShapeInfo); - xDimCstride = shape::stride(xShapeInfo)[dimC]; - zDimCstride = shape::stride(zShapeInfo)[dimC]; - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { - const T* xTad = x + xTadOffsets[i]; - T* zTad = z + zTadOffsets[i]; - - rgbYuv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); - } - +template +__global__ void rgbToYuvCuda(const void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xTadOffsets, void* vz, + const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadOffsets, + const Nd4jLong numOfTads, const int dimC) { + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); + + __shared__ int rank; + __shared__ Nd4jLong xDimCstride, zDimCstride; + + if (threadIdx.x == 0) { + rank = shape::rank(xShapeInfo); + xDimCstride = shape::stride(xShapeInfo)[dimC]; + zDimCstride = shape::stride(zShapeInfo)[dimC]; + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { + const T* xTad = x + xTadOffsets[i]; + T* zTad = z + zTadOffsets[i]; + + rgbYuv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], + zTad[zDimCstride], zTad[2 * zDimCstride]); + } } /////////////////////////////////////////////////////////////////// -template -linkage void rgbToYuvCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) { - - rgbToYuvCuda << > > (vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, dimC); +template +linkage void rgbToYuvCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) { + rgbToYuvCuda<<>>( + vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, + dimC); } /////////////////////////////////////////////////////////////////// -void transformRgbYuv(sd::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) { - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input.shapeInfo(), { dimC }); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output.shapeInfo(), { dimC }); - - const Nd4jLong numOfTads = packX.numberOfTads(); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; - - PointersManager manager(context, "yuv_to_rgb"); - - NDArray::prepareSpecialUse({ &output }, { &input }); - BUILD_SINGLE_SELECTOR(input.dataType(), rgbToYuvCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), packX.platformOffsets(), output.specialBuffer(), output.specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), FLOAT_TYPES); - NDArray::registerSpecialUse({ &output }, { &input }); - - manager.synchronize(); +void transformRgbYuv(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const int dimC) { + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input.shapeInfo(), {dimC}); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output.shapeInfo(), {dimC}); + + const Nd4jLong numOfTads = packX.numberOfTads(); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "yuv_to_rgb"); + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR( + input.dataType(), rgbToYuvCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), packX.platformOffsets(), + output.specialBuffer(), output.specialShapeInfo(), + packZ.platformOffsets(), numOfTads, dimC), + FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + manager.synchronize(); } /////////////////////////////////////////////////////////////////// -template -__global__ void yuvToRgbCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) { - - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); - - __shared__ int rank; - __shared__ Nd4jLong xDimCstride, zDimCstride; - - if (threadIdx.x == 0) { - rank = shape::rank(xShapeInfo); - xDimCstride = shape::stride(xShapeInfo)[dimC]; - zDimCstride = shape::stride(zShapeInfo)[dimC]; - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { - const T* xTad = x + xTadOffsets[i]; - T* zTad = z + zTadOffsets[i]; - - yuvRgb(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); - } - +template +__global__ void yuvToRgbCuda(const void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xTadOffsets, void* vz, + const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadOffsets, + const Nd4jLong numOfTads, const int dimC) { + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); + + __shared__ int rank; + __shared__ Nd4jLong xDimCstride, zDimCstride; + + if (threadIdx.x == 0) { + rank = shape::rank(xShapeInfo); + xDimCstride = shape::stride(xShapeInfo)[dimC]; + zDimCstride = shape::stride(zShapeInfo)[dimC]; + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { + const T* xTad = x + xTadOffsets[i]; + T* zTad = z + zTadOffsets[i]; + + yuvRgb(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], + zTad[zDimCstride], zTad[2 * zDimCstride]); + } } /////////////////////////////////////////////////////////////////// -template -linkage void yuvToRgbCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) { - - yuvToRgbCuda << > > (vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, dimC); +template +linkage void yuvToRgbCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) { + yuvToRgbCuda<<>>( + vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, + dimC); } /////////////////////////////////////////////////////////////////// -void transformYuvRgb(sd::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) { - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input.shapeInfo(), { dimC }); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output.shapeInfo(), { dimC }); - - const Nd4jLong numOfTads = packX.numberOfTads(); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; - - PointersManager manager(context, "yuv_to_rgb"); - - NDArray::prepareSpecialUse({ &output }, { &input }); - BUILD_SINGLE_SELECTOR(input.dataType(), yuvToRgbCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), packX.platformOffsets(), output.specialBuffer(), output.specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), FLOAT_TYPES); - NDArray::registerSpecialUse({ &output }, { &input }); - - manager.synchronize(); +void transformYuvRgb(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const int dimC) { + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input.shapeInfo(), {dimC}); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output.shapeInfo(), {dimC}); + + const Nd4jLong numOfTads = packX.numberOfTads(); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "yuv_to_rgb"); + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR( + input.dataType(), yuvToRgbCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), packX.platformOffsets(), + output.specialBuffer(), output.specialShapeInfo(), + packZ.platformOffsets(), numOfTads, dimC), + FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + manager.synchronize(); } /////////////////////////////////////////////////////////////////// // for example xShapeInfo = {2,3,4}, zShapeInfo = {2,1,4} -template -__global__ void rgbToGrsCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const int dimC) { - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ Nd4jLong zLen; - __shared__ int rank, *sharedMem; // xRank == zRank - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - zLen = shape::length(zShapeInfo); - rank = shape::rank(zShapeInfo); - } - __syncthreads(); - - auto coords = sharedMem + threadIdx.x * rank; - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; i += gridDim.x * blockDim.x) { - - if (dimC == (rank - 1) && 'c' == shape::order(xShapeInfo) && 1 == shape::elementWiseStride(xShapeInfo) && 'c' == shape::order(zShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo)) { - const auto xStep = i*3; - z[i] = 0.2989f * x[xStep] + 0.5870f * x[xStep + 1] + 0.1140f * x[xStep + 2]; - } - else { - - shape::index2coords(i, zShapeInfo, coords); +template +__global__ void rgbToGrsCuda(const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + const int dimC) { + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ Nd4jLong zLen; + __shared__ int rank, *sharedMem; // xRank == zRank + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + zLen = shape::length(zShapeInfo); + rank = shape::rank(zShapeInfo); + } + __syncthreads(); + + auto coords = sharedMem + threadIdx.x * rank; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; + i += gridDim.x * blockDim.x) { + if (dimC == (rank - 1) && 'c' == shape::order(xShapeInfo) && + 1 == shape::elementWiseStride(xShapeInfo) && + 'c' == shape::order(zShapeInfo) && + 1 == shape::elementWiseStride(zShapeInfo)) { + const auto xStep = i * 3; + z[i] = + 0.2989f * x[xStep] + 0.5870f * x[xStep + 1] + 0.1140f * x[xStep + 2]; + } else { + shape::index2coords(i, zShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); - const auto xOffset0 = shape::getOffset(xShapeInfo, coords); - const auto xOffset1 = xOffset0 + shape::stride(xShapeInfo)[dimC]; - const auto xOffset2 = xOffset1 + shape::stride(xShapeInfo)[dimC]; + const auto zOffset = shape::getOffset(zShapeInfo, coords); + const auto xOffset0 = shape::getOffset(xShapeInfo, coords); + const auto xOffset1 = xOffset0 + shape::stride(xShapeInfo)[dimC]; + const auto xOffset2 = xOffset1 + shape::stride(xShapeInfo)[dimC]; - z[zOffset] = 0.2989f * x[xOffset0] + 0.5870f * x[xOffset1] + 0.1140f * x[xOffset2]; - } - } + z[zOffset] = + 0.2989f * x[xOffset0] + 0.5870f * x[xOffset1] + 0.1140f * x[xOffset2]; + } + } } /////////////////////////////////////////////////////////////////// -template -linkage void rgbToGrsCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const int dimC) { - - rgbToGrsCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, dimC); +template +linkage void rgbToGrsCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, + const int sharedMem, + const cudaStream_t* stream, const void* vx, + const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const int dimC) { + rgbToGrsCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, dimC); } /////////////////////////////////////////////////////////////////// -void transformRgbGrs(sd::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) { - - PointersManager manager(context, "rgbToGrs"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = input.rankOf() * sizeof(int) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), rgbToGrsCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), dimC), NUMERIC_TYPES); - NDArray::registerSpecialUse({&output}, {&input}); - - manager.synchronize(); +void transformRgbGrs(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const int dimC) { + PointersManager manager(context, "rgbToGrs"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = input.rankOf() * sizeof(int) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR( + input.dataType(), rgbToGrsCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), dimC), + NUMERIC_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + manager.synchronize(); } - /////////////////////////////////////////////////////////////////// template -static void _CUDA_G rgbToHsvCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, - void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, - const Nd4jLong numOfTads, const int dimC) { - - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); +static void _CUDA_G rgbToHsvCuda(const void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xTadOffsets, void* vz, + const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadOffsets, + const Nd4jLong numOfTads, const int dimC) { + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); - __shared__ int rank; - __shared__ Nd4jLong xDimCstride, zDimCstride; + __shared__ int rank; + __shared__ Nd4jLong xDimCstride, zDimCstride; - if (threadIdx.x == 0) { - rank = shape::rank(xShapeInfo); - xDimCstride = shape::stride(xShapeInfo)[dimC]; - zDimCstride = shape::stride(zShapeInfo)[dimC]; - } - __syncthreads(); + if (threadIdx.x == 0) { + rank = shape::rank(xShapeInfo); + xDimCstride = shape::stride(xShapeInfo)[dimC]; + zDimCstride = shape::stride(zShapeInfo)[dimC]; + } + __syncthreads(); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { - const T* xTad = x + xTadOffsets[i]; - T* zTad = z + zTadOffsets[i]; + for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { + const T* xTad = x + xTadOffsets[i]; + T* zTad = z + zTadOffsets[i]; - rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); - } + rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], + zTad[zDimCstride], zTad[2 * zDimCstride]); + } } /////////////////////////////////////////////////////////////////// template -static void _CUDA_G hsvToRgbCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, - void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, +static void _CUDA_G hsvToRgbCuda(const void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xTadOffsets, void* vz, + const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) { + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); + __shared__ int rank; + __shared__ Nd4jLong xDimCstride, zDimCstride; - __shared__ int rank; - __shared__ Nd4jLong xDimCstride, zDimCstride; + if (threadIdx.x == 0) { + rank = shape::rank(xShapeInfo); + xDimCstride = shape::stride(xShapeInfo)[dimC]; + zDimCstride = shape::stride(zShapeInfo)[dimC]; + } + __syncthreads(); - if (threadIdx.x == 0) { - rank = shape::rank(xShapeInfo); - xDimCstride = shape::stride(xShapeInfo)[dimC]; - zDimCstride = shape::stride(zShapeInfo)[dimC]; - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { - const T* xTad = x + xTadOffsets[i]; - T* zTad = z + zTadOffsets[i]; + for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { + const T* xTad = x + xTadOffsets[i]; + T* zTad = z + zTadOffsets[i]; - hsvToRgb(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); - } + hsvToRgb(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], + zTad[zDimCstride], zTad[2 * zDimCstride]); + } } /////////////////////////////////////////////////////////////////// -template -static _CUDA_H void hsvToRgbCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, - void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, - const Nd4jLong numOfTads, const int dimC) { - - hsvToRgbCuda<<>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, dimC); +template +static _CUDA_H void hsvToRgbCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) { + hsvToRgbCuda<<>>( + vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, + dimC); } -template -static _CUDA_H void rgbToHsvCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, - void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, - const Nd4jLong numOfTads, const int dimC) { - - rgbToHsvCuda<<>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, dimC); +template +static _CUDA_H void rgbToHsvCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) { + rgbToHsvCuda<<>>( + vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, + dimC); } /////////////////////////////////////////////////////////////////// -void transformHsvRgb(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), {dimC}); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), {dimC}); - - const Nd4jLong numOfTads = packX.numberOfTads(); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; - - PointersManager manager(context, "hsv_to_rgb"); - - NDArray::prepareSpecialUse({output}, {input}); - BUILD_SINGLE_SELECTOR(input->dataType(), hsvToRgbCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->specialBuffer(), input->specialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), FLOAT_TYPES); - NDArray::registerSpecialUse({output}, {input}); - - manager.synchronize(); +void transformHsvRgb(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC) { + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), {dimC}); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), {dimC}); + + const Nd4jLong numOfTads = packX.numberOfTads(); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "hsv_to_rgb"); + + NDArray::prepareSpecialUse({output}, {input}); + BUILD_SINGLE_SELECTOR( + input->dataType(), hsvToRgbCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + input->specialBuffer(), input->specialShapeInfo(), + packX.platformOffsets(), output->specialBuffer(), + output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), + FLOAT_TYPES); + NDArray::registerSpecialUse({output}, {input}); + + manager.synchronize(); } /////////////////////////////////////////////////////////////////// -void transformRgbHsv(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), {dimC}); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), {dimC}); - - const Nd4jLong numOfTads = packX.numberOfTads(); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; - - PointersManager manager(context, "rgb_to_hsv"); - - NDArray::prepareSpecialUse({output}, {input}); - BUILD_SINGLE_SELECTOR(input->dataType(), rgbToHsvCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->specialBuffer(), input->specialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), FLOAT_TYPES); - NDArray::registerSpecialUse({output}, {input}); - - manager.synchronize(); +void transformRgbHsv(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC) { + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), {dimC}); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), {dimC}); + + const Nd4jLong numOfTads = packX.numberOfTads(); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "rgb_to_hsv"); + + NDArray::prepareSpecialUse({output}, {input}); + BUILD_SINGLE_SELECTOR( + input->dataType(), rgbToHsvCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + input->specialBuffer(), input->specialShapeInfo(), + packX.platformOffsets(), output->specialBuffer(), + output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), + FLOAT_TYPES); + NDArray::registerSpecialUse({output}, {input}); + + manager.synchronize(); } -template -__global__ void tripleTransformerCuda(const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets, const int dimC, int mode, uint64_t numTads) { - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ Nd4jLong zLen, *sharedMem; - __shared__ int rank; // xRank == zRank - - float yiqarr[3][3] = { - { 0.299f, 0.59590059f, 0.2115f }, - { 0.587f, -0.27455667f, -0.52273617f }, - { 0.114f, -0.32134392f, 0.31119955f } - }; - - float rgbarr[3][3] = { - { 1.f, 1.f, 1.f }, - { 0.95598634f, -0.27201283f, -1.10674021f }, - { 0.6208248f, -0.64720424f, 1.70423049f } - }; - - auto tr = mode == 1? yiqarr : rgbarr; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - zLen = shape::length(zShapeInfo); - rank = shape::rank(zShapeInfo); +template +__global__ void tripleTransformerCuda( + const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadShapeInfo, + const Nd4jLong* xOffsets, void* vz, const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadShapeInfo, const Nd4jLong* zOffsets, const int dimC, + int mode, uint64_t numTads) { + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ Nd4jLong zLen, *sharedMem; + __shared__ int rank; // xRank == zRank + + float yiqarr[3][3] = {{0.299f, 0.59590059f, 0.2115f}, + {0.587f, -0.27455667f, -0.52273617f}, + {0.114f, -0.32134392f, 0.31119955f}}; + + float rgbarr[3][3] = {{1.f, 1.f, 1.f}, + {0.95598634f, -0.27201283f, -1.10674021f}, + {0.6208248f, -0.64720424f, 1.70423049f}}; + + auto tr = mode == 1 ? yiqarr : rgbarr; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + zLen = shape::length(zShapeInfo); + rank = shape::rank(zShapeInfo); + } + __syncthreads(); + + Nd4jLong* coords = sharedMem + threadIdx.x * rank; + + if (dimC == (rank - 1) && 'c' == shape::order(xShapeInfo) && + 1 == shape::elementWiseStride(xShapeInfo) && + 'c' == shape::order(zShapeInfo) && + 1 == shape::elementWiseStride(zShapeInfo)) { + for (uint64_t f = blockIdx.x * blockDim.x + threadIdx.x; f < zLen / 3; + f += gridDim.x * blockDim.x) { + auto i = f * 3; + + auto xi0 = x[i]; + auto xi1 = x[i + 1]; + auto xi2 = x[i + 2]; + + for (int e = 0; e < 3; e++) + z[i + e] = xi0 * tr[0][e] + xi1 * tr[1][e] + xi2 * tr[2][e]; } - __syncthreads(); - - Nd4jLong* coords = sharedMem + threadIdx.x * rank; - - if (dimC == (rank - 1) && 'c' == shape::order(xShapeInfo) && 1 == shape::elementWiseStride(xShapeInfo) && 'c' == shape::order(zShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo)) { - for (uint64_t f = blockIdx.x * blockDim.x + threadIdx.x; f < zLen / 3; f += gridDim.x * blockDim.x) { - auto i = f * 3; - - auto xi0 = x[i]; - auto xi1 = x[i+1]; - auto xi2 = x[i+2]; - - for (int e = 0; e < 3; e++) - z[i + e] = xi0 * tr[0][e] + xi1 * tr[1][e] + xi2 * tr[2][e]; - } - } else { - // TAD based case - const Nd4jLong xDimCstride = shape::stride(xShapeInfo)[dimC]; - const Nd4jLong zDimCstride = shape::stride(zShapeInfo)[dimC]; - - for (uint64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < numTads; i += blockDim.x * gridDim.x) { - const T* xTad = x + xOffsets[i]; - T* zTad = z + zOffsets[i]; - - auto xi0 = xTad[0]; - auto xi1 = xTad[xDimCstride]; - auto xi2 = xTad[xDimCstride * 2]; - - for (int e = 0; e < 3; e++) - zTad[zDimCstride * e] = xi0 * tr[0][e] + xi1 * tr[1][e] + xi2 * tr[2][e]; - } + } else { + // TAD based case + const Nd4jLong xDimCstride = shape::stride(xShapeInfo)[dimC]; + const Nd4jLong zDimCstride = shape::stride(zShapeInfo)[dimC]; + + for (uint64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < numTads; + i += blockDim.x * gridDim.x) { + const T* xTad = x + xOffsets[i]; + T* zTad = z + zOffsets[i]; + + auto xi0 = xTad[0]; + auto xi1 = xTad[xDimCstride]; + auto xi2 = xTad[xDimCstride * 2]; + + for (int e = 0; e < 3; e++) + zTad[zDimCstride * e] = + xi0 * tr[0][e] + xi1 * tr[1][e] + xi2 * tr[2][e]; } + } } - template -static void rgbYiq(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimC); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimC); - - NDArray::prepareSpecialUse({output}, {input}); - return tripleTransformerCuda<<<256, 256, 8192, *context->getCudaStream()>>>(input->specialBuffer(), input->specialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformShapeInfo(), packZ.platformOffsets(), dimC, 1, packZ.numberOfTads()); - NDArray::registerSpecialUse({output}, {input}); +static void rgbYiq(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC) { + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimC); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimC); + + NDArray::prepareSpecialUse({output}, {input}); + return tripleTransformerCuda + <<<256, 256, 8192, *context->getCudaStream()>>>( + input->specialBuffer(), input->specialShapeInfo(), + packX.platformShapeInfo(), packX.platformOffsets(), + output->specialBuffer(), output->specialShapeInfo(), + packZ.platformShapeInfo(), packZ.platformOffsets(), dimC, 1, + packZ.numberOfTads()); + NDArray::registerSpecialUse({output}, {input}); } template -FORCEINLINE static void yiqRgb(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimC); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimC); - - NDArray::prepareSpecialUse({output}, {input}); - return tripleTransformerCuda<<<256, 256, 8192, *context->getCudaStream()>>>(input->specialBuffer(), input->specialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformShapeInfo(), packZ.platformOffsets(), dimC, 2, packZ.numberOfTads()); - NDArray::registerSpecialUse({output}, {input}); -} - -void transformYiqRgb(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { - BUILD_SINGLE_SELECTOR(input->dataType(), yiqRgb, (context, input, output, dimC), FLOAT_TYPES); +FORCEINLINE static void yiqRgb(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC) { + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimC); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimC); + + NDArray::prepareSpecialUse({output}, {input}); + return tripleTransformerCuda + <<<256, 256, 8192, *context->getCudaStream()>>>( + input->specialBuffer(), input->specialShapeInfo(), + packX.platformShapeInfo(), packX.platformOffsets(), + output->specialBuffer(), output->specialShapeInfo(), + packZ.platformShapeInfo(), packZ.platformOffsets(), dimC, 2, + packZ.numberOfTads()); + NDArray::registerSpecialUse({output}, {input}); } -void transformRgbYiq(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { - BUILD_SINGLE_SELECTOR(input->dataType(), rgbYiq, (context, input, output, dimC), FLOAT_TYPES); +void transformYiqRgb(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), yiqRgb, + (context, input, output, dimC), FLOAT_TYPES); } - - - - -} -} +void transformRgbYiq(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), rgbYiq, + (context, input, output, dimC), FLOAT_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu b/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu index f6e233aab481..946508124a88 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu @@ -19,74 +19,91 @@ // @author raver119@gmail.com // - -#include -#include -#include -#include #include -#include #include +#include +#include +#include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { template -static void ismax_(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector& dimensions) { - auto stream = context->getCudaStream(); - - auto xRank = input->rankOf(); - auto zRank = output->rankOf(); - auto xType = input->dataType(); - auto zType = output->dataType(); - input->syncToDevice(); - Nd4jLong* special = nullptr; - PointersManager manager(context, "IsMaxHelper"); - if (dimensions.size() == 0) { - /** - * In case of vector-input for IsMax, it just turns into IndexReduce call + subsequent filler call - */ - auto indexMax = input->applyIndexReduce(indexreduce::IndexMax, dimensions); - auto targetIdx = indexMax.e(0); - - dim3 launchDims(128, 512, 1024); - BUILD_SINGLE_SELECTOR(zType, fillIsMaxGeneric, (launchDims, stream, output->specialBuffer(), output->specialShapeInfo(), output->lengthOf(), targetIdx), LIBND4J_TYPES); - manager.synchronize(); - - } else { - Nd4jLong* hostYShapeInfo = nullptr; - Nd4jLong* hostTShapeInfo = nullptr; - int* dimension = nullptr; - int dimensionLength = dimensions.size(); - std::vector copy(dimensions); - - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), copy.data(), copy.size()); - - // we launch legacy IndexMax op, to get indices of max values along dimension - auto indexMaxArr = input->applyIndexReduce(indexreduce::IndexMax, dimensions); - - dim3 launchDims(256, 256, 16384); - dimension = (int *) manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int)); - - // at this point, all IMax indexes are gathered, and we execute filler - BUILD_SINGLE_SELECTOR(zType, fillDimensionalIsMaxGeneric, (launchDims, stream, indexMaxArr.specialBuffer(), output->specialBuffer(), output->specialShapeInfo(), packZ.specialShapeInfo(), dimension, dimensionLength, packZ.specialOffsets()), LIBND4J_TYPES); - manager.synchronize(); - } +static void ismax_(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const std::vector& dimensions) { + auto stream = context->getCudaStream(); + + auto xRank = input->rankOf(); + auto zRank = output->rankOf(); + auto xType = input->dataType(); + auto zType = output->dataType(); + input->syncToDevice(); + Nd4jLong* special = nullptr; + PointersManager manager(context, "IsMaxHelper"); + if (dimensions.size() == 0) { + /** + * In case of vector-input for IsMax, it just turns into IndexReduce call + + * subsequent filler call + */ + auto indexMax = input->applyIndexReduce(indexreduce::IndexMax, dimensions); + auto targetIdx = indexMax.e(0); + + dim3 launchDims(128, 512, 1024); + BUILD_SINGLE_SELECTOR( + zType, fillIsMaxGeneric, + (launchDims, stream, output->specialBuffer(), + output->specialShapeInfo(), output->lengthOf(), targetIdx), + LIBND4J_TYPES); + manager.synchronize(); + + } else { + Nd4jLong* hostYShapeInfo = nullptr; + Nd4jLong* hostTShapeInfo = nullptr; + int* dimension = nullptr; + int dimensionLength = dimensions.size(); + std::vector copy(dimensions); + + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), copy.data(), copy.size()); + + // we launch legacy IndexMax op, to get indices of max values along + // dimension + auto indexMaxArr = + input->applyIndexReduce(indexreduce::IndexMax, dimensions); + + dim3 launchDims(256, 256, 16384); + dimension = (int*)manager.replicatePointer(dimensions.data(), + dimensions.size() * sizeof(int)); + + // at this point, all IMax indexes are gathered, and we execute filler + BUILD_SINGLE_SELECTOR(zType, fillDimensionalIsMaxGeneric, + (launchDims, stream, indexMaxArr.specialBuffer(), + output->specialBuffer(), output->specialShapeInfo(), + packZ.specialShapeInfo(), dimension, dimensionLength, + packZ.specialOffsets()), + LIBND4J_TYPES); + manager.synchronize(); + } } +void ismax(sd::LaunchContext* context, const NDArray* input, NDArray* output, + const std::vector& dimensions) { + NDArray::prepareSpecialUse({output}, {input}); -void ismax(sd::LaunchContext * context, const NDArray *input, NDArray *output, const std::vector& dimensions) { - NDArray::prepareSpecialUse({output}, {input}); - - BUILD_SINGLE_SELECTOR(input->dataType(), ismax_, (context, input, output, dimensions), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), ismax_, + (context, input, output, dimensions), LIBND4J_TYPES); - NDArray::registerSpecialUse({output}, {input}); + NDArray::registerSpecialUse({output}, {input}); } -BUILD_SINGLE_TEMPLATE(template void ismax_, (sd::LaunchContext * context, const NDArray *input, NDArray *output, const std::vector& dimensions), LIBND4J_TYPES); - -} -} -} +BUILD_SINGLE_TEMPLATE(template void ismax_, + (sd::LaunchContext * context, const NDArray* input, + NDArray* output, const std::vector& dimensions), + LIBND4J_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu b/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu index ab65ed96b36d..2203f4d07c6b 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu @@ -18,98 +18,108 @@ // @author GS // +#include #include #include -#include #include namespace sd { - namespace ops { - namespace helpers { - - template - linkage void reluDerivative__(NDArray* theFirst, NDArray* theSecond) { - auto functor = LAMBDA_TT(x, y){ - return x > (T) 0.f ? y : T(0.f); - }; - - theFirst->applyPairwiseLambda(*theSecond, functor, *theFirst); - } - - void reluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative__, (theFirst, theSecond), FLOAT_TYPES); - } - - template - linkage void reluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return x > (T)0.f ? y : T(0.f); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void reluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - template - linkage void relu6Derivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return x > (T)0.f && x < (T)6.f? y : T(0.f); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void relu6Derivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), relu6Derivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - template - linkage void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, const float alpha) { - - const T alphaT = static_cast(alpha); - - auto functor = LAMBDA_TT(x, y, alphaT) { - return x < 0 ? alphaT * y : y; - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void leakyReluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES); - } - - template - linkage void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, const float alpha) { - - const T alphaT = static_cast(alpha); - - auto functor = LAMBDA_TT(x, y, alphaT){ - return y * sd::math::nd4j_eluderivative(x, alphaT); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void eluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES); - } - - template - linkage void seluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return y * simdOps::SELUDerivative::op(x, nullptr); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void seluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), seluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - } - } -} \ No newline at end of file +namespace ops { +namespace helpers { + +template +linkage void reluDerivative__(NDArray* theFirst, NDArray* theSecond) { + auto functor = LAMBDA_TT(x, y) { return x > (T)0.f ? y : T(0.f); }; + + theFirst->applyPairwiseLambda(*theSecond, functor, *theFirst); +} + +void reluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative__, + (theFirst, theSecond), FLOAT_TYPES); +} + +template +linkage void reluDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { return x > (T)0.f ? y : T(0.f); }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void reluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} + +template +linkage void relu6Derivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + return x > (T)0.f && x < (T)6.f ? y : T(0.f); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void relu6Derivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), relu6Derivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} + +template +linkage void leakyReluDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output, const float alpha) { + const T alphaT = static_cast(alpha); + + auto functor = LAMBDA_TT(x, y, alphaT) { return x < 0 ? alphaT * y : y; }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void leakyReluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput, + const float alpha) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, + (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES); +} + +template +linkage void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, + const float alpha) { + const T alphaT = static_cast(alpha); + + auto functor = LAMBDA_TT(x, y, alphaT) { + return y * sd::math::nd4j_eluderivative(x, alphaT); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void eluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput, const float alpha) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, + (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES); +} + +template +linkage void seluDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + return y * simdOps::SELUDerivative::op(x, nullptr); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void seluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), seluDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/legacy/tanh.cu b/libnd4j/include/ops/declarable/helpers/cuda/legacy/tanh.cu index 56a57614f4c2..b11e76be7ad2 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/legacy/tanh.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/legacy/tanh.cu @@ -18,69 +18,81 @@ // @author GS // -#include #include +#include #include #include namespace sd { - namespace ops { - namespace helpers { - //////////////////////////////////////////////////////////////////////// - template - linkage void tanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - T th = sd::math::nd4j_tanh(x); - return y * ((T)1.0f - (th * th)); - }; +namespace ops { +namespace helpers { +//////////////////////////////////////////////////////////////////////// +template +linkage void tanhDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + T th = sd::math::nd4j_tanh(x); + return y * ((T)1.0f - (th * th)); + }; - input->applyPairwiseLambda(*epsilon, functor, *output); - } + input->applyPairwiseLambda(*epsilon, functor, *output); +} - void tanhDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), tanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } +void tanhDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), tanhDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} - // return static_cast(d2) * simdOps::HardTanhDerivative::op(d1, nullptr); - template - linkage void hardTanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - T th = sd::math::nd4j_tanh(x); - return y * simdOps::HardTanhDerivative::op(x, nullptr); - }; +// return static_cast(d2) * simdOps::HardTanhDerivative::op(d1, nullptr); +template +linkage void hardTanhDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + T th = sd::math::nd4j_tanh(x); + return y * simdOps::HardTanhDerivative::op(x, nullptr); + }; - input->applyPairwiseLambda(*epsilon, functor, *output); - } + input->applyPairwiseLambda(*epsilon, functor, *output); +} - void hardTanhDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } +void hardTanhDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardTanhDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} - template - linkage void rationalTanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return y * simdOps::RationalTanhDerivative::op(x, nullptr); - }; +template +linkage void rationalTanhDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + return y * simdOps::RationalTanhDerivative::op(x, nullptr); + }; - input->applyPairwiseLambda(*epsilon, functor, *output); - } + input->applyPairwiseLambda(*epsilon, functor, *output); +} - void rationalTanhDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), rationalTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } +void rationalTanhDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), rationalTanhDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} - template - linkage void rectifiedTanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return x > (T) 0.0f ? y * (sd::math::nd4j_tanhderivative(x)) : (T) 0.0f; - }; +template +linkage void rectifiedTanhDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + return x > (T)0.0f ? y * (sd::math::nd4j_tanhderivative(x)) : (T)0.0f; + }; - input->applyPairwiseLambda(*epsilon, functor, *output); - } + input->applyPairwiseLambda(*epsilon, functor, *output); +} - void rectifiedTanhDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), rectifiedTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - } - } -} \ No newline at end of file +void rectifiedTanhDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), rectifiedTanhDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/legacy_helper.cu b/libnd4j/include/ops/declarable/helpers/cuda/legacy_helper.cu index 181050919e06..b729ca137d5e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/legacy_helper.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/legacy_helper.cu @@ -18,232 +18,271 @@ // @author GS // -#include #include -#include +#include #include +#include namespace sd { namespace ops { namespace helpers { //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - linkage void cubeDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return y * (3 * x * x); - }; +template +linkage void cubeDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { return y * (3 * x * x); }; - input->applyPairwiseLambda(*epsilon, functor, *output); - } + input->applyPairwiseLambda(*epsilon, functor, *output); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - void cubeDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), cubeDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } +void cubeDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), cubeDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - //return (x >= X(0.f) ? y: -y); - template - linkage void reduceNorm1_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return x > T(0.f)? y : -y; - }; +// return (x >= X(0.f) ? y: -y); +template +linkage void reduceNorm1_(NDArray* input, NDArray* epsilon, NDArray* output) { + auto functor = LAMBDA_TT(x, y) { return x > T(0.f) ? y : -y; }; - input->applyPairwiseLambda(*epsilon, functor, *output); - } + input->applyPairwiseLambda(*epsilon, functor, *output); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - void reduceNorm1(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), reduceNorm1_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } +void reduceNorm1(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), reduceNorm1_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////// - template - linkage void sigmCrossEntropy_(NDArray* logits, NDArray* labels, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return sd::math::nd4j_max(x, (T)0.f) - x * y + sd::math::nd4j_log((T)1.f + sd::math::nd4j_exp(-sd::math::nd4j_abs(x))); - }; - - logits->applyPairwiseLambda(*labels, functor, *output); - } +//////////////////////////////////////////////////////////////////////// +template +linkage void sigmCrossEntropy_(NDArray* logits, NDArray* labels, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + return sd::math::nd4j_max(x, (T)0.f) - x * y + + sd::math::nd4j_log( + (T)1.f + sd::math::nd4j_exp(-sd::math::nd4j_abs(x))); + }; + + logits->applyPairwiseLambda(*labels, functor, *output); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - void sigmCrossEntropy(sd::LaunchContext * context, NDArray* logits, NDArray* labels, NDArray* output) { - BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropy_, (logits, labels, output), FLOAT_TYPES); - } +void sigmCrossEntropy(sd::LaunchContext* context, NDArray* logits, + NDArray* labels, NDArray* output) { + BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropy_, + (logits, labels, output), FLOAT_TYPES); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////// - template - linkage void sigmCrossEntropyGrad_(NDArray* logits, NDArray* labels, NDArray* output) { - // 1 - labels - 1 / (1 + exp(logits)) - auto functor = LAMBDA_TT(x, y) { - if(x <= 0) - return static_cast(1.) - y - static_cast(1.) / (static_cast(1.) + sd::math::nd4j_exp(x)); - auto e = sd::math::nd4j_exp(-x); - return static_cast(1.) - y - e / (static_cast(1.) + e); - }; - - logits->applyPairwiseLambda(*labels, functor, *output); - } +//////////////////////////////////////////////////////////////////////// +template +linkage void sigmCrossEntropyGrad_(NDArray* logits, NDArray* labels, + NDArray* output) { + // 1 - labels - 1 / (1 + exp(logits)) + auto functor = LAMBDA_TT(x, y) { + if (x <= 0) + return static_cast(1.) - y - + static_cast(1.) / + (static_cast(1.) + sd::math::nd4j_exp(x)); + auto e = sd::math::nd4j_exp(-x); + return static_cast(1.) - y - e / (static_cast(1.) + e); + }; + + logits->applyPairwiseLambda(*labels, functor, *output); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - void sigmCrossEntropyGrad(sd::LaunchContext * context, NDArray* logits, NDArray* labels, NDArray* output) { - BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropyGrad_, (logits, labels, output), FLOAT_TYPES); - } +void sigmCrossEntropyGrad(sd::LaunchContext* context, NDArray* logits, + NDArray* labels, NDArray* output) { + BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropyGrad_, + (logits, labels, output), FLOAT_TYPES); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // X f = (X) 1.0f + sd::math::nd4j_abs(d1); - // return (X) d2 * ((X) 1.0f / (f * f)); - // - template - linkage void softSignDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - T ss = (T)1.f + sd::math::nd4j_abs(x); - return y * ((T) 1.0f / (ss * ss)); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } +// X f = (X) 1.0f + sd::math::nd4j_abs(d1); +// return (X) d2 * ((X) 1.0f / (f * f)); +// +template +linkage void softSignDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + T ss = (T)1.f + sd::math::nd4j_abs(x); + return y * ((T)1.0f / (ss * ss)); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - void softSignDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), softSignDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } +void softSignDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), softSignDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - linkage void softPlusDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - T p = sd::math::nd4j_pow(static_cast(M_E), x); - return y * (p / (p + 1.)); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } +template +linkage void softPlusDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + T p = sd::math::nd4j_pow(static_cast(M_E), x); + return y * (p / (p + 1.)); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} - void softPlusDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), softPlusDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } +void softPlusDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), softPlusDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// /// /// \param input /// \param epsilon /// \param output - template - linkage void sigmoidDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - T s = sd::math::nd4j_sigmoid(x); - return y * (s * ((T) 1.0f - s)); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void sigmoidDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), sigmoidDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - template - linkage void hardSigmoidDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return y * simdOps::HardSigmoidDerivative::op(x, nullptr); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void hardSigmoidDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardSigmoidDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - linkage void logSumExp_(NDArray* input, NDArray* axis, NDArray* output) { - // reduce along axis with - NDArray tempInput = input->dup(); - input->applyTransform(transform::Exp, tempInput); - std::vector axisVector; - if (axis != nullptr) { - axisVector.resize(axis->lengthOf()); - for (size_t i = 0; i < axisVector.size(); ++i) - axisVector[i] = axis->e(i); - } - tempInput.reduceAlongDimension(reduce::Sum, *output, axisVector); - output->applyTransform(transform::Log, *output); - } - - template - linkage void logSumExp_(NDArray* input, NDArray* subtrah, NDArray* axis, NDArray* output) { - // reduce along axis with - NDArray tempInput = input->dup(); - input->applyPairwiseTransform(pairwise::Subtract, *subtrah, tempInput); - tempInput.applyTransform(transform::Exp, tempInput); - - std::vector axisVector; - if (axis != nullptr) { - axisVector.resize(axis->lengthOf()); - for (size_t i = 0; i < axisVector.size(); ++i) - axisVector[i] = axis->e(i); - } - tempInput.reduceAlongDimension(reduce::Sum, *output, axisVector); - output->applyTransform(transform::Log, *output); - } - -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - void logSumExp(sd::LaunchContext * context, NDArray* input, NDArray* axis, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), logSumExp_, (input, axis, output), FLOAT_TYPES); - } - -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - void logSumExp(sd::LaunchContext * context, NDArray* input, NDArray* subtrah, NDArray* axis, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), logSumExp_, (input, subtrah, axis, output), FLOAT_TYPES); - } - -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - void weightedCrossEntropyWithLogitsFunctor_(NDArray const* targets, NDArray const* input, NDArray const* weights, NDArray* output) { - - T posWeight = weights->e(0); - - auto mainRoutineT1 = LAMBDA_TT(_x, _z, posWeight) { - T targetWeight = (1. + (posWeight - (T)1.f) * _z); - return (1. - _z) * _x + - targetWeight * (sd::math::nd4j_log((T)1.f + sd::math::nd4j_exp(-sd::math::nd4j_abs(_x))) + - sd::math::nd4j_max(-_x, T(0.f)) - ); - }; - - auto mainRoutineT2 = LAMBDA_TTT(_x, _z, _w) { - return (((T)1.0 - _z) * _x) + - _w * (sd::math::nd4j_log(T(1.) + sd::math::nd4j_exp(-sd::math::nd4j_abs(_x))) + - sd::math::nd4j_max(-_x, T(0.f))); - }; - - - if (weights->isScalar()) { - const_cast(input)->applyPairwiseLambda(const_cast(*targets), mainRoutineT1, *output); - } - else - { - std::unique_ptr targetVector(new NDArray(*weights)); - targetVector->applyScalar(scalar::Add, -1.f, *targetVector); - - std::unique_ptr targetTensor(new NDArray(*targets)); - *targetTensor = (*targetVector * *targetTensor) + T(1.f); - const_cast(input)->applyTriplewiseLambda(const_cast(*targets), *targetTensor.get(), mainRoutineT2, *output); - } - } -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - void weightedCrossEntropyWithLogitsFunctor(sd::LaunchContext * context, NDArray const* targets, NDArray const* input, NDArray const* weights, NDArray* output) { - NDArray::prepareSpecialUse({output}, {targets, input, weights}); - - BUILD_SINGLE_SELECTOR(targets->dataType(), weightedCrossEntropyWithLogitsFunctor_, (targets, input, weights, output), FLOAT_TYPES); - - NDArray::registerSpecialUse({output}, {targets, input, weights}); - } +template +linkage void sigmoidDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + T s = sd::math::nd4j_sigmoid(x); + return y * (s * ((T)1.0f - s)); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} +void sigmoidDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), sigmoidDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); } -} -} \ No newline at end of file + +template +linkage void hardSigmoidDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + return y * simdOps::HardSigmoidDerivative::op(x, nullptr); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void hardSigmoidDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardSigmoidDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +linkage void logSumExp_(NDArray* input, NDArray* axis, NDArray* output) { + // reduce along axis with + NDArray tempInput = input->dup(); + input->applyTransform(transform::Exp, tempInput); + std::vector axisVector; + if (axis != nullptr) { + axisVector.resize(axis->lengthOf()); + for (size_t i = 0; i < axisVector.size(); ++i) + axisVector[i] = axis->e(i); + } + tempInput.reduceAlongDimension(reduce::Sum, *output, axisVector); + output->applyTransform(transform::Log, *output); +} + +template +linkage void logSumExp_(NDArray* input, NDArray* subtrah, NDArray* axis, + NDArray* output) { + // reduce along axis with + NDArray tempInput = input->dup(); + input->applyPairwiseTransform(pairwise::Subtract, *subtrah, tempInput); + tempInput.applyTransform(transform::Exp, tempInput); + + std::vector axisVector; + if (axis != nullptr) { + axisVector.resize(axis->lengthOf()); + for (size_t i = 0; i < axisVector.size(); ++i) + axisVector[i] = axis->e(i); + } + tempInput.reduceAlongDimension(reduce::Sum, *output, axisVector); + output->applyTransform(transform::Log, *output); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void logSumExp(sd::LaunchContext* context, NDArray* input, NDArray* axis, + NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), logSumExp_, (input, axis, output), + FLOAT_TYPES); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void logSumExp(sd::LaunchContext* context, NDArray* input, NDArray* subtrah, + NDArray* axis, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), logSumExp_, + (input, subtrah, axis, output), FLOAT_TYPES); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +void weightedCrossEntropyWithLogitsFunctor_(NDArray const* targets, + NDArray const* input, + NDArray const* weights, + NDArray* output) { + T posWeight = weights->e(0); + + auto mainRoutineT1 = LAMBDA_TT(_x, _z, posWeight) { + T targetWeight = (1. + (posWeight - (T)1.f) * _z); + return (1. - _z) * _x + + targetWeight * (sd::math::nd4j_log( + (T)1.f + sd::math::nd4j_exp( + -sd::math::nd4j_abs(_x))) + + sd::math::nd4j_max(-_x, T(0.f))); + }; + + auto mainRoutineT2 = LAMBDA_TTT(_x, _z, _w) { + return (((T)1.0 - _z) * _x) + + _w * + (sd::math::nd4j_log( + T(1.) + sd::math::nd4j_exp(-sd::math::nd4j_abs(_x))) + + sd::math::nd4j_max(-_x, T(0.f))); + }; + + if (weights->isScalar()) { + const_cast(input)->applyPairwiseLambda( + const_cast(*targets), mainRoutineT1, *output); + } else { + std::unique_ptr targetVector(new NDArray(*weights)); + targetVector->applyScalar(scalar::Add, -1.f, *targetVector); + + std::unique_ptr targetTensor(new NDArray(*targets)); + *targetTensor = (*targetVector * *targetTensor) + T(1.f); + const_cast(input)->applyTriplewiseLambda( + const_cast(*targets), *targetTensor.get(), mainRoutineT2, + *output); + } +} +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void weightedCrossEntropyWithLogitsFunctor(sd::LaunchContext* context, + NDArray const* targets, + NDArray const* input, + NDArray const* weights, + NDArray* output) { + NDArray::prepareSpecialUse({output}, {targets, input, weights}); + + BUILD_SINGLE_SELECTOR(targets->dataType(), + weightedCrossEntropyWithLogitsFunctor_, + (targets, input, weights, output), FLOAT_TYPES); + + NDArray::registerSpecialUse({output}, {targets, input, weights}); +} + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lgamma.cu b/libnd4j/include/ops/declarable/helpers/cuda/lgamma.cu index 2de455d0f8cf..2828380ab839 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lgamma.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lgamma.cu @@ -19,7 +19,7 @@ // @author George A. Shulinok // -#include +#include //#include //#include @@ -31,24 +31,24 @@ namespace helpers { // calculate digamma function for array elements template void lgamma_(NDArray& x, NDArray& z) { - //auto dtype = x.dataType(); - auto lgammaProc = LAMBDA_T(x_, dtype) { - return T(DataTypeUtils::fromT() == DataType::DOUBLE?::lgamma(x_): ::lgammaf(x_)); //math::nd4j_log(math::nd4j_gamma(x)); - }; - - x.applyLambda(lgammaProc, z); + // auto dtype = x.dataType(); + auto lgammaProc = LAMBDA_T(x_, dtype) { + return T( + DataTypeUtils::fromT() == DataType::DOUBLE + ? ::lgamma(x_) + : ::lgammaf(x_)); // math::nd4j_log(math::nd4j_gamma(x)); + }; + + x.applyLambda(lgammaProc, z); } void lgamma(sd::LaunchContext* context, NDArray& x, NDArray& z) { - - BUILD_SINGLE_SELECTOR(x.dataType(), lgamma_, (x, z), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(x.dataType(), lgamma_, (x, z), FLOAT_TYPES); } -BUILD_SINGLE_TEMPLATE(template void lgamma_, (NDArray& x, NDArray& z), FLOAT_TYPES); - - - -} -} -} +BUILD_SINGLE_TEMPLATE(template void lgamma_, (NDArray & x, NDArray& z), + FLOAT_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lrn.cu b/libnd4j/include/ops/declarable/helpers/cuda/lrn.cu index 123c06ac570e..3efb0bc89838 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lrn.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lrn.cu @@ -18,154 +18,192 @@ // @author raver119@gmail.com // -#include #include #include +#include namespace sd { namespace ops { namespace helpers { - template - static _CUDA_G void lrnKernel(void *vx, Nd4jLong const*xTadShapeInfo, Nd4jLong const*xTadOffsets, void *vz, Nd4jLong const*zTadShapeInfo, Nd4jLong const*zTadOffsets, Nd4jLong numTads, Nd4jLong tadLength, int depth, double bias, double alpha, double beta) { - extern __shared__ char sharedChar[]; - T* shared = reinterpret_cast(sharedChar); - - auto xEws = shape::elementWiseStride(xTadShapeInfo); - auto zEws = shape::elementWiseStride(zTadShapeInfo); - - auto xOrder = shape::order(xTadShapeInfo); - auto zOrder = shape::order(zTadShapeInfo); - - const T tbias = static_cast(bias); - const T tbeta = static_cast(beta); - const T talpha = static_cast(alpha); - - // one block of threads processes 1 example within batch - for (uint i = blockIdx.x; i < numTads; i += gridDim.x) { - auto x = reinterpret_cast(vx) + xTadOffsets[i]; - auto z = reinterpret_cast(vz) + zTadOffsets[i]; - - // load everything into shared memory, so we'll operate on shared memory from now on - shared[threadIdx.x] = x[threadIdx.x * xEws]; - __syncthreads(); - - const uint begin = sd::math::nd4j_max(0, threadIdx.x - depth); - const uint last = depth + threadIdx.x + 1; - const uint end = sd::math::nd4j_min(last, tadLength); - - T prev = 0.; - for (int s = begin; s < end; s++) - prev = prev + shared[s] * shared[s]; - - z[threadIdx.x * zEws] = shared[threadIdx.x] / sd::math::nd4j_pow(tbias + alpha * prev, tbeta); - } - } - - template - static _CUDA_G void lrnBPKernel(void const* vx, Nd4jLong const* xTadShapeInfo, Nd4jLong const* xTadOffsets, void *vz, Nd4jLong const* zTadShapeInfo, Nd4jLong const* zTadOffsets, Nd4jLong numTads, Nd4jLong tadLength, int depth, double bias, double alpha, double beta) { - extern __shared__ char sharedChar[]; - X* sharedX = reinterpret_cast(sharedChar); - Z* sharedY = reinterpret_cast(sharedX + blockDim.x); - - auto xEws = shape::elementWiseStride(xTadShapeInfo); - auto zEws = shape::elementWiseStride(zTadShapeInfo); - - auto xOrder = shape::order(xTadShapeInfo); - auto zOrder = shape::order(zTadShapeInfo); - - const Z tbias = static_cast(bias); - const Z tbeta = static_cast(beta); - const Z talpha = static_cast(alpha); - const Z coeff = talpha * tbeta; - - - - for (uint i = blockIdx.x; i < numTads; i += gridDim.x) { - auto x = reinterpret_cast(vx) + xTadOffsets[i]; - auto z = reinterpret_cast(vz) + zTadOffsets[i]; - - const uint begin = sd::math::nd4j_max(0, threadIdx.x - depth); - const uint last = depth + threadIdx.x + 1; - const uint end = sd::math::nd4j_min(last, tadLength); - - // load everything into shared memory - sharedX[threadIdx.x] = x[threadIdx.x * xEws]; - sharedY[threadIdx.x] = 0.f; - __syncthreads(); - - // we're operating in shared memory - for (int s = begin; s < end; s++) - sharedY[threadIdx.x] = sharedY[threadIdx.x] + sharedX[s] * sharedX[s]; - __syncthreads(); - - Z factor[1024]; - Z init = tbias + talpha * sharedY[threadIdx.x]; - - Z prev = 0.f; - for (uint s = begin; s < end; ++s) { - factor[s] = sd::math::nd4j_pow(tbias + talpha * sharedY[s], -tbeta - 1); - prev = prev + sharedX[s] * factor[s]; - } - - z[threadIdx.x * zEws] = factor[threadIdx.x] * init - 2 * sharedX[threadIdx.x] * coeff * prev; - } - } - - - template - static void lrnBP_(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int depth, const float bias, const float alpha, const float beta) { - auto rank = input.rankOf(); - auto packX = ConstantTadHelper::getInstance().tadForDimensions(input.shapeInfo(), {rank - 1}); - auto packZ = ConstantTadHelper::getInstance().tadForDimensions(gradI.shapeInfo(), {rank - 1}); - - const auto tadLength = shape::length(packX.primaryShapeInfo()); - const int numBlocks = sd::math::nd4j_min(1024, packX.numberOfTads()); - const int numThreads = tadLength; - - if (tadLength > 1024 || tadLength < 1) - throw std::runtime_error("LRN: tadLength > 1024 isn't implemented yet"); - - lrnBPKernel<<getCudaStream()>>>(input.specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), gradI.specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), packX.numberOfTads(), tadLength, depth, bias, alpha, beta); +template +static _CUDA_G void lrnKernel(void* vx, Nd4jLong const* xTadShapeInfo, + Nd4jLong const* xTadOffsets, void* vz, + Nd4jLong const* zTadShapeInfo, + Nd4jLong const* zTadOffsets, Nd4jLong numTads, + Nd4jLong tadLength, int depth, double bias, + double alpha, double beta) { + extern __shared__ char sharedChar[]; + T* shared = reinterpret_cast(sharedChar); + + auto xEws = shape::elementWiseStride(xTadShapeInfo); + auto zEws = shape::elementWiseStride(zTadShapeInfo); + + auto xOrder = shape::order(xTadShapeInfo); + auto zOrder = shape::order(zTadShapeInfo); + + const T tbias = static_cast(bias); + const T tbeta = static_cast(beta); + const T talpha = static_cast(alpha); + + // one block of threads processes 1 example within batch + for (uint i = blockIdx.x; i < numTads; i += gridDim.x) { + auto x = reinterpret_cast(vx) + xTadOffsets[i]; + auto z = reinterpret_cast(vz) + zTadOffsets[i]; + + // load everything into shared memory, so we'll operate on shared memory + // from now on + shared[threadIdx.x] = x[threadIdx.x * xEws]; + __syncthreads(); + + const uint begin = sd::math::nd4j_max(0, threadIdx.x - depth); + const uint last = depth + threadIdx.x + 1; + const uint end = sd::math::nd4j_min(last, tadLength); + + T prev = 0.; + for (int s = begin; s < end; s++) prev = prev + shared[s] * shared[s]; + + z[threadIdx.x * zEws] = + shared[threadIdx.x] / + sd::math::nd4j_pow(tbias + alpha * prev, tbeta); + } +} - gradI.tickWriteDevice(); - gradI *= gradO; +template +static _CUDA_G void lrnBPKernel(void const* vx, Nd4jLong const* xTadShapeInfo, + Nd4jLong const* xTadOffsets, void* vz, + Nd4jLong const* zTadShapeInfo, + Nd4jLong const* zTadOffsets, Nd4jLong numTads, + Nd4jLong tadLength, int depth, double bias, + double alpha, double beta) { + extern __shared__ char sharedChar[]; + X* sharedX = reinterpret_cast(sharedChar); + Z* sharedY = reinterpret_cast(sharedX + blockDim.x); + + auto xEws = shape::elementWiseStride(xTadShapeInfo); + auto zEws = shape::elementWiseStride(zTadShapeInfo); + + auto xOrder = shape::order(xTadShapeInfo); + auto zOrder = shape::order(zTadShapeInfo); + + const Z tbias = static_cast(bias); + const Z tbeta = static_cast(beta); + const Z talpha = static_cast(alpha); + const Z coeff = talpha * tbeta; + + for (uint i = blockIdx.x; i < numTads; i += gridDim.x) { + auto x = reinterpret_cast(vx) + xTadOffsets[i]; + auto z = reinterpret_cast(vz) + zTadOffsets[i]; + + const uint begin = sd::math::nd4j_max(0, threadIdx.x - depth); + const uint last = depth + threadIdx.x + 1; + const uint end = sd::math::nd4j_min(last, tadLength); + + // load everything into shared memory + sharedX[threadIdx.x] = x[threadIdx.x * xEws]; + sharedY[threadIdx.x] = 0.f; + __syncthreads(); + + // we're operating in shared memory + for (int s = begin; s < end; s++) + sharedY[threadIdx.x] = sharedY[threadIdx.x] + sharedX[s] * sharedX[s]; + __syncthreads(); + + Z factor[1024]; + Z init = tbias + talpha * sharedY[threadIdx.x]; + + Z prev = 0.f; + for (uint s = begin; s < end; ++s) { + factor[s] = + sd::math::nd4j_pow(tbias + talpha * sharedY[s], -tbeta - 1); + prev = prev + sharedX[s] * factor[s]; } - void lrnBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int depth, const float bias, const float alpha, const float beta) { - input.syncToDevice(); - gradO.syncToDevice(); - - BUILD_DOUBLE_SELECTOR(input.dataType(), gradO.dataType(), lrnBP_, (block, input, gradO, gradI, depth, bias, alpha, beta), FLOAT_TYPES, FLOAT_TYPES); + z[threadIdx.x * zEws] = + factor[threadIdx.x] * init - 2 * sharedX[threadIdx.x] * coeff * prev; + } +} - gradI.tickWriteDevice(); - } +template +static void lrnBP_(sd::graph::Context& block, const NDArray& input, + const NDArray& gradO, NDArray& gradI, const int depth, + const float bias, const float alpha, const float beta) { + auto rank = input.rankOf(); + auto packX = ConstantTadHelper::getInstance().tadForDimensions( + input.shapeInfo(), {rank - 1}); + auto packZ = ConstantTadHelper::getInstance().tadForDimensions( + gradI.shapeInfo(), {rank - 1}); + + const auto tadLength = shape::length(packX.primaryShapeInfo()); + const int numBlocks = + sd::math::nd4j_min(1024, packX.numberOfTads()); + const int numThreads = tadLength; + + if (tadLength > 1024 || tadLength < 1) + throw std::runtime_error("LRN: tadLength > 1024 isn't implemented yet"); + + lrnBPKernel<<getCudaStream()>>>( + input.specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), + gradI.specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), + packX.numberOfTads(), tadLength, depth, bias, alpha, beta); + + gradI.tickWriteDevice(); + gradI *= gradO; +} - template - static void lrnFunctor_(sd::graph::Context& block, NDArray* input, NDArray* output, int depth, double bias, double alpha, double beta) { - auto rank = input->rankOf(); - auto packX = ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), {rank - 1}); - auto packZ = ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), {rank - 1}); +void lrnBP(sd::graph::Context& block, const NDArray& input, + const NDArray& gradO, NDArray& gradI, const int depth, + const float bias, const float alpha, const float beta) { + input.syncToDevice(); + gradO.syncToDevice(); - const auto tadLength = shape::length(packX.primaryShapeInfo()); - const int numBlocks = sd::math::nd4j_min(1024, packX.numberOfTads()); - const int numThreads = tadLength; + BUILD_DOUBLE_SELECTOR(input.dataType(), gradO.dataType(), lrnBP_, + (block, input, gradO, gradI, depth, bias, alpha, beta), + FLOAT_TYPES, FLOAT_TYPES); - if (tadLength > 1024 || tadLength < 1) - throw std::runtime_error("LRN: tadLength > 1024 isn't implemented yet"); + gradI.tickWriteDevice(); +} - lrnKernel<<getCudaStream()>>>(input->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), packX.numberOfTads(), tadLength, depth, bias, alpha, beta); - } +template +static void lrnFunctor_(sd::graph::Context& block, NDArray* input, + NDArray* output, int depth, double bias, double alpha, + double beta) { + auto rank = input->rankOf(); + auto packX = ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), {rank - 1}); + auto packZ = ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), {rank - 1}); + + const auto tadLength = shape::length(packX.primaryShapeInfo()); + const int numBlocks = + sd::math::nd4j_min(1024, packX.numberOfTads()); + const int numThreads = tadLength; + + if (tadLength > 1024 || tadLength < 1) + throw std::runtime_error("LRN: tadLength > 1024 isn't implemented yet"); + + lrnKernel<<getCudaStream()>>>( + input->specialBuffer(), packX.platformShapeInfo(), + packX.platformOffsets(), output->specialBuffer(), + packZ.platformShapeInfo(), packZ.platformOffsets(), packX.numberOfTads(), + tadLength, depth, bias, alpha, beta); +} - int lrnFunctor(sd::graph::Context& block, NDArray* input, NDArray* output, int depth, double bias, double alpha, double beta) { - input->syncToDevice(); +int lrnFunctor(sd::graph::Context& block, NDArray* input, NDArray* output, + int depth, double bias, double alpha, double beta) { + input->syncToDevice(); - BUILD_SINGLE_SELECTOR(input->dataType(), lrnFunctor_, (block, input, output, depth, bias, alpha, beta), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), lrnFunctor_, + (block, input, output, depth, bias, alpha, beta), + FLOAT_TYPES); - output->tickWriteDevice(); + output->tickWriteDevice(); - return Status::OK(); - } -} -} + return Status::OK(); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lstm.cu b/libnd4j/include/ops/declarable/helpers/cuda/lstm.cu index af0c413d6701..e58944c679b7 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lstm.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lstm.cu @@ -20,179 +20,212 @@ // implementation of operation for LSTM cell with peep hole connections: // http://www.bioinf.jku.at/publications/older/2604.pdf -// S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. -// and +// S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural +// Computation, 9(8):1735-1780, 1997. and // https://research.google.com/pubs/archive/43905.pdf -// Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014. +// Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory +// recurrent neural network architectures for large scale acoustic modeling." +// INTERSPEECH, 2014. - -#include -#include -#include -#include -#include #include +#include +#include +#include +#include +#include + #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - - ////////////////////////////////////////////////////////////////////////// -void lstmCell(sd::LaunchContext * context, const NDArray* xt, const NDArray* ht_1, const NDArray* ct_1, const NDArray* Wx, const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, const NDArray* b, - NDArray* ht, NDArray* ct, const std::vector& params) { - - // xt input [bS x nIn] - // ht_1 previous cell output [bS x numProj], that is at previous time step t-1, in case of projection=false -> numProj=nOut!!! - // ct_1 previous cell state [bS x nOut], that is at previous time step t-1 - - // Wx input-to-hidden weights, [nIn x 4*nOut] - // Wh hidden-to-hidden weights, [numProj x 4*nOut] - // Wc diagonal weights for peephole connections [3*nOut] - // Wp projection weights [nOut x numProj] - // b biases, [4*nOut] - - // ht current cell output [bS x numProj], that is at current time step t - // ct current cell state [bS x nOut], that is at current time step t - - const bool peephole = (bool)params[0]; // if true, provide peephole connections - const bool projection = (bool)params[1]; // if true, then projection is performed, if false then numProj==nOut is mandatory!!!! - double clippingCellValue = params[2]; // clipping value for ct, if it is not equal to zero, then cell state is clipped - double clippingProjValue = params[3]; // clipping value for projected ht, if it is not equal to zero, then projected cell output is clipped - const double forgetBias = params[4]; - - const int bS = xt->sizeAt(0); - const int nIn = xt->sizeAt(1); - const int numProj = ht_1->sizeAt(1); - const int nOut = ct_1->sizeAt(1); - - auto z = mmul(*xt, *Wx) + mmul(*ht_1, *Wh) + *b; // [bS x 4*nOut] + [bS x 4*nOut] + [1 x 4*nOut] = [bS x 4*nOut] - - auto zit = z({0,0, 0,nOut}); // z for input gate, = mmul(Wxi,xt) + mmul(Whi,ht_1) + bi = [bS x nOut] - auto zft = z({0,0, nOut,2*nOut}); // z for forget gate, = mmul(Wxf,xt) + mmul(Whf,ht_1) + bf = [bS x nOut] - auto zct = z({0,0, 2*nOut,3*nOut}); // z for cell state, = mmul(Wxc,xt) + mmul(Whc,ht_1) + bc = [bS x nOut] - auto zot = z({0,0, 3*nOut,4*nOut}); // z for output gate, = mmul(Wxo,xt) + mmul(Who,ht_1) + bo = [bS x nOut] - - if(peephole) { // add peephole connections: z + ct_1*Wc - zit += (*ct_1) * (*Wc)({0, nOut}); // add peephole connections to input gate - zft += (*ct_1) * (*Wc)({nOut, 2*nOut}); // add peephole connections to forget gate - } - - // current sell state = ft*ct_1 + it*tanh(mmul(Wxc,xt) + mmul(Whc,ht_1) + bc - ct->assign( sigmoid(zft + forgetBias) * (*ct_1) + sigmoid(zit) * tanh(zct) ); - - // if clipping value is provided then cell state is clipped by this value prior to the cell output activation - if(clippingCellValue > 0.0) - ct->applyScalar(scalar::LstmClip, clippingCellValue, *ct); - - if(peephole) - zot += (*ct) * (*Wc)({{2*nOut, 3*nOut}}); // add peephole connections to output gate zot + ct*Wc - - // current cell output = ot*tanh(ct) - auto htNoPeepHole = sigmoid(zot) * tanh(*ct); // = [bS x nOut] - - // apply projection - if(projection) { - ht->assign( mmul(htNoPeepHole, *Wp) ); // [bS x nOut] * [ nOut x numProj] = [bS x numProj] - // if clipping projection is provided then projected cell output state is clipped by this value - if(clippingProjValue != 0.) - ht->applyScalar(scalar::LstmClip, clippingProjValue, *ht); - } - else - ht->assign(&htNoPeepHole); +void lstmCell(sd::LaunchContext* context, const NDArray* xt, + const NDArray* ht_1, const NDArray* ct_1, const NDArray* Wx, + const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, + const NDArray* b, NDArray* ht, NDArray* ct, + const std::vector& params) { + // xt input [bS x nIn] + // ht_1 previous cell output [bS x numProj], that is at previous time step + // t-1, in case of projection=false -> numProj=nOut!!! ct_1 previous cell + // state [bS x nOut], that is at previous time step t-1 + + // Wx input-to-hidden weights, [nIn x 4*nOut] + // Wh hidden-to-hidden weights, [numProj x 4*nOut] + // Wc diagonal weights for peephole connections [3*nOut] + // Wp projection weights [nOut x numProj] + // b biases, [4*nOut] + + // ht current cell output [bS x numProj], that is at current time step t + // ct current cell state [bS x nOut], that is at current time step t + + const bool peephole = + (bool)params[0]; // if true, provide peephole connections + const bool projection = + (bool)params[1]; // if true, then projection is performed, if false then + // numProj==nOut is mandatory!!!! + double clippingCellValue = + params[2]; // clipping value for ct, if it is not equal to zero, then + // cell state is clipped + double clippingProjValue = + params[3]; // clipping value for projected ht, if it is not equal to + // zero, then projected cell output is clipped + const double forgetBias = params[4]; + + const int bS = xt->sizeAt(0); + const int nIn = xt->sizeAt(1); + const int numProj = ht_1->sizeAt(1); + const int nOut = ct_1->sizeAt(1); + + auto z = mmul(*xt, *Wx) + mmul(*ht_1, *Wh) + + *b; // [bS x 4*nOut] + [bS x 4*nOut] + [1 x 4*nOut] = [bS x 4*nOut] + + auto zit = z({0, 0, 0, nOut}); // z for input gate, = mmul(Wxi,xt) + + // mmul(Whi,ht_1) + bi = [bS x nOut] + auto zft = z({0, 0, nOut, 2 * nOut}); // z for forget gate, = mmul(Wxf,xt) + + // mmul(Whf,ht_1) + bf = [bS x nOut] + auto zct = + z({0, 0, 2 * nOut, 3 * nOut}); // z for cell state, = mmul(Wxc,xt) + + // mmul(Whc,ht_1) + bc = [bS x nOut] + auto zot = + z({0, 0, 3 * nOut, 4 * nOut}); // z for output gate, = mmul(Wxo,xt) + + // mmul(Who,ht_1) + bo = [bS x nOut] + + if (peephole) { // add peephole connections: z + ct_1*Wc + zit += + (*ct_1) * (*Wc)({0, nOut}); // add peephole connections to input gate + zft += (*ct_1) * + (*Wc)({nOut, 2 * nOut}); // add peephole connections to forget gate + } + + // current sell state = ft*ct_1 + it*tanh(mmul(Wxc,xt) + mmul(Whc,ht_1) + bc + ct->assign(sigmoid(zft + forgetBias) * (*ct_1) + sigmoid(zit) * tanh(zct)); + + // if clipping value is provided then cell state is clipped by this value + // prior to the cell output activation + if (clippingCellValue > 0.0) + ct->applyScalar(scalar::LstmClip, clippingCellValue, *ct); + + if (peephole) + zot += (*ct) * (*Wc)({{2 * nOut, 3 * nOut}}); // add peephole connections + // to output gate zot + ct*Wc + + // current cell output = ot*tanh(ct) + auto htNoPeepHole = sigmoid(zot) * tanh(*ct); // = [bS x nOut] + + // apply projection + if (projection) { + ht->assign(mmul(htNoPeepHole, + *Wp)); // [bS x nOut] * [ nOut x numProj] = [bS x numProj] + // if clipping projection is provided then projected cell output state is + // clipped by this value + if (clippingProjValue != 0.) + ht->applyScalar(scalar::LstmClip, clippingProjValue, *ht); + } else + ht->assign(&htNoPeepHole); } - - -void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast, - const NDArray* W, const NDArray* Wci, const NDArray* Wcf, const NDArray* Wco, const NDArray* b, - NDArray* i, NDArray* c, NDArray* f, NDArray* o, NDArray* z, NDArray* h, NDArray* y, const std::vector& params) { - /* Input arrays: - * 0: xt - input [bS, nIn] at time t - * 1: cLast (cs_prev) - previous cell state [bS, nOut], time t-1 - * 2: yLast (h_prev) - previous output [bS, nOut], time t-1 - * 3: W - Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(nIn+nOut), 4*nOut] - * 4: Wci - weights - cell peephole (t-1) connections to input modulation gate, [nOut] - * 5: Wcf - weights - cell peephole (t-1) connections to forget gate, [nOut] - * 6: Wco - weights - cell peephole (t) connections to output gate, [nOut] - * 7: b - biases, [4*nOut] - * - * Input integer arguments: - * 0: if not zero, provide peephole connections - * - * Input float arguments: - * 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training - * 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped - * - * Output arrays: - * 0: i - Input modulation gate activations [bS, nOut] - * 1: c (cs) - Cell state (pre tanh) [bs, nOut] (cs) - * 2: f - Output - forget gate activations [bs, nOut] - * 3: o - Output - output gate activations [bs, nOut] - * 4: z (ci) - Output - block input [bs, nOut] - * 5: h (co) - Cell state, post tanh [bs, nOut] - * 6: y (h) - Current cell output [bS, nOut], time t - */ - const bool peephole = (bool)params[0]; // if true, provide peephole connections - const double forgetBias = params[1]; - const double clippingCellValue = params[2]; // clipping value for ct, if it is not equal to zero, then cell state is clipped - - const int bS = xt->sizeAt(0); - const int nIn = xt->sizeAt(1); - const int nOut = cLast->sizeAt(1); - - //Concat inputs: [xt, yt-1]: concat([bs,nIn],[bs,nOut]) -> [bs, (nIn+nOut)] - NDArray concatOut(xt->ordering(), {xt->sizeAt(0), xt->sizeAt(1) + yLast->sizeAt(1)}, xt->dataType(), xt->getContext()); - helpers::concat(xt->getContext(), {const_cast(xt), const_cast(yLast)}, concatOut, {1}); - - auto m = mmul(concatOut, *W); // mmul: [bs, (nIn+nOut)] * [(nIn+nOut), 4*nOut] = [bs, 4*nOut] - m += (*b); // addiRowVector - - //Note: weights are ordered [inputGate, blockInput, forgetGate, outputGate] to match TF (TF code comments state [i,f,z/ci,o] but behaviour is [i,z,f,o]) - auto zi = m({0,0, 0, nOut}); // z for input modulation gate, [bS, nOut] - auto zz = m({0,0, nOut, 2*nOut}); // z for block input, [bS, nOut] - auto zf = m({0,0, 2*nOut, 3*nOut}); // z for forget gate, [bS, nOut] - auto zo = m({0,0, 3*nOut, 4*nOut}); // z for output gate, [bS, nOut] - - if(peephole) { // add peephole connections: z + ct_1*Wc - zi += (*cLast) * (*Wci); // add peephole connections to input gate - zf += (*cLast) * (*Wcf); // add peephole connections to forget gate - } - - // current sell state = ft*cLast + it*tanh(mmul(Wxc,xt) + mmul(Whc,ht_1) + bc - if(forgetBias != 0.0) - zf += forgetBias; - - zz.applyTransform(transform::Tanh, *z); //z = tanh(zz) - zi.applyTransform(transform::Sigmoid, *i); //i = sigmoid(zi) - zf.applyTransform(transform::Sigmoid, *f); //f = sigmoid(zf); - - //cell state = blockInput .* inputGate + prevCellState .* forgetGate - z->applyPairwiseTransform(pairwise::Multiply, *i, *c); //c = z * i - auto temp = (*f) * (*cLast); - *c += temp; //c = (i * z) + (zf * (*cLast)) - c->applyTransform(transform::Tanh, *h); //h = tanh(c) - - // if clipping value is provided then cell state is clipped by this value prior to the cell output activation - if(clippingCellValue > 0.0) - c->applyScalar(scalar::LstmClip, clippingCellValue, *c); - - if(peephole) { - // add peephole connections to output gate zot + ct*Wc - auto prod = *c * (*Wco); - zo += prod; - } - zo.applyTransform(transform::Sigmoid, *o); // o = sigmoid(zo) - - // current cell output = ot*tanh(ct) - c->applyTransform(transform::Tanh, *h); //h = tanh(c) - o->applyPairwiseTransform(pairwise::Multiply, *h, *y); //y = o * h -} - - -} -} +void lstmBlockCell(const NDArray* xt, const NDArray* cLast, + const NDArray* yLast, const NDArray* W, const NDArray* Wci, + const NDArray* Wcf, const NDArray* Wco, const NDArray* b, + NDArray* i, NDArray* c, NDArray* f, NDArray* o, NDArray* z, + NDArray* h, NDArray* y, const std::vector& params) { + /* Input arrays: + * 0: xt - input [bS, nIn] at time t + * 1: cLast (cs_prev) - previous cell state [bS, nOut], time t-1 + * 2: yLast (h_prev) - previous output [bS, nOut], time t-1 + * 3: W - Weights - concatenated (input-to-hidden, + * hidden-to-hidden weights) weights, [(nIn+nOut), 4*nOut] 4: Wci - weights - + * cell peephole (t-1) connections to input modulation gate, [nOut] 5: Wcf - + * weights - cell peephole (t-1) connections to forget gate, [nOut] 6: Wco - + * weights - cell peephole (t) connections to output gate, [nOut] 7: b - + * biases, [4*nOut] + * + * Input integer arguments: + * 0: if not zero, provide peephole connections + * + * Input float arguments: + * 0: the bias added to forget gates in order to reduce the scale of + * forgetting in the beginning of the training 1: clipping value for cell + * state, if it is not equal to zero, then cell state is clipped + * + * Output arrays: + * 0: i - Input modulation gate activations [bS, nOut] + * 1: c (cs) - Cell state (pre tanh) [bs, nOut] (cs) + * 2: f - Output - forget gate activations [bs, nOut] + * 3: o - Output - output gate activations [bs, nOut] + * 4: z (ci) - Output - block input [bs, nOut] + * 5: h (co) - Cell state, post tanh [bs, nOut] + * 6: y (h) - Current cell output [bS, nOut], time t + */ + const bool peephole = + (bool)params[0]; // if true, provide peephole connections + const double forgetBias = params[1]; + const double clippingCellValue = + params[2]; // clipping value for ct, if it is not equal to zero, then + // cell state is clipped + + const int bS = xt->sizeAt(0); + const int nIn = xt->sizeAt(1); + const int nOut = cLast->sizeAt(1); + + // Concat inputs: [xt, yt-1]: concat([bs,nIn],[bs,nOut]) -> [bs, (nIn+nOut)] + NDArray concatOut(xt->ordering(), + {xt->sizeAt(0), xt->sizeAt(1) + yLast->sizeAt(1)}, + xt->dataType(), xt->getContext()); + helpers::concat(xt->getContext(), + {const_cast(xt), const_cast(yLast)}, + concatOut, {1}); + + auto m = + mmul(concatOut, + *W); // mmul: [bs, (nIn+nOut)] * [(nIn+nOut), 4*nOut] = [bs, 4*nOut] + m += (*b); // addiRowVector + + // Note: weights are ordered [inputGate, blockInput, forgetGate, outputGate] + // to match TF (TF code comments state [i,f,z/ci,o] but behaviour is + // [i,z,f,o]) + auto zi = m({0, 0, 0, nOut}); // z for input modulation gate, [bS, nOut] + auto zz = m({0, 0, nOut, 2 * nOut}); // z for block input, [bS, nOut] + auto zf = m({0, 0, 2 * nOut, 3 * nOut}); // z for forget gate, [bS, nOut] + auto zo = m({0, 0, 3 * nOut, 4 * nOut}); // z for output gate, [bS, nOut] + + if (peephole) { // add peephole connections: z + ct_1*Wc + zi += (*cLast) * (*Wci); // add peephole connections to input gate + zf += (*cLast) * (*Wcf); // add peephole connections to forget gate + } + + // current sell state = ft*cLast + it*tanh(mmul(Wxc,xt) + mmul(Whc,ht_1) + bc + if (forgetBias != 0.0) zf += forgetBias; + + zz.applyTransform(transform::Tanh, *z); // z = tanh(zz) + zi.applyTransform(transform::Sigmoid, *i); // i = sigmoid(zi) + zf.applyTransform(transform::Sigmoid, *f); // f = sigmoid(zf); + + // cell state = blockInput .* inputGate + prevCellState .* forgetGate + z->applyPairwiseTransform(pairwise::Multiply, *i, *c); // c = z * i + auto temp = (*f) * (*cLast); + *c += temp; // c = (i * z) + (zf * (*cLast)) + c->applyTransform(transform::Tanh, *h); // h = tanh(c) + + // if clipping value is provided then cell state is clipped by this value + // prior to the cell output activation + if (clippingCellValue > 0.0) + c->applyScalar(scalar::LstmClip, clippingCellValue, *c); + + if (peephole) { + // add peephole connections to output gate zot + ct*Wc + auto prod = *c * (*Wco); + zo += prod; + } + zo.applyTransform(transform::Sigmoid, *o); // o = sigmoid(zo) + + // current cell output = ot*tanh(ct) + c->applyTransform(transform::Tanh, *h); // h = tanh(c) + o->applyPairwiseTransform(pairwise::Multiply, *h, *y); // y = o * h } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lstsq.cu b/libnd4j/include/ops/declarable/helpers/cuda/lstsq.cu index b28efff80d74..a3746d1473c5 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lstsq.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lstsq.cu @@ -17,99 +17,128 @@ // // @author GS // -#include #include +#include #include #include -#include - -#include +#include #include #include -#include +#include +#include namespace sd { namespace ops { namespace helpers { - template - static __global__ void fillRegularizerKernel(T* ioMatrixData, const Nd4jLong* ioMatrixShape, const Nd4jLong* ioMatrixTads, const Nd4jLong* ioMatrixOffsets, Nd4jLong batchSize, Nd4jLong rows, T const value) { - - for (auto x = blockIdx.x; x < batchSize; x += gridDim.x) { - auto z = ioMatrixData + ioMatrixOffsets[x]; - for (auto r = threadIdx.x; r < rows; r += blockDim.x) { - Nd4jLong pos[] = {r,r}; - auto zIndex = shape::getOffset(ioMatrixTads, pos); - z[zIndex] = value; - } - } - +template +static __global__ void fillRegularizerKernel(T* ioMatrixData, + const Nd4jLong* ioMatrixShape, + const Nd4jLong* ioMatrixTads, + const Nd4jLong* ioMatrixOffsets, + Nd4jLong batchSize, Nd4jLong rows, + T const value) { + for (auto x = blockIdx.x; x < batchSize; x += gridDim.x) { + auto z = ioMatrixData + ioMatrixOffsets[x]; + for (auto r = threadIdx.x; r < rows; r += blockDim.x) { + Nd4jLong pos[] = {r, r}; + auto zIndex = shape::getOffset(ioMatrixTads, pos); + z[zIndex] = value; } + } +} - template - static void fillRegularizer(sd::LaunchContext* context, NDArray& ioMatrix, double const value) { - auto lastDimsTads = ConstantTadHelper::getInstance().tadForDimensions(ioMatrix.shapeInfo(), {-2, -1}); - auto stream = context->getCudaStream(); - auto rows = ioMatrix.sizeAt(-2); - //auto cols = ioMatrix.sizeAt(-1); - fillRegularizerKernel<<<256, 256, 128, *stream>>>(ioMatrix.dataBuffer()->specialAsT(), ioMatrix.specialShapeInfo(), lastDimsTads.specialShapeInfo(), lastDimsTads.specialOffsets(), lastDimsTads.numberOfTads(), rows, (T)value); - - } - - template - int leastSquaresSolveFunctor_(sd::LaunchContext* context, NDArray const* leftInput, NDArray const* rightInput, double const l2Regularizer, bool const fast, NDArray* output) { - if (fast) { // Cholesky decomposition approach - // Equation for solve A^T * Ax = A^T * b, so - // 1. Computing A2: - auto tAtShape = ShapeUtils::evalShapeForMatmul(leftInput->shapeInfo(), leftInput->shapeInfo(), true, false); - //tAtShape[tAtShape.size() - 2] = output->sizeAt(-2); - NDArray leftOutput(leftInput->ordering(), tAtShape, output->dataType(), context); - MmulHelper::matmul(leftInput, leftInput, &leftOutput, true, false); // Computing A2 = A^T * A - // 2. Computing B' = A^T * b - auto rightOutput = output->ulike(); - - MmulHelper::matmul(leftInput, rightInput, &rightOutput, true, false); // Computing B' = A^T * b - // 3. Regularization ( indeed A' = A2 - l2Regularizer * I) - if (l2Regularizer != 0.0) { - auto regularizer = leftOutput.ulike(); regularizer.nullify(); - fillRegularizer(context, regularizer, (T)l2Regularizer); - leftOutput += regularizer; - } - - // 4. Cholesky decomposition -- output matrix is square and lower triangular - helpers::cholesky(context, &leftOutput, &leftOutput, true); // inplace decomposition - // 5. Solve two triangular systems: - auto rightB = rightOutput.ulike(); rightB.nullify(); - - helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, true, false, &rightB); - - helpers::adjointMatrix(context, &leftOutput, true, &leftOutput); - helpers::triangularSolveFunctor(context, &leftOutput, &rightB, false, false, output); - // All done - } - else { // QR decomposition approach - // Equation for solve Rx = Q^T * b, where A = Q * R, where Q - orthogonal matrix, and R - upper triangular - // 1. QR decomposition - auto qShape = leftInput->getShapeAsVector(); - auto rShape = leftInput->getShapeAsVector(); - qShape[leftInput->rankOf() - 1] = leftInput->sizeAt(-2); - - NDArray Q(leftInput->ordering(), qShape, leftInput->dataType(), context);// = leftInput->ulike(); - NDArray R(leftInput->ordering(), rShape, leftInput->dataType(), context); // = rightInput->ulike(); - helpers::qr(context, leftInput, &Q, &R, true); - // 2. b` = Q^t * b: - auto rightOutput = rightInput->ulike(); - MmulHelper::matmul(&Q, rightInput, &rightOutput, true, false); - // 3. Solve triangular system - helpers::triangularSolveFunctor(context, &R, &rightOutput, false, false, output); - } - return Status::OK(); - } +template +static void fillRegularizer(sd::LaunchContext* context, NDArray& ioMatrix, + double const value) { + auto lastDimsTads = ConstantTadHelper::getInstance().tadForDimensions( + ioMatrix.shapeInfo(), {-2, -1}); + auto stream = context->getCudaStream(); + auto rows = ioMatrix.sizeAt(-2); + // auto cols = ioMatrix.sizeAt(-1); + fillRegularizerKernel<<<256, 256, 128, *stream>>>( + ioMatrix.dataBuffer()->specialAsT(), ioMatrix.specialShapeInfo(), + lastDimsTads.specialShapeInfo(), lastDimsTads.specialOffsets(), + lastDimsTads.numberOfTads(), rows, (T)value); +} - int leastSquaresSolveFunctor(sd::LaunchContext* context, NDArray const* leftInput, NDArray const* rightInput, double const l2Regularizer, bool const fast, NDArray* output) { - BUILD_SINGLE_SELECTOR(leftInput->dataType(), return leastSquaresSolveFunctor_, (context, leftInput, rightInput, l2Regularizer, fast, output), FLOAT_TYPES); +template +int leastSquaresSolveFunctor_(sd::LaunchContext* context, + NDArray const* leftInput, + NDArray const* rightInput, + double const l2Regularizer, bool const fast, + NDArray* output) { + if (fast) { // Cholesky decomposition approach + // Equation for solve A^T * Ax = A^T * b, so + // 1. Computing A2: + auto tAtShape = ShapeUtils::evalShapeForMatmul( + leftInput->shapeInfo(), leftInput->shapeInfo(), true, false); + // tAtShape[tAtShape.size() - 2] = output->sizeAt(-2); + NDArray leftOutput(leftInput->ordering(), tAtShape, output->dataType(), + context); + MmulHelper::matmul(leftInput, leftInput, &leftOutput, true, + false); // Computing A2 = A^T * A + // 2. Computing B' = A^T * b + auto rightOutput = output->ulike(); + + MmulHelper::matmul(leftInput, rightInput, &rightOutput, true, + false); // Computing B' = A^T * b + // 3. Regularization ( indeed A' = A2 - l2Regularizer * I) + if (l2Regularizer != 0.0) { + auto regularizer = leftOutput.ulike(); + regularizer.nullify(); + fillRegularizer(context, regularizer, (T)l2Regularizer); + leftOutput += regularizer; } + // 4. Cholesky decomposition -- output matrix is square and lower triangular + helpers::cholesky(context, &leftOutput, &leftOutput, + true); // inplace decomposition + // 5. Solve two triangular systems: + auto rightB = rightOutput.ulike(); + rightB.nullify(); + + helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, true, + false, &rightB); + + helpers::adjointMatrix(context, &leftOutput, true, &leftOutput); + helpers::triangularSolveFunctor(context, &leftOutput, &rightB, false, false, + output); + // All done + } else { // QR decomposition approach + // Equation for solve Rx = Q^T * b, where A = Q * R, where Q - orthogonal + // matrix, and R - upper triangular + // 1. QR decomposition + auto qShape = leftInput->getShapeAsVector(); + auto rShape = leftInput->getShapeAsVector(); + qShape[leftInput->rankOf() - 1] = leftInput->sizeAt(-2); + + NDArray Q(leftInput->ordering(), qShape, leftInput->dataType(), + context); // = leftInput->ulike(); + NDArray R(leftInput->ordering(), rShape, leftInput->dataType(), + context); // = rightInput->ulike(); + helpers::qr(context, leftInput, &Q, &R, true); + // 2. b` = Q^t * b: + auto rightOutput = rightInput->ulike(); + MmulHelper::matmul(&Q, rightInput, &rightOutput, true, false); + // 3. Solve triangular system + helpers::triangularSolveFunctor(context, &R, &rightOutput, false, false, + output); + } + return Status::OK(); } + +int leastSquaresSolveFunctor(sd::LaunchContext* context, + NDArray const* leftInput, + NDArray const* rightInput, + double const l2Regularizer, bool const fast, + NDArray* output) { + BUILD_SINGLE_SELECTOR( + leftInput->dataType(), return leastSquaresSolveFunctor_, + (context, leftInput, rightInput, l2Regularizer, fast, output), + FLOAT_TYPES); } -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu index c59ef9489cc1..241c56980725 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu @@ -18,12 +18,12 @@ // @author raver119@gmail.com // -#include -#include #include #include #include +#include #include +#include //#include #include @@ -33,978 +33,1075 @@ namespace sd { namespace ops { namespace helpers { -// ------------------------------------------------------------------------------------------------------------------ // +// ------------------------------------------------------------------------------------------------------------------ +// // // invert the second diagonal for lower diagonal matrix - template - static __global__ void - invertKernelLow(void *invertedBuf, const Nd4jLong *invertedShape, const void *inputBuf, const Nd4jLong *inputShape, Nd4jLong n) { - auto inverted = reinterpret_cast(invertedBuf); - auto input = reinterpret_cast(inputBuf); - - auto start = threadIdx.x + blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; - - for (int i = start + 1; i < n; i += step) { - Nd4jLong pos[] = {i, i - 1}; - Nd4jLong posX[] = {i, i}; - Nd4jLong posY[] = {i - 1, i - 1}; - auto xIndex = shape::getOffset(inputShape, pos); - auto dxIndex = shape::getOffset(inputShape, posX); - auto dyIndex = shape::getOffset(inputShape, posY); - auto zIndex = shape::getOffset(invertedShape, pos); - // invert lower triangular matrix - inverted[zIndex] = -input[xIndex] / (input[dxIndex] * input[dyIndex]); -// math::atomics::nd4j_atomicAdd(&inverted[zIndex], - input[xIndex] * inverted[iIndex] / input[dIndex]); - } - } -// ------------------------------------------------------------------------------------------------------------------ // -// invert diagonal vals to upper diagonal matrix - template - static __global__ void - upvertKernel(void *invertedBuf, const Nd4jLong *invertedShape, const void *inputBuf, const Nd4jLong *inputShape, Nd4jLong n) { - auto inverted = reinterpret_cast(invertedBuf); - auto input = reinterpret_cast(inputBuf); - - auto start = threadIdx.x + blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; - - for (int i = start; i < n; i += step) { - Nd4jLong pos[] = {i, i}; - auto xIndex = shape::getOffset(inputShape, pos); - auto zIndex = shape::getOffset(invertedShape, pos); - - // invert diagonal elements - inverted[zIndex] /= input[xIndex]; - } - } +template +static __global__ void invertKernelLow(void *invertedBuf, + const Nd4jLong *invertedShape, + const void *inputBuf, + const Nd4jLong *inputShape, Nd4jLong n) { + auto inverted = reinterpret_cast(invertedBuf); + auto input = reinterpret_cast(inputBuf); + + auto start = threadIdx.x + blockIdx.x * blockDim.x; + auto step = blockDim.x * gridDim.x; + + for (int i = start + 1; i < n; i += step) { + Nd4jLong pos[] = {i, i - 1}; + Nd4jLong posX[] = {i, i}; + Nd4jLong posY[] = {i - 1, i - 1}; + auto xIndex = shape::getOffset(inputShape, pos); + auto dxIndex = shape::getOffset(inputShape, posX); + auto dyIndex = shape::getOffset(inputShape, posY); + auto zIndex = shape::getOffset(invertedShape, pos); + // invert lower triangular matrix + inverted[zIndex] = -input[xIndex] / (input[dxIndex] * input[dyIndex]); + // math::atomics::nd4j_atomicAdd(&inverted[zIndex], - + // input[xIndex] * inverted[iIndex] / input[dIndex]); + } +} +// ------------------------------------------------------------------------------------------------------------------ +// // invert diagonal vals to upper diagonal matrix +template +static __global__ void upvertKernel(void *invertedBuf, + const Nd4jLong *invertedShape, + const void *inputBuf, + const Nd4jLong *inputShape, Nd4jLong n) { + auto inverted = reinterpret_cast(invertedBuf); + auto input = reinterpret_cast(inputBuf); + + auto start = threadIdx.x + blockIdx.x * blockDim.x; + auto step = blockDim.x * gridDim.x; + + for (int i = start; i < n; i += step) { + Nd4jLong pos[] = {i, i}; + auto xIndex = shape::getOffset(inputShape, pos); + auto zIndex = shape::getOffset(invertedShape, pos); + + // invert diagonal elements + inverted[zIndex] /= input[xIndex]; + } +} -// ------------------------------------------------------------------------------------------------------------------ // +// ------------------------------------------------------------------------------------------------------------------ +// // // invert upper second diagonal - template - static __global__ void - upvertKernelUp(void *invertedBuf, const Nd4jLong *invertedShape, const void *inputBuf, const Nd4jLong *inputShape, Nd4jLong n) { - - __shared__ T* inverted; - __shared__ const T* input; - if (threadIdx.x == 0) { - inverted = reinterpret_cast(invertedBuf); - input = reinterpret_cast(inputBuf); - } - __syncthreads(); - - auto start = threadIdx.x + blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; - - for (int i = start; i < n - 1; i += step) { - Nd4jLong pos[] = {i, i + 1}; - Nd4jLong posX[] = {i + 1, i + 1}; - auto xIndex = shape::getOffset(inputShape, pos); - auto iIndex = shape::getOffset(invertedShape, posX); - auto zIndex = shape::getOffset(invertedShape, pos); - // invert upper matrix - math::atomics::nd4j_atomicAdd(&inverted[zIndex], -input[xIndex] * inverted[iIndex]); // / input[yIndex]); - //inputMatrix->t(i, i + 1) * invertedMatrix->t(i + 1, i + 1) / inputMatrix->t(i, i) - } - } - -// ------------------------------------------------------------------------------------------------------------------ // - template - static __global__ void - invertLowKernel(void *invertedBuf, const Nd4jLong *invertedShape, const void *inputBuf, const Nd4jLong *inputShape, Nd4jLong n) { - - auto input = reinterpret_cast(inputBuf); - auto inverted = reinterpret_cast(invertedBuf); - - - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - auto step = gridDim.x * blockDim.x; - - for (int i = tid + 2; i < n; i += step) { - for (int j = i - 2; j >= 0; --j) - for (int k = 0; k < i; k++) { - Nd4jLong posZ[] = {i, j}; - Nd4jLong posY[] = {k, j}; - Nd4jLong posX[] = {i, k}; - Nd4jLong posD[] = {i, i}; - - auto xIndex = shape::getOffset(inputShape, posX); - auto yIndex = shape::getOffset(invertedShape, posY); - auto dIndex = shape::getOffset(inputShape, posD); - auto zIndex = shape::getOffset(invertedShape, posZ); - // invert non-diagonal elements - math::atomics::nd4j_atomicAdd(&inverted[zIndex], -inverted[yIndex] * input[xIndex] / input[dIndex]); - } - } - } - -// ------------------------------------------------------------------------------------------------------------------ // -// Invertion of upper triangular matrix non-diagonal elements when main and second diagonals already processed - template - static __global__ void - invertUpKernel( - void *invertedBuf, const Nd4jLong *invertedShape, - const void *inputBuf, const Nd4jLong *inputShape, - Nd4jLong n) { - - auto inverted = reinterpret_cast(invertedBuf);; - auto input = reinterpret_cast(inputBuf); - - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (int i = (int)n - tid - 2; i >= 0; i -= step) { - for (int j = i + 2; j < (int)n; j++) - for (int k = i; k < (int)n; k++) { - Nd4jLong posZ[] = {i, j}; - Nd4jLong posY[] = {k, j}; - Nd4jLong posX[] = {i, k}; - // inversion with Joardan Gauss transformation - auto xIndex = shape::getOffset(inputShape, posX); - auto yIndex = shape::getOffset(invertedShape, posY); - auto zIndex = shape::getOffset(invertedShape, posZ); - // invert upper non-diagonal elements - math::atomics::nd4j_atomicAdd(&inverted[zIndex], -inverted[yIndex] * input[xIndex]); - } - } - } - -// ------------------------------------------------------------------------------------------------------------------ // -// procedure to invert lower-triangular matrix. -// In current case lower triangular matrix has main diagonal with general values -// - template - static void invertLowerMatrix_(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) { - int n = inputMatrix->rows(); - invertedMatrix->setIdentity(); - - if (inputMatrix->isIdentityMatrix()) return; - - auto stream = context->getCudaStream(); +template +static __global__ void upvertKernelUp(void *invertedBuf, + const Nd4jLong *invertedShape, + const void *inputBuf, + const Nd4jLong *inputShape, Nd4jLong n) { + __shared__ T *inverted; + __shared__ const T *input; + if (threadIdx.x == 0) { + inverted = reinterpret_cast(invertedBuf); + input = reinterpret_cast(inputBuf); + } + __syncthreads(); + + auto start = threadIdx.x + blockIdx.x * blockDim.x; + auto step = blockDim.x * gridDim.x; + + for (int i = start; i < n - 1; i += step) { + Nd4jLong pos[] = {i, i + 1}; + Nd4jLong posX[] = {i + 1, i + 1}; + auto xIndex = shape::getOffset(inputShape, pos); + auto iIndex = shape::getOffset(invertedShape, posX); + auto zIndex = shape::getOffset(invertedShape, pos); + // invert upper matrix + math::atomics::nd4j_atomicAdd( + &inverted[zIndex], + -input[xIndex] * inverted[iIndex]); // / input[yIndex]); + // inputMatrix->t(i, i + 1) * invertedMatrix->t(i + 1, i + 1) / + // inputMatrix->t(i, i) + } +} - // invert lower matrix - // invert main diagonal - upvertKernel<<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); - // invert the second diagonal - invertKernelLow<<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); +// ------------------------------------------------------------------------------------------------------------------ +// // +template +static __global__ void invertLowKernel(void *invertedBuf, + const Nd4jLong *invertedShape, + const void *inputBuf, + const Nd4jLong *inputShape, Nd4jLong n) { + auto input = reinterpret_cast(inputBuf); + auto inverted = reinterpret_cast(invertedBuf); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto step = gridDim.x * blockDim.x; + + for (int i = tid + 2; i < n; i += step) { + for (int j = i - 2; j >= 0; --j) + for (int k = 0; k < i; k++) { + Nd4jLong posZ[] = {i, j}; + Nd4jLong posY[] = {k, j}; + Nd4jLong posX[] = {i, k}; + Nd4jLong posD[] = {i, i}; + + auto xIndex = shape::getOffset(inputShape, posX); + auto yIndex = shape::getOffset(invertedShape, posY); + auto dIndex = shape::getOffset(inputShape, posD); + auto zIndex = shape::getOffset(invertedShape, posZ); // invert non-diagonal elements - invertLowKernel<<>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); - } + math::atomics::nd4j_atomicAdd( + &inverted[zIndex], + -inverted[yIndex] * input[xIndex] / input[dIndex]); + } + } +} -// ------------------------------------------------------------------------------------------------------------------ // -// caller for invert lower matrix routine - void invertLowerMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) { - NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); - BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE); - NDArray::registerSpecialUse({invertedMatrix}, {inputMatrix}); - } +// ------------------------------------------------------------------------------------------------------------------ +// // Invertion of upper triangular matrix non-diagonal elements when main and +// second diagonals already processed +template +static __global__ void invertUpKernel(void *invertedBuf, + const Nd4jLong *invertedShape, + const void *inputBuf, + const Nd4jLong *inputShape, Nd4jLong n) { + auto inverted = reinterpret_cast(invertedBuf); + ; + auto input = reinterpret_cast(inputBuf); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (int i = (int)n - tid - 2; i >= 0; i -= step) { + for (int j = i + 2; j < (int)n; j++) + for (int k = i; k < (int)n; k++) { + Nd4jLong posZ[] = {i, j}; + Nd4jLong posY[] = {k, j}; + Nd4jLong posX[] = {i, k}; + // inversion with Joardan Gauss transformation + auto xIndex = shape::getOffset(inputShape, posX); + auto yIndex = shape::getOffset(invertedShape, posY); + auto zIndex = shape::getOffset(invertedShape, posZ); + // invert upper non-diagonal elements + math::atomics::nd4j_atomicAdd(&inverted[zIndex], + -inverted[yIndex] * input[xIndex]); + } + } +} -// ------------------------------------------------------------------------------------------------------------------ // -// procedure to invert upper-triangular matrix. -// In current case upper triangular matrix has main diagonal with all ones on it. - template - static void invertUpperMatrix_(LaunchContext *context, NDArray* inputMatrix, NDArray* invertedMatrix) { - int n = inputMatrix->rows(); - invertedMatrix->setIdentity(); - auto stream = context->getCudaStream(); - if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I - return; - } +// ------------------------------------------------------------------------------------------------------------------ +// // procedure to invert lower-triangular matrix. In current case lower +// triangular matrix has main diagonal with general values +// +template +static void invertLowerMatrix_(LaunchContext *context, NDArray *inputMatrix, + NDArray *invertedMatrix) { + int n = inputMatrix->rows(); + invertedMatrix->setIdentity(); + + if (inputMatrix->isIdentityMatrix()) return; + + auto stream = context->getCudaStream(); + + // invert lower matrix + // invert main diagonal + upvertKernel<<<1, n, 512, *stream>>>( + invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), + inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); + // invert the second diagonal + invertKernelLow<<<1, n, 512, *stream>>>( + invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), + inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); + // invert non-diagonal elements + invertLowKernel<<>>( + invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), + inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); +} - // invert upper matrix - // invert the second diagonal - upvertKernelUp<<<1, n, 512, *stream >>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), - inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); +// ------------------------------------------------------------------------------------------------------------------ +// // caller for invert lower matrix routine +void invertLowerMatrix(LaunchContext *context, NDArray *inputMatrix, + NDArray *invertedMatrix) { + NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); + BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, + (context, inputMatrix, invertedMatrix), FLOAT_NATIVE); + NDArray::registerSpecialUse({invertedMatrix}, {inputMatrix}); +} - // invert other elements - invertUpKernel<<>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(),inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); - } +// ------------------------------------------------------------------------------------------------------------------ +// // procedure to invert upper-triangular matrix. In current case upper +// triangular matrix has main diagonal with all ones on it. +template +static void invertUpperMatrix_(LaunchContext *context, NDArray *inputMatrix, + NDArray *invertedMatrix) { + int n = inputMatrix->rows(); + invertedMatrix->setIdentity(); + auto stream = context->getCudaStream(); + if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I + return; + } + + // invert upper matrix + // invert the second diagonal + upvertKernelUp<<<1, n, 512, *stream>>>( + invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), + inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); + + // invert other elements + invertUpKernel<<>>( + invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), + inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); +} -// ------------------------------------------------------------------------------------------------------------------ // +// ------------------------------------------------------------------------------------------------------------------ +// // // invertion of upper triangular matrix - runner routine - void invertUpperMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) { - NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); - BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE); - NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); - } - -// ------------------------------------------------------------------------------------------------------------------ // - // determinant kernel - accumulation product of all values on the main diagonal - template - static __global__ void determinantKernel(T *compound, T *result, Nd4jLong len) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - for (auto i = start; i < len; i += step) { - auto pos = i * len + i; //shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), di, 2); - // multiply all diagonal elements - math::atomics::nd4j_atomicMul(&result[0], compound[pos]); - } - } - -// ------------------------------------------------------------------------------------------------------------------ // - // determinant logarithm - accumulation sum of all logarithm values on the main diagonal. All in logarithic values - // should be positive - template - static __global__ void determinantLogKernel(T *compound, T *result, Nd4jLong len) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - for (auto i = start; i < len; i += step) { - auto pos = i * len + i; //shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), di, 2); - // sum logs of all diagonal elements - math::atomics::nd4j_atomicAdd(result, math::nd4j_log(math::nd4j_abs(compound[pos]))); - } - } - -// ------------------------------------------------------------------------------------------------------------------ // - // kernel to copy matrix with given shape to compound tensor with given pos - // output - a N-D tensor buffer with rank not less than 2, input - 2D square n x n matrix with n = rowLen - template - static __global__ void - fillMatrix(void *output, const Nd4jLong *outShape, const void *input, const Nd4jLong *inputShape, Nd4jLong pos, Nd4jLong rowLen) { - __shared__ F *matrix; - __shared__ const T *inputBuf; - __shared__ Nd4jLong inputLen; - __shared__ Nd4jLong n2; - - if (threadIdx.x == 0) { - matrix = reinterpret_cast(output); - inputBuf = reinterpret_cast(input); - inputLen = shape::length(inputShape); - n2 = rowLen * rowLen; - } - __syncthreads(); - - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (int k = pos + start, j = start; j < n2; k += step, j += step) { - auto xIndex = shape::getIndexOffset(k, inputShape); - matrix[j] = (F) inputBuf[xIndex]; - } - } +void invertUpperMatrix(LaunchContext *context, NDArray *inputMatrix, + NDArray *invertedMatrix) { + NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); + BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, + (context, inputMatrix, invertedMatrix), FLOAT_NATIVE); + NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); +} -// ------------------------------------------------------------------------------------------------------------------ // -// same as above, but without type conversion - template - static __global__ void - returnMatrix(void *output, const Nd4jLong *outputShape, const void *input, const Nd4jLong *inputShape, Nd4jLong pos, Nd4jLong rowLen) { - __shared__ Nd4jLong outputLen; - __shared__ Nd4jLong n2; - auto matrix = reinterpret_cast(input); - auto outputBuf = reinterpret_cast(output); +// ------------------------------------------------------------------------------------------------------------------ +// // determinant kernel - accumulation product of all values on the main +// diagonal +template +static __global__ void determinantKernel(T *compound, T *result, Nd4jLong len) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + for (auto i = start; i < len; i += step) { + auto pos = i * len + i; // shape::getOffset(0, shape::shapeOf(shape), + // shape::stride(shape), di, 2); + // multiply all diagonal elements + math::atomics::nd4j_atomicMul(&result[0], compound[pos]); + } +} - if (threadIdx.x == 0) { +// ------------------------------------------------------------------------------------------------------------------ +// // determinant logarithm - accumulation sum of all logarithm values on the +// main diagonal. All in logarithic values should be positive +template +static __global__ void determinantLogKernel(T *compound, T *result, + Nd4jLong len) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + for (auto i = start; i < len; i += step) { + auto pos = i * len + i; // shape::getOffset(0, shape::shapeOf(shape), + // shape::stride(shape), di, 2); + // sum logs of all diagonal elements + math::atomics::nd4j_atomicAdd( + result, math::nd4j_log(math::nd4j_abs(compound[pos]))); + } +} - outputLen = shape::length(inputShape); - n2 = rowLen * rowLen; - } - __syncthreads(); - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; +// ------------------------------------------------------------------------------------------------------------------ +// // kernel to copy matrix with given shape to compound tensor with given pos +// output - a N-D tensor buffer with rank not less than 2, input - 2D square n x +// n matrix with n = rowLen +template +static __global__ void fillMatrix(void *output, const Nd4jLong *outShape, + const void *input, const Nd4jLong *inputShape, + Nd4jLong pos, Nd4jLong rowLen) { + __shared__ F *matrix; + __shared__ const T *inputBuf; + __shared__ Nd4jLong inputLen; + __shared__ Nd4jLong n2; + + if (threadIdx.x == 0) { + matrix = reinterpret_cast(output); + inputBuf = reinterpret_cast(input); + inputLen = shape::length(inputShape); + n2 = rowLen * rowLen; + } + __syncthreads(); + + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (int k = pos + start, j = start; j < n2; k += step, j += step) { + auto xIndex = shape::getIndexOffset(k, inputShape); + matrix[j] = (F)inputBuf[xIndex]; + } +} - for (int k = pos + start, j = start; j < n2; k += step, j += step) { - auto zIndex = shape::getIndexOffset(k, outputShape); - outputBuf[zIndex] = matrix[j]; - } - } +// ------------------------------------------------------------------------------------------------------------------ +// // same as above, but without type conversion +template +static __global__ void returnMatrix(void *output, const Nd4jLong *outputShape, + const void *input, + const Nd4jLong *inputShape, Nd4jLong pos, + Nd4jLong rowLen) { + __shared__ Nd4jLong outputLen; + __shared__ Nd4jLong n2; + auto matrix = reinterpret_cast(input); + auto outputBuf = reinterpret_cast(output); + + if (threadIdx.x == 0) { + outputLen = shape::length(inputShape); + n2 = rowLen * rowLen; + } + __syncthreads(); + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (int k = pos + start, j = start; j < n2; k += step, j += step) { + auto zIndex = shape::getIndexOffset(k, outputShape); + outputBuf[zIndex] = matrix[j]; + } +} -// ------------------------------------------------------------------------------------------------------------------ // - // fill up permutaion matrix kernel. Permutation matrix filled with zeros and ones - template - static __global__ void fillUpPermutation(void *output, const Nd4jLong *shape, int *source, int rowNum) { - F *permutation = reinterpret_cast(output); - - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - for (auto i = start; i < rowNum; i += step) { - int val = source[i] - 1; - Nd4jLong posF[] = {i, val}; - auto pos = shape::getOffset(shape, posF); - permutation[pos] = F(1.f); - } - } +// ------------------------------------------------------------------------------------------------------------------ +// // fill up permutaion matrix kernel. Permutation matrix filled with zeros and +// ones +template +static __global__ void fillUpPermutation(void *output, const Nd4jLong *shape, + int *source, int rowNum) { + F *permutation = reinterpret_cast(output); + + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + for (auto i = start; i < rowNum; i += step) { + int val = source[i] - 1; + Nd4jLong posF[] = {i, val}; + auto pos = shape::getOffset(shape, posF); + permutation[pos] = F(1.f); + } +} -// ------------------------------------------------------------------------------------------------------------------ // - // LUP decomposition runner - using CUBLAS SOLVER - // if permutation is given, then using LUP decomposition, LU decomposition otherwise - // L - lower triangular, U - upper triangular, P - permutation matricies - // PA = LU - // - // input - A matrix nxn - // compound - C matrix L + U - I, or main diagonal and lower - L matrix, from the 2nd diagonal - U matrix - template - static void lup_(LaunchContext *context, NDArray *input, NDArray *compound, NDArray *permutation) { - auto stream = context->getCudaStream(); - auto n = input->rows(); - std::lock_guard lock(*LaunchContext::deviceMutex()); - - cusolverDnHandle_t* cusolverH = (cusolverDnHandle_t*)context->getCusolverHandle(); //nullptr; - // create solver handle - cusolverStatus_t status; //cusolverDnCreate(&cusolverH); -// if (CUSOLVER_STATUS_SUCCESS != status) { -// throw cuda_exception::build("Cannot create cuSolver handle", status); -// } - // set solver stream - status = cusolverDnSetStream(*cusolverH, *stream); - if (CUSOLVER_STATUS_SUCCESS != status) { - throw cuda_exception::build("Cannot set up stream for cuda solver", status); +// ------------------------------------------------------------------------------------------------------------------ +// // LUP decomposition runner - using CUBLAS SOLVER if permutation is given, +// then using LUP decomposition, LU decomposition otherwise L - lower +// triangular, U - upper triangular, P - permutation matricies PA = LU +// +// input - A matrix nxn +// compound - C matrix L + U - I, or main diagonal and lower - L matrix, from +// the 2nd diagonal - U matrix +template +static void lup_(LaunchContext *context, NDArray *input, NDArray *compound, + NDArray *permutation) { + auto stream = context->getCudaStream(); + auto n = input->rows(); + std::lock_guard lock(*LaunchContext::deviceMutex()); + + cusolverDnHandle_t *cusolverH = + (cusolverDnHandle_t *)context->getCusolverHandle(); // nullptr; + // create solver handle + cusolverStatus_t status; // cusolverDnCreate(&cusolverH); + // if (CUSOLVER_STATUS_SUCCESS != status) { + // throw cuda_exception::build("Cannot create cuSolver handle", + // status); + // } + // set solver stream + status = cusolverDnSetStream(*cusolverH, *stream); + if (CUSOLVER_STATUS_SUCCESS != status) { + throw cuda_exception::build("Cannot set up stream for cuda solver", status); + } + int lwork = 0; + int *d_info = nullptr; + // allocate memory for permutation vector + auto err = cudaMalloc((void **)&d_info, sizeof(int)); + if (err) { + throw cuda_exception::build( + "helpers::lup_: Cannot allocate memory for solver info buffer", err); + } + + DataType dtype = input->dataType(); + switch (dtype) { // there are two implementations with cublas for LUP + // decomposition - double and float + + case DataType::DOUBLE: { + double *d_work = nullptr; + // compute internal buffer size + double *matrix = reinterpret_cast(input->specialBuffer()); + status = cusolverDnDgetrf_bufferSize(*cusolverH, n, n, matrix, n, &lwork); + if (CUSOLVER_STATUS_SUCCESS != status) { + throw cuda_exception::build( + "helpers::lup_: Cannot create cuSolver handle", status); + } + + err = cudaMalloc((void **)&d_work, sizeof(float) * lwork); + if (err) { + throw cuda_exception::build( + "helpers::lup_: Cannot allocate memory for solver data buffer", + err); + } + + if (permutation == nullptr) { + status = cusolverDnDgetrf(*cusolverH, n, n, matrix, n, d_work, nullptr, + d_info); + + if (status != CUSOLVER_STATUS_SUCCESS) { + throw cuda_exception::build( + "helpers::lup_: LU factorization is failed due ", status); } - int lwork = 0; - int *d_info = nullptr; - // allocate memory for permutation vector - auto err = cudaMalloc((void **) &d_info, sizeof(int)); - if (err) { - throw cuda_exception::build("helpers::lup_: Cannot allocate memory for solver info buffer", err); + } else { + NDArray permutVector('c', {n}, sd::DataType::INT32, context); + int *permutationBuf = permutVector.dataBuffer()->specialAsT(); + status = cusolverDnDgetrf(*cusolverH, n, n, matrix, n, d_work, + permutationBuf, d_info); + if (status != CUSOLVER_STATUS_SUCCESS) { + throw cuda_exception::build( + "helpers::lup_: LU factorization is failed due ", status); } - DataType dtype = input->dataType(); - switch (dtype) { // there are two implementations with cublas for LUP decomposition - double and float - - case DataType::DOUBLE: { - double *d_work = nullptr; - // compute internal buffer size - double *matrix = reinterpret_cast(input->specialBuffer()); - status = cusolverDnDgetrf_bufferSize( - *cusolverH, - n, - n, - matrix, - n, - &lwork); - if (CUSOLVER_STATUS_SUCCESS != status) { - throw cuda_exception::build("helpers::lup_: Cannot create cuSolver handle", status); - } - - err = cudaMalloc((void **) &d_work, sizeof(float) * lwork); - if (err) { - throw cuda_exception::build("helpers::lup_: Cannot allocate memory for solver data buffer", - err); - } - - if (permutation == nullptr) { - status = cusolverDnDgetrf( - *cusolverH, - n, - n, - matrix, - n, - d_work, - nullptr, - d_info); - - if (status != CUSOLVER_STATUS_SUCCESS) { - throw cuda_exception::build("helpers::lup_: LU factorization is failed due ", - status); - } - } - else { - NDArray permutVector('c', {n}, sd::DataType::INT32, context); - int* permutationBuf = permutVector.dataBuffer()->specialAsT(); - status = cusolverDnDgetrf( - *cusolverH, - n, - n, - matrix, - n, - d_work, - permutationBuf, - d_info); - if (status != CUSOLVER_STATUS_SUCCESS) { - throw cuda_exception::build("helpers::lup_: LU factorization is failed due ", - status); - } - - if (permutation->rankOf() == 2) { - fillUpPermutation <<< n, n, 1024, *stream >>> - (permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n); - } - else { - permutVector.tickWriteDevice(); - input->tickWriteDevice(); - compound->assign(input); - permutation->assign(permutVector); - } - } - err = cudaFree(d_work); - if (err) { - throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver data buffer", - err); - } - } - break; - case DataType::FLOAT32: { - float *matrix = reinterpret_cast(input->specialBuffer()); - float *d_work = nullptr; - - status = cusolverDnSgetrf_bufferSize( - *cusolverH, - n, - n, - matrix, - n, - &lwork); - if (CUSOLVER_STATUS_SUCCESS != status) { - throw cuda_exception::build("helpers::lup_: Cannot create cuSolver handle", status); - } - - err = cudaMalloc((void **) &d_work, sizeof(float) * lwork); - if (err) { - throw cuda_exception::build("helpers::lup_: Cannot allocate memory for solver data buffer", - err); - } - - if (permutation == nullptr) - status = cusolverDnSgetrf( - *cusolverH, - n, - n, - matrix, - n, - d_work, - nullptr, - d_info); - else { - NDArray permutVector('c', {n}, DataType::INT32, context); - int *permutationBuf = reinterpret_cast(permutVector.specialBuffer()); - status = cusolverDnSgetrf( - *cusolverH, - n, - n, - matrix, - n, - d_work, - permutationBuf, - d_info); - if (permutation->rankOf() == 2) { - fillUpPermutation <<< n, n, 128, *stream >>> - (permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n); - permutation->tickWriteDevice(); - } - else { - input->tickWriteDevice(); - compound->assign(input); - permutation->assign(permutVector); - } - } - err = cudaFree(d_work); - if (err) { - throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver data buffer", - err); - } - - } - } - if (CUSOLVER_STATUS_SUCCESS != status) { - throw cuda_exception::build("helpers::lup_: Cannot make LU decomposition", status); + if (permutation->rankOf() == 2) { + fillUpPermutation<<>>( + permutation->specialBuffer(), permutation->specialShapeInfo(), + permutationBuf, n); + } else { + permutVector.tickWriteDevice(); + input->tickWriteDevice(); + compound->assign(input); + permutation->assign(permutVector); } - err = cudaFree(d_info); - if (err) { - throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver info buffer", err); + } + err = cudaFree(d_work); + if (err) { + throw cuda_exception::build( + "helpers::lup_: Cannot deallocate memory for solver data buffer", + err); + } + } break; + case DataType::FLOAT32: { + float *matrix = reinterpret_cast(input->specialBuffer()); + float *d_work = nullptr; + + status = cusolverDnSgetrf_bufferSize(*cusolverH, n, n, matrix, n, &lwork); + if (CUSOLVER_STATUS_SUCCESS != status) { + throw cuda_exception::build( + "helpers::lup_: Cannot create cuSolver handle", status); + } + + err = cudaMalloc((void **)&d_work, sizeof(float) * lwork); + if (err) { + throw cuda_exception::build( + "helpers::lup_: Cannot allocate memory for solver data buffer", + err); + } + + if (permutation == nullptr) + status = cusolverDnSgetrf(*cusolverH, n, n, matrix, n, d_work, nullptr, + d_info); + else { + NDArray permutVector('c', {n}, DataType::INT32, context); + int *permutationBuf = + reinterpret_cast(permutVector.specialBuffer()); + status = cusolverDnSgetrf(*cusolverH, n, n, matrix, n, d_work, + permutationBuf, d_info); + if (permutation->rankOf() == 2) { + fillUpPermutation<<>>( + permutation->specialBuffer(), permutation->specialShapeInfo(), + permutationBuf, n); + permutation->tickWriteDevice(); + } else { + input->tickWriteDevice(); + compound->assign(input); + permutation->assign(permutVector); } -// cusolverDnDestroy(cusolverH); -// NDArray::registerSpecialUse({input}, {input}); - input->tickWriteDevice(); + } + err = cudaFree(d_work); + if (err) { + throw cuda_exception::build( + "helpers::lup_: Cannot deallocate memory for solver data buffer", + err); + } } -// ------------------------------------------------------------------------------------------------------------------ // - - BUILD_DOUBLE_TEMPLATE(template void lup_,(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation), FLOAT_NATIVE, INDEXING_TYPES); - - template - static __device__ void swapRows(T* matrix, const Nd4jLong* shape, Nd4jLong theFirst, Nd4jLong theSecond, Nd4jLong n) { - if (theFirst != theSecond) { - for (auto i = 0; i < n; i++) { - Nd4jLong theFirstPos[] = {theFirst, i}; - Nd4jLong theSecondPos[] = {theSecond, i}; - auto theFirstIndex = shape::getOffset(shape, theFirstPos, 0); - auto theSecondIndex = shape::getOffset(shape, theSecondPos, 0); - math::nd4j_swap(matrix[theFirstIndex], matrix[theSecondIndex]); - } - } + } + if (CUSOLVER_STATUS_SUCCESS != status) { + throw cuda_exception::build("helpers::lup_: Cannot make LU decomposition", + status); + } + err = cudaFree(d_info); + if (err) { + throw cuda_exception::build( + "helpers::lup_: Cannot deallocate memory for solver info buffer", err); + } + // cusolverDnDestroy(cusolverH); + // NDArray::registerSpecialUse({input}, {input}); + input->tickWriteDevice(); +} +// ------------------------------------------------------------------------------------------------------------------ +// // + +BUILD_DOUBLE_TEMPLATE(template void lup_, + (LaunchContext * context, NDArray *input, NDArray *output, + NDArray *permutation), + FLOAT_NATIVE, INDEXING_TYPES); + +template +static __device__ void swapRows(T *matrix, const Nd4jLong *shape, + Nd4jLong theFirst, Nd4jLong theSecond, + Nd4jLong n) { + if (theFirst != theSecond) { + for (auto i = 0; i < n; i++) { + Nd4jLong theFirstPos[] = {theFirst, i}; + Nd4jLong theSecondPos[] = {theSecond, i}; + auto theFirstIndex = shape::getOffset(shape, theFirstPos, 0); + auto theSecondIndex = shape::getOffset(shape, theSecondPos, 0); + math::nd4j_swap(matrix[theFirstIndex], matrix[theSecondIndex]); } + } +} - template - static __device__ void processColumns(Nd4jLong currentRow, Nd4jLong rowNum, T* compoundBuf, const Nd4jLong* compoundShape) { - Nd4jLong xDiag[] = {currentRow, currentRow}; - auto diagIndex = shape::getOffset(compoundShape, xDiag, 0); - for (auto j = currentRow + 1; j < rowNum; j++) { - Nd4jLong xRow[] = {j, currentRow}; - auto rowIndex = shape::getOffset(compoundShape, xRow, 0); - compoundBuf[rowIndex] /= compoundBuf[diagIndex]; //output->t(i, i); - for (auto k = currentRow + 1; k < rowNum; k++) { - Nd4jLong yRow[] = {j, k}; - Nd4jLong yCol[] = {currentRow, k}; - auto rowIndexY = shape::getOffset(compoundShape, yRow, 0); - auto colIndex = shape::getOffset(compoundShape, yCol, 0); - compoundBuf[rowIndexY] -= compoundBuf[rowIndex] * compoundBuf[colIndex]; - } - } +template +static __device__ void processColumns(Nd4jLong currentRow, Nd4jLong rowNum, + T *compoundBuf, + const Nd4jLong *compoundShape) { + Nd4jLong xDiag[] = {currentRow, currentRow}; + auto diagIndex = shape::getOffset(compoundShape, xDiag, 0); + for (auto j = currentRow + 1; j < rowNum; j++) { + Nd4jLong xRow[] = {j, currentRow}; + auto rowIndex = shape::getOffset(compoundShape, xRow, 0); + compoundBuf[rowIndex] /= compoundBuf[diagIndex]; // output->t(i, i); + for (auto k = currentRow + 1; k < rowNum; k++) { + Nd4jLong yRow[] = {j, k}; + Nd4jLong yCol[] = {currentRow, k}; + auto rowIndexY = shape::getOffset(compoundShape, yRow, 0); + auto colIndex = shape::getOffset(compoundShape, yCol, 0); + compoundBuf[rowIndexY] -= compoundBuf[rowIndex] * compoundBuf[colIndex]; } + } +} - template - __device__ Nd4jLong argmaxCol(Nd4jLong column, T* compoundBuffer, const Nd4jLong* compoundShape) { - auto rowNum = shape::sizeAt(compoundShape, 0); - Nd4jLong xInitial[] = {column, column}; - auto xInitialIndex = shape::getOffset(compoundShape, xInitial, 0); - auto maxValue = T(0); //sd::math::nd4j_abs(compoundBuffer[xInitialIndex]); - auto result = -1LL; - - for (auto rowCounter = column; rowCounter < rowNum; rowCounter++) { - Nd4jLong xPos[] = {rowCounter, column}; - auto xIndex = shape::getOffset(compoundShape, xPos, 0); - if (sd::math::nd4j_abs(compoundBuffer[xIndex]) > maxValue) { - maxValue = sd::math::nd4j_max(maxValue, sd::math::nd4j_abs(compoundBuffer[xIndex])); - result = rowCounter; - } - } - return result; +template +__device__ Nd4jLong argmaxCol(Nd4jLong column, T *compoundBuffer, + const Nd4jLong *compoundShape) { + auto rowNum = shape::sizeAt(compoundShape, 0); + Nd4jLong xInitial[] = {column, column}; + auto xInitialIndex = shape::getOffset(compoundShape, xInitial, 0); + auto maxValue = T(0); // sd::math::nd4j_abs(compoundBuffer[xInitialIndex]); + auto result = -1LL; + + for (auto rowCounter = column; rowCounter < rowNum; rowCounter++) { + Nd4jLong xPos[] = {rowCounter, column}; + auto xIndex = shape::getOffset(compoundShape, xPos, 0); + if (sd::math::nd4j_abs(compoundBuffer[xIndex]) > maxValue) { + maxValue = sd::math::nd4j_max(maxValue, + sd::math::nd4j_abs(compoundBuffer[xIndex])); + result = rowCounter; } + } + return result; +} - template - static __device__ int luNN(T* matrix, const Nd4jLong* shape, I* permutation, const Nd4jLong* permuShape, Nd4jLong n) { - - for (auto i = 0; i < n - 1; i++) { - auto pivotIndex = argmaxCol(i, matrix, shape); - if (pivotIndex < 0) { - return -1;//throw std::runtime_error("helpers::luNN_: input matrix is singular."); - } - math::nd4j_swap(permutation[shape::getIndexOffset(i, permuShape)], permutation[shape::getIndexOffset(pivotIndex, permuShape)]); - swapRows(matrix, shape, (Nd4jLong)i, pivotIndex, n); - - processColumns(i, n, matrix, shape); - } - return 0; +template +static __device__ int luNN(T *matrix, const Nd4jLong *shape, I *permutation, + const Nd4jLong *permuShape, Nd4jLong n) { + for (auto i = 0; i < n - 1; i++) { + auto pivotIndex = argmaxCol(i, matrix, shape); + if (pivotIndex < 0) { + return -1; // throw std::runtime_error("helpers::luNN_: input matrix is + // singular."); } + math::nd4j_swap(permutation[shape::getIndexOffset(i, permuShape)], + permutation[shape::getIndexOffset(pivotIndex, permuShape)]); + swapRows(matrix, shape, (Nd4jLong)i, pivotIndex, n); - template - static __global__ void luBatchedKernel( - T* outputBuf, const Nd4jLong* outputShape, - I* permutations, const Nd4jLong* permuShape, - const Nd4jLong* outputTadShape, const Nd4jLong* outputTadOffsets, - const Nd4jLong* permuTadShape, const Nd4jLong* permuTadOffsets, - Nd4jLong batchNum) { - - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (auto b = start; b < batchNum; b += step) { - T* matrix = outputBuf + outputTadOffsets[b]; - I* permutation = permutations + permuTadOffsets[b]; - - if (0 != luNN(matrix, outputTadShape, permutation, permuTadShape, shape::length(permuTadShape))) break; - } - } + processColumns(i, n, matrix, shape); + } + return 0; +} - template - static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) { - auto n = input->sizeAt(-1); - auto stream = context->getCudaStream(); - NDArray iota('c', {n}, permutationVectors->dataType(), context);// = NDArrayFactory::create(); // ('c', {n}); - iota.linspace(0); iota.syncToDevice(); - - output->assign(input); // fill up output tensor with zeros -// output->tickWriteDevice(); - permutationVectors->applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), iota, *permutationVectors, true, nullptr); -// permutationVectors->tickWriteDevice(); - auto tads = ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), {-2, -1}); - auto permutaionTads = ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), {-1}); - auto batchNum = tads.numberOfTads(); - luBatchedKernel<<>>(reinterpret_cast(output->platformBuffer()), - output->specialShapeInfo(), reinterpret_cast(permutationVectors->platformBuffer()), - permutationVectors->specialShapeInfo(), tads.specialShapeInfo(), tads.specialOffsets(), - permutaionTads.specialShapeInfo(), permutaionTads.specialOffsets(), batchNum); - } +template +static __global__ void luBatchedKernel( + T *outputBuf, const Nd4jLong *outputShape, I *permutations, + const Nd4jLong *permuShape, const Nd4jLong *outputTadShape, + const Nd4jLong *outputTadOffsets, const Nd4jLong *permuTadShape, + const Nd4jLong *permuTadOffsets, Nd4jLong batchNum) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (auto b = start; b < batchNum; b += step) { + T *matrix = outputBuf + outputTadOffsets[b]; + I *permutation = permutations + permuTadOffsets[b]; + + if (0 != luNN(matrix, outputTadShape, permutation, permuTadShape, + shape::length(permuTadShape))) + break; + } +} - void lu(LaunchContext* context, NDArray* input, NDArray* output, NDArray* permutations) { - NDArray::prepareSpecialUse({output, permutations}, {input}); - BUILD_DOUBLE_SELECTOR(input->dataType(), permutations->dataType(), lu_, (context, input, output, permutations), FLOAT_NATIVE, INDEXING_TYPES); - NDArray::registerSpecialUse({output, permutations}, {input}); - } -// ------------------------------------------------------------------------------------------------------------------ // - template - static int determinant_(sd::LaunchContext *context, NDArray *input, NDArray *output) { - Nd4jLong n = input->sizeAt(-1); - Nd4jLong n2 = n * n; - std::vector dims(); - auto packX = ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); - //auto packZ = ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), {output->rankOf() - 1}); -// DataType dtype = input->dataType(); -// if (dtype != DataType::DOUBLE) -// dtype = DataType::FLOAT32; - auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT(), context); //, block.getWorkspace()); - auto det = NDArrayFactory::create(1, context); - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({output}, {input}); - dim3 launchDims(256, 256, 1024); - output->assign(1.f); - for (int e = 0; e < output->lengthOf(); e++) { - Nd4jLong pos = e * n2; -// if (matrix.dataType() == input->dataType()) - fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); -// else -// fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->special(), pos, n); - lup_(context, &matrix, nullptr, nullptr); -// else -// lup_(context, &matrix, nullptr, nullptr); - auto offset = shape::getIndexOffset(e, output->shapeInfo()); - auto inputBuf = reinterpret_cast(matrix.specialBuffer()); - auto outputBuf = reinterpret_cast(output->specialBuffer()) + offset; -// if (matrix.dataType() == input->dataType()) - determinantKernel<<< launchDims.x, launchDims.y, launchDims.z, *stream>>>(inputBuf, outputBuf, n); -// else -// determinantKernel<<>> (inputBuf, outputBuf, n); - } - NDArray::registerSpecialUse({output}, {input}); +template +static void lu_(LaunchContext *context, NDArray *input, NDArray *output, + NDArray *permutationVectors) { + auto n = input->sizeAt(-1); + auto stream = context->getCudaStream(); + NDArray iota('c', {n}, permutationVectors->dataType(), + context); // = NDArrayFactory::create(); // ('c', {n}); + iota.linspace(0); + iota.syncToDevice(); + + output->assign(input); // fill up output tensor with zeros + // output->tickWriteDevice(); + permutationVectors->applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), iota, + *permutationVectors, true, nullptr); + // permutationVectors->tickWriteDevice(); + auto tads = ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), {-2, -1}); + auto permutaionTads = ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), {-1}); + auto batchNum = tads.numberOfTads(); + luBatchedKernel<<>>( + reinterpret_cast(output->platformBuffer()), + output->specialShapeInfo(), + reinterpret_cast(permutationVectors->platformBuffer()), + permutationVectors->specialShapeInfo(), tads.specialShapeInfo(), + tads.specialOffsets(), permutaionTads.specialShapeInfo(), + permutaionTads.specialOffsets(), batchNum); +} - return Status::OK(); - } +void lu(LaunchContext *context, NDArray *input, NDArray *output, + NDArray *permutations) { + NDArray::prepareSpecialUse({output, permutations}, {input}); + BUILD_DOUBLE_SELECTOR(input->dataType(), permutations->dataType(), lu_, + (context, input, output, permutations), FLOAT_NATIVE, + INDEXING_TYPES); + NDArray::registerSpecialUse({output, permutations}, {input}); +} +// ------------------------------------------------------------------------------------------------------------------ +// // +template +static int determinant_(sd::LaunchContext *context, NDArray *input, + NDArray *output) { + Nd4jLong n = input->sizeAt(-1); + Nd4jLong n2 = n * n; + std::vector dims(); + auto packX = ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); + // auto packZ = + // ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), + // {output->rankOf() - 1}); + // DataType dtype = input->dataType(); + // if (dtype != DataType::DOUBLE) + // dtype = DataType::FLOAT32; + auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, + DataTypeUtils::fromT(), + context); //, block.workspace()); + auto det = NDArrayFactory::create(1, context); + auto stream = context->getCudaStream(); + NDArray::prepareSpecialUse({output}, {input}); + dim3 launchDims(256, 256, 1024); + output->assign(1.f); + for (int e = 0; e < output->lengthOf(); e++) { + Nd4jLong pos = e * n2; + // if (matrix.dataType() == input->dataType()) + fillMatrix<<>>( + matrix.specialBuffer(), matrix.specialShapeInfo(), + input->specialBuffer(), input->specialShapeInfo(), pos, n); + // else + // fillMatrix<<>>(matrix.specialBuffer(), + // matrix.specialShapeInfo(), input->specialBuffer(), + // input->special(), pos, n); + lup_(context, &matrix, nullptr, nullptr); + // else + // lup_(context, &matrix, nullptr, nullptr); + auto offset = shape::getIndexOffset(e, output->shapeInfo()); + auto inputBuf = reinterpret_cast(matrix.specialBuffer()); + auto outputBuf = reinterpret_cast(output->specialBuffer()) + offset; + // if (matrix.dataType() == input->dataType()) + determinantKernel<<>>( + inputBuf, outputBuf, n); + // else + // determinantKernel<<>> (inputBuf, outputBuf, n); + } + NDArray::registerSpecialUse({output}, {input}); + + return Status::OK(); +} - int determinant(sd::LaunchContext *context, NDArray *input, NDArray *output) { - NDArray::prepareSpecialUse({output}, {input}); - BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_NATIVE); - NDArray::registerSpecialUse({output}, {input}); - } +int determinant(sd::LaunchContext *context, NDArray *input, NDArray *output) { + NDArray::prepareSpecialUse({output}, {input}); + BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, + (context, input, output), FLOAT_NATIVE); + NDArray::registerSpecialUse({output}, {input}); +} - template - int logAbsDeterminant_(LaunchContext *context, NDArray *input, NDArray *output) { - Nd4jLong n = input->sizeAt(-1); - Nd4jLong n2 = n * n; - std::vector dims(); - auto packX = ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); - //auto packZ = ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), {output->rankOf() - 1}); - DataType dtype = input->dataType(); - if (dtype != DataType::DOUBLE) - dtype = DataType::FLOAT32; - - auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, context); //, block.getWorkspace()); - auto det = NDArrayFactory::create(1, context); - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({output}, {input}); - dim3 launchDims(256, 256, 1024); - output->assign(0.f); - for (int e = 0; e < output->lengthOf(); e++) { - Nd4jLong pos = e * n2; -// if (matrix.dataType() == input->dataType()) - fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); -// else -// fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->special(), pos, n); - -// if (matrix.dataType() == input->dataType()) - lup_(context, &matrix, nullptr, nullptr); -// else -// lup_(context, &matrix, nullptr, nullptr); - auto offset = shape::getIndexOffset(e, output->shapeInfo()); - auto inputBuf = reinterpret_cast(matrix.specialBuffer()); - auto outputBuf = reinterpret_cast(output->specialBuffer()) + offset; -// if (matrix.dataType() == input->dataType()) - determinantLogKernel<<>>(inputBuf, outputBuf, n); -// else -// determinantLogKernel<<>> (inputBuf, outputBuf, n); - } - NDArray::registerSpecialUse({output}, {input}); - - return Status::OK(); - - return ND4J_STATUS_OK; - } +template +int logAbsDeterminant_(LaunchContext *context, NDArray *input, + NDArray *output) { + Nd4jLong n = input->sizeAt(-1); + Nd4jLong n2 = n * n; + std::vector dims(); + auto packX = ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); + // auto packZ = + // ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), + // {output->rankOf() - 1}); + DataType dtype = input->dataType(); + if (dtype != DataType::DOUBLE) dtype = DataType::FLOAT32; + + auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, + context); //, block.workspace()); + auto det = NDArrayFactory::create(1, context); + auto stream = context->getCudaStream(); + NDArray::prepareSpecialUse({output}, {input}); + dim3 launchDims(256, 256, 1024); + output->assign(0.f); + for (int e = 0; e < output->lengthOf(); e++) { + Nd4jLong pos = e * n2; + // if (matrix.dataType() == input->dataType()) + fillMatrix<<>>( + matrix.specialBuffer(), matrix.specialShapeInfo(), + input->specialBuffer(), input->specialShapeInfo(), pos, n); + // else + // fillMatrix<<>>(matrix.specialBuffer(), + // matrix.specialShapeInfo(), input->specialBuffer(), + // input->special(), pos, n); + + // if (matrix.dataType() == input->dataType()) + lup_(context, &matrix, nullptr, nullptr); + // else + // lup_(context, &matrix, nullptr, nullptr); + auto offset = shape::getIndexOffset(e, output->shapeInfo()); + auto inputBuf = reinterpret_cast(matrix.specialBuffer()); + auto outputBuf = reinterpret_cast(output->specialBuffer()) + offset; + // if (matrix.dataType() == input->dataType()) + determinantLogKernel + <<>>(inputBuf, + outputBuf, n); + // else + // determinantLogKernel<<>> (inputBuf, + // outputBuf, n); + } + NDArray::registerSpecialUse({output}, {input}); + + return Status::OK(); + + return ND4J_STATUS_OK; +} - int logAbsDeterminant(sd::LaunchContext *context, NDArray *input, NDArray *output) { - NDArray::prepareSpecialUse({output}, {input}); - BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_NATIVE); - NDArray::registerSpecialUse({output}, {input}); - } +int logAbsDeterminant(sd::LaunchContext *context, NDArray *input, + NDArray *output) { + NDArray::prepareSpecialUse({output}, {input}); + BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, + (context, input, output), FLOAT_NATIVE); + NDArray::registerSpecialUse({output}, {input}); +} - template - static __global__ void - fillLowerUpperKernel( - void *lowerBuf, const Nd4jLong *lowerShape, - void *upperBuf, const Nd4jLong *upperShape, - void *matrixBuf, const Nd4jLong *matrixShape, - Nd4jLong n) { - - __shared__ T *lowerMatrix; - __shared__ T *upperMatrix; - __shared__ T *matrix; - - if (threadIdx.x == 0) { - lowerMatrix = reinterpret_cast(lowerBuf); - upperMatrix = reinterpret_cast(upperBuf); - matrix = reinterpret_cast(matrixBuf); - } - __syncthreads(); - - for (int k = blockIdx.x; k < n; k += gridDim.x) { // and then put all values under main diagonal on to it - for (int j = threadIdx.x; j < n; j += blockDim.x) { - Nd4jLong posX[] = {k, j}; - Nd4jLong posD[] = {j, j}; - auto xPos = shape::getOffset(lowerShape, posX); - auto yPos = shape::getOffset(upperShape, posX); - auto iPos = shape::getOffset(matrixShape, posX); - auto dPos = shape::getOffset(matrixShape, posD); - if (k >= j) - lowerMatrix[xPos] = matrix[iPos];//(k, j); - else - upperMatrix[yPos] = matrix[iPos]; //k, j); - } - } - } +template +static __global__ void fillLowerUpperKernel( + void *lowerBuf, const Nd4jLong *lowerShape, void *upperBuf, + const Nd4jLong *upperShape, void *matrixBuf, const Nd4jLong *matrixShape, + Nd4jLong n) { + __shared__ T *lowerMatrix; + __shared__ T *upperMatrix; + __shared__ T *matrix; + + if (threadIdx.x == 0) { + lowerMatrix = reinterpret_cast(lowerBuf); + upperMatrix = reinterpret_cast(upperBuf); + matrix = reinterpret_cast(matrixBuf); + } + __syncthreads(); + + for (int k = blockIdx.x; k < n; + k += + gridDim.x) { // and then put all values under main diagonal on to it + for (int j = threadIdx.x; j < n; j += blockDim.x) { + Nd4jLong posX[] = {k, j}; + Nd4jLong posD[] = {j, j}; + auto xPos = shape::getOffset(lowerShape, posX); + auto yPos = shape::getOffset(upperShape, posX); + auto iPos = shape::getOffset(matrixShape, posX); + auto dPos = shape::getOffset(matrixShape, posD); + if (k >= j) + lowerMatrix[xPos] = matrix[iPos]; //(k, j); + else + upperMatrix[yPos] = matrix[iPos]; // k, j); + } + } +} - template - static int inverse_(sd::LaunchContext *context, NDArray *input, NDArray *output) { - auto n = input->sizeAt(-1); - auto n2 = n * n; - auto dtype = DataTypeUtils::fromT(); //input->dataType(); -// if (dtype != DataType::DOUBLE) -// dtype = DataType::FLOAT32; - NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, context); - NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, context); - NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, context); - NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, context); - NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, context); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), - {input->rankOf() - 2, - input->rankOf() - 1}); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), - {output->rankOf() - 2, - output->rankOf() - 1}); - auto stream = context->getCudaStream(); - - for (auto i = 0LL; i < packX.numberOfTads(); i++) { - fillMatrix<<<1, n2, 1024, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n); - matrix.tickWriteDevice(); - //compound.assign(matrix); -// if (matrix.dataType() == input->dataType()) - lup_(context, &matrix, nullptr, nullptr); - fillLowerUpperKernel<<>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), matrix.specialBuffer(), matrix.specialShapeInfo(), n); - lower.tickWriteDevice(); - upper.tickWriteDevice(); -// lower.printIndexedBuffer("LOWER"); -// upper.printIndexedBuffer("UPPER"); - matrix.assign(0); - invertUpperMatrix(context, &upper, &matrix); // U^{-1} - matrix.tickWriteDevice(); -// matrix.printIndexedBuffer("Upper Inverted"); - compound.assign(0); - invertLowerMatrix(context, &lower, &compound); // L{-1} - compound.tickWriteDevice(); -// compound.printIndexedBuffer("Lower Inverted"); -// matrix.tickWriteDevice(); -// compound.tickWriteDevice(); - sd::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0); - upper.tickWriteDevice(); -// upper.printIndexedBuffer("Full inverted"); - returnMatrix<<<1, n2, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), i * n2, n); - } - return Status::OK(); - } +template +static int inverse_(sd::LaunchContext *context, NDArray *input, + NDArray *output) { + auto n = input->sizeAt(-1); + auto n2 = n * n; + auto dtype = DataTypeUtils::fromT(); // input->dataType(); + // if (dtype != DataType::DOUBLE) + // dtype = DataType::FLOAT32; + NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, context); + NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, context); + NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, context); + NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, context); + NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, context); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), {output->rankOf() - 2, output->rankOf() - 1}); + auto stream = context->getCudaStream(); + + for (auto i = 0LL; i < packX.numberOfTads(); i++) { + fillMatrix<<<1, n2, 1024, *stream>>>( + matrix.specialBuffer(), matrix.specialShapeInfo(), + input->specialBuffer(), input->specialShapeInfo(), i * n2, n); + matrix.tickWriteDevice(); + // compound.assign(matrix); + // if (matrix.dataType() == input->dataType()) + lup_(context, &matrix, nullptr, nullptr); + fillLowerUpperKernel<<>>( + lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), + upper.specialShapeInfo(), matrix.specialBuffer(), + matrix.specialShapeInfo(), n); + lower.tickWriteDevice(); + upper.tickWriteDevice(); + // lower.printIndexedBuffer("LOWER"); + // upper.printIndexedBuffer("UPPER"); + matrix.assign(0); + invertUpperMatrix(context, &upper, &matrix); // U^{-1} + matrix.tickWriteDevice(); + // matrix.printIndexedBuffer("Upper Inverted"); + compound.assign(0); + invertLowerMatrix(context, &lower, &compound); // L{-1} + compound.tickWriteDevice(); + // compound.printIndexedBuffer("Lower Inverted"); + // matrix.tickWriteDevice(); + // compound.tickWriteDevice(); + sd::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0); + upper.tickWriteDevice(); + // upper.printIndexedBuffer("Full inverted"); + returnMatrix<<<1, n2, 1024, *stream>>>( + output->specialBuffer(), output->specialShapeInfo(), + upper.specialBuffer(), upper.specialShapeInfo(), i * n2, n); + } + return Status::OK(); +} - int inverse(sd::LaunchContext *context, NDArray *input, NDArray *output) { - NDArray::prepareSpecialUse({output}, {input}); - BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_NATIVE); - NDArray::registerSpecialUse({output}, {input}); - } +int inverse(sd::LaunchContext *context, NDArray *input, NDArray *output) { + NDArray::prepareSpecialUse({output}, {input}); + BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, + (context, input, output), FLOAT_NATIVE); + NDArray::registerSpecialUse({output}, {input}); +} - bool checkCholeskyInput(sd::LaunchContext *context, NDArray const *input) { - return true; - } +bool checkCholeskyInput(sd::LaunchContext *context, NDArray const *input) { + return true; +} - template - __global__ void fillBatchKernel(F **dArrayBatch, F *buf, const Nd4jLong *offsets, Nd4jLong batchSize) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; +template +__global__ void fillBatchKernel(F **dArrayBatch, F *buf, + const Nd4jLong *offsets, Nd4jLong batchSize) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; - for (auto i = start; i < batchSize; i += step) { - dArrayBatch[i] = buf + offsets[i]; - } - } + for (auto i = start; i < batchSize; i += step) { + dArrayBatch[i] = buf + offsets[i]; + } +} - template - __global__ void - adjustResultsKernel(F *dArray, const Nd4jLong *shape, const Nd4jLong *offsets, Nd4jLong batchSize, Nd4jLong n) { - //auto i = blockIdx.x * blockDim.x + threadIdx.x; - Nd4jLong *shapeOf = shape::shapeOf(shape); - Nd4jLong *strideOf = shape::stride(shape); - - for (auto i = blockIdx.x; i < batchSize; i += gridDim.x) { - auto current = dArray + offsets[i]; - for (auto r = threadIdx.x; r < n; r += blockDim.x) { - for (auto c = r + 1; c < n; c++) { - Nd4jLong posRC[] = {r, c}; - auto pos = r * n + c; //shape::getOffset(0, shapeOf, strideOf, posRC, 2); - current[pos] = 0.; - } - } - } - } +template +__global__ void adjustResultsKernel(F *dArray, const Nd4jLong *shape, + const Nd4jLong *offsets, Nd4jLong batchSize, + Nd4jLong n) { + // auto i = blockIdx.x * blockDim.x + threadIdx.x; + Nd4jLong *shapeOf = shape::shapeOf(shape); + Nd4jLong *strideOf = shape::stride(shape); + + for (auto i = blockIdx.x; i < batchSize; i += gridDim.x) { + auto current = dArray + offsets[i]; + for (auto r = threadIdx.x; r < n; r += blockDim.x) { + for (auto c = r + 1; c < n; c++) { + Nd4jLong posRC[] = {r, c}; + auto pos = + r * n + c; // shape::getOffset(0, shapeOf, strideOf, posRC, 2); + current[pos] = 0.; + } + } + } +} - template - int cholesky__(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { - if (!inplace) - output->assign(input); - auto tempOutput =output->dup(); - cusolverDnHandle_t handle = nullptr; - auto n = input->sizeAt(-1); - auto n2 = n * n; - NDArray::prepareSpecialUse({output}, {input}); - auto status = cusolverDnCreate(&handle); - if (CUSOLVER_STATUS_SUCCESS != status) { - throw cuda_exception::build("helpers::cholesky_: Cannot create solver handle", status); - } - F **dArrayBatch = nullptr; - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(tempOutput.shapeInfo(), - {tempOutput.rankOf() - 2, - tempOutput.rankOf() - 1}); - const Nd4jLong batchSize = packX.numberOfTads(); - int *dInfoArray = nullptr; - auto err = cudaMalloc((void **) &dArrayBatch, sizeof(F *) * batchSize); - if (err) { - throw cuda_exception::build("helpers::cholesky_: Cannot allocate memory for solver batch data buffer", - err); - } - err = cudaMalloc((void **) &dInfoArray, sizeof(int) * batchSize); - if (err) { - throw cuda_exception::build("helpers::cholesky_: Cannot allocate memory for solver errors buffer", err); - } - auto stream = context->getCudaStream(); - fillBatchKernel<<<1, batchSize, 128, *stream>>>(dArrayBatch, reinterpret_cast(tempOutput.specialBuffer()), packX.specialOffsets(), batchSize); - - status = cusolverDnSetStream(handle, *stream); - if (CUSOLVER_STATUS_SUCCESS != status) { - throw cuda_exception::build("helpers::cholesky_: Cannot set stream to solver handle", status); - } - const cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER; - if (input->dataType() == DataType::DOUBLE) - status = cusolverDnDpotrfBatched( - handle, - uplo, - n, - (double **) dArrayBatch, - n, - dInfoArray, - batchSize); - else - status = cusolverDnSpotrfBatched( - handle, - uplo, - n, - (float **) dArrayBatch, - n, - dInfoArray, - batchSize); - - if (CUSOLVER_STATUS_SUCCESS != status) { - throw cuda_exception::build("helpers::cholesky_: Cholesky factorization failed for batch", status); - } - adjustResultsKernel<<>>(reinterpret_cast(tempOutput.specialBuffer()), packX.specialShapeInfo(), packX.specialOffsets(), batchSize, n); - - err = cudaFree(dArrayBatch); - if (err) { - throw cuda_exception::build("helpers::cholesky_: Cannot deallocate memory for solver batch data buffer", - err); - } - err = cudaFree(dInfoArray); - if (err) { - throw cuda_exception::build("helpers::cholesky_: Cannot allocate memory for solver errors buffer", err); - } - - if (!inplace) - output->assign(tempOutput); - else - input->assign(tempOutput); - - NDArray::registerSpecialUse({output}, {input}); - return Status::OK(); - } +template +int cholesky__(LaunchContext *context, NDArray *input, NDArray *output, + bool inplace) { + if (!inplace) output->assign(input); + auto tempOutput = output->dup(); + cusolverDnHandle_t handle = nullptr; + auto n = input->sizeAt(-1); + auto n2 = n * n; + NDArray::prepareSpecialUse({output}, {input}); + auto status = cusolverDnCreate(&handle); + if (CUSOLVER_STATUS_SUCCESS != status) { + throw cuda_exception::build( + "helpers::cholesky_: Cannot create solver handle", status); + } + F **dArrayBatch = nullptr; + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + tempOutput.shapeInfo(), + {tempOutput.rankOf() - 2, tempOutput.rankOf() - 1}); + const Nd4jLong batchSize = packX.numberOfTads(); + int *dInfoArray = nullptr; + auto err = cudaMalloc((void **)&dArrayBatch, sizeof(F *) * batchSize); + if (err) { + throw cuda_exception::build( + "helpers::cholesky_: Cannot allocate memory for solver batch data " + "buffer", + err); + } + err = cudaMalloc((void **)&dInfoArray, sizeof(int) * batchSize); + if (err) { + throw cuda_exception::build( + "helpers::cholesky_: Cannot allocate memory for solver errors buffer", + err); + } + auto stream = context->getCudaStream(); + fillBatchKernel<<<1, batchSize, 128, *stream>>>( + dArrayBatch, reinterpret_cast(tempOutput.specialBuffer()), + packX.specialOffsets(), batchSize); + + status = cusolverDnSetStream(handle, *stream); + if (CUSOLVER_STATUS_SUCCESS != status) { + throw cuda_exception::build( + "helpers::cholesky_: Cannot set stream to solver handle", status); + } + const cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER; + if (input->dataType() == DataType::DOUBLE) + status = cusolverDnDpotrfBatched(handle, uplo, n, (double **)dArrayBatch, n, + dInfoArray, batchSize); + else + status = cusolverDnSpotrfBatched(handle, uplo, n, (float **)dArrayBatch, n, + dInfoArray, batchSize); + + if (CUSOLVER_STATUS_SUCCESS != status) { + throw cuda_exception::build( + "helpers::cholesky_: Cholesky factorization failed for batch", status); + } + adjustResultsKernel<<>>( + reinterpret_cast(tempOutput.specialBuffer()), + packX.specialShapeInfo(), packX.specialOffsets(), batchSize, n); + + err = cudaFree(dArrayBatch); + if (err) { + throw cuda_exception::build( + "helpers::cholesky_: Cannot deallocate memory for solver batch data " + "buffer", + err); + } + err = cudaFree(dInfoArray); + if (err) { + throw cuda_exception::build( + "helpers::cholesky_: Cannot allocate memory for solver errors buffer", + err); + } + + if (!inplace) + output->assign(tempOutput); + else + input->assign(tempOutput); + + NDArray::registerSpecialUse({output}, {input}); + return Status::OK(); +} // template - int cholesky_(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { - NDArray::prepareSpecialUse({output}, {input}); - if (input->dataType() == DataType::DOUBLE) - cholesky__(context, input, output, inplace); - else if (input->dataType() == DataType::FLOAT32) - cholesky__(context, input, output, inplace); - else { - std::unique_ptr tempOutput( - NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32, context)); - tempOutput->assign(input); - cholesky__(context, tempOutput.get(), tempOutput.get(), true); - output->assign(tempOutput.get()); - } - NDArray::registerSpecialUse({output}, {input}); - return Status::OK(); - } +int cholesky_(LaunchContext *context, NDArray *input, NDArray *output, + bool inplace) { + NDArray::prepareSpecialUse({output}, {input}); + if (input->dataType() == DataType::DOUBLE) + cholesky__(context, input, output, inplace); + else if (input->dataType() == DataType::FLOAT32) + cholesky__(context, input, output, inplace); + else { + std::unique_ptr tempOutput(NDArrayFactory::create_( + 'c', input->getShapeAsVector(), DataType::FLOAT32, context)); + tempOutput->assign(input); + cholesky__(context, tempOutput.get(), tempOutput.get(), true); + output->assign(tempOutput.get()); + } + NDArray::registerSpecialUse({output}, {input}); + return Status::OK(); +} - int cholesky(sd::LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { -// BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES); - return cholesky_(context, input, output, inplace); - } -// BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES); - BUILD_SINGLE_TEMPLATE(template int inverse_, (sd::LaunchContext * context, NDArray * input, NDArray * output), - FLOAT_NATIVE); - - template - __global__ void logDetKernel( - const T *inputBuf, const Nd4jLong *inputShape, - Nd4jLong batchNum, - const Nd4jLong *tadShape, const Nd4jLong *tadOffsets, - T *outputBuf, const Nd4jLong *outputShape) { - - __shared__ int n; - if (threadIdx.x == 0) { - n = shape::sizeAt(inputShape, -1); // * shape::sizeAt(inputShape, -1); - } - __syncthreads(); - - auto output = outputBuf; - auto input = inputBuf; - - for (auto i = blockIdx.x; i < batchNum; i += gridDim.x) { - auto current = input + tadOffsets[i]; - - auto zIndex = shape::getIndexOffset(i, outputShape); - for (auto e = threadIdx.x; e < n; e += blockDim.x) { - Nd4jLong diag[] = {e, e}; - auto xIndex = shape::getOffset(tadShape, diag); - math::atomics::nd4j_atomicAdd(&output[zIndex],math::nd4j_log(current[xIndex] * current[xIndex])); - } - } - } +int cholesky(sd::LaunchContext *context, NDArray *input, NDArray *output, + bool inplace) { + // BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, + // input, output, inplace), FLOAT_TYPES); + return cholesky_(context, input, output, inplace); +} +// BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, +// NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template int inverse_, + (sd::LaunchContext * context, NDArray *input, + NDArray *output), + FLOAT_NATIVE); + +template +__global__ void logDetKernel(const T *inputBuf, const Nd4jLong *inputShape, + Nd4jLong batchNum, const Nd4jLong *tadShape, + const Nd4jLong *tadOffsets, T *outputBuf, + const Nd4jLong *outputShape) { + __shared__ int n; + if (threadIdx.x == 0) { + n = shape::sizeAt(inputShape, -1); // * shape::sizeAt(inputShape, -1); + } + __syncthreads(); + + auto output = outputBuf; + auto input = inputBuf; + + for (auto i = blockIdx.x; i < batchNum; i += gridDim.x) { + auto current = input + tadOffsets[i]; + + auto zIndex = shape::getIndexOffset(i, outputShape); + for (auto e = threadIdx.x; e < n; e += blockDim.x) { + Nd4jLong diag[] = {e, e}; + auto xIndex = shape::getOffset(tadShape, diag); + math::atomics::nd4j_atomicAdd( + &output[zIndex], + math::nd4j_log(current[xIndex] * current[xIndex])); + } + } +} - template - int logdetFunctor_(sd::LaunchContext *context, NDArray *input, NDArray *output) { - NDArray::prepareSpecialUse({output}, {input}); - auto n2 = input->sizeAt(-1) * input->sizeAt(-2); - auto stream = context->getCudaStream(); - NDArray tempOutput(*input); - - cholesky(context, input, &tempOutput, false); - - auto outputBuf = output->dataBuffer()->specialAsT(); //reinterpret_cast(output->specialBuffer()); // + e * n2; // + e * n2; - auto inputBuf = tempOutput.dataBuffer()->specialAsT(); //reinterpret_cast(tempOutput.specialBuffer()); - output->nullify(); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(tempOutput.shapeInfo(), - {tempOutput.rankOf() - 2, - tempOutput.rankOf() - 1}); - logDetKernel<<<128, 512, 256, *stream>>>(inputBuf, tempOutput.specialShapeInfo(), - packX.numberOfTads(), packX.specialShapeInfo(), - packX.specialOffsets(), outputBuf, output->specialShapeInfo()); - output->tickWriteDevice(); - NDArray::registerSpecialUse({output}, {input}); - return Status::OK(); - } +template +int logdetFunctor_(sd::LaunchContext *context, NDArray *input, + NDArray *output) { + NDArray::prepareSpecialUse({output}, {input}); + auto n2 = input->sizeAt(-1) * input->sizeAt(-2); + auto stream = context->getCudaStream(); + NDArray tempOutput(*input); + + cholesky(context, input, &tempOutput, false); + + auto outputBuf = + output->dataBuffer() + ->specialAsT(); // reinterpret_cast(output->specialBuffer()); + // // + e * n2; // + e * n2; + auto inputBuf = + tempOutput.dataBuffer() + ->specialAsT< + T>(); // reinterpret_cast(tempOutput.specialBuffer()); + output->nullify(); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + tempOutput.shapeInfo(), + {tempOutput.rankOf() - 2, tempOutput.rankOf() - 1}); + logDetKernel<<<128, 512, 256, *stream>>>( + inputBuf, tempOutput.specialShapeInfo(), packX.numberOfTads(), + packX.specialShapeInfo(), packX.specialOffsets(), outputBuf, + output->specialShapeInfo()); + output->tickWriteDevice(); + NDArray::registerSpecialUse({output}, {input}); + return Status::OK(); +} - int logdetFunctor(sd::LaunchContext *context, NDArray *input, NDArray *output) { - BUILD_SINGLE_SELECTOR(output->dataType(), return logdetFunctor_, (context, input, output), FLOAT_NATIVE); - } +int logdetFunctor(sd::LaunchContext *context, NDArray *input, NDArray *output) { + BUILD_SINGLE_SELECTOR(output->dataType(), return logdetFunctor_, + (context, input, output), FLOAT_NATIVE); +} - /* - * lup - batched input, batched outputs - * */ - int lup(LaunchContext *context, NDArray *input, NDArray *compound, NDArray *permutation) { - BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lup_,(context, input, compound, permutation), FLOAT_NATIVE, INDEXING_TYPES); - return Status::OK(); - } +/* + * lup - batched input, batched outputs + * */ +int lup(LaunchContext *context, NDArray *input, NDArray *compound, + NDArray *permutation) { + BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lup_, + (context, input, compound, permutation), FLOAT_NATIVE, + INDEXING_TYPES); + return Status::OK(); +} // BUILD_SINGLE_TEMPLATE(template int logdetFunctor_, -// (sd::LaunchContext * context, NDArray * input, NDArray * output), FLOAT_NATIVE); - } -} -} +// (sd::LaunchContext * context, NDArray * input, +// NDArray * output), FLOAT_NATIVE); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/matrixSetDiag.cu b/libnd4j/include/ops/declarable/helpers/cuda/matrixSetDiag.cu index 97124c3dba69..a8dff33b48ae 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/matrixSetDiag.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/matrixSetDiag.cu @@ -19,84 +19,104 @@ // #include -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template -__global__ static void matrixSetDiagCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const bool zeroPad) { - - // x - input, shape [A,B,C] - // y - diagonal, shape [A,B] - // z - output, shape [A,B,C] - // input and output are the same array (x == z) when zeroPad = true - - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ int xRank, *sharedMem; // xRank = zRank, xRank = yRank + 1 - __shared__ Nd4jLong xLen; // xLen = zLen - __shared__ bool areSameOffsets; - - if (threadIdx.x == 0) { - - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - areSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); // shapes are definitely the same, but strides might not - - xRank = shape::rank(xShapeInfo); - xLen = shape::length(xShapeInfo); - } - - __syncthreads(); - - auto coords = sharedMem + threadIdx.x * xRank; // we provide (xRank * sizeof(int) * threadIdx.x) amount of shared memory per each thread - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < xLen; i += gridDim.x * blockDim.x) { - - shape::index2coords(i, xShapeInfo, coords); - - const auto xOffset = shape::getOffset(xShapeInfo, coords); - const auto zOffset = areSameOffsets ? xOffset : shape::getOffset(zShapeInfo, coords); - - // condition to be on diagonal of innermost matrix - if(coords[xRank - 2] == coords[xRank - 1]) - z[zOffset] = y[shape::getOffset(yShapeInfo, coords)]; - else - z[zOffset] = zeroPad ? static_cast(0) : x[xOffset]; - } +template +__global__ static void matrixSetDiagCuda(const void* vx, + const Nd4jLong* xShapeInfo, + const void* vy, + const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, + const bool zeroPad) { + // x - input, shape [A,B,C] + // y - diagonal, shape [A,B] + // z - output, shape [A,B,C] + // input and output are the same array (x == z) when zeroPad = true + + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + __shared__ int xRank, *sharedMem; // xRank = zRank, xRank = yRank + 1 + __shared__ Nd4jLong xLen; // xLen = zLen + __shared__ bool areSameOffsets; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + areSameOffsets = shape::haveSameShapeAndStrides( + xShapeInfo, + zShapeInfo); // shapes are definitely the same, but strides might not + + xRank = shape::rank(xShapeInfo); + xLen = shape::length(xShapeInfo); + } + + __syncthreads(); + + auto coords = + sharedMem + + threadIdx.x * xRank; // we provide (xRank * sizeof(int) * threadIdx.x) + // amount of shared memory per each thread + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < xLen; i += gridDim.x * blockDim.x) { + shape::index2coords(i, xShapeInfo, coords); + + const auto xOffset = shape::getOffset(xShapeInfo, coords); + const auto zOffset = + areSameOffsets ? xOffset : shape::getOffset(zShapeInfo, coords); + + // condition to be on diagonal of innermost matrix + if (coords[xRank - 2] == coords[xRank - 1]) + z[zOffset] = y[shape::getOffset(yShapeInfo, coords)]; + else + z[zOffset] = zeroPad ? static_cast(0) : x[xOffset]; + } } /////////////////////////////////////////////////////////////////// -template -static void matrixSetDiagCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const bool zeroPad) { - - matrixSetDiagCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, zeroPad); +template +static void matrixSetDiagCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vy, const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const bool zeroPad) { + matrixSetDiagCuda<<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, zeroPad); } /////////////////////////////////////////////////////////////////// -void matrixSetDiag(sd::LaunchContext* context, const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad) { - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * input.rankOf() + 128; - - PointersManager manager(context, "matrixSetDiag"); - - NDArray::prepareSpecialUse({&output}, {&input, &diagonal}); - BUILD_SINGLE_SELECTOR(input.dataType(), matrixSetDiagCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), diagonal.specialBuffer(), diagonal.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), zeroPad), LIBND4J_TYPES); - NDArray::registerSpecialUse({&output}, {&input, &diagonal}); - - manager.synchronize(); +void matrixSetDiag(sd::LaunchContext* context, const NDArray& input, + const NDArray& diagonal, NDArray& output, + const bool zeroPad) { + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(int) * input.rankOf() + 128; + + PointersManager manager(context, "matrixSetDiag"); + + NDArray::prepareSpecialUse({&output}, {&input, &diagonal}); + BUILD_SINGLE_SELECTOR( + input.dataType(), matrixSetDiagCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), + diagonal.specialBuffer(), diagonal.specialShapeInfo(), + output.specialBuffer(), output.specialShapeInfo(), zeroPad), + LIBND4J_TYPES); + NDArray::registerSpecialUse({&output}, {&input, &diagonal}); + + manager.synchronize(); } -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/matrix_band.cu b/libnd4j/include/ops/declarable/helpers/cuda/matrix_band.cu index 446d57b27ce4..17c35a4d2fc1 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/matrix_band.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/matrix_band.cu @@ -17,11 +17,11 @@ // // @author George A. Shulinok // -#include -#include #include -#include #include +#include +#include +#include namespace sd { namespace ops { @@ -42,75 +42,87 @@ namespace helpers { // numTads - number of subarrays // inputLength - input subarray length // - template - static __global__ void matrixBandKernel(const void* inputBuffer, const Nd4jLong* inputShape, - void* outputBuffer, const Nd4jLong* outputShape, - Nd4jLong lowerBand, Nd4jLong upperBand, - const Nd4jLong* tadOnlyInputShapeInfo, const Nd4jLong* tadInputOffsets, - const Nd4jLong* tadOnlyOutputShapeInfo, const Nd4jLong* tadOutputOffsets, - Nd4jLong numTads, - Nd4jLong inputLength) { - int totalThreads = blockDim.x; - Nd4jLong rows = shape::sizeAt(inputShape, -2); - Nd4jLong cols = shape::sizeAt(inputShape, -1); - for (Nd4jLong e = blockIdx.x; e < numTads; e += gridDim.x) { - auto yOffset = tadInputOffsets[e]; - auto xOffset = tadOutputOffsets[e]; - for (Nd4jLong i = blockIdx.y; i < rows; i += gridDim.y) { - for (Nd4jLong j = threadIdx.x; j < cols; j += totalThreads) { - Nd4jLong coords[2] = {i, j}; - Nd4jLong tadOffsetOut = shape::getOffset(tadOnlyOutputShapeInfo, coords); - Nd4jLong tadOffsetIn = shape::getOffset(tadOnlyInputShapeInfo, coords); +template +static __global__ void matrixBandKernel( + const void* inputBuffer, const Nd4jLong* inputShape, void* outputBuffer, + const Nd4jLong* outputShape, Nd4jLong lowerBand, Nd4jLong upperBand, + const Nd4jLong* tadOnlyInputShapeInfo, const Nd4jLong* tadInputOffsets, + const Nd4jLong* tadOnlyOutputShapeInfo, const Nd4jLong* tadOutputOffsets, + Nd4jLong numTads, Nd4jLong inputLength) { + int totalThreads = blockDim.x; + Nd4jLong rows = shape::sizeAt(inputShape, -2); + Nd4jLong cols = shape::sizeAt(inputShape, -1); + for (Nd4jLong e = blockIdx.x; e < numTads; e += gridDim.x) { + auto yOffset = tadInputOffsets[e]; + auto xOffset = tadOutputOffsets[e]; + for (Nd4jLong i = blockIdx.y; i < rows; i += gridDim.y) { + for (Nd4jLong j = threadIdx.x; j < cols; j += totalThreads) { + Nd4jLong coords[2] = {i, j}; + Nd4jLong tadOffsetOut = + shape::getOffset(tadOnlyOutputShapeInfo, coords); + Nd4jLong tadOffsetIn = shape::getOffset(tadOnlyInputShapeInfo, coords); - if (i >= j) { // check lower diagonals - if (lowerBand > 0) { - if ((i - j) > lowerBand) - *(reinterpret_cast(outputBuffer) + xOffset + tadOffsetOut) = T(0); - else - *(reinterpret_cast(outputBuffer) + xOffset + tadOffsetOut) = *( - reinterpret_cast(inputBuffer) + yOffset + tadOffsetIn); - } - } else if (j > i) { - if (upperBand > 0) - if ((j - i) > upperBand) - *(reinterpret_cast(outputBuffer) + xOffset + tadOffsetOut) = T(0); - else - *(reinterpret_cast(outputBuffer) + xOffset + tadOffsetOut) = *( - reinterpret_cast(inputBuffer) + yOffset + tadOffsetIn); - } - } - } + if (i >= j) { // check lower diagonals + if (lowerBand > 0) { + if ((i - j) > lowerBand) + *(reinterpret_cast(outputBuffer) + xOffset + tadOffsetOut) = + T(0); + else + *(reinterpret_cast(outputBuffer) + xOffset + tadOffsetOut) = + *(reinterpret_cast(inputBuffer) + yOffset + + tadOffsetIn); + } + } else if (j > i) { + if (upperBand > 0) + if ((j - i) > upperBand) + *(reinterpret_cast(outputBuffer) + xOffset + tadOffsetOut) = + T(0); + else + *(reinterpret_cast(outputBuffer) + xOffset + tadOffsetOut) = + *(reinterpret_cast(inputBuffer) + yOffset + + tadOffsetIn); } - + } } + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // matrixBandPart_ - main algorithm caller // - template - void matrixBandPart_(sd::LaunchContext * context, NDArray* input, NDArray* output, Nd4jLong lowerBand, Nd4jLong upperBand) { - dim3 launchDims(256, 512, 8192); - auto stream = context->getCudaStream(); +template +void matrixBandPart_(sd::LaunchContext* context, NDArray* input, + NDArray* output, Nd4jLong lowerBand, Nd4jLong upperBand) { + dim3 launchDims(256, 512, 8192); + auto stream = context->getCudaStream(); - std::vector lastDims({input->rankOf() - 2, input->rankOf() - 1}); - std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(input->rankOf(), lastDims); + std::vector lastDims({input->rankOf() - 2, input->rankOf() - 1}); + std::vector dimsToExclude = + ShapeUtils::evalDimsToExclude(input->rankOf(), lastDims); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), lastDims); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), lastDims); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), lastDims); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), lastDims); - const Nd4jLong numTads = packX.numberOfTads(); + const Nd4jLong numTads = packX.numberOfTads(); - NDArray::prepareSpecialUse({output}, {input}); - matrixBandKernel<<>>(input->specialBuffer(), - input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), - lowerBand, upperBand, packX.specialShapeInfo(), packX.specialOffsets(), packZ.specialShapeInfo(), packZ.specialOffsets(), numTads, input->lengthOf()); - NDArray::registerSpecialUse({output}, {input}); - } - - //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - void matrixBandPart(sd::LaunchContext * context, NDArray* input, NDArray* output, Nd4jLong lowerBand, Nd4jLong upperBand) { - BUILD_SINGLE_SELECTOR(input->dataType(), matrixBandPart_, (context, input, output, lowerBand, upperBand), FLOAT_TYPES); - } -} -} + NDArray::prepareSpecialUse({output}, {input}); + matrixBandKernel<<>>( + input->specialBuffer(), input->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), lowerBand, upperBand, + packX.specialShapeInfo(), packX.specialOffsets(), + packZ.specialShapeInfo(), packZ.specialOffsets(), numTads, + input->lengthOf()); + NDArray::registerSpecialUse({output}, {input}); } +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void matrixBandPart(sd::LaunchContext* context, NDArray* input, NDArray* output, + Nd4jLong lowerBand, Nd4jLong upperBand) { + BUILD_SINGLE_SELECTOR(input->dataType(), matrixBandPart_, + (context, input, output, lowerBand, upperBand), + FLOAT_TYPES); +} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/matrix_diag_part.cu b/libnd4j/include/ops/declarable/helpers/cuda/matrix_diag_part.cu index b8edcbc26677..2856e3a7d08d 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/matrix_diag_part.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/matrix_diag_part.cu @@ -19,13 +19,12 @@ // #include -#include +#include #include -#include +#include #include #include -#include -#include +#include namespace sd { namespace ops { @@ -33,70 +32,89 @@ namespace helpers { //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // put diagonals from input batched matricies to output batched vectors - template - static __global__ void matrixDiagPartKernel(void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength, - const Nd4jLong* tadOnlyInputShapeInfo, const Nd4jLong *tadInputOffsets, - const Nd4jLong* tadOnlyOutputShapeInfo, const Nd4jLong *tadOutputOffsets) { - int totalThreads = blockDim.x; - for (Nd4jLong i = blockIdx.x; i < numTads; i += gridDim.x) { - auto yOffset = tadInputOffsets[i]; - auto xOffset = tadOutputOffsets[i]; - for (Nd4jLong j = threadIdx.x; j < inputLength; j += totalThreads) { - Nd4jLong coords[2] = {j, j}; - Nd4jLong tadOffset = shape::getOffset(tadOnlyInputShapeInfo, coords); - *(reinterpret_cast(outputBuffer) + xOffset + shape::getIndexOffset(j, tadOnlyOutputShapeInfo)) = *(reinterpret_cast(inputBuffer) + yOffset + tadOffset); - } - } +template +static __global__ void matrixDiagPartKernel( + void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, + Nd4jLong inputLength, const Nd4jLong* tadOnlyInputShapeInfo, + const Nd4jLong* tadInputOffsets, const Nd4jLong* tadOnlyOutputShapeInfo, + const Nd4jLong* tadOutputOffsets) { + int totalThreads = blockDim.x; + for (Nd4jLong i = blockIdx.x; i < numTads; i += gridDim.x) { + auto yOffset = tadInputOffsets[i]; + auto xOffset = tadOutputOffsets[i]; + for (Nd4jLong j = threadIdx.x; j < inputLength; j += totalThreads) { + Nd4jLong coords[2] = {j, j}; + Nd4jLong tadOffset = shape::getOffset(tadOnlyInputShapeInfo, coords); + *(reinterpret_cast(outputBuffer) + xOffset + + shape::getIndexOffset(j, tadOnlyOutputShapeInfo)) = + *(reinterpret_cast(inputBuffer) + yOffset + tadOffset); } + } +} ////////////////////////////////////////////////////////////////////////// // Returns a batched matrix tensor with new batched diagonal values. -// for detailed explanations please take a look on web page: https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag +// for detailed explanations please take a look on web page: +// https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag // - template - int _matrixDiagPart(sd::LaunchContext * context, const NDArray* input, NDArray* output) { - auto stream = context->getCudaStream(); - auto listOut = output->allTensorsAlongDimension({output->rankOf() - 1}); - auto listDiag = input->allTensorsAlongDimension({input->rankOf() - 2, input->rankOf() - 1}); - - if (listOut.size() != listDiag.size()) { - nd4j_printf("matrix_diag_part: Input matrix has wrong shape.", ""); - return ND4J_STATUS_VALIDATION; - } - Nd4jLong lastDimension = sd::math::nd4j_min(input->sizeAt(-2), input->sizeAt(-1)); - - std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(output->rankOf(), {output->rankOf() - 1}); - const Nd4jLong numTads = ShapeUtils::getNumOfSubArrs(input->shapeInfo(), dimsToExclude); //this->tensorsAlongDimension({dimension}); - //printf("Repeat delta %lld, numTads %lld\n", repeatDelta, numTads); - //tadOnlyInputShapeInfo, tadInputOffsets, tadOnlyOutputShapeInfo, tadOutputOffsets; - std::vector outputDims({output->rankOf() - 1}); - std::vector inputDims({input->rankOf() - 2, input->rankOf() - 1}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), inputDims); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), outputDims); - - - if (!output->isActualOnDeviceSide()) - input->syncToDevice(); - - if (!input->isActualOnDeviceSide()) - input->syncToDevice(); - - - dim3 launchDims(256, 512, 8192); - matrixDiagPartKernel<<>>(input->specialBuffer(), output->specialBuffer(), numTads, lastDimension, packX.specialShapeInfo(), packX.specialOffsets(), packZ.specialShapeInfo(), packZ.specialOffsets()); - - return Status::OK(); - } +template +int _matrixDiagPart(sd::LaunchContext* context, const NDArray* input, + NDArray* output) { + auto stream = context->getCudaStream(); + auto listOut = output->allTensorsAlongDimension({output->rankOf() - 1}); + auto listDiag = input->allTensorsAlongDimension( + {input->rankOf() - 2, input->rankOf() - 1}); + + if (listOut.size() != listDiag.size()) { + nd4j_printf("matrix_diag_part: Input matrix has wrong shape.", ""); + return ND4J_STATUS_VALIDATION; + } + Nd4jLong lastDimension = + sd::math::nd4j_min(input->sizeAt(-2), input->sizeAt(-1)); + + std::vector dimsToExclude = + ShapeUtils::evalDimsToExclude(output->rankOf(), {output->rankOf() - 1}); + const Nd4jLong numTads = ShapeUtils::getNumOfSubArrs( + input->shapeInfo(), + dimsToExclude); // this->tensorsAlongDimension({dimension}); + // printf("Repeat delta %lld, numTads %lld\n", repeatDelta, numTads); + // tadOnlyInputShapeInfo, tadInputOffsets, tadOnlyOutputShapeInfo, + // tadOutputOffsets; + std::vector outputDims({output->rankOf() - 1}); + std::vector inputDims({input->rankOf() - 2, input->rankOf() - 1}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), inputDims); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), outputDims); + + if (!output->isActualOnDeviceSide()) input->syncToDevice(); + + if (!input->isActualOnDeviceSide()) input->syncToDevice(); + + dim3 launchDims(256, 512, 8192); + matrixDiagPartKernel + <<>>( + input->specialBuffer(), output->specialBuffer(), numTads, + lastDimension, packX.specialShapeInfo(), packX.specialOffsets(), + packZ.specialShapeInfo(), packZ.specialOffsets()); + + return Status::OK(); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // caller for _matrixDiagPart // - int matrixDiagPart(sd::LaunchContext * context, const NDArray* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), return _matrixDiagPart, (context, input, output), LIBND4J_TYPES); - } +int matrixDiagPart(sd::LaunchContext* context, const NDArray* input, + NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), return _matrixDiagPart, + (context, input, output), LIBND4J_TYPES); +} - BUILD_SINGLE_TEMPLATE(template int _matrixDiagPart, (sd::LaunchContext * context, const NDArray* input, NDArray* output), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template int _matrixDiagPart, + (sd::LaunchContext * context, const NDArray* input, + NDArray* output), + LIBND4J_TYPES); -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu b/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu index 8c30e510fdf4..4ee2cc0ff189 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu @@ -18,80 +18,93 @@ // @author raver119@gmail.com // -#include #include - +#include namespace sd { namespace ops { namespace helpers { - template - static _CUDA_G void indicesFiller(void *vz, Nd4jLong const* zShapeInfo, Nd4jLong part, Nd4jLong bSize) { - auto z = reinterpret_cast(vz); +template +static _CUDA_G void indicesFiller(void* vz, Nd4jLong const* zShapeInfo, + Nd4jLong part, Nd4jLong bSize) { + auto z = reinterpret_cast(vz); - for (int b = blockIdx.x; b < bSize; b += gridDim.x) { - for (Nd4jLong e = threadIdx.x; e < part; e += blockDim.x) { - z[shape::getIndexOffset(e + b * part, zShapeInfo)] = static_cast(e); - } - } + for (int b = blockIdx.x; b < bSize; b += gridDim.x) { + for (Nd4jLong e = threadIdx.x; e < part; e += blockDim.x) { + z[shape::getIndexOffset(e + b * part, zShapeInfo)] = static_cast(e); } + } +} - template - static void maxPoolingFunctor_(sd::graph::Context& block, NDArray* input, NDArray* values, std::vector const& params, NDArray* indices) { - int kY = params[0]; - int kX = params[1]; +template +static void maxPoolingFunctor_(sd::graph::Context& block, NDArray* input, + NDArray* values, std::vector const& params, + NDArray* indices) { + int kY = params[0]; + int kX = params[1]; - int sY = params[2]; - int sX = params[3]; + int sY = params[2]; + int sX = params[3]; - int pY = params[4]; - int pX = params[5]; + int pY = params[4]; + int pX = params[5]; - int dY = params[6]; - int dX = params[7]; + int dY = params[6]; + int dX = params[7]; - int oY = 0; - int oX = 0; + int oY = 0; + int oX = 0; - const int bSize = input->sizeAt(0); - const int inD = input->sizeAt(1); - const int inY = input->sizeAt(2); - const int inX = input->sizeAt(3); + const int bSize = input->sizeAt(0); + const int inD = input->sizeAt(1); + const int inY = input->sizeAt(2); + const int inX = input->sizeAt(3); - const bool isSameMode = params[8] != 0; + const bool isSameMode = params[8] != 0; - ConvolutionUtils::calcOutSizePool2D(oY, oX, kY, kX, sY, sX, pY, pX, dY, dX, inY, inX, isSameMode); + ConvolutionUtils::calcOutSizePool2D(oY, oX, kY, kX, sY, sX, pY, pX, dY, dX, + inY, inX, isSameMode); - if (isSameMode) - ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, params[0], params[1], params[2], params[3], params[6], params[7]); + if (isSameMode) + ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, params[0], + params[1], params[2], params[3], params[6], + params[7]); - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; - ConvolutionUtils::pooling2d(block, *input, *values, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::MAX_POOL, 1); + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; + ConvolutionUtils::pooling2d(block, *input, *values, kY, kX, sY, sX, pY, pX, + dY, dX, PoolingType::MAX_POOL, 1); - if (nullptr != indices) { - // for max_pool_with_argmax - auto total = input->lengthOf(); - auto part = total / bSize; + if (nullptr != indices) { + // for max_pool_with_argmax + auto total = input->lengthOf(); + auto part = total / bSize; - indicesFiller<<<256, 256, 1024, *block.launchContext()->getCudaStream()>>>(indices->specialBuffer(), indices->specialShapeInfo(), part, bSize); + indicesFiller + <<<256, 256, 1024, *block.launchContext()->getCudaStream()>>>( + indices->specialBuffer(), indices->specialShapeInfo(), part, bSize); - /* - for (int k = 0; k < total; ) - for (int i = 0; i < part; i++) { - indices->p(k++, i); - } - */ + /* + for (int k = 0; k < total; ) + for (int i = 0; i < part; i++) { + indices->p(k++, i); } - } - - void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector const& params, NDArray* indices) { - NDArray::prepareSpecialUse({values, indices}, {input}); - auto yType = indices == nullptr ? sd::DataType::INT64 : indices->dataType(); - BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), LIBND4J_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({values, indices}, {input}); - } - + */ + } } + +void maxPoolingFunctor(sd::LaunchContext* context, sd::graph::Context& block, + NDArray* input, NDArray* values, + std::vector const& params, NDArray* indices) { + NDArray::prepareSpecialUse({values, indices}, {input}); + auto yType = indices == nullptr ? sd::DataType::INT64 : indices->dataType(); + BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, + (block, input, values, params, indices), LIBND4J_TYPES, + INDEXING_TYPES); + NDArray::registerSpecialUse({values, indices}, {input}); } -} \ No newline at end of file + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/maximum.cu b/libnd4j/include/ops/declarable/helpers/cuda/maximum.cu index c4c8783ffe38..2177eb0282e4 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/maximum.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/maximum.cu @@ -18,90 +18,87 @@ // @author sgazeos@gmail.com // -#include #include #include - +#include namespace sd { - namespace ops { - namespace helpers { - - template - void maximumBPFunctor_(NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) { - - auto lambdaX = LAMBDA_TTT(_e, _x, _y) { - return _x >= _y ? _e : (T) 0.; - }; +namespace ops { +namespace helpers { - auto lambdaY = LAMBDA_TTT(_e, _x, _y) { - return _x <= _y ? _e : (T) 0.; - }; +template +void maximumBPFunctor_(NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, + NDArray* gradY) { + auto lambdaX = LAMBDA_TTT(_e, _x, _y) { return _x >= _y ? _e : (T)0.; }; + auto lambdaY = LAMBDA_TTT(_e, _x, _y) { return _x <= _y ? _e : (T)0.; }; - if (x->isSameShape(y)) { - // PWT case case + if (x->isSameShape(y)) { + // PWT case case - // X gradient - epsNext->applyTriplewiseLambda(*x, *y, lambdaX, *gradX); + // X gradient + epsNext->applyTriplewiseLambda(*x, *y, lambdaX, *gradX); - // Y gradient - epsNext->applyTriplewiseLambda(*x, *y, lambdaY, *gradY); + // Y gradient + epsNext->applyTriplewiseLambda(*x, *y, lambdaY, *gradY); - } else if (y->isScalar()) { - T s = y->e(0); - auto lambdaS = LAMBDA_TT(_e, _x, s) { - return _x >= s ? _e : (T) 0.; - }; + } else if (y->isScalar()) { + T s = y->e(0); + auto lambdaS = LAMBDA_TT(_e, _x, s) { return _x >= s ? _e : (T)0.; }; - // scalar case - auto tmp = epsNext->reduceNumber(reduce::Sum); - if (x <= y) - gradY->assign(tmp); - else - gradY->assign(0.0f); + // scalar case + auto tmp = epsNext->reduceNumber(reduce::Sum); + if (x <= y) + gradY->assign(tmp); + else + gradY->assign(0.0f); - epsNext->applyPairwiseLambda(*x, lambdaS, *gradX); - } else { - // broadcast case + epsNext->applyPairwiseLambda(*x, lambdaS, *gradX); + } else { + // broadcast case - // in this case we want to boost our X and Y shapes to the size of FF pass output (or epsNext, which has the same shape) - auto preX = x->dup(); - auto preY = y->dup(); + // in this case we want to boost our X and Y shapes to the size of FF pass + // output (or epsNext, which has the same shape) + auto preX = x->dup(); + auto preY = y->dup(); - auto targetShape = epsNext->getShapeAsVector(); + auto targetShape = epsNext->getShapeAsVector(); - preX.tileToShape(targetShape, preX); - preY.tileToShape(targetShape, preY); + preX.tileToShape(targetShape, preX); + preY.tileToShape(targetShape, preY); - epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX); - epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY); + epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX); + epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY); - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); + auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), + epsNext->shapeInfo()); + auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), + epsNext->shapeInfo()); - if (axisX.size() > 0) { - auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); - gradX->assign(sum); - } else - gradX->assign(preX); + if (axisX.size() > 0) { + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); + gradX->assign(sum); + } else + gradX->assign(preX); - if (axisY.size() > 0) { - auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); - gradY->assign(sum); - } else - gradY->assign(preY); - } - } + if (axisY.size() > 0) { + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); + gradY->assign(sum); + } else + gradY->assign(preY); + } +} - void maximumBPFunctor(sd::LaunchContext * context, NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) { - NDArray::prepareSpecialUse({gradX, gradY}, {x, y, epsNext}); +void maximumBPFunctor(sd::LaunchContext* context, NDArray* x, NDArray* y, + NDArray* epsNext, NDArray* gradX, NDArray* gradY) { + NDArray::prepareSpecialUse({gradX, gradY}, {x, y, epsNext}); - BUILD_SINGLE_SELECTOR(x->dataType(), maximumBPFunctor_, (x, y, epsNext, gradX, gradY), NUMERIC_TYPES); + BUILD_SINGLE_SELECTOR(x->dataType(), maximumBPFunctor_, + (x, y, epsNext, gradX, gradY), NUMERIC_TYPES); - NDArray::registerSpecialUse({gradX, gradY}, {x, y, epsNext}); - } + NDArray::registerSpecialUse({gradX, gradY}, {x, y, epsNext}); +} - } - } -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/merge.cu b/libnd4j/include/ops/declarable/helpers/cuda/merge.cu index 3c580ee339eb..383cfb657905 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/merge.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/merge.cu @@ -14,533 +14,599 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 - // +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 +// - -#include -#include -#include -#include #include -#include +#include #include -#include #include +#include +#include +#include +#include + +#include namespace sd { - namespace ops { - namespace helpers { - ////////////////////////////////////////////////////////////////////////// - template - static __global__ void mergeMaxIndexCudaLauncher(void** inArrs, void** inShapes, const int numArrays, void* voutput, const Nd4jLong* outputShape, Nd4jLong length) { - auto output = reinterpret_cast(voutput); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - - for (Nd4jLong e = tid; e < length; e += step) { - T mVal = -DataTypeUtils::max(); - Z mIdx(0); - - for (int i = 0; i < numArrays; i++) { - auto x = reinterpret_cast(inArrs[i]); - auto xShape = reinterpret_cast(inShapes[i]); - auto val = x[shape::getIndexOffset(e, xShape)];; - if (mVal < val) { - mIdx = static_cast(i); - mVal = val; - } - } - - output[shape::getIndexOffset(e, outputShape)] = mIdx; - } - } - - template - static void mergeMaxIndex_(sd::LaunchContext* context, const std::vector& inArrs, NDArray& output) { - - int nArrSize = static_cast(inArrs.size()); - std::vector inBuffers(nArrSize), inShapes(nArrSize); - - for (int e = 0; e < nArrSize; e++) { - inBuffers[e] = inArrs[e]->specialBuffer(); - inShapes[e] = inArrs[e]->specialShapeInfo(); - } - - PointersManager manager(context, "mergeMaxIndex"); - - auto pInBuffers = reinterpret_cast(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void*))); - auto pInShapes = reinterpret_cast(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void*))); - auto length = output.lengthOf(); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; - - mergeMaxIndexCudaLauncher<<getCudaStream()>>>(pInBuffers, pInShapes, nArrSize, output.specialBuffer(), output.specialShapeInfo(), length); - - manager.synchronize(); - } - - void mergeMaxIndex(sd::LaunchContext* context, const std::vector& inArrs, NDArray& output) { - - NDArray::prepareSpecialUse({ &output }, inArrs); - - BUILD_DOUBLE_SELECTOR(inArrs[0]->dataType(), output.dataType(), mergeMaxIndex_, (context, inArrs, output), LIBND4J_TYPES, INDEXING_TYPES); - - NDArray::registerSpecialUse({ &output }, inArrs); - } - - - ////////////////////////////////////////////////////////////////////////// - template - static __global__ void mergeMaxCudaLauncher(void** inArrs, void** inShapes, const int numArrays, void* voutput, const Nd4jLong* outputShape, Nd4jLong length) { - auto output = reinterpret_cast(voutput); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - - for (Nd4jLong e = tid; e < length; e += step) { - T mVal = -DataTypeUtils::max(); - - for (int i = 0; i < numArrays; i++) { - auto x = reinterpret_cast(inArrs[i]); - auto xShape = reinterpret_cast(inShapes[i]); - auto val = x[shape::getIndexOffset(e, xShape)];; - if (mVal < val) - mVal = val; - } - - output[shape::getIndexOffset(e, outputShape)] = mVal; - } - } - - template - static void mergeMax_(sd::LaunchContext* context, const std::vector& inArrs, NDArray& output) { - - int nArrsSize = static_cast(inArrs.size()); - - std::vector inBuffers(nArrsSize), inShapes(nArrsSize); - - for (int e = 0; e < nArrsSize; e++) { - inBuffers[e] = inArrs[e]->specialBuffer(); - inShapes[e] = inArrs[e]->specialShapeInfo(); - } - - PointersManager manager(context, "mergeMax"); - - auto pInBuffers = reinterpret_cast(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void*))); - auto pInShapes = reinterpret_cast(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void*))); - auto length = output.lengthOf(); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; - - mergeMaxCudaLauncher<<getCudaStream()>>>(pInBuffers, pInShapes, nArrsSize, output.specialBuffer(), output.specialShapeInfo(), length); - - manager.synchronize(); - } - - void mergeMax(sd::LaunchContext* context, const std::vector& inArrs, NDArray& output) { - - NDArray::prepareSpecialUse({ &output }, inArrs); - - BUILD_SINGLE_SELECTOR(output.dataType(), mergeMax_, (context, inArrs, output), LIBND4J_TYPES); - - NDArray::registerSpecialUse({ &output }, inArrs); - } - - ////////////////////////////////////////////////////////////////////////// - template - static __global__ void mergeMaxBpCudaLauncher( - void** inArrs, void** inShapes, - const void* vgradient, const Nd4jLong* gradientShape, - const int numArrays, - void** outArrs, void** outShapes, - Nd4jLong length, - bool bSameOrderAndEws1) { +namespace ops { +namespace helpers { +////////////////////////////////////////////////////////////////////////// +template +static __global__ void mergeMaxIndexCudaLauncher(void** inArrs, void** inShapes, + const int numArrays, + void* voutput, + const Nd4jLong* outputShape, + Nd4jLong length) { + auto output = reinterpret_cast(voutput); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + + for (Nd4jLong e = tid; e < length; e += step) { + T mVal = -DataTypeUtils::max(); + Z mIdx(0); + + for (int i = 0; i < numArrays; i++) { + auto x = reinterpret_cast(inArrs[i]); + auto xShape = reinterpret_cast(inShapes[i]); + auto val = x[shape::getIndexOffset(e, xShape)]; + ; + if (mVal < val) { + mIdx = static_cast(i); + mVal = val; + } + } + + output[shape::getIndexOffset(e, outputShape)] = mIdx; + } +} - auto grad = reinterpret_cast(vgradient); +template +static void mergeMaxIndex_(sd::LaunchContext* context, + const std::vector& inArrs, + NDArray& output) { + int nArrSize = static_cast(inArrs.size()); + std::vector inBuffers(nArrSize), inShapes(nArrSize); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; + for (int e = 0; e < nArrSize; e++) { + inBuffers[e] = inArrs[e]->specialBuffer(); + inShapes[e] = inArrs[e]->specialShapeInfo(); + } - int coords[MAX_RANK]; + PointersManager manager(context, "mergeMaxIndex"); - for (Nd4jLong e = tid; e < length; e += step) { + auto pInBuffers = reinterpret_cast(manager.replicatePointer( + inBuffers.data(), inBuffers.size() * sizeof(void*))); + auto pInShapes = reinterpret_cast(manager.replicatePointer( + inShapes.data(), inShapes.size() * sizeof(void*))); + auto length = output.lengthOf(); - T mVal = -DataTypeUtils::max(); - int nMaxIndex = 0; - auto xOffset = e, zOffset = e, gradOffset = e; + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; + + mergeMaxIndexCudaLauncher + <<getCudaStream()>>>( + pInBuffers, pInShapes, nArrSize, output.specialBuffer(), + output.specialShapeInfo(), length); + + manager.synchronize(); +} - if (!bSameOrderAndEws1) { - shape::index2coords(e, gradientShape, coords); - gradOffset = shape::getOffset(gradientShape, coords); - } +void mergeMaxIndex(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output) { + NDArray::prepareSpecialUse({&output}, inArrs); - for (int i = 0; i < numArrays; i++) { - auto x = reinterpret_cast(inArrs[i]); + BUILD_DOUBLE_SELECTOR(inArrs[0]->dataType(), output.dataType(), + mergeMaxIndex_, (context, inArrs, output), + LIBND4J_TYPES, INDEXING_TYPES); - if (!bSameOrderAndEws1) { - auto xShape = reinterpret_cast(inShapes[i]); - xOffset = shape::getOffset(xShape, coords); - } + NDArray::registerSpecialUse({&output}, inArrs); +} - auto val = x[xOffset]; - if (mVal < val) { - mVal = val; - nMaxIndex = i; - } - } - - // outputs have to be pre-nullify - if (!bSameOrderAndEws1) { - auto outShape = reinterpret_cast(outShapes[nMaxIndex]); - zOffset = shape::getOffset(outShape, coords); - } +////////////////////////////////////////////////////////////////////////// +template +static __global__ void mergeMaxCudaLauncher(void** inArrs, void** inShapes, + const int numArrays, void* voutput, + const Nd4jLong* outputShape, + Nd4jLong length) { + auto output = reinterpret_cast(voutput); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + + for (Nd4jLong e = tid; e < length; e += step) { + T mVal = -DataTypeUtils::max(); + + for (int i = 0; i < numArrays; i++) { + auto x = reinterpret_cast(inArrs[i]); + auto xShape = reinterpret_cast(inShapes[i]); + auto val = x[shape::getIndexOffset(e, xShape)]; + ; + if (mVal < val) mVal = val; + } - auto output = reinterpret_cast(outArrs[nMaxIndex]); + output[shape::getIndexOffset(e, outputShape)] = mVal; + } +} - output[zOffset] = grad[gradOffset]; - } - } +template +static void mergeMax_(sd::LaunchContext* context, + const std::vector& inArrs, + NDArray& output) { + int nArrsSize = static_cast(inArrs.size()); - template - static void mergeMaxBp_(sd::LaunchContext* context, const std::vector& inArrs, std::vector& outArrs, int nArrSize, bool bSameOrderAndEws1) { + std::vector inBuffers(nArrsSize), inShapes(nArrsSize); - std::vector inBuffers(nArrSize), inShapes(nArrSize), outBuffers(nArrSize), outShapes(nArrSize); + for (int e = 0; e < nArrsSize; e++) { + inBuffers[e] = inArrs[e]->specialBuffer(); + inShapes[e] = inArrs[e]->specialShapeInfo(); + } - for (int e = 0; e < nArrSize; e++) { - inBuffers[e] = inArrs[e]->specialBuffer(); - inShapes[e] = inArrs[e]->specialShapeInfo(); - outBuffers[e] = outArrs[e]->specialBuffer(); - outShapes[e] = outArrs[e]->specialShapeInfo(); - } + PointersManager manager(context, "mergeMax"); - PointersManager manager(context, "mergeMaxBp"); + auto pInBuffers = reinterpret_cast(manager.replicatePointer( + inBuffers.data(), inBuffers.size() * sizeof(void*))); + auto pInShapes = reinterpret_cast(manager.replicatePointer( + inShapes.data(), inShapes.size() * sizeof(void*))); + auto length = output.lengthOf(); - auto pInBuffers = reinterpret_cast(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void*))); - auto pInShapes = reinterpret_cast(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void*))); + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; - auto pOutBuffers = reinterpret_cast(manager.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void*))); - auto pOutShapes = reinterpret_cast(manager.replicatePointer(outShapes.data(), outShapes.size() * sizeof(void*))); + mergeMaxCudaLauncher + <<getCudaStream()>>>( + pInBuffers, pInShapes, nArrsSize, output.specialBuffer(), + output.specialShapeInfo(), length); - auto length = inArrs[nArrSize]->lengthOf(); + manager.synchronize(); +} - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; +void mergeMax(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output) { + NDArray::prepareSpecialUse({&output}, inArrs); - mergeMaxBpCudaLauncher<<getCudaStream()>>>(pInBuffers, pInShapes, inArrs[nArrSize]->specialBuffer(), - inArrs[nArrSize]->specialShapeInfo(), nArrSize, pOutBuffers, pOutShapes, - length, bSameOrderAndEws1); - - manager.synchronize(); - } + BUILD_SINGLE_SELECTOR(output.dataType(), mergeMax_, (context, inArrs, output), + LIBND4J_TYPES); - void mergeMaxBp(sd::LaunchContext* context, const std::vector& inArrs, std::vector& outArrs) { + NDArray::registerSpecialUse({&output}, inArrs); +} - // not use gradient - int nArrSize = static_cast(inArrs.size() - 1); - - const std::vector& out = reinterpret_cast&>(outArrs); +////////////////////////////////////////////////////////////////////////// +template +static __global__ void mergeMaxBpCudaLauncher( + void** inArrs, void** inShapes, const void* vgradient, + const Nd4jLong* gradientShape, const int numArrays, void** outArrs, + void** outShapes, Nd4jLong length, bool bSameOrderAndEws1) { + auto grad = reinterpret_cast(vgradient); - NDArray::prepareSpecialUse(out, inArrs); + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; - bool bSameOrderAndEws1 = (1 == inArrs[nArrSize]->ews()); - auto ordering = inArrs[nArrSize]->ordering(); - - for (int i = 0; i < nArrSize; ++i) { - bSameOrderAndEws1 &= (ordering == inArrs[i]->ordering()); - bSameOrderAndEws1 &= (1 == inArrs[i]->ews()); - - bSameOrderAndEws1 &= (ordering == outArrs[i]->ordering()); - bSameOrderAndEws1 &= (1 == outArrs[i]->ews()); - } + int coords[MAX_RANK]; - BUILD_SINGLE_SELECTOR(inArrs[nArrSize]->dataType(), mergeMaxBp_, (context, inArrs, outArrs, nArrSize, bSameOrderAndEws1), LIBND4J_TYPES); + for (Nd4jLong e = tid; e < length; e += step) { + T mVal = -DataTypeUtils::max(); + int nMaxIndex = 0; + auto xOffset = e, zOffset = e, gradOffset = e; - NDArray::registerSpecialUse( out, inArrs ); - } + if (!bSameOrderAndEws1) { + shape::index2coords(e, gradientShape, coords); + gradOffset = shape::getOffset(gradientShape, coords); + } + for (int i = 0; i < numArrays; i++) { + auto x = reinterpret_cast(inArrs[i]); - ////////////////////////////////////////////////////////////////////////// - template - static __global__ void mergeAvgCudaLauncher(void** inArrs, void** inShapes, const int numArrays, void* voutput, const Nd4jLong* outputShape, Nd4jLong length) { - auto output = reinterpret_cast(voutput); + if (!bSameOrderAndEws1) { + auto xShape = reinterpret_cast(inShapes[i]); + xOffset = shape::getOffset(xShape, coords); + } - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; + auto val = x[xOffset]; + if (mVal < val) { + mVal = val; + nMaxIndex = i; + } + } - for (Nd4jLong e = tid; e < length; e += step) { - T sum(0.0f); + // outputs have to be pre-nullify + if (!bSameOrderAndEws1) { + auto outShape = reinterpret_cast(outShapes[nMaxIndex]); + zOffset = shape::getOffset(outShape, coords); + } - for (int i = 0; i < numArrays; i++) { - auto x = reinterpret_cast(inArrs[i]); - auto xShape = reinterpret_cast(inShapes[i]); + auto output = reinterpret_cast(outArrs[nMaxIndex]); - sum += x[shape::getIndexOffset(e, xShape)]; - } + output[zOffset] = grad[gradOffset]; + } +} - output[shape::getIndexOffset(e, outputShape)] = sum / numArrays; - } - } +template +static void mergeMaxBp_(sd::LaunchContext* context, + const std::vector& inArrs, + std::vector& outArrs, int nArrSize, + bool bSameOrderAndEws1) { + std::vector inBuffers(nArrSize), inShapes(nArrSize), + outBuffers(nArrSize), outShapes(nArrSize); + + for (int e = 0; e < nArrSize; e++) { + inBuffers[e] = inArrs[e]->specialBuffer(); + inShapes[e] = inArrs[e]->specialShapeInfo(); + outBuffers[e] = outArrs[e]->specialBuffer(); + outShapes[e] = outArrs[e]->specialShapeInfo(); + } + + PointersManager manager(context, "mergeMaxBp"); + + auto pInBuffers = reinterpret_cast(manager.replicatePointer( + inBuffers.data(), inBuffers.size() * sizeof(void*))); + auto pInShapes = reinterpret_cast(manager.replicatePointer( + inShapes.data(), inShapes.size() * sizeof(void*))); + + auto pOutBuffers = reinterpret_cast(manager.replicatePointer( + outBuffers.data(), outBuffers.size() * sizeof(void*))); + auto pOutShapes = reinterpret_cast(manager.replicatePointer( + outShapes.data(), outShapes.size() * sizeof(void*))); + + auto length = inArrs[nArrSize]->lengthOf(); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; + + mergeMaxBpCudaLauncher + <<getCudaStream()>>>( + pInBuffers, pInShapes, inArrs[nArrSize]->specialBuffer(), + inArrs[nArrSize]->specialShapeInfo(), nArrSize, pOutBuffers, + pOutShapes, length, bSameOrderAndEws1); + + manager.synchronize(); +} - template - static void mergeAvg_(sd::LaunchContext* context, const std::vector& inArrs, NDArray& output) { - - std::vector inBuffers(inArrs.size()), inShapes(inArrs.size()); +void mergeMaxBp(sd::LaunchContext* context, + const std::vector& inArrs, + std::vector& outArrs) { + // not use gradient + int nArrSize = static_cast(inArrs.size() - 1); - for (int e = 0; e < inArrs.size(); e++) { - inBuffers[e] = inArrs[e]->specialBuffer(); - inShapes[e] = inArrs[e]->specialShapeInfo(); - } + const std::vector& out = + reinterpret_cast&>(outArrs); - PointersManager manager(context, "mergeAvg"); + NDArray::prepareSpecialUse(out, inArrs); - auto pInBuffers = reinterpret_cast(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void*))); - auto pInShapes = reinterpret_cast(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void*))); - auto length = output.lengthOf(); + bool bSameOrderAndEws1 = (1 == inArrs[nArrSize]->ews()); + auto ordering = inArrs[nArrSize]->ordering(); - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; + for (int i = 0; i < nArrSize; ++i) { + bSameOrderAndEws1 &= (ordering == inArrs[i]->ordering()); + bSameOrderAndEws1 &= (1 == inArrs[i]->ews()); - mergeAvgCudaLauncher<<getCudaStream()>>>(pInBuffers, pInShapes, (int)inArrs.size(), output.specialBuffer(), output.specialShapeInfo(), length); + bSameOrderAndEws1 &= (ordering == outArrs[i]->ordering()); + bSameOrderAndEws1 &= (1 == outArrs[i]->ews()); + } - manager.synchronize(); - } + BUILD_SINGLE_SELECTOR(inArrs[nArrSize]->dataType(), mergeMaxBp_, + (context, inArrs, outArrs, nArrSize, bSameOrderAndEws1), + LIBND4J_TYPES); - void mergeAvg(sd::LaunchContext* context, const std::vector& inArrs, NDArray& output) { - - NDArray::prepareSpecialUse({ &output }, inArrs); + NDArray::registerSpecialUse(out, inArrs); +} - BUILD_SINGLE_SELECTOR(output.dataType(), mergeAvg_, (context, inArrs, output), FLOAT_TYPES); +////////////////////////////////////////////////////////////////////////// +template +static __global__ void mergeAvgCudaLauncher(void** inArrs, void** inShapes, + const int numArrays, void* voutput, + const Nd4jLong* outputShape, + Nd4jLong length) { + auto output = reinterpret_cast(voutput); - NDArray::registerSpecialUse({ &output }, inArrs); - } - ////////////////////////////////////////////////////////////////////////// - template - static __global__ void mergeAvgBpCudaLauncher( - const void* vgradient, const Nd4jLong* gradientShape, - void** outArrs, void** outShapes, - const int numArrays, - Nd4jLong length, - bool bSameOrderAndEws1) { + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; - auto grad = reinterpret_cast(vgradient); + for (Nd4jLong e = tid; e < length; e += step) { + T sum(0.0f); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; + for (int i = 0; i < numArrays; i++) { + auto x = reinterpret_cast(inArrs[i]); + auto xShape = reinterpret_cast(inShapes[i]); - int coords[MAX_RANK]; + sum += x[shape::getIndexOffset(e, xShape)]; + } - for (Nd4jLong e = tid; e < length; e += step) { + output[shape::getIndexOffset(e, outputShape)] = sum / numArrays; + } +} - auto zOffset = e, gradOffset = e; - if (!bSameOrderAndEws1) { - shape::index2coords(e, gradientShape, coords); - gradOffset = shape::getOffset(gradientShape, coords); - } +template +static void mergeAvg_(sd::LaunchContext* context, + const std::vector& inArrs, + NDArray& output) { + std::vector inBuffers(inArrs.size()), inShapes(inArrs.size()); - for (int i = 0; i < numArrays; i++) { + for (int e = 0; e < inArrs.size(); e++) { + inBuffers[e] = inArrs[e]->specialBuffer(); + inShapes[e] = inArrs[e]->specialShapeInfo(); + } - if (!bSameOrderAndEws1) { - auto outShape = reinterpret_cast(outShapes[i]); - zOffset = shape::getOffset(outShape, coords); - } + PointersManager manager(context, "mergeAvg"); - auto output = reinterpret_cast(outArrs[i]); + auto pInBuffers = reinterpret_cast(manager.replicatePointer( + inBuffers.data(), inBuffers.size() * sizeof(void*))); + auto pInShapes = reinterpret_cast(manager.replicatePointer( + inShapes.data(), inShapes.size() * sizeof(void*))); + auto length = output.lengthOf(); - output[zOffset] = grad[gradOffset] / numArrays; - } - } - } + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; - template - static void mergeAvgBp_(sd::LaunchContext* context, const NDArray& gradient, std::vector& outArrs, bool bSameOrderAndEws1) { + mergeAvgCudaLauncher + <<getCudaStream()>>>( + pInBuffers, pInShapes, (int)inArrs.size(), output.specialBuffer(), + output.specialShapeInfo(), length); - int nArrSize = static_cast(outArrs.size()); + manager.synchronize(); +} - std::vector outBuffers(nArrSize), outShapes(nArrSize); +void mergeAvg(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output) { + NDArray::prepareSpecialUse({&output}, inArrs); - for (int e = 0; e < nArrSize; e++) { - outBuffers[e] = outArrs[e]->specialBuffer(); - outShapes[e] = outArrs[e]->specialShapeInfo(); - } + BUILD_SINGLE_SELECTOR(output.dataType(), mergeAvg_, (context, inArrs, output), + FLOAT_TYPES); - PointersManager manager(context, "mergeAvgBp"); + NDArray::registerSpecialUse({&output}, inArrs); +} +////////////////////////////////////////////////////////////////////////// +template +static __global__ void mergeAvgBpCudaLauncher(const void* vgradient, + const Nd4jLong* gradientShape, + void** outArrs, void** outShapes, + const int numArrays, + Nd4jLong length, + bool bSameOrderAndEws1) { + auto grad = reinterpret_cast(vgradient); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + + int coords[MAX_RANK]; + + for (Nd4jLong e = tid; e < length; e += step) { + auto zOffset = e, gradOffset = e; + if (!bSameOrderAndEws1) { + shape::index2coords(e, gradientShape, coords); + gradOffset = shape::getOffset(gradientShape, coords); + } - auto pOutBuffers = reinterpret_cast(manager.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void*))); - auto pOutShapes = reinterpret_cast(manager.replicatePointer(outShapes.data(), outShapes.size() * sizeof(void*))); + for (int i = 0; i < numArrays; i++) { + if (!bSameOrderAndEws1) { + auto outShape = reinterpret_cast(outShapes[i]); + zOffset = shape::getOffset(outShape, coords); + } - auto length = gradient.lengthOf(); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; + auto output = reinterpret_cast(outArrs[i]); - mergeAvgBpCudaLauncher<<getCudaStream()>>>(gradient.specialBuffer(), gradient.specialShapeInfo(), - pOutBuffers, pOutShapes, nArrSize, length, bSameOrderAndEws1); + output[zOffset] = grad[gradOffset] / numArrays; + } + } +} - manager.synchronize(); - } +template +static void mergeAvgBp_(sd::LaunchContext* context, const NDArray& gradient, + std::vector& outArrs, + bool bSameOrderAndEws1) { + int nArrSize = static_cast(outArrs.size()); - void mergeAvgBp(sd::LaunchContext* context, const NDArray& gradient, std::vector& outArrs) { + std::vector outBuffers(nArrSize), outShapes(nArrSize); - const std::vector& out = reinterpret_cast&>(outArrs); + for (int e = 0; e < nArrSize; e++) { + outBuffers[e] = outArrs[e]->specialBuffer(); + outShapes[e] = outArrs[e]->specialShapeInfo(); + } - NDArray::prepareSpecialUse( out, { &gradient }); + PointersManager manager(context, "mergeAvgBp"); - bool bSameOrderAndEws1 = (1 == gradient.ews()); - auto ordering = gradient.ordering(); + auto pOutBuffers = reinterpret_cast(manager.replicatePointer( + outBuffers.data(), outBuffers.size() * sizeof(void*))); + auto pOutShapes = reinterpret_cast(manager.replicatePointer( + outShapes.data(), outShapes.size() * sizeof(void*))); - for (const auto& v : outArrs) { - bSameOrderAndEws1 &= (ordering == v->ordering()); - bSameOrderAndEws1 &= (1 == v->ews()); - } + auto length = gradient.lengthOf(); - BUILD_SINGLE_SELECTOR(gradient.dataType(), mergeAvgBp_, (context, gradient, outArrs, bSameOrderAndEws1), LIBND4J_TYPES); + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; - NDArray::prepareSpecialUse(out, { &gradient }); - } + mergeAvgBpCudaLauncher + <<getCudaStream()>>>( + gradient.specialBuffer(), gradient.specialShapeInfo(), pOutBuffers, + pOutShapes, nArrSize, length, bSameOrderAndEws1); - ////////////////////////////////////////////////////////////////////////// - template - static __global__ void mergeAddCudaLauncher(void** inArrs, void** inShapes, const int numArrays, void* voutput, const Nd4jLong* outputShape, Nd4jLong length) { - - auto output = reinterpret_cast(voutput); + manager.synchronize(); +} - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; +void mergeAvgBp(sd::LaunchContext* context, const NDArray& gradient, + std::vector& outArrs) { + const std::vector& out = + reinterpret_cast&>(outArrs); - for (Nd4jLong e = tid; e < length; e += step) { - T sum(0.0f); + NDArray::prepareSpecialUse(out, {&gradient}); - for (int i = 0; i < numArrays; i++) { - auto x = reinterpret_cast(inArrs[i]); - auto xShape = reinterpret_cast(inShapes[i]); + bool bSameOrderAndEws1 = (1 == gradient.ews()); + auto ordering = gradient.ordering(); - sum += x[shape::getIndexOffset(e, xShape)]; - } + for (const auto& v : outArrs) { + bSameOrderAndEws1 &= (ordering == v->ordering()); + bSameOrderAndEws1 &= (1 == v->ews()); + } - output[shape::getIndexOffset(e, outputShape)] = sum; - } - } + BUILD_SINGLE_SELECTOR(gradient.dataType(), mergeAvgBp_, + (context, gradient, outArrs, bSameOrderAndEws1), + LIBND4J_TYPES); - template - static void mergeAdd_(sd::LaunchContext* context, const std::vector& inArrs, NDArray& output) { - - int nArrSize = static_cast(inArrs.size()); - std::vector inBuffers(nArrSize), inShapes(nArrSize); + NDArray::prepareSpecialUse(out, {&gradient}); +} - for (int e = 0; e < nArrSize; e++) { - inBuffers[e] = inArrs[e]->specialBuffer(); - inShapes[e] = inArrs[e]->specialShapeInfo(); - } +////////////////////////////////////////////////////////////////////////// +template +static __global__ void mergeAddCudaLauncher(void** inArrs, void** inShapes, + const int numArrays, void* voutput, + const Nd4jLong* outputShape, + Nd4jLong length) { + auto output = reinterpret_cast(voutput); - PointersManager manager(context, "mergeAdd"); + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; - auto pInBuffers = reinterpret_cast(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void*))); - auto pInShapes = reinterpret_cast(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void*))); - auto length = output.lengthOf(); + for (Nd4jLong e = tid; e < length; e += step) { + T sum(0.0f); - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; + for (int i = 0; i < numArrays; i++) { + auto x = reinterpret_cast(inArrs[i]); + auto xShape = reinterpret_cast(inShapes[i]); - mergeAddCudaLauncher<<getCudaStream()>>>(pInBuffers, pInShapes, nArrSize, output.specialBuffer(), output.specialShapeInfo(), length); + sum += x[shape::getIndexOffset(e, xShape)]; + } - manager.synchronize(); - } - BUILD_SINGLE_TEMPLATE(template void mergeAdd_, (sd::LaunchContext* context, const std::vector& inArrs, NDArray& output), NUMERIC_TYPES); + output[shape::getIndexOffset(e, outputShape)] = sum; + } +} - void mergeAdd(sd::LaunchContext* context, const std::vector& inArrs, NDArray& output) { - - NDArray::prepareSpecialUse({ &output }, inArrs); - - BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (context, inArrs, output), NUMERIC_TYPES); +template +static void mergeAdd_(sd::LaunchContext* context, + const std::vector& inArrs, + NDArray& output) { + int nArrSize = static_cast(inArrs.size()); + std::vector inBuffers(nArrSize), inShapes(nArrSize); - NDArray::registerSpecialUse({ &output }, inArrs); - } + for (int e = 0; e < nArrSize; e++) { + inBuffers[e] = inArrs[e]->specialBuffer(); + inShapes[e] = inArrs[e]->specialShapeInfo(); + } - ////////////////////////////////////////////////////////////////////////// - template - static __global__ void mergeAddBpCudaLauncher(const void* vgradient, const Nd4jLong* gradientShape, void** outArrs, void** outShapes, - const int numArrays, Nd4jLong length, bool bSameOrderAndEws1) { + PointersManager manager(context, "mergeAdd"); - auto grad = reinterpret_cast(vgradient); + auto pInBuffers = reinterpret_cast(manager.replicatePointer( + inBuffers.data(), inBuffers.size() * sizeof(void*))); + auto pInShapes = reinterpret_cast(manager.replicatePointer( + inShapes.data(), inShapes.size() * sizeof(void*))); + auto length = output.lengthOf(); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; - int coords[MAX_RANK]; + mergeAddCudaLauncher + <<getCudaStream()>>>( + pInBuffers, pInShapes, nArrSize, output.specialBuffer(), + output.specialShapeInfo(), length); - for (Nd4jLong e = tid; e < length; e += step) { + manager.synchronize(); +} +BUILD_SINGLE_TEMPLATE(template void mergeAdd_, + (sd::LaunchContext * context, + const std::vector& inArrs, + NDArray& output), + NUMERIC_TYPES); - auto zOffset = e, gradOffset = e; - if (!bSameOrderAndEws1) { - shape::index2coords(e, gradientShape, coords); - gradOffset = shape::getOffset(gradientShape, coords); - } +void mergeAdd(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output) { + NDArray::prepareSpecialUse({&output}, inArrs); - for (int i = 0; i < numArrays; i++) { - - if (!bSameOrderAndEws1) { - auto outShape = reinterpret_cast(outShapes[i]); - zOffset = shape::getOffset(outShape, coords); - } + BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (context, inArrs, output), + NUMERIC_TYPES); - auto output = reinterpret_cast(outArrs[i]); + NDArray::registerSpecialUse({&output}, inArrs); +} - output[zOffset] = grad[gradOffset]; - } - } - } +////////////////////////////////////////////////////////////////////////// +template +static __global__ void mergeAddBpCudaLauncher(const void* vgradient, + const Nd4jLong* gradientShape, + void** outArrs, void** outShapes, + const int numArrays, + Nd4jLong length, + bool bSameOrderAndEws1) { + auto grad = reinterpret_cast(vgradient); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + + int coords[MAX_RANK]; + + for (Nd4jLong e = tid; e < length; e += step) { + auto zOffset = e, gradOffset = e; + if (!bSameOrderAndEws1) { + shape::index2coords(e, gradientShape, coords); + gradOffset = shape::getOffset(gradientShape, coords); + } - template - static void mergeAddBp_(sd::LaunchContext* context, const NDArray& gradient, std::vector& outArrs, bool bSameOrderAndEws1) { + for (int i = 0; i < numArrays; i++) { + if (!bSameOrderAndEws1) { + auto outShape = reinterpret_cast(outShapes[i]); + zOffset = shape::getOffset(outShape, coords); + } - int nArrSize = static_cast(outArrs.size()); + auto output = reinterpret_cast(outArrs[i]); - std::vector outBuffers(nArrSize), outShapes(nArrSize); + output[zOffset] = grad[gradOffset]; + } + } +} - for (int e = 0; e < nArrSize; e++) { - outBuffers[e] = outArrs[e]->specialBuffer(); - outShapes[e] = outArrs[e]->specialShapeInfo(); - } +template +static void mergeAddBp_(sd::LaunchContext* context, const NDArray& gradient, + std::vector& outArrs, + bool bSameOrderAndEws1) { + int nArrSize = static_cast(outArrs.size()); - PointersManager manager(context, "mergeAddBp"); + std::vector outBuffers(nArrSize), outShapes(nArrSize); - auto pOutBuffers = reinterpret_cast(manager.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void*))); - auto pOutShapes = reinterpret_cast(manager.replicatePointer(outShapes.data(), outShapes.size() * sizeof(void*))); + for (int e = 0; e < nArrSize; e++) { + outBuffers[e] = outArrs[e]->specialBuffer(); + outShapes[e] = outArrs[e]->specialShapeInfo(); + } - auto length = gradient.lengthOf(); + PointersManager manager(context, "mergeAddBp"); - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; + auto pOutBuffers = reinterpret_cast(manager.replicatePointer( + outBuffers.data(), outBuffers.size() * sizeof(void*))); + auto pOutShapes = reinterpret_cast(manager.replicatePointer( + outShapes.data(), outShapes.size() * sizeof(void*))); - mergeAddBpCudaLauncher<<getCudaStream()>>>(gradient.specialBuffer(), gradient.specialShapeInfo(), - pOutBuffers, pOutShapes, nArrSize, length, bSameOrderAndEws1); + auto length = gradient.lengthOf(); - manager.synchronize(); - } + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; - void mergeAddBp(sd::LaunchContext* context, const NDArray& gradient, std::vector& outArrs) { + mergeAddBpCudaLauncher + <<getCudaStream()>>>( + gradient.specialBuffer(), gradient.specialShapeInfo(), pOutBuffers, + pOutShapes, nArrSize, length, bSameOrderAndEws1); - const std::vector& out = reinterpret_cast& >(outArrs); - NDArray::prepareSpecialUse( out, { &gradient }); + manager.synchronize(); +} - bool bSameOrderAndEws1 = (1 == gradient.ews()); - auto ordering = gradient.ordering(); +void mergeAddBp(sd::LaunchContext* context, const NDArray& gradient, + std::vector& outArrs) { + const std::vector& out = + reinterpret_cast&>(outArrs); + NDArray::prepareSpecialUse(out, {&gradient}); - for (const auto& v : outArrs) { - bSameOrderAndEws1 &= (ordering == v->ordering()); - bSameOrderAndEws1 &= (1 == v->ews()); - } + bool bSameOrderAndEws1 = (1 == gradient.ews()); + auto ordering = gradient.ordering(); - BUILD_SINGLE_SELECTOR(gradient.dataType(), mergeAddBp_, (context, gradient, outArrs, bSameOrderAndEws1), LIBND4J_TYPES); + for (const auto& v : outArrs) { + bSameOrderAndEws1 &= (ordering == v->ordering()); + bSameOrderAndEws1 &= (1 == v->ews()); + } - NDArray::prepareSpecialUse( out, { &gradient }); - } + BUILD_SINGLE_SELECTOR(gradient.dataType(), mergeAddBp_, + (context, gradient, outArrs, bSameOrderAndEws1), + LIBND4J_TYPES); - } - } + NDArray::prepareSpecialUse(out, {&gradient}); } + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/meshgrid.cu b/libnd4j/include/ops/declarable/helpers/cuda/meshgrid.cu index 918dca510d95..eb18064f14ae 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/meshgrid.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/meshgrid.cu @@ -18,131 +18,144 @@ // @author raver119@gmail.com // - -#include -#include -#include #include +#include +#include +#include + #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - template - static _CUDA_D void assign_(void *vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - - auto xEws = shape::elementWiseStride(xShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); +template +static _CUDA_D void assign_(void *vx, Nd4jLong *xShapeInfo, void *vz, + Nd4jLong *zShapeInfo) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); - auto xOrder = shape::order(xShapeInfo); - auto zOrder = shape::order(zShapeInfo); + auto tid = threadIdx.x + blockIdx.x * blockDim.x; - __shared__ Nd4jLong length; + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); - if (threadIdx.x == 0) { - length = shape::length(xShapeInfo); - } - __syncthreads(); + auto xOrder = shape::order(xShapeInfo); + auto zOrder = shape::order(zShapeInfo); - if (xEws > 0 && zEws > 0 && xOrder == zOrder) { - for (int i = threadIdx.x; i < length; i += blockDim.x) { - z[i * zEws] = x[i * xEws]; - } - } else { - for (int i = threadIdx.x; i < length; i += blockDim.x) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto zOffset = shape::getIndexOffset(i, zShapeInfo); + __shared__ Nd4jLong length; - z[zOffset] = x[xOffset]; - } - } + if (threadIdx.x == 0) { + length = shape::length(xShapeInfo); + } + __syncthreads(); + if (xEws > 0 && zEws > 0 && xOrder == zOrder) { + for (int i = threadIdx.x; i < length; i += blockDim.x) { + z[i * zEws] = x[i * xEws]; } + } else { + for (int i = threadIdx.x; i < length; i += blockDim.x) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto zOffset = shape::getIndexOffset(i, zShapeInfo); - template - static _CUDA_G void meshgridKernel(int rank, void **outBuffers, Nd4jLong **tadShapes, Nd4jLong **tadOffsets, Nd4jLong *numTads, void **inBuffers, Nd4jLong **inShapes) { - // for all arrays - for (int i = blockIdx.x; i < rank; i += gridDim.x) { - - // for all tads in this array - for(Nd4jLong j = 0; j < numTads[i]; j++) { - assign_(inBuffers[i], inShapes[i], reinterpret_cast(outBuffers[i]) + tadOffsets[i][j], tadShapes[i]); - } - __syncthreads(); - } + z[zOffset] = x[xOffset]; } + } +} - template - static void meshgrid_(sd::LaunchContext * context, const std::vector& inArrs, const std::vector& outArrs, const bool swapFirst2Dims) { - const int rank = inArrs.size(); - int inIndices[MAX_RANK]; - std::iota(inIndices, inIndices + rank, 0); - if(swapFirst2Dims && rank > 1) { - inIndices[0] = 1; - inIndices[1] = 0; - } - - PointersManager pm(context, "meshgrid"); - std::vector hInBuffers(rank); - std::vector hOutBuffers(rank); - std::vector hInShapes(rank); - - std::vector hOutTadShapes(rank); - std::vector hOutTadOffsets(rank); - - std::vector hNumTads(rank); - - for(int i = 0; i < rank; ++i) { - hInBuffers[i] = inArrs[i]->specialBuffer(); - hInShapes[i] = inArrs[i]->specialShapeInfo(); - - hOutBuffers[i] = outArrs[i]->specialBuffer(); - - - auto pack = ConstantTadHelper::getInstance().tadForDimensions(outArrs[i]->shapeInfo(), {inIndices[i]}); - hOutTadShapes[i] = pack.specialShapeInfo(); - hOutTadOffsets[i] = pack.specialOffsets(); - hNumTads[i] = pack.numberOfTads(); - - - //auto list = outArrs[i]->allTensorsAlongDimension({inIndices[i]}); - //for(int j = 0; j < list->size(); ++j) - // list->at(j)->assign(inArrs[i]); - - //delete list; - } - - auto dInBuffers = reinterpret_cast(pm.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void *))); - auto dOutBuffers = reinterpret_cast(pm.replicatePointer(hOutBuffers.data(), hOutBuffers.size() * sizeof(void *))); - - - auto dInShapes = reinterpret_cast(pm.replicatePointer(hInShapes.data(), hInShapes.size() * sizeof(Nd4jLong *))); - auto dOutTadShapes = reinterpret_cast(pm.replicatePointer(hOutTadShapes.data(), hOutTadShapes.size() * sizeof(Nd4jLong *))); - auto dOutTadOffsets = reinterpret_cast(pm.replicatePointer(hOutTadOffsets.data(), hOutTadOffsets.size() * sizeof(Nd4jLong *))); - - auto dNumTads = reinterpret_cast(pm.replicatePointer(hNumTads.data(), hNumTads.size() * sizeof(Nd4jLong))); - - - meshgridKernel<<<256, 256, 1024, *context->getCudaStream()>>>(rank, dOutBuffers, dOutTadShapes, dOutTadOffsets, dNumTads, dInBuffers, dInShapes); - - pm.synchronize(); +template +static _CUDA_G void meshgridKernel(int rank, void **outBuffers, + Nd4jLong **tadShapes, Nd4jLong **tadOffsets, + Nd4jLong *numTads, void **inBuffers, + Nd4jLong **inShapes) { + // for all arrays + for (int i = blockIdx.x; i < rank; i += gridDim.x) { + // for all tads in this array + for (Nd4jLong j = 0; j < numTads[i]; j++) { + assign_(inBuffers[i], inShapes[i], + reinterpret_cast(outBuffers[i]) + tadOffsets[i][j], + tadShapes[i]); } + __syncthreads(); + } +} - ////////////////////////////////////////////////////////////////////////// - void meshgrid(sd::LaunchContext * context, const std::vector& inArrs, const std::vector& outArrs, const bool swapFirst2Dims) { - - BUILD_SINGLE_SELECTOR(inArrs.at(0)->dataType(), meshgrid_, (context, inArrs, outArrs, swapFirst2Dims), NUMERIC_TYPES); +template +static void meshgrid_(sd::LaunchContext *context, + const std::vector &inArrs, + const std::vector &outArrs, + const bool swapFirst2Dims) { + const int rank = inArrs.size(); + int inIndices[MAX_RANK]; + std::iota(inIndices, inIndices + rank, 0); + if (swapFirst2Dims && rank > 1) { + inIndices[0] = 1; + inIndices[1] = 0; + } + + PointersManager pm(context, "meshgrid"); + std::vector hInBuffers(rank); + std::vector hOutBuffers(rank); + std::vector hInShapes(rank); + + std::vector hOutTadShapes(rank); + std::vector hOutTadOffsets(rank); + + std::vector hNumTads(rank); + + for (int i = 0; i < rank; ++i) { + hInBuffers[i] = inArrs[i]->specialBuffer(); + hInShapes[i] = inArrs[i]->specialShapeInfo(); + + hOutBuffers[i] = outArrs[i]->specialBuffer(); + + auto pack = ConstantTadHelper::getInstance().tadForDimensions( + outArrs[i]->shapeInfo(), {inIndices[i]}); + hOutTadShapes[i] = pack.specialShapeInfo(); + hOutTadOffsets[i] = pack.specialOffsets(); + hNumTads[i] = pack.numberOfTads(); + + // auto list = outArrs[i]->allTensorsAlongDimension({inIndices[i]}); + // for(int j = 0; j < list->size(); ++j) + // list->at(j)->assign(inArrs[i]); + + // delete list; + } + + auto dInBuffers = reinterpret_cast(pm.replicatePointer( + hInBuffers.data(), hInBuffers.size() * sizeof(void *))); + auto dOutBuffers = reinterpret_cast(pm.replicatePointer( + hOutBuffers.data(), hOutBuffers.size() * sizeof(void *))); + + auto dInShapes = reinterpret_cast(pm.replicatePointer( + hInShapes.data(), hInShapes.size() * sizeof(Nd4jLong *))); + auto dOutTadShapes = reinterpret_cast(pm.replicatePointer( + hOutTadShapes.data(), hOutTadShapes.size() * sizeof(Nd4jLong *))); + auto dOutTadOffsets = reinterpret_cast(pm.replicatePointer( + hOutTadOffsets.data(), hOutTadOffsets.size() * sizeof(Nd4jLong *))); + + auto dNumTads = reinterpret_cast( + pm.replicatePointer(hNumTads.data(), hNumTads.size() * sizeof(Nd4jLong))); + + meshgridKernel<<<256, 256, 1024, *context->getCudaStream()>>>( + rank, dOutBuffers, dOutTadShapes, dOutTadOffsets, dNumTads, dInBuffers, + dInShapes); + + pm.synchronize(); +} - for (auto v:outArrs) - v->tickWriteDevice(); - } +////////////////////////////////////////////////////////////////////////// +void meshgrid(sd::LaunchContext *context, const std::vector &inArrs, + const std::vector &outArrs, + const bool swapFirst2Dims) { + BUILD_SINGLE_SELECTOR(inArrs.at(0)->dataType(), meshgrid_, + (context, inArrs, outArrs, swapFirst2Dims), + NUMERIC_TYPES); -} -} + for (auto v : outArrs) v->tickWriteDevice(); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/minimum.cu b/libnd4j/include/ops/declarable/helpers/cuda/minimum.cu index b43bb418eb39..078894e3f382 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/minimum.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/minimum.cu @@ -19,91 +19,88 @@ // #ifndef __MIN_I_MAX_H_HELPERS__ #define __MIN_I_MAX_H_HELPERS__ -#include #include #include +#include namespace sd { - namespace ops { - namespace helpers { - - template - void minimumBPFunctor_(NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) { - - auto lambdaX = LAMBDA_TTT(_e, _x, _y) { - return _x <= _y ? _e : (T) 0.; - }; - - auto lambdaY = LAMBDA_TTT(_e, _x, _y) { - return _x >= _y ? _e : (T) 0.; - }; - - - if (x->isSameShape(y)) { - // PWT case case - - // X gradient - epsNext->applyTriplewiseLambda(*x, *y, lambdaX, *gradX); - - // Y gradient - epsNext->applyTriplewiseLambda(*x, *y, lambdaY, *gradY); - - } else if (y->isScalar()) { - T s = y->e(0); - auto lambdaS = LAMBDA_TT(_e, _x, s) { - return _x <= s ? _e : (T) 0.; - }; - - // scalar case - auto tmp = epsNext->reduceNumber(reduce::Sum); - if (x <= y) - gradY->assign(tmp); - else - gradY->assign(0.0f); - - epsNext->applyPairwiseLambda(*x, lambdaS, *gradX); - } else { - // broadcast case - - // in this case we want to boost our X and Y shapes to the size of FF pass output (or epsNext, which has the same shape) - auto preX = x->dup(); - auto preY = y->dup(); - - auto targetShape = epsNext->getShapeAsVector(); - - preX.tileToShape(targetShape, preX); - preY.tileToShape(targetShape, preY); - - epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX); - epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY); - - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); - - if (axisX.size() > 0) { - auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); - gradX->assign(sum); - } else - gradX->assign(preX); - - if (axisY.size() > 0) { - auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); - gradY->assign(sum); - } else - gradY->assign(preY); - } - - } - - void minimumBPFunctor(sd::LaunchContext * context, NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) { - NDArray::prepareSpecialUse({gradX, gradY}, {x, y, epsNext}); +namespace ops { +namespace helpers { + +template +void minimumBPFunctor_(NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, + NDArray* gradY) { + auto lambdaX = LAMBDA_TTT(_e, _x, _y) { return _x <= _y ? _e : (T)0.; }; + + auto lambdaY = LAMBDA_TTT(_e, _x, _y) { return _x >= _y ? _e : (T)0.; }; + + if (x->isSameShape(y)) { + // PWT case case + + // X gradient + epsNext->applyTriplewiseLambda(*x, *y, lambdaX, *gradX); + + // Y gradient + epsNext->applyTriplewiseLambda(*x, *y, lambdaY, *gradY); + + } else if (y->isScalar()) { + T s = y->e(0); + auto lambdaS = LAMBDA_TT(_e, _x, s) { return _x <= s ? _e : (T)0.; }; + + // scalar case + auto tmp = epsNext->reduceNumber(reduce::Sum); + if (x <= y) + gradY->assign(tmp); + else + gradY->assign(0.0f); + + epsNext->applyPairwiseLambda(*x, lambdaS, *gradX); + } else { + // broadcast case + + // in this case we want to boost our X and Y shapes to the size of FF pass + // output (or epsNext, which has the same shape) + auto preX = x->dup(); + auto preY = y->dup(); + + auto targetShape = epsNext->getShapeAsVector(); + + preX.tileToShape(targetShape, preX); + preY.tileToShape(targetShape, preY); + + epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX); + epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY); + + auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), + epsNext->shapeInfo()); + auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), + epsNext->shapeInfo()); + + if (axisX.size() > 0) { + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); + gradX->assign(sum); + } else + gradX->assign(preX); + + if (axisY.size() > 0) { + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); + gradY->assign(sum); + } else + gradY->assign(preY); + } +} - BUILD_SINGLE_SELECTOR(x->dataType(), minimumBPFunctor_, (x, y, epsNext, gradX, gradY), NUMERIC_TYPES); +void minimumBPFunctor(sd::LaunchContext* context, NDArray* x, NDArray* y, + NDArray* epsNext, NDArray* gradX, NDArray* gradY) { + NDArray::prepareSpecialUse({gradX, gradY}, {x, y, epsNext}); - NDArray::registerSpecialUse({gradX, gradY}, {x, y, epsNext}); - } + BUILD_SINGLE_SELECTOR(x->dataType(), minimumBPFunctor_, + (x, y, epsNext, gradX, gradY), NUMERIC_TYPES); - } - } + NDArray::registerSpecialUse({gradX, gradY}, {x, y, epsNext}); } + +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/cuda/nth_element.cu b/libnd4j/include/ops/declarable/helpers/cuda/nth_element.cu index c2f34f9fe0be..5795d6343d33 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/nth_element.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/nth_element.cu @@ -18,72 +18,88 @@ // @author sgazeos@gmail.com // -#include -#include -#include +#include #include +#include +#include #include -#include +#include namespace sd { namespace ops { namespace helpers { - template - static __global__ void fillUpElementKernel(void* outputBuffer, Nd4jLong const* outputShapeInfo, void* inputBuffer, Nd4jLong const* inputShapeInfo, Nd4jLong const* pTadShape, Nd4jLong const* pTadOffsets, Nd4jLong n) { - __shared__ Nd4jLong bufferLength; - - auto z = reinterpret_cast(outputBuffer); - auto x = reinterpret_cast(inputBuffer); - - if (threadIdx.x == 0) - bufferLength = shape::length(outputShapeInfo); - - __syncthreads(); +template +static __global__ void fillUpElementKernel( + void* outputBuffer, Nd4jLong const* outputShapeInfo, void* inputBuffer, + Nd4jLong const* inputShapeInfo, Nd4jLong const* pTadShape, + Nd4jLong const* pTadOffsets, Nd4jLong n) { + __shared__ Nd4jLong bufferLength; - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - for (int t = tid; t < bufferLength; t += step) { - auto tX = x + pTadOffsets[t]; - z[shape::getIndexOffset(t, outputShapeInfo)] = tX[shape::getIndexOffset(n, pTadShape)]; //tX]; - } - } + auto z = reinterpret_cast(outputBuffer); + auto x = reinterpret_cast(inputBuffer); - template - void nthElementFunctor_(sd::LaunchContext * context, NDArray* input, Nd4jLong n, NDArray* output, bool reverse) { + if (threadIdx.x == 0) bufferLength = shape::length(outputShapeInfo); - NDArray::prepareSpecialUse({output}, {input}); - NDArray sortedVals(*input); - Nd4jPointer params[2]; - params[0] = context; - params[1] = context->getCudaStream(); - // Nth element in sorted sequence : basic algorithm sort and retrieve nth element in sorted - if (input->isVector()) { - sort(params, nullptr, sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), reverse); + __syncthreads(); - cudaMemcpy(reinterpret_cast(output->specialBuffer()), reinterpret_cast(sortedVals.specialBuffer()) + n, sizeof(T), cudaMemcpyDeviceToDevice); - } - else { // rank greater than 1 - std::vector lastDims({input->rankOf() - 1});// = ShapeUtils::evalDimsToExclude(input->rankOf(), {input->rankOf() - 1}); + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + for (int t = tid; t < bufferLength; t += step) { + auto tX = x + pTadOffsets[t]; + z[shape::getIndexOffset(t, outputShapeInfo)] = + tX[shape::getIndexOffset(n, pTadShape)]; // tX]; + } +} - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(sortedVals.shapeInfo(), lastDims); +template +void nthElementFunctor_(sd::LaunchContext* context, NDArray* input, Nd4jLong n, + NDArray* output, bool reverse) { + NDArray::prepareSpecialUse({output}, {input}); + NDArray sortedVals(*input); + Nd4jPointer params[2]; + params[0] = context; + params[1] = context->getCudaStream(); + // Nth element in sorted sequence : basic algorithm sort and retrieve nth + // element in sorted + if (input->isVector()) { + sort(params, nullptr, sortedVals.shapeInfo(), sortedVals.specialBuffer(), + sortedVals.specialShapeInfo(), reverse); - auto pTadShape = packX.specialShapeInfo(); - auto pTadShapeH = packX.primaryShapeInfo(); - auto pTadOffsets = packX.specialOffsets(); - sortTad(params, sortedVals.buffer(), sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), lastDims.data(), lastDims.size(), pTadShape, pTadOffsets, reverse); - sortedVals.tickWriteDevice(); - sortedVals.syncToHost(); - auto stream = context->getCudaStream(); - fillUpElementKernel<<<32, 64, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), pTadShape, pTadOffsets, n); - } - NDArray::registerSpecialUse({output}, {input}); - } - void nthElementFunctor(sd::LaunchContext * context, NDArray* input, Nd4jLong n, NDArray* output, bool reverse) { - BUILD_SINGLE_SELECTOR(input->dataType(), nthElementFunctor_, (context, input, n, output, reverse), LIBND4J_TYPES); + cudaMemcpy(reinterpret_cast(output->specialBuffer()), + reinterpret_cast(sortedVals.specialBuffer()) + n, sizeof(T), + cudaMemcpyDeviceToDevice); + } else { // rank greater than 1 + std::vector lastDims( + {input->rankOf() - + 1}); // = ShapeUtils::evalDimsToExclude(input->rankOf(), + // {input->rankOf() - 1}); - } + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + sortedVals.shapeInfo(), lastDims); + auto pTadShape = packX.specialShapeInfo(); + auto pTadShapeH = packX.primaryShapeInfo(); + auto pTadOffsets = packX.specialOffsets(); + sortTad(params, sortedVals.buffer(), sortedVals.shapeInfo(), + sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), + lastDims.data(), lastDims.size(), pTadShape, pTadOffsets, reverse); + sortedVals.tickWriteDevice(); + sortedVals.syncToHost(); + auto stream = context->getCudaStream(); + fillUpElementKernel<<<32, 64, 1024, *stream>>>( + output->specialBuffer(), output->specialShapeInfo(), + sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), pTadShape, + pTadOffsets, n); + } + NDArray::registerSpecialUse({output}, {input}); } +void nthElementFunctor(sd::LaunchContext* context, NDArray* input, Nd4jLong n, + NDArray* output, bool reverse) { + BUILD_SINGLE_SELECTOR(input->dataType(), nthElementFunctor_, + (context, input, n, output, reverse), LIBND4J_TYPES); } -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/one_hot.cu b/libnd4j/include/ops/declarable/helpers/cuda/one_hot.cu index f1520045955f..0ee16f576f90 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/one_hot.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/one_hot.cu @@ -18,93 +18,104 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 30.05.2019 // - -#include -#include -#include -#include #include -#include +#include #include -#include #include +#include +#include +#include +#include +#include -namespace sd { -namespace ops { -namespace helpers { +namespace sd { +namespace ops { +namespace helpers { /////////////////////////////////////////////////////////////////// // x - indices, z - output -template -__global__ static void onehotCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const uint axis, const uint depth, const Z on, const Z off) { - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ int xRank, zRank; - __shared__ Nd4jLong zLen, totalThreads, *sharedMem; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - xRank = shape::rank(xShapeInfo); - zRank = shape::rank(zShapeInfo); - zLen = shape::length(zShapeInfo); - totalThreads = gridDim.x * blockDim.x; - } - __syncthreads(); - - auto coord = sharedMem + threadIdx.x * zRank; - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < zLen; i += totalThreads) { - - shape::index2coords(i, zShapeInfo, coord); - const auto zOffset = shape::getOffset(zShapeInfo, coord); - const auto depthCoord = coord[axis]; - - for (uint j = axis; j < zRank - 1; ++j) - coord[j] = coord[j + 1]; - - const auto xOffset = shape::getOffset(xShapeInfo, coord); - const Nd4jLong idx = x[xOffset]; - z[zOffset] = depthCoord == idx ? on : off; - } +template +__global__ static void onehotCuda(const void *vx, const Nd4jLong *xShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, + const uint axis, const uint depth, const Z on, + const Z off) { + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ int xRank, zRank; + __shared__ Nd4jLong zLen, totalThreads, *sharedMem; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + xRank = shape::rank(xShapeInfo); + zRank = shape::rank(zShapeInfo); + zLen = shape::length(zShapeInfo); + totalThreads = gridDim.x * blockDim.x; + } + __syncthreads(); + + auto coord = sharedMem + threadIdx.x * zRank; + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < zLen; i += totalThreads) { + shape::index2coords(i, zShapeInfo, coord); + const auto zOffset = shape::getOffset(zShapeInfo, coord); + const auto depthCoord = coord[axis]; + + for (uint j = axis; j < zRank - 1; ++j) coord[j] = coord[j + 1]; + + const auto xOffset = shape::getOffset(xShapeInfo, coord); + const Nd4jLong idx = x[xOffset]; + z[zOffset] = depthCoord == idx ? on : off; + } } /////////////////////////////////////////////////////////////////// -template -static void onehotCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - const uint axis, const uint depth, - const double on, const double off) { - - onehotCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, axis, depth, static_cast(on), static_cast(off)); +template +static void onehotCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, const uint axis, + const uint depth, const double on, + const double off) { + onehotCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, axis, depth, static_cast(on), + static_cast(off)); } /////////////////////////////////////////////////////////////////// -void onehot(const sd::LaunchContext* context, const NDArray *indices, NDArray *output, const uint axis, const uint depth, const double on, const double off) { - - const auto xType = indices->dataType(); - const auto zType = output->dataType(); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (output->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(decltype(*output->shapeInfo())) * output->rankOf() + 128; - - PointersManager manager(context, "onehot"); - - NDArray::prepareSpecialUse({output}, {indices}); - BUILD_DOUBLE_SELECTOR(xType, zType, onehotCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), axis, depth, on, off), LIBND4J_TYPES, LIBND4J_TYPES); - NDArray::registerSpecialUse({output}, {indices}); - - manager.synchronize(); +void onehot(const sd::LaunchContext *context, const NDArray *indices, + NDArray *output, const uint axis, const uint depth, const double on, + const double off) { + const auto xType = indices->dataType(); + const auto zType = output->dataType(); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (output->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * + sizeof(decltype(*output->shapeInfo())) * + output->rankOf() + + 128; + + PointersManager manager(context, "onehot"); + + NDArray::prepareSpecialUse({output}, {indices}); + BUILD_DOUBLE_SELECTOR(xType, zType, onehotCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, + context->getCudaStream(), indices->specialBuffer(), + indices->specialShapeInfo(), output->specialBuffer(), + output->specialShapeInfo(), axis, depth, on, off), + LIBND4J_TYPES, LIBND4J_TYPES); + NDArray::registerSpecialUse({output}, {indices}); + + manager.synchronize(); } - -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/pad.cu b/libnd4j/include/ops/declarable/helpers/cuda/pad.cu index 842a41ced7cf..5728b48c431d 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/pad.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/pad.cu @@ -18,250 +18,281 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - -#include -#include -#include -#include #include -#include +#include #include -#include #include +#include +#include +#include +#include + +#include namespace sd { - namespace ops { - namespace helpers { +namespace ops { +namespace helpers { /////////////////////////////////////////////////////////////////// // x - input, y - paddings, z - output - template - __global__ static void padCuda(const int mode, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - const void *vPadVal) { - - const X padVal = *reinterpret_cast(vPadVal); - - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ int rank, rankMinusOne; - __shared__ Nd4jLong zLen, totalThreads, *coords, *xShape, *zShape, shift1, shift2, yStride0; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - coords = reinterpret_cast(shmem); - zLen = shape::length(zShapeInfo); - xShape = shape::shapeOf(const_cast(xShapeInfo)); - zShape = shape::shapeOf(const_cast(zShapeInfo)); - yStride0 = shape::stride(const_cast(yShapeInfo))[0]; - rank = shape::rank(xShapeInfo); - zLen = shape::length(zShapeInfo); - rankMinusOne = rank - 1; - totalThreads = gridDim.x * blockDim.x; - shift1 = mode == 1 ? 0 : 1; // REFLECT : SYMMETRIC - shift2 = mode == 1 ? 2 : 1; // REFLECT : SYMMETRIC - } - - __syncthreads(); - - auto xzCoord = coords + threadIdx.x * rank; // we use xzCoord storage both for x and z arrays - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - if(mode == 0) { // CONSTANT case - - for (Nd4jLong i = tid; i < zLen; i += totalThreads) { - - shape::index2coords(i, zShapeInfo, xzCoord); - const auto zOffset = shape::getOffset(zShapeInfo, xzCoord); - - bool within = true; - for(int j = rankMinusOne; j >= 0; --j) { - if(xShape[j] == zShape[j]) continue; - const auto left = y[shape::getIndexOffset(yStride0 * j, yShapeInfo)]; - if(xzCoord[j] < left || xzCoord[j] >= left + xShape[j]) {within = false; break;} - else {xzCoord[j] = xzCoord[j] - left;} - } - - if(within) - z[zOffset] = x[shape::getOffset(xShapeInfo, xzCoord)]; - else - z[zOffset] = padVal; - } - } - else { // REFLECT and SYMMETRIC cases - - for (Nd4jLong i = tid; i < zLen; i += totalThreads) { - - shape::index2coords(i, zShapeInfo, xzCoord); - const auto zOffset = shape::getOffset(zShapeInfo, xzCoord); - - for(int j = rankMinusOne; j >= 0; --j) { - - if(xShape[j] == zShape[j]) continue; - xzCoord[j] = xzCoord[j] - y[shape::getIndexOffset(yStride0 * j, yShapeInfo)]; // are ready to fill middle (within input dimension range) - if(xzCoord[j] < 0) xzCoord[j] = -xzCoord[j] - shift1; // means fill from left - else if(xzCoord[j] >= xShape[j]) xzCoord[j] = 2 * xShape[j] - xzCoord[j] - shift2; // means fill from right - } - - const auto xOffset = shape::getOffset(xShapeInfo, xzCoord); - z[zOffset] = x[xOffset]; - } - } - } - -/////////////////////////////////////////////////////////////////// - template - static void padCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const int mode, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - const void* padVal) { +template +__global__ static void padCuda(const int mode, const void* vx, + const Nd4jLong* xShapeInfo, const void* vy, + const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, + const void* vPadVal) { + const X padVal = *reinterpret_cast(vPadVal); + + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + __shared__ int rank, rankMinusOne; + __shared__ Nd4jLong zLen, totalThreads, *coords, *xShape, *zShape, shift1, + shift2, yStride0; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + coords = reinterpret_cast(shmem); + zLen = shape::length(zShapeInfo); + xShape = shape::shapeOf(const_cast(xShapeInfo)); + zShape = shape::shapeOf(const_cast(zShapeInfo)); + yStride0 = shape::stride(const_cast(yShapeInfo))[0]; + rank = shape::rank(xShapeInfo); + zLen = shape::length(zShapeInfo); + rankMinusOne = rank - 1; + totalThreads = gridDim.x * blockDim.x; + shift1 = mode == 1 ? 0 : 1; // REFLECT : SYMMETRIC + shift2 = mode == 1 ? 2 : 1; // REFLECT : SYMMETRIC + } + + __syncthreads(); + + auto xzCoord = + coords + + threadIdx.x * rank; // we use xzCoord storage both for x and z arrays + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + if (mode == 0) { // CONSTANT case + + for (Nd4jLong i = tid; i < zLen; i += totalThreads) { + shape::index2coords(i, zShapeInfo, xzCoord); + const auto zOffset = shape::getOffset(zShapeInfo, xzCoord); + + bool within = true; + for (int j = rankMinusOne; j >= 0; --j) { + if (xShape[j] == zShape[j]) continue; + const auto left = y[shape::getIndexOffset(yStride0 * j, yShapeInfo)]; + if (xzCoord[j] < left || xzCoord[j] >= left + xShape[j]) { + within = false; + break; + } else { + xzCoord[j] = xzCoord[j] - left; + } + } - padCuda<<>>(mode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, padVal); - } + if (within) + z[zOffset] = x[shape::getOffset(xShapeInfo, xzCoord)]; + else + z[zOffset] = padVal; + } + } else { // REFLECT and SYMMETRIC cases + + for (Nd4jLong i = tid; i < zLen; i += totalThreads) { + shape::index2coords(i, zShapeInfo, xzCoord); + const auto zOffset = shape::getOffset(zShapeInfo, xzCoord); + + for (int j = rankMinusOne; j >= 0; --j) { + if (xShape[j] == zShape[j]) continue; + xzCoord[j] = + xzCoord[j] - + y[shape::getIndexOffset( + yStride0 * j, yShapeInfo)]; // are ready to fill middle (within + // input dimension range) + if (xzCoord[j] < 0) + xzCoord[j] = -xzCoord[j] - shift1; // means fill from left + else if (xzCoord[j] >= xShape[j]) + xzCoord[j] = + 2 * xShape[j] - xzCoord[j] - shift2; // means fill from right + } + + const auto xOffset = shape::getOffset(xShapeInfo, xzCoord); + z[zOffset] = x[xOffset]; + } + } +} /////////////////////////////////////////////////////////////////// - void pad(sd::LaunchContext * context, const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, const NDArray& padValue) { - - PointersManager manager(context, "pad"); - - NDArray::prepareSpecialUse({&output}, {&input, &paddings, &padValue}); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = 8 * threadsPerBlock * output.rankOf() + 128; - - const auto xType = input.dataType(); - const auto yType = paddings.dataType(); - - BUILD_DOUBLE_SELECTOR(xType, yType, padCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), mode, input.specialBuffer(), input.specialShapeInfo(), paddings.specialBuffer(), paddings.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), padValue.specialBuffer()), LIBND4J_TYPES, INDEXING_TYPES); - - NDArray::registerSpecialUse({&output}, {&input, &paddings, &padValue}); - manager.synchronize(); - } - - - //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - static __global__ void mirrorPadLinearKernel(void const* vx, const Nd4jLong* xShape, void* vz, const Nd4jLong* zShape, Nd4jLong leftSide, Nd4jLong leftSideCorrected, Nd4jLong xLen, Nd4jLong len, Nd4jLong zLen) { - - __shared__ T const* x; - __shared__ T* z; - if (threadIdx.x == 0) { - x = reinterpret_cast(vx); - z = reinterpret_cast(vz); - } - __syncthreads(); - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for(int i = start; i < zLen; i+= step) { - auto zIndex = shape::getIndexOffset(i, zShape); - auto xIndex = shape::getIndexOffset(len - i, xShape); - - if (i < leftSide) // left side - xIndex = shape::getIndexOffset(leftSideCorrected - i, xShape); - - else if(i >= leftSide && i < leftSide + xLen) // middle - xIndex = shape::getIndexOffset(i - leftSide, xShape); - -// else // right side -// z[i] = x[len - i]; - z[zIndex] = x[xIndex]; - } - - } - - template - static __global__ void mirrorPadKernel(void const* vx, const Nd4jLong* xShape, void* vz, const Nd4jLong* zShape, Nd4jLong outLen, void const* paddings, const Nd4jLong* paddingShape, int reflBorder) { - - __shared__ F const* x; - __shared__ I const* pads; - __shared__ F* z; - __shared__ Nd4jLong zRank, rank; - __shared__ Nd4jLong* xIdx; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - xIdx = reinterpret_cast(shmem); - rank = shape::rank(xShape); - - x = reinterpret_cast(vx);// - pads = reinterpret_cast(paddings); - z = reinterpret_cast(vz); - } - __syncthreads(); - auto start = threadIdx.x + blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; - - for(Nd4jLong i = start; i < outLen; i+= step) { - auto xzCoord = xIdx + threadIdx.x * rank; - //auto zxCoord = xIdx + (threadIdx.x + threadIdx.x % 2 + 1) * rank; - - shape::index2coords(i, zShape, xzCoord); - auto outOffset = shape::getOffset(zShape, xzCoord); -// auto intStep = blockDim.y * gridDim.y; - for(int j = 0; j < rank; j++) { - - const Nd4jLong inLen = shape::sizeAt(xShape, j); - Nd4jLong coords[2] = {j, 0}; - auto padOffset = shape::getOffset(paddingShape, coords); // padding already has rank 2 - const auto leftSide = pads[padOffset]; - const auto leftSideCorrected = leftSide - reflBorder; - const Nd4jLong len = 2 * (inLen - 1) + leftSide + reflBorder; - - if(xzCoord[j] < leftSide) // left side - xzCoord[j] = leftSideCorrected - xzCoord[j]; - - else if(xzCoord[j] >= leftSide && xzCoord[j] < leftSide + inLen) // middle - xzCoord[j] = xzCoord[j] - leftSide; - - else if (len > xzCoord[j]) // right side - xzCoord[j] = len - xzCoord[j]; - else - xzCoord[j] = xzCoord[j] - len; - } - - auto inOffset = shape::getOffset(xShape, xzCoord); - z[outOffset] = x[inOffset]; - } - } - - template - static void mirrorPad_(sd::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode) { - // mode: 0 - REFLECT, else - SYMMETRIC - const int reflBorder = (bool)mode ? 1 : 0; - const int rank = input.rankOf(); - const Nd4jLong outLen = output.lengthOf(); - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({&output}, {&input, &paddings}); - - if(rank <= 1) { - - const Nd4jLong inLen = input.lengthOf(); - const auto leftSide = paddings.e(0); - const auto leftSideCorrected = leftSide - reflBorder; - const Nd4jLong len = 2*(inLen-1) + leftSide + reflBorder; - - mirrorPadLinearKernel<<<256, 512, 256, *stream>>>(input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), leftSide, leftSideCorrected, inLen, len, outLen); - sd::DebugHelper::checkErrorCode(stream, "helpers::mirrorPadLinearKernel(...) failed"); - } - else { - mirrorPadKernel<<<256, 256, 8192, *stream>>>(input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), outLen, paddings.specialBuffer(), paddings.specialShapeInfo(), reflBorder); - sd::DebugHelper::checkErrorCode(stream, "helpers::mirrorPadKernel(...) failed"); - } - NDArray::registerSpecialUse({&output}, {&input, &paddings}); - } - - void mirrorPad(sd::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode) { - BUILD_DOUBLE_SELECTOR(input.dataType(), paddings.dataType(), mirrorPad_, (context, input, paddings, output, mode), LIBND4J_TYPES, INDEXING_TYPES); - } - +template +static void padCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, + const int sharedMem, const cudaStream_t* stream, + const int mode, const void* vx, + const Nd4jLong* xShapeInfo, const void* vy, + const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const void* padVal) { + padCuda<<>>( + mode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, padVal); +} - } +/////////////////////////////////////////////////////////////////// +void pad(sd::LaunchContext* context, const int mode, const NDArray& input, + const NDArray& paddings, NDArray& output, const NDArray& padValue) { + PointersManager manager(context, "pad"); + + NDArray::prepareSpecialUse({&output}, {&input, &paddings, &padValue}); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = 8 * threadsPerBlock * output.rankOf() + 128; + + const auto xType = input.dataType(); + const auto yType = paddings.dataType(); + + BUILD_DOUBLE_SELECTOR(xType, yType, padCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, + context->getCudaStream(), mode, input.specialBuffer(), + input.specialShapeInfo(), paddings.specialBuffer(), + paddings.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), padValue.specialBuffer()), + LIBND4J_TYPES, INDEXING_TYPES); + + NDArray::registerSpecialUse({&output}, {&input, &paddings, &padValue}); + manager.synchronize(); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +static __global__ void mirrorPadLinearKernel( + void const* vx, const Nd4jLong* xShape, void* vz, const Nd4jLong* zShape, + Nd4jLong leftSide, Nd4jLong leftSideCorrected, Nd4jLong xLen, Nd4jLong len, + Nd4jLong zLen) { + __shared__ T const* x; + __shared__ T* z; + if (threadIdx.x == 0) { + x = reinterpret_cast(vx); + z = reinterpret_cast(vz); + } + __syncthreads(); + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (int i = start; i < zLen; i += step) { + auto zIndex = shape::getIndexOffset(i, zShape); + auto xIndex = shape::getIndexOffset(len - i, xShape); + + if (i < leftSide) // left side + xIndex = shape::getIndexOffset(leftSideCorrected - i, xShape); + + else if (i >= leftSide && i < leftSide + xLen) // middle + xIndex = shape::getIndexOffset(i - leftSide, xShape); + + // else // right + // side + // z[i] = x[len - i]; + z[zIndex] = x[xIndex]; + } +} + +template +static __global__ void mirrorPadKernel(void const* vx, const Nd4jLong* xShape, + void* vz, const Nd4jLong* zShape, + Nd4jLong outLen, void const* paddings, + const Nd4jLong* paddingShape, + int reflBorder) { + __shared__ F const* x; + __shared__ I const* pads; + __shared__ F* z; + __shared__ Nd4jLong zRank, rank; + __shared__ Nd4jLong* xIdx; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + xIdx = reinterpret_cast(shmem); + rank = shape::rank(xShape); + + x = reinterpret_cast(vx); // + pads = reinterpret_cast(paddings); + z = reinterpret_cast(vz); + } + __syncthreads(); + auto start = threadIdx.x + blockIdx.x * blockDim.x; + auto step = blockDim.x * gridDim.x; + + for (Nd4jLong i = start; i < outLen; i += step) { + auto xzCoord = xIdx + threadIdx.x * rank; + // auto zxCoord = xIdx + (threadIdx.x + threadIdx.x % 2 + 1) * rank; + + shape::index2coords(i, zShape, xzCoord); + auto outOffset = shape::getOffset(zShape, xzCoord); + // auto intStep = blockDim.y * gridDim.y; + for (int j = 0; j < rank; j++) { + const Nd4jLong inLen = shape::sizeAt(xShape, j); + Nd4jLong coords[2] = {j, 0}; + auto padOffset = + shape::getOffset(paddingShape, coords); // padding already has rank 2 + const auto leftSide = pads[padOffset]; + const auto leftSideCorrected = leftSide - reflBorder; + const Nd4jLong len = 2 * (inLen - 1) + leftSide + reflBorder; + + if (xzCoord[j] < leftSide) // left side + xzCoord[j] = leftSideCorrected - xzCoord[j]; + + else if (xzCoord[j] >= leftSide && + xzCoord[j] < leftSide + inLen) // middle + xzCoord[j] = xzCoord[j] - leftSide; + + else if (len > xzCoord[j]) // right side + xzCoord[j] = len - xzCoord[j]; + else + xzCoord[j] = xzCoord[j] - len; } -} \ No newline at end of file + + auto inOffset = shape::getOffset(xShape, xzCoord); + z[outOffset] = x[inOffset]; + } +} + +template +static void mirrorPad_(sd::LaunchContext* context, const NDArray& input, + const NDArray& paddings, NDArray& output, + const int mode) { + // mode: 0 - REFLECT, else - SYMMETRIC + const int reflBorder = (bool)mode ? 1 : 0; + const int rank = input.rankOf(); + const Nd4jLong outLen = output.lengthOf(); + auto stream = context->getCudaStream(); + NDArray::prepareSpecialUse({&output}, {&input, &paddings}); + + if (rank <= 1) { + const Nd4jLong inLen = input.lengthOf(); + const auto leftSide = paddings.e(0); + const auto leftSideCorrected = leftSide - reflBorder; + const Nd4jLong len = 2 * (inLen - 1) + leftSide + reflBorder; + + mirrorPadLinearKernel<<<256, 512, 256, *stream>>>( + input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), leftSide, leftSideCorrected, inLen, len, + outLen); + sd::DebugHelper::checkErrorCode( + stream, "helpers::mirrorPadLinearKernel(...) failed"); + } else { + mirrorPadKernel<<<256, 256, 8192, *stream>>>( + input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), outLen, paddings.specialBuffer(), + paddings.specialShapeInfo(), reflBorder); + sd::DebugHelper::checkErrorCode(stream, + "helpers::mirrorPadKernel(...) failed"); + } + NDArray::registerSpecialUse({&output}, {&input, &paddings}); +} + +void mirrorPad(sd::LaunchContext* context, const NDArray& input, + const NDArray& paddings, NDArray& output, const int mode) { + BUILD_DOUBLE_SELECTOR(input.dataType(), paddings.dataType(), mirrorPad_, + (context, input, paddings, output, mode), LIBND4J_TYPES, + INDEXING_TYPES); +} + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/percentile.cu b/libnd4j/include/ops/declarable/helpers/cuda/percentile.cu index 1bc50fad700c..0fdbe85b8048 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/percentile.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/percentile.cu @@ -19,119 +19,137 @@ // @author raver119@gmail.com // -#include #include +#include #include #include -#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - template - static _CUDA_G void percentileKernel(void *vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, - const Nd4jLong numTads, const Nd4jLong tadLength, - void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong zLength, - const Nd4jLong position) { - for (int t = blockIdx.x; t < numTads; t += gridDim.x) { - auto x = reinterpret_cast(vx) + xTadOffsets[t]; - auto z = reinterpret_cast(vz); - - - // sort tad - if (tadLength > 1) { - for (int m = 0; m < tadLength; m++) { - if (m % 2 == 0) { - for (int tid = threadIdx.x; tid < tadLength; tid += blockDim.x) { - auto top = 2 * tid + 1; - if (top < tadLength) { - auto t0 = shape::getIndexOffset(top - 1, xTadShapeInfo); - auto t1 = shape::getIndexOffset(top, xTadShapeInfo); - - if (x[t0] > x[t1]) { - //swap values - X dz0 = x[t0]; - x[t0] = x[t1]; - x[t1] = dz0; - } - } - } - } else { - for (int tid = threadIdx.x; tid < tadLength; tid += blockDim.x) { - auto top = 2 * tid + 2; - if (top < tadLength) { - auto t0 = shape::getIndexOffset(top - 1, xTadShapeInfo); - auto t1 = shape::getIndexOffset(top, xTadShapeInfo); - - if (x[t0] > x[t1]) { - //swap values - X dz0 = x[t0]; - x[t0] = x[t1]; - x[t1] = dz0; - } - } - } - } - __syncthreads(); - } +template +static _CUDA_G void percentileKernel(void* vx, const Nd4jLong* xTadShapeInfo, + const Nd4jLong* xTadOffsets, + const Nd4jLong numTads, + const Nd4jLong tadLength, void* vz, + const Nd4jLong* zShapeInfo, + const Nd4jLong zLength, + const Nd4jLong position) { + for (int t = blockIdx.x; t < numTads; t += gridDim.x) { + auto x = reinterpret_cast(vx) + xTadOffsets[t]; + auto z = reinterpret_cast(vz); + + // sort tad + if (tadLength > 1) { + for (int m = 0; m < tadLength; m++) { + if (m % 2 == 0) { + for (int tid = threadIdx.x; tid < tadLength; tid += blockDim.x) { + auto top = 2 * tid + 1; + if (top < tadLength) { + auto t0 = shape::getIndexOffset(top - 1, xTadShapeInfo); + auto t1 = shape::getIndexOffset(top, xTadShapeInfo); + + if (x[t0] > x[t1]) { + // swap values + X dz0 = x[t0]; + x[t0] = x[t1]; + x[t1] = dz0; + } } - - // saving final value - if (threadIdx.x == 0) - z[shape::getIndexOffset(t, zShapeInfo)] = x[shape::getIndexOffset(position, xTadShapeInfo)]; - __syncthreads(); + } + } else { + for (int tid = threadIdx.x; tid < tadLength; tid += blockDim.x) { + auto top = 2 * tid + 2; + if (top < tadLength) { + auto t0 = shape::getIndexOffset(top - 1, xTadShapeInfo); + auto t1 = shape::getIndexOffset(top, xTadShapeInfo); + + if (x[t0] > x[t1]) { + // swap values + X dz0 = x[t0]; + x[t0] = x[t1]; + x[t1] = dz0; + } + } + } } + __syncthreads(); + } } + // saving final value + if (threadIdx.x == 0) + z[shape::getIndexOffset(t, zShapeInfo)] = + x[shape::getIndexOffset(position, xTadShapeInfo)]; + __syncthreads(); + } +} +template +static void _percentile(sd::LaunchContext* context, const NDArray& input, + NDArray& output, std::vector& axis, const float q, + const int interpolation) { + const int inputRank = input.rankOf(); + + if (axis.empty()) + for (int i = 0; i < inputRank; ++i) axis.push_back(i); + else + shape::checkDimensions(inputRank, axis); + + auto tempArray = input.dup(); + auto packX = ConstantTadHelper::getInstance().tadForDimensions( + tempArray.shapeInfo(), axis); + + auto tadLength = shape::length(packX.primaryShapeInfo()); + + const float fraction = 1.f - q / 100.; + Nd4jLong position = 0; + + switch (interpolation) { + case 0: // lower + position = static_cast( + math::nd4j_ceil((tadLength - 1) * fraction)); + break; + case 1: // higher + position = static_cast( + math::nd4j_floor((tadLength - 1) * fraction)); + break; + case 2: // nearest + position = static_cast( + math::nd4j_round((tadLength - 1) * fraction)); + break; + } + position = tadLength - position - 1; + + percentileKernel<<<256, 512, 1024, *context->getCudaStream()>>>( + tempArray.specialBuffer(), packX.platformShapeInfo(), + packX.platformOffsets(), packX.numberOfTads(), tadLength, + output.specialBuffer(), output.specialShapeInfo(), output.lengthOf(), + position); + + sd::DebugHelper::checkErrorCode(context->getCudaStream(), "percentile"); +} - template - static void _percentile(sd::LaunchContext * context, const NDArray& input, NDArray& output, std::vector& axis, const float q, const int interpolation) { - const int inputRank = input.rankOf(); - - if(axis.empty()) - for(int i=0; i(math::nd4j_ceil((tadLength - 1) * fraction)); - break; - case 1: // higher - position = static_cast(math::nd4j_floor((tadLength - 1) * fraction)); - break; - case 2: // nearest - position = static_cast(math::nd4j_round((tadLength - 1) * fraction)); - break; - } - position = tadLength - position - 1; - - percentileKernel<<<256, 512, 1024, *context->getCudaStream()>>>(tempArray.specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), packX.numberOfTads(), tadLength, output.specialBuffer(), output.specialShapeInfo(), output.lengthOf(), position); - - sd::DebugHelper::checkErrorCode(context->getCudaStream(), "percentile"); - } - - void percentile(sd::LaunchContext * context, const NDArray& input, NDArray& output, std::vector& axises, const float q, const int interpolation) { - NDArray::prepareSpecialUse({&output}, {&input}); +void percentile(sd::LaunchContext* context, const NDArray& input, + NDArray& output, std::vector& axises, const float q, + const int interpolation) { + NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), _percentile, (context, input, output, axises, q, interpolation), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), _percentile, + (context, input, output, axises, q, interpolation), + LIBND4J_TYPES); - NDArray::registerSpecialUse({&output}, {&input}); - } + NDArray::registerSpecialUse({&output}, {&input}); +} - BUILD_SINGLE_TEMPLATE(template void _percentile, (sd::LaunchContext * context, const NDArray& input, NDArray& output, std::vector& axises, const float q, const int interpolation), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void _percentile, + (sd::LaunchContext * context, const NDArray& input, + NDArray& output, std::vector& axises, const float q, + const int interpolation), + LIBND4J_TYPES); -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/polyGamma.cu b/libnd4j/include/ops/declarable/helpers/cuda/polyGamma.cu index 3e82632e27f8..ccf0cf6e234a 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/polyGamma.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/polyGamma.cu @@ -18,86 +18,98 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 26.04.2019 // -#include -#include #include +#include +#include namespace sd { namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template +template __global__ static void polyGammaCuda(const void *vn, const Nd4jLong *nShapeInfo, - const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - const auto n = reinterpret_cast(vn); - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ Nd4jLong len; - __shared__ bool sameOffsetNX, sameOffsetNZ; - - if (threadIdx.x == 0) { - len = shape::length(nShapeInfo); - sameOffsetNX = shape::haveSameShapeAndStrides(xShapeInfo, nShapeInfo); - sameOffsetNZ = shape::haveSameShapeAndStrides(zShapeInfo, nShapeInfo); - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto totalThreads = gridDim.x * blockDim.x; - - for (int i = tid; i < len; i += totalThreads) { - - const auto nOffset = shape::getIndexOffset(i, nShapeInfo); - const auto xOffset = sameOffsetNX ? nOffset : shape::getIndexOffset(i, xShapeInfo); - const auto zOffset = sameOffsetNZ ? nOffset : shape::getIndexOffset(i, zShapeInfo); - - const T order = n[nOffset]; - - int sign = (static_cast(order) + 1) % 2 ? -1 : 1; - - if(order != static_cast(order)) { - z[zOffset] = DataTypeUtils::nanOrZero(); - } - else if(order == 0) { - z[zOffset] = diGammaScalar(x[xOffset]); - } - else { - T factorial = 1; - for(int i = 2; i <= order; ++i) - factorial *= i; - - z[zOffset] = sign * factorial * zetaScalar(order + 1, x[xOffset]); - } + const void *vx, const Nd4jLong *xShapeInfo, + void *vz, const Nd4jLong *zShapeInfo) { + const auto n = reinterpret_cast(vn); + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ Nd4jLong len; + __shared__ bool sameOffsetNX, sameOffsetNZ; + + if (threadIdx.x == 0) { + len = shape::length(nShapeInfo); + sameOffsetNX = shape::haveSameShapeAndStrides(xShapeInfo, nShapeInfo); + sameOffsetNZ = shape::haveSameShapeAndStrides(zShapeInfo, nShapeInfo); + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto totalThreads = gridDim.x * blockDim.x; + + for (int i = tid; i < len; i += totalThreads) { + const auto nOffset = shape::getIndexOffset(i, nShapeInfo); + const auto xOffset = + sameOffsetNX ? nOffset : shape::getIndexOffset(i, xShapeInfo); + const auto zOffset = + sameOffsetNZ ? nOffset : shape::getIndexOffset(i, zShapeInfo); + + const T order = n[nOffset]; + + int sign = (static_cast(order) + 1) % 2 ? -1 : 1; + + if (order != static_cast(order)) { + z[zOffset] = DataTypeUtils::nanOrZero(); + } else if (order == 0) { + z[zOffset] = diGammaScalar(x[xOffset]); + } else { + T factorial = 1; + for (int i = 2; i <= order; ++i) factorial *= i; + + z[zOffset] = sign * factorial * zetaScalar(order + 1, x[xOffset]); } + } } /////////////////////////////////////////////////////////////////// -template -static void polyGammaCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vn, const Nd4jLong *nShapeInfo, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { - - polyGammaCuda<<>>(vn, nShapeInfo, vx, xShapeInfo, vz, zShapeInfo); +template +static void polyGammaCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, + const cudaStream_t *stream, const void *vn, + const Nd4jLong *nShapeInfo, const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + polyGammaCuda<<>>( + vn, nShapeInfo, vx, xShapeInfo, vz, zShapeInfo); } /////////////////////////////////////////////////////////////////// -void polyGamma(sd::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& z) { +void polyGamma(sd::LaunchContext *context, const NDArray &n, const NDArray &x, + NDArray &z) { + NDArray::prepareSpecialUse({&z}, {&n, &x}); - NDArray::prepareSpecialUse({&z}, {&n, &x}); + int threadsPerBlock = MAX_NUM_THREADS / 2; + int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - int threadsPerBlock = MAX_NUM_THREADS / 2; - int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + BUILD_SINGLE_SELECTOR( + n.dataType(), polyGammaCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + n.specialBuffer(), n.specialShapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), z.specialBuffer(), z.specialShapeInfo()), + FLOAT_TYPES); - BUILD_SINGLE_SELECTOR(n.dataType(), polyGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), n.specialBuffer(), n.specialShapeInfo(), x.specialBuffer(), x.specialShapeInfo(), z.specialBuffer(), z.specialShapeInfo()), FLOAT_TYPES); - - NDArray::registerSpecialUse({&z}, {&n, &x}); -} - -BUILD_SINGLE_TEMPLATE(template void polyGammaCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vn, const Nd4jLong *nShapeInfo, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo), FLOAT_TYPES); - -} -} + NDArray::registerSpecialUse({&z}, {&n, &x}); } +BUILD_SINGLE_TEMPLATE(template void polyGammaCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t *stream, const void *vn, + const Nd4jLong *nShapeInfo, const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo), + FLOAT_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/prefix.cu b/libnd4j/include/ops/declarable/helpers/cuda/prefix.cu index 959b458656c6..bb30db33fe1a 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/prefix.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/prefix.cu @@ -18,11 +18,11 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 12.06.2019 // -#include #include #include #include #include +#include namespace sd { namespace ops { @@ -30,148 +30,162 @@ namespace helpers { /////////////////////////////////////////////////////////////////// template -__global__ static void prefixPerBlockCuda(scalar::Ops op, - const void* vx, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, - void* vz, const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, - const Nd4jLong numTads, const Nd4jLong tadLen, - const bool exclusive, const bool reverse) { - - __shared__ T *shared, lastElemInChunk; - __shared__ uint numTadChunks, blockDim2; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - shared = reinterpret_cast(shmem); - blockDim2 = 2 * blockDim.x; - numTadChunks = (tadLen + blockDim2 - 1) / blockDim2; // ceil +__global__ static void prefixPerBlockCuda( + scalar::Ops op, const void* vx, const Nd4jLong* xTadShapeInfo, + const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zTadShapeInfo, + const Nd4jLong* zTadOffsets, const Nd4jLong numTads, const Nd4jLong tadLen, + const bool exclusive, const bool reverse) { + __shared__ T *shared, lastElemInChunk; + __shared__ uint numTadChunks, blockDim2; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + shared = reinterpret_cast(shmem); + blockDim2 = 2 * blockDim.x; + numTadChunks = (tadLen + blockDim2 - 1) / blockDim2; // ceil + } + __syncthreads(); + + const auto xTad = reinterpret_cast(vx) + xTadOffsets[blockIdx.x]; + auto zTad = reinterpret_cast(vz) + zTadOffsets[blockIdx.x]; + + Nd4jLong sharedInd(2 * threadIdx.x), leftArrInd, rightArrInd, step; + T xLeft, xRight; + + for (uint i = 0; i < numTadChunks; ++i) { + leftArrInd = sharedInd + i * blockDim2; + rightArrInd = leftArrInd + 1; + + if (reverse) { + if (rightArrInd < tadLen) { + rightArrInd = tadLen - 1 - rightArrInd; + leftArrInd = tadLen - 1 - leftArrInd; + } else if (leftArrInd < tadLen) + leftArrInd = tadLen - 1 - leftArrInd; } - __syncthreads(); - const auto xTad = reinterpret_cast(vx) + xTadOffsets[blockIdx.x]; - auto zTad = reinterpret_cast(vz) + zTadOffsets[blockIdx.x]; - - Nd4jLong sharedInd(2 * threadIdx.x), leftArrInd, rightArrInd, step; - T xLeft, xRight; - - for (uint i = 0; i < numTadChunks; ++i) { - - leftArrInd = sharedInd + i * blockDim2; - rightArrInd = leftArrInd + 1; - - if(reverse) { - if(rightArrInd < tadLen) { - rightArrInd = tadLen - 1 - rightArrInd; - leftArrInd = tadLen - 1 - leftArrInd; - } - else if(leftArrInd < tadLen) - leftArrInd = tadLen - 1 - leftArrInd; - } - - if(leftArrInd < tadLen) - shared[sharedInd] = xLeft = xTad[shape::getIndexOffset(leftArrInd, xTadShapeInfo)]; - // else - // shared[sharedInd] = (op == scalar::Add) ? 0 : 1; - - if(rightArrInd < tadLen) - shared[sharedInd + 1] = xRight = xTad[shape::getIndexOffset(rightArrInd, xTadShapeInfo)]; - // else - // shared[sharedInd + 1] = (op == scalar::Add) ? 0 : 1; - - - step = 1; - - for (uint d = blockDim.x; d > 0; d /= 2) { - - __syncthreads(); - if(threadIdx.x < d) { - uint left = step * (sharedInd + 1) - 1; - uint right = step * (sharedInd + 2) - 1; - shared[right] = (op == scalar::Add) ? (shared[right] + shared[left]) : (shared[right] * shared[left]); - } - step *= 2; - } - - if (threadIdx.x == 0) - shared[blockDim2 - 1] = (op == scalar::Add) ? 0 : 1; - __syncthreads(); - - for (uint d = 1; d < blockDim2; d *= 2) { - - step /= 2; - - __syncthreads(); - if(threadIdx.x < d) { - uint left = step * (sharedInd + 1) - 1; - uint right = step * (sharedInd + 2) - 1; - T temp = shared[left]; - shared[left] = shared[right]; - shared[right] = (op == scalar::Add) ? (shared[right] + temp) : (shared[right] * temp); - } - } - - __syncthreads(); - - if(leftArrInd < tadLen) { - T result = shared[sharedInd]; - if(!exclusive) - result = (op == scalar::Add) ? result + xLeft : result * xLeft; - if(i > 0) - result = (op == scalar::Add) ? result + lastElemInChunk : result * lastElemInChunk; - zTad[shape::getIndexOffset(leftArrInd, zTadShapeInfo)] = result; - } - - if(rightArrInd < tadLen) { - T result = shared[sharedInd + 1]; - if(!exclusive) - result = (op == scalar::Add) ? result + xRight : result * xRight; - if(i > 0) - result = (op == scalar::Add) ? result + lastElemInChunk : result * lastElemInChunk; - if(i < numTadChunks - 1 && threadIdx.x == blockDim.x - 1) // last element in chunk - lastElemInChunk = !exclusive ? result : (op == scalar::Add) ? result + xRight : result * xRight; - zTad[shape::getIndexOffset(rightArrInd, zTadShapeInfo)] = result; - } + if (leftArrInd < tadLen) + shared[sharedInd] = xLeft = + xTad[shape::getIndexOffset(leftArrInd, xTadShapeInfo)]; + // else + // shared[sharedInd] = (op == scalar::Add) ? 0 : 1; + + if (rightArrInd < tadLen) + shared[sharedInd + 1] = xRight = + xTad[shape::getIndexOffset(rightArrInd, xTadShapeInfo)]; + // else + // shared[sharedInd + 1] = (op == scalar::Add) ? 0 : 1; + + step = 1; + + for (uint d = blockDim.x; d > 0; d /= 2) { + __syncthreads(); + if (threadIdx.x < d) { + uint left = step * (sharedInd + 1) - 1; + uint right = step * (sharedInd + 2) - 1; + shared[right] = (op == scalar::Add) ? (shared[right] + shared[left]) + : (shared[right] * shared[left]); + } + step *= 2; } -} - -/////////////////////////////////////////////////////////////////// -template -static void prefixPerBlockCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - scalar::Ops op, - const void* vx, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, - void* vz, const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, - const Nd4jLong numTads, const Nd4jLong tadLen, - const bool exclusive, const bool reverse) { - - prefixPerBlockCuda<<>>(op, vx, xTadShapeInfo, xTadOffsets, vz, zTadShapeInfo, zTadOffsets, numTads, tadLen, exclusive, reverse); -} - -/////////////////////////////////////////////////////////////////// -void prefix(sd::LaunchContext * context, scalar::Ops op, const NDArray* x, NDArray* z, const std::vector& dims, bool exclusive, bool reverse) { - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x->shapeInfo(), dims); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(z->shapeInfo(), dims); - const Nd4jLong numTads = packX.numberOfTads(); - const Nd4jLong tadLen = x->lengthOf() / numTads; + if (threadIdx.x == 0) shared[blockDim2 - 1] = (op == scalar::Add) ? 0 : 1; + __syncthreads(); - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = numTads; - const int sharedMem = 2 * threadsPerBlock * x->sizeOfT() + 128; + for (uint d = 1; d < blockDim2; d *= 2) { + step /= 2; + + __syncthreads(); + if (threadIdx.x < d) { + uint left = step * (sharedInd + 1) - 1; + uint right = step * (sharedInd + 2) - 1; + T temp = shared[left]; + shared[left] = shared[right]; + shared[right] = (op == scalar::Add) ? (shared[right] + temp) + : (shared[right] * temp); + } + } - PointersManager manager(context, "prefix"); + __syncthreads(); - NDArray::prepareSpecialUse({z}, {x}); - BUILD_SINGLE_SELECTOR(x->dataType(), prefixPerBlockCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, x->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), z->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), numTads, tadLen, exclusive, reverse), NUMERIC_TYPES); - NDArray::registerSpecialUse({z}, {x}); + if (leftArrInd < tadLen) { + T result = shared[sharedInd]; + if (!exclusive) + result = (op == scalar::Add) ? result + xLeft : result * xLeft; + if (i > 0) + result = (op == scalar::Add) ? result + lastElemInChunk + : result * lastElemInChunk; + zTad[shape::getIndexOffset(leftArrInd, zTadShapeInfo)] = result; + } - manager.synchronize(); + if (rightArrInd < tadLen) { + T result = shared[sharedInd + 1]; + if (!exclusive) + result = (op == scalar::Add) ? result + xRight : result * xRight; + if (i > 0) + result = (op == scalar::Add) ? result + lastElemInChunk + : result * lastElemInChunk; + if (i < numTadChunks - 1 && + threadIdx.x == blockDim.x - 1) // last element in chunk + lastElemInChunk = !exclusive ? result + : (op == scalar::Add) ? result + xRight + : result * xRight; + zTad[shape::getIndexOffset(rightArrInd, zTadShapeInfo)] = result; + } + } } /////////////////////////////////////////////////////////////////// -void prefix(sd::LaunchContext * context, scalar::Ops op, const NDArray* x, NDArray* z, bool exclusive, bool reverse) { - prefix(context, op, x, z, {}, exclusive, reverse); +template +static void prefixPerBlockCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, scalar::Ops op, const void* vx, + const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, void* vz, + const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, + const Nd4jLong numTads, const Nd4jLong tadLen, const bool exclusive, + const bool reverse) { + prefixPerBlockCuda<<>>( + op, vx, xTadShapeInfo, xTadOffsets, vz, zTadShapeInfo, zTadOffsets, + numTads, tadLen, exclusive, reverse); } +/////////////////////////////////////////////////////////////////// +void prefix(sd::LaunchContext* context, scalar::Ops op, const NDArray* x, + NDArray* z, const std::vector& dims, bool exclusive, + bool reverse) { + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + x->shapeInfo(), dims); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + z->shapeInfo(), dims); + + const Nd4jLong numTads = packX.numberOfTads(); + const Nd4jLong tadLen = x->lengthOf() / numTads; + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = numTads; + const int sharedMem = 2 * threadsPerBlock * x->sizeOfT() + 128; + + PointersManager manager(context, "prefix"); + + NDArray::prepareSpecialUse({z}, {x}); + BUILD_SINGLE_SELECTOR( + x->dataType(), prefixPerBlockCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, + x->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), + z->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), + numTads, tadLen, exclusive, reverse), + NUMERIC_TYPES); + NDArray::registerSpecialUse({z}, {x}); + + manager.synchronize(); } + +/////////////////////////////////////////////////////////////////// +void prefix(sd::LaunchContext* context, scalar::Ops op, const NDArray* x, + NDArray* z, bool exclusive, bool reverse) { + prefix(context, op, x, z, {}, exclusive, reverse); } -} \ No newline at end of file + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/print_variable.cu b/libnd4j/include/ops/declarable/helpers/cuda/print_variable.cu index 6733ce642087..2f60ce274562 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/print_variable.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/print_variable.cu @@ -18,44 +18,48 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { - namespace ops { - namespace helpers { - template - static _CUDA_G void print_device(const void *special, const Nd4jLong *shapeInfo) { - auto length = shape::length(shapeInfo); - auto x = reinterpret_cast(special); - - // TODO: add formatting here - printf("["); - - for (uint64_t e = 0; e < length; e++) { - printf("%f", (float) x[shape::getIndexOffset(e, shapeInfo)]); - - if (e < length - 1) - printf(", "); - } - - printf("]\n"); - } - - template - static _CUDA_H void exec_print_device(LaunchContext &ctx, const void *special, const Nd4jLong *shapeInfo) { - print_device<<<1, 1, 1024, *ctx.getCudaStream()>>>(special, shapeInfo); - } - - void print_special(LaunchContext &ctx, const NDArray &array, const std::string &message) { - NDArray::prepareSpecialUse({}, {&array}); - - PointersManager pm(&ctx, "print_device"); - BUILD_SINGLE_SELECTOR(array.dataType(), exec_print_device, (ctx, array.specialBuffer(), array.specialShapeInfo()), LIBND4J_TYPES) - pm.synchronize(); - - NDArray::registerSpecialUse({}, {&array}); - } - } - } +namespace ops { +namespace helpers { +template +static _CUDA_G void print_device(const void *special, + const Nd4jLong *shapeInfo) { + auto length = shape::length(shapeInfo); + auto x = reinterpret_cast(special); + + // TODO: add formatting here + printf("["); + + for (uint64_t e = 0; e < length; e++) { + printf("%f", (float)x[shape::getIndexOffset(e, shapeInfo)]); + + if (e < length - 1) printf(", "); + } + + printf("]\n"); +} + +template +static _CUDA_H void exec_print_device(LaunchContext &ctx, const void *special, + const Nd4jLong *shapeInfo) { + print_device<<<1, 1, 1024, *ctx.getCudaStream()>>>(special, shapeInfo); +} + +void print_special(LaunchContext &ctx, const NDArray &array, + const std::string &message) { + NDArray::prepareSpecialUse({}, {&array}); + + PointersManager pm(&ctx, "print_device"); + BUILD_SINGLE_SELECTOR(array.dataType(), exec_print_device, + (ctx, array.specialBuffer(), array.specialShapeInfo()), + LIBND4J_TYPES) + pm.synchronize(); + + NDArray::registerSpecialUse({}, {&array}); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/qr.cu b/libnd4j/include/ops/declarable/helpers/cuda/qr.cu index e499f21d00fb..762e3f3b14fb 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/qr.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/qr.cu @@ -17,163 +17,187 @@ // // @author George A. Shulinok // -#include #include #include +#include namespace sd { namespace ops { namespace helpers { - template - static __global__ void matrixMinorKernel(T* outBuffer, Nd4jLong* outShape, T* inBuffer, Nd4jLong* inShape, Nd4jLong column, Nd4jLong rows, Nd4jLong columns) { -// auto tid = threadIdx.x + blockDim.x * blockIdx.x; -// auto step = blockDim.x * gridDim.x; -// if (threadIdx.x == 0) { -// for (auto i = tid; i < column; i += step) { -// Nd4jLong diagPos[] = {i, i}; -// auto zIndex = shape::getOffset(outShape, diagPos); -// outBuffer[zIndex] = T(1.f); -// } -// } -// __syncthreads(); - - for (auto i = blockIdx.x; i < rows; i += gridDim.x) - for (auto j = threadIdx.x; j < columns; j += blockDim.x) { - Nd4jLong pos[] = {i,j}; - auto zIndex = shape::getOffset(outShape, pos); - auto xIndex = shape::getOffset(inShape, pos); - if (i < column || j < column) { - outBuffer[zIndex] = i != j?T(0.f):T(1.f); - } - else - outBuffer[zIndex] = inBuffer[xIndex]; //m.t(i,j) = in.t(i,j); - } - - +template +static __global__ void matrixMinorKernel(T* outBuffer, Nd4jLong* outShape, + T* inBuffer, Nd4jLong* inShape, + Nd4jLong column, Nd4jLong rows, + Nd4jLong columns) { + // auto tid = threadIdx.x + blockDim.x * blockIdx.x; + // auto step = blockDim.x * gridDim.x; + // if (threadIdx.x == 0) { + // for (auto i = tid; i < column; i += step) { + // Nd4jLong diagPos[] = {i, i}; + // auto zIndex = shape::getOffset(outShape, diagPos); + // outBuffer[zIndex] = T(1.f); + // } + // } + // __syncthreads(); + + for (auto i = blockIdx.x; i < rows; i += gridDim.x) + for (auto j = threadIdx.x; j < columns; j += blockDim.x) { + Nd4jLong pos[] = {i, j}; + auto zIndex = shape::getOffset(outShape, pos); + auto xIndex = shape::getOffset(inShape, pos); + if (i < column || j < column) { + outBuffer[zIndex] = i != j ? T(0.f) : T(1.f); + } else + outBuffer[zIndex] = inBuffer[xIndex]; // m.t(i,j) = in.t(i,j); } +} - template - NDArray matrixMinor(LaunchContext* context, NDArray& in, Nd4jLong col) { - NDArray m = in.ulike(); - m.setIdentity(); - m({col, m.rows(), col, m.columns()}).assign(in({col, m.rows(), col, m.columns()})); - -// auto stream = context->getCudaStream(); -// matrixMinorKernel<<<128, 128, 256, *stream>>>(m.dataBuffer()->specialAsT(), m.special(), -// matrixMinorKernel<<<128, 128, 256, *stream>>>(m.dataBuffer()->specialAsT(), m.special(), -// reinterpret_cast(in.specialBuffer()), in.special(), col, in.rows(), in.columns()); -// - m.tickWriteDevice(); - return m; - } +template +NDArray matrixMinor(LaunchContext* context, NDArray& in, Nd4jLong col) { + NDArray m = in.ulike(); + m.setIdentity(); + m({col, m.rows(), col, m.columns()}) + .assign(in({col, m.rows(), col, m.columns()})); + + // auto stream = context->getCudaStream(); + // matrixMinorKernel<<<128, 128, 256, + // *stream>>>(m.dataBuffer()->specialAsT(), m.special(), + // matrixMinorKernel<<<128, 128, 256, + // *stream>>>(m.dataBuffer()->specialAsT(), m.special(), + // reinterpret_cast(in.specialBuffer()), + // in.special(), col, in.rows(), in.columns()); + // + m.tickWriteDevice(); + return m; +} /* m = I - v v^T */ - template - static __global__ void vmulKernel(T* resBuf, const Nd4jLong* resShape, T const* vBuff, Nd4jLong const* vShape, Nd4jLong n) { - for (auto i = blockIdx.x; i < n; i += gridDim.x) - for (auto j = threadIdx.x; j < n; j += blockDim.x) { - Nd4jLong posR[] = {i, j}; - auto indexR = shape::getOffset(resShape, posR); - auto indexX = shape::getIndexOffset(i, vShape); - auto indexY = shape::getIndexOffset(j, vShape); - - resBuf[indexR] = T(-2.f) * vBuff[indexX] * vBuff[indexY] + (i != j?T(0.f):T(1.f)); - } - } - - template - NDArray vmul(LaunchContext* context, NDArray const& v, int n) - { - NDArray res('c', {n,n}, v.dataType(), context); // x = matrix_new(n, n); - - auto stream = context->getCudaStream(); - vmulKernel<<<128, 128, 128, *stream>>>(res.dataBuffer()->specialAsT(), res.specialShapeInfo(), - reinterpret_cast(v.specialBuffer()), v.specialShapeInfo(), n); - return res; - } - - template - static bool diagonalIsPositive(NDArray* matrix, Nd4jLong k) { - T hVal; - Nd4jLong pos[] = {k, k}; - auto shift = shape::getOffset(matrix->shapeInfo(), pos); - cudaMemcpy(&hVal, matrix->specialBuffer(), sizeof(T), cudaMemcpyDeviceToHost); - return hVal > T(0.f); +template +static __global__ void vmulKernel(T* resBuf, const Nd4jLong* resShape, + T const* vBuff, Nd4jLong const* vShape, + Nd4jLong n) { + for (auto i = blockIdx.x; i < n; i += gridDim.x) + for (auto j = threadIdx.x; j < n; j += blockDim.x) { + Nd4jLong posR[] = {i, j}; + auto indexR = shape::getOffset(resShape, posR); + auto indexX = shape::getIndexOffset(i, vShape); + auto indexY = shape::getIndexOffset(j, vShape); + + resBuf[indexR] = + T(-2.f) * vBuff[indexX] * vBuff[indexY] + (i != j ? T(0.f) : T(1.f)); } +} - template - void qrSingle(LaunchContext* context, NDArray* matrix, NDArray* Q, NDArray* R, bool const fullMatricies) { - Nd4jLong M = matrix->sizeAt(0); - Nd4jLong N = matrix->sizeAt(1); - auto resQ = fullMatricies?Q->ulike():NDArrayFactory::create(matrix->ordering(), {M,M}, Q->getContext()); - auto resR = fullMatricies?R->ulike():matrix->ulike(); - std::vector q(M); - NDArray z = *matrix; - NDArray e('c', {M}, DataTypeUtils::fromT(), context); // two internal buffers and scalar for squared norm - for (auto k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number - e.nullify(); - z = matrixMinor(context, z, k); // minor computing for current column with given matrix z (initally is a input matrix) - - auto currentColumn = z({0, 0, k, k + 1}); // retrieve k column from z to x buffer - auto norm = currentColumn.reduceAlongDimension(reduce::Norm2, {0}); - if (diagonalIsPositive(matrix, k)) //matrix->t(k,k) > T(0.f)) // negate on positive matrix diagonal element - norm.applyTransform(transform::Neg, norm); // *= -1.f;//-norm.t(0); - - e.p(k, norm); // e - is filled by 0 vector except diagonal element (filled by 1) - e += currentColumn; // e[i] = x[i] + a * e[i] for each i from 0 to n - 1 - auto normE = e.reduceAlongDimension(reduce::Norm2, {0}); - e /= normE; - q[k] = vmul(context, e, M); - auto qQ = z.ulike(); - MmulHelper::matmul(&q[k], &z, &qQ, false, false); - z = std::move(qQ); - } - resQ.assign(q[0]); // -// MmulHelper::matmul(&q[0], matrix, &resR, false, false); - for (int i = 1; i < N && i < M - 1; i++) { - auto tempResQ = resQ; - MmulHelper::matmul(&q[i], &resQ, &tempResQ, false, false); - resQ = std::move(tempResQ); - } - MmulHelper::matmul(&resQ, matrix, &resR, false, false); - // resR *= -1.f; - resQ.transposei(); - - if (fullMatricies) { - Q->assign(resQ); - R->assign(resR); - } - else { - Q->assign(resQ({0, 0, 0, N})); - R->assign(resR({0, N, 0, 0})); - } - } +template +NDArray vmul(LaunchContext* context, NDArray const& v, int n) { + NDArray res('c', {n, n}, v.dataType(), context); // x = matrix_new(n, n); - template - void qr_(LaunchContext* context, NDArray const* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) { - Nd4jLong lastDim = input->rankOf() - 1; - Nd4jLong preLastDim = input->rankOf() - 2; - - NDArray::prepareSpecialUse({outputQ, outputR}, {input}); - ResultSet listOutQ(outputQ->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); - ResultSet listOutR(outputR->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); - ResultSet listInput(input->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); - auto start = 0; - auto stop = listInput.size(); - auto increment = 1; - - for (auto batch = start; batch < stop; batch += increment) { - //qr here - qrSingle(context, listInput.at(batch), listOutQ.at(batch), listOutR.at(batch), fullMatricies); - } - NDArray::registerSpecialUse({outputQ, outputR}, {input}); - } + auto stream = context->getCudaStream(); + vmulKernel<<<128, 128, 128, *stream>>>( + res.dataBuffer()->specialAsT(), res.specialShapeInfo(), + reinterpret_cast(v.specialBuffer()), v.specialShapeInfo(), n); + return res; +} - void qr(sd::LaunchContext* context, NDArray const* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) { - BUILD_SINGLE_SELECTOR(input->dataType(), qr_, (context, input, outputQ, outputR, fullMatricies), FLOAT_TYPES); - } +template +static bool diagonalIsPositive(NDArray* matrix, Nd4jLong k) { + T hVal; + Nd4jLong pos[] = {k, k}; + auto shift = shape::getOffset(matrix->shapeInfo(), pos); + cudaMemcpy(&hVal, matrix->specialBuffer(), sizeof(T), cudaMemcpyDeviceToHost); + return hVal > T(0.f); +} +template +void qrSingle(LaunchContext* context, NDArray* matrix, NDArray* Q, NDArray* R, + bool const fullMatricies) { + Nd4jLong M = matrix->sizeAt(0); + Nd4jLong N = matrix->sizeAt(1); + auto resQ = fullMatricies ? Q->ulike() + : NDArrayFactory::create( + matrix->ordering(), {M, M}, Q->getContext()); + auto resR = fullMatricies ? R->ulike() : matrix->ulike(); + std::vector q(M); + NDArray z = *matrix; + NDArray e('c', {M}, DataTypeUtils::fromT(), + context); // two internal buffers and scalar for squared norm + for (auto k = 0; k < N && k < M - 1; + k++) { // loop for columns, but not further then row number + e.nullify(); + z = matrixMinor(context, z, + k); // minor computing for current column with given + // matrix z (initally is a input matrix) + + auto currentColumn = + z({0, 0, k, k + 1}); // retrieve k column from z to x buffer + auto norm = currentColumn.reduceAlongDimension(reduce::Norm2, {0}); + if (diagonalIsPositive(matrix, + k)) // matrix->t(k,k) > T(0.f)) // negate on + // positive matrix diagonal element + norm.applyTransform(transform::Neg, norm); // *= -1.f;//-norm.t(0); + + e.p(k, norm); // e - is filled by 0 vector except diagonal element (filled + // by 1) + e += currentColumn; // e[i] = x[i] + a * e[i] for each i from 0 to n - 1 + auto normE = e.reduceAlongDimension(reduce::Norm2, {0}); + e /= normE; + q[k] = vmul(context, e, M); + auto qQ = z.ulike(); + MmulHelper::matmul(&q[k], &z, &qQ, false, false); + z = std::move(qQ); + } + resQ.assign(q[0]); // + // MmulHelper::matmul(&q[0], matrix, &resR, false, false); + for (int i = 1; i < N && i < M - 1; i++) { + auto tempResQ = resQ; + MmulHelper::matmul(&q[i], &resQ, &tempResQ, false, false); + resQ = std::move(tempResQ); + } + MmulHelper::matmul(&resQ, matrix, &resR, false, false); + // resR *= -1.f; + resQ.transposei(); + + if (fullMatricies) { + Q->assign(resQ); + R->assign(resR); + } else { + Q->assign(resQ({0, 0, 0, N})); + R->assign(resR({0, N, 0, 0})); + } } + +template +void qr_(LaunchContext* context, NDArray const* input, NDArray* outputQ, + NDArray* outputR, bool const fullMatricies) { + Nd4jLong lastDim = input->rankOf() - 1; + Nd4jLong preLastDim = input->rankOf() - 2; + + NDArray::prepareSpecialUse({outputQ, outputR}, {input}); + ResultSet listOutQ( + outputQ->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); + ResultSet listOutR( + outputR->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); + ResultSet listInput( + input->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); + auto start = 0; + auto stop = listInput.size(); + auto increment = 1; + + for (auto batch = start; batch < stop; batch += increment) { + // qr here + qrSingle(context, &listInput.at(batch), &listOutQ.at(batch), &listOutR.at(batch), fullMatricies); + } + NDArray::registerSpecialUse({outputQ, outputR}, {input}); } + +void qr(sd::LaunchContext* context, NDArray const* input, NDArray* outputQ, + NDArray* outputR, bool const fullMatricies) { + BUILD_SINGLE_SELECTOR(input->dataType(), qr_, + (context, input, outputQ, outputR, fullMatricies), + FLOAT_TYPES); } + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/random.cu b/libnd4j/include/ops/declarable/helpers/cuda/random.cu index e13883515710..ddb40c23c87b 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/random.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/random.cu @@ -20,15 +20,16 @@ #include //#include -#include -#include -#include -#include -#include #include #include +#include #include #include +#include +#include + +#include +#include namespace sd { namespace ops { @@ -122,334 +123,374 @@ namespace helpers { return (decreasedAlpha * normalizedVar / beta); } - /* - * fillGammaKernel - fill up output with gamma distributed values - * - * uList - uniformly distributed values set - * uLength - length of uList - * alpha - alpha param - * beta - beta param - * output - distributed output. - * */ - template - static __global__ void fillGammaKernel(T const* uList, Nd4jLong uLength, T const* alpha, const Nd4jLong* alphaShape, - T const* beta, const Nd4jLong* betaShape, T* output, const Nd4jLong* outputShape) { - // fill up - __shared__ Nd4jLong aLength; - __shared__ Nd4jLong outLength; - if (threadIdx.x == 0) { - aLength = shape::length(alphaShape); - outLength = shape::length(outputShape) / aLength; - } - __syncthreads(); - - for (auto k = blockIdx.x; k < (int)outLength; k += gridDim.x) { - auto pos = k * aLength; -// auto u = uList[k]; // this is a vector - //Nd4jLong index = k; - for (auto e = threadIdx.x; e < (int)aLength; e += blockDim.x) { - auto aIndex = shape::getIndexOffset(e, alphaShape); - auto bIndex = betaShape?shape::getIndexOffset(e, betaShape):-1LL; - auto betaV = T(beta != nullptr ? beta[bIndex] : T(1.f)); - auto zIndex = shape::getIndexOffset(e + pos, outputShape); - - output[zIndex] = alpha[aIndex] > T(1.f)?gammaGreat(uList, pos, uLength, alpha[aIndex], betaV):gammaLess(uList, pos, uLength, alpha[aIndex], betaV); - } - } +/* + * fillGammaKernel - fill up output with gamma distributed values + * + * uList - uniformly distributed values set + * uLength - length of uList + * alpha - alpha param + * beta - beta param + * output - distributed output. + * */ +template +static __global__ void fillGammaKernel(T const* uList, Nd4jLong uLength, T const* alpha, + const Nd4jLong* alphaShape, T const* beta, + const Nd4jLong* betaShape, T* output, + const Nd4jLong* outputShape) { + // fill up + __shared__ Nd4jLong aLength; + __shared__ Nd4jLong outLength;if (threadIdx.x == 0) { + aLength = shape::length(alphaShape);outLength = shape::length(outputShape) / aLength; + } + __syncthreads(); + + for (auto k = blockIdx.x; k < (int)outLength; k += gridDim.x) { + auto pos = k * aLength; +// auto u = uList[k]; // this is a vector//Nd4jLong index = k; + for (auto e = threadIdx.x; e < (int)aLength; e += blockDim.x) { + auto aIndex = shape::getIndexOffset(e, alphaShape); + auto bIndex = betaShape ? shape::getIndexOffset(e, betaShape) : -1LL; + auto betaV = T(beta != nullptr ? beta[bIndex] : T(1.f)); + auto zIndex = shape::getIndexOffset(e + pos, outputShape); + + output[zIndex] = alpha[aIndex] > T(1.f)?gammaGreat(uList, pos, uLength, alpha[aIndex], betaV):gammaLess(uList, pos, uLength, alpha[aIndex], betaV); } + } +} - template - static void fillRandomGamma_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) { - // To fill up output need to broadcast alpha and beta to the same shape and in - const Nd4jLong* broadcasted = nullptr; - if (beta != nullptr) - ShapeUtils::evalBroadcastShapeInfo(*alpha, *beta, true, broadcasted, context->getWorkspace()); - else - broadcasted = alpha->shapeInfo(); - auto step = shape::length(broadcasted); - auto shift = output->lengthOf() * 4LL; // 2-wise greater case for uniform vals - - auto copyAlpha = alpha; - auto copyBeta = beta; - if (beta != nullptr) { - NDArray alphaBroadcasted(broadcasted, alpha->dataType(), true, context); - NDArray betaBroadcasted(broadcasted, beta->dataType(), true, context); - - copyAlpha = new NDArray(alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *alpha)); - copyBeta = new NDArray(betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *beta)); -// if (!copyAlpha->isActualOnDevice()) copyAlpha->syncToDevice(); -// if (!copyBeta->isActualOnDevice()) copyBeta->syncToDevice(); - } - - auto stream = context->getCudaStream(); - NDArray uniform = NDArrayFactory::create('c', {shift}, context); - uniform.syncToDevice(); - // fill up uniform with given length - RandomLauncher::fillUniform(context, rng, &uniform, 0.0000000001, 0.9999999999); +template +static void fillRandomGamma_(LaunchContext* context, + graph::RandomGenerator& rng, NDArray* alpha, + NDArray* beta, NDArray* output) { + // To fill up output need to broadcast alpha and beta to the same shape and in + const Nd4jLong* broadcasted = nullptr; + if (beta != nullptr) + ShapeUtils::evalBroadcastShapeInfo(*alpha, *beta, true, broadcasted, + context->getWorkspace()); + else + broadcasted = alpha->shapeInfo(); + auto step = shape::length(broadcasted); + auto shift = output->lengthOf() * 4LL; // 2-wise greater case for uniform vals + + auto copyAlpha = alpha; + auto copyBeta = beta; + if (beta != nullptr) { + NDArray alphaBroadcasted(broadcasted, alpha->dataType(), true, context); + NDArray betaBroadcasted(broadcasted, beta->dataType(), true, context); + + copyAlpha = new NDArray(alphaBroadcasted.applyTrueBroadcast( + BroadcastOpsTuple::Assign(), *alpha)); + copyBeta = new NDArray(betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *beta)); + } + + auto stream = context->getCudaStream(); + NDArray uniform = NDArrayFactory::create('c', {shift}, context); + uniform.syncToDevice(); + // fill up uniform with given length + RandomLauncher::fillUniform(context, rng, &uniform, 0.0000000001, 0.9999999999); uniform.syncToDevice(); // uniform.printIndexedBuffer("Uniform"); - fillGammaKernel<<<128, 128, 256, *stream>>>(uniform.dataBuffer()->specialAsT(), shift, - copyAlpha->dataBuffer()->specialAsT(), copyAlpha->specialShapeInfo(), - beta?copyBeta->dataBuffer()->specialAsT():(T const*)nullptr, - beta?copyBeta->specialShapeInfo():(Nd4jLong const*)nullptr, - output->dataBuffer()->specialAsT(), output->specialShapeInfo()); - - if (beta != nullptr) { - delete copyAlpha; - delete copyBeta; - //delete broadcasted; - } + fillGammaKernel<<<128, 128, 256, *stream>>>( + uniform.dataBuffer()->specialAsT(), shift, + copyAlpha->dataBuffer()->specialAsT(), copyAlpha->specialShapeInfo(), + beta ? copyBeta->dataBuffer()->specialAsT() : (T const*)nullptr, + beta ? copyBeta->specialShapeInfo() : (Nd4jLong const*)nullptr, + output->dataBuffer()->specialAsT(), output->specialShapeInfo()); + + if (beta != nullptr) { + delete copyAlpha; + delete copyBeta; + // delete broadcasted; + } +} - } +void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, + NDArray* alpha, NDArray* beta, NDArray* output) { + if (beta) + NDArray::prepareSpecialUse({output}, {alpha, beta}); + else + NDArray::prepareSpecialUse({output}, {alpha}); + BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomGamma_, + (context, rng, alpha, beta, output), FLOAT_NATIVE); + if (beta) + NDArray::registerSpecialUse({output}, {alpha, beta}); + else + NDArray::prepareSpecialUse({output}, {alpha}); +} +BUILD_SINGLE_TEMPLATE(template void fillRandomGamma_, + (LaunchContext * context, graph::RandomGenerator& rng, + NDArray* alpha, NDArray* beta, NDArray* output), + FLOAT_NATIVE); - void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) { - if (beta) - NDArray::prepareSpecialUse({output}, {alpha, beta}); - else - NDArray::prepareSpecialUse({output}, {alpha}); - BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomGamma_, (context, rng, alpha, beta, output), FLOAT_NATIVE); - if (beta) - NDArray::registerSpecialUse({output}, {alpha, beta}); - else - NDArray::prepareSpecialUse({output}, {alpha}); +/* + * algorithm Poisson generator based upon the inversion by sequential search + * +init: + Let x ← 0, p ← e−λ, s ← p. + using uniformly random sequence U (u in U) distributed at [0, 1]. +while u > s do: + x ← x + 1. + p ← p * λ / x. + s ← s + p. +return x. + * */ +template +static __global__ void fillPoissonKernel(T* uList, Nd4jLong uLength, T* lambda, + const Nd4jLong* lambdaShape, T* output, + const Nd4jLong* outputShape) { + __shared__ Nd4jLong step; + + if (threadIdx.x == 0) { + step = shape::length(lambdaShape); + } + __syncthreads(); + + for (auto k = blockIdx.x; k < (int)uLength; k += gridDim.x) { + auto pos = k * step; + auto u = uList[k]; + for (auto e = threadIdx.x; e < step; e += blockDim.x) { + auto p = math::nd4j_exp(-lambda[e]); + auto s = p; + auto x = T(0.f); + auto lIndex = shape::getIndexOffset(e, lambdaShape); + auto zIndex = shape::getIndexOffset(e + pos, outputShape); + while (u > s) { + x += T(1.); + p *= lambda[lIndex] / x; + s += p; + } + output[zIndex] = x; } - BUILD_SINGLE_TEMPLATE(template void fillRandomGamma_, (LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output), FLOAT_NATIVE); - - - /* - * algorithm Poisson generator based upon the inversion by sequential search - * - init: - Let x ← 0, p ← e−λ, s ← p. - using uniformly random sequence U (u in U) distributed at [0, 1]. - while u > s do: - x ← x + 1. - p ← p * λ / x. - s ← s + p. - return x. - * */ - template - static __global__ void fillPoissonKernel(T* uList, Nd4jLong uLength, T* lambda, const Nd4jLong* lambdaShape, - T* output, const Nd4jLong* outputShape) { + } +} - __shared__ Nd4jLong step; +template +static void fillRandomPoisson_(LaunchContext* context, + graph::RandomGenerator& rng, NDArray* lambda, + NDArray* output) { + auto shift = output->lengthOf() / lambda->lengthOf(); + NDArray uniform('c', {shift}, output->dataType()); + auto stream = context->getCudaStream(); + // fill up uniform with given length + RandomLauncher::fillUniform(context, rng, &uniform, 0., 1.); + fillPoissonKernel<<<128, 256, 128, *stream>>>( + uniform.dataBuffer()->specialAsT(), uniform.lengthOf(), + lambda->dataBuffer()->specialAsT(), lambda->specialShapeInfo(), + output->dataBuffer()->specialAsT(), output->specialShapeInfo()); +} - if (threadIdx.x == 0) { - step = shape::length(lambdaShape); - } - __syncthreads(); - - for (auto k = blockIdx.x; k < (int)uLength; k += gridDim.x) { - auto pos = k * step; - auto u = uList[k]; - for (auto e = threadIdx.x; e < step; e += blockDim.x) { - auto p = math::nd4j_exp(-lambda[e]); - auto s = p; - auto x = T(0.f); - auto lIndex = shape::getIndexOffset(e, lambdaShape); - auto zIndex = shape::getIndexOffset(e + pos, outputShape); - while (u > s) { - x += T(1.); - p *= lambda[lIndex] / x; - s += p; - } - output[zIndex] = x; - } - } - } +void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, + NDArray* lambda, NDArray* output) { + NDArray::prepareSpecialUse({output}, {lambda}); + BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomPoisson_, + (context, rng, lambda, output), FLOAT_NATIVE); + NDArray::registerSpecialUse({output}, {lambda}); +} - template - static void fillRandomPoisson_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output) { - auto shift = output->lengthOf() / lambda->lengthOf(); - NDArray uniform('c', {shift}, output->dataType()); - auto stream = context->getCudaStream(); - // fill up uniform with given length - RandomLauncher::fillUniform(context, rng, &uniform, 0., 1.); - fillPoissonKernel<<<128, 256, 128, *stream>>>(uniform.dataBuffer()->specialAsT(), uniform.lengthOf(), - lambda->dataBuffer()->specialAsT(), lambda->specialShapeInfo(), - output->dataBuffer()->specialAsT(), output->specialShapeInfo()); - } +BUILD_SINGLE_TEMPLATE(template void fillRandomPoisson_, + (LaunchContext * context, graph::RandomGenerator& rng, + NDArray* lambda, NDArray* output), + FLOAT_NATIVE); + +template +static __global__ void fillUniformKernel(graph::RandomGenerator* devRng, T from, + T to, T* output, + const Nd4jLong* outputShape) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + __shared__ Nd4jLong outputLen; + + if (0 == threadIdx.x) { + outputLen = shape::length(outputShape); + } + __syncthreads(); + + for (auto i = start; i < outputLen; i += step) { + auto zIndex = shape::getIndexOffset(i, outputShape); + output[zIndex] = devRng->relativeT(i, from, to); + } +} - void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output) { - NDArray::prepareSpecialUse({output}, {lambda}); - BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomPoisson_, (context, rng, lambda, output), FLOAT_NATIVE); - NDArray::registerSpecialUse({output}, {lambda}); +template +static void fillRandomUniform_(LaunchContext* context, + graph::RandomGenerator& rng, NDArray* min, + NDArray* max, NDArray* output) { + T minVal = T(0); + T maxVal = DataTypeUtils::infOrMax(); + if (min) minVal = min->t(0); + if (max) maxVal = max->t(0); + + if (output->isR()) + RandomLauncher::fillUniform(context, rng, output, minVal, maxVal); + else { + auto stream = context->getCudaStream(); + graph::RandomGenerator* devRng; + auto err = cudaMalloc(&devRng, sizeof(graph::RandomGenerator)); + if (err != 0) { + cuda_exception::build( + "fillRandomUniform_: Cannot allocate device memory for random " + "generator due error", + err); } - BUILD_SINGLE_TEMPLATE(template void fillRandomPoisson_, (LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output), FLOAT_NATIVE); - - template - static __global__ void fillUniformKernel(graph::RandomGenerator* devRng, T from, T to, T* output, const Nd4jLong* outputShape) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - __shared__ Nd4jLong outputLen; - - if (0 == threadIdx.x) { - outputLen = shape::length(outputShape); - } - __syncthreads(); - - for (auto i = start; i < outputLen; i += step) { - auto zIndex = shape::getIndexOffset(i, outputShape); - output[zIndex] = devRng->relativeT(i, from, to); - } - + err = cudaMemcpy(devRng, &rng, sizeof(graph::RandomGenerator), + cudaMemcpyHostToDevice); + if (err != 0) { + cuda_exception::build( + "fillRandomUniform_: Cannot copy random generator to device", err); } - - template - static void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) { - T minVal = T(0); - T maxVal = DataTypeUtils::infOrMax(); - if (min) - minVal = min->t(0); - if (max) - maxVal = max->t(0); - - if (output->isR()) - RandomLauncher::fillUniform(context, rng, output, minVal, maxVal); - else { - auto stream = context->getCudaStream(); - graph::RandomGenerator *devRng; - auto err = cudaMalloc(&devRng, sizeof(graph::RandomGenerator)); - if (err != 0) { - cuda_exception::build("fillRandomUniform_: Cannot allocate device memory for random generator due error", err); - } - - err = cudaMemcpy(devRng, &rng, sizeof(graph::RandomGenerator), cudaMemcpyHostToDevice); - if (err != 0) { - cuda_exception::build("fillRandomUniform_: Cannot copy random generator to device", err); - } - auto outputBuf = output->dataBuffer()->specialAsT(); - auto outputShape = output->specialShapeInfo(); - fillUniformKernel<<<128, 128, 128, *stream>>>(devRng, minVal, maxVal, outputBuf, outputShape); - - err = cudaStreamSynchronize(*stream); - if (err != 0) { - cuda_exception::build("fillRandomUniform_: Cannot successfully finish kernel call", err); - } - - err = cudaFree(devRng); - if (err != 0) { - cuda_exception::build("fillRandomUniform_: Cannot deallocate device memory for random generator", err); - } - } + auto outputBuf = output->dataBuffer()->specialAsT(); + auto outputShape = output->specialShapeInfo(); + fillUniformKernel<<<128, 128, 128, *stream>>>(devRng, minVal, maxVal, + outputBuf, outputShape); + + err = cudaStreamSynchronize(*stream); + if (err != 0) { + cuda_exception::build( + "fillRandomUniform_: Cannot successfully finish kernel call", err); } - void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) { - BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES); + err = cudaFree(devRng); + if (err != 0) { + cuda_exception::build( + "fillRandomUniform_: Cannot deallocate device memory for random " + "generator", + err); } + } +} + +void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, + NDArray* min, NDArray* max, NDArray* output) { + BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, + (context, rng, min, max, output), NUMERIC_TYPES); +} /////////////////////////////////////////////////////////////////// // used https://en.wikipedia.org/wiki/Categorical_distribution // methods: gumbel trick + softmax + argmax -template -__global__ static void fillMultiNomialCuda_(graph::RandomGenerator* devRng, const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong batchValue, - const Nd4jLong numOfSamples, const Nd4jLong numOfClassX, - const Nd4jLong dimA, const X minVal, const X maxVal) { - - - const X* x = reinterpret_cast(vx); - Z* z = reinterpret_cast(vz); - - __shared__ Nd4jLong xDimAstride, zDimAstride, xDimCstride, zDimCstride, dimC; - - if (0 == threadIdx.x) { - dimC = (0 == dimA) ? 1 : 0; - zDimAstride = shape::stride(zShapeInfo)[dimA]; - xDimAstride = shape::stride(xShapeInfo)[dimA]; - zDimCstride = shape::stride(zShapeInfo)[dimC]; - xDimCstride = shape::stride(xShapeInfo)[dimC]; - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong index = tid; index < batchValue*numOfSamples; index += gridDim.x * blockDim.x) { - - Nd4jLong nBatchIndex = index / numOfSamples; - Nd4jLong nSampleIndexInBatch = index - (nBatchIndex * numOfSamples); - - const X* xTad = x + (nBatchIndex * xDimCstride); - Z* zTad = z + (nBatchIndex * zDimCstride); - Z& arg = zTad[nSampleIndexInBatch * zDimAstride]; - - X Max = -minVal; - Nd4jLong nSamplesPerBatch = nBatchIndex * numOfClassX * numOfSamples; - Nd4jLong nClassPerSamples = nSampleIndexInBatch * numOfClassX; - - for (Nd4jLong nClass = 0; nClass < numOfClassX; nClass++) { - Nd4jLong nIndex = nSamplesPerBatch + nClassPerSamples + nClass; - X tValue = (xTad[nClass * xDimAstride] - sd::math::nd4j_log(-sd::math::nd4j_log(devRng->relativeT(nIndex, minVal, maxVal)))); - if (tValue > Max) { - Max = tValue; - arg = nClass; - } - } +template +__global__ static void fillMultiNomialCuda_( + graph::RandomGenerator* devRng, const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong batchValue, + const Nd4jLong numOfSamples, const Nd4jLong numOfClassX, + const Nd4jLong dimA, const X minVal, const X maxVal) { + const X* x = reinterpret_cast(vx); + Z* z = reinterpret_cast(vz); + + __shared__ Nd4jLong xDimAstride, zDimAstride, xDimCstride, zDimCstride, dimC; + + if (0 == threadIdx.x) { + dimC = (0 == dimA) ? 1 : 0; + zDimAstride = shape::stride(zShapeInfo)[dimA]; + xDimAstride = shape::stride(xShapeInfo)[dimA]; + zDimCstride = shape::stride(zShapeInfo)[dimC]; + xDimCstride = shape::stride(xShapeInfo)[dimC]; + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong index = tid; index < batchValue * numOfSamples; + index += gridDim.x * blockDim.x) { + Nd4jLong nBatchIndex = index / numOfSamples; + Nd4jLong nSampleIndexInBatch = index - (nBatchIndex * numOfSamples); + + const X* xTad = x + (nBatchIndex * xDimCstride); + Z* zTad = z + (nBatchIndex * zDimCstride); + Z& arg = zTad[nSampleIndexInBatch * zDimAstride]; + + X Max = -minVal; + Nd4jLong nSamplesPerBatch = nBatchIndex * numOfClassX * numOfSamples; + Nd4jLong nClassPerSamples = nSampleIndexInBatch * numOfClassX; + + for (Nd4jLong nClass = 0; nClass < numOfClassX; nClass++) { + Nd4jLong nIndex = nSamplesPerBatch + nClassPerSamples + nClass; + X tValue = (xTad[nClass * xDimAstride] - + sd::math::nd4j_log(-sd::math::nd4j_log( + devRng->relativeT(nIndex, minVal, maxVal)))); + if (tValue > Max) { + Max = tValue; + arg = nClass; + } } + } } ////////////////////////////////////////////////////////////////////////// -template +template __host__ static void fillMultiNomialCudaLauncher( - const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, - graph::RandomGenerator* devRng, const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const Nd4jLong batchValue, const Nd4jLong numOfSamples, - const Nd4jLong numOfClassX, const Nd4jLong dimA){ - - const X minVal = DataTypeUtils::min(); - const X maxVal = 1.0; - - fillMultiNomialCuda_ <<< blocksPerGrid, threadsPerBlock, 256, * stream >>> ( - devRng, vx, xShapeInfo, vz, zShapeInfo, batchValue, - numOfSamples, numOfClassX, dimA, minVal, maxVal); + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, graph::RandomGenerator* devRng, const void* vx, + const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + const Nd4jLong batchValue, const Nd4jLong numOfSamples, + const Nd4jLong numOfClassX, const Nd4jLong dimA) { + const X minVal = DataTypeUtils::min(); + const X maxVal = 1.0; + + fillMultiNomialCuda_<<>>( + devRng, vx, xShapeInfo, vz, zShapeInfo, batchValue, numOfSamples, + numOfClassX, dimA, minVal, maxVal); } - -/////////////////////////////////////////////////////////////////// -void fillRandomMultiNomial(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC) { - - Nd4jLong dimA = (0 == dimC) ? 1 : 0; - - const Nd4jLong batchValue = output.sizeAt(dimC); - const Nd4jLong numOfClassX = input.sizeAt(dimA); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (batchValue * numOfSamples + threadsPerBlock - 1) / threadsPerBlock; - - PointersManager manager(context, "fillMultinomial"); - graph::RandomGenerator *devRng; - - auto err = cudaMalloc(&devRng, sizeof(graph::RandomGenerator)); - if (err != 0) { - cuda_exception::build("fillRandomMultiNomial: Cannot allocate device memory for random generator due error", err); - } - err = cudaStreamSynchronize(*context->getCudaStream()); - if (err != 0) { - cuda_exception::build("fillRandomMultiNomial: Cannot synchronize stream for random generator due error", err); - } - err = cudaMemcpyAsync(devRng, &rng, sizeof(graph::RandomGenerator), cudaMemcpyHostToDevice, *context->getCudaStream()); - if (err != 0) { - cuda_exception::build("fillRandomMultiNomial: Cannot copy random generator to device", err); - } - - NDArray::prepareSpecialUse({ &output }, { &input }); - BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), fillMultiNomialCudaLauncher, - (blocksPerGrid, threadsPerBlock, context->getCudaStream(), devRng, input.specialBuffer(), - input.specialShapeInfo(), output.specialBuffer(), - output.specialShapeInfo(), batchValue, numOfSamples, - numOfClassX, dimA), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({ &output }, { &input }); - manager.synchronize(); - - err = cudaFree(devRng); - if (err != 0) { - cuda_exception::build("fillRandomMultiNomial: Cannot deallocate device memory for random generator", err); - } - rng.rewindH(output.lengthOf() * numOfClassX); - } +/////////////////////////////////////////////////////////////////// +void fillRandomMultiNomial(LaunchContext* context, graph::RandomGenerator& rng, + NDArray& input, NDArray& output, + const Nd4jLong numOfSamples, const int dimC) { + Nd4jLong dimA = (0 == dimC) ? 1 : 0; + + const Nd4jLong batchValue = output.sizeAt(dimC); + const Nd4jLong numOfClassX = input.sizeAt(dimA); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (batchValue * numOfSamples + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "fillMultinomial"); + graph::RandomGenerator* devRng; + + auto err = cudaMalloc(&devRng, sizeof(graph::RandomGenerator)); + if (err != 0) { + cuda_exception::build( + "fillRandomMultiNomial: Cannot allocate device memory for random " + "generator due error", + err); + } + err = cudaStreamSynchronize(*context->getCudaStream()); + if (err != 0) { + cuda_exception::build( + "fillRandomMultiNomial: Cannot synchronize stream for random generator " + "due error", + err); + } + err = cudaMemcpyAsync(devRng, &rng, sizeof(graph::RandomGenerator), + cudaMemcpyHostToDevice, *context->getCudaStream()); + if (err != 0) { + cuda_exception::build( + "fillRandomMultiNomial: Cannot copy random generator to device", err); + } + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_DOUBLE_SELECTOR( + input.dataType(), output.dataType(), fillMultiNomialCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), devRng, + input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), batchValue, numOfSamples, numOfClassX, dimA), + FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + manager.synchronize(); + + err = cudaFree(devRng); + if (err != 0) { + cuda_exception::build( + "fillRandomMultiNomial: Cannot deallocate device memory for random " + "generator", + err); + } + rng.rewindH(output.lengthOf() * numOfClassX); } -} -} \ No newline at end of file + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/randomShuffle.cu b/libnd4j/include/ops/declarable/helpers/cuda/randomShuffle.cu index bb7998e60ef3..29725ac86ab5 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/randomShuffle.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/randomShuffle.cu @@ -187,7 +187,7 @@ static void randomShuffle_(sd::LaunchContext* context, NDArray& input, NDArray& for(int i = firstDim - 1; i > 0; --i) { const int j = rng.relativeInt(i) % (i + 1); if(i != j) - subArrsList.at(i)->swapUnsafe(*subArrsList.at(j)); + subArrsList.at(i).swapUnsafe(subArrsList.at(j)); } } else { @@ -204,7 +204,7 @@ static void randomShuffle_(sd::LaunchContext* context, NDArray& input, NDArray& auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; ++i) - subArrsListOut.at(i)->assign(subArrsListIn.at(indices[i])); + subArrsListOut.at(i).assign(subArrsListIn.at(indices[i])); }; samediff::Threads::parallel_for(func, 0, firstDim); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/random_crop.cu b/libnd4j/include/ops/declarable/helpers/cuda/random_crop.cu index 0489103f9287..3401d1e89250 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/random_crop.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/random_crop.cu @@ -20,24 +20,31 @@ #include //#include -#include -#include #include + +#include +#include namespace sd { namespace ops { namespace helpers { - template - static int _randomCropFunctor(graph::Context& context, NDArray* input, NDArray* shape, NDArray* output, int seed) { - return Status::OK(); - } +template +static int _randomCropFunctor(graph::Context& context, NDArray* input, + NDArray* shape, NDArray* output, int seed) { + return Status::OK(); +} - int randomCropFunctor(graph::Context& context, NDArray* input, NDArray* shape, NDArray* output, int seed) { - BUILD_SINGLE_SELECTOR(input->dataType(), return _randomCropFunctor, (context, input, shape, output, seed), FLOAT_TYPES); - } +int randomCropFunctor(graph::Context& context, NDArray* input, NDArray* shape, + NDArray* output, int seed) { + BUILD_SINGLE_SELECTOR(input->dataType(), return _randomCropFunctor, + (context, input, shape, output, seed), FLOAT_TYPES); +} - BUILD_SINGLE_TEMPLATE(template int _randomCropFunctor, (graph::Context& context, NDArray* input, NDArray* shape, NDArray* output, int seed), FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template int _randomCropFunctor, + (graph::Context & context, NDArray* input, NDArray* shape, + NDArray* output, int seed), + FLOAT_TYPES); -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/range.cu b/libnd4j/include/ops/declarable/helpers/cuda/range.cu index e33f95c52fc3..de50a258fa12 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/range.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/range.cu @@ -18,36 +18,40 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 27.08.2018 // - #include namespace sd { namespace ops { namespace helpers { - template - static __global__ void global_range(void *output, Nd4jLong length, T start, T delta) { - auto buff = reinterpret_cast(output); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - - for(Nd4jLong i = tid; i < length; i += step) - buff[i] = start + i * delta; - } - - ////////////////////////////////////////////////////////////////////////// - // be careful: outVector must have c-order and ews = 1 !!! - template - static void _range(sd::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector) { - global_range<<<512, 512, 2048, *context->getCudaStream()>>>(outVector.specialBuffer(), outVector.lengthOf(), start.e(0), delta.e(0)); - } - - void range(sd::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector) { - NDArray::prepareSpecialUse({&outVector}, {&start, &delta}); - BUILD_SINGLE_SELECTOR(outVector.dataType(), _range, (context, start, delta, outVector), LIBND4J_TYPES); - NDArray::registerSpecialUse({&outVector}, {&start, &delta}); - } +template +static __global__ void global_range(void* output, Nd4jLong length, T start, + T delta) { + auto buff = reinterpret_cast(output); + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + for (Nd4jLong i = tid; i < length; i += step) buff[i] = start + i * delta; } + +////////////////////////////////////////////////////////////////////////// +// be careful: outVector must have c-order and ews = 1 !!! +template +static void _range(sd::LaunchContext* context, const NDArray& start, + const NDArray& delta, NDArray& outVector) { + global_range<<<512, 512, 2048, *context->getCudaStream()>>>( + outVector.specialBuffer(), outVector.lengthOf(), start.e(0), + delta.e(0)); } -} \ No newline at end of file + +void range(sd::LaunchContext* context, const NDArray& start, + const NDArray& delta, NDArray& outVector) { + NDArray::prepareSpecialUse({&outVector}, {&start, &delta}); + BUILD_SINGLE_SELECTOR(outVector.dataType(), _range, + (context, start, delta, outVector), LIBND4J_TYPES); + NDArray::registerSpecialUse({&outVector}, {&start, &delta}); +} + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu b/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu index 6ae1b22a8d77..a16d0dcbf1fc 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu @@ -18,214 +18,257 @@ // @author Yurii Shyrma, created on 16.04.2018 // -#include -#include #include -#include -#include #include +#include +#include +#include +#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - template - static __global__ void reverseTadKernel(const void* vinput, const Nd4jLong *inputShape, void* voutput, const Nd4jLong *outputShape, const Nd4jLong *inputTadShape, const Nd4jLong *inputTadOffsets, const Nd4jLong *outputTadShape, const Nd4jLong *outputTadOffsets, uint64_t limit, uint64_t numOfElemsToReverse, uint64_t numTads) { - auto input = reinterpret_cast(vinput); - auto output = reinterpret_cast(voutput); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - - // this means that we'll have additional cycle, to move middle element - auto div = numOfElemsToReverse / 2; - auto odd = numOfElemsToReverse % 2 != 0; - auto rlimit = odd ? limit / 2 + 1 : limit / 2; - - // all threads operate in the same input/output space - for (uint64_t e = tid; e < rlimit; e += step) { - // finding out the TAD we're going to process - auto tadId = e / div; - - if (tadId >= numTads) - continue; - - // now finding out element within tad - auto idx = e % div; - - //printf("TID: %i; numTads: %lld; tadLength: %lld; tadId: %i, idx: %lld\n", tid, numTads, numOfElemsToReverse, tadId, idx); - - auto tadInput = input + inputTadOffsets[tadId]; - auto tadOutput = output + outputTadOffsets[tadId]; - - // we're calculating offsets within input TAD - auto fOffset = shape::getIndexOffset(idx, inputTadShape); - auto lOffset = shape::getIndexOffset(numOfElemsToReverse - idx - 1, inputTadShape); - - // now we're storing input values - auto v1 = tadInput[fOffset]; - auto v2 = tadInput[lOffset]; - - // now we're calculating offsets within output TAD - auto zfOffset = shape::getIndexOffset(idx, outputTadShape); - auto zlOffset = shape::getIndexOffset(numOfElemsToReverse - idx - 1, outputTadShape); - - // and saving values to output arrays - tadOutput[zfOffset] = v2; - tadOutput[zlOffset] = v1; - } - - // moving odd element in blocks - if (odd && threadIdx.x == 0) { - for (uint64_t e = blockIdx.x; e < numTads; e += gridDim.x) { - auto tadInput = input + inputTadOffsets[e]; - auto tadOutput = output + outputTadOffsets[e]; - - auto xOffset = shape::getIndexOffset(numOfElemsToReverse / 2, inputTadShape); - auto zOffset = shape::getIndexOffset(numOfElemsToReverse / 2, outputTadShape); - - tadOutput[zOffset] = tadInput[xOffset]; - } - } - - } - - - template - static __global__ void reverseArrayKernel(const void* input, const Nd4jLong *inputShape, void* output, const Nd4jLong *outputShape, Nd4jLong numOfElemsToReverse) { - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - __shared__ int linearStatus; - __shared__ const T* inputArr; - __shared__ T* outputArr; - __shared__ char inputOrder, outputOrder; - - if (threadIdx.x == 0) { - linearStatus = (shape::elementWiseStride(inputShape) == shape::elementWiseStride(outputShape)) && (inputOrder == outputOrder)? shape::elementWiseStride(inputShape):0; - - char inputOrder = shape::order(inputShape); - char outputOrder = shape::order(outputShape); - inputArr = reinterpret_cast(input); - outputArr = reinterpret_cast(output); - } - __syncthreads(); - - auto odd = numOfElemsToReverse % 2 != 0; - auto limit = numOfElemsToReverse / 2; - - for (uint64_t e = tid; e < limit; e += step) { - // we're calculating offsets within input array - auto fOffset = shape::getIndexOffset(e, inputShape); - auto lOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, inputShape); - - // now we're storing input values - auto v1 = inputArr[fOffset]; - auto v2 = inputArr[lOffset]; - - // now we're calculating offsets within output array - auto zfOffset = shape::getIndexOffset(e, outputShape); - auto zlOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, outputShape); - - // and saving values to output arrays - outputArr[zfOffset] = v2; - outputArr[zlOffset] = v1; - } - - // in case of odd array we'll have to move middle value - if (odd && tid == 0) { - auto xOffset = shape::getIndexOffset(limit, inputShape); - auto zOffset = shape::getIndexOffset(limit, outputShape); - - outputArr[zOffset] = inputArr[xOffset]; - } +template +static __global__ void reverseTadKernel( + const void* vinput, const Nd4jLong* inputShape, void* voutput, + const Nd4jLong* outputShape, const Nd4jLong* inputTadShape, + const Nd4jLong* inputTadOffsets, const Nd4jLong* outputTadShape, + const Nd4jLong* outputTadOffsets, uint64_t limit, + uint64_t numOfElemsToReverse, uint64_t numTads) { + auto input = reinterpret_cast(vinput); + auto output = reinterpret_cast(voutput); + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + + // this means that we'll have additional cycle, to move middle element + auto div = numOfElemsToReverse / 2; + auto odd = numOfElemsToReverse % 2 != 0; + auto rlimit = odd ? limit / 2 + 1 : limit / 2; + + // all threads operate in the same input/output space + for (uint64_t e = tid; e < rlimit; e += step) { + // finding out the TAD we're going to process + auto tadId = e / div; + + if (tadId >= numTads) continue; + + // now finding out element within tad + auto idx = e % div; + + // printf("TID: %i; numTads: %lld; tadLength: %lld; tadId: %i, idx: %lld\n", + // tid, numTads, numOfElemsToReverse, tadId, idx); + + auto tadInput = input + inputTadOffsets[tadId]; + auto tadOutput = output + outputTadOffsets[tadId]; + + // we're calculating offsets within input TAD + auto fOffset = shape::getIndexOffset(idx, inputTadShape); + auto lOffset = + shape::getIndexOffset(numOfElemsToReverse - idx - 1, inputTadShape); + + // now we're storing input values + auto v1 = tadInput[fOffset]; + auto v2 = tadInput[lOffset]; + + // now we're calculating offsets within output TAD + auto zfOffset = shape::getIndexOffset(idx, outputTadShape); + auto zlOffset = + shape::getIndexOffset(numOfElemsToReverse - idx - 1, outputTadShape); + + // and saving values to output arrays + tadOutput[zfOffset] = v2; + tadOutput[zlOffset] = v1; + } + + // moving odd element in blocks + if (odd && threadIdx.x == 0) { + for (uint64_t e = blockIdx.x; e < numTads; e += gridDim.x) { + auto tadInput = input + inputTadOffsets[e]; + auto tadOutput = output + outputTadOffsets[e]; + + auto xOffset = + shape::getIndexOffset(numOfElemsToReverse / 2, inputTadShape); + auto zOffset = + shape::getIndexOffset(numOfElemsToReverse / 2, outputTadShape); + + tadOutput[zOffset] = tadInput[xOffset]; } + } +} - template - static void reverseTad(sd::LaunchContext * context, const NDArray* input, NDArray* output, const Nd4jLong *inputTadShape, const Nd4jLong *inputTadOffsets, const Nd4jLong *outputTadShape, const Nd4jLong *outputTadOffsets, uint64_t tadLength) { - auto stream = context->getCudaStream(); - reverseTadKernel<<<256, 512, 8192, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), inputTadShape, inputTadOffsets, outputTadShape, outputTadOffsets, input->lengthOf(), tadLength, input->lengthOf() / tadLength); - } +template +static __global__ void reverseArrayKernel(const void* input, + const Nd4jLong* inputShape, + void* output, + const Nd4jLong* outputShape, + Nd4jLong numOfElemsToReverse) { + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + __shared__ int linearStatus; + __shared__ const T* inputArr; + __shared__ T* outputArr; + __shared__ char inputOrder, outputOrder; + + if (threadIdx.x == 0) { + linearStatus = (shape::elementWiseStride(inputShape) == + shape::elementWiseStride(outputShape)) && + (inputOrder == outputOrder) + ? shape::elementWiseStride(inputShape) + : 0; + + char inputOrder = shape::order(inputShape); + char outputOrder = shape::order(outputShape); + inputArr = reinterpret_cast(input); + outputArr = reinterpret_cast(output); + } + __syncthreads(); + + auto odd = numOfElemsToReverse % 2 != 0; + auto limit = numOfElemsToReverse / 2; + + for (uint64_t e = tid; e < limit; e += step) { + // we're calculating offsets within input array + auto fOffset = shape::getIndexOffset(e, inputShape); + auto lOffset = + shape::getIndexOffset(numOfElemsToReverse - e - 1, inputShape); + + // now we're storing input values + auto v1 = inputArr[fOffset]; + auto v2 = inputArr[lOffset]; + + // now we're calculating offsets within output array + auto zfOffset = shape::getIndexOffset(e, outputShape); + auto zlOffset = + shape::getIndexOffset(numOfElemsToReverse - e - 1, outputShape); + + // and saving values to output arrays + outputArr[zfOffset] = v2; + outputArr[zlOffset] = v1; + } + + // in case of odd array we'll have to move middle value + if (odd && tid == 0) { + auto xOffset = shape::getIndexOffset(limit, inputShape); + auto zOffset = shape::getIndexOffset(limit, outputShape); + + outputArr[zOffset] = inputArr[xOffset]; + } +} - template - static void reverseArray(sd::LaunchContext * context, const NDArray* input, NDArray* output, Nd4jLong numOfElemsToReverse) { - auto stream = context->getCudaStream(); - Nd4jLong numOfReverse = numOfElemsToReverse; - if (numOfElemsToReverse == 0) - numOfReverse = input->lengthOf(); +template +static void reverseTad(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const Nd4jLong* inputTadShape, + const Nd4jLong* inputTadOffsets, + const Nd4jLong* outputTadShape, + const Nd4jLong* outputTadOffsets, uint64_t tadLength) { + auto stream = context->getCudaStream(); + reverseTadKernel<<<256, 512, 8192, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), inputTadShape, + inputTadOffsets, outputTadShape, outputTadOffsets, input->lengthOf(), + tadLength, input->lengthOf() / tadLength); +} - reverseArrayKernel<<<256, 512, 8192, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), numOfReverse); - } +template +static void reverseArray(sd::LaunchContext* context, const NDArray* input, + NDArray* output, Nd4jLong numOfElemsToReverse) { + auto stream = context->getCudaStream(); + Nd4jLong numOfReverse = numOfElemsToReverse; + if (numOfElemsToReverse == 0) numOfReverse = input->lengthOf(); + reverseArrayKernel<<<256, 512, 8192, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), numOfReverse); +} - /////////////////////////////////////////////////////////////////// - template - static void reverseSequence_(sd::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim){ - int posOfNonUnityDim = -1; - seqLengths->syncToHost(); - auto stream = context->getCudaStream(); - - if(input->isVector() || shape::isLikeVector(input->shapeInfo(), posOfNonUnityDim) || seqLengths->lengthOf() == 1) { - int numOfElemsToReverse = seqLengths->e(0); - if((seqDim == 0 && input->sizeAt(0) == 1) || (batchDim == posOfNonUnityDim)) - output->assign(input); - else - reverseArrayKernel<<<256, 512, 8192, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), numOfElemsToReverse);//helpers::reverseArray(context, const_cast(input), output, numOfElemsToReverse); - } - else { - - if(seqDim > batchDim) - --seqDim; - - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {batchDim}); - - auto inSubArrsSet = input->allTensorsAlongDimension(dimensions); - auto outSubArrsSet = output->allTensorsAlongDimension(dimensions); - - for(int i = 0; i < inSubArrsSet.size(); ++i) { - - int numOfElemsToReverse = seqLengths->e(i); - - if(numOfElemsToReverse == 0 || numOfElemsToReverse == 1) { - outSubArrsSet.at(i)->assign(inSubArrsSet.at(i)); - } - else { - auto inInnerSet = inSubArrsSet.at(i)->allTensorsAlongDimension({seqDim}); - auto outInnerSet = outSubArrsSet.at(i)->allTensorsAlongDimension({seqDim}); - for(int j = 0; j < inInnerSet.size(); ++j) - reverseArray(context, inInnerSet.at(j), outInnerSet.at(j), numOfElemsToReverse); - } - } - } +/////////////////////////////////////////////////////////////////// +template +static void reverseSequence_(sd::LaunchContext* context, const NDArray* input, + const NDArray* seqLengths, NDArray* output, + int seqDim, const int batchDim) { + int posOfNonUnityDim = -1; + seqLengths->syncToHost(); + auto stream = context->getCudaStream(); + + if (input->isVector() || + shape::isLikeVector(input->shapeInfo(), posOfNonUnityDim) || + seqLengths->lengthOf() == 1) { + int numOfElemsToReverse = seqLengths->e(0); + if ((seqDim == 0 && input->sizeAt(0) == 1) || + (batchDim == posOfNonUnityDim)) + output->assign(input); + else + reverseArrayKernel<<<256, 512, 8192, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), + numOfElemsToReverse); // helpers::reverseArray(context, + // const_cast(input), output, + // numOfElemsToReverse); + } else { + if (seqDim > batchDim) --seqDim; + + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {batchDim}); + + auto inSubArrsSet = input->allTensorsAlongDimension(dimensions); + auto outSubArrsSet = output->allTensorsAlongDimension(dimensions); + + for (int i = 0; i < inSubArrsSet.size(); ++i) { + int numOfElemsToReverse = seqLengths->e(i); + + if (numOfElemsToReverse == 0 || numOfElemsToReverse == 1) { + outSubArrsSet.at(i).assign(inSubArrsSet.at(i)); + } else { + auto inInnerSet = + inSubArrsSet.at(i).allTensorsAlongDimension({seqDim}); + auto outInnerSet = + outSubArrsSet.at(i).allTensorsAlongDimension({seqDim}); + for (int j = 0; j < inInnerSet.size(); ++j) + reverseArray(context, &inInnerSet.at(j), &outInnerSet.at(j), numOfElemsToReverse); + } } + } +} - void reverseSequence(sd::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim) { - NDArray::prepareSpecialUse({output}, {input, seqLengths}); +void reverseSequence(sd::LaunchContext* context, const NDArray* input, + const NDArray* seqLengths, NDArray* output, int seqDim, + const int batchDim) { + NDArray::prepareSpecialUse({output}, {input, seqLengths}); - // if op isn't inplace - copy original data into output array - if (output->specialBuffer() != input->specialBuffer()) - output->assign(input); + // if op isn't inplace - copy original data into output array + if (output->specialBuffer() != input->specialBuffer()) output->assign(input); - BUILD_SINGLE_SELECTOR(input->dataType(), reverseSequence_, (context, input, seqLengths, output, seqDim, batchDim), LIBND4J_TYPES); - NDArray::registerSpecialUse({output}, {input, seqLengths}); - } + BUILD_SINGLE_SELECTOR(input->dataType(), reverseSequence_, + (context, input, seqLengths, output, seqDim, batchDim), + LIBND4J_TYPES); + NDArray::registerSpecialUse({output}, {input, seqLengths}); +} - ////////////////////////////////////////////////////////////////////////// - void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector* intArgs) { +////////////////////////////////////////////////////////////////////////// +void reverse(sd::LaunchContext* context, const NDArray* input, NDArray* output, + const std::vector* intArgs) { - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), *intArgs); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), *intArgs); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), *intArgs); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), *intArgs); - NDArray::prepareSpecialUse({output}, {input}); + NDArray::prepareSpecialUse({output}, {input}); - if (packX.numberOfTads() == 1) { - BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, (context, input, output, 0), LIBND4J_TYPES); - } else { - BUILD_SINGLE_SELECTOR(input->dataType(), reverseTad, (context, input, output, packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets(), (uint64_t) (input->lengthOf() / packX.numberOfTads())), LIBND4J_TYPES); - } + if (packX.numberOfTads() == 1) { + BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, (context, input, output, 0), LIBND4J_TYPES); + } else { + BUILD_SINGLE_SELECTOR( + input->dataType(), reverseTad, + (context, input, output, packX.platformShapeInfo(), + packX.platformOffsets(), packZ.platformShapeInfo(), + packZ.platformOffsets(), + (uint64_t)(input->lengthOf() / packX.numberOfTads())), + LIBND4J_TYPES); + } - NDArray::registerSpecialUse({output}, {input}); - } -} -} + NDArray::registerSpecialUse({output}, {input}); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/roll.cu b/libnd4j/include/ops/declarable/helpers/cuda/roll.cu index a5149c978eb9..01db45ba6e78 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/roll.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/roll.cu @@ -18,314 +18,373 @@ // @author raver119@gmail.com // -#include #include #include +#include namespace sd { namespace ops { namespace helpers { - template - static void _CUDA_D rollKernelLinearStage1Dev(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Nd4jLong fullLength, int actualShift) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); +template +static void _CUDA_D rollKernelLinearStage1Dev( + const void *vx, const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, Nd4jLong fullLength, int actualShift) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); - auto xEws = shape::elementWiseStride(xShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); - auto xOrder = shape::order(xShapeInfo); - auto zOrder = shape::order(zShapeInfo); + auto xOrder = shape::order(xShapeInfo); + auto zOrder = shape::order(zShapeInfo); - auto tid = threadIdx.x + blockIdx.x * blockDim.x; + auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (xEws > 0 && zEws > 0 && xOrder == zOrder) { - for (int i = tid; i < actualShift; i += blockDim.x * gridDim.x) { - int sourceIndex = fullLength - actualShift + i; + if (xEws > 0 && zEws > 0 && xOrder == zOrder) { + for (int i = tid; i < actualShift; i += blockDim.x * gridDim.x) { + int sourceIndex = fullLength - actualShift + i; - auto eA = x[sourceIndex * xEws]; - auto eB = x[i * xEws]; + auto eA = x[sourceIndex * xEws]; + auto eB = x[i * xEws]; - z[i * zEws] = eA; - z[sourceIndex * zEws] = eB; - } - } else { - for (int i = tid; i < actualShift; i += blockDim.x * gridDim.x) { - int sourceIndex = fullLength - actualShift + i; + z[i * zEws] = eA; + z[sourceIndex * zEws] = eB; + } + } else { + for (int i = tid; i < actualShift; i += blockDim.x * gridDim.x) { + int sourceIndex = fullLength - actualShift + i; - auto xOffsetA = shape::getIndexOffset(i, xShapeInfo); - auto xOffsetB = shape::getIndexOffset(sourceIndex, xShapeInfo); + auto xOffsetA = shape::getIndexOffset(i, xShapeInfo); + auto xOffsetB = shape::getIndexOffset(sourceIndex, xShapeInfo); - auto zOffsetA = shape::getIndexOffset(i, zShapeInfo); - auto zOffsetB = shape::getIndexOffset(sourceIndex, zShapeInfo); + auto zOffsetA = shape::getIndexOffset(i, zShapeInfo); + auto zOffsetB = shape::getIndexOffset(sourceIndex, zShapeInfo); - auto eA = x[xOffsetA]; - auto eB = x[xOffsetB]; + auto eA = x[xOffsetA]; + auto eB = x[xOffsetB]; - z[zOffsetA] = eB; - z[zOffsetB] = eA; - } - } + z[zOffsetA] = eB; + z[zOffsetB] = eA; } + } +} - template - static void _CUDA_G rollKernelLinearStage1(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Nd4jLong fullLength, int actualShift) { - rollKernelLinearStage1Dev(vx, xShapeInfo, vz, zShapeInfo, fullLength, actualShift); - } +template +static void _CUDA_G rollKernelLinearStage1(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + Nd4jLong fullLength, + int actualShift) { + rollKernelLinearStage1Dev(vx, xShapeInfo, vz, zShapeInfo, fullLength, + actualShift); +} - template - static void _CUDA_G rollKernelLinearStage2(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Nd4jLong fullLength, int actualShift, int shiftCount) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); +template +static void _CUDA_G rollKernelLinearStage2(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + Nd4jLong fullLength, int actualShift, + int shiftCount) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); - auto xEws = shape::elementWiseStride(xShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); - auto xOrder = shape::order(xShapeInfo); - auto zOrder = shape::order(zShapeInfo); + auto xOrder = shape::order(xShapeInfo); + auto zOrder = shape::order(zShapeInfo); - auto tid = threadIdx.x + blockIdx.x * blockDim.x; + auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (xEws > 0 && zEws > 0 && xOrder == zOrder) { - for (int count = 1; count < shiftCount; ++count) { - for (int i = tid; i < actualShift; i += blockDim.x * gridDim.x) { - int destinationIndex = fullLength - (count + 1) * actualShift + i; - int sourceIndex = fullLength - count * actualShift + i; + if (xEws > 0 && zEws > 0 && xOrder == zOrder) { + for (int count = 1; count < shiftCount; ++count) { + for (int i = tid; i < actualShift; i += blockDim.x * gridDim.x) { + int destinationIndex = fullLength - (count + 1) * actualShift + i; + int sourceIndex = fullLength - count * actualShift + i; - auto eA = x[sourceIndex * xEws]; - auto eB = x[destinationIndex * xEws]; + auto eA = x[sourceIndex * xEws]; + auto eB = x[destinationIndex * xEws]; - z[destinationIndex * zEws] = eA; - z[sourceIndex * zEws] = eB; - } + z[destinationIndex * zEws] = eA; + z[sourceIndex * zEws] = eB; + } - __syncthreads(); - } - } else { - for (int count = 1; count < shiftCount; ++count) { - for (int i = tid; i < actualShift; i += blockDim.x * gridDim.x) { - int destinationIndex = fullLength - (count + 1) * actualShift + i; - int sourceIndex = fullLength - count * actualShift + i; + __syncthreads(); + } + } else { + for (int count = 1; count < shiftCount; ++count) { + for (int i = tid; i < actualShift; i += blockDim.x * gridDim.x) { + int destinationIndex = fullLength - (count + 1) * actualShift + i; + int sourceIndex = fullLength - count * actualShift + i; - auto xOffsetA = shape::getIndexOffset(destinationIndex, xShapeInfo); - auto xOffsetB = shape::getIndexOffset(sourceIndex, xShapeInfo); + auto xOffsetA = shape::getIndexOffset(destinationIndex, xShapeInfo); + auto xOffsetB = shape::getIndexOffset(sourceIndex, xShapeInfo); - auto zOffsetA = shape::getIndexOffset(destinationIndex, zShapeInfo); - auto zOffsetB = shape::getIndexOffset(sourceIndex, zShapeInfo); + auto zOffsetA = shape::getIndexOffset(destinationIndex, zShapeInfo); + auto zOffsetB = shape::getIndexOffset(sourceIndex, zShapeInfo); - auto eA = x[xOffsetA]; - auto eB = x[xOffsetB]; + auto eA = x[xOffsetA]; + auto eB = x[xOffsetB]; - z[zOffsetA] = eB; - z[zOffsetB] = eA; - } + z[zOffsetA] = eB; + z[zOffsetB] = eA; + } - __syncthreads(); - } - } + __syncthreads(); } + } +} - template - static void _CUDA_G rollKernelLinearStage3(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Nd4jLong fullLength, int actualShift, int remainShift) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - auto xEws = shape::elementWiseStride(xShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); - - auto xOrder = shape::order(xShapeInfo); - auto zOrder = shape::order(zShapeInfo); - - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - - if (xEws > 0 && zEws > 0 && xOrder == zOrder) { - for (int i = tid ; i < actualShift; i += blockDim.x * gridDim.x) { - int remainIdx = i + actualShift; - int sourceIndex = remainIdx + remainShift; +template +static void _CUDA_G rollKernelLinearStage3(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + Nd4jLong fullLength, int actualShift, + int remainShift) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); - auto eA = x[sourceIndex * xEws]; - auto eB = x[remainIdx * xEws]; + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); - z[remainIdx * zEws] = eA; - z[sourceIndex * zEws] = eB; - } - } else { - for (int i = tid; i < actualShift; i += blockDim.x * gridDim.x) { - int remainIdx = i + actualShift; - int sourceIndex = remainIdx + remainShift; + auto xOrder = shape::order(xShapeInfo); + auto zOrder = shape::order(zShapeInfo); - auto xOffsetA = shape::getIndexOffset(remainIdx, xShapeInfo); - auto xOffsetB = shape::getIndexOffset(sourceIndex, xShapeInfo); + auto tid = threadIdx.x + blockIdx.x * blockDim.x; - auto zOffsetA = shape::getIndexOffset(remainIdx, zShapeInfo); - auto zOffsetB = shape::getIndexOffset(sourceIndex, zShapeInfo); + if (xEws > 0 && zEws > 0 && xOrder == zOrder) { + for (int i = tid; i < actualShift; i += blockDim.x * gridDim.x) { + int remainIdx = i + actualShift; + int sourceIndex = remainIdx + remainShift; - auto eA = x[xOffsetA]; - auto eB = x[xOffsetB]; + auto eA = x[sourceIndex * xEws]; + auto eB = x[remainIdx * xEws]; - z[zOffsetA] = eB; - z[zOffsetB] = eA; - } - } + z[remainIdx * zEws] = eA; + z[sourceIndex * zEws] = eB; } + } else { + for (int i = tid; i < actualShift; i += blockDim.x * gridDim.x) { + int remainIdx = i + actualShift; + int sourceIndex = remainIdx + remainShift; - template - static void _CUDA_D swapTadsKernel(void *vx, void *vz, const Nd4jLong *zShapeInfo, Nd4jLong tadLength) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); + auto xOffsetA = shape::getIndexOffset(remainIdx, xShapeInfo); + auto xOffsetB = shape::getIndexOffset(sourceIndex, xShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); + auto zOffsetA = shape::getIndexOffset(remainIdx, zShapeInfo); + auto zOffsetB = shape::getIndexOffset(sourceIndex, zShapeInfo); - auto zOrder = shape::order(zShapeInfo); + auto eA = x[xOffsetA]; + auto eB = x[xOffsetB]; - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - - if (zEws > 0) { - for (int e = threadIdx.x; e < tadLength; e += blockDim.x) { - auto eA = x[e * zEws]; - auto eB = z[e * zEws]; - - x[e * zEws] = eB; - z[e * zEws] = eA; - } - } else { - for (int e = threadIdx.x; e < tadLength; e += blockDim.x) { - auto zOffset = shape::getIndexOffset(e, zShapeInfo); - - auto eA = x[zOffset]; - auto eB = z[zOffset]; - - x[zOffset] = eB; - z[zOffset] = eA; - } - } + z[zOffsetA] = eB; + z[zOffsetB] = eA; } + } +} - template - static void _CUDA_G rollKernelFullAnyDimensionStage1(const void *vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, void *vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets, int numTads, Nd4jLong tadLength, int dim, Nd4jLong sizeAt, int theShift) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); +template +static void _CUDA_D swapTadsKernel(void *vx, void *vz, + const Nd4jLong *zShapeInfo, + Nd4jLong tadLength) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); - for (int e = blockIdx.x + theShift; e < sizeAt - theShift; e += gridDim.x) { - int sourceIndex = dim * sizeAt + e - theShift; - int targetIndex = dim * sizeAt + e; + auto zEws = shape::elementWiseStride(zShapeInfo); - swapTadsKernel(z + xTadOffsets[sourceIndex], z + xTadOffsets[targetIndex], zTadShapeInfo, tadLength); - } - } + auto zOrder = shape::order(zShapeInfo); - template - static void _CUDA_G rollKernelFullAnyDimensionStage2(void *vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, void *vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets, int numTads, Nd4jLong tadLength, int dim, Nd4jLong sizeAt, int theShift) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); + auto tid = threadIdx.x + blockIdx.x * blockDim.x; - for (int e = blockIdx.x; e < theShift; e += gridDim.x) { - int sourceIndex = dim * sizeAt + sizeAt - theShift + e; - int targetIndex = dim * sizeAt + e; - - swapTadsKernel(z + zTadOffsets[sourceIndex], z + zTadOffsets[targetIndex], zTadShapeInfo, tadLength); - } - } + if (zEws > 0) { + for (int e = threadIdx.x; e < tadLength; e += blockDim.x) { + auto eA = x[e * zEws]; + auto eB = z[e * zEws]; - template - static void rollFunctorFull_(NDArray* input, NDArray* output, std::vector const& shifts, std::vector const& axes, bool inplace){ - if (!inplace) - output->assign(input); - - for (size_t i = 0; i < axes.size(); i++) { - int axe = axes[i]; - if (axe == input->rankOf() - 1) { // last dimension - ResultSet listOfTensors = output->allTensorsAlongDimension({axe}); - ResultSet listOfOutTensors = output->allTensorsAlongDimension({axe}); - int fullLen = listOfTensors.size(); - int theShift = shifts[i]; -// if (theShift > 0) { -// theShift %= fullLen; -// } -// else { -// theShift -= fullLen * (theShift / fullLen - 1); -// } - for (int k = 0; k < fullLen; k++) { - rollFunctorLinear(output->getContext(), listOfTensors.at(k), listOfOutTensors.at(k), theShift, true); - } - } else { - std::vector dims(input->rankOf() - axe - 1); - for (int i = 0; i < dims.size(); ++i) - dims[i] = axe + 1 + i; - - auto packZ = ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dims); - - int numTads = packZ.numberOfTads(); - int sizeAt = input->sizeAt(axe); - auto tadLength = shape::length(packZ.primaryShapeInfo()); - - int theShift = shifts[i]; - -// if (theShift > 0) -// theShift %= sizeAt; -// else -// theShift -= sizeAt * (theShift / sizeAt - 1); - - if (theShift) { - for (int dim = 0; dim < numTads / sizeAt; ++dim) { - - rollKernelFullAnyDimensionStage1<<<1, 256, 1024, *(output->getContext()->getCudaStream())>>>(output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), numTads, tadLength, dim, sizeAt, theShift); - - rollKernelFullAnyDimensionStage2<<<1, 256, 1024, *(output->getContext()->getCudaStream())>>>(output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), numTads, tadLength, dim, sizeAt, theShift); - } - } - } - } + x[e * zEws] = eB; + z[e * zEws] = eA; } + } else { + for (int e = threadIdx.x; e < tadLength; e += blockDim.x) { + auto zOffset = shape::getIndexOffset(e, zShapeInfo); - template - static void rollFunctorLinear_(NDArray* input, NDArray* output, int shift, bool inplace){ - if (!inplace) - output->assign(input); + auto eA = x[zOffset]; + auto eB = z[zOffset]; - auto fullLen = input->lengthOf(); - int actualShift = shift; // % fullLen; // shift already non-negative then - if (actualShift < 0) { - actualShift -= fullLen * (actualShift / fullLen - 1); - } - else - actualShift %= fullLen; - - if (actualShift) { - int shiftCount = fullLen / actualShift - 1; - int remainShift = fullLen % actualShift; + x[zOffset] = eB; + z[zOffset] = eA; + } + } +} - // stage 1) swap last actualShift elements with first ones. - rollKernelLinearStage1<<<1, 1, 1024, *(output->getContext()->getCudaStream())>>>(output->specialBuffer(), output->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), fullLen, actualShift); +template +static void _CUDA_G rollKernelFullAnyDimensionStage1( + const void *vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, + void *vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets, + int numTads, Nd4jLong tadLength, int dim, Nd4jLong sizeAt, int theShift) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + for (int e = blockIdx.x + theShift; e < sizeAt - theShift; e += gridDim.x) { + int sourceIndex = dim * sizeAt + e - theShift; + int targetIndex = dim * sizeAt + e; + + swapTadsKernel(z + xTadOffsets[sourceIndex], + z + xTadOffsets[targetIndex], zTadShapeInfo, tadLength); + } +} - // stage 2) swap swapped actualShift elements with rest remainShiftCount times. - rollKernelLinearStage2<<<1, 1, 1024, *(output->getContext()->getCudaStream())>>>(output->specialBuffer(), output->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), fullLen, actualShift, shiftCount); +template +static void _CUDA_G rollKernelFullAnyDimensionStage2( + void *vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, + void *vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets, + int numTads, Nd4jLong tadLength, int dim, Nd4jLong sizeAt, int theShift) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + for (int e = blockIdx.x; e < theShift; e += gridDim.x) { + int sourceIndex = dim * sizeAt + sizeAt - theShift + e; + int targetIndex = dim * sizeAt + e; + + swapTadsKernel(z + zTadOffsets[sourceIndex], + z + zTadOffsets[targetIndex], zTadShapeInfo, tadLength); + } +} - // FIXME: no parallelism here :( - // stage 3) swap remainer of items. - if (remainShift && shiftCount) - rollKernelLinearStage3<<<1, 1, 1024, *(output->getContext()->getCudaStream())>>>(output->specialBuffer(), output->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), fullLen, actualShift, remainShift); +template +static void rollFunctorFull_(NDArray *input, NDArray *output, + std::vector const &shifts, + std::vector const &axes, bool inplace) { + if (!inplace) output->assign(input); + + for (size_t i = 0; i < axes.size(); i++) { + int axe = axes[i]; + if (axe == input->rankOf() - 1) { // last dimension + ResultSet listOfTensors = output->allTensorsAlongDimension({axe}); + ResultSet listOfOutTensors = output->allTensorsAlongDimension({axe}); + int fullLen = listOfTensors.size(); + int theShift = shifts[i]; + // if (theShift > 0) { + // theShift %= fullLen; + // } + // else { + // theShift -= fullLen * (theShift / fullLen - 1); + // } + for (int k = 0; k < fullLen; k++) { + rollFunctorLinear(output->getContext(), &listOfTensors.at(k), &listOfOutTensors.at(k), theShift, true); + } + } else { + std::vector dims(input->rankOf() - axe - 1); + for (int i = 0; i < dims.size(); ++i) dims[i] = axe + 1 + i; + + auto packZ = ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dims); + + int numTads = packZ.numberOfTads(); + int sizeAt = input->sizeAt(axe); + auto tadLength = shape::length(packZ.primaryShapeInfo()); + + int theShift = shifts[i]; + + // if (theShift > 0) + // theShift %= sizeAt; + // else + // theShift -= sizeAt * (theShift / sizeAt - 1); + + if (theShift) { + for (int dim = 0; dim < numTads / sizeAt; ++dim) { + rollKernelFullAnyDimensionStage1 + <<<1, 256, 1024, *(output->getContext()->getCudaStream())>>>( + output->specialBuffer(), packZ.platformShapeInfo(), + packZ.platformOffsets(), output->specialBuffer(), + packZ.platformShapeInfo(), packZ.platformOffsets(), numTads, + tadLength, dim, sizeAt, theShift); + + rollKernelFullAnyDimensionStage2 + <<<1, 256, 1024, *(output->getContext()->getCudaStream())>>>( + output->specialBuffer(), packZ.platformShapeInfo(), + packZ.platformOffsets(), output->specialBuffer(), + packZ.platformShapeInfo(), packZ.platformOffsets(), numTads, + tadLength, dim, sizeAt, theShift); } + } } + } +} - void rollFunctorFull(sd::LaunchContext * context, NDArray* input, NDArray* output, std::vector const& shifts, std::vector const& axes, bool inplace){ - input->syncToDevice(); +template +static void rollFunctorLinear_(NDArray *input, NDArray *output, int shift, + bool inplace) { + if (!inplace) output->assign(input); + + auto fullLen = input->lengthOf(); + int actualShift = shift; // % fullLen; // shift already non-negative then + if (actualShift < 0) { + actualShift -= fullLen * (actualShift / fullLen - 1); + } else + actualShift %= fullLen; + + if (actualShift) { + int shiftCount = fullLen / actualShift - 1; + int remainShift = fullLen % actualShift; + + // stage 1) swap last actualShift elements with first ones. + rollKernelLinearStage1 + <<<1, 1, 1024, *(output->getContext()->getCudaStream())>>>( + output->specialBuffer(), output->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), fullLen, + actualShift); + + // stage 2) swap swapped actualShift elements with rest remainShiftCount + // times. + rollKernelLinearStage2 + <<<1, 1, 1024, *(output->getContext()->getCudaStream())>>>( + output->specialBuffer(), output->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), fullLen, + actualShift, shiftCount); + + // FIXME: no parallelism here :( + // stage 3) swap remainer of items. + if (remainShift && shiftCount) + rollKernelLinearStage3 + <<<1, 1, 1024, *(output->getContext()->getCudaStream())>>>( + output->specialBuffer(), output->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), fullLen, + actualShift, remainShift); + } +} - BUILD_SINGLE_SELECTOR(input->dataType(), rollFunctorFull_, (input, output, shifts, axes, inplace), LIBND4J_TYPES); +void rollFunctorFull(sd::LaunchContext *context, NDArray *input, + NDArray *output, std::vector const &shifts, + std::vector const &axes, bool inplace) { + input->syncToDevice(); - output->tickWriteDevice(); - } + BUILD_SINGLE_SELECTOR(input->dataType(), rollFunctorFull_, + (input, output, shifts, axes, inplace), LIBND4J_TYPES); - void rollFunctorLinear(sd::LaunchContext * context, NDArray* input, NDArray* output, int shift, bool inplace){ - input->syncToDevice(); + output->tickWriteDevice(); +} - BUILD_SINGLE_SELECTOR(input->dataType(), rollFunctorLinear_, (input, output, shift, inplace), LIBND4J_TYPES); +void rollFunctorLinear(sd::LaunchContext *context, NDArray *input, + NDArray *output, int shift, bool inplace) { + input->syncToDevice(); - output->tickWriteDevice(); - } + BUILD_SINGLE_SELECTOR(input->dataType(), rollFunctorLinear_, + (input, output, shift, inplace), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void rollFunctorLinear_, (NDArray* input, NDArray* output, int shift, bool inplace), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void rollFunctorFull_, (NDArray* input, NDArray* output, std::vector const& shifts, std::vector const& axes, bool inplace), LIBND4J_TYPES); + output->tickWriteDevice(); } -} -} \ No newline at end of file + +BUILD_SINGLE_TEMPLATE(template void rollFunctorLinear_, + (NDArray * input, NDArray *output, int shift, + bool inplace), + LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void rollFunctorFull_, + (NDArray * input, NDArray *output, + std::vector const &shifts, + std::vector const &axes, bool inplace), + LIBND4J_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/s_t_b.cu b/libnd4j/include/ops/declarable/helpers/cuda/s_t_b.cu index 8b7bfb2b5e64..265dc9c46411 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/s_t_b.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/s_t_b.cu @@ -19,486 +19,577 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - /////////////////////////////////////////////////////////////////// -template -__global__ static void batchToSpaceCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint cropBottom, const uint cropLeft) { +template +__global__ static void batchToSpaceCuda(const void* vx, + const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, + const uint cropBottom, + const uint cropLeft) { + // input [bS, H * blockSize, W * blockSize, iC] + // output [bS, H * blockSize - cropBottom - cropTop, W * blockSize - cropLeft + // - cropRight, iC] - // input [bS, H * blockSize, W * blockSize, iC] - // output [bS, H * blockSize - cropBottom - cropTop, W * blockSize - cropLeft - cropRight, iC] + // if (cropTop = cropBottom = cropRight = cropLeft = 0) shapes are the same + // else: + // oH -> [cropBottom, iH - cropTop] + // oW -> [cropLeft, iH - cropRight] + // xLen >= zLen - // if (cropTop = cropBottom = cropRight = cropLeft = 0) shapes are the same - // else: - // oH -> [cropBottom, iH - cropTop] - // oW -> [cropLeft, iH - cropRight] - // xLen >= zLen + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); + __shared__ int rank, *sharedMem; + __shared__ Nd4jLong zLen; - __shared__ int rank, *sharedMem; - __shared__ Nd4jLong zLen; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + rank = shape::rank(zShapeInfo); + zLen = shape::length(zShapeInfo); + } + __syncthreads(); - rank = shape::rank(zShapeInfo); - zLen = shape::length(zShapeInfo); - } - __syncthreads(); + auto coords = sharedMem + threadIdx.x * rank; - auto coords = sharedMem + threadIdx.x * rank; + const auto i = blockIdx.x * blockDim.x + threadIdx.x; - const auto i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= zLen) return; - if(i >= zLen) - return; + shape::index2coords(i, zShapeInfo, coords); - shape::index2coords(i, zShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); + coords[1] += cropBottom; + coords[2] += cropLeft; - coords[1] += cropBottom; - coords[2] += cropLeft; - - const auto xOffset = shape::getOffset(xShapeInfo, coords); - - z[zOffset] = x[xOffset]; + const auto xOffset = shape::getOffset(xShapeInfo, coords); + z[zOffset] = x[xOffset]; } /////////////////////////////////////////////////////////////////// -template -static void batchToSpaceCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint cropBottom, const uint cropLeft) { - - batchToSpaceCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, cropBottom, cropLeft); +template +static void batchToSpaceCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, const uint cropBottom, + const uint cropLeft) { + batchToSpaceCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, cropBottom, cropLeft); } -BUILD_SINGLE_TEMPLATE(template void batchToSpaceCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint cropBottom, const uint cropLeft), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void batchToSpaceCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const int sharedMem, const cudaStream_t* stream, + const void* vx, const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const uint cropBottom, + const uint cropLeft), + LIBND4J_TYPES); /////////////////////////////////////////////////////////////////// -void batchToSpace(sd::LaunchContext* context, const NDArray& input, NDArray& output, const uint cropBottom, const uint cropTop, const uint cropLeft, const uint cropRight, const uint blockSize) { - - // [bS*blockSize*blockSize, H/blockSize, W/blockSize, iC] is rearranged/permuted to [bS, oH, oW, iC] - // oH = H - cropTop - cropBottom - // oW = W - cropLeft - cropRight - - NDArray inputRearranged0 = input.reshape(input.ordering(), {blockSize, blockSize, output.sizeAt(0), input.sizeAt(1), input.sizeAt(2), input.sizeAt(3)}); - inputRearranged0.permutei({2, 3,0, 4,1, 5}); - - if(input.lengthOf() == output.lengthOf()) { - - output.assign(inputRearranged0); - } - else { - - NDArray inputRearranged1 = inputRearranged0.reshape(input.ordering(), {output.sizeAt(0), input.sizeAt(1) * blockSize, input.sizeAt(2) * blockSize, input.sizeAt(3)}); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * output.rankOf() + 128; - - PointersManager manager(context, "batchToSpace"); - - NDArray::prepareSpecialUse({&output}, {&inputRearranged1}); - BUILD_SINGLE_SELECTOR(input.dataType(), batchToSpaceCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), inputRearranged1.specialBuffer(), inputRearranged1.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), cropBottom, cropLeft), LIBND4J_TYPES); - NDArray::registerSpecialUse({&output}, {&inputRearranged1}); - - manager.synchronize(); - } +void batchToSpace(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const uint cropBottom, const uint cropTop, + const uint cropLeft, const uint cropRight, + const uint blockSize) { + // [bS*blockSize*blockSize, H/blockSize, W/blockSize, iC] is + // rearranged/permuted to [bS, oH, oW, iC] oH = H - cropTop - cropBottom oW = + // W - cropLeft - cropRight + + NDArray inputRearranged0 = input.reshape( + input.ordering(), {blockSize, blockSize, output.sizeAt(0), + input.sizeAt(1), input.sizeAt(2), input.sizeAt(3)}); + inputRearranged0.permutei({2, 3, 0, 4, 1, 5}); + + if (input.lengthOf() == output.lengthOf()) { + output.assign(inputRearranged0); + } else { + NDArray inputRearranged1 = inputRearranged0.reshape( + input.ordering(), {output.sizeAt(0), input.sizeAt(1) * blockSize, + input.sizeAt(2) * blockSize, input.sizeAt(3)}); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(int) * output.rankOf() + 128; + + PointersManager manager(context, "batchToSpace"); + + NDArray::prepareSpecialUse({&output}, {&inputRearranged1}); + BUILD_SINGLE_SELECTOR( + input.dataType(), batchToSpaceCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + inputRearranged1.specialBuffer(), inputRearranged1.specialShapeInfo(), + output.specialBuffer(), output.specialShapeInfo(), cropBottom, + cropLeft), + LIBND4J_TYPES); + NDArray::registerSpecialUse({&output}, {&inputRearranged1}); + + manager.synchronize(); + } } - - /////////////////////////////////////////////////////////////////// -template -__global__ static void batchToSpaceNDCuda(const void* vx, const Nd4jLong* xShapeInfo, - const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, +template +__global__ static void batchToSpaceNDCuda(const void* vx, + const Nd4jLong* xShapeInfo, + const void* vy, + const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const uint numOfSpatialDims) { + // 4D example, numOfSpatialDims = 2 + // input [bS, H * blockShape[0], W * blockShape[1], iC] + // output [bS, H * blockShape[0] - cropBottom - cropTop, W * blockShape[1] - + // cropLeft - cropRight, iC] - // 4D example, numOfSpatialDims = 2 - // input [bS, H * blockShape[0], W * blockShape[1], iC] - // output [bS, H * blockShape[0] - cropBottom - cropTop, W * blockShape[1] - cropLeft - cropRight, iC] - - // if (cropTop = cropBottom = cropRight = cropLeft = 0) shapes are the same - // else: - // oH -> [cropBottom, iH - cropTop] - // oW -> [cropLeft, iH - cropRight] - // xLen >= zLen + // if (cropTop = cropBottom = cropRight = cropLeft = 0) shapes are the same + // else: + // oH -> [cropBottom, iH - cropTop] + // oW -> [cropLeft, iH - cropRight] + // xLen >= zLen - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); - __shared__ int rank, *sharedMem; - __shared__ Nd4jLong zLen; + __shared__ int rank, *sharedMem; + __shared__ Nd4jLong zLen; - if (threadIdx.x == 0) { - - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - rank = shape::rank(zShapeInfo); - zLen = shape::length(zShapeInfo); - } + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - __syncthreads(); + rank = shape::rank(zShapeInfo); + zLen = shape::length(zShapeInfo); + } - auto coords = sharedMem + threadIdx.x * rank; + __syncthreads(); - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; i += gridDim.x * blockDim.x) { + auto coords = sharedMem + threadIdx.x * rank; - shape::index2coords(i, zShapeInfo, coords); + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; + i += gridDim.x * blockDim.x) { + shape::index2coords(i, zShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); - // evaluate spatial coordinates for x - for(uint j = 1; j <= numOfSpatialDims; ++j) { - const auto yOffset = (j - 1) * yShapeInfo[3]; // yRank = 2, calculate offset manually - coords[j] += y[yOffset]; // add crop left - } + // evaluate spatial coordinates for x + for (uint j = 1; j <= numOfSpatialDims; ++j) { + const auto yOffset = + (j - 1) * yShapeInfo[3]; // yRank = 2, calculate offset manually + coords[j] += y[yOffset]; // add crop left + } - const auto xOffset = shape::getOffset(xShapeInfo, coords); + const auto xOffset = shape::getOffset(xShapeInfo, coords); - z[zOffset] = x[xOffset]; - } + z[zOffset] = x[xOffset]; + } } /////////////////////////////////////////////////////////////////// -template -static void batchToSpaceNDCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint numOfSpatialDims) { - - batchToSpaceNDCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, numOfSpatialDims); +template +static void batchToSpaceNDCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vy, const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const uint numOfSpatialDims) { + batchToSpaceNDCuda + <<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, numOfSpatialDims); } -BUILD_DOUBLE_TEMPLATE(template void batchToSpaceNDCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint numOfSpatialDims), LIBND4J_TYPES, INTEGER_TYPES); +BUILD_DOUBLE_TEMPLATE(template void batchToSpaceNDCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const int sharedMem, const cudaStream_t* stream, + const void* vx, const Nd4jLong* xShapeInfo, + const void* vy, const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const uint numOfSpatialDims), + LIBND4J_TYPES, INTEGER_TYPES); ////////////////////////////////////////////////////////////////////////// -void batchToSpaceND(sd::LaunchContext* context, const NDArray& input, const NDArray& blockShape, const NDArray& crop, NDArray& output) { - - // 4D example, numOfSpatialDims = 2 - two spatial dimensions - // [bS*blockShape[0]*blockShape[1], iH, iW, iC] is rearranged/permuted to [bS, iH*blockShape[0] - cropTop - cropBottom, iW*blockShape[1] - cropLeft - cropRight, iC] - - const uint rank = input.rankOf(); - const uint numOfSpatialDims = blockShape.sizeAt(0); +void batchToSpaceND(sd::LaunchContext* context, const NDArray& input, + const NDArray& blockShape, const NDArray& crop, + NDArray& output) { + // 4D example, numOfSpatialDims = 2 - two spatial dimensions + // [bS*blockShape[0]*blockShape[1], iH, iW, iC] is rearranged/permuted to [bS, + // iH*blockShape[0] - cropTop - cropBottom, iW*blockShape[1] - cropLeft - + // cropRight, iC] - //*** construct reshaping std::vector for first reshape of input array ***// + const uint rank = input.rankOf(); + const uint numOfSpatialDims = blockShape.sizeAt(0); - std::vector temp(numOfSpatialDims + rank); + //*** construct reshaping std::vector for first reshape of input array ***// - int i; - for(i = 0; i < numOfSpatialDims; ++i) - temp[i] = blockShape.e(i); - temp[i++] = output.sizeAt(0); - for(int j = 1; j < rank; ++i, ++j) - temp[i] = input.sizeAt(j); + std::vector temp(numOfSpatialDims + rank); - NDArray inputRearranged0 = input.reshape(input.ordering(), temp); + int i; + for (i = 0; i < numOfSpatialDims; ++i) temp[i] = blockShape.e(i); + temp[i++] = output.sizeAt(0); + for (int j = 1; j < rank; ++i, ++j) temp[i] = input.sizeAt(j); - //*** construct permuting std::vector for permutation of input array ***// + NDArray inputRearranged0 = input.reshape(input.ordering(), temp); - temp[0] = numOfSpatialDims; + //*** construct permuting std::vector for permutation of input array ***// - for(i = 1; i <= numOfSpatialDims; ++i) { - temp[2*i - 1] = numOfSpatialDims + i; - temp[2*i] = i - 1; - } - for(i = 2 * numOfSpatialDims + 1; i < temp.size(); ++i) - temp[i] = i; - - inputRearranged0.permutei(temp); + temp[0] = numOfSpatialDims; + for (i = 1; i <= numOfSpatialDims; ++i) { + temp[2 * i - 1] = numOfSpatialDims + i; + temp[2 * i] = i - 1; + } + for (i = 2 * numOfSpatialDims + 1; i < temp.size(); ++i) temp[i] = i; - if(input.lengthOf() == output.lengthOf()) { + inputRearranged0.permutei(temp); - output.assign(inputRearranged0); - } - else { - //*** construct reshaping std::vector for second reshape of input array ***// + if (input.lengthOf() == output.lengthOf()) { + output.assign(inputRearranged0); + } else { + //*** construct reshaping std::vector for second reshape of input array + //***// - temp.resize(rank); + temp.resize(rank); - temp[0] = output.sizeAt(0); + temp[0] = output.sizeAt(0); - for(i = 1; i < rank; ++i) - temp[i] = (i <= numOfSpatialDims) ? input.sizeAt(i) * blockShape.e(i - 1) : input.sizeAt(i); + for (i = 1; i < rank; ++i) + temp[i] = (i <= numOfSpatialDims) + ? input.sizeAt(i) * blockShape.e(i - 1) + : input.sizeAt(i); - NDArray inputRearranged1 = inputRearranged0.reshape(input.ordering(), temp); + NDArray inputRearranged1 = inputRearranged0.reshape(input.ordering(), temp); - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * output.rankOf() + 128; + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(int) * output.rankOf() + 128; - PointersManager manager(context, "batchToSpaceND"); + PointersManager manager(context, "batchToSpaceND"); - NDArray::prepareSpecialUse({&output}, {&inputRearranged1, &crop}); - BUILD_DOUBLE_SELECTOR(input.dataType(), crop.dataType(), batchToSpaceNDCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), inputRearranged1.specialBuffer(), inputRearranged1.specialShapeInfo(), crop.specialBuffer(), crop.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), numOfSpatialDims), LIBND4J_TYPES, INTEGER_TYPES); - NDArray::registerSpecialUse({&output}, {&inputRearranged1, &crop}); + NDArray::prepareSpecialUse({&output}, {&inputRearranged1, &crop}); + BUILD_DOUBLE_SELECTOR( + input.dataType(), crop.dataType(), batchToSpaceNDCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + inputRearranged1.specialBuffer(), inputRearranged1.specialShapeInfo(), + crop.specialBuffer(), crop.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), numOfSpatialDims), + LIBND4J_TYPES, INTEGER_TYPES); + NDArray::registerSpecialUse({&output}, {&inputRearranged1, &crop}); - manager.synchronize(); - } + manager.synchronize(); + } } - - /////////////////////////////////////////////////////////////////// -template -__global__ static void spaceToBatchCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight) { - - // input [bS, H * blockSize - padBottom - padTop, W * blockSize - padLeft - padRight, iC] - // output [bs, H * blockSize, W * blockSize, iC] - - // if (padTop = padBottom = padRight = padLeft = 0) shapes are the same - // else: - // iH -> [padBottom, oH - padTop] - // iW -> [padLeft, oW - padRight] - // zLen > xLen +template +__global__ static void spaceToBatchCuda(const void* vx, + const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, + const uint padBottom, const uint padTop, + const uint padLeft, + const uint padRight) { + // input [bS, H * blockSize - padBottom - padTop, W * blockSize - padLeft - + // padRight, iC] output [bs, H * blockSize, W * blockSize, iC] - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); + // if (padTop = padBottom = padRight = padLeft = 0) shapes are the same + // else: + // iH -> [padBottom, oH - padTop] + // iW -> [padLeft, oW - padRight] + // zLen > xLen - __shared__ int rank, *sharedMem; - __shared__ Nd4jLong zLen; + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + __shared__ int rank, *sharedMem; + __shared__ Nd4jLong zLen; - rank = shape::rank(zShapeInfo); - zLen = shape::length(zShapeInfo); - } - __syncthreads(); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - auto coords = sharedMem + threadIdx.x * rank; + rank = shape::rank(zShapeInfo); + zLen = shape::length(zShapeInfo); + } + __syncthreads(); - const auto i = blockIdx.x * blockDim.x + threadIdx.x; + auto coords = sharedMem + threadIdx.x * rank; - if(i >= zLen) - return; + const auto i = blockIdx.x * blockDim.x + threadIdx.x; - shape::index2coords(i, zShapeInfo, coords); + if (i >= zLen) return; - const auto zOffset = shape::getOffset(zShapeInfo, coords); + shape::index2coords(i, zShapeInfo, coords); - if(coords[1] >= padBottom && coords[1] < zShapeInfo[2] - padTop && coords[2] >= padLeft && coords[2] < zShapeInfo[3] - padRight) { + const auto zOffset = shape::getOffset(zShapeInfo, coords); - coords[1] -= padBottom; - coords[2] -= padLeft; + if (coords[1] >= padBottom && coords[1] < zShapeInfo[2] - padTop && + coords[2] >= padLeft && coords[2] < zShapeInfo[3] - padRight) { + coords[1] -= padBottom; + coords[2] -= padLeft; - const auto xOffset = shape::getOffset(xShapeInfo, coords); + const auto xOffset = shape::getOffset(xShapeInfo, coords); - z[zOffset] = x[xOffset]; - } - else - z[zOffset] = 0.f; + z[zOffset] = x[xOffset]; + } else + z[zOffset] = 0.f; } /////////////////////////////////////////////////////////////////// -template -static void spaceToBatchCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight) { - - spaceToBatchCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, padBottom, padTop, padLeft, padRight); +template +static void spaceToBatchCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, const uint padBottom, + const uint padTop, const uint padLeft, const uint padRight) { + spaceToBatchCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, padBottom, padTop, padLeft, padRight); } -BUILD_SINGLE_TEMPLATE(template void spaceToBatchCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void spaceToBatchCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const int sharedMem, const cudaStream_t* stream, + const void* vx, const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const uint padBottom, + const uint padTop, const uint padLeft, + const uint padRight), + LIBND4J_TYPES); /////////////////////////////////////////////////////////////////// -void spaceToBatch(sd::LaunchContext* context, const NDArray& input, NDArray& output, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight, const uint blockSize) { - - // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC] - - NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), input.sizeAt(3)}, false); - outputRearranged0.permutei({2, 3,0, 4,1, 5}); - - if(input.lengthOf() == output.lengthOf()) { - - outputRearranged0.assign(input); - } - else { - - NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, input.sizeAt(3)}, false); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * output.rankOf() + 128; - - PointersManager manager(context, "spaceToBatch"); - - NDArray::prepareSpecialUse({&outputRearranged1}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), spaceToBatchCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), outputRearranged1.specialBuffer(), outputRearranged1.specialShapeInfo(), padBottom, padTop, padLeft, padRight), LIBND4J_TYPES); - NDArray::registerSpecialUse({&outputRearranged1}, {&input}); - - manager.synchronize(); - - if(output.specialBuffer() != outputRearranged1.specialBuffer()) - outputRearranged0.assign(outputRearranged1); - } +void spaceToBatch(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const uint padBottom, const uint padTop, + const uint padLeft, const uint padRight, + const uint blockSize) { + // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + + // padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC] + + NDArray outputRearranged0 = + output.reshape(output.ordering(), + {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), + output.sizeAt(2), input.sizeAt(3)}, + false); + outputRearranged0.permutei({2, 3, 0, 4, 1, 5}); + + if (input.lengthOf() == output.lengthOf()) { + outputRearranged0.assign(input); + } else { + NDArray outputRearranged1 = outputRearranged0.reshape( + output.ordering(), + {input.sizeAt(0), output.sizeAt(1) * blockSize, + output.sizeAt(2) * blockSize, input.sizeAt(3)}, + false); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(int) * output.rankOf() + 128; + + PointersManager manager(context, "spaceToBatch"); + + NDArray::prepareSpecialUse({&outputRearranged1}, {&input}); + BUILD_SINGLE_SELECTOR( + input.dataType(), spaceToBatchCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), + outputRearranged1.specialBuffer(), + outputRearranged1.specialShapeInfo(), padBottom, padTop, padLeft, + padRight), + LIBND4J_TYPES); + NDArray::registerSpecialUse({&outputRearranged1}, {&input}); + + manager.synchronize(); + + if (output.specialBuffer() != outputRearranged1.specialBuffer()) + outputRearranged0.assign(outputRearranged1); + } } /////////////////////////////////////////////////////////////////// -template -__global__ static void spaceToBatchNDCuda(const void* vx, const Nd4jLong* xShapeInfo, - const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, +template +__global__ static void spaceToBatchNDCuda(const void* vx, + const Nd4jLong* xShapeInfo, + const void* vy, + const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const uint numOfSpatialDims) { + // x - input, y - padding, z - output - // x - input, y - padding, z - output - - // 4D example - // input [bS, H * blockShape[0] - padBottom - padTop, W * blockShape[1] - padLeft - padRight, iC] - // output [bS, H * blockShape[0], W * blockShape[1], iC] - - // if (padTop = padBottom = padRight = padLeft = 0) shapes are the same - // else: - // iH -> [padBottom, oH - padTop] - // iW -> [padLeft, oW - padRight] - // zLen > xLen + // 4D example + // input [bS, H * blockShape[0] - padBottom - padTop, W * blockShape[1] - + // padLeft - padRight, iC] output [bS, H * blockShape[0], W * blockShape[1], + // iC] - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); + // if (padTop = padBottom = padRight = padLeft = 0) shapes are the same + // else: + // iH -> [padBottom, oH - padTop] + // iW -> [padLeft, oW - padRight] + // zLen > xLen - __shared__ int rank, *sharedMem; // xRank = zRank, yRank = 2; - __shared__ Nd4jLong zLen, totalThreads; + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); - if (threadIdx.x == 0) { - - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - rank = shape::rank(zShapeInfo); - zLen = shape::length(zShapeInfo); - totalThreads = gridDim.x * blockDim.x; - } + __shared__ int rank, *sharedMem; // xRank = zRank, yRank = 2; + __shared__ Nd4jLong zLen, totalThreads; - __syncthreads(); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - auto coords = sharedMem + threadIdx.x * rank; + rank = shape::rank(zShapeInfo); + zLen = shape::length(zShapeInfo); + totalThreads = gridDim.x * blockDim.x; + } - for (int i = blockDim.x * blockIdx.x + threadIdx.x; i < zLen; i += totalThreads) { + __syncthreads(); - shape::index2coords(i, zShapeInfo, coords); + auto coords = sharedMem + threadIdx.x * rank; - const auto zOffset = shape::getOffset(zShapeInfo, coords); + for (int i = blockDim.x * blockIdx.x + threadIdx.x; i < zLen; + i += totalThreads) { + shape::index2coords(i, zShapeInfo, coords); - bool within = true; + const auto zOffset = shape::getOffset(zShapeInfo, coords); - for(uint j = 1; j <= numOfSpatialDims; ++j) { + bool within = true; - // yRank = 2, calculate offset manually - const auto yOffset = (j - 1) * yShapeInfo[3]; - const auto padLeft = y[yOffset]; - const auto padRight = y[yOffset + yShapeInfo[4]]; + for (uint j = 1; j <= numOfSpatialDims; ++j) { + // yRank = 2, calculate offset manually + const auto yOffset = (j - 1) * yShapeInfo[3]; + const auto padLeft = y[yOffset]; + const auto padRight = y[yOffset + yShapeInfo[4]]; - within &= (coords[j] >= padLeft && coords[j] < shape::shapeOf(const_cast(zShapeInfo))[j] - padRight); + within &= + (coords[j] >= padLeft && + coords[j] < + shape::shapeOf(const_cast(zShapeInfo))[j] - padRight); - if(!within) - break; + if (!within) break; - coords[j] -= padLeft; // get coordinates for x - } - - if(within) - z[zOffset] = x[shape::getOffset(xShapeInfo, coords)]; - else - z[zOffset] = 0.f; + coords[j] -= padLeft; // get coordinates for x } + + if (within) + z[zOffset] = x[shape::getOffset(xShapeInfo, coords)]; + else + z[zOffset] = 0.f; + } } /////////////////////////////////////////////////////////////////// -template -static void spaceToBatchNDCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint numOfSpatialDims) { - - spaceToBatchNDCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, numOfSpatialDims); +template +static void spaceToBatchNDCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vy, const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const uint numOfSpatialDims) { + spaceToBatchNDCuda + <<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, numOfSpatialDims); } -BUILD_DOUBLE_TEMPLATE(template void spaceToBatchNDCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint numOfSpatialDims), LIBND4J_TYPES, INTEGER_TYPES); +BUILD_DOUBLE_TEMPLATE(template void spaceToBatchNDCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const int sharedMem, const cudaStream_t* stream, + const void* vx, const Nd4jLong* xShapeInfo, + const void* vy, const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const uint numOfSpatialDims), + LIBND4J_TYPES, INTEGER_TYPES); ////////////////////////////////////////////////////////////////////////// -void spaceToBatchND(sd::LaunchContext* context, const NDArray& input, const NDArray& blockShape, const NDArray& padding, NDArray& output ) { - - // 4D example with two spatial dimensions - // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockShape[0]*blockShape[1], (iH + padBottom + padTop)/blockShape[0], (iW + padLeft + padRight)/blockShape[1], iC] +void spaceToBatchND(sd::LaunchContext* context, const NDArray& input, + const NDArray& blockShape, const NDArray& padding, + NDArray& output) { + // 4D example with two spatial dimensions + // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockShape[0]*blockShape[1], + // (iH + padBottom + padTop)/blockShape[0], (iW + padLeft + + // padRight)/blockShape[1], iC] - const uint rank = input.rankOf(); + const uint rank = input.rankOf(); - const uint numOfSpatialDims = blockShape.sizeAt(0); + const uint numOfSpatialDims = blockShape.sizeAt(0); - //*** construct reshaping std::vector for first reshape of output array ***// - std::vector temp(numOfSpatialDims + rank); + //*** construct reshaping std::vector for first reshape of output array ***// + std::vector temp(numOfSpatialDims + rank); - int i; - for(i = 0; i < numOfSpatialDims; ++i) - temp[i] = blockShape.e(i); - temp[i++] = input.sizeAt(0); - for(int j = 1; j < rank; ++i, ++j) - temp[i] = output.sizeAt(j); + int i; + for (i = 0; i < numOfSpatialDims; ++i) temp[i] = blockShape.e(i); + temp[i++] = input.sizeAt(0); + for (int j = 1; j < rank; ++i, ++j) temp[i] = output.sizeAt(j); - NDArray outputRearranged0 = output.reshape(output.ordering(), temp, false); + NDArray outputRearranged0 = output.reshape(output.ordering(), temp, false); - //*** construct permuting std::vector for permutation of output array ***// + //*** construct permuting std::vector for permutation of output array ***// - temp[0] = numOfSpatialDims; - - for(i = 1; i <= numOfSpatialDims; ++i) { - temp[2*i - 1] = numOfSpatialDims + i; - temp[2*i] = i - 1; - } - for(i = 2 * numOfSpatialDims + 1; i < temp.size(); ++i) - temp[i] = i; + temp[0] = numOfSpatialDims; - outputRearranged0.permutei(temp); + for (i = 1; i <= numOfSpatialDims; ++i) { + temp[2 * i - 1] = numOfSpatialDims + i; + temp[2 * i] = i - 1; + } + for (i = 2 * numOfSpatialDims + 1; i < temp.size(); ++i) temp[i] = i; - // ****** // + outputRearranged0.permutei(temp); - if(input.lengthOf() == output.lengthOf()) { - outputRearranged0.assign(input); - } - else { + // ****** // - //*** construct reshaping std::vector for second reshape of output array ***// - temp.resize(rank); + if (input.lengthOf() == output.lengthOf()) { + outputRearranged0.assign(input); + } else { + //*** construct reshaping std::vector for second reshape of output array + //***// + temp.resize(rank); - temp[0] = input.sizeAt(0); + temp[0] = input.sizeAt(0); - for(i = 1; i < rank; ++i) - temp[i] = (i <= numOfSpatialDims) ? output.sizeAt(i) * blockShape.e(i - 1) : output.sizeAt(i); + for (i = 1; i < rank; ++i) + temp[i] = (i <= numOfSpatialDims) + ? output.sizeAt(i) * blockShape.e(i - 1) + : output.sizeAt(i); - NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), temp, false); + NDArray outputRearranged1 = + outputRearranged0.reshape(output.ordering(), temp, false); - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * output.rankOf() + 128; + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(int) * output.rankOf() + 128; - PointersManager manager(context, "spaceToBatchND"); + PointersManager manager(context, "spaceToBatchND"); - NDArray::prepareSpecialUse({&outputRearranged1}, {&input, &padding}); - BUILD_DOUBLE_SELECTOR(input.dataType(), padding.dataType(), spaceToBatchNDCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), padding.specialBuffer(), padding.specialShapeInfo(), outputRearranged1.specialBuffer(), outputRearranged1.specialShapeInfo(), numOfSpatialDims), LIBND4J_TYPES, INTEGER_TYPES); - NDArray::registerSpecialUse({&outputRearranged1}, {&input, &padding}); + NDArray::prepareSpecialUse({&outputRearranged1}, {&input, &padding}); + BUILD_DOUBLE_SELECTOR( + input.dataType(), padding.dataType(), spaceToBatchNDCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), + padding.specialBuffer(), padding.specialShapeInfo(), + outputRearranged1.specialBuffer(), + outputRearranged1.specialShapeInfo(), numOfSpatialDims), + LIBND4J_TYPES, INTEGER_TYPES); + NDArray::registerSpecialUse({&outputRearranged1}, {&input, &padding}); - manager.synchronize(); + manager.synchronize(); - if(output.specialBuffer() != outputRearranged1.specialBuffer()) - outputRearranged0.assign(outputRearranged1); - } + if (output.specialBuffer() != outputRearranged1.specialBuffer()) + outputRearranged0.assign(outputRearranged1); + } } - /* template struct SpaceToBatchHelper { template - static void run(T *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong *space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const Nd4jLong *block_offsets, T *ptrBatch, const Nd4jLong *batch_shape, const Nd4jLong *batch_strides) { - for (int batch_pos = 0; batch_pos < batch_shape[0]; ++batch_pos) { - const int space_pos = batch_pos * block_shape[0] + block_offsets[0] - pad_start[0]; - if (space_pos >= 0 && space_pos < space_shape[0]) { - SpaceToBatchHelper::run(ptrSpace + space_pos * space_strides[0], space_shape + 1, space_strides + 1, block_shape + 1, pad_start + 1, block_offsets + 1, ptrBatch, batch_shape + 1, batch_strides + 1); - } else { + static void run(T *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong +*space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const +Nd4jLong *block_offsets, T *ptrBatch, const Nd4jLong *batch_shape, const +Nd4jLong *batch_strides) { for (int batch_pos = 0; batch_pos < batch_shape[0]; +++batch_pos) { const int space_pos = batch_pos * block_shape[0] + +block_offsets[0] - pad_start[0]; if (space_pos >= 0 && space_pos < +space_shape[0]) { SpaceToBatchHelper::run(ptrSpace + space_pos * +space_strides[0], space_shape + 1, space_strides + 1, block_shape + 1, pad_start ++ 1, block_offsets + 1, ptrBatch, batch_shape + 1, batch_strides + 1); } else { if (!B2S) for (int i = 0; i < batch_strides[0]; i++) ptrBatch[i] = (T) 0.f; @@ -512,24 +603,29 @@ void spaceToBatchND(sd::LaunchContext* context, const NDArray& input, const NDAr template struct SpaceToBatchHelper<0, B2S> { template - static void run(T *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong *space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const Nd4jLong *block_offsets, T *ptrBatch, const Nd4jLong *batch_shape, const Nd4jLong *batch_strides) { - int str = batch_strides[-1]; - for (int i = 0; i < str; i++) - if (B2S) - ptrSpace[i] = ptrBatch[i]; - else - ptrBatch[i] = ptrSpace[i]; + static void run(T *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong +*space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const +Nd4jLong *block_offsets, T *ptrBatch, const Nd4jLong *batch_shape, const +Nd4jLong *batch_strides) { int str = batch_strides[-1]; for (int i = 0; i < str; +i++) if (B2S) ptrSpace[i] = ptrBatch[i]; else ptrBatch[i] = ptrSpace[i]; } }; template - void _execute(sd::LaunchContext * context, void *vptrSpace, const Nd4jLong *space_shape, const Nd4jLong *space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const Nd4jLong *block_offsets, void *vptrBatch, const Nd4jLong *batch_shape, const Nd4jLong *batch_strides) { - auto ptrSpace = reinterpret_cast(vptrSpace); - auto ptrBatch = reinterpret_cast(vptrBatch); - SpaceToBatchHelper::run(ptrSpace, space_shape, space_strides, block_shape, pad_start, block_offsets, ptrBatch, batch_shape, batch_strides); + void _execute(sd::LaunchContext * context, void *vptrSpace, const Nd4jLong +*space_shape, const Nd4jLong *space_strides, const Nd4jLong *block_shape, const +Nd4jLong *pad_start, const Nd4jLong *block_offsets, void *vptrBatch, const +Nd4jLong *batch_shape, const Nd4jLong *batch_strides) { auto ptrSpace = +reinterpret_cast(vptrSpace); auto ptrBatch = reinterpret_cast(vptrBatch); SpaceToBatchHelper::run(ptrSpace, +space_shape, space_strides, block_shape, pad_start, block_offsets, ptrBatch, +batch_shape, batch_strides); }; - Nd4jStatus _batchToSpace(sd::LaunchContext * context, int internal_block_dims, NDArray *input, NDArray *output, std::vector &internal_input_shape, std::vector &internal_output_shape, Nd4jLong *block_shape, Nd4jLong *crops) { + Nd4jStatus _batchToSpace(sd::LaunchContext * context, int +internal_block_dims, NDArray *input, NDArray *output, std::vector +&internal_input_shape, std::vector &internal_output_shape, Nd4jLong +*block_shape, Nd4jLong *crops) { return Status::OK(); } @@ -542,12 +638,16 @@ void spaceToBatchND(sd::LaunchContext* context, const NDArray& input, const NDAr #define STB_BOOL (0, false),\ (1, true) - BUILD_TRIPLE_TEMPLATE(template void _execute, (sd::LaunchContext * context, void *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong *space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const Nd4jLong *block_offsets, void *ptrBatch, const Nd4jLong *batch_shape, const Nd4jLong *batch_strides), LIBND4J_TYPES, STB_DIM, STB_BOOL); + BUILD_TRIPLE_TEMPLATE(template void _execute, (sd::LaunchContext * context, +void *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong *space_strides, +const Nd4jLong *block_shape, const Nd4jLong *pad_start, const Nd4jLong +*block_offsets, void *ptrBatch, const Nd4jLong *batch_shape, const Nd4jLong +*batch_strides), LIBND4J_TYPES, STB_DIM, STB_BOOL); #undef STB_BOOL #undef STB_DIM */ -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/s_t_d.cu b/libnd4j/include/ops/declarable/helpers/cuda/s_t_d.cu index 19a1937ddbe9..1203ce7e5ec6 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/s_t_d.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/s_t_d.cu @@ -23,86 +23,106 @@ namespace sd { namespace ops { namespace helpers { - template - static _CUDA_G void spaceToDepthKernel( - const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - const int block_size, - const bool isNHWC) { - auto input_ptr = reinterpret_cast(vx); - auto output_ptr = reinterpret_cast(vz); - - const int batch_size = shape::sizeAt(xShapeInfo, 0); - const int input_depth = isNHWC ? shape::sizeAt(xShapeInfo, 3) : shape::sizeAt(xShapeInfo, 1); - const int input_height = isNHWC ? shape::sizeAt(xShapeInfo, 1) : shape::sizeAt(xShapeInfo, 2); - const int input_width = isNHWC ? shape::sizeAt(xShapeInfo, 2) : shape::sizeAt(xShapeInfo, 3); - - const int output_depth = isNHWC ? shape::sizeAt(zShapeInfo, 3) : shape::sizeAt(zShapeInfo, 1); - const int output_height = isNHWC ? shape::sizeAt(zShapeInfo, 1) : shape::sizeAt(zShapeInfo, 2); - const int output_width = isNHWC ? shape::sizeAt(zShapeInfo, 2) : shape::sizeAt(zShapeInfo, 3); - - const int input_depth_by_output_height = input_depth * output_height; - - const int output_area = output_width * output_height; - const int output_depth_by_output_area = output_depth * output_area; - - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - - if (isNHWC) { - const int total_count = batch_size * input_height * input_width * input_depth; - - for (int inp_idx = tid; inp_idx < total_count; inp_idx += blockDim.x * gridDim.x){ - // inp_idx = d + input_depth * (w + input_width * (h + input_height * b)) - const int d = inp_idx % input_depth; - const int inp_idx2 = inp_idx / input_depth; - const int w = inp_idx2 % input_width; - const int inp_idx3 = inp_idx2 / input_width; - const int h = inp_idx3 % input_height; - const int b = inp_idx3 / input_height; - - const int out_h = h / block_size; - const int offset_h = h % block_size; - const int out_w = w / block_size; - const int offset_w = w % block_size; - const int offset_d = (offset_h * block_size + offset_w) * input_depth; - const int out_d = d + offset_d; - - const int out_idx = out_d + output_depth * (out_w + output_width * (out_h + output_height * b)); - *(output_ptr + out_idx) = *(input_ptr + inp_idx); - } - } else { - const int total_count = batch_size * output_depth_by_output_area; - - for (int inp_idx = tid; inp_idx < total_count; inp_idx += blockDim.x * gridDim.x) { - const int n_iC_oY_bY_oX = inp_idx / block_size; - const int bX = inp_idx - n_iC_oY_bY_oX * block_size; - - const int n_iC_oY_bY = n_iC_oY_bY_oX / output_width; - const int oX = n_iC_oY_bY_oX - n_iC_oY_bY * output_width; - - const int n_iC_oY = n_iC_oY_bY / block_size; - const int bY = n_iC_oY_bY - n_iC_oY * block_size; - - const int n = n_iC_oY / input_depth_by_output_height; - const int iC_oY = n_iC_oY - n * input_depth_by_output_height; - - const int output_idx = oX + (((n * block_size + bY) * block_size + bX) * input_depth_by_output_height + iC_oY) * output_width; - - *(output_ptr + output_idx) = *(input_ptr + inp_idx); - } - } +template +static _CUDA_G void spaceToDepthKernel(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + const int block_size, + const bool isNHWC) { + auto input_ptr = reinterpret_cast(vx); + auto output_ptr = reinterpret_cast(vz); + + const int batch_size = shape::sizeAt(xShapeInfo, 0); + const int input_depth = + isNHWC ? shape::sizeAt(xShapeInfo, 3) : shape::sizeAt(xShapeInfo, 1); + const int input_height = + isNHWC ? shape::sizeAt(xShapeInfo, 1) : shape::sizeAt(xShapeInfo, 2); + const int input_width = + isNHWC ? shape::sizeAt(xShapeInfo, 2) : shape::sizeAt(xShapeInfo, 3); + + const int output_depth = + isNHWC ? shape::sizeAt(zShapeInfo, 3) : shape::sizeAt(zShapeInfo, 1); + const int output_height = + isNHWC ? shape::sizeAt(zShapeInfo, 1) : shape::sizeAt(zShapeInfo, 2); + const int output_width = + isNHWC ? shape::sizeAt(zShapeInfo, 2) : shape::sizeAt(zShapeInfo, 3); + + const int input_depth_by_output_height = input_depth * output_height; + + const int output_area = output_width * output_height; + const int output_depth_by_output_area = output_depth * output_area; + + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + + if (isNHWC) { + const int total_count = + batch_size * input_height * input_width * input_depth; + + for (int inp_idx = tid; inp_idx < total_count; + inp_idx += blockDim.x * gridDim.x) { + // inp_idx = d + input_depth * (w + input_width * (h + input_height * b)) + const int d = inp_idx % input_depth; + const int inp_idx2 = inp_idx / input_depth; + const int w = inp_idx2 % input_width; + const int inp_idx3 = inp_idx2 / input_width; + const int h = inp_idx3 % input_height; + const int b = inp_idx3 / input_height; + + const int out_h = h / block_size; + const int offset_h = h % block_size; + const int out_w = w / block_size; + const int offset_w = w % block_size; + const int offset_d = (offset_h * block_size + offset_w) * input_depth; + const int out_d = d + offset_d; + + const int out_idx = + out_d + + output_depth * (out_w + output_width * (out_h + output_height * b)); + *(output_ptr + out_idx) = *(input_ptr + inp_idx); } + } else { + const int total_count = batch_size * output_depth_by_output_area; - template - static void _spaceTodepth_(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) { - spaceToDepthKernel<<<512, 512, 1024, *context->getCudaStream()>>>(input.specialBuffer(), input.specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), block_size, isNHWC); - } + for (int inp_idx = tid; inp_idx < total_count; + inp_idx += blockDim.x * gridDim.x) { + const int n_iC_oY_bY_oX = inp_idx / block_size; + const int bX = inp_idx - n_iC_oY_bY_oX * block_size; + + const int n_iC_oY_bY = n_iC_oY_bY_oX / output_width; + const int oX = n_iC_oY_bY_oX - n_iC_oY_bY * output_width; + + const int n_iC_oY = n_iC_oY_bY / block_size; + const int bY = n_iC_oY_bY - n_iC_oY * block_size; + + const int n = n_iC_oY / input_depth_by_output_height; + const int iC_oY = n_iC_oY - n * input_depth_by_output_height; - void _spaceTodepth(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) { - NDArray::prepareSpecialUse({output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), _spaceTodepth_, (context, input, output, block_size, isNHWC), LIBND4J_TYPES); - NDArray::registerSpecialUse({output}, {&input}); + const int output_idx = oX + (((n * block_size + bY) * block_size + bX) * + input_depth_by_output_height + + iC_oY) * + output_width; + + *(output_ptr + output_idx) = *(input_ptr + inp_idx); } + } +} + +template +static void _spaceTodepth_(sd::LaunchContext *context, const NDArray &input, + NDArray *output, int block_size, bool isNHWC) { + spaceToDepthKernel<<<512, 512, 1024, *context->getCudaStream()>>>( + input.specialBuffer(), input.specialShapeInfo(), output->specialBuffer(), + output->specialShapeInfo(), block_size, isNHWC); } + +void _spaceTodepth(sd::LaunchContext *context, const NDArray &input, + NDArray *output, int block_size, bool isNHWC) { + NDArray::prepareSpecialUse({output}, {&input}); + BUILD_SINGLE_SELECTOR(input.dataType(), _spaceTodepth_, + (context, input, output, block_size, isNHWC), + LIBND4J_TYPES); + NDArray::registerSpecialUse({output}, {&input}); } -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu b/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu index cbe8895b2e13..3904c82509fb 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu @@ -19,698 +19,773 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include -#include #include #include #include +#include +#include +#include -namespace sd { -namespace ops { -namespace helpers { +#include +namespace sd { +namespace ops { +namespace helpers { /////////////////////////////////////////////////////////////////// // x - indices, y - contains number of bad indices, z - input/output -template -__global__ static void checkIndicesCuda(const void *vx, const Nd4jLong *xShapeInfo, Nd4jLong* y, const Nd4jLong *zShapeInfo, const int axis) { - - const auto x = reinterpret_cast(vx); - - __shared__ int xRank, *coords, xLastDim; - __shared__ Nd4jLong xLen, numOfBadIndxPerBlock; +template +__global__ static void checkIndicesCuda(const void *vx, + const Nd4jLong *xShapeInfo, Nd4jLong *y, + const Nd4jLong *zShapeInfo, + const int axis) { + const auto x = reinterpret_cast(vx); - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - coords = reinterpret_cast(shmem); + __shared__ int xRank, *coords, xLastDim; + __shared__ Nd4jLong xLen, numOfBadIndxPerBlock; - xRank = shape::rank(xShapeInfo); - xLen = shape::length(xShapeInfo); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + coords = reinterpret_cast(shmem); - numOfBadIndxPerBlock = 0; - } - __syncthreads(); + xRank = shape::rank(xShapeInfo); + xLen = shape::length(xShapeInfo); - auto xCoords = coords + threadIdx.x * xRank; + numOfBadIndxPerBlock = 0; + } + __syncthreads(); - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { + auto xCoords = coords + threadIdx.x * xRank; - shape::index2coords(i, xShapeInfo, xCoords); + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; + i += gridDim.x * blockDim.x) { + shape::index2coords(i, xShapeInfo, xCoords); - const Nd4jLong currentInd = x[shape::getOffset(xShapeInfo, xCoords)]; + const Nd4jLong currentInd = x[shape::getOffset(xShapeInfo, xCoords)]; - if(currentInd >= shape::sizeAt(zShapeInfo, axis == -1 ? xCoords[xRank-1] : axis)) { - printf("checkIndices cuda: out of range element %lld at index %lld \n", currentInd, i); - sd::math::atomics::nd4j_atomicAdd(&numOfBadIndxPerBlock, 1); - } + if (currentInd >= + shape::sizeAt(zShapeInfo, axis == -1 ? xCoords[xRank - 1] : axis)) { + printf("checkIndices cuda: out of range element %lld at index %lld \n", + currentInd, i); + sd::math::atomics::nd4j_atomicAdd(&numOfBadIndxPerBlock, 1); } - __syncthreads(); + } + __syncthreads(); - if (threadIdx.x == 0 && numOfBadIndxPerBlock != 0) - sd::math::atomics::nd4j_atomicAdd(y, numOfBadIndxPerBlock); + if (threadIdx.x == 0 && numOfBadIndxPerBlock != 0) + sd::math::atomics::nd4j_atomicAdd(y, numOfBadIndxPerBlock); } /////////////////////////////////////////////////////////////////// -template -static void checkIndicesCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void *vx, const Nd4jLong *xShapeInfo, Nd4jLong* y, const Nd4jLong *zShapeInfo, const int axis) { - - checkIndicesCuda<<>>(vx, xShapeInfo, y, zShapeInfo, axis); +template +static void checkIndicesCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, + Nd4jLong *y, const Nd4jLong *zShapeInfo, const int axis) { + checkIndicesCuda<<>>( + vx, xShapeInfo, y, zShapeInfo, axis); } - /////////////////////////////////////////////////////////////////// -Nd4jLong checkIndices(sd::LaunchContext *context, const NDArray& indices, const NDArray& output, const int axis) { - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (indices.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * indices.rankOf() + 256; +Nd4jLong checkIndices(sd::LaunchContext *context, const NDArray &indices, + const NDArray &output, const int axis) { + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (indices.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(int) * indices.rankOf() + 256; - const auto xType = indices.dataType(); + const auto xType = indices.dataType(); - PointersManager manager(context, "scatterNDcheckIndices"); + PointersManager manager(context, "scatterNDcheckIndices"); - // scalar, initial value = 0 - NDArray numOfBadIndx(sd::DataType::INT64, context, true); + // scalar, initial value = 0 + NDArray numOfBadIndx(sd::DataType::INT64, context, true); - NDArray::prepareSpecialUse({&numOfBadIndx}, {&indices}); - BUILD_SINGLE_SELECTOR(xType, checkIndicesCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), indices.specialBuffer(), indices.specialShapeInfo(), reinterpret_cast(numOfBadIndx.specialBuffer()), output.specialShapeInfo(), axis), INDEXING_TYPES); - NDArray::registerSpecialUse({&numOfBadIndx}, {&indices}); + NDArray::prepareSpecialUse({&numOfBadIndx}, {&indices}); + BUILD_SINGLE_SELECTOR( + xType, checkIndicesCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + indices.specialBuffer(), indices.specialShapeInfo(), + reinterpret_cast(numOfBadIndx.specialBuffer()), + output.specialShapeInfo(), axis), + INDEXING_TYPES); + NDArray::registerSpecialUse({&numOfBadIndx}, {&indices}); - manager.synchronize(); + manager.synchronize(); - return numOfBadIndx.t(0); + return numOfBadIndx.t(0); } /////////////////////////////////////////////////////////////////// // x - indices, y - updates, z - input/output -template -__global__ static void scatterLockCuda(const int opCode, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ int xRank, yRank, zRank, xNonUnitDim, yNonUnitDim, zNonUnitDim, *coords; - __shared__ Nd4jLong xLen, zLen; - __shared__ bool is1Dcase, xySameStride; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - coords = reinterpret_cast(shmem); - - xLen = shape::length(xShapeInfo); - zLen = shape::length(zShapeInfo); - - xRank = shape::rank(xShapeInfo); - yRank = shape::rank(yShapeInfo); - zRank = shape::rank(zShapeInfo); - - xNonUnitDim = yNonUnitDim = zNonUnitDim = 0; - - is1Dcase = (shape::isCommonVector(zShapeInfo, zNonUnitDim) || shape::isScalar(zShapeInfo)) && (shape::isCommonVector(yShapeInfo, yNonUnitDim) || shape::isScalar(yShapeInfo)) && (shape::isCommonVector(xShapeInfo, xNonUnitDim) || shape::isScalar(xShapeInfo)); - - if(is1Dcase) - xySameStride = shape::stride(xShapeInfo)[xNonUnitDim] = shape::stride(yShapeInfo)[yNonUnitDim]; +template +__global__ static void scatterLockCuda(const int opCode, const void *vx, + const Nd4jLong *xShapeInfo, + const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + __shared__ int xRank, yRank, zRank, xNonUnitDim, yNonUnitDim, zNonUnitDim, + *coords; + __shared__ Nd4jLong xLen, zLen; + __shared__ bool is1Dcase, xySameStride; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + coords = reinterpret_cast(shmem); + + xLen = shape::length(xShapeInfo); + zLen = shape::length(zShapeInfo); + + xRank = shape::rank(xShapeInfo); + yRank = shape::rank(yShapeInfo); + zRank = shape::rank(zShapeInfo); + + xNonUnitDim = yNonUnitDim = zNonUnitDim = 0; + + is1Dcase = (shape::isCommonVector(zShapeInfo, zNonUnitDim) || + shape::isScalar(zShapeInfo)) && + (shape::isCommonVector(yShapeInfo, yNonUnitDim) || + shape::isScalar(yShapeInfo)) && + (shape::isCommonVector(xShapeInfo, xNonUnitDim) || + shape::isScalar(xShapeInfo)); + + if (is1Dcase) + xySameStride = shape::stride(xShapeInfo)[xNonUnitDim] = + shape::stride(yShapeInfo)[yNonUnitDim]; + } + __syncthreads(); + + Nd4jLong yOffset, zOffset; + int zFirstCoord, *yCoords, *zCoords; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; + i += gridDim.x * blockDim.x) { + if (!is1Dcase) { + yCoords = coords + threadIdx.x * (yRank + zRank); + zCoords = yCoords + yRank; + shape::index2coords(i, zShapeInfo, zCoords); } - __syncthreads(); - - - Nd4jLong yOffset, zOffset; - int zFirstCoord, *yCoords, *zCoords; - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; i += gridDim.x * blockDim.x) { - - if(!is1Dcase) { - - yCoords = coords + threadIdx.x * (yRank + zRank); - zCoords = yCoords + yRank; - shape::index2coords(i, zShapeInfo, zCoords); - } - - for (Nd4jLong j = 0; j < xLen; ++j) { - - if(is1Dcase) { - - yOffset = j * shape::stride(yShapeInfo)[yNonUnitDim]; - zFirstCoord = x[xySameStride ? yOffset : j * shape::stride(xShapeInfo)[xNonUnitDim]]; - - if(i != zFirstCoord) - continue; - - zOffset = i * shape::stride(zShapeInfo)[zNonUnitDim]; - } - else { - - shape::index2coords(j, xShapeInfo, yCoords); // first xRank coordinates in yCoords are the same for y and x - - zFirstCoord = x[shape::getOffset(xShapeInfo, yCoords)]; - - if(zCoords[0] != zFirstCoord) - continue; - - for (uint k = 0; k < yRank - xRank; ++k) - yCoords[xRank + k] = zCoords[k + 1]; - - yOffset = shape::getOffset(yShapeInfo, yCoords); - zOffset = shape::getOffset(zShapeInfo, zCoords); - } - - switch (opCode) { - case pairwise::Add: - z[zOffset] += y[yOffset]; - break; - case pairwise::Subtract: - z[zOffset] -= y[yOffset]; - break; - case pairwise::Multiply: - z[zOffset] *= y[yOffset]; - break; - case pairwise::Divide: - z[zOffset] /= y[yOffset]; - break; - case pairwise::ReverseSubtract: - z[zOffset] = y[yOffset] - z[zOffset]; - break; - case pairwise::ReverseDivide: - z[zOffset] = y[yOffset] / z[zOffset]; - break; - case pairwise::CopyPws: - z[zOffset] = y[yOffset]; - break; - case pairwise::MaxPairwise: - if(z[zOffset] < y[yOffset]) z[zOffset] = y[yOffset]; - break; - case pairwise::MinPairwise: - if(z[zOffset] > y[yOffset]) z[zOffset] = y[yOffset]; - break; - default: - continue; - } - } + for (Nd4jLong j = 0; j < xLen; ++j) { + if (is1Dcase) { + yOffset = j * shape::stride(yShapeInfo)[yNonUnitDim]; + zFirstCoord = + x[xySameStride ? yOffset + : j * shape::stride(xShapeInfo)[xNonUnitDim]]; + + if (i != zFirstCoord) continue; + + zOffset = i * shape::stride(zShapeInfo)[zNonUnitDim]; + } + + else { + shape::index2coords(j, xShapeInfo, + yCoords); // first xRank coordinates in yCoords are + // the same for y and x + + zFirstCoord = x[shape::getOffset(xShapeInfo, yCoords)]; + + if (zCoords[0] != zFirstCoord) continue; + + for (uint k = 0; k < yRank - xRank; ++k) + yCoords[xRank + k] = zCoords[k + 1]; + + yOffset = shape::getOffset(yShapeInfo, yCoords); + zOffset = shape::getOffset(zShapeInfo, zCoords); + } + + switch (opCode) { + case pairwise::Add: + z[zOffset] += y[yOffset]; + break; + case pairwise::Subtract: + z[zOffset] -= y[yOffset]; + break; + case pairwise::Multiply: + z[zOffset] *= y[yOffset]; + break; + case pairwise::Divide: + z[zOffset] /= y[yOffset]; + break; + case pairwise::ReverseSubtract: + z[zOffset] = y[yOffset] - z[zOffset]; + break; + case pairwise::ReverseDivide: + z[zOffset] = y[yOffset] / z[zOffset]; + break; + case pairwise::CopyPws: + z[zOffset] = y[yOffset]; + break; + case pairwise::MaxPairwise: + if (z[zOffset] < y[yOffset]) z[zOffset] = y[yOffset]; + break; + case pairwise::MinPairwise: + if (z[zOffset] > y[yOffset]) z[zOffset] = y[yOffset]; + break; + default: + continue; + } } + } } - /////////////////////////////////////////////////////////////////// // x - indices, y - updates, z - input/output -template -__global__ static void scatterCuda(const int opCode, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ int xRank, yRank, zRank, xNonUnitDim, yNonUnitDim, zNonUnitDim, *coords; - __shared__ Nd4jLong yLen; - __shared__ bool is1Dcase, xySameStride; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - coords = reinterpret_cast(shmem); - - yLen = shape::length(yShapeInfo); - - xRank = shape::rank(xShapeInfo); - yRank = shape::rank(yShapeInfo); - zRank = shape::rank(zShapeInfo); - - xNonUnitDim = yNonUnitDim = zNonUnitDim = 0; - - is1Dcase = (shape::isCommonVector(zShapeInfo, zNonUnitDim) || shape::isScalar(zShapeInfo)) && (shape::isCommonVector(yShapeInfo, yNonUnitDim) || shape::isScalar(yShapeInfo)) && (shape::isCommonVector(xShapeInfo, xNonUnitDim) || shape::isScalar(xShapeInfo)); - - if(is1Dcase) - xySameStride = shape::stride(xShapeInfo)[xNonUnitDim] = shape::stride(yShapeInfo)[yNonUnitDim]; +template +__global__ static void scatterCuda(const int opCode, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + __shared__ int xRank, yRank, zRank, xNonUnitDim, yNonUnitDim, zNonUnitDim, + *coords; + __shared__ Nd4jLong yLen; + __shared__ bool is1Dcase, xySameStride; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + coords = reinterpret_cast(shmem); + + yLen = shape::length(yShapeInfo); + + xRank = shape::rank(xShapeInfo); + yRank = shape::rank(yShapeInfo); + zRank = shape::rank(zShapeInfo); + + xNonUnitDim = yNonUnitDim = zNonUnitDim = 0; + + is1Dcase = (shape::isCommonVector(zShapeInfo, zNonUnitDim) || + shape::isScalar(zShapeInfo)) && + (shape::isCommonVector(yShapeInfo, yNonUnitDim) || + shape::isScalar(yShapeInfo)) && + (shape::isCommonVector(xShapeInfo, xNonUnitDim) || + shape::isScalar(xShapeInfo)); + + if (is1Dcase) + xySameStride = shape::stride(xShapeInfo)[xNonUnitDim] = + shape::stride(yShapeInfo)[yNonUnitDim]; + } + __syncthreads(); + + Nd4jLong xOffset, yOffset, zOffset; + int *yCoords, *zCoords; + + if (!is1Dcase) { + yCoords = coords + threadIdx.x * (yRank + zRank); + zCoords = yCoords + yRank; + } + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < yLen; + i += gridDim.x * blockDim.x) { + if (is1Dcase) { + yOffset = i * shape::stride(yShapeInfo)[yNonUnitDim]; + zOffset = x[xySameStride ? yOffset + : i * shape::stride(xShapeInfo)[xNonUnitDim]] * + shape::stride(zShapeInfo)[zNonUnitDim]; + } else { + shape::index2coords(i, yShapeInfo, yCoords); + + yOffset = shape::getOffset(yShapeInfo, yCoords); + xOffset = shape::getOffset( + xShapeInfo, yCoords); // first xRank coordinates in yCoords are the + // same for y and x -> for (uint j = 0; j < + // xRank; ++j) xCoords[j] = yCoords[j]; + + zCoords[0] = x[xOffset]; + + for (uint j = 0; j < yRank - xRank; ++j) + zCoords[j + 1] = yCoords[xRank + j]; + + zOffset = shape::getOffset(zShapeInfo, zCoords); } - __syncthreads(); - - - Nd4jLong xOffset, yOffset, zOffset; - int *yCoords, *zCoords; - if(!is1Dcase) { - yCoords = coords + threadIdx.x * (yRank + zRank); - zCoords = yCoords + yRank; - } - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < yLen; i += gridDim.x * blockDim.x) { - - if(is1Dcase) { - - yOffset = i * shape::stride(yShapeInfo)[yNonUnitDim]; - zOffset = x[xySameStride ? yOffset : i * shape::stride(xShapeInfo)[xNonUnitDim]] * shape::stride(zShapeInfo)[zNonUnitDim]; - } - else { - shape::index2coords(i, yShapeInfo, yCoords); - - yOffset = shape::getOffset(yShapeInfo, yCoords); - xOffset = shape::getOffset(xShapeInfo, yCoords); // first xRank coordinates in yCoords are the same for y and x -> for (uint j = 0; j < xRank; ++j) xCoords[j] = yCoords[j]; - - zCoords[0] = x[xOffset]; - - for (uint j = 0; j < yRank - xRank; ++j) - zCoords[j + 1] = yCoords[xRank + j]; - - zOffset = shape::getOffset(zShapeInfo, zCoords); - } - - switch (opCode) { - case pairwise::Add: - z[zOffset] += y[yOffset]; - break; - case pairwise::Subtract: - z[zOffset] -= y[yOffset]; - break; - case pairwise::Multiply: - z[zOffset] *= y[yOffset]; - break; - case pairwise::Divide: - z[zOffset] /= y[yOffset]; - break; - case pairwise::ReverseSubtract: - z[zOffset] = y[yOffset] - z[zOffset]; - break; - case pairwise::ReverseDivide: - z[zOffset] = y[yOffset] / z[zOffset]; - break; - case pairwise::CopyPws: - z[zOffset] = y[yOffset]; - break; - case pairwise::MaxPairwise: - if(z[zOffset] < y[yOffset]) z[zOffset] = y[yOffset]; - break; - case pairwise::MinPairwise: - if(z[zOffset] > y[yOffset]) z[zOffset] = y[yOffset]; - break; - default: - continue; - } + switch (opCode) { + case pairwise::Add: + z[zOffset] += y[yOffset]; + break; + case pairwise::Subtract: + z[zOffset] -= y[yOffset]; + break; + case pairwise::Multiply: + z[zOffset] *= y[yOffset]; + break; + case pairwise::Divide: + z[zOffset] /= y[yOffset]; + break; + case pairwise::ReverseSubtract: + z[zOffset] = y[yOffset] - z[zOffset]; + break; + case pairwise::ReverseDivide: + z[zOffset] = y[yOffset] / z[zOffset]; + break; + case pairwise::CopyPws: + z[zOffset] = y[yOffset]; + break; + case pairwise::MaxPairwise: + if (z[zOffset] < y[yOffset]) z[zOffset] = y[yOffset]; + break; + case pairwise::MinPairwise: + if (z[zOffset] > y[yOffset]) z[zOffset] = y[yOffset]; + break; + default: + continue; } + } } /////////////////////////////////////////////////////////////////// -template -static void scatterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const int opCode, +template +static void scatterCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t *stream, const int opCode, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, const bool lock) { - - if(lock) - scatterLockCuda<<>>(opCode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); - else - scatterCuda<<>>(opCode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); + if (lock) + scatterLockCuda + <<>>( + opCode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); + else + scatterCuda<<>>( + opCode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); } - /////////////////////////////////////////////////////////////////// -void scatter(sd::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) { - - const auto xType = indices.dataType(); - const auto yType = updates.dataType(); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = ((lock ? output.lengthOf() : updates.lengthOf()) + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = sizeof(int) * threadsPerBlock * (updates.rankOf() + output.rankOf()) + 256; - - PointersManager manager(context, "scatter"); - - NDArray::prepareSpecialUse({&output}, {&updates, &indices}); - BUILD_DOUBLE_SELECTOR(xType, yType, scatterCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, indices.specialBuffer(), indices.specialShapeInfo(), updates.specialBuffer(), updates.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), lock), INDEXING_TYPES, GENERIC_NUMERIC_TYPES); - NDArray::registerSpecialUse({&output}, {&updates, &indices}); - - manager.synchronize(); +void scatter(sd::LaunchContext *context, pairwise::Ops op, + const NDArray &indices, const NDArray &updates, NDArray &output, + const bool lock) { + const auto xType = indices.dataType(); + const auto yType = updates.dataType(); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + ((lock ? output.lengthOf() : updates.lengthOf()) + threadsPerBlock - 1) / + threadsPerBlock; + const int sharedMem = + sizeof(int) * threadsPerBlock * (updates.rankOf() + output.rankOf()) + + 256; + + PointersManager manager(context, "scatter"); + + NDArray::prepareSpecialUse({&output}, {&updates, &indices}); + BUILD_DOUBLE_SELECTOR( + xType, yType, scatterCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, + indices.specialBuffer(), indices.specialShapeInfo(), + updates.specialBuffer(), updates.specialShapeInfo(), + output.specialBuffer(), output.specialShapeInfo(), lock), + INDEXING_TYPES, GENERIC_NUMERIC_TYPES); + NDArray::registerSpecialUse({&output}, {&updates, &indices}); + + manager.synchronize(); } /////////////////////////////////////////////////////////////////// // x - indices, y - updates, z - output -template -__global__ static void scatterNDLockCuda(const int opCode, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ int xRank, yRank, zRank, biggerXYRank, xLastDim, *coords, xNonUnitDim, yNonUnitDim, zNonUnitDim; - __shared__ Nd4jLong zLen, len; - __shared__ bool is1Dcase; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - coords = reinterpret_cast(shmem); - - xRank = shape::rank(xShapeInfo); - yRank = shape::rank(yShapeInfo); - zRank = shape::rank(zShapeInfo); - xLastDim = shape::sizeAt(xShapeInfo, -1); - - biggerXYRank = xRank > yRank ? xRank : yRank; - - xNonUnitDim = yNonUnitDim = zNonUnitDim = 0; - - is1Dcase = (shape::isCommonVector(zShapeInfo, zNonUnitDim) || shape::isScalar(zShapeInfo)) && (shape::isCommonVector(yShapeInfo, yNonUnitDim) || shape::isScalar(yShapeInfo)) && (shape::isCommonVector(xShapeInfo, xNonUnitDim) || shape::isScalar(xShapeInfo)); - - len = is1Dcase ? shape::length(xShapeInfo) : shape::length(xShapeInfo) / xLastDim; - zLen = shape::length(zShapeInfo); - } - __syncthreads(); - - Nd4jLong yOffset, zOffset, xOffset; - int *yCoords, *zCoords; - - if(!is1Dcase) { - yCoords = coords + threadIdx.x * (biggerXYRank + zRank); - zCoords = yCoords + biggerXYRank; - } - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; i += gridDim.x * blockDim.x) { - - if(!is1Dcase) - shape::index2coords(i, zShapeInfo, zCoords); - - for (Nd4jLong j = 0; j < len; ++j) { // if !is1Dcase then we loop through first xRank-1 dimensions of x, that is we exclude last x dimension - - if(is1Dcase) { - - if(x[j * shape::stride(xShapeInfo)[xNonUnitDim]] != i) - continue; - - yOffset = j * shape::stride(yShapeInfo)[yNonUnitDim]; - zOffset = i * shape::stride(zShapeInfo)[zNonUnitDim]; - } - else { - - shape::index2coords(j, xRank-1, shape::shapeOf(const_cast(xShapeInfo)), yCoords); // first xRank-1 coordinates in yCoords are the same for y and x - - // first iteration - yCoords[xRank - 1] = 0; - xOffset = shape::getOffset(xShapeInfo, yCoords); - if(zCoords[0] != x[xOffset]) - continue; - - // rest iterations - bool matched = true; - for (uint k = 1; k < xLastDim; ++k) { - yCoords[xRank - 1] = k; - xOffset += shape::stride(xShapeInfo)[xRank-1]; - if(zCoords[k] != x[xOffset]) { - matched = false; - break; - } - } - - if(!matched) - continue; - - for (uint k = xLastDim; k < zRank; ++k) - yCoords[yRank - zRank + k] = zCoords[k]; - - yOffset = shape::getOffset(yShapeInfo, yCoords); - zOffset = shape::getOffset(zShapeInfo, zCoords); - } - - switch (opCode) { - case pairwise::Add: - z[zOffset] += y[yOffset]; - break; - case pairwise::Subtract: - z[zOffset] -= y[yOffset]; - break; - case pairwise::Multiply: - z[zOffset] *= y[yOffset]; - break; - case pairwise::Divide: - z[zOffset] /= y[yOffset]; - break; - case pairwise::ReverseSubtract: - z[zOffset] = y[yOffset] - z[zOffset]; - break; - case pairwise::ReverseDivide: - z[zOffset] = y[yOffset] / z[zOffset]; - break; - case pairwise::CopyPws: - z[zOffset] = y[yOffset]; - break; - case pairwise::MaxPairwise: - if(z[zOffset] < y[yOffset]) z[zOffset] = y[yOffset]; - break; - case pairwise::MinPairwise: - if(z[zOffset] > y[yOffset]) z[zOffset] = y[yOffset]; - break; - default: - continue; - } +template +__global__ static void scatterNDLockCuda(const int opCode, const void *vx, + const Nd4jLong *xShapeInfo, + const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + __shared__ int xRank, yRank, zRank, biggerXYRank, xLastDim, *coords, + xNonUnitDim, yNonUnitDim, zNonUnitDim; + __shared__ Nd4jLong zLen, len; + __shared__ bool is1Dcase; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + coords = reinterpret_cast(shmem); + + xRank = shape::rank(xShapeInfo); + yRank = shape::rank(yShapeInfo); + zRank = shape::rank(zShapeInfo); + xLastDim = shape::sizeAt(xShapeInfo, -1); + + biggerXYRank = xRank > yRank ? xRank : yRank; + + xNonUnitDim = yNonUnitDim = zNonUnitDim = 0; + + is1Dcase = (shape::isCommonVector(zShapeInfo, zNonUnitDim) || + shape::isScalar(zShapeInfo)) && + (shape::isCommonVector(yShapeInfo, yNonUnitDim) || + shape::isScalar(yShapeInfo)) && + (shape::isCommonVector(xShapeInfo, xNonUnitDim) || + shape::isScalar(xShapeInfo)); + + len = is1Dcase ? shape::length(xShapeInfo) + : shape::length(xShapeInfo) / xLastDim; + zLen = shape::length(zShapeInfo); + } + __syncthreads(); + + Nd4jLong yOffset, zOffset, xOffset; + int *yCoords, *zCoords; + + if (!is1Dcase) { + yCoords = coords + threadIdx.x * (biggerXYRank + zRank); + zCoords = yCoords + biggerXYRank; + } + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; + i += gridDim.x * blockDim.x) { + if (!is1Dcase) shape::index2coords(i, zShapeInfo, zCoords); + + for (Nd4jLong j = 0; j < len; + ++j) { // if !is1Dcase then we loop through first xRank-1 dimensions + // of x, that is we exclude last x dimension + + if (is1Dcase) { + if (x[j * shape::stride(xShapeInfo)[xNonUnitDim]] != i) continue; + + yOffset = j * shape::stride(yShapeInfo)[yNonUnitDim]; + zOffset = i * shape::stride(zShapeInfo)[zNonUnitDim]; + } else { + shape::index2coords(j, xRank - 1, + shape::shapeOf(const_cast(xShapeInfo)), + yCoords); // first xRank-1 coordinates in yCoords + // are the same for y and x + + // first iteration + yCoords[xRank - 1] = 0; + xOffset = shape::getOffset(xShapeInfo, yCoords); + if (zCoords[0] != x[xOffset]) continue; + + // rest iterations + bool matched = true; + for (uint k = 1; k < xLastDim; ++k) { + yCoords[xRank - 1] = k; + xOffset += shape::stride(xShapeInfo)[xRank - 1]; + if (zCoords[k] != x[xOffset]) { + matched = false; + break; + } } + + if (!matched) continue; + + for (uint k = xLastDim; k < zRank; ++k) + yCoords[yRank - zRank + k] = zCoords[k]; + + yOffset = shape::getOffset(yShapeInfo, yCoords); + zOffset = shape::getOffset(zShapeInfo, zCoords); + } + + switch (opCode) { + case pairwise::Add: + z[zOffset] += y[yOffset]; + break; + case pairwise::Subtract: + z[zOffset] -= y[yOffset]; + break; + case pairwise::Multiply: + z[zOffset] *= y[yOffset]; + break; + case pairwise::Divide: + z[zOffset] /= y[yOffset]; + break; + case pairwise::ReverseSubtract: + z[zOffset] = y[yOffset] - z[zOffset]; + break; + case pairwise::ReverseDivide: + z[zOffset] = y[yOffset] / z[zOffset]; + break; + case pairwise::CopyPws: + z[zOffset] = y[yOffset]; + break; + case pairwise::MaxPairwise: + if (z[zOffset] < y[yOffset]) z[zOffset] = y[yOffset]; + break; + case pairwise::MinPairwise: + if (z[zOffset] > y[yOffset]) z[zOffset] = y[yOffset]; + break; + default: + continue; + } } + } } /////////////////////////////////////////////////////////////////// // x - indices, y - updates, z - output -template -__global__ static void scatterNDCuda(const int opCode, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ int xRank, yRank, zRank, biggerXYRank, xLastDim, *coords, xNonUnitDim, yNonUnitDim, zNonUnitDim; - __shared__ Nd4jLong yLen; - __shared__ bool is1Dcase; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - coords = reinterpret_cast(shmem); - - yLen = shape::length(yShapeInfo); - xRank = shape::rank(xShapeInfo); - yRank = shape::rank(yShapeInfo); - zRank = shape::rank(zShapeInfo); - xLastDim = shape::sizeAt(xShapeInfo, -1); - - biggerXYRank = xRank > yRank ? xRank : yRank; - - xNonUnitDim = yNonUnitDim = zNonUnitDim = 0; - - is1Dcase = (shape::isCommonVector(zShapeInfo, zNonUnitDim) || shape::isScalar(zShapeInfo)) && (shape::isCommonVector(yShapeInfo, yNonUnitDim) || shape::isScalar(yShapeInfo)) && (shape::isCommonVector(xShapeInfo, xNonUnitDim) || shape::isScalar(xShapeInfo)); +template +__global__ static void scatterNDCuda(const int opCode, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + __shared__ int xRank, yRank, zRank, biggerXYRank, xLastDim, *coords, + xNonUnitDim, yNonUnitDim, zNonUnitDim; + __shared__ Nd4jLong yLen; + __shared__ bool is1Dcase; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + coords = reinterpret_cast(shmem); + + yLen = shape::length(yShapeInfo); + xRank = shape::rank(xShapeInfo); + yRank = shape::rank(yShapeInfo); + zRank = shape::rank(zShapeInfo); + xLastDim = shape::sizeAt(xShapeInfo, -1); + + biggerXYRank = xRank > yRank ? xRank : yRank; + + xNonUnitDim = yNonUnitDim = zNonUnitDim = 0; + + is1Dcase = (shape::isCommonVector(zShapeInfo, zNonUnitDim) || + shape::isScalar(zShapeInfo)) && + (shape::isCommonVector(yShapeInfo, yNonUnitDim) || + shape::isScalar(yShapeInfo)) && + (shape::isCommonVector(xShapeInfo, xNonUnitDim) || + shape::isScalar(xShapeInfo)); + } + __syncthreads(); + + Nd4jLong yOffset, zOffset; + int *yCoords, *zCoords; + + if (!is1Dcase) { + yCoords = coords + threadIdx.x * (biggerXYRank + zRank); + zCoords = yCoords + biggerXYRank; + } + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < yLen; + i += gridDim.x * blockDim.x) { + if (is1Dcase) { + yOffset = i * shape::stride(yShapeInfo)[zNonUnitDim]; + zOffset = x[i * shape::stride(xShapeInfo)[xNonUnitDim]] * + shape::stride(zShapeInfo)[zNonUnitDim]; + } else { + shape::index2coords(i, yShapeInfo, yCoords); + + yOffset = shape::getOffset(yShapeInfo, yCoords); + + if (yRank >= xRank) + zCoords[xLastDim] = + yCoords[xRank - 1]; // saving y coordinate, since it might be + // changed in next instructions + + for (uint j = 0; j < xLastDim; ++j) { // first xRank-1 coordinates in + // yCoords are the same for y and x + yCoords[xRank - 1] = j; + zCoords[j] = x[shape::getOffset(xShapeInfo, yCoords)]; + } + + for (uint j = xLastDim + 1; j < zRank; ++j) + zCoords[j] = yCoords[yRank - zRank + j]; + + zOffset = shape::getOffset(zShapeInfo, zCoords); } - __syncthreads(); - - Nd4jLong yOffset, zOffset; - int *yCoords, *zCoords; - - if(!is1Dcase) { - yCoords = coords + threadIdx.x * (biggerXYRank + zRank); - zCoords = yCoords + biggerXYRank; - } - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < yLen; i += gridDim.x * blockDim.x) { - - if(is1Dcase) { - - yOffset = i * shape::stride(yShapeInfo)[zNonUnitDim]; - zOffset = x[i * shape::stride(xShapeInfo)[xNonUnitDim]] * shape::stride(zShapeInfo)[zNonUnitDim]; - } - else { - - shape::index2coords(i, yShapeInfo, yCoords); - - yOffset = shape::getOffset(yShapeInfo, yCoords); - - if(yRank >= xRank) - zCoords[xLastDim] = yCoords[xRank - 1]; // saving y coordinate, since it might be changed in next instructions - - for (uint j = 0; j < xLastDim; ++j) { // first xRank-1 coordinates in yCoords are the same for y and x - yCoords[xRank - 1] = j; - zCoords[j] = x[shape::getOffset(xShapeInfo, yCoords)]; - } - - for (uint j = xLastDim + 1; j < zRank; ++j) - zCoords[j] = yCoords[yRank - zRank + j]; - - zOffset = shape::getOffset(zShapeInfo, zCoords); - } - switch (opCode) { - case pairwise::Add: - z[zOffset] += y[yOffset]; - break; - case pairwise::Subtract: - z[zOffset] -= y[yOffset]; - break; - case pairwise::Multiply: - z[zOffset] *= y[yOffset]; - break; - case pairwise::Divide: - z[zOffset] /= y[yOffset]; - break; - case pairwise::ReverseSubtract: - z[zOffset] = y[yOffset] - z[zOffset]; - break; - case pairwise::ReverseDivide: - z[zOffset] = y[yOffset] / z[zOffset]; - break; - case pairwise::CopyPws: - z[zOffset] = y[yOffset]; - break; - case pairwise::MaxPairwise: - if(z[zOffset] < y[yOffset]) z[zOffset] = y[yOffset]; - break; - case pairwise::MinPairwise: - if(z[zOffset] > y[yOffset]) z[zOffset] = y[yOffset]; - break; - default: - continue; - } + switch (opCode) { + case pairwise::Add: + z[zOffset] += y[yOffset]; + break; + case pairwise::Subtract: + z[zOffset] -= y[yOffset]; + break; + case pairwise::Multiply: + z[zOffset] *= y[yOffset]; + break; + case pairwise::Divide: + z[zOffset] /= y[yOffset]; + break; + case pairwise::ReverseSubtract: + z[zOffset] = y[yOffset] - z[zOffset]; + break; + case pairwise::ReverseDivide: + z[zOffset] = y[yOffset] / z[zOffset]; + break; + case pairwise::CopyPws: + z[zOffset] = y[yOffset]; + break; + case pairwise::MaxPairwise: + if (z[zOffset] < y[yOffset]) z[zOffset] = y[yOffset]; + break; + case pairwise::MinPairwise: + if (z[zOffset] > y[yOffset]) z[zOffset] = y[yOffset]; + break; + default: + continue; } + } } /////////////////////////////////////////////////////////////////// -template -static void scatterNDCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const int opCode, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - const bool lock) { - - if(lock) - scatterNDLockCuda<<>>(opCode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); - else - scatterNDCuda<<>>(opCode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); +template +static void scatterNDCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t *stream, const int opCode, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, const bool lock) { + if (lock) + scatterNDLockCuda + <<>>( + opCode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); + else + scatterNDCuda<<>>( + opCode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); } /////////////////////////////////////////////////////////////////// -void scatterND(sd::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) { - - const int xRank = indices.rankOf(); - const int yRank = updates.rankOf(); - const int zRank = output.rankOf(); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = ((lock ? output.lengthOf() : updates.lengthOf()) + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * ((yRank > xRank ? yRank : xRank) + zRank) + 256; - - const auto xType = indices.dataType(); - const auto yType = updates.dataType(); - - PointersManager manager(context, "scatterND"); - - NDArray::prepareSpecialUse({&output}, {&updates, &indices}); - BUILD_DOUBLE_SELECTOR(xType, yType, scatterNDCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, indices.specialBuffer(), indices.specialShapeInfo(), updates.specialBuffer(), updates.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), lock), INDEXING_TYPES, GENERIC_NUMERIC_TYPES); - NDArray::registerSpecialUse({&output}, {&updates, &indices}); - - manager.synchronize(); +void scatterND(sd::LaunchContext *context, pairwise::Ops op, + const NDArray &indices, const NDArray &updates, NDArray &output, + const bool lock) { + const int xRank = indices.rankOf(); + const int yRank = updates.rankOf(); + const int zRank = output.rankOf(); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + ((lock ? output.lengthOf() : updates.lengthOf()) + threadsPerBlock - 1) / + threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(int) * + ((yRank > xRank ? yRank : xRank) + zRank) + + 256; + + const auto xType = indices.dataType(); + const auto yType = updates.dataType(); + + PointersManager manager(context, "scatterND"); + + NDArray::prepareSpecialUse({&output}, {&updates, &indices}); + BUILD_DOUBLE_SELECTOR( + xType, yType, scatterNDCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, + indices.specialBuffer(), indices.specialShapeInfo(), + updates.specialBuffer(), updates.specialShapeInfo(), + output.specialBuffer(), output.specialShapeInfo(), lock), + INDEXING_TYPES, GENERIC_NUMERIC_TYPES); + NDArray::registerSpecialUse({&output}, {&updates, &indices}); + + manager.synchronize(); } /////////////////////////////////////////////////////////////////// -template +template __global__ void scatterForLossCuda(const void *vx, const Nd4jLong *xShapeInfo, - void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - const auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); + void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo) { + const auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); - __shared__ Nd4jLong xLen; - __shared__ int xRank, *sharedMem; // xRank = zRank, yRank = xRank + 1 + __shared__ Nd4jLong xLen; + __shared__ int xRank, *sharedMem; // xRank = zRank, yRank = xRank + 1 - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - xLen = shape::length(xShapeInfo); - xRank = shape::rank(xShapeInfo); - } - __syncthreads(); + xLen = shape::length(xShapeInfo); + xRank = shape::rank(xShapeInfo); + } + __syncthreads(); - const auto xInd = threadIdx.x + blockIdx.x * blockDim.x; + const auto xInd = threadIdx.x + blockIdx.x * blockDim.x; - if(xInd >= xLen) - return; + if (xInd >= xLen) return; - auto coords = sharedMem + threadIdx.x * (xRank + 1); + auto coords = sharedMem + threadIdx.x * (xRank + 1); - shape::index2coords(xInd, xShapeInfo, coords); + shape::index2coords(xInd, xShapeInfo, coords); - // y last coordinate - coords[xRank] = x[shape::getOffset(xShapeInfo, coords)]; + // y last coordinate + coords[xRank] = x[shape::getOffset(xShapeInfo, coords)]; - const auto yOffset = shape::getOffset(yShapeInfo, coords); + const auto yOffset = shape::getOffset(yShapeInfo, coords); - if(z == nullptr) { // gradient calculation - y[yOffset] -= 1.f; - } - else { - z[shape::getOffset(zShapeInfo, coords)] = y[yOffset]; - } + if (z == nullptr) { // gradient calculation + y[yOffset] -= 1.f; + } else { + z[shape::getOffset(zShapeInfo, coords)] = y[yOffset]; + } } /////////////////////////////////////////////////////////////////// -template -static void scatterForLossCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong* xShapeInfo, void *vy, const Nd4jLong* yShapeInfo, void *vz, const Nd4jLong* zShapeInfo) { - - scatterForLossCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); +template +static void scatterForLossCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, + void *vy, const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + scatterForLossCuda + <<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); } /////////////////////////////////////////////////////////////////// -void scatterForLoss(sd::LaunchContext* context, const NDArray& indices, NDArray& updates, NDArray& output, const bool calcGrad) { - // shapes of indices and output must be the same - // shape of indices should be the same as updates shape with last dimension excluded, for example if updates is {a,b,c} then indices should be {a,b} - - PointersManager manager(context, "scatterForLoss"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (indices.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = updates.rankOf() * sizeof(int) * threadsPerBlock + 128; - - if(calcGrad) { - NDArray::prepareSpecialUse({&updates}, {&indices}); - BUILD_DOUBLE_SELECTOR(indices.dataType(), updates.dataType(), scatterForLossCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), indices.specialBuffer(), indices.specialShapeInfo(), updates.specialBuffer(), updates.specialShapeInfo(), nullptr, nullptr), INDEXING_TYPES, FLOAT_TYPES); - NDArray::registerSpecialUse({&updates}, {&indices}); - } - else { - NDArray::prepareSpecialUse({&output}, {&indices, &updates}); - BUILD_DOUBLE_SELECTOR(indices.dataType(), updates.dataType(), scatterForLossCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), indices.specialBuffer(), indices.specialShapeInfo(), updates.specialBuffer(), updates.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo()), INDEXING_TYPES, FLOAT_TYPES); - NDArray::registerSpecialUse({&output}, {&indices, &updates}); - } - - manager.synchronize(); -} - -} -} +void scatterForLoss(sd::LaunchContext *context, const NDArray &indices, + NDArray &updates, NDArray &output, const bool calcGrad) { + // shapes of indices and output must be the same + // shape of indices should be the same as updates shape with last dimension + // excluded, for example if updates is {a,b,c} then indices should be {a,b} + + PointersManager manager(context, "scatterForLoss"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (indices.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = updates.rankOf() * sizeof(int) * threadsPerBlock + 128; + + if (calcGrad) { + NDArray::prepareSpecialUse({&updates}, {&indices}); + BUILD_DOUBLE_SELECTOR( + indices.dataType(), updates.dataType(), scatterForLossCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + indices.specialBuffer(), indices.specialShapeInfo(), + updates.specialBuffer(), updates.specialShapeInfo(), nullptr, nullptr), + INDEXING_TYPES, FLOAT_TYPES); + NDArray::registerSpecialUse({&updates}, {&indices}); + } else { + NDArray::prepareSpecialUse({&output}, {&indices, &updates}); + BUILD_DOUBLE_SELECTOR( + indices.dataType(), updates.dataType(), scatterForLossCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + indices.specialBuffer(), indices.specialShapeInfo(), + updates.specialBuffer(), updates.specialShapeInfo(), + output.specialBuffer(), output.specialShapeInfo()), + INDEXING_TYPES, FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&indices, &updates}); + } + + manager.synchronize(); } +} // namespace helpers +} // namespace ops +} // namespace sd /* /////////////////////////////////////////////////////////////////// template -static void scatterLockCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const int opCode, - const void* vx, const Nd4jLong *xShapeInfo, - const void* vy, const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets, - void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets, - const Nd4jLong xLen, const Nd4jLong yTadLen, const Nd4jLong zTadLen) { - - scatterLockCuda<<>>(opCode, vx, xShapeInfo, vy, yTadShapeInfo, yOffsets, vz, zTadShapeInfo, zOffsets, xLen, yTadLen, zTadLen); +static void scatterLockCudaLauncher(const int blocksPerGrid, const int +threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const int +opCode, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const +Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets, void* vz, const Nd4jLong +*zTadShapeInfo, const Nd4jLong *zOffsets, const Nd4jLong xLen, const Nd4jLong +yTadLen, const Nd4jLong zTadLen) { + + scatterLockCuda<<>>(opCode, vx, xShapeInfo, vy, yTadShapeInfo, yOffsets, vz, +zTadShapeInfo, zOffsets, xLen, yTadLen, zTadLen); } @@ -718,16 +793,18 @@ static void scatterLockCudaLauncher(const int blocksPerGrid, const int threadsPe // x - indices, y - updates, z - input/output template __global__ static void scatterLockCuda(const int opCode, - const void* vx, const Nd4jLong *xShapeInfo, - const void* vy, const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets, - void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets, - const Nd4jLong xLen, const Nd4jLong yTadLen, const Nd4jLong zTadLen) { + const void* vx, const Nd4jLong +*xShapeInfo, const void* vy, const Nd4jLong *yTadShapeInfo, const Nd4jLong +*yOffsets, void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets, + const Nd4jLong xLen, const Nd4jLong +yTadLen, const Nd4jLong zTadLen) { const int xRank = indices.rankOf(); - std::vector zTadDims = ShapeUtils::evalDimsToExclude(output.rankOf(), {0}); + std::vector zTadDims = +ShapeUtils::evalDimsToExclude(output.rankOf(), {0}); int sizeOfUpdDims = xRank; if(output.rankOf() == updates.rankOf() && indices.isVector()) @@ -736,19 +813,28 @@ __global__ static void scatterLockCuda(const int opCode, std::vector yTadDims(sizeOfUpdDims); std::iota(yTadDims.begin(), yTadDims.end(), 0); - auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions(updates.shapeInfo(), ShapeUtils::evalDimsToExclude(updates.rankOf(), yTadDims)); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output.shapeInfo(), zTadDims); + auto packY = +sd::ConstantTadHelper::getInstance().tadForDimensions(updates.shapeInfo(), +ShapeUtils::evalDimsToExclude(updates.rankOf(), yTadDims)); auto packZ = +sd::ConstantTadHelper::getInstance().tadForDimensions(output.shapeInfo(), +zTadDims); const Nd4jLong zTadLen = shape::length(packZ.primaryShapeInfo()); const Nd4jLong yTadLen = shape::length(packY.primaryShapeInfo()); - const auto threadsPerBlock = sd::math::nd4j_max(32, sd::math::nd4j_min(zTadLen, 1024)); - const auto blocksPerGrid = indices.lengthOf(); + const auto threadsPerBlock = sd::math::nd4j_max(32, +sd::math::nd4j_min(zTadLen, 1024)); const auto blocksPerGrid = +indices.lengthOf(); const auto xType = indices.dataType(); const auto yType = updates.dataType(); - BUILD_DOUBLE_SELECTOR(xType, yType, scatterLockCudaLauncher, (blocksPerGrid, threadsPerBlock, 1024, context->getCudaStream(), op, indices.specialBuffer(), indices.specialShapeInfo(), updates.specialBuffer(), packY.specialShapeInfo(), packY.specialOffsets(), output.specialBuffer(), packZ.specialShapeInfo(), packZ.specialOffsets(), indices.lengthOf(), yTadLen, zTadLen), INDEXING_TYPES, GENERIC_NUMERIC_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, scatterLockCudaLauncher, +(blocksPerGrid, threadsPerBlock, 1024, context->getCudaStream(), op, +indices.specialBuffer(), indices.specialShapeInfo(), updates.specialBuffer(), +packY.specialShapeInfo(), packY.specialOffsets(), output.specialBuffer(), +packZ.specialShapeInfo(), packZ.specialOffsets(), indices.lengthOf(), yTadLen, +zTadLen), INDEXING_TYPES, GENERIC_NUMERIC_TYPES); @@ -764,12 +850,14 @@ __global__ static void scatterLockCuda(const int opCode, for (int e = 0; e < xLen; e++) { const Nd4jLong zIndex = x[shape::getIndexOffset(e, xShapeInfo)]; - const bool isOwner = zIndex < gridDim.x ? blockIdx.x == zIndex : blockIdx.x == zIndex % gridDim.x; + const bool isOwner = zIndex < gridDim.x ? blockIdx.x == zIndex : +blockIdx.x == zIndex % gridDim.x; if (!isOwner) continue; - if(vectorCase) { // means z_rank = 1 and might be yTadLen != zTadLen in this case + if(vectorCase) { // means z_rank = 1 and might be yTadLen != zTadLen in +this case if(threadIdx.x != 0) continue; @@ -842,13 +930,9 @@ __global__ static void scatterLockCuda(const int opCode, zTad[zOffset] = yTad[yOffset]; break; case pairwise::MaxPairwise: - if(zTad[zOffset] < yTad[yOffset]) zTad[zOffset] = yTad[yOffset]; - break; - case pairwise::MinPairwise: - if(zTad[zOffset] > yTad[yOffset]) zTad[zOffset] = yTad[yOffset]; - break; - default: - continue; + if(zTad[zOffset] < yTad[yOffset]) zTad[zOffset] = +yTad[yOffset]; break; case pairwise::MinPairwise: if(zTad[zOffset] > +yTad[yOffset]) zTad[zOffset] = yTad[yOffset]; break; default: continue; } } } @@ -856,10 +940,11 @@ __global__ static void scatterLockCuda(const int opCode, } template - __global__ static void scatterCuda(const int opCode, const int numOfSubArrs, - void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets, - void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, - const int* indexes, unsigned int arrLenX, unsigned int arrLenY) { + __global__ static void scatterCuda(const int opCode, const int +numOfSubArrs, void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets, + void* vy, const Nd4jLong +*yShapeInfo, const Nd4jLong *yOffsets, const int* indexes, unsigned int arrLenX, +unsigned int arrLenY) { __shared__ T *x, *y; @@ -868,7 +953,8 @@ __global__ static void scatterLockCuda(const int opCode, for (int e = 0; e < numOfSubArrs; e++) { const auto xIndex = indexes[e]; - const bool isOwner = xIndex < gridDim.x ? blockIdx.x == xIndex : blockIdx.x == xIndex % gridDim.x; + const bool isOwner = xIndex < gridDim.x ? blockIdx.x == +xIndex : blockIdx.x == xIndex % gridDim.x; if (!isOwner) continue; @@ -879,10 +965,11 @@ __global__ static void scatterLockCuda(const int opCode, } __syncthreads(); - for (Nd4jLong i = threadIdx.x; i < arrLenX; i += blockDim.x) { + for (Nd4jLong i = threadIdx.x; i < arrLenX; i += +blockDim.x) { - const auto xOffset = shape::getIndexOffset(i, xShapeInfo); - const auto yOffset = shape::getIndexOffset(i, yShapeInfo); + const auto xOffset = shape::getIndexOffset(i, +xShapeInfo); const auto yOffset = shape::getIndexOffset(i, yShapeInfo); switch (opCode) { case pairwise::Add: @@ -922,9 +1009,9 @@ __global__ static void scatterLockCuda(const int opCode, } __syncthreads(); - for (Nd4jLong i = threadIdx.x; i < arrLenX; i += blockDim.x) { - const auto xOffset = shape::getIndexOffset(i, xShapeInfo); - const auto yOffset = shape::getIndexOffset(i, yShapeInfo); + for (Nd4jLong i = threadIdx.x; i < arrLenX; i += +blockDim.x) { const auto xOffset = shape::getIndexOffset(i, xShapeInfo); const +auto yOffset = shape::getIndexOffset(i, yShapeInfo); switch (opCode) { case pairwise::Add: @@ -959,12 +1046,17 @@ __global__ static void scatterLockCuda(const int opCode, template - void scatter_(sd::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) { + void scatter_(sd::LaunchContext *context, pairwise::Ops op, const +NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) { std::vector dims = {0}; - auto inverted = ShapeUtils::evalDimsToExclude(output.rankOf(), dims); + auto inverted = ShapeUtils::evalDimsToExclude(output.rankOf(), +dims); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(output.shapeInfo(), inverted); - auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions(updates.shapeInfo(), inverted); + auto packX = +sd::ConstantTadHelper::getInstance().tadForDimensions(output.shapeInfo(), +inverted); auto packY = +sd::ConstantTadHelper::getInstance().tadForDimensions(updates.shapeInfo(), +inverted); auto psX = packX.specialShapeInfo(); auto psY = packY.special(); @@ -976,17 +1068,23 @@ __global__ static void scatterLockCuda(const int opCode, NDArray::prepareSpecialUse({&output}, {&updates, &indices}); - unsigned int tadLengthX = shape::length(packX.primaryShapeInfo()); - unsigned int tadLengthY = shape::length(packY.primary()); - if (tadLengthX != tadLengthY) - throw std::runtime_error("scatter: Lengths of TADs must be equal"); + unsigned int tadLengthX = +shape::length(packX.primaryShapeInfo()); unsigned int tadLengthY = +shape::length(packY.primary()); if (tadLengthX != tadLengthY) throw +std::runtime_error("scatter: Lengths of TADs must be equal"); - auto blockSize = sd::math::nd4j_max(32, sd::math::nd4j_min(tadLengthX, 1024)); + auto blockSize = sd::math::nd4j_max(32, +sd::math::nd4j_min(tadLengthX, 1024)); if (lock) - scatterCuda<<<512, blockSize, 1024, *context->getCudaStream()>>>(op, indices.lengthOf(), output.specialBuffer(), psX, poX, updates.specialBuffer(), psY, poY, reinterpret_cast(indices.specialBuffer()), tadLengthX, tadLengthY); - else - scatterCuda<<<512, blockSize, 1024, *context->getCudaStream()>>>(op, indices.lengthOf(), output.specialBuffer(), psX, poX, updates.specialBuffer(), psY, poY, reinterpret_cast(indices.specialBuffer()), tadLengthX, tadLengthY); + scatterCuda<<<512, blockSize, 1024, +*context->getCudaStream()>>>(op, indices.lengthOf(), output.specialBuffer(), +psX, poX, updates.specialBuffer(), psY, poY, reinterpret_cast(indices.specialBuffer()), tadLengthX, tadLengthY); else scatterCuda<<<512, blockSize, 1024, *context->getCudaStream()>>>(op, +indices.lengthOf(), output.specialBuffer(), psX, poX, updates.specialBuffer(), +psY, poY, reinterpret_cast(indices.specialBuffer()), tadLengthX, +tadLengthY); NDArray::registerSpecialUse({&output}, {&updates, &indices}); manager.synchronize(); @@ -998,11 +1096,11 @@ __global__ static void scatterLockCuda(const int opCode, // x - indices, y - updates, z - output template __global__ static void scatterNDLockCuda(const int opCode, - const void* vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, - const void* vy, const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets, - void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets, - const Nd4jLong *zShapeInfo, - const Nd4jLong numOfXTads, const Nd4jLong numOfZTads, const Nd4jLong yTadLen) { + const void* vx, const Nd4jLong +*xTadShapeInfo, const Nd4jLong *xOffsets, const void* vy, const Nd4jLong +*yTadShapeInfo, const Nd4jLong *yOffsets, void* vz, const Nd4jLong +*zTadShapeInfo, const Nd4jLong *zOffsets, const Nd4jLong *zShapeInfo, const +Nd4jLong numOfXTads, const Nd4jLong numOfZTads, const Nd4jLong yTadLen) { @@ -1016,17 +1114,23 @@ const int xLastDim = indices.sizeAt(-1); zTadDims[i] = zRank - 1 - j; } - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(indices.shapeInfo(), {xRank - 1}); - auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions(updates.shapeInfo(), yTadDims); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output.shapeInfo(), zTadDims); + auto packX = +sd::ConstantTadHelper::getInstance().tadForDimensions(indices.shapeInfo(), +{xRank - 1}); auto packY = +sd::ConstantTadHelper::getInstance().tadForDimensions(updates.shapeInfo(), +yTadDims); auto packZ = +sd::ConstantTadHelper::getInstance().tadForDimensions(output.shapeInfo(), +zTadDims); const int threadsPerBlock = MAX_NUM_THREADS / 4; const int blocksPerGrid = packZ.numberOfTads(); const int sharedMem = 8 * threadsPerBlock * xLastDim + 128; --------------------------------------------------------------------------- - // zTadLen == yTadLen if numOfZTads > 1, in opposite case z and y are vectors - // numOfXTads == numOfYTads if numOfZTads > 1, in opposite case z and y are vectors + // zTadLen == yTadLen if numOfZTads > 1, in opposite case z and y are +vectors + // numOfXTads == numOfYTads if numOfZTads > 1, in opposite case z and y are +vectors const auto x = reinterpret_cast(vx); const auto y = reinterpret_cast(vy); @@ -1049,11 +1153,14 @@ const int xLastDim = indices.sizeAt(-1); const X* xTad = x + xOffsets[i]; for (uint k = 0; k < xLastDim; ++k) - zTadCoordsPerThread[k] = xTad[shape::getIndexOffset(k, xTadShapeInfo)]; + zTadCoordsPerThread[k] = xTad[shape::getIndexOffset(k, +xTadShapeInfo)]; - const auto zTadIndex = shape::coords2index(xLastDim, zShapeInfo + 1, zTadCoordsPerThread); + const auto zTadIndex = shape::coords2index(xLastDim, zShapeInfo + 1, +zTadCoordsPerThread); - const bool isOwner = zTadIndex < gridDim.x ? blockIdx.x == zTadIndex : blockIdx.x == zTadIndex % gridDim.x; + const bool isOwner = zTadIndex < gridDim.x ? blockIdx.x == zTadIndex : +blockIdx.x == zTadIndex % gridDim.x; if(!isOwner) continue; @@ -1063,8 +1170,8 @@ const int xLastDim = indices.sizeAt(-1); if(threadIdx.x != 0) continue; - const auto yOffset = shape::getIndexOffset(i, yTadShapeInfo); - const auto zOffset = shape::getIndexOffset(zTadIndex, zTadShapeInfo); + const auto yOffset = shape::getIndexOffset(i, yTadShapeInfo); const +auto zOffset = shape::getIndexOffset(zTadIndex, zTadShapeInfo); switch (opCode) { case pairwise::Add: @@ -1130,13 +1237,9 @@ const int xLastDim = indices.sizeAt(-1); zTad[zOffset] = yTad[yOffset]; break; case pairwise::MaxPairwise: - if(zTad[zOffset] < yTad[yOffset]) zTad[zOffset] = yTad[yOffset]; - break; - case pairwise::MinPairwise: - if(zTad[zOffset] > yTad[yOffset]) zTad[zOffset] = yTad[yOffset]; - break; - default: - continue; + if(zTad[zOffset] < yTad[yOffset]) zTad[zOffset] = +yTad[yOffset]; break; case pairwise::MinPairwise: if(zTad[zOffset] > +yTad[yOffset]) zTad[zOffset] = yTad[yOffset]; break; default: continue; } } } @@ -1144,24 +1247,33 @@ const int xLastDim = indices.sizeAt(-1); } */ - // PointersManager manager(&context, "NativeOps::concat"); - // PointersManager::printDevContentOnDev(vx, 2); - // PointersManager::printDevContentOnDev(xShapeInfo, 8); - // PointersManager::printDevContentOnDev(vy, 8); - // PointersManager::printDevContentOnDev(yShapeInfo, 8); - // PointersManager::printDevContentOnDev(zShapeInfo, 8); - - // manager.printDevContentOnHost(indices.specialBuffer(), indices.lengthOf()); - // manager.printDevContentOnHost(indices.special(), shape::shapeInfoLength(indices.rankOf())); - // manager.printDevContentOnHost(updates.specialBuffer(), updates.lengthOf()); - // manager.printDevContentOnHost(updates.special(), shape::shapeInfoLength(updates.rankOf())); - // manager.printDevContentOnHost(output.special(), shape::shapeInfoLength(output.rankOf())); - // printf("!!!!!!!\n"); - // manager.printDevContentOnHost(packX.special(), 2*shape::rank(packX.primary()) + 4); - // manager.printDevContentOnHost(packX.special(), packX.numberOfTads()); - // manager.printDevContentOnHost(packY.special(), 2*shape::rank(packY.primary()) + 4); - // manager.printDevContentOnHost(packY.special(), packY.numberOfTads()); - // manager.printDevContentOnHost(packZ.special(), 2*shape::rank(packZ.primary()) + 4); - // manager.printDevContentOnHost(packZ.special(), packZ.numberOfTads()); - // printf("dddddddd\n"); - // shape::printShapeInfoLinear(packY.primary()); \ No newline at end of file +// PointersManager manager(&context, "NativeOps::concat"); +// PointersManager::printDevContentOnDev(vx, 2); +// PointersManager::printDevContentOnDev(xShapeInfo, 8); +// PointersManager::printDevContentOnDev(vy, 8); +// PointersManager::printDevContentOnDev(yShapeInfo, 8); +// PointersManager::printDevContentOnDev(zShapeInfo, 8); + +// manager.printDevContentOnHost(indices.specialBuffer(), +// indices.lengthOf()); +// manager.printDevContentOnHost(indices.special(), +// shape::shapeInfoLength(indices.rankOf())); +// manager.printDevContentOnHost(updates.specialBuffer(), +// updates.lengthOf()); +// manager.printDevContentOnHost(updates.special(), +// shape::shapeInfoLength(updates.rankOf())); +// manager.printDevContentOnHost(output.special(), +// shape::shapeInfoLength(output.rankOf())); printf("!!!!!!!\n"); +// manager.printDevContentOnHost(packX.special(), +// 2*shape::rank(packX.primary()) + 4); +// manager.printDevContentOnHost(packX.special(), +// packX.numberOfTads()); +// manager.printDevContentOnHost(packY.special(), +// 2*shape::rank(packY.primary()) + 4); +// manager.printDevContentOnHost(packY.special(), +// packY.numberOfTads()); +// manager.printDevContentOnHost(packZ.special(), +// 2*shape::rank(packZ.primary()) + 4); +// manager.printDevContentOnHost(packZ.special(), +// packZ.numberOfTads()); printf("dddddddd\n"); +// shape::printShapeInfoLinear(packY.primary()); \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/scatter_simple.cu b/libnd4j/include/ops/declarable/helpers/cuda/scatter_simple.cu index 3b422a5c2ada..9e0088775819 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/scatter_simple.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/scatter_simple.cu @@ -18,62 +18,75 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - -#include -#include -#include -#include #include -#include +#include #include -#include #include +#include +#include +#include +#include -namespace sd { - namespace ops { - namespace helpers { - template - static _CUDA_G void scatterSimpleKernel(void *vx, const Nd4jLong *xTadShape, const Nd4jLong *xTadOffsets, Nd4jLong xLength, Nd4jLong numTads, const void *vi, const Nd4jLong *iShapeInfo, Nd4jLong iLength, const void *vu, const Nd4jLong *uShapeInfo, Nd4jLong uLength) { - auto u = reinterpret_cast(vu); - auto indices = reinterpret_cast(vi); - - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - for (int i = tid; i < iLength; i += blockDim.x * gridDim.x) { - auto x = reinterpret_cast(vx) + xTadOffsets[i]; - auto idx = indices[shape::getIndexOffset(i, iShapeInfo)]; - - x[shape::getIndexOffset(idx, xTadShape)] = u[shape::getIndexOffset(i, uShapeInfo)]; - } - } - - - template - void scatterSimple_(sd::LaunchContext * context, const int opId, NDArray& input, const NDArray& updates, const NDArray& indices, const std::vector& dimensions) { - - auto dims = ShapeUtils::evalDimsToExclude(input.rankOf(), dimensions); - auto packX = ConstantTadHelper::getInstance().tadForDimensions(input.shapeInfo(), dims); - - auto xLength = shape::length(packX.primaryShapeInfo()); - auto iLength = indices.lengthOf(); - auto uLength = updates.lengthOf(); - - scatterSimpleKernel<<<256, 256, 1024, *context->getCudaStream()>>>(input.specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), xLength, packX.numberOfTads(), indices.specialBuffer(), indices.specialShapeInfo(), iLength, updates.specialBuffer(), updates.specialShapeInfo(), uLength); - } - - - void scatterSimple(sd::LaunchContext * context, const int opId, NDArray& input, const NDArray& updates, const NDArray& indices, const std::vector& dimensions) { - auto xType = input.dataType(); - auto yType = indices.dataType(); - - if (opId != 6) - throw std::runtime_error("scatterSimple: only copy op is supported"); - - NDArray::prepareSpecialUse({&input}, {&updates, &indices}); - - BUILD_DOUBLE_SELECTOR(xType, yType, scatterSimple_, (context, opId, input, updates, indices, dimensions), LIBND4J_TYPES, INDEXING_TYPES); +#include - NDArray::registerSpecialUse({&input}, {&updates, &indices}); - } - } - } -} \ No newline at end of file +namespace sd { +namespace ops { +namespace helpers { +template +static _CUDA_G void scatterSimpleKernel( + void* vx, const Nd4jLong* xTadShape, const Nd4jLong* xTadOffsets, + Nd4jLong xLength, Nd4jLong numTads, const void* vi, + const Nd4jLong* iShapeInfo, Nd4jLong iLength, const void* vu, + const Nd4jLong* uShapeInfo, Nd4jLong uLength) { + auto u = reinterpret_cast(vu); + auto indices = reinterpret_cast(vi); + + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + for (int i = tid; i < iLength; i += blockDim.x * gridDim.x) { + auto x = reinterpret_cast(vx) + xTadOffsets[i]; + auto idx = indices[shape::getIndexOffset(i, iShapeInfo)]; + + x[shape::getIndexOffset(idx, xTadShape)] = + u[shape::getIndexOffset(i, uShapeInfo)]; + } +} + +template +void scatterSimple_(sd::LaunchContext* context, const int opId, NDArray& input, + const NDArray& updates, const NDArray& indices, + const std::vector& dimensions) { + auto dims = ShapeUtils::evalDimsToExclude(input.rankOf(), dimensions); + auto packX = ConstantTadHelper::getInstance().tadForDimensions( + input.shapeInfo(), dims); + + auto xLength = shape::length(packX.primaryShapeInfo()); + auto iLength = indices.lengthOf(); + auto uLength = updates.lengthOf(); + + scatterSimpleKernel<<<256, 256, 1024, *context->getCudaStream()>>>( + input.specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), + xLength, packX.numberOfTads(), indices.specialBuffer(), + indices.specialShapeInfo(), iLength, updates.specialBuffer(), + updates.specialShapeInfo(), uLength); +} + +void scatterSimple(sd::LaunchContext* context, const int opId, NDArray& input, + const NDArray& updates, const NDArray& indices, + const std::vector& dimensions) { + auto xType = input.dataType(); + auto yType = indices.dataType(); + + if (opId != 6) + throw std::runtime_error("scatterSimple: only copy op is supported"); + + NDArray::prepareSpecialUse({&input}, {&updates, &indices}); + + BUILD_DOUBLE_SELECTOR(xType, yType, scatterSimple_, + (context, opId, input, updates, indices, dimensions), + LIBND4J_TYPES, INDEXING_TYPES); + + NDArray::registerSpecialUse({&input}, {&updates, &indices}); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/scatter_update.cu b/libnd4j/include/ops/declarable/helpers/cuda/scatter_update.cu index 3a3bfef12d93..ac0de2f72d21 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/scatter_update.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/scatter_update.cu @@ -18,115 +18,124 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - -#include -#include -#include -#include #include -#include +#include #include -#include #include +#include +#include +#include +#include -namespace sd { - namespace ops { - namespace helpers { - /////////////////////////////////////////////////////////////////// - template - __global__ static void scatterUpdateCuda(const int opCode, const int numOfInd, - void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets, - void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, - const int* indexes) { - - __shared__ T *x, *y; - __shared__ Nd4jLong arrLenX, arrLenY; - - for (int e = 0; e < numOfInd; e++ ) { - - const auto xIndex = indexes[e]; - const bool isOwner = xIndex < gridDim.x ? blockIdx.x == xIndex : blockIdx.x == xIndex % gridDim.x; - - if (!isOwner) - continue; - - if (threadIdx.x == 0) { - x = reinterpret_cast(vx) + xOffsets[xIndex]; - y = reinterpret_cast(vy) + yOffsets[e]; - arrLenX = shape::length(xShapeInfo); - arrLenY = shape::length(yShapeInfo); - } - __syncthreads(); - - if (arrLenX != arrLenY) - return; - - for (Nd4jLong i = threadIdx.x; i < arrLenX; i += blockDim.x) { - - const auto xOffset = shape::getIndexOffset(i, xShapeInfo); - const auto yOffset = shape::getIndexOffset(i, yShapeInfo); - - switch (opCode) { - case 0: - x[xOffset] += y[yOffset]; - break; - case 1: - x[xOffset] -= y[yOffset]; - break; - case 2: - x[xOffset] *= y[yOffset]; - break; - case 3: - x[xOffset] /= y[yOffset]; - break; - case 4: - x[xOffset] = y[yOffset] - x[xOffset]; - break; - case 5: - x[xOffset] = y[yOffset] / x[xOffset]; - break; - case 6: - x[xOffset] = y[yOffset]; - break; - default: - continue; - } - } - __syncthreads(); - } - } - - template - __host__ static void scatterUpdateCudaLauncher(const cudaStream_t* stream, const int opCode, const int numOfInd, void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets, void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, const int* indexes) { - - scatterUpdateCuda<<<512, 256, MAX_NUM_THREADS, *stream>>>(opCode, numOfInd, vx, xShapeInfo, xOffsets, vy, yShapeInfo, yOffsets, indexes); - } +#include +namespace sd { +namespace ops { +namespace helpers { +/////////////////////////////////////////////////////////////////// +template +__global__ static void scatterUpdateCuda(const int opCode, const int numOfInd, + void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xOffsets, void* vy, + const Nd4jLong* yShapeInfo, + const Nd4jLong* yOffsets, + const int* indexes) { + __shared__ T *x, *y; + __shared__ Nd4jLong arrLenX, arrLenY; + + for (int e = 0; e < numOfInd; e++) { + const auto xIndex = indexes[e]; + const bool isOwner = xIndex < gridDim.x ? blockIdx.x == xIndex + : blockIdx.x == xIndex % gridDim.x; + + if (!isOwner) continue; + + if (threadIdx.x == 0) { + x = reinterpret_cast(vx) + xOffsets[xIndex]; + y = reinterpret_cast(vy) + yOffsets[e]; + arrLenX = shape::length(xShapeInfo); + arrLenY = shape::length(yShapeInfo); + } + __syncthreads(); + + if (arrLenX != arrLenY) return; + + for (Nd4jLong i = threadIdx.x; i < arrLenX; i += blockDim.x) { + const auto xOffset = shape::getIndexOffset(i, xShapeInfo); + const auto yOffset = shape::getIndexOffset(i, yShapeInfo); + + switch (opCode) { + case 0: + x[xOffset] += y[yOffset]; + break; + case 1: + x[xOffset] -= y[yOffset]; + break; + case 2: + x[xOffset] *= y[yOffset]; + break; + case 3: + x[xOffset] /= y[yOffset]; + break; + case 4: + x[xOffset] = y[yOffset] - x[xOffset]; + break; + case 5: + x[xOffset] = y[yOffset] / x[xOffset]; + break; + case 6: + x[xOffset] = y[yOffset]; + break; + default: + continue; + } + } + __syncthreads(); + } +} + +template +__host__ static void scatterUpdateCudaLauncher( + const cudaStream_t* stream, const int opCode, const int numOfInd, void* vx, + const Nd4jLong* xShapeInfo, const Nd4jLong* xOffsets, void* vy, + const Nd4jLong* yShapeInfo, const Nd4jLong* yOffsets, const int* indexes) { + scatterUpdateCuda<<<512, 256, MAX_NUM_THREADS, *stream>>>( + opCode, numOfInd, vx, xShapeInfo, xOffsets, vy, yShapeInfo, yOffsets, + indexes); +} ////////////////////////////////////////////////////////////////////////// - void scatterUpdate(sd::LaunchContext* context, NDArray& input, NDArray& updates, const std::vector* intArgs) { - - const int opCode = (*intArgs)[0]; - const int numOfDims = (*intArgs)[1]; - const int numOfInd = (*intArgs)[2 + numOfDims]; - - std::vector tadDimensions(numOfDims); - for (int e = 2; e < 2 + numOfDims; e++) - tadDimensions[e-2] = (*intArgs)[e]; - - auto packX = ConstantTadHelper::getInstance().tadForDimensions(input.shapeInfo(), tadDimensions); - auto packY = ConstantTadHelper::getInstance().tadForDimensions(updates.shapeInfo(), tadDimensions); - - NDArray indices(const_cast(intArgs->data()) + numOfDims + 3, 'c', {numOfInd}, sd::DataType::INT32, context); - - PointersManager manager(context, "scatterUpdate"); - - NDArray::prepareSpecialUse({&input}, {&input, &updates, &indices}); - BUILD_SINGLE_SELECTOR(input.dataType(), scatterUpdateCudaLauncher, (context->getCudaStream(), opCode, numOfInd, input.specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), updates.specialBuffer(), packY.platformShapeInfo(), packY.platformOffsets(), reinterpret_cast(indices.specialBuffer())), LIBND4J_TYPES); - NDArray::registerSpecialUse({&input}, {&input, &updates, &indices}); - - manager.synchronize(); - } - } - } -} \ No newline at end of file +void scatterUpdate(sd::LaunchContext* context, NDArray& input, NDArray& updates, + const std::vector* intArgs) { + const int opCode = (*intArgs)[0]; + const int numOfDims = (*intArgs)[1]; + const int numOfInd = (*intArgs)[2 + numOfDims]; + + std::vector tadDimensions(numOfDims); + for (int e = 2; e < 2 + numOfDims; e++) tadDimensions[e - 2] = (*intArgs)[e]; + + auto packX = ConstantTadHelper::getInstance().tadForDimensions( + input.shapeInfo(), tadDimensions); + auto packY = ConstantTadHelper::getInstance().tadForDimensions( + updates.shapeInfo(), tadDimensions); + + NDArray indices(const_cast(intArgs->data()) + numOfDims + 3, 'c', + {numOfInd}, sd::DataType::INT32, context); + + PointersManager manager(context, "scatterUpdate"); + + NDArray::prepareSpecialUse({&input}, {&input, &updates, &indices}); + BUILD_SINGLE_SELECTOR(input.dataType(), scatterUpdateCudaLauncher, + (context->getCudaStream(), opCode, numOfInd, + input.specialBuffer(), packX.platformShapeInfo(), + packX.platformOffsets(), updates.specialBuffer(), + packY.platformShapeInfo(), packY.platformOffsets(), + reinterpret_cast(indices.specialBuffer())), + LIBND4J_TYPES); + NDArray::registerSpecialUse({&input}, {&input, &updates, &indices}); + + manager.synchronize(); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment.cu index 60d00fb60f57..f077d8df7ce1 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment.cu @@ -18,117 +18,149 @@ // @author GS // -#include -#include #include -#include -#include #include -#include #include +#include +#include +#include +#include +#include namespace sd { namespace ops { namespace helpers { - // -------------------------------------------------------------------------------------------------------------- // - // Sorted segments ops implementations - - template - static bool segmentIndicesValidate_(NDArray* indices, NDArray& aexpected, NDArray& aoutput) { - return true; - } - - bool segmentIndicesValidate(sd::LaunchContext* context , NDArray* indices, NDArray& expected, NDArray& output) { - BUILD_DOUBLE_SELECTOR(output.dataType(), indices->dataType(), return segmentIndicesValidate_, (indices, expected, output), NUMERIC_TYPES, INDEXING_TYPES); - } - - // -------------------------------------------------------------------------------------------------------------- // - // Unsorted segment ops functors implementation - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void unsortedSegmentIndexValidateKernel(const I* indices, const Nd4jLong* indicesShape, I expected, I* found) { - __shared__ bool onlyTrue; - __shared__ Nd4jLong len; - - if (threadIdx.x == 0) { - onlyTrue = true; - len = shape::length(indicesShape); - } - __syncthreads(); - auto start = threadIdx.x + blockIdx.x * blockDim.x; - auto step = gridDim.x * blockDim.x; - for (int e = start; e < len && onlyTrue; e += step) { - sd::math::atomics::nd4j_atomicMax(found, indices[e]); - if (expected < *found) - onlyTrue = false; - } - } - - template - static bool unsortedSegmentIndicesValidate_(sd::LaunchContext* context , NDArray* indices, Nd4jLong expected, Nd4jLong& output) { - output = expected; - I found = output; - I exp = expected; - auto stream = context->getCudaStream(); - I* devFound; - cudaMalloc(&devFound, sizeof(I)); - cudaMemcpy(devFound, &found, sizeof(I), cudaMemcpyHostToDevice); - unsortedSegmentIndexValidateKernel<<<1, indices->lengthOf(), 128, *stream>>>(reinterpret_cast(indices->specialBuffer()), indices->specialShapeInfo(), exp, devFound); - cudaMemcpy(&found, devFound, sizeof(I), cudaMemcpyDeviceToHost); - cudaFree(devFound); - output = found; - return expected == output; - } - - bool unsortedSegmentIndicesValidate(sd::LaunchContext* context , NDArray* indices, Nd4jLong expected, Nd4jLong& output) { - BUILD_SINGLE_SELECTOR(indices->dataType(), return unsortedSegmentIndicesValidate_, (context, indices, expected, output), INDEXING_TYPES); - } - - // -------------------------------------------------------------------------------------------------------------- // - - // -------------------------------------------------------------------------------------------------------------- // - // fill up segments starts and ends - splitted ordered case - template - static __global__ void fillUpSegmentsKernel(const void* indices, const Nd4jLong* indexShape, int numClasses, int* classesRangesStart, int* classesRangesLenghts) { - __shared__ const I* idxBuf; - __shared__ Nd4jLong idxLen; - __shared__ int* result; - if (threadIdx.x == 0) { - idxBuf = reinterpret_cast(indices); - idxLen = shape::length(indexShape); - } - __syncthreads(); - - auto tid = threadIdx.x + blockDim.x * blockIdx.x; - auto step = blockDim.x * gridDim.x; - - for (auto j = tid; j < idxLen; j += step) { - auto pos = idxBuf[j]; - sd::math::atomics::nd4j_atomicMin(&classesRangesStart[pos], (int)j); - sd::math::atomics::nd4j_atomicAdd(&classesRangesLenghts[pos], 1); - } - } - - // -------------------------------------------------------------------------------------------------------------- // - - template - static void fillUpSegments_(NDArray* indices, Nd4jLong numClasses, NDArray& classesRangesBegs, NDArray& classesRangesLens) { - dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - auto stream = classesRangesBegs.getContext()->getCudaStream(); - fillUpSegmentsKernel<<>>(indices->specialBuffer(), indices->specialShapeInfo(), numClasses, begins, lengths); - } - // -------------------------------------------------------------------------------------------------------------- // - - void fillUpSegments(NDArray* indices, Nd4jLong numClasses, NDArray& classesRangesBegs, NDArray& classesRangesLens) { - BUILD_SINGLE_SELECTOR(indices->dataType(), fillUpSegments_, (indices, numClasses, classesRangesBegs, classesRangesLens), INDEXING_TYPES); - } - // -------------------------------------------------------------------------------------------------------------- // +// -------------------------------------------------------------------------------------------------------------- +// // Sorted segments ops implementations +template +static bool segmentIndicesValidate_(NDArray* indices, NDArray& aexpected, + NDArray& aoutput) { + return true; } + +bool segmentIndicesValidate(sd::LaunchContext* context, NDArray* indices, + NDArray& expected, NDArray& output) { + BUILD_DOUBLE_SELECTOR( + output.dataType(), indices->dataType(), return segmentIndicesValidate_, + (indices, expected, output), NUMERIC_TYPES, INDEXING_TYPES); +} + +// -------------------------------------------------------------------------------------------------------------- +// // Unsorted segment ops functors implementation +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void unsortedSegmentIndexValidateKernel( + const I* indices, const Nd4jLong* indicesShape, I expected, I* found) { + __shared__ bool onlyTrue; + __shared__ Nd4jLong len; + + if (threadIdx.x == 0) { + onlyTrue = true; + len = shape::length(indicesShape); + } + __syncthreads(); + auto start = threadIdx.x + blockIdx.x * blockDim.x; + auto step = gridDim.x * blockDim.x; + for (int e = start; e < len && onlyTrue; e += step) { + sd::math::atomics::nd4j_atomicMax(found, indices[e]); + if (expected < *found) onlyTrue = false; + } +} + +template +static bool unsortedSegmentIndicesValidate_(sd::LaunchContext* context, + NDArray* indices, Nd4jLong expected, + Nd4jLong& output) { + output = expected; + I found = output; + I exp = expected; + auto stream = context->getCudaStream(); + I* devFound; + cudaMalloc(&devFound, sizeof(I)); + cudaMemcpy(devFound, &found, sizeof(I), cudaMemcpyHostToDevice); + unsortedSegmentIndexValidateKernel + <<<1, indices->lengthOf(), 128, *stream>>>( + reinterpret_cast(indices->specialBuffer()), + indices->specialShapeInfo(), exp, devFound); + cudaMemcpy(&found, devFound, sizeof(I), cudaMemcpyDeviceToHost); + cudaFree(devFound); + output = found; + return expected == output; +} + +bool unsortedSegmentIndicesValidate(sd::LaunchContext* context, + NDArray* indices, Nd4jLong expected, + Nd4jLong& output) { + BUILD_SINGLE_SELECTOR(indices->dataType(), + return unsortedSegmentIndicesValidate_, + (context, indices, expected, output), INDEXING_TYPES); +} + +// -------------------------------------------------------------------------------------------------------------- +// // + +// -------------------------------------------------------------------------------------------------------------- +// // fill up segments starts and ends - splitted ordered case +template +static __global__ void fillUpSegmentsKernel(const void* indices, + const Nd4jLong* indexShape, + int numClasses, + int* classesRangesStart, + int* classesRangesLenghts) { + __shared__ const I* idxBuf; + __shared__ Nd4jLong idxLen; + __shared__ int* result; + if (threadIdx.x == 0) { + idxBuf = reinterpret_cast(indices); + idxLen = shape::length(indexShape); + } + __syncthreads(); + + auto tid = threadIdx.x + blockDim.x * blockIdx.x; + auto step = blockDim.x * gridDim.x; + + for (auto j = tid; j < idxLen; j += step) { + auto pos = idxBuf[j]; + sd::math::atomics::nd4j_atomicMin(&classesRangesStart[pos], (int)j); + sd::math::atomics::nd4j_atomicAdd(&classesRangesLenghts[pos], 1); + } +} + +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static void fillUpSegments_(NDArray* indices, Nd4jLong numClasses, + NDArray& classesRangesBegs, + NDArray& classesRangesLens) { + dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + auto stream = classesRangesBegs.getContext()->getCudaStream(); + fillUpSegmentsKernel<<>>( + indices->specialBuffer(), indices->specialShapeInfo(), numClasses, begins, + lengths); } +// -------------------------------------------------------------------------------------------------------------- +// // + +void fillUpSegments(NDArray* indices, Nd4jLong numClasses, + NDArray& classesRangesBegs, NDArray& classesRangesLens) { + BUILD_SINGLE_SELECTOR( + indices->dataType(), fillUpSegments_, + (indices, numClasses, classesRangesBegs, classesRangesLens), + INDEXING_TYPES); } -// -------------------------------------------------------------------------------------------------------------- // -// -------------------------------------------------------------------------------------------------------------- // +// -------------------------------------------------------------------------------------------------------------- +// // + +} // namespace helpers +} // namespace ops +} // namespace sd +// -------------------------------------------------------------------------------------------------------------- +// // +// -------------------------------------------------------------------------------------------------------------- +// // diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu index d623c873494e..6300f01bfa42 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu @@ -18,413 +18,509 @@ // @author GS // -#include -#include - #include -#include -#include #include -#include #include +#include +#include +#include +#include +#include namespace sd { - namespace ops { - namespace helpers { - - // -------------------------------------------------------------------------------------------------------------- // - // Segment ops linear kernels - // -------------------------------------------------------------------------------------------------------------- // - - template - static __global__ void - segmentMaxLinearKernel(void *input, Nd4jLong const* inputShape, int *starts, int *lengths, Nd4jLong numOfClasses, - void *output, Nd4jLong const* outputShape) { - __shared__ T *val; - __shared__ Nd4jLong xLen, zLen, zIndex; - __shared__ T *x; - __shared__ T *z; - __shared__ int threadsPerSegment, start, finish; - - auto segment = blockIdx.x; - if (threadIdx.x == 0) { -// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; -// segment = blockIdx.x / threadsPerSegment; - x = reinterpret_cast(input); - z = reinterpret_cast(output); - extern __shared__ unsigned char shmem[]; - val = reinterpret_cast(shmem); - xLen = shape::length(inputShape); - zLen = shape::length(outputShape); - - if (segment < numOfClasses) { - zIndex = shape::getIndexOffset(segment, outputShape); - start = starts[segment]; - finish = start + lengths[segment]; - z[zIndex] = x[shape::getIndexOffset(start, inputShape)]; - val[segment] = z[zIndex]; - } - - } - __syncthreads(); - - for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputShape); - sd::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]); - } - } - // -------------------------------------------------------------------------------------------------------------- // - - template - static __global__ void - unsortedSegmentMaxLinearKernel(void *input, Nd4jLong const* inputShape, void *indices, Nd4jLong const* indicesShape, - int *starts, int *lengths, Nd4jLong numOfClasses, void *output, - Nd4jLong const* outputShape) { - __shared__ T *val; - __shared__ Nd4jLong xLen, zLen, zIndex; - __shared__ T *x; - __shared__ T *z; - __shared__ I *y; //int threadsPerSegment, start, finish; - auto segment = blockIdx.x; - - if (threadIdx.x == 0) { - x = reinterpret_cast(input); - z = reinterpret_cast(output); - y = reinterpret_cast(indices); - xLen = shape::length(inputShape); - zLen = shape::length(outputShape); - - zIndex = shape::getIndexOffset(segment, outputShape); - //start = starts[segment]; - //finish = start + lengths[segment]; - if (lengths[segment] > 0) - z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape)]; - else - z[zIndex] = -DataTypeUtils::max(); - } - __syncthreads(); - if (lengths[segment] > 0) - for (auto e = threadIdx.x + 1; e < xLen; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputShape); - auto yIndex = shape::getIndexOffset(e, indicesShape); - if (y[yIndex] == segment) { - sd::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]); - } - } - } - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void segmentMaxTadKernel(void* inputBuf, Nd4jLong const* inputShape, Nd4jLong const* inputTads, - Nd4jLong const* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, - Nd4jLong const* outputShape, Nd4jLong const* outputTads, Nd4jLong const* outputTadOffsets, T filler = 0) { - - __shared__ T* val; - __shared__ Nd4jLong len, zIndex, total; - __shared__ T* z; - __shared__ int start, finish; - __shared__ I segment; - - if (threadIdx.x == 0) { - segment = indices[blockIdx.x]; // / threadsPerSegment; - z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; - len = shape::length(inputTads); - - start = starts[segment]; - finish = start + lengths[segment]; - total = shape::sizeAt(inputShape, 0); - } - __syncthreads(); - - auto idx = blockIdx.x; - if (idx <= total) { - auto x = reinterpret_cast(inputBuf) + inputTadOffsets[idx]; - if (blockIdx.x == start) { - for (auto e = threadIdx.x; e < len; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputTads); - auto zIndex = shape::getIndexOffset(e, outputTads); - sd::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]); - //z[zIndex] = x[xIndex]; - } - } - else { - for (auto e = threadIdx.x; e < len; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputTads); - auto zIndex = shape::getIndexOffset(e, outputTads); - if (lengths[segment]) - sd::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]); - } - } - } - } - // -------------------------------------------------------------------------------------------------------------- // - - template - static void segmentMaxFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { - //int numClasses = output->sizeAt(0); - // if input is a vector: (as if in doc sample) - //Nd4jLong idx = indices->e(0); - output->assign(-DataTypeUtils::infOrMax()); - auto stream = context->getCudaStream(); - indices->syncToHost(); - Nd4jLong numOfClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); - - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - dim3 dims(256, 512, 256); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); - - NDArray::prepareSpecialUse({output}, {input, indices, &classesRangesBegs, &classesRangesLens}); - - if (input->isVector()) { - - segmentMaxLinearKernel<<lengthOf(), numOfClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - segmentMaxTadKernel<<>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices, &classesRangesBegs, &classesRangesLens}); - } - // -------------------------------------------------------------------------------------------------------------- // - void segmentMaxFunctor(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices}); - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMaxFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices}); - } - // -------------------------------------------------------------------------------------------------------------- // - - template - static void unsortedSegmentMaxFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - auto stream = context->getCudaStream(); -// NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); - output->assign(DataTypeUtils::infOrMax()); - - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); -// NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); -// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), row, classes); - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); -// int* classesBuf = reinterpret_cast(classes.specialBuffer()); - fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - - if (input->isVector()) { - unsortedSegmentMaxLinearKernel<<>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - dims.x = input->sizeAt(0); - output->assign(-DataTypeUtils::max()); - segmentMaxTadKernel<<>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); - } - - } - // -------------------------------------------------------------------------------------------------------------- // - void unsortedSegmentMaxFunctor(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices}); - output->nullify(); - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMaxFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices}); - } - - // -------------------------------------------------------------------------------------------------------------- // - // segment max - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void segmentMaxBPLinearKernel(void* inputBuf, Nd4jLong const* inputShape, void* forwardOutput, - Nd4jLong const* forwardShape, void* eps, Nd4jLong const* epsShape, void* indicesBuf, Nd4jLong const* indicesShape, - void* outputBuf, Nd4jLong const* outputShape) { - __shared__ T* x; - __shared__ T* gradIn; - __shared__ T* gradOut; - __shared__ I* y; - __shared__ T* z; - __shared__ Nd4jLong xLen, gradLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - x = reinterpret_cast(inputBuf); - y = reinterpret_cast(indicesBuf); - z = reinterpret_cast(outputBuf); - gradIn = reinterpret_cast(forwardOutput); - gradOut = reinterpret_cast(eps); - gradLen = shape::length(epsShape); - } - __syncthreads(); - - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = gridDim.x * blockDim.x; - - for (auto e = start; e < xLen; e += step) { - - auto zOffset = shape::getIndexOffset(e, outputShape); - auto xOffset = shape::getIndexOffset(e, inputShape); - auto yOffset = shape::getIndexOffset(e, indicesShape); - auto classIndex = y[yOffset]; - auto gradOffsetI = shape::getIndexOffset(classIndex, forwardShape); - auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape); - - if (sd::math::nd4j_abs(gradIn[gradOffsetI] - x[xOffset]) <= T(1.e-6)) { - z[zOffset] = gradOut[gradOffsetO]; - } - } - } - - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void segmentMaxBPTadKernel(void* inputBuf, Nd4jLong const* inputShape, void* forwardOutput, - Nd4jLong const* forwardShape, void* eps, Nd4jLong const* epsShape, void* indicesBuf, Nd4jLong const* indicesShape, - void* outputBuf, Nd4jLong const* outputShape,Nd4jLong const* inputTad, - Nd4jLong const* inputOffsets, Nd4jLong const* gradInTad, Nd4jLong const* gradInOffsets, - Nd4jLong const* gradOutTad, Nd4jLong const* gradOutOffsets, Nd4jLong const* outTad, - Nd4jLong const* outOffsets) { - __shared__ T* x; - __shared__ T* gradIn; - __shared__ T* gradOut; - __shared__ I* y; - __shared__ T* z; - __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - x = reinterpret_cast(inputBuf); - y = reinterpret_cast(indicesBuf); - z = reinterpret_cast(outputBuf); - yLen = shape::length(indicesShape); - gradOut = reinterpret_cast(eps); - gradIn = reinterpret_cast(forwardOutput); - gradLen = shape::length(epsShape); - currentLen = shape::length(outTad); - } - __syncthreads(); - - for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { - auto yIndex = shape::getIndexOffset(i, indicesShape); - auto segment = y[yIndex]; - T* current = x + inputOffsets[i]; - T* currentOut = z + outOffsets[i]; - T* in = gradIn + gradInOffsets[segment]; - T* outGrad = gradOut + gradOutOffsets[segment]; - - for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { - if (sd::math::nd4j_abs(in[e] - current[e]) <= T(1.e-6)) - currentOut[e] = outGrad[e]; - } - } - } - // -------------------------------------------------------------------------------------------------------------- // - template - int segmentMaxFunctorBP_(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - //int numOfClasses = gradOut->sizeAt(0); - // if input is a vector: (as if in doc sample) - auto stream = context->getCudaStream(); - NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT(), context);//->shapeInfo(), context); - segmentMaxFunctor_(context, input, indices, &tempRes); - NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes}); - if (input->isVector()) { - Nd4jLong loop_size = input->lengthOf(); - auto numOfClasses = gradOut->lengthOf(); //indices->e(loop_size - 1); - segmentMaxBPLinearKernel<<<1 + gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimensions); - auto packGradIn = sd::ConstantTadHelper::getInstance().tadForDimensions(tempRes.shapeInfo(), dimensions); - auto packGradOut = sd::ConstantTadHelper::getInstance().tadForDimensions(gradOut->shapeInfo(), dimensions); - Nd4jLong const* inputTads = packX.specialShapeInfo(); - Nd4jLong const* inputTadOffsets = packX.specialOffsets(); - Nd4jLong const* outputTads = packZ.specialShapeInfo(); - Nd4jLong const* outputTadOffsets = packZ.specialOffsets(); - Nd4jLong const* gradInTads = packGradIn.specialShapeInfo(); - Nd4jLong const* gradInTadOffsets = packGradIn.specialOffsets(); - Nd4jLong const* gradOutTads = packGradOut.specialShapeInfo(); - Nd4jLong const* gradOutTadOffsets = packGradOut.specialOffsets(); - - segmentMaxBPTadKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), - inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets, - outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes}); - return Status::OK(); - } - // -------------------------------------------------------------------------------------------------------------- // - int segmentMaxFunctorBP(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMaxFunctorBP_, (context, input, - indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - } - - // -------------------------------------------------------------------------------------------------------------- // - template - static int unsortedSegmentMaxFunctorBP_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - //int numOfClasses = gradOut->sizeAt(0); - // if input is a vector: (as if in doc sample) - auto stream = context->getCudaStream(); - NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT(), context);//->shapeInfo(), context); - unsortedSegmentMaxFunctor_(context, input, indices, numOfClasses, &tempRes); - NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes}); - if (input->isVector()) { - Nd4jLong loop_size = input->lengthOf(); - auto numOfClasses = gradOut->lengthOf(); //indices->e(loop_size - 1); - segmentMaxBPLinearKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimensions); - auto packGradIn = sd::ConstantTadHelper::getInstance().tadForDimensions(tempRes.shapeInfo(), dimensions); - auto packGradOut = sd::ConstantTadHelper::getInstance().tadForDimensions(gradOut->shapeInfo(), dimensions); - Nd4jLong const* inputTads = packX.specialShapeInfo(); - Nd4jLong const* inputTadOffsets = packX.specialOffsets(); - Nd4jLong const* outputTads = packZ.specialShapeInfo(); - Nd4jLong const* outputTadOffsets = packZ.specialOffsets(); - Nd4jLong const* gradInTads = packGradIn.specialShapeInfo(); - Nd4jLong const* gradInTadOffsets = packGradIn.specialOffsets(); - Nd4jLong const* gradOutTads = packGradOut.specialShapeInfo(); - Nd4jLong const* gradOutTadOffsets = packGradOut.specialOffsets(); - - segmentMaxBPTadKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), - inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets, - outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes}); - return Status::OK(); - } - // -------------------------------------------------------------------------------------------------------------- // - int unsortedSegmentMaxFunctorBP(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMaxFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - } - } +namespace ops { +namespace helpers { + +// -------------------------------------------------------------------------------------------------------------- +// // Segment ops linear kernels +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static __global__ void segmentMaxLinearKernel( + void* input, Nd4jLong const* inputShape, int* starts, int* lengths, + Nd4jLong numOfClasses, void* output, Nd4jLong const* outputShape) { + __shared__ T* val; + __shared__ Nd4jLong xLen, zLen, zIndex; + __shared__ T* x; + __shared__ T* z; + __shared__ int threadsPerSegment, start, finish; + + auto segment = blockIdx.x; + if (threadIdx.x == 0) { + // threadsPerSegment = (gridDim.x + numOfClasses - 1) / + // numOfClasses; segment = blockIdx.x / + // threadsPerSegment; + x = reinterpret_cast(input); + z = reinterpret_cast(output); + extern __shared__ unsigned char shmem[]; + val = reinterpret_cast(shmem); + xLen = shape::length(inputShape); + zLen = shape::length(outputShape); + + if (segment < numOfClasses) { + zIndex = shape::getIndexOffset(segment, outputShape); + start = starts[segment]; + finish = start + lengths[segment]; + z[zIndex] = x[shape::getIndexOffset(start, inputShape)]; + val[segment] = z[zIndex]; + } + } + __syncthreads(); + + for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputShape); + sd::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static __global__ void unsortedSegmentMaxLinearKernel( + void* input, Nd4jLong const* inputShape, void* indices, + Nd4jLong const* indicesShape, int* starts, int* lengths, + Nd4jLong numOfClasses, void* output, Nd4jLong const* outputShape) { + __shared__ T* val; + __shared__ Nd4jLong xLen, zLen, zIndex; + __shared__ T* x; + __shared__ T* z; + __shared__ I* y; // int threadsPerSegment, start, finish; + auto segment = blockIdx.x; + + if (threadIdx.x == 0) { + x = reinterpret_cast(input); + z = reinterpret_cast(output); + y = reinterpret_cast(indices); + xLen = shape::length(inputShape); + zLen = shape::length(outputShape); + + zIndex = shape::getIndexOffset(segment, outputShape); + // start = starts[segment]; + // finish = start + lengths[segment]; + if (lengths[segment] > 0) + z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape)]; + else + z[zIndex] = -DataTypeUtils::max(); + } + __syncthreads(); + if (lengths[segment] > 0) + for (auto e = threadIdx.x + 1; e < xLen; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputShape); + auto yIndex = shape::getIndexOffset(e, indicesShape); + if (y[yIndex] == segment) { + sd::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]); + } + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void segmentMaxTadKernel( + void* inputBuf, Nd4jLong const* inputShape, Nd4jLong const* inputTads, + Nd4jLong const* inputTadOffsets, I* indices, int* starts, int* lengths, + Nd4jLong numOfClasses, void* outputBuf, Nd4jLong const* outputShape, + Nd4jLong const* outputTads, Nd4jLong const* outputTadOffsets, + T filler = 0) { + __shared__ T* val; + __shared__ Nd4jLong len, zIndex, total; + __shared__ T* z; + __shared__ int start, finish; + __shared__ I segment; + + if (threadIdx.x == 0) { + segment = indices[blockIdx.x]; // / threadsPerSegment; + z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; + len = shape::length(inputTads); + + start = starts[segment]; + finish = start + lengths[segment]; + total = shape::sizeAt(inputShape, 0); + } + __syncthreads(); + + auto idx = blockIdx.x; + if (idx <= total) { + auto x = reinterpret_cast(inputBuf) + inputTadOffsets[idx]; + if (blockIdx.x == start) { + for (auto e = threadIdx.x; e < len; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputTads); + auto zIndex = shape::getIndexOffset(e, outputTads); + sd::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]); + // z[zIndex] = x[xIndex]; + } + } else { + for (auto e = threadIdx.x; e < len; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputTads); + auto zIndex = shape::getIndexOffset(e, outputTads); + if (lengths[segment]) + sd::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]); + } + } + } +} +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static void segmentMaxFunctor_(LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + // int numClasses = output->sizeAt(0); + // if input is a vector: (as if in doc sample) + // Nd4jLong idx = indices->e(0); + output->assign(-DataTypeUtils::infOrMax()); + auto stream = context->getCudaStream(); + indices->syncToHost(); + Nd4jLong numOfClasses = indices->e(indices->lengthOf() - 1) + 1; + NDArray classesRangesLens = + NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numOfClasses}, context); + + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + dim3 dims(256, 512, 256); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); + + NDArray::prepareSpecialUse( + {output}, {input, indices, &classesRangesBegs, &classesRangesLens}); + + if (input->isVector()) { + segmentMaxLinearKernel + <<lengthOf(), numOfClasses * 32 + 32, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), begins, lengths, + numOfClasses, output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + segmentMaxTadKernel<<>>( + input->specialBuffer(), input->specialShapeInfo(), inputTads, + inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, + lengths, numOfClasses, output->specialBuffer(), + output->specialShapeInfo(), outputTads, outputTadOffsets); + } + NDArray::registerSpecialUse( + {output}, {input, indices, &classesRangesBegs, &classesRangesLens}); +} +// -------------------------------------------------------------------------------------------------------------- +// // +void segmentMaxFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), + segmentMaxFunctor_, (context, input, indices, output), + NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); +} +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static void unsortedSegmentMaxFunctor_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + Nd4jLong numOfClasses, NDArray* output) { + auto stream = context->getCudaStream(); + // NDArray classes = NDArrayFactory::create('c', {numOfClasses, + // 2}); + output->assign(DataTypeUtils::infOrMax()); + + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesLens = + NDArrayFactory::create('c', {numOfClasses}, context); + // NDArray row = NDArrayFactory::create('c', {1, 2}, + // {(int)indices->lengthOf(), (int)0}); + // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), row, + // classes); + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); + // int* classesBuf = reinterpret_cast(classes.specialBuffer()); + fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + + if (input->isVector()) { + unsortedSegmentMaxLinearKernel<<>>( + input->specialBuffer(), input->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, + numOfClasses, output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + dims.x = input->sizeAt(0); + output->assign(-DataTypeUtils::max()); + segmentMaxTadKernel<<>>( + input->specialBuffer(), input->specialShapeInfo(), inputTads, + inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, + lengths, numOfClasses, output->specialBuffer(), + output->specialShapeInfo(), outputTads, outputTadOffsets); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +void unsortedSegmentMaxFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + output->nullify(); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), + unsortedSegmentMaxFunctor_, + (context, input, indices, numOfClasses, output), + NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); +} + +// -------------------------------------------------------------------------------------------------------------- +// // segment max +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void segmentMaxBPLinearKernel( + void* inputBuf, Nd4jLong const* inputShape, void* forwardOutput, + Nd4jLong const* forwardShape, void* eps, Nd4jLong const* epsShape, + void* indicesBuf, Nd4jLong const* indicesShape, void* outputBuf, + Nd4jLong const* outputShape) { + __shared__ T* x; + __shared__ T* gradIn; + __shared__ T* gradOut; + __shared__ I* y; + __shared__ T* z; + __shared__ Nd4jLong xLen, gradLen; + + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + x = reinterpret_cast(inputBuf); + y = reinterpret_cast(indicesBuf); + z = reinterpret_cast(outputBuf); + gradIn = reinterpret_cast(forwardOutput); + gradOut = reinterpret_cast(eps); + gradLen = shape::length(epsShape); + } + __syncthreads(); + + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = gridDim.x * blockDim.x; + + for (auto e = start; e < xLen; e += step) { + auto zOffset = shape::getIndexOffset(e, outputShape); + auto xOffset = shape::getIndexOffset(e, inputShape); + auto yOffset = shape::getIndexOffset(e, indicesShape); + auto classIndex = y[yOffset]; + auto gradOffsetI = shape::getIndexOffset(classIndex, forwardShape); + auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape); + + if (sd::math::nd4j_abs(gradIn[gradOffsetI] - x[xOffset]) <= T(1.e-6)) { + z[zOffset] = gradOut[gradOffsetO]; + } + } +} + +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void segmentMaxBPTadKernel( + void* inputBuf, Nd4jLong const* inputShape, void* forwardOutput, + Nd4jLong const* forwardShape, void* eps, Nd4jLong const* epsShape, + void* indicesBuf, Nd4jLong const* indicesShape, void* outputBuf, + Nd4jLong const* outputShape, Nd4jLong const* inputTad, + Nd4jLong const* inputOffsets, Nd4jLong const* gradInTad, + Nd4jLong const* gradInOffsets, Nd4jLong const* gradOutTad, + Nd4jLong const* gradOutOffsets, Nd4jLong const* outTad, + Nd4jLong const* outOffsets) { + __shared__ T* x; + __shared__ T* gradIn; + __shared__ T* gradOut; + __shared__ I* y; + __shared__ T* z; + __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; + + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + x = reinterpret_cast(inputBuf); + y = reinterpret_cast(indicesBuf); + z = reinterpret_cast(outputBuf); + yLen = shape::length(indicesShape); + gradOut = reinterpret_cast(eps); + gradIn = reinterpret_cast(forwardOutput); + gradLen = shape::length(epsShape); + currentLen = shape::length(outTad); + } + __syncthreads(); + + for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { + auto yIndex = shape::getIndexOffset(i, indicesShape); + auto segment = y[yIndex]; + T* current = x + inputOffsets[i]; + T* currentOut = z + outOffsets[i]; + T* in = gradIn + gradInOffsets[segment]; + T* outGrad = gradOut + gradOutOffsets[segment]; + + for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { + if (sd::math::nd4j_abs(in[e] - current[e]) <= T(1.e-6)) + currentOut[e] = outGrad[e]; } -} \ No newline at end of file + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +template +int segmentMaxFunctorBP_(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + // int numOfClasses = gradOut->sizeAt(0); + // if input is a vector: (as if in doc sample) + auto stream = context->getCudaStream(); + NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), + DataTypeUtils::fromT(), + context); //->shapeInfo(), context); + segmentMaxFunctor_(context, input, indices, &tempRes); + NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes}); + if (input->isVector()) { + Nd4jLong loop_size = input->lengthOf(); + auto numOfClasses = + gradOut->lengthOf(); // indices->e(loop_size - 1); + segmentMaxBPLinearKernel + <<<1 + gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + tempRes.specialBuffer(), tempRes.specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimensions); + auto packGradIn = sd::ConstantTadHelper::getInstance().tadForDimensions( + tempRes.shapeInfo(), dimensions); + auto packGradOut = sd::ConstantTadHelper::getInstance().tadForDimensions( + gradOut->shapeInfo(), dimensions); + Nd4jLong const* inputTads = packX.specialShapeInfo(); + Nd4jLong const* inputTadOffsets = packX.specialOffsets(); + Nd4jLong const* outputTads = packZ.specialShapeInfo(); + Nd4jLong const* outputTadOffsets = packZ.specialOffsets(); + Nd4jLong const* gradInTads = packGradIn.specialShapeInfo(); + Nd4jLong const* gradInTadOffsets = packGradIn.specialOffsets(); + Nd4jLong const* gradOutTads = packGradOut.specialShapeInfo(); + Nd4jLong const* gradOutTadOffsets = packGradOut.specialOffsets(); + + segmentMaxBPTadKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + tempRes.specialBuffer(), tempRes.specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), inputTads, + inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, + gradOutTadOffsets, outputTads, outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes}); + return Status::OK(); +} +// -------------------------------------------------------------------------------------------------------------- +// // +int segmentMaxFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + BUILD_DOUBLE_SELECTOR( + output->dataType(), indices->dataType(), return segmentMaxFunctorBP_, + (context, input, indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); +} + +// -------------------------------------------------------------------------------------------------------------- +// // +template +static int unsortedSegmentMaxFunctorBP_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + NDArray* gradOut, Nd4jLong numOfClasses, + NDArray* output) { + // int numOfClasses = gradOut->sizeAt(0); + // if input is a vector: (as if in doc sample) + auto stream = context->getCudaStream(); + NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), + DataTypeUtils::fromT(), + context); //->shapeInfo(), context); + unsortedSegmentMaxFunctor_(context, input, indices, numOfClasses, + &tempRes); + NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes}); + if (input->isVector()) { + Nd4jLong loop_size = input->lengthOf(); + auto numOfClasses = + gradOut->lengthOf(); // indices->e(loop_size - 1); + segmentMaxBPLinearKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + tempRes.specialBuffer(), tempRes.specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimensions); + auto packGradIn = sd::ConstantTadHelper::getInstance().tadForDimensions( + tempRes.shapeInfo(), dimensions); + auto packGradOut = sd::ConstantTadHelper::getInstance().tadForDimensions( + gradOut->shapeInfo(), dimensions); + Nd4jLong const* inputTads = packX.specialShapeInfo(); + Nd4jLong const* inputTadOffsets = packX.specialOffsets(); + Nd4jLong const* outputTads = packZ.specialShapeInfo(); + Nd4jLong const* outputTadOffsets = packZ.specialOffsets(); + Nd4jLong const* gradInTads = packGradIn.specialShapeInfo(); + Nd4jLong const* gradInTadOffsets = packGradIn.specialOffsets(); + Nd4jLong const* gradOutTads = packGradOut.specialShapeInfo(); + Nd4jLong const* gradOutTadOffsets = packGradOut.specialOffsets(); + + segmentMaxBPTadKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + tempRes.specialBuffer(), tempRes.specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), inputTads, + inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, + gradOutTadOffsets, outputTads, outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes}); + return Status::OK(); +} +// -------------------------------------------------------------------------------------------------------------- +// // +int unsortedSegmentMaxFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + BUILD_DOUBLE_SELECTOR( + output->dataType(), indices->dataType(), + return unsortedSegmentMaxFunctorBP_, + (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, + INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu index 5ccecf37c6f5..db8cd6aca799 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu @@ -18,400 +18,502 @@ // @author GS // -#include -#include #include -#include -#include #include -#include #include +#include +#include +#include +#include +#include namespace sd { namespace ops { namespace helpers { - // -------------------------------------------------------------------------------------------------------------- // - // Segment ops linear kernels - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void segmentMeanLinearKernel(void* input, Nd4jLong const* inputShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong const* outputShape) { - __shared__ T* val; - __shared__ Nd4jLong xLen, zLen, segment, zIndex; - __shared__ T* x; - __shared__ T* z; - __shared__ int threadsPerSegment, start, finish; - - if (threadIdx.x == 0) { - threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; - segment = blockIdx.x / threadsPerSegment; - x = reinterpret_cast(input); - z = reinterpret_cast(output); -// extern __shared__ unsigned char shmem[]; -// val = reinterpret_cast(shmem); - xLen = shape::length(inputShape); - zLen = shape::length(outputShape); - - //[zIndex] = - if (segment < numOfClasses) { - zIndex = shape::getIndexOffset(segment, outputShape); - start = starts[segment]; - finish = start + lengths[segment]; - //val[segment] = ; - z[zIndex] = T(x[shape::getIndexOffset(start, inputShape)] / lengths[segment]); -// val[segment] = z[zIndex]; - } - - } - __syncthreads(); - - for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputShape); - if (lengths[segment]) - sd::math::atomics::nd4j_atomicAdd(&z[zIndex], T(x[xIndex] / lengths[segment])); - } - } - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void unsortedSegmentMeanLinearKernel(void* input, Nd4jLong const* inputShape, void* indices, Nd4jLong const* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong const* outputShape) { - __shared__ T* val; - __shared__ Nd4jLong xLen, zLen, zIndex; - __shared__ T* x; - __shared__ T* z; - __shared__ I* y; //int threadsPerSegment, start, finish; - auto segment = blockIdx.x;// / - if (threadIdx.x == 0) { -// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; -// threadsPerSegment; - x = reinterpret_cast(input); - z = reinterpret_cast(output); - y = reinterpret_cast(indices); -// extern __shared__ unsigned char shmem[]; -// val = reinterpret_cast(shmem); - xLen = shape::length(inputShape); - zLen = shape::length(outputShape); - -// if (segment < numOfClasses) { - zIndex = shape::getIndexOffset(segment, outputShape); - //start = starts[segment]; - //finish = start + lengths[segment]; - if (lengths[segment] > 0) - z[zIndex] = T(x[shape::getIndexOffset(starts[segment], inputShape)] / T(lengths[segment])); - else - z[zIndex] = 0; //DataTypeUtils::max(); -// val[segment] = z[zIndex]; -// } - - } - __syncthreads(); - if (lengths[segment] > 0) - for (auto e = threadIdx.x; e < xLen; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputShape); - auto yIndex = shape::getIndexOffset(e, indicesShape); - if (y[yIndex] == segment && e != starts[segment]) { - sd::math::atomics::nd4j_atomicAdd(&z[zIndex], T(x[xIndex]/T(lengths[segment]))); - } - } - } - // -------------------------------------------------------------------------------------------------------------- // - // SegmentMean kernel - template - static __global__ void segmentMeanTadKernel(void* inputBuf, Nd4jLong const* inputShape, Nd4jLong const* inputTads, Nd4jLong const* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong const* outputShape, Nd4jLong const* outputTads, Nd4jLong const* outputTadOffsets) { - __shared__ T* val; - __shared__ Nd4jLong len, zIndex, total; - __shared__ T* z; - __shared__ int threadsPerSegment, start, finish; - auto segment = indices[blockIdx.x]; // / threadsPerSegment; - - if (threadIdx.x == 0) { - z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; - len = shape::length(inputTads); - start = starts[segment]; - finish = start + lengths[segment]; - total = shape::sizeAt(inputShape, 0); - - } - __syncthreads(); - - auto idx = blockIdx.x; - if (blockIdx.x <= total) { - auto x = reinterpret_cast(inputBuf) + inputTadOffsets[idx]; - if (blockIdx.x == start) { - for (auto e = threadIdx.x; e < len; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputTads); - auto zIndex = shape::getIndexOffset(e, outputTads); - sd::math::atomics::nd4j_atomicAdd(&z[zIndex], T(x[xIndex]/lengths[segment])); - } - } - else { - for (auto e = threadIdx.x; e < len; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputTads); - auto zIndex = shape::getIndexOffset(e, outputTads); - if (lengths[segment]) - sd::math::atomics::nd4j_atomicAdd(&z[zIndex], T(x[xIndex]/lengths[segment])); - } - } - } - } - // -------------------------------------------------------------------------------------------------------------- // - // segmen mean - template - static void segmentMeanFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { - auto stream = context->getCudaStream(); - Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); - - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - NDArray::prepareSpecialUse({output}, {input, indices}); - dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); - - if (input->isVector()) { - segmentMeanLinearKernel<<lengthOf(), numClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - segmentMeanTadKernel<<sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices}); - - } - // -------------------------------------------------------------------------------------------------------------- // - void segmentMeanFunctor(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentMeanFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices}); - } - - // -------------------------------------------------------------------------------------------------------------- // - template - static void unsortedSegmentMeanFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - auto stream = context->getCudaStream(); -// NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); - - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); -// NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); -// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); -// int* classesBuf = reinterpret_cast(classes.specialBuffer()); - fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - - if (input->isVector()) { - unsortedSegmentMeanLinearKernel<<>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo()); - } - else { - output->assign(0); - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimensions); - Nd4jLong const* inputTads = packX.specialShapeInfo(); - Nd4jLong const* inputTadOffsets = packX.specialOffsets(); - Nd4jLong const* outputTads = packZ.specialShapeInfo(); - Nd4jLong const* outputTadOffsets = packZ.specialOffsets(); - dims.x = input->sizeAt(0); - segmentMeanTadKernel<<>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); - } - +// -------------------------------------------------------------------------------------------------------------- +// // Segment ops linear kernels +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void segmentMeanLinearKernel( + void* input, Nd4jLong const* inputShape, int* starts, int* lengths, + Nd4jLong numOfClasses, void* output, Nd4jLong const* outputShape) { + __shared__ T* val; + __shared__ Nd4jLong xLen, zLen, segment, zIndex; + __shared__ T* x; + __shared__ T* z; + __shared__ int threadsPerSegment, start, finish; + + if (threadIdx.x == 0) { + threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; + segment = blockIdx.x / threadsPerSegment; + x = reinterpret_cast(input); + z = reinterpret_cast(output); + // extern __shared__ unsigned char shmem[]; + // val = reinterpret_cast(shmem); + xLen = shape::length(inputShape); + zLen = shape::length(outputShape); + + //[zIndex] = + if (segment < numOfClasses) { + zIndex = shape::getIndexOffset(segment, outputShape); + start = starts[segment]; + finish = start + lengths[segment]; + // val[segment] = ; + z[zIndex] = + T(x[shape::getIndexOffset(start, inputShape)] / lengths[segment]); + // val[segment] = z[zIndex]; } - // -------------------------------------------------------------------------------------------------------------- // - void unsortedSegmentMeanFunctor(sd::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices}); - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMeanFunctor_, (context, input, indices, numOfClasses, output), - NUMERIC_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices}); - } - - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void segmentMeanBPLinearKernel(void* inputBuf, Nd4jLong const* inputShape, void* eps, Nd4jLong const* epsShape, void* indicesBuf, Nd4jLong const* indicesShape, - int* lengths, void* outputBuf, Nd4jLong const* outputShape) { - __shared__ T* x; - __shared__ T* gradIn; - __shared__ T* gradOut; - __shared__ I* y; - __shared__ T* z; - __shared__ Nd4jLong xLen, gradLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - x = reinterpret_cast(inputBuf); - y = reinterpret_cast(indicesBuf); - z = reinterpret_cast(outputBuf); - gradOut = reinterpret_cast(eps); - gradLen = shape::length(epsShape); - } - __syncthreads(); - - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = gridDim.x * blockDim.x; - - for (auto e = start; e < xLen; e += step) { - - auto zOffset = shape::getIndexOffset(e, outputShape); - auto xOffset = shape::getIndexOffset(e, inputShape); - auto yOffset = shape::getIndexOffset(e, indicesShape); - auto classIndex = y[yOffset]; - auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape); - - z[zOffset] = T(gradOut[gradOffsetO] / float(lengths[classIndex])); - } - } - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void segmentMeanBPTadKernel(void* inputBuf, Nd4jLong const* inputShape, void* eps, Nd4jLong const* epsShape, - void* indicesBuf, Nd4jLong const* indicesShape, int* lengths, void* outputBuf, Nd4jLong const* outputShape,Nd4jLong const* inputTad, - Nd4jLong const* inputOffsets, Nd4jLong const* gradOutTad, Nd4jLong const* gradOutOffsets, Nd4jLong const* outTad, Nd4jLong const* outOffsets) { - __shared__ T* x; - __shared__ T* gradOut; - __shared__ I* y; - __shared__ T* z; - __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - x = reinterpret_cast(inputBuf); - y = reinterpret_cast(indicesBuf); - z = reinterpret_cast(outputBuf); - yLen = shape::length(indicesShape); - gradOut = reinterpret_cast(eps); - gradLen = shape::length(epsShape); - currentLen = shape::length(outTad); - } - __syncthreads(); - - for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { -// auto yIndex = shape::getIndexOffset(i, indicesShape); - auto segment = y[i]; //yIndex]; - T* currentOut = z + outOffsets[i]; - T* outGrad = gradOut + gradOutOffsets[segment]; - - for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { - auto zIndex = shape::getIndexOffset(e, outTad); - auto gradIndex = shape::getIndexOffset(e, gradOutTad); - if (lengths[segment] > 0) - currentOut[zIndex] = T(outGrad[gradIndex] / float(lengths[segment])); - } - } - } - // -------------------------------------------------------------------------------------------------------------- // - // backrop for mean - template - int segmentMeanFunctorBP_(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - auto numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); - - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); - fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - - if (input->isVector()) { - Nd4jLong loop_size = input->lengthOf(); - auto numOfClasses = gradOut->lengthOf(); //indices->e(loop_size - 1); - segmentMeanBPLinearKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), - input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), lengths, output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimensions); -// auto packGradIn = sd::ConstantTadHelper::getInstance().tadForDimensions(tempRes.shapeInfo(), dimensions); - auto packGradOut = sd::ConstantTadHelper::getInstance().tadForDimensions(gradOut->shapeInfo(), dimensions); - Nd4jLong const* inputTads = packX.specialShapeInfo(); - Nd4jLong const* inputTadOffsets = packX.specialOffsets(); - Nd4jLong const* outputTads = packZ.specialShapeInfo(); - Nd4jLong const* outputTadOffsets = packZ.specialOffsets(); - Nd4jLong const* gradOutTads = packGradOut.specialShapeInfo(); - Nd4jLong const* gradOutTadOffsets = packGradOut.specialOffsets(); - - segmentMeanBPTadKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - gradOut->specialBuffer(), gradOut->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), lengths, - output->specialBuffer(), output->specialShapeInfo(), inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets, - outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - return Status::OK(); + } + __syncthreads(); + + for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputShape); + if (lengths[segment]) + sd::math::atomics::nd4j_atomicAdd(&z[zIndex], + T(x[xIndex] / lengths[segment])); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void unsortedSegmentMeanLinearKernel( + void* input, Nd4jLong const* inputShape, void* indices, + Nd4jLong const* indicesShape, int* starts, int* lengths, + Nd4jLong numOfClasses, void* output, Nd4jLong const* outputShape) { + __shared__ T* val; + __shared__ Nd4jLong xLen, zLen, zIndex; + __shared__ T* x; + __shared__ T* z; + __shared__ I* y; // int threadsPerSegment, start, finish; + auto segment = blockIdx.x; // / + if (threadIdx.x == 0) { + // threadsPerSegment = (gridDim.x + numOfClasses - 1) / + // numOfClasses; threadsPerSegment; + x = reinterpret_cast(input); + z = reinterpret_cast(output); + y = reinterpret_cast(indices); + // extern __shared__ unsigned char shmem[]; + // val = reinterpret_cast(shmem); + xLen = shape::length(inputShape); + zLen = shape::length(outputShape); + + // if (segment < numOfClasses) { + zIndex = shape::getIndexOffset(segment, outputShape); + // start = starts[segment]; + // finish = start + lengths[segment]; + if (lengths[segment] > 0) + z[zIndex] = T(x[shape::getIndexOffset(starts[segment], inputShape)] / + T(lengths[segment])); + else + z[zIndex] = 0; // DataTypeUtils::max(); + // val[segment] = z[zIndex]; + // } + } + __syncthreads(); + if (lengths[segment] > 0) + for (auto e = threadIdx.x; e < xLen; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputShape); + auto yIndex = shape::getIndexOffset(e, indicesShape); + if (y[yIndex] == segment && e != starts[segment]) { + sd::math::atomics::nd4j_atomicAdd(&z[zIndex], + T(x[xIndex] / T(lengths[segment]))); + } } - // -------------------------------------------------------------------------------------------------------------- // - // segmen mean bp main - int segmentMeanFunctorBP(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMeanFunctorBP_, (context, input, - indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); +} +// -------------------------------------------------------------------------------------------------------------- +// // SegmentMean kernel +template +static __global__ void segmentMeanTadKernel( + void* inputBuf, Nd4jLong const* inputShape, Nd4jLong const* inputTads, + Nd4jLong const* inputTadOffsets, I* indices, int* starts, int* lengths, + Nd4jLong numOfClasses, void* outputBuf, Nd4jLong const* outputShape, + Nd4jLong const* outputTads, Nd4jLong const* outputTadOffsets) { + __shared__ T* val; + __shared__ Nd4jLong len, zIndex, total; + __shared__ T* z; + __shared__ int threadsPerSegment, start, finish; + auto segment = indices[blockIdx.x]; // / threadsPerSegment; + + if (threadIdx.x == 0) { + z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; + len = shape::length(inputTads); + start = starts[segment]; + finish = start + lengths[segment]; + total = shape::sizeAt(inputShape, 0); + } + __syncthreads(); + + auto idx = blockIdx.x; + if (blockIdx.x <= total) { + auto x = reinterpret_cast(inputBuf) + inputTadOffsets[idx]; + if (blockIdx.x == start) { + for (auto e = threadIdx.x; e < len; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputTads); + auto zIndex = shape::getIndexOffset(e, outputTads); + sd::math::atomics::nd4j_atomicAdd(&z[zIndex], + T(x[xIndex] / lengths[segment])); + } + } else { + for (auto e = threadIdx.x; e < len; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputTads); + auto zIndex = shape::getIndexOffset(e, outputTads); + if (lengths[segment]) + sd::math::atomics::nd4j_atomicAdd(&z[zIndex], + T(x[xIndex] / lengths[segment])); + } } - // -------------------------------------------------------------------------------------------------------------- // - - template - static int unsortedSegmentMeanFunctorBP_(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - auto numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); - - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); - fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // segmen mean +template +static void segmentMeanFunctor_(LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + auto stream = context->getCudaStream(); + Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; + NDArray classesRangesLens = + NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numClasses}, context); + + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + NDArray::prepareSpecialUse({output}, {input, indices}); + dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); + + if (input->isVector()) { + segmentMeanLinearKernel + <<lengthOf(), numClasses * 32 + 32, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), begins, lengths, + numClasses, output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + segmentMeanTadKernel<<sizeAt(0), 512, 2048, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), inputTads, + inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, + lengths, numClasses, output->specialBuffer(), + output->specialShapeInfo(), outputTads, outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices}); +} +// -------------------------------------------------------------------------------------------------------------- +// // +void segmentMeanFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), + segmentMeanFunctor_, (context, input, indices, output), + NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); +} - if (input->isVector()) { - Nd4jLong loop_size = input->lengthOf(); - auto numOfClasses = gradOut->lengthOf(); //indices->e(loop_size - 1); - segmentMeanBPLinearKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), - input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), lengths, output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimensions); -// auto packGradIn = sd::ConstantTadHelper::getInstance().tadForDimensions(tempRes.shapeInfo(), dimensions); - auto packGradOut = sd::ConstantTadHelper::getInstance().tadForDimensions(gradOut->shapeInfo(), dimensions); - Nd4jLong const* inputTads = packX.specialShapeInfo(); - Nd4jLong const* inputTadOffsets = packX.specialOffsets(); - Nd4jLong const* outputTads = packZ.specialShapeInfo(); - Nd4jLong const* outputTadOffsets = packZ.specialOffsets(); - Nd4jLong const* gradOutTads = packGradOut.specialShapeInfo(); - Nd4jLong const* gradOutTadOffsets = packGradOut.specialOffsets(); +// -------------------------------------------------------------------------------------------------------------- +// // +template +static void unsortedSegmentMeanFunctor_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + Nd4jLong numOfClasses, + NDArray* output) { + auto stream = context->getCudaStream(); + // NDArray classes = NDArrayFactory::create('c', {numOfClasses, + // 2}); + + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesLens = + NDArrayFactory::create('c', {numOfClasses}, context); + // NDArray row = NDArrayFactory::create('c', {1, 2}, + // {(int)indices->lengthOf(), (int)0}); + // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, + // &classes); + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); + // int* classesBuf = reinterpret_cast(classes.specialBuffer()); + fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + + if (input->isVector()) { + unsortedSegmentMeanLinearKernel<<>>( + input->specialBuffer(), input->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, + numOfClasses, output->specialBuffer(), output->specialShapeInfo()); + } else { + output->assign(0); + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimensions); + Nd4jLong const* inputTads = packX.specialShapeInfo(); + Nd4jLong const* inputTadOffsets = packX.specialOffsets(); + Nd4jLong const* outputTads = packZ.specialShapeInfo(); + Nd4jLong const* outputTadOffsets = packZ.specialOffsets(); + dims.x = input->sizeAt(0); + segmentMeanTadKernel<<>>( + input->specialBuffer(), input->specialShapeInfo(), inputTads, + inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, + lengths, numOfClasses, output->specialBuffer(), + output->specialShapeInfo(), outputTads, outputTadOffsets); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +void unsortedSegmentMeanFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), + unsortedSegmentMeanFunctor_, + (context, input, indices, numOfClasses, output), + NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); +} - segmentMeanBPTadKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - gradOut->specialBuffer(), gradOut->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), lengths, - output->specialBuffer(), output->specialShapeInfo(), inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets, - outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - return Status::OK(); - } - // -------------------------------------------------------------------------------------------------------------- // - int unsortedSegmentMeanFunctorBP(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMeanFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void segmentMeanBPLinearKernel( + void* inputBuf, Nd4jLong const* inputShape, void* eps, + Nd4jLong const* epsShape, void* indicesBuf, Nd4jLong const* indicesShape, + int* lengths, void* outputBuf, Nd4jLong const* outputShape) { + __shared__ T* x; + __shared__ T* gradIn; + __shared__ T* gradOut; + __shared__ I* y; + __shared__ T* z; + __shared__ Nd4jLong xLen, gradLen; + + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + x = reinterpret_cast(inputBuf); + y = reinterpret_cast(indicesBuf); + z = reinterpret_cast(outputBuf); + gradOut = reinterpret_cast(eps); + gradLen = shape::length(epsShape); + } + __syncthreads(); + + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = gridDim.x * blockDim.x; + + for (auto e = start; e < xLen; e += step) { + auto zOffset = shape::getIndexOffset(e, outputShape); + auto xOffset = shape::getIndexOffset(e, inputShape); + auto yOffset = shape::getIndexOffset(e, indicesShape); + auto classIndex = y[yOffset]; + auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape); + + z[zOffset] = T(gradOut[gradOffsetO] / float(lengths[classIndex])); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void segmentMeanBPTadKernel( + void* inputBuf, Nd4jLong const* inputShape, void* eps, + Nd4jLong const* epsShape, void* indicesBuf, Nd4jLong const* indicesShape, + int* lengths, void* outputBuf, Nd4jLong const* outputShape, + Nd4jLong const* inputTad, Nd4jLong const* inputOffsets, + Nd4jLong const* gradOutTad, Nd4jLong const* gradOutOffsets, + Nd4jLong const* outTad, Nd4jLong const* outOffsets) { + __shared__ T* x; + __shared__ T* gradOut; + __shared__ I* y; + __shared__ T* z; + __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; + + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + x = reinterpret_cast(inputBuf); + y = reinterpret_cast(indicesBuf); + z = reinterpret_cast(outputBuf); + yLen = shape::length(indicesShape); + gradOut = reinterpret_cast(eps); + gradLen = shape::length(epsShape); + currentLen = shape::length(outTad); + } + __syncthreads(); + + for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { + // auto yIndex = shape::getIndexOffset(i, indicesShape); + auto segment = y[i]; // yIndex]; + T* currentOut = z + outOffsets[i]; + T* outGrad = gradOut + gradOutOffsets[segment]; + + for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { + auto zIndex = shape::getIndexOffset(e, outTad); + auto gradIndex = shape::getIndexOffset(e, gradOutTad); + if (lengths[segment] > 0) + currentOut[zIndex] = T(outGrad[gradIndex] / float(lengths[segment])); } - + } +} +// -------------------------------------------------------------------------------------------------------------- +// // backrop for mean +template +int segmentMeanFunctorBP_(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + auto stream = context->getCudaStream(); + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + auto numClasses = indices->e(indices->lengthOf() - 1) + 1; + NDArray classesRangesLens = + NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numClasses}, context); + + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); + fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + + if (input->isVector()) { + Nd4jLong loop_size = input->lengthOf(); + auto numOfClasses = + gradOut->lengthOf(); // indices->e(loop_size - 1); + segmentMeanBPLinearKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), lengths, + output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimensions); + // auto packGradIn = + // sd::ConstantTadHelper::getInstance().tadForDimensions(tempRes.shapeInfo(), + // dimensions); + auto packGradOut = sd::ConstantTadHelper::getInstance().tadForDimensions( + gradOut->shapeInfo(), dimensions); + Nd4jLong const* inputTads = packX.specialShapeInfo(); + Nd4jLong const* inputTadOffsets = packX.specialOffsets(); + Nd4jLong const* outputTads = packZ.specialShapeInfo(); + Nd4jLong const* outputTadOffsets = packZ.specialOffsets(); + Nd4jLong const* gradOutTads = packGradOut.specialShapeInfo(); + Nd4jLong const* gradOutTadOffsets = packGradOut.specialOffsets(); + + segmentMeanBPTadKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), lengths, + output->specialBuffer(), output->specialShapeInfo(), inputTads, + inputTadOffsets, gradOutTads, gradOutTadOffsets, outputTads, + outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); + return Status::OK(); } +// -------------------------------------------------------------------------------------------------------------- +// // segmen mean bp main +int segmentMeanFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + BUILD_DOUBLE_SELECTOR( + output->dataType(), indices->dataType(), return segmentMeanFunctorBP_, + (context, input, indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } -} \ No newline at end of file +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static int unsortedSegmentMeanFunctorBP_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + NDArray* gradOut, + Nd4jLong numOfClasses, + NDArray* output) { + auto stream = context->getCudaStream(); + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + auto numClasses = indices->e(indices->lengthOf() - 1) + 1; + NDArray classesRangesLens = + NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numClasses}, context); + + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); + fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + + if (input->isVector()) { + Nd4jLong loop_size = input->lengthOf(); + auto numOfClasses = + gradOut->lengthOf(); // indices->e(loop_size - 1); + segmentMeanBPLinearKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), lengths, + output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimensions); + // auto packGradIn = + // sd::ConstantTadHelper::getInstance().tadForDimensions(tempRes.shapeInfo(), + // dimensions); + auto packGradOut = sd::ConstantTadHelper::getInstance().tadForDimensions( + gradOut->shapeInfo(), dimensions); + Nd4jLong const* inputTads = packX.specialShapeInfo(); + Nd4jLong const* inputTadOffsets = packX.specialOffsets(); + Nd4jLong const* outputTads = packZ.specialShapeInfo(); + Nd4jLong const* outputTadOffsets = packZ.specialOffsets(); + Nd4jLong const* gradOutTads = packGradOut.specialShapeInfo(); + Nd4jLong const* gradOutTadOffsets = packGradOut.specialOffsets(); + + segmentMeanBPTadKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), lengths, + output->specialBuffer(), output->specialShapeInfo(), inputTads, + inputTadOffsets, gradOutTads, gradOutTadOffsets, outputTads, + outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); + return Status::OK(); +} +// -------------------------------------------------------------------------------------------------------------- +// // +int unsortedSegmentMeanFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + BUILD_DOUBLE_SELECTOR( + output->dataType(), indices->dataType(), + return unsortedSegmentMeanFunctorBP_, + (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, + INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); +} + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu index 9e825c701dbd..d6cd53101e7d 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu @@ -18,411 +18,493 @@ // @author GS // -#include -#include #include -#include -#include #include -#include #include +#include +#include +#include +#include +#include namespace sd { namespace ops { namespace helpers { - // -------------------------------------------------------------------------------------------------------------- // - // Segment ops linear kernels - // -------------------------------------------------------------------------------------------------------------- // - - template - static __global__ void - segmentMinLinearKernel(const void *input, const Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses, - void *output, const Nd4jLong *outputShape) { - __shared__ T *val; - __shared__ Nd4jLong xLen, zLen, zIndex; - __shared__ const T *x; - __shared__ T *z; - __shared__ int threadsPerSegment, start, finish; - - auto segment = blockIdx.x; - if (threadIdx.x == 0) { -// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; -// segment = blockIdx.x / threadsPerSegment; - x = reinterpret_cast(input); - z = reinterpret_cast(output); - extern __shared__ unsigned char shmem[]; - val = reinterpret_cast(shmem); - xLen = shape::length(inputShape); - zLen = shape::length(outputShape); - - if (segment < numOfClasses) { - zIndex = shape::getIndexOffset(segment, outputShape); - start = starts[segment]; - finish = start + lengths[segment]; - z[zIndex] = x[shape::getIndexOffset(start, inputShape)]; - val[segment] = z[zIndex]; - } - - } - __syncthreads(); - - for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputShape); - sd::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]); - } - +// -------------------------------------------------------------------------------------------------------------- +// // Segment ops linear kernels +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static __global__ void segmentMinLinearKernel( + const void* input, const Nd4jLong* inputShape, int* starts, int* lengths, + Nd4jLong numOfClasses, void* output, const Nd4jLong* outputShape) { + __shared__ T* val; + __shared__ Nd4jLong xLen, zLen, zIndex; + __shared__ const T* x; + __shared__ T* z; + __shared__ int threadsPerSegment, start, finish; + + auto segment = blockIdx.x; + if (threadIdx.x == 0) { + // threadsPerSegment = (gridDim.x + numOfClasses - 1) / + // numOfClasses; segment = blockIdx.x / threadsPerSegment; + x = reinterpret_cast(input); + z = reinterpret_cast(output); + extern __shared__ unsigned char shmem[]; + val = reinterpret_cast(shmem); + xLen = shape::length(inputShape); + zLen = shape::length(outputShape); + + if (segment < numOfClasses) { + zIndex = shape::getIndexOffset(segment, outputShape); + start = starts[segment]; + finish = start + lengths[segment]; + z[zIndex] = x[shape::getIndexOffset(start, inputShape)]; + val[segment] = z[zIndex]; } - // -------------------------------------------------------------------------------------------------------------- // - - template - static __global__ void - unsortedSegmentMinLinearKernel(const void *input, const Nd4jLong *inputShape, const void *indices, const Nd4jLong *indicesShape, - int *starts, int *lengths, Nd4jLong numOfClasses, void *output, - const Nd4jLong *outputShape) { - __shared__ - T *val; - __shared__ - Nd4jLong xLen, zLen, segment, zIndex; - __shared__ - const T *x; - __shared__ - T *z; - __shared__ - const I *y; //int threadsPerSegment, start, finish; - - if (threadIdx.x == 0) { - segment = blockIdx.x; - x = reinterpret_cast(input); - z = reinterpret_cast(output); - y = reinterpret_cast(indices); - xLen = shape::length(inputShape); - zLen = shape::length(outputShape); - - zIndex = shape::getIndexOffset(segment, outputShape); - if (lengths[segment] > 0) - z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape)]; - else - z[zIndex] = DataTypeUtils::max(); + } + __syncthreads(); - } - __syncthreads(); - - if (lengths[segment] > 0) - for (auto e = threadIdx.x + 1; e < xLen; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputShape); - auto yIndex = shape::getIndexOffset(e, indicesShape); - if (y[yIndex] == segment) { - sd::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]); - } - } + for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputShape); + sd::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static __global__ void unsortedSegmentMinLinearKernel( + const void* input, const Nd4jLong* inputShape, const void* indices, + const Nd4jLong* indicesShape, int* starts, int* lengths, + Nd4jLong numOfClasses, void* output, const Nd4jLong* outputShape) { + __shared__ T* val; + __shared__ Nd4jLong xLen, zLen, segment, zIndex; + __shared__ const T* x; + __shared__ T* z; + __shared__ const I* y; // int threadsPerSegment, start, finish; + + if (threadIdx.x == 0) { + segment = blockIdx.x; + x = reinterpret_cast(input); + z = reinterpret_cast(output); + y = reinterpret_cast(indices); + xLen = shape::length(inputShape); + zLen = shape::length(outputShape); + + zIndex = shape::getIndexOffset(segment, outputShape); + if (lengths[segment] > 0) + z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape)]; + else + z[zIndex] = DataTypeUtils::max(); + } + __syncthreads(); + + if (lengths[segment] > 0) + for (auto e = threadIdx.x + 1; e < xLen; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputShape); + auto yIndex = shape::getIndexOffset(e, indicesShape); + if (y[yIndex] == segment) { + sd::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]); + } } - // -------------------------------------------------------------------------------------------------------------- // +} +// -------------------------------------------------------------------------------------------------------------- +// // // SegmentMin kernel - template - static __global__ void segmentMinTadKernel(const void* inputBuf, const Nd4jLong* inputShape, const Nd4jLong* inputTads, const Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, const Nd4jLong* outputShape, const Nd4jLong* outputTads, const Nd4jLong* outputTadOffsets) { - __shared__ T* val; - __shared__ Nd4jLong len, zIndex, total; - __shared__ T* z; - __shared__ int threadsPerSegment, start, finish; - - auto segment = indices[blockIdx.x]; // / threadsPerSegment; - if (threadIdx.x == 0) { - z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; - len = shape::length(inputTads); - start = starts[segment]; - finish = start + lengths[segment]; - total = shape::sizeAt(inputShape, 0); - - } - __syncthreads(); - - auto idx = blockIdx.x; - if (blockIdx.x <= total) { - auto x = reinterpret_cast(inputBuf) + inputTadOffsets[idx]; - if (blockIdx.x == start) { - for (auto e = threadIdx.x; e < len; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputTads); - auto zIndex = shape::getIndexOffset(e, outputTads); - sd::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]); - } - } - else { - for (auto e = threadIdx.x; e < len; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputTads); - auto zIndex = shape::getIndexOffset(e, outputTads); -// if (lengths[indices[idx]]) - sd::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]); - } - } - } - } - // -------------------------------------------------------------------------------------------------------------- // - // segmen min - template - static void segmentMinFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { - auto stream = context->getCudaStream(); - Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; - auto classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); - auto classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); - output->assign(DataTypeUtils::infOrMax()); - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - - fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); - NDArray::prepareSpecialUse({output}, {input, indices, &classesRangesBegs, &classesRangesLens}); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - if (input->isVector()) { - segmentMinLinearKernel<<lengthOf(), numClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - segmentMinTadKernel<<sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); - - } - NDArray::registerSpecialUse({output}, {input, indices, &classesRangesBegs, &classesRangesLens}); - +template +static __global__ void segmentMinTadKernel( + const void* inputBuf, const Nd4jLong* inputShape, const Nd4jLong* inputTads, + const Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, + Nd4jLong numOfClasses, void* outputBuf, const Nd4jLong* outputShape, + const Nd4jLong* outputTads, const Nd4jLong* outputTadOffsets) { + __shared__ T* val; + __shared__ Nd4jLong len, zIndex, total; + __shared__ T* z; + __shared__ int threadsPerSegment, start, finish; + + auto segment = indices[blockIdx.x]; // / threadsPerSegment; + if (threadIdx.x == 0) { + z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; + len = shape::length(inputTads); + start = starts[segment]; + finish = start + lengths[segment]; + total = shape::sizeAt(inputShape, 0); + } + __syncthreads(); + + auto idx = blockIdx.x; + if (blockIdx.x <= total) { + auto x = reinterpret_cast(inputBuf) + inputTadOffsets[idx]; + if (blockIdx.x == start) { + for (auto e = threadIdx.x; e < len; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputTads); + auto zIndex = shape::getIndexOffset(e, outputTads); + sd::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]); + } + } else { + for (auto e = threadIdx.x; e < len; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputTads); + auto zIndex = shape::getIndexOffset(e, outputTads); + // if (lengths[indices[idx]]) + sd::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]); + } } - // -------------------------------------------------------------------------------------------------------------- // - void segmentMinFunctor(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices}); - output->nullify(); - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMinFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices}); - } - - // -------------------------------------------------------------------------------------------------------------- // - - template - static void unsortedSegmentMinFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - auto stream = context->getCudaStream(); -// NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); -// NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); -// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); - output->assign(DataTypeUtils::infOrMax()); - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); -// int* classesBuf = reinterpret_cast(classes.specialBuffer()); - fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - NDArray::prepareSpecialUse({output}, {input, indices}); - if (input->isVector()) { - unsortedSegmentMinLinearKernel<<>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo()); - } - else { - output->assign(DataTypeUtils::max()); - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - dims.x = input->sizeAt(0); - segmentMinTadKernel<<>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices}); - - } - // -------------------------------------------------------------------------------------------------------------- // - void unsortedSegmentMinFunctor(sd::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices}); - output->nullify(); - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMinFunctor_, (context, input, indices, numOfClasses, output), - NUMERIC_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices}); - } - - template - static __global__ void segmentMinBPLinearKernel(const void* inputBuf, const Nd4jLong* inputShape, void* forwardOutput, - const Nd4jLong* forwardShape, void* eps, const Nd4jLong* epsShape, const void* indicesBuf, const Nd4jLong* indicesShape, - void* outputBuf, const Nd4jLong* outputShape) { - __shared__ const T* x; - __shared__ T* gradIn; - __shared__ T* gradOut; - __shared__ const I* y; - __shared__ T* z; - __shared__ Nd4jLong xLen, gradLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - x = reinterpret_cast(inputBuf); - y = reinterpret_cast(indicesBuf); - z = reinterpret_cast(outputBuf); - gradIn = reinterpret_cast(forwardOutput); - gradOut = reinterpret_cast(eps); - gradLen = shape::length(epsShape); - } - __syncthreads(); - - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = gridDim.x * blockDim.x; - - for (auto e = start; e < xLen; e += step) { - - auto zOffset = shape::getIndexOffset(e, outputShape); - auto xOffset = shape::getIndexOffset(e, inputShape); - auto yOffset = shape::getIndexOffset(e, indicesShape); - auto classIndex = y[yOffset]; - auto gradOffsetI = shape::getIndexOffset(classIndex, forwardShape); - auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape); - - if (sd::math::nd4j_abs(gradIn[gradOffsetI] - x[xOffset]) <= T(1.e-6)) { - z[zOffset] = gradOut[gradOffsetO]; - } - } - } - - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void segmentMinBPTadKernel(const void* inputBuf, const Nd4jLong* inputShape, void* forwardOutput, - const Nd4jLong* forwardShape, void* eps, const Nd4jLong* epsShape, - const void* indicesBuf, const Nd4jLong* indicesShape, - void* outputBuf, const Nd4jLong* outputShape, - const Nd4jLong* inputTad, const Nd4jLong* inputOffsets, - const Nd4jLong* gradInTad, const Nd4jLong* gradInOffsets, - const Nd4jLong* gradOutTad, const Nd4jLong* gradOutOffsets, - const Nd4jLong* outTad, const Nd4jLong* outOffsets) { - __shared__ const T* x; - __shared__ T* gradIn; - __shared__ T* gradOut; - __shared__ const I* y; - __shared__ T* z; - __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - x = reinterpret_cast(inputBuf); - y = reinterpret_cast(indicesBuf); - z = reinterpret_cast(outputBuf); - yLen = shape::length(indicesShape); - gradOut = reinterpret_cast(eps); - gradIn = reinterpret_cast(forwardOutput); - gradLen = shape::length(epsShape); - currentLen = shape::length(outTad); - } - __syncthreads(); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // segmen min +template +static void segmentMinFunctor_(LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + auto stream = context->getCudaStream(); + Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; + auto classesRangesLens = + NDArrayFactory::create('c', {numClasses}, context); + auto classesRangesBegs = + NDArrayFactory::create('c', {numClasses}, context); + output->assign(DataTypeUtils::infOrMax()); + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + + fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); + NDArray::prepareSpecialUse( + {output}, {input, indices, &classesRangesBegs, &classesRangesLens}); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + if (input->isVector()) { + segmentMinLinearKernel + <<lengthOf(), numClasses * 32 + 32, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), begins, lengths, + numClasses, output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + segmentMinTadKernel<<sizeAt(0), 512, 2048, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), inputTads, + inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, + lengths, numClasses, output->specialBuffer(), + output->specialShapeInfo(), outputTads, outputTadOffsets); + } + NDArray::registerSpecialUse( + {output}, {input, indices, &classesRangesBegs, &classesRangesLens}); +} +// -------------------------------------------------------------------------------------------------------------- +// // +void segmentMinFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + output->nullify(); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), + segmentMinFunctor_, (context, input, indices, output), + NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); +} - for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { - auto yIndex = shape::getIndexOffset(i, indicesShape); - auto segment = y[yIndex]; - auto current = x + inputOffsets[i]; - auto currentOut = z + outOffsets[i]; - auto in = gradIn + gradInOffsets[segment]; - auto outGrad = gradOut + gradOutOffsets[segment]; +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static void unsortedSegmentMinFunctor_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + Nd4jLong numOfClasses, NDArray* output) { + auto stream = context->getCudaStream(); + // NDArray classes = NDArrayFactory::create('c', {numOfClasses, + // 2}); + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesLens = + NDArrayFactory::create('c', {numOfClasses}, context); + // NDArray row = NDArrayFactory::create('c', {1, 2}, + // {(int)indices->lengthOf(), (int)0}); + // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, + // &classes); + output->assign(DataTypeUtils::infOrMax()); + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); + // int* classesBuf = reinterpret_cast(classes.specialBuffer()); + fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + NDArray::prepareSpecialUse({output}, {input, indices}); + if (input->isVector()) { + unsortedSegmentMinLinearKernel<<>>( + input->specialBuffer(), input->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, + numOfClasses, output->specialBuffer(), output->specialShapeInfo()); + } else { + output->assign(DataTypeUtils::max()); + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + dims.x = input->sizeAt(0); + segmentMinTadKernel<<>>( + input->specialBuffer(), input->specialShapeInfo(), inputTads, + inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, + lengths, numOfClasses, output->specialBuffer(), + output->specialShapeInfo(), outputTads, outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices}); +} +// -------------------------------------------------------------------------------------------------------------- +// // +void unsortedSegmentMinFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + output->nullify(); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), + unsortedSegmentMinFunctor_, + (context, input, indices, numOfClasses, output), + NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); +} - for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { - if (sd::math::nd4j_abs(in[e] - current[e]) <= T(1.e-6)) - currentOut[e] = outGrad[e]; - } - } +template +static __global__ void segmentMinBPLinearKernel( + const void* inputBuf, const Nd4jLong* inputShape, void* forwardOutput, + const Nd4jLong* forwardShape, void* eps, const Nd4jLong* epsShape, + const void* indicesBuf, const Nd4jLong* indicesShape, void* outputBuf, + const Nd4jLong* outputShape) { + __shared__ const T* x; + __shared__ T* gradIn; + __shared__ T* gradOut; + __shared__ const I* y; + __shared__ T* z; + __shared__ Nd4jLong xLen, gradLen; + + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + x = reinterpret_cast(inputBuf); + y = reinterpret_cast(indicesBuf); + z = reinterpret_cast(outputBuf); + gradIn = reinterpret_cast(forwardOutput); + gradOut = reinterpret_cast(eps); + gradLen = shape::length(epsShape); + } + __syncthreads(); + + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = gridDim.x * blockDim.x; + + for (auto e = start; e < xLen; e += step) { + auto zOffset = shape::getIndexOffset(e, outputShape); + auto xOffset = shape::getIndexOffset(e, inputShape); + auto yOffset = shape::getIndexOffset(e, indicesShape); + auto classIndex = y[yOffset]; + auto gradOffsetI = shape::getIndexOffset(classIndex, forwardShape); + auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape); + + if (sd::math::nd4j_abs(gradIn[gradOffsetI] - x[xOffset]) <= T(1.e-6)) { + z[zOffset] = gradOut[gradOffsetO]; } + } +} - // -------------------------------------------------------------------------------------------------------------- // - template - int segmentMinFunctorBP_(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - //int numOfClasses = gradOut->sizeAt(0); - // if input is a vector: (as if in doc sample) - auto stream = context->getCudaStream(); - NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT(), context);//->shapeInfo(), context); - segmentMinFunctor_(context, input, indices, &tempRes); - NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes}); - if (input->isVector()) { - Nd4jLong loop_size = input->lengthOf(); - auto numOfClasses = gradOut->lengthOf(); //indices->e(loop_size - 1); - - segmentMinBPLinearKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimensions); - auto packGradIn = sd::ConstantTadHelper::getInstance().tadForDimensions(tempRes.shapeInfo(), dimensions); - auto packGradOut = sd::ConstantTadHelper::getInstance().tadForDimensions(gradOut->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - auto gradInTads = packGradIn.specialShapeInfo(); - auto gradInTadOffsets = packGradIn.specialOffsets(); - auto gradOutTads = packGradOut.specialShapeInfo(); - auto gradOutTadOffsets = packGradOut.specialOffsets(); - - segmentMinBPTadKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), - inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets, - outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes}); - return Status::OK(); - } - // -------------------------------------------------------------------------------------------------------------- // - // segmen min - int segmentMinFunctorBP(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMinFunctorBP_, (context, input, - indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void segmentMinBPTadKernel( + const void* inputBuf, const Nd4jLong* inputShape, void* forwardOutput, + const Nd4jLong* forwardShape, void* eps, const Nd4jLong* epsShape, + const void* indicesBuf, const Nd4jLong* indicesShape, void* outputBuf, + const Nd4jLong* outputShape, const Nd4jLong* inputTad, + const Nd4jLong* inputOffsets, const Nd4jLong* gradInTad, + const Nd4jLong* gradInOffsets, const Nd4jLong* gradOutTad, + const Nd4jLong* gradOutOffsets, const Nd4jLong* outTad, + const Nd4jLong* outOffsets) { + __shared__ const T* x; + __shared__ T* gradIn; + __shared__ T* gradOut; + __shared__ const I* y; + __shared__ T* z; + __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; + + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + x = reinterpret_cast(inputBuf); + y = reinterpret_cast(indicesBuf); + z = reinterpret_cast(outputBuf); + yLen = shape::length(indicesShape); + gradOut = reinterpret_cast(eps); + gradIn = reinterpret_cast(forwardOutput); + gradLen = shape::length(epsShape); + currentLen = shape::length(outTad); + } + __syncthreads(); + + for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { + auto yIndex = shape::getIndexOffset(i, indicesShape); + auto segment = y[yIndex]; + auto current = x + inputOffsets[i]; + auto currentOut = z + outOffsets[i]; + auto in = gradIn + gradInOffsets[segment]; + auto outGrad = gradOut + gradOutOffsets[segment]; + + for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { + if (sd::math::nd4j_abs(in[e] - current[e]) <= T(1.e-6)) + currentOut[e] = outGrad[e]; } + } +} - template - static int unsortedSegmentMinFunctorBP_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - //int numOfClasses = gradOut->sizeAt(0); - // if input is a vector: (as if in doc sample) - auto stream = context->getCudaStream(); - NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT(), context);//->shapeInfo(), context); - unsortedSegmentMinFunctor_(context, input, indices, numOfClasses, &tempRes); - NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes}); - if (input->isVector()) { - Nd4jLong loop_size = input->lengthOf(); - auto numOfClasses = gradOut->lengthOf(); //indices->e(loop_size - 1); - segmentMinBPLinearKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimensions); - auto packGradIn = sd::ConstantTadHelper::getInstance().tadForDimensions(tempRes.shapeInfo(), dimensions); - auto packGradOut = sd::ConstantTadHelper::getInstance().tadForDimensions(gradOut->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - auto gradInTads = packGradIn.specialShapeInfo(); - auto gradInTadOffsets = packGradIn.specialOffsets(); - auto gradOutTads = packGradOut.specialShapeInfo(); - auto gradOutTadOffsets = packGradOut.specialOffsets(); +// -------------------------------------------------------------------------------------------------------------- +// // +template +int segmentMinFunctorBP_(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + // int numOfClasses = gradOut->sizeAt(0); + // if input is a vector: (as if in doc sample) + auto stream = context->getCudaStream(); + NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), + DataTypeUtils::fromT(), + context); //->shapeInfo(), context); + segmentMinFunctor_(context, input, indices, &tempRes); + NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes}); + if (input->isVector()) { + Nd4jLong loop_size = input->lengthOf(); + auto numOfClasses = + gradOut->lengthOf(); // indices->e(loop_size - 1); + + segmentMinBPLinearKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + tempRes.specialBuffer(), tempRes.specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimensions); + auto packGradIn = sd::ConstantTadHelper::getInstance().tadForDimensions( + tempRes.shapeInfo(), dimensions); + auto packGradOut = sd::ConstantTadHelper::getInstance().tadForDimensions( + gradOut->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + auto gradInTads = packGradIn.specialShapeInfo(); + auto gradInTadOffsets = packGradIn.specialOffsets(); + auto gradOutTads = packGradOut.specialShapeInfo(); + auto gradOutTadOffsets = packGradOut.specialOffsets(); + + segmentMinBPTadKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + tempRes.specialBuffer(), tempRes.specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), inputTads, + inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, + gradOutTadOffsets, outputTads, outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes}); + return Status::OK(); +} +// -------------------------------------------------------------------------------------------------------------- +// // segmen min +int segmentMinFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + BUILD_DOUBLE_SELECTOR( + output->dataType(), indices->dataType(), return segmentMinFunctorBP_, + (context, input, indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); +} - segmentMinBPTadKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), - inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets, - outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes}); - return Status::OK(); - } - // -------------------------------------------------------------------------------------------------------------- // - int unsortedSegmentMinFunctorBP(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMinFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - } +template +static int unsortedSegmentMinFunctorBP_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + NDArray* gradOut, Nd4jLong numOfClasses, + NDArray* output) { + // int numOfClasses = gradOut->sizeAt(0); + // if input is a vector: (as if in doc sample) + auto stream = context->getCudaStream(); + NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), + DataTypeUtils::fromT(), + context); //->shapeInfo(), context); + unsortedSegmentMinFunctor_(context, input, indices, numOfClasses, + &tempRes); + NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes}); + if (input->isVector()) { + Nd4jLong loop_size = input->lengthOf(); + auto numOfClasses = + gradOut->lengthOf(); // indices->e(loop_size - 1); + segmentMinBPLinearKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + tempRes.specialBuffer(), tempRes.specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimensions); + auto packGradIn = sd::ConstantTadHelper::getInstance().tadForDimensions( + tempRes.shapeInfo(), dimensions); + auto packGradOut = sd::ConstantTadHelper::getInstance().tadForDimensions( + gradOut->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + auto gradInTads = packGradIn.specialShapeInfo(); + auto gradInTadOffsets = packGradIn.specialOffsets(); + auto gradOutTads = packGradOut.specialShapeInfo(); + auto gradOutTadOffsets = packGradOut.specialOffsets(); + + segmentMinBPTadKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + tempRes.specialBuffer(), tempRes.specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), inputTads, + inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, + gradOutTadOffsets, outputTads, outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes}); + return Status::OK(); } +// -------------------------------------------------------------------------------------------------------------- +// // +int unsortedSegmentMinFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + BUILD_DOUBLE_SELECTOR( + output->dataType(), indices->dataType(), + return unsortedSegmentMinFunctorBP_, + (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, + INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu index 44e07730001b..8c75fff717af 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu @@ -18,366 +18,462 @@ // @author GS // -#include -#include #include -#include -#include #include -#include #include +#include +#include +#include +#include +#include namespace sd { namespace ops { namespace helpers { - // -------------------------------------------------------------------------------------------------------------- // - // Segment Prod ops linear kernels - // -------------------------------------------------------------------------------------------------------------- // - - template - static __global__ void segmentProdLinearKernel(void* input, Nd4jLong const* inputShape, int* starts, int* lengths, - Nd4jLong numOfClasses, void* output, Nd4jLong const* outputShape) { - - __shared__ Nd4jLong xLen, zLen; - __shared__ T* x; - __shared__ T* z; - - if (threadIdx.x == 0) { - x = reinterpret_cast(input); - z = reinterpret_cast(output); - xLen = shape::length(inputShape); - zLen = shape::length(outputShape); - } - __syncthreads(); - - for(auto segment = blockIdx.x; segment < numOfClasses; segment += gridDim.x) { - auto zIndex = shape::getIndexOffset(segment, outputShape); - auto start = starts[segment]; - auto finish = start + lengths[segment]; - if (lengths[segment] == 0) { - continue; - } - for (auto e = start + threadIdx.x; e < finish; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputShape); - sd::math::atomics::nd4j_atomicMul(&z[segment], x[xIndex]); - } - } - - } - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void unsortedSegmentProdLinearKernel(T* input, Nd4jLong const* inputShape, I* indices, Nd4jLong const* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, T* output, Nd4jLong const* outputShape) { - __shared__ Nd4jLong xLen, zLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - zLen = shape::length(outputShape); - } - __syncthreads(); - auto start = threadIdx.x + blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; - for (auto idx = start; idx < xLen; idx += step) { - auto xIndex = shape::getIndexOffset(idx, inputShape); - auto yIndex = shape::getIndexOffset(idx, indicesShape); - auto segment = indices[yIndex]; - auto zIndex = shape::getIndexOffset(segment, outputShape); - if (lengths[segment] == 0) { - continue; - } - sd::math::atomics::nd4j_atomicMul(&output[zIndex], input[xIndex]); - } +// -------------------------------------------------------------------------------------------------------------- +// // Segment Prod ops linear kernels +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static __global__ void segmentProdLinearKernel( + void* input, Nd4jLong const* inputShape, int* starts, int* lengths, + Nd4jLong numOfClasses, void* output, Nd4jLong const* outputShape) { + __shared__ Nd4jLong xLen, zLen; + __shared__ T* x; + __shared__ T* z; + + if (threadIdx.x == 0) { + x = reinterpret_cast(input); + z = reinterpret_cast(output); + xLen = shape::length(inputShape); + zLen = shape::length(outputShape); + } + __syncthreads(); + + for (auto segment = blockIdx.x; segment < numOfClasses; + segment += gridDim.x) { + auto zIndex = shape::getIndexOffset(segment, outputShape); + auto start = starts[segment]; + auto finish = start + lengths[segment]; + if (lengths[segment] == 0) { + continue; } - // -------------------------------------------------------------------------------------------------------------- // - // SegmentProd kernel - template - static __global__ void segmentProdTadKernel(void* inputBuf, Nd4jLong const* inputShape, Nd4jLong const* inputTads, - Nd4jLong const* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, - Nd4jLong const* outputShape, Nd4jLong const* outputTads, Nd4jLong const* outputTadOffsets) { - - __shared__ Nd4jLong len, total; - - if (threadIdx.x == 0) { - total = shape::sizeAt(inputShape, 0); - len = shape::length(inputTads); - } - __syncthreads(); - - for (auto idx = blockIdx.x; idx < total; idx += gridDim.x) { - auto x = reinterpret_cast(inputBuf) + inputTadOffsets[idx]; - auto segment = indices[idx]; // / threadsPerSegment; - auto z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; - auto start = starts[segment]; - auto finish = start + lengths[segment]; - if (lengths[segment] == 0) continue; - for (auto e = threadIdx.x; e < len; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputTads); - auto zIndex = shape::getIndexOffset(e, outputTads); - sd::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]); - } - } + for (auto e = start + threadIdx.x; e < finish; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputShape); + sd::math::atomics::nd4j_atomicMul(&z[segment], x[xIndex]); } - // -------------------------------------------------------------------------------------------------------------- // - - template - static void segmentProdFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { - auto stream = context->getCudaStream(); - Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); - output->assign(1); - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - - dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); - fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - - if (input->isVector()) { - segmentProdLinearKernel<<<128, 256, 128, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - segmentProdTadKernel<<<128, 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); - } - + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void unsortedSegmentProdLinearKernel( + T* input, Nd4jLong const* inputShape, I* indices, + Nd4jLong const* indicesShape, int* starts, int* lengths, + Nd4jLong numOfClasses, T* output, Nd4jLong const* outputShape) { + __shared__ Nd4jLong xLen, zLen; + + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + zLen = shape::length(outputShape); + } + __syncthreads(); + auto start = threadIdx.x + blockIdx.x * blockDim.x; + auto step = blockDim.x * gridDim.x; + for (auto idx = start; idx < xLen; idx += step) { + auto xIndex = shape::getIndexOffset(idx, inputShape); + auto yIndex = shape::getIndexOffset(idx, indicesShape); + auto segment = indices[yIndex]; + auto zIndex = shape::getIndexOffset(segment, outputShape); + if (lengths[segment] == 0) { + continue; } - // -------------------------------------------------------------------------------------------------------------- // - void segmentProdFunctor(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentProdFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices}); + sd::math::atomics::nd4j_atomicMul(&output[zIndex], input[xIndex]); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // SegmentProd kernel +template +static __global__ void segmentProdTadKernel( + void* inputBuf, Nd4jLong const* inputShape, Nd4jLong const* inputTads, + Nd4jLong const* inputTadOffsets, I* indices, int* starts, int* lengths, + Nd4jLong numOfClasses, void* outputBuf, Nd4jLong const* outputShape, + Nd4jLong const* outputTads, Nd4jLong const* outputTadOffsets) { + __shared__ Nd4jLong len, total; + + if (threadIdx.x == 0) { + total = shape::sizeAt(inputShape, 0); + len = shape::length(inputTads); + } + __syncthreads(); + + for (auto idx = blockIdx.x; idx < total; idx += gridDim.x) { + auto x = reinterpret_cast(inputBuf) + inputTadOffsets[idx]; + auto segment = indices[idx]; // / threadsPerSegment; + auto z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; + auto start = starts[segment]; + auto finish = start + lengths[segment]; + if (lengths[segment] == 0) continue; + for (auto e = threadIdx.x; e < len; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputTads); + auto zIndex = shape::getIndexOffset(e, outputTads); + sd::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]); } + } +} +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static void segmentProdFunctor_(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + auto stream = context->getCudaStream(); + Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; + NDArray classesRangesLens = + NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numClasses}, context); + output->assign(1); + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + + dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); + fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + + if (input->isVector()) { + segmentProdLinearKernel<<<128, 256, 128, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), begins, lengths, + numClasses, output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + segmentProdTadKernel<<<128, 512, 2048, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), inputTads, + inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, + lengths, numClasses, output->specialBuffer(), + output->specialShapeInfo(), outputTads, outputTadOffsets); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +void segmentProdFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), + segmentProdFunctor_, (context, input, indices, output), + NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); +} - // -------------------------------------------------------------------------------------------------------------- // - template - static void unsortedSegmentProdFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - auto stream = context->getCudaStream(); -// NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); -// NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); -// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); -// int* classesBuf = reinterpret_cast(classes.specialBuffer()); - fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - output->assign(1); - - if (input->isVector()) { - unsortedSegmentProdLinearKernel<<<128, 256, 256, *stream>>>( - input->dataBuffer()->specialAsT(), input->specialShapeInfo(), - indices->dataBuffer()->specialAsT(), indices->specialShapeInfo(), begins, lengths, numOfClasses, - output->dataBuffer()->specialAsT(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - dims.x = input->sizeAt(0); - segmentProdTadKernel<<<128, 256, 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); - } - - } - // -------------------------------------------------------------------------------------------------------------- // - void unsortedSegmentProdFunctor(sd::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices}); - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentProdFunctor_, (context, input, indices, numOfClasses, output), - NUMERIC_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices}); - } +// -------------------------------------------------------------------------------------------------------------- +// // +template +static void unsortedSegmentProdFunctor_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + Nd4jLong numOfClasses, + NDArray* output) { + auto stream = context->getCudaStream(); + // NDArray classes = NDArrayFactory::create('c', {numOfClasses, + // 2}); + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesLens = + NDArrayFactory::create('c', {numOfClasses}, context); + // NDArray row = NDArrayFactory::create('c', {1, 2}, + // {(int)indices->lengthOf(), (int)0}); + // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, + // &classes); + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); + // int* classesBuf = reinterpret_cast(classes.specialBuffer()); + fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + output->assign(1); + + if (input->isVector()) { + unsortedSegmentProdLinearKernel<<<128, 256, 256, *stream>>>( + input->dataBuffer()->specialAsT(), input->specialShapeInfo(), + indices->dataBuffer()->specialAsT(), indices->specialShapeInfo(), + begins, lengths, numOfClasses, output->dataBuffer()->specialAsT(), + output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + dims.x = input->sizeAt(0); + segmentProdTadKernel<<<128, 256, 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), inputTads, + inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, + lengths, numOfClasses, output->specialBuffer(), + output->specialShapeInfo(), outputTads, outputTadOffsets); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +void unsortedSegmentProdFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), + unsortedSegmentProdFunctor_, + (context, input, indices, numOfClasses, output), + NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); +} - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void segmentProdBPLinearKernel(void* inputBuf, Nd4jLong const* inputShape, void* forwardOutput, - Nd4jLong const* forwardShape, void* eps, Nd4jLong const* epsShape, void* indicesBuf, Nd4jLong const* indicesShape, - void* outputBuf, Nd4jLong const* outputShape) { - __shared__ T* x; - __shared__ T* gradIn; - __shared__ T* gradOut; - __shared__ I* y; - __shared__ T* z; - __shared__ Nd4jLong xLen, gradLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - x = reinterpret_cast(inputBuf); - y = reinterpret_cast(indicesBuf); - z = reinterpret_cast(outputBuf); - gradIn = reinterpret_cast(forwardOutput); - gradOut = reinterpret_cast(eps); - gradLen = shape::length(epsShape); - } - __syncthreads(); - - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = gridDim.x * blockDim.x; - - for (auto e = start; e < xLen; e += step) { - - auto zOffset = shape::getIndexOffset(e, outputShape); - auto xOffset = shape::getIndexOffset(e, inputShape); - auto yOffset = shape::getIndexOffset(e, indicesShape); - auto classIndex = y[yOffset]; - auto gradOffsetI = shape::getIndexOffset(classIndex, forwardShape); - auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape); - - z[zOffset] = gradOut[gradOffsetO] * gradIn[gradOffsetI] / x[xOffset]; - } - } - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void segmentProdBPTadKernel(void* inputBuf, Nd4jLong const* inputShape, void* forwardOutput, - Nd4jLong const* forwardShape, void* eps, Nd4jLong const* epsShape, void* indicesBuf, Nd4jLong const* indicesShape, - void* outputBuf, Nd4jLong const* outputShape, Nd4jLong const* inputTad, - Nd4jLong const* inputOffsets, Nd4jLong const* gradInTad, Nd4jLong const* gradInOffsets, - Nd4jLong const* gradOutTad, Nd4jLong const* gradOutOffsets, Nd4jLong const* outTad, - Nd4jLong const* outOffsets) { - __shared__ T* x; - __shared__ T* gradIn; - __shared__ T* gradOut; - __shared__ I* y; - __shared__ T* z; - __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - x = reinterpret_cast(inputBuf); - y = reinterpret_cast(indicesBuf); - z = reinterpret_cast(outputBuf); - yLen = shape::length(indicesShape); - gradOut = reinterpret_cast(eps); - gradIn = reinterpret_cast(forwardOutput); - gradLen = shape::length(epsShape); - currentLen = shape::length(outTad); - } - __syncthreads(); - - for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { - auto yIndex = shape::getIndexOffset(i, indicesShape); - auto segment = y[yIndex]; - T* current = x + inputOffsets[i]; - T* currentOut = z + outOffsets[i]; - T* in = gradIn + gradInOffsets[segment]; - T* outGrad = gradOut + gradOutOffsets[segment]; - - for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { - currentOut[e] = outGrad[e] * in[e] / current[e]; - } - } +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void segmentProdBPLinearKernel( + void* inputBuf, Nd4jLong const* inputShape, void* forwardOutput, + Nd4jLong const* forwardShape, void* eps, Nd4jLong const* epsShape, + void* indicesBuf, Nd4jLong const* indicesShape, void* outputBuf, + Nd4jLong const* outputShape) { + __shared__ T* x; + __shared__ T* gradIn; + __shared__ T* gradOut; + __shared__ I* y; + __shared__ T* z; + __shared__ Nd4jLong xLen, gradLen; + + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + x = reinterpret_cast(inputBuf); + y = reinterpret_cast(indicesBuf); + z = reinterpret_cast(outputBuf); + gradIn = reinterpret_cast(forwardOutput); + gradOut = reinterpret_cast(eps); + gradLen = shape::length(epsShape); + } + __syncthreads(); + + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = gridDim.x * blockDim.x; + + for (auto e = start; e < xLen; e += step) { + auto zOffset = shape::getIndexOffset(e, outputShape); + auto xOffset = shape::getIndexOffset(e, inputShape); + auto yOffset = shape::getIndexOffset(e, indicesShape); + auto classIndex = y[yOffset]; + auto gradOffsetI = shape::getIndexOffset(classIndex, forwardShape); + auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape); + + z[zOffset] = gradOut[gradOffsetO] * gradIn[gradOffsetI] / x[xOffset]; + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void segmentProdBPTadKernel( + void* inputBuf, Nd4jLong const* inputShape, void* forwardOutput, + Nd4jLong const* forwardShape, void* eps, Nd4jLong const* epsShape, + void* indicesBuf, Nd4jLong const* indicesShape, void* outputBuf, + Nd4jLong const* outputShape, Nd4jLong const* inputTad, + Nd4jLong const* inputOffsets, Nd4jLong const* gradInTad, + Nd4jLong const* gradInOffsets, Nd4jLong const* gradOutTad, + Nd4jLong const* gradOutOffsets, Nd4jLong const* outTad, + Nd4jLong const* outOffsets) { + __shared__ T* x; + __shared__ T* gradIn; + __shared__ T* gradOut; + __shared__ I* y; + __shared__ T* z; + __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; + + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + x = reinterpret_cast(inputBuf); + y = reinterpret_cast(indicesBuf); + z = reinterpret_cast(outputBuf); + yLen = shape::length(indicesShape); + gradOut = reinterpret_cast(eps); + gradIn = reinterpret_cast(forwardOutput); + gradLen = shape::length(epsShape); + currentLen = shape::length(outTad); + } + __syncthreads(); + + for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { + auto yIndex = shape::getIndexOffset(i, indicesShape); + auto segment = y[yIndex]; + T* current = x + inputOffsets[i]; + T* currentOut = z + outOffsets[i]; + T* in = gradIn + gradInOffsets[segment]; + T* outGrad = gradOut + gradOutOffsets[segment]; + + for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { + currentOut[e] = outGrad[e] * in[e] / current[e]; } + } +} - // -------------------------------------------------------------------------------------------------------------- // - template - int segmentProdFunctorBP_(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - auto stream = context->getCudaStream(); - NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT(), context);//->shapeInfo(), context); - segmentProdFunctor_(context, input, indices, &tempRes); - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - if (input->isVector()) { - Nd4jLong loopSize = input->lengthOf(); - auto numOfClasses = gradOut->lengthOf(); //indices->e(loop_size - 1); - segmentProdBPLinearKernel<<lengthOf(), loopSize, 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimensions); - auto packGradIn = sd::ConstantTadHelper::getInstance().tadForDimensions(tempRes.shapeInfo(), dimensions); - auto packGradOut = sd::ConstantTadHelper::getInstance().tadForDimensions(gradOut->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - auto gradInTads = packGradIn.specialShapeInfo(); - auto gradInTadOffsets = packGradIn.specialOffsets(); - auto gradOutTads = packGradOut.specialShapeInfo(); - auto gradOutTadOffsets = packGradOut.specialOffsets(); - - segmentProdBPTadKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), - inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets, - outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - return Status::OK(); - } +// -------------------------------------------------------------------------------------------------------------- +// // +template +int segmentProdFunctorBP_(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + auto stream = context->getCudaStream(); + NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), + DataTypeUtils::fromT(), + context); //->shapeInfo(), context); + segmentProdFunctor_(context, input, indices, &tempRes); + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + if (input->isVector()) { + Nd4jLong loopSize = input->lengthOf(); + auto numOfClasses = + gradOut->lengthOf(); // indices->e(loop_size - 1); + segmentProdBPLinearKernel + <<lengthOf(), loopSize, 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + tempRes.specialBuffer(), tempRes.specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimensions); + auto packGradIn = sd::ConstantTadHelper::getInstance().tadForDimensions( + tempRes.shapeInfo(), dimensions); + auto packGradOut = sd::ConstantTadHelper::getInstance().tadForDimensions( + gradOut->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + auto gradInTads = packGradIn.specialShapeInfo(); + auto gradInTadOffsets = packGradIn.specialOffsets(); + auto gradOutTads = packGradOut.specialShapeInfo(); + auto gradOutTadOffsets = packGradOut.specialOffsets(); + + segmentProdBPTadKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + tempRes.specialBuffer(), tempRes.specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), inputTads, + inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, + gradOutTadOffsets, outputTads, outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); + return Status::OK(); +} - // -------------------------------------------------------------------------------------------------------------- // +// -------------------------------------------------------------------------------------------------------------- +// // - int segmentProdFunctorBP(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentProdFunctorBP_, (context, input, - indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - } +int segmentProdFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + BUILD_DOUBLE_SELECTOR( + output->dataType(), indices->dataType(), return segmentProdFunctorBP_, + (context, input, indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); +} - // -------------------------------------------------------------------------------------------------------------- // - - template - static int unsortedSegmentProdFunctorBP_(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - auto stream = context->getCudaStream(); - - NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT(), context);//->shapeInfo(), context); - unsortedSegmentProdFunctor_(context, input, indices, numOfClasses, &tempRes); - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - if (input->isVector()) { - Nd4jLong loopSize = input->lengthOf(); - auto numOfClasses = gradOut->lengthOf(); //indices->e(loop_size - 1); - segmentProdBPLinearKernel<<lengthOf(), loopSize, 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimensions); - auto packGradIn = sd::ConstantTadHelper::getInstance().tadForDimensions(tempRes.shapeInfo(), dimensions); - auto packGradOut = sd::ConstantTadHelper::getInstance().tadForDimensions(gradOut->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - auto gradInTads = packGradIn.specialShapeInfo(); - auto gradInTadOffsets = packGradIn.specialOffsets(); - auto gradOutTads = packGradOut.specialShapeInfo(); - auto gradOutTadOffsets = packGradOut.specialOffsets(); - - segmentProdBPTadKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), - inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets, - outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - return Status::OK(); - } +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static int unsortedSegmentProdFunctorBP_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + NDArray* gradOut, + Nd4jLong numOfClasses, + NDArray* output) { + auto stream = context->getCudaStream(); + + NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), + DataTypeUtils::fromT(), + context); //->shapeInfo(), context); + unsortedSegmentProdFunctor_(context, input, indices, numOfClasses, + &tempRes); + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + if (input->isVector()) { + Nd4jLong loopSize = input->lengthOf(); + auto numOfClasses = + gradOut->lengthOf(); // indices->e(loop_size - 1); + segmentProdBPLinearKernel + <<lengthOf(), loopSize, 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + tempRes.specialBuffer(), tempRes.specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimensions); + auto packGradIn = sd::ConstantTadHelper::getInstance().tadForDimensions( + tempRes.shapeInfo(), dimensions); + auto packGradOut = sd::ConstantTadHelper::getInstance().tadForDimensions( + gradOut->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + auto gradInTads = packGradIn.specialShapeInfo(); + auto gradInTadOffsets = packGradIn.specialOffsets(); + auto gradOutTads = packGradOut.specialShapeInfo(); + auto gradOutTadOffsets = packGradOut.specialOffsets(); + + segmentProdBPTadKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + tempRes.specialBuffer(), tempRes.specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), inputTads, + inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, + gradOutTadOffsets, outputTads, outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); + return Status::OK(); +} - // -------------------------------------------------------------------------------------------------------------- // - int unsortedSegmentProdFunctorBP(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentProdFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - } +// -------------------------------------------------------------------------------------------------------------- +// // +int unsortedSegmentProdFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + BUILD_DOUBLE_SELECTOR( + output->dataType(), indices->dataType(), + return unsortedSegmentProdFunctorBP_, + (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, + INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); +} - // -------------------------------------------------------------------------------------------------------------- // +// -------------------------------------------------------------------------------------------------------------- +// // -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu index 20f2323323d4..ee6fd1e8464d 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu @@ -18,240 +18,307 @@ // @author GS // -#include -#include #include -#include -#include #include -#include #include +#include +#include +#include +#include +#include namespace sd { namespace ops { namespace helpers { - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void unsortedSegmentSqrtNLinearKernel(T* input, Nd4jLong const* inputShape, I* indices, Nd4jLong const* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, T* output, Nd4jLong const* outputShape) { - __shared__ Nd4jLong xLen, zLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - zLen = shape::length(outputShape); - } - __syncthreads(); +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void unsortedSegmentSqrtNLinearKernel( + T* input, Nd4jLong const* inputShape, I* indices, + Nd4jLong const* indicesShape, int* starts, int* lengths, + Nd4jLong numOfClasses, T* output, Nd4jLong const* outputShape) { + __shared__ Nd4jLong xLen, zLen; - auto start = threadIdx.x + blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + zLen = shape::length(outputShape); + } + __syncthreads(); - for (auto idx = start; idx < xLen; idx += step) { - auto yIndex = shape::getIndexOffset(idx, indicesShape); - auto segment = indices[yIndex]; - auto zIndex = shape::getIndexOffset(segment, outputShape); - if (lengths[segment] == 0) continue; - auto xIndex = shape::getIndexOffset(idx, inputShape); + auto start = threadIdx.x + blockIdx.x * blockDim.x; + auto step = blockDim.x * gridDim.x; - sd::math::atomics::nd4j_atomicAdd(&output[zIndex], input[xIndex] / sd::math::nd4j_sqrt(lengths[segment])); - } - } - // -------------------------------------------------------------------------------------------------------------- // - // SegmentSqrtN kernel - template - static __global__ void segmentSqrtNTadKernel(T* inputBuf, Nd4jLong const* inputShape, Nd4jLong const* inputTads, Nd4jLong const* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong const* outputShape, Nd4jLong const* outputTads, Nd4jLong const* outputTadOffsets) { + for (auto idx = start; idx < xLen; idx += step) { + auto yIndex = shape::getIndexOffset(idx, indicesShape); + auto segment = indices[yIndex]; + auto zIndex = shape::getIndexOffset(segment, outputShape); + if (lengths[segment] == 0) continue; + auto xIndex = shape::getIndexOffset(idx, inputShape); - __shared__ Nd4jLong len, total; + sd::math::atomics::nd4j_atomicAdd( + &output[zIndex], + input[xIndex] / sd::math::nd4j_sqrt(lengths[segment])); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // SegmentSqrtN kernel +template +static __global__ void segmentSqrtNTadKernel( + T* inputBuf, Nd4jLong const* inputShape, Nd4jLong const* inputTads, + Nd4jLong const* inputTadOffsets, I* indices, int* starts, int* lengths, + Nd4jLong numOfClasses, void* outputBuf, Nd4jLong const* outputShape, + Nd4jLong const* outputTads, Nd4jLong const* outputTadOffsets) { + __shared__ Nd4jLong len, total; - if (threadIdx.x == 0) { - total = shape::sizeAt(inputShape, 0); - len = shape::length(inputTads); - } - __syncthreads(); + if (threadIdx.x == 0) { + total = shape::sizeAt(inputShape, 0); + len = shape::length(inputTads); + } + __syncthreads(); - for (auto idx = blockIdx.x; idx < total; idx += gridDim.x) { - auto segment = indices[idx]; - auto x = inputBuf + inputTadOffsets[idx]; - auto z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; - auto start = starts[segment]; - auto finish = start + lengths[segment]; + for (auto idx = blockIdx.x; idx < total; idx += gridDim.x) { + auto segment = indices[idx]; + auto x = inputBuf + inputTadOffsets[idx]; + auto z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; + auto start = starts[segment]; + auto finish = start + lengths[segment]; - for (auto e = threadIdx.x; e < len; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputTads); - auto zIndex = shape::getIndexOffset(e, outputTads); - sd::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex] / sd::math::nd4j_sqrt(lengths[segment])); - } - } - } - // -------------------------------------------------------------------------------------------------------------- // - template - static void unsortedSegmentSqrtNFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - auto stream = context->getCudaStream(); -// NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); -// NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); -// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); -// dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); - dim3 dims(128, 256, 256); -// int* classesBuf = reinterpret_cast(classes.specialBuffer()); - fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - output->nullify(); - if (input->isVector()) { - unsortedSegmentSqrtNLinearKernel<<>>( - input->dataBuffer()->specialAsT(), input->specialShapeInfo(), - indices->dataBuffer()->specialAsT(), indices->specialShapeInfo(), begins, lengths, numOfClasses, - output->dataBuffer()->specialAsT(), output->specialShapeInfo()); - } - else { - output->nullify(); - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - dims.x = input->sizeAt(0); - segmentSqrtNTadKernel<<>>( - input->dataBuffer()->specialAsT(), input->specialShapeInfo(), inputTads, inputTadOffsets, indices->dataBuffer()->specialAsT(), - begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); - } - } - // -------------------------------------------------------------------------------------------------------------- // - void unsortedSegmentSqrtNFunctor(sd::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices}); - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSqrtNFunctor_, (context, input, indices, numOfClasses, output), - FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices}); + for (auto e = threadIdx.x; e < len; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputTads); + auto zIndex = shape::getIndexOffset(e, outputTads); + sd::math::atomics::nd4j_atomicAdd( + &z[zIndex], + x[xIndex] / sd::math::nd4j_sqrt(lengths[segment])); } - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void segmentSqrtNBPLinearKernel(void* inputBuf, Nd4jLong const* inputShape, void* eps, Nd4jLong const* epsShape, void* indicesBuf, Nd4jLong const* indicesShape, - int* lengths, void* outputBuf, Nd4jLong const* outputShape) { - __shared__ T* x; - __shared__ T* gradIn; - __shared__ T* gradOut; - __shared__ I* y; - __shared__ T* z; - __shared__ Nd4jLong xLen, gradLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - x = reinterpret_cast(inputBuf); - y = reinterpret_cast(indicesBuf); - z = reinterpret_cast(outputBuf); - gradOut = reinterpret_cast(eps); - gradLen = shape::length(epsShape); - } - __syncthreads(); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +template +static void unsortedSegmentSqrtNFunctor_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + Nd4jLong numOfClasses, + NDArray* output) { + auto stream = context->getCudaStream(); + // NDArray classes = NDArrayFactory::create('c', {numOfClasses, + // 2}); + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesLens = + NDArrayFactory::create('c', {numOfClasses}, context); + // NDArray row = NDArrayFactory::create('c', {1, 2}, + // {(int)indices->lengthOf(), (int)0}); + // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, + // &classes); + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + // dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + + // 32); + dim3 dims(128, 256, 256); + // int* classesBuf = reinterpret_cast(classes.specialBuffer()); + fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + output->nullify(); + if (input->isVector()) { + unsortedSegmentSqrtNLinearKernel<<>>( + input->dataBuffer()->specialAsT(), input->specialShapeInfo(), + indices->dataBuffer()->specialAsT(), indices->specialShapeInfo(), + begins, lengths, numOfClasses, output->dataBuffer()->specialAsT(), + output->specialShapeInfo()); + } else { + output->nullify(); + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + dims.x = input->sizeAt(0); + segmentSqrtNTadKernel<<>>( + input->dataBuffer()->specialAsT(), input->specialShapeInfo(), + inputTads, inputTadOffsets, indices->dataBuffer()->specialAsT(), + begins, lengths, numOfClasses, output->specialBuffer(), + output->specialShapeInfo(), outputTads, outputTadOffsets); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +void unsortedSegmentSqrtNFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), + unsortedSegmentSqrtNFunctor_, + (context, input, indices, numOfClasses, output), + FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); +} +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void segmentSqrtNBPLinearKernel( + void* inputBuf, Nd4jLong const* inputShape, void* eps, + Nd4jLong const* epsShape, void* indicesBuf, Nd4jLong const* indicesShape, + int* lengths, void* outputBuf, Nd4jLong const* outputShape) { + __shared__ T* x; + __shared__ T* gradIn; + __shared__ T* gradOut; + __shared__ I* y; + __shared__ T* z; + __shared__ Nd4jLong xLen, gradLen; - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = gridDim.x * blockDim.x; + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + x = reinterpret_cast(inputBuf); + y = reinterpret_cast(indicesBuf); + z = reinterpret_cast(outputBuf); + gradOut = reinterpret_cast(eps); + gradLen = shape::length(epsShape); + } + __syncthreads(); - for (auto e = start; e < xLen; e += step) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = gridDim.x * blockDim.x; - auto zOffset = shape::getIndexOffset(e, outputShape); - auto xOffset = shape::getIndexOffset(e, inputShape); - auto yOffset = shape::getIndexOffset(e, indicesShape); - auto classIndex = y[yOffset]; - auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape); + for (auto e = start; e < xLen; e += step) { + auto zOffset = shape::getIndexOffset(e, outputShape); + auto xOffset = shape::getIndexOffset(e, inputShape); + auto yOffset = shape::getIndexOffset(e, indicesShape); + auto classIndex = y[yOffset]; + auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape); - z[zOffset] = T(gradOut[gradOffsetO] / math::nd4j_sqrt(lengths[classIndex])); - } - } - // -------------------------------------------------------------------------------------------------------------- // + z[zOffset] = T(gradOut[gradOffsetO] / + math::nd4j_sqrt(lengths[classIndex])); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // - template - static __global__ void segmentSqrtNBPTadKernel(void* inputBuf, Nd4jLong const* inputShape, void* eps, Nd4jLong const* epsShape, - void* indicesBuf, Nd4jLong const* indicesShape, int* lengths, void* outputBuf, Nd4jLong const* outputShape,Nd4jLong const* inputTad, - Nd4jLong const* inputOffsets, Nd4jLong const* gradOutTad, Nd4jLong const* gradOutOffsets, Nd4jLong const* outTad, Nd4jLong const* outOffsets) { - __shared__ T* x; - __shared__ T* gradOut; - __shared__ I* y; - __shared__ T* z; - __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; +template +static __global__ void segmentSqrtNBPTadKernel( + void* inputBuf, Nd4jLong const* inputShape, void* eps, + Nd4jLong const* epsShape, void* indicesBuf, Nd4jLong const* indicesShape, + int* lengths, void* outputBuf, Nd4jLong const* outputShape, + Nd4jLong const* inputTad, Nd4jLong const* inputOffsets, + Nd4jLong const* gradOutTad, Nd4jLong const* gradOutOffsets, + Nd4jLong const* outTad, Nd4jLong const* outOffsets) { + __shared__ T* x; + __shared__ T* gradOut; + __shared__ I* y; + __shared__ T* z; + __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - x = reinterpret_cast(inputBuf); - y = reinterpret_cast(indicesBuf); - z = reinterpret_cast(outputBuf); - yLen = shape::length(indicesShape); - gradOut = reinterpret_cast(eps); - gradLen = shape::length(epsShape); - currentLen = shape::length(outTad); - } - __syncthreads(); + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + x = reinterpret_cast(inputBuf); + y = reinterpret_cast(indicesBuf); + z = reinterpret_cast(outputBuf); + yLen = shape::length(indicesShape); + gradOut = reinterpret_cast(eps); + gradLen = shape::length(epsShape); + currentLen = shape::length(outTad); + } + __syncthreads(); - for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { -// auto yIndex = shape::getIndexOffset(i, indicesShape); - auto segment = y[i]; //yIndex]; - T* currentOut = z + outOffsets[i]; - T* outGrad = gradOut + gradOutOffsets[segment]; + for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { + // auto yIndex = shape::getIndexOffset(i, indicesShape); + auto segment = y[i]; // yIndex]; + T* currentOut = z + outOffsets[i]; + T* outGrad = gradOut + gradOutOffsets[segment]; - for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { - auto zIndex = shape::getIndexOffset(e, outTad); - auto gradIndex = shape::getIndexOffset(e, gradOutTad); - if (lengths[segment] > 0) - currentOut[zIndex] = T(outGrad[gradIndex] / math::nd4j_sqrt(lengths[segment])); - } - } + for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { + auto zIndex = shape::getIndexOffset(e, outTad); + auto gradIndex = shape::getIndexOffset(e, gradOutTad); + if (lengths[segment] > 0) + currentOut[zIndex] = T(outGrad[gradIndex] / + math::nd4j_sqrt(lengths[segment])); } - // -------------------------------------------------------------------------------------------------------------- // + } +} +// -------------------------------------------------------------------------------------------------------------- +// // - template - static int unsortedSegmentSqrtNFunctorBP_(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - auto numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); +template +static int unsortedSegmentSqrtNFunctorBP_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + NDArray* gradOut, + Nd4jLong numOfClasses, + NDArray* output) { + auto stream = context->getCudaStream(); + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + auto numClasses = indices->e(indices->lengthOf() - 1) + 1; + NDArray classesRangesLens = + NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numClasses}, context); - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); - fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); + fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - if (input->isVector()) { - Nd4jLong loop_size = input->lengthOf(); - auto numOfClasses = gradOut->lengthOf(); //indices->e(loop_size - 1); - segmentSqrtNBPLinearKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), - input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), lengths, output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimensions); -// auto packGradIn = sd::ConstantTadHelper::getInstance().tadForDimensions(tempRes.shapeInfo(), dimensions); - auto packGradOut = sd::ConstantTadHelper::getInstance().tadForDimensions(gradOut->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - auto gradOutTads = packGradOut.specialShapeInfo(); - auto gradOutTadOffsets = packGradOut.specialOffsets(); + if (input->isVector()) { + Nd4jLong loop_size = input->lengthOf(); + auto numOfClasses = + gradOut->lengthOf(); // indices->e(loop_size - 1); + segmentSqrtNBPLinearKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), lengths, + output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimensions); + // auto packGradIn = + // sd::ConstantTadHelper::getInstance().tadForDimensions(tempRes.shapeInfo(), + // dimensions); + auto packGradOut = sd::ConstantTadHelper::getInstance().tadForDimensions( + gradOut->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + auto gradOutTads = packGradOut.specialShapeInfo(); + auto gradOutTadOffsets = packGradOut.specialOffsets(); - segmentSqrtNBPTadKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - gradOut->specialBuffer(), gradOut->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), lengths, - output->specialBuffer(), output->specialShapeInfo(), inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets, - outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); + segmentSqrtNBPTadKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), lengths, + output->specialBuffer(), output->specialShapeInfo(), inputTads, + inputTadOffsets, gradOutTads, gradOutTadOffsets, outputTads, + outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - return Status::OK(); - } - // -------------------------------------------------------------------------------------------------------------- // - int unsortedSegmentSqrtNFunctorBP(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSqrtNFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - } + return Status::OK(); } +// -------------------------------------------------------------------------------------------------------------- +// // +int unsortedSegmentSqrtNFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + BUILD_DOUBLE_SELECTOR( + output->dataType(), indices->dataType(), + return unsortedSegmentSqrtNFunctorBP_, + (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, + INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu index a2050d695edd..80e4fec0174a 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu @@ -18,393 +18,448 @@ // @author GS // -#include -#include #include -#include -#include #include -#include #include +#include +#include +#include +#include +#include namespace sd { namespace ops { namespace helpers { - // -------------------------------------------------------------------------------------------------------------- // - // Segment ops linear kernels - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void - segmentSumLinearKernel( - const void *input, const Nd4jLong *inputShape, - int *starts, int *lengths, Nd4jLong numOfClasses, - void *output, const Nd4jLong *outputShape) { - __shared__ - T *val; - __shared__ - Nd4jLong xLen, zLen, segment, zIndex; - __shared__ - const T *x; - __shared__ - T *z; - __shared__ int threadsPerSegment, start, finish; - - if (threadIdx.x == 0) { - threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; - segment = blockIdx.x / threadsPerSegment; - x = reinterpret_cast(input); - z = reinterpret_cast(output); - - xLen = shape::length(inputShape); - zLen = shape::length(outputShape); - - - if (segment < numOfClasses) { - zIndex = shape::getIndexOffset(segment, outputShape); - start = starts[segment]; - finish = start + lengths[segment]; - //val[segment] = ; - z[zIndex] = x[shape::getIndexOffset(start, inputShape)]; - } - - } - __syncthreads(); - - for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputShape); - sd::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]); - } - } - // -------------------------------------------------------------------------------------------------------------- // - - template - static __global__ void - unsortedSegmentSumLinearKernel( - const void *input, const Nd4jLong *inputShape, - const void *indices, const Nd4jLong *indicesShape, - int *starts, int *lengths, Nd4jLong numOfClasses, - void *output, const Nd4jLong *outputShape) { - __shared__ - T *val; - __shared__ - Nd4jLong xLen, zLen, segment, zIndex; - __shared__ - const T *x; - __shared__ - T *z; - __shared__ - const I *y; //int threadsPerSegment, start, finish; - - if (threadIdx.x == 0) { - segment = blockIdx.x; - x = reinterpret_cast(input); - z = reinterpret_cast(output); - y = reinterpret_cast(indices); - xLen = shape::length(inputShape); - zLen = shape::length(outputShape); - - zIndex = shape::getIndexOffset(segment, outputShape); - if (lengths[segment] > 0) - z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape)]; - else - z[zIndex] = 0; //DataTypeUtils::max(); - } - __syncthreads(); - - if (lengths[segment] > 0) - for (auto e = threadIdx.x; e < xLen; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputShape); - auto yIndex = shape::getIndexOffset(e, indicesShape); - if (y[yIndex] == segment && e != starts[segment]) { - sd::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]); - } - } - } - // -------------------------------------------------------------------------------------------------------------- // - // SegmentSum kernel - template - static __global__ void segmentSumTadKernel( - const void* inputBuf, const Nd4jLong* inputShape, const Nd4jLong* inputTads, const Nd4jLong* inputTadOffsets, - const I* indices, - int* starts, int* lengths, Nd4jLong numOfClasses, - void* outputBuf, const Nd4jLong* outputShape, const Nd4jLong* outputTads, const Nd4jLong* outputTadOffsets) { - __shared__ T* val; - __shared__ Nd4jLong len, zIndex, total; - __shared__ T* z; - __shared__ int start, finish; - - if (threadIdx.x == 0) { - auto segment = indices[blockIdx.x]; // / threadsPerSegment; - z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; - len = shape::length(inputTads); - start = starts[segment]; - finish = start + lengths[segment]; - total = shape::sizeAt(inputShape, 0); - - } - __syncthreads(); - - auto idx = blockIdx.x; - if (blockIdx.x <= total) { - auto x = reinterpret_cast(inputBuf) + inputTadOffsets[idx]; - if (blockIdx.x == start) { - for (auto e = threadIdx.x; e < len; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputTads); - auto zIndex = shape::getIndexOffset(e, outputTads); - sd::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]); - } - } - else { - for (auto e = threadIdx.x; e < len; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputTads); - auto zIndex = shape::getIndexOffset(e, outputTads); - if (lengths[indices[idx]]) - sd::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]); - } - } - } - } - // -------------------------------------------------------------------------------------------------------------- // - - template - static void segmentSumFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { - auto stream = context->getCudaStream(); - Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); - - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - - dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); - fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - - if (input->isVector()) { - segmentSumLinearKernel<<lengthOf(), numClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - segmentSumTadKernel<<sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); - } - - } - // -------------------------------------------------------------------------------------------------------------- // - void segmentSumFunctor(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices}); - output->nullify(); - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentSumFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices}); - } - - // -------------------------------------------------------------------------------------------------------------- // - template - static void unsortedSegmentSumFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - auto stream = context->getCudaStream(); -// NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); -// NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); -// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - dim3 dims(numOfClasses, indices->lengthOf(), (numOfClasses + 1) * 64); -// int* classesBuf = reinterpret_cast(classes.specialBuffer()); - fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - - if (input->isVector()) { - unsortedSegmentSumLinearKernel<<>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo()); - } - else { - output->assign(0); - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - dims.x = input->sizeAt(0); - segmentSumTadKernel<<>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); - } - +// -------------------------------------------------------------------------------------------------------------- +// // Segment ops linear kernels +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void segmentSumLinearKernel( + const void* input, const Nd4jLong* inputShape, int* starts, int* lengths, + Nd4jLong numOfClasses, void* output, const Nd4jLong* outputShape) { + __shared__ T* val; + __shared__ Nd4jLong xLen, zLen, segment, zIndex; + __shared__ const T* x; + __shared__ T* z; + __shared__ int threadsPerSegment, start, finish; + + if (threadIdx.x == 0) { + threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; + segment = blockIdx.x / threadsPerSegment; + x = reinterpret_cast(input); + z = reinterpret_cast(output); + + xLen = shape::length(inputShape); + zLen = shape::length(outputShape); + + if (segment < numOfClasses) { + zIndex = shape::getIndexOffset(segment, outputShape); + start = starts[segment]; + finish = start + lengths[segment]; + // val[segment] = ; + z[zIndex] = x[shape::getIndexOffset(start, inputShape)]; } - // -------------------------------------------------------------------------------------------------------------- // - void unsortedSegmentSumFunctor(sd::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices}); - output->nullify(); - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSumFunctor_, (context, input, indices, numOfClasses, output), - NUMERIC_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices}); - - } - - // -------------------------------------------------------------------------------------------------------------- // - // Backpropagate ops - // -------------------------------------------------------------------------------------------------------------- // - // Sorted sum backpropagate - template - static __global__ void segmentSumBPLinearKernel( - const void* inputBuf, const Nd4jLong* inputShape, - const void* eps, const Nd4jLong* epsShape, - const void* indicesBuf, const Nd4jLong* indicesShape, - void* outputBuf, const Nd4jLong* outputShape) { - auto x = reinterpret_cast(inputBuf); - auto y = reinterpret_cast(indicesBuf); - auto z = reinterpret_cast(outputBuf); - auto gradOut = reinterpret_cast(eps); - __shared__ Nd4jLong xLen, gradLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - gradLen = shape::length(epsShape); - } - __syncthreads(); - - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = gridDim.x * blockDim.x; - - for (auto e = start; e < xLen; e += step) { - - auto zOffset = shape::getIndexOffset(e, outputShape); - auto xOffset = shape::getIndexOffset(e, inputShape); - auto yOffset = shape::getIndexOffset(e, indicesShape); - auto classIndex = y[yOffset]; - auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape); - - z[zOffset] = gradOut[gradOffsetO]; - } - } - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void segmentSumBPTadKernel( - const void* inputBuf, const Nd4jLong* inputShape, - const void* eps, const Nd4jLong* epsShape, - const void* indicesBuf, const Nd4jLong* indicesShape, - void* outputBuf, const Nd4jLong* outputShape, - const Nd4jLong* inputTad, const Nd4jLong* inputOffsets, - const Nd4jLong* gradOutTad, const Nd4jLong* gradOutOffsets, - const Nd4jLong* outTad, const Nd4jLong* outOffsets) { - __shared__ const T* x; - __shared__ const T* gradOut; - __shared__ const I* y; - __shared__ T* z; - __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - x = reinterpret_cast(inputBuf); - y = reinterpret_cast(indicesBuf); - z = reinterpret_cast(outputBuf); - yLen = shape::length(indicesShape); - gradOut = reinterpret_cast(eps); - gradLen = shape::length(epsShape); - currentLen = shape::length(outTad); - } - __syncthreads(); - - for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { - auto yIndex = shape::getIndexOffset(i, indicesShape); - auto segment = y[yIndex]; - auto currentOut = z + outOffsets[i]; - auto outGrad = gradOut + gradOutOffsets[segment]; - - for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { - currentOut[e] = outGrad[e]; - } - } + } + __syncthreads(); + for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputShape); + sd::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static __global__ void unsortedSegmentSumLinearKernel( + const void* input, const Nd4jLong* inputShape, const void* indices, + const Nd4jLong* indicesShape, int* starts, int* lengths, + Nd4jLong numOfClasses, void* output, const Nd4jLong* outputShape) { + __shared__ T* val; + __shared__ Nd4jLong xLen, zLen, segment, zIndex; + __shared__ const T* x; + __shared__ T* z; + __shared__ const I* y; // int threadsPerSegment, start, finish; + + if (threadIdx.x == 0) { + segment = blockIdx.x; + x = reinterpret_cast(input); + z = reinterpret_cast(output); + y = reinterpret_cast(indices); + xLen = shape::length(inputShape); + zLen = shape::length(outputShape); + + zIndex = shape::getIndexOffset(segment, outputShape); + if (lengths[segment] > 0) + z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape)]; + else + z[zIndex] = 0; // DataTypeUtils::max(); + } + __syncthreads(); + + if (lengths[segment] > 0) + for (auto e = threadIdx.x; e < xLen; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputShape); + auto yIndex = shape::getIndexOffset(e, indicesShape); + if (y[yIndex] == segment && e != starts[segment]) { + sd::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]); + } } - // -------------------------------------------------------------------------------------------------------------- // - template - int segmentSumFunctorBP_(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - if (input->isVector()) { - Nd4jLong loop_size = input->lengthOf(); - auto numOfClasses = gradOut->lengthOf(); //indices->e(loop_size - 1); - segmentSumBPLinearKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), - input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimensions); - auto packGradOut = sd::ConstantTadHelper::getInstance().tadForDimensions(gradOut->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - auto gradOutTads = packGradOut.specialShapeInfo(); - auto gradOutTadOffsets = packGradOut.specialOffsets(); - - segmentSumBPTadKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), - inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets, - outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - return Status::OK(); +} +// -------------------------------------------------------------------------------------------------------------- +// // SegmentSum kernel +template +static __global__ void segmentSumTadKernel( + const void* inputBuf, const Nd4jLong* inputShape, const Nd4jLong* inputTads, + const Nd4jLong* inputTadOffsets, const I* indices, int* starts, + int* lengths, Nd4jLong numOfClasses, void* outputBuf, + const Nd4jLong* outputShape, const Nd4jLong* outputTads, + const Nd4jLong* outputTadOffsets) { + __shared__ T* val; + __shared__ Nd4jLong len, zIndex, total; + __shared__ T* z; + __shared__ int start, finish; + + if (threadIdx.x == 0) { + auto segment = indices[blockIdx.x]; // / threadsPerSegment; + z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; + len = shape::length(inputTads); + start = starts[segment]; + finish = start + lengths[segment]; + total = shape::sizeAt(inputShape, 0); + } + __syncthreads(); + + auto idx = blockIdx.x; + if (blockIdx.x <= total) { + auto x = reinterpret_cast(inputBuf) + inputTadOffsets[idx]; + if (blockIdx.x == start) { + for (auto e = threadIdx.x; e < len; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputTads); + auto zIndex = shape::getIndexOffset(e, outputTads); + sd::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]); + } + } else { + for (auto e = threadIdx.x; e < len; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputTads); + auto zIndex = shape::getIndexOffset(e, outputTads); + if (lengths[indices[idx]]) + sd::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]); + } } - // -------------------------------------------------------------------------------------------------------------- // + } +} +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static void segmentSumFunctor_(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + auto stream = context->getCudaStream(); + Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; + NDArray classesRangesLens = + NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numClasses}, context); + + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + + dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); + fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + + if (input->isVector()) { + segmentSumLinearKernel + <<lengthOf(), numClasses * 32 + 32, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), begins, lengths, + numClasses, output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + segmentSumTadKernel<<sizeAt(0), 512, 2048, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), inputTads, + inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, + lengths, numClasses, output->specialBuffer(), + output->specialShapeInfo(), outputTads, outputTadOffsets); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +void segmentSumFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + output->nullify(); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), + segmentSumFunctor_, (context, input, indices, output), + NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); +} - int segmentSumFunctorBP(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentSumFunctorBP_, (context, input, - indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - } +// -------------------------------------------------------------------------------------------------------------- +// // +template +static void unsortedSegmentSumFunctor_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + Nd4jLong numOfClasses, NDArray* output) { + auto stream = context->getCudaStream(); + // NDArray classes = NDArrayFactory::create('c', {numOfClasses, + // 2}); + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesLens = + NDArrayFactory::create('c', {numOfClasses}, context); + // NDArray row = NDArrayFactory::create('c', {1, 2}, + // {(int)indices->lengthOf(), (int)0}); + // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, + // &classes); + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + dim3 dims(numOfClasses, indices->lengthOf(), (numOfClasses + 1) * 64); + // int* classesBuf = reinterpret_cast(classes.specialBuffer()); + fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + + if (input->isVector()) { + unsortedSegmentSumLinearKernel<<>>( + input->specialBuffer(), input->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, + numOfClasses, output->specialBuffer(), output->specialShapeInfo()); + } else { + output->assign(0); + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + dims.x = input->sizeAt(0); + segmentSumTadKernel<<>>( + input->specialBuffer(), input->specialShapeInfo(), inputTads, + inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, + lengths, numOfClasses, output->specialBuffer(), + output->specialShapeInfo(), outputTads, outputTadOffsets); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +void unsortedSegmentSumFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + output->nullify(); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), + unsortedSegmentSumFunctor_, + (context, input, indices, numOfClasses, output), + NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); +} - template - static int unsortedSegmentSumFunctorBP_(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - if (input->isVector()) { - Nd4jLong loop_size = input->lengthOf(); - auto numOfClasses = gradOut->lengthOf(); //indices->e(loop_size - 1); - segmentSumBPLinearKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), - input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), dimensions); - auto packGradOut = sd::ConstantTadHelper::getInstance().tadForDimensions(gradOut->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - auto gradOutTads = packGradOut.specialShapeInfo(); - auto gradOutTadOffsets = packGradOut.specialOffsets(); - - segmentSumBPTadKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), - inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets, - outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - return Status::OK(); - } - // -------------------------------------------------------------------------------------------------------------- // - int unsortedSegmentSumFunctorBP(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSumFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); +// -------------------------------------------------------------------------------------------------------------- +// // Backpropagate ops +// -------------------------------------------------------------------------------------------------------------- +// // Sorted sum backpropagate +template +static __global__ void segmentSumBPLinearKernel( + const void* inputBuf, const Nd4jLong* inputShape, const void* eps, + const Nd4jLong* epsShape, const void* indicesBuf, + const Nd4jLong* indicesShape, void* outputBuf, + const Nd4jLong* outputShape) { + auto x = reinterpret_cast(inputBuf); + auto y = reinterpret_cast(indicesBuf); + auto z = reinterpret_cast(outputBuf); + auto gradOut = reinterpret_cast(eps); + __shared__ Nd4jLong xLen, gradLen; + + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + gradLen = shape::length(epsShape); + } + __syncthreads(); + + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = gridDim.x * blockDim.x; + + for (auto e = start; e < xLen; e += step) { + auto zOffset = shape::getIndexOffset(e, outputShape); + auto xOffset = shape::getIndexOffset(e, inputShape); + auto yOffset = shape::getIndexOffset(e, indicesShape); + auto classIndex = y[yOffset]; + auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape); + + z[zOffset] = gradOut[gradOffsetO]; + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void segmentSumBPTadKernel( + const void* inputBuf, const Nd4jLong* inputShape, const void* eps, + const Nd4jLong* epsShape, const void* indicesBuf, + const Nd4jLong* indicesShape, void* outputBuf, const Nd4jLong* outputShape, + const Nd4jLong* inputTad, const Nd4jLong* inputOffsets, + const Nd4jLong* gradOutTad, const Nd4jLong* gradOutOffsets, + const Nd4jLong* outTad, const Nd4jLong* outOffsets) { + __shared__ const T* x; + __shared__ const T* gradOut; + __shared__ const I* y; + __shared__ T* z; + __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; + + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + x = reinterpret_cast(inputBuf); + y = reinterpret_cast(indicesBuf); + z = reinterpret_cast(outputBuf); + yLen = shape::length(indicesShape); + gradOut = reinterpret_cast(eps); + gradLen = shape::length(epsShape); + currentLen = shape::length(outTad); + } + __syncthreads(); + + for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { + auto yIndex = shape::getIndexOffset(i, indicesShape); + auto segment = y[yIndex]; + auto currentOut = z + outOffsets[i]; + auto outGrad = gradOut + gradOutOffsets[segment]; + + for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { + currentOut[e] = outGrad[e]; } + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +template +int segmentSumFunctorBP_(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + auto stream = context->getCudaStream(); + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + if (input->isVector()) { + Nd4jLong loop_size = input->lengthOf(); + auto numOfClasses = + gradOut->lengthOf(); // indices->e(loop_size - 1); + segmentSumBPLinearKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimensions); + auto packGradOut = sd::ConstantTadHelper::getInstance().tadForDimensions( + gradOut->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + auto gradOutTads = packGradOut.specialShapeInfo(); + auto gradOutTadOffsets = packGradOut.specialOffsets(); + + segmentSumBPTadKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), inputTads, + inputTadOffsets, gradOutTads, gradOutTadOffsets, outputTads, + outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); + return Status::OK(); +} +// -------------------------------------------------------------------------------------------------------------- +// // + +int segmentSumFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + BUILD_DOUBLE_SELECTOR( + output->dataType(), indices->dataType(), return segmentSumFunctorBP_, + (context, input, indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); +} +template +static int unsortedSegmentSumFunctorBP_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + NDArray* gradOut, Nd4jLong numOfClasses, + NDArray* output) { + auto stream = context->getCudaStream(); + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + if (input->isVector()) { + Nd4jLong loop_size = input->lengthOf(); + auto numOfClasses = + gradOut->lengthOf(); // indices->e(loop_size - 1); + segmentSumBPLinearKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), dimensions); + auto packGradOut = sd::ConstantTadHelper::getInstance().tadForDimensions( + gradOut->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + auto gradOutTads = packGradOut.specialShapeInfo(); + auto gradOutTadOffsets = packGradOut.specialOffsets(); + + segmentSumBPTadKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), inputTads, + inputTadOffsets, gradOutTads, gradOutTadOffsets, outputTads, + outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); + return Status::OK(); } +// -------------------------------------------------------------------------------------------------------------- +// // +int unsortedSegmentSumFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + BUILD_DOUBLE_SELECTOR( + output->dataType(), indices->dataType(), + return unsortedSegmentSumFunctorBP_, + (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, + INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } -} \ No newline at end of file + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/sequence_mask.cu b/libnd4j/include/ops/declarable/helpers/cuda/sequence_mask.cu index 51b7590c0ad0..bd3a854fc408 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/sequence_mask.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/sequence_mask.cu @@ -24,41 +24,53 @@ namespace sd { namespace ops { namespace helpers { - template - static __global__ void sequenceMaskKernel(const void* inputBuf, const Nd4jLong* inputShape, void* outputBuf, const Nd4jLong* outputShape, int maxIndex) { +template +static __global__ void sequenceMaskKernel(const void* inputBuf, + const Nd4jLong* inputShape, + void* outputBuf, + const Nd4jLong* outputShape, + int maxIndex) { + __shared__ const I* input; + __shared__ B* output; + __shared__ Nd4jLong inputLen, outputLen; + if (threadIdx.x == 0) { + input = reinterpret_cast(inputBuf); + output = reinterpret_cast(outputBuf); + inputLen = shape::length(inputShape); + outputLen = shape::length(outputShape); + } + __syncthreads(); - __shared__ const I* input; - __shared__ B* output; - __shared__ Nd4jLong inputLen, outputLen; - if (threadIdx.x == 0) { - input = reinterpret_cast(inputBuf); - output = reinterpret_cast(outputBuf); - inputLen = shape::length(inputShape); - outputLen = shape::length(outputShape); - } - __syncthreads(); - - for (auto i = blockIdx.x; i < maxIndex; i += gridDim.x) - for(auto k = threadIdx.x; k < inputLen; k += blockDim.x) - if (i < input[shape::getIndexOffset(k, inputShape)]) - output[shape::getIndexOffset(k * maxIndex + i, outputShape)] = B(true); - - } - - template - static void sequenceMask_(LaunchContext* context, NDArray* input, NDArray* output, int maxIndex) { - dim3 launchDims(maxIndex, input->lengthOf(), 128); - NDArray::prepareSpecialUse({output}, {input}); - auto stream = context->getCudaStream(); - sequenceMaskKernel<<>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), maxIndex); - NDArray::registerSpecialUse({output}, {input}); - } - - void sequenceMask(sd::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (context, input, output, maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED); - } + for (auto i = blockIdx.x; i < maxIndex; i += gridDim.x) + for (auto k = threadIdx.x; k < inputLen; k += blockDim.x) + if (i < input[shape::getIndexOffset(k, inputShape)]) + output[shape::getIndexOffset(k * maxIndex + i, outputShape)] = B(true); +} - BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (sd::LaunchContext* context, NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED); +template +static void sequenceMask_(LaunchContext* context, NDArray* input, + NDArray* output, int maxIndex) { + dim3 launchDims(maxIndex, input->lengthOf(), 128); + NDArray::prepareSpecialUse({output}, {input}); + auto stream = context->getCudaStream(); + sequenceMaskKernel + <<>>( + input->specialBuffer(), input->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), maxIndex); + NDArray::registerSpecialUse({output}, {input}); } + +void sequenceMask(sd::LaunchContext* context, NDArray* input, NDArray* output, + int maxIndex) { + BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, + (context, input, output, maxIndex), INTEGER_TYPES, + LIBND4J_TYPES_EXTENDED); } -} \ No newline at end of file + +BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, + (sd::LaunchContext * context, NDArray* input, + NDArray* output, int maxIndex), + INTEGER_TYPES, LIBND4J_TYPES_EXTENDED); +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/sg_cb.cu b/libnd4j/include/ops/declarable/helpers/cuda/sg_cb.cu index 3957f23d5ada..cbec9bc8d07f 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/sg_cb.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/sg_cb.cu @@ -18,745 +18,936 @@ // @author raver119@gmail.com // -#include -#include #include +#include +#include #define HS_MAX_EXP 6.0f namespace sd { - namespace ops { - namespace helpers { - template - __global__ void hSoftmaxKernel(void *vsyn0, void *vsyn1, void *vexpTable, void *vneu1e, double alpha, int vectorLength, int code, int expLength, bool isInference) { - - auto syn0 = reinterpret_cast(vsyn0); - auto syn1 = reinterpret_cast(vsyn1); - auto expTable = reinterpret_cast(vexpTable); - auto neu1e = reinterpret_cast(vneu1e); - - T dot(0.0f); - T g(0.0f); - T f(0.0f); - - // dot - for (int e = 0; e < vectorLength; e++) { - dot += syn0[e] * syn1[e]; - } - - // gradient - if (dot < (T) - HS_MAX_EXP || dot >= (T) HS_MAX_EXP) - return; - - - int idx = static_cast((dot + HS_MAX_EXP) * ((float) expLength / HS_MAX_EXP / 2.0f)); - - if (idx >= expLength || idx < 0) - return; - - f = expTable[idx]; - g = (static_cast(1.0f) - static_cast(code) - f) * (T) alpha; - - // axpy1 - - for (int e = 0; e < vectorLength; e++) { - neu1e[e] = g * syn1[e] + neu1e[e]; - } - - // axpy2 - if (!isInference) { - for (int e = 0; e < vectorLength; e++) { - syn1[e] = g * syn0[e] + syn1[e]; - } - } - } - - template - void hSoftmax_(void *vsyn0, void *vsyn1, void *vexpTable, void *vneu1e, double alpha, int vectorLength, int code, int expLength, bool isInference, cudaStream_t* stream) { - hSoftmaxKernel<<<1,1,128, *stream>>>(vsyn0, vsyn1, vexpTable, vneu1e, alpha, vectorLength, code, expLength, isInference); - } - - template - __global__ void nSamplingKernel(void *vsyn0, void *vsyn1Neg, void *vexpTable, void *vneu1e, double alpha, int vectorLength, int code, int expLength, bool isInference) { - auto syn0 = reinterpret_cast(vsyn0); - auto syn1Neg = reinterpret_cast(vsyn1Neg); - auto expTable = reinterpret_cast(vexpTable); - auto neu1e = reinterpret_cast(vneu1e); - - T dot = (T) 0.0f; - T g = (T) 0.0f; - - for (int e = 0; e < vectorLength; e++) { - dot += syn0[e] * syn1Neg[e]; - } - - if (dot > HS_MAX_EXP) - g = (code - 1) * alpha; - else if (dot < (T) - HS_MAX_EXP) - g = (code - 0) * alpha; - else { - int idx = (int) ((dot + (T) HS_MAX_EXP) * ((T) expLength / HS_MAX_EXP / 2.0)); - if (idx >= expLength) - return; - - if (idx < 0) - return; - - g = ((T) code - expTable[idx]) * alpha; - } - - // axpy1 - for (int e = 0; e < vectorLength; e++) { - neu1e[e] = g * syn1Neg[e] + neu1e[e]; - } - - // axpy2 - if (!isInference) { - for (int e = 0; e < vectorLength; e++) { - syn1Neg[e] = g * syn0[e] + syn1Neg[e]; - } - } - } - - template - void nSampling_(void *vsyn0, void *vsyn1Neg, void *vexpTable, void *vneu1e, double alpha, int vectorLength, int code, int expLength, bool isInference, cudaStream_t* stream) { - nSamplingKernel<<<1,1,128, *stream>>>(vsyn0, vsyn1Neg, vexpTable, vneu1e, alpha, vectorLength, code, expLength, isInference); - } - - /* - * binarySearch - find element in haystack buffer (haystack - sorted device memory) - * */ - int binarySearch(const int *haystack, const int needle, const int totalElements) { - int firstIndex = 0; - int lastIndex = totalElements - 1; - int halfIndex = sd::math::nd4j_floor((lastIndex + firstIndex) / (float) 2); - - while(haystack[halfIndex] != needle && firstIndex < lastIndex) { - if (needle < haystack[halfIndex]) { - lastIndex = halfIndex - 1; - } else if (needle > haystack[halfIndex]) { - firstIndex = halfIndex + 1; - } - halfIndex = sd::math::nd4j_floor((lastIndex + firstIndex) / (float) 2); - } - - return (haystack[halfIndex] == needle) ? halfIndex : -1; - } - template - __global__ void addInfVectorKernel(T* neu1, T* infVector, int vectorLength) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (auto i = start; i < vectorLength; i += step) { - neu1[i] += infVector[i]; - } - } - - template - void skipgram_(NDArray& s0, NDArray& s1, NDArray& s1n, NDArray& expTableV, NDArray& negTableV, NDArray& infV, int target, int ngStarter, NDArray& indices, NDArray& codes, double alpha, Nd4jLong randomValue, const int hsRounds, const int nsRounds) { -// void *vsyn0, void *vsyn1, void *vsyn1Neg, void *vexpTable, void *vnegTable, void *vinfVector, int target, int ngStarter, int *indices, int8_t *codes, double alpha, Nd4jLong randomValue, const int hsRounds, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength) { - auto syn0 = reinterpret_cast(s0.specialBuffer()); - auto syn1 = reinterpret_cast(s1.specialBuffer()); - auto syn1Neg = reinterpret_cast(s1n.specialBuffer()); - auto expTable = reinterpret_cast(expTableV.specialBuffer()); - auto negTable = reinterpret_cast(negTableV.specialBuffer()); - auto infVector = reinterpret_cast(infV.specialBuffer()); - const int vocabSize = s0.sizeAt(0); - const int vectorLength = s0.sizeAt(1); - const int expLength = expTableV.lengthOf(); - const int negLength = negTableV.lengthOf(); - indices.tickReadDevice(); - indices.syncToHost(); - codes.tickReadDevice(); - codes.syncToHost(); - auto stream = s0.getContext()->getCudaStream(); - - T* neu1e; // = new T[vectorLength]; - //memset(neu1e, 0, vectorLength * sizeof(T)); - auto err = cudaMalloc(&neu1e, sizeof(T) * vectorLength); - err = cudaMemset(neu1e, 0, sizeof(T) * vectorLength); - // hierarchic softmax goes first (if enabled) - - auto syn0row = infVector != nullptr ? infVector : syn0 + (target * vectorLength); - auto irow = 0; - if (hsRounds > 0) { - for (int r = 0; r < hsRounds; r++) { - irow = indices.t(r); - if (irow < 0 || irow >= vocabSize) - break; - - hSoftmax_(syn0row, syn1 + (irow * vectorLength), expTable, neu1e, alpha, vectorLength, codes.t(r), expLength, infVector != nullptr, stream); - } - } - - // negative sampling goes second (if enabled) - auto nsStarter = ngStarter; - irow = nsStarter; - if (nsRounds > 0) { - for (int r = 0; r < nsRounds + 1; r++) { - if (r == 0) { - // target is known in advance - } else { - randomValue = randomValue * (unsigned long long) 25214903917 + 11; - auto idx = sd::math::nd4j_abs((randomValue >> 16) % negLength); - irow = idx >= negLength ? -1 : negTableV.e(idx); - - if (irow < 0 || irow >= vocabSize) irow = randomValue % (vocabSize - 1) + 1; - if (irow == nsStarter) - continue; - } - - nSampling_(syn0row, syn1Neg + (irow * vectorLength), expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, expLength, infVector != nullptr, stream); - } - } - - if (infVector == nullptr) { - addInfVectorKernel<<<128, 256, 256, *stream>>>(syn0row, neu1e, vectorLength); - } else { - addInfVectorKernel<<<128, 256, 256, *stream>>>(infVector, neu1e, vectorLength); - } - err = cudaStreamSynchronize(*stream); - if (0 != err) { - throw cuda_exception::build("helpers::skipgram_: Cannot synchronize stream after addInfVectorKernel", err); - } - - err = cudaFree(neu1e); - if (0 != err) { - throw cuda_exception::build("helpers::skipgram_: Cannot deallocate temp memory for lingual net", err); - } - } - BUILD_SINGLE_TEMPLATE(template void skipgram_, (NDArray& syn0, NDArray& syn1, NDArray& syn1Neg, NDArray& expTable, NDArray& negTable, NDArray& infVector, int target, int ngStarter, NDArray& indices, NDArray& codes, double alpha, Nd4jLong randomValue, const int hsRounds, const int nsRounds), FLOAT_TYPES); - - /* - * batched version of skipgram routine - * */ - template - void skipgramBatchExec_(NDArray &s0, NDArray &s1, NDArray &s1n, NDArray& expTableV, NDArray& negTableV, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, const int nsRounds, const bool preciseMode, const int numThreads) { -// (NDArray &s0, NDArray &s1, NDArray &s1n, NDArray& expTable, NDArray& negTable, NDArray& infVector, NDArray& targets, NDArray& negStarters, NDArray& indices, NDArray& codes, NDArray& lr, NDArray& nextRandom, const int nsRounds, const bool preciseMode, const int numThreads) { - //auto syn0 = reinterpret_cast(vsyn0); - //auto syn1 = reinterpret_cast(vsyn1); - //auto syn1Neg = reinterpret_cast(vsyn1Neg); - auto stream = s0.getContext()->getCudaStream(); - negTableV.tickReadDevice(); - negTableV.syncToHost(); - const auto expTable = reinterpret_cast(expTableV.specialBuffer()); - const auto negTable = reinterpret_cast(negTableV.buffer()); - const auto infVector = (T*)nullptr; //reinterpret_cast(infVector.specialBuffer()); - - const int vocabSize = s0.sizeAt(0); - const int vectorLength = s0.sizeAt(1); - const int expLength = expTableV.lengthOf(); - const int negLength = negTableV.lengthOf(); - - //T sneu1e[600]; - - //const auto numThreads = omp_get_max_threads(); - const auto idxShift = indices.isEmpty() ? 0 : indices.sizeAt(1); - const auto hsRounds = codes.isEmpty() ? 0 : codes.sizeAt(1); - - // regular mode provides 0 guarantees for reproducibility - auto numTargets = targets.lengthOf(); - targets.syncToHost(); - indices.syncToHost(); - codes.syncToHost(); - lr.syncToHost(); - nextRandom.syncToHost(); - negStarters.tickReadDevice(); - negStarters.syncToHost(); - auto bTarget = reinterpret_cast(targets.buffer()); //targets.bufferAsT(); - auto bIndices = reinterpret_cast(indices.buffer()); //indices.bufferAsT(); - auto bCodes = reinterpret_cast(codes.buffer()); //codes.bufferAsT(); - -// PRAGMA_OMP_PARALLEL_FOR_ARGS(num_threads(numThreads)) - for (int t = 0; t < numTargets; t++) { - T* neu1e;//lvectorLength <= 600 ? sneu1e : new T[vectorLength]; - auto err = cudaMalloc(&neu1e, vectorLength * sizeof(T)); - err = cudaMemset(neu1e, 0, vectorLength * sizeof(T)); - //memset(neu1e, 0, vectorLength * sizeof(T)); - - auto target = bTarget[t]; - auto alpha = lr.e(t); - unsigned long long randomValue = nextRandom.e(t); - - auto syn0row = reinterpret_cast(s0.specialBuffer()) + (target * vectorLength); - - if (hsRounds > 0) { - int irow = 0; - auto cShift = t * idxShift; - - for (int e = 0; e < hsRounds; e++) { - irow = bIndices[e + cShift]; - if (irow < 0 || irow >= vocabSize) - continue; - - auto syn1row = reinterpret_cast(s1.specialBuffer()) + (irow * vectorLength); - auto code = bCodes[e + cShift]; - - //nd4j_printf("syn0: [%i]; syn1: [%i]; code: [%i]\n", target, irow, code); - hSoftmax_(syn0row, syn1row, expTable, neu1e, alpha, vectorLength, code, expLength, false, stream); - } - } - - - if (nsRounds > 0) { - int irow = negStarters.e(t); - int nsStarter = irow; - for (int r = 0; r < nsRounds + 1; r++) { - if (r == 0) { - // target is known in advance - } else { - randomValue = randomValue * (unsigned long long) 25214903917 + 11; - auto idx = sd::math::nd4j_abs((randomValue >> 16) % negLength); - irow = idx >= negLength ? -1 : static_cast(negTable[idx]); - - if (irow < 0 || irow >= vocabSize) - irow = randomValue % (vocabSize - 1) + 1; - - if (irow == nsStarter) - continue; - } - auto syn1row = reinterpret_cast(s1n.specialBuffer()) + (irow * vectorLength); - - nSampling_(syn0row, syn1row, expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, expLength, false, stream); - } - } - addInfVectorKernel<<<128, 256, 256, *stream>>>(syn0row, neu1e, vectorLength); - err = cudaStreamSynchronize(*stream); - if (0 != err) { - throw cuda_exception::build("helpers::skipgramBatchExec_: Cannot synchronize stream after addInfVectorKernel", err); - } - - // optionally release temp arrays - err = cudaFree(neu1e); - if (err != 0) { - throw cuda_exception::build("helpers::skipgramBatchExec_: Cannot deallocate memory with stage", err); - break; - } -// if (vectorLength > 600) -// delete[] neu1e; - } - } - BUILD_SINGLE_TEMPLATE(template void skipgramBatchExec_, (NDArray &s0, NDArray &s1, NDArray &s1n, NDArray& expTable, NDArray& negTable, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, const int nsRounds, const bool preciseMode, const int numThreads), FLOAT_TYPES); - - void skipgram(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, NDArray &negTable, - NDArray &target, NDArray &ngStarter, int nsRounds, NDArray &indices, NDArray &codes, NDArray &alpha, NDArray &randomValue, NDArray &inferenceVector, const bool preciseMode, const int numWorkers) { - auto xType = syn0.dataType(); - // single round case - if ((ngStarter.isScalar() && !ngStarter.isEmpty())|| (target.isScalar() && !target.isEmpty())) { - auto hsRounds = codes.lengthOf(); - target.syncToHost(); - ngStarter.syncToHost(); - alpha.syncToHost(); - randomValue.syncToHost(); - - auto targetV = target.isEmpty() ? -1 : target.e(0); - auto starterV = ngStarter.isEmpty() ? -1 : ngStarter.e(0); - auto alphaV = alpha.e(0); - auto randomV = randomValue.e(0); - BUILD_SINGLE_SELECTOR(xType, skipgram_, (syn0, syn1, syn1Neg, expTable, negTable, inferenceVector, targetV, starterV, indices, codes, alphaV, randomV, hsRounds, nsRounds), FLOAT_TYPES); - } else if (ngStarter.isVector() || target.isVector()){ - // batch mode -// NDArray& infVector, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, const int nsRounds, const bool preciseMode, const int numThreads) - BUILD_SINGLE_SELECTOR(xType, skipgramBatchExec_, (syn0, syn1, syn1Neg, expTable, negTable, target, ngStarter, indices, codes, alpha, randomValue, nsRounds, preciseMode, numWorkers), FLOAT_TYPES); - } else - throw std::runtime_error("SkipGram: target must have rank 0 or 1"); - } - - template - static __global__ void checkContextKernel(int* context, T* syn0, T* neu1, int contextWidth, int vectorLength, int vocabSize) { - __shared__ bool hasError; - if (0 == threadIdx.x) { - hasError = false; - } - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (int c = start; c < contextWidth; c += step) { - if (context[c] >= vocabSize) - hasError = true; //throw std::runtime_error("Bad context 4"); - if (!hasError) { - T *syn0word = syn0 + (context[c] * vectorLength); - - for (int i = 0; i < vectorLength; i++) { - neu1[i] += syn0word[i]; - } - } - } - if (threadIdx.x == 0) { - if (hasError) - neu1[0] = DataTypeUtils::infOrMax(); - } - __syncthreads(); - } - - template - __global__ void shiftKernel(T* neu1, T* infVector, int contextWidth, int vectorLength) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (int i = start; i < vectorLength; i += step) { - neu1[i] /= contextWidth + int(infVector != nullptr); // ? 1 : 0); - } - } - - template - __global__ void fillUpSynonymsKernel(int starter, int contextWidth, int vectorLength, int* lockedWords, int* context, T* neu1e, T* syn0) { - auto start = threadIdx.x + blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; - - for (int c = starter + start; c < contextWidth; c += step) { - if (lockedWords[c] == 1) - continue; - - T *syn0word = syn0 + (context[c] * vectorLength); - - for (int i = 0; i < vectorLength; i++) { - syn0word[i] += neu1e[i]; - } - } - } - - template - void cbow_(LaunchContext* lc, void *vsyn0, void *vsyn1, void *vsyn1Neg, void *vexpTable, void *vnegTable, void *vinfVector, int target, int ngStarter, int *context, int *lockedWords, int *indices, int8_t *codes, double alpha, Nd4jLong randomValue, const int contextWidth, const int hsRounds, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const int numLabels, const bool trainWords) { - auto syn0 = reinterpret_cast(vsyn0); - auto syn1 = reinterpret_cast(vsyn1); - auto syn1Neg = reinterpret_cast(vsyn1Neg); - auto expTable = reinterpret_cast(vexpTable); - auto negTable = reinterpret_cast(vnegTable); - auto infVector = reinterpret_cast(vinfVector); - auto stream = lc->getCudaStream(); - - T* neu1; // = new T[vectorLength]; - T* neu1e; // = new T[vectorLength]; - size_t buffSize = sizeof(T) * vectorLength; - auto err = cudaMalloc(&neu1, buffSize); - err = cudaMalloc(&neu1e, buffSize); - err = cudaMemset(neu1, 0, buffSize); - err = cudaMemset(neu1e, 0, buffSize); - - // building neu1 for current window - checkContextKernel<<<1,1,128,*stream>>>(context, syn0, neu1, contextWidth, vectorLength, vocabSize); - - T checkVal; - err = cudaMemcpy(&checkVal, neu1, sizeof(T), cudaMemcpyDeviceToHost); - if (DataTypeUtils::infOrMax() == checkVal) - throw std::runtime_error("Bad context 4"); - // for inference we add additional inference vector - if (infVector != nullptr) { - addInfVectorKernel<<<128, 256, 128, *stream>>>(neu1, infVector, vectorLength); - } - - - // average neu1 - if (contextWidth > 0) { - shiftKernel<<<128, 256, 128, *stream>>>(neu1, infVector, contextWidth, vectorLength); - } - - // softmax round - if (hsRounds > 0) { - for (int i = 0; i < hsRounds; i++) { - if (indices[i] < 0 || indices[i] >= vocabSize) - throw std::runtime_error("Bad context 5"); - T* syn1Shifted = syn1 + (indices[i] * vectorLength); - hSoftmax_(neu1, syn1Shifted, expTable, neu1e, alpha, vectorLength, codes[i], expLength, infVector != nullptr, stream); - } - } - - auto nsStarter = ngStarter; - auto irow = nsStarter; - if (nsRounds > 0) { - for (int r = 0; r < nsRounds + 1; r++) { - if (r == 0) { - // target is known in advance - } else { - randomValue = randomValue * (unsigned long long) 25214903917 + 11; - auto idx = sd::math::nd4j_abs((randomValue >> 16) % negLength); - irow = idx >= negLength ? -1 : static_cast(negTable[idx]); - - if (irow < 0 || irow >= vocabSize) irow = randomValue % (vocabSize - 1) + 1; - if (irow == nsStarter) - continue; - } - - nSampling_(neu1, syn1Neg + (irow * vectorLength), expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, expLength, infVector != nullptr, stream); - } - } - - // if we don't train words - we skip start of idxSyn0 - int starter = trainWords == 1 ? 0 : contextWidth - numLabels; - - // propagate neu1e -> syn0 - if (infVector == nullptr) { - fillUpSynonymsKernel<<<1,1,128, *stream>>>(starter, contextWidth, vectorLength, lockedWords, context, neu1e, syn0); - } else { - - for (int i = 0; i < vectorLength; i++) { - infVector[i] += neu1e[i]; - } - } - err = cudaStreamSynchronize(*stream); - if (0 != err) { - throw cuda_exception::build( - "helpers::cbow_: Cannot synchronize stream after kernel executing", err); - } - err = cudaFree(neu1); - if (0 != err) { - throw cuda_exception::build( - "helpers::cbow_: Cannot deallocate memory for synonims table", err); - } - - err = cudaFree(neu1e); - if (0 != err) { - throw cuda_exception::build( - "helpers::cbow_: Cannot deallocate memory for antonims table", err); - } - } - BUILD_SINGLE_TEMPLATE(template void cbow_, (LaunchContext* lc, void *syn0, void *syn1, void *syn1Neg, void *expTable, void *vnegTable, void *vinfVector, int target, int ngStarter, int *context, int *lockedWords, int *indices, int8_t *codes, double alpha, Nd4jLong randomValue, const int contextWidth, const int hsRounds, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const int numLabels, const bool trainWords), FLOAT_TYPES); - - template - static __global__ void buildCurrentWindowKernel(int vocabSize, int contextWidth, int vectorLength, int* bContext, T* syn0, T* neu1, int* actualContext, int e) { - // building neu1 for current window - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (int c = start; c < contextWidth; c += step) { - // getting next context word - auto cContext = bContext[c + (e * contextWidth)]; - - // skipping padded values - if (cContext < 0) - continue; - -// if (cContext >= vocabSize) -// throw std::runtime_error("ContextID can't be >= vocab size"); - - T *syn0word = syn0 + (cContext * vectorLength); - - for (int i = 0; i < vectorLength; i++) - neu1[i] += syn0word[i]; - - atomicAdd(actualContext, 1); - } - } - - template - __global__ void arrangeNeuKernel(int vectorLength, T* neu1, T* infVector, int* actualContext) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (int i = start; i < vectorLength && *actualContext > 0; i += step) - neu1[i] /= (*actualContext + int(infVector != nullptr)); - } - - template - __global__ void applyShiftKernel(int* bContext, int* bLocker, T* syn0, T* neu1e, int contextWidth, int vectorLength, int e, int starter) { - auto step = blockDim.x * gridDim.x; - auto start = blockDim.x * blockIdx.x + threadIdx.x; - - for (int c = starter + start; c < contextWidth; c += step) { - // getting context - auto cContext = bContext[c + (e * contextWidth)]; - auto cLock = bLocker[c + (e * contextWidth)]; - - // skipping padded values - if (cContext < 0 || cLock == 1) - continue; - -// if (cContext >= vocabSize) -// throw std::runtime_error("ContextID can't be > vocab size"); - - // one word from context - T *syn0word = syn0 + (cContext * vectorLength); - - for (int i = 0; i < vectorLength; i++) - syn0word[i] += neu1e[i]; - - } - } - - template - void cbowBatchExec_(LaunchContext* lc, NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, void *vinfVector, NDArray &context, NDArray &lockedWords, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, NDArray &nLabels, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const bool trainWords, const int numThreads) { - const auto syn0 = reinterpret_cast(s0.specialBuffer()); //bufferAsT(); - const auto syn1 = reinterpret_cast(s1.specialBuffer()); //bufferAsT(); - const auto syn1Neg = reinterpret_cast(s1n.specialBuffer()); //bufferAsT(); - - const auto expTable = reinterpret_cast(vexpTable); - const auto negTable = reinterpret_cast(vnegTable); - const auto infVector = reinterpret_cast(vinfVector); - - auto stream = lc->getCudaStream(); - - indices.syncToHost(); - codes.syncToHost(); - negStarters.syncToHost(); - context.syncToHost(); - - //const auto numThreads = omp_get_max_threads(); - const auto idxShift = indices.isEmpty() ? 0 : indices.sizeAt(1); - const auto hsRounds = codes.isEmpty() ? 0 : codes.sizeAt(1); - const auto numTargets = context.sizeAt(0); - const int contextWidth = context.sizeAt(1); - //const auto bContext = reinterpret_cast(context.buffer()); //bufferAsT(); - const auto dContext = context.dataBuffer()->specialAsT(); //bufferAsT(); -// const auto bLocker = reinterpret_cast(lockedWords.buffer()); //lockedWords.bufferAsT(); - const auto dLocker = lockedWords.dataBuffer()->specialAsT(); //.specialBuffer()); //lockedWords.bufferAsT(); - const auto bIndices = indices.dataBuffer()->primaryAsT(); //buffer());//AsT(); - const auto bCodes = codes.dataBuffer()->primaryAsT(); //reinterpret_cast(codes.buffer()); //bufferAsT(); - const auto bStarters = negStarters.dataBuffer()->primaryAsT(); //reinterpret_cast(negStarters.buffer()); //AsT(); - const auto numIndices = indices.isEmpty() ? 0 : indices.sizeAt(1); - lr.syncToHost(); - nLabels.syncToHost(); - //PRAGMA_OMP_PARALLEL_FOR_ARGS(num_threads(numThreads) private(sneu1, sneu1e)) - //NDArray neuVector('c', {vectorLength}, DataTypeUtils::fromT()); - // auto neuEVector = neuVector; //NDArrayFactory::create('c', {vectorLength}); - T* neu1; // = reinterpret_cast(neuVector.specialBuffer());// = vectorLength <= 600 ? sneu1 : new T[vectorLength]; - T* neu1e; // = reinterpret_cast(neuVector.specialBuffer()); // = vectorLength <= 600 ? sneu1e : new T[vectorLength]; - auto cerr = cudaMalloc(&neu1, sizeof(T) * vectorLength); - if (cerr) { - throw cuda_exception::build("Cannot allocate temp vector buffer", cerr); - } - cerr = cudaMalloc(&neu1e, sizeof(T) * vectorLength); - if (cerr) { - throw cuda_exception::build("Cannot allocate temp vector buffer", cerr); - } - int* actualContext; - cerr = cudaMalloc(&actualContext, sizeof(int)); - if (cerr) { - throw cuda_exception::build("Cannot allocate counter buffer", cerr); - } - - for (int e = 0; e < numTargets; e++) { - -// auto err = cudaMalloc(&neu1, sizeof(T)* vectorLength); -// q err = cudaMalloc(&neu1e, sizeof(T)*vectorLength); -// -// // optionally we nullify temp arrays after successful (and on first) cycle -// memset(neu1, 0, sizeof(T) * vectorLength); -// memset(neu1e, 0, sizeof(T) * vectorLength); - - auto alpha = lr.e(e); - auto numLabels = nLabels.isEmpty() ? 0 : nLabels.e(e); - -// auto err = cudaMemset(actualContext, 0, sizeof(int)); -// if (err) { -// printf("Cuda error %d\n", err); break; -// } - - buildCurrentWindowKernel<<<1,1,128, *stream>>>(vocabSize, contextWidth, vectorLength, dContext, syn0, neu1, actualContext, e); - arrangeNeuKernel<<<1,1,128, *stream>>>(vectorLength, neu1, infVector, actualContext); - - // hierarchic softmax step - if (!indices.isEmpty()) { - for (int i = 0; i < numIndices; i++) { - const int cIndex = bIndices[(e * numIndices) + i]; - const int cCode = bCodes[(e * numIndices) + i]; - - // we're skipping padded values - if (cIndex < 0) - continue; - - if (cIndex >= vocabSize) - throw std::runtime_error("Index can't be > vocab size"); - - hSoftmax_(neu1, syn1 + (cIndex * vectorLength), expTable, neu1e, alpha, vectorLength, cCode, expLength, false, stream); - } - } - - // negative sampling step - if (!negStarters.isEmpty() && nsRounds > 0) { - int irow = bStarters[e]; - const int nsStarter = irow; - unsigned long long randomValue = nextRandom.e(e); - - for (int r = 0; r < nsRounds + 1; r++) { - // we're skipping rng on 0 step - if (r != 0) { - randomValue = randomValue * (unsigned long long) 25214903917 + 11; - auto idx = sd::math::nd4j_abs((randomValue >> 16) % negLength); - irow = idx >= negLength ? -1 : static_cast(negTable[idx]); - - if (irow < 0 || irow >= vocabSize) irow = randomValue % (vocabSize - 1) + 1; - if (irow == nsStarter) - continue; - - nSampling_(neu1, s1n.bufferWithOffset(irow * vectorLength), expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, expLength, infVector != nullptr, stream); - } else { - nSampling_(neu1, s1n.bufferWithOffset(irow * vectorLength), expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, expLength, infVector != nullptr, stream); - } - - //nd4j_printf("Thread <%i>: syn0: [%i]; s1n: [%i];\n", omp_get_thread_num(), 0, irow); - } - } - - - // if we're skipping labels - int starter = trainWords == 1 ? 0 : contextWidth - numLabels; - - // applying previously averaged results - applyShiftKernel<<<1,1,128, *stream>>>(dContext, dLocker, syn0, neu1e, contextWidth, vectorLength, e, starter); - - // optionally release temp arrays -// if (vectorLength > 600) { -// } - - } - cerr = cudaStreamSynchronize(*stream); - if (cerr) { - throw cuda_exception::build("Cannot syncronize stream before memory deallocation", cerr); - } - - cerr = cudaFree(neu1); - if (cerr) { - throw cuda_exception::build("Cannot deallocate temp buffer1", cerr); - } - cerr = cudaFree(neu1e); - if (cerr) { - throw cuda_exception::build("Cannot deallocate temp buffer1 E", cerr); - } - cerr = cudaFree(actualContext); - if (cerr) { - throw cuda_exception::build("Cannot deallocate temp buffer1", cerr); - } - - } - BUILD_SINGLE_TEMPLATE(template void cbowBatchExec_, (LaunchContext* lc, NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, void *vinfVector, NDArray &context, NDArray &lockedWords, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, NDArray &nLabels, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const bool trainWords, const int numThreads), FLOAT_TYPES); - - void cbow(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, NDArray &negTable, NDArray &target, NDArray &ngStarter, int nsRounds, NDArray &context, NDArray &lockedWords, NDArray &indices, NDArray &codes, NDArray &alpha, NDArray &randomValue, NDArray &numLabels, NDArray &inferenceVector, const bool trainWords, int numWorkers) { - auto xType = syn0.dataType(); - auto lc = context.getContext(); - indices.syncToHost(); - NDArray::prepareSpecialUse({&syn0, &syn1, &syn1Neg, &expTable, &negTable, &target, &ngStarter}, {&context, &lockedWords, &indices, &codes, &alpha, &randomValue, &numLabels, &inferenceVector}); - //auto stream = lc->getCudaStream(); - if ((context.rankOf() == 0 || context.rankOf() == 1) && (indices.rankOf() == 1 || indices.rankOf() == 0)) { - // single round case - /*nd4j_printf("Row exec; ContextWidth: %i; LockedWords: %i; numLabels: %i; Train words: %i\n", (int) context.lengthOf(), (int) lockedWords.lengthOf(), numLabels.isEmpty() ? 0 : numLabels.e(0), (int) trainWords); - if (context.lengthOf() == 2) { - context.printBuffer("context"); - lockedWords.printBuffer("locked"); - codes.printBuffer("codes"); - indices.printBuffer("indices"); - }*/ - - auto hsRounds = codes.lengthOf(); - target.syncToHost(); - numLabels.syncToHost(); - target.syncToHost(); - alpha.syncToHost(); - numLabels.syncToHost(); - codes.syncToHost(); - negTable.syncToHost(); - BUILD_SINGLE_SELECTOR(xType, cbow_, (lc, syn0.specialBuffer(), syn1.specialBuffer(), syn1Neg.specialBuffer(), expTable.specialBuffer(), negTable.buffer(), inferenceVector.specialBuffer(), target.isEmpty() ? -1 : target.e(0), ngStarter.isEmpty() ? -1 : ngStarter.e(0), reinterpret_cast(context.specialBuffer()), reinterpret_cast(lockedWords.specialBuffer()),reinterpret_cast(indices.buffer()), reinterpret_cast(codes.buffer()), alpha.e( 0), randomValue.e(0), (int) context.lengthOf(), hsRounds, nsRounds, (int) syn0.sizeAt(0), (int) syn0.sizeAt(1), (int) expTable.lengthOf(), (int) negTable.lengthOf(), numLabels.isEmpty() ? 0 : numLabels.e(0), trainWords), FLOAT_TYPES); - } else if (context.rankOf() == 2 && indices.rankOf() == 2) { - // batch mode - //nd4j_printf("Batch exec\n",""); - - BUILD_SINGLE_SELECTOR(xType, cbowBatchExec_, (lc, syn0, syn1, syn1Neg, expTable.specialBuffer(), negTable.specialBuffer(), nullptr, context, lockedWords, target, ngStarter, indices, codes, alpha, randomValue, numLabels, nsRounds, syn0.sizeAt(0), syn0.sizeAt(1), expTable.lengthOf(), negTable.isEmpty() ? 0 : negTable.lengthOf(), trainWords, numWorkers), FLOAT_TYPES); - } else - throw std::runtime_error("CBOW: context must have rank 0/1 or 2"); - - NDArray::registerSpecialUse({&syn0, &syn1, &syn1Neg, &expTable, &negTable, &target, &ngStarter}, {&context, &lockedWords, &indices, &codes, &alpha, &randomValue, &numLabels, &inferenceVector}); - } +namespace ops { +namespace helpers { +template +__global__ void hSoftmaxKernel(void *vsyn0, void *vsyn1, void *vexpTable, + void *vneu1e, double alpha, int vectorLength, + int code, int expLength, bool isInference) { + auto syn0 = reinterpret_cast(vsyn0); + auto syn1 = reinterpret_cast(vsyn1); + auto expTable = reinterpret_cast(vexpTable); + auto neu1e = reinterpret_cast(vneu1e); + + T dot(0.0f); + T g(0.0f); + T f(0.0f); + + // dot + for (int e = 0; e < vectorLength; e++) { + dot += syn0[e] * syn1[e]; + } + + // gradient + if (dot < (T)-HS_MAX_EXP || dot >= (T)HS_MAX_EXP) return; + + int idx = static_cast((dot + HS_MAX_EXP) * + ((float)expLength / HS_MAX_EXP / 2.0f)); + + if (idx >= expLength || idx < 0) return; + + f = expTable[idx]; + g = (static_cast(1.0f) - static_cast(code) - f) * (T)alpha; + + // axpy1 + + for (int e = 0; e < vectorLength; e++) { + neu1e[e] = g * syn1[e] + neu1e[e]; + } + + // axpy2 + if (!isInference) { + for (int e = 0; e < vectorLength; e++) { + syn1[e] = g * syn0[e] + syn1[e]; + } + } +} + +template +void hSoftmax_(void *vsyn0, void *vsyn1, void *vexpTable, void *vneu1e, + double alpha, int vectorLength, int code, int expLength, + bool isInference, cudaStream_t *stream) { + hSoftmaxKernel<<<1, 1, 128, *stream>>>(vsyn0, vsyn1, vexpTable, vneu1e, + alpha, vectorLength, code, + expLength, isInference); +} + +template +__global__ void nSamplingKernel(void *vsyn0, void *vsyn1Neg, void *vexpTable, + void *vneu1e, double alpha, int vectorLength, + int code, int expLength, bool isInference) { + auto syn0 = reinterpret_cast(vsyn0); + auto syn1Neg = reinterpret_cast(vsyn1Neg); + auto expTable = reinterpret_cast(vexpTable); + auto neu1e = reinterpret_cast(vneu1e); + + T dot = (T)0.0f; + T g = (T)0.0f; + + for (int e = 0; e < vectorLength; e++) { + dot += syn0[e] * syn1Neg[e]; + } + + if (dot > HS_MAX_EXP) + g = (code - 1) * alpha; + else if (dot < (T)-HS_MAX_EXP) + g = (code - 0) * alpha; + else { + int idx = (int)((dot + (T)HS_MAX_EXP) * ((T)expLength / HS_MAX_EXP / 2.0)); + if (idx >= expLength) return; + + if (idx < 0) return; + + g = ((T)code - expTable[idx]) * alpha; + } + + // axpy1 + for (int e = 0; e < vectorLength; e++) { + neu1e[e] = g * syn1Neg[e] + neu1e[e]; + } + + // axpy2 + if (!isInference) { + for (int e = 0; e < vectorLength; e++) { + syn1Neg[e] = g * syn0[e] + syn1Neg[e]; + } + } +} + +template +void nSampling_(void *vsyn0, void *vsyn1Neg, void *vexpTable, void *vneu1e, + double alpha, int vectorLength, int code, int expLength, + bool isInference, cudaStream_t *stream) { + nSamplingKernel<<<1, 1, 128, *stream>>>(vsyn0, vsyn1Neg, vexpTable, vneu1e, + alpha, vectorLength, code, + expLength, isInference); +} + +/* + * binarySearch - find element in haystack buffer (haystack - sorted device + * memory) + * */ +int binarySearch(const int *haystack, const int needle, + const int totalElements) { + int firstIndex = 0; + int lastIndex = totalElements - 1; + int halfIndex = + sd::math::nd4j_floor((lastIndex + firstIndex) / (float)2); + + while (haystack[halfIndex] != needle && firstIndex < lastIndex) { + if (needle < haystack[halfIndex]) { + lastIndex = halfIndex - 1; + } else if (needle > haystack[halfIndex]) { + firstIndex = halfIndex + 1; + } + halfIndex = + sd::math::nd4j_floor((lastIndex + firstIndex) / (float)2); + } + + return (haystack[halfIndex] == needle) ? halfIndex : -1; +} +template +__global__ void addInfVectorKernel(T *neu1, T *infVector, int vectorLength) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (auto i = start; i < vectorLength; i += step) { + neu1[i] += infVector[i]; + } +} + +template +void skipgram_(NDArray &s0, NDArray &s1, NDArray &s1n, NDArray &expTableV, + NDArray &negTableV, NDArray &infV, int target, int ngStarter, + NDArray &indices, NDArray &codes, double alpha, + Nd4jLong randomValue, const int hsRounds, const int nsRounds) { + // void *vsyn0, void *vsyn1, void *vsyn1Neg, void + // *vexpTable, void *vnegTable, void *vinfVector, int + // target, int ngStarter, int *indices, int8_t *codes, + // double alpha, Nd4jLong randomValue, const int hsRounds, + // const int nsRounds, const int vocabSize, const int + // vectorLength, const int expLength, const int negLength) + // { + auto syn0 = reinterpret_cast(s0.specialBuffer()); + auto syn1 = reinterpret_cast(s1.specialBuffer()); + auto syn1Neg = reinterpret_cast(s1n.specialBuffer()); + auto expTable = reinterpret_cast(expTableV.specialBuffer()); + auto negTable = reinterpret_cast(negTableV.specialBuffer()); + auto infVector = reinterpret_cast(infV.specialBuffer()); + const int vocabSize = s0.sizeAt(0); + const int vectorLength = s0.sizeAt(1); + const int expLength = expTableV.lengthOf(); + const int negLength = negTableV.lengthOf(); + indices.tickReadDevice(); + indices.syncToHost(); + codes.tickReadDevice(); + codes.syncToHost(); + auto stream = s0.getContext()->getCudaStream(); + + T *neu1e; // = new T[vectorLength]; + // memset(neu1e, 0, vectorLength * sizeof(T)); + auto err = cudaMalloc(&neu1e, sizeof(T) * vectorLength); + err = cudaMemset(neu1e, 0, sizeof(T) * vectorLength); + // hierarchic softmax goes first (if enabled) + + auto syn0row = + infVector != nullptr ? infVector : syn0 + (target * vectorLength); + auto irow = 0; + if (hsRounds > 0) { + for (int r = 0; r < hsRounds; r++) { + irow = indices.t(r); + if (irow < 0 || irow >= vocabSize) break; + + hSoftmax_(syn0row, syn1 + (irow * vectorLength), expTable, neu1e, + alpha, vectorLength, codes.t(r), expLength, + infVector != nullptr, stream); + } + } + + // negative sampling goes second (if enabled) + auto nsStarter = ngStarter; + irow = nsStarter; + if (nsRounds > 0) { + for (int r = 0; r < nsRounds + 1; r++) { + if (r == 0) { + // target is known in advance + } else { + randomValue = randomValue * (unsigned long long)25214903917 + 11; + auto idx = + sd::math::nd4j_abs((randomValue >> 16) % negLength); + irow = idx >= negLength ? -1 : negTableV.e(idx); + + if (irow < 0 || irow >= vocabSize) + irow = randomValue % (vocabSize - 1) + 1; + if (irow == nsStarter) continue; + } + + nSampling_(syn0row, syn1Neg + (irow * vectorLength), expTable, neu1e, + alpha, vectorLength, r == 0 ? 1 : 0, expLength, + infVector != nullptr, stream); + } + } + + if (infVector == nullptr) { + addInfVectorKernel + <<<128, 256, 256, *stream>>>(syn0row, neu1e, vectorLength); + } else { + addInfVectorKernel + <<<128, 256, 256, *stream>>>(infVector, neu1e, vectorLength); + } + err = cudaStreamSynchronize(*stream); + if (0 != err) { + throw cuda_exception::build( + "helpers::skipgram_: Cannot synchronize stream after " + "addInfVectorKernel", + err); + } + + err = cudaFree(neu1e); + if (0 != err) { + throw cuda_exception::build( + "helpers::skipgram_: Cannot deallocate temp memory for lingual net", + err); + } +} +BUILD_SINGLE_TEMPLATE(template void skipgram_, + (NDArray & syn0, NDArray &syn1, NDArray &syn1Neg, + NDArray &expTable, NDArray &negTable, NDArray &infVector, + int target, int ngStarter, NDArray &indices, + NDArray &codes, double alpha, Nd4jLong randomValue, + const int hsRounds, const int nsRounds), + FLOAT_TYPES); + +/* + * batched version of skipgram routine + * */ +template +void skipgramBatchExec_(NDArray &s0, NDArray &s1, NDArray &s1n, + NDArray &expTableV, NDArray &negTableV, + NDArray &targets, NDArray &negStarters, + NDArray &indices, NDArray &codes, NDArray &lr, + NDArray &nextRandom, const int nsRounds, + const bool preciseMode, const int numThreads) { + // (NDArray &s0, NDArray &s1, NDArray &s1n, NDArray& expTable, + // NDArray& negTable, NDArray& infVector, NDArray& targets, + // NDArray& negStarters, NDArray& indices, NDArray& codes, NDArray& + // lr, NDArray& nextRandom, const int nsRounds, const bool + // preciseMode, const int numThreads) { + // auto syn0 = reinterpret_cast(vsyn0); + // auto syn1 = reinterpret_cast(vsyn1); + // auto syn1Neg = reinterpret_cast(vsyn1Neg); + auto stream = s0.getContext()->getCudaStream(); + negTableV.tickReadDevice(); + negTableV.syncToHost(); + const auto expTable = reinterpret_cast(expTableV.specialBuffer()); + const auto negTable = reinterpret_cast(negTableV.buffer()); + const auto infVector = + (T *)nullptr; // reinterpret_cast(infVector.specialBuffer()); + + const int vocabSize = s0.sizeAt(0); + const int vectorLength = s0.sizeAt(1); + const int expLength = expTableV.lengthOf(); + const int negLength = negTableV.lengthOf(); + + // T sneu1e[600]; + + // const auto numThreads = omp_get_max_threads(); + const auto idxShift = indices.isEmpty() ? 0 : indices.sizeAt(1); + const auto hsRounds = codes.isEmpty() ? 0 : codes.sizeAt(1); + + // regular mode provides 0 guarantees for reproducibility + auto numTargets = targets.lengthOf(); + targets.syncToHost(); + indices.syncToHost(); + codes.syncToHost(); + lr.syncToHost(); + nextRandom.syncToHost(); + negStarters.tickReadDevice(); + negStarters.syncToHost(); + auto bTarget = + reinterpret_cast(targets.buffer()); // targets.bufferAsT(); + auto bIndices = + reinterpret_cast(indices.buffer()); // indices.bufferAsT(); + auto bCodes = + reinterpret_cast(codes.buffer()); // codes.bufferAsT(); + + // PRAGMA_OMP_PARALLEL_FOR_ARGS(num_threads(numThreads)) + for (int t = 0; t < numTargets; t++) { + T *neu1e; // lvectorLength <= 600 ? sneu1e : new T[vectorLength]; + auto err = cudaMalloc(&neu1e, vectorLength * sizeof(T)); + err = cudaMemset(neu1e, 0, vectorLength * sizeof(T)); + // memset(neu1e, 0, vectorLength * sizeof(T)); + + auto target = bTarget[t]; + auto alpha = lr.e(t); + unsigned long long randomValue = nextRandom.e(t); + + auto syn0row = + reinterpret_cast(s0.specialBuffer()) + (target * vectorLength); + + if (hsRounds > 0) { + int irow = 0; + auto cShift = t * idxShift; + + for (int e = 0; e < hsRounds; e++) { + irow = bIndices[e + cShift]; + if (irow < 0 || irow >= vocabSize) continue; + + auto syn1row = + reinterpret_cast(s1.specialBuffer()) + (irow * vectorLength); + auto code = bCodes[e + cShift]; + + // nd4j_printf("syn0: [%i]; syn1: [%i]; code: [%i]\n", target, irow, + // code); + hSoftmax_(syn0row, syn1row, expTable, neu1e, alpha, vectorLength, + code, expLength, false, stream); + } + } + + if (nsRounds > 0) { + int irow = negStarters.e(t); + int nsStarter = irow; + for (int r = 0; r < nsRounds + 1; r++) { + if (r == 0) { + // target is known in advance + } else { + randomValue = randomValue * (unsigned long long)25214903917 + 11; + auto idx = + sd::math::nd4j_abs((randomValue >> 16) % negLength); + irow = idx >= negLength ? -1 : static_cast(negTable[idx]); + + if (irow < 0 || irow >= vocabSize) + irow = randomValue % (vocabSize - 1) + 1; + + if (irow == nsStarter) continue; + } + auto syn1row = + reinterpret_cast(s1n.specialBuffer()) + (irow * vectorLength); + nSampling_(syn0row, syn1row, expTable, neu1e, alpha, vectorLength, + r == 0 ? 1 : 0, expLength, false, stream); + } + } + addInfVectorKernel + <<<128, 256, 256, *stream>>>(syn0row, neu1e, vectorLength); + err = cudaStreamSynchronize(*stream); + if (0 != err) { + throw cuda_exception::build( + "helpers::skipgramBatchExec_: Cannot synchronize stream after " + "addInfVectorKernel", + err); + } + + // optionally release temp arrays + err = cudaFree(neu1e); + if (err != 0) { + throw cuda_exception::build( + "helpers::skipgramBatchExec_: Cannot deallocate memory with stage", + err); + break; + } + // if (vectorLength > 600) + // delete[] neu1e; + } +} +BUILD_SINGLE_TEMPLATE(template void skipgramBatchExec_, + (NDArray & s0, NDArray &s1, NDArray &s1n, + NDArray &expTable, NDArray &negTable, NDArray &targets, + NDArray &negStarters, NDArray &indices, NDArray &codes, + NDArray &lr, NDArray &nextRandom, const int nsRounds, + const bool preciseMode, const int numThreads), + FLOAT_TYPES); + +void skipgram(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, + NDArray &negTable, NDArray &target, NDArray &ngStarter, + int nsRounds, NDArray &indices, NDArray &codes, NDArray &alpha, + NDArray &randomValue, NDArray &inferenceVector, + const bool preciseMode, const int numWorkers) { + auto xType = syn0.dataType(); + // single round case + if ((ngStarter.isScalar() && !ngStarter.isEmpty()) || + (target.isScalar() && !target.isEmpty())) { + auto hsRounds = codes.lengthOf(); + target.syncToHost(); + ngStarter.syncToHost(); + alpha.syncToHost(); + randomValue.syncToHost(); + + auto targetV = target.isEmpty() ? -1 : target.e(0); + auto starterV = ngStarter.isEmpty() ? -1 : ngStarter.e(0); + auto alphaV = alpha.e(0); + auto randomV = randomValue.e(0); + BUILD_SINGLE_SELECTOR( + xType, skipgram_, + (syn0, syn1, syn1Neg, expTable, negTable, inferenceVector, targetV, + starterV, indices, codes, alphaV, randomV, hsRounds, nsRounds), + FLOAT_TYPES); + } else if (ngStarter.isVector() || target.isVector()) { + // batch mode + // NDArray& infVector, NDArray &targets, NDArray + // &negStarters, NDArray &indices, NDArray &codes, + // NDArray &lr, NDArray &nextRandom, const int nsRounds, + // const bool preciseMode, const int numThreads) + BUILD_SINGLE_SELECTOR( + xType, skipgramBatchExec_, + (syn0, syn1, syn1Neg, expTable, negTable, target, ngStarter, indices, + codes, alpha, randomValue, nsRounds, preciseMode, numWorkers), + FLOAT_TYPES); + } else + throw std::runtime_error("SkipGram: target must have rank 0 or 1"); +} + +template +static __global__ void checkContextKernel(int *context, T *syn0, T *neu1, + int contextWidth, int vectorLength, + int vocabSize) { + __shared__ bool hasError; + if (0 == threadIdx.x) { + hasError = false; + } + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (int c = start; c < contextWidth; c += step) { + if (context[c] >= vocabSize) + hasError = true; // throw std::runtime_error("Bad context 4"); + if (!hasError) { + T *syn0word = syn0 + (context[c] * vectorLength); + + for (int i = 0; i < vectorLength; i++) { + neu1[i] += syn0word[i]; + } + } + } + if (threadIdx.x == 0) { + if (hasError) neu1[0] = DataTypeUtils::infOrMax(); + } + __syncthreads(); +} + +template +__global__ void shiftKernel(T *neu1, T *infVector, int contextWidth, + int vectorLength) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (int i = start; i < vectorLength; i += step) { + neu1[i] /= contextWidth + int(infVector != nullptr); // ? 1 : 0); + } +} + +template +__global__ void fillUpSynonymsKernel(int starter, int contextWidth, + int vectorLength, int *lockedWords, + int *context, T *neu1e, T *syn0) { + auto start = threadIdx.x + blockIdx.x * blockDim.x; + auto step = blockDim.x * gridDim.x; + + for (int c = starter + start; c < contextWidth; c += step) { + if (lockedWords[c] == 1) continue; + + T *syn0word = syn0 + (context[c] * vectorLength); + + for (int i = 0; i < vectorLength; i++) { + syn0word[i] += neu1e[i]; + } + } +} + +template +void cbow_(LaunchContext *lc, void *vsyn0, void *vsyn1, void *vsyn1Neg, + void *vexpTable, void *vnegTable, void *vinfVector, int target, + int ngStarter, int *context, int *lockedWords, int *indices, + int8_t *codes, double alpha, Nd4jLong randomValue, + const int contextWidth, const int hsRounds, const int nsRounds, + const int vocabSize, const int vectorLength, const int expLength, + const int negLength, const int numLabels, const bool trainWords) { + auto syn0 = reinterpret_cast(vsyn0); + auto syn1 = reinterpret_cast(vsyn1); + auto syn1Neg = reinterpret_cast(vsyn1Neg); + auto expTable = reinterpret_cast(vexpTable); + auto negTable = reinterpret_cast(vnegTable); + auto infVector = reinterpret_cast(vinfVector); + auto stream = lc->getCudaStream(); + + T *neu1; // = new T[vectorLength]; + T *neu1e; // = new T[vectorLength]; + size_t buffSize = sizeof(T) * vectorLength; + auto err = cudaMalloc(&neu1, buffSize); + err = cudaMalloc(&neu1e, buffSize); + err = cudaMemset(neu1, 0, buffSize); + err = cudaMemset(neu1e, 0, buffSize); + + // building neu1 for current window + checkContextKernel<<<1, 1, 128, *stream>>>( + context, syn0, neu1, contextWidth, vectorLength, vocabSize); + + T checkVal; + err = cudaMemcpy(&checkVal, neu1, sizeof(T), cudaMemcpyDeviceToHost); + if (DataTypeUtils::infOrMax() == checkVal) + throw std::runtime_error("Bad context 4"); + // for inference we add additional inference vector + if (infVector != nullptr) { + addInfVectorKernel + <<<128, 256, 128, *stream>>>(neu1, infVector, vectorLength); + } + + // average neu1 + if (contextWidth > 0) { + shiftKernel<<<128, 256, 128, *stream>>>(neu1, infVector, contextWidth, + vectorLength); + } + + // softmax round + if (hsRounds > 0) { + for (int i = 0; i < hsRounds; i++) { + if (indices[i] < 0 || indices[i] >= vocabSize) + throw std::runtime_error("Bad context 5"); + T *syn1Shifted = syn1 + (indices[i] * vectorLength); + hSoftmax_(neu1, syn1Shifted, expTable, neu1e, alpha, vectorLength, + codes[i], expLength, infVector != nullptr, stream); + } + } + + auto nsStarter = ngStarter; + auto irow = nsStarter; + if (nsRounds > 0) { + for (int r = 0; r < nsRounds + 1; r++) { + if (r == 0) { + // target is known in advance + } else { + randomValue = randomValue * (unsigned long long)25214903917 + 11; + auto idx = + sd::math::nd4j_abs((randomValue >> 16) % negLength); + irow = idx >= negLength ? -1 : static_cast(negTable[idx]); + + if (irow < 0 || irow >= vocabSize) + irow = randomValue % (vocabSize - 1) + 1; + if (irow == nsStarter) continue; + } + + nSampling_(neu1, syn1Neg + (irow * vectorLength), expTable, neu1e, + alpha, vectorLength, r == 0 ? 1 : 0, expLength, + infVector != nullptr, stream); + } + } + + // if we don't train words - we skip start of idxSyn0 + int starter = trainWords == 1 ? 0 : contextWidth - numLabels; + + // propagate neu1e -> syn0 + if (infVector == nullptr) { + fillUpSynonymsKernel<<<1, 1, 128, *stream>>>( + starter, contextWidth, vectorLength, lockedWords, context, neu1e, syn0); + } else { + for (int i = 0; i < vectorLength; i++) { + infVector[i] += neu1e[i]; + } + } + err = cudaStreamSynchronize(*stream); + if (0 != err) { + throw cuda_exception::build( + "helpers::cbow_: Cannot synchronize stream after kernel executing", + err); + } + err = cudaFree(neu1); + if (0 != err) { + throw cuda_exception::build( + "helpers::cbow_: Cannot deallocate memory for synonims table", err); + } + + err = cudaFree(neu1e); + if (0 != err) { + throw cuda_exception::build( + "helpers::cbow_: Cannot deallocate memory for antonims table", err); + } +} +BUILD_SINGLE_TEMPLATE( + template void cbow_, + (LaunchContext * lc, void *syn0, void *syn1, void *syn1Neg, void *expTable, + void *vnegTable, void *vinfVector, int target, int ngStarter, int *context, + int *lockedWords, int *indices, int8_t *codes, double alpha, + Nd4jLong randomValue, const int contextWidth, const int hsRounds, + const int nsRounds, const int vocabSize, const int vectorLength, + const int expLength, const int negLength, const int numLabels, + const bool trainWords), + FLOAT_TYPES); + +template +static __global__ void buildCurrentWindowKernel(int vocabSize, int contextWidth, + int vectorLength, int *bContext, + T *syn0, T *neu1, + int *actualContext, int e) { + // building neu1 for current window + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (int c = start; c < contextWidth; c += step) { + // getting next context word + auto cContext = bContext[c + (e * contextWidth)]; + + // skipping padded values + if (cContext < 0) continue; + + // if (cContext >= vocabSize) + // throw std::runtime_error("ContextID can't be >= + // vocab size"); + + T *syn0word = syn0 + (cContext * vectorLength); + + for (int i = 0; i < vectorLength; i++) neu1[i] += syn0word[i]; + + atomicAdd(actualContext, 1); + } +} + +template +__global__ void arrangeNeuKernel(int vectorLength, T *neu1, T *infVector, + int *actualContext) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (int i = start; i 0; i += step) + neu1[i] /= (*actualContext + int(infVector != nullptr)); +} + +template +__global__ void applyShiftKernel(int *bContext, int *bLocker, T *syn0, T *neu1e, + int contextWidth, int vectorLength, int e, + int starter) { + auto step = blockDim.x * gridDim.x; + auto start = blockDim.x * blockIdx.x + threadIdx.x; + + for (int c = starter + start; c < contextWidth; c += step) { + // getting context + auto cContext = bContext[c + (e * contextWidth)]; + auto cLock = bLocker[c + (e * contextWidth)]; + + // skipping padded values + if (cContext < 0 || cLock == 1) continue; + + // if (cContext >= vocabSize) + // throw std::runtime_error("ContextID can't be > + // vocab size"); + + // one word from context + T *syn0word = syn0 + (cContext * vectorLength); + + for (int i = 0; i < vectorLength; i++) syn0word[i] += neu1e[i]; + } +} + +template +void cbowBatchExec_(LaunchContext *lc, NDArray &s0, NDArray &s1, NDArray &s1n, + void *vexpTable, void *vnegTable, void *vinfVector, + NDArray &context, NDArray &lockedWords, NDArray &targets, + NDArray &negStarters, NDArray &indices, NDArray &codes, + NDArray &lr, NDArray &nextRandom, NDArray &nLabels, + const int nsRounds, const int vocabSize, + const int vectorLength, const int expLength, + const int negLength, const bool trainWords, + const int numThreads) { + const auto syn0 = + reinterpret_cast(s0.specialBuffer()); // bufferAsT(); + const auto syn1 = + reinterpret_cast(s1.specialBuffer()); // bufferAsT(); + const auto syn1Neg = + reinterpret_cast(s1n.specialBuffer()); // bufferAsT(); + + const auto expTable = reinterpret_cast(vexpTable); + const auto negTable = reinterpret_cast(vnegTable); + const auto infVector = reinterpret_cast(vinfVector); + + auto stream = lc->getCudaStream(); + + indices.syncToHost(); + codes.syncToHost(); + negStarters.syncToHost(); + context.syncToHost(); + + // const auto numThreads = omp_get_max_threads(); + const auto idxShift = indices.isEmpty() ? 0 : indices.sizeAt(1); + const auto hsRounds = codes.isEmpty() ? 0 : codes.sizeAt(1); + const auto numTargets = context.sizeAt(0); + const int contextWidth = context.sizeAt(1); + // const auto bContext = reinterpret_cast(context.buffer()); + // //bufferAsT(); + const auto dContext = + context.dataBuffer()->specialAsT(); // bufferAsT(); + // const auto bLocker = + // reinterpret_cast(lockedWords.buffer()); + // //lockedWords.bufferAsT(); + const auto dLocker = + lockedWords.dataBuffer() + ->specialAsT(); //.specialBuffer()); + ////lockedWords.bufferAsT(); + const auto bIndices = + indices.dataBuffer()->primaryAsT(); // buffer());//AsT(); + const auto bCodes = + codes.dataBuffer() + ->primaryAsT(); // reinterpret_cast(codes.buffer()); + // //bufferAsT(); + const auto bStarters = + negStarters.dataBuffer() + ->primaryAsT(); // reinterpret_cast(negStarters.buffer()); + // //AsT(); + const auto numIndices = indices.isEmpty() ? 0 : indices.sizeAt(1); + lr.syncToHost(); + nLabels.syncToHost(); + // PRAGMA_OMP_PARALLEL_FOR_ARGS(num_threads(numThreads) private(sneu1, + // sneu1e)) NDArray neuVector('c', {vectorLength}, DataTypeUtils::fromT()); + // auto neuEVector = neuVector; //NDArrayFactory::create('c', + // {vectorLength}); + T *neu1; // = reinterpret_cast(neuVector.specialBuffer());// = + // vectorLength <= 600 ? sneu1 : new T[vectorLength]; + T *neu1e; // = reinterpret_cast(neuVector.specialBuffer()); // = + // vectorLength <= 600 ? sneu1e : new T[vectorLength]; + auto cerr = cudaMalloc(&neu1, sizeof(T) * vectorLength); + if (cerr) { + throw cuda_exception::build("Cannot allocate temp vector buffer", cerr); + } + cerr = cudaMalloc(&neu1e, sizeof(T) * vectorLength); + if (cerr) { + throw cuda_exception::build("Cannot allocate temp vector buffer", cerr); + } + int *actualContext; + cerr = cudaMalloc(&actualContext, sizeof(int)); + if (cerr) { + throw cuda_exception::build("Cannot allocate counter buffer", cerr); + } + + for (int e = 0; e < numTargets; e++) { + // auto err = cudaMalloc(&neu1, sizeof(T)* vectorLength); + // q err = cudaMalloc(&neu1e, sizeof(T)*vectorLength); + // + // // optionally we nullify temp arrays after successful + // (and on first) cycle memset(neu1, 0, sizeof(T) * + // vectorLength); memset(neu1e, 0, sizeof(T) * + // vectorLength); + + auto alpha = lr.e(e); + auto numLabels = nLabels.isEmpty() ? 0 : nLabels.e(e); + + // auto err = cudaMemset(actualContext, 0, sizeof(int)); + // if (err) { + // printf("Cuda error %d\n", err); break; + // } + + buildCurrentWindowKernel + <<<1, 1, 128, *stream>>>(vocabSize, contextWidth, vectorLength, + dContext, syn0, neu1, actualContext, e); + arrangeNeuKernel + <<<1, 1, 128, *stream>>>(vectorLength, neu1, infVector, actualContext); + + // hierarchic softmax step + if (!indices.isEmpty()) { + for (int i = 0; i < numIndices; i++) { + const int cIndex = bIndices[(e * numIndices) + i]; + const int cCode = bCodes[(e * numIndices) + i]; + + // we're skipping padded values + if (cIndex < 0) continue; + + if (cIndex >= vocabSize) + throw std::runtime_error("Index can't be > vocab size"); + + hSoftmax_(neu1, syn1 + (cIndex * vectorLength), expTable, neu1e, + alpha, vectorLength, cCode, expLength, false, stream); + } + } + + // negative sampling step + if (!negStarters.isEmpty() && nsRounds > 0) { + int irow = bStarters[e]; + const int nsStarter = irow; + unsigned long long randomValue = nextRandom.e(e); + + for (int r = 0; r < nsRounds + 1; r++) { + // we're skipping rng on 0 step + if (r != 0) { + randomValue = randomValue * (unsigned long long)25214903917 + 11; + auto idx = + sd::math::nd4j_abs((randomValue >> 16) % negLength); + irow = idx >= negLength ? -1 : static_cast(negTable[idx]); + + if (irow < 0 || irow >= vocabSize) + irow = randomValue % (vocabSize - 1) + 1; + if (irow == nsStarter) continue; + + nSampling_(neu1, s1n.bufferWithOffset(irow * vectorLength), + expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, + expLength, infVector != nullptr, stream); + } else { + nSampling_(neu1, s1n.bufferWithOffset(irow * vectorLength), + expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, + expLength, infVector != nullptr, stream); } + + // nd4j_printf("Thread <%i>: syn0: [%i]; s1n: [%i];\n", + // omp_get_thread_num(), 0, irow); + } } -} \ No newline at end of file + + // if we're skipping labels + int starter = trainWords == 1 ? 0 : contextWidth - numLabels; + + // applying previously averaged results + applyShiftKernel<<<1, 1, 128, *stream>>>( + dContext, dLocker, syn0, neu1e, contextWidth, vectorLength, e, starter); + + // optionally release temp arrays + // if (vectorLength > 600) { + // } + } + cerr = cudaStreamSynchronize(*stream); + if (cerr) { + throw cuda_exception::build( + "Cannot syncronize stream before memory deallocation", cerr); + } + + cerr = cudaFree(neu1); + if (cerr) { + throw cuda_exception::build("Cannot deallocate temp buffer1", cerr); + } + cerr = cudaFree(neu1e); + if (cerr) { + throw cuda_exception::build("Cannot deallocate temp buffer1 E", cerr); + } + cerr = cudaFree(actualContext); + if (cerr) { + throw cuda_exception::build("Cannot deallocate temp buffer1", cerr); + } +} +BUILD_SINGLE_TEMPLATE(template void cbowBatchExec_, + (LaunchContext * lc, NDArray &s0, NDArray &s1, + NDArray &s1n, void *vexpTable, void *vnegTable, + void *vinfVector, NDArray &context, NDArray &lockedWords, + NDArray &targets, NDArray &negStarters, NDArray &indices, + NDArray &codes, NDArray &lr, NDArray &nextRandom, + NDArray &nLabels, const int nsRounds, + const int vocabSize, const int vectorLength, + const int expLength, const int negLength, + const bool trainWords, const int numThreads), + FLOAT_TYPES); + +void cbow(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, + NDArray &negTable, NDArray &target, NDArray &ngStarter, int nsRounds, + NDArray &context, NDArray &lockedWords, NDArray &indices, + NDArray &codes, NDArray &alpha, NDArray &randomValue, + NDArray &numLabels, NDArray &inferenceVector, const bool trainWords, + int numWorkers) { + auto xType = syn0.dataType(); + auto lc = context.getContext(); + indices.syncToHost(); + NDArray::prepareSpecialUse( + {&syn0, &syn1, &syn1Neg, &expTable, &negTable, &target, &ngStarter}, + {&context, &lockedWords, &indices, &codes, &alpha, &randomValue, + &numLabels, &inferenceVector}); + // auto stream = lc->getCudaStream(); + if ((context.rankOf() == 0 || context.rankOf() == 1) && + (indices.rankOf() == 1 || indices.rankOf() == 0)) { + // single round case + /*nd4j_printf("Row exec; ContextWidth: %i; LockedWords: %i; numLabels: %i; + Train words: %i\n", (int) context.lengthOf(), (int) lockedWords.lengthOf(), + numLabels.isEmpty() ? 0 : numLabels.e(0), (int) trainWords); if + (context.lengthOf() == 2) { context.printBuffer("context"); + lockedWords.printBuffer("locked"); + codes.printBuffer("codes"); + indices.printBuffer("indices"); + }*/ + + auto hsRounds = codes.lengthOf(); + target.syncToHost(); + numLabels.syncToHost(); + target.syncToHost(); + alpha.syncToHost(); + numLabels.syncToHost(); + codes.syncToHost(); + negTable.syncToHost(); + BUILD_SINGLE_SELECTOR( + xType, cbow_, + (lc, syn0.specialBuffer(), syn1.specialBuffer(), + syn1Neg.specialBuffer(), expTable.specialBuffer(), negTable.buffer(), + inferenceVector.specialBuffer(), + target.isEmpty() ? -1 : target.e(0), + ngStarter.isEmpty() ? -1 : ngStarter.e(0), + reinterpret_cast(context.specialBuffer()), + reinterpret_cast(lockedWords.specialBuffer()), + reinterpret_cast(indices.buffer()), + reinterpret_cast(codes.buffer()), alpha.e(0), + randomValue.e(0), (int)context.lengthOf(), hsRounds, + nsRounds, (int)syn0.sizeAt(0), (int)syn0.sizeAt(1), + (int)expTable.lengthOf(), (int)negTable.lengthOf(), + numLabels.isEmpty() ? 0 : numLabels.e(0), trainWords), + FLOAT_TYPES); + } else if (context.rankOf() == 2 && indices.rankOf() == 2) { + // batch mode + // nd4j_printf("Batch exec\n",""); + + BUILD_SINGLE_SELECTOR( + xType, cbowBatchExec_, + (lc, syn0, syn1, syn1Neg, expTable.specialBuffer(), + negTable.specialBuffer(), nullptr, context, lockedWords, target, + ngStarter, indices, codes, alpha, randomValue, numLabels, nsRounds, + syn0.sizeAt(0), syn0.sizeAt(1), expTable.lengthOf(), + negTable.isEmpty() ? 0 : negTable.lengthOf(), trainWords, numWorkers), + FLOAT_TYPES); + } else + throw std::runtime_error("CBOW: context must have rank 0/1 or 2"); + + NDArray::registerSpecialUse( + {&syn0, &syn1, &syn1Neg, &expTable, &negTable, &target, &ngStarter}, + {&context, &lockedWords, &indices, &codes, &alpha, &randomValue, + &numLabels, &inferenceVector}); +} + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/shift.cu b/libnd4j/include/ops/declarable/helpers/cuda/shift.cu index c69285ef29b9..cc241d3eb30d 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/shift.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/shift.cu @@ -21,61 +21,65 @@ #include namespace sd { - namespace ops { - namespace helpers { - template - void rshift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { - auto lambda = LAMBDA_T(x, shift) { - return x >> shift; - }; +namespace ops { +namespace helpers { +template +void rshift_bits_(LaunchContext *launchContext, NDArray &input, NDArray &output, + uint32_t shift) { + auto lambda = LAMBDA_T(x, shift) { return x >> shift; }; - input.applyLambda(lambda, output); - } + input.applyLambda(lambda, output); +} - void rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { - BUILD_SINGLE_SELECTOR(x.dataType(), rshift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); - } +void rshift_bits(LaunchContext *launchContext, NDArray &x, NDArray &z, + uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), rshift_bits_, + (launchContext, x, z, shift), INTEGER_TYPES); +} - template - void shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { - auto lambda = LAMBDA_T(x, shift) { - return x << shift; - }; +template +void shift_bits_(LaunchContext *launchContext, NDArray &input, NDArray &output, + uint32_t shift) { + auto lambda = LAMBDA_T(x, shift) { return x << shift; }; - input.applyLambda(lambda, output); - } + input.applyLambda(lambda, output); +} - void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { - BUILD_SINGLE_SELECTOR(x.dataType(), shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); - } +void shift_bits(LaunchContext *launchContext, NDArray &x, NDArray &z, + uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), shift_bits_, (launchContext, x, z, shift), + INTEGER_TYPES); +} - template - void cyclic_rshift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { - auto step = (sizeof(T) * 8) - shift; - auto lambda = LAMBDA_T(x, shift, step) { - return x >> shift | x << step; - }; +template +void cyclic_rshift_bits_(LaunchContext *launchContext, NDArray &input, + NDArray &output, uint32_t shift) { + auto step = (sizeof(T) * 8) - shift; + auto lambda = LAMBDA_T(x, shift, step) { return x >> shift | x << step; }; - input.applyLambda(lambda, output); - } + input.applyLambda(lambda, output); +} - void cyclic_rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { - BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_rshift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); - } +void cyclic_rshift_bits(LaunchContext *launchContext, NDArray &x, NDArray &z, + uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_rshift_bits_, + (launchContext, x, z, shift), INTEGER_TYPES); +} - template - void cyclic_shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { - auto step = (sizeof(T) * 8) - shift; - auto lambda = LAMBDA_T(x, shift, step) { - return x << shift | x >> step; - }; +template +void cyclic_shift_bits_(LaunchContext *launchContext, NDArray &input, + NDArray &output, uint32_t shift) { + auto step = (sizeof(T) * 8) - shift; + auto lambda = LAMBDA_T(x, shift, step) { return x << shift | x >> step; }; - input.applyLambda(lambda, output); - } + input.applyLambda(lambda, output); +} - void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { - BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); - } - } - } -} \ No newline at end of file +void cyclic_shift_bits(LaunchContext *launchContext, NDArray &x, NDArray &z, + uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_shift_bits_, + (launchContext, x, z, shift), INTEGER_TYPES); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/solve.cu b/libnd4j/include/ops/declarable/helpers/cuda/solve.cu index 43ef78c3eb36..e2d8b7ab8de7 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/solve.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/solve.cu @@ -18,123 +18,154 @@ // @author GS // -#include #include #include -#include - #include #include -#include "../triangular_solve.h" +#include +#include + #include "../lup.h" #include "../solve.h" +#include "../triangular_solve.h" namespace sd { - namespace ops { - namespace helpers { - - template - static __global__ void oneOnDiagonalKernel(T* ioBuf, Nd4jLong const* ioShape, Nd4jLong const* tadShape, Nd4jLong const* tadOffsets, Nd4jLong batchNum, Nd4jLong rowNum) { - for (auto i = blockIdx.x; i < batchNum; i += gridDim.x) { - auto matrixPart = ioBuf + tadOffsets[i]; - for (auto j = threadIdx.x; j < rowNum; j += blockDim.x) { - Nd4jLong pos[] = {j, j}; - auto offset = shape::getOffset(tadShape, pos); - - matrixPart[offset] = T(1.f); - } - } - } - - template - static __global__ void restorePermutationsKernel(T* PBuf, Nd4jLong const* PShapeInfo, int const* permutationsBuf, - Nd4jLong const* PTadShapeInfo, Nd4jLong const* PTadSOffsets, Nd4jLong const* permutationsTadShapeInfo, Nd4jLong const* permutationsTadOffsets, Nd4jLong batchNum, Nd4jLong rowNum) { - for (auto batch = blockIdx.x; batch < batchNum; batch += gridDim.x) { - auto permutations = permutationsBuf + permutationsTadOffsets[batch]; - auto P = PBuf + PTadSOffsets[batch]; - - for (auto row = threadIdx.x; row < rowNum; row += blockDim.x) { - //auto posX[] = {row}; - Nd4jLong posZ[] = {row, permutations[row]}; - auto zOffset = shape::getOffset(PTadShapeInfo, posZ); - P[zOffset] = T(1.f); - } - } - } - - template - static int solveFunctor_(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, - bool adjoint, NDArray* output) { - NDArray::prepareSpecialUse({output}, {leftInput, rightInput}); - // stage 1: LU decomposition batched - auto leftOutput = leftInput->ulike(); - auto permuShape = rightInput->getShapeAsVector(); permuShape.pop_back(); - auto permutations = NDArrayFactory::create('c', permuShape, context); - helpers::lu(context, leftInput, &leftOutput, &permutations); - auto leftLower = leftOutput.dup(); - auto rightOutput = rightInput->ulike(); - auto leftLowerTad = ConstantTadHelper::getInstance().tadForDimensions(leftLower.shapeInfo(), {-2, -1}); - auto stream = context->getCudaStream(); - oneOnDiagonalKernel<<<128, 256, 256, *stream>>>(leftLower.dataBuffer()->specialAsT(), leftLower.specialShapeInfo(), leftLowerTad.specialShapeInfo(), leftLowerTad.specialOffsets(), leftLowerTad.numberOfTads(), leftLower.sizeAt(-1)); - auto P = leftOutput.ulike(); P.nullify(); - auto PTad = ConstantTadHelper::getInstance().tadForDimensions(P.shapeInfo(), {-2, -1}); - auto permutationsTad = ConstantTadHelper::getInstance().tadForDimensions(permutations.shapeInfo(), {-1}); - restorePermutationsKernel<<<128, 256, 256, *stream>>>(P.dataBuffer()->specialAsT(), P.specialShapeInfo(), permutations.dataBuffer()->specialAsT(), - PTad.specialShapeInfo(), PTad.specialOffsets(), permutationsTad.specialShapeInfo(), permutationsTad.specialOffsets(), permutationsTad.numberOfTads(), permutations.sizeAt(-1)); - P.tickWriteDevice(); - auto rightPart = rightInput->ulike(); - MmulHelper::matmul(&P, rightInput, &rightPart, 0, 0); - - // stage 2: triangularSolveFunctor for Lower with given b - helpers::triangularSolveFunctor(context, &leftLower, &rightPart, true, false, &rightOutput); - // stage 3: triangularSolveFunctor for Upper with output of previous stage - helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, false, false, output); - NDArray::registerSpecialUse({output}, {leftInput, rightInput}); - - return Status::OK(); - } - - int solveFunctor(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) { - BUILD_SINGLE_SELECTOR(leftInput->dataType(), return solveFunctor_, (context, leftInput, rightInput, adjoint, output), FLOAT_TYPES); - } - - template - static __global__ void adjointKernel(T* output, Nd4jLong batchSize, Nd4jLong rows, Nd4jLong columns, Nd4jLong const* outputTads, - Nd4jLong const* outputOffsets) { - - for (auto b = blockIdx.x; b < batchSize; b += gridDim.x) { - auto outputPart = output + outputOffsets[b]; - for (auto r = threadIdx.x; r < rows; r += blockDim.x) { - for (auto c = threadIdx.y; c < r; c += blockDim.y) { - Nd4jLong zPos[] = {r, c}; - Nd4jLong xPos[] = {c, r}; - auto zIndex = shape::getOffset(outputTads, zPos); - auto xIndex = shape::getOffset(outputTads, xPos); - math::nd4j_swap(outputPart[zIndex], outputPart[xIndex]); - } - } - } - - } - - template - static void adjointMatrix_(sd::LaunchContext* context, NDArray const* input, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input}); - auto inputTads = ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), {-2, -1}); - auto outputTads = ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), {-2, -1}); - auto stream = context->getCudaStream(); - auto outputBuf = reinterpret_cast(output->specialBuffer()); - auto rows = input->sizeAt(-2); - auto columns = input->sizeAt(-1); - output->assign(input); - adjointKernel<<<128, 256, 256, *stream>>>(outputBuf, outputTads.numberOfTads(), rows, columns, outputTads.specialShapeInfo(), outputTads.specialOffsets()); - NDArray::registerSpecialUse({output}, {input}); - } - - void adjointMatrix(sd::LaunchContext* context, NDArray const* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), adjointMatrix_, (context, input, output), FLOAT_TYPES); - } - - } +namespace ops { +namespace helpers { + +template +static __global__ void oneOnDiagonalKernel(T* ioBuf, Nd4jLong const* ioShape, + Nd4jLong const* tadShape, + Nd4jLong const* tadOffsets, + Nd4jLong batchNum, Nd4jLong rowNum) { + for (auto i = blockIdx.x; i < batchNum; i += gridDim.x) { + auto matrixPart = ioBuf + tadOffsets[i]; + for (auto j = threadIdx.x; j < rowNum; j += blockDim.x) { + Nd4jLong pos[] = {j, j}; + auto offset = shape::getOffset(tadShape, pos); + + matrixPart[offset] = T(1.f); + } + } +} + +template +static __global__ void restorePermutationsKernel( + T* PBuf, Nd4jLong const* PShapeInfo, int const* permutationsBuf, + Nd4jLong const* PTadShapeInfo, Nd4jLong const* PTadSOffsets, + Nd4jLong const* permutationsTadShapeInfo, + Nd4jLong const* permutationsTadOffsets, Nd4jLong batchNum, + Nd4jLong rowNum) { + for (auto batch = blockIdx.x; batch < batchNum; batch += gridDim.x) { + auto permutations = permutationsBuf + permutationsTadOffsets[batch]; + auto P = PBuf + PTadSOffsets[batch]; + + for (auto row = threadIdx.x; row < rowNum; row += blockDim.x) { + // auto posX[] = {row}; + Nd4jLong posZ[] = {row, permutations[row]}; + auto zOffset = shape::getOffset(PTadShapeInfo, posZ); + P[zOffset] = T(1.f); } + } } + +template +static int solveFunctor_(sd::LaunchContext* context, NDArray* leftInput, + NDArray* rightInput, bool adjoint, NDArray* output) { + NDArray::prepareSpecialUse({output}, {leftInput, rightInput}); + // stage 1: LU decomposition batched + auto leftOutput = leftInput->ulike(); + auto permuShape = rightInput->getShapeAsVector(); + permuShape.pop_back(); + auto permutations = NDArrayFactory::create('c', permuShape, context); + helpers::lu(context, leftInput, &leftOutput, &permutations); + auto leftLower = leftOutput.dup(); + auto rightOutput = rightInput->ulike(); + auto leftLowerTad = ConstantTadHelper::getInstance().tadForDimensions( + leftLower.shapeInfo(), {-2, -1}); + auto stream = context->getCudaStream(); + oneOnDiagonalKernel<<<128, 256, 256, *stream>>>( + leftLower.dataBuffer()->specialAsT(), leftLower.specialShapeInfo(), + leftLowerTad.specialShapeInfo(), leftLowerTad.specialOffsets(), + leftLowerTad.numberOfTads(), leftLower.sizeAt(-1)); + auto P = leftOutput.ulike(); + P.nullify(); + auto PTad = ConstantTadHelper::getInstance().tadForDimensions(P.shapeInfo(), + {-2, -1}); + auto permutationsTad = ConstantTadHelper::getInstance().tadForDimensions( + permutations.shapeInfo(), {-1}); + restorePermutationsKernel<<<128, 256, 256, *stream>>>( + P.dataBuffer()->specialAsT(), P.specialShapeInfo(), + permutations.dataBuffer()->specialAsT(), PTad.specialShapeInfo(), + PTad.specialOffsets(), permutationsTad.specialShapeInfo(), + permutationsTad.specialOffsets(), permutationsTad.numberOfTads(), + permutations.sizeAt(-1)); + P.tickWriteDevice(); + auto rightPart = rightInput->ulike(); + MmulHelper::matmul(&P, rightInput, &rightPart, 0, 0); + + // stage 2: triangularSolveFunctor for Lower with given b + helpers::triangularSolveFunctor(context, &leftLower, &rightPart, true, false, + &rightOutput); + // stage 3: triangularSolveFunctor for Upper with output of previous stage + helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, false, + false, output); + NDArray::registerSpecialUse({output}, {leftInput, rightInput}); + + return Status::OK(); +} + +int solveFunctor(sd::LaunchContext* context, NDArray* leftInput, + NDArray* rightInput, bool adjoint, NDArray* output) { + BUILD_SINGLE_SELECTOR(leftInput->dataType(), return solveFunctor_, + (context, leftInput, rightInput, adjoint, output), + FLOAT_TYPES); +} + +template +static __global__ void adjointKernel(T* output, Nd4jLong batchSize, + Nd4jLong rows, Nd4jLong columns, + Nd4jLong const* outputTads, + Nd4jLong const* outputOffsets) { + for (auto b = blockIdx.x; b < batchSize; b += gridDim.x) { + auto outputPart = output + outputOffsets[b]; + for (auto r = threadIdx.x; r < rows; r += blockDim.x) { + for (auto c = threadIdx.y; c < r; c += blockDim.y) { + Nd4jLong zPos[] = {r, c}; + Nd4jLong xPos[] = {c, r}; + auto zIndex = shape::getOffset(outputTads, zPos); + auto xIndex = shape::getOffset(outputTads, xPos); + math::nd4j_swap(outputPart[zIndex], outputPart[xIndex]); + } + } + } +} + +template +static void adjointMatrix_(sd::LaunchContext* context, NDArray const* input, + NDArray* output) { + NDArray::prepareSpecialUse({output}, {input}); + auto inputTads = ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), {-2, -1}); + auto outputTads = ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), {-2, -1}); + auto stream = context->getCudaStream(); + auto outputBuf = reinterpret_cast(output->specialBuffer()); + auto rows = input->sizeAt(-2); + auto columns = input->sizeAt(-1); + output->assign(input); + adjointKernel<<<128, 256, 256, *stream>>>( + outputBuf, outputTads.numberOfTads(), rows, columns, + outputTads.specialShapeInfo(), outputTads.specialOffsets()); + NDArray::registerSpecialUse({output}, {input}); +} + +void adjointMatrix(sd::LaunchContext* context, NDArray const* input, + NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), adjointMatrix_, + (context, input, output), FLOAT_TYPES); +} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/split.cu b/libnd4j/include/ops/declarable/helpers/cuda/split.cu index 19c58b89eba0..71e97993f205 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/split.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/split.cu @@ -19,174 +19,195 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // - -#include -#include -#include -#include #include -#include +#include #include -#include #include +#include +#include +#include +#include -namespace sd { -namespace ops { -namespace helpers { +#include +namespace sd { +namespace ops { +namespace helpers { /////////////////////////////////////////////////////////////////// -template -__global__ static void splitCuda(const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis) { - - const T* x = reinterpret_cast(vx); +template +__global__ static void splitCuda(const void* vx, const Nd4jLong* xShapeInfo, + void* pVz, const Nd4jLong* zTadShapeInfo, + const int axis) { + const T* x = reinterpret_cast(vx); - __shared__ Nd4jLong xLen, totalThreads; - __shared__ int xRank, zDim; + __shared__ Nd4jLong xLen, totalThreads; + __shared__ int xRank, zDim; - if (threadIdx.x == 0) { - xLen = shape::length(xShapeInfo); - xRank = shape::rank(xShapeInfo); - zDim = shape::shapeOf(zTadShapeInfo)[axis]; // same for all input arrays - totalThreads = gridDim.x * blockDim.x; - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + xRank = shape::rank(xShapeInfo); + zDim = shape::shapeOf(zTadShapeInfo)[axis]; // same for all input arrays + totalThreads = gridDim.x * blockDim.x; + } + __syncthreads(); - int coords[MAX_RANK]; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (uint64_t i = tid; i < xLen; i += totalThreads) { + int coords[MAX_RANK]; - shape::index2coords(i, xShapeInfo, coords); + for (uint64_t i = tid; i < xLen; i += totalThreads) { + shape::index2coords(i, xShapeInfo, coords); - const auto xOffset = shape::getOffset(xShapeInfo, coords); + const auto xOffset = shape::getOffset(xShapeInfo, coords); - auto *z = reinterpret_cast(reinterpret_cast(pVz)[coords[axis] / zDim]); + auto* z = reinterpret_cast( + reinterpret_cast(pVz)[coords[axis] / zDim]); - coords[axis] %= zDim; + coords[axis] %= zDim; - const auto zOffset = shape::getOffset(zTadShapeInfo, coords); + const auto zOffset = shape::getOffset(zTadShapeInfo, coords); - z[zOffset] = x[xOffset]; - } + z[zOffset] = x[xOffset]; + } } /////////////////////////////////////////////////////////////////// -template -__host__ static void splitCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis) { - - splitCuda<<>>(vx, xShapeInfo, pVz, zTadShapeInfo, axis); +template +__host__ static void splitCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + void* pVz, const Nd4jLong* zTadShapeInfo, const int axis) { + splitCuda<<>>( + vx, xShapeInfo, pVz, zTadShapeInfo, axis); } -BUILD_SINGLE_TEMPLATE(template void splitCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void splitCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, + const Nd4jLong* xShapeInfo, void* pVz, + const Nd4jLong* zTadShapeInfo, const int axis), + LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// -void split(sd::LaunchContext* context, const NDArray& input, std::vector& outArrs, const int axis) { - - const int numOfSubArrs = outArrs.size(); - const auto sizeofT = input.sizeOfT(); - - for(int i = 0; i < numOfSubArrs; ++i) - outArrs[i]->syncToDevice(); - input.syncToDevice(); - - bool luckCase1 = ((axis == 0 && input.ordering() == 'c') || (axis == input.rankOf() - 1 && input.ordering() == 'f')) && input.ews() == 1; - - if(luckCase1) { - for (uint i = 0; i < numOfSubArrs; ++i) { - luckCase1 &= outArrs[i]->ordering() == input.ordering() && outArrs[i]->ews() == 1; - if(!luckCase1) - break; - } +void split(sd::LaunchContext* context, const NDArray& input, + std::vector& outArrs, const int axis) { + const int numOfSubArrs = outArrs.size(); + const auto sizeofT = input.sizeOfT(); + + for (int i = 0; i < numOfSubArrs; ++i) outArrs[i]->syncToDevice(); + input.syncToDevice(); + + bool luckCase1 = ((axis == 0 && input.ordering() == 'c') || + (axis == input.rankOf() - 1 && input.ordering() == 'f')) && + input.ews() == 1; + + if (luckCase1) { + for (uint i = 0; i < numOfSubArrs; ++i) { + luckCase1 &= + outArrs[i]->ordering() == input.ordering() && outArrs[i]->ews() == 1; + if (!luckCase1) break; } + } - if(luckCase1) { // for example {1,10} + {2,10} + {3,10} = {6, 10} order c; or {10,1} + {10,2} + {10,3} = {10, 6} order f + if (luckCase1) { // for example {1,10} + {2,10} + {3,10} = {6, 10} order c; + // or {10,1} + {10,2} + {10,3} = {10, 6} order f - auto x = static_cast(input.specialBuffer()); + auto x = static_cast(input.specialBuffer()); - for (uint i = 0; i < numOfSubArrs; ++i) { - const auto memAmountToCopy = outArrs[i]->lengthOf() * sizeofT; - cudaMemcpyAsync(static_cast(outArrs[i]->specialBuffer()), x, memAmountToCopy, cudaMemcpyDeviceToDevice, *context->getCudaStream()); - x = static_cast(x) + memAmountToCopy; - } + for (uint i = 0; i < numOfSubArrs; ++i) { + const auto memAmountToCopy = outArrs[i]->lengthOf() * sizeofT; + cudaMemcpyAsync(static_cast(outArrs[i]->specialBuffer()), x, + memAmountToCopy, cudaMemcpyDeviceToDevice, + *context->getCudaStream()); + x = static_cast(x) + memAmountToCopy; + } - if(cudaStreamSynchronize(*context->getCudaStream()) != 0) - throw std::runtime_error("split cuda: luckCase1 failed!"); + if (cudaStreamSynchronize(*context->getCudaStream()) != 0) + throw std::runtime_error("split cuda: luckCase1 failed!"); - for(int i = 0; i < numOfSubArrs; ++i) - outArrs[i]->tickWriteDevice(); - input.tickReadDevice(); + for (int i = 0; i < numOfSubArrs; ++i) outArrs[i]->tickWriteDevice(); + input.tickReadDevice(); - return; - } + return; + } - // const bool isXcontin = input.strideAt(axis) == 1; - // bool areOutputsContin = true; - // bool allSameOrder = true; - // std::vector strideOfContigStride(outArrs.size()); + // const bool isXcontin = input.strideAt(axis) == 1; + // bool areOutputsContin = true; + // bool allSameOrder = true; + // std::vector strideOfContigStride(outArrs.size()); - // if(isXcontin) { + // if(isXcontin) { - // for (uint i = 0; i < outArrs.size(); ++i) { + // for (uint i = 0; i < outArrs.size(); ++i) { - // areOutputsContin &= outArrs[i]->strideAt(axis) == 1; - // allSameOrder &= input.ordering() == outArrs[i]->ordering(); - // if(!areOutputsContin || !allSameOrder) - // break; + // areOutputsContin &= outArrs[i]->strideAt(axis) == 1; + // allSameOrder &= input.ordering() == outArrs[i]->ordering(); + // if(!areOutputsContin || !allSameOrder) + // break; - // strideOfContigStride[i] = shape::strideOverContigAxis(axis, outArrs[i]->shapeInfo()); - // } - // } + // strideOfContigStride[i] = shape::strideOverContigAxis(axis, + // outArrs[i]->shapeInfo()); + // } + // } - // const bool luckCase2 = isXcontin && areOutputsContin && allSameOrder; + // const bool luckCase2 = isXcontin && areOutputsContin && allSameOrder; - // if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, here axis 1 shoud have stride = 1 for all inputs arrays and input array + // if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, + // here axis 1 shoud have stride = 1 for all inputs arrays and input array - // const auto xStep = shape::strideOverContigAxis(axis, input.shapeInfo()); - // const auto zDim = outArrs[0]->sizeAt(axis); // same for all outArrs + // const auto xStep = shape::strideOverContigAxis(axis, + // input.shapeInfo()); const auto zDim = outArrs[0]->sizeAt(axis); // + // same for all outArrs - // for (uint i = 0; i < input.lengthOf() / input.sizeAt(axis); ++i) { + // for (uint i = 0; i < input.lengthOf() / input.sizeAt(axis); ++i) { - // const auto iShift = i * sizeofT; - // void* x = static_cast(input.specialBuffer()) + xStep * iShift; + // const auto iShift = i * sizeofT; + // void* x = static_cast(input.specialBuffer()) + xStep * + // iShift; - // for (uint j = 0; j < numOfSubArrs; ++j) { - // void* z = static_cast(outArrs[j]->specialBuffer()) + strideOfContigStride[j] * iShift; - // const auto memSizeToCopy = zDim * sizeofT; - // cudaMemcpyAsync(z, x, memSizeToCopy, cudaMemcpyDeviceToDevice, *context->getCudaStream()); - // x = static_cast(x) + memSizeToCopy; - // } - // } + // for (uint j = 0; j < numOfSubArrs; ++j) { + // void* z = static_cast(outArrs[j]->specialBuffer()) + + // strideOfContigStride[j] * iShift; const auto memSizeToCopy = + // zDim * sizeofT; cudaMemcpyAsync(z, x, memSizeToCopy, + // cudaMemcpyDeviceToDevice, *context->getCudaStream()); x = + // static_cast(x) + memSizeToCopy; + // } + // } - // if(cudaStreamSynchronize(*context->getCudaStream()) != 0) - // throw std::runtime_error("split cuda: luckCase2 failed!"); - // } - // else { // general (slower) case + // if(cudaStreamSynchronize(*context->getCudaStream()) != 0) + // throw std::runtime_error("split cuda: luckCase2 failed!"); + // } + // else { // general (slower) case - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - // prepare arrays of pointers on buffers and shapes - std::vector hOutBuffers(numOfSubArrs); + // prepare arrays of pointers on buffers and shapes + std::vector hOutBuffers(numOfSubArrs); - for(int i = 0; i < numOfSubArrs; ++i) - hOutBuffers[i] = outArrs[i]->specialBuffer(); + for (int i = 0; i < numOfSubArrs; ++i) + hOutBuffers[i] = outArrs[i]->specialBuffer(); - PointersManager manager(context, "helpers::split"); + PointersManager manager(context, "helpers::split"); - void* dOutBuffers = manager.replicatePointer(hOutBuffers.data(), hOutBuffers.size() * sizeof(void*)); + void* dOutBuffers = manager.replicatePointer( + hOutBuffers.data(), hOutBuffers.size() * sizeof(void*)); - BUILD_SINGLE_SELECTOR(input.dataType(), splitCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), dOutBuffers, outArrs[0]->specialShapeInfo(), axis), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR( + input.dataType(), splitCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), dOutBuffers, + outArrs[0]->specialShapeInfo(), axis), + LIBND4J_TYPES); - manager.synchronize(); - // } + manager.synchronize(); + // } - for(int i = 0; i < numOfSubArrs; ++i) - outArrs[i]->tickWriteDevice(); - input.tickReadDevice(); + for (int i = 0; i < numOfSubArrs; ++i) outArrs[i]->tickWriteDevice(); + input.tickReadDevice(); } -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/sru.cu b/libnd4j/include/ops/declarable/helpers/cuda/sru.cu index b59ac00524b2..27d9beadaa38 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/sru.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/sru.cu @@ -15,521 +15,539 @@ ******************************************************************************/ // -// implementation of operations for Simple Recurrent Unit: arXiv:1709.02755v2 [cs.CL] 12 Sep 2017 +// implementation of operations for Simple Recurrent Unit: arXiv:1709.02755v2 +// [cs.CL] 12 Sep 2017 // // @author Yurii Shyrma, created on 05.12.2017 // -#include #include -#include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// - static FORCEINLINE NDArray activation(const NDArray& arr) { - // return (const_cast&>(arr)).template transform>(); - auto result = NDArray(&arr, false, arr.getContext()); - (const_cast(arr)).applyTransform(transform::Tanh, result); - return result; - } - - - ////////////////////////////////////////////////////////////////////////// - static FORCEINLINE NDArray sigmoid(const NDArray& arr) { - return (const_cast(arr)).transform(transform::Sigmoid); - } - +////////////////////////////////////////////////////////////////////////// +static FORCEINLINE NDArray activation(const NDArray& arr) { + // return (const_cast&>(arr)).template + // transform>(); + auto result = NDArray(&arr, false, arr.getContext()); + (const_cast(arr)).applyTransform(transform::Tanh, result); + return result; +} ////////////////////////////////////////////////////////////////////////// -void sruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* c0, const NDArray* w, const NDArray* b, NDArray* h, NDArray* c) { +static FORCEINLINE NDArray sigmoid(const NDArray& arr) { + return (const_cast(arr)).transform(transform::Sigmoid); +} - // x input [bS x inSize], bS - batch size, inSize - number of features - // c0 previous cell state c [bS x inSize], that is at previous time step t-1 - // w weights [inSize x 3*inSize] - // b biases [2*inSize] +////////////////////////////////////////////////////////////////////////// +void sruCell(sd::LaunchContext* context, const NDArray* x, const NDArray* c0, + const NDArray* w, const NDArray* b, NDArray* h, NDArray* c) { + // x input [bS x inSize], bS - batch size, inSize - number of features + // c0 previous cell state c [bS x inSize], that is at previous time step t-1 + // w weights [inSize x 3*inSize] + // b biases [2*inSize] - // h current cell output [bS x inSize], that is at current time step t - // c current cell state [bS x inSize], that is at current time step t + // h current cell output [bS x inSize], that is at current time step t + // c current cell state [bS x inSize], that is at current time step t - const int inSize = x->sizeAt(1); // inSize - number of features + const int inSize = x->sizeAt(1); // inSize - number of features - auto z = mmul(*x, *w); // [bS x 3*inSize] + auto z = mmul(*x, *w); // [bS x 3*inSize] - // forget gate = sigmoid(x*Wf + bf) - auto f = sigmoid(z({0,0, inSize, 2*inSize}) + (*b)({0, inSize})); + // forget gate = sigmoid(x*Wf + bf) + auto f = sigmoid(z({0, 0, inSize, 2 * inSize}) + (*b)({0, inSize})); - // reset gate = sigmoid(x*Wr + br) - auto r = sigmoid(z({0,0, 2*inSize, 3*inSize}) + (*b)({inSize, 2*inSize})); + // reset gate = sigmoid(x*Wr + br) + auto r = + sigmoid(z({0, 0, 2 * inSize, 3 * inSize}) + (*b)({inSize, 2 * inSize})); - // ◦ means element-wise product or so called Hadamard product - // current sell state = f◦c0 + (1 - f)◦(x*Wc) - c->assign(f * (*c0) + (1.f - f) * z({0, 0 ,0, inSize}) ); - // *c = f*(*c0 - z({},{0, inSize})) + z({{},{0, inSize}}); + // ◦ means element-wise product or so called Hadamard product + // current sell state = f◦c0 + (1 - f)◦(x*Wc) + c->assign(f * (*c0) + (1.f - f) * z({0, 0, 0, inSize})); + // *c = f*(*c0 - z({},{0, inSize})) + z({{},{0, inSize}}); - // current cell output = r◦activation(c) + (1 - r)◦x - h->assign( r * activation(*c) + (1.f - r) * (*x) ); - // *h = r * (activation(c) - *x) + *x; + // current cell output = r◦activation(c) + (1 - r)◦x + h->assign(r * activation(*c) + (1.f - r) * (*x)); + // *h = r * (activation(c) - *x) + *x; } ////////////////////////////////////////////////////////////////////////// -void sruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* c0, const NDArray* w, const NDArray* b, NDArray* h, NDArray* c) { - - // x input [bS x inSize x time] - // c0 initial cell state (at time step = 0) [bS x inSize], - // w weights, [3*inSize x inSize] - // b biases, [2*inSize] +void sruTimeLoop(sd::LaunchContext* context, const NDArray* x, + const NDArray* c0, const NDArray* w, const NDArray* b, + NDArray* h, NDArray* c) { + // x input [bS x inSize x time] + // c0 initial cell state (at time step = 0) [bS x inSize], + // w weights, [3*inSize x inSize] + // b biases, [2*inSize] - // h cell outputs [bS x inSize x time] - // c cell states [bS x inSize x time] + // h cell outputs [bS x inSize x time] + // c cell states [bS x inSize x time] - auto wT = w->transpose(); // [3*inSize x inSize] -> [inSize x 3*inSize] + auto wT = w->transpose(); // [3*inSize x inSize] -> [inSize x 3*inSize] - const int time = x->sizeAt(2); + const int time = x->sizeAt(2); - NDArray ct_1(*c0); + NDArray ct_1(*c0); - // loop through time steps - for (int t = 0; t < time; ++t) { + // loop through time steps + for (int t = 0; t < time; ++t) { + auto xt = (*x)({0, 0, 0, 0, t, t + 1}); + auto ht = (*h)({0, 0, 0, 0, t, t + 1}); + auto ct = (*c)({0, 0, 0, 0, t, t + 1}); - auto xt = (*x)({0,0, 0,0, t,t+1}); - auto ht = (*h)({0,0, 0,0, t,t+1}); - auto ct = (*c)({0,0, 0,0, t,t+1}); - - helpers::sruCell(context, &xt, &ct_1, &wT, b, &ht, &ct); - ct_1.assign(ct); - } + helpers::sruCell(context, &xt, &ct_1, &wT, b, &ht, &ct); + ct_1.assign(ct); + } } - ////////////////////////////////////////////////////////////////////////// template -__global__ static void sruBICuda(const void* vx, const Nd4jLong* xShapeInfo, - const void* vwi, const Nd4jLong* wiShapeInfo, - const void* vb, const Nd4jLong* bShapeInfo, - const void* vc0, const Nd4jLong* c0ShapeInfo, - const void* vmask, const Nd4jLong* maskShapeInfo, - void* vht, const Nd4jLong* htShapeInfo, - void* vct, const Nd4jLong* ctShapeInfo) { - // inputs: - // x [time, bS, 2*K] - // wi [time, bS, 6*K], wi = mmul(x, weights); - // b [4*K] - // c0 [bS, 2*K] - // mask [bS, 2*K], optional - - // outputs - // ht [time, bS, 2*K] - // ct [time, bS, 2*K] - - const auto x = reinterpret_cast(vx); - const auto wi = reinterpret_cast(vwi); - const auto b = reinterpret_cast(vb); - const auto c0 = reinterpret_cast(vc0); - const auto mask = reinterpret_cast(vmask); - auto ht = reinterpret_cast(vht); - auto ct = reinterpret_cast(vct); - - const int rank = 3; - - __shared__ int time, K, *sharedMem; - __shared__ Nd4jLong len, totalThreads; - - if (threadIdx.x == 0) { - - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - time = xShapeInfo[1]; - K = xShapeInfo[3] / 2; - len = xShapeInfo[2] * xShapeInfo[3]; // 2*K*bS - - totalThreads = gridDim.x * blockDim.x; - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - auto coords = sharedMem + threadIdx.x * rank; - - if(tid >= len) - return; - - shape::index2coords(tid, rank - 1, xShapeInfo + 2, coords + 1); // loop through last two dimensions of x : {bS, 2*K} - - const auto maskOffst = mask ? shape::getOffset(maskShapeInfo, coords + 1) : 0; - const auto c0Offset = shape::getOffset(c0ShapeInfo, coords + 1); - const auto bFOffset = shape::getOffset(bShapeInfo, coords + 2); - const auto bROffset = bFOffset + 2 * K * bShapeInfo[2]; // 2*K*b_stride - - const T maskVal = mask ? mask[maskOffst] : static_cast(1); - const T bF = b[bFOffset]; - const T bR = b[bROffset]; - T c0Val = c0[c0Offset]; - - const bool flip = coords[2] >= K; - - if(flip) - coords[0] = time - 1; - else - coords[0] = 0; - - auto xOffset = shape::getOffset(xShapeInfo, coords); - auto htOffset = shape::getOffset(htShapeInfo, coords); - auto ctOffset = shape::getOffset(ctShapeInfo, coords); - - coords[2] *= 3; - auto wiOffset0 = shape::getOffset(wiShapeInfo, coords); - auto wiOffset1 = wiOffset0 + wiShapeInfo[rank + 3]; // add last stride - auto wiOffset2 = wiOffset1 + wiShapeInfo[rank + 3]; // add last stride - - // time loop - for (uint t = 0; t < time; ++t) { - - // evaluate sigmoids - T ft = (1.f)/(1.f + sd::math::nd4j_exp(-(wi[wiOffset1] + bF))); - T rt = (1.f)/(1.f + sd::math::nd4j_exp(-(wi[wiOffset2] + bR))); - - c0Val = (c0Val - wi[wiOffset0]) * ft + wi[wiOffset0]; - ct[ctOffset] = c0Val; - T val = sd::math::nd4j_tanh(c0Val); - T xVal = x[xOffset]; - ht[htOffset] = (val * maskVal - xVal) * rt + xVal; - - if(flip) { - xOffset -= xShapeInfo[rank + 1]; // first stride, corresponds to time step - htOffset -= htShapeInfo[rank + 1]; - ctOffset -= htShapeInfo[rank + 1]; - wiOffset0 -= wiShapeInfo[rank + 1]; - wiOffset1 -= wiShapeInfo[rank + 1]; - wiOffset2 -= wiShapeInfo[rank + 1]; - } - else { - xOffset += xShapeInfo[rank + 1]; // first stride, corresponds to time step - htOffset += htShapeInfo[rank + 1]; - ctOffset += htShapeInfo[rank + 1]; - wiOffset0 += wiShapeInfo[rank + 1]; - wiOffset1 += wiShapeInfo[rank + 1]; - wiOffset2 += wiShapeInfo[rank + 1]; - } +__global__ static void sruBICuda(const void* vx, const Nd4jLong* xShapeInfo, + const void* vwi, const Nd4jLong* wiShapeInfo, + const void* vb, const Nd4jLong* bShapeInfo, + const void* vc0, const Nd4jLong* c0ShapeInfo, + const void* vmask, + const Nd4jLong* maskShapeInfo, void* vht, + const Nd4jLong* htShapeInfo, void* vct, + const Nd4jLong* ctShapeInfo) { + // inputs: + // x [time, bS, 2*K] + // wi [time, bS, 6*K], wi = mmul(x, weights); + // b [4*K] + // c0 [bS, 2*K] + // mask [bS, 2*K], optional + + // outputs + // ht [time, bS, 2*K] + // ct [time, bS, 2*K] + + const auto x = reinterpret_cast(vx); + const auto wi = reinterpret_cast(vwi); + const auto b = reinterpret_cast(vb); + const auto c0 = reinterpret_cast(vc0); + const auto mask = reinterpret_cast(vmask); + auto ht = reinterpret_cast(vht); + auto ct = reinterpret_cast(vct); + + const int rank = 3; + + __shared__ int time, K, *sharedMem; + __shared__ Nd4jLong len, totalThreads; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + time = xShapeInfo[1]; + K = xShapeInfo[3] / 2; + len = xShapeInfo[2] * xShapeInfo[3]; // 2*K*bS + + totalThreads = gridDim.x * blockDim.x; + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto coords = sharedMem + threadIdx.x * rank; + + if (tid >= len) return; + + shape::index2coords( + tid, rank - 1, xShapeInfo + 2, + coords + 1); // loop through last two dimensions of x : {bS, 2*K} + + const auto maskOffst = mask ? shape::getOffset(maskShapeInfo, coords + 1) : 0; + const auto c0Offset = shape::getOffset(c0ShapeInfo, coords + 1); + const auto bFOffset = shape::getOffset(bShapeInfo, coords + 2); + const auto bROffset = bFOffset + 2 * K * bShapeInfo[2]; // 2*K*b_stride + + const T maskVal = mask ? mask[maskOffst] : static_cast(1); + const T bF = b[bFOffset]; + const T bR = b[bROffset]; + T c0Val = c0[c0Offset]; + + const bool flip = coords[2] >= K; + + if (flip) + coords[0] = time - 1; + else + coords[0] = 0; + + auto xOffset = shape::getOffset(xShapeInfo, coords); + auto htOffset = shape::getOffset(htShapeInfo, coords); + auto ctOffset = shape::getOffset(ctShapeInfo, coords); + + coords[2] *= 3; + auto wiOffset0 = shape::getOffset(wiShapeInfo, coords); + auto wiOffset1 = wiOffset0 + wiShapeInfo[rank + 3]; // add last stride + auto wiOffset2 = wiOffset1 + wiShapeInfo[rank + 3]; // add last stride + + // time loop + for (uint t = 0; t < time; ++t) { + // evaluate sigmoids + T ft = (1.f) / (1.f + sd::math::nd4j_exp(-(wi[wiOffset1] + bF))); + T rt = (1.f) / (1.f + sd::math::nd4j_exp(-(wi[wiOffset2] + bR))); + + c0Val = (c0Val - wi[wiOffset0]) * ft + wi[wiOffset0]; + ct[ctOffset] = c0Val; + T val = sd::math::nd4j_tanh(c0Val); + T xVal = x[xOffset]; + ht[htOffset] = (val * maskVal - xVal) * rt + xVal; + + if (flip) { + xOffset -= + xShapeInfo[rank + 1]; // first stride, corresponds to time step + htOffset -= htShapeInfo[rank + 1]; + ctOffset -= htShapeInfo[rank + 1]; + wiOffset0 -= wiShapeInfo[rank + 1]; + wiOffset1 -= wiShapeInfo[rank + 1]; + wiOffset2 -= wiShapeInfo[rank + 1]; + } else { + xOffset += + xShapeInfo[rank + 1]; // first stride, corresponds to time step + htOffset += htShapeInfo[rank + 1]; + ctOffset += htShapeInfo[rank + 1]; + wiOffset0 += wiShapeInfo[rank + 1]; + wiOffset1 += wiShapeInfo[rank + 1]; + wiOffset2 += wiShapeInfo[rank + 1]; } + } } ////////////////////////////////////////////////////////////////////////// template -static void sruBICudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - const void* vwi, const Nd4jLong* wiShapeInfo, - const void* vb, const Nd4jLong* bShapeInfo, - const void* vc0, const Nd4jLong* c0ShapeInfo, - const void* vmask, const Nd4jLong* maskShapeInfo, - void* vht, const Nd4jLong* htShapeInfo, - void* vct, const Nd4jLong* ctShapeInfo) { - - sruBICuda<<>>(vx, xShapeInfo, vwi, wiShapeInfo, vb, bShapeInfo, vc0, c0ShapeInfo, vmask, maskShapeInfo, vht, htShapeInfo, vct, ctShapeInfo); +static void sruBICudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vwi, const Nd4jLong* wiShapeInfo, const void* vb, + const Nd4jLong* bShapeInfo, const void* vc0, const Nd4jLong* c0ShapeInfo, + const void* vmask, const Nd4jLong* maskShapeInfo, void* vht, + const Nd4jLong* htShapeInfo, void* vct, const Nd4jLong* ctShapeInfo) { + sruBICuda<<>>( + vx, xShapeInfo, vwi, wiShapeInfo, vb, bShapeInfo, vc0, c0ShapeInfo, vmask, + maskShapeInfo, vht, htShapeInfo, vct, ctShapeInfo); } ////////////////////////////////////////////////////////////////////////// -void sruBI(sd::LaunchContext * context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct) { - - // x = x * mask - if(mask) - x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x); // apply mask - - // U = x * w - NDArray wi = mmul(*x, *w); // U [time x bS x 6*K] - - PointersManager manager(context, "sru_bi"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (x->sizeAt(1) * x->sizeAt(2) + threadsPerBlock - 1) / threadsPerBlock; // loop through last two dimensions of x array -> bS, 2*K - const int sharedMem = threadsPerBlock * sizeof(int) * x->rankOf() + 128; - - NDArray::prepareSpecialUse({ht, ct}, {x, &wi, b, c0, mask}); - BUILD_SINGLE_SELECTOR(x->dataType(), sruBICudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), x->specialBuffer(), x->specialShapeInfo(), wi.specialBuffer(), wi.specialShapeInfo(), b->specialBuffer(), b->specialShapeInfo(), c0->specialBuffer(), c0->specialShapeInfo(), mask ? mask->specialBuffer() : nullptr, mask ? mask->specialShapeInfo() : nullptr, ht->specialBuffer(), ht->specialShapeInfo(), ct->specialBuffer(), ct->specialShapeInfo()), FLOAT_TYPES); - NDArray::registerSpecialUse({ht, ct}, {x, &wi, b, c0, mask}); - - manager.synchronize(); +void sruBI(sd::LaunchContext* context, NDArray* x, const NDArray* w, + const NDArray* b, const NDArray* c0, const NDArray* mask, + NDArray* ht, NDArray* ct) { + // x = x * mask + if (mask) + x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x); // apply mask + + // U = x * w + NDArray wi = mmul(*x, *w); // U [time x bS x 6*K] + + PointersManager manager(context, "sru_bi"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (x->sizeAt(1) * x->sizeAt(2) + threadsPerBlock - 1) / + threadsPerBlock; // loop through last two dimensions of x array -> bS, + // 2*K + const int sharedMem = threadsPerBlock * sizeof(int) * x->rankOf() + 128; + + NDArray::prepareSpecialUse({ht, ct}, {x, &wi, b, c0, mask}); + BUILD_SINGLE_SELECTOR( + x->dataType(), sruBICudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + x->specialBuffer(), x->specialShapeInfo(), wi.specialBuffer(), + wi.specialShapeInfo(), b->specialBuffer(), b->specialShapeInfo(), + c0->specialBuffer(), c0->specialShapeInfo(), + mask ? mask->specialBuffer() : nullptr, + mask ? mask->specialShapeInfo() : nullptr, ht->specialBuffer(), + ht->specialShapeInfo(), ct->specialBuffer(), ct->specialShapeInfo()), + FLOAT_TYPES); + NDArray::registerSpecialUse({ht, ct}, {x, &wi, b, c0, mask}); + + manager.synchronize(); } - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - ////////////////////////////////////////////////////////////////////////// template -__global__ static void sruBIBPCuda(const void* vx, const Nd4jLong* xShapeInfo, - const void* vwi, const Nd4jLong* wiShapeInfo, - const void* vb, const Nd4jLong* bShapeInfo, - const void* vc0, const Nd4jLong* c0ShapeInfo, - const void* vmask, const Nd4jLong* maskShapeInfo, - const void* vct, const Nd4jLong* ctShapeInfo, - const void* vgradHt, const Nd4jLong* gradHtShapeInfo, - const void* vgradCt, const Nd4jLong* gradCtShapeInfo, - void* vgradI, const Nd4jLong* gradIShapeInfo, - void* vgradWi, const Nd4jLong* gradWiShapeInfo, - void* vgradB, const Nd4jLong* gradBShapeInfo, - void* vgradC0, const Nd4jLong* gradC0ShapeInfo) { - // inputs: - // x [time, bS, 2*K] - // wi [time, bS, 6*K], wi = mmul(x, weights); - // b [4*K] - // c0 [bS, 2*K] - // mask [bS, 2*K], optional - // ct [time, bS, 2*K] - // gradHt [time, bS, 2*K] - // gradCt [bS, 2*K] - - // outputs - // gradI [time, bS, 2*K] - // gradWi [time, 2*K, 6*K] - // gradB [bS, 4*K] - // gradC0 [bS, 2*K] - - const auto x = reinterpret_cast(vx); - const auto wi = reinterpret_cast(vwi); - const auto b = reinterpret_cast(vb); - const auto c0 = reinterpret_cast(vc0); - const auto mask = reinterpret_cast(vmask); - const auto ct = reinterpret_cast(vct); - const auto gradHt = reinterpret_cast(vgradHt); - const auto gradCt = reinterpret_cast(vgradCt); - - auto gradI = reinterpret_cast(vgradI); - auto gradWi = reinterpret_cast(vgradWi); - auto gradB = reinterpret_cast(vgradB); - auto gradC0 = reinterpret_cast(vgradC0); - - const int rank = 3; - - __shared__ int time, K, *sharedMem; - __shared__ Nd4jLong len, totalThreads; - - if (threadIdx.x == 0) { - - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - time = xShapeInfo[1]; - K = xShapeInfo[3] / 2; - len = xShapeInfo[2] * xShapeInfo[3]; // 2*K*bS - - totalThreads = gridDim.x * blockDim.x; - } - - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - auto coords = sharedMem + threadIdx.x * rank; - - if(tid >= len) - return; - - shape::index2coords(tid, rank - 1, xShapeInfo + 2, coords + 1); // loop through last two dimensions of x : {bS, 2*K} - - const auto maskOffst = mask ? shape::getOffset(maskShapeInfo, coords + 1) : 0; - const auto c0Offset = shape::getOffset(c0ShapeInfo, coords + 1); - const auto gradCtOffset = shape::getOffset(gradCtShapeInfo, coords + 1); - const auto gradC0Offset = shape::getOffset(gradC0ShapeInfo, coords + 1); - const auto bFOffset = shape::getOffset(bShapeInfo, coords + 2); - const auto bROffset = bFOffset + 2 * K * bShapeInfo[2]; // 2*K*b_stride - // const auto gradBFOffset = shape::getOffset(gradBShapeInfo, coords + 1); - const auto gradBFOffset = coords[1] * gradBShapeInfo[3] / 2 + coords[2] * gradBShapeInfo[4]; - const auto gradBROffset = gradBFOffset + gradBShapeInfo[3]; - - const bool flip = coords[2] >= K; - - if(flip) - coords[0] = 0; +__global__ static void sruBIBPCuda( + const void* vx, const Nd4jLong* xShapeInfo, const void* vwi, + const Nd4jLong* wiShapeInfo, const void* vb, const Nd4jLong* bShapeInfo, + const void* vc0, const Nd4jLong* c0ShapeInfo, const void* vmask, + const Nd4jLong* maskShapeInfo, const void* vct, const Nd4jLong* ctShapeInfo, + const void* vgradHt, const Nd4jLong* gradHtShapeInfo, const void* vgradCt, + const Nd4jLong* gradCtShapeInfo, void* vgradI, + const Nd4jLong* gradIShapeInfo, void* vgradWi, + const Nd4jLong* gradWiShapeInfo, void* vgradB, + const Nd4jLong* gradBShapeInfo, void* vgradC0, + const Nd4jLong* gradC0ShapeInfo) { + // inputs: + // x [time, bS, 2*K] + // wi [time, bS, 6*K], wi = mmul(x, weights); + // b [4*K] + // c0 [bS, 2*K] + // mask [bS, 2*K], optional + // ct [time, bS, 2*K] + // gradHt [time, bS, 2*K] + // gradCt [bS, 2*K] + + // outputs + // gradI [time, bS, 2*K] + // gradWi [time, 2*K, 6*K] + // gradB [bS, 4*K] + // gradC0 [bS, 2*K] + + const auto x = reinterpret_cast(vx); + const auto wi = reinterpret_cast(vwi); + const auto b = reinterpret_cast(vb); + const auto c0 = reinterpret_cast(vc0); + const auto mask = reinterpret_cast(vmask); + const auto ct = reinterpret_cast(vct); + const auto gradHt = reinterpret_cast(vgradHt); + const auto gradCt = reinterpret_cast(vgradCt); + + auto gradI = reinterpret_cast(vgradI); + auto gradWi = reinterpret_cast(vgradWi); + auto gradB = reinterpret_cast(vgradB); + auto gradC0 = reinterpret_cast(vgradC0); + + const int rank = 3; + + __shared__ int time, K, *sharedMem; + __shared__ Nd4jLong len, totalThreads; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + time = xShapeInfo[1]; + K = xShapeInfo[3] / 2; + len = xShapeInfo[2] * xShapeInfo[3]; // 2*K*bS + + totalThreads = gridDim.x * blockDim.x; + } + + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto coords = sharedMem + threadIdx.x * rank; + + if (tid >= len) return; + + shape::index2coords( + tid, rank - 1, xShapeInfo + 2, + coords + 1); // loop through last two dimensions of x : {bS, 2*K} + + const auto maskOffst = mask ? shape::getOffset(maskShapeInfo, coords + 1) : 0; + const auto c0Offset = shape::getOffset(c0ShapeInfo, coords + 1); + const auto gradCtOffset = shape::getOffset(gradCtShapeInfo, coords + 1); + const auto gradC0Offset = shape::getOffset(gradC0ShapeInfo, coords + 1); + const auto bFOffset = shape::getOffset(bShapeInfo, coords + 2); + const auto bROffset = bFOffset + 2 * K * bShapeInfo[2]; // 2*K*b_stride + // const auto gradBFOffset = shape::getOffset(gradBShapeInfo, coords + 1); + const auto gradBFOffset = + coords[1] * gradBShapeInfo[3] / 2 + coords[2] * gradBShapeInfo[4]; + const auto gradBROffset = gradBFOffset + gradBShapeInfo[3]; + + const bool flip = coords[2] >= K; + + if (flip) + coords[0] = 0; + else + coords[0] = time - 1; + + auto xOffset = shape::getOffset(xShapeInfo, coords); + auto ctOffset = shape::getOffset(ctShapeInfo, coords); + auto gradIOffset = shape::getOffset(gradIShapeInfo, coords); + auto gradHtOffset = shape::getOffset(gradHtShapeInfo, coords); + + coords[2] *= 3; + auto gradWiOffset0 = shape::getOffset(gradWiShapeInfo, coords); + auto gradWiOffset1 = + gradWiOffset0 + gradWiShapeInfo[rank + 3]; // add last stride + auto gradWiOffset2 = + gradWiOffset1 + gradWiShapeInfo[rank + 3]; // add last stride + auto wiOffset0 = shape::getOffset(wiShapeInfo, coords); + auto wiOffset1 = wiOffset0 + wiShapeInfo[rank + 3]; // add last stride + auto wiOffset2 = wiOffset1 + wiShapeInfo[rank + 3]; // add last stride + + const T xVal = x[xOffset]; + const T maskVal = mask ? mask[maskOffst] : static_cast(1); + const T c0Val = c0[c0Offset]; + const T bF = b[bFOffset]; + const T bR = b[bROffset]; + T gradCtVal = gradCt[gradCtOffset]; + T gbF = 0.f; + T gbR = 0.f; + + // time loop + for (uint t = 0; t < time; ++t) { + // evaluate sigmoids + T ft = (1.f) / (1.f + sd::math::nd4j_exp(-(wi[wiOffset1] + bF))); + T rt = (1.f) / (1.f + sd::math::nd4j_exp(-(wi[wiOffset2] + bR))); + + T val = sd::math::nd4j_tanh(ct[ctOffset]); + + T prevVal; + if (t < time - 1) + prevVal = + ct[ctOffset += flip ? ctShapeInfo[rank + 1] : -ctShapeInfo[rank + 1]]; else - coords[0] = time - 1; - - auto xOffset = shape::getOffset(xShapeInfo, coords); - auto ctOffset = shape::getOffset(ctShapeInfo, coords); - auto gradIOffset = shape::getOffset(gradIShapeInfo, coords); - auto gradHtOffset = shape::getOffset(gradHtShapeInfo, coords); - - coords[2] *= 3; - auto gradWiOffset0 = shape::getOffset(gradWiShapeInfo, coords); - auto gradWiOffset1 = gradWiOffset0 + gradWiShapeInfo[rank + 3]; // add last stride - auto gradWiOffset2 = gradWiOffset1 + gradWiShapeInfo[rank + 3]; // add last stride - auto wiOffset0 = shape::getOffset(wiShapeInfo, coords); - auto wiOffset1 = wiOffset0 + wiShapeInfo[rank + 3]; // add last stride - auto wiOffset2 = wiOffset1 + wiShapeInfo[rank + 3]; // add last stride - - const T xVal = x[xOffset]; - const T maskVal = mask ? mask[maskOffst] : static_cast(1); - const T c0Val = c0[c0Offset]; - const T bF = b[bFOffset]; - const T bR = b[bROffset]; - T gradCtVal = gradCt[gradCtOffset]; - T gbF = 0.f; - T gbR = 0.f; - - // time loop - for (uint t = 0; t < time; ++t) { - - // evaluate sigmoids - T ft = (1.f)/(1.f + sd::math::nd4j_exp(-(wi[wiOffset1] + bF))); - T rt = (1.f)/(1.f + sd::math::nd4j_exp(-(wi[wiOffset2] + bR))); - - T val = sd::math::nd4j_tanh(ct[ctOffset]); - - T prevVal; - if(t < time-1) - prevVal = ct[ctOffset += flip ? ctShapeInfo[rank + 1] : -ctShapeInfo[rank + 1]]; - else - prevVal = c0Val; - - // grad wrt input - gradI[gradIOffset] = gradHt[gradHtOffset] - gradHt[gradHtOffset] * rt ; - - // grad wrt rt, wiR and bR - T grt = gradHt[gradHtOffset] * (val * maskVal - x[xOffset]) * (rt - rt * rt); - gradWi[gradWiOffset2] = grt; - gbR += grt; - - // grad wrt state - T gradC0Val = gradHt[gradHtOffset] * maskVal * (rt - rt * val * val) + gradCtVal; - - // grad wrt wi0 - gradWi[gradWiOffset0] = gradC0Val - gradC0Val * ft; - - // grad wrt ft, wi1, and bF - T gft = gradC0Val * (prevVal - wi[wiOffset0]) * (ft - ft * ft); - gradWi[gradWiOffset1] = gft; - gbF += gft; - - // grad wrt c_previous - gradCtVal = gradC0Val * ft; - - if(flip) { - xOffset += xShapeInfo[rank + 1]; // first stride, corresponds to time step - gradHtOffset += gradHtShapeInfo[rank + 1]; - gradIOffset += gradIShapeInfo[rank + 1]; - wiOffset0 += wiShapeInfo[rank + 1]; - wiOffset1 += wiShapeInfo[rank + 1]; - wiOffset2 += wiShapeInfo[rank + 1]; - gradWiOffset0 += gradWiShapeInfo[rank + 1]; - gradWiOffset1 += gradWiShapeInfo[rank + 1]; - gradWiOffset2 += gradWiShapeInfo[rank + 1]; - } - else { - xOffset -= xShapeInfo[rank + 1]; // first stride, corresponds to time step - gradHtOffset -= gradHtShapeInfo[rank + 1]; - gradIOffset -= gradIShapeInfo[rank + 1]; - wiOffset0 -= wiShapeInfo[rank + 1]; - wiOffset1 -= wiShapeInfo[rank + 1]; - wiOffset2 -= wiShapeInfo[rank + 1]; - gradWiOffset0 -= gradWiShapeInfo[rank + 1]; - gradWiOffset1 -= gradWiShapeInfo[rank + 1]; - gradWiOffset2 -= gradWiShapeInfo[rank + 1]; - } + prevVal = c0Val; + + // grad wrt input + gradI[gradIOffset] = gradHt[gradHtOffset] - gradHt[gradHtOffset] * rt; + + // grad wrt rt, wiR and bR + T grt = + gradHt[gradHtOffset] * (val * maskVal - x[xOffset]) * (rt - rt * rt); + gradWi[gradWiOffset2] = grt; + gbR += grt; + + // grad wrt state + T gradC0Val = + gradHt[gradHtOffset] * maskVal * (rt - rt * val * val) + gradCtVal; + + // grad wrt wi0 + gradWi[gradWiOffset0] = gradC0Val - gradC0Val * ft; + + // grad wrt ft, wi1, and bF + T gft = gradC0Val * (prevVal - wi[wiOffset0]) * (ft - ft * ft); + gradWi[gradWiOffset1] = gft; + gbF += gft; + + // grad wrt c_previous + gradCtVal = gradC0Val * ft; + + if (flip) { + xOffset += + xShapeInfo[rank + 1]; // first stride, corresponds to time step + gradHtOffset += gradHtShapeInfo[rank + 1]; + gradIOffset += gradIShapeInfo[rank + 1]; + wiOffset0 += wiShapeInfo[rank + 1]; + wiOffset1 += wiShapeInfo[rank + 1]; + wiOffset2 += wiShapeInfo[rank + 1]; + gradWiOffset0 += gradWiShapeInfo[rank + 1]; + gradWiOffset1 += gradWiShapeInfo[rank + 1]; + gradWiOffset2 += gradWiShapeInfo[rank + 1]; + } else { + xOffset -= + xShapeInfo[rank + 1]; // first stride, corresponds to time step + gradHtOffset -= gradHtShapeInfo[rank + 1]; + gradIOffset -= gradIShapeInfo[rank + 1]; + wiOffset0 -= wiShapeInfo[rank + 1]; + wiOffset1 -= wiShapeInfo[rank + 1]; + wiOffset2 -= wiShapeInfo[rank + 1]; + gradWiOffset0 -= gradWiShapeInfo[rank + 1]; + gradWiOffset1 -= gradWiShapeInfo[rank + 1]; + gradWiOffset2 -= gradWiShapeInfo[rank + 1]; } + } - gradB[gradBFOffset] = gbF; - gradB[gradBROffset] = gbR; - gradC0[gradC0Offset] = gradCtVal; + gradB[gradBFOffset] = gbF; + gradB[gradBROffset] = gbR; + gradC0[gradC0Offset] = gradCtVal; } ////////////////////////////////////////////////////////////////////////// template -static void sruBIBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - const void* vwi, const Nd4jLong* wiShapeInfo, - const void* vb, const Nd4jLong* bShapeInfo, - const void* vc0, const Nd4jLong* c0ShapeInfo, - const void* vmask, const Nd4jLong* maskShapeInfo, - const void* vct, const Nd4jLong* ctShapeInfo, - const void* vgradHt, const Nd4jLong* gradHtShapeInfo, - const void* vgradCt, const Nd4jLong* gradCtShapeInfo, - void* vgradI, const Nd4jLong* gradIShapeInfo, - void* vgradWi, const Nd4jLong* gradWiShapeInfo, - void* vgradB, const Nd4jLong* gradBShapeInfo, - void* vgradC0, const Nd4jLong* gradC0ShapeInfo) { - - sruBIBPCuda<<>>(vx, xShapeInfo, vwi, wiShapeInfo, vb, bShapeInfo, vc0, c0ShapeInfo, vmask, maskShapeInfo, vct, ctShapeInfo, vgradHt, gradHtShapeInfo, vgradCt, gradCtShapeInfo, vgradI, gradIShapeInfo, vgradWi, gradWiShapeInfo, vgradB, gradBShapeInfo, vgradC0, gradC0ShapeInfo); +static void sruBIBPCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vwi, const Nd4jLong* wiShapeInfo, const void* vb, + const Nd4jLong* bShapeInfo, const void* vc0, const Nd4jLong* c0ShapeInfo, + const void* vmask, const Nd4jLong* maskShapeInfo, const void* vct, + const Nd4jLong* ctShapeInfo, const void* vgradHt, + const Nd4jLong* gradHtShapeInfo, const void* vgradCt, + const Nd4jLong* gradCtShapeInfo, void* vgradI, + const Nd4jLong* gradIShapeInfo, void* vgradWi, + const Nd4jLong* gradWiShapeInfo, void* vgradB, + const Nd4jLong* gradBShapeInfo, void* vgradC0, + const Nd4jLong* gradC0ShapeInfo) { + sruBIBPCuda<<>>( + vx, xShapeInfo, vwi, wiShapeInfo, vb, bShapeInfo, vc0, c0ShapeInfo, vmask, + maskShapeInfo, vct, ctShapeInfo, vgradHt, gradHtShapeInfo, vgradCt, + gradCtShapeInfo, vgradI, gradIShapeInfo, vgradWi, gradWiShapeInfo, vgradB, + gradBShapeInfo, vgradC0, gradC0ShapeInfo); } -BUILD_SINGLE_TEMPLATE(template void sruBIBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vwi, const Nd4jLong* wiShapeInfo, const void* vb, const Nd4jLong* bShapeInfo, const void* vc0, const Nd4jLong* c0ShapeInfo, const void* vmask, const Nd4jLong* maskShapeInfo, const void* vct, const Nd4jLong* ctShapeInfo, const void* vgradHt, const Nd4jLong* gradHtShapeInfo, const void* vgradCt, const Nd4jLong* gradCtShapeInfo, void* vgradI, const Nd4jLong* gradIShapeInfo, void* vgradWi, const Nd4jLong* gradWiShapeInfo, void* vgradB, const Nd4jLong* gradBShapeInfo, void* vgradC0, const Nd4jLong* gradC0ShapeInfo), FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template void sruBIBPCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const int sharedMem, const cudaStream_t* stream, + const void* vx, const Nd4jLong* xShapeInfo, + const void* vwi, const Nd4jLong* wiShapeInfo, + const void* vb, const Nd4jLong* bShapeInfo, + const void* vc0, const Nd4jLong* c0ShapeInfo, + const void* vmask, const Nd4jLong* maskShapeInfo, + const void* vct, const Nd4jLong* ctShapeInfo, + const void* vgradHt, const Nd4jLong* gradHtShapeInfo, + const void* vgradCt, const Nd4jLong* gradCtShapeInfo, + void* vgradI, const Nd4jLong* gradIShapeInfo, + void* vgradWi, const Nd4jLong* gradWiShapeInfo, + void* vgradB, const Nd4jLong* gradBShapeInfo, + void* vgradC0, const Nd4jLong* gradC0ShapeInfo), + FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// -void sruBIBP(sd::LaunchContext* context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, - const NDArray* gradCt, const NDArray* gradHt, const NDArray* mask, - NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0) { - - // x = x * mask - if(mask) - x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x); // apply mask - - // U = x * w - NDArray wi = mmul(*x, *w); // U [time x bS x 6*K] - - const int time = x->sizeAt(0); - const int bS = x->sizeAt(1); - const int K = x->sizeAt(2) / 2; - - NDArray gradBias(x->ordering(), {bS, 4*K}, x->dataType(), context); - NDArray gradWi (x->ordering(), {time, bS, 6*K}, x->dataType(), context); - - PointersManager manager(context, "sru_bi_bp"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (x->sizeAt(1) * x->sizeAt(2) + threadsPerBlock - 1) / threadsPerBlock; // loop through last two dimensions of x array -> bS, 2*K - const int sharedMem = threadsPerBlock * sizeof(int) * x->rankOf() + 128; - - NDArray::prepareSpecialUse({gradI, &gradWi, &gradBias, gradC0}, {x, &wi, b, c0, ct, gradCt, gradHt, mask}); - BUILD_SINGLE_SELECTOR(x->dataType(), sruBIBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), x->specialBuffer(), x->specialShapeInfo(), wi.specialBuffer(), wi.specialShapeInfo(), b->specialBuffer(), b->specialShapeInfo(), c0->specialBuffer(), c0->specialShapeInfo(), mask ? mask->specialBuffer() : nullptr, mask ? mask->specialShapeInfo() : nullptr, ct->specialBuffer(), ct->specialShapeInfo(), gradHt->specialBuffer(), gradHt->specialShapeInfo(), gradCt->specialBuffer(), gradCt->specialShapeInfo(), gradI->specialBuffer(), gradI->specialShapeInfo(), gradWi.specialBuffer(), gradWi.specialShapeInfo(), gradBias.specialBuffer(), gradBias.specialShapeInfo(), gradC0->specialBuffer(), gradC0->specialShapeInfo()), FLOAT_TYPES); - NDArray::registerSpecialUse({gradI, &gradWi, &gradBias, gradC0}, {x, &wi, b, c0, ct, gradCt, gradHt, mask}); - - manager.synchronize(); - - // gradB - gradBias.reduceAlongDimension(reduce::Sum, *gradB, {0}); // [4*K] - - // gradW - x->permutei({0, 2, 1}); // [time, bS, 2*K] -> [time, 2*K, bS] - MmulHelper::mmul(x, &gradWi, gradW, 1., 0.); // [time, 2*K, bS] x [time, bS , 6*K] = [time, 2*K, 6*K] +void sruBIBP(sd::LaunchContext* context, NDArray* x, const NDArray* w, + const NDArray* b, const NDArray* c0, const NDArray* ct, + const NDArray* gradCt, const NDArray* gradHt, const NDArray* mask, + NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0) { + // x = x * mask + if (mask) + x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x); // apply mask + + // U = x * w + NDArray wi = mmul(*x, *w); // U [time x bS x 6*K] + + const int time = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int K = x->sizeAt(2) / 2; + + NDArray gradBias(x->ordering(), {bS, 4 * K}, x->dataType(), context); + NDArray gradWi(x->ordering(), {time, bS, 6 * K}, x->dataType(), context); + + PointersManager manager(context, "sru_bi_bp"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (x->sizeAt(1) * x->sizeAt(2) + threadsPerBlock - 1) / + threadsPerBlock; // loop through last two dimensions of x array -> bS, + // 2*K + const int sharedMem = threadsPerBlock * sizeof(int) * x->rankOf() + 128; + + NDArray::prepareSpecialUse({gradI, &gradWi, &gradBias, gradC0}, + {x, &wi, b, c0, ct, gradCt, gradHt, mask}); + BUILD_SINGLE_SELECTOR( + x->dataType(), sruBIBPCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + x->specialBuffer(), x->specialShapeInfo(), wi.specialBuffer(), + wi.specialShapeInfo(), b->specialBuffer(), b->specialShapeInfo(), + c0->specialBuffer(), c0->specialShapeInfo(), + mask ? mask->specialBuffer() : nullptr, + mask ? mask->specialShapeInfo() : nullptr, ct->specialBuffer(), + ct->specialShapeInfo(), gradHt->specialBuffer(), + gradHt->specialShapeInfo(), gradCt->specialBuffer(), + gradCt->specialShapeInfo(), gradI->specialBuffer(), + gradI->specialShapeInfo(), gradWi.specialBuffer(), + gradWi.specialShapeInfo(), gradBias.specialBuffer(), + gradBias.specialShapeInfo(), gradC0->specialBuffer(), + gradC0->specialShapeInfo()), + FLOAT_TYPES); + NDArray::registerSpecialUse({gradI, &gradWi, &gradBias, gradC0}, + {x, &wi, b, c0, ct, gradCt, gradHt, mask}); + + manager.synchronize(); + + // gradB + gradBias.reduceAlongDimension(reduce::Sum, *gradB, {0}); // [4*K] + + // gradW + x->permutei({0, 2, 1}); // [time, bS, 2*K] -> [time, 2*K, bS] + MmulHelper::mmul( + x, &gradWi, gradW, 1., + 0.); // [time, 2*K, bS] x [time, bS , 6*K] = [time, 2*K, 6*K] } - -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/stack.cu b/libnd4j/include/ops/declarable/helpers/cuda/stack.cu index 2bb09c3b5274..f057b3725a2b 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/stack.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/stack.cu @@ -18,196 +18,216 @@ // Created by Yurii Shyrma on 02.01.2018 // -#include -#include #include #include -#include -#include #include +#include +#include +#include +#include namespace sd { namespace ops { namespace helpers { - /////////////////////////////////////////////////////////////////// template -static __global__ void stackScalarsCuda(void* pVx, void* vz, const Nd4jLong* zShapeInfo) { +static __global__ void stackScalarsCuda(void* pVx, void* vz, + const Nd4jLong* zShapeInfo) { + T* z = reinterpret_cast(vz); - T* z = reinterpret_cast(vz); + __shared__ Nd4jLong zLen, totalThreads; - __shared__ Nd4jLong zLen, totalThreads; - - if (threadIdx.x == 0) { - zLen = shape::length(zShapeInfo); - totalThreads = gridDim.x * blockDim.x; - } - __syncthreads(); + if (threadIdx.x == 0) { + zLen = shape::length(zShapeInfo); + totalThreads = gridDim.x * blockDim.x; + } + __syncthreads(); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (Nd4jLong i = tid; i < zLen; i += totalThreads) { - - const T *x = reinterpret_cast(reinterpret_cast(pVx)[i]); - z[shape::getIndexOffset(i, zShapeInfo)] = *x; - } + for (Nd4jLong i = tid; i < zLen; i += totalThreads) { + const T* x = reinterpret_cast(reinterpret_cast(pVx)[i]); + z[shape::getIndexOffset(i, zShapeInfo)] = *x; + } } - /////////////////////////////////////////////////////////////////// -template -__host__ static void stackScalarsCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, - void* pVx, void* vz, const Nd4jLong* zShapeInfo) { - - stackScalarsCuda<<>>(pVx, vz, zShapeInfo); +template +__host__ static void stackScalarsCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, + const cudaStream_t* stream, + void* pVx, void* vz, + const Nd4jLong* zShapeInfo) { + stackScalarsCuda + <<>>(pVx, vz, zShapeInfo); } /////////////////////////////////////////////////////////////////// template -static void stack_(sd::LaunchContext* context, const std::vector& inArrs, NDArray& output, const int dim) { +static void stack_(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output, + const int dim) { + const int numOfSubArrs = inArrs.size(); - const int numOfSubArrs = inArrs.size(); + NDArray::prepareSpecialUse({&output}, inArrs); - NDArray::prepareSpecialUse({&output}, inArrs); + if (inArrs[0]->rankOf() == 0) { + std::vector hInBuffers(numOfSubArrs); - if(inArrs[0]->rankOf() == 0) { + for (int i = 0; i < numOfSubArrs; ++i) + hInBuffers[i] = inArrs[i]->specialBuffer(); - std::vector hInBuffers(numOfSubArrs); + PointersManager manager(context, "helpers::stack cuda"); - for(int i = 0; i < numOfSubArrs; ++i) - hInBuffers[i] = inArrs[i]->specialBuffer(); + void* dInBuffers = manager.replicatePointer( + hInBuffers.data(), hInBuffers.size() * sizeof(void*)); - PointersManager manager(context, "helpers::stack cuda"); - - void* dInBuffers = manager.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void*)); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - stackScalarsCudaLauncher(blocksPerGrid, threadsPerBlock, context->getCudaStream(), dInBuffers, output.specialBuffer(), output.specialShapeInfo()); - - manager.synchronize(); - } - else { + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - auto zTadPack = ConstantTadHelper::getInstance().tadForDimensions(output.shapeInfo(), ShapeUtils::evalDimsToExclude(output.rankOf(), {dim})); - auto zTadShapeInfo = zTadPack.primaryShapeInfo(); + stackScalarsCudaLauncher( + blocksPerGrid, threadsPerBlock, context->getCudaStream(), dInBuffers, + output.specialBuffer(), output.specialShapeInfo()); - for (uint i = 0; i < numOfSubArrs; ++i) { + manager.synchronize(); + } else { + auto zTadPack = ConstantTadHelper::getInstance().tadForDimensions( + output.shapeInfo(), + ShapeUtils::evalDimsToExclude(output.rankOf(), {dim})); + auto zTadShapeInfo = zTadPack.primaryShapeInfo(); - void* zBuff = output.specialBufferWithOffset(zTadPack.primaryOffsets()[i]); + for (uint i = 0; i < numOfSubArrs; ++i) { + void* zBuff = + output.specialBufferWithOffset(zTadPack.primaryOffsets()[i]); - NativeOpExecutioner::execTransformAny(context, transform::Assign, - nullptr, inArrs[i]->shapeInfo(), inArrs[i]->specialBuffer(), inArrs[i]->specialShapeInfo(), - nullptr, zTadShapeInfo, zBuff, zTadPack.specialShapeInfo(), - nullptr, nullptr, nullptr, false/*allowParallelism*/); - } + NativeOpExecutioner::execTransformAny( + context, transform::Assign, nullptr, inArrs[i]->shapeInfo(), + inArrs[i]->specialBuffer(), inArrs[i]->specialShapeInfo(), nullptr, + zTadShapeInfo, zBuff, zTadPack.specialShapeInfo(), nullptr, nullptr, + nullptr, false /*allowParallelism*/); } + } - NDArray::registerSpecialUse({&output}, inArrs); + NDArray::registerSpecialUse({&output}, inArrs); } //////////////////////////////////////////////////////////////////////// -void stack(sd::LaunchContext* context, const std::vector& inArrs, NDArray& output, const int dim) { - BUILD_SINGLE_SELECTOR(output.dataType(), stack_, (context, inArrs, output, dim), LIBND4J_TYPES); +void stack(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output, + const int dim) { + BUILD_SINGLE_SELECTOR(output.dataType(), stack_, + (context, inArrs, output, dim), LIBND4J_TYPES); } -BUILD_SINGLE_TEMPLATE(template void stack_ , (sd::LaunchContext* context, const std::vector& inArrs, NDArray& output, const int dim), LIBND4J_TYPES); - +BUILD_SINGLE_TEMPLATE(template void stack_, + (sd::LaunchContext * context, + const std::vector& inArrs, + NDArray& output, const int dim), + LIBND4J_TYPES); /////////////////////////////////////////////////////////////////// template -static __global__ void unstackScalarsCuda(const void* vx, const Nd4jLong* xShapeInfo, void* pVz) { - - const T* x = reinterpret_cast(vx); - - __shared__ Nd4jLong xLen, totalThreads; +static __global__ void unstackScalarsCuda(const void* vx, + const Nd4jLong* xShapeInfo, + void* pVz) { + const T* x = reinterpret_cast(vx); - if (threadIdx.x == 0) { - xLen = shape::length(xShapeInfo); - totalThreads = gridDim.x * blockDim.x; - } - __syncthreads(); + __shared__ Nd4jLong xLen, totalThreads; - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + totalThreads = gridDim.x * blockDim.x; + } + __syncthreads(); - for (Nd4jLong i = tid; i < xLen; i += totalThreads) { + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - T* z = reinterpret_cast(reinterpret_cast(pVz)[i]); - *z = x[shape::getIndexOffset(i, xShapeInfo)]; - } + for (Nd4jLong i = tid; i < xLen; i += totalThreads) { + T* z = reinterpret_cast(reinterpret_cast(pVz)[i]); + *z = x[shape::getIndexOffset(i, xShapeInfo)]; + } } - /////////////////////////////////////////////////////////////////// -template -__host__ static void unstackScalarsCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, void* pVz) { - - unstackScalarsCuda<<>>(vx, xShapeInfo, pVz); +template +__host__ static void unstackScalarsCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, + const cudaStream_t* stream, + const void* vx, + const Nd4jLong* xShapeInfo, + void* pVz) { + unstackScalarsCuda + <<>>(vx, xShapeInfo, pVz); } /////////////////////////////////////////////////////////////////// template -static void unstack_(sd::LaunchContext* context, const NDArray& input, const std::vector& outArrs, const int dim) { - - const int numOfSubArrs = outArrs.size(); +static void unstack_(sd::LaunchContext* context, const NDArray& input, + const std::vector& outArrs, const int dim) { + const int numOfSubArrs = outArrs.size(); - // NDArray::prepareSpecialUse(outArrs, {&input}); - input.syncToDevice(); - for (const auto a : outArrs) - a->getDataBuffer()->allocateSpecial(); + // NDArray::prepareSpecialUse(outArrs, {&input}); + input.syncToDevice(); + for (const auto a : outArrs) a->getDataBuffer()->allocateSpecial(); + if (outArrs[0]->rankOf() == 0) { + std::vector hOutBuffers(numOfSubArrs); - if(outArrs[0]->rankOf() == 0) { + for (int i = 0; i < numOfSubArrs; ++i) + hOutBuffers[i] = outArrs[i]->specialBuffer(); - std::vector hOutBuffers(numOfSubArrs); + PointersManager manager(context, "helpers::unstack cuda"); - for(int i = 0; i < numOfSubArrs; ++i) - hOutBuffers[i] = outArrs[i]->specialBuffer(); + void* dOutBuffers = manager.replicatePointer( + hOutBuffers.data(), hOutBuffers.size() * sizeof(void*)); - PointersManager manager(context, "helpers::unstack cuda"); + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - void* dOutBuffers = manager.replicatePointer(hOutBuffers.data(), hOutBuffers.size() * sizeof(void*)); + unstackScalarsCudaLauncher( + blocksPerGrid, threadsPerBlock, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), dOutBuffers); - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + manager.synchronize(); + } else { + auto xTadPack = ConstantTadHelper::getInstance().tadForDimensions( + input.shapeInfo(), + ShapeUtils::evalDimsToExclude(input.rankOf(), {dim})); + auto xTadShapeInfo = xTadPack.primaryShapeInfo(); - unstackScalarsCudaLauncher(blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), dOutBuffers); + for (uint i = 0; i < numOfSubArrs; ++i) { + auto xBuff = input.specialBufferWithOffset(xTadPack.primaryOffsets()[i]); - manager.synchronize(); + NativeOpExecutioner::execTransformAny( + input.getContext(), transform::Assign, nullptr, xTadShapeInfo, xBuff, + xTadPack.specialShapeInfo(), nullptr, outArrs[i]->shapeInfo(), + outArrs[i]->specialBuffer(), outArrs[i]->specialShapeInfo(), nullptr, + nullptr, nullptr, false /*allowParallelism*/); } - else { - - auto xTadPack = ConstantTadHelper::getInstance().tadForDimensions(input.shapeInfo(), ShapeUtils::evalDimsToExclude(input.rankOf(), {dim})); - auto xTadShapeInfo = xTadPack.primaryShapeInfo(); - - for (uint i = 0; i < numOfSubArrs; ++i) { + } - auto xBuff = input.specialBufferWithOffset(xTadPack.primaryOffsets()[i]); - - NativeOpExecutioner::execTransformAny(input.getContext(), transform::Assign, - nullptr, xTadShapeInfo, xBuff, xTadPack.specialShapeInfo(), - nullptr, outArrs[i]->shapeInfo(), outArrs[i]->specialBuffer(), outArrs[i]->specialShapeInfo(), - nullptr, nullptr, nullptr, false/*allowParallelism*/); - } - } - - // NDArray::registerSpecialUse(outArrs, {&input}); - input.tickReadDevice(); - for (const auto p : outArrs) - p->tickWriteDevice(); + // NDArray::registerSpecialUse(outArrs, {&input}); + input.tickReadDevice(); + for (const auto p : outArrs) p->tickWriteDevice(); } //////////////////////////////////////////////////////////////////////// -void unstack(sd::LaunchContext* context, const NDArray& input, const std::vector& outArrs, const int dim) { - BUILD_SINGLE_SELECTOR(input.dataType(), unstack_, (context, input, outArrs, dim), LIBND4J_TYPES); +void unstack(sd::LaunchContext* context, const NDArray& input, + const std::vector& outArrs, const int dim) { + BUILD_SINGLE_SELECTOR(input.dataType(), unstack_, + (context, input, outArrs, dim), LIBND4J_TYPES); } -BUILD_SINGLE_TEMPLATE(template void unstack_, (sd::LaunchContext* context, const NDArray& input, const std::vector& outArrs, const int dim), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void unstack_, + (sd::LaunchContext * context, const NDArray& input, + const std::vector& outArrs, const int dim), + LIBND4J_TYPES); /////////////////////////////////////////////////////////////////// // template -// static __global__ void unstackCuda(const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis) { +// static __global__ void unstackCuda(const void* vx, const Nd4jLong* +// xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis) { // const T* x = reinterpret_cast(vx); // __shared__ Nd4jLong xLen, totalThreads; @@ -230,10 +250,11 @@ BUILD_SINGLE_TEMPLATE(template void unstack_, (sd::LaunchContext* context, const // const auto xOffset = shape::getOffset(xShapeInfo, coords); -// T *z = reinterpret_cast(reinterpret_cast(pVz)[coords[axis]]); +// T *z = reinterpret_cast(reinterpret_cast(pVz)[coords[axis]]); -// for (uint j = axis; j < xRank - 1; ++j) // shift coords staring from axis position -// coords[j] = coords[j + 1]; +// for (uint j = axis; j < xRank - 1; ++j) // shift coords staring +// from axis position coords[j] = coords[j + 1]; // const auto zOffset = shape::getOffset(zTadShapeInfo, coords); @@ -243,19 +264,26 @@ BUILD_SINGLE_TEMPLATE(template void unstack_, (sd::LaunchContext* context, const // /////////////////////////////////////////////////////////////////// // template -// __host__ static void unstackCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, -// const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis) { +// __host__ static void unstackCudaLauncher(const int blocksPerGrid, const int +// threadsPerBlock, const cudaStream_t *stream, +// const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* +// zTadShapeInfo, const int axis) { -// unstackCuda<<>>(vx, xShapeInfo, pVz, zTadShapeInfo, axis); +// unstackCuda<<>>(vx, +// xShapeInfo, pVz, zTadShapeInfo, axis); // } -// BUILD_SINGLE_TEMPLATE(template void unstackCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis), LIBND4J_TYPES); - +// BUILD_SINGLE_TEMPLATE(template void unstackCudaLauncher, (const int +// blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const +// void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* +// zTadShapeInfo, const int axis), LIBND4J_TYPES); // /////////////////////////////////////////////////////////////////// -// void unstack(sd::LaunchContext* context, const NDArray& input, const std::vector& outArrs, const int axis) { +// void unstack(sd::LaunchContext* context, const NDArray& input, const +// std::vector& outArrs, const int axis) { // const int threadsPerBlock = MAX_NUM_THREADS / 2; -// const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; +// const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / +// threadsPerBlock; // const int numOfSubArrs = outArrs.size(); @@ -266,13 +294,17 @@ BUILD_SINGLE_TEMPLATE(template void unstack_, (sd::LaunchContext* context, const // PointersManager manager(context, "helpers::unstack"); -// void* dOutBuffers = manager.replicatePointer(hOutBuffers.data(), hOutBuffers.size() * sizeof(void*)); +// void* dOutBuffers = manager.replicatePointer(hOutBuffers.data(), +// hOutBuffers.size() * sizeof(void*)); // for(uint i = 0; i < numOfSubArrs; ++i) // outArrs[i]->syncToDevice(); // input.syncToDevice(); -// BUILD_SINGLE_SELECTOR(input.dataType(), unstackCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), dOutBuffers, outArrs[0]->special(), axis), LIBND4J_TYPES); +// BUILD_SINGLE_SELECTOR(input.dataType(), unstackCudaLauncher, +// (blocksPerGrid, threadsPerBlock, context->getCudaStream(), +// input.specialBuffer(), input.specialShapeInfo(), dOutBuffers, +// outArrs[0]->special(), axis), LIBND4J_TYPES); // manager.synchronize(); @@ -281,10 +313,10 @@ BUILD_SINGLE_TEMPLATE(template void unstack_, (sd::LaunchContext* context, const // input.tickWriteDevice(); // } - // /////////////////////////////////////////////////////////////////// // template -// static __global__ void stackCuda(void* pVx, const Nd4jLong* xTadShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int axis) { +// static __global__ void stackCuda(void* pVx, const Nd4jLong* xTadShapeInfo, +// void* vz, const Nd4jLong* zShapeInfo, const int axis) { // T* z = reinterpret_cast(vz); @@ -308,10 +340,11 @@ BUILD_SINGLE_TEMPLATE(template void unstack_, (sd::LaunchContext* context, const // const auto zOffset = shape::getOffset(zShapeInfo, coords); -// const T *x = reinterpret_cast(reinterpret_cast(pVx)[coords[axis]]); +// const T *x = reinterpret_cast(reinterpret_cast(pVx)[coords[axis]]); -// for (uint j = axis; j < zRank - 1; ++j) // shift coords staring from axis position -// coords[j] = coords[j + 1]; +// for (uint j = axis; j < zRank - 1; ++j) // shift coords staring +// from axis position coords[j] = coords[j + 1]; // const auto xOffset = shape::getOffset(xTadShapeInfo, coords); @@ -321,19 +354,26 @@ BUILD_SINGLE_TEMPLATE(template void unstack_, (sd::LaunchContext* context, const // /////////////////////////////////////////////////////////////////// // template -// __host__ static void stackCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, -// void* pVx, const Nd4jLong* xTadShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int axis) { +// __host__ static void stackCudaLauncher(const int blocksPerGrid, const int +// threadsPerBlock, const cudaStream_t *stream, +// void* pVx, const Nd4jLong* xTadShapeInfo, void* vz, const Nd4jLong* +// zShapeInfo, const int axis) { -// stackCuda<<>>(pVx, xTadShapeInfo, vz, zShapeInfo, axis); +// stackCuda<<>>(pVx, +// xTadShapeInfo, vz, zShapeInfo, axis); // } -// BUILD_SINGLE_TEMPLATE(template void stackCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, void* pVx, const Nd4jLong* xTadShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int axis), LIBND4J_TYPES); - +// BUILD_SINGLE_TEMPLATE(template void stackCudaLauncher, (const int +// blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, void* +// pVx, const Nd4jLong* xTadShapeInfo, void* vz, const Nd4jLong* zShapeInfo, +// const int axis), LIBND4J_TYPES); // /////////////////////////////////////////////////////////////////// -// void stack(sd::LaunchContext* context, const std::vector& inArrs, NDArray& output, const int axis) { +// void stack(sd::LaunchContext* context, const std::vector& +// inArrs, NDArray& output, const int axis) { // const int threadsPerBlock = MAX_NUM_THREADS / 2; -// const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; +// const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / +// threadsPerBlock; // const int numOfSubArrs = inArrs.size(); @@ -344,13 +384,17 @@ BUILD_SINGLE_TEMPLATE(template void unstack_, (sd::LaunchContext* context, const // PointersManager manager(context, "helpers::stack"); -// void* dInBuffers = manager.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void*)); +// void* dInBuffers = manager.replicatePointer(hInBuffers.data(), +// hInBuffers.size() * sizeof(void*)); // for(uint i = 0; i < numOfSubArrs; ++i) // inArrs[i]->syncToDevice(); // output.syncToDevice(); -// BUILD_SINGLE_SELECTOR(output.dataType(), stackCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), dInBuffers, inArrs[0]->specialShapeInfo(), output.specialBuffer(), output.special(), axis), LIBND4J_TYPES); +// BUILD_SINGLE_SELECTOR(output.dataType(), stackCudaLauncher, +// (blocksPerGrid, threadsPerBlock, context->getCudaStream(), dInBuffers, +// inArrs[0]->specialShapeInfo(), output.specialBuffer(), +// output.special(), axis), LIBND4J_TYPES); // manager.synchronize(); @@ -359,7 +403,6 @@ BUILD_SINGLE_TEMPLATE(template void unstack_, (sd::LaunchContext* context, const // output.tickWriteDevice(); // } -} -} -} - +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/svd.cu b/libnd4j/include/ops/declarable/helpers/cuda/svd.cu index 33dd0251a5fd..bab93a08685e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/svd.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/svd.cu @@ -18,660 +18,727 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include #include +#include #include #include #include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - -// FIXME -> we should optimize these helpers for the case when input matrices have c order (perform transpositions appropriately) +// FIXME -> we should optimize these helpers for the case when input matrices +// have c order (perform transpositions appropriately) template -__global__ static void inverseColumnSignCuda(void* vu, const Nd4jLong* uShapeInfo, void* vv, const Nd4jLong* vShapeInfo) { - - T* u = reinterpret_cast(vu); - T* v = reinterpret_cast(vv); - - __shared__ int rank, uLastButOneColumn, vLastButOneColumn; // uRank = vRank - __shared__ Nd4jLong uLen, vLen; - __shared__ Nd4jLong *sharedMem; - - if (threadIdx.x == 0) { +__global__ static void inverseColumnSignCuda(void* vu, + const Nd4jLong* uShapeInfo, + void* vv, + const Nd4jLong* vShapeInfo) { + T* u = reinterpret_cast(vu); + T* v = reinterpret_cast(vv); - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + __shared__ int rank, uLastButOneColumn, vLastButOneColumn; // uRank = vRank + __shared__ Nd4jLong uLen, vLen; + __shared__ Nd4jLong* sharedMem; - rank = shape::rank(uShapeInfo); - uLen = shape::length(uShapeInfo); - vLen = shape::length(vShapeInfo); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - uLastButOneColumn = uShapeInfo[rank] - 2; - vLastButOneColumn = vShapeInfo[rank - 1] - 2; - } - - __syncthreads(); + rank = shape::rank(uShapeInfo); + uLen = shape::length(uShapeInfo); + vLen = shape::length(vShapeInfo); - const auto ind = threadIdx.x + blockIdx.x * blockDim.x; + uLastButOneColumn = uShapeInfo[rank] - 2; + vLastButOneColumn = vShapeInfo[rank - 1] - 2; + } - auto coords = sharedMem + threadIdx.x * rank; + __syncthreads(); - // u - for (Nd4jLong i = ind; i < uLen; i += gridDim.x * blockDim.x) { + const auto ind = threadIdx.x + blockIdx.x * blockDim.x; - shape::index2coords(i, uShapeInfo, coords); + auto coords = sharedMem + threadIdx.x * rank; - if(coords[rank - 1] == 0 || coords[rank - 1] == uLastButOneColumn) // do not change sign in first and last but one columns - continue; + // u + for (Nd4jLong i = ind; i < uLen; i += gridDim.x * blockDim.x) { + shape::index2coords(i, uShapeInfo, coords); - const auto uOffset = shape::getOffset(uShapeInfo, coords); + if (coords[rank - 1] == 0 || + coords[rank - 1] == uLastButOneColumn) // do not change sign in first + // and last but one columns + continue; - u[uOffset] = -u[uOffset]; - } + const auto uOffset = shape::getOffset(uShapeInfo, coords); - // v - for (Nd4jLong i = ind; i < vLen; i += gridDim.x * blockDim.x) { + u[uOffset] = -u[uOffset]; + } - shape::index2coords(i, vShapeInfo, coords); + // v + for (Nd4jLong i = ind; i < vLen; i += gridDim.x * blockDim.x) { + shape::index2coords(i, vShapeInfo, coords); - if(coords[rank - 2] == 0 || coords[rank - 2] == vLastButOneColumn) // do not change sign in first and last but one columns - continue; + if (coords[rank - 2] == 0 || + coords[rank - 2] == vLastButOneColumn) // do not change sign in first + // and last but one columns + continue; - const auto vOffset = shape::getOffset(vShapeInfo, coords); + const auto vOffset = shape::getOffset(vShapeInfo, coords); - v[vOffset] = -v[vOffset]; - } + v[vOffset] = -v[vOffset]; + } } ////////////////////////////////////////////////////////////////////////// template -static void inverseColumnSignCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - void* vu, const Nd4jLong* uShapeInfo, - void* vv, const Nd4jLong* vShapeInfo) { - - inverseColumnSignCuda<<>>(vu, uShapeInfo, vv, vShapeInfo); +static void inverseColumnSignCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, + const int sharedMem, + const cudaStream_t* stream, void* vu, + const Nd4jLong* uShapeInfo, void* vv, + const Nd4jLong* vShapeInfo) { + inverseColumnSignCuda + <<>>(vu, uShapeInfo, + vv, vShapeInfo); } -BUILD_SINGLE_TEMPLATE(template void inverseColumnSignCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t* stream, void* vu, const Nd4jLong* uShapeInfo, void* vv, const Nd4jLong* vShapeInfo), FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template void inverseColumnSignCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const int sharedMem, const cudaStream_t* stream, + void* vu, const Nd4jLong* uShapeInfo, void* vv, + const Nd4jLong* vShapeInfo), + FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// -static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDArray* U, NDArray* VT, const bool fullUV, const bool calcUV) { - - // since cusa api cusolverDnDgesvd/cusolverDnSgesvd have following constrain on input matrix A: A_rows >= A_columns && A_order = 'f' - // we make this function to have deal with 2 valid cases only: - // 1) A_rows >= A_columns and A_corder = 'f' - // 2) A_rows <= A_columns and A_corder = 'c' - int this case perform transposition to get f order - // if 1) or 2) are not met then throw exception - - // A [m, n] - // S [n] - // U [m, m] or [m, n] if fullUV = false and m > n - // VT [n, n] or [m, n] if fullUV = false and m < n - - if(A->rankOf() != 2) - throw std::runtime_error("svdQR: rank of A array is not equal 2 !"); - - auto m = A->sizeAt(0); - auto n = A->sizeAt(1); - const int minDim = m < n ? m : n; - const char orderA = A->ordering(); - - if(m < n) - throw std::runtime_error("svdQR: due to cuda api input constrains given shape of A array are not valid !"); - - if(std::vector({minDim}) != S->getShapeAsVector()) - throw std::runtime_error("svdQR: wrong shape of S array !"); - - if(calcUV) { - - if(fullUV && std::vector({m,m}) != U->getShapeAsVector()) - throw std::runtime_error("svdQR: wrong shape of U array !"); - else if(!fullUV && std::vector({m,minDim}) != U->getShapeAsVector()) - throw std::runtime_error("svdQR: wrong shape of U array !"); - - if(fullUV && std::vector({n,n}) != VT->getShapeAsVector()) - throw std::runtime_error("svdQR: wrong shape of VT array !"); - else if(!fullUV && std::vector({minDim,n}) != VT->getShapeAsVector()) - throw std::runtime_error("svdQR: wrong shape of VT array !"); - } - - NDArray* pA = const_cast(A); - NDArray* pS = S; - NDArray* pU = U; - NDArray* pVT = VT; - - std::vector toDelete; - - if(pA->ews() != 1 || pA->ordering() == 'c') { - pA = new NDArray(A->dup('f')); - toDelete.push_back(pA); - } - - if(S->ews() != 1) { - pS = new NDArray(S->dup('f')); - toDelete.push_back(pS); +static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, + NDArray* U, NDArray* VT, const bool fullUV, + const bool calcUV) { + // since cusa api cusolverDnDgesvd/cusolverDnSgesvd have following constrain + // on input matrix A: A_rows >= A_columns && A_order = 'f' we make this + // function to have deal with 2 valid cases only: 1) A_rows >= A_columns and + // A_corder = 'f' 2) A_rows <= A_columns and A_corder = 'c' - int this case + // perform transposition to get f order if 1) or 2) are not met then throw + // exception + + // A [m, n] + // S [n] + // U [m, m] or [m, n] if fullUV = false and m > n + // VT [n, n] or [m, n] if fullUV = false and m < n + + if (A->rankOf() != 2) + throw std::runtime_error("svdQR: rank of A array is not equal 2 !"); + + auto m = A->sizeAt(0); + auto n = A->sizeAt(1); + const int minDim = m < n ? m : n; + const char orderA = A->ordering(); + + if (m < n) + throw std::runtime_error( + "svdQR: due to cuda api input constrains given shape of A array are " + "not valid !"); + + if (std::vector({minDim}) != S->getShapeAsVector()) + throw std::runtime_error("svdQR: wrong shape of S array !"); + + if (calcUV) { + if (fullUV && std::vector({m, m}) != U->getShapeAsVector()) + throw std::runtime_error("svdQR: wrong shape of U array !"); + else if (!fullUV && + std::vector({m, minDim}) != U->getShapeAsVector()) + throw std::runtime_error("svdQR: wrong shape of U array !"); + + if (fullUV && std::vector({n, n}) != VT->getShapeAsVector()) + throw std::runtime_error("svdQR: wrong shape of VT array !"); + else if (!fullUV && + std::vector({minDim, n}) != VT->getShapeAsVector()) + throw std::runtime_error("svdQR: wrong shape of VT array !"); + } + + NDArray* pA = const_cast(A); + NDArray* pS = S; + NDArray* pU = U; + NDArray* pVT = VT; + + std::vector toDelete; + + if (pA->ews() != 1 || pA->ordering() == 'c') { + pA = new NDArray(A->dup('f')); + toDelete.push_back(pA); + } + + if (S->ews() != 1) { + pS = new NDArray(S->dup('f')); + toDelete.push_back(pS); + } + + if (calcUV) { + if (pU->ews() != 1 || pU->ordering() == 'c') { + pU = new NDArray(U->dup('f')); + toDelete.push_back(pU); } - if(calcUV) { - - if(pU->ews() != 1 || pU->ordering() == 'c') { - pU = new NDArray(U->dup('f')); - toDelete.push_back(pU); - } - - if(pVT->ews() != 1 || pVT->ordering() == 'c') { - pVT = new NDArray(VT->dup('f')); - toDelete.push_back(pVT); - } + if (pVT->ews() != 1 || pVT->ordering() == 'c') { + pVT = new NDArray(VT->dup('f')); + toDelete.push_back(pVT); } - - std::lock_guard lock(*LaunchContext::deviceMutex()); - - // create cusolverDn handle - cusolverDnHandle_t* handle = (cusolverDnHandle_t*)context->getCusolverHandle(); //nullptr; - //cusolverStatus_t status = cusolverDnCreate(&handle); - if(handle == nullptr) - throw cuda_exception::build("svdQR: cuda failed !", -1); - - // stream - auto status = cusolverDnSetStream(*handle, *context->getCudaStream()); - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdQR: cuda failed !", status); - - // query working space of SVD - int lwork = 0; - if(A->dataType() == DataType::DOUBLE) - status = cusolverDnDgesvd_bufferSize(*handle, m, n, &lwork); - else if(A->dataType() == DataType::FLOAT32) - status = cusolverDnSgesvd_bufferSize(*handle, m, n, &lwork); + } + + std::lock_guard lock(*LaunchContext::deviceMutex()); + + // create cusolverDn handle + cusolverDnHandle_t* handle = + (cusolverDnHandle_t*)context->getCusolverHandle(); // nullptr; + // cusolverStatus_t status = cusolverDnCreate(&handle); + if (handle == nullptr) + throw cuda_exception::build("svdQR: cuda failed !", -1); + + // stream + auto status = cusolverDnSetStream(*handle, *context->getCudaStream()); + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdQR: cuda failed !", status); + + // query working space of SVD + int lwork = 0; + if (A->dataType() == DataType::DOUBLE) + status = cusolverDnDgesvd_bufferSize(*handle, m, n, &lwork); + else if (A->dataType() == DataType::FLOAT32) + status = cusolverDnSgesvd_bufferSize(*handle, m, n, &lwork); + else + throw std::invalid_argument("svdQR: given data type is unsupported !"); + + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdQR: cuda failed !", status); + + // allocate memory for dWork + void* dWork = nullptr; + cudaError_t status2 = cudaMalloc((void**)&dWork, A->sizeOfT() * lwork); + if (status2 != cudaSuccess) + throw cuda_exception::build("svdQR: cuda failed !", status2); + + signed char jobu, jobvt; + + if (calcUV) { + if (fullUV) + jobu = jobvt = 'A'; else - throw std::invalid_argument("svdQR: given data type is unsupported !"); - - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdQR: cuda failed !", status); - - // allocate memory for dWork - void* dWork = nullptr; - cudaError_t status2 = cudaMalloc((void**)&dWork , A->sizeOfT() * lwork); - if(status2 != cudaSuccess) - throw cuda_exception::build("svdQR: cuda failed !", status2); - - signed char jobu, jobvt; - - if(calcUV) { - if(fullUV) - jobu = jobvt = 'A'; - else - jobu = jobvt = 'S'; - } - else { - jobu = jobvt = 'N'; - } - - int *devInfo = nullptr; - void* rWork = nullptr; - - int lda(m), ldu, ldvt; - - if(calcUV) { - ldu = pU->sizeAt(0); - ldvt = pVT->sizeAt(0); - } - - PointersManager manager(context, "svdQR"); - - NDArray::prepareSpecialUse({pS, pU, pVT}, {pA}); - - // choose appropriate cuda gemm api depending on data types - if(A->dataType() == DataType::DOUBLE) { - status = cusolverDnDgesvd(*handle, jobu, jobvt, m, n, reinterpret_cast(pA->specialBuffer()), lda, reinterpret_cast(pS->specialBuffer()), calcUV ? reinterpret_cast(pU->specialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pVT->specialBuffer()) : nullptr, ldvt, reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), devInfo); - } - else if(A->dataType() == DataType::FLOAT32) { - status = cusolverDnSgesvd(*handle, jobu, jobvt, m, n, reinterpret_cast(pA->specialBuffer()), lda, reinterpret_cast(pS->specialBuffer()), calcUV ? reinterpret_cast(pU->specialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pVT->specialBuffer()) : nullptr, ldvt, reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), devInfo); - } - else - throw std::invalid_argument("svdQR: given data type is unsupported !"); - - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdQR: cuda failed !", status); - - manager.synchronize(); - - NDArray::registerSpecialUse({pS, pU, pVT}, {pA}); - - S->assign(pS); - - if(calcUV) { - U->assign(pU); - VT->assign(pVT); - } - - for (int i = toDelete.size() - 1; i >= 0; --i) - delete toDelete[i]; - - if (devInfo) - cudaFree(devInfo); - if (dWork ) - cudaFree(dWork); - if (rWork) - cudaFree(rWork); - -// if(handle) -// cusolverDnDestroy(handle); - - // cudaDeviceReset(); + jobu = jobvt = 'S'; + } else { + jobu = jobvt = 'N'; + } + + int* devInfo = nullptr; + void* rWork = nullptr; + + int lda(m), ldu, ldvt; + + if (calcUV) { + ldu = pU->sizeAt(0); + ldvt = pVT->sizeAt(0); + } + + PointersManager manager(context, "svdQR"); + + NDArray::prepareSpecialUse({pS, pU, pVT}, {pA}); + + // choose appropriate cuda gemm api depending on data types + if (A->dataType() == DataType::DOUBLE) { + status = cusolverDnDgesvd( + *handle, jobu, jobvt, m, n, + reinterpret_cast(pA->specialBuffer()), lda, + reinterpret_cast(pS->specialBuffer()), + calcUV ? reinterpret_cast(pU->specialBuffer()) : nullptr, ldu, + calcUV ? reinterpret_cast(pVT->specialBuffer()) : nullptr, + ldvt, reinterpret_cast(dWork), lwork, + reinterpret_cast(rWork), devInfo); + } else if (A->dataType() == DataType::FLOAT32) { + status = cusolverDnSgesvd( + *handle, jobu, jobvt, m, n, + reinterpret_cast(pA->specialBuffer()), lda, + reinterpret_cast(pS->specialBuffer()), + calcUV ? reinterpret_cast(pU->specialBuffer()) : nullptr, ldu, + calcUV ? reinterpret_cast(pVT->specialBuffer()) : nullptr, ldvt, + reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), + devInfo); + } else + throw std::invalid_argument("svdQR: given data type is unsupported !"); + + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdQR: cuda failed !", status); + + manager.synchronize(); + + NDArray::registerSpecialUse({pS, pU, pVT}, {pA}); + + S->assign(pS); + + if (calcUV) { + U->assign(pU); + VT->assign(pVT); + } + + for (int i = toDelete.size() - 1; i >= 0; --i) delete toDelete[i]; + + if (devInfo) cudaFree(devInfo); + if (dWork) cudaFree(dWork); + if (rWork) cudaFree(rWork); + + // if(handle) + // cusolverDnDestroy(handle); + + // cudaDeviceReset(); } ////////////////////////////////////////////////////////////////////////// -static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDArray* U, NDArray* V, const bool fullUV, const bool calcUV) { - - // A [m, n] - // S [n] - // U [m, m] or [m, n] if fullUV = false and m > n - // V [n, n] or [n, m] if fullUV = false and m < n +static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, + NDArray* U, NDArray* V, const bool fullUV, + const bool calcUV) { + // A [m, n] + // S [n] + // U [m, m] or [m, n] if fullUV = false and m > n + // V [n, n] or [n, m] if fullUV = false and m < n - if(A->rankOf() != 2) - throw std::runtime_error("svdJcb: rank of A array is not equal 2 !"); + if (A->rankOf() != 2) + throw std::runtime_error("svdJcb: rank of A array is not equal 2 !"); - int m = A->sizeAt(0); - int n = A->sizeAt(1); - const int minDim = m < n ? m : n; + int m = A->sizeAt(0); + int n = A->sizeAt(1); + const int minDim = m < n ? m : n; - if(std::vector({minDim}) != S->getShapeAsVector()) - throw std::runtime_error("svdJcb: wrong shape of S array !"); + if (std::vector({minDim}) != S->getShapeAsVector()) + throw std::runtime_error("svdJcb: wrong shape of S array !"); - if(calcUV) { + if (calcUV) { + if (fullUV && std::vector({m, m}) != U->getShapeAsVector()) + throw std::runtime_error("svdJcb: wrong shape of U array !"); + else if (!fullUV && + std::vector({m, minDim}) != U->getShapeAsVector()) + throw std::runtime_error("svdJcb: wrong shape of U array !"); - if(fullUV && std::vector({m,m}) != U->getShapeAsVector()) - throw std::runtime_error("svdJcb: wrong shape of U array !"); - else if(!fullUV && std::vector({m,minDim}) != U->getShapeAsVector()) - throw std::runtime_error("svdJcb: wrong shape of U array !"); + if (fullUV && std::vector({n, n}) != V->getShapeAsVector()) + throw std::runtime_error("svdJcb: wrong shape of V array !"); + else if (!fullUV && + std::vector({n, minDim}) != V->getShapeAsVector()) + throw std::runtime_error("svdJcb: wrong shape of V array !"); + } - if(fullUV && std::vector({n,n}) != V->getShapeAsVector()) - throw std::runtime_error("svdJcb: wrong shape of V array !"); - else if(!fullUV && std::vector({n,minDim}) != V->getShapeAsVector()) - throw std::runtime_error("svdJcb: wrong shape of V array !"); - } + NDArray* pA = const_cast(A); - NDArray* pA = const_cast(A); + const bool aForder = m == 1 || A->strideAt(0) == 1; + const bool aCorder = n == 1 || A->strideAt(1) == 1; - const bool aForder = m == 1 || A->strideAt(0) == 1; - const bool aCorder = n == 1 || A->strideAt(1) == 1; + const bool transA = !aForder && aCorder; + const bool dupA = !aForder && !aCorder; - const bool transA = !aForder && aCorder; - const bool dupA = !aForder && !aCorder; + std::vector toDelete; - std::vector toDelete; + if (dupA) { + pA = new NDArray(A->dup('f')); + toDelete.push_back(pA); + } - if(dupA) { - pA = new NDArray(A->dup('f')); - toDelete.push_back(pA); - } - - NDArray* pS = S; - - if(S->ews() != 1) { - pS = new NDArray(S->dup('f')); - toDelete.push_back(pS); - } + NDArray* pS = S; - NDArray *pU(nullptr), *pV(nullptr); + if (S->ews() != 1) { + pS = new NDArray(S->dup('f')); + toDelete.push_back(pS); + } - int lda = transA ? pA->strideAt(0) : pA->strideAt(1); - int ldu(transA ? n : m), ldv(transA ? m : n); - bool uForder(true), vForder(true); + NDArray *pU(nullptr), *pV(nullptr); - if(calcUV) { + int lda = transA ? pA->strideAt(0) : pA->strideAt(1); + int ldu(transA ? n : m), ldv(transA ? m : n); + bool uForder(true), vForder(true); - pU = transA ? V : U; - pV = transA ? U : V; + if (calcUV) { + pU = transA ? V : U; + pV = transA ? U : V; - uForder = pU->sizeAt(0) == 1 || pU->strideAt(0) == 1; - vForder = pV->sizeAt(0) == 1 || pV->strideAt(0) == 1; + uForder = pU->sizeAt(0) == 1 || pU->strideAt(0) == 1; + vForder = pV->sizeAt(0) == 1 || pV->strideAt(0) == 1; - if(!uForder) { - pU = new NDArray(pU->dup('f')); - toDelete.push_back(pU); - } - - if(!vForder) { - pV = new NDArray(pV->dup('f')); - toDelete.push_back(pV); - } - - ldu = pU->strideAt(1); - ldv = pV->strideAt(1); - } - - std::lock_guard lock(*LaunchContext::deviceMutex()); - - // create cusolverDn handle - cusolverDnHandle_t* handle = (cusolverDnHandle_t*)context->getCusolverHandle(); - //cusolverStatus_t status = cusolverDnCreate(&handle); - if(handle == nullptr) - throw cuda_exception::build("svdJcb: cuda failed !", -1); - - // stream - auto status = cusolverDnSetStream(*handle, *context->getCudaStream()); - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdJcb: cuda failed !", status); - - // set parameters - gesvdjInfo_t gesvdjParams = nullptr; - status = cusolverDnCreateGesvdjInfo(&gesvdjParams); - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdJcb: cuda failed !", status); - status = cusolverDnXgesvdjSetTolerance(gesvdjParams, 1.e-7); // tolerance - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdJcb: cuda failed !", status); - status = cusolverDnXgesvdjSetMaxSweeps(gesvdjParams, 15); // max_sweeps - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdJcb: cuda failed !", status); - - int *devInfo = nullptr; - const cusolverEigMode_t jobz = calcUV ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR; - const int econ = !fullUV; - - if(transA) - math::nd4j_swap(m, n); - - // *** avoid bug in cuda API *** - void* nullPtr = nullptr; - NDArray* arrToAvoidBugInAPI = nullptr; - if(!calcUV && m != n) { - int maxDim = m > n ? m : n; - arrToAvoidBugInAPI = new NDArray('c', {maxDim, maxDim}, pA->dataType(), context); - nullPtr = arrToAvoidBugInAPI->specialBuffer(); + if (!uForder) { + pU = new NDArray(pU->dup('f')); + toDelete.push_back(pU); } - // ****************** - - NDArray::prepareSpecialUse({pS, pU, pV}, {pA}); - - // query working space of SVD - int lwork = 0; - if(A->dataType() == DataType::DOUBLE) - status = cusolverDnDgesvdj_bufferSize(*handle, jobz, econ, m, n, reinterpret_cast(pA->specialBuffer()), lda, reinterpret_cast(pS->specialBuffer()), calcUV ? reinterpret_cast(pU->specialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->specialBuffer()) : reinterpret_cast(nullPtr), ldv, &lwork, gesvdjParams); - else if(A->dataType() == DataType::FLOAT32) - status = cusolverDnSgesvdj_bufferSize(*handle, jobz, econ, m, n, reinterpret_cast(pA->specialBuffer()), lda, reinterpret_cast(pS->specialBuffer()), calcUV ? reinterpret_cast(pU->specialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->specialBuffer()) : reinterpret_cast(nullPtr), ldv, &lwork, gesvdjParams); - else - throw std::invalid_argument("svdJcb: given data type is unsupported !"); - - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdJcb: cuda failed !", status); - // allocate memory dWork - void* dWork = nullptr; - auto status2 = cudaMalloc((void**)&dWork , A->sizeOfT() * lwork); - if(status2 != cudaSuccess) - throw cuda_exception::build("svdJcb: cuda failed !", status2); - - PointersManager manager(context, "svdJcb"); - - // choose appropriate cuda gemm api depending on data types - if(A->dataType() == DataType::DOUBLE) { - status = cusolverDnDgesvdj(*handle, jobz, econ, m, n, reinterpret_cast(pA->specialBuffer()), lda, reinterpret_cast(pS->specialBuffer()), calcUV ? reinterpret_cast(pU->specialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->specialBuffer()) : reinterpret_cast(nullPtr), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); + if (!vForder) { + pV = new NDArray(pV->dup('f')); + toDelete.push_back(pV); } - else if(A->dataType() == DataType::FLOAT32) { - status = cusolverDnSgesvdj(*handle, jobz, econ, m, n, reinterpret_cast(pA->specialBuffer()), lda, reinterpret_cast(pS->specialBuffer()), calcUV ? reinterpret_cast(pU->specialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->specialBuffer()) : reinterpret_cast(nullPtr), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); - } - else - throw std::invalid_argument("svdJcb: given data type is unsupported !"); - - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdJcb: cuda failed !", status); - - manager.synchronize(); - - NDArray::registerSpecialUse({pS, pU, pV}, {pA}); - - if(S->ews() != 1) - S->assign(pS); - if(calcUV) { - - if(!uForder) - U->assign(transA ? pV : pU); - if(!vForder) - V->assign(transA ? pU : pV); - } - - if(!calcUV && m != n) - delete arrToAvoidBugInAPI; - - for (int i = toDelete.size() - 1; i >= 0; --i) - delete toDelete[i]; - - if (devInfo) - cudaFree(devInfo); - if (dWork ) - cudaFree(dWork); -// if(handle) -// cusolverDnDestroy(handle); - if(gesvdjParams) - cusolverDnDestroyGesvdjInfo(gesvdjParams); - - // cudaDeviceReset(); + ldu = pU->strideAt(1); + ldv = pV->strideAt(1); + } + + std::lock_guard lock(*LaunchContext::deviceMutex()); + + // create cusolverDn handle + cusolverDnHandle_t* handle = + (cusolverDnHandle_t*)context->getCusolverHandle(); + // cusolverStatus_t status = cusolverDnCreate(&handle); + if (handle == nullptr) + throw cuda_exception::build("svdJcb: cuda failed !", -1); + + // stream + auto status = cusolverDnSetStream(*handle, *context->getCudaStream()); + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdJcb: cuda failed !", status); + + // set parameters + gesvdjInfo_t gesvdjParams = nullptr; + status = cusolverDnCreateGesvdjInfo(&gesvdjParams); + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdJcb: cuda failed !", status); + status = cusolverDnXgesvdjSetTolerance(gesvdjParams, 1.e-7); // tolerance + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdJcb: cuda failed !", status); + status = cusolverDnXgesvdjSetMaxSweeps(gesvdjParams, 15); // max_sweeps + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdJcb: cuda failed !", status); + + int* devInfo = nullptr; + const cusolverEigMode_t jobz = + calcUV ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR; + const int econ = !fullUV; + + if (transA) math::nd4j_swap(m, n); + + // *** avoid bug in cuda API *** + void* nullPtr = nullptr; + NDArray* arrToAvoidBugInAPI = nullptr; + if (!calcUV && m != n) { + int maxDim = m > n ? m : n; + arrToAvoidBugInAPI = + new NDArray('c', {maxDim, maxDim}, pA->dataType(), context); + nullPtr = arrToAvoidBugInAPI->specialBuffer(); + } + // ****************** + + NDArray::prepareSpecialUse({pS, pU, pV}, {pA}); + + // query working space of SVD + int lwork = 0; + if (A->dataType() == DataType::DOUBLE) + status = cusolverDnDgesvdj_bufferSize( + *handle, jobz, econ, m, n, + reinterpret_cast(pA->specialBuffer()), lda, + reinterpret_cast(pS->specialBuffer()), + calcUV ? reinterpret_cast(pU->specialBuffer()) + : reinterpret_cast(nullPtr), + ldu, + calcUV ? reinterpret_cast(pV->specialBuffer()) + : reinterpret_cast(nullPtr), + ldv, &lwork, gesvdjParams); + else if (A->dataType() == DataType::FLOAT32) + status = cusolverDnSgesvdj_bufferSize( + *handle, jobz, econ, m, n, + reinterpret_cast(pA->specialBuffer()), lda, + reinterpret_cast(pS->specialBuffer()), + calcUV ? reinterpret_cast(pU->specialBuffer()) + : reinterpret_cast(nullPtr), + ldu, + calcUV ? reinterpret_cast(pV->specialBuffer()) + : reinterpret_cast(nullPtr), + ldv, &lwork, gesvdjParams); + else + throw std::invalid_argument("svdJcb: given data type is unsupported !"); + + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdJcb: cuda failed !", status); + + // allocate memory dWork + void* dWork = nullptr; + auto status2 = cudaMalloc((void**)&dWork, A->sizeOfT() * lwork); + if (status2 != cudaSuccess) + throw cuda_exception::build("svdJcb: cuda failed !", status2); + + PointersManager manager(context, "svdJcb"); + + // choose appropriate cuda gemm api depending on data types + if (A->dataType() == DataType::DOUBLE) { + status = cusolverDnDgesvdj( + *handle, jobz, econ, m, n, + reinterpret_cast(pA->specialBuffer()), lda, + reinterpret_cast(pS->specialBuffer()), + calcUV ? reinterpret_cast(pU->specialBuffer()) + : reinterpret_cast(nullPtr), + ldu, + calcUV ? reinterpret_cast(pV->specialBuffer()) + : reinterpret_cast(nullPtr), + ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); + } else if (A->dataType() == DataType::FLOAT32) { + status = cusolverDnSgesvdj( + *handle, jobz, econ, m, n, + reinterpret_cast(pA->specialBuffer()), lda, + reinterpret_cast(pS->specialBuffer()), + calcUV ? reinterpret_cast(pU->specialBuffer()) + : reinterpret_cast(nullPtr), + ldu, + calcUV ? reinterpret_cast(pV->specialBuffer()) + : reinterpret_cast(nullPtr), + ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); + } else + throw std::invalid_argument("svdJcb: given data type is unsupported !"); + + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdJcb: cuda failed !", status); + + manager.synchronize(); + + NDArray::registerSpecialUse({pS, pU, pV}, {pA}); + + if (S->ews() != 1) S->assign(pS); + + if (calcUV) { + if (!uForder) U->assign(transA ? pV : pU); + if (!vForder) V->assign(transA ? pU : pV); + } + + if (!calcUV && m != n) delete arrToAvoidBugInAPI; + + for (int i = toDelete.size() - 1; i >= 0; --i) delete toDelete[i]; + + if (devInfo) cudaFree(devInfo); + if (dWork) cudaFree(dWork); + // if(handle) + // cusolverDnDestroy(handle); + if (gesvdjParams) cusolverDnDestroyGesvdjInfo(gesvdjParams); + + // cudaDeviceReset(); } ////////////////////////////////////////////////////////////////////////// -static void svdBatched(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDArray* U, NDArray* V, const bool fullUV, const bool calcUV) { - - // A [..., m, n] - // S [..., n] - // U [..., m, m] or [..., m, n] if fullUV = false and m > n - // V [..., n, n] or [..., n, m] if fullUV = false and m < n - - auto m = A->sizeAt(-2); - auto n = A->sizeAt(-1); - const int minDim = m < n ? m : n; - const Nd4jLong bS = A->lengthOf() / (m * n); - - if(m > 32 || n > 32) - throw std::runtime_error("svdBatched: numbers of rows and columns should be <= 32 !"); - - if(minDim != S->sizeAt(-1)) - throw std::runtime_error("svdBatched: wrong shape of S array !"); - - if(calcUV) { - - if(U->sizeAt(-2) != m) - throw std::runtime_error("svdBatched: wrong shape of U array !"); - if(U->sizeAt(-1) != (fullUV ? m : minDim)) - throw std::runtime_error("svdBatched: wrong shape of U array !"); - if(U->lengthOf() / (U->sizeAt(-2) * U->sizeAt(-1)) != bS) - throw std::runtime_error("svdBatched: wrong shape of U array !"); - - if(V->sizeAt(-2) != n) - throw std::runtime_error("svdBatched: wrong shape of V array !"); - if(V->sizeAt(-1) != (fullUV ? n : minDim)) - throw std::runtime_error("svdBatched: wrong shape of V array !"); - if(V->lengthOf() / (V->sizeAt(-2) * V->sizeAt(-1)) != bS) - throw std::runtime_error("svdBatched: wrong shape of V array !"); - } - - NDArray* pA = const_cast(A); - NDArray* pS = S; - NDArray* pU = U; - NDArray* pV = V; - - std::vector toDelete; - - if(pA->ews() != 1 || pA->ordering() == 'c') { - pA = new NDArray(A->dup('f')); - toDelete.push_back(pA); - } - - if(S->ews() != 1) { - pS = new NDArray(S->dup('f')); - toDelete.push_back(pS); - } - - if(calcUV) { - - if(pU->ews() != 1 || pU->ordering() == 'c') { - pU = new NDArray(U->dup('f')); - toDelete.push_back(pU); - } - - if(pV->ews() != 1 || pV->ordering() == 'c') { - pV = new NDArray(V->dup('f')); - toDelete.push_back(pV); - } - } - - // create cusolverDn handle - cusolverDnHandle_t handle = nullptr; - cusolverStatus_t status = cusolverDnCreate(&handle); - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdBatched: cuda failed !", status); - - // stream - status = cusolverDnSetStream(handle, *context->getCudaStream()); - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdBatched: cuda failed !", status); - - // set parameters - gesvdjInfo_t gesvdjParams = nullptr; - status = cusolverDnCreateGesvdjInfo(&gesvdjParams); - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdBatched: cuda failed !", status); - status = cusolverDnXgesvdjSetTolerance(gesvdjParams, 1.e-7); // tolerance - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdBatched: cuda failed !", status); - status = cusolverDnXgesvdjSetMaxSweeps(gesvdjParams, 15); // max_sweeps - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdBatched: cuda failed !", status); - - // devInfo - int *devInfo = nullptr; - auto status2 = cudaMalloc((void**)&devInfo, sizeof(int) * bS); - if(status2 != cudaSuccess) - throw cuda_exception::build("svdBatched: cuda failed !", status2); - status2 = cudaDeviceSynchronize(); - if(status2 != cudaSuccess) - throw cuda_exception::build("svdJcb: cuda failed !", status2); - - const cusolverEigMode_t jobz = calcUV ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR; - - int lda(m), ldu, ldv; - - if(calcUV) { - ldu = pU->sizeAt(-2); - ldv = pV->sizeAt(-2); +static void svdBatched(sd::LaunchContext* context, const NDArray* A, NDArray* S, + NDArray* U, NDArray* V, const bool fullUV, + const bool calcUV) { + // A [..., m, n] + // S [..., n] + // U [..., m, m] or [..., m, n] if fullUV = false and m > n + // V [..., n, n] or [..., n, m] if fullUV = false and m < n + + auto m = A->sizeAt(-2); + auto n = A->sizeAt(-1); + const int minDim = m < n ? m : n; + const Nd4jLong bS = A->lengthOf() / (m * n); + + if (m > 32 || n > 32) + throw std::runtime_error( + "svdBatched: numbers of rows and columns should be <= 32 !"); + + if (minDim != S->sizeAt(-1)) + throw std::runtime_error("svdBatched: wrong shape of S array !"); + + if (calcUV) { + if (U->sizeAt(-2) != m) + throw std::runtime_error("svdBatched: wrong shape of U array !"); + if (U->sizeAt(-1) != (fullUV ? m : minDim)) + throw std::runtime_error("svdBatched: wrong shape of U array !"); + if (U->lengthOf() / (U->sizeAt(-2) * U->sizeAt(-1)) != bS) + throw std::runtime_error("svdBatched: wrong shape of U array !"); + + if (V->sizeAt(-2) != n) + throw std::runtime_error("svdBatched: wrong shape of V array !"); + if (V->sizeAt(-1) != (fullUV ? n : minDim)) + throw std::runtime_error("svdBatched: wrong shape of V array !"); + if (V->lengthOf() / (V->sizeAt(-2) * V->sizeAt(-1)) != bS) + throw std::runtime_error("svdBatched: wrong shape of V array !"); + } + + NDArray* pA = const_cast(A); + NDArray* pS = S; + NDArray* pU = U; + NDArray* pV = V; + + std::vector toDelete; + + if (pA->ews() != 1 || pA->ordering() == 'c') { + pA = new NDArray(A->dup('f')); + toDelete.push_back(pA); + } + + if (S->ews() != 1) { + pS = new NDArray(S->dup('f')); + toDelete.push_back(pS); + } + + if (calcUV) { + if (pU->ews() != 1 || pU->ordering() == 'c') { + pU = new NDArray(U->dup('f')); + toDelete.push_back(pU); } - // Ak (i,j) = A[i + 5*j + 25*k] - - // query working space of SVD - int lwork = 0; - if(A->dataType() == DataType::DOUBLE) - status = cusolverDnDgesvdjBatched_bufferSize(handle, jobz, m, n, reinterpret_cast(pA->specialBuffer()), lda, reinterpret_cast(pS->specialBuffer()), calcUV ? reinterpret_cast(pU->specialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->specialBuffer()) : nullptr, ldv, &lwork, gesvdjParams, bS); - else if(A->dataType() == DataType::FLOAT32) - status = cusolverDnSgesvdjBatched_bufferSize(handle, jobz, m, n, reinterpret_cast(pA->specialBuffer()), lda, reinterpret_cast(pS->specialBuffer()), calcUV ? reinterpret_cast(pU->specialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->specialBuffer()) : nullptr, ldv, &lwork, gesvdjParams, bS); - else - throw std::invalid_argument("svdBatched: given data type is unsupported !"); - - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdBatched: cuda failed !", status); - - // allocate memory dWork - void* dWork = nullptr; - status2 = cudaMalloc((void**)&dWork , A->sizeOfT() * lwork); - if(status2 != cudaSuccess) - throw cuda_exception::build("svdBatched: cuda failed !", status2); - status2 = cudaDeviceSynchronize(); - if(status2 != cudaSuccess) - throw cuda_exception::build("svdBatched: cuda failed !", status2); - - PointersManager manager(context, "svdBatched"); - - NDArray::prepareSpecialUse({pS, pU, pV}, {pA}); - - // choose appropriate cuda gemm api depending on data types - if(A->dataType() == DataType::DOUBLE) { - status = cusolverDnDgesvdjBatched(handle, jobz, m, n, reinterpret_cast(pA->specialBuffer()), lda, reinterpret_cast(pS->specialBuffer()), calcUV ? reinterpret_cast(pU->specialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->specialBuffer()) : nullptr, ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams, bS); + if (pV->ews() != 1 || pV->ordering() == 'c') { + pV = new NDArray(V->dup('f')); + toDelete.push_back(pV); } - else if(A->dataType() == DataType::FLOAT32) { - status = cusolverDnSgesvdjBatched(handle, jobz, m, n, reinterpret_cast(pA->specialBuffer()), lda, reinterpret_cast(pS->specialBuffer()), calcUV ? reinterpret_cast(pU->specialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->specialBuffer()) : nullptr, ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams, bS); - } - else - throw std::invalid_argument("svdBatched: given data type is unsupported !"); - - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdBatched: cuda failed !", status); - - manager.synchronize(); - - NDArray::registerSpecialUse({pS, pU, pV}, {pA}); - - S->assign(pS); - - if(calcUV) { - U->assign(pU); - V->assign(pV); - } - - for (int i = toDelete.size() - 1; i >= 0; --i) - delete toDelete[i]; - - if (devInfo) - cudaFree(devInfo); - if (dWork ) - cudaFree(dWork); - if(handle) - cusolverDnDestroy(handle); - if(gesvdjParams) - cusolverDnDestroyGesvdjInfo(gesvdjParams); - - // cudaDeviceReset(); + } + + // create cusolverDn handle + cusolverDnHandle_t handle = nullptr; + cusolverStatus_t status = cusolverDnCreate(&handle); + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdBatched: cuda failed !", status); + + // stream + status = cusolverDnSetStream(handle, *context->getCudaStream()); + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdBatched: cuda failed !", status); + + // set parameters + gesvdjInfo_t gesvdjParams = nullptr; + status = cusolverDnCreateGesvdjInfo(&gesvdjParams); + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdBatched: cuda failed !", status); + status = cusolverDnXgesvdjSetTolerance(gesvdjParams, 1.e-7); // tolerance + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdBatched: cuda failed !", status); + status = cusolverDnXgesvdjSetMaxSweeps(gesvdjParams, 15); // max_sweeps + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdBatched: cuda failed !", status); + + // devInfo + int* devInfo = nullptr; + auto status2 = cudaMalloc((void**)&devInfo, sizeof(int) * bS); + if (status2 != cudaSuccess) + throw cuda_exception::build("svdBatched: cuda failed !", status2); + status2 = cudaDeviceSynchronize(); + if (status2 != cudaSuccess) + throw cuda_exception::build("svdJcb: cuda failed !", status2); + + const cusolverEigMode_t jobz = + calcUV ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR; + + int lda(m), ldu, ldv; + + if (calcUV) { + ldu = pU->sizeAt(-2); + ldv = pV->sizeAt(-2); + } + + // Ak (i,j) = A[i + 5*j + 25*k] + + // query working space of SVD + int lwork = 0; + if (A->dataType() == DataType::DOUBLE) + status = cusolverDnDgesvdjBatched_bufferSize( + handle, jobz, m, n, reinterpret_cast(pA->specialBuffer()), lda, + reinterpret_cast(pS->specialBuffer()), + calcUV ? reinterpret_cast(pU->specialBuffer()) : nullptr, ldu, + calcUV ? reinterpret_cast(pV->specialBuffer()) : nullptr, ldv, + &lwork, gesvdjParams, bS); + else if (A->dataType() == DataType::FLOAT32) + status = cusolverDnSgesvdjBatched_bufferSize( + handle, jobz, m, n, reinterpret_cast(pA->specialBuffer()), lda, + reinterpret_cast(pS->specialBuffer()), + calcUV ? reinterpret_cast(pU->specialBuffer()) : nullptr, ldu, + calcUV ? reinterpret_cast(pV->specialBuffer()) : nullptr, ldv, + &lwork, gesvdjParams, bS); + else + throw std::invalid_argument("svdBatched: given data type is unsupported !"); + + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdBatched: cuda failed !", status); + + // allocate memory dWork + void* dWork = nullptr; + status2 = cudaMalloc((void**)&dWork, A->sizeOfT() * lwork); + if (status2 != cudaSuccess) + throw cuda_exception::build("svdBatched: cuda failed !", status2); + status2 = cudaDeviceSynchronize(); + if (status2 != cudaSuccess) + throw cuda_exception::build("svdBatched: cuda failed !", status2); + + PointersManager manager(context, "svdBatched"); + + NDArray::prepareSpecialUse({pS, pU, pV}, {pA}); + + // choose appropriate cuda gemm api depending on data types + if (A->dataType() == DataType::DOUBLE) { + status = cusolverDnDgesvdjBatched( + handle, jobz, m, n, reinterpret_cast(pA->specialBuffer()), lda, + reinterpret_cast(pS->specialBuffer()), + calcUV ? reinterpret_cast(pU->specialBuffer()) : nullptr, ldu, + calcUV ? reinterpret_cast(pV->specialBuffer()) : nullptr, ldv, + reinterpret_cast(dWork), lwork, devInfo, gesvdjParams, bS); + } else if (A->dataType() == DataType::FLOAT32) { + status = cusolverDnSgesvdjBatched( + handle, jobz, m, n, reinterpret_cast(pA->specialBuffer()), lda, + reinterpret_cast(pS->specialBuffer()), + calcUV ? reinterpret_cast(pU->specialBuffer()) : nullptr, ldu, + calcUV ? reinterpret_cast(pV->specialBuffer()) : nullptr, ldv, + reinterpret_cast(dWork), lwork, devInfo, gesvdjParams, bS); + } else + throw std::invalid_argument("svdBatched: given data type is unsupported !"); + + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdBatched: cuda failed !", status); + + manager.synchronize(); + + NDArray::registerSpecialUse({pS, pU, pV}, {pA}); + + S->assign(pS); + + if (calcUV) { + U->assign(pU); + V->assign(pV); + } + + for (int i = toDelete.size() - 1; i >= 0; --i) delete toDelete[i]; + + if (devInfo) cudaFree(devInfo); + if (dWork) cudaFree(dWork); + if (handle) cusolverDnDestroy(handle); + if (gesvdjParams) cusolverDnDestroyGesvdjInfo(gesvdjParams); + + // cudaDeviceReset(); } //////////////////////////////////////////////////////////////////// -void svd(sd::LaunchContext* context, const NDArray* x, const std::vector& outArrs, const bool fullUV, const bool calcUV, const int switchNum) { - - NDArray* S = outArrs[0]; - NDArray* U = outArrs[1]; - // NDArray VT = outArrs[2]->transpose(); - NDArray* V = outArrs[2]; - - NDArray::prepareSpecialUse({S, U, V}, {x}); - - if(x->rankOf() == 2) { - // svdQR(context, x, S, U, VT, fullUV, calcUV); - svdJcb(context, x, S, U, V, fullUV, calcUV); +void svd(sd::LaunchContext* context, const NDArray* x, + const std::vector& outArrs, const bool fullUV, + const bool calcUV, const int switchNum) { + NDArray* S = outArrs[0]; + NDArray* U = outArrs[1]; + // NDArray VT = outArrs[2]->transpose(); + NDArray* V = outArrs[2]; + + NDArray::prepareSpecialUse({S, U, V}, {x}); + + if (x->rankOf() == 2) { + // svdQR(context, x, S, U, VT, fullUV, calcUV); + svdJcb(context, x, S, U, V, fullUV, calcUV); + } else { + // svdBatched(context, *x, *S, *U, *V, fullUV, calcUV); + + ResultSet *tadsU(nullptr), *tadsV(nullptr); + + auto tadsX = + x->allTensorsAlongDimension({x->rankOf() - 2, x->rankOf() - 1}); + auto tadsS = S->allTensorsAlongDimension({S->rankOf() - 1}); + + if (calcUV) { + tadsU = new ResultSet( + U->allTensorsAlongDimension({U->rankOf() - 2, U->rankOf() - 1})); + tadsV = new ResultSet( + V->allTensorsAlongDimension({V->rankOf() - 2, V->rankOf() - 1})); } - else { - // svdBatched(context, *x, *S, *U, *V, fullUV, calcUV); + for (int i = 0; i < tadsX.size(); ++i) + svdJcb(context, &tadsX.at(i), &tadsS.at(i), calcUV ? &tadsU->at(i) : nullptr, calcUV ? &tadsV->at(i) : nullptr, fullUV, calcUV); - ResultSet *tadsU(nullptr), *tadsV(nullptr); - - auto tadsX = x->allTensorsAlongDimension({x->rankOf() - 2, x->rankOf() - 1}); - auto tadsS = S->allTensorsAlongDimension({S->rankOf() - 1}); - - if(calcUV) { - tadsU = new ResultSet(U->allTensorsAlongDimension({U->rankOf() - 2, U->rankOf() - 1})); - tadsV = new ResultSet(V->allTensorsAlongDimension({V->rankOf() - 2, V->rankOf() - 1})); - } - - for (int i = 0; i < tadsX.size(); ++i) - svdJcb(context, tadsX.at(i), tadsS.at(i), calcUV ? tadsU->at(i) : nullptr, calcUV ? tadsV->at(i) : nullptr, fullUV, calcUV); - - if(calcUV) { - delete tadsU; - delete tadsV; - } + if (calcUV) { + delete tadsU; + delete tadsV; } + } - NDArray::registerSpecialUse({S, U, V}, {x}); + NDArray::registerSpecialUse({S, U, V}, {x}); } - -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu b/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu index 2138e1188d8b..43501948e282 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu @@ -18,25 +18,26 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { - namespace ops { - namespace helpers { - template - void toggle_bits__(NDArray &in, NDArray &out) { - auto lambda = LAMBDA_T(_x) { - return ~_x;//eUtils::flip_bits(_x); - }; +namespace ops { +namespace helpers { +template +void toggle_bits__(NDArray &in, NDArray &out) { + auto lambda = LAMBDA_T(_x) { + return ~_x; // eUtils::flip_bits(_x); + }; - in.applyLambda(lambda, out); - } - BUILD_SINGLE_TEMPLATE(template void toggle_bits__, (NDArray &in, NDArray &out), INTEGER_TYPES); + in.applyLambda(lambda, out); +} +BUILD_SINGLE_TEMPLATE(template void toggle_bits__, (NDArray & in, NDArray &out), + INTEGER_TYPES); - void __toggle_bits(sd::LaunchContext * context, NDArray& in, NDArray& out) { - BUILD_SINGLE_SELECTOR(in.dataType(), toggle_bits__, (in, out), INTEGER_TYPES); - } - } - } -} \ No newline at end of file +void __toggle_bits(sd::LaunchContext *context, NDArray &in, NDArray &out) { + BUILD_SINGLE_SELECTOR(in.dataType(), toggle_bits__, (in, out), INTEGER_TYPES); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/top_k.cu b/libnd4j/include/ops/declarable/helpers/cuda/top_k.cu index 61aefa255f82..3fc327af0f77 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/top_k.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/top_k.cu @@ -18,265 +18,308 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// -template +template __global__ static void inTopKCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, - const uint k) { - - - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ uint* sharedMem; - __shared__ X elemToCompare; - __shared__ const X* xTad; - __shared__ Nd4jLong idx, xTadLen; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - xTadLen = shape::length(xTadShapeInfo); - - xTad = reinterpret_cast(vx) + xTadOffsets[blockIdx.x]; - idx = y[shape::getIndexOffset(blockIdx.x, yShapeInfo)]; // shape::length(yShapeInfo) == numTads - elemToCompare = xTad[shape::getIndexOffset(idx, xTadShapeInfo)]; - } - - __syncthreads(); - - sharedMem[threadIdx.x] = 0; - for (Nd4jLong i = threadIdx.x; i < xTadLen; i += blockDim.x) - if(elemToCompare < xTad[shape::getIndexOffset(i, xTadShapeInfo)]) - ++sharedMem[threadIdx.x]; - + void* vz, const Nd4jLong* zShapeInfo, + const Nd4jLong* xTadShapeInfo, + const Nd4jLong* xTadOffsets, const uint k) { + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + __shared__ uint* sharedMem; + __shared__ X elemToCompare; + __shared__ const X* xTad; + __shared__ Nd4jLong idx, xTadLen; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + xTadLen = shape::length(xTadShapeInfo); + + xTad = reinterpret_cast(vx) + xTadOffsets[blockIdx.x]; + idx = y[shape::getIndexOffset( + blockIdx.x, yShapeInfo)]; // shape::length(yShapeInfo) == numTads + elemToCompare = xTad[shape::getIndexOffset(idx, xTadShapeInfo)]; + } + + __syncthreads(); + + sharedMem[threadIdx.x] = 0; + for (Nd4jLong i = threadIdx.x; i < xTadLen; i += blockDim.x) + if (elemToCompare < xTad[shape::getIndexOffset(i, xTadShapeInfo)]) + ++sharedMem[threadIdx.x]; + + __syncthreads(); + + // aggregate sum + for (uint activeThreads = blockDim.x / 2; activeThreads > 0; + activeThreads /= 2) { + if (threadIdx.x < activeThreads) + sharedMem[threadIdx.x] += sharedMem[threadIdx.x + activeThreads]; __syncthreads(); + } - // aggregate sum - for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) { - if (threadIdx.x < activeThreads) - sharedMem[threadIdx.x] += sharedMem[threadIdx.x + activeThreads]; - __syncthreads(); - } - - if (threadIdx.x == 0) - z[shape::getIndexOffset(blockIdx.x, zShapeInfo)] = *sharedMem < k; + if (threadIdx.x == 0) + z[shape::getIndexOffset(blockIdx.x, zShapeInfo)] = *sharedMem < k; } /////////////////////////////////////////////////////////////////// -template -static void inTopKCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, - const uint k) { - - inTopKCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, xTadShapeInfo, xTadOffsets, k); +template +static void inTopKCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, + const Nd4jLong* xShapeInfo, const void* vy, + const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, + const Nd4jLong* xTadShapeInfo, + const Nd4jLong* xTadOffsets, const uint k) { + inTopKCuda<<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, xTadShapeInfo, + xTadOffsets, k); } /////////////////////////////////////////////////////////////////// -int inTopKFunctor(sd::LaunchContext * context, const NDArray* predictions, const NDArray* targets, NDArray* output, const uint k) { - - PointersManager manager(context, "in_top_k"); - - const auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(predictions->shapeInfo(), {1}); - - const int threadsPerBlock = MAX_NUM_THREADS; - const int blocksPerGrid = static_cast(packX.numberOfTads()); - const int sharedMem = sizeof(uint) * threadsPerBlock + 128; - - const auto xType = predictions->dataType(); - const auto yType = targets->dataType(); - - NDArray::prepareSpecialUse({output}, {predictions, targets}); - BUILD_DOUBLE_SELECTOR(xType, yType, inTopKCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), predictions->specialBuffer(), predictions->specialShapeInfo(), targets->specialBuffer(), targets->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), packX.specialShapeInfo(), packX.specialOffsets(), k), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {predictions, targets}); - - manager.synchronize(); - - return Status::OK(); +int inTopKFunctor(sd::LaunchContext* context, const NDArray* predictions, + const NDArray* targets, NDArray* output, const uint k) { + PointersManager manager(context, "in_top_k"); + + const auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + predictions->shapeInfo(), {1}); + + const int threadsPerBlock = MAX_NUM_THREADS; + const int blocksPerGrid = static_cast(packX.numberOfTads()); + const int sharedMem = sizeof(uint) * threadsPerBlock + 128; + + const auto xType = predictions->dataType(); + const auto yType = targets->dataType(); + + NDArray::prepareSpecialUse({output}, {predictions, targets}); + BUILD_DOUBLE_SELECTOR( + xType, yType, inTopKCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + predictions->specialBuffer(), predictions->specialShapeInfo(), + targets->specialBuffer(), targets->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), + packX.specialShapeInfo(), packX.specialOffsets(), k), + FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {predictions, targets}); + + manager.synchronize(); + + return Status::OK(); } - template - static _CUDA_G void topValuesMover(void const* vx, Nd4jLong const* xTadShapeInfo, Nd4jLong const* xTadOffsets, void const* vi, Nd4jLong const* iTadShapeInfo, Nd4jLong const* iTadOffsets, void *vz, Nd4jLong const* zTadShapeInfo, Nd4jLong const* zTadOffsets, Nd4jLong tadLength, int numTads, int k) { - for (int t = blockIdx.x; t < numTads; t += gridDim.x) { - auto x = reinterpret_cast(vx) + xTadOffsets[t]; - auto i = reinterpret_cast(vi) + iTadOffsets[t]; - auto z = reinterpret_cast(vz) + zTadOffsets[t]; - - for (int e = threadIdx.x; e < k; e += blockDim.x) { - auto idx = i[shape::getIndexOffset(e, iTadShapeInfo)]; - - z[shape::getIndexOffset(e, zTadShapeInfo)] = x[shape::getIndexOffset(idx, xTadShapeInfo)]; - } - } +template +static _CUDA_G void topValuesMover( + void const* vx, Nd4jLong const* xTadShapeInfo, Nd4jLong const* xTadOffsets, + void const* vi, Nd4jLong const* iTadShapeInfo, Nd4jLong const* iTadOffsets, + void* vz, Nd4jLong const* zTadShapeInfo, Nd4jLong const* zTadOffsets, + Nd4jLong tadLength, int numTads, int k) { + for (int t = blockIdx.x; t < numTads; t += gridDim.x) { + auto x = reinterpret_cast(vx) + xTadOffsets[t]; + auto i = reinterpret_cast(vi) + iTadOffsets[t]; + auto z = reinterpret_cast(vz) + zTadOffsets[t]; + + for (int e = threadIdx.x; e < k; e += blockDim.x) { + auto idx = i[shape::getIndexOffset(e, iTadShapeInfo)]; + + z[shape::getIndexOffset(e, zTadShapeInfo)] = + x[shape::getIndexOffset(idx, xTadShapeInfo)]; } + } +} - - template - static _CUDA_G void indicesAlongDimension(void const* vx, Nd4jLong const* xTadShapeInfo, Nd4jLong const* xTadOffsets, void* vi, Nd4jLong const* iTadShapeInfo, Nd4jLong const* iTadOffsets, void *vz, Nd4jLong const* zTadShapeInfo, Nd4jLong const* zTadOffsets, Nd4jLong tadLength, int numTads, int k, int scanWidth, bool needSort) { - extern __shared__ char _shmem[]; - - X* tempValues = reinterpret_cast(_shmem) + threadIdx.x * scanWidth; - Y* tempIndices = reinterpret_cast(reinterpret_cast(_shmem) + blockDim.x * scanWidth) + threadIdx.x * scanWidth; - - __shared__ X localMaximum; - if (threadIdx.x == 0) - localMaximum = -DataTypeUtils::max(); - __syncthreads(); - - for (int t = blockIdx.x; t < numTads; t += gridDim.x) { - auto x = reinterpret_cast(vx) + xTadOffsets[t]; - auto i = reinterpret_cast(vi) + iTadOffsets[t]; - auto z = reinterpret_cast(vz) + zTadOffsets[t]; - - // we'll do multiple reads here - for (int p = 0; p < k; p += scanWidth) { - - // resetting temporary storage - for (int p = 0; p < scanWidth; p++) { - tempValues[p] = -DataTypeUtils::max(); - tempIndices[p] = DataTypeUtils::max(); - } - - // local max values/indices - for (int e = threadIdx.x; e < tadLength; e++) { - auto value = x[shape::getIndexOffset(e, xTadShapeInfo)]; - - // we'll compare this value to current stored ones - for (int f = 0; f < scanWidth; f++) { - if (value > tempValues[f] && (p == 0 || value < localMaximum)) { - tempValues[f] = value; - tempIndices[f] = e; - } - } - } - __syncthreads(); - - // at this point we have local part ready for merge and define global maximum for this iteration, and local maximum for next iteration - for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) { - if (threadIdx.x < activeThreads) { - if (tempValues[0] < tempValues[0 + activeThreads * scanWidth]) { - tempValues[0] = tempValues[0 + activeThreads * scanWidth]; - tempIndices[0] = tempIndices[0 + activeThreads * scanWidth]; - } - } - __syncthreads(); - } - __syncthreads(); - - // at this point we know local minimum for next iteration - if (threadIdx.x == 0) { - localMaximum = tempValues[scanWidth - 1]; - z[shape::getIndexOffset(p, zTadShapeInfo)] = tempValues[scanWidth - 1]; - i[shape::getIndexOffset(p, iTadShapeInfo)] = tempIndices[scanWidth - 1]; - } - __syncthreads(); - } - - __syncthreads(); - if (!needSort) { - // if we don't need sort, we need to return values based on their indices (ascending) - for (int m = 0; m < k; m++) { - if (m % 2 == 0) { - for (int tid = threadIdx.x; tid < k; tid += blockDim.x) { - auto top = 2 * tid + 1; - if (top < k) { - auto t0 = shape::getIndexOffset(top - 1, iTadShapeInfo); - auto t1 = shape::getIndexOffset(top, iTadShapeInfo); - - if (i[t0] > i[t1]) { - // swap indices first - Y di0 = i[t0]; - i[t0] = i[t1]; - i[t1] = di0; - - //swap values next - - X dz0 = z[t0]; - z[t0] = z[t1]; - z[t1] = dz0; - } - } - } - } else { - for (int tid = threadIdx.x; tid < k; tid += blockDim.x) { - auto top = 2 * tid + 2; - if (top < k) { - auto t0 = shape::getIndexOffset(top - 1, iTadShapeInfo); - auto t1 = shape::getIndexOffset(top, iTadShapeInfo); - - if (i[t0] > i[t1]) { - // swap indices first - Y di0 = i[t0]; - i[t0] = i[t1]; - i[t1] = di0; - - //swap values next - - X dz0 = z[t0]; - z[t0] = z[t1]; - z[t1] = dz0; - } - } - } - } - __syncthreads(); - } - } +template +static _CUDA_G void indicesAlongDimension( + void const* vx, Nd4jLong const* xTadShapeInfo, Nd4jLong const* xTadOffsets, + void* vi, Nd4jLong const* iTadShapeInfo, Nd4jLong const* iTadOffsets, + void* vz, Nd4jLong const* zTadShapeInfo, Nd4jLong const* zTadOffsets, + Nd4jLong tadLength, int numTads, int k, int scanWidth, bool needSort) { + extern __shared__ char _shmem[]; + + X* tempValues = reinterpret_cast(_shmem) + threadIdx.x * scanWidth; + Y* tempIndices = reinterpret_cast(reinterpret_cast(_shmem) + + blockDim.x * scanWidth) + + threadIdx.x * scanWidth; + + __shared__ X localMaximum; + if (threadIdx.x == 0) localMaximum = -DataTypeUtils::max(); + __syncthreads(); + + for (int t = blockIdx.x; t < numTads; t += gridDim.x) { + auto x = reinterpret_cast(vx) + xTadOffsets[t]; + auto i = reinterpret_cast(vi) + iTadOffsets[t]; + auto z = reinterpret_cast(vz) + zTadOffsets[t]; + + // we'll do multiple reads here + for (int p = 0; p < k; p += scanWidth) { + // resetting temporary storage + for (int p = 0; p < scanWidth; p++) { + tempValues[p] = -DataTypeUtils::max(); + tempIndices[p] = DataTypeUtils::max(); + } + + // local max values/indices + for (int e = threadIdx.x; e < tadLength; e++) { + auto value = x[shape::getIndexOffset(e, xTadShapeInfo)]; + + // we'll compare this value to current stored ones + for (int f = 0; f < scanWidth; f++) { + if (value > tempValues[f] && (p == 0 || value < localMaximum)) { + tempValues[f] = value; + tempIndices[f] = e; + } + } + } + __syncthreads(); + + // at this point we have local part ready for merge and define global + // maximum for this iteration, and local maximum for next iteration + for (uint activeThreads = blockDim.x / 2; activeThreads > 0; + activeThreads /= 2) { + if (threadIdx.x < activeThreads) { + if (tempValues[0] < tempValues[0 + activeThreads * scanWidth]) { + tempValues[0] = tempValues[0 + activeThreads * scanWidth]; + tempIndices[0] = tempIndices[0 + activeThreads * scanWidth]; + } } + __syncthreads(); + } + __syncthreads(); + + // at this point we know local minimum for next iteration + if (threadIdx.x == 0) { + localMaximum = tempValues[scanWidth - 1]; + z[shape::getIndexOffset(p, zTadShapeInfo)] = tempValues[scanWidth - 1]; + i[shape::getIndexOffset(p, iTadShapeInfo)] = tempIndices[scanWidth - 1]; + } + __syncthreads(); } - - template - static int topKFunctor_(sd::LaunchContext * context, const NDArray* input, NDArray* values, NDArray* indices, const uint k, bool needSort) { - - auto packX = ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), {input->rankOf() - 1}); - auto packI = ConstantTadHelper::getInstance().tadForDimensions(indices->shapeInfo(), {input->rankOf() - 1}); - auto packZ = ConstantTadHelper::getInstance().tadForDimensions(values->shapeInfo(), {input->rankOf() - 1}); - - auto tadLength = shape::length(packX.primaryShapeInfo()); - - // we get top K values first - if (k == 1) { - input->applyIndexReduce(indexreduce::IndexMax, *indices, {input->rankOf() - 1}); - - // copy values on specified indices - topValuesMover<<<256, 256, 1024, *context->getCudaStream()>>>(input->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), indices->specialBuffer(), packI.platformShapeInfo(), packI.platformOffsets(), values->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, packX.numberOfTads(), k); + __syncthreads(); + if (!needSort) { + // if we don't need sort, we need to return values based on their indices + // (ascending) + for (int m = 0; m < k; m++) { + if (m % 2 == 0) { + for (int tid = threadIdx.x; tid < k; tid += blockDim.x) { + auto top = 2 * tid + 1; + if (top < k) { + auto t0 = shape::getIndexOffset(top - 1, iTadShapeInfo); + auto t1 = shape::getIndexOffset(top, iTadShapeInfo); + + if (i[t0] > i[t1]) { + // swap indices first + Y di0 = i[t0]; + i[t0] = i[t1]; + i[t1] = di0; + + // swap values next + + X dz0 = z[t0]; + z[t0] = z[t1]; + z[t1] = dz0; + } + } + } } else { - int scanWidth = 1; - int numTreads = 256; - int shMemSize = (numTreads * sizeof(X) * scanWidth) + (numTreads * sizeof(Y) * scanWidth) + 512; - - indicesAlongDimension<<<256, numTreads, shMemSize, *context->getCudaStream()>>>(input->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), indices->specialBuffer(), packI.platformShapeInfo(), packI.platformOffsets(), values->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, packX.numberOfTads(), k, scanWidth, needSort); + for (int tid = threadIdx.x; tid < k; tid += blockDim.x) { + auto top = 2 * tid + 2; + if (top < k) { + auto t0 = shape::getIndexOffset(top - 1, iTadShapeInfo); + auto t1 = shape::getIndexOffset(top, iTadShapeInfo); + + if (i[t0] > i[t1]) { + // swap indices first + Y di0 = i[t0]; + i[t0] = i[t1]; + i[t1] = di0; + + // swap values next + + X dz0 = z[t0]; + z[t0] = z[t1]; + z[t1] = dz0; + } + } + } } - - return Status::OK(); + __syncthreads(); + } } + } +} - int topKFunctor(sd::LaunchContext * context, const NDArray* input, NDArray* values, NDArray* indices, const uint k, bool needSort) { - input->syncToDevice(); +template +static int topKFunctor_(sd::LaunchContext* context, const NDArray* input, + NDArray* values, NDArray* indices, const uint k, + bool needSort) { + auto packX = ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), {input->rankOf() - 1}); + auto packI = ConstantTadHelper::getInstance().tadForDimensions( + indices->shapeInfo(), {input->rankOf() - 1}); + auto packZ = ConstantTadHelper::getInstance().tadForDimensions( + values->shapeInfo(), {input->rankOf() - 1}); + + auto tadLength = shape::length(packX.primaryShapeInfo()); + + // we get top K values first + if (k == 1) { + input->applyIndexReduce(indexreduce::IndexMax, *indices, + {input->rankOf() - 1}); + + // copy values on specified indices + topValuesMover<<<256, 256, 1024, *context->getCudaStream()>>>( + input->specialBuffer(), packX.platformShapeInfo(), + packX.platformOffsets(), indices->specialBuffer(), + packI.platformShapeInfo(), packI.platformOffsets(), + values->specialBuffer(), packZ.platformShapeInfo(), + packZ.platformOffsets(), tadLength, packX.numberOfTads(), k); + } else { + int scanWidth = 1; + int numTreads = 256; + int shMemSize = (numTreads * sizeof(X) * scanWidth) + + (numTreads * sizeof(Y) * scanWidth) + 512; + + indicesAlongDimension + <<<256, numTreads, shMemSize, *context->getCudaStream()>>>( + input->specialBuffer(), packX.platformShapeInfo(), + packX.platformOffsets(), indices->specialBuffer(), + packI.platformShapeInfo(), packI.platformOffsets(), + values->specialBuffer(), packZ.platformShapeInfo(), + packZ.platformOffsets(), tadLength, packX.numberOfTads(), k, + scanWidth, needSort); + } + + return Status::OK(); +} - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), topKFunctor_, (context, input, values, indices, k, needSort), LIBND4J_TYPES, INDEXING_TYPES); +int topKFunctor(sd::LaunchContext* context, const NDArray* input, + NDArray* values, NDArray* indices, const uint k, + bool needSort) { + input->syncToDevice(); - values->tickWriteDevice(); - indices->tickWriteDevice(); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), topKFunctor_, + (context, input, values, indices, k, needSort), + LIBND4J_TYPES, INDEXING_TYPES); - return Status::OK(); - } + values->tickWriteDevice(); + indices->tickWriteDevice(); + return Status::OK(); } -} -} \ No newline at end of file + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu index 80e0e08581a6..b74a0055898e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu @@ -19,294 +19,330 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - -#include -#include -#include -#include #include -#include +#include #include -#include #include +#include +#include +#include +#include + +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template -__global__ static void invertPermutationCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo) { - - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); - - __shared__ Nd4jLong len, totalThreads; - - if (threadIdx.x == 0) { - - len = shape::length(xShapeInfo); - totalThreads = gridDim.x * blockDim.x; - } - - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < len; i += totalThreads) { - - const auto xOffset = shape::getIndexOffset(i, xShapeInfo); - const Nd4jLong index = x[xOffset]; - const auto zOffset = shape::getIndexOffset(index, zShapeInfo); - z[zOffset] = i; - } +template +__global__ static void invertPermutationCuda(const void* vx, + const Nd4jLong* xShapeInfo, + void* vz, + const Nd4jLong* zShapeInfo) { + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); + + __shared__ Nd4jLong len, totalThreads; + + if (threadIdx.x == 0) { + len = shape::length(xShapeInfo); + totalThreads = gridDim.x * blockDim.x; + } + + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < len; i += totalThreads) { + const auto xOffset = shape::getIndexOffset(i, xShapeInfo); + const Nd4jLong index = x[xOffset]; + const auto zOffset = shape::getIndexOffset(index, zShapeInfo); + z[zOffset] = i; + } } /////////////////////////////////////////////////////////////////// -template -__host__ static void invertPermutationCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo) { - - invertPermutationCuda<<>>(vx, xShapeInfo, vz, zShapeInfo); +template +__host__ static void invertPermutationCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo) { + invertPermutationCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo); } //////////////////////////////////////////////////////////////////////// -void invertPermutation(sd::LaunchContext* context, const NDArray& input, NDArray& output) { - - const int threadsPerBlock = MAX_NUM_THREADS; - const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - PointersManager manager(context, "invertPermutation"); - - NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), invertPermutationCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo()), LIBND4J_TYPES); - NDArray::registerSpecialUse({&output}, {&input}); - - manager.synchronize(); +void invertPermutation(sd::LaunchContext* context, const NDArray& input, + NDArray& output) { + const int threadsPerBlock = MAX_NUM_THREADS; + const int blocksPerGrid = + (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "invertPermutation"); + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR( + input.dataType(), invertPermutationCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo()), + LIBND4J_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + manager.synchronize(); } ////////////////////////////////////////////////////////////////////////// -template -__global__ static void traceCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint diagLen) { - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ T* sharedMem; - __shared__ int xRank, zRank, *coordsMem; // xRank = zRank + 2 - __shared__ Nd4jLong xLen, zLen; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - coordsMem = reinterpret_cast(shmem + blockDim.x * sizeof(T)); - - xRank = shape::rank(xShapeInfo); - zRank = shape::rank(zShapeInfo); - xLen = shape::length(xShapeInfo); - zLen = shape::length(zShapeInfo); // corresponds to number of matrices - +template +__global__ static void traceCuda(const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + const uint diagLen) { + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ T* sharedMem; + __shared__ int xRank, zRank, *coordsMem; // xRank = zRank + 2 + __shared__ Nd4jLong xLen, zLen; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + coordsMem = reinterpret_cast(shmem + blockDim.x * sizeof(T)); + + xRank = shape::rank(xShapeInfo); + zRank = shape::rank(zShapeInfo); + xLen = shape::length(xShapeInfo); + zLen = shape::length(zShapeInfo); // corresponds to number of matrices + } + __syncthreads(); + + auto coords = coordsMem + threadIdx.x * xRank; + + for (uint m = blockIdx.x; m < zLen; + m += + gridDim.x) { // one block per each element of z, that is per each matrix + + shape::index2coords(m, zShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); + + sharedMem[threadIdx.x] = 0; + + for (uint i = threadIdx.x; i < diagLen; i += blockDim.x) { + coords[zRank] = coords[zRank + 1] = i; + const auto xOffset = shape::getOffset(xShapeInfo, coords); + sharedMem[threadIdx.x] += x[xOffset]; } - __syncthreads(); - - auto coords = coordsMem + threadIdx.x * xRank; - - for (uint m = blockIdx.x; m < zLen; m += gridDim.x) { // one block per each element of z, that is per each matrix - - shape::index2coords(m, zShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); - - sharedMem[threadIdx.x] = 0; - for (uint i = threadIdx.x; i < diagLen; i += blockDim.x) { - - coords[zRank] = coords[zRank + 1] = i; - const auto xOffset = shape::getOffset(xShapeInfo, coords); - sharedMem[threadIdx.x] += x[xOffset]; - } - - __syncthreads(); - - // aggregate sum - for (Nd4jLong activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) { - if (threadIdx.x < activeThreads) - sharedMem[threadIdx.x] += sharedMem[threadIdx.x + activeThreads]; - __syncthreads(); - } + __syncthreads(); - if (threadIdx.x == 0) - z[zOffset] = *sharedMem; - __syncthreads(); + // aggregate sum + for (Nd4jLong activeThreads = blockDim.x / 2; activeThreads > 0; + activeThreads /= 2) { + if (threadIdx.x < activeThreads) + sharedMem[threadIdx.x] += sharedMem[threadIdx.x + activeThreads]; + __syncthreads(); } + if (threadIdx.x == 0) z[zOffset] = *sharedMem; + __syncthreads(); + } } /////////////////////////////////////////////////////////////////// -template -static void traceCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - const uint diagLen) { - - traceCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, diagLen); +template +static void traceCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, + const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const uint diagLen) { + traceCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, diagLen); } - /////////////////////////////////////////////////////////////////// void trace(sd::LaunchContext* context, const NDArray& input, NDArray& output) { - - PointersManager manager(context, "trace"); - - const uint diagLen = input.sizeAt(-1) < input.sizeAt(-2) ? input.sizeAt(-1) : input.sizeAt(-2); - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * (sizeof(int) * input.rankOf() + input.sizeOfT()) + 128; - - NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), traceCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), diagLen), LIBND4J_TYPES); - NDArray::registerSpecialUse({&output}, {&input}); - - manager.synchronize(); + PointersManager manager(context, "trace"); + + const uint diagLen = + input.sizeAt(-1) < input.sizeAt(-2) ? input.sizeAt(-1) : input.sizeAt(-2); + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + threadsPerBlock * (sizeof(int) * input.rankOf() + input.sizeOfT()) + 128; + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR( + input.dataType(), traceCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), diagLen), + LIBND4J_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + manager.synchronize(); } /////////////////////////////////////////////////////////////////// -template -__global__ static void triuBPCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int diag) { - - // x and z have same shapes - const auto x = reinterpret_cast(vx); // gradO - auto z = reinterpret_cast(vz); // gradI - - __shared__ int rank, areSameOffsets, *sharedMem; // xRank = zRank - __shared__ Nd4jLong len, totalThreads; // xLen = zLen - - if (threadIdx.x == 0) { - - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - areSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - rank = shape::rank(xShapeInfo); - len = shape::length(zShapeInfo); - totalThreads = gridDim.x * blockDim.x; - } - - __syncthreads(); - - auto coords = sharedMem + threadIdx.x * rank; - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < len; i += totalThreads) { - - shape::index2coords(i, zShapeInfo, coords); - - const auto zOffset = shape::getOffset(zShapeInfo, coords); - - if((coords[rank - 2] + diag > coords[rank - 1])) // row + diag > col - z[zOffset] = 0; - else - z[zOffset] = x[areSameOffsets ? zOffset : shape::getOffset(xShapeInfo, coords)]; - } +template +__global__ static void triuBPCuda(const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + const int diag) { + // x and z have same shapes + const auto x = reinterpret_cast(vx); // gradO + auto z = reinterpret_cast(vz); // gradI + + __shared__ int rank, areSameOffsets, *sharedMem; // xRank = zRank + __shared__ Nd4jLong len, totalThreads; // xLen = zLen + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + areSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + rank = shape::rank(xShapeInfo); + len = shape::length(zShapeInfo); + totalThreads = gridDim.x * blockDim.x; + } + + __syncthreads(); + + auto coords = sharedMem + threadIdx.x * rank; + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < len; i += totalThreads) { + shape::index2coords(i, zShapeInfo, coords); + + const auto zOffset = shape::getOffset(zShapeInfo, coords); + + if ((coords[rank - 2] + diag > coords[rank - 1])) // row + diag > col + z[zOffset] = 0; + else + z[zOffset] = + x[areSameOffsets ? zOffset : shape::getOffset(xShapeInfo, coords)]; + } } /////////////////////////////////////////////////////////////////// -template -static void triuBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int diag) { - - triuBPCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, diag); +template +static void triuBPCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, + const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const int diag) { + triuBPCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, diag); } /////////////////////////////////////////////////////////////////// -void triuBP(sd::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int diagonal) { - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (gradO.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * gradO.rankOf() + 128; - - PointersManager manager(context, "triuBP"); - - NDArray::prepareSpecialUse({&gradI}, {&gradO}); - BUILD_SINGLE_SELECTOR(gradI.dataType(), triuBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), gradO.specialBuffer(), gradO.specialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), diagonal), LIBND4J_TYPES); - NDArray::registerSpecialUse({&gradI}, {&gradO}); - - manager.synchronize(); +void triuBP(sd::LaunchContext* context, const NDArray& input, + const NDArray& gradO, NDArray& gradI, const int diagonal) { + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (gradO.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(int) * gradO.rankOf() + 128; + + PointersManager manager(context, "triuBP"); + + NDArray::prepareSpecialUse({&gradI}, {&gradO}); + BUILD_SINGLE_SELECTOR( + gradI.dataType(), triuBPCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + gradO.specialBuffer(), gradO.specialShapeInfo(), gradI.specialBuffer(), + gradI.specialShapeInfo(), diagonal), + LIBND4J_TYPES); + NDArray::registerSpecialUse({&gradI}, {&gradO}); + + manager.synchronize(); } /////////////////////////////////////////////////////////////////// -template -__global__ static void tileBPCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, Nd4jLong* globMem) { +template +__global__ static void tileBPCuda(const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + Nd4jLong* globMem) { + // x and z have same shapes + const auto x = reinterpret_cast(vx); // gradO + auto z = reinterpret_cast(vz); // gradI - // x and z have same shapes - const auto x = reinterpret_cast(vx); // gradO - auto z = reinterpret_cast(vz); // gradI + __shared__ int xRank, zRank, *sharedMem; // xRank >= zRank + __shared__ Nd4jLong numOfXOffsets, zLen, totalThreads; // xLen >= zLen - __shared__ int xRank, zRank, *sharedMem; // xRank >= zRank - __shared__ Nd4jLong numOfXOffsets, zLen, totalThreads; // xLen >= zLen + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - if (threadIdx.x == 0) { + xRank = shape::rank(zShapeInfo); + zLen = shape::length(zShapeInfo); + numOfXOffsets = shape::length(xShapeInfo) / zLen; - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + totalThreads = gridDim.x * blockDim.x; + } - xRank = shape::rank(zShapeInfo); - zLen = shape::length(zShapeInfo); - numOfXOffsets = shape::length(xShapeInfo) / zLen; + __syncthreads(); - totalThreads = gridDim.x * blockDim.x; - } + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - __syncthreads(); + auto memBuff = sharedMem + threadIdx.x * 2 * xRank; + auto xOffsets = globMem + tid * numOfXOffsets; - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + for (Nd4jLong i = tid; i < zLen; i += totalThreads) { + const auto zOffset = shape::getIndexOffset(i, zShapeInfo); - auto memBuff = sharedMem + threadIdx.x * 2 * xRank; - auto xOffsets = globMem + tid * numOfXOffsets; + shape::outerArrayOffsets(xOffsets, i, xShapeInfo, zShapeInfo, memBuff); - for (Nd4jLong i = tid; i < zLen; i += totalThreads) { - - const auto zOffset = shape::getIndexOffset(i, zShapeInfo); - - shape::outerArrayOffsets(xOffsets, i, xShapeInfo, zShapeInfo, memBuff); - - z[zOffset] = x[xOffsets[0]]; // first offset - for (Nd4jLong j = 1; j < numOfXOffsets; ++j) // rest offsets - z[zOffset] += x[xOffsets[j]]; - } + z[zOffset] = x[xOffsets[0]]; // first offset + for (Nd4jLong j = 1; j < numOfXOffsets; ++j) // rest offsets + z[zOffset] += x[xOffsets[j]]; + } } /////////////////////////////////////////////////////////////////// -template -static void tileBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, Nd4jLong* globMem) { - - tileBPCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, globMem); +template +static void tileBPCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, + const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, Nd4jLong* globMem) { + tileBPCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, globMem); } - ////////////////////////////////////////////////////////////////////////// -void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector reps) { - - NDArray memBuff('c', gradO.getShapeAsVector(), sd::DataType::INT64, context); // empty auxiliary array for storing device memory which will be used in kernel calculations - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (gradI.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * 2 * gradO.rankOf() + 128; - - PointersManager manager(context, "tileBP"); - - NDArray::prepareSpecialUse({&gradI}, {&gradO, &memBuff}); - BUILD_SINGLE_SELECTOR(gradI.dataType(), tileBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), gradO.specialBuffer(), gradO.specialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), reinterpret_cast(memBuff.specialBuffer())), FLOAT_TYPES); - NDArray::registerSpecialUse({&gradI}, {&gradO, &memBuff}); - - manager.synchronize(); +void tileBP(sd::LaunchContext* context, const NDArray& gradO /*input*/, + NDArray& gradI /*output*/, const std::vector reps) { + NDArray memBuff('c', gradO.getShapeAsVector(), sd::DataType::INT64, + context); // empty auxiliary array for storing device memory + // which will be used in kernel calculations + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (gradI.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + threadsPerBlock * sizeof(int) * 2 * gradO.rankOf() + 128; + + PointersManager manager(context, "tileBP"); + + NDArray::prepareSpecialUse({&gradI}, {&gradO, &memBuff}); + BUILD_SINGLE_SELECTOR( + gradI.dataType(), tileBPCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + gradO.specialBuffer(), gradO.specialShapeInfo(), gradI.specialBuffer(), + gradI.specialShapeInfo(), + reinterpret_cast(memBuff.specialBuffer())), + FLOAT_TYPES); + NDArray::registerSpecialUse({&gradI}, {&gradO, &memBuff}); + + manager.synchronize(); } ////////////////////////////////////////////////////////////////////////// void eye(sd::LaunchContext * context, NDArray& output) { - - output.setIdentity(); + output.setIdentity(); } -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu b/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu index e77bb4e19b69..8884ffdce8c6 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu @@ -18,150 +18,161 @@ // @author GS // -#include #include #include #include +#include + #include "../triangular_solve.h" namespace sd { - namespace ops { - namespace helpers { - /* - * lower triangular process for system of linear equations - * x_1 = b_1/a_1,1 - * x_2 = (b_2 - a_2,1 * x_1) / a_2,2 - * x_3 = (b_3 - a_3,1 * x_1 - a_3,2 * x_2) / a_3,3 - * ... - * x_M = (b_M - a_M,1 * x_1 - ... a_M,M-1 * x_M-1)/ a_M,M - * - * output == x - * a == leftInput - * b == rightInput - * - * */ - template - static _CUDA_HD void lowerTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape, - T const* rightInput, Nd4jLong const* rightInputShape, - bool const unitOnDiag, T* output, const Nd4jLong* outputShape, - Nd4jLong rows, Nd4jLong cols) { - - for (auto r = 0; r < rows; r++) { - for (auto j = 0; j < cols; j++) { - Nd4jLong posY[] = {r, j}; - Nd4jLong posX[] = {r, r}; - auto xIndex = shape::getOffset(leftInputShape, posX, 0); - auto yIndex = shape::getOffset(rightInputShape, posY, 0); - auto zIndex = shape::getOffset(outputShape, posY, 0); - - auto sum = rightInput[yIndex]; - for (auto c = 0; c < r; c++) { - Nd4jLong posZ[] = {c, j}; - Nd4jLong pos[] = {r, c}; - auto xcIndex = shape::getOffset(leftInputShape, pos, 0); - auto zcIndex = shape::getOffset(outputShape, posZ, 0); - sum -= leftInput[xcIndex] * output[zcIndex]; - } - output[zIndex] = unitOnDiag?sum:sum / leftInput[xIndex]; - } - } - } - - /* - * upper triangular process for system of linear equations - * x_M = b_M/a_M,M - * x_M-1 = (b_M-1 - a_M-1,M-2 * x_M) / a_M-1,M-1 - * x_M-2 = (b_M-2 - a_M-2,M-3 * x_M-2 - a_M-2,M-1 * x_M) / a_3,3 - * ... - * x_1 = (b_1 - a_1,2 * x_2 - ... a_1,M * x_M)/ a_1,1 - * - * output == x - * a == leftInput - * b == rightInput - * - * */ - - template - static _CUDA_HD void upperTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape, - T const* rightInput, Nd4jLong const* rightInputShape, bool const unitOnDiag, T* output, - const Nd4jLong* outputShape, Nd4jLong rows, Nd4jLong cols) { - - for (auto r = rows; r > 0; r--) { - for (auto j = 0; j < cols; j++) { - Nd4jLong posY[] = {r - 1, j}; - Nd4jLong posX[] = {r - 1, r - 1}; - auto xIndex = shape::getOffset(leftInputShape, posX, 0); - auto yIndex = shape::getOffset(rightInputShape, posY, 0); - auto zIndex = shape::getOffset(outputShape, posY, 0); - auto sum = rightInput[yIndex]; - for (auto c = r; c < rows; c++) { - Nd4jLong posZ[] = {c, j}; - Nd4jLong pos[] = {r - 1, c}; - auto zcIndex = shape::getOffset(outputShape, posZ, 0); - auto xcIndex = shape::getOffset(leftInputShape, pos, 0); - sum -= leftInput[xcIndex] * output[zcIndex]; - } - output[zIndex] = unitOnDiag?sum:sum / leftInput[xIndex]; - } - } - } - - template - static __global__ void triangularSolveKernel(T const* leftInput, Nd4jLong const* leftPartShape, - T const* rightInput, Nd4jLong const* rightPartShape, bool const lower, bool const unitsOnDiag, T* output, - const Nd4jLong* outputShape, const Nd4jLong* tadLeftShape, const Nd4jLong* tadLeftOffset, const Nd4jLong* tadRightShape, - const Nd4jLong* tadRightOffset, const Nd4jLong* tadOutputShape, const Nd4jLong* tadOutputOffset, Nd4jLong batchNum) { - - __shared__ Nd4jLong rows; - __shared__ Nd4jLong cols; +namespace ops { +namespace helpers { +/* + * lower triangular process for system of linear equations + * x_1 = b_1/a_1,1 + * x_2 = (b_2 - a_2,1 * x_1) / a_2,2 + * x_3 = (b_3 - a_3,1 * x_1 - a_3,2 * x_2) / a_3,3 + * ... + * x_M = (b_M - a_M,1 * x_1 - ... a_M,M-1 * x_M-1)/ a_M,M + * + * output == x + * a == leftInput + * b == rightInput + * + * */ +template +static _CUDA_HD void lowerTriangularSolve( + T const* leftInput, Nd4jLong const* leftInputShape, T const* rightInput, + Nd4jLong const* rightInputShape, bool const unitOnDiag, T* output, + const Nd4jLong* outputShape, Nd4jLong rows, Nd4jLong cols) { + for (auto r = 0; r < rows; r++) { + for (auto j = 0; j < cols; j++) { + Nd4jLong posY[] = {r, j}; + Nd4jLong posX[] = {r, r}; + auto xIndex = shape::getOffset(leftInputShape, posX, 0); + auto yIndex = shape::getOffset(rightInputShape, posY, 0); + auto zIndex = shape::getOffset(outputShape, posY, 0); + + auto sum = rightInput[yIndex]; + for (auto c = 0; c < r; c++) { + Nd4jLong posZ[] = {c, j}; + Nd4jLong pos[] = {r, c}; + auto xcIndex = shape::getOffset(leftInputShape, pos, 0); + auto zcIndex = shape::getOffset(outputShape, posZ, 0); + sum -= leftInput[xcIndex] * output[zcIndex]; + } + output[zIndex] = unitOnDiag?sum:sum / leftInput[xIndex]; + } + } +} - if (threadIdx.x == 0) { - rows = shape::sizeAt(leftPartShape, -2); - cols = shape::sizeAt(rightPartShape, -1); - } - __syncthreads(); - - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto stop = batchNum; - auto increment = blockDim.x * gridDim.x; - - for (auto i = start; i < stop; i += increment) { - auto pLeftPart = leftInput + tadLeftOffset[i]; - auto pRightPart = rightInput + tadRightOffset[i]; - auto pOutputPart = output + tadOutputOffset[i]; - if (lower) { - lowerTriangularSolve(pLeftPart, tadLeftShape, pRightPart, tadRightShape, unitsOnDiag, pOutputPart, tadOutputShape, rows, cols); - } else { - upperTriangularSolve(pLeftPart, tadLeftShape, pRightPart, tadRightShape, unitsOnDiag, pOutputPart, tadOutputShape, rows, cols); - } - } - } +/* + * upper triangular process for system of linear equations + * x_M = b_M/a_M,M + * x_M-1 = (b_M-1 - a_M-1,M-2 * x_M) / a_M-1,M-1 + * x_M-2 = (b_M-2 - a_M-2,M-3 * x_M-2 - a_M-2,M-1 * x_M) / a_3,3 + * ... + * x_1 = (b_1 - a_1,2 * x_2 - ... a_1,M * x_M)/ a_1,1 + * + * output == x + * a == leftInput + * b == rightInput + * + * */ + +template +static _CUDA_HD void upperTriangularSolve( + T const* leftInput, Nd4jLong const* leftInputShape, T const* rightInput, + Nd4jLong const* rightInputShape, bool const unitOnDiag, T* output, + const Nd4jLong* outputShape, Nd4jLong rows, Nd4jLong cols) { + for (auto r = rows; r > 0; r--) { + for (auto j = 0; j < cols; j++) { + Nd4jLong posY[] = {r - 1, j}; + Nd4jLong posX[] = {r - 1, r - 1}; + auto xIndex = shape::getOffset(leftInputShape, posX, 0); + auto yIndex = shape::getOffset(rightInputShape, posY, 0); + auto zIndex = shape::getOffset(outputShape, posY, 0); + auto sum = rightInput[yIndex]; + for (auto c = r; c < rows; c++) { + Nd4jLong posZ[] = {c, j}; + Nd4jLong pos[] = {r - 1, c}; + auto zcIndex = shape::getOffset(outputShape, posZ, 0); + auto xcIndex = shape::getOffset(leftInputShape, pos, 0); + sum -= leftInput[xcIndex] * output[zcIndex]; + } + output[zIndex] = unitOnDiag?sum:sum / leftInput[xIndex]; + } + } +} - template - static int triangularSolveFunctor_(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, - bool lower, bool unitsOnDiag, NDArray* output) { - NDArray::prepareSpecialUse({output}, {leftInput, rightInput}); - auto leftTads = ConstantTadHelper::getInstance().tadForDimensions(leftInput->shapeInfo(), {-2, -1}); - auto rightTads = ConstantTadHelper::getInstance().tadForDimensions(rightInput->shapeInfo(), {-2, -1}); - auto outputTads = ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), {-2, -1}); - - auto stream = context->getCudaStream(); - T const* leftBuf = reinterpret_cast(leftInput->specialBuffer()); - T const* rightBuf = reinterpret_cast(rightInput->specialBuffer()); - T* outputBuf = reinterpret_cast(output->specialBuffer()); - triangularSolveKernel<<<128, 128, 256, *stream>>>(leftBuf, leftInput->specialShapeInfo(), - rightBuf, rightInput->specialShapeInfo(), lower, unitsOnDiag, outputBuf, output->specialShapeInfo(), - leftTads.specialShapeInfo(), leftTads.specialOffsets(), rightTads.specialShapeInfo(), - rightTads.specialOffsets(), outputTads.specialShapeInfo(), outputTads.specialOffsets(), - leftTads.numberOfTads()); - - NDArray::registerSpecialUse({output}, {leftInput, rightInput}); - - return Status::OK(); +template +static __global__ void triangularSolveKernel( + T const* leftInput, Nd4jLong const* leftPartShape, T const* rightInput, + Nd4jLong const* rightPartShape, bool const lower, bool const unitsOnDiag, + T* output, const Nd4jLong* outputShape, const Nd4jLong* tadLeftShape, const + Nd4jLong * tadLeftOffset, const Nd4jLong* tadRightShape, + const Nd4jLong* tadRightOffset, const Nd4jLong* tadOutputShape, + const Nd4jLong* tadOutputOffset, Nd4jLong batchNum) { + __shared__ Nd4jLong rows; + __shared__ Nd4jLong cols; + + if (threadIdx.x == 0) { + rows = shape::sizeAt(leftPartShape, -2); + cols = shape::sizeAt(rightPartShape, -1); + } + __syncthreads(); + + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto stop = batchNum; + auto increment = blockDim.x * gridDim.x; + + for (auto i = start; i < stop; i += increment) { + auto pLeftPart = leftInput + tadLeftOffset[i]; + auto pRightPart = rightInput + tadRightOffset[i]; + auto pOutputPart = output + tadOutputOffset[i]; + if (lower) { + lowerTriangularSolve(pLeftPart, tadLeftShape, pRightPart, + tadRightShape, unitsOnDiag, pOutputPart, + tadOutputShape, rows, cols); + } else { + upperTriangularSolve(pLeftPart, tadLeftShape, pRightPart, + tadRightShape, unitsOnDiag, pOutputPart, + tadOutputShape, rows, cols); + } + } +} - } +template +static int triangularSolveFunctor_(sd::LaunchContext* context, + NDArray* leftInput, NDArray* rightInput, + bool lower, bool unitsOnDiag, NDArray* output) { + NDArray::prepareSpecialUse({output}, {leftInput, rightInput}); + auto leftTads = ConstantTadHelper::getInstance().tadForDimensions( + leftInput->shapeInfo(), {-2, -1}); + auto rightTads = ConstantTadHelper::getInstance().tadForDimensions( + rightInput->shapeInfo(), {-2, -1}); + auto outputTads = ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), {-2, -1}); + + auto stream = context->getCudaStream(); + T const* leftBuf = reinterpret_cast(leftInput->specialBuffer()); + T const* rightBuf = reinterpret_cast(rightInput->specialBuffer()); + T* outputBuf = reinterpret_cast(output->specialBuffer()); + triangularSolveKernel<<<128, 128, 256, *stream>>>( + leftBuf, leftInput->specialShapeInfo(), rightBuf, + rightInput->specialShapeInfo(), lower, unitsOnDiag, outputBuf, + output->specialShapeInfo(), leftTads.specialShapeInfo(), + leftTads.specialOffsets(), rightTads.specialShapeInfo(), + rightTads.specialOffsets(), outputTads.specialShapeInfo(), + outputTads.specialOffsets(), leftTads.numberOfTads()); + + NDArray::registerSpecialUse({output}, {leftInput, rightInput}); + + return Status::OK(); +} - /// triangularSolve2D - 2D implementation of triangularSolveFunctor +/// triangularSolve2D - 2D implementation of triangularSolveFunctor /// \tparam T - type of NDArray output /// \param context - launch context pointer /// \param leftInput - T matrix of equation Tx = b @@ -189,78 +200,87 @@ namespace sd { // output.syncToDevice(); } BUILD_SINGLE_TEMPLATE(template void triangularSolve2D, (sd::LaunchContext* context, NDArray const& leftInput, NDArray const& rightInput, bool const lower, bool const unitsOnDiag, NDArray& output), FLOAT_TYPES); -// template void triangularSolve2D(sd::LaunchContext* context, NDArray const& leftInput, NDArray const& rightInput, bool const lower, bool const unitsOnDiag, NDArray& output); -// template void triangularSolve2D(sd::LaunchContext* context, NDArray const& leftInput, NDArray const& rightInput, bool const lower, bool const unitsOnDiag, NDArray& output); -// template void triangularSolve2D(sd::LaunchContext* context, NDArray const& leftInput, NDArray const& rightInput, bool const lower, bool const unitsOnDiag, NDArray& output); -// template void triangularSolve2D(sd::LaunchContext* context, NDArray const& leftInput, NDArray const& rightInput, bool const lower, bool const unitsOnDiag, NDArray& output); - - int triangularSolveFunctor(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool lower, bool unitsOnDiag, NDArray* output) { - BUILD_SINGLE_SELECTOR(leftInput->dataType(), return triangularSolveFunctor_, (context, leftInput, rightInput, lower, unitsOnDiag, output), FLOAT_NATIVE); - } - template - static __global__ void upperAdjointKernel(T const* input, T* output, - Nd4jLong batchSize, Nd4jLong rows, Nd4jLong columns, - Nd4jLong const* inputTads, Nd4jLong const* inputOffsets, Nd4jLong const* outputTads, Nd4jLong const* outputOffsets) { - - for (auto b = blockIdx.x; b < batchSize; b += gridDim.x) { - auto inputPart = input + inputOffsets[b]; - auto outputPart = output + outputOffsets[b]; - for (auto r = threadIdx.x; r < rows; r += blockDim.x) { - for (auto c = threadIdx.y; c <= r; c += blockDim.y) { - Nd4jLong zPos[] = {r, c}; - Nd4jLong xPos[] = {c, r}; - auto zIndex = shape::getOffset(outputTads, zPos); - auto xIndex = shape::getOffset(inputTads, xPos); - outputPart[zIndex] = inputPart[xIndex]; - } - } - } +int triangularSolveFunctor(sd::LaunchContext* context, NDArray* leftInput, + NDArray* rightInput, bool lower, bool unitsOnDiag, + NDArray* output) { + BUILD_SINGLE_SELECTOR( + leftInput->dataType(), return triangularSolveFunctor_, + (context, leftInput, rightInput, lower, unitsOnDiag, output), FLOAT_NATIVE); +} - } +template +static __global__ void upperAdjointKernel( + T const* input, T* output, Nd4jLong batchSize, Nd4jLong rows, + Nd4jLong columns, Nd4jLong const* inputTads, Nd4jLong const* inputOffsets, + Nd4jLong const* outputTads, Nd4jLong const* outputOffsets) { + for (auto b = blockIdx.x; b < batchSize; b += gridDim.x) { + auto inputPart = input + inputOffsets[b]; + auto outputPart = output + outputOffsets[b]; + for (auto r = threadIdx.x; r < rows; r += blockDim.x) { + for (auto c = threadIdx.y; c <= r; c += blockDim.y) { + Nd4jLong zPos[] = {r, c}; + Nd4jLong xPos[] = {c, r}; + auto zIndex = shape::getOffset(outputTads, zPos); + auto xIndex = shape::getOffset(inputTads, xPos); + outputPart[zIndex] = inputPart[xIndex]; + } + } + } +} - template - static __global__ void lowerAdjointKernel(T const* input, T* output, - Nd4jLong batchSize, Nd4jLong rows, Nd4jLong columns, - Nd4jLong const* inputTads, Nd4jLong const* inputOffsets, Nd4jLong const* outputTads, Nd4jLong const* outputOffsets) { - - for (auto b = blockIdx.x; b < batchSize; b += gridDim.x) { - auto inputPart = input + inputOffsets[b]; - auto outputPart = output + outputOffsets[b]; - for (auto r = threadIdx.x; r < rows; r += blockDim.x) { - for (auto c = r + threadIdx.y; c < columns; c += blockDim.y) { - Nd4jLong zPos[] = {r, c}; - Nd4jLong xPos[] = {c, r}; - auto zIndex = shape::getOffset(outputTads, zPos); - auto xIndex = shape::getOffset(inputTads, xPos); - outputPart[zIndex] = inputPart[xIndex]; - } - } - } - } +template +static __global__ void lowerAdjointKernel( + T const* input, T* output, Nd4jLong batchSize, Nd4jLong rows, + Nd4jLong columns, Nd4jLong const* inputTads, Nd4jLong const* inputOffsets, + Nd4jLong const* outputTads, Nd4jLong const* outputOffsets) { + for (auto b = blockIdx.x; b < batchSize; b += gridDim.x) { + auto inputPart = input + inputOffsets[b]; + auto outputPart = output + outputOffsets[b]; + for (auto r = threadIdx.x; r < rows; r += blockDim.x) { + for (auto c = r + threadIdx.y; c < columns; c += blockDim.y) { + Nd4jLong zPos[] = {r, c}; + Nd4jLong xPos[] = {c, r}; + auto zIndex = shape::getOffset(outputTads, zPos); + auto xIndex = shape::getOffset(inputTads, xPos); + outputPart[zIndex] = inputPart[xIndex]; + } + } + } +} - template - static void adjointTriangularMatrix_(sd::LaunchContext* context, NDArray const* input, bool const lower, - NDArray* output) { - - auto inputTads = ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), {-2, -1}); - auto outputTads = ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), {-2, -1}); - auto stream = context->getCudaStream(); - auto inputBuf = reinterpret_cast(input->specialBuffer()); - auto outputBuf = reinterpret_cast(output->specialBuffer()); - auto rows = input->sizeAt(-2); - auto columns = input->sizeAt(-1); - - if (lower) { - lowerAdjointKernel<<<128, 256, 256, *stream>>>(inputBuf, outputBuf, outputTads.numberOfTads(), rows, columns, inputTads.specialShapeInfo(), inputTads.specialOffsets(), outputTads.specialShapeInfo(), outputTads.specialOffsets()); - } else { - upperAdjointKernel<<<128, 256, 256, *stream>>>(inputBuf, outputBuf, outputTads.numberOfTads(), rows, columns, inputTads.specialShapeInfo(), inputTads.specialOffsets(), outputTads.specialShapeInfo(), outputTads.specialOffsets()); - } - } +template +static void adjointTriangularMatrix_(sd::LaunchContext* context, + NDArray const* input, bool const lower, + NDArray* output) { + auto inputTads = ConstantTadHelper::getInstance().tadForDimensions( + input->shapeInfo(), {-2, -1}); + auto outputTads = ConstantTadHelper::getInstance().tadForDimensions( + output->shapeInfo(), {-2, -1}); + auto stream = context->getCudaStream(); + auto inputBuf = reinterpret_cast(input->specialBuffer()); + auto outputBuf = reinterpret_cast(output->specialBuffer()); + auto rows = input->sizeAt(-2); + auto columns = input->sizeAt(-1); + + if (lower) { + lowerAdjointKernel<<<128, 256, 256, *stream>>>( + inputBuf, outputBuf, outputTads.numberOfTads(), rows, columns, + inputTads.specialShapeInfo(), inputTads.specialOffsets(), + outputTads.specialShapeInfo(), outputTads.specialOffsets()); + } else { + upperAdjointKernel<<<128, 256, 256, *stream>>>( + inputBuf, outputBuf, outputTads.numberOfTads(), rows, columns, + inputTads.specialShapeInfo(), inputTads.specialOffsets(), + outputTads.specialShapeInfo(), outputTads.specialOffsets()); + } +} - void adjointMatrix(sd::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), adjointTriangularMatrix_, (context, input, lower, output), FLOAT_NATIVE); - } +void adjointMatrix(sd::LaunchContext* context, NDArray const* input, + bool const lower, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), adjointTriangularMatrix_, + (context, input, lower, output), FLOAT_NATIVE); +} /* ////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaDelta.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaDelta.cu index c096c4294cd1..dae449321d78 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaDelta.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaDelta.cu @@ -18,112 +18,143 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include -#include #include #include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template -__global__ void adaDeltaUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinMsg, const Nd4jLong* inMsgShapeInfo, - const void* vinMsdx, const Nd4jLong* inMsdxShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vstMsg, - const Nd4jLong* stMsgShapeInfo, void* vstMsdx, const Nd4jLong* stMsdxShapeInfo, const T rho, const T epsilon) { - - const auto grad = reinterpret_cast(vx); - const auto initMsg= reinterpret_cast(vinMsg); - const auto initMsdx = reinterpret_cast(vinMsdx); - - auto up = reinterpret_cast(vz); - auto stMsg = reinterpret_cast(vstMsg); - auto stMsdx = reinterpret_cast(vstMsdx); - - __shared__ Nd4jLong xLen; - __shared__ T rhoT; - __shared__ bool bEWS, bOrdering, bXZsame, bXInMsgSame, bXStMsgSame, bXInMsdxSame, bXStMsdxSame; - - if (threadIdx.x == 0) { - xLen = shape::length(xShapeInfo); - - rhoT = (1 - rho); - - bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && - 1 == shape::elementWiseStride(stMsgShapeInfo) && 1 == shape::elementWiseStride(inMsgShapeInfo) && - 1 == shape::elementWiseStride(stMsdxShapeInfo) && 1 == shape::elementWiseStride(inMsdxShapeInfo); - bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stMsgShapeInfo) && - shape::order(stMsgShapeInfo) == shape::order(inMsgShapeInfo) && shape::order(inMsgShapeInfo) == shape::order(stMsdxShapeInfo) && - shape::order(stMsdxShapeInfo) == shape::order(inMsdxShapeInfo); - - bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - bXInMsgSame = shape::haveSameShapeAndStrides(xShapeInfo, inMsgShapeInfo); - bXStMsgSame = shape::haveSameShapeAndStrides(xShapeInfo, stMsgShapeInfo); - bXInMsdxSame = shape::haveSameShapeAndStrides(xShapeInfo, inMsdxShapeInfo); - bXStMsdxSame = shape::haveSameShapeAndStrides(xShapeInfo, stMsdxShapeInfo); +template +__global__ void adaDeltaUpdaterCuda( + const void* vx, const Nd4jLong* xShapeInfo, const void* vinMsg, + const Nd4jLong* inMsgShapeInfo, const void* vinMsdx, + const Nd4jLong* inMsdxShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vstMsg, const Nd4jLong* stMsgShapeInfo, void* vstMsdx, + const Nd4jLong* stMsdxShapeInfo, const T rho, const T epsilon) { + const auto grad = reinterpret_cast(vx); + const auto initMsg = reinterpret_cast(vinMsg); + const auto initMsdx = reinterpret_cast(vinMsdx); + + auto up = reinterpret_cast(vz); + auto stMsg = reinterpret_cast(vstMsg); + auto stMsdx = reinterpret_cast(vstMsdx); + + __shared__ Nd4jLong xLen; + __shared__ T rhoT; + __shared__ bool bEWS, bOrdering, bXZsame, bXInMsgSame, bXStMsgSame, + bXInMsdxSame, bXStMsdxSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + + rhoT = (1 - rho); + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && + 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stMsgShapeInfo) && + 1 == shape::elementWiseStride(inMsgShapeInfo) && + 1 == shape::elementWiseStride(stMsdxShapeInfo) && + 1 == shape::elementWiseStride(inMsdxShapeInfo); + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && + shape::order(zShapeInfo) == shape::order(stMsgShapeInfo) && + shape::order(stMsgShapeInfo) == shape::order(inMsgShapeInfo) && + shape::order(inMsgShapeInfo) == shape::order(stMsdxShapeInfo) && + shape::order(stMsdxShapeInfo) == shape::order(inMsdxShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInMsgSame = shape::haveSameShapeAndStrides(xShapeInfo, inMsgShapeInfo); + bXStMsgSame = shape::haveSameShapeAndStrides(xShapeInfo, stMsgShapeInfo); + bXInMsdxSame = shape::haveSameShapeAndStrides(xShapeInfo, inMsdxShapeInfo); + bXStMsdxSame = shape::haveSameShapeAndStrides(xShapeInfo, stMsdxShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; + i += gridDim.x * blockDim.x) { + auto xOffset = i, zOffset = i, initMsgOffset = i, initMsdxOffset = i, + stMsgOffset = i, stMsdxOffset = i; + + if (!bEWS || !bOrdering) { + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initMsgOffset = + bXInMsgSame ? xOffset : shape::getOffset(inMsgShapeInfo, coords); + stMsgOffset = + bXStMsgSame ? xOffset : shape::getOffset(stMsgShapeInfo, coords); + initMsdxOffset = + bXInMsdxSame ? xOffset : shape::getOffset(inMsdxShapeInfo, coords); + stMsdxOffset = + bXStMsdxSame ? xOffset : shape::getOffset(stMsdxShapeInfo, coords); } - __syncthreads(); - - int coords[MAX_RANK]; - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { + stMsg[stMsgOffset] = + rho * initMsg[initMsgOffset] + grad[xOffset] * grad[xOffset] * rhoT; - auto xOffset = i, zOffset = i, initMsgOffset = i, initMsdxOffset = i, stMsgOffset = i, stMsdxOffset = i; + up[zOffset] = + grad[xOffset] * + (sd::math::nd4j_sqrt(initMsdx[initMsdxOffset] + epsilon) / + sd::math::nd4j_sqrt(stMsg[stMsgOffset] + epsilon)); - if (!bEWS || !bOrdering){ - - shape::index2coords(i, xShapeInfo, coords); - xOffset = shape::getOffset(xShapeInfo, coords); - zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); - initMsgOffset = bXInMsgSame ? xOffset : shape::getOffset(inMsgShapeInfo, coords); - stMsgOffset = bXStMsgSame ? xOffset : shape::getOffset(stMsgShapeInfo, coords); - initMsdxOffset = bXInMsdxSame ? xOffset : shape::getOffset(inMsdxShapeInfo, coords); - stMsdxOffset = bXStMsdxSame ? xOffset : shape::getOffset(stMsdxShapeInfo, coords); - } - - stMsg[stMsgOffset] = rho * initMsg[initMsgOffset] + grad[xOffset] * grad[xOffset] * rhoT; - - up[zOffset] = grad[xOffset] * (sd::math::nd4j_sqrt(initMsdx[initMsdxOffset] + epsilon) / sd::math::nd4j_sqrt(stMsg[stMsgOffset] + epsilon)); - - stMsdx[stMsdxOffset] = rho * initMsdx[initMsdxOffset] + up[zOffset] * up[zOffset] * rhoT; - } + stMsdx[stMsdxOffset] = + rho * initMsdx[initMsdxOffset] + up[zOffset] * up[zOffset] * rhoT; + } } /////////////////////////////////////////////////////////////////// -template -linkage void adaDeltaUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, - const void* vinMsg, const Nd4jLong* inMsgShapeInfo, const void* vinMsdx, const Nd4jLong* inMsdxShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, void* vstMsg, const Nd4jLong* stMsgShapeInfo, - void* vstMsdx, const Nd4jLong* stMsdxShapeInfo, const double dRho, const double dEpsilon) { - - const T rho = static_cast(dRho); - const T epsilon = static_cast(dEpsilon); - - adaDeltaUpdaterCuda<<>>(vx, xShapeInfo, vinMsg, inMsgShapeInfo, - vinMsdx, inMsdxShapeInfo, vz, zShapeInfo, vstMsg, stMsgShapeInfo, vstMsdx, stMsdxShapeInfo, rho, epsilon); +template +linkage void adaDeltaUpdaterCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vinMsg, const Nd4jLong* inMsgShapeInfo, const void* vinMsdx, + const Nd4jLong* inMsdxShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vstMsg, const Nd4jLong* stMsgShapeInfo, void* vstMsdx, + const Nd4jLong* stMsdxShapeInfo, const double dRho, const double dEpsilon) { + const T rho = static_cast(dRho); + const T epsilon = static_cast(dEpsilon); + + adaDeltaUpdaterCuda<<>>( + vx, xShapeInfo, vinMsg, inMsgShapeInfo, vinMsdx, inMsdxShapeInfo, vz, + zShapeInfo, vstMsg, stMsgShapeInfo, vstMsdx, stMsdxShapeInfo, rho, + epsilon); } /////////////////////////////////////////////////////////////////// -void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx, - NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon) { - - PointersManager manager(context, "adaDeltaUpdater"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - NDArray::prepareSpecialUse({ &update, &stateMsg, &stateMsdx }, { &gradient, &initStateMsg, &initStateMsdx }); - BUILD_SINGLE_SELECTOR(gradient.dataType(), adaDeltaUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.specialBuffer(), gradient.specialShapeInfo(), - initStateMsg.specialBuffer(), initStateMsg.specialShapeInfo(), initStateMsdx.specialBuffer(), initStateMsdx.specialShapeInfo(), - update.specialBuffer(), update.specialShapeInfo(),stateMsg.specialBuffer(), stateMsg.specialShapeInfo(), - stateMsdx.specialBuffer(), stateMsdx.specialShapeInfo(), dRho, dEpsilon), FLOAT_TYPES); - NDArray::registerSpecialUse({ &update, &stateMsg, &stateMsdx }, { &gradient, &initStateMsg, &initStateMsdx }); - - manager.synchronize(); +void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateMsg, const NDArray& initStateMsdx, + NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, + const double dRho, const double dEpsilon) { + PointersManager manager(context, "adaDeltaUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({&update, &stateMsg, &stateMsdx}, + {&gradient, &initStateMsg, &initStateMsdx}); + BUILD_SINGLE_SELECTOR( + gradient.dataType(), adaDeltaUpdaterCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + gradient.specialBuffer(), gradient.specialShapeInfo(), + initStateMsg.specialBuffer(), initStateMsg.specialShapeInfo(), + initStateMsdx.specialBuffer(), initStateMsdx.specialShapeInfo(), + update.specialBuffer(), update.specialShapeInfo(), + stateMsg.specialBuffer(), stateMsg.specialShapeInfo(), + stateMsdx.specialBuffer(), stateMsdx.specialShapeInfo(), dRho, dEpsilon), + FLOAT_TYPES); + NDArray::registerSpecialUse({&update, &stateMsg, &stateMsdx}, + {&gradient, &initStateMsg, &initStateMsdx}); + + manager.synchronize(); } -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaGrad.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaGrad.cu index 50a43986c433..e40aa5a60804 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaGrad.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaGrad.cu @@ -18,100 +18,109 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include -#include #include #include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template -__global__ void adaGradUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, +template +__global__ void adaGradUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, + const void* vin, const Nd4jLong* inShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + void* vst, const Nd4jLong* stShapeInfo, const T lr, const T epsilon) { - - const auto x = reinterpret_cast(vx); - const auto init = reinterpret_cast(vin); - - auto up = reinterpret_cast(vz); - auto st = reinterpret_cast(vst); - - __shared__ bool bEWS, bOrdering, bXZsame, bXInSame, bXStSame; - __shared__ Nd4jLong xLen; - - if (threadIdx.x == 0) { - xLen = shape::length(xShapeInfo); - - bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && - 1 == shape::elementWiseStride(stShapeInfo) && 1 == shape::elementWiseStride(inShapeInfo); - bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(xShapeInfo) == shape::order(stShapeInfo) && - shape::order(xShapeInfo) == shape::order(inShapeInfo); - - bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - bXInSame = shape::haveSameShapeAndStrides(xShapeInfo, inShapeInfo); - bXStSame = shape::haveSameShapeAndStrides(xShapeInfo, stShapeInfo); + const auto x = reinterpret_cast(vx); + const auto init = reinterpret_cast(vin); + + auto up = reinterpret_cast(vz); + auto st = reinterpret_cast(vst); + + __shared__ bool bEWS, bOrdering, bXZsame, bXInSame, bXStSame; + __shared__ Nd4jLong xLen; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && + 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stShapeInfo) && + 1 == shape::elementWiseStride(inShapeInfo); + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && + shape::order(xShapeInfo) == shape::order(stShapeInfo) && + shape::order(xShapeInfo) == shape::order(inShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInSame = shape::haveSameShapeAndStrides(xShapeInfo, inShapeInfo); + bXStSame = shape::haveSameShapeAndStrides(xShapeInfo, stShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; + i += gridDim.x * blockDim.x) { + auto xOffset = i, zOffset = i, initOffset = i, stOffset = i; + + if (!bEWS || !bOrdering) { + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initOffset = bXInSame ? xOffset : shape::getOffset(inShapeInfo, coords); + stOffset = bXStSame ? xOffset : shape::getOffset(stShapeInfo, coords); } - __syncthreads(); - - int coords[MAX_RANK]; - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { - - auto xOffset = i, zOffset = i, initOffset = i, stOffset = i; - - if (!bEWS || !bOrdering) { - - shape::index2coords(i, xShapeInfo, coords); - xOffset = shape::getOffset(xShapeInfo, coords); - zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); - initOffset = bXInSame ? xOffset : shape::getOffset(inShapeInfo, coords); - stOffset = bXStSame ? xOffset : shape::getOffset(stShapeInfo, coords); - } - st[stOffset] = init[initOffset] + x[xOffset] * x[xOffset]; - up[zOffset] = (lr * x[xOffset]) / (math::nd4j_sqrt(st[stOffset]) + epsilon); - - } + st[stOffset] = init[initOffset] + x[xOffset] * x[xOffset]; + up[zOffset] = + (lr * x[xOffset]) / (math::nd4j_sqrt(st[stOffset]) + epsilon); + } } /////////////////////////////////////////////////////////////////// -template -linkage void adaGradUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, - const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, - const double dLr, const double dEpsilon) { - - const T lr = static_cast(dLr); - const T epsilon = static_cast(dEpsilon); - - adaGradUpdaterCuda<<>>(vx, xShapeInfo, vin, inShapeInfo, - vz, zShapeInfo, vst, stShapeInfo, lr, epsilon); +template +linkage void adaGradUpdaterCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vin, const Nd4jLong* inShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, + const double dLr, const double dEpsilon) { + const T lr = static_cast(dLr); + const T epsilon = static_cast(dEpsilon); + + adaGradUpdaterCuda<<>>( + vx, xShapeInfo, vin, inShapeInfo, vz, zShapeInfo, vst, stShapeInfo, lr, + epsilon); } /////////////////////////////////////////////////////////////////// -void updaterAdaGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, - NDArray& update, NDArray& stateH, const double dLr, const double dEpsilon) { - - PointersManager manager(context, "adaGradUpdater"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - NDArray::prepareSpecialUse({ &update, &stateH }, { &gradient, &initState }); - BUILD_SINGLE_SELECTOR(gradient.dataType(), adaGradUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), - gradient.specialBuffer(), gradient.specialShapeInfo(), - initState.specialBuffer(), initState.specialShapeInfo(), - update.specialBuffer(), update.specialShapeInfo(), - stateH.specialBuffer(), stateH.specialShapeInfo(), dLr, dEpsilon), FLOAT_TYPES); - NDArray::registerSpecialUse({ &update, &stateH }, { &gradient, &initState }); - - manager.synchronize(); +void updaterAdaGrad(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initState, NDArray& update, NDArray& stateH, + const double dLr, const double dEpsilon) { + PointersManager manager(context, "adaGradUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({&update, &stateH}, {&gradient, &initState}); + BUILD_SINGLE_SELECTOR( + gradient.dataType(), adaGradUpdaterCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + gradient.specialBuffer(), gradient.specialShapeInfo(), + initState.specialBuffer(), initState.specialShapeInfo(), + update.specialBuffer(), update.specialShapeInfo(), + stateH.specialBuffer(), stateH.specialShapeInfo(), dLr, dEpsilon), + FLOAT_TYPES); + NDArray::registerSpecialUse({&update, &stateH}, {&gradient, &initState}); + + manager.synchronize(); } -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaMax.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaMax.cu index 09301d05a770..38a45d0d84c9 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaMax.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaMax.cu @@ -18,125 +18,150 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include -#include #include #include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template -__global__ void adaMaxUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo, - const void* vinm, const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, - void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo, - const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) { - - const auto grad = reinterpret_cast(vx); - const auto initU = reinterpret_cast(vinv); - const auto initM = reinterpret_cast(vinm); - - auto up = reinterpret_cast(vz); - auto stU = reinterpret_cast(vstV); - auto stM = reinterpret_cast(vstM); - - __shared__ Nd4jLong xLen; - __shared__ T beta1T, epsilonT; - __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame; - - if (threadIdx.x == 0) { - xLen = shape::length(xShapeInfo); - beta1T = sd::math::nd4j_pow(beta1, (iteration + 1) ); - - epsilonT = lr / (1.0 - beta1T); - if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) - epsilonT = epsilon; - - bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && - 1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) && - 1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo); - bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(xShapeInfo) == shape::order(stmShapeInfo) && - shape::order(xShapeInfo) == shape::order(inmShapeInfo) && shape::order(xShapeInfo) == shape::order(invShapeInfo) && - shape::order(xShapeInfo) == shape::order(stvShapeInfo); - - bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo); - bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo); - bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo); - bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo); +template +__global__ void adaMaxUpdaterCuda( + const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, + const Nd4jLong* invShapeInfo, const void* vinm, + const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, + const Nd4jLong* stmShapeInfo, const T lr, const T beta1, const T beta2, + const T epsilon, const T iteration) { + const auto grad = reinterpret_cast(vx); + const auto initU = reinterpret_cast(vinv); + const auto initM = reinterpret_cast(vinm); + + auto up = reinterpret_cast(vz); + auto stU = reinterpret_cast(vstV); + auto stM = reinterpret_cast(vstM); + + __shared__ Nd4jLong xLen; + __shared__ T beta1T, epsilonT; + __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, + bXStMSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + beta1T = sd::math::nd4j_pow(beta1, (iteration + 1)); + + epsilonT = lr / (1.0 - beta1T); + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || + sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && + 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stmShapeInfo) && + 1 == shape::elementWiseStride(inmShapeInfo) && + 1 == shape::elementWiseStride(stvShapeInfo) && + 1 == shape::elementWiseStride(invShapeInfo); + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && + shape::order(xShapeInfo) == shape::order(stmShapeInfo) && + shape::order(xShapeInfo) == shape::order(inmShapeInfo) && + shape::order(xShapeInfo) == shape::order(invShapeInfo) && + shape::order(xShapeInfo) == shape::order(stvShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo); + bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo); + bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo); + bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; + i += gridDim.x * blockDim.x) { + auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, + stMOffset = i, stUOffset = i; + + if (!bEWS || !bOrdering) { + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initUOffset = + bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords); + stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords); + initMOffset = + bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords); + stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords); } - __syncthreads(); - - int coords[MAX_RANK]; - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { - - - auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, stMOffset = i, stUOffset = i; - - if (!bEWS || !bOrdering) { - shape::index2coords(i, xShapeInfo, coords); - xOffset = shape::getOffset(xShapeInfo, coords); - zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); - initUOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords); - stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords); - initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords); - stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords); - } + // m = B_1 * m + (1-B_1)*grad + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); + // u = max(B_2 * u, |grad|) + stU[stUOffset] = sd::math::nd4j_max((beta2 * initU[initUOffset]), + sd::math::nd4j_abs(grad[xOffset])) + + 1e-32; - //m = B_1 * m + (1-B_1)*grad - stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); - //u = max(B_2 * u, |grad|) - stU[stUOffset] = sd::math::nd4j_max( (beta2* initU[initUOffset]), sd::math::nd4j_abs(grad[xOffset])) + 1e-32; - - up[zOffset] = (stM[stMOffset] * epsilonT) / stU[stUOffset]; - } + up[zOffset] = (stM[stMOffset] * epsilonT) / stU[stUOffset]; + } } /////////////////////////////////////////////////////////////////// -template -linkage void adaMaxUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, - const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo, - void* vstM, const Nd4jLong* stmShapeInfo, const double dLr, - const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - - const T lr = static_cast(dLr); - const T beta1 = static_cast(dBeta1); - const T beta2 = static_cast(dBeta2); - const T epsilon = static_cast(dEpsilon); - const T iteration = static_cast(nIteration); - - adaMaxUpdaterCuda<<>>(vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, vz, - zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, iteration); +template +linkage void adaMaxUpdaterCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, + const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, + const Nd4jLong* stmShapeInfo, const double dLr, const double dBeta1, + const double dBeta2, const double dEpsilon, const int nIteration) { + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + + adaMaxUpdaterCuda<<>>( + vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, vz, zShapeInfo, + vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, + iteration); } /////////////////////////////////////////////////////////////////// -void updaterAdaMax(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, - NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, - const double dBeta2, const double dEpsilon, const int nIteration) { - - PointersManager manager(context, "adaMaxUpdater"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - NDArray::prepareSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM }); - BUILD_SINGLE_SELECTOR(gradient.dataType(), adaMaxUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), - gradient.specialBuffer(), gradient.specialShapeInfo(), initStateU.specialBuffer(), - initStateU.specialShapeInfo(), initStateM.specialBuffer(), initStateM.specialShapeInfo(), - update.specialBuffer(), update.specialShapeInfo(), stateU.specialBuffer(), - stateU.specialShapeInfo(), stateM.specialBuffer(), stateM.specialShapeInfo(), - dLr, dBeta1, dBeta2, dEpsilon, nIteration ), FLOAT_TYPES); - NDArray::registerSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM }); - - manager.synchronize(); +void updaterAdaMax(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateU, const NDArray& initStateM, + NDArray& update, NDArray& stateU, NDArray& stateM, + const double dLr, const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + PointersManager manager(context, "adaMaxUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({&update, &stateU, &stateM}, + {&gradient, &initStateU, &initStateM}); + BUILD_SINGLE_SELECTOR( + gradient.dataType(), adaMaxUpdaterCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + gradient.specialBuffer(), gradient.specialShapeInfo(), + initStateU.specialBuffer(), initStateU.specialShapeInfo(), + initStateM.specialBuffer(), initStateM.specialShapeInfo(), + update.specialBuffer(), update.specialShapeInfo(), + stateU.specialBuffer(), stateU.specialShapeInfo(), + stateM.specialBuffer(), stateM.specialShapeInfo(), dLr, dBeta1, dBeta2, + dEpsilon, nIteration), + FLOAT_TYPES); + NDArray::registerSpecialUse({&update, &stateU, &stateM}, + {&gradient, &initStateU, &initStateM}); + + manager.synchronize(); } -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterAdam.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdam.cu index 91d79809c707..112cfabe9df8 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/updaterAdam.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdam.cu @@ -18,122 +18,152 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include -#include #include #include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template -__global__ void adamUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, - const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vstV, - const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo, - const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) { - - const auto grad = reinterpret_cast(vx); - const auto initU = reinterpret_cast(vinv); - const auto initM = reinterpret_cast(vinm); - - auto up = reinterpret_cast(vz); - auto stU = reinterpret_cast(vstV); - auto stM = reinterpret_cast(vstM); - - __shared__ Nd4jLong xLen; - __shared__ T epsilonT; - __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame; - - if (threadIdx.x == 0) { - xLen = shape::length(xShapeInfo); - - T beta1T = sd::math::nd4j_pow(beta1, (iteration + 1)); - T beta2T = sd::math::nd4j_pow(beta2, (iteration + 1)); - - epsilonT = lr * sd::math::nd4j_sqrt(1. - beta2T) / (1.0 - beta1T); - if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) - epsilonT = epsilon; - - bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && - 1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) && - 1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo); - bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stmShapeInfo) && - shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && shape::order(inmShapeInfo) == shape::order(stvShapeInfo) && - shape::order(stvShapeInfo) == shape::order(invShapeInfo); - - bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo); - bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo); - bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo); - bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo); +template +__global__ void adamUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, + const void* vinv, const Nd4jLong* invShapeInfo, + const void* vinm, const Nd4jLong* inmShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + void* vstV, const Nd4jLong* stvShapeInfo, + void* vstM, const Nd4jLong* stmShapeInfo, + const T lr, const T beta1, const T beta2, + const T epsilon, const T iteration) { + const auto grad = reinterpret_cast(vx); + const auto initU = reinterpret_cast(vinv); + const auto initM = reinterpret_cast(vinm); + + auto up = reinterpret_cast(vz); + auto stU = reinterpret_cast(vstV); + auto stM = reinterpret_cast(vstM); + + __shared__ Nd4jLong xLen; + __shared__ T epsilonT; + __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, + bXStMSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + + T beta1T = sd::math::nd4j_pow(beta1, (iteration + 1)); + T beta2T = sd::math::nd4j_pow(beta2, (iteration + 1)); + + epsilonT = lr * sd::math::nd4j_sqrt(1. - beta2T) / (1.0 - beta1T); + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || + sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && + 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stmShapeInfo) && + 1 == shape::elementWiseStride(inmShapeInfo) && + 1 == shape::elementWiseStride(stvShapeInfo) && + 1 == shape::elementWiseStride(invShapeInfo); + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && + shape::order(zShapeInfo) == shape::order(stmShapeInfo) && + shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && + shape::order(inmShapeInfo) == shape::order(stvShapeInfo) && + shape::order(stvShapeInfo) == shape::order(invShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo); + bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo); + bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo); + bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; + i += gridDim.x * blockDim.x) { + auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, + stMOffset = i, stUOffset = i; + + if (!bEWS || !bOrdering) { + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initUOffset = + bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords); + stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords); + initMOffset = + bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords); + stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords); } - __syncthreads(); - - int coords[MAX_RANK]; - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); + stU[stUOffset] = beta2 * initU[initUOffset] + + grad[xOffset] * grad[xOffset] * (1 - beta2); - auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, stMOffset = i, stUOffset = i; - - if (!bEWS || !bOrdering){ - - shape::index2coords(i, xShapeInfo, coords); - xOffset = shape::getOffset(xShapeInfo, coords); - zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); - initUOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords); - stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords); - initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords); - stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords); - } - - stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); - stU[stUOffset] = beta2 * initU[initUOffset] + grad[xOffset] * grad[xOffset] * (1 - beta2); - - up[zOffset] = (stM[stMOffset] * epsilonT) / ( sd::math::nd4j_sqrt(stU[stUOffset]) + epsilon); - } + up[zOffset] = (stM[stMOffset] * epsilonT) / + (sd::math::nd4j_sqrt(stU[stUOffset]) + epsilon); + } } /////////////////////////////////////////////////////////////////// -template -linkage void adamUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, - const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo, - void* vstM, const Nd4jLong* stmShapeInfo, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - - const T lr = static_cast(dLr); - const T beta1 = static_cast(dBeta1); - const T beta2 = static_cast(dBeta2); - const T epsilon = static_cast(dEpsilon); - const T iteration = static_cast(nIteration); - adamUpdaterCuda<<>>(vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, - vz, zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, iteration); +template +linkage void adamUpdaterCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, + const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, + const Nd4jLong* stmShapeInfo, const double dLr, const double dBeta1, + const double dBeta2, const double dEpsilon, const int nIteration) { + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + adamUpdaterCuda<<>>( + vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, vz, zShapeInfo, + vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, + iteration); } /////////////////////////////////////////////////////////////////// -void updaterAdam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, - NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, +void updaterAdam(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateU, const NDArray& initStateM, + NDArray& update, NDArray& stateU, NDArray& stateM, + const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - - PointersManager manager(context, "adamUpdater"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - NDArray::prepareSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM }); - - BUILD_SINGLE_SELECTOR(gradient.dataType(), adamUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.specialBuffer(), gradient.specialShapeInfo(), - initStateU.specialBuffer(), initStateU.specialShapeInfo(), initStateM.specialBuffer(), initStateM.specialShapeInfo(), - update.specialBuffer(), update.specialShapeInfo(), stateU.specialBuffer(), stateU.specialShapeInfo(), - stateM.specialBuffer(), stateM.specialShapeInfo(), dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); - - NDArray::registerSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM }); - - manager.synchronize(); + PointersManager manager(context, "adamUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({&update, &stateU, &stateM}, + {&gradient, &initStateU, &initStateM}); + + BUILD_SINGLE_SELECTOR( + gradient.dataType(), adamUpdaterCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + gradient.specialBuffer(), gradient.specialShapeInfo(), + initStateU.specialBuffer(), initStateU.specialShapeInfo(), + initStateM.specialBuffer(), initStateM.specialShapeInfo(), + update.specialBuffer(), update.specialShapeInfo(), + stateU.specialBuffer(), stateU.specialShapeInfo(), + stateM.specialBuffer(), stateM.specialShapeInfo(), dLr, dBeta1, dBeta2, + dEpsilon, nIteration), + FLOAT_TYPES); + + NDArray::registerSpecialUse({&update, &stateU, &stateM}, + {&gradient, &initStateU, &initStateM}); + + manager.synchronize(); } -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterAmsGrad.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterAmsGrad.cu index ff3bc1e4bed8..17d55d5054bd 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/updaterAmsGrad.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterAmsGrad.cu @@ -18,135 +18,176 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include -#include #include #include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template -__global__ void amsGradUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo, - const void* vinm, const Nd4jLong* inmShapeInfo, const void* vinh, const Nd4jLong* inhShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, - const Nd4jLong* stmShapeInfo, void* vstH, const Nd4jLong* sthShapeInfo, - const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) { - - const auto grad = reinterpret_cast(vx); - const auto initV = reinterpret_cast(vinv); - const auto initM = reinterpret_cast(vinm); - const auto initH = reinterpret_cast(vinh); - - auto up = reinterpret_cast(vz); - auto stV = reinterpret_cast(vstV); - auto stM = reinterpret_cast(vstM); - auto stH = reinterpret_cast(vstH); - - __shared__ Nd4jLong xLen; - __shared__ T mbeta1, mbeta2, epsilonT; - __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame, bXInHSame, bXStHSame; - - if (threadIdx.x == 0) { - xLen = shape::length(xShapeInfo); - - epsilonT = lr * sd::math::nd4j_sqrt(1.0 - sd::math::nd4j_pow(beta2, (iteration + 1))) / (1.0 - sd::math::nd4j_pow(beta1, (iteration + 1))); - - if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) - epsilonT = epsilon; - - mbeta1 = (1 - beta1); - mbeta2 = (1 - beta2); - - bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && - 1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) && - 1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo) && - 1 == shape::elementWiseStride(sthShapeInfo) && 1 == shape::elementWiseStride(inhShapeInfo); - - bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stmShapeInfo) && - shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && shape::order(inmShapeInfo) == shape::order(stvShapeInfo) && - shape::order(stvShapeInfo) == shape::order(invShapeInfo) && shape::order(invShapeInfo) == shape::order(sthShapeInfo) && - shape::order(sthShapeInfo) == shape::order(inhShapeInfo); - - bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo); - bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo); - bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo); - bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo); - bXInHSame = shape::haveSameShapeAndStrides(xShapeInfo, inhShapeInfo); - bXStHSame = shape::haveSameShapeAndStrides(xShapeInfo, sthShapeInfo); +template +__global__ void amsGradUpdaterCuda( + const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, + const Nd4jLong* invShapeInfo, const void* vinm, + const Nd4jLong* inmShapeInfo, const void* vinh, + const Nd4jLong* inhShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, + const Nd4jLong* stmShapeInfo, void* vstH, const Nd4jLong* sthShapeInfo, + const T lr, const T beta1, const T beta2, const T epsilon, + const T iteration) { + const auto grad = reinterpret_cast(vx); + const auto initV = reinterpret_cast(vinv); + const auto initM = reinterpret_cast(vinm); + const auto initH = reinterpret_cast(vinh); + + auto up = reinterpret_cast(vz); + auto stV = reinterpret_cast(vstV); + auto stM = reinterpret_cast(vstM); + auto stH = reinterpret_cast(vstH); + + __shared__ Nd4jLong xLen; + __shared__ T mbeta1, mbeta2, epsilonT; + __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, + bXStMSame, bXInHSame, bXStHSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + + epsilonT = lr * + sd::math::nd4j_sqrt( + 1.0 - sd::math::nd4j_pow(beta2, (iteration + 1))) / + (1.0 - sd::math::nd4j_pow(beta1, (iteration + 1))); + + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || + sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + mbeta1 = (1 - beta1); + mbeta2 = (1 - beta2); + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && + 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stmShapeInfo) && + 1 == shape::elementWiseStride(inmShapeInfo) && + 1 == shape::elementWiseStride(stvShapeInfo) && + 1 == shape::elementWiseStride(invShapeInfo) && + 1 == shape::elementWiseStride(sthShapeInfo) && + 1 == shape::elementWiseStride(inhShapeInfo); + + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && + shape::order(zShapeInfo) == shape::order(stmShapeInfo) && + shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && + shape::order(inmShapeInfo) == shape::order(stvShapeInfo) && + shape::order(stvShapeInfo) == shape::order(invShapeInfo) && + shape::order(invShapeInfo) == shape::order(sthShapeInfo) && + shape::order(sthShapeInfo) == shape::order(inhShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo); + bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo); + bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo); + bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo); + bXInHSame = shape::haveSameShapeAndStrides(xShapeInfo, inhShapeInfo); + bXStHSame = shape::haveSameShapeAndStrides(xShapeInfo, sthShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; + i += gridDim.x * blockDim.x) { + auto xOffset = i, zOffset = i, initMOffset = i, initVOffset = i, + initHOffset = i, stMOffset = i, stVOffset = i, stHOffset = i; + + if (!bEWS || !bOrdering) { + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initMOffset = + bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords); + stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords); + initVOffset = + bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords); + stVOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords); + initHOffset = + bXInHSame ? xOffset : shape::getOffset(inhShapeInfo, coords); + stHOffset = bXStHSame ? xOffset : shape::getOffset(sthShapeInfo, coords); } - __syncthreads(); - - int coords[MAX_RANK]; - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { - - auto xOffset = i, zOffset = i, initMOffset = i, initVOffset = i, initHOffset = i, stMOffset = i, stVOffset = i, stHOffset = i; - - if (!bEWS || !bOrdering){ - - shape::index2coords(i, xShapeInfo, coords); - xOffset = shape::getOffset(xShapeInfo, coords); - zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); - initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords); - stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords); - initVOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords); - stVOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords); - initHOffset = bXInHSame ? xOffset : shape::getOffset(inhShapeInfo, coords); - stHOffset = bXStHSame ? xOffset : shape::getOffset(sthShapeInfo, coords); - } - stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * mbeta1; - stV[stVOffset] = beta2 * initV[initVOffset] + grad[xOffset] * grad[xOffset] * mbeta2; - stH[stHOffset] = sd::math::nd4j_max(initH[initHOffset], stV[stVOffset]); + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * mbeta1; + stV[stVOffset] = + beta2 * initV[initVOffset] + grad[xOffset] * grad[xOffset] * mbeta2; + stH[stHOffset] = sd::math::nd4j_max(initH[initHOffset], stV[stVOffset]); - up[zOffset] = epsilonT * stM[stMOffset] / (sd::math::nd4j_sqrt(stH[stHOffset]) + epsilon); - } + up[zOffset] = epsilonT * stM[stMOffset] / + (sd::math::nd4j_sqrt(stH[stHOffset]) + epsilon); + } } /////////////////////////////////////////////////////////////////// -template -linkage void amsGradUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, - const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo, - const void* vinh, const Nd4jLong* inhShapeInfo, void* vz, const Nd4jLong* zShapeInfo, - void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo, - void* vstH, const Nd4jLong* sthShapeInfo, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - - const T lr = static_cast(dLr); - const T beta1 = static_cast(dBeta1); - const T beta2 = static_cast(dBeta2); - const T epsilon = static_cast(dEpsilon); - const T iteration = static_cast(nIteration); - - amsGradUpdaterCuda<<>>(vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, - vinh, inhShapeInfo, vz, zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, vstH, sthShapeInfo, lr, beta1, beta2, epsilon, iteration); +template +linkage void amsGradUpdaterCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, + const Nd4jLong* inmShapeInfo, const void* vinh, + const Nd4jLong* inhShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, + const Nd4jLong* stmShapeInfo, void* vstH, const Nd4jLong* sthShapeInfo, + const double dLr, const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + + amsGradUpdaterCuda<<>>( + vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, vinh, + inhShapeInfo, vz, zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, + vstH, sthShapeInfo, lr, beta1, beta2, epsilon, iteration); } /////////////////////////////////////////////////////////////////// -void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH, - NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - - PointersManager manager(context, "amsGradUpdater"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - NDArray::prepareSpecialUse({ &update, &stateV, &stateM, &stateH }, { &gradient, &initStateV, &initStateM, &initStateH }); - BUILD_SINGLE_SELECTOR(gradient.dataType(), amsGradUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.specialBuffer(), gradient.specialShapeInfo(), - initStateV.specialBuffer(), initStateV.specialShapeInfo(), initStateM.specialBuffer(), initStateM.specialShapeInfo(), - initStateH.specialBuffer(), initStateH.specialShapeInfo(), update.specialBuffer(), update.specialShapeInfo(), - stateV.specialBuffer(), stateV.specialShapeInfo(), stateM.specialBuffer(), stateM.specialShapeInfo(), - stateH.specialBuffer(), stateH.specialShapeInfo(), dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); - NDArray::registerSpecialUse({ &update, &stateV, &stateM , &stateH }, { &gradient, &initStateV, &initStateM, &initStateH }); - - manager.synchronize(); +void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateV, const NDArray& initStateM, + const NDArray& initStateH, NDArray& update, NDArray& stateV, + NDArray& stateM, NDArray& stateH, const double dLr, + const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + PointersManager manager(context, "amsGradUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse( + {&update, &stateV, &stateM, &stateH}, + {&gradient, &initStateV, &initStateM, &initStateH}); + BUILD_SINGLE_SELECTOR( + gradient.dataType(), amsGradUpdaterCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + gradient.specialBuffer(), gradient.specialShapeInfo(), + initStateV.specialBuffer(), initStateV.specialShapeInfo(), + initStateM.specialBuffer(), initStateM.specialShapeInfo(), + initStateH.specialBuffer(), initStateH.specialShapeInfo(), + update.specialBuffer(), update.specialShapeInfo(), + stateV.specialBuffer(), stateV.specialShapeInfo(), + stateM.specialBuffer(), stateM.specialShapeInfo(), + stateH.specialBuffer(), stateH.specialShapeInfo(), dLr, dBeta1, dBeta2, + dEpsilon, nIteration), + FLOAT_TYPES); + NDArray::registerSpecialUse( + {&update, &stateV, &stateM, &stateH}, + {&gradient, &initStateV, &initStateM, &initStateH}); + + manager.synchronize(); } - -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterNadam.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterNadam.cu index 141ed27dbfbe..32bd086cad99 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/updaterNadam.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterNadam.cu @@ -18,120 +18,150 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include -#include #include #include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template -__global__ void nadamUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo, - const void* vinm, const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, - void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo, - const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) { - - const auto grad = reinterpret_cast(vx); - const auto initV = reinterpret_cast(vinv); - const auto initM = reinterpret_cast(vinm); - - auto up = reinterpret_cast(vz); - auto stV = reinterpret_cast(vstV); - auto stM = reinterpret_cast(vstM); - - __shared__ Nd4jLong xLen; - __shared__ T mbeta1T, mbeta1, mbeta2; - __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame; - - if (threadIdx.x == 0) { - xLen = shape::length(xShapeInfo); - - mbeta1T = 1.0 - sd::math::nd4j_pow(beta1, (iteration + 1)); - mbeta1 = (1 - beta1); - mbeta2 = (1 - beta2); - - bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && - 1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) && - 1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo); - bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stmShapeInfo) && - shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && shape::order(inmShapeInfo) == shape::order(stvShapeInfo) && - shape::order(stvShapeInfo) == shape::order(invShapeInfo); - - bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo); - bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo); - bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo); - bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo); +template +__global__ void nadamUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, + const void* vinv, const Nd4jLong* invShapeInfo, + const void* vinm, const Nd4jLong* inmShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + void* vstV, const Nd4jLong* stvShapeInfo, + void* vstM, const Nd4jLong* stmShapeInfo, + const T lr, const T beta1, const T beta2, + const T epsilon, const T iteration) { + const auto grad = reinterpret_cast(vx); + const auto initV = reinterpret_cast(vinv); + const auto initM = reinterpret_cast(vinm); + + auto up = reinterpret_cast(vz); + auto stV = reinterpret_cast(vstV); + auto stM = reinterpret_cast(vstM); + + __shared__ Nd4jLong xLen; + __shared__ T mbeta1T, mbeta1, mbeta2; + __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, + bXStMSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + + mbeta1T = 1.0 - sd::math::nd4j_pow(beta1, (iteration + 1)); + mbeta1 = (1 - beta1); + mbeta2 = (1 - beta2); + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && + 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stmShapeInfo) && + 1 == shape::elementWiseStride(inmShapeInfo) && + 1 == shape::elementWiseStride(stvShapeInfo) && + 1 == shape::elementWiseStride(invShapeInfo); + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && + shape::order(zShapeInfo) == shape::order(stmShapeInfo) && + shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && + shape::order(inmShapeInfo) == shape::order(stvShapeInfo) && + shape::order(stvShapeInfo) == shape::order(invShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo); + bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo); + bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo); + bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; + i += gridDim.x * blockDim.x) { + auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, + stMOffset = i, stUOffset = i; + + if (!bEWS || !bOrdering) { + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initUOffset = + bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords); + stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords); + initMOffset = + bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords); + stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords); } - __syncthreads(); - - int coords[MAX_RANK]; - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { - - auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, stMOffset = i, stUOffset = i; - - if (!bEWS || !bOrdering){ - - shape::index2coords(i, xShapeInfo, coords); - xOffset = shape::getOffset(xShapeInfo, coords); - zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); - initUOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords); - stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords); - initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords); - stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords); - } - auto oneMinusBeta1Grad = grad[xOffset] * mbeta1; + auto oneMinusBeta1Grad = grad[xOffset] * mbeta1; - stM[stMOffset] = beta1 * initM[initMOffset] + oneMinusBeta1Grad; - stV[stUOffset] = beta2 * initV[initUOffset] + grad[xOffset] * grad[xOffset] * mbeta2; + stM[stMOffset] = beta1 * initM[initMOffset] + oneMinusBeta1Grad; + stV[stUOffset] = + beta2 * initV[initUOffset] + grad[xOffset] * grad[xOffset] * mbeta2; - up[zOffset] = (lr * ((stM[stMOffset] * beta1 + oneMinusBeta1Grad) / mbeta1T)) / (sd::math::nd4j_sqrt(stV[stUOffset]) + epsilon); - } + up[zOffset] = + (lr * ((stM[stMOffset] * beta1 + oneMinusBeta1Grad) / mbeta1T)) / + (sd::math::nd4j_sqrt(stV[stUOffset]) + epsilon); + } } /////////////////////////////////////////////////////////////////// -template -linkage void nadamUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, - const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, - const Nd4jLong* stmShapeInfo, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - - const T lr = static_cast(dLr); - const T beta1 = static_cast(dBeta1); - const T beta2 = static_cast(dBeta2); - const T epsilon = static_cast(dEpsilon); - const T iteration = static_cast(nIteration); - - nadamUpdaterCuda<<>>(vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, - vz, zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, iteration); +template +linkage void nadamUpdaterCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, + const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, + const Nd4jLong* stmShapeInfo, const double dLr, const double dBeta1, + const double dBeta2, const double dEpsilon, const int nIteration) { + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + + nadamUpdaterCuda<<>>( + vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, vz, zShapeInfo, + vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, + iteration); } /////////////////////////////////////////////////////////////////// -void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, - NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - - PointersManager manager(context, "nadamUpdater"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - NDArray::prepareSpecialUse({ &update, &stateV, &stateM }, { &gradient, &initStateV, &initStateM }); - BUILD_SINGLE_SELECTOR(gradient.dataType(), nadamUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.specialBuffer(), gradient.specialShapeInfo(), - initStateV.specialBuffer(), initStateV.specialShapeInfo(), initStateM.specialBuffer(), initStateM.specialShapeInfo(), - update.specialBuffer(), update.specialShapeInfo(), stateV.specialBuffer(), stateV.specialShapeInfo(), - stateM.specialBuffer(), stateM.specialShapeInfo(), dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); - NDArray::registerSpecialUse({ &update, &stateV, &stateM }, { &gradient, &initStateV, &initStateM }); - - manager.synchronize(); +void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateV, const NDArray& initStateM, + NDArray& update, NDArray& stateV, NDArray& stateM, + const double dLr, const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + PointersManager manager(context, "nadamUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({&update, &stateV, &stateM}, + {&gradient, &initStateV, &initStateM}); + BUILD_SINGLE_SELECTOR( + gradient.dataType(), nadamUpdaterCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + gradient.specialBuffer(), gradient.specialShapeInfo(), + initStateV.specialBuffer(), initStateV.specialShapeInfo(), + initStateM.specialBuffer(), initStateM.specialShapeInfo(), + update.specialBuffer(), update.specialShapeInfo(), + stateV.specialBuffer(), stateV.specialShapeInfo(), + stateM.specialBuffer(), stateM.specialShapeInfo(), dLr, dBeta1, dBeta2, + dEpsilon, nIteration), + FLOAT_TYPES); + NDArray::registerSpecialUse({&update, &stateV, &stateM}, + {&gradient, &initStateV, &initStateM}); + + manager.synchronize(); } - -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterNesterovs.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterNesterovs.cu index 75e1f593884b..b09139ed8453 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/updaterNesterovs.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterNesterovs.cu @@ -18,100 +18,111 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include -#include #include #include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - /////////////////////////////////////////////////////////////////// -template -__global__ void nesterovsUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, const T lr, const T momentum) { - - const auto grad = reinterpret_cast(vx); - const auto init = reinterpret_cast(vin); - auto up = reinterpret_cast(vz); - auto st = reinterpret_cast(vst); - - __shared__ Nd4jLong xLen; - __shared__ T momentumT; - __shared__ bool bEWS, bOrdering, bXZsame, bXInSame, bXStSame; - - if (threadIdx.x == 0) { - xLen = shape::length(xShapeInfo); - momentumT = (-momentum - 1); - - bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && - 1 == shape::elementWiseStride(stShapeInfo) && 1 == shape::elementWiseStride(inShapeInfo); - bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(xShapeInfo) == shape::order(inShapeInfo) && - shape::order(xShapeInfo) == shape::order(stShapeInfo); - - bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - bXInSame = shape::haveSameShapeAndStrides(xShapeInfo, inShapeInfo); - bXStSame = shape::haveSameShapeAndStrides(xShapeInfo, stShapeInfo); +template +__global__ void nesterovsUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, + const void* vin, + const Nd4jLong* inShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, void* vst, + const Nd4jLong* stShapeInfo, const T lr, + const T momentum) { + const auto grad = reinterpret_cast(vx); + const auto init = reinterpret_cast(vin); + auto up = reinterpret_cast(vz); + auto st = reinterpret_cast(vst); + + __shared__ Nd4jLong xLen; + __shared__ T momentumT; + __shared__ bool bEWS, bOrdering, bXZsame, bXInSame, bXStSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + momentumT = (-momentum - 1); + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && + 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stShapeInfo) && + 1 == shape::elementWiseStride(inShapeInfo); + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && + shape::order(xShapeInfo) == shape::order(inShapeInfo) && + shape::order(xShapeInfo) == shape::order(stShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInSame = shape::haveSameShapeAndStrides(xShapeInfo, inShapeInfo); + bXStSame = shape::haveSameShapeAndStrides(xShapeInfo, stShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; + i += gridDim.x * blockDim.x) { + auto xOffset = i, zOffset = i, initOffset = i, stOffset = i; + + if (!bEWS || !bOrdering) { + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initOffset = bXInSame ? xOffset : shape::getOffset(inShapeInfo, coords); + stOffset = bXStSame ? xOffset : shape::getOffset(stShapeInfo, coords); } - __syncthreads(); - - int coords[MAX_RANK]; - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { - - auto xOffset = i, zOffset = i, initOffset = i, stOffset = i; - - if (!bEWS || !bOrdering) { - - shape::index2coords(i, xShapeInfo, coords); - xOffset = shape::getOffset(xShapeInfo, coords); - zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); - initOffset = bXInSame ? xOffset : shape::getOffset(inShapeInfo, coords); - stOffset = bXStSame ? xOffset : shape::getOffset(stShapeInfo, coords); - } - T prevState = momentum * init[initOffset]; - st[stOffset] = prevState - lr * grad[xOffset]; - up[zOffset] = prevState + momentumT * st[stOffset]; - } + T prevState = momentum * init[initOffset]; + st[stOffset] = prevState - lr * grad[xOffset]; + up[zOffset] = prevState + momentumT * st[stOffset]; + } } /////////////////////////////////////////////////////////////////// -template -linkage void nesterovsUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, - const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, - const double dLr, const double dMomentum) { - - const T lr = static_cast(dLr); - const T momentum = static_cast(dMomentum); - nesterovsUpdaterCuda<<>>(vx, xShapeInfo, vin, inShapeInfo, - vz, zShapeInfo, vst, stShapeInfo, lr, momentum); +template +linkage void nesterovsUpdaterCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vin, const Nd4jLong* inShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, + const double dLr, const double dMomentum) { + const T lr = static_cast(dLr); + const T momentum = static_cast(dMomentum); + nesterovsUpdaterCuda<<>>( + vx, xShapeInfo, vin, inShapeInfo, vz, zShapeInfo, vst, stShapeInfo, lr, + momentum); } /////////////////////////////////////////////////////////////////// -void updaterNesterovs(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, - NDArray& update, NDArray& stateV, const double dLr, const double dMomentum) { - - PointersManager manager(context, "nesterovsUpdater"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - NDArray::prepareSpecialUse({ &update, &stateV }, { &gradient, &initState }); - BUILD_SINGLE_SELECTOR(gradient.dataType(), nesterovsUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, - context->getCudaStream(), gradient.specialBuffer(), gradient.specialShapeInfo(), - initState.specialBuffer(), initState.specialShapeInfo(), - update.specialBuffer(), update.specialShapeInfo(), - stateV.specialBuffer(), stateV.specialShapeInfo(), dLr, dMomentum), FLOAT_TYPES); - NDArray::registerSpecialUse({ &update, &stateV }, { &gradient, &initState }); - - manager.synchronize(); +void updaterNesterovs(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initState, NDArray& update, + NDArray& stateV, const double dLr, + const double dMomentum) { + PointersManager manager(context, "nesterovsUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({&update, &stateV}, {&gradient, &initState}); + BUILD_SINGLE_SELECTOR( + gradient.dataType(), nesterovsUpdaterCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + gradient.specialBuffer(), gradient.specialShapeInfo(), + initState.specialBuffer(), initState.specialShapeInfo(), + update.specialBuffer(), update.specialShapeInfo(), + stateV.specialBuffer(), stateV.specialShapeInfo(), dLr, dMomentum), + FLOAT_TYPES); + NDArray::registerSpecialUse({&update, &stateV}, {&gradient, &initState}); + + manager.synchronize(); } -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterRmsProp.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterRmsProp.cu index 26f7253d2dea..f72092b8703c 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/updaterRmsProp.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterRmsProp.cu @@ -18,104 +18,115 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include -#include #include #include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template -__global__ void rmsPropUpdaterCuda(const void *vx, const Nd4jLong *xShapeInfo, const void *vin, const Nd4jLong *inShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, - const T lr, const T rmsDecay, const T epsilon) { - - const auto x = reinterpret_cast(vx); - const auto init = reinterpret_cast(vin); - - auto up = reinterpret_cast(vz); - auto st = reinterpret_cast(vst); - - __shared__ Nd4jLong xLen; - __shared__ bool bEWS, bOrdering, bXZsame, bXInSame, bXStSame; - - if (threadIdx.x == 0) { - - xLen = shape::length(xShapeInfo); - - bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && - 1 == shape::elementWiseStride(stShapeInfo) && 1 == shape::elementWiseStride(inShapeInfo); - - bOrdering = shape::order(zShapeInfo) == shape::order(xShapeInfo) && shape::order(xShapeInfo) == shape::order(stShapeInfo) && - shape::order(xShapeInfo) == shape::order(inShapeInfo); - bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - bXInSame = shape::haveSameShapeAndStrides(xShapeInfo, inShapeInfo); - bXStSame = shape::haveSameShapeAndStrides(xShapeInfo, stShapeInfo); +template +__global__ void rmsPropUpdaterCuda(const void *vx, const Nd4jLong *xShapeInfo, + const void *vin, const Nd4jLong *inShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, + void *vst, const Nd4jLong *stShapeInfo, + const T lr, const T rmsDecay, + const T epsilon) { + const auto x = reinterpret_cast(vx); + const auto init = reinterpret_cast(vin); + + auto up = reinterpret_cast(vz); + auto st = reinterpret_cast(vst); + + __shared__ Nd4jLong xLen; + __shared__ bool bEWS, bOrdering, bXZsame, bXInSame, bXStSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && + 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stShapeInfo) && + 1 == shape::elementWiseStride(inShapeInfo); + + bOrdering = shape::order(zShapeInfo) == shape::order(xShapeInfo) && + shape::order(xShapeInfo) == shape::order(stShapeInfo) && + shape::order(xShapeInfo) == shape::order(inShapeInfo); + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInSame = shape::haveSameShapeAndStrides(xShapeInfo, inShapeInfo); + bXStSame = shape::haveSameShapeAndStrides(xShapeInfo, stShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; + i += gridDim.x * blockDim.x) { + auto xOffset = i, zOffset = i, initOffset = i, stOffset = i; + + if (!bEWS || !bOrdering) { + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initOffset = bXInSame ? xOffset : shape::getOffset(inShapeInfo, coords); + stOffset = bXStSame ? xOffset : shape::getOffset(stShapeInfo, coords); } - __syncthreads(); - - int coords[MAX_RANK]; - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { - - auto xOffset = i, zOffset = i, initOffset = i, stOffset = i; - if (!bEWS || !bOrdering) { - - shape::index2coords(i, xShapeInfo, coords); - xOffset = shape::getOffset(xShapeInfo, coords); - zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); - initOffset = bXInSame ? xOffset : shape::getOffset(inShapeInfo, coords); - stOffset = bXStSame ? xOffset : shape::getOffset(stShapeInfo, coords); - } - - st[stOffset] = init[initOffset] * rmsDecay + x[xOffset] * x[xOffset] * (1 - rmsDecay) ; - up[zOffset] = (lr * x[xOffset]) / ( math::nd4j_sqrt(st[stOffset]) + epsilon); - } + st[stOffset] = + init[initOffset] * rmsDecay + x[xOffset] * x[xOffset] * (1 - rmsDecay); + up[zOffset] = + (lr * x[xOffset]) / (math::nd4j_sqrt(st[stOffset]) + epsilon); + } } /////////////////////////////////////////////////////////////////// -template -linkage void rmsPropUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, - const void *vx, const Nd4jLong *xShapeInfo, const void *vin, const Nd4jLong *inShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, - const double dLr, const double dRmsDecay, const double dEpsilon) { - - const T lr = static_cast(dLr); - const T rmsDecay = static_cast(dRmsDecay); - const T epsilon = static_cast(dEpsilon); - - rmsPropUpdaterCuda<<>>(vx, xShapeInfo, vin, inShapeInfo, - vz, zShapeInfo, vst, stShapeInfo, lr, rmsDecay, epsilon); +template +linkage void rmsPropUpdaterCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, + const void *vin, const Nd4jLong *inShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, void *vst, const Nd4jLong *stShapeInfo, + const double dLr, const double dRmsDecay, const double dEpsilon) { + const T lr = static_cast(dLr); + const T rmsDecay = static_cast(dRmsDecay); + const T epsilon = static_cast(dEpsilon); + + rmsPropUpdaterCuda<<>>( + vx, xShapeInfo, vin, inShapeInfo, vz, zShapeInfo, vst, stShapeInfo, lr, + rmsDecay, epsilon); } /////////////////////////////////////////////////////////////////// -void updaterRmsProp(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateG, - const double dLr, const double dRmsDecay, const double dEpsilon) { - - PointersManager manager(context, "rmsPropUpdater"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - NDArray::prepareSpecialUse({&update, &stateG}, {&gradient, &initState }); - - BUILD_SINGLE_SELECTOR(gradient.dataType(), rmsPropUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, - context->getCudaStream(), gradient.specialBuffer(), gradient.specialShapeInfo(), - initState.specialBuffer(), initState.specialShapeInfo(), - update.specialBuffer(), update.specialShapeInfo(), - stateG.specialBuffer(), stateG.specialShapeInfo(), - dLr, dRmsDecay, dEpsilon ), FLOAT_TYPES); - - NDArray::registerSpecialUse({&update, &stateG}, {&gradient, &initState}); - - manager.synchronize(); +void updaterRmsProp(sd::LaunchContext *context, const NDArray &gradient, + const NDArray &initState, NDArray &update, NDArray &stateG, + const double dLr, const double dRmsDecay, + const double dEpsilon) { + PointersManager manager(context, "rmsPropUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({&update, &stateG}, {&gradient, &initState}); + + BUILD_SINGLE_SELECTOR(gradient.dataType(), rmsPropUpdaterCudaLauncher, + (blocksPerGrid, threadsPerBlock, + context->getCudaStream(), gradient.specialBuffer(), + gradient.specialShapeInfo(), initState.specialBuffer(), + initState.specialShapeInfo(), update.specialBuffer(), + update.specialShapeInfo(), stateG.specialBuffer(), + stateG.specialShapeInfo(), dLr, dRmsDecay, dEpsilon), + FLOAT_TYPES); + + NDArray::registerSpecialUse({&update, &stateG}, {&gradient, &initState}); + + manager.synchronize(); } -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/weights.cu b/libnd4j/include/ops/declarable/helpers/cuda/weights.cu index 1620820a5896..9b5b01df8fef 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/weights.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/weights.cu @@ -24,96 +24,117 @@ namespace sd { namespace ops { namespace helpers { - - - template - static __device__ void adjustWeightsKernelD(void* inputBuffer, Nd4jLong const* inputShape, - void* weightsBuffer, Nd4jLong const* weightsShape, - void* outputBuffer, Nd4jLong inputLength, - Nd4jLong outputLength, int val) { - // typedef Nd4jLong T; - auto tid = threadIdx.x; - //int threadCount = gridDim.x * blockDim.x; - __shared__ T* outputPart; - __shared__ Nd4jLong offset; - //for (int e = 0; e < inputLength; e++) { - for (Nd4jLong e = tid; e < inputLength; e += blockDim.x) { - - Nd4jLong xOffset = shape::getIndexOffset(e, inputShape); - int current = *(reinterpret_cast(inputBuffer) + xOffset); - if (current == val) { - //printf("%lld\n", xOffset); - //Nd4jLong zOffset = shape::getIndexOffset(val, outputShape); - if (weightsBuffer != nullptr) { - Nd4jLong yOffset = shape::getIndexOffset(e, weightsShape); - //atomicAdd(); - //*reinterpret_cast(outputBuffer) += reinterpret_cast(weightsBuffer)[yOffset]; - sd::math::atomics::nd4j_atomicAdd(reinterpret_cast(outputBuffer), reinterpret_cast(weightsBuffer)[yOffset]); //output->p(val, output->e(val) + 1); -// atomicAdd(reinterpret_cast(outputBuffer), reinterpret_cast(weightsBuffer)[yOffset]); //output->p(val, output->e(val) + 1); - } - else { - //*reinterpret_cast(outputBuffer) += int(1); - //printf("outputBuffer[0] = %d\n", static_cast(*(reinterpret_cast(outputBuffer)))); - sd::math::atomics::nd4j_atomicAdd(reinterpret_cast(outputBuffer), T(1)); //output->p(val, output->e(val) + 1); -// atomicAdd(reinterpret_cast(outputBuffer), int(1)); //output->p(val, output->e(val) + 1); - // printf("outputBuffer[%ld] = %d\n", zOffset, static_cast(*(reinterpret_cast(outputBuffer) + zOffset))); - } - //printf("xOffset is %ld, zOffset is %ld\n", xOffset, zOffset); - } - } -// if (threadIdx.x + offset < outputLength) -// reinterpret_cast(outputBuffer)[threadIdx.x + offset] = outputPart[threadIdx.x]; +template +static __device__ void adjustWeightsKernelD( + void* inputBuffer, Nd4jLong const* inputShape, void* weightsBuffer, + Nd4jLong const* weightsShape, void* outputBuffer, Nd4jLong inputLength, + Nd4jLong outputLength, int val) { + // typedef Nd4jLong T; + auto tid = threadIdx.x; + // int threadCount = gridDim.x * blockDim.x; + __shared__ T* outputPart; + __shared__ Nd4jLong offset; + // for (int e = 0; e < inputLength; e++) { + for (Nd4jLong e = tid; e < inputLength; e += blockDim.x) { + Nd4jLong xOffset = shape::getIndexOffset(e, inputShape); + int current = *(reinterpret_cast(inputBuffer) + xOffset); + if (current == val) { + // printf("%lld\n", xOffset); + // Nd4jLong zOffset = shape::getIndexOffset(val, outputShape); + if (weightsBuffer != nullptr) { + Nd4jLong yOffset = shape::getIndexOffset(e, weightsShape); + // atomicAdd(); + //*reinterpret_cast(outputBuffer) += reinterpret_cast(weightsBuffer)[yOffset]; + sd::math::atomics::nd4j_atomicAdd( + reinterpret_cast(outputBuffer), + reinterpret_cast( + weightsBuffer)[yOffset]); // output->p(val, output->e(val) + + // 1); + // atomicAdd(reinterpret_cast(outputBuffer), + // reinterpret_cast(weightsBuffer)[yOffset]); + // //output->p(val, output->e(val) + 1); + } else { + //*reinterpret_cast(outputBuffer) += int(1); + // printf("outputBuffer[0] = %d\n", + // static_cast(*(reinterpret_cast(outputBuffer)))); + sd::math::atomics::nd4j_atomicAdd( + reinterpret_cast(outputBuffer), + T(1)); // output->p(val, output->e(val) + 1); + // atomicAdd(reinterpret_cast(outputBuffer), + // int(1)); //output->p(val, output->e(val) + 1); + // printf("outputBuffer[%ld] = %d\n", zOffset, + // static_cast(*(reinterpret_cast(outputBuffer) + + // zOffset))); + } + // printf("xOffset is %ld, zOffset is %ld\n", xOffset, zOffset); } + } + // if (threadIdx.x + offset < outputLength) + // reinterpret_cast(outputBuffer)[threadIdx.x + offset] = + // outputPart[threadIdx.x]; +} - template - static __global__ void adjustWeightsKernel(void* inputBuffer, Nd4jLong const* inputShape, - void* weightsBuffer, Nd4jLong const* weightsShape, - void* outputBuffer, Nd4jLong const* outputShape, - int minLength, int maxLength) { - - //auto tid = blockIdx.x * blockDim.x + threadIdx.x; // * blockDim.x; // + threadIdx.x; - int threadCount = gridDim.x * blockDim.x; - Nd4jLong inputLength = shape::length(inputShape); - - Nd4jLong outputLength = shape::length(outputShape); - Nd4jLong borderLen = 1; - - for (Nd4jLong e = blockIdx.x; e < outputLength; e += threadCount) { - //if (blockIdx.x < outputLength) { - //if (e + threadCount < outputLength) { - Nd4jLong zOffset = shape::getIndexOffset(e, outputShape); - //printf("%d %d %d\n", blockIdx.x, blockDim.x, threadIdx.x); - //Nd4jLong borderLen = 1; - T* outputBufferZ = reinterpret_cast(outputBuffer) + zOffset; - adjustWeightsKernelD(inputBuffer, inputShape, weightsBuffer, weightsShape, (void*)outputBufferZ, - inputLength, outputLength, (int)zOffset); - - } - } +template +static __global__ void adjustWeightsKernel( + void* inputBuffer, Nd4jLong const* inputShape, void* weightsBuffer, + Nd4jLong const* weightsShape, void* outputBuffer, + Nd4jLong const* outputShape, int minLength, int maxLength) { + // auto tid = blockIdx.x * blockDim.x + threadIdx.x; // * blockDim.x; // + + // threadIdx.x; + int threadCount = gridDim.x * blockDim.x; + Nd4jLong inputLength = shape::length(inputShape); - template - static void adjustWeights_(sd::LaunchContext * context, NDArray* input, NDArray* weights, NDArray* output, int minLength, int maxLength) { -// for (int e = 0; e < input->lengthOf(); e++) { -// int val = input->e(e); -// if (val < maxLength) { -// if (weights != nullptr) -// output->p(val, output->e(val) + weights->e(e)); -// else -// output->p(val, output->e(val) + 1); -// } -// } - dim3 launchDims(256, 512, 8192); - auto stream = context->getCudaStream(); - adjustWeightsKernel<<>>(input->specialBuffer(), - input->specialShapeInfo(), weights?weights->specialBuffer():nullptr, weights?weights->specialShapeInfo():nullptr, - output->specialBuffer(), output->specialShapeInfo(), minLength, maxLength); - } + Nd4jLong outputLength = shape::length(outputShape); + Nd4jLong borderLen = 1; - void adjustWeights(sd::LaunchContext * context, NDArray* input, NDArray* weights, NDArray* output, int minLength, int maxLength) { - BUILD_SINGLE_SELECTOR(output->dataType(), adjustWeights_, (context, input, weights, output, minLength, maxLength), GENERIC_NUMERIC_TYPES); - } + for (Nd4jLong e = blockIdx.x; e < outputLength; e += threadCount) { + // if (blockIdx.x < outputLength) { + // if (e + threadCount < outputLength) { + Nd4jLong zOffset = shape::getIndexOffset(e, outputShape); + // printf("%d %d %d\n", blockIdx.x, blockDim.x, threadIdx.x); + // Nd4jLong borderLen = 1; + T* outputBufferZ = reinterpret_cast(outputBuffer) + zOffset; + adjustWeightsKernelD(inputBuffer, inputShape, weightsBuffer, + weightsShape, (void*)outputBufferZ, inputLength, + outputLength, (int)zOffset); + } +} - BUILD_SINGLE_TEMPLATE(template void adjustWeights_, (sd::LaunchContext * context, NDArray* input, NDArray* weights, NDArray* output, int minLength, int maxLength), GENERIC_NUMERIC_TYPES); +template +static void adjustWeights_(sd::LaunchContext* context, NDArray* input, + NDArray* weights, NDArray* output, int minLength, + int maxLength) { + // for (int e = 0; e < input->lengthOf(); e++) { + // int val = input->e(e); + // if (val < maxLength) { + // if (weights != nullptr) + // output->p(val, output->e(val) + weights->e(e)); + // else + // output->p(val, output->e(val) + 1); + // } + // } + dim3 launchDims(256, 512, 8192); + auto stream = context->getCudaStream(); + adjustWeightsKernel<<>>( + input->specialBuffer(), input->specialShapeInfo(), + weights ? weights->specialBuffer() : nullptr, + weights ? weights->specialShapeInfo() : nullptr, output->specialBuffer(), + output->specialShapeInfo(), minLength, maxLength); } + +void adjustWeights(sd::LaunchContext* context, NDArray* input, NDArray* weights, + NDArray* output, int minLength, int maxLength) { + BUILD_SINGLE_SELECTOR(output->dataType(), adjustWeights_, + (context, input, weights, output, minLength, maxLength), + GENERIC_NUMERIC_TYPES); } -} \ No newline at end of file + +BUILD_SINGLE_TEMPLATE(template void adjustWeights_, + (sd::LaunchContext * context, NDArray* input, + NDArray* weights, NDArray* output, int minLength, + int maxLength), + GENERIC_NUMERIC_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/zeta.cu b/libnd4j/include/ops/declarable/helpers/cuda/zeta.cu index 660c49325845..dd397ac1e02f 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/zeta.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/zeta.cu @@ -18,68 +18,77 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 26.04.2019 // -#include +#include namespace sd { namespace ops { namespace helpers { - /////////////////////////////////////////////////////////////////// -template +template __global__ static void zetaCuda(const void *vx, const Nd4jLong *xShapeInfo, const void *vq, const Nd4jLong *qShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - const auto x = reinterpret_cast(vx); - const auto q = reinterpret_cast(vq); - auto z = reinterpret_cast(vz); - - __shared__ Nd4jLong len; + void *vz, const Nd4jLong *zShapeInfo) { + const auto x = reinterpret_cast(vx); + const auto q = reinterpret_cast(vq); + auto z = reinterpret_cast(vz); - if (threadIdx.x == 0) - len = shape::length(xShapeInfo); - __syncthreads(); + __shared__ Nd4jLong len; - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto totalThreads = gridDim.x * blockDim.x; + if (threadIdx.x == 0) len = shape::length(xShapeInfo); + __syncthreads(); - for (int i = tid; i < len; i += totalThreads) { + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto totalThreads = gridDim.x * blockDim.x; - const auto xOffset = shape::getIndexOffset(i, xShapeInfo); - const auto qOffset = shape::getIndexOffset(i, qShapeInfo); - const auto zOffset = shape::getIndexOffset(i, zShapeInfo); + for (int i = tid; i < len; i += totalThreads) { + const auto xOffset = shape::getIndexOffset(i, xShapeInfo); + const auto qOffset = shape::getIndexOffset(i, qShapeInfo); + const auto zOffset = shape::getIndexOffset(i, zShapeInfo); - z[zOffset] = zetaScalar(x[xOffset], q[qOffset]); - } + z[zOffset] = zetaScalar(x[xOffset], q[qOffset]); + } } /////////////////////////////////////////////////////////////////// -template -static void zetaCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vq, const Nd4jLong *qShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { - - zetaCuda<<>>(vx, xShapeInfo, vq, qShapeInfo, vz, zShapeInfo); +template +static void zetaCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const void *vq, + const Nd4jLong *qShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + zetaCuda<<>>( + vx, xShapeInfo, vq, qShapeInfo, vz, zShapeInfo); } -void zeta(sd::LaunchContext * context, const NDArray& x, const NDArray& q, NDArray& z) { - - if(!x.isActualOnDeviceSide()) x.syncToDevice(); - if(!q.isActualOnDeviceSide()) q.syncToDevice(); +void zeta(sd::LaunchContext *context, const NDArray &x, const NDArray &q, + NDArray &z) { + if (!x.isActualOnDeviceSide()) x.syncToDevice(); + if (!q.isActualOnDeviceSide()) q.syncToDevice(); - int threadsPerBlock = MAX_NUM_THREADS / 2; - int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + int threadsPerBlock = MAX_NUM_THREADS / 2; + int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - BUILD_SINGLE_SELECTOR(x.dataType(), zetaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), x.specialBuffer(), x.specialShapeInfo(), q.specialBuffer(), q.specialShapeInfo(), z.specialBuffer(), z.specialShapeInfo()), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR( + x.dataType(), zetaCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + x.specialBuffer(), x.specialShapeInfo(), q.specialBuffer(), + q.specialShapeInfo(), z.specialBuffer(), z.specialShapeInfo()), + FLOAT_TYPES); - x.tickReadHost(); - q.tickReadHost(); - z.tickWriteDevice(); -} - -BUILD_SINGLE_TEMPLATE(template void zetaCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vq, const Nd4jLong *qShapeInfo, void *vz, const Nd4jLong *zShapeInfo), FLOAT_TYPES); - - -} -} + x.tickReadHost(); + q.tickReadHost(); + z.tickWriteDevice(); } +BUILD_SINGLE_TEMPLATE(template void zetaCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const void *vq, + const Nd4jLong *qShapeInfo, void *vz, + const Nd4jLong *zShapeInfo), + FLOAT_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/d_t_s.h b/libnd4j/include/ops/declarable/helpers/d_t_s.h index e5ac58e5aa24..b66d12b6ce3f 100644 --- a/libnd4j/include/ops/declarable/helpers/d_t_s.h +++ b/libnd4j/include/ops/declarable/helpers/d_t_s.h @@ -18,14 +18,15 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { namespace ops { namespace helpers { - void _depthToSpace(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC); -} +void _depthToSpace(sd::LaunchContext *context, const NDArray &input, + NDArray *output, int block_size, bool isNHWC); } -} \ No newline at end of file +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/diag.h b/libnd4j/include/ops/declarable/helpers/diag.h index af84eec01fc7..a216c9457454 100644 --- a/libnd4j/include/ops/declarable/helpers/diag.h +++ b/libnd4j/include/ops/declarable/helpers/diag.h @@ -19,17 +19,19 @@ // #ifndef __DIAG_H_HELPERS__ #define __DIAG_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void diagFunctor(sd::LaunchContext * context, NDArray const* input, NDArray* output); - void diagPartFunctor(sd::LaunchContext * context, NDArray const* input, NDArray* output); +void diagFunctor(sd::LaunchContext* context, NDArray const* input, + NDArray* output); +void diagPartFunctor(sd::LaunchContext* context, NDArray const* input, + NDArray* output); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/dilation2d.h b/libnd4j/include/ops/declarable/helpers/dilation2d.h index 281a2f26a347..1198a6b59d66 100644 --- a/libnd4j/include/ops/declarable/helpers/dilation2d.h +++ b/libnd4j/include/ops/declarable/helpers/dilation2d.h @@ -20,68 +20,77 @@ #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////// -void dilation2d(sd::LaunchContext* context, NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW); +void dilation2d(sd::LaunchContext *context, NDArray *input, NDArray *weights, + NDArray *output, const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW); ////////////////////////////////////////////////////////////////////// -FORCEINLINE Nd4jStatus outputSize(sd::LaunchContext * context, const int inSize, const int k, const int d, const int s, bool isSameMode, int *outSize, int *padding_before, int *padding_after) { - if (s <= 0) - return Status::THROW("Dilation2D: Stride must be > 0"); - - if (d < 1) - return Status::THROW("Dilation2D: Dilation rate must be >= 1"); - - int kEff = (k - 1) * d + 1; - if (isSameMode) { - *outSize = (inSize + s - 1) / s; - const int padding_needed = sd::math::nd4j_max(0, (*outSize - 1) * s + kEff -inSize); - - *padding_before = padding_needed / 2; - *padding_after = padding_needed - *padding_before; - } else { - *outSize = (inSize - kEff + s) / s; - *padding_before = *padding_after = 0; - } - - if (*outSize < 0) - return Status::THROW("Dilation2D: outSize has negative value"); - - return Status::OK(); +FORCEINLINE Nd4jStatus outputSize(sd::LaunchContext *context, const int inSize, + const int k, const int d, const int s, + bool isSameMode, int *outSize, + int *padding_before, int *padding_after) { + if (s <= 0) return Status::THROW("Dilation2D: Stride must be > 0"); + + if (d < 1) return Status::THROW("Dilation2D: Dilation rate must be >= 1"); + + int kEff = (k - 1) * d + 1; + if (isSameMode) { + *outSize = (inSize + s - 1) / s; + const int padding_needed = + sd::math::nd4j_max(0, (*outSize - 1) * s + kEff - inSize); + + *padding_before = padding_needed / 2; + *padding_after = padding_needed - *padding_before; + } else { + *outSize = (inSize - kEff + s) / s; + *padding_before = *padding_after = 0; + } + + if (*outSize < 0) + return Status::THROW("Dilation2D: outSize has negative value"); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////// -FORCEINLINE Nd4jStatus dilation_hw(sd::LaunchContext * context, Nd4jLong const* in, Nd4jLong const* wh, std::vector &strides, std::vector &rates, bool isSameMode, int *sH, int *sW, int *pH, int *pW, int *dH, int *dW, int *oH, int *oW) { - const int iH = shape::sizeAt(in, 1); - const int iW = shape::sizeAt(in, 2); - const int iC = shape::sizeAt(in, 3); - - *sH = strides[1]; - *sW = strides[2]; - *dH = rates[1]; - *dW = rates[2]; - - const int kH = shape::sizeAt(wh, 0); - const int kW = shape::sizeAt(wh, 1); - - const int kHeff = kH + (kH - 1) * (*dH - 1); - const int kWeff = kW + (kW - 1) * (*dW - 1); - - int padding_after_unusedA, padding_after_unusedB; - if (outputSize(context, iH, kHeff, 1, *sH, isSameMode, oH, pH, &padding_after_unusedA) != Status::OK()) - return Status::THROW("Dilation2D: bad height"); - - if (outputSize(context, iW, kWeff, 1, *sW, isSameMode, oW, pW, &padding_after_unusedA) != Status::OK()) - return Status::THROW("Dilation2D: bad width"); - - return Status::OK(); +FORCEINLINE Nd4jStatus dilation_hw(sd::LaunchContext *context, + Nd4jLong const *in, Nd4jLong const *wh, + std::vector &strides, + std::vector &rates, bool isSameMode, + int *sH, int *sW, int *pH, int *pW, int *dH, + int *dW, int *oH, int *oW) { + const int iH = shape::sizeAt(in, 1); + const int iW = shape::sizeAt(in, 2); + const int iC = shape::sizeAt(in, 3); + + *sH = strides[1]; + *sW = strides[2]; + *dH = rates[1]; + *dW = rates[2]; + + const int kH = shape::sizeAt(wh, 0); + const int kW = shape::sizeAt(wh, 1); + + const int kHeff = kH + (kH - 1) * (*dH - 1); + const int kWeff = kW + (kW - 1) * (*dW - 1); + + int padding_after_unusedA, padding_after_unusedB; + if (outputSize(context, iH, kHeff, 1, *sH, isSameMode, oH, pH, + &padding_after_unusedA) != Status::OK()) + return Status::THROW("Dilation2D: bad height"); + + if (outputSize(context, iW, kWeff, 1, *sW, isSameMode, oW, pW, + &padding_after_unusedA) != Status::OK()) + return Status::THROW("Dilation2D: bad width"); + + return Status::OK(); } - - -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/dropout.h b/libnd4j/include/ops/declarable/helpers/dropout.h index 052b68f33f6b..4a4c1c423ee3 100644 --- a/libnd4j/include/ops/declarable/helpers/dropout.h +++ b/libnd4j/include/ops/declarable/helpers/dropout.h @@ -19,20 +19,29 @@ // #ifndef __DROP_OUT_HELPERS__ #define __DROP_OUT_HELPERS__ -#include #include #include +#include namespace sd { namespace ops { namespace helpers { - int dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue); - int dropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue); - int alphaDropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta); - int alphaDropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta); +int dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, + NDArray* reduceShape, int seed, double probValue); +int dropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, + NDArray* output, NDArray* reduceShape, int seed, + double probValue); +int alphaDropOutFunctor(graph::Context& context, NDArray* input, + NDArray* output, NDArray* reduceShape, int seed, + double probValue, double alpha, double alpha1, + double beta); +int alphaDropOutFunctorBP(graph::Context& context, NDArray* input, + NDArray* gradOut, NDArray* output, + NDArray* reduceShape, int seed, double probValue, + double alpha, double alpha1, double beta); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/dynamic.h b/libnd4j/include/ops/declarable/helpers/dynamic.h index 29452cb8ff38..75c6ee21dbf1 100644 --- a/libnd4j/include/ops/declarable/helpers/dynamic.h +++ b/libnd4j/include/ops/declarable/helpers/dynamic.h @@ -20,21 +20,32 @@ #ifndef __DYNAMIC_H_HELPERS__ #define __DYNAMIC_H_HELPERS__ -#include #include +#include namespace sd { - namespace ops { - namespace helpers { +namespace ops { +namespace helpers { - void dynamicPartitionFunctor(sd::LaunchContext * context, NDArray const* input, NDArray const* indices, std::vector& outputList); +void dynamicPartitionFunctor(sd::LaunchContext* context, NDArray const* input, + NDArray const* indices, + std::vector& outputList); - int dynamicStitchFunctor(sd::LaunchContext * context, std::vector const& inputs, std::vector const& indices, NDArray* output); +int dynamicStitchFunctor(sd::LaunchContext* context, + std::vector const& inputs, + std::vector const& indices, NDArray* output); - void dynamicPartitionFunctorBP(sd::LaunchContext * context, NDArray const* input, NDArray const* indices, std::vector const& gradientInputList, std::vector& outputList); +void dynamicPartitionFunctorBP(sd::LaunchContext* context, NDArray const* input, + NDArray const* indices, + std::vector const& gradientInputList, + std::vector& outputList); - int dynamicStitchFunctorBP(sd::LaunchContext * context, std::vector const& inputs, std::vector const& indices, NDArray const* gradientInput, std::vector& outputList); - } - } -} +int dynamicStitchFunctorBP(sd::LaunchContext* context, + std::vector const& inputs, + std::vector const& indices, + NDArray const* gradientInput, + std::vector& outputList); +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/extract_patches.h b/libnd4j/include/ops/declarable/helpers/extract_patches.h index 63d5e94f4846..1c64745d00c5 100644 --- a/libnd4j/include/ops/declarable/helpers/extract_patches.h +++ b/libnd4j/include/ops/declarable/helpers/extract_patches.h @@ -19,16 +19,18 @@ // #ifndef __EXTRACT_PATCHES_H_HELPERS__ #define __EXTRACT_PATCHES_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void extractPatches(sd::LaunchContext * context, NDArray* images, NDArray* output, int sizeRow, int sizeCol, int stradeRow, int stradeCol, int rateRow, int rateCol, bool theSame); +void extractPatches(sd::LaunchContext* context, NDArray* images, + NDArray* output, int sizeRow, int sizeCol, int stradeRow, + int stradeCol, int rateRow, int rateCol, bool theSame); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/fake_quantization.h b/libnd4j/include/ops/declarable/helpers/fake_quantization.h index b5f4dff00225..90d42a07e1b8 100644 --- a/libnd4j/include/ops/declarable/helpers/fake_quantization.h +++ b/libnd4j/include/ops/declarable/helpers/fake_quantization.h @@ -19,16 +19,19 @@ // #ifndef __FAKE_QUANTIZATION_H_HELPERS__ #define __FAKE_QUANTIZATION_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output); - void fakeQuantWithMinMaxVarsPerChannel(LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output); -} -} -} +void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, + int numBits, bool narrowed, NDArray* output); +void fakeQuantWithMinMaxVarsPerChannel(LaunchContext* context, NDArray* input, + NDArray* min, NDArray* max, int numBits, + bool narrowed, NDArray* output); +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/flatten.h b/libnd4j/include/ops/declarable/helpers/flatten.h index da6253dfa0d9..8240d7ffa053 100644 --- a/libnd4j/include/ops/declarable/helpers/flatten.h +++ b/libnd4j/include/ops/declarable/helpers/flatten.h @@ -18,51 +18,48 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_FLATTEN_H -#define DEV_TESTS_FLATTEN_H +#ifndef SD_FLATTEN_H +#define SD_FLATTEN_H -#include #include -namespace sd { -namespace ops { -namespace helpers { +#include +namespace sd { +namespace ops { +namespace helpers { ////////////////////////////////////////////////////////////////////// -void flatten(sd::LaunchContext *context, std::vector &inputs, NDArray *output, char order); - +void flatten(sd::LaunchContext *context, std::vector &inputs, + NDArray *output, char order); ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD Nd4jLong getIndexOffsetOrdered(Nd4jLong index, const Nd4jLong *shapeInfo, const char order) { - - Nd4jLong offset = 0; - - if (order == 'c') { - - for(uint i = shapeInfo[0]; i > 1; --i) { - offset += (index % shapeInfo[i]) * shapeInfo[i + shapeInfo[0]]; - index /= shapeInfo[i]; - } - - offset += index * shapeInfo[1 + shapeInfo[0]]; // last iteration +INLINEDEF _CUDA_HD Nd4jLong getIndexOffsetOrdered(Nd4jLong index, + const Nd4jLong *shapeInfo, + const char order) { + Nd4jLong offset = 0; + + if (order == 'c') { + for (uint i = shapeInfo[0]; i > 1; --i) { + offset += (index % shapeInfo[i]) * shapeInfo[i + shapeInfo[0]]; + index /= shapeInfo[i]; } - else { - - for(uint i = 1; i < shapeInfo[0]; ++i) { - offset += (index % shapeInfo[i]) * shapeInfo[i + shapeInfo[0]]; - index /= shapeInfo[i]; - } - offset += index * shapeInfo[2 * shapeInfo[0]]; // last iteration + offset += index * shapeInfo[1 + shapeInfo[0]]; // last iteration + } else { + for (uint i = 1; i < shapeInfo[0]; ++i) { + offset += (index % shapeInfo[i]) * shapeInfo[i + shapeInfo[0]]; + index /= shapeInfo[i]; } - return offset; -} - + offset += index * shapeInfo[2 * shapeInfo[0]]; // last iteration + } + return offset; } -} -} -#endif //DEV_TESTS_FLATTEN_H +} // namespace helpers +} // namespace ops +} // namespace sd + +#endif // SD_FLATTEN_H diff --git a/libnd4j/include/ops/declarable/helpers/gammaMathFunc.h b/libnd4j/include/ops/declarable/helpers/gammaMathFunc.h index 2f99f3777c31..8ed853a43173 100644 --- a/libnd4j/include/ops/declarable/helpers/gammaMathFunc.h +++ b/libnd4j/include/ops/declarable/helpers/gammaMathFunc.h @@ -23,78 +23,98 @@ #define LIBND4J_GAMMAMATHFUNC_H #include + #include "array/NDArray.h" namespace sd { namespace ops { namespace helpers { - // calculate the digamma function for each element for array - void diGamma(sd::LaunchContext* context, const NDArray& x, NDArray& z); - - // calculate the polygamma function - void polyGamma(sd::LaunchContext* context, const NDArray& n, const NDArray& x, NDArray& z); - - // calculate the digamma function for one element - // implementation is based on serial representation written in terms of the Hurwitz zeta function as polygamma = (-1)^{n+1} * n! * zeta(n+1, x) - template - _CUDA_HD T diGammaScalar(T x) { - - const int xInt = static_cast(x); - - // negative and zero - if(x <= 0) { - if(x == xInt) // integer - return DataTypeUtils::infOrMax(); - else - return diGammaScalar(1 - x) - M_PI / sd::math::nd4j_tan(M_PI * x); // use reflection formula psi(1-x) = psi(x) + pi*cot(pi*x) - } - - // positive integer - if(x == xInt && xInt <= 20) { // psi(n) = -Euler_Mascheroni_const + sum_from_k=1_to_n-1( 1/k ), for n = 1,2,3,...inf, we use this formula only for n <= 20 to avoid time consuming sum calculation for bigger n - T result = -0.577215664901532; - for (uint i = 1; i <= xInt - 1; ++i) { - result += static_cast(1) / i; - } - return result; - } - - // positive half-integer - if(x - xInt == 0.5 && xInt <= 20) { // psi(n+0.5) = -Euler_Mascheroni_const - 2*ln(2) + sum_from_k=1_to_n( 2/(2*k-1) ) , for n = 1,2,3,...inf, we use this formula only for n <= 20 to avoid time consuming sum calculation for bigger n - T result = -0.577215664901532 - 2 * sd::math::nd4j_log(2); - for (uint i = 1; i <= xInt; ++i) { - result += static_cast(2) / (2*i - 1); - } - return result; - } - - // positive, smaller then 5; we should use number > 5 in order to have satisfactory accuracy in asymptotic expansion - if(x < 5) - return diGammaScalar(1 + x) - static_cast(1) / x; // recurrence formula psi(x) = psi(x+1) - 1/x. - - // *** other positive **** // - - // truncated expansion formula (from wiki) - // psi(x) = log(x) - 1/(2*x) - 1/(12*x^2) + 1/(120*x^4) - 1/(252*x^6) + 1/(240*x^8) - 5/(660*x^10) + 691/(32760*x^12) - 1/(12*x^14) + ... - - if(x >= (sizeof(T) > 4 ? 1.e16 : 1.e8)) // if x is too big take into account only log(x) - return sd::math::nd4j_log(x); - - // coefficients used in truncated asymptotic expansion formula - const T coeffs[7] = {-(T)1/12, (T)1/120, -(T)1/252, (T)1/240, -(T)5/660, (T)691/32760, -(T)1/12}; - // const T coeffs[7] = {-0.0833333333333333, 0.00833333333333333, -0.00396825396825397, 0.00416666666666667, -0.00757575757575758, 0.0210927960927961, -0.0833333333333333}; - - const T x2Inv = static_cast(1) / (x * x); - T result = 0; - - for (int i = 6; i >= 0; --i) - result = (result + coeffs[i]) * x2Inv; - return result + sd::math::nd4j_log(x) - static_cast(0.5) / x; - } - -} -} +// calculate the digamma function for each element for array +void diGamma(sd::LaunchContext* context, const NDArray& x, NDArray& z); + +// calculate the polygamma function +void polyGamma(sd::LaunchContext* context, const NDArray& n, const NDArray& x, + NDArray& z); + +// calculate the digamma function for one element +// implementation is based on serial representation written in terms of the +// Hurwitz zeta function as polygamma = (-1)^{n+1} * n! * zeta(n+1, x) +template +_CUDA_HD T diGammaScalar(T x) { + const int xInt = static_cast(x); + + // negative and zero + if (x <= 0) { + if (x == xInt) // integer + return DataTypeUtils::infOrMax(); + else + return diGammaScalar(1 - x) - + M_PI / sd::math::nd4j_tan( + M_PI * x); // use reflection formula psi(1-x) = psi(x) + // + pi*cot(pi*x) + } + + // positive integer + if (x == xInt && + xInt <= + 20) { // psi(n) = -Euler_Mascheroni_const + sum_from_k=1_to_n-1( 1/k + // ), for n = 1,2,3,...inf, we use this formula only for n <= + // 20 to avoid time consuming sum calculation for bigger n + T result = -0.577215664901532; + for (uint i = 1; i <= xInt - 1; ++i) { + result += static_cast(1) / i; + } + return result; + } + + // positive half-integer + if (x - xInt == 0.5 && + xInt <= 20) { // psi(n+0.5) = -Euler_Mascheroni_const - 2*ln(2) + + // sum_from_k=1_to_n( 2/(2*k-1) ) , for n = 1,2,3,...inf, + // we use this formula only for n <= 20 to avoid time + // consuming sum calculation for bigger n + T result = -0.577215664901532 - 2 * sd::math::nd4j_log(2); + for (uint i = 1; i <= xInt; ++i) { + result += static_cast(2) / (2 * i - 1); + } + return result; + } + + // positive, smaller then 5; we should use number > 5 in order to have + // satisfactory accuracy in asymptotic expansion + if (x < 5) + return diGammaScalar(1 + x) - + static_cast(1) / + x; // recurrence formula psi(x) = psi(x+1) - 1/x. + + // *** other positive **** // + + // truncated expansion formula (from wiki) + // psi(x) = log(x) - 1/(2*x) - 1/(12*x^2) + 1/(120*x^4) - 1/(252*x^6) + + // 1/(240*x^8) - 5/(660*x^10) + 691/(32760*x^12) - 1/(12*x^14) + ... + + if (x >= (sizeof(T) > 4 + ? 1.e16 + : 1.e8)) // if x is too big take into account only log(x) + return sd::math::nd4j_log(x); + + // coefficients used in truncated asymptotic expansion formula + const T coeffs[7] = {-(T)1 / 12, (T)1 / 120, -(T)1 / 252, (T)1 / 240, + -(T)5 / 660, (T)691 / 32760, -(T)1 / 12}; + // const T coeffs[7] = {-0.0833333333333333, 0.00833333333333333, + // -0.00396825396825397, 0.00416666666666667, -0.00757575757575758, + // 0.0210927960927961, -0.0833333333333333}; + + const T x2Inv = static_cast(1) / (x * x); + T result = 0; + + for (int i = 6; i >= 0; --i) result = (result + coeffs[i]) * x2Inv; + return result + sd::math::nd4j_log(x) - static_cast(0.5) / x; } +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //LIBND4J_GAMMAMATHFUNC_H +#endif // LIBND4J_GAMMAMATHFUNC_H diff --git a/libnd4j/include/ops/declarable/helpers/gather.h b/libnd4j/include/ops/declarable/helpers/gather.h index f3870838576c..bb324558627d 100644 --- a/libnd4j/include/ops/declarable/helpers/gather.h +++ b/libnd4j/include/ops/declarable/helpers/gather.h @@ -26,11 +26,13 @@ namespace sd { namespace ops { namespace helpers { - - void gather(sd::LaunchContext * context, const NDArray* input, const NDArray* indices, NDArray* output, const std::vector& intArgs); + +void gather(sd::LaunchContext* context, const NDArray* input, + const NDArray* indices, NDArray* output, + const std::vector& intArgs); } -} -} +} // namespace ops +} // namespace sd -#endif //LIBND4J_GATHER_H +#endif // LIBND4J_GATHER_H diff --git a/libnd4j/include/ops/declarable/helpers/gradient.h b/libnd4j/include/ops/declarable/helpers/gradient.h index 583396cf345f..b10298b0c7aa 100644 --- a/libnd4j/include/ops/declarable/helpers/gradient.h +++ b/libnd4j/include/ops/declarable/helpers/gradient.h @@ -19,19 +19,20 @@ // #ifndef __GRADIENT_H_HELPERS__ #define __GRADIENT_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - /* - * applyGradientDescent: calculate z = x - y * w. - * */ - void applyGradientDescent(sd::LaunchContext* context, NDArray* input, NDArray* step, double weight, NDArray* output); +/* + * applyGradientDescent: calculate z = x - y * w. + * */ +void applyGradientDescent(sd::LaunchContext* context, NDArray* input, + NDArray* step, double weight, NDArray* output); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/gru.h b/libnd4j/include/ops/declarable/helpers/gru.h index 9e98e4046616..afc2d9854058 100644 --- a/libnd4j/include/ops/declarable/helpers/gru.h +++ b/libnd4j/include/ops/declarable/helpers/gru.h @@ -27,33 +27,37 @@ namespace sd { namespace ops { namespace helpers { - void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wru, const NDArray* Wc, - const NDArray* bru, const NDArray* bc, - NDArray* r, NDArray* u, NDArray* c, NDArray* h); - - void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wru, const NDArray* Wc, const NDArray* b, - NDArray* gates, NDArray* h); - - void gruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h); - - void gruCellBp(sd::LaunchContext* context, - const NDArray* x, const NDArray* hLast, - const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc, - const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, - NDArray* dLdx, NDArray* dLdhLast, - NDArray* dLdW, NDArray* dLdWc, - NDArray* dLdb, NDArray* dLdbc); - - void gruCellBp(sd::LaunchContext* context, - const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* gates, - NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb); - - void gruTimeLoopBp(sd::LaunchContext * context, - const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, - NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb); -} -} -} - - -#endif //LIBND4J_GRU_H \ No newline at end of file +void gruCell(sd::LaunchContext* context, const NDArray* x, const NDArray* hLast, + const NDArray* Wru, const NDArray* Wc, const NDArray* bru, + const NDArray* bc, NDArray* r, NDArray* u, NDArray* c, NDArray* h); + +void gruCell(sd::LaunchContext* context, const NDArray* x, const NDArray* hLast, + const NDArray* Wru, const NDArray* Wc, const NDArray* b, + NDArray* gates, NDArray* h); + +void gruTimeLoop(sd::LaunchContext* context, const NDArray* x, + const NDArray* h0, const NDArray* Wx, const NDArray* Wh, + const NDArray* b, NDArray* h); + +void gruCellBp(sd::LaunchContext* context, const NDArray* x, + const NDArray* hLast, const NDArray* W, const NDArray* Wc, + const NDArray* b, const NDArray* bc, const NDArray* dLdr, + const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, + NDArray* dLdx, NDArray* dLdhLast, NDArray* dLdW, NDArray* dLdWc, + NDArray* dLdb, NDArray* dLdbc); + +void gruCellBp(sd::LaunchContext* context, const NDArray* x, const NDArray* hI, + const NDArray* Wx, const NDArray* Wh, const NDArray* b, + const NDArray* dLdh, const NDArray* gates, NDArray* dLdx, + NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb); + +void gruTimeLoopBp(sd::LaunchContext* context, const NDArray* x, + const NDArray* hI, const NDArray* Wx, const NDArray* Wh, + const NDArray* b, const NDArray* dLdh, NDArray* dLdx, + NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, + NDArray* dLdb); +} // namespace helpers +} // namespace ops +} // namespace sd + +#endif // LIBND4J_GRU_H \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/hamming.h b/libnd4j/include/ops/declarable/helpers/hamming.h index 6450d788240e..77d0aa6df35b 100644 --- a/libnd4j/include/ops/declarable/helpers/hamming.h +++ b/libnd4j/include/ops/declarable/helpers/hamming.h @@ -21,12 +21,15 @@ #ifndef SD_HAMMING_H #define SD_HAMMING_H +#include +#include + namespace sd { - namespace ops { - namespace helpers { - void hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &output); - } - } +namespace ops { +namespace helpers { +void hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &output); } +} // namespace ops +} // namespace sd -#endif //DEV_TESTS_HAMMING_H +#endif // SD_HAMMING_H diff --git a/libnd4j/include/ops/declarable/helpers/hashcode.h b/libnd4j/include/ops/declarable/helpers/hashcode.h index 730249d1aa66..583317eb61d3 100644 --- a/libnd4j/include/ops/declarable/helpers/hashcode.h +++ b/libnd4j/include/ops/declarable/helpers/hashcode.h @@ -18,53 +18,52 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_HASHCODE_H -#define DEV_TESTS_HASHCODE_H +#ifndef SD_HASHCODE_H +#define SD_HASHCODE_H #include "helpers.h" namespace sd { - namespace ops { - namespace helpers { - template - FORCEINLINE _CUDA_HD Nd4jLong longBytes(T value); +namespace ops { +namespace helpers { +template +FORCEINLINE _CUDA_HD Nd4jLong longBytes(T value); - template <> - FORCEINLINE _CUDA_HD Nd4jLong longBytes(float value) { - int intie = *(int *)&value; - return static_cast(intie); - } - - template <> - FORCEINLINE _CUDA_HD Nd4jLong longBytes(double value) { - Nd4jLong longie = *(Nd4jLong *)&value; - return longie; - } - - template <> - FORCEINLINE _CUDA_HD Nd4jLong longBytes(float16 value) { - return longBytes((float) value); - } +template <> +FORCEINLINE _CUDA_HD Nd4jLong longBytes(float value) { + int intie = *(int *)&value; + return static_cast(intie); +} - template <> - FORCEINLINE _CUDA_HD Nd4jLong longBytes(Nd4jLong value) { - return value; - } +template <> +FORCEINLINE _CUDA_HD Nd4jLong longBytes(double value) { + Nd4jLong longie = *(Nd4jLong *)&value; + return longie; +} - template <> - FORCEINLINE _CUDA_HD Nd4jLong longBytes(bfloat16 value) { - return longBytes((float) value); - } +template <> +FORCEINLINE _CUDA_HD Nd4jLong longBytes(float16 value) { + return longBytes((float)value); +} - template - FORCEINLINE _CUDA_HD Nd4jLong longBytes(T value) { - return longBytes((Nd4jLong) value); - } +template <> +FORCEINLINE _CUDA_HD Nd4jLong longBytes(Nd4jLong value) { + return value; +} +template <> +FORCEINLINE _CUDA_HD Nd4jLong longBytes(bfloat16 value) { + return longBytes((float)value); +} - void hashCode(LaunchContext *context, NDArray &array, NDArray &result); - } - } +template +FORCEINLINE _CUDA_HD Nd4jLong longBytes(T value) { + return longBytes((Nd4jLong)value); } -#endif //DEV_TESTS_HASHCODE_H +void hashCode(LaunchContext *context, NDArray &array, NDArray &result); +} // namespace helpers +} // namespace ops +} // namespace sd + +#endif // SD_HASHCODE_H diff --git a/libnd4j/include/ops/declarable/helpers/helpers.h b/libnd4j/include/ops/declarable/helpers/helpers.h index c36387e6e57b..6651757329cb 100644 --- a/libnd4j/include/ops/declarable/helpers/helpers.h +++ b/libnd4j/include/ops/declarable/helpers/helpers.h @@ -21,29 +21,28 @@ #ifndef LIBND4J_OPS_HELPERS_H #define LIBND4J_OPS_HELPERS_H -#include -#include +#include +#include #include +#include +#include +#include +#include #include #include -#include -#include -#include + #include -#include -#include +#include #ifdef __CUDACC__ #include -#include -#include #include +#include +#include #include #include #include -#include - -#endif // CUDACC +#endif // CUDACC -#endif // LIBND4J_HELPERS_H +#endif // LIBND4J_HELPERS_H diff --git a/libnd4j/include/ops/declarable/helpers/histogram.h b/libnd4j/include/ops/declarable/helpers/histogram.h index b9738ef07a49..909b521e8774 100644 --- a/libnd4j/include/ops/declarable/helpers/histogram.h +++ b/libnd4j/include/ops/declarable/helpers/histogram.h @@ -24,11 +24,12 @@ #include namespace sd { - namespace ops { - namespace helpers { - void histogramHelper(sd::LaunchContext *context, NDArray &input, NDArray &output); - } - } +namespace ops { +namespace helpers { +void histogramHelper(sd::LaunchContext *context, NDArray &input, + NDArray &output); } +} // namespace ops +} // namespace sd -#endif //DEV_TESTS_HISTOGRAM_H +#endif // SD_HISTOGRAM_H diff --git a/libnd4j/include/ops/declarable/helpers/histogramFixedWidth.h b/libnd4j/include/ops/declarable/helpers/histogramFixedWidth.h index 40ba6ffecfff..065fc001dbbd 100644 --- a/libnd4j/include/ops/declarable/helpers/histogramFixedWidth.h +++ b/libnd4j/include/ops/declarable/helpers/histogramFixedWidth.h @@ -27,11 +27,11 @@ namespace sd { namespace ops { namespace helpers { -void histogramFixedWidth(sd::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output); +void histogramFixedWidth(sd::LaunchContext* context, const NDArray& input, + const NDArray& range, NDArray& output); - -} -} } +} // namespace ops +} // namespace sd -#endif //LIBND4J_HELPERS_HISTOGRAMFIXEDWIDTH_H +#endif // LIBND4J_HELPERS_HISTOGRAMFIXEDWIDTH_H diff --git a/libnd4j/include/ops/declarable/helpers/im2col.h b/libnd4j/include/ops/declarable/helpers/im2col.h index 6b61535f96ed..7c2cdfcaada7 100644 --- a/libnd4j/include/ops/declarable/helpers/im2col.h +++ b/libnd4j/include/ops/declarable/helpers/im2col.h @@ -27,9 +27,12 @@ namespace sd { namespace ops { namespace helpers { - ND4J_EXPORT void im2col(sd::LaunchContext & context, const NDArray& im, NDArray& col, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const NDArray& arrZeroPadVal); -} -} +SD_EXPORT void im2col(sd::LaunchContext& context, const NDArray& im, + NDArray& col, const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, const int dH, + const int dW, const NDArray& arrZeroPadVal); } +} // namespace ops +} // namespace sd -#endif //LIBND4J_HELPERS_H +#endif // LIBND4J_HELPERS_H diff --git a/libnd4j/include/ops/declarable/helpers/image_draw_bounding_boxes.h b/libnd4j/include/ops/declarable/helpers/image_draw_bounding_boxes.h index 758a02e31e34..cdb4303f4081 100644 --- a/libnd4j/include/ops/declarable/helpers/image_draw_bounding_boxes.h +++ b/libnd4j/include/ops/declarable/helpers/image_draw_bounding_boxes.h @@ -19,16 +19,17 @@ // #ifndef __IMAGE_DRAW_BOUNDING_BOXES_H_HELPERS__ #define __IMAGE_DRAW_BOUNDING_BOXES_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void drawBoundingBoxesFunctor(sd::LaunchContext * context, NDArray* images, NDArray* boxes, NDArray* colors, NDArray* output); +void drawBoundingBoxesFunctor(sd::LaunchContext* context, NDArray* images, + NDArray* boxes, NDArray* colors, NDArray* output); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/image_resize.h b/libnd4j/include/ops/declarable/helpers/image_resize.h index bd9e10b58200..5aac4ec47872 100644 --- a/libnd4j/include/ops/declarable/helpers/image_resize.h +++ b/libnd4j/include/ops/declarable/helpers/image_resize.h @@ -20,44 +20,53 @@ // #ifndef __IMAGE_RESIZE_HELPERS__ #define __IMAGE_RESIZE_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - enum ImageResizeMethods { - kResizeBilinear = 0, // as java require +enum ImageResizeMethods { + kResizeBilinear = 0, // as java require kResizeNearest, - kResizeBicubic, - kResizeArea, - kResizeGaussian, - kResizeLanczos3, - kResizeLanczos5, - kResizeMitchellcubic, - kResizeFirst = kResizeBilinear, + kResizeBicubic, + kResizeArea, + kResizeGaussian,kResizeLanczos3, + kResizeLanczos5, + kResizeMitchellcubic, + kResizeFirst = kResizeBilinear, kResizeLast = kResizeMitchellcubic, - kResizeOldLast = kResizeArea - }; + kResizeOldLast =kResizeArea +}; - int resizeBilinearFunctor(sd::LaunchContext * context, NDArray const* image, int const width, int const height, - bool const alignCorners, bool const halfPixelCenter, NDArray* output); - int resizeNeighborFunctor(sd::LaunchContext * context, NDArray const* image, int const width, int const height, - bool const alignCorners, bool const halfPixelCenter, NDArray* output); - int resizeBicubicFunctor(sd::LaunchContext * context, NDArray const* image, int const width, int const height, - bool preserveAspectRatio, bool antialias, NDArray* output); - int resizeBicubicFunctorA(sd::LaunchContext * context, NDArray const* image, int const width, int const height, - bool const alignCorners, bool const halfPixelAlign, NDArray* output); - int resizeAreaFunctor(sd::LaunchContext * context, NDArray const* image, int const width, int const height, - bool const alignCorners, NDArray* output); +int resizeBilinearFunctor(sd::LaunchContext* context, NDArray const* image, + int const width, int const height, + bool const alignCorners, bool const halfPixelCenter, + NDArray* output); +int resizeNeighborFunctor(sd::LaunchContext* context, NDArray const* image, + int const width, int const height, + bool const alignCorners, bool const halfPixelCenter, + NDArray* output); +int resizeBicubicFunctor(sd::LaunchContext* context, NDArray const* image, + int const width, int const height, + bool preserveAspectRatio, bool antialias, + NDArray* output); +int resizeBicubicFunctorA(sd::LaunchContext* context, NDArray const* image, + int const width, int const height, + bool const alignCorners, bool const halfPixelAlign, + NDArray* output); +int resizeAreaFunctor(sd::LaunchContext* context, NDArray const* image, + int const width, int const height, + bool const alignCorners, NDArray* output); - int resizeFunctor(sd::LaunchContext * context, NDArray const* image, int const width, int const height, - ImageResizeMethods method, bool antialias, NDArray* output); +int resizeFunctor(sd::LaunchContext* context, NDArray const* image, + int const width, int const height, ImageResizeMethods method, + bool antialias, NDArray* output); int resizeImagesFunctor(sd::LaunchContext * context, NDArray const* image, int const width, int const height, ImageResizeMethods method, bool alignCorners, NDArray* output); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/image_suppression.h b/libnd4j/include/ops/declarable/helpers/image_suppression.h index a8d2027b8708..139a70b4cb45 100644 --- a/libnd4j/include/ops/declarable/helpers/image_suppression.h +++ b/libnd4j/include/ops/declarable/helpers/image_suppression.h @@ -19,21 +19,26 @@ // #ifndef __IMAGE_SUPPRESSION_H_HELPERS__ #define __IMAGE_SUPPRESSION_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void nonMaxSuppression(sd::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, - double overlapThreshold, double scoreThreshold, NDArray* output); - Nd4jLong nonMaxSuppressionV3(sd::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, - double overlapThreshold, double scoreThreshold, NDArray* output); - Nd4jLong nonMaxSuppressionGeneric(sd::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize, - double overlapThreshold, double scoreThreshold, NDArray* output); +void nonMaxSuppression(sd::LaunchContext* context, NDArray* boxes, + NDArray* scales, int maxSize, double overlapThreshold, + double scoreThreshold, NDArray* output); +Nd4jLong nonMaxSuppressionV3(sd::LaunchContext* context, NDArray* boxes, + NDArray* scales, int maxSize, + double overlapThreshold, double scoreThreshold, + NDArray* output); +Nd4jLong nonMaxSuppressionGeneric(sd::LaunchContext* context, NDArray* boxes, + NDArray* scores, int maxSize, + double overlapThreshold, + double scoreThreshold, NDArray* output); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/imagesHelpers.h b/libnd4j/include/ops/declarable/helpers/imagesHelpers.h index 3a1666c7aca0..25fb3d3e8a73 100644 --- a/libnd4j/include/ops/declarable/helpers/imagesHelpers.h +++ b/libnd4j/include/ops/declarable/helpers/imagesHelpers.h @@ -16,7 +16,7 @@ // // @author Oleh Semeniv (oleg.semeniv@gmail.com) -// +// // // @author AbdelRauf (rauf@konduit.ai) // @@ -24,27 +24,34 @@ #ifndef LIBND4J_HELPERS_IMAGES_H #define LIBND4J_HELPERS_IMAGES_H -#include -#include #include +#include +#include namespace sd { namespace ops { namespace helpers { - void transformRgbGrs(sd::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC); +void transformRgbGrs(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const int dimC); - void transformHsvRgb(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC); +void transformHsvRgb(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC); - void transformRgbHsv(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC); - void transformYuvRgb(sd::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC); - void transformRgbYuv(sd::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC); +void transformRgbHsv(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC); +void transformYuvRgb(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const int dimC); +void transformRgbYuv(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const int dimC); - void transformYiqRgb(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC); +void transformYiqRgb(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC); - void transformRgbYiq(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC); -} -} -} +void transformRgbYiq(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC); +} // namespace helpers +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/impl/choose.cpp b/libnd4j/include/ops/declarable/helpers/impl/choose.cpp index 2f574a52edc4..5c39c63068cc 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/choose.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/choose.cpp @@ -18,137 +18,137 @@ // @author sgazeos@gmail.com // -#include #include +#include #include namespace sd { namespace ops { namespace helpers { - template - static sd::NDArray* processCondition_(int mode,sd::NDArray *arg, sd::NDArray *comp, sd::NDArray& compScalar); - - template - static T processElementCondition(int mode,T d1,T d2); - - - template - sd::NDArray* processCondition_(int mode,sd::NDArray *arg, sd::NDArray *comp, sd::NDArray *output, sd::NDArray *numResult, sd::NDArray& compScalar) { - - //Convert to straight ndarray based on input - - int numResults = 0; - if(comp != nullptr) { - if (comp->isScalar()) { - //Other input for compare could be an ndarray or a secondary scalar - //for comparison -// sd::NDArray arg1 = *arg; -// sd::NDArray comp1 = *comp; - for (Nd4jLong i = 0; i < arg->lengthOf(); i++) { - T result2 = processElementCondition(mode, arg->e(i), comp->e(0)); - if(result2 > static_cast(0)) { - if (output != nullptr) - output->p(numResults, arg->e(i)); - numResults++; - } - } - } else { - // REQUIRE_TRUE(comp.isSameShape(arg)); - //Other input for compare could be an ndarray or a secondary scalar - //for comparison - sd::NDArray arg1 = *arg; - for (Nd4jLong i = 0; i < arg->lengthOf(); i++) { - T result2 = processElementCondition(mode, arg->e(i), comp->e(i)); - if(result2 > static_cast(0)) { - if (output != nullptr) - output->p(numResults, arg->e(i)); - numResults++; - } - } - } - +template +static sd::NDArray* processCondition_(int mode, sd::NDArray* arg, + sd::NDArray* comp, + sd::NDArray& compScalar); + +template +static T processElementCondition(int mode, T d1, T d2); + +template +sd::NDArray* processCondition_(int mode, sd::NDArray* arg, sd::NDArray* comp, + sd::NDArray* output, sd::NDArray* numResult, + sd::NDArray& compScalar) { + // Convert to straight ndarray based on input + + int numResults = 0; + if (comp != nullptr) { + if (comp->isScalar()) { + // Other input for compare could be an ndarray or a secondary scalar + // for comparison + // sd::NDArray arg1 = *arg; + // sd::NDArray comp1 = *comp; + for (Nd4jLong i = 0; i < arg->lengthOf(); i++) { + T result2 = processElementCondition(mode, arg->e(i), comp->e(0)); + if (result2 > static_cast(0)) { + if (output != nullptr) output->p(numResults, arg->e(i)); + numResults++; } - else { - // sd::NDArray arg1 = *arg; - //Other input for compare could be an ndarray or a secondary scalar - //for comparison - for (Nd4jLong i = 0; i < arg->lengthOf(); i++) { - T result2 = processElementCondition(mode, arg->e(i), compScalar.e(0)); - if(result2 > static_cast(0)) { - if (output != nullptr) - output->p(numResults, arg->e(i)); - numResults++; - } - } + } + } else { + // REQUIRE_TRUE(comp.isSameShape(arg)); + // Other input for compare could be an ndarray or a secondary scalar + // for comparison + sd::NDArray arg1 = *arg; + for (Nd4jLong i = 0; i < arg->lengthOf(); i++) { + T result2 = processElementCondition(mode, arg->e(i), comp->e(i)); + if (result2 > static_cast(0)) { + if (output != nullptr) output->p(numResults, arg->e(i)); + numResults++; } - - if(numResult != nullptr) - numResult->p(0,numResults); - - return output; + } } - sd::NDArray* processCondition(sd::LaunchContext * context, int mode,sd::NDArray *arg, sd::NDArray *comp, sd::NDArray *output, sd::NDArray *numResult, sd::NDArray& compScalar) { - arg->syncToHost(); + } else { + // sd::NDArray arg1 = *arg; + // Other input for compare could be an ndarray or a secondary scalar + // for comparison + for (Nd4jLong i = 0; i < arg->lengthOf(); i++) { + T result2 = + processElementCondition(mode, arg->e(i), compScalar.e(0)); + if (result2 > static_cast(0)) { + if (output != nullptr) output->p(numResults, arg->e(i)); + numResults++; + } + } + } - if (comp != nullptr) - comp->syncToHost(); + if (numResult != nullptr) numResult->p(0, numResults); - if (output != nullptr) - output->syncToHost(); + return output; +} - if (numResult != nullptr) - numResult->syncToHost(); +sd::NDArray* processCondition(sd::LaunchContext* context, int mode, + sd::NDArray* arg, sd::NDArray* comp, + sd::NDArray* output, sd::NDArray* numResult, + sd::NDArray& compScalar) { + arg->syncToHost(); - compScalar.syncToHost(); + if (comp != nullptr) comp->syncToHost(); - BUILD_SINGLE_SELECTOR(arg->dataType(), return processCondition_, (mode, arg, comp, output, numResult, compScalar), FLOAT_TYPES); + if (output != nullptr) output->syncToHost(); - arg->syncToDevice(); + if (numResult != nullptr) numResult->syncToHost(); - if (comp != nullptr) - comp->syncToDevice(); + compScalar.syncToHost(); - if (output != nullptr) - output->syncToDevice(); + BUILD_SINGLE_SELECTOR(arg->dataType(), return processCondition_, + (mode, arg, comp, output, numResult, compScalar), + FLOAT_TYPES); - if (numResult != nullptr) - numResult->syncToDevice(); - - compScalar.syncToDevice(); + arg->syncToDevice(); - } - BUILD_SINGLE_TEMPLATE(template NDArray* processCondition_, (int mode,sd::NDArray *arg, sd::NDArray *comp, sd::NDArray *output, sd::NDArray *numResult, sd::NDArray& compScalar), FLOAT_TYPES); + if (comp != nullptr) comp->syncToDevice(); - template - T processElementCondition(int mode,T d1,T d2) { - T modePointer = (T ) mode; - T input[3] = {d2, (T) EPS, (T) mode}; - T res = simdOps::MatchCondition::op(d1, input); - return res; + if (output != nullptr) output->syncToDevice(); - } + if (numResult != nullptr) numResult->syncToDevice(); - void chooseFunctorArray(sd::LaunchContext * context, NDArray* arg, NDArray* comp, int mode, NDArray* result, NDArray* numResults) { - if(arg->isScalar() || comp->isScalar()) { - if(arg->isScalar()) { - processCondition(context, mode,comp,nullptr,result,numResults, *arg); - } - else { - processCondition(context, mode,arg,nullptr,result,numResults, *comp); - } - } - else { - auto zero = NDArrayFactory::create(0); - processCondition(context, mode,arg,comp,result,numResults, zero); - } - } + compScalar.syncToDevice(); +} +BUILD_SINGLE_TEMPLATE(template NDArray* processCondition_, + (int mode, sd::NDArray* arg, sd::NDArray* comp, + sd::NDArray* output, sd::NDArray* numResult, + sd::NDArray& compScalar), + FLOAT_TYPES); + +template +T processElementCondition(int mode, T d1, T d2) { + T modePointer = (T)mode; + T input[3] = {d2, (T)EPS, (T)mode}; + T res = simdOps::MatchCondition::op(d1, input); + return res; +} - void chooseFunctorScalar(sd::LaunchContext * context, NDArray* arg, double scalar, int mode, NDArray* result, NDArray* numResults) { - auto scalarA = NDArrayFactory::create(scalar); - processCondition(context, mode, arg, nullptr,result, numResults, scalarA); +void chooseFunctorArray(sd::LaunchContext* context, NDArray* arg, NDArray* comp, + int mode, NDArray* result, NDArray* numResults) { + if (arg->isScalar() || comp->isScalar()) { + if (arg->isScalar()) { + processCondition(context, mode, comp, nullptr, result, numResults, *arg); + } else { + processCondition(context, mode, arg, nullptr, result, numResults, *comp); } - -} + } else { + auto zero = NDArrayFactory::create(0); + processCondition(context, mode, arg, comp, result, numResults, zero); + } } + +void chooseFunctorScalar(sd::LaunchContext* context, NDArray* arg, + double scalar, int mode, NDArray* result, + NDArray* numResults) { + auto scalarA = NDArrayFactory::create(scalar); + processCondition(context, mode, arg, nullptr, result, numResults, scalarA); } + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/impl/gru.cpp b/libnd4j/include/ops/declarable/helpers/impl/gru.cpp index 277188428cb1..98c5f191ab5e 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/gru.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/gru.cpp @@ -7,9 +7,9 @@ * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software - * dnIntributed under the License nIn dnIntributed on an "AS nIn" BASnIn, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permnInsions and limitations + * dnIntributed under the License nIn dnIntributed on an "AS nIn" BASnIn, + *WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See + *the License for the specific language governing permnInsions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 @@ -21,526 +21,554 @@ // implementation of gated Recurrent Unit cell // (cf. https://arxiv.org/abs/1406.1078). -// Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, Yoshua Bengio -// "Learning Phrase Representations using RNN Encoder-Decoder for StatnIntical Machine Translation" +// Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi +// Bougares, Holger Schwenk, Yoshua Bengio "Learning Phrase Representations +// using RNN Encoder-Decoder for StatnIntical Machine Translation" - -#include +#include #include +#include #include -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// -void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hI, const NDArray* W, const NDArray* Wc, - const NDArray* b, const NDArray* bc, - NDArray* r, NDArray* u, NDArray* c, NDArray* h) { - - //Inputs: - // x input [bS, nIn], nIn - input size - // hI previous cell output [bS, nOut], that is at previous time step t-1, nOut - number of units - // W RU weights - [nIn+nOut, 2*nOut] - reset and update gates - // Wc C weights - [nIn+nOut, nOut] - cell gate - // b r and u biases, [2*nOut] - reset and update gates - // bc c biases, [nOut] - cell gate - - //Outputs: - // r Reset gate output [bS, nOut] - // u Update gate output [bS, nOut] - // c Cell gate output [bS, nOut] - // h current cell output [bS, nOut] - - /***************************************************************************************/ - /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ - /** however it is more math-friendly and convenient for backprop formulas derivation) **/ - - const int bS = x->sizeAt(0); - const int nIn = x->sizeAt(1); - const int nOut = hI->sizeAt(1); - - NDArray Wrx = (*W)({0,nIn, 0,nOut}); // [nIn, nOut] - NDArray Wux = (*W)({0,nIn, nOut,2*nOut}); // [nIn, nOut] - NDArray Wrh = (*W)({nIn,nIn+nOut, 0,nOut}); // [nOut, nOut] - NDArray Wuh = (*W)({nIn,nIn+nOut, nOut,2*nOut}); // [nOut, nOut] - - NDArray Wcx = (*Wc)({0,nIn, 0,0}); // reset cell weights [nIn, nOut] - NDArray Wch = (*Wc)({nIn,nIn+nOut, 0,0}); // updates cell weights [nOut, nOut] - - NDArray br = (*b)({0, nOut}); // [nOut] - NDArray bu = (*b)({nOut, 2*nOut}); // [nOut] - - // × means matrix multipication - // * means element-wise product or so called Hadamard product - - // reset gate - r->assign(mmul(*x, Wrx) + mmul(*hI, Wrh) + br); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + [nOut] = [bS, nOut] - r->applyTransform(transform::Sigmoid, *r); - - // update gate - u->assign(mmul(*x, Wux) + mmul(*hI, Wuh) + bu); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + [nOut] = [bS, nOut] - u->applyTransform(transform::Sigmoid, *u); - - // cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc) - c->assign(mmul(*x, Wcx) + mmul(*r * *hI, Wch) + *bc); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + [nOut] = [bS, nOut] - c->applyTransform(transform::Tanh, *c); - - // cell output - h->assign(*u * *hI + (1.f - *u) * *c); +void gruCell(sd::LaunchContext* context, const NDArray* x, const NDArray* hI, + const NDArray* W, const NDArray* Wc, const NDArray* b, + const NDArray* bc, NDArray* r, NDArray* u, NDArray* c, + NDArray* h) { + // Inputs: + // x input [bS, nIn], nIn - input size + // hI previous cell output [bS, nOut], that is at previous time step + // t-1, nOut - number of units W RU weights - [nIn+nOut, 2*nOut] - + // reset and update gates Wc C weights - [nIn+nOut, nOut] - cell gate b + // r and u biases, [2*nOut] - reset and update gates bc c biases, [nOut] + // - cell gate + + // Outputs: + // r Reset gate output [bS, nOut] + // u Update gate output [bS, nOut] + // c Cell gate output [bS, nOut] + // h current cell output [bS, nOut] + + /***************************************************************************************/ + /************************ THIS IS NOT OPTIMAZED CODE + * ***********************************/ + /** however it is more math-friendly and convenient for backprop formulas + * derivation) **/ + + const int bS = x->sizeAt(0); + const int nIn = x->sizeAt(1); + const int nOut = hI->sizeAt(1); + + NDArray Wrx = (*W)({0, nIn, 0, nOut}); // [nIn, nOut] + NDArray Wux = (*W)({0, nIn, nOut, 2 * nOut}); // [nIn, nOut] + NDArray Wrh = (*W)({nIn, nIn + nOut, 0, nOut}); // [nOut, nOut] + NDArray Wuh = (*W)({nIn, nIn + nOut, nOut, 2 * nOut}); // [nOut, nOut] + + NDArray Wcx = (*Wc)({0, nIn, 0, 0}); // reset cell weights [nIn, nOut] + NDArray Wch = + (*Wc)({nIn, nIn + nOut, 0, 0}); // updates cell weights [nOut, nOut] + + NDArray br = (*b)({0, nOut}); // [nOut] + NDArray bu = (*b)({nOut, 2 * nOut}); // [nOut] + + // × means matrix multipication + // * means element-wise product or so called Hadamard product + + // reset gate + r->assign(mmul(*x, Wrx) + mmul(*hI, Wrh) + + br); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + + // [nOut] = [bS, nOut] + r->applyTransform(transform::Sigmoid, *r); + + // update gate + u->assign(mmul(*x, Wux) + mmul(*hI, Wuh) + + bu); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + + // [nOut] = [bS, nOut] + u->applyTransform(transform::Sigmoid, *u); + + // cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc) + c->assign(mmul(*x, Wcx) + mmul(*r * *hI, Wch) + + *bc); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + + // [nOut] = [bS, nOut] + c->applyTransform(transform::Tanh, *c); + + // cell output + h->assign(*u * *hI + (1.f - *u) * *c); } ////////////////////////////////////////////////////////////////////////// -void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, +void gruCell(sd::LaunchContext* context, const NDArray* x, const NDArray* hI, + const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* gates, NDArray* h) { - - //Inputs: - // x input [bS, nIn] - // hI previous cell output [bS, nOut], that is at previous time step t-1 - // Wx weights for x - [nIn, 3*nOut] - // Wh weights for h - [nOut, 3*nOut] - // b biases [3*nOut] - - // 3*nOut means following sequence: reset, update, cell - - //Outputs: - // gates [bS, 3*nOut] = reset gate [bS, nOut] + update gate [bS, nOut] + cell gate [bS, nOut] - // h current cell output [bS, nOut] - - // formulas: - // zr = x × Wxr + hI × Whr + br - // zu = x × Wxu + hI × Whu + bu - // r = sigmoid(zr) - // u = sigmoid(zu) - // zc = x × Wxc + (r * hI) × Whc + bc - // c = tanh(zc) - // h = (1-u)*c + u*hI - - const int bS = x->sizeAt(0); - const int nIn = x->sizeAt(1); - const int nOut = hI->sizeAt(1); - - NDArray temp = gates->ulike(); - MmulHelper::mmul(x, Wx, &temp); // [bS, nIn] × [nIn, 3*nOut] = [bS, 3*nOut] - temp += *b; - - MmulHelper::mmul(hI, Wh, gates); // [bS, nOut] × [nOut, 3*nOut] = [bS, 3*nOut] - - NDArray ru = (*gates)({0,0, 0,2*nOut}); // [bS, 2*nOut] - - NDArray r = (*gates)({0,0, 0,nOut}); // [bS, nOut] - NDArray u = (*gates)({0,0, nOut,2*nOut}); // [bS, nOut] - NDArray c = (*gates)({0,0, 2*nOut,3*nOut}); // [bS, nOut] - - // reset and update gates - ru += temp({0,0, 0,2*nOut}); - ru.applyTransform(transform::Sigmoid, ru); - - // cell gate - c.assign(c*r + temp({0,0, 2*nOut, 3*nOut})); - c.applyTransform(transform::Tanh, c); - - // cell output - h->assign(u * *hI + (1.f - u) * c); + // Inputs: + // x input [bS, nIn] + // hI previous cell output [bS, nOut], that is at previous time step + // t-1 Wx weights for x - [nIn, 3*nOut] Wh weights for h - [nOut, + // 3*nOut] b biases [3*nOut] + + // 3*nOut means following sequence: reset, update, cell + + // Outputs: + // gates [bS, 3*nOut] = reset gate [bS, nOut] + update gate [bS, nOut] + + // cell gate [bS, nOut] h current cell output [bS, nOut] + + // formulas: + // zr = x × Wxr + hI × Whr + br + // zu = x × Wxu + hI × Whu + bu + // r = sigmoid(zr) + // u = sigmoid(zu) + // zc = x × Wxc + (r * hI) × Whc + bc + // c = tanh(zc) + // h = (1-u)*c + u*hI + + const int bS = x->sizeAt(0); + const int nIn = x->sizeAt(1); + const int nOut = hI->sizeAt(1); + + NDArray temp = gates->ulike(); + MmulHelper::mmul(x, Wx, &temp); // [bS, nIn] × [nIn, 3*nOut] = [bS, 3*nOut] + temp += *b; + + MmulHelper::mmul(hI, Wh, + gates); // [bS, nOut] × [nOut, 3*nOut] = [bS, 3*nOut] + + NDArray ru = (*gates)({0, 0, 0, 2 * nOut}); // [bS, 2*nOut] + + NDArray r = (*gates)({0, 0, 0, nOut}); // [bS, nOut] + NDArray u = (*gates)({0, 0, nOut, 2 * nOut}); // [bS, nOut] + NDArray c = (*gates)({0, 0, 2 * nOut, 3 * nOut}); // [bS, nOut] + + // reset and update gates + ru += temp({0, 0, 0, 2 * nOut}); + ru.applyTransform(transform::Sigmoid, ru); + + // cell gate + c.assign(c * r + temp({0, 0, 2 * nOut, 3 * nOut})); + c.applyTransform(transform::Tanh, c); + + // cell output + h->assign(u * *hI + (1.f - u) * c); } ////////////////////////////////////////////////////////////////////////// -void gruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) { - - // sL means time steps - - // x input [sL, bS, nIn] - // hI initial cell output (at time step = 0) [bS, nOut] - // Wx input-to-hidden weights, [nIn, 3*nOut] - // Wh hidden-to-hidden weights, [nOut, 3*nOut] - // b biases, [3*nOut] - - // h cell outputs at each time step [sL, bS, nOut] - - const int sL = x->sizeAt(0); - const int bS = x->sizeAt(1); - const int nOut = hI->sizeAt(1); - - NDArray gates(h->ordering(), {bS, 3*nOut}, h->dataType(), context); - - auto xSet = x->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nIn] - auto hSet = h->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut] - - // time loop - for (int t = 0; t < sL; ++t) - gruCell(context, xSet.at(t), t == 0 ? hI : hSet.at(t-1), Wx, Wh, b, &gates, hSet.at(t)); +void gruTimeLoop(sd::LaunchContext* context, const NDArray* x, + const NDArray* hI, const NDArray* Wx, const NDArray* Wh, + const NDArray* b, NDArray* h) { + // sL means time steps + + // x input [sL, bS, nIn] + // hI initial cell output (at time step = 0) [bS, nOut] + // Wx input-to-hidden weights, [nIn, 3*nOut] + // Wh hidden-to-hidden weights, [nOut, 3*nOut] + // b biases, [3*nOut] + + // h cell outputs at each time step [sL, bS, nOut] + + const int sL = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int nOut = hI->sizeAt(1); + + NDArray gates(h->ordering(), {bS, 3 * nOut}, h->dataType(), context); + + auto xSet = + x->allTensorsAlongDimension({1, 2}); // sub-arrays with shape [bS, nIn] + auto hSet = + h->allTensorsAlongDimension({1, 2}); // sub-arrays with shape [bS, nOut] + + // time loop + for (int t = 0; t < sL; ++t) + gruCell(context, &xSet.at(t), t == 0 ? hI : &hSet.at(t - 1), Wx, Wh, b, + &gates, &hSet.at(t)); } ////////////////////////////////////////////////////////////////////////// -void gruCellBp(sd::LaunchContext* context, - const NDArray* x, const NDArray* hLast, - const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc, - const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, - NDArray* dLdx, NDArray* dLdhLast, - NDArray* dLdW, NDArray* dLdWc, - NDArray* dLdb, NDArray* dLdbc) { - - //Inputs: - // x input [bS, iS] - // hLast previous cell output [bS, nU], that is at previous time step t-1 - // W weights - [iS+nU, 2*nU] - reset and update gates - // Wc C weights - [iS+nU, nU] - cell gate - // b r and u biases, [2*nU] - reset and update gates - // bc c biases, [nU] - cell gate - // dLdr gradient wrt reset gate, [bS, nU] - // dLdu gradient wrt update gate, [bS, nU] - // dLdc gradient wrt cell state, [bS, nU] - // dLdh gradient wrt current cell output, [bS, nU] - - //Outputs: - // dLdx gradient wrt x, [bS, iS], - // dLdhLast gradient wrt hLast, [bS, nU] - // dLdW gradient wrt W, [iS+nU, 2*nU] - // dLdWc gradient wrt Wc, [iS+nU, nU] - // dLdb gradient wrt bru [2*nU] - // dLdbc gradient wrt bc [nU] - - // * means element-wise product or so called Hadamard product - // × means matrix multiplication - - /************************************************************************************************/ - /******************************* THIS IS NOT OPTIMAZED CODE *************************************/ - /*** aim is to have math-readable code in order to keep track of backprop formulas derivation ***/ - - const int bS = x->sizeAt(0); - const int iS = x->sizeAt(1); - const int nU = hLast->sizeAt(1); - - NDArray xT = x->transpose(); // [iS, bS] - NDArray hLastT = hLast->transpose(); // [nU, bS] - - NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU] - NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU] - NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU] - NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU] - - NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU] - NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU] - - NDArray br = (*b)({0, nU}); // [nU] - NDArray bu = (*b)({nU, 2*nU}); // [nU] - - NDArray WrxT = Wrx.transpose(); // [nU, iS] - NDArray WuxT = Wux.transpose(); // [nU, iS] - NDArray WrhT = Wrh.transpose(); // [nU, nU] - NDArray WuhT = Wuh.transpose(); // [nU, nU] - - NDArray WcxT = Wcx.transpose(); // [nU, iS] - NDArray WchT = Wch.transpose(); // [nU, nU] - - NDArray dLdWrx = (*dLdW)({0,iS, 0,nU}); // [iS, nU] - NDArray dLdWux = (*dLdW)({0,iS, nU,2*nU}); // [iS, nU] - NDArray dLdWrh = (*dLdW)({iS,iS+nU, 0,nU}); // [nU, nU] - NDArray dLdWuh = (*dLdW)({iS,iS+nU, nU,2*nU}); // [nU, nU] - - NDArray dLdWcx = (*dLdWc)({0,iS, 0,0}); // [iS, nU] - NDArray dLdWch = (*dLdWc)({iS,iS+nU, 0,0}); // [nU, nU] - - NDArray dLdbr = (*dLdb)({0, nU}); // [nU] - NDArray dLdbu = (*dLdb)({nU, 2*nU}); // [nU] - - - // ***** feed forward step ***** // - - // reset gate - NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - r.applyTransform(transform::Sigmoid, r); - - // update gate - NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - u.applyTransform(transform::Sigmoid, u); - - // cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc) - NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - c.applyTransform(transform::Tanh, c); - - // h = (1 - u) * c + u * hPrev - - - // ***** back prop step ***** // - - // notations: - // Zr = x × Wrx + hLast × Wrh + br - // Zu = x × Wux + hLast × Wuh + bu - // Sr = sigmoid(Zr) - // Su = sigmoid(Zu) - // Zc = x × Wcx + (r * hlast) × Wch + bc - - - // dLdx = dLdh * dhdx = dLdh * (dhdu * dudx + dhdc * dcdx) = (dLdh * dhdu) * dudx + (dLdh * dhdc) * dcdx = dLdu * dudx + dLdc * dcdx - // = dLdx_u + dLdx_c - // dLdx_u = dLdu * dudx = dLdu * dudZu * dZudx = |dZudx = ... × WuxT| = (dLdu * dudZu) × WuxT - // dLdx_c = dLdc * dcdx = dLdc * dcdZc * (dZcdx + dZcdr * drdx) = dLdc * dcdZc * dZcdx + dLdc * dcdZc * dZcdr * drdx = dLdx_c0 + dLdx_c1 - // dLdx_c0 = dLdc * dcdZc * dZcdx = |dZcdx = ... × WcxT| = (dLdc * dcdZc) × WcxT - // dZcdr = (... * hLast) × WchT - // dLdc * dcdZc * dZcdr = dLdr = (dLdc * dcdZc * hLast) × WchT - // drdx = drdZr * dZrdx - // dZrdx = ... × WrxT - // dLdx_c1 = dLdc * dcdZc * dZcdr * drdx = dLdr * drdx = (dLdr * drdZr) × WrxT - // finally dLdx = dLdx_u + dLdx_c0 + dLdx_c1 = (dLdu * dudZu) × WuxT + (dLdc * dcdZc) × WcxT + (dLdr * drdZr) × WrxT - - - // dLdhLast = dLdh * (dhdhLast + dhdu * dudhLast + dhdc * dcdhLast) = dLdh * dhdhLast + dLdu * dudhLast + dLdc * dcdhLast - // = dLdhLast_h + dLdhLast_u + dLdhLast_c - // dLdhLast_h = dLdh * dhdhLas = dLdh * u - // dLdhLast_u = dLdu * dudhLast = |dudhLast = dudZu * dZudhLast , dZudhLast = ... × WuhT| = (dLdu * dudZu) × WuhT - // dLdhLast_c = dLdc * dcdhLast = dLdc * (dcdZc * dZcdhLast + dcdZc * dZcdr * drdhLast) = - // = dLdc * dcdZc * dZcdhLast + dLdc * dcdZc * dZcdr * drdhLast = - // = dLdc * dcdZc * dZcdhLast + dLdr * drdhLast = dLdhLast_c0 + dLdhLast_c1 - // dLdhLast_c0 = dLdc * dcdZc * dZcdhLast = |dZcdhLast = (... * r) × WchT| = (dLdc * dcdZc * r) × WchT - // dLdhLast_c1 = dLdr * drdhLast = |drdhLast = drdZr * dZrdhLast, dZrdhLast = ... × WrhT| = (dLdr * drdZr) × WrhT - // finally dLdhLast = dLdhLast_h + dLdhLast_u + dLdhLast_c0 + dLdhLast_c1 = - // = dLdh * u + (dLdu * dudZu) × WuhT + (dLdc * dcdZc * r) × WchT + (dLdr * drdZr) × WrhT - - - // dLdWrx = dLdh * dhdWrx = (dLdh * dhdc) * dcdWrx = dLdc * dcdZc * dZcdWrx = dLdc * dcdZc * dZcdr * drdWrx = - // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrx = dLdr * drdZr * dZrdWrx - // dZrdWrx = xT × ... - // finally dLdWrx = xT × (dLdr * drdZr) - - - // dLdWrh = dLdh * dhdWrh = (dLdh * dhdc) * dcdWrh = dLdc * dcdZc * dZcdWrh = dLdc * dcdZc * dZcdr * drdWrh = - // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrh = dLdr * drdZr * dZrdWrh - // dZrdWrh = hLastT × ... - // finally dLdWrh = hLastT × (dLdr * drdZr) - - - // dLdWux = dLdh * dhdWux = (dLdh * dhdu) * dudWux = dLdu * dudZu * dZudWux - // dZudWux = xT × ... - // dLdu * dudZu * dZudWux = xT × (dLdu * dudZu) - - - // dLdWuh = dLdh * dhdWuh = (dLdh * dhdu) * dudWuh = dLdh * dhdu * dudZu * dZudWuh = dLdu * dudZu * dZudWuh - // dZudWuh = hLastT × ... - // finally dLdWuh = hLastT × (dLdu * dudZu) - - - // dLdWcx = dLdh * dhdWcx = dLdh * dhdc * dcdWcx = (dLdh * dhdc) * dcdZc * dZcdWcx = dLdc * dcdZc * dZcdWcx - // dZcdWcx = xT × ... - // finally dLdWcx = xT × (dLdc * dcdZc) - - - // dLdWch = dLdh * dhdWch = dLdh * dhdc * dcdWch = (dLdh * dhdc) * dcdZc * dZcdWch = dLdc * dcdZc * dZcdWch - // dZcdWch = (r*hLast)^T × ... - // finally dLdWch = (r*hLast)^T × (dLdc * dcdZc) - - - // dLdbr = dLdh * dhdbr = (dLdh * dhdc) * dcdbr = dLdc * dcdbr = dLdc * dcdZc * dZcdbr = dLdc * dcdZc * dZcdr * drdbr = - // = dLdr * drdZr * dZrdbr - // dZrdbr = 1 - // finally dLdbr = dLdr * drdZr - - - // dLdbu = dLdh * dhdbu = (dLdh * dhdu) * dudbu = dLdu * dudZu * dZudbu - // dZudbu = 1 - // finally dLdbu = dLdu * dudZu - - - // dLdbc = dLdh * dhdbc = (dLdh * dhdc) * dcdbc = dLdc * dcdZc * dZcdbc - // dZcdbc = 1 - // finally dLdbc = dLdc * dcdZc - - NDArray dhdc = 1.f - u; // [bS, nU] - NDArray dhdu = *hLast - c; // [bS, nU] - NDArray dudZu = u * dhdc; // [bS, nU] - NDArray drdZr = r * (1.f - r); // [bS, nU] - NDArray dcdZc = 1.f - c * c; // [bS, nU] - NDArray dLdZc = *dLdc * dcdZc; // [bS, nU] - NDArray dLdZu = *dLdu * dudZu; // [bS, nU] - NDArray dLdZr = *dLdr * drdZr; // [bS, nU] - - // NDArray dLdc = *dLdh * dhdc; // [bS, nU] - // NDArray dLdu = *dLdh * dhdu; // [bS, nU] - // NDArray dLdr = mmul(dLdc * dcdZc * *hLast, WchT); // [bS, nU] - - dLdx->assign(mmul(dLdZu, WuxT) + mmul(dLdZc, WcxT) + mmul(dLdZr, WrxT)); // [bS, iS] - - dLdhLast->assign(*dLdh * u + mmul(dLdZu, WuhT) + mmul(dLdZc * r, WchT) + mmul(dLdZr, WrhT)); // [bS, nU] - - dLdWrx.assign(mmul(xT, dLdZr)); // [iS, bS] × [bS, nU] = [iS, nU] - dLdWrh.assign(mmul(hLastT, dLdZr)); // [nU, bS] × [bS, nU] = [nU, nU] - dLdWux.assign(mmul(xT, dLdZu)); // [iS, bS] × [bS, nU] = [iS, nU] - dLdWuh.assign(mmul(hLastT, dLdZu)); // [nU, bS] × [bS, nU] = [nU, nU] - - dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU] - dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU] - - dLdbr.assign(dLdZr.reduceAlongDimension(reduce::Sum, {0})); // [nU] - dLdbu.assign(dLdZu.reduceAlongDimension(reduce::Sum, {0})); // [nU] - - dLdbc->assign(dLdZc.reduceAlongDimension(reduce::Sum, {0})); // [nU] +void gruCellBp(sd::LaunchContext* context, const NDArray* x, + const NDArray* hLast, const NDArray* W, const NDArray* Wc, + const NDArray* b, const NDArray* bc, const NDArray* dLdr, + const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, + NDArray* dLdx, NDArray* dLdhLast, NDArray* dLdW, NDArray* dLdWc, + NDArray* dLdb, NDArray* dLdbc) { + // Inputs: + // x input [bS, iS] + // hLast previous cell output [bS, nU], that is at previous time + // step t-1 W weights - [iS+nU, 2*nU] - reset and update gates Wc + // C weights - [iS+nU, nU] - cell gate b r and u biases, [2*nU] - + // reset and update gates bc c biases, [nU] - cell gate dLdr + // gradient wrt reset gate, [bS, nU] dLdu gradient wrt update gate, + // [bS, nU] dLdc gradient wrt cell state, [bS, nU] dLdh gradient wrt + // current cell output, [bS, nU] + + // Outputs: + // dLdx gradient wrt x, [bS, iS], + // dLdhLast gradient wrt hLast, [bS, nU] + // dLdW gradient wrt W, [iS+nU, 2*nU] + // dLdWc gradient wrt Wc, [iS+nU, nU] + // dLdb gradient wrt bru [2*nU] + // dLdbc gradient wrt bc [nU] + + // * means element-wise product or so called Hadamard product + // × means matrix multiplication + + /************************************************************************************************/ + /******************************* THIS IS NOT OPTIMAZED CODE + * *************************************/ + /*** aim is to have math-readable code in order to keep track of backprop + * formulas derivation ***/ + + const int bS = x->sizeAt(0); + const int iS = x->sizeAt(1); + const int nU = hLast->sizeAt(1); + + NDArray xT = x->transpose(); // [iS, bS] + NDArray hLastT = hLast->transpose(); // [nU, bS] + + NDArray Wrx = (*W)({0, iS, 0, nU}); // [iS, nU] + NDArray Wux = (*W)({0, iS, nU, 2 * nU}); // [iS, nU] + NDArray Wrh = (*W)({iS, iS + nU, 0, nU}); // [nU, nU] + NDArray Wuh = (*W)({iS, iS + nU, nU, 2 * nU}); // [nU, nU] + + NDArray Wcx = (*Wc)({0, iS, 0, 0}); // reset cell weights [iS, nU] + NDArray Wch = (*Wc)({iS, iS + nU, 0, 0}); // updates cell weights [nU, nU] + + NDArray br = (*b)({0, nU}); // [nU] + NDArray bu = (*b)({nU, 2 * nU}); // [nU] + + NDArray WrxT = Wrx.transpose(); // [nU, iS] + NDArray WuxT = Wux.transpose(); // [nU, iS] + NDArray WrhT = Wrh.transpose(); // [nU, nU] + NDArray WuhT = Wuh.transpose(); // [nU, nU] + + NDArray WcxT = Wcx.transpose(); // [nU, iS] + NDArray WchT = Wch.transpose(); // [nU, nU] + + NDArray dLdWrx = (*dLdW)({0, iS, 0, nU}); // [iS, nU] + NDArray dLdWux = (*dLdW)({0, iS, nU, 2 * nU}); // [iS, nU] + NDArray dLdWrh = (*dLdW)({iS, iS + nU, 0, nU}); // [nU, nU] + NDArray dLdWuh = (*dLdW)({iS, iS + nU, nU, 2 * nU}); // [nU, nU] + + NDArray dLdWcx = (*dLdWc)({0, iS, 0, 0}); // [iS, nU] + NDArray dLdWch = (*dLdWc)({iS, iS + nU, 0, 0}); // [nU, nU] + + NDArray dLdbr = (*dLdb)({0, nU}); // [nU] + NDArray dLdbu = (*dLdb)({nU, 2 * nU}); // [nU] + + // ***** feed forward step ***** // + + // reset gate + NDArray r = + mmul(*x, Wrx) + mmul(*hLast, Wrh) + + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] + r.applyTransform(transform::Sigmoid, r); + + // update gate + NDArray u = + mmul(*x, Wux) + mmul(*hLast, Wuh) + + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] + u.applyTransform(transform::Sigmoid, u); + + // cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc) + NDArray c = + mmul(*x, Wcx) + mmul(r * *hLast, Wch) + + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] + c.applyTransform(transform::Tanh, c); + + // h = (1 - u) * c + u * hPrev + + // ***** back prop step ***** // + + // notations: + // Zr = x × Wrx + hLast × Wrh + br + // Zu = x × Wux + hLast × Wuh + bu + // Sr = sigmoid(Zr) + // Su = sigmoid(Zu) + // Zc = x × Wcx + (r * hlast) × Wch + bc + + // dLdx = dLdh * dhdx = dLdh * (dhdu * dudx + dhdc * dcdx) = (dLdh * dhdu) * + // dudx + (dLdh * dhdc) * dcdx = dLdu * dudx + dLdc * dcdx + // = dLdx_u + dLdx_c + // dLdx_u = dLdu * dudx = dLdu * dudZu * dZudx = |dZudx = ... × WuxT| = (dLdu + // * dudZu) × WuxT dLdx_c = dLdc * dcdx = dLdc * dcdZc * (dZcdx + dZcdr * + // drdx) = dLdc * dcdZc * dZcdx + dLdc * dcdZc * dZcdr * drdx = dLdx_c0 + + // dLdx_c1 dLdx_c0 = dLdc * dcdZc * dZcdx = |dZcdx = ... × WcxT| = (dLdc * + // dcdZc) × WcxT dZcdr = (... * hLast) × WchT dLdc * dcdZc * dZcdr = dLdr = + // (dLdc * dcdZc * hLast) × WchT drdx = drdZr * dZrdx dZrdx = ... × WrxT + // dLdx_c1 = dLdc * dcdZc * dZcdr * drdx = dLdr * drdx = (dLdr * drdZr) × WrxT + // finally dLdx = dLdx_u + dLdx_c0 + dLdx_c1 = (dLdu * dudZu) × WuxT + (dLdc * + // dcdZc) × WcxT + (dLdr * drdZr) × WrxT + + // dLdhLast = dLdh * (dhdhLast + dhdu * dudhLast + dhdc * dcdhLast) = dLdh + // * dhdhLast + dLdu * dudhLast + dLdc * dcdhLast + // = dLdhLast_h + dLdhLast_u + dLdhLast_c + // dLdhLast_h = dLdh * dhdhLas = dLdh * u + // dLdhLast_u = dLdu * dudhLast = |dudhLast = dudZu * dZudhLast , dZudhLast = + // ... × WuhT| = (dLdu * dudZu) × WuhT dLdhLast_c = dLdc * dcdhLast = dLdc * + // (dcdZc * dZcdhLast + dcdZc * dZcdr * drdhLast) = + // = dLdc * dcdZc * dZcdhLast + dLdc * dcdZc * dZcdr * drdhLast = + // = dLdc * dcdZc * dZcdhLast + dLdr * drdhLast = dLdhLast_c0 + + // dLdhLast_c1 + // dLdhLast_c0 = dLdc * dcdZc * dZcdhLast = |dZcdhLast = (... * r) × WchT| = + // (dLdc * dcdZc * r) × WchT dLdhLast_c1 = dLdr * drdhLast = |drdhLast = + // drdZr * dZrdhLast, dZrdhLast = ... × WrhT| = (dLdr * drdZr) × WrhT finally + // dLdhLast = dLdhLast_h + dLdhLast_u + dLdhLast_c0 + dLdhLast_c1 = + // = dLdh * u + (dLdu * dudZu) × WuhT + (dLdc * dcdZc * r) × + // WchT + (dLdr * drdZr) × WrhT + + // dLdWrx = dLdh * dhdWrx = (dLdh * dhdc) * dcdWrx = dLdc * dcdZc * dZcdWrx = + // dLdc * dcdZc * dZcdr * drdWrx = + // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrx = dLdr * drdZr * dZrdWrx + // dZrdWrx = xT × ... + // finally dLdWrx = xT × (dLdr * drdZr) + + // dLdWrh = dLdh * dhdWrh = (dLdh * dhdc) * dcdWrh = dLdc * dcdZc * dZcdWrh = + // dLdc * dcdZc * dZcdr * drdWrh = + // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrh = dLdr * drdZr * dZrdWrh + // dZrdWrh = hLastT × ... + // finally dLdWrh = hLastT × (dLdr * drdZr) + + // dLdWux = dLdh * dhdWux = (dLdh * dhdu) * dudWux = dLdu * dudZu * dZudWux + // dZudWux = xT × ... + // dLdu * dudZu * dZudWux = xT × (dLdu * dudZu) + + // dLdWuh = dLdh * dhdWuh = (dLdh * dhdu) * dudWuh = dLdh * dhdu * dudZu * + // dZudWuh = dLdu * dudZu * dZudWuh dZudWuh = hLastT × ... finally dLdWuh = + // hLastT × (dLdu * dudZu) + + // dLdWcx = dLdh * dhdWcx = dLdh * dhdc * dcdWcx = (dLdh * dhdc) * dcdZc * + // dZcdWcx = dLdc * dcdZc * dZcdWcx dZcdWcx = xT × ... finally dLdWcx = xT × + // (dLdc * dcdZc) + + // dLdWch = dLdh * dhdWch = dLdh * dhdc * dcdWch = (dLdh * dhdc) * dcdZc * + // dZcdWch = dLdc * dcdZc * dZcdWch dZcdWch = (r*hLast)^T × ... finally dLdWch + // = (r*hLast)^T × (dLdc * dcdZc) + + // dLdbr = dLdh * dhdbr = (dLdh * dhdc) * dcdbr = dLdc * dcdbr = dLdc * dcdZc + // * dZcdbr = dLdc * dcdZc * dZcdr * drdbr = + // = dLdr * drdZr * dZrdbr + // dZrdbr = 1 + // finally dLdbr = dLdr * drdZr + + // dLdbu = dLdh * dhdbu = (dLdh * dhdu) * dudbu = dLdu * dudZu * dZudbu + // dZudbu = 1 + // finally dLdbu = dLdu * dudZu + + // dLdbc = dLdh * dhdbc = (dLdh * dhdc) * dcdbc = dLdc * dcdZc * dZcdbc + // dZcdbc = 1 + // finally dLdbc = dLdc * dcdZc + + NDArray dhdc = 1.f - u; // [bS, nU] + NDArray dhdu = *hLast - c; // [bS, nU] + NDArray dudZu = u * dhdc; // [bS, nU] + NDArray drdZr = r * (1.f - r); // [bS, nU] + NDArray dcdZc = 1.f - c * c; // [bS, nU] + NDArray dLdZc = *dLdc * dcdZc; // [bS, nU] + NDArray dLdZu = *dLdu * dudZu; // [bS, nU] + NDArray dLdZr = *dLdr * drdZr; // [bS, nU] + + // NDArray dLdc = *dLdh * dhdc; // [bS, nU] + // NDArray dLdu = *dLdh * dhdu; // [bS, nU] + // NDArray dLdr = mmul(dLdc * dcdZc * *hLast, WchT); // [bS, nU] + + dLdx->assign(mmul(dLdZu, WuxT) + mmul(dLdZc, WcxT) + + mmul(dLdZr, WrxT)); // [bS, iS] + + dLdhLast->assign(*dLdh * u + mmul(dLdZu, WuhT) + mmul(dLdZc * r, WchT) + + mmul(dLdZr, WrhT)); // [bS, nU] + + dLdWrx.assign(mmul(xT, dLdZr)); // [iS, bS] × [bS, nU] = [iS, nU] + dLdWrh.assign(mmul(hLastT, dLdZr)); // [nU, bS] × [bS, nU] = [nU, nU] + dLdWux.assign(mmul(xT, dLdZu)); // [iS, bS] × [bS, nU] = [iS, nU] + dLdWuh.assign(mmul(hLastT, dLdZu)); // [nU, bS] × [bS, nU] = [nU, nU] + + dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU] + dLdWch.assign( + mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU] + + dLdbr.assign(dLdZr.reduceAlongDimension(reduce::Sum, {0})); // [nU] + dLdbu.assign(dLdZu.reduceAlongDimension(reduce::Sum, {0})); // [nU] + + dLdbc->assign(dLdZc.reduceAlongDimension(reduce::Sum, {0})); // [nU] } - ////////////////////////////////////////////////////////////////////////// -void gruCellBp(sd::LaunchContext* context, - const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* gates, - NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) { - - //Inputs: - // x input [bS, nIn] - // hI previous cell output [bS, nOut], that nIn at previous time step t-1 - // Wx input-to-hidden weights - [nIn, 3*nOut] - // Wh hidden-to-hidden weights - [nOut, 3*nOut] - // b biases, [3*nOut] - reset and update gates - // dLdh gradient vs. ff output, [bS, nOut] +void gruCellBp(sd::LaunchContext* context, const NDArray* x, const NDArray* hI, + const NDArray* Wx, const NDArray* Wh, const NDArray* b, + const NDArray* dLdh, const NDArray* gates, NDArray* dLdx, + NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) { + // Inputs: + // x input [bS, nIn] + // hI previous cell output [bS, nOut], that nIn at previous time + // step t-1 Wx input-to-hidden weights - [nIn, 3*nOut] Wh + // hidden-to-hidden weights - [nOut, 3*nOut] b biases, [3*nOut] - + // reset and update gates dLdh gradient vs. ff output, [bS, nOut] - //Outputs: - // dLdx gradient vs. x, [bS, nIn], - // dLdhI gradient vs. hI, [bS, nOut] - // dLdWx gradient vs. W, [nIn, 3*nOut] - // dLdWh gradient vs. Wc, [nOut, 3*nOut] - // dLdb gradient vs. b [3*nOut] + // Outputs: + // dLdx gradient vs. x, [bS, nIn], + // dLdhI gradient vs. hI, [bS, nOut] + // dLdWx gradient vs. W, [nIn, 3*nOut] + // dLdWh gradient vs. Wc, [nOut, 3*nOut] + // dLdb gradient vs. b [3*nOut] - // 3*nOut means following sequence: reset, update, cell + // 3*nOut means following sequence: reset, update, cell - // * means element-wnIne product or so called Hadamard product - // × means matrix multiplication + // * means element-wnIne product or so called Hadamard product + // × means matrix multiplication - // formulas: - // zr = x × Wxr + hI × Whr + br - // zu = x × Wxu + hI × Whu + bu - // r = sigmoid(zr) - // u = sigmoid(zu) - // zc = x × Wxc + (r * hI) × Whc + bc - // c = tanh(zc) - // h = (1-u)*c + u*hI + // formulas: + // zr = x × Wxr + hI × Whr + br + // zu = x × Wxu + hI × Whu + bu + // r = sigmoid(zr) + // u = sigmoid(zu) + // zc = x × Wxc + (r * hI) × Whc + bc + // c = tanh(zc) + // h = (1-u)*c + u*hI - // dLdhI += dLdh; [bS, nOut] + // dLdhI += dLdh; [bS, nOut] + // dhdc = 1 - u [bS, nOut] dhdu = -c + hI [bS, nOut] - // dhdc = 1 - u [bS, nOut] - // dhdu = -c + hI [bS, nOut] + // dcdzc = 1 - c*c; [bS, nOut] dudzu = u*(1-u) [bS, nOut] drdzr = r(1-r) [bS, + // nOut] - // dcdzc = 1 - c*c; [bS, nOut] - // dudzu = u*(1-u) [bS, nOut] - // drdzr = r(1-r) [bS, nOut] + // dzcdr = (...*hI × WhcT) [bS, nOut] - // dzcdr = (...*hI × WhcT) [bS, nOut] + // dLdzr = dLdh*dhdc*dcdzc*dzcdr*drdzr = (dLdzc*hI*r(1-r) × WhcT); [bS, nOut] + // dLdzu = dLdh*dhdu*dudzu = dLdh*(hI-c)*u*(1-u) [bS, nOut] dLdzc = + // dLdh*dhdc*dcdzc = dLdh*(1-u)*(1-c*c) [bS, nOut] - // dLdzr = dLdh*dhdc*dcdzc*dzcdr*drdzr = (dLdzc*hI*r(1-r) × WhcT); [bS, nOut] - // dLdzu = dLdh*dhdu*dudzu = dLdh*(hI-c)*u*(1-u) [bS, nOut] - // dLdzc = dLdh*dhdc*dcdzc = dLdh*(1-u)*(1-c*c) [bS, nOut] + // dLdx = dLdzr × WxrT + dLdzu × WxuT + dLdzc × WxcT, [bs, nOut] × [nOut, + // nIn] + ... = [bS, nIn] - // dLdx = dLdzr × WxrT + dLdzu × WxuT + dLdzc × WxcT, [bs, nOut] × [nOut, nIn] + ... = [bS, nIn] + // dLdhI = dLdzr × WhrT + dLdzu × WhuT + dLdzc × WhcT, [bs, nOut] × [nOut, + // nOut] + ... = [bS, nOut] - // dLdhI = dLdzr × WhrT + dLdzu × WhuT + dLdzc × WhcT, [bs, nOut] × [nOut, nOut] + ... = [bS, nOut] + // dLdWxr = xT × dLdzr [nIn, bS] x [bS, nOut] = [nIn, + // nOut] dLdWxu = xT × dLdzu [nIn, bS] x [bS, nOut] = + // [nIn, nOut] dLdWxc = xT × dLdzc [nIn, bS] x [bS, + // nOut] = [nIn, nOut] - // dLdWxr = xT × dLdzr [nIn, bS] x [bS, nOut] = [nIn, nOut] - // dLdWxu = xT × dLdzu [nIn, bS] x [bS, nOut] = [nIn, nOut] - // dLdWxc = xT × dLdzc [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWhr = xT × dLdzr [nOut, bS] x [bS, nOut] = + // [nOut, nOut] dLdWhu = xT × dLdzu [nOut, bS] x [bS, + // nOut] = [nOut, nOut] dLdWhc = (r*hI)T × dLdzc [nOut, + // bS] x [bS, nOut] = [nOut, nOut] - // dLdWhr = xT × dLdzr [nOut, bS] x [bS, nOut] = [nOut, nOut] - // dLdWhu = xT × dLdzu [nOut, bS] x [bS, nOut] = [nOut, nOut] - // dLdWhc = (r*hI)T × dLdzc [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdbr = dLdzr.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdbu = dLdzu.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdbc = dLdzc.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] - // dLdbr = dLdzr.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] - // dLdbu = dLdzu.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] - // dLdbc = dLdzc.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + const int nOut = hI->sizeAt(1); - const int nOut = hI->sizeAt(1); + NDArray dLdz = gates->ulike(); // [bS, 3*nOut] - NDArray dLdz = gates->ulike(); // [bS, 3*nOut] + NDArray dLdzru = dLdz({0, 0, 0, 2 * nOut}); // [bS, 2*nOut] - NDArray dLdzru = dLdz({0,0, 0,2*nOut}); // [bS, 2*nOut] + NDArray dLdzr = dLdz({0, 0, 0, nOut}); // [bS, nOut] + NDArray dLdzu = dLdz({0, 0, nOut, 2 * nOut}); // [bS, nOut] + NDArray dLdzc = dLdz({0, 0, 2 * nOut, 3 * nOut}); // [bS, nOut] - NDArray dLdzr = dLdz({0,0, 0,nOut}); // [bS, nOut] - NDArray dLdzu = dLdz({0,0, nOut,2*nOut}); // [bS, nOut] - NDArray dLdzc = dLdz({0,0, 2*nOut,3*nOut}); // [bS, nOut] + NDArray r = (*gates)({0, 0, 0, nOut}); // [bS, nOut] + NDArray u = (*gates)({0, 0, nOut, 2 * nOut}); // [bS, nOut] + NDArray c = (*gates)({0, 0, 2 * nOut, 3 * nOut}); // [bS, nOut] - NDArray r = (*gates)({0,0, 0,nOut}); // [bS, nOut] - NDArray u = (*gates)({0,0, nOut,2*nOut}); // [bS, nOut] - NDArray c = (*gates)({0,0, 2*nOut,3*nOut}); // [bS, nOut] + NDArray WhcT = (*Wh)({0, 0, 2 * nOut, 3 * nOut}).transpose(); - NDArray WhcT = (*Wh)({0,0, 2*nOut,3*nOut}).transpose(); + if (dLdh) *dLdhI += *dLdh; - if(dLdh) - *dLdhI += *dLdh; + NDArray temp1 = 1 - u; // [bS, nOut] - NDArray temp1 = 1 - u; // [bS, nOut] + // dLdzc + dLdzc.assign(*dLdhI * temp1 * (1 - c * c)); // [bS, nOut] - // dLdzc - dLdzc.assign(*dLdhI * temp1 * (1-c*c)); // [bS, nOut] + // dLdzu + dLdzu.assign(*dLdhI * (*hI - c) * u * temp1); // [bS, nOut] - // dLdzu - dLdzu.assign(*dLdhI * (*hI - c) * u * temp1); // [bS, nOut] + // dLdzr + NDArray temp2 = dLdzc * (*hI) * r * (1 - r); + MmulHelper::mmul(&temp2, &WhcT, + &dLdzr); // [bS, nOut] x [nOut, nOut] = [bS, nOut] - // dLdzr - NDArray temp2 = dLdzc * (*hI) * r *(1-r); - MmulHelper::mmul(&temp2, &WhcT, &dLdzr); // [bS, nOut] x [nOut, nOut] = [bS, nOut] + // dLdx + NDArray WxT = Wx->transpose(); + MmulHelper::mmul(&dLdz, &WxT, + dLdx); // [bS, 3*nOut] x [3*nOut, nIn] = [bS, nIn] - // dLdx - NDArray WxT = Wx->transpose(); - MmulHelper::mmul(&dLdz, &WxT, dLdx); // [bS, 3*nOut] x [3*nOut, nIn] = [bS, nIn] + // dLdWx + *dLdWx += + mmul(x->transpose(), dLdz); // [nIn, bS] x [bS, 3*nOut] = [nIn, 3*nOut] - // dLdWx - *dLdWx += mmul(x->transpose(), dLdz); // [nIn, bS] x [bS, 3*nOut] = [nIn, 3*nOut] + // dLdb + *dLdb += dLdz.reduceAlongDimension( + reduce::Sum, {0}); // [bS, 3*nOut] -> reduce -> [3*nOut]; - // dLdb - *dLdb += dLdz.reduceAlongDimension(reduce::Sum, {0}); // [bS, 3*nOut] -> reduce -> [3*nOut]; + dLdzc *= r; - dLdzc *= r; + // dLdhI + NDArray WhT = Wh->transpose(); + dLdhI->assign(*dLdhI * u + + mmul(dLdz, WhT)); // [bS, 3*nOut] x [3*nOut, nOut] = [bS, nOut] - // dLdhI - NDArray WhT = Wh->transpose(); - dLdhI->assign(*dLdhI*u + mmul(dLdz, WhT)); // [bS, 3*nOut] x [3*nOut, nOut] = [bS, nOut] - - // dLdWr - *dLdWh += mmul(hI->transpose(), dLdz); // [nOut, bS] x [bS, 3*nOut] = [nOut, 3*nOut] + // dLdWr + *dLdWh += mmul(hI->transpose(), + dLdz); // [nOut, bS] x [bS, 3*nOut] = [nOut, 3*nOut] } - ////////////////////////////////////////////////////////////////////////// -void gruTimeLoopBp(sd::LaunchContext * context, - const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, - NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) { - // sL means time steps - - // x input [sL, bS, nIn] - // hI initial cell output (at time step = 0) [bS, nOut] - // Wx input-to-hidden weights, [nIn, 3*nOut] - // Wh hidden-to-hidden weights, [nOut, 3*nOut] - // b biases, [3*nOut] - // dLdh gradient vs. ff output, [sL, bS, nOut] - - // dLdx gradient vs. x, [sL, bS, nIn], - // dLdhI gradient vs. hI, [bS, nOut] - // dLdWx gradient vs. W, [nIn, 3*nOut] - // dLdWh gradient vs. Wc, [nOut, 3*nOut] - // dLdb gradient vs. b [3*nOut] - - const int sL = x->sizeAt(0); - const int bS = x->sizeAt(1); - const int nOut = hI->sizeAt(1); - - NDArray gates(x->ordering(), {sL, bS, 3*nOut}, dLdh->dataType(), x->getContext()); - NDArray h(x->ordering(), {sL+1, bS, nOut}, dLdh->dataType(), x->getContext()); - - auto xSet = x->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nIn] - auto dLdhSet = dLdh->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut] - auto hSet = h.allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut] - auto gatesSet = gates.allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut] - auto dLdxSet = dLdx->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nIn] - - hSet.at(0)->assign(hI); - - // forward time loop - for (int t = 0; t < sL; ++t) - gruCell(context, xSet.at(t), hSet.at(t), Wx, Wh, b, gatesSet.at(t), hSet.at(t+1)); - - // backward time loop - for (int t = sL-1; t >= 0; --t) - gruCellBp(context, xSet.at(t), hSet.at(t), Wx, Wh, b, dLdhSet.at(t), gatesSet.at(t), - dLdxSet.at(t), dLdhI, dLdWx, dLdWh, dLdb); +void gruTimeLoopBp(sd::LaunchContext* context, const NDArray* x, + const NDArray* hI, const NDArray* Wx, const NDArray* Wh, + const NDArray* b, const NDArray* dLdh, NDArray* dLdx, + NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, + NDArray* dLdb) { + // sL means time steps + + // x input [sL, bS, nIn] + // hI initial cell output (at time step = 0) [bS, nOut] + // Wx input-to-hidden weights, [nIn, 3*nOut] + // Wh hidden-to-hidden weights, [nOut, 3*nOut] + // b biases, [3*nOut] + // dLdh gradient vs. ff output, [sL, bS, nOut] + + // dLdx gradient vs. x, [sL, bS, nIn], + // dLdhI gradient vs. hI, [bS, nOut] + // dLdWx gradient vs. W, [nIn, 3*nOut] + // dLdWh gradient vs. Wc, [nOut, 3*nOut] + // dLdb gradient vs. b [3*nOut] + + const int sL = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int nOut = hI->sizeAt(1); + + NDArray gates(x->ordering(), {sL, bS, 3 * nOut}, dLdh->dataType(), + x->getContext()); + NDArray h(x->ordering(), {sL + 1, bS, nOut}, dLdh->dataType(), + x->getContext()); + + auto xSet = + x->allTensorsAlongDimension({1, 2}); // sub-arrays with shape [bS, nIn] + auto dLdhSet = dLdh->allTensorsAlongDimension( + {1, 2}); // sub-arrays with shape [bS, nOut] + auto hSet = + h.allTensorsAlongDimension({1, 2}); // sub-arrays with shape [bS, nOut] + auto gatesSet = gates.allTensorsAlongDimension( + {1, 2}); // sub-arrays with shape [bS, nOut] + auto dLdxSet = dLdx->allTensorsAlongDimension( + {1, 2}); // sub-arrays with shape [bS, nIn] + + hSet.at(0).assign(hI); + + // forward time loop + for (int t = 0; t < sL; ++t) + gruCell(context, &xSet.at(t), &hSet.at(t), Wx, Wh, b, &gatesSet.at(t), + &hSet.at(t + 1)); + + // backward time loop + for (int t = sL - 1; t >= 0; --t) + gruCellBp(context, &xSet.at(t), &hSet.at(t), Wx, Wh, b, &dLdhSet.at(t), + &gatesSet.at(t), &dLdxSet.at(t), dLdhI, dLdWx, dLdWh, dLdb); } - -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/impl/knn_mindistance.cpp b/libnd4j/include/ops/declarable/helpers/impl/knn_mindistance.cpp index 1a61587a3582..f7187265d896 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/knn_mindistance.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/knn_mindistance.cpp @@ -21,42 +21,47 @@ #include namespace sd { - namespace ops { - namespace helpers { - template - void mindistance_(const void* vinput, const void *vlow, const void *vhigh, int32_t length, void *vout) { - auto input = reinterpret_cast(vinput); - auto low = reinterpret_cast(vlow); - auto high = reinterpret_cast(vhigh); - auto output = reinterpret_cast(vout); - - T res = 0.0f; - T po = 2.f; - T o = 1.f; - -#pragma omp simd reduction(sumT:res) - for (auto e = 0; e < length; e++) { - T p = input[e]; - T l = low[e]; - T h = high[e]; - if (!(l <= p || h <= p)) { - if (p < l) - res += sd::math::nd4j_pow((p - o), po); - else - res += sd::math::nd4j_pow((p - h), po); - } - } - - output[0] = sd::math::nd4j_pow(res, (T) 0.5f); - } - - void knn_mindistance(const NDArray &input, const NDArray &lowest, const NDArray &highest, NDArray &output) { - NDArray::preparePrimaryUse({&output}, {&input, &lowest, &highest}); - - BUILD_SINGLE_SELECTOR(input.dataType(), mindistance_, (input.buffer(), lowest.buffer(), highest.buffer(), input.lengthOf(), output.buffer()), FLOAT_TYPES); - - NDArray::registerPrimaryUse({&output}, {&input, &lowest, &highest}); - } - } +namespace ops { +namespace helpers { +template +void mindistance_(const void *vinput, const void *vlow, const void *vhigh, + int32_t length, void *vout) { + auto input = reinterpret_cast(vinput); + auto low = reinterpret_cast(vlow); + auto high = reinterpret_cast(vhigh); + auto output = reinterpret_cast(vout); + + T res = 0.0f; + T po = 2.f; + T o = 1.f; + +#pragma omp simd reduction(sumT : res) + for (auto e = 0; e < length; e++) { + T p = input[e]; + T l = low[e]; + T h = high[e]; + if (!(l <= p || h <= p)) { + if (p < l) + res += sd::math::nd4j_pow((p - o), po); + else + res += sd::math::nd4j_pow((p - h), po); } -} \ No newline at end of file + } + + output[0] = sd::math::nd4j_pow(res, (T)0.5f); +} + +void knn_mindistance(const NDArray &input, const NDArray &lowest, + const NDArray &highest, NDArray &output) { + NDArray::preparePrimaryUse({&output}, {&input, &lowest, &highest}); + + BUILD_SINGLE_SELECTOR(input.dataType(), mindistance_, + (input.buffer(), lowest.buffer(), highest.buffer(), + input.lengthOf(), output.buffer()), + FLOAT_TYPES); + + NDArray::registerPrimaryUse({&output}, {&input, &lowest, &highest}); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/impl/listdiff.cpp b/libnd4j/include/ops/declarable/helpers/impl/listdiff.cpp index 6d937dc0fbb2..e975f61f4d99 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/listdiff.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/listdiff.cpp @@ -19,105 +19,122 @@ // #include + #include //#include namespace sd { namespace ops { namespace helpers { - template - static Nd4jLong listDiffCount_(NDArray* values, NDArray* keep) { - Nd4jLong saved = 0L; - for (Nd4jLong e = 0; e < values->lengthOf(); e++) { - auto v = values->e(e); - ExtraArguments extras({v, 0.0, 10.0}); - auto idx = keep->indexReduceNumber(indexreduce::FirstIndex, &extras); - auto index = idx.e(0); - if (index < 0) - saved++; - } - return saved; - } +template +static Nd4jLong listDiffCount_(NDArray* values, NDArray* keep) { + Nd4jLong saved = 0L; + for (Nd4jLong e = 0; e < values->lengthOf(); e++) { + auto v = values->e(e); + ExtraArguments extras({v, 0.0, 10.0}); + auto idx = keep->indexReduceNumber(indexreduce::FirstIndex, &extras); + auto index = idx.e(0); + if (index < 0) saved++; + } + return saved; +} - Nd4jLong listDiffCount(sd::LaunchContext * context, NDArray* values, NDArray* keep) { - auto xType = values->dataType(); +Nd4jLong listDiffCount(sd::LaunchContext* context, NDArray* values, + NDArray* keep) { + auto xType = values->dataType(); - NDArray::preparePrimaryUse({},{values, keep}); + NDArray::preparePrimaryUse({}, {values, keep}); - BUILD_SINGLE_SELECTOR(xType, return listDiffCount_, (values, keep), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(xType, return listDiffCount_, (values, keep), + LIBND4J_TYPES); - NDArray::registerPrimaryUse({},{values, keep}); - } + NDArray::registerPrimaryUse({}, {values, keep}); +} - BUILD_SINGLE_TEMPLATE(template Nd4jLong listDiffCount_, (NDArray* values, NDArray* keep);, LIBND4J_TYPES); - - template - static int listDiffFunctor_(NDArray* values, NDArray* keep, NDArray* output1, NDArray* output2) { - - std::vector saved; - std::vector indices; - - for (Nd4jLong e = 0; e < values->lengthOf(); e++) { - auto v = values->e(e); - ExtraArguments extras({v, 0.0, 10.0}); - NDArray idxScalar = keep->indexReduceNumber(indexreduce::FirstIndex, &extras); - Nd4jLong idx = idxScalar.e(0); - - if (idx < 0) { - saved.emplace_back(v); - indices.emplace_back(e); - } - } - - - if (saved.size() == 0) { -// if (sd::ops::conditionHelper(__FILE__, __LINE__, false, 0, "ListDiff: search returned no results") != 0) - nd4j_printf("ListDiff: search returned no results", ""); - throw std::invalid_argument("Op validation failed"); - } else { - auto z0 = output1;//OUTPUT_VARIABLE(0); //new NDArray('c', {(int) saved.size()}); - auto z1 = output2; //OUTPUT_VARIABLE(1); //new NDArray('c', {(int) saved.size()}); - - if (z0->lengthOf() != saved.size()) { - nd4j_printf("ListDiff: output/actual size mismatch", ""); - throw std::invalid_argument("Op validation failed"); - } - - if (z1->lengthOf() != saved.size()) { - nd4j_printf("ListDiff: output/actual indices size mismatch", ""); - throw std::invalid_argument("Op validation failed"); - } - memcpy(z0->buffer(), saved.data(), saved.size() * sizeof(T)); - for (int e = 0; e < indices.size(); e++) { - z1->p(e, indices[e]); - } - } - return Status::OK(); +BUILD_SINGLE_TEMPLATE(template Nd4jLong listDiffCount_, + (NDArray * values, NDArray* keep); + , LIBND4J_TYPES); + +template +static int listDiffFunctor_(NDArray* values, NDArray* keep, NDArray* output1, + NDArray* output2) { + std::vector saved; + std::vector indices; + + for (Nd4jLong e = 0; e < values->lengthOf(); e++) { + auto v = values->e(e); + ExtraArguments extras({v, 0.0, 10.0}); + NDArray idxScalar = + keep->indexReduceNumber(indexreduce::FirstIndex, &extras); + Nd4jLong idx = idxScalar.e(0); + + if (idx < 0) { + saved.emplace_back(v); + indices.emplace_back(e); + } + } + + if (saved.size() == 0) { + // if (sd::ops::conditionHelper(__FILE__, __LINE__, false, 0, + // "ListDiff: search returned no results") != 0) + nd4j_printf("ListDiff: search returned no results", ""); + throw std::invalid_argument("Op validation failed"); + } else { + auto z0 = output1; // OUTPUT_VARIABLE(0); //new NDArray('c', {(int) + // saved.size()}); + auto z1 = output2; // OUTPUT_VARIABLE(1); //new NDArray('c', {(int) + // saved.size()}); + + if (z0->lengthOf() != saved.size()) { + nd4j_printf("ListDiff: output/actual size mismatch", ""); + throw std::invalid_argument("Op validation failed"); } - int listDiffFunctor(sd::LaunchContext * context, NDArray* values, NDArray* keep, NDArray* output1, NDArray* output2) { - auto xType = values->dataType(); - - NDArray::preparePrimaryUse({output1, output2}, {values, keep}); + if (z1->lengthOf() != saved.size()) { + nd4j_printf("ListDiff: output/actual indices size mismatch", ""); + throw std::invalid_argument("Op validation failed"); + } + memcpy(z0->buffer(), saved.data(), saved.size() * sizeof(T)); + for (int e = 0; e < indices.size(); e++) { + z1->p(e, indices[e]); + } + } + return Status::OK(); +} - int result = 0; +int listDiffFunctor(sd::LaunchContext* context, NDArray* values, NDArray* keep, + NDArray* output1, NDArray* output2) { + auto xType = values->dataType(); - if (DataTypeUtils::isR(xType)) { - BUILD_SINGLE_SELECTOR(xType, result = listDiffFunctor_, (values, keep, output1, output2), FLOAT_TYPES); - } else if (DataTypeUtils::isZ(xType)) { - BUILD_SINGLE_SELECTOR(xType, result = listDiffFunctor_, (values, keep, output1, output2), INTEGER_TYPES); - } else { - throw std::runtime_error("ListDiff: Only integer and floating point data types are supported"); - } + NDArray::preparePrimaryUse({output1, output2}, {values, keep}); - NDArray::registerPrimaryUse({output1, output2}, {values, keep}); + int result = 0; - return result; - } + if (DataTypeUtils::isR(xType)) { + BUILD_SINGLE_SELECTOR(xType, result = listDiffFunctor_, + (values, keep, output1, output2), FLOAT_TYPES); + } else if (DataTypeUtils::isZ(xType)) { + BUILD_SINGLE_SELECTOR(xType, result = listDiffFunctor_, + (values, keep, output1, output2), INTEGER_TYPES); + } else { + throw std::runtime_error( + "ListDiff: Only integer and floating point data types are supported"); + } - BUILD_SINGLE_TEMPLATE(template int listDiffFunctor_, (NDArray* values, NDArray* keep, NDArray* output1, NDArray* output2);, FLOAT_TYPES); - BUILD_SINGLE_TEMPLATE(template int listDiffFunctor_, (NDArray* values, NDArray* keep, NDArray* output1, NDArray* output2);, INTEGER_TYPES); + NDArray::registerPrimaryUse({output1, output2}, {values, keep}); + return result; } -} -} \ No newline at end of file + +BUILD_SINGLE_TEMPLATE(template int listDiffFunctor_, + (NDArray * values, NDArray* keep, NDArray* output1, + NDArray* output2); + , FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template int listDiffFunctor_, + (NDArray * values, NDArray* keep, NDArray* output1, + NDArray* output2); + , INTEGER_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/impl/lstm.cpp b/libnd4j/include/ops/declarable/helpers/impl/lstm.cpp index 4ab585e26f83..5bb3942790f3 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/lstm.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/lstm.cpp @@ -20,122 +20,127 @@ // implementation of operation for LSTM cell with peep hole connections: // http://www.bioinf.jku.at/publications/older/2604.pdf -// S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. -// and +// S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural +// Computation, 9(8):1735-1780, 1997. and // https://research.google.com/pubs/archive/43905.pdf -// Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014. +// Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory +// recurrent neural network architectures for large scale acoustic modeling." +// INTERSPEECH, 2014. - -#include +#include #include +#include #include -#include #include -#include +#include +#include + #include -#include namespace sd { - namespace ops { - namespace helpers { +namespace ops { +namespace helpers { ///////////////////////////////////////////////////////////////////////////// - void lstmBlockTimeLoop(const NDArray* maxSeqLength, const NDArray* xSeq, const NDArray* c0, const NDArray* y0, - const NDArray* W, const NDArray* Wci, const NDArray* Wcf, const NDArray* Wco, const NDArray* b, - const NDArray* iSeq, const NDArray* cSeq, const NDArray* fSeq, const NDArray* oSeq, const NDArray* zSeq, - const NDArray* hSeq, const NDArray* ySeq, const std::vector& params, const int dataFormat){ - - int seqLen, bS, nIn, nOut; - - if(dataFormat == 0) { - seqLen = xSeq->sizeAt(0); - bS = xSeq->sizeAt(1); - nIn = xSeq->sizeAt(2); - nOut = iSeq->sizeAt(2); - } - else if(dataFormat == 1) { - seqLen = xSeq->sizeAt(2); - bS = xSeq->sizeAt(0); - nIn = xSeq->sizeAt(1); - nOut = iSeq->sizeAt(1); - } - else if(dataFormat == 2) { - seqLen = xSeq->sizeAt(1); - bS = xSeq->sizeAt(0); - nIn = xSeq->sizeAt(2); - nOut = iSeq->sizeAt(2); - } - - const std::vector inSliceShape({bS,nIn}); - const std::vector outSliceShape({bS,nOut}); - - auto c_t1 = const_cast(c0); - auto y_t1 = const_cast(y0); - - // loop through time steps - for (int t = 0; t < seqLen; ++t) { - - auto xt = timeSubset(xSeq, t, dataFormat); - - auto it = timeSubset(iSeq, t, dataFormat); - auto ct = timeSubset(cSeq, t, dataFormat); - auto ft = timeSubset(fSeq, t, dataFormat); - auto ot = timeSubset(oSeq, t, dataFormat); - auto zt = timeSubset(zSeq, t, dataFormat); - auto ht = timeSubset(hSeq, t, dataFormat); - auto yt = timeSubset(ySeq, t, dataFormat); - - helpers::lstmBlockCell(&xt, c_t1, y_t1, W, Wci, Wcf, Wco, b, &it, &ct, &ft, &ot, &zt, &ht, &yt, params); - - if(t != 0) { - delete c_t1; - delete y_t1; - } - - if(t < seqLen - 1) { - c_t1 = new NDArray(std::move(ct)); - y_t1 = new NDArray(std::move(yt)); - } - } - } - - - - ////////////////////////////////////////////////////////////////////////// - void lstmTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* c0, const NDArray* Wx, const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, const NDArray* b, - NDArray* h, NDArray* c, const std::vector& params) { - - // x input [time x bS x nIn] - // h0 initial cell output (at time step = 0) [bS x numProj], in case of projection=false -> numProj == numUnits !!! - // c0 initial cell state (at time step = 0) [bS x numUnits], - - // Wx input-to-hidden weights, [nIn x 4*numUnits] - // Wh hidden-to-hidden weights, [numProj x 4*numUnits] - // Wc diagonal weights for peephole connections [3*numUnits] - // Wp projection weights [numUnits x numProj] - // b biases, [4*numUnits] - - // h cell outputs [time x bS x numProj], that is per each time step - // c cell states [time x bS x numUnits] that is per each time step - - const int time = x->sizeAt(0); - - NDArray currentH(*h0); - NDArray currentC(*c0); - - // loop through time steps - for (int t = 0; t < time; ++t) { - auto xt = (*x)({t,t+1, 0,0, 0,0}); - auto ht = (*h)({t,t+1, 0,0, 0,0}); - auto ct = (*c)({t,t+1, 0,0, 0,0}); - - helpers::lstmCell(context, &xt,¤tH,¤tC, Wx,Wh,Wc,Wp, b, &ht, &ct, params); - currentH.assign(ht); - currentC.assign(ct); - } - } - - - } +void lstmBlockTimeLoop(const NDArray* maxSeqLength, const NDArray* xSeq, + const NDArray* c0, const NDArray* y0, const NDArray* W, + const NDArray* Wci, const NDArray* Wcf, + const NDArray* Wco, const NDArray* b, + const NDArray* iSeq, const NDArray* cSeq, + const NDArray* fSeq, const NDArray* oSeq, + const NDArray* zSeq, const NDArray* hSeq, + const NDArray* ySeq, const std::vector& params, + const int dataFormat) { + int seqLen, bS, nIn, nOut; + + if (dataFormat == 0) { + seqLen = xSeq->sizeAt(0); + bS = xSeq->sizeAt(1); + nIn = xSeq->sizeAt(2); + nOut = iSeq->sizeAt(2); + } else if (dataFormat == 1) { + seqLen = xSeq->sizeAt(2); + bS = xSeq->sizeAt(0); + nIn = xSeq->sizeAt(1); + nOut = iSeq->sizeAt(1); + } else if (dataFormat == 2) { + seqLen = xSeq->sizeAt(1); + bS = xSeq->sizeAt(0); + nIn = xSeq->sizeAt(2); + nOut = iSeq->sizeAt(2); + } + + const std::vector inSliceShape({bS, nIn}); + const std::vector outSliceShape({bS, nOut}); + + auto c_t1 = const_cast(c0); + auto y_t1 = const_cast(y0); + + // loop through time steps + for (int t = 0; t < seqLen; ++t) { + auto xt = timeSubset(xSeq, t, dataFormat); + + auto it = timeSubset(iSeq, t, dataFormat); + auto ct = timeSubset(cSeq, t, dataFormat); + auto ft = timeSubset(fSeq, t, dataFormat); + auto ot = timeSubset(oSeq, t, dataFormat); + auto zt = timeSubset(zSeq, t, dataFormat); + auto ht = timeSubset(hSeq, t, dataFormat); + auto yt = timeSubset(ySeq, t, dataFormat); + + helpers::lstmBlockCell(&xt, c_t1, y_t1, W, Wci, Wcf, Wco, b, &it, &ct, &ft, + &ot, &zt, &ht, &yt, params); + + if (t != 0) { + delete c_t1; + delete y_t1; + } + + if (t < seqLen - 1) { + c_t1 = new NDArray(std::move(ct)); + y_t1 = new NDArray(std::move(yt)); } -} \ No newline at end of file + } +} + +////////////////////////////////////////////////////////////////////////// +void lstmTimeLoop(sd::LaunchContext* context, const NDArray* x, + const NDArray* h0, const NDArray* c0, const NDArray* Wx, + const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, + const NDArray* b, NDArray* h, NDArray* c, + const std::vector& params) { + // x input [time x bS x nIn] + // h0 initial cell output (at time step = 0) [bS x numProj], in case of + // projection=false -> numProj == numUnits !!! c0 initial cell state (at time + // step = 0) [bS x numUnits], + + // Wx input-to-hidden weights, [nIn x 4*numUnits] + // Wh hidden-to-hidden weights, [numProj x 4*numUnits] + // Wc diagonal weights for peephole connections [3*numUnits] + // Wp projection weights [numUnits x numProj] + // b biases, [4*numUnits] + + // h cell outputs [time x bS x numProj], that is per each time step + // c cell states [time x bS x numUnits] that is per each time step + + const int time = x->sizeAt(0); + + NDArray currentH(*h0); + NDArray currentC(*c0); + + // loop through time steps + for (int t = 0; t < time; ++t) { + auto xt = (*x)({t, t + 1, 0, 0, 0, 0}); + auto ht = (*h)({t, t + 1, 0, 0, 0, 0}); + auto ct = (*c)({t, t + 1, 0, 0, 0, 0}); + + helpers::lstmCell(context, &xt, ¤tH, ¤tC, Wx, Wh, Wc, Wp, b, + &ht, &ct, params); + currentH.assign(ht); + currentC.assign(ct); + } +} + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp index bffd13128069..5ec64180f023 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp @@ -21,17 +21,18 @@ // implementation of operation for LSTM cell with peep hole connections: // http://www.bioinf.jku.at/publications/older/2604.pdf -// S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. -// and +// S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural +// Computation, 9(8):1735-1780, 1997. and // https://research.google.com/pubs/archive/43905.pdf -// Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014. +// Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory +// recurrent neural network architectures for large scale acoustic modeling." +// INTERSPEECH, 2014. - -#include #include -#include -#include #include +#include +#include +#include // #include // #include // #include @@ -39,1176 +40,1378 @@ // #include // #include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// -static void applyActivation(const NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) { - - switch (opId) { - case 0: - (const_cast(x)).applyTransform(transform::Tanh, z); - break; - case 1: - (const_cast(x)).applyScalar(scalar::RELU, 0, z); - break; - case 2: - (const_cast(x)).applyTransform(transform::Sigmoid, z); - break; - case 3: { - ExtraArguments args({ static_cast(alpha), static_cast(beta)}); - (const_cast(x)).applyTransform(transform::Affine, z, &args); - break; - } - case 4: - (const_cast(x)).applyScalar(scalar::LeakyRELU, alpha, z); - break; - case 5: - thresholdRelu(x.getContext(), x, alpha, z); - break; - case 6: { - ExtraArguments args({ static_cast(alpha), static_cast(beta)}); - (const_cast(x)).applyTransform(transform::ScaledTanh, z, &args); - break; - } - case 7: - (const_cast(x)).applyTransform(transform::HardSigmoid, z); - break; - case 8: - (const_cast(x)).applyScalar(scalar::ELU, alpha, z); - break; - case 9: - (const_cast(x)).applyTransform(transform::SoftSign, z); - break; - case 10: - (const_cast(x)).applyTransform(transform::SoftPlus, z); - break; - default: - throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !"); +static void applyActivation(const NDArray& x, const int opId, const float alpha, + const float beta, NDArray& z) { + switch (opId) { + case 0: + (const_cast(x)).applyTransform(transform::Tanh, z); + break; + case 1: + (const_cast(x)).applyScalar(scalar::RELU, 0, z); + break; + case 2: + (const_cast(x)).applyTransform(transform::Sigmoid, z); + break; + case 3: { + ExtraArguments args( + {static_cast(alpha), static_cast(beta)}); + (const_cast(x)).applyTransform(transform::Affine, z, &args); + break; } + case 4: + (const_cast(x)).applyScalar(scalar::LeakyRELU, alpha, z); + break; + case 5: + thresholdRelu(x.getContext(), x, alpha, z); + break; + case 6: { + ExtraArguments args( + {static_cast(alpha), static_cast(beta)}); + (const_cast(x)).applyTransform(transform::ScaledTanh, z, &args); + break; + } + case 7: + (const_cast(x)).applyTransform(transform::HardSigmoid, z); + break; + case 8: + (const_cast(x)).applyScalar(scalar::ELU, alpha, z); + break; + case 9: + (const_cast(x)).applyTransform(transform::SoftSign, z); + break; + case 10: + (const_cast(x)).applyTransform(transform::SoftPlus, z); + break; + default: + throw std::invalid_argument( + "LSTM_LAYER operation: wrong id number of activation !"); + } } ////////////////////////////////////////////////////////////////////////// -static void activationDeriv(const NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) { - - switch (opId) { - case 0: - (const_cast(x)).applyTransform(transform::TanhDerivative, z); - break; - case 1: - (const_cast(x)).applyScalar(scalar::RELUDerivative, 0, z); - break; - case 2: - (const_cast(x)).applyTransform(transform::SigmoidDerivative, z); - break; - case 3: { - z = alpha; - break; - } - case 4: - (const_cast(x)).applyScalar(scalar::LeakyRELUDerivative, alpha, z); - break; - case 5: - (const_cast(x)).applyScalar(scalar::RELUDerivative, alpha, z); - break; - case 6: { - auto func = PRAGMA_THREADS_FOR { - for(Nd4jLong i = start; i < stop; ++i) { - auto val = beta * x.e(i); - z.p(i, alpha * beta * (1.f - sd::math::nd4j_tanh(val) * sd::math::nd4j_tanh(val))); - } - }; - samediff::Threads::parallel_for(func, 0, x.lengthOf()); - break; +static void activationDeriv(const NDArray& x, const int opId, const float alpha, + const float beta, NDArray& z) { + switch (opId) { + case 0: + (const_cast(x)).applyTransform(transform::TanhDerivative, z); + break; + case 1: + (const_cast(x)) + .applyScalar(scalar::RELUDerivative, 0, z); + break; + case 2: + (const_cast(x)).applyTransform(transform::SigmoidDerivative, z); + break; + case 3: { + z = alpha; + break; + } + case 4: + (const_cast(x)) + .applyScalar(scalar::LeakyRELUDerivative, alpha, z); + break; + case 5: + (const_cast(x)) + .applyScalar(scalar::RELUDerivative, alpha, z); + break; + case 6: { + auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = start; i < stop; ++i) { + auto val = beta * x.e(i); + z.p(i, alpha * beta * + (1.f - sd::math::nd4j_tanh(val) * + sd::math::nd4j_tanh(val))); } - case 7: - (const_cast(x)).applyTransform(transform::HardSigmoidDerivative, z); - break; - case 8: - (const_cast(x)).applyScalar(scalar::ELUDerivative, alpha, z); - break; - case 9: - (const_cast(x)).applyTransform(transform::SoftSignDerivative, z); - break; - case 10: { - auto func = PRAGMA_THREADS_FOR { - for(Nd4jLong i = start; i < stop; ++i) { - auto val = sd::math::nd4j_exp(x.e(i)); - z.p(i, val / (1.f + val)); - } - }; - samediff::Threads::parallel_for(func, 0, x.lengthOf()); - break; + }; + samediff::Threads::parallel_for(func, 0, x.lengthOf()); + break; + } + case 7: + (const_cast(x)) + .applyTransform(transform::HardSigmoidDerivative, z); + break; + case 8: + (const_cast(x)) + .applyScalar(scalar::ELUDerivative, alpha, z); + break; + case 9: + (const_cast(x)) + .applyTransform(transform::SoftSignDerivative, z); + break; + case 10: { + auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = start; i < stop; ++i) { + auto val = sd::math::nd4j_exp(x.e(i)); + z.p(i, val / (1.f + val)); } - default: - throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !"); + }; + samediff::Threads::parallel_for(func, 0, x.lengthOf()); + break; } + default: + throw std::invalid_argument( + "LSTM_LAYER operation: wrong id number of activation !"); + } } ////////////////////////////////////////////////////////////////////////// -// FIXME - derivative undefined when not-clipped c has element/elements equal to -clipVal or clipVal -static void clipDeriv(const float clipVal, const NDArray& c, NDArray& z0, NDArray& z1, NDArray& z2, NDArray& z3) { - - if(clipVal == 0) - return; - - auto func = PRAGMA_THREADS_FOR { - for(Nd4jLong i = start; i < stop; ++i) { - const auto val = c.e(i); - if(val == -clipVal || val == clipVal) { - z0.p(i, 0.f); - z1.p(i, 0.f); - z2.p(i, 0.f); - z3.p(i, 0.f); - } - } - }; - samediff::Threads::parallel_for(func, 0, c.lengthOf()); +// FIXME - derivative undefined when not-clipped c has element/elements equal to +// -clipVal or clipVal +static void clipDeriv(const float clipVal, const NDArray& c, NDArray& z0, + NDArray& z1, NDArray& z2, NDArray& z3) { + if (clipVal == 0) return; + + auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = start; i < stop; ++i) { + const auto val = c.e(i); + if (val == -clipVal || val == clipVal) { + z0.p(i, 0.f); + z1.p(i, 0.f); + z2.p(i, 0.f); + z3.p(i, 0.f); + } + } + }; + samediff::Threads::parallel_for(func, 0, c.lengthOf()); } ////////////////////////////////////////////////////////////////////////// -static NDArray tensorAlongTimeBatchDims(const NDArray& arr, const int dataFormat, const int t1, const int t2, const int b1, const int b2) { - - if(dataFormat == 0 || dataFormat == 3) - return arr({t1,t2, b1,b2, 0,0}); // TNS: [sL, bS, nIn] +static NDArray tensorAlongTimeBatchDims(const NDArray& arr, + const int dataFormat, const int t1, + const int t2, const int b1, + const int b2) { + if (dataFormat == 0 || dataFormat == 3) + return arr({t1, t2, b1, b2, 0, 0}); // TNS: [sL, bS, nIn] - if(dataFormat == 1) - return arr({b1,b2, t1,t2, 0,0}); // NTS: [bS, sL ,nIn] + if (dataFormat == 1) + return arr({b1, b2, t1, t2, 0, 0}); // NTS: [bS, sL ,nIn] - return arr({b1,b2, 0,0, t1,t2}); // NST: [bS, nIn, sL] + return arr({b1, b2, 0, 0, t1, t2}); // NST: [bS, nIn, sL] } ////////////////////////////////////////////////////////////////////////// -static FORCEINLINE int getBatchTimeTotalIndex(const int dataFormat, const int sL, const int bS, const int t, const int b) { - - if(dataFormat == 0 || dataFormat == 3) - return t * bS + b; // TNS: shape [sL, bS, nIn] +static FORCEINLINE int getBatchTimeTotalIndex(const int dataFormat, + const int sL, const int bS, + const int t, const int b) { + if (dataFormat == 0 || dataFormat == 3) + return t * bS + b; // TNS: shape [sL, bS, nIn] - return b * sL + t; // NTS, NST: shape [bS, sL, nIn], [bS, nIn, sL] + return b * sL + t; // NTS, NST: shape [bS, sL, nIn], [bS, nIn, sL] } - ////////////////////////////////////////////////////////////////////////// void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, - const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, - const std::vector& params, + const NDArray* b, const NDArray* hI, const NDArray* cI, + const NDArray* Wp, const std::vector& params, NDArray* h, NDArray* c) { - - // * -> means element-wise multiplication - // × -> means matrix multiplication - - /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ - /** the objective is to provide math-readable code **/ - - // equations (no peephole connections) - // it = σ(Wxi × xt + Wri × ht-1 + bi) - // ft = σ(Wxf × xt + Wrf × ht-1 + bf) - // c't = tanh(Wxc × xt + Wrc × ht-1 + bc) - // ct = ft * ct-1 + it * c't - // ot = σ(Wxo × xt + Wro × ht-1 + bo) - // ht = ot * tanh(ct) - - // equations (peephole connections are present) - // it = σ(Wxi × xt + Wri × ht-1 + Wpi * ct-1 + bi) - // ft = σ(Wxf × xt + Wrf × ht-1 + Wpf * ct-1 + bf) - // c't = tanh(Wxc × xt + Wrc × ht-1 + bc) - // ct = ft * ct-1 + it * c't - // ot = σ(Wxo × xt + Wro × ht-1 + Wpo * ct + bo) - // ht = ot * tanh(ct) - - - // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus - - // params[0] - dataFormat, ignore - // params[1] - directionMode, ignore - // params[2] - cell clipping value, if it = 0 then do not apply clipping - - // params[3] - activation ID for input (i), forget (f) and output (o) gates - // params[4] - alpha value for gates activation - // params[5] - beta value for gates activation - - // params[6] - activation ID for cell state (c) - // params[7] - alpha value for cell state activation - // params[8] - beta value for cell state activation - - // params[9] - activation ID for output (h) - // params[10] - alpha value for output activation - // params[11] - beta value for output activation - - // INPUTS: - // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr - // Wx - input weights [nIn, 4*nOut] - // Wr - recurrent weights [nOut, 4*nOut] - // b - biases [4*nOut], optional, may be nullptr - // hI - (ht-1) previous (initial) output at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr - // cI - (ct-1) previous (initial) cell state at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr - // Wp - peephole weights [3*nOut], optional, may be nullptr - - // OUTPUTS: - // h - current output, that is at current time step t, [bS, nOut] or [nOut] if seqLen != nullptr - // c - current cell state, that is at current time step t, [bS, nOut] or [nOut] if seqLen != nullptr - - // !!! dimension 4*nOut implies order it, ft, c't, ot - // !!! dimension 3*nOut implies order it, ft, ot - - const Nd4jLong nOut = Wx->sizeAt(-1) / 4; - - auto z = mmul(*x, *Wx) + mmul(*hI, *Wr); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut] - //or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut] - - // add biases if they are given - if(b != nullptr) - z += *b; // broadcast [bS, 4*nOut](or[4*nOut]) + [4*nOut] = [bS, 4*nOut] - - auto zi = x->rankOf() == 1 ? z({0, nOut}) : z({0,0, 0, nOut}); // input gate it, [bS, nOut](or[nOut]) - auto zf = x->rankOf() == 1 ? z({nOut, 2*nOut}) : z({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut](or[nOut]) - auto zg = x->rankOf() == 1 ? z({2*nOut, 3*nOut}) : z({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut](or[nOut]) - auto zo = x->rankOf() == 1 ? z({3*nOut, 4*nOut}) : z({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut](or[nOut]) - - // peephole connections for input and forget gates - if(Wp != nullptr) { - zi += *cI * (*Wp)({0, nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) - zf += *cI * (*Wp)({nOut, 2*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) - } - - applyActivation(zi, params[3], params[4], params[5], zi); // inplace - applyActivation(zf, params[3], params[4], params[5], zf); // inplace - applyActivation(zg, params[6], params[7], params[8], zg); // inplace - - c->assign(zf * *cI + zi * zg); // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, nOut] = [bS, nOut](or[nOut]) - - // if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation - if(params[2] != 0) - c->applyScalar(scalar::LstmClip, params[2], *c); - - // peephole connections for output gate - if(Wp != nullptr) - zo += *c * (*Wp)({2*nOut, 3*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) - - applyActivation(zo, params[3], params[4], params[5], zo); - - applyActivation(*c, params[9], params[10], params[11], *h); - *h *= zo; // [bS, nOut] * [bS, nOut](or[nOut]) + // * -> means element-wise multiplication + // × -> means matrix multiplication + + /************************ THIS IS NOT OPTIMAZED CODE + * ***********************************/ + /** the objective is to provide math-readable code **/ + + // equations (no peephole connections) + // it = σ(Wxi × xt + Wri × ht-1 + bi) + // ft = σ(Wxf × xt + Wrf × ht-1 + bf) + // c't = tanh(Wxc × xt + Wrc × ht-1 + bc) + // ct = ft * ct-1 + it * c't + // ot = σ(Wxo × xt + Wro × ht-1 + bo) + // ht = ot * tanh(ct) + + // equations (peephole connections are present) + // it = σ(Wxi × xt + Wri × ht-1 + Wpi * ct-1 + bi) + // ft = σ(Wxf × xt + Wrf × ht-1 + Wpf * ct-1 + bf) + // c't = tanh(Wxc × xt + Wrc × ht-1 + bc) + // ct = ft * ct-1 + it * c't + // ot = σ(Wxo × xt + Wro × ht-1 + Wpo * ct + bo) + // ht = ot * tanh(ct) + + // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= + // thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, + // 10=softplus + + // params[0] - dataFormat, ignore + // params[1] - directionMode, ignore + // params[2] - cell clipping value, if it = 0 then do not apply clipping + + // params[3] - activation ID for input (i), forget (f) and output (o) gates + // params[4] - alpha value for gates activation + // params[5] - beta value for gates activation + + // params[6] - activation ID for cell state (c) + // params[7] - alpha value for cell state activation + // params[8] - beta value for cell state activation + + // params[9] - activation ID for output (h) + // params[10] - alpha value for output activation + // params[11] - beta value for output activation + + // INPUTS: + // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr + // Wx - input weights [nIn, 4*nOut] + // Wr - recurrent weights [nOut, 4*nOut] + // b - biases [4*nOut], optional, may be nullptr + // hI - (ht-1) previous (initial) output at time t-1, optional may be nullptr, + // [bS, nOut] or [nOut] if seqLen != nullptr cI - (ct-1) previous (initial) + // cell state at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if + // seqLen != nullptr Wp - peephole weights [3*nOut], optional, may be nullptr + + // OUTPUTS: + // h - current output, that is at current time step t, [bS, nOut] or [nOut] if + // seqLen != nullptr c - current cell state, that is at current time step t, + // [bS, nOut] or [nOut] if seqLen != nullptr + + // !!! dimension 4*nOut implies order it, ft, c't, ot + // !!! dimension 3*nOut implies order it, ft, ot + + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + auto z = mmul(*x, *Wx) + + mmul(*hI, *Wr); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, + // 4*nOut] = [bS, 4*nOut] + // or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, + // 4*nOut] = [4*nOut] + + // add biases if they are given + if (b != nullptr) + z += *b; // broadcast [bS, 4*nOut](or[4*nOut]) + [4*nOut] = [bS, 4*nOut] + + auto zi = x->rankOf() == 1 + ? z({0, nOut}) + : z({0, 0, 0, nOut}); // input gate it, [bS, nOut](or[nOut]) + auto zf = + x->rankOf() == 1 + ? z({nOut, 2 * nOut}) + : z({0, 0, nOut, 2 * nOut}); // forget gate ft, [bS, nOut](or[nOut]) + auto zg = x->rankOf() == 1 + ? z({2 * nOut, 3 * nOut}) + : z({0, 0, 2 * nOut, + 3 * nOut}); // cell gate c't, [bS, nOut](or[nOut]) + auto zo = x->rankOf() == 1 + ? z({3 * nOut, 4 * nOut}) + : z({0, 0, 3 * nOut, + 4 * nOut}); // output gate ot, [bS, nOut](or[nOut]) + + // peephole connections for input and forget gates + if (Wp != nullptr) { + zi += *cI * (*Wp)({0, nOut}); // broadcast: [bS, nOut] + [bS, nOut] * + // [nOut] = [bS, nOut](or[nOut]) + zf += *cI * (*Wp)({nOut, 2 * nOut}); // broadcast: [bS, nOut] + [bS, nOut] + // * [nOut] = [bS, nOut](or[nOut]) + } + + applyActivation(zi, params[3], params[4], params[5], zi); // inplace + applyActivation(zf, params[3], params[4], params[5], zf); // inplace + applyActivation(zg, params[6], params[7], params[8], zg); // inplace + + c->assign(zf * *cI + zi * zg); // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, + // nOut] = [bS, nOut](or[nOut]) + + // if clipping value is non-zero then cell state is clipped by this value + // prior to the cell output activation + if (params[2] != 0) c->applyScalar(scalar::LstmClip, params[2], *c); + + // peephole connections for output gate + if (Wp != nullptr) + zo += + *c * (*Wp)({2 * nOut, 3 * nOut}); // broadcast: [bS, nOut] + [bS, nOut] + // * [nOut] = [bS, nOut](or[nOut]) + + applyActivation(zo, params[3], params[4], params[5], zo); + + applyActivation(*c, params[9], params[10], params[11], *h); + *h *= zo; // [bS, nOut] * [bS, nOut](or[nOut]) } - ////////////////////////////////////////////////////////////////////////// // this auxiliary ff should be running before backprop void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, - const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, - const std::vector& params, + const NDArray* b, const NDArray* hI, const NDArray* cI, + const NDArray* Wp, const std::vector& params, NDArray* z, NDArray* a, NDArray* h, NDArray* c) { - - // z - zi, zf, zg, zo - // a - i, f, g, o - - const Nd4jLong nOut = Wx->sizeAt(-1) / 4; - - z->assign(mmul(*x, *Wx) + mmul(*hI, *Wr)); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut] - //or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut] - // add biases if they are given - if(b != nullptr) - *z += *b; // broadcast [bS, 4*nOut](or[4*nOut]) + [4*nOut] = [bS, 4*nOut] - - auto zi = x->rankOf() == 1 ? (*z)({0, nOut}) : (*z)({0,0, 0, nOut}); // input gate it, [bS, nOut](or[nOut]) - auto zf = x->rankOf() == 1 ? (*z)({nOut, 2*nOut}) : (*z)({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut](or[nOut]) - auto zg = x->rankOf() == 1 ? (*z)({2*nOut, 3*nOut}) : (*z)({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut](or[nOut]) - auto zo = x->rankOf() == 1 ? (*z)({3*nOut, 4*nOut}) : (*z)({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut](or[nOut]) - - auto i = x->rankOf() == 1 ? (*a)({0, nOut}) : (*a)({0,0, 0, nOut}); // input gate it, [bS, nOut](or[nOut]) - auto f = x->rankOf() == 1 ? (*a)({nOut, 2*nOut}) : (*a)({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut](or[nOut]) - auto g = x->rankOf() == 1 ? (*a)({2*nOut, 3*nOut}) : (*a)({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut](or[nOut]) - auto o = x->rankOf() == 1 ? (*a)({3*nOut, 4*nOut}) : (*a)({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut](or[nOut]) - - // peephole connections for input and forget gates - if(Wp != nullptr) { - zi += *cI * (*Wp)({0, nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) - zf += *cI * (*Wp)({nOut, 2*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) - } - - applyActivation(zi, params[3], params[4], params[5], i); - applyActivation(zf, params[3], params[4], params[5], f); - applyActivation(zg, params[6], params[7], params[8], g); - - c->assign(f * *cI + i * g); // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, nOut] = [bS, nOut](or[nOut]) - - // if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation - if(params[2] != 0) - c->applyScalar(scalar::LstmClip, params[2], *c); - - // peephole connections for output gate - if(Wp != nullptr) - zo += *c * (*Wp)({2*nOut, 3*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) - - applyActivation(zo, params[3], params[4], params[5], o); - - applyActivation(*c, params[9], params[10], params[11], *h); - *h *= o; // [bS, nOut] * [bS, nOut](or[nOut]) + // z - zi, zf, zg, zo + // a - i, f, g, o + + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + z->assign(mmul(*x, *Wx) + + mmul(*hI, *Wr)); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * + // [nOut, 4*nOut] = [bS, 4*nOut] + // or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, + // 4*nOut] = [4*nOut] + // add biases if they are given + if (b != nullptr) + *z += *b; // broadcast [bS, 4*nOut](or[4*nOut]) + [4*nOut] = [bS, 4*nOut] + + auto zi = x->rankOf() == 1 + ? (*z)({0, nOut}) + : (*z)({0, 0, 0, nOut}); // input gate it, [bS, nOut](or[nOut]) + auto zf = x->rankOf() == 1 + ? (*z)({nOut, 2 * nOut}) + : (*z)({0, 0, nOut, + 2 * nOut}); // forget gate ft, [bS, nOut](or[nOut]) + auto zg = x->rankOf() == 1 + ? (*z)({2 * nOut, 3 * nOut}) + : (*z)({0, 0, 2 * nOut, + 3 * nOut}); // cell gate c't, [bS, nOut](or[nOut]) + auto zo = x->rankOf() == 1 + ? (*z)({3 * nOut, 4 * nOut}) + : (*z)({0, 0, 3 * nOut, + 4 * nOut}); // output gate ot, [bS, nOut](or[nOut]) + + auto i = x->rankOf() == 1 + ? (*a)({0, nOut}) + : (*a)({0, 0, 0, nOut}); // input gate it, [bS, nOut](or[nOut]) + auto f = x->rankOf() == 1 + ? (*a)({nOut, 2 * nOut}) + : (*a)({0, 0, nOut, + 2 * nOut}); // forget gate ft, [bS, nOut](or[nOut]) + auto g = x->rankOf() == 1 + ? (*a)({2 * nOut, 3 * nOut}) + : (*a)({0, 0, 2 * nOut, + 3 * nOut}); // cell gate c't, [bS, nOut](or[nOut]) + auto o = x->rankOf() == 1 + ? (*a)({3 * nOut, 4 * nOut}) + : (*a)({0, 0, 3 * nOut, + 4 * nOut}); // output gate ot, [bS, nOut](or[nOut]) + + // peephole connections for input and forget gates + if (Wp != nullptr) { + zi += *cI * (*Wp)({0, nOut}); // broadcast: [bS, nOut] + [bS, nOut] * + // [nOut] = [bS, nOut](or[nOut]) + zf += *cI * (*Wp)({nOut, 2 * nOut}); // broadcast: [bS, nOut] + [bS, nOut] + // * [nOut] = [bS, nOut](or[nOut]) + } + + applyActivation(zi, params[3], params[4], params[5], i); + applyActivation(zf, params[3], params[4], params[5], f); + applyActivation(zg, params[6], params[7], params[8], g); + + c->assign(f * *cI + i * g); // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, + // nOut] = [bS, nOut](or[nOut]) + + // if clipping value is non-zero then cell state is clipped by this value + // prior to the cell output activation + if (params[2] != 0) c->applyScalar(scalar::LstmClip, params[2], *c); + + // peephole connections for output gate + if (Wp != nullptr) + zo += + *c * (*Wp)({2 * nOut, 3 * nOut}); // broadcast: [bS, nOut] + [bS, nOut] + // * [nOut] = [bS, nOut](or[nOut]) + + applyActivation(zo, params[3], params[4], params[5], o); + + applyActivation(*c, params[9], params[10], params[11], *h); + *h *= o; // [bS, nOut] * [bS, nOut](or[nOut]) } - ////////////////////////////////////////////////////////////////////////// -void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, - const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL, - const NDArray* z, const NDArray* a, const NDArray* c, const std::vector& params, - NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp) { - - /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ - /** the objective is to provide math-readable code **/ - - // equations (no peephole connections) - // zi = x × Wxi + hI × Wri + bi - // zf = x × Wxf + hI × Wrf + bf - // zg = x × Wxg + hI × Wrg + bg - // zo = x × Wxo + hI × Wro + bo - // i = act(zi) - // f = act(zf) - // g = actC(zg) - // o = act(zo) - // c = clip(f * cI + i * g) - // h = o * actH(c) - - // equations (peephole connections are present) - // zi = x × Wxi + hI × Wri + cI * Wpi + bi - // zf = x × Wxf + hI × Wrf + cI * Wpf + bf - // zg = x × Wxg + hI × Wrg + bg - // zo = x × Wxo + hI × Wro + c * Wpo + bo - // i = act(zi) - // f = act(zf) - // g = actC(zg) - // o = act(zo) - // c = clip(f * cI + i * g) - // h = o * actH(c) - - // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus - - // params[0] - dataFormat, ignore - // params[1] - directionMode, ignore - // params[2] - cell clipping value, if it = 0 then do not apply clipping - - // params[3] - activation ID for input (i), forget (f) and output (o) gates - // params[4] - alpha value for gates activation - // params[5] - beta value for gates activation - - // params[6] - activation ID for cell state (c) - // params[7] - alpha value for cell state activation - // params[8] - beta value for cell state activation - - // params[9] - activation ID for output (h) - // params[10] - alpha value for output activation - // params[11] - beta value for output activation - - // INPUTS: - // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr - // Wx - input weights [nIn, 4*nOut] - // Wr - recurrent weights [nOut, 4*nOut] - // b - biases [4*nOut], optional, may be nullptr - // hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr - // cI - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr - // Wp - peephole weights [3*nOut], optional, may be nullptr - // dLdh - loss derivative with respect to h at each time step, [bS, nOut] or [nOut] if seqLen != nullptr - // dLdhL - loss derivative with respect to h at last time step, [bS, nOut] or [nOut] if seqLen != nullptr - // dLdcL - loss derivative with respect to c at last time step, [bS, nOut] or [nOut] if seqLen != nullptr - // z - zi,zf,zg,zo taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut] - // a - i,f,g,o taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut] - // c - taken from ff outputs to reduce amount of calculations in bp, [bS, nOut] - - // OUTPUTS: - // dLdx - loss derivative with respect to x, [bS, nIn] or [nIn] if seqLen != nullptr - // dLdWx - loss derivative with respect to Wx, [nIn, 4*nOut] - // dLdWr - loss derivative with respect to Wr, [nOut, 4*nOut] - // dLdb - loss derivative with respect to b, optional, may be nullptr, [4*nOut] - // dLdhI - loss derivative with respect to hI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr - // dLdcI - loss derivative with respect to cI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr - // dLdWp - loss derivative with respect to Wp, optional, may be nullptr, [3*nOut] - - // !!! dimension 4*nOut implies order i, f, g, o - // !!! dimension 3*nOut implies order i, f, o - - // dhdc = o*tanhDeriv + Wp ? tanh(c)*dodzo*dzodc : 0 [bS, nOut] - // dcdcI = f + Wp ? dcdzi*dzidcI + dcdzf*dzfdcI : 0 [bS, nOut] - - // dLdhI += dLdh; [bS, nOut] - // dLdcI += dLdhI * dhdc; [bS, nOut] - - // dLdzi = dLdcI*dcdi*didzi; [bS, nOut](or[nOut]) - // dLdzf = dLdcI*dcdf*dfdzf; [bS, nOut](or[nOut]) - // dLdzg = dLdcI*dcdg*dgdzg; [bS, nOut](or[nOut]) - // dLdzo = dLdhI*dhdo*dodzo; [bS, nOut](or[nOut]) - - // dLdx = dLdzi×WxiT + dLdzf×WxfT + dLdzg×WxgT + dLdzo×WxoT, [bS, nIn] - // dLdhI = dLdzi×WriT + dLdzf×WrfT + dLdzg×WrgT + dLdzo×WroT, [bS, nOut] - // dLdcI = dLdcI*dcdcI, [bS, nOut] - - // dLdWxi = xT×dLdzi [nIn, bS] x [bS, nOut] = [nIn, nOut] - // dLdWxf = xT×dLdzf [nIn, bS] x [bS, nOut] = [nIn, nOut] - // dLdWxg = xT×dLdzg [nIn, bS] x [bS, nOut] = [nIn, nOut] - // dLdWxo = xT×dLdzo [nIn, bS] x [bS, nOut] = [nIn, nOut] - - // dLdWri = hIT×dLdzi [nOut, bS] x [bS, nOut] = [nOut, nOut] - // dLdWrf = hIT×dLdzf [nOut, bS] x [bS, nOut] = [nOut, nOut] - // dLdWrg = hIT×dLdzg [nOut, bS] x [bS, nOut] = [nOut, nOut] - // dLdWro = hIT×dLdzo [nOut, bS] x [bS, nOut] = [nOut, nOut] - - // dLdbi = dLdzi.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] - // dLdbf = dLdzf.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] - // dLdbg = dLdzg.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] - // dLdbo = dLdzo.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] - - // dLdWpi = (dLdzi*cI).reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] - // dLdWpf = (dLdzf*cI).reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] - // dLdWpo = (dLdzo*c) .reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] - - const Nd4jLong nOut = Wx->sizeAt(-1) / 4; - const Nd4jLong nIn = x->sizeAt(-1); - - NDArray zi = x->rankOf() == 1 ? (*z)({0, nOut}) : (*z)({0,0, 0, nOut}); // input gate i, [bS, nOut](or[nOut]) - NDArray zf = x->rankOf() == 1 ? (*z)({nOut, 2*nOut}) : (*z)({0,0, nOut, 2*nOut}); // forget gate f, [bS, nOut](or[nOut]) - NDArray zg = x->rankOf() == 1 ? (*z)({2*nOut, 3*nOut}) : (*z)({0,0, 2*nOut, 3*nOut}); // cell gate g, [bS, nOut](or[nOut]) - NDArray zo = x->rankOf() == 1 ? (*z)({3*nOut, 4*nOut}) : (*z)({0,0, 3*nOut, 4*nOut}); // output gate o, [bS, nOut](or[nOut]) - - NDArray i = x->rankOf() == 1 ? (*a)({0, nOut}) : (*a)({0,0, 0, nOut}); // input gate i, [bS, nOut](or[nOut]) - NDArray f = x->rankOf() == 1 ? (*a)({nOut, 2*nOut}) : (*a)({0,0, nOut, 2*nOut}); // forget gate f, [bS, nOut](or[nOut]) - NDArray g = x->rankOf() == 1 ? (*a)({2*nOut, 3*nOut}) : (*a)({0,0, 2*nOut, 3*nOut}); // cell gate g, [bS, nOut](or[nOut]) - NDArray o = x->rankOf() == 1 ? (*a)({3*nOut, 4*nOut}) : (*a)({0,0, 3*nOut, 4*nOut}); // output gate o, [bS, nOut](or[nOut]) - - NDArray dLdz = z->ulike(); // [bS, 4*nOut](or[4*nOut]) - NDArray dLdzi = x->rankOf() == 1 ? dLdz({0, nOut}) : dLdz({0,0, 0, nOut}); - NDArray dLdzf = x->rankOf() == 1 ? dLdz({nOut, 2*nOut}) : dLdz({0,0, nOut, 2*nOut}); - NDArray dLdzg = x->rankOf() == 1 ? dLdz({2*nOut, 3*nOut}) : dLdz({0,0, 2*nOut, 3*nOut}); - NDArray dLdzo = x->rankOf() == 1 ? dLdz({3*nOut, 4*nOut}) : dLdz({0,0, 3*nOut, 4*nOut}); - - // dcdzi = dcdi*didzi, [bS, nOut](or[nOut]) - activationDeriv(zi, params[3], params[4], params[5], dLdzi); // didzi, inplace - dLdzi *= g; // dcdi = g*clipDeriv - - // dcdzf = dcdf*dfdzf, [bS, nOut](or[nOut]) - activationDeriv(zf, params[3], params[4], params[5], dLdzf); // dfdzf, inplace - dLdzf *= *cI; // dcdf = cI*clipDeriv - - // dcdzg = dcde*dedzg, [bS, nOut](or[nOut]) - activationDeriv(zg, params[6], params[7], params[8], dLdzg); // dgdzg, inplace - dLdzg *= i; // dcdf = i*clipDeriv - - // dhdzo = dhdo*dodzo = actH(c)*dodzo, [bS, nOut](or[nOut]) - activationDeriv(zo, params[3], params[4], params[5], dLdzo); - NDArray temp = dLdzo.ulike(); - applyActivation(*c, params[9], params[10], params[11], temp); // actH(c), inplace - dLdzo *= temp; - - // dcdcI - NDArray dcdcI = f.dup(); // dcdcI = f*clipDeriv [bS, nOut](or[nOut]) - - // take into account possible deposit from clipping derivative - clipDeriv(params[2], *c, dLdzi, dLdzf, dLdzg, dcdcI); - - // dhdc - NDArray dhdc = c->ulike(); - activationDeriv(*c, params[9], params[10], params[11], dhdc); // [bS, nOut] - dhdc *= o; - - if(Wp) { - dhdc += dLdzo*(*Wp)({2*nOut, 3*nOut}); - dcdcI += dLdzi*(*Wp)({0, nOut}) + dLdzf*(*Wp)({nOut, 2*nOut}); // broadcast [bS, nOut] * nOut + ... - } - - if(dLdh) - *dLdhI += *dLdh; - if(dLdhL) - *dLdhI += *dLdhL; - if(dLdcL) - *dLdcI += *dLdcL; - - *dLdcI += *dLdhI * dhdc; - - dLdzi *= *dLdcI; // [bS, nOut](or[nOut]) - dLdzf *= *dLdcI; // [bS, nOut](or[nOut]) - dLdzg *= *dLdcI; // [bS, nOut](or[nOut]) - dLdzo *= *dLdhI; // [bS, nOut](or[nOut]) - - // dLdx - NDArray WxT = Wx->transpose(); - MmulHelper::mmul(&dLdz, &WxT, dLdx); // [bS, 4*nOut] x [4*nOut, nIn] (or [4*nOut] x [4*nOut, nIn]) = [bS, nIn] ( or[nIn] ) - - // dLdhI - NDArray WrT = Wr->transpose(); - MmulHelper::mmul(&dLdz, &WrT, dLdhI); // [bS, 4*nOut] x [4*nOut, nOut] (or [4*nOut] x [4*nOut, nOut]) = [bS, nOut] ( or[nOut] ) - - // dLdcI - dLdcI->assign(*dLdcI*dcdcI); // [bS, nOut](or[nOut]) - - if(x->rankOf() == 1) { - - NDArray xT = x->reshape(x->ordering(),{nIn, 1}); // [nIn] -> [nIn, 1] - NDArray hIT = hI->reshape(hI->ordering(),{nOut, 1}); // [nOut] -> [nOut, 1] - NDArray dLdzR = dLdz.reshape(dLdz.ordering(), {1, 4*nOut}); // [nOut] -> [1, 4*nOut] - - // dLdWx - *dLdWx += mmul(xT, dLdzR); // [nIn, 1] x [1, 4*nOut] = [nIn, 4*nOut] - - // dLdWr - *dLdWr += mmul(hIT, dLdzR); // [nOut, 1] x [1, 4*nOut] = [nOut, 4*nOut] - } - else { - - // dLdWx - *dLdWx += mmul(x->transpose(), dLdz); // [nIn, bS] x [bS, 4*nOut] = [nIn, 4*nOut] - - // dLdWr - *dLdWr += mmul(hI->transpose(), dLdz); // [nOut, bS] x [bS, 4*nOut] = [nOut, 4*nOut] - } - - // dLdb - if(b && x->rankOf() == 1) - *dLdb += dLdz; // [4*nOut] - else if(b) - *dLdb += dLdz.reduceAlongDimension(reduce::Sum, {0}); // [bS, 4*nOut] -> reduce -> [4*nOut]; - - // dLdWp - if(Wp && x->rankOf() == 1) { - (*dLdWp)({ 0,nOut}) += std::move(dLdzi)*(*cI); // [nOut] - (*dLdWp)({ nOut,2*nOut}) += std::move(dLdzf)*(*cI); // [nOut] - (*dLdWp)({2*nOut,3*nOut}) += std::move(dLdzo)*(*c); // [nOut] - } - else if(Wp) { - NDArray temp(Wp->ordering(), {nOut}, Wp->dataType(), Wp->getContext()); - (std::move(dLdzi)*(*cI)).reduceAlongDimension(reduce::Sum, temp, {0}); // [bS, nOut] -> reduce -> [nOut] - (*dLdWp)({0,nOut}) += temp; - (std::move(dLdzf)*(*cI)).reduceAlongDimension(reduce::Sum, temp, {0}); // [bS, nOut] -> reduce -> [nOut] - (*dLdWp)({nOut,2*nOut}) += temp; - (std::move(dLdzo)*(*c)).reduceAlongDimension(reduce::Sum, temp, {0}); // [bS, nOut] -> reduce -> [nOut] - (*dLdWp)({2*nOut,3*nOut}) += temp; - } +void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, + const NDArray* b, const NDArray* hI, const NDArray* cI, + const NDArray* Wp, const NDArray* dLdh, + const NDArray* dLdhL, const NDArray* dLdcL, + const NDArray* z, const NDArray* a, const NDArray* c, + const std::vector& params, NDArray* dLdx, + NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, + NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp) { + /************************ THIS IS NOT OPTIMAZED CODE + * ***********************************/ + /** the objective is to provide math-readable code **/ + + // equations (no peephole connections) + // zi = x × Wxi + hI × Wri + bi + // zf = x × Wxf + hI × Wrf + bf + // zg = x × Wxg + hI × Wrg + bg + // zo = x × Wxo + hI × Wro + bo + // i = act(zi) + // f = act(zf) + // g = actC(zg) + // o = act(zo) + // c = clip(f * cI + i * g) + // h = o * actH(c) + + // equations (peephole connections are present) + // zi = x × Wxi + hI × Wri + cI * Wpi + bi + // zf = x × Wxf + hI × Wrf + cI * Wpf + bf + // zg = x × Wxg + hI × Wrg + bg + // zo = x × Wxo + hI × Wro + c * Wpo + bo + // i = act(zi) + // f = act(zf) + // g = actC(zg) + // o = act(zo) + // c = clip(f * cI + i * g) + // h = o * actH(c) + + // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= + // thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, + // 10=softplus + + // params[0] - dataFormat, ignore + // params[1] - directionMode, ignore + // params[2] - cell clipping value, if it = 0 then do not apply clipping + + // params[3] - activation ID for input (i), forget (f) and output (o) gates + // params[4] - alpha value for gates activation + // params[5] - beta value for gates activation + + // params[6] - activation ID for cell state (c) + // params[7] - alpha value for cell state activation + // params[8] - beta value for cell state activation + + // params[9] - activation ID for output (h) + // params[10] - alpha value for output activation + // params[11] - beta value for output activation + + // INPUTS: + // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr + // Wx - input weights [nIn, 4*nOut] + // Wr - recurrent weights [nOut, 4*nOut] + // b - biases [4*nOut], optional, may be nullptr + // hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] + // if seqLen != nullptr cI - (ct-1) previous (initial) cell state at time + // t-1, [bS, nOut] or [nOut] if seqLen != nullptr Wp - peephole weights + // [3*nOut], optional, may be nullptr dLdh - loss derivative with respect to + // h at each time step, [bS, nOut] or [nOut] if seqLen != nullptr dLdhL - loss + // derivative with respect to h at last time step, [bS, nOut] or [nOut] if + // seqLen != nullptr dLdcL - loss derivative with respect to c at last time + // step, [bS, nOut] or [nOut] if seqLen != nullptr z - zi,zf,zg,zo taken + // from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut] a - + // i,f,g,o taken from ff outputs to reduce amount of calculations in bp, [bS, + // 4*nOut] c - taken from ff outputs to reduce amount of calculations in + // bp, [bS, nOut] + + // OUTPUTS: + // dLdx - loss derivative with respect to x, [bS, nIn] or [nIn] if seqLen != + // nullptr dLdWx - loss derivative with respect to Wx, [nIn, 4*nOut] dLdWr - + // loss derivative with respect to Wr, [nOut, 4*nOut] dLdb - loss derivative + // with respect to b, optional, may be nullptr, [4*nOut] dLdhI - loss + // derivative with respect to hI, optional may be nullptr, [bS, nOut] or + // [nOut] if seqLen != nullptr dLdcI - loss derivative with respect to cI, + // optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr dLdWp - + // loss derivative with respect to Wp, optional, may be nullptr, [3*nOut] + + // !!! dimension 4*nOut implies order i, f, g, o + // !!! dimension 3*nOut implies order i, f, o + + // dhdc = o*tanhDeriv + Wp ? tanh(c)*dodzo*dzodc : 0 [bS, nOut] + // dcdcI = f + Wp ? dcdzi*dzidcI + dcdzf*dzfdcI : 0 [bS, nOut] + + // dLdhI += dLdh; [bS, nOut] + // dLdcI += dLdhI * dhdc; [bS, nOut] + + // dLdzi = dLdcI*dcdi*didzi; [bS, nOut](or[nOut]) + // dLdzf = dLdcI*dcdf*dfdzf; [bS, nOut](or[nOut]) + // dLdzg = dLdcI*dcdg*dgdzg; [bS, nOut](or[nOut]) + // dLdzo = dLdhI*dhdo*dodzo; [bS, nOut](or[nOut]) + + // dLdx = dLdzi×WxiT + dLdzf×WxfT + dLdzg×WxgT + dLdzo×WxoT, [bS, nIn] + // dLdhI = dLdzi×WriT + dLdzf×WrfT + dLdzg×WrgT + dLdzo×WroT, [bS, nOut] + // dLdcI = dLdcI*dcdcI, [bS, nOut] + + // dLdWxi = xT×dLdzi [nIn, bS] x [bS, nOut] = [nIn, + // nOut] dLdWxf = xT×dLdzf [nIn, bS] x [bS, nOut] = + // [nIn, nOut] dLdWxg = xT×dLdzg [nIn, bS] x [bS, + // nOut] = [nIn, nOut] dLdWxo = xT×dLdzo [nIn, bS] + // x [bS, nOut] = [nIn, nOut] + + // dLdWri = hIT×dLdzi [nOut, bS] x [bS, nOut] = + // [nOut, nOut] dLdWrf = hIT×dLdzf [nOut, bS] x [bS, + // nOut] = [nOut, nOut] dLdWrg = hIT×dLdzg [nOut, + // bS] x [bS, nOut] = [nOut, nOut] dLdWro = hIT×dLdzo [nOut, bS] x [bS, nOut] + // = [nOut, nOut] + + // dLdbi = dLdzi.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdbf = dLdzf.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdbg = dLdzg.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdbo = dLdzo.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + + // dLdWpi = (dLdzi*cI).reduce_sum_along_0_axis [bS, nOut] -> reduce -> + // [nOut] dLdWpf = (dLdzf*cI).reduce_sum_along_0_axis [bS, nOut] -> + // reduce -> [nOut] dLdWpo = (dLdzo*c) .reduce_sum_along_0_axis [bS, + // nOut] -> reduce -> [nOut] + + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + const Nd4jLong nIn = x->sizeAt(-1); + + NDArray zi = + x->rankOf() == 1 + ? (*z)({0, nOut}) + : (*z)({0, 0, 0, nOut}); // input gate i, [bS, nOut](or[nOut]) + NDArray zf = x->rankOf() == 1 + ? (*z)({nOut, 2 * nOut}) + : (*z)({0, 0, nOut, + 2 * nOut}); // forget gate f, [bS, nOut](or[nOut]) + NDArray zg = x->rankOf() == 1 + ? (*z)({2 * nOut, 3 * nOut}) + : (*z)({0, 0, 2 * nOut, + 3 * nOut}); // cell gate g, [bS, nOut](or[nOut]) + NDArray zo = x->rankOf() == 1 + ? (*z)({3 * nOut, 4 * nOut}) + : (*z)({0, 0, 3 * nOut, + 4 * nOut}); // output gate o, [bS, nOut](or[nOut]) + + NDArray i = + x->rankOf() == 1 + ? (*a)({0, nOut}) + : (*a)({0, 0, 0, nOut}); // input gate i, [bS, nOut](or[nOut]) + NDArray f = x->rankOf() == 1 + ? (*a)({nOut, 2 * nOut}) + : (*a)({0, 0, nOut, + 2 * nOut}); // forget gate f, [bS, nOut](or[nOut]) + NDArray g = x->rankOf() == 1 + ? (*a)({2 * nOut, 3 * nOut}) + : (*a)({0, 0, 2 * nOut, + 3 * nOut}); // cell gate g, [bS, nOut](or[nOut]) + NDArray o = x->rankOf() == 1 + ? (*a)({3 * nOut, 4 * nOut}) + : (*a)({0, 0, 3 * nOut, + 4 * nOut}); // output gate o, [bS, nOut](or[nOut]) + + NDArray dLdz = z->ulike(); // [bS, 4*nOut](or[4*nOut]) + NDArray dLdzi = x->rankOf() == 1 ? dLdz({0, nOut}) : dLdz({0, 0, 0, nOut}); + NDArray dLdzf = + x->rankOf() == 1 ? dLdz({nOut, 2 * nOut}) : dLdz({0, 0, nOut, 2 * nOut}); + NDArray dLdzg = x->rankOf() == 1 ? dLdz({2 * nOut, 3 * nOut}) + : dLdz({0, 0, 2 * nOut, 3 * nOut}); + NDArray dLdzo = x->rankOf() == 1 ? dLdz({3 * nOut, 4 * nOut}) + : dLdz({0, 0, 3 * nOut, 4 * nOut}); + + // dcdzi = dcdi*didzi, [bS, nOut](or[nOut]) + activationDeriv(zi, params[3], params[4], params[5], + dLdzi); // didzi, inplace + dLdzi *= g; // dcdi = g*clipDeriv + + // dcdzf = dcdf*dfdzf, [bS, nOut](or[nOut]) + activationDeriv(zf, params[3], params[4], params[5], + dLdzf); // dfdzf, inplace + dLdzf *= *cI; // dcdf = cI*clipDeriv + + // dcdzg = dcde*dedzg, [bS, nOut](or[nOut]) + activationDeriv(zg, params[6], params[7], params[8], + dLdzg); // dgdzg, inplace + dLdzg *= i; // dcdf = i*clipDeriv + + // dhdzo = dhdo*dodzo = actH(c)*dodzo, [bS, nOut](or[nOut]) + activationDeriv(zo, params[3], params[4], params[5], dLdzo); + NDArray temp = dLdzo.ulike(); + applyActivation(*c, params[9], params[10], params[11], + temp); // actH(c), inplace + dLdzo *= temp; + + // dcdcI + NDArray dcdcI = f.dup(); // dcdcI = f*clipDeriv [bS, nOut](or[nOut]) + + // take into account possible deposit from clipping derivative + clipDeriv(params[2], *c, dLdzi, dLdzf, dLdzg, dcdcI); + + // dhdc + NDArray dhdc = c->ulike(); + activationDeriv(*c, params[9], params[10], params[11], dhdc); // [bS, nOut] + dhdc *= o; + + if (Wp) { + dhdc += dLdzo * (*Wp)({2 * nOut, 3 * nOut}); + dcdcI += + dLdzi * (*Wp)({0, nOut}) + + dLdzf * (*Wp)({nOut, 2 * nOut}); // broadcast [bS, nOut] * nOut + ... + } + + if (dLdh) *dLdhI += *dLdh; + if (dLdhL) *dLdhI += *dLdhL; + if (dLdcL) *dLdcI += *dLdcL; + + *dLdcI += *dLdhI * dhdc; + + dLdzi *= *dLdcI; // [bS, nOut](or[nOut]) + dLdzf *= *dLdcI; // [bS, nOut](or[nOut]) + dLdzg *= *dLdcI; // [bS, nOut](or[nOut]) + dLdzo *= *dLdhI; // [bS, nOut](or[nOut]) + + // dLdx + NDArray WxT = Wx->transpose(); + MmulHelper::mmul(&dLdz, &WxT, + dLdx); // [bS, 4*nOut] x [4*nOut, nIn] (or [4*nOut] x + // [4*nOut, nIn]) = [bS, nIn] ( or[nIn] ) + + // dLdhI + NDArray WrT = Wr->transpose(); + MmulHelper::mmul(&dLdz, &WrT, + dLdhI); // [bS, 4*nOut] x [4*nOut, nOut] (or [4*nOut] x + // [4*nOut, nOut]) = [bS, nOut] ( or[nOut] ) + + // dLdcI + dLdcI->assign(*dLdcI * dcdcI); // [bS, nOut](or[nOut]) + + if (x->rankOf() == 1) { + NDArray xT = x->reshape(x->ordering(), {nIn, 1}); // [nIn] -> [nIn, 1] + NDArray hIT = + hI->reshape(hI->ordering(), {nOut, 1}); // [nOut] -> [nOut, 1] + NDArray dLdzR = + dLdz.reshape(dLdz.ordering(), {1, 4 * nOut}); // [nOut] -> [1, 4*nOut] + + // dLdWx + *dLdWx += mmul(xT, dLdzR); // [nIn, 1] x [1, 4*nOut] = [nIn, 4*nOut] + + // dLdWr + *dLdWr += mmul(hIT, dLdzR); // [nOut, 1] x [1, 4*nOut] = [nOut, 4*nOut] + } else { + // dLdWx + *dLdWx += + mmul(x->transpose(), dLdz); // [nIn, bS] x [bS, 4*nOut] = [nIn, 4*nOut] + + // dLdWr + *dLdWr += mmul(hI->transpose(), + dLdz); // [nOut, bS] x [bS, 4*nOut] = [nOut, 4*nOut] + } + + // dLdb + if (b && x->rankOf() == 1) + *dLdb += dLdz; // [4*nOut] + else if (b) + *dLdb += dLdz.reduceAlongDimension( + reduce::Sum, {0}); // [bS, 4*nOut] -> reduce -> [4*nOut]; + + // dLdWp + if (Wp && x->rankOf() == 1) { + (*dLdWp)({0, nOut}) += std::move(dLdzi) * (*cI); // [nOut] + (*dLdWp)({nOut, 2 * nOut}) += std::move(dLdzf) * (*cI); // [nOut] + (*dLdWp)({2 * nOut, 3 * nOut}) += std::move(dLdzo) * (*c); // [nOut] + } else if (Wp) { + NDArray temp(Wp->ordering(), {nOut}, Wp->dataType(), Wp->getContext()); + // [bS, nOut] -> reduce -> [nOut] + (std::move(dLdzi) * (*cI)).reduceAlongDimension(reduce::Sum, temp, {0}); + (*dLdWp)({0, nOut}) += temp; + + // [bS, nOut] -> reduce -> [nOut] + (std::move(dLdzf) * (*cI)).reduceAlongDimension(reduce::Sum, temp, {0}); + (*dLdWp)({nOut, 2 * nOut}) += temp; + + // [bS, nOut] -> reduce -> [nOut] + (std::move(dLdzo) * (*c)).reduceAlongDimension(reduce::Sum, temp, {0}); + (*dLdWp)({2 * nOut, 3 * nOut}) += temp; + } } ////////////////////////////////////////////////////////////////////////// void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, - const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp, - const std::vector& params, - const bool forward, + const NDArray* b, const NDArray* seqLen, + const NDArray* hI, const NDArray* cI, const NDArray* Wp, + const std::vector& params, const bool forward, NDArray* h, NDArray* hL, NDArray* cL) { - - // INPUTS: - // x - current input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL], - // Wx - input weights [nIn, 4*nOut] - // Wr - recurrent weights [nOut, 4*nOut] - // b - biases [4*nOut], optional, may be nullptr - // seqLen - [bS], optional, may be nullptr - // hI - initial output [bS, nOut], optional, may be nullptr - // cI - initial cell state at time t-1 [bS, nOut], optional, may be nullptr - // Wp - peephole weights [3*nOut], optional, may be nullptr - - // OUTPUTS: - // h - output [sL, bS, nOut], [bS, sL, nOut], [bS, nOut, sL], optional, may be nullptr - // hL - output at last step [bS, nOut], optional, may be nullptr - // cL - cell state at last step [bS, nOut], optional, may be nullptr - - // params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; - // dataFormat: 0,3 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL] - - const int dataFormat = params[0]; - const int directionMode = params[1]; - - const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); - const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); - const Nd4jLong nOut = Wx->sizeAt(-1) / 4; - - const std::vector shapeOut = {bS, nOut}; - - const auto type = h ? h->dataType() : (hL ? hL->dataType() : cL->dataType()); - - auto h0 = const_cast(hI); - if(!hI) { - h0 = new NDArray(x->ordering(), shapeOut, type, x->getContext()); - h0->nullify(); - } - - auto c0 = const_cast(cI); - if(!cI) { - c0 = new NDArray(x->ordering(), shapeOut, type, x->getContext()); - c0->nullify(); - } - - auto ct = cL; - if(!cL) - ct = new NDArray(x->ordering(), shapeOut, type, x->getContext()); - - auto ht = hL; - if(!h && !hL) - ht = new NDArray(x->ordering(), shapeOut, type, x->getContext()); - - // create sets of required (depends on seqLen presence) sub-arrays - std::vector dims; - ResultSet *xSet(nullptr), *hSet(nullptr), *h0Set(nullptr), *c0Set(nullptr), *htSet(nullptr), *ctSet(nullptr); - - if(!seqLen) { - - dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {dataFormat < 3 ? dataFormat : 0}); // points on bS and nIn/nOut axes - - xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nIn] - if(h) - hSet = new ResultSet(h->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nOut] - } - else { - - dims = dataFormat == 2 ? std::vector({1}) : std::vector({2}); // points on nIn/nOut axis - - xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn] - h0Set = new ResultSet(h0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] - c0Set = new ResultSet(c0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] - ctSet = new ResultSet(ct->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] - if(h) - hSet = new ResultSet(h->allTensorsAlongDimension(dims)); // sub-arrays with shape [nOut] - if(ht) - htSet = new ResultSet(ht->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] - } - - // loops - if(forward) { - - if(!seqLen) { - - if(!h) { // seqLen and h are absent - - lstmLayerCell(xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step - for (Nd4jLong t = 1; t < sL; ++t) - lstmLayerCell(xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps - } - else { // seqLen is absent and h is present - - lstmLayerCell(xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, hSet->at(0), ct); // first time step - for (Nd4jLong t = 1; t < sL; ++t) - lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t - 1), ct, Wp, params, hSet->at(t), ct); // rest time steps - - if(hL) - hL->assign(hSet->at(sL - 1)); // assign last output to hL if it is not nullptr - } + // INPUTS: + // x - current input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL], + // Wx - input weights [nIn, 4*nOut] + // Wr - recurrent weights [nOut, 4*nOut] + // b - biases [4*nOut], optional, may be nullptr + // seqLen - [bS], optional, may be nullptr + // hI - initial output [bS, nOut], optional, may be nullptr + // cI - initial cell state at time t-1 [bS, nOut], optional, may be nullptr + // Wp - peephole weights [3*nOut], optional, may be nullptr + + // OUTPUTS: + // h - output [sL, bS, nOut], [bS, sL, nOut], [bS, nOut, sL], optional, may + // be nullptr hL - output at last step [bS, nOut], optional, may be nullptr cL + // - cell state at last step [bS, nOut], optional, may be nullptr + + // params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, + // gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; + // dataFormat: 0,3 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL] + + const int dataFormat = params[0]; + const int directionMode = params[1]; + + const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); + const Nd4jLong bS = + dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + const std::vector shapeOut = {bS, nOut}; + + const auto type = h ? h->dataType() : (hL ? hL->dataType() : cL->dataType()); + + auto h0 = const_cast(hI); + if (!hI) { + h0 = new NDArray(x->ordering(), shapeOut, type, x->getContext()); + h0->nullify(); + } + + auto c0 = const_cast(cI); + if (!cI) { + c0 = new NDArray(x->ordering(), shapeOut, type, x->getContext()); + c0->nullify(); + } + + auto ct = cL; + if (!cL) ct = new NDArray(x->ordering(), shapeOut, type, x->getContext()); + + auto ht = hL; + if (!h && !hL) + ht = new NDArray(x->ordering(), shapeOut, type, x->getContext()); + + // create sets of required (depends on seqLen presence) sub-arrays + std::vector dims; + ResultSet *xSet(nullptr), *hSet(nullptr), *h0Set(nullptr), *c0Set(nullptr), + *htSet(nullptr), *ctSet(nullptr); + + if (!seqLen) { + dims = ShapeUtils::evalDimsToExclude( + x->rankOf(), + {dataFormat < 3 ? dataFormat : 0}); // points on bS and nIn/nOut axes + + xSet = new ResultSet( + x->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nIn] + if (h) + hSet = new ResultSet(h->allTensorsAlongDimension( + dims)); // sub-arrays with shape [bS, nOut] + } else { + dims = dataFormat == 2 ? std::vector({1}) + : std::vector({2}); // points on nIn/nOut axis + + xSet = new ResultSet( + x->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn] + h0Set = new ResultSet( + h0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + c0Set = new ResultSet( + c0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + ctSet = new ResultSet( + ct->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + if (h) + hSet = new ResultSet( + h->allTensorsAlongDimension(dims)); // sub-arrays with shape [nOut] + if (ht) + htSet = new ResultSet( + ht->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + } + + // loops + if (forward) { + if (!seqLen) { + if (!h) { // seqLen and h are absent + + lstmLayerCell(&xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, ht, + ct); // first time step + for (Nd4jLong t = 1; t < sL; ++t) + lstmLayerCell(&xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, + ct); // rest time steps + } else { // seqLen is absent and h is present + + lstmLayerCell(&xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, &hSet->at(0), + ct); // first time step + for (Nd4jLong t = 1; t < sL; ++t) + lstmLayerCell(&xSet->at(t), Wx, Wr, b, &hSet->at(t - 1), ct, Wp, + params, &hSet->at(t), ct); // rest time steps + + if (hL) + hL->assign(hSet->at( + sL - 1)); // assign last output to hL if it is not nullptr + } + } else { + if (!h) { // seqLen is present and h is absent + + for (Nd4jLong e = 0; e < bS; ++e) { + const int limit = seqLen->e(e); + + if (limit == 0) { + if (cL) ctSet->at(e).nullify(); + if (hL) htSet->at(e).nullify(); + continue; + } + + auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e); + lstmLayerCell(&xSet->at(ind), Wx, Wr, b, &h0Set->at(e), &c0Set->at(e), + Wp, params, &htSet->at(e), + &ctSet->at(e)); // first time step + + for (int t = 1; t < limit; ++t) { + ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + lstmLayerCell(&xSet->at(ind), Wx, Wr, b, &htSet->at(e), + &ctSet->at(e), Wp, params, &htSet->at(e), + &ctSet->at(e)); // rest time steps + } } - else { - - if(!h) { // seqLen is present and h is absent - - for (Nd4jLong e = 0; e < bS; ++e) { - - const int limit = seqLen->e(e); - - if(limit == 0) { - if(cL) - ctSet->at(e)->nullify(); - if(hL) - htSet->at(e)->nullify(); - continue; - } - - auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e); - lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step + } else { // seqLen and h are present - for (int t = 1; t < limit; ++t) { - ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps - } - } - } - else { // seqLen and h are present + for (Nd4jLong e = 0; e < bS; ++e) { + int limit = seqLen->e(e); - for (Nd4jLong e = 0; e < bS; ++e) { + if (limit == 0) { + tensorAlongTimeBatchDims(*h, dataFormat, 0, 0, e, e + 1) + .nullify(); // nullify for given e and whole time range - int limit = seqLen->e(e); + if (cL) ctSet->at(e).nullify(); + if (hL) htSet->at(e).nullify(); - if(limit == 0) { + continue; + } - tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range + auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e); + lstmLayerCell(&xSet->at(indPrev), Wx, Wr, b, &h0Set->at(e), + &c0Set->at(e), Wp, params, &hSet->at(indPrev), + &ctSet->at(e)); // first time step - if(cL) - ctSet->at(e)->nullify(); - if(hL) - htSet->at(e)->nullify(); + for (int t = 1; t < limit; ++t) { + auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + lstmLayerCell(&xSet->at(indCurr), Wx, Wr, b, &hSet->at(indPrev), + &ctSet->at(e), Wp, params, &hSet->at(indCurr), + &ctSet->at(e)); // rest time steps + indPrev = indCurr; + } - continue; - } + if (hL) + htSet->at(e).assign(hSet->at( + indPrev)); // assign last output to hL if hL is not nullptr - auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e); - lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step - - for (int t = 1; t < limit; ++t) { - auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps - indPrev = indCurr; - } - - if(hL) - htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if hL is not nullptr - - if(limit != sL) - tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) - } - } + if (limit != sL) + tensorAlongTimeBatchDims(*h, dataFormat, limit, sL, e, e + 1) + .nullify(); // nullify for given e and time range [limit, sL) } + } } - else { // backward - - if(!seqLen) { - - if(!h) { // seqLen and h are absent - - lstmLayerCell(xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step - for (Nd4jLong t = sL - 2; t >= 0; --t) - lstmLayerCell(xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps - } - else { // seqLen is absent and h is present - - lstmLayerCell(xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, hSet->at(sL - 1), ct); // first time step - for (Nd4jLong t = sL - 2; t >= 0; --t) - lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t + 1), ct, Wp, params, hSet->at(t), ct); // rest time steps - - if(hL) - hL->assign(hSet->at(0)); // assign last output to hL if it is not nullptr - } + } else { // backward + + if (!seqLen) { + if (!h) { // seqLen and h are absent + + lstmLayerCell(&xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, ht, + ct); // first time step + for (Nd4jLong t = sL - 2; t >= 0; --t) + lstmLayerCell(&xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, + ct); // rest time steps + } else { // seqLen is absent and h is present + + lstmLayerCell(&xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, + &hSet->at(sL - 1), ct); // first time step + for (Nd4jLong t = sL - 2; t >= 0; --t) + lstmLayerCell(&xSet->at(t), Wx, Wr, b, &hSet->at(t + 1), ct, Wp, + params, &hSet->at(t), ct); // rest time steps + + if (hL) + hL->assign( + hSet->at(0)); // assign last output to hL if it is not nullptr + } + } else if (directionMode == 1) { // only backward, no bidirectional mode + + if (!h) { // h is absent and seqLen is present + + for (Nd4jLong e = 0; e < bS; ++e) { + const int limit = seqLen->e(e); + + if (limit == 0) { + if (cL) ctSet->at(e).nullify(); + if (hL) htSet->at(e).nullify(); + continue; + } + + auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e); + lstmLayerCell(&xSet->at(ind), Wx, Wr, b, &h0Set->at(e), &c0Set->at(e), + Wp, params, &htSet->at(e), + &ctSet->at(e)); // first time step + + for (Nd4jLong t = sL - 2; t >= sL - limit; --t) { + ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + lstmLayerCell(&xSet->at(ind), Wx, Wr, b, &htSet->at(e), + &ctSet->at(e), Wp, params, &htSet->at(e), + &ctSet->at(e)); // rest time steps + } } - else if(directionMode == 1) { // only backward, no bidirectional mode - - if(!h) { // h is absent and seqLen is present - - for (Nd4jLong e = 0; e < bS; ++e) { + } else { // seqLen and h are present - const int limit = seqLen->e(e); + for (Nd4jLong e = 0; e < bS; ++e) { + int limit = seqLen->e(e); - if(limit == 0) { - if(cL) - ctSet->at(e)->nullify(); - if(hL) - htSet->at(e)->nullify(); - continue; - } + if (limit == 0) { + tensorAlongTimeBatchDims(*h, dataFormat, 0, 0, e, e + 1) + .nullify(); // nullify for given e and whole time range - auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e); - lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step + if (cL) ctSet->at(e).nullify(); + if (hL) htSet->at(e).nullify(); - for (Nd4jLong t = sL - 2; t >= sL - limit; --t) { - ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps - } - } - } - else { // seqLen and h are present + continue; + } - for (Nd4jLong e = 0; e < bS; ++e) { + auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e); + lstmLayerCell(&xSet->at(indPrev), Wx, Wr, b, &h0Set->at(e), + &c0Set->at(e), Wp, params, &hSet->at(indPrev), + &ctSet->at(e)); // first time step - int limit = seqLen->e(e); + for (Nd4jLong t = sL - 2; t >= sL - limit; --t) { + auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + lstmLayerCell(&xSet->at(indCurr), Wx, Wr, b, &hSet->at(indPrev), + &ctSet->at(e), Wp, params, &hSet->at(indCurr), + &ctSet->at(e)); // rest time steps + indPrev = indCurr; + } - if(limit == 0) { + if (hL) + htSet->at(e).assign(hSet->at( + indPrev)); // assign last output to hL if it is not nullptr - tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range - - if(cL) - ctSet->at(e)->nullify(); - if(hL) - htSet->at(e)->nullify(); - - continue; - } - - auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e); - lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step - - for (Nd4jLong t = sL - 2; t >= sL - limit; --t) { - auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps - indPrev = indCurr; - } - - if(hL) - htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr - - if(limit != sL) - tensorAlongTimeBatchDims(*h, dataFormat, 0,sL-limit, e,e+1).nullify(); // nullify for given e and time range [limit, sL) - } - } + if (limit != sL) + tensorAlongTimeBatchDims(*h, dataFormat, 0, sL - limit, e, e + 1) + .nullify(); // nullify for given e and time range [limit, sL) } - else { // backward in bidirectional mode - - if(!h) { // h is absent and seqLen is present - - for (Nd4jLong e = 0; e < bS; ++e) { - - const int limit = seqLen->e(e); - - if(limit == 0) { - if(cL) - ctSet->at(e)->nullify(); - if(hL) - htSet->at(e)->nullify(); - continue; - } - - auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e); - lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step - - for (int t = limit - 2; t >= 0; --t) { - ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps - } - } - } - else { // seqLen and h are present - - for (Nd4jLong e = 0; e < bS; ++e) { - - int limit = seqLen->e(e); - - if(limit == 0) { - - tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range - - if(cL) - ctSet->at(e)->nullify(); - if(hL) - htSet->at(e)->nullify(); - - continue; - } - - auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e); - lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step - - for (int t = limit - 2; t >= 0; --t) { - auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps - indPrev = indCurr; - } - - if(hL) - htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr - - if(limit != sL) - tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) - } - } + } + } else { // backward in bidirectional mode + + if (!h) { // h is absent and seqLen is present + + for (Nd4jLong e = 0; e < bS; ++e) { + const int limit = seqLen->e(e); + + if (limit == 0) { + if (cL) ctSet->at(e).nullify(); + if (hL) htSet->at(e).nullify(); + continue; + } + + auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e); + lstmLayerCell(&xSet->at(ind), Wx, Wr, b, &h0Set->at(e), &c0Set->at(e), + Wp, params, &htSet->at(e), + &ctSet->at(e)); // first time step + + for (int t = limit - 2; t >= 0; --t) { + ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + lstmLayerCell(&xSet->at(ind), Wx, Wr, b, &htSet->at(e), + &ctSet->at(e), Wp, params, &htSet->at(e), + &ctSet->at(e)); // rest time steps + } + } + } else { // seqLen and h are present + + for (Nd4jLong e = 0; e < bS; ++e) { + int limit = seqLen->e(e); + + if (limit == 0) { + tensorAlongTimeBatchDims(*h, dataFormat, 0, 0, e, e + 1) + .nullify(); // nullify for given e and whole time range + + if (cL) ctSet->at(e).nullify(); + if (hL) htSet->at(e).nullify(); + + continue; + } + + auto indPrev = + getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e); + lstmLayerCell(&xSet->at(indPrev), Wx, Wr, b, &h0Set->at(e), + &c0Set->at(e), Wp, params, &hSet->at(indPrev), + &ctSet->at(e)); // first time step + + for (int t = limit - 2; t >= 0; --t) { + auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + lstmLayerCell(&xSet->at(indCurr), Wx, Wr, b, &hSet->at(indPrev), + &ctSet->at(e), Wp, params, &hSet->at(indCurr), + &ctSet->at(e)); // rest time steps + indPrev = indCurr; + } + + if (hL) + htSet->at(e).assign(hSet->at( + indPrev)); // assign last output to hL if it is not nullptr + + if (limit != sL) + tensorAlongTimeBatchDims(*h, dataFormat, limit, sL, e, e + 1) + .nullify(); // nullify for given e and time range [limit, sL) } + } } - - delete xSet; - delete hSet; - delete h0Set; - delete c0Set; - delete htSet; - delete ctSet; - - if(!hI) - delete h0; - if(!cI) - delete c0; - if(!cL) - delete ct; - if(!h && !hL) - delete ht; + } + + delete xSet; + delete hSet; + delete h0Set; + delete c0Set; + delete htSet; + delete ctSet; + + if (!hI) delete h0; + if (!cI) delete c0; + if (!cL) delete ct; + if (!h && !hL) delete ht; } - ////////////////////////////////////////////////////////////////////////// void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, - const NDArray* b, const NDArray* seqLen, NDArray* hI, NDArray* cI, const NDArray* Wp, - const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL, + const NDArray* b, const NDArray* seqLen, NDArray* hI, + NDArray* cI, const NDArray* Wp, const NDArray* dLdh, + const NDArray* dLdhL, const NDArray* dLdcL, const std::vector& params, const bool forward, - NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdb, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdWp) { - - // INPUTS: - // x - current input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL], - // Wx - input weights [nIn, 4*nOut] - // Wr - recurrent weights [nOut, 4*nOut] - // b - biases [4*nOut], optional, may be nullptr - // seqLen - [bS], optional, may be nullptr - // hI - initial output [bS, nOut], optional, may be nullptr - // cI - initial cell state at time t-1 [bS, nOut], optional, may be nullptr - // Wp - peephole weights [3*nOut], optional, may be nullptr - // dLdh - gradient vs. output [sL, bS, nOut], [bS, sL, nOut], [bS, nOut, sL], optional, may be nullptr - // dLdhL - gradient vs. output at last time step [bS, nOut], optional, may be nullptr - // dLdcL - gradient vs. cell state at last time step [bS, nOut], optional, may be nullptr - - // OUTPUTS: - // dLdx - gradient vs. input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL] - // dLdWx - gradient vs. input weights [nIn, 4*nOut] - // dLdWr - gradient vs. recurrent weights [nOut, 4*nOut] - // dLdb - gradient vs. biases [4*nOut], optional, may be nullptr - // dLdhI - gradient vs. initial output [bS, nOut], optional, may be nullptr - // dLdcI - gradient vs. initial cell state at time t-1 [bS, nOut], optional, may be nullptr - // dLdWp - gradient vs. peephole weights [3*nOut], optional, may be nullptr - - // params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; - // dataFormat: 0,3 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL] - - const int dataFormat = params[0]; - const int directionMode = params[1]; - - const int sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); - const int bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); - const int nOut = Wx->sizeAt(-1) / 4; - - const auto type = dLdh ? dLdh->dataType() : (dLdhL ? dLdhL->dataType() : dLdcL->dataType()); - - auto dLdh0 = dLdhI; - if(!hI) - dLdh0 = new NDArray(x->ordering(), {bS, nOut}, type, x->getContext()); // this constructor nullifies array automatically - - auto dLdc0 = dLdcI; - if(!cI) - dLdc0 = new NDArray(x->ordering(), {bS, nOut}, type, x->getContext()); // this constructor nullifies array automatically - - NDArray z(x->ordering(), {sL, bS, 4*nOut}, type, x->getContext()); - NDArray a = z.ulike(); - NDArray h(x->ordering(), {sL+1, bS, nOut}, type, x->getContext()); - NDArray c = h.ulike(); - - // create sets of required (depends on seqLen presence) sub-arrays - std::vector dims; - ResultSet *xSet(nullptr), *dLdxSet(nullptr), *hSet(nullptr), *cSet(nullptr), *zSet(nullptr), *aSet(nullptr), *dLdhSet(nullptr), - *dLdh0Set(nullptr), *dLdc0Set(nullptr), *dLdhLSet(nullptr), *dLdcLSet(nullptr), *hISet(nullptr), *cISet(nullptr); - - if(!seqLen) { - - dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {dataFormat < 3 ? dataFormat : 0}); // points on [bS, nIn/nOut] - - xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nIn] - dLdxSet = new ResultSet(dLdx->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nIn] - hSet = new ResultSet(h.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, nOut] - cSet = new ResultSet(c.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, nOut] - zSet = new ResultSet(z.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, 4*nOut] - aSet = new ResultSet(a.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, 4*nOut] - if(dLdh) - dLdhSet = new ResultSet(dLdh->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nOut] - } - else { - - dims = dataFormat == 2 ? std::vector({1}) : std::vector({2}); // points on nIn/nOut axis - - xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn] - dLdxSet = new ResultSet(dLdx->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn] - hSet = new ResultSet(h.allTensorsAlongDimension({2})); // sub-arrays with shape [nOut] - cSet = new ResultSet(c.allTensorsAlongDimension({2})); // sub-arrays with shape [nOut] - zSet = new ResultSet(z.allTensorsAlongDimension({2})); // sub-arrays with shape [4*nOut] - aSet = new ResultSet(a.allTensorsAlongDimension({2})); // sub-arrays with shape [4*nOut] - - if(hI) - hISet = new ResultSet(hI->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] - if(cI) - cISet = new ResultSet(cI->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] - - dLdh0Set = new ResultSet(dLdh0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] - dLdc0Set = new ResultSet(dLdc0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] - - if(dLdh) - dLdhSet = new ResultSet(dLdh->allTensorsAlongDimension(dims)); // sub-arrays with shape [nOut] - if(dLdhL) - dLdhLSet = new ResultSet(dLdhL->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] - if(dLdcL) - dLdcLSet = new ResultSet(dLdcL->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] - } - - - // loops - if(forward) { - - if(!seqLen) { // seqLen is absent - - if(hI) - hSet->at(0)->assign(hI); - else - hSet->at(0)->nullify(); - if(cI) - cSet->at(0)->assign(cI); - else - cSet->at(0)->nullify(); - - // ff - for (int t = 0; t < sL; ++t) - lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t), cSet->at(t), Wp, params, zSet->at(t), aSet->at(t), hSet->at(t+1), cSet->at(t+1)); - - // bp - for (int t = sL-1; t >= 0; --t) { - const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : nullptr; - const NDArray* dLdhhL = (t == sL-1 && dLdhL) ? dLdhL : nullptr; - const NDArray* dLdccL = (t == sL-1 && dLdcL) ? dLdcL : nullptr; - lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t), cSet->at(t), Wp, dLdhh, dLdhhL, dLdccL, - zSet->at(t), aSet->at(t), cSet->at(t+1), params, dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp); - } + NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, + NDArray* dLdb, NDArray* dLdhI, NDArray* dLdcI, + NDArray* dLdWp) { + // INPUTS: + // x - current input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL], + // Wx - input weights [nIn, 4*nOut] + // Wr - recurrent weights [nOut, 4*nOut] + // b - biases [4*nOut], optional, may be nullptr + // seqLen - [bS], optional, may be nullptr + // hI - initial output [bS, nOut], optional, may be nullptr + // cI - initial cell state at time t-1 [bS, nOut], optional, may be nullptr + // Wp - peephole weights [3*nOut], optional, may be nullptr + // dLdh - gradient vs. output [sL, bS, nOut], [bS, sL, nOut], [bS, nOut, + // sL], optional, may be nullptr dLdhL - gradient vs. output at last time step + // [bS, nOut], optional, may be nullptr dLdcL - gradient vs. cell state at + // last time step [bS, nOut], optional, may be nullptr + + // OUTPUTS: + // dLdx - gradient vs. input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL] + // dLdWx - gradient vs. input weights [nIn, 4*nOut] + // dLdWr - gradient vs. recurrent weights [nOut, 4*nOut] + // dLdb - gradient vs. biases [4*nOut], optional, may be nullptr + // dLdhI - gradient vs. initial output [bS, nOut], optional, may be nullptr + // dLdcI - gradient vs. initial cell state at time t-1 [bS, nOut], optional, + // may be nullptr dLdWp - gradient vs. peephole weights [3*nOut], optional, + // may be nullptr + + // params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, + // gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; + // dataFormat: 0,3 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL] + + const int dataFormat = params[0]; + const int directionMode = params[1]; + + const int sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); + const int bS = + dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); + const int nOut = Wx->sizeAt(-1) / 4; + + const auto type = + dLdh ? dLdh->dataType() : (dLdhL ? dLdhL->dataType() : dLdcL->dataType()); + + auto dLdh0 = dLdhI; + if (!hI) + dLdh0 = new NDArray( + x->ordering(), {bS, nOut}, type, + x->getContext()); // this constructor nullifies array automatically + + auto dLdc0 = dLdcI; + if (!cI) + dLdc0 = new NDArray( + x->ordering(), {bS, nOut}, type, + x->getContext()); // this constructor nullifies array automatically + + NDArray z(x->ordering(), {sL, bS, 4 * nOut}, type, x->getContext()); + NDArray a = z.ulike(); + NDArray h(x->ordering(), {sL + 1, bS, nOut}, type, x->getContext()); + NDArray c = h.ulike(); + + // create sets of required (depends on seqLen presence) sub-arrays + std::vector dims; + ResultSet *xSet(nullptr), *dLdxSet(nullptr), *hSet(nullptr), *cSet(nullptr), + *zSet(nullptr), *aSet(nullptr), *dLdhSet(nullptr), *dLdh0Set(nullptr), + *dLdc0Set(nullptr), *dLdhLSet(nullptr), *dLdcLSet(nullptr), + *hISet(nullptr), *cISet(nullptr); + + if (!seqLen) { + dims = ShapeUtils::evalDimsToExclude( + x->rankOf(), + {dataFormat < 3 ? dataFormat : 0}); // points on [bS, nIn/nOut] + + xSet = new ResultSet( + x->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nIn] + dLdxSet = new ResultSet(dLdx->allTensorsAlongDimension( + dims)); // sub-arrays with shape [bS, nIn] + hSet = new ResultSet(h.allTensorsAlongDimension( + {1, 2})); // sub-arrays with shape [bS, nOut] + cSet = new ResultSet(c.allTensorsAlongDimension( + {1, 2})); // sub-arrays with shape [bS, nOut] + zSet = new ResultSet(z.allTensorsAlongDimension( + {1, 2})); // sub-arrays with shape [bS, 4*nOut] + aSet = new ResultSet(a.allTensorsAlongDimension( + {1, 2})); // sub-arrays with shape [bS, 4*nOut] + if (dLdh) + dLdhSet = new ResultSet(dLdh->allTensorsAlongDimension( + dims)); // sub-arrays with shape [bS, nOut] + } else { + dims = dataFormat == 2 ? std::vector({1}) + : std::vector({2}); // points on nIn/nOut axis + + xSet = new ResultSet( + x->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn] + dLdxSet = new ResultSet( + dLdx->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn] + hSet = new ResultSet( + h.allTensorsAlongDimension({2})); // sub-arrays with shape [nOut] + cSet = new ResultSet( + c.allTensorsAlongDimension({2})); // sub-arrays with shape [nOut] + zSet = new ResultSet( + z.allTensorsAlongDimension({2})); // sub-arrays with shape [4*nOut] + aSet = new ResultSet( + a.allTensorsAlongDimension({2})); // sub-arrays with shape [4*nOut] + + if (hI) + hISet = new ResultSet( + hI->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + if (cI) + cISet = new ResultSet( + cI->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + + dLdh0Set = new ResultSet( + dLdh0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + dLdc0Set = new ResultSet( + dLdc0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + + if (dLdh) + dLdhSet = new ResultSet(dLdh->allTensorsAlongDimension( + dims)); // sub-arrays with shape [nOut] + if (dLdhL) + dLdhLSet = new ResultSet(dLdhL->allTensorsAlongDimension( + {1})); // sub-arrays with shape [nOut] + if (dLdcL) + dLdcLSet = new ResultSet(dLdcL->allTensorsAlongDimension( + {1})); // sub-arrays with shape [nOut] + } + + // loops + if (forward) { + if (!seqLen) { // seqLen is absent + + if (hI) + hSet->at(0).assign(hI); + else + hSet->at(0).nullify(); + if (cI) + cSet->at(0).assign(cI); + else + cSet->at(0).nullify(); + + // ff + for (int t = 0; t < sL; ++t) + lstmLayerCell(&xSet->at(t), Wx, Wr, b, &hSet->at(t), &cSet->at(t), Wp, + params, &zSet->at(t), &aSet->at(t), &hSet->at(t + 1), + &cSet->at(t + 1)); + + // bp + for (int t = sL - 1; t >= 0; --t) { + const NDArray* dLdhh = dLdh ? &dLdhSet->at(t) : nullptr; + const NDArray* dLdhhL = (t == sL - 1 && dLdhL) ? dLdhL : nullptr; + const NDArray* dLdccL = (t == sL - 1 && dLdcL) ? dLdcL : nullptr; + lstmLayerCellBp(&xSet->at(t), Wx, Wr, b, &hSet->at(t), &cSet->at(t), Wp, + dLdhh, dLdhhL, dLdccL, &zSet->at(t), &aSet->at(t), + &cSet->at(t + 1), params, &dLdxSet->at(t), dLdWx, dLdWr, + dLdh0, dLdc0, dLdb, dLdWp); + } + } else { // seqLen is present + + for (int e = 0; e < bS; ++e) { + const int limit = seqLen->e(e); + + if (limit == 0) { + tensorAlongTimeBatchDims(*dLdx, dataFormat, 0, 0, e, e + 1) + .nullify(); // nullify for given e and whole time range + continue; } - else { // seqLen is present - - for (int e = 0; e < bS; ++e) { - - const int limit = seqLen->e(e); - - if(limit == 0) { - tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range - continue; - } - - if(hI) - hSet->at(e)->assign(hISet->at(e)); - else - hSet->at(e)->nullify(); - if(cI) - cSet->at(e)->assign(cISet->at(e)); - else - cSet->at(e)->nullify(); - - // ff - for (int t = 0; t < limit; ++t) - lstmLayerCell(xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, hSet->at(t*bS + e), cSet->at(t*bS + e), Wp, params, - zSet->at(t*bS + e), aSet->at(t*bS + e), hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e)); - - // bp - for (int t = limit-1; t >= 0; --t) { - const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : nullptr; - const NDArray* dLdhhL = (t == limit-1 && dLdhL) ? dLdhLSet->at(e) : nullptr; - const NDArray* dLdccL = (t == limit-1 && dLdcL) ? dLdcLSet->at(e) : nullptr; - lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at(t*bS + e), cSet->at(t*bS + e), Wp, dLdhh, dLdhhL, dLdccL, - zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at((t+1)*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr, - dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp); - } - - if(limit != sL) - tensorAlongTimeBatchDims(*dLdx, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) - } + + if (hI) + hSet->at(e).assign(hISet->at(e)); + else + hSet->at(e).nullify(); + if (cI) + cSet->at(e).assign(cISet->at(e)); + else + cSet->at(e).nullify(); + + // ff + for (int t = 0; t < limit; ++t) + lstmLayerCell( + &xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, + Wr, b, &hSet->at(t * bS + e), &cSet->at(t * bS + e), Wp, params, + &zSet->at(t * bS + e), &aSet->at(t * bS + e), + &hSet->at((t + 1) * bS + e), &cSet->at((t + 1) * bS + e)); + + // bp + for (int t = limit - 1; t >= 0; --t) { + const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + const NDArray* dLdhh = dLdh ? &dLdhSet->at(ind) : nullptr; + const NDArray* dLdhhL = + (t == limit - 1 && dLdhL) ? &dLdhLSet->at(e) : nullptr; + const NDArray* dLdccL = + (t == limit - 1 && dLdcL) ? &dLdcLSet->at(e) : nullptr; + lstmLayerCellBp(&xSet->at(ind), Wx, Wr, b, &hSet->at(t * bS + e), + &cSet->at(t * bS + e), Wp, dLdhh, dLdhhL, dLdccL, + &zSet->at(t * bS + e), &aSet->at(t * bS + e), + &cSet->at((t + 1) * bS + e), params, + &dLdxSet->at(ind), dLdWx, dLdWr, &dLdh0Set->at(e), + &dLdc0Set->at(e), dLdb, dLdWp); } + + if (limit != sL) + tensorAlongTimeBatchDims(*dLdx, dataFormat, limit, sL, e, e + 1) + .nullify(); // nullify for given e and time range [limit, sL) + } } - else { // backward or bidirectional - - if(!seqLen) { // backward or bidirectional, seqLen is absent - - if(hI) - hSet->at(sL)->assign(hI); - else - hSet->at(sL)->nullify(); - if(cI) - cSet->at(sL)->assign(cI); - else - cSet->at(sL)->nullify(); - - // ff - for (int t = sL-1; t >= 0; --t) - lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t+1), cSet->at(t+1), Wp, params, zSet->at(t), aSet->at(t), hSet->at(t), cSet->at(t)); - - // bp - for (int t = 0; t < sL; ++t) { - const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : nullptr; - const NDArray* dLdhhL = (t == 0 && dLdhL) ? dLdhL : nullptr; - const NDArray* dLdccL = (t == 0 && dLdcL) ? dLdcL : nullptr; - lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t+1), cSet->at(t+1), Wp, dLdhh, dLdhhL, dLdccL, - zSet->at(t), aSet->at(t), cSet->at(t), params, dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp); - } - } - else if(directionMode == 1) { // backward, seqLen is present - - for (int e = 0; e < bS; ++e) { - - const int limit = seqLen->e(e); - - if(limit == 0) { - tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range - continue; - } - - if(hI) - hSet->at(sL*bS + e)->assign(hISet->at(e)); - else - hSet->at(sL*bS + e)->nullify(); - if(cI) - cSet->at(sL*bS + e)->assign(cISet->at(e)); - else - cSet->at(sL*bS + e)->nullify(); - - // ff - for (int t = sL - 1; t >= sL-limit; --t) - lstmLayerCell(xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, params, - zSet->at(t*bS + e), aSet->at(t*bS + e), hSet->at(t*bS + e), cSet->at(t*bS + e)); - - // bp - for (int t = sL-limit; t < sL; ++t) { - const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : nullptr; - const NDArray* dLdhhL = (t == sL-limit && dLdhL) ? dLdhLSet->at(e) : nullptr; - const NDArray* dLdccL = (t == sL-limit && dLdcL) ? dLdcLSet->at(e) : nullptr; - lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdhhL, dLdccL, - zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at(t*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr, - dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp); - } - - if(limit != sL) - tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,sL-limit, e,e+1).nullify(); // nullify for given e and time range [limit, sL) - } + } else { // backward or bidirectional + + if (!seqLen) { // backward or bidirectional, seqLen is absent + + if (hI) + hSet->at(sL).assign(hI); + else + hSet->at(sL).nullify(); + if (cI) + cSet->at(sL).assign(cI); + else + cSet->at(sL).nullify(); + + // ff + for (int t = sL - 1; t >= 0; --t) + lstmLayerCell(&xSet->at(t), Wx, Wr, b, &hSet->at(t + 1), + &cSet->at(t + 1), Wp, params, &zSet->at(t), &aSet->at(t), + &hSet->at(t), &cSet->at(t)); + + // bp + for (int t = 0; t < sL; ++t) { + const NDArray* dLdhh = dLdh ? &dLdhSet->at(t) : nullptr; + const NDArray* dLdhhL = (t == 0 && dLdhL) ? dLdhL : nullptr; + const NDArray* dLdccL = (t == 0 && dLdcL) ? dLdcL : nullptr; + lstmLayerCellBp( + &xSet->at(t), Wx, Wr, b, &hSet->at(t + 1), &cSet->at(t + 1), Wp, + dLdhh, dLdhhL, dLdccL, &zSet->at(t), &aSet->at(t), &cSet->at(t), + params, &dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp); + } + } else if (directionMode == 1) { // backward, seqLen is present + + for (int e = 0; e < bS; ++e) { + const int limit = seqLen->e(e); + + if (limit == 0) { + tensorAlongTimeBatchDims(*dLdx, dataFormat, 0, 0, e, e + 1) + .nullify(); // nullify for given e and whole time range + continue; } - else { // bidirectional mode, seqLen is present - - for (int e = 0; e < bS; ++e) { - - const int limit = seqLen->e(e); - - if(limit == 0) { - tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range - continue; - } - - if(hI) - h({limit,limit+1, e,e+1, 0,0}).assign(hISet->at(e)); - else - h({limit,limit+1, e,e+1, 0,0}).nullify(); - if(cI) - c({limit,limit+1, e,e+1, 0,0}).assign(cISet->at(e)); - else - c({limit,limit+1, e,e+1, 0,0}).nullify(); - - // ff - for (int t = limit - 1; t >= 0; --t) - lstmLayerCell(xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, params, - zSet->at(t*bS + e), aSet->at(t*bS + e), hSet->at(t*bS + e), cSet->at(t*bS + e)); - - // bp - for (int t = 0; t < limit; ++t) { - const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : nullptr; - const NDArray* dLdhhL = (t == 0 && dLdhL) ? dLdhLSet->at(e) : nullptr; - const NDArray* dLdccL = (t == 0 && dLdcL) ? dLdcLSet->at(e) : nullptr; - lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdhhL, dLdccL, - zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at(t*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr, - dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp); - } - - if(limit != sL) - tensorAlongTimeBatchDims(*dLdx, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) - } + + if (hI) + hSet->at(sL * bS + e).assign(hISet->at(e)); + else + hSet->at(sL * bS + e).nullify(); + if (cI) + cSet->at(sL * bS + e).assign(cISet->at(e)); + else + cSet->at(sL * bS + e).nullify(); + + // ff + for (int t = sL - 1; t >= sL - limit; --t) + lstmLayerCell( + &xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, + Wr, b, &hSet->at((t + 1) * bS + e), &cSet->at((t + 1) * bS + e), + Wp, params, &zSet->at(t * bS + e), &aSet->at(t * bS + e), + &hSet->at(t * bS + e), &cSet->at(t * bS + e)); + + // bp + for (int t = sL - limit; t < sL; ++t) { + const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + const NDArray* dLdhh = dLdh ? &dLdhSet->at(ind) : nullptr; + const NDArray* dLdhhL = + (t == sL - limit && dLdhL) ? &dLdhLSet->at(e) : nullptr; + const NDArray* dLdccL = + (t == sL - limit && dLdcL) ? &dLdcLSet->at(e) : nullptr; + lstmLayerCellBp( + &xSet->at(ind), Wx, Wr, b, &hSet->at((t + 1) * bS + e), + &cSet->at((t + 1) * bS + e), Wp, dLdhh, dLdhhL, dLdccL, + &zSet->at(t * bS + e), &aSet->at(t * bS + e), + &cSet->at(t * bS + e), params, &dLdxSet->at(ind), dLdWx, dLdWr, + &dLdh0Set->at(e), &dLdc0Set->at(e), dLdb, dLdWp); } - } - delete xSet; delete dLdxSet; delete hSet; delete cSet; delete aSet; delete zSet; - delete dLdhSet; delete dLdh0Set; delete dLdc0Set; delete dLdhLSet; delete dLdcLSet; delete hISet; delete cISet; + if (limit != sL) + tensorAlongTimeBatchDims(*dLdx, dataFormat, 0, sL - limit, e, e + 1) + .nullify(); // nullify for given e and time range [limit, sL) + } + } else { // bidirectional mode, seqLen is present - if(!hI) - delete dLdh0; - if(!cI) - delete dLdc0; -} + for (int e = 0; e < bS; ++e) { + const int limit = seqLen->e(e); + if (limit == 0) { + tensorAlongTimeBatchDims(*dLdx, dataFormat, 0, 0, e, e + 1) + .nullify(); // nullify for given e and whole time range + continue; + } -} -} -} + if (hI) + h({limit, limit + 1, e, e + 1, 0, 0}).assign(hISet->at(e)); + else + h({limit, limit + 1, e, e + 1, 0, 0}).nullify(); + if (cI) + c({limit, limit + 1, e, e + 1, 0, 0}).assign(cISet->at(e)); + else + c({limit, limit + 1, e, e + 1, 0, 0}).nullify(); + + // ff + for (int t = limit - 1; t >= 0; --t) + lstmLayerCell( + &xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, + Wr, b, &hSet->at((t + 1) * bS + e), &cSet->at((t + 1) * bS + e), + Wp, params, &zSet->at(t * bS + e), &aSet->at(t * bS + e), + &hSet->at(t * bS + e), &cSet->at(t * bS + e)); + + // bp + for (int t = 0; t < limit; ++t) { + const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + const NDArray* dLdhh = dLdh ? &dLdhSet->at(ind) : nullptr; + const NDArray* dLdhhL = + (t == 0 && dLdhL) ? &dLdhLSet->at(e) : nullptr; + const NDArray* dLdccL = + (t == 0 && dLdcL) ? &dLdcLSet->at(e) : nullptr; + lstmLayerCellBp( + &xSet->at(ind), Wx, Wr, b, &hSet->at((t + 1) * bS + e), + &cSet->at((t + 1) * bS + e), Wp, dLdhh, dLdhhL, dLdccL, + &zSet->at(t * bS + e), &aSet->at(t * bS + e), + &cSet->at(t * bS + e), params, &dLdxSet->at(ind), dLdWx, dLdWr, + &dLdh0Set->at(e), &dLdc0Set->at(e), dLdb, dLdWp); + } + if (limit != sL) + tensorAlongTimeBatchDims(*dLdx, dataFormat, limit, sL, e, e + 1) + .nullify(); // nullify for given e and time range [limit, sL) + } + } + } + + delete xSet; + delete dLdxSet; + delete hSet; + delete cSet; + delete aSet; + delete zSet; + delete dLdhSet; + delete dLdh0Set; + delete dLdc0Set; + delete dLdhLSet; + delete dLdcLSet; + delete hISet; + delete cISet; + + if (!hI) delete dLdh0; + if (!cI) delete dLdc0; +} +} // namespace helpers +} // namespace ops +} // namespace sd ////////////////////////////////////////////////////////////////////////// -// void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, -// const NDArray* b, NDArray* hI, NDArray* cI, const NDArray* Wp, const NDArray* dLdh, -// const std::vector& params, const bool firstIter, - -// NDArray* dhIdcI, NDArray* dhIdWx, NDArray* dcIdWx, NDArray* dhIdWr, NDArray* dcIdWr, -// NDArray* dhIdb, NDArray* dcIdb, NDArray* dhIdWp, NDArray* dcIdWp, -// NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp) { - -// /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ +// void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* +// Wr, +// const NDArray* b, NDArray* hI, NDArray* +// cI, const NDArray* Wp, const NDArray* dLdh, const +// std::vector& params, const bool firstIter, + +// NDArray* dhIdcI, NDArray* dhIdWx, NDArray* dcIdWx, +// NDArray* dhIdWr, NDArray* dcIdWr, NDArray* dhIdb, +// NDArray* dcIdb, NDArray* dhIdWp, NDArray* dcIdWp, +// NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* +// dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp) { + +// /************************ THIS IS NOT OPTIMAZED CODE +// ***********************************/ // /** the objective is to provide math-readable code **/ // // equations (no peephole connections) @@ -1235,13 +1438,16 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // // c = clip(f * cI + i * g) // // h = o * actH(c) -// // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus +// // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky +// relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, +// 9=softsign, 10=softplus // // params[0] - dataFormat, ignore // // params[1] - directionMode, ignore // // params[2] - cell clipping value, if it = 0 then do not apply clipping -// // params[3] - activation ID for input (i), forget (f) and output (o) gates +// // params[3] - activation ID for input (i), forget (f) and output (o) +// gates // // params[4] - alpha value for gates activation // // params[5] - beta value for gates activation @@ -1254,32 +1460,50 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // // params[11] - beta value for output activation // // INPUTS: -// // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr +// // x - current input at time t, [bS, nIn] or [nIn] if seqLen != +// nullptr // // Wx - input weights [nIn, 4*nOut] // // Wr - recurrent weights [nOut, 4*nOut] // // b - biases [4*nOut], optional, may be nullptr -// // hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr -// // cI - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr +// // hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or +// [nOut] if seqLen != nullptr +// // cI - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or +// [nOut] if seqLen != nullptr // // Wp - peephole weights [3*nOut], optional, may be nullptr -// // dLdh - loss derivative with respect to h, [bS, nOut] or [nOut] if seqLen != nullptr -// // dhIdcI - derivative from previous time step, [bS, nOut] or [nOut] if seqLen != nullptr -// // dhIdWx - derivative from previous time step (Jacobian), [nIn, 4*nOut, bS, nOut] or [nIn, 4*nOut, nOut] if seqLen != nullptr -// // dcIdWx - derivative from previous time step (Jacobian), [nIn, 4*nOut, bS, nOut] or [nIn, 4*nOut, nOut] if seqLen != nullptr -// // dhIdWr - derivative from previous time step (Jacobian), [nOut, 4*nOut, bS, nOut] or [nOut, 4*nOut, nOut] if seqLen != nullptr -// // dcIdWr - derivative from previous time step (Jacobian), [nOut, 4*nOut, bS, nOut] or [nOut, 4*nOut, nOut] if seqLen != nullptr -// // dcIdWp - derivative from previous time step, [3*nOut], optional, may be nullptr -// // dhIdWp - derivative from previous time step, [3*nOut], optional, may be nullptr -// // dcIdb - derivative from previous time step, [4*nOut], optional, may be nullptr -// // dhIdb - derivative from previous time step, [4*nOut], optional, may be nullptr +// // dLdh - loss derivative with respect to h, [bS, nOut] or [nOut] if +// seqLen != nullptr +// // dhIdcI - derivative from previous time step, [bS, nOut] or [nOut] if +// seqLen != nullptr +// // dhIdWx - derivative from previous time step (Jacobian), [nIn, 4*nOut, +// bS, nOut] or [nIn, 4*nOut, nOut] if seqLen != nullptr +// // dcIdWx - derivative from previous time step (Jacobian), [nIn, 4*nOut, +// bS, nOut] or [nIn, 4*nOut, nOut] if seqLen != nullptr +// // dhIdWr - derivative from previous time step (Jacobian), [nOut, 4*nOut, +// bS, nOut] or [nOut, 4*nOut, nOut] if seqLen != nullptr +// // dcIdWr - derivative from previous time step (Jacobian), [nOut, 4*nOut, +// bS, nOut] or [nOut, 4*nOut, nOut] if seqLen != nullptr +// // dcIdWp - derivative from previous time step, [3*nOut], optional, may +// be nullptr +// // dhIdWp - derivative from previous time step, [3*nOut], optional, may +// be nullptr +// // dcIdb - derivative from previous time step, [4*nOut], optional, may +// be nullptr +// // dhIdb - derivative from previous time step, [4*nOut], optional, may +// be nullptr // // OUTPUTS: -// // dLdx - loss derivative with respect to x, [bS, nIn] or [nIn] if seqLen != nullptr +// // dLdx - loss derivative with respect to x, [bS, nIn] or [nIn] if +// seqLen != nullptr // // dLdWx - loss derivative with respect to Wx, [nIn, 4*nOut] // // dLdWr - loss derivative with respect to Wr, [nOut, 4*nOut] -// // dLdb - loss derivative with respect to b, optional, may be nullptr, [4*nOut] -// // dLdhI - loss derivative with respect to hI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr -// // dLdcI - loss derivative with respect to cI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr -// // dLdWp - loss derivative with respect to Wp, optional, may be nullptr, [3*nOut] +// // dLdb - loss derivative with respect to b, optional, may be nullptr, +// [4*nOut] +// // dLdhI - loss derivative with respect to hI, optional may be nullptr, +// [bS, nOut] or [nOut] if seqLen != nullptr +// // dLdcI - loss derivative with respect to cI, optional may be nullptr, +// [bS, nOut] or [nOut] if seqLen != nullptr +// // dLdWp - loss derivative with respect to Wp, optional, may be nullptr, +// [3*nOut] // // !!! dimension 4*nOut implies order i, f, g, o // // !!! dimension 3*nOut implies order i, f, o @@ -1302,52 +1526,100 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // // dhIdcI = dhdc_from_previous_time_step -// // dLdx = iFactor×WxiT + fFactor×WxfT + eFactor×WxgT + oFactor×WxoT, [bS, nIn] -// // dLdhI = iFactor×WriT + fFactor×WrfT + eFactor×WrgT + oFactor×WroT, [bS, nOut] -// // dLdcI = factor*tempC + dLdhI * dhIdcI, dhIdcI=0 if firstIter, [bS, nOut] - -// // dcdWxi(dcIdWxi) = dcdzi*dzidWxi + tempIFE*dhIdWxi + tempC*dcIdWxi, dcIdWxi=dhIdWxi= 0 if firstIter, [nIn, nOut, bS, nOut] -// // dcdWxf(dcIdWxf) = dcdzf*dzfdWxf + tempIFE*dhIdWxf + tempC*dcIdWxf, dcIdWxf=dhIdWxf= 0 if firstIter, [nIn, nOut, bS, nOut] -// // dcdWxg(dcIdWxg) = dcdzg*dzgdWxg + tempIFE*dhIdWxg + tempC*dcIdWxg, dcIdWxg=dhIdWxg= 0 if firstIter, [nIn, nOut, bS, nOut] -// // dcdWxo(dcIdWxo) = 0 + tempIFE*dhIdWxo + tempC*dcIdWxo; dcIdWxo=dhIdWxo= 0 if firstIter, [nIn, nOut, bS, nOut] - -// // dhdWxi(dhIdWxi) = 0 + dhdc*dcdWxi + tempO*dhIdWxi, dhIdWxi= 0 if firstIter, [nIn, nOut, bS, nOut] -// // dhdWxf(dhIdWxf) = 0 + dhdc*dcdWxf + tempO*dhIdWxf, dhIdWxf= 0 if firstIter, [nIn, nOut, bS, nOut] -// // dhdWxg(dhIdWxg) = 0 + dhdc*dcdWxg + tempO*dhIdWxg, dhIdWxg= 0 if firstIter, [nIn, nOut, bS, nOut] -// // dhdWxo(dhIdWxo) = dhdzo*dzodWxo + dhdc*dcdWxo + tempO*dhIdWxo, dhIdWxo= 0 if firstIter, [nIn, nOut, bS, nOut] - -// // dhdWri(dhIdWri) = 0 + dhdc*dcdWri + tempO*dhIdWri, dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] -// // dhdWrf(dhIdWrf) = 0 + dhdc*dcdWrf + tempO*dhIdWrf, dhIdWrf= 0 if firstIter, [nOut, nOut, bS, nOut] -// // dhdWrg(dhIdWrg) = 0 + dhdc*dcdWrg + tempO*dhIdWrg, dhIdWrg= 0 if firstIter, [nOut, nOut, bS, nOut] -// // dhdWro(dhIdWro) = dhdzo*dzodWro + dhdc*dcdWro + tempO*dhIdWro, dhIdWro= 0 if firstIter, [nOut, nOut, bS, nOut] - -// // dcdWri(dcIdWri) = dcdzi*dzidWri + tempIFE*dhIdWri + tempC*dcIdWri, dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] -// // dcdWrf(dcIdWrf) = dcdzf*dzfdWrf + tempIFE*dhIdWrf + tempC*dcIdWrf, dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] -// // dcdWrg(dcIdWrg) = dcdzg*dzgdWrg + tempIFE*dhIdWrg + tempC*dcIdWrg, dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] -// // dcdWro(dcIdWro) = 0 + tempIFE*dhIdWro + tempC*dcIdWro; dcIdWro=dhIdWro= 0 if firstIter, [nOut, nOut, bS, nOut] - -// // dcIdWpi = (dcdzi*cI + tempIFE*dhIdWpi + tempC*dcIdWpi).reduceALongFirstDim, dcIdWpi=dhIdWpi= 0 if firstIter, [bS, nOut]->reduce->[bS] -// // dcIdWpf = (dcdzf*cI + tempIFE*dhIdWpf + tempC*dcIdWpf).reduceALongFirstDim, dcIdWpf=dhIdWpf= 0 if firstIter, [bS, nOut]->reduce->[bS] -// // dcIdWpo = (0 + tempIFE*dhIdWpo + tempC*dcIdWpo).reduceALongFirstDim, dcIdWpo=dhIdWpo= 0 if firstIter, [bS, nOut]->reduce->[bS] - -// // dhdWpi(dhIdWpi) =( 0 + dhdc*dcdWpi + tempO*dhIdWpi).reduceALongFirstDim, dhIdWpi= 0 if firstIter, [bS, nOut]->reduce->[bS] -// // dhdWpf(dhIdWpf) =( 0 + dhdc*dcdWpf + tempO*dhIdWpf).reduceALongFirstDim, dhIdWpf= 0 if firstIter, [bS, nOut]->reduce->[bS] -// // dhdWpo(dhIdWpo) =(dhdzo*c + dhdc*dcdWpo + tempO*dhIdWpo).reduceALongFirstDim, dhIdWpo= 0 if firstIter, [bS, nOut]->reduce->[bS] - -// // dcdbi(dcIdbi) = (dcdzi + tempIFE*dhIdbi + tempC*dcIdbi).reduceALongFirstDim, dcIdbi=dhIdbi= 0 if firstIter, [bS, nOut]->reduce->[bS] -// // dcdbf(dcIdbf) = (dcdzf + tempIFE*dhIdbf + tempC*dcIdbf).reduceALongFirstDim, dcIdbf=dhIdbf= 0 if firstIter, [bS, nOut]->reduce->[bS] -// // dcdbg(dcIdbg) = (dcdzg + tempIFE*dhIdbg + tempC*dcIdbg).reduceALongFirstDim, dcIdbg=dhIdbg= 0 if firstIter, [bS, nOut]->reduce->[bS] -// // dcdbo(dcIdbo) = ( 0 + tempIFE*dhIdbo + tempC*dcIdbo).reduceALongFirstDim; dcIdbo=dhIdbo= 0 if firstIter, [bS, nOut]->reduce->[bS] - -// // dhdbi(dhIdbi) = ( 0 + dhdc*dcdbi + tempO*dhIdbi).reduceALongFirstDim, dhIdbi= 0 if firstIter, [bS, nOut]->reduce->[bS] -// // dhdbf(dhIdbf) = ( 0 + dhdc*dcdbf + tempO*dhIdbf).reduceALongFirstDim, dhIdbf= 0 if firstIter, [bS, nOut]->reduce->[bS] -// // dhdbg(dhIdbg) = ( 0 + dhdc*dcdbg + tempO*dhIdbg).reduceALongFirstDim, dhIdbg= 0 if firstIter, [bS, nOut]->reduce->[bS] -// // dhdbo(dhIdbo) = (dhdzo + dhdc*dcdbo + tempO*dhIdbo).reduceALongFirstDim, dhIdbo= 0 if firstIter, [bS, nOut]->reduce->[bS] +// // dLdx = iFactor×WxiT + fFactor×WxfT + eFactor×WxgT + oFactor×WxoT, +// [bS, nIn] +// // dLdhI = iFactor×WriT + fFactor×WrfT + eFactor×WrgT + oFactor×WroT, +// [bS, nOut] +// // dLdcI = factor*tempC + dLdhI * dhIdcI, dhIdcI=0 if firstIter, [bS, +// nOut] + +// // dcdWxi(dcIdWxi) = dcdzi*dzidWxi + tempIFE*dhIdWxi + tempC*dcIdWxi, +// dcIdWxi=dhIdWxi= 0 if firstIter, [nIn, nOut, bS, nOut] +// // dcdWxf(dcIdWxf) = dcdzf*dzfdWxf + tempIFE*dhIdWxf + tempC*dcIdWxf, +// dcIdWxf=dhIdWxf= 0 if firstIter, [nIn, nOut, bS, nOut] +// // dcdWxg(dcIdWxg) = dcdzg*dzgdWxg + tempIFE*dhIdWxg + tempC*dcIdWxg, +// dcIdWxg=dhIdWxg= 0 if firstIter, [nIn, nOut, bS, nOut] +// // dcdWxo(dcIdWxo) = 0 + tempIFE*dhIdWxo + tempC*dcIdWxo; +// dcIdWxo=dhIdWxo= 0 if firstIter, [nIn, nOut, bS, nOut] + +// // dhdWxi(dhIdWxi) = 0 + dhdc*dcdWxi + tempO*dhIdWxi, +// dhIdWxi= 0 if firstIter, [nIn, nOut, bS, nOut] +// // dhdWxf(dhIdWxf) = 0 + dhdc*dcdWxf + tempO*dhIdWxf, +// dhIdWxf= 0 if firstIter, [nIn, nOut, bS, nOut] +// // dhdWxg(dhIdWxg) = 0 + dhdc*dcdWxg + tempO*dhIdWxg, +// dhIdWxg= 0 if firstIter, [nIn, nOut, bS, nOut] +// // dhdWxo(dhIdWxo) = dhdzo*dzodWxo + dhdc*dcdWxo + tempO*dhIdWxo, +// dhIdWxo= 0 if firstIter, [nIn, nOut, bS, nOut] + +// // dhdWri(dhIdWri) = 0 + dhdc*dcdWri + tempO*dhIdWri, +// dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] +// // dhdWrf(dhIdWrf) = 0 + dhdc*dcdWrf + tempO*dhIdWrf, +// dhIdWrf= 0 if firstIter, [nOut, nOut, bS, nOut] +// // dhdWrg(dhIdWrg) = 0 + dhdc*dcdWrg + tempO*dhIdWrg, +// dhIdWrg= 0 if firstIter, [nOut, nOut, bS, nOut] +// // dhdWro(dhIdWro) = dhdzo*dzodWro + dhdc*dcdWro + tempO*dhIdWro, +// dhIdWro= 0 if firstIter, [nOut, nOut, bS, nOut] + +// // dcdWri(dcIdWri) = dcdzi*dzidWri + tempIFE*dhIdWri + tempC*dcIdWri, +// dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] +// // dcdWrf(dcIdWrf) = dcdzf*dzfdWrf + tempIFE*dhIdWrf + tempC*dcIdWrf, +// dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] +// // dcdWrg(dcIdWrg) = dcdzg*dzgdWrg + tempIFE*dhIdWrg + tempC*dcIdWrg, +// dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] +// // dcdWro(dcIdWro) = 0 + tempIFE*dhIdWro + tempC*dcIdWro; +// dcIdWro=dhIdWro= 0 if firstIter, [nOut, nOut, bS, nOut] + +// // dcIdWpi = (dcdzi*cI + tempIFE*dhIdWpi + +// tempC*dcIdWpi).reduceALongFirstDim, dcIdWpi=dhIdWpi= 0 if firstIter, +// [bS, nOut]->reduce->[bS] +// // dcIdWpf = (dcdzf*cI + tempIFE*dhIdWpf + +// tempC*dcIdWpf).reduceALongFirstDim, dcIdWpf=dhIdWpf= 0 if firstIter, +// [bS, nOut]->reduce->[bS] +// // dcIdWpo = (0 + tempIFE*dhIdWpo + +// tempC*dcIdWpo).reduceALongFirstDim, dcIdWpo=dhIdWpo= 0 if firstIter, +// [bS, nOut]->reduce->[bS] + +// // dhdWpi(dhIdWpi) =( 0 + dhdc*dcdWpi + +// tempO*dhIdWpi).reduceALongFirstDim, dhIdWpi= 0 if firstIter, +// [bS, nOut]->reduce->[bS] +// // dhdWpf(dhIdWpf) =( 0 + dhdc*dcdWpf + +// tempO*dhIdWpf).reduceALongFirstDim, dhIdWpf= 0 if firstIter, +// [bS, nOut]->reduce->[bS] +// // dhdWpo(dhIdWpo) =(dhdzo*c + dhdc*dcdWpo + +// tempO*dhIdWpo).reduceALongFirstDim, dhIdWpo= 0 if firstIter, +// [bS, nOut]->reduce->[bS] + +// // dcdbi(dcIdbi) = (dcdzi + tempIFE*dhIdbi + +// tempC*dcIdbi).reduceALongFirstDim, dcIdbi=dhIdbi= 0 if +// firstIter, [bS, nOut]->reduce->[bS] +// // dcdbf(dcIdbf) = (dcdzf + tempIFE*dhIdbf + +// tempC*dcIdbf).reduceALongFirstDim, dcIdbf=dhIdbf= 0 if +// firstIter, [bS, nOut]->reduce->[bS] +// // dcdbg(dcIdbg) = (dcdzg + tempIFE*dhIdbg + +// tempC*dcIdbg).reduceALongFirstDim, dcIdbg=dhIdbg= 0 if +// firstIter, [bS, nOut]->reduce->[bS] +// // dcdbo(dcIdbo) = ( 0 + tempIFE*dhIdbo + +// tempC*dcIdbo).reduceALongFirstDim; dcIdbo=dhIdbo= 0 if +// firstIter, [bS, nOut]->reduce->[bS] + +// // dhdbi(dhIdbi) = ( 0 + dhdc*dcdbi + +// tempO*dhIdbi).reduceALongFirstDim, dhIdbi= 0 if firstIter, [bS, +// nOut]->reduce->[bS] +// // dhdbf(dhIdbf) = ( 0 + dhdc*dcdbf + +// tempO*dhIdbf).reduceALongFirstDim, dhIdbf= 0 if firstIter, [bS, +// nOut]->reduce->[bS] +// // dhdbg(dhIdbg) = ( 0 + dhdc*dcdbg + +// tempO*dhIdbg).reduceALongFirstDim, dhIdbg= 0 if firstIter, [bS, +// nOut]->reduce->[bS] +// // dhdbo(dhIdbo) = (dhdzo + dhdc*dcdbo + +// tempO*dhIdbo).reduceALongFirstDim, dhIdbo= 0 if firstIter, [bS, +// nOut]->reduce->[bS] // const Nd4jLong nOut = Wx->sizeAt(-1) / 4; -// NDArray *Wpi(nullptr), *Wpf(nullptr), *Wpo(nullptr), *dcIdWpi(nullptr), *dcIdWpf(nullptr), *dcIdWpo(nullptr), *dhIdWpi(nullptr), *dhIdWpf(nullptr), *dhIdWpo(nullptr); -// if(Wp) { +// NDArray *Wpi(nullptr), *Wpf(nullptr), *Wpo(nullptr), *dcIdWpi(nullptr), +// *dcIdWpf(nullptr), *dcIdWpo(nullptr), *dhIdWpi(nullptr), +// *dhIdWpf(nullptr), *dhIdWpo(nullptr); if(Wp) { // Wpi = new NDArray((*Wp)({0, nOut})); // Wpf = new NDArray((*Wp)({nOut, 2*nOut})); // Wpo = new NDArray((*Wp)({2*nOut, 3*nOut})); @@ -1359,8 +1631,9 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // dcIdWpo = new NDArray((*dcIdWp)({2*nOut, 3*nOut})); // } -// NDArray *dcIdbi(nullptr), *dcIdbf(nullptr), *dcIdbg(nullptr), *dcIdbo(nullptr), *dhIdbi(nullptr), *dhIdbf(nullptr), *dhIdbg(nullptr), *dhIdbo(nullptr); -// if(b) { +// NDArray *dcIdbi(nullptr), *dcIdbf(nullptr), *dcIdbg(nullptr), +// *dcIdbo(nullptr), *dhIdbi(nullptr), *dhIdbf(nullptr), *dhIdbg(nullptr), +// *dhIdbo(nullptr); if(b) { // dhIdbi = new NDArray((*dhIdb)({0, nOut})); // dhIdbf = new NDArray((*dhIdb)({nOut, 2*nOut})); // dhIdbg = new NDArray((*dhIdb)({2*nOut, 3*nOut})); @@ -1371,53 +1644,91 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // dcIdbo = new NDArray((*dcIdb)({3*nOut, 4*nOut})); // } -// NDArray dhIdWxi = x->rankOf() == 1 ? (*dhIdWx)({0,0, 0,nOut, 0,0}) : (*dhIdWx)({0,0, 0,nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr -// NDArray dhIdWxf = x->rankOf() == 1 ? (*dhIdWx)({0,0, nOut,2*nOut, 0,0}) : (*dhIdWx)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr -// NDArray dhIdWxg = x->rankOf() == 1 ? (*dhIdWx)({0,0, 2*nOut,3*nOut, 0,0}) : (*dhIdWx)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr -// NDArray dhIdWxo = x->rankOf() == 1 ? (*dhIdWx)({0,0, 3*nOut,4*nOut, 0,0}) : (*dhIdWx)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr - -// NDArray dhIdWri = x->rankOf() == 1 ? (*dhIdWr)({0,0, 0,nOut, 0,0}) : (*dhIdWr)({0,0, 0,nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr -// NDArray dhIdWrf = x->rankOf() == 1 ? (*dhIdWr)({0,0, nOut,2*nOut, 0,0}) : (*dhIdWr)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr -// NDArray dhIdWrg = x->rankOf() == 1 ? (*dhIdWr)({0,0, 2*nOut,3*nOut, 0,0}) : (*dhIdWr)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr -// NDArray dhIdWro = x->rankOf() == 1 ? (*dhIdWr)({0,0, 3*nOut,4*nOut, 0,0}) : (*dhIdWr)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr - -// NDArray dcIdWxi = x->rankOf() == 1 ? (*dcIdWx)({0,0, 0,nOut, 0,0}) : (*dcIdWx)({0,0, 0,nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr -// NDArray dcIdWxf = x->rankOf() == 1 ? (*dcIdWx)({0,0, nOut,2*nOut, 0,0}) : (*dcIdWx)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr -// NDArray dcIdWxg = x->rankOf() == 1 ? (*dcIdWx)({0,0, 2*nOut,3*nOut, 0,0}) : (*dcIdWx)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr -// NDArray dcIdWxo = x->rankOf() == 1 ? (*dcIdWx)({0,0, 3*nOut,4*nOut, 0,0}) : (*dcIdWx)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr - -// NDArray dcIdWri = x->rankOf() == 1 ? (*dcIdWr)({0,0, 0,nOut, 0,0}) : (*dcIdWr)({0,0, 0,nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr -// NDArray dcIdWrf = x->rankOf() == 1 ? (*dcIdWr)({0,0, nOut,2*nOut, 0,0}) : (*dcIdWr)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr -// NDArray dcIdWrg = x->rankOf() == 1 ? (*dcIdWr)({0,0, 2*nOut,3*nOut, 0,0}) : (*dcIdWr)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr -// NDArray dcIdWro = x->rankOf() == 1 ? (*dcIdWr)({0,0, 3*nOut,4*nOut, 0,0}) : (*dcIdWr)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr - -// NDArray WxiT = (*Wx)({0,0, 0, nOut}).transpose(); // [nOut, nIn] -// NDArray WxfT = (*Wx)({0,0, nOut, 2*nOut}).transpose(); // [nOut, nIn] -// NDArray WxgT = (*Wx)({0,0, 2*nOut,3*nOut}).transpose(); // [nOut, nIn] -// NDArray WxoT = (*Wx)({0,0, 3*nOut,4*nOut}).transpose(); // [nOut, nIn] - -// NDArray WriT = (*Wr)({0,0, 0, nOut}).transpose(); // [nOut, nOut] -// NDArray WrfT = (*Wr)({0,0, nOut, 2*nOut}).transpose(); // [nOut, nOut] -// NDArray WrgT = (*Wr)({0,0, 2*nOut,3*nOut}).transpose(); // [nOut, nOut] -// NDArray WroT = (*Wr)({0,0, 3*nOut,4*nOut}).transpose(); // [nOut, nOut] +// NDArray dhIdWxi = x->rankOf() == 1 ? (*dhIdWx)({0,0, 0,nOut, 0,0}) : +// (*dhIdWx)({0,0, 0,nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] +// or [nIn, nOut, nOut] if seqLen != nullptr NDArray dhIdWxf = x->rankOf() +// == 1 ? (*dhIdWx)({0,0, nOut,2*nOut, 0,0}) : (*dhIdWx)({0,0, +// nOut,2*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, +// nOut] if seqLen != nullptr NDArray dhIdWxg = x->rankOf() == 1 ? +// (*dhIdWx)({0,0, 2*nOut,3*nOut, 0,0}) : (*dhIdWx)({0,0, 2*nOut,3*nOut, +// 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != +// nullptr NDArray dhIdWxo = x->rankOf() == 1 ? (*dhIdWx)({0,0, +// 3*nOut,4*nOut, 0,0}) : (*dhIdWx)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // +// [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr + +// NDArray dhIdWri = x->rankOf() == 1 ? (*dhIdWr)({0,0, 0,nOut, 0,0}) : +// (*dhIdWr)({0,0, 0,nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] +// or [nOut, nOut, nOut] if seqLen != nullptr NDArray dhIdWrf = x->rankOf() +// == 1 ? (*dhIdWr)({0,0, nOut,2*nOut, 0,0}) : (*dhIdWr)({0,0, +// nOut,2*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, +// nOut] if seqLen != nullptr NDArray dhIdWrg = x->rankOf() == 1 ? +// (*dhIdWr)({0,0, 2*nOut,3*nOut, 0,0}) : (*dhIdWr)({0,0, 2*nOut,3*nOut, +// 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen +// != nullptr NDArray dhIdWro = x->rankOf() == 1 ? (*dhIdWr)({0,0, +// 3*nOut,4*nOut, 0,0}) : (*dhIdWr)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // +// [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr + +// NDArray dcIdWxi = x->rankOf() == 1 ? (*dcIdWx)({0,0, 0,nOut, 0,0}) : +// (*dcIdWx)({0,0, 0,nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] +// or [nIn, nOut, nOut] if seqLen != nullptr NDArray dcIdWxf = x->rankOf() +// == 1 ? (*dcIdWx)({0,0, nOut,2*nOut, 0,0}) : (*dcIdWx)({0,0, +// nOut,2*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, +// nOut] if seqLen != nullptr NDArray dcIdWxg = x->rankOf() == 1 ? +// (*dcIdWx)({0,0, 2*nOut,3*nOut, 0,0}) : (*dcIdWx)({0,0, 2*nOut,3*nOut, +// 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != +// nullptr NDArray dcIdWxo = x->rankOf() == 1 ? (*dcIdWx)({0,0, +// 3*nOut,4*nOut, 0,0}) : (*dcIdWx)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // +// [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr + +// NDArray dcIdWri = x->rankOf() == 1 ? (*dcIdWr)({0,0, 0,nOut, 0,0}) : +// (*dcIdWr)({0,0, 0,nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] +// or [nOut, nOut, nOut] if seqLen != nullptr NDArray dcIdWrf = x->rankOf() +// == 1 ? (*dcIdWr)({0,0, nOut,2*nOut, 0,0}) : (*dcIdWr)({0,0, +// nOut,2*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, +// nOut] if seqLen != nullptr NDArray dcIdWrg = x->rankOf() == 1 ? +// (*dcIdWr)({0,0, 2*nOut,3*nOut, 0,0}) : (*dcIdWr)({0,0, 2*nOut,3*nOut, +// 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen +// != nullptr NDArray dcIdWro = x->rankOf() == 1 ? (*dcIdWr)({0,0, +// 3*nOut,4*nOut, 0,0}) : (*dcIdWr)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // +// [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr + +// NDArray WxiT = (*Wx)({0,0, 0, nOut}).transpose(); // +// [nOut, nIn] NDArray WxfT = (*Wx)({0,0, nOut, 2*nOut}).transpose(); // +// [nOut, nIn] NDArray WxgT = (*Wx)({0,0, 2*nOut,3*nOut}).transpose(); // +// [nOut, nIn] NDArray WxoT = (*Wx)({0,0, 3*nOut,4*nOut}).transpose(); // +// [nOut, nIn] + +// NDArray WriT = (*Wr)({0,0, 0, nOut}).transpose(); // +// [nOut, nOut] NDArray WrfT = (*Wr)({0,0, nOut, 2*nOut}).transpose(); // +// [nOut, nOut] NDArray WrgT = (*Wr)({0,0, 2*nOut,3*nOut}).transpose(); // +// [nOut, nOut] NDArray WroT = (*Wr)({0,0, 3*nOut,4*nOut}).transpose(); // +// [nOut, nOut] // // ***** feed forward step ***** // -// auto z = mmul(*x, *Wx) + mmul(*hI, *Wr); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut] -// //or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut] +// auto z = mmul(*x, *Wx) + mmul(*hI, *Wr); // [bs, nIn] * [nIn, +// 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut] +// //or [nIn] * [nIn, 4*nOut] + +// [nOut] * [nOut, 4*nOut] = +// [4*nOut] // // add biases if they are given // if(b) -// z += *b; // broadcast [bS, 4*nOut] + [4*nOut] = [bS, 4*nOut](or[4*nOut]) +// z += *b; // broadcast [bS, 4*nOut] +// + [4*nOut] = [bS, 4*nOut](or[4*nOut]) -// auto zi = x->rankOf() == 1 ? z({0, nOut}) : z({0,0, 0, nOut}); // input gate i, [bS, nOut](or[nOut]) -// auto zf = x->rankOf() == 1 ? z({nOut, 2*nOut}) : z({0,0, nOut, 2*nOut}); // forget gate f, [bS, nOut](or[nOut]) -// auto zg = x->rankOf() == 1 ? z({2*nOut, 3*nOut}) : z({0,0, 2*nOut, 3*nOut}); // cell gate g, [bS, nOut](or[nOut]) -// auto zo = x->rankOf() == 1 ? z({3*nOut, 4*nOut}) : z({0,0, 3*nOut, 4*nOut}); // output gate o, [bS, nOut](or[nOut]) +// auto zi = x->rankOf() == 1 ? z({0, nOut}) : z({0,0, 0, nOut}); // +// input gate i, [bS, nOut](or[nOut]) auto zf = x->rankOf() == 1 ? z({nOut, +// 2*nOut}) : z({0,0, nOut, 2*nOut}); // forget gate f, [bS, +// nOut](or[nOut]) auto zg = x->rankOf() == 1 ? z({2*nOut, 3*nOut}) : +// z({0,0, 2*nOut, 3*nOut}); // cell gate g, [bS, nOut](or[nOut]) +// auto zo = x->rankOf() == 1 ? z({3*nOut, 4*nOut}) : z({0,0, 3*nOut, +// 4*nOut}); // output gate o, [bS, nOut](or[nOut]) // // peephole connections for input and forget gates // if(Wp) { -// zi += *cI * *Wpi; // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) -// zf += *cI * *Wpf; // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) +// zi += *cI * *Wpi; // broadcast: [bS, nOut] + [bS, nOut] * [nOut] +// = [bS, nOut](or[nOut]) zf += *cI * *Wpf; // broadcast: [bS, +// nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) // } // NDArray i = zi.ulike(); // [bS, nOut] @@ -1427,69 +1738,79 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // applyActivation(zf, params[3], params[4], params[5], f); // applyActivation(zg, params[6], params[7], params[8], g); -// NDArray c = f * *cI + i * g; // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, nOut] = [bS, nOut](or[nOut]) +// NDArray c = f * *cI + i * g; // [bS, nOut] * [bS, nOut] + [bS, +// nOut] * [bS, nOut] = [bS, nOut](or[nOut]) -// // if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation -// if(params[2] != 0) +// // if clipping value is non-zero then cell state is clipped by this value +// prior to the cell output activation if(params[2] != 0) // c.applyScalar(scalar::LstmClip, params[2], c); // // peephole connections for output gate // if(Wp) -// zo += c * *Wpo; // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) +// zo += c * *Wpo; // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = +// [bS, nOut](or[nOut]) // NDArray o = zo.ulike(); // [bS, nOut](or[nOut]) // applyActivation(zo, params[3], params[4], params[5], o); // // ***** back prop step ***** // -// NDArray dWxJacobian = mmulJacobianWeightsDeriv(nOut, *x); // [nIn, nOut, bS, nOut] (or [nIn, nOut, nOut]) -// NDArray dWrJacobian = mmulJacobianWeightsDeriv(nOut, *hI); // [nOut, nOut, bS, nOut] (or [nOut, nOut, nOut]) +// NDArray dWxJacobian = mmulJacobianWeightsDeriv(nOut, *x); // [nIn, +// nOut, bS, nOut] (or [nIn, nOut, nOut]) NDArray dWrJacobian = +// mmulJacobianWeightsDeriv(nOut, *hI); // [nOut, nOut, bS, nOut] (or +// [nOut, nOut, nOut]) // // dodzo -// NDArray dodzo = zo.ulike(); // [bS, nOut](or[nOut]) -// activationDeriv(zo, params[3], params[4], params[5], dodzo); +// NDArray dodzo = zo.ulike(); // [bS, nOut](or[nOut]) activationDeriv(zo, +// params[3], params[4], params[5], dodzo); // // dhdzo = dhdo*dodzo = actH(c)*dodzo -// NDArray dhdzo = zo.ulike(); // [bS, nOut](or[nOut]) -// applyActivation(c, params[9], params[10], params[11], dhdzo); // actH(c) +// NDArray dhdzo = zo.ulike(); // [bS, nOut](or[nOut]) applyActivation(c, +// params[9], params[10], params[11], dhdzo); // actH(c) // hI->assign(o*dhdzo); // dhdzo *= dodzo; // // dcdzi = dcdi*didzi -// NDArray dcdzi = zi.ulike(); // [bS, nOut](or[nOut]) -// activationDeriv(zi, params[3], params[4], params[5], dcdzi); // didzi -// dcdzi *= g; // dcdi = g*clipDeriv +// NDArray dcdzi = zi.ulike(); // [bS, nOut](or[nOut]) activationDeriv(zi, +// params[3], params[4], params[5], dcdzi); // didzi dcdzi *= g; +// // dcdi = g*clipDeriv // // dcdzf = dcdf*dfdzf -// NDArray dcdzf = zf.ulike(); // [bS, nOut](or[nOut]) -// activationDeriv(zf, params[3], params[4], params[5], dcdzf); // dfdzf -// dcdzf *= *cI; // dcdf = cI*clipDeriv +// NDArray dcdzf = zf.ulike(); // [bS, nOut](or[nOut]) activationDeriv(zf, +// params[3], params[4], params[5], dcdzf); // dfdzf dcdzf *= +// *cI; // dcdf = +// cI*clipDeriv // // dcdzg = dcde*dedzg -// NDArray dcdzg = zg.ulike(); // [bS, nOut](or[nOut]) -// activationDeriv(zg, params[6], params[7], params[8], dcdzg); // dedzg -// dcdzg *= i; // dcdf = i*clipDeriv +// NDArray dcdzg = zg.ulike(); // [bS, nOut](or[nOut]) activationDeriv(zg, +// params[6], params[7], params[8], dcdzg); // dedzg dcdzg *= i; +// // dcdf = i*clipDeriv // // dcdcI -// NDArray dcdcI = f.dup(); // [bS, nOut](or[nOut]) +// NDArray dcdcI = f.dup(); // [bS, nOut](or[nOut]) // // take into account possible deposit from clipping derivative // clipDeriv(params[2], c, dcdzi, dcdzf, dcdzg, dcdcI); // // dzodc -// NDArray* dzodc = Wpo; // [nOut], should be [bS, nOut] actually, however it will be broadcasted appropriately in future calcus (element-wise multiplication) +// NDArray* dzodc = Wpo; // [nOut], should be [bS, +// nOut] actually, however it will be broadcasted appropriately in future +// calcus (element-wise multiplication) // // dzidcI -// NDArray* dzidcI = Wpi; // [nOut], should be [bS, nOut] actually, however it will be broadcasted appropriately in future calcus (element-wise multiplication) +// NDArray* dzidcI = Wpi; // [nOut], should be [bS, +// nOut] actually, however it will be broadcasted appropriately in future +// calcus (element-wise multiplication) // // dzfdcI -// NDArray* dzfdcI = Wpf; // [nOut], should be [bS, nOut] actually, however it will be broadcasted appropriately in future calcus (element-wise multiplication) +// NDArray* dzfdcI = Wpf; // [nOut], should be [bS, +// nOut] actually, however it will be broadcasted appropriately in future +// calcus (element-wise multiplication) // // dhdc // NDArray dhdc = c.ulike(); -// activationDeriv(c, params[9], params[10], params[11], dhdc); // [bS, nOut] -// dhdc *= o; -// if(Wp) +// activationDeriv(c, params[9], params[10], params[11], dhdc); // [bS, +// nOut] dhdc *= o; if(Wp) // dhdc += dhdzo* *dzodc; // NDArray factor = *dLdh * dhdc; @@ -1504,25 +1825,33 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // tempC += dcdzi*(*dzidcI) + dcdzf*(*dzfdcI); // // dLdx -// dLdx->assign(mmul(iFactor, WxiT) + mmul(fFactor, WxfT) + mmul(eFactor, WxgT) + mmul(oFactor, WxoT)); // [bS, nIn](or[nOut]) +// dLdx->assign(mmul(iFactor, WxiT) + mmul(fFactor, WxfT) + mmul(eFactor, +// WxgT) + mmul(oFactor, WxoT)); // [bS, nIn](or[nOut]) // // NDArray temp = c.ulike(); -// // applyActivation(c, params[9], params[10], params[11], temp); // actH(c) -// // dLdx->assign(mmul(o*(1-temp*temp)*g*i*(1-i), WxiT) + mmul(o*(1-temp*temp)*(*cI)*f*(1-f), WxfT) + mmul(o*(1-temp*temp)*i*g*(1-g), WxgT) + mmul(temp*o*(1-o), WxoT)); // [bS, nIn](or[nOut]) +// // applyActivation(c, params[9], params[10], params[11], temp); // +// actH(c) +// // dLdx->assign(mmul(o*(1-temp*temp)*g*i*(1-i), WxiT) + +// mmul(o*(1-temp*temp)*(*cI)*f*(1-f), WxfT) + +// mmul(o*(1-temp*temp)*i*g*(1-g), WxgT) + mmul(temp*o*(1-o), WxoT)); // +// [bS, nIn](or[nOut]) // // dLdhI // NDArray* dLdhII = dLdhI; // if(dLdcI && !dLdhI) // dLdhII = new NDArray(dLdcI->ulike()); -// dLdhII->assign(mmul(iFactor, WriT) + mmul(fFactor, WrfT) + mmul(eFactor, WrgT) + mmul(oFactor, WroT)); // [bS, nOut](or[nOut]) +// dLdhII->assign(mmul(iFactor, WriT) + mmul(fFactor, WrfT) + mmul(eFactor, +// WrgT) + mmul(oFactor, WroT)); // [bS, nOut](or[nOut]) // if(firstIter) { // // dLdcI // if(dLdcI) -// dLdcI->assign(factor*tempC); // [bS, nOut](or[nOut]) +// dLdcI->assign(factor*tempC); // [bS, +// nOut](or[nOut]) // // dcIdWxi(dcdWxi) -// dcIdWxi.assign(dcdzi*dWxJacobian); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]); +// dcIdWxi.assign(dcdzi*dWxJacobian); // broadcast [bS, nOut] * +// [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]); // // dcIdWxf(dcdWxf) // dcIdWxf.assign(dcdzf*dWxJacobian); // // dcIdWxg(dcdWxg) @@ -1531,7 +1860,8 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // dcIdWxo.nullify(); // // dhIdWxi -// dhIdWxi.assign(dhdc*dcIdWxi); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]); +// dhIdWxi.assign(dhdc*dcIdWxi); // broadcast [bS, nOut] * +// [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]); // // dhIdWxf // dhIdWxf.assign(dhdc*dcIdWxf); // // dhIdWxg @@ -1540,7 +1870,8 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // dhIdWxo.assign(dhdzo*dWxJacobian /*+ 0 */); // // dcIdWri(dcdWri) -// dcIdWri.assign(dcdzi*dWrJacobian); // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]);; +// dcIdWri.assign(dcdzi*dWrJacobian); // broadcast [bS, nOut] * +// [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]);; // // dcIdWrf(dcdWrf) // dcIdWrf.assign(dcdzf*dWrJacobian); // // dcIdWrg(dcdWrg) @@ -1549,7 +1880,8 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // dcIdWro.nullify(); // // dhIdWri -// dhIdWri.assign(dhdc*dcIdWri); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]); +// dhIdWri.assign(dhdc*dcIdWri); // broadcast [bS, nOut] * +// [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]); // // dhIdWrf // dhIdWrf.assign(dhdc*dcIdWrf); // // dhIdWrg @@ -1574,18 +1906,23 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // } // else if(Wp) { // // dcIdWpi -// (dcdzi*(*cI)).reduceAlongDimension(reduce::Sum, *dcIdWpi, {0}); // [bS, nOut]->reduce->[nOut] +// (dcdzi*(*cI)).reduceAlongDimension(reduce::Sum, *dcIdWpi, {0}); +// // [bS, nOut]->reduce->[nOut] // // dcIdWpf -// (dcdzf*(*cI)).reduceAlongDimension(reduce::Sum, *dcIdWpf, {0}); // [bS, nOut]->reduce->[nOut] +// (dcdzf*(*cI)).reduceAlongDimension(reduce::Sum, *dcIdWpf, {0}); +// // [bS, nOut]->reduce->[nOut] // // dcIdWpo -// dcIdWpo->nullify(); // [nOut] +// dcIdWpo->nullify(); // [nOut] // // dhIdWpi -// (*dLdh*dhdc*(dcdzi*(*cI))).reduceAlongDimension(reduce::Sum, *dhIdWpi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (*dLdh*dhdc*(dcdzi*(*cI))).reduceAlongDimension(reduce::Sum, +// *dhIdWpi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // dhIdWpf -// (*dLdh*dhdc*(dcdzf*(*cI))).reduceAlongDimension(reduce::Sum, *dhIdWpf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (*dLdh*dhdc*(dcdzf*(*cI))).reduceAlongDimension(reduce::Sum, +// *dhIdWpf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // dhIdWpo -// (*dLdh*dhdzo*c /* +0*/).reduceAlongDimension(reduce::Sum, *dhIdWpo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (*dLdh*dhdzo*c /* +0*/).reduceAlongDimension(reduce::Sum, +// *dhIdWpo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // } // if(b && x->rankOf() == 1) { @@ -1610,36 +1947,45 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // } // else if(b) { // // dcIdbi -// dcdzi.reduceAlongDimension(reduce::Sum, *dcIdbi, {0}); // [bS, nOut]->reduce->[nOut] +// dcdzi.reduceAlongDimension(reduce::Sum, *dcIdbi, {0}); // +// [bS, nOut]->reduce->[nOut] // // dcIdbf -// dcdzf.reduceAlongDimension(reduce::Sum, *dcIdbf, {0}); // [bS, nOut]->reduce->[nOut] +// dcdzf.reduceAlongDimension(reduce::Sum, *dcIdbf, {0}); // +// [bS, nOut]->reduce->[nOut] // // dcIdbg -// dcdzg.reduceAlongDimension(reduce::Sum, *dcIdbg, {0}); // [bS, nOut]->reduce->[nOut] +// dcdzg.reduceAlongDimension(reduce::Sum, *dcIdbg, {0}); // +// [bS, nOut]->reduce->[nOut] // // dcIdbo // dcIdbo->nullify(); // [nOut] // //dhIdbi -// (*dLdh*dhdc*dcdzi).reduceAlongDimension(reduce::Sum, *dhIdbi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (*dLdh*dhdc*dcdzi).reduceAlongDimension(reduce::Sum, *dhIdbi, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // //dhIdbf -// (*dLdh*dhdc*dcdzf).reduceAlongDimension(reduce::Sum, *dhIdbf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (*dLdh*dhdc*dcdzf).reduceAlongDimension(reduce::Sum, *dhIdbf, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // //dhIdbg -// (*dLdh*dhdc*(*dcIdbg)).reduceAlongDimension(reduce::Sum, *dhIdbg, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (*dLdh*dhdc*(*dcIdbg)).reduceAlongDimension(reduce::Sum, *dhIdbg, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // //dhIdbo -// (*dLdh*dhdzo).reduceAlongDimension(reduce::Sum, *dhIdbo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (*dLdh*dhdzo).reduceAlongDimension(reduce::Sum, *dhIdbo, {0}); // +// ([bS, nOut] * [nOut])->reduce->[nOut] // } // } // else { -// NDArray tempIFE = mmul(dcdzi, WriT) + mmul(dcdzf, WrfT) + mmul(dcdzg, WrgT); -// NDArray tempO = mmul(dhdzo, WroT); +// NDArray tempIFE = mmul(dcdzi, WriT) + mmul(dcdzf, WrfT) + mmul(dcdzg, +// WrgT); NDArray tempO = mmul(dhdzo, WroT); // // dLdcI // if(dLdcI) // dLdcI->assign(factor*tempC + (*dLdhII)*(*dhIdcI)); // // dcIdWxi(dcdWxi) -// dcIdWxi.assign(dcdzi*dWxJacobian + tempIFE*dhIdWxi + tempC*dcIdWxi); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]); +// dcIdWxi.assign(dcdzi*dWxJacobian + tempIFE*dhIdWxi + tempC*dcIdWxi); +// // broadcast [bS, nOut] * [nIn, nOut, bS, nOut](or [nOut] * [nIn, +// nOut, nOut]); // // dcIdWxf(dcdWxf) // dcIdWxf.assign(dcdzf*dWxJacobian + tempIFE*dhIdWxf + tempC*dcIdWxf); // // dcIdWxg(dcdWxg) @@ -1648,7 +1994,8 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // dcIdWxo.assign(/* 0 + */tempIFE * dhIdWxo + tempC*dcIdWxo); // // dhIdWxi -// dhIdWxi.assign(dhdc*dcIdWxi + tempO*dhIdWxi); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]); +// dhIdWxi.assign(dhdc*dcIdWxi + tempO*dhIdWxi); // broadcast [bS, nOut] +// * [nIn, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]); // // dhIdWxf // dhIdWxf.assign(dhdc*dcIdWxf + tempO*dhIdWxf); // // dhIdWxg @@ -1657,7 +2004,9 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // dhIdWxo.assign(dhdzo*dWxJacobian + dhdc*dcIdWxo + tempO*dhIdWxo); // // dcIdWri(dcdWri) -// dcIdWri.assign(dcdzi*dWrJacobian + tempIFE*dhIdWri + tempC*dcIdWri); // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]); +// dcIdWri.assign(dcdzi*dWrJacobian + tempIFE*dhIdWri + tempC*dcIdWri); +// // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, +// nOut, nOut]); // // dcIdWrf(dcdWrf) // dcIdWrf.assign(dcdzf*dWrJacobian + tempIFE*dhIdWrf + tempC*dcIdWrf); // // dcIdWrg(dcdWrg) @@ -1666,7 +2015,8 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // dcIdWro.assign(/* 0 + */tempIFE * dhIdWro + tempC*dcIdWro); // // dhIdWri -// dhIdWri.assign(dhdc*dcIdWri + tempO*dhIdWri); // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]); +// dhIdWri.assign(dhdc*dcIdWri + tempO*dhIdWri); // broadcast [bS, nOut] +// * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]); // // dhIdWrf // dhIdWrf.assign(dhdc*dcIdWrf + tempO*dhIdWrf); // // dhIdWrg @@ -1676,99 +2026,150 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // if(Wp && x->rankOf() == 1) { // // dcIdWpi -// dcIdWpi->assign(dcdzi*(*cI) + tempIFE*(*dhIdWpi) + tempC*(*dcIdWpi)); // [nOut] * [nOut] +// dcIdWpi->assign(dcdzi*(*cI) + tempIFE*(*dhIdWpi) + +// tempC*(*dcIdWpi)); // [nOut] * [nOut] // // dcIdWpf -// dcIdWpf->assign(dcdzf*(*cI) + tempIFE*(*dhIdWpf) + tempC*(*dcIdWpf)); // [nOut] * [nOut] +// dcIdWpf->assign(dcdzf*(*cI) + tempIFE*(*dhIdWpf) + +// tempC*(*dcIdWpf)); // [nOut] * [nOut] // // dcIdWpo -// dcIdWpo->assign(/* 0 + */ tempIFE*(*dhIdWpo) + tempC*(*dcIdWpo)); // [nOut] * [nOut] +// dcIdWpo->assign(/* 0 + */ tempIFE*(*dhIdWpo) + +// tempC*(*dcIdWpo)); // [nOut] * [nOut] // // dhdWpi -// dhIdWpi->assign(dhdc*(*dcIdWpi) + tempO*(*dhIdWpi)); // [nOut] * [nOut] +// dhIdWpi->assign(dhdc*(*dcIdWpi) + tempO*(*dhIdWpi)); // [nOut] * +// [nOut] // // dhdWpf -// dhIdWpf->assign(dhdc*(*dcIdWpf) + tempO*(*dhIdWpf)); // [nOut] * [nOut] +// dhIdWpf->assign(dhdc*(*dcIdWpf) + tempO*(*dhIdWpf)); // [nOut] * +// [nOut] // // dhdWpo -// dhIdWpo->assign(dhdzo*c + dhdc*(*dcIdWpo) + tempO*(*dhIdWpo)); // [nOut] * [nOut] +// dhIdWpo->assign(dhdzo*c + dhdc*(*dcIdWpo) + tempO*(*dhIdWpo)); // +// [nOut] * [nOut] // } // else if(Wp) { // // dcIdWpi -// (dcdzi*(*cI) + tempIFE*(*dhIdWpi) + tempC*(*dcIdWpi)).reduceAlongDimension(reduce::Sum, *dcIdWpi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (dcdzi*(*cI) + tempIFE*(*dhIdWpi) + +// tempC*(*dcIdWpi)).reduceAlongDimension(reduce::Sum, *dcIdWpi, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // dcIdWpf -// (dcdzf*(*cI) + tempIFE*(*dhIdWpf) + tempC*(*dcIdWpf)).reduceAlongDimension(reduce::Sum, *dcIdWpf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (dcdzf*(*cI) + tempIFE*(*dhIdWpf) + +// tempC*(*dcIdWpf)).reduceAlongDimension(reduce::Sum, *dcIdWpf, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // dcIdWpo -// (/* 0 + */ tempIFE*(*dhIdWpo) + tempC*(*dcIdWpo)).reduceAlongDimension(reduce::Sum, *dcIdWpo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (/* 0 + */ tempIFE*(*dhIdWpo) + +// tempC*(*dcIdWpo)).reduceAlongDimension(reduce::Sum, *dcIdWpo, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // dhIdWpi -// (dhdc*(*dcIdWpi) + tempO*(*dhIdWpi)).reduceAlongDimension(reduce::Sum, *dhIdWpi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (dhdc*(*dcIdWpi) + +// tempO*(*dhIdWpi)).reduceAlongDimension(reduce::Sum, *dhIdWpi, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // dhIdWpf -// (dhdc*(*dcIdWpf) + tempO*(*dhIdWpf)).reduceAlongDimension(reduce::Sum, *dhIdWpf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (dhdc*(*dcIdWpf) + +// tempO*(*dhIdWpf)).reduceAlongDimension(reduce::Sum, *dhIdWpf, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // dhIdWpo -// (dhdzo*c + dhdc*(*dcIdWpo) + tempO*(*dhIdWpo)).reduceAlongDimension(reduce::Sum, *dhIdWpo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (dhdzo*c + dhdc*(*dcIdWpo) + +// tempO*(*dhIdWpo)).reduceAlongDimension(reduce::Sum, *dhIdWpo, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // } // if(b && x->rankOf() == 1) { // // dcIdbi -// dcIdbi->assign(dcdzi + tempIFE*(*dhIdbi) + tempC*(*dcIdbi)); // [nOut] +// dcIdbi->assign(dcdzi + tempIFE*(*dhIdbi) + tempC*(*dcIdbi)); // +// [nOut] // // dcIdbf -// dcIdbf->assign(dcdzf + tempIFE*(*dhIdbf) + tempC*(*dcIdbf)); // [nOut] +// dcIdbf->assign(dcdzf + tempIFE*(*dhIdbf) + tempC*(*dcIdbf)); // +// [nOut] // // dcIdbg -// dcIdbg->assign(dcdzg + tempIFE*(*dhIdbg) + tempC*(*dcIdbg)); // [nOut] +// dcIdbg->assign(dcdzg + tempIFE*(*dhIdbg) + tempC*(*dcIdbg)); // +// [nOut] // // dcIdbo -// dcIdbo->assign(/*0+*/ tempIFE*(*dhIdbo) + tempC*(*dcIdbo)); // [nOut] +// dcIdbo->assign(/*0+*/ tempIFE*(*dhIdbo) + tempC*(*dcIdbo)); // +// [nOut] // //dhIdbi -// dhIdbi->assign(dhdc*(*dcIdbi) + tempO*(*dhIdbi)); // [nOut] +// dhIdbi->assign(dhdc*(*dcIdbi) + tempO*(*dhIdbi)); // +// [nOut] // //dhIdbf -// dhIdbf->assign(dhdc*(*dcIdbf) + tempO*(*dhIdbf)); // [nOut] +// dhIdbf->assign(dhdc*(*dcIdbf) + tempO*(*dhIdbf)); // +// [nOut] // //dhIdbg -// dhIdbg->assign(dhdc*(*dcIdbg) + tempO*(*dhIdbg)); // [nOut] +// dhIdbg->assign(dhdc*(*dcIdbg) + tempO*(*dhIdbg)); // +// [nOut] // //dhIdbo -// dhIdbo->assign(dhdzo + dhdc*(*dcIdbo) + tempO*(*dhIdbo)); // [nOut] +// dhIdbo->assign(dhdzo + dhdc*(*dcIdbo) + tempO*(*dhIdbo)); // +// [nOut] // } // else if(b) { // // dcIdbi -// (dcdzi + tempIFE*(*dhIdbi) + tempC*(*dcIdbi)).reduceAlongDimension(reduce::Sum, *dcIdbi, {0}); // [bS, nOut]->reduce->[nOut] +// (dcdzi + tempIFE*(*dhIdbi) + +// tempC*(*dcIdbi)).reduceAlongDimension(reduce::Sum, *dcIdbi, {0}); +// // [bS, nOut]->reduce->[nOut] // // dcIdbf -// (dcdzf + tempIFE*(*dhIdbf) + tempC*(*dcIdbf)).reduceAlongDimension(reduce::Sum, *dcIdbf, {0}); // [bS, nOut]->reduce->[nOut] +// (dcdzf + tempIFE*(*dhIdbf) + +// tempC*(*dcIdbf)).reduceAlongDimension(reduce::Sum, *dcIdbf, {0}); +// // [bS, nOut]->reduce->[nOut] // // dcIdbg -// (dcdzg + tempIFE*(*dhIdbg) + tempC*(*dcIdbg)).reduceAlongDimension(reduce::Sum, *dcIdbg, {0}); // [bS, nOut]->reduce->[nOut] +// (dcdzg + tempIFE*(*dhIdbg) + +// tempC*(*dcIdbg)).reduceAlongDimension(reduce::Sum, *dcIdbg, {0}); +// // [bS, nOut]->reduce->[nOut] // // dcIdbo -// (/*0+*/ tempIFE*(*dhIdbo) + tempC*(*dcIdbo)).reduceAlongDimension(reduce::Sum, *dcIdbo, {0}); // [bS, nOut]->reduce->[nOut] +// (/*0+*/ tempIFE*(*dhIdbo) + +// tempC*(*dcIdbo)).reduceAlongDimension(reduce::Sum, *dcIdbo, {0}); +// // [bS, nOut]->reduce->[nOut] // //dhIdbi -// (dhdc*(*dcIdbi) + tempO*(*dhIdbi)).reduceAlongDimension(reduce::Sum, *dhIdbi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (dhdc*(*dcIdbi) + +// tempO*(*dhIdbi)).reduceAlongDimension(reduce::Sum, *dhIdbi, {0}); +// // ([bS, nOut] * [nOut])->reduce->[nOut] // //dhIdbf -// (dhdc*(*dcIdbf) + tempO*(*dhIdbf)).reduceAlongDimension(reduce::Sum, *dhIdbf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (dhdc*(*dcIdbf) + +// tempO*(*dhIdbf)).reduceAlongDimension(reduce::Sum, *dhIdbf, {0}); +// // ([bS, nOut] * [nOut])->reduce->[nOut] // //dhIdbg -// (dhdc*(*dcIdbg) + tempO*(*dhIdbg)).reduceAlongDimension(reduce::Sum, *dhIdbg, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (dhdc*(*dcIdbg) + +// tempO*(*dhIdbg)).reduceAlongDimension(reduce::Sum, *dhIdbg, {0}); +// // ([bS, nOut] * [nOut])->reduce->[nOut] // //dhIdbo -// (dhdzo + dhdc*(*dcIdbo) + tempO*(*dhIdbo)).reduceAlongDimension(reduce::Sum, *dhIdbo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (dhdzo + dhdc*(*dcIdbo) + +// tempO*(*dhIdbo)).reduceAlongDimension(reduce::Sum, *dhIdbo, {0}); +// // ([bS, nOut] * [nOut])->reduce->[nOut] // } // } -// const std::vector dimsToExclude = x->rankOf() == 1 ? std::vector({2}) : std::vector({2, 3}); +// const std::vector dimsToExclude = x->rankOf() == 1 ? +// std::vector({2}) : std::vector({2, 3}); // // dLdWxi, dLdWxf, dLdWxg, dLdWxo -// (*dLdh*(*dhIdWx)).reduceAlongDimension(reduce::Sum, *dLdWx, dimsToExclude); +// (*dLdh*(*dhIdWx)).reduceAlongDimension(reduce::Sum, *dLdWx, +// dimsToExclude); // // dLdWri, dLdWrf, dLdWrg, dLdWro -// (*dLdh*(*dhIdWr)).reduceAlongDimension(reduce::Sum, *dLdWr, dimsToExclude); +// (*dLdh*(*dhIdWr)).reduceAlongDimension(reduce::Sum, *dLdWr, +// dimsToExclude); // // dLdWpi, dLdWpf, dLdWpo // if(Wp) { // if(x->rankOf() == 1) { -// (*dLdWp)({0, nOut}).assign(*dLdh*(*dhIdWpi)); // [nOut] * [nOut] -// (*dLdWp)({nOut, 2*nOut}).assign(*dLdh*(*dhIdWpf)); // [nOut] * [nOut] -// (*dLdWp)({2*nOut, 3*nOut}).assign(*dLdh*(*dhIdWpo)); // [nOut] * [nOut] +// (*dLdWp)({0, nOut}).assign(*dLdh*(*dhIdWpi)); // [nOut] +// * [nOut] +// (*dLdWp)({nOut, 2*nOut}).assign(*dLdh*(*dhIdWpf)); // [nOut] +// * [nOut] +// (*dLdWp)({2*nOut, 3*nOut}).assign(*dLdh*(*dhIdWpo)); // [nOut] +// * [nOut] // } // else { // // NDArray temp1 = (*dLdWp)({0, nOut}); // // NDArray temp2 = (*dLdWp)({nOut, 2*nOut}); // // NDArray temp3 = (*dLdWp)({2*nOut, 3*nOut}); -// // dhIdWpi->reduceAlongDimension(reduce::Sum, temp1, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] -// // dhIdWpf->reduceAlongDimension(reduce::Sum, temp2, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] -// // dhIdWpo->reduceAlongDimension(reduce::Sum, temp3, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // dhIdWpi->reduceAlongDimension(reduce::Sum, temp1, {0}); // +// ([bS, nOut] * [nOut])->reduce->[nOut] +// // dhIdWpf->reduceAlongDimension(reduce::Sum, temp2, {0}); // +// ([bS, nOut] * [nOut])->reduce->[nOut] +// // dhIdWpo->reduceAlongDimension(reduce::Sum, temp3, {0}); // +// ([bS, nOut] * [nOut])->reduce->[nOut] // (*dLdWp)({0, nOut}).assign(dhIdWpi); // (*dLdWp)({nOut, 2*nOut}).assign(dhIdWpf); // (*dLdWp)({2*nOut, 3*nOut}).assign(dhIdWpo); @@ -1778,20 +2179,28 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // // dLdbi, dLdbf, dLdbg, dLdbo // if(b) { // if(x->rankOf() == 1) { -// (*dLdb)({0, nOut}).assign(*dLdh*(*dhIdbi)); // [nOut] * [nOut] -// (*dLdb)({nOut, 2*nOut}).assign(*dLdh*(*dhIdbf)); // [nOut] * [nOut] -// (*dLdb)({2*nOut, 3*nOut}).assign(*dLdh*(*dhIdbg)); // [nOut] * [nOut] -// (*dLdb)({3*nOut, 4*nOut}).assign(*dLdh*(*dhIdbo)); // [nOut] * [nOut] +// (*dLdb)({0, nOut}).assign(*dLdh*(*dhIdbi)); // [nOut] * +// [nOut] +// (*dLdb)({nOut, 2*nOut}).assign(*dLdh*(*dhIdbf)); // [nOut] * +// [nOut] +// (*dLdb)({2*nOut, 3*nOut}).assign(*dLdh*(*dhIdbg)); // [nOut] * +// [nOut] +// (*dLdb)({3*nOut, 4*nOut}).assign(*dLdh*(*dhIdbo)); // [nOut] * +// [nOut] // } // else { // // NDArray temp1 = (*dLdb)({0, nOut}); // // NDArray temp2 = (*dLdb)({nOut, 2*nOut}); // // NDArray temp3 = (*dLdb)({2*nOut, 3*nOut}); // // NDArray temp4 = (*dLdb)({3*nOut, 4*nOut}); -// // (*dLdh*(*dhIdbi)).reduceAlongDimension(reduce::Sum, temp1, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] -// // (*dLdh*(*dhIdbf)).reduceAlongDimension(reduce::Sum, temp2, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] -// // (*dLdh*(*dhIdbg)).reduceAlongDimension(reduce::Sum, temp3, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] -// // (*dLdh*(*dhIdbo)).reduceAlongDimension(reduce::Sum, temp3, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // (*dLdh*(*dhIdbi)).reduceAlongDimension(reduce::Sum, temp1, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // (*dLdh*(*dhIdbf)).reduceAlongDimension(reduce::Sum, temp2, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // (*dLdh*(*dhIdbg)).reduceAlongDimension(reduce::Sum, temp3, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // (*dLdh*(*dhIdbo)).reduceAlongDimension(reduce::Sum, temp3, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // (*dLdb)({0, nOut}).assign(dhIdbi); // (*dLdb)({nOut, 2*nOut}).assign(dhIdbf); // (*dLdb)({2*nOut, 3*nOut}).assign(dhIdbg); @@ -1808,9 +2217,11 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // if(dLdcI && !dLdhI) // delete dLdhII; // if(Wp) { -// delete Wpi; delete Wpf; delete Wpo; delete dcIdWpi; delete dcIdWpf; delete dcIdWpo; delete dhIdWpi; delete dhIdWpf; delete dhIdWpo; +// delete Wpi; delete Wpf; delete Wpo; delete dcIdWpi; delete dcIdWpf; +// delete dcIdWpo; delete dhIdWpi; delete dhIdWpf; delete dhIdWpo; // } // if(b) { -// delete dcIdbi; delete dcIdbf; delete dcIdbg; delete dcIdbo; delete dhIdbi; delete dhIdbf; delete dhIdbg; delete dhIdbo; +// delete dcIdbi; delete dcIdbf; delete dcIdbg; delete dcIdbo; delete +// dhIdbi; delete dhIdbf; delete dhIdbg; delete dhIdbo; // } // } diff --git a/libnd4j/include/ops/declarable/helpers/impl/multiUnique.cpp b/libnd4j/include/ops/declarable/helpers/impl/multiUnique.cpp index 5989f5246f67..710199b1ab3f 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/multiUnique.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/multiUnique.cpp @@ -18,50 +18,53 @@ // @author sgazeos@gmail.com // -#include #include +#include namespace sd { namespace ops { namespace helpers { //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - bool multiUnique(std::vector const& inputList, sd::memory::Workspace *workspace) { - Nd4jLong length = 0; - std::vector reshaped(inputList.size()); - int pos = 0; - Nd4jLong axis = 0; - Context cContext(1); - for (auto array: inputList) { - if (array->dataType() != sd::DataType::INT32) - throw std::runtime_error("multiUnique: this op support INT32 data type only."); - - reshaped[pos] = array->reshape(array->ordering(), {-1}); - cContext.setInputArray(pos, &reshaped[pos]); +bool multiUnique(std::vector const& inputList, + sd::memory::Workspace* workspace) { + Nd4jLong length = 0; + std::vector reshaped(inputList.size()); + int pos = 0; + Nd4jLong axis = 0; + Context cContext(1); + for (auto array : inputList) { + if (array->dataType() != sd::DataType::INT32) + throw std::runtime_error( + "multiUnique: this op support INT32 data type only."); - length += array->lengthOf(); - pos++; - } - NDArray arrayFull('c', {length}, sd::DataType::INT32, inputList[0]->getContext()); - cContext.setOutputArray(0, &arrayFull); - cContext.setIArguments(&axis, 1); + reshaped[pos] = array->reshape(array->ordering(), {-1}); + cContext.setInputArray(pos, reshaped[pos]); - sd::ops::concat opConcat; - auto cResult = opConcat.execute(&cContext); - if (Status::OK() != cResult) - throw std::runtime_error("multiUnique: cannot execute concat op properly."); + length += array->lengthOf(); + pos++; + } + NDArray arrayFull('c', {length}, sd::DataType::INT32, + inputList[0]->getContext()); + cContext.setOutputArray(0, arrayFull); + cContext.setIArguments(&axis, 1); - sd::ops::unique opUnique; - auto uResult = opUnique.evaluate({&arrayFull}); - if (Status::OK() != uResult.status()) - throw std::runtime_error("multiUnique: cannot execute unique op properly."); + sd::ops::concat opConcat; + auto cResult = opConcat.execute(&cContext); + if (Status::OK() != cResult) + throw std::runtime_error("multiUnique: cannot execute concat op properly."); - auto uniqueVals = uResult.at(0); + sd::ops::unique opUnique; + auto uResult = opUnique.evaluate({&arrayFull}); + if (Status::OK() != uResult.status()) + throw std::runtime_error("multiUnique: cannot execute unique op properly."); - bool res = uniqueVals->lengthOf() == arrayFull.lengthOf(); + auto uniqueVals = uResult.at(0); - return res; - } + bool res = uniqueVals.lengthOf() == arrayFull.lengthOf(); + return res; } -} -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/impl/rnn.cpp b/libnd4j/include/ops/declarable/helpers/impl/rnn.cpp index f910f07ed101..60bdcc20fc13 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/rnn.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/rnn.cpp @@ -18,80 +18,84 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 16.04.2018 // -// function nnCell implements an Elman RNN cell: output = activation(Wx*x + bx + Wh*ht + bh) +// function nnCell implements an Elman RNN cell: output = activation(Wx*x + bx +// + Wh*ht + bh) -#include #include +#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// -void rnnCell(sd::LaunchContext * context, const NDArray* xt, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* hPrev, NDArray* ht) { - - // xt input [bS x iS] - // Wx input-to-hidden weights, [iS x nU] - // Wh hidden-to-hidden weights, [nU x nU] - // b biases, [2*nU]: {0, nU} are input-to-hidden biases and {nU, 2*nU} are hidden-to-hidden biases - // hPrev previous cell output [bS x nU], that is at previous time step t-1, in case of projection=false -> nU=nU!!! - - const int nU = hPrev->sizeAt(1); - - // ht is current cell output [bS x nU], that is at current time step t - ht->assign(mmul(*xt, *Wx) + (*b)({{0, nU}}) + mmul(*hPrev, *Wh) + (*b)({{nU, 2*nU}})); // [bS x nU] + [nU] + [bS x nU] + [nU] = [bS x nU] - ht->applyTransform(transform::Tanh, *ht); +void rnnCell(sd::LaunchContext* context, const NDArray* xt, const NDArray* Wx, + const NDArray* Wh, const NDArray* b, const NDArray* hPrev, + NDArray* ht) { + // xt input [bS x iS] + // Wx input-to-hidden weights, [iS x nU] + // Wh hidden-to-hidden weights, [nU x nU] + // b biases, [2*nU]: {0, nU} are input-to-hidden biases and {nU, 2*nU} are + // hidden-to-hidden biases hPrev previous cell output [bS x nU], that is at + // previous time step t-1, in case of projection=false -> nU=nU!!! + + const int nU = hPrev->sizeAt(1); + + // ht is current cell output [bS x nU], that is at current time step t + ht->assign( + mmul(*xt, *Wx) + (*b)({{0, nU}}) + mmul(*hPrev, *Wh) + + (*b)({{nU, + 2 * nU}})); // [bS x nU] + [nU] + [bS x nU] + [nU] = [bS x nU] + ht->applyTransform(transform::Tanh, *ht); } ////////////////////////////////////////////////////////////////////////// -void rnnTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* h0, const NDArray* maxTimeStep, NDArray* h, NDArray* hFinal) { - - // x input [time x bS x iS] - // Wx input-to-hidden weights, [iS x nU] - // Wh hidden-to-hidden weights, [nU x nU] - // b biases for, [2*nU] - - // h0 initial cell output (at time step = 0) [bS x nU] - // maxTimeStep vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep - - const int time = x->sizeAt(0); - const int bS = x->sizeAt(1); - - // at first time step - if(h0) - hFinal->assign(h0); - else - *hFinal = 0.; - - BlasHelper::getInstance(); // to avoid memory leak in pragma parallel loops - // loop through batch of inputs - for (int e = 0; e < bS; ++e) { - // loop through time steps - for (int t = 0; t < time; ++t) { - - int maxStep = maxTimeStep ? maxTimeStep->e(e) : time; - - auto xt = (*x)({t,t+1, e,e+1, 0,0}, true); - auto ht = (*h)({t,t+1, e,e+1, 0,0}, true); - auto hPrev = (*hFinal)({e,e+1, 0,0}, true); // previous state - - if(t >= maxStep) { - ht = 0.; - if(maxStep != 0) - hPrev.assign((*h)({maxStep-1,maxStep, e,e+1, 0,0})); - } - else { - helpers::rnnCell(context, &xt, Wx, Wh, b, &hPrev, &ht); - hPrev.assign(ht); - } - } +void rnnTimeLoop(sd::LaunchContext* context, const NDArray* x, + const NDArray* Wx, const NDArray* Wh, const NDArray* b, + const NDArray* h0, const NDArray* maxTimeStep, NDArray* h, + NDArray* hFinal) { + // x input [time x bS x iS] + // Wx input-to-hidden weights, [iS x nU] + // Wh hidden-to-hidden weights, [nU x nU] + // b biases for, [2*nU] + + // h0 initial cell output (at time step = 0) [bS x nU] + // maxTimeStep vector [bS] containing integer values within [0,time), each + // element of this vector set max time step per each input in batch, this + // means there are no calculations for time >= maxTimeStep + + const int time = x->sizeAt(0); + const int bS = x->sizeAt(1); + + // at first time step + if (h0) + hFinal->assign(h0); + else + *hFinal = 0.; + + BlasHelper::getInstance(); // to avoid memory leak in pragma parallel loops + // loop through batch of inputs + for (int e = 0; e < bS; ++e) { + // loop through time steps + for (int t = 0; t < time; ++t) { + int maxStep = maxTimeStep ? maxTimeStep->e(e) : time; + + auto xt = (*x)({t, t + 1, e, e + 1, 0, 0}, true); + auto ht = (*h)({t, t + 1, e, e + 1, 0, 0}, true); + auto hPrev = (*hFinal)({e, e + 1, 0, 0}, true); // previous state + + if (t >= maxStep) { + ht = 0.; + if (maxStep != 0) + hPrev.assign((*h)({maxStep - 1, maxStep, e, e + 1, 0, 0})); + } else { + helpers::rnnCell(context, &xt, Wx, Wh, b, &hPrev, &ht); + hPrev.assign(ht); + } } + } } - -} -} -} - +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/impl/sparse_to_dense.cpp b/libnd4j/include/ops/declarable/helpers/impl/sparse_to_dense.cpp index 4baa36d652bb..20d34e275eb9 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/sparse_to_dense.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/sparse_to_dense.cpp @@ -18,107 +18,112 @@ // @author raver119@gmail.com // - -#include -#include #include +#include +#include namespace sd { - namespace ops { - namespace helpers { - template - static void fill_(const void *vvalues, const void *vindices, void *voutput, const Nd4jLong *zShapeInfo, uint8_t rank, uint64_t length) { - auto values = reinterpret_cast(vvalues); - auto indices = reinterpret_cast(vindices); - auto output = reinterpret_cast(voutput); - - int coords[MAX_RANK]; - uint64_t pos = 0; - for (uint64_t e = 0L; e < length; e++) { - // indices come in blocks - for (uint8_t p = 0; p < rank; p++) { - coords[p] = indices[pos++]; - } - - // fill output at given coords with sparse value - output[shape::getOffset(zShapeInfo, coords)] = values[e]; - } - - } - - void compat_sparse_to_dense(const NDArray &values, const NDArray &indices, NDArray *def, NDArray &output) { - // make sure host buffer is updated - values.syncToHost(); - indices.syncToHost(); +namespace ops { +namespace helpers { +template +static void fill_(const void *vvalues, const void *vindices, void *voutput, + const Nd4jLong *zShapeInfo, uint8_t rank, uint64_t length) { + auto values = reinterpret_cast(vvalues); + auto indices = reinterpret_cast(vindices); + auto output = reinterpret_cast(voutput); + + int coords[MAX_RANK]; + uint64_t pos = 0; + for (uint64_t e = 0L; e < length; e++) { + // indices come in blocks + for (uint8_t p = 0; p < rank; p++) { + coords[p] = indices[pos++]; + } + + // fill output at given coords with sparse value + output[shape::getOffset(zShapeInfo, coords)] = values[e]; + } +} + +void compat_sparse_to_dense(const NDArray &values, const NDArray &indices, + NDArray *def, NDArray &output) { + // make sure host buffer is updated + values.syncToHost(); + indices.syncToHost(); output.syncToHost(); - auto rank = output.rankOf(); - - if (output.isS()) { - // string case is not so trivial, since elements might, and probably will, have different sizes - auto numValues = values.lengthOf(); - auto numElements = output.lengthOf(); - - // first of all we calculate final buffer sizes and offsets - auto defaultLength = def == nullptr ? 0 : StringUtils::byteLength(*def); - auto valuesLength = StringUtils::byteLength(values); - auto bufferLength = defaultLength * (output.lengthOf() - numValues) + valuesLength; - auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numElements); - - // now we make sure our output buffer can hold results - output.dataBuffer()->expand( bufferLength + headerLength); - - std::vector outputCoords(rank); - std::vector valueCoords(rank); - - auto offsetsBuffer = output.bufferAsT(); - auto dataBuffer = reinterpret_cast(offsetsBuffer + output.lengthOf()); - - offsetsBuffer[0] = 0; - - // getting initial value coords - for (int e = 0; e < rank; e++) - valueCoords[e] = indices.e(e); - - // write results individually - for (Nd4jLong e = 0; e < numElements; e++) { - auto vIndex = shape::coords2index(output.shapeInfo(), valueCoords.data()); - auto cLength = 0L; - std::string str; - if (vIndex == e) { - // we're writing down sparse value here - str = values.e(e); - } else { - // we're writing down default value if it exists - if (def != nullptr) - str = def->e(0); - else - str = ""; - } - - // TODO: make it unicode compliant - memcpy(&dataBuffer[offsetsBuffer[e]], str.c_str(), str.length()); - - // writing down offset - offsetsBuffer[e+1] = cLength; - } - } else { - // numeric case is trivial, since all elements have equal sizes - - // write out default values, if they are present - if (def != nullptr) { - output.assign(def); - - // make sure output is synced back - output.syncToHost(); - } - - // write out values - BUILD_DOUBLE_SELECTOR(values.dataType(), indices.dataType(), fill_, (values.buffer(), indices.buffer(), output.buffer(), output.shapeInfo(), rank, values.lengthOf()), LIBND4J_TYPES, INDEXING_TYPES); - } - // copy back to device, if there's any - output.syncToDevice(); - } - } + auto rank = output.rankOf(); + + if (output.isS()) { + // string case is not so trivial, since elements might, and probably will, + // have different sizes + auto numValues = values.lengthOf(); + auto numElements = output.lengthOf(); + + // first of all we calculate final buffer sizes and offsets + auto defaultLength = def == nullptr ? 0 : StringUtils::byteLength(*def); + auto valuesLength = StringUtils::byteLength(values); + auto bufferLength = + defaultLength * (output.lengthOf() - numValues) + valuesLength; + auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numElements); + + // now we make sure our output buffer can hold results + output.dataBuffer()->expand(bufferLength + headerLength); + + std::vector outputCoords(rank); + std::vector valueCoords(rank); + + auto offsetsBuffer = output.bufferAsT(); + auto dataBuffer = + reinterpret_cast(offsetsBuffer + output.lengthOf()); + + offsetsBuffer[0] = 0; + + // getting initial value coords + for (int e = 0; e < rank; e++) valueCoords[e] = indices.e(e); + + // write results individually + for (Nd4jLong e = 0; e < numElements; e++) { + auto vIndex = shape::coords2index(output.shapeInfo(), valueCoords.data()); + auto cLength = 0L; + std::string str; + if (vIndex == e) { + // we're writing down sparse value here + str = values.e(e); + } else { + // we're writing down default value if it exists + if (def != nullptr) + str = def->e(0); + else + str = ""; + } + + // TODO: make it unicode compliant + memcpy(&dataBuffer[offsetsBuffer[e]], str.c_str(), str.length()); + + // writing down offset + offsetsBuffer[e + 1] = cLength; + } + } else { + // numeric case is trivial, since all elements have equal sizes + + // write out default values, if they are present + if (def != nullptr) { + output.assign(def); + + // make sure output is synced back + output.syncToHost(); } -} \ No newline at end of file + + // write out values + BUILD_DOUBLE_SELECTOR(values.dataType(), indices.dataType(), fill_, + (values.buffer(), indices.buffer(), output.buffer(), + output.shapeInfo(), rank, values.lengthOf()), + LIBND4J_TYPES, INDEXING_TYPES); + } + // copy back to device, if there's any + output.syncToDevice(); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/impl/sqrtm.cpp b/libnd4j/include/ops/declarable/helpers/impl/sqrtm.cpp index b8cc6d8ac29b..55d3f97d3bb7 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/sqrtm.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/sqrtm.cpp @@ -40,23 +40,27 @@ static void sqrtm_(const NDArray* x, NDArray* z) { auto listX = x->allTensorsAlongDimension({-2, -1}); auto listZ = z->allTensorsAlongDimension({-2, -1}); +#ifndef __CUDABLAS__ auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i++) - ops::helpers::Sqrtm::calc(*listX.at(i), *listZ.at(i)); + ops::helpers::Sqrtm::calc(listX.at(i), listZ.at(i)); }; samediff::Threads::parallel_tad(func, 0, listX.size()); +#else + for (auto i = 0; i < listX.size(); i++) + ops::helpers::Sqrtm::calc(listX.at(i), listZ.at(i)); +#endif } } ////////////////////////////////////////////////////////////////////////// void sqrtm(sd::LaunchContext* context, const NDArray* x, NDArray* z) { - - x->syncToHost(); - BUILD_SINGLE_SELECTOR(z->dataType(), sqrtm_, (x, z), FLOAT_TYPES); - z->syncToDevice(); + x->syncToHost(); + BUILD_SINGLE_SELECTOR(z->dataType(), sqrtm_, (x, z), FLOAT_TYPES); + z->syncToDevice(); } diff --git a/libnd4j/include/ops/declarable/helpers/impl/unique.cpp b/libnd4j/include/ops/declarable/helpers/impl/unique.cpp index c67b713c282d..03d9dbc3e831 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/unique.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/unique.cpp @@ -18,93 +18,98 @@ // @author sgazeos@gmail.com // -#include -#include #include +#include #include +#include namespace sd { namespace ops { namespace helpers { - template - static Nd4jLong uniqueCount_(NDArray* input) { - Nd4jLong count = 0; +template +static Nd4jLong uniqueCount_(NDArray* input) { + Nd4jLong count = 0; - std::vector values; + std::vector values; - for (Nd4jLong e = 0; e < input->lengthOf(); e++) { - T v = input->e(e); - if (std::find(values.begin(), values.end(), v) == values.end()) { - values.push_back(v); - count++; - } - } - return count; + for (Nd4jLong e = 0; e < input->lengthOf(); e++) { + T v = input->e(e); + if (std::find(values.begin(), values.end(), v) == values.end()) { + values.push_back(v); + count++; } + } + return count; +} + +Nd4jLong uniqueCount(sd::LaunchContext* context, NDArray* input) { + BUILD_SINGLE_SELECTOR(input->dataType(), return uniqueCount_, (input), + LIBND4J_TYPES); +} - Nd4jLong uniqueCount(sd::LaunchContext * context, NDArray* input) { - BUILD_SINGLE_SELECTOR(input->dataType(), return uniqueCount_, (input), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template Nd4jLong uniqueCount_, (NDArray * input), + LIBND4J_TYPES); + +template +static Nd4jStatus uniqueFunctor_(NDArray* input, NDArray* values, + NDArray* indices, NDArray* counts) { + std::vector valuesVector; + MAP_IMPL indicesMap; + MAP_IMPL countsMap; + + for (Nd4jLong e = 0; e < input->lengthOf(); e++) { + T v = input->e(e); + if (std::find(valuesVector.begin(), valuesVector.end(), v) == + valuesVector.end()) { + valuesVector.push_back(v); + indicesMap[v] = e; + countsMap[v] = 1; + } else { + countsMap[v]++; } + } - BUILD_SINGLE_TEMPLATE(template Nd4jLong uniqueCount_, (NDArray* input), LIBND4J_TYPES); - - template - static Nd4jStatus uniqueFunctor_(NDArray* input, NDArray* values, NDArray* indices, NDArray* counts) { - - std::vector valuesVector; - MAP_IMPL indicesMap; - MAP_IMPL countsMap; - - for (Nd4jLong e = 0; e < input->lengthOf(); e++) { - T v = input->e(e); - if (std::find(valuesVector.begin(), valuesVector.end(), v) == valuesVector.end()) { - valuesVector.push_back(v); - indicesMap[v] = e; - countsMap[v] = 1; - } - else { - countsMap[v]++; - } - } - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - values->p(e, static_cast(valuesVector[e])); - if (counts != nullptr) - counts->p(e, countsMap[valuesVector[e]]); - } - }; - samediff::Threads::parallel_for(func, 0, values->lengthOf()); - - for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { - auto posI = std::find(valuesVector.begin(), valuesVector.end(), input->e(e)); - auto dist = std::distance(valuesVector.begin(), posI); - indices->p(e, Nd4jLong(dist));//indicesMap[(*input)(e)]; - } - - return Status::OK(); + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + values->p(e, static_cast(valuesVector[e])); + if (counts != nullptr) counts->p(e, countsMap[valuesVector[e]]); } + }; + samediff::Threads::parallel_for(func, 0, values->lengthOf()); - Nd4jStatus uniqueFunctor(sd::LaunchContext * context, NDArray* input, NDArray* values, NDArray* indices, NDArray* counts) { - input->syncToHost(); - values->syncToHost(); - indices->syncToHost(); + for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { + auto posI = + std::find(valuesVector.begin(), valuesVector.end(), input->e(e)); + auto dist = std::distance(valuesVector.begin(), posI); + indices->p(e, Nd4jLong(dist)); // indicesMap[(*input)(e)]; + } - if (counts != nullptr) - counts->syncToHost(); + return Status::OK(); +} - BUILD_SINGLE_SELECTOR(input->dataType(), return uniqueFunctor_,(input, values, indices, counts), LIBND4J_TYPES); +Nd4jStatus uniqueFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* values, NDArray* indices, NDArray* counts) { + input->syncToHost(); + values->syncToHost(); + indices->syncToHost(); - input->syncToDevice(); - values->syncToDevice(); - indices->syncToDevice(); + if (counts != nullptr) counts->syncToHost(); - if (counts != nullptr) - counts->syncToDevice(); - } + BUILD_SINGLE_SELECTOR(input->dataType(), return uniqueFunctor_, + (input, values, indices, counts), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template Nd4jStatus uniqueFunctor_, (NDArray* input, NDArray* values, NDArray* indices, NDArray* counts), LIBND4J_TYPES); -} + input->syncToDevice(); + values->syncToDevice(); + indices->syncToDevice(); + + if (counts != nullptr) counts->syncToDevice(); } -} \ No newline at end of file + +BUILD_SINGLE_TEMPLATE(template Nd4jStatus uniqueFunctor_, + (NDArray * input, NDArray* values, NDArray* indices, + NDArray* counts), + LIBND4J_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/impl/where.cpp b/libnd4j/include/ops/declarable/helpers/impl/where.cpp index b2d758673237..ba360f919dc3 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/where.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/where.cpp @@ -18,45 +18,49 @@ // Created by raver119 on 24/09/18. // -#include #include +#include namespace sd { - namespace ops { - namespace helpers { - template - static void __where(NDArray &condition, NDArray& output, memory::Workspace *workspace) { - NDArrayList list(0, true); - int cnt = 0; - - int idx[MAX_RANK]; - - for (Nd4jLong e = 0; e < condition.lengthOf(); e++) { - - shape::index2coordsCPU(0, e, condition.shapeInfo(), idx); - - auto offset = shape::getOffset(condition.shapeInfo(), idx); - - if (condition.e(offset)) { - auto array = NDArrayFactory::create_('c', {1, condition.rankOf()}, output.dataType(), output.getContext()); - for (int f = 0; f < condition.rankOf(); f++) - array->p(f, (T) idx[f]); - - list.write(cnt++, array); - } - } - - auto s = list.stack(); - output.assign(s); - delete s; - } - BUILD_SINGLE_TEMPLATE(template void __where,(NDArray &condition, NDArray& output, memory::Workspace *workspace), LIBND4J_TYPES); - - void _where(sd::LaunchContext * context, NDArray &condition, NDArray& output, memory::Workspace *workspace) { - condition.syncToHost(); - BUILD_SINGLE_SELECTOR(output.dataType(), __where, (condition, output, workspace), LIBND4J_TYPES); - output.syncToDevice(); - } - } +namespace ops { +namespace helpers { +template +static void __where(NDArray &condition, NDArray &output, + memory::Workspace *workspace) { + NDArrayList list(0, true); + int cnt = 0; + + int idx[MAX_RANK]; + + for (Nd4jLong e = 0; e < condition.lengthOf(); e++) { + shape::index2coordsCPU(0, e, condition.shapeInfo(), idx); + + auto offset = shape::getOffset(condition.shapeInfo(), idx); + + if (condition.e(offset)) { + auto array = NDArrayFactory::create( + 'c', {1, condition.rankOf()}, output.dataType(), output.getContext()); + for (int f = 0; f < condition.rankOf(); f++) array.p(f, (T)idx[f]); + + list.write(cnt++, array); } + } + + auto s = list.stack(); + output.assign(s); +} +BUILD_SINGLE_TEMPLATE(template void __where, + (NDArray & condition, NDArray &output, + memory::Workspace *workspace), + LIBND4J_TYPES); + +void _where(sd::LaunchContext *context, NDArray &condition, NDArray &output, + memory::Workspace *workspace) { + condition.syncToHost(); + BUILD_SINGLE_SELECTOR(output.dataType(), __where, + (condition, output, workspace), LIBND4J_TYPES); + output.syncToDevice(); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/ismax.h b/libnd4j/include/ops/declarable/helpers/ismax.h index 6052b362484b..32b322ea6aaf 100644 --- a/libnd4j/include/ops/declarable/helpers/ismax.h +++ b/libnd4j/include/ops/declarable/helpers/ismax.h @@ -24,15 +24,15 @@ #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - void ismax(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector& dimensions); +void ismax(sd::LaunchContext* context, const NDArray* input, NDArray* output, + const std::vector& dimensions); } -} -} - +} // namespace ops +} // namespace sd -#endif //LIBND4J_LSTM_H +#endif // LIBND4J_LSTM_H diff --git a/libnd4j/include/ops/declarable/helpers/knn.h b/libnd4j/include/ops/declarable/helpers/knn.h index 3a3494a121b3..695803e1fadc 100644 --- a/libnd4j/include/ops/declarable/helpers/knn.h +++ b/libnd4j/include/ops/declarable/helpers/knn.h @@ -24,11 +24,12 @@ #include namespace sd { - namespace ops { - namespace helpers { - void knn_mindistance(const NDArray &input, const NDArray &lowest, const NDArray &highest, NDArray &output); - } - } +namespace ops { +namespace helpers { +void knn_mindistance(const NDArray &input, const NDArray &lowest, + const NDArray &highest, NDArray &output); } +} // namespace ops +} // namespace sd -#endif //SAMEDIFF_KNN_H +#endif // SAMEDIFF_KNN_H diff --git a/libnd4j/include/ops/declarable/helpers/legacy_helpers.h b/libnd4j/include/ops/declarable/helpers/legacy_helpers.h index e3191425d936..84a9ff5d48b7 100644 --- a/libnd4j/include/ops/declarable/helpers/legacy_helpers.h +++ b/libnd4j/include/ops/declarable/helpers/legacy_helpers.h @@ -25,46 +25,77 @@ namespace sd { namespace ops { namespace helpers { /* - FORCEINLINE void reluDerivative(NDArray* theFirst, NDArray const* theSecond); - FORCEINLINE void reluDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void relu6Derivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void leakyReluDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void eluDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void seluDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void cubeDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void reduceNorm1(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void sxeLossWithLogits(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void tanhDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void hardTanhDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void rationalTanhDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void rectifiedTanhDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void softSignDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void softPlusDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void sigmoidDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void hardSigmoidDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); + FORCEINLINE void reluDerivative(NDArray* theFirst, NDArray const* + theSecond); FORCEINLINE void reluDerivative(NDArray* theFirst, NDArray* + theSecond, NDArray* theOutput); FORCEINLINE void relu6Derivative(NDArray* + theFirst, NDArray* theSecond, NDArray* theOutput); FORCEINLINE void + leakyReluDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* + theOutput); FORCEINLINE void eluDerivative(NDArray* theFirst, NDArray* + theSecond, NDArray* theOutput); FORCEINLINE void seluDerivative(NDArray* + theFirst, NDArray* theSecond, NDArray* theOutput); FORCEINLINE void + cubeDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); + FORCEINLINE void reduceNorm1(NDArray* theFirst, NDArray* theSecond, NDArray* + theOutput); FORCEINLINE void sxeLossWithLogits(NDArray* theFirst, NDArray* + theSecond, NDArray* theOutput); FORCEINLINE void tanhDerivative(NDArray* + theFirst, NDArray* theSecond, NDArray* theOutput); FORCEINLINE void + hardTanhDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* + theOutput); FORCEINLINE void rationalTanhDerivative(NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); FORCEINLINE void + rectifiedTanhDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* + theOutput); FORCEINLINE void softSignDerivative(NDArray* theFirst, NDArray* + theSecond, NDArray* theOutput); FORCEINLINE void softPlusDerivative(NDArray* + theFirst, NDArray* theSecond, NDArray* theOutput); FORCEINLINE void + sigmoidDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); + FORCEINLINE void hardSigmoidDerivative(NDArray* theFirst, NDArray* + theSecond, NDArray* theOutput); */ - void reluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond); - void reluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void relu6Derivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void leakyReluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha); - void eluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha); - void seluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void cubeDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void reduceNorm1(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void sigmCrossEntropy(sd::LaunchContext * context, NDArray* logits, NDArray* lablels, NDArray* theOutput); - void sigmCrossEntropyGrad(sd::LaunchContext * context, NDArray* logits, NDArray* lablels, NDArray* theOutput); - void tanhDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void hardTanhDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void rationalTanhDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void rectifiedTanhDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void softSignDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void softPlusDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void sigmoidDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void hardSigmoidDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void logSumExp(sd::LaunchContext * context, NDArray* input, NDArray* axis, NDArray* output); - void logSumExp(sd::LaunchContext * context, NDArray* input, NDArray* subtrah, NDArray* axis, NDArray* output); - void weightedCrossEntropyWithLogitsFunctor(sd::LaunchContext * context, NDArray const* targets, NDArray const* input, NDArray const* weights, NDArray* output); -} -} -} +void reluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond); +void reluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void relu6Derivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void leakyReluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput, + const float alpha); +void eluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput, const float alpha); +void seluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void cubeDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void reduceNorm1(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void sigmCrossEntropy(sd::LaunchContext* context, NDArray* logits, + NDArray* lablels, NDArray* theOutput); +void sigmCrossEntropyGrad(sd::LaunchContext* context, NDArray* logits, + NDArray* lablels, NDArray* theOutput); +void tanhDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void hardTanhDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void rationalTanhDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void rectifiedTanhDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void softSignDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void softPlusDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void sigmoidDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void hardSigmoidDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void logSumExp(sd::LaunchContext* context, NDArray* input, NDArray* axis, + NDArray* output); +void logSumExp(sd::LaunchContext* context, NDArray* input, NDArray* subtrah, + NDArray* axis, NDArray* output); +void weightedCrossEntropyWithLogitsFunctor(sd::LaunchContext* context, + NDArray const* targets, + NDArray const* input, + NDArray const* weights, + NDArray* output); +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/lgamma.h b/libnd4j/include/ops/declarable/helpers/lgamma.h index 184e33556777..507cad1587f3 100644 --- a/libnd4j/include/ops/declarable/helpers/lgamma.h +++ b/libnd4j/include/ops/declarable/helpers/lgamma.h @@ -23,18 +23,18 @@ #define __LIBND4J_L_GAMMA__H__ #include + #include "array/NDArray.h" namespace sd { namespace ops { namespace helpers { - // calculate the digamma function for each element for array - void lgamma(sd::LaunchContext* context, NDArray& x, NDArray& z); - -} -} -} +// calculate the digamma function for each element for array +void lgamma(sd::LaunchContext* context, NDArray& x, NDArray& z); +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //__LIBND4J_L_GAMMA__H__ +#endif //__LIBND4J_L_GAMMA__H__ diff --git a/libnd4j/include/ops/declarable/helpers/listdiff.h b/libnd4j/include/ops/declarable/helpers/listdiff.h index 227eccac8f6a..4348fbbfd390 100644 --- a/libnd4j/include/ops/declarable/helpers/listdiff.h +++ b/libnd4j/include/ops/declarable/helpers/listdiff.h @@ -26,9 +26,11 @@ namespace sd { namespace ops { namespace helpers { - int listDiffFunctor(sd::LaunchContext * context, NDArray* values, NDArray* keep, NDArray* output1, NDArray* output2); - Nd4jLong listDiffCount(sd::LaunchContext * context, NDArray* values, NDArray* keep); -} -} -} +int listDiffFunctor(sd::LaunchContext* context, NDArray* values, NDArray* keep, + NDArray* output1, NDArray* output2); +Nd4jLong listDiffCount(sd::LaunchContext* context, NDArray* values, + NDArray* keep); +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/lrn.h b/libnd4j/include/ops/declarable/helpers/lrn.h index f8c9089c7e53..24a06e30695b 100644 --- a/libnd4j/include/ops/declarable/helpers/lrn.h +++ b/libnd4j/include/ops/declarable/helpers/lrn.h @@ -19,19 +19,22 @@ // #ifndef __LRN_H_HELPERS__ #define __LRN_H_HELPERS__ -#include #include #include +#include namespace sd { namespace ops { namespace helpers { - int lrnFunctor(sd::graph::Context& block, NDArray* input, NDArray* output, int depth, double bias, double alpha, double beta); +int lrnFunctor(sd::graph::Context& block, NDArray* input, NDArray* output, + int depth, double bias, double alpha, double beta); - void lrnBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int depth, const float bias, const float alpha, const float beta); +void lrnBP(sd::graph::Context& block, const NDArray& input, + const NDArray& gradO, NDArray& gradI, const int depth, + const float bias, const float alpha, const float beta); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/lstm.h b/libnd4j/include/ops/declarable/helpers/lstm.h index 6eb5886f3f5f..eabd15aba331 100644 --- a/libnd4j/include/ops/declarable/helpers/lstm.h +++ b/libnd4j/include/ops/declarable/helpers/lstm.h @@ -23,57 +23,64 @@ #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// - static FORCEINLINE NDArray sigmoid(const NDArray& arr) { - return (const_cast(arr)).transform(transform::Sigmoid); - } - - static FORCEINLINE void sigmoidInplace(const NDArray& arr) { - (const_cast(arr)).applyTransform(transform::Sigmoid, const_cast(arr)); - } - ////////////////////////////////////////////////////////////////////////// - static FORCEINLINE NDArray tanh(const NDArray& arr) { - return (const_cast(arr)).transform(transform::Tanh); - } +static FORCEINLINE NDArray sigmoid(const NDArray& arr) { + return (const_cast(arr)).transform(transform::Sigmoid); +} - static FORCEINLINE void tanhInplace(const NDArray& arr) { - (const_cast(arr)).applyTransform(transform::Tanh, const_cast(arr)); - } +static FORCEINLINE void sigmoidInplace(const NDArray& arr) { + (const_cast(arr)) + .applyTransform(transform::Sigmoid, const_cast(arr)); +} ////////////////////////////////////////////////////////////////////////// - static NDArray timeSubset(const NDArray* arr, const int t, const int dataFormat){ - - if(dataFormat == 0) { // TNS: shape [timeLength, numExamples, inOutSize] - return (*arr)({t,t+1, 0,0, 0,0}); - } - else if(dataFormat == 1) { //NST: shape [numExamples, inOutSize, timeLength] - return (*arr)({0,0, 0,0, t,t+1}); - } - else { //NTS: shape [numExamples, timeLength, inOutSize] - TF "time_major=false" layout - return (*arr)({0,0, t,t+1, 0,0}); - } - } - - void lstmCell(sd::LaunchContext * context, const NDArray* xt, const NDArray* ht_1, const NDArray* ct_1, const NDArray* Wx, const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, const NDArray* b, - NDArray* ht, NDArray* ct, const std::vector& params); - - void lstmTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* c0, const NDArray* Wx, const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, const NDArray* b, - NDArray* h, NDArray* c, const std::vector& params); - - void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast, - const NDArray* W, const NDArray* Wci, const NDArray* Wcf, const NDArray* Wco, const NDArray* b, - NDArray* i, NDArray* c, NDArray* f, NDArray* o, NDArray* z, NDArray* h, NDArray* y, const std::vector& params); - - - -} +static FORCEINLINE NDArray tanh(const NDArray& arr) { + return (const_cast(arr)).transform(transform::Tanh); } + +static FORCEINLINE void tanhInplace(const NDArray& arr) { + (const_cast(arr)) + .applyTransform(transform::Tanh, const_cast(arr)); } +////////////////////////////////////////////////////////////////////////// +static NDArray timeSubset(const NDArray* arr, const int t, + const int dataFormat) { + if (dataFormat == 0) { // TNS: shape [timeLength, numExamples, inOutSize] + return (*arr)({t, t + 1, 0, 0, 0, 0}); + } else if (dataFormat == + 1) { // NST: shape [numExamples, inOutSize, timeLength] + return (*arr)({0, 0, 0, 0, t, t + 1}); + } else { // NTS: shape [numExamples, timeLength, inOutSize] - TF + // "time_major=false" layout + return (*arr)({0, 0, t, t + 1, 0, 0}); + } +} -#endif //LIBND4J_LSTM_H +void lstmCell(sd::LaunchContext* context, const NDArray* xt, + const NDArray* ht_1, const NDArray* ct_1, const NDArray* Wx, + const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, + const NDArray* b, NDArray* ht, NDArray* ct, + const std::vector& params); + +void lstmTimeLoop(sd::LaunchContext* context, const NDArray* x, + const NDArray* h0, const NDArray* c0, const NDArray* Wx, + const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, + const NDArray* b, NDArray* h, NDArray* c, + const std::vector& params); + +void lstmBlockCell(const NDArray* xt, const NDArray* cLast, + const NDArray* yLast, const NDArray* W, const NDArray* Wci, + const NDArray* Wcf, const NDArray* Wco, const NDArray* b, + NDArray* i, NDArray* c, NDArray* f, NDArray* o, NDArray* z, + NDArray* h, NDArray* y, const std::vector& params); + +} // namespace helpers +} // namespace ops +} // namespace sd + +#endif // LIBND4J_LSTM_H diff --git a/libnd4j/include/ops/declarable/helpers/lstmBlock.h b/libnd4j/include/ops/declarable/helpers/lstmBlock.h index 7df9bb795c0d..9487e5c6295c 100644 --- a/libnd4j/include/ops/declarable/helpers/lstmBlock.h +++ b/libnd4j/include/ops/declarable/helpers/lstmBlock.h @@ -23,23 +23,28 @@ #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - - void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast, - const NDArray* W, const NDArray* Wci, const NDArray* Wcf, const NDArray* Wco, const NDArray* b, - NDArray* i, NDArray* c, NDArray* f, NDArray* o, NDArray* z, NDArray* h, NDArray* y, const std::vector& params); - - void lstmBlockTimeLoop(const NDArray* maxSeqLength, const NDArray* xSeq, const NDArray* c0, const NDArray* y0, - const NDArray* W, const NDArray* Wci, const NDArray* Wcf, const NDArray* Wco, const NDArray* b, - const NDArray* iSeq, const NDArray* cSeq, const NDArray* fSeq, const NDArray* oSeq, const NDArray* zSeq, - const NDArray* hSeq, const NDArray* ySeq, const std::vector& params, const int dataFormat); - -} -} -} - - -#endif //LIBND4J_LSTM_H +void lstmBlockCell(const NDArray* xt, const NDArray* cLast, + const NDArray* yLast, const NDArray* W, const NDArray* Wci, + const NDArray* Wcf, const NDArray* Wco, const NDArray* b, + NDArray* i, NDArray* c, NDArray* f, NDArray* o, NDArray* z, + NDArray* h, NDArray* y, const std::vector& params); + +void lstmBlockTimeLoop(const NDArray* maxSeqLength, const NDArray* xSeq, + const NDArray* c0, const NDArray* y0, const NDArray* W, + const NDArray* Wci, const NDArray* Wcf, + const NDArray* Wco, const NDArray* b, + const NDArray* iSeq, const NDArray* cSeq, + const NDArray* fSeq, const NDArray* oSeq, + const NDArray* zSeq, const NDArray* hSeq, + const NDArray* ySeq, const std::vector& params, + const int dataFormat); + +} // namespace helpers +} // namespace ops +} // namespace sd + +#endif // LIBND4J_LSTM_H diff --git a/libnd4j/include/ops/declarable/helpers/lstmLayer.h b/libnd4j/include/ops/declarable/helpers/lstmLayer.h index 29c434865ea7..6ba4f2e660b3 100644 --- a/libnd4j/include/ops/declarable/helpers/lstmLayer.h +++ b/libnd4j/include/ops/declarable/helpers/lstmLayer.h @@ -23,48 +23,59 @@ #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// -void ND4J_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, - const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, - const std::vector& params, - NDArray* h, NDArray* c); +void SD_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, + const NDArray* Wr, const NDArray* b, + const NDArray* hI, const NDArray* cI, + const NDArray* Wp, + const std::vector& params, NDArray* h, + NDArray* c); ////////////////////////////////////////////////////////////////////////// // this auxiliary ff should be running before backprop -void ND4J_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, - const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, - const std::vector& params, - NDArray* z, NDArray* a, NDArray* h, NDArray* c); +void SD_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, + const NDArray* Wr, const NDArray* b, + const NDArray* hI, const NDArray* cI, + const NDArray* Wp, + const std::vector& params, NDArray* z, + NDArray* a, NDArray* h, NDArray* c); ////////////////////////////////////////////////////////////////////////// -void ND4J_EXPORT lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, - const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL, - const NDArray* z, const NDArray* a, const NDArray* c, const std::vector& params, - NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp); - +void SD_EXPORT lstmLayerCellBp(const NDArray* x, const NDArray* Wx, + const NDArray* Wr, const NDArray* b, + const NDArray* hI, const NDArray* cI, + const NDArray* Wp, const NDArray* dLdh, + const NDArray* dLdhL, const NDArray* dLdcL, + const NDArray* z, const NDArray* a, + const NDArray* c, + const std::vector& params, NDArray* dLdx, + NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, + NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp); ////////////////////////////////////////////////////////////////////////// -void ND4J_EXPORT lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, - const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp, - const std::vector& params, - const bool forward, - NDArray* h, NDArray* hL, NDArray* cL); +void SD_EXPORT lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, + const NDArray* Wr, const NDArray* b, + const NDArray* seqLen, const NDArray* hI, + const NDArray* cI, const NDArray* Wp, + const std::vector& params, + const bool forward, NDArray* h, NDArray* hL, + NDArray* cL); ////////////////////////////////////////////////////////////////////////// -void ND4J_EXPORT lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, - const NDArray* b, const NDArray* seqLen, NDArray* hI, NDArray* cI, const NDArray* Wp, - const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL, - const std::vector& params, const bool forward, - NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdb, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdWp); - - -} -} -} +void SD_EXPORT lstmLayerTimeLoopBp( + const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, + const NDArray* seqLen, NDArray* hI, NDArray* cI, const NDArray* Wp, + const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL, + const std::vector& params, const bool forward, NDArray* dLdx, + NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdb, NDArray* dLdhI, + NDArray* dLdcI, NDArray* dLdWp); +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //LIBND4J_LSTMLAYER_H +#endif // LIBND4J_LSTMLAYER_H diff --git a/libnd4j/include/ops/declarable/helpers/lstsq.h b/libnd4j/include/ops/declarable/helpers/lstsq.h index 9cc6293837a0..de2322bb70c4 100644 --- a/libnd4j/include/ops/declarable/helpers/lstsq.h +++ b/libnd4j/include/ops/declarable/helpers/lstsq.h @@ -20,15 +20,19 @@ #ifndef __LST_SQ_SOLVE__H_HELPERS__ #define __LST_SQ_SOLVE__H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - int leastSquaresSolveFunctor(sd::LaunchContext* context, NDArray const* leftInput, NDArray const* rightInput, double const l2Regularizer, bool const fast, NDArray* output); -} -} +int leastSquaresSolveFunctor(sd::LaunchContext* context, + NDArray const* leftInput, + NDArray const* rightInput, + double const l2Regularizer, bool const fast, + NDArray* output); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/lup.h b/libnd4j/include/ops/declarable/helpers/lup.h index 1e58c2e3fe83..d3aa99919730 100644 --- a/libnd4j/include/ops/declarable/helpers/lup.h +++ b/libnd4j/include/ops/declarable/helpers/lup.h @@ -19,26 +19,32 @@ // #ifndef __LUP_H_HELPERS__ #define __LUP_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - int lup(sd::LaunchContext* context, NDArray* input, NDArray* lu, NDArray* permutation); - void lu(sd::LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation); - int determinant(sd::LaunchContext * context, NDArray* input, NDArray* output); - int logAbsDeterminant(sd::LaunchContext * context, NDArray* input, NDArray* output); +int lup(sd::LaunchContext* context, NDArray* input, NDArray* lu, + NDArray* permutation); +void lu(sd::LaunchContext* context, NDArray* input, NDArray* output, + NDArray* permutation); +int determinant(sd::LaunchContext* context, NDArray* input, NDArray* output); +int logAbsDeterminant(sd::LaunchContext* context, NDArray* input, + NDArray* output); - int inverse(sd::LaunchContext * context, NDArray* input, NDArray* output); - int upperInverseFunctor(sd::LaunchContext* context, NDArray* input, NDArray* output); - int lowerInverseFunctor(sd::LaunchContext* context, NDArray* input, NDArray* output); +int inverse(sd::LaunchContext* context, NDArray* input, NDArray* output); +int upperInverseFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* output); +int lowerInverseFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* output); - bool checkCholeskyInput(sd::LaunchContext * context, NDArray const* input); - int cholesky(sd::LaunchContext * context, NDArray* input, NDArray* output, bool inplace = false); - int logdetFunctor(sd::LaunchContext * context, NDArray* input, NDArray* output); -} -} -} +bool checkCholeskyInput(sd::LaunchContext* context, NDArray const* input); +int cholesky(sd::LaunchContext* context, NDArray* input, NDArray* output, + bool inplace = false); +int logdetFunctor(sd::LaunchContext* context, NDArray* input, NDArray* output); +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/matmul.h b/libnd4j/include/ops/declarable/helpers/matmul.h index 2cce162d81b3..5051cc24f163 100644 --- a/libnd4j/include/ops/declarable/helpers/matmul.h +++ b/libnd4j/include/ops/declarable/helpers/matmul.h @@ -24,12 +24,13 @@ #include namespace sd { - namespace ops { - namespace helpers { +namespace ops { +namespace helpers { - void _matmul(sd::LaunchContext * context, NDArray *A, NDArray *B, NDArray *C, int transA, int transB, double alpha = 1., double beta = 0.); - } - } +void _matmul(sd::LaunchContext *context, NDArray *A, NDArray *B, NDArray *C, + int transA, int transB, double alpha = 1., double beta = 0.); } +} // namespace ops +} // namespace sd -#endif //LIBND4J_MATMUL_H +#endif // LIBND4J_MATMUL_H diff --git a/libnd4j/include/ops/declarable/helpers/matrixSetDiag.h b/libnd4j/include/ops/declarable/helpers/matrixSetDiag.h index 332c3134bac1..caf708675971 100644 --- a/libnd4j/include/ops/declarable/helpers/matrixSetDiag.h +++ b/libnd4j/include/ops/declarable/helpers/matrixSetDiag.h @@ -22,17 +22,19 @@ #define LIBND4J_MATRIXSETDIAG_H #include + #include "array/NDArray.h" namespace sd { namespace ops { namespace helpers { - void matrixSetDiag(sd::LaunchContext* context, const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad); +void matrixSetDiag(sd::LaunchContext* context, const NDArray& input, + const NDArray& diagonal, NDArray& output, + const bool zeroPad); } -} -} - +} // namespace ops +} // namespace sd -#endif //LIBND4J_MATRIXSETDIAG_H +#endif // LIBND4J_MATRIXSETDIAG_H diff --git a/libnd4j/include/ops/declarable/helpers/matrix_band.h b/libnd4j/include/ops/declarable/helpers/matrix_band.h index f997e4d56573..8aadd8bd47a1 100644 --- a/libnd4j/include/ops/declarable/helpers/matrix_band.h +++ b/libnd4j/include/ops/declarable/helpers/matrix_band.h @@ -19,17 +19,17 @@ // #ifndef __MATRIX_BAND_H_HELPERS__ #define __MATRIX_BAND_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void matrixBandPart(sd::LaunchContext * context, NDArray* input, NDArray* output, Nd4jLong lowerBand, Nd4jLong upperBand); - +void matrixBandPart(sd::LaunchContext* context, NDArray* input, NDArray* output, + Nd4jLong lowerBand, Nd4jLong upperBand); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/matrix_diag_part.h b/libnd4j/include/ops/declarable/helpers/matrix_diag_part.h index fd25c636c92a..01f3b40d4d02 100644 --- a/libnd4j/include/ops/declarable/helpers/matrix_diag_part.h +++ b/libnd4j/include/ops/declarable/helpers/matrix_diag_part.h @@ -19,16 +19,17 @@ // #ifndef __MATRIX_DIAG_PART_HELPERS__ #define __MATRIX_DIAG_PART_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - int matrixDiagPart(sd::LaunchContext * context, NDArray const* input, NDArray* output); +int matrixDiagPart(sd::LaunchContext* context, NDArray const* input, + NDArray* output); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/max_pooling.h b/libnd4j/include/ops/declarable/helpers/max_pooling.h index a3750798b752..f713e27545be 100644 --- a/libnd4j/include/ops/declarable/helpers/max_pooling.h +++ b/libnd4j/include/ops/declarable/helpers/max_pooling.h @@ -19,16 +19,18 @@ // #ifndef __MAX_POOLING_HELPERS__ #define __MAX_POOLING_HELPERS__ -#include #include #include +#include namespace sd { namespace ops { namespace helpers { - void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector const& params, NDArray* indices); -} -} +void maxPoolingFunctor(sd::LaunchContext* context, sd::graph::Context& block, + NDArray* input, NDArray* values, + std::vector const& params, NDArray* indices); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/meshgrid.h b/libnd4j/include/ops/declarable/helpers/meshgrid.h index e6c38502940a..66723f2ed03c 100644 --- a/libnd4j/include/ops/declarable/helpers/meshgrid.h +++ b/libnd4j/include/ops/declarable/helpers/meshgrid.h @@ -23,15 +23,15 @@ #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - void meshgrid(sd::LaunchContext * context, const std::vector& inArrs, const std::vector& outArrs, const bool swapFirst2Dims); +void meshgrid(sd::LaunchContext* context, const std::vector& inArrs, + const std::vector& outArrs, const bool swapFirst2Dims); } -} -} - +} // namespace ops +} // namespace sd -#endif //LIBND4J_SRU_H +#endif // LIBND4J_SRU_H diff --git a/libnd4j/include/ops/declarable/helpers/minimax.h b/libnd4j/include/ops/declarable/helpers/minimax.h index f619a20f667e..58c0215629bc 100644 --- a/libnd4j/include/ops/declarable/helpers/minimax.h +++ b/libnd4j/include/ops/declarable/helpers/minimax.h @@ -19,17 +19,19 @@ // #ifndef __MIN_I_MAX_H_HELPERS__ #define __MIN_I_MAX_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void minimumBPFunctor(sd::LaunchContext * context, NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY); - void maximumBPFunctor(sd::LaunchContext * context, NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY); +void minimumBPFunctor(sd::LaunchContext* context, NDArray* x, NDArray* y, + NDArray* epsNext, NDArray* gradX, NDArray* gradY); +void maximumBPFunctor(sd::LaunchContext* context, NDArray* x, NDArray* y, + NDArray* epsNext, NDArray* gradX, NDArray* gradY); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/multiUnique.h b/libnd4j/include/ops/declarable/helpers/multiUnique.h index 3119901c12f1..d915ed7db0c3 100644 --- a/libnd4j/include/ops/declarable/helpers/multiUnique.h +++ b/libnd4j/include/ops/declarable/helpers/multiUnique.h @@ -19,16 +19,17 @@ // #ifndef __MULTI_UNIQUE_H_HELPERS__ #define __MULTI_UNIQUE_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - ND4J_EXPORT bool multiUnique(std::vector const& inputList, sd::memory::Workspace* workspace = nullptr); +SD_EXPORT bool multiUnique(std::vector const& inputList, + sd::memory::Workspace* workspace = nullptr); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/nth_element.h b/libnd4j/include/ops/declarable/helpers/nth_element.h index 1a2c28719af2..fac411d1833d 100644 --- a/libnd4j/include/ops/declarable/helpers/nth_element.h +++ b/libnd4j/include/ops/declarable/helpers/nth_element.h @@ -19,16 +19,17 @@ // #ifndef __NTH_ELEMENT__H_HELPERS__ #define __NTH_ELEMENT__H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void nthElementFunctor(sd::LaunchContext * context, NDArray* input, Nd4jLong n, NDArray* output, bool reverse); +void nthElementFunctor(sd::LaunchContext* context, NDArray* input, Nd4jLong n, + NDArray* output, bool reverse); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/one_hot.h b/libnd4j/include/ops/declarable/helpers/one_hot.h index 1fefd7c446ed..7b9f2b997ec9 100644 --- a/libnd4j/include/ops/declarable/helpers/one_hot.h +++ b/libnd4j/include/ops/declarable/helpers/one_hot.h @@ -18,20 +18,22 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_ONE_HOT_H -#define DEV_TESTS_ONE_HOT_H +#ifndef SD_ONE_HOT_H +#define SD_ONE_HOT_H -#include #include +#include -namespace sd { -namespace ops { -namespace helpers { +namespace sd { +namespace ops { +namespace helpers { - void onehot(const sd::LaunchContext* context, const NDArray *indices, NDArray *output, const uint axis, const uint depth, const double on, const double off); +void onehot(const sd::LaunchContext *context, const NDArray *indices, + NDArray *output, const uint axis, const uint depth, const double on, + const double off); } -} -} +} // namespace ops +} // namespace sd -#endif //DEV_TESTS_ONE_HOT_H +#endif // SD_ONE_HOT_H diff --git a/libnd4j/include/ops/declarable/helpers/percentile.h b/libnd4j/include/ops/declarable/helpers/percentile.h index 5eddb42b2907..a0044ccfb6a2 100644 --- a/libnd4j/include/ops/declarable/helpers/percentile.h +++ b/libnd4j/include/ops/declarable/helpers/percentile.h @@ -22,18 +22,19 @@ #define LIBND4J_PERCENTILE_H #include + #include "array/NDArray.h" namespace sd { namespace ops { namespace helpers { - void percentile(sd::LaunchContext * context, const NDArray& input, NDArray& output, std::vector& axises, const float q, const int interpolation); - +void percentile(sd::LaunchContext* context, const NDArray& input, + NDArray& output, std::vector& axises, const float q, + const int interpolation); } -} -} - +} // namespace ops +} // namespace sd -#endif //LIBND4J_PERCENTILE_H \ No newline at end of file +#endif // LIBND4J_PERCENTILE_H \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/prefix.h b/libnd4j/include/ops/declarable/helpers/prefix.h index 757c5c94f72f..dd810a17a8d8 100644 --- a/libnd4j/include/ops/declarable/helpers/prefix.h +++ b/libnd4j/include/ops/declarable/helpers/prefix.h @@ -21,22 +21,28 @@ #ifndef LIBND4J_PREFIX_HELPER_H #define LIBND4J_PREFIX_HELPER_H +#include #include #include + #include -#include namespace sd { - namespace ops { - namespace helpers { - // template - // void prefix(sd::LaunchContext * context, sd::scalar::Ops op, void* x, Nd4jLong *xShapeInfo, void* z, Nd4jLong* zShapeInfo, bool exclusive, bool reverse); +namespace ops { +namespace helpers { +// template +// void prefix(sd::LaunchContext * context, sd::scalar::Ops op, void* x, +// Nd4jLong *xShapeInfo, void* z, Nd4jLong* zShapeInfo, bool exclusive, bool +// reverse); - void prefix(sd::LaunchContext* context, sd::scalar::Ops op, const NDArray* x, NDArray* z, bool exclusive, bool reverse); +void prefix(sd::LaunchContext* context, sd::scalar::Ops op, const NDArray* x, + NDArray* z, bool exclusive, bool reverse); - void prefix(sd::LaunchContext* context, sd::scalar::Ops op, const NDArray* x, NDArray* z, const std::vector& dims, bool exclusive, bool reverse); - } - } -} +void prefix(sd::LaunchContext* context, sd::scalar::Ops op, const NDArray* x, + NDArray* z, const std::vector& dims, bool exclusive, + bool reverse); +} // namespace helpers +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/print_variable.h b/libnd4j/include/ops/declarable/helpers/print_variable.h index 46cf4ee01bc1..d04a1cc7ce19 100644 --- a/libnd4j/include/ops/declarable/helpers/print_variable.h +++ b/libnd4j/include/ops/declarable/helpers/print_variable.h @@ -24,11 +24,12 @@ #include namespace sd { - namespace ops { - namespace helpers { - void print_special(LaunchContext &ctx, const NDArray &array, const std::string &message = {}); - } - } +namespace ops { +namespace helpers { +void print_special(LaunchContext &ctx, const NDArray &array, + const std::string &message = {}); } +} // namespace ops +} // namespace sd -#endif //LIBND4J_PRINT_VARIABLE_H +#endif // LIBND4J_PRINT_VARIABLE_H diff --git a/libnd4j/include/ops/declarable/helpers/qr.h b/libnd4j/include/ops/declarable/helpers/qr.h index 05de6ca4049e..115d374e62c8 100644 --- a/libnd4j/include/ops/declarable/helpers/qr.h +++ b/libnd4j/include/ops/declarable/helpers/qr.h @@ -19,17 +19,17 @@ // #ifndef __QR__H_HELPERS__ #define __QR__H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void qr(sd::LaunchContext * context, NDArray const* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies); - +void qr(sd::LaunchContext* context, NDArray const* input, NDArray* outputQ, + NDArray* outputR, bool const fullMatricies); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/random.h b/libnd4j/include/ops/declarable/helpers/random.h index 5ee75e141fc7..dbe93151e360 100644 --- a/libnd4j/include/ops/declarable/helpers/random.h +++ b/libnd4j/include/ops/declarable/helpers/random.h @@ -22,20 +22,25 @@ // #ifndef __RANDOM_HELPERS__ #define __RANDOM_HELPERS__ -#include #include -#include #include +#include +#include namespace sd { namespace ops { namespace helpers { - void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output); - void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output); - void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output); - void fillRandomMultiNomial(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC); -} -} -} +void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, + NDArray* alpha, NDArray* beta, NDArray* output); +void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, + NDArray* lambda, NDArray* output); +void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, + NDArray* min, NDArray* max, NDArray* output); +void fillRandomMultiNomial(LaunchContext* context, graph::RandomGenerator& rng, + NDArray& input, NDArray& output, + const Nd4jLong numOfSamples, const int dimC); +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/random_crop.h b/libnd4j/include/ops/declarable/helpers/random_crop.h index f4d36a850b19..bf1730c66791 100644 --- a/libnd4j/include/ops/declarable/helpers/random_crop.h +++ b/libnd4j/include/ops/declarable/helpers/random_crop.h @@ -19,18 +19,19 @@ // #ifndef __RANDOM_CROP_HELPERS__ #define __RANDOM_CROP_HELPERS__ -#include #include -#include #include +#include +#include namespace sd { namespace ops { namespace helpers { - int randomCropFunctor(graph::Context& context, NDArray* input, NDArray* shape, NDArray* output, int seed); +int randomCropFunctor(graph::Context& context, NDArray* input, NDArray* shape, + NDArray* output, int seed); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/range.h b/libnd4j/include/ops/declarable/helpers/range.h index 13155fd70529..fc5e6ae85b6d 100644 --- a/libnd4j/include/ops/declarable/helpers/range.h +++ b/libnd4j/include/ops/declarable/helpers/range.h @@ -23,16 +23,16 @@ #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - // be careful: outVector must have c-order and ews = 1 !!! - void range(sd::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector); +// be careful: outVector must have c-order and ews = 1 !!! +void range(sd::LaunchContext* context, const NDArray& start, + const NDArray& delta, NDArray& outVector); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd - -#endif //LIBND4J_RANGE_H +#endif // LIBND4J_RANGE_H diff --git a/libnd4j/include/ops/declarable/helpers/reverse.h b/libnd4j/include/ops/declarable/helpers/reverse.h index c44327bb0783..3b0ba60c6cd6 100644 --- a/libnd4j/include/ops/declarable/helpers/reverse.h +++ b/libnd4j/include/ops/declarable/helpers/reverse.h @@ -23,19 +23,19 @@ #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { void reverseSequence(sd::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim); void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector* intArgs); +void reverse(sd::LaunchContext* context, const NDArray* input, NDArray* output, + const std::vector* intArgs, bool isBackProp); +} // namespace helpers +} // namespace ops +} // namespace sd -} -} -} - - -#endif //LIBND4J_REVERSESEQUENCE_H +#endif // LIBND4J_REVERSESEQUENCE_H diff --git a/libnd4j/include/ops/declarable/helpers/rnn.h b/libnd4j/include/ops/declarable/helpers/rnn.h index 32f49fe2e317..885d3be8bb77 100644 --- a/libnd4j/include/ops/declarable/helpers/rnn.h +++ b/libnd4j/include/ops/declarable/helpers/rnn.h @@ -23,18 +23,21 @@ #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { +void rnnCell(sd::LaunchContext* context, const NDArray* xt, const NDArray* Wx, + const NDArray* Wh, const NDArray* b, const NDArray* ht_1, + NDArray* ht); - void rnnCell(sd::LaunchContext * context, const NDArray* xt, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* ht_1, NDArray* ht); +void rnnTimeLoop(sd::LaunchContext* context, const NDArray* x, + const NDArray* Wx, const NDArray* Wh, const NDArray* b, + const NDArray* h0, const NDArray* maxTimeStep, NDArray* h, + NDArray* hFinal); - void rnnTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* h0, const NDArray* maxTimeStep, NDArray* h, NDArray* hFinal); +} // namespace helpers +} // namespace ops +} // namespace sd -} -} -} - - -#endif //LIBND4J_RNN_H +#endif // LIBND4J_RNN_H diff --git a/libnd4j/include/ops/declarable/helpers/roll.h b/libnd4j/include/ops/declarable/helpers/roll.h index 3e637dbc43ea..715e1e40e4a2 100644 --- a/libnd4j/include/ops/declarable/helpers/roll.h +++ b/libnd4j/include/ops/declarable/helpers/roll.h @@ -24,10 +24,13 @@ namespace sd { namespace ops { namespace helpers { - void rollFunctorLinear(sd::LaunchContext * context, NDArray* input, NDArray* output, int shift, bool inplace = false); +void rollFunctorLinear(sd::LaunchContext* context, NDArray* input, + NDArray* output, int shift, bool inplace = false); - void rollFunctorFull(sd::LaunchContext * context, NDArray* input, NDArray* output, std::vector const& shifts, std::vector const& axes, bool inplace = false); -} -} -} +void rollFunctorFull(sd::LaunchContext* context, NDArray* input, + NDArray* output, std::vector const& shifts, + std::vector const& axes, bool inplace = false); +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/s_t_b.h b/libnd4j/include/ops/declarable/helpers/s_t_b.h index 1147f05ab92d..360aa68dd457 100644 --- a/libnd4j/include/ops/declarable/helpers/s_t_b.h +++ b/libnd4j/include/ops/declarable/helpers/s_t_b.h @@ -27,23 +27,38 @@ namespace sd { namespace ops { namespace helpers { - void batchToSpace(sd::LaunchContext* context, const NDArray& input, NDArray& output, const uint cropBottom, const uint cropTop, const uint cropLeft, const uint cropRight, const uint blockSize); +void batchToSpace(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const uint cropBottom, const uint cropTop, + const uint cropLeft, const uint cropRight, + const uint blockSize); - void spaceToBatch(sd::LaunchContext* context, const NDArray& input, NDArray& output, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight, const uint blockSize); +void spaceToBatch(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const uint padBottom, const uint padTop, + const uint padLeft, const uint padRight, + const uint blockSize); - void spaceToBatchND(sd::LaunchContext* context, const NDArray& input, const NDArray& blockShape, const NDArray& padding, NDArray& output); +void spaceToBatchND(sd::LaunchContext* context, const NDArray& input, + const NDArray& blockShape, const NDArray& padding, + NDArray& output); - void batchToSpaceND(sd::LaunchContext* context, const NDArray& input, const NDArray& blockShape, const NDArray& crop, NDArray& output); +void batchToSpaceND(sd::LaunchContext* context, const NDArray& input, + const NDArray& blockShape, const NDArray& crop, + NDArray& output); /* // this method MUST be platform-specific template - void _execute(sd::LaunchContext * context, void *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong *space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const Nd4jLong *block_offsets, void *ptrBatch, const Nd4jLong *batch_shape, const Nd4jLong *batch_strides); + void _execute(sd::LaunchContext * context, void *ptrSpace, const Nd4jLong +*space_shape, const Nd4jLong *space_strides, const Nd4jLong *block_shape, const +Nd4jLong *pad_start, const Nd4jLong *block_offsets, void *ptrBatch, const +Nd4jLong *batch_shape, const Nd4jLong *batch_strides); template - FORCEINLINE void _prepare(sd::LaunchContext * context, NDArray * space, NDArray *batch, const Nd4jLong block_array[NUM_BLOCK_DIMS], const Nd4jLong padding_array[NUM_BLOCK_DIMS * 2]) { + FORCEINLINE void _prepare(sd::LaunchContext * context, NDArray * space, +NDArray *batch, const Nd4jLong block_array[NUM_BLOCK_DIMS], const Nd4jLong +padding_array[NUM_BLOCK_DIMS * 2]) { Nd4jLong pad_start[NUM_BLOCK_DIMS]; Nd4jLong block_shape[NUM_BLOCK_DIMS]; @@ -69,26 +84,38 @@ namespace helpers { const Nd4jLong space_b = batch_b % space_size; Nd4jLong block_index = batch_b / space_size; Nd4jLong block_offsets[NUM_BLOCK_DIMS]; - for (Nd4jLong block_dim = NUM_BLOCK_DIMS - 1; block_dim >= 0; --block_dim) { - block_offsets[block_dim] = block_dim > 0 ? block_index % block_shape[block_dim] : block_index; - block_index /= block_shape[block_dim]; + for (Nd4jLong block_dim = NUM_BLOCK_DIMS - 1; block_dim >= 0; +--block_dim) { block_offsets[block_dim] = block_dim > 0 ? block_index % +block_shape[block_dim] : block_index; block_index /= block_shape[block_dim]; } Nd4jLong space_offset = space_b * space_strides[0]; Nd4jLong batch_offset = batch_b * batch_strides[0]; auto xType = space->dataType(); - //_execute(space->buffer() + space_offset, space_shape, &space_strides[1], block_shape, pad_start, block_offsets, batch->buffer() + batch_offset, batch_shape, &batch_strides[1]); - BUILD_SINGLE_PARTIAL_SELECTOR(xType, _execute<, (NUM_BLOCK_DIMS, B2S>(context, space->bufferWithOffset(space_offset), space_shape, &space_strides[1], block_shape, pad_start, block_offsets, batch->bufferWithOffset(batch_offset), batch_shape, &batch_strides[1])), LIBND4J_TYPES); + //_execute(space->buffer() + space_offset, +space_shape, &space_strides[1], block_shape, pad_start, block_offsets, +batch->buffer() + batch_offset, batch_shape, &batch_strides[1]); + BUILD_SINGLE_PARTIAL_SELECTOR(xType, _execute<, (NUM_BLOCK_DIMS, +B2S>(context, space->bufferWithOffset(space_offset), space_shape, +&space_strides[1], block_shape, pad_start, block_offsets, +batch->bufferWithOffset(batch_offset), batch_shape, &batch_strides[1])), +LIBND4J_TYPES); } }; - Nd4jStatus _spaceToBatch(sd::LaunchContext * context, int internal_block_dims, NDArray *input, NDArray *output, std::vector &internal_input_shape, std::vector &internal_output_shape, Nd4jLong *block_shape, Nd4jLong *paddings); + Nd4jStatus _spaceToBatch(sd::LaunchContext * context, int +internal_block_dims, NDArray *input, NDArray *output, std::vector +&internal_input_shape, std::vector &internal_output_shape, Nd4jLong +*block_shape, Nd4jLong *paddings); - Nd4jStatus _batchToSpace(sd::LaunchContext * context, int internal_block_dims, NDArray *input, NDArray *output, std::vector &internal_input_shape, std::vector &internal_output_shape, Nd4jLong *block_shape, Nd4jLong *crops); + Nd4jStatus _batchToSpace(sd::LaunchContext * context, int +internal_block_dims, NDArray *input, NDArray *output, std::vector +&internal_input_shape, std::vector &internal_output_shape, Nd4jLong +*block_shape, Nd4jLong *crops); */ -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //LIBND4J_S_T_B_H +#endif // LIBND4J_S_T_B_H diff --git a/libnd4j/include/ops/declarable/helpers/s_t_d.h b/libnd4j/include/ops/declarable/helpers/s_t_d.h index 7ef500f03a29..21f588fd05d1 100644 --- a/libnd4j/include/ops/declarable/helpers/s_t_d.h +++ b/libnd4j/include/ops/declarable/helpers/s_t_d.h @@ -18,13 +18,14 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { namespace ops { namespace helpers { - void _spaceTodepth(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC); -} +void _spaceTodepth(sd::LaunchContext *context, const NDArray &input, + NDArray *output, int block_size, bool isNHWC); } -} \ No newline at end of file +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/scatter.h b/libnd4j/include/ops/declarable/helpers/scatter.h index 6e456ff9728f..e888260a1f37 100644 --- a/libnd4j/include/ops/declarable/helpers/scatter.h +++ b/libnd4j/include/ops/declarable/helpers/scatter.h @@ -18,23 +18,29 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_SCATTER_H -#define DEV_TESTS_SCATTER_H +#ifndef SD_SCATTER_H +#define SD_SCATTER_H #include namespace sd { - namespace ops { - namespace helpers { - void scatter(sd::LaunchContext* context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock); - - void scatterND(sd::LaunchContext* context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock); - - void scatterForLoss(sd::LaunchContext* context, const NDArray& indices, NDArray& updates, NDArray& output, const bool calcGrad); - - Nd4jLong checkIndices(sd::LaunchContext *context, const NDArray& indices, const NDArray& output, const int axis = -1); - } - } -} - -#endif //DEV_TESTS_SCATTER_H +namespace ops { +namespace helpers { +void scatter(sd::LaunchContext* context, pairwise::Ops op, + const NDArray& indices, const NDArray& updates, NDArray& output, + const bool lock); + +void scatterND(sd::LaunchContext* context, pairwise::Ops op, + const NDArray& indices, const NDArray& updates, NDArray& output, + const bool lock); + +void scatterForLoss(sd::LaunchContext* context, const NDArray& indices, + NDArray& updates, NDArray& output, const bool calcGrad); + +Nd4jLong checkIndices(sd::LaunchContext* context, const NDArray& indices, + const NDArray& output, const int axis = -1); +} // namespace helpers +} // namespace ops +} // namespace sd + +#endif // SD_SCATTER_H diff --git a/libnd4j/include/ops/declarable/helpers/segment.h b/libnd4j/include/ops/declarable/helpers/segment.h index 2433313ffbaf..900b80ba0c1a 100644 --- a/libnd4j/include/ops/declarable/helpers/segment.h +++ b/libnd4j/include/ops/declarable/helpers/segment.h @@ -17,66 +17,104 @@ // // @author sgazeos@gmail.com // @brief helpers fuctions for segment_* ops (segment_max, segment_min, etc.) -// @brief helpers fuctions for unsorted_segment_* ops (unsorted_segment_max, etc.) +// @brief helpers fuctions for unsorted_segment_* ops (unsorted_segment_max, +// etc.) // #ifndef __SEGMENT_HELPERS__ #define __SEGMENT_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - bool segmentIndicesValidate(sd::LaunchContext * context, NDArray* indices, NDArray& expected, NDArray& output); +bool segmentIndicesValidate(sd::LaunchContext* context, NDArray* indices, + NDArray& expected, NDArray& output); - bool unsortedSegmentIndicesValidate(sd::LaunchContext * context, NDArray* indices, Nd4jLong numOfClasses, Nd4jLong& output); +bool unsortedSegmentIndicesValidate(sd::LaunchContext* context, + NDArray* indices, Nd4jLong numOfClasses, + Nd4jLong& output); - void segmentMaxFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output); +void segmentMaxFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output); - void segmentMinFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output); +void segmentMinFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output); - void segmentMeanFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output); +void segmentMeanFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output); - void segmentSumFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output); +void segmentSumFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output); - void segmentProdFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output); +void segmentProdFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output); - void unsortedSegmentSqrtNFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output); +void unsortedSegmentSqrtNFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output); - void unsortedSegmentMaxFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output); +void unsortedSegmentMaxFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output); - void unsortedSegmentMinFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output); +void unsortedSegmentMinFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output); - void unsortedSegmentMeanFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output); +void unsortedSegmentMeanFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output); - void unsortedSegmentSumFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output); +void unsortedSegmentSumFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output); - void unsortedSegmentProdFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output); +void unsortedSegmentProdFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output); - int segmentMaxFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output); +int segmentMaxFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output); - int segmentMinFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output); +int segmentMinFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output); - int segmentMeanFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output); +int segmentMeanFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output); - int segmentSumFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output); +int segmentSumFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output); - int segmentProdFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output); +int segmentProdFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output); - int unsortedSegmentSqrtNFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output); +int unsortedSegmentSqrtNFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output); - int unsortedSegmentMaxFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output); +int unsortedSegmentMaxFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output); - int unsortedSegmentMinFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output); +int unsortedSegmentMinFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output); - int unsortedSegmentMeanFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output); +int unsortedSegmentMeanFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output); - int unsortedSegmentSumFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output); +int unsortedSegmentSumFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output); - int unsortedSegmentProdFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output); +int unsortedSegmentProdFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/segment_common.h b/libnd4j/include/ops/declarable/helpers/segment_common.h index b0a92b8b33e1..aa67893e9ed7 100644 --- a/libnd4j/include/ops/declarable/helpers/segment_common.h +++ b/libnd4j/include/ops/declarable/helpers/segment_common.h @@ -16,21 +16,23 @@ // // @author sgazeos@gmail.com -// @brief helpers common fuctions for segment_* ops (segment_max, segment_min, etc.) -// @brief helpers common fuctions for unsorted_segment_* ops (unsorted_segment_max, etc.) +// @brief helpers common fuctions for segment_* ops (segment_max, segment_min, +// etc.) +// @brief helpers common fuctions for unsorted_segment_* ops +// (unsorted_segment_max, etc.) // #ifndef __SEGMENT_COMMON_HELPERS__ #define __SEGMENT_COMMON_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void fillUpSegments(NDArray* indices, Nd4jLong numClasses, NDArray& classesRangesBegs, NDArray& classesRangesLens); - +void fillUpSegments(NDArray* indices, Nd4jLong numClasses, + NDArray& classesRangesBegs, NDArray& classesRangesLens); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/sequence_mask.h b/libnd4j/include/ops/declarable/helpers/sequence_mask.h index 491640b9e769..2b97dcca580d 100644 --- a/libnd4j/include/ops/declarable/helpers/sequence_mask.h +++ b/libnd4j/include/ops/declarable/helpers/sequence_mask.h @@ -19,16 +19,17 @@ // #ifndef __SEQUENCE_MASK_HELPERS__ #define __SEQUENCE_MASK_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void sequenceMask(sd::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex); +void sequenceMask(sd::LaunchContext* context, NDArray* input, NDArray* output, + int maxIndex); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/sg_cb.h b/libnd4j/include/ops/declarable/helpers/sg_cb.h index 6b0824a81a4b..9279c2d6e6db 100644 --- a/libnd4j/include/ops/declarable/helpers/sg_cb.h +++ b/libnd4j/include/ops/declarable/helpers/sg_cb.h @@ -18,23 +18,33 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_SG_CB_H -#define DEV_TESTS_SG_CB_H +#ifndef SD_SG_CB_H +#define SD_SG_CB_H +#include #include #include -#include namespace sd { - namespace ops { - namespace helpers { - void skipgram(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, NDArray &negTable, NDArray &target, NDArray &ngStarter, int nsRounds, NDArray &indices, NDArray &codes, NDArray &alpha, NDArray &randomValue, NDArray &inferenceVector, const bool preciseMode, const int numWorkers); +namespace ops { +namespace helpers { +void skipgram(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, + NDArray &negTable, NDArray &target, NDArray &ngStarter, + int nsRounds, NDArray &indices, NDArray &codes, NDArray &alpha, + NDArray &randomValue, NDArray &inferenceVector, + const bool preciseMode, const int numWorkers); - void cbow(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, NDArray &negTable, NDArray &target, NDArray &ngStarter, int nsRounds, NDArray &context, NDArray &lockedWords, NDArray &indices, NDArray &codes, NDArray &alpha, NDArray &randomValue, NDArray &numLabels, NDArray &inferenceVector, const bool trainWords, const int numWorkers); +void cbow(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, + NDArray &negTable, NDArray &target, NDArray &ngStarter, int nsRounds, + NDArray &context, NDArray &lockedWords, NDArray &indices, + NDArray &codes, NDArray &alpha, NDArray &randomValue, + NDArray &numLabels, NDArray &inferenceVector, const bool trainWords, + const int numWorkers); - int binarySearch(const int *haystack, const int needle, const int totalElements); - } - } -} +int binarySearch(const int *haystack, const int needle, + const int totalElements); +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //DEV_TESTS_SG_CB_H +#endif // SD_SG_CB_H diff --git a/libnd4j/include/ops/declarable/helpers/shift.h b/libnd4j/include/ops/declarable/helpers/shift.h index f1b21741ca57..6dfe541d1fc0 100644 --- a/libnd4j/include/ops/declarable/helpers/shift.h +++ b/libnd4j/include/ops/declarable/helpers/shift.h @@ -18,25 +18,29 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_SHIFT_H -#define DEV_TESTS_SHIFT_H +#ifndef SD_SHIFT_H +#define SD_SHIFT_H +#include #include #include -#include namespace sd { - namespace ops { - namespace helpers { - void rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift); +namespace ops { +namespace helpers { +void rshift_bits(LaunchContext *launchContext, NDArray &x, NDArray &z, + uint32_t shift); - void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift); +void shift_bits(LaunchContext *launchContext, NDArray &x, NDArray &z, + uint32_t shift); - void cyclic_rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift); +void cyclic_rshift_bits(LaunchContext *launchContext, NDArray &x, NDArray &z, + uint32_t shift); - void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift); - } - } -} +void cyclic_shift_bits(LaunchContext *launchContext, NDArray &x, NDArray &z, + uint32_t shift); +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //DEV_TESTS_SHIFT_H +#endif // SD_SHIFT_H diff --git a/libnd4j/include/ops/declarable/helpers/solve.h b/libnd4j/include/ops/declarable/helpers/solve.h index 17234f313877..ee0bca0c4820 100644 --- a/libnd4j/include/ops/declarable/helpers/solve.h +++ b/libnd4j/include/ops/declarable/helpers/solve.h @@ -19,16 +19,18 @@ // #ifndef __SOLVE__H_HELPERS__ #define __SOLVE__H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - int solveFunctor(sd::LaunchContext* context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output); - void adjointMatrix(sd::LaunchContext* context, NDArray const* input, NDArray* output); -} -} -} +int solveFunctor(sd::LaunchContext* context, NDArray* leftInput, + NDArray* rightInput, bool adjoint, NDArray* output); +void adjointMatrix(sd::LaunchContext* context, NDArray const* input, + NDArray* output); +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/sparse_to_dense.h b/libnd4j/include/ops/declarable/helpers/sparse_to_dense.h index 541621257548..ecb66ff9f9ef 100644 --- a/libnd4j/include/ops/declarable/helpers/sparse_to_dense.h +++ b/libnd4j/include/ops/declarable/helpers/sparse_to_dense.h @@ -24,11 +24,12 @@ #include namespace sd { - namespace ops { - namespace helpers { - void compat_sparse_to_dense(const NDArray &values, const NDArray &indices, NDArray *def, NDArray &output); - } - } +namespace ops { +namespace helpers { +void compat_sparse_to_dense(const NDArray &values, const NDArray &indices, + NDArray *def, NDArray &output); } +} // namespace ops +} // namespace sd -#endif //SAMEDIFF_SPARSE_TO_DENSE_H +#endif // SAMEDIFF_SPARSE_TO_DENSE_H diff --git a/libnd4j/include/ops/declarable/helpers/sru.h b/libnd4j/include/ops/declarable/helpers/sru.h index 639247278512..a998ef8a345a 100644 --- a/libnd4j/include/ops/declarable/helpers/sru.h +++ b/libnd4j/include/ops/declarable/helpers/sru.h @@ -23,25 +23,28 @@ #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { +void sruCell(sd::LaunchContext* context, const NDArray* x, const NDArray* c0, + const NDArray* w, const NDArray* b, NDArray* h, NDArray* c); - void sruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* c0, const NDArray* w, const NDArray* b, NDArray* h, NDArray* c); +void sruTimeLoop(sd::LaunchContext* context, const NDArray* x, + const NDArray* c0, const NDArray* w, const NDArray* b, + NDArray* h, NDArray* c); - void sruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* c0, const NDArray* w, const NDArray* b, NDArray* h, NDArray* c); +void sruBI(sd::LaunchContext* context, NDArray* x, const NDArray* w, + const NDArray* b, const NDArray* c0, const NDArray* mask, + NDArray* ht, NDArray* ct); - void sruBI(sd::LaunchContext * context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct); +void sruBIBP(sd::LaunchContext* context, NDArray* x, const NDArray* w, + const NDArray* b, const NDArray* c0, const NDArray* ct, + const NDArray* inGradC0, const NDArray* inGradH, + const NDArray* mask, NDArray* gradI, NDArray* gradWeights, + NDArray* gradB, NDArray* gradC0); +} // namespace helpers +} // namespace ops +} // namespace sd - void sruBIBP(sd::LaunchContext * context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradH, const NDArray* mask, - NDArray* gradI, NDArray* gradWeights, NDArray* gradB, NDArray* gradC0); -} -} -} - - -#endif //LIBND4J_SRU_H - - - \ No newline at end of file +#endif // LIBND4J_SRU_H diff --git a/libnd4j/include/ops/declarable/helpers/stack.h b/libnd4j/include/ops/declarable/helpers/stack.h index 0ab486a5d33c..af694cc57448 100644 --- a/libnd4j/include/ops/declarable/helpers/stack.h +++ b/libnd4j/include/ops/declarable/helpers/stack.h @@ -21,20 +21,21 @@ #ifndef LIBND4J_STACK_H #define LIBND4J_STACK_H -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { -void stack (sd::LaunchContext* context, const std::vector& inArrs, NDArray& outArr, const int dim); -void unstack(sd::LaunchContext* context, const NDArray& input, const std::vector& outArrs, const int dim); - - -} -} -} +void stack(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& outArr, + const int dim); +void unstack(sd::LaunchContext* context, const NDArray& input, + const std::vector& outArrs, const int dim); +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //LIBND4J_STACK_H +#endif // LIBND4J_STACK_H diff --git a/libnd4j/include/ops/declarable/helpers/svd.h b/libnd4j/include/ops/declarable/helpers/svd.h index 027807191897..14843d09d9a6 100644 --- a/libnd4j/include/ops/declarable/helpers/svd.h +++ b/libnd4j/include/ops/declarable/helpers/svd.h @@ -22,19 +22,22 @@ #define LIBND4J_SVD_HELPER_H #include + #include "array/NDArray.h" -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// -// svd operation, this function is not method of SVD class, it is standalone function -void svd(sd::LaunchContext* context, const NDArray* x, const std::vector& outArrs, const bool fullUV, const bool calcUV, const int switchNum); - +// svd operation, this function is not method of SVD class, it is standalone +// function +void svd(sd::LaunchContext* context, const NDArray* x, + const std::vector& outArrs, const bool fullUV, + const bool calcUV, const int switchNum); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //LIBND4J_SVD_HELPER_H +#endif // LIBND4J_SVD_HELPER_H diff --git a/libnd4j/include/ops/declarable/helpers/threshold.h b/libnd4j/include/ops/declarable/helpers/threshold.h index 21ac0c8208a0..f77e0efb057c 100644 --- a/libnd4j/include/ops/declarable/helpers/threshold.h +++ b/libnd4j/include/ops/declarable/helpers/threshold.h @@ -24,14 +24,14 @@ #include namespace sd { - namespace ops { - namespace helpers { - int32_t thresholdEstimate(const NDArray &updates, float threshold); +namespace ops { +namespace helpers { +int32_t thresholdEstimate(const NDArray &updates, float threshold); - void thresholdEncode(NDArray &updates, NDArray &encoded, float threshold); - void thresholdDecode(const NDArray &encoded, NDArray &updates); - } - } -} +void thresholdEncode(NDArray &updates, NDArray &encoded, float threshold); +void thresholdDecode(const NDArray &encoded, NDArray &updates); +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //SD_THRESHOLD_H +#endif // SD_THRESHOLD_H diff --git a/libnd4j/include/ops/declarable/helpers/toggle_bits.h b/libnd4j/include/ops/declarable/helpers/toggle_bits.h index 6d8ffe44af3d..763c494e3784 100644 --- a/libnd4j/include/ops/declarable/helpers/toggle_bits.h +++ b/libnd4j/include/ops/declarable/helpers/toggle_bits.h @@ -20,18 +20,19 @@ #include -#ifndef DEV_TESTS_TOGGLE_BITS_H -#define DEV_TESTS_TOGGLE_BITS_H +#ifndef SD_TOGGLE_BITS_H +#define SD_TOGGLE_BITS_H namespace sd { - namespace ops { - namespace helpers { - template - static void toggle_bits__(sd::LaunchContext * context, NDArray& in, NDArray& out); +namespace ops { +namespace helpers { +template +static void toggle_bits__(sd::LaunchContext* context, NDArray& in, + NDArray& out); - void __toggle_bits(sd::LaunchContext * context, NDArray& in, NDArray& out); - } - } -} +void __toggle_bits(sd::LaunchContext* context, NDArray& in, NDArray& out); +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //DEV_TESTS_TOGGLE_BITS_H +#endif // SD_TOGGLE_BITS_H diff --git a/libnd4j/include/ops/declarable/helpers/top_k.h b/libnd4j/include/ops/declarable/helpers/top_k.h index 6a459f925308..4c93d4d5bf85 100644 --- a/libnd4j/include/ops/declarable/helpers/top_k.h +++ b/libnd4j/include/ops/declarable/helpers/top_k.h @@ -19,18 +19,20 @@ // #ifndef __TOP_K_HELPERS__ #define __TOP_K_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - int topKFunctor(sd::LaunchContext * context, const NDArray* input, NDArray* values, NDArray* indices, const uint k, bool needSort); +int topKFunctor(sd::LaunchContext* context, const NDArray* input, + NDArray* values, NDArray* indices, const uint k, bool needSort); - int inTopKFunctor(sd::LaunchContext * context, const NDArray* predictions, const NDArray* targets, NDArray* output, const uint k); +int inTopKFunctor(sd::LaunchContext* context, const NDArray* predictions, + const NDArray* targets, NDArray* output, const uint k); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/transforms.h b/libnd4j/include/ops/declarable/helpers/transforms.h index d20b98e6c665..2a61a06d88b9 100644 --- a/libnd4j/include/ops/declarable/helpers/transforms.h +++ b/libnd4j/include/ops/declarable/helpers/transforms.h @@ -21,70 +21,104 @@ #ifndef LIBND4J_TRANSFORMS_H #define LIBND4J_TRANSFORMS_H -#include -#include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - void triuBP(sd::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int diagonal); - - void trace(sd::LaunchContext * context, const NDArray& input, NDArray& output); +void triuBP(sd::LaunchContext* context, const NDArray& input, + const NDArray& gradO, NDArray& gradI, const int diagonal); - void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace); +void trace(sd::LaunchContext* context, const NDArray& input, NDArray& output); - // auxiliary function which serves for recursion purpose and is used in pad operation - // void recursiveLoopForPad(const int mode, NDArray& input, const NDArray& paddings, NDArray& output, std::vector dimensions, int dim, int inIdx, int outIdx, NDArray& padValue); +void randomShuffle(sd::LaunchContext* context, NDArray& input, NDArray& output, + sd::graph::RandomGenerator& rng, const bool isInplace); - void pad(sd::LaunchContext * context, const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, NDArray const& padValue); +// auxiliary function which serves for recursion purpose and is used in pad +// operation void recursiveLoopForPad(const int mode, NDArray& input, const +// NDArray& paddings, NDArray& output, std::vector dimensions, int dim, int +// inIdx, int outIdx, NDArray& padValue); - void invertPermutation(sd::LaunchContext * context, const NDArray& input, NDArray& output); +void pad(sd::LaunchContext* context, const int mode, const NDArray& input, + const NDArray& paddings, NDArray& output, NDArray const& padValue); - void gatherND(sd::LaunchContext * context, NDArray& input, NDArray& indices, NDArray& output); +void invertPermutation(sd::LaunchContext* context, const NDArray& input, + NDArray& output); - void gather(sd::LaunchContext * context, NDArray* input, const NDArray* indices, NDArray* output, const std::vector& intArgs); +void gatherND(sd::LaunchContext* context, NDArray& input, NDArray& indices, + NDArray& output); - void eye(sd::LaunchContext * context, NDArray& output); +void gather(sd::LaunchContext* context, NDArray* input, const NDArray* indices, + NDArray* output, const std::vector& intArgs); - void scatterUpdate(sd::LaunchContext * context, NDArray& operand, NDArray& updates, const std::vector* intArgs); +void eye(sd::LaunchContext* context, NDArray& output); - void scatterSimple(sd::LaunchContext * context, const int opId, NDArray& input, const NDArray& updates, const NDArray& indices, const std::vector& dimensions); +void scatterUpdate(sd::LaunchContext* context, NDArray& operand, + NDArray& updates, const std::vector* intArgs); - void mergeMaxIndex(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output); +void scatterSimple(sd::LaunchContext* context, const int opId, NDArray& input, + const NDArray& updates, const NDArray& indices, + const std::vector& dimensions); - void mergeMax(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output); - void mergeMaxBp(sd::LaunchContext* context, const std::vector& inArrs, std::vector& outArrs); +void mergeMaxIndex(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output); - void mergeAvg(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output); - void mergeAvgBp(sd::LaunchContext* context, const NDArray& gradient, std::vector& outArrs); +void mergeMax(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output); +void mergeMaxBp(sd::LaunchContext* context, + const std::vector& inArrs, + std::vector& outArrs); - void mergeAdd(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output); - void mergeAddBp(sd::LaunchContext* context, const NDArray& gradient, std::vector& outArrs); +void mergeAvg(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output); +void mergeAvgBp(sd::LaunchContext* context, const NDArray& gradient, + std::vector& outArrs); - void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace, const bool useAverage); +void mergeAdd(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output); +void mergeAddBp(sd::LaunchContext* context, const NDArray& gradient, + std::vector& outArrs); - void clipByGlobalNorm(sd::LaunchContext * context, std::vector const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector& outputs, bool isInplace); +void clipByNorm(sd::LaunchContext* context, NDArray& input, NDArray& output, + const std::vector& dimensions, const NDArray& clipNorm, + const bool isInplace, const bool useAverage); - void clipByNormBp(sd::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector& dimensions, const NDArray& clipNorm, const bool useAverage); +void clipByGlobalNorm(sd::LaunchContext* context, + std::vector const& inputs, double clipNorm, + sd::memory::Workspace* workspace, + std::vector& outputs, bool isInplace); - void clipByAveragedNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace); +void clipByNormBp(sd::LaunchContext* context, const NDArray& input, + const NDArray& gradO, NDArray& gradI /*output*/, + const std::vector& dimensions, const NDArray& clipNorm, const bool useAverage); - void mirrorPad(sd::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode); +void clipByAveragedNorm(sd::LaunchContext* context, NDArray& input, NDArray& output, + const std::vector& dimensions, const NDArray& clipNorm, + const bool isInplace); - void clipByValue(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output); +void mirrorPad(sd::LaunchContext* context, const NDArray& input, + const NDArray& paddings, NDArray& output, const int mode); - void mirrorPad(sd::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode); +void clipByValue(sd::LaunchContext* context, NDArray& input, double leftBound, + double rightBound, NDArray& output); - void concat(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output, const int axis); +void mirrorPad(sd::LaunchContext* context, const NDArray& input, + const NDArray& paddings, NDArray& output, const int mode); - void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector reps); +void concat(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output, + const int axis); - void split(sd::LaunchContext* context, const NDArray& input, std::vector& outArrs, const int axis); -} -} -} +void tileBP(sd::LaunchContext* context, const NDArray& gradO /*input*/, + NDArray& gradI /*output*/, const std::vector reps); +void split(sd::LaunchContext* context, const NDArray& input, + std::vector& outArrs, const int axis); +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //LIBND4J_TRANSFORMS_H +#endif // LIBND4J_TRANSFORMS_H diff --git a/libnd4j/include/ops/declarable/helpers/triangular_solve.h b/libnd4j/include/ops/declarable/helpers/triangular_solve.h index 94e0198afc40..44fe095c471b 100644 --- a/libnd4j/include/ops/declarable/helpers/triangular_solve.h +++ b/libnd4j/include/ops/declarable/helpers/triangular_solve.h @@ -19,18 +19,21 @@ // #ifndef __TRIANGULAR_SOLVE__H_HELPERS__ #define __TRIANGULAR_SOLVE__H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - int triangularSolveFunctor(sd::LaunchContext* context, NDArray* leftInput, NDArray* rightInput, bool lower, bool unitsOnDiag, NDArray* output); +int triangularSolveFunctor(sd::LaunchContext* context, NDArray* leftInput, + NDArray* rightInput, bool lower, bool unitsOnDiag, NDArray* output); template - void triangularSolve2D(sd::LaunchContext* context, const NDArray& leftInput, const NDArray& rightInput, const bool lower, const bool unitsOnDiag, NDArray& output); - void adjointMatrix(sd::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output); -} -} -} + void triangularSolve2D(sd::LaunchContext* context, const NDArray& leftInput, const NDArray& rightInput, const bool lower, const bool unitsOnDiag, + NDArray& output); +void adjointMatrix(sd::LaunchContext* context, NDArray const* input, + bool const lower, NDArray* output); +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/unique.h b/libnd4j/include/ops/declarable/helpers/unique.h index 8898be585cd9..74071451df16 100644 --- a/libnd4j/include/ops/declarable/helpers/unique.h +++ b/libnd4j/include/ops/declarable/helpers/unique.h @@ -20,18 +20,19 @@ #ifndef __UNIQUE_H_HELPERS__ #define __UNIQUE_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - Nd4jLong uniqueCount(sd::LaunchContext * context, NDArray* input); +Nd4jLong uniqueCount(sd::LaunchContext* context, NDArray* input); - Nd4jStatus uniqueFunctor(sd::LaunchContext * context, NDArray* input, NDArray* values, NDArray* indices, NDArray* counts); +Nd4jStatus uniqueFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* values, NDArray* indices, NDArray* counts); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/updatersHelpers.h b/libnd4j/include/ops/declarable/helpers/updatersHelpers.h index 5bd89b487484..4bff0431152f 100644 --- a/libnd4j/include/ops/declarable/helpers/updatersHelpers.h +++ b/libnd4j/include/ops/declarable/helpers/updatersHelpers.h @@ -21,24 +21,52 @@ #ifndef LIBND4J_UPDATER_RMS_PROM_H #define LIBND4J_UPDATER_RMS_PROM_H -#include #include +#include namespace sd { namespace ops { namespace helpers { - void updaterRmsProp(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateG, const double dLr, const double dRmsDecay, const double dEpsilon); - void updaterAdaGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateH, const double dLr, const double dEpsilon); - void updaterNesterovs(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateV, const double dLr, const double bMomentum); - void updaterAdaMax(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration); - void updaterAdam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration); - void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx, NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon); - void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration); - void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH, NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration); +void updaterRmsProp(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initState, NDArray& update, NDArray& stateG, + const double dLr, const double dRmsDecay, + const double dEpsilon); +void updaterAdaGrad(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initState, NDArray& update, NDArray& stateH, + const double dLr, const double dEpsilon); +void updaterNesterovs(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initState, NDArray& update, + NDArray& stateV, const double dLr, + const double bMomentum); +void updaterAdaMax(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateU, const NDArray& initStateM, + NDArray& update, NDArray& stateU, NDArray& stateM, + const double dLr, const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration); +void updaterAdam(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateU, const NDArray& initStateM, + NDArray& update, NDArray& stateU, NDArray& stateM, + const double dLr, const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration); +void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateMsg, const NDArray& initStateMsdx, + NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, + const double dRho, const double dEpsilon); +void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateV, const NDArray& initStateM, + NDArray& update, NDArray& stateV, NDArray& stateM, + const double dLr, const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration); +void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateV, const NDArray& initStateM, + const NDArray& initStateH, NDArray& update, NDArray& stateV, + NDArray& stateM, NDArray& stateH, const double dLr, + const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/weights.h b/libnd4j/include/ops/declarable/helpers/weights.h index 66246b641049..88f293601cf4 100644 --- a/libnd4j/include/ops/declarable/helpers/weights.h +++ b/libnd4j/include/ops/declarable/helpers/weights.h @@ -19,16 +19,17 @@ // #ifndef __WEIGHTS_H_HELPERS__ #define __WEIGHTS_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void adjustWeights(sd::LaunchContext * context, NDArray* input, NDArray* weights, NDArray* output, int minLength, int maxLength); +void adjustWeights(sd::LaunchContext* context, NDArray* input, NDArray* weights, + NDArray* output, int minLength, int maxLength); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/where.h b/libnd4j/include/ops/declarable/helpers/where.h index 2c958846246d..1538c61e8309 100644 --- a/libnd4j/include/ops/declarable/helpers/where.h +++ b/libnd4j/include/ops/declarable/helpers/where.h @@ -18,17 +18,18 @@ // Created by raver119 on 24/09/18. // -#ifndef DEV_TESTS_WHERE_H -#define DEV_TESTS_WHERE_H +#ifndef SD_WHERE_H +#define SD_WHERE_H #include namespace sd { - namespace ops { - namespace helpers { - void _where(sd::LaunchContext * context, NDArray &condition, NDArray& output, memory::Workspace *workspace); - } - } +namespace ops { +namespace helpers { +void _where(sd::LaunchContext *context, NDArray &condition, NDArray &output, + memory::Workspace *workspace); } +} // namespace ops +} // namespace sd -#endif //DEV_TESTS_WHERE_H +#endif // SD_WHERE_H diff --git a/libnd4j/include/ops/declarable/helpers/zeta.h b/libnd4j/include/ops/declarable/helpers/zeta.h index 7aee45f1c1ef..c3bf280bf66f 100644 --- a/libnd4j/include/ops/declarable/helpers/zeta.h +++ b/libnd4j/include/ops/declarable/helpers/zeta.h @@ -22,79 +22,84 @@ #define LIBND4J_ZETA_H #include + #include "array/NDArray.h" namespace sd { namespace ops { namespace helpers { - - // calculate the Hurwitz zeta function for arrays - void zeta(sd::LaunchContext * context, const NDArray& x, const NDArray& q, NDArray& output); - - - - // calculate the Hurwitz zeta function for scalars - // fast implementation, it is based on Euler-Maclaurin summation formula - template - _CUDA_HD T zetaScalar(const T x, const T q) { - - const T machep = 1.11022302462515654042e-16; - - // FIXME: @raver119 - // expansion coeffZetaicients for Euler-Maclaurin summation formula (2k)! / B2k, where B2k are Bernoulli numbers - const T coeffZeta[] = { 12.0,-720.0,30240.0,-1209600.0,47900160.0,-1.8924375803183791606e9,7.47242496e10,-2.950130727918164224e12, 1.1646782814350067249e14, -4.5979787224074726105e15, 1.8152105401943546773e17, -7.1661652561756670113e18}; - - // if (x <= (T)1.) - // throw("zeta function: x must be > 1 !"); - - // if (q <= (T)0.) - // throw("zeta function: q must be > 0 !"); - - T a, b(0.), k, s, t, w; - - s = math::nd4j_pow(q, -x); - a = q; - int i = 0; - - while(i < 9 || a <= (T)9.) { - i += 1; - a += (T)1.0; - b = math::nd4j_pow(a, -x); - s += b; - if(math::nd4j_abs(b / s) < (T)machep) - return s; - } - - w = a; - s += b * (w / (x - (T)1.) - (T)0.5); - a = (T)1.; - k = (T)0.; - - for(i = 0; i < 12; ++i) { - a *= x + k; - b /= w; - t = a * b / coeffZeta[i]; - s += t; - t = math::nd4j_abs(t / s); - - if(t < (T)machep) - return s; - - k += (T)1.f; - a *= x + k; - b /= w; - k += (T)1.f; - } - - return s; - } - - - -} -} +// calculate the Hurwitz zeta function for arrays +void zeta(sd::LaunchContext* context, const NDArray& x, const NDArray& q, + NDArray& output); + +// calculate the Hurwitz zeta function for scalars +// fast implementation, it is based on Euler-Maclaurin summation formula +template +_CUDA_HD T zetaScalar(const T x, const T q) { + const T machep = 1.11022302462515654042e-16; + + // FIXME: @raver119 + // expansion coeffZetaicients for Euler-Maclaurin summation formula (2k)! / + // B2k, where B2k are Bernoulli numbers + const T coeffZeta[] = {12.0, + -720.0, + 30240.0, + -1209600.0, + 47900160.0, + -1.8924375803183791606e9, + 7.47242496e10, + -2.950130727918164224e12, + 1.1646782814350067249e14, + -4.5979787224074726105e15, + 1.8152105401943546773e17, + -7.1661652561756670113e18}; + + // if (x <= (T)1.) + // throw("zeta function: x must be > 1 !"); + + // if (q <= (T)0.) + // throw("zeta function: q must be > 0 !"); + + T a, b(0.), k, s, t, w; + + s = math::nd4j_pow(q, -x); + a = q; + int i = 0; + + while (i < 9 || a <= (T)9.) { + i += 1; + a += (T)1.0; + b = math::nd4j_pow(a, -x); + s += b; + if (math::nd4j_abs(b / s) < (T)machep) return s; + } + + w = a; + s += b * (w / (x - (T)1.) - (T)0.5); + a = (T)1.; + k = (T)0.; + + for (i = 0; i < 12; ++i) { + a *= x + k; + b /= w; + t = a * b / coeffZeta[i]; + s += t; + t = math::nd4j_abs(t / s); + + if (t < (T)machep) return s; + + k += (T)1.f; + a *= x + k; + b /= w; + k += (T)1.f; + } + + return s; } +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //LIBND4J_ZETA_H +#endif // LIBND4J_ZETA_H diff --git a/libnd4j/include/ops/declarable/impl/BooleanOp.cpp b/libnd4j/include/ops/declarable/impl/BooleanOp.cpp index 07960497ab2a..65946ff62c22 100644 --- a/libnd4j/include/ops/declarable/impl/BooleanOp.cpp +++ b/libnd4j/include/ops/declarable/impl/BooleanOp.cpp @@ -19,119 +19,126 @@ // #include "ops/declarable/BooleanOp.h" -#include -#include + #include +#include +#include + namespace sd { - namespace ops { - BooleanOp::BooleanOp(const char *name, int numInputs, bool scalar) : DeclarableOp::DeclarableOp(name, numInputs, scalar) { - // - } - - /** - * Output shape of any BooleanOp is ALWAYS scalar - */ - ShapeList *BooleanOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - return SHAPELIST(ConstantShapeHelper::getInstance().scalarShapeInfo(DataType::BOOL)); - } - - bool BooleanOp::verify(sd::graph::Context &block) { - // check if scalar or not - - // validation? - - Nd4jStatus status = this->validateNonEmptyInput(block); - if (status != ND4J_STATUS_OK) { - nd4j_printf("Inputs should be not empty for BooleanOps",""); - throw std::runtime_error("Bad inputs"); - } - - status = this->validateAndExecute(block); - if (status == ND4J_STATUS_TRUE) - return true; - else if (status == ND4J_STATUS_FALSE) - return false; - else { - nd4j_printf("Got error %i during [%s] evaluation: ", (int) status, this->getOpDescriptor()->getOpName()->c_str()); - throw std::runtime_error("Internal error"); - } - } - - bool BooleanOp::prepareOutputs(Context& ctx) { - - auto variableSpace = ctx.getVariableSpace(); - if (ctx.isFastPath()) - return true; - - for (int e = 0; e < this->getOpDescriptor()->getNumberOfOutputs(); e++) { - std::pair pair(ctx.nodeId(), e); - - if (!variableSpace->hasVariable(pair)) - variableSpace->putVariable(pair, new Variable()); - - auto var = ctx.variable(pair); - - if (!var->hasNDArray()) { - var->setNDArray(NDArrayFactory::create_(false, ctx.launchContext())); - var->markRemovable(true); - } - } - - return true; - } - - Nd4jStatus sd::ops::BooleanOp::execute(Context* block) { - - // basic validation: ensure inputs are set - REQUIRE_OK(this->validateNonEmptyInput(*block)); - - // ensure number of IArgs, TArgs match our expectations - REQUIRE_OK(this->validateArguments(*block)); - - // this method will allocate output NDArrays for this op - this->prepareOutputs(*block); - - auto timeStart = std::chrono::system_clock::now(); - - Nd4jStatus status = this->validateAndExecute(*block); - - auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); - block->setInnerTime(outerTime); - - // basically we're should be putting 0.0 as FALSE, and any non-0.0 value will be treated as TRUE - std::pair p(block->nodeId(), 0); - auto var = block->isFastPath() ? block->fastpath_out()[0] : block->variable(p)->getNDArray(); - var->p(Nd4jLong(0), status == ND4J_STATUS_TRUE ? 1.0f : 0.0f); - - // for CPU backend that's nop, but for CUDA-like archs this will update special buffer - var->syncToDevice(); - - if (status == ND4J_STATUS_FALSE || status == ND4J_STATUS_TRUE) - return ND4J_STATUS_OK; - - nd4j_printf("%s: node_%i got unexpected result instead of boolean: [%i]\n", this->getOpName()->c_str(), block->nodeId(), status); - return ND4J_STATUS_KERNEL_FAILURE; - } - - bool BooleanOp::verify(const std::vector &args) { - VariableSpace variableSpace; - - int cnt = -1; - std::vector in; - for (auto v: args) { - auto var = new Variable(v); - var->markRemovable(false); - in.push_back(cnt); - variableSpace.putVariable(cnt--, var); - } +namespace ops { +BooleanOp::BooleanOp(const char *name, int numInputs, bool scalar) + : DeclarableOp::DeclarableOp(name, numInputs, scalar) { + // +} - Context block(1, &variableSpace, false); - block.fillInputs(in); +/** + * Output shape of any BooleanOp is ALWAYS scalar + */ +ShapeList *BooleanOp::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + return SHAPELIST( + ConstantShapeHelper::getInstance().scalarShapeInfo(DataType::BOOL)); +} - return this->verify(block); - } +bool BooleanOp::verify(sd::graph::Context &block) { + // check if scalar or not + + // validation? + + Nd4jStatus status = this->validateNonEmptyInput(block); + if (status != ND4J_STATUS_OK) { + nd4j_printf("Inputs should be not empty for BooleanOps", ""); + throw std::runtime_error("Bad inputs"); + } + + status = this->validateAndExecute(block); + if (status == ND4J_STATUS_TRUE) + return true; + else if (status == ND4J_STATUS_FALSE) + return false; + else { + nd4j_printf("Got error %i during [%s] evaluation: ", (int)status, + this->getOpDescriptor()->getOpName()->c_str()); + throw std::runtime_error("Internal error"); + } +} + +bool BooleanOp::prepareOutputs(Context &ctx) { + auto variableSpace = ctx.getVariableSpace(); + if (ctx.isFastPath()) return true; + + for (int e = 0; e < this->getOpDescriptor()->getNumberOfOutputs(); e++) { + std::pair pair(ctx.nodeId(), e); + + if (!variableSpace->hasVariable(pair)) + variableSpace->putVariable(pair, std::make_shared()); + + auto var = ctx.variable(pair); + + if (!var->hasNDArray()) { + var->setNDArray(std::make_shared( + NDArrayFactory::create(false, ctx.launchContext()))); } + } + + return true; } +Nd4jStatus sd::ops::BooleanOp::execute(Context *block) { + // basic validation: ensure inputs are set + REQUIRE_OK(this->validateNonEmptyInput(*block)); + + // ensure number of IArgs, TArgs match our expectations + REQUIRE_OK(this->validateArguments(*block)); + + // this method will allocate output NDArrays for this op + this->prepareOutputs(*block); + + auto timeStart = std::chrono::system_clock::now(); + + Nd4jStatus status = this->validateAndExecute(*block); + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = + std::chrono::duration_cast(timeEnd - timeStart) + .count(); + block->setInnerTime(outerTime); + + // basically we're should be putting 0.0 as FALSE, and any non-0.0 value will + // be treated as TRUE + std::pair p(block->nodeId(), 0); + auto var = block->isFastPath() ? block->fastpath_out()[0] + : block->variable(p)->getNDArray(); + var->p(Nd4jLong(0), status == ND4J_STATUS_TRUE ? 1.0f : 0.0f); + + // for CPU backend that's nop, but for CUDA-like archs this will update + // special buffer + var->syncToDevice(); + + if (status == ND4J_STATUS_FALSE || status == ND4J_STATUS_TRUE) + return ND4J_STATUS_OK; + + nd4j_printf("%s: node_%i got unexpected result instead of boolean: [%i]\n", + this->getOpName().c_str(), block->nodeId(), status); + return ND4J_STATUS_KERNEL_FAILURE; +} + +bool BooleanOp::verify(const std::vector &args) { + VariableSpace variableSpace; + + int cnt = -1; + std::vector in; + for (auto v : args) { + auto var = std::make_shared(*v, "", cnt); + in.emplace_back(cnt); + variableSpace.putVariable(cnt--, var); + } + + Context block(1, &variableSpace, false); + block.fillInputs(in); + + return this->verify(block); +} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/impl/BroadcastableBoolOp.cpp b/libnd4j/include/ops/declarable/impl/BroadcastableBoolOp.cpp index 634236d35393..59072e653cc3 100644 --- a/libnd4j/include/ops/declarable/impl/BroadcastableBoolOp.cpp +++ b/libnd4j/include/ops/declarable/impl/BroadcastableBoolOp.cpp @@ -18,55 +18,69 @@ // Created by raver on 6/6/2018. // +#include +#include #include #include -#include -#include namespace sd { - namespace ops { - BroadcastableBoolOp::BroadcastableBoolOp(const char *name, int numTArgs, int numIArgs) : DeclarableCustomOp::DeclarableCustomOp(2, 1, name, false, numTArgs, numIArgs) { - // - } +namespace ops { +BroadcastableBoolOp::BroadcastableBoolOp(const char *name, int numTArgs, + int numIArgs) + : DeclarableCustomOp::DeclarableCustomOp(2, 1, name, false, numTArgs, + numIArgs) { + // +} - ShapeList *BroadcastableBoolOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto shapeList = SHAPELIST(); - auto x = inputShape->at(0); - auto y = inputShape->at(1); - sd::DataType dtype = sd::DataType::BOOL; +ShapeList *BroadcastableBoolOp::calculateOutputShape( + ShapeList *inputShape, sd::graph::Context &block) { + auto shapeList = SHAPELIST(); + auto x = inputShape->at(0); + auto y = inputShape->at(1); + sd::DataType dtype = sd::DataType::BOOL; - if(shape::isEmpty(x) || shape::isEmpty(y)) { - // this is edge case, [3, 4] + [] = [] - if ((shape::isEmpty(x) && shape::rank(x) == 0) || (shape::isEmpty(y) && shape::rank(y) == 0)) { - shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor::emptyDescriptor(dtype))); - return shapeList; - } - - const Nd4jLong *newshape = nullptr; - ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace()); - shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(newshape, dtype))); - } else if (shape::isScalar(x) && shape::isScalar(y)) { - if (shape::rank(x) >= shape::rank(y)) { - shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(x, dtype))); - } else { - shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(y, dtype))); - } - } else if (shape::equalsSoft(x, y)) { - shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(x, dtype))); - } else if (shape::isScalar(x) && !shape::isScalar(y)) { - shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(y, dtype))); - } else if (!shape::isScalar(x) && shape::isScalar(y)) { - shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(x, dtype))); - } else if (ShapeUtils::areShapesBroadcastable(x, y)) { - const Nd4jLong *newshape = nullptr; - ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace()); - shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(newshape, dtype))); - } else { - // in this case we'll throw exception later - shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(x, dtype))); - } + if (shape::isEmpty(x) || shape::isEmpty(y)) { + // this is edge case, [3, 4] + [] = [] + if ((shape::isEmpty(x) && shape::rank(x) == 0) || + (shape::isEmpty(y) && shape::rank(y) == 0)) { + shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor::emptyDescriptor(dtype))); + return shapeList; + } - return shapeList; - } + const Nd4jLong *newshape = nullptr; + ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace()); + shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(newshape, dtype))); + } else if (shape::isScalar(x) && shape::isScalar(y)) { + if (shape::rank(x) >= shape::rank(y)) { + shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(x, dtype))); + } else { + shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(y, dtype))); } -} \ No newline at end of file + } else if (shape::equalsSoft(x, y)) { + shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(x, dtype))); + } else if (shape::isScalar(x) && !shape::isScalar(y)) { + shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(y, dtype))); + } else if (!shape::isScalar(x) && shape::isScalar(y)) { + shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(x, dtype))); + } else if (ShapeUtils::areShapesBroadcastable(x, y)) { + const Nd4jLong *newshape = nullptr; + ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace()); + shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(newshape, dtype))); + } else { + // in this case we'll throw exception later + shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(x, dtype))); + } + + return shapeList; +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/BroadcastableOp.cpp b/libnd4j/include/ops/declarable/impl/BroadcastableOp.cpp index 4611d49cb6f5..fc5ab1c3ebee 100644 --- a/libnd4j/include/ops/declarable/impl/BroadcastableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/BroadcastableOp.cpp @@ -18,69 +18,81 @@ // Created by raver on 6/6/2018. // +#include +#include #include #include -#include -#include namespace sd { - namespace ops { - BroadcastableOp::BroadcastableOp(const char *name, int numTArgs, int numIArgs) : DeclarableCustomOp::DeclarableCustomOp(2, 1, name, false, numTArgs, numIArgs) { - // - } - - ShapeList *BroadcastableOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto shapeList = SHAPELIST(); - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto outputs = _descriptor->getOutputTypesForOutput(0); - sd::DataType dtype = block.dataType(0); - if (block.dataType(0) != sd::DataType::BOOL && !(outputs.size() == 1 && outputs[0] == sd::DataType::BOOL)) { - if (Environment::getInstance().isExperimentalBuild()) { - if (shape::length(y) > shape::length(x)) { - dtype = DataTypeUtils::pickPairwiseResultType(y, x); - } else { - dtype = DataTypeUtils::pickPairwiseResultType(x, y); - } - } else { - dtype = ArrayOptions::dataType(x); - } - } else - dtype = sd::DataType::BOOL; - - if(shape::isEmpty(x) || shape::isEmpty(y)) { - // this is edge case, [3, 4] + [] = [] - if ((shape::isEmpty(x) && shape::rank(x) == 0) || (shape::isEmpty(y) && shape::rank(y) == 0)) { - shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor::emptyDescriptor(dtype))); - return shapeList; - } +namespace ops { +BroadcastableOp::BroadcastableOp(const char *name, int numTArgs, int numIArgs) + : DeclarableCustomOp::DeclarableCustomOp(2, 1, name, false, numTArgs, + numIArgs) { + // +} +ShapeList *BroadcastableOp::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + auto shapeList = SHAPELIST(); + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto outputs = _descriptor->getOutputTypesForOutput(0); + sd::DataType dtype; + if (!(outputs.size() == 1 && outputs[0] == sd::DataType::BOOL)) { + if (Environment::getInstance().isExperimentalBuild()) { + if (shape::length(y) > shape::length(x)) { + dtype = DataTypeUtils::pickPairwiseResultType(y, x); + } else { + dtype = DataTypeUtils::pickPairwiseResultType(x, y); + } + } else { + dtype = ArrayOptions::dataType(x); + } + } else + dtype = sd::DataType::BOOL; - const Nd4jLong *newshape = nullptr; - ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace()); - shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(newshape, dtype))); - } else if (shape::isScalar(x) && shape::isScalar(y)) { - if (shape::rank(x) >= shape::rank(y)) { - shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(x, dtype))); - } else { - shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(y, dtype))); - } - } else if (shape::equalsSoft(x, y)) { - shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(x, dtype))); - } else if (shape::isScalar(x) && !shape::isScalar(y)) { - shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(y, dtype))); - } else if (!shape::isScalar(x) && shape::isScalar(y)) { - shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(x, dtype))); - } else if (ShapeUtils::areShapesBroadcastable(x, y)) { - const Nd4jLong *newshape = nullptr; - ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace()); - shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(newshape, dtype))); - } else { - // in this case we'll throw exception later - shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(x, dtype))); - } + if (shape::isEmpty(x) || shape::isEmpty(y)) { + // this is edge case, [3, 4] + [] = [] + if ((shape::isEmpty(x) && shape::rank(x) == 0) || + (shape::isEmpty(y) && shape::rank(y) == 0)) { + shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor::emptyDescriptor(dtype))); + return shapeList; + } - return shapeList; - } + const Nd4jLong *newshape = nullptr; + ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace()); + shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(newshape, dtype))); + } else if (shape::isScalar(x) && shape::isScalar(y)) { + if (shape::rank(x) >= shape::rank(y)) { + shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(x, dtype))); + } else { + shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(y, dtype))); } -} \ No newline at end of file + } else if (shape::equalsSoft(x, y)) { + shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(x, dtype))); + } else if (shape::isScalar(x) && !shape::isScalar(y)) { + shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(y, dtype))); + } else if (!shape::isScalar(x) && shape::isScalar(y)) { + shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(x, dtype))); + } else if (ShapeUtils::areShapesBroadcastable(x, y)) { + const Nd4jLong *newshape = nullptr; + ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace()); + shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(newshape, dtype))); + } else { + // in this case we'll throw exception later + shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(x, dtype))); + } + + return shapeList; +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/DeclarableCustomOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableCustomOp.cpp index d6227af0cd91..4169b055eca7 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableCustomOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableCustomOp.cpp @@ -22,9 +22,13 @@ #include namespace sd { - namespace ops { - DeclarableCustomOp::DeclarableCustomOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs) : sd::ops::DeclarableOp(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs) { - // - } - } -} \ No newline at end of file +namespace ops { +DeclarableCustomOp::DeclarableCustomOp(int numInputs, int numOutputs, + const char *opName, bool allowsInplace, + int tArgs, int iArgs) + : sd::ops::DeclarableOp(numInputs, numOutputs, opName, allowsInplace, tArgs, + iArgs) { + // +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp index d70355038c50..364b6a3fd4aa 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp @@ -18,142 +18,137 @@ // @author raver119@gmail.com // -#include -#include #include #include #include +#include +#include namespace sd { - namespace ops { - DeclarableListOp::DeclarableListOp(int numInputs, int numOutputs, const char* opName, int tArgs, int iArgs) : DeclarableOp::DeclarableOp(numInputs, numOutputs, opName, false, tArgs, iArgs) { - // This kind of operations work with sets: NDArrayList - this->getOpDescriptor()->setInputType(InputType_NUMERIC_SET); - } +namespace ops { +DeclarableListOp::DeclarableListOp(int numInputs, int numOutputs, + const char* opName, int tArgs, int iArgs) + : DeclarableOp::DeclarableOp(numInputs, numOutputs, opName, false, tArgs, + iArgs) { + // This kind of operations work with sets: NDArrayList + this->getOpDescriptor()->setInputType(InputType_NUMERIC_SET); +} /* template void DeclarableListOp::execute(Block& block) { // } */ - /** - * This method just outputs scalar buffer - * - * @tparam T - * @param inputShape - * @param block - * @return - */ - ShapeList* DeclarableListOp::calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) { - // TODO: ensure this method isn't ever called - - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(block.dataType(), 'c', {1, 1}); - return SHAPELIST(newShape); - } - - sd::NDArray* sd::ops::DeclarableListOp::getZ(Context& block, int inputId) { - //nd4j_printf("wow\n",""); - return nullptr; - } - - void DeclarableListOp::setupResult(NDArray* array, Context& block) { - block.pushNDArrayToVariableSpace(block.getNodeId(), 0, array); - } - - void DeclarableListOp::setupResultList(NDArrayList* arrayList, Context& block) { - block.pushNDArrayListToVariableSpace(block.getNodeId(), 0, arrayList); - } - - ResultSet DeclarableListOp::execute(NDArrayList* list, std::initializer_list inputs, std::initializer_list tArgs, std::initializer_list iArgs) { - std::vector ins(inputs); - std::vector tas(tArgs); - std::vector ias(iArgs); - return this->execute(list, ins, tas, ias); - } - - Nd4jStatus DeclarableListOp::execute(Context* block) { - if (block == nullptr) - throw std::invalid_argument("Block is NULL"); - - nd4j_debug("Executing list op: [%s]\n", this->getOpName()->c_str()); - - // ensure number of IArgs, TArgs match our expectations - REQUIRE_OK(this->validateArguments(*block)); - - // we shouldn't call for this in ListOp - //this->prepareOutputs(*block); - - auto timeStart = std::chrono::system_clock::now(); - - Nd4jStatus status = this->validateAndExecute(*block); - - auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); - block->setInnerTime(outerTime); - - return status; - } - - ResultSet DeclarableListOp::execute(NDArrayList* list, std::vector& inputs, std::vector& tArgs, std::vector& iArgs) { - VariableSpace varSpace; - int nodeId = 119; - - // should be never used in practice, since in-graph NDArrayList should have id set - int cnt = -1; - std::vector in; - if (list != nullptr) { - if (list->id().first == 0) - list->id().first = -1; - - auto listVar = new Variable(nullptr, nullptr, -119, 0); - listVar->setNDArrayList(list); - varSpace.putVariable(-1, listVar); - in.push_back(-1); - cnt--; - } - - - for (auto v: inputs) { - auto var = new Variable(v); - var->markRemovable(false); - in.push_back(cnt); - varSpace.putVariable(cnt--, var); - } - - Context block(1, &varSpace, false); - block.fillInputs(in); - - for (int e = 0; e < tArgs.size(); e++) - block.getTArguments()->emplace_back(tArgs.at(e)); - - - for (int e = 0; e < iArgs.size(); e++) - block.getIArguments()->emplace_back(iArgs.at(e)); - - - Nd4jStatus result = this->validateAndExecute(block); - ResultSet res; - res.setStatus(result); - - for (int e = 0; e < DataTypeUtils::max(); e++) { - std::pair pair(1, e); - if (varSpace.hasVariable(pair)) { - auto var = varSpace.getVariable(pair); - if (var->hasNDArray()) { - auto arr = var->getNDArray(); - if (arr->isAttached()) { - auto d = arr->detach(); - res.push_back(d); - } else { - var->markRemovable(false); - res.push_back(arr); - } - } - } else - break; - } - - return res; +/** + * This method just outputs scalar buffer + * + * @tparam T + * @param inputShape + * @param block + * @return + */ +ShapeList* DeclarableListOp::calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) { + // TODO: ensure this method isn't ever called + + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( + DataType::FLOAT32, 'c', {1, 1}); + return SHAPELIST(newShape); +} + +sd::NDArray* sd::ops::DeclarableListOp::getZ(Context& block, int inputId) { + // nd4j_printf("wow\n",""); + return nullptr; +} + +void DeclarableListOp::setupResult(const NDArray& array, Context& block) { + block.pushNDArrayToVariableSpace(block.getNodeId(), 0, array); +} + +void DeclarableListOp::setupResultList(const NDArrayList& arrayList, + Context& block) { + block.pushNDArrayListToVariableSpace(block.getNodeId(), 0, arrayList); +} + +Nd4jStatus DeclarableListOp::execute(Context* block) { + if (block == nullptr) throw std::invalid_argument("Block is NULL"); + + nd4j_debug("Executing list op: [%s]\n", this->getOpName().c_str()); + + // ensure number of IArgs, TArgs match our expectations + REQUIRE_OK(this->validateArguments(*block)); + + // we shouldn't call for this in ListOp + // this->prepareOutputs(*block); + + auto timeStart = std::chrono::system_clock::now(); + + Nd4jStatus status = this->validateAndExecute(*block); + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = + std::chrono::duration_cast(timeEnd - timeStart) + .count(); + block->setInnerTime(outerTime); + + return status; +} + +ResultSet DeclarableListOp::execute(const NDArrayList& list, + const std::vector& inputs, + const std::vector& tArgs, + const std::vector& iArgs) { + VariableSpace varSpace; + int nodeId = 119; + + // should be never used in practice, since in-graph NDArrayList should have id + // set + int cnt = -1; + std::vector in; + + // first input must be our NDArrayList, except create_list op. it creates list + // itself. + if (getOpName() != "create_list") { + auto listVar = std::make_shared(list, "", cnt); + varSpace.putVariable(cnt, listVar); + in.push_back(cnt--); + } + + for (auto v : inputs) { + auto var = std::make_shared(*v, "", cnt); + in.push_back(cnt); + varSpace.putVariable(cnt--, var); + } + + Context block(1, &varSpace, false); + block.fillInputs(in); + + for (int e = 0; e < tArgs.size(); e++) block.appendT(tArgs.at(e)); + + for (int e = 0; e < iArgs.size(); e++) block.appendI(iArgs.at(e)); + + Nd4jStatus result = this->validateAndExecute(block); + ResultSet res; + res.setStatus(result); + + for (int e = 0; e < DataTypeUtils::max(); e++) { + std::pair pair(1, e); + if (varSpace.hasVariable(pair)) { + auto var = varSpace.getVariable(pair); + if (var->hasNDArray()) { + auto arr = var->getNDArray(); + if (arr->isAttached()) { + res.push_back(arr->detach()); + } else { + var->markRemovable(false); + res.push_back(*arr); } - } -} \ No newline at end of file + } + } else + break; + } + + return res; +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index cd8d0bdd8657..a7e452d0ce0a 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -18,1123 +18,1216 @@ // @author raver119@gmail.com // -#include -#include -#include #include +#include #include +#include #include -#include -#include +#include #include +#include +#include + #include namespace sd { - namespace ops { - Nd4jStatus conditionHelper(const char *file, int line, int condition, int argNumber, const char *format, ...) { - if (!condition) { - va_list args; - - printf("Error at [%s:%i:%i]:\n", file, line, argNumber); - va_start(args, format); - vprintf(format, args); - va_end(args); - printf("\n"); - fflush(stdout); - - return ND4J_STATUS_BAD_PARAMS; - } - return ND4J_STATUS_OK; - } - - DeclarableOp::DeclarableOp() { - // no-op - } - - DeclarableOp::DeclarableOp(const char *name, bool isLogical) { - _descriptor = new OpDescriptor(name, isLogical); - _name = name; - } - - DeclarableOp::DeclarableOp(const char *name, int numInputs, bool scalar) { - _descriptor = new OpDescriptor(numInputs, name, scalar); - _name = name; - } - - DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace) { - _descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace); - _name = opName; - } - - DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent) { - _descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, divergent); - _name = opName; - } - - DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs) { - _descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs); - _name = opName; - } - - DeclarableOp::~DeclarableOp() { - if (_descriptor != nullptr) - delete _descriptor; +namespace ops { +Nd4jStatus conditionHelper(const char *file, int line, int condition, + int argNumber, const char *format, ...) { + if (!condition) { + va_list args; + + printf("Error at [%s:%i:%i]:\n", file, line, argNumber); + va_start(args, format); + vprintf(format, args); + va_end(args); + printf("\n"); + fflush(stdout); + + return ND4J_STATUS_BAD_PARAMS; + } + return ND4J_STATUS_OK; +} + +DeclarableOp::DeclarableOp() { + // no-op +} + +DeclarableOp::DeclarableOp(const char *name, bool isLogical) { + _descriptor = new OpDescriptor(name, isLogical); + _name = name; +} + +DeclarableOp::DeclarableOp(const char *name, int numInputs, bool scalar) { + _descriptor = new OpDescriptor(numInputs, name, scalar); + _name = name; +} + +DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, + bool allowsInplace) { + _descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace); + _name = opName; +} + +DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, + bool allowsInplace, bool divergent) { + _descriptor = + new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, divergent); + _name = opName; +} + +DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, + bool allowsInplace, int tArgs, int iArgs) { + _descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, + tArgs, iArgs); + _name = opName; +} + +DeclarableOp::~DeclarableOp() { + if (_descriptor != nullptr) delete _descriptor; + + if (_scalar != nullptr) delete _scalar; +} + +OpDescriptor *DeclarableOp::getOpDescriptor() { return _descriptor; } + +const std::string &DeclarableOp::getOpName() const { + return *_descriptor->getOpName(); +} + +Nd4jLong DeclarableOp::getOpHash() const { return _descriptor->getHash(); } + +sd::NDArray *sd::ops::DeclarableOp::getNullifiedZ(Context &block, int inputId) { + auto result = getZ(block, inputId); + if (result != nullptr && !block.isInplace()) result->nullify(); + + return result; +} + +sd::NDArray *sd::ops::DeclarableOp::getZ(Context &ctx, int inputId) { + NDArray *z = nullptr; + + if (ctx.isFastPath()) { + if (ctx.fastpath_out().size() <= inputId) { + if (ctx.isInplace()) { + z = ctx.fastpath_in()[inputId].get(); + } else + throw std::runtime_error("fastpath_out: unresolved output array"); + } else { + z = ctx.fastpath_out()[inputId].get(); + } + } else { + std::pair pair(ctx.nodeId(), inputId); + + if (ctx.isInplace()) { + auto vz = ctx.variable(inputId)->getNDArray(); + z = vz.get(); + + // hypothetically it's possible to have no variable. chances are low, but + // who knows. let's just create it for now + if (!ctx.getVariableSpace()->hasVariable(pair)) { + auto var = std::make_shared(); + ctx.getVariableSpace()->putVariable(pair, var); + } + + // now we're saving input array as output array + auto var = ctx.getVariableSpace()->getVariable(pair); + var->markRemovable(false); + var->setNDArray(vz); + } else if (!ctx.isInplace()) { + auto var = ctx.variable(pair); + if (var->getNDArray() != nullptr && var->getNDArray()->nonNull()) { + z = var->getNDArray().get(); + } else { + nd4j_printf("Can't get Z variable for node_%i!\n", ctx.nodeId()); + } + } else { + nd4j_printf("BOOM!\n", ""); + throw std::runtime_error("Boom!"); + } + } - if (_scalar != nullptr) - delete _scalar; - } + if (z != nullptr && z->undefined()) return nullptr; - OpDescriptor* DeclarableOp::getOpDescriptor() { - return _descriptor; - } + return z; +} - std::string *DeclarableOp::getOpName() { - return _descriptor->getOpName(); - } - - Nd4jLong DeclarableOp::getOpHash() { - return _descriptor->getHash(); - } +int DeclarableOp::prepareOutputs(Context &ctx) { + auto workspace = ctx.workspace(); + GraphProfile *prof = nullptr; + NodeProfile *node = nullptr; + std::chrono::time_point inputEnd, inputStart, + shapeStart, shapeEnd, arrayStart, arrayEnd; + bool canUseFastPath = true; - sd::NDArray* sd::ops::DeclarableOp::getNullifiedZ(Context& block, int inputId) { - auto result = getZ(block, inputId); - if (result != nullptr && !block.isInplace()) - result->nullify(); + auto fp = ctx.isFastPath(); - return result; + if (Environment::getInstance().isProfiling()) { + /* + if (ctx.getVariableSpace() != nullptr && ctx.getVariableSpace()->flowPath() + != nullptr) { prof = ctx.getVariableSpace()->flowPath()->profile(); node = + prof->nodeById(ctx.nodeId()); + } + */ + throw std::runtime_error( + "DeclarableOp::prepareOutputs - Not implemented yet"); + } + + if (ctx.isInplace()) { + if (Environment::getInstance().isProfiling() && node != nullptr) { + if (fp) { + // + } else { + for (auto p : ctx.inputs()) { + auto var = ctx.variable(p); + if (var->variableType() == VariableType::NDARRAY) { + auto array = var->getNDArray().get(); + + node->addInputShape(array->shapeInfo()); + node->addOutputShape(array->shapeInfo()); + } } + } + } - - sd::NDArray* sd::ops::DeclarableOp::getZ(Context& ctx, int inputId) { - NDArray* z = nullptr; - - if (ctx.isFastPath()) { - if (ctx.fastpath_out().size() <= inputId) { - if (ctx.isInplace()) { - z = ctx.fastpath_in()[inputId]; - } else - throw std::runtime_error("fastpath_out: unresolved output array"); - } else { - z = ctx.fastpath_out()[inputId]; - } + // if that's not fp, we can still propagate inputs and outputs + if (!fp) { + int cnt = 0; + auto id = ctx.nodeId(); + auto vs = ctx.getVariableSpace(); + for (auto p : ctx.inputs()) { + auto var = ctx.variable(p); + if (var->variableType() == VariableType::NDARRAY) { + auto array = var->getNDArray(); + ctx.setInputArray(cnt, array); + ctx.setOutputArray(cnt, array); + + // in case of this override we might need to update outputs in the + // Graph VariableSpace as well + if (vs != nullptr) { + if (vs->hasVariable(id, cnt)) { + auto v2 = vs->getVariable(id, cnt); + if (!v2->hasNDArray()) { + v2->setNDArray(array); + v2->markRemovable(false); + } } else { - std::pair pair(ctx.nodeId(), inputId); - - if (ctx.isInplace()) { - z = ctx.variable(inputId)->getNDArray(); - - // hypothetically it's possible to have no variable. chances are low, but who knows. let's just create it for now - if (!ctx.getVariableSpace()->hasVariable(pair)) { - auto var = new Variable(); - ctx.getVariableSpace()->putVariable(pair, var); - } - - // now we're saving input array as output array - auto var = ctx.getVariableSpace()->getVariable(pair); - var->markRemovable(false); - var->setNDArray(z); - } else if (!ctx.isInplace()) { - auto var = ctx.variable(pair); - if (var->getNDArray() != nullptr && var->getNDArray()->nonNull()) { - z = var->getNDArray(); - } else { - nd4j_printf("Can't get Z variable for node_%i!\n", ctx.nodeId()); - } - } else { - nd4j_printf("BOOM!\n", ""); - throw std::runtime_error("Boom!"); - } + auto v2 = vs->putVariable(id, cnt, array); + v2->markRemovable(false); } + } - return z; + cnt++; + } else { + canUseFastPath = false; } + } + } - int sd::ops::DeclarableOp::prepareOutputs(Context &ctx) { - auto workspace = ctx.getWorkspace(); - GraphProfile *prof = nullptr; - NodeProfile *node = nullptr; - std::chrono::time_point inputEnd, inputStart, shapeStart, shapeEnd, arrayStart, arrayEnd; - bool canUseFastPath = true; - - auto fp = ctx.isFastPath(); - - if (Environment::getInstance().isProfiling()) { - if (ctx.getVariableSpace() != nullptr && ctx.getVariableSpace()->flowPath() != nullptr) { - prof = ctx.getVariableSpace()->flowPath()->profile(); - node = prof->nodeById(ctx.nodeId()); - } - } - - if (ctx.isInplace()) { - if (Environment::getInstance().isProfiling() && node != nullptr) { - if (fp) { - // - } else { - for (auto p: *ctx.inputs()) { - auto var = ctx.variable(p); - if (var->variableType() == VariableType::NDARRAY) { - NDArray *array = var->getNDArray(); - - node->addInputShape(array->shapeInfo()); - node->addOutputShape(array->shapeInfo()); - } - } - } - } - - // if that's not fp, we can still propagate inputs and outputs - if (!fp) { - int cnt = 0; - auto id = ctx.nodeId(); - auto vs = ctx.getVariableSpace(); - for (auto p: *ctx.inputs()) { - auto var = ctx.variable(p); - if (var->variableType() == VariableType::NDARRAY) { - NDArray *array = var->getNDArray(); - ctx.setInputArray(cnt, array); - ctx.setOutputArray(cnt, array); - - - // in case of this override we might need to update outputs in the Graph VariableSpace as well - if (vs != nullptr) { - if (vs->hasVariable(id, cnt)) { - auto v2 = vs->getVariable(id, cnt); - if (!v2->hasNDArray()) { - v2->setNDArray(array); - v2->markRemovable(false); - - } - } else { - auto v2 = vs->putVariable(id, cnt, array); - v2->markRemovable(false); - } - } - - cnt++; - } else { - canUseFastPath = false; - } - } - } - - if (!canUseFastPath) - ctx.forbidFastPath(true); - - // do nothing, getZ result will do the trick - return static_cast(ctx.width()); - } else { - // if op is not inplace - we should pre-allocate arrays - ShapeList inSha; - int results = 0; - - if (Environment::getInstance().isProfiling() && node != nullptr) - inputStart = std::chrono::system_clock::now(); - - int cntIn = 0; - // we build list of input shapes - if (fp) { - for (const auto p:ctx.fastpath_in()) { - inSha.push_back(p == nullptr ? nullptr : p->shapeInfo()); - } - } else { - int arrCnt = 0; - for (auto p: *ctx.inputs()) { - auto var = ctx.variable(p); - if (var->variableType() == VariableType::NDARRAY) { - NDArray *array = var->getNDArray(); - if (array == nullptr) - throw unresolved_input_exception::build("Variable wasn't resolved prior shape calculation", p); - - inSha.push_back(array->shapeInfo()); - - // we're also filling ctx with arrays - if (canUseFastPath) - ctx.setInputArray(arrCnt++, array); - } else { - canUseFastPath = false; - } - cntIn++; - } - } - - // if we override shape function, we'll return size of fastPath - if (fp && ctx.shapeFunctionOverride()) { - return (int) ctx.fastpath_out().size(); - } - - // optionally saving input time - if (Environment::getInstance().isProfiling() && node != nullptr) { - inputEnd = std::chrono::system_clock::now(); - auto inputTime = std::chrono::duration_cast(inputEnd - inputStart).count(); - node->setInputTime(inputTime); - - // saving output shapes in profile - for (int e = 0; e < inSha.size(); e++) - node->addInputShape(inSha.at(e)); - - shapeStart = std::chrono::system_clock::now(); - } - - auto outSha = this->calculateOutputShape(&inSha, ctx); - results = outSha->size(); - - // optionally saving shapeTime - if (Environment::getInstance().isProfiling() && node != nullptr) { - shapeEnd = std::chrono::system_clock::now(); - auto prepTime = std::chrono::duration_cast(shapeEnd - shapeStart).count(); - node->setShapeFunctionTime(prepTime); - - // saving output shapes in profile - for (int e = 0; e < outSha->size(); e++) - node->addOutputShape(outSha->at(e)); - - arrayStart = std::chrono::system_clock::now(); - } - - int cnt = 0; - - for (auto out: *outSha->asVector()) { - if (!fp) { - // we need to check, if Z is really needed - std::pair pair(ctx.nodeId(), cnt++); - - if (!ctx.isValueAvailable(pair.second)) { - if (Environment::getInstance().isDebugAndVerbose()) - shape::printShapeInfoLinear("Going to create variable with shape", out); - - // we're creating non-initialized array here - auto outArr = new NDArray(out, true, ctx.launchContext(), false); - - ctx.pushNDArrayToVariableSpace(pair, outArr); - - if (canUseFastPath) - ctx.setOutputArray(pair.second, outArr); - } else { - // validate/compare shapes here. existent vs provided in outSha - auto var = ctx.variable(pair); - auto shape = var->getNDArray()->shapeInfo(); - - if (canUseFastPath) - ctx.setOutputArray(pair.second, var->getNDArray()); - - if (!shape::equalsSoft(out, shape) || shape::isEmpty(out) != shape::isEmpty(shape)) { - auto eShape = ShapeUtils::shapeAsString(out); - auto aShape = ShapeUtils::shapeAsString(shape); - - //outSha->destroy(); - delete outSha; - - nd4j_printf("Expected vs provided shapes mismatch %s vs %s at index %i\n", eShape.c_str(), aShape.c_str(), pair.second); - throw std::runtime_error("Expected vs provided shapes mismatch"); - } - - /* - * FIXME: we want to uncomment this eventually, and check data types equality - //checking out data type equality - if (ArrayOptions::dataType(out) != ArrayOptions::dataType(shape)) { - std::string msg = "Provided array [" + StringUtils::valueToString(pair.second) + "] has unexpected data type"; - throw sd::datatype_exception::build(msg, ArrayOptions::dataType(out), ArrayOptions::dataType(shape)); - } - */ - } - } else { - auto fout = ctx.fastpath_out(); - auto idx = cnt++; - if (fout.size() <= idx) { - // array doesnt exist - auto outArr = new NDArray(out, true, ctx.launchContext()); - ctx.setOutputArray(idx, outArr, true); - } else { - auto array = fout[idx]; - // checking out shape equality - if (!shape::equalsSoft(out, array->shapeInfo()) || shape::isEmpty(out) != array->isEmpty()) { - auto eShape = ShapeUtils::shapeAsString(out); - auto aShape = ShapeUtils::shapeAsString(array->shapeInfo()); - - //outSha->destroy(); - delete outSha; - - nd4j_printf("Expected vs provided shape mismatch %s vs %s at index %i\n", eShape.c_str(), aShape.c_str(), idx); - throw std::runtime_error("Expected vs provided shape mismatch"); - } - } - } - } - - if (!canUseFastPath) - ctx.forbidFastPath(true); - - delete outSha; - - // saving arrayTime - if (Environment::getInstance().isProfiling() && node != nullptr) { - arrayEnd = std::chrono::system_clock::now(); - auto arrayTime = std::chrono::duration_cast(arrayEnd - arrayStart).count(); - node->setArrayTime(arrayTime); - } - - return results; - } + if (!canUseFastPath) ctx.forbidFastPath(true); + + // do nothing, getZ result will do the trick + return static_cast(ctx.width()); + } else { + // if op is not inplace - we should pre-allocate arrays + ShapeList inSha; + int results = 0; + + if (Environment::getInstance().isProfiling() && node != nullptr) + inputStart = std::chrono::system_clock::now(); + + int cntIn = 0; + // we build list of input shapes + if (fp) { + for (const auto p : ctx.fastpath_in()) { + inSha.push_back(p == nullptr ? nullptr : p->shapeInfo()); + } + } else { + int arrCnt = 0; + for (auto p : ctx.inputs()) { + auto var = ctx.variable(p); + if (var->variableType() == VariableType::NDARRAY) { + auto array = var->getNDArray(); + if (array.get() == nullptr) + throw unresolved_input_exception::build( + "Variable wasn't resolved prior shape calculation", p); + + inSha.push_back(array->shapeInfo()); + + // we're also filling ctx with arrays + if (canUseFastPath) ctx.setInputArray(arrCnt++, array); + } else { + canUseFastPath = false; } + cntIn++; + } + } - void sd::ops::DeclarableOp::storeResult(Context &block, int outputNumber, NDArray* array) { - this->storeResult(block, outputNumber, *array); - } + // if we override shape function, we'll return size of fastPath + if (fp && ctx.shapeFunctionOverride()) { + return (int)ctx.fastpath_out().size(); + } - void sd::ops::DeclarableOp::storeResult(sd::graph::Context &ctx, int outputNumber, NDArray& array) { - ctx.pushNDArrayToVariableSpace(ctx.nodeId(), outputNumber, &array, !ctx.isInplace()); - } + // optionally saving input time + if (Environment::getInstance().isProfiling() && node != nullptr) { + inputEnd = std::chrono::system_clock::now(); + auto inputTime = std::chrono::duration_cast( + inputEnd - inputStart) + .count(); + node->setInputTime(inputTime); - bool sd::ops::DeclarableOp::allocateResult(Context& block, Nd4jLong* shape) { - auto var = block.variable(block.getNodeId(), 0); + // saving output shapes in profile + for (int e = 0; e < inSha.size(); e++) node->addInputShape(inSha.at(e)); - auto workspace = block.getWorkspace(); + shapeStart = std::chrono::system_clock::now(); + } - Nd4jLong len = shape::length(shape); - Nd4jLong* __shape; - ALLOCATE(__shape, workspace, shape::shapeInfoLength(shape), Nd4jLong); //new int[shape[0] * 2 + 4]; + auto outSha = this->calculateOutputShape(&inSha, ctx); + results = outSha->size(); - memcpy(__shape, shape, shape::shapeInfoByteLength(shape)); + // optionally saving shapeTime + if (Environment::getInstance().isProfiling() && node != nullptr) { + shapeEnd = std::chrono::system_clock::now(); + auto prepTime = std::chrono::duration_cast( + shapeEnd - shapeStart) + .count(); + node->setShapeFunctionTime(prepTime); - // if that's first run - we probably have nothing here - if (var->getNDArray() == nullptr) { + // saving output shapes in profile + for (int e = 0; e < outSha->size(); e++) + node->addOutputShape(outSha->at(e)); - std::shared_ptr buffer = std::make_shared(len * sizeof(int8_t), ArrayOptions::dataType(__shape), workspace); - var->setNDArray(new NDArray(buffer, ShapeDescriptor(__shape), block.launchContext())); - } - else if(var->getNDArray()->lengthOf() != len) { - // if length not match - lets reallocate array - delete var->getNDArray(); - std::shared_ptr buffer = std::make_shared(len * sizeof(int8_t), ArrayOptions::dataType(__shape), workspace); - var->setNDArray(new NDArray(buffer, ShapeDescriptor(__shape), block.launchContext())); - } + arrayStart = std::chrono::system_clock::now(); + } - return true; + int cnt = 0; + + for (auto out : *outSha->asVector()) { + if (!fp) { + // we need to check, if Z is really needed + std::pair pair(ctx.nodeId(), cnt++); + + if (!ctx.isValueAvailable(ctx.name(), ctx.nodeId(), pair.second)) { + if (Environment::getInstance().isDebugAndVerbose()) + shape::printShapeInfoLinear("Going to create variable with shape", + out); + + // we're creating non-initialized array here + NDArray outArr(out, true, ctx.launchContext(), false); + + ctx.pushNDArrayToVariableSpace(pair, outArr); + + if (canUseFastPath) ctx.setOutputArray(pair.second, outArr); + } else { + // validate/compare shapes here. existent vs provided in outSha + auto var = ctx.variable(pair); + auto shape = var->getNDArray()->shapeInfo(); + + if (canUseFastPath) + ctx.setOutputArray(pair.second, var->getNDArray()); + + if (!shape::equalsSoft(out, shape) || + shape::isEmpty(out) != shape::isEmpty(shape)) { + auto eShape = ShapeUtils::shapeAsString(out); + auto aShape = ShapeUtils::shapeAsString(shape); + + // outSha->destroy(); + delete outSha; + + nd4j_printf( + "Expected vs provided shapes mismatch %s vs %s at index %i\n", + eShape.c_str(), aShape.c_str(), pair.second); + throw std::runtime_error("Expected vs provided shapes mismatch"); + } + + /* + * FIXME: we want to uncomment this eventually, and check data types + equality + //checking out data type equality + if (ArrayOptions::dataType(out) != ArrayOptions::dataType(shape)) { + std::string msg = "Provided array [" + + StringUtils::valueToString(pair.second) + "] has unexpected data + type"; throw sd::datatype_exception::build(msg, + ArrayOptions::dataType(out), ArrayOptions::dataType(shape)); + } + */ } - - - bool sd::ops::DeclarableOp::allocateResult(Context& block, std::initializer_list& shape, char order) { - auto var = block.variable(block.getNodeId(), 0); - auto workspace = block.getWorkspace(); - - Nd4jLong len = shape::length(shape); - // if that's first run - we probably have nothing here - if (var->getNDArray() == nullptr) { - var->setNDArray(new NDArray(order, shape, block.dataType(), block.launchContext())); - } else if(var->getNDArray()->lengthOf() != len) { - // if length not match - lets reallocate array - delete var->getNDArray(); - var->setNDArray(new NDArray(order, shape, block.dataType(), block.launchContext())); - } - - return true; + } else { + auto fout = ctx.fastpath_out(); + auto idx = cnt++; + if (fout.size() <= idx) { + // array doesnt exist + auto outArr = + std::make_shared(out, true, ctx.launchContext()); + ctx.setOutputArray(idx, outArr); + } else { + auto array = fout[idx]; + // checking out shape equality + if (!shape::equalsSoft(out, array->shapeInfo()) || + shape::isEmpty(out) != array->isEmpty()) { + auto eShape = ShapeUtils::shapeAsString(out); + auto aShape = ShapeUtils::shapeAsString(array->shapeInfo()); + + // outSha->destroy(); + delete outSha; + + nd4j_printf( + "Expected vs provided shape mismatch %s vs %s at index %i\n", + eShape.c_str(), aShape.c_str(), idx); + throw std::runtime_error("Expected vs provided shape mismatch"); + } } + } + } - Nd4jStatus sd::ops::DeclarableOp::validateDataTypes(Context& block) { - _registrator.lock(); - if (!_registered) { - _registered = true; - this->registerTypes(); - } - _registrator.unlock(); - - // rolling over inputs first - int cnt = 0, inT = 0; - std::vector inputTypes(block.width()); - if (block.isFastPath()) { - for (auto array: block.fastpath_in()) { - if (array == nullptr) - continue; - - inputTypes[inT++] = array->dataType(); - if (!_descriptor->checkInputMatch(cnt, array->dataType())) { - auto ctype = DataTypeUtils::asString(array->dataType()); - nd4j_printf("Op [%s] failed check for input [%i], DataType: [%s]\n", - _descriptor->getOpName()->data(), cnt, ctype.c_str()); - return ND4J_STATUS_BAD_ARGUMENTS; - } - cnt++; - } - } else { - for (auto &p: *(block.inputs())) { - auto var = block.variable(p); - - // we're not checking validity, if ANY types were explicitly allowed - //if (block.dataType(cnt) == sd::DataType::ANY) - // continue; - - // only validating non-null variables - if (var != nullptr && var->hasNDArray()) { - auto array = var->getNDArray(); - - inputTypes[inT++] = array->dataType(); - if (!_descriptor->checkInputMatch(cnt, array->dataType())) { - auto ctype = DataTypeUtils::asString(array->dataType()); - nd4j_printf("Op [%s] failed check for input [%i], DataType: [%s]\n", - _descriptor->getOpName()->data(), cnt, ctype.c_str()); - return ND4J_STATUS_BAD_ARGUMENTS; - } - } - - cnt++; - } - } + if (!canUseFastPath) ctx.forbidFastPath(true); - if (block.isFastPath()) { - int index = 0; - for (auto array: block.fastpath_out()) { - if (array == nullptr) - continue; - - auto cType = array->dataType(); - - if (_descriptor->isSameMode()) { - - if (index >= block.width()) { - if (block.fastpath_in().size() == 0) - continue; - - auto ia = block.fastpath_in()[0]; - - if (ia->dataType() != cType) { - auto t = DataTypeUtils::asString(cType); - nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n", - _descriptor->getOpName()->data(), index, t.c_str()); - return ND4J_STATUS_BAD_ARGUMENTS; - } - } else { - // for same mode, output type must be the same as input type - auto ia = block.fastpath_in()[index]; - - if (ia->dataType() != cType) { - auto t = DataTypeUtils::asString(cType); - nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n", - _descriptor->getOpName()->data(), index, t.c_str()); - return ND4J_STATUS_BAD_ARGUMENTS; - } - } - } else if (_descriptor->isInherit(index)) { - // in inherit mode, output type must be the same as one of input types - if (std::find(inputTypes.begin(), inputTypes.end(), cType) == inputTypes.end()) { - auto t = DataTypeUtils::asString(cType); - nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s].\n", - _descriptor->getOpName()->data(), index, t.c_str()); - return ND4J_STATUS_BAD_ARGUMENTS; - } - - } else if (!_descriptor->checkOutputMatch(index, cType)) { - auto t = DataTypeUtils::asString(cType); - nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s];\n", - _descriptor->getOpName()->data(), index, t.c_str()); - return ND4J_STATUS_BAD_ARGUMENTS; - } - index++; - } - } else { - // checking optionally available outputs - auto varSpace = block.getVariableSpace(); - for (int index = 0; index < DataTypeUtils::max(); index++) { - if (varSpace != nullptr && varSpace->hasVariable(block.nodeId(), index)) { - auto var = block.variable(block.nodeId(), index); - - // only validating non-null variables - if (var != nullptr && var->hasNDArray()) { - auto array = var->getNDArray(); - auto cType = array->dataType(); - - if (_descriptor->isSameMode()) { - - if (index >= block.width()) { - if (block.width() == 0) - continue; - - auto iv = block.variable(0); - - if (iv->getNDArray()->dataType() != cType) { - auto t = DataTypeUtils::asString(cType); - nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n", - _descriptor->getOpName()->data(), index, t.c_str()); - return ND4J_STATUS_BAD_ARGUMENTS; - } - } else { - // for same mode, output type must be the same as input type - auto iv = block.variable(index); - - if (iv->getNDArray()->dataType() != cType) { - auto t = DataTypeUtils::asString(cType); - nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n", - _descriptor->getOpName()->data(), index, t.c_str()); - return ND4J_STATUS_BAD_ARGUMENTS; - } - } - } else if (_descriptor->isInherit(index)) { - // in inherit mode, output type must be the same as one of input types - if (std::find(inputTypes.begin(), inputTypes.end(), cType) == inputTypes.end()) { - auto t = DataTypeUtils::asString(cType); - nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s].\n", - _descriptor->getOpName()->data(), index, t.c_str()); - return ND4J_STATUS_BAD_ARGUMENTS; - } - - } else if (!_descriptor->checkOutputMatch(index, cType)) { - auto t = DataTypeUtils::asString(cType); - nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s];\n", - _descriptor->getOpName()->data(), index, t.c_str()); - return ND4J_STATUS_BAD_ARGUMENTS; - } - } - } else - break; - } - } + delete outSha; + // saving arrayTime + if (Environment::getInstance().isProfiling() && node != nullptr) { + arrayEnd = std::chrono::system_clock::now(); + auto arrayTime = std::chrono::duration_cast( + arrayEnd - arrayStart) + .count(); + node->setArrayTime(arrayTime); + } - return ND4J_STATUS_OK; + return results; + } +} + +void sd::ops::DeclarableOp::storeResult(Context &block, int outputNumber, + NDArray *array) { + this->storeResult(block, outputNumber, *array); +} + +void sd::ops::DeclarableOp::storeResult(sd::graph::Context &ctx, + int outputNumber, NDArray &array) { + ctx.pushNDArrayToVariableSpace(ctx.nodeId(), outputNumber, array); +} + +bool sd::ops::DeclarableOp::allocateResult(Context &block, Nd4jLong *shape) { + auto var = block.variable(block.getNodeId(), 0); + + auto workspace = block.workspace(); + + Nd4jLong len = shape::length(shape); + Nd4jLong *__shape; + ALLOCATE(__shape, workspace, shape::shapeInfoLength(shape), + Nd4jLong); // new int[shape[0] * 2 + 4]; + + memcpy(__shape, shape, shape::shapeInfoByteLength(shape)); + + // if that's first run - we probably have nothing here + if (var->getNDArray() == nullptr) { + std::shared_ptr buffer = std::make_shared( + len * sizeof(int8_t), ArrayOptions::dataType(__shape), workspace); + var->setNDArray(std::make_shared(buffer, ShapeDescriptor(__shape), + block.launchContext())); + } else if (var->getNDArray()->lengthOf() != len) { + // if length not match - lets reallocate array + std::shared_ptr buffer = std::make_shared( + len * sizeof(int8_t), ArrayOptions::dataType(__shape), workspace); + var->setNDArray(std::make_shared(buffer, ShapeDescriptor(__shape), + block.launchContext())); + } + + return true; +} + +bool sd::ops::DeclarableOp::allocateResult( + Context &block, std::initializer_list &shape, char order) { + auto var = block.variable(block.getNodeId(), 0); + auto workspace = block.workspace(); + + Nd4jLong len = shape::length(shape); + // if that's first run - we probably have nothing here + if (var->getNDArray() == nullptr) { + var->setNDArray(std::make_shared(order, shape, DataType::FLOAT32, + block.launchContext())); + } else if (var->getNDArray()->lengthOf() != len) { + // if length not match - lets reallocate array + var->setNDArray(std::make_shared(order, shape, DataType::FLOAT32, + block.launchContext())); + } + + return true; +} + +Nd4jStatus sd::ops::DeclarableOp::validateDataTypes(Context &block) { + _registrator.lock(); + if (!_registered) { + _registered = true; + this->registerTypes(); + } + _registrator.unlock(); + + // rolling over inputs first + int cnt = 0, inT = 0; + std::vector inputTypes(block.width()); + if (block.isFastPath()) { + for (auto array : block.fastpath_in()) { + if (array == nullptr) continue; + + inputTypes[inT++] = array->dataType(); + if (!_descriptor->checkInputMatch(cnt, array->dataType())) { + auto ctype = DataTypeUtils::asString(array->dataType()); + nd4j_printf("Op [%s] failed check for input [%i], DataType: [%s]\n", + _descriptor->getOpName()->data(), cnt, ctype.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; + } + cnt++; + } + } else { + for (auto &p : block.inputs()) { + auto var = block.variable(p); + + // we're not checking validity, if ANY types were explicitly allowed + // if (block.dataType(cnt) == sd::DataType::ANY) + // continue; + + // only validating non-null variables + if (var != nullptr && var->hasNDArray()) { + auto array = var->getNDArray(); + + inputTypes[inT++] = array->dataType(); + if (!_descriptor->checkInputMatch(cnt, array->dataType())) { + auto ctype = DataTypeUtils::asString(array->dataType()); + nd4j_printf("Op [%s] failed check for input [%i], DataType: [%s]\n", + _descriptor->getOpName()->data(), cnt, ctype.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; } + } - Nd4jStatus sd::ops::DeclarableOp::execute(Context* block) { - nd4j_debug("Executing op: [%s]\n", this->getOpName()->c_str()); - - std::chrono::time_point timeEnter, timeStart, timeEnd; - Nd4jLong prepTime, outerTime; - - Nd4jLong memoryBefore = block->workspace() == nullptr ? 0L : block->workspace()->getSpilledSize() + block->workspace()->getUsedSize(); - if (Environment::getInstance().isProfiling()) - timeEnter = std::chrono::system_clock::now(); - - // basic validation: ensure inputs are set - REQUIRE_OK(this->validateNonEmptyInput(*block)); - - // ensure number of IArgs, TArgs match our expectations - REQUIRE_OK(this->validateArguments(*block)); - - // validating data types for inputs and (optionally) outputs - REQUIRE_OK(this->validateDataTypes(*block)); - - - // this method will allocate output NDArrays for this op - auto numOutputs = this->prepareOutputs(*block); - - if (Environment::getInstance().isProfiling()) { - timeStart = std::chrono::system_clock::now(); - prepTime = std::chrono::duration_cast(timeStart - timeEnter).count(); - } - - - Nd4jStatus status; - bool hasHelper = false; - - // platform helpers use might be forbidden for various reasons, so we'll check it out first - if (block->helpersAllowed() && sd::Environment::getInstance().helpersAllowed()) { - // if we have platform-specific helper for this op - invoke it - if (OpRegistrator::getInstance().hasHelper(this->getOpHash(), block->engine())) { - auto helper = OpRegistrator::getInstance().getPlatformHelper(this->getOpHash(), block->engine()); - if (helper->isUsable(*block)) { - status = helper->invokeHelper(*block); - hasHelper = true; - } - } - } - - // if we don't have platform-specific helper - invoke generic implementation - if (!hasHelper) - status = this->validateAndExecute(*block); - - // optionally saving execution time - if (Environment::getInstance().isProfiling()) { - timeEnd = std::chrono::system_clock::now(); - outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - block->setInnerTime(outerTime); - } - - if (Environment::getInstance().isProfiling() && block->getVariableSpace() != nullptr) { - auto fp = block->getVariableSpace()->flowPath(); - if (fp != nullptr) { - auto p = fp->profile(); - if (p != nullptr) { - Nd4jLong memoryAfter = block->workspace() == nullptr ? 0L : block->workspace()->getSpilledSize() + block->workspace()->getUsedSize(); - Nd4jLong memoryUsed = memoryAfter - memoryBefore; - p->nodeById(block->nodeId())->setPreparationTime(prepTime); - p->nodeById(block->nodeId())->setExecutionTime(outerTime); - p->nodeById(block->nodeId())->setTotalSize(memoryUsed); - } - } - } - - - // now we print out all outputs for this node - if (sd::Environment::getInstance().isDebugAndVerbose()) { - auto vs = block->getVariableSpace(); - - for (int e = 0; e < numOutputs; e++) { - // if given output index doesn't exist - we're done - - if (!block->isFastPath()) { - if (!vs->hasVariable(block->nodeId(), e)) - break; - } else { - // we have to check either in or out stack, depending on isInplace() - if (block->isInplace()) { - if (block->fastpath_in().size() <= e) - break; - } else { - if (block->fastpath_out().size() <= e) - break; - } - } - - auto array = block->isFastPath() ? block->isInplace() ? block->fastpath_in()[e] : block->fastpath_out()[e] : vs->getVariable(block->nodeId(), e)->getNDArray(); - - auto shape = ShapeUtils::shapeAsString(array); - auto first = array->isEmpty() ? std::string("Empty NDArray") : array->asString(32); - auto type = DataTypeUtils::asString(array->dataType()); - - nd4j_printf("node_%i:%i result shape: %s; dtype: %s; first values %s\n", block->nodeId(), e, shape.c_str(), type.c_str(), first.c_str()); - } - } - - return status; + cnt++; + } + } + + if (block.isFastPath()) { + int index = 0; + for (auto array : block.fastpath_out()) { + if (array == nullptr) continue; + + auto cType = array->dataType(); + + if (_descriptor->isSameMode()) { + if (index >= block.width()) { + if (block.fastpath_in().size() == 0) continue; + + auto ia = block.fastpath_in()[0]; + + if (ia->dataType() != cType) { + auto t = DataTypeUtils::asString(cType); + nd4j_printf( + "Op [%s] failed check for output [%i], DataType: [%s]\n", + _descriptor->getOpName()->data(), index, t.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; + } + } else { + // for same mode, output type must be the same as input type + auto ia = block.fastpath_in()[index]; + + if (ia->dataType() != cType) { + auto t = DataTypeUtils::asString(cType); + nd4j_printf( + "Op [%s] failed check for output [%i], DataType: [%s]\n", + _descriptor->getOpName()->data(), index, t.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; + } } - - void DeclarableOp::overwriteResult(Context &block, int outputIdx, NDArray *array) { - throw std::runtime_error("Overwrite result used!"); - //block.pushNDArrayToVariableSpace(block.nodeId(), outputIdx, array); - /* - auto varSpace = block.getVariableSpace(); - if (varSpace->hasVariable(block.getNodeId(), outputIdx)) { - auto var = varSpace->getVariable(block.getNodeId(), outputIdx); - if (var->getNDArray() != nullptr && var->isRemovable()) - delete var->getNDArray(); - - var->setNDArray(array); - var->markRemovable(true); - } else { - auto var = new Variable(array, nullptr, block.getNodeId(), outputIdx); - varSpace->putVariable(block.getNodeId(), outputIdx, var); - } - */ + } else if (_descriptor->isInherit(index)) { + // in inherit mode, output type must be the same as one of input types + if (std::find(inputTypes.begin(), inputTypes.end(), cType) == + inputTypes.end()) { + auto t = DataTypeUtils::asString(cType); + nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s].\n", + _descriptor->getOpName()->data(), index, t.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; } - void DeclarableOp::overwriteResult(Context &block, int outputIdx, NDArrayList *list) { - throw std::runtime_error("Overwrite result used!"); - //block.pushNDArrayListToVariableSpace(block.nodeId(), outputIdx, list); - /* - auto varSpace = block.getVariableSpace(); - if (varSpace->hasVariable(block.getNodeId(), outputIdx)) { - auto var = varSpace->getVariable(block.getNodeId(), outputIdx); - var->setNDArrayList(list); + } else if (!_descriptor->checkOutputMatch(index, cType)) { + auto t = DataTypeUtils::asString(cType); + nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s];\n", + _descriptor->getOpName()->data(), index, t.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; + } + index++; + } + } else { + // checking optionally available outputs + auto varSpace = block.getVariableSpace(); + for (int index = 0; index < DataTypeUtils::max(); index++) { + if (varSpace != nullptr && varSpace->hasVariable(block.nodeId(), index)) { + auto var = block.variable(block.nodeId(), index); + + // only validating non-null variables + if (var != nullptr && var->hasNDArray()) { + auto array = var->getNDArray(); + auto cType = array->dataType(); + + if (_descriptor->isSameMode()) { + if (index >= block.width()) { + if (block.width() == 0) continue; + + auto iv = block.variable(0); + + if (iv->getNDArray()->dataType() != cType) { + auto t = DataTypeUtils::asString(cType); + nd4j_printf( + "Op [%s] failed check for output [%i], DataType: [%s]\n", + _descriptor->getOpName()->data(), index, t.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; + } } else { - auto var = new Variable(nullptr, nullptr, block.getNodeId(), outputIdx); - var->setNDArrayList(list); - varSpace->putVariable(block.getNodeId(), outputIdx, var); - } - */ - } - - Nd4jStatus sd::ops::DeclarableOp::validateArguments(Context& block) { - /* - * We're checking number of T and I arguments. If number of args is finite number - we check strict equality - * If number of args is variable (-1), but variables MUST be present - we check for non-zero number of arguments - */ - if (_descriptor->getNumberOfTArgs() > 0) { - if ((int) block.getTArguments()->size() < _descriptor->getNumberOfTArgs()) { - nd4j_printf("%s: %i T args expected, but %i received\n", this->getOpName()->c_str(), _descriptor->getNumberOfTArgs(), block.getTArguments()->size()); - return ND4J_STATUS_BAD_PARAMS; - } - } else - if (_descriptor->getNumberOfTArgs() == -1) - if (block.getTArguments()->size() == 0) { - nd4j_printf("%s: Number of T arguments should be positive number, but got 0 arguments\n", this->getOpName()->c_str()); - return ND4J_STATUS_BAD_PARAMS; - } - - if (_descriptor->getNumberOfIArgs() > 0) { - if ((int) block.getIArguments()->size() < _descriptor->getNumberOfIArgs()) { - nd4j_printf("%s: %i int args expected, but %i received\n", this->getOpName()->c_str(), _descriptor->getNumberOfIArgs(), block.getIArguments()->size()); - return ND4J_STATUS_BAD_PARAMS; - } - } else - if (_descriptor->getNumberOfIArgs() == -1) - if (block.getIArguments()->size() == 0) { - nd4j_printf("%s: Number of Integer arguments should be positive number, but got 0 arguments\n", this->getOpName()->c_str()); - return ND4J_STATUS_BAD_PARAMS; - } - - - return ND4J_STATUS_OK; - } - - Nd4jStatus sd::ops::DeclarableOp::validateInputDimensions(Context& block, int rank) { - if (block.width() == 0) - return ND4J_STATUS_OK; - - for (auto p: *block.inputs()) { - auto v = block.variable(p); - NDArray *aV = v->getNDArray(); - - if (aV == nullptr) - return ND4J_STATUS_BAD_INPUT; - - if (aV->rankOf() != rank) - return ND4J_STATUS_BAD_DIMENSIONS; - } - - return ND4J_STATUS_OK; - } - - Nd4jStatus sd::ops::DeclarableOp::validateInput2D(Context& block) { - return validateInputDimensions(block, 2); - } - - Nd4jStatus sd::ops::DeclarableOp::validateInput3D(Context& block) { - return validateInputDimensions(block, 3); - } - - Nd4jStatus sd::ops::DeclarableOp::validateInput4D(Context& block) { - return validateInputDimensions(block, 4); - } - - Nd4jStatus sd::ops::DeclarableOp::validateNonEmptyInput(Context& block) { - if (this->getOpDescriptor()->getNumberOfInputs() == -2 || this->getOpDescriptor()->getNumberOfInputs() == 0) - return Status::OK(); - - if (block.width() < 1) { - nd4j_printf("%s: no operands provided for the op", this->getOpName()->c_str()); - return ND4J_STATUS_BAD_INPUT; + // for same mode, output type must be the same as input type + auto iv = block.variable(index); + + if (iv->getNDArray()->dataType() != cType) { + auto t = DataTypeUtils::asString(cType); + nd4j_printf( + "Op [%s] failed check for output [%i], DataType: [%s]\n", + _descriptor->getOpName()->data(), index, t.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; + } } - - - int cnt = 0; - for (auto p: *block.inputs()) { - auto v = block.variable(p); - if (v == nullptr) { - if (this->getOpName() != nullptr) { - nd4j_printf("Node [%i:<%s>]: Variable [%i] (%i:%i) is NULL\n", block.getNodeId(), this->getOpName()->c_str(), cnt, p.first, p.second); - } else { - nd4j_printf("Node [%i:]: Variable [%i] (%i:%i) is NULL\n", block.getNodeId(), cnt, p.first, p.second); - } - return ND4J_STATUS_BAD_INPUT; - } - - if (v->variableType() == VariableType::NDARRAY) { - NDArray *aV = v->getNDArray(); - - // if array is empty intentionally - we're ok with that - if (v->hasNDArray() && v->isEmpty()) - continue; - - if (aV == nullptr || !aV->nonNull()) { - if (this->getOpName() != nullptr) { - nd4j_printf("Node [%i:<%s>]: NDArray [%i] (%i:%i) is NULL\n", block.getNodeId(), this->getOpName()->c_str(), cnt, p.first, p.second); - } else { - nd4j_printf("Node [%i:]: NDArray [%i] (%i:%i) is NULL\n", block.getNodeId(), cnt, p.first, p.second); - } - return ND4J_STATUS_BAD_INPUT; - } - } - - cnt++; - } - - return ND4J_STATUS_OK; - } - - Nd4jStatus sd::ops::DeclarableOp::validateOrdersMatch(Context& block) { - if (block.width() == 0) - return ND4J_STATUS_OK; - - NDArray *a0 = block.variable(0)->getNDArray(); - for (auto p: *block.inputs()) { - auto v = block.variable(p); - NDArray *aV = v->getNDArray(); - if (a0->ordering() != aV->ordering()) - return ND4J_STATUS_BAD_ORDER; - } - - return ND4J_STATUS_OK; - } - - Nd4jStatus sd::ops::DeclarableOp::execute(sd::graph::RandomGenerator& rng, const std::vector& inputs, const std::vector& outputs, const std::vector& tArgs, const std::vector& iArgs, const std::vector& bArgs, const std::vector& dArgs, bool isInplace, sd::DataType type) { - VariableSpace variableSpace; - FlowPath fp; - variableSpace.setFlowPath(&fp); - - int cnt = -1; - std::vector in; - for (auto v: inputs) { - if (v == nullptr) - continue; - - auto var = new Variable(v); - var->markRemovable(false); - in.push_back(cnt); - variableSpace.putVariable(cnt--, var); - } - - int et = 0; - for (auto v: outputs) { - auto var = new Variable(v); - var->markRemovable(false); - std::pair pair(1, et++); - variableSpace.putVariable(pair, var); + } else if (_descriptor->isInherit(index)) { + // in inherit mode, output type must be the same as one of input + // types + if (std::find(inputTypes.begin(), inputTypes.end(), cType) == + inputTypes.end()) { + auto t = DataTypeUtils::asString(cType); + nd4j_printf( + "Op [%s] failed check for output [%i], DataType: [%s].\n", + _descriptor->getOpName()->data(), index, t.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; } - Context block(1, &variableSpace, false); - block.fillInputs(in); - block.markInplace(isInplace); - block.setDataType(0, type); - - // we need this line for tests basically - //if (rng != nullptr) - block.setRng(rng); - - for (int e = 0; e < tArgs.size(); e++) - block.getTArguments()->emplace_back(tArgs.at(e)); - - // FIXME: iargs should be Nd4jLong - for (int e = 0; e < iArgs.size(); e++) - block.getIArguments()->emplace_back(static_cast(iArgs.at(e))); - - for (int e = 0; e < bArgs.size(); e++) - block.getBArguments()->push_back(static_cast(bArgs.at(e))); - - for (int e = 0; e < dArgs.size(); e++) - block.getDArguments()->push_back(dArgs.at(e)); - - Nd4jStatus result = this->execute(&block); - - return result; - } - - Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs) { - return execute(inputs, outputs, std::vector(), std::vector(), std::vector(), std::vector()); - } - - template <> - Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list tArgs) { - return execute(inputs, outputs, tArgs, std::vector(), std::vector(), std::vector()); - } - - template <> - Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list dArgs) { - return execute(inputs, outputs, std::vector(), std::vector(), std::vector(), dArgs); + } else if (!_descriptor->checkOutputMatch(index, cType)) { + auto t = DataTypeUtils::asString(cType); + nd4j_printf( + "Op [%s] failed check for output [%i], DataType: [%s];\n", + _descriptor->getOpName()->data(), index, t.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; + } } - - template <> - Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list tArgs) { - std::vector realArgs; - for (auto v:tArgs) - realArgs.emplace_back(v); - - return execute(inputs, outputs, realArgs, std::vector(), std::vector(), std::vector()); - } - - template <> - Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list iArgs) { - return execute(inputs, outputs, std::vector(), iArgs, std::vector(), std::vector()); - } - - template <> - Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list iArgs) { - std::vector realArgs; - for (auto v:iArgs) - realArgs.emplace_back(v); - - return execute(inputs, outputs, std::vector(), realArgs, std::vector(), std::vector()); - } - - template <> - Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list bArgs) { - return execute(inputs, outputs, std::vector(), std::vector(), bArgs, std::vector()); - } - - Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, const std::vector &tArgs, const std::vector &iArgs, const std::vector &bArgs, const std::vector &dArgs, bool isInplace) { - Context ctx(1); - - for (int e = 0; e < inputs.size(); e++) { - ctx.setInputArray(e, inputs[e]); - } - - for (int e = 0; e < outputs.size(); e++) { - ctx.setOutputArray(e, outputs[e]); - } - - - if (isInplace) - ctx.markInplace(isInplace); - - ctx.setIArguments(iArgs); - ctx.setTArguments(tArgs); - ctx.setBArguments(bArgs); - ctx.setDArguments(dArgs); - - return execute(&ctx); - } - - sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs) { - return evaluate(inputs, std::vector(), std::vector(), std::vector(), std::vector()); - } - - template <> - sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, std::initializer_list iArgs) { - std::vector realArgs; - for (auto v:iArgs) - realArgs.emplace_back(v); - - return evaluate(inputs, std::vector(), realArgs, std::vector(), std::vector()); - } - - template <> - sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, std::initializer_list iArgs) { - return evaluate(inputs, std::vector(), iArgs, std::vector(), std::vector()); - } - - template <> - sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, std::initializer_list tArgs) { - std::vector realArgs; - for (auto v:tArgs) - realArgs.emplace_back(v); - - return evaluate(inputs, realArgs, std::vector(), std::vector(), std::vector()); - } - - template <> - sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, std::initializer_list tArgs) { - return evaluate(inputs, tArgs, std::vector(), std::vector(), std::vector()); - } - - template <> - sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, std::initializer_list bArgs) { - return evaluate(inputs, std::vector(), std::vector(), bArgs, std::vector()); - } - - template <> - sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, std::initializer_list bArgs) { - return evaluate(inputs, std::vector(), std::vector(), std::vector(), bArgs); - } - - sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, const std::vector &tArgs, const std::vector &iArgs, const std::vector &bArgs, const std::vector &dArgs, bool isInplace) { - VariableSpace variableSpace; - //ResultSet arrayList; - FlowPath fp; - variableSpace.setFlowPath(&fp); - - int cnt = -1; - std::vector in; - for (auto v: inputs) { - if (v == nullptr) - continue; - - auto var = new Variable(v); - var->markRemovable(false); - in.push_back(cnt); - variableSpace.putVariable(cnt--, var); - } - - Context block(1, &variableSpace, false); - block.setDataType(0, sd::DataType::FLOAT32); - block.fillInputs(in); - block.markInplace(isInplace); - // block.setRNG(ProviderRNG::getInstance().getRNG()); - - for (int e = 0; e < tArgs.size(); e++) - block.getTArguments()->emplace_back(tArgs.at(e)); - - for (int e = 0; e < iArgs.size(); e++) - block.getIArguments()->emplace_back(iArgs.at(e)); - - for (int e = 0; e < bArgs.size(); e++) - block.getBArguments()->push_back(bArgs.at(e)); - - for (int e = 0; e < dArgs.size(); e++) - block.getDArguments()->push_back(dArgs.at(e)); - - Nd4jStatus status = this->execute(&block); - ResultSet arrayList; - if (isInplace) - arrayList.setNonRemovable(); - - arrayList.setStatus(status); - if (status != ND4J_STATUS_OK) - return arrayList; - - if (!isInplace) { - for (int e = 0; e < DataTypeUtils::max(); e++) { - std::pair pair(1, e); - if (variableSpace.hasVariable(pair)) { - auto var = variableSpace.getVariable(pair); - auto arr = var->getNDArray(); - if (!arr->isAttached()) { - var->markRemovable(false); - arr->setContext(sd::LaunchContext::defaultContext()); - arrayList.push_back(arr); - } else { - arrayList.push_back(arr->detach()); - } - } else - break; - } - } else { - for (auto v:inputs) { - arrayList.push_back(v); - } - } - - return arrayList; - } - - sd::ResultSet sd::ops::DeclarableOp::execute(const sd::OpArgsHolder& holder, bool isInplace) { - // FIXME: add DArgs to OpArgsHolder - return evaluate(holder.getInArrs(), holder.getTArgs(), holder.getIArgs(), holder.getBArgs(), std::vector(), isInplace); + } else + break; + } + } + + return ND4J_STATUS_OK; +} + +Nd4jStatus DeclarableOp::execute(Context *block) { + nd4j_debug("Executing op: [%s]\n", this->getOpName().c_str()); + + std::chrono::time_point timeEnter, timeStart, + timeEnd; + Nd4jLong prepTime, outerTime; + + Nd4jLong memoryBefore = block->workspace() == nullptr + ? 0L + : block->workspace()->getSpilledSize() + + block->workspace()->getUsedSize(); + if (Environment::getInstance().isProfiling()) + timeEnter = std::chrono::system_clock::now(); + + // make sure we're not trying to call non-inpace op inplace + if (block->isInplace() && !this->getOpDescriptor()->allowsInplace()) + throw std::runtime_error( + "DeclarableOp::execute - trying to execute non-inplace op as inplace"); + + // basic validation: ensure inputs are set + REQUIRE_OK(this->validateNonEmptyInput(*block)); + + // ensure number of IArgs, TArgs match our expectations + REQUIRE_OK(this->validateArguments(*block)); + + // validating data types for inputs and (optionally) outputs + REQUIRE_OK(this->validateDataTypes(*block)); + + // this method will allocate output NDArrays for this op + auto numOutputs = this->prepareOutputs(*block); + + if (Environment::getInstance().isProfiling()) { + timeStart = std::chrono::system_clock::now(); + prepTime = std::chrono::duration_cast(timeStart - + timeEnter) + .count(); + } + + Nd4jStatus status; + bool hasHelper = false; + + // platform helpers use might be forbidden for various reasons, so we'll check + // it out first + if (block->helpersAllowed() && + sd::Environment::getInstance().helpersAllowed()) { + // if we have platform-specific helper for this op - invoke it + if (OpRegistrator::getInstance().hasHelper(this->getOpHash(), + block->engine())) { + auto helper = OpRegistrator::getInstance().getPlatformHelper( + this->getOpHash(), block->engine()); + if (helper->isUsable(*block)) { + status = helper->invokeHelper(*block); + hasHelper = true; + } + } + } + + // if we don't have platform-specific helper - invoke generic implementation + if (!hasHelper) status = this->validateAndExecute(*block); + + // optionally saving execution time + if (Environment::getInstance().isProfiling()) { + timeEnd = std::chrono::system_clock::now(); + outerTime = std::chrono::duration_cast(timeEnd - + timeStart) + .count(); + block->setInnerTime(outerTime); + } + + if (Environment::getInstance().isProfiling() && + block->getVariableSpace() != nullptr) { + /* + auto fp = block->getVariableSpace()->flowPath(); + if (fp != nullptr) { + auto p = fp->profile(); + if (p != nullptr) { + Nd4jLong memoryAfter = block->workspace() == nullptr ? 0L : + block->workspace()->getSpilledSize() + block->workspace()->getUsedSize(); + Nd4jLong memoryUsed = memoryAfter - memoryBefore; + p->nodeById(block->nodeId())->setPreparationTime(prepTime); + p->nodeById(block->nodeId())->setExecutionTime(outerTime); + p->nodeById(block->nodeId())->setTotalSize(memoryUsed); } - - Nd4jStatus sd::ops::DeclarableOp::validateInputDimensionsMatch(Context& block) { - if (block.width() == 0) - return ND4J_STATUS_OK; - - NDArray *a0 = block.array(0); - for (int e = 1; e < block.width(); e++) { - auto aV = block.array(e); - if (!shape::equalsSoft(a0->shapeInfo(), aV->shapeInfo())) - return ND4J_STATUS_BAD_DIMENSIONS; - } - - return ND4J_STATUS_OK; + } + */ + throw std::runtime_error("DeclarableOp::execute - Not implemented yet"); + } + + // now we print out all outputs for this node + if (sd::Environment::getInstance().isDebugAndVerbose()) { + auto vs = block->getVariableSpace(); + + for (int e = 0; e < numOutputs; e++) { + // if given output index doesn't exist - we're done + + if (!block->isFastPath()) { + if (!vs->hasVariable(block->nodeId(), e)) break; + } else { + // we have to check either in or out stack, depending on isInplace() + if (block->isInplace()) { + if (block->fastpath_in().size() <= e) break; + } else { + if (block->fastpath_out().size() <= e) break; } + } - Nd4jStatus sd::ops::DeclarableOp::validateInputLengthMatch(Context& block) { - if (block.width() == 0) - return ND4J_STATUS_OK; + auto array = block->isFastPath() + ? block->isInplace() ? block->fastpath_in()[e] + : block->fastpath_out()[e] + : vs->getVariable(block->nodeId(), e)->getNDArray(); + auto shape = ShapeUtils::shapeAsString(array.get()); + auto first = + array->isEmpty() ? std::string("Empty NDArray") : array->asString(32); + auto type = DataTypeUtils::asString(array->dataType()); - Nd4jLong l0 = block.array(0)->lengthOf(); - for (uint32_t e = 0; e < block.width(); e++) { - if (l0 != block.array(e)->lengthOf()) - return ND4J_STATUS_BAD_LENGTH; - } + nd4j_printf("node_%i:%i result shape: %s; dtype: %s; first values %s\n", + block->nodeId(), e, shape.c_str(), type.c_str(), + first.c_str()); + } + } + + return status; +} + +void DeclarableOp::overwriteResult(Context &block, int outputIdx, + NDArray *array) { + throw std::runtime_error("Overwrite result used!"); + // block.pushNDArrayToVariableSpace(block.nodeId(), outputIdx, array); + /* + auto varSpace = block.getVariableSpace(); + if (varSpace->hasVariable(block.getNodeId(), outputIdx)) { + auto var = varSpace->getVariable(block.getNodeId(), outputIdx); + if (var->getNDArray() != nullptr && var->isRemovable()) + delete var->getNDArray(); + + var->setNDArray(array); + var->markRemovable(true); + } else { + auto var = new Variable(array, nullptr, block.getNodeId(), outputIdx); + varSpace->putVariable(block.getNodeId(), outputIdx, var); + } + */ +} + +void DeclarableOp::overwriteResult(Context &block, int outputIdx, + NDArrayList *list) { + throw std::runtime_error("Overwrite result used!"); + // block.pushNDArrayListToVariableSpace(block.nodeId(), outputIdx, list); + /* + auto varSpace = block.getVariableSpace(); + if (varSpace->hasVariable(block.getNodeId(), outputIdx)) { + auto var = varSpace->getVariable(block.getNodeId(), outputIdx); + var->setNDArrayList(list); + } else { + auto var = new Variable(nullptr, nullptr, block.getNodeId(), outputIdx); + var->setNDArrayList(list); + varSpace->putVariable(block.getNodeId(), outputIdx, var); + } + */ +} + +Nd4jStatus sd::ops::DeclarableOp::validateArguments(Context &block) { + /* + * We're checking number of T and I arguments. If number of args is finite + * number - we check strict equality If number of args is variable (-1), but + * variables MUST be present - we check for non-zero number of arguments + */ + if (_descriptor->getNumberOfTArgs() > 0) { + if ((int)block.numT() < _descriptor->getNumberOfTArgs()) { + nd4j_printf("%s: %i T args expected, but %i received\n", + this->getOpName().c_str(), _descriptor->getNumberOfTArgs(), + block.numT()); + return ND4J_STATUS_BAD_PARAMS; + } + } else if (_descriptor->getNumberOfTArgs() == -1) + if (block.numT() == 0) { + nd4j_printf( + "%s: Number of T arguments should be positive number, but got 0 " + "arguments\n", + this->getOpName().c_str()); + return ND4J_STATUS_BAD_PARAMS; + } - return ND4J_STATUS_OK; - } + if (_descriptor->getNumberOfIArgs() > 0) { + if ((int)block.numI() < _descriptor->getNumberOfIArgs()) { + nd4j_printf("%s: %i int args expected, but %i received\n", + this->getOpName().c_str(), _descriptor->getNumberOfIArgs(), + block.numI()); + return ND4J_STATUS_BAD_PARAMS; + } + } else if (_descriptor->getNumberOfIArgs() == -1) + if (block.numI() == 0) { + nd4j_printf( + "%s: Number of Integer arguments should be positive number, but got " + "0 arguments\n", + this->getOpName().c_str()); + return ND4J_STATUS_BAD_PARAMS; + } - samediff::EmptyHandling DeclarableOp::emptyHandling() { - return samediff::EmptyHandling::EMPTY_SKIP; - } + return ND4J_STATUS_OK; +} + +Nd4jStatus sd::ops::DeclarableOp::validateInputDimensions(Context &block, + int rank) { + if (block.width() == 0) return ND4J_STATUS_OK; + + for (auto p : block.inputs()) { + auto v = block.variable(p); + NDArray *aV = v->getNDArray().get(); + + if (aV == nullptr) return ND4J_STATUS_BAD_INPUT; + + if (aV->rankOf() != rank) return ND4J_STATUS_BAD_DIMENSIONS; + } + + return ND4J_STATUS_OK; +} + +Nd4jStatus sd::ops::DeclarableOp::validateInput2D(Context &block) { + return validateInputDimensions(block, 2); +} + +Nd4jStatus sd::ops::DeclarableOp::validateInput3D(Context &block) { + return validateInputDimensions(block, 3); +} + +Nd4jStatus sd::ops::DeclarableOp::validateInput4D(Context &block) { + return validateInputDimensions(block, 4); +} + +Nd4jStatus sd::ops::DeclarableOp::validateNonEmptyInput(Context &block) { + if (this->getOpDescriptor()->getNumberOfInputs() == -2 || + this->getOpDescriptor()->getNumberOfInputs() == 0) + return Status::OK(); + + if (block.width() < 1) { + nd4j_printf("%s: no operands provided for the op", + this->getOpName().c_str()); + return ND4J_STATUS_BAD_INPUT; + } + + int cnt = 0; + for (auto p : block.inputs()) { + auto v = block.variable(p); + if (v == nullptr) { + if (!this->getOpName().empty()) { + nd4j_printf("Node [%i:<%s>]: Variable [%i] (%i:%i) is NULL\n", + block.getNodeId(), this->getOpName().c_str(), cnt, p.first, + p.second); + } else { + nd4j_printf("Node [%i:]: Variable [%i] (%i:%i) is NULL\n", + block.getNodeId(), cnt, p.first, p.second); + } + return ND4J_STATUS_BAD_INPUT; + } - void DeclarableOp::registerTypes() { - this->getOpDescriptor()->setSameMode(true); + if (v->variableType() == VariableType::NDARRAY) { + // if array is empty intentionally - we're ok with that + if (v->hasNDArray() && v->isEmpty()) { + continue; + } + + NDArray *aV = v->getNDArray().get(); + + if (aV == nullptr || !aV->nonNull()) { + if (!this->getOpName().empty()) { + nd4j_printf("Node [%i:<%s>]: NDArray [%i] (%i:%i) is NULL\n", + block.getNodeId(), this->getOpName().c_str(), cnt, + p.first, p.second); + } else { + nd4j_printf("Node [%i:]: NDArray [%i] (%i:%i) is NULL\n", + block.getNodeId(), cnt, p.first, p.second); } + return ND4J_STATUS_BAD_INPUT; + } + } - /* - template - int* sd::ops::DeclarableOp::calculateOutputShape(int* inputShape, sd::graph::Block& block) { - // default implementation suits transform, so just returns the same shape - - int* newshape; - ALLOCATE(newshape, block.getWorkspace(), shape::shapeInfoLength(inputShape), int); - memcpy(newshape, inputShape, shape::shapeInfoByteLength(inputShape)); - - return newshape; + cnt++; + } + + return ND4J_STATUS_OK; +} + +Nd4jStatus sd::ops::DeclarableOp::validateOrdersMatch(Context &block) { + if (block.width() == 0) return ND4J_STATUS_OK; + + NDArray *a0 = block.variable(0)->getNDArray().get(); + for (auto p : block.inputs()) { + auto v = block.variable(p); + NDArray *aV = v->getNDArray().get(); + if (a0->ordering() != aV->ordering()) return ND4J_STATUS_BAD_ORDER; + } + + return ND4J_STATUS_OK; +} + +Nd4jStatus sd::ops::DeclarableOp::execute( + sd::graph::RandomGenerator &rng, const std::vector &inputs, + const std::vector &outputs, const std::vector &tArgs, + const std::vector &iArgs, const std::vector &bArgs, + const std::vector &dArgs, bool isInplace, sd::DataType type) { + VariableSpace variableSpace; + FlowPath fp; + // variableSpace.setFlowPath(&fp); + + int cnt = -1; + std::vector in; + for (auto v : inputs) { + if (v == nullptr) continue; + + auto var = std::make_shared(*v, "", cnt); + + in.push_back(cnt); + variableSpace.putVariable(cnt--, var); + } + + int et = 0; + for (auto v : outputs) { + std::pair pair(1, et++); + auto var = std::make_shared(*v, "", pair.first, pair.second); + + variableSpace.putVariable(pair, var); + } + + Context block(1, &variableSpace, false); + block.fillInputs(in); + block.markInplace(isInplace); + + // we need this line for tests basically + // if (rng != nullptr) + block.setRng(rng); + + for (int e = 0; e < tArgs.size(); e++) block.appendT(tArgs.at(e)); + + // FIXME: iargs should be Nd4jLong + for (int e = 0; e < iArgs.size(); e++) + block.appendI(static_cast(iArgs.at(e))); + + for (int e = 0; e < bArgs.size(); e++) + block.appendB(static_cast(bArgs.at(e))); + + for (int e = 0; e < dArgs.size(); e++) block.appendD(dArgs.at(e)); + + Nd4jStatus result = this->execute(&block); + + return result; +} + +Nd4jStatus DeclarableOp::execute(const std::vector &inputs, + const std::vector &outputs) { + return execute(inputs, outputs, std::vector(), + std::vector(), std::vector(), + std::vector()); +} + +template <> +Nd4jStatus DeclarableOp::execute(const std::vector &inputs, + const std::vector &outputs, + std::initializer_list tArgs) { + return execute(inputs, outputs, tArgs, std::vector(), + std::vector(), std::vector()); +} + +template <> +Nd4jStatus DeclarableOp::execute(const std::vector &inputs, + const std::vector &outputs, + std::initializer_list dArgs) { + return execute(inputs, outputs, std::vector(), + std::vector(), std::vector(), dArgs); +} + +template <> +Nd4jStatus DeclarableOp::execute(const std::vector &inputs, + const std::vector &outputs, + std::initializer_list tArgs) { + std::vector realArgs; + for (auto v : tArgs) realArgs.emplace_back(v); + + return execute(inputs, outputs, realArgs, std::vector(), + std::vector(), std::vector()); +} + +template <> +Nd4jStatus DeclarableOp::execute(const std::vector &inputs, + const std::vector &outputs, + std::initializer_list iArgs) { + return execute(inputs, outputs, std::vector(), iArgs, + std::vector(), std::vector()); +} + +template <> +Nd4jStatus DeclarableOp::execute(const std::vector &inputs, + const std::vector &outputs, + std::initializer_list iArgs) { + std::vector realArgs; + for (auto v : iArgs) realArgs.emplace_back(v); + + return execute(inputs, outputs, std::vector(), realArgs, + std::vector(), std::vector()); +} + +template <> +Nd4jStatus DeclarableOp::execute(const std::vector &inputs, + const std::vector &outputs, + std::initializer_list bArgs) { + return execute(inputs, outputs, std::vector(), + std::vector(), bArgs, std::vector()); +} + +Nd4jStatus DeclarableOp::execute(const std::vector &inputs, + const std::vector &outputs, + const std::vector &tArgs, + const std::vector &iArgs, + const std::vector &bArgs, + const std::vector &dArgs, + bool isInplace) { + Context ctx(1); + + for (int e = 0; e < inputs.size(); e++) { + ctx.setInputArray(e, inputs[e] == nullptr ? NDArray() : *inputs[e]); + } + + for (int e = 0; e < outputs.size(); e++) { + ctx.setOutputArray(e, outputs[e] == nullptr ? NDArray() : *outputs[e]); + } + + if (isInplace) ctx.markInplace(isInplace); + + ctx.setIArguments(iArgs); + ctx.setTArguments(tArgs); + ctx.setBArguments(bArgs); + ctx.setDArguments(dArgs); + + return execute(&ctx); +} + +sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs) { + return evaluate(inputs, std::vector(), std::vector(), + std::vector(), std::vector()); +} + +template <> +sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, + std::initializer_list iArgs) { + std::vector realArgs; + for (auto v : iArgs) realArgs.emplace_back(v); + + return evaluate(inputs, std::vector(), realArgs, std::vector(), + std::vector()); +} + +template <> +sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, + std::initializer_list iArgs) { + return evaluate(inputs, std::vector(), iArgs, std::vector(), + std::vector()); +} + +template <> +sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, + std::initializer_list tArgs) { + std::vector realArgs; + for (auto v : tArgs) realArgs.emplace_back(v); + + return evaluate(inputs, realArgs, std::vector(), + std::vector(), std::vector()); +} + +template <> +sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, + std::initializer_list tArgs) { + return evaluate(inputs, tArgs, std::vector(), std::vector(), + std::vector()); +} + +template <> +sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, + std::initializer_list bArgs) { + return evaluate(inputs, std::vector(), std::vector(), bArgs, + std::vector()); +} + +template <> +sd::ResultSet DeclarableOp::evaluate( + const std::vector &inputs, + std::initializer_list bArgs) { + return evaluate(inputs, std::vector(), std::vector(), + std::vector(), bArgs); +} + +sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, + const std::vector &tArgs, + const std::vector &iArgs, + const std::vector &bArgs, + const std::vector &dArgs, + bool isInplace) { + VariableSpace variableSpace; + // ResultSet arrayList; + FlowPath fp; + // variableSpace.setFlowPath(&fp); + + int cnt = -1; + std::vector in; + for (auto v : inputs) { + if (v == nullptr) continue; + + auto var = std::make_shared(*v, "", cnt, 0); + var->markRemovable(false); + in.push_back(cnt); + variableSpace.putVariable(cnt--, var); + } + + Context block(1, &variableSpace, false); + block.fillInputs(in); + block.markInplace(isInplace); + // block.setRNG(ProviderRNG::getInstance().getRNG()); + + for (int e = 0; e < tArgs.size(); e++) block.appendT(tArgs.at(e)); + + for (int e = 0; e < iArgs.size(); e++) block.appendI(iArgs.at(e)); + + for (int e = 0; e < bArgs.size(); e++) block.appendB(bArgs.at(e)); + + for (int e = 0; e < dArgs.size(); e++) block.appendD(dArgs.at(e)); + + Nd4jStatus status = this->execute(&block); + ResultSet arrayList; + if (isInplace) arrayList.setNonRemovable(); + + arrayList.setStatus(status); + if (status != ND4J_STATUS_OK) return arrayList; + + if (!isInplace) { + for (int e = 0; e < DataTypeUtils::max(); e++) { + std::pair pair(1, e); + if (variableSpace.hasVariable(pair)) { + auto var = variableSpace.getVariable(pair); + auto arr = var->getNDArray(); + if (!arr->isAttached()) { + var->markRemovable(false); + arr->setContext(sd::LaunchContext::defaultContext()); + arrayList.push_back(*arr.get()); + } else { + arrayList.push_back(arr->detach()); } - */ + } else + break; + } + } else { + for (auto v : inputs) { + arrayList.push_back(*v); } -} \ No newline at end of file + } + + return arrayList; +} + +sd::ResultSet sd::ops::DeclarableOp::execute(const sd::OpArgsHolder &holder, + bool isInplace) { + // FIXME: add DArgs to OpArgsHolder + return evaluate(holder.getInArrs(), holder.getTArgs(), holder.getIArgs(), + holder.getBArgs(), std::vector(), isInplace); +} + +Nd4jStatus sd::ops::DeclarableOp::validateInputDimensionsMatch(Context &block) { + if (block.width() == 0) return ND4J_STATUS_OK; + + NDArray *a0 = block.array(0).get(); + for (int e = 1; e < block.width(); e++) { + auto aV = block.array(e); + if (!shape::equalsSoft(a0->shapeInfo(), aV->shapeInfo())) + return ND4J_STATUS_BAD_DIMENSIONS; + } + + return ND4J_STATUS_OK; +} + +Nd4jStatus sd::ops::DeclarableOp::validateInputLengthMatch(Context &block) { + if (block.width() == 0) return ND4J_STATUS_OK; + + Nd4jLong l0 = block.array(0)->lengthOf(); + for (uint32_t e = 0; e < block.width(); e++) { + if (l0 != block.array(e)->lengthOf()) return ND4J_STATUS_BAD_LENGTH; + } + + return ND4J_STATUS_OK; +} + +samediff::EmptyHandling DeclarableOp::emptyHandling() { + return samediff::EmptyHandling::EMPTY_SKIP; +} + +void DeclarableOp::registerTypes() { + this->getOpDescriptor()->setSameMode(true); +} + +/* +template +int* sd::ops::DeclarableOp::calculateOutputShape(int* inputShape, +sd::graph::Block& block) { + // default implementation suits transform, so just returns the same shape + + int* newshape; + ALLOCATE(newshape, block.workspace(), shape::shapeInfoLength(inputShape), +int); memcpy(newshape, inputShape, shape::shapeInfoByteLength(inputShape)); + + return newshape; +} +*/ +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/DeclarableReductionOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableReductionOp.cpp index 2dd281991260..cbdd2c63a962 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableReductionOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableReductionOp.cpp @@ -18,46 +18,51 @@ // Created by raver119 on 07.10.2017. // -#include -#include -#include -#include -#include #include +#include +#include +#include +#include +#include namespace sd { - namespace ops { - DeclarableReductionOp::DeclarableReductionOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs) : sd::ops::DeclarableOp(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs) { - // - } +namespace ops { +DeclarableReductionOp::DeclarableReductionOp(int numInputs, int numOutputs, + const char* opName, + bool allowsInplace, int tArgs, + int iArgs) + : sd::ops::DeclarableOp(numInputs, numOutputs, opName, allowsInplace, tArgs, + iArgs) { + // +} - sd::ShapeList* DeclarableReductionOp::calculateOutputShape(sd::ShapeList* inputShape, sd::graph::Context& block) { - // int numDims = INT_ARG(0); - std::vector dims; - if (inputShape->size() > 1) { - // the second argument is axis - auto axis = INPUT_VARIABLE(1); - for (int e = 0; e < axis->lengthOf(); e++) - dims.push_back(axis->e(e)); - } - else if (block.getIArguments()->size()) - for (int e = 0; e < block.getIArguments()->size(); e++) - dims.push_back(INT_ARG(e)); - else if (block.getAxis()->size()) { - dims = *block.getAxis(); //.push_back(axis->e(e)); - } +sd::ShapeList* DeclarableReductionOp::calculateOutputShape( + sd::ShapeList* inputShape, sd::graph::Context& block) { + // int numDims = INT_ARG(0); + std::vector dims; + if (inputShape->size() > 1) { + // the second argument is axis + auto axis = INPUT_VARIABLE(1); + for (int e = 0; e < axis->lengthOf(); e++) dims.push_back(axis->e(e)); + } else if (block.numI()) + for (int e = 0; e < block.numI(); e++) dims.push_back(INT_ARG(e)); + else if (block.getAxis().size()) { + dims = block.getAxis(); //.push_back(axis->e(e)); + } - if (dims.size() > 1) - std::sort(dims.begin(), dims.end()); + if (dims.size() > 1) std::sort(dims.begin(), dims.end()); - // special case - output is scalar - if (dims.size() == 0 || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max())) { - auto newShape = ConstantShapeHelper::getInstance().scalarShapeInfo(block.dataType()); - return SHAPELIST(newShape); - } + // special case - output is scalar + if (dims.size() == 0 || + (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max())) { + auto newShape = + ConstantShapeHelper::getInstance().scalarShapeInfo(DataType::FLOAT32); + return SHAPELIST(newShape); + } - auto newShape = ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), false, false, block.getWorkspace()); - return SHAPELIST(newShape); - } - } + auto newShape = ShapeUtils::evalReduceShapeInfo( + 'c', dims, inputShape->at(0), false, false, block.workspace()); + return SHAPELIST(newShape); } +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/impl/LegacyBroadcastBoolOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyBroadcastBoolOp.cpp index a171ff3394b1..8ec05e96ea85 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyBroadcastBoolOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyBroadcastBoolOp.cpp @@ -18,80 +18,117 @@ // Created by raver119 on 17.10.2017. // -#include +#include +#include #include #include -#include -#include - +#include namespace sd { - namespace ops { - Nd4jStatus LegacyBroadcastBoolOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - - auto z = OUTPUT_VARIABLE(0); - - std::vector dims(*block.getIArguments()); - if (dims.size() > 0) - std::sort(dims.begin(), dims.end()); - - NDArray::prepareSpecialUse({z}, {x, y}); - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x->shapeInfo(), dims); - - PointersManager manager(block.launchContext(), "LegacyBroadcastBoolOp"); - auto pTadShape = Environment::getInstance().isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); //(Nd4jLong *) manager.replicatePointer(tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); - auto pTadOffsets = Environment::getInstance().isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tad.tadOffsets, tad.numTads * sizeof(Nd4jLong)); - - REQUIRE_TRUE(shape::length(packX.primaryShapeInfo()) == y->lengthOf(), 0, "Length of broadcast TAD should be equal to length of Y operand, but got [%i] vs [%i]", (int) shape::length(packX.primaryShapeInfo()), (int) y->lengthOf()); - - if (x == z) - NativeOpExecutioner::execBroadcast(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - y->buffer(), y->shapeInfo(), y->specialBuffer(), y->specialShapeInfo(), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - dims.data(), dims.size(), pTadShape, pTadOffsets, pTadShape, pTadOffsets); - else { - // this is rare, but possible use case - X and Z might have different shapes/strides/orders. In this case we prepare and pass separate TAD info - - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(z->shapeInfo(), dims); - - auto zTadShape = Environment::getInstance().isCPU() ? packZ.primaryShapeInfo() : packZ.specialShapeInfo(); //(Nd4jLong *) manager.replicatePointer(tadZ.tadOnlyShapeInfo, shape::shapeInfoByteLength(tadZ.tadOnlyShapeInfo)); - auto zTadOffsets = Environment::getInstance().isCPU() ? packZ.primaryOffsets() : packZ.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tadZ.tadOffsets, tadZ.numTads * sizeof(Nd4jLong)); - - NativeOpExecutioner::execBroadcast(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - y->buffer(), y->shapeInfo(), y->specialBuffer(), y->specialShapeInfo(), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - dims.data(), dims.size(), pTadShape, pTadOffsets, zTadShape, zTadOffsets); - } - - manager.synchronize(); - STORE_RESULT(*z); - - return Status::OK(); - } +namespace ops { + +Nd4jStatus LegacyBroadcastBoolOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + + auto z = OUTPUT_VARIABLE(0); + + std::vector dims(block.getIArguments()); + if (dims.size() > 0) std::sort(dims.begin(), dims.end()); + + NDArray::prepareSpecialUse({z}, {x, y}); + + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + x->shapeInfo(), dims); + + PointersManager manager(block.launchContext(), "LegacyBroadcastBoolOp"); + auto pTadShape = + Environment::getInstance().isCPU() + ? packX.primaryShapeInfo() + : packX + .specialShapeInfo(); //(Nd4jLong *) + // manager.replicatePointer(tad.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + auto pTadOffsets = + Environment::getInstance().isCPU() + ? packX.primaryOffsets() + : packX.specialOffsets(); //(Nd4jLong *) + // manager.replicatePointer(tad.tadOffsets, + // tad.numTads * sizeof(Nd4jLong)); + + REQUIRE_TRUE(shape::length(packX.primaryShapeInfo()) == y->lengthOf(), 0, + "Length of broadcast TAD should be equal to length of Y " + "operand, but got [%i] vs [%i]", + (int)shape::length(packX.primaryShapeInfo()), + (int)y->lengthOf()); + + if (x == z) + NativeOpExecutioner::execBroadcast( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), y->buffer(), y->shapeInfo(), + y->specialBuffer(), y->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), dims.size(), + pTadShape, pTadOffsets, pTadShape, pTadOffsets); + else { + // this is rare, but possible use case - X and Z might have different + // shapes/strides/orders. In this case we prepare and pass separate TAD info + + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + z->shapeInfo(), dims); + + auto zTadShape = + Environment::getInstance().isCPU() + ? packZ.primaryShapeInfo() + : packZ + .specialShapeInfo(); //(Nd4jLong *) + // manager.replicatePointer(tadZ.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tadZ.tadOnlyShapeInfo)); + auto zTadOffsets = + Environment::getInstance().isCPU() + ? packZ.primaryOffsets() + : packZ + .specialOffsets(); //(Nd4jLong *) + // manager.replicatePointer(tadZ.tadOffsets, + // tadZ.numTads * sizeof(Nd4jLong)); + + NativeOpExecutioner::execBroadcast( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), y->buffer(), y->shapeInfo(), + y->specialBuffer(), y->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), dims.size(), + pTadShape, pTadOffsets, zTadShape, zTadOffsets); + } + + manager.synchronize(); + STORE_RESULT(*z); + + return Status::OK(); +} - LegacyBroadcastBoolOp::LegacyBroadcastBoolOp() : LegacyOp::LegacyOp(2) { - // - } +LegacyBroadcastBoolOp::LegacyBroadcastBoolOp() : LegacyOp::LegacyOp(2) { + // +} - LegacyBroadcastBoolOp::LegacyBroadcastBoolOp(int opNum) : LegacyOp::LegacyOp(2, opNum) { - // - } +LegacyBroadcastBoolOp::LegacyBroadcastBoolOp(int opNum) + : LegacyOp::LegacyOp(2, opNum) { + // +} - LegacyOp* LegacyBroadcastBoolOp::clone() { - return new LegacyBroadcastBoolOp(this->_opNum); - } +LegacyOp *LegacyBroadcastBoolOp::clone() { + return new LegacyBroadcastBoolOp(this->_opNum); +} - /** - * If external NDArray wasn't specified - the same shape is returned by all broadcast ops. - */ - ShapeList* LegacyBroadcastBoolOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(inShape, DataType::BOOL))); - } - } +/** + * If external NDArray wasn't specified - the same shape is returned by all + * broadcast ops. + */ +ShapeList *LegacyBroadcastBoolOp::calculateOutputShape( + ShapeList *inputShape, sd::graph::Context &block) { + auto inShape = inputShape->at(0); + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(inShape, DataType::BOOL))); } +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/impl/LegacyBroadcastOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyBroadcastOp.cpp index c47cc904089a..8faeed3646ee 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyBroadcastOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyBroadcastOp.cpp @@ -18,91 +18,127 @@ // Created by raver119 on 17.10.2017. // -#include +#include +#include +#include #include +#include #include -#include -#include -#include namespace sd { - namespace ops { - Nd4jStatus LegacyBroadcastOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - - auto z = OUTPUT_VARIABLE(0); - - NDArray::prepareSpecialUse({z}, {x, y}); - - std::vector dims(*block.getAxis()); - if (dims.size() == 0 && block.width() > 2) { - auto axis = INPUT_VARIABLE(2); - helpers::adjustAxis(x->rankOf(), axis, dims); - //dims = ShapeUtils::convertAxisToTadTarget(z->rankOf(), dims); - } - if (dims.size() > 0) - std::sort(dims.begin(), dims.end()); - - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x->shapeInfo(), dims); - - auto tadLen = shape::length(packX.primaryShapeInfo()); - REQUIRE_TRUE(tadLen == y->lengthOf(), 0, "Length of broadcast TAD should be equal to length of Y operand, but got [%i] vs [%i]",tadLen, (int) y->lengthOf()); - - PointersManager manager(block.launchContext(),"LegacyBroadcastOp"); - auto pTadShape = Environment::getInstance().isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); //(Nd4jLong *) manager.replicatePointer(tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); - auto pTadOffsets = Environment::getInstance().isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tad.tadOffsets, tad.numTads * sizeof(Nd4jLong)); - - if (x == z) - NativeOpExecutioner::execBroadcast(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - y->buffer(), y->shapeInfo(), y->specialBuffer(), y->specialShapeInfo(), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), dims.data(), dims.size(), pTadShape, pTadOffsets, pTadShape, pTadOffsets); - else { - // this is rare, but possible use case - X and Z might have different shapes/strides/orders. In this case we prepare and pass separate TAD info - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(z->shapeInfo(), dims); - - auto zTadShape = Environment::getInstance().isCPU() ? packZ.primaryShapeInfo() : packZ.specialShapeInfo(); //(Nd4jLong *) manager.replicatePointer(tadZ.tadOnlyShapeInfo, shape::shapeInfoByteLength(tadZ.tadOnlyShapeInfo)); - auto zTadOffsets = Environment::getInstance().isCPU() ? packZ.primaryOffsets() : packZ.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tadZ.tadOffsets, tadZ.numTads * sizeof(Nd4jLong)); - - NativeOpExecutioner::execBroadcast(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - y->buffer(), y->shapeInfo(), y->specialBuffer(), y->specialShapeInfo(), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - dims.data(), dims.size(), pTadShape, pTadOffsets, zTadShape, zTadOffsets); - } - - manager.synchronize(); - STORE_RESULT(*z); - - return Status::OK(); - } - - LegacyBroadcastOp::LegacyBroadcastOp() : LegacyOp::LegacyOp(2) { - // - } - - LegacyBroadcastOp::LegacyBroadcastOp(int opNum) : LegacyOp::LegacyOp(2, opNum) { - // - } +namespace ops { + +Nd4jStatus LegacyBroadcastOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + + auto z = OUTPUT_VARIABLE(0); + + NDArray::prepareSpecialUse({z}, {x, y}); + + std::vector dims(block.getAxis()); + if (dims.size() == 0 && block.width() > 2) { + auto axis = INPUT_VARIABLE(2); + helpers::adjustAxis(x->rankOf(), axis, dims); + // dims = ShapeUtils::convertAxisToTadTarget(z->rankOf(), dims); + } + if (dims.size() > 0) std::sort(dims.begin(), dims.end()); + + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + x->shapeInfo(), dims); + + auto tadLen = shape::length(packX.primaryShapeInfo()); + REQUIRE_TRUE(tadLen == y->lengthOf(), 0, + "Length of broadcast TAD should be equal to length of Y " + "operand, but got [%i] vs [%i]", + tadLen, (int)y->lengthOf()); + + PointersManager manager(block.launchContext(), "LegacyBroadcastOp"); + auto pTadShape = + Environment::getInstance().isCPU() + ? packX.primaryShapeInfo() + : packX + .specialShapeInfo(); //(Nd4jLong *) + // manager.replicatePointer(tad.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + auto pTadOffsets = + Environment::getInstance().isCPU() + ? packX.primaryOffsets() + : packX.specialOffsets(); //(Nd4jLong *) + // manager.replicatePointer(tad.tadOffsets, + // tad.numTads * sizeof(Nd4jLong)); + + if (x == z) + NativeOpExecutioner::execBroadcast( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), y->buffer(), y->shapeInfo(), + y->specialBuffer(), y->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), dims.size(), + pTadShape, pTadOffsets, pTadShape, pTadOffsets); + else { + // this is rare, but possible use case - X and Z might have different + // shapes/strides/orders. In this case we prepare and pass separate TAD info + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + z->shapeInfo(), dims); + + auto zTadShape = + Environment::getInstance().isCPU() + ? packZ.primaryShapeInfo() + : packZ + .specialShapeInfo(); //(Nd4jLong *) + // manager.replicatePointer(tadZ.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tadZ.tadOnlyShapeInfo)); + auto zTadOffsets = + Environment::getInstance().isCPU() + ? packZ.primaryOffsets() + : packZ + .specialOffsets(); //(Nd4jLong *) + // manager.replicatePointer(tadZ.tadOffsets, + // tadZ.numTads * sizeof(Nd4jLong)); + + NativeOpExecutioner::execBroadcast( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), y->buffer(), y->shapeInfo(), + y->specialBuffer(), y->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), dims.size(), + pTadShape, pTadOffsets, zTadShape, zTadOffsets); + } + + manager.synchronize(); + STORE_RESULT(*z); + + return Status::OK(); +} - LegacyOp* LegacyBroadcastOp::clone() { - return new LegacyBroadcastOp(this->_opNum); - } +LegacyBroadcastOp::LegacyBroadcastOp() : LegacyOp::LegacyOp(2) { + // +} - /** - * If external NDArray wasn't specified - the same shape is returned by all broadcast ops. - */ - ShapeList* LegacyBroadcastOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); +LegacyBroadcastOp::LegacyBroadcastOp(int opNum) : LegacyOp::LegacyOp(2, opNum) { + // +} - // FIXME: remove memcpy - Nd4jLong *newShape; - ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(inShape), Nd4jLong); - memcpy(newShape, inShape, shape::shapeInfoByteLength(inShape)); +LegacyOp *LegacyBroadcastOp::clone() { + return new LegacyBroadcastOp(this->_opNum); +} - return SHAPELIST(CONSTANT(newShape)); - } - } +/** + * If external NDArray wasn't specified - the same shape is returned by all + * broadcast ops. + */ +ShapeList *LegacyBroadcastOp::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + auto inShape = inputShape->at(0); + + // FIXME: remove memcpy + Nd4jLong *newShape; + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(inShape), + Nd4jLong); + memcpy(newShape, inShape, shape::shapeInfoByteLength(inShape)); + + return SHAPELIST(CONSTANT(newShape)); } +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/impl/LegacyIndexReduceOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyIndexReduceOp.cpp index a9e8475c0a63..b6a3fc606209 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyIndexReduceOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyIndexReduceOp.cpp @@ -18,180 +18,195 @@ // Created by raver119 on 16.10.2017. // -#include -#include -#include #include #include - +#include +#include +#include namespace sd { - namespace ops { - LegacyIndexReduceOp::LegacyIndexReduceOp() : LegacyOp::LegacyOp(1){ - // - } - - LegacyIndexReduceOp::LegacyIndexReduceOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - // - } - - LegacyOp* LegacyIndexReduceOp::clone() { - return new LegacyIndexReduceOp(this->_opNum); - } - - ShapeList *LegacyIndexReduceOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - - if (block.getAxis()->size() == 0 && block.width() == 1) { - Nd4jLong *newShape; - // in this case we just return scalar - ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong); - newShape[0] = 2; - newShape[1] = 1; - newShape[2] = 1; - newShape[3] = 1; - newShape[4] = 1; - newShape[6] = 1; - newShape[7] = 99; - - auto result = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(newShape, DataType::INT64)); - RELEASE(newShape, block.getWorkspace()); - return SHAPELIST(result); - } else if (block.getAxis()->size()){ - // in this case we're building proper shape for reduction - auto array = INPUT_VARIABLE(0); //new NDArray(nullptr, inShape, block.getWorkspace()); - - auto newShape = ShapeUtils::evalReduceShapeInfo('c', *block.getAxis(), *array, DataType::INT64, false, true, block.workspace()); - return SHAPELIST(newShape); - } - else { - bool allAxes = false; - auto indices = INPUT_VARIABLE(1); - Nd4jLong rank = shape::rank(inShape); - if (indices->lengthOf() == rank) - allAxes = true; - - std::vector axis(indices->lengthOf()); - for (int e = 0; e < indices->lengthOf(); e++) { - // lol otherwise we segfault on macOS - int f = indices->e(e); - axis[e] = f >= 0 ? f : f += rank; - } - if (allAxes){ - Nd4jLong *newShape; - // in this case we just return scalar - ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong); - newShape[0] = 2; - newShape[1] = 1; - newShape[2] = 1; - newShape[3] = 1; - newShape[4] = 1; - newShape[6] = 1; - newShape[7] = 99; - - auto result = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(newShape, DataType::INT64)); - RELEASE(newShape, block.getWorkspace()); - return SHAPELIST(result); - } else { - // in this case we're building proper shape for reduction - auto array = INPUT_VARIABLE(0); //new NDArray(nullptr, inShape, block.getWorkspace()); - return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', axis, *array, DataType::INT64, false, true, block.workspace())); - } - } - } - - /** - * For all reductions rules are simple: either you return scalar, or you return reduced NDArray. - * It solely depends on input shape, and requested dimensions - */ - Nd4jStatus LegacyIndexReduceOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); - - NDArray::prepareSpecialUse({z}, {x}); - - if (z->dataType() != INT64) { - throw std::runtime_error("IndexReduce operations require output to be INT64"); - } - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - - bool allAxes = false; - - ExtraArguments extras(*block.getTArguments()); - PointersManager manager(block.launchContext(), "LegacyIndexReduceOp"); - - if (block.width() == 1) { - if (block.getAxis()->size() == 0) { - // scalar - NativeOpExecutioner::execIndexReduceScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), - x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(x->dataType()), - z->buffer(), z->shapeInfo(), - z->specialBuffer(), z->specialShapeInfo()); - } else { - // TAD - std::vector dims(block.getAxis()->size()); - for (size_t e = 0; e < dims.size(); e++) { - auto axe = block.getAxis()->at(e); - dims[e] = axe < 0 ? axe + x->rankOf(): axe; - } - if (dims.size() > 1) - std::sort(dims.begin(), dims.end()); - - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(x->shapeInfo(), dims); - - NativeOpExecutioner::execIndexReduce(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), - x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(x->dataType()), - reinterpret_cast(z->buffer()), z->shapeInfo(), - z->specialBuffer(), z->specialShapeInfo(), - nullptr, (int) dims.size(), - Environment::getInstance().isCPU() ? tadPack.primaryShapeInfo() : tadPack.specialShapeInfo(), Environment::getInstance().isCPU() ? tadPack.primaryOffsets() : tadPack.specialOffsets()); - } - } else { - // TF mode - auto indices = INPUT_VARIABLE(1); - if (indices->lengthOf() == x->rankOf()) - allAxes = true; - - std::vector axis(indices->lengthOf()); - for (int e = 0; e < indices->lengthOf(); e++) { - // lol otherwise we segfault on macOS - int f = indices->e(e); - axis[e] = f >= 0 ? f : f += x->rankOf(); - } - - if (allAxes) { - NativeOpExecutioner::execIndexReduceScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), - x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(x->dataType()), - z->buffer(), z->shapeInfo(), z->specialBuffer(), - z->specialShapeInfo()); - - } else { - if (indices->lengthOf() > 1) - std::sort(axis.begin(), axis.end()); - - REQUIRE_TRUE(axis.size() > 0, 0, "Some dimensions required for reduction!"); - - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(x->shapeInfo(), axis); - - NativeOpExecutioner::execIndexReduce(block.launchContext(), opNum, - x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(x->dataType()), - reinterpret_cast(z->buffer()), - z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - nullptr, (int) axis.size(), - Environment::getInstance().isCPU() ? tadPack.primaryShapeInfo() : tadPack.specialShapeInfo(), - Environment::getInstance().isCPU() ? tadPack.primaryOffsets() : tadPack.specialOffsets()); - } - } - - manager.synchronize(); - STORE_RESULT(*z); - - return Status::OK(); - } +namespace ops { + +LegacyIndexReduceOp::LegacyIndexReduceOp() : LegacyOp::LegacyOp(1) { + // +} + +LegacyIndexReduceOp::LegacyIndexReduceOp(int opNum) + : LegacyOp::LegacyOp(1, opNum) { + // +} + +LegacyOp *LegacyIndexReduceOp::clone() { + return new LegacyIndexReduceOp(this->_opNum); +} + +ShapeList *LegacyIndexReduceOp::calculateOutputShape( + ShapeList *inputShape, sd::graph::Context &block) { + auto inShape = inputShape->at(0); + + if (block.getAxis().size() == 0 && block.width() == 1) { + Nd4jLong *newShape; + // in this case we just return scalar + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); + newShape[0] = 2; + newShape[1] = 1; + newShape[2] = 1; + newShape[3] = 1; + newShape[4] = 1; + newShape[6] = 1; + newShape[7] = 99; + + auto result = ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(newShape, DataType::INT64)); + RELEASE(newShape, block.workspace()); + return SHAPELIST(result); + } else if (block.getAxis().size()) { + // in this case we're building proper shape for reduction + auto array = + INPUT_VARIABLE(0); // new NDArray(nullptr, inShape, block.workspace()); + + auto newShape = ShapeUtils::evalReduceShapeInfo( + 'c', block.getAxis(), *array, DataType::INT64, false, true, + block.workspace()); + return SHAPELIST(newShape); + } else { + bool allAxes = false; + auto indices = INPUT_VARIABLE(1); + Nd4jLong rank = shape::rank(inShape); + if (indices->lengthOf() == rank) allAxes = true; + + std::vector axis(indices->lengthOf()); + for (int e = 0; e < indices->lengthOf(); e++) { + // lol otherwise we segfault on macOS + int f = indices->e(e); + axis[e] = f >= 0 ? f : f += rank; } -} \ No newline at end of file + if (allAxes) { + Nd4jLong *newShape; + // in this case we just return scalar + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(2), + Nd4jLong); + newShape[0] = 2; + newShape[1] = 1; + newShape[2] = 1; + newShape[3] = 1; + newShape[4] = 1; + newShape[6] = 1; + newShape[7] = 99; + + auto result = ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(newShape, DataType::INT64)); + RELEASE(newShape, block.workspace()); + return SHAPELIST(result); + } else { + // in this case we're building proper shape for reduction + auto array = INPUT_VARIABLE( + 0); // new NDArray(nullptr, inShape, block.workspace()); + return SHAPELIST(ShapeUtils::evalReduceShapeInfo( + 'c', axis, *array, DataType::INT64, false, true, block.workspace())); + } + } +} + +/** + * For all reductions rules are simple: either you return scalar, or you + * return reduced NDArray. It solely depends on input shape, and requested + * dimensions + */ +Nd4jStatus LegacyIndexReduceOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + + NDArray::prepareSpecialUse({z}, {x}); + + if (z->dataType() != INT64) { + throw std::runtime_error( + "IndexReduce operations require output to be INT64"); + } + + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + + bool allAxes = false; + + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyIndexReduceOp"); + + if (block.width() == 1) { + if (block.getAxis().size() == 0) { + // scalar + NativeOpExecutioner::execIndexReduceScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo()); + } else { + // TAD + std::vector dims(block.getAxis().size()); + for (size_t e = 0; e < dims.size(); e++) { + auto axe = block.getAxis().at(e); + dims[e] = axe < 0 ? axe + x->rankOf() : axe; + } + if (dims.size() > 1) std::sort(dims.begin(), dims.end()); + + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + x->shapeInfo(), dims); + + NativeOpExecutioner::execIndexReduce( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), + reinterpret_cast(z->buffer()), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), nullptr, (int)dims.size(), + Environment::getInstance().isCPU() ? tadPack.primaryShapeInfo() + : tadPack.specialShapeInfo(), + Environment::getInstance().isCPU() ? tadPack.primaryOffsets() + : tadPack.specialOffsets()); + } + } else { + // TF mode + auto indices = INPUT_VARIABLE(1); + if (indices->lengthOf() == x->rankOf()) allAxes = true; + + std::vector axis(indices->lengthOf()); + for (int e = 0; e < indices->lengthOf(); e++) { + // lol otherwise we segfault on macOS + int f = indices->e(e); + axis[e] = f >= 0 ? f : f += x->rankOf(); + } + + if (allAxes) { + NativeOpExecutioner::execIndexReduceScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo()); + + } else { + if (indices->lengthOf() > 1) std::sort(axis.begin(), axis.end()); + + REQUIRE_TRUE(axis.size() > 0, 0, + "Some dimensions required for reduction!"); + + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + x->shapeInfo(), axis); + + NativeOpExecutioner::execIndexReduce( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), + reinterpret_cast(z->buffer()), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), nullptr, (int)axis.size(), + Environment::getInstance().isCPU() ? tadPack.primaryShapeInfo() + : tadPack.specialShapeInfo(), + Environment::getInstance().isCPU() ? tadPack.primaryOffsets() + : tadPack.specialOffsets()); + } + } + + manager.synchronize(); + STORE_RESULT(*z); + + return Status::OK(); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyOp.cpp index c179488dffb4..e5afc53cf39c 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyOp.cpp @@ -20,16 +20,45 @@ #include - namespace sd { - namespace ops { - LegacyOp::LegacyOp(int numInputs) : DeclarableOp::DeclarableOp(numInputs , 1, "LegacyOp", false) { - _numInputs = numInputs; - } - - LegacyOp::LegacyOp(int numInputs, int opNum) : DeclarableOp::DeclarableOp(numInputs , 1, "LegacyOp", false) { - _opNum = opNum; - _numInputs = numInputs; - } - } -} \ No newline at end of file +namespace ops { +LegacyOp::LegacyOp(int numInputs) + : DeclarableOp::DeclarableOp(numInputs, 1, "LegacyOp", false) { + _numInputs = numInputs; +} + +LegacyOp::LegacyOp(int numInputs, int opNum) + : DeclarableOp::DeclarableOp(numInputs, 1, "LegacyOp", false) { + _opNum = opNum; + _numInputs = numInputs; +} + +LegacyOp::LegacyOp(const LegacyOp &other) noexcept { + _numInputs = other._numInputs; + _opNum = other._opNum; +} + +LegacyOp &LegacyOp::operator=(const LegacyOp &other) noexcept { + if (this == &other) return *this; + + _numInputs = other._numInputs; + _opNum = other._opNum; + + return *this; +} + +LegacyOp::LegacyOp(LegacyOp &&other) noexcept { + _numInputs = other._numInputs; + _opNum = other._opNum; +} + +LegacyOp &LegacyOp::operator=(LegacyOp &&other) noexcept { + if (this == &other) return *this; + + _numInputs = other._numInputs; + _opNum = other._opNum; + + return *this; +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyPairwiseTransformBoolOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyPairwiseTransformBoolOp.cpp index 8b6e1406ea0e..d5b76e4cda7b 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyPairwiseTransformBoolOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyPairwiseTransformBoolOp.cpp @@ -21,53 +21,65 @@ #include #include - namespace sd { - namespace ops { - LegacyPairwiseTransformBoolOp::LegacyPairwiseTransformBoolOp() : LegacyOp::LegacyOp(2) { - // just a no-op - } +namespace ops { + +LegacyPairwiseTransformBoolOp::LegacyPairwiseTransformBoolOp() + : LegacyOp::LegacyOp(2) { + // just a no-op +} - LegacyPairwiseTransformBoolOp::LegacyPairwiseTransformBoolOp(int opNum) : LegacyOp::LegacyOp(2, opNum) { - // just a no-op - } +LegacyPairwiseTransformBoolOp::LegacyPairwiseTransformBoolOp(int opNum) + : LegacyOp::LegacyOp(2, opNum) { + // just a no-op +} - LegacyOp* LegacyPairwiseTransformBoolOp::clone() { - return new LegacyPairwiseTransformBoolOp(this->_opNum); - } +LegacyOp *LegacyPairwiseTransformBoolOp::clone() { + return new LegacyPairwiseTransformBoolOp(this->_opNum); +} - Nd4jStatus LegacyPairwiseTransformBoolOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +Nd4jStatus LegacyPairwiseTransformBoolOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - NDArray::prepareSpecialUse({z}, {x, y}); + NDArray::prepareSpecialUse({z}, {x, y}); - if (!x->isSameShape(y)) - REQUIRE_TRUE(x->isSameShape(y) || y->isScalar(), 0, "Node_%i: For Pairwise transforms shapes of both operands should be equal but got %s vs %s", block.getNodeId(), ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); + if (!x->isSameShape(y)) + REQUIRE_TRUE(x->isSameShape(y) || y->isScalar(), 0, + "Node_%i: For Pairwise transforms shapes of both operands " + "should be equal but got %s vs %s", + block.getNodeId(), ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str()); - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - ExtraArguments extras(*block.getTArguments()); - PointersManager manager(block.launchContext(), "LegacyPairwiseTransformBoolOp"); + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), + "LegacyPairwiseTransformBoolOp"); - NativeOpExecutioner::execPairwiseTransform(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - y->buffer(), y->shapeInfo(), y->specialBuffer(), y->specialShapeInfo(), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - extras.argumentsAsT(x->dataType())); + NativeOpExecutioner::execPairwiseTransform( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), y->buffer(), y->shapeInfo(), + y->specialBuffer(), y->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), + extras.argumentsAsT(x->dataType())); - manager.synchronize(); - STORE_RESULT(*z); + manager.synchronize(); + STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - /** - * Output shape of PWT operations always the same as input[0] shape, no exclusions. - */ - ShapeList *LegacyPairwiseTransformBoolOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(inShape, DataType::BOOL))); - } - } -} \ No newline at end of file +/** + * Output shape of PWT operations always the same as input[0] shape, no + * exclusions. + */ +ShapeList *LegacyPairwiseTransformBoolOp::calculateOutputShape( + ShapeList *inputShape, sd::graph::Context &block) { + auto inShape = inputShape->at(0); + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(inShape, DataType::BOOL))); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyPairwiseTransformOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyPairwiseTransformOp.cpp index 877d2d73d12b..92111c80a1fa 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyPairwiseTransformOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyPairwiseTransformOp.cpp @@ -21,57 +21,65 @@ #include #include - namespace sd { - namespace ops { - LegacyPairwiseTransformOp::LegacyPairwiseTransformOp() : LegacyOp::LegacyOp(2) { - this->getOpDescriptor()->allowInplace(true); - } - - LegacyPairwiseTransformOp::LegacyPairwiseTransformOp(int opNum) : LegacyOp::LegacyOp(2, opNum) { - this->getOpDescriptor()->allowInplace(true); - } - - LegacyOp* LegacyPairwiseTransformOp::clone() { - return new LegacyPairwiseTransformOp(this->_opNum); - } - - Nd4jStatus LegacyPairwiseTransformOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - NDArray::prepareSpecialUse({z}, {x, y}); - - if (!x->isSameShape(y)) - REQUIRE_TRUE(x->isSameShape(y) || y->isScalar(), 0, "Node_%i: For Pairwise transforms shapes of both operands should be equal but got %s vs %s", block.getNodeId(), ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - - ExtraArguments extras(*block.getTArguments()); - PointersManager manager(block.launchContext(), "LegacyPairwiseTransformOp"); - - NativeOpExecutioner::execPairwiseTransform(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - y->buffer(), y->shapeInfo(), y->specialBuffer(), y->specialShapeInfo(), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - extras.argumentsAsT(z->dataType())); - - manager.synchronize(); - STORE_RESULT(*z); - - return Status::OK(); - } - - /** - * Output shape of PWT operations always the same as input[0] shape, no exclusions. - */ - ShapeList *LegacyPairwiseTransformOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - - Nd4jLong *newShape; - COPY_SHAPE(inShape, newShape); - - return SHAPELIST(CONSTANT(newShape)); - } - } -} \ No newline at end of file +namespace ops { +LegacyPairwiseTransformOp::LegacyPairwiseTransformOp() : LegacyOp::LegacyOp(2) { + this->getOpDescriptor()->allowInplace(true); +} + +LegacyPairwiseTransformOp::LegacyPairwiseTransformOp(int opNum) + : LegacyOp::LegacyOp(2, opNum) { + this->getOpDescriptor()->allowInplace(true); +} + +LegacyOp *LegacyPairwiseTransformOp::clone() { + return new LegacyPairwiseTransformOp(this->_opNum); +} + +Nd4jStatus LegacyPairwiseTransformOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + NDArray::prepareSpecialUse({z}, {x, y}); + + if (!x->isSameShape(y)) + REQUIRE_TRUE(x->isSameShape(y) || y->isScalar(), 0, + "Node_%i: For Pairwise transforms shapes of both operands " + "should be equal but got %s vs %s", + block.getNodeId(), ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str()); + + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyPairwiseTransformOp"); + + NativeOpExecutioner::execPairwiseTransform( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), y->buffer(), y->shapeInfo(), + y->specialBuffer(), y->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), + extras.argumentsAsT(z->dataType())); + + manager.synchronize(); + STORE_RESULT(*z); + + return Status::OK(); +} + +/** + * Output shape of PWT operations always the same as input[0] shape, no + * exclusions. + */ +ShapeList *LegacyPairwiseTransformOp::calculateOutputShape( + ShapeList *inputShape, sd::graph::Context &block) { + auto inShape = inputShape->at(0); + + Nd4jLong *newShape; + COPY_SHAPE(inShape, newShape); + + return SHAPELIST(CONSTANT(newShape)); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp index 09c0a054a4f1..e54d0d937fb1 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp @@ -18,431 +18,468 @@ // Created by raver119 on 16.10.2017. // -#include -#include -#include #include #include +#include +#include #include +#include namespace sd { - namespace ops { - LegacyRandomOp::LegacyRandomOp() : LegacyOp::LegacyOp(1) { - // just a no-op - } - - LegacyRandomOp::LegacyRandomOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - // just a no-op - } - - LegacyOp* LegacyRandomOp::clone() { - return new LegacyRandomOp(this->_opNum); - } - - template - Nd4jStatus LegacyRandomOp::validateAndExecute_(Context &block) { - auto input = INPUT_VARIABLE(0); - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - - /* - (0, randomOps::UniformDistribution) ,\ - (1, randomOps::DropOut) ,\ - (2, randomOps::DropOutInverted) ,\ - (3, randomOps::ProbablisticMerge) ,\ - (4, randomOps::Linspace) ,\ - (5, randomOps::Choice) ,\ - (6, randomOps::GaussianDistribution) ,\ - (7, randomOps::BernoulliDistribution) ,\ - (8, randomOps::BinomialDistribution),\ - (9, randomOps::BinomialDistributionEx),\ - (10, randomOps::LogNormalDistribution) ,\ - (11, randomOps::TruncatedNormalDistribution) ,\ - (12, randomOps::AlphaDropOut) - */ - switch(opNum) { - case sd::random::UniformDistribution: { - // uniform distribution - T from, to; - if (block.width() > 2) { - auto arg1 = INPUT_VARIABLE(1); - auto arg2 = INPUT_VARIABLE(2); - REQUIRE_TRUE(arg1->isScalar(), 0, "Uniform: Second argument must be scalar"); - REQUIRE_TRUE(arg2->isScalar(), 0, "Uniform: Third argument must be scalar"); - - from = arg1->e(0); - to = arg2->e(0); - } else if (block.getTArguments()->size() == 2) { - from = T_ARG(0); - to = T_ARG(1); - } else { - REQUIRE_TRUE(false, 0, "Uniform requires either TArgs or 3 arguments to be present"); - } - - auto z = OUTPUT_VARIABLE(0); //NDArrayFactory::create_('c', shape, block.getWorkspace()); - - RandomLauncher::fillUniform(block.launchContext(), block.randomGenerator(), z, from, to); - - // FIXME: - //OVERWRITE_RESULT(z); - } - break; - case sd::random::DropOut: { - auto z = OUTPUT_VARIABLE(0); - - T prob; - if (block.width() > 1) { - auto arg = INPUT_VARIABLE(1); - REQUIRE_TRUE(arg->isScalar(), 0, "DropOut: Second argument must be scalar"); - - prob = arg->e(0); - } else if (block.getTArguments()->size() > 0) { - prob = T_ARG(0); - } else { - REQUIRE_TRUE(false, 0, "DropOut requires either TArgs or second argument to be present"); - } - - if (!block.isInplace()) - z->assign(input); - - RandomLauncher::applyDropOut(block.launchContext(), block.randomGenerator(), z, prob); - } - break; - case sd::random::DropOutInverted: { - auto z = OUTPUT_VARIABLE(0); - sd::ops::dropout op; - return op.execute(&block); - } - break; - case sd::random::GaussianDistribution: { - // gaussian distribution - T mean, stdev; - if (block.width() > 2) { - auto arg1 = INPUT_VARIABLE(1); - auto arg2 = INPUT_VARIABLE(2); - REQUIRE_TRUE(arg1->isScalar(), 0, "Gaussian: Second argument must be scalar"); - REQUIRE_TRUE(arg2->isScalar(), 0, "Gaussian: Third argument must be scalar"); - - mean = arg1->e(0); - stdev = arg2->e(0); - } else if (block.getTArguments()->size() == 2) { - mean = T_ARG(0); - stdev = T_ARG(1); - } else { - REQUIRE_TRUE(false, 0, "Gaussian requires either TArgs or 3 arguments to be present"); - } - - REQUIRE_TRUE(input->isVector(), 0, "Gaussian requires pure shape as first argument"); - - std::vector shape(input->lengthOf()); - for (int e = 0; e < input->lengthOf(); e++) - shape[e] = input->e(e); - - auto z = OUTPUT_VARIABLE(0);//NDArrayFactory::create_('c', shape, block.getWorkspace()); - - RandomLauncher::fillGaussian(block.launchContext(), block.randomGenerator(), z, mean, stdev); - - // FIXME: !! - //OVERWRITE_RESULT(z); - } - break; - case sd::random::BernoulliDistribution: { - // bernoulli distribution - T prob; - if (block.width() > 1) { - auto arg1 = INPUT_VARIABLE(1); - REQUIRE_TRUE(arg1->isScalar(), 0, "Bernoulli: Second argument must be scalar"); - - prob = arg1->e(0); - } else if (block.getTArguments()->size() > 0) { - prob = T_ARG(0); - } else { - REQUIRE_TRUE(false, 0, "Bernoulli requires either 1 TArg or 2 arguments to be present"); - } - - REQUIRE_TRUE(input->isVector(), 0, "Bernoulli requires pure shape as first argument"); - - std::vector shape(input->lengthOf()); - for (int e = 0; e < input->lengthOf(); e++) - shape[e] = input->e(e); - - auto z = OUTPUT_VARIABLE(0); // NDArrayFactory::create_('c', shape, block.getWorkspace()); - - RandomLauncher::fillBernoulli(block.launchContext(), block.randomGenerator(), z, prob); - - // FIXME: - //OVERWRITE_RESULT(z); - } - break; - case sd::random::BinomialDistributionEx: { - // BinomialEx distribution - T prob; - int trials; - if (block.width() > 2) { - auto arg1 = INPUT_VARIABLE(1); - auto arg2 = INPUT_VARIABLE(2); - REQUIRE_TRUE(arg1->isScalar(), 0, "Binomial: Second argument must be scalar"); - REQUIRE_TRUE(arg2->isScalar(), 0, "Binomial: Third argument must be scalar"); - - trials = arg1->e(0); - prob = arg2->e(0); - } else if (block.getTArguments()->size() == 1 && block.getIArguments()->size() == 1) { - trials = INT_ARG(0); - prob = T_ARG(0); - } else { - REQUIRE_TRUE(false, 0, "Binomial requires either TArgs/IArgs or 3 arguments to be present"); - } - - REQUIRE_TRUE(input->isVector(), 0, "Binomial requires pure shape as first argument"); - - std::vector shape(input->lengthOf()); - for (int e = 0; e < input->lengthOf(); e++) - shape[e] = input->e(e); - - auto z = OUTPUT_VARIABLE(0);//NDArrayFactory::create_('c', shape, block.getWorkspace()); - - RandomLauncher::fillBinomial(block.launchContext(), block.randomGenerator(), z, trials, prob); - - // FIXME: !!! - //OVERWRITE_RESULT(z); - } - break; - case sd::random::LogNormalDistribution: { - // lognorm distribution - T mean, stdev; - if (block.width() > 2) { - auto arg1 = INPUT_VARIABLE(1); - auto arg2 = INPUT_VARIABLE(2); - REQUIRE_TRUE(arg1->isScalar(), 0, "LogNormal: Second argument must be scalar"); - REQUIRE_TRUE(arg2->isScalar(), 0, "LogNormal: Third argument must be scalar"); - - mean = arg1->e(0); - stdev = arg2->e(0); - } else if (block.getTArguments()->size() == 2) { - mean = T_ARG(0); - stdev = T_ARG(1); - } else { - REQUIRE_TRUE(false, 0, "LogNormal requires either TArgs or 3 arguments to be present"); - } - - REQUIRE_TRUE(input->isVector(), 0, "LogNormal requires pure shape as first argument"); - - std::vector shape(input->lengthOf()); - for (int e = 0; e < input->lengthOf(); e++) - shape[e] = input->e(e); - - auto z = OUTPUT_VARIABLE(0);//NDArrayFactory::create_('c', shape, block.getWorkspace()); - - RandomLauncher::fillLogNormal(block.launchContext(), block.randomGenerator(), z, mean, stdev); - - // FIXME: !! - //OVERWRITE_RESULT(z); - } - break; - case sd::random::TruncatedNormalDistribution: { - // truncated norm distribution - T mean, stdev; - if (block.width() > 2) { - auto arg1 = INPUT_VARIABLE(1); - auto arg2 = INPUT_VARIABLE(2); - REQUIRE_TRUE(arg1->isScalar(), 0, "TruncatedNormal: Second argument must be scalar"); - REQUIRE_TRUE(arg2->isScalar(), 0, "TruncatedNormal: Third argument must be scalar"); - - mean = arg1->e(0); - stdev = arg2->e(0); - } else if (block.getTArguments()->size() == 2) { - mean = T_ARG(0); - stdev = T_ARG(1); - } else { - REQUIRE_TRUE(false, 0, "TruncatedNormal requires either TArgs or 3 arguments to be present"); - } - - REQUIRE_TRUE(input->isVector(), 0, "TruncatedNormal requires pure shape as first argument"); - - std::vector shape(input->lengthOf()); - for (int e = 0; e < input->lengthOf(); e++) - shape[e] = input->e(e); - - auto z = OUTPUT_VARIABLE(0); // NDArrayFactory::create_('c', shape, block.getWorkspace()); - - RandomLauncher::fillTruncatedNormal(block.launchContext(), block.randomGenerator(), z, mean, stdev); - } - break; - case sd::random::AlphaDropOut: { - auto z = OUTPUT_VARIABLE(0); - - T prob, a, b, pa; - if (block.width() > 4) { - auto arg1 = INPUT_VARIABLE(1); - auto arg2 = INPUT_VARIABLE(2); - auto arg3 = INPUT_VARIABLE(3); - auto arg4 = INPUT_VARIABLE(4); - REQUIRE_TRUE(arg1->isScalar(), 0, "AlphaDropOut: Second argument must be scalar"); - REQUIRE_TRUE(arg2->isScalar(), 0, "AlphaDropOut: Third argument must be scalar"); - REQUIRE_TRUE(arg3->isScalar(), 0, "AlphaDropOut: Fourth argument must be scalar"); - REQUIRE_TRUE(arg4->isScalar(), 0, "AlphaDropOut: Fifth argument must be scalar"); - - prob = arg1->e(0); - a = arg2->e(0); - b = arg3->e(0); - pa = arg4->e(0); - } else if (block.getTArguments()->size() == 4) { - prob = T_ARG(0); - a = T_ARG(1); - b = T_ARG(2); - pa = T_ARG(3); - } else { - REQUIRE_TRUE(false, 0, "AlphaDropOut requires either TArgs or 5 arguments to be present"); - } - - if (!block.isInplace()) - z->assign(input); - - RandomLauncher::applyAlphaDropOut(block.launchContext(), block.randomGenerator(), z, prob, a, b, pa); - } - break; - case sd::random::Linspace: { - auto z = OUTPUT_VARIABLE(0); - auto start = INPUT_VARIABLE(0); - auto finish = INPUT_VARIABLE(1); - auto numOfElements = INPUT_VARIABLE(2); - - z->linspace(start->e(0), (finish->e(0) - start->e(0)) / (numOfElements->e(0) - 1.)); - } - break; - default: { - nd4j_printf("Unknown random op requested: [%i]\n", opNum); - return ND4J_STATUS_KERNEL_FAILURE; - } - } - - return Status::OK(); - } - - Nd4jStatus LegacyRandomOp::validateAndExecute(Context &block) { -// REQUIRE_TRUE(block.getRNG() != nullptr, 0, "RNG should be provided for LegacyRandomOp, but got NULL instead at node_%i", block.nodeId()) - - auto z = OUTPUT_VARIABLE(0); - BUILD_SINGLE_SELECTOR(z->dataType(), return validateAndExecute_, (block), FLOAT_TYPES); - } - - /** - * For transform operations, output shape always equals to input shape. With just a few exclusions, like im2col and col2im. - * But these ops already have CustomOp implementations. - * - */ - ShapeList *LegacyRandomOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - auto xType = ArrayOptions::dataType(inShape); - Nd4jLong *newShape; - if (DataTypeUtils::isR(xType)) { - COPY_SHAPE(inShape, newShape); - - return SHAPELIST(CONSTANT(newShape)); - } else if (DataTypeUtils::isZ(xType)) { - auto zShapeArr = INPUT_VARIABLE(0); - auto zShapeVector = zShapeArr->asVectorT(); - auto dtype = block.dataType(); - - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(dtype, 'c', zShapeVector)); - } else - throw std::runtime_error("LegacyRandomOp: Unknown input data type!"); - } - - Nd4jStatus LegacyRandomOp::execute(Context* block) { - return DeclarableOp::execute(block); - } - - sd::ResultSet LegacyRandomOp::execute(sd::graph::RandomGenerator& rng, std::initializer_list inputs, std::initializer_list tArgs, std::initializer_list iArgs, bool isInplace) { - std::vector ins(inputs); - std::vector tas(tArgs); - std::vector ias(iArgs); - return this->execute(rng, ins, tas, ias, isInplace); - } - - sd::ResultSet LegacyRandomOp::execute(sd::graph::RandomGenerator& rng, std::vector& inputs, std::vector& tArgs, std::vector& iArgs, bool isInplace) { - VariableSpace variableSpace; - ResultSet arrayList; - //ResultSet arrayList; - - if (isInplace) - arrayList.setNonRemovable(); - - int cnt = -1; - std::vector in; - for (auto v: inputs) { - if (v == nullptr) - continue; - - auto var = new Variable(v); - var->markRemovable(false); - in.push_back(cnt); - variableSpace.putVariable(cnt--, var); - } - - Context block(1, &variableSpace, false); - // FIX ME: implement setRng method - block.setRng(rng); - block.fillInputs(in); - block.markInplace(isInplace); - - for (int e = 0; e < tArgs.size(); e++) - block.getTArguments()->emplace_back(tArgs.at(e)); - - - for (int e = 0; e < iArgs.size(); e++) - block.getIArguments()->emplace_back(iArgs.at(e)); - - Nd4jStatus status = this->execute(&block); - arrayList.setStatus(status); - if (status != ND4J_STATUS_OK) - return arrayList; - - - for (int e = 0; e < DataTypeUtils::max(); e++) { - std::pair pair(1, e); - if (variableSpace.hasVariable(pair)) { - auto var = variableSpace.getVariable(pair); - auto arr = var->getNDArray(); - if (!arr->isAttached()) { - var->markRemovable(false); - arrayList.push_back(arr); - } else { - arrayList.push_back(arr->detach()); - } - } else - break; - } - - return arrayList; - } - - Nd4jStatus LegacyRandomOp::validateDataTypes(Context& block) { - if (block.isFastPath()) { - // in this case we'll roll through pre-defined outputs - auto fpo = block.fastpath_out(); - for (auto v:fpo) { - if (v != nullptr) { - if (!v->isR()) - return ND4J_STATUS_BAD_ARGUMENTS; - } - } - } else { - std::pair pair(block.nodeId(), 0); - if (block.getVariableSpace()->hasVariable(pair)) { - auto var = block.variable(pair); - if (!var->hasNDArray()) - return ND4J_STATUS_BAD_ARGUMENTS; - - auto arr = var->getNDArray(); - if (!arr->isR()) - return ND4J_STATUS_BAD_ARGUMENTS; - } - } - - return Status::OK(); - } - - BUILD_SINGLE_TEMPLATE(template Nd4jStatus LegacyRandomOp::validateAndExecute_, (Context&), FLOAT_TYPES); +namespace ops { + +LegacyRandomOp::LegacyRandomOp() : LegacyOp::LegacyOp(1) { + // just a no-op +} + +LegacyRandomOp::LegacyRandomOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { + // just a no-op +} + +LegacyOp* LegacyRandomOp::clone() { return new LegacyRandomOp(this->_opNum); } + +template +Nd4jStatus LegacyRandomOp::validateAndExecute_(Context& block) { + auto input = INPUT_VARIABLE(0); + + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + + /* + (0, randomOps::UniformDistribution) ,\ + (1, randomOps::DropOut) ,\ + (2, randomOps::DropOutInverted) ,\ + (3, randomOps::ProbablisticMerge) ,\ + (4, randomOps::Linspace) ,\ + (5, randomOps::Choice) ,\ + (6, randomOps::GaussianDistribution) ,\ + (7, randomOps::BernoulliDistribution) ,\ + (8, randomOps::BinomialDistribution),\ + (9, randomOps::BinomialDistributionEx),\ + (10, randomOps::LogNormalDistribution) ,\ + (11, randomOps::TruncatedNormalDistribution) ,\ + (12, randomOps::AlphaDropOut) + */ + switch (opNum) { + case sd::random::UniformDistribution: { + // uniform distribution + T from, to; + if (block.width() > 2) { + auto arg1 = INPUT_VARIABLE(1); + auto arg2 = INPUT_VARIABLE(2); + REQUIRE_TRUE(arg1->isScalar(), 0, + "Uniform: Second argument must be scalar"); + REQUIRE_TRUE(arg2->isScalar(), 0, + "Uniform: Third argument must be scalar"); + + from = arg1->e(0); + to = arg2->e(0); + } else if (block.numT() == 2) { + from = T_ARG(0); + to = T_ARG(1); + } else { + REQUIRE_TRUE( + false, 0, + "Uniform requires either TArgs or 3 arguments to be present"); + } + + auto z = OUTPUT_VARIABLE( + 0); // NDArrayFactory::create_('c', shape, block.workspace()); + + RandomLauncher::fillUniform(block.launchContext(), + block.randomGenerator(), z, from, to); + + // FIXME: + // OVERWRITE_RESULT(z); + } break; + case sd::random::DropOut: { + auto z = OUTPUT_VARIABLE(0); + + T prob; + if (block.width() > 1) { + auto arg = INPUT_VARIABLE(1); + REQUIRE_TRUE(arg->isScalar(), 0, + "DropOut: Second argument must be scalar"); + + prob = arg->e(0); + } else if (block.numT() > 0) { + prob = T_ARG(0); + } else { + REQUIRE_TRUE( + false, 0, + "DropOut requires either TArgs or second argument to be present"); + } + + if (!block.isInplace()) z->assign(input); + + RandomLauncher::applyDropOut(block.launchContext(), + block.randomGenerator(), z, prob); + } break; + case sd::random::DropOutInverted: { + auto z = OUTPUT_VARIABLE(0); + sd::ops::dropout op; + return op.execute(&block); + } break; + case sd::random::GaussianDistribution: { + // gaussian distribution + T mean, stdev; + if (block.width() > 2) { + auto arg1 = INPUT_VARIABLE(1); + auto arg2 = INPUT_VARIABLE(2); + REQUIRE_TRUE(arg1->isScalar(), 0, + "Gaussian: Second argument must be scalar"); + REQUIRE_TRUE(arg2->isScalar(), 0, + "Gaussian: Third argument must be scalar"); + + mean = arg1->e(0); + stdev = arg2->e(0); + } else if (block.numT() == 2) { + mean = T_ARG(0); + stdev = T_ARG(1); + } else { + REQUIRE_TRUE( + false, 0, + "Gaussian requires either TArgs or 3 arguments to be present"); + } + + REQUIRE_TRUE(input->isVector(), 0, + "Gaussian requires pure shape as first argument"); + + std::vector shape(input->lengthOf()); + for (int e = 0; e < input->lengthOf(); e++) + shape[e] = input->e(e); + + auto z = OUTPUT_VARIABLE( + 0); // NDArrayFactory::create_('c', shape, block.workspace()); + + RandomLauncher::fillGaussian(block.launchContext(), + block.randomGenerator(), z, mean, stdev); + + // FIXME: !! + // OVERWRITE_RESULT(z); + } break; + case sd::random::BernoulliDistribution: { + // bernoulli distribution + T prob; + if (block.width() > 1) { + auto arg1 = INPUT_VARIABLE(1); + REQUIRE_TRUE(arg1->isScalar(), 0, + "Bernoulli: Second argument must be scalar"); + + prob = arg1->e(0); + } else if (block.numT() > 0) { + prob = T_ARG(0); + } else { + REQUIRE_TRUE( + false, 0, + "Bernoulli requires either 1 TArg or 2 arguments to be present"); + } + + REQUIRE_TRUE(input->isVector(), 0, + "Bernoulli requires pure shape as first argument"); + + std::vector shape(input->lengthOf()); + for (int e = 0; e < input->lengthOf(); e++) + shape[e] = input->e(e); + + auto z = OUTPUT_VARIABLE( + 0); // NDArrayFactory::create_('c', shape, block.workspace()); + + RandomLauncher::fillBernoulli(block.launchContext(), + block.randomGenerator(), z, prob); + + // FIXME: + // OVERWRITE_RESULT(z); + } break; + case sd::random::BinomialDistributionEx: { + // BinomialEx distribution + T prob; + int trials; + if (block.width() > 2) { + auto arg1 = INPUT_VARIABLE(1); + auto arg2 = INPUT_VARIABLE(2); + REQUIRE_TRUE(arg1->isScalar(), 0, + "Binomial: Second argument must be scalar"); + REQUIRE_TRUE(arg2->isScalar(), 0, + "Binomial: Third argument must be scalar"); + + trials = arg1->e(0); + prob = arg2->e(0); + } else if (block.numT() == 1 && block.numI() == 1) { + trials = INT_ARG(0); + prob = T_ARG(0); + } else { + REQUIRE_TRUE(false, 0, + "Binomial requires either TArgs/IArgs or 3 arguments to " + "be present"); + } + + REQUIRE_TRUE(input->isVector(), 0, + "Binomial requires pure shape as first argument"); + + std::vector shape(input->lengthOf()); + for (int e = 0; e < input->lengthOf(); e++) + shape[e] = input->e(e); + + auto z = OUTPUT_VARIABLE( + 0); // NDArrayFactory::create_('c', shape, block.workspace()); + + RandomLauncher::fillBinomial(block.launchContext(), + block.randomGenerator(), z, trials, prob); + + // FIXME: !!! + // OVERWRITE_RESULT(z); + } break; + case sd::random::LogNormalDistribution: { + // lognorm distribution + T mean, stdev; + if (block.width() > 2) { + auto arg1 = INPUT_VARIABLE(1); + auto arg2 = INPUT_VARIABLE(2); + REQUIRE_TRUE(arg1->isScalar(), 0, + "LogNormal: Second argument must be scalar"); + REQUIRE_TRUE(arg2->isScalar(), 0, + "LogNormal: Third argument must be scalar"); + + mean = arg1->e(0); + stdev = arg2->e(0); + } else if (block.numT() == 2) { + mean = T_ARG(0); + stdev = T_ARG(1); + } else { + REQUIRE_TRUE( + false, 0, + "LogNormal requires either TArgs or 3 arguments to be present"); + } + + REQUIRE_TRUE(input->isVector(), 0, + "LogNormal requires pure shape as first argument"); + + std::vector shape(input->lengthOf()); + for (int e = 0; e < input->lengthOf(); e++) + shape[e] = input->e(e); + + auto z = OUTPUT_VARIABLE( + 0); // NDArrayFactory::create_('c', shape, block.workspace()); + + RandomLauncher::fillLogNormal(block.launchContext(), + block.randomGenerator(), z, mean, stdev); + + // FIXME: !! + // OVERWRITE_RESULT(z); + } break; + case sd::random::TruncatedNormalDistribution: { + // truncated norm distribution + T mean, stdev; + if (block.width() > 2) { + auto arg1 = INPUT_VARIABLE(1); + auto arg2 = INPUT_VARIABLE(2); + REQUIRE_TRUE(arg1->isScalar(), 0, + "TruncatedNormal: Second argument must be scalar"); + REQUIRE_TRUE(arg2->isScalar(), 0, + "TruncatedNormal: Third argument must be scalar"); + + mean = arg1->e(0); + stdev = arg2->e(0); + } else if (block.numT() == 2) { + mean = T_ARG(0); + stdev = T_ARG(1); + } else { + REQUIRE_TRUE(false, 0, + "TruncatedNormal requires either TArgs or 3 arguments to " + "be present"); + } + + REQUIRE_TRUE(input->isVector(), 0, + "TruncatedNormal requires pure shape as first argument"); + + std::vector shape(input->lengthOf()); + for (int e = 0; e < input->lengthOf(); e++) + shape[e] = input->e(e); + + auto z = OUTPUT_VARIABLE( + 0); // NDArrayFactory::create_('c', shape, block.workspace()); + + RandomLauncher::fillTruncatedNormal( + block.launchContext(), block.randomGenerator(), z, mean, stdev); + } break; + case sd::random::AlphaDropOut: { + auto z = OUTPUT_VARIABLE(0); + + T prob, a, b, pa; + if (block.width() > 4) { + auto arg1 = INPUT_VARIABLE(1); + auto arg2 = INPUT_VARIABLE(2); + auto arg3 = INPUT_VARIABLE(3); + auto arg4 = INPUT_VARIABLE(4); + REQUIRE_TRUE(arg1->isScalar(), 0, + "AlphaDropOut: Second argument must be scalar"); + REQUIRE_TRUE(arg2->isScalar(), 0, + "AlphaDropOut: Third argument must be scalar"); + REQUIRE_TRUE(arg3->isScalar(), 0, + "AlphaDropOut: Fourth argument must be scalar"); + REQUIRE_TRUE(arg4->isScalar(), 0, + "AlphaDropOut: Fifth argument must be scalar"); + + prob = arg1->e(0); + a = arg2->e(0); + b = arg3->e(0); + pa = arg4->e(0); + } else if (block.numT() == 4) { + prob = T_ARG(0); + a = T_ARG(1); + b = T_ARG(2); + pa = T_ARG(3); + } else { + REQUIRE_TRUE( + false, 0, + "AlphaDropOut requires either TArgs or 5 arguments to be present"); + } + + if (!block.isInplace()) z->assign(input); + + RandomLauncher::applyAlphaDropOut( + block.launchContext(), block.randomGenerator(), z, prob, a, b, pa); + } break; + case sd::random::Linspace: { + auto z = OUTPUT_VARIABLE(0); + auto start = INPUT_VARIABLE(0); + auto finish = INPUT_VARIABLE(1); + auto numOfElements = INPUT_VARIABLE(2); + + z->linspace(start->e(0), + (finish->e(0) - start->e(0)) / + (numOfElements->e(0) - 1.)); + } break; + default: { + nd4j_printf("Unknown random op requested: [%i]\n", opNum); + return ND4J_STATUS_KERNEL_FAILURE; + } + } + + return Status::OK(); +} + +Nd4jStatus LegacyRandomOp::validateAndExecute(Context& block) { + // REQUIRE_TRUE(block.getRNG() != nullptr, 0, "RNG should be + // provided for LegacyRandomOp, but got NULL instead at node_%i", + // block.nodeId()) + + auto z = OUTPUT_VARIABLE(0); + BUILD_SINGLE_SELECTOR(z->dataType(), return validateAndExecute_, (block), + FLOAT_TYPES); +} + +/** + * For transform operations, output shape always equals to input shape. With + * just a few exclusions, like im2col and col2im. But these ops already have + * CustomOp implementations. + * + */ +ShapeList* LegacyRandomOp::calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) { + auto inShape = inputShape->at(0); + auto xType = ArrayOptions::dataType(inShape); + Nd4jLong* newShape; + if (DataTypeUtils::isR(xType)) { + COPY_SHAPE(inShape, newShape); + + return SHAPELIST(CONSTANT(newShape)); + } else if (DataTypeUtils::isZ(xType)) { + auto zShapeArr = INPUT_VARIABLE(0); + auto zShapeVector = zShapeArr->asVectorT(); + auto dtype = block.numD() > 0 ? D_ARG(0) : sd::DataType::FLOAT32; + + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + dtype, 'c', zShapeVector)); + } else + throw std::runtime_error("LegacyRandomOp: Unknown input data type!"); +} + +Nd4jStatus LegacyRandomOp::execute(Context* block) { + return DeclarableOp::execute(block); +} + +sd::ResultSet LegacyRandomOp::execute(sd::graph::RandomGenerator& rng, + const std::vector& inputs, + const std::vector& tArgs, + const std::vector& iArgs, + const std::vector& dArgs, + bool isInplace) { + VariableSpace variableSpace; + ResultSet arrayList; + // ResultSet arrayList; + + if (isInplace) arrayList.setNonRemovable(); + + int cnt = -1; + std::vector in; + for (auto v : inputs) { + if (v == nullptr) continue; + + auto var = std::make_shared(*v, "", cnt); + + in.push_back(cnt); + variableSpace.putVariable(cnt--, var); + } + + Context block(1, &variableSpace, false); + // FIX ME: implement setRng method + block.setRng(rng); + block.fillInputs(in); + block.markInplace(isInplace); + + for (int e = 0; e < tArgs.size(); e++) block.appendT(tArgs.at(e)); + + for (int e = 0; e < iArgs.size(); e++) block.appendI(iArgs.at(e)); + + for (int e = 0; e < dArgs.size(); e++) block.appendD(dArgs.at(e)); + + Nd4jStatus status = this->execute(&block); + arrayList.setStatus(status); + if (status != ND4J_STATUS_OK) return arrayList; + + for (int e = 0; e < DataTypeUtils::max(); e++) { + std::pair pair(1, e); + if (variableSpace.hasVariable(pair)) { + auto var = variableSpace.getVariable(pair); + auto arr = var->getNDArray(); + if (!arr->isAttached()) { + var->markRemovable(false); + arrayList.push_back(*arr.get()); + } else { + arrayList.push_back(arr->detach()); + } + } else + break; + } + + return arrayList; +} + +Nd4jStatus LegacyRandomOp::validateDataTypes(Context& block) { + if (block.isFastPath()) { + // in this case we'll roll through pre-defined outputs + auto fpo = block.fastpath_out(); + for (auto v : fpo) { + if (v != nullptr) { + if (!v->isR()) return ND4J_STATUS_BAD_ARGUMENTS; + } } + } else { + std::pair pair(block.nodeId(), 0); + if (block.getVariableSpace()->hasVariable(pair)) { + auto var = block.variable(pair); + if (!var->hasNDArray()) return ND4J_STATUS_BAD_ARGUMENTS; + + auto arr = var->getNDArray(); + if (!arr->isR()) return ND4J_STATUS_BAD_ARGUMENTS; + } + } + + return Status::OK(); } + +BUILD_SINGLE_TEMPLATE(template Nd4jStatus LegacyRandomOp::validateAndExecute_, + (Context&), FLOAT_TYPES); +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/impl/LegacyReduce3Op.cpp b/libnd4j/include/ops/declarable/impl/LegacyReduce3Op.cpp index 700e0dba955c..67c2ff4a9bdc 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyReduce3Op.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyReduce3Op.cpp @@ -18,106 +18,142 @@ // Created by raver119 on 17.10.2017. // -#include +#include +#include #include #include -#include -#include +#include namespace sd { - namespace ops { - Nd4jStatus LegacyReduce3Op::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - NDArray::prepareSpecialUse({z}, {x, y}); - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - - nd4j_debug("Executing LegacyReduce3Op: [%i]\n", opNum); - - ExtraArguments extras(*block.getTArguments()); - PointersManager manager(block.launchContext(), "LegacyReduce3Op"); - - if (x->isSameShape(y) && (block.getIArguments()->size() == 0 || (block.getIArguments()->size() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()))) { - // reduce3 to scalar - NativeOpExecutioner::execReduce3Scalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(z->dataType()), - y->buffer(), y->shapeInfo(), y->specialBuffer(), y->specialShapeInfo(), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo()); - } else { - std::vector dims(*block.getAxis()); - for (int e = 0; e < dims.size(); e++) - if (dims[e] < 0) - dims[e] += x->rankOf(); - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x->shapeInfo(), dims); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(z->shapeInfo(), dims); - - REQUIRE_TRUE(dims.size() > 0, 0, "Some dimensions requuired for reduction!"); - - auto xTadShape = Environment::getInstance().isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); //(Nd4jLong *) manager.replicatePointer(tadX.tadOnlyShapeInfo, shape::shapeInfoByteLength(tadX.tadOnlyShapeInfo)); - auto xTadOffsets = Environment::getInstance().isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tadX.tadOffsets, tadX.numTads * sizeof(Nd4jLong)); - - auto yTadShape = Environment::getInstance().isCPU() ? packZ.primaryShapeInfo() : packZ.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tadY.tadOnlyShapeInfo, shape::shapeInfoByteLength(tadY.tadOnlyShapeInfo)); - auto yTadOffsets = Environment::getInstance().isCPU() ? packZ.primaryOffsets() : packZ.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tadY.tadOffsets, tadY.numTads * sizeof(Nd4jLong)); - - NativeOpExecutioner::execReduce3(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(z->dataType()), - y->buffer(), y->shapeInfo(), y->specialBuffer(), y->specialShapeInfo(), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - dims.data(), dims.size(), xTadShape, xTadOffsets, yTadShape, yTadOffsets); - } - - manager.synchronize(); - STORE_RESULT(*z); - - return Status::OK(); - } - - LegacyReduce3Op::LegacyReduce3Op() : LegacyOp::LegacyOp(2) { - // - } - - LegacyReduce3Op::LegacyReduce3Op(int opNum) : LegacyOp::LegacyOp(2, opNum) { - // - } - - LegacyOp* LegacyReduce3Op::clone() { - return new LegacyReduce3Op(this->_opNum); - } - - /** - * For all reductions rules are simple: either you return scalar, or you return reduced NDArray. - * It solely depends on input shape, and requested dimensions - */ - ShapeList *LegacyReduce3Op::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto xShape = inputShape->at(0); - auto yShape = inputShape->at(1); - - Nd4jLong *zShape = nullptr; - - if (shape::equalsSoft(xShape, yShape) && (block.getIArguments()->size() == 0 || (block.getIArguments()->size() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()))) { - // reduce3 to scalar case - ALLOCATE(zShape, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong); - zShape[0] = 2; - zShape[1] = 1; - zShape[2] = 1; - zShape[3] = 1; - zShape[4] = 1; - zShape[5] = 0; - zShape[6] = 1; - zShape[7] = 99; - } else { - auto array = new NDArray(nullptr, xShape, block.launchContext()); - - xShape = ShapeUtils::evalReduceShapeInfo('c', *block.getIArguments(), *array, false, true); - - delete array; - } - - return SHAPELIST(zShape); - } - } -} \ No newline at end of file +namespace ops { + +Nd4jStatus LegacyReduce3Op::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + NDArray::prepareSpecialUse({z}, {x, y}); + + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + + nd4j_debug("Executing LegacyReduce3Op: [%i]\n", opNum); + + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyReduce3Op"); + + if (x->isSameShape(y) && + (block.numI() == 0 || + (block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()))) { + // reduce3 to scalar + NativeOpExecutioner::execReduce3Scalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), y->buffer(), y->shapeInfo(), + y->specialBuffer(), y->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo()); + } else { + std::vector dims(block.getAxis()); + for (int e = 0; e < dims.size(); e++) + if (dims[e] < 0) dims[e] += x->rankOf(); + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + x->shapeInfo(), dims); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + z->shapeInfo(), dims); + + REQUIRE_TRUE(dims.size() > 0, 0, + "Some dimensions requuired for reduction!"); + + auto xTadShape = + Environment::getInstance().isCPU() + ? packX.primaryShapeInfo() + : packX + .specialShapeInfo(); //(Nd4jLong *) + // manager.replicatePointer(tadX.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tadX.tadOnlyShapeInfo)); + auto xTadOffsets = + Environment::getInstance().isCPU() + ? packX.primaryOffsets() + : packX + .specialOffsets(); //(Nd4jLong *) + // manager.replicatePointer(tadX.tadOffsets, + // tadX.numTads * sizeof(Nd4jLong)); + + auto yTadShape = + Environment::getInstance().isCPU() + ? packZ.primaryShapeInfo() + : packZ + .specialOffsets(); //(Nd4jLong *) + // manager.replicatePointer(tadY.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tadY.tadOnlyShapeInfo)); + auto yTadOffsets = + Environment::getInstance().isCPU() + ? packZ.primaryOffsets() + : packZ + .specialOffsets(); //(Nd4jLong *) + // manager.replicatePointer(tadY.tadOffsets, + // tadY.numTads * sizeof(Nd4jLong)); + + NativeOpExecutioner::execReduce3( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), y->buffer(), y->shapeInfo(), + y->specialBuffer(), y->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), dims.size(), + xTadShape, xTadOffsets, yTadShape, yTadOffsets); + } + + manager.synchronize(); + STORE_RESULT(*z); + + return Status::OK(); +} + +LegacyReduce3Op::LegacyReduce3Op() : LegacyOp::LegacyOp(2) { + // +} + +LegacyReduce3Op::LegacyReduce3Op(int opNum) : LegacyOp::LegacyOp(2, opNum) { + // +} + +LegacyOp *LegacyReduce3Op::clone() { return new LegacyReduce3Op(this->_opNum); } + +/** + * For all reductions rules are simple: either you return scalar, or you + * return reduced NDArray. It solely depends on input shape, and requested + * dimensions + */ +ShapeList *LegacyReduce3Op::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + auto xShape = inputShape->at(0); + auto yShape = inputShape->at(1); + + Nd4jLong *zShape = nullptr; + + if (shape::equalsSoft(xShape, yShape) && + (block.numI() == 0 || + (block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()))) { + // reduce3 to scalar case + ALLOCATE(zShape, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); + zShape[0] = 2; + zShape[1] = 1; + zShape[2] = 1; + zShape[3] = 1; + zShape[4] = 1; + zShape[5] = 0; + zShape[6] = 1; + zShape[7] = 99; + } else { + auto array = new NDArray(nullptr, xShape, block.launchContext()); + + xShape = ShapeUtils::evalReduceShapeInfo('c', block.getIArguments(), *array, + false, true); + + delete array; + } + + return SHAPELIST(zShape); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyReduceBoolOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyReduceBoolOp.cpp index e16e71619065..1ce1943eba0d 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyReduceBoolOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyReduceBoolOp.cpp @@ -18,134 +18,178 @@ // Created by raver119 on 16.10.2017. // -#include -#include -#include +#include #include #include -#include +#include +#include +#include namespace sd { - namespace ops { - LegacyReduceBoolOp::LegacyReduceBoolOp() : LegacyOp::LegacyOp(1) { - // - } - - LegacyReduceBoolOp::LegacyReduceBoolOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - //this->_opNum = opNum; - } - - LegacyOp* LegacyReduceBoolOp::clone() { - return new LegacyReduceBoolOp(this->_opNum); - } - - Nd4jStatus LegacyReduceBoolOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - - auto z = OUTPUT_VARIABLE(0); - - NDArray::prepareSpecialUse({z}, {x}); - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - nd4j_debug("Executing LegacyReduceFloatOp: [%i]\n", opNum); - - auto axis = *block.getAxis(); - - bool allAxes = false; - - ExtraArguments extras(*block.getTArguments()); - PointersManager manager(block.launchContext(),"LegacyReduceBoolOp"); - - if (block.width() == 1) { - if (axis.size() == x->rankOf()) - allAxes = true; - - if ((axis.empty()) || - (axis.size() == 1 && axis[0] == sd::DataTypeUtils::max()) || allAxes) { - // scalar - NativeOpExecutioner::execReduceBoolScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo()); - } else { - // TAD - std::vector dims(axis); - - for (int e = 0; e < dims.size(); e++) - if (dims[e] < 0) - dims[e] += x->rankOf(); - - REQUIRE_TRUE(dims.size() > 0, 0, "Some dimensions required for reduction!"); - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x->shapeInfo(), dims); - - auto pTadShape = Environment::getInstance().isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); //manager.replicatePointer(tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); - auto pTadOffsets = Environment::getInstance().isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //manager.replicatePointer(tad.tadOffsets, tad.numTads * sizeof(Nd4jLong)); - - NativeOpExecutioner::execReduceBool(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(x->dataType()), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - dims.data(), (int) dims.size(), reinterpret_cast(pTadShape), reinterpret_cast(pTadOffsets)); - } - - STORE_RESULT(*z); - } else { - auto indices = INPUT_VARIABLE(1); - if (indices->lengthOf() == x->rankOf()) - allAxes = true; - - //indices->printIndexedBuffer("indices"); - - std::vector dims(indices->lengthOf()); - for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { - // lol otherwise we segfault on macOS - int f = indices->e(e); - dims[e] = f >= 0 ? f : f += x->rankOf(); - } - - if ((block.getIArguments()->size() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) || allAxes) { - // scalar - NativeOpExecutioner::execReduceBoolScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo()); - } else { - // TAD - if (indices->lengthOf() > 1) - std::sort(dims.begin(), dims.end()); +namespace ops { + +LegacyReduceBoolOp::LegacyReduceBoolOp() : LegacyOp::LegacyOp(1) { + // +} + +LegacyReduceBoolOp::LegacyReduceBoolOp(int opNum) + : LegacyOp::LegacyOp(1, opNum) { + // this->_opNum = opNum; +} + +LegacyOp *LegacyReduceBoolOp::clone() { + return new LegacyReduceBoolOp(this->_opNum); +} + +Nd4jStatus LegacyReduceBoolOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + + auto z = OUTPUT_VARIABLE(0); + + NDArray::prepareSpecialUse({z}, {x}); + + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + nd4j_debug("Executing LegacyReduceFloatOp: [%i]\n", opNum); + + auto axis = block.getAxis(); + + bool allAxes = false; + + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyReduceBoolOp"); + + if (block.width() == 1) { + if (axis.size() == x->rankOf()) allAxes = true; + + if ((axis.empty()) || + (axis.size() == 1 && axis[0] == sd::DataTypeUtils::max()) || + allAxes) { + // scalar + NativeOpExecutioner::execReduceBoolScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo()); + } else { + // TAD + std::vector dims(axis); + + for (int e = 0; e < dims.size(); e++) + if (dims[e] < 0) dims[e] += x->rankOf(); + + REQUIRE_TRUE(dims.size() > 0, 0, + "Some dimensions required for reduction!"); + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + x->shapeInfo(), dims); + + auto pTadShape = + Environment::getInstance().isCPU() + ? packX.primaryShapeInfo() + : packX + .specialShapeInfo(); // manager.replicatePointer(tad.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + auto pTadOffsets = + Environment::getInstance().isCPU() + ? packX.primaryOffsets() + : packX + .specialOffsets(); // manager.replicatePointer(tad.tadOffsets, + // tad.numTads * sizeof(Nd4jLong)); + + NativeOpExecutioner::execReduceBool( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), + (int)dims.size(), reinterpret_cast(pTadShape), + reinterpret_cast(pTadOffsets)); + } - REQUIRE_TRUE(dims.size() > 0, 0, "Some dimensions required for reduction!"); + STORE_RESULT(*z); + } else { + auto indices = INPUT_VARIABLE(1); + if (indices->lengthOf() == x->rankOf()) allAxes = true; - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x->shapeInfo(), dims); + // indices->printIndexedBuffer("indices"); - auto pTadShape = Environment::getInstance().isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); //(Nd4jLong *) manager.replicatePointer(tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); - auto pTadOffsets = Environment::getInstance().isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tad.tadOffsets, tad.numTads * sizeof(Nd4jLong)); + std::vector dims(indices->lengthOf()); + for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { + // lol otherwise we segfault on macOS + int f = indices->e(e); + dims[e] = f >= 0 ? f : f += x->rankOf(); + } - NativeOpExecutioner::execReduceBool(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), extras.argumentsAsT(x->dataType()), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), dims.data(), (int) dims.size(), pTadShape, pTadOffsets); - } - } + if ((block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) || + allAxes) { + // scalar + NativeOpExecutioner::execReduceBoolScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo()); + } else { + // TAD + if (indices->lengthOf() > 1) std::sort(dims.begin(), dims.end()); + + REQUIRE_TRUE(dims.size() > 0, 0, + "Some dimensions required for reduction!"); + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + x->shapeInfo(), dims); + + auto pTadShape = + Environment::getInstance().isCPU() + ? packX.primaryShapeInfo() + : packX + .specialShapeInfo(); //(Nd4jLong *) + // manager.replicatePointer(tad.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + auto pTadOffsets = + Environment::getInstance().isCPU() + ? packX.primaryOffsets() + : packX + .specialOffsets(); //(Nd4jLong *) + // manager.replicatePointer(tad.tadOffsets, + // tad.numTads * sizeof(Nd4jLong)); + + NativeOpExecutioner::execReduceBool( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), + (int)dims.size(), pTadShape, pTadOffsets); + } + } - manager.synchronize(); - return Status::OK(); - } + manager.synchronize(); + return Status::OK(); +} - /** - * For all reductions rules are simple: either you return scalar, or you return reduced NDArray. - * It solely depends on input shape, and requested dimensions - */ - ShapeList *LegacyReduceBoolOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); +/** + * For all reductions rules are simple: either you return scalar, or you + * return reduced NDArray. It solely depends on input shape, and requested + * dimensions + */ +ShapeList *LegacyReduceBoolOp::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + auto inShape = inputShape->at(0); - Nd4jLong *newShape; + Nd4jLong *newShape; - bool allAxes = false; + bool allAxes = false; - auto keepDims = block.numB() > 0 ? B_ARG(0) : false; - auto newFormat = block.numB() > 1 ? B_ARG(1) : true; + auto keepDims = block.numB() > 0 ? B_ARG(0) : false; + auto newFormat = block.numB() > 1 ? B_ARG(1) : true; - auto axis = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : *block.getAxis(); + auto axis = + block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : block.getAxis(); - if (axis.size() == shape::rank(inShape)) - allAxes = true; + if (axis.size() == shape::rank(inShape)) allAxes = true; - // in this case we're building proper shape for reduction - return SHAPELIST(ShapeUtils::evalReduceShapeInfo(shape::order(inShape), axis, inShape, DataType::BOOL, keepDims, !newFormat, block.workspace())); - } - } -} \ No newline at end of file + // in this case we're building proper shape for reduction + return SHAPELIST(ShapeUtils::evalReduceShapeInfo( + shape::order(inShape), axis, inShape, DataType::BOOL, keepDims, + !newFormat, block.workspace())); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyReduceFloatOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyReduceFloatOp.cpp index a0ff14858823..7ddd8b962bba 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyReduceFloatOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyReduceFloatOp.cpp @@ -18,135 +18,175 @@ // Created by raver119 on 16.10.2017. // -#include -#include -#include +#include #include #include -#include +#include +#include +#include namespace sd { - namespace ops { - LegacyReduceFloatOp::LegacyReduceFloatOp() : LegacyOp::LegacyOp(1) { - // - } - - LegacyReduceFloatOp::LegacyReduceFloatOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - //this->_opNum = opNum; - } - - LegacyOp* LegacyReduceFloatOp::clone() { - return new LegacyReduceFloatOp(this->_opNum); - } - - Nd4jStatus LegacyReduceFloatOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - - auto z = OUTPUT_VARIABLE(0); - - NDArray::prepareSpecialUse({z}, {x}); - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - nd4j_debug("Executing LegacyReduceFloatOp: [%i]\n", opNum); - - bool allAxes = false; - auto axis = *block.getAxis(); - - ExtraArguments extras(*block.getTArguments()); - PointersManager manager(block.launchContext(), "LegacyReduceFloatOp"); - - if (block.width() == 1) { - - if (axis.size() == x->rankOf()) - allAxes = true; - - // _axis.(block.getIArguments()->size() == 0) || - // (block.getIArguments()->size() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) - if (block.getAxis()->empty() || allAxes) { - // scalar - NativeOpExecutioner::execReduceFloatScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo()); - } else { - // TAD - std::vector dims(*block.getAxis()); - - for (int e = 0; e < dims.size(); e++) - if (dims[e] < 0) - dims[e] += x->rankOf(); - - REQUIRE_TRUE(dims.size() > 0, 0, "Some dimensions required for reduction!"); - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x->shapeInfo(), dims); - - auto pTadShape = Environment::getInstance().isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); //manager.replicatePointer(tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); - auto pTadOffsets = Environment::getInstance().isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //manager.replicatePointer(tad.tadOffsets, tad.numTads * sizeof(Nd4jLong)); - - NativeOpExecutioner::execReduceFloat(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - dims.data(), (int) dims.size(), reinterpret_cast(pTadShape), reinterpret_cast(pTadOffsets)); - - } - - STORE_RESULT(*z); - } else { - auto indices = INPUT_VARIABLE(1); - if (indices->lengthOf() == x->rankOf()) - allAxes = true; - - //indices->printIndexedBuffer("indices"); - - std::vector dims(indices->lengthOf()); - for (int e = 0; e < indices->lengthOf(); e++) { - // lol otherwise we segfault on macOS - int f = indices->e(e); - dims[e] = f >= 0 ? f : f += x->rankOf(); - } - - if ((block.getIArguments()->size() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) || allAxes) { - // scalar - NativeOpExecutioner::execReduceFloatScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo()); - } else { - // TAD - REQUIRE_TRUE(dims.size() > 0, 0, "Some dimensions required for reduction!"); - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x->shapeInfo(), dims); +namespace ops { +LegacyReduceFloatOp::LegacyReduceFloatOp() : LegacyOp::LegacyOp(1) { + // +} + +LegacyReduceFloatOp::LegacyReduceFloatOp(int opNum) + : LegacyOp::LegacyOp(1, opNum) { + // this->_opNum = opNum; +} + +LegacyOp *LegacyReduceFloatOp::clone() { + return new LegacyReduceFloatOp(this->_opNum); +} + +Nd4jStatus LegacyReduceFloatOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + + auto z = OUTPUT_VARIABLE(0); + + NDArray::prepareSpecialUse({z}, {x}); + + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + nd4j_debug("Executing LegacyReduceFloatOp: [%i]\n", opNum); + + bool allAxes = false; + auto axis = block.getAxis(); + + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyReduceFloatOp"); + + if (block.width() == 1) { + if (axis.size() == x->rankOf()) allAxes = true; + + // _axis.(block.getIArguments()->size() == 0) || + // (block.getIArguments()->size() == 1 && INT_ARG(0) == + // sd::DataTypeUtils::max()) + if (block.getAxis().empty() || allAxes) { + // scalar + NativeOpExecutioner::execReduceFloatScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo()); + } else { + // TAD + std::vector dims(block.getAxis()); + + for (int e = 0; e < dims.size(); e++) + if (dims[e] < 0) dims[e] += x->rankOf(); + + REQUIRE_TRUE(dims.size() > 0, 0, + "Some dimensions required for reduction!"); + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + x->shapeInfo(), dims); + + auto pTadShape = + Environment::getInstance().isCPU() + ? packX.primaryShapeInfo() + : packX + .specialShapeInfo(); // manager.replicatePointer(tad.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + auto pTadOffsets = + Environment::getInstance().isCPU() + ? packX.primaryOffsets() + : packX + .specialOffsets(); // manager.replicatePointer(tad.tadOffsets, + // tad.numTads * sizeof(Nd4jLong)); + + NativeOpExecutioner::execReduceFloat( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), + (int)dims.size(), reinterpret_cast(pTadShape), + reinterpret_cast(pTadOffsets)); + } - auto pTadShape = Environment::getInstance().isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); //(Nd4jLong *) manager.replicatePointer(tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); - auto pTadOffsets = Environment::getInstance().isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tad.tadOffsets, tad.numTads * sizeof(Nd4jLong)); + STORE_RESULT(*z); + } else { + auto indices = INPUT_VARIABLE(1); + if (indices->lengthOf() == x->rankOf()) allAxes = true; - NativeOpExecutioner::execReduceFloat(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - dims.data(), (int) dims.size(), pTadShape, pTadOffsets); + // indices->printIndexedBuffer("indices"); + std::vector dims(indices->lengthOf()); + for (int e = 0; e < indices->lengthOf(); e++) { + // lol otherwise we segfault on macOS + int f = indices->e(e); + dims[e] = f >= 0 ? f : f += x->rankOf(); + } - } - } + if ((block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) || + allAxes) { + // scalar + NativeOpExecutioner::execReduceFloatScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo()); + } else { + // TAD + REQUIRE_TRUE(dims.size() > 0, 0, + "Some dimensions required for reduction!"); + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + x->shapeInfo(), dims); + + auto pTadShape = + Environment::getInstance().isCPU() + ? packX.primaryShapeInfo() + : packX + .specialShapeInfo(); //(Nd4jLong *) + // manager.replicatePointer(tad.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + auto pTadOffsets = + Environment::getInstance().isCPU() + ? packX.primaryOffsets() + : packX + .specialOffsets(); //(Nd4jLong *) + // manager.replicatePointer(tad.tadOffsets, + // tad.numTads * sizeof(Nd4jLong)); + + NativeOpExecutioner::execReduceFloat( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), + (int)dims.size(), pTadShape, pTadOffsets); + } + } - manager.synchronize(); - return Status::OK(); - } + manager.synchronize(); + return Status::OK(); +} - /** - * For all reductions rules are simple: either you return scalar, or you return reduced NDArray. - * It solely depends on input shape, and requested dimensions - */ - ShapeList *LegacyReduceFloatOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); +/** + * For all reductions rules are simple: either you return scalar, or you + * return reduced NDArray. It solely depends on input shape, and requested + * dimensions + */ +ShapeList *LegacyReduceFloatOp::calculateOutputShape( + ShapeList *inputShape, sd::graph::Context &block) { + auto inShape = inputShape->at(0); - bool allAxes = false; + bool allAxes = false; - auto keepDims = block.numB() > 0 ? B_ARG(0) : false; - auto newFormat = block.numB() > 1 ? B_ARG(1) : true; + auto keepDims = block.numB() > 0 ? B_ARG(0) : false; + auto newFormat = block.numB() > 1 ? B_ARG(1) : true; - auto axis = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : *block.getAxis(); + auto axis = + block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : block.getAxis(); - if (axis.size() == shape::rank(inShape)) - allAxes = true; + if (axis.size() == shape::rank(inShape)) allAxes = true; - // in this case we're building proper shape for reduction - auto newShape = ShapeUtils::evalReduceShapeInfo(shape::order(inShape), axis, inShape, keepDims, !newFormat, block.workspace()); + // in this case we're building proper shape for reduction + auto newShape = + ShapeUtils::evalReduceShapeInfo(shape::order(inShape), axis, inShape, + keepDims, !newFormat, block.workspace()); - return SHAPELIST(newShape); - } - } -} \ No newline at end of file + return SHAPELIST(newShape); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyReduceLongOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyReduceLongOp.cpp index f5007ff039ef..93e7d1b05ce8 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyReduceLongOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyReduceLongOp.cpp @@ -18,135 +18,177 @@ // Created by raver119 on 16.10.2017. // -#include -#include -#include +#include #include #include -#include +#include +#include +#include namespace sd { - namespace ops { - LegacyReduceLongOp::LegacyReduceLongOp() : LegacyOp::LegacyOp(1) { - // - } - - LegacyReduceLongOp::LegacyReduceLongOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - //this->_opNum = opNum; - } - - LegacyOp* LegacyReduceLongOp::clone() { - return new LegacyReduceLongOp(this->_opNum); - } - - Nd4jStatus LegacyReduceLongOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - - auto z = OUTPUT_VARIABLE(0); - - NDArray::prepareSpecialUse({z}, {x}); - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - nd4j_debug("Executing LegacyReduceFloatOp: [%i]\n", opNum); - - auto axis = *block.getAxis(); - bool allAxes = false; - - ExtraArguments extras(*block.getTArguments()); - PointersManager manager(block.launchContext(),"LegacyReduceLongOp"); - - if (block.width() == 1) { - - if (axis.size() == x->rankOf()) - allAxes = true; - - if ((axis.empty()) || - (axis.size() == 1 && axis[0] == sd::DataTypeUtils::max()) || allAxes) { - // scalar - NativeOpExecutioner::execReduceLongScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo()); - } else { - // TAD - std::vector dims(axis); - - for (int e = 0; e < dims.size(); e++) - if (dims[e] < 0) - dims[e] += x->rankOf(); - - if (dims.size() > 1) - std::sort(dims.begin(), dims.end()); - - REQUIRE_TRUE(dims.size() > 0, 0, "Some dimensions required for reduction!"); - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x->shapeInfo(), dims); - - auto pTadShape = Environment::getInstance().isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); //(Nd4jLong *) manager.replicatePointer(tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); - auto pTadOffsets = Environment::getInstance().isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tad.tadOffsets, tad.numTads * sizeof(Nd4jLong)); - - NativeOpExecutioner::execReduceLong(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(x->dataType()), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - dims.data(), (int) dims.size(), pTadShape, pTadOffsets); - } - - STORE_RESULT(*z); - } else { - auto indices = INPUT_VARIABLE(1); - if (indices->lengthOf() == x->rankOf()) - allAxes = true; - - //indices->printIndexedBuffer("indices"); - - std::vector dims(indices->lengthOf()); - for (int e = 0; e < indices->lengthOf(); e++) { - // lol otherwise we segfault on macOS - int f = indices->e(e); - dims[e] = f >= 0 ? f : f += x->rankOf(); - } - - if ((block.getIArguments()->size() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) || allAxes) { - // scalar - NativeOpExecutioner::execReduceLongScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo()); - } else { - // TAD - REQUIRE_TRUE(dims.size() > 0, 0, "Some dimensions required for reduction!"); +namespace ops { +LegacyReduceLongOp::LegacyReduceLongOp() : LegacyOp::LegacyOp(1) { + // +} + +LegacyReduceLongOp::LegacyReduceLongOp(int opNum) + : LegacyOp::LegacyOp(1, opNum) { + // this->_opNum = opNum; +} + +LegacyOp *LegacyReduceLongOp::clone() { + return new LegacyReduceLongOp(this->_opNum); +} + +Nd4jStatus LegacyReduceLongOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + + auto z = OUTPUT_VARIABLE(0); + + NDArray::prepareSpecialUse({z}, {x}); + + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + nd4j_debug("Executing LegacyReduceFloatOp: [%i]\n", opNum); + + auto axis = block.getAxis(); + bool allAxes = false; + + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyReduceLongOp"); + + if (block.width() == 1) { + if (axis.size() == x->rankOf()) allAxes = true; + + if ((axis.empty()) || + (axis.size() == 1 && axis[0] == sd::DataTypeUtils::max()) || + allAxes) { + // scalar + NativeOpExecutioner::execReduceLongScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo()); + } else { + // TAD + std::vector dims(axis); + + for (int e = 0; e < dims.size(); e++) + if (dims[e] < 0) dims[e] += x->rankOf(); + + if (dims.size() > 1) std::sort(dims.begin(), dims.end()); + + REQUIRE_TRUE(dims.size() > 0, 0, + "Some dimensions required for reduction!"); + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + x->shapeInfo(), dims); + + auto pTadShape = + Environment::getInstance().isCPU() + ? packX.primaryShapeInfo() + : packX + .specialShapeInfo(); //(Nd4jLong *) + // manager.replicatePointer(tad.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + auto pTadOffsets = + Environment::getInstance().isCPU() + ? packX.primaryOffsets() + : packX + .specialOffsets(); //(Nd4jLong *) + // manager.replicatePointer(tad.tadOffsets, + // tad.numTads * sizeof(Nd4jLong)); + + NativeOpExecutioner::execReduceLong( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), + (int)dims.size(), pTadShape, pTadOffsets); + } - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x->shapeInfo(), dims); + STORE_RESULT(*z); + } else { + auto indices = INPUT_VARIABLE(1); + if (indices->lengthOf() == x->rankOf()) allAxes = true; - auto pTadShape = Environment::getInstance().isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); //(Nd4jLong *) manager.replicatePointer(tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); - auto pTadOffsets = Environment::getInstance().isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tad.tadOffsets, tad.numTads * sizeof(Nd4jLong)); + // indices->printIndexedBuffer("indices"); - NativeOpExecutioner::execReduceLong(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), extras.argumentsAsT(x->dataType()), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), dims.data(), (int) dims.size(), pTadShape, pTadOffsets); + std::vector dims(indices->lengthOf()); + for (int e = 0; e < indices->lengthOf(); e++) { + // lol otherwise we segfault on macOS + int f = indices->e(e); + dims[e] = f >= 0 ? f : f += x->rankOf(); + } - } - } + if ((block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) || + allAxes) { + // scalar + NativeOpExecutioner::execReduceLongScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo()); + } else { + // TAD + REQUIRE_TRUE(dims.size() > 0, 0, + "Some dimensions required for reduction!"); + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + x->shapeInfo(), dims); + + auto pTadShape = + Environment::getInstance().isCPU() + ? packX.primaryShapeInfo() + : packX + .specialShapeInfo(); //(Nd4jLong *) + // manager.replicatePointer(tad.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + auto pTadOffsets = + Environment::getInstance().isCPU() + ? packX.primaryOffsets() + : packX + .specialOffsets(); //(Nd4jLong *) + // manager.replicatePointer(tad.tadOffsets, + // tad.numTads * sizeof(Nd4jLong)); + + NativeOpExecutioner::execReduceLong( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), + (int)dims.size(), pTadShape, pTadOffsets); + } + } - manager.synchronize(); - return Status::OK(); - } + manager.synchronize(); + return Status::OK(); +} - /** - * For all reductions rules are simple: either you return scalar, or you return reduced NDArray. - * It solely depends on input shape, and requested dimensions - */ - ShapeList *LegacyReduceLongOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); +/** + * For all reductions rules are simple: either you return scalar, or you + * return reduced NDArray. It solely depends on input shape, and requested + * dimensions + */ +ShapeList *LegacyReduceLongOp::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + auto inShape = inputShape->at(0); - Nd4jLong *newShape; + Nd4jLong *newShape; - bool allAxes = false; + bool allAxes = false; - auto keepDims = block.numB() > 0 ? B_ARG(0) : false; - auto newFormat = block.numB() > 1 ? B_ARG(1) : true; + auto keepDims = block.numB() > 0 ? B_ARG(0) : false; + auto newFormat = block.numB() > 1 ? B_ARG(1) : true; - auto axis = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : *block.getAxis(); + auto axis = + block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : block.getAxis(); - if (axis.size() == shape::rank(inShape)) - allAxes = true; + if (axis.size() == shape::rank(inShape)) allAxes = true; - // in this case we're building proper shape for reduction - return SHAPELIST(ShapeUtils::evalReduceShapeInfo(shape::order(inShape), axis, inShape, DataType::INT64, keepDims, !newFormat, block.workspace())); - } - } -} \ No newline at end of file + // in this case we're building proper shape for reduction + return SHAPELIST(ShapeUtils::evalReduceShapeInfo( + shape::order(inShape), axis, inShape, DataType::INT64, keepDims, + !newFormat, block.workspace())); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyReduceOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyReduceOp.cpp index e6c3dd63b295..363dbc294847 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyReduceOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyReduceOp.cpp @@ -18,170 +18,186 @@ // Created by raver119 on 16.10.2017. // -#include -#include #include +#include +#include #ifdef LEGACY_REDUCE_SAME_ONLY namespace sd { - namespace ops { - LegacyReduceOp::LegacyReduceOp() : LegacyOp::LegacyOp(1) { - // - } - - LegacyReduceOp::LegacyReduceOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - //this->_opNum = opNum; - } - - LegacyOp* LegacyReduceOp::clone() { - return new LegacyReduceOp(this->_opNum); - } - - Nd4jStatus LegacyReduceOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - nd4j_debug("Executing LegacyReduceOp: [%i]\n", opNum); - - bool allAxes = false; - - if (block.width() == 1) { - auto z = OUTPUT_VARIABLE(0); +namespace ops { +LegacyReduceOp::LegacyReduceOp() : LegacyOp::LegacyOp(1) { + // +} - if (block.getIArguments()->size() == x->rankOf()) - allAxes = true; +LegacyReduceOp::LegacyReduceOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { + // this->_opNum = opNum; +} - if ((block.getIArguments()->size() == 0) || - (block.getIArguments()->size() == 1 && INT_ARG(0) == MAX_INT) || allAxes) { - // scalar - NativeOpExcutioner::execReduceFloatScalar(opNum, x->buffer(), x->shapeInfo(), block.getTArguments()->data(), z->buffer(), z->shapeInfo()); - } else { - // TAD - std::vector dims(*block.getIArguments()); +LegacyOp *LegacyReduceOp::clone() { return new LegacyReduceOp(this->_opNum); } - for (int e = 0; e < dims.size(); e++) - if (dims[e] < 0) - dims[e] += x->rankOf(); +Nd4jStatus LegacyReduceOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); - std::sort(dims.begin(), dims.end()); + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + nd4j_debug("Executing LegacyReduceOp: [%i]\n", opNum); - REQUIRE_TRUE(dims.size() > 0, 0, "Some dimensions required for reduction!"); + bool allAxes = false; - shape::TAD tad(x->shapeInfo(), dims.data(), dims.size()); - tad.createTadOnlyShapeInfo(); - tad.createOffsets(); + if (block.width() == 1) { + auto z = OUTPUT_VARIABLE(0); - NativeOpExcutioner::execReduceFloat(opNum, x->buffer(), x->shapeInfo(), block.getTArguments()->data(), z->buffer(), z->shapeInfo(), dims.data(), (int) dims.size(), tad.tadOnlyShapeInfo, tad.tadOffsets); - } + if (block.getIArguments()->size() == x->rankOf()) allAxes = true; - STORE_RESULT(*z); - } else { - auto indices = INPUT_VARIABLE(1); - if (indices->lengthOf() == x->rankOf()) - allAxes = true; + if ((block.getIArguments()->size() == 0) || + (block.getIArguments()->size() == 1 && INT_ARG(0) == MAX_INT) || + allAxes) { + // scalar + NativeOpExcutioner::execReduceFloatScalar( + opNum, x->buffer(), x->shapeInfo(), block.getTArguments()->data(), + z->buffer(), z->shapeInfo()); + } else { + // TAD + std::vector dims(*block.getIArguments()); - //indices->printIndexedBuffer("indices"); + for (int e = 0; e < dims.size(); e++) + if (dims[e] < 0) dims[e] += x->rankOf(); - std::vector axis(indices->lengthOf()); - for (int e = 0; e < indices->lengthOf(); e++) { - // lol otherwise we segfault on macOS - int f = indices->e(e); - axis[e] = f >= 0 ? f : f += x->rankOf(); - } + std::sort(dims.begin(), dims.end()); - if ((block.getIArguments()->size() == 1 && INT_ARG(0) == MAX_INT) || allAxes) { - auto z = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(dims.size() > 0, 0, + "Some dimensions required for reduction!"); - auto b = x->buffer(); - auto s = x->shapeInfo(); - auto e = block.numT() > 0 ? block.getTArguments()->data() : nullptr; + shape::TAD tad(x->shapeInfo(), dims.data(), dims.size()); + tad.createTadOnlyShapeInfo(); + tad.createOffsets(); - //x->printIndexedBuffer("x"); + NativeOpExcutioner::execReduceFloat( + opNum, x->buffer(), x->shapeInfo(), block.getTArguments()->data(), + z->buffer(), z->shapeInfo(), dims.data(), (int)dims.size(), + tad.tadOnlyShapeInfo, tad.tadOffsets); + } - // scalar - NativeOpExcutioner::execReduceFloatScalar(opNum, b, s, e, z->buffer(), z->shapeInfo()); - } else { - // TAD - if (indices->lengthOf() > 1) - std::sort(axis.begin(), axis.end()); + STORE_RESULT(*z); + } else { + auto indices = INPUT_VARIABLE(1); + if (indices->lengthOf() == x->rankOf()) allAxes = true; - REQUIRE_TRUE(axis.size() > 0, 0, "Some dimensions required for reduction!"); + // indices->printIndexedBuffer("indices"); - shape::TAD tad(x->shapeInfo(), axis.data(), axis.size()); - tad.createTadOnlyShapeInfo(); - tad.createOffsets(); + std::vector axis(indices->lengthOf()); + for (int e = 0; e < indices->lengthOf(); e++) { + // lol otherwise we segfault on macOS + int f = indices->e(e); + axis[e] = f >= 0 ? f : f += x->rankOf(); + } - auto newShape = ShapeUtils::evalReduceShapeInfo(x->ordering(), axis, *x); - auto z = new NDArray(newShape, x->getWorkspace()); + if ((block.getIArguments()->size() == 1 && INT_ARG(0) == MAX_INT) || + allAxes) { + auto z = OUTPUT_VARIABLE(0); + + auto b = x->buffer(); + auto s = x->shapeInfo(); + auto e = block.numT() > 0 ? block.getTArguments()->data() : nullptr; + + // x->printIndexedBuffer("x"); + + // scalar + NativeOpExcutioner::execReduceFloatScalar(opNum, b, s, e, z->buffer(), + z->shapeInfo()); + } else { + // TAD + if (indices->lengthOf() > 1) std::sort(axis.begin(), axis.end()); + + REQUIRE_TRUE(axis.size() > 0, 0, + "Some dimensions required for reduction!"); + + shape::TAD tad(x->shapeInfo(), axis.data(), axis.size()); + tad.createTadOnlyShapeInfo(); + tad.createOffsets(); + + auto newShape = ShapeUtils::evalReduceShapeInfo(x->ordering(), axis, *x); + auto z = new NDArray(newShape, x->getWorkspace()); + + NativeOpExcutioner::execReduceFloat( + opNum, x->buffer(), x->shapeInfo(), block.getTArguments()->data(), + z->buffer(), z->shapeInfo(), axis.data(), (int)axis.size(), + tad.tadOnlyShapeInfo, tad.tadOffsets); + + // keepDims processing, for TF compatibility + if (block.getIArguments()->size() > 0 && + block.getIArguments()->at(0) == 1) { + // z->printShapeInfo("z shape before"); + std::vector newshape(z->getShapeAsVector()); + for (int e = 0; e < axis.size(); e++) { + auto a = axis.at(e); + newshape.insert(newshape.begin() + a, 1); + } + z->reshapei(z->ordering(), newshape); + // z->printShapeInfo("z shape after"); + } - NativeOpExcutioner::execReduceFloat(opNum, x->buffer(), x->shapeInfo(), block.getTArguments()->data(), z->buffer(), z->shapeInfo(), axis.data(), (int) axis.size(), tad.tadOnlyShapeInfo, tad.tadOffsets); + OVERWRITE_RESULT(z); + } + } + return ND4J_STATUS_OK; +} - // keepDims processing, for TF compatibility - if (block.getIArguments()->size() > 0 && block.getIArguments()->at(0) == 1) { - // z->printShapeInfo("z shape before"); - std::vector newshape(z->getShapeAsVector()); - for (int e = 0; e < axis.size(); e++) { - auto a = axis.at(e); - newshape.insert(newshape.begin() + a, 1); - } - z->reshapei(z->ordering(), newshape); - // z->printShapeInfo("z shape after"); - } +/** + * For all reductions rules are simple: either you return scalar, or you + * return reduced NDArray. It solely depends on input shape, and requested + * dimensions + */ +ShapeList *LegacyReduceOp::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + auto inShape = inputShape->at(0); + + Nd4jLong *newShape; + + bool allAxes = false; + + if (block.getIArguments()->size() == shape::rank(inShape)) allAxes = true; + + if (block.getIArguments()->size() == 0 || + (block.getIArguments()->size() == 1 && INT_ARG(0) == MAX_INT) || + allAxes) { + if (block.getIArguments()->size() > 0 && + block.getIArguments()->at(0) == 1) { + // in this case we just return legacy scalar + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(2), + Nd4jLong); + newShape[0] = 2; + newShape[1] = 1; + newShape[2] = 1; + newShape[3] = 1; + newShape[4] = 1; + newShape[5] = 0; + newShape[6] = 1; + newShape[7] = 99; + // ArrayOptions::setDataType(newShape, block.dataType() == + // DataType::BOOL?block.dataType():ArrayOptions::dataType(inShape)); + } else { + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(0), + Nd4jLong); + newShape[0] = 0; + newShape[1] = 0; + newShape[2] = 1; + newShape[3] = 99; + // ArrayOptions::setDataType(newShape, block.dataType() == + // DataType::BOOL?block.dataType():ArrayOptions::dataType(inShape)); + } + } else { + // in this case we're building proper shape for reduction + auto array = new NDArray(nullptr, inShape, block.workspace()); - OVERWRITE_RESULT(z); - } - } + newShape = ShapeUtils::evalReduceShapeInfo(shape::order(inShape), + *block.getIArguments(), *array, + false, false, block.workspace()); - return ND4J_STATUS_OK; - } + delete array; + } - /** - * For all reductions rules are simple: either you return scalar, or you return reduced NDArray. - * It solely depends on input shape, and requested dimensions - */ - ShapeList *LegacyReduceOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - - Nd4jLong *newShape; - - bool allAxes = false; - - if (block.getIArguments()->size() == shape::rank(inShape)) - allAxes = true; - - if (block.getIArguments()->size() == 0 || (block.getIArguments()->size() == 1 && INT_ARG(0) == MAX_INT) || allAxes) { - if (block.getIArguments()->size() > 0 && block.getIArguments()->at(0) == 1) { - // in this case we just return legacy scalar - ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong); - newShape[0] = 2; - newShape[1] = 1; - newShape[2] = 1; - newShape[3] = 1; - newShape[4] = 1; - newShape[5] = 0; - newShape[6] = 1; - newShape[7] = 99; - //ArrayOptions::setDataType(newShape, block.dataType() == DataType::BOOL?block.dataType():ArrayOptions::dataType(inShape)); - } else { - ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(0), Nd4jLong); - newShape[0] = 0; - newShape[1] = 0; - newShape[2] = 1; - newShape[3] = 99; - //ArrayOptions::setDataType(newShape, block.dataType() == DataType::BOOL?block.dataType():ArrayOptions::dataType(inShape)); - } - } else { - // in this case we're building proper shape for reduction - auto array = new NDArray(nullptr, inShape, block.getWorkspace()); - - newShape = ShapeUtils::evalReduceShapeInfo(shape::order(inShape), *block.getIArguments(), *array, false, false, block.workspace()); - - delete array; - } - - return SHAPELIST(newShape); - } - } + return SHAPELIST(newShape); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/impl/LegacyReduceSameOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyReduceSameOp.cpp index 299d19f14833..17a392447763 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyReduceSameOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyReduceSameOp.cpp @@ -18,131 +18,174 @@ // Created by raver119 on 16.10.2017. // -#include -#include -#include +#include #include #include -#include +#include +#include +#include namespace sd { - namespace ops { - LegacyReduceSameOp::LegacyReduceSameOp() : LegacyOp::LegacyOp(1) { - // - } - - LegacyReduceSameOp::LegacyReduceSameOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - //this->_opNum = opNum; - } - - LegacyOp* LegacyReduceSameOp::clone() { - return new LegacyReduceSameOp(this->_opNum); - } - - Nd4jStatus LegacyReduceSameOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - - auto z = OUTPUT_VARIABLE(0); - - NDArray::prepareSpecialUse({z}, {x}); - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - nd4j_debug("Executing LegacyReduceSameOp: [%i]\n", opNum); - - auto axis = *block.getAxis(); - bool allAxes = false; - - ExtraArguments extras(*block.getTArguments()); - PointersManager manager(block.launchContext(), "LegacyReduceSameOp"); - - if (block.width() == 1) { - if (axis.size() == x->rankOf()) - allAxes = true; - - if (axis.empty() || allAxes) { - // scalar - NativeOpExecutioner::execReduceSameScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo()); - } else { - // TAD - std::vector dims(axis); - - for (int e = 0; e < dims.size(); e++) - if (dims[e] < 0) - dims[e] += x->rankOf(); - - REQUIRE_TRUE(dims.size() > 0, 0, "Some dimensions required for reduction!"); - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x->shapeInfo(), dims); - - auto pTadShape = Environment::getInstance().isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); //(Nd4jLong *) manager.replicatePointer(tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); - auto pTadOffsets = Environment::getInstance().isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tad.tadOffsets, tad.numTads * sizeof(Nd4jLong)); - - NativeOpExecutioner::execReduceSame(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(z->dataType()), - z->buffer(), z->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - dims.data(), (int) dims.size(), pTadShape, pTadOffsets); - } - - STORE_RESULT(*z); - } else { - auto indices = INPUT_VARIABLE(1); - if (indices->lengthOf() == x->rankOf()) - allAxes = true; - - //indices->printIndexedBuffer("indices"); - - std::vector dims(indices->lengthOf()); - for (int e = 0; e < indices->lengthOf(); e++) { - // lol otherwise we segfault on macOS - int f = indices->e(e); - dims[e] = f >= 0 ? f : f += x->rankOf(); - } +namespace ops { +LegacyReduceSameOp::LegacyReduceSameOp() : LegacyOp::LegacyOp(1) { + // +} + +LegacyReduceSameOp::LegacyReduceSameOp(int opNum) + : LegacyOp::LegacyOp(1, opNum) { + // this->_opNum = opNum; +} + +LegacyOp *LegacyReduceSameOp::clone() { + return new LegacyReduceSameOp(this->_opNum); +} + +Nd4jStatus LegacyReduceSameOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + + auto z = OUTPUT_VARIABLE(0); + + NDArray::prepareSpecialUse({z}, {x}); + + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + nd4j_debug("Executing LegacyReduceSameOp: [%i]\n", opNum); + + auto axis = block.getAxis(); + bool allAxes = false; + + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyReduceSameOp"); + + if (block.width() == 1) { + if (axis.size() == x->rankOf()) allAxes = true; + + if (axis.empty() || allAxes) { + // scalar + NativeOpExecutioner::execReduceSameScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo()); + } else { + // TAD + std::vector dims(axis); + + for (int e = 0; e < dims.size(); e++) + if (dims[e] < 0) dims[e] += x->rankOf(); + + REQUIRE_TRUE(dims.size() > 0, 0, + "Some dimensions required for reduction!"); + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + x->shapeInfo(), dims); + + auto pTadShape = + Environment::getInstance().isCPU() + ? packX.primaryShapeInfo() + : packX + .specialShapeInfo(); //(Nd4jLong *) + // manager.replicatePointer(tad.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + auto pTadOffsets = + Environment::getInstance().isCPU() + ? packX.primaryOffsets() + : packX + .specialOffsets(); //(Nd4jLong *) + // manager.replicatePointer(tad.tadOffsets, + // tad.numTads * sizeof(Nd4jLong)); + + NativeOpExecutioner::execReduceSame( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), dims.data(), + (int)dims.size(), pTadShape, pTadOffsets); + } - if ((block.getIArguments()->size() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) || allAxes) { - // scalar - NativeOpExecutioner::execReduceSameScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo()); - } else { - // TAD - REQUIRE_TRUE(dims.size() > 0, 0, "Some dimensions required for reduction!"); + STORE_RESULT(*z); + } else { + auto indices = INPUT_VARIABLE(1); + if (indices->lengthOf() == x->rankOf()) allAxes = true; - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x->shapeInfo(), dims); + // indices->printIndexedBuffer("indices"); - auto pTadShape = Environment::getInstance().isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); //(Nd4jLong *) manager.replicatePointer(tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); - auto pTadOffsets = Environment::getInstance().isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tad.tadOffsets, tad.numTads * sizeof(Nd4jLong)); + std::vector dims(indices->lengthOf()); + for (int e = 0; e < indices->lengthOf(); e++) { + // lol otherwise we segfault on macOS + int f = indices->e(e); + dims[e] = f >= 0 ? f : f += x->rankOf(); + } - NativeOpExecutioner::execReduceSame(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - dims.data(), (int) dims.size(), pTadShape, pTadOffsets); - } - } + if ((block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) || + allAxes) { + // scalar + NativeOpExecutioner::execReduceSameScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo()); + } else { + // TAD + REQUIRE_TRUE(dims.size() > 0, 0, + "Some dimensions required for reduction!"); + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + x->shapeInfo(), dims); + + auto pTadShape = + Environment::getInstance().isCPU() + ? packX.primaryShapeInfo() + : packX + .specialShapeInfo(); //(Nd4jLong *) + // manager.replicatePointer(tad.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + auto pTadOffsets = + Environment::getInstance().isCPU() + ? packX.primaryOffsets() + : packX + .specialOffsets(); //(Nd4jLong *) + // manager.replicatePointer(tad.tadOffsets, + // tad.numTads * sizeof(Nd4jLong)); + + NativeOpExecutioner::execReduceSame( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), + (int)dims.size(), pTadShape, pTadOffsets); + } + } - manager.synchronize(); + manager.synchronize(); - return Status::OK(); - } + return Status::OK(); +} - /** - * For all reductions rules are simple: either you return scalar, or you return reduced NDArray. - * It solely depends on input shape, and requested dimensions - */ - ShapeList *LegacyReduceSameOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); +/** + * For all reductions rules are simple: either you return scalar, or you + * return reduced NDArray. It solely depends on input shape, and requested + * dimensions + */ +ShapeList *LegacyReduceSameOp::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + auto inShape = inputShape->at(0); - bool allAxes = false; + bool allAxes = false; - auto keepDims = block.numB() > 0 ? B_ARG(0) : false; - auto newFormat = block.numB() > 1 ? B_ARG(1) : true; + auto keepDims = block.numB() > 0 ? B_ARG(0) : false; + auto newFormat = block.numB() > 1 ? B_ARG(1) : true; - auto axis = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : *block.getAxis(); + auto axis = + block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : block.getAxis(); - if (axis.size() == shape::rank(inShape)) - allAxes = true; + if (axis.size() == shape::rank(inShape)) allAxes = true; - // in this case we're building proper shape for reduction - auto newShape = ShapeUtils::evalReduceShapeInfo(shape::order(inShape), axis, inShape, keepDims, !newFormat, block.workspace()); + // in this case we're building proper shape for reduction + auto newShape = + ShapeUtils::evalReduceShapeInfo(shape::order(inShape), axis, inShape, + keepDims, !newFormat, block.workspace()); - return SHAPELIST(newShape); - } - } -} \ No newline at end of file + return SHAPELIST(newShape); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp index abfd84efbe88..29c2efa1ab6d 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp @@ -18,70 +18,87 @@ // Created by raver119 on 16.10.2017. // -#include #include #include - +#include namespace sd { - namespace ops { - LegacyScalarBoolOp::LegacyScalarBoolOp() : LegacyOp::LegacyOp(1) { - // no-op - } - - LegacyScalarBoolOp::LegacyScalarBoolOp(int opNum) : LegacyOp::LegacyOp(1, opNum){ - // no-op - } - - LegacyOp* LegacyScalarBoolOp::clone() { - return new LegacyScalarBoolOp(this->_opNum, *this->_scalar); - } - - LegacyScalarBoolOp::LegacyScalarBoolOp(int opNum, NDArray &scalar) : LegacyOp::LegacyOp(1, opNum){ - _scalar = new NDArray(scalar.dup(scalar.ordering())); - } - - ShapeList *LegacyScalarBoolOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - - Nd4jLong *newShape; - COPY_SHAPE(inShape, newShape); - - return SHAPELIST(CONSTANT(newShape)); - } - - Nd4jStatus LegacyScalarBoolOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - - ExtraArguments extras(*block.getTArguments()); - PointersManager manager(block.launchContext(), "LegacyScalarBoolOp"); - - if (block.width() > 1) { - auto y = INPUT_VARIABLE(1); - - NDArray::prepareSpecialUse({z}, {x, y}); - - NativeOpExecutioner::execScalarBool(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), y->buffer(), y->shapeInfo(), y->specialBuffer(), y->specialShapeInfo(), extras.argumentsAsT(x->dataType())); - } else if (block.getTArguments()->size() > 0) { - auto y = NDArrayFactory::create(T_ARG(0), block.launchContext()); - - NDArray::prepareSpecialUse({z}, {x, &y}); - - NativeOpExecutioner::execScalarBool(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(),z->buffer(), z->shapeInfo(),z->specialBuffer(), z->specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), extras.argumentsAsT(x->dataType(), 1)); - - manager.synchronize(); - } else { - NDArray::prepareSpecialUse({z}, {x, _scalar}); - - NativeOpExecutioner::execScalarBool(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(),z->buffer(), z->shapeInfo(),z->specialBuffer(), z->specialShapeInfo(), _scalar->buffer(), _scalar->shapeInfo(), _scalar->specialBuffer(), _scalar->specialShapeInfo(), extras.argumentsAsT(x->dataType())); - } - manager.synchronize(); - STORE_RESULT(*z); - - return Status::OK(); - } - } -} \ No newline at end of file +namespace ops { +LegacyScalarBoolOp::LegacyScalarBoolOp() : LegacyOp::LegacyOp(1) { + // no-op +} + +LegacyScalarBoolOp::LegacyScalarBoolOp(int opNum) + : LegacyOp::LegacyOp(1, opNum) { + // no-op +} + +LegacyOp *LegacyScalarBoolOp::clone() { + return new LegacyScalarBoolOp(this->_opNum, *this->_scalar); +} + +LegacyScalarBoolOp::LegacyScalarBoolOp(int opNum, NDArray &scalar) + : LegacyOp::LegacyOp(1, opNum) { + _scalar = new NDArray(scalar.dup(scalar.ordering())); +} + +ShapeList *LegacyScalarBoolOp::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + auto inShape = inputShape->at(0); + + Nd4jLong *newShape; + COPY_SHAPE(inShape, newShape); + + return SHAPELIST(CONSTANT(newShape)); +} + +Nd4jStatus LegacyScalarBoolOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyScalarBoolOp"); + + if (block.width() > 1) { + auto y = INPUT_VARIABLE(1); + + NDArray::prepareSpecialUse({z}, {x, y}); + + NativeOpExecutioner::execScalarBool( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), y->buffer(), y->shapeInfo(), + y->specialBuffer(), y->specialShapeInfo(), + extras.argumentsAsT(x->dataType())); + } else if (block.numT() > 0) { + auto y = NDArrayFactory::create(T_ARG(0), block.launchContext()); + + NDArray::prepareSpecialUse({z}, {x, &y}); + + NativeOpExecutioner::execScalarBool( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), y.buffer(), y.shapeInfo(), + y.specialBuffer(), y.specialShapeInfo(), + extras.argumentsAsT(x->dataType(), 1)); + + manager.synchronize(); + } else { + NDArray::prepareSpecialUse({z}, {x, _scalar}); + + NativeOpExecutioner::execScalarBool( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), _scalar->buffer(), + _scalar->shapeInfo(), _scalar->specialBuffer(), + _scalar->specialShapeInfo(), extras.argumentsAsT(x->dataType())); + } + manager.synchronize(); + STORE_RESULT(*z); + + return Status::OK(); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyScalarOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyScalarOp.cpp index 0c700b88b797..875e6d21b8f0 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyScalarOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyScalarOp.cpp @@ -18,72 +18,88 @@ // Created by raver119 on 16.10.2017. // -#include #include #include - +#include namespace sd { - namespace ops { - LegacyScalarOp::LegacyScalarOp() : LegacyOp::LegacyOp(1) { - this->getOpDescriptor()->allowInplace(true); - } - - LegacyScalarOp::LegacyScalarOp(int opNum) : LegacyOp::LegacyOp(1, opNum){ - this->getOpDescriptor()->allowInplace(true); - } - - LegacyOp* LegacyScalarOp::clone() { - return new LegacyScalarOp(this->_opNum, *this->_scalar); - } - - LegacyScalarOp::LegacyScalarOp(int opNum, NDArray &scalar) : LegacyOp::LegacyOp(1, opNum){ - _scalar = new NDArray(scalar.dup(scalar.ordering())); - } - - ShapeList *LegacyScalarOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - - Nd4jLong *newShape; - COPY_SHAPE(inShape, newShape); - - return SHAPELIST(CONSTANT(newShape)); - } - - Nd4jStatus LegacyScalarOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - - ExtraArguments extras(*block.getTArguments()); - PointersManager manager(block.launchContext(), "LegacyScalarOp"); - - if (block.width() > 1) { - auto y = INPUT_VARIABLE(1); - - NDArray::prepareSpecialUse({z}, {x, y}); - - NativeOpExecutioner::execScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), y->buffer(), y->shapeInfo(), y->specialBuffer(), y->specialShapeInfo(), extras.argumentsAsT(z->dataType())); - - NDArray::registerSpecialUse({z}, {x, y}); - } else if (block.getTArguments()->size() > 0) { - auto y = NDArrayFactory::create(x->dataType(), T_ARG(0), block.launchContext()); - - x->applyScalarArr(static_cast(opNum), y, *z); - // NDArray::prepareSpecialUse({z}, {x, &y}); - // NativeOpExecutioner::execScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.special(), extras.argumentsAsT(z->dataType(), 1)); - - manager.synchronize(); - } else { - NDArray::prepareSpecialUse({z}, {x, _scalar}); - - NativeOpExecutioner::execScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), _scalar->buffer(), _scalar->shapeInfo(), _scalar->specialBuffer(), _scalar->specialShapeInfo(), extras.argumentsAsT(z->dataType())); - - NDArray::registerSpecialUse({z}, {x, _scalar}); - } - - return Status::OK(); - } - } -} \ No newline at end of file +namespace ops { +LegacyScalarOp::LegacyScalarOp() : LegacyOp::LegacyOp(1) { + this->getOpDescriptor()->allowInplace(true); +} + +LegacyScalarOp::LegacyScalarOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { + this->getOpDescriptor()->allowInplace(true); +} + +LegacyOp *LegacyScalarOp::clone() { + return new LegacyScalarOp(this->_opNum, *this->_scalar); +} + +LegacyScalarOp::LegacyScalarOp(int opNum, NDArray &scalar) + : LegacyOp::LegacyOp(1, opNum) { + _scalar = new NDArray(scalar.dup(scalar.ordering())); +} + +ShapeList *LegacyScalarOp::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + auto inShape = inputShape->at(0); + + Nd4jLong *newShape; + COPY_SHAPE(inShape, newShape); + + return SHAPELIST(CONSTANT(newShape)); +} + +Nd4jStatus LegacyScalarOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyScalarOp"); + + if (block.width() > 1) { + auto y = INPUT_VARIABLE(1); + + NDArray::prepareSpecialUse({z}, {x, y}); + + NativeOpExecutioner::execScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), y->buffer(), y->shapeInfo(), + y->specialBuffer(), y->specialShapeInfo(), + extras.argumentsAsT(z->dataType())); + + NDArray::registerSpecialUse({z}, {x, y}); + } else if (block.numT() > 0) { + auto y = + NDArrayFactory::create(x->dataType(), T_ARG(0), block.launchContext()); + + x->applyScalarArr(static_cast(opNum), y, *z); + // NDArray::prepareSpecialUse({z}, {x, &y}); + // NativeOpExecutioner::execScalar(block.launchContext(), opNum, + // x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), + // z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), + // y.buffer(), y.shapeInfo(), y.specialBuffer(), y.special(), + // extras.argumentsAsT(z->dataType(), 1)); + + manager.synchronize(); + } else { + NDArray::prepareSpecialUse({z}, {x, _scalar}); + + NativeOpExecutioner::execScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), _scalar->buffer(), + _scalar->shapeInfo(), _scalar->specialBuffer(), + _scalar->specialShapeInfo(), extras.argumentsAsT(z->dataType())); + + NDArray::registerSpecialUse({z}, {x, _scalar}); + } + + return Status::OK(); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyStatsOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyStatsOp.cpp index 4a60064b5df8..cb362ea0369a 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyStatsOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyStatsOp.cpp @@ -18,103 +18,126 @@ // Created by raver119 on 17.10.2017. // -#include +#include +#include #include #include -#include -#include - +#include namespace sd { - namespace ops { - Nd4jStatus LegacyStatsOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); - - NDArray::prepareSpecialUse({z}, {x}); - - // we assume that opNuk is either stored in block, or was provided via op constructor - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - - // bias goes as first argument, unlike all other reductions - bool biasCorrected = false; - if (block.getIArguments()->size() > 0) - biasCorrected = INT_ARG(0) > 0; - - ExtraArguments extras(*block.getTArguments()); - PointersManager manager(block.launchContext(),"LegacyStatsOp"); - - if (block.getIArguments()->size() == 1 || (block.getIArguments()->size() == 2 && INT_ARG(1) == sd::DataTypeUtils::max())) { - // scalar - NativeOpExecutioner::execSummaryStatsScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), biasCorrected); - } else { - // dimensions for TAD - // we should skip first argument here, because it's addressing bias correction - std::vector dims(*block.getIArguments()); - for (int e = 0; e < dims.size(); e++) - if (dims[e] < 0) - dims[e] += x->rankOf(); - - REQUIRE_TRUE(dims.size() > 0, 0, "Some dimensions requuired for reduction!"); - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x->shapeInfo(), dims); - - auto pTadShape = Environment::getInstance().isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); //(Nd4jLong *) manager.replicatePointer(tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); - auto pTadOffsets = Environment::getInstance().isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tad.tadOffsets, tad.numTads * sizeof(Nd4jLong)); - - NativeOpExecutioner::execSummaryStats(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), extras.argumentsAsT(z->dataType()), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), dims.data(), (int) dims.size(), pTadShape, pTadOffsets, biasCorrected); - } - - manager.synchronize(); - STORE_RESULT(*z); - - return Status::OK(); - } - - LegacyStatsOp::LegacyStatsOp() : LegacyOp::LegacyOp(1) { - // - } - - LegacyStatsOp::LegacyStatsOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - // - } - - LegacyOp* LegacyStatsOp::clone() { - return new LegacyStatsOp(this->_opNum); - } - - /** - * For all reductions rules are simple: either you return scalar, or you return reduced NDArray. - * It solely depends on input shape, and requested dimensions - */ - ShapeList *LegacyStatsOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - - Nd4jLong *newShape; - if (block.getIArguments()->size() == 0 || (block.getIArguments()->size() == 1 && INT_ARG(0) == sd::DataTypeUtils::max())) { - // in this case we just return scalar - ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong); - newShape[0] = 2; - newShape[1] = 1; - newShape[2] = 1; - newShape[3] = 1; - newShape[4] = 1; - newShape[5] = 0; - newShape[6] = 1; - newShape[7] = 99; - } else { - // in this case we're building proper shape for reduction - auto array = new NDArray(nullptr, inShape, block.launchContext()); - - auto newShape = ShapeUtils::evalReduceShapeInfo('c', *block.getIArguments(), *array, false, true); - - delete array; - return SHAPELIST(newShape); - } - - return SHAPELIST(CONSTANT(newShape)); - } - } -} \ No newline at end of file +namespace ops { +Nd4jStatus LegacyStatsOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + + NDArray::prepareSpecialUse({z}, {x}); + + // we assume that opNuk is either stored in block, or was provided via op + // constructor + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + + // bias goes as first argument, unlike all other reductions + bool biasCorrected = false; + if (block.numI() > 0) biasCorrected = INT_ARG(0) > 0; + + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyStatsOp"); + + if (block.numI() == 1 || + (block.numI() == 2 && INT_ARG(1) == sd::DataTypeUtils::max())) { + // scalar + NativeOpExecutioner::execSummaryStatsScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), biasCorrected); + } else { + // dimensions for TAD + // we should skip first argument here, because it's addressing bias + // correction + std::vector dims(block.getIArguments()); + for (int e = 0; e < dims.size(); e++) + if (dims[e] < 0) dims[e] += x->rankOf(); + + REQUIRE_TRUE(dims.size() > 0, 0, + "Some dimensions requuired for reduction!"); + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + x->shapeInfo(), dims); + + auto pTadShape = + Environment::getInstance().isCPU() + ? packX.primaryShapeInfo() + : packX + .specialShapeInfo(); //(Nd4jLong *) + // manager.replicatePointer(tad.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + auto pTadOffsets = + Environment::getInstance().isCPU() + ? packX.primaryOffsets() + : packX + .specialOffsets(); //(Nd4jLong *) + // manager.replicatePointer(tad.tadOffsets, + // tad.numTads * sizeof(Nd4jLong)); + + NativeOpExecutioner::execSummaryStats( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), + (int)dims.size(), pTadShape, pTadOffsets, biasCorrected); + } + + manager.synchronize(); + STORE_RESULT(*z); + + return Status::OK(); +} + +LegacyStatsOp::LegacyStatsOp() : LegacyOp::LegacyOp(1) { + // +} + +LegacyStatsOp::LegacyStatsOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { + // +} + +LegacyOp *LegacyStatsOp::clone() { return new LegacyStatsOp(this->_opNum); } + +/** + * For all reductions rules are simple: either you return scalar, or you + * return reduced NDArray. It solely depends on input shape, and requested + * dimensions + */ +ShapeList *LegacyStatsOp::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + auto inShape = inputShape->at(0); + + Nd4jLong *newShape; + if (block.numI() == 0 || + (block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max())) { + // in this case we just return scalar + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); + newShape[0] = 2; + newShape[1] = 1; + newShape[2] = 1; + newShape[3] = 1; + newShape[4] = 1; + newShape[5] = 0; + newShape[6] = 1; + newShape[7] = 99; + } else { + // in this case we're building proper shape for reduction + auto array = new NDArray(nullptr, inShape, block.launchContext()); + + auto newShape = ShapeUtils::evalReduceShapeInfo('c', block.getIArguments(), + *array, false, true); + + delete array; + return SHAPELIST(newShape); + } + + return SHAPELIST(CONSTANT(newShape)); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyTransformAnyOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyTransformAnyOp.cpp index dde8ce9e9103..b8edd4eac4f5 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyTransformAnyOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyTransformAnyOp.cpp @@ -18,57 +18,61 @@ // Created by raver119 on 16.10.2017. // -#include - #include - +#include namespace sd { - namespace ops { - LegacyTransformAnyOp::LegacyTransformAnyOp() : LegacyOp::LegacyOp(1) { - // just a no-op - } +namespace ops { +LegacyTransformAnyOp::LegacyTransformAnyOp() : LegacyOp::LegacyOp(1) { + // just a no-op +} - LegacyTransformAnyOp::LegacyTransformAnyOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - // just a no-op - } +LegacyTransformAnyOp::LegacyTransformAnyOp(int opNum) + : LegacyOp::LegacyOp(1, opNum) { + // just a no-op +} - LegacyOp* LegacyTransformAnyOp::clone() { - return new LegacyTransformAnyOp(this->_opNum); - } +LegacyOp *LegacyTransformAnyOp::clone() { + return new LegacyTransformAnyOp(this->_opNum); +} - Nd4jStatus LegacyTransformAnyOp::validateAndExecute(Context &block) { - auto input = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +Nd4jStatus LegacyTransformAnyOp::validateAndExecute(Context &block) { + auto input = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - NDArray::prepareSpecialUse({z}, {input}); + NDArray::prepareSpecialUse({z}, {input}); - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - ExtraArguments extras(*block.getTArguments()); - PointersManager manager(block.launchContext(),"LegacyTransformAnyOp"); + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyTransformAnyOp"); - NativeOpExecutioner::execTransformAny(block.launchContext(), opNum, input->buffer(), input->shapeInfo(), input->specialBuffer(), input->specialShapeInfo(), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), extras.argumentsAsT(z->dataType()), nullptr, nullptr); + NativeOpExecutioner::execTransformAny( + block.launchContext(), opNum, input->buffer(), input->shapeInfo(), + input->specialBuffer(), input->specialShapeInfo(), z->buffer(), + z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), nullptr, nullptr); - manager.synchronize(); - STORE_RESULT(*z); + manager.synchronize(); + STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - /** - * For transform operations, output shape always equals to input shape. With just a few exclusions, like im2col and col2im. - * But these ops already have CustomOp implementations. - * - */ - ShapeList *LegacyTransformAnyOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - - Nd4jLong *newShape; - COPY_SHAPE(inShape, newShape); - - return SHAPELIST(CONSTANT(newShape)); - } - } -} \ No newline at end of file +/** + * For transform operations, output shape always equals to input shape. With + * just a few exclusions, like im2col and col2im. But these ops already have + * CustomOp implementations. + * + */ +ShapeList *LegacyTransformAnyOp::calculateOutputShape( + ShapeList *inputShape, sd::graph::Context &block) { + auto inShape = inputShape->at(0); + + Nd4jLong *newShape; + COPY_SHAPE(inShape, newShape); + + return SHAPELIST(CONSTANT(newShape)); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyTransformBoolOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyTransformBoolOp.cpp index 3bf4f1ff4a0c..b13126483ff2 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyTransformBoolOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyTransformBoolOp.cpp @@ -18,54 +18,58 @@ // Created by raver119 on 16.10.2017. // -#include - #include - +#include namespace sd { - namespace ops { - LegacyTransformBoolOp::LegacyTransformBoolOp() : LegacyOp::LegacyOp(1) { - // just a no-op - } +namespace ops { +LegacyTransformBoolOp::LegacyTransformBoolOp() : LegacyOp::LegacyOp(1) { + // just a no-op +} - LegacyTransformBoolOp::LegacyTransformBoolOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - // just a no-op - } +LegacyTransformBoolOp::LegacyTransformBoolOp(int opNum) + : LegacyOp::LegacyOp(1, opNum) { + // just a no-op +} - LegacyOp* LegacyTransformBoolOp::clone() { - return new LegacyTransformBoolOp(this->_opNum); - } +LegacyOp *LegacyTransformBoolOp::clone() { + return new LegacyTransformBoolOp(this->_opNum); +} - Nd4jStatus LegacyTransformBoolOp::validateAndExecute(Context &block) { - auto input = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +Nd4jStatus LegacyTransformBoolOp::validateAndExecute(Context &block) { + auto input = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - NDArray::prepareSpecialUse({z}, {input}); + NDArray::prepareSpecialUse({z}, {input}); - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - ExtraArguments extras(*block.getTArguments()); - PointersManager manager(block.launchContext(),"LegacyTransformBoolOp"); + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyTransformBoolOp"); - NativeOpExecutioner::execTransformBool(block.launchContext(), opNum, input->buffer(), input->shapeInfo(), input->specialBuffer(), input->specialShapeInfo(), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - extras.argumentsAsT(input->dataType()), nullptr, nullptr); + NativeOpExecutioner::execTransformBool( + block.launchContext(), opNum, input->buffer(), input->shapeInfo(), + input->specialBuffer(), input->specialShapeInfo(), z->buffer(), + z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), + extras.argumentsAsT(input->dataType()), nullptr, nullptr); - manager.synchronize(); - STORE_RESULT(*z); + manager.synchronize(); + STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - /** - * For transform operations, output shape always equals to input shape. With just a few exclusions, like im2col and col2im. - * But these ops already have CustomOp implementations. - * - */ - ShapeList *LegacyTransformBoolOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(inShape, DataType::BOOL))); - } - } -} \ No newline at end of file +/** + * For transform operations, output shape always equals to input shape. With + * just a few exclusions, like im2col and col2im. But these ops already have + * CustomOp implementations. + * + */ +ShapeList *LegacyTransformBoolOp::calculateOutputShape( + ShapeList *inputShape, sd::graph::Context &block) { + auto inShape = inputShape->at(0); + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + ShapeDescriptor(inShape, DataType::BOOL))); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyTransformFloatOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyTransformFloatOp.cpp index f25ba00fef13..5896f96f51a5 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyTransformFloatOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyTransformFloatOp.cpp @@ -18,57 +18,61 @@ // Created by raver119 on 16.10.2017. // -#include - #include - +#include namespace sd { - namespace ops { - LegacyTransformFloatOp::LegacyTransformFloatOp() : LegacyOp::LegacyOp(1) { - // just a no-op - } +namespace ops { +LegacyTransformFloatOp::LegacyTransformFloatOp() : LegacyOp::LegacyOp(1) { + // just a no-op +} - LegacyTransformFloatOp::LegacyTransformFloatOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - // just a no-op - } +LegacyTransformFloatOp::LegacyTransformFloatOp(int opNum) + : LegacyOp::LegacyOp(1, opNum) { + // just a no-op +} - LegacyOp* LegacyTransformFloatOp::clone() { - return new LegacyTransformFloatOp(this->_opNum); - } +LegacyOp *LegacyTransformFloatOp::clone() { + return new LegacyTransformFloatOp(this->_opNum); +} - Nd4jStatus LegacyTransformFloatOp::validateAndExecute(Context &block) { - auto input = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +Nd4jStatus LegacyTransformFloatOp::validateAndExecute(Context &block) { + auto input = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - NDArray::prepareSpecialUse({z}, {input}); + NDArray::prepareSpecialUse({z}, {input}); - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - ExtraArguments extras(*block.getTArguments()); - PointersManager manager(block.launchContext(), "LegacyTransformFloatOp"); + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyTransformFloatOp"); - NativeOpExecutioner::execTransformFloat(block.launchContext(), opNum, input->buffer(), input->shapeInfo(), input->specialBuffer(), input->specialShapeInfo(), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), extras.argumentsAsT(z->dataType()), nullptr, nullptr); + NativeOpExecutioner::execTransformFloat( + block.launchContext(), opNum, input->buffer(), input->shapeInfo(), + input->specialBuffer(), input->specialShapeInfo(), z->buffer(), + z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), nullptr, nullptr); - manager.synchronize(); - STORE_RESULT(*z); + manager.synchronize(); + STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - /** - * For transform operations, output shape always equals to input shape. With just a few exclusions, like im2col and col2im. - * But these ops already have CustomOp implementations. - * - */ - ShapeList *LegacyTransformFloatOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - - Nd4jLong *newShape; - COPY_SHAPE(inShape, newShape); - - return SHAPELIST(CONSTANT(newShape)); - } - } -} \ No newline at end of file +/** + * For transform operations, output shape always equals to input shape. With + * just a few exclusions, like im2col and col2im. But these ops already have + * CustomOp implementations. + * + */ +ShapeList *LegacyTransformFloatOp::calculateOutputShape( + ShapeList *inputShape, sd::graph::Context &block) { + auto inShape = inputShape->at(0); + + Nd4jLong *newShape; + COPY_SHAPE(inShape, newShape); + + return SHAPELIST(CONSTANT(newShape)); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyTransformOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyTransformOp.cpp index d0a8f7604d52..1ebefda35482 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyTransformOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyTransformOp.cpp @@ -18,51 +18,54 @@ // Created by raver119 on 16.10.2017. // -#include - #include +#include #ifdef ONLY_SAME_TRANSFORM namespace sd { - namespace ops { - LegacyTransformOp::LegacyTransformOp() : LegacyOp::LegacyOp(1) { - // just a no-op - } +namespace ops { +LegacyTransformOp::LegacyTransformOp() : LegacyOp::LegacyOp(1) { + // just a no-op +} - LegacyTransformOp::LegacyTransformOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - // just a no-op - } +LegacyTransformOp::LegacyTransformOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { + // just a no-op +} - LegacyOp* LegacyTransformOp::clone() { - return new LegacyTransformOp(this->_opNum); - } +LegacyOp *LegacyTransformOp::clone() { + return new LegacyTransformOp(this->_opNum); +} - Nd4jStatus LegacyTransformOp::validateAndExecute(Context &block) { - auto input = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +Nd4jStatus LegacyTransformOp::validateAndExecute(Context &block) { + auto input = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - NativeOpExcutioner::execTransformSame(opNum, input->buffer(), input->shapeInfo(), z->buffer(), z->shapeInfo(), block.getTArguments()->data(), nullptr, nullptr); + NativeOpExcutioner::execTransformSame( + opNum, input->buffer(), input->shapeInfo(), z->buffer(), z->shapeInfo(), + block.getTArguments()->data(), nullptr, nullptr); - STORE_RESULT(*z); + STORE_RESULT(*z); - return ND4J_STATUS_OK; - } + return ND4J_STATUS_OK; +} - /** - * For transform operations, output shape always equals to input shape. With just a few exclusions, like im2col and col2im. - * But these ops already have CustomOp implementations. - * - */ - ShapeList *LegacyTransformOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); +/** + * For transform operations, output shape always equals to input shape. With + * just a few exclusions, like im2col and col2im. But these ops already have + * CustomOp implementations. + * + */ +ShapeList *LegacyTransformOp::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + auto inShape = inputShape->at(0); - Nd4jLong *newShape; - COPY_SHAPE(inShape, newShape); + Nd4jLong *newShape; + COPY_SHAPE(inShape, newShape); - return SHAPELIST(CONSTANT(newShape)); - } - } + return SHAPELIST(CONSTANT(newShape)); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/impl/LegacyTransformSameOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyTransformSameOp.cpp index 02a69da6be69..c353627fa3ff 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyTransformSameOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyTransformSameOp.cpp @@ -18,57 +18,61 @@ // Created by raver119 on 16.10.2017. // -#include - #include - +#include namespace sd { - namespace ops { - LegacyTransformSameOp::LegacyTransformSameOp() : LegacyOp::LegacyOp(1) { - this->getOpDescriptor()->allowInplace(true); - } +namespace ops { +LegacyTransformSameOp::LegacyTransformSameOp() : LegacyOp::LegacyOp(1) { + this->getOpDescriptor()->allowInplace(true); +} - LegacyTransformSameOp::LegacyTransformSameOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - this->getOpDescriptor()->allowInplace(true); - } +LegacyTransformSameOp::LegacyTransformSameOp(int opNum) + : LegacyOp::LegacyOp(1, opNum) { + this->getOpDescriptor()->allowInplace(true); +} - LegacyOp* LegacyTransformSameOp::clone() { - return new LegacyTransformSameOp(this->_opNum); - } +LegacyOp *LegacyTransformSameOp::clone() { + return new LegacyTransformSameOp(this->_opNum); +} - Nd4jStatus LegacyTransformSameOp::validateAndExecute(Context &block) { - auto input = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +Nd4jStatus LegacyTransformSameOp::validateAndExecute(Context &block) { + auto input = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - NDArray::prepareSpecialUse({z}, {input}); + NDArray::prepareSpecialUse({z}, {input}); - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - ExtraArguments extras(*block.getTArguments()); - PointersManager manager(block.launchContext(), "LegacyTransformSameOp"); + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyTransformSameOp"); - NativeOpExecutioner::execTransformSame(block.launchContext(), opNum, input->buffer(), input->shapeInfo(), input->specialBuffer(), input->specialShapeInfo(), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), extras.argumentsAsT(z->dataType()), nullptr, nullptr); + NativeOpExecutioner::execTransformSame( + block.launchContext(), opNum, input->buffer(), input->shapeInfo(), + input->specialBuffer(), input->specialShapeInfo(), z->buffer(), + z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), nullptr, nullptr); - manager.synchronize(); - STORE_RESULT(*z); + manager.synchronize(); + STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - /** - * For transform operations, output shape always equals to input shape. With just a few exclusions, like im2col and col2im. - * But these ops already have CustomOp implementations. - * - */ - ShapeList *LegacyTransformSameOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - - Nd4jLong *newShape; - COPY_SHAPE(inShape, newShape); - - return SHAPELIST(CONSTANT(newShape)); - } - } -} \ No newline at end of file +/** + * For transform operations, output shape always equals to input shape. With + * just a few exclusions, like im2col and col2im. But these ops already have + * CustomOp implementations. + * + */ +ShapeList *LegacyTransformSameOp::calculateOutputShape( + ShapeList *inputShape, sd::graph::Context &block) { + auto inShape = inputShape->at(0); + + Nd4jLong *newShape; + COPY_SHAPE(inShape, newShape); + + return SHAPELIST(CONSTANT(newShape)); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyTransformStrictOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyTransformStrictOp.cpp index 2093e3aab80d..baa3386afe12 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyTransformStrictOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyTransformStrictOp.cpp @@ -18,56 +18,61 @@ // Created by raver119 on 16.10.2017. // -#include - #include - +#include namespace sd { - namespace ops { - LegacyTransformStrictOp::LegacyTransformStrictOp() : LegacyOp::LegacyOp(1) { - this->getOpDescriptor()->allowInplace(true); - } +namespace ops { +LegacyTransformStrictOp::LegacyTransformStrictOp() : LegacyOp::LegacyOp(1) { + this->getOpDescriptor()->allowInplace(true); +} - LegacyTransformStrictOp::LegacyTransformStrictOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - this->getOpDescriptor()->allowInplace(true); - } +LegacyTransformStrictOp::LegacyTransformStrictOp(int opNum) + : LegacyOp::LegacyOp(1, opNum) { + this->getOpDescriptor()->allowInplace(true); +} - LegacyOp* LegacyTransformStrictOp::clone() { - return new LegacyTransformStrictOp(this->_opNum); - } +LegacyOp *LegacyTransformStrictOp::clone() { + return new LegacyTransformStrictOp(this->_opNum); +} - Nd4jStatus LegacyTransformStrictOp::validateAndExecute(Context &block) { - auto input = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +Nd4jStatus LegacyTransformStrictOp::validateAndExecute(Context &block) { + auto input = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - NDArray::prepareSpecialUse({z}, {input}); + NDArray::prepareSpecialUse({z}, {input}); - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - ExtraArguments extras(*block.getTArguments()); - PointersManager manager(block.launchContext(), "LegacyTransformStrictOp"); + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyTransformStrictOp"); - NativeOpExecutioner::execTransformStrict(block.launchContext(), opNum, input->buffer(), input->shapeInfo(), input->specialBuffer(), input->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), extras.argumentsAsT(z->dataType()), nullptr, nullptr); + NativeOpExecutioner::execTransformStrict( + block.launchContext(), opNum, input->buffer(), input->shapeInfo(), + input->specialBuffer(), input->specialShapeInfo(), z->buffer(), + z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), nullptr, nullptr); - manager.synchronize(); - STORE_RESULT(*z); + manager.synchronize(); + STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - /** - * For transform operations, output shape always equals to input shape. With just a few exclusions, like im2col and col2im. - * But these ops already have CustomOp implementations. - * - */ - ShapeList *LegacyTransformStrictOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - - Nd4jLong *newShape; - COPY_SHAPE(inShape, newShape); - - return SHAPELIST(CONSTANT(newShape)); - } - } -} \ No newline at end of file +/** + * For transform operations, output shape always equals to input shape. With + * just a few exclusions, like im2col and col2im. But these ops already have + * CustomOp implementations. + * + */ +ShapeList *LegacyTransformStrictOp::calculateOutputShape( + ShapeList *inputShape, sd::graph::Context &block) { + auto inShape = inputShape->at(0); + + Nd4jLong *newShape; + COPY_SHAPE(inShape, newShape); + + return SHAPELIST(CONSTANT(newShape)); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LogicOp.cpp b/libnd4j/include/ops/declarable/impl/LogicOp.cpp index ae24b5631c9c..a90698a136b9 100644 --- a/libnd4j/include/ops/declarable/impl/LogicOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LogicOp.cpp @@ -21,20 +21,21 @@ #include "ops/declarable/LogicOp.h" namespace sd { - namespace ops { - LogicOp::LogicOp(const char *name) : DeclarableOp::DeclarableOp(name, true) { - // just using DeclarableOp constructor - //this->_descriptor-> - } +namespace ops { +LogicOp::LogicOp(const char *name) : DeclarableOp::DeclarableOp(name, true) { + // just using DeclarableOp constructor + // this->_descriptor-> +} - Nd4jStatus LogicOp::validateAndExecute(sd::graph::Context &block) { - nd4j_logger("WARNING: LogicOps should NOT be ever called\n", ""); - return ND4J_STATUS_BAD_INPUT; - } +Nd4jStatus LogicOp::validateAndExecute(sd::graph::Context &block) { + nd4j_logger("WARNING: LogicOps should NOT be ever called\n", ""); + return ND4J_STATUS_BAD_INPUT; +} - ShapeList* LogicOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - // FIXME: we probably want these ops to evaluate scopes - return SHAPELIST(); - } - } -} \ No newline at end of file +ShapeList *LogicOp::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + // FIXME: we probably want these ops to evaluate scopes + return SHAPELIST(); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/OpDescriptor.cpp b/libnd4j/include/ops/declarable/impl/OpDescriptor.cpp index 84c1bc291e8e..9c698e6aa892 100644 --- a/libnd4j/include/ops/declarable/impl/OpDescriptor.cpp +++ b/libnd4j/include/ops/declarable/impl/OpDescriptor.cpp @@ -21,281 +21,275 @@ #include namespace sd { - namespace ops { - - OpDescriptor::OpDescriptor(const char * opName, bool isLogic) { - _logic = isLogic; - _opName = opName; - } - - OpDescriptor::OpDescriptor(int numInputs, const char * opName, bool isScalar) { - _numInputs = numInputs; - _numOutputs = 1; - - _opName = opName; - _hash = sd::ops::HashHelper::getInstance().getLongHash(_opName); - _opClass = sd::graph::OpClass_CONDITIONAL; - - _scalar = isScalar; - } - - OpDescriptor::OpDescriptor(int numInputs, std::string opName, bool isScalar) { - _numInputs = numInputs; - _numOutputs = 1; - - _opName = opName; - _hash = sd::ops::HashHelper::getInstance().getLongHash(_opName); - _opClass = sd::graph::OpClass_CONDITIONAL; - - _scalar = isScalar; - } - - void OpDescriptor::allowInplace(bool reallyAllow){ - _allowsInplace = reallyAllow; - } - - bool OpDescriptor::operator==(const OpDescriptor& other) const { - if (_hash == -1 && other._hash == -1) - return this->_opNum == other._opNum; - else - return this->_hash == other._hash; - } - - OpDescriptor::OpDescriptor(int numInputs, int numOutputs, std::string opName, bool allowsInplace) : OpDescriptor::OpDescriptor(numInputs, numOutputs, opName.c_str(), allowsInplace) { - // - } - - void OpDescriptor::setHash(Nd4jLong hash) { - _hash = hash; - } - - // default constructor - OpDescriptor::OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace) { - _numInputs = numInputs; - _numOutputs = numOutputs; - - std::string tmp(opName); - _opName = tmp; - _allowsInplace = allowsInplace; - _hash = sd::ops::HashHelper::getInstance().getLongHash(tmp); - _divergent = false; - - // just default value - _opClass = sd::graph::OpClass_TRANSFORM; - } - - // constructor for configurable op - OpDescriptor::OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs) : OpDescriptor::OpDescriptor(numInputs, numOutputs, opName, allowsInplace) { - _tArgs = tArgs; - _iArgs = iArgs; - } - - // constructor for non-configurable divergent op - OpDescriptor::OpDescriptor(int numInputs, int numOutputs, std::string opName, bool allowsInplace, bool divergent) : OpDescriptor::OpDescriptor(numInputs, numOutputs, opName.c_str(), allowsInplace, divergent) { - - } - - // constructor for non-configurable divergent op - OpDescriptor::OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent) : OpDescriptor::OpDescriptor(numInputs, numOutputs, opName, allowsInplace) { - _divergent = divergent; - } - - // constructor for configurable divergent op - OpDescriptor::OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent, int tArgs, int iArgs) : OpDescriptor(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs) { - _divergent = divergent; - } - - // default destructor - OpDescriptor::~OpDescriptor() { - // - } - - int OpDescriptor::getNumberOfTArgs() { - return _tArgs; - } - - int OpDescriptor::getNumberOfIArgs() { - return _iArgs; - } - - int OpDescriptor::getNumberOfInputs() { - return _numInputs; - } - - Nd4jLong OpDescriptor::getHash() { - return _hash; - } - - int OpDescriptor::getNumberOfOutputs() { - return _numOutputs; - } - - std::string * OpDescriptor::getOpName() { - return &_opName; - } - - bool OpDescriptor::isDivergent() { - return _divergent; - } - - void OpDescriptor::setOpNum(int opNum) { - _opNum = opNum; - } - - bool OpDescriptor::allowsInplace() { - return _allowsInplace; - } - - int OpDescriptor::getOpNum() { - return _opNum; - } - - OpDescriptor* OpDescriptor::setInputType(const InputType type) { - _inputType = type; - return this; - } - - InputType OpDescriptor::inputType() { - return _inputType; - } - - OpDescriptor* OpDescriptor::setAllowedInputTypes(const std::initializer_list &dtypes) { - _allowedIns = dtypes; - return this; - } - - OpDescriptor* OpDescriptor::setAllowedOutputTypes(const std::initializer_list &dtypes) { - _allowedOuts = dtypes; - return this; - } - - OpDescriptor* OpDescriptor::allowOverride(bool allowOverride) { - _dtypeOverride = allowOverride; - return this; - } - - OpDescriptor* OpDescriptor::setAllowedInputTypes(const sd::DataType dtype) { - _allowedIns.clear(); - _allowedIns.emplace_back(dtype); - return this; - } - - OpDescriptor* OpDescriptor::setAllowedOutputTypes(const sd::DataType dtype) { - _allowedOuts.clear(); - _allowedOuts.emplace_back(dtype); - return this; - } - - OpDescriptor* OpDescriptor::setInputType(const int idx, const sd::DataType dtype) { - _inputTypes[idx] = { dtype }; - return this; - } - - OpDescriptor* OpDescriptor::setOutputType(const int idx, const sd::DataType dtype) { - _outputTypes[idx] = { dtype }; - return this; - } - - OpDescriptor* OpDescriptor::setSameMode(const bool reallySame) { - _sameMode = reallySame; - return this; - } - - OpDescriptor* OpDescriptor::setAllowedInputTypes(int index, const std::vector &dtype) { - _inputTypes[index] = dtype; - return this; - } - - OpDescriptor* OpDescriptor::setAllowedOutputTypes(int index, const std::vector &dtype) { - _outputTypes[index] = dtype; - return this; - } - - OpDescriptor* OpDescriptor::setAllowedInputTypes(int index, sd::DataType dtype) { - if (_inputTypes.count(index) == 0) - _inputTypes[index] = {dtype}; - else - _inputTypes[index].emplace_back(dtype); - - return this; - } - - OpDescriptor* OpDescriptor::setAllowedOutputTypes(int index, sd::DataType dtype) { - if (_outputTypes.count(index) == 0) - _outputTypes[index] = {dtype}; - else - _outputTypes[index].emplace_back(dtype); - - return this; - } - - bool OpDescriptor::checkDataTypesMatch(sd::DataType needle, std::vector &haystack) const { - // if haystack is empty - INHERIT is occurs - any type is perfect? - if (haystack.empty()) - return true; - - // first we're checking for direct input type match - if (std::find(haystack.begin(), haystack.end(), needle) == haystack.end()) { - - // if direct input match failed - we're checking for ANY as allowed input - if (std::find(haystack.begin(), haystack.end(), sd::DataType::ANY) == haystack.end()) - return false; - else - return true; - } else { - return true; - } - } - - bool OpDescriptor::checkInputMatch(int index, sd::DataType dataType) { - // we check for per-input types first - if (_inputTypes.empty() || _inputTypes.count(index) == 0) { - // checking global input types - return checkDataTypesMatch(dataType, _allowedIns); - } else { - // checking data type for specified input - auto allowed = _inputTypes[index]; - return checkDataTypesMatch(dataType, allowed); - } - return true; - } - - bool OpDescriptor::checkOutputMatch(int index, sd::DataType dataType) { - // we check for per-output types first - if (_outputTypes.empty() || _outputTypes.count(index) == 0) { - - // checking global output types - return checkDataTypesMatch(dataType, _allowedOuts); - } else { - // checking data type for specified output - auto allowed = _outputTypes[index]; - return checkDataTypesMatch(dataType, allowed); - } - return true; - } - - bool OpDescriptor::isSameMode() { - return _sameMode; - } - - bool OpDescriptor::isInherit(int index) { - if (std::find(_allowedOuts.begin(), _allowedOuts.end(), sd::DataType::INHERIT) != _allowedOuts.end()) - return true; - if (_outputTypes.count(index) > 0) { - auto vec = _outputTypes[index]; - - if (std::find(vec.begin(), vec.end(), sd::DataType::INHERIT) != vec.end()) - return true; - } - - return false; - } - - std::vector OpDescriptor::getOutputTypesForOutput(int index) { - if (_outputTypes.count(index) > 0) - return _outputTypes.at(index); - else - return std::vector(); - } - } -} \ No newline at end of file +namespace ops { + +OpDescriptor::OpDescriptor(const char* opName, bool isLogic) { + _logic = isLogic; + _opName = opName; +} + +OpDescriptor::OpDescriptor(int numInputs, const char* opName, bool isScalar) { + _numInputs = numInputs; + _numOutputs = 1; + + _opName = opName; + _hash = sd::ops::HashHelper::getInstance().getLongHash(_opName); + _opClass = sd::graph::OpClass_CONDITIONAL; + + _scalar = isScalar; +} + +OpDescriptor::OpDescriptor(int numInputs, std::string opName, bool isScalar) { + _numInputs = numInputs; + _numOutputs = 1; + + _opName = opName; + _hash = sd::ops::HashHelper::getInstance().getLongHash(_opName); + _opClass = sd::graph::OpClass_CONDITIONAL; + + _scalar = isScalar; +} + +void OpDescriptor::allowInplace(bool reallyAllow) { + _allowsInplace = reallyAllow; +} + +bool OpDescriptor::operator==(const OpDescriptor& other) const { + if (_hash == -1 && other._hash == -1) + return this->_opNum == other._opNum; + else + return this->_hash == other._hash; +} + +OpDescriptor::OpDescriptor(int numInputs, int numOutputs, std::string opName, + bool allowsInplace) + : OpDescriptor::OpDescriptor(numInputs, numOutputs, opName.c_str(), + allowsInplace) { + // +} + +void OpDescriptor::setHash(Nd4jLong hash) { _hash = hash; } + +// default constructor +OpDescriptor::OpDescriptor(int numInputs, int numOutputs, const char* opName, + bool allowsInplace) { + _numInputs = numInputs; + _numOutputs = numOutputs; + + std::string tmp(opName); + _opName = tmp; + _allowsInplace = allowsInplace; + _hash = sd::ops::HashHelper::getInstance().getLongHash(tmp); + _divergent = false; + + // just default value + _opClass = sd::graph::OpClass_TRANSFORM; +} + +// constructor for configurable op +OpDescriptor::OpDescriptor(int numInputs, int numOutputs, const char* opName, + bool allowsInplace, int tArgs, int iArgs) + : OpDescriptor::OpDescriptor(numInputs, numOutputs, opName, allowsInplace) { + _tArgs = tArgs; + _iArgs = iArgs; +} + +// constructor for non-configurable divergent op +OpDescriptor::OpDescriptor(int numInputs, int numOutputs, std::string opName, + bool allowsInplace, bool divergent) + : OpDescriptor::OpDescriptor(numInputs, numOutputs, opName.c_str(), + allowsInplace, divergent) {} + +// constructor for non-configurable divergent op +OpDescriptor::OpDescriptor(int numInputs, int numOutputs, const char* opName, + bool allowsInplace, bool divergent) + : OpDescriptor::OpDescriptor(numInputs, numOutputs, opName, allowsInplace) { + _divergent = divergent; +} + +// constructor for configurable divergent op +OpDescriptor::OpDescriptor(int numInputs, int numOutputs, const char* opName, + bool allowsInplace, bool divergent, int tArgs, + int iArgs) + : OpDescriptor(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs) { + _divergent = divergent; +} + +// default destructor +OpDescriptor::~OpDescriptor() { + // +} + +int OpDescriptor::getNumberOfTArgs() { return _tArgs; } + +int OpDescriptor::getNumberOfIArgs() { return _iArgs; } + +int OpDescriptor::getNumberOfInputs() { return _numInputs; } + +Nd4jLong OpDescriptor::getHash() { return _hash; } + +int OpDescriptor::getNumberOfOutputs() { return _numOutputs; } + +std::string* OpDescriptor::getOpName() { return &_opName; } + +bool OpDescriptor::isDivergent() { return _divergent; } + +void OpDescriptor::setOpNum(int opNum) { _opNum = opNum; } + +bool OpDescriptor::allowsInplace() { return _allowsInplace; } + +int OpDescriptor::getOpNum() { return _opNum; } + +OpDescriptor* OpDescriptor::setInputType(const InputType type) { + _inputType = type; + return this; +} + +InputType OpDescriptor::inputType() { return _inputType; } + +OpDescriptor* OpDescriptor::setAllowedInputTypes( + const std::initializer_list& dtypes) { + _allowedIns = dtypes; + return this; +} + +OpDescriptor* OpDescriptor::setAllowedOutputTypes( + const std::initializer_list& dtypes) { + _allowedOuts = dtypes; + return this; +} + +OpDescriptor* OpDescriptor::allowOverride(bool allowOverride) { + _dtypeOverride = allowOverride; + return this; +} + +OpDescriptor* OpDescriptor::setAllowedInputTypes(const sd::DataType dtype) { + _allowedIns.clear(); + _allowedIns.emplace_back(dtype); + return this; +} + +OpDescriptor* OpDescriptor::setAllowedOutputTypes(const sd::DataType dtype) { + _allowedOuts.clear(); + _allowedOuts.emplace_back(dtype); + return this; +} + +OpDescriptor* OpDescriptor::setInputType(const int idx, + const sd::DataType dtype) { + _inputTypes[idx] = {dtype}; + return this; +} + +OpDescriptor* OpDescriptor::setOutputType(const int idx, + const sd::DataType dtype) { + _outputTypes[idx] = {dtype}; + return this; +} + +OpDescriptor* OpDescriptor::setSameMode(const bool reallySame) { + _sameMode = reallySame; + return this; +} + +OpDescriptor* OpDescriptor::setAllowedInputTypes( + int index, const std::vector& dtype) { + _inputTypes[index] = dtype; + return this; +} + +OpDescriptor* OpDescriptor::setAllowedOutputTypes( + int index, const std::vector& dtype) { + _outputTypes[index] = dtype; + return this; +} + +OpDescriptor* OpDescriptor::setAllowedInputTypes(int index, + sd::DataType dtype) { + if (_inputTypes.count(index) == 0) + _inputTypes[index] = {dtype}; + else + _inputTypes[index].emplace_back(dtype); + + return this; +} + +OpDescriptor* OpDescriptor::setAllowedOutputTypes(int index, + sd::DataType dtype) { + if (_outputTypes.count(index) == 0) + _outputTypes[index] = {dtype}; + else + _outputTypes[index].emplace_back(dtype); + + return this; +} + +bool OpDescriptor::checkDataTypesMatch( + sd::DataType needle, std::vector& haystack) const { + // if haystack is empty - INHERIT is occurs - any type is perfect? + if (haystack.empty()) return true; + + // first we're checking for direct input type match + if (std::find(haystack.begin(), haystack.end(), needle) == haystack.end()) { + // if direct input match failed - we're checking for ANY as allowed input + if (std::find(haystack.begin(), haystack.end(), sd::DataType::ANY) == + haystack.end()) + return false; + else + return true; + } else { + return true; + } +} + +bool OpDescriptor::checkInputMatch(int index, sd::DataType dataType) { + // we check for per-input types first + if (_inputTypes.empty() || _inputTypes.count(index) == 0) { + // checking global input types + return checkDataTypesMatch(dataType, _allowedIns); + } else { + // checking data type for specified input + auto allowed = _inputTypes[index]; + return checkDataTypesMatch(dataType, allowed); + } + return true; +} + +bool OpDescriptor::checkOutputMatch(int index, sd::DataType dataType) { + // we check for per-output types first + if (_outputTypes.empty() || _outputTypes.count(index) == 0) { + // checking global output types + return checkDataTypesMatch(dataType, _allowedOuts); + } else { + // checking data type for specified output + auto allowed = _outputTypes[index]; + return checkDataTypesMatch(dataType, allowed); + } + return true; +} + +bool OpDescriptor::isSameMode() { return _sameMode; } + +bool OpDescriptor::isInherit(int index) { + if (std::find(_allowedOuts.begin(), _allowedOuts.end(), + sd::DataType::INHERIT) != _allowedOuts.end()) + return true; + if (_outputTypes.count(index) > 0) { + auto vec = _outputTypes[index]; + + if (std::find(vec.begin(), vec.end(), sd::DataType::INHERIT) != vec.end()) + return true; + } + + return false; +} + +std::vector OpDescriptor::getOutputTypesForOutput(int index) { + if (_outputTypes.count(index) > 0) + return _outputTypes.at(index); + else + return std::vector(); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp b/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp index 327cb0482caa..6a6fc45e6774 100644 --- a/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp +++ b/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp @@ -18,252 +18,263 @@ // Created by raver119 on 07.10.2017. // - - #include + #include namespace sd { - namespace ops { +namespace ops { - /////////////////////////////// +/////////////////////////////// - template - __registrator::__registrator() { - auto ptr = new OpName(); - OpRegistrator::getInstance().registerOperation(ptr); - } +template +__registrator::__registrator() { + auto ptr = new OpName(); + OpRegistrator::getInstance().registerOperation(ptr); +} +template +__registratorSynonym::__registratorSynonym(const char* name, + const char* oname) { + auto ptr = reinterpret_cast( + OpRegistrator::getInstance().getOperation(oname)); + if (ptr == nullptr) { + std::string newName(name); + std::string oldName(oname); + + OpRegistrator::getInstance().updateMSVC( + sd::ops::HashHelper::getInstance().getLongHash(newName), oldName); + return; + } + OpRegistrator::getInstance().registerOperation(name, ptr); +} - template - __registratorSynonym::__registratorSynonym(const char *name, const char *oname) { - auto ptr = reinterpret_cast(OpRegistrator::getInstance().getOperation(oname)); - if (ptr == nullptr) { - std::string newName(name); - std::string oldName(oname); +/////////////////////////////// - OpRegistrator::getInstance().updateMSVC(sd::ops::HashHelper::getInstance().getLongHash(newName), oldName); - return; - } - OpRegistrator::getInstance().registerOperation(name, ptr); - } +OpRegistrator& OpRegistrator::getInstance() { + static OpRegistrator instance; - /////////////////////////////// + return instance; +} +void OpRegistrator::updateMSVC(Nd4jLong newHash, std::string& oldName) { + std::pair pair(newHash, oldName); + _msvc.insert(pair); +} - OpRegistrator& OpRegistrator::getInstance() { - static OpRegistrator instance; - return instance; - } +template +std::string OpRegistrator::local_to_string(T value) { + // create an output string stream + std::ostringstream os; + // throw the value into the string stream + os << value; - void OpRegistrator::updateMSVC(Nd4jLong newHash, std::string& oldName) { - std::pair pair(newHash, oldName); - _msvc.insert(pair); - } + // convert the string stream into a string and return + return os.str(); +} - template - std::string OpRegistrator::local_to_string(T value) { - //create an output string stream - std::ostringstream os ; +template <> +std::string OpRegistrator::local_to_string(int value) { + // create an output string stream + std::ostringstream os; - //throw the value into the string stream - os << value ; + // throw the value into the string stream + os << value; - //convert the string stream into a string and return - return os.str() ; - } + // convert the string stream into a string and return + return os.str(); +} - template <> - std::string OpRegistrator::local_to_string(int value) { - //create an output string stream - std::ostringstream os ; +void OpRegistrator::sigIntHandler(int sig) { - //throw the value into the string stream - os << value ; +} - //convert the string stream into a string and return - return os.str() ; - } +void OpRegistrator::exitHandler() { - void OpRegistrator::sigIntHandler(int sig) { +} - } +void OpRegistrator::sigSegVHandler(int sig) { - void OpRegistrator::exitHandler() { +} - } +OpRegistrator::~OpRegistrator() { +#ifndef _RELEASE + _msvc.clear(); - void OpRegistrator::sigSegVHandler(int sig) { + for (auto x : _uniqueH) delete x; - } + _uniqueH.clear(); - OpRegistrator::~OpRegistrator() { -#ifndef _RELEASE - _msvc.clear(); + _declarablesD.clear(); - for (auto x : _uniqueD) - delete x; + _declarablesLD.clear(); +#endif +} - for (auto x: _uniqueH) - delete x; +const char* OpRegistrator::getAllCustomOperations() { + _locker.lock(); + + if (!isInit) { + for (MAP_IMPL>::iterator + it = _declarablesD.begin(); + it != _declarablesD.end(); ++it) { + std::string op = + it->first + ":" + + local_to_string(it->second->getOpDescriptor()->getHash()) + ":" + + local_to_string(it->second->getOpDescriptor()->getNumberOfInputs()) + + ":" + + local_to_string(it->second->getOpDescriptor()->getNumberOfOutputs()) + + ":" + + local_to_string(it->second->getOpDescriptor()->allowsInplace()) + + ":" + + local_to_string(it->second->getOpDescriptor()->getNumberOfTArgs()) + + ":" + + local_to_string(it->second->getOpDescriptor()->getNumberOfIArgs()) + + ":" + ";"; + _opsList += op; + } - _uniqueD.clear(); + isInit = true; + } - _uniqueH.clear(); + _locker.unlock(); - _declarablesD.clear(); + return _opsList.c_str(); +} - _declarablesLD.clear(); -#endif - } - - const char * OpRegistrator::getAllCustomOperations() { - _locker.lock(); - - if (!isInit) { - for (MAP_IMPL::iterator it=_declarablesD.begin(); it!=_declarablesD.end(); ++it) { - std::string op = it->first + ":" - + local_to_string(it->second->getOpDescriptor()->getHash()) + ":" - + local_to_string(it->second->getOpDescriptor()->getNumberOfInputs()) + ":" - + local_to_string(it->second->getOpDescriptor()->getNumberOfOutputs()) + ":" - + local_to_string(it->second->getOpDescriptor()->allowsInplace()) + ":" - + local_to_string(it->second->getOpDescriptor()->getNumberOfTArgs()) + ":" - + local_to_string(it->second->getOpDescriptor()->getNumberOfIArgs()) + ":" - + ";" ; - _opsList += op; - } - - isInit = true; - } - - _locker.unlock(); - - return _opsList.c_str(); - } - - - bool OpRegistrator::registerOperation(const char* name, sd::ops::DeclarableOp* op) { - std::string str(name); - std::pair pair(str, op); - _declarablesD.insert(pair); - - auto hash = sd::ops::HashHelper::getInstance().getLongHash(str); - std::pair pair2(hash, op); - _declarablesLD.insert(pair2); - return true; - } - - /** - * This method registers operation - * - * @param op - */ - bool OpRegistrator::registerOperation(sd::ops::DeclarableOp *op) { - _uniqueD.emplace_back(op); - return registerOperation(op->getOpName()->c_str(), op); - } - - void OpRegistrator::registerHelper(sd::ops::platforms::PlatformHelper* op) { - std::pair p = {op->hash(), op->engine()}; - if (_helpersLH.count(p) > 0) - throw std::runtime_error("Tried to double register PlatformHelper"); - - _uniqueH.emplace_back(op); - - nd4j_debug("Adding helper for op \"%s\": [%lld - %i]\n", op->name().c_str(), op->hash(), (int) op->engine()); - - std::pair, sd::ops::platforms::PlatformHelper*> pair({op->name(), op->engine()}, op); - _helpersH.insert(pair); - - std::pair, sd::ops::platforms::PlatformHelper*> pair2(p, op); - _helpersLH.insert(pair2); - } - - sd::ops::DeclarableOp* OpRegistrator::getOperation(const char *name) { - std::string str(name); - return getOperation(str); - } - - /** - * This method returns registered Op by name - * - * @param name - * @return - */ - sd::ops::DeclarableOp *OpRegistrator::getOperation(Nd4jLong hash) { - if (!_declarablesLD.count(hash)) { - if (!_msvc.count(hash)) { - nd4j_printf("Unknown D operation requested by hash: [%lld]\n", hash); - return nullptr; - } else { - _locker.lock(); - - auto str = _msvc.at(hash); - auto op = _declarablesD.at(str); - auto oHash = op->getOpDescriptor()->getHash(); - - std::pair pair(oHash, op); - _declarablesLD.insert(pair); - - _locker.unlock(); - } - } - - return _declarablesLD.at(hash); - } - - sd::ops::DeclarableOp *OpRegistrator::getOperation(std::string& name) { - if (!_declarablesD.count(name)) { - nd4j_debug("Unknown operation requested: [%s]\n", name.c_str()); - return nullptr; - } - - return _declarablesD.at(name); - } - - sd::ops::platforms::PlatformHelper* OpRegistrator::getPlatformHelper(Nd4jLong hash, samediff::Engine engine) { - std::pair p = {hash, engine}; - if (_helpersLH.count(p) == 0) - throw std::runtime_error("Requested helper can't be found"); - - return _helpersLH[p]; - } - - bool OpRegistrator::hasHelper(Nd4jLong hash, samediff::Engine engine) { - std::pair p = {hash, engine}; - return _helpersLH.count(p) > 0; - } - - int OpRegistrator::numberOfOperations() { - return (int) _declarablesLD.size(); - } - - std::vector OpRegistrator::getAllHashes() { - std::vector result; - - for (auto &v:_declarablesLD) { - result.emplace_back(v.first); - } - - return result; - } - } +bool OpRegistrator::hasOperation(const std::string& opName) const { + return _declarablesD.count(opName) > 0; } -namespace std { - size_t hash>::operator()(const std::pair& k) const { - using std::hash; - auto res = std::hash()(k.first); - res ^= std::hash()((int) k.second) + 0x9e3779b9 + (res << 6) + (res >> 2); - return res; - } +bool OpRegistrator::hasOperation(Nd4jLong opName) const { + return _declarablesLD.count(opName) > 0; +} - size_t hash>::operator()(const std::pair& k) const { - using std::hash; - auto res = std::hash()(k.first); - res ^= std::hash()((int) k.second) + 0x9e3779b9 + (res << 6) + (res >> 2); - return res; +bool OpRegistrator::registerOperation( + const std::string& opName, std::shared_ptr op) { + std::pair> pair(opName, + op); + _declarablesD.insert(pair); + + auto hash = sd::ops::HashHelper::getInstance().getLongHash(opName); + std::pair> pair2(hash, op); + _declarablesLD.insert(pair2); + return true; +} + +/** + * This method registers operation + * + * @param op + */ +bool OpRegistrator::registerOperation( + std::shared_ptr op) { + return registerOperation(op->getOpName(), op); +} + +void OpRegistrator::registerHelper(sd::ops::platforms::PlatformHelper* op) { + std::pair p = {op->hash(), op->engine()}; + if (_helpersLH.count(p) > 0) + throw std::runtime_error("Tried to double register PlatformHelper"); + + _uniqueH.emplace_back(op); + + nd4j_debug("Adding helper for op \"%s\": [%lld - %i]\n", op->name().c_str(), + op->hash(), (int)op->engine()); + + std::pair, + sd::ops::platforms::PlatformHelper*> + pair({op->name(), op->engine()}, op); + _helpersH.insert(pair); + + std::pair, + sd::ops::platforms::PlatformHelper*> + pair2(p, op); + _helpersLH.insert(pair2); +} + +/** + * This method returns registered Op by name + * + * @param name + * @return + */ +std::shared_ptr OpRegistrator::getOperation( + Nd4jLong hash) { + if (!_declarablesLD.count(hash)) { + if (!_msvc.count(hash)) { + nd4j_printf("Unknown D operation requested by hash: [%lld]\n", hash); + return nullptr; + } else { + std::lock_guard lock(_locker); + + auto str = _msvc.at(hash); + auto op = _declarablesD.at(str); + auto oHash = op->getOpDescriptor()->getHash(); + + std::pair> pair(oHash, + op); + _declarablesLD.insert(pair); } + } + + return _declarablesLD.at(hash); } +std::shared_ptr OpRegistrator::getOperation( + const std::string& name) { + if (!_declarablesD.count(name)) { + nd4j_debug("Unknown operation requested: [%s]\n", name.c_str()); + return nullptr; + } + + return _declarablesD.at(name); +} + +sd::ops::platforms::PlatformHelper* OpRegistrator::getPlatformHelper( + Nd4jLong hash, samediff::Engine engine) { + std::pair p = {hash, engine}; + if (_helpersLH.count(p) == 0) + throw std::runtime_error("Requested helper can't be found"); + + return _helpersLH[p]; +} + +bool OpRegistrator::hasHelper(Nd4jLong hash, samediff::Engine engine) { + std::pair p = {hash, engine}; + return _helpersLH.count(p) > 0; +} + +int OpRegistrator::numberOfOperations() { return (int)_declarablesLD.size(); } + +std::vector OpRegistrator::getAllHashes() { + std::vector result; + + for (auto& v : _declarablesLD) { + result.emplace_back(v.first); + } + + return result; +} + +} // namespace ops +} // namespace sd + +namespace std { +size_t hash>::operator()( + const std::pair& k) const { + using std::hash; + auto res = std::hash()(k.first); + res ^= std::hash()((int)k.second) + 0x9e3779b9 + (res << 6) + (res >> 2); + return res; +} + +size_t hash>::operator()( + const std::pair& k) const { + using std::hash; + auto res = std::hash()(k.first); + res ^= std::hash()((int)k.second) + 0x9e3779b9 + (res << 6) + (res >> 2); + return res; +} +} // namespace std diff --git a/libnd4j/include/ops/declarable/impl/OpTuple.cpp b/libnd4j/include/ops/declarable/impl/OpTuple.cpp index fc43739e8b2d..c64ae1a755cb 100644 --- a/libnd4j/include/ops/declarable/impl/OpTuple.cpp +++ b/libnd4j/include/ops/declarable/impl/OpTuple.cpp @@ -20,38 +20,40 @@ #include "ops/declarable/OpTuple.h" -sd::ops::OpTuple::OpTuple(const char *opName) { - _opName = opName; -} - -sd::ops::OpTuple::OpTuple(const char *opName, std::initializer_list &&inputs, std::initializer_list &&tArgs, std::initializer_list &&iArgs) { - _opName = opName; - _inputs = inputs; - _iArgs = iArgs; - _tArgs = tArgs; +sd::ops::OpTuple::OpTuple(const char *opName) { _opName = opName; } + +sd::ops::OpTuple::OpTuple(const char *opName, + std::initializer_list &&inputs, + std::initializer_list &&tArgs, + std::initializer_list &&iArgs) { + _opName = opName; + _inputs = inputs; + _iArgs = iArgs; + _tArgs = tArgs; } sd::ops::OpTuple::~OpTuple() { - for (auto v: _inputs) - delete v; + for (auto v : _inputs) delete v; } sd::ops::OpTuple *sd::ops::OpTuple::addInput(sd::NDArray *array) { - _inputs.emplace_back(array); - return this; + _inputs.emplace_back(array); + return this; } sd::ops::OpTuple *sd::ops::OpTuple::addOutput(sd::NDArray *array) { - _outputs.emplace_back(array); - return this; + _outputs.emplace_back(array); + return this; } -sd::ops::OpTuple *sd::ops::OpTuple::setTArgs(std::initializer_list tArgs) { - _tArgs = tArgs; - return this; +sd::ops::OpTuple *sd::ops::OpTuple::setTArgs( + std::initializer_list tArgs) { + _tArgs = tArgs; + return this; } -sd::ops::OpTuple *sd::ops::OpTuple::setIArgs(std::initializer_list iArgs) { - _iArgs = iArgs; - return this; +sd::ops::OpTuple *sd::ops::OpTuple::setIArgs( + std::initializer_list iArgs) { + _iArgs = iArgs; + return this; } diff --git a/libnd4j/include/ops/declarable/impl/PlatformHelper.cpp b/libnd4j/include/ops/declarable/impl/PlatformHelper.cpp index 245626c09990..f935dbe52c53 100644 --- a/libnd4j/include/ops/declarable/impl/PlatformHelper.cpp +++ b/libnd4j/include/ops/declarable/impl/PlatformHelper.cpp @@ -19,81 +19,79 @@ // #include "../PlatformHelper.h" + #include namespace sd { - namespace ops { - namespace platforms { - PlatformHelper::PlatformHelper(const char *name, samediff::Engine engine) { - // we just store name/hash of target operation - _name = std::string(name); - _hash = HashHelper::getInstance().getLongHash(_name); - _engine = engine; - } - - sd::NDArray* PlatformHelper::getNullifiedZ(graph::Context& block, int inputId) { - auto result = getZ(block, inputId); - if (result != nullptr && !block.isInplace()) - result->nullify(); - - return result; - } - - sd::NDArray* PlatformHelper::getZ(graph::Context &ctx, int inputId) { - NDArray *z = nullptr; - - if (ctx.isFastPath()) { - if (ctx.fastpath_out().size() <= inputId) { - if (ctx.isInplace()) { - z = ctx.fastpath_in()[inputId]; - } else - throw std::runtime_error("fastpath_out: unresolved output array"); - } else { - z = ctx.fastpath_out()[inputId]; - } - } else { - std::pair pair(ctx.nodeId(), inputId); - - if (ctx.isInplace()) { - z = ctx.variable(inputId)->getNDArray(); - - // hypothetically it's possible to have no variable. chances are low, but who knows. let's just create it for now - if (!ctx.getVariableSpace()->hasVariable(pair)) { - auto var = new graph::Variable(); - ctx.getVariableSpace()->putVariable(pair, var); - } - - // now we're saving input array as output array - auto var = ctx.getVariableSpace()->getVariable(pair); - var->markRemovable(false); - var->setNDArray(z); - } else if (!ctx.isInplace()) { - auto var = ctx.variable(pair); - if (var->getNDArray() != nullptr && var->getNDArray()->nonNull()) { - z = var->getNDArray(); - } else { - nd4j_printf("Can't get Z variable for node_%i!\n", ctx.nodeId()); - } - } else { - nd4j_printf("BOOM!\n", ""); - throw std::runtime_error("Boom!"); - } - } - - return z; - } - - samediff::Engine PlatformHelper::engine() { - return _engine; - } - - std::string PlatformHelper::name() { - return _name; - } - - Nd4jLong PlatformHelper::hash() { - return _hash; - } - } +namespace ops { +namespace platforms { +PlatformHelper::PlatformHelper(const char* name, samediff::Engine engine) { + // we just store name/hash of target operation + _name = std::string(name); + _hash = HashHelper::getInstance().getLongHash(_name); + _engine = engine; +} + +sd::NDArray* PlatformHelper::getNullifiedZ(graph::Context& block, int inputId) { + auto result = getZ(block, inputId); + if (result != nullptr && result->undefined()) return nullptr; + + if (result != nullptr && !block.isInplace()) result->nullify(); + + return result; +} + +sd::NDArray* PlatformHelper::getZ(graph::Context& ctx, int inputId) { + NDArray* z = nullptr; + + if (ctx.isFastPath()) { + if (ctx.fastpath_out().size() <= inputId) { + if (ctx.isInplace()) { + z = ctx.fastpath_in()[inputId].get(); + } else + throw std::runtime_error("fastpath_out: unresolved output array"); + } else { + z = ctx.fastpath_out()[inputId].get(); + } + } else { + std::pair pair(ctx.nodeId(), inputId); + + if (ctx.isInplace()) { + auto vz = ctx.variable(inputId)->getNDArray(); + z = vz.get(); + + // hypothetically it's possible to have no variable. chances are low, but + // who knows. let's just create it for now + if (!ctx.getVariableSpace()->hasVariable(pair)) { + auto var = std::make_shared(); + ctx.getVariableSpace()->putVariable(pair, var); + } + + // now we're saving input array as output array + auto var = ctx.getVariableSpace()->getVariable(pair); + var->markRemovable(false); + var->setNDArray(vz); + } else if (!ctx.isInplace()) { + auto var = ctx.variable(pair); + if (var->getNDArray() != nullptr && var->getNDArray()->nonNull()) { + z = var->getNDArray().get(); + } else { + nd4j_printf("Can't get Z variable for node_%i!\n", ctx.nodeId()); + } + } else { + nd4j_printf("BOOM!\n", ""); + throw std::runtime_error("Boom!"); } -} \ No newline at end of file + } + + return z; +} + +samediff::Engine PlatformHelper::engine() { return _engine; } + +std::string PlatformHelper::name() { return _name; } + +Nd4jLong PlatformHelper::hash() { return _hash; } +} // namespace platforms +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/cudnn/avgpool2d.cu b/libnd4j/include/ops/declarable/platform/cudnn/avgpool2d.cu index eb213f4c2ae6..4a4deb642f94 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/avgpool2d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/avgpool2d.cu @@ -18,121 +18,170 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // +#include #include "cudnnUtils.h" -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { - ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(avgpool2d, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - const auto kH = INT_ARG(0); - const auto kW = INT_ARG(1); - const auto sH = INT_ARG(2); - const auto sW = INT_ARG(3); - auto pH = INT_ARG(4); - auto pW = INT_ARG(5); - const auto dH = INT_ARG(6); - const auto dW = INT_ARG(7); - const auto paddingMode = static_cast(INT_ARG(8)); - const auto extraParam0 = INT_ARG(9); - const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D CUDNN op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int oH = 0; - int oW = 0; - - const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); - const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); - - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); - - if (paddingMode) - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; - - pooling2dCUDNN(block.launchContext(), input, output, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, mode); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + const auto kH = INT_ARG(0); + const auto kW = INT_ARG(1); + const auto sH = INT_ARG(2); + const auto sW = INT_ARG(3); + auto pH = INT_ARG(4); + auto pW = INT_ARG(5); + const auto dH = INT_ARG(6); + const auto dW = INT_ARG(7); + const auto paddingMode = static_cast(INT_ARG(8)); + const auto extraParam0 = INT_ARG(9); + const int isNCHW = block.numI() > 10 + ? !INT_ARG(10) + : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "AVGPOOL2D CUDNN op: input should have rank of 4, but got %i instead", + input->rankOf()); + REQUIRE_TRUE( + dH != 0 && dW != 0, 0, + "AVGPOOL2D CUDNN op: dilation must not be zero, but got instead {%i, %i}", + dH, dW); + + int oH = 0; + int oW = 0; + + const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); + const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); + + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, + iH, iW, paddingMode); + + if (paddingMode) + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + const cudnnPoolingMode_t mode = + (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING + : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + + pooling2dCUDNN(block.launchContext(), input, output, kH, kW, sH, sW, pH, pW, + dH, dW, isNCHW, mode); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(avgpool2d, ENGINE_CUDA) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; + const auto goodType = input->dataType() == DataType::DOUBLE || + input->dataType() == DataType::FLOAT32 || + input->dataType() == DataType::HALF || + input->dataType() == DataType::INT32; - return goodType && input->dataType() == output->dataType(); + return goodType && input->dataType() == output->dataType(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(avgpool2d_bp, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - const auto kH = INT_ARG(0); // filter(kernel) height - const auto kW = INT_ARG(1); // filter(kernel) width - const auto sH = INT_ARG(2); // strides height - const auto sW = INT_ARG(3); // strides width - auto pH = INT_ARG(4); // paddings height - auto pW = INT_ARG(5); // paddings width - const auto dH = INT_ARG(6); // dilations height - const auto dW = INT_ARG(7); // dilations width - const auto paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - const auto extraParam0 = INT_ARG(9); - const auto isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D_BP CUDNN op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D_BP CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); - std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL2D_BP CUDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "AVGPOOL2D_BP CUDNN op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); - - if(paddingMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; - - pooling2dBpCUDNN(block.launchContext(), input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, mode); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE( + 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + const auto kH = INT_ARG(0); // filter(kernel) height + const auto kW = INT_ARG(1); // filter(kernel) width + const auto sH = INT_ARG(2); // strides height + const auto sW = INT_ARG(3); // strides width + auto pH = INT_ARG(4); // paddings height + auto pW = INT_ARG(5); // paddings width + const auto dH = INT_ARG(6); // dilations height + const auto dW = INT_ARG(7); // dilations width + const auto paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + const auto extraParam0 = INT_ARG(9); + const auto isNCHW = block.numI() > 10 + ? !INT_ARG(10) + : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "AVGPOOL2D_BP CUDNN op: input should have rank of 4, but got %i instead", + input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, + "AVGPOOL2D_BP CUDNN op: dilation must not be zero, but got " + "instead {%i, %i}", + dH, dW); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWoC, indWkH, indOoH); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}); + std::vector expectedGradIShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "AVGPOOL2D_BP CUDNN op: wrong shape of output's gradients array " + "(next epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, + "AVGPOOL2D_BP CUDNN op: wrong shape of input's gradients array " + "(epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradIShape).c_str(), + ShapeUtils::shapeAsString(gradI).c_str()); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + const cudnnPoolingMode_t mode = + (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING + : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + + pooling2dBpCUDNN(block.launchContext(), input, gradO, gradI, kH, kW, sH, sW, + pH, pW, dH, dW, isNCHW, mode); + + return Status::OK(); } PLATFORM_CHECK(avgpool2d_bp, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; - - return goodType && (input->dataType() == gradO->dataType()) - && (input->dataType() == gradI->dataType()) - && shape::haveSameShapeAndStrides(input->shapeInfo(), gradI->shapeInfo()); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE( + 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + const auto goodType = input->dataType() == DataType::DOUBLE || + input->dataType() == DataType::FLOAT32 || + input->dataType() == DataType::HALF || + input->dataType() == DataType::INT32; + + return goodType && (input->dataType() == gradO->dataType()) && + (input->dataType() == gradI->dataType()) && + shape::haveSameShapeAndStrides(input->shapeInfo(), gradI->shapeInfo()); } - -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/cudnn/avgpool3d.cu b/libnd4j/include/ops/declarable/platform/cudnn/avgpool3d.cu index da2fdbc098f7..15e1dd975161 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/avgpool3d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/avgpool3d.cu @@ -18,127 +18,188 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // +#include #include "cudnnUtils.h" -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { - ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(avgpool3dnew, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) - - int kD = INT_ARG(0); // filter(kernel) depth - int kH = INT_ARG(1); // filter(kernel) height - int kW = INT_ARG(2); // filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - int extraParam0 = INT_ARG(13); - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC - - REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::vector expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "AVGPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); - - if(paddingMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; - - pooling3dCUDNN(block.launchContext(), input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, mode); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto output = OUTPUT_VARIABLE( + 0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) + + int kD = INT_ARG(0); // filter(kernel) depth + int kH = INT_ARG(1); // filter(kernel) height + int kW = INT_ARG(2); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + int extraParam0 = INT_ARG(13); + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) + : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "AVGPOOL3DNEW CUDNN OP: rank of input array must be equal to 5, " + "but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "AVGPOOL3DNEW CUDNN OP: dilation must not be zero, but got " + "instead {%i, %i, %i}", + dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedOutputShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oD, oH, oW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, + "AVGPOOL3DNEW CUDNN OP: wrong shape of output array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedOutputShape).c_str(), + ShapeUtils::shapeAsString(output).c_str()); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + const cudnnPoolingMode_t mode = + (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING + : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + + pooling3dCUDNN(block.launchContext(), input, output, kD, kH, kW, sD, sH, sW, + pD, pH, pW, dD, dH, dW, isNCDHW, mode); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(avgpool3dnew, ENGINE_CUDA) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; + const auto goodType = input->dataType() == DataType::DOUBLE || + input->dataType() == DataType::FLOAT32 || + input->dataType() == DataType::HALF || + input->dataType() == DataType::INT32; - return goodType && input->dataType() == output->dataType(); + return goodType && input->dataType() == output->dataType(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - - const int kD = INT_ARG(0); // filter(kernel) depth - const int kH = INT_ARG(1); // filter(kernel) height - const int kW = INT_ARG(2); // filter(kernel) width - const int sD = INT_ARG(3); // strides depth - const int sH = INT_ARG(4); // strides height - const int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - const int dD = INT_ARG(9); // dilations depth - const int dH = INT_ARG(10); // dilations height - const int dW = INT_ARG(11); // dilations width - const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID - const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging - const int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC - - REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW_BP CUDNN OP: input should have rank of 5, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW_BP CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL3DNEW_BP CUDNN: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "AVGPOOL3DNEW_BP CUDNN: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); - - if(isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; - - pooling3dBpCUDNN(block.launchContext(), input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, mode); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, + // oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), epsilon + + const int kD = INT_ARG(0); // filter(kernel) depth + const int kH = INT_ARG(1); // filter(kernel) height + const int kW = INT_ARG(2); // filter(kernel) width + const int sD = INT_ARG(3); // strides depth + const int sH = INT_ARG(4); // strides height + const int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + const int dD = INT_ARG(9); // dilations depth + const int dH = INT_ARG(10); // dilations height + const int dW = INT_ARG(11); // dilations width + const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + const int extraParam0 = + INT_ARG(13); // define what divisor to use while averaging + const int isNCDHW = block.numI() > 14 + ? !INT_ARG(14) + : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "AVGPOOL3DNEW_BP CUDNN OP: input should have rank of 5, but got " + "%i instead", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "AVGPOOL3DNEW_BP CUDNN OP: dilation must not be zero, but got " + "instead {%i, %i, %i}", + dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oD, oH, oW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + std::vector expectedGradIShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iD, iH, iW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "AVGPOOL3DNEW_BP CUDNN: wrong shape of output's gradients array " + "(next epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, + "AVGPOOL3DNEW_BP CUDNN: wrong shape of input's gradients array " + "(epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradIShape).c_str(), + ShapeUtils::shapeAsString(gradI).c_str()); + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + const cudnnPoolingMode_t mode = + (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING + : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + + pooling3dBpCUDNN(block.launchContext(), input, gradO, gradI, kD, kH, kW, sD, + sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, mode); + + return Status::OK(); } PLATFORM_CHECK(avgpool3dnew_bp, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - - const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; - - return goodType && (input->dataType() == gradO->dataType()) - && (input->dataType() == gradI->dataType()) - && shape::haveSameShapeAndStrides(input->shapeInfo(), gradI->shapeInfo()); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, + // oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), epsilon + + const auto goodType = input->dataType() == DataType::DOUBLE || + input->dataType() == DataType::FLOAT32 || + input->dataType() == DataType::HALF || + input->dataType() == DataType::INT32; + + return goodType && (input->dataType() == gradO->dataType()) && + (input->dataType() == gradI->dataType()) && + shape::haveSameShapeAndStrides(input->shapeInfo(), gradI->shapeInfo()); } - -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/cudnn/batchnorm.cu b/libnd4j/include/ops/declarable/platform/cudnn/batchnorm.cu index 7568ba47a330..c011a28af0ec 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/batchnorm.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/batchnorm.cu @@ -18,551 +18,716 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // +#include #include "cudnnUtils.h" -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////////// -static void batchnormCUDNN(const LaunchContext* context, - const NDArray* input, const NDArray* mean, const NDArray* variance, - const NDArray* gamma, const NDArray* beta, - NDArray* output, - const double epsilon, const bool isSpatialMode) { - - - // input, output -> 4D:nchw, 5D:ncdhw - // mean, variance, gamma, beta -> 1xCx1x1 for 4D and 1xCx1x1x1 for 5D for BATCHNORM_MODE_SPATIAL mode - // -> 1xCxHxW for 4D and 1xCxDxHxW for 5D for BATCHNORM_MODE_PER_ACTIVATION mode - - const cudnnDataType_t dataType = cudnnDataType(input->dataType()); - - const int xRank = input->rankOf(); - - auto handle = reinterpret_cast(context->getCuDnnHandle()); - cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); - if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: can't set stream for cuDNN", err); - - const std::vector xShape = input->getShapeAsVectorInt(); // input and output have same shapes - - std::vector paramsShape, paramsStrides; // mean, variance, gamma and beta have same shapes - if(isSpatialMode) { // 1xCx1x1 - const int iC = mean->lengthOf(); - const int stride0 = mean->strideAt(0); - paramsShape = xRank == 4 ? std::vector({1, iC, 1, 1}) : std::vector({1, iC, 1, 1, 1}); - paramsStrides = xRank == 4 ? std::vector({iC*stride0, stride0, 1, 1}) : std::vector({iC*stride0, stride0, 1, 1, 1}); - } - else { - paramsShape = mean->getShapeAsVectorInt(); - paramsStrides = xRank == 4 ? std::vector({(int)mean->strideAt(0), (int)mean->strideAt(1), (int)mean->strideAt(2), (int)mean->strideAt(3)}) : std::vector({(int)mean->strideAt(0), (int)mean->strideAt(1), (int)mean->strideAt(2), (int)mean->strideAt(3), (int)mean->strideAt(4)}); - } - - std::vector xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3)}; - std::vector zStrides = {(int)output->strideAt(0), (int)output->strideAt(1), (int)output->strideAt(2), (int)output->strideAt(3)}; - - if(xRank > 4) { // 5D - xStrides.push_back((int)input->strideAt(4)); - zStrides.push_back((int)output->strideAt(4)); - } - - cudnnTensorFormat_t format = CUDNN_TENSOR_NCHW; - - // input descriptor - cudnnTensorDescriptor_t x; - cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1) - err = cudnnSetTensorNdDescriptorEx(x, format, dataType, xRank, xShape.data()); - else - err = cudnnSetTensorNdDescriptor(x, dataType, xRank, xShape.data(), xStrides.data()); - if (err != 0) throw sd::cuda_exception::build("batchnormCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input failed", err); - - // output descriptor - cudnnTensorDescriptor_t z; - cudnnCreateTensorDescriptor(&z); - if(output->ews() == 1) - err = cudnnSetTensorNdDescriptorEx(z, format, dataType, xRank, xShape.data()); - else - err = cudnnSetTensorNdDescriptor(z, dataType, xRank, xShape.data(), zStrides.data()); - if (err != 0) throw sd::cuda_exception::build("batchnormCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for output failed", err); - - // mean, variance, gamma and beta descriptor, the same descriptor for all of them - cudnnTensorDescriptor_t params; - cudnnCreateTensorDescriptor(¶ms); - if(mean->ews() == 1) - err = cudnnSetTensorNdDescriptorEx(params, format, dataType, xRank, paramsShape.data()); - else - err = cudnnSetTensorNdDescriptor(params, dataType, xRank, paramsShape.data(), paramsStrides.data()); - if (err != 0) throw sd::cuda_exception::build("batchnormCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for mean/variance/gamma/beta failed", err); - - // provide scaling parameters - const float alpha32(1), beta32(0); - const double alpha64(1), beta64(0); - const void* ptrAlpha = output->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); - const void* ptrBeta = output->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); - - NDArray::prepareSpecialUse({output}, {input, mean, variance, gamma, beta}); - - // calculations - err = cudnnBatchNormalizationForwardInference(*handle, isSpatialMode ? CUDNN_BATCHNORM_SPATIAL : CUDNN_BATCHNORM_PER_ACTIVATION, - ptrAlpha, ptrBeta, - x, input->specialBuffer(), - z, output->specialBuffer(), - params, - gamma->specialBuffer(), beta->specialBuffer(), - mean->specialBuffer(), variance->specialBuffer(), epsilon); - - if (err != 0) throw sd::cuda_exception::build("batchnormCUDNN: cudnnBatchNormalizationForwardInference failed", err); - - auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - if (cudaErr != 0) - throw cuda_exception::build("batchnormCUDNN: cudaStreamSynchronize failed !", cudaErr); - - NDArray::registerSpecialUse({output}, {input, mean, variance, gamma, beta}); +static void batchnormCUDNN(const LaunchContext* context, const NDArray* input, + const NDArray* mean, const NDArray* variance, + const NDArray* gamma, const NDArray* beta, + NDArray* output, const double epsilon, + const bool isSpatialMode) { + // input, output -> 4D:nchw, 5D:ncdhw + // mean, variance, gamma, beta -> 1xCx1x1 for 4D and 1xCx1x1x1 for 5D for + // BATCHNORM_MODE_SPATIAL mode + // -> 1xCxHxW for 4D and 1xCxDxHxW for 5D for + // BATCHNORM_MODE_PER_ACTIVATION mode + + const cudnnDataType_t dataType = cudnnDataType(input->dataType()); + + const int xRank = input->rankOf(); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build("conv2dCUDNN: can't set stream for cuDNN", + err); + + const std::vector xShape = + input->getShapeAsVectorInt(); // input and output have same shapes + + std::vector paramsShape, + paramsStrides; // mean, variance, gamma and beta have same shapes + if (isSpatialMode) { // 1xCx1x1 + const int iC = mean->lengthOf(); + const int stride0 = mean->strideAt(0); + paramsShape = xRank == 4 ? std::vector({1, iC, 1, 1}) + : std::vector({1, iC, 1, 1, 1}); + paramsStrides = xRank == 4 + ? std::vector({iC * stride0, stride0, 1, 1}) + : std::vector({iC * stride0, stride0, 1, 1, 1}); + } else { + paramsShape = mean->getShapeAsVectorInt(); + paramsStrides = + xRank == 4 + ? std::vector({(int)mean->strideAt(0), (int)mean->strideAt(1), + (int)mean->strideAt(2), (int)mean->strideAt(3)}) + : std::vector({(int)mean->strideAt(0), (int)mean->strideAt(1), + (int)mean->strideAt(2), (int)mean->strideAt(3), + (int)mean->strideAt(4)}); + } + + std::vector xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), + (int)input->strideAt(2), + (int)input->strideAt(3)}; + std::vector zStrides = { + (int)output->strideAt(0), (int)output->strideAt(1), + (int)output->strideAt(2), (int)output->strideAt(3)}; + + if (xRank > 4) { // 5D + xStrides.push_back((int)input->strideAt(4)); + zStrides.push_back((int)output->strideAt(4)); + } + + cudnnTensorFormat_t format = CUDNN_TENSOR_NCHW; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if (input->ews() == 1) + err = + cudnnSetTensorNdDescriptorEx(x, format, dataType, xRank, xShape.data()); + else + err = cudnnSetTensorNdDescriptor(x, dataType, xRank, xShape.data(), + xStrides.data()); + if (err != 0) + throw sd::cuda_exception::build( + "batchnormCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input " + "failed", + err); + + // output descriptor + cudnnTensorDescriptor_t z; + cudnnCreateTensorDescriptor(&z); + if (output->ews() == 1) + err = + cudnnSetTensorNdDescriptorEx(z, format, dataType, xRank, xShape.data()); + else + err = cudnnSetTensorNdDescriptor(z, dataType, xRank, xShape.data(), + zStrides.data()); + if (err != 0) + throw sd::cuda_exception::build( + "batchnormCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for output " + "failed", + err); + + // mean, variance, gamma and beta descriptor, the same descriptor for all of + // them + cudnnTensorDescriptor_t params; + cudnnCreateTensorDescriptor(¶ms); + if (mean->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(params, format, dataType, xRank, + paramsShape.data()); + else + err = cudnnSetTensorNdDescriptor(params, dataType, xRank, + paramsShape.data(), paramsStrides.data()); + if (err != 0) + throw sd::cuda_exception::build( + "batchnormCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for " + "mean/variance/gamma/beta failed", + err); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* ptrAlpha = output->sizeOfT() <= 4 + ? reinterpret_cast(&alpha32) + : reinterpret_cast(&alpha64); + const void* ptrBeta = output->sizeOfT() <= 4 + ? reinterpret_cast(&beta32) + : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({output}, {input, mean, variance, gamma, beta}); + + // calculations + err = cudnnBatchNormalizationForwardInference( + *handle, + isSpatialMode ? CUDNN_BATCHNORM_SPATIAL : CUDNN_BATCHNORM_PER_ACTIVATION, + ptrAlpha, ptrBeta, x, input->specialBuffer(), z, output->specialBuffer(), + params, gamma->specialBuffer(), beta->specialBuffer(), + mean->specialBuffer(), variance->specialBuffer(), epsilon); + + if (err != 0) + throw sd::cuda_exception::build( + "batchnormCUDNN: cudnnBatchNormalizationForwardInference failed", err); + + auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + if (cudaErr != 0) + throw cuda_exception::build( + "batchnormCUDNN: cudaStreamSynchronize failed !", cudaErr); + + NDArray::registerSpecialUse({output}, {input, mean, variance, gamma, beta}); } ////////////////////////////////////////////////////////////////////////// -static void batchnormBpCUDNN(const LaunchContext* context, - const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* gradO, - NDArray* gradI, NDArray* gradG, NDArray* gradB, - const double epsilon, const bool isSpatialMode) { - - // input, gradO, gradI -> 4D:nchw, 5D:ncdhw - // mean, variance, gamma, beta, gradM, gradV, gradG, gradB -> 1xCx1x1 for 4D and 1xCx1x1x1 for 5D for BATCHNORM_MODE_SPATIAL mode - // -> 1xCxHxW for 4D and 1xCxDxHxW for 5D for BATCHNORM_MODE_PER_ACTIVATION mode - - const cudnnDataType_t dataType = cudnnDataType(input->dataType()); - - const int xRank = input->rankOf(); - - auto handle = reinterpret_cast(context->getCuDnnHandle()); - cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); - if (err != 0) throw sd::cuda_exception::build("batchnormBpCUDNN: can't set stream for cuDNN", err); - - const std::vector xShape = input->getShapeAsVectorInt(); // input and output have same shapes - - std::vector paramsShape, paramsStrides; // mean, variance, gamma and beta have same shapes - if(isSpatialMode) { // 1xCx1x1 - const int iC = mean->lengthOf(); - const int stride0 = mean->strideAt(0); - paramsShape = xRank == 4 ? std::vector({1, iC, 1, 1}) : std::vector({1, iC, 1, 1, 1}); - paramsStrides = xRank == 4 ? std::vector({iC*stride0, stride0, 1, 1}) : std::vector({iC*stride0, stride0, 1, 1, 1}); - } - else { - paramsShape = mean->getShapeAsVectorInt(); - paramsStrides = xRank == 4 ? std::vector({(int)mean->strideAt(0), (int)mean->strideAt(1), (int)mean->strideAt(2), (int)mean->strideAt(3)}) : std::vector({(int)mean->strideAt(0), (int)mean->strideAt(1), (int)mean->strideAt(2), (int)mean->strideAt(3), (int)mean->strideAt(4)}); - } - - std::vector xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3)}; - std::vector dxStrides = {(int)gradI->strideAt(0), (int)gradI->strideAt(1), (int)gradI->strideAt(2), (int)gradI->strideAt(3)}; - std::vector dzStrides = {(int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), (int)gradO->strideAt(3)}; - - if(xRank > 4) { // 5D - xStrides.push_back((int)input->strideAt(4)); - dxStrides.push_back((int)gradI->strideAt(4)); - dzStrides.push_back((int)gradO->strideAt(4)); - } - - cudnnTensorFormat_t format = CUDNN_TENSOR_NCHW; - - // input descriptor - cudnnTensorDescriptor_t x; - cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1) - err = cudnnSetTensorNdDescriptorEx(x, format, dataType, xRank, xShape.data()); - else - err = cudnnSetTensorNdDescriptor(x, dataType, xRank, xShape.data(), xStrides.data()); - if (err != 0) throw sd::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input failed", err); - - // gradO descriptor - cudnnTensorDescriptor_t dz; - cudnnCreateTensorDescriptor(&dz); - if(gradO->ews() == 1) - err = cudnnSetTensorNdDescriptorEx(dz, format, dataType, xRank, xShape.data()); - else - err = cudnnSetTensorNdDescriptor(dz, dataType, xRank, xShape.data(), dzStrides.data()); - if (err != 0) throw sd::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradO failed", err); - - // gradI descriptor - cudnnTensorDescriptor_t dx; - cudnnCreateTensorDescriptor(&dx); - if(input->ews() == 1) - err = cudnnSetTensorNdDescriptorEx(dx, format, dataType, xRank, xShape.data()); - else - err = cudnnSetTensorNdDescriptor(dx, dataType, xRank, xShape.data(), dxStrides.data()); - if (err != 0) throw sd::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradI failed", err); - - // mean, variance, gamma, gradG and gradB descriptor, the same descriptor for all of them - cudnnTensorDescriptor_t params; - cudnnCreateTensorDescriptor(¶ms); - if(mean->ews() == 1) - err = cudnnSetTensorNdDescriptorEx(params, format, dataType, xRank, paramsShape.data()); - else - err = cudnnSetTensorNdDescriptor(params, dataType, xRank, paramsShape.data(), paramsStrides.data()); - if (err != 0) throw sd::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for mean/variance/gamma/gradG/gradB failed", err); - - // provide scaling parameters - const float alpha32(1), beta32(0); - double alpha64(1), beta64(0); - const void* ptrAlpha = input->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); - const void* ptrBeta = input->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); - - NDArray::prepareSpecialUse({gradI, gradG, gradB}, {input, mean, variance, gamma, gradO}); - - // calculations - // TODO: we can use cache here - err = cudnnBatchNormalizationBackward(*handle, isSpatialMode ? CUDNN_BATCHNORM_SPATIAL : CUDNN_BATCHNORM_PER_ACTIVATION, - ptrAlpha, ptrBeta, ptrAlpha, ptrBeta, - x, input->specialBuffer(), - dz, gradO->specialBuffer(), - dx, gradI->specialBuffer(), - params, - gamma->specialBuffer(), gradG->specialBuffer(), gradB->specialBuffer(), - epsilon, - nullptr/*mean->specialBuffer()*/, nullptr/*variance->specialBuffer()*/); - - if (err != 0) throw sd::cuda_exception::build("batchnormBpCUDNN: cudnnBatchNormalizationBackward failed", err); - - auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - if (cudaErr != 0) - throw cuda_exception::build("batchnormBpCUDNN: cudaStreamSynchronize failed !", cudaErr); - - NDArray::registerSpecialUse({gradI, gradG, gradB}, {input, mean, variance, gamma, gradO}); +static void batchnormBpCUDNN(const LaunchContext* context, const NDArray* input, + const NDArray* mean, const NDArray* variance, + const NDArray* gamma, const NDArray* gradO, + NDArray* gradI, NDArray* gradG, NDArray* gradB, + const double epsilon, const bool isSpatialMode) { + // input, gradO, gradI -> 4D:nchw, 5D:ncdhw + // mean, variance, gamma, beta, gradM, gradV, gradG, gradB -> 1xCx1x1 for 4D + // and 1xCx1x1x1 for 5D for BATCHNORM_MODE_SPATIAL mode + // -> 1xCxHxW for 4D + // and 1xCxDxHxW for + // 5D for + // BATCHNORM_MODE_PER_ACTIVATION + // mode + + const cudnnDataType_t dataType = cudnnDataType(input->dataType()); + + const int xRank = input->rankOf(); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build( + "batchnormBpCUDNN: can't set stream for cuDNN", err); + + const std::vector xShape = + input->getShapeAsVectorInt(); // input and output have same shapes + + std::vector paramsShape, + paramsStrides; // mean, variance, gamma and beta have same shapes + if (isSpatialMode) { // 1xCx1x1 + const int iC = mean->lengthOf(); + const int stride0 = mean->strideAt(0); + paramsShape = xRank == 4 ? std::vector({1, iC, 1, 1}) + : std::vector({1, iC, 1, 1, 1}); + paramsStrides = xRank == 4 + ? std::vector({iC * stride0, stride0, 1, 1}) + : std::vector({iC * stride0, stride0, 1, 1, 1}); + } else { + paramsShape = mean->getShapeAsVectorInt(); + paramsStrides = + xRank == 4 + ? std::vector({(int)mean->strideAt(0), (int)mean->strideAt(1), + (int)mean->strideAt(2), (int)mean->strideAt(3)}) + : std::vector({(int)mean->strideAt(0), (int)mean->strideAt(1), + (int)mean->strideAt(2), (int)mean->strideAt(3), + (int)mean->strideAt(4)}); + } + + std::vector xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), + (int)input->strideAt(2), + (int)input->strideAt(3)}; + std::vector dxStrides = { + (int)gradI->strideAt(0), (int)gradI->strideAt(1), (int)gradI->strideAt(2), + (int)gradI->strideAt(3)}; + std::vector dzStrides = { + (int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), + (int)gradO->strideAt(3)}; + + if (xRank > 4) { // 5D + xStrides.push_back((int)input->strideAt(4)); + dxStrides.push_back((int)gradI->strideAt(4)); + dzStrides.push_back((int)gradO->strideAt(4)); + } + + cudnnTensorFormat_t format = CUDNN_TENSOR_NCHW; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if (input->ews() == 1) + err = + cudnnSetTensorNdDescriptorEx(x, format, dataType, xRank, xShape.data()); + else + err = cudnnSetTensorNdDescriptor(x, dataType, xRank, xShape.data(), + xStrides.data()); + if (err != 0) + throw sd::cuda_exception::build( + "batchnormBpCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input " + "failed", + err); + + // gradO descriptor + cudnnTensorDescriptor_t dz; + cudnnCreateTensorDescriptor(&dz); + if (gradO->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(dz, format, dataType, xRank, + xShape.data()); + else + err = cudnnSetTensorNdDescriptor(dz, dataType, xRank, xShape.data(), + dzStrides.data()); + if (err != 0) + throw sd::cuda_exception::build( + "batchnormBpCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradO " + "failed", + err); + + // gradI descriptor + cudnnTensorDescriptor_t dx; + cudnnCreateTensorDescriptor(&dx); + if (input->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(dx, format, dataType, xRank, + xShape.data()); + else + err = cudnnSetTensorNdDescriptor(dx, dataType, xRank, xShape.data(), + dxStrides.data()); + if (err != 0) + throw sd::cuda_exception::build( + "batchnormBpCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradI " + "failed", + err); + + // mean, variance, gamma, gradG and gradB descriptor, the same descriptor for + // all of them + cudnnTensorDescriptor_t params; + cudnnCreateTensorDescriptor(¶ms); + if (mean->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(params, format, dataType, xRank, + paramsShape.data()); + else + err = cudnnSetTensorNdDescriptor(params, dataType, xRank, + paramsShape.data(), paramsStrides.data()); + if (err != 0) + throw sd::cuda_exception::build( + "batchnormBpCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for " + "mean/variance/gamma/gradG/gradB failed", + err); + + // provide scaling parameters + const float alpha32(1), beta32(0); + double alpha64(1), beta64(0); + const void* ptrAlpha = input->sizeOfT() <= 4 + ? reinterpret_cast(&alpha32) + : reinterpret_cast(&alpha64); + const void* ptrBeta = input->sizeOfT() <= 4 + ? reinterpret_cast(&beta32) + : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({gradI, gradG, gradB}, + {input, mean, variance, gamma, gradO}); + + // calculations + // TODO: we can use cache here + err = cudnnBatchNormalizationBackward( + *handle, + isSpatialMode ? CUDNN_BATCHNORM_SPATIAL : CUDNN_BATCHNORM_PER_ACTIVATION, + ptrAlpha, ptrBeta, ptrAlpha, ptrBeta, x, input->specialBuffer(), dz, + gradO->specialBuffer(), dx, gradI->specialBuffer(), params, + gamma->specialBuffer(), gradG->specialBuffer(), gradB->specialBuffer(), + epsilon, nullptr /*mean->specialBuffer()*/, + nullptr /*variance->specialBuffer()*/); + + if (err != 0) + throw sd::cuda_exception::build( + "batchnormBpCUDNN: cudnnBatchNormalizationBackward failed", err); + + auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + if (cudaErr != 0) + throw cuda_exception::build( + "batchnormBpCUDNN: cudaStreamSynchronize failed !", cudaErr); + + NDArray::registerSpecialUse({gradI, gradG, gradB}, + {input, mean, variance, gamma, gradO}); } - ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(batchnorm, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); - auto mean = INPUT_VARIABLE(1); - auto variance = INPUT_VARIABLE(2); - NDArray* gamma = nullptr; - NDArray* beta = nullptr; - - auto output = OUTPUT_VARIABLE(0); - - const bool applyScale = (bool)INT_ARG(0); - const bool applyOffset = (bool)INT_ARG(1); - const double epsilon = T_ARG(0); - - if(applyScale) - gamma = INPUT_VARIABLE(3); - if(applyOffset) - beta = INPUT_VARIABLE(3 + (int)applyScale); - - const int numOfIntArgs = block.getIArguments()->size(); - const int inRank = input->rankOf(); - - // get axes args to normalize input array over - std::vector axes; - if(numOfIntArgs > 2) - for(int i = 2; i < numOfIntArgs; ++i) - axes.push_back(INT_ARG(i)); - else - axes.push_back(inRank-1); // default dimension to reduce along is last dimension - - const int numOfAxes = axes.size(); - REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM CUDNN op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank); - - // evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes - // for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5} - std::vector expShape; - if(numOfAxes == 1) - expShape.push_back(input->sizeAt(axes[0])); - else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3} - expShape = std::vector(inRank, 1); - for(uint i = 0; i < numOfAxes; ++i) - expShape[axes[i]] = input->sizeAt(axes[i]); - } - - REQUIRE_TRUE(mean->isSameShape(expShape) , 0, "BATCHNORM CUDNN op: wrong shape of mean array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(mean).c_str()); - REQUIRE_TRUE(variance->isSameShape(expShape), 0, "BATCHNORM CUDNN op: wrong shape of variance array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(variance).c_str()); - if(gamma) - REQUIRE_TRUE(gamma->isSameShape(expShape), 0, "BATCHNORM CUDNN op: wrong shape of gamma array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(gamma).c_str()); - if(beta) - REQUIRE_TRUE(beta->isSameShape(expShape), 0, "BATCHNORM CUDNN op: wrong shape of beta array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(beta).c_str()); - - // types of all input arrays should be the same - for(int i = 1; i < block.width(); ++i) - REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM CUDNN op: types of all input arrays should be the same !"); - - // cudnn supports NCHW format only - const bool needPermut = axes.size() == 1 && mean->lengthOf() == input->sizeAt(-1); - - if(needPermut) { // if NHWC - std::vector perm = inRank == 4 ? std::vector({0, 3, 1, 2}) : std::vector({0, 4, 1, 2, 3}); // NHWC -> NCHW - input = new NDArray(input->permute(perm)); - output = new NDArray(output->permute(perm)); - } - - // cudnn requires gamma and beta to be non-nullptr - if(!applyScale) { - gamma = new NDArray(mean); - *gamma = 1; - } - if(!applyOffset) { - beta = new NDArray(mean); - *beta = 0; - } - - // calculations - batchnormCUDNN(block.launchContext(), input, mean, variance, gamma, beta, output, epsilon, axes.size() == 1); - - if(needPermut) { - delete input; - delete output; - } - - if(!applyScale) - delete gamma; - - if(!applyOffset) - delete beta; - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto mean = INPUT_VARIABLE(1); + auto variance = INPUT_VARIABLE(2); + NDArray* gamma = nullptr; + NDArray* beta = nullptr; + + auto output = OUTPUT_VARIABLE(0); + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + const double epsilon = T_ARG(0); + + if (applyScale) gamma = INPUT_VARIABLE(3); + if (applyOffset) beta = INPUT_VARIABLE(3 + (int)applyScale); + + const int numOfIntArgs = block.numI(); + const int inRank = input->rankOf(); + + // get axes args to normalize input array over + std::vector axes; + if (numOfIntArgs > 2) + for (int i = 2; i < numOfIntArgs; ++i) axes.push_back(INT_ARG(i)); + else + axes.push_back(inRank - + 1); // default dimension to reduce along is last dimension + + const int numOfAxes = axes.size(); + REQUIRE_TRUE(numOfAxes <= inRank, 0, + "BATCHNORM CUDNN op: too big number of input axes to normalize " + "over, expected number should be less or equal to rank of input " + "array, but got %i and %i correspondingly !", + numOfAxes, inRank); + + // evaluate expected shape for mean, variance and gamma. These 3 arrays should + // have identical shapes for example if input shape is {2,3,4,5,6} and axes = + // {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then + // expected shape would be {5} + std::vector expShape; + if (numOfAxes == 1) + expShape.push_back(input->sizeAt(axes[0])); + else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if + // axes = {1, 3} + expShape = std::vector(inRank, 1); + for (uint i = 0; i < numOfAxes; ++i) + expShape[axes[i]] = input->sizeAt(axes[i]); + } + + REQUIRE_TRUE(mean->isSameShape(expShape), 0, + "BATCHNORM CUDNN op: wrong shape of mean array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(mean).c_str()); + REQUIRE_TRUE(variance->isSameShape(expShape), 0, + "BATCHNORM CUDNN op: wrong shape of variance array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(variance).c_str()); + if (gamma) + REQUIRE_TRUE(gamma->isSameShape(expShape), 0, + "BATCHNORM CUDNN op: wrong shape of gamma array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(gamma).c_str()); + if (beta) + REQUIRE_TRUE(beta->isSameShape(expShape), 0, + "BATCHNORM CUDNN op: wrong shape of beta array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(beta).c_str()); + + // types of all input arrays should be the same + for (int i = 1; i < block.width(); ++i) + REQUIRE_TRUE( + INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, + "BATCHNORM CUDNN op: types of all input arrays should be the same !"); + + // cudnn supports NCHW format only + const bool needPermut = + axes.size() == 1 && mean->lengthOf() == input->sizeAt(-1); + + if (needPermut) { // if NHWC + std::vector perm = + inRank == 4 ? std::vector({0, 3, 1, 2}) + : std::vector({0, 4, 1, 2, 3}); // NHWC -> NCHW + input = new NDArray(input->permute(perm)); + output = new NDArray(output->permute(perm)); + } + + // cudnn requires gamma and beta to be non-nullptr + if (!applyScale) { + gamma = new NDArray(mean); + *gamma = 1; + } + if (!applyOffset) { + beta = new NDArray(mean); + *beta = 0; + } + + // calculations + batchnormCUDNN(block.launchContext(), input, mean, variance, gamma, beta, + output, epsilon, axes.size() == 1); + + if (needPermut) { + delete input; + delete output; + } + + if (!applyScale) delete gamma; + + if (!applyOffset) delete beta; + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(batchnorm, ENGINE_CUDA) { - - const bool applyScale = (bool)INT_ARG(0); - const bool applyOffset = (bool)INT_ARG(1); - - NDArray* input = INPUT_VARIABLE(0); - NDArray* mean = INPUT_VARIABLE(1); - NDArray* variance = INPUT_VARIABLE(2); - NDArray* gamma = applyScale ? INPUT_VARIABLE(3) : nullptr; - NDArray* beta = applyOffset ? INPUT_VARIABLE(3 + (int)applyScale) : nullptr; - - const int numOfIntArgs = block.getIArguments()->size(); - const int xRank = input->rankOf(); - - // *********************************** // - if(xRank != 4 && xRank != 5) - return false; - - // *********************************** // - const bool badType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; - if(badType) - return false; - - // *********************************** // - // get axes args to normalize input array over - std::vector axes; - if(numOfIntArgs > 2) - for(int i = 2; i < numOfIntArgs; ++i) - axes.push_back(INT_ARG(i)); - else - axes.push_back(xRank-1); // default dimension to reduce along is last dimension - - if(axes.size() != 1 && axes.size() != 3 && axes.size() != 4) - return false; - - // *********************************** // - bool allParamsHaveSameShapeAndStrides = shape::haveSameShapeAndStrides(mean->shapeInfo(), variance->shapeInfo()); - if(gamma) - allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->shapeInfo(), gamma->shapeInfo()); - if(beta) - allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->shapeInfo(), beta->shapeInfo()); - - if(!allParamsHaveSameShapeAndStrides) - return false; - - // *********************************** // - bool isFormatGood = false; - if(axes.size() == 1) - isFormatGood = mean->lengthOf() == input->sizeAt(1) || mean->lengthOf() == input->sizeAt(-1); // mean [C] - else { - auto inputShapeModif = input->getShapeAsVector(); // [dim0,dim1,dim2,dim3] 4D or [dim0,dim1,dim2,dim3,dim4] - inputShapeModif[0] = 1; - isFormatGood = mean->isSameShape(inputShapeModif); // mean [1,dim1,dim2,dim3] 4D or [1,dim1,dim2,dim3,dim4] - } - if(!isFormatGood) - return false; - - return true; + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + + NDArray* input = INPUT_VARIABLE(0); + NDArray* mean = INPUT_VARIABLE(1); + NDArray* variance = INPUT_VARIABLE(2); + NDArray* gamma = applyScale ? INPUT_VARIABLE(3) : nullptr; + NDArray* beta = applyOffset ? INPUT_VARIABLE(3 + (int)applyScale) : nullptr; + + const int numOfIntArgs = block.numI(); + const int xRank = input->rankOf(); + + // *********************************** // + if (xRank != 4 && xRank != 5) return false; + + // *********************************** // + const bool badType = input->dataType() != DataType::DOUBLE && + input->dataType() != DataType::FLOAT32 && + input->dataType() != DataType::HALF; + if (badType) return false; + + // *********************************** // + // get axes args to normalize input array over + std::vector axes; + if (numOfIntArgs > 2) + for (int i = 2; i < numOfIntArgs; ++i) axes.push_back(INT_ARG(i)); + else + axes.push_back(xRank - + 1); // default dimension to reduce along is last dimension + + if (axes.size() != 1 && axes.size() != 3 && axes.size() != 4) return false; + + // *********************************** // + bool allParamsHaveSameShapeAndStrides = + shape::haveSameShapeAndStrides(mean->shapeInfo(), variance->shapeInfo()); + if (gamma) + allParamsHaveSameShapeAndStrides &= + shape::haveSameShapeAndStrides(mean->shapeInfo(), gamma->shapeInfo()); + if (beta) + allParamsHaveSameShapeAndStrides &= + shape::haveSameShapeAndStrides(mean->shapeInfo(), beta->shapeInfo()); + + if (!allParamsHaveSameShapeAndStrides) return false; + + // *********************************** // + bool isFormatGood = false; + if (axes.size() == 1) + isFormatGood = mean->lengthOf() == input->sizeAt(1) || + mean->lengthOf() == input->sizeAt(-1); // mean [C] + else { + auto inputShapeModif = + input->getShapeAsVector(); // [dim0,dim1,dim2,dim3] 4D or + // [dim0,dim1,dim2,dim3,dim4] + inputShapeModif[0] = 1; + isFormatGood = + mean->isSameShape(inputShapeModif); // mean [1,dim1,dim2,dim3] 4D or + // [1,dim1,dim2,dim3,dim4] + } + if (!isFormatGood) return false; + + return true; } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(batchnorm_bp, ENGINE_CUDA) { - - NDArray* input = INPUT_VARIABLE(0); - NDArray* mean = INPUT_VARIABLE(1); - NDArray* variance = INPUT_VARIABLE(2); - NDArray* gamma = nullptr; - NDArray* beta = nullptr; - NDArray* gradO = INPUT_VARIABLE(block.width() - 1); // next epsilon - - NDArray* gradI = OUTPUT_VARIABLE(0); - NDArray* gradM = OUTPUT_VARIABLE(1); - NDArray* gradV = OUTPUT_VARIABLE(2); - NDArray* gradG = nullptr; - NDArray* gradB = nullptr; - - const bool applyScale = (bool)INT_ARG(0); - const bool applyOffset = (bool)INT_ARG(1); - const float epsilon = T_ARG(0); - - if(applyScale) { - gamma = INPUT_VARIABLE(3); - gradG = OUTPUT_VARIABLE(3); - } - if(applyOffset) { - beta = INPUT_VARIABLE(3 + (int)applyScale); - gradB = OUTPUT_VARIABLE(3 + (int)applyScale); - } - - const int numOfIntArgs = block.getIArguments()->size(); - const int inRank = input->rankOf(); - - // get axes args to normalize input array over - std::vector axes; - if(numOfIntArgs > 2) - for(int i = 2; i < numOfIntArgs; ++i) - axes.push_back(INT_ARG(i)); - else - axes.push_back(inRank-1); // default dimension to reduce along is last dimension - - const int numOfAxes = axes.size(); - REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM_BP CUDNN op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank); - - // evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes - // for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5} - std::vector expShape; - if(numOfAxes == 1) - expShape.push_back(input->sizeAt(axes[0])); - else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3} - expShape = std::vector(inRank, 1); - for(uint i = 0; i < numOfAxes; ++i) - expShape[axes[i]] = input->sizeAt(axes[i]); - } - - REQUIRE_TRUE(mean->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of mean array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(mean).c_str()); - REQUIRE_TRUE(variance->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of variance array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(variance).c_str()); - if(gamma) - REQUIRE_TRUE(gamma->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of gamma array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(gamma).c_str()); - if(beta) - REQUIRE_TRUE(beta->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of beta array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(beta).c_str()); - - REQUIRE_TRUE(input->isSameShape(gradO), 0, "BATCHNORM_BP CUDNN op: wrong shape of output gradients array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - - // types of all input arrays should be the same (except gradO) - for(int i = 1; i < block.width() - 2; ++i) - REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP CUDNN op: types of arrays (input, mean, variance, gamma, beta) should be the same !"); - - // cudnn supports NCHW format only - const bool needPermut = axes.size() == 1 && mean->lengthOf() != input->sizeAt(1); - - if(needPermut) { // if NHWC - std::vector perm = inRank == 4 ? std::vector({0, 3, 1, 2}) : std::vector({0, 4, 1, 2, 3}); // NHWC -> NCHW - input = new NDArray(input->permute(perm)); - gradO = new NDArray(gradO->permute(perm)); - gradI = new NDArray(gradI->permute(perm)); - } - - // cudnn requires gamma, gradG, gradB to be non-nullptr - if(!applyScale) { - gamma = new NDArray(mean); - gradG = new NDArray(mean); - *gamma = 1; - } - if(!applyOffset) - gradB = new NDArray(mean); - - // calculations - batchnormBpCUDNN(block.launchContext(), input, mean, variance, gamma, gradO, gradI, gradG, gradB, epsilon, axes.size() == 1); - - *gradM = 0; // put zeros so far - *gradV = 0; // put zeros so far - - if(needPermut) { - delete input; - delete gradO; - delete gradI; - } - - if(!applyScale) { - delete gamma; - delete gradG; - } - - if(!applyOffset) - delete gradB; - - return Status::OK(); - + NDArray* input = INPUT_VARIABLE(0); + NDArray* mean = INPUT_VARIABLE(1); + NDArray* variance = INPUT_VARIABLE(2); + NDArray* gamma = nullptr; + NDArray* beta = nullptr; + NDArray* gradO = INPUT_VARIABLE(block.width() - 1); // next epsilon + + NDArray* gradI = OUTPUT_VARIABLE(0); + NDArray* gradM = OUTPUT_VARIABLE(1); + NDArray* gradV = OUTPUT_VARIABLE(2); + NDArray* gradG = nullptr; + NDArray* gradB = nullptr; + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + const float epsilon = T_ARG(0); + + if (applyScale) { + gamma = INPUT_VARIABLE(3); + gradG = OUTPUT_VARIABLE(3); + } + if (applyOffset) { + beta = INPUT_VARIABLE(3 + (int)applyScale); + gradB = OUTPUT_VARIABLE(3 + (int)applyScale); + } + + const int numOfIntArgs = block.numI(); + const int inRank = input->rankOf(); + + // get axes args to normalize input array over + std::vector axes; + if (numOfIntArgs > 2) + for (int i = 2; i < numOfIntArgs; ++i) axes.push_back(INT_ARG(i)); + else + axes.push_back(inRank - + 1); // default dimension to reduce along is last dimension + + const int numOfAxes = axes.size(); + REQUIRE_TRUE(numOfAxes <= inRank, 0, + "BATCHNORM_BP CUDNN op: too big number of input axes to " + "normalize over, expected number should be less or equal to " + "rank of input array, but got %i and %i correspondingly !", + numOfAxes, inRank); + + // evaluate expected shape for mean, variance and gamma. These 3 arrays should + // have identical shapes for example if input shape is {2,3,4,5,6} and axes = + // {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then + // expected shape would be {5} + std::vector expShape; + if (numOfAxes == 1) + expShape.push_back(input->sizeAt(axes[0])); + else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if + // axes = {1, 3} + expShape = std::vector(inRank, 1); + for (uint i = 0; i < numOfAxes; ++i) + expShape[axes[i]] = input->sizeAt(axes[i]); + } + + REQUIRE_TRUE(mean->isSameShape(expShape), 0, + "BATCHNORM_BP CUDNN op: wrong shape of mean array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(mean).c_str()); + REQUIRE_TRUE(variance->isSameShape(expShape), 0, + "BATCHNORM_BP CUDNN op: wrong shape of variance array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(variance).c_str()); + if (gamma) + REQUIRE_TRUE(gamma->isSameShape(expShape), 0, + "BATCHNORM_BP CUDNN op: wrong shape of gamma array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(gamma).c_str()); + if (beta) + REQUIRE_TRUE(beta->isSameShape(expShape), 0, + "BATCHNORM_BP CUDNN op: wrong shape of beta array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(beta).c_str()); + + REQUIRE_TRUE(input->isSameShape(gradO), 0, + "BATCHNORM_BP CUDNN op: wrong shape of output gradients array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(input).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + + // types of all input arrays should be the same (except gradO) + for (int i = 1; i < block.width() - 2; ++i) + REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), + 0, + "BATCHNORM_BP CUDNN op: types of arrays (input, mean, " + "variance, gamma, beta) should be the same !"); + + // cudnn supports NCHW format only + const bool needPermut = + axes.size() == 1 && mean->lengthOf() != input->sizeAt(1); + + if (needPermut) { // if NHWC + std::vector perm = + inRank == 4 ? std::vector({0, 3, 1, 2}) + : std::vector({0, 4, 1, 2, 3}); // NHWC -> NCHW + input = new NDArray(input->permute(perm)); + gradO = new NDArray(gradO->permute(perm)); + gradI = new NDArray(gradI->permute(perm)); + } + + // cudnn requires gamma, gradG, gradB to be non-nullptr + if (!applyScale) { + gamma = new NDArray(mean); + gradG = new NDArray(mean); + *gamma = 1; + } + if (!applyOffset) gradB = new NDArray(mean); + + // calculations + batchnormBpCUDNN(block.launchContext(), input, mean, variance, gamma, gradO, + gradI, gradG, gradB, epsilon, axes.size() == 1); + + *gradM = 0; // put zeros so far + *gradV = 0; // put zeros so far + + if (needPermut) { + delete input; + delete gradO; + delete gradI; + } + + if (!applyScale) { + delete gamma; + delete gradG; + } + + if (!applyOffset) delete gradB; + + return Status::OK(); } PLATFORM_CHECK(batchnorm_bp, ENGINE_CUDA) { - - NDArray* input = INPUT_VARIABLE(0); - NDArray* mean = INPUT_VARIABLE(1); - NDArray* variance = INPUT_VARIABLE(2); - NDArray* gamma = nullptr; - NDArray* beta = nullptr; - NDArray* gradO = INPUT_VARIABLE(block.width() - 1); // next epsilon - - NDArray* gradI = OUTPUT_VARIABLE(0); - NDArray* gradM = OUTPUT_VARIABLE(1); - NDArray* gradV = OUTPUT_VARIABLE(2); - NDArray* gradG = nullptr; - NDArray* gradB = nullptr; - - const int numOfIntArgs = block.getIArguments()->size(); - const int xRank = input->rankOf(); - - // *********************************** // - if(xRank != 4 && xRank != 5) - return false; - - // *********************************** // - const bool badType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; - if(badType) - return false; - - // *********************************** // - // get axes args to normalize input array over - std::vector axes; - if(numOfIntArgs > 2) - for(int i = 2; i < numOfIntArgs; ++i) - axes.push_back(INT_ARG(i)); - else - axes.push_back(xRank-1); // default dimension to reduce along is last dimension - - if(axes.size() != 1 && axes.size() != 3 && axes.size() != 4) - return false; - - // *********************************** // - bool allParamsHaveSameShapeAndStrides = shape::haveSameShapeAndStrides(mean->shapeInfo(), variance->shapeInfo()); - if(gamma) - allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->shapeInfo(), gamma->shapeInfo()); - if(gradG) - allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->shapeInfo(), gradG->shapeInfo()); - if(gradB) - allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->shapeInfo(), gradB->shapeInfo()); - - if(!allParamsHaveSameShapeAndStrides) - return false; - - // *********************************** // - bool isFormatGood = false; - if(axes.size() == 1) - isFormatGood = mean->lengthOf() == input->sizeAt(1) || mean->lengthOf() == input->sizeAt(-1); // mean [C] - else { - auto inputShapeModif = input->getShapeAsVector(); // [dim0,dim1,dim2,dim3] 4D or [dim0,dim1,dim2,dim3,dim4] - inputShapeModif[0] = 1; - isFormatGood = mean->isSameShape(inputShapeModif); // mean [1,dim1,dim2,dim3] 4D or [1,dim1,dim2,dim3,dim4] - } - if(!isFormatGood) - return false; - - return true; + NDArray* input = INPUT_VARIABLE(0); + NDArray* mean = INPUT_VARIABLE(1); + NDArray* variance = INPUT_VARIABLE(2); + NDArray* gamma = nullptr; + NDArray* beta = nullptr; + NDArray* gradO = INPUT_VARIABLE(block.width() - 1); // next epsilon + + NDArray* gradI = OUTPUT_VARIABLE(0); + NDArray* gradM = OUTPUT_VARIABLE(1); + NDArray* gradV = OUTPUT_VARIABLE(2); + NDArray* gradG = nullptr; + NDArray* gradB = nullptr; + + const int numOfIntArgs = block.numI(); + const int xRank = input->rankOf(); + + // *********************************** // + if (xRank != 4 && xRank != 5) return false; + + // *********************************** // + const bool badType = input->dataType() != DataType::DOUBLE && + input->dataType() != DataType::FLOAT32 && + input->dataType() != DataType::HALF; + if (badType) return false; + + // *********************************** // + // get axes args to normalize input array over + std::vector axes; + if (numOfIntArgs > 2) + for (int i = 2; i < numOfIntArgs; ++i) axes.push_back(INT_ARG(i)); + else + axes.push_back(xRank - + 1); // default dimension to reduce along is last dimension + + if (axes.size() != 1 && axes.size() != 3 && axes.size() != 4) return false; + + // *********************************** // + bool allParamsHaveSameShapeAndStrides = + shape::haveSameShapeAndStrides(mean->shapeInfo(), variance->shapeInfo()); + if (gamma) + allParamsHaveSameShapeAndStrides &= + shape::haveSameShapeAndStrides(mean->shapeInfo(), gamma->shapeInfo()); + if (gradG) + allParamsHaveSameShapeAndStrides &= + shape::haveSameShapeAndStrides(mean->shapeInfo(), gradG->shapeInfo()); + if (gradB) + allParamsHaveSameShapeAndStrides &= + shape::haveSameShapeAndStrides(mean->shapeInfo(), gradB->shapeInfo()); + + if (!allParamsHaveSameShapeAndStrides) return false; + + // *********************************** // + bool isFormatGood = false; + if (axes.size() == 1) + isFormatGood = mean->lengthOf() == input->sizeAt(1) || + mean->lengthOf() == input->sizeAt(-1); // mean [C] + else { + auto inputShapeModif = + input->getShapeAsVector(); // [dim0,dim1,dim2,dim3] 4D or + // [dim0,dim1,dim2,dim3,dim4] + inputShapeModif[0] = 1; + isFormatGood = + mean->isSameShape(inputShapeModif); // mean [1,dim1,dim2,dim3] 4D or + // [1,dim1,dim2,dim3,dim4] + } + if (!isFormatGood) return false; + + return true; } - -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/cudnn/conv2d.cu b/libnd4j/include/ops/declarable/platform/cudnn/conv2d.cu index a77faf6f73a4..7bad4e2b0ca8 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/conv2d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/conv2d.cu @@ -19,520 +19,823 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // +#include #include "cudnnUtils.h" -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////////// -static void conv2dCUDNN(const LaunchContext* context, - const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, - const int kH, const int kW, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW, - const int paddingMode, const bool isNCHW, const int wFormat) { - - // cudnn support only two formats for weights {oC,iC,kH,kW} and {oC,kH,kW,iC} - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - auto handle = reinterpret_cast(context->getCuDnnHandle()); - cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); - if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: can't set stream for cuDNN", err); - - cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; - cudnnTensorFormat_t formatW = 0 == wFormat ? format : (1 == wFormat ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC); - - // input descriptor - cudnnTensorDescriptor_t x; - cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1 && input->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); - else - err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); - if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input failed", err); - - // weights descriptor - cudnnFilterDescriptor_t w; - cudnnCreateFilterDescriptor(&w); - err = cudnnSetFilter4dDescriptor(w, cudnnDataType(weights->dataType()), formatW, oC, iC, kH, kW); - if(err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnSetFilter4dDescriptor failed", err); - - // output descriptor - cudnnTensorDescriptor_t z; - cudnnCreateTensorDescriptor(&z); - if(output->ews() == 1 && output->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); - else - err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1)); - if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for output failed", err); - - // description of convolution - cudnnConvolutionDescriptor_t conv; - cudnnCreateConvolutionDescriptor(&conv); - err = cudnnSetConvolution2dDescriptor(conv, pH, pW, sH, sW, dH, dW, CUDNN_CROSS_CORRELATION, cudnnDataType(output->dataType())); - if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnSetConvolution2dDescriptor failed", err); - - // algorithm description - cudnnConvolutionFwdAlgo_t algo; - err = cudnnGetConvolutionForwardAlgorithm(*handle, x, w, conv, z, CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, 0, &algo); - if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnGetConvolutionForwardAlgorithm failed", err); - - - // allocate auxiliary device memory, abbreviation ws means workspace - size_t wsSize; - err = cudnnGetConvolutionForwardWorkspaceSize(*handle, x, w, conv, z, algo, &wsSize); - if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnGetConvolutionForwardWorkspaceSize failed", err); - void* wsData; - auto cudaErr = cudaMalloc(&wsData, wsSize); - if (cudaErr != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudaMalloc for auxiliary workspace memory failed", cudaErr); - - // provide scaling parameters - const float alpha32(1), beta32(0); - const double alpha64(1), beta64(0); - const void* alpha = output->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); - const void* beta = output->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); - - NDArray::prepareSpecialUse({output}, {input, weights, bias}); - - // run calculation - err = cudnnConvolutionForward(*handle, alpha, x, input->specialBuffer(), w, weights->specialBuffer(), conv, algo, wsData, wsSize, beta, z, output->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnConvolutionForward failed", err); - - // add bias if it is present - if (bias != nullptr) { - cudnnTensorDescriptor_t b; - cudnnCreateTensorDescriptor(&b); - // err = cudnnSetTensor4dDescriptor(b, format, cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, isNCHW ? 1: bias->lengthOf()); - err = cudnnSetTensor4dDescriptor(b, CUDNN_TENSOR_NCHW, cudnnDataType(bias->dataType()), 1, oC, 1, 1); - if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnSetTensor4dDescriptor for bias failed", err); - err = cudnnAddTensor(*handle, alpha, b, bias->specialBuffer(), alpha, z, output->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnAddTensor bias failed", err); - } - - // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - // if (cudaErr != 0) - // throw cuda_exception::build("conv2dCUDNN: cudaStreamSynchronize failed !", cudaErr); - - cudaErr = cudaFree(wsData); - if (cudaErr != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudaFree for auxiliary workspace memory failed", cudaErr); - - NDArray::registerSpecialUse({output}, {input, weights, bias}); +static void conv2dCUDNN(const LaunchContext* context, const NDArray* input, + const NDArray* weights, const NDArray* bias, + NDArray* output, const int kH, const int kW, + const int sH, const int sW, const int pH, const int pW, + const int dH, const int dW, const int paddingMode, + const bool isNCHW, const int wFormat) { + // cudnn support only two formats for weights {oC,iC,kH,kW} and {oC,kH,kW,iC} + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build("conv2dCUDNN: can't set stream for cuDNN", + err); + + cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + cudnnTensorFormat_t formatW = + 0 == wFormat ? format + : (1 == wFormat ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC); + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if (input->ews() == 1 && input->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx( + x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), + input->strideAt(indIOioC), input->strideAt(indIiH), + input->strideAt(indIiH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx " + "for input failed", + err); + + // weights descriptor + cudnnFilterDescriptor_t w; + cudnnCreateFilterDescriptor(&w); + err = cudnnSetFilter4dDescriptor(w, cudnnDataType(weights->dataType()), + formatW, oC, iC, kH, kW); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dCUDNN: cudnnSetFilter4dDescriptor failed", err); + + // output descriptor + cudnnTensorDescriptor_t z; + cudnnCreateTensorDescriptor(&z); + if (output->ews() == 1 && output->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); + else + err = cudnnSetTensor4dDescriptorEx( + z, cudnnDataType(output->dataType()), bS, oC, oH, oW, + output->strideAt(0), output->strideAt(indIOioC), + output->strideAt(indOoH), output->strideAt(indOoH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx " + "for output failed", + err); + + // description of convolution + cudnnConvolutionDescriptor_t conv; + cudnnCreateConvolutionDescriptor(&conv); + err = cudnnSetConvolution2dDescriptor(conv, pH, pW, sH, sW, dH, dW, + CUDNN_CROSS_CORRELATION, + cudnnDataType(output->dataType())); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dCUDNN: cudnnSetConvolution2dDescriptor failed", err); + + // algorithm description + cudnnConvolutionFwdAlgo_t algo; + err = cudnnGetConvolutionForwardAlgorithm( + *handle, x, w, conv, z, CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, 0, &algo); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dCUDNN: cudnnGetConvolutionForwardAlgorithm failed", err); + + // allocate auxiliary device memory, abbreviation ws means workspace + size_t wsSize; + err = cudnnGetConvolutionForwardWorkspaceSize(*handle, x, w, conv, z, algo, + &wsSize); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dCUDNN: cudnnGetConvolutionForwardWorkspaceSize failed", err); + void* wsData; + auto cudaErr = cudaMalloc(&wsData, wsSize); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "conv2dCUDNN: cudaMalloc for auxiliary workspace memory failed", + cudaErr); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = output->sizeOfT() <= 4 + ? reinterpret_cast(&alpha32) + : reinterpret_cast(&alpha64); + const void* beta = output->sizeOfT() <= 4 + ? reinterpret_cast(&beta32) + : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({output}, {input, weights, bias}); + + // run calculation + err = cudnnConvolutionForward(*handle, alpha, x, input->specialBuffer(), w, + weights->specialBuffer(), conv, algo, wsData, + wsSize, beta, z, output->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dCUDNN: cudnnConvolutionForward failed", err); + + // add bias if it is present + if (bias != nullptr) { + cudnnTensorDescriptor_t b; + cudnnCreateTensorDescriptor(&b); + // err = cudnnSetTensor4dDescriptor(b, format, + // cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, + // isNCHW ? 1: bias->lengthOf()); + err = cudnnSetTensor4dDescriptor( + b, CUDNN_TENSOR_NCHW, cudnnDataType(bias->dataType()), 1, oC, 1, 1); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dCUDNN: cudnnSetTensor4dDescriptor for bias failed", err); + err = cudnnAddTensor(*handle, alpha, b, bias->specialBuffer(), alpha, z, + output->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build("conv2dCUDNN: cudnnAddTensor bias failed", + err); + } + + // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + // if (cudaErr != 0) + // throw cuda_exception::build("conv2dCUDNN: cudaStreamSynchronize failed + // !", cudaErr); + + cudaErr = cudaFree(wsData); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "conv2dCUDNN: cudaFree for auxiliary workspace memory failed", cudaErr); + + NDArray::registerSpecialUse({output}, {input, weights, bias}); } ////////////////////////////////////////////////////////////////////////// -static void conv2dBpCUDNN(const LaunchContext* context, - const NDArray* input, const NDArray* weights, const NDArray* gradO, +static void conv2dBpCUDNN(const LaunchContext* context, const NDArray* input, + const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, - const int kH, const int kW, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW, - const int paddingMode, const bool isNCHW, const int wFormat) { - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - auto handle = reinterpret_cast(context->getCuDnnHandle()); - cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: can't set stream for cuDNN", err); - - cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; - cudnnTensorFormat_t formatW = 0 == wFormat ? format : (1 == wFormat ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC); - - // input descriptor - cudnnTensorDescriptor_t x; - cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1 && input->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); - else - err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input failed", err); - - // gradO descriptor - cudnnTensorDescriptor_t dz; - cudnnCreateTensorDescriptor(&dz); - if(gradO->ews() == 1 && gradO->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); - else - err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1)); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradO failed", err); - - // gradI descriptor - cudnnTensorDescriptor_t dx; - cudnnCreateTensorDescriptor(&dx); - if(gradI->ews() == 1 && gradI->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(dx, format, cudnnDataType(gradI->dataType()), bS, iC, iH, iW); - else - err = cudnnSetTensor4dDescriptorEx(dx, cudnnDataType(gradI->dataType()), bS, iC, iH, iW, gradI->strideAt(0), gradI->strideAt(indIOioC), gradI->strideAt(indIiH), gradI->strideAt(indIiH + 1)); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradI failed", err); - - // gradW descriptor - cudnnFilterDescriptor_t dw; - cudnnCreateFilterDescriptor(&dw); - err = cudnnSetFilter4dDescriptor(dw, cudnnDataType(gradW->dataType()), formatW, oC, iC, kH, kW); - if(err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnSetFilter4dDescriptor gradW failed", err); - - // description of convolution - cudnnConvolutionDescriptor_t conv; - cudnnCreateConvolutionDescriptor(&conv); - err = cudnnSetConvolution2dDescriptor(conv, pH, pW, sH, sW, dH, dW, CUDNN_CROSS_CORRELATION, cudnnDataType(gradO->dataType())); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnSetConvolution2dDescriptor failed", err); - - // gradW algorithm description - cudnnConvolutionBwdFilterAlgo_t algoGradW; - err = cudnnGetConvolutionBackwardFilterAlgorithm(*handle, x, dz, conv, dw, CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, 0, &algoGradW); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnGetConvolutionBackwardFilterAlgorithm failed", err); - - // gradI algorithm description - cudnnConvolutionBwdDataAlgo_t algoGradI; - err = cudnnGetConvolutionBackwardDataAlgorithm(*handle, dw, dz, conv, x, CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, 0, &algoGradI); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnGetConvolutionBackwardDataAlgorithm failed", err); - - // allocate auxiliary device memory for gradW calculation, abbreviation ws means workspace - size_t wsGradWSize; - err = cudnnGetConvolutionBackwardFilterWorkspaceSize(*handle, x, dz, conv, dw, algoGradW, &wsGradWSize); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnGetConvolutionBackwardFilterWorkspaceSize failed", err); - void* wsGradWData; - auto cudaErr = cudaMalloc(&wsGradWData, wsGradWSize); - if (cudaErr != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradWData failed", cudaErr); - - // allocate auxiliary device memory for gradI calculation, abbreviation ws means workspace - size_t wsGradISize; - err = cudnnGetConvolutionBackwardDataWorkspaceSize(*handle, dw, dz, conv, dx, algoGradI, &wsGradISize); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnGetConvolutionBackwardDataWorkspaceSize failed", err); - void* wsGradIData; - cudaErr = cudaMalloc(&wsGradIData, wsGradISize); - if (cudaErr != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradIData failed", cudaErr); - - // provide scaling parameters - const float alpha32(1), beta32(0); - const double alpha64(1), beta64(0); - const void* alpha = gradO->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); - const void* beta = gradO->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); - - NDArray::prepareSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); - - // run calculation for gradB (if not nullptr) - if(gradB != nullptr) { - cudnnTensorDescriptor_t db; - cudnnCreateTensorDescriptor(&db); - // err = cudnnSetTensor4dDescriptor(db, format, cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, isNCHW ? 1: gradB->lengthOf()); - err = cudnnSetTensor4dDescriptor(db, CUDNN_TENSOR_NCHW, cudnnDataType(gradB->dataType()), 1, oC, 1, 1); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnSetTensor4dDescriptor for gradB failed", err); - - err = cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->specialBuffer(), beta, db, gradB->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnConvolutionBackwardBias failed", err); - } - - // run calculation for gradW - err = cudnnConvolutionBackwardFilter(*handle, alpha, x, input->specialBuffer(), dz, gradO->specialBuffer(), conv, algoGradW, wsGradWData, wsGradWSize, beta, dw, gradW->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnConvolutionBackwardFilter failed", err); - - // run calculation for gradI - err = cudnnConvolutionBackwardData(*handle, alpha, dw, weights->specialBuffer(), dz, gradO->specialBuffer(), conv, algoGradI, wsGradIData, wsGradISize, beta, dx, gradI->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnConvolutionBackwardData failed", err); - - // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - // if (cudaErr != 0) - // throw cuda_exception::build("conv2dBpCUDNN: cudaStreamSynchronize failed !", cudaErr); - - cudaErr = cudaFree(wsGradWData); - if (cudaErr != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudaFree for auxiliary workspace memory wsGradWData failed", cudaErr); - cudaErr = cudaFree(wsGradIData); - if (cudaErr != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudaFree for auxiliary workspace memory wsGradIData failed", cudaErr); - - NDArray::registerSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); + const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, + const int dH, const int dW, const int paddingMode, + const bool isNCHW, const int wFormat) { + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build("conv2dBpCUDNN: can't set stream for cuDNN", + err); + + cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + cudnnTensorFormat_t formatW = + 0 == wFormat ? format + : (1 == wFormat ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC); + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if (input->ews() == 1 && input->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx( + x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), + input->strideAt(indIOioC), input->strideAt(indIiH), + input->strideAt(indIiH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: " + "cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input " + "failed", + err); + + // gradO descriptor + cudnnTensorDescriptor_t dz; + cudnnCreateTensorDescriptor(&dz); + if (gradO->ews() == 1 && gradO->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); + else + err = cudnnSetTensor4dDescriptorEx( + dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, + gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), + gradO->strideAt(indOoH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: " + "cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradO " + "failed", + err); + + // gradI descriptor + cudnnTensorDescriptor_t dx; + cudnnCreateTensorDescriptor(&dx); + if (gradI->ews() == 1 && gradI->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + dx, format, cudnnDataType(gradI->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx( + dx, cudnnDataType(gradI->dataType()), bS, iC, iH, iW, + gradI->strideAt(0), gradI->strideAt(indIOioC), gradI->strideAt(indIiH), + gradI->strideAt(indIiH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: " + "cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradI " + "failed", + err); + + // gradW descriptor + cudnnFilterDescriptor_t dw; + cudnnCreateFilterDescriptor(&dw); + err = cudnnSetFilter4dDescriptor(dw, cudnnDataType(gradW->dataType()), + formatW, oC, iC, kH, kW); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudnnSetFilter4dDescriptor gradW failed", err); + + // description of convolution + cudnnConvolutionDescriptor_t conv; + cudnnCreateConvolutionDescriptor(&conv); + err = cudnnSetConvolution2dDescriptor(conv, pH, pW, sH, sW, dH, dW, + CUDNN_CROSS_CORRELATION, + cudnnDataType(gradO->dataType())); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudnnSetConvolution2dDescriptor failed", err); + + // gradW algorithm description + cudnnConvolutionBwdFilterAlgo_t algoGradW; + err = cudnnGetConvolutionBackwardFilterAlgorithm( + *handle, x, dz, conv, dw, CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, 0, + &algoGradW); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudnnGetConvolutionBackwardFilterAlgorithm failed", + err); + + // gradI algorithm description + cudnnConvolutionBwdDataAlgo_t algoGradI; + err = cudnnGetConvolutionBackwardDataAlgorithm( + *handle, dw, dz, conv, x, CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, 0, + &algoGradI); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudnnGetConvolutionBackwardDataAlgorithm failed", err); + + // allocate auxiliary device memory for gradW calculation, abbreviation ws + // means workspace + size_t wsGradWSize; + err = cudnnGetConvolutionBackwardFilterWorkspaceSize(*handle, x, dz, conv, dw, + algoGradW, &wsGradWSize); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudnnGetConvolutionBackwardFilterWorkspaceSize failed", + err); + void* wsGradWData; + auto cudaErr = cudaMalloc(&wsGradWData, wsGradWSize); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradWData " + "failed", + cudaErr); + + // allocate auxiliary device memory for gradI calculation, abbreviation ws + // means workspace + size_t wsGradISize; + err = cudnnGetConvolutionBackwardDataWorkspaceSize(*handle, dw, dz, conv, dx, + algoGradI, &wsGradISize); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudnnGetConvolutionBackwardDataWorkspaceSize failed", + err); + void* wsGradIData; + cudaErr = cudaMalloc(&wsGradIData, wsGradISize); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradIData " + "failed", + cudaErr); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = gradO->sizeOfT() <= 4 + ? reinterpret_cast(&alpha32) + : reinterpret_cast(&alpha64); + const void* beta = gradO->sizeOfT() <= 4 + ? reinterpret_cast(&beta32) + : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); + + // run calculation for gradB (if not nullptr) + if (gradB != nullptr) { + cudnnTensorDescriptor_t db; + cudnnCreateTensorDescriptor(&db); + // err = cudnnSetTensor4dDescriptor(db, format, + // cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, + // isNCHW ? 1: gradB->lengthOf()); + err = cudnnSetTensor4dDescriptor( + db, CUDNN_TENSOR_NCHW, cudnnDataType(gradB->dataType()), 1, oC, 1, 1); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudnnSetTensor4dDescriptor for gradB failed", err); + + err = + cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->specialBuffer(), + beta, db, gradB->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudnnConvolutionBackwardBias failed", err); + } + + // run calculation for gradW + err = cudnnConvolutionBackwardFilter( + *handle, alpha, x, input->specialBuffer(), dz, gradO->specialBuffer(), + conv, algoGradW, wsGradWData, wsGradWSize, beta, dw, + gradW->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudnnConvolutionBackwardFilter failed", err); + + // run calculation for gradI + err = cudnnConvolutionBackwardData( + *handle, alpha, dw, weights->specialBuffer(), dz, gradO->specialBuffer(), + conv, algoGradI, wsGradIData, wsGradISize, beta, dx, + gradI->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudnnConvolutionBackwardData failed", err); + + // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + // if (cudaErr != 0) + // throw cuda_exception::build("conv2dBpCUDNN: cudaStreamSynchronize + // failed !", cudaErr); + + cudaErr = cudaFree(wsGradWData); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudaFree for auxiliary workspace memory wsGradWData " + "failed", + cudaErr); + cudaErr = cudaFree(wsGradIData); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudaFree for auxiliary workspace memory wsGradIData " + "failed", + cudaErr); + + NDArray::registerSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv2d, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - - auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0)); // filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // filter(kernel) width - - REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM CONV2D CUDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM CONV2D CUDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) { - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - REQUIRE_TRUE((bias->rankOf() == 1 && bias->strideAt(0) == 1) || (bias->rankOf() == 2 && bias->sizeAt(0) == 1 && bias->strideAt(1) == 1) || (bias->rankOf() == 2 && bias->sizeAt(1) == 1 && bias->strideAt(0) == 1), 0, "CUSTOM CONV2D CUDNN OP: bias array should be contiguous in memory !"); - } - - NDArray* newWeights = weights; // cudnn support only two formats {oC,iC,kH,kW} and {oC,kH,kW,iC} - if(0 == wFormat) { - newWeights = new NDArray(weights->ordering(), isNCHW ? std::vector({oC, iC, kH, kW}) : std::vector({oC, kH, kW, iC}), weights->dataType(), weights->getContext()); - newWeights->assign(weights->permute(isNCHW ? std::vector({3,2,0,1}) : std::vector({3,0,1,2}))); // (kH, kW, iC, oC --> oC, iC, kH, kW) or (kH, kW, iC, oC --> oC, kH, kW, iC) - } - - NDArray* newInput = input; - NDArray* newGradI = nullptr; - if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings - checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW); - - conv2dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW, paddingMode, isNCHW, wFormat); - - if(newInput != input) - delete newInput; - - if(0 == wFormat) - delete newWeights; - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + auto output = + OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + bool isNCHW = block.numI() > 9 + ? !INT_ARG(9) + : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.numI() > 10 + ? INT_ARG(10) + : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - + // [oC, kH, kW, iC] + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "CUSTOM CONV2D CUDNN OP: rank of input array must be equal to " + "4, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "CUSTOM CONV2D CUDNN OP: rank of weights array must be equal to " + "4, but got %i instead !", + weights->rankOf()); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW, paddingMode); + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM CONV2D CUDNN OP: wrong shape of weights array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) { + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM CONV2D CUDNN OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + REQUIRE_TRUE( + (bias->rankOf() == 1 && bias->strideAt(0) == 1) || + (bias->rankOf() == 2 && bias->sizeAt(0) == 1 && + bias->strideAt(1) == 1) || + (bias->rankOf() == 2 && bias->sizeAt(1) == 1 && + bias->strideAt(0) == 1), + 0, + "CUSTOM CONV2D CUDNN OP: bias array should be contiguous in memory !"); + } + + NDArray* newWeights = weights; // cudnn support only two formats + // {oC,iC,kH,kW} and {oC,kH,kW,iC} + if (0 == wFormat) { + newWeights = new NDArray(weights->ordering(), + isNCHW ? std::vector({oC, iC, kH, kW}) + : std::vector({oC, kH, kW, iC}), + weights->dataType(), weights->getContext()); + newWeights->assign(weights->permute( + isNCHW + ? std::vector({3, 2, 0, 1}) + : std::vector( + {3, 0, 1, 2}))); // (kH, kW, iC, oC --> oC, iC, kH, kW) or + // (kH, kW, iC, oC --> oC, kH, kW, iC) + } + + NDArray* newInput = input; + NDArray* newGradI = nullptr; + if (paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric + // left/right top/bottopm paddings + checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, + sH, sW, pH, pW, dH, dW, isNCHW); + + conv2dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kH, kW, + sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); + + if (newInput != input) delete newInput; + + if (0 == wFormat) delete newWeights; + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(conv2d, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - - const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL - - const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; - const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; - const bool badBiasType = bias == nullptr ? false : (bias->dataType() != DataType::DOUBLE && bias->dataType() != DataType::FLOAT32 && bias->dataType() != DataType::HALF); - - return paddingMode != 2 && !badInputType && !badWeightsType && !badBiasType; + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL + + const bool badInputType = input->dataType() != DataType::DOUBLE && + input->dataType() != DataType::FLOAT32 && + input->dataType() != DataType::HALF; + const bool badWeightsType = weights->dataType() != DataType::DOUBLE && + weights->dataType() != DataType::FLOAT32 && + weights->dataType() != DataType::HALF; + const bool badBiasType = bias == nullptr + ? false + : (bias->dataType() != DataType::DOUBLE && + bias->dataType() != DataType::FLOAT32 && + bias->dataType() != DataType::HALF); + + return paddingMode != 2 && !badInputType && !badWeightsType && !badBiasType; } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) { + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, + // oH, oW] (NCHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + auto gradW = OUTPUT_VARIABLE( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = block.numI() > 9 + ? !INT_ARG(9) + : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.numI() > 10 + ? INT_ARG(10) + : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - + // [oC, kH, kW, iC] + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "CUSTOM CONV2D_BP CUDNN OP: rank of input array must be equal " + "to 4, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "CUSTOM CONV2D_BP CUDNN OP: rank of weights array must be equal " + "to 4, but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 4, 0, + "CUSTOM CONV2D_BP CUDNN OP: rank of output's gradients (next " + "epsilon) array must be equal to 4, but got %i instead !", + gradO->rankOf()); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, paddingMode); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW, paddingMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM CONV2D_BP CUDNN OP: wrong shape of output gradients " + "(next epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM CONV2D_BP CUDNN OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM CONV2D_BP CUDNN OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + NDArray *newWeights = weights, + *newGradW = gradW; // cudnn support only two formats {oC,iC,kH,kW} + // and {oC,kH,kW,iC} + if (0 == wFormat) { + newGradW = new NDArray(gradW->ordering(), + isNCHW ? std::vector({oC, iC, kH, kW}) + : std::vector({oC, kH, kW, iC}), + gradW->dataType(), gradW->getContext()); + newWeights = new NDArray(weights->ordering(), + isNCHW ? std::vector({oC, iC, kH, kW}) + : std::vector({oC, kH, kW, iC}), + weights->dataType(), weights->getContext()); + newWeights->assign(weights->permute( + isNCHW + ? std::vector({3, 2, 0, 1}) + : std::vector( + {3, 0, 1, 2}))); // (kH, kW, iC, oC --> oC, iC, kH, kW) or + // (kH, kW, iC, oC --> oC, kH, kW, iC) + } + + NDArray* newInput = input; + NDArray* newGradI = gradI; + if (paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric + // left/right top/bottopm paddings + checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, + sH, sW, pH, pW, dH, dW, isNCHW); + + conv2dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, + newGradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, + isNCHW, wFormat); + + if (0 == wFormat) { + newGradW->permutei( + isNCHW ? std::vector({2, 3, 1, 0}) + : std::vector( + {1, 2, 3, 0})); // (oC, iC, kH, kW --> kH, kW, iC, oC) or + // (oC, kH, kW, iC --> kH, kW, iC, oC) + gradW->assign(newGradW); + } + + if (newInput != input) { + if (isNCHW) + gradI->assign( + (*newGradI)({0, 0, 0, 0, 0, gradI->sizeAt(2), 0, gradI->sizeAt(3)})); + else + gradI->assign( + (*newGradI)({0, 0, 0, gradI->sizeAt(1), 0, gradI->sizeAt(2), 0, 0})); - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM CONV2D_BP CUDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM CONV2D_BP CUDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 4, 0, "CUSTOM CONV2D_BP CUDNN OP: rank of output's gradients (next epsilon) array must be equal to 4, but got %i instead !", gradO->rankOf()); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV2D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - NDArray *newWeights = weights, *newGradW = gradW; // cudnn support only two formats {oC,iC,kH,kW} and {oC,kH,kW,iC} - if(0 == wFormat) { - newGradW = new NDArray(gradW->ordering(), isNCHW ? std::vector({oC, iC, kH, kW}) : std::vector({oC, kH, kW, iC}), gradW->dataType(), gradW->getContext()); - newWeights = new NDArray(weights->ordering(), isNCHW ? std::vector({oC, iC, kH, kW}) : std::vector({oC, kH, kW, iC}), weights->dataType(), weights->getContext()); - newWeights->assign(weights->permute(isNCHW ? std::vector({3,2,0,1}) : std::vector({3,0,1,2}))); // (kH, kW, iC, oC --> oC, iC, kH, kW) or (kH, kW, iC, oC --> oC, kH, kW, iC) - } - - NDArray* newInput = input; - NDArray* newGradI = gradI; - if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings - checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW); - - conv2dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, newGradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,paddingMode,isNCHW,wFormat); - - if(0 == wFormat) { - newGradW->permutei(isNCHW ? std::vector({2,3,1,0}) : std::vector({1,2,3,0})); // (oC, iC, kH, kW --> kH, kW, iC, oC) or (oC, kH, kW, iC --> kH, kW, iC, oC) - gradW->assign(newGradW); - } - - if(newInput != input) { - - if(isNCHW) - gradI->assign((*newGradI)({0,0, 0,0, 0,gradI->sizeAt(2), 0,gradI->sizeAt(3)})); - else - gradI->assign((*newGradI)({0,0, 0,gradI->sizeAt(1), 0,gradI->sizeAt(2), 0,0})); - - delete newInput; - delete newGradI; - } - - if(0 == wFormat) { - delete newWeights; - delete newGradW; - } - - return Status::OK(); -} - -PLATFORM_CHECK(conv2d_bp, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - - const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL - const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + delete newInput; + delete newGradI; + } - const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; - const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; - const bool badGradOType = gradO->dataType() != DataType::DOUBLE && gradO->dataType() != DataType::FLOAT32 && gradO->dataType() != DataType::HALF; - const bool badBiasType = bias == nullptr ? false : (bias->dataType() != DataType::DOUBLE && bias->dataType() != DataType::FLOAT32 && bias->dataType() != DataType::HALF); + if (0 == wFormat) { + delete newWeights; + delete newGradW; + } - return isNCHW && paddingMode != 2 && !badInputType && !badWeightsType && !badGradOType && !badBiasType; + return Status::OK(); } - - - - - +PLATFORM_CHECK(conv2d_bp, ENGINE_CUDA) { + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, + // oH, oW] (NCHW), epsilon_next + + const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL + const int isNCHW = block.numI() > 9 + ? !INT_ARG(9) + : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + + const bool badInputType = input->dataType() != DataType::DOUBLE && + input->dataType() != DataType::FLOAT32 && + input->dataType() != DataType::HALF; + const bool badWeightsType = weights->dataType() != DataType::DOUBLE && + weights->dataType() != DataType::FLOAT32 && + weights->dataType() != DataType::HALF; + const bool badGradOType = gradO->dataType() != DataType::DOUBLE && + gradO->dataType() != DataType::FLOAT32 && + gradO->dataType() != DataType::HALF; + const bool badBiasType = bias == nullptr + ? false + : (bias->dataType() != DataType::DOUBLE && + bias->dataType() != DataType::FLOAT32 && + bias->dataType() != DataType::HALF); + + return isNCHW && paddingMode != 2 && !badInputType && !badWeightsType && + !badGradOType && !badBiasType; +} // PLATFORM_IMPL(conv2d, ENGINE_CUDA) { -// auto handle = reinterpret_cast(block.launchContext()->getCuDnnHandle()); -// auto res = cudnnSetStream(*handle, *block.launchContext()->getCudaStream()); -// if (res != 0) +// auto handle = reinterpret_cast(block.launchContext()->getCuDnnHandle()); auto res = +// cudnnSetStream(*handle, *block.launchContext()->getCudaStream()); if (res +// != 0) // throw sd::cuda_exception::build("Can't set stream for cuDNN", res); -// auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) -// auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always -// auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] +// auto input = INPUT_VARIABLE(0); // +// [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto weights = +// INPUT_VARIABLE(1); // [kH, kW, iC, oC] +// always auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // +// [oC] -// auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) +// auto output = OUTPUT_VARIABLE(0); // +// [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) // NDArray::prepareSpecialUse({output}, {input, weights, bias}); -// int sH = INT_ARG(2); // strides height -// int sW = INT_ARG(3); // strides width -// int pH = INT_ARG(4); // paddings height -// int pW = INT_ARG(5); // paddings width -// int dH = INT_ARG(6); // dilations height -// int dW = INT_ARG(7); // dilations width -// int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME -// bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - -// int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0)); // filter(kernel) height -// int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // filter(kernel) width - -// int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; -// int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes -// ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); -// ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, isSameMode); +// int sH = INT_ARG(2); // strides height int sW = INT_ARG(3); // strides +// width int pH = INT_ARG(4); // paddings height int pW = INT_ARG(5); // +// paddings width int dH = INT_ARG(6); // dilations height int dW = +// INT_ARG(7); // +// dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME bool +// isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // +// INT_ARG(9): 0-NCHW, 1-NHWC + +// int kH = INT_ARG(0) > 0 ? INT_ARG(0) : +// static_cast(weights->sizeAt(0)); // filter(kernel) height int kW = +// INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // +// filter(kernel) width + +// int bS, iC, iH, iW, oC, oH, oW; // batch +// size, input channels, input height/width, output channels, output +// height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // +// corresponding indexes ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, +// *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, +// indWoC, indWkH, indOoH); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, +// iH, iW, kH, kW, sH, sW, dH, dW, isSameMode); // auto dtype = cudnnDataType(input->dataType()); - // cudnnTensorDescriptor_t src; // cudnnCreateTensorDescriptor(&src); -// res = cudnnSetTensor4dDescriptorEx(src, dtype, input->sizeAt(0), input->sizeAt(1), input->sizeAt(2), input->sizeAt(3), input->strideAt(0), input->strideAt(1), input->strideAt(2), input->strideAt(3)); -// if (res != 0) -// throw sd::cuda_exception::build("cudnnSetTensor4dDescriptorEx src failed", res); +// res = cudnnSetTensor4dDescriptorEx(src, dtype, input->sizeAt(0), +// input->sizeAt(1), input->sizeAt(2), input->sizeAt(3), input->strideAt(0), +// input->strideAt(1), input->strideAt(2), input->strideAt(3)); if (res != +// 0) +// throw sd::cuda_exception::build("cudnnSetTensor4dDescriptorEx src +// failed", res); // // TODO: we definitely want NHWC here as well // cudnnFilterDescriptor_t wght; // cudnnCreateFilterDescriptor(&wght); -// res = cudnnSetFilter4dDescriptor(wght, dtype, CUDNN_TENSOR_NCHW, oC, iC, kH, kW); -// if (res != 0) -// throw sd::cuda_exception::build("cudnnSetFilter4dDescriptor failed", res); +// res = cudnnSetFilter4dDescriptor(wght, dtype, CUDNN_TENSOR_NCHW, oC, iC, +// kH, kW); if (res != 0) +// throw sd::cuda_exception::build("cudnnSetFilter4dDescriptor failed", +// res); // cudnnConvolutionDescriptor_t cdc; // cudnnCreateConvolutionDescriptor(&cdc); -// res = cudnnSetConvolution2dDescriptor(cdc, pH, pW, sH, sW, dH, dW, CUDNN_CROSS_CORRELATION, dtype); -// if (res != 0) -// throw sd::cuda_exception::build("cudnnSetConvolution2dDescriptor failed", res); +// res = cudnnSetConvolution2dDescriptor(cdc, pH, pW, sH, sW, dH, dW, +// CUDNN_CROSS_CORRELATION, dtype); if (res != 0) +// throw sd::cuda_exception::build("cudnnSetConvolution2dDescriptor +// failed", res); // cudnnTensorDescriptor_t dst; // cudnnCreateTensorDescriptor(&dst); -// res = cudnnSetTensor4dDescriptorEx(dst, dtype, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3), output->strideAt(0), output->strideAt(1), output->strideAt(2), output->strideAt(3)); -// if (res != 0) -// throw sd::cuda_exception::build("cudnnSetTensor4dDescriptorEx dst failed", res); - -// // TODO: workspace algorithms are supposed to be faster, so we should use it here if we have enough memory -// cudnnConvolutionFwdAlgo_t algo; -// res = cudnnGetConvolutionForwardAlgorithm(*handle, src, wght, cdc, dst, CUDNN_CONVOLUTION_FWD_NO_WORKSPACE, 0, &algo); -// if (res != 0) -// throw sd::cuda_exception::build("cudnnGetConvolutionForwardAlgorithm failed", res); +// res = cudnnSetTensor4dDescriptorEx(dst, dtype, output->sizeAt(0), +// output->sizeAt(1), output->sizeAt(2), output->sizeAt(3), +// output->strideAt(0), output->strideAt(1), output->strideAt(2), +// output->strideAt(3)); if (res != 0) +// throw sd::cuda_exception::build("cudnnSetTensor4dDescriptorEx dst +// failed", res); + +// // TODO: workspace algorithms are supposed to be faster, so we should use +// it here if we have enough memory cudnnConvolutionFwdAlgo_t algo; res = +// cudnnGetConvolutionForwardAlgorithm(*handle, src, wght, cdc, dst, +// CUDNN_CONVOLUTION_FWD_NO_WORKSPACE, 0, &algo); if (res != 0) +// throw sd::cuda_exception::build("cudnnGetConvolutionForwardAlgorithm +// failed", res); // // TODO: should be float if dtype is half/float, and double otherwise // float alpha = 1.0f; // float beta = 0.0f; -// res = cudnnConvolutionForward(*handle, &alpha, src, input->specialBuffer(), wght, weights->specialBuffer(), cdc, algo, nullptr, 0, &beta, dst, output->specialBuffer()); -// if (res != 0) -// throw sd::cuda_exception::build("cudnnConvolutionForward failed", res); - +// res = cudnnConvolutionForward(*handle, &alpha, src, +// input->specialBuffer(), wght, weights->specialBuffer(), cdc, algo, +// nullptr, 0, &beta, dst, output->specialBuffer()); if (res != 0) +// throw sd::cuda_exception::build("cudnnConvolutionForward failed", +// res); // if (bias != nullptr) { // cudnnTensorDescriptor_t bs; // cudnnCreateTensorDescriptor(&bs); // if (isNCHW) { -// res = cudnnSetTensor4dDescriptor(bs, CUDNN_TENSOR_NCHW, dtype, 1, bias->lengthOf(), 1, 1); -// if (res != 0) -// throw sd::cuda_exception::build("cudnnSetTensor4dDescriptorEx bias NHWC failed", res); +// res = cudnnSetTensor4dDescriptor(bs, CUDNN_TENSOR_NCHW, dtype, 1, +// bias->lengthOf(), 1, 1); if (res != 0) +// throw sd::cuda_exception::build("cudnnSetTensor4dDescriptorEx +// bias NHWC failed", res); // } else { -// res = cudnnSetTensor4dDescriptor(bs, CUDNN_TENSOR_NHWC, dtype, 1, 1, 1, bias->lengthOf()); -// if (res != 0) -// throw sd::cuda_exception::build("cudnnSetTensor4dDescriptorEx bias NHWC failed", res); +// res = cudnnSetTensor4dDescriptor(bs, CUDNN_TENSOR_NHWC, dtype, 1, +// 1, 1, bias->lengthOf()); if (res != 0) +// throw sd::cuda_exception::build("cudnnSetTensor4dDescriptorEx +// bias NHWC failed", res); // } -// res = cudnnAddTensor(*handle, &alpha, bs, bias->specialBuffer(), &alpha, dst, output->specialBuffer()); -// if (res != 0) +// res = cudnnAddTensor(*handle, &alpha, bs, bias->specialBuffer(), +// &alpha, dst, output->specialBuffer()); if (res != 0) // throw sd::cuda_exception::build("cudnnAddTensor failed", res); // } - // NDArray::registerSpecialUse({output}, {input, weights, bias}); // return Status::OK(); // } - -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/cudnn/conv3d.cu b/libnd4j/include/ops/declarable/platform/cudnn/conv3d.cu index 693ebeefa053..dfd63d957eee 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/conv3d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/conv3d.cu @@ -19,453 +19,758 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // +#include #include "cudnnUtils.h" -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////////// -static void conv3dCUDNN(const LaunchContext* context, - const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const int paddingMode, const bool isNCDHW, const int wFormat) { - - // cudnn support only one format for weights {oC,iC,kD,kH,kW} - - const int numDims = 5; - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - auto handle = reinterpret_cast(context->getCuDnnHandle()); - cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); - if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: can't set stream for cuDNN", err); - - const std::vector pads = {pD, pH, pW}; - const std::vector filtStrides = {sD, sH, sW}; - const std::vector dilations = {dD, dH, dW}; - - const std::vector xShape = {bS, iC, iD, iH, iW}; - const std::vector zShape = {bS, oC, oD, oH, oW}; - const std::vector wShape = {oC, iC, kD, kH, kW}; - const std::vector bShape = {1, oC, 1, 1, 1}; // {1, (isNCDHW ? oC : 1), 1, 1, (isNCDHW ? 1 : oC)}; - - const std::vector xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)}; - const std::vector zStrides = {(int)output->strideAt(0), (int)output->strideAt(1), (int)output->strideAt(2), (int)output->strideAt(3), (int)output->strideAt(4)}; - - cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; - - // input descriptor - cudnnTensorDescriptor_t x; - cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1) - err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape.data()); - else - err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape.data(), xStrides.data()); - if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input failed", err); - - // weights descriptor - cudnnFilterDescriptor_t w; - cudnnCreateFilterDescriptor(&w); - err = cudnnSetFilterNdDescriptor(w, cudnnDataType(weights->dataType()), CUDNN_TENSOR_NCHW, numDims, wShape.data()); - if(err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnSetFilterNdDescriptor failed", err); - - // output descriptor - cudnnTensorDescriptor_t z; - cudnnCreateTensorDescriptor(&z); - if(output->ews() == 1) - err = cudnnSetTensorNdDescriptorEx(z, format, cudnnDataType(output->dataType()), numDims, zShape.data()); - else - err = cudnnSetTensorNdDescriptor(z, cudnnDataType(output->dataType()), numDims, zShape.data(), zStrides.data()); - if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for output failed", err); - - // description of convolution - cudnnConvolutionDescriptor_t conv; - cudnnCreateConvolutionDescriptor(&conv); - err = cudnnSetConvolutionNdDescriptor(conv, numDims-2, pads.data(), filtStrides.data(), dilations.data(), CUDNN_CROSS_CORRELATION, cudnnDataType(output->dataType())); - if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnSetConvolutionNdDescriptor failed", err); - - // algorithm description - cudnnConvolutionFwdAlgo_t algo; - err = cudnnGetConvolutionForwardAlgorithm(*handle, x, w, conv, z, CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, 0, &algo); - if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnGetConvolutionForwardAlgorithm failed", err); - - // allocate auxiliary device memory, abbreviation ws means workspace - size_t wsSize; - err = cudnnGetConvolutionForwardWorkspaceSize(*handle, x, w, conv, z, algo, &wsSize); - if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnGetConvolutionForwardWorkspaceSize failed", err); - void* wsData; - auto cudaErr = cudaMalloc(&wsData, wsSize); - if (cudaErr != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudaMalloc for auxiliary workspace memory failed", cudaErr); - - // provide scaling parameters - const float alpha32(1), beta32(0); - const double alpha64(1), beta64(0); - const void* alpha = output->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); - const void* beta = output->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); - - NDArray::prepareSpecialUse({output}, {input, weights, bias}); - - // run calculation - err = cudnnConvolutionForward(*handle, alpha, x, input->specialBuffer(), w, weights->specialBuffer(), conv, algo, wsData, wsSize, beta, z, output->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnConvolutionForward failed", err); - - // add bias if it is present - if (bias != nullptr) { - - cudnnTensorDescriptor_t b; - cudnnCreateTensorDescriptor(&b); - err = cudnnSetTensorNdDescriptorEx(b, /*format*/CUDNN_TENSOR_NCHW, cudnnDataType(bias->dataType()), numDims, bShape.data()); - if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnSetTensorNdDescriptor for bias failed", err); - err = cudnnAddTensor(*handle, alpha, b, bias->specialBuffer(), alpha, z, output->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnAddTensor bias failed", err); - } - - // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - // if (cudaErr != 0) - // throw cuda_exception::build("conv3dCUDNN: cudaStreamSynchronize failed !", cudaErr); - - cudaErr = cudaFree(wsData); - if (cudaErr != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudaFree for auxiliary workspace memory failed", cudaErr); - - NDArray::registerSpecialUse({output}, {input, weights, bias}); +static void conv3dCUDNN(const LaunchContext* context, const NDArray* input, + const NDArray* weights, const NDArray* bias, + NDArray* output, const int kD, const int kH, + const int kW, const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, const int dD, + const int dH, const int dW, const int paddingMode, + const bool isNCDHW, const int wFormat) { + // cudnn support only one format for weights {oC,iC,kD,kH,kW} + + const int numDims = 5; + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWiC, indWoC, indWkD); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build("conv3dCUDNN: can't set stream for cuDNN", + err); + + const std::vector pads = {pD, pH, pW}; + const std::vector filtStrides = {sD, sH, sW}; + const std::vector dilations = {dD, dH, dW}; + + const std::vector xShape = {bS, iC, iD, iH, iW}; + const std::vector zShape = {bS, oC, oD, oH, oW}; + const std::vector wShape = {oC, iC, kD, kH, kW}; + const std::vector bShape = { + 1, oC, 1, 1, 1}; // {1, (isNCDHW ? oC : 1), 1, 1, (isNCDHW ? 1 : oC)}; + + const std::vector xStrides = { + (int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), + (int)input->strideAt(3), (int)input->strideAt(4)}; + const std::vector zStrides = { + (int)output->strideAt(0), (int)output->strideAt(1), + (int)output->strideAt(2), (int)output->strideAt(3), + (int)output->strideAt(4)}; + + cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if (input->ews() == 1) + err = cudnnSetTensorNdDescriptorEx( + x, format, cudnnDataType(input->dataType()), numDims, xShape.data()); + else + err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), + numDims, xShape.data(), xStrides.data()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx " + "for input failed", + err); + + // weights descriptor + cudnnFilterDescriptor_t w; + cudnnCreateFilterDescriptor(&w); + err = cudnnSetFilterNdDescriptor(w, cudnnDataType(weights->dataType()), + CUDNN_TENSOR_NCHW, numDims, wShape.data()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dCUDNN: cudnnSetFilterNdDescriptor failed", err); + + // output descriptor + cudnnTensorDescriptor_t z; + cudnnCreateTensorDescriptor(&z); + if (output->ews() == 1) + err = cudnnSetTensorNdDescriptorEx( + z, format, cudnnDataType(output->dataType()), numDims, zShape.data()); + else + err = cudnnSetTensorNdDescriptor(z, cudnnDataType(output->dataType()), + numDims, zShape.data(), zStrides.data()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx " + "for output failed", + err); + + // description of convolution + cudnnConvolutionDescriptor_t conv; + cudnnCreateConvolutionDescriptor(&conv); + err = cudnnSetConvolutionNdDescriptor( + conv, numDims - 2, pads.data(), filtStrides.data(), dilations.data(), + CUDNN_CROSS_CORRELATION, cudnnDataType(output->dataType())); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dCUDNN: cudnnSetConvolutionNdDescriptor failed", err); + + // algorithm description + cudnnConvolutionFwdAlgo_t algo; + err = cudnnGetConvolutionForwardAlgorithm( + *handle, x, w, conv, z, CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, 0, &algo); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dCUDNN: cudnnGetConvolutionForwardAlgorithm failed", err); + + // allocate auxiliary device memory, abbreviation ws means workspace + size_t wsSize; + err = cudnnGetConvolutionForwardWorkspaceSize(*handle, x, w, conv, z, algo, + &wsSize); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dCUDNN: cudnnGetConvolutionForwardWorkspaceSize failed", err); + void* wsData; + auto cudaErr = cudaMalloc(&wsData, wsSize); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "conv3dCUDNN: cudaMalloc for auxiliary workspace memory failed", + cudaErr); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = output->sizeOfT() <= 4 + ? reinterpret_cast(&alpha32) + : reinterpret_cast(&alpha64); + const void* beta = output->sizeOfT() <= 4 + ? reinterpret_cast(&beta32) + : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({output}, {input, weights, bias}); + + // run calculation + err = cudnnConvolutionForward(*handle, alpha, x, input->specialBuffer(), w, + weights->specialBuffer(), conv, algo, wsData, + wsSize, beta, z, output->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dCUDNN: cudnnConvolutionForward failed", err); + + // add bias if it is present + if (bias != nullptr) { + cudnnTensorDescriptor_t b; + cudnnCreateTensorDescriptor(&b); + err = cudnnSetTensorNdDescriptorEx(b, /*format*/ CUDNN_TENSOR_NCHW, + cudnnDataType(bias->dataType()), numDims, + bShape.data()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dCUDNN: cudnnSetTensorNdDescriptor for bias failed", err); + err = cudnnAddTensor(*handle, alpha, b, bias->specialBuffer(), alpha, z, + output->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build("conv3dCUDNN: cudnnAddTensor bias failed", + err); + } + + // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + // if (cudaErr != 0) + // throw cuda_exception::build("conv3dCUDNN: cudaStreamSynchronize failed + // !", cudaErr); + + cudaErr = cudaFree(wsData); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "conv3dCUDNN: cudaFree for auxiliary workspace memory failed", cudaErr); + + NDArray::registerSpecialUse({output}, {input, weights, bias}); } ////////////////////////////////////////////////////////////////////////// -static void conv3dBpCUDNN(const LaunchContext* context, - const NDArray* input, const NDArray* weights, const NDArray* gradO, +static void conv3dBpCUDNN(const LaunchContext* context, const NDArray* input, + const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, - const int paddingMode, const bool isNCDHW, const int wFormat) { - - // cudnn supports only two formats {oC,iC,kD,kH,kW} and {oC,kD,kH,kW,iC} for weights/gradW - - const int numDims = 5; - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - auto handle = reinterpret_cast(context->getCuDnnHandle()); - cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: can't set stream for cuDNN", err); - - const std::vector pads = {pD, pH, pW}; - const std::vector filtStrides = {sD, sH, sW}; - const std::vector dilations = {dD, dH, dW}; - - const std::vector xShape = {bS, iC, iD, iH, iW}; - const std::vector dzShape = {bS, oC, oD, oH, oW}; - const std::vector wShape = {oC, iC, kD, kH, kW}; - const std::vector dbShape = {1, (int)(isNCDHW ? oC : 1), 1, 1, (int)(isNCDHW ? 1 : oC)}; - - const std::vector xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)}; - const std::vector dxStrides = {(int)gradI->strideAt(0), (int)gradI->strideAt(1), (int)gradI->strideAt(2), (int)gradI->strideAt(3), (int)gradI->strideAt(4)}; - const std::vector dzStrides = {(int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), (int)gradO->strideAt(3), (int)gradO->strideAt(4)}; - - cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; - cudnnTensorFormat_t formatW = 0 == wFormat ? format : (1 == wFormat ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC); - - // input descriptor - cudnnTensorDescriptor_t x; - cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1) - err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape.data()); - else - err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape.data(), xStrides.data()); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input failed", err); - - // gradO descriptor - cudnnTensorDescriptor_t dz; - cudnnCreateTensorDescriptor(&dz); - if(gradO->ews() == 1) - err = cudnnSetTensorNdDescriptorEx(dz, format, cudnnDataType(gradO->dataType()), numDims, dzShape.data()); - else - err = cudnnSetTensorNdDescriptor(dz, cudnnDataType(gradO->dataType()), numDims, dzShape.data(), dzStrides.data()); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradO failed", err); - - // gradI descriptor - cudnnTensorDescriptor_t dx; - cudnnCreateTensorDescriptor(&dx); - if(gradI->ews() == 1) - err = cudnnSetTensorNdDescriptorEx(dx, format, cudnnDataType(gradI->dataType()), numDims, xShape.data()); - else - err = cudnnSetTensorNdDescriptor(dx, cudnnDataType(gradI->dataType()), numDims, xShape.data(), dxStrides.data()); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradI failed", err); - - // gradW descriptor - cudnnFilterDescriptor_t dw; - cudnnCreateFilterDescriptor(&dw); - err = cudnnSetFilterNdDescriptor(dw, cudnnDataType(gradW->dataType()), formatW, numDims, wShape.data()); - if(err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnSetFilterNdDescriptor failed", err); - - // description of convolution - cudnnConvolutionDescriptor_t conv; - cudnnCreateConvolutionDescriptor(&conv); - err = cudnnSetConvolutionNdDescriptor(conv, numDims-2, pads.data(), filtStrides.data(), dilations.data(), CUDNN_CROSS_CORRELATION, cudnnDataType(gradO->dataType())); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnSetConvolutionNdDescriptor failed", err); - - // gradW algorithm description - cudnnConvolutionBwdFilterAlgo_t algoGradW; - err = cudnnGetConvolutionBackwardFilterAlgorithm(*handle, x, dz, conv, dw, CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, 0, &algoGradW); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnGetConvolutionBackwardFilterAlgorithm failed", err); - - // gradI algorithm description - cudnnConvolutionBwdDataAlgo_t algoGradI; - err = cudnnGetConvolutionBackwardDataAlgorithm(*handle, dw, dz, conv, x, CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, 0, &algoGradI); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnGetConvolutionBackwardDataAlgorithm failed", err); - - // allocate auxiliary device memory for gradW calculation, abbreviation ws means workspace - size_t wsGradWSize; - err = cudnnGetConvolutionBackwardFilterWorkspaceSize(*handle, x, dz, conv, dw, algoGradW, &wsGradWSize); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnGetConvolutionBackwardFilterWorkspaceSize failed", err); - void* wsGradWData; - auto cudaErr = cudaMalloc(&wsGradWData, wsGradWSize); - if (cudaErr != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradWData failed", cudaErr); - - // allocate auxiliary device memory for gradI calculation, abbreviation ws means workspace - size_t wsGradISize; - err = cudnnGetConvolutionBackwardDataWorkspaceSize(*handle, dw, dz, conv, dx, algoGradI, &wsGradISize); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnGetConvolutionBackwardDataWorkspaceSize failed", err); - void* wsGradIData; - cudaErr = cudaMalloc(&wsGradIData, wsGradISize); - if (cudaErr != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradIData failed", cudaErr); - - // provide scaling parameters - const float alpha32(1), beta32(0); - const double alpha64(1), beta64(0); - const void* alpha = gradO->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); - const void* beta = gradO->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); - - NDArray::prepareSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); - - // run calculation for gradB (if not nullptr) - if(gradB != nullptr) { - - cudnnTensorDescriptor_t db; - cudnnCreateTensorDescriptor(&db); - err = cudnnSetTensorNdDescriptorEx(db, format, cudnnDataType(gradB->dataType()), numDims, dbShape.data()); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnSetTensorNdDescriptor for gradB failed", err); - - err = cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->specialBuffer(), beta, db, gradB->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnConvolutionBackwardBias failed", err); - } - - // run calculation for gradW - err = cudnnConvolutionBackwardFilter(*handle, alpha, x, input->specialBuffer(), dz, gradO->specialBuffer(), conv, algoGradW, wsGradWData, wsGradWSize, beta, dw, gradW->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnConvolutionBackwardFilter failed", err); - - // run calculation for gradI - err = cudnnConvolutionBackwardData(*handle, alpha, dw, weights->specialBuffer(), dz, gradO->specialBuffer(), conv, algoGradI, wsGradIData, wsGradISize, beta, dx, gradI->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnConvolutionBackwardData failed", err); - - // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - // if (cudaErr != 0) - // throw cuda_exception::build("conv3dBpCUDNN: cudaStreamSynchronize failed !", cudaErr); - - cudaErr = cudaFree(wsGradWData); - if (cudaErr != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudaFree for auxiliary workspace memory wsGradWData failed", cudaErr); - cudaErr = cudaFree(wsGradIData); - if (cudaErr != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudaFree for auxiliary workspace memory wsGradIData failed", cudaErr); - - NDArray::registerSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); + const int paddingMode, const bool isNCDHW, + const int wFormat) { + // cudnn supports only two formats {oC,iC,kD,kH,kW} and {oC,kD,kH,kW,iC} for + // weights/gradW + + const int numDims = 5; + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWiC, indWoC, indWkD); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build("conv3dBpCUDNN: can't set stream for cuDNN", + err); + + const std::vector pads = {pD, pH, pW}; + const std::vector filtStrides = {sD, sH, sW}; + const std::vector dilations = {dD, dH, dW}; + + const std::vector xShape = {bS, iC, iD, iH, iW}; + const std::vector dzShape = {bS, oC, oD, oH, oW}; + const std::vector wShape = {oC, iC, kD, kH, kW}; + const std::vector dbShape = {1, (int)(isNCDHW ? oC : 1), 1, 1, + (int)(isNCDHW ? 1 : oC)}; + + const std::vector xStrides = { + (int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), + (int)input->strideAt(3), (int)input->strideAt(4)}; + const std::vector dxStrides = { + (int)gradI->strideAt(0), (int)gradI->strideAt(1), (int)gradI->strideAt(2), + (int)gradI->strideAt(3), (int)gradI->strideAt(4)}; + const std::vector dzStrides = { + (int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), + (int)gradO->strideAt(3), (int)gradO->strideAt(4)}; + + cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + cudnnTensorFormat_t formatW = + 0 == wFormat ? format + : (1 == wFormat ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC); + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if (input->ews() == 1) + err = cudnnSetTensorNdDescriptorEx( + x, format, cudnnDataType(input->dataType()), numDims, xShape.data()); + else + err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), + numDims, xShape.data(), xStrides.data()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input " + "failed", + err); + + // gradO descriptor + cudnnTensorDescriptor_t dz; + cudnnCreateTensorDescriptor(&dz); + if (gradO->ews() == 1) + err = cudnnSetTensorNdDescriptorEx( + dz, format, cudnnDataType(gradO->dataType()), numDims, dzShape.data()); + else + err = cudnnSetTensorNdDescriptor(dz, cudnnDataType(gradO->dataType()), + numDims, dzShape.data(), dzStrides.data()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradO " + "failed", + err); + + // gradI descriptor + cudnnTensorDescriptor_t dx; + cudnnCreateTensorDescriptor(&dx); + if (gradI->ews() == 1) + err = cudnnSetTensorNdDescriptorEx( + dx, format, cudnnDataType(gradI->dataType()), numDims, xShape.data()); + else + err = cudnnSetTensorNdDescriptor(dx, cudnnDataType(gradI->dataType()), + numDims, xShape.data(), dxStrides.data()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradI " + "failed", + err); + + // gradW descriptor + cudnnFilterDescriptor_t dw; + cudnnCreateFilterDescriptor(&dw); + err = cudnnSetFilterNdDescriptor(dw, cudnnDataType(gradW->dataType()), + formatW, numDims, wShape.data()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudnnSetFilterNdDescriptor failed", err); + + // description of convolution + cudnnConvolutionDescriptor_t conv; + cudnnCreateConvolutionDescriptor(&conv); + err = cudnnSetConvolutionNdDescriptor( + conv, numDims - 2, pads.data(), filtStrides.data(), dilations.data(), + CUDNN_CROSS_CORRELATION, cudnnDataType(gradO->dataType())); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudnnSetConvolutionNdDescriptor failed", err); + + // gradW algorithm description + cudnnConvolutionBwdFilterAlgo_t algoGradW; + err = cudnnGetConvolutionBackwardFilterAlgorithm( + *handle, x, dz, conv, dw, CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, 0, + &algoGradW); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudnnGetConvolutionBackwardFilterAlgorithm failed", + err); + + // gradI algorithm description + cudnnConvolutionBwdDataAlgo_t algoGradI; + err = cudnnGetConvolutionBackwardDataAlgorithm( + *handle, dw, dz, conv, x, CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, 0, + &algoGradI); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudnnGetConvolutionBackwardDataAlgorithm failed", err); + + // allocate auxiliary device memory for gradW calculation, abbreviation ws + // means workspace + size_t wsGradWSize; + err = cudnnGetConvolutionBackwardFilterWorkspaceSize(*handle, x, dz, conv, dw, + algoGradW, &wsGradWSize); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudnnGetConvolutionBackwardFilterWorkspaceSize failed", + err); + void* wsGradWData; + auto cudaErr = cudaMalloc(&wsGradWData, wsGradWSize); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradWData " + "failed", + cudaErr); + + // allocate auxiliary device memory for gradI calculation, abbreviation ws + // means workspace + size_t wsGradISize; + err = cudnnGetConvolutionBackwardDataWorkspaceSize(*handle, dw, dz, conv, dx, + algoGradI, &wsGradISize); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudnnGetConvolutionBackwardDataWorkspaceSize failed", + err); + void* wsGradIData; + cudaErr = cudaMalloc(&wsGradIData, wsGradISize); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradIData " + "failed", + cudaErr); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = gradO->sizeOfT() <= 4 + ? reinterpret_cast(&alpha32) + : reinterpret_cast(&alpha64); + const void* beta = gradO->sizeOfT() <= 4 + ? reinterpret_cast(&beta32) + : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); + + // run calculation for gradB (if not nullptr) + if (gradB != nullptr) { + cudnnTensorDescriptor_t db; + cudnnCreateTensorDescriptor(&db); + err = cudnnSetTensorNdDescriptorEx( + db, format, cudnnDataType(gradB->dataType()), numDims, dbShape.data()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudnnSetTensorNdDescriptor for gradB failed", err); + + err = + cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->specialBuffer(), + beta, db, gradB->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudnnConvolutionBackwardBias failed", err); + } + + // run calculation for gradW + err = cudnnConvolutionBackwardFilter( + *handle, alpha, x, input->specialBuffer(), dz, gradO->specialBuffer(), + conv, algoGradW, wsGradWData, wsGradWSize, beta, dw, + gradW->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudnnConvolutionBackwardFilter failed", err); + + // run calculation for gradI + err = cudnnConvolutionBackwardData( + *handle, alpha, dw, weights->specialBuffer(), dz, gradO->specialBuffer(), + conv, algoGradI, wsGradIData, wsGradISize, beta, dx, + gradI->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudnnConvolutionBackwardData failed", err); + + // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + // if (cudaErr != 0) + // throw cuda_exception::build("conv3dBpCUDNN: cudaStreamSynchronize + // failed !", cudaErr); + + cudaErr = cudaFree(wsGradWData); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudaFree for auxiliary workspace memory wsGradWData " + "failed", + cudaErr); + cudaErr = cudaFree(wsGradIData); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudaFree for auxiliary workspace memory wsGradIData " + "failed", + cudaErr); + + NDArray::registerSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv3dnew, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) - - REQUIRE_TRUE(input->rankOf() == 5, 0, "CONV3D CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 5, 0, "CONV3D CUDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf()); - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] - - REQUIRE_TRUE(paddingMode < 2, 0, "CONV3D CUDNN OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode); - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV3D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV3D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - NDArray* newWeights = weights; // cudnn support only one format {oC,iC,kD,kH,kW} - if(1 != wFormat) { - newWeights = new NDArray(weights->ordering(), {oC, iC, kD, kH, kW}, weights->dataType(), weights->getContext()); - newWeights->assign(weights->permute(0 == wFormat ? std::vector({4,3,0,1,2}) : std::vector({0,4,1,2,3}))); // kD, kH, kW, iC, oC --> oC, iC, kD, kH, kW or oC, kD, kH, kW, iC --> oC, iC, kD, kH, kW - } - - NDArray* newInput = input; - NDArray* newGradI = nullptr; - if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings - checkConv3dCUDNNPadAsymmetric(newInput, newGradI, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW); - - conv3dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kD,kH,kW,sD,sH,sW,pD,pH,pW,dD,dH,dW, paddingMode, isNCDHW, wFormat); - - if(newInput != input) - delete newInput; - - if(1 != wFormat) - delete newWeights; - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto output = OUTPUT_VARIABLE( + 0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "CONV3D CUDNN OP: rank of input array must be equal to 5, but " + "got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, + "CONV3D CUDNN OP: rank of weights array must be equal to 5, but " + "got %i instead !", + weights->rankOf()); + + int kD = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 + ? INT_ARG(2) + : static_cast(weights->sizeAt(2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID + int isNCDHW = block.numI() > 13 + ? !INT_ARG(13) + : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 + ? INT_ARG(14) + : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], + // 2-[oC, kD, kH, kW, iC] + + REQUIRE_TRUE(paddingMode < 2, 0, + "CONV3D CUDNN OP: causal padding mode (paddingMode = 2) is not " + "allowed for this operation !"); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWiC, indWoC, indWkD); + + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW, paddingMode); + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CONV3D CUDNN OP: wrong shape of weights array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CONV3D CUDNN OP: wrong shape of array with biases, expected " + "rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + NDArray* newWeights = + weights; // cudnn support only one format {oC,iC,kD,kH,kW} + if (1 != wFormat) { + newWeights = new NDArray(weights->ordering(), {oC, iC, kD, kH, kW}, + weights->dataType(), weights->getContext()); + newWeights->assign(weights->permute( + 0 == wFormat + ? std::vector({4, 3, 0, 1, 2}) + : std::vector( + {0, 4, 1, 2, + 3}))); // kD, kH, kW, iC, oC --> oC, iC, kD, kH, kW or + // oC, kD, kH, kW, iC --> oC, iC, kD, kH, kW + } + + NDArray* newInput = input; + NDArray* newGradI = nullptr; + if (paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric + // left/right top/bottopm paddings + checkConv3dCUDNNPadAsymmetric(newInput, newGradI, iD, iH, iW, oD, oH, oW, + kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, + dW, isNCDHW); + + conv3dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kD, kH, + kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW, + wFormat); + + if (newInput != input) delete newInput; + + if (1 != wFormat) delete newWeights; + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(conv3dnew, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - - int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID - - const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; - const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; - const bool badBiasType = bias == nullptr ? false : (bias->dataType() != DataType::DOUBLE && bias->dataType() != DataType::FLOAT32 && bias->dataType() != DataType::HALF); - - return paddingMode != 2 && !badInputType && !badWeightsType && !badBiasType; + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID + + const bool badInputType = input->dataType() != DataType::DOUBLE && + input->dataType() != DataType::FLOAT32 && + input->dataType() != DataType::HALF; + const bool badWeightsType = weights->dataType() != DataType::DOUBLE && + weights->dataType() != DataType::FLOAT32 && + weights->dataType() != DataType::HALF; + const bool badBiasType = bias == nullptr + ? false + : (bias->dataType() != DataType::DOUBLE && + bias->dataType() != DataType::FLOAT32 && + bias->dataType() != DataType::HALF); + + return paddingMode != 2 && !badInputType && !badWeightsType && !badBiasType; } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) { + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = + block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_VARIABLE( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "CONV3D_BP CUDNN OP: rank of input array must be equal to 5, " + "but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, + "CONV3D_BP CUDNN OP: rank of weights array must be equal to 5, " + "but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 5, 0, + "CONV3D_BP CUDNN OP: rank of output gradients (next epsilon) " + "array must be equal to 5, but got %i instead !", + gradO->rankOf()); + + int kD = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 + ? INT_ARG(2) + : static_cast(weights->sizeAt(2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + int isNCDHW = block.numI() > 13 + ? !INT_ARG(13) + : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 + ? INT_ARG(14) + : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], + // 2-[oC, kD, kH, kW, iC] + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWiC, indWoC, indWkD); + + int trueoD, trueoH, trueoW; // true output depth/height/width + ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, + sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, + iW, paddingMode); + + REQUIRE_TRUE(paddingMode < 2, 0, + "CONV3D_BP CUDNN OP: causal padding mode (paddingMode = 2) is " + "not allowed for this operation !"); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, oC, trueoD, trueoH, trueoW, + 0, indIOioC, indIOioD, + indIOioD + 1, indIOioD + 2}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CONV3D_BP CUDNN OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradW->isSameShape(expectedWeightsShape), 0, + "CONV3D_BP CUDNN OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CONV3D_BP CUDNN OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW, paddingMode); + + NDArray *newWeights = weights, + *newGradW = gradW; // cudnn support only two formats {oC,iC,kD,kH,kW} + // and {oC,kD,kH,kW,iC} + if (0 == wFormat) { + newGradW = + new NDArray(gradW->ordering(), + isNCDHW ? std::vector({oC, iC, kD, kH, kW}) + : std::vector({oC, kD, kH, kW, iC}), + gradW->dataType(), gradW->getContext()); + newWeights = + new NDArray(weights->ordering(), + isNCDHW ? std::vector({oC, iC, kD, kH, kW}) + : std::vector({oC, kD, kH, kW, iC}), + weights->dataType(), weights->getContext()); + newWeights->assign(weights->permute( + isNCDHW + ? std::vector({4, 3, 0, 1, 2}) + : std::vector( + {4, 0, 1, 2, + 3}))); // (kD, kH, kW, iC, oC --> oC, iC, kD, kH, kW) or + // (kD, kH, kW, iC, oC --> oC, kD, kH, kW, iC) + } + + NDArray* newInput = input; + NDArray* newGradI = gradI; + if (paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric + // left/right top/bottopm paddings + checkConv3dCUDNNPadAsymmetric(newInput, newGradI, iD, iH, iW, oD, oH, oW, + kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, + dW, isNCDHW); + + conv3dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, + newGradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, isNCDHW, wFormat); + + if (0 == wFormat) { + newGradW->permutei( + isNCDHW ? std::vector({2, 3, 4, 1, 0}) + : std::vector( + {1, 2, 3, 4, + 0})); // (oC, iC, kD, kH, kW --> kD, kH, kW, iC, oC) or + // (oC, kD, kH, kW, iC --> kD, kH, kW, iC, oC) + gradW->assign(newGradW); + } + + if (newInput != input) { + if (isNCDHW) + gradI->assign((*newGradI)({0, 0, 0, 0, 0, gradI->sizeAt(2), 0, + gradI->sizeAt(3), 0, gradI->sizeAt(4)})); + else + gradI->assign((*newGradI)({0, 0, 0, gradI->sizeAt(1), 0, gradI->sizeAt(2), + 0, gradI->sizeAt(3), 0, 0})); - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - REQUIRE_TRUE(input->rankOf() == 5, 0, "CONV3D_BP CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 5, 0, "CONV3D_BP CUDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 5, 0, "CONV3D_BP CUDNN OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !", gradO->rankOf()); - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - int trueoD, trueoH, trueoW; // true output depth/height/width - ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); - - REQUIRE_TRUE(paddingMode < 2, 0, "CONV3D_BP CUDNN OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CONV3D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(gradW->isSameShape(expectedWeightsShape), 0, "CONV3D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV3D_BP CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode); - - NDArray *newWeights = weights, *newGradW = gradW; // cudnn support only two formats {oC,iC,kD,kH,kW} and {oC,kD,kH,kW,iC} - if(0 == wFormat) { - newGradW = new NDArray(gradW->ordering(), isNCDHW ? std::vector({oC, iC, kD, kH, kW}) : std::vector({oC, kD, kH, kW, iC}), gradW->dataType(), gradW->getContext()); - newWeights = new NDArray(weights->ordering(), isNCDHW ? std::vector({oC, iC, kD, kH, kW}) : std::vector({oC, kD, kH, kW, iC}), weights->dataType(), weights->getContext()); - newWeights->assign(weights->permute(isNCDHW ? std::vector({4,3,0,1,2}) : std::vector({4,0,1,2,3}))); // (kD, kH, kW, iC, oC --> oC, iC, kD, kH, kW) or (kD, kH, kW, iC, oC --> oC, kD, kH, kW, iC) - } - - NDArray* newInput = input; - NDArray* newGradI = gradI; - if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings - checkConv3dCUDNNPadAsymmetric(newInput, newGradI, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW); - - conv3dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, newGradW, gradB, kD,kH,kW,sD,sH,sW,pD,pH,pW,dD,dH,dW,paddingMode,isNCDHW,wFormat); - - if(0 == wFormat) { - newGradW->permutei(isNCDHW ? std::vector({2,3,4,1,0}) : std::vector({1,2,3,4,0})); // (oC, iC, kD, kH, kW --> kD, kH, kW, iC, oC) or (oC, kD, kH, kW, iC --> kD, kH, kW, iC, oC) - gradW->assign(newGradW); - } - - - if(newInput != input) { - - if(isNCDHW) - gradI->assign((*newGradI)({0,0, 0,0, 0,gradI->sizeAt(2), 0,gradI->sizeAt(3), 0,gradI->sizeAt(4)})); - else - gradI->assign((*newGradI)({0,0, 0,gradI->sizeAt(1), 0,gradI->sizeAt(2), 0,gradI->sizeAt(3), 0,0})); - - delete newInput; - delete newGradI; - } - - if(0 == wFormat) { - delete newWeights; - delete newGradW; - } - - return Status::OK(); -} - -PLATFORM_CHECK(conv3dnew_bp, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - - int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + delete newInput; + delete newGradI; + } - const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; - const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; - const bool badGradOType = gradO->dataType() != DataType::DOUBLE && gradO->dataType() != DataType::FLOAT32 && gradO->dataType() != DataType::HALF; - const bool badBiasType = bias == nullptr ? false : (bias->dataType() != DataType::DOUBLE && bias->dataType() != DataType::FLOAT32 && bias->dataType() != DataType::HALF); + if (0 == wFormat) { + delete newWeights; + delete newGradW; + } - return isNCDHW && paddingMode != 2 && !badInputType && !badWeightsType && !badGradOType && !badBiasType; + return Status::OK(); } +PLATFORM_CHECK(conv3dnew_bp, ENGINE_CUDA) { + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = + block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, + // oH, oW] (NCDHW), epsilon_next + + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + int isNCDHW = block.numI() > 13 + ? !INT_ARG(13) + : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + + const bool badInputType = input->dataType() != DataType::DOUBLE && + input->dataType() != DataType::FLOAT32 && + input->dataType() != DataType::HALF; + const bool badWeightsType = weights->dataType() != DataType::DOUBLE && + weights->dataType() != DataType::FLOAT32 && + weights->dataType() != DataType::HALF; + const bool badGradOType = gradO->dataType() != DataType::DOUBLE && + gradO->dataType() != DataType::FLOAT32 && + gradO->dataType() != DataType::HALF; + const bool badBiasType = bias == nullptr + ? false + : (bias->dataType() != DataType::DOUBLE && + bias->dataType() != DataType::FLOAT32 && + bias->dataType() != DataType::HALF); + + return isNCDHW && paddingMode != 2 && !badInputType && !badWeightsType && + !badGradOType && !badBiasType; } -} -} + +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu index 54f8a1f3bbca..821ad71fc841 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu @@ -18,395 +18,533 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // +#include #include "cudnnUtils.h" -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////////// -void checkConv2dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI, - const int iH, const int iW, - const int oH, const int oW, - const int kH, const int kW, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW, - const bool isNCHW) { +void checkConv2dCUDNNPadAsymmetric(NDArray*& input, NDArray*& gradI, + const int iH, const int iW, const int oH, + const int oW, const int kH, const int kW, + const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW, + const bool isNCHW) { + const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH); + const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW); - const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH); - const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW); + const bool isPHasymm = pH != (pHsum - pH); + const bool isPWasymm = pW != (pWsum - pW); - const bool isPHasymm = pH != (pHsum - pH); - const bool isPWasymm = pW != (pWsum - pW); + if (!isPHasymm && !isPWasymm) return; - if(!isPHasymm && !isPWasymm) - return; + std::vector newShape = input->getShapeAsVector(); - std::vector newShape = input->getShapeAsVector(); + const int iHposition = isNCHW ? 2 : 1; - const int iHposition = isNCHW ? 2 : 1; + if (isPHasymm) newShape[iHposition] += 1; + if (isPWasymm) newShape[iHposition + 1] += 1; - if(isPHasymm) - newShape[iHposition] += 1; - if(isPWasymm) - newShape[iHposition + 1] += 1; + NDArray* newInput = new NDArray(input->ordering(), newShape, + input->dataType(), input->getContext()); - NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext()); + if (isNCHW) + (*newInput)({0, 0, 0, 0, 0, input->sizeAt(2), 0, input->sizeAt(3)}) + .assign(input); + else + (*newInput)({0, 0, 0, input->sizeAt(1), 0, input->sizeAt(2), 0, 0}) + .assign(input); - if(isNCHW) - (*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3)}).assign(input); - else - (*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,0}).assign(input); + input = newInput; - input = newInput; - - if(gradI != nullptr) - gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext()); + if (gradI != nullptr) + gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), + gradI->getContext()); } - ////////////////////////////////////////////////////////////////////////// -void checkConv3dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI, - const int iD, const int iH, const int iW, - const int oD, const int oH, const int oW, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const bool isNCDHW) { - - const auto pDsum = ((oD - 1) * sD + ((kD - 1) * dD + 1) - iD); - const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH); - const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW); - - const bool isPDasymm = pD != (pDsum - pD); - const bool isPHasymm = pH != (pHsum - pH); - const bool isPWasymm = pW != (pWsum - pW); - - if(!isPDasymm && !isPHasymm && !isPWasymm) - return; - - std::vector newShape = input->getShapeAsVector(); - - const int iDposition = isNCDHW ? 2 : 1; - - if(isPDasymm) - newShape[iDposition] += 1; - if(isPHasymm) - newShape[iDposition + 1] += 1; - if(isPWasymm) - newShape[iDposition + 2] += 1; - - NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext()); - - if(isNCDHW) - (*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3), 0,input->sizeAt(4)}).assign(input); - else - (*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,input->sizeAt(3), 0,0}).assign(input); - - input = newInput; +void checkConv3dCUDNNPadAsymmetric(NDArray*& input, NDArray*& gradI, + const int iD, const int iH, const int iW, + const int oD, const int oH, const int oW, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const bool isNCDHW) { + const auto pDsum = ((oD - 1) * sD + ((kD - 1) * dD + 1) - iD); + const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH); + const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW); + + const bool isPDasymm = pD != (pDsum - pD); + const bool isPHasymm = pH != (pHsum - pH); + const bool isPWasymm = pW != (pWsum - pW); + + if (!isPDasymm && !isPHasymm && !isPWasymm) return; + + std::vector newShape = input->getShapeAsVector(); + + const int iDposition = isNCDHW ? 2 : 1; + + if (isPDasymm) newShape[iDposition] += 1; + if (isPHasymm) newShape[iDposition + 1] += 1; + if (isPWasymm) newShape[iDposition + 2] += 1; + + NDArray* newInput = new NDArray(input->ordering(), newShape, + input->dataType(), input->getContext()); + + if (isNCDHW) + (*newInput)({0, 0, 0, 0, 0, input->sizeAt(2), 0, input->sizeAt(3), 0, + input->sizeAt(4)}) + .assign(input); + else + (*newInput)({0, 0, 0, input->sizeAt(1), 0, input->sizeAt(2), 0, + input->sizeAt(3), 0, 0}) + .assign(input); + + input = newInput; + + if (gradI != nullptr) + gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), + gradI->getContext()); +} - if(gradI != nullptr) - gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext()); +////////////////////////////////////////////////////////////////////////// +void pooling2dCUDNN(const LaunchContext* context, const NDArray* input, + NDArray* output, const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, const int dH, + const int dW, const bool isNCHW, + const cudnnPoolingMode_t mode) { + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWoC, indWkH, indOoH); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build( + "pooling2dCUDNN: can't set stream for cuDNN", err); + + cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if (input->ews() == 1 && input->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx( + x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), + input->strideAt(indIOioC), input->strideAt(indIiH), + input->strideAt(indIiH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "pooling2dCUDNN: " + "cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input " + "failed", + err); + + // output descriptor + cudnnTensorDescriptor_t z; + cudnnCreateTensorDescriptor(&z); + if (output->ews() == 1 && output->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); + else + err = cudnnSetTensor4dDescriptorEx( + z, cudnnDataType(output->dataType()), bS, oC, oH, oW, + output->strideAt(0), output->strideAt(indIOioC), + output->strideAt(indOoH), output->strideAt(indOoH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "pooling2dCUDNN: " + "cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for output " + "failed", + err); + + // description of pooling + cudnnPoolingDescriptor_t pooling; + cudnnCreatePoolingDescriptor(&pooling); + err = cudnnSetPooling2dDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, kH, kW, + pH, pW, sH, sW); + if (err != 0) + throw sd::cuda_exception::build( + "pooling2dCUDNN: cudnnSetPooling2dDescriptor failed", err); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = output->sizeOfT() <= 4 + ? reinterpret_cast(&alpha32) + : reinterpret_cast(&alpha64); + const void* beta = output->sizeOfT() <= 4 + ? reinterpret_cast(&beta32) + : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({output}, {input}); + + // run calculation + err = cudnnPoolingForward(*handle, pooling, alpha, x, input->specialBuffer(), + beta, z, output->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "pooling2dCUDNN: cudnnPoolingForward failed", err); + + auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + if (cudaErr != 0) + throw cuda_exception::build( + "pooling2dCUDNN: cudaStreamSynchronize failed !", cudaErr); + + NDArray::registerSpecialUse({output}, {input}); } ////////////////////////////////////////////////////////////////////////// -void pooling2dCUDNN(const LaunchContext* context, - const NDArray* input, NDArray* output, - const int kH, const int kW, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW, - const bool isNCHW, const cudnnPoolingMode_t mode) { - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - auto handle = reinterpret_cast(context->getCuDnnHandle()); - cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); - if (err != 0) throw sd::cuda_exception::build("pooling2dCUDNN: can't set stream for cuDNN", err); - - cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; - - // input descriptor - cudnnTensorDescriptor_t x; - cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1 && input->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); - else - err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); - if (err != 0) throw sd::cuda_exception::build("pooling2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input failed", err); - - // output descriptor - cudnnTensorDescriptor_t z; - cudnnCreateTensorDescriptor(&z); - if(output->ews() == 1 && output->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); - else - err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1)); - if (err != 0) throw sd::cuda_exception::build("pooling2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for output failed", err); - - // description of pooling - cudnnPoolingDescriptor_t pooling; - cudnnCreatePoolingDescriptor(&pooling); - err = cudnnSetPooling2dDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, kH, kW, pH, pW, sH, sW); - if (err != 0) throw sd::cuda_exception::build("pooling2dCUDNN: cudnnSetPooling2dDescriptor failed", err); - - // provide scaling parameters - const float alpha32(1), beta32(0); - const double alpha64(1), beta64(0); - const void* alpha = output->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); - const void* beta = output->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); - - NDArray::prepareSpecialUse({output}, {input}); - - // run calculation - err = cudnnPoolingForward(*handle, pooling, alpha, x, input->specialBuffer(), beta, z, output->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("pooling2dCUDNN: cudnnPoolingForward failed", err); - - auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - if (cudaErr != 0) - throw cuda_exception::build("pooling2dCUDNN: cudaStreamSynchronize failed !", cudaErr); - - NDArray::registerSpecialUse({output}, {input}); +void pooling2dBpCUDNN(const LaunchContext* context, const NDArray* input, + const NDArray* gradO, NDArray* gradI, const int kH, + const int kW, const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW, + const bool isNCHW, const cudnnPoolingMode_t mode) { + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWoC, indWkH, indOoH); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build( + "pooling2dBpCUDNN: can't set stream for cuDNN", err); + + cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input and gradI descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if (input->ews() == 1 && input->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx( + x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), + input->strideAt(indIOioC), input->strideAt(indIiH), + input->strideAt(indIiH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "pooling2dBpCUDNN: " + "cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for " + "input/gradI failed", + err); + + // gradO descriptor + cudnnTensorDescriptor_t dz; + cudnnCreateTensorDescriptor(&dz); + if (gradO->ews() == 1 && gradO->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); + else + err = cudnnSetTensor4dDescriptorEx( + dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, + gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), + gradO->strideAt(indOoH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "pooling2dBpCUDNN: " + "cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradO " + "failed", + err); + + // description of pooling + cudnnPoolingDescriptor_t pooling; + cudnnCreatePoolingDescriptor(&pooling); + err = cudnnSetPooling2dDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, kH, kW, + pH, pW, sH, sW); + if (err != 0) + throw sd::cuda_exception::build( + "pooling2dBpCUDNN: cudnnSetPooling2dDescriptor failed", err); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = gradO->sizeOfT() <= 4 + ? reinterpret_cast(&alpha32) + : reinterpret_cast(&alpha64); + const void* beta = gradO->sizeOfT() <= 4 + ? reinterpret_cast(&beta32) + : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({gradI}, {input, gradO}); + + // run calculation for gradI + err = cudnnPoolingBackward(*handle, pooling, alpha, dz, + gradO->specialBuffer(), dz, gradO->specialBuffer(), + x, input->specialBuffer(), beta, x, + gradI->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "pooling2dBpCUDNN: cudnnPoolingBackward failed", err); + + auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + if (cudaErr != 0) + throw cuda_exception::build( + "pooling2dBpCUDNN: cudaStreamSynchronize failed !", cudaErr); + + NDArray::registerSpecialUse({gradI}, {input, gradO}); } ////////////////////////////////////////////////////////////////////////// -void pooling2dBpCUDNN(const LaunchContext* context, - const NDArray* input, const NDArray* gradO, - NDArray* gradI, - const int kH, const int kW, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW, - const bool isNCHW, const cudnnPoolingMode_t mode) { - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - auto handle = reinterpret_cast(context->getCuDnnHandle()); - cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); - if (err != 0) throw sd::cuda_exception::build("pooling2dBpCUDNN: can't set stream for cuDNN", err); - - cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; - - // input and gradI descriptor - cudnnTensorDescriptor_t x; - cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1 && input->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); - else - err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); - if (err != 0) throw sd::cuda_exception::build("pooling2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input/gradI failed", err); - - // gradO descriptor - cudnnTensorDescriptor_t dz; - cudnnCreateTensorDescriptor(&dz); - if(gradO->ews() == 1 && gradO->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); - else - err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1)); - if (err != 0) throw sd::cuda_exception::build("pooling2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradO failed", err); - - // description of pooling - cudnnPoolingDescriptor_t pooling; - cudnnCreatePoolingDescriptor(&pooling); - err = cudnnSetPooling2dDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, kH, kW, pH, pW, sH, sW); - if (err != 0) throw sd::cuda_exception::build("pooling2dBpCUDNN: cudnnSetPooling2dDescriptor failed", err); - - // provide scaling parameters - const float alpha32(1), beta32(0); - const double alpha64(1), beta64(0); - const void* alpha = gradO->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); - const void* beta = gradO->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); +void pooling3dCUDNN(const LaunchContext* context, const NDArray* input, + NDArray* output, const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, const int pD, + const int pH, const int pW, const int dD, const int dH, + const int dW, const bool isNCDHW, + const cudnnPoolingMode_t mode) { + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build( + "pooling3dCUDNN: can't set stream for cuDNN", err); + + const int numDims = 5; + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + const int pSizes[] = {pD, pH, pW}; + const int sSizes[] = {sD, sH, sW}; + const int kSizes[] = {kD, kH, kW}; + + const int xShape[] = {bS, iC, iD, iH, iW}; + const int zShape[] = {bS, oC, oD, oH, oW}; + + const int xStrides[] = {(int)input->strideAt(0), (int)input->strideAt(1), + (int)input->strideAt(2), (int)input->strideAt(3), + (int)input->strideAt(4)}; + const int zStrides[] = {(int)output->strideAt(0), (int)output->strideAt(1), + (int)output->strideAt(2), (int)output->strideAt(3), + (int)output->strideAt(4)}; + + cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if (input->ews() == 1 && input->ordering() == 'c') + err = cudnnSetTensorNdDescriptorEx( + x, format, cudnnDataType(input->dataType()), numDims, xShape); + else + err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), + numDims, xShape, xStrides); + if (err != 0) + throw sd::cuda_exception::build( + "pooling3dCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input " + "failed", + err); + + // output descriptor + cudnnTensorDescriptor_t z; + cudnnCreateTensorDescriptor(&z); + if (output->ews() == 1 && output->ordering() == 'c') + err = cudnnSetTensorNdDescriptorEx( + z, format, cudnnDataType(output->dataType()), numDims, zShape); + else + err = cudnnSetTensorNdDescriptor(z, cudnnDataType(output->dataType()), + numDims, zShape, zStrides); + if (err != 0) + throw sd::cuda_exception::build( + "pooling3dCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for output " + "failed", + err); + + // description of pooling + cudnnPoolingDescriptor_t pooling; + cudnnCreatePoolingDescriptor(&pooling); + err = cudnnSetPoolingNdDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, + numDims - 2, kSizes, pSizes, sSizes); + if (err != 0) + throw sd::cuda_exception::build( + "pooling3dCUDNN: cudnnSetPoolingNdDescriptor failed", err); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = output->sizeOfT() <= 4 + ? reinterpret_cast(&alpha32) + : reinterpret_cast(&alpha64); + const void* beta = output->sizeOfT() <= 4 + ? reinterpret_cast(&beta32) + : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({output}, {input}); + + // run calculation + err = cudnnPoolingForward(*handle, pooling, alpha, x, input->specialBuffer(), + beta, z, output->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "pooling3dCUDNN: cudnnPoolingForward failed", err); + + auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + if (cudaErr != 0) + throw cuda_exception::build( + "pooling3dCUDNN: cudaStreamSynchronize failed !", cudaErr); + + NDArray::registerSpecialUse({output}, {input}); +} +////////////////////////////////////////////////////////////////////////// +void pooling3dBpCUDNN(const LaunchContext* context, const NDArray* input, + const NDArray* gradO, NDArray* gradI, const int kD, + const int kH, const int kW, const int sD, const int sH, + const int sW, const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const bool isNCDHW, const cudnnPoolingMode_t mode) { + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build( + "pooling3dBpCUDNN: can't set stream for cuDNN", err); + + const int numDims = 5; + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + const int pSizes[] = {pD, pH, pW}; + const int sSizes[] = {sD, sH, sW}; + const int kSizes[] = {kD, kH, kW}; + + const int xShape[] = {bS, iC, iD, iH, iW}; + const int dzShape[] = {bS, oC, oD, oH, oW}; + + const int xStrides[] = {(int)input->strideAt(0), (int)input->strideAt(1), + (int)input->strideAt(2), (int)input->strideAt(3), + (int)input->strideAt(4)}; + const int dzStrides[] = {(int)gradO->strideAt(0), (int)gradO->strideAt(1), + (int)gradO->strideAt(2), (int)gradO->strideAt(3), + (int)gradO->strideAt(4)}; + + cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input and gradI descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if (input->ews() == 1 && input->ordering() == 'c') + err = cudnnSetTensorNdDescriptorEx( + x, format, cudnnDataType(input->dataType()), numDims, xShape); + else + err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), + numDims, xShape, xStrides); + if (err != 0) + throw sd::cuda_exception::build( + "pooling3dBpCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for " + "input/gradI failed", + err); + + // gradO descriptor + cudnnTensorDescriptor_t dz; + cudnnCreateTensorDescriptor(&dz); + if (gradO->ews() == 1 && gradO->ordering() == 'c') + err = cudnnSetTensorNdDescriptorEx( + dz, format, cudnnDataType(gradO->dataType()), numDims, dzShape); + else + err = cudnnSetTensorNdDescriptor(dz, cudnnDataType(gradO->dataType()), + numDims, dzShape, dzStrides); + if (err != 0) + throw sd::cuda_exception::build( + "pooling3dBpCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradO " + "failed", + err); + + // description of pooling + cudnnPoolingDescriptor_t pooling; + cudnnCreatePoolingDescriptor(&pooling); + err = cudnnSetPoolingNdDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, + numDims - 2, kSizes, pSizes, sSizes); + if (err != 0) + throw sd::cuda_exception::build( + "pooling3dBpCUDNN: cudnnSetPoolingNdDescriptor failed", err); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = gradO->sizeOfT() <= 4 + ? reinterpret_cast(&alpha32) + : reinterpret_cast(&alpha64); + const void* beta = gradO->sizeOfT() <= 4 + ? reinterpret_cast(&beta32) + : reinterpret_cast(&beta64); + + // cudnn maxpool2d_bp api requires ff output as one of input arguments + if (mode == CUDNN_POOLING_MAX) { + NDArray temp(gradO); + + NDArray::prepareSpecialUse({gradI}, {input, gradO, &temp}); + + // run ff calculation + err = + cudnnPoolingForward(*handle, pooling, alpha, x, input->specialBuffer(), + beta, dz, temp.specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "pooling3dCUDNN: cudnnPoolingForward failed", err); + + // run bp calculation for gradI + err = cudnnPoolingBackward(*handle, pooling, alpha, dz, + temp.specialBuffer(), dz, gradO->specialBuffer(), + x, input->specialBuffer(), beta, x, + gradI->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "pooling2dBpCUDNN: cudnnPoolingBackward failed", err); + + NDArray::registerSpecialUse({gradI}, {input, gradO, &temp}); + } else { NDArray::prepareSpecialUse({gradI}, {input, gradO}); - // run calculation for gradI - err = cudnnPoolingBackward(*handle, pooling, alpha, dz, gradO->specialBuffer(), dz, gradO->specialBuffer(), x, input->specialBuffer(), beta, x, gradI->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("pooling2dBpCUDNN: cudnnPoolingBackward failed", err); - - auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - if (cudaErr != 0) - throw cuda_exception::build("pooling2dBpCUDNN: cudaStreamSynchronize failed !", cudaErr); + // run bp calculation for gradI + err = cudnnPoolingBackward( + *handle, pooling, alpha, dz, gradO->specialBuffer(), dz, + gradO->specialBuffer(), x, input->specialBuffer(), beta, x, + gradI->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "pooling2dBpCUDNN: cudnnPoolingBackward failed", err); NDArray::registerSpecialUse({gradI}, {input, gradO}); -} + } -////////////////////////////////////////////////////////////////////////// -void pooling3dCUDNN(const LaunchContext* context, - const NDArray* input, NDArray* output, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const bool isNCDHW, const cudnnPoolingMode_t mode) { - - auto handle = reinterpret_cast(context->getCuDnnHandle()); - cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); - if (err != 0) throw sd::cuda_exception::build("pooling3dCUDNN: can't set stream for cuDNN", err); - - const int numDims = 5; - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - const int pSizes[] = {pD, pH, pW}; - const int sSizes[] = {sD, sH, sW}; - const int kSizes[] = {kD, kH, kW}; - - const int xShape[] = {bS, iC, iD, iH, iW}; - const int zShape[] = {bS, oC, oD, oH, oW}; - - const int xStrides[] = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)}; - const int zStrides[] = {(int)output->strideAt(0), (int)output->strideAt(1), (int)output->strideAt(2), (int)output->strideAt(3), (int)output->strideAt(4)}; - - cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; - - // input descriptor - cudnnTensorDescriptor_t x; - cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1 && input->ordering() == 'c') - err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape); - else - err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides); - if (err != 0) throw sd::cuda_exception::build("pooling3dCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input failed", err); - - // output descriptor - cudnnTensorDescriptor_t z; - cudnnCreateTensorDescriptor(&z); - if(output->ews() == 1 && output->ordering() == 'c') - err = cudnnSetTensorNdDescriptorEx(z, format, cudnnDataType(output->dataType()), numDims, zShape); - else - err = cudnnSetTensorNdDescriptor(z, cudnnDataType(output->dataType()), numDims, zShape, zStrides); - if (err != 0) throw sd::cuda_exception::build("pooling3dCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for output failed", err); - - // description of pooling - cudnnPoolingDescriptor_t pooling; - cudnnCreatePoolingDescriptor(&pooling); - err = cudnnSetPoolingNdDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, numDims - 2, kSizes, pSizes, sSizes); - if (err != 0) throw sd::cuda_exception::build("pooling3dCUDNN: cudnnSetPoolingNdDescriptor failed", err); - - // provide scaling parameters - const float alpha32(1), beta32(0); - const double alpha64(1), beta64(0); - const void* alpha = output->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); - const void* beta = output->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); - - NDArray::prepareSpecialUse({output}, {input}); - - // run calculation - err = cudnnPoolingForward(*handle, pooling, alpha, x, input->specialBuffer(), beta, z, output->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("pooling3dCUDNN: cudnnPoolingForward failed", err); - - auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - if (cudaErr != 0) - throw cuda_exception::build("pooling3dCUDNN: cudaStreamSynchronize failed !", cudaErr); - - NDArray::registerSpecialUse({output}, {input}); + auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + if (cudaErr != 0) + throw cuda_exception::build( + "pooling3dBpCUDNN: cudaStreamSynchronize failed !", cudaErr); } -////////////////////////////////////////////////////////////////////////// -void pooling3dBpCUDNN(const LaunchContext* context, - const NDArray* input, const NDArray* gradO, - NDArray* gradI, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const bool isNCDHW, const cudnnPoolingMode_t mode) { - - auto handle = reinterpret_cast(context->getCuDnnHandle()); - cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); - if (err != 0) throw sd::cuda_exception::build("pooling3dBpCUDNN: can't set stream for cuDNN", err); - - const int numDims = 5; - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - const int pSizes[] = {pD, pH, pW}; - const int sSizes[] = {sD, sH, sW}; - const int kSizes[] = {kD, kH, kW}; - - const int xShape[] = {bS, iC, iD, iH, iW}; - const int dzShape[] = {bS, oC, oD, oH, oW}; - - const int xStrides[] = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)}; - const int dzStrides[] = {(int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), (int)gradO->strideAt(3), (int)gradO->strideAt(4)}; - - cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; - - // input and gradI descriptor - cudnnTensorDescriptor_t x; - cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1 && input->ordering() == 'c') - err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape); - else - err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides); - if (err != 0) throw sd::cuda_exception::build("pooling3dBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input/gradI failed", err); - - // gradO descriptor - cudnnTensorDescriptor_t dz; - cudnnCreateTensorDescriptor(&dz); - if(gradO->ews() == 1 && gradO->ordering() == 'c') - err = cudnnSetTensorNdDescriptorEx(dz, format, cudnnDataType(gradO->dataType()), numDims, dzShape); - else - err = cudnnSetTensorNdDescriptor(dz, cudnnDataType(gradO->dataType()), numDims, dzShape, dzStrides); - if (err != 0) throw sd::cuda_exception::build("pooling3dBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradO failed", err); - - // description of pooling - cudnnPoolingDescriptor_t pooling; - cudnnCreatePoolingDescriptor(&pooling); - err = cudnnSetPoolingNdDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, numDims - 2, kSizes, pSizes, sSizes); - if (err != 0) throw sd::cuda_exception::build("pooling3dBpCUDNN: cudnnSetPoolingNdDescriptor failed", err); - - // provide scaling parameters - const float alpha32(1), beta32(0); - const double alpha64(1), beta64(0); - const void* alpha = gradO->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); - const void* beta = gradO->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); - - // cudnn maxpool2d_bp api requires ff output as one of input arguments - if(mode == CUDNN_POOLING_MAX) { - - NDArray temp(gradO); - - NDArray::prepareSpecialUse({gradI}, {input, gradO, &temp}); - - // run ff calculation - err = cudnnPoolingForward(*handle, pooling, alpha, x, input->specialBuffer(), beta, dz, temp.specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("pooling3dCUDNN: cudnnPoolingForward failed", err); - - // run bp calculation for gradI - err = cudnnPoolingBackward(*handle, pooling, alpha, dz, temp.specialBuffer(), dz, gradO->specialBuffer(), x, input->specialBuffer(), beta, x, gradI->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("pooling2dBpCUDNN: cudnnPoolingBackward failed", err); - - NDArray::registerSpecialUse({gradI}, {input, gradO, &temp}); - } - else { - - NDArray::prepareSpecialUse({gradI}, {input, gradO}); - - // run bp calculation for gradI - err = cudnnPoolingBackward(*handle, pooling, alpha, dz, gradO->specialBuffer(), dz, gradO->specialBuffer(), x, input->specialBuffer(), beta, x, gradI->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("pooling2dBpCUDNN: cudnnPoolingBackward failed", err); - - NDArray::registerSpecialUse({gradI}, {input, gradO}); - } - - auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - if (cudaErr != 0) - throw cuda_exception::build("pooling3dBpCUDNN: cudaStreamSynchronize failed !", cudaErr); -} - -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.h b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.h index 3379979a32b9..9d015abaec36 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.h +++ b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.h @@ -21,121 +21,110 @@ #ifndef SD_CUDNNUTILS_H #define SD_CUDNNUTILS_H -#include -#include -#include +#include #include #include +#include +#include #include +#include -#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { - DECLARE_PLATFORM(conv2d, ENGINE_CUDA); - DECLARE_PLATFORM(conv2d_bp, ENGINE_CUDA); +DECLARE_PLATFORM(conv2d, ENGINE_CUDA); +DECLARE_PLATFORM(conv2d_bp, ENGINE_CUDA); - DECLARE_PLATFORM(conv3dnew, ENGINE_CUDA); - DECLARE_PLATFORM(conv3dnew_bp, ENGINE_CUDA); +DECLARE_PLATFORM(conv3dnew, ENGINE_CUDA); +DECLARE_PLATFORM(conv3dnew_bp, ENGINE_CUDA); - DECLARE_PLATFORM(depthwise_conv2d, ENGINE_CUDA); - DECLARE_PLATFORM(depthwise_conv2d_bp, ENGINE_CUDA); +DECLARE_PLATFORM(depthwise_conv2d, ENGINE_CUDA); +DECLARE_PLATFORM(depthwise_conv2d_bp, ENGINE_CUDA); - DECLARE_PLATFORM(batchnorm, ENGINE_CUDA); - DECLARE_PLATFORM(batchnorm_bp, ENGINE_CUDA); +DECLARE_PLATFORM(batchnorm, ENGINE_CUDA); +DECLARE_PLATFORM(batchnorm_bp, ENGINE_CUDA); - DECLARE_PLATFORM(avgpool2d, ENGINE_CUDA); - DECLARE_PLATFORM(avgpool2d_bp, ENGINE_CUDA); +DECLARE_PLATFORM(avgpool2d, ENGINE_CUDA); +DECLARE_PLATFORM(avgpool2d_bp, ENGINE_CUDA); - DECLARE_PLATFORM(maxpool2d, ENGINE_CUDA); - DECLARE_PLATFORM(maxpool2d_bp, ENGINE_CUDA); +DECLARE_PLATFORM(maxpool2d, ENGINE_CUDA); +DECLARE_PLATFORM(maxpool2d_bp, ENGINE_CUDA); - DECLARE_PLATFORM(avgpool3dnew, ENGINE_CUDA); - DECLARE_PLATFORM(avgpool3dnew_bp, ENGINE_CUDA); +DECLARE_PLATFORM(avgpool3dnew, ENGINE_CUDA); +DECLARE_PLATFORM(avgpool3dnew_bp, ENGINE_CUDA); - DECLARE_PLATFORM(maxpool3dnew, ENGINE_CUDA); - DECLARE_PLATFORM(maxpool3dnew_bp, ENGINE_CUDA); +DECLARE_PLATFORM(maxpool3dnew, ENGINE_CUDA); +DECLARE_PLATFORM(maxpool3dnew_bp, ENGINE_CUDA); ////////////////////////////////////////////////////////////////////////// FORCEINLINE cudnnDataType_t cudnnDataType(sd::DataType dataType) { - switch (dataType) { - case sd::DataType::FLOAT32: - return CUDNN_DATA_FLOAT; - case sd::DataType::DOUBLE: - return CUDNN_DATA_DOUBLE; - case sd::DataType::HALF: - return CUDNN_DATA_HALF; - case sd::DataType::INT32: - return CUDNN_DATA_INT32; - case sd::DataType::INT8: - return CUDNN_DATA_INT8; - default: - throw datatype_exception::build("Unsupported data type", dataType); - } + switch (dataType) { + case sd::DataType::FLOAT32: + return CUDNN_DATA_FLOAT; + case sd::DataType::DOUBLE: + return CUDNN_DATA_DOUBLE; + case sd::DataType::HALF: + return CUDNN_DATA_HALF; + case sd::DataType::INT32: + return CUDNN_DATA_INT32; + case sd::DataType::INT8: + return CUDNN_DATA_INT8; + default: + throw datatype_exception::build("Unsupported data type", dataType); + } } ////////////////////////////////////////////////////////////////////////// -void checkConv2dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI, - const int iH, const int iW, - const int oH, const int oW, - const int kH, const int kW, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW, - const bool isNCHW); +void checkConv2dCUDNNPadAsymmetric(NDArray*& input, NDArray*& gradI, + const int iH, const int iW, const int oH, + const int oW, const int kH, const int kW, + const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW, + const bool isNCHW); ////////////////////////////////////////////////////////////////////////// -void checkConv3dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI, - const int iD, const int iH, const int iW, - const int oD, const int oH, const int oW, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const bool isNCDHW); +void checkConv3dCUDNNPadAsymmetric(NDArray*& input, NDArray*& gradI, + const int iD, const int iH, const int iW, + const int oD, const int oH, const int oW, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const bool isNCDHW); ////////////////////////////////////////////////////////////////////////// -void pooling2dCUDNN(const LaunchContext* context, - const NDArray* input, NDArray* output, - const int kH, const int kW, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW, - const bool isNCHW, const cudnnPoolingMode_t mode); +void pooling2dCUDNN(const LaunchContext* context, const NDArray* input, + NDArray* output, const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, const int dH, + const int dW, const bool isNCHW, + const cudnnPoolingMode_t mode); ////////////////////////////////////////////////////////////////////////// -void pooling2dBpCUDNN(const LaunchContext* context, - const NDArray* input, const NDArray* gradO, - NDArray* gradI, - const int kH, const int kW, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW, +void pooling2dBpCUDNN(const LaunchContext* context, const NDArray* input, + const NDArray* gradO, NDArray* gradI, const int kH, + const int kW, const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW, const bool isNCHW, const cudnnPoolingMode_t mode); ////////////////////////////////////////////////////////////////////////// -void pooling3dCUDNN(const LaunchContext* context, - const NDArray* input, NDArray* output, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const bool isNCDHW, const cudnnPoolingMode_t mode); +void pooling3dCUDNN(const LaunchContext* context, const NDArray* input, + NDArray* output, const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, const int pD, + const int pH, const int pW, const int dD, const int dH, + const int dW, const bool isNCDHW, + const cudnnPoolingMode_t mode); ////////////////////////////////////////////////////////////////////////// -void pooling3dBpCUDNN(const LaunchContext* context, - const NDArray* input, const NDArray* gradO, - NDArray* gradI, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const bool isNCDHW, const cudnnPoolingMode_t mode); - -} -} -} - -#endif //SD_CUDNNUTILS_H +void pooling3dBpCUDNN(const LaunchContext* context, const NDArray* input, + const NDArray* gradO, NDArray* gradI, const int kD, + const int kH, const int kW, const int sD, const int sH, + const int sW, const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const bool isNCDHW, const cudnnPoolingMode_t mode); + +} // namespace platforms +} // namespace ops +} // namespace sd + +#endif // SD_CUDNNUTILS_H diff --git a/libnd4j/include/ops/declarable/platform/cudnn/depthwiseConv2d.cu b/libnd4j/include/ops/declarable/platform/cudnn/depthwiseConv2d.cu index c268961ce840..8a09c6f33d57 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/depthwiseConv2d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/depthwiseConv2d.cu @@ -18,453 +18,755 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // +#include #include "cudnnUtils.h" -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { - ////////////////////////////////////////////////////////////////////////// static void depthwiseConv2dCUDNN(const LaunchContext* context, - const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, - const int kH, const int kW, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW, - const int paddingMode, const bool isNCHW) { - - // cudnn supports only following case: mC = 1, oC = iC (groupCount == iC) - - // input [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc - // weights [iC, mC, kH, kW] - // bias [oC], may be nullptr - // output [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc - // oC = iC*mC - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(1); - - auto handle = reinterpret_cast(context->getCuDnnHandle()); - cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: can't set stream for cuDNN", err); - - cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; - - // input descriptor - cudnnTensorDescriptor_t x; - cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1 && input->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); - else - err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input failed", err); - - // weights descriptor - cudnnFilterDescriptor_t w; - cudnnCreateFilterDescriptor(&w); - err = cudnnSetFilter4dDescriptor(w, cudnnDataType(weights->dataType()), CUDNN_TENSOR_NCHW, iC, mC, kH, kW); - if(err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnSetFilter4dDescriptor failed", err); - - // output descriptor - cudnnTensorDescriptor_t z; - cudnnCreateTensorDescriptor(&z); - if(output->ews() == 1 && output->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); - else - err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1)); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for output failed", err); - - // description of convolution - cudnnConvolutionDescriptor_t conv; - cudnnCreateConvolutionDescriptor(&conv); - err = cudnnSetConvolution2dDescriptor(conv, pH, pW, sH, sW, dH, dW, CUDNN_CROSS_CORRELATION, cudnnDataType(output->dataType())); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnSetConvolution2dDescriptor failed", err); - err = cudnnSetConvolutionGroupCount(conv, iC); // set number of groups (depthwise mode) in description of convolution, groupCount == iC - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnSetConvolutionGroupCount failed", err); - - // algorithm description - cudnnConvolutionFwdAlgo_t algo; - err = cudnnGetConvolutionForwardAlgorithm(*handle, x, w, conv, z, CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, 0, &algo); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnGetConvolutionForwardAlgorithm failed", err); - - // allocate auxiliary device memory, abbreviation ws means workspace - size_t wsSize; - err = cudnnGetConvolutionForwardWorkspaceSize(*handle, x, w, conv, z, algo, &wsSize); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnGetConvolutionForwardWorkspaceSize failed", err); - void* wsData; - auto cudaErr = cudaMalloc(&wsData, wsSize); - if (cudaErr != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudaMalloc for auxiliary workspace memory failed", cudaErr); - - // provide scaling parameters - const float alpha32(1), beta32(0); - const double alpha64(1), beta64(0); - const void* alpha = output->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); - const void* beta = output->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); - - NDArray::prepareSpecialUse({output}, {input, weights, bias}); - - // run calculation - err = cudnnConvolutionForward(*handle, alpha, x, input->specialBuffer(), w, weights->specialBuffer(), conv, algo, wsData, wsSize, beta, z, output->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnConvolutionForward failed", err); - - // add bias if it is present - if (bias != nullptr) { - - cudnnTensorDescriptor_t b; - cudnnCreateTensorDescriptor(&b); - // err = cudnnSetTensor4dDescriptor(b, format, cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, isNCHW ? 1: bias->lengthOf()); - err = cudnnSetTensor4dDescriptor(b, CUDNN_TENSOR_NCHW, cudnnDataType(bias->dataType()), 1, oC, 1, 1); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnSetTensor4dDescriptor for bias failed", err); - err = cudnnAddTensor(*handle, alpha, b, bias->specialBuffer(), alpha, z, output->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnAddTensor bias failed", err); - } - - // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - // if (cudaErr != 0) - // throw cuda_exception::build("depthwiseConv2dCUDNN: cudaStreamSynchronize failed !", cudaErr); - - cudaErr = cudaFree(wsData); - if (cudaErr != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudaFree for auxiliary workspace memory failed", cudaErr); - - NDArray::registerSpecialUse({output}, {input, weights, bias}); + const NDArray* input, const NDArray* weights, + const NDArray* bias, NDArray* output, + const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, + const int dH, const int dW, + const int paddingMode, const bool isNCHW) { + // cudnn supports only following case: mC = 1, oC = iC (groupCount == iC) + + // input [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc + // weights [iC, mC, kH, kW] + // bias [oC], may be nullptr + // output [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc + // oC = iC*mC + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(1); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: can't set stream for cuDNN", err); + + cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if (input->ews() == 1 && input->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx( + x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), + input->strideAt(indIOioC), input->strideAt(indIiH), + input->strideAt(indIiH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: " + "cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input " + "failed", + err); + + // weights descriptor + cudnnFilterDescriptor_t w; + cudnnCreateFilterDescriptor(&w); + err = cudnnSetFilter4dDescriptor(w, cudnnDataType(weights->dataType()), + CUDNN_TENSOR_NCHW, iC, mC, kH, kW); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: cudnnSetFilter4dDescriptor failed", err); + + // output descriptor + cudnnTensorDescriptor_t z; + cudnnCreateTensorDescriptor(&z); + if (output->ews() == 1 && output->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); + else + err = cudnnSetTensor4dDescriptorEx( + z, cudnnDataType(output->dataType()), bS, oC, oH, oW, + output->strideAt(0), output->strideAt(indIOioC), + output->strideAt(indOoH), output->strideAt(indOoH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: " + "cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for output " + "failed", + err); + + // description of convolution + cudnnConvolutionDescriptor_t conv; + cudnnCreateConvolutionDescriptor(&conv); + err = cudnnSetConvolution2dDescriptor(conv, pH, pW, sH, sW, dH, dW, + CUDNN_CROSS_CORRELATION, + cudnnDataType(output->dataType())); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: cudnnSetConvolution2dDescriptor failed", err); + err = cudnnSetConvolutionGroupCount( + conv, iC); // set number of groups (depthwise mode) in description of + // convolution, groupCount == iC + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: cudnnSetConvolutionGroupCount failed", err); + + // algorithm description + cudnnConvolutionFwdAlgo_t algo; + err = cudnnGetConvolutionForwardAlgorithm( + *handle, x, w, conv, z, CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, 0, &algo); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: cudnnGetConvolutionForwardAlgorithm failed", + err); + + // allocate auxiliary device memory, abbreviation ws means workspace + size_t wsSize; + err = cudnnGetConvolutionForwardWorkspaceSize(*handle, x, w, conv, z, algo, + &wsSize); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: cudnnGetConvolutionForwardWorkspaceSize failed", + err); + void* wsData; + auto cudaErr = cudaMalloc(&wsData, wsSize); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: cudaMalloc for auxiliary workspace memory " + "failed", + cudaErr); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = output->sizeOfT() <= 4 + ? reinterpret_cast(&alpha32) + : reinterpret_cast(&alpha64); + const void* beta = output->sizeOfT() <= 4 + ? reinterpret_cast(&beta32) + : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({output}, {input, weights, bias}); + + // run calculation + err = cudnnConvolutionForward(*handle, alpha, x, input->specialBuffer(), w, + weights->specialBuffer(), conv, algo, wsData, + wsSize, beta, z, output->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: cudnnConvolutionForward failed", err); + + // add bias if it is present + if (bias != nullptr) { + cudnnTensorDescriptor_t b; + cudnnCreateTensorDescriptor(&b); + // err = cudnnSetTensor4dDescriptor(b, format, + // cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, + // isNCHW ? 1: bias->lengthOf()); + err = cudnnSetTensor4dDescriptor( + b, CUDNN_TENSOR_NCHW, cudnnDataType(bias->dataType()), 1, oC, 1, 1); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: cudnnSetTensor4dDescriptor for bias failed", + err); + err = cudnnAddTensor(*handle, alpha, b, bias->specialBuffer(), alpha, z, + output->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: cudnnAddTensor bias failed", err); + } + + // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + // if (cudaErr != 0) + // throw cuda_exception::build("depthwiseConv2dCUDNN: + // cudaStreamSynchronize failed !", cudaErr); + + cudaErr = cudaFree(wsData); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: cudaFree for auxiliary workspace memory failed", + cudaErr); + + NDArray::registerSpecialUse({output}, {input, weights, bias}); } ////////////////////////////////////////////////////////////////////////// static void depthwiseConv2dBpCUDNN(const LaunchContext* context, - const NDArray* input, const NDArray* weights, const NDArray* gradO, - NDArray* gradI, NDArray* gradW, NDArray* gradB, - const int kH, const int kW, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW, - const int paddingMode, const bool isNCHW) { - - // cudnn supports only following case: mC = 1, oC = iC (groupCount == iC) - - // input, gradI [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc - // weights, gradW [iC, mC, kH, kW] - // gradB [oC], may be nullptr - // gradO [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc - // oC = iC*mC - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(1); - - auto handle = reinterpret_cast(context->getCuDnnHandle()); - cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: can't set stream for cuDNN", err); - - cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; - - // input descriptor - cudnnTensorDescriptor_t x; - cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1 && input->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); - else - err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input failed", err); - - // gradO descriptor - cudnnTensorDescriptor_t dz; - cudnnCreateTensorDescriptor(&dz); - if(gradO->ews() == 1 && gradO->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); - else - err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1)); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradO failed", err); - - // gradI descriptor - cudnnTensorDescriptor_t dx; - cudnnCreateTensorDescriptor(&dx); - if(gradI->ews() == 1 && gradI->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(dx, format, cudnnDataType(gradI->dataType()), bS, iC, iH, iW); - else - err = cudnnSetTensor4dDescriptorEx(dx, cudnnDataType(gradI->dataType()), bS, iC, iH, iW, gradI->strideAt(0), gradI->strideAt(indIOioC), gradI->strideAt(indIiH), gradI->strideAt(indIiH + 1)); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradI failed", err); - - // gradW descriptor - cudnnFilterDescriptor_t dw; - cudnnCreateFilterDescriptor(&dw); - err = cudnnSetFilter4dDescriptor(dw, cudnnDataType(gradW->dataType()), CUDNN_TENSOR_NCHW, iC, mC, kH, kW); - if(err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetFilter4dDescriptor gradW failed", err); - - // description of convolution - cudnnConvolutionDescriptor_t conv; - cudnnCreateConvolutionDescriptor(&conv); - err = cudnnSetConvolution2dDescriptor(conv, pH, pW, sH, sW, dH, dW, CUDNN_CROSS_CORRELATION, cudnnDataType(gradO->dataType())); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetConvolution2dDescriptor failed", err); - err = cudnnSetConvolutionGroupCount(conv, iC); // set number of groups (depthwise mode) in description of convolution, groupCount == iC - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetConvolutionGroupCount failed", err); - - // gradW algorithm description - cudnnConvolutionBwdFilterAlgo_t algoGradW; - err = cudnnGetConvolutionBackwardFilterAlgorithm(*handle, x, dz, conv, dw, CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, 0, &algoGradW); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnGetConvolutionBackwardFilterAlgorithm failed", err); - - // gradI algorithm description - cudnnConvolutionBwdDataAlgo_t algoGradI; - err = cudnnGetConvolutionBackwardDataAlgorithm(*handle, dw, dz, conv, x, CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, 0, &algoGradI); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnGetConvolutionBackwardDataAlgorithm failed", err); - - // allocate auxiliary device memory for gradW calculation, abbreviation ws means workspace - size_t wsGradWSize; - err = cudnnGetConvolutionBackwardFilterWorkspaceSize(*handle, x, dz, conv, dw, algoGradW, &wsGradWSize); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnGetConvolutionBackwardFilterWorkspaceSize failed", err); - void* wsGradWData; - auto cudaErr = cudaMalloc(&wsGradWData, wsGradWSize); - if (cudaErr != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradWData failed", cudaErr); - - // allocate auxiliary device memory for gradI calculation, abbreviation ws means workspace - size_t wsGradISize; - err = cudnnGetConvolutionBackwardDataWorkspaceSize(*handle, dw, dz, conv, dx, algoGradI, &wsGradISize); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnGetConvolutionBackwardDataWorkspaceSize failed", err); - void* wsGradIData; - cudaErr = cudaMalloc(&wsGradIData, wsGradISize); - if (cudaErr != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradIData failed", cudaErr); - - // provide scaling parameters - const float alpha32(1), beta32(0); - const double alpha64(1), beta64(0); - const void* alpha = gradO->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); - const void* beta = gradO->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); - - NDArray::prepareSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); - - // run calculation for gradB (if not nullptr) - if(gradB != nullptr) { - cudnnTensorDescriptor_t db; - cudnnCreateTensorDescriptor(&db); - // err = cudnnSetTensor4dDescriptor(db, format, cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, isNCHW ? 1: gradB->lengthOf()); - err = cudnnSetTensor4dDescriptor(db, CUDNN_TENSOR_NCHW, cudnnDataType(gradB->dataType()), 1, oC, 1, 1); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetTensor4dDescriptor for gradB failed", err); - - err = cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->specialBuffer(), beta, db, gradB->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnConvolutionBackwardBias failed", err); - } - - // run calculation for gradW - err = cudnnConvolutionBackwardFilter(*handle, alpha, x, input->specialBuffer(), dz, gradO->specialBuffer(), conv, algoGradW, wsGradWData, wsGradWSize, beta, dw, gradW->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnConvolutionBackwardFilter failed", err); - - // run calculation for gradI - err = cudnnConvolutionBackwardData(*handle, alpha, dw, weights->specialBuffer(), dz, gradO->specialBuffer(), conv, algoGradI, wsGradIData, wsGradISize, beta, dx, gradI->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnConvolutionBackwardData failed", err); - - // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - // if (cudaErr != 0) - // throw cuda_exception::build("depthwiseConv2dBpCUDNN: cudaStreamSynchronize failed !", cudaErr); - - cudaErr = cudaFree(wsGradWData); - if (cudaErr != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudaFree for auxiliary workspace memory wsGradWData failed", cudaErr); - cudaErr = cudaFree(wsGradIData); - if (cudaErr != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudaFree for auxiliary workspace memory wsGradIData failed", cudaErr); - - NDArray::registerSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); + const NDArray* input, const NDArray* weights, + const NDArray* gradO, NDArray* gradI, + NDArray* gradW, NDArray* gradB, const int kH, + const int kW, const int sH, const int sW, + const int pH, const int pW, const int dH, + const int dW, const int paddingMode, + const bool isNCHW) { + // cudnn supports only following case: mC = 1, oC = iC (groupCount == iC) + + // input, gradI [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc + // weights, gradW [iC, mC, kH, kW] + // gradB [oC], may be nullptr + // gradO [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc + // oC = iC*mC + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(1); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: can't set stream for cuDNN", err); + + cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if (input->ews() == 1 && input->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx( + x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), + input->strideAt(indIOioC), input->strideAt(indIiH), + input->strideAt(indIiH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: " + "cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input " + "failed", + err); + + // gradO descriptor + cudnnTensorDescriptor_t dz; + cudnnCreateTensorDescriptor(&dz); + if (gradO->ews() == 1 && gradO->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); + else + err = cudnnSetTensor4dDescriptorEx( + dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, + gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), + gradO->strideAt(indOoH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: " + "cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradO " + "failed", + err); + + // gradI descriptor + cudnnTensorDescriptor_t dx; + cudnnCreateTensorDescriptor(&dx); + if (gradI->ews() == 1 && gradI->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + dx, format, cudnnDataType(gradI->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx( + dx, cudnnDataType(gradI->dataType()), bS, iC, iH, iW, + gradI->strideAt(0), gradI->strideAt(indIOioC), gradI->strideAt(indIiH), + gradI->strideAt(indIiH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: " + "cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradI " + "failed", + err); + + // gradW descriptor + cudnnFilterDescriptor_t dw; + cudnnCreateFilterDescriptor(&dw); + err = cudnnSetFilter4dDescriptor(dw, cudnnDataType(gradW->dataType()), + CUDNN_TENSOR_NCHW, iC, mC, kH, kW); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudnnSetFilter4dDescriptor gradW failed", err); + + // description of convolution + cudnnConvolutionDescriptor_t conv; + cudnnCreateConvolutionDescriptor(&conv); + err = cudnnSetConvolution2dDescriptor(conv, pH, pW, sH, sW, dH, dW, + CUDNN_CROSS_CORRELATION, + cudnnDataType(gradO->dataType())); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudnnSetConvolution2dDescriptor failed", err); + err = cudnnSetConvolutionGroupCount( + conv, iC); // set number of groups (depthwise mode) in description of + // convolution, groupCount == iC + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudnnSetConvolutionGroupCount failed", err); + + // gradW algorithm description + cudnnConvolutionBwdFilterAlgo_t algoGradW; + err = cudnnGetConvolutionBackwardFilterAlgorithm( + *handle, x, dz, conv, dw, CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, 0, + &algoGradW); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudnnGetConvolutionBackwardFilterAlgorithm " + "failed", + err); + + // gradI algorithm description + cudnnConvolutionBwdDataAlgo_t algoGradI; + err = cudnnGetConvolutionBackwardDataAlgorithm( + *handle, dw, dz, conv, x, CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, 0, + &algoGradI); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudnnGetConvolutionBackwardDataAlgorithm " + "failed", + err); + + // allocate auxiliary device memory for gradW calculation, abbreviation ws + // means workspace + size_t wsGradWSize; + err = cudnnGetConvolutionBackwardFilterWorkspaceSize(*handle, x, dz, conv, dw, + algoGradW, &wsGradWSize); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: " + "cudnnGetConvolutionBackwardFilterWorkspaceSize failed", + err); + void* wsGradWData; + auto cudaErr = cudaMalloc(&wsGradWData, wsGradWSize); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudaMalloc for auxiliary workspace memory " + "wsGradWData failed", + cudaErr); + + // allocate auxiliary device memory for gradI calculation, abbreviation ws + // means workspace + size_t wsGradISize; + err = cudnnGetConvolutionBackwardDataWorkspaceSize(*handle, dw, dz, conv, dx, + algoGradI, &wsGradISize); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudnnGetConvolutionBackwardDataWorkspaceSize " + "failed", + err); + void* wsGradIData; + cudaErr = cudaMalloc(&wsGradIData, wsGradISize); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudaMalloc for auxiliary workspace memory " + "wsGradIData failed", + cudaErr); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = gradO->sizeOfT() <= 4 + ? reinterpret_cast(&alpha32) + : reinterpret_cast(&alpha64); + const void* beta = gradO->sizeOfT() <= 4 + ? reinterpret_cast(&beta32) + : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); + + // run calculation for gradB (if not nullptr) + if (gradB != nullptr) { + cudnnTensorDescriptor_t db; + cudnnCreateTensorDescriptor(&db); + // err = cudnnSetTensor4dDescriptor(db, format, + // cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, + // isNCHW ? 1: gradB->lengthOf()); + err = cudnnSetTensor4dDescriptor( + db, CUDNN_TENSOR_NCHW, cudnnDataType(gradB->dataType()), 1, oC, 1, 1); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudnnSetTensor4dDescriptor for gradB failed", + err); + + err = + cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->specialBuffer(), + beta, db, gradB->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudnnConvolutionBackwardBias failed", err); + } + + // run calculation for gradW + err = cudnnConvolutionBackwardFilter( + *handle, alpha, x, input->specialBuffer(), dz, gradO->specialBuffer(), + conv, algoGradW, wsGradWData, wsGradWSize, beta, dw, + gradW->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudnnConvolutionBackwardFilter failed", err); + + // run calculation for gradI + err = cudnnConvolutionBackwardData( + *handle, alpha, dw, weights->specialBuffer(), dz, gradO->specialBuffer(), + conv, algoGradI, wsGradIData, wsGradISize, beta, dx, + gradI->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudnnConvolutionBackwardData failed", err); + + // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + // if (cudaErr != 0) + // throw cuda_exception::build("depthwiseConv2dBpCUDNN: + // cudaStreamSynchronize failed !", cudaErr); + + cudaErr = cudaFree(wsGradWData); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudaFree for auxiliary workspace memory " + "wsGradWData failed", + cudaErr); + cudaErr = cudaFree(wsGradIData); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudaFree for auxiliary workspace memory " + "wsGradIData failed", + cudaErr); + + NDArray::registerSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(depthwise_conv2d, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC - - auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) - - REQUIRE_TRUE(input->rankOf() == 4, 0, "DEPTHWISECONV2D CUDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "DEPTHWISECONV2D CUDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "DEPTHWISECONV2D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "DEPTHWISECONV2D CUDNN OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "DEPTHWISECONV2D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - std::vector wPermut; // cudnn support format {oC, iC/groupCount, kH, kW} only, mC = 1, oC = iC (groupCount == iC) that is {iC, mC, kH, kW} in our case - if(0 == wFormat) - wPermut = {2,3,0,1}; // kH, kW, iC, mC -> iC, mC, kH, kW - else if(1 == wFormat) - wPermut = {1,0,2,3}; // mC, iC, kH, kW -> iC, mC, kH, kW - else - wPermut = {3,0,1,2}; // mC, kH, kW, iC -> iC, mC, kH, kW - - NDArray* newWeights = new NDArray(weights->ordering(), {iC, mC, kH, kW}, weights->dataType(), weights->getContext()); - newWeights->assign(weights->permute(wPermut)); - - NDArray* newInput = input; - NDArray* newGradI = nullptr; - if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings - checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW); - - depthwiseConv2dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW, paddingMode, isNCHW); - - if(newInput != input) - delete newInput; - - delete newWeights; - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC + + auto output = OUTPUT_VARIABLE( + 0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "DEPTHWISECONV2D CUDNN OP: rank of input array must be equal to " + "4, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "DEPTHWISECONV2D CUDNN OP: rank of weights array must be equal " + "to 4, but got %i instead !", + weights->rankOf()); + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = block.numI() > 9 + ? !INT_ARG(9) + : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.numI() > 10 + ? INT_ARG(10) + : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - + // [mC, kH, kW, iC] + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW, paddingMode); + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "DEPTHWISECONV2D CUDNN OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + REQUIRE_TRUE(output->sizeAt(indIOioC) == iC * mC, 0, + "DEPTHWISECONV2D CUDNN OP: the output_channels must be equal to " + "input_channels * channels_multiplier = %i !", + iC * mC); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "DEPTHWISECONV2D CUDNN OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + std::vector wPermut; // cudnn support format {oC, iC/groupCount, kH, kW} + // only, mC = 1, oC = iC (groupCount == iC) that is + // {iC, mC, kH, kW} in our case + if (0 == wFormat) + wPermut = {2, 3, 0, 1}; // kH, kW, iC, mC -> iC, mC, kH, kW + else if (1 == wFormat) + wPermut = {1, 0, 2, 3}; // mC, iC, kH, kW -> iC, mC, kH, kW + else + wPermut = {3, 0, 1, 2}; // mC, kH, kW, iC -> iC, mC, kH, kW + + NDArray* newWeights = new NDArray(weights->ordering(), {iC, mC, kH, kW}, + weights->dataType(), weights->getContext()); + newWeights->assign(weights->permute(wPermut)); + + NDArray* newInput = input; + NDArray* newGradI = nullptr; + if (paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric + // left/right top/bottopm paddings + checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, + sH, sW, pH, pW, dH, dW, isNCHW); + + depthwiseConv2dCUDNN(block.launchContext(), newInput, newWeights, bias, + output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, + isNCHW); + + if (newInput != input) delete newInput; + + delete newWeights; + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(depthwise_conv2d, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC - - const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL - const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - const int mC = weights->sizeAt(0 == wFormat ? 3 : 0); - - const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; - const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; - const bool badBiasType = bias == nullptr ? false : (bias->dataType() != DataType::DOUBLE && bias->dataType() != DataType::FLOAT32 && bias->dataType() != DataType::HALF); - - return mC == 1 && paddingMode != 2 && !badInputType && !badWeightsType && !badBiasType; + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC + + const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL + const int wFormat = block.numI() > 10 + ? INT_ARG(10) + : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 + // - [mC, kH, kW, iC] + + const int mC = weights->sizeAt(0 == wFormat ? 3 : 0); + + const bool badInputType = input->dataType() != DataType::DOUBLE && + input->dataType() != DataType::FLOAT32 && + input->dataType() != DataType::HALF; + const bool badWeightsType = weights->dataType() != DataType::DOUBLE && + weights->dataType() != DataType::FLOAT32 && + weights->dataType() != DataType::HALF; + const bool badBiasType = bias == nullptr + ? false + : (bias->dataType() != DataType::DOUBLE && + bias->dataType() != DataType::FLOAT32 && + bias->dataType() != DataType::HALF); + + return mC == 1 && paddingMode != 2 && !badInputType && !badWeightsType && + !badBiasType; } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) { + auto input = INPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + auto bias = + block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_VARIABLE( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "DEPTHWISECONV2D_BP CUDNN OP: rank of input array must be equal " + "to 4, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "DEPTHWISECONV2D_BP CUDNN OP: rank of weights array must be " + "equal to 4, but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 4, 0, + "DEPTHWISECONV2D_BP CUDNN OP: rank of output gradients (next " + "epsilon) array must be equal to 4, but got %i instead !", + gradO->rankOf()); + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = block.numI() > 9 + ? !INT_ARG(9) + : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = block.numI() > 10 + ? INT_ARG(10) + : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - + // [mC, kH, kW, iC] + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + int trueoH, trueoW; // correct output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, paddingMode); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW, paddingMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of output gradients " + "(next epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE( + bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + std::vector wPermut, + gradWPermut; // cudnn support format {oC, iC/groupCount, kH, kW} only, mC + // = 1, oC = iC (groupCount == iC) that is {iC, mC, kH, kW} + if (0 == wFormat) { + wPermut = {2, 3, 0, 1}; // kH, kW, iC, mC -> iC, mC, kH, kW + gradWPermut = {2, 3, 0, 1}; // iC, mC, kH, kW -> kH, kW, iC, mC + } else if (1 == wFormat) { + wPermut = {1, 0, 2, 3}; // mC, iC, kH, kW -> iC, mC, kH, kW + gradWPermut = {1, 0, 2, 3}; // iC, mC, kH, kW -> mC, iC, kH, kW + } else { + wPermut = {3, 0, 1, 2}; // mC, kH, kW, iC -> iC, mC, kH, kW + gradWPermut = {1, 2, 3, 0}; // iC, mC, kH, kW -> mC, kH, kW, iC + } + + NDArray* newGradW = new NDArray(gradW->ordering(), {iC, mC, kH, kW}, + gradW->dataType(), gradW->getContext()); + NDArray* newWeights = new NDArray(weights->ordering(), {iC, mC, kH, kW}, + weights->dataType(), weights->getContext()); + + newWeights->assign(weights->permute(wPermut)); + + NDArray* newInput = input; + NDArray* newGradI = gradI; + if (paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric + // left/right top/bottopm paddings + checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, + sH, sW, pH, pW, dH, dW, isNCHW); + + depthwiseConv2dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, + newGradI, newGradW, gradB, kH, kW, sH, sW, pH, pW, dH, + dW, paddingMode, isNCHW); + + newGradW->permutei(gradWPermut); + gradW->assign(newGradW); + + if (newInput != input) { + if (isNCHW) + gradI->assign( + (*newGradI)({0, 0, 0, 0, 0, gradI->sizeAt(2), 0, gradI->sizeAt(3)})); + else + gradI->assign( + (*newGradI)({0, 0, 0, gradI->sizeAt(1), 0, gradI->sizeAt(2), 0, 0})); - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - REQUIRE_TRUE(input->rankOf() == 4, 0, "DEPTHWISECONV2D_BP CUDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "DEPTHWISECONV2D_BP CUDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 4, 0, "DEPTHWISECONV2D_BP CUDNN OP: rank of output gradients (next epsilon) array must be equal to 4, but got %i instead !", gradO->rankOf()); - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - int trueoH, trueoW; // correct output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - std::vector wPermut, gradWPermut; // cudnn support format {oC, iC/groupCount, kH, kW} only, mC = 1, oC = iC (groupCount == iC) that is {iC, mC, kH, kW} - if(0 == wFormat) { - wPermut = {2,3,0,1}; // kH, kW, iC, mC -> iC, mC, kH, kW - gradWPermut = {2,3,0,1}; // iC, mC, kH, kW -> kH, kW, iC, mC - } - else if(1 == wFormat) { - wPermut = {1,0,2,3}; // mC, iC, kH, kW -> iC, mC, kH, kW - gradWPermut = {1,0,2,3}; // iC, mC, kH, kW -> mC, iC, kH, kW - } - else { - wPermut = {3,0,1,2}; // mC, kH, kW, iC -> iC, mC, kH, kW - gradWPermut = {1,2,3,0}; // iC, mC, kH, kW -> mC, kH, kW, iC - } - - NDArray* newGradW = new NDArray(gradW->ordering(), {iC, mC, kH, kW}, gradW->dataType(), gradW->getContext()); - NDArray* newWeights = new NDArray(weights->ordering(), {iC, mC, kH, kW}, weights->dataType(), weights->getContext()); - - newWeights->assign(weights->permute(wPermut)); - - NDArray* newInput = input; - NDArray* newGradI = gradI; - if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings - checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW); - - depthwiseConv2dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, newGradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,paddingMode,isNCHW); - - newGradW->permutei(gradWPermut); - gradW->assign(newGradW); - - if(newInput != input) { - - if(isNCHW) - gradI->assign((*newGradI)({0,0, 0,0, 0,gradI->sizeAt(2), 0,gradI->sizeAt(3)})); - else - gradI->assign((*newGradI)({0,0, 0,gradI->sizeAt(1), 0,gradI->sizeAt(2), 0,0})); - - delete newInput; - delete newGradI; - } - - delete newWeights; - delete newGradW; - - return Status::OK(); -} - -PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - - const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL - const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - const int mC = weights->sizeAt(0 == wFormat ? 3 : 0); + delete newInput; + delete newGradI; + } - const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; - const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; - const bool badGradOType = gradO->dataType() != DataType::DOUBLE && gradO->dataType() != DataType::FLOAT32 && gradO->dataType() != DataType::HALF; - const bool badBiasType = bias == nullptr ? false : (bias->dataType() != DataType::DOUBLE && bias->dataType() != DataType::FLOAT32 && bias->dataType() != DataType::HALF); + delete newWeights; + delete newGradW; - return mC == 1 && isNCHW && paddingMode != 2 && !badInputType && !badWeightsType && !badGradOType && !badBiasType; + return Status::OK(); } - -} -} +PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CUDA) { + auto input = INPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + auto bias = + block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, + // oH, oW] (NCDHW), epsilon_next + + const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL + const int isNCHW = block.numI() > 9 + ? !INT_ARG(9) + : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + const int wFormat = block.numI() > 10 + ? INT_ARG(10) + : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 + // - [mC, kH, kW, iC] + + const int mC = weights->sizeAt(0 == wFormat ? 3 : 0); + + const bool badInputType = input->dataType() != DataType::DOUBLE && + input->dataType() != DataType::FLOAT32 && + input->dataType() != DataType::HALF; + const bool badWeightsType = weights->dataType() != DataType::DOUBLE && + weights->dataType() != DataType::FLOAT32 && + weights->dataType() != DataType::HALF; + const bool badGradOType = gradO->dataType() != DataType::DOUBLE && + gradO->dataType() != DataType::FLOAT32 && + gradO->dataType() != DataType::HALF; + const bool badBiasType = bias == nullptr + ? false + : (bias->dataType() != DataType::DOUBLE && + bias->dataType() != DataType::FLOAT32 && + bias->dataType() != DataType::HALF); + + return mC == 1 && isNCHW && paddingMode != 2 && !badInputType && + !badWeightsType && !badGradOType && !badBiasType; } + +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/cudnn/maxpool2d.cu b/libnd4j/include/ops/declarable/platform/cudnn/maxpool2d.cu index 5bb646f57c47..50504fd29bc1 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/maxpool2d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/maxpool2d.cu @@ -18,115 +18,160 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // +#include #include "cudnnUtils.h" -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { - ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(maxpool2d, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - paddingModee; - const auto kH = INT_ARG(0); - const auto kW = INT_ARG(1); - const auto sH = INT_ARG(2); - const auto sW = INT_ARG(3); - auto pH = INT_ARG(4); - auto pW = INT_ARG(5); - const auto dH = INT_ARG(6); - const auto dW = INT_ARG(7); - const auto paddingMode = static_cast(INT_ARG(8)); - const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D CUDNN op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int oH = 0; - int oW = 0; - - const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); - const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); - - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); - - if (paddingMode) - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - pooling2dCUDNN(block.launchContext(), input, output, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, CUDNN_POOLING_MAX); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - paddingModee; + const auto kH = INT_ARG(0); + const auto kW = INT_ARG(1); + const auto sH = INT_ARG(2); + const auto sW = INT_ARG(3); + auto pH = INT_ARG(4); + auto pW = INT_ARG(5); + const auto dH = INT_ARG(6); + const auto dW = INT_ARG(7); + const auto paddingMode = static_cast(INT_ARG(8)); + const int isNCHW = block.numI() > 10 + ? !INT_ARG(10) + : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "MAXPOOL2D CUDNN op: input should have rank of 4, but got %i instead", + input->rankOf()); + REQUIRE_TRUE( + dH != 0 && dW != 0, 0, + "MAXPOOL2D CUDNN op: dilation must not be zero, but got instead {%i, %i}", + dH, dW); + + int oH = 0; + int oW = 0; + + const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); + const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); + + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, + iH, iW, paddingMode); + + if (paddingMode) + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + pooling2dCUDNN(block.launchContext(), input, output, kH, kW, sH, sW, pH, pW, + dH, dW, isNCHW, CUDNN_POOLING_MAX); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(maxpool2d, ENGINE_CUDA) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; + const auto goodType = input->dataType() == DataType::DOUBLE || + input->dataType() == DataType::FLOAT32 || + input->dataType() == DataType::HALF || + input->dataType() == DataType::INT32; - return goodType && input->dataType() == output->dataType(); + return goodType && input->dataType() == output->dataType(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(maxpool2d_bp, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - const auto kH = INT_ARG(0); // filter(kernel) height - const auto kW = INT_ARG(1); // filter(kernel) width - const auto sH = INT_ARG(2); // strides height - const auto sW = INT_ARG(3); // strides width - auto pH = INT_ARG(4); // paddings height - auto pW = INT_ARG(5); // paddings width - const auto dH = INT_ARG(6); // dilations height - const auto dW = INT_ARG(7); // dilations width - const auto paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - const auto isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D_BP CUDNN op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D_BP CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); - std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL2D_BP CUDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "MAXPOOL2D_BP CUDNN op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); - - if(paddingMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - pooling2dBpCUDNN(block.launchContext(), input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, CUDNN_POOLING_MAX); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE( + 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + const auto kH = INT_ARG(0); // filter(kernel) height + const auto kW = INT_ARG(1); // filter(kernel) width + const auto sH = INT_ARG(2); // strides height + const auto sW = INT_ARG(3); // strides width + auto pH = INT_ARG(4); // paddings height + auto pW = INT_ARG(5); // paddings width + const auto dH = INT_ARG(6); // dilations height + const auto dW = INT_ARG(7); // dilations width + const auto paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + const auto isNCHW = block.numI() > 10 + ? !INT_ARG(10) + : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "MAXPOOL2D_BP CUDNN op: input should have rank of 4, but got %i instead", + input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, + "MAXPOOL2D_BP CUDNN op: dilation must not be zero, but got " + "instead {%i, %i}", + dH, dW); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWoC, indWkH, indOoH); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}); + std::vector expectedGradIShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "MAXPOOL2D_BP CUDNN op: wrong shape of output's gradients array " + "(next epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, + "MAXPOOL2D_BP CUDNN op: wrong shape of input's gradients array " + "(epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradIShape).c_str(), + ShapeUtils::shapeAsString(gradI).c_str()); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + pooling2dBpCUDNN(block.launchContext(), input, gradO, gradI, kH, kW, sH, sW, + pH, pW, dH, dW, isNCHW, CUDNN_POOLING_MAX); + + return Status::OK(); } PLATFORM_CHECK(maxpool2d_bp, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; - - return goodType && (input->dataType() == gradO->dataType()) - && (input->dataType() == gradI->dataType()) - && shape::haveSameShapeAndStrides(input->shapeInfo(), gradI->shapeInfo()); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE( + 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + const auto goodType = input->dataType() == DataType::DOUBLE || + input->dataType() == DataType::FLOAT32 || + input->dataType() == DataType::HALF || + input->dataType() == DataType::INT32; + + return goodType && (input->dataType() == gradO->dataType()) && + (input->dataType() == gradI->dataType()) && + shape::haveSameShapeAndStrides(input->shapeInfo(), gradI->shapeInfo()); } - -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/cudnn/maxpool3d.cu b/libnd4j/include/ops/declarable/platform/cudnn/maxpool3d.cu index f7b9c8b50278..ab5ff55f28d6 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/maxpool3d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/maxpool3d.cu @@ -18,123 +18,180 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // +#include #include "cudnnUtils.h" -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { - ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(maxpool3dnew, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) - - int kD = INT_ARG(0); // filter(kernel) depth - int kH = INT_ARG(1); // filter(kernel) height - int kW = INT_ARG(2); // filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - // int extraParam0 = INT_ARG(13); - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC - - REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::vector expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "MAXPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); - - if(paddingMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - pooling3dCUDNN(block.launchContext(), input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, CUDNN_POOLING_MAX); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto output = OUTPUT_VARIABLE( + 0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) + + int kD = INT_ARG(0); // filter(kernel) depth + int kH = INT_ARG(1); // filter(kernel) height + int kW = INT_ARG(2); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + // int extraParam0 = INT_ARG(13); + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) + : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "MAXPOOL3DNEW CUDNN OP: rank of input array must be equal to 5, " + "but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "MAXPOOL3DNEW CUDNN OP: dilation must not be zero, but got " + "instead {%i, %i, %i}", + dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedOutputShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oD, oH, oW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, + "MAXPOOL3DNEW CUDNN OP: wrong shape of output array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedOutputShape).c_str(), + ShapeUtils::shapeAsString(output).c_str()); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + pooling3dCUDNN(block.launchContext(), input, output, kD, kH, kW, sD, sH, sW, + pD, pH, pW, dD, dH, dW, isNCDHW, CUDNN_POOLING_MAX); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(maxpool3dnew, ENGINE_CUDA) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + const auto goodType = input->dataType() == DataType::DOUBLE || + input->dataType() == DataType::FLOAT32 || + input->dataType() == DataType::HALF || + input->dataType() == DataType::INT32; - const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; - - return goodType && input->dataType() == output->dataType(); + return goodType && input->dataType() == output->dataType(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - - const int kD = INT_ARG(0); // filter(kernel) depth - const int kH = INT_ARG(1); // filter(kernel) height - const int kW = INT_ARG(2); // filter(kernel) width - const int sD = INT_ARG(3); // strides depth - const int sH = INT_ARG(4); // strides height - const int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - const int dD = INT_ARG(9); // dilations depth - const int dH = INT_ARG(10); // dilations height - const int dW = INT_ARG(11); // dilations width - const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID - // const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging - const int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC - - REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW_BP CUDNN OP: input should have rank of 5, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW_BP CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL3DNEW_BP CUDNN: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "MAXPOOL3DNEW_BP CUDNN: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); - - if(isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - pooling3dBpCUDNN(block.launchContext(), input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, CUDNN_POOLING_MAX); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, + // oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), epsilon + + const int kD = INT_ARG(0); // filter(kernel) depth + const int kH = INT_ARG(1); // filter(kernel) height + const int kW = INT_ARG(2); // filter(kernel) width + const int sD = INT_ARG(3); // strides depth + const int sH = INT_ARG(4); // strides height + const int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + const int dD = INT_ARG(9); // dilations depth + const int dH = INT_ARG(10); // dilations height + const int dW = INT_ARG(11); // dilations width + const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + // const int extraParam0 = INT_ARG(13); // define what divisor to use while + // averaging + const int isNCDHW = block.numI() > 14 + ? !INT_ARG(14) + : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "MAXPOOL3DNEW_BP CUDNN OP: input should have rank of 5, but got " + "%i instead", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "MAXPOOL3DNEW_BP CUDNN OP: dilation must not be zero, but got " + "instead {%i, %i, %i}", + dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oD, oH, oW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + std::vector expectedGradIShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iD, iH, iW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "MAXPOOL3DNEW_BP CUDNN: wrong shape of output's gradients array " + "(next epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, + "MAXPOOL3DNEW_BP CUDNN: wrong shape of input's gradients array " + "(epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradIShape).c_str(), + ShapeUtils::shapeAsString(gradI).c_str()); + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + pooling3dBpCUDNN(block.launchContext(), input, gradO, gradI, kD, kH, kW, sD, + sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, CUDNN_POOLING_MAX); + + return Status::OK(); } PLATFORM_CHECK(maxpool3dnew_bp, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - - const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; - - return goodType && (input->dataType() == gradO->dataType()) - && (input->dataType() == gradI->dataType()) - && shape::haveSameShapeAndStrides(input->shapeInfo(), gradI->shapeInfo()); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, + // oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), epsilon + + const auto goodType = input->dataType() == DataType::DOUBLE || + input->dataType() == DataType::FLOAT32 || + input->dataType() == DataType::HALF || + input->dataType() == DataType::INT32; + + return goodType && (input->dataType() == gradO->dataType()) && + (input->dataType() == gradI->dataType()) && + shape::haveSameShapeAndStrides(input->shapeInfo(), gradI->shapeInfo()); } - -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp index 4adab2dfef49..fd6b6cb8be81 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp @@ -20,115 +20,151 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include +#include #include +#include +#include #include -#include #include "mkldnnUtils.h" -#include using namespace dnnl; using namespace samediff; -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(avgpool2d, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - const auto kH = INT_ARG(0); - const auto kW = INT_ARG(1); - const auto sH = INT_ARG(2); - const auto sW = INT_ARG(3); - auto pH = INT_ARG(4); - auto pW = INT_ARG(5); - const auto dH = INT_ARG(6); - const auto dW = INT_ARG(7); - const auto paddingMode = INT_ARG(8); - const auto extraParam0 = INT_ARG(9); - const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D MKLDNN op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - if (paddingMode) - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding; - - mkldnnUtils::poolingMKLDNN(input, output, 0,kH,kW, 0,sH,sW, 0,pH,pW, isNCHW, mode); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + const auto kH = INT_ARG(0); + const auto kW = INT_ARG(1); + const auto sH = INT_ARG(2); + const auto sW = INT_ARG(3); + auto pH = INT_ARG(4); + auto pW = INT_ARG(5); + const auto dH = INT_ARG(6); + const auto dW = INT_ARG(7); + const auto paddingMode = INT_ARG(8); + const auto extraParam0 = INT_ARG(9); + const int isNCHW = + block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "AVGPOOL2D MKLDNN op: input should have rank of 4, but got %i instead", + input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, + "AVGPOOL2D MKLDNN op: dilation must not be zero, but got " + "instead {%i, %i}", + dH, dW); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWoC, indWkH, indOoH); + + if (paddingMode) + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding + : algorithm::pooling_avg_include_padding; + + mkldnnUtils::poolingMKLDNN(input, output, 0, kH, kW, 0, sH, sW, 0, pH, pW, + isNCHW, mode); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(avgpool2d, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); + return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int extraParam0 = INT_ARG(9); - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D_BP MKLDNN op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL2D_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - - if(paddingMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding; - - mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, 0,kH,kW, 0,sH,sW, 0,pH,pW, isNCHW, mode); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE( + 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int extraParam0 = INT_ARG(9); + int isNCHW = + block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "AVGPOOL2D_BP MKLDNN op: input should have rank of 4, but got %i instead", + input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, + "AVGPOOL2D_BP MKLDNN op: dilation must not be zero, but got " + "instead {%i, %i}", + dH, dW); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWoC, indWkH, indOoH); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "AVGPOOL2D_BP MKLDNN op: wrong shape of output's gradients " + "array (next epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding + : algorithm::pooling_avg_include_padding; + + mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, 0, kH, kW, 0, sH, sW, 0, pH, + pW, isNCHW, mode); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(avgpool2d_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); + return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); } - -} -} -} \ No newline at end of file +} // namespace platforms +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp index 96110bd295d5..a22f3abf6886 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp @@ -20,120 +20,155 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include +#include #include +#include +#include #include -#include #include "mkldnnUtils.h" -#include using namespace dnnl; -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(avgpool3dnew, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) - - int kD = INT_ARG(0); // filter(kernel) depth - int kH = INT_ARG(1); // filter(kernel) height - int kW = INT_ARG(2); // filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - int extraParam0 = INT_ARG(13); - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC - - REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW MKLDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - if(paddingMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding; - - mkldnnUtils::poolingMKLDNN(input, output, kD,kH,kW, sD,sH,sW, pD,pH,pW, isNCDHW, mode); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto output = OUTPUT_VARIABLE( + 0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) + + int kD = INT_ARG(0); // filter(kernel) depth + int kH = INT_ARG(1); // filter(kernel) height + int kW = INT_ARG(2); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + int extraParam0 = INT_ARG(13); + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "AVGPOOL3DNEW MKLDNN OP: rank of input array must be equal to " + "5, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "AVGPOOL3DNEW MKLDNN OP: dilation must not be zero, but got " + "instead {%i, %i, %i}", + dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding + : algorithm::pooling_avg_include_padding; + + mkldnnUtils::poolingMKLDNN(input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, + isNCDHW, mode); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(avgpool3dnew, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); + return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); } ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - - const int kD = INT_ARG(0); // filter(kernel) depth - const int kH = INT_ARG(1); // filter(kernel) height - const int kW = INT_ARG(2); // filter(kernel) width - const int sD = INT_ARG(3); // strides depth - const int sH = INT_ARG(4); // strides height - const int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - const int dD = INT_ARG(9); // dilations depth - const int dH = INT_ARG(10); // dilations height - const int dW = INT_ARG(11); // dilations width - const int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging - const int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC - - REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW_BP MKLDNN op: input should have rank of 5, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - - if(paddingMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding; - - mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, kD,kH,kW, sD,sH,sW, pD,pH,pW, isNCDHW, mode); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, + // oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), epsilon + + const int kD = INT_ARG(0); // filter(kernel) depth + const int kH = INT_ARG(1); // filter(kernel) height + const int kW = INT_ARG(2); // filter(kernel) width + const int sD = INT_ARG(3); // strides depth + const int sH = INT_ARG(4); // strides height + const int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + const int dD = INT_ARG(9); // dilations depth + const int dH = INT_ARG(10); // dilations height + const int dW = INT_ARG(11); // dilations width + const int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + const int extraParam0 = + INT_ARG(13); // define what divisor to use while averaging + const int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "AVGPOOL3DNEW_BP MKLDNN op: input should have rank of 5, but " + "got %i instead", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "AVGPOOL3DNEW_BP MKLDNN op: dilation must not be zero, but got " + "instead {%i, %i, %i}", + dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oD, oH, oW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "AVGPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients " + "array (next epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding + : algorithm::pooling_avg_include_padding; + + mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, + pH, pW, isNCDHW, mode); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(avgpool3dnew_bp, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); + return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); } - -} -} -} \ No newline at end of file +} // namespace platforms +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp index 6e0b1685a51e..3c526d9e39e1 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp @@ -21,427 +21,490 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include +#include +#include #include +#include +#include #include -#include #include "mkldnnUtils.h" -#include -#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { - ////////////////////////////////////////////////////////////////////////// -static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* weights, NDArray* z, - const float epsilon, const bool isNCHW) { - - // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for x - - // x -> 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc - // mean -> 1D [c] - // variance -> 1D [c] - // weights 2D [2, c], weights({0,1, 0,0}) contains gamma and weights({1,2, 0,0}) contains beta - // z(output) - same shape as x - - const int xRank = x->rankOf(); - - // input type - dnnl::memory::data_type type = dnnl::memory::data_type::f32; - - // indicate whether gamma or/and beta are given - auto flags = dnnl::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch - if (weights != nullptr) - flags |= dnnl::normalization_flags::use_scale_shift; - - dnnl::memory::dims dims; - dnnl::memory::format_tag format; - - const int indHW = isNCHW ? 2 : 1; - const int bS = x->sizeAt(0); - const int iC = isNCHW ? x->sizeAt(1) : x->sizeAt(-1); - - int iD, iH, iW; - - if(xRank == 2) { - dims = {bS, iC}; - format = dnnl::memory::format_tag::nc; - } - else if(xRank == 4) { - iH = x->sizeAt(indHW); - iW = x->sizeAt(indHW + 1); - dims = {bS, iC, iH, iW}; - format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - } - else { // xRank = 5 - iD = x->sizeAt(indHW); - iH = x->sizeAt(indHW + 1); - iW = x->sizeAt(indHW + 2); - dims = {bS, iC, iD, iH, iW}; - format = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - } - - // memory descriptors for arrays - - // x - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format); - dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); +static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, + const NDArray* variance, const NDArray* weights, + NDArray* z, const float epsilon, + const bool isNCHW) { + // unfortunately mkl dnn doesn't support any format + // (dnnl::memory::format_tag::any) for x + + // x -> 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc + // mean -> 1D [c] + // variance -> 1D [c] + // weights 2D [2, c], weights({0,1, 0,0}) contains gamma and weights({1,2, + // 0,0}) contains beta z(output) - same shape as x + + const int xRank = x->rankOf(); + + // input type + dnnl::memory::data_type type = dnnl::memory::data_type::f32; + + // indicate whether gamma or/and beta are given + auto flags = + dnnl::normalization_flags::use_global_stats; // don't calculate the mean + // and variance for each + // mini-batch + if (weights != nullptr) flags |= dnnl::normalization_flags::use_scale_shift; + + dnnl::memory::dims dims; + dnnl::memory::format_tag format; + + const int indHW = isNCHW ? 2 : 1; + const int bS = x->sizeAt(0); + const int iC = isNCHW ? x->sizeAt(1) : x->sizeAt(-1); + + int iD, iH, iW; + + if (xRank == 2) { + dims = {bS, iC}; + format = dnnl::memory::format_tag::nc; + } else if (xRank == 4) { + iH = x->sizeAt(indHW); + iW = x->sizeAt(indHW + 1); + dims = {bS, iC, iH, iW}; + format = isNCHW ? dnnl::memory::format_tag::nchw + : dnnl::memory::format_tag::nhwc; + } else { // xRank = 5 + iD = x->sizeAt(indHW); + iH = x->sizeAt(indHW + 1); + iW = x->sizeAt(indHW + 2); + dims = {bS, iC, iD, iH, iW}; + format = isNCHW ? dnnl::memory::format_tag::ncdhw + : dnnl::memory::format_tag::ndhwc; + } + + // memory descriptors for arrays + + // x + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format); + dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); mkldnnUtils::setBlockStrides(*x, x_user_md); - // z, output - dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format); - mkldnnUtils::setBlockStrides(*z, z_user_md); - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + // z, output + dnnl::memory::desc z_mkl_md = + dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format); - // batchnorm forward description - dnnl::batch_normalization_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags); - dnnl::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + mkldnnUtils::setBlockStrides(*z, z_user_md); - // arguments (memory buffers) necessary for calculations - std::unordered_map args; + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); + // batchnorm forward description + dnnl::batch_normalization_forward::desc op_ff_desc( + dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags); + dnnl::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, + engine); - // provide memory and check whether reorder is required + // arguments (memory buffers) necessary for calculations + std::unordered_map args; - // x - mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + dnnl::stream stream(engine); - // z - auto z_user_mem = mkldnnUtils::loadDataToMklStream(*z, engine, stream, z_user_md, op_ff_prim_desc.dst_desc(), args[DNNL_ARG_DST]); + // provide memory and check whether reorder is required - // mean - auto mean_mkl_mem = dnnl::memory(op_ff_prim_desc.mean_desc(), engine, const_cast(mean->buffer())); - args[DNNL_ARG_MEAN] = mean_mkl_mem; + // x + mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, + op_ff_prim_desc.src_desc(), + args[DNNL_ARG_SRC]); - // variance - auto var_mkl_mem = dnnl::memory(op_ff_prim_desc.variance_desc(), engine, const_cast(variance->buffer())); - args[DNNL_ARG_VARIANCE] = var_mkl_mem; + // z + auto z_user_mem = mkldnnUtils::loadDataToMklStream(*z, engine, stream, z_user_md, op_ff_prim_desc.dst_desc(), + args[DNNL_ARG_DST] ); - // gamma and beta (and their gradients) if they are present - if(weights != nullptr) { + // mean + auto mean_mkl_mem = dnnl::memory(op_ff_prim_desc.mean_desc(), engine, + const_cast(mean->buffer())); + args[DNNL_ARG_MEAN] = mean_mkl_mem; - auto w_mkl_mem = dnnl::memory(op_ff_prim_desc.weights_desc(), engine, const_cast(weights->buffer())); - args[DNNL_ARG_WEIGHTS] = w_mkl_mem; - } + // variance + auto var_mkl_mem = dnnl::memory(op_ff_prim_desc.variance_desc(), engine, + const_cast(variance->buffer())); + args[DNNL_ARG_VARIANCE] = var_mkl_mem; - // run calculations - dnnl::batch_normalization_forward(op_ff_prim_desc).execute(stream, args); + // gamma and beta (and their gradients) if they are present + if (weights != nullptr) { + auto w_mkl_mem = dnnl::memory(op_ff_prim_desc.weights_desc(), engine, + const_cast(weights->buffer())); + args[DNNL_ARG_WEIGHTS] = w_mkl_mem; + } - // reorder outputs if necessary - if (op_ff_prim_desc.dst_desc() != z_user_mem.get_desc()) - dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); + // run calculations + dnnl::batch_normalization_forward(op_ff_prim_desc).execute(stream, args); - stream.wait(); + // reorder outputs if necessary + if (op_ff_prim_desc.dst_desc() != z_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); - // shape::printArray(z_mkl_mem.map_data(),8); -} + stream.wait(); + // shape::printArray(z_mkl_mem.map_data(),8); +} ////////////////////////////////////////////////////////////////////////// -static void batchnormBpMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray &dLdO, const NDArray* weights, - NDArray* dLdI, NDArray* dLdW, const float epsilon, const bool isNCHW) { - - // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for x - - // x -> 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc - // mean -> 1D [c] - // variance -> 1D [c] - // dLdO - same shape as x - // weights 2D [2, c], weights({0,1, 0,0}) contains gamma and weights({1,2, 0,0}) contains beta - // dLdI - same shape as x - // dLdW - same shape as weights, dLdW({0,1, 0,0}) contains grad_gamma and dLdW({1,2, 0,0}) contains grad_beta - - const int xRank = x->rankOf(); - - // input type - dnnl::memory::data_type type = dnnl::memory::data_type::f32; - - // indicate whether gamma or/and beta are given - auto flags = dnnl::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch - if (weights != nullptr) - flags |= dnnl::normalization_flags::use_scale_shift; - - dnnl::memory::dims dims; - dnnl::memory::format_tag format; - - const int indHW = isNCHW ? 2 : 1; - const int bS = x->sizeAt(0); - const int iC = isNCHW ? x->sizeAt(1) : x->sizeAt(-1); - - int iD, iH, iW; - - if(xRank == 2) { - dims = {bS, iC}; - format = dnnl::memory::format_tag::nc; - } - else if(xRank == 4) { - iH = x->sizeAt(indHW); - iW = x->sizeAt(indHW + 1); - dims = {bS, iC, iH, iW}; - format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - } - else { // xRank = 5 - iD = x->sizeAt(indHW); - iH = x->sizeAt(indHW + 1); - iW = x->sizeAt(indHW + 2); - dims = {bS, iC, iD, iH, iW}; - format = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - } - - // memory descriptors for arrays - - // x - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format); - dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); - mkldnnUtils::setBlockStrides(*x, x_user_md); - - // dLdO - dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format); - mkldnnUtils::setBlockStrides(dLdO, dLdO_user_md); - - // dLdI - dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format); - mkldnnUtils::setBlockStrides(*dLdI, dLdI_user_md); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // batchnorm forward description - dnnl::batch_normalization_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags); - dnnl::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); - - // batchnorm backprop description - dnnl::batch_normalization_backward::desc op_bp_desc(dnnl::prop_kind::backward, dLdO_mkl_md, x_mkl_md, epsilon, flags); - dnnl::batch_normalization_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory and check whether reorder is required - - // x - mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // dLdO - mkldnnUtils::loadDataToMklStream(dLdO, engine, stream, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); - - // mean - auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, const_cast(mean->buffer())); - args[DNNL_ARG_MEAN] = mean_mkl_mem; - - // variance - auto var_mkl_mem = dnnl::memory(op_bp_prim_desc.variance_desc(), engine, const_cast(variance->buffer())); - args[DNNL_ARG_VARIANCE] = var_mkl_mem; - - // dLdI - auto dLdI_user_mem = mkldnnUtils::loadDataToMklStream(*dLdI, engine, stream, dLdI_user_md, op_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]); - - // gamma and beta (and their gradients) if they are present - if(weights != nullptr) { - - auto w_mkl_mem = dnnl::memory(op_bp_prim_desc.weights_desc(), engine, const_cast(weights->buffer())); - args[DNNL_ARG_WEIGHTS] = w_mkl_mem; - - auto dLdW_mkl_mem = dnnl::memory(op_bp_prim_desc.weights_desc(), engine, dLdW->buffer()); - args[DNNL_ARG_DIFF_WEIGHTS] = dLdW_mkl_mem; - } - - // run calculations - dnnl::batch_normalization_backward(op_bp_prim_desc).execute(stream, args); - - // reorder outputs if necessary - if (op_bp_prim_desc.diff_src_desc() != dLdI_user_mem.get_desc()) - dnnl::reorder(args[DNNL_ARG_DIFF_SRC], dLdI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], dLdI_user_mem); - - stream.wait(); - - // shape::printArray(dLdI_mkl_mem.map_data(),8); - - // notations: - // f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output - // g = dLdO - // stdInv = 1 / (v + eps)^0.5 - // N - batch size (product of spatial dimensions) - - // formula for full derivative with respect to input (x) - // dLdI = dfdx + dfdm*dmdx + dfdv*(dvdm*dmdx + dvdx) - - // !!! MKL CALCULATES ONLY FIRST TERM dfdx, SO WE SHOULD CALCULATE TERM (dfdm*dmdx + dfdv*(dvdm*dmdx + dvdx)) BY OURSELF !!! - - // dfdm = -gamma*stdInv*g_sum; - // dmdx = 1/N; - // dvdx = 2 * (x - m) / N - // dvdm = -2 * [(x - m)]_sum / N - // dfdv = -0.5 * [g*(x - m)]_sum * stdInv^3, drop gamma here for calc convenience - - // finally: - // dLdI = dfdm / N + (2/N) * dfdv * (dvdm/2 + (x - m)) - // dLdI = gamma * ( stdInv * -g_sum/N + (2/N) * dfdv * (dvdm/2 + (x - m)) ) - - std::vector axes = isNCHW ? std::vector{1} : std::vector{xRank - 1}; - const auto excludedAxes = ShapeUtils::evalDimsToExclude(x->rankOf(), axes); - - // inversed batch size 1 / N - const auto Ninv = 1.f * mean->lengthOf() / x->lengthOf(); - - // x - mean - NDArray xMinusMean(x); // empty array with same shape as x - const_cast(x)->applyBroadcast(sd::broadcast::Subtract, axes, *mean, xMinusMean); - - // stdInv - NDArray stdInv = *variance + epsilon; - stdInv.applyTransform(transform::Reciprocal, stdInv); // 1 / (variance + epsilon) - stdInv.applyTransform(transform::Sqrt, stdInv); // 1 / (variance + epsilon)^0.5 - - // dfdm / N - auto dfdm = dLdO.reduceAlongDimension(sd::reduce::Sum, excludedAxes); - dfdm *= stdInv; - dfdm *= -Ninv; - - // dvdm / 2 - NDArray dvdm(mean); // empty array with same shape as mean - xMinusMean.reduceAlongDimension(sd::reduce::Sum, dvdm, excludedAxes); - dvdm *= -Ninv; - - // (2/N)*dfdv - NDArray dfdv(variance); // empty array with same shape as variance - (xMinusMean * dLdO).reduceAlongDimension(sd::reduce::Sum, dfdv, excludedAxes); - dfdv *= stdInv*stdInv*stdInv; - dfdv *= -Ninv; - - // dvdm/2 + (x - m) - xMinusMean.applyBroadcast(sd::broadcast::Add, axes, dvdm, xMinusMean); - // dfdv * (dvdm/2 + (x - m)) - xMinusMean.applyBroadcast(sd::broadcast::Multiply, axes, dfdv, xMinusMean); - // add dfdm / N - xMinusMean.applyBroadcast(sd::broadcast::Add, axes, dfdm, xMinusMean); - // * gamma - auto gamma = (*weights)({0,1, 0,0}); - xMinusMean.applyBroadcast(sd::broadcast::Multiply, axes, gamma, xMinusMean); - - *dLdI += xMinusMean; +static void batchnormBpMKLDNN(const NDArray* x, const NDArray* mean, + const NDArray* variance, + const NDArray& dLdO, const NDArray* weights, + NDArray* dLdI, NDArray* dLdW, + const float epsilon, const bool isNCHW) { + // unfortunately mkl dnn doesn't support any format + // (dnnl::memory::format_tag::any) for x + + // x -> 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc + // mean -> 1D [c] + // variance -> 1D [c] + // dLdO - same shape as x + // weights 2D [2, c], weights({0,1, 0,0}) contains gamma and weights({1,2, + // 0,0}) contains beta dLdI - same shape as x dLdW - same shape as weights, + // dLdW({0,1, 0,0}) contains grad_gamma and dLdW({1,2, 0,0}) contains + // grad_beta + + const int xRank = x->rankOf(); + + // input type + dnnl::memory::data_type type = dnnl::memory::data_type::f32; + + // indicate whether gamma or/and beta are given + auto flags = + dnnl::normalization_flags::use_global_stats; // don't calculate the mean + // and variance for each + // mini-batch + if (weights != nullptr) flags |= dnnl::normalization_flags::use_scale_shift; + + dnnl::memory::dims dims; + dnnl::memory::format_tag format; + + const int indHW = isNCHW ? 2 : 1; + const int bS = x->sizeAt(0); + const int iC = isNCHW ? x->sizeAt(1) : x->sizeAt(-1); + + int iD, iH, iW; + + if (xRank == 2) { + dims = {bS, iC}; + format = dnnl::memory::format_tag::nc; + } else if (xRank == 4) { + iH = x->sizeAt(indHW); + iW = x->sizeAt(indHW + 1); + dims = {bS, iC, iH, iW}; + format = isNCHW ? dnnl::memory::format_tag::nchw + : dnnl::memory::format_tag::nhwc; + } else { // xRank = 5 + iD = x->sizeAt(indHW); + iH = x->sizeAt(indHW + 1); + iW = x->sizeAt(indHW + 2); + dims = {bS, iC, iD, iH, iW}; + format = isNCHW ? dnnl::memory::format_tag::ncdhw + : dnnl::memory::format_tag::ndhwc; + } + + // memory descriptors for arrays + + // x + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format); + dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); + + mkldnnUtils::setBlockStrides(*x, x_user_md); + + // dLdO + dnnl::memory::desc dLdO_mkl_md = + dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format); + + mkldnnUtils::setBlockStrides(dLdO, dLdO_user_md); + + // dLdI + dnnl::memory::desc dLdI_mkl_md = + dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format); + + mkldnnUtils::setBlockStrides(*dLdI, dLdI_user_md); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // batchnorm forward description + dnnl::batch_normalization_forward::desc op_ff_desc( + dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags); + dnnl::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, + engine); + + // batchnorm backprop description + dnnl::batch_normalization_backward::desc op_bp_desc( + dnnl::prop_kind::backward, dLdO_mkl_md, x_mkl_md, epsilon, flags); + dnnl::batch_normalization_backward::primitive_desc op_bp_prim_desc( + op_bp_desc, engine, op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory and check whether reorder is required + + // x + mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, + op_bp_prim_desc.src_desc(), + args[DNNL_ARG_SRC]); + + // dLdO + mkldnnUtils::loadDataToMklStream(dLdO, engine, stream, dLdO_user_md, + op_bp_prim_desc.diff_dst_desc(), + args[DNNL_ARG_DIFF_DST]); + + // mean + auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, + const_cast(mean->buffer())); + args[DNNL_ARG_MEAN] = mean_mkl_mem; + + // variance + auto var_mkl_mem = dnnl::memory(op_bp_prim_desc.variance_desc(), engine, + const_cast(variance->buffer())); + args[DNNL_ARG_VARIANCE] = var_mkl_mem; + + // dLdI + auto dLdI_user_mem = mkldnnUtils::loadDataToMklStream(*dLdI, engine, stream, dLdI_user_md, + op_bp_prim_desc.diff_src_desc() , + args[DNNL_ARG_DIFF_SRC] ); + + // gamma and beta (and their gradients) if they are present + if (weights != nullptr) { + auto w_mkl_mem = dnnl::memory(op_bp_prim_desc.weights_desc(), engine, + const_cast(weights->buffer())); + args[DNNL_ARG_WEIGHTS] = w_mkl_mem; + + auto dLdW_mkl_mem = + dnnl::memory(op_bp_prim_desc.weights_desc(), engine, dLdW->buffer()); + args[DNNL_ARG_DIFF_WEIGHTS] = dLdW_mkl_mem; + } + + // run calculations + dnnl::batch_normalization_backward(op_bp_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (op_bp_prim_desc.diff_src_desc() != dLdI_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_SRC], dLdI_user_mem) + .execute(stream, args[DNNL_ARG_DIFF_SRC], dLdI_user_mem); + + stream.wait(); + + // shape::printArray(dLdI_mkl_mem.map_data(),8); + + // notations: + // f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * + // ff_output g = dLdO stdInv = 1 / (v + eps)^0.5 N - batch size (product of + // spatial dimensions) + + // formula for full derivative with respect to input (x) + // dLdI = dfdx + dfdm*dmdx + dfdv*(dvdm*dmdx + dvdx) + + // !!! MKL CALCULATES ONLY FIRST TERM dfdx, SO WE SHOULD CALCULATE TERM + // (dfdm*dmdx + dfdv*(dvdm*dmdx + dvdx)) BY OURSELF !!! + + // dfdm = -gamma*stdInv*g_sum; + // dmdx = 1/N; + // dvdx = 2 * (x - m) / N + // dvdm = -2 * [(x - m)]_sum / N + // dfdv = -0.5 * [g*(x - m)]_sum * stdInv^3, drop gamma here for calc + // convenience + + // finally: + // dLdI = dfdm / N + (2/N) * dfdv * (dvdm/2 + (x - m)) + // dLdI = gamma * ( stdInv * -g_sum/N + (2/N) * dfdv * (dvdm/2 + (x - m)) ) + + std::vector axes = + isNCHW ? std::vector{1} : std::vector{xRank - 1}; + const auto excludedAxes = ShapeUtils::evalDimsToExclude(x->rankOf(), axes); + + // inversed batch size 1 / N + const auto Ninv = 1.f * mean->lengthOf() / x->lengthOf(); + + // x - mean + NDArray xMinusMean(x); // empty array with same shape as x + const_cast(x)->applyBroadcast(sd::broadcast::Subtract, axes, *mean, + xMinusMean); + + // stdInv + NDArray stdInv = *variance + epsilon; + stdInv.applyTransform(transform::Reciprocal, + stdInv); // 1 / (variance + epsilon) + stdInv.applyTransform(transform::Sqrt, + stdInv); // 1 / (variance + epsilon)^0.5 + + // dfdm / N + auto dfdm = dLdO.reduceAlongDimension(sd::reduce::Sum, excludedAxes); + dfdm *= stdInv; + dfdm *= -Ninv; + + // dvdm / 2 + NDArray dvdm(mean); // empty array with same shape as mean + xMinusMean.reduceAlongDimension(sd::reduce::Sum, dvdm, excludedAxes); + dvdm *= -Ninv; + + // (2/N)*dfdv + NDArray dfdv(variance); // empty array with same shape as variance + (xMinusMean * dLdO).reduceAlongDimension(sd::reduce::Sum, dfdv, excludedAxes); + dfdv *= stdInv * stdInv * stdInv; + dfdv *= -Ninv; + + // dvdm/2 + (x - m) + xMinusMean.applyBroadcast(sd::broadcast::Add, axes, dvdm, xMinusMean); + // dfdv * (dvdm/2 + (x - m)) + xMinusMean.applyBroadcast(sd::broadcast::Multiply, axes, dfdv, xMinusMean); + // add dfdm / N + xMinusMean.applyBroadcast(sd::broadcast::Add, axes, dfdm, xMinusMean); + // * gamma + auto gamma = (*weights)({0, 1, 0, 0}); + xMinusMean.applyBroadcast(sd::broadcast::Multiply, axes, gamma, xMinusMean); + + *dLdI += xMinusMean; } PLATFORM_IMPL(batchnorm, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc - auto mean = INPUT_VARIABLE(1); // [c] - auto variance = INPUT_VARIABLE(2); // [c] - NDArray* gamma = nullptr; // [c] - NDArray* beta = nullptr; // [c] - - auto output = OUTPUT_VARIABLE(0); // same shape as input - - const bool applyScale = (bool)INT_ARG(0); - const bool applyOffset = (bool)INT_ARG(1); - const double epsilon = T_ARG(0); - - if(applyScale) - gamma = INPUT_VARIABLE(3); - if(applyOffset) - beta = INPUT_VARIABLE(3 + (int)applyScale); - - const int numOfIntArgs = block.getIArguments()->size(); - const int inRank = input->rankOf(); - - // get axes args to normalize input array over - std::vector axes; - if(numOfIntArgs > 2) - for(int i = 2; i < numOfIntArgs; ++i) - axes.push_back(INT_ARG(i)); + auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc + auto mean = INPUT_VARIABLE(1); // [c] + auto variance = INPUT_VARIABLE(2); // [c] + NDArray* gamma = nullptr; // [c] + NDArray* beta = nullptr; // [c] + + auto output = OUTPUT_VARIABLE(0); // same shape as input + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + const double epsilon = T_ARG(0); + + if (applyScale) gamma = INPUT_VARIABLE(3); + if (applyOffset) beta = INPUT_VARIABLE(3 + (int)applyScale); + + const int numOfIntArgs = block.numI(); + const int inRank = input->rankOf(); + + // get axes args to normalize input array over + std::vector axes; + if (numOfIntArgs > 2) + for (int i = 2; i < numOfIntArgs; ++i) axes.push_back(INT_ARG(i)); + else + axes.push_back(inRank - + 1); // default dimension to reduce along is last dimension + + const int numOfAxes = axes.size(); + REQUIRE_TRUE(numOfAxes == 1, 0, + "BATCHNORM_MKLDNN op: mkl dnn library supports only one axis " + "which represents channel dimension, but got %i axes instead!", + numOfAxes); + REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 5, 0, + "BATCHNORM_MKLDNN op: possible values for rank of input array " + "are 2, 4 or 5, but got %i instead!", + inRank); + REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == input->sizeAt(axes[0]), + 0, + "BATCHNORM_MKLDNN op: wrong shape of mean array, expected is " + "[%lld], but got %s instead !", + input->sizeAt(axes[0]), ShapeUtils::shapeAsString(mean).c_str()); + REQUIRE_TRUE( + variance->rankOf() == 1 && variance->sizeAt(0) == input->sizeAt(axes[0]), + 0, + "BATCHNORM_MKLDNN op: wrong shape of variance array, expected is [%lld], " + "but got %s instead !", + input->sizeAt(axes[0]), ShapeUtils::shapeAsString(variance).c_str()); + if (gamma != nullptr) + REQUIRE_TRUE( + gamma->rankOf() == 1 && gamma->sizeAt(0) == input->sizeAt(axes[0]), 0, + "BATCHNORM_MKLDNN op: wrong shape of gamma array, expected is [%lld], " + "but got %s instead !", + input->sizeAt(axes[0]), ShapeUtils::shapeAsString(gamma).c_str()); + if (beta != nullptr) + REQUIRE_TRUE( + beta->rankOf() == 1 && beta->sizeAt(0) == input->sizeAt(axes[0]), 0, + "BATCHNORM_MKLDNN op: wrong shape of beta array, expected is [%lld], " + "but got %s instead !", + input->sizeAt(axes[0]), ShapeUtils::shapeAsString(beta).c_str()); + + // types of all input arrays should be the same (except dLdO) + for (int i = 1; i < block.width() - 1; ++i) + REQUIRE_TRUE( + INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, + "BATCHNORM_MKLDNN op: types of all input arrays should be the same !"); + + NDArray* weights = nullptr; + + if (applyScale || applyOffset) { + weights = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, + input->dataType()); + + if (applyScale) + (*weights)({0, 1, 0, 0}).assign(gamma); else - axes.push_back(inRank-1); // default dimension to reduce along is last dimension - - const int numOfAxes = axes.size(); - REQUIRE_TRUE(numOfAxes == 1, 0, "BATCHNORM_MKLDNN op: mkl dnn library supports only one axis which represents channel dimension, but got %i axes instead!", numOfAxes); - REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 5, 0, "BATCHNORM_MKLDNN op: possible values for rank of input array are 2, 4 or 5, but got %i instead!", inRank); - REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of mean array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(mean).c_str()); - REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of variance array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(variance).c_str()); - if(gamma != nullptr) - REQUIRE_TRUE(gamma->rankOf() == 1 && gamma->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of gamma array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(gamma).c_str()); - if(beta != nullptr) - REQUIRE_TRUE(beta->rankOf() == 1 && beta->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of beta array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(beta).c_str()); - - // types of all input arrays should be the same (except dLdO) - for(int i = 1; i < block.width() - 1; ++i) - REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_MKLDNN op: types of all input arrays should be the same !"); - - - NDArray *weights = nullptr; - - if(applyScale || applyOffset) { - - weights = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, input->dataType()); - - if(applyScale) - (*weights)({0,1, 0,0}).assign(gamma); - else - (*weights)({0,1, 0,0}).assign(1); - if(applyOffset) - (*weights)({1,2, 0,0}).assign(beta); - else - (*weights)({1,2, 0,0}).assign(0); - } + (*weights)({0, 1, 0, 0}).assign(1); + if (applyOffset) + (*weights)({1, 2, 0, 0}).assign(beta); + else + (*weights)({1, 2, 0, 0}).assign(0); + } - const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2); + const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2); - batchnormMKLDNN(input, mean, variance, weights, output, epsilon, isNCHW); + batchnormMKLDNN(input, mean, variance, weights, output, epsilon, isNCHW); - delete weights; + delete weights; - return Status::OK(); + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(batchnorm, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc - auto mean = INPUT_VARIABLE(1); // [c] - auto variance = INPUT_VARIABLE(2); // [c] - NDArray* gamma = nullptr; // [c] - NDArray* beta = nullptr; // [c] - - auto output = OUTPUT_VARIABLE(0); // same shape as input - - const bool applyScale = (bool)INT_ARG(0); - const bool applyOffset = (bool)INT_ARG(1); - - if(applyScale) - gamma = INPUT_VARIABLE(3); - if(applyOffset) - beta = INPUT_VARIABLE(3 + (int)applyScale); - - - const int numOfIntArgs = block.getIArguments()->size(); - std::vector axes; - if(numOfIntArgs > 2) - for(int i = 2; i < numOfIntArgs; ++i) - axes.push_back(INT_ARG(i)); - else - axes.push_back(input->rankOf()-1); // default dimension to reduce along is last dimension - - DataType inputType = input->dataType(); - DataType meanType = mean->dataType(); - DataType varType = variance->dataType(); - DataType gammaType = gamma != nullptr ? gamma->dataType() : DataType::FLOAT32; - DataType betaType = beta != nullptr ? beta->dataType() : DataType::FLOAT32; - DataType outType = output->dataType(); - - const int inRank = input->rankOf(); - - return block.isUseMKLDNN() && axes.size() == 1 && (axes[0] == 1 || axes[0] == inRank - 1) && (inRank == 2 || inRank == 4 || inRank == 5) && - (inputType == DataType::FLOAT32 && meanType == DataType::FLOAT32 && varType == DataType::FLOAT32 && - gammaType == DataType::FLOAT32 && betaType == DataType::FLOAT32 && outType == DataType::FLOAT32); + auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc + auto mean = INPUT_VARIABLE(1); // [c] + auto variance = INPUT_VARIABLE(2); // [c] + NDArray* gamma = nullptr; // [c] + NDArray* beta = nullptr; // [c] + + auto output = OUTPUT_VARIABLE(0); // same shape as input + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + + if (applyScale) gamma = INPUT_VARIABLE(3); + if (applyOffset) beta = INPUT_VARIABLE(3 + (int)applyScale); + + const int numOfIntArgs = block.numI(); + std::vector axes; + if (numOfIntArgs > 2) + for (int i = 2; i < numOfIntArgs; ++i) axes.push_back(INT_ARG(i)); + else + axes.push_back(input->rankOf() - + 1); // default dimension to reduce along is last dimension + + DataType inputType = input->dataType(); + DataType meanType = mean->dataType(); + DataType varType = variance->dataType(); + DataType gammaType = gamma != nullptr ? gamma->dataType() : DataType::FLOAT32; + DataType betaType = beta != nullptr ? beta->dataType() : DataType::FLOAT32; + DataType outType = output->dataType(); + + const int inRank = input->rankOf(); + + return block.isUseMKLDNN() && axes.size() == 1 && + (axes[0] == 1 || axes[0] == inRank - 1) && + (inRank == 2 || inRank == 4 || inRank == 5) && + (inputType == DataType::FLOAT32 && meanType == DataType::FLOAT32 && + varType == DataType::FLOAT32 && gammaType == DataType::FLOAT32 && + betaType == DataType::FLOAT32 && outType == DataType::FLOAT32); } ////////////////////////////////////////////////////////////////////////// @@ -472,42 +535,53 @@ PLATFORM_CHECK(batchnorm, ENGINE_CPU) { // axes.push_back(input->rankOf() - 1); // std::vector shape({2, mean->lengthOf()}); -// NDArray weights = NDArrayFactory::create('c', shape, block.launchContext()); -// weights({0, 1, 0, 0}).assign(1.0f); -// weights({1, 2, 0, 0}).assign(0.0f); +// NDArray weights = NDArrayFactory::create('c', shape, +// block.launchContext()); weights({0, 1, 0, 0}).assign(1.0f); weights({1, +// 2, 0, 0}).assign(0.0f); // mkldnn_memory_desc_t empty; -// dnnl::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), user_src_md(empty), user_dst_md(empty); +// dnnl::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), +// user_src_md(empty), user_dst_md(empty); // auto flag = dnnl::normalization_flags::use_global_stats; // if (applyScale || applyOffset) // flag |= dnnl::normalization_flags::use_scale_shift; // mkldnnUtils::getMKLDNNMemoryDescBatchNorm(input, nullptr, output, -// &batchnorm_src_md, nullptr, &batchnorm_dst_md, -// &user_src_md, nullptr, &user_dst_md, axes[0]); +// &batchnorm_src_md, nullptr, +// &batchnorm_dst_md, +// &user_src_md, nullptr, +// &user_dst_md, axes[0]); -// auto batchnorm_desc = dnnl::batch_normalization_forward::desc(dnnl::prop_kind::forward_inference, batchnorm_src_md, epsilon, flag); +// auto batchnorm_desc = +// dnnl::batch_normalization_forward::desc(dnnl::prop_kind::forward_inference, +// batchnorm_src_md, epsilon, flag); -// auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); +// auto engine = +// mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); // dnnl::stream stream(engine); -// auto batchnorm_prim_desc = dnnl::batch_normalization_forward::primitive_desc(batchnorm_desc, engine); -// auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); -// auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); -// auto batchnorm_mean_memory = dnnl::memory(batchnorm_prim_desc.mean_desc(), engine, +// auto batchnorm_prim_desc = +// dnnl::batch_normalization_forward::primitive_desc(batchnorm_desc, +// engine); auto user_src_memory = dnnl::memory(user_src_md, engine, +// input->buffer()); auto user_dst_memory = dnnl::memory(user_dst_md, +// engine, output->buffer()); auto batchnorm_mean_memory = +// dnnl::memory(batchnorm_prim_desc.mean_desc(), engine, // mean->buffer()); -// auto batchnorm_variance_memory = dnnl::memory(batchnorm_prim_desc.variance_desc(), engine, +// auto batchnorm_variance_memory = +// dnnl::memory(batchnorm_prim_desc.variance_desc(), engine, // variance->buffer()); // auto batchnorm_src_memory = user_src_memory; // dnnl::memory m(batchnorm_src_md, engine); // if (m.get_desc() != user_src_memory.get_desc()) { // batchnorm_src_memory = dnnl::memory(batchnorm_src_md, engine); -// dnnl::reorder(user_src_memory, batchnorm_src_memory).execute(stream, user_src_memory, +// dnnl::reorder(user_src_memory, batchnorm_src_memory).execute(stream, +// user_src_memory, // batchnorm_src_memory); // } // auto batchnorm_dst_memory = user_dst_memory; // if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) { -// batchnorm_dst_memory = dnnl::memory(batchnorm_prim_desc.dst_desc(), engine); +// batchnorm_dst_memory = dnnl::memory(batchnorm_prim_desc.dst_desc(), +// engine); // } // if (applyScale || applyOffset) { // if (gamma != nullptr) { @@ -517,22 +591,34 @@ PLATFORM_CHECK(batchnorm, ENGINE_CPU) { // weights({1, 2, 0, 0}).assign(beta); // } -// auto batchnorm_weights_memory = dnnl::memory(batchnorm_prim_desc.weights_desc(), engine, weights.buffer()); +// auto batchnorm_weights_memory = +// dnnl::memory(batchnorm_prim_desc.weights_desc(), engine, +// weights.buffer()); // dnnl::batch_normalization_forward(batchnorm_prim_desc).execute(stream, -// {{MKLDNN_ARG_SRC, batchnorm_src_memory}, -// {MKLDNN_ARG_MEAN, batchnorm_mean_memory}, -// {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory}, -// {MKLDNN_ARG_WEIGHTS, batchnorm_weights_memory}, -// {MKLDNN_ARG_DST, batchnorm_dst_memory}}); +// {{MKLDNN_ARG_SRC, +// batchnorm_src_memory}, +// {MKLDNN_ARG_MEAN, +// batchnorm_mean_memory}, +// {MKLDNN_ARG_VARIANCE, +// batchnorm_variance_memory}, +// {MKLDNN_ARG_WEIGHTS, +// batchnorm_weights_memory}, +// {MKLDNN_ARG_DST, +// batchnorm_dst_memory}}); // } else { // dnnl::batch_normalization_forward(batchnorm_prim_desc).execute(stream, -// {{MKLDNN_ARG_SRC, batchnorm_src_memory}, -// {MKLDNN_ARG_MEAN, batchnorm_mean_memory}, -// {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory}, -// {MKLDNN_ARG_DST, batchnorm_dst_memory}}); +// {{MKLDNN_ARG_SRC, +// batchnorm_src_memory}, +// {MKLDNN_ARG_MEAN, +// batchnorm_mean_memory}, +// {MKLDNN_ARG_VARIANCE, +// batchnorm_variance_memory}, +// {MKLDNN_ARG_DST, +// batchnorm_dst_memory}}); // } // if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) { -// dnnl::reorder(batchnorm_dst_memory, user_dst_memory).execute(stream, batchnorm_dst_memory, +// dnnl::reorder(batchnorm_dst_memory, user_dst_memory).execute(stream, +// batchnorm_dst_memory, // user_dst_memory); // } // stream.wait(); @@ -571,160 +657,192 @@ PLATFORM_CHECK(batchnorm, ENGINE_CPU) { // axes.push_back(input->rankOf() - 1); // return block.isUseMKLDNN() && -// sd::MKLDNNStream::isSupported({input, mean, variance, gamma, beta, output}) && -// axes.size() == 1; +// sd::MKLDNNStream::isSupported({input, mean, variance, gamma, beta, +// output}) && axes.size() == 1; // } - ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(batchnorm_bp, ENGINE_CPU) { - - NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc - NDArray* mean = INPUT_VARIABLE(1); // [c] - NDArray* variance = INPUT_VARIABLE(2); // [c] - NDArray* gamma = nullptr; // [c] - NDArray* beta = nullptr; // [c] - NDArray* dLdO = INPUT_VARIABLE(block.width() - 1); // same as input - - NDArray* dLdI = OUTPUT_VARIABLE(0); // same as input - NDArray* dLdM = OUTPUT_VARIABLE(1); // [c] - NDArray* dLdV = OUTPUT_VARIABLE(2); // [c] - NDArray* dLdG = nullptr; // [c] - NDArray* dLdB = nullptr; // [c] - - const bool applyScale = (bool)INT_ARG(0); - const bool applyOffset = (bool)INT_ARG(1); - const float epsilon = T_ARG(0); - - if(applyScale) { - gamma = INPUT_VARIABLE(3); - dLdG = OUTPUT_VARIABLE(3); - } - if(applyOffset) { - beta = INPUT_VARIABLE(3 + (int)applyScale); - dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); - } - - const int numOfIntArgs = block.getIArguments()->size(); - const int inRank = input->rankOf(); - - // get axes args to normalize input array over - std::vector axes; - if(numOfIntArgs > 2) - for(int i = 2; i < numOfIntArgs; ++i) - axes.push_back(INT_ARG(i)); + NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc + NDArray* mean = INPUT_VARIABLE(1); // [c] + NDArray* variance = INPUT_VARIABLE(2); // [c] + NDArray* gamma = nullptr; // [c] + NDArray* beta = nullptr; // [c] + NDArray* dLdO = INPUT_VARIABLE(block.width() - 1); // same as input + + NDArray* dLdI = OUTPUT_VARIABLE(0); // same as input + NDArray* dLdM = OUTPUT_VARIABLE(1); // [c] + NDArray* dLdV = OUTPUT_VARIABLE(2); // [c] + NDArray* dLdG = nullptr; // [c] + NDArray* dLdB = nullptr; // [c] + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + const float epsilon = T_ARG(0); + + if (applyScale) { + gamma = INPUT_VARIABLE(3); + dLdG = OUTPUT_VARIABLE(3); + } + if (applyOffset) { + beta = INPUT_VARIABLE(3 + (int)applyScale); + dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); + } + + const int numOfIntArgs = block.numI(); + const int inRank = input->rankOf(); + + // get axes args to normalize input array over + std::vector axes; + if (numOfIntArgs > 2) + for (int i = 2; i < numOfIntArgs; ++i) axes.push_back(INT_ARG(i)); + else + axes.push_back(inRank - + 1); // default dimension to reduce along is last dimension + + const int numOfAxes = axes.size(); + REQUIRE_TRUE(numOfAxes == 1, 0, + "BATCHNORM_BP_MKLDNN op: mkl dnn library supports only one axis " + "which represents channel dimension, but got %i axes instead!", + numOfAxes); + REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 5, 0, + "BATCHNORM_BP_MKLDNN op: possible values for rank of input " + "array are 2, 4 or 5, but got %i instead!", + inRank); + REQUIRE_TRUE(input->isSameShape(dLdO), 0, + "BATCHNORM_BP_MKLDNN op: wrong shape of gradients array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(input).c_str(), + ShapeUtils::shapeAsString(dLdO).c_str()); + REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == input->sizeAt(axes[0]), + 0, + "BATCHNORM_BP_MKLDNN op: wrong shape of mean array, expected is " + "[%lld], but got %s instead !", + input->sizeAt(axes[0]), ShapeUtils::shapeAsString(mean).c_str()); + REQUIRE_TRUE( + variance->rankOf() == 1 && variance->sizeAt(0) == input->sizeAt(axes[0]), + 0, + "BATCHNORM_BP_MKLDNN op: wrong shape of variance array, expected is " + "[%lld], but got %s instead !", + input->sizeAt(axes[0]), ShapeUtils::shapeAsString(variance).c_str()); + if (gamma != nullptr) + REQUIRE_TRUE( + gamma->rankOf() == 1 && gamma->sizeAt(0) == input->sizeAt(axes[0]), 0, + "BATCHNORM_BP_MKLDNN op: wrong shape of gamma array, expected is " + "[%lld], but got %s instead !", + input->sizeAt(axes[0]), ShapeUtils::shapeAsString(gamma).c_str()); + if (beta != nullptr) + REQUIRE_TRUE( + beta->rankOf() == 1 && beta->sizeAt(0) == input->sizeAt(axes[0]), 0, + "BATCHNORM_BP_MKLDNN op: wrong shape of beta array, expected is " + "[%lld], but got %s instead !", + input->sizeAt(axes[0]), ShapeUtils::shapeAsString(beta).c_str()); + + // types of all input arrays should be the same + for (int i = 1; i < block.width() - 1; ++i) + REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), + 0, + "BATCHNORM_BP_MKLDNN op: types of all input arrays should be " + "the same !"); + + NDArray *weights = nullptr, *dLdW = nullptr; + + if (applyScale || applyOffset) { + weights = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, + input->dataType()); + dLdW = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, + input->dataType()); + if (applyScale) + (*weights)({0, 1, 0, 0}).assign(gamma); else - axes.push_back(inRank-1); // default dimension to reduce along is last dimension - - const int numOfAxes = axes.size(); - REQUIRE_TRUE(numOfAxes == 1, 0, "BATCHNORM_BP_MKLDNN op: mkl dnn library supports only one axis which represents channel dimension, but got %i axes instead!", numOfAxes); - REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 5, 0, "BATCHNORM_BP_MKLDNN op: possible values for rank of input array are 2, 4 or 5, but got %i instead!", inRank); - REQUIRE_TRUE(input->isSameShape(dLdO), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of gradients array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(dLdO).c_str()); - REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of mean array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(mean).c_str()); - REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of variance array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(variance).c_str()); - if(gamma != nullptr) - REQUIRE_TRUE(gamma->rankOf() == 1 && gamma->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of gamma array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(gamma).c_str()); - if(beta != nullptr) - REQUIRE_TRUE(beta->rankOf() == 1 && beta->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of beta array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(beta).c_str()); - - // types of all input arrays should be the same - for(int i = 1; i < block.width() - 1; ++i) - REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP_MKLDNN op: types of all input arrays should be the same !"); - - - NDArray *weights = nullptr, *dLdW = nullptr; - - if(applyScale || applyOffset) { - weights = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, input->dataType()); - dLdW = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, input->dataType()); - if(applyScale) - (*weights)({0,1, 0,0}).assign(gamma); - else - (*weights)({0,1, 0,0}).assign(1); - if(applyOffset) - (*weights)({1,2, 0,0}).assign(beta); - else - (*weights)({1,2, 0,0}).assign(0); - } - - const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2); - - if (shape::strideDescendingCAscendingF(dLdO->shapeInfo())) - batchnormBpMKLDNN(input, mean, variance, *dLdO, weights, dLdI, dLdW, epsilon, isNCHW); + (*weights)({0, 1, 0, 0}).assign(1); + if (applyOffset) + (*weights)({1, 2, 0, 0}).assign(beta); else - batchnormBpMKLDNN(input, mean, variance, dLdO->dup(), weights, dLdI, dLdW, epsilon, isNCHW); + (*weights)({1, 2, 0, 0}).assign(0); + } + + const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2); - *dLdM = 0; - *dLdV = 0; + if (shape::strideDescendingCAscendingF(dLdO->shapeInfo())) + batchnormBpMKLDNN(input, mean, variance, *dLdO, weights, dLdI, dLdW, + epsilon, isNCHW); + else + batchnormBpMKLDNN(input, mean, variance, dLdO->dup(), weights, dLdI, + dLdW, epsilon, isNCHW); - if(applyScale || applyOffset) { - if(applyScale) - dLdG->assign((*dLdW)({0,1, 0,0})); - if(applyOffset) - dLdB->assign((*dLdW)({1,2, 0,0})); + *dLdM = 0; + *dLdV = 0; - delete weights; - delete dLdW; - } + if (applyScale || applyOffset) { + if (applyScale) dLdG->assign((*dLdW)({0, 1, 0, 0})); + if (applyOffset) dLdB->assign((*dLdW)({1, 2, 0, 0})); - return Status::OK(); + delete weights; + delete dLdW; + } + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(batchnorm_bp, ENGINE_CPU) { - - NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw - NDArray* mean = INPUT_VARIABLE(1); // [c] - NDArray* variance = INPUT_VARIABLE(2); // [c] - NDArray* dLdO = INPUT_VARIABLE(3); // same as input - NDArray* gamma = nullptr; // [c] - NDArray* beta = nullptr; // [c] - - NDArray* dLdI = OUTPUT_VARIABLE(0); // same as input - NDArray* dLdM = OUTPUT_VARIABLE(1); // [c] - NDArray* dLdV = OUTPUT_VARIABLE(2); // [c] - NDArray* dLdG = nullptr; // [c] - NDArray* dLdB = nullptr; // [c] - - const bool applyScale = (bool)INT_ARG(0); - const bool applyOffset = (bool)INT_ARG(1); - - if(applyScale) { - gamma = INPUT_VARIABLE(4); - dLdG = OUTPUT_VARIABLE(3); - } - if(applyOffset) { - beta = INPUT_VARIABLE(4 + (int)applyScale); - dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); - } - - const int numOfIntArgs = block.getIArguments()->size(); - std::vector axes; - if(numOfIntArgs > 2) - for(int i = 2; i < numOfIntArgs; ++i) - axes.push_back(INT_ARG(i)); - else - axes.push_back(input->rankOf()-1); // default dimension to reduce along is last dimension - - DataType inputType = input->dataType(); - DataType meanType = mean->dataType(); - DataType varType = variance->dataType(); - DataType dLdOType = dLdO->dataType(); - DataType gammaType = gamma != nullptr ? gamma->dataType() : DataType::FLOAT32; - DataType betaType = beta != nullptr ? beta->dataType() : DataType::FLOAT32; - - DataType dLdIType = dLdI->dataType(); - DataType dLdGType = gamma != nullptr ? dLdG->dataType() : DataType::FLOAT32; - DataType dLdBType = beta != nullptr ? dLdB->dataType() : DataType::FLOAT32; - - const int inRank = input->rankOf(); - - return block.isUseMKLDNN() && axes.size() == 1 && (axes[0] == 1 || axes[0] == inRank - 1) && (inRank == 2 || inRank == 4 || inRank == 5) && - (inputType == DataType::FLOAT32 && meanType == DataType::FLOAT32 && varType == DataType::FLOAT32 && - dLdOType == DataType::FLOAT32 && gammaType == DataType::FLOAT32 && betaType == DataType::FLOAT32 && - dLdIType == DataType::FLOAT32 && dLdGType == DataType::FLOAT32 && dLdBType == DataType::FLOAT32); + NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw + NDArray* mean = INPUT_VARIABLE(1); // [c] + NDArray* variance = INPUT_VARIABLE(2); // [c] + NDArray* dLdO = INPUT_VARIABLE(3); // same as input + NDArray* gamma = nullptr; // [c] + NDArray* beta = nullptr; // [c] + + NDArray* dLdI = OUTPUT_VARIABLE(0); // same as input + NDArray* dLdM = OUTPUT_VARIABLE(1); // [c] + NDArray* dLdV = OUTPUT_VARIABLE(2); // [c] + NDArray* dLdG = nullptr; // [c] + NDArray* dLdB = nullptr; // [c] + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + + if (applyScale) { + gamma = INPUT_VARIABLE(4); + dLdG = OUTPUT_VARIABLE(3); + } + if (applyOffset) { + beta = INPUT_VARIABLE(4 + (int)applyScale); + dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); + } + + const int numOfIntArgs = block.numI(); + std::vector axes; + if (numOfIntArgs > 2) + for (int i = 2; i < numOfIntArgs; ++i) axes.push_back(INT_ARG(i)); + else + axes.push_back(input->rankOf() - + 1); // default dimension to reduce along is last dimension + + DataType inputType = input->dataType(); + DataType meanType = mean->dataType(); + DataType varType = variance->dataType(); + DataType dLdOType = dLdO->dataType(); + DataType gammaType = gamma != nullptr ? gamma->dataType() : DataType::FLOAT32; + DataType betaType = beta != nullptr ? beta->dataType() : DataType::FLOAT32; + + DataType dLdIType = dLdI->dataType(); + DataType dLdGType = gamma != nullptr ? dLdG->dataType() : DataType::FLOAT32; + DataType dLdBType = beta != nullptr ? dLdB->dataType() : DataType::FLOAT32; + + const int inRank = input->rankOf(); + + return block.isUseMKLDNN() && axes.size() == 1 && + (axes[0] == 1 || axes[0] == inRank - 1) && + (inRank == 2 || inRank == 4 || inRank == 5) && + (inputType == DataType::FLOAT32 && meanType == DataType::FLOAT32 && + varType == DataType::FLOAT32 && dLdOType == DataType::FLOAT32 && + gammaType == DataType::FLOAT32 && betaType == DataType::FLOAT32 && + dLdIType == DataType::FLOAT32 && dLdGType == DataType::FLOAT32 && + dLdBType == DataType::FLOAT32); } -} -} -} \ No newline at end of file +} // namespace platforms +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp index 3bf97e586527..ea51ca54850a 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp @@ -94,7 +94,7 @@ PLATFORM_IMPL(concat, ENGINE_CPU) { REQUIRE_TRUE(block.width() > 0, 0, "CONCAT MKLDNN op: No input arrays were provided"); - const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0); + const bool isAxisInLastArr = block.numB() == 0 ? false : B_ARG(0); const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); @@ -105,7 +105,7 @@ PLATFORM_IMPL(concat, ENGINE_CPU) { int index = 0; bool allOfSameType = true; auto rankOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->rankOf() : 0; - auto typeOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->dataType() : block.dataType(); + auto typeOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->dataType() : DataType::FLOAT32; for(int i = 0; i < numOfInArrs; ++i) { auto input = INPUT_VARIABLE(i); @@ -178,7 +178,7 @@ PLATFORM_CHECK(concat, ENGINE_CPU) { const auto zType = z->dataType(); - const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0); + const bool isAxisInLastArr = block.numB() == 0 ? false : B_ARG(0); const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); return z->rankOf() < 7 && numOfInArrs <= 3072 diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp index a889d030207b..1dcfa56865e1 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp @@ -20,49 +20,59 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include +#include #include +#include +#include #include -#include #include "mkldnnUtils.h" -#include using namespace dnnl; -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////// static void conv2dMKLDNN(const NDArray *input, const NDArray *weights, - const NDArray *bias, NDArray *output, - const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const int paddingMode, const int isNCHW, const int wFormat) { - - // mkl support weights in [oC, iC, kH, kW] format only - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d - - dnnl::memory::dims strides = { sH, sW }; - dnnl::memory::dims padding = { pH, pW }; - dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; - dnnl::memory::dims dilation = { dH-1, dW-1}; - - auto xzFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; - - dnnl::memory::dims xDims = {bS, iC, iH, iW}; - dnnl::memory::dims wDims = {oC, iC, kH, kW}; - dnnl::memory::dims zDims = {bS, oC, oH, oW}; - - auto type = dnnl::memory::data_type::f32; - - std::vector permut; + const NDArray *bias, NDArray *output, const int kH, + const int kW, const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW, + const int wFormat) { + // mkl support weights in [oC, iC, kH, kW] format only + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + const int pWSame = (paddingMode == 2 && dW > 1) + ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 + : pW; // dH == 1 for causal mode in conv1d + + dnnl::memory::dims strides = {sH, sW}; + dnnl::memory::dims padding = {pH, pW}; + dnnl::memory::dims padding_r = {(oH - 1) * sH - iH + kH - pH, + (oW - 1) * sW - iW + kW - pWSame}; + dnnl::memory::dims dilation = {dH - 1, dW - 1}; + + auto xzFormatMkl = + isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; + + dnnl::memory::dims xDims = {bS, iC, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oH, oW}; + + auto type = dnnl::memory::data_type::f32; + + std::vector permut; if(0 == wFormat) permut = {3,2,0,1}; // [kH, kW, iC, oC] -> [oC, iC, kH, kW] else if(2 == wFormat) @@ -70,94 +80,118 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights, // memory descriptors for arrays - // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(*input, x_user_md); - - // weights - dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); - mkldnnUtils::setBlockStrides(*weights, w_user_md, permut); - - // bias - dnnl::memory::desc b_mkl_md; - if(bias != nullptr) - b_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x); - - // output - dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(*output, z_user_md); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // operation primitive description - dnnl::convolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_forward::primitive_desc op_prim_desc(op_desc, engine); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // weights - mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - - // bias - if(bias != nullptr) { - auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, const_cast(bias->buffer())); - args[DNNL_ARG_BIAS] = b_mkl_mem; - } - - // output - auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]); - - // run calculations - dnnl::convolution_forward(op_prim_desc).execute(stream, args); - - // reorder outputs if necessary - if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) - dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); - - stream.wait(); - // shape::printArray(z_mkl_mem.map_data(),8); + // input + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); + mkldnnUtils::setBlockStrides(*input, x_user_md); + + // weights + dnnl::memory::desc w_mkl_md = + dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); + mkldnnUtils::setBlockStrides(*weights, + w_user_md, permut); + + // bias + dnnl::memory::desc b_mkl_md; + if (bias != nullptr) + b_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x); + + // output + dnnl::memory::desc z_mkl_md = + dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl); + mkldnnUtils::setBlockStrides(*output, z_user_md); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // operation primitive description + dnnl::convolution_forward::desc op_desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, + x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, + padding_r); + dnnl::convolution_forward::primitive_desc op_prim_desc(op_desc, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, + op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + + // weights + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, + op_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + + // bias + if (bias != nullptr) { + auto b_mkl_mem = + dnnl::memory(b_mkl_md, engine, const_cast(bias->buffer())); + args[DNNL_ARG_BIAS] = b_mkl_mem; + } + + // output + auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), + args[DNNL_ARG_DST] ); + + // run calculations + dnnl::convolution_forward(op_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); + + stream.wait(); + // shape::printArray(z_mkl_mem.map_data(),8); } ////////////////////////////////////////////////////////////////////// -static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO, - NDArray *gradI, NDArray *gradW, NDArray *gradB, - const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const int paddingMode, const int isNCHW, const int wFormat) { - - // mkl support weights/gradW in [oC, iC, kH, kW] format only - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d - - dnnl::memory::dims strides = { sH, sW }; - dnnl::memory::dims padding = { pH, pW }; - dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; - dnnl::memory::dims dilation = { dH-1, dW-1}; - - auto xzFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; - - dnnl::memory::dims xDims = {bS, iC, iH, iW}; - dnnl::memory::dims wDims = {oC, iC, kH, kW}; - dnnl::memory::dims zDims = {bS, oC, oH, oW}; - - auto type = dnnl::memory::data_type::f32; - - std::vector permut; +static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, + const NDArray *bias, const NDArray *gradO, + NDArray *gradI, NDArray *gradW, NDArray *gradB, + const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, + const int dH, const int dW, const int paddingMode, + const int isNCHW, const int wFormat) { + // mkl support weights/gradW in [oC, iC, kH, kW] format only + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + const int pWSame = (paddingMode == 2 && dW > 1) + ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 + : pW; // dH == 1 for causal mode in conv1d + + dnnl::memory::dims strides = {sH, sW}; + dnnl::memory::dims padding = {pH, pW}; + dnnl::memory::dims padding_r = {(oH - 1) * sH - iH + kH - pH, + (oW - 1) * sW - iW + kW - pWSame}; + dnnl::memory::dims dilation = {dH - 1, dW - 1}; + + auto xzFormatMkl = + isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; + + dnnl::memory::dims xDims = {bS, iC, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oH, oW}; + + auto type = dnnl::memory::data_type::f32; + + std::vector permut; if(0 == wFormat) permut = {3,2,0,1}; // [kH, kW, iC, oC] -> [oC, iC, kH, kW] else if(2 == wFormat) @@ -165,162 +199,200 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N // memory descriptors for arrays - // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(*input, x_user_md); - - // weights - dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); - mkldnnUtils::setBlockStrides(*weights, w_user_md, permut); - - // gradO - dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(*gradO, gradO_user_md); - - // gradI - dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(*gradI, gradI_user_md); - - // gradW - dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); - mkldnnUtils::setBlockStrides(*gradW, gradW_user_md, permut); - - // gradB - dnnl::memory::desc gradB_mkl_md; - if(gradB != nullptr) - gradB_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // forward primitive description - dnnl::convolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); - - // backward data primitive description - dnnl::convolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc); - - // backward weights primitive description - dnnl::convolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::convolution_auto, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // weights - mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - - // gradO - auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast(gradO->buffer())); - const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; - auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; - if (gradOReorderW) - dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW); - if (gradOReorderD) - dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD); - args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; - - // gradI - auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]); - - // gradW - auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, op_weights_bp_prim_desc.diff_weights_desc(), args[DNNL_ARG_DIFF_WEIGHTS]); - - // gradB - if(gradB != nullptr) { - auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->buffer()); - args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem; - } - - // run backward data calculations - dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args); - - if(gradOReorderW || gradOReorderD) - args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW; - - // run backward weights calculations - dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); - - // reorder gradI if necessary - if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc()) - dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem); - if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc()) - dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem).execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem); - - stream.wait(); - - // shape::printArray(z_mkl_mem.map_data(),8); + // input + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); + mkldnnUtils::setBlockStrides(*input, x_user_md); + + // weights + dnnl::memory::desc w_mkl_md = + dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); + mkldnnUtils::setBlockStrides(*weights, + w_user_md, permut); + + // gradO + dnnl::memory::desc gradO_mkl_md = + dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc gradO_user_md = + dnnl::memory::desc(zDims, type, xzFormatMkl); + mkldnnUtils::setBlockStrides(*gradO, gradO_user_md); + + // gradI + dnnl::memory::desc gradI_mkl_md = + dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc gradI_user_md = + dnnl::memory::desc(xDims, type, xzFormatMkl); + mkldnnUtils::setBlockStrides(*gradI, gradI_user_md); + + // gradW + dnnl::memory::desc gradW_mkl_md = + dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc gradW_user_md = + dnnl::memory::desc(wDims, type, wFormatMkl); + mkldnnUtils::setBlockStrides(*gradW, + gradW_user_md, permut); + + // gradB + dnnl::memory::desc gradB_mkl_md; + if (gradB != nullptr) + gradB_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // forward primitive description + dnnl::convolution_forward::desc op_ff_desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, + x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, + padding, padding_r); + dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + + // backward data primitive description + dnnl::convolution_backward_data::desc op_data_bp_desc( + dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, + strides, dilation, padding, padding_r); + dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc( + op_data_bp_desc, engine, op_ff_prim_desc); + + // backward weights primitive description + dnnl::convolution_backward_weights::desc op_weights_bp_desc( + dnnl::algorithm::convolution_auto, x_mkl_md, gradW_mkl_md, gradB_mkl_md, + gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::convolution_backward_weights::primitive_desc op_weights_bp_prim_desc( + op_weights_bp_desc, engine, op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, + op_weights_bp_prim_desc.src_desc(), + args[DNNL_ARG_SRC]); + + // weights + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, + op_data_bp_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + + // gradO + auto gradO_user_mem = + dnnl::memory(gradO_user_md, engine, const_cast(gradO->buffer())); + const bool gradOReorderW = + op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + const bool gradOReorderD = + op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + auto gradO_mkl_memW = + gradOReorderW + ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) + : gradO_user_mem; + auto gradO_mkl_memD = + gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) + : gradO_user_mem; + if (gradOReorderW) + dnnl::reorder(gradO_user_mem, gradO_mkl_memW) + .execute(stream, gradO_user_mem, gradO_mkl_memW); + if (gradOReorderD) + dnnl::reorder(gradO_user_mem, gradO_mkl_memD) + .execute(stream, gradO_user_mem, gradO_mkl_memD); + args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; + + // gradI + auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, + op_data_bp_prim_desc.diff_src_desc() , + args[DNNL_ARG_DIFF_SRC] ); + + // gradW + auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, + op_weights_bp_prim_desc.diff_weights_desc() , + args[DNNL_ARG_DIFF_WEIGHTS] ); + + // gradB + if (gradB != nullptr) { + auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->buffer()); + args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem; + } + + // run backward data calculations + dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args); + + if (gradOReorderW || gradOReorderD) args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW; + + // run backward weights calculations + dnnl::convolution_backward_weights(op_weights_bp_prim_desc) + .execute(stream, args); + + // reorder gradI if necessary + if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem) + .execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem); + if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem) + .execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem); + + stream.wait(); + + // shape::printArray(z_mkl_mem.map_data(),8); } /* ////////////////////////////////////////////////////////////////////// -static void conv2dMKLDNN(sd::graph::Context &block, const NDArray *input, const NDArray *weights, - const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH, - const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, - const int isNCHW) { +static void conv2dMKLDNN(sd::graph::Context &block, const NDArray *input, const +NDArray *weights, const NDArray *bias, NDArray *output, const int kH, const int +kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, +const int paddingMode, const int isNCHW) { - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + int bS, iC, iH, iW, oC, oH, oW; // batch size, +input channels, input height/width, output channels, output height/width; int +indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, +iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, +dW, paddingMode); dnnl_memory_desc_t empty; - dnnl::memory::desc x_mkl_md(empty), w_mkl_md(empty), b_mkl_md(empty), z_mkl_md(empty); - dnnl::memory::desc x_user_md(empty), w_user_md(empty), b_user_md(empty), z_user_md(empty); + dnnl::memory::desc x_mkl_md(empty), w_mkl_md(empty), b_mkl_md(empty), +z_mkl_md(empty); dnnl::memory::desc x_user_md(empty), w_user_md(empty), +b_user_md(empty), z_user_md(empty); dnnl::memory::dims strides, padding, padding_r, dilation; - mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, - bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr, - bias, output, - &x_mkl_md, nullptr, &w_mkl_md, nullptr, - &b_mkl_md, &z_mkl_md, - &x_user_md, nullptr, &w_user_md, nullptr, - &b_user_md, &z_user_md, - strides, padding, padding_r, dilation); - - auto conv_desc = bias != nullptr ? convolution_forward::desc(prop_kind::forward, - algorithm::convolution_auto, x_mkl_md, - w_mkl_md, b_mkl_md, - z_mkl_md, strides, dilation, padding, - padding_r) - : convolution_forward::desc(prop_kind::forward, - algorithm::convolution_auto, x_mkl_md, - w_mkl_md, - z_mkl_md, strides, dilation, padding, - padding_r); - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, +paddingMode, isNCHW, bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, +nullptr, bias, output, &x_mkl_md, nullptr, &w_mkl_md, nullptr, &b_mkl_md, +&z_mkl_md, &x_user_md, nullptr, &w_user_md, nullptr, &b_user_md, &z_user_md, + strides, padding, padding_r, +dilation); + + auto conv_desc = bias != nullptr ? +convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, +x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r) + : +convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, +x_mkl_md, w_mkl_md, z_mkl_md, strides, dilation, padding, padding_r); auto +engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); dnnl::stream stream(engine); - auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine); - auto user_src_memory = dnnl::memory(x_user_md, engine, const_cast(input)->buffer()); - auto user_weights_memory = dnnl::memory(w_user_md, engine, - const_cast(weights)->buffer()); - auto user_dst_memory = dnnl::memory(z_user_md, engine, output->buffer()); - auto conv_src_memory = user_src_memory; - if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) { - conv_src_memory = dnnl::memory(conv_prim_desc.src_desc(), engine); - reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory); + auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, +engine); auto user_src_memory = dnnl::memory(x_user_md, engine, +const_cast(input)->buffer()); auto user_weights_memory = +dnnl::memory(w_user_md, engine, const_cast(weights)->buffer()); auto +user_dst_memory = dnnl::memory(z_user_md, engine, output->buffer()); auto +conv_src_memory = user_src_memory; if (conv_prim_desc.src_desc() != +user_src_memory.get_desc()) { conv_src_memory = +dnnl::memory(conv_prim_desc.src_desc(), engine); reorder(user_src_memory, +conv_src_memory).execute(stream, user_src_memory, conv_src_memory); } auto conv_weights_memory = user_weights_memory; if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) { - conv_weights_memory = dnnl::memory(conv_prim_desc.weights_desc(), engine); - reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory, - conv_weights_memory); + conv_weights_memory = dnnl::memory(conv_prim_desc.weights_desc(), +engine); reorder(user_weights_memory, conv_weights_memory).execute(stream, +user_weights_memory, conv_weights_memory); } auto conv_dst_memory = user_dst_memory; if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) { @@ -328,105 +400,133 @@ static void conv2dMKLDNN(sd::graph::Context &block, const NDArray *input, const } if (bias != nullptr) { auto conv_bias_memory = dnnl::memory(conv_prim_desc.bias_desc(), engine, - const_cast(bias)->buffer()); - convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory}, - {DNNL_ARG_WEIGHTS, conv_weights_memory}, - {DNNL_ARG_BIAS, conv_bias_memory}, - {DNNL_ARG_DST, conv_dst_memory}}); - } else { - convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory}, - {DNNL_ARG_WEIGHTS, conv_weights_memory}, - {DNNL_ARG_DST, conv_dst_memory}}); + const_cast(bias)->buffer()); convolution_forward(conv_prim_desc).execute(stream, +{{DNNL_ARG_SRC, conv_src_memory}, {DNNL_ARG_WEIGHTS, conv_weights_memory}, + {DNNL_ARG_BIAS, +conv_bias_memory}, {DNNL_ARG_DST, conv_dst_memory}}); } else { + convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, +conv_src_memory}, {DNNL_ARG_WEIGHTS, conv_weights_memory}, {DNNL_ARG_DST, +conv_dst_memory}}); } if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory); + reorder(conv_dst_memory, user_dst_memory).execute(stream, +conv_dst_memory, user_dst_memory); } stream.wait(); } ////////////////////////////////////////////////////////////////////// static void conv2dBpMKLDNN(sd::graph::Context &block, - const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO, - NDArray *gradI, NDArray *gradW, NDArray *gradB, - const int kH, const int kW, const int sH,const int sW, int pH, int pW, const int dH, const int dW, - const int paddingMode, const int isNCHW) { + const NDArray *input, const NDArray *weights, const +NDArray *bias, const NDArray *gradO, NDArray *gradI, NDArray *gradW, NDArray +*gradB, const int kH, const int kW, const int sH,const int sW, int pH, int pW, +const int dH, const int dW, const int paddingMode, const int isNCHW) { - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + int bS, iC, iH, iW, oC, oH, oW; // batch size, +input channels, input height/width, output channels, output height/width; int +indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, +iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, +dW, paddingMode); dnnl_memory_desc_t empty; - dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty); + dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), +conv_weights_md(empty), conv_diff_weights_md(empty), conv_bias_md(empty), +conv_dst_md(empty); dnnl::memory::desc user_src_md(empty), +user_diff_src_md(empty), user_weights_md(empty), user_diff_weights_md(empty), +user_bias_md(empty), user_dst_md(empty); - dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; + dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, +conv_dilation; - mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, - bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW, + mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, +paddingMode, isNCHW, bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW, gradB, gradO, - &conv_src_md, &conv_diff_src_md, &conv_weights_md, - &conv_diff_weights_md, &conv_bias_md, &conv_dst_md, - &user_src_md, &user_diff_src_md, &user_weights_md, - &user_diff_weights_md, &user_bias_md, &user_dst_md, - conv_strides, conv_padding, conv_padding_r, conv_dilation); - auto conv_desc = gradB != nullptr - ? convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) - : convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); - - auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine( LaunchContext::defaultContext()->engine())); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); + &conv_src_md, &conv_diff_src_md, +&conv_weights_md, &conv_diff_weights_md, &conv_bias_md, &conv_dst_md, + &user_src_md, &user_diff_src_md, +&user_weights_md, &user_diff_weights_md, &user_bias_md, &user_dst_md, + conv_strides, conv_padding, +conv_padding_r, conv_dilation); auto conv_desc = gradB != nullptr ? +convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, +conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, +conv_dilation, conv_padding, conv_padding_r) : +convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, +conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, +conv_padding, conv_padding_r); + + auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, +mkldnnUtils::getEngine( LaunchContext::defaultContext()->engine())); + + auto engine = +mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); dnnl::stream +stream(engine); if (gradW != nullptr) { - auto convW_desc = gradB != nullptr ? convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) - : convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); + auto convW_desc = gradB != nullptr ? +convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, +conv_diff_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, +conv_padding, conv_padding_r) : +convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, +conv_diff_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, +conv_padding_r); - auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, conv_prim_desc); + auto convW_prim_desc = +convolution_backward_weights::primitive_desc(convW_desc, engine, +conv_prim_desc); - auto userW_src_memory = dnnl::memory(user_src_md, engine, const_cast(input)->buffer()); - auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer()); - auto userW_dst_memory = dnnl::memory(user_dst_md, engine,const_cast(gradO)->buffer()); + auto userW_src_memory = dnnl::memory(user_src_md, engine, +const_cast(input)->buffer()); auto userW_weights_memory = +dnnl::memory(user_diff_weights_md, engine, gradW->buffer()); auto +userW_dst_memory = dnnl::memory(user_dst_md, engine,const_cast(gradO)->buffer()); auto convW_src_memory = userW_src_memory; if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) { convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine); - reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,convW_src_memory); + reorder(userW_src_memory, convW_src_memory).execute(stream, +userW_src_memory,convW_src_memory); } auto convW_weights_memory = userW_weights_memory; - if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) { - convW_weights_memory = dnnl::memory(convW_prim_desc.diff_weights_desc(), engine); + if (convW_prim_desc.diff_weights_desc() != +userW_weights_memory.get_desc()) { convW_weights_memory = +dnnl::memory(convW_prim_desc.diff_weights_desc(), engine); } auto convW_dst_memory = userW_dst_memory; if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) { - convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine); - reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory, convW_dst_memory); + convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), +engine); reorder(userW_dst_memory, convW_dst_memory).execute(stream, +userW_dst_memory, convW_dst_memory); } if (gradB != nullptr) { - auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine, gradB->buffer()); + auto convW_bias_memory = +dnnl::memory(convW_prim_desc.diff_bias_desc(), engine, gradB->buffer()); convolution_backward_weights(convW_prim_desc).execute(stream, - {{DNNL_ARG_SRC, convW_src_memory}, - {DNNL_ARG_DIFF_DST, convW_dst_memory}, - {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}, - {DNNL_ARG_DIFF_BIAS, convW_bias_memory}}); + {{DNNL_ARG_SRC, +convW_src_memory}, {DNNL_ARG_DIFF_DST, convW_dst_memory}, + {DNNL_ARG_DIFF_WEIGHTS, +convW_weights_memory}, {DNNL_ARG_DIFF_BIAS, convW_bias_memory}}); } else { convolution_backward_weights(convW_prim_desc).execute(stream, - {{DNNL_ARG_SRC, convW_src_memory}, - {DNNL_ARG_DIFF_DST, convW_dst_memory}, - {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}}); + {{DNNL_ARG_SRC, +convW_src_memory}, {DNNL_ARG_DIFF_DST, convW_dst_memory}, + {DNNL_ARG_DIFF_WEIGHTS, +convW_weights_memory}}); } - if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) { - reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory, + if (convW_prim_desc.diff_weights_desc() != +userW_weights_memory.get_desc()) { reorder(convW_weights_memory, +userW_weights_memory).execute(stream, convW_weights_memory, userW_weights_memory); } @@ -435,38 +535,48 @@ static void conv2dBpMKLDNN(sd::graph::Context &block, if (gradI != nullptr) { - auto convI_desc = convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); + auto convI_desc = +convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md, +conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, +conv_padding_r); - auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, conv_prim_desc); - auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer()); - auto userI_weights_memory = dnnl::memory(user_weights_md, engine,const_cast(weights)->buffer()); - auto userI_dst_memory = dnnl::memory(user_dst_md, engine, const_cast(gradO)->buffer()); + auto convI_prim_desc = +convolution_backward_data::primitive_desc(convI_desc, engine, conv_prim_desc); + auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, +gradI->buffer()); auto userI_weights_memory = dnnl::memory(user_weights_md, +engine,const_cast(weights)->buffer()); auto userI_dst_memory = +dnnl::memory(user_dst_md, engine, const_cast(gradO)->buffer()); auto convI_src_memory = userI_src_memory; if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) { - convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine); + convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), +engine); } auto convI_weights_memory = userI_weights_memory; if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) { - convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine); - reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory, convI_weights_memory); + convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), +engine); reorder(userI_weights_memory, convI_weights_memory).execute(stream, +userI_weights_memory, convI_weights_memory); } auto convI_dst_memory = userI_dst_memory; if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) { - convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine); - reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory, convI_dst_memory); + convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), +engine); reorder(userI_dst_memory, convI_dst_memory).execute(stream, +userI_dst_memory, convI_dst_memory); } convolution_backward_data(convI_prim_desc).execute(stream, - {{DNNL_ARG_DIFF_DST, convI_dst_memory}, - {DNNL_ARG_WEIGHTS, convI_weights_memory}, - {DNNL_ARG_DIFF_SRC, convI_src_memory}}); + {{DNNL_ARG_DIFF_DST, +convI_dst_memory}, {DNNL_ARG_WEIGHTS, convI_weights_memory}, + {DNNL_ARG_DIFF_SRC, +convI_src_memory}}); if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) { - reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, userI_src_memory); + reorder(convI_src_memory, userI_src_memory).execute(stream, +convI_src_memory, userI_src_memory); } stream.wait(); @@ -477,116 +587,172 @@ static void conv2dBpMKLDNN(sd::graph::Context &block, ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv2d, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - - auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0)); // filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // filter(kernel) width - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV2D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV2D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - conv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + auto output = + OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + bool isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, + // iC, kH, kW], 2 - [oC, kH, kW, iC] + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW, paddingMode); + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CONV2D MKLDNN OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CONV2D MKLDNN OP: wrong shape of array with biases, expected " + "rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + conv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, + paddingMode, isNCHW, wFormat); + + return Status::OK(); } - PLATFORM_CHECK(conv2d, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); + auto input = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); - // conv2d is only available for float32 dtype - return block.isUseMKLDNN() && input->dataType() == sd::DataType::FLOAT32 && - weights->dataType() == sd::DataType::FLOAT32; + // conv2d is only available for float32 dtype + return block.isUseMKLDNN() && input->dataType() == sd::DataType::FLOAT32 && + weights->dataType() == sd::DataType::FLOAT32; } ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); - - if(paddingMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CONV2D_BP MKLDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV2D_BP MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV2D_BP MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - conv2dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, + // oH, oW] (NCHW), epsilon_next + + auto gradI = OUTPUT_NULLIFIED( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + auto gradW = OUTPUT_NULLIFIED( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, + // iC, kH, kW], 2 - [oC, kH, kW, iC] + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, paddingMode); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW, paddingMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CONV2D_BP MKLDNN OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CONV2D_BP MKLDNN OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CONV2D_BP MKLDNN OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + conv2dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, + sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); + + return Status::OK(); } PLATFORM_CHECK(conv2d_bp, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - - return block.isUseMKLDNN() && - sd::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB}); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, + // oH, oW] (NCHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + return block.isUseMKLDNN() && + sd::MKLDNNStream::isSupported( + {input, weights, bias, gradO, gradI, gradW, gradB}); } - - -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp index bfa9e49d1324..806918b5e7ba 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp @@ -19,51 +19,59 @@ // @author raver119@gmail.com // -#include +#include #include +#include +#include #include -#include #include "mkldnnUtils.h" -#include using namespace dnnl; -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////// static void conv3dMKLDNN(const NDArray *input, const NDArray *weights, - const NDArray *bias, NDArray *output, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const int paddingMode, const int isNCDHW, const int wFormat) { - - // mkl support weights in [oC, iC, kD, kH, kW] format only - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - // const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d - - dnnl::memory::dims strides = {sD, sH, sW}; - dnnl::memory::dims padding = {pD, pH, pW}; - // dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; - dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW}; - dnnl::memory::dims dilation = {dD-1, dH-1, dW-1}; - - auto xzFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw; - - dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; - dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; - dnnl::memory::dims zDims = {bS, oC, oD, oH, oW}; - - std::vector permut; + const NDArray *bias, NDArray *output, const int kD, + const int kH, const int kW, const int sD, const int sH, + const int sW, const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const int paddingMode, const int isNCDHW, + const int wFormat) { + // mkl support weights in [oC, iC, kD, kH, kW] format only + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWiC, indWoC, indWkD); + + // const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) + // * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d + + dnnl::memory::dims strides = {sD, sH, sW}; + dnnl::memory::dims padding = {pD, pH, pW}; + // dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * + // sW - iW + kW - pWSame }; + dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, + (oH - 1) * sH - iH + kH - pH, + (oW - 1) * sW - iW + kW - pW}; + dnnl::memory::dims dilation = {dD - 1, dH - 1, dW - 1}; + + auto xzFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw + : dnnl::memory::format_tag::ndhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw; + + dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oD, oH, oW}; + + std::vector permut; if(0 == wFormat) permut = {4,3,0,1,2}; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW] else if(2 == wFormat) @@ -71,99 +79,122 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights, auto type = dnnl::memory::data_type::f32; - // memory descriptors for arrays - - // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(*input, x_user_md); - - // weights - dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); - mkldnnUtils::setBlockStrides(*weights, w_user_md, permut); - - // bias - dnnl::memory::desc b_mkl_md; - if(bias != nullptr) - b_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x); - - // output - dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(*output, z_user_md); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // operation primitive description - dnnl::convolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_forward::primitive_desc op_prim_desc(op_desc, engine); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // weights - mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - - // bias - if(bias != nullptr) { - auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, const_cast(bias->buffer())); - args[DNNL_ARG_BIAS] = b_mkl_mem; - } - - // output - auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]); - - // run calculations - dnnl::convolution_forward(op_prim_desc).execute(stream, args); - - // reorder outputs if necessary - if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) - dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); - - stream.wait(); + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); + mkldnnUtils::setBlockStrides(*input, x_user_md); + + // weights + dnnl::memory::desc w_mkl_md = + dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); + mkldnnUtils::setBlockStrides(*weights, + w_user_md, permut); + + // bias + dnnl::memory::desc b_mkl_md; + if (bias != nullptr) + b_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x); + + // output + dnnl::memory::desc z_mkl_md = + dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl); + mkldnnUtils::setBlockStrides(*output, z_user_md); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // operation primitive description + dnnl::convolution_forward::desc op_desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, + x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, + padding_r); + dnnl::convolution_forward::primitive_desc op_prim_desc(op_desc, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, + op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + + // weights + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, + op_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + + // bias + if (bias != nullptr) { + auto b_mkl_mem = + dnnl::memory(b_mkl_md, engine, const_cast(bias->buffer())); + args[DNNL_ARG_BIAS] = b_mkl_mem; + } + + // output + auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), + args[DNNL_ARG_DST] ); + + // run calculations + dnnl::convolution_forward(op_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); + + stream.wait(); } ////////////////////////////////////////////////////////////////////// -static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO, - NDArray *gradI, NDArray *gradW, NDArray *gradB, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const int paddingMode, const int isNCDHW, const int wFormat) { - - // mkl support weights/gradW in [oC, iC, kD, kH, kW] format only - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - // const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d - - dnnl::memory::dims strides = {sD, sH, sW}; - dnnl::memory::dims padding = {pD, pH, pW}; - // dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; - dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW}; - dnnl::memory::dims dilation = {dD-1, dH-1, dW-1}; - - auto xzFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw; - - dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; - dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; - dnnl::memory::dims zDims = {bS, oC, oD, oH, oW}; - - auto type = dnnl::memory::data_type::f32; - - std::vector permut; +static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, + const NDArray *bias, const NDArray *gradO, + NDArray *gradI, NDArray *gradW, NDArray *gradB, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const int paddingMode, const int isNCDHW, + const int wFormat) { + // mkl support weights/gradW in [oC, iC, kD, kH, kW] format only + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWiC, indWoC, indWkD); + + // const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) + // * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d + + dnnl::memory::dims strides = {sD, sH, sW}; + dnnl::memory::dims padding = {pD, pH, pW}; + // dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * + // sW - iW + kW - pWSame }; + dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, + (oH - 1) * sH - iH + kH - pH, + (oW - 1) * sW - iW + kW - pW}; + dnnl::memory::dims dilation = {dD - 1, dH - 1, dW - 1}; + + auto xzFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw + : dnnl::memory::format_tag::ndhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw; + + dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oD, oH, oW}; + + auto type = dnnl::memory::data_type::f32; + + std::vector permut; if(0 == wFormat) permut = {4,3,0,1,2}; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW] else if(2 == wFormat) @@ -171,158 +202,207 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N // memory descriptors for arrays - // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(*input, x_user_md); - - // weights - dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); - mkldnnUtils::setBlockStrides(*weights, w_user_md, permut); - - // gradO - dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl); - - mkldnnUtils::setBlockStrides(*gradO, gradO_user_md); - - // gradI - dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); - - mkldnnUtils::setBlockStrides(*gradI, gradI_user_md); - - // gradW - dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); - mkldnnUtils::setBlockStrides(*gradW, gradW_user_md, permut); - - // gradB - dnnl::memory::desc gradB_mkl_md; - if(gradB != nullptr) - gradB_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // forward primitive description - dnnl::convolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); - - // backward data primitive description - dnnl::convolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc); - - // backward weights primitive description - dnnl::convolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::convolution_auto, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // weights - mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - - // gradO - auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast(gradO->buffer())); - const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; - auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; - if (gradOReorderW) - dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW); - if (gradOReorderD) - dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD); - args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; - - // gradI - auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]); - - // gradW - auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, op_weights_bp_prim_desc.diff_weights_desc(), args[DNNL_ARG_DIFF_WEIGHTS]); - - // gradB - if(gradB != nullptr) { - auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->buffer()); - args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem; - } - - // run backward data calculations - dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args); - - if(gradOReorderW || gradOReorderD) - args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW; - - // run backward weights calculations - dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); - - // reorder gradI if necessary - if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc()) - dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem); - if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc()) - dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem).execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem); - - stream.wait(); - - // shape::printArray(z_mkl_mem.map_data(),8); + // input + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); + mkldnnUtils::setBlockStrides(*input, x_user_md); + + // weights + dnnl::memory::desc w_mkl_md = + dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); + mkldnnUtils::setBlockStrides(*weights, + w_user_md, permut); + + // gradO + dnnl::memory::desc gradO_mkl_md = + dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc gradO_user_md = + dnnl::memory::desc(zDims, type, xzFormatMkl); + + mkldnnUtils::setBlockStrides(*gradO, gradO_user_md); + + // gradI + dnnl::memory::desc gradI_mkl_md = + dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc gradI_user_md = + dnnl::memory::desc(xDims, type, xzFormatMkl); + + mkldnnUtils::setBlockStrides(*gradI, gradI_user_md); + + // gradW + dnnl::memory::desc gradW_mkl_md = + dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc gradW_user_md = + dnnl::memory::desc(wDims, type, wFormatMkl); + mkldnnUtils::setBlockStrides(*gradW, + gradW_user_md, permut); + + // gradB + dnnl::memory::desc gradB_mkl_md; + if (gradB != nullptr) + gradB_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // forward primitive description + dnnl::convolution_forward::desc op_ff_desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, + x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, + padding, padding_r); + dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + + // backward data primitive description + dnnl::convolution_backward_data::desc op_data_bp_desc( + dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, + strides, dilation, padding, padding_r); + dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc( + op_data_bp_desc, engine, op_ff_prim_desc); + + // backward weights primitive description + dnnl::convolution_backward_weights::desc op_weights_bp_desc( + dnnl::algorithm::convolution_auto, x_mkl_md, gradW_mkl_md, gradB_mkl_md, + gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::convolution_backward_weights::primitive_desc op_weights_bp_prim_desc( + op_weights_bp_desc, engine, op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, + op_weights_bp_prim_desc.src_desc(), + args[DNNL_ARG_SRC]); + + // weights + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, + op_data_bp_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + + // gradO + auto gradO_user_mem = + dnnl::memory(gradO_user_md, engine, const_cast(gradO->buffer())); + const bool gradOReorderW = + op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + const bool gradOReorderD = + op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + auto gradO_mkl_memW = + gradOReorderW + ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) + : gradO_user_mem; + auto gradO_mkl_memD = + gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) + : gradO_user_mem; + if (gradOReorderW) + dnnl::reorder(gradO_user_mem, gradO_mkl_memW) + .execute(stream, gradO_user_mem, gradO_mkl_memW); + if (gradOReorderD) + dnnl::reorder(gradO_user_mem, gradO_mkl_memD) + .execute(stream, gradO_user_mem, gradO_mkl_memD); + args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; + + // gradI + auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, + op_data_bp_prim_desc.diff_src_desc() , + args[DNNL_ARG_DIFF_SRC] ); + + // gradW + auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, + op_weights_bp_prim_desc.diff_weights_desc() , + args[DNNL_ARG_DIFF_WEIGHTS] ); + + // gradB + if (gradB != nullptr) { + auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->buffer()); + args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem; + } + + // run backward data calculations + dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args); + + if (gradOReorderW || gradOReorderD) args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW; + + // run backward weights calculations + dnnl::convolution_backward_weights(op_weights_bp_prim_desc) + .execute(stream, args); + + // reorder gradI if necessary + if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem) + .execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem); + if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem) + .execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem); + + stream.wait(); + + // shape::printArray(z_mkl_mem.map_data(),8); } - /* ////////////////////////////////////////////////////////////////////// static void conv3dMKLDNN(sd::graph::Context &block, - const NDArray *input, const NDArray *weights, const NDArray *bias, - NDArray *output, - const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, int pD, int pH, int pW, const int dD, const int dH, const int dW, - const int paddingMode, const int isNCDHW) { - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + const NDArray *input, const NDArray *weights, const +NDArray *bias, NDArray *output, const int kD, const int kH, const int kW, const +int sD, const int sH, const int sW, int pD, int pH, int pW, const int dD, const +int dH, const int dW, const int paddingMode, const int isNCDHW) { + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, +input channels, input depth/height/width, output channels, output +depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // +corresponding indexes ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, +*input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, +indWoC, indWkD); dnnl_memory_desc_t empty; - dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md( empty); - dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md( empty); - - dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; - - mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, - isNCDHW, - bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, weights, - nullptr, bias, output, - &conv_src_md, nullptr, &conv_weights_md, nullptr, - &conv_bias_md, &conv_dst_md, - &user_src_md, nullptr, &user_weights_md, nullptr, - &user_bias_md, &user_dst_md, - conv_strides, conv_padding, conv_padding_r, conv_dilation); - auto conv_desc = bias != nullptr ? convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) - : convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); - - auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine); - auto user_src_memory = dnnl::memory(user_src_md, engine, const_cast(input)->buffer()); - auto user_weights_memory = dnnl::memory(user_weights_md, engine, const_cast(weights)->buffer()); + dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), +conv_bias_md(empty), conv_dst_md( empty); dnnl::memory::desc user_src_md(empty), +user_weights_md(empty), user_bias_md(empty), user_dst_md( empty); + + dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, +conv_dilation; + + mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, +dD, dH, dW, paddingMode, isNCDHW, bS, iC, iD, iH, iW, oC, oD, oH, oW, input, +nullptr, weights, nullptr, bias, output, &conv_src_md, nullptr, +&conv_weights_md, nullptr, &conv_bias_md, &conv_dst_md, &user_src_md, nullptr, +&user_weights_md, nullptr, &user_bias_md, &user_dst_md, conv_strides, +conv_padding, conv_padding_r, conv_dilation); auto conv_desc = bias != nullptr ? +convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, +conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, +conv_dilation, conv_padding, conv_padding_r) : +convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, +conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, +conv_padding, conv_padding_r); + + auto engine = +mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); dnnl::stream +stream(engine); + + auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, +engine); auto user_src_memory = dnnl::memory(user_src_md, engine, +const_cast(input)->buffer()); auto user_weights_memory = +dnnl::memory(user_weights_md, engine, const_cast(weights)->buffer()); auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); auto conv_src_memory = user_src_memory; if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) { conv_src_memory = dnnl::memory(conv_prim_desc.src_desc(), engine); - reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory); + reorder(user_src_memory, conv_src_memory).execute(stream, +user_src_memory, conv_src_memory); } auto conv_weights_memory = user_weights_memory; if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) { - conv_weights_memory = dnnl::memory(conv_prim_desc.weights_desc(), engine); - reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory, conv_weights_memory); + conv_weights_memory = dnnl::memory(conv_prim_desc.weights_desc(), +engine); reorder(user_weights_memory, conv_weights_memory).execute(stream, +user_weights_memory, conv_weights_memory); } auto conv_dst_memory = user_dst_memory; @@ -331,20 +411,21 @@ static void conv3dMKLDNN(sd::graph::Context &block, } if (bias != nullptr) { - auto conv_bias_memory = dnnl::memory(conv_prim_desc.bias_desc(), engine, bias->buffer()); - convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory}, - {DNNL_ARG_WEIGHTS, conv_weights_memory}, - {DNNL_ARG_BIAS, conv_bias_memory}, - {DNNL_ARG_DST, conv_dst_memory}}); + auto conv_bias_memory = dnnl::memory(conv_prim_desc.bias_desc(), engine, +bias->buffer()); convolution_forward(conv_prim_desc).execute(stream, +{{DNNL_ARG_SRC, conv_src_memory}, {DNNL_ARG_WEIGHTS, conv_weights_memory}, + {DNNL_ARG_BIAS, +conv_bias_memory}, {DNNL_ARG_DST, conv_dst_memory}}); } else { - convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory}, - {DNNL_ARG_WEIGHTS, conv_weights_memory}, - {DNNL_ARG_DST, conv_dst_memory}}); + convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, +conv_src_memory}, {DNNL_ARG_WEIGHTS, conv_weights_memory}, {DNNL_ARG_DST, +conv_dst_memory}}); } if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) - reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory); + reorder(conv_dst_memory, user_dst_memory).execute(stream, +conv_dst_memory, user_dst_memory); stream.wait(); } @@ -352,247 +433,380 @@ static void conv3dMKLDNN(sd::graph::Context &block, ////////////////////////////////////////////////////////////////////// static void conv3dBpMKLDNN(sd::graph::Context &block, - const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO, - NDArray *gradI, NDArray *gradW, NDArray *gradB, - const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, int pD, int pH, int pW, const int dD, const int dH, const int dW, + const NDArray *input, const NDArray *weights, const +NDArray *bias, const NDArray *gradO, NDArray *gradI, NDArray *gradW, NDArray +*gradB, const int kD, const int kH, const int kW, const int sD, const int sH, +const int sW, int pD, int pH, int pW, const int dD, const int dH, const int dW, const int paddingMode, const int isNCDHW) { - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, +input channels, input depth/height/width, output channels, output +depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // +corresponding indexes ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, +*input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, +indWoC, indWkD); dnnl_memory_desc_t empty; - dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty); - - dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; - - mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, - isNCDHW, - bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, weights, - gradW, gradB, gradO, - &conv_src_md, &conv_diff_src_md, &conv_weights_md, - &conv_diff_weights_md, &conv_bias_md, &conv_dst_md, - &user_src_md, &user_diff_src_md, &user_weights_md, - &user_diff_weights_md, &user_bias_md, &user_dst_md, - conv_strides, conv_padding, conv_padding_r, conv_dilation); - - auto conv_desc = gradB != nullptr ? convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) - : convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); - - auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine())); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); + dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), +conv_weights_md(empty), conv_diff_weights_md(empty), conv_bias_md(empty), +conv_dst_md(empty); dnnl::memory::desc user_src_md(empty), +user_diff_src_md(empty), user_weights_md(empty), user_diff_weights_md(empty), +user_bias_md(empty), user_dst_md(empty); + + dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, +conv_dilation; + + mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, +dD, dH, dW, paddingMode, isNCDHW, bS, iC, iD, iH, iW, oC, oD, oH, oW, input, +gradI, weights, gradW, gradB, gradO, &conv_src_md, &conv_diff_src_md, +&conv_weights_md, &conv_diff_weights_md, &conv_bias_md, &conv_dst_md, + &user_src_md, &user_diff_src_md, +&user_weights_md, &user_diff_weights_md, &user_bias_md, &user_dst_md, + conv_strides, conv_padding, +conv_padding_r, conv_dilation); + + auto conv_desc = gradB != nullptr ? +convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, +conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, +conv_dilation, conv_padding, conv_padding_r) : +convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, +conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, +conv_padding, conv_padding_r); + + auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, +mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine())); + + auto engine = +mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); dnnl::stream +stream(engine); if (gradW != nullptr) { - auto convW_desc = gradB != nullptr ? convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) - : convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, conv_prim_desc); - - auto userW_src_memory = dnnl::memory(user_src_md, engine, const_cast(input)->buffer()); - auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer()); - auto userW_dst_memory = dnnl::memory(user_dst_md, engine, const_cast(gradO)->buffer()); + auto convW_desc = gradB != nullptr ? +convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, +conv_diff_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, +conv_padding, conv_padding_r) : +convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, +conv_diff_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, +conv_padding_r); auto engine = +mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + auto convW_prim_desc = +convolution_backward_weights::primitive_desc(convW_desc, engine, +conv_prim_desc); + + auto userW_src_memory = dnnl::memory(user_src_md, engine, +const_cast(input)->buffer()); auto userW_weights_memory = +dnnl::memory(user_diff_weights_md, engine, gradW->buffer()); auto +userW_dst_memory = dnnl::memory(user_dst_md, engine, const_cast(gradO)->buffer()); auto convW_src_memory = userW_src_memory; if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) { convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine); - reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory, convW_src_memory); + reorder(userW_src_memory, convW_src_memory).execute(stream, +userW_src_memory, convW_src_memory); } auto convW_weights_memory = userW_weights_memory; - if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) { - convW_weights_memory = dnnl::memory(convW_prim_desc.diff_weights_desc(), engine); + if (convW_prim_desc.diff_weights_desc() != +userW_weights_memory.get_desc()) { convW_weights_memory = +dnnl::memory(convW_prim_desc.diff_weights_desc(), engine); } auto convW_dst_memory = userW_dst_memory; if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) { - convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine); - reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory, convW_dst_memory); + convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), +engine); reorder(userW_dst_memory, convW_dst_memory).execute(stream, +userW_dst_memory, convW_dst_memory); } if (gradB != nullptr) { - auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine, gradB->buffer()); + auto convW_bias_memory = +dnnl::memory(convW_prim_desc.diff_bias_desc(), engine, gradB->buffer()); convolution_backward_weights(convW_prim_desc).execute(stream, - {{DNNL_ARG_SRC, convW_src_memory}, - {DNNL_ARG_DIFF_DST, convW_dst_memory}, - {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}, - {DNNL_ARG_DIFF_BIAS, convW_bias_memory}}); + {{DNNL_ARG_SRC, +convW_src_memory}, {DNNL_ARG_DIFF_DST, convW_dst_memory}, + {DNNL_ARG_DIFF_WEIGHTS, +convW_weights_memory}, {DNNL_ARG_DIFF_BIAS, convW_bias_memory}}); } else { convolution_backward_weights(convW_prim_desc).execute(stream, - {{DNNL_ARG_SRC, convW_src_memory}, - {DNNL_ARG_DIFF_DST, convW_dst_memory}, - {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}}); + {{DNNL_ARG_SRC, +convW_src_memory}, {DNNL_ARG_DIFF_DST, convW_dst_memory}, + {DNNL_ARG_DIFF_WEIGHTS, +convW_weights_memory}}); } - if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) - reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory, userW_weights_memory); + if (convW_prim_desc.diff_weights_desc() != +userW_weights_memory.get_desc()) reorder(convW_weights_memory, +userW_weights_memory).execute(stream, convW_weights_memory, +userW_weights_memory); stream.wait(); } if (gradI != nullptr) { - auto convI_desc = convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); - - auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, conv_prim_desc); - auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer()); - auto userI_weights_memory = dnnl::memory(user_weights_md, engine, const_cast(weights)->buffer()); - auto userI_dst_memory = dnnl::memory(user_dst_md, engine, const_cast(gradO)->buffer()); + auto convI_desc = +convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md, +conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, +conv_padding_r); + + auto convI_prim_desc = +convolution_backward_data::primitive_desc(convI_desc, engine, conv_prim_desc); + auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, +gradI->buffer()); auto userI_weights_memory = dnnl::memory(user_weights_md, +engine, const_cast(weights)->buffer()); auto userI_dst_memory = +dnnl::memory(user_dst_md, engine, const_cast(gradO)->buffer()); auto convI_src_memory = userI_src_memory; if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) - convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine); + convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), +engine); auto convI_weights_memory = userI_weights_memory; if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) { - convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine); - reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory, convI_weights_memory); + convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), +engine); reorder(userI_weights_memory, convI_weights_memory).execute(stream, +userI_weights_memory, convI_weights_memory); } auto convI_dst_memory = userI_dst_memory; if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) { - convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine); - reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory, convI_dst_memory); + convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), +engine); reorder(userI_dst_memory, convI_dst_memory).execute(stream, +userI_dst_memory, convI_dst_memory); } convolution_backward_data(convI_prim_desc).execute(stream, - {{DNNL_ARG_DIFF_DST, convI_dst_memory}, - {DNNL_ARG_WEIGHTS, convI_weights_memory}, - {DNNL_ARG_DIFF_SRC, convI_src_memory}}); + {{DNNL_ARG_DIFF_DST, +convI_dst_memory}, {DNNL_ARG_WEIGHTS, convI_weights_memory}, + {DNNL_ARG_DIFF_SRC, +convI_src_memory}}); if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) - reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, userI_src_memory); + reorder(convI_src_memory, userI_src_memory).execute(stream, +convI_src_memory, userI_src_memory); } } */ ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv3dnew, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) - - REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM CONV3D MKLDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf()); - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, iC, oC], 1 - [oC, iC, kD, kH, kW], 2 - [oC, kD, kH, kW, iC] - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - if (paddingMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - conv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW, wFormat); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto output = OUTPUT_VARIABLE( + 0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "CUSTOM CONV3D MKLDNN OP: rank of input array must be equal to " + "5, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, + "CUSTOM CONV3D MKLDNN OP: rank of weights array must be equal " + "to 5, but got %i instead !", + weights->rankOf()); + + int kD = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 + ? INT_ARG(2) + : static_cast(weights->sizeAt(2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID + int isNCDHW = + block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 + ? INT_ARG(14) + : 0; // 0 - [kD, kH, kW, iC, oC], 1 - [oC, iC, kD, kH, kW], + // 2 - [oC, kD, kH, kW, iC] + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM CONV3D MKLDNN OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM CONV3D MKLDNN OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + conv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, + dD, dH, dW, paddingMode, isNCDHW, wFormat); + + return Status::OK(); } PLATFORM_CHECK(conv3dnew, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) - - return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, weights, bias, output}); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto output = OUTPUT_VARIABLE( + 0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) + + return block.isUseMKLDNN() && + sd::MKLDNNStream::isSupported({input, weights, bias, output}); } ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_NULLIFIED(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] - - REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !", gradO->rankOf()); - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, iC, oC], 1 - [oC, iC, kD, kH, kW], 2 - [oC, kD, kH, kW, iC] - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - if(paddingMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - int trueoD, trueoH, trueoW; // true output depth/height/width - ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx( {bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - conv3dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW, wFormat); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = + block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_NULLIFIED(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_NULLIFIED( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "CUSTOM CONV3D_BP MKLDNN OP: rank of input array must be equal " + "to 5, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, + "CUSTOM CONV3D_BP MKLDNN OP: rank of weights array must be " + "equal to 5, but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 5, 0, + "CUSTOM CONV3D_BP MKLDNN OP: rank of output gradients (next " + "epsilon) array must be equal to 5, but got %i instead !", + gradO->rankOf()); + + int kD = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 + ? INT_ARG(2) + : static_cast(weights->sizeAt(2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + int isNCDHW = + block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 + ? INT_ARG(14) + : 0; // 0 - [kD, kH, kW, iC, oC], 1 - [oC, iC, kD, kH, kW], + // 2 - [oC, kD, kH, kW, iC] + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWiC, indWoC, indWkD); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + int trueoD, trueoH, trueoW; // true output depth/height/width + ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, + sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, + iW, paddingMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, oC, trueoD, trueoH, trueoW, + 0, indIOioC, indIOioD, + indIOioD + 1, indIOioD + 2}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM CONV3D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + conv3dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kD, kH, kW, + sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW, + wFormat); + + return Status::OK(); } PLATFORM_CHECK(conv3dnew_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - return block.isUseMKLDNN() && - sd::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB}); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = + block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_VARIABLE( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + return block.isUseMKLDNN() && + sd::MKLDNNStream::isSupported( + {input, weights, bias, gradO, gradI, gradW, gradB}); } - - -} -} -} \ No newline at end of file +} // namespace platforms +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp index b7b58b409d53..1f385c721c7c 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp @@ -18,436 +18,613 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include +#include #include +#include +#include #include -#include #include "mkldnnUtils.h" -#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////////// -static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, - const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const int paddingMode, const bool isNCHW, const int wFormat) { - - // mkl supports weights format [oC, iC, kH, kW] only - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); - - dnnl::memory::dims strides = { sH, sW }; - dnnl::memory::dims padding = { pH, pW }; - dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; - dnnl::memory::dims dilation = { dH-1, dW-1 }; - - std::vector permut; - if(0 == wFormat) - permut = {2,3,0,1}; // [kH, kW, oC, iC] -> [oC, iC, kH, kW] - else if(1 == wFormat) - permut = {1,0,2,3}; // [iC, oC, kH, kW] -> [oC, iC, kH, kW] - else - permut = {3,0,1,2}; // [iC, kH, kW, oC] -> [oC, iC, kH, kW] - - // input type - dnnl::memory::data_type xType; - if(input->dataType() == DataType::FLOAT32) - xType = dnnl::memory::data_type::f32; - else if(input->dataType() == DataType::HALF) - xType = dnnl::memory::data_type::f16; - else if(input->dataType() == DataType::UINT8) - xType = dnnl::memory::data_type::u8; - else - xType = dnnl::memory::data_type::s8; - - // weights type - dnnl::memory::data_type wType = xType; - if(xType == dnnl::memory::data_type::u8) - wType = dnnl::memory::data_type::s8; - - // output and bias type (have the same types) - dnnl::memory::data_type zType; - if(output->dataType() == DataType::FLOAT32) - zType = dnnl::memory::data_type::f32; - else if(output->dataType() == DataType::HALF) - zType = dnnl::memory::data_type::f16; - else if(output->dataType() == DataType::UINT8) - zType = dnnl::memory::data_type::u8; - else if(output->dataType() == DataType::INT8) - zType = dnnl::memory::data_type::s8; - else - zType = dnnl::memory::data_type::s32; - - dnnl::memory::format_tag xFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; - - dnnl::memory::dims xDims = {bS, iC, iH, iW}; - dnnl::memory::dims wDims = {oC, iC, kH, kW}; - dnnl::memory::dims zDims = {bS, oC, oH, oW}; - - // memory descriptors for arrays - - // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); - mkldnnUtils::setBlockStrides(*input, x_user_md); - - // weights - dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); - mkldnnUtils::setBlockStrides(*weights, w_user_md, permut); - - // bias - dnnl::memory::desc b_mkl_md; - if(bias != nullptr) - b_mkl_md = dnnl::memory::desc({oC}, zType, dnnl::memory::format_tag::x); - - // output - dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl); - mkldnnUtils::setBlockStrides(*output, z_user_md); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // operation primitive description - dnnl::deconvolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, - x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r); - dnnl::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // weights - mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - - // bias - if(bias != nullptr) { - auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, const_cast(bias->buffer())); - args[DNNL_ARG_BIAS] = b_mkl_mem; - } - - // output - auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]); - - // run calculations - dnnl::deconvolution_forward(op_prim_desc).execute(stream, args); - - // reorder outputs if necessary - if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) - dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); - - stream.wait(); - - // shape::printArray(z_mkl_mem.map_data(),8); +static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, + const NDArray* bias, NDArray* output, const int kH, + const int kW, const int sH, const int sW, + const int pH, const int pW, const int dH, + const int dW, const int paddingMode, + const bool isNCHW, const int wFormat) { + // mkl supports weights format [oC, iC, kH, kW] only + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWoC, indWiC, indWkH, indOoH); + + dnnl::memory::dims strides = {sH, sW}; + dnnl::memory::dims padding = {pH, pW}; + dnnl::memory::dims padding_r = {(iH - 1) * sH - oH + kH - pH, + (iW - 1) * sW - oW + kW - pW}; + dnnl::memory::dims dilation = {dH - 1, dW - 1}; + + std::vector permut; + if (0 == wFormat) + permut = {2,3,0,1}; // [kH, kW, oC, iC] -> [oC, iC, kH, kW] + else if (1 == wFormat) + permut = {1,0,2,3}; // [iC, oC, kH, kW] -> [oC, iC, kH, kW] + else + permut = {3,0,1,2}; // [iC, kH, kW, oC] -> [oC, iC, kH, kW] + + + // input type + dnnl::memory::data_type xType; + if (input->dataType() == DataType::FLOAT32) + xType = dnnl::memory::data_type::f32; + else if (input->dataType() == DataType::HALF) + xType = dnnl::memory::data_type::f16; + else if (input->dataType() == DataType::UINT8) + xType = dnnl::memory::data_type::u8; + else + xType = dnnl::memory::data_type::s8; + + // weights type + dnnl::memory::data_type wType = xType; + if (xType == dnnl::memory::data_type::u8) wType = dnnl::memory::data_type::s8; + + // output and bias type (have the same types) + dnnl::memory::data_type zType; + if (output->dataType() == DataType::FLOAT32) + zType = dnnl::memory::data_type::f32; + else if (output->dataType() == DataType::HALF) + zType = dnnl::memory::data_type::f16; + else if (output->dataType() == DataType::UINT8) + zType = dnnl::memory::data_type::u8; + else if (output->dataType() == DataType::INT8) + zType = dnnl::memory::data_type::s8; + else + zType = dnnl::memory::data_type::s32; + + dnnl::memory::format_tag xFormatMkl = + isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; + + dnnl::memory::dims xDims = {bS, iC, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oH, oW}; + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); + mkldnnUtils::setBlockStrides(*input, x_user_md); + + // weights + dnnl::memory::desc w_mkl_md = + dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); + mkldnnUtils::setBlockStrides(*weights, + w_user_md, permut); + + // bias + dnnl::memory::desc b_mkl_md; + if (bias != nullptr) + b_mkl_md = dnnl::memory::desc({oC}, zType, dnnl::memory::format_tag::x); + + // output + dnnl::memory::desc z_mkl_md = + dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl); + mkldnnUtils::setBlockStrides(*output, z_user_md); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // operation primitive description + dnnl::deconvolution_forward::desc op_desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, + x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, + padding_r); + dnnl::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, + op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + + // weights + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, + op_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + + // bias + if (bias != nullptr) { + auto b_mkl_mem = + dnnl::memory(b_mkl_md, engine, const_cast(bias->buffer())); + args[DNNL_ARG_BIAS] = b_mkl_mem; + } + + // output + auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), + args[DNNL_ARG_DST] ); + + // run calculations + dnnl::deconvolution_forward(op_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); + + stream.wait(); + + // shape::printArray(z_mkl_mem.map_data(),8); } ////////////////////////////////////////////////////////////////////////// -static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, - const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const int paddingMode, const bool isNCHW, const int wFormat) { - - // mkl supports weights/gradW in [oC, iC, kH, kW] format only - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); - - dnnl::memory::dims strides = { sH, sW }; - dnnl::memory::dims padding = { pH, pW }; - dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; - dnnl::memory::dims dilation = { dH-1, dW-1 }; - - std::vector permut; - if(0 == wFormat) - permut = {2,3,0,1}; // [kH, kW, oC, iC] -> [oC, iC, kH, kW] - else if(1 == wFormat) - permut = {1,0,2,3}; // [iC, oC, kH, kW] -> [oC, iC, kH, kW] - else - permut = {3,0,1,2}; // [iC, kH, kW, oC] -> [oC, iC, kH, kW] - - // input type - dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // weights type - dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradO type - dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradI type - dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradW type - dnnl::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradB type - dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32; - - dnnl::memory::format_tag xFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; - - dnnl::memory::dims xDims = {bS, iC, iH, iW}; - dnnl::memory::dims wDims = {oC, iC, kH, kW}; - dnnl::memory::dims zDims = {bS, oC, oH, oW}; - - // memory descriptors for arrays - - // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); - mkldnnUtils::setBlockStrides(*input, x_user_md); - - // weights - dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); - mkldnnUtils::setBlockStrides(*weights, w_user_md, permut); - - // gradO - dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl); - mkldnnUtils::setBlockStrides(*gradO, gradO_user_md); - - // gradI - dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl); - mkldnnUtils::setBlockStrides(*gradI, gradI_user_md); - - // gradW - dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl); - mkldnnUtils::setBlockStrides(*gradW, gradW_user_md, permut); - - // gradB - dnnl::memory::desc gradB_mkl_md; - if(gradB != nullptr) - gradB_mkl_md = dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // forward primitive description - dnnl::deconvolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); - - // backward data primitive description - dnnl::deconvolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc); - - // backward weights primitive description - dnnl::deconvolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // weights - mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - - // gradO - auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast(gradO->buffer())); - const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; - auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; - if (gradOReorderW) - dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW); - if (gradOReorderD) - dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD); - args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; - - // gradI - auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]); - - // gradW - auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, op_weights_bp_prim_desc.diff_weights_desc(), args[DNNL_ARG_DIFF_WEIGHTS]); - - // gradB - if(gradB != nullptr) { - auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->buffer()); - args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem; - } - - // run backward data calculations - dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args); - - if(gradOReorderW || gradOReorderD) - args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW; - - // run backward weights calculations - dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); - - // reorder gradI if necessary - if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc()) - dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem); - if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc()) - dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem).execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem); - - stream.wait(); - - // shape::printArray(z_mkl_mem.map_data(),8); +static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, + const NDArray* gradO, NDArray* gradI, + NDArray* gradW, NDArray* gradB, const int kH, + const int kW, const int sH, const int sW, + const int pH, const int pW, const int dH, + const int dW, const int paddingMode, + const bool isNCHW, const int wFormat) { + // mkl supports weights/gradW in [oC, iC, kH, kW] format only + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWoC, indWiC, indWkH, indOoH); + + dnnl::memory::dims strides = {sH, sW}; + dnnl::memory::dims padding = {pH, pW}; + dnnl::memory::dims padding_r = {(iH - 1) * sH - oH + kH - pH, + (iW - 1) * sW - oW + kW - pW}; + dnnl::memory::dims dilation = {dH - 1, dW - 1}; + + std::vector permut; + if (0 == wFormat) + permut = {2,3,0,1}; // [kH, kW, oC, iC] -> [oC, iC, kH, kW] + else if (1 == wFormat) + permut = {1,0,2,3}; // [iC, oC, kH, kW] -> [oC, iC, kH, kW] + else + permut = {3,0,1,2}; // [iC, kH, kW, oC] -> [oC, iC, kH, kW] + + + // input type + dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // weights type + dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradO type + dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradI type + dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradW type + dnnl::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradB type + dnnl::memory::data_type gradBType = + gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16) + : dnnl::memory::data_type::f32; + + dnnl::memory::format_tag xFormatMkl = + isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; + + dnnl::memory::dims xDims = {bS, iC, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oH, oW}; + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); + mkldnnUtils::setBlockStrides(*input, x_user_md); + + // weights + dnnl::memory::desc w_mkl_md = + dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); + mkldnnUtils::setBlockStrides(*weights, + w_user_md, permut); + + // gradO + dnnl::memory::desc gradO_mkl_md = + dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradO_user_md = + dnnl::memory::desc(zDims, gradOType, xFormatMkl); + mkldnnUtils::setBlockStrides(*gradO, gradO_user_md); + + // gradI + dnnl::memory::desc gradI_mkl_md = + dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradI_user_md = + dnnl::memory::desc(xDims, gradIType, xFormatMkl); + mkldnnUtils::setBlockStrides(*gradI, gradI_user_md); + + // gradW + dnnl::memory::desc gradW_mkl_md = + dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradW_user_md = + dnnl::memory::desc(wDims, gradWType, wFormatMkl); + mkldnnUtils::setBlockStrides(*gradW, + gradW_user_md, permut); + + // gradB + dnnl::memory::desc gradB_mkl_md; + if (gradB != nullptr) + gradB_mkl_md = + dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // forward primitive description + dnnl::deconvolution_forward::desc op_ff_desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, + x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, + padding, padding_r); + dnnl::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, + engine); + + // backward data primitive description + dnnl::deconvolution_backward_data::desc op_data_bp_desc( + dnnl::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, + gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc( + op_data_bp_desc, engine, op_ff_prim_desc); + + // backward weights primitive description + dnnl::deconvolution_backward_weights::desc op_weights_bp_desc( + dnnl::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, + gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc( + op_weights_bp_desc, engine, op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, + op_weights_bp_prim_desc.src_desc(), + args[DNNL_ARG_SRC]); + + // weights + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, + op_data_bp_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + + // gradO + auto gradO_user_mem = + dnnl::memory(gradO_user_md, engine, const_cast(gradO->buffer())); + const bool gradOReorderW = + op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + const bool gradOReorderD = + op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + auto gradO_mkl_memW = + gradOReorderW + ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) + : gradO_user_mem; + auto gradO_mkl_memD = + gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) + : gradO_user_mem; + if (gradOReorderW) + dnnl::reorder(gradO_user_mem, gradO_mkl_memW) + .execute(stream, gradO_user_mem, gradO_mkl_memW); + if (gradOReorderD) + dnnl::reorder(gradO_user_mem, gradO_mkl_memD) + .execute(stream, gradO_user_mem, gradO_mkl_memD); + args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; + + // gradI + auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, + op_data_bp_prim_desc.diff_src_desc() , + args[DNNL_ARG_DIFF_SRC] ); + + // gradW + auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, + op_weights_bp_prim_desc.diff_weights_desc() , + args[DNNL_ARG_DIFF_WEIGHTS] ); + + // gradB + if (gradB != nullptr) { + auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->buffer()); + args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem; + } + + // run backward data calculations + dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args); + + if (gradOReorderW || gradOReorderD) args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW; + + // run backward weights calculations + dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc) + .execute(stream, args); + + // reorder gradI if necessary + if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem) + .execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem); + if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem) + .execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem); + + stream.wait(); + + // shape::printArray(z_mkl_mem.map_data(),8); } - ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(deconv2d, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - - auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - - REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DECONV2D_MKLDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DECONV2D_MKLDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - if(paddingMode){ // SAME - //Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass - ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); - } - - deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + auto output = + OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "CUSTOM DECONV2D_MKLDNN OP: rank of input array must be equal " + "to 4, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "CUSTOM DECONV2D_MKLDNN OP: rank of weights array must be equal " + "to 4, but got %i instead !", + weights->rankOf()); + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, + // oC, kH, kW], 2 - [iC, kH, kW, oC] + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWoC, indWiC, indWkH, indOoH); + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DECONV2D_MKLDNN OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM DECONV2D_MKLDNN OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + if (paddingMode) { // SAME + // Note: we're intentionally swapping iH and oH, to calculated the padding + // for a"normal" conv (not deconv) forward pass + ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, + dW); + } + + deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, + paddingMode, isNCHW, wFormat); + + return Status::OK(); } PLATFORM_CHECK(deconv2d, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; - - auto output = INPUT_VARIABLE(0); - - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - - const DataType xType = input->dataType(); - const DataType wType = weights->dataType(); - const DataType zType = output->dataType(); - const DataType bType = bias != nullptr ? bias->dataType() : zType; - - return block.isUseMKLDNN() && (dH <= 1 && dW <= 1 && !paddingMode) && - ( - (xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) || - ((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType) - ); + auto input = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; + + auto output = INPUT_VARIABLE(0); + + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + + const DataType xType = input->dataType(); + const DataType wType = weights->dataType(); + const DataType zType = output->dataType(); + const DataType bType = bias != nullptr ? bias->dataType() : zType; + + return block.isUseMKLDNN() && (dH <= 1 && dW <= 1 && !paddingMode) && + ((xType == DataType::FLOAT32 && wType == DataType::FLOAT32 && + bType == DataType::FLOAT32 && zType == DataType::FLOAT32) || + ((xType == DataType::UINT8 || xType == DataType::INT8) && + wType == DataType::INT8 && + (zType == DataType::UINT8 || zType == DataType::INT8 || + zType == DataType::INT32 || zType == DataType::FLOAT32) && + bType == zType)); } - ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI - auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DECONV2D_MKLDNN_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DECONV2D_MKLDNN_BP OP: rank of weights array must be equal to 4 , but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 4, 0, "CUSTOM DECONV2D_MKLDNN_BP OP: rank of output gradients (next epsilon) array must be equal to 4, but got %i instead !", gradO->rankOf()); - - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - if(paddingMode){ // SAME - //Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass - ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); - } - - deconv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI + auto gradW = OUTPUT_VARIABLE( + 1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "CUSTOM DECONV2D_MKLDNN_BP OP: rank of input array must be " + "equal to 4, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "CUSTOM DECONV2D_MKLDNN_BP OP: rank of weights array must be " + "equal to 4 , but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 4, 0, + "CUSTOM DECONV2D_MKLDNN_BP OP: rank of output gradients (next " + "epsilon) array must be equal to 4, but got %i instead !", + gradO->rankOf()); + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, + // oC, kH, kW], 2 - [iC, kH, kW, oC] + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWoC, indWiC, indWkH, indOoH); + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, paddingMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of output gradients " + "(next epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE( + bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + if (paddingMode) { // SAME + // Note: we're intentionally swapping iH and oH, to calculated the padding + // for a"normal" conv (not deconv) forward pass + ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, + dW); + } + + deconv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, + pH, pW, dH, dW, paddingMode, isNCHW, wFormat); + + return Status::OK(); } PLATFORM_CHECK(deconv2d_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI - auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - - const DataType xType = input->dataType(); - const DataType wType = weights->dataType(); - const DataType gradOType = gradO->dataType(); - - const DataType gradIType = gradI->dataType(); - const DataType gradWType = gradW->dataType(); - const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32; - - return block.isUseMKLDNN() && (dH <= 1 && dW <= 1 && !paddingMode) && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) ); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI + auto gradW = OUTPUT_VARIABLE( + 1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + + const DataType xType = input->dataType(); + const DataType wType = weights->dataType(); + const DataType gradOType = gradO->dataType(); + + const DataType gradIType = gradI->dataType(); + const DataType gradWType = gradW->dataType(); + const DataType gradBType = + gradB != nullptr ? gradB->dataType() : DataType::FLOAT32; + + return block.isUseMKLDNN() && (dH <= 1 && dW <= 1 && !paddingMode) && + ((xType == DataType::FLOAT32 || xType == DataType::BFLOAT16) && + (wType == DataType::FLOAT32 || wType == DataType::BFLOAT16) && + (gradOType == DataType::FLOAT32 || gradOType == DataType::BFLOAT16) && + (gradIType == DataType::FLOAT32 || gradIType == DataType::BFLOAT16) && + (gradWType == DataType::FLOAT32 || gradWType == DataType::BFLOAT16) && + (gradBType == DataType::FLOAT32 || gradBType == DataType::BFLOAT16)); } - -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp index 10e3ba77ec8e..08ebcc31d666 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp @@ -18,197 +18,273 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include +#include #include +#include +#include #include -#include #include "mkldnnUtils.h" -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////////// -static void deconv2TFdBpMKLDNN(const NDArray* weights, const NDArray* gradO, NDArray* gradI, - const int bS, const int iC, const int iH, const int iW, const int oC, const int oH, const int oW, - const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const bool isNCHW, const int wFormat) { - - // gradI [bS, iH, iW, iC], mkl doesn't support ndhwc format - // weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, iC, oC] - // gradO [bS, oH, oW, oC] - - dnnl::memory::dims strides = { sH, sW }; - dnnl::memory::dims dilation = { dH - 1, dW - 1 }; - dnnl::memory::dims padding = { pH, pW }; - dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; - - // weights type - dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradO type - dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradI type - dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - - dnnl::memory::format_tag xFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; - - dnnl::memory::dims xDims = {bS, iC, iH, iW}; - dnnl::memory::dims wDims = {oC, iC, kH, kW}; - dnnl::memory::dims zDims = {bS, oC, oH, oW}; - - // memory descriptors for arrays - - // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, gradOType, dnnl::memory::format_tag::any); - - // weights - dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); - mkldnnUtils::setBlockStrides(*weights, w_user_md, {3,2,0,1}); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW] - - // gradO - dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl); - mkldnnUtils::setBlockStrides(*gradO, gradO_user_md); - - // gradI - dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl); - mkldnnUtils::setBlockStrides(*gradI, gradI_user_md); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // forward primitive description - dnnl::convolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); - - // backward data primitive description - dnnl::convolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // weights - mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - - // gradO - mkldnnUtils::loadDataToMklStream(*gradO, engine, stream, gradO_user_md, op_data_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); - - // gradI - auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]); - - // run backward data calculations - dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args); - - // reorder gradI if necessary - if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc()) - dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem); - - stream.wait(); - - // shape::printArray(z_mkl_mem.map_data(),8); +static void deconv2TFdBpMKLDNN(const NDArray* weights, + const NDArray* gradO, NDArray* gradI, + const int bS, const int iC, const int iH, + const int iW, const int oC, const int oH, + const int oW, const int kH, const int kW, + const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW, + const bool isNCHW, const int wFormat) { + // gradI [bS, iH, iW, iC], mkl doesn't support ndhwc format + // weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, + // kW, iC, oC] gradO [bS, oH, oW, oC] + + dnnl::memory::dims strides = {sH, sW}; + dnnl::memory::dims dilation = {dH - 1, dW - 1}; + dnnl::memory::dims padding = {pH, pW}; + dnnl::memory::dims padding_r = {(oH - 1) * sH - iH + kH - pH, + (oW - 1) * sW - iW + kW - pW}; + + // weights type + dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradO type + dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradI type + dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + + dnnl::memory::format_tag xFormatMkl = + isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; + + dnnl::memory::dims xDims = {bS, iC, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oH, oW}; + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xDims, gradOType, dnnl::memory::format_tag::any); + + // weights + dnnl::memory::desc w_mkl_md = + dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); + mkldnnUtils::setBlockStrides(*weights, + w_user_md, {3,2,0,1}); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW] + + + // gradO + dnnl::memory::desc gradO_mkl_md = + dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradO_user_md = + dnnl::memory::desc(zDims, gradOType, xFormatMkl); + mkldnnUtils::setBlockStrides(*gradO, gradO_user_md); + + // gradI + dnnl::memory::desc gradI_mkl_md = + dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradI_user_md = + dnnl::memory::desc(xDims, gradIType, xFormatMkl); + mkldnnUtils::setBlockStrides(*gradI, gradI_user_md); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // forward primitive description + dnnl::convolution_forward::desc op_ff_desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, + x_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + + // backward data primitive description + dnnl::convolution_backward_data::desc op_data_bp_desc( + dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, + strides, dilation, padding, padding_r); + dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc( + op_data_bp_desc, engine, op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // weights + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, + op_data_bp_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + + // gradO + mkldnnUtils::loadDataToMklStream(*gradO, engine, stream, gradO_user_md, + op_data_bp_prim_desc.diff_dst_desc(), + args[DNNL_ARG_DIFF_DST]); + + // gradI + auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, + op_data_bp_prim_desc.diff_src_desc() , + args[DNNL_ARG_DIFF_SRC] ); + + // run backward data calculations + dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args); + + // reorder gradI if necessary + if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem) + .execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem); + + stream.wait(); + + // shape::printArray(z_mkl_mem.map_data(),8); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) { - - auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto gradIShape = INPUT_VARIABLE(0); // [4] - shape of input of conv2d (that is shape of gradI) - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - const int rank = gradO->rankOf(); - - REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM DECONV2D_TF MKLDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, "CUSTOM DECONV2D_TF MKLDNN OP: rank of array with output shape must be equal to 1, but got %i instead !", gradIShape->rankOf()); - REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, "CUSTOM DECONV2D_TF MKLDNN OP: length of array with output shape must be equal to 4, but got %i instead !", gradIShape->lengthOf()); - - int indIOioC, indIiH, indWoC(3), indOoH; - if(!isNCHW) { - indIOioC = 3; indIiH = 1; indOoH = 1; - } - else { - indIOioC = 1; indIiH = 2; indOoH = 2; - } - - std::vector gradIShapeVector = gradIShape->template asVectorT(); - - const int bS = gradIShapeVector[0]; // batch size - const int iH = gradIShapeVector[indIiH]; // input height - const int iW = gradIShapeVector[indIiH+1]; // input width - const int iC = gradIShapeVector[indIOioC]; // input channels - const int oC = weights->sizeAt(indWoC); // output channels - const int oH = gradO->sizeAt(indOoH); // input height - const int oW = gradO->sizeAt(indOoH); // input width - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = {kH, kW, iC, oC}; - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_TF MKLDNN OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_TF MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - - if(isSameMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - // // mkl supports only [oC, iC, kH, kW] for weights - // weights = new NDArray(weights->permute({3,2,0,1})); // [kH, kW, iC, oC] -> [oC, iC, kH, kW] - - // // mkl supports NCHW format only - // if(!isNCHW) { - // gradI = new NDArray(gradI->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - // gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] - // } - - deconv2TFdBpMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, wFormat); - - // delete weights; - - // if(!isNCHW) { - // delete gradI; - // delete gradO; - // } - - // ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); - - return Status::OK(); + auto gradO = INPUT_VARIABLE( + 2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto gradIShape = INPUT_VARIABLE( + 0); // [4] - shape of input of conv2d (that is shape of gradI) + + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, + // iC, kH, kW], 2 - [oC, kH, kW, iC] + + const int rank = gradO->rankOf(); + + REQUIRE_TRUE(weights->rankOf() == rank, 0, + "CUSTOM DECONV2D_TF MKLDNN OP: rank of weights array must be " + "equal to 4, but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, + "CUSTOM DECONV2D_TF MKLDNN OP: rank of array with output shape " + "must be equal to 1, but got %i instead !", + gradIShape->rankOf()); + REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, + "CUSTOM DECONV2D_TF MKLDNN OP: length of array with output " + "shape must be equal to 4, but got %i instead !", + gradIShape->lengthOf()); + + int indIOioC, indIiH, indWoC(3), indOoH; + if (!isNCHW) { + indIOioC = 3; + indIiH = 1; + indOoH = 1; + } else { + indIOioC = 1; + indIiH = 2; + indOoH = 2; + } + + std::vector gradIShapeVector = + gradIShape->template asVectorT(); + + const int bS = gradIShapeVector[0]; // batch size + const int iH = gradIShapeVector[indIiH]; // input height + const int iW = gradIShapeVector[indIiH + 1]; // input width + const int iC = gradIShapeVector[indIOioC]; // input channels + const int oC = weights->sizeAt(indWoC); // output channels + const int oH = gradO->sizeAt(indOoH); // input height + const int oW = gradO->sizeAt(indOoH); // input width + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = {kH, kW, iC, oC}; + REQUIRE_TRUE( + gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM DECONV2D_TF MKLDNN OP: wrong shape of input array, basing on " + "array with output shape expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DECONV2D_TF MKLDNN OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + // // mkl supports only [oC, iC, kH, kW] for weights + // weights = new NDArray(weights->permute({3,2,0,1})); // [kH, kW, iC, + // oC] -> [oC, iC, kH, kW] + + // // mkl supports NCHW format only + // if(!isNCHW) { + // gradI = new NDArray(gradI->permute({0,3,1,2})); // [bS, iH, iW, iC] + // -> [bS, iC, iH, iW] gradO = new NDArray(gradO->permute({0,3,1,2})); // + // [bS, oH, oW, oC] -> [bS, oC, oH, oW] + // } + + deconv2TFdBpMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, + kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, wFormat); + + // delete weights; + + // if(!isNCHW) { + // delete gradI; + // delete gradO; + // } + + // ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, + // nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); + + return Status::OK(); } PLATFORM_CHECK(deconv2d_tf, ENGINE_CPU) { - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always - auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI - - - const DataType wType = weights->dataType(); - const DataType gradOType = gradO->dataType(); - const DataType gradIType = gradI->dataType(); - - return block.isUseMKLDNN() && ((wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16)); + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto gradO = INPUT_VARIABLE( + 2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI + + const DataType wType = weights->dataType(); + const DataType gradOType = gradO->dataType(); + const DataType gradIType = gradI->dataType(); + + return block.isUseMKLDNN() && + ((wType == DataType::FLOAT32 || wType == DataType::BFLOAT16) && + (gradOType == DataType::FLOAT32 || gradOType == DataType::BFLOAT16) && + (gradIType == DataType::FLOAT32 || gradIType == DataType::BFLOAT16)); } -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp index 59f355c6e633..31935684e54e 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp @@ -18,449 +18,636 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include +#include #include +#include +#include #include -#include #include "mkldnnUtils.h" -#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////////// -static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, - const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, - const bool isNCDHW, const int wFormat) { - - // mkl supports weights in [oC, iC, kD, kH, kW] only - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); - - dnnl::memory::dims strides = { sD, sH, sW }; - dnnl::memory::dims padding = { pD, pH, pW }; - dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; - dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 }; - - std::vector permut; - if(0 == wFormat) - permut = {3,4,0,1,2}; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] - else if(1 == wFormat) - permut = {1,0,2,3,4}; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW] - else - permut = {4,0,1,2,3}; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW] - - // input type - dnnl::memory::data_type xType; - if(input->dataType() == DataType::FLOAT32) - xType = dnnl::memory::data_type::f32; - else if(input->dataType() == DataType::HALF) - xType = dnnl::memory::data_type::f16; - else if(input->dataType() == DataType::UINT8) - xType = dnnl::memory::data_type::u8; - else - xType = dnnl::memory::data_type::s8; - - // weights type - dnnl::memory::data_type wType = xType; - if(xType == dnnl::memory::data_type::u8) - wType = dnnl::memory::data_type::s8; - - // output and bias type (have the same types) - dnnl::memory::data_type zType; - if(output->dataType() == DataType::FLOAT32) - zType = dnnl::memory::data_type::f32; - else if(output->dataType() == DataType::HALF) - zType = dnnl::memory::data_type::f16; - else if(output->dataType() == DataType::UINT8) - zType = dnnl::memory::data_type::u8; - else if(output->dataType() == DataType::INT8) - zType = dnnl::memory::data_type::s8; - else - zType = dnnl::memory::data_type::s32; - - dnnl::memory::format_tag xFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw; - - dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; - dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; - dnnl::memory::dims zDims = {bS, oC, oD, oH, oW}; - - // memory descriptors for arrays - - // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); - mkldnnUtils::setBlockStrides(*input, x_user_md); - - // weights - dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); - mkldnnUtils::setBlockStrides(*weights, w_user_md, permut); - - // bias - dnnl::memory::desc b_mkl_md; - if(bias != nullptr) - b_mkl_md = dnnl::memory::desc({oC}, zType, dnnl::memory::format_tag::x); - - // output - dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl); - mkldnnUtils::setBlockStrides(*output, z_user_md); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // operation primitive description - dnnl::deconvolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, - x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r); - dnnl::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // weights - mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - - // bias - if(bias != nullptr) { - auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, const_cast(bias->buffer())); - args[DNNL_ARG_BIAS] = b_mkl_mem; - } - - // output - auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]); - - // run calculations - dnnl::deconvolution_forward(op_prim_desc).execute(stream, args); - - // reorder outputs if necessary - if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) - dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); - - stream.wait(); - - // shape::printArray(z_mkl_mem.map_data(),8); +static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, + const NDArray* bias, NDArray* output, const int kD, + const int kH, const int kW, const int sD, + const int sH, const int sW, const int pD, + const int pH, const int pW, const int dD, + const int dH, const int dW, const bool isNCDHW, + const int wFormat) { + // mkl supports weights in [oC, iC, kD, kH, kW] only + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWoC, indWiC, indWkD); + + dnnl::memory::dims strides = {sD, sH, sW}; + dnnl::memory::dims padding = {pD, pH, pW}; + dnnl::memory::dims padding_r = {(iD - 1) * sD - oD + kD - pD, + (iH - 1) * sH - oH + kH - pH, + (iW - 1) * sW - oW + kW - pW}; + dnnl::memory::dims dilation = {dD - 1, dH - 1, dW - 1}; + + std::vector permut; + if (0 == wFormat) + permut = {3,4,0,1,2}; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] + else if (1 == wFormat) + permut = {1,0,2,3,4}; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW] + else + permut = {4,0,1,2,3}; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW] + + + // input type + dnnl::memory::data_type xType; + if (input->dataType() == DataType::FLOAT32) + xType = dnnl::memory::data_type::f32; + else if (input->dataType() == DataType::HALF) + xType = dnnl::memory::data_type::f16; + else if (input->dataType() == DataType::UINT8) + xType = dnnl::memory::data_type::u8; + else + xType = dnnl::memory::data_type::s8; + + // weights type + dnnl::memory::data_type wType = xType; + if (xType == dnnl::memory::data_type::u8) wType = dnnl::memory::data_type::s8; + + // output and bias type (have the same types) + dnnl::memory::data_type zType; + if (output->dataType() == DataType::FLOAT32) + zType = dnnl::memory::data_type::f32; + else if (output->dataType() == DataType::HALF) + zType = dnnl::memory::data_type::f16; + else if (output->dataType() == DataType::UINT8) + zType = dnnl::memory::data_type::u8; + else if (output->dataType() == DataType::INT8) + zType = dnnl::memory::data_type::s8; + else + zType = dnnl::memory::data_type::s32; + + dnnl::memory::format_tag xFormatMkl = isNCDHW + ? dnnl::memory::format_tag::ncdhw + : dnnl::memory::format_tag::ndhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw; + + dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oD, oH, oW}; + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); + mkldnnUtils::setBlockStrides(*input, x_user_md); + + // weights + dnnl::memory::desc w_mkl_md = + dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); + mkldnnUtils::setBlockStrides(*weights, + w_user_md, permut); + + // bias + dnnl::memory::desc b_mkl_md; + if (bias != nullptr) + b_mkl_md = dnnl::memory::desc({oC}, zType, dnnl::memory::format_tag::x); + + // output + dnnl::memory::desc z_mkl_md = + dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl); + mkldnnUtils::setBlockStrides(*output, z_user_md); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // operation primitive description + dnnl::deconvolution_forward::desc op_desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, + x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, + padding_r); + dnnl::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, + op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + + // weights + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, + op_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + + // bias + if (bias != nullptr) { + auto b_mkl_mem = + dnnl::memory(b_mkl_md, engine, const_cast(bias->buffer())); + args[DNNL_ARG_BIAS] = b_mkl_mem; + } + + // output + auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), + args[DNNL_ARG_DST] ); + + // run calculations + dnnl::deconvolution_forward(op_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); + + stream.wait(); + + // shape::printArray(z_mkl_mem.map_data(),8); } ////////////////////////////////////////////////////////////////////////// -static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const bool isNCDHW, const int wFormat) { - - // mkl supports weights/gradW in [oC, iC, kD, kH, kW] format only - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); - - dnnl::memory::dims strides = { sD, sH, sW }; - dnnl::memory::dims padding = { pD, pH, pW }; - dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; - dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 }; - - std::vector permut; - if(0 == wFormat) - permut = {3,4,0,1,2}; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] - else if(1 == wFormat) - permut = {1,0,2,3,4}; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW] - else - permut = {4,0,1,2,3}; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW] - - // input type - dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // weights type - dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradO type - dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradI type - dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradW type - dnnl::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradB type - dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32; - - dnnl::memory::format_tag xFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw; - - dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; - dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; - dnnl::memory::dims zDims = {bS, oC, oD, oH, oW}; - - // memory descriptors for arrays - - // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); - mkldnnUtils::setBlockStrides(*input, x_user_md); - - // weights - dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); - mkldnnUtils::setBlockStrides(*weights, w_user_md, permut); - - // gradO - dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl); - mkldnnUtils::setBlockStrides(*gradO, gradO_user_md); - - // gradI - dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl); - mkldnnUtils::setBlockStrides(*gradI, gradI_user_md); - - // gradW - dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl); - mkldnnUtils::setBlockStrides(*gradW, gradW_user_md, permut); - - // gradB - dnnl::memory::desc gradB_mkl_md; - if(gradB != nullptr) - gradB_mkl_md = dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x); - - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // forward primitive description - dnnl::deconvolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); - - // backward data primitive description - dnnl::deconvolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc); - - // backward weights primitive description - dnnl::deconvolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // weights - mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - - // gradO - auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast(gradO->buffer())); - const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; - auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; - if (gradOReorderW) - dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW); - if (gradOReorderD) - dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD); - args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; - - // gradI - auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]); - - // gradW - auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, op_weights_bp_prim_desc.diff_weights_desc(), args[DNNL_ARG_DIFF_WEIGHTS]); - - // gradB - if(gradB != nullptr) { - auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->buffer()); - args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem; - } - - // run backward data calculations - dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args); - - if(gradOReorderW || gradOReorderD) - args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW; - - // run backward weights calculations - dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); - - // reorder gradI if necessary - if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc()) - dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem); - if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc()) - dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem).execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem); - - stream.wait(); - - // shape::printArray(z_mkl_mem.map_data(),8); +static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, + const NDArray* gradO, NDArray* gradI, + NDArray* gradW, NDArray* gradB, const int kD, + const int kH, const int kW, const int sD, + const int sH, const int sW, const int pD, + const int pH, const int pW, const int dD, + const int dH, const int dW, + const bool isNCDHW, const int wFormat) { + // mkl supports weights/gradW in [oC, iC, kD, kH, kW] format only + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWoC, indWiC, indWkD); + + dnnl::memory::dims strides = {sD, sH, sW}; + dnnl::memory::dims padding = {pD, pH, pW}; + dnnl::memory::dims padding_r = {(iD - 1) * sD - oD + kD - pD, + (iH - 1) * sH - oH + kH - pH, + (iW - 1) * sW - oW + kW - pW}; + dnnl::memory::dims dilation = {dD - 1, dH - 1, dW - 1}; + + std::vector permut; + if (0 == wFormat) + permut = {3,4,0,1,2}; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] + else if (1 == wFormat) + permut = {1,0,2,3,4}; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW] + else + permut = {4,0,1,2,3}; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW] + + + // input type + dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // weights type + dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradO type + dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradI type + dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradW type + dnnl::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradB type + dnnl::memory::data_type gradBType = + gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16) + : dnnl::memory::data_type::f32; + + dnnl::memory::format_tag xFormatMkl = isNCDHW + ? dnnl::memory::format_tag::ncdhw + : dnnl::memory::format_tag::ndhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw; + + dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oD, oH, oW}; + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); + mkldnnUtils::setBlockStrides(*input, x_user_md); + + // weights + dnnl::memory::desc w_mkl_md = + dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); + mkldnnUtils::setBlockStrides(*weights, + w_user_md, permut); + + // gradO + dnnl::memory::desc gradO_mkl_md = + dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradO_user_md = + dnnl::memory::desc(zDims, gradOType, xFormatMkl); + mkldnnUtils::setBlockStrides(*gradO, gradO_user_md); + + // gradI + dnnl::memory::desc gradI_mkl_md = + dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradI_user_md = + dnnl::memory::desc(xDims, gradIType, xFormatMkl); + mkldnnUtils::setBlockStrides(*gradI, gradI_user_md); + + // gradW + dnnl::memory::desc gradW_mkl_md = + dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradW_user_md = + dnnl::memory::desc(wDims, gradWType, wFormatMkl); + mkldnnUtils::setBlockStrides(*gradW, + gradW_user_md, permut); + + // gradB + dnnl::memory::desc gradB_mkl_md; + if (gradB != nullptr) + gradB_mkl_md = + dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // forward primitive description + dnnl::deconvolution_forward::desc op_ff_desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, + x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, + padding, padding_r); + dnnl::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, + engine); + + // backward data primitive description + dnnl::deconvolution_backward_data::desc op_data_bp_desc( + dnnl::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, + gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc( + op_data_bp_desc, engine, op_ff_prim_desc); + + // backward weights primitive description + dnnl::deconvolution_backward_weights::desc op_weights_bp_desc( + dnnl::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, + gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc( + op_weights_bp_desc, engine, op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, + op_weights_bp_prim_desc.src_desc(), + args[DNNL_ARG_SRC]); + + // weights + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, + op_data_bp_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + + // gradO + auto gradO_user_mem = + dnnl::memory(gradO_user_md, engine, const_cast(gradO->buffer())); + const bool gradOReorderW = + op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + const bool gradOReorderD = + op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + auto gradO_mkl_memW = + gradOReorderW + ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) + : gradO_user_mem; + auto gradO_mkl_memD = + gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) + : gradO_user_mem; + if (gradOReorderW) + dnnl::reorder(gradO_user_mem, gradO_mkl_memW) + .execute(stream, gradO_user_mem, gradO_mkl_memW); + if (gradOReorderD) + dnnl::reorder(gradO_user_mem, gradO_mkl_memD) + .execute(stream, gradO_user_mem, gradO_mkl_memD); + args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; + + // gradI + auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, + op_data_bp_prim_desc.diff_src_desc() , + args[DNNL_ARG_DIFF_SRC] ); + + // gradW + auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, + op_weights_bp_prim_desc.diff_weights_desc() , + args[DNNL_ARG_DIFF_WEIGHTS] ); + + // gradB + if (gradB != nullptr) { + auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->buffer()); + args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem; + } + + // run backward data calculations + dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args); + + if (gradOReorderW || gradOReorderD) args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW; + + // run backward weights calculations + dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc) + .execute(stream, args); + + // reorder gradI if necessary + if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem) + .execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem); + if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem) + .execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem); + + stream.wait(); + + // shape::printArray(z_mkl_mem.map_data(),8); } - ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(deconv3d, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - - auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) - - REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM DECONV3D_MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM DECONV3D_MKLDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf()); - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0)); // filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2)); // filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D_MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - if(isSameMode){ // SAME - //Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass - ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - } - - deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, wFormat); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + auto output = OUTPUT_VARIABLE( + 0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "CUSTOM DECONV3D_MKLDNN OP: rank of input array must be equal " + "to 5, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, + "CUSTOM DECONV3D_MKLDNN OP: rank of weights array must be equal " + "to 5, but got %i instead !", + weights->rankOf()); + + int kD = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 + ? INT_ARG(2) + : static_cast(weights->sizeAt(2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID + int isNCDHW = + block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 + ? INT_ARG(14) + : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], + // 2 - [iC, kD, kH, kW, oC] + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWoC, indWiC, indWkD); + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DECONV3D_MKLDNN OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM DECONV3D_MKLDNN OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + if (isSameMode) { // SAME + // Note: we're intentionally swapping iH and oH, to calculated the padding + // for a"normal" conv (not deconv) forward pass + ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + } + + deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, + pW, dD, dH, dW, isNCDHW, wFormat); + + return Status::OK(); } PLATFORM_CHECK(deconv3d, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; - - auto output = INPUT_VARIABLE(0); - - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - - const DataType xType = input->dataType(); - const DataType wType = weights->dataType(); - const DataType zType = output->dataType(); - const DataType bType = bias != nullptr ? bias->dataType() : zType; - - return block.isUseMKLDNN() && (dD <= 1 && dH <= 1 && dW <= 1 && !isSameMode) && - ( - (xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) || - ((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType) - ); + auto input = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; + + auto output = INPUT_VARIABLE(0); + + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID + + const DataType xType = input->dataType(); + const DataType wType = weights->dataType(); + const DataType zType = output->dataType(); + const DataType bType = bias != nullptr ? bias->dataType() : zType; + + return block.isUseMKLDNN() && + (dD <= 1 && dH <= 1 && dW <= 1 && !isSameMode) && + ((xType == DataType::FLOAT32 && wType == DataType::FLOAT32 && + bType == DataType::FLOAT32 && zType == DataType::FLOAT32) || + ((xType == DataType::UINT8 || xType == DataType::INT8) && + wType == DataType::INT8 && + (zType == DataType::UINT8 || zType == DataType::INT8 || + zType == DataType::INT32 || zType == DataType::FLOAT32) && + bType == zType)); } - ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI - auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM DECONV3D_MKLDNN_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM DECONV3D_MKLDNN_BP OP: rank of weights array must be equal to 5 , but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 5, 0, "CUSTOM DECONV3D_MKLDNN_BP OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !", gradO->rankOf()); - - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); - - int trueoD, trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass - ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, wFormat); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = + block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), gradI + auto gradW = OUTPUT_VARIABLE( + 1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "CUSTOM DECONV3D_MKLDNN_BP OP: rank of input array must be " + "equal to 5, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, + "CUSTOM DECONV3D_MKLDNN_BP OP: rank of weights array must be " + "equal to 5 , but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 5, 0, + "CUSTOM DECONV3D_MKLDNN_BP OP: rank of output gradients (next " + "epsilon) array must be equal to 5, but got %i instead !", + gradO->rankOf()); + + int kD = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 + ? INT_ARG(2) + : static_cast(weights->sizeAt(2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID + int isNCDHW = + block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 + ? INT_ARG(14) + : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], + // 2 - [iC, kD, kH, kW, oC] + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWoC, indWiC, indWkD); + + int trueoD, trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, + sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, + iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, oC, trueoD, trueoH, trueoW, + 0, indIOioC, indIOioD, + indIOioD + 1, indIOioD + 2}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of output gradients " + "(next epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE( + bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + if (isSameMode) // Note: we're intentionally swapping iH and oH, to + // calculated the padding for a"normal" conv (not deconv) + // forward pass + ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, + sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, wFormat); + + return Status::OK(); } - PLATFORM_CHECK(deconv3d_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NHWC) or [bS, iD, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI - auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - - const DataType xType = input->dataType(); - const DataType wType = weights->dataType(); - const DataType gradOType = gradO->dataType(); - - const DataType gradIType = gradI->dataType(); - const DataType gradWType = gradW->dataType(); - const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32; - - return block.isUseMKLDNN() && (dD <= 1 && dH <= 1 && dW <= 1 && !isSameMode) && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) ); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NHWC) or [bS, iD, iC, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = + block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NHWC) or [bS, oC, oD, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI + auto gradW = OUTPUT_VARIABLE( + 1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID + + const DataType xType = input->dataType(); + const DataType wType = weights->dataType(); + const DataType gradOType = gradO->dataType(); + + const DataType gradIType = gradI->dataType(); + const DataType gradWType = gradW->dataType(); + const DataType gradBType = + gradB != nullptr ? gradB->dataType() : DataType::FLOAT32; + + return block.isUseMKLDNN() && + (dD <= 1 && dH <= 1 && dW <= 1 && !isSameMode) && + ((xType == DataType::FLOAT32 || xType == DataType::BFLOAT16) && + (wType == DataType::FLOAT32 || wType == DataType::BFLOAT16) && + (gradOType == DataType::FLOAT32 || gradOType == DataType::BFLOAT16) && + (gradIType == DataType::FLOAT32 || gradIType == DataType::BFLOAT16) && + (gradWType == DataType::FLOAT32 || gradWType == DataType::BFLOAT16) && + (gradBType == DataType::FLOAT32 || gradBType == DataType::BFLOAT16)); } -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp index 938494d5a6bf..78a8bb34e3b1 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp @@ -19,466 +19,659 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include #include +#include +#include #include +#include + #include "mkldnnUtils.h" using namespace dnnl; -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////////// -static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, - const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const int paddingMode, const bool isNCHW, const int wFormat) { - - // mkl supports only following case: mC = 1, oC = iC - - // input [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc, since mkl doesn't support nhwc format we'll permute when nhwc is given - // weights {iC, mC, 1, kH, kW} - // bias [oC], may be nullptr - // output [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc - // oC = iC*mC - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d - - dnnl::memory::dims strides = { sH, sW }; - dnnl::memory::dims padding = { pH, pW }; - dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; - dnnl::memory::dims dilation = { dH-1, dW-1}; - - uint i0, i1, i2, i3; - if(0 == wFormat) { - i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW] - } - else if(1 == wFormat) { - i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [mC, iC, kH, kW] -> [iC, mC, 1, kH, kW] - } - else { - i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [mC, kH, kW, iC] -> [iC, mC, 1, kH, kW] - } - - // input type - dnnl::memory::data_type xType; - if(input->dataType() == DataType::FLOAT32) - xType = dnnl::memory::data_type::f32; - else if(input->dataType() == DataType::HALF) - xType = dnnl::memory::data_type::f16; - else if(input->dataType() == DataType::UINT8) - xType = dnnl::memory::data_type::u8; - else - xType = dnnl::memory::data_type::s8; - - // weights type - dnnl::memory::data_type wType = xType; - if(xType == dnnl::memory::data_type::u8) - wType = dnnl::memory::data_type::s8; - - // output and bias type (have the same types) - dnnl::memory::data_type zType; - if(output->dataType() == DataType::FLOAT32) - zType = dnnl::memory::data_type::f32; - else if(output->dataType() == DataType::HALF) - zType = dnnl::memory::data_type::f16; - else if(output->dataType() == DataType::UINT8) - zType = dnnl::memory::data_type::u8; - else if(output->dataType() == DataType::INT8) - zType = dnnl::memory::data_type::s8; - else - zType = dnnl::memory::data_type::s32; - - dnnl::memory::format_tag xzFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::goihw; - - dnnl::memory::dims xDims = {bS, iC, iH, iW}; - dnnl::memory::dims wDims = {iC, mC, 1, kH, kW}; - dnnl::memory::dims zDims = {bS, oC, oH, oW}; - - // memory descriptors for arrays - - // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl); - mkldnnUtils::setBlockStrides(*input, x_user_md); - - // weights - dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); - w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); // permute - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); - w_user_md.data.format_desc.blocking.strides[2] = 0; - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i2); - w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i3); - - // bias - dnnl::memory::desc b_mkl_md; - if(bias != nullptr) - b_mkl_md = dnnl::memory::desc({oC}, zType, dnnl::memory::format_tag::x); - - // output - dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFormatMkl); - mkldnnUtils::setBlockStrides(*output, z_user_md); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // operation primitive description - dnnl::convolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, - x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_forward::primitive_desc op_prim_desc(op_desc, engine); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // weights - mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - - // bias - if(bias != nullptr) { - auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, const_cast(bias->buffer())); - args[DNNL_ARG_BIAS] = b_mkl_mem; - } - - // output - auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]); - - // run calculations - dnnl::convolution_forward(op_prim_desc).execute(stream, args); - - // reorder outputs if necessary - if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) - dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); - - stream.wait(); - // shape::printArray(z_mkl_mem.map_data(),8); +static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, + const NDArray* bias, NDArray* output, + const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, + const int dH, const int dW, + const int paddingMode, const bool isNCHW, + const int wFormat) { + // mkl supports only following case: mC = 1, oC = iC + + // input [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc, since mkl doesn't + // support nhwc format we'll permute when nhwc is given weights {iC, mC, 1, + // kH, kW} bias [oC], may be nullptr output [bS, oC, oH, oW] nchw or [bS, oH, + // oW, oC] nhwc oC = iC*mC + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + const int pWSame = (paddingMode == 2 && dW > 1) + ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 + : pW; // dH == 1 for causal mode in conv1d + + dnnl::memory::dims strides = {sH, sW}; + dnnl::memory::dims padding = {pH, pW}; + dnnl::memory::dims padding_r = {(oH - 1) * sH - iH + kH - pH, + (oW - 1) * sW - iW + kW - pWSame}; + dnnl::memory::dims dilation = {dH - 1, dW - 1}; + + uint i0, i1, i2, i3; + if (0 == wFormat) { + i0 = 2; + i1 = 3; + i2 = 0; + i3 = 1; // [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW] + } else if (1 == wFormat) { + i0 = 1; + i1 = 0; + i2 = 2; + i3 = 3; // [mC, iC, kH, kW] -> [iC, mC, 1, kH, kW] + } else { + i0 = 3; + i1 = 0; + i2 = 1; + i3 = 2; // [mC, kH, kW, iC] -> [iC, mC, 1, kH, kW] + } + + // input type + dnnl::memory::data_type xType; + if (input->dataType() == DataType::FLOAT32) + xType = dnnl::memory::data_type::f32; + else if (input->dataType() == DataType::HALF) + xType = dnnl::memory::data_type::f16; + else if (input->dataType() == DataType::UINT8) + xType = dnnl::memory::data_type::u8; + else + xType = dnnl::memory::data_type::s8; + + // weights type + dnnl::memory::data_type wType = xType; + if (xType == dnnl::memory::data_type::u8) wType = dnnl::memory::data_type::s8; + + // output and bias type (have the same types) + dnnl::memory::data_type zType; + if (output->dataType() == DataType::FLOAT32) + zType = dnnl::memory::data_type::f32; + else if (output->dataType() == DataType::HALF) + zType = dnnl::memory::data_type::f16; + else if (output->dataType() == DataType::UINT8) + zType = dnnl::memory::data_type::u8; + else if (output->dataType() == DataType::INT8) + zType = dnnl::memory::data_type::s8; + else + zType = dnnl::memory::data_type::s32; + + dnnl::memory::format_tag xzFormatMkl = + isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::goihw; + + dnnl::memory::dims xDims = {bS, iC, iH, iW}; + dnnl::memory::dims wDims = {iC, mC, 1, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oH, oW}; + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl); + mkldnnUtils::setBlockStrides(*input, x_user_md); + + // weights + dnnl::memory::desc w_mkl_md = + dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); + w_user_md.data.format_kind = dnnl_blocked; // overrides format + w_user_md.data.format_desc.blocking.strides[0] = + weights->strideAt(i0); // permute + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); + w_user_md.data.format_desc.blocking.strides[2] = 0; + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i2); + w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i3); + + // bias + dnnl::memory::desc b_mkl_md; + if (bias != nullptr) + b_mkl_md = dnnl::memory::desc({oC}, zType, dnnl::memory::format_tag::x); + + // output + dnnl::memory::desc z_mkl_md = + dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFormatMkl); + mkldnnUtils::setBlockStrides(*output, z_user_md); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // operation primitive description + dnnl::convolution_forward::desc op_desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, + x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, + padding_r); + dnnl::convolution_forward::primitive_desc op_prim_desc(op_desc, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, + op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + + // weights + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, + op_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + + // bias + if (bias != nullptr) { + auto b_mkl_mem = + dnnl::memory(b_mkl_md, engine, const_cast(bias->buffer())); + args[DNNL_ARG_BIAS] = b_mkl_mem; + } + + // output + auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), + args[DNNL_ARG_DST] ); + + // run calculations + dnnl::convolution_forward(op_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); + + stream.wait(); + // shape::printArray(z_mkl_mem.map_data(),8); } ////////////////////////////////////////////////////////////////////////// -static void depthwiseConv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, - const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const int paddingMode, const bool isNCHW, const int wFormat) { - - // mkl supports only following case: mC = 1, oC = iC - - // input, gradI [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc, since mkl doesn't support nhwc format we'll permute when nhwc is given - // weights/gradW {iC, mC, 1, kH, kW} - // gradB [oC], may be nullptr - // gradO [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc - // oC = iC*mC - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); - - const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d - - dnnl::memory::dims strides = { sH, sW }; - dnnl::memory::dims padding = { pH, pW }; - dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; - dnnl::memory::dims dilation = { dH-1, dW-1}; - - uint i0, i1, i2, i3; - if(0 == wFormat) { - i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW] - } - else if(1 == wFormat) { - i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [mC, iC, kH, kW] -> [iC, mC, 1, kH, kW] - } - else { - i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [mC, kH, kW, iC] -> [iC, mC, 1, kH, kW] - } - - // input type - dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // weights type - dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradO type - dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradI type - dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradW type - dnnl::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradB type - dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32; - - dnnl::memory::format_tag xzFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::goihw; - - dnnl::memory::dims xDims = {bS, iC, iH, iW}; - dnnl::memory::dims wDims = {iC, mC, 1, kH, kW}; - dnnl::memory::dims zDims = {bS, oC, oH, oW}; - - // memory descriptors for arrays - - // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl); - mkldnnUtils::setBlockStrides(*input, x_user_md); - - // weights - dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); - w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); // permute - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); - w_user_md.data.format_desc.blocking.strides[2] = 0; - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i2); - w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i3); - - // gradO - dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xzFormatMkl); - mkldnnUtils::setBlockStrides(*gradO, gradO_user_md); - - // gradI - dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xzFormatMkl); - mkldnnUtils::setBlockStrides(*gradI, gradI_user_md); - - // gradW - dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl); - gradW_user_md.data.format_kind = dnnl_blocked; // overrides format - gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0); // permute - gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); - gradW_user_md.data.format_desc.blocking.strides[2] = 0; - gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i2); - gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i3); - - // gradB - dnnl::memory::desc gradB_mkl_md; - if(gradB != nullptr) - gradB_mkl_md = dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // forward primitive description - dnnl::convolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); - - // backward data primitive description - dnnl::convolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc); - - // backward weights primitive description - dnnl::convolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::convolution_auto, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // weights - mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - - // gradO - auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast(gradO->buffer())); - const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; - auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; - if (gradOReorderW) - dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW); - if (gradOReorderD) - dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD); - args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; - - // gradI - auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]); - - // gradW - auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, op_weights_bp_prim_desc.diff_weights_desc(), args[DNNL_ARG_DIFF_WEIGHTS]); - - // gradB - if(gradB != nullptr) { - auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->buffer()); - args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem; - } - - // run backward data calculations - dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args); - - if(gradOReorderW || gradOReorderD) - args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW; - - // run backward weights calculations - dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); - - // reorder gradI if necessary - if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc()) - dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem); - if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc()) - dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem).execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem); - - stream.wait(); - - // shape::printArray(z_mkl_mem.map_data(),8); +static void depthwiseConv2dBpMKLDNN( + const NDArray* input, const NDArray* weights, const NDArray* gradO, + NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, + const int sH, const int sW, const int pH, const int pW, const int dH, + const int dW, const int paddingMode, const bool isNCHW, const int wFormat) { + // mkl supports only following case: mC = 1, oC = iC + + // input, gradI [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc, since mkl + // doesn't support nhwc format we'll permute when nhwc is given weights/gradW + // {iC, mC, 1, kH, kW} gradB [oC], may be nullptr gradO [bS, oC, oH, oW] nchw + // or [bS, oH, oW, oC] nhwc oC = iC*mC + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); + + const int pWSame = (paddingMode == 2 && dW > 1) + ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 + : pW; // dH == 1 for causal mode in conv1d + + dnnl::memory::dims strides = {sH, sW}; + dnnl::memory::dims padding = {pH, pW}; + dnnl::memory::dims padding_r = {(oH - 1) * sH - iH + kH - pH, + (oW - 1) * sW - iW + kW - pWSame}; + dnnl::memory::dims dilation = {dH - 1, dW - 1}; + + uint i0, i1, i2, i3; + if (0 == wFormat) { + i0 = 2; + i1 = 3; + i2 = 0; + i3 = 1; // [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW] + } else if (1 == wFormat) { + i0 = 1; + i1 = 0; + i2 = 2; + i3 = 3; // [mC, iC, kH, kW] -> [iC, mC, 1, kH, kW] + } else { + i0 = 3; + i1 = 0; + i2 = 1; + i3 = 2; // [mC, kH, kW, iC] -> [iC, mC, 1, kH, kW] + } + + // input type + dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // weights type + dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradO type + dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradI type + dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradW type + dnnl::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradB type + dnnl::memory::data_type gradBType = + gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16) + : dnnl::memory::data_type::f32; + + dnnl::memory::format_tag xzFormatMkl = + isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::goihw; + + dnnl::memory::dims xDims = {bS, iC, iH, iW}; + dnnl::memory::dims wDims = {iC, mC, 1, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oH, oW}; + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl); + mkldnnUtils::setBlockStrides(*input, x_user_md); + + // weights + dnnl::memory::desc w_mkl_md = + dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); + w_user_md.data.format_kind = dnnl_blocked; // overrides format + w_user_md.data.format_desc.blocking.strides[0] = + weights->strideAt(i0); // permute + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); + w_user_md.data.format_desc.blocking.strides[2] = 0; + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i2); + w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i3); + + // gradO + dnnl::memory::desc gradO_mkl_md = + dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradO_user_md = + dnnl::memory::desc(zDims, gradOType, xzFormatMkl); + mkldnnUtils::setBlockStrides(*gradO, gradO_user_md); + + // gradI + dnnl::memory::desc gradI_mkl_md = + dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradI_user_md = + dnnl::memory::desc(xDims, gradIType, xzFormatMkl); + mkldnnUtils::setBlockStrides(*gradI, gradI_user_md); + + // gradW + dnnl::memory::desc gradW_mkl_md = + dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradW_user_md = + dnnl::memory::desc(wDims, gradWType, wFormatMkl); + gradW_user_md.data.format_kind = dnnl_blocked; // overrides format + gradW_user_md.data.format_desc.blocking.strides[0] = + gradW->strideAt(i0); // permute + gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); + gradW_user_md.data.format_desc.blocking.strides[2] = 0; + gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i2); + gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i3); + + // gradB + dnnl::memory::desc gradB_mkl_md; + if (gradB != nullptr) + gradB_mkl_md = + dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // forward primitive description + dnnl::convolution_forward::desc op_ff_desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, + x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, + padding, padding_r); + dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + + // backward data primitive description + dnnl::convolution_backward_data::desc op_data_bp_desc( + dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, + strides, dilation, padding, padding_r); + dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc( + op_data_bp_desc, engine, op_ff_prim_desc); + + // backward weights primitive description + dnnl::convolution_backward_weights::desc op_weights_bp_desc( + dnnl::algorithm::convolution_auto, x_mkl_md, gradW_mkl_md, gradB_mkl_md, + gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::convolution_backward_weights::primitive_desc op_weights_bp_prim_desc( + op_weights_bp_desc, engine, op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, + op_weights_bp_prim_desc.src_desc(), + args[DNNL_ARG_SRC]); + + // weights + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, + op_data_bp_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + + // gradO + auto gradO_user_mem = + dnnl::memory(gradO_user_md, engine, const_cast(gradO->buffer())); + const bool gradOReorderW = + op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + const bool gradOReorderD = + op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + auto gradO_mkl_memW = + gradOReorderW + ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) + : gradO_user_mem; + auto gradO_mkl_memD = + gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) + : gradO_user_mem; + if (gradOReorderW) + dnnl::reorder(gradO_user_mem, gradO_mkl_memW) + .execute(stream, gradO_user_mem, gradO_mkl_memW); + if (gradOReorderD) + dnnl::reorder(gradO_user_mem, gradO_mkl_memD) + .execute(stream, gradO_user_mem, gradO_mkl_memD); + args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; + + // gradI + auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, + op_data_bp_prim_desc.diff_src_desc() , + args[DNNL_ARG_DIFF_SRC] ); + + // gradW + auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, + op_weights_bp_prim_desc.diff_weights_desc() , + args[DNNL_ARG_DIFF_WEIGHTS] ); + + // gradB + if (gradB != nullptr) { + auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->buffer()); + args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem; + } + + // run backward data calculations + dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args); + + if (gradOReorderW || gradOReorderD) args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW; + + // run backward weights calculations + dnnl::convolution_backward_weights(op_weights_bp_prim_desc) + .execute(stream, args); + + // reorder gradI if necessary + if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem) + .execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem); + if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem) + .execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem); + + stream.wait(); + + // shape::printArray(z_mkl_mem.map_data(),8); } - ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(depthwise_conv2d, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC - - auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D MKL OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "CUSTOM DEPTHWISECONV2D MKL OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D MKL OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - depthwiseConv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC + + auto output = OUTPUT_VARIABLE( + 0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, + // iC, kH, kW], 2 - [mC, kH, kW, iC] + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW, paddingMode); + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DEPTHWISECONV2D MKL OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + REQUIRE_TRUE(output->sizeAt(indIOioC) == iC * mC, 0, + "CUSTOM DEPTHWISECONV2D MKL OP: the output_channels must be " + "equal to input_channels * channels_multiplier = %i !", + iC * mC); + if (bias) + REQUIRE_TRUE( + bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM DEPTHWISECONV2D MKL OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + depthwiseConv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, + dH, dW, paddingMode, isNCHW, wFormat); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(depthwise_conv2d, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; - - auto output = INPUT_VARIABLE(0); - - const DataType xType = input->dataType(); - const DataType wType = weights->dataType(); - const DataType zType = output->dataType(); - const DataType bType = bias != nullptr ? bias->dataType() : zType; - - const int mC = weights->sizeAt(3); - - return block.isUseMKLDNN() && mC == 1 && - ( - (xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) || - (xType==DataType::BFLOAT16 && wType==DataType::BFLOAT16 && bType==DataType::BFLOAT16 && zType==DataType::BFLOAT16) || - ((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType) - ); + auto input = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; + + auto output = INPUT_VARIABLE(0); + + const DataType xType = input->dataType(); + const DataType wType = weights->dataType(); + const DataType zType = output->dataType(); + const DataType bType = bias != nullptr ? bias->dataType() : zType; + + const int mC = weights->sizeAt(3); + + return block.isUseMKLDNN() && mC == 1 && + ((xType == DataType::FLOAT32 && wType == DataType::FLOAT32 && + bType == DataType::FLOAT32 && zType == DataType::FLOAT32) || + (xType == DataType::BFLOAT16 && wType == DataType::BFLOAT16 && + bType == DataType::BFLOAT16 && zType == DataType::BFLOAT16) || + ((xType == DataType::UINT8 || xType == DataType::INT8) && + wType == DataType::INT8 && + (zType == DataType::UINT8 || zType == DataType::INT8 || + zType == DataType::INT32 || zType == DataType::FLOAT32) && + bType == zType)); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] - - REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of output gradients (next epsilon) array must be equal to 4, but got %i instead !", gradO->rankOf()); - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - int trueoH, trueoW; // correct output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - depthwiseConv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + auto bias = + block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_NULLIFIED( + 0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_NULLIFIED( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of input array must be " + "equal to 4, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of weights array must " + "be equal to 4, but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 4, 0, + "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of output gradients " + "(next epsilon) array must be equal to 4, but got %i instead !", + gradO->rankOf()); + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, + // iC, kH, kW], 2 - [mC, kH, kW, iC] + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + int trueoH, trueoW; // correct output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, paddingMode); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW, paddingMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); + REQUIRE_TRUE( + gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of weights " + "array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE( + bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + depthwiseConv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, + kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, + wFormat); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - const DataType xType = input->dataType(); - const DataType wType = weights->dataType(); - const DataType gradOType = gradO->dataType(); - - const DataType gradIType = gradI->dataType(); - const DataType gradWType = gradW->dataType(); - const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32; - - const int mC = weights->sizeAt(3); - - return block.isUseMKLDNN() && mC == 1 && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) ); + auto input = INPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto bias = + block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_VARIABLE( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + const DataType xType = input->dataType(); + const DataType wType = weights->dataType(); + const DataType gradOType = gradO->dataType(); + + const DataType gradIType = gradI->dataType(); + const DataType gradWType = gradW->dataType(); + const DataType gradBType = + gradB != nullptr ? gradB->dataType() : DataType::FLOAT32; + + const int mC = weights->sizeAt(3); + + return block.isUseMKLDNN() && mC == 1 && + ((xType == DataType::FLOAT32 || xType == DataType::BFLOAT16) && + (wType == DataType::FLOAT32 || wType == DataType::BFLOAT16) && + (gradOType == DataType::FLOAT32 || gradOType == DataType::BFLOAT16) && + (gradIType == DataType::FLOAT32 || gradIType == DataType::BFLOAT16) && + (gradWType == DataType::FLOAT32 || gradWType == DataType::BFLOAT16) && + (gradBType == DataType::FLOAT32 || gradBType == DataType::BFLOAT16)); } -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/lrn.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/lrn.cpp index 583ab08528c0..c99c5c93dbba 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/lrn.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/lrn.cpp @@ -19,75 +19,83 @@ // @author raver119@gmail.com // -#include +#include #include +#include +#include #include -#include #include "mkldnnUtils.h" -#include using namespace dnnl; namespace sd { - namespace ops { - namespace platforms { - PLATFORM_IMPL(lrn, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(input->rankOf() == 4, 0, "lrn: Input rank of 4 expected, but got %i instead", - input->rankOf()); - - double alpha = T_ARG(1); - double beta = T_ARG(2); - double bias = T_ARG(0); - int depth = INT_ARG(0); - - dnnl_memory_desc_t empty; - dnnl::memory::desc lrn_src_md(empty), lrn_dst_md(empty), user_src_md(empty), user_dst_md(empty); - - mkldnnUtils::getMKLDNNMemoryDescLrn(input, nullptr, output, &lrn_src_md, nullptr, &lrn_dst_md, - &user_src_md, nullptr, &user_dst_md, input->rankOf() - 1); - - auto lrn_desc = lrn_forward::desc(prop_kind::forward_inference, algorithm::lrn_across_channels, - lrn_src_md, (2 * depth + 1), alpha * (2 * depth + 1), beta, bias); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); - auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, engine); - auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); - auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); - - auto lrn_src_memory = user_src_memory; - if (lrn_prim_desc.src_desc() != user_src_memory.get_desc()) { - lrn_src_memory = dnnl::memory(lrn_prim_desc.src_desc(), engine); - reorder(user_src_memory, lrn_src_memory).execute(stream, user_src_memory, lrn_src_memory); - } - - auto lrn_dst_memory = user_dst_memory; - if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - lrn_dst_memory = dnnl::memory(lrn_prim_desc.dst_desc(), engine); - } - - lrn_forward(lrn_prim_desc).execute(stream, {{DNNL_ARG_SRC, lrn_src_memory}, - {DNNL_ARG_DST, lrn_dst_memory}}); - - if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - reorder(lrn_dst_memory, user_dst_memory).execute(stream, lrn_dst_memory, user_dst_memory); - } - - stream.wait(); - - return Status::OK(); - }; - - PLATFORM_CHECK(lrn, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); - } - } - } -} \ No newline at end of file +namespace ops { +namespace platforms { +PLATFORM_IMPL(lrn, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "lrn: Input rank of 4 expected, but got %i instead", + input->rankOf()); + + double alpha = T_ARG(1); + double beta = T_ARG(2); + double bias = T_ARG(0); + int depth = INT_ARG(0); + + dnnl_memory_desc_t empty; + dnnl::memory::desc lrn_src_md(empty), lrn_dst_md(empty), user_src_md(empty), + user_dst_md(empty); + + mkldnnUtils::getMKLDNNMemoryDescLrn( + input, nullptr, output, &lrn_src_md, nullptr, &lrn_dst_md, &user_src_md, + nullptr, &user_dst_md, input->rankOf() - 1); + + auto lrn_desc = lrn_forward::desc( + prop_kind::forward_inference, algorithm::lrn_across_channels, lrn_src_md, + (2 * depth + 1), alpha * (2 * depth + 1), beta, bias); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + dnnl::stream stream(engine); + auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, engine); + auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); + auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); + + auto lrn_src_memory = user_src_memory; + if (lrn_prim_desc.src_desc() != user_src_memory.get_desc()) { + lrn_src_memory = dnnl::memory(lrn_prim_desc.src_desc(), engine); + reorder(user_src_memory, lrn_src_memory) + .execute(stream, user_src_memory, lrn_src_memory); + } + + auto lrn_dst_memory = user_dst_memory; + if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) { + lrn_dst_memory = dnnl::memory(lrn_prim_desc.dst_desc(), engine); + } + + lrn_forward(lrn_prim_desc) + .execute(stream, {{DNNL_ARG_SRC, lrn_src_memory}, + {DNNL_ARG_DST, lrn_dst_memory}}); + + if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) { + reorder(lrn_dst_memory, user_dst_memory) + .execute(stream, lrn_dst_memory, user_dst_memory); + } + + stream.wait(); + + return Status::OK(); +}; + +PLATFORM_CHECK(lrn, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); +} +} // namespace platforms +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp index a74a557324b1..8ec643ef1d87 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp @@ -19,466 +19,505 @@ // #include + #include "mkldnnUtils.h" using namespace dnnl; -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { -static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* Wr, - const NDArray* b, const NDArray* hI, const NDArray* cI, - const std::vector& params, - NDArray* h, NDArray* hL, NDArray* cL) { - - // equations (no peephole connections) - // it = σ(Wxi * xt + Wri * ht-1 + bi) - // ft = σ(Wxf * xt + Wrf * ht-1 + bf) - // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) - // ct = ft ◦ ct-1 + it ◦ c't - // ot = σ(Wxo * xt + Wro * ht-1 + bo) - // ht = ot ◦ tanh(ct) - - // notations: - // bS - batch size - // sL - sequence length, number of time steps - // nIn - input size - // nOut - output size (hidden size) - - // INPUTS: - - // ******* - // input x: - // 1) [sL, bS, nIn] when dataFormat == 0 - - // ******* - // input weights Wx: - // 1) [1, 1, nIn, 4*nOut] when directionMode < 2 - // 2) [1, 2, nIn, 4*nOut] when directionMode >= 2 - - // ******* - // recurrent weights Wr: - // 1) [1, 1, nOut, 4*nOut] when directionMode < 2 - // 2) [1, 2, nOut, 4*nOut] when directionMode >= 2 - - // ******* - // biases b: - // 1) [1, 1, 4*nOut] when directionMode < 2 - // 2) [1, 2, 4*nOut] when directionMode >= 2 - - // ******* - // initial output hI: - // 1) [1, 1, bS, nOut] when directionMode < 2 - // 2) [1, 2, bS, nOut] when directionMode >= 2 - - // ******* - // initial cell state cI (same shape as in hI): - // 1) [1, 1, bS, nOut] when directionMode < 2 - // 2) [1, 2, bS, nOut] when directionMode >= 2 - - - // OUTPUTS: - - // ******* - // output h: - // 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0 - // 2) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0 - - // ******* - // output at last step hL: - // 1) [1, 1, bS, nOut] when directionMode < 2 - // 2) [1, 2, bS, nOut] when directionMode >= 2 - - // ******* - // cell state at last step cL (same shape as in hL): - // 1) [1, 1, bS, nOut] when directionMode < 2 - // 2) [1, 2, bS, nOut] when directionMode >= 2 - - // !!! dimension 4*nOut implies order it, ft, c't, ot - // !!! dimension 3*nOut implies order it, ft, ot - - // params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; - - // dataFormat: 0 = [sL, bS, nIn] - // directionMode: 0 = forward, 1 = backward, 2 = bidirectional sum, 3 = bidirectional concat - - const int dataFormat = params[0]; - const int directionMode = params[1]; - - const int sL = x->sizeAt(0); // dataFormat == 0 ? x->sizeAt(0) : x->sizeAt(1); - const int bS = x->sizeAt(1); // dataFormat == 0 ? x->sizeAt(1) : x->sizeAt(0); - const int nIn = x->sizeAt(-1); - const int nOut = Wx->sizeAt(-1); - - const int dirDim = directionMode < 2 ? 1 : 2; // number of dimensionss, 1 unidirectional, 2 for bidirectional - const int hDirDim = directionMode <= 2 ? 1 : 2; // for h array, take into account bidirectional_sum mode (directionMode == 2) - - // evaluate direction - rnn_direction direction; - switch (directionMode) { - case 0: - direction = rnn_direction::unidirectional_left2right; - break; - case 1: - direction = rnn_direction::unidirectional_right2left; - break; - case 2: - direction = rnn_direction::bidirectional_sum; - break; - default: - direction = rnn_direction::bidirectional_concat; - } - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - dnnl::memory::desc x_user_md, wx_user_md, wr_user_md, b_user_md, hI_user_md, cI_user_md, h_user_md, hL_user_md, cL_user_md, - x_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md, hI_lstm_md, cI_lstm_md, h_lstm_md, hL_lstm_md, cL_lstm_md; - - // input type - dnnl::memory::data_type xType; - if(x->dataType() == DataType::FLOAT32) - xType = dnnl::memory::data_type::f32; - else if(x->dataType() == DataType::HALF) - xType = dnnl::memory::data_type::f16; - else - xType = dnnl::memory::data_type::u8; - - // weights type - dnnl::memory::data_type wType = xType; - if(xType == dnnl::memory::data_type::u8) - wType = dnnl::memory::data_type::s8; - - // bias type - dnnl::memory::data_type bType = xType; - if(xType == dnnl::memory::data_type::u8) - bType = dnnl::memory::data_type::f32; - - // output type - dnnl::memory::data_type hType; - if(h->dataType() == DataType::FLOAT32) - hType = dnnl::memory::data_type::f32; - else if(h->dataType() == DataType::HALF) - hType = dnnl::memory::data_type::f16; - else - hType = dnnl::memory::data_type::u8; - - - // memory descriptors for arrays - // x - x_lstm_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::any); - // x_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, nIn}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, nIn}, type, dnnl::memory::format_tag::ntc); - x_user_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::tnc); - mkldnnUtils::setBlockStrides(*x, x_user_md); - - // wx - wx_lstm_md = dnnl::memory::desc({1,dirDim,nIn,4,nOut}, wType, dnnl::memory::format_tag::any); - wx_user_md = dnnl::memory::desc({1,dirDim,nIn,4,nOut}, wType, dnnl::memory::format_tag::ldigo); - mkldnnUtils::setBlockStrides(*Wx, wx_user_md); - - // wr - wr_lstm_md = dnnl::memory::desc({1,dirDim,nOut,4,nOut}, wType, dnnl::memory::format_tag::any); - wr_user_md = dnnl::memory::desc({1,dirDim,nOut,4,nOut}, wType, dnnl::memory::format_tag::ldigo); - mkldnnUtils::setBlockStrides(*Wr, wr_user_md); - - // h - h_lstm_md = dnnl::memory::desc({sL, bS, hDirDim*nOut}, hType, dnnl::memory::format_tag::any); - // h_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, hDirDim*nOut}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, hDirDim*nOut}, type, dnnl::memory::format_tag::ntc); - h_user_md = dnnl::memory::desc({sL, bS, hDirDim*nOut}, hType, dnnl::memory::format_tag::tnc); - mkldnnUtils::setBlockStrides(*h, h_user_md); - - // b - if(b) { - b_lstm_md = dnnl::memory::desc({1,dirDim,4,nOut}, bType, dnnl::memory::format_tag::any); - b_user_md = dnnl::memory::desc({1,dirDim,4,nOut}, bType, dnnl::memory::format_tag::ldgo); - mkldnnUtils::setBlockStrides(*b, b_user_md); - } - - // hI - if(hI) { - hI_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::any); - hI_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::ldnc); - mkldnnUtils::setBlockStrides(*hI, hI_user_md); - } - - // cI - if(cI) { - cI_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::any); - cI_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::ldnc); - mkldnnUtils::setBlockStrides(*cI, cI_user_md); - } - - // hL - if(hL) { - hL_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::any); - hL_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc); - hL_user_md.data.format_kind = dnnl_blocked; // overrides format - mkldnnUtils::setBlockStrides(*hL, hL_user_md); - } - - if(cL) { - cL_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc); - cL_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc); - mkldnnUtils::setBlockStrides(*cL, cL_user_md); - } - - // lstm memory description - lstm_forward::desc lstm_desc(prop_kind::forward_inference, direction, - x_lstm_md, hI_lstm_md, cI_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md, - h_lstm_md, hL_lstm_md, cL_lstm_md); - - dnnl::stream stream(engine); - - // lstm primitive description - lstm_forward::primitive_desc lstm_prim_desc(lstm_desc, engine); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - // provide memory and check whether reorder is required - // x - mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, lstm_prim_desc.src_layer_desc(), args[DNNL_ARG_SRC_LAYER]); - - // wx - mkldnnUtils::loadDataToMklStream(*Wx, engine, stream, wx_user_md, lstm_prim_desc.weights_layer_desc(), args[DNNL_ARG_WEIGHTS_LAYER]); - - // wr - mkldnnUtils::loadDataToMklStream(*Wr, engine, stream, wr_user_md, lstm_prim_desc.weights_iter_desc(), args[DNNL_ARG_WEIGHTS_ITER]); - - // h - auto h_user_mem = mkldnnUtils::loadDataToMklStream(*h, engine, stream, h_user_md, lstm_prim_desc.dst_layer_desc(), args[DNNL_ARG_DST_LAYER]); - - // b - if(b) - mkldnnUtils::loadDataToMklStream(*b, engine, stream, b_user_md, lstm_prim_desc.bias_desc(), args[DNNL_ARG_BIAS]); - - // hI - if(hI) - mkldnnUtils::loadDataToMklStream(*hI, engine, stream, hI_user_md, lstm_prim_desc.src_iter_desc(), args[DNNL_ARG_SRC_ITER]); - - // cI - if(cI) - mkldnnUtils::loadDataToMklStream(*cI, engine, stream, cI_user_md, lstm_prim_desc.src_iter_c_desc(), args[DNNL_ARG_SRC_ITER_C]); - - dnnl::memory hL_user_mem, cL_user_mem, hL_lstm_mem, cL_lstm_mem; - - // hL - if(hL) - hL_user_mem = mkldnnUtils::loadDataToMklStream(*hL, engine, stream, hL_user_md, lstm_prim_desc.dst_iter_desc(), args[DNNL_ARG_DST_ITER]); - - // cL - if(cL) - cL_user_mem = mkldnnUtils::loadDataToMklStream(*cL, engine, stream, cL_user_md, lstm_prim_desc.dst_iter_c_desc(), args[DNNL_ARG_DST_ITER_C]); - - // run calculations - lstm_forward(lstm_prim_desc).execute(stream, args); - - // reorder outputs if necessary - if (lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc()) - reorder(args[DNNL_ARG_DST_LAYER], h_user_mem).execute(stream, args[DNNL_ARG_DST_LAYER], h_user_mem); - if(lstm_prim_desc.dst_iter_desc() != hL_user_mem.get_desc()) - reorder(args[DNNL_ARG_DST_ITER], hL_user_mem).execute(stream, args[DNNL_ARG_DST_ITER], hL_user_mem); - if(lstm_prim_desc.dst_iter_c_desc() != cL_user_mem.get_desc()) - reorder(args[DNNL_ARG_DST_ITER_C], cL_user_mem).execute(stream, args[DNNL_ARG_DST_ITER_C], cL_user_mem); - - stream.wait(); +static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, + const NDArray* Wr, const NDArray* b, + const NDArray* hI, const NDArray* cI, + const std::vector& params, NDArray* h, + NDArray* hL, NDArray* cL) { + + // equations (no peephole connections) + // it = σ(Wxi * xt + Wri * ht-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + bo) + // ht = ot ◦ tanh(ct) + + // notations: + // bS - batch size + // sL - sequence length, number of time steps + // nIn - input size + // nOut - output size (hidden size) + + // INPUTS: + + // ******* + // input x: + // 1) [sL, bS, nIn] when dataFormat == 0 + + // ******* + // input weights Wx: + // 1) [1, 1, nIn, 4*nOut] when directionMode < 2 + // 2) [1, 2, nIn, 4*nOut] when directionMode >= 2 + + // ******* + // recurrent weights Wr: + // 1) [1, 1, nOut, 4*nOut] when directionMode < 2 + // 2) [1, 2, nOut, 4*nOut] when directionMode >= 2 + + // ******* + // biases b: + // 1) [1, 1, 4*nOut] when directionMode < 2 + // 2) [1, 2, 4*nOut] when directionMode >= 2 + + // ******* + // initial output hI: + // 1) [1, 1, bS, nOut] when directionMode < 2 + // 2) [1, 2, bS, nOut] when directionMode >= 2 + + // ******* + // initial cell state cI (same shape as in hI): + // 1) [1, 1, bS, nOut] when directionMode < 2 + // 2) [1, 2, bS, nOut] when directionMode >= 2 + + // OUTPUTS: + + // ******* + // output h: + // 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0 + // 2) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0 + + // ******* + // output at last step hL: + // 1) [1, 1, bS, nOut] when directionMode < 2 + // 2) [1, 2, bS, nOut] when directionMode >= 2 + + // ******* + // cell state at last step cL (same shape as in hL): + // 1) [1, 1, bS, nOut] when directionMode < 2 + // 2) [1, 2, bS, nOut] when directionMode >= 2 + + // !!! dimension 4*nOut implies order it, ft, c't, ot + // !!! dimension 3*nOut implies order it, ft, ot + + // params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, + // gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; + + // dataFormat: 0 = [sL, bS, nIn] + // directionMode: 0 = forward, 1 = backward, 2 = bidirectional sum, 3 = + // bidirectional concat + + const int dataFormat = params[0]; + const int directionMode = params[1]; + + const int sL = x->sizeAt(0); // dataFormat == 0 ? x->sizeAt(0) : x->sizeAt(1); + const int bS = x->sizeAt(1); // dataFormat == 0 ? x->sizeAt(1) : x->sizeAt(0); + const int nIn = x->sizeAt(-1); + const int nOut = Wx->sizeAt(-1); + + const int dirDim = directionMode < 2 ? 1 : 2; // number of dimensionss, 1 unidirectional, 2 for bidirectional + const int hDirDim = directionMode <= 2 ? 1 : 2; // for h array, take into account bidirectional_sum mode (directionMode == 2) + + // evaluate direction + rnn_direction direction; + switch (directionMode) { + case 0: + direction = rnn_direction::unidirectional_left2right; + break; + case 1: + direction = rnn_direction::unidirectional_right2left; + break; + case 2: + direction = rnn_direction::bidirectional_sum; + break; + default: + direction = rnn_direction::bidirectional_concat; + } + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + dnnl::memory::desc x_user_md, wx_user_md, wr_user_md, b_user_md, hI_user_md, + cI_user_md, h_user_md, hL_user_md, cL_user_md, x_lstm_md, wx_lstm_md, + wr_lstm_md, b_lstm_md, hI_lstm_md, cI_lstm_md, h_lstm_md, hL_lstm_md, + cL_lstm_md; + + // input type + dnnl::memory::data_type xType; + if (x->dataType() == DataType::FLOAT32) + xType = dnnl::memory::data_type::f32; + else if (x->dataType() == DataType::HALF) + xType = dnnl::memory::data_type::f16; + else + xType = dnnl::memory::data_type::u8; + + // weights type + dnnl::memory::data_type wType = xType; + if (xType == dnnl::memory::data_type::u8) + wType = dnnl::memory::data_type::s8; + + // bias type + dnnl::memory::data_type bType = xType; + if (xType == dnnl::memory::data_type::u8) + bType = dnnl::memory::data_type::f32; + + // output type + dnnl::memory::data_type hType; + if (h->dataType() == DataType::FLOAT32) + hType = dnnl::memory::data_type::f32; + else if (h->dataType() == DataType::HALF) + hType = dnnl::memory::data_type::f16; + else + hType = dnnl::memory::data_type::u8; + + // memory descriptors for arrays + // x + x_lstm_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::any); + // x_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, nIn}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, nIn}, type, dnnl::memory::format_tag::ntc); + x_user_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::tnc); + mkldnnUtils::setBlockStrides(*x, x_user_md); + + // wx + wx_lstm_md = dnnl::memory::desc({1, dirDim, nIn, 4, nOut}, wType, dnnl::memory::format_tag::any); + wx_user_md = dnnl::memory::desc({1, dirDim, nIn, 4, nOut}, wType, dnnl::memory::format_tag::ldigo); + mkldnnUtils::setBlockStrides(*Wx, wx_user_md); + + // wr + wr_lstm_md = dnnl::memory::desc({1, dirDim, nOut, 4, nOut}, wType, dnnl::memory::format_tag::any); + wr_user_md = dnnl::memory::desc({1, dirDim, nOut, 4, nOut}, wType, dnnl::memory::format_tag::ldigo); + mkldnnUtils::setBlockStrides(*Wr, wr_user_md); + + // h + h_lstm_md = dnnl::memory::desc({sL, bS, hDirDim * nOut}, hType, dnnl::memory::format_tag::any); + // h_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, hDirDim*nOut}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, hDirDim*nOut}, type, dnnl::memory::format_tag::ntc); + h_user_md = dnnl::memory::desc({sL, bS, hDirDim * nOut}, hType, dnnl::memory::format_tag::tnc); + mkldnnUtils::setBlockStrides(*h, h_user_md); + + // b + if (b) { + b_lstm_md = dnnl::memory::desc({1, dirDim, 4, nOut}, bType, dnnl::memory::format_tag::any); + b_user_md = dnnl::memory::desc({1, dirDim, 4, nOut}, bType, dnnl::memory::format_tag::ldgo); + mkldnnUtils::setBlockStrides(*b, b_user_md); + } + + // hI + if (hI) { + hI_lstm_md = dnnl::memory::desc({1, dirDim, bS, nOut}, xType, dnnl::memory::format_tag::any); + hI_user_md = dnnl::memory::desc({1, dirDim, bS, nOut}, xType, dnnl::memory::format_tag::ldnc); + mkldnnUtils::setBlockStrides(*hI, hI_user_md); + } + + // cI + if (cI) { + cI_lstm_md = dnnl::memory::desc({1, dirDim, bS, nOut}, xType, dnnl::memory::format_tag::any); + cI_user_md = dnnl::memory::desc({1, dirDim, bS, nOut}, xType, dnnl::memory::format_tag::ldnc); + mkldnnUtils::setBlockStrides(*cI, cI_user_md); + } + + // hL + if (hL) { + hL_lstm_md = dnnl::memory::desc({1, dirDim, bS, nOut}, hType, dnnl::memory::format_tag::any); + hL_user_md = dnnl::memory::desc({1, dirDim, bS, nOut}, hType, dnnl::memory::format_tag::ldnc); + hL_user_md.data.format_kind = dnnl_blocked; // overrides format + mkldnnUtils::setBlockStrides(*hL, hL_user_md); + } + + if (cL) { + cL_lstm_md = dnnl::memory::desc({1, dirDim, bS, nOut}, hType, dnnl::memory::format_tag::ldnc); + cL_user_md = dnnl::memory::desc({1, dirDim, bS, nOut}, hType, dnnl::memory::format_tag::ldnc); + mkldnnUtils::setBlockStrides(*cL, cL_user_md); + } + + // lstm memory description + lstm_forward::desc lstm_desc(prop_kind::forward_inference, direction, x_lstm_md, hI_lstm_md, cI_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md, h_lstm_md, hL_lstm_md, cL_lstm_md); + + dnnl::stream stream(engine); + + // lstm primitive description + lstm_forward::primitive_desc lstm_prim_desc(lstm_desc, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + // provide memory and check whether reorder is required + // x + mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, lstm_prim_desc.src_layer_desc(), args[DNNL_ARG_SRC_LAYER]); + + // wx + mkldnnUtils::loadDataToMklStream(*Wx, engine, stream, wx_user_md, lstm_prim_desc.weights_layer_desc(), args[DNNL_ARG_WEIGHTS_LAYER]); + + // wr + mkldnnUtils::loadDataToMklStream(*Wr, engine, stream, wr_user_md, lstm_prim_desc.weights_iter_desc(), args[DNNL_ARG_WEIGHTS_ITER]); + + // h + auto h_user_mem = mkldnnUtils::loadDataToMklStream(*h, engine, stream, h_user_md, lstm_prim_desc.dst_layer_desc(), args[DNNL_ARG_DST_LAYER] ); + + // b + if (b) + mkldnnUtils::loadDataToMklStream(*b, engine, stream, b_user_md, lstm_prim_desc.bias_desc(), args[DNNL_ARG_BIAS]); + + // hI + if (hI) + mkldnnUtils::loadDataToMklStream(*hI, engine, stream, hI_user_md, lstm_prim_desc.src_iter_desc(), args[DNNL_ARG_SRC_ITER]); + + // cI + if (cI) + mkldnnUtils::loadDataToMklStream(*cI, engine, stream, cI_user_md, lstm_prim_desc.src_iter_c_desc(), args[DNNL_ARG_SRC_ITER_C]); + + dnnl::memory hL_user_mem, cL_user_mem, hL_lstm_mem, cL_lstm_mem; + + // hL + if (hL) + hL_user_mem = mkldnnUtils::loadDataToMklStream(*hL, engine, stream, hL_user_md, lstm_prim_desc.dst_iter_desc(), + args[DNNL_ARG_DST_ITER] ); + + // cL + if (cL) + cL_user_mem = mkldnnUtils::loadDataToMklStream(*cL, engine, stream, cL_user_md, lstm_prim_desc.dst_iter_c_desc(), + args[DNNL_ARG_DST_ITER_C] ); + + // run calculations + lstm_forward(lstm_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc()) + reorder(args[DNNL_ARG_DST_LAYER], h_user_mem).execute(stream, args[DNNL_ARG_DST_LAYER], h_user_mem); + if (lstm_prim_desc.dst_iter_desc() != hL_user_mem.get_desc()) + reorder(args[DNNL_ARG_DST_ITER], hL_user_mem).execute(stream, args[DNNL_ARG_DST_ITER], hL_user_mem); + if (lstm_prim_desc.dst_iter_c_desc() != cL_user_mem.get_desc()) + reorder(args[DNNL_ARG_DST_ITER_C], cL_user_mem).execute(stream, args[DNNL_ARG_DST_ITER_C], cL_user_mem); + + stream.wait(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(lstmLayer, ENGINE_CPU) { - const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX) - const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) - - const auto hasBiases = B_ARG(0); // indicates whether biases array is provided - const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided - const auto hasInitH = B_ARG(2); // indicates whether initial output is provided - const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided - const auto hasPH = B_ARG(4); // indicates whether peephole connections are present - const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} - const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) - const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) - - const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping - - const auto x = INPUT_VARIABLE(0); // input - const auto Wx = INPUT_VARIABLE(1); // input weights - const auto Wr = INPUT_VARIABLE(2); // recurrent weights - - int count = 3; - const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases - const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector - const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output - const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state - const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights - - REQUIRE_TRUE(cellClip == 0 , 0, "LSTM_LAYER_MKLDNN operation: cell clipping is not supported currently !"); - REQUIRE_TRUE(retFullSeq, 0, "LSTM_LAYER_MKLDNN operation: option to calculate full time sequence output h should be always true in case of mkl dnn library !"); - REQUIRE_TRUE(hasPH == false , 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support peephole connections !"); - REQUIRE_TRUE(hasSeqLen == false, 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support array specifying max time step per each example in batch !"); - REQUIRE_TRUE(dataFormat < 2, 0, "LSTM_LAYER_MKLDNN operation: wrong data format, only two formats are allowed for input/output tensors in mkl dnn library: TNC and NTC!"); - REQUIRE_TRUE(directionMode < 4, 0, "LSTM_LAYER_MKLDNN operation: option for bidirectional extra output dimension is not valid in mkl dnn library !"); - REQUIRE_TRUE(retLastH == retLastC, 0, "LSTM_LAYER_MKLDNN operation: only two options are present: 1) calculate both output at last time and cell state at last time; 2) do not calculate both !"); - REQUIRE_TRUE(hasInitH == hasInitC, 0, "LSTM_LAYER_MKLDNN operation: either both of or neither of initial C and initial H must be provided"); - - count = 0; - auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output - auto hL = retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step - auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step - - // evaluate dimensions - const Nd4jLong sL = x->sizeAt(dataFormat); - const Nd4jLong bS = dataFormat == 0 ? x->sizeAt(1) : x->sizeAt(0); - const Nd4jLong nIn = x->sizeAt(2); - const Nd4jLong nOut = Wx->sizeAt(-1) / 4; - - // inputs validations - if(directionMode < 2) { // no bidirectional - - // Wx validation - if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - // Wr validation - if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); - // biases validation - if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); - // initial output validation - if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str()); - // initial cell validation - if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str()); - } - else { // bidirectional - // Wx validation - if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - // Wr validation - if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); - // biases validation - if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); - // initial output validation - if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str()); - // initial cell validation - if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str()); - } - - std::vector params = {static_cast(dataFormat), static_cast(directionMode), static_cast(cellClip)}; - - const int dirDim = directionMode < 2 ? 1 : 2; // number of dimensions, 1 unidirectional, 2 for bidirectional - - // permut x and h to tnc format if they have ntc format - NDArray* xP(const_cast(x)), *hP(h); - if(dataFormat == 1) { - xP = new NDArray(x->permute({1,0,2})); // [bS, sL, nIn] -> [sL, bS, nIn] - hP = new NDArray(h->permute({1,0,2})); // [bS, sL, dirDim*nOn] -> [sL, bS, dirDim*nOn] - } - - // reshape arrays in accordance to mkl allowed formats - NDArray *WxR(nullptr), *WrR(nullptr), *bR(nullptr), *hIR(nullptr), *cIR(nullptr), *hLR(nullptr), *cLR(nullptr); - - WxR = new NDArray(Wx->reshape(Wx->ordering(), {1,dirDim,nIn,4,nOut})); - WrR = new NDArray(Wr->reshape(Wr->ordering(), {1,dirDim,nOut,4,nOut})); - - if(b) - bR = new NDArray(b->reshape(b->ordering(), {1,dirDim,4,nOut})); - else - bR = new NDArray(x->ordering(), {1,dirDim,4,nOut}, x->dataType(), x->getContext()); // already nullified - - if(hI) - hIR = new NDArray(hI->reshape(hI->ordering(), {1,dirDim,bS,nOut})); - - if(cI) - cIR = new NDArray(cI->reshape(cI->ordering(), {1,dirDim,bS,nOut})); - - if(hL) - hLR = new NDArray(hL->reshape(hL->ordering(), {1,dirDim,bS,nOut}, false)); - - if(cL) - cLR = new NDArray(cL->reshape(cL->ordering(), {1,dirDim,bS,nOut}, false)); - - lstmLayerMKLDNN(xP, WxR, WrR, bR, hIR, cIR, params, hP, hLR, cLR); - - delete WxR; - delete WrR; - delete bR; - delete hIR; - delete cIR; - delete hLR; - delete cLR; - - if(dataFormat == 1) { - delete xP; - delete hP; - } - - return Status::OK(); + const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX) + const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 =bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) + + const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided + const auto hasInitH = B_ARG(2); // indicates whether initial output is provided + const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided + const auto hasPH = B_ARG(4); // indicates whether peephole connections are present + const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} + const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + + const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping + + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + + int count = 3; + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector + const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output + const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state + const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights + + REQUIRE_TRUE(cellClip == 0, 0, "LSTM_LAYER_MKLDNN operation: cell clipping is not supported " "currently !"); + REQUIRE_TRUE(retFullSeq, 0, "LSTM_LAYER_MKLDNN operation: option to calculate full time sequence output h should be always true in case of mkl dnn library !"); + REQUIRE_TRUE(hasPH == false, 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support peephole connections !"); + REQUIRE_TRUE(hasSeqLen == false, 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support array specifying max time step per each example in batch !"); + REQUIRE_TRUE(dataFormat < 2, 0, "LSTM_LAYER_MKLDNN operation: wrong data format, only two formats are allowed for input/output tensors in mkl dnn library: TNC and NTC!"); + REQUIRE_TRUE(directionMode < 4, 0, "LSTM_LAYER_MKLDNN operation: option for bidirectional extra output dimension is not valid in mkl dnn library !"); + REQUIRE_TRUE(retLastH == retLastC, 0, "LSTM_LAYER_MKLDNN operation: only two options are present: 1) calculate both output at last time and cell state at last time; 2) do not calculate both !"); + REQUIRE_TRUE(hasInitH == hasInitC, 0, "LSTM_LAYER_MKLDNN operation: either both of or neither of initial C and initial H must be provided"); + + count = 0; + auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output + auto hL = retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step + auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step + + // evaluate dimensions + const Nd4jLong sL = x->sizeAt(dataFormat); + const Nd4jLong bS = dataFormat == 0 ? x->sizeAt(1) : x->sizeAt(0); + const Nd4jLong nIn = x->sizeAt(2); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + // inputs validations + if (directionMode < 2) { // no bidirectional + + // Wx validation + if (Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4 * nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); + // Wr validation + if (Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4 * nOut) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4 * nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); + // biases validation + if (b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4 * nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4 * nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); + // initial output validation + if (hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str()); + // initial cell validation + if (cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str()); + } + else { // bidirectional + // Wx validation + if (Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of input weights expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4 * nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); + // Wr validation + if (Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4 * nOut) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4 * nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); + // biases validation + if (b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4 * nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4 * nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); + // initial output validation + if (hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str()); + // initial cell validation + if (cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str()); + } + + std::vector params = {static_cast(dataFormat), static_cast(directionMode), static_cast(cellClip)}; + + const int dirDim = directionMode < 2 ? 1 : 2; // number of dimensions, 1 unidirectional, 2 for bidirectional + + // permut x and h to tnc format if they have ntc format + NDArray *xP(const_cast(x)), *hP(h); + if (dataFormat == 1) { + xP = new NDArray(x->permute({1, 0, 2})); // [bS, sL, nIn] -> [sL, bS, nIn] + hP = new NDArray(h->permute({1, 0, 2})); // [bS, sL, dirDim*nOn] -> [sL, bS, dirDim*nOn] + } + + // reshape arrays in accordance to mkl allowed formats + NDArray *WxR(nullptr), *WrR(nullptr), *bR(nullptr), *hIR(nullptr), *cIR(nullptr), *hLR(nullptr), *cLR(nullptr); + + WxR = new NDArray(Wx->reshape(Wx->ordering(), {1, dirDim, nIn, 4, nOut})); + WrR = new NDArray(Wr->reshape(Wr->ordering(), {1, dirDim, nOut, 4, nOut})); + if (b) + bR = new NDArray(b->reshape(b->ordering(), {1, dirDim, 4, nOut})); + else + bR = new NDArray(x->ordering(), {1,dirDim,4,nOut}, x->dataType(), x->getContext()); // already nullifiedif (hI) hIR = new NDArray(hI->reshape(hI->ordering(), {1, dirDim, bS, nOut})); + if(hI) + hIR = new NDArray(hI->reshape(hI->ordering(), {1,dirDim,bS,nOut})); + if (cI) + cIR = new NDArray(cI->reshape(cI->ordering(), {1, dirDim, bS, nOut})); + if (hL) + hLR = new NDArray(hL->reshape(hL->ordering(), {1, dirDim, bS, nOut}, false)); + if (cL) + cLR = new NDArray(cL->reshape(cL->ordering(), {1, dirDim, bS, nOut}, false)); + + lstmLayerMKLDNN(xP, WxR, WrR, bR, hIR, cIR, params, hP, hLR, cLR); + + delete WxR; + delete WrR; + delete bR; + delete hIR; + delete cIR; + delete hLR; + delete cLR; + + if (dataFormat == 1) { + delete xP; + delete hP; + } + + return Status::OK(); } PLATFORM_CHECK(lstmLayer, ENGINE_CPU) { - - const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX) - const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) - - const auto hasBiases = B_ARG(0); // indicates whether biases array is provided - const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided - const auto hasInitH = B_ARG(2); // indicates whether initial output is provided - const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided - const auto hasPH = B_ARG(4); // indicates whether peephole connections are present - const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} - const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) - const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) - - const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping - - const auto x = INPUT_VARIABLE(0); // input - const auto Wx = INPUT_VARIABLE(1); // input weights - const auto Wr = INPUT_VARIABLE(2); // recurrent weights - - int count = 3; - const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases - const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output - const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state - - count = 0; - auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output - auto hL = retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step - auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step - - DataType xType = x->dataType(); - DataType WxType = Wx->dataType(); - DataType WrType = Wr->dataType(); - DataType bType = b != nullptr ? b->dataType() : (xType == DataType::HALF ? xType : DataType::FLOAT32); - DataType hIType = hI != nullptr ? hI->dataType() : xType; - DataType cIType = cI != nullptr ? cI->dataType() : xType; - DataType hType = h != nullptr ? h->dataType() : xType; - DataType hLType = hL != nullptr ? hL->dataType() : xType; - DataType cLType = cL != nullptr ? cL->dataType() : xType; - - auto featuresSupported = (cellClip == 0) //Cell clipping not supported - && retFullSeq //Always return full sequence in case of MKL DNN - && !hasPH //Peephole connections not supported in MKL DNN - && !hasSeqLen //Sequence length array not supported in MKL DNN - && dataFormat < 2 //Data format - only 0 and 1 supported in MKL DNN- 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn] - && directionMode < 4 //Direction mode - only 0-3 supported in MKL DNN (no extra dim option) - 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat - && retLastH == retLastC //Return both lastH and lastC, or return neither (not just 1 or other) - && hasInitH == hasInitC; //Need both or neither initial H and C - - return block.isUseMKLDNN() && featuresSupported && ( - (xType==DataType::FLOAT32 && WxType==DataType::FLOAT32 && WrType==DataType::FLOAT32 && bType==DataType::FLOAT32 && hIType==DataType::FLOAT32 && cIType==DataType::FLOAT32 && hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32) || - (xType==DataType::HALF && WxType==DataType::HALF && WrType==DataType::HALF && bType==DataType::HALF && hIType==DataType::HALF && cIType==DataType::HALF && hType==DataType::HALF && hLType==DataType::HALF && cLType==DataType::HALF) || - (xType==DataType::UINT8 && WxType==DataType::INT8 && WrType==DataType::INT8 && bType==DataType::FLOAT32 && hIType==DataType::UINT8 && cIType==DataType::UINT8 && (hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32 || hType==DataType::UINT8 && hLType==DataType::UINT8 && cLType==DataType::UINT8)) - ); + const auto dataFormat = INT_ARG( + 0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, + // nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX) + const auto directionMode = + INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = + // bidirectional concat, 4 = bidirectional extra output dim + // (in conjunction with format dataFormat = 3) + + const auto hasBiases = + B_ARG(0); // indicates whether biases array is provided + const auto hasSeqLen = + B_ARG(1); // indicates whether seqLen array is provided + const auto hasInitH = + B_ARG(2); // indicates whether initial output is provided + const auto hasInitC = + B_ARG(3); // indicates whether initial cell state is provided + const auto hasPH = + B_ARG(4); // indicates whether peephole connections are present + const auto retFullSeq = B_ARG(5); // indicates whether to return whole time + // sequence h {h_0, h_1, ... , h_sL-1} + const auto retLastH = + B_ARG(6); // indicates whether to return output at last time step only, + // in this case shape would be [bS, nOut] (exact shape depends + // on dataFormat argument) + const auto retLastC = + B_ARG(7); // indicates whether to return cells state at last time step + // only, in this case shape would be [bS, nOut] (exact shape + // depends on dataFormat argument) + + const auto cellClip = + T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping + + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + + int count = 3; + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto hI = + hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output + const auto cI = + hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state + + count = 0; + auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output + auto hL = + retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step + auto cL = + retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step + + DataType xType = x->dataType(); + DataType WxType = Wx->dataType(); + DataType WrType = Wr->dataType(); + DataType bType = b != nullptr + ? b->dataType() + : (xType == DataType::HALF ? xType : DataType::FLOAT32); + DataType hIType = hI != nullptr ? hI->dataType() : xType; + DataType cIType = cI != nullptr ? cI->dataType() : xType; + DataType hType = h != nullptr ? h->dataType() : xType; + DataType hLType = hL != nullptr ? hL->dataType() : xType; + DataType cLType = cL != nullptr ? cL->dataType() : xType; + + auto featuresSupported = + (cellClip == 0) // Cell clipping not supported + && retFullSeq // Always return full sequence in case of MKL DNN + && !hasPH // Peephole connections not supported in MKL DNN + && !hasSeqLen // Sequence length array not supported in MKL DNN + && dataFormat < 2 // Data format - only 0 and 1 supported in MKL DNN- 0 = + // [sL, bS, nIn], 1 = [bS, sL ,nIn] + && directionMode < 4 // Direction mode - only 0-3 supported in MKL DNN + // (no extra dim option) - 0 = fwd, 1 = bwd, 2 = + // bidirectional sum, 3 = bidirectional concat + && retLastH == retLastC // Return both lastH and lastC, or return neither + // (not just 1 or other) + && hasInitH == hasInitC; // Need both or neither initial H and C + + return block.isUseMKLDNN() && featuresSupported && + ((xType == DataType::FLOAT32 && WxType == DataType::FLOAT32 && + WrType == DataType::FLOAT32 && bType == DataType::FLOAT32 && + hIType == DataType::FLOAT32 && cIType == DataType::FLOAT32 && + hType == DataType::FLOAT32 && hLType == DataType::FLOAT32 && + cLType == DataType::FLOAT32) || + (xType == DataType::HALF && WxType == DataType::HALF && + WrType == DataType::HALF && bType == DataType::HALF && + hIType == DataType::HALF && cIType == DataType::HALF && + hType == DataType::HALF && hLType == DataType::HALF && + cLType == DataType::HALF) || + (xType == DataType::UINT8 && WxType == DataType::INT8 && + WrType == DataType::INT8 && bType == DataType::FLOAT32 && + hIType == DataType::UINT8 && cIType == DataType::UINT8 && + (hType == DataType::FLOAT32 && hLType == DataType::FLOAT32 && + cLType == DataType::FLOAT32 || + hType == DataType::UINT8 && hLType == DataType::UINT8 && + cLType == DataType::UINT8))); } - - -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp index f242b2e79ee5..7a1908372012 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp @@ -18,284 +18,355 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include +#include #include +#include #include -#include -#include "mkldnnUtils.h" #include +#include "mkldnnUtils.h" -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////////// -static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const bool transX, const bool transY, float alpha = 1.f, float beta = 0.f) { - - // mkl works with following - // [M,K] x [K,N] = [M,N] - // [bS, M,K] x [bS, K,N] = [bS, M,N] - - // possible input cases not supported by mkl, however we'll perform permut/reshape procedures in order to fit requirements - // [4] x [4] = [1] --> [1,4] x [4,1] = [1,1] - // [4] x [4,5] = [5] --> [1,4] x [4,5] = [1,5] - // [4,5] x [5] = [4] --> [4,5] x [5,1] = [4,1] - // [2,3, 4,5] x [2,3, 5,4] = [2,3, 4,4] --> [6, 4,5] x [6, 5,4] = [6, 4,4] - // [2,2,3, 4,5] x [2,2,3, 5,4] = [2,2,3, 4,4] --> [12, 4,5] x [12, 5,4] = [12, 4,4] - - const auto xRank = x->rankOf(); - const auto yRank = y->rankOf(); - const auto zRank = z->rankOf(); - - std::vector permut; - - // fill permutation vector appropriately if transposition is required - if((transX && xRank > 1) || (transY && yRank > 1)) { - - const int rank = xRank >= yRank ? xRank : yRank; - permut.resize(rank); - std::iota(std::begin(permut), std::end(permut), 0); - permut[rank-2] = rank - 1; - permut[rank-1] = rank - 2; - } - - const NDArray* xT = (transX && xRank > 1) ? new NDArray(x->permute(permut)) : x; - const NDArray* yT = (transY && yRank > 1) ? new NDArray(y->permute(permut)) : y; - - const NDArray* xTR = xRank <= 3 ? xT : new NDArray(xT->reshape(xT->ordering(), {xT->lengthOf() / (xT->sizeAt(-2) * xT->sizeAt(-1)), xT->sizeAt(-2), xT->sizeAt(-1)})); - const NDArray* yTR = xRank <= 3 ? yT : new NDArray(yT->reshape(yT->ordering(), {yT->lengthOf() / (yT->sizeAt(-2) * yT->sizeAt(-1)), yT->sizeAt(-2), yT->sizeAt(-1)})); - NDArray* zR = xRank <= 3 ? z : new NDArray(z->reshape(z->ordering(), {z->lengthOf() / (z->sizeAt(-2) * z->sizeAt(-1)), z->sizeAt(-2), z->sizeAt(-1)})/*, false*/); - - // [M,K] x [K,N] = [M,N] - const int64_t M = (xRank > 1) ? xTR->sizeAt(-2) : 1; - const int64_t K = (xRank > 1) ? xTR->sizeAt(-1) : xTR->lengthOf(); - const int64_t N = (yRank > 1) ? yTR->sizeAt(-1) : 1; - const int64_t bS = (xRank > 2) ? xTR->sizeAt(0) : 1; // [bS, M,K] x [bS, K,N] = [bS, M,N] - - dnnl::memory::dims xShape = xRank < 3 ? dnnl::memory::dims({M, K}) : dnnl::memory::dims({bS, M, K}); - dnnl::memory::dims yShape = xRank < 3 ? dnnl::memory::dims({K, N}) : dnnl::memory::dims({bS, K, N}); - dnnl::memory::dims zShape = xRank < 3 ? dnnl::memory::dims({M, N}) : dnnl::memory::dims({bS, M, N}); - - // x type - dnnl::memory::data_type xType; - if(x->dataType() == DataType::FLOAT32) - xType = dnnl::memory::data_type::f32; - else if(x->dataType() == DataType::HALF) - xType = dnnl::memory::data_type::f16; - else if(x->dataType() == DataType::BFLOAT16) - xType = dnnl::memory::data_type::bf16; - else if(x->dataType() == DataType::UINT8) - xType = dnnl::memory::data_type::u8; - else - xType = dnnl::memory::data_type::s8; - - // y type - dnnl::memory::data_type yType = xType; - if(y->dataType() == DataType::UINT8) - yType = dnnl::memory::data_type::u8; - else if(y->dataType() == DataType::INT8) - yType = dnnl::memory::data_type::s8; - - // z type - dnnl::memory::data_type zType = xType; - if(z->dataType() == DataType::FLOAT32) - zType = dnnl::memory::data_type::f32; - else if(z->dataType() == DataType::INT32) - zType = dnnl::memory::data_type::s32; - else if(z->dataType() == DataType::UINT8) - zType = dnnl::memory::data_type::u8; - else if(z->dataType() == DataType::INT8) - zType = dnnl::memory::data_type::s8; - - - const auto xFormat = xRank == 1 ? dnnl::memory::format_tag::ab : mkldnnUtils::getFormat(*xTR); +static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, + const bool transX, const bool transY, + float alpha = 1.f, float beta = 0.f) { + // mkl works with following + // [M,K] x [K,N] = [M,N] + // [bS, M,K] x [bS, K,N] = [bS, M,N] + + // possible input cases not supported by mkl, however we'll perform + // permut/reshape procedures in order to fit requirements [4] x [4] + // = [1] --> [1,4] x [4,1] = [1,1] [4] x [4,5] = [5] + // --> [1,4] x [4,5] = [1,5] [4,5] x [5] = [4] --> + // [4,5] x [5,1] = [4,1] [2,3, 4,5] x [2,3, 5,4] = [2,3, 4,4] --> + // [6, 4,5] x [6, 5,4] = [6, 4,4] [2,2,3, 4,5] x [2,2,3, 5,4] = [2,2,3, 4,4] + // --> [12, 4,5] x [12, 5,4] = [12, 4,4] + + const auto xRank = x->rankOf(); + const auto yRank = y->rankOf(); + const auto zRank = z->rankOf(); + + std::vector permut; + + // fill permutation vector appropriately if transposition is required + if ((transX && xRank > 1) || (transY && yRank > 1)) { + const int rank = xRank >= yRank ? xRank : yRank; + permut.resize(rank); + std::iota(std::begin(permut), std::end(permut), 0); + permut[rank - 2] = rank - 1; + permut[rank - 1] = rank - 2; + } + + const NDArray* xT = + (transX && xRank > 1) ? new NDArray(x->permute(permut)) : x; + const NDArray* yT = + (transY && yRank > 1) ? new NDArray(y->permute(permut)) : y; + + const NDArray* xTR = + xRank <= 3 ? xT + : new NDArray(xT->reshape( + xT->ordering(), + {xT->lengthOf() / (xT->sizeAt(-2) * xT->sizeAt(-1)), + xT->sizeAt(-2), xT->sizeAt(-1)})); + const NDArray* yTR = + xRank <= 3 ? yT + : new NDArray(yT->reshape( + yT->ordering(), + {yT->lengthOf() / (yT->sizeAt(-2) * yT->sizeAt(-1)), + yT->sizeAt(-2), yT->sizeAt(-1)})); + NDArray* zR = + xRank <= 3 + ? z + : new NDArray(z->reshape( + z->ordering(), {z->lengthOf() / (z->sizeAt(-2) * z->sizeAt(-1)), + z->sizeAt(-2), z->sizeAt(-1)}) /*, false*/); + + // [M,K] x [K,N] = [M,N] + const int64_t M = (xRank > 1) ? xTR->sizeAt(-2) : 1; + const int64_t K = (xRank > 1) ? xTR->sizeAt(-1) : xTR->lengthOf(); + const int64_t N = (yRank > 1) ? yTR->sizeAt(-1) : 1; + const int64_t bS = + (xRank > 2) ? xTR->sizeAt(0) : 1; // [bS, M,K] x [bS, K,N] = [bS, M,N] + + dnnl::memory::dims xShape = + xRank < 3 ? dnnl::memory::dims({M, K}) : dnnl::memory::dims({bS, M, K}); + dnnl::memory::dims yShape = + xRank < 3 ? dnnl::memory::dims({K, N}) : dnnl::memory::dims({bS, K, N}); + dnnl::memory::dims zShape = + xRank < 3 ? dnnl::memory::dims({M, N}) : dnnl::memory::dims({bS, M, N}); + + // x type + dnnl::memory::data_type xType; + if (x->dataType() == DataType::FLOAT32) + xType = dnnl::memory::data_type::f32; + else if (x->dataType() == DataType::HALF) + xType = dnnl::memory::data_type::f16; + else if (x->dataType() == DataType::BFLOAT16) + xType = dnnl::memory::data_type::bf16; + else if (x->dataType() == DataType::UINT8) + xType = dnnl::memory::data_type::u8; + else + xType = dnnl::memory::data_type::s8; + + // y type + dnnl::memory::data_type yType = xType; + if (y->dataType() == DataType::UINT8) + yType = dnnl::memory::data_type::u8; + else if (y->dataType() == DataType::INT8) + yType = dnnl::memory::data_type::s8; + + // z type + dnnl::memory::data_type zType = xType; + if (z->dataType() == DataType::FLOAT32) + zType = dnnl::memory::data_type::f32; + else if (z->dataType() == DataType::INT32) + zType = dnnl::memory::data_type::s32; + else if (z->dataType() == DataType::UINT8) + zType = dnnl::memory::data_type::u8; + else if (z->dataType() == DataType::INT8) + zType = dnnl::memory::data_type::s8; + +const auto xFormat = xRank == 1 ? dnnl::memory::format_tag::ab : mkldnnUtils::getFormat(*xTR); const auto yFormat = yRank == 1 ? dnnl::memory::format_tag::ab : mkldnnUtils::getFormat(*yTR); const auto zFormat = zRank == 1 ? dnnl::memory::format_tag::ab : mkldnnUtils::getFormat(*zR); // memory descriptors for arrays dnnl::memory::desc x_mkl_md, x_user_md, y_mkl_md, y_user_md, z_mkl_md, z_user_md; - // x - x_user_md = x_mkl_md = dnnl::memory::desc(xShape, xType, xFormat); - if(xTR->ews() != 1) { - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = xRank == 1 ? 1 : xTR->strideAt(0); - x_user_md.data.format_desc.blocking.strides[1] = xRank == 1 ? xTR->strideAt(0) : xTR->strideAt(1); - if(xRank > 2) - x_user_md.data.format_desc.blocking.strides[2] = xTR->strideAt(2); - } - - // y - y_user_md = y_mkl_md = dnnl::memory::desc(yShape, yType, yFormat); - if(yTR->ews() != 1) { - y_user_md.data.format_kind = dnnl_blocked; // overrides format - y_user_md.data.format_desc.blocking.strides[0] = yRank == 1 ? 1 : yTR->strideAt(0); - y_user_md.data.format_desc.blocking.strides[1] = yRank == 1 ? yTR->strideAt(0) : yTR->strideAt(1); - if(yRank > 2) - y_user_md.data.format_desc.blocking.strides[2] = yTR->strideAt(2); - } - - // z - z_user_md = z_mkl_md = dnnl::memory::desc(zShape, zType, zFormat); - if(zR->ews() != 1) { - z_user_md.data.format_kind = dnnl_blocked; // overrides format - z_user_md.data.format_desc.blocking.strides[0] = zRank == 1 ? 1 : zR->strideAt(0); - z_user_md.data.format_desc.blocking.strides[1] = zRank == 1 ? zR->strideAt(0) : zR->strideAt(1); - if(zRank > 2) - z_user_md.data.format_desc.blocking.strides[2] = zR->strideAt(2); - } - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // Create attributes (to handle alpha and beta if necessary) - dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0) - if (alpha != 1.f) attr.set_output_scales(0, {alpha}); - if (beta != 0.f) { - dnnl::post_ops po; - po.append_sum(beta); - attr.set_post_ops(po); - } - - // operation primitive description - dnnl::matmul::desc op_desc(x_mkl_md, y_mkl_md, z_mkl_md); - dnnl::matmul::primitive_desc op_prim_desc(op_desc, attr, engine); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(*xTR, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // y - mkldnnUtils::loadDataToMklStream(*yTR, engine, stream, y_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - - // z - auto z_user_mem = mkldnnUtils::loadDataToMklStream(*zR, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]); - - // run calculations - dnnl::matmul(op_prim_desc).execute(stream, args); - - // reorder outputs if necessary - if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) - dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); - - stream.wait(); - - if(zR->buffer() != z->buffer()) - z->assign(zR); - - if(zR != z) - delete zR; - if(xTR != xT) - delete xTR; - if(xT != x) - delete xT; - if(yTR != yT) - delete yTR; - if(yT != y) - delete yT; - - // shape::printArray(z_mkl_mem.map_data(),8); + // x + x_user_md = x_mkl_md = + dnnl::memory::desc(xShape, xType, xFormat); + if (xTR->ews() != 1) { + x_user_md.data.format_kind = dnnl_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = + xRank == 1 ? 1 : xTR->strideAt(0); + x_user_md.data.format_desc.blocking.strides[1] = + xRank == 1 ? xTR->strideAt(0) : xTR->strideAt(1); + if (xRank > 2) + x_user_md.data.format_desc.blocking.strides[2] = xTR->strideAt(2); + } + + // y + y_user_md = y_mkl_md = + dnnl::memory::desc(yShape, yType, yFormat); + if (yTR->ews() != 1) { + y_user_md.data.format_kind = dnnl_blocked; // overrides format + y_user_md.data.format_desc.blocking.strides[0] = + yRank == 1 ? 1 : yTR->strideAt(0); + y_user_md.data.format_desc.blocking.strides[1] = + yRank == 1 ? yTR->strideAt(0) : yTR->strideAt(1); + if (yRank > 2) + y_user_md.data.format_desc.blocking.strides[2] = yTR->strideAt(2); + } + + // z + z_user_md = z_mkl_md = + dnnl::memory::desc(zShape, zType, zFormat); + if (zR->ews() != 1) { + z_user_md.data.format_kind = dnnl_blocked; // overrides format + z_user_md.data.format_desc.blocking.strides[0] = + zRank == 1 ? 1 : zR->strideAt(0); + z_user_md.data.format_desc.blocking.strides[1] = + zRank == 1 ? zR->strideAt(0) : zR->strideAt(1); + if (zRank > 2) + z_user_md.data.format_desc.blocking.strides[2] = zR->strideAt(2); + } + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // Create attributes (to handle alpha and beta if necessary) + dnnl::primitive_attr attr; // it is empty since we have usual values for + // alpha (=1) and beta (=0) + if (alpha != 1.f) attr.set_output_scales(0, {alpha}); + if (beta != 0.f) { + dnnl::post_ops po; + po.append_sum(beta); + attr.set_post_ops(po); + } + + // operation primitive description + dnnl::matmul::desc op_desc(x_mkl_md, y_mkl_md, z_mkl_md); + dnnl::matmul::primitive_desc op_prim_desc(op_desc, attr, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(*xTR, engine, stream, x_user_md, + op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + + + // y + mkldnnUtils::loadDataToMklStream(*yTR, engine, stream, y_user_md, + op_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + + // z + auto z_user_mem = mkldnnUtils::loadDataToMklStream(*zR, engine, stream, z_user_md, op_prim_desc.dst_desc(), + args[DNNL_ARG_DST] ); + + // run calculations + dnnl::matmul(op_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); + + stream.wait(); + + if (zR->buffer() != z->buffer()) z->assign(zR); + + if (zR != z) delete zR; + if (xTR != xT) delete xTR; + if (xT != x) delete xT; + if (yTR != yT) delete yTR; + if (yT != y) delete yT; + + // shape::printArray(z_mkl_mem.map_data(),8); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(matmul, ENGINE_CPU) { - - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - if(x->isEmpty() || y->isEmpty()) - return Status::OK(); - - int iSize = (int) block.getIArguments()->size(); - int transX = iSize > 0 ? INT_ARG(0) : 0; - int transY = iSize > 1 ? INT_ARG(1) : 0; - const int transZ = iSize > 2 ? INT_ARG(2) : 0; - - // optional use alpha nad beta - iSize = (int)block.getTArguments()->size(); - float alpha = iSize > 0 ? T_ARG(0) : 1.0; - float beta = iSize > 1 ? T_ARG(1) : 0.0; - - const int xRank = x->rankOf(); - const int yRank = y->rankOf(); - const int zRank = z->rankOf(); - - if (transZ) { - x = INPUT_VARIABLE(1); - y = INPUT_VARIABLE(0); - bool temp = transX; - transX = !transY; - transY = !temp; - } - - const int xLastDim = transX ? -2 : -1; - const int yLastDim = transY ? -2 : -1; - const int xLastButOneDim = transX ? -1 : -2; - const int yLastButOneDim = transY ? -1 : -2; - - // ******* input validation ******* // - REQUIRE_TRUE(xRank > 0 && yRank > 0, 0, "MATMUL MKLDNN OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !", xRank, yRank); - - if (xRank == 1 && yRank == 1) { // dot case, output is scalar (or vector with length = 1) - REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0,"MATMUL MKLDNN OP: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !",x->lengthOf(), y->lengthOf()); - } else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector - REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0, "MATMUL MKLDNN OP: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); - } else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector - REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0, "MATMUL MKLDNN OP: input arrays have inconsistent shapes for matrix-vector product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); - } else { - REQUIRE_TRUE(xRank == yRank && yRank == zRank, 0, "MATMUL MKLDNN OP: input and output arrays must have the same rank, but got instead: x rank = %i, y rank = %i, z rank = %i !", xRank, yRank, zRank); - REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) && x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && y->sizeAt(yLastDim) == z->sizeAt(-1), 0, "MATMUL MKLDNN OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str()); - - if (xRank > 2) // outer dims must be the same - for (int i = 0; i < xRank - 2; ++i) - REQUIRE_TRUE(x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0, "MATMUL MKLDNN OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str()); - } - // ******* end of input validation ******* // - - matmulMKLDNN(x, y, z, transX, transY, alpha, beta); - - return Status::OK(); + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + if (x->isEmpty() || y->isEmpty()) return Status::OK(); + + int iSize = (int)block.numI(); + int transX = iSize > 0 ? INT_ARG(0) : 0; + int transY = iSize > 1 ? INT_ARG(1) : 0; + const int transZ = iSize > 2 ? INT_ARG(2) : 0; + + // optional use alpha nad beta + iSize = (int)block.numT(); + float alpha = iSize > 0 ? T_ARG(0) : 1.0; + float beta = iSize > 1 ? T_ARG(1) : 0.0; + + const int xRank = x->rankOf(); + const int yRank = y->rankOf(); + const int zRank = z->rankOf(); + + if (transZ) { + x = INPUT_VARIABLE(1); + y = INPUT_VARIABLE(0); + bool temp = transX; + transX = !transY; + transY = !temp; + } + + const int xLastDim = transX ? -2 : -1; + const int yLastDim = transY ? -2 : -1; + const int xLastButOneDim = transX ? -1 : -2; + const int yLastButOneDim = transY ? -1 : -2; + + // ******* input validation ******* // + REQUIRE_TRUE( + xRank > 0 && yRank > 0, 0, + "MATMUL MKLDNN OP: input arrays must have rank bigger than 0 (should not " + "be scalars), but got instead: x rank = %i, y rank = %i !", + xRank, yRank); + + if (xRank == 1 && + yRank == 1) { // dot case, output is scalar (or vector with length = 1) + REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, + "MATMUL MKLDNN OP: since input arrays are vectors they must " + "have the same length, but got x length = %i, y length = %i !", + x->lengthOf(), y->lengthOf()); + } else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = + // [5], output is vector + REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0, + "MATMUL MKLDNN OP: input arrays have inconsistent shapes for " + "vector-matrix product: x %s, y %s !", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str()); + } else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] + // = [4], output is vector + REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0, + "MATMUL MKLDNN OP: input arrays have inconsistent shapes for " + "matrix-vector product: x %s, y %s !", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str()); + } else { + REQUIRE_TRUE( + xRank == yRank && yRank == zRank, 0, + "MATMUL MKLDNN OP: input and output arrays must have the same rank, " + "but got instead: x rank = %i, y rank = %i, z rank = %i !", + xRank, yRank, zRank); + REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) && + x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && + y->sizeAt(yLastDim) == z->sizeAt(-1), + 0, + "MATMUL MKLDNN OP: input/output arrays have inconsistent " + "shapes for matrix product: x %s, y %s, z %s !", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str(), + ShapeUtils::shapeAsString(z).c_str()); + + if (xRank > 2) // outer dims must be the same + for (int i = 0; i < xRank - 2; ++i) + REQUIRE_TRUE( + x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0, + "MATMUL MKLDNN OP: input/output arrays have inconsistent shapes " + "for matrix product: x %s, y %s, z %s !", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str(), + ShapeUtils::shapeAsString(z).c_str()); + } + // ******* end of input validation ******* // + + matmulMKLDNN(x, y, z, transX, transY, alpha, beta); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(matmul, ENGINE_CPU) { - - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - - auto z = OUTPUT_VARIABLE(0); - - const auto xType = x->dataType(); - const auto yType = y->dataType(); - const auto zType = z->dataType(); - - float alpha = block.numT() > 0 ? T_ARG(0) : 1.0f; - float beta = block.numT() > 1 ? T_ARG(1) : 0.0f; - - // we're skipping if result order is F or arrays are not continuous - bool skip2D = z->rankOf() == 2 && (z->ordering() == 'f' || x->ews() != 1 || y->ews() != 1 || z->ews() != 1); - - // we're skipping 3D cases if they are not C continuoys - bool skip3D = z->rankOf() == 3 && (x->ordering() == 'f' || y->ordering() == 'f' || z->ordering() == 'f' || x->ews() != 1 || y->ews() != 1 || z->ews() != 1); - - return !skip2D && !skip3D && block.isUseMKLDNN() && x->rankOf() < 3 && - ( - (xType==DataType::FLOAT32 && yType==DataType::FLOAT32 && zType==DataType::FLOAT32) || - (xType==DataType::HALF && yType==DataType::HALF && zType==DataType::FLOAT32) || - (xType==DataType::BFLOAT16 && yType==DataType::BFLOAT16 && zType==DataType::BFLOAT16) || - ((xType==DataType::UINT8 || xType==DataType::INT8) && (yType==DataType::UINT8 || yType==DataType::INT8) && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32)) - ); + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + + auto z = OUTPUT_VARIABLE(0); + + const auto xType = x->dataType(); + const auto yType = y->dataType(); + const auto zType = z->dataType(); + + float alpha = block.numT() > 0 ? T_ARG(0) : 1.0f; + float beta = block.numT() > 1 ? T_ARG(1) : 0.0f; + + // we're skipping if result order is F or arrays are not continuous + bool skip2D = z->rankOf() == 2 && (z->ordering() == 'f' || x->ews() != 1 || + y->ews() != 1 || z->ews() != 1); + + // we're skipping 3D cases if they are not C continuoys + bool skip3D = + z->rankOf() == 3 && + (x->ordering() == 'f' || y->ordering() == 'f' || z->ordering() == 'f' || + x->ews() != 1 || y->ews() != 1 || z->ews() != 1); + + return !skip2D && !skip3D && block.isUseMKLDNN() && x->rankOf() < 3 && + ((xType == DataType::FLOAT32 && yType == DataType::FLOAT32 && + zType == DataType::FLOAT32) || + (xType == DataType::HALF && yType == DataType::HALF && + zType == DataType::FLOAT32) || + (xType == DataType::BFLOAT16 && yType == DataType::BFLOAT16 && + zType == DataType::BFLOAT16) || + ((xType == DataType::UINT8 || xType == DataType::INT8) && + (yType == DataType::UINT8 || yType == DataType::INT8) && + (zType == DataType::UINT8 || zType == DataType::INT8 || + zType == DataType::INT32 || zType == DataType::FLOAT32))); } - -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp index 50b3fafa5625..b91b1b8f599a 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp @@ -20,108 +20,143 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include +#include #include +#include +#include #include -#include #include "mkldnnUtils.h" -#include using namespace dnnl; -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { - ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(maxpool2d, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D MKLDNN OP: input array should have rank of 4, but got %i instead", input->rankOf()); - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - const int kH = INT_ARG(0); - const int kW = INT_ARG(1); - const int sH = INT_ARG(2); - const int sW = INT_ARG(3); - int pH = INT_ARG(4); - int pW = INT_ARG(5); - const int dH = INT_ARG(6); - const int dW = INT_ARG(7); - const int paddingMode = INT_ARG(8); - // const int extraParam0 = INT_ARG(9); - const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW - - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - if (paddingMode) - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - mkldnnUtils::poolingMKLDNN(input, output, 0,kH,kW, 0,sH,sW, 0,pH,pW, isNCHW, algorithm::pooling_max); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "MAXPOOL2D MKLDNN OP: input array should have rank of 4, but " + "got %i instead", + input->rankOf()); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + const int kH = INT_ARG(0); + const int kW = INT_ARG(1); + const int sH = INT_ARG(2); + const int sW = INT_ARG(3); + int pH = INT_ARG(4); + int pW = INT_ARG(5); + const int dH = INT_ARG(6); + const int dW = INT_ARG(7); + const int paddingMode = INT_ARG(8); + // const int extraParam0 = INT_ARG(9); + const int isNCHW = + block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW + + REQUIRE_TRUE(dH != 0 && dW != 0, 0, + "MAXPOOL2D MKLDNN op: dilation must not be zero, but got " + "instead {%i, %i}", + dH, dW); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWoC, indWkH, indOoH); + + if (paddingMode) + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + mkldnnUtils::poolingMKLDNN(input, output, 0, kH, kW, 0, sH, sW, 0, pH, pW, + isNCHW, algorithm::pooling_max); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(maxpool2d, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); + return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(maxpool2d_bp, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - // int extraParam0 = INT_ARG(9); - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D_BP MKLDNN op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL2D_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - - if (paddingMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, 0,kH,kW, 0,sH,sW, 0,pH,pW, isNCHW, algorithm::pooling_max); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE( + 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + // int extraParam0 = INT_ARG(9); + int isNCHW = + block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "MAXPOOL2D_BP MKLDNN op: input should have rank of 4, but got %i instead", + input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, + "MAXPOOL2D_BP MKLDNN op: dilation must not be zero, but got " + "instead {%i, %i}", + dH, dW); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWoC, indWkH, indOoH); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "MAXPOOL2D_BP MKLDNN op: wrong shape of output's gradients " + "array (next epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, 0, kH, kW, 0, sH, sW, 0, pH, + pW, isNCHW, algorithm::pooling_max); + + return Status::OK(); } PLATFORM_CHECK(maxpool2d_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); + return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); } -} -} -} \ No newline at end of file +} // namespace platforms +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp index 078b45ba0bf9..c1a0b35af33e 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp @@ -19,115 +19,150 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include +#include #include +#include +#include #include -#include #include "mkldnnUtils.h" -#include using namespace dnnl; -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) - - int kD = INT_ARG(0); // filter(kernel) depth - int kH = INT_ARG(1); // filter(kernel) height - int kW = INT_ARG(2); // filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - // int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW - - REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW MKLDNN op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - if(paddingMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - mkldnnUtils::poolingMKLDNN(input, output, kD,kH,kW, sD,sH,sW, pD,pH,pW, isNCDHW, algorithm::pooling_max); - - return Status::OK(); - + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto output = OUTPUT_VARIABLE( + 0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) + + int kD = INT_ARG(0); // filter(kernel) depth + int kH = INT_ARG(1); // filter(kernel) height + int kW = INT_ARG(2); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + // int extraParam0 = INT_ARG(13); // + // unnecessary for max case, required only for avg and pnorm cases + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "MAXPOOL3DNEW MKLDNN OP: rank of input array must be equal to " + "5, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "MAXPOOL3DNEW MKLDNN op: dilation must not be zero, but got " + "instead {%i, %i, %i}", + dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + mkldnnUtils::poolingMKLDNN(input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, + isNCDHW, algorithm::pooling_max); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(maxpool3dnew, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); + return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - - const int kD = INT_ARG(0); // filter(kernel) depth - const int kH = INT_ARG(1); // filter(kernel) height - const int kW = INT_ARG(2); // filter(kernel) width - const int sD = INT_ARG(3); // strides depth - const int sH = INT_ARG(4); // strides height - const int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - const int dD = INT_ARG(9); // dilations depth - const int dH = INT_ARG(10); // dilations height - const int dW = INT_ARG(11); // dilations width - const int paddngMode = INT_ARG(12); // 1-SAME, 0-VALID - // int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW - - REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW_BP MKLDNN op: input should have rank of 5, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - - if(paddngMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, kD,kH,kW, sD,sH,sW, pD,pH,pW, isNCDHW, algorithm::pooling_max); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, + // oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), epsilon + + const int kD = INT_ARG(0); // filter(kernel) depth + const int kH = INT_ARG(1); // filter(kernel) height + const int kW = INT_ARG(2); // filter(kernel) width + const int sD = INT_ARG(3); // strides depth + const int sH = INT_ARG(4); // strides height + const int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + const int dD = INT_ARG(9); // dilations depth + const int dH = INT_ARG(10); // dilations height + const int dW = INT_ARG(11); // dilations width + const int paddngMode = INT_ARG(12); // 1-SAME, 0-VALID + // int extraParam0 = INT_ARG(13); // + // unnecessary for max case, required only for avg and pnorm cases + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "MAXPOOL3DNEW_BP MKLDNN op: input should have rank of 5, but " + "got %i instead", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "MAXPOOL3DNEW_BP MKLDNN op: dilation must not be zero, but got " + "instead {%i, %i, %i}", + dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oD, oH, oW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "MAXPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients " + "array (next epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + + if (paddngMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, + pH, pW, isNCDHW, algorithm::pooling_max); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(maxpool3dnew_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); + return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); } -} -} -} \ No newline at end of file +} // namespace platforms +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp index dcc0258f42b6..2763e8fe2ffb 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp @@ -19,23 +19,24 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // +#include "mkldnnUtils.h" + #include #include -#include "mkldnnUtils.h" using namespace dnnl; -namespace sd { +namespace sd { namespace mkldnnUtils { ////////////////////////////////////////////////////////////////////// -void getDims(const NDArray* array, const int rank, dnnl::memory::dims& mklDims){ - - std::vector vDims(rank); - for (auto i = 0; i < rank; i++) { - vDims[i] = array->sizeAt(i); - } - mklDims = dnnl::memory::dims(vDims); +void getDims(const NDArray* array, const int rank, + dnnl::memory::dims& mklDims) { + std::vector vDims(rank); + for (auto i = 0; i < rank; i++) { + vDims[i] = array->sizeAt(i); + } + mklDims = dnnl::memory::dims(vDims); } ////////////////////////////////////////////////////////////////////// dnnl::memory::format_tag getFormat(const NDArray& arr) { @@ -70,300 +71,360 @@ dnnl::memory::format_tag getFormat(const NDArray& arr) { ////////////////////////////////////////////////////////////////////// void setBlockStrides(const NDArray& array, dnnl::memory::desc& mklMd, const std::vector& permut) { - - if (array.ews() != 1 || (array.rankOf() > 3 && array.ordering() == 'f') || !permut.empty()) { - - mklMd.data.format_kind = dnnl_blocked; // overrides format - - if(permut.empty()) - for (auto i = 0; i < array.rankOf(); ++i) - mklMd.data.format_desc.blocking.strides[i] = array.strideAt(i); - else { - if(array.rankOf() != permut.size()) - throw std::invalid_argument("mkldnnUtils::setBlockStrides: size of permut vector is not equal to array rank !"); - for (auto i = 0; i < array.rankOf(); ++i) - mklMd.data.format_desc.blocking.strides[i] = array.strideAt(permut[i]); - } - } + if (array.ews() != 1 || (array.rankOf() > 3 && array.ordering() == 'f') || !permut.empty()) { + mklMd.data.format_kind = dnnl_blocked; // overrides format + if(permut.empty()) + for (auto i = 0; i < array.rankOf(); ++i) + mklMd.data.format_desc.blocking.strides[i] = array.strideAt(i); + else { + if(array.rankOf() != permut.size()) + throw std::invalid_argument("mkldnnUtils::setBlockStrides: size of permut vector is not equal to array rank !"); + + for (auto i = 0; i < array.rankOf(); ++i) + mklMd.data.format_desc.blocking.strides[i] = array.strideAt(permut[i]); + } + } } + //////////////////////////////////////////////////////////////////////////////////////////////// -dnnl::memory loadDataToMklStream(const NDArray& array, const dnnl::engine& engine, const dnnl::stream& stream, - const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md, dnnl::memory& arg) { - - auto user_mem = dnnl::memory(user_md, engine, const_cast(array).buffer()); - const bool bReorder = primitive_md != user_mem.get_desc(); - auto mkl_mem = bReorder ? dnnl::memory(primitive_md, engine) : user_mem; - if (bReorder) - dnnl::reorder(user_mem, mkl_mem).execute(stream, user_mem, mkl_mem); - arg = mkl_mem; - return user_mem; +dnnl::memory loadDataToMklStream(const NDArray& array, const dnnl::engine& engine, + const dnnl::stream& stream, + const dnnl::memory::desc& user_md, + const dnnl::memory::desc& primitive_md, + dnnl::memory& arg) { + auto user_mem = + dnnl::memory(user_md, engine, const_cast(array).buffer()); + const bool bReorder = primitive_md != user_mem.get_desc(); + auto mkl_mem = bReorder ? dnnl::memory(primitive_md, engine) : user_mem; + if (bReorder) + dnnl::reorder(user_mem, mkl_mem).execute(stream, user_mem, mkl_mem); + arg = mkl_mem; + return user_mem; } ////////////////////////////////////////////////////////////////////// -void poolingMKLDNN(const NDArray *input, NDArray *output, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int isNCHW, const dnnl::algorithm mode) { - - // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for input - const int rank = input->rankOf(); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; - dnnl::memory::dims strides, kernel, padding, padding_r, xDims, zDims; - dnnl::memory::format_tag xzFrmat; - - const auto type = dnnl::memory::data_type::f32; - - if(rank == 4) { // 2d - - ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - strides = { sH, sW }; - kernel = { kH, kW }; - padding = { pH, pW }; - padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; - xDims = {bS, iC, iH, iW}; - zDims = {bS, oC, oH, oW}; - - xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - } - else { // 3d - - ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH); - - strides = { sD, sH, sW }; - kernel = { kD, kH, kW }; - padding = { pD, pH, pW }; - padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; - xDims = {bS, iC, iD, iH, iW}; - zDims = {bS, oC, oD, oH, oW}; - - xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - } - - std::vector permut; +void poolingMKLDNN(const NDArray* input, NDArray* output, const int kD, + const int kH, const int kW, const int sD, const int sH, + const int sW, const int pD, const int pH, const int pW, + const int isNCHW, const dnnl::algorithm mode) { + // unfortunately mkl dnn doesn't support any format + // (dnnl::memory::format_tag::any) for input + const int rank = input->rankOf(); + + int bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWoC, indWiC, + indWkH, indOoH; + dnnl::memory::dims strides, kernel, padding, padding_r, xDims, zDims; + dnnl::memory::format_tag xzFrmat; + + const auto type = dnnl::memory::data_type::f32; + + if (rank == 4) { // 2d + + ops::ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + strides = {sH, sW}; + kernel = {kH, kW}; + padding = {pH, pW}; + padding_r = {(oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW}; + xDims = {bS, iC, iH, iW}; + zDims = {bS, oC, oH, oW}; + + xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw + : dnnl::memory::format_tag::nhwc; + } else { // 3d + + ops::ConvolutionUtils::getSizesAndIndexesConv3d( + isNCHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIiH, indWiC, indWoC, indWkH); + + strides = {sD, sH, sW}; + kernel = {kD, kH, kW}; + padding = {pD, pH, pW}; + padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, + (oW - 1) * sW - iW + kW - pW}; + xDims = {bS, iC, iD, iH, iW}; + zDims = {bS, oC, oD, oH, oW}; + + xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw + : dnnl::memory::format_tag::ndhwc; + } + + std::vector permut; if(!isNCHW) permut = rank == 4 ? std::vector({0,3,1,2}) : std::vector({0,4,1,2,3}); // memory descriptors for arrays - // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); - mkldnnUtils::setBlockStrides(*input, x_user_md, permut); + // input + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); + mkldnnUtils::setBlockStrides(*input, + x_user_md, permut); - // output - dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat); - mkldnnUtils::setBlockStrides(*output, z_user_md, permut); + // output + dnnl::memory::desc z_mkl_md = + dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat); + mkldnnUtils::setBlockStrides(*output, + z_user_md, permut); - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - // operation primitive description - dnnl::pooling_forward::desc op_desc(dnnl::prop_kind::forward_inference, mode, x_mkl_md, z_mkl_md, strides, kernel, padding, padding_r); - dnnl::pooling_forward::primitive_desc op_prim_desc(op_desc, engine); + // operation primitive description + dnnl::pooling_forward::desc op_desc(dnnl::prop_kind::forward_inference, mode, + x_mkl_md, z_mkl_md, strides, kernel, + padding, padding_r); + dnnl::pooling_forward::primitive_desc op_prim_desc(op_desc, engine); - // arguments (memory buffers) necessary for calculations - std::unordered_map args; + // arguments (memory buffers) necessary for calculations + std::unordered_map args; - dnnl::stream stream(engine); + dnnl::stream stream(engine); - // provide memory buffers and check whether reorder is required + // provide memory buffers and check whether reorder is required - // input - mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + // input + mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, + op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - // output - auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]); + // output + auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), + args[DNNL_ARG_DST] ); - // run calculations - dnnl::pooling_forward(op_prim_desc).execute(stream, args); + // run calculations + dnnl::pooling_forward(op_prim_desc).execute(stream, args); - // reorder outputs if necessary - if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) - dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); + // reorder outputs if necessary + if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); - stream.wait(); + stream.wait(); } ////////////////////////////////////////////////////////////////////// -void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int isNCHW, const dnnl::algorithm mode) { - - // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for input - - const int rank = input->rankOf(); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; - dnnl::memory::dims strides, kernel, padding, padding_r, xDims, zDims; - dnnl::memory::format_tag xzFrmat; - - const auto type = dnnl::memory::data_type::f32; - - if(rank == 4) { // 2d - - ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - strides = { sH, sW }; - kernel = { kH, kW }; - padding = { pH, pW }; - padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; - xDims = {bS, iC, iH, iW}; - zDims = {bS, oC, oH, oW}; - - xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - } - else { // 3d - - ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH); - - strides = { sD, sH, sW }; - kernel = { kD, kH, kW }; - padding = { pD, pH, pW }; - padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; - xDims = {bS, iC, iD, iH, iW}; - zDims = {bS, oC, oD, oH, oW}; - - xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - } - - std::vector permut; +void poolingBpMKLDNN(const NDArray* input, const NDArray* gradO, NDArray* gradI, + const int kD, const int kH, const int kW, const int sD, + const int sH, const int sW, const int pD, const int pH, + const int pW, const int isNCHW, + const dnnl::algorithm mode) { + // unfortunately mkl dnn doesn't support any format + // (dnnl::memory::format_tag::any) for input + + const int rank = input->rankOf(); + + int bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWoC, indWiC, + indWkH, indOoH; + dnnl::memory::dims strides, kernel, padding, padding_r, xDims, zDims; + dnnl::memory::format_tag xzFrmat; + + const auto type = dnnl::memory::data_type::f32; + + if (rank == 4) { // 2d + + ops::ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWoC, indWkH, indOoH); + + strides = {sH, sW}; + kernel = {kH, kW}; + padding = {pH, pW}; + padding_r = {(oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW}; + xDims = {bS, iC, iH, iW}; + zDims = {bS, oC, oH, oW}; + + xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw + : dnnl::memory::format_tag::nhwc; + } else { // 3d + + ops::ConvolutionUtils::getSizesAndIndexesConv3d( + isNCHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH); + + strides = {sD, sH, sW}; + kernel = {kD, kH, kW}; + padding = {pD, pH, pW}; + padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, + (oW - 1) * sW - iW + kW - pW}; + xDims = {bS, iC, iD, iH, iW}; + zDims = {bS, oC, oD, oH, oW}; + + xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw + : dnnl::memory::format_tag::ndhwc; + } + + std::vector permut; if(!isNCHW) permut = rank == 4 ? std::vector({0,3,1,2}) : std::vector({0,4,1,2,3}); // memory descriptors for arrays + // input + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); + mkldnnUtils::setBlockStrides(*input, + x_user_md, permut); + + // gradO + dnnl::memory::desc gradO_mkl_md = + dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat); + mkldnnUtils::setBlockStrides(*gradO, + gradO_user_md, permut); + + // gradI + dnnl::memory::desc gradI_mkl_md = + dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat); + mkldnnUtils::setBlockStrides(*gradI, + gradI_user_md, permut); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + dnnl::stream stream(engine); + + // forward primitive description + dnnl::pooling_forward::desc op_ff_desc(dnnl::prop_kind::forward, mode, + x_mkl_md, gradO_mkl_md, strides, + kernel, padding, padding_r); + dnnl::pooling_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + + // backward primitive description + dnnl::pooling_backward::desc op_bp_desc(mode, gradI_mkl_md, gradO_mkl_md, + strides, kernel, padding, padding_r); + dnnl::pooling_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, + op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + // gradO + mkldnnUtils::loadDataToMklStream(*gradO, engine, stream, gradO_user_md, + op_bp_prim_desc.diff_dst_desc(), + args[DNNL_ARG_DIFF_DST]); + + // gradI + auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, + op_bp_prim_desc.diff_src_desc() , + args[DNNL_ARG_DIFF_SRC] ); + + if (mode == algorithm::pooling_max) { // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); - mkldnnUtils::setBlockStrides(*input, x_user_md, permut); - - // gradO - dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat); - mkldnnUtils::setBlockStrides(*gradO, gradO_user_md, permut); - - // gradI - dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat); - mkldnnUtils::setBlockStrides(*gradI, gradI_user_md, permut); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); - - // forward primitive description - dnnl::pooling_forward::desc op_ff_desc(dnnl::prop_kind::forward, mode, x_mkl_md, gradO_mkl_md, strides, kernel, padding, padding_r); - dnnl::pooling_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, + op_ff_prim_desc.src_desc(), + args[DNNL_ARG_SRC]); - // backward primitive description - dnnl::pooling_backward::desc op_bp_desc(mode, gradI_mkl_md, gradO_mkl_md, strides, kernel, padding, padding_r); - dnnl::pooling_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc); + // z + auto z_mkl_mem = dnnl::memory(op_ff_prim_desc.dst_desc(), engine); + args[DNNL_ARG_DST] = z_mkl_mem; - // arguments (memory buffers) necessary for calculations - std::unordered_map args; + // auxiliary memory allocation + auto workspace = dnnl::memory(op_ff_prim_desc.workspace_desc(), engine); + args[DNNL_ARG_WORKSPACE] = workspace; - // gradO - mkldnnUtils::loadDataToMklStream(*gradO, engine, stream, gradO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); + // run forward calculations + dnnl::pooling_forward(op_ff_prim_desc).execute(stream, args); + } - // gradI - auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]); + // run backward calculations + dnnl::pooling_backward(op_bp_prim_desc).execute(stream, args); - if(mode == algorithm::pooling_max) { + // reorder gradI if necessary + if (op_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem) + .execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem); - // input - mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // z - auto z_mkl_mem = dnnl::memory(op_ff_prim_desc.dst_desc(), engine); - args[DNNL_ARG_DST] = z_mkl_mem; - - // auxiliary memory allocation - auto workspace = dnnl::memory(op_ff_prim_desc.workspace_desc(), engine); - args[DNNL_ARG_WORKSPACE] = workspace; - - // run forward calculations - dnnl::pooling_forward(op_ff_prim_desc).execute(stream, args); - } - - // run backward calculations - dnnl::pooling_backward(op_bp_prim_desc).execute(stream, args); - - // reorder gradI if necessary - if (op_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc()) - dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem); - - stream.wait(); + stream.wait(); } ////////////////////////////////////////////////////////////////////////// -void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst, - dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) { - const Nd4jLong* shape = src->shapeInfo(); - long rank = shape[0]; - long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one - long dim2 = axis >= 2 ? 1 : 2; - long dim3 = axis >= 3 ? 2 : 3; - dnnl::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; - - auto type = dnnl::memory::data_type::f32; - auto format = axis == 1 ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - auto supposed_to_be_any_format = format; // doesn't work with "any" - - if (src != nullptr && src->buffer() != nullptr && lrn_src_md != nullptr) { - *lrn_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); - *user_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format); - user_src_md->data.format_kind = dnnl_blocked; - user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; - user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; - user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1; - user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1; - } - - if (diff_src != nullptr && diff_src->buffer() != nullptr && lrn_diff_src_md != nullptr) { - *lrn_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); - *user_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format); - user_diff_src_md->data.format_kind = dnnl_blocked; - user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0]; - user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1]; - user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1; - user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1; - } - - if (dst != nullptr && dst->buffer() != nullptr && lrn_dst_md != nullptr) { - *lrn_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); - *user_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, format); - user_dst_md->data.format_kind = dnnl_blocked; - user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; - user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; - user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1; - user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1; - } +void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, + const NDArray* dst, dnnl::memory::desc* lrn_src_md, + dnnl::memory::desc* lrn_diff_src_md, + dnnl::memory::desc* lrn_dst_md, + dnnl::memory::desc* user_src_md, + dnnl::memory::desc* user_diff_src_md, + dnnl::memory::desc* user_dst_md, int axis) { + const Nd4jLong* shape = src->shapeInfo(); + long rank = shape[0]; + long dim1 = + axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one + long dim2 = axis >= 2 ? 1 : 2; + long dim3 = axis >= 3 ? 2 : 3; + dnnl::memory::dims lrn_src_tz = {(int)shape[1], (int)shape[dim1 + 1], + rank > 2 ? (int)shape[dim2 + 1] : 1, + rank > 3 ? (int)shape[dim3 + 1] : 1}; + + auto type = dnnl::memory::data_type::f32; + auto format = axis == 1 ? dnnl::memory::format_tag::nchw + : dnnl::memory::format_tag::nhwc; + auto supposed_to_be_any_format = format; // doesn't work with "any" + + if (src != nullptr && src->buffer() != nullptr && lrn_src_md != nullptr) { + *lrn_src_md = + dnnl::memory::desc({lrn_src_tz}, type, supposed_to_be_any_format); + *user_src_md = dnnl::memory::desc({lrn_src_tz}, type, format); + user_src_md->data.format_kind = dnnl_blocked; + user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; + user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; + user_src_md->data.format_desc.blocking.strides[2] = + rank > 2 ? src->stridesOf()[dim2] : 1; + user_src_md->data.format_desc.blocking.strides[3] = + rank > 3 ? src->stridesOf()[dim3] : 1; + } + + if (diff_src != nullptr && diff_src->buffer() != nullptr && + lrn_diff_src_md != nullptr) { + *lrn_diff_src_md = + dnnl::memory::desc({lrn_src_tz}, type, supposed_to_be_any_format); + *user_diff_src_md = dnnl::memory::desc({lrn_src_tz}, type, format); + user_diff_src_md->data.format_kind = dnnl_blocked; + user_diff_src_md->data.format_desc.blocking.strides[0] = + diff_src->stridesOf()[0]; + user_diff_src_md->data.format_desc.blocking.strides[1] = + diff_src->stridesOf()[dim1]; + user_diff_src_md->data.format_desc.blocking.strides[2] = + rank > 2 ? diff_src->stridesOf()[dim2] : 1; + user_diff_src_md->data.format_desc.blocking.strides[3] = + rank > 3 ? diff_src->stridesOf()[dim3] : 1; + } + + if (dst != nullptr && dst->buffer() != nullptr && lrn_dst_md != nullptr) { + *lrn_dst_md = + dnnl::memory::desc({lrn_src_tz}, type, supposed_to_be_any_format); + *user_dst_md = dnnl::memory::desc({lrn_src_tz}, type, format); + user_dst_md->data.format_kind = dnnl_blocked; + user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; + user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; + user_dst_md->data.format_desc.blocking.strides[2] = + rank > 2 ? dst->stridesOf()[dim2] : 1; + user_dst_md->data.format_desc.blocking.strides[3] = + rank > 3 ? dst->stridesOf()[dim3] : 1; + } } ////////////////////////////////////////////////////////////////////////// -dnnl::engine& getEngine(void *ptr) { - auto eng = reinterpret_cast(ptr); - return *eng; +dnnl::engine& getEngine(void* ptr) { + auto eng = reinterpret_cast(ptr); + return *eng; } - /* ////////////////////////////////////////////////////////////////////////// void getMKLDNNMemoryDescPool2d( - int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW, - int bS, int iC, int iH, int iW, int oC, int oH, int oW, - const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm, - dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, - dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) { - dnnl::memory::dims pool_src_tz = { bS, iC, iH, iW }; - dnnl::memory::dims pool_dst_tz = { bS, oC, oH, oW }; + int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int +poolingMode, int extraParam0, bool isNCHW, int bS, int iC, int iH, int iW, int +oC, int oH, int oW, const NDArray* src, const NDArray* diff_src, const NDArray* +dst, dnnl::algorithm& algorithm, dnnl::memory::desc* pool_src_md, +dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, +dnnl::memory::desc* user_dst_md, dnnl::memory::dims& pool_strides, +dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, +dnnl::memory::dims& pool_padding_r) { dnnl::memory::dims pool_src_tz = { bS, iC, +iH, iW }; dnnl::memory::dims pool_dst_tz = { bS, oC, oH, oW }; pool_strides = { sH, sW }; pool_kernel = { kH, kW }; @@ -372,51 +433,70 @@ void getMKLDNNMemoryDescPool2d( (oW - 1) * sW - iW + kW - pW }; algorithm = poolingMode == 0 ? algorithm::pooling_max - : extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding - : algorithm::pooling_avg_include_padding; + : extraParam0 == 0 ? +algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding; auto type = dnnl::memory::data_type::f32; - auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any" + auto format = isNCHW ? dnnl::memory::format_tag::nchw : +dnnl::memory::format_tag::nhwc; auto supposed_to_be_any_format = +dnnl::memory::format_tag::nChw8c; // doesn't work with "any" if (src != nullptr && src->buffer() != nullptr && pool_src_md != nullptr) { - *pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); - *user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); - user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0]; - user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3]; - user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1]; - user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2]; - } - - if (diff_src != nullptr && diff_src->buffer() != nullptr && pool_diff_src_md != nullptr) { - *pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); - *user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); - user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0]; - user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3]; - user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1]; - user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2]; + *pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, +supposed_to_be_any_format); *user_src_md = dnnl::memory::desc({ pool_src_tz }, +type, format); user_src_md->data.format_kind = dnnl_blocked; // overrides +"format = isNCHW ? nchw : nhwc" + user_src_md->data.format_desc.blocking.strides[0] = +src->stridesOf()[isNCHW ? 0 : 0]; + user_src_md->data.format_desc.blocking.strides[1] = +src->stridesOf()[isNCHW ? 1 : 3]; + user_src_md->data.format_desc.blocking.strides[2] = +src->stridesOf()[isNCHW ? 2 : 1]; + user_src_md->data.format_desc.blocking.strides[3] = +src->stridesOf()[isNCHW ? 3 : 2]; + } + + if (diff_src != nullptr && diff_src->buffer() != nullptr && pool_diff_src_md +!= nullptr) { *pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, +supposed_to_be_any_format); *user_diff_src_md = dnnl::memory::desc({ pool_src_tz +}, type, format); user_diff_src_md->data.format_kind = dnnl_blocked; // +overrides "format = isNCHW ? nchw : nhwc" + user_diff_src_md->data.format_desc.blocking.strides[0] = +diff_src->stridesOf()[isNCHW ? 0 : 0]; + user_diff_src_md->data.format_desc.blocking.strides[1] = +diff_src->stridesOf()[isNCHW ? 1 : 3]; + user_diff_src_md->data.format_desc.blocking.strides[2] = +diff_src->stridesOf()[isNCHW ? 2 : 1]; + user_diff_src_md->data.format_desc.blocking.strides[3] = +diff_src->stridesOf()[isNCHW ? 3 : 2]; } if (dst != nullptr && dst->buffer() != nullptr && pool_dst_md != nullptr) { - *pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format); - *user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format); - user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0]; - user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3]; - user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1]; - user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2]; + *pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, +supposed_to_be_any_format); *user_dst_md = dnnl::memory::desc({ pool_dst_tz }, +type, format); user_dst_md->data.format_kind = dnnl_blocked; // overrides +"format = isNCHW ? nchw : nhwc" + user_dst_md->data.format_desc.blocking.strides[0] = +dst->stridesOf()[isNCHW ? 0 : 0]; + user_dst_md->data.format_desc.blocking.strides[1] = +dst->stridesOf()[isNCHW ? 1 : 3]; + user_dst_md->data.format_desc.blocking.strides[2] = +dst->stridesOf()[isNCHW ? 2 : 1]; + user_dst_md->data.format_desc.blocking.strides[3] = +dst->stridesOf()[isNCHW ? 3 : 2]; } }; ////////////////////////////////////////////////////////////////////////// void getMKLDNNMemoryDescPool3d( - int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW, - int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, - const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm, - dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, - dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) { + int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, +int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW, int bS, +int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* +src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm, + dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, +dnnl::memory::desc* pool_dst_md, dnnl::memory::desc* user_src_md, +dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, + dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, +dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) { dnnl::memory::dims pool_src_tz = { bS, iC, iD, iH, iW }; dnnl::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW }; @@ -428,258 +508,346 @@ void getMKLDNNMemoryDescPool3d( (oW - 1) * sW - iW + kW - pW }; algorithm = poolingMode == 0 ? algorithm::pooling_max - : extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding - : algorithm::pooling_avg_include_padding; + : extraParam0 == 0 ? +algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding; auto type = dnnl::memory::data_type::f32; - auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - auto supposed_to_be_any_format = dnnl::memory::format_tag::nCdhw8c; // doesn't work with "any" + auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : +dnnl::memory::format_tag::ndhwc; auto supposed_to_be_any_format = +dnnl::memory::format_tag::nCdhw8c; // doesn't work with "any" if (src != nullptr && src->buffer() != nullptr && pool_src_md != nullptr) { - *pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); - *user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); - user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0]; - user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4]; - user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1]; - user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2]; - user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3]; - } - - if (diff_src != nullptr && diff_src->buffer() != nullptr && pool_diff_src_md != nullptr) { - *pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); - *user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); - user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0]; - user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4]; - user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1]; - user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2]; - user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3]; + *pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, +supposed_to_be_any_format); *user_src_md = dnnl::memory::desc({ pool_src_tz }, +type, format); user_src_md->data.format_kind = dnnl_blocked; // overrides +"format = isNCDHW ? ncdhw : ndhwc" + user_src_md->data.format_desc.blocking.strides[0] = +src->stridesOf()[isNCDHW ? 0 : 0]; + user_src_md->data.format_desc.blocking.strides[1] = +src->stridesOf()[isNCDHW ? 1 : 4]; + user_src_md->data.format_desc.blocking.strides[2] = +src->stridesOf()[isNCDHW ? 2 : 1]; + user_src_md->data.format_desc.blocking.strides[3] = +src->stridesOf()[isNCDHW ? 3 : 2]; + user_src_md->data.format_desc.blocking.strides[4] = +src->stridesOf()[isNCDHW ? 4 : 3]; + } + + if (diff_src != nullptr && diff_src->buffer() != nullptr && pool_diff_src_md +!= nullptr) { *pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, +supposed_to_be_any_format); *user_diff_src_md = dnnl::memory::desc({ pool_src_tz +}, type, format); user_diff_src_md->data.format_kind = dnnl_blocked; // +overrides "format = isNCDHW ? ncdhw : ndhwc" + user_diff_src_md->data.format_desc.blocking.strides[0] = +diff_src->stridesOf()[isNCDHW ? 0 : 0]; + user_diff_src_md->data.format_desc.blocking.strides[1] = +diff_src->stridesOf()[isNCDHW ? 1 : 4]; + user_diff_src_md->data.format_desc.blocking.strides[2] = +diff_src->stridesOf()[isNCDHW ? 2 : 1]; + user_diff_src_md->data.format_desc.blocking.strides[3] = +diff_src->stridesOf()[isNCDHW ? 3 : 2]; + user_diff_src_md->data.format_desc.blocking.strides[4] = +diff_src->stridesOf()[isNCDHW ? 4 : 3]; } if (dst != nullptr && dst->buffer() != nullptr && pool_dst_md != nullptr) { - *pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format); - *user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format); - user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0]; - user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4]; - user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1]; - user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2]; - user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3]; + *pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, +supposed_to_be_any_format); *user_dst_md = dnnl::memory::desc({ pool_dst_tz }, +type, format); user_dst_md->data.format_kind = dnnl_blocked; // overrides +"format = isNCDHW ? ncdhw : ndhwc" + user_dst_md->data.format_desc.blocking.strides[0] = +dst->stridesOf()[isNCDHW ? 0 : 0]; + user_dst_md->data.format_desc.blocking.strides[1] = +dst->stridesOf()[isNCDHW ? 1 : 4]; + user_dst_md->data.format_desc.blocking.strides[2] = +dst->stridesOf()[isNCDHW ? 2 : 1]; + user_dst_md->data.format_desc.blocking.strides[3] = +dst->stridesOf()[isNCDHW ? 3 : 2]; + user_dst_md->data.format_desc.blocking.strides[4] = +dst->stridesOf()[isNCDHW ? 4 : 3]; } }; ////////////////////////////////////////////////////////////////////////// void getMKLDNNMemoryDescConv2d( - int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW, - int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src, - const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, - dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, - dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md, - dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md, - dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) { + int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const +int paddingMode, bool isNCHW, int bS, int iC, int iH, int iW, int oC, int oH, +int oW, const NDArray* src, const NDArray* diff_src, const NDArray* weights, +const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, + dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, +dnnl::memory::desc* conv_weights_md, dnnl::memory::desc* conv_diff_weights_md, +dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, +dnnl::memory::desc* user_weights_md, dnnl::memory::desc* user_diff_weights_md, +dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md, + dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, +dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) { dnnl::memory::dims conv_src_tz = { bS, iC, iH, iW }; dnnl::memory::dims conv_weights_tz = { oC, iC, kH, kW }; dnnl::memory::dims conv_bias_tz = { oC }; dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW }; - const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d + const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) +* dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d conv_strides = { sH, sW }; conv_padding = { pH, pW }; - conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; - conv_dilation = { dH-1, dW-1}; + conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - +pWSame }; conv_dilation = { dH-1, dW-1}; auto type = dnnl::memory::data_type::f32; - auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - auto formatw = dnnl::memory::format_tag::hwio; + auto format = isNCHW ? dnnl::memory::format_tag::nchw : +dnnl::memory::format_tag::nhwc; auto formatw = dnnl::memory::format_tag::hwio; if (src != nullptr && conv_src_md != nullptr) { - *conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); - *user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); - user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0]; - user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3]; - user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1]; - user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2]; + *conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, +dnnl::memory::format_tag::any); *user_src_md = dnnl::memory::desc({ conv_src_tz +}, type, format); user_src_md->data.format_kind = dnnl_blocked; // overrides +"format = isNCHW ? nchw : nhwc" + user_src_md->data.format_desc.blocking.strides[0] = +src->stridesOf()[isNCHW ? 0 : 0]; + user_src_md->data.format_desc.blocking.strides[1] = +src->stridesOf()[isNCHW ? 1 : 3]; + user_src_md->data.format_desc.blocking.strides[2] = +src->stridesOf()[isNCHW ? 2 : 1]; + user_src_md->data.format_desc.blocking.strides[3] = +src->stridesOf()[isNCHW ? 3 : 2]; } if (diff_src != nullptr && conv_diff_src_md != nullptr) { - *conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); - *user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); - user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0]; - user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3]; - user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1]; - user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2]; + *conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, +dnnl::memory::format_tag::any); *user_diff_src_md = dnnl::memory::desc({ +conv_src_tz }, type, format); user_diff_src_md->data.format_kind = dnnl_blocked; +// overrides "format = isNCHW ? nchw : nhwc" + user_diff_src_md->data.format_desc.blocking.strides[0] = +diff_src->stridesOf()[isNCHW ? 0 : 0]; + user_diff_src_md->data.format_desc.blocking.strides[1] = +diff_src->stridesOf()[isNCHW ? 1 : 3]; + user_diff_src_md->data.format_desc.blocking.strides[2] = +diff_src->stridesOf()[isNCHW ? 2 : 1]; + user_diff_src_md->data.format_desc.blocking.strides[3] = +diff_src->stridesOf()[isNCHW ? 3 : 2]; } if (weights != nullptr && conv_weights_md != nullptr) { - *conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); - *user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); - user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio" - user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[3]; - user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[2]; - user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0]; - user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1]; + *conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, +dnnl::memory::format_tag::any); *user_weights_md = dnnl::memory::desc({ +conv_weights_tz }, type, formatw); user_weights_md->data.format_kind = +dnnl_blocked; // overrides "formatw = hwio" + user_weights_md->data.format_desc.blocking.strides[0] = +weights->stridesOf()[3]; user_weights_md->data.format_desc.blocking.strides[1] = +weights->stridesOf()[2]; user_weights_md->data.format_desc.blocking.strides[2] = +weights->stridesOf()[0]; user_weights_md->data.format_desc.blocking.strides[3] = +weights->stridesOf()[1]; } if (diff_weights != nullptr && conv_diff_weights_md != nullptr) { - *conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); - *user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); - user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio" - user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[3]; - user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[2]; - user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0]; - user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1]; + *conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, +dnnl::memory::format_tag::any); *user_diff_weights_md = dnnl::memory::desc({ +conv_weights_tz }, type, formatw); user_diff_weights_md->data.format_kind = +dnnl_blocked; // overrides "formatw = hwio" + user_diff_weights_md->data.format_desc.blocking.strides[0] = +diff_weights->stridesOf()[3]; + user_diff_weights_md->data.format_desc.blocking.strides[1] = +diff_weights->stridesOf()[2]; + user_diff_weights_md->data.format_desc.blocking.strides[2] = +diff_weights->stridesOf()[0]; + user_diff_weights_md->data.format_desc.blocking.strides[3] = +diff_weights->stridesOf()[1]; } if (bias != nullptr && conv_bias_md != nullptr) { - *conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any); - *user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x); + *conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, +dnnl::memory::format_tag::any); *user_bias_md = dnnl::memory::desc({ +conv_bias_tz }, type, dnnl::memory::format_tag::x); } if (dst != nullptr && conv_dst_md != nullptr) { - *conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any); - *user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format); - user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0]; - user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3]; - user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1]; - user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2]; + *conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, +dnnl::memory::format_tag::any); *user_dst_md = dnnl::memory::desc({ conv_dst_tz +}, type, format); user_dst_md->data.format_kind = dnnl_blocked; // overrides +"format = isNCHW ? nchw : nhwc" + user_dst_md->data.format_desc.blocking.strides[0] = +dst->stridesOf()[isNCHW ? 0 : 0]; + user_dst_md->data.format_desc.blocking.strides[1] = +dst->stridesOf()[isNCHW ? 1 : 3]; + user_dst_md->data.format_desc.blocking.strides[2] = +dst->stridesOf()[isNCHW ? 2 : 1]; + user_dst_md->data.format_desc.blocking.strides[3] = +dst->stridesOf()[isNCHW ? 3 : 2]; } } ////////////////////////////////////////////////////////////////////////// void getMKLDNNMemoryDescConv3d( - int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool paddingMode, bool isNCDHW, - int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src, - const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, - dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, - dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md, - dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md, - dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) { - dnnl::memory::dims conv_src_tz = { bS, iC, iD, iH, iW }; - dnnl::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW }; + int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, +int dD, int dH, int dW, bool paddingMode, bool isNCDHW, int bS, int iC, int iD, +int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const +NDArray* diff_src, const NDArray* weights, const NDArray* diff_weights, const +NDArray* bias, const NDArray* dst, dnnl::memory::desc* conv_src_md, +dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, + dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* +conv_bias_md, dnnl::memory::desc* conv_dst_md, dnnl::memory::desc* user_src_md, +dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md, + dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* +user_bias_md, dnnl::memory::desc* user_dst_md, dnnl::memory::dims& conv_strides, +dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, +dnnl::memory::dims& conv_dilation) { dnnl::memory::dims conv_src_tz = { bS, iC, +iD, iH, iW }; dnnl::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW }; dnnl::memory::dims conv_bias_tz = { oC }; dnnl::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW }; conv_strides = { sD, sH, sW }; conv_padding = { pD, pH, pW }; - conv_padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; - conv_dilation = { dD-1, dH-1, dW-1}; + conv_padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - +pH, (oW - 1) * sW - iW + kW - pW }; conv_dilation = { dD-1, dH-1, dW-1}; auto type = dnnl::memory::data_type::f32; - auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - auto formatw = dnnl::memory::format_tag::dhwio; + auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : +dnnl::memory::format_tag::ndhwc; auto formatw = dnnl::memory::format_tag::dhwio; if (src != nullptr && conv_src_md != nullptr) { - *conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); - *user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); - user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0]; - user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4]; - user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1]; - user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2]; - user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3]; + *conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, +dnnl::memory::format_tag::any); *user_src_md = dnnl::memory::desc({ conv_src_tz +}, type, format); user_src_md->data.format_kind = dnnl_blocked; // overrides +"format = isNCDHW ? ncdhw : ndhwc" + user_src_md->data.format_desc.blocking.strides[0] = +src->stridesOf()[isNCDHW ? 0 : 0]; + user_src_md->data.format_desc.blocking.strides[1] = +src->stridesOf()[isNCDHW ? 1 : 4]; + user_src_md->data.format_desc.blocking.strides[2] = +src->stridesOf()[isNCDHW ? 2 : 1]; + user_src_md->data.format_desc.blocking.strides[3] = +src->stridesOf()[isNCDHW ? 3 : 2]; + user_src_md->data.format_desc.blocking.strides[4] = +src->stridesOf()[isNCDHW ? 4 : 3]; } if (diff_src != nullptr && conv_diff_src_md != nullptr) { - *conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); - *user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); - user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0]; - user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4]; - user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1]; - user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2]; - user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3]; + *conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, +dnnl::memory::format_tag::any); *user_diff_src_md = dnnl::memory::desc({ +conv_src_tz }, type, format); user_diff_src_md->data.format_kind = dnnl_blocked; +// overrides "format = isNCDHW ? ncdhw : ndhwc" + user_diff_src_md->data.format_desc.blocking.strides[0] = +diff_src->stridesOf()[isNCDHW ? 0 : 0]; + user_diff_src_md->data.format_desc.blocking.strides[1] = +diff_src->stridesOf()[isNCDHW ? 1 : 4]; + user_diff_src_md->data.format_desc.blocking.strides[2] = +diff_src->stridesOf()[isNCDHW ? 2 : 1]; + user_diff_src_md->data.format_desc.blocking.strides[3] = +diff_src->stridesOf()[isNCDHW ? 3 : 2]; + user_diff_src_md->data.format_desc.blocking.strides[4] = +diff_src->stridesOf()[isNCDHW ? 4 : 3]; } if (weights != nullptr && conv_weights_md != nullptr) { - *conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); - *user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); - user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio" - user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[4]; - user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[3]; - user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0]; - user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1]; - user_weights_md->data.format_desc.blocking.strides[4] = weights->stridesOf()[2]; + *conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, +dnnl::memory::format_tag::any); *user_weights_md = dnnl::memory::desc({ +conv_weights_tz }, type, formatw); user_weights_md->data.format_kind = +dnnl_blocked; // overrides "formatw = dhwio" + user_weights_md->data.format_desc.blocking.strides[0] = +weights->stridesOf()[4]; user_weights_md->data.format_desc.blocking.strides[1] = +weights->stridesOf()[3]; user_weights_md->data.format_desc.blocking.strides[2] = +weights->stridesOf()[0]; user_weights_md->data.format_desc.blocking.strides[3] = +weights->stridesOf()[1]; user_weights_md->data.format_desc.blocking.strides[4] = +weights->stridesOf()[2]; } if (diff_weights != nullptr && conv_diff_weights_md != nullptr) { - *conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); - *user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); - user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio" - user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[4]; - user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[3]; - user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0]; - user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1]; - user_diff_weights_md->data.format_desc.blocking.strides[4] = diff_weights->stridesOf()[2]; + *conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, +dnnl::memory::format_tag::any); *user_diff_weights_md = dnnl::memory::desc({ +conv_weights_tz }, type, formatw); user_diff_weights_md->data.format_kind = +dnnl_blocked; // overrides "formatw = dhwio" + user_diff_weights_md->data.format_desc.blocking.strides[0] = +diff_weights->stridesOf()[4]; + user_diff_weights_md->data.format_desc.blocking.strides[1] = +diff_weights->stridesOf()[3]; + user_diff_weights_md->data.format_desc.blocking.strides[2] = +diff_weights->stridesOf()[0]; + user_diff_weights_md->data.format_desc.blocking.strides[3] = +diff_weights->stridesOf()[1]; + user_diff_weights_md->data.format_desc.blocking.strides[4] = +diff_weights->stridesOf()[2]; } if (bias != nullptr && conv_bias_md != nullptr) { - *conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any); - *user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x); + *conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, +dnnl::memory::format_tag::any); *user_bias_md = dnnl::memory::desc({ +conv_bias_tz }, type, dnnl::memory::format_tag::x); } if (dst != nullptr && conv_dst_md != nullptr) { - *conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any); - *user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format); - user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0]; - user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4]; - user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1]; - user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2]; - user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3]; + *conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, +dnnl::memory::format_tag::any); *user_dst_md = dnnl::memory::desc({ conv_dst_tz +}, type, format); user_dst_md->data.format_kind = dnnl_blocked; // overrides +"format = isNCDHW ? ncdhw : ndhwc" + user_dst_md->data.format_desc.blocking.strides[0] = +dst->stridesOf()[isNCDHW ? 0 : 0]; + user_dst_md->data.format_desc.blocking.strides[1] = +dst->stridesOf()[isNCDHW ? 1 : 4]; + user_dst_md->data.format_desc.blocking.strides[2] = +dst->stridesOf()[isNCDHW ? 2 : 1]; + user_dst_md->data.format_desc.blocking.strides[3] = +dst->stridesOf()[isNCDHW ? 3 : 2]; + user_dst_md->data.format_desc.blocking.strides[4] = +dst->stridesOf()[isNCDHW ? 4 : 3]; } }; -void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst, - dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) { - const Nd4jLong* shape = src->shapeInfo(); - Nd4jLong rank = shape[0]; - Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one - Nd4jLong dim2 = axis >= 2 ? 1 : 2; - Nd4jLong dim3 = axis >= 3 ? 2 : 3; - dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; +void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, +const NDArray* dst, dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* +batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md, dnnl::memory::desc* +user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* +user_dst_md, int axis) { const Nd4jLong* shape = src->shapeInfo(); Nd4jLong rank += shape[0]; Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to +be the "channel" one Nd4jLong dim2 = axis >= 2 ? 1 : 2; Nd4jLong dim3 = axis >= +3 ? 2 : 3; dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], +(int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? +(int)shape[dim3 + 1] : 1}; auto type = dnnl::memory::data_type::f32; auto format = dnnl::memory::format_tag::nchw; - auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any" - - if (src != nullptr && src->buffer() != nullptr && batchnorm_src_md != nullptr) { - *batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); - *user_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); - user_src_md->data.format_kind = dnnl_blocked; // overrides format - user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; - user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; - user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1; - user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1; - } - - if (diff_src != nullptr && diff_src->buffer() != nullptr && batchnorm_diff_src_md != nullptr) { - *batchnorm_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); - *user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); - user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format - user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0]; - user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1]; - user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1; - user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1; - } - - if (dst != nullptr && dst->buffer() != nullptr && batchnorm_dst_md != nullptr) { - *batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); - *user_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); - user_dst_md->data.format_kind = dnnl_blocked; // overrides format - user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; - user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; - user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1; - user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1; + auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // +doesn't work with "any" + + if (src != nullptr && src->buffer() != nullptr && batchnorm_src_md != +nullptr) { *batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, +supposed_to_be_any_format); *user_src_md = dnnl::memory::desc({ batchnorm_src_tz +}, type, format); user_src_md->data.format_kind = dnnl_blocked; // overrides +format user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; + user_src_md->data.format_desc.blocking.strides[1] = +src->stridesOf()[dim1]; user_src_md->data.format_desc.blocking.strides[2] = rank +> 2 ? src->stridesOf()[dim2] : 1; + user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? +src->stridesOf()[dim3] : 1; + } + + if (diff_src != nullptr && diff_src->buffer() != nullptr && +batchnorm_diff_src_md != nullptr) { *batchnorm_diff_src_md = +dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); + *user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, +format); user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format + user_diff_src_md->data.format_desc.blocking.strides[0] = +diff_src->stridesOf()[0]; user_diff_src_md->data.format_desc.blocking.strides[1] += diff_src->stridesOf()[dim1]; + user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? +diff_src->stridesOf()[dim2] : 1; + user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? +diff_src->stridesOf()[dim3] : 1; + } + + if (dst != nullptr && dst->buffer() != nullptr && batchnorm_dst_md != +nullptr) { *batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, +supposed_to_be_any_format); *user_dst_md = dnnl::memory::desc({ batchnorm_src_tz +}, type, format); user_dst_md->data.format_kind = dnnl_blocked; // overrides +format user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; + user_dst_md->data.format_desc.blocking.strides[1] = +dst->stridesOf()[dim1]; user_dst_md->data.format_desc.blocking.strides[2] = rank +> 2 ? dst->stridesOf()[dim2] : 1; + user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? +dst->stridesOf()[dim3] : 1; } }; */ -} -} \ No newline at end of file +} // namespace mkldnnUtils +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h index f3ff327a441f..38c70013b702 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h @@ -14,184 +14,213 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author saudet - // @author Yurii Shyrma (iuriish@yahoo.com) - // +// +// @author saudet +// @author Yurii Shyrma (iuriish@yahoo.com) +// -#ifndef DEV_TESTS_MKLDNNUTILS_H -#define DEV_TESTS_MKLDNNUTILS_H +#ifndef SD_MKLDNNUTILS_H +#define SD_MKLDNNUTILS_H - -#include #include -#include -#include #include +#include +#include #include #include -using namespace samediff; +#include +using namespace samediff; namespace sd { - namespace ops { - namespace platforms { - /** - * Here we actually declare our platform helpers - */ - DECLARE_PLATFORM(conv2d, ENGINE_CPU); - - DECLARE_PLATFORM(conv2d_bp, ENGINE_CPU); - - DECLARE_PLATFORM(avgpool2d, ENGINE_CPU); - - DECLARE_PLATFORM(avgpool2d_bp, ENGINE_CPU); - - DECLARE_PLATFORM(maxpool2d, ENGINE_CPU); - - DECLARE_PLATFORM(maxpool2d_bp, ENGINE_CPU); - - DECLARE_PLATFORM(conv3dnew, ENGINE_CPU); - - DECLARE_PLATFORM(conv3dnew_bp, ENGINE_CPU); - - DECLARE_PLATFORM(maxpool3dnew, ENGINE_CPU); - - DECLARE_PLATFORM(maxpool3dnew_bp, ENGINE_CPU); - - DECLARE_PLATFORM(avgpool3dnew, ENGINE_CPU); - - DECLARE_PLATFORM(avgpool3dnew_bp, ENGINE_CPU); - - DECLARE_PLATFORM(lrn, ENGINE_CPU); +namespace ops { +namespace platforms { +/** + * Here we actually declare our platform helpers + */ +DECLARE_PLATFORM(conv2d, ENGINE_CPU); - DECLARE_PLATFORM(batchnorm, ENGINE_CPU); +DECLARE_PLATFORM(conv2d_bp, ENGINE_CPU); - DECLARE_PLATFORM(batchnorm_bp, ENGINE_CPU); +DECLARE_PLATFORM(avgpool2d, ENGINE_CPU); - DECLARE_PLATFORM(lstmLayer, ENGINE_CPU); +DECLARE_PLATFORM(avgpool2d_bp, ENGINE_CPU); - DECLARE_PLATFORM(deconv2d, ENGINE_CPU); +DECLARE_PLATFORM(maxpool2d, ENGINE_CPU); - DECLARE_PLATFORM(deconv2d_tf, ENGINE_CPU); +DECLARE_PLATFORM(maxpool2d_bp, ENGINE_CPU); - DECLARE_PLATFORM(deconv3d, ENGINE_CPU); +DECLARE_PLATFORM(conv3dnew, ENGINE_CPU); - DECLARE_PLATFORM(deconv2d_bp, ENGINE_CPU); +DECLARE_PLATFORM(conv3dnew_bp, ENGINE_CPU); - DECLARE_PLATFORM(deconv3d_bp, ENGINE_CPU); +DECLARE_PLATFORM(maxpool3dnew, ENGINE_CPU); - DECLARE_PLATFORM(depthwise_conv2d, ENGINE_CPU); +DECLARE_PLATFORM(maxpool3dnew_bp, ENGINE_CPU); - DECLARE_PLATFORM(depthwise_conv2d_bp, ENGINE_CPU); +DECLARE_PLATFORM(avgpool3dnew, ENGINE_CPU); - DECLARE_PLATFORM(matmul, ENGINE_CPU); +DECLARE_PLATFORM(avgpool3dnew_bp, ENGINE_CPU); - DECLARE_PLATFORM(softmax, ENGINE_CPU); +DECLARE_PLATFORM(lrn, ENGINE_CPU); - DECLARE_PLATFORM(softmax_bp, ENGINE_CPU); +DECLARE_PLATFORM(batchnorm, ENGINE_CPU); - DECLARE_PLATFORM(tanh, ENGINE_CPU); +DECLARE_PLATFORM(batchnorm_bp, ENGINE_CPU); - DECLARE_PLATFORM(tanh_bp, ENGINE_CPU); +DECLARE_PLATFORM(lstmLayer, ENGINE_CPU); - DECLARE_PLATFORM(xw_plus_b, ENGINE_CPU); +DECLARE_PLATFORM(deconv2d, ENGINE_CPU); - DECLARE_PLATFORM(xw_plus_b_bp, ENGINE_CPU); +DECLARE_PLATFORM(deconv2d_tf, ENGINE_CPU); - DECLARE_PLATFORM(concat, ENGINE_CPU); +DECLARE_PLATFORM(deconv3d, ENGINE_CPU); - } - } +DECLARE_PLATFORM(deconv2d_bp, ENGINE_CPU); - namespace mkldnnUtils { +DECLARE_PLATFORM(deconv3d_bp, ENGINE_CPU); - void poolingMKLDNN(const NDArray* input, NDArray* output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int isNCHW, const dnnl::algorithm mode); +DECLARE_PLATFORM(depthwise_conv2d, ENGINE_CPU); - void poolingBpMKLDNN(const NDArray* input, const NDArray* gradO, NDArray* gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int isNCHW, const dnnl::algorithm mode); +DECLARE_PLATFORM(depthwise_conv2d_bp, ENGINE_CPU); - void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst, - dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis); +DECLARE_PLATFORM(matmul, ENGINE_CPU); - dnnl::engine& getEngine(void* ptr); +DECLARE_PLATFORM(softmax, ENGINE_CPU); - /** - * This function creates memory dimentions - * @param const pointer to array - * @param const array rank - * @param reference to memory dimentions - */ - void getDims(const NDArray* array, const int rank, dnnl::memory::dims& mklDims); - /** - * This function evaluate memory format tag based on array shapeInfo - * @param const array - * @return memory format - */ - dnnl::memory::format_tag getFormat(const NDArray& arr); +DECLARE_PLATFORM(softmax_bp, ENGINE_CPU); - void setBlockStrides(const NDArray& array, dnnl::memory::desc& mklMd, const std::vector& permut = {}); - ////////////////////////////////////////////////////////////////////// - /** - * This function load and reorder user memory to mkl - * @param const pointer to dataset - * @param reference to mkl engine - * @param reference to mkl stream - * @param reference to args container for dnnl - * @param reference to user memory description - * @param primitive memory descriptor - * @param dnnl arg activation enumerator - */ - dnnl::memory loadDataToMklStream(const NDArray& array, const dnnl::engine& engine, const dnnl::stream& stream, const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md, - dnnl::memory& arg); +DECLARE_PLATFORM(tanh, ENGINE_CPU); - /** - * Utility methods for MKLDNN - */ - /* void getMKLDNNMemoryDescConv2d( - int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW, - int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src, - const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, - dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, - dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md, - dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md, - dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation); +DECLARE_PLATFORM(tanh_bp, ENGINE_CPU); - void getMKLDNNMemoryDescConv3d( - int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW, - int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src, - const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, - dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, - dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md, - dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md, - dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation); +DECLARE_PLATFORM(xw_plus_b, ENGINE_CPU); - void getMKLDNNMemoryDescPool2d( - int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW, - int bS, int iC, int iH, int iW, int oC, int oH, int oW, - const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm, - dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, - dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r); - - void getMKLDNNMemoryDescPool3d( - int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW, - int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, - const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm, - dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, - dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r); - - void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst, - dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis); - */ - } -} - - - -#endif //DEV_TESTS_MKLDNNUTILS_H +DECLARE_PLATFORM(xw_plus_b_bp, ENGINE_CPU); + +DECLARE_PLATFORM(concat, ENGINE_CPU);} // namespace platforms +} // namespace ops + +namespace mkldnnUtils { + +void poolingMKLDNN(const NDArray* input, NDArray* output, const int kD, + const int kH, const int kW, const int sD, const int sH, + const int sW, const int pD, const int pH, const int pW, + const int isNCHW, const dnnl::algorithm mode); + +void poolingBpMKLDNN(const NDArray* input, const NDArray* gradO, NDArray* gradI, + const int kD, const int kH, const int kW, const int sD, + const int sH, const int sW, const int pD, const int pH, + const int pW, const int isNCHW, + const dnnl::algorithm mode); + +void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, + const NDArray* dst, dnnl::memory::desc* lrn_src_md, + dnnl::memory::desc* lrn_diff_src_md, + dnnl::memory::desc* lrn_dst_md, + dnnl::memory::desc* user_src_md, + dnnl::memory::desc* user_diff_src_md, + dnnl::memory::desc* user_dst_md, int axis); + +dnnl::engine& getEngine(void* ptr); + +/** + * This function creates memory dimentions + * @param const pointer to array + * @param const array rank + * @param reference to memory dimentions + */ +void getDims(const NDArray* array, const int rank, dnnl::memory::dims& mklDims); +/** + * This function evaluate memory format tag based on array shapeInfo + * @param const array + * @return memory format + */ +dnnl::memory::format_tag getFormat(const NDArray& arr); + +void setBlockStrides(const NDArray& array, dnnl::memory::desc& mklMd, const std::vector& permut = {}); +////////////////////////////////////////////////////////////////////// +/** + * This function load and reorder user memory to mkl + * @param const pointer to dataset + * @param reference to mkl engine + * @param reference to mkl stream + * @param reference to args container for dnnl + * @param reference to user memory description + * @param primitive memory descriptor + * @param dnnl arg activation enumerator + */ +dnnl::memory loadDataToMklStream(const NDArray& array, const dnnl::engine& engine, + const dnnl::stream& stream, + const dnnl::memory::desc& user_md, + const dnnl::memory::desc& primitive_md, + dnnl::memory& arg); + +/** + * Utility methods for MKLDNN + */ +/* void getMKLDNNMemoryDescConv2d( + int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, + const int paddingMode, bool isNCHW, int bS, int iC, int iH, int iW, int oC, + int oH, int oW, const NDArray* src, const NDArray* diff_src, const NDArray* + weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* + dst, dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, + dnnl::memory::desc* conv_weights_md, dnnl::memory::desc* + conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* + conv_dst_md, dnnl::memory::desc* user_src_md, dnnl::memory::desc* + user_diff_src_md, dnnl::memory::desc* user_weights_md, dnnl::memory::desc* + user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* + user_dst_md, dnnl::memory::dims& conv_strides, dnnl::memory::dims& + conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& + conv_dilation); + + void getMKLDNNMemoryDescConv3d( + int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, + int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW, int bS, int + iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* + src, const NDArray* diff_src, const NDArray* weights, const NDArray* + diff_weights, const NDArray* bias, const NDArray* dst, dnnl::memory::desc* + conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* + conv_weights_md, dnnl::memory::desc* conv_diff_weights_md, + dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* + user_diff_src_md, dnnl::memory::desc* user_weights_md, dnnl::memory::desc* + user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* + user_dst_md, dnnl::memory::dims& conv_strides, dnnl::memory::dims& + conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& + conv_dilation); + + void getMKLDNNMemoryDescPool2d( + int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, + int poolingMode, int extraParam0, bool isNCHW, int bS, int iC, int iH, int + iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src, + const NDArray* dst, dnnl::algorithm& algorithm, dnnl::memory::desc* + pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* + pool_dst_md, dnnl::memory::desc* user_src_md, dnnl::memory::desc* + user_diff_src_md, dnnl::memory::desc* user_dst_md, dnnl::memory::dims& + pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& + pool_padding, dnnl::memory::dims& pool_padding_r); + + void getMKLDNNMemoryDescPool3d( + int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, + int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool + isNCDHW, int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int + oW, const NDArray* src, const NDArray* diff_src, const NDArray* dst, + dnnl::algorithm& algorithm, dnnl::memory::desc* pool_src_md, + dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* + user_diff_src_md, dnnl::memory::desc* user_dst_md, dnnl::memory::dims& + pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& + pool_padding, dnnl::memory::dims& pool_padding_r); + + void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* + diff_src, const NDArray* dst, dnnl::memory::desc* batchnorm_src_md, + dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* + batchnorm_dst_md, dnnl::memory::desc* user_src_md, dnnl::memory::desc* + user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis); +*/ +} // namespace mkldnnUtils +} // namespace sd + +#endif // SD_MKLDNNUTILS_H diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp index 9935fd50faea..a291887b0343 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp @@ -14,248 +14,278 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleg Semeniv - // - // +// +// @author Oleg Semeniv +// +// -#include +#include #include +#include #include -#include + #include "mkldnnUtils.h" using namespace dnnl; namespace sd { - namespace ops { - namespace platforms { - +namespace ops { +namespace platforms { - ////////////////////////////////////////////////////////////////////// - static void softmaxMKLDNN(const NDArray* x, NDArray* z, const int axis) { +////////////////////////////////////////////////////////////////////// +static void softmaxMKLDNN(const NDArray* x, NDArray* z, const int axis) { - dnnl::memory::dims shape = x->getShapeAsFlatVector(); + dnnl::memory::dims shape = x->getShapeAsFlatVector(); - const int xRank = x->rankOf(); + const int xRank = x->rankOf(); - dnnl::memory::format_tag xFormat = mkldnnUtils::getFormat(*x); +dnnl::memory::format_tag xFormat = mkldnnUtils::getFormat(*x); dnnl::memory::format_tag zFormat = mkldnnUtils::getFormat(*z); - - // optimized cases - if (2 == xRank && 0 == axis) { - if(x->ews() == 1) + // optimized cases + if (2 == xRank && 0 == axis) { + if(x->ews() == 1) xFormat = dnnl::memory::format_tag::ba; if(z->ews() == 1) zFormat = dnnl::memory::format_tag::ba; - } - else if (4 == xRank && 1 == axis && (x->sizeAt(2) * x->sizeAt(3)) > 1) { - if(x->ews() == 1) + } else if (4 == xRank && 1 == axis && (x->sizeAt(2) * x->sizeAt(3)) > 1) { + if(x->ews() == 1) xFormat = dnnl::memory::format_tag::acdb; if(z->ews() == 1) zFormat = dnnl::memory::format_tag::acdb; - } + } - dnnl::memory::data_type xType = dnnl::memory::data_type::f32; + dnnl::memory::data_type xType = dnnl::memory::data_type::f32; - dnnl::memory::desc x_mkl_md, x_user_md, z_mkl_md, z_user_md; + dnnl::memory::desc x_mkl_md , x_user_md, z_mkl_md, z_user_md; + x_user_md = x_mkl_md = dnnl::memory::desc(shape, xType, xFormat); + mkldnnUtils::setBlockStrides(*x, x_user_md); - x_user_md = x_mkl_md = dnnl::memory::desc(shape, xType, xFormat); - mkldnnUtils::setBlockStrides(*x, x_user_md); + // z + z_user_md = z_mkl_md = dnnl::memory::desc(shape, xType, zFormat); + mkldnnUtils::setBlockStrides(*z, z_user_md); - // z - z_user_md = z_mkl_md = dnnl::memory::desc(shape, xType, zFormat); - mkldnnUtils::setBlockStrides(*z, z_user_md); + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + // Create attributes (to handle alpha and beta if necessary) + dnnl::primitive_attr attr; // it is empty since we have usual values for + // alpha (=1) and beta (=0) - // Create attributes (to handle alpha and beta if necessary) - dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0) + // operation primitive description + dnnl::softmax_forward::desc op_desc(dnnl::prop_kind::forward_inference, + x_mkl_md, axis); - // operation primitive description - dnnl::softmax_forward::desc op_desc(dnnl::prop_kind::forward_inference, x_mkl_md, axis); + dnnl::softmax_forward::primitive_desc op_prim_desc(op_desc, attr, engine); - dnnl::softmax_forward::primitive_desc op_prim_desc(op_desc, attr, engine); + // arguments (memory buffers) necessary for calculations + std::unordered_map args; - // arguments (memory buffers) necessary for calculations - std::unordered_map args; + dnnl::stream stream(engine); - dnnl::stream stream(engine); + // provide memory buffers and check whether reorder is required - // provide memory buffers and check whether reorder is required + // input + mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, + op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - // input - mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + // z + auto z_user_mem = mkldnnUtils::loadDataToMklStream(*z, engine, stream, z_user_md, op_prim_desc.dst_desc(), + args[DNNL_ARG_DST] ); - // z - auto z_user_mem = mkldnnUtils::loadDataToMklStream(*z, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]); + // run calculations + dnnl::softmax_forward(op_prim_desc).execute(stream, args); - // run calculations - dnnl::softmax_forward(op_prim_desc).execute(stream, args); + // reorder outputs if necessary + if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); - // reorder outputs if necessary - if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) - dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); - - stream.wait(); - } - - - PLATFORM_IMPL(softmax, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - const int rank = input->rankOf(); - int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1; + stream.wait(); +} - if (dim < 0) { - dim += rank; - } +PLATFORM_IMPL(softmax, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(dim < rank && dim >= 0, 0, "SOFTMAX_MKLDNN OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim); + const int rank = input->rankOf(); + int dim = block.numI() > 0 ? INT_ARG(0) : rank - 1; - REQUIRE_TRUE(rank <= 6, 0, "SOFTMAX_MKLDNN OP: the rank of input must be less or qual 6, but got rank = %i instead !", rank); + if (dim < 0) { + dim += rank; + } - // mkldnnSoftMax - softmaxMKLDNN(input, output, dim); + REQUIRE_TRUE( + dim < rank && dim >= 0, 0, + "SOFTMAX_MKLDNN OP: the value of input integer parameter (dimension) " + "must be less than input array rank %i, but got dimension = %i instead !", + rank, dim); - return Status::OK(); - } + REQUIRE_TRUE(rank <= 6, 0, + "SOFTMAX_MKLDNN OP: the rank of input must be less or qual 6, " + "but got rank = %i instead !", + rank); - PLATFORM_CHECK(softmax, ENGINE_CPU) { + // mkldnnSoftMax + softmaxMKLDNN(input, output, dim); - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); + return Status::OK(); +} - const DataType xType = x->dataType(); - const DataType zType = z->dataType(); +PLATFORM_CHECK(softmax, ENGINE_CPU) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + + const DataType xType = x->dataType(); + const DataType zType = z->dataType(); + + const int xRank = x->rankOf(); + bool bSupportedRanks = (xRank > 2 && xRank < 7); + /* + Source Destination + f32 f32 + */ + return !x->isEmpty() && block.isUseMKLDNN() && bSupportedRanks && + (xType == DataType::FLOAT32 && zType == DataType::FLOAT32); +} - const int xRank = x->rankOf(); - bool bSupportedRanks = (xRank > 2 && xRank < 7); - /* - Source Destination - f32 f32 - */ - return !x->isEmpty() && block.isUseMKLDNN() && bSupportedRanks && (xType == DataType::FLOAT32 && zType == DataType::FLOAT32); +////////////////////////////////////////////////////////////////////// +static void softmaxBpMKLDNN(const NDArray* x, const NDArray* dLdz, + NDArray* dLdx, const int axis) { - } - ////////////////////////////////////////////////////////////////////// - static void softmaxBpMKLDNN(const NDArray* x, const NDArray* dLdz, NDArray* dLdx, const int axis) { + dnnl::memory::desc x_user_md, x_mkl_md, dLdx_mkl_md, dLdx_user_md, dLdz_mkl_md, dLdz_user_md; - dnnl::memory::desc x_user_md, x_mkl_md, dLdx_mkl_md, dLdx_user_md, dLdz_mkl_md, dLdz_user_md; + // x + x_mkl_md = + x_user_md = + dnnl::memory::desc(x->getShapeAsFlatVector(), dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*x)); + mkldnnUtils::setBlockStrides(*x, x_user_md); - // x - x_mkl_md = x_user_md = dnnl::memory::desc(x->getShapeAsFlatVector(), dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*x)); - mkldnnUtils::setBlockStrides(*x, x_user_md); + // dLdx + dLdx_mkl_md = + dLdx_user_md = + dnnl::memory::desc(dLdx->getShapeAsFlatVector(), dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*dLdx)); + mkldnnUtils::setBlockStrides(*dLdx, dLdx_user_md); - // dLdx - dLdx_mkl_md = dLdx_user_md = dnnl::memory::desc(dLdx->getShapeAsFlatVector(), dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*dLdx)); - mkldnnUtils::setBlockStrides(*dLdx, dLdx_user_md); - // dLdz - dLdz_mkl_md = dLdz_user_md = dnnl::memory::desc(dLdz->getShapeAsFlatVector(), dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*dLdz)); - mkldnnUtils::setBlockStrides(*dLdz, dLdz_user_md); + // dLdz + dLdz_mkl_md = + dLdz_user_md = + dnnl::memory::desc(dLdz->getShapeAsFlatVector(), dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*dLdz)); + mkldnnUtils::setBlockStrides(*dLdz, dLdz_user_md); - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - // operation primitive description - // forward description - dnnl::softmax_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, axis); - dnnl::softmax_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + // operation primitive description + // forward description + dnnl::softmax_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, + x_mkl_md, axis); + dnnl::softmax_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); - // backward description - dnnl::softmax_backward::desc op_bp_desc(dLdz_mkl_md, dLdx_mkl_md, axis); - dnnl::softmax_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc); + // backward description + dnnl::softmax_backward::desc op_bp_desc(dLdz_mkl_md, dLdx_mkl_md, axis); + dnnl::softmax_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, + op_ff_prim_desc); - // arguments (memory buffers) necessary for calculations - std::unordered_map argsbp, argsff; + // arguments (memory buffers) necessary for calculations + std::unordered_map argsbp, argsff; - dnnl::stream stream(engine); + dnnl::stream stream(engine); - // provide memory buffers and check whether reorder is required for forward - // input - mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), argsff[DNNL_ARG_SRC]); + // provide memory buffers and check whether reorder is required for forward + // input + mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, + op_ff_prim_desc.src_desc(), + argsff[DNNL_ARG_SRC]); // dLdz mkldnnUtils::loadDataToMklStream(*dLdz, engine, stream, dLdz_user_md, op_bp_prim_desc.diff_dst_desc(), argsbp[DNNL_ARG_DIFF_DST]); - // dLdx - auto dLdx_user_mem = mkldnnUtils::loadDataToMklStream(*dLdx, engine, stream, dLdx_user_md, op_ff_prim_desc.src_desc(), argsff[DNNL_ARG_DST]); - - // check and arg set for backprob - argsbp[DNNL_ARG_DIFF_SRC] = argsff[DNNL_ARG_DST]; - argsbp[DNNL_ARG_DST] = argsff[DNNL_ARG_DST]; - - - // run calculations forward - dnnl::softmax_forward(op_ff_prim_desc).execute(stream, argsff); + // dLdx + auto dLdx_user_mem = mkldnnUtils::loadDataToMklStream(*dLdx, engine, stream, dLdx_user_md, op_ff_prim_desc.src_desc(), + argsff[DNNL_ARG_DST] ); - // run calculations backward - dnnl::softmax_backward(op_bp_prim_desc).execute(stream, argsbp); + // check and arg set for backprob + argsbp[DNNL_ARG_DIFF_SRC] = argsff[DNNL_ARG_DST]; + argsbp[DNNL_ARG_DST] = argsff[DNNL_ARG_DST]; - // reorder outputs if necessary - if (op_ff_prim_desc.dst_desc() != dLdx_user_mem.get_desc()) - dnnl::reorder(argsff[DNNL_ARG_DST], dLdx_user_mem).execute(stream, argsff[DNNL_ARG_DST], dLdx_user_mem); - stream.wait(); - } + // run calculations forward + dnnl::softmax_forward(op_ff_prim_desc).execute(stream, argsff); + // run calculations backward + dnnl::softmax_backward(op_bp_prim_desc).execute(stream, argsbp); - PLATFORM_IMPL(softmax_bp, ENGINE_CPU) { + // reorder outputs if necessary + if (op_ff_prim_desc.dst_desc() != dLdx_user_mem.get_desc()) + dnnl::reorder(argsff[DNNL_ARG_DST], dLdx_user_mem) + .execute(stream, argsff[DNNL_ARG_DST], dLdx_user_mem); - auto input = INPUT_VARIABLE(0); - auto dLdz = INPUT_VARIABLE(1); - auto dLdx = OUTPUT_VARIABLE(0); - - const int rank = input->rankOf(); - const int dLdzRank = dLdz->rankOf(); - int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1; - - if (dim < 0) { - dim += rank; - } + stream.wait(); +} - REQUIRE_TRUE(dim < rank && dim >= 0, 0, "SOFTMAX_MKLDNN_BP OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim); +PLATFORM_IMPL(softmax_bp, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto dLdz = INPUT_VARIABLE(1); + auto dLdx = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(rank <= 6 && dLdzRank <= 6, 0, "SOFTMAX_MKLDNN_BP OP: the rank of input and dLdz must be less or qual 6, but got input rank = %i and dLdz rank rank = %i instead !", rank, dLdzRank); + const int rank = input->rankOf(); + const int dLdzRank = dLdz->rankOf(); + int dim = block.numI() > 0 ? INT_ARG(0) : rank - 1; - // mkldnnSoftMax - softmaxBpMKLDNN(input, dLdz, dLdx, dim); + if (dim < 0) { + dim += rank; + } - return Status::OK(); - } + REQUIRE_TRUE( + dim < rank && dim >= 0, 0, + "SOFTMAX_MKLDNN_BP OP: the value of input integer parameter (dimension) " + "must be less than input array rank %i, but got dimension = %i instead !", + rank, dim); - PLATFORM_CHECK(softmax_bp, ENGINE_CPU) { + REQUIRE_TRUE( + rank <= 6 && dLdzRank <= 6, 0, + "SOFTMAX_MKLDNN_BP OP: the rank of input and dLdz must be less or qual " + "6, but got input rank = %i and dLdz rank rank = %i instead !", + rank, dLdzRank); - auto x = INPUT_VARIABLE(0); - auto dLdz = INPUT_VARIABLE(1); - auto dLdx = OUTPUT_VARIABLE(0); + // mkldnnSoftMax + softmaxBpMKLDNN(input, dLdz, dLdx, dim); - const DataType xType = x->dataType(); - const DataType dLdzType = dLdz->dataType(); - const DataType dLdxType = dLdx->dataType(); + return Status::OK(); +} - const int xRank = x->rankOf(); - const int dLdzRank = dLdz->rankOf(); +PLATFORM_CHECK(softmax_bp, ENGINE_CPU) { + auto x = INPUT_VARIABLE(0); + auto dLdz = INPUT_VARIABLE(1); + auto dLdx = OUTPUT_VARIABLE(0); - bool bSupportedRanks = xRank < 7 && dLdzRank == xRank && (!x->isEmpty() && !dLdz->isEmpty()); + const DataType xType = x->dataType(); + const DataType dLdzType = dLdz->dataType(); + const DataType dLdxType = dLdx->dataType(); - if (bSupportedRanks) { - for (int i = 0; i < xRank; i++) { - if (x->sizeAt(i) != dLdz->sizeAt(i)) { - bSupportedRanks = false; - break; - } - } - } + const int xRank = x->rankOf(); + const int dLdzRank = dLdz->rankOf(); - //Source Destination - //f32 f32 - return block.isUseMKLDNN() && bSupportedRanks && (xType == DataType::FLOAT32 && dLdzType == DataType::FLOAT32 && dLdxType == DataType::FLOAT32); - } + bool bSupportedRanks = + xRank < 7 && dLdzRank == xRank && (!x->isEmpty() && !dLdz->isEmpty()); - } + if (bSupportedRanks) { + for (int i = 0; i < xRank; i++) { + if (x->sizeAt(i) != dLdz->sizeAt(i)) { + bSupportedRanks = false; + break; + } } + } + + // Source Destination + // f32 f32 + return block.isUseMKLDNN() && bSupportedRanks && + (xType == DataType::FLOAT32 && dLdzType == DataType::FLOAT32 && + dLdxType == DataType::FLOAT32); } + +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp index a808239dec11..e702ec0b8fe0 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp @@ -14,204 +14,232 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleg Semeniv - // - // +// +// @author Oleg Semeniv +// +// -#include +#include #include +#include #include -#include + #include "mkldnnUtils.h" using namespace dnnl; namespace sd { - namespace ops { - namespace platforms { - - ////////////////////////////////////////////////////////////////////// - static void tanhMKLDNN(const NDArray* x, NDArray* z) { - - dnnl::memory::dims shape = x->getShapeAsFlatVector(); - - dnnl::memory::desc x_mkl_md, x_user_md, z_mkl_md, z_user_md; - - x_user_md = x_mkl_md = dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*x)); - mkldnnUtils::setBlockStrides(*x, x_user_md); - - // z - z_user_md = z_mkl_md = dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*z)); - mkldnnUtils::setBlockStrides(*z, z_user_md); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // Create attributes (to handle alpha and beta if necessary) - dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0) - - // operation primitive description - dnnl::eltwise_forward::desc op_desc(dnnl::prop_kind::forward_inference, algorithm::eltwise_tanh, x_mkl_md, 0, 0); +namespace ops { +namespace platforms { - dnnl::eltwise_forward::primitive_desc op_prim_desc(op_desc, attr, engine); +////////////////////////////////////////////////////////////////////// +static void tanhMKLDNN(const NDArray* x, NDArray* z) { - // arguments (memory buffers) necessary for calculations - std::unordered_map args; + dnnl::memory::dims shape = x->getShapeAsFlatVector(); - dnnl::stream stream(engine); + dnnl::memory::desc x_mkl_md, x_user_md, z_mkl_md, z_user_md; - // provide memory buffers and check whether reorder is required - // input - mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + x_user_md = x_mkl_md = + dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*x)); + mkldnnUtils::setBlockStrides(*x, x_user_md); - // z - auto z_user_mem = mkldnnUtils::loadDataToMklStream(*z, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]); + // z + z_user_md = z_mkl_md = + dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*z)); + mkldnnUtils::setBlockStrides(*z, z_user_md); - // run calculations - dnnl::eltwise_forward(op_prim_desc).execute(stream, args); + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - // reorder outputs if necessary - if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) - dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); + // Create attributes (to handle alpha and beta if necessary) + dnnl::primitive_attr attr; // it is empty since we have usual values for + // alpha (=1) and beta (=0) - stream.wait(); - } + // operation primitive description + dnnl::eltwise_forward::desc op_desc(dnnl::prop_kind::forward_inference, + algorithm::eltwise_tanh, x_mkl_md, 0, 0); + dnnl::eltwise_forward::primitive_desc op_prim_desc(op_desc, attr, engine); - PLATFORM_IMPL(tanh, ENGINE_CPU) { + // arguments (memory buffers) necessary for calculations + std::unordered_map args; - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - const int rank = input->rankOf(); - REQUIRE_TRUE(rank <= 6, 0, "TANH_MKLDNN OP: the rank of input must be less or qual 6, but got rank = %i instead !", rank); + dnnl::stream stream(engine); - // mkldnnTanh - tanhMKLDNN(input, output); + // provide memory buffers and check whether reorder is required + // input + mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, + op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - return Status::OK(); - } + // z + auto z_user_mem = mkldnnUtils::loadDataToMklStream(*z, engine, stream, z_user_md, op_prim_desc.dst_desc(), + args[DNNL_ARG_DST] ); - PLATFORM_CHECK(tanh, ENGINE_CPU) { + // run calculations + dnnl::eltwise_forward(op_prim_desc).execute(stream, args); - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); + // reorder outputs if necessary + if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); - const DataType xType = x->dataType(); - const DataType zType = z->dataType(); - - const int xRank = x->rankOf(); - bool bSupportedRanks = !x->isEmpty() && xRank < 7 && xRank > 0 && (xType == DataType::FLOAT32 && zType == DataType::FLOAT32); - /* - Source Destination - f32 f32 - */ - return block.isUseMKLDNN() && bSupportedRanks; - } - - - ////////////////////////////////////////////////////////////////////// - static void tanhBpMKLDNN(const NDArray* x, const NDArray* dLdz, NDArray* dLdx) { - - dnnl::memory::dims shape = x->getShapeAsFlatVector(); - - dnnl::memory::desc x_mkl_md, x_user_md, dLdx_mkl_md, dLdx_user_md, dLdz_mkl_md, dLdz_user_md; - - // x - x_user_md = x_mkl_md = dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*x)); - mkldnnUtils::setBlockStrides(*x, x_user_md); + stream.wait(); +} - // dLdz - dLdz_user_md = dLdz_mkl_md = dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*dLdz)); - mkldnnUtils::setBlockStrides(*dLdz, dLdz_user_md); +PLATFORM_IMPL(tanh, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + const int rank = input->rankOf(); + REQUIRE_TRUE(rank <= 6, 0, + "TANH_MKLDNN OP: the rank of input must be less or qual 6, but " + "got rank = %i instead !", + rank); - // dLdx - dLdx_user_md = dLdx_mkl_md = dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*dLdx)); - mkldnnUtils::setBlockStrides(*dLdx, dLdx_user_md); + // mkldnnTanh + tanhMKLDNN(input, output); - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + return Status::OK(); +} - // arguments (memory buffers) necessary for calculations - std::unordered_map args; +PLATFORM_CHECK(tanh, ENGINE_CPU) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + + const DataType xType = x->dataType(); + const DataType zType = z->dataType(); + + const int xRank = x->rankOf(); + bool bSupportedRanks = + !x->isEmpty() && xRank < 7 && xRank > 0 && + (xType == DataType::FLOAT32 && zType == DataType::FLOAT32); + /* + Source Destination + f32 f32 + */ + return block.isUseMKLDNN() && bSupportedRanks; +} - dnnl::stream stream(engine); +////////////////////////////////////////////////////////////////////// +static void tanhBpMKLDNN(const NDArray* x, const NDArray* dLdz, NDArray* dLdx) { - // operation primitive description - // forward - dnnl::eltwise_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, algorithm::eltwise_tanh, x_mkl_md, 0, 0); - dnnl::eltwise_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + dnnl::memory::dims shape = x->getShapeAsFlatVector(); - // backward description - dnnl::eltwise_backward::desc op_desc(algorithm::eltwise_tanh, dLdz_mkl_md, x_mkl_md, 0, 0); - dnnl::eltwise_backward::primitive_desc op_prim_desc(op_desc, engine, op_ff_prim_desc); + dnnl::memory::desc x_mkl_md, x_user_md, dLdx_mkl_md, dLdx_user_md, dLdz_mkl_md, dLdz_user_md; - // provide memory buffers and check whether reorder is required for forward - // input - mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + // x - // dLdz - mkldnnUtils::loadDataToMklStream(*dLdz, engine, stream, dLdz_user_md, op_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); + x_user_md = x_mkl_md = + dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*x)); + mkldnnUtils::setBlockStrides(*x, x_user_md); - // dLdx - auto dLdx_user_mem = mkldnnUtils::loadDataToMklStream(*dLdx, engine, stream, dLdx_user_md, op_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]); + // dLdz + dLdz_user_md = dLdz_mkl_md = + dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*dLdz)); + mkldnnUtils::setBlockStrides(*dLdz, dLdz_user_md); - // run calculations backward - dnnl::eltwise_backward(op_prim_desc).execute(stream, args); + // dLdx + dLdx_user_md = dLdx_mkl_md = + dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*dLdx)); + mkldnnUtils::setBlockStrides(*dLdx, dLdx_user_md); - // reorder outputs if necessary - if (op_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc()) - dnnl::reorder(args[DNNL_ARG_DIFF_SRC], dLdx_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], dLdx_user_mem); + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - stream.wait(); - } + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + dnnl::stream stream(engine); - PLATFORM_IMPL(tanh_bp, ENGINE_CPU) { + // operation primitive description + // forward + dnnl::eltwise_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, + algorithm::eltwise_tanh, x_mkl_md, 0, + 0); + dnnl::eltwise_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); - auto input = INPUT_VARIABLE(0); - auto dLdz = INPUT_VARIABLE(1); - auto dLdx = OUTPUT_VARIABLE(0); + // backward description + dnnl::eltwise_backward::desc op_desc(algorithm::eltwise_tanh, dLdz_mkl_md, + x_mkl_md, 0, 0); + dnnl::eltwise_backward::primitive_desc op_prim_desc(op_desc, engine, + op_ff_prim_desc); - const int rank = input->rankOf(); - const int dLdzRank = dLdz->rankOf(); + // provide memory buffers and check whether reorder is required for forward + // input + mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, + op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - REQUIRE_TRUE(rank <= 6 && dLdzRank <= 6, 0, "TANH_BP_MKLDNN OP: the rank of input and dLdz must be less or qual 6, but got input rank = %i and dLdz rank rank = %i instead !", rank, dLdzRank); + // dLdz + mkldnnUtils::loadDataToMklStream(*dLdz, engine, stream, dLdz_user_md, + op_prim_desc.diff_dst_desc(), + args[DNNL_ARG_DIFF_DST]); - // mkldnnSoftMax - tanhBpMKLDNN(input, dLdz, dLdx); + // dLdx + auto dLdx_user_mem = mkldnnUtils::loadDataToMklStream(*dLdx, engine, stream, dLdx_user_md, + op_prim_desc.diff_src_desc() , + args[DNNL_ARG_DIFF_SRC] ); - return Status::OK(); - } + // run calculations backward + dnnl::eltwise_backward(op_prim_desc).execute(stream, args); - PLATFORM_CHECK(tanh_bp, ENGINE_CPU) { + // reorder outputs if necessary + if (op_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_SRC], dLdx_user_mem) + .execute(stream, args[DNNL_ARG_DIFF_SRC], dLdx_user_mem); - auto x = INPUT_VARIABLE(0); - auto dLdz = INPUT_VARIABLE(1); - auto dLdx = OUTPUT_VARIABLE(0); + stream.wait(); +} - const DataType xType = x->dataType(); - const DataType dLdzType = dLdz->dataType(); - const DataType dLdxType = dLdx->dataType(); +PLATFORM_IMPL(tanh_bp, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto dLdz = INPUT_VARIABLE(1); + auto dLdx = OUTPUT_VARIABLE(0); - const int xRank = x->rankOf(); - const int dLdzRank = dLdz->rankOf(); + const int rank = input->rankOf(); + const int dLdzRank = dLdz->rankOf(); - bool bSupportedRanks = xRank < 7 && xRank > 0 && dLdzRank == xRank && (!x->isEmpty() && !dLdz->isEmpty()); - bSupportedRanks &= (xType == DataType::FLOAT32 && dLdzType == DataType::FLOAT32 && dLdxType == DataType::FLOAT32); + REQUIRE_TRUE( + rank <= 6 && dLdzRank <= 6, 0, + "TANH_BP_MKLDNN OP: the rank of input and dLdz must be less or qual 6, " + "but got input rank = %i and dLdz rank rank = %i instead !", + rank, dLdzRank); - if (bSupportedRanks) { - for (int i = 0; i < xRank; i++) { - if (x->sizeAt(i) != dLdz->sizeAt(i)) { - bSupportedRanks = false; - break; - } - } - } + // mkldnnSoftMax + tanhBpMKLDNN(input, dLdz, dLdx); - //Source Destination - //f32 f32 - return block.isUseMKLDNN() && bSupportedRanks; - } + return Status::OK(); +} - } +PLATFORM_CHECK(tanh_bp, ENGINE_CPU) { + auto x = INPUT_VARIABLE(0); + auto dLdz = INPUT_VARIABLE(1); + auto dLdx = OUTPUT_VARIABLE(0); + + const DataType xType = x->dataType(); + const DataType dLdzType = dLdz->dataType(); + const DataType dLdxType = dLdx->dataType(); + + const int xRank = x->rankOf(); + const int dLdzRank = dLdz->rankOf(); + + bool bSupportedRanks = xRank < 7 && xRank > 0 && dLdzRank == xRank && + (!x->isEmpty() && !dLdz->isEmpty()); + bSupportedRanks &= + (xType == DataType::FLOAT32 && dLdzType == DataType::FLOAT32 && + dLdxType == DataType::FLOAT32); + + if (bSupportedRanks) { + for (int i = 0; i < xRank; i++) { + if (x->sizeAt(i) != dLdz->sizeAt(i)) { + bSupportedRanks = false; + break; + } } + } + + // Source Destination + // f32 f32 + return block.isUseMKLDNN() && bSupportedRanks; } + +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp index 1097ccd34ef9..d7b8c20587e4 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp @@ -14,373 +14,456 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleg Semeniv - // - // +// +// @author Oleg Semeniv +// +// -#include +#include #include +#include #include -#include + #include "mkldnnUtils.h" using namespace dnnl; namespace sd { - namespace ops { - namespace platforms { - - ////////////////////////////////////////////////////////////////////// - static void xwPlusBiasMKLDNN(const NDArray* x, const NDArray* weights, const NDArray* bias, NDArray* z, const bool bShouldTransp) { - - // mkl works with following - // [M,K] x [N,K]^T + [N] = [M,N] - const auto xRank = x->rankOf(); - - // [M,K] x [K,N] = [M,N] - const int M = x->sizeAt(0); - const int K = x->sizeAt(1); // K == wK - const int N = z->sizeAt(1); - - dnnl::memory::dims xShape = dnnl::memory::dims({ M, K }); - dnnl::memory::dims wShape = dnnl::memory::dims({ N, K }); - dnnl::memory::dims zShape = dnnl::memory::dims({ M, N }); - dnnl::memory::dims bShape = dnnl::memory::dims({ N }); - - dnnl::memory::format_tag format = dnnl::memory::format_tag::ab; - - // x type - dnnl::memory::data_type xType = dnnl::memory::data_type::f32; - if (x->dataType() == DataType::UINT8) - xType = dnnl::memory::data_type::u8; - else if (x->dataType() == DataType::INT8) - xType = dnnl::memory::data_type::s8; - - // weights type - dnnl::memory::data_type wType = (weights->dataType() == DataType::FLOAT32) ? - wType = dnnl::memory::data_type::f32 : wType = dnnl::memory::data_type::s8; - - // bias type need add description for bias - dnnl::memory::data_type bType = dnnl::memory::data_type::f32; - if (bias->dataType() == DataType::INT32) - bType = dnnl::memory::data_type::s32; - else if (bias->dataType() == DataType::UINT8) - bType = dnnl::memory::data_type::u8; - else if (bias->dataType() == DataType::INT8) - bType = dnnl::memory::data_type::s8; - - // z type - dnnl::memory::data_type zType = dnnl::memory::data_type::f32; - if (z->dataType() == DataType::INT32) - zType = dnnl::memory::data_type::s32; - else if (z->dataType() == DataType::UINT8) - zType = dnnl::memory::data_type::u8; - else if (z->dataType() == DataType::INT8) - zType = dnnl::memory::data_type::s8; - - // memory descriptors for arrays - // x - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, mkldnnUtils::getFormat(*x)); - mkldnnUtils::setBlockStrides(*x, x_user_md); - - // weights - dnnl::memory::desc weights_mkl_md = dnnl::memory::desc(wShape, wType, dnnl::memory::format_tag::any); - dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, wType, mkldnnUtils::getFormat(*weights)); - mkldnnUtils::setBlockStrides(*weights, weights_user_md, bShouldTransp ? std::vector({1,0}) : std::vector()); - - // bias - dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::a); - dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::a); - mkldnnUtils::setBlockStrides(*bias, bias_user_md); - - // z - dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, dnnl::memory::format_tag::any); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, mkldnnUtils::getFormat(*z)); - mkldnnUtils::setBlockStrides(*z, z_user_md); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // operation primitive description - dnnl::inner_product_forward::desc op_desc(dnnl::prop_kind::forward_inference, x_mkl_md, weights_mkl_md, bias_mkl_md, z_mkl_md); - - dnnl::inner_product_forward::primitive_desc op_prim_desc(op_desc, engine); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // weights - mkldnnUtils::loadDataToMklStream(*weights, engine, stream, weights_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - - // bias - auto bias_mkl_mem = dnnl::memory(bias_mkl_md, engine, const_cast(bias->buffer())); - args[DNNL_ARG_BIAS] = bias_mkl_mem; - - // z - auto z_user_mem = mkldnnUtils::loadDataToMklStream(*z, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]); - - // run calculations - dnnl::inner_product_forward(op_prim_desc).execute(stream, args); - - // reorder outputs if necessary - if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) - dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); - - stream.wait(); - } - - ////////////////////////////////////////////////////////////////////// - static void xwPlusBiasBp(const NDArray* x, const NDArray* weights, const NDArray* bias, const NDArray* dLdz, - NDArray* dLdx, NDArray* dLdw, NDArray* dLdb, const bool bShouldTransp) { - - // mkl works with following - // [M,K] x [N,K]^T + [N] = [M,N] - const auto xRank = x->rankOf(); - - // [M,K] x [K,N] = [M,N] - const int M = x->sizeAt(0); - const int K = x->sizeAt(1); // K == wK - const int N = dLdz->sizeAt(1); - // input dims - dnnl::memory::dims xShape = dnnl::memory::dims({ M, K }); - dnnl::memory::dims wShape = dnnl::memory::dims({ N, K }); - dnnl::memory::dims dLdzShape = dnnl::memory::dims({ M, N }); - - dnnl::memory::dims bShape = dnnl::memory::dims({ N }); - - // output dims - dnnl::memory::dims dLdxShape = xShape; - dnnl::memory::dims dLdwShape = wShape; - - dnnl::memory::data_type dataType = dnnl::memory::data_type::f32; - - // memory descriptors for arrays - // x - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dataType, mkldnnUtils::getFormat(*x)); - mkldnnUtils::setBlockStrides(*x, x_user_md); - - // weights - dnnl::memory::desc weights_mkl_md = dnnl::memory::desc(wShape, dataType, dnnl::memory::format_tag::any); - dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, dataType, mkldnnUtils::getFormat(*weights)); - mkldnnUtils::setBlockStrides(*weights, weights_user_md, bShouldTransp ? std::vector({1,0}) : std::vector()); - - // bias - dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::any); - dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, dataType, mkldnnUtils::getFormat(*bias)); - mkldnnUtils::setBlockStrides(*bias, bias_user_md); - - // dLdz - dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(dLdzShape, dataType, dnnl::memory::format_tag::any); - dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dataType, mkldnnUtils::getFormat(*dLdz)); - mkldnnUtils::setBlockStrides(*dLdz, dLdz_user_md); - - - // dLdw - dnnl::memory::desc dLdw_mkl_md = dnnl::memory::desc(wShape, dataType, dnnl::memory::format_tag::any); - dnnl::memory::desc dLdw_user_md = dnnl::memory::desc(wShape, dataType, mkldnnUtils::getFormat(*dLdw)); - mkldnnUtils::setBlockStrides(*dLdw, dLdw_user_md, bShouldTransp ? std::vector({1,0}) : std::vector()); - - // dLdb - dnnl::memory::desc dLdb_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::any); - dnnl::memory::desc dLdb_user_md = dnnl::memory::desc(bShape, dataType, mkldnnUtils::getFormat(*dLdb)); - mkldnnUtils::setBlockStrides(*dLdb, dLdb_user_md); - - // dLdx - dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any); - dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dataType, mkldnnUtils::getFormat(*dLdx)); - mkldnnUtils::setBlockStrides(*dLdx, dLdx_user_md); - - // create engine - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // forward - // operation primitive description - dnnl::inner_product_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, weights_mkl_md, bias_mkl_md, dLdz_mkl_md); - dnnl::inner_product_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); - - // backprob - // dLdw - auto op_bpdw_desc = inner_product_backward_weights::desc(x_mkl_md, dLdw_mkl_md, dLdb_mkl_md, dLdz_mkl_md); - auto op_bpdw_prim_desc = inner_product_backward_weights::primitive_desc(op_bpdw_desc, engine, op_ff_prim_desc); - - // backprob - // dLdx - auto op_bpdx_desc = inner_product_backward_data::desc(dLdx_mkl_md, weights_mkl_md, dLdz_mkl_md); - auto op_bpdx_prim_desc = inner_product_backward_data::primitive_desc(op_bpdx_desc, engine, op_ff_prim_desc); - - // arguments (memory buffers) necessary for calculations - std::unordered_map argsDw, argsDx; - - dnnl::stream stream(engine); - - // dLdz dw - mkldnnUtils::loadDataToMklStream(*dLdz, engine, stream, dLdz_user_md, op_bpdw_prim_desc.diff_dst_desc(), argsDw[DNNL_ARG_DIFF_DST]); - - // dLdz - dx - mkldnnUtils::loadDataToMklStream(*dLdz, engine, stream, dLdz_user_md, op_bpdx_prim_desc.diff_dst_desc(), argsDx[DNNL_ARG_DIFF_DST]); - - // input x for dw - mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_bpdw_prim_desc.src_desc(), argsDw[DNNL_ARG_SRC]); - - // weights - dx - mkldnnUtils::loadDataToMklStream(*weights, engine, stream, weights_user_md, op_bpdx_prim_desc.weights_desc(), argsDx[DNNL_ARG_WEIGHTS]); - - // dLdw - auto dLdw_user_mem = mkldnnUtils::loadDataToMklStream(*dLdw, engine, stream, dLdw_user_md, op_bpdw_prim_desc.diff_weights_desc(), argsDw[DNNL_ARG_DIFF_WEIGHTS]); - - // dLdx - auto dLdx_user_mem = mkldnnUtils::loadDataToMklStream(*dLdx, engine, stream, dLdx_user_md, op_bpdx_prim_desc.diff_src_desc(), argsDx[DNNL_ARG_DIFF_SRC]); - - // dLdb - auto dLdb_user_mem = mkldnnUtils::loadDataToMklStream(*dLdb, engine, stream, dLdb_user_md, op_bpdw_prim_desc.diff_bias_desc(), argsDw[DNNL_ARG_DIFF_BIAS]); - - // run calculations dw - dnnl::inner_product_backward_weights(op_bpdw_prim_desc).execute(stream, argsDw); - // run calculations dx - dnnl::inner_product_backward_data(op_bpdx_prim_desc).execute(stream, argsDx); - - // reorder outputs if necessary - if (op_bpdx_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc()) - dnnl::reorder(argsDx[DNNL_ARG_DIFF_SRC], dLdx_user_mem).execute(stream, argsDx[DNNL_ARG_DIFF_SRC], dLdx_user_mem); - - if (op_bpdw_prim_desc.diff_weights_desc() != dLdw_user_mem.get_desc()) - dnnl::reorder(argsDw[DNNL_ARG_DIFF_WEIGHTS], dLdw_user_mem).execute(stream, argsDw[DNNL_ARG_DIFF_WEIGHTS], dLdw_user_mem); - - if (op_bpdw_prim_desc.diff_bias_desc() != dLdb_user_mem.get_desc()) - dnnl::reorder(argsDw[DNNL_ARG_DIFF_BIAS], dLdb_user_mem).execute(stream, argsDw[DNNL_ARG_DIFF_BIAS], dLdb_user_mem); - - stream.wait(); - } - - PLATFORM_IMPL(xw_plus_b, ENGINE_CPU) { - - auto x = INPUT_VARIABLE(0); - auto w = INPUT_VARIABLE(1); - auto b = INPUT_VARIABLE(2); - auto z = OUTPUT_VARIABLE(0); - - if (x->isEmpty() || w->isEmpty() || b->isEmpty()) - return Status::OK(); - - const int xRank = x->rankOf(); - const int wRank = w->rankOf(); - const int zRank = z->rankOf(); - - const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N] - - REQUIRE_TRUE(xRank == 2, 0, "xw_plus_b MKL: Input x array should have rank equal 2, but got instead %i!", xRank); - REQUIRE_TRUE(wRank == 2, 0, "xw_plus_b MKL: Input weights array should have rank equal 2, but got instead %i!", wRank); - REQUIRE_TRUE(zRank == 2, 0, "xw_plus_b MKL: Output array should have rank equal 2, but got instead %i!", zRank); +namespace ops { +namespace platforms { + +////////////////////////////////////////////////////////////////////// +static void xwPlusBiasMKLDNN(const NDArray* x, const NDArray* weights, + const NDArray* bias, NDArray* z, + const bool bShouldTransp) { + // mkl works with following + // [M,K] x [N,K]^T + [N] = [M,N] + const auto xRank = x->rankOf(); + + // [M,K] x [K,N] = [M,N] + const int M = x->sizeAt(0); + const int K = x->sizeAt(1); // K == wK + const int N = z->sizeAt(1); + + dnnl::memory::dims xShape = dnnl::memory::dims({M, K}); + dnnl::memory::dims wShape = dnnl::memory::dims({N, K}); + dnnl::memory::dims zShape = dnnl::memory::dims({M, N}); + dnnl::memory::dims bShape = dnnl::memory::dims({N}); + + dnnl::memory::format_tag format = dnnl::memory::format_tag::ab; + + // x type + dnnl::memory::data_type xType = dnnl::memory::data_type::f32; + if (x->dataType() == DataType::UINT8) + xType = dnnl::memory::data_type::u8; + else if (x->dataType() == DataType::INT8) + xType = dnnl::memory::data_type::s8; + + // weights type + dnnl::memory::data_type wType = (weights->dataType() == DataType::FLOAT32) + ? wType = dnnl::memory::data_type::f32 + : wType = dnnl::memory::data_type::s8; + + // bias type need add description for bias + dnnl::memory::data_type bType = dnnl::memory::data_type::f32; + if (bias->dataType() == DataType::INT32) + bType = dnnl::memory::data_type::s32; + else if (bias->dataType() == DataType::UINT8) + bType = dnnl::memory::data_type::u8; + else if (bias->dataType() == DataType::INT8) + bType = dnnl::memory::data_type::s8; + + // z type + dnnl::memory::data_type zType = dnnl::memory::data_type::f32; + if (z->dataType() == DataType::INT32) + zType = dnnl::memory::data_type::s32; + else if (z->dataType() == DataType::UINT8) + zType = dnnl::memory::data_type::u8; + else if (z->dataType() == DataType::INT8) + zType = dnnl::memory::data_type::s8; + + // memory descriptors for arrays + // x + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xShape, xType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, mkldnnUtils::getFormat(*x)); + mkldnnUtils::setBlockStrides(*x, x_user_md); + + // weights + dnnl::memory::desc weights_mkl_md = + dnnl::memory::desc(wShape, wType, dnnl::memory::format_tag::any); + dnnl::memory::desc weights_user_md = + dnnl::memory::desc(wShape, wType, mkldnnUtils::getFormat(*weights)); + mkldnnUtils::setBlockStrides(* + weights, + weights_user_md, bShouldTransp ? std::vector({1,0}) : std::vector()); + + // bias + dnnl::memory::desc bias_mkl_md = + dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::a); + dnnl::memory::desc bias_user_md = + dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::a); + mkldnnUtils::setBlockStrides(*bias, bias_user_md); + + // z + dnnl::memory::desc z_mkl_md = + dnnl::memory::desc(zShape, zType, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, mkldnnUtils::getFormat(*z)); + mkldnnUtils::setBlockStrides(*z, z_user_md); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // operation primitive description + dnnl::inner_product_forward::desc op_desc(dnnl::prop_kind::forward_inference, + x_mkl_md, weights_mkl_md, + bias_mkl_md, z_mkl_md); + + dnnl::inner_product_forward::primitive_desc op_prim_desc(op_desc, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, + op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + + // weights + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, weights_user_md, + op_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + + // bias + auto bias_mkl_mem = + dnnl::memory(bias_mkl_md, engine, const_cast(bias->buffer())); + args[DNNL_ARG_BIAS] = bias_mkl_mem; + + // z + auto z_user_mem = mkldnnUtils::loadDataToMklStream(*z, engine, stream, z_user_md, op_prim_desc.dst_desc(), + args[DNNL_ARG_DST] ); + + // run calculations + dnnl::inner_product_forward(op_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); + + stream.wait(); +} - REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == z->sizeAt(-1), 0, "xw_plus_b MKL: Input bias vector should be 1D and have proper dimension 1x%i." - " But got rank %i, and got length %i instead %i.", z->sizeAt(-1), b->rankOf(), b->lengthOf(), z->sizeAt(-1)); - - // mkldnnInerPorductss - xwPlusBiasMKLDNN(x, w, b, z, bShouldTransp); - - return Status::OK(); - } - - PLATFORM_CHECK(xw_plus_b, ENGINE_CPU) { - - auto x = INPUT_VARIABLE(0); - auto w = INPUT_VARIABLE(1); - auto b = INPUT_VARIABLE(2); - auto z = OUTPUT_VARIABLE(0); +////////////////////////////////////////////////////////////////////// +static void xwPlusBiasBp(const NDArray* x, const NDArray* weights, + const NDArray* bias, const NDArray* dLdz, + NDArray* dLdx, NDArray* dLdw, NDArray* dLdb, + const bool bShouldTransp) { + // mkl works with following + // [M,K] x [N,K]^T + [N] = [M,N] + const auto xRank = x->rankOf(); + + // [M,K] x [K,N] = [M,N] + const int M = x->sizeAt(0); + const int K = x->sizeAt(1); // K == wK + const int N = dLdz->sizeAt(1); + // input dims + dnnl::memory::dims xShape = dnnl::memory::dims({M, K}); + dnnl::memory::dims wShape = dnnl::memory::dims({N, K}); + dnnl::memory::dims dLdzShape = dnnl::memory::dims({M, N}); + + dnnl::memory::dims bShape = dnnl::memory::dims({N}); + // output dims + dnnl::memory::dims dLdxShape = xShape; + dnnl::memory::dims dLdwShape = wShape; + + + dnnl::memory::data_type dataType = dnnl::memory::data_type::f32; + + // memory descriptors for arrays + // x + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dataType, mkldnnUtils::getFormat(*x)); + mkldnnUtils::setBlockStrides(*x, x_user_md); + + // weights + dnnl::memory::desc weights_mkl_md = + dnnl::memory::desc(wShape, dataType, dnnl::memory::format_tag::any); + dnnl::memory::desc weights_user_md = + dnnl::memory::desc(wShape, dataType, mkldnnUtils::getFormat(*weights)); + mkldnnUtils::setBlockStrides(* + weights, + weights_user_md, bShouldTransp ? std::vector({1,0}) : std::vector()); + + // bias + dnnl::memory::desc bias_mkl_md = + dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::any); + dnnl::memory::desc bias_user_md = + dnnl::memory::desc(bShape, dataType, mkldnnUtils::getFormat(*bias)); + mkldnnUtils::setBlockStrides(*bias, bias_user_md); + + // dLdz + dnnl::memory::desc dLdz_mkl_md = + dnnl::memory::desc(dLdzShape, dataType, dnnl::memory::format_tag::any); + dnnl::memory::desc dLdz_user_md = + dnnl::memory::desc(dLdzShape, dataType, mkldnnUtils::getFormat(*dLdz)); + mkldnnUtils::setBlockStrides(*dLdz, dLdz_user_md); + + // dLdw + dnnl::memory::desc dLdw_mkl_md = dnnl::memory::desc(wShape, dataType, dnnl::memory::format_tag::any); + dnnl::memory::desc dLdw_user_md = + dnnl::memory::desc(wShape, dataType, mkldnnUtils::getFormat(*dLdw)); + mkldnnUtils::setBlockStrides(*dLdw, + dLdw_user_md, bShouldTransp ? std::vector({1,0}) : std::vector()); + + // dLdb + dnnl::memory::desc dLdb_mkl_md = + dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::any); + dnnl::memory::desc dLdb_user_md = + dnnl::memory::desc(bShape, dataType, mkldnnUtils::getFormat(*dLdb)); + mkldnnUtils::setBlockStrides(*dLdb, dLdb_user_md); + + // dLdx + dnnl::memory::desc dLdx_mkl_md = + dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any); + dnnl::memory::desc dLdx_user_md = + dnnl::memory::desc(xShape, dataType, mkldnnUtils::getFormat(*dLdx)); + mkldnnUtils::setBlockStrides(*dLdx, dLdx_user_md); + + // create engine + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + // forward + // operation primitive description + dnnl::inner_product_forward::desc op_ff_desc( + dnnl::prop_kind::forward_inference, x_mkl_md, weights_mkl_md, bias_mkl_md,dLdz_mkl_md); + dnnl::inner_product_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + + // backprob + // dLdw + auto op_bpdw_desc = inner_product_backward_weights::desc( + x_mkl_md, dLdw_mkl_md, dLdb_mkl_md, dLdz_mkl_md); + auto op_bpdw_prim_desc = inner_product_backward_weights::primitive_desc( + op_bpdw_desc, engine, op_ff_prim_desc); + + // backprob + // dLdx + auto op_bpdx_desc = inner_product_backward_data::desc( + dLdx_mkl_md, weights_mkl_md, dLdz_mkl_md); + auto op_bpdx_prim_desc = inner_product_backward_data::primitive_desc( + op_bpdx_desc, engine, op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map argsDw, argsDx; + + dnnl::stream stream(engine); + + // dLdz dw + mkldnnUtils::loadDataToMklStream(*dLdz, engine, stream, dLdz_user_md, + op_bpdw_prim_desc.diff_dst_desc(), + argsDw[DNNL_ARG_DIFF_DST]); + + // dLdz - dx + mkldnnUtils::loadDataToMklStream(*dLdz, engine, stream, dLdz_user_md, + op_bpdx_prim_desc.diff_dst_desc(), + argsDx[DNNL_ARG_DIFF_DST]); + + // input x for dw + mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, + op_bpdw_prim_desc.src_desc(), + argsDw[DNNL_ARG_SRC]); + + // weights - dx + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, weights_user_md, + op_bpdx_prim_desc.weights_desc(), + argsDx[DNNL_ARG_WEIGHTS]); + + // dLdw + auto dLdw_user_mem = mkldnnUtils::loadDataToMklStream(*dLdw, engine, stream, dLdw_user_md, + op_bpdw_prim_desc.diff_weights_desc() , + argsDw[DNNL_ARG_DIFF_WEIGHTS] ); + + // dLdx + auto dLdx_user_mem = mkldnnUtils::loadDataToMklStream(*dLdx, engine, stream, dLdx_user_md, + op_bpdx_prim_desc.diff_src_desc() , + argsDx[DNNL_ARG_DIFF_SRC] ); + + // dLdb + auto dLdb_user_mem = mkldnnUtils::loadDataToMklStream(*dLdb, engine, stream, dLdb_user_md, + op_bpdw_prim_desc.diff_bias_desc() , + argsDw[DNNL_ARG_DIFF_BIAS] ); + + // run calculations dw + dnnl::inner_product_backward_weights(op_bpdw_prim_desc) + .execute(stream, argsDw); + // run calculations dx + dnnl::inner_product_backward_data(op_bpdx_prim_desc).execute(stream, argsDx); + + // reorder outputs if necessary + if (op_bpdx_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc()) + dnnl::reorder(argsDx[DNNL_ARG_DIFF_SRC], dLdx_user_mem) + .execute(stream, argsDx[DNNL_ARG_DIFF_SRC], dLdx_user_mem); + + if (op_bpdw_prim_desc.diff_weights_desc() != dLdw_user_mem.get_desc()) + dnnl::reorder(argsDw[DNNL_ARG_DIFF_WEIGHTS], dLdw_user_mem) + .execute(stream, argsDw[DNNL_ARG_DIFF_WEIGHTS], dLdw_user_mem); + + if (op_bpdw_prim_desc.diff_bias_desc() != dLdb_user_mem.get_desc()) + dnnl::reorder(argsDw[DNNL_ARG_DIFF_BIAS], dLdb_user_mem) + .execute(stream, argsDw[DNNL_ARG_DIFF_BIAS], dLdb_user_mem); + + stream.wait(); +} - const DataType xType = x->dataType(); - const DataType wType = w->dataType(); - const DataType bType = b->dataType(); - const DataType zType = z->dataType(); +PLATFORM_IMPL(xw_plus_b, ENGINE_CPU) { + auto x = INPUT_VARIABLE(0); + auto w = INPUT_VARIABLE(1); + auto b = INPUT_VARIABLE(2); + auto z = OUTPUT_VARIABLE(0); + + if (x->isEmpty() || w->isEmpty() || b->isEmpty()) return Status::OK(); + + const int xRank = x->rankOf(); + const int wRank = w->rankOf(); + const int zRank = z->rankOf(); + + const bool bShouldTransp = + block.numI() > 0 + ? (1 != INT_ARG(0)) + : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N] + + REQUIRE_TRUE(xRank == 2, 0, + "xw_plus_b MKL: Input x array should have rank equal 2, but got " + "instead %i!", + xRank); + REQUIRE_TRUE(wRank == 2, 0, + "xw_plus_b MKL: Input weights array should have rank equal 2, " + "but got instead %i!", + wRank); + REQUIRE_TRUE(zRank == 2, 0, + "xw_plus_b MKL: Output array should have rank equal 2, but got " + "instead %i!", + zRank); + + REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == z->sizeAt(-1), 0, + "xw_plus_b MKL: Input bias vector should be 1D and have proper " + "dimension 1x%i." + " But got rank %i, and got length %i instead %i.", + z->sizeAt(-1), b->rankOf(), b->lengthOf(), z->sizeAt(-1)); + + // mkldnnInerPorductss + xwPlusBiasMKLDNN(x, w, b, z, bShouldTransp); + + return Status::OK(); +} - /* - Source Weights Destination Bias - f32 f32 f32 f32 - u8, s8 s8 u8, s8, s32, f32 u8, s8, s32, f32 - */ - return block.isUseMKLDNN() && - ((xType == DataType::FLOAT32 && wType == DataType::FLOAT32 && bType == DataType::FLOAT32 && zType == DataType::FLOAT32) || - ( // x - (xType == DataType::UINT8 || xType == DataType::INT8) && - // w - (wType == DataType::UINT8 || wType == DataType::INT8) && - // b - (bType == DataType::UINT8 || bType == DataType::INT8 || bType == DataType::INT32 || bType == DataType::FLOAT32) && - // z - (zType == DataType::UINT8 || zType == DataType::INT8 || zType == DataType::INT32 || zType == DataType::FLOAT32) - )); - } +PLATFORM_CHECK(xw_plus_b, ENGINE_CPU) { + auto x = INPUT_VARIABLE(0); + auto w = INPUT_VARIABLE(1); + auto b = INPUT_VARIABLE(2); + auto z = OUTPUT_VARIABLE(0); + + const DataType xType = x->dataType(); + const DataType wType = w->dataType(); + const DataType bType = b->dataType(); + const DataType zType = z->dataType(); + + /* + Source Weights Destination Bias + f32 f32 f32 f32 + u8, s8 s8 u8, s8, s32, f32 u8, s8, s32, f32 + */ + return block.isUseMKLDNN() && + ((xType == DataType::FLOAT32 && wType == DataType::FLOAT32 && + bType == DataType::FLOAT32 && zType == DataType::FLOAT32) || + ( // x + (xType == DataType::UINT8 || xType == DataType::INT8) && + // w + (wType == DataType::UINT8 || wType == DataType::INT8) && + // b + (bType == DataType::UINT8 || bType == DataType::INT8 || + bType == DataType::INT32 || bType == DataType::FLOAT32) && + // z + (zType == DataType::UINT8 || zType == DataType::INT8 || + zType == DataType::INT32 || zType == DataType::FLOAT32))); +} - PLATFORM_IMPL(xw_plus_b_bp, ENGINE_CPU) { +PLATFORM_IMPL(xw_plus_b_bp, ENGINE_CPU) { + auto x = INPUT_VARIABLE(0); + auto w = INPUT_VARIABLE(1); + auto b = INPUT_VARIABLE(2); + auto dLdz = INPUT_VARIABLE(3); + + auto dLdx = OUTPUT_VARIABLE(0); + auto dLdw = OUTPUT_VARIABLE(1); + auto dLdb = OUTPUT_VARIABLE(2); + + if (x->isEmpty() || w->isEmpty() || b->isEmpty() || dLdz->isEmpty()) + return Status::OK(); + + const int xRank = x->rankOf(); + const int wRank = w->rankOf(); + const int dLdzRank = dLdz->rankOf(); + + const bool bShouldTransp = + block.numI() > 0 + ? (1 != INT_ARG(0)) + : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N] + + REQUIRE_TRUE(x->rankOf() == 2, 0, + "xw_plus_b BP MKL: Input x array should have rank equal 2, but " + "got instead %i!", + x->rankOf()); + REQUIRE_TRUE(w->rankOf() == 2, 0, + "xw_plus_b BP MKL: Input weights array should have rank equal " + "2, but got instead %i!", + w->rankOf()); + REQUIRE_TRUE(dLdz->rankOf() == 2, 0, + "xw_plus_b BP MKL: Output array should have rank equal 2, but " + "got instead %i!", + dLdz->rankOf()); + REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == dLdz->sizeAt(1), 0, + "xw_plus_b BP MKL: Input bias vector should be 1D and have " + "proper dimension 1x%i." + " But got rank %i, and got length %i instead %i.", + dLdz->sizeAt(1), b->rankOf(), b->lengthOf(), dLdz->sizeAt(1)); + + xwPlusBiasBp(x, w, b, dLdz, dLdx, dLdw, dLdb, bShouldTransp); + + return Status::OK(); +} - auto x = INPUT_VARIABLE(0); - auto w = INPUT_VARIABLE(1); - auto b = INPUT_VARIABLE(2); - auto dLdz = INPUT_VARIABLE(3); - - auto dLdx = OUTPUT_VARIABLE(0); - auto dLdw = OUTPUT_VARIABLE(1); - auto dLdb = OUTPUT_VARIABLE(2); - - if (x->isEmpty() || w->isEmpty() || b->isEmpty() || dLdz->isEmpty()) - return Status::OK(); - - const int xRank = x->rankOf(); - const int wRank = w->rankOf(); - const int dLdzRank = dLdz->rankOf(); - - const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N] - - REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b BP MKL: Input x array should have rank equal 2, but got instead %i!", x->rankOf()); - REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b BP MKL: Input weights array should have rank equal 2, but got instead %i!", w->rankOf()); - REQUIRE_TRUE(dLdz->rankOf() == 2, 0, "xw_plus_b BP MKL: Output array should have rank equal 2, but got instead %i!", dLdz->rankOf()); - REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == dLdz->sizeAt(1), 0, "xw_plus_b BP MKL: Input bias vector should be 1D and have proper dimension 1x%i." - " But got rank %i, and got length %i instead %i.", dLdz->sizeAt(1), b->rankOf(), b->lengthOf(), dLdz->sizeAt(1)); - - xwPlusBiasBp(x, w, b, dLdz, dLdx, dLdw, dLdb, bShouldTransp); - - return Status::OK(); - } - - PLATFORM_CHECK(xw_plus_b_bp, ENGINE_CPU) { - - auto x = INPUT_VARIABLE(0); - auto w = INPUT_VARIABLE(1); - auto b = INPUT_VARIABLE(2); - auto dLdz = INPUT_VARIABLE(3); - - auto dLdx = OUTPUT_VARIABLE(0); - auto dLdw = OUTPUT_VARIABLE(1); - auto dLdb = OUTPUT_VARIABLE(2); - - const DataType xType = x->dataType(); - const DataType wType = w->dataType(); - const DataType bType = b->dataType(); - const DataType dLdzType = dLdz->dataType(); - const DataType dLdxType = dLdx->dataType(); - const DataType dLdwType = dLdw->dataType(); - const DataType dLdbType = dLdb->dataType(); - - /* - Source Weights Destination Bias - f32 f32 f32 f32 - */ - return block.isUseMKLDNN() && - (xType == DataType::FLOAT32 && wType == DataType::FLOAT32 && - bType == DataType::FLOAT32 && dLdzType == DataType::FLOAT32 && - dLdbType == DataType::FLOAT32 && dLdxType == DataType::FLOAT32 && - dLdwType == DataType::FLOAT32); - } - - } - } +PLATFORM_CHECK(xw_plus_b_bp, ENGINE_CPU) { + auto x = INPUT_VARIABLE(0); + auto w = INPUT_VARIABLE(1); + auto b = INPUT_VARIABLE(2); + auto dLdz = INPUT_VARIABLE(3); + + auto dLdx = OUTPUT_VARIABLE(0); + auto dLdw = OUTPUT_VARIABLE(1); + auto dLdb = OUTPUT_VARIABLE(2); + + const DataType xType = x->dataType(); + const DataType wType = w->dataType(); + const DataType bType = b->dataType(); + const DataType dLdzType = dLdz->dataType(); + const DataType dLdxType = dLdx->dataType(); + const DataType dLdwType = dLdw->dataType(); + const DataType dLdbType = dLdb->dataType(); + + /* + Source Weights Destination Bias + f32 f32 f32 f32 + */ + return block.isUseMKLDNN() && + (xType == DataType::FLOAT32 && wType == DataType::FLOAT32 && + bType == DataType::FLOAT32 && dLdzType == DataType::FLOAT32 && + dLdbType == DataType::FLOAT32 && dLdxType == DataType::FLOAT32 && + dLdwType == DataType::FLOAT32); } + +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/gemm.h b/libnd4j/include/ops/gemm.h index 23f1636a2cfd..cf99babd93c0 100644 --- a/libnd4j/include/ops/gemm.h +++ b/libnd4j/include/ops/gemm.h @@ -25,38 +25,40 @@ #include #include - namespace sd { - namespace blas { - template - static void * transpose(int orderSource, int orderTarget, int rows, int cols, void *source); - - static inline int linearIndexC(int rows, int cols, int r, int c); - static inline int linearIndexF(int rows, int cols, int r, int c); - - template - class GEMM { - protected: - public: - static void op(int Order, int TransA, int TransB, int M, int N, int K, double alpha, void *A, int lda, void *B, int ldb, double beta, void *C, int ldc); - }; +namespace blas { +template +static void *transpose(int orderSource, int orderTarget, int rows, int cols, + void *source); - template - class GEMV : public sd::blas::GEMM{ - public: - static void op(int TRANS, int M, int N, double alpha, void* vA, int lda, void* vX, int incx, double beta, void* vY, int incy ); - }; +static inline int linearIndexC(int rows, int cols, int r, int c); +static inline int linearIndexF(int rows, int cols, int r, int c); +template +class GEMM { + protected: + public: + static void op(int Order, int TransA, int TransB, int M, int N, int K, + double alpha, void *A, int lda, void *B, int ldb, double beta, + void *C, int ldc); +}; - int FORCEINLINE linearIndexC(int rows, int cols, int r, int c) { - return (r * cols + c); - } +template +class GEMV : public sd::blas::GEMM { + public: + static void op(int TRANS, int M, int N, double alpha, void *vA, int lda, + void *vX, int incx, double beta, void *vY, int incy); +}; - int FORCEINLINE linearIndexF(int rows, int cols, int r, int c) { - return (c * rows + r); - } +int FORCEINLINE linearIndexC(int rows, int cols, int r, int c) { + return (r * cols + c); +} - } +int FORCEINLINE linearIndexF(int rows, int cols, int r, int c) { + return (c * rows + r); } -#endif //LIBND4J_GEMM_H +} // namespace blas +} // namespace sd + +#endif // LIBND4J_GEMM_H diff --git a/libnd4j/include/ops/impl/BroadcastBoolOpsTuple.cpp b/libnd4j/include/ops/impl/BroadcastBoolOpsTuple.cpp index 7e903346b817..3b2fb741b61f 100644 --- a/libnd4j/include/ops/impl/BroadcastBoolOpsTuple.cpp +++ b/libnd4j/include/ops/impl/BroadcastBoolOpsTuple.cpp @@ -20,8 +20,10 @@ #include namespace sd { - BroadcastBoolOpsTuple BroadcastBoolOpsTuple::custom(sd::scalar::BoolOps scalar, sd::pairwise::BoolOps pairwise, sd::broadcast::BoolOps broadcast) { - BroadcastBoolOpsTuple t(scalar, pairwise, broadcast); - return t; - } +BroadcastBoolOpsTuple BroadcastBoolOpsTuple::custom( + sd::scalar::BoolOps scalar, sd::pairwise::BoolOps pairwise, + sd::broadcast::BoolOps broadcast) { + BroadcastBoolOpsTuple t(scalar, pairwise, broadcast); + return t; } +} // namespace sd diff --git a/libnd4j/include/ops/impl/BroadcastIntOpsTuple.cpp b/libnd4j/include/ops/impl/BroadcastIntOpsTuple.cpp index 5680b80569dd..d42887bc394d 100644 --- a/libnd4j/include/ops/impl/BroadcastIntOpsTuple.cpp +++ b/libnd4j/include/ops/impl/BroadcastIntOpsTuple.cpp @@ -20,8 +20,10 @@ #include namespace sd { - BroadcastIntOpsTuple BroadcastIntOpsTuple::custom(sd::scalar::IntOps scalar, sd::pairwise::IntOps pairwise, sd::broadcast::IntOps broadcast) { - BroadcastIntOpsTuple t(scalar, pairwise, broadcast); - return t; - } +BroadcastIntOpsTuple BroadcastIntOpsTuple::custom( + sd::scalar::IntOps scalar, sd::pairwise::IntOps pairwise, + sd::broadcast::IntOps broadcast) { + BroadcastIntOpsTuple t(scalar, pairwise, broadcast); + return t; } +} // namespace sd diff --git a/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp b/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp index 71afe82605f2..4fbee9db4f3c 100644 --- a/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp +++ b/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp @@ -20,47 +20,56 @@ #include namespace sd { - BroadcastOpsTuple BroadcastOpsTuple::custom(sd::scalar::Ops scalar, sd::pairwise::Ops pairwise, sd::broadcast::Ops broadcast) { - BroadcastOpsTuple t(scalar, pairwise, broadcast); - return t; - } - - BroadcastOpsTuple BroadcastOpsTuple::Add() { - return custom(sd::scalar::Add, sd::pairwise::Add, sd::broadcast::Add); - } - - BroadcastOpsTuple BroadcastOpsTuple::Assign() { - return custom(sd::scalar::CopyPws, sd::pairwise::CopyPws, sd::broadcast::CopyPws); - } +BroadcastOpsTuple BroadcastOpsTuple::custom(sd::scalar::Ops scalar, + sd::pairwise::Ops pairwise, + sd::broadcast::Ops broadcast) { + BroadcastOpsTuple t(scalar, pairwise, broadcast); + return t; +} - BroadcastOpsTuple BroadcastOpsTuple::Divide() { - return custom(sd::scalar::Divide, sd::pairwise::Divide, sd::broadcast::Divide); - } +BroadcastOpsTuple BroadcastOpsTuple::Add() { + return custom(sd::scalar::Add, sd::pairwise::Add, sd::broadcast::Add); +} - BroadcastOpsTuple BroadcastOpsTuple::DivideNoNan() { - return custom(sd::scalar::DivideNoNan, sd::pairwise::DivideNoNan, sd::broadcast::DivideNoNan); - } +BroadcastOpsTuple BroadcastOpsTuple::Assign() { + return custom(sd::scalar::CopyPws, sd::pairwise::CopyPws, + sd::broadcast::CopyPws); +} - BroadcastOpsTuple BroadcastOpsTuple::Multiply() { - return custom(sd::scalar::Multiply, sd::pairwise::Multiply, sd::broadcast::Multiply); - } +BroadcastOpsTuple BroadcastOpsTuple::Divide() { + return custom(sd::scalar::Divide, sd::pairwise::Divide, + sd::broadcast::Divide); +} - BroadcastOpsTuple BroadcastOpsTuple::Subtract() { - return custom(sd::scalar::Subtract, sd::pairwise::Subtract, sd::broadcast::Subtract); - } - BroadcastOpsTuple BroadcastOpsTuple::IGamma() { - return custom(sd::scalar::IGamma, sd::pairwise::IGamma, sd::broadcast::IGamma); - } - BroadcastOpsTuple BroadcastOpsTuple::IGammac() { - return custom(sd::scalar::IGammac, sd::pairwise::IGammac, sd::broadcast::IGammac); - } +BroadcastOpsTuple BroadcastOpsTuple::DivideNoNan() { + return custom(sd::scalar::DivideNoNan, sd::pairwise::DivideNoNan, + sd::broadcast::DivideNoNan); +} +BroadcastOpsTuple BroadcastOpsTuple::Multiply() { + return custom(sd::scalar::Multiply, sd::pairwise::Multiply, + sd::broadcast::Multiply); +} - BroadcastOpsTuple BroadcastOpsTuple::Pow() { - return custom(sd::scalar::Pow, sd::pairwise::Pow, sd::broadcast::Pow); - } - BroadcastOpsTuple BroadcastOpsTuple::PowDerivative() { - return custom(sd::scalar::PowDerivative, sd::pairwise::PowDerivative, sd::broadcast::PowDerivative); - } +BroadcastOpsTuple BroadcastOpsTuple::Subtract() { + return custom(sd::scalar::Subtract, sd::pairwise::Subtract, + sd::broadcast::Subtract); +} +BroadcastOpsTuple BroadcastOpsTuple::IGamma() { + return custom(sd::scalar::IGamma, sd::pairwise::IGamma, + sd::broadcast::IGamma); +} +BroadcastOpsTuple BroadcastOpsTuple::IGammac() { + return custom(sd::scalar::IGammac, sd::pairwise::IGammac, + sd::broadcast::IGammac); +} +BroadcastOpsTuple BroadcastOpsTuple::Pow() { + return custom(sd::scalar::Pow, sd::pairwise::Pow, sd::broadcast::Pow); +} +BroadcastOpsTuple BroadcastOpsTuple::PowDerivative() { + return custom(sd::scalar::PowDerivative, sd::pairwise::PowDerivative, + sd::broadcast::PowDerivative); } + +} // namespace sd diff --git a/libnd4j/include/ops/impl/gemm.cpp b/libnd4j/include/ops/impl/gemm.cpp index 8632ddcb9ef3..2e007c7e110b 100644 --- a/libnd4j/include/ops/impl/gemm.cpp +++ b/libnd4j/include/ops/impl/gemm.cpp @@ -19,132 +19,128 @@ // Modified by GS on 3/9/2018 // +#include #include -#include #include -#include +#include namespace sd { - namespace blas { - - template - void* transpose(int orderSource, int orderTarget, int rows, int cols, void *vsource) { - auto ret = new T[rows * cols]; - auto source = reinterpret_cast(vsource); - - // handle transpose in parallel - auto func = PRAGMA_THREADS_FOR { - for (auto r = start; r < stop; r++) { - for (int c = 0; c < cols; c++) { - int zIdx = orderTarget == CblasRowMajor ? linearIndexC(rows, cols, r, c) : linearIndexF(rows, cols, r, c); - int xIdx = orderSource == CblasColMajor ? linearIndexF(rows, cols, r, c) : linearIndexC(rows, cols, r, c); +namespace blas { + +template +void *transpose(int orderSource, int orderTarget, int rows, int cols, + void *vsource) { + auto ret = new T[rows * cols]; + auto source = reinterpret_cast(vsource); + + // handle transpose in parallel + auto func = PRAGMA_THREADS_FOR { + for (auto r = start; r < stop; r++) { + for (int c = 0; c < cols; c++) { + int zIdx = orderTarget == CblasRowMajor + ? linearIndexC(rows, cols, r, c) + : linearIndexF(rows, cols, r, c); + int xIdx = orderSource == CblasColMajor + ? linearIndexF(rows, cols, r, c) + : linearIndexC(rows, cols, r, c); + + ret[zIdx] = source[xIdx]; + } + } + }; - ret[zIdx] = source[xIdx]; - } - } - }; + samediff::Threads::parallel_for(func, 0, rows); - samediff::Threads::parallel_for(func, 0, rows); + return ret; +} - return ret; +template +void GEMM::op(int Order, int TransA, int TransB, int M, int N, int K, + double alpha, void *vA, int lda, void *vB, int ldb, + double beta, void *vC, int ldc) { + auto A = reinterpret_cast(vA); + auto B = reinterpret_cast(vB); + auto C = reinterpret_cast(vC); + + bool transAFlag = TransA == CblasTrans; + bool transBFlag = TransB == CblasTrans; + + if (beta == 0.0) { + Z z = 0.f; + int length = M * N; + if (length <= Environment::getInstance().elementwiseThreshold()) { + for (int r = 0; r < length; r++) C[r] = z; + } else { + auto func = PRAGMA_THREADS_FOR { + for (auto r = start; r < stop; r++) C[r] = z; + }; + samediff::Threads::parallel_for(func, 0, length); + } + } + + auto func = PRAGMA_THREADS_FOR_2D { + for (auto r = start_x; r < stop_x; r += inc_x) { + for (auto c = start_y; c < stop_y; c += inc_y) { + int zIdx = linearIndexF(M, N, r, c); + + Z dot = static_cast(0.0f); + + if (alpha != 0.0) { + int bIdx; // = linearIndexF(K, N, 0, c); + int aIdx; + + for (int k = 0; k < K; k++) { + aIdx = (transAFlag ? linearIndexC(M, K, r, k) + : linearIndexF(M, K, r, k)); + bIdx = (transBFlag ? linearIndexC(K, N, k, c) + : linearIndexF(K, N, k, c)); + dot += static_cast(alpha) * static_cast(A[aIdx]) * + static_cast(B[bIdx]); // A[aIdx]sd::math::nd4j_dot(aX, + // bX, K) * alpha; + } } - template - void GEMM::op(int Order, int TransA, int TransB, - int M, int N, int K, - double alpha, - void *vA, int lda, - void *vB, int ldb, - double beta, - void *vC, int ldc) { - - auto A = reinterpret_cast(vA); - auto B = reinterpret_cast(vB); - auto C = reinterpret_cast(vC); - - bool transAFlag = TransA == CblasTrans; - bool transBFlag = TransB == CblasTrans; - - if (beta == 0.0) { - Z z = 0.f; - int length = M*N; - if (length <= Environment::getInstance().elementwiseThreshold()) { - for (int r = 0; r < length; r++) - C[r] = z; - } else { - auto func = PRAGMA_THREADS_FOR { - for (auto r = start; r < stop; r++) - C[r] = z; - }; - samediff::Threads::parallel_for(func, 0, length); - } - } - - - auto func = PRAGMA_THREADS_FOR_2D { - for (auto r = start_x; r < stop_x; r += inc_x) { - for (auto c = start_y; c < stop_y; c += inc_y) { - int zIdx = linearIndexF(M, N, r, c); - - Z dot = static_cast(0.0f); - - if (alpha != 0.0) { - int bIdx; // = linearIndexF(K, N, 0, c); - int aIdx; - - for (int k = 0; k < K; k++) { - aIdx = (transAFlag ? linearIndexC(M, K, r, k) : linearIndexF(M, K, r, k)); - bIdx = (transBFlag ? linearIndexC(K, N, k, c) : linearIndexF(K, N, k, c)); - dot += static_cast(alpha) * static_cast(A[aIdx]) * static_cast(B[bIdx]);//A[aIdx]sd::math::nd4j_dot(aX, bX, K) * alpha; - } - } - - if (beta != 0.0) { - C[zIdx] = static_cast(dot + static_cast(beta) * C[zIdx]); - } else { - C[zIdx] = static_cast(dot); - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, M, 1, 0, N, 1); + if (beta != 0.0) { + C[zIdx] = static_cast(dot + static_cast(beta) * C[zIdx]); + } else { + C[zIdx] = static_cast(dot); } + } + } + }; + samediff::Threads::parallel_for(func, 0, M, 1, 0, N, 1); +} - template - void GEMV::op(int TRANS, int M, int N, - double alpha, - void * vX, - int lda, - void* vY, - int incx, - double beta, - void* vZ, - int incy ) { - - auto x = reinterpret_cast(vX); - auto y = reinterpret_cast(vY); - auto z = reinterpret_cast(vZ); - - auto aT = TRANS == CblasTrans ? reinterpret_cast(sd::blas::transpose(CblasColMajor, CblasRowMajor, M, N, reinterpret_cast(x))) : x; - - auto func = PRAGMA_THREADS_FOR { - for (auto r = start; r < stop; r++) { - int aIdx = linearIndexC(M, N, r, 0); - auto aX = aT + aIdx; - - auto dot = sd::math::nd4j_dot(aX, y, lda) * static_cast(alpha); - z[r] = beta == 0.0f ? dot : dot + static_cast(beta) * z[r]; - } - }; - samediff::Threads::parallel_for(func, 0, M); - - if (TRANS == CblasTrans) - delete[] aT; - } - - //BUILD_TRIPLE_TEMPLATE(template class GEMV, , LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); - //BUILD_TRIPLE_TEMPLATE(template class GEMM, , LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); +template +void GEMV::op(int TRANS, int M, int N, double alpha, void *vX, int lda, + void *vY, int incx, double beta, void *vZ, int incy) { + auto x = reinterpret_cast(vX); + auto y = reinterpret_cast(vY); + auto z = reinterpret_cast(vZ); + + auto aT = TRANS == CblasTrans ? reinterpret_cast(sd::blas::transpose( + CblasColMajor, CblasRowMajor, M, N, + reinterpret_cast(x))) + : x; + + auto func = PRAGMA_THREADS_FOR { + for (auto r = start; r < stop; r++) { + int aIdx = linearIndexC(M, N, r, 0); + auto aX = aT + aIdx; + + auto dot = + sd::math::nd4j_dot(aX, y, lda) * static_cast(alpha); + z[r] = beta == 0.0f ? dot : dot + static_cast(beta) * z[r]; } + }; + samediff::Threads::parallel_for(func, 0, M); + + if (TRANS == CblasTrans) delete[] aT; } + +// BUILD_TRIPLE_TEMPLATE(template class GEMV, , LIBND4J_TYPES, FLOAT_TYPES, +// FLOAT_TYPES); BUILD_TRIPLE_TEMPLATE(template class GEMM, , LIBND4J_TYPES, +// FLOAT_TYPES, FLOAT_TYPES); +} // namespace blas +} // namespace sd diff --git a/libnd4j/include/ops/impl/specials_double.hpp b/libnd4j/include/ops/impl/specials_double.hpp index d219220ac36d..93f13fb7d07b 100644 --- a/libnd4j/include/ops/impl/specials_double.hpp +++ b/libnd4j/include/ops/impl/specials_double.hpp @@ -19,252 +19,296 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // - -#include -#include +#include +#include #include +#include +#include #include #include -#include -#include +#include #include -#include namespace sd { +template +void SpecialTypeConverter::convertGeneric(Nd4jPointer *extras, void *dx, + Nd4jLong N, void *dz) { + auto x = reinterpret_cast(dx); + auto z = reinterpret_cast(dz); - template - void SpecialTypeConverter::convertGeneric(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz) { - auto x = reinterpret_cast(dx); - auto z = reinterpret_cast(dz); - - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - z[i] = static_cast(x[i]); - } - }; - - samediff::Threads::parallel_for(func, 0, N); - }; - - - template - void quickSort_parallel_internal_key(X* key, Nd4jLong const* xShapeInfo, Y* values, Nd4jLong const* yShapeInfo, int left, int right, int cutoff, bool descending) { - int i = left, j = right; - X ktmp; - X pivot = key[shape::getIndexOffset((left + right) / 2, xShapeInfo)]; - - Y vtmp; - - { - /* PARTITION PART */ - while (i <= j) { - if (descending) { - while (key[shape::getIndexOffset(i, xShapeInfo)] > pivot) - i++; - while (key[shape::getIndexOffset(j, xShapeInfo)] < pivot) - j--; - if (i <= j) { - ktmp = key[shape::getIndexOffset(i, xShapeInfo)]; - key[shape::getIndexOffset(i, xShapeInfo)] = key[shape::getIndexOffset(j, xShapeInfo)]; - key[shape::getIndexOffset(j, xShapeInfo)] = ktmp; - - vtmp = values[shape::getIndexOffset(i, yShapeInfo)]; - values[shape::getIndexOffset(i, yShapeInfo)] = values[shape::getIndexOffset(j, yShapeInfo)]; - values[shape::getIndexOffset(j, yShapeInfo)] = vtmp; - - i++; - j--; - } - } else { - while (key[shape::getIndexOffset(i, xShapeInfo)] < pivot) - i++; - while (key[shape::getIndexOffset(j, xShapeInfo)] > pivot) - j--; - if (i <= j) { - ktmp = key[shape::getIndexOffset(i, xShapeInfo)]; - key[shape::getIndexOffset(i, xShapeInfo)] = key[shape::getIndexOffset(j, xShapeInfo)]; - key[shape::getIndexOffset(j, xShapeInfo)] = ktmp; - - vtmp = values[shape::getIndexOffset(i, yShapeInfo)]; - values[shape::getIndexOffset(i, yShapeInfo)] = values[shape::getIndexOffset(j, yShapeInfo)]; - values[shape::getIndexOffset(j, yShapeInfo)] = vtmp; - - i++; - j--; - } - } - } - + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + z[i] = static_cast(x[i]); + } + }; + + samediff::Threads::parallel_for(func, 0, N); +}; + +template +void quickSort_parallel_internal_key(X *key, Nd4jLong const *xShapeInfo, + Y *values, Nd4jLong const *yShapeInfo, + int left, int right, int cutoff, + bool descending) { + int i = left, j = right; + X ktmp; + X pivot = key[shape::getIndexOffset((left + right) / 2, xShapeInfo)]; + + Y vtmp; + + { + /* PARTITION PART */ + while (i <= j) { + if (descending) { + while (key[shape::getIndexOffset(i, xShapeInfo)] > pivot) i++; + while (key[shape::getIndexOffset(j, xShapeInfo)] < pivot) j--; + if (i <= j) { + ktmp = key[shape::getIndexOffset(i, xShapeInfo)]; + key[shape::getIndexOffset(i, xShapeInfo)] = + key[shape::getIndexOffset(j, xShapeInfo)]; + key[shape::getIndexOffset(j, xShapeInfo)] = ktmp; + + vtmp = values[shape::getIndexOffset(i, yShapeInfo)]; + values[shape::getIndexOffset(i, yShapeInfo)] = + values[shape::getIndexOffset(j, yShapeInfo)]; + values[shape::getIndexOffset(j, yShapeInfo)] = vtmp; + + i++; + j--; } - - // - - if ( ((right-left) pivot) j--; + if (i <= j) { + ktmp = key[shape::getIndexOffset(i, xShapeInfo)]; + key[shape::getIndexOffset(i, xShapeInfo)] = + key[shape::getIndexOffset(j, xShapeInfo)]; + key[shape::getIndexOffset(j, xShapeInfo)] = ktmp; + + vtmp = values[shape::getIndexOffset(i, yShapeInfo)]; + values[shape::getIndexOffset(i, yShapeInfo)] = + values[shape::getIndexOffset(j, yShapeInfo)]; + values[shape::getIndexOffset(j, yShapeInfo)] = vtmp; + + i++; + j--; } + } } + } + // - template - void quickSort_parallel_internal_value(X* key, Nd4jLong const* xShapeInfo, Y* value, Nd4jLong const* yShapeInfo, int left, int right, int cutoff, bool descending) { - int i = left, j = right; - X ktmp; - Y pivot = value[shape::getIndexOffset((left + right) / 2, yShapeInfo)]; - - Y vtmp; - - { - /* PARTITION PART */ - while (i <= j) { - if (descending) { - while (value[shape::getIndexOffset(i, yShapeInfo)] > pivot) - i++; - while (value[shape::getIndexOffset(j, yShapeInfo)] < pivot) - j--; - if (i <= j) { - ktmp = key[shape::getIndexOffset(i, xShapeInfo)]; - key[shape::getIndexOffset(i, xShapeInfo)] = key[shape::getIndexOffset(j, xShapeInfo)]; - key[shape::getIndexOffset(j, xShapeInfo)] = ktmp; - - vtmp = value[shape::getIndexOffset(i, yShapeInfo)]; - value[shape::getIndexOffset(i, yShapeInfo)] = value[shape::getIndexOffset(j, yShapeInfo)]; - value[shape::getIndexOffset(j, yShapeInfo)] = vtmp; - - i++; - j--; - } - } else { - while (value[shape::getIndexOffset(i, yShapeInfo)] < pivot) - i++; - while (value[shape::getIndexOffset(j, yShapeInfo)] > pivot) - j--; - if (i <= j) { - ktmp = key[shape::getIndexOffset(i, xShapeInfo)]; - key[shape::getIndexOffset(i, xShapeInfo)] = key[shape::getIndexOffset(j, xShapeInfo)]; - key[shape::getIndexOffset(j, xShapeInfo)] = ktmp; - - vtmp = value[shape::getIndexOffset(i, yShapeInfo)]; - value[shape::getIndexOffset(i, yShapeInfo)] = value[shape::getIndexOffset(j, yShapeInfo)]; - value[shape::getIndexOffset(j, yShapeInfo)] = vtmp; - - i++; - j--; - } - } - } - - } - - // + if (((right - left) < cutoff)) { + if (left < j) { + quickSort_parallel_internal_key(key, xShapeInfo, values, yShapeInfo, left, + j, cutoff, descending); + } + if (i < right) { + quickSort_parallel_internal_key(key, xShapeInfo, values, yShapeInfo, i, + right, cutoff, descending); + } - if ( ((right-left) +void quickSort_parallel_internal_value(X *key, Nd4jLong const *xShapeInfo, + Y *value, Nd4jLong const *yShapeInfo, + int left, int right, int cutoff, + bool descending) { + int i = left, j = right; + X ktmp; + Y pivot = value[shape::getIndexOffset((left + right) / 2, yShapeInfo)]; + + Y vtmp; + + { + /* PARTITION PART */ + while (i <= j) { + if (descending) { + while (value[shape::getIndexOffset(i, yShapeInfo)] > pivot) i++; + while (value[shape::getIndexOffset(j, yShapeInfo)] < pivot) j--; + if (i <= j) { + ktmp = key[shape::getIndexOffset(i, xShapeInfo)]; + key[shape::getIndexOffset(i, xShapeInfo)] = + key[shape::getIndexOffset(j, xShapeInfo)]; + key[shape::getIndexOffset(j, xShapeInfo)] = ktmp; + + vtmp = value[shape::getIndexOffset(i, yShapeInfo)]; + value[shape::getIndexOffset(i, yShapeInfo)] = + value[shape::getIndexOffset(j, yShapeInfo)]; + value[shape::getIndexOffset(j, yShapeInfo)] = vtmp; + + i++; + j--; + } + } else { + while (value[shape::getIndexOffset(i, yShapeInfo)] < pivot) i++; + while (value[shape::getIndexOffset(j, yShapeInfo)] > pivot) j--; + if (i <= j) { + ktmp = key[shape::getIndexOffset(i, xShapeInfo)]; + key[shape::getIndexOffset(i, xShapeInfo)] = + key[shape::getIndexOffset(j, xShapeInfo)]; + key[shape::getIndexOffset(j, xShapeInfo)] = ktmp; + + vtmp = value[shape::getIndexOffset(i, yShapeInfo)]; + value[shape::getIndexOffset(i, yShapeInfo)] = + value[shape::getIndexOffset(j, yShapeInfo)]; + value[shape::getIndexOffset(j, yShapeInfo)] = vtmp; + + i++; + j--; } + } } + } + // - template - static void quickSort_parallel_key(void *varray, Nd4jLong const* xShapeInfo, void *yarray, Nd4jLong const* yShapeInfo, Nd4jLong lenArray, int numThreads, bool descending){ - auto array = reinterpret_cast(varray); - auto values = reinterpret_cast(yarray); - int cutoff = 1000; - - PRAGMA_OMP_PARALLEL_THREADS(numThreads) - { -PRAGMA_OMP_SINGLE_ARGS(nowait) - { - quickSort_parallel_internal_key(array, xShapeInfo, values, yShapeInfo, 0, lenArray-1, cutoff, descending); - } - } + if (((right - left) < cutoff)) { + if (left < j) { + quickSort_parallel_internal_value(key, xShapeInfo, value, yShapeInfo, + left, j, cutoff, descending); } - - template - static void quickSort_parallel_value(void *varray, Nd4jLong const* xShapeInfo, void *yarray, Nd4jLong const* yShapeInfo, Nd4jLong lenArray, int numThreads, bool descending){ - auto array = reinterpret_cast(varray); - auto values = reinterpret_cast(yarray); - int cutoff = 1000; - - PRAGMA_OMP_PARALLEL_THREADS(numThreads) - { -PRAGMA_OMP_SINGLE_ARGS(nowait) - { - quickSort_parallel_internal_value(array, xShapeInfo, values, yShapeInfo, 0, lenArray-1, cutoff, descending); - } - } + if (i < right) { + quickSort_parallel_internal_value(key, xShapeInfo, value, yShapeInfo, i, + right, cutoff, descending); } - template - void DoubleMethods::sortByKey(void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, bool descending) { - quickSort_parallel_key(vx, xShapeInfo, vy, yShapeInfo, shape::length(xShapeInfo), omp_get_max_threads(), descending); + } else { + PRAGMA_OMP_TASK { + quickSort_parallel_internal_value(key, xShapeInfo, value, yShapeInfo, + left, j, cutoff, descending); } - - template - void DoubleMethods::sortByValue(void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, bool descending) { - quickSort_parallel_value(vx, xShapeInfo, vy, yShapeInfo, shape::length(xShapeInfo), omp_get_max_threads(), descending); + PRAGMA_OMP_TASK { + quickSort_parallel_internal_value(key, xShapeInfo, value, yShapeInfo, i, + right, cutoff, descending); } + } +} - template - void DoubleMethods::sortTadByKey(void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int *dimension, int dimensionLength, bool descending) { - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - - auto packX = ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength); - auto packY = ConstantTadHelper::getInstance().tadForDimensions(yShapeInfo, dimension, dimensionLength); - - auto xLength = shape::length(xShapeInfo); - auto xTadLength = shape::length(packX.primaryShapeInfo()); - auto numTads = packX.numberOfTads(); - - auto func = PRAGMA_THREADS_FOR { - for (auto r = start; r < stop; r++) { - auto dx = x + packX.primaryOffsets()[r]; - auto dy = y + packY.primaryOffsets()[r]; - - quickSort_parallel_key(dx, packX.primaryShapeInfo(), dy, packY.primaryShapeInfo(), xTadLength, 1, descending); - } - }; - - samediff::Threads::parallel_tad(func, 0, numTads); +template +static void quickSort_parallel_key(void *varray, Nd4jLong const *xShapeInfo, + void *yarray, Nd4jLong const *yShapeInfo, + Nd4jLong lenArray, int numThreads, + bool descending) { + auto array = reinterpret_cast(varray); + auto values = reinterpret_cast(yarray); + int cutoff = 1000; + + PRAGMA_OMP_PARALLEL_THREADS(numThreads) { + PRAGMA_OMP_SINGLE_ARGS(nowait) { + quickSort_parallel_internal_key(array, xShapeInfo, values, yShapeInfo, 0, + lenArray - 1, cutoff, descending); } + } +} - template - void DoubleMethods::sortTadByValue(void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int *dimension, int dimensionLength, bool descending) { - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); +template +static void quickSort_parallel_value(void *varray, Nd4jLong const *xShapeInfo, + void *yarray, Nd4jLong const *yShapeInfo, + Nd4jLong lenArray, int numThreads, + bool descending) { + auto array = reinterpret_cast(varray); + auto values = reinterpret_cast(yarray); + int cutoff = 1000; + + PRAGMA_OMP_PARALLEL_THREADS(numThreads) { + PRAGMA_OMP_SINGLE_ARGS(nowait) { + quickSort_parallel_internal_value(array, xShapeInfo, values, yShapeInfo, + 0, lenArray - 1, cutoff, descending); + } + } +} - auto packX = ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength); - auto packY = ConstantTadHelper::getInstance().tadForDimensions(yShapeInfo, dimension, dimensionLength); +template +void DoubleMethods::sortByKey(void *vx, Nd4jLong const *xShapeInfo, + void *vy, Nd4jLong const *yShapeInfo, + bool descending) { + quickSort_parallel_key(vx, xShapeInfo, vy, yShapeInfo, + shape::length(xShapeInfo), omp_get_max_threads(), + descending); +} - auto xLength = shape::length(xShapeInfo); - auto xTadLength = shape::length(packX.primaryShapeInfo()); - auto numTads = packX.numberOfTads(); +template +void DoubleMethods::sortByValue(void *vx, Nd4jLong const *xShapeInfo, + void *vy, Nd4jLong const *yShapeInfo, + bool descending) { + quickSort_parallel_value(vx, xShapeInfo, vy, yShapeInfo, + shape::length(xShapeInfo), + omp_get_max_threads(), descending); +} - auto func = PRAGMA_THREADS_FOR { - for (auto r = start; r < stop; r++) { - auto dx = x + packX.primaryOffsets()[r]; - auto dy = y + packY.primaryOffsets()[r]; +template +void DoubleMethods::sortTadByKey(void *vx, Nd4jLong const *xShapeInfo, + void *vy, Nd4jLong const *yShapeInfo, + int *dimension, int dimensionLength, + bool descending) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + + auto packX = ConstantTadHelper::getInstance().tadForDimensions( + xShapeInfo, dimension, dimensionLength); + auto packY = ConstantTadHelper::getInstance().tadForDimensions( + yShapeInfo, dimension, dimensionLength); + + auto xLength = shape::length(xShapeInfo); + auto xTadLength = shape::length(packX.primaryShapeInfo()); + auto numTads = packX.numberOfTads(); + + auto func = PRAGMA_THREADS_FOR { + for (auto r = start; r < stop; r++) { + auto dx = x + packX.primaryOffsets()[r]; + auto dy = y + packY.primaryOffsets()[r]; + + quickSort_parallel_key(dx, packX.primaryShapeInfo(), dy, + packY.primaryShapeInfo(), xTadLength, 1, + descending); + } + }; - quickSort_parallel_value(dx, packX.primaryShapeInfo(), dy, packY.primaryShapeInfo(), xTadLength, 1, descending); - } - }; + samediff::Threads::parallel_tad(func, 0, numTads); +} - samediff::Threads::parallel_tad(func, 0, numTads); +template +void DoubleMethods::sortTadByValue(void *vx, Nd4jLong const *xShapeInfo, + void *vy, Nd4jLong const *yShapeInfo, + int *dimension, int dimensionLength, + bool descending) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + + auto packX = ConstantTadHelper::getInstance().tadForDimensions( + xShapeInfo, dimension, dimensionLength); + auto packY = ConstantTadHelper::getInstance().tadForDimensions( + yShapeInfo, dimension, dimensionLength); + + auto xLength = shape::length(xShapeInfo); + auto xTadLength = shape::length(packX.primaryShapeInfo()); + auto numTads = packX.numberOfTads(); + + auto func = PRAGMA_THREADS_FOR { + for (auto r = start; r < stop; r++) { + auto dx = x + packX.primaryOffsets()[r]; + auto dy = y + packY.primaryOffsets()[r]; + + quickSort_parallel_value(dx, packX.primaryShapeInfo(), dy, + packY.primaryShapeInfo(), xTadLength, 1, + descending); } -} + }; + samediff::Threads::parallel_tad(func, 0, numTads); +} +} // namespace sd diff --git a/libnd4j/include/ops/impl/specials_single.hpp b/libnd4j/include/ops/impl/specials_single.hpp index b6d717b83b17..f1cf52cae0c1 100644 --- a/libnd4j/include/ops/impl/specials_single.hpp +++ b/libnd4j/include/ops/impl/specials_single.hpp @@ -19,24 +19,24 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // - -#include -#include +#include +#include #include +#include +#include #include #include -#include -#include +#include #include -#include namespace sd { /** -* Concatneate multi array of the same shape together -* along a particular dimension -*/ + * Concatneate multi array of the same shape together + * along a particular dimension + */ // template -// void SpecialMethods::concatCpuGeneric(const std::vector& inArrs, NDArray& output, const int axis) { +// void SpecialMethods::concatCpuGeneric(const std::vector& inArrs, +// NDArray& output, const int axis) { // const uint numOfArrs = inArrs.size(); // int outDim; @@ -45,17 +45,20 @@ namespace sd { // if(isOutputVector || (axis == 0 && output.ordering() == 'c')) { // bool allVectorsOrScalars = true; -// const uint outEws = isOutputVector ? output.stridesOf()[outDim] : output.ews(); +// const uint outEws = isOutputVector ? output.stridesOf()[outDim] : +// output.ews(); // std::vector nonUnityDim(numOfArrs); // std::vector zOffset(numOfArrs); // for(int i = 0; i < numOfArrs; i++) { -// allVectorsOrScalars &= (inArrs[i]->lengthOf() == 1 || inArrs[i]->isCommonVector(nonUnityDim[i])); +// allVectorsOrScalars &= (inArrs[i]->lengthOf() == 1 || +// inArrs[i]->isCommonVector(nonUnityDim[i])); // if(!allVectorsOrScalars) // break; // if(i == 0) zOffset[0] = 0; -// else zOffset[i] = zOffset[i - 1] + outEws * inArrs[i - 1]->lengthOf(); +// else zOffset[i] = zOffset[i - 1] + outEws * inArrs[i - +// 1]->lengthOf(); // } // if(allVectorsOrScalars) { @@ -65,7 +68,8 @@ namespace sd { // auto func = PRAGMA_THREADS_FOR { // for (auto r = start; r < stop; r += increment) { // const Nd4jLong arrLen = inArrs[r]->lengthOf(); -// const uint xEws = (arrLen == 1) ? 1 : inArrs[r]->stridesOf()[nonUnityDim[r]]; +// const uint xEws = (arrLen == 1) ? 1 : +// inArrs[r]->stridesOf()[nonUnityDim[r]]; // T *z = outBuff + zOffset[r]; // T *x = inArrs[r]->bufferAsT(); @@ -86,21 +90,26 @@ namespace sd { // const int rank = inArrs[0]->rankOf(); // const int rank2 = 2*rank; -// std::vector> indices(numOfArrs, std::vector(rank2,0)); +// std::vector> indices(numOfArrs, +// std::vector(rank2,0)); // // take into account indices for first array // indices[0][2 * axis + 1] = inArrs[0]->sizeAt(axis); // // loop through the rest of input arrays // for(int i = 1; i < numOfArrs; ++i) { -// indices[i][2 * axis] = indices[i-1][2 * axis + 1]; // index start from -// indices[i][2 * axis + 1] = indices[i-1][2 * axis + 1] + inArrs[i]->sizeAt(axis); // index end with (excluding) +// indices[i][2 * axis] = indices[i-1][2 * axis + 1]; // index +// start from indices[i][2 * axis + 1] = indices[i-1][2 * axis + 1] +// + inArrs[i]->sizeAt(axis); // index end with (excluding) // } // auto func = PRAGMA_THREADS_FOR { // for (auto i = start; i < stop; i += increment) { // auto temp = output(indices[i], true); -// sd::TransformLoops::template loopTransform>( inArrs[i]->bufferAsT(), inArrs[i]->shapeInfo(), temp.bufferAsT(), temp.shapeInfo(), nullptr, 0, 1); +// sd::TransformLoops::template +// loopTransform>( +// inArrs[i]->bufferAsT(), inArrs[i]->shapeInfo(), +// temp.bufferAsT(), temp.shapeInfo(), nullptr, 0, 1); // } // }; @@ -127,227 +136,238 @@ namespace sd { template -void SpecialMethods::concatCpuGeneric(const std::vector& inArrs, NDArray& output, const int axis) { - - const int numOfInArrs = inArrs.size(); - const auto sizeofT = output.sizeOfT(); - - T* zBuff = output.bufferAsT(); - - bool luckCase1 = ((axis == 0 && output.ordering() == 'c') || (axis == output.rankOf() - 1 && output.ordering() == 'f')) && output.ews() == 1; - - if(luckCase1) { - for (uint i = 0; i < numOfInArrs; ++i) { - luckCase1 &= inArrs[i]->ordering() == output.ordering() && inArrs[i]->ews() == 1; - if(!luckCase1) - break; - } +void SpecialMethods::concatCpuGeneric( + const std::vector &inArrs, NDArray &output, + const int axis) { + const int numOfInArrs = inArrs.size(); + const auto sizeofT = output.sizeOfT(); + + T *zBuff = output.bufferAsT(); + + bool luckCase1 = + ((axis == 0 && output.ordering() == 'c') || + (axis == output.rankOf() - 1 && output.ordering() == 'f')) && + output.ews() == 1; + + if (luckCase1) { + for (uint i = 0; i < numOfInArrs; ++i) { + luckCase1 &= + inArrs[i]->ordering() == output.ordering() && inArrs[i]->ews() == 1; + if (!luckCase1) break; } + } - if(luckCase1) { // for example {1,10} + {2,10} + {3,10} = {6, 10} order c; or {10,1} + {10,2} + {10,3} = {10, 6} order f + if (luckCase1) { // for example {1,10} + {2,10} + {3,10} = {6, 10} order c; + // or {10,1} + {10,2} + {10,3} = {10, 6} order f - T* z = zBuff; - for (uint i = 0; i < numOfInArrs; ++i) { - const auto memAmountToCopy = inArrs[i]->lengthOf(); - memcpy(z, inArrs[i]->bufferAsT(), memAmountToCopy * sizeofT); - z += memAmountToCopy; - } - return; + T *z = zBuff; + for (uint i = 0; i < numOfInArrs; ++i) { + const auto memAmountToCopy = inArrs[i]->lengthOf(); + memcpy(z, inArrs[i]->bufferAsT(), memAmountToCopy * sizeofT); + z += memAmountToCopy; } + return; + } - // const bool isZcontin = output.strideAt(axis) == 1; - // bool areInputsContin = true; - // bool allSameOrder = true; - // std::vector strideOfContigStride(numOfInArrs); - - // if(isZcontin) { + // const bool isZcontin = output.strideAt(axis) == 1; + // bool areInputsContin = true; + // bool allSameOrder = true; + // std::vector strideOfContigStride(numOfInArrs); - // for (uint i = 0; i < numOfInArrs; ++i) { + // if(isZcontin) { - // areInputsContin &= inArrs[i]->strideAt(axis) == 1; - // allSameOrder &= inArrs[i]->ordering() == output.ordering(); - // if(!areInputsContin || !allSameOrder) - // break; + // for (uint i = 0; i < numOfInArrs; ++i) { - // strideOfContigStride[i] = strideOverContigAxis(axis, inArrs[i]->getShapeInfo()); - // } - // } + // areInputsContin &= inArrs[i]->strideAt(axis) == 1; + // allSameOrder &= inArrs[i]->ordering() == output.ordering(); + // if(!areInputsContin || !allSameOrder) + // break; - // const bool luckCase2 = isZcontin && areInputsContin && allSameOrder; + // strideOfContigStride[i] = strideOverContigAxis(axis, + // inArrs[i]->getShapeInfo()); + // } + // } - // if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, here axis 1 shoud have stride = 1 for all inputs arrays and output array + // const bool luckCase2 = isZcontin && areInputsContin && allSameOrder; - // const auto zStep = strideOverContigAxis(axis, output.getShapeInfo()); + // if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, + // here axis 1 shoud have stride = 1 for all inputs arrays and output array - // for (uint i = 0; i < output.lengthOf() / output.sizeAt(axis); ++i) { + // const auto zStep = strideOverContigAxis(axis, + // output.getShapeInfo()); - // T* z = zBuff + zStep * i; + // for (uint i = 0; i < output.lengthOf() / output.sizeAt(axis); ++i) { - // for (uint j = 0; j < inArrs.size(); ++j) { - // const auto xDim = inArrs[j]->sizeAt(axis); - // const T* x = inArrs[j]->bufferAsT() + strideOfContigStride[j] * i; - // memcpy(z, x, xDim * sizeofT); - // z += xDim; - // } - // } + // T* z = zBuff + zStep * i; - // return; - // } - - // general case - auto func = PRAGMA_THREADS_FOR { + // for (uint j = 0; j < inArrs.size(); ++j) { + // const auto xDim = inArrs[j]->sizeAt(axis); + // const T* x = inArrs[j]->bufferAsT() + + // strideOfContigStride[j] * i; memcpy(z, x, xDim * sizeofT); z += + // xDim; + // } + // } - int coords[MAX_RANK], temp; + // return; + // } - for (auto i = start; i < stop; i += increment) { + // general case + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK], temp; - shape::index2coordsCPU(start, i, output.shapeInfo(), coords); + for (auto i = start; i < stop; i += increment) { + shape::index2coordsCPU(start, i, output.shapeInfo(), coords); - const auto zOffset = shape::getOffset(output.shapeInfo(), coords); + const auto zOffset = shape::getOffset(output.shapeInfo(), coords); - uint inArrIdx = 0; - uint xDim = inArrs[inArrIdx]->sizeAt(axis); + uint inArrIdx = 0; + uint xDim = inArrs[inArrIdx]->sizeAt(axis); - temp = coords[axis]; - while (coords[axis] >= xDim) { - coords[axis] -= xDim; - xDim = inArrs[++inArrIdx]->sizeAt(axis); - } + temp = coords[axis]; + while (coords[axis] >= xDim) { + coords[axis] -= xDim; + xDim = inArrs[++inArrIdx]->sizeAt(axis); + } - const T* x = inArrs[inArrIdx]->bufferAsT(); - const auto xOffset = shape::getOffset(inArrs[inArrIdx]->shapeInfo(), coords); + const T *x = inArrs[inArrIdx]->bufferAsT(); + const auto xOffset = + shape::getOffset(inArrs[inArrIdx]->shapeInfo(), coords); - zBuff[zOffset] = x[xOffset]; + zBuff[zOffset] = x[xOffset]; - coords[axis] = temp; - } - }; + coords[axis] = temp; + } + }; - samediff::Threads::parallel_for(func, 0, output.lengthOf()); + samediff::Threads::parallel_for(func, 0, output.lengthOf()); } /** -* Concatneate multi array of the same shape together -* along a particular dimension -*/ + * Concatneate multi array of the same shape together + * along a particular dimension + */ template -void SpecialMethods::concatCpuGeneric(int dimension, int numArrays, Nd4jPointer *data, Nd4jPointer *inputShapeInfo, void *vresult, Nd4jLong const* resultShapeInfo) { - auto result = reinterpret_cast(vresult); - std::vector inputs(numArrays); +void SpecialMethods::concatCpuGeneric(int dimension, int numArrays, + Nd4jPointer *data, + Nd4jPointer *inputShapeInfo, + void *vresult, + Nd4jLong const *resultShapeInfo) { + auto result = reinterpret_cast(vresult); + std::vector inputs(numArrays); - NDArray output(static_cast(result), resultShapeInfo); + NDArray output(static_cast(result), resultShapeInfo); - for(int i = 0; i < numArrays; ++i) - inputs[i] = new NDArray(static_cast(data[i]), static_cast(inputShapeInfo[i])); + for (int i = 0; i < numArrays; ++i) + inputs[i] = new NDArray(static_cast(data[i]), + static_cast(inputShapeInfo[i])); - sd::SpecialMethods::concatCpuGeneric(inputs, output, dimension); + sd::SpecialMethods::concatCpuGeneric(inputs, output, dimension); - for(int i = 0; i < numArrays; ++i) - delete inputs[i]; + for (int i = 0; i < numArrays; ++i) delete inputs[i]; } - template -void SpecialMethods::splitCpuGeneric(const NDArray& input, const std::vector& outArrs, const int axis) { - - int numSplits = outArrs.size(); +void SpecialMethods::splitCpuGeneric(const NDArray &input, + const std::vector &outArrs, + const int axis) { + int numSplits = outArrs.size(); - const auto sizeofT = input.sizeOfT(); + const auto sizeofT = input.sizeOfT(); - auto xBuff = input.bufferAsT(); + auto xBuff = input.bufferAsT(); - bool luckCase1 = ((axis == 0 && input.ordering() == 'c') || (axis == input.rankOf() - 1 && input.ordering() == 'f')) && input.ews() == 1; + bool luckCase1 = ((axis == 0 && input.ordering() == 'c') || + (axis == input.rankOf() - 1 && input.ordering() == 'f')) && + input.ews() == 1; - if (luckCase1) { - for (uint i = 0; i < numSplits; ++i) { - luckCase1 &= outArrs[i]->ordering() == input.ordering() && outArrs[i]->ews() == 1; - if (!luckCase1) - break; - } + if (luckCase1) { + for (uint i = 0; i < numSplits; ++i) { + luckCase1 &= + outArrs[i]->ordering() == input.ordering() && outArrs[i]->ews() == 1; + if (!luckCase1) break; } - - if (luckCase1) { - - T* x = const_cast(xBuff); - for (uint i = 0; i < numSplits; ++i) { - const auto memAmountToCopy = outArrs[i]->lengthOf(); - memcpy(outArrs[i]->bufferAsT(), x, memAmountToCopy * sizeofT); - x += memAmountToCopy; - } - return; + } + + if (luckCase1) { + T *x = const_cast(xBuff); + for (uint i = 0; i < numSplits; ++i) { + const auto memAmountToCopy = outArrs[i]->lengthOf(); + memcpy(outArrs[i]->bufferAsT(), x, memAmountToCopy * sizeofT); + x += memAmountToCopy; } + return; + } - // const bool isXcontin = input.strideAt(axis) == 1; - // bool areOutsContin = true; - // bool allSameOrder = true; - // std::vector strideOfContigStride(numSplits); - - // if (isXcontin) { + // const bool isXcontin = input.strideAt(axis) == 1; + // bool areOutsContin = true; + // bool allSameOrder = true; + // std::vector strideOfContigStride(numSplits); - // for (uint i = 0; i < numSplits; ++i) { + // if (isXcontin) { - // areOutsContin &= outArrs[i]->strideAt(axis) == 1; - // allSameOrder &= outArrs[i]->ordering() == input.ordering(); - // if (!areOutsContin || !allSameOrder) - // break; + // for (uint i = 0; i < numSplits; ++i) { - // strideOfContigStride[i] = shape::strideOverContigAxis(axis, outArrs[i]->shapeInfo()); - // } - // } + // areOutsContin &= outArrs[i]->strideAt(axis) == 1; + // allSameOrder &= outArrs[i]->ordering() == input.ordering(); + // if (!areOutsContin || !allSameOrder) + // break; - // const bool luckCase2 = isXcontin && areOutsContin && allSameOrder; + // strideOfContigStride[i] = shape::strideOverContigAxis(axis, + // outArrs[i]->shapeInfo()); + // } + // } - // if (luckCase2) { + // const bool luckCase2 = isXcontin && areOutsContin && allSameOrder; - // const auto xStep = shape::strideOverContigAxis(axis, input.shapeInfo()); + // if (luckCase2) { - // for (uint i = 0; i < input.lengthOf() / input.sizeAt(axis); ++i) { + // const auto xStep = shape::strideOverContigAxis(axis, + // input.shapeInfo()); - // T* x = xBuff + xStep * i; + // for (uint i = 0; i < input.lengthOf() / input.sizeAt(axis); ++i) { - // for (uint j = 0; j < numSplits; ++j) { - // const auto zDim = outArrs[j]->sizeAt(axis); - // T* z = outArrs[j]->bufferAsT() + strideOfContigStride[j] * i; - // memcpy(z, x, zDim * sizeofT); - // x += zDim; - // } - // } + // T* x = xBuff + xStep * i; - // return; - // } + // for (uint j = 0; j < numSplits; ++j) { + // const auto zDim = outArrs[j]->sizeAt(axis); + // T* z = outArrs[j]->bufferAsT() + strideOfContigStride[j] * + // i; memcpy(z, x, zDim * sizeofT); x += zDim; + // } + // } - uint zDim = outArrs[0]->sizeAt(axis); - // general case + // return; + // } - auto func = PRAGMA_THREADS_FOR{ + uint zDim = outArrs[0]->sizeAt(axis); + // general case - int coords[MAX_RANK], temp; + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK], temp; - for (auto i = start; i < stop; i += increment) { + for (auto i = start; i < stop; i += increment) { + shape::index2coordsCPU(start, i, input.shapeInfo(), coords); + const auto xOffset = shape::getOffset(input.shapeInfo(), coords); - shape::index2coordsCPU(start, i, input.shapeInfo(), coords); - const auto xOffset = shape::getOffset(input.shapeInfo(), coords); + uint outArrIdx = 0; + temp = coords[axis]; - uint outArrIdx = 0; - temp = coords[axis]; + while (coords[axis] >= zDim) { + coords[axis] -= zDim; + ++outArrIdx; + } - while (coords[axis] >= zDim) { - coords[axis] -= zDim; - ++outArrIdx; - } + T *z = outArrs[outArrIdx]->bufferAsT(); + const auto zOffset = + shape::getOffset(outArrs[outArrIdx]->shapeInfo(), coords); + z[zOffset] = xBuff[xOffset]; - T* z = outArrs[outArrIdx]->bufferAsT(); - const auto zOffset = shape::getOffset(outArrs[outArrIdx]->shapeInfo(), coords); - z[zOffset] = xBuff[xOffset]; - - coords[axis] = temp; - } - }; + coords[axis] = temp; + } + }; - samediff::Threads::parallel_for(func, 0, input.lengthOf()); + samediff::Threads::parallel_for(func, 0, input.lengthOf()); } - /** * This kernel accumulates X arrays, and stores result into Z * @@ -357,22 +377,23 @@ void SpecialMethods::splitCpuGeneric(const NDArray& input, const std::vector< * @param n * @param length */ - template - void SpecialMethods::accumulateGeneric(void **vx, void *vz, Nd4jLong const* zShapeInfo, int n, const Nd4jLong length) { - auto z = reinterpret_cast(vz); - auto x = reinterpret_cast(vx); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - for (auto ar = 0L; ar < n; ar++) { - z[i] += x[ar][i]; - } - } - }; - - samediff::Threads::parallel_for(func, 0, length); +template +void SpecialMethods::accumulateGeneric(void **vx, void *vz, + Nd4jLong const *zShapeInfo, int n, + const Nd4jLong length) { + auto z = reinterpret_cast(vz); + auto x = reinterpret_cast(vx); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + for (auto ar = 0L; ar < n; ar++) { + z[i] += x[ar][i]; + } } + }; + samediff::Threads::parallel_for(func, 0, length); +} /** * This kernel averages X input arrays, and stores result to Z @@ -384,277 +405,296 @@ void SpecialMethods::splitCpuGeneric(const NDArray& input, const std::vector< * @param length * @param propagate */ - template - void SpecialMethods::averageGeneric(void **vx, void *vz, Nd4jLong const* zShapeInfo, int n, const Nd4jLong length, bool propagate) { - auto z = reinterpret_cast(vz); - auto x = reinterpret_cast(vx); - - if (z == nullptr) { - //code branch for absent Z - z = x[0]; - - PRAGMA_OMP_SIMD - for (uint64_t i = 0; i < length; i++) { - z[i] /= static_cast(n); - } - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - for (Nd4jLong ar = 1; ar < n; ar++) { - z[i] += x[ar][i] / static_cast(n); - } - } - }; - samediff::Threads::parallel_for(func, 0, length); - - // instead of doing element-wise propagation, we just issue memcpy to propagate data - for (Nd4jLong ar = 1; ar < n; ar++) { - memcpy(x[ar], z, length * sizeof(T)); - } - } else { - // code branch for existing Z - - // memset before propagation - memset(z, 0, length * sizeof(T)); - - // aggregation step - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - for (Nd4jLong ar = 0; ar < n; ar++) { - z[i] += x[ar][i] / static_cast(n); - } - } - }; - samediff::Threads::parallel_for(func, 0, length); - - // instead of doing element-wise propagation, we just issue memcpy to propagate data - for (Nd4jLong ar = 0; ar < n; ar++) { - memcpy(x[ar], z, length * sizeof(T)); - } - } +template +void SpecialMethods::averageGeneric(void **vx, void *vz, + Nd4jLong const *zShapeInfo, int n, + const Nd4jLong length, bool propagate) { + auto z = reinterpret_cast(vz); + auto x = reinterpret_cast(vx); + + if (z == nullptr) { + // code branch for absent Z + z = x[0]; + + PRAGMA_OMP_SIMD + for (uint64_t i = 0; i < length; i++) { + z[i] /= static_cast(n); } - template - Nd4jLong SpecialMethods::getPosition(Nd4jLong const* xShapeInfo, Nd4jLong index) { - auto xEWS = shape::elementWiseStride(xShapeInfo); + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + for (Nd4jLong ar = 1; ar < n; ar++) { + z[i] += x[ar][i] / static_cast(n); + } + } + }; + samediff::Threads::parallel_for(func, 0, length); - if (xEWS == 1) - return index; - else if (xEWS > 1) - return index * xEWS; - else - return shape::getIndexOffset(index, xShapeInfo); + // instead of doing element-wise propagation, we just issue memcpy to + // propagate data + for (Nd4jLong ar = 1; ar < n; ar++) { + memcpy(x[ar], z, length * sizeof(T)); } + } else { + // code branch for existing Z - template - void SpecialMethods::quickSort_parallel_internal(T* array, Nd4jLong const* xShapeInfo, int left, int right, int cutoff, bool descending) { - - int i = left, j = right; - T tmp; - T pivot = array[getPosition(xShapeInfo, (left + right) / 2)]; - - - { - /* PARTITION PART */ - while (i <= j) { - if (descending) { - while (array[getPosition(xShapeInfo, i)] > pivot) - i++; - while (array[getPosition(xShapeInfo, j)] < pivot) - j--; - if (i <= j) { - tmp = array[getPosition(xShapeInfo, i)]; - array[getPosition(xShapeInfo, i)] = array[getPosition(xShapeInfo, j)]; - array[getPosition(xShapeInfo, j)] = tmp; - i++; - j--; - } - } else { - while (array[getPosition(xShapeInfo, i)] < pivot) - i++; - while (array[getPosition(xShapeInfo, j)] > pivot) - j--; - if (i <= j) { - tmp = array[getPosition(xShapeInfo, i)]; - array[getPosition(xShapeInfo, i)] = array[getPosition(xShapeInfo, j)]; - array[getPosition(xShapeInfo, j)] = tmp; - i++; - j--; - } - } - } + // memset before propagation + memset(z, 0, length * sizeof(T)); + // aggregation step + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + for (Nd4jLong ar = 0; ar < n; ar++) { + z[i] += x[ar][i] / static_cast(n); } + } + }; + samediff::Threads::parallel_for(func, 0, length); - // + // instead of doing element-wise propagation, we just issue memcpy to + // propagate data + for (Nd4jLong ar = 0; ar < n; ar++) { + memcpy(x[ar], z, length * sizeof(T)); + } + } +} - if ( ((right-left) +Nd4jLong SpecialMethods::getPosition(Nd4jLong const *xShapeInfo, + Nd4jLong index) { + auto xEWS = shape::elementWiseStride(xShapeInfo); + + if (xEWS == 1) + return index; + else if (xEWS > 1) + return index * xEWS; + else + return shape::getIndexOffset(index, xShapeInfo); +} - }else{ -PRAGMA_OMP_TASK - { quickSort_parallel_internal(array, xShapeInfo, left, j, cutoff, descending); } -PRAGMA_OMP_TASK - { quickSort_parallel_internal(array, xShapeInfo, i, right, cutoff, descending); } +template +void SpecialMethods::quickSort_parallel_internal(T *array, + Nd4jLong const *xShapeInfo, + int left, int right, + int cutoff, + bool descending) { + int i = left, j = right; + T tmp; + T pivot = array[getPosition(xShapeInfo, (left + right) / 2)]; + + { + /* PARTITION PART */ + while (i <= j) { + if (descending) { + while (array[getPosition(xShapeInfo, i)] > pivot) i++; + while (array[getPosition(xShapeInfo, j)] < pivot) j--; + if (i <= j) { + tmp = array[getPosition(xShapeInfo, i)]; + array[getPosition(xShapeInfo, i)] = array[getPosition(xShapeInfo, j)]; + array[getPosition(xShapeInfo, j)] = tmp; + i++; + j--; } - } - - template - void SpecialMethods::quickSort_parallel(void *varray, Nd4jLong const* xShapeInfo, Nd4jLong lenArray, int numThreads, bool descending){ - auto array = reinterpret_cast(varray); - int cutoff = 1000; - - PRAGMA_OMP_PARALLEL_THREADS(numThreads) - { -PRAGMA_OMP_SINGLE_ARGS(nowait) - { - quickSort_parallel_internal(array, xShapeInfo, 0, lenArray-1, cutoff, descending); - } + } else { + while (array[getPosition(xShapeInfo, i)] < pivot) i++; + while (array[getPosition(xShapeInfo, j)] > pivot) j--; + if (i <= j) { + tmp = array[getPosition(xShapeInfo, i)]; + array[getPosition(xShapeInfo, i)] = array[getPosition(xShapeInfo, j)]; + array[getPosition(xShapeInfo, j)] = tmp; + i++; + j--; } - + } } + } + // - - template - int SpecialMethods::nextPowerOf2(int number) { - int pos = 0; - - while (number > 0) { - pos++; - number = number >> 1; - } - return (int) pow(2, pos); + if (((right - left) < cutoff)) { + if (left < j) { + quickSort_parallel_internal(array, xShapeInfo, left, j, cutoff, + descending); + } + if (i < right) { + quickSort_parallel_internal(array, xShapeInfo, i, right, cutoff, + descending); } - template - int SpecialMethods::lastPowerOf2(int number) { - int p = 1; - while (p <= number) - p <<= 1; - - p >>= 1; - return p; + } else { + PRAGMA_OMP_TASK { + quickSort_parallel_internal(array, xShapeInfo, left, j, cutoff, + descending); + } + PRAGMA_OMP_TASK { + quickSort_parallel_internal(array, xShapeInfo, i, right, cutoff, + descending); } + } +} +template +void SpecialMethods::quickSort_parallel(void *varray, + Nd4jLong const *xShapeInfo, + Nd4jLong lenArray, int numThreads, + bool descending) { + auto array = reinterpret_cast(varray); + int cutoff = 1000; + + PRAGMA_OMP_PARALLEL_THREADS(numThreads) { + PRAGMA_OMP_SINGLE_ARGS(nowait) { + quickSort_parallel_internal(array, xShapeInfo, 0, lenArray - 1, cutoff, + descending); + } + } +} - template - void SpecialMethods::sortGeneric(void *vx, Nd4jLong const* xShapeInfo, bool descending) { - auto x = reinterpret_cast(vx); +template +int SpecialMethods::nextPowerOf2(int number) { + int pos = 0; + + while (number > 0) { + pos++; + number = number >> 1; + } + return (int)pow(2, pos); +} - quickSort_parallel(x, xShapeInfo, shape::length(xShapeInfo), omp_get_max_threads(), descending); - } +template +int SpecialMethods::lastPowerOf2(int number) { + int p = 1; + while (p <= number) p <<= 1; - template - void SpecialMethods::sortTadGeneric(void *vx, Nd4jLong const* xShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool descending) { - auto x = reinterpret_cast(vx); + p >>= 1; + return p; +} - //quickSort_parallel(x, xShapeInfo, shape::length(xShapeInfo), omp_get_max_threads(), descending); - Nd4jLong xLength = shape::length(xShapeInfo); - Nd4jLong xTadLength = shape::tadLength(xShapeInfo, dimension, dimensionLength); - int numTads = xLength / xTadLength; +template +void SpecialMethods::sortGeneric(void *vx, Nd4jLong const *xShapeInfo, + bool descending) { + auto x = reinterpret_cast(vx); - auto func = PRAGMA_THREADS_FOR { - for (auto r = start; r < stop; r++) { - T *dx = x + tadOffsets[r]; + quickSort_parallel(x, xShapeInfo, shape::length(xShapeInfo), + omp_get_max_threads(), descending); +} - quickSort_parallel(dx, tadShapeInfo, xTadLength, 1, descending); - } - }; - samediff::Threads::parallel_tad(func, 0, numTads); +template +void SpecialMethods::sortTadGeneric(void *vx, Nd4jLong const *xShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, + bool descending) { + auto x = reinterpret_cast(vx); + + // quickSort_parallel(x, xShapeInfo, shape::length(xShapeInfo), + // omp_get_max_threads(), descending); + Nd4jLong xLength = shape::length(xShapeInfo); + Nd4jLong xTadLength = + shape::tadLength(xShapeInfo, dimension, dimensionLength); + int numTads = xLength / xTadLength; + + auto func = PRAGMA_THREADS_FOR { + for (auto r = start; r < stop; r++) { + T *dx = x + tadOffsets[r]; + + quickSort_parallel(dx, tadShapeInfo, xTadLength, 1, descending); } + }; + samediff::Threads::parallel_tad(func, 0, numTads); +} +template +void SpecialMethods::decodeBitmapGeneric(const void *dx, Nd4jLong N, + void *vz, + Nd4jLong const *zShapeInfo) { + auto dz = reinterpret_cast(vz); + auto x = reinterpret_cast(dx); + Nd4jLong lim = N / 16 + 5; + + FloatBits2 fb; + fb.i_ = x[2]; + float threshold = fb.f_; + + auto pPos = -1; + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + const auto v = x[e]; + for (int bitId = 0; bitId < 16; bitId++) { + bool hasBit = (v & 1 << (bitId)) != 0; + bool hasSign = (v & 1 << (bitId + 16)) != 0; + auto cPos = (e - 4) * 16 + bitId; + + if (hasBit) { + if (hasSign) + dz[cPos] -= static_cast(threshold); + else + dz[cPos] += static_cast(threshold); + } else if (hasSign) { + dz[cPos] -= static_cast(threshold / 2); + } - template - void SpecialMethods::decodeBitmapGeneric(const void *dx, Nd4jLong N, void *vz, Nd4jLong const* zShapeInfo) { - auto dz = reinterpret_cast(vz); - auto x = reinterpret_cast(dx); - Nd4jLong lim = N / 16 + 5; - - FloatBits2 fb; - fb.i_ = x[2]; - float threshold = fb.f_; - - auto pPos = -1; - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - const auto v = x[e]; - for (int bitId = 0; bitId < 16; bitId++) { - bool hasBit = (v & 1 << (bitId)) != 0; - bool hasSign = (v & 1 << (bitId + 16)) != 0; - auto cPos = (e - 4) * 16 + bitId; - - if (hasBit) { - if (hasSign) - dz[cPos] -= static_cast(threshold); - else - dz[cPos] += static_cast(threshold); - } else if (hasSign) { - dz[cPos] -= static_cast(threshold / 2); - } - - pPos = cPos; - } - } - }; - - samediff::Threads::parallel_for(func, 4, lim); + pPos = cPos; + } } + }; - template - Nd4jLong SpecialMethods::encodeBitmapGeneric(void *vx, Nd4jLong const* xShapeInfo, Nd4jLong N, int *dz, float threshold) { - auto dx = reinterpret_cast(vx); - const T two(2.0f); - const T zero(0.0f); - const T t(threshold); - const T thalf = t / two; + samediff::Threads::parallel_for(func, 4, lim); +} + +template +Nd4jLong SpecialMethods::encodeBitmapGeneric(void *vx, + Nd4jLong const *xShapeInfo, + Nd4jLong N, int *dz, + float threshold) { + auto dx = reinterpret_cast(vx); + const T two(2.0f); + const T zero(0.0f); + const T t(threshold); + const T thalf = t / two; - //auto func = PRAGMA_REDUCE_LONG { - Nd4jLong retVal = 0L; + // auto func = PRAGMA_REDUCE_LONG { + Nd4jLong retVal = 0L; - PRAGMA_OMP_PARALLEL_FOR_REDUCTION(+:retVal) - for (auto x = 0; x < N; x += 16) { - int byte = 0; - int byteId = x / 16 + 4; + PRAGMA_OMP_PARALLEL_FOR_REDUCTION(+ : retVal) + for (auto x = 0; x < N; x += 16) { + int byte = 0; + int byteId = x / 16 + 4; - for (int f = 0; f < 16; f++) { - auto e = x + f; + for (int f = 0; f < 16; f++) { + auto e = x + f; - if (e >= N) - continue; + if (e >= N) continue; - T val = dx[e]; - T abs = sd::math::nd4j_abs(val); + T val = dx[e]; + T abs = sd::math::nd4j_abs(val); - int bitId = e % 16; + int bitId = e % 16; - if (abs >= t) { - byte |= 1 << (bitId); - retVal++; + if (abs >= t) { + byte |= 1 << (bitId); + retVal++; - if (val < zero) { - byte |= 1 << (bitId + 16); - dx[e] += t; - } else { - dx[e] -= t; - } - } else if (abs >= thalf && val < zero) { - byte |= 1 << (bitId + 16); - dx[e] += thalf; + if (val < zero) { + byte |= 1 << (bitId + 16); + dx[e] += t; + } else { + dx[e] -= t; + } + } else if (abs >= thalf && val < zero) { + byte |= 1 << (bitId + 16); + dx[e] += thalf; - retVal++; - } - } + retVal++; + } + } - dz[byteId] = byte; - } + dz[byteId] = byte; + } - return retVal; - //}; + return retVal; + //}; - //return samediff::Threads::parallel_long(func, LAMBDA_SUML, 0, N, 16); - } + // return samediff::Threads::parallel_long(func, LAMBDA_SUML, 0, N, 16); } - +} // namespace sd diff --git a/libnd4j/include/ops/impl/specials_sparse.cpp b/libnd4j/include/ops/impl/specials_sparse.cpp index c782ccf188a4..7de2d954a4e4 100644 --- a/libnd4j/include/ops/impl/specials_sparse.cpp +++ b/libnd4j/include/ops/impl/specials_sparse.cpp @@ -19,10 +19,10 @@ // #include -#include -#include #include #include +#include +#include #ifdef _OPENMP #include #endif @@ -30,193 +30,203 @@ #include namespace sd { - namespace sparse { - - template - void SparseUtils::printIndex(Nd4jLong *indices, int rank, int x) { - printf(" ["); - for (int e = 0; e < rank; e++) { - if (e > 0) - printf(", "); - - printf("%lld", (long long) indices[x * rank + e]); - } - printf("] "); - } - - template - bool SparseUtils::ltIndices(Nd4jLong *indices, int rank, Nd4jLong x, Nd4jLong y) { - for (int e = 0; e < rank; e++) { - Nd4jLong idxX = indices[x * rank + e]; - Nd4jLong idxY = indices[y * rank + e]; - // we're comparing indices one by one, starting from outer dimension - if (idxX < idxY) { - return true; - } else if (idxX == idxY) { - // do nothing, continue to next dimension - } else - return false; - } - - return false; - } +namespace sparse { - template - bool SparseUtils::gtIndices(Nd4jLong *indices, int rank, Nd4jLong x, Nd4jLong y) { - for (int e = 0; e < rank; e++) { - // we're comparing indices one by one, starting from outer dimension - Nd4jLong idxX = indices[x * rank + e]; - Nd4jLong idxY = indices[y * rank + e]; - if ( idxX > idxY) { - return true; - } else if (idxX == idxY) { - // do nothing, continue to next dimension - } else - return false; - } - return false; - } +template +void SparseUtils::printIndex(Nd4jLong *indices, int rank, int x) { + printf(" ["); + for (int e = 0; e < rank; e++) { + if (e > 0) printf(", "); - template - void SparseUtils::swapEverything(Nd4jLong *indices, T *array, int rank, Nd4jLong x, Nd4jLong y) { - // swap indices - for (int e = 0; e < rank; e++) { - Nd4jLong tmp = indices[x * rank + e]; - indices[x * rank + e] = indices[y * rank + e]; - indices[y * rank + e] = tmp; - } - - // swap values - T tmp = array[x]; - array[x] = array[y]; - array[y] = tmp; - } + printf("%lld", (long long)indices[x * rank + e]); + } + printf("] "); +} - template - Nd4jLong SparseUtils::coo_quickSort_findPivot(Nd4jLong *indices, T *array, Nd4jLong left, Nd4jLong right, - int rank) { - Nd4jLong mid = (left + right) / 2; +template +bool SparseUtils::ltIndices(Nd4jLong *indices, int rank, Nd4jLong x, + Nd4jLong y) { + for (int e = 0; e < rank; e++) { + Nd4jLong idxX = indices[x * rank + e]; + Nd4jLong idxY = indices[y * rank + e]; + // we're comparing indices one by one, starting from outer dimension + if (idxX < idxY) { + return true; + } else if (idxX == idxY) { + // do nothing, continue to next dimension + } else + return false; + } + + return false; +} - // ensure left < mid - if (ltIndices(indices, rank, mid, left)) { // ensure lo < mid - swapEverything(indices, array, rank, mid, left); - } +template +bool SparseUtils::gtIndices(Nd4jLong *indices, int rank, Nd4jLong x, + Nd4jLong y) { + for (int e = 0; e < rank; e++) { + // we're comparing indices one by one, starting from outer dimension + Nd4jLong idxX = indices[x * rank + e]; + Nd4jLong idxY = indices[y * rank + e]; + if (idxX > idxY) { + return true; + } else if (idxX == idxY) { + // do nothing, continue to next dimension + } else + return false; + } + return false; +} - // ensure left < right - if (ltIndices(indices, rank, right, left)) { - swapEverything(indices, array, rank, right, left); - } +template +void SparseUtils::swapEverything(Nd4jLong *indices, T *array, int rank, + Nd4jLong x, Nd4jLong y) { + // swap indices + for (int e = 0; e < rank; e++) { + Nd4jLong tmp = indices[x * rank + e]; + indices[x * rank + e] = indices[y * rank + e]; + indices[y * rank + e] = tmp; + } + + // swap values + T tmp = array[x]; + array[x] = array[y]; + array[y] = tmp; +} - // ensure mid < right - if (ltIndices(indices, rank, right, mid)) { - swapEverything(indices, array, rank, right, mid); - } +template +Nd4jLong SparseUtils::coo_quickSort_findPivot(Nd4jLong *indices, T *array, + Nd4jLong left, Nd4jLong right, + int rank) { + Nd4jLong mid = (left + right) / 2; + + // ensure left < mid + if (ltIndices(indices, rank, mid, left)) { // ensure lo < mid + swapEverything(indices, array, rank, mid, left); + } + + // ensure left < right + if (ltIndices(indices, rank, right, left)) { + swapEverything(indices, array, rank, right, left); + } + + // ensure mid < right + if (ltIndices(indices, rank, right, mid)) { + swapEverything(indices, array, rank, right, mid); + } + + // mid is the median of the 3, and is the optimal pivot point + return mid; +} - // mid is the median of the 3, and is the optimal pivot point - return mid; +template +void SparseUtils::coo_quickSort_parallel_internal(Nd4jLong *indices, + T *array, Nd4jLong left, + Nd4jLong right, int cutoff, + int rank) { + Nd4jLong span = right - left; // elements to be partitioned - 1 + + if (span == 1) { + // only 2 elements to partition. swap if needed and return directly without + // further sorting. + if (ltIndices(indices, rank, right, left)) { + swapEverything(indices, array, rank, left, right); } + return; + } + + // find optimal pivot and sort left < right < right + Nd4jLong pvt = coo_quickSort_findPivot(indices, array, left, right, rank); + + if (span == 2) { + // only 3 elements to partition. findPivot has already sorted them. no + // further sorting is needed. + return; + } + + // index that is greater than pivot - leftmost element is already partitioned + // because of findPivot. + Nd4jLong i = left + 1; + + // index that is smaller than pivot - rightmost element is already partitioned + // because of findPivot. + Nd4jLong j = right - 1; + + { + // flag that indicates that pivot index lies between i and j and *could* be + // swapped. + bool checkPivot = true; + /* PARTITION PART */ + while (i <= j) { + while (ltIndices(indices, rank, i, pvt)) i++; + + while (gtIndices(indices, rank, j, pvt)) j--; + + if (i <= j) { + if (i != j) { // swap can be fairly expensive. don't swap i -> i + swapEverything(indices, array, rank, i, j); + } - template - void SparseUtils::coo_quickSort_parallel_internal(Nd4jLong *indices, T* array, Nd4jLong left, Nd4jLong right, int cutoff, int rank) { - Nd4jLong span = right - left; // elements to be partitioned - 1 - - if (span == 1){ - // only 2 elements to partition. swap if needed and return directly without further sorting. - if (ltIndices(indices, rank, right, left)){ - swapEverything(indices, array, rank, left, right); - } - return; - } - - - // find optimal pivot and sort left < right < right - Nd4jLong pvt = coo_quickSort_findPivot(indices, array, left, right, rank); - - if (span == 2){ - // only 3 elements to partition. findPivot has already sorted them. no further sorting is needed. - return; - } - - // index that is greater than pivot - leftmost element is already partitioned because of findPivot. - Nd4jLong i = left + 1; - - // index that is smaller than pivot - rightmost element is already partitioned because of findPivot. - Nd4jLong j = right - 1; - - - { - // flag that indicates that pivot index lies between i and j and *could* be swapped. - bool checkPivot = true; - /* PARTITION PART */ - while (i <= j) { - while (ltIndices(indices, rank, i, pvt)) - i++; - - while (gtIndices(indices, rank, j, pvt)) - j--; - - - if (i <= j) { - if(i != j) { // swap can be fairly expensive. don't swap i -> i - swapEverything(indices, array, rank, i, j); - } - - // only check pivot if it hasn't already been swapped. - if (checkPivot) { - // check if we moved the pivot, if so, change pivot index accordingly - if (pvt == j) { - pvt = i; - checkPivot = false; - } else if (pvt == i) { - pvt = j; - checkPivot = false; - } - } - - i++; - j--; - } - } - - } - - if ( (span < cutoff) ){ - if (left < j){ coo_quickSort_parallel_internal(indices, array, left, j, cutoff, rank); } - if (i < right){ coo_quickSort_parallel_internal(indices, array, i, right, cutoff, rank); } - - }else{ -PRAGMA_OMP_TASK - { coo_quickSort_parallel_internal(indices, array, left, j, cutoff, rank); } -PRAGMA_OMP_TASK - { coo_quickSort_parallel_internal(indices, array, i, right, cutoff, rank); } - } - + // only check pivot if it hasn't already been swapped. + if (checkPivot) { + // check if we moved the pivot, if so, change pivot index accordingly + if (pvt == j) { + pvt = i; + checkPivot = false; + } else if (pvt == i) { + pvt = j; + checkPivot = false; + } } - template - void SparseUtils::coo_quickSort_parallel(Nd4jLong *indices, T* array, Nd4jLong lenArray, int numThreads, int rank){ + i++; + j--; + } + } + } - int cutoff = 1000; + if ((span < cutoff)) { + if (left < j) { + coo_quickSort_parallel_internal(indices, array, left, j, cutoff, rank); + } + if (i < right) { + coo_quickSort_parallel_internal(indices, array, i, right, cutoff, rank); + } - PRAGMA_OMP_PARALLEL_THREADS(numThreads) - { -PRAGMA_OMP_SINGLE_ARGS(nowait) - { - coo_quickSort_parallel_internal(indices, array, 0, lenArray-1, cutoff, rank); - } - } + } else { + PRAGMA_OMP_TASK { + coo_quickSort_parallel_internal(indices, array, left, j, cutoff, rank); + } + PRAGMA_OMP_TASK { + coo_quickSort_parallel_internal(indices, array, i, right, cutoff, rank); + } + } +} - } +template +void SparseUtils::coo_quickSort_parallel(Nd4jLong *indices, T *array, + Nd4jLong lenArray, int numThreads, + int rank) { + int cutoff = 1000; - template - void SparseUtils::sortCooIndicesGeneric(Nd4jLong *indices, T *values, Nd4jLong length, int rank) { + PRAGMA_OMP_PARALLEL_THREADS(numThreads) { + PRAGMA_OMP_SINGLE_ARGS(nowait) { + coo_quickSort_parallel_internal(indices, array, 0, lenArray - 1, cutoff, + rank); + } + } +} + +template +void SparseUtils::sortCooIndicesGeneric(Nd4jLong *indices, T *values, + Nd4jLong length, int rank) { #ifdef _OPENMP - coo_quickSort_parallel(indices, values, length, omp_get_max_threads(), rank); + coo_quickSort_parallel(indices, values, length, omp_get_max_threads(), rank); #else - coo_quickSort_parallel(indices, values, length, 1, rank); + coo_quickSort_parallel(indices, values, length, 1, rank); #endif - } - - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT SparseUtils, , LIBND4J_TYPES); - } } + +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT SparseUtils, , LIBND4J_TYPES); +} // namespace sparse +} // namespace sd diff --git a/libnd4j/include/ops/meta_ops.h b/libnd4j/include/ops/meta_ops.h index a2120477d72e..defa1d0853d3 100644 --- a/libnd4j/include/ops/meta_ops.h +++ b/libnd4j/include/ops/meta_ops.h @@ -18,157 +18,157 @@ #ifndef FUSED_OPS_H_ #define FUSED_OPS_H_ -#include -#include - #include +#include +#include namespace metaOps { - /** - * InvertedMetaOp shares the same idea as MetaOp, but op being applied to op.Y in pairwise/broadcast ops +/** + * InvertedMetaOp shares the same idea as MetaOp, but op being applied to op.Y + * in pairwise/broadcast ops + */ +template +class InvertedMetaOp { + public: + no_op_exec_special no_op_exec_special_cuda + + /* + * PREDICATE + */ + + // scalar, transform, reduce, indexreduce entry + op_def static T + op(T d1, T *params) { + /* + * We assume, that this method won't be EVER called + */ + printf("You should NEVER see this message in output\n"); + return (T)0.0f; + } + + // PWT, broadcast entry. Predicate can be only scalar, transform + op_def static T op(T d1, T d2, T *params) { + Nd4jPointer *wrap = reinterpret_cast(params); + T *paramsA = reinterpret_cast(wrap[0]); + T *paramsB = reinterpret_cast(wrap[1]); + + return OpTypeB::op(OpTypeA::op(d1, d2, paramsA), paramsB); + } + + /* + * POSTULATE + */ + + // will be called for reduce, reduce3 + op_def static T postProcess(T reduction, Nd4jLong n, T *params) { + /* + * We assume, that this method won't be EVER called + */ + printf("You should NEVER EVER see this message in output\n"); + + return (T)0.0f; + } +}; + +/** + * Special case here: MetaOp which consist of 2 operations. + * + * Predicate can be either scalar or transform, to process data before actual op + * call Postulate will be the scalar/transform, but will be applied to result of + * broadcast/reduce/reduce3 + */ +template +class MetaOp { + public: + no_op_exec_special no_op_exec_special_cuda + + /* + * PREDICATE + */ + + meta_def static T + startingValue(const T *input) { + return (T)0.0f; + } + + // scalar, transform, reduce, indexreduce entry + meta_def static T op(T d1, T *params) { + /* + * We assume, that params for MetaOp is a set of pointers to actual op A & B + * extraArgs */ - template - class InvertedMetaOp { - public: - no_op_exec_special - no_op_exec_special_cuda - - /* - * PREDICATE - */ - - // scalar, transform, reduce, indexreduce entry - op_def static T op(T d1, T *params) { - /* - * We assume, that this method won't be EVER called - */ - printf("You should NEVER see this message in output\n"); - return (T) 0.0f; - } - - // PWT, broadcast entry. Predicate can be only scalar, transform - op_def static T op(T d1, T d2, T *params) { - Nd4jPointer *wrap = reinterpret_cast (params); - T *paramsA = reinterpret_cast (wrap[0]); - T *paramsB = reinterpret_cast (wrap[1]); - - return OpTypeB::op(OpTypeA::op(d1, d2, paramsA), paramsB); - } - - /* - * POSTULATE - */ - - // will be called for reduce, reduce3 - op_def static T postProcess(T reduction, Nd4jLong n, T *params) { - /* - * We assume, that this method won't be EVER called - */ - printf("You should NEVER EVER see this message in output\n"); - - return (T) 0.0f; - } - }; - - - /** - * Special case here: MetaOp which consist of 2 operations. - * - * Predicate can be either scalar or transform, to process data before actual op call - * Postulate will be the scalar/transform, but will be applied to result of broadcast/reduce/reduce3 - */ - template - class MetaOp { - public: - no_op_exec_special - no_op_exec_special_cuda - - /* - * PREDICATE - */ - - meta_def static T startingValue(const T *input) { - return (T) 0.0f; - } - - // scalar, transform, reduce, indexreduce entry - meta_def static T op(T d1, T *params) { - /* - * We assume, that params for MetaOp is a set of pointers to actual op A & B extraArgs - */ - Nd4jPointer *wrap = reinterpret_cast (params); - T *paramsA = reinterpret_cast (wrap[0]); - T *paramsB = reinterpret_cast (wrap[1]); - - return OpTypeB::op(OpTypeA::op(d1, paramsA), paramsB); - } - - // PWT, broadcast entry. Predicate can be only scalar, transform - meta_def static T op(T d1, T d2, T *params) { - Nd4jPointer *wrap = reinterpret_cast (params); - T *paramsA = reinterpret_cast (wrap[0]); - T *paramsB = reinterpret_cast (wrap[1]); - - return OpTypeB::op(OpTypeA::op(d1, paramsA), d2, paramsB); - } - - /* - * POSTULATE - */ - - // will be called for reduce, reduce3 - meta_def static T postProcess(T reduction, Nd4jLong n, T *params) { - Nd4jPointer *wrap = reinterpret_cast (params); - T *paramsA = reinterpret_cast (wrap[0]); - T *paramsB = reinterpret_cast (wrap[1]); - - return OpTypeB::op(OpTypeA::postProcess(reduction, n, paramsA), paramsB); - } - }; - - - template - class ReduceMetaOp { - public: - no_op_exec_special - no_op_exec_special_cuda - - meta_def static T startingValue(const T *input) { - return OpTypeB::startingValue(input); - } - - meta_def static T merge(T old, T opOutput, T *params) { - Nd4jPointer *wrap = reinterpret_cast (params); -// T *paramsA = reinterpret_cast (wrap[0]); - T *paramsB = reinterpret_cast (wrap[1]); - - return OpTypeB::merge(old, opOutput, paramsB); - } - - meta_def static T update(T old, T opOutput, T *params) { - Nd4jPointer *wrap = reinterpret_cast (params); - //T *paramsA = reinterpret_cast (wrap[0]); - T *paramsB = reinterpret_cast (wrap[1]); - - return OpTypeB::update(old, opOutput, paramsB); - } - - meta_def static T op(T d1, T *params) { - Nd4jPointer *wrap = reinterpret_cast (params); - T *paramsA = reinterpret_cast (wrap[0]); - T *paramsB = reinterpret_cast (wrap[1]); - - return OpTypeB::op(OpTypeA::op(d1, paramsA), paramsB); - } - - meta_def static T postProcess(T reduction, Nd4jLong n, T *params) { - Nd4jPointer *wrap = reinterpret_cast (params); -// T *paramsA = reinterpret_cast (wrap[0]); - T *paramsB = reinterpret_cast (wrap[1]); - - return OpTypeB::postProcess(reduction, n, paramsB); - } - }; -} + Nd4jPointer *wrap = reinterpret_cast(params); + T *paramsA = reinterpret_cast(wrap[0]); + T *paramsB = reinterpret_cast(wrap[1]); + + return OpTypeB::op(OpTypeA::op(d1, paramsA), paramsB); + } + + // PWT, broadcast entry. Predicate can be only scalar, transform + meta_def static T op(T d1, T d2, T *params) { + Nd4jPointer *wrap = reinterpret_cast(params); + T *paramsA = reinterpret_cast(wrap[0]); + T *paramsB = reinterpret_cast(wrap[1]); + + return OpTypeB::op(OpTypeA::op(d1, paramsA), d2, paramsB); + } + + /* + * POSTULATE + */ + + // will be called for reduce, reduce3 + meta_def static T postProcess(T reduction, Nd4jLong n, T *params) { + Nd4jPointer *wrap = reinterpret_cast(params); + T *paramsA = reinterpret_cast(wrap[0]); + T *paramsB = reinterpret_cast(wrap[1]); + + return OpTypeB::op(OpTypeA::postProcess(reduction, n, paramsA), paramsB); + } +}; + +template +class ReduceMetaOp { + public: + no_op_exec_special no_op_exec_special_cuda + + meta_def static T + startingValue(const T *input) { + return OpTypeB::startingValue(input); + } + + meta_def static T merge(T old, T opOutput, T *params) { + Nd4jPointer *wrap = reinterpret_cast(params); + // T *paramsA = reinterpret_cast (wrap[0]); + T *paramsB = reinterpret_cast(wrap[1]); + + return OpTypeB::merge(old, opOutput, paramsB); + } + + meta_def static T update(T old, T opOutput, T *params) { + Nd4jPointer *wrap = reinterpret_cast(params); + // T *paramsA = reinterpret_cast (wrap[0]); + T *paramsB = reinterpret_cast(wrap[1]); + + return OpTypeB::update(old, opOutput, paramsB); + } + + meta_def static T op(T d1, T *params) { + Nd4jPointer *wrap = reinterpret_cast(params); + T *paramsA = reinterpret_cast(wrap[0]); + T *paramsB = reinterpret_cast(wrap[1]); + + return OpTypeB::op(OpTypeA::op(d1, paramsA), paramsB); + } + + meta_def static T postProcess(T reduction, Nd4jLong n, T *params) { + Nd4jPointer *wrap = reinterpret_cast(params); + // T *paramsA = reinterpret_cast (wrap[0]); + T *paramsB = reinterpret_cast(wrap[1]); + + return OpTypeB::postProcess(reduction, n, paramsB); + } +}; +} // namespace metaOps #endif \ No newline at end of file diff --git a/libnd4j/include/ops/ops.h b/libnd4j/include/ops/ops.h index aca6fec6f140..88d2d3b2dfc6 100644 --- a/libnd4j/include/ops/ops.h +++ b/libnd4j/include/ops/ops.h @@ -18,13 +18,14 @@ #ifndef OPS_H_ #define OPS_H_ -#include #include #include -#include -#include -#include #include +#include +#include +#include + +#include #define MIN_V 1e-12 #define MAX_FLOAT 1e37 @@ -37,21 +38,94 @@ #define DOUBLE_PI_T T(2.0 * 3.14159265358979323846) #define DOUBLE_PI_X X(2.0 * 3.14159265358979323846) -#define no_op_exec_special_any static const bool requiresSpecial = false; static void execSpecial(const X *dx, const Nd4jLong *xShapeBuffer, Z *result, const Nd4jLong *resultShapeBuffer, X *extraParams, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {} -#define no_op_exec_special_bool static const bool requiresSpecial = false; static void execSpecial(const X *dx, const Nd4jLong *xShapeBuffer, Z *result, const Nd4jLong *resultShapeBuffer, X *extraParams, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {} -#define no_op_exec_special_same static const bool requiresSpecial = false; static void execSpecial(const X *dx, const Nd4jLong *xShapeBuffer, X *result, const Nd4jLong *resultShapeBuffer, X *extraParams, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {} -#define no_op_exec_special static const bool requiresSpecial = false; static void execSpecial(const X *dx, const Nd4jLong *xShapeBuffer, Z *result, const Nd4jLong *resultShapeBuffer, Z *extraParams, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {} -#define no_op_exec_special_accumulation static const bool requiresSpecialAccumulation = false; static void execSpecial(const X *x, const Nd4jLong *xShapeInfo, Z *extraParams, Z *result, const Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset){} -#define no_op_exec_special_accumulation_long static const bool requiresSpecialAccumulation = false; static void execSpecial(const X *x, const Nd4jLong *xShapeInfo, X *extraParams, Z *result, const Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset){} -#define no_op_exec_special_accumulation_same static const bool requiresSpecialAccumulation = false; static void execSpecial(const X *x, const Nd4jLong *xShapeInfo, X *extraParams, X *result, const Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset){} +#define no_op_exec_special_any \ + static const bool requiresSpecial = false; \ + static void execSpecial(const X *dx, const Nd4jLong *xShapeBuffer, \ + Z *result, const Nd4jLong *resultShapeBuffer, \ + X *extraParams, const Nd4jLong *tadShapeInfo, \ + const Nd4jLong *tadOffsets) {} +#define no_op_exec_special_bool \ + static const bool requiresSpecial = false; \ + static void execSpecial(const X *dx, const Nd4jLong *xShapeBuffer, \ + Z *result, const Nd4jLong *resultShapeBuffer, \ + X *extraParams, const Nd4jLong *tadShapeInfo, \ + const Nd4jLong *tadOffsets) {} +#define no_op_exec_special_same \ + static const bool requiresSpecial = false; \ + static void execSpecial(const X *dx, const Nd4jLong *xShapeBuffer, \ + X *result, const Nd4jLong *resultShapeBuffer, \ + X *extraParams, const Nd4jLong *tadShapeInfo, \ + const Nd4jLong *tadOffsets) {} +#define no_op_exec_special \ + static const bool requiresSpecial = false; \ + static void execSpecial(const X *dx, const Nd4jLong *xShapeBuffer, \ + Z *result, const Nd4jLong *resultShapeBuffer, \ + Z *extraParams, const Nd4jLong *tadShapeInfo, \ + const Nd4jLong *tadOffsets) {} +#define no_op_exec_special_accumulation \ + static const bool requiresSpecialAccumulation = false; \ + static void execSpecial( \ + const X *x, const Nd4jLong *xShapeInfo, Z *extraParams, Z *result, \ + const Nd4jLong *resultShapeInfoBuffer, int *dimension, \ + int dimensionLength, const Nd4jLong *tadShapeInfo, \ + const Nd4jLong *tadOffset) {} +#define no_op_exec_special_accumulation_long \ + static const bool requiresSpecialAccumulation = false; \ + static void execSpecial( \ + const X *x, const Nd4jLong *xShapeInfo, X *extraParams, Z *result, \ + const Nd4jLong *resultShapeInfoBuffer, int *dimension, \ + int dimensionLength, const Nd4jLong *tadShapeInfo, \ + const Nd4jLong *tadOffset) {} +#define no_op_exec_special_accumulation_same \ + static const bool requiresSpecialAccumulation = false; \ + static void execSpecial( \ + const X *x, const Nd4jLong *xShapeInfo, X *extraParams, X *result, \ + const Nd4jLong *resultShapeInfoBuffer, int *dimension, \ + int dimensionLength, const Nd4jLong *tadShapeInfo, \ + const Nd4jLong *tadOffset) {} #ifdef __CUDACC__ -#define no_op_exec_special_any_cuda static __device__ void execSpecialCuda(const X *dx, const Nd4jLong *xShapeBuffer, Z *result, const Nd4jLong *resultShapeBuffer, X *extraParams, int *allocationPointer, Z *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {} -#define no_op_exec_special_bool_cuda static __device__ void execSpecialCuda(const X *dx, const Nd4jLong *xShapeBuffer, Z *result, const Nd4jLong *resultShapeBuffer, X *extraParams, int *allocationPointer, Z *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {} -#define no_op_exec_special_same_cuda static __device__ void execSpecialCuda(const X *dx, const Nd4jLong *xShapeBuffer, X *result, const Nd4jLong *resultShapeBuffer, X *extraParams, int *allocationPointer, X *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {} -#define no_op_exec_special_cuda static __device__ void execSpecialCuda(const X *dx, const Nd4jLong *xShapeBuffer,Z *result, const Nd4jLong *resultShapeBuffer,Z *extraParams, int *allocationPointer, Z *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {} -#define no_op_exec_special_accumulation_same_cuda static inline __device__ void execSpecialCuda(const X *dx, const Nd4jLong *xShapeInfo, X *extraParams, X *result, const Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, X *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets) {} -#define no_op_exec_special_accumulation_long_cuda static inline __device__ void execSpecialCuda(const X *dx, const Nd4jLong *xShapeInfo, X *extraParams, Z *result, const Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Z *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets) {} -#define no_op_exec_special_accumulation_cuda static inline __device__ void execSpecialCuda(const X *dx, const Nd4jLong *xShapeInfo, Z *extraParams, Z *result, const Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Z *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets) {} +#define no_op_exec_special_any_cuda \ + static __device__ void execSpecialCuda( \ + const X *dx, const Nd4jLong *xShapeBuffer, Z *result, \ + const Nd4jLong *resultShapeBuffer, X *extraParams, \ + int *allocationPointer, Z *reductionPointer, \ + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {} +#define no_op_exec_special_bool_cuda \ + static __device__ void execSpecialCuda( \ + const X *dx, const Nd4jLong *xShapeBuffer, Z *result, \ + const Nd4jLong *resultShapeBuffer, X *extraParams, \ + int *allocationPointer, Z *reductionPointer, \ + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {} +#define no_op_exec_special_same_cuda \ + static __device__ void execSpecialCuda( \ + const X *dx, const Nd4jLong *xShapeBuffer, X *result, \ + const Nd4jLong *resultShapeBuffer, X *extraParams, \ + int *allocationPointer, X *reductionPointer, \ + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {} +#define no_op_exec_special_cuda \ + static __device__ void execSpecialCuda( \ + const X *dx, const Nd4jLong *xShapeBuffer, Z *result, \ + const Nd4jLong *resultShapeBuffer, Z *extraParams, \ + int *allocationPointer, Z *reductionPointer, \ + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {} +#define no_op_exec_special_accumulation_same_cuda \ + static inline __device__ void execSpecialCuda( \ + const X *dx, const Nd4jLong *xShapeInfo, X *extraParams, X *result, \ + const Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, \ + X *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, \ + const Nd4jLong *tadOffsets) {} +#define no_op_exec_special_accumulation_long_cuda \ + static inline __device__ void execSpecialCuda( \ + const X *dx, const Nd4jLong *xShapeInfo, X *extraParams, Z *result, \ + const Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, \ + Z *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, \ + const Nd4jLong *tadOffsets) {} +#define no_op_exec_special_accumulation_cuda \ + static inline __device__ void execSpecialCuda( \ + const X *dx, const Nd4jLong *xShapeInfo, Z *extraParams, Z *result, \ + const Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, \ + Z *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, \ + const Nd4jLong *tadOffsets) {} #else // hacky fix for isnan/being being out of scope @@ -73,4602 +147,4243 @@ #define no_op_exec_special_accumulation_same_cuda #endif - #define SELU_ALPHA 1.6732632423543772848170429916717 #define SELU_LAMBDA 1.0507009873554804934193349852946 - namespace functions { - namespace indexreduce { - template - struct IndexValue { - T value; - Nd4jLong index; - _CUDA_HD IndexValue() = default; - _CUDA_HD IndexValue(const T val, const Nd4jLong ind): index(ind), value(val) {} - }; - } - - namespace summarystats { - template - class SummaryStatsData; - } +namespace indexreduce { +template +struct IndexValue { + T value; + Nd4jLong index; + _CUDA_HD IndexValue() = default; + _CUDA_HD IndexValue(const T val, const Nd4jLong ind) + : index(ind), value(val) {} +}; +} // namespace indexreduce + +namespace summarystats { +template +class SummaryStatsData; } +} // namespace functions namespace simdOps { - template - class Add { - public: - op_def static Z op(X d1, Y d2) { - return static_cast(d1 + d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return static_cast(d1 + d2); - } - - op_def static Z op(X d1) { - return static_cast(d1); - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return static_cast(d1 + params[0]); - } - - op_def static X startingValue() { - return static_cast(0.f); - } - }; - - template - class NewAdd { - public: - op_def static X op(X d1, Y d2, X *params) { - return d1 + d2; - } - }; - - template - class Subtract { - public: - op_def static Z op(X d1, Y d2) { - return static_cast(d1 - d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return static_cast(d1 - d2); - } - - op_def static Z op(X d1) { - return static_cast(d1); - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return static_cast(d1 - params[0]); - } - - }; - - template - class SquaredSubtract { - public: - op_def static Z op(X d1, Y d2) { - auto d = static_cast(d1 - d2); - return d * d; - } - - op_def static Z op(X d1, Y d2, Z *params) { - auto d = static_cast(d1 - d2); - return d * d; - } - - op_def static Z op(X d1) { - return d1; - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - auto d = static_cast(d1 - params[0]); - return d * d; - } - }; - - template - class SquaredReverseSubtract { - public: - op_def static Z op(X d1, Y d2) { - auto d = static_cast(d2 - d1); - return d * d; - } - - op_def static Z op(X d1, Y d2, Z *params) { - auto d = static_cast(d2 - d1); - return d * d; - } - - op_def static Z op(X d1) { - return d1; - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - auto d = static_cast(params[0] - d1); - return d * d; - } - }; - - template - class ReverseSubtract { - public: - op_def static Z op(X d1, Y d2) { - return static_cast(d2 - d1); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return static_cast(d2 - d1); - } - - op_def static Z op(X d1) { - return d1; - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return static_cast(params[0] - d1); - } - }; - - - template - class LogPoissonLossFull { - - public: - op_def static Z op(X z, Y c) { - auto zz = static_cast(z); - auto zc = static_cast(c); - return (sd::math::nd4j_exp(c) - zz * zc + (zz * sd::math::nd4j_log(z) - zz + static_cast(0.5f) * sd::math::nd4j_log(static_cast(DOUBLE_PI_X) * zz))); - } - - op_def static Z op(X z, Y c, Z *params) { - auto zz = static_cast(z); - auto zc = static_cast(c); - return (sd::math::nd4j_exp(c) - zz * zc + (zz * sd::math::nd4j_log(z) - zz + static_cast(0.5f) * sd::math::nd4j_log(static_cast(DOUBLE_PI_X) * zz))); - } - - op_def static Z op(X z) { - auto zz = static_cast(z); - return (zz * sd::math::nd4j_log(z) - zz + static_cast(0.5f) * sd::math::nd4j_log(static_cast(DOUBLE_PI_X) * zz)); - } - - // op for MetaOps - op_def static X op(X z, Y *params) { - return (sd::math::nd4j_exp(params[0]) - z * params[0] + (z * sd::math::nd4j_log(z) - z + static_cast(0.5f) * sd::math::nd4j_log(DOUBLE_PI_X * z))); - } - }; - - template - class LogPoissonLoss { - - public: - op_def static Z op(X z, Y c) { - auto zz = static_cast(z); - auto zc = static_cast(c); - return (sd::math::nd4j_exp(c) - zz * zc); - } - - op_def static Z op(X z, Y c, Z *params) { - auto zz = static_cast(z); - auto zc = static_cast(c); - return (sd::math::nd4j_exp(c) - zz * zc); - } - - op_def static Z op(X z) { - return static_cast(z); - } - - // op for MetaOps - op_def static Z op(X z, Y *params) { - return (sd::math::nd4j_exp(params[0]) - static_cast(z) * static_cast(params[0])); - } - }; - - template - class Multiply { - public: - op_def static Z op(X d1, Y d2) { - return static_cast(d1 * d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return static_cast(d1 * d2); - } - - op_def static Z op(X d1) { - return static_cast(d1); - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return static_cast(d1 * params[0]); - } - - op_def static X startingValue() { - return static_cast(1.f); - } - }; - - template - class Divide { - public: - op_def static Z op(X d1, Y d2) { - return static_cast(d1 / d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return static_cast(d1 / d2); - } - - op_def static Z op(X d1) { - return static_cast(d1); - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return static_cast(d1 / params[0]); - } - - op_def static X startingValue() { - return static_cast(1); - } - }; - - template - class DivideNoNan { - public: - op_def static Z op(X d1, Y d2) { - if (d2 == (Y)0) return (Z)0; - return static_cast(d1 / d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - if (d2 == (Y)0) return (Z)0; - return static_cast(d1 / d2); - } - - op_def static Z op(X d1) { - return static_cast(d1); - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - if (params[0] == (Y)0) return (Z)0; - return static_cast(d1 / params[0]); - } - - op_def static X startingValue() { - return static_cast(1); - } - }; - - template - class SafeDivide { - public: - op_def static Z op(X d1, Y d2) { - if(d2 == static_cast(0)) - return static_cast(0); - return static_cast(d1 / d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - if(d2 == static_cast(0)) - return static_cast(0); - return static_cast(d1 / d2); - } - - op_def static Z op(X d1) { - return static_cast(d1); - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - if(params[0] == static_cast(0)) - return static_cast(0); - return static_cast(d1 / params[0]); - } - }; - - template - class FloorDiv { - public: - op_def static Z op(X d1, Y d2) { - return sd::math::nd4j_floor(static_cast(d1 / d2)); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_floor(static_cast(d1 / d2)); - } - - op_def static Z op(X d1) { - return sd::math::nd4j_floor(static_cast(d1)); - } +template +class Add { + public: + op_def static Z op(X d1, Y d2) { return static_cast(d1 + d2); } + + op_def static Z op(X d1, Y d2, Z *params) { return static_cast(d1 + d2); } + + op_def static Z op(X d1) { return static_cast(d1); } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { return static_cast(d1 + params[0]); } + + op_def static X startingValue() { return static_cast(0.f); } +}; + +template +class NewAdd { + public: + op_def static X op(X d1, Y d2, X *params) { return d1 + d2; } +}; + +template +class Subtract { + public: + op_def static Z op(X d1, Y d2) { return static_cast(d1 - d2); } + + op_def static Z op(X d1, Y d2, Z *params) { return static_cast(d1 - d2); } + + op_def static Z op(X d1) { return static_cast(d1); } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { return static_cast(d1 - params[0]); } +}; + +template +class SquaredSubtract { + public: + op_def static Z op(X d1, Y d2) { + auto d = static_cast(d1 - d2); + return d * d; + } + + op_def static Z op(X d1, Y d2, Z *params) { + auto d = static_cast(d1 - d2); + return d * d; + } + + op_def static Z op(X d1) { return d1; } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { + auto d = static_cast(d1 - params[0]); + return d * d; + } +}; + +template +class SquaredReverseSubtract { + public: + op_def static Z op(X d1, Y d2) { + auto d = static_cast(d2 - d1); + return d * d; + } + + op_def static Z op(X d1, Y d2, Z *params) { + auto d = static_cast(d2 - d1); + return d * d; + } + + op_def static Z op(X d1) { return d1; } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { + auto d = static_cast(params[0] - d1); + return d * d; + } +}; + +template +class ReverseSubtract { + public: + op_def static Z op(X d1, Y d2) { return static_cast(d2 - d1); } + + op_def static Z op(X d1, Y d2, Z *params) { return static_cast(d2 - d1); } + + op_def static Z op(X d1) { return d1; } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { return static_cast(params[0] - d1); } +}; + +template +class LogPoissonLossFull { + public: + op_def static Z op(X z, Y c) { + auto zz = static_cast(z); + auto zc = static_cast(c); + return (sd::math::nd4j_exp(c) - zz * zc + + (zz * sd::math::nd4j_log(z) - zz + + static_cast(0.5f) * + sd::math::nd4j_log(static_cast(DOUBLE_PI_X) * zz))); + } + + op_def static Z op(X z, Y c, Z *params) { + auto zz = static_cast(z); + auto zc = static_cast(c); + return (sd::math::nd4j_exp(c) - zz * zc + + (zz * sd::math::nd4j_log(z) - zz + + static_cast(0.5f) * + sd::math::nd4j_log(static_cast(DOUBLE_PI_X) * zz))); + } + + op_def static Z op(X z) { + auto zz = static_cast(z); + return (zz * sd::math::nd4j_log(z) - zz + + static_cast(0.5f) * + sd::math::nd4j_log(static_cast(DOUBLE_PI_X) * zz)); + } + + // op for MetaOps + op_def static X op(X z, Y *params) { + return (sd::math::nd4j_exp(params[0]) - z * params[0] + + (z * sd::math::nd4j_log(z) - z + + static_cast(0.5f) * sd::math::nd4j_log(DOUBLE_PI_X * z))); + } +}; + +template +class LogPoissonLoss { + public: + op_def static Z op(X z, Y c) { + auto zz = static_cast(z); + auto zc = static_cast(c); + return (sd::math::nd4j_exp(c) - zz * zc); + } + + op_def static Z op(X z, Y c, Z *params) { + auto zz = static_cast(z); + auto zc = static_cast(c); + return (sd::math::nd4j_exp(c) - zz * zc); + } + + op_def static Z op(X z) { return static_cast(z); } + + // op for MetaOps + op_def static Z op(X z, Y *params) { + return (sd::math::nd4j_exp(params[0]) - + static_cast(z) * static_cast(params[0])); + } +}; + +template +class Multiply { + public: + op_def static Z op(X d1, Y d2) { return static_cast(d1 * d2); } + + op_def static Z op(X d1, Y d2, Z *params) { return static_cast(d1 * d2); } + + op_def static Z op(X d1) { return static_cast(d1); } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { return static_cast(d1 * params[0]); } + + op_def static X startingValue() { return static_cast(1.f); } +}; + +template +class Divide { + public: + op_def static Z op(X d1, Y d2) { return static_cast(d1 / d2); } + + op_def static Z op(X d1, Y d2, Z *params) { return static_cast(d1 / d2); } + + op_def static Z op(X d1) { return static_cast(d1); } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { return static_cast(d1 / params[0]); } + + op_def static X startingValue() { return static_cast(1); } +}; + +template +class DivideNoNan { + public: + op_def static Z op(X d1, Y d2) { + if (d2 == (Y)0) return (Z)0; + return static_cast(d1 / d2); + } + + op_def static Z op(X d1, Y d2, Z *params) { + if (d2 == (Y)0) return (Z)0; + return static_cast(d1 / d2); + } + + op_def static Z op(X d1) { return static_cast(d1); } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { + if (params[0] == (Y)0) return (Z)0; + return static_cast(d1 / params[0]); + } + + op_def static X startingValue() { return static_cast(1); } +}; + +template +class SafeDivide { + public: + op_def static Z op(X d1, Y d2) { + if (d2 == static_cast(0)) return static_cast(0); + return static_cast(d1 / d2); + } + + op_def static Z op(X d1, Y d2, Z *params) { + if (d2 == static_cast(0)) return static_cast(0); + return static_cast(d1 / d2); + } + + op_def static Z op(X d1) { return static_cast(d1); } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { + if (params[0] == static_cast(0)) return static_cast(0); + return static_cast(d1 / params[0]); + } +}; + +template +class FloorDiv { + public: + op_def static Z op(X d1, Y d2) { + return sd::math::nd4j_floor(static_cast(d1 / d2)); + } + + op_def static Z op(X d1, Y d2, Z *params) { + return sd::math::nd4j_floor(static_cast(d1 / d2)); + } + + op_def static Z op(X d1) { + return sd::math::nd4j_floor(static_cast(d1)); + } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { + return sd::math::nd4j_floor(static_cast(d1 / params[0])); + } +}; + +template +class TruncateDiv { + public: + op_def static Z op(X d1, Y d2) { + auto i1 = static_cast(d1); + auto i2 = static_cast(d2); + return static_cast(i1 / i2); + } + + op_def static Z op(X d1, Y d2, Z *params) { + auto i1 = static_cast(d1); + auto i2 = static_cast(d2); + return static_cast(i1 / i2); + } + + op_def static Z op(X d1) { return d1; } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { + auto i1 = static_cast(d1); + auto i2 = static_cast(params[0]); + return static_cast(i1 / i2); + } +}; + +template +class TruncateMod { + public: + op_def static Z op(X d1, Y d2) { + auto i1 = static_cast(d1); + auto i2 = static_cast(d2); + return static_cast(i1 % i2); + } + + op_def static Z op(X d1, Y d2, Z *params) { + auto i1 = static_cast(d1); + auto i2 = static_cast(d2); + return static_cast(i1 % i2); + } + + op_def static Z op(X d1) { return static_cast(d1); } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { + auto i1 = static_cast(d1); + auto i2 = static_cast(params[0]); + return static_cast(i1 % i2); + } +}; + +template +class Remainder { + public: + op_def static Z op(X d1, Y d2) { + return sd::math::nd4j_remainder(d1, d2); + } + + op_def static Z op(X d1, Y d2, Z *params) { + return sd::math::nd4j_remainder(d1, d2); + } + + op_def static Z op(X d1) { return d1; } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { + return sd::math::nd4j_remainder(d1, params[0]); + } +}; + +template +class FMod { + public: + op_def static Z op(X d1, Y d2) { + return sd::math::nd4j_fmod(d1, d2); + } + + op_def static Z op(X d1, Y d2, Z *params) { + return sd::math::nd4j_fmod(d1, d2); + } + + op_def static Z op(X d1) { return d1; } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { + return sd::math::nd4j_fmod(d1, params[0]); + } +}; + +template +class FloorMod { + public: + op_def static Z op(X d1, Y d2) { + auto m = sd::math::nd4j_fmod(d1, d2); + return (d1 < static_cast(0)) == (d2 < static_cast(0)) + ? m + : sd::math::nd4j_fmod(m + static_cast(d2), d2); + } + + op_def static Z op(X d1, Y d2, Z *params) { + auto m = sd::math::nd4j_fmod(d1, d2); + return (d1 < static_cast(0.0f)) == (d2 < static_cast(0)) + ? m + : sd::math::nd4j_fmod(m + static_cast(d2), d2); + } + + op_def static Z op(X d1) { return d1; } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { return op(d1, params[0]); } +}; + +template +class ReverseDivide { + public: + op_def static Z op(X d1, Y d2) { return static_cast(d2 / d1); } + + op_def static Z op(X d1, Y d2, Z *params) { return static_cast(d2 / d1); } + + op_def static Z op(X d1) { return static_cast(d1); } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { return static_cast(params[0] / d1); } +}; + +template +class CopyPws { + public: + op_def static Z op(X d1, Y d2) { return static_cast(d2); } + + op_def static Z op(X d1, Y d2, Z *params) { return static_cast(d2); } + + op_def static Z op(X d1) { return static_cast(d1); } + + op_def static Z op(X d1, Y *params) { return static_cast(d1); } +}; + +template +class Copy { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1; + } +}; + +template +class Copy2 { + public: + op_def static Z op(X d1, Y d2) { return static_cast(d2); } + + op_def static Z op(X d1, Y d2, Z *params) { return static_cast(d2); } + + op_def static Z op(X d1) { return static_cast(d1); } + + op_def static Z op(X d1, Y *params) { return static_cast(d1); } +}; + +template +class Axpy { + public: + op_def static Z op(X d1, Y d2) { return static_cast(d2 + d1); } + + op_def static Z op(X d1, Y d2, Z *params) { + auto alpha = params[0]; + return alpha * static_cast(d1) + static_cast(d2); + } + + op_def static Z op(X d1) { return static_cast(d1); } +}; + +template +class Assign { + public: + no_op_exec_special_any no_op_exec_special_any_cuda + + op_def static Z + op(X d1, X *params) { + return static_cast(d1); + } +}; + +template +class And { + public: + no_op_exec_special_bool no_op_exec_special_bool_cuda + + op_def static Z + op(X d1, X d2) { + return d2 + d1; + } + + op_def static Z op(X d1, X d2, X *params) { + if (params != nullptr) { + auto comp = params[0]; + return d1 != comp && d2 != comp ? static_cast(1) : static_cast(0); + } else { + auto b1 = static_cast(d1); + auto b2 = static_cast(d2); - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return sd::math::nd4j_floor(static_cast(d1 / params[0])); - } - }; - - template - class TruncateDiv { - public: - op_def static Z op(X d1, Y d2) { - auto i1 = static_cast(d1); - auto i2 = static_cast(d2); - return static_cast(i1 / i2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - auto i1 = static_cast(d1); - auto i2 = static_cast(d2); - return static_cast(i1 / i2); - } - - op_def static Z op(X d1) { - return d1; - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - auto i1 = static_cast(d1); - auto i2 = static_cast(params[0]); - return static_cast(i1 / i2); - } - }; - - template - class TruncateMod { - public: - op_def static Z op(X d1, Y d2) { - auto i1 = static_cast(d1); - auto i2 = static_cast(d2); - return static_cast(i1 % i2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - auto i1 = static_cast(d1); - auto i2 = static_cast(d2); - return static_cast(i1 % i2); - } - - op_def static Z op(X d1) { - return static_cast(d1); - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - auto i1 = static_cast(d1); - auto i2 = static_cast(params[0]); - return static_cast(i1 % i2); - } - }; + return (b1 && b2) ? static_cast(1) : static_cast(0); + } + } - template - class Remainder { - public: - op_def static Z op(X d1, Y d2) { - return sd::math::nd4j_remainder(d1, d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_remainder(d1, d2); - } - - op_def static Z op(X d1) { - return d1; - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return sd::math::nd4j_remainder(d1, params[0]); - } - }; - - template - class FMod { - public: - op_def static Z op(X d1, Y d2) { - return sd::math::nd4j_fmod(d1, d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_fmod(d1, d2); - } - - op_def static Z op(X d1) { - return d1; - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return sd::math::nd4j_fmod(d1, params[0]); - } - }; - - template - class FloorMod { - public: - op_def static Z op(X d1, Y d2) { - auto m = sd::math::nd4j_fmod(d1, d2); - return (d1 < static_cast(0)) == (d2 < static_cast(0)) ? m : sd::math::nd4j_fmod(m + static_cast(d2), d2); - } + op_def static Z op(X d1) { return d1; } - op_def static Z op(X d1, Y d2, Z *params) { - auto m = sd::math::nd4j_fmod(d1, d2); - return (d1 < static_cast(0.0f)) == (d2 < static_cast(0)) ? m : sd::math::nd4j_fmod(m + static_cast(d2), d2); - } + // op for MetaOps + op_def static Z op(X d1, X *params) { return static_cast(119); } +}; - op_def static Z op(X d1) { - return d1; - } +template +class IntOr { + public: + op_def static X op(X d1, X d2) { return d2 | d1; } + + op_def static X op(X d1, X d2, X *params) { return op(d1, d2); } +}; + +template +class IntAnd { + public: + op_def static X op(X d1, X d2) { return d2 & d1; } + + op_def static X op(X d1, X d2, X *params) { return op(d1, d2); } +}; + +template +class IntXor { + public: + op_def static X op(X d1, X d2) { return d2 ^ d1; } + + op_def static X op(X d1, X d2, X *params) { return op(d1, d2); } +}; + +template +class ShiftLeft { + public: + op_def static X op(X d1, X d2) { return d1 << d2; } + + op_def static X op(X d1, X d2, X *params) { return op(d1, d2); } +}; + +template +class ShiftRight { + public: + op_def static X op(X d1, X d2) { return d1 >> d2; } + + op_def static X op(X d1, X d2, X *params) { return op(d1, d2); } +}; + +template +class CyclicShiftLeft { + public: + op_def static X op(X d1, X d2) { return sd::math::nd4j_rotl(d1, d2); } + + op_def static X op(X d1, X d2, X *params) { return op(d1, d2); } +}; + +template +class CyclicShiftRight { + public: + op_def static X op(X d1, X d2) { return sd::math::nd4j_rotr(d1, d2); } + + op_def static X op(X d1, X d2, X *params) { return op(d1, d2); } +}; + +template +class Or { + public: + no_op_exec_special_bool no_op_exec_special_bool_cuda + + op_def static Z + op(X d1, X d2) { + return d2 + d1; + } + + op_def static Z op(X d1, X d2, X *params) { + if (params != nullptr) { + auto comp = params[0]; + + return d1 != comp || d2 != comp ? static_cast(1) : static_cast(0); + } else { + auto b1 = static_cast(d1); + auto b2 = static_cast(d2); + + return b1 || b2 ? static_cast(1) : static_cast(0); + } + } + + op_def static Z op(X d1) { return d1; } + + // op for MetaOps + op_def static Z op(X d1, X *params) { return static_cast(119); } +}; + +template +class Xor { + public: + no_op_exec_special_bool no_op_exec_special_bool_cuda + + op_def static Z + op(X d1, X d2) { + return d2 + d1; + } + + op_def static Z op(X d1, X d2, X *params) { + if (params != nullptr) { + auto comp = params[0]; + + return ((d1 == comp && d2 != comp) || (d1 != comp && d2 == comp)) + ? static_cast(1) + : static_cast(0); + } else { + auto b1 = static_cast(d1); + auto b2 = static_cast(d2); + + return (!b1 && b2) || (b1 && !b2) ? static_cast(1) : static_cast(0); + } + } + + op_def static Z op(X d1) { return d1; } +}; - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return op(d1, params[0]); - } - }; - - template - class ReverseDivide { - public: - op_def static Z op(X d1, Y d2) { - return static_cast(d2 / d1); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return static_cast(d2 / d1); - } - - op_def static Z op(X d1) { - return static_cast(d1); - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return static_cast(params[0] / d1); - } - }; - - template - class CopyPws { - public: - op_def static Z op(X d1, Y d2) { - return static_cast(d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return static_cast(d2); - } - - op_def static Z op(X d1) { - return static_cast(d1); - } - - op_def static Z op(X d1, Y *params) { - return static_cast(d1); - } - }; - - template - class Copy { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1; - } - }; - - - template - class Copy2 { - public: - op_def static Z op(X d1, Y d2) { - return static_cast(d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return static_cast(d2); - } - - op_def static Z op(X d1) { - return static_cast(d1); - } - - op_def static Z op(X d1, Y *params) { - return static_cast(d1); - } - }; - - template - class Axpy { - public: - op_def static Z op(X d1, Y d2) { - return static_cast(d2 + d1); - } - - op_def static Z op(X d1, Y d2, Z *params) { - auto alpha = params[0]; - return alpha * static_cast(d1) + static_cast(d2); - } - - op_def static Z op(X d1) { - return static_cast(d1); - } - }; - - template - class Assign { - public: - no_op_exec_special_any - no_op_exec_special_any_cuda - - op_def static Z op(X d1, X *params) { - return static_cast(d1); - } - }; - - template - class And { - public: - no_op_exec_special_bool - no_op_exec_special_bool_cuda - - op_def static Z op(X d1, X d2) { - return d2 + d1; - } - - op_def static Z op(X d1, X d2, X *params) { - if (params != nullptr) { - auto comp = params[0]; - return d1 != comp && d2 != comp ? static_cast(1) : static_cast(0); - } else { - auto b1 = static_cast(d1); - auto b2 = static_cast(d2); - - return (b1 && b2) ? static_cast(1) : static_cast(0); - } - } - - op_def static Z op(X d1) { - return d1; - } - - // op for MetaOps - op_def static Z op(X d1, X *params) { - return static_cast(119); - } - }; - - template - class IntOr { - public: - - op_def static X op(X d1, X d2) { - return d2 | d1; - } +template +class Not { + public: + no_op_exec_special_bool no_op_exec_special_bool_cuda + + op_def static Z + op(X d1, X d2) { + return static_cast(0); + } - op_def static X op(X d1, X d2, X *params) { - return op(d1, d2); - } - }; + op_def static Z op(X d1, X d2, X *params) { + return d1 != d2 ? static_cast(1) : static_cast(0); + } - template - class IntAnd { - public: + // this transform op should run only on boolean input + op_def static Z op(X d1, X *params) { + auto b1 = static_cast(d1); + return !b1; + } +}; + +template +class LogicalNot { + public: + op_def static Z op(X d1, Y d2) { return !((int)d1 && (int)d2); } - op_def static X op(X d1, X d2) { - return d2 & d1; - } + op_def static Z op(X d1, Y d2, Z *params) { + return static_cast(!(static_cast(d1) && static_cast(d2))); + } - op_def static X op(X d1, X d2, X *params) { - return op(d1, d2); - } - }; + op_def static Z op(X d1) { return d1; } - template - class IntXor { - public: + // op for MetaOps + op_def static Z op(X d1, Y *params) { return static_cast(119); } +}; + +template +class LogicalXor { + public: + op_def static Z op(X d1, Y d2) { + auto i1 = static_cast(d1); + auto i2 = static_cast(d2); - op_def static X op(X d1, X d2) { - return d2 ^ d1; - } + return (i1 | i2) & ~(i1 & i2); + } - op_def static X op(X d1, X d2, X *params) { - return op(d1, d2); - } - }; + op_def static Z op(X d1, Y d2, Z *params) { return op(d1, d2); } - template - class ShiftLeft { - public: + op_def static Z op(X d1) { return d1; } - op_def static X op(X d1, X d2) { - return d1 << d2; - } + // op for MetaOps + op_def static Z op(X d1, Y *params) { return static_cast(119); } +}; - op_def static X op(X d1, X d2, X *params) { - return op(d1, d2); - } - }; +template +class LogicalAnd { + public: + op_def static Z op(X d1, Y d2) { + return static_cast(d1) & static_cast(d2); + } - template - class ShiftRight { - public: + op_def static Z op(X d1, Y d2, Z *params) { return op(d1, d2); } - op_def static X op(X d1, X d2) { - return d1 >> d2; - } + op_def static Z op(Y d1) { return d1; } - op_def static X op(X d1, X d2, X *params) { - return op(d1, d2); - } - }; + // op for MetaOps + op_def static Z op(X d1, Y *params) { return static_cast(119); } +}; + +template +class LogicalOr { + public: + op_def static Z op(X d1, Y d2) { + return static_cast(d1) | static_cast(d2); + } - template - class CyclicShiftLeft { - public: + op_def static Z op(X d1, Y d2, Z *params) { return op(d1, d2); } - op_def static X op(X d1, X d2) { - return sd::math::nd4j_rotl(d1, d2); - } + op_def static Z op(X d1) { return d1; } - op_def static X op(X d1, X d2, X *params) { - return op(d1, d2); - } - }; + // op for MetaOps + op_def static Z op(X d1, Y *params) { return static_cast(119); } +}; - template - class CyclicShiftRight { - public: +template +class Mod { + public: - op_def static X op(X d1, X d2) { - return sd::math::nd4j_rotr(d1, d2); - } - op_def static X op(X d1, X d2, X *params) { - return op(d1, d2); - } - }; - - - template - class Or { - public: - no_op_exec_special_bool - no_op_exec_special_bool_cuda - - op_def static Z op(X d1, X d2) { - return d2 + d1; - } - - op_def static Z op(X d1, X d2, X *params) { - if (params != nullptr) { - auto comp = params[0]; - - return d1 != comp || d2 != comp ? static_cast(1) : static_cast(0); - } else { - auto b1 = static_cast(d1); - auto b2 = static_cast(d2); - - return b1 || b2 ? static_cast(1) : static_cast(0); - } - } - - op_def static Z op(X d1) { - return d1; - } - - // op for MetaOps - op_def static Z op(X d1, X *params) { - return static_cast(119); - } - }; - - template - class Xor { - public: - - no_op_exec_special_bool - no_op_exec_special_bool_cuda - - op_def static Z op(X d1, X d2) { - return d2 + d1; - } - - op_def static Z op(X d1, X d2, X *params) { - if (params != nullptr) { - auto comp = params[0]; - - return ((d1 == comp && d2 != comp) || (d1 != comp && d2 == comp)) ? static_cast(1) : static_cast(0); - } else { - auto b1 = static_cast(d1); - auto b2 = static_cast(d2); - - return (!b1 && b2 )||(b1 && !b2) ? static_cast(1) : static_cast(0); - } - } - - op_def static Z op(X d1) { - return d1; - } - }; - - - template - class Not { - public: - no_op_exec_special_bool - no_op_exec_special_bool_cuda - - op_def static Z op(X d1, X d2) { - return static_cast(0); - } - - op_def static Z op(X d1, X d2, X *params) { - return d1 != d2 ? static_cast(1) : static_cast(0); - } - - // this transform op should run only on boolean input - op_def static Z op(X d1, X *params) { - auto b1 = static_cast(d1); - return !b1; - } - }; - - template - class LogicalNot { - public: - op_def static Z op(X d1, Y d2) { - return !((int) d1 && (int) d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return static_cast(!(static_cast(d1) && static_cast(d2))); - } - - op_def static Z op(X d1) { - return d1; - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return static_cast(119); - } - }; - - template - class LogicalXor { - public: - op_def static Z op(X d1, Y d2) { - auto i1 = static_cast(d1); - auto i2 = static_cast(d2); - - return (i1 | i2) &~ (i1 & i2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return op(d1, d2); - } - - op_def static Z op(X d1) { - return d1; - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return static_cast(119); - } - }; - - template - class LogicalAnd { - public: - op_def static Z op(X d1, Y d2) { - return static_cast(d1) & static_cast(d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return op(d1, d2); - } - - op_def static Z op(Y d1) { - return d1; - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return static_cast(119); - } - }; - - template - class LogicalOr { - public: - op_def static Z op(X d1, Y d2) { - return static_cast(d1) | static_cast(d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return op(d1, d2); - } - - op_def static Z op(X d1) { - return d1; - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return static_cast(119); - } - }; - - - template - class Mod { - public: - - op_def static Z op(X d1, Y d2) { - auto dx = static_cast(d2); - auto f = sd::math::nd4j_floor(d1 / dx); + op_def static Z op(X d1, Y d2) { + auto dx = static_cast(d2); +auto f = sd::math::nd4j_floor(d1 / dx); auto r = f * dx; return d1 - r; } - op_def static Z op(X d1, Y d2, Z *params) { - return op(d1, d2); - } - - // op for MetaOp - op_def static Z op(X d1, Y *params) { - return op(d1, params[0]); - } - }; - - template - class ReverseMod { - public: - op_def static Z op(X d1, Y d2) { - return static_cast(d2) % static_cast(d1); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return op(d1, d2); - } - - // op for MetaOp - op_def static Z op(X d1, Y *params) { - return op(d1, params[0]); - } - }; - - /** - * Whether 2 elements in an array - * are epsilion equal - */ - template - class Epsilon { - public: - - op_def static Z op(X d1, X d2) { - X diff = d1 - d2; - X absDiff = sd::math::nd4j_abs(diff); - if (absDiff <= static_cast(MIN_V)) - return static_cast(1); - return static_cast(0); - } - - op_def static Z op(X d1, X d2, X *params) { - return op(d1, d2); - } - - op_def static Z op(X d1, X *params) { - return d1; - } - }; - - - template - class EqualTo { - public: - op_def static Z op(X d1, X d2) { - return d1 == d2; - } - - op_def static Z op(X d1, X d2, X *params) { - return op(d1, d2); - } - - op_def static Z op(X d1, X *params) { - return d1; - } - }; - - - template - class NotEqualTo { - public: - op_def static Z op(X d1, X d2) { - return d1 != d2; - } - - op_def static Z op(X d1, X d2, X *params) { - return op(d1, d2); - } - - op_def static Z op(X d1, X *params) { - return d1; - } - }; - - - - template - class GreaterThanOrEqual { - public: - op_def static Z op(X d1, X d2) { - return d1 >= d2; - } - - op_def static Z op(X d1, X d2, X *params) { - return op(d1, d2); - } - - // FIXME: this signature clashes with MetaOp stuff - op_def static Z op(X d1, X *params) { - return d1; - } - }; - - - template - class GreaterThan { - public: - op_def static Z op(X d1, X d2) { - return d1 > d2; - } - - op_def static Z op(X d1, X d2, X *params) { - return op(d1, d2); - } - - // FIXME: this signature clashes with MetaOp stuff - op_def static Z op(X d1, X *params) { - return d1; - } - - }; - - - template - class LessThan { - public: - op_def static Z op(X d1, X d2) { - return d1 < d2; - } - - op_def static Z op(X d1, X d2, X *params) { - return op(d1, d2); - } - - op_def static Z op(X d1, X *params) { - return d1; - } - - }; - - - template - class LessThanOrEqual { - public: - op_def static Z op(X d1, X d2) { - return d1 <= d2; - } - - op_def static Z op(X d1, X d2, X *params) { - return op(d1, d2); - } - - op_def static Z op(X d1, X *params) { - return d1; - } - - }; - - - template - class Abs { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_abs(d1); - } - }; - - - template - class Ceiling { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_ceil(d1); - } - }; - - - template - class Cosine { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_cos(d1); - } - }; - - - template - class Exp { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_exp(d1); - } - }; - - - template - class HardTanhDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return ((d1 >= static_cast(-1.f) && d1 <= static_cast(1.f)) ? static_cast(1.f) : static_cast(0.f)); - } - }; - - - template - class HardTanh { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - if (d1 < static_cast(-1)) - return static_cast(-1); - else if (d1 > static_cast(1)) - return static_cast(1); - else - return d1; - - } - }; - - - template - class Floor { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_floor(d1); - } - }; - - - template - class Log { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_log(d1); - } - }; - - template - class Log1p { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_log(1 + d1); - } - }; - - template - class LogX { - public: - - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_log(d1) / sd::math::nd4j_log(d2) ; - } - }; - - template - class StabilizeFP16 { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - if (d1 <= static_cast(0)) - return static_cast(sd::DataTypeUtils::min()); - else return d1; - } - }; - - template - class StabilizeX { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - if (d1 <= static_cast(0)) - return sd::DataTypeUtils::min(); - else return d1; - } - }; - - template - class SpecialDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1 * (static_cast(1.f) - d1); - } - }; - - - template - class Neg { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return -d1; - } - }; - - template - class Erf { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_erf(d1); - } - }; - - - template - class Erfc { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_erfc(d1); - } - }; - - template - class Reciprocal { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda -// op_def static T op(T d1) { -// return (T(1.0f) / d1); -// } - // op for MetaOps - op_def static X op(X d1, X *params) { - return (static_cast(1) / d1); - } - }; - - template - class Sqr { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Z *params) { - return sd::math::nd4j_pow(d1, static_cast(2)); - } - - op_def static Z op(X d1) { - return sd::math::nd4j_pow(d1, static_cast(2)); - } - }; - - - template - class RelativeError { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Y d2) { - return sd::math::nd4j_re(d1, d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return op(d1, d2); - } - - op_def static Z op(X d1) { - return static_cast(0); - } - }; - - template - class BinaryRelativeError { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Y d2, Z *params) { - X threshold = params[0]; - return sd::math::nd4j_re(d1, d2) > threshold ? static_cast(1) : static_cast(0); - } - - op_def static Z op(X d1) { - return static_cast(0); - } - }; - - template - class BinaryMinimumAbsoluteRelativeError { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, X *params) { - X d2 = params[0]; - X thresholdRelative = params[1]; - X thresholdAbsolute = params[2]; - return sd::math::nd4j_re(d1, d2) > thresholdRelative ? (sd::math::nd4j_abs(d1 - static_cast(d2)) < thresholdAbsolute ? static_cast(0) : static_cast(1)) : static_cast(0); - } - - op_def static Z op(X d1, Y d2, Z *params) { - X thresholdRelative = params[0]; - X thresholdAbsolute = params[1]; - return sd::math::nd4j_re(d1, d2) > thresholdRelative ? (sd::math::nd4j_abs(d1 - static_cast(d2)) < thresholdAbsolute ? static_cast(0) : static_cast(1)) : static_cast(0); - } - - op_def static Z op(X d1) { - return static_cast(0); - } - }; - - template - class ReversePow { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Z *params) { - return sd::math::nd4j_pow(params[0], d1); - } - - op_def static Z op(X d1, Y d2) { - return sd::math::nd4j_pow(d2, d1); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_pow(d2, d1); - } - - op_def static Z op(X d1) { - return d1; - } - }; - - template - class Pow { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Z *params) { - return sd::math::nd4j_pow(d1, params[0]); - } - - op_def static Z op(X d1, Y d2) { - return sd::math::nd4j_pow(d1, d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_pow(d1, d2); - } - - op_def static Z op(X d1) { - return d1; - } - }; - - - template - class PowDerivative { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Z *params) { - return params[0] * sd::math::nd4j_pow(d1, static_cast(params[0]) - static_cast(1.f)); - } - - op_def static Z op(X d1, Y d2) { - return static_cast(d2) * sd::math::nd4j_pow(d1, static_cast(d2) - static_cast(1.f)); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return static_cast(d2) * sd::math::nd4j_pow(d1, static_cast(d2) - static_cast(1.f)); - } - - op_def static Z op(X d1) { - return static_cast(d1); - } - }; - - - template - class IGamma { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Z *params) { - return sd::math::nd4j_igamma(d1, params[0]); - } - - op_def static Z op(X d1, Y d2) { - return sd::math::nd4j_igamma(d1, d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_igamma(d1, d2); - } - - op_def static Z op(X d1) { - return d1; - } - }; - - template - class IGammac { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Z *params) { - return sd::math::nd4j_igammac(d1, params[0]); - } - - op_def static Z op(X d1, Y d2) { - return sd::math::nd4j_igammac(d1, d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_igammac(d1, d2); - } - - op_def static Z op(X d1) { - return d1; - } - }; - - template - class Round { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_round(d1); - } - }; - - template - class IsNan { - public: - no_op_exec_special_bool - no_op_exec_special_bool_cuda - - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda - - op_def static Z op(X d1, X *params) { - return sd::math::nd4j_isnan(d1) ? static_cast(1) : static_cast(0); - } - - op_def static X startingValue(const X *input) { - return static_cast(0); - } - - op_def static Z merge(X old, X opOutput, X *extraParams) { - return opOutput + old; - } - - - op_def static Z update(X old, X opOutput, X *extraParams) { - return opOutput + old; - } - - - op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction; - } - }; - - - template - class Expm1 { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_exp(d1) - static_cast(1); - } - }; - - template - class IsPositive { - public: - no_op_exec_special_bool - no_op_exec_special_bool_cuda - - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda - - op_def static Z op(X d1, X *params) { - return d1 > (X)0.f; - } - - op_def static X startingValue(const X *input) { - return static_cast(0); - } - - op_def static Z merge(X old, X opOutput, X *extraParams) { - return opOutput + old; - } - - op_def static Z update(X old, X opOutput, X *extraParams) { - return opOutput + old; - } - - op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction; - } - }; - - template - class IsNegative { - public: - no_op_exec_special_bool - no_op_exec_special_bool_cuda - - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda - - op_def static Z op(X d1, X *params) { - return d1 < (X)0.f; - } - - op_def static X startingValue(const X *input) { - return static_cast(0); - } - - op_def static Z merge(X old, X opOutput, X *extraParams) { - return opOutput + old; - } - - op_def static Z update(X old, X opOutput, X *extraParams) { - return opOutput + old; - } - - op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction; - } - }; - - template - class IsInf { - public: - no_op_exec_special_bool - no_op_exec_special_bool_cuda - - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda - - op_def static Z op(X d1, X *params) { - return sd::math::nd4j_isinf(d1) ? static_cast(1) : static_cast(0); - } - - op_def static X startingValue(const X *input) { - return static_cast(0); - } - - op_def static Z merge(X old, X opOutput, X *extraParams) { - return opOutput + old; - } - - - op_def static Z update(X old, X opOutput, X *extraParams) { - return opOutput + old; - } - - - op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction; - } - }; - - - template - class IsInfOrNan{ - public: - no_op_exec_special_bool - no_op_exec_special_bool_cuda - - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda - - op_def static Z op(X d1, X *params) { - return sd::math::nd4j_isfin(d1) ? static_cast(0) : static_cast(1); - } - - op_def static X startingValue(const X *input) { - return static_cast(0); - } - - op_def static Z merge(X old, X opOutput, X *extraParams) { - return opOutput == static_cast(0) && old == static_cast(0) ? static_cast(0) : static_cast(1); - } - - - op_def static Z update(X old, X opOutput, X *extraParams) { - return opOutput == static_cast(0) && old == static_cast(0) ? static_cast(0) : static_cast(1); - } - - op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction != static_cast(0); - } - }; - - - - template - class IsFinite { - public: - no_op_exec_special_bool - no_op_exec_special_bool_cuda - - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda - - op_def static Z op(X d1, X *params) { - return sd::math::nd4j_isfin(d1) ? static_cast(1) : static_cast(0); - } - - op_def static X startingValue(const X *input) { - return static_cast(1); - } - - op_def static Z merge(X old, X opOutput, X *extraParams) { - return opOutput == static_cast(0) || old == static_cast(0) ? static_cast(0) : static_cast(1); - } - - op_def static Z update(X old, X opOutput, X *extraParams) { - return opOutput == static_cast(0) || old == static_cast(0) ? static_cast(0) : static_cast(1); - } - - op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction != static_cast(0); - } - }; - - - template - class ClipByValue { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - if (d1 > params[1]) - return params[1]; - if (d1 < params[0]) - return params[0]; - return d1; - } - }; - - template - class LstmClip { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Y d2, Z *params) { - X _v = (X) d2; - if (d1 > _v) - return _v; - else if (d1 < -_v) - return -_v; - else return d1; - } - }; - - template - class Swish { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1 * sd::math::nd4j_sigmoid(d1); - } - }; - - template - class Mish { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1 * sd::math::nd4j_tanh(sd::math::nd4j_softplus(d1)); - } - }; - - template - class MishDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - auto ex = sd::math::nd4j_exp(d1); - auto e2x = ex * ex; - auto e3x = ex * ex * ex; - - return (ex * (4 * (d1 + 1) + 4 * e2x + e3x + ex *(4 * d1 + 6))) / sd::math::nd4j_pow((2 * ex + e2x + 2), (X) 2.f); - } - }; - - template - class GELU { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1 * sd::math::nd4j_sigmoid(static_cast(1.702f) * d1); - } - }; - - template - class PreciseGELU { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - auto sp = sd::math::nd4j_sqrt(static_cast(2) / static_cast(M_PI)); - auto xp = d1 + sd::math::nd4j_pow(static_cast(0.044715) * d1, static_cast(3)); - return (d1 / static_cast(2)) * (static_cast(1) + sd::math::nd4j_tanh(sp * xp)); - } - }; - - template - class GELUDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - auto x17 = static_cast(1.702f) * d1; - auto ep = sd::math::nd4j_pow(static_cast(M_E), x17); - // (E^(1.702 x) (1. + E^(1.702 x) + 1.702 x))/(1. + E^(1.702 x))^2 - return (ep * (static_cast(1.f) + ep + x17)) / sd::math::nd4j_pow((static_cast(1.f) + ep), 2); - } - }; - - template - class PreciseGELUDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - auto x79 = static_cast(0.797885) * d1; - auto x03 = sd::math::nd4j_pow(static_cast(0.0356774) * d1, 3); - auto x39 = static_cast(0.398942) * d1; - auto x05 = sd::math::nd4j_pow(static_cast(0.0535161) * d1, 3); - auto scz = sd::math::nd4j_sech(x79 + x03); - // 0.5 + (0.398942 x + 0.0535161 x^3) Sech[0.797885 x + 0.0356774 x^3]^2 + 0.5 Tanh[0.797885 x + 0.0356774 x^3] - return static_cast(0.5) + (x39 + x05) * (scz * scz) + static_cast(0.5) * sd::math::nd4j_tanh(x79 + x03); - } - }; - - - template - class SwishDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - X ex = sd::math::nd4j_pow(static_cast(M_E), d1); - return (ex * (d1 + ex + static_cast(1.f))) / sd::math::nd4j_pow((ex + static_cast(1.f)) , static_cast(2.f)); - } - }; - - - template - class LogSigmoid { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_log(sd::math::nd4j_sigmoid(d1)); - } - }; - - template - class LogSigmoidDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - X ex = sd::math::nd4j_pow(M_E, d1); - return static_cast(1.f) / (ex + static_cast(1.f)); - } - }; - - template - class Sigmoid { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_sigmoid(d1); - } - }; - - template - class Affine { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return params[0] * d1 + params[1]; - } - }; - - template - class SigmoidDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_sigmoidderivative(d1); - } - }; - - - template - class HardSigmoid { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_min(static_cast(1), sd::math::nd4j_max(static_cast(0), (static_cast(0.2f)) * d1 + static_cast(0.5f))); - } - }; - - template - class HardSigmoidDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1 < static_cast(-2.5f) || d1 > static_cast(2.5f) ? static_cast(0.f) : static_cast(0.2f); - } - }; - - - /** - * Scale to be between a min and max - */ - template - class SetRange { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - auto min = params[0]; - auto max = params[1]; - if (static_cast(d1) >= min && static_cast(d1) <= max) - return d1; - if (min == static_cast(0) && max == static_cast(1)) { - auto val = static_cast(1) / (static_cast(1) + sd::math::nd4j_exp(-d1)); - return (sd::math::nd4j_floor(val * (max - min)) + min); - } - - return (sd::math::nd4j_floor(d1 * (max - min)) + min); - } - }; - - - template - class Sin { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_sin(d1); - } - }; - - template - class Square { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1 * d1; - } - }; - - template - class Sqrt { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Z *params) { - return sd::math::nd4j_sqrt(d1); - } - }; - - template - class RSqrt { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Z *params) { - return static_cast(1) / sd::math::nd4j_sqrt(d1); - } - }; - - template - class Rint { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_rint(d1); - } - }; - - - template - class SoftPlus { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_softplus(d1); - } - }; - - - template - class Sign { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return (d1 > static_cast(0)) - (d1 < static_cast(0)); - } - }; - - - template - class TimesOneMinus { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1 * (static_cast(1) - d1); - } - }; - - - template - class RationalTanh { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - // keep 2/3 as runtime variable, to match precision - auto dis = (static_cast(2) / static_cast(3)) * d1; - - auto tanh = sd::math::nd4j_sgn(dis) * (static_cast(1) - (static_cast(1) / (static_cast(1) + static_cast(sd::math::nd4j_abs(dis)) + sd::math::nd4j_pow(dis, static_cast(2)) + static_cast(1.41645f) * sd::math::nd4j_pow(dis, static_cast(4)) ))); - return static_cast(1.7159f) * tanh; - } - }; - - template - class RationalTanhDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - auto dis = (static_cast(2.f) / static_cast(3.f)) * d1; - - auto a = static_cast(1.f) + sd::math::nd4j_abs(dis) + sd::math::nd4j_pow(dis, static_cast(2.f)) + static_cast(1.41645f) * sd::math::nd4j_pow(dis, static_cast(4)); - - auto tDeriv = (static_cast(1.f) + sd::math::nd4j_sign(dis) * (static_cast(2.f) * dis + static_cast(4.f) * static_cast(1.41645f) * sd::math::nd4j_pow(dis, static_cast(3)))) / (a * a); - - return static_cast(1.7159f) * (static_cast(2.f) / static_cast(3.f)) * tDeriv; - } - }; - - template - class Tanh { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_tanh(d1); - } - }; - - template - class ScaledTanh { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return params[0] * sd::math::nd4j_tanh(params[1] * d1); - } - }; - - template - class RectifiedTanh { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_max(static_cast(0), sd::math::nd4j_tanh(d1)); - } - }; - - template - class RectifiedTanhDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1 > static_cast(0.f) ? sd::math::nd4j_tanhderivative(d1) : static_cast(0.f); - } - }; - - template - class ATanh { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_atanh(d1); - } - }; - - template - class TanhDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_tanhderivative(d1); - } - }; - - template - class Cube { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1 * d1 * d1; - } - }; - - - template - class CubeDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return static_cast(3) * d1 * d1; - } - }; - - template - class ACos { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_acos(d1); - } - }; - - template - class ASinh { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_asinh(d1); - } - }; - - template - class ASinhDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return static_cast(1.f) / (sd::math::nd4j_sqrt(sd::math::nd4j_pow(d1, static_cast(2.f)) + static_cast(1.f))); - } - }; - - template - class ACosh { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_acosh(d1); - } - }; - - - template - class ACoshDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return static_cast(1.f) / (sd::math::nd4j_sqrt(d1 - static_cast(1.f)) * sd::math::nd4j_sqrt(d1 + static_cast(1.f))); - } - }; - - - - template - class Ones { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return static_cast(1.0f); - } - }; - - - - template - class SoftSign { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_softsign(d1); - } - }; - - - template - class SoftSignDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_softsignderivative(d1); - } - }; - - template - class MatchConditionBool { - public: - no_op_exec_special_bool - no_op_exec_special_bool_cuda - - // this op return 1.0 if condition met, 0.0 otherwise - op_def static Z op(X d1, X *extraParams) { - X compare = extraParams[0]; - X eps = extraParams[1]; - - auto mode = static_cast(extraParams[2]); - //nd4j_printf("value: %f; comp: %f; eps: %f; mode: %i;\n", d1, compare, eps, mode); - - switch (mode) { - case 0: // equals - return sd::math::nd4j_abs(d1 - compare) <= eps ? true : false; - case 1: // not equals - return sd::math::nd4j_abs(d1 - compare) > eps ? true : false; - case 2: // less_than - return d1 < compare ? true : false; - case 3: // greater_than - return d1 > compare ? true : false; - case 4: // less_or_equals_than - return d1 <= compare ? true : false; - case 5: // greater_or_equals_than - return d1 >= compare ? true : false; - case 6: // abs_less_than - return sd::math::nd4j_abs(d1) < compare ? true : false; - case 7: // abs_greater_than - return sd::math::nd4j_abs(d1) > compare ? true : false; - case 8: // is inf - return sd::math::nd4j_isinf(d1) ? true : false; - case 9: // is nan - return sd::math::nd4j_isnan(d1) ? true : false; - case 10: - return (d1 == compare) ? true : false; - case 11: - return (d1 != compare) ? true : false; - case 12: // abs_greater_or_equals_than - return sd::math::nd4j_abs(d1) >= compare ? true : false; - case 13: // abs_less_or_equals_than - return sd::math::nd4j_abs(d1) <= compare ? true : false; - case 14: - // isFinite - return !(sd::math::nd4j_isinf(d1) || sd::math::nd4j_isnan(d1)); - case 15: - // isInfinite - return sd::math::nd4j_isinf(d1) || sd::math::nd4j_isnan(d1); - default: - printf("Undefined match condition: [%i]\n", mode); - } - - return d1; - } - }; - - template - class MatchCondition { - public: - no_op_exec_special - no_op_exec_special_cuda - - no_op_exec_special_accumulation_long - no_op_exec_special_accumulation_cuda - - op_def static Z startingValue(const X *input) { - return static_cast(0); - } - - op_def static Z merge(Z old, Z opOutput, X *extraParams) { - return old + opOutput; - } - - op_def static Z update(Z old, Z opOutput, X *extraParams) { - return old + opOutput; - } - - op_def static Z op(X d1, X compare, X eps, int mode) { - switch (mode) { - case 0: // equals - return sd::math::nd4j_abs(d1 - compare) <= eps ? 1 : 0; - case 1: // not equals - return sd::math::nd4j_abs(d1 - compare) > eps ? 1 : 0; - case 2: // less_than - return d1 < compare ? 1 : 0; - case 3: // greater_than - return d1 > compare ? 1 : 0; - case 4: // less_or_equals_than - return d1 <= compare ? 1 : 0; - case 5: // greater_or_equals_than - return d1 >= compare ? 1 : 0; - case 6: // abs_less_than - return sd::math::nd4j_abs(d1) < compare ? 1 : 0; - case 7: // abs_greater_than - return sd::math::nd4j_abs(d1) > compare ? 1 : 0; - case 8: // is inf - return sd::math::nd4j_isinf(d1) ? 1 : 0; - case 9: // is nan - return sd::math::nd4j_isnan(d1) ? 1 : 0; - case 10: - return (d1 == compare) ? 1 : 0; - case 11: - return (d1 != compare) ? 1 : 0; - case 12: // abs_greater_or_equals_than - return sd::math::nd4j_abs(d1) >= compare ? 1 : 0; - case 13: // abs_less_or_equals_than - return sd::math::nd4j_abs(d1) <= compare ? 1 : 0; - case 14: - // isFinite - return !(sd::math::nd4j_isinf(d1) || sd::math::nd4j_isnan(d1)) ? 1 : 0; - case 15: - // isInfinite - return sd::math::nd4j_isinf(d1) || sd::math::nd4j_isnan(d1) ? 1 : 0; - default: - printf("Undefined match condition: [%i]\n", mode); - } - - return d1; - } - - // this op return 1.0 if condition met, 0.0 otherwise - op_def static Z op(X d1, X compare, X *extraParams) { - X eps = extraParams[1]; - - auto mode = static_cast(extraParams[0]); - - return op(d1, compare, eps, mode); - } - - // this op return 1.0 if condition met, 0.0 otherwise - op_def static Z op(X d1, X *extraParams) { - X compare = extraParams[0]; - X eps = extraParams[1]; - - auto mode = static_cast(extraParams[2]); - - return op(d1, compare, eps, mode); - } - - op_def static Z postProcess(Z reduction, Nd4jLong n, X *extraParams) { - return reduction; - } - }; - - template - class ELU { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_elu(d1, static_cast(d2)); - } - }; - - - template - class ELUDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_eluderivative(d1, static_cast(d2)); - } - }; - - - template - class RELU { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static Z op(X d1, Y d2, Z *params) { - auto xt = static_cast(d1); - auto xf = static_cast(d2); - return xt < xf ? xf : xt; - } - }; - - template - class RELUDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static Z op(X d1, Y d2, Z *params) { - auto xt = static_cast(d1); - auto xf = static_cast(d2); - return xt > xf ? static_cast(1.f) : static_cast(0.f); - } - }; - - template - class SXELogitsSmoother { - public: - op_def static Z op(X d1, Y d2, Z *params) { - return d1 * ((X)1.f - (X) d2) + (X)(0.5f) * (X) d2; - } - }; - - template - class RELU6 { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static Z op(X d1, Y d2, Z *params) { - auto relu = simdOps::RELU::op(d1, d2, params); - return relu < static_cast(6) ? relu : static_cast(6); - } - }; - - template - class LeakyRELU { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Y d2, Z *params) { - auto val = static_cast(d1); - auto alpha = static_cast(d2); - return val < 0.0f ? alpha * val : val; - } - }; - - template - class SELU { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1 > static_cast(0.0f) ? static_cast(SELU_LAMBDA) * static_cast(d1) : static_cast(SELU_LAMBDA) * (static_cast(SELU_ALPHA) * sd::math::nd4j_exp(d1) - static_cast(SELU_ALPHA)); - } - }; - - template - class SELUDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1 > static_cast(0.f) ? static_cast(SELU_LAMBDA) : static_cast(SELU_ALPHA) * static_cast(SELU_LAMBDA) * sd::math::nd4j_exp(d1); - } - }; - - template - class LeakyRELUDerivative { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Y d2, Z *params) { - if (d1 >= static_cast(0)) - return static_cast(1); - else - return static_cast(d2); - } - }; - - - template - class ASin { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_asin(d1); - } - }; - - template - class Sinh { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_sinh(d1); - } - }; - - template - class SinhDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_cosh(d1); - } - }; - - template - class Cosh { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_cosh(d1); - } - }; - - - template - class Tan { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_tan(d1); - } - }; - - template - class TanDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return static_cast(1.f) / sd::math::nd4j_pow(sd::math::nd4j_cos(d1), static_cast(2.0f)); - } - }; - - template - class ATan { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_atan(d1); - } - }; - - template - class Atan2 { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Y d2) { - return sd::math::nd4j_atan2(d2, d1); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return op(d1, d2); - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return op(d1, params[0]); - } - }; - - - template - class Identity { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1; - } - }; - - - template - class Stabilize { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - X k = params[0]; - if (d1 * k > static_cast(- MIN_CUTFOFF)) - return static_cast(- MIN_CUTFOFF) / k; - else if (d1 * k < static_cast(MIN_CUTFOFF)) - return static_cast(MIN_CUTFOFF) / k; - return d1; - } - }; - - - - template - class Step { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static Z op(X d1, Y d2, Z *params) { - return (d1 > static_cast(d2) ? static_cast(1) : static_cast(0)); - } - }; - - - - template - class OneMinus { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return static_cast(1) - d1; - } - }; - - template - class Sum { - public: - no_op_exec_special_accumulation_same - no_op_exec_special_accumulation_same_cuda - - op_def static X startingValue(const X *input) { - return static_cast(0.0f); - } - - op_def static X merge(X old, X opOutput, X *extraParams) { - return opOutput + old; - } - - op_def static X update(X old, X opOutput, X *extraParams) { - return opOutput + old; - } - - op_def static X op(X d1, X *extraParams) { - return d1; - } - - op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction; - } - }; - - template - class ReduceSameBenchmarkOp { - public: - no_op_exec_special_accumulation_same - no_op_exec_special_accumulation_same_cuda - - const static functions::ReduceType reduceType = functions::ReduceType::SUM; - - op_def static X startingValue(const X *input) { - return static_cast(0.0f); - } - - op_def static X merge(X old, X opOutput, X *extraParams) { - return opOutput + old; - } - - op_def static X update(X old, X opOutput, X *extraParams) { - return opOutput + old; - } - - op_def static X op(X d1, X *extraParams) { - auto f1 = static_cast(d1); - return static_cast(sd::math::nd4j_pow(f1, 3) - + sd::math::nd4j_log(f1) * sd::math::nd4j_sin(f1) - / sd::math::nd4j_tanh(static_cast(M_E) * static_cast(M_PI) * f1) - * sd::math::nd4j_sqrt(static_cast(M_PI) / f1) - - sd::math::nd4j_atan(static_cast(M_E) / f1)); - } - - op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction; - } - }; - - - template - class ShannonEntropy { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda - - const static functions::ReduceType reduceType = functions::ReduceType::SUM; - - op_def static X startingValue(const X *input) { - return static_cast(0); - } - - op_def static Z merge(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } - - op_def static Z update(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } - - op_def static Z op(X d1, Z *extraParams) { - auto p = d1 * d1; - return static_cast(p) * sd::math::nd4j_log(p); - } - - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { - return -reduction; - } - }; - - - template - class LogEntropy { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda - - const static functions::ReduceType reduceType = functions::ReduceType::SUM; - - op_def static X startingValue(const X *input) { - return static_cast(0); - } - - op_def static Z merge(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } - - op_def static Z update(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } - - op_def static Z op(X d1, Z *extraParams) { - return static_cast(d1) * sd::math::nd4j_log(d1); - } - - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { - //entropy is -sum(p(x) * log(p(x))); log entropy is log of this - return sd::math::nd4j_log(-reduction); - } - }; - - template - class Entropy { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda - - const static functions::ReduceType reduceType = functions::ReduceType::SUM; - - op_def static X startingValue(const X *input) { - return static_cast(0); - } - - op_def static Z merge(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } - - op_def static Z update(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } - - op_def static Z op(X d1, Z *extraParams) { - return static_cast(d1) * sd::math::nd4j_log(d1); - } - - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { - return static_cast(-reduction); //entropy is -sum(p(x) * log(p(x))) - } - }; - - - template - class ASum { - public: - no_op_exec_special_accumulation_same - no_op_exec_special_accumulation_same_cuda - - const static functions::ReduceType reduceType = functions::ReduceType::ASUM; - - op_def static X startingValue(const X *input) { - return static_cast(0); - } - - op_def static X merge(X old, X opOutput, X *extraParams) { - return sd::math::nd4j_abs(opOutput) + sd::math::nd4j_abs(old); - } - - op_def static X update(X old, X opOutput, X *extraParams) { - return sd::math::nd4j_abs(opOutput) + sd::math::nd4j_abs(old); - } - - op_def static X op(X d1, X *extraParams) { - return sd::math::nd4j_abs(d1); - } - - op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { - return sd::math::nd4j_abs(reduction); - } - }; - - - template - class CountNonZero { - public: - no_op_exec_special_accumulation_long - no_op_exec_special_accumulation_cuda - - const static functions::ReduceType reduceType = functions::ReduceType::ASUM; - - op_def static Z startingValue(const X *input) { - return static_cast(0); - } - - op_def static Z merge(Z old, Z opOutput, X *extraParams) { - return opOutput + old; - } - - op_def static Z update(Z old, Z opOutput, X *extraParams) { - return opOutput + old; - } - - op_def static Z op(X d1, X *extraParams) { - return d1 == static_cast(0.0f) ? static_cast(0.0f) : static_cast(1.0f); - } - - op_def static Z postProcess(Z reduction, Nd4jLong n, X *extraParams) { - return reduction; - } - }; + op_def static Z op(X d1, Y d2, Z *params) { return op(d1, d2); } + // op for MetaOp + op_def static Z op(X d1, Y *params) { return op(d1, params[0]); } +}; - template - class CountZero { - public: - no_op_exec_special_accumulation_long - no_op_exec_special_accumulation_cuda +template +class ReverseMod { + public: + op_def static Z op(X d1, Y d2) { + return static_cast(d2) % static_cast(d1); + } - const static functions::ReduceType reduceType = functions::ReduceType::SUM; + op_def static Z op(X d1, Y d2, Z *params) { return op(d1, d2); } - op_def static Z startingValue(const X *input) { - return static_cast(0.0f); - } + // op for MetaOp + op_def static Z op(X d1, Y *params) { return op(d1, params[0]); } +}; - op_def static Z merge(Z old, Z opOutput, X *extraParams) { - return opOutput + old; - } +/** + * Whether 2 elements in an array + * are epsilion equal + */ +template +class Epsilon { + public: + op_def static Z op(X d1, X d2) { + X diff = d1 - d2; + X absDiff = sd::math::nd4j_abs(diff); + if (absDiff <= static_cast(MIN_V)) return static_cast(1); + return static_cast(0); + } - op_def static Z update(Z old, Z opOutput, X *extraParams) { - return opOutput + old; - } + op_def static Z op(X d1, X d2, X *params) { return op(d1, d2); } - op_def static Z op(X d1, X *extraParams) { - return d1 == static_cast(0) ? static_cast(1) : static_cast(0); - } + op_def static Z op(X d1, X *params) { return d1; } +}; - op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { - return static_cast(reduction); - } - }; +template +class EqualTo { + public: + op_def static Z op(X d1, X d2) { return d1 == d2; } + + op_def static Z op(X d1, X d2, X *params) { return op(d1, d2); } - template - class Prod { - public: - no_op_exec_special_accumulation_same - no_op_exec_special_accumulation_same_cuda + op_def static Z op(X d1, X *params) { return d1; } +}; + +template +class NotEqualTo { + public: + op_def static Z op(X d1, X d2) { return d1 != d2; } + + op_def static Z op(X d1, X d2, X *params) { return op(d1, d2); } + + op_def static Z op(X d1, X *params) { return d1; } +}; + +template +class GreaterThanOrEqual { + public: + op_def static Z op(X d1, X d2) { return d1 >= d2; } + + op_def static Z op(X d1, X d2, X *params) { return op(d1, d2); } + + // FIXME: this signature clashes with MetaOp stuff + op_def static Z op(X d1, X *params) { return d1; } +}; + +template +class GreaterThan { + public: + op_def static Z op(X d1, X d2) { return d1 > d2; } + + op_def static Z op(X d1, X d2, X *params) { return op(d1, d2); } + + // FIXME: this signature clashes with MetaOp stuff + op_def static Z op(X d1, X *params) { return d1; } +}; + +template +class LessThan { + public: + op_def static Z op(X d1, X d2) { return d1 < d2; } + + op_def static Z op(X d1, X d2, X *params) { return op(d1, d2); } + + op_def static Z op(X d1, X *params) { return d1; } +}; + +template +class LessThanOrEqual { + public: + op_def static Z op(X d1, X d2) { return d1 <= d2; } + + op_def static Z op(X d1, X d2, X *params) { return op(d1, d2); } + + op_def static Z op(X d1, X *params) { return d1; } +}; + +template +class Abs { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_abs(d1); + } +}; + +template +class Ceiling { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_ceil(d1); + } +}; + +template +class Cosine { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_cos(d1); + } +}; + +template +class Exp { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_exp(d1); + } +}; + +template +class HardTanhDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return ((d1 >= static_cast(-1.f) && d1 <= static_cast(1.f)) + ? static_cast(1.f) + : static_cast(0.f)); + } +}; + +template +class HardTanh { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + if (d1 < static_cast(-1)) + return static_cast(-1); + else if (d1 > static_cast(1)) + return static_cast(1); + else + return d1; + } +}; + +template +class Floor { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_floor(d1); + } +}; + +template +class Log { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_log(d1); + } +}; + +template +class Log1p { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_log(1 + d1); + } +}; + +template +class LogX { + public: + op_def static Z op(X d1, Y d2, Z *params) { + return sd::math::nd4j_log(d1) / sd::math::nd4j_log(d2); + } +}; + +template +class StabilizeFP16 { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + if (d1 <= static_cast(0)) + return static_cast(sd::DataTypeUtils::min()); + else + return d1; + } +}; + +template +class StabilizeX { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + if (d1 <= static_cast(0)) + return sd::DataTypeUtils::min(); + else + return d1; + } +}; + +template +class SpecialDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1 * (static_cast(1.f) - d1); + } +}; + +template +class Neg { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return -d1; + } +}; + +template +class Erf { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_erf(d1); + } +}; + +template +class Erfc { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_erfc(d1); + } +}; + +template +class Reciprocal { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + // op_def static T op(T d1) { + // return (T(1.0f) / d1); + // } + // op for MetaOps + op_def static X + op(X d1, X *params) { + return (static_cast(1) / d1); + } +}; + +template +class Sqr { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Z *params) { + return sd::math::nd4j_pow(d1, static_cast(2)); + } + + op_def static Z op(X d1) { + return sd::math::nd4j_pow(d1, static_cast(2)); + } +}; + +template +class RelativeError { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Y d2) { + return sd::math::nd4j_re(d1, d2); + } + + op_def static Z op(X d1, Y d2, Z *params) { return op(d1, d2); } + + op_def static Z op(X d1) { return static_cast(0); } +}; + +template +class BinaryRelativeError { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Y d2, Z *params) { + X threshold = params[0]; + return sd::math::nd4j_re(d1, d2) > threshold ? static_cast(1) + : static_cast(0); + } + + op_def static Z op(X d1) { return static_cast(0); } +}; + +template +class BinaryMinimumAbsoluteRelativeError { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, X *params) { + X d2 = params[0]; + X thresholdRelative = params[1]; + X thresholdAbsolute = params[2]; + return sd::math::nd4j_re(d1, d2) > thresholdRelative + ? (sd::math::nd4j_abs(d1 - static_cast(d2)) < + thresholdAbsolute + ? static_cast(0) + : static_cast(1)) + : static_cast(0); + } + + op_def static Z op(X d1, Y d2, Z *params) { + X thresholdRelative = params[0]; + X thresholdAbsolute = params[1]; + return sd::math::nd4j_re(d1, d2) > thresholdRelative + ? (sd::math::nd4j_abs(d1 - static_cast(d2)) < + thresholdAbsolute + ? static_cast(0) + : static_cast(1)) + : static_cast(0); + } + + op_def static Z op(X d1) { return static_cast(0); } +}; + +template +class ReversePow { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Z *params) { + return sd::math::nd4j_pow(params[0], d1); + } + + op_def static Z op(X d1, Y d2) { return sd::math::nd4j_pow(d2, d1); } + + op_def static Z op(X d1, Y d2, Z *params) { + return sd::math::nd4j_pow(d2, d1); + } + + op_def static Z op(X d1) { return d1; } +}; + +template +class Pow { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Z *params) { + return sd::math::nd4j_pow(d1, params[0]); + } + + op_def static Z op(X d1, Y d2) { return sd::math::nd4j_pow(d1, d2); } + + op_def static Z op(X d1, Y d2, Z *params) { + return sd::math::nd4j_pow(d1, d2); + } + + op_def static Z op(X d1) { return d1; } +}; + +template +class PowDerivative { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Z *params) { + return params[0] * sd::math::nd4j_pow( + d1, static_cast(params[0]) - static_cast(1.f)); + } + + op_def static Z op(X d1, Y d2) { + return static_cast(d2) * + sd::math::nd4j_pow( + d1, static_cast(d2) - static_cast(1.f)); + } + + op_def static Z op(X d1, Y d2, Z *params) { + return static_cast(d2) * + sd::math::nd4j_pow( + d1, static_cast(d2) - static_cast(1.f)); + } + + op_def static Z op(X d1) { return static_cast(d1); } +}; + +template +class IGamma { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Z *params) { + return sd::math::nd4j_igamma(d1, params[0]); + } + + op_def static Z op(X d1, Y d2) { + return sd::math::nd4j_igamma(d1, d2); + } + + op_def static Z op(X d1, Y d2, Z *params) { + return sd::math::nd4j_igamma(d1, d2); + } + + op_def static Z op(X d1) { return d1; } +}; + +template +class IGammac { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Z *params) { + return sd::math::nd4j_igammac(d1, params[0]); + } + + op_def static Z op(X d1, Y d2) { + return sd::math::nd4j_igammac(d1, d2); + } + + op_def static Z op(X d1, Y d2, Z *params) { + return sd::math::nd4j_igammac(d1, d2); + } + + op_def static Z op(X d1) { return d1; } +}; + +template +class Round { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_round(d1); + } +}; + +template +class IsNan { + public: + no_op_exec_special_bool no_op_exec_special_bool_cuda + + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + op_def static Z + op(X d1, X *params) { + return sd::math::nd4j_isnan(d1) ? static_cast(1) : static_cast(0); + } + + op_def static X startingValue(const X *input) { return static_cast(0); } + + op_def static Z merge(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static Z update(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction; + } +}; + +template +class Expm1 { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_exp(d1) - static_cast(1); + } +}; + +template +class IsPositive { + public: + no_op_exec_special_bool no_op_exec_special_bool_cuda + + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + op_def static Z + op(X d1, X *params) { + return d1 > (X)0.f; + } + + op_def static X startingValue(const X *input) { return static_cast(0); } + + op_def static Z merge(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static Z update(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction; + } +}; - const static functions::ReduceType reduceType = functions::ReduceType::PRODUCT; +template +class IsNegative { + public: + no_op_exec_special_bool no_op_exec_special_bool_cuda + + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + op_def static Z + op(X d1, X *params) { + return d1 < (X)0.f; + } + + op_def static X startingValue(const X *input) { return static_cast(0); } + + op_def static Z merge(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static Z update(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction; + } +}; + +template +class IsInf { + public: + no_op_exec_special_bool no_op_exec_special_bool_cuda + + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + op_def static Z + op(X d1, X *params) { + return sd::math::nd4j_isinf(d1) ? static_cast(1) : static_cast(0); + } + + op_def static X startingValue(const X *input) { return static_cast(0); } + + op_def static Z merge(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static Z update(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction; + } +}; + +template +class IsInfOrNan { + public: + no_op_exec_special_bool no_op_exec_special_bool_cuda + + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + op_def static Z + op(X d1, X *params) { + return sd::math::nd4j_isfin(d1) ? static_cast(0) : static_cast(1); + } + + op_def static X startingValue(const X *input) { return static_cast(0); } + + op_def static Z merge(X old, X opOutput, X *extraParams) { + return opOutput == static_cast(0) && old == static_cast(0) + ? static_cast(0) + : static_cast(1); + } + + op_def static Z update(X old, X opOutput, X *extraParams) { + return opOutput == static_cast(0) && old == static_cast(0) + ? static_cast(0) + : static_cast(1); + } + + op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction != static_cast(0); + } +}; + +template +class IsFinite { + public: + no_op_exec_special_bool no_op_exec_special_bool_cuda + + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + op_def static Z + op(X d1, X *params) { + return sd::math::nd4j_isfin(d1) ? static_cast(1) : static_cast(0); + } + + op_def static X startingValue(const X *input) { return static_cast(1); } + + op_def static Z merge(X old, X opOutput, X *extraParams) { + return opOutput == static_cast(0) || old == static_cast(0) + ? static_cast(0) + : static_cast(1); + } + + op_def static Z update(X old, X opOutput, X *extraParams) { + return opOutput == static_cast(0) || old == static_cast(0) + ? static_cast(0) + : static_cast(1); + } + + op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction != static_cast(0); + } +}; + +template +class ClipByValue { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + if (d1 > params[1]) return params[1]; + if (d1 < params[0]) return params[0]; + return d1; + } +}; + +template +class LstmClip { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Y d2, Z *params) { + X _v = (X)d2; + if (d1 > _v) + return _v; + else if (d1 < -_v) + return -_v; + else + return d1; + } +}; + +template +class Swish { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1 * sd::math::nd4j_sigmoid(d1); + } +}; + +template +class Mish { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1 * sd::math::nd4j_tanh(sd::math::nd4j_softplus(d1)); + } +}; + +template +class MishDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + auto ex = sd::math::nd4j_exp(d1); + auto e2x = ex * ex; + auto e3x = ex * ex * ex; + + return (ex * (4 * (d1 + 1) + 4 * e2x + e3x + ex * (4 * d1 + 6))) / + sd::math::nd4j_pow((2 * ex + e2x + 2), (X)2.f); + } +}; + +template +class GELU { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1 * sd::math::nd4j_sigmoid(static_cast(1.702f) * d1); + } +}; + +template +class PreciseGELU { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + auto sp = + sd::math::nd4j_sqrt(static_cast(2) / static_cast(M_PI)); + auto xp = d1 + sd::math::nd4j_pow(static_cast(0.044715) * d1, + static_cast(3)); + return (d1 / static_cast(2)) * + (static_cast(1) + sd::math::nd4j_tanh(sp * xp)); + } +}; + +template +class GELUDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + auto x17 = static_cast(1.702f) * d1; + auto ep = sd::math::nd4j_pow(static_cast(M_E), x17); + // (E^(1.702 x) (1. + E^(1.702 x) + 1.702 x))/(1. + E^(1.702 x))^2 + return (ep * (static_cast(1.f) + ep + x17)) / + sd::math::nd4j_pow((static_cast(1.f) + ep), 2); + } +}; + +template +class PreciseGELUDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + auto x79 = static_cast(0.797885) * d1; + auto x03 = sd::math::nd4j_pow(static_cast(0.0356774) * d1, 3); + auto x39 = static_cast(0.398942) * d1; + auto x05 = sd::math::nd4j_pow(static_cast(0.0535161) * d1, 3); + auto scz = sd::math::nd4j_sech(x79 + x03); + // 0.5 + (0.398942 x + 0.0535161 x^3) Sech[0.797885 x + 0.0356774 x^3]^2 + + // 0.5 Tanh[0.797885 x + 0.0356774 x^3] + return static_cast(0.5) + (x39 + x05) * (scz * scz) + + static_cast(0.5) * sd::math::nd4j_tanh(x79 + x03); + } +}; + +template +class SwishDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + X ex = sd::math::nd4j_pow(static_cast(M_E), d1); + return (ex * (d1 + ex + static_cast(1.f))) / + sd::math::nd4j_pow((ex + static_cast(1.f)), + static_cast(2.f)); + } +}; + +template +class LogSigmoid { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_log(sd::math::nd4j_sigmoid(d1)); + } +}; + +template +class LogSigmoidDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + X ex = sd::math::nd4j_pow(M_E, d1); + return static_cast(1.f) / (ex + static_cast(1.f)); + } +}; + +template +class Sigmoid { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_sigmoid(d1); + } +}; + +template +class Affine { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return params[0] * d1 + params[1]; + } +}; + +template +class SigmoidDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_sigmoidderivative(d1); + } +}; + +template +class HardSigmoid { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_min( + static_cast(1), + sd::math::nd4j_max(static_cast(0), (static_cast(0.2f)) * d1 + + static_cast(0.5f))); + } +}; + +template +class HardSigmoidDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1 < static_cast(-2.5f) || d1 > static_cast(2.5f) + ? static_cast(0.f) + : static_cast(0.2f); + } +}; + +/** + * Scale to be between a min and max + */ +template +class SetRange { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + auto min = params[0]; + auto max = params[1]; + if (static_cast(d1) >= min && static_cast(d1) <= max) return d1; + if (min == static_cast(0) && max == static_cast(1)) { + auto val = static_cast(1) / + (static_cast(1) + sd::math::nd4j_exp(-d1)); + return (sd::math::nd4j_floor(val * (max - min)) + min); + } + + return (sd::math::nd4j_floor(d1 * (max - min)) + min); + } +}; + +template +class Sin { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_sin(d1); + } +}; + +template +class Square { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1 * d1; + } +}; + +template +class Sqrt { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Z *params) { + return sd::math::nd4j_sqrt(d1); + } +}; + +template +class RSqrt { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Z *params) { + return static_cast(1) / sd::math::nd4j_sqrt(d1); + } +}; + +template +class Rint { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_rint(d1); + } +}; + +template +class SoftPlus { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_softplus(d1); + } +}; + +template +class Sign { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return (d1 > static_cast(0)) - (d1 < static_cast(0)); + } +}; + +template +class TimesOneMinus { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1 * (static_cast(1) - d1); + } +}; + +template +class RationalTanh { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + // keep 2/3 as runtime variable, to match precision + auto dis = (static_cast(2) / static_cast(3)) * d1; + + auto tanh = + sd::math::nd4j_sgn(dis) * + (static_cast(1) - + (static_cast(1) / + (static_cast(1) + static_cast(sd::math::nd4j_abs(dis)) + + sd::math::nd4j_pow(dis, static_cast(2)) + + static_cast(1.41645f) * + sd::math::nd4j_pow(dis, static_cast(4))))); + return static_cast(1.7159f) * tanh; + } +}; + +template +class RationalTanhDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + auto dis = (static_cast(2.f) / static_cast(3.f)) * d1; + + auto a = static_cast(1.f) + sd::math::nd4j_abs(dis) + + sd::math::nd4j_pow(dis, static_cast(2.f)) + + static_cast(1.41645f) * + sd::math::nd4j_pow(dis, static_cast(4)); + + auto tDeriv = + (static_cast(1.f) + + sd::math::nd4j_sign(dis) * + (static_cast(2.f) * dis + + static_cast(4.f) * static_cast(1.41645f) * + sd::math::nd4j_pow(dis, static_cast(3)))) / + (a * a); + + return static_cast(1.7159f) * + (static_cast(2.f) / static_cast(3.f)) * tDeriv; + } +}; + +template +class Tanh { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_tanh(d1); + } +}; + +template +class ScaledTanh { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return params[0] * sd::math::nd4j_tanh(params[1] * d1); + } +}; + +template +class RectifiedTanh { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_max(static_cast(0), + sd::math::nd4j_tanh(d1)); + } +}; + +template +class RectifiedTanhDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1 > static_cast(0.f) ? sd::math::nd4j_tanhderivative(d1) + : static_cast(0.f); + } +}; + +template +class ATanh { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_atanh(d1); + } +}; + +template +class TanhDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_tanhderivative(d1); + } +}; + +template +class Cube { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1 * d1 * d1; + } +}; + +template +class CubeDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return static_cast(3) * d1 * d1; + } +}; + +template +class ACos { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_acos(d1); + } +}; + +template +class ASinh { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_asinh(d1); + } +}; + +template +class ASinhDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return static_cast(1.f) / + (sd::math::nd4j_sqrt( + sd::math::nd4j_pow(d1, static_cast(2.f)) + + static_cast(1.f))); + } +}; + +template +class ACosh { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_acosh(d1); + } +}; + +template +class ACoshDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return static_cast(1.f) / + (sd::math::nd4j_sqrt(d1 - static_cast(1.f)) * + sd::math::nd4j_sqrt(d1 + static_cast(1.f))); + } +}; + +template +class Ones { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return static_cast(1.0f); + } +}; + +template +class SoftSign { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_softsign(d1); + } +}; + +template +class SoftSignDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_softsignderivative(d1); + } +}; + +template +class MatchConditionBool { + public: + no_op_exec_special_bool no_op_exec_special_bool_cuda + + // this op return 1.0 if condition met, 0.0 otherwise + op_def static Z + op(X d1, X *extraParams) { + X compare = extraParams[0]; + X eps = extraParams[1]; + + auto mode = static_cast(extraParams[2]); + // nd4j_printf("value: %f; comp: %f; eps: %f; mode: %i;\n", d1, compare, + // eps, mode); + + switch (mode) { + case 0: // equals + return sd::math::nd4j_abs(d1 - compare) <= eps ? true : false; + case 1: // not equals + return sd::math::nd4j_abs(d1 - compare) > eps ? true : false; + case 2: // less_than + return d1 < compare ? true : false; + case 3: // greater_than + return d1 > compare ? true : false; + case 4: // less_or_equals_than + return d1 <= compare ? true : false; + case 5: // greater_or_equals_than + return d1 >= compare ? true : false; + case 6: // abs_less_than + return sd::math::nd4j_abs(d1) < compare ? true : false; + case 7: // abs_greater_than + return sd::math::nd4j_abs(d1) > compare ? true : false; + case 8: // is inf + return sd::math::nd4j_isinf(d1) ? true : false; + case 9: // is nan + return sd::math::nd4j_isnan(d1) ? true : false; + case 10: + return (d1 == compare) ? true : false; + case 11: + return (d1 != compare) ? true : false; + case 12: // abs_greater_or_equals_than + return sd::math::nd4j_abs(d1) >= compare ? true : false; + case 13: // abs_less_or_equals_than + return sd::math::nd4j_abs(d1) <= compare ? true : false; + case 14: + // isFinite + return !(sd::math::nd4j_isinf(d1) || sd::math::nd4j_isnan(d1)); + case 15: + // isInfinite + return sd::math::nd4j_isinf(d1) || sd::math::nd4j_isnan(d1); + default: + printf("Undefined match condition: [%i]\n", mode); + } + + return d1; + } +}; + +template +class MatchCondition { + public: + no_op_exec_special no_op_exec_special_cuda + + no_op_exec_special_accumulation_long no_op_exec_special_accumulation_cuda + + op_def static Z + startingValue(const X *input) { + return static_cast(0); + } + + op_def static Z merge(Z old, Z opOutput, X *extraParams) { + return old + opOutput; + } + + op_def static Z update(Z old, Z opOutput, X *extraParams) { + return old + opOutput; + } + + op_def static Z op(X d1, X compare, X eps, int mode) { + switch (mode) { + case 0: // equals + return sd::math::nd4j_abs(d1 - compare) <= eps ? 1 : 0; + case 1: // not equals + return sd::math::nd4j_abs(d1 - compare) > eps ? 1 : 0; + case 2: // less_than + return d1 < compare ? 1 : 0; + case 3: // greater_than + return d1 > compare ? 1 : 0; + case 4: // less_or_equals_than + return d1 <= compare ? 1 : 0; + case 5: // greater_or_equals_than + return d1 >= compare ? 1 : 0; + case 6: // abs_less_than + return sd::math::nd4j_abs(d1) < compare ? 1 : 0; + case 7: // abs_greater_than + return sd::math::nd4j_abs(d1) > compare ? 1 : 0; + case 8: // is inf + return sd::math::nd4j_isinf(d1) ? 1 : 0; + case 9: // is nan + return sd::math::nd4j_isnan(d1) ? 1 : 0; + case 10: + return (d1 == compare) ? 1 : 0; + case 11: + return (d1 != compare) ? 1 : 0; + case 12: // abs_greater_or_equals_than + return sd::math::nd4j_abs(d1) >= compare ? 1 : 0; + case 13: // abs_less_or_equals_than + return sd::math::nd4j_abs(d1) <= compare ? 1 : 0; + case 14: + // isFinite + return !(sd::math::nd4j_isinf(d1) || sd::math::nd4j_isnan(d1)) ? 1 : 0; + case 15: + // isInfinite + return sd::math::nd4j_isinf(d1) || sd::math::nd4j_isnan(d1) ? 1 : 0; + default: + printf("Undefined match condition: [%i]\n", mode); + } + + return d1; + } + + // this op return 1.0 if condition met, 0.0 otherwise + op_def static Z op(X d1, X compare, X *extraParams) { + X eps = extraParams[1]; + + auto mode = static_cast(extraParams[0]); + + return op(d1, compare, eps, mode); + } + + // this op return 1.0 if condition met, 0.0 otherwise + op_def static Z op(X d1, X *extraParams) { + X compare = extraParams[0]; + X eps = extraParams[1]; + + auto mode = static_cast(extraParams[2]); + + return op(d1, compare, eps, mode); + } + + op_def static Z postProcess(Z reduction, Nd4jLong n, X *extraParams) { + return reduction; + } +}; + +template +class ELU { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static Z + op(X d1, Y d2, Z *params) { + return sd::math::nd4j_elu(d1, static_cast(d2)); + } +}; + +template +class ELUDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static Z + op(X d1, Y d2, Z *params) { + return sd::math::nd4j_eluderivative(d1, static_cast(d2)); + } +}; + +template +class RELU { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static Z + op(X d1, Y d2, Z *params) { + auto xt = static_cast(d1); + auto xf = static_cast(d2); + return xt < xf ? xf : xt; + } +}; + +template +class RELUDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static Z + op(X d1, Y d2, Z *params) { + auto xt = static_cast(d1); + auto xf = static_cast(d2); + return xt > xf ? static_cast(1.f) : static_cast(0.f); + } +}; + +template +class SXELogitsSmoother { + public: + op_def static Z op(X d1, Y d2, Z *params) { + return d1 * ((X)1.f - (X)d2) + (X)(0.5f) * (X)d2; + } +}; + +template +class RELU6 { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static Z + op(X d1, Y d2, Z *params) { + auto relu = simdOps::RELU::op(d1, d2, params); + return relu < static_cast(6) ? relu : static_cast(6); + } +}; + +template +class LeakyRELU { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Y d2, Z *params) { + auto val = static_cast(d1); + auto alpha = static_cast(d2); + return val < 0.0f ? alpha * val : val; + } +}; + +template +class SELU { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1 > static_cast(0.0f) + ? static_cast(SELU_LAMBDA) * static_cast(d1) + : static_cast(SELU_LAMBDA) * + (static_cast(SELU_ALPHA) * + sd::math::nd4j_exp(d1) - + static_cast(SELU_ALPHA)); + } +}; + +template +class SELUDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1 > static_cast(0.f) + ? static_cast(SELU_LAMBDA) + : static_cast(SELU_ALPHA) * static_cast(SELU_LAMBDA) * + sd::math::nd4j_exp(d1); + } +}; + +template +class LeakyRELUDerivative { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Y d2, Z *params) { + if (d1 >= static_cast(0)) + return static_cast(1); + else + return static_cast(d2); + } +}; + +template +class ASin { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_asin(d1); + } +}; + +template +class Sinh { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_sinh(d1); + } +}; + +template +class SinhDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_cosh(d1); + } +}; + +template +class Cosh { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_cosh(d1); + } +}; + +template +class Tan { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_tan(d1); + } +}; + +template +class TanDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return static_cast(1.f) / + sd::math::nd4j_pow(sd::math::nd4j_cos(d1), + static_cast(2.0f)); + } +}; + +template +class ATan { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_atan(d1); + } +}; + +template +class Atan2 { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Y d2) { + return sd::math::nd4j_atan2(d2, d1); + } + + op_def static Z op(X d1, Y d2, Z *params) { return op(d1, d2); } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { return op(d1, params[0]); } +}; + +template +class Identity { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1; + } +}; + +template +class Stabilize { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + X k = params[0]; + if (d1 * k > static_cast(-MIN_CUTFOFF)) + return static_cast(-MIN_CUTFOFF) / k; + else if (d1 * k < static_cast(MIN_CUTFOFF)) + return static_cast(MIN_CUTFOFF) / k; + return d1; + } +}; + +template +class Step { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static Z + op(X d1, Y d2, Z *params) { + return (d1 > static_cast(d2) ? static_cast(1) : static_cast(0)); + } +}; + +template +class OneMinus { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return static_cast(1) - d1; + } +}; + +template +class Sum { + public: + no_op_exec_special_accumulation_same no_op_exec_special_accumulation_same_cuda + + op_def static X + startingValue(const X *input) { + return static_cast(0.0f); + } + + op_def static X merge(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static X update(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static X op(X d1, X *extraParams) { return d1; } + + op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction; + } +}; + +template +class ReduceSameBenchmarkOp { + public: + no_op_exec_special_accumulation_same no_op_exec_special_accumulation_same_cuda + + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; + + op_def static X startingValue(const X *input) { return static_cast(0.0f); } + + op_def static X merge(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static X update(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static X op(X d1, X *extraParams) { + auto f1 = static_cast(d1); + return static_cast( + sd::math::nd4j_pow(f1, 3) + + sd::math::nd4j_log(f1) * + sd::math::nd4j_sin(f1) / + sd::math::nd4j_tanh(static_cast(M_E) * + static_cast(M_PI) * f1) * + sd::math::nd4j_sqrt(static_cast(M_PI) / f1) - + sd::math::nd4j_atan(static_cast(M_E) / f1)); + } + + op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction; + } +}; + +template +class ShannonEntropy { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; + + op_def static X startingValue(const X *input) { return static_cast(0); } + + op_def static Z merge(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } + + op_def static Z update(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } + + op_def static Z op(X d1, Z *extraParams) { + auto p = d1 * d1; + return static_cast(p) * sd::math::nd4j_log(p); + } + + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { + return -reduction; + } +}; + +template +class LogEntropy { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; + + op_def static X startingValue(const X *input) { return static_cast(0); } + + op_def static Z merge(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } + + op_def static Z update(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - op_def static X startingValue(const X *input) { - return static_cast(1); - } + op_def static Z op(X d1, Z *extraParams) { + return static_cast(d1) * sd::math::nd4j_log(d1); + } - op_def static X merge(X old, X opOutput, X *extraParams) { - return opOutput * old; - } + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { + // entropy is -sum(p(x) * log(p(x))); log entropy is log of this + return sd::math::nd4j_log(-reduction); + } +}; - op_def static X update(X old, X opOutput, X *extraParams) { - return opOutput * old; - } - - op_def static X op(X d1, X *extraParams) { - return d1; - } - - op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction; - } - }; - - - template - class Any { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda - - const static functions::ReduceType reduceType = functions::ReduceType::SUM; - - op_def static X startingValue(const X *input) { - return static_cast(0.0f); - } - - op_def static Z merge(X old, X opOutput, X *extraParams) { - return opOutput + old; - } - - op_def static Z update(X old, X opOutput, X *extraParams) { - return opOutput + old; - } - - op_def static Z op(X d1, X *extraParams) { - return d1; - } - - op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction > static_cast(0) ? static_cast(1) : static_cast(0) ; - } - }; - - - template - class All { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda - - const static functions::ReduceType reduceType = functions::ReduceType::PRODUCT; - - op_def static X startingValue(const X *input) { - return static_cast(1); - } +template +class Entropy { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; - op_def static Z merge(X old, X opOutput, X *extraParams) { - return opOutput * old; - } + op_def static X startingValue(const X *input) { return static_cast(0); } + + op_def static Z merge(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - op_def static Z update(X old, X opOutput, X *extraParams) { - return opOutput * old; - } + op_def static Z update(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - op_def static Z op(X d1, X *extraParams) { - return d1; - } + op_def static Z op(X d1, Z *extraParams) { + return static_cast(d1) * sd::math::nd4j_log(d1); + } - op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction > static_cast(0) ? static_cast(1) : static_cast(0); - } - }; + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { + return static_cast(-reduction); // entropy is -sum(p(x) * log(p(x))) + } +}; - template - class Mean { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda +template +class ASum { + public: + no_op_exec_special_accumulation_same no_op_exec_special_accumulation_same_cuda - const static functions::ReduceType reduceType = functions::ReduceType::SUM; + const static functions::ReduceType reduceType = + functions::ReduceType::ASUM; - op_def static X startingValue(const X *input) { - return static_cast(0); - } + op_def static X startingValue(const X *input) { return static_cast(0); } - op_def static Z merge(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } + op_def static X merge(X old, X opOutput, X *extraParams) { + return sd::math::nd4j_abs(opOutput) + sd::math::nd4j_abs(old); + } - op_def static Z update(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } + op_def static X update(X old, X opOutput, X *extraParams) { + return sd::math::nd4j_abs(opOutput) + sd::math::nd4j_abs(old); + } - op_def static Z op(X d1, Z *extraParams) { - return d1; - } + op_def static X op(X d1, X *extraParams) { return sd::math::nd4j_abs(d1); } - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { - return reduction / (Z) n; - } - }; + op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { + return sd::math::nd4j_abs(reduction); + } +}; - template - class ReduceFloatBenchmarkOp { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda +template +class CountNonZero { + public: + no_op_exec_special_accumulation_long no_op_exec_special_accumulation_cuda - const static functions::ReduceType reduceType = functions::ReduceType::SUM; + const static functions::ReduceType reduceType = + functions::ReduceType::ASUM; - op_def static X startingValue(const X *input) { - return static_cast(0); - } + op_def static Z startingValue(const X *input) { return static_cast(0); } - op_def static Z merge(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } + op_def static Z merge(Z old, Z opOutput, X *extraParams) { + return opOutput + old; + } - op_def static Z update(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } + op_def static Z update(Z old, Z opOutput, X *extraParams) { + return opOutput + old; + } - op_def static Z op(X d1, Z *extraParams) { - auto f1 = static_cast(d1); - return static_cast(sd::math::nd4j_pow(f1, 3) - + sd::math::nd4j_log(f1) * sd::math::nd4j_sin(f1) - / sd::math::nd4j_tanh(static_cast(M_E) * static_cast(M_PI) * f1) - * sd::math::nd4j_sqrt(static_cast(M_PI) / f1) - - sd::math::nd4j_atan(static_cast(M_E) / f1)); - } + op_def static Z op(X d1, X *extraParams) { + return d1 == static_cast(0.0f) ? static_cast(0.0f) + : static_cast(1.0f); + } - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { - return (Z) reduction / (Z) n; - } - }; + op_def static Z postProcess(Z reduction, Nd4jLong n, X *extraParams) { + return reduction; + } +}; +template +class CountZero { + public: + no_op_exec_special_accumulation_long no_op_exec_special_accumulation_cuda - template - class AMean { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; - const static functions::ReduceType reduceType = functions::ReduceType::SUM; + op_def static Z startingValue(const X *input) { return static_cast(0.0f); } - op_def static X startingValue(const X *input) { - return static_cast(0); - } + op_def static Z merge(Z old, Z opOutput, X *extraParams) { + return opOutput + old; + } - op_def static Z merge(Z old, Z opOutput, Z *extraParams) { - return sd::math::nd4j_abs(opOutput) + sd::math::nd4j_abs(old); - } + op_def static Z update(Z old, Z opOutput, X *extraParams) { + return opOutput + old; + } - op_def static Z update(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } + op_def static Z op(X d1, X *extraParams) { + return d1 == static_cast(0) ? static_cast(1) : static_cast(0); + } - op_def static Z op(X d1, Z *extraParams) { - return sd::math::nd4j_abs(d1); - } + op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { + return static_cast(reduction); + } +}; - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { - return sd::math::nd4j_abs(reduction) / static_cast(n); - } - }; +template +class Prod { + public: + no_op_exec_special_accumulation_same no_op_exec_special_accumulation_same_cuda - template - class Max { - public: - no_op_exec_special_accumulation_same - no_op_exec_special_accumulation_same_cuda + const static functions::ReduceType reduceType = + functions::ReduceType::PRODUCT; - const static functions::ReduceType reduceType = functions::ReduceType::MAX; + op_def static X startingValue(const X *input) { return static_cast(1); } - op_def static X startingValue(const X *input) { - return -sd::DataTypeUtils::infOrMax(); - } + op_def static X merge(X old, X opOutput, X *extraParams) { + return opOutput * old; + } - op_def static X merge(X old, X opOutput, X *extraParams) { - return sd::math::nd4j_max(old, opOutput); - } + op_def static X update(X old, X opOutput, X *extraParams) { + return opOutput * old; + } - op_def static X update(X old, X opOutput, X *extraParams) { - return sd::math::nd4j_max(opOutput, old); - } + op_def static X op(X d1, X *extraParams) { return d1; } - op_def static X op(X d1, X d2, X *params) { - return sd::math::nd4j_max(d1, d2); - } + op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction; + } +}; - op_def static X op(X d1, X d2) { - return sd::math::nd4j_max(d1, d2); - } +template +class Any { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda - // FIXME: this signature overlaps with MetaOp - op_def static X op(X d1, X *extraParams) { - return d1; - } - - op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction; - } - }; - - - template - class AMaxPairwise { - public: - op_def static Z op(X d1, Y d2, Z *params) { - return op(d1, d2); - } - - op_def static Z op(X d1, Y d2) { - auto z1 = static_cast(d1); - auto z2 = static_cast(d2); - - if (sd::math::nd4j_abs(z1) > sd::math::nd4j_abs(z2)) - return z1; - else - return z2; - } - }; - - - template - class AMinPairwise { - public: - op_def static Z op(X d1, Y d2, Z *params) { - return op(d1, d2); - } - - op_def static Z op(X d1, Y d2) { - auto z1 = static_cast(d1); - auto z2 = static_cast(d2); - - if (sd::math::nd4j_abs(z1) < sd::math::nd4j_abs(z2)) - return z1; - else - return z2; - } - }; - - template - class MaxPairwise { - public: - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_max(static_cast(d1), static_cast(d2)); - } - - op_def static Z op(X d1, Y d2) { - return sd::math::nd4j_max(static_cast(d1), static_cast(d2)); - } - }; - - - template - class MinPairwise { - public: - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_min(static_cast(d1), static_cast(d2)); - } - - op_def static Z op(X d1, Y d2) { - return sd::math::nd4j_min(static_cast(d1), static_cast(d2)); - } - }; - - template - class AMax { - public: - no_op_exec_special_accumulation_same - no_op_exec_special_accumulation_same_cuda - - const static functions::ReduceType reduceType = functions::ReduceType::AMAX; - - op_def static X startingValue(const X *input) { - return input[0]; - } + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; - op_def static X merge(X old, X opOutput, X *extraParams) { - return sd::math::nd4j_max(sd::math::nd4j_abs(old), sd::math::nd4j_abs(opOutput)); - } + op_def static X startingValue(const X *input) { return static_cast(0.0f); } - op_def static X update(X old, X opOutput, X *extraParams) { - return sd::math::nd4j_max(sd::math::nd4j_abs(opOutput), sd::math::nd4j_abs(old)); - } + op_def static Z merge(X old, X opOutput, X *extraParams) { + return opOutput + old; + } - op_def static X op(X d1, X d2, X *params) { - return sd::math::nd4j_max(sd::math::nd4j_abs(d1), sd::math::nd4j_abs(d2)); - } + op_def static Z update(X old, X opOutput, X *extraParams) { + return opOutput + old; + } - op_def static X op(X d1, X d2) { - return sd::math::nd4j_abs(d1) > sd::math::nd4j_abs(d2) ? d1 : d2; - } + op_def static Z op(X d1, X *extraParams) { return d1; } - // FIXME: this signature overlaps with MetaOp - op_def static X op(X d1, X *extraParams) { - return sd::math::nd4j_abs(d1); - } + op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction > static_cast(0) ? static_cast(1) + : static_cast(0); + } +}; - op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { - return sd::math::nd4j_abs(reduction); - } - }; +template +class All { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + const static functions::ReduceType reduceType = + functions::ReduceType::PRODUCT; - template - class AMin { - public: - no_op_exec_special_accumulation_same - no_op_exec_special_accumulation_same_cuda + op_def static X startingValue(const X *input) { return static_cast(1); } - const static functions::ReduceType reduceType = functions::ReduceType::AMIN; + op_def static Z merge(X old, X opOutput, X *extraParams) { + return opOutput * old; + } - op_def static X startingValue(const X *input) { - return input[0]; - } + op_def static Z update(X old, X opOutput, X *extraParams) { + return opOutput * old; + } - op_def static X merge(X old, X opOutput, X *extraParams) { - return sd::math::nd4j_min(sd::math::nd4j_abs(old), sd::math::nd4j_abs(opOutput)); - } + op_def static Z op(X d1, X *extraParams) { return d1; } - op_def static X update(X old, X opOutput, X *extraParams) { - return sd::math::nd4j_min(sd::math::nd4j_abs(opOutput), sd::math::nd4j_abs(old)); - } + op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction > static_cast(0) ? static_cast(1) + : static_cast(0); + } +}; - op_def static X op(X d1, X d2, X *params) { - return sd::math::nd4j_min(sd::math::nd4j_abs(d1), sd::math::nd4j_abs(d2)); - } +template +class Mean { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda - op_def static X op(X d1, X d2) { - return sd::math::nd4j_min(sd::math::nd4j_abs(d1), sd::math::nd4j_abs(d2)); - } + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; - // FIXME: this signature overlaps with MetaOp - op_def static X op(X d1, X *extraParams) { - return sd::math::nd4j_abs(d1); - } + op_def static X startingValue(const X *input) { return static_cast(0); } - op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { - return sd::math::nd4j_abs(reduction); - } - }; + op_def static Z merge(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - template - class Min { - public: - no_op_exec_special_accumulation_same - no_op_exec_special_accumulation_same_cuda + op_def static Z update(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - const static functions::ReduceType reduceType = functions::ReduceType::MIN; + op_def static Z op(X d1, Z *extraParams) { return d1; } - op_def static X startingValue(const X *input) { - return sd::DataTypeUtils::infOrMax(); - } + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { + return reduction / (Z)n; + } +}; - op_def static X merge(X old, X opOutput, X *extraParams) { - return sd::math::nd4j_min(old, opOutput); - } +template +class ReduceFloatBenchmarkOp { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda - op_def static X update(X old, X opOutput, X *extraParams) { - return sd::math::nd4j_min(opOutput, old); - } + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; - op_def static X op(X d1, X d2, X *params) { - return sd::math::nd4j_min(d1, d2); - } + op_def static X startingValue(const X *input) { return static_cast(0); } - op_def static X op(X d1, X d2) { - return sd::math::nd4j_min(d1, d2); - } + op_def static Z merge(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - // FIXME: this signature overlaps with MetaOp - op_def static X op(X d1, X *extraParams) { - return d1; - } + op_def static Z update(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction; - } - }; + op_def static Z op(X d1, Z *extraParams) { + auto f1 = static_cast(d1); + return static_cast( + sd::math::nd4j_pow(f1, 3) + + sd::math::nd4j_log(f1) * + sd::math::nd4j_sin(f1) / + sd::math::nd4j_tanh(static_cast(M_E) * + static_cast(M_PI) * f1) * + sd::math::nd4j_sqrt(static_cast(M_PI) / f1) - + sd::math::nd4j_atan(static_cast(M_E) / f1)); + } + + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { + return (Z)reduction / (Z)n; + } +}; + +template +class AMean { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; + + op_def static X startingValue(const X *input) { return static_cast(0); } + + op_def static Z merge(Z old, Z opOutput, Z *extraParams) { + return sd::math::nd4j_abs(opOutput) + sd::math::nd4j_abs(old); + } + + op_def static Z update(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } + + op_def static Z op(X d1, Z *extraParams) { return sd::math::nd4j_abs(d1); } + + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { + return sd::math::nd4j_abs(reduction) / static_cast(n); + } +}; + +template +class Max { + public: + no_op_exec_special_accumulation_same no_op_exec_special_accumulation_same_cuda + + const static functions::ReduceType reduceType = + functions::ReduceType::MAX; + + op_def static X startingValue(const X *input) { + return -sd::DataTypeUtils::infOrMax(); + } + + op_def static X merge(X old, X opOutput, X *extraParams) { + return sd::math::nd4j_max(old, opOutput); + } + + op_def static X update(X old, X opOutput, X *extraParams) { + return sd::math::nd4j_max(opOutput, old); + } + + op_def static X op(X d1, X d2, X *params) { + return sd::math::nd4j_max(d1, d2); + } + + op_def static X op(X d1, X d2) { return sd::math::nd4j_max(d1, d2); } + + // FIXME: this signature overlaps with MetaOp + op_def static X op(X d1, X *extraParams) { return d1; } + + op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction; + } +}; + +template +class AMaxPairwise { + public: + op_def static Z op(X d1, Y d2, Z *params) { return op(d1, d2); } + + op_def static Z op(X d1, Y d2) { + auto z1 = static_cast(d1); + auto z2 = static_cast(d2); + + if (sd::math::nd4j_abs(z1) > sd::math::nd4j_abs(z2)) + return z1; + else + return z2; + } +}; + +template +class AMinPairwise { + public: + op_def static Z op(X d1, Y d2, Z *params) { return op(d1, d2); } + + op_def static Z op(X d1, Y d2) { + auto z1 = static_cast(d1); + auto z2 = static_cast(d2); + + if (sd::math::nd4j_abs(z1) < sd::math::nd4j_abs(z2)) + return z1; + else + return z2; + } +}; + +template +class MaxPairwise { + public: + op_def static Z op(X d1, Y d2, Z *params) { + return sd::math::nd4j_max(static_cast(d1), static_cast(d2)); + } + + op_def static Z op(X d1, Y d2) { + return sd::math::nd4j_max(static_cast(d1), static_cast(d2)); + } +}; + +template +class MinPairwise { + public: + op_def static Z op(X d1, Y d2, Z *params) { + return sd::math::nd4j_min(static_cast(d1), static_cast(d2)); + } + + op_def static Z op(X d1, Y d2) { + return sd::math::nd4j_min(static_cast(d1), static_cast(d2)); + } +}; + +template +class AMax { + public: + no_op_exec_special_accumulation_same no_op_exec_special_accumulation_same_cuda + + const static functions::ReduceType reduceType = + functions::ReduceType::AMAX; + + op_def static X startingValue(const X *input) { return input[0]; } + + op_def static X merge(X old, X opOutput, X *extraParams) { + return sd::math::nd4j_max(sd::math::nd4j_abs(old), + sd::math::nd4j_abs(opOutput)); + } + + op_def static X update(X old, X opOutput, X *extraParams) { + return sd::math::nd4j_max(sd::math::nd4j_abs(opOutput), + sd::math::nd4j_abs(old)); + } + + op_def static X op(X d1, X d2, X *params) { + return sd::math::nd4j_max(sd::math::nd4j_abs(d1), + sd::math::nd4j_abs(d2)); + } + + op_def static X op(X d1, X d2) { + return sd::math::nd4j_abs(d1) > sd::math::nd4j_abs(d2) ? d1 : d2; + } + + // FIXME: this signature overlaps with MetaOp + op_def static X op(X d1, X *extraParams) { return sd::math::nd4j_abs(d1); } + + op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { + return sd::math::nd4j_abs(reduction); + } +}; + +template +class AMin { + public: + no_op_exec_special_accumulation_same no_op_exec_special_accumulation_same_cuda + + const static functions::ReduceType reduceType = + functions::ReduceType::AMIN; + + op_def static X startingValue(const X *input) { return input[0]; } + + op_def static X merge(X old, X opOutput, X *extraParams) { + return sd::math::nd4j_min(sd::math::nd4j_abs(old), + sd::math::nd4j_abs(opOutput)); + } + + op_def static X update(X old, X opOutput, X *extraParams) { + return sd::math::nd4j_min(sd::math::nd4j_abs(opOutput), + sd::math::nd4j_abs(old)); + } + + op_def static X op(X d1, X d2, X *params) { + return sd::math::nd4j_min(sd::math::nd4j_abs(d1), + sd::math::nd4j_abs(d2)); + } + + op_def static X op(X d1, X d2) { + return sd::math::nd4j_min(sd::math::nd4j_abs(d1), + sd::math::nd4j_abs(d2)); + } + + // FIXME: this signature overlaps with MetaOp + op_def static X op(X d1, X *extraParams) { return sd::math::nd4j_abs(d1); } + + op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { + return sd::math::nd4j_abs(reduction); + } +}; +template +class Min { + public: + no_op_exec_special_accumulation_same no_op_exec_special_accumulation_same_cuda - template - class Norm1 { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda + const static functions::ReduceType reduceType = + functions::ReduceType::MIN; - const static functions::ReduceType reduceType = functions::ReduceType::SUM; + op_def static X startingValue(const X *input) { + return sd::DataTypeUtils::infOrMax(); + } - op_def static X startingValue(const X *input) { - return static_cast(0); - } + op_def static X merge(X old, X opOutput, X *extraParams) { + return sd::math::nd4j_min(old, opOutput); + } - op_def static Z merge(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; + op_def static X update(X old, X opOutput, X *extraParams) { + return sd::math::nd4j_min(opOutput, old); + } - } + op_def static X op(X d1, X d2, X *params) { + return sd::math::nd4j_min(d1, d2); + } - op_def static Z update(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; + op_def static X op(X d1, X d2) { return sd::math::nd4j_min(d1, d2); } - } + // FIXME: this signature overlaps with MetaOp + op_def static X op(X d1, X *extraParams) { return d1; } - op_def static Z op(X d1, Z *extraParams) { - return static_cast(sd::math::nd4j_abs(d1)); - } + op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction; + } +}; - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { - return reduction; - } - }; +template +class Norm1 { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; - template - class Norm2 { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda + op_def static X startingValue(const X *input) { return static_cast(0); } - const static functions::ReduceType reduceType = functions::ReduceType::SUM; + op_def static Z merge(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - op_def static X startingValue(const X *input) { - return static_cast(0); - } + op_def static Z update(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - op_def static Z merge(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } + op_def static Z op(X d1, Z *extraParams) { + return static_cast(sd::math::nd4j_abs(d1)); + } + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { + return reduction; + } +}; - op_def static Z update(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } +template +class Norm2 { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { - return sd::math::nd4j_sqrt(reduction); - } + op_def static X startingValue(const X *input) { return static_cast(0); } - op_def static Z op(X d1, Z *extraParams) { - return static_cast(d1 * d1); - } - }; + op_def static Z merge(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - template - class SquaredNorm { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda + op_def static Z update(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - const static functions::ReduceType reduceType = functions::ReduceType::SUM; + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { + return sd::math::nd4j_sqrt(reduction); + } - op_def static X startingValue(const X *input) { - return static_cast(0); - } + op_def static Z op(X d1, Z *extraParams) { return static_cast(d1 * d1); } +}; - op_def static Z merge(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } +template +class SquaredNorm { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; - op_def static Z update(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } + op_def static X startingValue(const X *input) { return static_cast(0); } - op_def static Z op(X d1, Z *extraParams) { - return static_cast(d1 * d1); - } + op_def static Z merge(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { - return reduction; - } - }; + op_def static Z update(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - template - class NormFrobenius { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda + op_def static Z op(X d1, Z *extraParams) { return static_cast(d1 * d1); } - const static functions::ReduceType reduceType = functions::ReduceType::SUM; + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { + return reduction; + } +}; - op_def static X startingValue(const X *input) { - return static_cast(0); - } +template +class NormFrobenius { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda - op_def static Z merge(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; + op_def static X startingValue(const X *input) { return static_cast(0); } - op_def static Z update(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } + op_def static Z merge(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - op_def static Z op(X d1, Z *extraParams) { - X v = sd::math::nd4j_abs(d1); - return static_cast(v * v); - } + op_def static Z update(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { - return sd::math::nd4j_sqrt(reduction); - } - }; + op_def static Z op(X d1, Z *extraParams) { + X v = sd::math::nd4j_abs(d1); + return static_cast(v * v); + } - template - class NormP { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { + return sd::math::nd4j_sqrt(reduction); + } +}; - const static functions::ReduceType reduceType = functions::ReduceType::SUM; +template +class NormP { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda - op_def static X startingValue(const X *input) { - return static_cast(0); - } + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; - op_def static Z merge(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } + op_def static X startingValue(const X *input) { return static_cast(0); } - op_def static Z update(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } - - op_def static Z op(X d1, Z *extraParams) { - return sd::math::nd4j_pow(sd::math::nd4j_abs(d1), extraParams[0]); - } - - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { - return sd::math::nd4j_pow(reduction, static_cast(1.0f) / extraParams[0]); - } - }; - - template - class NormMax { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda - - const static functions::ReduceType reduceType = functions::ReduceType::SUM; - - op_def static X startingValue(const X *input) { - return static_cast(0); - } - - op_def static Z merge(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - - } - - op_def static Z update(Z old, Z opOutput, Z *extraParams) { - return sd::math::nd4j_max(sd::math::nd4j_abs(old), - sd::math::nd4j_abs(opOutput)); - } - - op_def static Z op(X d1, Z *extraParams) { - return static_cast(d1); - } - - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { - return sd::math::nd4j_max(sd::math::nd4j_abs(reduction), sd::math::nd4j_abs(reduction)); - } - }; - - template - class Variance { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda - - const static functions::ReduceType reduceType = functions::ReduceType::SUM; - - op_def static X startingValue(const X *input) { - return static_cast(0.0f); - } - - op_def static Z merge(X old, X opOutput, Z *extraParams) { - return old + opOutput; - } - - op_def static Z update(X old, X opOutput, Z *extraParams) { - return old + opOutput; - - } - - op_def static X op(X d1, Z *extraParams) { - X mean = static_cast(extraParams[0]); - X ret = d1 - mean; - return ret * ret; - } - - op_def static Z postProcess(X reduction, Nd4jLong n, Z *extraParams) { - // T bias = extraParams[1]; - // return (reduction - (sd::math::nd4j_pow(bias, static_cast(2.0f)) / static_cast(n))) / (n - 1) - return static_cast(reduction) / static_cast(n - 1); - } - }; - - /** - * Standard deviation of a buffer - */ - template - class StandardDeviation { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda - - const static functions::ReduceType reduceType = functions::ReduceType::SUM; - - op_def static X startingValue(const X *input) { - return static_cast(0.0f); - } - - op_def static Z merge(X old, X opOutput, Z *extraParams) { - return old + opOutput; - } - - op_def static Z update(X old, X opOutput, Z *extraParams) { - return old + opOutput; - - } - - op_def static Z op(X d1, Z *extraParams) { - X mean = extraParams[0]; - X ret = d1 - mean; - return ret * ret; - } - - op_def static Z postProcess(X reduction, Nd4jLong n, Z *extraParams) { - Z ret = Variance::postProcess(reduction, n, extraParams); - Z sqrtRet = sd::math::nd4j_sqrt(ret); - return sqrtRet; - } - }; - - template - class CosineSimilarity { - public: - static const int extraParamsLen = 2; - - op_def static X *generateExtraParams() { - //T *extraParams = new T[2]; - return nullptr; - } - - op_def static void finalizeExtraParams(X *extraParams) { - //delete[] extraParams; - } - - op_def static Y startingValue(const X *input) { - return static_cast(0.0f); - } - - op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParams) { - return reduction / (sd::math::nd4j_sqrt(extraParams[0]) * sd::math::nd4j_sqrt(extraParams[1])); - } - - op_def static Y op(X d1, X d2, Y *extraParams) { - extraParams[0] += static_cast(d1 * d1); - extraParams[1] += static_cast(d2 * d2); - return static_cast(d1 * d2); - } - - op_def static void aggregateExtraParams(Y *extraParamsTotal, Y *extraParamsLocal) { - extraParamsTotal[0] += extraParamsLocal[0]; - extraParamsTotal[1] += extraParamsLocal[1]; - } + op_def static Z merge(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } + + op_def static Z update(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } + + op_def static Z op(X d1, Z *extraParams) { + return sd::math::nd4j_pow(sd::math::nd4j_abs(d1), + extraParams[0]); + } + + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { + return sd::math::nd4j_pow(reduction, + static_cast(1.0f) / extraParams[0]); + } +}; + +template +class NormMax { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; + + op_def static X startingValue(const X *input) { return static_cast(0); } + + op_def static Z merge(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } + + op_def static Z update(Z old, Z opOutput, Z *extraParams) { + return sd::math::nd4j_max(sd::math::nd4j_abs(old), + sd::math::nd4j_abs(opOutput)); + } + + op_def static Z op(X d1, Z *extraParams) { return static_cast(d1); } + + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { + return sd::math::nd4j_max(sd::math::nd4j_abs(reduction), + sd::math::nd4j_abs(reduction)); + } +}; + +template +class Variance { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; + + op_def static X startingValue(const X *input) { return static_cast(0.0f); } + + op_def static Z merge(X old, X opOutput, Z *extraParams) { + return old + opOutput; + } + + op_def static Z update(X old, X opOutput, Z *extraParams) { + return old + opOutput; + } + + op_def static X op(X d1, Z *extraParams) { + X mean = static_cast(extraParams[0]); + X ret = d1 - mean; + return ret * ret; + } + + op_def static Z postProcess(X reduction, Nd4jLong n, Z *extraParams) { + // T bias = extraParams[1]; + // return (reduction - (sd::math::nd4j_pow(bias, static_cast(2.0f)) / + // static_cast(n))) / (n - 1) + return static_cast(reduction) / static_cast(n - 1); + } +}; + +/** + * Standard deviation of a buffer + */ +template +class StandardDeviation { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; + + op_def static X startingValue(const X *input) { return static_cast(0.0f); } + + op_def static Z merge(X old, X opOutput, Z *extraParams) { + return old + opOutput; + } + + op_def static Z update(X old, X opOutput, Z *extraParams) { + return old + opOutput; + } + + op_def static Z op(X d1, Z *extraParams) { + X mean = extraParams[0]; + X ret = d1 - mean; + return ret * ret; + } + + op_def static Z postProcess(X reduction, Nd4jLong n, Z *extraParams) { + Z ret = Variance::postProcess(reduction, n, extraParams); + Z sqrtRet = sd::math::nd4j_sqrt(ret); + return sqrtRet; + } +}; + +template +class CosineSimilarity { + public: + static const int extraParamsLen = 2; + + op_def static X *generateExtraParams() { + // T *extraParams = new T[2]; + return nullptr; + } + + op_def static void finalizeExtraParams(X *extraParams) { + // delete[] extraParams; + } + + op_def static Y startingValue(const X *input) { return static_cast(0.0f); } + + op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParams) { + return reduction / (sd::math::nd4j_sqrt(extraParams[0]) * + sd::math::nd4j_sqrt(extraParams[1])); + } + + op_def static Y op(X d1, X d2, Y *extraParams) { + extraParams[0] += static_cast(d1 * d1); + extraParams[1] += static_cast(d2 * d2); + return static_cast(d1 * d2); + } + + op_def static void aggregateExtraParams(Y *extraParamsTotal, + Y *extraParamsLocal) { + extraParamsTotal[0] += extraParamsLocal[0]; + extraParamsTotal[1] += extraParamsLocal[1]; + } #ifdef __CUDACC__ - static _CUDA_D inline Y opAtomic(X d1, X d2, Y *extraParams) { - sd::math::atomics::nd4j_atomicAdd(&extraParams[0],static_cast(d1 * d1)); - sd::math::atomics::nd4j_atomicAdd(&extraParams[1],static_cast(d2 * d2)); + static _CUDA_D inline Y opAtomic(X d1, X d2, Y *extraParams) { + sd::math::atomics::nd4j_atomicAdd(&extraParams[0], static_cast(d1 * d1)); + sd::math::atomics::nd4j_atomicAdd(&extraParams[1], static_cast(d2 * d2)); - return static_cast(d1 * d2); - } + return static_cast(d1 * d2); + } #endif - op_def static Y update(Y old, Y opOutput, Y *extraParams) { - return old + opOutput; - } - - - op_def static Y merge(Y old, Y opOutput, Y *extraParams) { - return update(old, opOutput, extraParams); - } - }; + op_def static Y update(Y old, Y opOutput, Y *extraParams) { + return old + opOutput; + } + op_def static Y merge(Y old, Y opOutput, Y *extraParams) { + return update(old, opOutput, extraParams); + } +}; - template - class JaccardDistance { - public: - static const int extraParamsLen = 2; +template +class JaccardDistance { + public: + static const int extraParamsLen = 2; - op_def static X *generateExtraParams() { - //T *extraParams = new T[2]; - return nullptr; - } + op_def static X *generateExtraParams() { + // T *extraParams = new T[2]; + return nullptr; + } - op_def static void finalizeExtraParams(X *extraParams) { - //delete[] extraParams; - } + op_def static void finalizeExtraParams(X *extraParams) { + // delete[] extraParams; + } - op_def static Y startingValue(const X *input) { - return static_cast(0.0f); - } + op_def static Y startingValue(const X *input) { return static_cast(0.0f); } - op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParams) { - // num / denom - return (static_cast(1.0f)) - (extraParams[0] / extraParams[1]); - } + op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParams) { + // num / denom + return (static_cast(1.0f)) - (extraParams[0] / extraParams[1]); + } - op_def static Y num(X d1, X d2) { - return sd::math::nd4j_min(d1, d2); - } + op_def static Y num(X d1, X d2) { return sd::math::nd4j_min(d1, d2); } - op_def static Y denom(X d1, X d2) { - return sd::math::nd4j_max(d1, d2); - } + op_def static Y denom(X d1, X d2) { return sd::math::nd4j_max(d1, d2); } - op_def static Y op(X d1, X d2, Y *extraParams) { - extraParams[0] += static_cast(num(d1, d2)); - extraParams[1] += static_cast(denom(d1, d2)); - return static_cast(0.0f); - } + op_def static Y op(X d1, X d2, Y *extraParams) { + extraParams[0] += static_cast(num(d1, d2)); + extraParams[1] += static_cast(denom(d1, d2)); + return static_cast(0.0f); + } - op_def static void aggregateExtraParams(Y *extraParamsTotal, Y *extraParamsLocal) { - extraParamsTotal[0] += extraParamsLocal[0]; - extraParamsTotal[1] += extraParamsLocal[1]; - } + op_def static void aggregateExtraParams(Y *extraParamsTotal, + Y *extraParamsLocal) { + extraParamsTotal[0] += extraParamsLocal[0]; + extraParamsTotal[1] += extraParamsLocal[1]; + } #ifdef __CUDACC__ - __device__ - static inline Y opAtomic(X d1, X d2, Y *extraParams) { - sd::math::atomics::nd4j_atomicAdd(&extraParams[0],num(d1, d2)); - sd::math::atomics::nd4j_atomicAdd(&extraParams[1], denom(d1, d2)); + __device__ static inline Y opAtomic(X d1, X d2, Y *extraParams) { + sd::math::atomics::nd4j_atomicAdd(&extraParams[0], num(d1, d2)); + sd::math::atomics::nd4j_atomicAdd(&extraParams[1], denom(d1, d2)); - return static_cast(0.0f); - } + return static_cast(0.0f); + } #endif - op_def static Y update(Y old, Y opOutput, Y *extraParams) { - return old + opOutput; - } + op_def static Y update(Y old, Y opOutput, Y *extraParams) { + return old + opOutput; + } + op_def static Y merge(Y old, Y opOutput, Y *extraParams) { + return update(old, opOutput, extraParams); + } +}; - op_def static Y merge(Y old, Y opOutput, Y *extraParams) { - return update(old, opOutput, extraParams); - } - }; +template +class SimpleHammingDistance { + public: + static const int extraParamsLen = 0; + op_def static X *generateExtraParams() { + // T *extraParams = new T[2]; + return nullptr; + } - template - class SimpleHammingDistance { - public: - static const int extraParamsLen = 0; + op_def static void finalizeExtraParams(X *extraParams) { + // delete[] extraParams; + } - op_def static X *generateExtraParams() { - //T *extraParams = new T[2]; - return nullptr; - } + op_def static Y startingValue(const X *input) { return static_cast(0.0f); } - op_def static void finalizeExtraParams(X *extraParams) { - //delete[] extraParams; - } + op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParams) { + return static_cast(reduction / n); + } - op_def static Y startingValue(const X *input) { - return static_cast(0.0f); - } - - op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParams) { - return static_cast(reduction / n); - } - - op_def static Y op(X d1, X d2, Y *extraParams) { - return (d1 == d2) ? static_cast(0.0f) : static_cast(1.0f); - } - - op_def static void aggregateExtraParams(Y *extraParamsTotal, Y *extraParamsLocal) { + op_def static Y op(X d1, X d2, Y *extraParams) { + return (d1 == d2) ? static_cast(0.0f) : static_cast(1.0f); + } - } + op_def static void aggregateExtraParams(Y *extraParamsTotal, + Y *extraParamsLocal) {} #ifdef __CUDACC__ - __device__ - static inline Y opAtomic(X d1, X d2, Y *extraParams) { - return op(d1, d2, extraParams); - } + __device__ static inline Y opAtomic(X d1, X d2, Y *extraParams) { + return op(d1, d2, extraParams); + } #endif - op_def static Y update(Y old, Y opOutput, Y *extraParams) { - return old + opOutput; - } - - - op_def static Y merge(Y old, Y opOutput, Y *extraParams) { - return update(old, opOutput, extraParams); - } - }; - - template - class CosineDistance { - public: - static const int extraParamsLen = 2; - - op_def static X *generateExtraParams() { - //T *extraParams = new T[2]; - return nullptr; - } - - op_def static void finalizeExtraParams(X *extraParams) { - //delete[] extraParams; - } - - op_def static Y startingValue(const X *input) { - return static_cast(0.0f); - } - - op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParams) { - return (static_cast(1.0f)) - (reduction / (sd::math::nd4j_sqrt(extraParams[0]) * sd::math::nd4j_sqrt(extraParams[1]))); - } - - op_def static Y op(X d1, X d2, Y *extraParams) { - extraParams[0] += static_cast(sd::math::nd4j_abs(d1) * sd::math::nd4j_abs(d1)); - extraParams[1] += static_cast(sd::math::nd4j_abs(d2) * sd::math::nd4j_abs(d2)); - return (d1 * d2); - } - - op_def static void aggregateExtraParams(Y *extraParamsTotal, Y *extraParamsLocal) { - extraParamsTotal[0] += extraParamsLocal[0]; - extraParamsTotal[1] += extraParamsLocal[1]; - } + op_def static Y update(Y old, Y opOutput, Y *extraParams) { + return old + opOutput; + } + + op_def static Y merge(Y old, Y opOutput, Y *extraParams) { + return update(old, opOutput, extraParams); + } +}; + +template +class CosineDistance { + public: + static const int extraParamsLen = 2; + + op_def static X *generateExtraParams() { + // T *extraParams = new T[2]; + return nullptr; + } + + op_def static void finalizeExtraParams(X *extraParams) { + // delete[] extraParams; + } + + op_def static Y startingValue(const X *input) { return static_cast(0.0f); } + + op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParams) { + return (static_cast(1.0f)) - + (reduction / (sd::math::nd4j_sqrt(extraParams[0]) * + sd::math::nd4j_sqrt(extraParams[1]))); + } + + op_def static Y op(X d1, X d2, Y *extraParams) { + extraParams[0] += + static_cast(sd::math::nd4j_abs(d1) * sd::math::nd4j_abs(d1)); + extraParams[1] += + static_cast(sd::math::nd4j_abs(d2) * sd::math::nd4j_abs(d2)); + return (d1 * d2); + } + + op_def static void aggregateExtraParams(Y *extraParamsTotal, + Y *extraParamsLocal) { + extraParamsTotal[0] += extraParamsLocal[0]; + extraParamsTotal[1] += extraParamsLocal[1]; + } #ifdef __CUDACC__ - static _CUDA_D inline Y opAtomic(X d1, X d2, Y *extraParams) { - sd::math::atomics::nd4j_atomicAdd(&extraParams[0], sd::math::nd4j_abs(d1) * sd::math::nd4j_abs(d1)); - sd::math::atomics::nd4j_atomicAdd(&extraParams[1], sd::math::nd4j_abs(d2) * sd::math::nd4j_abs(d2)); - - return (d1 * d2); - } + static _CUDA_D inline Y opAtomic(X d1, X d2, Y *extraParams) { + sd::math::atomics::nd4j_atomicAdd( + &extraParams[0], sd::math::nd4j_abs(d1) * sd::math::nd4j_abs(d1)); + sd::math::atomics::nd4j_atomicAdd( + &extraParams[1], sd::math::nd4j_abs(d2) * sd::math::nd4j_abs(d2)); + + return (d1 * d2); + } #endif - op_def static Y update(Y old, Y opOutput, Y *extraParams) { - return old + opOutput; - } - - - op_def static Y merge(Y old, Y opOutput, Y *extraParams) { - return update(old, opOutput, extraParams); - } - }; + op_def static Y update(Y old, Y opOutput, Y *extraParams) { + return old + opOutput; + } + op_def static Y merge(Y old, Y opOutput, Y *extraParams) { + return update(old, opOutput, extraParams); + } +}; - /** - * Dot product between 2 arrays - */ - template - class Dot { - public: - static const int extraParamsLen = 0; +/** + * Dot product between 2 arrays + */ +template +class Dot { + public: + static const int extraParamsLen = 0; - op_def static X * generateExtraParams() { - return nullptr; - } + op_def static X *generateExtraParams() { return nullptr; } - op_def static void finalizeExtraParams(X *extraParamsRef) { - //no-op - //delete[] * extraParamsRef; - } + op_def static void finalizeExtraParams(X *extraParamsRef) { + // no-op + // delete[] * extraParamsRef; + } - op_def static Y startingValue(const X *input) { - return static_cast(0.0f); - } + op_def static Y startingValue(const X *input) { return static_cast(0.0f); } - op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParamsRef) { - return reduction; - } - - op_def static Y op(X d1, X d2, Y *extraParamsRef) { - return static_cast(d1 * d2); - } + op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParamsRef) { + return reduction; + } + op_def static Y op(X d1, X d2, Y *extraParamsRef) { + return static_cast(d1 * d2); + } #ifdef __CUDACC__ - __device__ - static inline Y opAtomic(X d1, X d2, Y *extraParamsRef) { - return op(d1, d2, extraParamsRef); - } + __device__ static inline Y opAtomic(X d1, X d2, Y *extraParamsRef) { + return op(d1, d2, extraParamsRef); + } #endif - op_def static Y update(Y old, Y opOutput, Y *extraParamsRef) { - return opOutput + old; - } + op_def static Y update(Y old, Y opOutput, Y *extraParamsRef) { + return opOutput + old; + } - op_def static Y merge(Y old, Y opOutput, Y *extraParamsRef) { - return update(old, opOutput, extraParamsRef); - } + op_def static Y merge(Y old, Y opOutput, Y *extraParamsRef) { + return update(old, opOutput, extraParamsRef); + } - op_def static void aggregateExtraParams(Y *extraParamsTotal, Y *extraParamsLocal) {} - }; + op_def static void aggregateExtraParams(Y *extraParamsTotal, + Y *extraParamsLocal) {} +}; +/** + * Op to check equality within arrays + */ +template +class EqualsWithEps { + public: + static const int extraParamsLen = 0; - /** - * Op to check equality within arrays - */ - template - class EqualsWithEps { - public: - static const int extraParamsLen = 0; + op_def static X *generateExtraParams() { return nullptr; } - op_def static X * generateExtraParams() { - return nullptr; - } + op_def static void finalizeExtraParams(X *extraParamsRef) { + // no-op + } - op_def static void finalizeExtraParams(X *extraParamsRef) { - //no-op - } - - op_def static Z startingValue(const X *input) { - return static_cast(0.0f); - } - - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParamsRef) { - return reduction; - } + op_def static Z startingValue(const X *input) { return static_cast(0.0f); } - op_def static Z op(X d1, X d2, Z *extraParamsRef) { - double eps = sd::math::nd4j_abs(extraParamsRef[2]); - return static_cast(!sd::math::nd4j_eq(d1, d2, eps)); - } + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParamsRef) { + return reduction; + } + op_def static Z op(X d1, X d2, Z *extraParamsRef) { + double eps = sd::math::nd4j_abs(extraParamsRef[2]); + return static_cast(!sd::math::nd4j_eq(d1, d2, eps)); + } #ifdef __CUDACC__ - __device__ - static inline Z opAtomic(X d1, X d2, Z *extraParamsRef) { - return op(d1, d2, extraParamsRef); - } + __device__ static inline Z opAtomic(X d1, X d2, Z *extraParamsRef) { + return op(d1, d2, extraParamsRef); + } #endif - op_def static Z update(Z old, Z opOutput, Z *extraParamsRef) { - return opOutput + old; - } + op_def static Z update(Z old, Z opOutput, Z *extraParamsRef) { + return opOutput + old; + } - op_def static Z merge(X old, Z opOutput, Z *extraParamsRef) { - return update(old, opOutput, extraParamsRef); - } + op_def static Z merge(X old, Z opOutput, Z *extraParamsRef) { + return update(old, opOutput, extraParamsRef); + } - op_def static void aggregateExtraParams(Z *extraParamsTotal, Z *extraParamsLocal) {} - }; + op_def static void aggregateExtraParams(Z *extraParamsTotal, + Z *extraParamsLocal) {} +}; +template +class EuclideanDistance { + public: + static const int extraParamsLen = 0; + op_def static X *generateExtraParams() { return nullptr; } - template - class EuclideanDistance { - public: - static const int extraParamsLen = 0; + op_def static void finalizeExtraParams(X *extraParamsRef) { + // no-op + } - op_def static X * generateExtraParams() { - return nullptr; - } + op_def static Y startingValue(const X *input) { return static_cast(0.0f); } - op_def static void finalizeExtraParams(X *extraParamsRef) { - //no-op - } - - op_def static Y startingValue(const X *input) { - return static_cast(0.0f); - } - - op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParamsRef) { - return sd::math::nd4j_sqrt(reduction); - } - - op_def static Y op(X d1, X d2, Y *extraParamsRef) { - X ret = d1 - d2; - return static_cast(ret * ret); - } + op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParamsRef) { + return sd::math::nd4j_sqrt(reduction); + } + op_def static Y op(X d1, X d2, Y *extraParamsRef) { + X ret = d1 - d2; + return static_cast(ret * ret); + } #ifdef __CUDACC__ - __device__ - static inline Y opAtomic(X d1, X d2, Y *extraParamsRef) { - return op(d1, d2, extraParamsRef); - } + __device__ static inline Y opAtomic(X d1, X d2, Y *extraParamsRef) { + return op(d1, d2, extraParamsRef); + } #endif - op_def static Y update(Y old, Y opOutput, Y *extraParamsRef) { - return opOutput + old; - } - - op_def static Y merge(Y old, Y opOutput, Y *extraParamsRef) { - return update(old, opOutput, extraParamsRef); - } - op_def static void aggregateExtraParams(Y *extraParamsTotal, Y *extraParamsLocal) {} - - }; - - - template - class ManhattanDistance { - public: - static const int extraParamsLen = 0; + op_def static Y update(Y old, Y opOutput, Y *extraParamsRef) { + return opOutput + old; + } - op_def static X * generateExtraParams() { - return nullptr; - } + op_def static Y merge(Y old, Y opOutput, Y *extraParamsRef) { + return update(old, opOutput, extraParamsRef); + } + op_def static void aggregateExtraParams(Y *extraParamsTotal, + Y *extraParamsLocal) {} +}; - op_def static void finalizeExtraParams(X *extraParamsRef) { - //no-op - } +template +class ManhattanDistance { + public: + static const int extraParamsLen = 0; - op_def static Y startingValue(const X *input) { - return static_cast(0.0f); - } + op_def static X *generateExtraParams() { return nullptr; } - op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParamsRef) { - return reduction; - } + op_def static void finalizeExtraParams(X *extraParamsRef) { + // no-op + } - op_def static Y op(X d1, X d2, Y *extraParamsRef) { - return sd::math::nd4j_abs(d1 - d2); - } + op_def static Y startingValue(const X *input) { return static_cast(0.0f); } - op_def static Y update(Y old, Y opOutput, Y *extraParamsRef) { - return old + opOutput; - } + op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParamsRef) { + return reduction; + } - op_def static void aggregateExtraParams(Y *extraParamsTotal, Y *extraParamsLocal) { + op_def static Y op(X d1, X d2, Y *extraParamsRef) { + return sd::math::nd4j_abs(d1 - d2); + } - } + op_def static Y update(Y old, Y opOutput, Y *extraParamsRef) { + return old + opOutput; + } + op_def static void aggregateExtraParams(Y *extraParamsTotal, + Y *extraParamsLocal) {} #ifdef __CUDACC__ - __device__ - static inline Y opAtomic(X d1, X d2, Y *extraParamsRef) { - return op(d1, d2, extraParamsRef); - } + __device__ static inline Y opAtomic(X d1, X d2, Y *extraParamsRef) { + return op(d1, d2, extraParamsRef); + } #endif - op_def static Y merge(X old, X opOutput, X *extraParamsRef) { - return update(old, opOutput, extraParamsRef); - } - }; - - template - class IndexAbsoluteMax { - public: - static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue val, X *extraParams) { - return sd::math::nd4j_abs(val); - } - - static _CUDA_HD inline functions::indexreduce::IndexValue update(functions::indexreduce::IndexValue &old, functions::indexreduce::IndexValue &opOutput, X *extraParams) { - opOutput.value = sd::math::nd4j_abs(opOutput.value); - old.value = sd::math::nd4j_abs(old.value); - if (opOutput.value > old.value) - return opOutput; + op_def static Y merge(X old, X opOutput, X *extraParamsRef) { + return update(old, opOutput, extraParamsRef); + } +}; + +template +class IndexAbsoluteMax { + public: + static _CUDA_HD inline functions::indexreduce::IndexValue op( + functions::indexreduce::IndexValue val, X *extraParams) { + return sd::math::nd4j_abs(val); + } + + static _CUDA_HD inline functions::indexreduce::IndexValue update( + functions::indexreduce::IndexValue &old, + functions::indexreduce::IndexValue &opOutput, X *extraParams) { + opOutput.value = sd::math::nd4j_abs(opOutput.value); + old.value = sd::math::nd4j_abs(old.value); + if (opOutput.value > old.value) return opOutput; #ifdef __CUDACC__ - // workaround for cuda race condition at merge phase - else if (opOutput.value == old.value && opOutput.index < old.index) - return opOutput; + // workaround for cuda race condition at merge phase + else if (opOutput.value == old.value && opOutput.index < old.index) + return opOutput; #elif defined(__GNUC__) #endif - return old; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue merge( - functions::indexreduce::IndexValue f1, - functions::indexreduce::IndexValue f2, X *extraParams) { - if (sd::math::nd4j_abs(f1.value) > sd::math::nd4j_abs(f2.value)) - return f2; - return f1; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue postProcess( - functions::indexreduce::IndexValue reduction, int n, int xOffset, - X *dx, int incx, X *extraParams, X *result) { - return reduction; - } - - static _CUDA_HD inline X startingValue(const X *input) { - return 0; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue startingIndexValue(const X *input) { - functions::indexreduce::IndexValue local; - local.value = startingValue(input); - local.index = 0; - return local; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue d1, - functions::indexreduce::IndexValue d2, X *extraParams) { - return d1; - } - }; - - template - class FirstIndex { - public: - static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue val, X *extraParams) { - return val; - } - - static _CUDA_HD functions::indexreduce::IndexValue update(functions::indexreduce::IndexValue &old, functions::indexreduce::IndexValue &opOutput, X *extraParams) { - + return old; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue merge( + functions::indexreduce::IndexValue f1, + functions::indexreduce::IndexValue f2, X *extraParams) { + if (sd::math::nd4j_abs(f1.value) > sd::math::nd4j_abs(f2.value)) + return f2; + return f1; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue postProcess( + functions::indexreduce::IndexValue reduction, int n, int xOffset, + X *dx, int incx, X *extraParams, X *result) { + return reduction; + } + + static _CUDA_HD inline X startingValue(const X *input) { return 0; } + + static _CUDA_HD inline functions::indexreduce::IndexValue + startingIndexValue(const X *input) { + functions::indexreduce::IndexValue local; + local.value = startingValue(input); + local.index = 0; + return local; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue op( + functions::indexreduce::IndexValue d1, + functions::indexreduce::IndexValue d2, X *extraParams) { + return d1; + } +}; + +template +class FirstIndex { + public: + static _CUDA_HD inline functions::indexreduce::IndexValue op( + functions::indexreduce::IndexValue val, X *extraParams) { + return val; + } + + static _CUDA_HD functions::indexreduce::IndexValue update( + functions::indexreduce::IndexValue &old, + functions::indexreduce::IndexValue &opOutput, X *extraParams) { #ifdef __CUDACC__ - if (opOutput.index < 0) - return old; + if (opOutput.index < 0) return old; #endif - auto res = simdOps::MatchCondition::op(opOutput.value, extraParams); - - //printf("res: %f; oldIdx: %i; newIdx: %i\n", res, old.index, opOutput.index); - - if (res == static_cast(0)) - return old; - - if (old.index < 0) - return opOutput; - - if (old.index > opOutput.index) - return opOutput; - - return old; - } - - static _CUDA_HD inline X startingValue(const X *input) { - return -sd::DataTypeUtils::infOrMax(); - } - - static _CUDA_HD inline functions::indexreduce::IndexValue startingIndexValue(const X *input) { - functions::indexreduce::IndexValue local; - local.value = startingValue(input); - local.index = -1; - return local; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue d1, - functions::indexreduce::IndexValue d2, X *extraParams) { - return d1; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue merge( - functions::indexreduce::IndexValue f1, - functions::indexreduce::IndexValue f2, X *extraParams) { - if (f1.index > f2.index) - return f2; - return f1; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue postProcess( - functions::indexreduce::IndexValue reduction, int n, int xOffset, - X *dx, int incx, X *extraParams, X *result) { - return reduction; - } - }; - - - template - class LastIndex { - public: - static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue val, X *extraParams) { - return val; - } - - static _CUDA_HD functions::indexreduce::IndexValue update(functions::indexreduce::IndexValue &old, functions::indexreduce::IndexValue &opOutput, X *extraParams) { + auto res = simdOps::MatchCondition::op(opOutput.value, extraParams); + + // printf("res: %f; oldIdx: %i; newIdx: %i\n", res, old.index, + // opOutput.index); + + if (res == static_cast(0)) return old; + + if (old.index < 0) return opOutput; + + if (old.index > opOutput.index) return opOutput; + + return old; + } + + static _CUDA_HD inline X startingValue(const X *input) { + return -sd::DataTypeUtils::infOrMax(); + } + + static _CUDA_HD inline functions::indexreduce::IndexValue + startingIndexValue(const X *input) { + functions::indexreduce::IndexValue local; + local.value = startingValue(input); + local.index = -1; + return local; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue op( + functions::indexreduce::IndexValue d1, + functions::indexreduce::IndexValue d2, X *extraParams) { + return d1; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue merge( + functions::indexreduce::IndexValue f1, + functions::indexreduce::IndexValue f2, X *extraParams) { + if (f1.index > f2.index) return f2; + return f1; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue postProcess( + functions::indexreduce::IndexValue reduction, int n, int xOffset, + X *dx, int incx, X *extraParams, X *result) { + return reduction; + } +}; + +template +class LastIndex { + public: + static _CUDA_HD inline functions::indexreduce::IndexValue op( + functions::indexreduce::IndexValue val, X *extraParams) { + return val; + } + + static _CUDA_HD functions::indexreduce::IndexValue update( + functions::indexreduce::IndexValue &old, + functions::indexreduce::IndexValue &opOutput, X *extraParams) { #ifdef __CUDACC__ - if (opOutput.index < 0) - return old; + if (opOutput.index < 0) return old; #endif - auto res = simdOps::MatchCondition::op(opOutput.value, extraParams); - - if (res == static_cast(0)) - return old; - - if (old.index < 0) - return opOutput; - - if (old.index < opOutput.index) - return opOutput; - - return old; - } - - static _CUDA_HD inline X startingValue(const X *input) { - return -sd::DataTypeUtils::infOrMax(); - } - - static _CUDA_HD inline functions::indexreduce::IndexValue startingIndexValue(const X *input) { - functions::indexreduce::IndexValue local; - local.value = startingValue(input); - local.index = -1; - return local; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue d1, - functions::indexreduce::IndexValue d2, X *extraParams) { - return d1; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue merge( - functions::indexreduce::IndexValue f1, - functions::indexreduce::IndexValue f2, X *extraParams) { - if (f1.index < f2.index) - return f2; - return f1; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue postProcess( - functions::indexreduce::IndexValue reduction, int n, int xOffset, - X *dx, int incx, X *extraParams, X *result) { - return reduction; - } - }; - - - template - class IndexMax { - public: - - static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue val, X *extraParams) { - return val; - } - - static _CUDA_HD functions::indexreduce::IndexValue update(functions::indexreduce::IndexValue &old, functions::indexreduce::IndexValue &opOutput, X *extraParams) { - if (opOutput.value > old.value) { - return opOutput; - } + auto res = simdOps::MatchCondition::op(opOutput.value, extraParams); + + if (res == static_cast(0)) return old; + + if (old.index < 0) return opOutput; + + if (old.index < opOutput.index) return opOutput; + + return old; + } + + static _CUDA_HD inline X startingValue(const X *input) { + return -sd::DataTypeUtils::infOrMax(); + } + + static _CUDA_HD inline functions::indexreduce::IndexValue + startingIndexValue(const X *input) { + functions::indexreduce::IndexValue local; + local.value = startingValue(input); + local.index = -1; + return local; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue op( + functions::indexreduce::IndexValue d1, + functions::indexreduce::IndexValue d2, X *extraParams) { + return d1; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue merge( + functions::indexreduce::IndexValue f1, + functions::indexreduce::IndexValue f2, X *extraParams) { + if (f1.index < f2.index) return f2; + return f1; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue postProcess( + functions::indexreduce::IndexValue reduction, int n, int xOffset, + X *dx, int incx, X *extraParams, X *result) { + return reduction; + } +}; + +template +class IndexMax { + public: + static _CUDA_HD inline functions::indexreduce::IndexValue op( + functions::indexreduce::IndexValue val, X *extraParams) { + return val; + } + + static _CUDA_HD functions::indexreduce::IndexValue update( + functions::indexreduce::IndexValue &old, + functions::indexreduce::IndexValue &opOutput, X *extraParams) { + if (opOutput.value > old.value) { + return opOutput; + } #ifdef __CUDACC__ - // workaround for cuda race condition at merge phase - else if (opOutput.value == old.value && opOutput.index < old.index) - return opOutput; + // workaround for cuda race condition at merge phase + else if (opOutput.value == old.value && opOutput.index < old.index) + return opOutput; #elif defined(__GNUC__) #endif - return old; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue merge( - functions::indexreduce::IndexValue f1, - functions::indexreduce::IndexValue f2, X *extraParams) { - if (f1.value > f2.value) - return f2; - return f1; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue postProcess( - functions::indexreduce::IndexValue reduction, int n, int xOffset, - X *dx, int incx, X *extraParams, X *result) { - return reduction; - } - - static _CUDA_HD inline X startingValue(const X *input) { - return -sd::DataTypeUtils::infOrMax(); - } - - static _CUDA_HD inline functions::indexreduce::IndexValue startingIndexValue(const X *input) { - functions::indexreduce::IndexValue local; - local.value = startingValue(input); - local.index = 0; - return local; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue d1, - functions::indexreduce::IndexValue d2, X *extraParams) { - return d1; - } - }; - - - template - class IndexAbsoluteMin { - public: - static _CUDA_HD inline functions::indexreduce::IndexValue op( - functions::indexreduce::IndexValue val, X *extraParams) { - return val; - } - - static _CUDA_HD inline X startingValue(const X *input) { - return sd::DataTypeUtils::infOrMax(); - } - - static _CUDA_HD inline functions::indexreduce::IndexValue startingIndexValue(const X *input) { - functions::indexreduce::IndexValue local; - local.value = startingValue(input); - local.index = 0; - return local; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue update(functions::indexreduce::IndexValue &old, functions::indexreduce::IndexValue &opOutput, X *extraParams) { - opOutput.value = sd::math::nd4j_abs(opOutput.value); - old.value = sd::math::nd4j_abs(old.value); - if (opOutput.value < old.value) - return opOutput; + return old; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue merge( + functions::indexreduce::IndexValue f1, + functions::indexreduce::IndexValue f2, X *extraParams) { + if (f1.value > f2.value) return f2; + return f1; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue postProcess( + functions::indexreduce::IndexValue reduction, int n, int xOffset, + X *dx, int incx, X *extraParams, X *result) { + return reduction; + } + + static _CUDA_HD inline X startingValue(const X *input) { + return -sd::DataTypeUtils::infOrMax(); + } + + static _CUDA_HD inline functions::indexreduce::IndexValue + startingIndexValue(const X *input) { + functions::indexreduce::IndexValue local; + local.value = startingValue(input); + local.index = 0; + return local; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue op( + functions::indexreduce::IndexValue d1, + functions::indexreduce::IndexValue d2, X *extraParams) { + return d1; + } +}; + +template +class IndexAbsoluteMin { + public: + static _CUDA_HD inline functions::indexreduce::IndexValue op( + functions::indexreduce::IndexValue val, X *extraParams) { + return val; + } + + static _CUDA_HD inline X startingValue(const X *input) { + return sd::DataTypeUtils::infOrMax(); + } + + static _CUDA_HD inline functions::indexreduce::IndexValue + startingIndexValue(const X *input) { + functions::indexreduce::IndexValue local; + local.value = startingValue(input); + local.index = 0; + return local; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue update( + functions::indexreduce::IndexValue &old, + functions::indexreduce::IndexValue &opOutput, X *extraParams) { + opOutput.value = sd::math::nd4j_abs(opOutput.value); + old.value = sd::math::nd4j_abs(old.value); + if (opOutput.value < old.value) return opOutput; #ifdef __CUDACC__ - // workaround for cuda race condition at merge phase - else if (opOutput.value == old.value && opOutput.index < old.index) - return opOutput; + // workaround for cuda race condition at merge phase + else if (opOutput.value == old.value && opOutput.index < old.index) + return opOutput; #elif defined(__GNUC__) #endif - return old; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue merge( - functions::indexreduce::IndexValue f1, - functions::indexreduce::IndexValue f2, X *extraParams) { - if (sd::math::nd4j_abs(f1.value) < sd::math::nd4j_abs(f2.value)) - return f2; - return f1; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue postProcess( - functions::indexreduce::IndexValue reduction, int n, int xOffset, - X *dx, int incx, X *extraParams, X *result) { - return reduction; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue d1, - functions::indexreduce::IndexValue d2, X *extraParams) { - return d1; - } - }; - - - template - class IndexMin { - public: - static _CUDA_HD inline functions::indexreduce::IndexValue op( - functions::indexreduce::IndexValue val, X *extraParams) { - return val; - } - - static _CUDA_HD inline X startingValue(const X *input) { - return sd::DataTypeUtils::infOrMax(); - } - - static _CUDA_HD inline functions::indexreduce::IndexValue startingIndexValue(const X *input) { - functions::indexreduce::IndexValue local; - local.value = startingValue(input); - local.index = 0; - return local; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue update(functions::indexreduce::IndexValue &old, functions::indexreduce::IndexValue &opOutput, X *extraParams) { - if (opOutput.value < old.value) - return opOutput; + return old; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue merge( + functions::indexreduce::IndexValue f1, + functions::indexreduce::IndexValue f2, X *extraParams) { + if (sd::math::nd4j_abs(f1.value) < sd::math::nd4j_abs(f2.value)) + return f2; + return f1; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue postProcess( + functions::indexreduce::IndexValue reduction, int n, int xOffset, + X *dx, int incx, X *extraParams, X *result) { + return reduction; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue op( + functions::indexreduce::IndexValue d1, + functions::indexreduce::IndexValue d2, X *extraParams) { + return d1; + } +}; + +template +class IndexMin { + public: + static _CUDA_HD inline functions::indexreduce::IndexValue op( + functions::indexreduce::IndexValue val, X *extraParams) { + return val; + } + + static _CUDA_HD inline X startingValue(const X *input) { + return sd::DataTypeUtils::infOrMax(); + } + + static _CUDA_HD inline functions::indexreduce::IndexValue + startingIndexValue(const X *input) { + functions::indexreduce::IndexValue local; + local.value = startingValue(input); + local.index = 0; + return local; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue update( + functions::indexreduce::IndexValue &old, + functions::indexreduce::IndexValue &opOutput, X *extraParams) { + if (opOutput.value < old.value) return opOutput; #ifdef __CUDACC__ - // workaround for cuda race condition at merge phase - else if (opOutput.value == old.value && opOutput.index < old.index) - return opOutput; + // workaround for cuda race condition at merge phase + else if (opOutput.value == old.value && opOutput.index < old.index) + return opOutput; #elif defined(__GNUC__) #endif - return old; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue merge( - functions::indexreduce::IndexValue f1, - functions::indexreduce::IndexValue f2, X *extraParams) { - if (f1.value < f2.value) - return f2; - return f1; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue postProcess( - functions::indexreduce::IndexValue reduction, int n, int xOffset, - X *dx, int incx, X *extraParams, X *result) { - return reduction; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue d1, - functions::indexreduce::IndexValue d2, X *extraParams) { - return d1; - } - }; - - template - class SummaryStatsVariance { - public: - - static _CUDA_HD inline Z getValue(const bool biasCorrected, functions::summarystats::SummaryStatsData val) { - if (biasCorrected) { - Z ret = static_cast(val.varianceBiasCorrected()); - if (ret < static_cast(0.0f)) - return static_cast(val.variance()); - return ret; - } - return static_cast(val.variance()); - } - - static _CUDA_HD inline functions::summarystats::SummaryStatsData op(functions::summarystats::SummaryStatsData d1, Z *extraParams) { - return d1; - } - }; - - template - class SummaryStatsStandardDeviation { - public: - - static _CUDA_HD inline Z getValue(const bool biasCorrected, functions::summarystats::SummaryStatsData val) { - if (biasCorrected) { - auto ret = static_cast(val.varianceBiasCorrected()); - if (ret < static_cast(0.0f)) - return sd::math::nd4j_sqrt(val.variance()); - else - return sd::math::nd4j_sqrt(ret); - } - return sd::math::nd4j_sqrt(val.variance()); - } - - static _CUDA_HD inline functions::summarystats::SummaryStatsData op(functions::summarystats::SummaryStatsData d1, Z *extraParams) { - return d1; - } - }; - - template - class DropOut { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - inline _CUDA_D static X op(X d1, X *params) { - X prob = params[0]; + return old; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue merge( + functions::indexreduce::IndexValue f1, + functions::indexreduce::IndexValue f2, X *extraParams) { + if (f1.value < f2.value) return f2; + return f1; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue postProcess( + functions::indexreduce::IndexValue reduction, int n, int xOffset, + X *dx, int incx, X *extraParams, X *result) { + return reduction; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue op( + functions::indexreduce::IndexValue d1, + functions::indexreduce::IndexValue d2, X *extraParams) { + return d1; + } +}; + +template +class SummaryStatsVariance { + public: + static _CUDA_HD inline Z getValue( + const bool biasCorrected, + functions::summarystats::SummaryStatsData val) { + if (biasCorrected) { + Z ret = static_cast(val.varianceBiasCorrected()); + if (ret < static_cast(0.0f)) return static_cast(val.variance()); + return ret; + } + return static_cast(val.variance()); + } + + static _CUDA_HD inline functions::summarystats::SummaryStatsData op( + functions::summarystats::SummaryStatsData d1, Z *extraParams) { + return d1; + } +}; + +template +class SummaryStatsStandardDeviation { + public: + static _CUDA_HD inline Z getValue( + const bool biasCorrected, + functions::summarystats::SummaryStatsData val) { + if (biasCorrected) { + auto ret = static_cast(val.varianceBiasCorrected()); + if (ret < static_cast(0.0f)) + return sd::math::nd4j_sqrt(val.variance()); + else + return sd::math::nd4j_sqrt(ret); + } + return sd::math::nd4j_sqrt(val.variance()); + } + + static _CUDA_HD inline functions::summarystats::SummaryStatsData op( + functions::summarystats::SummaryStatsData d1, Z *extraParams) { + return d1; + } +}; + +template +class DropOut { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + inline _CUDA_D static X + op(X d1, X *params) { + X prob = params[0]; #ifdef __CUDACC__ - X length = params[1]; - X tid = blockIdx.x * blockDim.x + threadIdx.x; - X rnd = sd::math::nd4j_abs(sd::math::nd4j_cos(static_cast(clock64()) * static_cast(tid) + static_cast(length) * static_cast(tid))); + X length = params[1]; + X tid = blockIdx.x * blockDim.x + threadIdx.x; + X rnd = sd::math::nd4j_abs( + sd::math::nd4j_cos(static_cast(clock64()) * static_cast(tid) + + static_cast(length) * static_cast(tid))); #else - X rnd = static_cast(rand() / RAND_MAX); + X rnd = static_cast(rand() / RAND_MAX); #endif - return rnd >= prob ? static_cast(0.0f) : d1; - } - }; - - template - class DropOutInverted { - public: - no_op_exec_special - no_op_exec_special_cuda - + return rnd >= prob ? static_cast(0.0f) : d1; + } +}; + +template +class DropOutInverted { + public: + no_op_exec_special no_op_exec_special_cuda #ifdef __CUDACC__ - __device__ + __device__ #endif - inline static Z op(X d1, Y d2, Z *params) { - Y prob = d2; + inline static Z + op(X d1, Y d2, Z *params) { + Y prob = d2; #ifdef __CUDACC__ - X length = params[1]; - X tid = blockIdx.x * blockDim.x + threadIdx.x; - X rnd = sd::math::nd4j_abs(sd::math::nd4j_cos(static_cast(clock64()) * static_cast(tid) + static_cast(length) * static_cast(tid))); + X length = params[1]; + X tid = blockIdx.x * blockDim.x + threadIdx.x; + X rnd = sd::math::nd4j_abs( + sd::math::nd4j_cos(static_cast(clock64()) * static_cast(tid) + + static_cast(length) * static_cast(tid))); #else - X rnd = static_cast(rand() / RAND_MAX); + X rnd = static_cast(rand() / RAND_MAX); #endif - return rnd >= static_cast(prob) ? static_cast(0.0f) : reinterpret_cast(d1 / static_cast(prob)); - } - }; - - - template - class ReplaceNans { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_isnan(d1) ? static_cast(d2) : static_cast(d1) ; - } - }; - - // this op is used for conditional pairwise transforms only - template - class CompareAndReplace{ - public: - // op definition for PairWise Transform - op_def static Z op(X d1, Y d2, Z *params) { - auto zd1 = static_cast(d1); - auto zd2 = static_cast(d2); - auto compare = params[0]; - auto eps = params[2]; - int mode = (int) params[3]; - if (mode == 0) // equals - if (sd::math::nd4j_abs(zd1 - compare) <= eps) - return zd2; - else - return zd1; - else if (mode == 1) // not equals eps - if (sd::math::nd4j_abs(zd1 - compare) > eps) - return zd2; - else - return zd1; - else if (mode == 2) // less_than eps - if (zd1 < compare) - return zd2; - else - return zd1; - else if (mode ==3) // greater_than - if (zd1 > compare) - return zd2; - else - return zd1; - else if (mode == 4) // less_or_equals_than - if (zd1 <= compare) - return zd2; - else - return zd1; - else if (mode == 5) // greater_or_equals_than - if (zd1 >= compare) - return zd2; - else - return zd1; - else if (mode == 6) // abs_less_than - if (sd::math::nd4j_abs(zd1) < compare) - return zd2; - else - return zd1; - else if (mode == 7) // abs_greater_than - if (sd::math::nd4j_abs(zd1) > compare) - return zd2; - else - return zd1; - else if (mode == 8) // is inf - if (sd::math::nd4j_isinf(zd1)) - return zd2; - else - return zd1; - else if (mode == 9) // is nan - if (sd::math::nd4j_isnan(zd1)) - return zd2; - else - return zd1; - else if (mode == 10) - if (zd1 == compare) - return zd2; - else - return zd1; - else if (mode == 11) - if (zd1 != compare) - return zd2; - else - return zd1; - else if (mode == 12) // abs_greater_or_equals_than - if (sd::math::nd4j_abs(zd1) >= compare) - return zd2; - else - return zd1; - else if (mode == 13) // abs_less_or_equals_than - if (sd::math::nd4j_abs(zd1) <= compare) - return zd2; - else - return zd1; - else - printf("Undefined boolean operation: [%i]\n", mode); - return zd1; - } - }; - - template - class CompareAndSet { - public: - - - // op definition for PairWise Transform - op_def static Z op(X dX, Y dY, Z *params) { - auto d1 = static_cast(dX); - auto d2 = static_cast(dY); - auto compare = params[0]; - auto eps = params[2]; - auto mode = static_cast(params[3]); - if (mode == 0) // equals - if (sd::math::nd4j_abs(d2 - compare) <= eps) - return d2; - else - return d1; - else if (mode == 1) // not equals - if (sd::math::nd4j_abs(d2 - compare) > eps) - return d2; - else - return d1; - else if (mode == 2) // less_than - if (d2 < compare) - return d2; - else - return d1; - else if (mode ==3) // greater_than - if (d2 > compare) - return d2; - else - return d1; - else if (mode == 4) // less_or_equals_than - if (d2 <= compare) - return d2; - else - return d1; - else if (mode == 5) // greater_or_equals_than - if (d2 >= compare) - return d2; - else - return d1; - else if (mode == 6) // abs_less_than - if (sd::math::nd4j_abs(d2) < compare) - return d2; - else - return d1; - else if (mode == 7) // abs_greater_than - if (sd::math::nd4j_abs(d2) > compare) - return d2; - else - return d1; - else if (mode == 8) // is inf - if (sd::math::nd4j_isinf(d2)) - return d2; - else - return d1; - else if (mode == 9) // is nan - if (sd::math::nd4j_isnan(d2)) - return d2; - else - return d1; - else if (mode == 10) - if (d2 == compare) - return d2; - else - return d1; - else if (mode == 11) - if (d2 != compare) - return d2; - else - return d1; - else if (mode == 12) // abs_greater_or_equals_than - if (sd::math::nd4j_abs(d1) >= compare) - return d2; - else - return d1; - else if (mode == 13) // abs_less_or_equals_than - if (sd::math::nd4j_abs(d1) <= compare) - return d2; - else - return d1; - else - printf("Undefined boolean operation: [%i]\n", mode); - return d1; - } - }; - - template - class CompareAndSetTransform { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - - // op definition for Transform - op_def static X op(X d1, X *params) { - auto compare = params[0]; - auto set = params[1]; - auto eps = params[2]; - - // with mode == 0 we do set if d1 equals to compare, and with mode == 1 - we go otherwise - int mode = (int) params[3]; - if (mode == 0) // equals - if (sd::math::nd4j_abs(d1 - compare) <= eps) - return set; - else - return d1; - //return sd::math::nd4j_abs(d1 - compare) <= eps ? set : d1; - else if (mode == 1) // not equals - if (sd::math::nd4j_abs(d1 - compare) > eps) - return set; - else - return d1; - //return sd::math::nd4j_abs(d1 - compare) > eps ? set : d1; - else if (mode == 2) // less_than - if (d1 < compare) - return set; - else - return d1; - else if (mode ==3) // greater_than - if (d1 > compare) - return set; - else - return d1; - else if (mode == 4) // less_or_equals_than - if (d1 <= compare) - return set; - else - return d1; - else if (mode == 5) // greater_or_equals_than - if (d1 >= compare) - return set; - else - return d1; - else if (mode == 6) // abs_less_than - if (sd::math::nd4j_abs(d1) < compare) - return set; - else - return d1; - else if (mode == 7) // abs_greater_than - if (sd::math::nd4j_abs(d1) > compare) - return set; - else - return d1; - else if (mode == 8) // is inf - if (sd::math::nd4j_isinf(d1)) - return set; - else - return d1; - else if (mode == 9) // is nan - if (sd::math::nd4j_isnan(d1)) - return set; - else - return d1; - else if (mode == 10) - if (d1 == compare) - return set; - else - return d1; - else if (mode == 11) - if (d1 != compare) - return set; - else - return d1; - else if (mode == 12) // abs_greater_or_equals_than - if (sd::math::nd4j_abs(d1) >= compare) - return set; - else - return d1; - else if (mode == 13) // abs_less_or_equals_than - if (sd::math::nd4j_abs(d1) <= compare) - return set; - else - return d1; - else - printf("Undefined boolean operation: [%i]\n", mode); - return d1; - } - }; - - -} + return rnd >= static_cast(prob) + ? static_cast(0.0f) + : reinterpret_cast(d1 / static_cast(prob)); + } +}; + +template +class ReplaceNans { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Y d2, Z *params) { + return sd::math::nd4j_isnan(d1) ? static_cast(d2) : static_cast(d1); + } +}; + +// this op is used for conditional pairwise transforms only +template +class CompareAndReplace { + public: + // op definition for PairWise Transform + op_def static Z op(X d1, Y d2, Z *params) { + auto zd1 = static_cast(d1); + auto zd2 = static_cast(d2); + auto compare = params[0]; + auto eps = params[2]; + int mode = (int)params[3]; + if (mode == 0) // equals + if (sd::math::nd4j_abs(zd1 - compare) <= eps) + return zd2; + else + return zd1; + else if (mode == 1) // not equals eps + if (sd::math::nd4j_abs(zd1 - compare) > eps) + return zd2; + else + return zd1; + else if (mode == 2) // less_than eps + if (zd1 < compare) + return zd2; + else + return zd1; + else if (mode == 3) // greater_than + if (zd1 > compare) + return zd2; + else + return zd1; + else if (mode == 4) // less_or_equals_than + if (zd1 <= compare) + return zd2; + else + return zd1; + else if (mode == 5) // greater_or_equals_than + if (zd1 >= compare) + return zd2; + else + return zd1; + else if (mode == 6) // abs_less_than + if (sd::math::nd4j_abs(zd1) < compare) + return zd2; + else + return zd1; + else if (mode == 7) // abs_greater_than + if (sd::math::nd4j_abs(zd1) > compare) + return zd2; + else + return zd1; + else if (mode == 8) // is inf + if (sd::math::nd4j_isinf(zd1)) + return zd2; + else + return zd1; + else if (mode == 9) // is nan + if (sd::math::nd4j_isnan(zd1)) + return zd2; + else + return zd1; + else if (mode == 10) + if (zd1 == compare) + return zd2; + else + return zd1; + else if (mode == 11) + if (zd1 != compare) + return zd2; + else + return zd1; + else if (mode == 12) // abs_greater_or_equals_than + if (sd::math::nd4j_abs(zd1) >= compare) + return zd2; + else + return zd1; + else if (mode == 13) // abs_less_or_equals_than + if (sd::math::nd4j_abs(zd1) <= compare) + return zd2; + else + return zd1; + else + printf("Undefined boolean operation: [%i]\n", mode); + return zd1; + } +}; + +template +class CompareAndSet { + public: + // op definition for PairWise Transform + op_def static Z op(X dX, Y dY, Z *params) { + auto d1 = static_cast(dX); + auto d2 = static_cast(dY); + auto compare = params[0]; + auto eps = params[2]; + auto mode = static_cast(params[3]); + if (mode == 0) // equals + if (sd::math::nd4j_abs(d2 - compare) <= eps) + return d2; + else + return d1; + else if (mode == 1) // not equals + if (sd::math::nd4j_abs(d2 - compare) > eps) + return d2; + else + return d1; + else if (mode == 2) // less_than + if (d2 < compare) + return d2; + else + return d1; + else if (mode == 3) // greater_than + if (d2 > compare) + return d2; + else + return d1; + else if (mode == 4) // less_or_equals_than + if (d2 <= compare) + return d2; + else + return d1; + else if (mode == 5) // greater_or_equals_than + if (d2 >= compare) + return d2; + else + return d1; + else if (mode == 6) // abs_less_than + if (sd::math::nd4j_abs(d2) < compare) + return d2; + else + return d1; + else if (mode == 7) // abs_greater_than + if (sd::math::nd4j_abs(d2) > compare) + return d2; + else + return d1; + else if (mode == 8) // is inf + if (sd::math::nd4j_isinf(d2)) + return d2; + else + return d1; + else if (mode == 9) // is nan + if (sd::math::nd4j_isnan(d2)) + return d2; + else + return d1; + else if (mode == 10) + if (d2 == compare) + return d2; + else + return d1; + else if (mode == 11) + if (d2 != compare) + return d2; + else + return d1; + else if (mode == 12) // abs_greater_or_equals_than + if (sd::math::nd4j_abs(d1) >= compare) + return d2; + else + return d1; + else if (mode == 13) // abs_less_or_equals_than + if (sd::math::nd4j_abs(d1) <= compare) + return d2; + else + return d1; + else + printf("Undefined boolean operation: [%i]\n", mode); + return d1; + } +}; + +template +class CompareAndSetTransform { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + // op definition for Transform + op_def static X + op(X d1, X *params) { + auto compare = params[0]; + auto set = params[1]; + auto eps = params[2]; + + // with mode == 0 we do set if d1 equals to compare, and with mode == 1 - we + // go otherwise + int mode = (int)params[3]; + if (mode == 0) // equals + if (sd::math::nd4j_abs(d1 - compare) <= eps) + return set; + else + return d1; + // return sd::math::nd4j_abs(d1 - compare) <= eps ? set : d1; + else if (mode == 1) // not equals + if (sd::math::nd4j_abs(d1 - compare) > eps) + return set; + else + return d1; + // return sd::math::nd4j_abs(d1 - compare) > eps ? set : d1; + else if (mode == 2) // less_than + if (d1 < compare) + return set; + else + return d1; + else if (mode == 3) // greater_than + if (d1 > compare) + return set; + else + return d1; + else if (mode == 4) // less_or_equals_than + if (d1 <= compare) + return set; + else + return d1; + else if (mode == 5) // greater_or_equals_than + if (d1 >= compare) + return set; + else + return d1; + else if (mode == 6) // abs_less_than + if (sd::math::nd4j_abs(d1) < compare) + return set; + else + return d1; + else if (mode == 7) // abs_greater_than + if (sd::math::nd4j_abs(d1) > compare) + return set; + else + return d1; + else if (mode == 8) // is inf + if (sd::math::nd4j_isinf(d1)) + return set; + else + return d1; + else if (mode == 9) // is nan + if (sd::math::nd4j_isnan(d1)) + return set; + else + return d1; + else if (mode == 10) + if (d1 == compare) + return set; + else + return d1; + else if (mode == 11) + if (d1 != compare) + return set; + else + return d1; + else if (mode == 12) // abs_greater_or_equals_than + if (sd::math::nd4j_abs(d1) >= compare) + return set; + else + return d1; + else if (mode == 13) // abs_less_or_equals_than + if (sd::math::nd4j_abs(d1) <= compare) + return set; + else + return d1; + else + printf("Undefined boolean operation: [%i]\n", mode); + return d1; + } +}; + +} // namespace simdOps #endif - diff --git a/libnd4j/include/ops/random_ops.h b/libnd4j/include/ops/random_ops.h index d738589a7c08..c77dbcd8f8de 100644 --- a/libnd4j/include/ops/random_ops.h +++ b/libnd4j/include/ops/random_ops.h @@ -27,259 +27,305 @@ #define random_def inline static #endif -// since we can't inherit/overwrite static methods - we just define default impls -#define method_idx random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator* rng, T *extraParams) { return -1.0f; } -#define method_X random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator* rng, T *extraParams) { return -2.0f; } -#define method_XY random_def T op(T valueX, T valueY, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator* rng, T *extraParams) { return -3.0f; } - -#define no_exec_special static const bool requiresSpecial = false; static inline void specialOp(Nd4jPointer state, const T *x, const Nd4jLong *xShapeBuffer, const T *y, const Nd4jLong *yShapeBuffer, T *z, const Nd4jLong *zShapeBuffer, T *extraArguments) { } +// since we can't inherit/overwrite static methods - we just define default +// impls +#define method_idx \ + random_def T op(Nd4jLong idx, Nd4jLong length, \ + sd::graph::RandomGenerator *rng, T *extraParams) { \ + return -1.0f; \ + } +#define method_X \ + random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, \ + sd::graph::RandomGenerator *rng, T *extraParams) { \ + return -2.0f; \ + } +#define method_XY \ + random_def T op(T valueX, T valueY, Nd4jLong idx, Nd4jLong length, \ + sd::graph::RandomGenerator *rng, T *extraParams) { \ + return -3.0f; \ + } + +#define no_exec_special \ + static const bool requiresSpecial = false; \ + static inline void specialOp( \ + Nd4jPointer state, const T *x, const Nd4jLong *xShapeBuffer, const T *y, \ + const Nd4jLong *yShapeBuffer, T *z, const Nd4jLong *zShapeBuffer, \ + T *extraArguments) {} #ifdef __CUDACC__ -#define no_exec_special_cuda __device__ static inline void specialOpCuda(Nd4jPointer state, T const* x, Nd4jLong const* xShapeBuffer, T const* y, Nd4jLong const* yShapeBuffer, T *z, Nd4jLong const* zShapeBuffer, T *extraArguments) { } +#define no_exec_special_cuda \ + __device__ static inline void specialOpCuda( \ + Nd4jPointer state, T const *x, Nd4jLong const *xShapeBuffer, T const *y, \ + Nd4jLong const *yShapeBuffer, T *z, Nd4jLong const *zShapeBuffer, \ + T *extraArguments) {} #else #define no_exec_special_cuda #endif -#include -#include #include +#include +#include namespace randomOps { - /** - * This Op merges two arrays per-element, if probability meets threshold - */ - template - class ProbablisticMerge { - public: - - no_exec_special - no_exec_special_cuda - - method_idx - method_X - - random_def T op(T valueX, T valueY, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T threshold = extraParams[0]; - T randVal = helper->relativeT(idx); - - return randVal <= threshold ? valueY : valueX; - } - }; - - /** - * This Op produces random values within specified boundaries. Disribution is uniform - */ - template - class UniformDistribution { - public: - - no_exec_special - no_exec_special_cuda - - method_XY - method_X - - random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - return helper->relativeT(idx, extraParams[0], extraParams[1]); - } - }; - - /** - * This op produces single bernoulli trial - */ - template - class BernoulliDistribution { - public: - no_exec_special - no_exec_special_cuda - - method_XY - - random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - return extraParams[0] >= helper->relativeT(idx) ? (T) 1.0f : (T) 0.0f; - } - - random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - return valueX >= helper->relativeT(idx) ? (T) 1.0f : (T) 0.0f; - } - }; - - - /** - * This op produces single bernoulli trial - */ - template - class ExponentialDistribution { - public: - no_exec_special - no_exec_special_cuda - - method_XY - - random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T lambda = extraParams[0]; - T x = helper->relativeT(idx, sd::DataTypeUtils::min(), T(1.f) - sd::DataTypeUtils::template min()); // x from (0, 1) without bounds - T xVal = -sd::math::nd4j_log(x); - - return xVal <= (T)0.f ? (T)0.f : xVal / lambda; //pow((T) M_E, -(lambda * x)); - } - - random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T lambda = extraParams[0]; - return valueX <= (T)0.f ? (T)0.f : (T)(valueX/lambda); //1.f - sd::math::nd4j_exp(-lambda * valueX); //pow((T) M_E, -(lambda * valueX)); - } - }; - - template - class PoissonDistribution { - public: - no_exec_special - no_exec_special_cuda - - method_XY - - random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T lambda = extraParams[0]; - T x = helper->relativeT(idx, -sd::DataTypeUtils::template max() / 10 , sd::DataTypeUtils::template max() / 10); - return x <= (T)0.f ? (T)0.f : sd::math::nd4j_igammac(sd::math::nd4j_floor(x), lambda); - } - - random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T lambda = extraParams[0]; - return valueX <= (T)0.f ? (T)0.f : (T)sd::math::nd4j_igammac(sd::math::nd4j_floor(valueX), lambda); - } - }; - - template - class GammaDistribution { - public: - no_exec_special - no_exec_special_cuda - - method_XY - - random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T alpha = extraParams[0]; - T beta = extraParams[1]; - T x = helper->relativeT(idx, -sd::DataTypeUtils::template max() / 10 , sd::DataTypeUtils::template max() / 10); - return x <= (T)0.f ? (T)0.f : sd::math::nd4j_igamma(alpha, x * beta); - } - - random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T alpha = extraParams[0]; - T beta = extraParams[1]; - return valueX <= (T)0.f ? (T)0.f : sd::math::nd4j_igamma(alpha, beta * valueX); - } - }; - - /** - * Basic DropOut/DropConnect Op - */ - template - class DropOut { - public: - - no_exec_special - no_exec_special_cuda - - method_idx - method_XY - - // please note: prob is chance to retain original value - random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T randVal = helper->relativeT(idx); - return randVal >= extraParams[0] ? (T) 0.0f : valueX; - } - }; - - template - class AlphaDropOut { - public: - - no_exec_special - no_exec_special_cuda - - method_idx - method_XY - - // please note: prob is chance to retain original value - random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T randVal = helper->relativeT(idx); - // extraParams[0] == p - // [1] = a - // [2] = b - // [3] = alphaPrime - return randVal >= extraParams[0] ? (T) extraParams[1] * extraParams[3] + extraParams[2] : extraParams[1] * valueX + extraParams[2]; - } - }; - - /** - * Inverted DropOut implementation, used in DL4j - */ - template - class DropOutInverted { - public: - - no_exec_special - no_exec_special_cuda - - method_idx - method_XY - - // please note: prob is chance to retain original value - random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T prob = extraParams[0]; - T randVal = helper->relativeT(idx); - return randVal >= prob ? (T) 0.0f : valueX / prob; - } - }; - - - template - class Linspace { - public: - - no_exec_special - no_exec_special_cuda - - method_X - method_XY - - random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T from = extraParams[0]; - T to = extraParams[1]; - T step = extraParams[2]; - - if (step == static_cast(0.0f)) { - step = (T) idx / ((T)length - (T) 1.0f); - return from * ((T) 1.0f - step) + step * to; - } - return from + (idx * step); - - } - }; - - template - class ExponentialDistributionInv { // inverse exponential distribution - public: - no_exec_special - no_exec_special_cuda - - method_XY - - random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T lambda = extraParams[0]; - T x = helper->relativeT(idx, sd::DataTypeUtils::template min(), (T)1.f - sd::DataTypeUtils::template min()); - return -sd::math::nd4j_log((T)1.f - x) / lambda; - } - - random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T lambda = extraParams[0]; - return -sd::math::nd4j_log((T)1.f - valueX) / lambda; // valueX must be within (0, 1] - } - }; - -} - -#endif //LIBND4J_RANDOM_OPS_H +/** + * This Op merges two arrays per-element, if probability meets threshold + */ +template +class ProbablisticMerge { + public: + no_exec_special no_exec_special_cuda + + method_idx method_X + + random_def T + op(T valueX, T valueY, Nd4jLong idx, Nd4jLong length, + sd::graph::RandomGenerator *helper, T *extraParams) { + T threshold = extraParams[0]; + T randVal = helper->relativeT(idx); + + return randVal <= threshold ? valueY : valueX; + } +}; + +/** + * This Op produces random values within specified boundaries. Disribution is + * uniform + */ +template +class UniformDistribution { + public: + no_exec_special no_exec_special_cuda + + method_XY method_X + + random_def T + op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, + T *extraParams) { + return helper->relativeT(idx, extraParams[0], extraParams[1]); + } +}; + +/** + * This op produces single bernoulli trial + */ +template +class BernoulliDistribution { + public: + no_exec_special no_exec_special_cuda + + method_XY + + random_def T + op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, + T *extraParams) { + return extraParams[0] >= helper->relativeT(idx) ? (T)1.0f : (T)0.0f; + } + + random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, + sd::graph::RandomGenerator *helper, T *extraParams) { + return valueX >= helper->relativeT(idx) ? (T)1.0f : (T)0.0f; + } +}; + +/** + * This op produces single bernoulli trial + */ +template +class ExponentialDistribution { + public: + no_exec_special no_exec_special_cuda + + method_XY + + random_def T + op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, + T *extraParams) { + T lambda = extraParams[0]; + T x = helper->relativeT( + idx, sd::DataTypeUtils::min(), + T(1.f) - sd::DataTypeUtils::template min()); // x from (0, 1) + // without bounds + T xVal = -sd::math::nd4j_log(x); + + return xVal <= (T)0.f + ? (T)0.f + : xVal / lambda; // pow((T) M_E, -(lambda * x)); + } + + random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, + sd::graph::RandomGenerator *helper, T *extraParams) { + T lambda = extraParams[0]; + return valueX <= (T)0.f + ? (T)0.f + : (T)(valueX / lambda); // 1.f - sd::math::nd4j_exp(-lambda + // * valueX); //pow((T) M_E, + // -(lambda * valueX)); + } +}; + +template +class PoissonDistribution { + public: + no_exec_special no_exec_special_cuda + + method_XY + + random_def T + op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, + T *extraParams) { + T lambda = extraParams[0]; + T x = helper->relativeT(idx, -sd::DataTypeUtils::template max() / 10, + sd::DataTypeUtils::template max() / 10); + return x <= (T)0.f ? (T)0.f + : sd::math::nd4j_igammac( + sd::math::nd4j_floor(x), lambda); + } + + random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, + sd::graph::RandomGenerator *helper, T *extraParams) { + T lambda = extraParams[0]; + return valueX <= (T)0.f ? (T)0.f + : (T)sd::math::nd4j_igammac( + sd::math::nd4j_floor(valueX), lambda); + } +}; + +template +class GammaDistribution { + public: + no_exec_special no_exec_special_cuda + + method_XY + + random_def T + op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, + T *extraParams) { + T alpha = extraParams[0]; + T beta = extraParams[1]; + T x = helper->relativeT(idx, -sd::DataTypeUtils::template max() / 10, + sd::DataTypeUtils::template max() / 10); + return x <= (T)0.f ? (T)0.f + : sd::math::nd4j_igamma(alpha, x * beta); + } + + random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, + sd::graph::RandomGenerator *helper, T *extraParams) { + T alpha = extraParams[0]; + T beta = extraParams[1]; + return valueX <= (T)0.f + ? (T)0.f + : sd::math::nd4j_igamma(alpha, beta * valueX); + } +}; + +/** + * Basic DropOut/DropConnect Op + */ +template +class DropOut { + public: + no_exec_special no_exec_special_cuda + + method_idx method_XY + + // please note: prob is chance to retain original value + random_def T + op(T valueX, Nd4jLong idx, Nd4jLong length, + sd::graph::RandomGenerator *helper, T *extraParams) { + T randVal = helper->relativeT(idx); + return randVal >= extraParams[0] ? (T)0.0f : valueX; + } +}; + +template +class AlphaDropOut { + public: + no_exec_special no_exec_special_cuda + + method_idx method_XY + + // please note: prob is chance to retain original value + random_def T + op(T valueX, Nd4jLong idx, Nd4jLong length, + sd::graph::RandomGenerator *helper, T *extraParams) { + T randVal = helper->relativeT(idx); + // extraParams[0] == p + // [1] = a + // [2] = b + // [3] = alphaPrime + return randVal >= extraParams[0] + ? (T)extraParams[1] * extraParams[3] + extraParams[2] + : extraParams[1] * valueX + extraParams[2]; + } +}; + +/** + * Inverted DropOut implementation, used in DL4j + */ +template +class DropOutInverted { + public: + no_exec_special no_exec_special_cuda + + method_idx method_XY + + // please note: prob is chance to retain original value + random_def T + op(T valueX, Nd4jLong idx, Nd4jLong length, + sd::graph::RandomGenerator *helper, T *extraParams) { + T prob = extraParams[0]; + T randVal = helper->relativeT(idx); + return randVal >= prob ? (T)0.0f : valueX / prob; + } +}; + +template +class Linspace { + public: + no_exec_special no_exec_special_cuda + + method_X method_XY + + random_def T + op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, + T *extraParams) { + T from = extraParams[0]; + T to = extraParams[1]; + T step = extraParams[2]; + + if (step == static_cast(0.0f)) { + step = (T)idx / ((T)length - (T)1.0f); + return from * ((T)1.0f - step) + step * to; + } + return from + (idx * step); + } +}; + +template +class ExponentialDistributionInv { // inverse exponential distribution + public: + no_exec_special no_exec_special_cuda + + method_XY + + random_def T + op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, + T *extraParams) { + T lambda = extraParams[0]; + T x = helper->relativeT(idx, sd::DataTypeUtils::template min(), + (T)1.f - sd::DataTypeUtils::template min()); + return -sd::math::nd4j_log((T)1.f - x) / lambda; + } + + random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, + sd::graph::RandomGenerator *helper, T *extraParams) { + T lambda = extraParams[0]; + return -sd::math::nd4j_log((T)1.f - valueX) / + lambda; // valueX must be within (0, 1] + } +}; + +} // namespace randomOps + +#endif // LIBND4J_RANDOM_OPS_H diff --git a/libnd4j/include/ops/special_random_ops.h b/libnd4j/include/ops/special_random_ops.h index f9bacf5cb88a..6999fcbe8060 100644 --- a/libnd4j/include/ops/special_random_ops.h +++ b/libnd4j/include/ops/special_random_ops.h @@ -21,824 +21,892 @@ #ifndef LIBND4J_SPECIAL_RANDOM_OPS_H #define LIBND4J_SPECIAL_RANDOM_OPS_H -#include -#include +#include #include +#include +#include #include -#include namespace randomOps { ////////////////////////////////////////////////////////////////////// - template - class Choice { - public: - - method_idx - method_X - method_XY - - static const bool requiresSpecial = true; +template +class Choice { + public: + method_idx method_X method_XY + static const bool requiresSpecial = true; #ifdef __CUDACC__ - __device__ static inline void specialOpCuda(Nd4jPointer state, T const* x, Nd4jLong const* xShapeBuffer, T const* y, Nd4jLong const* yShapeBuffer, T *z, Nd4jLong const* zShapeBuffer, T *extraArguments) { - /** - * X holds data, - * Y holds probabilities - * Z will hold results - */ - - // TODO: we probably might want to skip this sum, and state that probabilities array should be real probabilities, i.e. should sum to 1.0 - //T probSum = extraArguments[0]; - - __shared__ Nd4jLong xLength; - __shared__ Nd4jLong yLength; - __shared__ Nd4jLong zLength; - - __shared__ Nd4jLong xEWS; - __shared__ Nd4jLong yEWS; - __shared__ Nd4jLong zEWS; - __shared__ char xOrder; - __shared__ char yOrder; - __shared__ char zOrder; - - __shared__ sd::graph::RandomGenerator *rng; - __shared__ unsigned char *cB; - __shared__ unsigned char *dB; - __shared__ sd::graph::RandomGenerator *devRng; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - rng = (sd::graph::RandomGenerator*) shmem; - cB = shmem; - devRng = reinterpret_cast (state); - dB = reinterpret_cast (state); - - xLength = shape::length(xShapeBuffer); - yLength = shape::length(yShapeBuffer); - zLength = shape::length(zShapeBuffer); - - xEWS = shape::elementWiseStride(xShapeBuffer); - yEWS = shape::elementWiseStride(yShapeBuffer); - zEWS = shape::elementWiseStride(zShapeBuffer); - xOrder = shape::order(xShapeBuffer); - yOrder = shape::order(yShapeBuffer); - zOrder = shape::order(zShapeBuffer); - } - __syncthreads(); - - // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); e+= blockDim.x) - cB[e] = dB[e]; - - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - if (zEWS >= 1 && xEWS >= 1 && yEWS >= 1 && xOrder == yOrder && xOrder == zOrder) { - for (Nd4jLong e = tid; e < zLength; e+=blockDim.x * gridDim.x) { - T prob = rng->relativeT(e); - T cumProb = (T) 0.0f; - for (Nd4jLong f = 0; f < yLength; f++) { - T relProb = y[f * yEWS]; - cumProb += relProb; - - if (prob <= cumProb || f == yLength - 1) { - z[e * zEWS] = x[f * xEWS]; - f += yLength; - } -// __syncthreads(); // Eliminated due RTX20xx specific - } -// __syncthreads(); // Eliminated due RTX20xx specific - } - } - else { - - for (Nd4jLong i = tid; i < zLength; i+=blockDim.x * gridDim.x) { - - auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer); - T prob = rng->relativeT(i); - T cumProb = (T) 0.0f; - - for (Nd4jLong f = 0; f < yLength; f++) { - - auto yOffset2 = shape::getIndexOffset(f, yShapeBuffer); - T relProb = y[yOffset2]; - cumProb += relProb; - - if (prob <= cumProb || f == yLength - 1) { - - auto xOffset2 = shape::getIndexOffset(f, xShapeBuffer); - z[zOffset2] = x[xOffset2]; - f += yLength; - } -// __syncthreads(); // Eliminated due RTX20xx specific - } -// __syncthreads(); // Eliminated due RTX20xx specific - } - } + __device__ static inline void specialOpCuda( + Nd4jPointer state, T const *x, Nd4jLong const *xShapeBuffer, T const *y, + Nd4jLong const *yShapeBuffer, T *z, Nd4jLong const *zShapeBuffer, + T *extraArguments) { + /** + * X holds data, + * Y holds probabilities + * Z will hold results + */ + + // TODO: we probably might want to skip this sum, and state that + // probabilities array should be real probabilities, i.e. should sum to 1.0 + // T probSum = extraArguments[0]; + + __shared__ Nd4jLong xLength; + __shared__ Nd4jLong yLength; + __shared__ Nd4jLong zLength; + + __shared__ Nd4jLong xEWS; + __shared__ Nd4jLong yEWS; + __shared__ Nd4jLong zEWS; + __shared__ char xOrder; + __shared__ char yOrder; + __shared__ char zOrder; + + __shared__ sd::graph::RandomGenerator *rng; + __shared__ unsigned char *cB; + __shared__ unsigned char *dB; + __shared__ sd::graph::RandomGenerator *devRng; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + rng = (sd::graph::RandomGenerator *)shmem; + cB = shmem; + devRng = reinterpret_cast(state); + dB = reinterpret_cast(state); + + xLength = shape::length(xShapeBuffer); + yLength = shape::length(yShapeBuffer); + zLength = shape::length(zShapeBuffer); + + xEWS = shape::elementWiseStride(xShapeBuffer); + yEWS = shape::elementWiseStride(yShapeBuffer); + zEWS = shape::elementWiseStride(zShapeBuffer); + xOrder = shape::order(xShapeBuffer); + yOrder = shape::order(yShapeBuffer); + zOrder = shape::order(zShapeBuffer); + } + __syncthreads(); + + // using this loop instead of memcpy + for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); + e += blockDim.x) + cB[e] = dB[e]; + + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + if (zEWS >= 1 && xEWS >= 1 && yEWS >= 1 && xOrder == yOrder && + xOrder == zOrder) { + for (Nd4jLong e = tid; e < zLength; e += blockDim.x * gridDim.x) { + T prob = rng->relativeT(e); + T cumProb = (T)0.0f; + for (Nd4jLong f = 0; f < yLength; f++) { + T relProb = y[f * yEWS]; + cumProb += relProb; + + if (prob <= cumProb || f == yLength - 1) { + z[e * zEWS] = x[f * xEWS]; + f += yLength; + } + // __syncthreads(); // Eliminated due RTX20xx + // specific } + // __syncthreads(); // Eliminated due RTX20xx + // specific + } + } else { + for (Nd4jLong i = tid; i < zLength; i += blockDim.x * gridDim.x) { + auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer); + T prob = rng->relativeT(i); + T cumProb = (T)0.0f; + + for (Nd4jLong f = 0; f < yLength; f++) { + auto yOffset2 = shape::getIndexOffset(f, yShapeBuffer); + T relProb = y[yOffset2]; + cumProb += relProb; + + if (prob <= cumProb || f == yLength - 1) { + auto xOffset2 = shape::getIndexOffset(f, xShapeBuffer); + z[zOffset2] = x[xOffset2]; + f += yLength; + } + // __syncthreads(); // Eliminated due RTX20xx + // specific + } + // __syncthreads(); // Eliminated due RTX20xx + // specific + } + } + } #endif - static inline void specialOp(Nd4jPointer state, const T *x, const Nd4jLong *xShapeBuffer, const T *y, const Nd4jLong *yShapeBuffer, T *z, const Nd4jLong *zShapeBuffer, T *extraArguments) { - /** - * X holds data, - * Y holds probabilities - * Z will hold results - */ - - //sd::random::RandomBuffer *buffer = reinterpret_cast (state); - sd::graph::RandomGenerator* rng = reinterpret_cast(state); - // TODO: we probably might want to skip this sum, and state that probabilities array should be real probabilities, i.e. should sum to 1.0 - //T probSum = extraArguments[0]; - - auto xLength = shape::length(xShapeBuffer); - auto yLength = shape::length(yShapeBuffer); - auto zLength = shape::length(zShapeBuffer); - - auto xEWS = shape::elementWiseStride(xShapeBuffer); - auto yEWS = shape::elementWiseStride(yShapeBuffer); - auto zEWS = shape::elementWiseStride(zShapeBuffer); - - int elementsPerThread = zLength / TAD_THRESHOLD; - int _threads = sd::math::nd4j_max(1, elementsPerThread); - _threads = sd::math::nd4j_min(_threads, sd::Environment::getInstance().maxThreads()); - - if (zEWS >= 1 && xEWS >= 1 && yEWS >= 1) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - T prob = rng->relativeT(e); - T cumProb = (T) 0.0f; - for (Nd4jLong f = 0; f < yLength; f++) { - T relProb = y[f * yEWS]; - cumProb += relProb; - - if (prob <= cumProb || f == yLength - 1) { - z[e * zEWS] = x[f * xEWS]; - break; - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); + static inline void specialOp(Nd4jPointer state, const T *x, + const Nd4jLong *xShapeBuffer, const T *y, + const Nd4jLong *yShapeBuffer, T *z, + const Nd4jLong *zShapeBuffer, + T *extraArguments) { + /** + * X holds data, + * Y holds probabilities + * Z will hold results + */ + + // sd::random::RandomBuffer *buffer = + // reinterpret_cast (state); + sd::graph::RandomGenerator *rng = + reinterpret_cast(state); + // TODO: we probably might want to skip this sum, and state that + // probabilities array should be real probabilities, i.e. should sum to 1.0 + // T probSum = extraArguments[0]; + + auto xLength = shape::length(xShapeBuffer); + auto yLength = shape::length(yShapeBuffer); + auto zLength = shape::length(zShapeBuffer); + + auto xEWS = shape::elementWiseStride(xShapeBuffer); + auto yEWS = shape::elementWiseStride(yShapeBuffer); + auto zEWS = shape::elementWiseStride(zShapeBuffer); + + int elementsPerThread = zLength / TAD_THRESHOLD; + int _threads = sd::math::nd4j_max(1, elementsPerThread); + _threads = sd::math::nd4j_min( + _threads, sd::Environment::getInstance().maxThreads()); + + if (zEWS >= 1 && xEWS >= 1 && yEWS >= 1) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + T prob = rng->relativeT(e); + T cumProb = (T)0.0f; + for (Nd4jLong f = 0; f < yLength; f++) { + T relProb = y[f * yEWS]; + cumProb += relProb; + + if (prob <= cumProb || f == yLength - 1) { + z[e * zEWS] = x[f * xEWS]; + break; } - else { - - auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong i = 0; i < zLength; i++) { - - auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer); - T prob = rng->relativeT(i); - T cumProb = (T) 0.0f; - - for (Nd4jLong f = 0; f < yLength; f++) { - - auto yOffset2 = shape::getIndexOffset(f, yShapeBuffer); - T relProb = y[yOffset2]; - cumProb += relProb; - - if (prob <= cumProb || f == yLength - 1) { - - auto xOffset2 = shape::getIndexOffset(f, xShapeBuffer); - z[zOffset2] = x[xOffset2]; - break; - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); + } + } + }; + + samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); + } else { + auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = 0; i < zLength; i++) { + auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer); + T prob = rng->relativeT(i); + T cumProb = (T)0.0f; + + for (Nd4jLong f = 0; f < yLength; f++) { + auto yOffset2 = shape::getIndexOffset(f, yShapeBuffer); + T relProb = y[yOffset2]; + cumProb += relProb; + + if (prob <= cumProb || f == yLength - 1) { + auto xOffset2 = shape::getIndexOffset(f, xShapeBuffer); + z[zOffset2] = x[xOffset2]; + break; } + } } - }; + }; + samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); + } + } +}; ////////////////////////////////////////////////////////////////////// - /** - * This Op produces random values within specified boundaries. Distribuion is Gaussian - */ - template - class GaussianDistribution { - public: +/** + * This Op produces random values within specified boundaries. Distribuion is + * Gaussian + */ +template +class GaussianDistribution { + public: + method_XY method_X method_idx - method_XY - method_X - method_idx - - static const bool requiresSpecial = true; + static const bool requiresSpecial = true; #ifdef __CUDACC__ - __device__ static inline void specialOpCuda(Nd4jPointer state, T const* x, Nd4jLong const* xShapeBuffer, T const* y, Nd4jLong const *yShapeBuffer, T *z, Nd4jLong const* zShapeBuffer, T *extraArguments) { - - __shared__ T epsilon; - __shared__ T two_pi; - - __shared__ Nd4jLong zLength; - __shared__ Nd4jLong zEWS; - __shared__ Nd4jLong yEWS; - __shared__ T mean; - __shared__ T stddev; - __shared__ int step; - - __shared__ T *tZ; - - __shared__ sd::graph::RandomGenerator* rng; - __shared__ unsigned char *cB; - __shared__ unsigned char *dB; - __shared__ sd::graph::RandomGenerator *devRng; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - rng = reinterpret_cast(shmem); - cB = shmem; - devRng = reinterpret_cast (state); - dB = reinterpret_cast (state); - - tZ = reinterpret_cast(shmem + sizeof(sd::graph::RandomGenerator)); - - zLength = shape::length(zShapeBuffer); - zEWS = shape::elementWiseStride(zShapeBuffer); - yEWS = shape::elementWiseStride(yShapeBuffer); - - - epsilon = static_cast(1e-5); - two_pi = static_cast(2.0f) * static_cast(3.14159265358979323846); - - mean = extraArguments[0]; - stddev = extraArguments[1]; - - step = (blockDim.x * gridDim.x); - } - __syncthreads(); - - // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); e+= blockDim.x) - cB[e] = dB[e]; - - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - int middle = zLength % 2 == 0 ? zLength / 2 : zLength / 2 + 1; - T t(-2.0f); - - for (int e = tid; e < middle; e += step) { - auto epm = e + middle; - // we need to get random values - T r0 = rng->relativeT(e, epsilon, static_cast(1.0f)); - T r1 = rng->relativeT(epm, epsilon, static_cast(1.0f)); - - T realMean0 = y == z ? mean : y[e * yEWS]; - - z[e * zEWS] = (sd::math::nd4j_sqrt(t * sd::math::nd4j_log(r0)) * sd::math::nd4j_cos(two_pi * r1)) * stddev + realMean0; - - if (epm < zLength) { - T realMean1 = y == z ? mean : y[epm * yEWS]; - z[epm * zEWS] = (sd::math::nd4j_sqrt(t * sd::math::nd4j_log(r0)) * sd::math::nd4j_sin(two_pi * r1)) * stddev + realMean1; - } - } - } + __device__ static inline void specialOpCuda( + Nd4jPointer state, T const *x, Nd4jLong const *xShapeBuffer, T const *y, + Nd4jLong const *yShapeBuffer, T *z, Nd4jLong const *zShapeBuffer, + T *extraArguments) { + __shared__ T epsilon; + __shared__ T two_pi; + + __shared__ Nd4jLong zLength; + __shared__ Nd4jLong zEWS; + __shared__ Nd4jLong yEWS; + __shared__ T mean; + __shared__ T stddev; + __shared__ int step; + + __shared__ T *tZ; + + __shared__ sd::graph::RandomGenerator *rng; + __shared__ unsigned char *cB; + __shared__ unsigned char *dB; + __shared__ sd::graph::RandomGenerator *devRng; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + rng = reinterpret_cast(shmem); + cB = shmem; + devRng = reinterpret_cast(state); + dB = reinterpret_cast(state); + + tZ = reinterpret_cast(shmem + sizeof(sd::graph::RandomGenerator)); + + zLength = shape::length(zShapeBuffer); + zEWS = shape::elementWiseStride(zShapeBuffer); + yEWS = shape::elementWiseStride(yShapeBuffer); + + epsilon = static_cast(1e-5); + two_pi = static_cast(2.0f) * static_cast(3.14159265358979323846); + + mean = extraArguments[0]; + stddev = extraArguments[1]; + + step = (blockDim.x * gridDim.x); + } + __syncthreads(); + + // using this loop instead of memcpy + for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); + e += blockDim.x) + cB[e] = dB[e]; + + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + int middle = zLength % 2 == 0 ? zLength / 2 : zLength / 2 + 1; + T t(-2.0f); + + for (int e = tid; e < middle; e += step) { + auto epm = e + middle; + // we need to get random values + T r0 = rng->relativeT(e, epsilon, static_cast(1.0f)); + T r1 = rng->relativeT(epm, epsilon, static_cast(1.0f)); + + T realMean0 = y == z ? mean : y[e * yEWS]; + + z[e * zEWS] = + (sd::math::nd4j_sqrt(t * sd::math::nd4j_log(r0)) * + sd::math::nd4j_cos(two_pi * r1)) * + stddev + + realMean0; + + if (epm < zLength) { + T realMean1 = y == z ? mean : y[epm * yEWS]; + z[epm * zEWS] = + (sd::math::nd4j_sqrt(t * sd::math::nd4j_log(r0)) * + sd::math::nd4j_sin(two_pi * r1)) * + stddev + + realMean1; + } + } + } #endif - - static inline void - specialOp(Nd4jPointer state, const T *x, const Nd4jLong *xShapeBuffer, const T *y, const Nd4jLong *yShapeBuffer, T *z, const Nd4jLong *zShapeBuffer, T *extraArguments) { - const T two_pi = static_cast(2.0f) * static_cast(3.14159265358979323846); - - auto zLength = shape::length(zShapeBuffer); - auto yEWS = shape::elementWiseStride(yShapeBuffer); - auto zEWS = shape::elementWiseStride(zShapeBuffer); - - auto middle = zLength % 2 + zLength / 2; - - int elementsPerThread = middle / TAD_THRESHOLD; - int _threads = sd::math::nd4j_max(1, elementsPerThread); - _threads = sd::math::nd4j_min(_threads, sd::Environment::getInstance().maxThreads()); - - int span = (middle / _threads) + 8; - - // we're enforcing even chunks, since it's mandatory for this algorithm - span -= span % 2; - - //sd::random::RandomBuffer *buffer = reinterpret_cast (state); - sd::graph::RandomGenerator* rng = reinterpret_cast(state); - const T mean = extraArguments[0]; - const T stddev = extraArguments[1]; - - const T epsilon = static_cast(1e-5); - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto epm = e + middle; - - // we need to get random values - T r0 = rng->relativeT(e, epsilon, static_cast(1.0f)); - T r1 = rng->relativeT(epm, epsilon, static_cast(1.0f)); - - T realMean0 = y == z ? mean : y[e * yEWS]; - - auto z0 = (sd::math::nd4j_sqrt(static_cast(-2.0f) * sd::math::nd4j_log(r0)) * - sd::math::nd4j_cos(two_pi * r1)) * stddev + realMean0; - z[e * zEWS] = z0; - - if (epm < zLength) { - T realMean1 = y == z ? mean : y[epm * yEWS]; - auto z1 = (sd::math::nd4j_sqrt(static_cast(-2.0f) * sd::math::nd4j_log(r0)) * - sd::math::nd4j_sin(two_pi * r1)) * stddev + realMean1; - z[epm * zEWS] = z1; - } - } - }; - - samediff::Threads::parallel_for(func, 0, middle, 1, _threads); + static inline void specialOp(Nd4jPointer state, const T *x, + const Nd4jLong *xShapeBuffer, const T *y, + const Nd4jLong *yShapeBuffer, T *z, + const Nd4jLong *zShapeBuffer, + T *extraArguments) { + const T two_pi = + static_cast(2.0f) * static_cast(3.14159265358979323846); + + auto zLength = shape::length(zShapeBuffer); + auto yEWS = shape::elementWiseStride(yShapeBuffer); + auto zEWS = shape::elementWiseStride(zShapeBuffer); + + auto middle = zLength % 2 + zLength / 2; + + int elementsPerThread = middle / TAD_THRESHOLD; + int _threads = sd::math::nd4j_max(1, elementsPerThread); + _threads = sd::math::nd4j_min( + _threads, sd::Environment::getInstance().maxThreads()); + + int span = (middle / _threads) + 8; + + // we're enforcing even chunks, since it's mandatory for this algorithm + span -= span % 2; + + // sd::random::RandomBuffer *buffer = + // reinterpret_cast (state); + sd::graph::RandomGenerator *rng = + reinterpret_cast(state); + const T mean = extraArguments[0]; + const T stddev = extraArguments[1]; + + const T epsilon = static_cast(1e-5); + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto epm = e + middle; + + // we need to get random values + T r0 = rng->relativeT(e, epsilon, static_cast(1.0f)); + T r1 = rng->relativeT(epm, epsilon, static_cast(1.0f)); + + T realMean0 = y == z ? mean : y[e * yEWS]; + + auto z0 = (sd::math::nd4j_sqrt(static_cast(-2.0f) * + sd::math::nd4j_log(r0)) * + sd::math::nd4j_cos(two_pi * r1)) * + stddev + + realMean0; + z[e * zEWS] = z0; + + if (epm < zLength) { + T realMean1 = y == z ? mean : y[epm * yEWS]; + auto z1 = (sd::math::nd4j_sqrt(static_cast(-2.0f) * + sd::math::nd4j_log(r0)) * + sd::math::nd4j_sin(two_pi * r1)) * + stddev + + realMean1; + z[epm * zEWS] = z1; } + } }; + samediff::Threads::parallel_for(func, 0, middle, 1, _threads); + } +}; ////////////////////////////////////////////////////////////////////// - /** - * This Op produces random values within [0..N], Distribuion is binomial - */ - template - class BinomialDistribution { - public: - - - method_XY - method_X - method_idx +/** + * This Op produces random values within [0..N], Distribuion is binomial + */ +template +class BinomialDistribution { + public: + method_XY method_X method_idx - static const bool requiresSpecial = true; + static const bool requiresSpecial = true; #ifdef __CUDACC__ - __device__ static inline void specialOpCuda(Nd4jPointer state, T const* x, Nd4jLong const* xShapeBuffer, T const* y, Nd4jLong const* yShapeBuffer, T *z, Nd4jLong const* zShapeBuffer, T *extraArguments) { - int trials = (int) extraArguments[0]; - T prob = extraArguments[1]; - - __shared__ Nd4jLong zLength; - __shared__ int yEWS; - __shared__ int zEWS; - - __shared__ sd::graph::RandomGenerator* rng; - __shared__ unsigned char *cB; - __shared__ unsigned char *dB; - __shared__ sd::graph::RandomGenerator *devRng; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - rng = reinterpret_cast(shmem); - cB = shmem; - devRng = reinterpret_cast(state); - dB = reinterpret_cast (state); - - zLength = shape::length(zShapeBuffer); - yEWS = shape::elementWiseStride(yShapeBuffer); - zEWS = shape::elementWiseStride(zShapeBuffer); - } - __syncthreads(); - - // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); e+= blockDim.x) - cB[e] = dB[e]; - - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong e = tid; e < zLength; e += blockDim.x * gridDim.x) { - int success = 0; - for (int t = 1; t <= trials; t++) { - T randVal = rng->relativeT((e+1) * t); - if (y != z) { - // we're using external probs - prob = y[(t-1) * yEWS]; - } - - if (randVal < prob) - success++; - } - // if trials is set to 0, effectively we just have successful memset - z[e * zEWS] = static_cast(success); - } + __device__ static inline void specialOpCuda( + Nd4jPointer state, T const *x, Nd4jLong const *xShapeBuffer, T const *y, + Nd4jLong const *yShapeBuffer, T *z, Nd4jLong const *zShapeBuffer, + T *extraArguments) { + int trials = (int)extraArguments[0]; + T prob = extraArguments[1]; + + __shared__ Nd4jLong zLength; + __shared__ int yEWS; + __shared__ int zEWS; + + __shared__ sd::graph::RandomGenerator *rng; + __shared__ unsigned char *cB; + __shared__ unsigned char *dB; + __shared__ sd::graph::RandomGenerator *devRng; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + rng = reinterpret_cast(shmem); + cB = shmem; + devRng = reinterpret_cast(state); + dB = reinterpret_cast(state); + + zLength = shape::length(zShapeBuffer); + yEWS = shape::elementWiseStride(yShapeBuffer); + zEWS = shape::elementWiseStride(zShapeBuffer); + } + __syncthreads(); + + // using this loop instead of memcpy + for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); + e += blockDim.x) + cB[e] = dB[e]; + + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong e = tid; e < zLength; e += blockDim.x * gridDim.x) { + int success = 0; + for (int t = 1; t <= trials; t++) { + T randVal = rng->relativeT((e + 1) * t); + if (y != z) { + // we're using external probs + prob = y[(t - 1) * yEWS]; } -#endif - - static inline void specialOp(Nd4jPointer state, const T *x, const Nd4jLong *xShapeBuffer, const T *y, const Nd4jLong *yShapeBuffer, T *z, const Nd4jLong *zShapeBuffer, T *extraArguments) { - int trials = (int) extraArguments[0]; - - Nd4jLong zLength = shape::length(zShapeBuffer); - - auto yEWS = shape::elementWiseStride(yShapeBuffer); - auto zEWS = shape::elementWiseStride(zShapeBuffer); - - int elementsPerThread = zLength / TAD_THRESHOLD; - int _threads = sd::math::nd4j_max(1, elementsPerThread); - _threads = sd::math::nd4j_min(_threads, sd::Environment::getInstance().maxThreads()); - - T prob = extraArguments[1]; - - sd::graph::RandomGenerator* rng = reinterpret_cast(state); - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - int success = 0; - for (int t = 1; t <= trials; t++) { - T randVal = rng->relativeT((e+1) * t); - if (y != z) { - // we're using external probs - prob = y[(t-1) * yEWS]; - } - - if (randVal < prob) - success++; - } - - // if trials is set to 0, effectively we just have successful memset - z[e * zEWS] = static_cast(success); - } - }; + if (randVal < prob) success++; + } + // if trials is set to 0, effectively we just have successful memset + z[e * zEWS] = static_cast(success); + } + } +#endif - samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); + static inline void specialOp(Nd4jPointer state, const T *x, + const Nd4jLong *xShapeBuffer, const T *y, + const Nd4jLong *yShapeBuffer, T *z, + const Nd4jLong *zShapeBuffer, + T *extraArguments) { + int trials = (int)extraArguments[0]; + + Nd4jLong zLength = shape::length(zShapeBuffer); + + auto yEWS = shape::elementWiseStride(yShapeBuffer); + auto zEWS = shape::elementWiseStride(zShapeBuffer); + + int elementsPerThread = zLength / TAD_THRESHOLD; + int _threads = sd::math::nd4j_max(1, elementsPerThread); + _threads = sd::math::nd4j_min( + _threads, sd::Environment::getInstance().maxThreads()); + + T prob = extraArguments[1]; + + sd::graph::RandomGenerator *rng = + reinterpret_cast(state); + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + int success = 0; + for (int t = 1; t <= trials; t++) { + T randVal = rng->relativeT((e + 1) * t); + if (y != z) { + // we're using external probs + prob = y[(t - 1) * yEWS]; + } + + if (randVal < prob) success++; } + + // if trials is set to 0, effectively we just have successful memset + z[e * zEWS] = static_cast(success); + } }; + samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); + } +}; ////////////////////////////////////////////////////////////////////// - /** - * This Op produces random values within [0..N], Distribuion is binomial - */ - template - class BinomialDistributionEx { - public: +/** + * This Op produces random values within [0..N], Distribuion is binomial + */ +template +class BinomialDistributionEx { + public: + method_XY method_X method_idx - - method_XY - method_X - method_idx - - static const bool requiresSpecial = true; + static const bool requiresSpecial = true; #ifdef __CUDACC__ - __device__ static inline void specialOpCuda(Nd4jPointer state, T const* x, Nd4jLong const* xShapeBuffer, T const* y, Nd4jLong const* yShapeBuffer, T *z, Nd4jLong const* zShapeBuffer, T *extraArguments) { - int trials = (int) extraArguments[0]; - T prob = extraArguments[1]; - - __shared__ Nd4jLong zLength; - __shared__ int yEWS; - __shared__ int zEWS; - - __shared__ sd::graph::RandomGenerator* rng; - __shared__ unsigned char *cB; - __shared__ unsigned char *dB; - __shared__ sd::graph::RandomGenerator *devRng; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - rng = (sd::graph::RandomGenerator*) shmem; - cB = shmem; - devRng = reinterpret_cast (state); - dB = reinterpret_cast (state); - - zLength = shape::length(zShapeBuffer); - yEWS = shape::elementWiseStride(yShapeBuffer); - zEWS = shape::elementWiseStride(zShapeBuffer); - } - __syncthreads(); - - // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); e+= blockDim.x) - cB[e] = dB[e]; - - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong e = tid; e < zLength; e += blockDim.x * gridDim.x) { - int success = 0; - for (int t = 1; t <= trials; t++) { - T randVal = rng->relativeT((e+1) * t); - if (y != z) { - // we're using external probs - prob = y[e * yEWS]; - } - - if (randVal < prob) - success++; - } - - // if trials is set to 0, effectively we just have successful memset - z[e * zEWS] = (T) success; - } + __device__ static inline void specialOpCuda( + Nd4jPointer state, T const *x, Nd4jLong const *xShapeBuffer, T const *y, + Nd4jLong const *yShapeBuffer, T *z, Nd4jLong const *zShapeBuffer, + T *extraArguments) { + int trials = (int)extraArguments[0]; + T prob = extraArguments[1]; + + __shared__ Nd4jLong zLength; + __shared__ int yEWS; + __shared__ int zEWS; + + __shared__ sd::graph::RandomGenerator *rng; + __shared__ unsigned char *cB; + __shared__ unsigned char *dB; + __shared__ sd::graph::RandomGenerator *devRng; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + rng = (sd::graph::RandomGenerator *)shmem; + cB = shmem; + devRng = reinterpret_cast(state); + dB = reinterpret_cast(state); + + zLength = shape::length(zShapeBuffer); + yEWS = shape::elementWiseStride(yShapeBuffer); + zEWS = shape::elementWiseStride(zShapeBuffer); + } + __syncthreads(); + + // using this loop instead of memcpy + for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); + e += blockDim.x) + cB[e] = dB[e]; + + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong e = tid; e < zLength; e += blockDim.x * gridDim.x) { + int success = 0; + for (int t = 1; t <= trials; t++) { + T randVal = rng->relativeT((e + 1) * t); + if (y != z) { + // we're using external probs + prob = y[e * yEWS]; } -#endif - - static inline void specialOp(Nd4jPointer state, const T *x, const Nd4jLong *xShapeBuffer, const T *y, const Nd4jLong *yShapeBuffer, T *z, const Nd4jLong *zShapeBuffer, T *extraArguments) { - int trials = (int) extraArguments[0]; - - Nd4jLong zLength = shape::length(zShapeBuffer); - - auto yEWS = shape::elementWiseStride(yShapeBuffer); - auto zEWS = shape::elementWiseStride(zShapeBuffer); - - int elementsPerThread = zLength / TAD_THRESHOLD; - int _threads = sd::math::nd4j_max(1, elementsPerThread); - _threads = sd::math::nd4j_min(_threads, sd::Environment::getInstance().maxThreads()); - T prob = extraArguments[1]; + if (randVal < prob) success++; + } - auto rng = reinterpret_cast(state); - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - - int success = 0; - for (int t = 1; t <= trials; t++) { - T randVal = rng->relativeT((e+1) * t); - if (y != z) { - // we're using external probs - prob = y[e * yEWS]; - } - - if (randVal < prob) - success++; - } - - // if trials is set to 0, effectively we just have successful memset - z[e * zEWS] = static_cast(success); - } - }; + // if trials is set to 0, effectively we just have successful memset + z[e * zEWS] = (T)success; + } + } +#endif - samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); + static inline void specialOp(Nd4jPointer state, const T *x, + const Nd4jLong *xShapeBuffer, const T *y, + const Nd4jLong *yShapeBuffer, T *z, + const Nd4jLong *zShapeBuffer, + T *extraArguments) { + int trials = (int)extraArguments[0]; + + Nd4jLong zLength = shape::length(zShapeBuffer); + + auto yEWS = shape::elementWiseStride(yShapeBuffer); + auto zEWS = shape::elementWiseStride(zShapeBuffer); + + int elementsPerThread = zLength / TAD_THRESHOLD; + int _threads = sd::math::nd4j_max(1, elementsPerThread); + _threads = sd::math::nd4j_min( + _threads, sd::Environment::getInstance().maxThreads()); + + T prob = extraArguments[1]; + + auto rng = reinterpret_cast(state); + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + int success = 0; + for (int t = 1; t <= trials; t++) { + T randVal = rng->relativeT((e + 1) * t); + if (y != z) { + // we're using external probs + prob = y[e * yEWS]; + } + + if (randVal < prob) success++; } - }; -////////////////////////////////////////////////////////////////////// - // This Op produces random Gaussian values within [mean-2*stddev,mean+2*stddev] - template - class TruncatedNormalDistribution { - private: - static inline _CUDA_HD T step(sd::graph::RandomGenerator* rng, T mean, T stddev, Nd4jLong e, Nd4jLong middle, T& z) { - auto epm = e + middle; - const T two_pi = static_cast(2.0f) * static_cast(3.14159265358979323846); - const T epsilon = static_cast(1.e-5f); - // we need to get random values - T r0 = rng->relativeT(e, epsilon, static_cast(1.0f)); - T r1 = rng->relativeT(epm, epsilon, static_cast(1.0f)); - - T realMean0 = mean; - - auto z0 = (sd::math::nd4j_sqrt(static_cast(-2.0f) * sd::math::nd4j_log(r0)) * sd::math::nd4j_cos(two_pi * r1)) * stddev + realMean0; - z = z0; - if (epm < middle) { - T realMean1 = mean; - auto z1 = (sd::math::nd4j_sqrt(static_cast(-2.0f) * sd::math::nd4j_log(r0)) * - sd::math::nd4j_sin(two_pi * r1)) * stddev + realMean1; - z = z1; - } - return z; - } - public: + // if trials is set to 0, effectively we just have successful memset + z[e * zEWS] = static_cast(success); + } + }; - method_XY - method_X - method_idx + samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); + } +}; - static const bool requiresSpecial = true; +////////////////////////////////////////////////////////////////////// +// This Op produces random Gaussian values within [mean-2*stddev,mean+2*stddev] +template +class TruncatedNormalDistribution { + private: + static inline _CUDA_HD T step(sd::graph::RandomGenerator *rng, T mean, + T stddev, Nd4jLong e, Nd4jLong middle, T &z) { + auto epm = e + middle; + const T two_pi = + static_cast(2.0f) * static_cast(3.14159265358979323846); + const T epsilon = static_cast(1.e-5f); + // we need to get random values + T r0 = rng->relativeT(e, epsilon, static_cast(1.0f)); + T r1 = rng->relativeT(epm, epsilon, static_cast(1.0f)); + + T realMean0 = mean; + + auto z0 = (sd::math::nd4j_sqrt(static_cast(-2.0f) * + sd::math::nd4j_log(r0)) * + sd::math::nd4j_cos(two_pi * r1)) * + stddev + + realMean0; + z = z0; + if (epm < middle) { + T realMean1 = mean; + auto z1 = (sd::math::nd4j_sqrt(static_cast(-2.0f) * + sd::math::nd4j_log(r0)) * + sd::math::nd4j_sin(two_pi * r1)) * + stddev + + realMean1; + z = z1; + } + return z; + } + + public: + method_XY method_X method_idx + + static const bool requiresSpecial = true; #ifdef __CUDACC__ - __device__ static inline void specialOpCuda(Nd4jPointer state, T const* x, Nd4jLong const* xShapeBuffer, T const* y, Nd4jLong const* yShapeBuffer, T *z, Nd4jLong const* zShapeBuffer, T *extraArguments) { - __shared__ T epsilon; - __shared__ T two_pi; - - __shared__ Nd4jLong zLength; - __shared__ Nd4jLong zEWS; - __shared__ Nd4jLong yEWS; - __shared__ T mean; - __shared__ T stddev; - __shared__ int step; - - __shared__ T *tZ; - - __shared__ sd::graph::RandomGenerator* rng; - __shared__ unsigned char *cB; - __shared__ unsigned char *dB; - __shared__ sd::graph::RandomGenerator* devRng; - __shared__ Nd4jLong middle; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - rng = reinterpret_cast(shmem); - cB = shmem; - devRng = reinterpret_cast (state); - dB = reinterpret_cast (state); - - tZ = reinterpret_cast(shmem + sizeof(sd::graph::RandomGenerator)); - - zLength = shape::length(zShapeBuffer); - zEWS = shape::elementWiseStride(zShapeBuffer); - yEWS = shape::elementWiseStride(yShapeBuffer); - - epsilon = static_cast(1e-6f); - two_pi = static_cast(2.0f) * static_cast(3.14159265358979323846); - - mean = extraArguments[0]; - stddev = extraArguments[1]; - - step = (blockDim.x * gridDim.x); - middle = zLength / 2 + (zLength % 2); - } - __syncthreads(); - - // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); e+= blockDim.x) - cB[e] = dB[e]; - - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - GaussianDistribution::specialOpCuda(state, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments); - __syncthreads(); - - T ds = sd::math::nd4j_abs(stddev) * static_cast(2.0f); - for (Nd4jLong e = tid; e < zLength; e += step) { - if (z[e] > mean + ds || z[e] < mean - ds) { - z[e] = TruncatedNormalDistribution::step(rng, mean, stddev, e, middle, z[e]); - - if (z[e] > mean + ds || z[e] < mean - ds) - z[e] = mean + sd::DataTypeUtils::min(); - } - } - } + __device__ static inline void specialOpCuda( + Nd4jPointer state, T const *x, Nd4jLong const *xShapeBuffer, T const *y, + Nd4jLong const *yShapeBuffer, T *z, Nd4jLong const *zShapeBuffer, + T *extraArguments) { + __shared__ T epsilon; + __shared__ T two_pi; + + __shared__ Nd4jLong zLength; + __shared__ Nd4jLong zEWS; + __shared__ Nd4jLong yEWS; + __shared__ T mean; + __shared__ T stddev; + __shared__ int step; + + __shared__ T *tZ; + + __shared__ sd::graph::RandomGenerator *rng; + __shared__ unsigned char *cB; + __shared__ unsigned char *dB; + __shared__ sd::graph::RandomGenerator *devRng; + __shared__ Nd4jLong middle; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + rng = reinterpret_cast(shmem); + cB = shmem; + devRng = reinterpret_cast(state); + dB = reinterpret_cast(state); + + tZ = reinterpret_cast(shmem + sizeof(sd::graph::RandomGenerator)); + + zLength = shape::length(zShapeBuffer); + zEWS = shape::elementWiseStride(zShapeBuffer); + yEWS = shape::elementWiseStride(yShapeBuffer); + + epsilon = static_cast(1e-6f); + two_pi = static_cast(2.0f) * static_cast(3.14159265358979323846); + + mean = extraArguments[0]; + stddev = extraArguments[1]; + + step = (blockDim.x * gridDim.x); + middle = zLength / 2 + (zLength % 2); + } + __syncthreads(); + + // using this loop instead of memcpy + for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); + e += blockDim.x) + cB[e] = dB[e]; + + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + GaussianDistribution::specialOpCuda(state, x, xShapeBuffer, y, + yShapeBuffer, z, zShapeBuffer, + extraArguments); + __syncthreads(); + + T ds = sd::math::nd4j_abs(stddev) * static_cast(2.0f); + for (Nd4jLong e = tid; e < zLength; e += step) { + if (z[e] > mean + ds || z[e] < mean - ds) { + z[e] = TruncatedNormalDistribution::step(rng, mean, stddev, e, + middle, z[e]); + + if (z[e] > mean + ds || z[e] < mean - ds) + z[e] = mean + sd::DataTypeUtils::min(); + } + } + } #endif - static inline void - specialOp(Nd4jPointer state, const T *x, const Nd4jLong *xShapeBuffer, const T *y, const Nd4jLong *yShapeBuffer, T *z, const Nd4jLong *zShapeBuffer, T *extraArguments) { - GaussianDistribution::specialOp(state, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments); - Nd4jLong zLength = shape::length(zShapeBuffer); - //auto yEWS = shape::elementWiseStride(yShapeBuffer); - //auto zEWS = shape::elementWiseStride(zShapeBuffer); - auto rng = reinterpret_cast(state); - T mean = extraArguments[0]; - T stddev = extraArguments[1]; - T ds = sd::math::nd4j_abs(stddev) * (T) 2.0f; - Nd4jLong middle = zLength / 2 + (zLength % 2); - int elementsPerThread = middle / TAD_THRESHOLD; - int _threads = sd::math::nd4j_max(1, elementsPerThread); - _threads = sd::math::nd4j_min(_threads, sd::Environment::getInstance().maxThreads()); - - const T epsilon = static_cast(1e-5); - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - if (z[e] > mean + ds || z[e] < mean - ds) { - z[e] = step(rng, mean, stddev, e, middle, z[e]); - - if (z[e] > mean + ds || z[e] < mean - ds) - z[e] = mean + sd::DataTypeUtils::min(); - } - } - }; - - samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); + static inline void specialOp(Nd4jPointer state, const T *x, + const Nd4jLong *xShapeBuffer, const T *y, + const Nd4jLong *yShapeBuffer, T *z, + const Nd4jLong *zShapeBuffer, + T *extraArguments) { + GaussianDistribution::specialOp(state, x, xShapeBuffer, y, yShapeBuffer, + z, zShapeBuffer, extraArguments); + Nd4jLong zLength = shape::length(zShapeBuffer); + // auto yEWS = shape::elementWiseStride(yShapeBuffer); + // auto zEWS = shape::elementWiseStride(zShapeBuffer); + auto rng = reinterpret_cast(state); + T mean = extraArguments[0]; + T stddev = extraArguments[1]; + T ds = sd::math::nd4j_abs(stddev) * (T)2.0f; + Nd4jLong middle = zLength / 2 + (zLength % 2); + int elementsPerThread = middle / TAD_THRESHOLD; + int _threads = sd::math::nd4j_max(1, elementsPerThread); + _threads = sd::math::nd4j_min( + _threads, sd::Environment::getInstance().maxThreads()); + + const T epsilon = static_cast(1e-5); + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + if (z[e] > mean + ds || z[e] < mean - ds) { + z[e] = step(rng, mean, stddev, e, middle, z[e]); + + if (z[e] > mean + ds || z[e] < mean - ds) + z[e] = mean + sd::DataTypeUtils::min(); } + } }; + samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); + } +}; + ////////////////////////////////////////////////////////////////////// // This Op produces random Log-normal distribution - template - class LogNormalDistribution { - public: - - method_XY - method_X - method_idx - - static const bool requiresSpecial = true; +template +class LogNormalDistribution { + public: + method_XY method_X method_idx + static const bool requiresSpecial = true; #ifdef __CUDACC__ - __device__ static inline void specialOpCuda(Nd4jPointer state, T const* x, Nd4jLong const* xShapeBuffer, T const* y, Nd4jLong const* yShapeBuffer, T *z, Nd4jLong const* zShapeBuffer, T *extraArguments) { - __shared__ T epsilon; - __shared__ T two_pi; - - __shared__ Nd4jLong zLength; - __shared__ Nd4jLong zEWS; - __shared__ Nd4jLong yEWS; - __shared__ T mean; - __shared__ T stddev; - __shared__ int step; - - __shared__ T *tZ; - - __shared__ sd::graph::RandomGenerator* rng; - __shared__ unsigned char *cB; - __shared__ unsigned char *dB; - __shared__ sd::graph::RandomGenerator* devRng; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - rng = reinterpret_cast(state); - cB = shmem; - devRng = reinterpret_cast(state); - - dB = reinterpret_cast (state); - - tZ = reinterpret_cast(shmem + sizeof(sd::graph::RandomGenerator)); - - zLength = shape::length(zShapeBuffer); - zEWS = shape::elementWiseStride(zShapeBuffer); - yEWS = shape::elementWiseStride(yShapeBuffer); - - - epsilon = static_cast(1e-5); - two_pi = static_cast(2.0f) * static_cast(3.14159265358979323846); - - mean = extraArguments[0]; - stddev = extraArguments[1]; - - step = (blockDim.x * gridDim.x); - } - __syncthreads(); - - // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); e+= blockDim.x) - cB[e] = dB[e]; - - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - int middle = zLength % 2 == 0 ? zLength / 2 : zLength / 2 + 1; - - for (Nd4jLong e = tid; e < middle; e += step) { - auto epm = e + middle; - - // we need to get random values - T r0 = rng->relativeT(e, epsilon, static_cast(1.0f)); - T r1 = rng->relativeT(epm, epsilon, static_cast(1.0f)); - - T realMean = y == z ? mean : y[e * yEWS]; - - z[e *zEWS] = sd::math::nd4j_exp((sd::math::nd4j_sqrt(static_cast(-2.0f) * sd::math::nd4j_log(r0)) * sd::math::nd4j_cos(two_pi * r1)) * stddev + realMean); - - if (epm < zLength) { - realMean = y == z ? mean : y[epm * yEWS]; - z[epm *zEWS] = sd::math::nd4j_exp((sd::math::nd4j_sqrt(static_cast(-2.0f) * sd::math::nd4j_log(r0)) * sd::math::nd4j_sin(two_pi * r1)) * stddev + realMean); - } - } - } + __device__ static inline void specialOpCuda( + Nd4jPointer state, T const *x, Nd4jLong const *xShapeBuffer, T const *y, + Nd4jLong const *yShapeBuffer, T *z, Nd4jLong const *zShapeBuffer, + T *extraArguments) { + __shared__ T epsilon; + __shared__ T two_pi; + + __shared__ Nd4jLong zLength; + __shared__ Nd4jLong zEWS; + __shared__ Nd4jLong yEWS; + __shared__ T mean; + __shared__ T stddev; + __shared__ int step; + + __shared__ T *tZ; + + __shared__ sd::graph::RandomGenerator *rng; + __shared__ unsigned char *cB; + __shared__ unsigned char *dB; + __shared__ sd::graph::RandomGenerator *devRng; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + rng = reinterpret_cast(state); + cB = shmem; + devRng = reinterpret_cast(state); + + dB = reinterpret_cast(state); + + tZ = reinterpret_cast(shmem + sizeof(sd::graph::RandomGenerator)); + + zLength = shape::length(zShapeBuffer); + zEWS = shape::elementWiseStride(zShapeBuffer); + yEWS = shape::elementWiseStride(yShapeBuffer); + + epsilon = static_cast(1e-5); + two_pi = static_cast(2.0f) * static_cast(3.14159265358979323846); + + mean = extraArguments[0]; + stddev = extraArguments[1]; + + step = (blockDim.x * gridDim.x); + } + __syncthreads(); + + // using this loop instead of memcpy + for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); + e += blockDim.x) + cB[e] = dB[e]; + + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + int middle = zLength % 2 == 0 ? zLength / 2 : zLength / 2 + 1; + + for (Nd4jLong e = tid; e < middle; e += step) { + auto epm = e + middle; + + // we need to get random values + T r0 = rng->relativeT(e, epsilon, static_cast(1.0f)); + T r1 = rng->relativeT(epm, epsilon, static_cast(1.0f)); + + T realMean = y == z ? mean : y[e * yEWS]; + + z[e * zEWS] = sd::math::nd4j_exp( + (sd::math::nd4j_sqrt(static_cast(-2.0f) * + sd::math::nd4j_log(r0)) * + sd::math::nd4j_cos(two_pi * r1)) * + stddev + + realMean); + + if (epm < zLength) { + realMean = y == z ? mean : y[epm * yEWS]; + z[epm * zEWS] = sd::math::nd4j_exp( + (sd::math::nd4j_sqrt(static_cast(-2.0f) * + sd::math::nd4j_log(r0)) * + sd::math::nd4j_sin(two_pi * r1)) * + stddev + + realMean); + } + } + } #endif - static inline void - specialOp(Nd4jPointer state, const T *x, const Nd4jLong *xShapeBuffer, const T *y, const Nd4jLong *yShapeBuffer, T *z, const Nd4jLong *zShapeBuffer, T *extraArguments) { - const T two_pi = static_cast(2.0f) * static_cast(3.14159265358979323846); - - Nd4jLong zLength = shape::length(zShapeBuffer); - auto yEWS = shape::elementWiseStride(yShapeBuffer); - auto zEWS = shape::elementWiseStride(zShapeBuffer); - - auto middle = zLength % 2 == 0 ? zLength / 2 : zLength / 2 + 1; - - int elementsPerThread = middle / TAD_THRESHOLD; - int _threads = sd::math::nd4j_max(1, elementsPerThread); - _threads = sd::math::nd4j_min(_threads, sd::Environment::getInstance().maxThreads()); - - int span = (zLength / _threads) + 8; - - // we're enforcing even chunks, since it's mandatory for this algorithm - span -= span % 2; - - auto rng = reinterpret_cast(state); - - const T mean = extraArguments[0]; - const T stddev = extraArguments[1]; - const T epsilon = static_cast(1e-5); - - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto e = start; e < stop; e++) { - auto epm = e + middle; - - // we need to get random values - T r0 = rng->relativeT(e, epsilon, static_cast(1.0f)); - T r1 = rng->relativeT(epm, epsilon, static_cast(1.0f)); - - T realMean = y == z ? mean : y[e * yEWS]; - - z[e * zEWS] = sd::math::nd4j_exp((sd::math::nd4j_sqrt(static_cast(-2.0f) * sd::math::nd4j_log(r0)) * sd::math::nd4j_cos(two_pi * r1)) * stddev + realMean); - - if (epm < zLength) { - realMean = y == z ? mean : y[epm * yEWS]; - z[epm * zEWS] = sd::math::nd4j_exp((sd::math::nd4j_sqrt(static_cast(-2.0f) * sd::math::nd4j_log(r0)) * sd::math::nd4j_sin(two_pi * r1)) * stddev + realMean); - } - } - }; - - samediff::Threads::parallel_for(func, 0, middle, 1, _threads); + static inline void specialOp(Nd4jPointer state, const T *x, + const Nd4jLong *xShapeBuffer, const T *y, + const Nd4jLong *yShapeBuffer, T *z, + const Nd4jLong *zShapeBuffer, + T *extraArguments) { + const T two_pi = + static_cast(2.0f) * static_cast(3.14159265358979323846); + + Nd4jLong zLength = shape::length(zShapeBuffer); + auto yEWS = shape::elementWiseStride(yShapeBuffer); + auto zEWS = shape::elementWiseStride(zShapeBuffer); + + auto middle = zLength % 2 == 0 ? zLength / 2 : zLength / 2 + 1; + + int elementsPerThread = middle / TAD_THRESHOLD; + int _threads = sd::math::nd4j_max(1, elementsPerThread); + _threads = sd::math::nd4j_min( + _threads, sd::Environment::getInstance().maxThreads()); + + int span = (zLength / _threads) + 8; + + // we're enforcing even chunks, since it's mandatory for this algorithm + span -= span % 2; + + auto rng = reinterpret_cast(state); + + const T mean = extraArguments[0]; + const T stddev = extraArguments[1]; + const T epsilon = static_cast(1e-5); + + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto e = start; e < stop; e++) { + auto epm = e + middle; + + // we need to get random values + T r0 = rng->relativeT(e, epsilon, static_cast(1.0f)); + T r1 = rng->relativeT(epm, epsilon, static_cast(1.0f)); + + T realMean = y == z ? mean : y[e * yEWS]; + + z[e * zEWS] = sd::math::nd4j_exp( + (sd::math::nd4j_sqrt(static_cast(-2.0f) * + sd::math::nd4j_log(r0)) * + sd::math::nd4j_cos(two_pi * r1)) * + stddev + + realMean); + + if (epm < zLength) { + realMean = y == z ? mean : y[epm * yEWS]; + z[epm * zEWS] = sd::math::nd4j_exp( + (sd::math::nd4j_sqrt(static_cast(-2.0f) * + sd::math::nd4j_log(r0)) * + sd::math::nd4j_sin(two_pi * r1)) * + stddev + + realMean); } + } }; + samediff::Threads::parallel_for(func, 0, middle, 1, _threads); + } +}; -} +} // namespace randomOps -#endif //LIBND4J_SPECIAL_RANDOM_OPS_H +#endif // LIBND4J_SPECIAL_RANDOM_OPS_H diff --git a/libnd4j/include/ops/specials.h b/libnd4j/include/ops/specials.h index ed5f8fb8c68a..a4e8fa2d3613 100644 --- a/libnd4j/include/ops/specials.h +++ b/libnd4j/include/ops/specials.h @@ -21,67 +21,84 @@ #ifndef LIBND4J_SPECIALS_H #define LIBND4J_SPECIALS_H - #ifdef __CUDACC__ #define ELEMENT_THRESHOLD 8192 #define TAD_THRESHOLD 2 #endif #include + #include namespace sd { - class NDArray; - - //FIXME: get rid of this redefinition - typedef union - { - float f_; - int i_; - } FloatBits2; - - - class ND4J_EXPORT SpecialTypeConverter { - public: - template - static void convertGeneric(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); - }; - - template - class ND4J_EXPORT SpecialMethods { - public: - static void concatCpuGeneric(const std::vector& inArrs, NDArray& output, int axis); - static void concatCpuGeneric(int dimension, int numArrays, Nd4jPointer *data, Nd4jPointer *inputShapeInfo, void *result, Nd4jLong const* resultShapeInfo); - static void splitCpuGeneric(const NDArray& input, const std::vector& outArrs, int axis); - static void accumulateGeneric(void **x, void *z, const Nd4jLong *zShapeInfo, int n, Nd4jLong length); - static void averageGeneric(void **x, void *z, const Nd4jLong *zShapeInfo, int n, Nd4jLong length, bool propagate); - - static Nd4jLong getPosition(const Nd4jLong *xShapeInfo, Nd4jLong index); - static void quickSort_parallel_internal(T* array, const Nd4jLong *xShapeInfo, int left, int right, int cutoff, bool descending); - static void quickSort_parallel(void* array, const Nd4jLong *xShapeInfo, Nd4jLong lenArray, int numThreads, bool descending); - - static int nextPowerOf2(int number); - static int lastPowerOf2(int number); - - static void sortGeneric(void *x, const Nd4jLong *xShapeInfo, bool descending); - static void sortTadGeneric(void *x, const Nd4jLong *xShapeInfo, int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, bool descending); - - static void decodeBitmapGeneric(const void *dx, Nd4jLong N, void *dz, const Nd4jLong *zShapeInfo); - static Nd4jLong encodeBitmapGeneric(void *dx, const Nd4jLong *zShapeInfo, Nd4jLong N, int *dz, float threshold); - - }; - - template - class ND4J_EXPORT DoubleMethods{ - public: - static void sortByKey(void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, bool descending); - static void sortByValue(void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, bool descending); - - - static void sortTadByKey(void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int *dimension, int dimensionLength, bool descending); - static void sortTadByValue(void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int *dimension, int dimensionLength, bool descending); - }; -} - - -#endif //LIBND4J_SPECIALS_H +class NDArray; + +// FIXME: get rid of this redefinition +typedef union { + float f_; + int i_; +} FloatBits2; + +class SD_EXPORT SpecialTypeConverter { + public: + template + static void convertGeneric(Nd4jPointer *extras, void *dx, Nd4jLong N, + void *dz); +}; + +template +class SD_EXPORT SpecialMethods { + public: + static void concatCpuGeneric(const std::vector &inArrs, + NDArray &output, int axis); + static void concatCpuGeneric(int dimension, int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfo, void *result, + Nd4jLong const *resultShapeInfo); + static void splitCpuGeneric(const NDArray &input, + const std::vector &outArrs, int axis); + static void accumulateGeneric(void **x, void *z, const Nd4jLong *zShapeInfo, + int n, Nd4jLong length); + static void averageGeneric(void **x, void *z, const Nd4jLong *zShapeInfo, + int n, Nd4jLong length, bool propagate); + + static Nd4jLong getPosition(const Nd4jLong *xShapeInfo, Nd4jLong index); + static void quickSort_parallel_internal(T *array, const Nd4jLong *xShapeInfo, + int left, int right, int cutoff, + bool descending); + static void quickSort_parallel(void *array, const Nd4jLong *xShapeInfo, + Nd4jLong lenArray, int numThreads, + bool descending); + + static int nextPowerOf2(int number); + static int lastPowerOf2(int number); + + static void sortGeneric(void *x, const Nd4jLong *xShapeInfo, bool descending); + static void sortTadGeneric(void *x, const Nd4jLong *xShapeInfo, + int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, bool descending); + + static void decodeBitmapGeneric(const void *dx, Nd4jLong N, void *dz, + const Nd4jLong *zShapeInfo); + static Nd4jLong encodeBitmapGeneric(void *dx, const Nd4jLong *zShapeInfo, + Nd4jLong N, int *dz, float threshold); +}; + +template +class SD_EXPORT DoubleMethods { + public: + static void sortByKey(void *vx, Nd4jLong const *xShapeInfo, void *vy, + Nd4jLong const *yShapeInfo, bool descending); + static void sortByValue(void *vx, Nd4jLong const *xShapeInfo, void *vy, + Nd4jLong const *yShapeInfo, bool descending); + + static void sortTadByKey(void *vx, Nd4jLong const *xShapeInfo, void *vy, + Nd4jLong const *yShapeInfo, int *dimension, + int dimensionLength, bool descending); + static void sortTadByValue(void *vx, Nd4jLong const *xShapeInfo, void *vy, + Nd4jLong const *yShapeInfo, int *dimension, + int dimensionLength, bool descending); +}; +} // namespace sd + +#endif // LIBND4J_SPECIALS_H diff --git a/libnd4j/include/ops/specials_cuda.h b/libnd4j/include/ops/specials_cuda.h index a12fd302f648..17cf26e30af7 100644 --- a/libnd4j/include/ops/specials_cuda.h +++ b/libnd4j/include/ops/specials_cuda.h @@ -21,85 +21,115 @@ #ifndef PROJECT_SPECIALS_CUDA_H #define PROJECT_SPECIALS_CUDA_H -#include #include +#include #ifdef __CUDACC__ //////////////////////////////////////////////////////////////////////// -template -__host__ void bitonicSortStepGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, int j, int k, int length, bool descending); +template +__host__ void bitonicSortStepGeneric(dim3 &launchDims, cudaStream_t *stream, + void *vx, Nd4jLong const *xShapeInfo, + int j, int k, int length, bool descending); //////////////////////////////////////////////////////////////////////// -template -__host__ void bitonicArbitraryStepGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, int window, int length, int reverse, bool descending); +template +__host__ void bitonicArbitraryStepGeneric(dim3 &launchDims, + cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, + int window, int length, int reverse, + bool descending); //////////////////////////////////////////////////////////////////////// template -__host__ void bitonicSortStepGenericKey(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int j, int k, int length, bool descending); +__host__ void bitonicSortStepGenericKey(dim3 &launchDims, cudaStream_t *stream, + void *vx, Nd4jLong const *xShapeInfo, + void *vy, Nd4jLong const *yShapeInfo, + int j, int k, int length, + bool descending); //////////////////////////////////////////////////////////////////////// template -__host__ void bitonicArbitraryStepGenericKey(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int window, int length, int reverse, bool descending); +__host__ void bitonicArbitraryStepGenericKey( + dim3 &launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, void *vy, Nd4jLong const *yShapeInfo, + int window, int length, int reverse, bool descending); //////////////////////////////////////////////////////////////////////// template -__host__ void bitonicSortStepGenericValue(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int j, int k, int length, bool descending); +__host__ void bitonicSortStepGenericValue(dim3 &launchDims, + cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, void *vy, + Nd4jLong const *yShapeInfo, int j, + int k, int length, bool descending); //////////////////////////////////////////////////////////////////////// template -__host__ void bitonicArbitraryStepGenericValue(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int window, int length, int reverse, bool descending); - - +__host__ void bitonicArbitraryStepGenericValue( + dim3 &launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, void *vy, Nd4jLong const *yShapeInfo, + int window, int length, int reverse, bool descending); //////////////////////////////////////////////////////////////////////// -template -__host__ void oesTadGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool descending); +template +__host__ void oesTadGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, int *dimension, + int dimensionLength, Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, bool descending); template -__host__ void oesTadGenericKey(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool descending); +__host__ void oesTadGenericKey(dim3 &launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, void *vy, + Nd4jLong const *yShapeInfo, int *dimension, + int dimensionLength, + Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, bool descending); template -__host__ void oesTadGenericValue(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool descending); +__host__ void oesTadGenericValue(dim3 &launchDims, cudaStream_t *stream, + void *vx, Nd4jLong const *xShapeInfo, void *vy, + Nd4jLong const *yShapeInfo, int *dimension, + int dimensionLength, + Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, bool descending); //////////////////////////////////////////////////////////////////////// -template -__global__ void printCudaGlobal(void* pointer, const int len) { - - for(int i = 0; i < len; ++i) - printf("%f, ", (double)reinterpret_cast(pointer)[i] ); - printf("\n"); +template +__global__ void printCudaGlobal(void *pointer, const int len) { + for (int i = 0; i < len; ++i) + printf("%f, ", (double)reinterpret_cast(pointer)[i]); + printf("\n"); } //////////////////////////////////////////////////////////////////////// -template -__device__ void printCudaDevice(void* pointer, const int len, const int tid = 0) { - - if(blockIdx.x * blockDim.x + threadIdx.x != tid) return; - for(int i = 0; i < len; ++i) - printf("%f, ", (double)reinterpret_cast(pointer)[i] ); - printf("\n"); +template +__device__ void printCudaDevice(void *pointer, const int len, + const int tid = 0) { + if (blockIdx.x * blockDim.x + threadIdx.x != tid) return; + for (int i = 0; i < len; ++i) + printf("%f, ", (double)reinterpret_cast(pointer)[i]); + printf("\n"); } //////////////////////////////////////////////////////////////////////// -template -__host__ void printCudaHost(void* pointer, const int len, cudaStream_t& stream) { - - void* ptr = malloc(sizeof(T)*len); - - cudaMemcpyAsync(ptr, pointer, sizeof(T)*len, cudaMemcpyDeviceToHost, stream); - cudaError_t cudaResult = cudaStreamSynchronize(stream); - if(cudaResult != 0) - throw std::runtime_error("printCudaHost:: cudaStreamSynchronize failed!"); - - for(int i = 0; i < len; ++i) - printf("%f, ", (double)reinterpret_cast(ptr)[i]); - printf("\n"); - - free(ptr); +template +__host__ void printCudaHost(void *pointer, const int len, + cudaStream_t &stream) { + void *ptr = malloc(sizeof(T) * len); + + cudaMemcpyAsync(ptr, pointer, sizeof(T) * len, cudaMemcpyDeviceToHost, + stream); + cudaError_t cudaResult = cudaStreamSynchronize(stream); + if (cudaResult != 0) + throw std::runtime_error("printCudaHost:: cudaStreamSynchronize failed!"); + + for (int i = 0; i < len; ++i) + printf("%f, ", (double)reinterpret_cast(ptr)[i]); + printf("\n"); + + free(ptr); } - #endif -#endif //PROJECT_SPECIALS_CUDA_H +#endif // PROJECT_SPECIALS_CUDA_H diff --git a/libnd4j/include/ops/specials_sparse.h b/libnd4j/include/ops/specials_sparse.h index cd0e2f6b529b..32715509d7ba 100644 --- a/libnd4j/include/ops/specials_sparse.h +++ b/libnd4j/include/ops/specials_sparse.h @@ -26,44 +26,50 @@ #include namespace sd { - namespace sparse { +namespace sparse { - template - class SparseUtils { - public: - /** - * Just simple helper for debugging :) - * - * @param indices - * @param rank - * @param x - */ - static void printIndex(Nd4jLong *indices, int rank, int x); - static bool ltIndices(Nd4jLong *indices, int rank, Nd4jLong x, Nd4jLong y); +template +class SparseUtils { + public: + /** + * Just simple helper for debugging :) + * + * @param indices + * @param rank + * @param x + */ + static void printIndex(Nd4jLong *indices, int rank, int x); + static bool ltIndices(Nd4jLong *indices, int rank, Nd4jLong x, Nd4jLong y); - /** - * Returns true, if x > y, false otherwise - * @param indices - * @param rank - * @param x - * @param y - * @return - */ - static bool gtIndices(Nd4jLong *indices, int rank, Nd4jLong x, Nd4jLong y); + /** + * Returns true, if x > y, false otherwise + * @param indices + * @param rank + * @param x + * @param y + * @return + */ + static bool gtIndices(Nd4jLong *indices, int rank, Nd4jLong x, Nd4jLong y); - static void swapEverything(Nd4jLong *indices, T *array, int rank, Nd4jLong x, Nd4jLong y); + static void swapEverything(Nd4jLong *indices, T *array, int rank, Nd4jLong x, + Nd4jLong y); - static void coo_quickSort_parallel_internal(Nd4jLong *indices, T* array, Nd4jLong left, Nd4jLong right, int cutoff, int rank); + static void coo_quickSort_parallel_internal(Nd4jLong *indices, T *array, + Nd4jLong left, Nd4jLong right, + int cutoff, int rank); - static void coo_quickSort_parallel(Nd4jLong *indices, T* array, Nd4jLong lenArray, int numThreads, int rank); + static void coo_quickSort_parallel(Nd4jLong *indices, T *array, + Nd4jLong lenArray, int numThreads, + int rank); - static Nd4jLong coo_quickSort_findPivot(Nd4jLong *indices, T *array, Nd4jLong left, Nd4jLong right, - int rank); + static Nd4jLong coo_quickSort_findPivot(Nd4jLong *indices, T *array, + Nd4jLong left, Nd4jLong right, + int rank); - static void sortCooIndicesGeneric(Nd4jLong *indices, T *values, Nd4jLong length, int rank); - }; - } -} + static void sortCooIndicesGeneric(Nd4jLong *indices, T *values, + Nd4jLong length, int rank); +}; +} // namespace sparse +} // namespace sd - -#endif //LIBND4J_SPECIALS_SPARSE_H +#endif // LIBND4J_SPECIALS_SPARSE_H diff --git a/libnd4j/include/performance/benchmarking/BenchmarkSuit.h b/libnd4j/include/performance/benchmarking/BenchmarkSuit.h index 7805e570e47a..1d5de9c518d9 100644 --- a/libnd4j/include/performance/benchmarking/BenchmarkSuit.h +++ b/libnd4j/include/performance/benchmarking/BenchmarkSuit.h @@ -21,21 +21,21 @@ #ifndef LIBND4J_BENCHMARKSUIT_H #define LIBND4J_BENCHMARKSUIT_H -#include -#include -#include -#include #include +#include +#include +#include -namespace sd { - class ND4J_EXPORT BenchmarkSuit { - public: - BenchmarkSuit() = default; - ~BenchmarkSuit() = default; +#include - virtual std::string runSuit() = 0; - }; -} +namespace sd { +class SD_EXPORT BenchmarkSuit { + public: + BenchmarkSuit() = default; + ~BenchmarkSuit() = default; + virtual std::string runSuit() = 0; +}; +} // namespace sd -#endif //DEV_TESTS_BENCHMARKSUIT_H +#endif // SD_BENCHMARKSUIT_H diff --git a/libnd4j/include/performance/benchmarking/FullBenchmarkSuit.h b/libnd4j/include/performance/benchmarking/FullBenchmarkSuit.h index 6b2314b96138..14fe0cf78695 100644 --- a/libnd4j/include/performance/benchmarking/FullBenchmarkSuit.h +++ b/libnd4j/include/performance/benchmarking/FullBenchmarkSuit.h @@ -24,11 +24,10 @@ #include namespace sd { - class FullBenchmarkSuit : public BenchmarkSuit { - public: - std::string runSuit() override; - }; -} +class FullBenchmarkSuit : public BenchmarkSuit { + public: + std::string runSuit() override; +}; +} // namespace sd - -#endif //DEV_TESTS_FULLBENCHMARKSUIT_H +#endif // SD_FULLBENCHMARKSUIT_H diff --git a/libnd4j/include/performance/benchmarking/LightBenchmarkSuit.h b/libnd4j/include/performance/benchmarking/LightBenchmarkSuit.h index 65a74b1fe7e0..846d29759941 100644 --- a/libnd4j/include/performance/benchmarking/LightBenchmarkSuit.h +++ b/libnd4j/include/performance/benchmarking/LightBenchmarkSuit.h @@ -24,11 +24,10 @@ #include namespace sd { - class LightBenchmarkSuit : public BenchmarkSuit { - public: - std::string runSuit() override; - }; -} +class LightBenchmarkSuit : public BenchmarkSuit { + public: + std::string runSuit() override; +}; +} // namespace sd - -#endif //DEV_TESTS_LIGHTBENCHMARKSUIT_H +#endif // SD_LIGHTBENCHMARKSUIT_H diff --git a/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp b/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp index 0eabe959ad11..3b0cba7a5251 100644 --- a/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp +++ b/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp @@ -19,1904 +19,2044 @@ // #include -#include #include +#include + #include #ifdef RELEASE_BUILD - int wIterations = 4; - int rIterations = 20; - int gemmRegularUpperPow = 11; - int scalarBenchmarkPowLimit = 26; - int transformBenchmarkPowLimit = 26; - int intermediateTransformPowLimit = 22; - int intermediateTransformPowLimit2 = 18; - int pairwisePowLimit = 26; - int heavyPowLimit = 22; - int nonEwsPowLimit = 10; - int reduceScalarPowLimit = 26; - int stridedReductionPowLimit = 20; - int mismatchedAssignPowLimit = 26; - int gatherOpPowLimit = 18; - int gatherOpPowLimit2 = 16; - int gatherOpPowLimit3 = 12; - int broadcastMatrixRankLimit = 5; - int limit30 = 30; - int limit26 = 26; - int limit24 = 24; - int limit22 = 22; - int limit20 = 20; - int limit18 = 18; - int limit10 = 10; - int limit5 = 5; - int limit3 = 3; +int wIterations = 4; +int rIterations = 20; +int gemmRegularUpperPow = 11; +int scalarBenchmarkPowLimit = 26; +int transformBenchmarkPowLimit = 26; +int intermediateTransformPowLimit = 22; +int intermediateTransformPowLimit2 = 18; +int pairwisePowLimit = 26; +int heavyPowLimit = 22; +int nonEwsPowLimit = 10; +int reduceScalarPowLimit = 26; +int stridedReductionPowLimit = 20; +int mismatchedAssignPowLimit = 26; +int gatherOpPowLimit = 18; +int gatherOpPowLimit2 = 16; +int gatherOpPowLimit3 = 12; +int broadcastMatrixRankLimit = 5; +int limit30 = 30; +int limit26 = 26; +int limit24 = 24; +int limit22 = 22; +int limit20 = 20; +int limit18 = 18; +int limit10 = 10; +int limit5 = 5; +int limit3 = 3; #else - int wIterations = 0; - int rIterations = 1; - int gemmRegularUpperPow = 7; - int scalarBenchmarkPowLimit = 10; - int transformBenchmarkPowLimit = 10; - int intermediateTransformPowLimit = 10; - int intermediateTransformPowLimit2 = 10; - int pairwisePowLimit = 10; - int heavyPowLimit = 10; - int nonEwsPowLimit = 6; - int reduceScalarPowLimit = 10; - int stridedReductionPowLimit = 12; - int mismatchedAssignPowLimit = 2; - int gatherOpPowLimit = 10; - int gatherOpPowLimit2 = 8; - int gatherOpPowLimit3 = 8; - int broadcastMatrixRankLimit = 3; - int limit26 = 8; - int limit24 = 8; - int limit22 = 8; - int limit20 = 8; - int limit18 = 8; - int limit10 = 4; - int limit5 = 3; - int limit3 = 1; +int wIterations = 0; +int rIterations = 1; +int gemmRegularUpperPow = 7; +int scalarBenchmarkPowLimit = 10; +int transformBenchmarkPowLimit = 10; +int intermediateTransformPowLimit = 10; +int intermediateTransformPowLimit2 = 10; +int pairwisePowLimit = 10; +int heavyPowLimit = 10; +int nonEwsPowLimit = 6; +int reduceScalarPowLimit = 10; +int stridedReductionPowLimit = 12; +int mismatchedAssignPowLimit = 2; +int gatherOpPowLimit = 10; +int gatherOpPowLimit2 = 8; +int gatherOpPowLimit3 = 8; +int broadcastMatrixRankLimit = 3; +int limit26 = 8; +int limit24 = 8; +int limit22 = 8; +int limit20 = 8; +int limit18 = 8; +int limit10 = 4; +int limit5 = 3; +int limit3 = 1; #endif namespace sd { - static std::string layerNormBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); +static std::string layerNormBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); - BoolParameters nhwc("nhwc"); //0 = nchw + BoolParameters nhwc("nhwc"); // 0 = nchw #ifdef _RELEASE - int c = 32; - int hw = 64; + int c = 32; + int hw = 64; #else - int c = 3; - int hw = 8; + int c = 3; + int hw = 8; #endif - ParametersBatch batch({&nhwc}); - - auto generator = PARAMETRIC_D() { - auto ctx = new Context(1); - int n = p.getIntParam("nhwc"); - - int axis; - if (n == 0) { - //nchw - auto input = NDArrayFactory::create_('c', {16, c, hw, hw}); - auto output = NDArrayFactory::create_('c', {16, c, hw, hw}); - ctx->setInputArray(0, input, true); - ctx->setOutputArray(0, output, true); - axis = 1; - } else { - auto input = NDArrayFactory::create_('c', {32, hw, hw, c}); - auto output = NDArrayFactory::create_('c', {32, hw, hw, c}); - ctx->setInputArray(0, input, true); - ctx->setOutputArray(0, output, true); - axis = 3; - } - - auto bias = NDArrayFactory::create_('c', {c}); - ctx->setInputArray(1, bias, true); - auto iargs = new Nd4jLong[1]; - iargs[0] = axis; - ctx->setIArguments(iargs, 1); - delete[] iargs; - - return ctx; - }; + ParametersBatch batch({&nhwc}); + + auto generator = PARAMETRIC_D() { + auto ctx = new Context(1); + int n = p.getIntParam("nhwc"); + + int axis; + if (n == 0) { + // nchw + auto input = NDArrayFactory::create('c', {16, c, hw, hw}); + auto output = NDArrayFactory::create('c', {16, c, hw, hw}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); + axis = 1; + } else { + auto input = NDArrayFactory::create('c', {32, hw, hw, c}); + auto output = NDArrayFactory::create('c', {32, hw, hw, c}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); + axis = 3; + } - sd::ops::layer_norm layerNorm; - DeclarableBenchmark benchmark(layerNorm, "layer norm"); - output += helper.runOperationSuit(&benchmark, generator, batch, "Layer Norm"); + auto bias = NDArrayFactory::create('c', {c}); + ctx->setInputArray(1, bias); + auto iargs = new Nd4jLong[1]; + iargs[0] = axis; + ctx->setIArguments(iargs, 1); + delete[] iargs; - return output; - } + return ctx; + }; + + sd::ops::layer_norm layerNorm; + DeclarableBenchmark benchmark(layerNorm, "layer norm"); + output += helper.runOperationSuit(&benchmark, generator, batch, "Layer Norm"); + return output; +} - static std::string maxPool3DBenchmark(){ - std::string output; - BenchmarkHelper helper(wIterations, rIterations); +static std::string maxPool3DBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); - BoolParameters ncdhw("ncdhw"); //1 = ndhwc - ParametersBatch batch({&ncdhw}); + BoolParameters ncdhw("ncdhw"); // 1 = ndhwc + ParametersBatch batch({&ncdhw}); - sd::ops::maxpool3dnew maxpool3Dnew; - DeclarableBenchmark benchmark(maxpool3Dnew, "maxPool3d"); + sd::ops::maxpool3dnew maxpool3Dnew; + DeclarableBenchmark benchmark(maxpool3Dnew, "maxPool3d"); #ifdef _RELEASE - int mb = 16; - int chIn = 16; - int chOut = 16; - int dhw = 64; + int mb = 16; + int chIn = 16; + int chOut = 16; + int dhw = 64; #else - int mb = 1; - int chIn = 3; - int chOut = 3; - int dhw = 16; + int mb = 1; + int chIn = 3; + int chOut = 3; + int dhw = 16; #endif - auto generator = PARAMETRIC_D() { - auto ctx = new Context(1); - int format = p.getIntParam("ncdhw"); - - //Set inputs and outputs - //Same mode + stride 1: output is same shape as input - if(format == 1) { - //NDHWC - ctx->setInputArray(0, NDArrayFactory::create_('c', {mb, dhw, dhw, dhw, chIn}), true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {mb, dhw, dhw, dhw, chIn}), true); - } else { - //NCDHW - ctx->setInputArray(0, NDArrayFactory::create_('c', {mb, chIn, dhw, dhw, dhw}), true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {mb, chIn, dhw, dhw, dhw}), true); - } - - auto iargs = new Nd4jLong[15]; - //Kernel, strides, padding, dilation - x3 each - iargs[0] = 3; //Kernel - iargs[1] = 3; - iargs[2] = 3; - iargs[3] = 1; //Stride - iargs[4] = 1; - iargs[5] = 1; - iargs[6] = 0; //Padding - iargs[7] = 0; - iargs[8] = 0; - iargs[9] = 1; //Dilation - iargs[10] = 1; - iargs[11] = 1; - iargs[12] = 1; //Same mode - iargs[13] = 0; //Unused for max - iargs[14] = format; //0 = ncdhw - ctx->setIArguments(iargs, 14); - delete[] iargs; - - return ctx; - }; - - output += helper.runOperationSuit(&benchmark, generator, batch, "maxPool3d"); - return output; + auto generator = PARAMETRIC_D() { + auto ctx = new Context(1); + int format = p.getIntParam("ncdhw"); + + // Set inputs and outputs + // Same mode + stride 1: output is same shape as input + if (format == 1) { + // NDHWC + ctx->setInputArray( + 0, NDArrayFactory::create('c', {mb, dhw, dhw, dhw, chIn})); + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {mb, dhw, dhw, dhw, chIn})); + } else { + // NCDHW + ctx->setInputArray( + 0, NDArrayFactory::create('c', {mb, chIn, dhw, dhw, dhw})); + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {mb, chIn, dhw, dhw, dhw})); } - - static std::string conv3dBenchmark(){ - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - BoolParameters ncdhw("ncdhw"); //1 = ndhwc - ParametersBatch batch({&ncdhw}); - - sd::ops::conv3dnew conv3Dnew; - DeclarableBenchmark benchmark(conv3Dnew, "conv3d"); + auto iargs = new Nd4jLong[15]; + // Kernel, strides, padding, dilation - x3 each + iargs[0] = 3; // Kernel + iargs[1] = 3; + iargs[2] = 3; + iargs[3] = 1; // Stride + iargs[4] = 1; + iargs[5] = 1; + iargs[6] = 0; // Padding + iargs[7] = 0; + iargs[8] = 0; + iargs[9] = 1; // Dilation + iargs[10] = 1; + iargs[11] = 1; + iargs[12] = 1; // Same mode + iargs[13] = 0; // Unused for max + iargs[14] = format; // 0 = ncdhw + ctx->setIArguments(iargs, 14); + delete[] iargs; + + return ctx; + }; + + output += helper.runOperationSuit(&benchmark, generator, batch, "maxPool3d"); + return output; +} + +static std::string conv3dBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + BoolParameters ncdhw("ncdhw"); // 1 = ndhwc + ParametersBatch batch({&ncdhw}); + + sd::ops::conv3dnew conv3Dnew; + DeclarableBenchmark benchmark(conv3Dnew, "conv3d"); #ifdef _RELEASE - int mb = 16; - int chIn = 16; - int chOut = 16; - int dhw = 64; + int mb = 16; + int chIn = 16; + int chOut = 16; + int dhw = 64; #else - int mb = 1; - int chIn = 3; - int chOut = 3; - int dhw = 16; + int mb = 1; + int chIn = 3; + int chOut = 3; + int dhw = 16; #endif - auto generator = PARAMETRIC_D() { - auto ctx = new Context(1); - int format = p.getIntParam("ncdhw"); - - //Set inputs and outputs - //Same mode + stride 1: output is same shape as input - if(format == 1) { - //NDHWC - ctx->setInputArray(0, NDArrayFactory::create_('c', {mb, dhw, dhw, dhw, chIn}), true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {mb, dhw, dhw, dhw, chIn}), true); - } else { - //NCDHW - ctx->setInputArray(0, NDArrayFactory::create_('c', {mb, chIn, dhw, dhw, dhw}), true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {mb, chIn, dhw, dhw, dhw}), true); - } - - //Weights and bias: - ctx->setInputArray(1, NDArrayFactory::create_('c', {3, 3, 3, chIn, chOut}), true); - ctx->setInputArray(2, NDArrayFactory::create_('c', {chOut}), true); - - - auto iargs = new Nd4jLong[14]; - //Kernel, strides, padding, dilation - x3 each - iargs[0] = 3; //Kernel - iargs[1] = 3; - iargs[2] = 3; - iargs[3] = 1; //Stride - iargs[4] = 1; - iargs[5] = 1; - iargs[6] = 0; //Padding - iargs[7] = 0; - iargs[8] = 0; - iargs[9] = 1; //Dilation - iargs[10] = 1; - iargs[11] = 1; - iargs[12] = 1; //Same mode - iargs[13] = format; //0 = ncdhw - ctx->setIArguments(iargs, 14); - delete[] iargs; - - return ctx; - }; - - output += helper.runOperationSuit(&benchmark, generator, batch, "CNN3D"); - return output; + auto generator = PARAMETRIC_D() { + auto ctx = new Context(1); + int format = p.getIntParam("ncdhw"); + + // Set inputs and outputs + // Same mode + stride 1: output is same shape as input + if (format == 1) { + // NDHWC + ctx->setInputArray( + 0, NDArrayFactory::create('c', {mb, dhw, dhw, dhw, chIn})); + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {mb, dhw, dhw, dhw, chIn})); + } else { + // NCDHW + ctx->setInputArray( + 0, NDArrayFactory::create('c', {mb, chIn, dhw, dhw, dhw})); + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {mb, chIn, dhw, dhw, dhw})); } - - static std::string lstmBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - BoolParameters format("format"); //0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen] + // Weights and bias: + ctx->setInputArray( + 1, NDArrayFactory::create('c', {3, 3, 3, chIn, chOut})); + ctx->setInputArray(2, NDArrayFactory::create('c', {chOut})); + + auto iargs = new Nd4jLong[14]; + // Kernel, strides, padding, dilation - x3 each + iargs[0] = 3; // Kernel + iargs[1] = 3; + iargs[2] = 3; + iargs[3] = 1; // Stride + iargs[4] = 1; + iargs[5] = 1; + iargs[6] = 0; // Padding + iargs[7] = 0; + iargs[8] = 0; + iargs[9] = 1; // Dilation + iargs[10] = 1; + iargs[11] = 1; + iargs[12] = 1; // Same mode + iargs[13] = format; // 0 = ncdhw + ctx->setIArguments(iargs, 14); + delete[] iargs; + + return ctx; + }; + + output += helper.runOperationSuit(&benchmark, generator, batch, "CNN3D"); + return output; +} + +static std::string lstmBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + BoolParameters format( + "format"); // 0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen] #ifdef _RELEASE - PredefinedParameters mb("mb", {1, 8, 64}); - PredefinedParameters nInOut("nInOut", {32, 256, 1024}); + PredefinedParameters mb("mb", {1, 8, 64}); + PredefinedParameters nInOut("nInOut", {32, 256, 1024}); #else - PredefinedParameters mb("mb", {1}); - PredefinedParameters nInOut("nInOut", {32}); + PredefinedParameters mb("mb", {1}); + PredefinedParameters nInOut("nInOut", {32}); #endif - ParametersBatch batch({&format, &mb, &nInOut}); - sd::ops::lstmBlock lstmBlock; - DeclarableBenchmark benchmark(lstmBlock, "lstm"); - - int seqLength = 32; - - auto generator = PARAMETRIC_D() { - auto ctx = new Context(1); - int f = p.getIntParam("format"); - int m = p.getIntParam("mb"); - int n = p.getIntParam("nInOut"); - - Nd4jLong l = 0; - ctx->setInputArray(0, NDArrayFactory::create_(l), true); //Max TS length (unused) - - - if (f == 0) { - //TNS format - ctx->setInputArray(1, NDArrayFactory::create_('c', {seqLength, m, n}), true); //x - ctx->setOutputArray(0, NDArrayFactory::create_('c', {seqLength, m, n}), true); //i - ctx->setOutputArray(1, NDArrayFactory::create_('c', {seqLength, m, n}), true); //c - ctx->setOutputArray(2, NDArrayFactory::create_('c', {seqLength, m, n}), true); //f - ctx->setOutputArray(3, NDArrayFactory::create_('c', {seqLength, m, n}), true); //o - ctx->setOutputArray(4, NDArrayFactory::create_('c', {seqLength, m, n}), true); //z - ctx->setOutputArray(5, NDArrayFactory::create_('c', {seqLength, m, n}), true); //h - ctx->setOutputArray(6, NDArrayFactory::create_('c', {seqLength, m, n}), true); //y - } else { - //NST format - ctx->setInputArray(1, NDArrayFactory::create_('f', {m, n, seqLength}), true); //x - ctx->setOutputArray(0, NDArrayFactory::create_('f', {m, n, seqLength}), true); //i - ctx->setOutputArray(1, NDArrayFactory::create_('f', {m, n, seqLength}), true); //c - ctx->setOutputArray(2, NDArrayFactory::create_('f', {m, n, seqLength}), true); //f - ctx->setOutputArray(3, NDArrayFactory::create_('f', {m, n, seqLength}), true); //o - ctx->setOutputArray(4, NDArrayFactory::create_('f', {m, n, seqLength}), true); //z - ctx->setOutputArray(5, NDArrayFactory::create_('f', {m, n, seqLength}), true); //h - ctx->setOutputArray(6, NDArrayFactory::create_('f', {m, n, seqLength}), true); //y - } - - auto cLast = NDArrayFactory::create_('c', {m, n}); - auto yLast = NDArrayFactory::create_('c', {m, n}); - auto W = NDArrayFactory::create_('c', {2 * n, 4 * n}); - auto Wci = NDArrayFactory::create_('c', {n}); - auto Wcf = NDArrayFactory::create_('c', {n}); - auto Wco = NDArrayFactory::create_('c', {n}); - auto b = NDArrayFactory::create_('c', {4 * n}); - - ctx->setInputArray(2, cLast, true); - ctx->setInputArray(3, yLast, true); - ctx->setInputArray(4, W, true); - ctx->setInputArray(5, Wci, true); - ctx->setInputArray(6, Wcf, true); - ctx->setInputArray(7, Wco, true); - ctx->setInputArray(8, b, true); - - auto iargs = new Nd4jLong[2]; - iargs[0] = 0; //No peephole - iargs[1] = f; - ctx->setIArguments(iargs, 2); - delete[] iargs; - - auto targs = new double[2]; - targs[0] = 1.0; //forget bias - targs[1] = 0.0; //cell clipping value - ctx->setTArguments(targs, 2); - delete[] targs; - return ctx; - }; - - output += helper.runOperationSuit(&benchmark, generator, batch, "LSTMBlock"); - return output; + ParametersBatch batch({&format, &mb, &nInOut}); + sd::ops::lstmBlock lstmBlock; + DeclarableBenchmark benchmark(lstmBlock, "lstm"); + + int seqLength = 32; + + auto generator = PARAMETRIC_D() { + auto ctx = new Context(1); + int f = p.getIntParam("format"); + int m = p.getIntParam("mb"); + int n = p.getIntParam("nInOut"); + + Nd4jLong l = 0; + ctx->setInputArray( + 0, NDArrayFactory::create(l)); // Max TS length (unused) + + if (f == 0) { + // TNS format + ctx->setInputArray( + 1, NDArrayFactory::create('c', {seqLength, m, n})); // x + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {seqLength, m, n})); // i + ctx->setOutputArray( + 1, NDArrayFactory::create('c', {seqLength, m, n})); // c + ctx->setOutputArray( + 2, NDArrayFactory::create('c', {seqLength, m, n})); // f + ctx->setOutputArray( + 3, NDArrayFactory::create('c', {seqLength, m, n})); // o + ctx->setOutputArray( + 4, NDArrayFactory::create('c', {seqLength, m, n})); // z + ctx->setOutputArray( + 5, NDArrayFactory::create('c', {seqLength, m, n})); // h + ctx->setOutputArray( + 6, NDArrayFactory::create('c', {seqLength, m, n})); // y + } else { + // NST format + ctx->setInputArray( + 1, NDArrayFactory::create('f', {m, n, seqLength})); // x + ctx->setOutputArray( + 0, NDArrayFactory::create('f', {m, n, seqLength})); // i + ctx->setOutputArray( + 1, NDArrayFactory::create('f', {m, n, seqLength})); // c + ctx->setOutputArray( + 2, NDArrayFactory::create('f', {m, n, seqLength})); // f + ctx->setOutputArray( + 3, NDArrayFactory::create('f', {m, n, seqLength})); // o + ctx->setOutputArray( + 4, NDArrayFactory::create('f', {m, n, seqLength})); // z + ctx->setOutputArray( + 5, NDArrayFactory::create('f', {m, n, seqLength})); // h + ctx->setOutputArray( + 6, NDArrayFactory::create('f', {m, n, seqLength})); // y } - static std::string batchnormBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - //Convolution2D op - BoolParameters nhwc("nhwc"); + auto cLast = NDArrayFactory::create('c', {m, n}); + auto yLast = NDArrayFactory::create('c', {m, n}); + auto W = NDArrayFactory::create('c', {2 * n, 4 * n}); + auto Wci = NDArrayFactory::create('c', {n}); + auto Wcf = NDArrayFactory::create('c', {n}); + auto Wco = NDArrayFactory::create('c', {n}); + auto b = NDArrayFactory::create('c', {4 * n}); + + ctx->setInputArray(2, cLast); + ctx->setInputArray(3, yLast); + ctx->setInputArray(4, W); + ctx->setInputArray(5, Wci); + ctx->setInputArray(6, Wcf); + ctx->setInputArray(7, Wco); + ctx->setInputArray(8, b); + + auto iargs = new Nd4jLong[2]; + iargs[0] = 0; // No peephole + iargs[1] = f; + ctx->setIArguments(iargs, 2); + delete[] iargs; + + auto targs = new double[2]; + targs[0] = 1.0; // forget bias + targs[1] = 0.0; // cell clipping value + ctx->setTArguments(targs, 2); + delete[] targs; + return ctx; + }; + + output += helper.runOperationSuit(&benchmark, generator, batch, "LSTMBlock"); + return output; +} + +static std::string batchnormBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + // Convolution2D op + BoolParameters nhwc("nhwc"); #ifdef _RELEASE - PredefinedParameters c("c", {3, 32, 128}); - PredefinedParameters hw("hw", {32, 128}); + PredefinedParameters c("c", {3, 32, 128}); + PredefinedParameters hw("hw", {32, 128}); #else - PredefinedParameters c("c", {3}); - PredefinedParameters hw("hw", {16}); + PredefinedParameters c("c", {3}); + PredefinedParameters hw("hw", {16}); #endif - ParametersBatch batch({&nhwc, &c, &hw}); - - auto generator = PARAMETRIC_D() { - auto ctx = new Context(1); - int n = p.getIntParam("nhwc"); - int hw = p.getIntParam("hw"); - int ch = p.getIntParam("c"); - - auto args = new Nd4jLong[3]; - args[0] = args[1] = 1; //apply scale and offset - if (n == 0) { - auto input = NDArrayFactory::create_('c', {32, ch, hw, hw}); - auto output = NDArrayFactory::create_('c', {32, ch, hw, hw}); - ctx->setInputArray(0, input, true); - ctx->setOutputArray(0, output, true); - args[2] = 1; //axis - } else { - auto input = NDArrayFactory::create_('c', {32, hw, hw, ch}); - auto output = NDArrayFactory::create_('c', {32, hw, hw, ch}); - ctx->setInputArray(0, input, true); - ctx->setOutputArray(0, output, true); - args[2] = 3; //axis - } - ctx->setIArguments(args, 3); - delete[] args; - - ctx->setInputArray(1, NDArrayFactory::create_('c', {ch}), true); //mean - auto v = NDArrayFactory::create_('c', {ch}); - v->assign(1.0f); - ctx->setInputArray(2, v, true); //variance - auto g = NDArrayFactory::create_('c', {ch}); - g->assign(1.0); - ctx->setInputArray(3, g, true); //gamma - auto b = NDArrayFactory::create_('c', {ch}); - b->assign(1.0); - ctx->setInputArray(4, b, true); //beta - - auto targs = new double[1]; - targs[0] = 1e-5; - ctx->setTArguments(targs, 1); - delete[] targs; - - return ctx; - }; - - sd::ops::batchnorm batchnorm; - DeclarableBenchmark benchmark(batchnorm, "batchnorm"); - output += helper.runOperationSuit(&benchmark, generator, batch, "Batch Normalization"); - - return output; + ParametersBatch batch({&nhwc, &c, &hw}); + + auto generator = PARAMETRIC_D() { + auto ctx = new Context(1); + int n = p.getIntParam("nhwc"); + int hw = p.getIntParam("hw"); + int ch = p.getIntParam("c"); + + auto args = new Nd4jLong[3]; + args[0] = args[1] = 1; // apply scale and offset + if (n == 0) { + auto input = NDArrayFactory::create('c', {32, ch, hw, hw}); + auto output = NDArrayFactory::create('c', {32, ch, hw, hw}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); + args[2] = 1; // axis + } else { + auto input = NDArrayFactory::create('c', {32, hw, hw, ch}); + auto output = NDArrayFactory::create('c', {32, hw, hw, ch}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); + args[2] = 3; // axis } - - static std::string pool2dBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - //Convolution2D op - BoolParameters nhwc("nhwc"); + ctx->setIArguments(args, 3); + delete[] args; + + ctx->setInputArray(1, NDArrayFactory::create('c', {ch})); // mean + auto v = NDArrayFactory::create('c', {ch}); + v.assign(1.0f); + ctx->setInputArray(2, v); // variance + auto g = NDArrayFactory::create('c', {ch}); + g.assign(1.0); + ctx->setInputArray(3, g); // gamma + auto b = NDArrayFactory::create('c', {ch}); + b.assign(1.0); + ctx->setInputArray(4, b); // beta + + auto targs = new double[1]; + targs[0] = 1e-5; + ctx->setTArguments(targs, 1); + delete[] targs; + + return ctx; + }; + + sd::ops::batchnorm batchnorm; + DeclarableBenchmark benchmark(batchnorm, "batchnorm"); + output += helper.runOperationSuit(&benchmark, generator, batch, + "Batch Normalization"); + + return output; +} + +static std::string pool2dBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + // Convolution2D op + BoolParameters nhwc("nhwc"); #ifdef _RELEASE - PredefinedParameters k("k", {2, 3, 5}); - PredefinedParameters c("c", {3, 32, 128}); - PredefinedParameters hw("hw", {32, 128}); + PredefinedParameters k("k", {2, 3, 5}); + PredefinedParameters c("c", {3, 32, 128}); + PredefinedParameters hw("hw", {32, 128}); #else - PredefinedParameters k("k", {2}); - PredefinedParameters c("c", {3}); - PredefinedParameters hw("hw", {8}); + PredefinedParameters k("k", {2}); + PredefinedParameters c("c", {3}); + PredefinedParameters hw("hw", {8}); #endif - ParametersBatch batch({&nhwc, &k, &c, &hw}); - - auto generator = PARAMETRIC_D() { - auto ctx = new Context(1); - int n = p.getIntParam("nhwc"); - int hw = p.getIntParam("hw"); - int khw = p.getIntParam("k"); - - if (n == 0) { - auto input = NDArrayFactory::create_('c', {32, p.getIntParam("c"), hw, hw}); - auto output = NDArrayFactory::create_('c', {32, p.getIntParam("c"), hw, hw}); - ctx->setInputArray(0, input, true); - ctx->setOutputArray(0, output, true); - } else { - auto input = NDArrayFactory::create_('c', {32, hw, hw, p.getIntParam("c")}); - auto output = NDArrayFactory::create_('c', {32, hw, hw, p.getIntParam("c")}); - ctx->setInputArray(0, input, true); - ctx->setOutputArray(0, output, true); - } - - auto args = new Nd4jLong[11]; - args[0] = args[1] = khw; //Kernel - args[2] = args[3] = 1;//Stride - args[4] = args[5] = 0; //Pad - args[6] = args[7] = 1; //Dilation - args[8] = 1; //SAME - args[9] = 0; //Divisor mode - 0 = exclude padding in divisor - args[10] = n;//0-nchw, 1=nhwc - ctx->setIArguments(args, 11); - delete[] args; - - return ctx; - }; - - sd::ops::avgpool2d avgpool2d; - DeclarableBenchmark benchmark1(avgpool2d, "avgpool"); - output += helper.runOperationSuit(&benchmark1, generator, batch, "Average Pooling 2d Operation"); - - sd::ops::maxpool2d maxpool2d; - DeclarableBenchmark benchmark2(maxpool2d, "maxpool"); - output += helper.runOperationSuit(&benchmark2, generator, batch, "Max Pooling 2d Operation"); - return output; + ParametersBatch batch({&nhwc, &k, &c, &hw}); + + auto generator = PARAMETRIC_D() { + auto ctx = new Context(1); + int n = p.getIntParam("nhwc"); + int hw = p.getIntParam("hw"); + int khw = p.getIntParam("k"); + + if (n == 0) { + auto input = + NDArrayFactory::create('c', {32, p.getIntParam("c"), hw, hw}); + auto output = + NDArrayFactory::create('c', {32, p.getIntParam("c"), hw, hw}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); + } else { + auto input = + NDArrayFactory::create('c', {32, hw, hw, p.getIntParam("c")}); + auto output = + NDArrayFactory::create('c', {32, hw, hw, p.getIntParam("c")}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); } - static std::string conv2dBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - //Convolution2D op - BoolParameters nhwc("nhwc"); + auto args = new Nd4jLong[11]; + args[0] = args[1] = khw; // Kernel + args[2] = args[3] = 1; // Stride + args[4] = args[5] = 0; // Pad + args[6] = args[7] = 1; // Dilation + args[8] = 1; // SAME + args[9] = 0; // Divisor mode - 0 = exclude padding in divisor + args[10] = n; // 0-nchw, 1=nhwc + ctx->setIArguments(args, 11); + delete[] args; + + return ctx; + }; + + sd::ops::avgpool2d avgpool2d; + DeclarableBenchmark benchmark1(avgpool2d, "avgpool"); + output += helper.runOperationSuit(&benchmark1, generator, batch, + "Average Pooling 2d Operation"); + + sd::ops::maxpool2d maxpool2d; + DeclarableBenchmark benchmark2(maxpool2d, "maxpool"); + output += helper.runOperationSuit(&benchmark2, generator, batch, + "Max Pooling 2d Operation"); + return output; +} + +static std::string conv2dBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + // Convolution2D op + BoolParameters nhwc("nhwc"); #ifdef _RELEASE - PredefinedParameters k("k", {2, 3, 5}); - PredefinedParameters c("c", {3, 32, 128}); - PredefinedParameters hw("hw", {32, 128}); + PredefinedParameters k("k", {2, 3, 5}); + PredefinedParameters c("c", {3, 32, 128}); + PredefinedParameters hw("hw", {32, 128}); #else - PredefinedParameters k("k", {2}); - PredefinedParameters c("c", {3}); - PredefinedParameters hw("hw", {8}); + PredefinedParameters k("k", {2}); + PredefinedParameters c("c", {3}); + PredefinedParameters hw("hw", {8}); #endif - ParametersBatch batch({&nhwc, &k, &c, &hw}); - sd::ops::conv2d conv2d; - DeclarableBenchmark benchmark(conv2d, "conv2d"); - - auto generator = PARAMETRIC_D() { - auto ctx = new Context(1); - int n = p.getIntParam("nhwc"); - int hw = p.getIntParam("hw"); - int khw = p.getIntParam("k"); - - if (n == 0) { - auto input = NDArrayFactory::create_('c', {32, p.getIntParam("c"), hw, hw}); - auto output = NDArrayFactory::create_('c', {32, p.getIntParam("c"), hw, hw}); - ctx->setInputArray(0, input, true); - ctx->setOutputArray(0, output, true); - } else { - auto input = NDArrayFactory::create_('c', {32, hw, hw, p.getIntParam("c")}); - auto output = NDArrayFactory::create_('c', {32, hw, hw, p.getIntParam("c")}); - ctx->setInputArray(0, input, true); - ctx->setOutputArray(0, output, true); - } - - auto b = NDArrayFactory::create_('c', {p.getIntParam("c")}); - auto w = NDArrayFactory::create_('c', {khw, khw, p.getIntParam("c"), p.getIntParam("c")}); // [kH, kW, iC, oC] always - - ctx->setInputArray(1, w, true); - ctx->setInputArray(2, b, true); - - auto args = new Nd4jLong[10]; - args[0] = args[1] = khw; //Kernel - args[2] = args[3] = 1;//Stride - args[4] = args[5] = 0; //Pad - args[6] = args[7] = 1; //Dilation - args[8] = 1; //SAME - args[9] = n;//0-nchw, 1=nhwc - ctx->setIArguments(args, 10); - delete[] args; - - return ctx; - }; - - output += helper.runOperationSuit(&benchmark, generator, batch, "Conv2d Operation"); - return output; - } - - static std::string rngBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - //Uniform, gaussian and bernoulli RNG generation - - IntPowerParameters length("length", 2, 4, scalarBenchmarkPowLimit, 3); //2^8 to 2^30 in steps of 3 - - ParametersBatch batch({&length}); - - auto gen01 = PARAMETRIC_D() { - auto ctx = new Context(1); - ctx->setInputArray(0, NDArrayFactory::create_('c', {2},{1, p.getIntParam("length")}), true); //Shape as NDArray - ctx->setOutputArray(0, NDArrayFactory::create_('c', {1, p.getIntParam("length")}), true); - auto d = new double[2]; - d[0] = 0.0; - d[1] = 1.0; - ctx->setTArguments(d, 2); - delete[] d; - return ctx; - }; - - auto gen05 = PARAMETRIC_D() { - auto ctx = new Context(1); - ctx->setInputArray(0, NDArrayFactory::create_('c', {2},{1, p.getIntParam("length")}), true); //Shape as NDArray - ctx->setOutputArray(0, NDArrayFactory::create_('c', {1, p.getIntParam("length")}), true); - auto d = new double[1]; - d[0] = 0.5; - ctx->setTArguments(d, 1); - delete[] d; - return ctx; - }; - - sd::ops::LegacyRandomOp unif(random::UniformDistribution); - DeclarableBenchmark dbU(unif, "uniform"); - output += helper.runOperationSuit(&dbU, gen01, batch, "Uniform Distribution"); - - sd::ops::LegacyRandomOp gaussian(random::GaussianDistribution); - DeclarableBenchmark dbG(gaussian, "gaussian"); - output += helper.runOperationSuit(&dbG, gen01, batch, "Gaussian Distribution"); - - sd::ops::LegacyRandomOp trunc(random::TruncatedNormalDistribution); - DeclarableBenchmark dbTU(unif, "trunc.norm"); - output += helper.runOperationSuit(&dbTU, gen01, batch, "Truncated Normal Distribution"); - - sd::ops::LegacyRandomOp ln(random::LogNormalDistribution); - DeclarableBenchmark dbLN(ln, "uniform"); - output += helper.runOperationSuit(&dbLN, gen01, batch, "Log Normal Distribution"); - - sd::ops::LegacyRandomOp bernoulli(random::BernoulliDistribution); - DeclarableBenchmark dbB(bernoulli, "bernoulli"); - output += helper.runOperationSuit(&dbB, gen05, batch, "Bernoulli Distribution"); - - sd::ops::LegacyRandomOp dropout(random::BernoulliDistribution); - DeclarableBenchmark dbD(dropout, "dropout"); - output += helper.runOperationSuit(&dbD, gen05, batch, "Dropout"); - - return output; + ParametersBatch batch({&nhwc, &k, &c, &hw}); + sd::ops::conv2d conv2d; + DeclarableBenchmark benchmark(conv2d, "conv2d"); + + auto generator = PARAMETRIC_D() { + auto ctx = new Context(1); + int n = p.getIntParam("nhwc"); + int hw = p.getIntParam("hw"); + int khw = p.getIntParam("k"); + + if (n == 0) { + auto input = + NDArrayFactory::create('c', {32, p.getIntParam("c"), hw, hw}); + auto output = + NDArrayFactory::create('c', {32, p.getIntParam("c"), hw, hw}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); + } else { + auto input = + NDArrayFactory::create('c', {32, hw, hw, p.getIntParam("c")}); + auto output = + NDArrayFactory::create('c', {32, hw, hw, p.getIntParam("c")}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); } - static std::string gemmIrregularBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - //Basically the same as above, but with irregular shapes (not multiples of 8, etc) + auto b = NDArrayFactory::create('c', {p.getIntParam("c")}); + auto w = NDArrayFactory::create( + 'c', {khw, khw, p.getIntParam("c"), + p.getIntParam("c")}); // [kH, kW, iC, oC] always + + ctx->setInputArray(1, w); + ctx->setInputArray(2, b); + + auto args = new Nd4jLong[10]; + args[0] = args[1] = khw; // Kernel + args[2] = args[3] = 1; // Stride + args[4] = args[5] = 0; // Pad + args[6] = args[7] = 1; // Dilation + args[8] = 1; // SAME + args[9] = n; // 0-nchw, 1=nhwc + ctx->setIArguments(args, 10); + delete[] args; + + return ctx; + }; + + output += + helper.runOperationSuit(&benchmark, generator, batch, "Conv2d Operation"); + return output; +} + +static std::string rngBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + // Uniform, gaussian and bernoulli RNG generation + + IntPowerParameters length("length", 2, 4, scalarBenchmarkPowLimit, + 3); // 2^8 to 2^30 in steps of 3 + + ParametersBatch batch({&length}); + + auto gen01 = PARAMETRIC_D() { + auto ctx = new Context(1); + ctx->setInputArray( + 0, NDArrayFactory::create( + 'c', {2}, {1, p.getIntParam("length")})); // Shape as NDArray + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {1, p.getIntParam("length")})); + auto d = new double[2]; + d[0] = 0.0; + d[1] = 1.0; + ctx->setTArguments(d, 2); + delete[] d; + return ctx; + }; + + auto gen05 = PARAMETRIC_D() { + auto ctx = new Context(1); + ctx->setInputArray( + 0, NDArrayFactory::create( + 'c', {2}, {1, p.getIntParam("length")})); // Shape as NDArray + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {1, p.getIntParam("length")})); + auto d = new double[1]; + d[0] = 0.5; + ctx->setTArguments(d, 1); + delete[] d; + return ctx; + }; + + sd::ops::LegacyRandomOp unif(random::UniformDistribution); + DeclarableBenchmark dbU(unif, "uniform"); + output += helper.runOperationSuit(&dbU, gen01, batch, "Uniform Distribution"); + + sd::ops::LegacyRandomOp gaussian(random::GaussianDistribution); + DeclarableBenchmark dbG(gaussian, "gaussian"); + output += + helper.runOperationSuit(&dbG, gen01, batch, "Gaussian Distribution"); + + sd::ops::LegacyRandomOp trunc(random::TruncatedNormalDistribution); + DeclarableBenchmark dbTU(unif, "trunc.norm"); + output += helper.runOperationSuit(&dbTU, gen01, batch, + "Truncated Normal Distribution"); + + sd::ops::LegacyRandomOp ln(random::LogNormalDistribution); + DeclarableBenchmark dbLN(ln, "uniform"); + output += + helper.runOperationSuit(&dbLN, gen01, batch, "Log Normal Distribution"); + + sd::ops::LegacyRandomOp bernoulli(random::BernoulliDistribution); + DeclarableBenchmark dbB(bernoulli, "bernoulli"); + output += + helper.runOperationSuit(&dbB, gen05, batch, "Bernoulli Distribution"); + + sd::ops::LegacyRandomOp dropout(random::BernoulliDistribution); + DeclarableBenchmark dbD(dropout, "dropout"); + output += helper.runOperationSuit(&dbD, gen05, batch, "Dropout"); + + return output; +} + +static std::string gemmIrregularBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + // Basically the same as above, but with irregular shapes (not multiples of 8, + // etc) #ifdef _RELEASE - int tAMax = 1; - int tBMax = 1; - int b = 1024; - int c = 1024; + int tAMax = 1; + int tBMax = 1; + int b = 1024; + int c = 1024; #else - int tAMax = 1; - int tBMax = 1; - int b = 32; - int c = 32; + int tAMax = 1; + int tBMax = 1; + int b = 32; + int c = 32; #endif - for (int tA = 0; tA <= tAMax; tA++) { - for (int tB = 0; tB <= tBMax; tB++) { - IntParameters d("d", 1020, 1028, 1); //1020, 1021, ..., 1028 - ParametersBatch dim({&d}); - - //Vary A.rows: - auto generator = PARAMETRIC_XYZ() { - auto a = p.getIntParam("d"); - std::vector shapeA; - std::vector shapeB; - if (tA) { - shapeA = {b, a}; - } else { - shapeA = {a, b}; - } - if (tB) { - shapeB = {c, b}; - } else { - shapeB = {b, c}; - } - auto A = NDArrayFactory::create_('c', shapeA); - auto B = NDArrayFactory::create_('c', shapeB); - auto C = NDArrayFactory::create_('f', {a, c}); - - x.push_back(A); - y.push_back(B); - z.push_back(C); - }; - - std::string n; - n += "Gemm (a.rows) - tA="; - n += std::to_string(tA); - n += ", tB="; - n += std::to_string(tB); - - MatrixBenchmark mb(1.0, 0.0, tA, tB, n); - - output += helper.runOperationSuit(&mb, generator, dim, n.c_str()); - - //Vary A.columns / B.rows - auto generator2 = PARAMETRIC_XYZ() { - auto a = 1024; - auto b = p.getIntParam("d"); - auto c = 1024; - std::vector shapeA; - std::vector shapeB; - if (tA) { - shapeA = {b, a}; - } else { - shapeA = {a, b}; - } - if (tB) { - shapeB = {c, b}; - } else { - shapeB = {b, c}; - } - auto A = NDArrayFactory::create_('c', shapeA); - auto B = NDArrayFactory::create_('c', shapeB); - auto C = NDArrayFactory::create_('f', {a, c}); - - x.push_back(A); - y.push_back(B); - z.push_back(C); - }; - - std::string n2; - n2 += "Gemm (a.columns) - tA="; - n2 += std::to_string(tA); - n2 += ", tB="; - n2 += std::to_string(tB); - - MatrixBenchmark mb2(1.0, 0.0, tA, tB, n2); - - output += helper.runOperationSuit(&mb2, generator2, dim, n2.c_str()); - - //Vary A.columns / B.rows - auto generator3 = PARAMETRIC_XYZ() { - auto a = 1024; - auto b = 1024; - auto c = p.getIntParam("d"); - std::vector shapeA; - std::vector shapeB; - if (tA) { - shapeA = {b, a}; - } else { - shapeA = {a, b}; - } - if (tB) { - shapeB = {c, b}; - } else { - shapeB = {b, c}; - } - auto A = NDArrayFactory::create_('c', shapeA); - auto B = NDArrayFactory::create_('c', shapeB); - auto C = NDArrayFactory::create_('f', {a, c}); - - x.push_back(A); - y.push_back(B); - z.push_back(C); - }; - - std::string n3; - n3 += "Gemm (b.columns) - tA="; - n3 += std::to_string(tA); - n3 += ", tB="; - n3 += std::to_string(tB); - - MatrixBenchmark mb3(1.0, 0.0, tA, tB, n); - - output += helper.runOperationSuit(&mb3, generator3, dim, n3.c_str()); - } + for (int tA = 0; tA <= tAMax; tA++) { + for (int tB = 0; tB <= tBMax; tB++) { + IntParameters d("d", 1020, 1028, 1); // 1020, 1021, ..., 1028 + ParametersBatch dim({&d}); + + // Vary A.rows: + auto generator = PARAMETRIC_XYZ() { + auto a = p.getIntParam("d"); + std::vector shapeA; + std::vector shapeB; + if (tA) { + shapeA = {b, a}; + } else { + shapeA = {a, b}; } + if (tB) { + shapeB = {c, b}; + } else { + shapeB = {b, c}; + } + auto A = NDArrayFactory::create('c', shapeA); + auto B = NDArrayFactory::create('c', shapeB); + auto C = NDArrayFactory::create('f', {a, c}); + + x.push_back(A); + y.push_back(B); + z.push_back(C); + }; + + std::string n; + n += "Gemm (a.rows) - tA="; + n += std::to_string(tA); + n += ", tB="; + n += std::to_string(tB); + + MatrixBenchmark mb(1.0, 0.0, tA, tB, n); + + output += helper.runOperationSuit(&mb, generator, dim, n.c_str()); + + // Vary A.columns / B.rows + auto generator2 = PARAMETRIC_XYZ() { + auto a = 1024; + auto b = p.getIntParam("d"); + auto c = 1024; + std::vector shapeA; + std::vector shapeB; + if (tA) { + shapeA = {b, a}; + } else { + shapeA = {a, b}; + } + if (tB) { + shapeB = {c, b}; + } else { + shapeB = {b, c}; + } + auto A = NDArrayFactory::create('c', shapeA); + auto B = NDArrayFactory::create('c', shapeB); + auto C = NDArrayFactory::create('f', {a, c}); + + x.push_back(A); + y.push_back(B); + z.push_back(C); + }; + + std::string n2; + n2 += "Gemm (a.columns) - tA="; + n2 += std::to_string(tA); + n2 += ", tB="; + n2 += std::to_string(tB); + + MatrixBenchmark mb2(1.0, 0.0, tA, tB, n2); + + output += helper.runOperationSuit(&mb2, generator2, dim, n2.c_str()); + + // Vary A.columns / B.rows + auto generator3 = PARAMETRIC_XYZ() { + auto a = 1024; + auto b = 1024; + auto c = p.getIntParam("d"); + std::vector shapeA; + std::vector shapeB; + if (tA) { + shapeA = {b, a}; + } else { + shapeA = {a, b}; + } + if (tB) { + shapeB = {c, b}; + } else { + shapeB = {b, c}; + } + auto A = NDArrayFactory::create('c', shapeA); + auto B = NDArrayFactory::create('c', shapeB); + auto C = NDArrayFactory::create('f', {a, c}); - return output; - } - - static std::string batchGemmBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - //Rank 3 - [32,1024,1024]x[32,1024,1024] - //Rank 4 - [4,8,1024,1024]x[4,8,1024,1024] - - IntParameters rank("rank", 3, 4, 1); - - ParametersBatch b({&rank}); - - auto generator = PARAMETRIC_D() { - auto rank = p.getIntParam("rank"); - std::vector shapeA; - std::vector shapeB; - auto ctx = new Context(1); - - if(rank == 3){ - ctx->setInputArray(0, NDArrayFactory::create_('c', {32, 1024, 1024}), true); - ctx->setInputArray(1, NDArrayFactory::create_('c', {32, 1024, 1024}), true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {32, 1024, 1024}), true); - } else { - ctx->setInputArray(0, NDArrayFactory::create_('c', {4, 8, 1024, 1024}), true); - ctx->setInputArray(1, NDArrayFactory::create_('c', {4, 8, 1024, 1024}), true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {4, 8, 1024, 1024}), true); - } + x.push_back(A); + y.push_back(B); + z.push_back(C); + }; - return ctx; - }; + std::string n3; + n3 += "Gemm (b.columns) - tA="; + n3 += std::to_string(tA); + n3 += ", tB="; + n3 += std::to_string(tB); - sd::ops::matmul mmul; - DeclarableBenchmark benchmark(mmul, "mmul (batch)"); - output += helper.runOperationSuit(&benchmark, generator, b, "MMul (batch)"); + MatrixBenchmark mb3(1.0, 0.0, tA, tB, n); - return output; + output += helper.runOperationSuit(&mb3, generator3, dim, n3.c_str()); + } + } + + return output; +} + +static std::string batchGemmBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + // Rank 3 - [32,1024,1024]x[32,1024,1024] + // Rank 4 - [4,8,1024,1024]x[4,8,1024,1024] + + IntParameters rank("rank", 3, 4, 1); + + ParametersBatch b({&rank}); + + auto generator = PARAMETRIC_D() { + auto rank = p.getIntParam("rank"); + std::vector shapeA; + std::vector shapeB; + auto ctx = new Context(1); + + if (rank == 3) { + ctx->setInputArray(0, + NDArrayFactory::create('c', {32, 1024, 1024})); + ctx->setInputArray(1, + NDArrayFactory::create('c', {32, 1024, 1024})); + ctx->setOutputArray(0, + NDArrayFactory::create('c', {32, 1024, 1024})); + } else { + ctx->setInputArray( + 0, NDArrayFactory::create('c', {4, 8, 1024, 1024})); + ctx->setInputArray( + 1, NDArrayFactory::create('c', {4, 8, 1024, 1024})); + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {4, 8, 1024, 1024})); } - static std::string gemmRegularBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - for (int o = 0; o <= 1; o++) { - char resultOrder = (o == 0 ? 'f' : 'c'); - for (int tA = 0; tA <= 1; tA++) { - for (int tB = 0; tB <= 1; tB++) { - - - IntPowerParameters pa("sz", 2, 7, gemmRegularUpperPow, 2); //2^7=128, 2^9=512, 2^11=2048 - - ParametersBatch b({&pa}); - - auto generator = PARAMETRIC_XYZ() { - auto s = p.getIntParam("sz"); - auto A = NDArrayFactory::create_('c', {s, s}); - auto B = NDArrayFactory::create_('c', {s, s}); - auto C = NDArrayFactory::create_(resultOrder, {s, s}); + return ctx; + }; - x.push_back(A); - y.push_back(B); - z.push_back(C); - }; + sd::ops::matmul mmul; + DeclarableBenchmark benchmark(mmul, "mmul (batch)"); + output += helper.runOperationSuit(&benchmark, generator, b, "MMul (batch)"); - std::string n; - n += "Gemm - tA="; - n += std::to_string(tA); - n += ", tB="; - n += std::to_string(tB); - n += ", cOrder="; - n += resultOrder; + return output; +} - MatrixBenchmark mb(1.0, 0.0, tA == 0 ? false : true, tB == 0 ? false : true, n); +static std::string gemmRegularBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); - output += helper.runOperationSuit(&mb, generator, b, n.c_str()); - } - } - } + for (int o = 0; o <= 1; o++) { + char resultOrder = (o == 0 ? 'f' : 'c'); + for (int tA = 0; tA <= 1; tA++) { + for (int tB = 0; tB <= 1; tB++) { + IntPowerParameters pa("sz", 2, 7, gemmRegularUpperPow, + 2); // 2^7=128, 2^9=512, 2^11=2048 - return output; - } + ParametersBatch b({&pa}); - static std::string scatterOpBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - IntPowerParameters length("length", 2, 10, gatherOpPowLimit, 4); //2^10 to 2^26 in steps of 4 - ParametersBatch batch({&length}); - - //Gather 1D tests - 1d ref, 1d indices, 1d updates -> 1d output - sd::ops::scatter_upd scatter_update1; - DeclarableBenchmark sa1d(scatter_update1, "scatter_update1d"); - auto generator = PARAMETRIC_D() { - auto ctx = new Context(1); - int length = p.getIntParam("length"); - auto in = NDArrayFactory::create_('c', {length}); - auto indices = NDArrayFactory::create_('c', {length}); - auto updates = NDArrayFactory::create_('c', {length}); - - int* a = new int[length]; - for( int i=0; ip(i, a[i]); - } - delete[] a; - - ctx->setInputArray(0, in, true); - ctx->setInputArray(1, indices, true); - ctx->setInputArray(2, updates, true); - ctx->setOutputArray(0, in); //Needs to be inplace to avoid copy! - ctx->markInplace(true); - return ctx; + auto generator = PARAMETRIC_XYZ() { + auto s = p.getIntParam("sz"); + auto A = NDArrayFactory::create('c', {s, s}); + auto B = NDArrayFactory::create('c', {s, s}); + auto C = NDArrayFactory::create(resultOrder, {s, s}); + + x.push_back(A); + y.push_back(B); + z.push_back(C); }; - output += helper.runOperationSuit(&sa1d, generator, batch, "Scatter Update - 1d"); - - //Gather 2D tests - 2d input, 1d indices, 2d updates -> 2d output - IntPowerParameters rows("rows", 2, 8, gatherOpPowLimit2, 4); //2^10 to 2^16 in steps of 2: 2^10, ..., 2^20 - PredefinedParameters cols("cols", {32}); - ParametersBatch batch2({&rows, &cols}); - sd::ops::scatter_upd scatter_update2; - DeclarableBenchmark sa2d(scatter_update2, "scatter_update2d"); - auto generator2 = PARAMETRIC_D() { - auto ctx = new Context(1); - int rows = p.getIntParam("rows"); - int cols = p.getIntParam("cols"); - auto in = NDArrayFactory::create_('c', {rows, cols}); - auto indices = NDArrayFactory::create_('c', {rows}); - auto updates = NDArrayFactory::create_('c', {rows, cols}); - - int* a = new int[rows]; - for( int i=0; ip(i, a[i]); - } - delete[] a; - - ctx->setInputArray(0, in, true); - ctx->setInputArray(1, indices, true); - ctx->setInputArray(2, updates, true); - ctx->setOutputArray(0, in); //Needs to be inplace to avoid copy! - ctx->markInplace(true); - return ctx; - }; + std::string n; + n += "Gemm - tA="; + n += std::to_string(tA); + n += ", tB="; + n += std::to_string(tB); + n += ", cOrder="; + n += resultOrder; - output += helper.runOperationSuit(&sa2d, generator2, batch2, "Scatter Update - 2d"); - - //Gather 3D tests - 3d input, 1d indices -> 3d output - IntPowerParameters sz0("sz0", 2, 8, gatherOpPowLimit3, 4); - PredefinedParameters sz1("sz1", {32}); - ParametersBatch batch3({&sz0, &sz1}); - sd::ops::scatter_upd scatter_update3; - DeclarableBenchmark sa3d(scatter_update3, "scatter3d"); - auto generator3 = PARAMETRIC_D() { - auto ctx = new Context(1); - int sz0 = p.getIntParam("sz0"); - int sz1 = p.getIntParam("sz1"); - auto in = NDArrayFactory::create_('c', {sz0, sz1, 512/sz1}); - auto indices = NDArrayFactory::create_('c', {sz0}); - auto updates = NDArrayFactory::create_('c', {sz0, sz1, 512/sz1}); - - int* a = new int[sz0]; - for( int i=0; ip(i, a[i]); - } - delete[] a; - - ctx->setInputArray(0, in, true); - ctx->setInputArray(1, indices, true); - ctx->setInputArray(2, updates, true); - ctx->setOutputArray(0, in); //Needs to be inplace to avoid copy! - ctx->markInplace(true); - return ctx; - }; + MatrixBenchmark mb(1.0, 0.0, tA == 0 ? false : true, + tB == 0 ? false : true, n); - output += helper.runOperationSuit(&sa3d, generator3, batch3, "Scatter Update - 3d"); - return output; + output += helper.runOperationSuit(&mb, generator, b, n.c_str()); + } } - - static std::string gatherOpBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - IntPowerParameters length("length", 2, 10, gatherOpPowLimit, 4); //2^10 to 2^22 in steps of 4 - ParametersBatch batch({&length}); - - //Gather 1D tests - 1d input, 1d indices -> 1d output - sd::ops::gather gather1; - DeclarableBenchmark gather1d(gather1, "gather1d"); - auto generator = PARAMETRIC_D() { - auto ctx = new Context(1); - int length = p.getIntParam("length"); - auto in = NDArrayFactory::create_('c', {length}); - auto indices = NDArrayFactory::create_('c', {length}); - int* a = new int[length]; - for( int i=0; ip(i, a[i]); - } - delete[] a; - - ctx->setInputArray(0, in, true); - ctx->setInputArray(1, indices, true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {length}), true); - return ctx; - }; - - output += helper.runOperationSuit(&gather1d, generator, batch, "Gather - 1d"); - - //Gather 2D tests - 2d input, 1d indices -> 2d output - IntPowerParameters rows("rows", 2, 8, gatherOpPowLimit2, 4); //2^10 to 2^20 in steps of 2: 2^10, ..., 2^20 - PredefinedParameters cols("cols", {32}); - ParametersBatch batch2({&rows, &cols}); - sd::ops::gather gather2; - DeclarableBenchmark gather2d(gather2, "gather2d"); - auto generator2 = PARAMETRIC_D() { - auto ctx = new Context(1); - int rows = p.getIntParam("rows"); - int cols = p.getIntParam("cols"); - auto in = NDArrayFactory::create_('c', {rows, cols}); - auto indices = NDArrayFactory::create_('c', {rows}); - - int* a = new int[rows]; - for( int i=0; ip(i, a[i]); - } - delete[] a; - - ctx->setInputArray(0, in, true); - ctx->setInputArray(1, indices, true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {rows, cols}), true); - return ctx; - }; - - output += helper.runOperationSuit(&gather2d, generator2, batch2, "Gather - 2d"); - - //Gather 3D tests - 3d input, 1d indices -> 3d output - IntPowerParameters sz0("sz0", 2, 8, gatherOpPowLimit3, 4); //2^8 to 2^16 in steps of 4 - PredefinedParameters sz1("sz1", {32}); - ParametersBatch batch3({&sz0, &sz1}); - sd::ops::gather gather3; - DeclarableBenchmark gather3d(gather3, "gather3d"); - auto generator3 = PARAMETRIC_D() { - auto ctx = new Context(1); - int sz0 = p.getIntParam("sz0"); - int sz1 = p.getIntParam("sz1"); - auto in = NDArrayFactory::create_('c', {sz0, sz1, 512/sz1}); - auto indices = NDArrayFactory::create_('c', {sz0}); - - int* a = new int[sz0]; - for( int i=0; ip(i, a[i]); - } - delete[] a; - - ctx->setInputArray(0, in, true); - ctx->setInputArray(1, indices, true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {sz0, sz1, 512/sz1}), true); - return ctx; - }; - - output += helper.runOperationSuit(&gather3d, generator3, batch3, "Gather - 3d"); - - return output; + } + + return output; +} + +static std::string scatterOpBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + IntPowerParameters length("length", 2, 10, gatherOpPowLimit, + 4); // 2^10 to 2^26 in steps of 4 + ParametersBatch batch({&length}); + + // Gather 1D tests - 1d ref, 1d indices, 1d updates -> 1d output + sd::ops::scatter_upd scatter_update1; + DeclarableBenchmark sa1d(scatter_update1, "scatter_update1d"); + auto generator = PARAMETRIC_D() { + auto ctx = new Context(1); + int length = p.getIntParam("length"); + auto in = NDArrayFactory::create('c', {length}); + auto indices = NDArrayFactory::create('c', {length}); + auto updates = NDArrayFactory::create('c', {length}); + + int *a = new int[length]; + for (int i = 0; i < length; i++) { + a[i] = i; } - - static std::string mismatchedOrdersAssignBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - IntPowerParameters rows("rows", 2, 2, mismatchedAssignPowLimit, 4); //2^2 to 2^26 in steps of 2 - 2^1=2, ..., 2^26=67108864 - BoolParameters cf("cf"); - - ParametersBatch batch({&rows, &cf}); - - auto generator = PARAMETRIC_XZ() { - int numElements = 67108864; //2^26 - int rows = p.getIntParam("rows"); - int cols = numElements / rows; - bool c = p.getIntParam("cf"); - - auto arr = NDArrayFactory::create_(c ? 'c' : 'f', {rows, cols}); - auto arr2 = NDArrayFactory::create_(c ? 'f' : 'c', {rows, cols}); - x.push_back(arr); - z.push_back(arr2); - }; - - TransformBenchmark tb(transform::AnyOps::Assign, "assign"); - output += helper.runOperationSuit(&tb, generator, batch, "C->F and F->C Assign"); - - //Also test: NCHW to NHWC and back - BoolParameters nchw("nchw"); - ParametersBatch batch2({&nchw}); - auto generator2 = PARAMETRIC_XZ() { - bool nchw = p.getIntParam("nchw"); - - if(nchw) { - auto orig = NDArrayFactory::create_('c', {16, 32, 64, 64}); - orig->permutei({0,2,3,1}); - x.push_back(orig); - z.push_back(NDArrayFactory::create_('c', {16, 64, 64, 32})); - } else { - auto orig = NDArrayFactory::create_('c', {16, 64, 64, 32}); - orig->permutei({0,3,1,2}); - x.push_back(orig); - z.push_back(NDArrayFactory::create_('c', {16, 32, 64, 64})); - } - }; - - TransformBenchmark tb2(transform::AnyOps::Assign, "assign_nchw"); - output += helper.runOperationSuit(&tb2, generator2, batch2, "nchw->nhwc and nhwc->nchw Assign"); - return output; + srand(12345); + std::random_shuffle(a, (a + length - 1)); + for (int i = 0; i < length; i++) { + indices.p(i, a[i]); } - - static std::string broadcastOpsMatrixBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - //Broadcast ops: matrices for rank 3, 4, 5 - for( int rank=3; rank <= broadcastMatrixRankLimit; rank++ ){ - int numAxisTests = -1; - if(rank == 3){ - numAxisTests = 3; - } else if(rank == 4){ - numAxisTests = 6; - } else if(rank == 5){ - numAxisTests = 10; - } - - IntParameters testNum("testNum", 0,numAxisTests-1,1); - ParametersBatch b({&testNum}); - - auto generator = PARAMETRIC_D(){ - int n = p.getIntParam("testNum"); - std::vector axis({}); - switch(n){ - //rank 3+ - case 0: - axis = std::vector({0,1}); - break; - case 1: - axis = std::vector({0,2}); - break; - case 2: - axis = std::vector({1,2}); - break; - //rank 4+ - case 3: - axis = std::vector({0,3}); - break; - case 4: - axis = std::vector({1,3}); - break; - case 5: - axis = std::vector({2,3}); - break; - //Rank 5 - case 6: - axis = std::vector({0,4}); - break; - case 7: - axis = std::vector({1,4}); - break; - case 8: - axis = std::vector({2,4}); - break; - case 9: - axis = std::vector({3,4}); - break; - } - - - std::vector shape({}); - std::vector toBcShape({}); - int vectorLength; - if(rank == 3){ - shape = std::vector({64,64,64}); - toBcShape = std::vector({64,64,64}); - vectorLength = 64; - } else if(rank == 4){ - shape = std::vector({32,32,32,32}); - toBcShape = std::vector({32,32,32,32}); - vectorLength = 32; - } else if(rank == 5){ - shape = std::vector({16,16,16,16,16}); - toBcShape = std::vector({16,16,16,16,16}); - vectorLength = 16; - } - - for( int i=0; isetInputArray(0, NDArrayFactory::create_('c', shape), true); - ctx->setInputArray(1, NDArrayFactory::create_('c', toBcShape), true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', shape), true); - return ctx; - }; - - std::string name; - name += "Broadcast Matrix Add (Custom) - Rank"; - name += std::to_string(rank); - - sd::ops::add op; - DeclarableBenchmark benchmark(op, "add"); - output += helper.runOperationSuit(&benchmark, generator, b, name.c_str()); - } - - return output; + delete[] a; + + ctx->setInputArray(0, in); + ctx->setInputArray(1, indices); + ctx->setInputArray(2, updates); + ctx->setOutputArray(0, in); // Needs to be inplace to avoid copy! + ctx->markInplace(true); + return ctx; + }; + + output += + helper.runOperationSuit(&sa1d, generator, batch, "Scatter Update - 1d"); + + // Gather 2D tests - 2d input, 1d indices, 2d updates -> 2d output + IntPowerParameters rows("rows", 2, 8, gatherOpPowLimit2, + 4); // 2^10 to 2^16 in steps of 2: 2^10, ..., 2^20 + PredefinedParameters cols("cols", {32}); + ParametersBatch batch2({&rows, &cols}); + sd::ops::scatter_upd scatter_update2; + DeclarableBenchmark sa2d(scatter_update2, "scatter_update2d"); + auto generator2 = PARAMETRIC_D() { + auto ctx = new Context(1); + int rows = p.getIntParam("rows"); + int cols = p.getIntParam("cols"); + auto in = NDArrayFactory::create('c', {rows, cols}); + auto indices = NDArrayFactory::create('c', {rows}); + auto updates = NDArrayFactory::create('c', {rows, cols}); + + int *a = new int[rows]; + for (int i = 0; i < rows; i++) { + a[i] = i; } - - - static std::string broadcast2dBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - PredefinedParameters rows("rows", {65536}); - IntPowerParameters cols("cols", 2, 2, limit10, 4); //2^2, 2^6, 2^10 - BoolParameters axis("axis"); - BoolParameters inplace("inplace"); - - ParametersBatch batch({&rows, &cols, &axis, &inplace}); - - auto generator = PARAMETRIC_D() { - auto a = p.getIntParam("axis"); - auto arr = NDArrayFactory::create_('c', {p.getIntParam("rows"), p.getIntParam("cols")}); - - auto ctx = new Context(1); - ctx->setInputArray(0, arr, true); - if(a == 0){ - ctx->setInputArray(1, NDArrayFactory::create_('c', {p.getIntParam("rows"), 1}), true); - } else { - ctx->setInputArray(1, NDArrayFactory::create_('c', {1, p.getIntParam("cols")}), true); - } - if (p.getIntParam("inplace") == 1) { - ctx->setOutputArray(0, arr); - ctx->markInplace(true); - } else { - ctx->setOutputArray(0, NDArrayFactory::create_('c', {p.getIntParam("rows"), p.getIntParam("cols")}), true); - } - return ctx; - }; - - std::string s("add"); - sd::ops::add op; - DeclarableBenchmark benchmark(op, "add"); - output += helper.runOperationSuit(&benchmark, generator, batch, "Broadcast (Custom) Add - 2d"); - return output; + srand(12345); + std::random_shuffle(a, (a + rows - 1)); + for (int i = 0; i < rows; i++) { + indices.p(i, a[i]); } - - static std::string broadcastBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - //Broadcast ops: vectors for rank 2, 3, 4, 5 - for( int axis=0; axis<=1; axis++ ){ - PredefinedParameters rows("rows", {65536}); - IntPowerParameters cols("cols", 2, 2, limit10, 4); //2^1 to 2^10 in steps of 2 - 2^1=2, ..., 2^10=1024 - BoolParameters inplace("inplace"); - - ParametersBatch batch({&rows, &cols, &inplace}); - - auto generator = PARAMETRIC_XYZ() { - auto arr = NDArrayFactory::create_('c', {p.getIntParam("rows"), p.getIntParam("cols")}); - x.push_back(arr); - if(axis == 0){ - y.push_back(NDArrayFactory::create_('c', {p.getIntParam("rows")})); - } else { - y.push_back(NDArrayFactory::create_('c', {p.getIntParam("cols")})); - } - if (p.getIntParam("inplace") == 1) { - z.push_back(arr); - } else { - z.push_back(NDArrayFactory::create_('c', {p.getIntParam("rows"), p.getIntParam("cols")})); - } - }; - - std::string s("bAdd"); s += std::to_string(axis); s += "r2"; - BroadcastBenchmark bAdd(broadcast::Add, s, {axis}); - output += helper.runOperationSuit(&bAdd, generator, batch, "Broadcast Add - Rank 2"); - } - - for( int rank=3; rank<=5; rank++ ){ - for( int axis=1; axis shape({}); - int vectorLength; - if(rank == 3){ - shape = std::vector({32,128,128}); - vectorLength = 128; - } else if(rank == 4){ - shape = std::vector({16,64,64,64}); - vectorLength = 64; - } else if(rank == 5){ - shape = std::vector({16,48,48,48,48}); - vectorLength = 48; - } - - ParametersBatch batch({}); - - //Note: always inplace here - auto generator = PARAMETRIC_XYZ() { - auto arr = NDArrayFactory::create_('c', shape); - x.push_back(arr); - y.push_back(NDArrayFactory::create_('c', {vectorLength})); - z.push_back(arr); - }; - - std::string name("bArr-r"); name += std::to_string(rank); name += "a"; name += std::to_string(axis); - BroadcastBenchmark bAdd(broadcast::Add, name, {axis}); - std::string n2("Broadcast Add - Rank"); n2 += std::to_string(rank); n2 += " - axis="; n2 += std::to_string(axis); - output += helper.runOperationSuit(&bAdd, generator, batch, n2.c_str()); - } - } - - return output; + delete[] a; + + ctx->setInputArray(0, in); + ctx->setInputArray(1, indices); + ctx->setInputArray(2, updates); + ctx->setOutputArray(0, in); // Needs to be inplace to avoid copy! + ctx->markInplace(true); + return ctx; + }; + + output += + helper.runOperationSuit(&sa2d, generator2, batch2, "Scatter Update - 2d"); + + // Gather 3D tests - 3d input, 1d indices -> 3d output + IntPowerParameters sz0("sz0", 2, 8, gatherOpPowLimit3, 4); + PredefinedParameters sz1("sz1", {32}); + ParametersBatch batch3({&sz0, &sz1}); + sd::ops::scatter_upd scatter_update3; + DeclarableBenchmark sa3d(scatter_update3, "scatter3d"); + auto generator3 = PARAMETRIC_D() { + auto ctx = new Context(1); + int sz0 = p.getIntParam("sz0"); + int sz1 = p.getIntParam("sz1"); + auto in = NDArrayFactory::create('c', {sz0, sz1, 512 / sz1}); + auto indices = NDArrayFactory::create('c', {sz0}); + auto updates = NDArrayFactory::create('c', {sz0, sz1, 512 / sz1}); + + int *a = new int[sz0]; + for (int i = 0; i < sz0; i++) { + a[i] = i; } - - static std::string fastStridedReductionNonEws() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - IntPowerParameters stride("stride", 2, 0, 10, 2); //2^0=1, ..., 2^10=1024 - - ParametersBatch batch({&stride}); - - //This is an edge case: technically an EWS *should* be available here - auto generator1 = PARAMETRIC_XYZ() { - auto stride = p.getIntParam("stride"); - auto arr = NDArrayFactory::create_('c', {131072 + (stride == 1 ? 0 : 1), stride}); - - NDArray* strided; - if(stride == 1){ - strided = arr; - } else { - IndicesList indices({NDIndex::interval(0,131072), NDIndex::interval(0,1)}); - strided = new NDArray(arr->subarray(indices)); //All rows, first column - delete arr; - } - - strided->assign(1.0); - x.push_back(strided); - y.push_back(nullptr); - z.push_back(NDArrayFactory::create_(0.0f)); - }; - - ReductionBenchmark rbSum(reduce::SameOps::Sum, "stridedSum"); - output += helper.runOperationSuit(&rbSum, (const std::function)(generator1), batch, "Strided Sum - No EWS Test 1"); - - - //No EWS defined for this case - auto generator2 = PARAMETRIC_XYZ() { - auto stride = p.getIntParam("stride"); - auto arr = NDArrayFactory::create_('c', {(stride == 1 ? 1 : 2) * 1024, 1024, stride}); - - NDArray* strided; - if(stride == 1){ - strided = arr; - } else { - IndicesList indices({NDIndex::interval(0,2*1024,2), NDIndex::all(), NDIndex::interval(0,1)}); - strided = new NDArray(arr->subarray(indices)); - delete arr; - } - - strided->assign(1.0); - x.push_back(strided); - y.push_back(nullptr); - z.push_back(NDArrayFactory::create_(0.0f)); - }; - - ReductionBenchmark rbSum2(reduce::SameOps::Sum, "stridedSumNoEWS"); - output += helper.runOperationSuit(&rbSum2, (const std::function)(generator2), batch, "Strided Sum - No EWS Test 2"); - - return output; + srand(12345); + std::random_shuffle(a, (a + sz0 - 1)); + for (int i = 0; i < sz0; i++) { + indices.p(i, a[i]); } - - static std::string fastStridedReductionIrregular() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - IntPowerParameters length("length", 2, 12, stridedReductionPowLimit, 4); //2^12 to 2^20 in steps of 4 - PredefinedParameters stride("stride", {26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, - 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, - 1018, 1019, 1020, 1021, 1022, 1023, 1024, 1025, 1026, 1027, 1028}); - - ParametersBatch batch({&length, &stride}); - - auto generator = PARAMETRIC_XYZ() { - auto stride = p.getIntParam("stride"); - auto arr = NDArrayFactory::create_('c', {p.getIntParam("length"), stride}); - - NDArray* strided; - if(stride == 1){ - strided = arr; - } else { - IndicesList indices({NDIndex::all(), NDIndex::interval(0,1)}); - strided = new NDArray(arr->subarray(indices)); //All rows, first column - delete arr; - } - - strided->assign(1.0); - x.push_back(strided); - y.push_back(nullptr); - z.push_back(NDArrayFactory::create_(0.0f)); - }; - - ReductionBenchmark rbSum(reduce::SameOps::Sum, "stridedSum"); - - output += helper.runOperationSuit(&rbSum, (const std::function)(generator), batch, "Strided Sum - Irregular Strides"); - - return output; + delete[] a; + + ctx->setInputArray(0, in); + ctx->setInputArray(1, indices); + ctx->setInputArray(2, updates); + ctx->setOutputArray(0, in); // Needs to be inplace to avoid copy! + ctx->markInplace(true); + return ctx; + }; + + output += + helper.runOperationSuit(&sa3d, generator3, batch3, "Scatter Update - 3d"); + return output; +} + +static std::string gatherOpBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + IntPowerParameters length("length", 2, 10, gatherOpPowLimit, + 4); // 2^10 to 2^22 in steps of 4 + ParametersBatch batch({&length}); + + // Gather 1D tests - 1d input, 1d indices -> 1d output + sd::ops::gather gather1; + DeclarableBenchmark gather1d(gather1, "gather1d"); + auto generator = PARAMETRIC_D() { + auto ctx = new Context(1); + int length = p.getIntParam("length"); + auto in = NDArrayFactory::create('c', {length}); + auto indices = NDArrayFactory::create('c', {length}); + int *a = new int[length]; + for (int i = 0; i < length; i++) { + a[i] = i; } - - static std::string fastStridedReductionsRegular() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - IntPowerParameters length("length", 2, 12, stridedReductionPowLimit, 4); //2^12 to 2^20 in steps of 4 - IntPowerParameters stride("stride", 2, 0, 10); //2^0=1, ..., 2^10=1024 - - ParametersBatch batch({&length, &stride}); - - auto generator = PARAMETRIC_XYZ() { - auto stride = p.getIntParam("stride"); - auto arr = NDArrayFactory::create_('c', {p.getIntParam("length"), stride}); - - NDArray* strided; - if(stride == 1){ - strided = arr; - } else { - IndicesList indices({NDIndex::all(), NDIndex::point(0)}); - strided = new NDArray(arr->subarray(indices)); //All rows, first column - delete arr; - } - - strided->assign(1.0); - x.push_back(strided); - y.push_back(nullptr); -// z.push_back(NDArrayFactory::create_(0.0f)); - z.push_back(NDArrayFactory::create_('c', {1})); - }; - - ReductionBenchmark rbSum(reduce::SameOps::Sum, "Strided Sum"); - - output += helper.runOperationSuit(&rbSum, (const std::function)(generator), batch, "Strided Sum - Regular Strides (powers of 2)"); - - auto generator3 = PARAMETRIC_D(){ - auto ctx = new Context(1); - auto stride = p.getIntParam("stride"); - auto arr = NDArrayFactory::create_('c', {p.getIntParam("length"), stride}); - - NDArray* strided; - if(stride == 1){ - strided = arr; - } else { - IndicesList indices({NDIndex::all(), NDIndex::point(0)}); - strided = new NDArray(arr->subarray(indices)); //All rows, first column - delete arr; - } - - strided->assign(1.0); - ctx->setInputArray(0, strided, true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {1}), true); - auto iargs = new Nd4jLong[1]; - iargs[0] = 0; - ctx->setIArguments(iargs, 1); - delete[] iargs; - return ctx; - }; - - sd::ops::argmax opArgmax; - DeclarableBenchmark dbArgmax(opArgmax, "stridedArgmax"); - output += helper.runOperationSuit(&dbArgmax, generator3, batch, "Strided Argmax"); - return output; + srand(12345); + std::random_shuffle(a, (a + length - 1)); + for (int i = 0; i < length; i++) { + indices.p(i, a[i]); + } + delete[] a; + + ctx->setInputArray(0, in); + ctx->setInputArray(1, indices); + ctx->setOutputArray(0, NDArrayFactory::create('c', {length})); + return ctx; + }; + + output += helper.runOperationSuit(&gather1d, generator, batch, "Gather - 1d"); + + // Gather 2D tests - 2d input, 1d indices -> 2d output + IntPowerParameters rows("rows", 2, 8, gatherOpPowLimit2, + 4); // 2^10 to 2^20 in steps of 2: 2^10, ..., 2^20 + PredefinedParameters cols("cols", {32}); + ParametersBatch batch2({&rows, &cols}); + sd::ops::gather gather2; + DeclarableBenchmark gather2d(gather2, "gather2d"); + auto generator2 = PARAMETRIC_D() { + auto ctx = new Context(1); + int rows = p.getIntParam("rows"); + int cols = p.getIntParam("cols"); + auto in = NDArrayFactory::create('c', {rows, cols}); + auto indices = NDArrayFactory::create('c', {rows}); + + int *a = new int[rows]; + for (int i = 0; i < rows; i++) { + a[i] = i; + } + srand(12345); + std::random_shuffle(a, (a + rows - 1)); + for (int i = 0; i < rows; i++) { + indices.p(i, a[i]); + } + delete[] a; + + ctx->setInputArray(0, in); + ctx->setInputArray(1, indices); + ctx->setOutputArray(0, NDArrayFactory::create('c', {rows, cols})); + return ctx; + }; + + output += + helper.runOperationSuit(&gather2d, generator2, batch2, "Gather - 2d"); + + // Gather 3D tests - 3d input, 1d indices -> 3d output + IntPowerParameters sz0("sz0", 2, 8, gatherOpPowLimit3, + 4); // 2^8 to 2^16 in steps of 4 + PredefinedParameters sz1("sz1", {32}); + ParametersBatch batch3({&sz0, &sz1}); + sd::ops::gather gather3; + DeclarableBenchmark gather3d(gather3, "gather3d"); + auto generator3 = PARAMETRIC_D() { + auto ctx = new Context(1); + int sz0 = p.getIntParam("sz0"); + int sz1 = p.getIntParam("sz1"); + auto in = NDArrayFactory::create('c', {sz0, sz1, 512 / sz1}); + auto indices = NDArrayFactory::create('c', {sz0}); + + int *a = new int[sz0]; + for (int i = 0; i < sz0; i++) { + a[i] = i; + } + srand(12345); + std::random_shuffle(a, (a + sz0 - 1)); + for (int i = 0; i < sz0; i++) { + indices.p(i, a[i]); + } + delete[] a; + + ctx->setInputArray(0, in); + ctx->setInputArray(1, indices); + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {sz0, sz1, 512 / sz1})); + return ctx; + }; + + output += + helper.runOperationSuit(&gather3d, generator3, batch3, "Gather - 3d"); + + return output; +} + +static std::string mismatchedOrdersAssignBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + IntPowerParameters rows( + "rows", 2, 2, mismatchedAssignPowLimit, + 4); // 2^2 to 2^26 in steps of 2 - 2^1=2, ..., 2^26=67108864 + BoolParameters cf("cf"); + + ParametersBatch batch({&rows, &cf}); + + auto generator = PARAMETRIC_XZ() { + int numElements = 67108864; // 2^26 + int rows = p.getIntParam("rows"); + int cols = numElements / rows; + bool c = p.getIntParam("cf"); + + auto arr = NDArrayFactory::create(c ? 'c' : 'f', {rows, cols}); + auto arr2 = NDArrayFactory::create(c ? 'f' : 'c', {rows, cols}); + x.push_back(arr); + z.push_back(arr2); + }; + + TransformBenchmark tb(transform::AnyOps::Assign, "assign"); + output += + helper.runOperationSuit(&tb, generator, batch, "C->F and F->C Assign"); + + // Also test: NCHW to NHWC and back + BoolParameters nchw("nchw"); + ParametersBatch batch2({&nchw}); + auto generator2 = PARAMETRIC_XZ() { + bool nchw = p.getIntParam("nchw"); + + if (nchw) { + auto orig = NDArrayFactory::create('c', {16, 32, 64, 64}); + orig.permutei({0, 2, 3, 1}); + x.push_back(orig); + z.push_back(NDArrayFactory::create('c', {16, 64, 64, 32})); + } else { + auto orig = NDArrayFactory::create('c', {16, 64, 64, 32}); + orig.permutei({0, 3, 1, 2}); + x.push_back(orig); + z.push_back(NDArrayFactory::create('c', {16, 32, 64, 64})); + } + }; + + TransformBenchmark tb2(transform::AnyOps::Assign, "assign_nchw"); + output += helper.runOperationSuit(&tb2, generator2, batch2, + "nchw->nhwc and nhwc->nchw Assign"); + return output; +} + +static std::string broadcastOpsMatrixBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + // Broadcast ops: matrices for rank 3, 4, 5 + for (int rank = 3; rank <= broadcastMatrixRankLimit; rank++) { + int numAxisTests = -1; + if (rank == 3) { + numAxisTests = 3; + } else if (rank == 4) { + numAxisTests = 6; + } else if (rank == 5) { + numAxisTests = 10; } - static std::string fastReduceAlongDimBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - int length[] = {1024*1024, 64*1024*1024}; - int powLimit[] = {10, 20, 26}; - int powStep[] = {2, 2, 4}; - - for( int i=0; i < limit3; i++ ){ - IntPowerParameters rows("rows", 2, 0, powLimit[i], powStep[i]); - BoolParameters dim("dim"); - - - ParametersBatch batch({&rows, &dim}); - - auto generator = PARAMETRIC_XYZ() { - int rows = p.getIntParam("rows"); - int cols = length[i] / rows; - int dim = p.getIntParam("dim"); - auto arr = NDArrayFactory::create_('c', {rows, cols}); - - - x.push_back(arr); - y.push_back(NDArrayFactory::create_(dim)); - - NDArray* result; - if(dim == 0){ - result = NDArrayFactory::create_('c', {cols}); - } else { - result = NDArrayFactory::create_('c', {rows}); - } - z.push_back(result); - }; - - ReductionBenchmark rbSum(reduce::SameOps::Sum, "sum"); - ReductionBenchmark rbMax(reduce::SameOps::Max, "max"); - - std::string s1("Sum Along Dimension - "); - s1 += std::to_string(length[i]); - - output += helper.runOperationSuit(&rbSum, (const std::function)(generator), batch, s1.c_str()); - - - auto generator3 = PARAMETRIC_D(){ - auto ctx = new Context(1); - int rows = p.getIntParam("rows"); - int cols = length[i] / rows; - int dim = p.getIntParam("dim"); - auto arr = NDArrayFactory::create_('c', {rows, cols}); - - Nd4jLong* dimArg = new Nd4jLong[1]; - dimArg[0] = dim; - ctx->setIArguments(dimArg, 1); - delete[] dimArg; - - ctx->setInputArray(0, arr, true); - - NDArray* result; - if(dim == 0){ - result = NDArrayFactory::create_('c', {cols}); - } else { - result = NDArrayFactory::create_('c', {rows}); - } - ctx->setOutputArray(0, result, true); - return ctx; - }; - - std::string s5("Argmax Along Dimension - "); - s5 += std::to_string(length[i]); - - sd::ops::argmax opArgmax; - DeclarableBenchmark dbArgmax(opArgmax, "Argmax"); - output += helper.runOperationSuit(&dbArgmax, generator3, batch, s5.c_str()); + IntParameters testNum("testNum", 0, numAxisTests - 1, 1); + ParametersBatch b({&testNum}); + + auto generator = PARAMETRIC_D() { + int n = p.getIntParam("testNum"); + std::vector axis({}); + switch (n) { + // rank 3+ + case 0: + axis = std::vector({0, 1}); + break; + case 1: + axis = std::vector({0, 2}); + break; + case 2: + axis = std::vector({1, 2}); + break; + // rank 4+ + case 3: + axis = std::vector({0, 3}); + break; + case 4: + axis = std::vector({1, 3}); + break; + case 5: + axis = std::vector({2, 3}); + break; + // Rank 5 + case 6: + axis = std::vector({0, 4}); + break; + case 7: + axis = std::vector({1, 4}); + break; + case 8: + axis = std::vector({2, 4}); + break; + case 9: + axis = std::vector({3, 4}); + break; + } + + std::vector shape({}); + std::vector toBcShape({}); + int vectorLength; + if (rank == 3) { + shape = std::vector({64, 64, 64}); + toBcShape = std::vector({64, 64, 64}); + vectorLength = 64; + } else if (rank == 4) { + shape = std::vector({32, 32, 32, 32}); + toBcShape = std::vector({32, 32, 32, 32}); + vectorLength = 32; + } else if (rank == 5) { + shape = std::vector({16, 16, 16, 16, 16}); + toBcShape = std::vector({16, 16, 16, 16, 16}); + vectorLength = 16; + } + + for (int i = 0; i < rank; i++) { + if (axis[0] == i || axis[1] == i) { + continue; } - - return output; + toBcShape[i] = 1; + } + + auto ctx = new Context(1); + ctx->setInputArray(0, NDArrayFactory::create('c', shape)); + ctx->setInputArray(1, NDArrayFactory::create('c', toBcShape)); + ctx->setOutputArray(0, NDArrayFactory::create('c', shape)); + return ctx; + }; + + std::string name; + name += "Broadcast Matrix Add (Custom) - Rank"; + name += std::to_string(rank); + + sd::ops::add op; + DeclarableBenchmark benchmark(op, "add"); + output += helper.runOperationSuit(&benchmark, generator, b, name.c_str()); + } + + return output; +} + +static std::string broadcast2dBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + PredefinedParameters rows("rows", {65536}); + IntPowerParameters cols("cols", 2, 2, limit10, 4); // 2^2, 2^6, 2^10 + BoolParameters axis("axis"); + BoolParameters inplace("inplace"); + + ParametersBatch batch({&rows, &cols, &axis, &inplace}); + + auto generator = PARAMETRIC_D() { + auto a = p.getIntParam("axis"); + auto arr = NDArrayFactory::create( + 'c', {p.getIntParam("rows"), p.getIntParam("cols")}); + + auto ctx = new Context(1); + ctx->setInputArray(0, arr); + if (a == 0) { + ctx->setInputArray( + 1, NDArrayFactory::create('c', {p.getIntParam("rows"), 1})); + } else { + ctx->setInputArray( + 1, NDArrayFactory::create('c', {1, p.getIntParam("cols")})); } - - static std::string fastReduceToScalarBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - IntPowerParameters length("length", 2, 10, reduceScalarPowLimit, 4); //2^10 to 2^26 in steps of 4 - - ParametersBatch batch({&length}); - - auto generator = PARAMETRIC_XYZ() { - auto arr = NDArrayFactory::create_('c', {p.getIntParam("length")}); - - x.push_back(arr); - y.push_back(nullptr); - z.push_back(NDArrayFactory::create_(0.0f)); - }; - - ReductionBenchmark rbSum(reduce::SameOps::Sum, "sum"); - - output += helper.runOperationSuit(&rbSum, (const std::function)(generator), batch, "Sum - Full Array Reduction"); - - //Index reduction - sd::ops::argmax opArgmax; - DeclarableBenchmark dbArgmax(opArgmax, "Argmax"); - auto generator3 = PARAMETRIC_D(){ - auto ctx = new Context(1); - - ctx->setInputArray(0, NDArrayFactory::create_('c', {p.getIntParam("length")}), true); - ctx->setInputArray(1, NDArrayFactory::create_((Nd4jLong)0), true); - ctx->setOutputArray(0, NDArrayFactory::create_(0), true); - - return ctx; - }; - output += helper.runOperationSuit(&dbArgmax, generator3, batch, "Argmax Full Array Reduction"); - - return output; + if (p.getIntParam("inplace") == 1) { + ctx->setOutputArray(0, arr); + ctx->markInplace(true); + } else { + ctx->setOutputArray( + 0, NDArrayFactory::create( + 'c', {p.getIntParam("rows"), p.getIntParam("cols")})); } - - static std::string fastNonEwsTransformBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - IntPowerParameters rowcol("rowcol", 2, 2, nonEwsPowLimit, 4); //2^2 to 2^14 in steps of 4 -> non-inplace case: 2x 2^10 x 2^10 = 128mb - BoolParameters inplace("inplace"); - - ParametersBatch batch({&rowcol, &inplace}); - - auto generator = PARAMETRIC_XZ() { - int r = p.getIntParam("rowcol"); - auto arr = NDArrayFactory::create_('c', {r, r+1}); - IndicesList indices({NDIndex::all(), NDIndex::interval(0,r-1)}); - auto view = new NDArray(arr->subarray(indices)); - //nd4j_printf("VIEW ARRAY: rows=%lld, columns=%lld", view->sizeAt(0), view->sizeAt(1)); - x.push_back(view); - if(p.getIntParam("inplace") == 1){ - z.push_back(view); - } else { - z.push_back(NDArrayFactory::create_('c', {view->sizeAt(0),view->sizeAt(1)})); - } - delete arr; - }; - - ScalarBenchmark sbLRelu(scalar::Ops::LeakyRELU, "LeakyRELU_View"); - sbLRelu.setY(NDArrayFactory::create_(0.0)); - - TransformBenchmark tbExp(transform::StrictOps::Exp, "exp view"); - - output += helper.runOperationSuit(&sbLRelu, generator, batch, "LeakyRELU View"); - output += helper.runOperationSuit(&tbExp, generator, batch, "Exp View"); - - return output; + return ctx; + }; + + std::string s("add"); + sd::ops::add op; + DeclarableBenchmark benchmark(op, "add"); + output += helper.runOperationSuit(&benchmark, generator, batch, + "Broadcast (Custom) Add - 2d"); + return output; +} + +static std::string broadcastBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + // Broadcast ops: vectors for rank 2, 3, 4, 5 + for (int axis = 0; axis <= 1; axis++) { + PredefinedParameters rows("rows", {65536}); + IntPowerParameters cols( + "cols", 2, 2, limit10, + 4); // 2^1 to 2^10 in steps of 2 - 2^1=2, ..., 2^10=1024 + BoolParameters inplace("inplace"); + + ParametersBatch batch({&rows, &cols, &inplace}); + + auto generator = PARAMETRIC_XYZ() { + auto arr = NDArrayFactory::create( + 'c', {p.getIntParam("rows"), p.getIntParam("cols")}); + x.push_back(arr); + if (axis == 0) { + y.push_back( + NDArrayFactory::create('c', {p.getIntParam("rows")})); + } else { + y.push_back( + NDArrayFactory::create('c', {p.getIntParam("cols")})); + } + if (p.getIntParam("inplace") == 1) { + z.push_back(arr); + } else { + z.push_back(NDArrayFactory::create( + 'c', {p.getIntParam("rows"), p.getIntParam("cols")})); + } + }; + + std::string s("bAdd"); + s += std::to_string(axis); + s += "r2"; + BroadcastBenchmark bAdd(broadcast::Add, s, {axis}); + output += helper.runOperationSuit(&bAdd, generator, batch, + "Broadcast Add - Rank 2"); + } + + for (int rank = 3; rank <= 5; rank++) { + for (int axis = 1; axis < rank; axis++) { + std::vector shape({}); + int vectorLength; + if (rank == 3) { + shape = std::vector({32, 128, 128}); + vectorLength = 128; + } else if (rank == 4) { + shape = std::vector({16, 64, 64, 64}); + vectorLength = 64; + } else if (rank == 5) { + shape = std::vector({16, 48, 48, 48, 48}); + vectorLength = 48; + } + + ParametersBatch batch({}); + + // Note: always inplace here + auto generator = PARAMETRIC_XYZ() { + auto arr = NDArrayFactory::create('c', shape); + x.push_back(arr); + y.push_back(NDArrayFactory::create('c', {vectorLength})); + z.push_back(arr); + }; + + std::string name("bArr-r"); + name += std::to_string(rank); + name += "a"; + name += std::to_string(axis); + BroadcastBenchmark bAdd(broadcast::Add, name, {axis}); + std::string n2("Broadcast Add - Rank"); + n2 += std::to_string(rank); + n2 += " - axis="; + n2 += std::to_string(axis); + output += helper.runOperationSuit(&bAdd, generator, batch, n2.c_str()); } + } - static std::string fastPairwiseBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - IntPowerParameters length("length", 2, 10, pairwisePowLimit, 4); //2^10 to 2^26 in steps of 4 -> max is 512mb - BoolParameters inplace("inplace"); + return output; +} - ParametersBatch batch({&length, &inplace}); +static std::string fastStridedReductionNonEws() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); - auto generator = PARAMETRIC_XYZ() { - auto arr1 = NDArrayFactory::create_('c', {p.getIntParam("length")}); - auto arr2 = NDArrayFactory::create_('c', {p.getIntParam("length")}); - x.push_back(arr1); - y.push_back(arr2); - if(p.getIntParam("inplace") == 1){ - z.push_back(arr1); - } else { - z.push_back(NDArrayFactory::create_('c', {p.getIntParam("length")})); - } - }; + IntPowerParameters stride("stride", 2, 0, 10, 2); // 2^0=1, ..., 2^10=1024 - PairwiseBenchmark pb1(pairwise::Ops::Add, "Add"); - output += helper.runOperationSuit(&pb1, generator, batch, "Pairwise Add"); + ParametersBatch batch({&stride}); - PairwiseBenchmark pb2(pairwise::Ops::Add, "Multiply"); - output += helper.runOperationSuit(&pb2, generator, batch, "Pairwise Multiply"); + // This is an edge case: technically an EWS *should* be available here + auto generator1 = PARAMETRIC_XYZ() { + auto stride = p.getIntParam("stride"); + auto arr = NDArrayFactory::create( + 'c', {131072 + (stride == 1 ? 0 : 1), stride}); - return output; + NDArray strided; + if (stride == 1) { + strided = arr; + } else { + IndicesList indices( + {NDIndex::interval(0, 131072), NDIndex::interval(0, 1)}); + strided = arr.subarray(indices); // All rows, first column } - static std::string heavyTransformsBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - IntPowerParameters length("length", 2, 10, heavyPowLimit, 4); //2^10 to 2^22, steps of 4 - BoolParameters inplace("inplace"); - - ParametersBatch batch({&length, &inplace}); - - auto generator = PARAMETRIC_XZ() { - auto arr = NDArrayFactory::create_('c', {p.getIntParam("length")}); - arr->assign(1.0); - x.push_back(arr); - if (p.getIntParam("inplace") == 1) { - z.push_back(arr); - } else { - z.push_back(NDArrayFactory::create_('c', {p.getIntParam("length")})); - } - }; - - //Ops to test: erf (transform), betainc (custom), polygamma, synthetic ops? - TransformBenchmark erf(transform::StrictOps::Erf, "Erf"); - output += helper.runOperationSuit(&erf, generator, batch, "Error Function (Erf)"); - - ParametersBatch batch2({&length}); - sd::ops::polygamma op1; - DeclarableBenchmark pg(op1, "polygamma"); - auto generator2 = PARAMETRIC_D() { - auto ctx = new Context(1); - auto in0 = NDArrayFactory::create_('c', {p.getIntParam("length")}); - in0->assign(0.25); - auto in1 = NDArrayFactory::create_('c', {p.getIntParam("length")}); - in1->assign(0.5); - ctx->setInputArray(0, in0, true); - ctx->setInputArray(1, in1, true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {p.getIntParam("length")}), true); - return ctx; - }; - - - IntPowerParameters lengthBetaInc("length", 2, 10, heavyPowLimit, 4); //2^10 to 2^22 in steps of 4 - ParametersBatch batch3({&lengthBetaInc}); - sd::ops::betainc op2; - DeclarableBenchmark binc(op2, "betainc"); - auto generator3 = PARAMETRIC_D() { - auto ctx = new Context(1); - auto in0 = NDArrayFactory::create_('c', {p.getIntParam("length")}); - in0->assign(0.25); - auto in1 = NDArrayFactory::create_('c', {p.getIntParam("length")}); - in1->assign(0.5); - auto in2 = NDArrayFactory::create_('c', {p.getIntParam("length")}); - in2->assign(0.75); - ctx->setInputArray(0, in0, true); - ctx->setInputArray(1, in1, true); - ctx->setInputArray(2, in2, true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {p.getIntParam("length")}), true); - return ctx; - }; - - output += helper.runOperationSuit(&pg, generator2, batch2, "PolyGamma Function"); - output += helper.runOperationSuit(&binc, generator3, batch3, "Incomplete Beta Function (BetaInc)"); + strided.assign(1.0); + x.push_back(strided); + y.push_back(NDArray()); + z.push_back(NDArrayFactory::create(0.0f)); + }; + + ReductionBenchmark rbSum(reduce::SameOps::Sum, "stridedSum"); + output += helper.runOperationSuit( + &rbSum, + (const std::function)(generator1), + batch, "Strided Sum - No EWS Test 1"); + + // No EWS defined for this case + auto generator2 = PARAMETRIC_XYZ() { + auto stride = p.getIntParam("stride"); + auto arr = NDArrayFactory::create( + 'c', {(stride == 1 ? 1 : 2) * 1024, 1024, stride}); + + NDArray strided; + if (stride == 1) { + strided = arr; + } else { + IndicesList indices({NDIndex::interval(0, 2 * 1024, 2), NDIndex::all(), + NDIndex::interval(0, 1)}); + strided = arr.subarray(indices); + } - return output; + strided.assign(1.0); + x.push_back(strided); + y.push_back(NDArray()); + z.push_back(NDArrayFactory::create(0.0f)); + }; + + ReductionBenchmark rbSum2(reduce::SameOps::Sum, "stridedSumNoEWS"); + output += helper.runOperationSuit( + &rbSum2, + (const std::function)(generator2), + batch, "Strided Sum - No EWS Test 2"); + + return output; +} + +static std::string fastStridedReductionIrregular() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + IntPowerParameters length("length", 2, 12, stridedReductionPowLimit, + 4); // 2^12 to 2^20 in steps of 4 + PredefinedParameters stride( + "stride", + {26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, + 1018, 1019, 1020, 1021, 1022, 1023, 1024, 1025, 1026, 1027, 1028}); + + ParametersBatch batch({&length, &stride}); + + auto generator = PARAMETRIC_XYZ() { + auto stride = p.getIntParam("stride"); + auto arr = + NDArrayFactory::create('c', {p.getIntParam("length"), stride}); + + NDArray strided; + if (stride == 1) { + strided = arr; + } else { + IndicesList indices({NDIndex::all(), NDIndex::interval(0, 1)}); + strided = arr.subarray(indices); // All rows, first column } - static std::string intermediateTransformsBenchmark() { - std::string output; - - //Non-inplace: 2x 2^26 elements FP32 -> 512MB - BenchmarkHelper helper(wIterations, rIterations); - IntPowerParameters length("length", 2, 10, intermediateTransformPowLimit, 4); //2^20 to 2^22 in steps of 4 - BoolParameters inplace("inplace"); - - ParametersBatch batch({&length, &inplace}); - - auto generator = PARAMETRIC_XZ() { - auto arr = NDArrayFactory::create_('c', {p.getIntParam("length")}); - arr->assign(1.0); - x.push_back(arr); - if(p.getIntParam("inplace") == 1){ - z.push_back(arr); - } else { - z.push_back(NDArrayFactory::create_('c', {p.getIntParam("length")})); - } - }; + strided.assign(1.0); + x.push_back(strided); + y.push_back(NDArray()); + z.push_back(NDArrayFactory::create(0.0f)); + }; - TransformBenchmark tbTanh(transform::StrictOps::Tanh, "tanh"); - TransformBenchmark tbGelu(transform::StrictOps::GELU, "gelu"); + ReductionBenchmark rbSum(reduce::SameOps::Sum, "stridedSum"); - output += helper.runOperationSuit(&tbTanh, generator, batch, "Tanh"); - output += helper.runOperationSuit(&tbGelu, generator, batch, "gelu"); + output += helper.runOperationSuit( + &rbSum, + (const std::function)(generator), + batch, "Strided Sum - Irregular Strides"); + return output; +} - //2x 1024 cols x 2^18 = 2GB - IntPowerParameters rows("rows", 2, 10, intermediateTransformPowLimit2, 4); - PredefinedParameters cols("cols", {4, 128, 1024}); +static std::string fastStridedReductionsRegular() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); - ParametersBatch batch2({&rows, &cols, &inplace}); + IntPowerParameters length("length", 2, 12, stridedReductionPowLimit, + 4); // 2^12 to 2^20 in steps of 4 + IntPowerParameters stride("stride", 2, 0, 10); // 2^0=1, ..., 2^10=1024 - auto generator2 = PARAMETRIC_XZ() { - auto arr = NDArrayFactory::create_('c', {p.getIntParam("rows"), p.getIntParam("cols")}); - arr->assign(1.0); - x.push_back(arr); - if(p.getIntParam("inplace") == 1){ - z.push_back(arr); - } else { - z.push_back(NDArrayFactory::create_('c', {p.getIntParam("rows"), p.getIntParam("cols")})); - } - }; + ParametersBatch batch({&length, &stride}); - //TransformBenchmark tbSoftmax(transform::StrictOps::SoftMax, "softmax"); + auto generator = PARAMETRIC_XYZ() { + auto stride = p.getIntParam("stride"); + auto arr = + NDArrayFactory::create('c', {p.getIntParam("length"), stride}); - //output += helper.runOperationSuit(&tbSoftmax, generator2, batch2, "Softmax"); - - return output; + NDArray strided; + if (stride == 1) { + strided = arr; + } else { + IndicesList indices({NDIndex::all(), NDIndex::point(0)}); + strided = arr.subarray(indices); // All rows, first column } - static std::string fastTransformsBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - IntPowerParameters length("length", 2, 10, transformBenchmarkPowLimit, 4); //2^10 to 2^30 in steps of 4 - 2^10, 2^14, ..., 2^26 - BoolParameters inplace("inplace"); - - ParametersBatch batch({&length, &inplace}); - - auto generator = PARAMETRIC_XZ() { - auto arr = NDArrayFactory::create_('c', {p.getIntParam("length")}); - arr->assign(1.0); - x.push_back(arr); - if(p.getIntParam("inplace") == 1){ - z.push_back(arr); - } else { - z.push_back(NDArrayFactory::create_('c', {p.getIntParam("length")})); - } - }; - - ScalarBenchmark sbLRelu(scalar::Ops::LeakyRELU, "LeakyRELU"); - sbLRelu.setY(NDArrayFactory::create_(0.0)); - - TransformBenchmark tbAbs(transform::SameOps::Abs, "abs"); - TransformBenchmark tbExp(transform::StrictOps::Exp, "exp"); - - output += helper.runOperationSuit(&sbLRelu, generator, batch, "LeakyRELU"); - output += helper.runOperationSuit(&tbAbs, generator, batch, "Abs"); - output += helper.runOperationSuit(&tbExp, generator, batch, "Exp"); - - return output; + strided.assign(1.0); + x.push_back(strided); + y.push_back(NDArray()); + // z.push_back(NDArrayFactory::create(0.0f)); + z.push_back(NDArrayFactory::create('c', {1})); + }; + + ReductionBenchmark rbSum(reduce::SameOps::Sum, "Strided Sum"); + + output += helper.runOperationSuit( + &rbSum, + (const std::function)(generator), + batch, "Strided Sum - Regular Strides (powers of 2)"); + + auto generator3 = PARAMETRIC_D() { + auto ctx = new Context(1); + auto stride = p.getIntParam("stride"); + auto arr = + NDArrayFactory::create('c', {p.getIntParam("length"), stride}); + + NDArray strided; + if (stride == 1) { + strided = arr; + } else { + IndicesList indices({NDIndex::all(), NDIndex::point(0)}); + strided = arr.subarray(indices); // All rows, first column } - static std::string fastScalarBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - IntPowerParameters length("length", 2, 10, scalarBenchmarkPowLimit, 4); //2^10 to 2^30 in steps of 4 - 2^10, 2^14, ..., 2^26 - BoolParameters inplace("inplace"); - - ParametersBatch batch({&length, &inplace}); - - auto generator = PARAMETRIC_XZ() { - auto arr = NDArrayFactory::create_('c', {p.getIntParam("length")}); - arr->assign(1.0); - x.push_back(arr); - if(p.getIntParam("inplace") == 1){ - z.push_back(arr); - } else { - z.push_back(NDArrayFactory::create_('c', {p.getIntParam("length")})); - } - }; - - ScalarBenchmark sbAdd(scalar::Ops::Add, "sAdd"); - ScalarBenchmark sbDiv(scalar::Ops::Divide, "sDiv"); - ScalarBenchmark sbPow(scalar::Ops::Pow, "sPow"); - - - sbAdd.setY(NDArrayFactory::create_(3.14159265359)); - sbDiv.setY(NDArrayFactory::create_(3.14159265359)); - sbPow.setY(NDArrayFactory::create_(3.14159265359)); - - - output += helper.runOperationSuit(&sbAdd, generator, batch, "Scalar Addition - x.add(3.14159265359) - F32"); - output += helper.runOperationSuit(&sbDiv, generator, batch, "Scalar Division - x.div(3.14159265359) - F32"); - output += helper.runOperationSuit(&sbPow, generator, batch, "Scalar Power - x.pow(3.14159265359) - F32"); - - return output; + strided.assign(1.0); + ctx->setInputArray(0, strided); + ctx->setOutputArray(0, NDArrayFactory::create('c', {1})); + auto iargs = new Nd4jLong[1]; + iargs[0] = 0; + ctx->setIArguments(iargs, 1); + delete[] iargs; + return ctx; + }; + + sd::ops::argmax opArgmax; + DeclarableBenchmark dbArgmax(opArgmax, "stridedArgmax"); + output += + helper.runOperationSuit(&dbArgmax, generator3, batch, "Strided Argmax"); + return output; +} + +static std::string fastReduceAlongDimBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + int length[] = {1024 * 1024, 64 * 1024 * 1024}; + int powLimit[] = {10, 20, 26}; + int powStep[] = {2, 2, 4}; + + for (int i = 0; i < limit3; i++) { + IntPowerParameters rows("rows", 2, 0, powLimit[i], powStep[i]); + BoolParameters dim("dim"); + + ParametersBatch batch({&rows, &dim}); + + auto generator = PARAMETRIC_XYZ() { + int rows = p.getIntParam("rows"); + int cols = length[i] / rows; + int dim = p.getIntParam("dim"); + auto arr = NDArrayFactory::create('c', {rows, cols}); + + x.push_back(arr); + y.push_back(NDArrayFactory::create(dim)); + + NDArray result; + if (dim == 0) { + result = NDArrayFactory::create('c', {cols}); + } else { + result = NDArrayFactory::create('c', {rows}); + } + z.push_back(result); + }; + + ReductionBenchmark rbSum(reduce::SameOps::Sum, "sum"); + ReductionBenchmark rbMax(reduce::SameOps::Max, "max"); + + std::string s1("Sum Along Dimension - "); + s1 += std::to_string(length[i]); + + output += helper.runOperationSuit( + &rbSum, + (const std::function)(generator), + batch, s1.c_str()); + + auto generator3 = PARAMETRIC_D() { + auto ctx = new Context(1); + int rows = p.getIntParam("rows"); + int cols = length[i] / rows; + int dim = p.getIntParam("dim"); + auto arr = NDArrayFactory::create('c', {rows, cols}); + + Nd4jLong *dimArg = new Nd4jLong[1]; + dimArg[0] = dim; + ctx->setIArguments(dimArg, 1); + delete[] dimArg; + + ctx->setInputArray(0, arr); + + NDArray result; + if (dim == 0) { + result = NDArrayFactory::create('c', {cols}); + } else { + result = NDArrayFactory::create('c', {rows}); + } + ctx->setOutputArray(0, result); + return ctx; + }; + + std::string s5("Argmax Along Dimension - "); + s5 += std::to_string(length[i]); + + sd::ops::argmax opArgmax; + DeclarableBenchmark dbArgmax(opArgmax, "Argmax"); + output += helper.runOperationSuit(&dbArgmax, generator3, batch, s5.c_str()); + } + + return output; +} + +static std::string fastReduceToScalarBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + IntPowerParameters length("length", 2, 10, reduceScalarPowLimit, + 4); // 2^10 to 2^26 in steps of 4 + + ParametersBatch batch({&length}); + + auto generator = PARAMETRIC_XYZ() { + auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); + + x.push_back(arr); + y.push_back(NDArray()); + z.push_back(NDArrayFactory::create(0.0f)); + }; + + ReductionBenchmark rbSum(reduce::SameOps::Sum, "sum"); + + output += helper.runOperationSuit( + &rbSum, + (const std::function)(generator), + batch, "Sum - Full Array Reduction"); + + // Index reduction + sd::ops::argmax opArgmax; + DeclarableBenchmark dbArgmax(opArgmax, "Argmax"); + auto generator3 = PARAMETRIC_D() { + auto ctx = new Context(1); + + ctx->setInputArray( + 0, NDArrayFactory::create('c', {p.getIntParam("length")})); + ctx->setInputArray(1, NDArrayFactory::create((Nd4jLong)0)); + ctx->setOutputArray(0, NDArrayFactory::create(0)); + + return ctx; + }; + output += helper.runOperationSuit(&dbArgmax, generator3, batch, + "Argmax Full Array Reduction"); + + return output; +} + +static std::string fastNonEwsTransformBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + IntPowerParameters rowcol("rowcol", 2, 2, nonEwsPowLimit, + 4); // 2^2 to 2^14 in steps of 4 -> non-inplace + // case: 2x 2^10 x 2^10 = 128mb + BoolParameters inplace("inplace"); + + ParametersBatch batch({&rowcol, &inplace}); + + auto generator = PARAMETRIC_XZ() { + int r = p.getIntParam("rowcol"); + auto arr = NDArrayFactory::create('c', {r, r + 1}); + IndicesList indices({NDIndex::all(), NDIndex::interval(0, r - 1)}); + auto view = arr.subarray(indices); + // nd4j_printf("VIEW ARRAY: rows=%lld, columns=%lld", view->sizeAt(0), + // view->sizeAt(1)); + x.push_back(view); + if (p.getIntParam("inplace") == 1) { + z.push_back(view); + } else { + z.push_back( + NDArrayFactory::create('c', {view.sizeAt(0), view.sizeAt(1)})); + } + }; + + ScalarBenchmark sbLRelu(scalar::Ops::LeakyRELU, "LeakyRELU_View"); + sbLRelu.setY(NDArrayFactory::create(0.0)); + + TransformBenchmark tbExp(transform::StrictOps::Exp, "exp view"); + + output += + helper.runOperationSuit(&sbLRelu, generator, batch, "LeakyRELU View"); + output += helper.runOperationSuit(&tbExp, generator, batch, "Exp View"); + + return output; +} + +static std::string fastPairwiseBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + IntPowerParameters length("length", 2, 10, pairwisePowLimit, + 4); // 2^10 to 2^26 in steps of 4 -> max is 512mb + BoolParameters inplace("inplace"); + + ParametersBatch batch({&length, &inplace}); + + auto generator = PARAMETRIC_XYZ() { + auto arr1 = NDArrayFactory::create('c', {p.getIntParam("length")}); + auto arr2 = NDArrayFactory::create('c', {p.getIntParam("length")}); + x.push_back(arr1); + y.push_back(arr2); + if (p.getIntParam("inplace") == 1) { + z.push_back(arr1); + } else { + z.push_back( + NDArrayFactory::create('c', {p.getIntParam("length")})); + } + }; + + PairwiseBenchmark pb1(pairwise::Ops::Add, "Add"); + output += helper.runOperationSuit(&pb1, generator, batch, "Pairwise Add"); + + PairwiseBenchmark pb2(pairwise::Ops::Add, "Multiply"); + output += + helper.runOperationSuit(&pb2, generator, batch, "Pairwise Multiply"); + + return output; +} + +static std::string heavyTransformsBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + IntPowerParameters length("length", 2, 10, heavyPowLimit, + 4); // 2^10 to 2^22, steps of 4 + BoolParameters inplace("inplace"); + + ParametersBatch batch({&length, &inplace}); + + auto generator = PARAMETRIC_XZ() { + auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); + arr.assign(1.0); + x.push_back(arr); + if (p.getIntParam("inplace") == 1) { + z.push_back(arr); + } else { + z.push_back( + NDArrayFactory::create('c', {p.getIntParam("length")})); + } + }; + + // Ops to test: erf (transform), betainc (custom), polygamma, synthetic ops? + TransformBenchmark erf(transform::StrictOps::Erf, "Erf"); + output += + helper.runOperationSuit(&erf, generator, batch, "Error Function (Erf)"); + + ParametersBatch batch2({&length}); + sd::ops::polygamma op1; + DeclarableBenchmark pg(op1, "polygamma"); + auto generator2 = PARAMETRIC_D() { + auto ctx = new Context(1); + auto in0 = NDArrayFactory::create('c', {p.getIntParam("length")}); + in0.assign(0.25); + auto in1 = NDArrayFactory::create('c', {p.getIntParam("length")}); + in1.assign(0.5); + ctx->setInputArray(0, in0); + ctx->setInputArray(1, in1); + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {p.getIntParam("length")})); + return ctx; + }; + + IntPowerParameters lengthBetaInc("length", 2, 10, heavyPowLimit, + 4); // 2^10 to 2^22 in steps of 4 + ParametersBatch batch3({&lengthBetaInc}); + sd::ops::betainc op2; + DeclarableBenchmark binc(op2, "betainc"); + auto generator3 = PARAMETRIC_D() { + auto ctx = new Context(1); + auto in0 = NDArrayFactory::create('c', {p.getIntParam("length")}); + in0.assign(0.25); + auto in1 = NDArrayFactory::create('c', {p.getIntParam("length")}); + in1.assign(0.5); + auto in2 = NDArrayFactory::create('c', {p.getIntParam("length")}); + in2.assign(0.75); + ctx->setInputArray(0, in0); + ctx->setInputArray(1, in1); + ctx->setInputArray(2, in2); + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {p.getIntParam("length")})); + return ctx; + }; + + output += + helper.runOperationSuit(&pg, generator2, batch2, "PolyGamma Function"); + output += helper.runOperationSuit(&binc, generator3, batch3, + "Incomplete Beta Function (BetaInc)"); + + return output; +} + +static std::string intermediateTransformsBenchmark() { + std::string output; + + // Non-inplace: 2x 2^26 elements FP32 -> 512MB + BenchmarkHelper helper(wIterations, rIterations); + IntPowerParameters length("length", 2, 10, intermediateTransformPowLimit, + 4); // 2^20 to 2^22 in steps of 4 + BoolParameters inplace("inplace"); + + ParametersBatch batch({&length, &inplace}); + + auto generator = PARAMETRIC_XZ() { + auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); + arr.assign(1.0); + x.push_back(arr); + if (p.getIntParam("inplace") == 1) { + z.push_back(arr); + } else { + z.push_back( + NDArrayFactory::create('c', {p.getIntParam("length")})); + } + }; + + TransformBenchmark tbTanh(transform::StrictOps::Tanh, "tanh"); + TransformBenchmark tbGelu(transform::StrictOps::GELU, "gelu"); + + output += helper.runOperationSuit(&tbTanh, generator, batch, "Tanh"); + output += helper.runOperationSuit(&tbGelu, generator, batch, "gelu"); + + // 2x 1024 cols x 2^18 = 2GB + IntPowerParameters rows("rows", 2, 10, intermediateTransformPowLimit2, 4); + PredefinedParameters cols("cols", {4, 128, 1024}); + + ParametersBatch batch2({&rows, &cols, &inplace}); + + auto generator2 = PARAMETRIC_XZ() { + auto arr = NDArrayFactory::create( + 'c', {p.getIntParam("rows"), p.getIntParam("cols")}); + arr.assign(1.0); + x.push_back(arr); + if (p.getIntParam("inplace") == 1) { + z.push_back(arr); + } else { + z.push_back(NDArrayFactory::create( + 'c', {p.getIntParam("rows"), p.getIntParam("cols")})); } + }; + + // TransformBenchmark tbSoftmax(transform::StrictOps::SoftMax, "softmax"); + + // output += helper.runOperationSuit(&tbSoftmax, generator2, batch2, + // "Softmax"); + + return output; +} + +static std::string fastTransformsBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + IntPowerParameters length( + "length", 2, 10, transformBenchmarkPowLimit, + 4); // 2^10 to 2^30 in steps of 4 - 2^10, 2^14, ..., 2^26 + BoolParameters inplace("inplace"); + + ParametersBatch batch({&length, &inplace}); + + auto generator = PARAMETRIC_XZ() { + auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); + arr.assign(1.0); + x.push_back(arr); + if (p.getIntParam("inplace") == 1) { + z.push_back(arr); + } else { + z.push_back( + NDArrayFactory::create('c', {p.getIntParam("length")})); + } + }; + ScalarBenchmark sbLRelu(scalar::Ops::LeakyRELU, "LeakyRELU"); + sbLRelu.setY(NDArrayFactory::create(0.0)); - static long nowMs(){ - auto s = std::chrono::system_clock::now().time_since_epoch(); - auto v = std::chrono::duration_cast(s).count(); - return v; - } + TransformBenchmark tbAbs(transform::SameOps::Abs, "abs"); + TransformBenchmark tbExp(transform::StrictOps::Exp, "exp"); - static long duration(long start){ - return nowMs() - start; - } + output += helper.runOperationSuit(&sbLRelu, generator, batch, "LeakyRELU"); + output += helper.runOperationSuit(&tbAbs, generator, batch, "Abs"); + output += helper.runOperationSuit(&tbExp, generator, batch, "Exp"); - static long done(long start){ - long dur = duration(start); - nd4j_printf("Done: %i ms\n", dur); - return nowMs(); - } + return output; +} +static std::string fastScalarBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); - std::string FullBenchmarkSuit::runSuit() { - std::string result; - - long start = nowMs(); - - // set 1 - nd4j_printf("Running FullBenchmarkSuite.fastScalarBenchmark\n", ""); - result += fastScalarBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.fastTransformsBenchmark\n", ""); - result += fastTransformsBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.intermediateTransformsBenchmark\n", ""); - result += intermediateTransformsBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.fastPairwiseBenchmark\n", ""); - result += fastPairwiseBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.heavyTransformsBenchmark\n", ""); - result += heavyTransformsBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.fastNonEwsTransformBenchmark\n", ""); - result += fastNonEwsTransformBenchmark(); - start = done(start); - - // set 2 - nd4j_printf("Running FullBenchmarkSuite.fastReduceToScalarBenchmark\n", ""); - result += fastReduceToScalarBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.fastReduceAlongDimBenchmark\n", ""); - result += fastReduceAlongDimBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.fastStridedReductionsRegular\n", ""); - result += fastStridedReductionsRegular(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.fastStridedReductionIrregular\n", ""); - result += fastStridedReductionIrregular(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.fastStridedReductionNonEws\n", ""); - result += fastStridedReductionNonEws(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.broadcastBenchmark\n", ""); - result += broadcastBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.broadcast2dBenchmark\n", ""); - result += broadcast2dBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.broadcastOpsMatrixBenchmark\n", ""); - result += broadcastOpsMatrixBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.mismatchedOrdersAssignBenchmark\n", ""); - result += mismatchedOrdersAssignBenchmark(); - start = done(start); - - - // set 3 - nd4j_printf("Running FullBenchmarkSuite.gatherOpBenchmark\n", ""); - result += gatherOpBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.scatterOpBenchmark\n", ""); - result += scatterOpBenchmark(); - start = done(start); - - // set 4 - nd4j_printf("Running FullBenchmarkSuite.gemmRegularBenchmark\n", ""); - result += gemmRegularBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.gemmIrregularBenchmark\n", ""); - result += gemmIrregularBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.rngBenchmark\n", ""); - result += rngBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.conv2dBenchmark\n", ""); - result += conv2dBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.pool2dBenchmark\n", ""); - result += pool2dBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.batchnormBenchmark\n", ""); - result += batchnormBenchmark(); - start = done(start); - - nd4j_printf("Running FullBenchmarkSuite.lstmBenchmark\n", ""); - result += lstmBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.conv3dBenchmark\n", ""); - result += conv3dBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.maxPool3DBenchmark\n", ""); - result += maxPool3DBenchmark(); - start = done(start); -// nd4j_printf("Running FullBenchmarkSuite.layerNormBenchmark\n", ""); -// result += layerNormBenchmark(); -// start = done(start); - - return result; - } + IntPowerParameters length( + "length", 2, 10, scalarBenchmarkPowLimit, + 4); // 2^10 to 2^30 in steps of 4 - 2^10, 2^14, ..., 2^26 + BoolParameters inplace("inplace"); + ParametersBatch batch({&length, &inplace}); -} \ No newline at end of file + auto generator = PARAMETRIC_XZ() { + auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); + arr.assign(1.0); + x.push_back(arr); + if (p.getIntParam("inplace") == 1) { + z.push_back(arr); + } else { + z.push_back( + NDArrayFactory::create('c', {p.getIntParam("length")})); + } + }; + + ScalarBenchmark sbAdd(scalar::Ops::Add, "sAdd"); + ScalarBenchmark sbDiv(scalar::Ops::Divide, "sDiv"); + ScalarBenchmark sbPow(scalar::Ops::Pow, "sPow"); + + sbAdd.setY(NDArrayFactory::create(3.14159265359)); + sbDiv.setY(NDArrayFactory::create(3.14159265359)); + sbPow.setY(NDArrayFactory::create(3.14159265359)); + + output += helper.runOperationSuit( + &sbAdd, generator, batch, "Scalar Addition - x.add(3.14159265359) - F32"); + output += helper.runOperationSuit( + &sbDiv, generator, batch, "Scalar Division - x.div(3.14159265359) - F32"); + output += helper.runOperationSuit( + &sbPow, generator, batch, "Scalar Power - x.pow(3.14159265359) - F32"); + + return output; +} + +static long nowMs() { + auto s = std::chrono::system_clock::now().time_since_epoch(); + auto v = std::chrono::duration_cast(s).count(); + return v; +} + +static long duration(long start) { return nowMs() - start; } + +static long done(long start) { + long dur = duration(start); + nd4j_printf("Done: %i ms\n", dur); + return nowMs(); +} + +std::string FullBenchmarkSuit::runSuit() { + std::string result; + + long start = nowMs(); + + // set 1 + nd4j_printf("Running FullBenchmarkSuite.fastScalarBenchmark\n", ""); + result += fastScalarBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.fastTransformsBenchmark\n", ""); + result += fastTransformsBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.intermediateTransformsBenchmark\n", + ""); + result += intermediateTransformsBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.fastPairwiseBenchmark\n", ""); + result += fastPairwiseBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.heavyTransformsBenchmark\n", ""); + result += heavyTransformsBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.fastNonEwsTransformBenchmark\n", ""); + result += fastNonEwsTransformBenchmark(); + start = done(start); + + // set 2 + nd4j_printf("Running FullBenchmarkSuite.fastReduceToScalarBenchmark\n", ""); + result += fastReduceToScalarBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.fastReduceAlongDimBenchmark\n", ""); + result += fastReduceAlongDimBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.fastStridedReductionsRegular\n", ""); + result += fastStridedReductionsRegular(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.fastStridedReductionIrregular\n", ""); + result += fastStridedReductionIrregular(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.fastStridedReductionNonEws\n", ""); + result += fastStridedReductionNonEws(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.broadcastBenchmark\n", ""); + result += broadcastBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.broadcast2dBenchmark\n", ""); + result += broadcast2dBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.broadcastOpsMatrixBenchmark\n", ""); + result += broadcastOpsMatrixBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.mismatchedOrdersAssignBenchmark\n", + ""); + result += mismatchedOrdersAssignBenchmark(); + start = done(start); + + // set 3 + nd4j_printf("Running FullBenchmarkSuite.gatherOpBenchmark\n", ""); + result += gatherOpBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.scatterOpBenchmark\n", ""); + result += scatterOpBenchmark(); + start = done(start); + + // set 4 + nd4j_printf("Running FullBenchmarkSuite.gemmRegularBenchmark\n", ""); + result += gemmRegularBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.gemmIrregularBenchmark\n", ""); + result += gemmIrregularBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.rngBenchmark\n", ""); + result += rngBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.conv2dBenchmark\n", ""); + result += conv2dBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.pool2dBenchmark\n", ""); + result += pool2dBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.batchnormBenchmark\n", ""); + result += batchnormBenchmark(); + start = done(start); + + nd4j_printf("Running FullBenchmarkSuite.lstmBenchmark\n", ""); + result += lstmBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.conv3dBenchmark\n", ""); + result += conv3dBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.maxPool3DBenchmark\n", ""); + result += maxPool3DBenchmark(); + start = done(start); + // nd4j_printf("Running FullBenchmarkSuite.layerNormBenchmark\n", ""); + // result += layerNormBenchmark(); + // start = done(start); + + return result; +} + +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/performance/benchmarking/impl/LightBenchmarkSuit.cpp b/libnd4j/include/performance/benchmarking/impl/LightBenchmarkSuit.cpp index 99a1b05bf92a..2ee5bfead650 100644 --- a/libnd4j/include/performance/benchmarking/impl/LightBenchmarkSuit.cpp +++ b/libnd4j/include/performance/benchmarking/impl/LightBenchmarkSuit.cpp @@ -18,9 +18,10 @@ // @author raver119@gmail.com // -#include #include "performance/benchmarking/LightBenchmarkSuit.h" +#include + #ifdef RELEASE_BUILD #define WARMUP 5 #define NUM_ITER 100 @@ -34,607 +35,670 @@ namespace sd { - template - static std::string transformBenchmark() { - std::string output; - output += "transformBenchmark " + DataTypeUtils::asString(DataTypeUtils::fromT()); - - BenchmarkHelper helper(WARMUP, NUM_ITER); - IntPowerParameters length("length", 2, 8, 20, 4); //2^8, 2^12, 2^16, 2^20 - 4MB - BoolParameters inplace("inplace"); - - ParametersBatch batch({&length, &inplace}); - - auto generator = PARAMETRIC_XZ() { - auto arr = NDArrayFactory::create_('c', {p.getIntParam("length")}); - arr->assign(1.0); - x.push_back(arr); - if(p.getIntParam("inplace") == 1){ - z.push_back(arr); - } else { - z.push_back(NDArrayFactory::create_('c', {p.getIntParam("length")})); - } - }; - - ScalarBenchmark sbRelu(scalar::Ops::RELU, "RELU"); - sbRelu.setY(NDArrayFactory::create_(0.0)); - - TransformBenchmark tbSigmoid(transform::StrictOps::Sigmoid, "sigmoid"); - //TransformBenchmark tbSoftmax(transform::StrictOps::SoftMax, "softmax"); - - output += helper.runOperationSuit(&sbRelu, generator, batch, "RELU"); - output += helper.runOperationSuit(&tbSigmoid, generator, batch, "Sigmoid"); - //output += helper.runOperationSuit(&tbSigmoid, generator, batch, "Softmax"); - - return output; - } - - template - static std::string scalarBenchmark() { - std::string output; - output += "scalarBenchmark " + DataTypeUtils::asString(DataTypeUtils::fromT()); - - BenchmarkHelper helper(WARMUP, NUM_ITER); - - IntPowerParameters length("length", 2, 8, 20, 4); //2^8, 2^12, 2^16, 2^20 - BoolParameters inplace("inplace"); - - ParametersBatch batch({&length, &inplace}); - - auto generator = PARAMETRIC_XZ() { - auto arr = NDArrayFactory::create_('c', {p.getIntParam("length")}); - arr->assign(1.0); - x.push_back(arr); - if(p.getIntParam("inplace") == 1){ - z.push_back(arr); - } else { - z.push_back(NDArrayFactory::create_('c', {p.getIntParam("length")})); - } - }; - - ScalarBenchmark sbAdd(scalar::Ops::Add, "sAdd"); - ScalarBenchmark sbDiv(scalar::Ops::Divide, "sDiv"); - ScalarBenchmark sbPow(scalar::Ops::Pow, "sPow"); - - - sbAdd.setY(NDArrayFactory::create_(3.14159265359)); - sbDiv.setY(NDArrayFactory::create_(3.14159265359)); - sbPow.setY(NDArrayFactory::create_(3.14159265359)); - - - output += helper.runOperationSuit(&sbAdd, generator, batch, "Scalar Addition - x.add(3.14159265359)"); - output += helper.runOperationSuit(&sbDiv, generator, batch, "Scalar Division - x.div(3.14159265359)"); - output += helper.runOperationSuit(&sbPow, generator, batch, "Scalar Power - x.pow(3.14159265359)"); - - return output; +template +static std::string transformBenchmark() { + std::string output; + output += "transformBenchmark " + + DataTypeUtils::asString(DataTypeUtils::fromT()); + + BenchmarkHelper helper(WARMUP, NUM_ITER); + IntPowerParameters length("length", 2, 8, 20, + 4); // 2^8, 2^12, 2^16, 2^20 - 4MB + BoolParameters inplace("inplace"); + + ParametersBatch batch({&length, &inplace}); + + auto generator = PARAMETRIC_XZ() { + auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); + arr.assign(1.0); + x.push_back(arr); + if (p.getIntParam("inplace") == 1) { + z.push_back(arr); + } else { + z.push_back(NDArrayFactory::create('c', {p.getIntParam("length")})); } + }; + ScalarBenchmark sbRelu(scalar::Ops::RELU, "RELU"); + sbRelu.setY(NDArrayFactory::create(0.0)); - template - static std::string pairwiseBenchmark() { - std::string output; - output += "pairwiseBenchmark " + DataTypeUtils::asString(DataTypeUtils::fromT()); + TransformBenchmark tbSigmoid(transform::StrictOps::Sigmoid, "sigmoid"); + // TransformBenchmark tbSoftmax(transform::StrictOps::SoftMax, "softmax"); - BenchmarkHelper helper(WARMUP, NUM_ITER); - IntPowerParameters length("length", 2, 8, 20, 4); //2^4 to 2^20 in steps of 4 - 2^4, 2^8, 2^16, 2^20 - BoolParameters inplace("inplace"); + output += helper.runOperationSuit(&sbRelu, generator, batch, "RELU"); + output += helper.runOperationSuit(&tbSigmoid, generator, batch, "Sigmoid"); + // output += helper.runOperationSuit(&tbSigmoid, generator, batch, "Softmax"); - ParametersBatch batch({&length, &inplace}); + return output; +} - auto generator = PARAMETRIC_XYZ() { - auto arr1 = NDArrayFactory::create_('c', {p.getIntParam("length")}); - auto arr2 = NDArrayFactory::create_('c', {p.getIntParam("length")}); - x.push_back(arr1); - y.push_back(arr2); - if(p.getIntParam("inplace") == 1){ - z.push_back(arr1); - } else { - z.push_back(NDArrayFactory::create_('c', {p.getIntParam("length")})); - } - }; +template +static std::string scalarBenchmark() { + std::string output; + output += + "scalarBenchmark " + DataTypeUtils::asString(DataTypeUtils::fromT()); - PairwiseBenchmark pb1(pairwise::Ops::Add, "Add"); - output += helper.runOperationSuit(&pb1, generator, batch, "Pairwise Add"); + BenchmarkHelper helper(WARMUP, NUM_ITER); - PairwiseBenchmark pb2(pairwise::Ops::Divide, "Divide"); - output += helper.runOperationSuit(&pb2, generator, batch, "Pairwise Divide"); + IntPowerParameters length("length", 2, 8, 20, 4); // 2^8, 2^12, 2^16, 2^20 + BoolParameters inplace("inplace"); - return output; - } + ParametersBatch batch({&length, &inplace}); - static std::string mismatchedOrderAssign() { - std::string output; - BenchmarkHelper helper(WARMUP, NUM_ITER); - - IntPowerParameters rows("rows", 2, 8, 20, 4); //2^8, 2^12, 2^16, 2^20 - BoolParameters cf("cf"); - - ParametersBatch batch({&rows, &cf}); - - auto generator = PARAMETRIC_XZ() { - int numElements = 4194304; //2^24 - int rows = p.getIntParam("rows"); - int cols = numElements / rows; - bool c = p.getIntParam("cf"); - - auto arr = NDArrayFactory::create_(c ? 'c' : 'f', {rows, cols}); - auto arr2 = NDArrayFactory::create_(c ? 'f' : 'c', {rows, cols}); - x.push_back(arr); - z.push_back(arr2); - }; - - TransformBenchmark tb(transform::AnyOps::Assign, "assign"); - output += helper.runOperationSuit(&tb, generator, batch, "C->F and F->C Assign F32"); - - //Also test: NCHW to NHWC and back - BoolParameters nchw("nchw"); - int mb = 8; - int hw = 64; - int c = 3; - ParametersBatch batch2({&nchw}); - auto generator2 = PARAMETRIC_XZ() { - bool nchw = p.getIntParam("nchw"); - - if(nchw) { - auto orig = NDArrayFactory::create_('c', {mb, c, hw, hw}); - orig->permutei({0,2,3,1}); - x.push_back(orig); - z.push_back(NDArrayFactory::create_('c', {mb, hw, hw, c})); - } else { - auto orig = NDArrayFactory::create_('c', {mb, hw, hw, c}); - orig->permutei({0,3,1,2}); - x.push_back(orig); - z.push_back(NDArrayFactory::create_('c', {mb, c, hw, hw})); - } - }; - - TransformBenchmark tb2(transform::AnyOps::Assign, "assign_nchw"); - output += helper.runOperationSuit(&tb2, generator2, batch2, "nchw->nhwc and nhwc->nchw Assign FP32"); - return output; + auto generator = PARAMETRIC_XZ() { + auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); + arr.assign(1.0); + x.push_back(arr); + if (p.getIntParam("inplace") == 1) { + z.push_back(arr); + } else { + z.push_back(NDArrayFactory::create('c', {p.getIntParam("length")})); } - - template - static std::string gemmBenchmark() { - std::string output; - output += "gemm " + DataTypeUtils::asString(DataTypeUtils::fromT()); - BenchmarkHelper helper(WARMUP, NUM_ITER); - - for (int o = 0; o <= 1; o++) { - char resultOrder = (o == 0 ? 'f' : 'c'); - IntPowerParameters sz("sz", 2, 4, 10, 2); //2^4=16, ..., 2^10=1024 -> 4 elements - - ParametersBatch b({&sz}); - - auto generator = PARAMETRIC_XYZ() { - auto a = p.getIntParam("sz"); - auto b = p.getIntParam("sz"); - auto c = p.getIntParam("sz"); - std::vector shapeA; - std::vector shapeB; - shapeA = {a, b}; - shapeB = {b, c}; - auto A = NDArrayFactory::create_('c', shapeA); - auto B = NDArrayFactory::create_('c', shapeB); - auto C = NDArrayFactory::create_(resultOrder, {a, c}); - - x.push_back(A); - y.push_back(B); - z.push_back(C); - }; - - std::string n; - n += "Gemm - cOrder="; - n += resultOrder; - - MatrixBenchmark mb(1.0, 0.0, false, false, n); - - output += helper.runOperationSuit(&mb, generator, b, n.c_str()); - } - - return output; + }; + + ScalarBenchmark sbAdd(scalar::Ops::Add, "sAdd"); + ScalarBenchmark sbDiv(scalar::Ops::Divide, "sDiv"); + ScalarBenchmark sbPow(scalar::Ops::Pow, "sPow"); + + sbAdd.setY(NDArrayFactory::create(3.14159265359)); + sbDiv.setY(NDArrayFactory::create(3.14159265359)); + sbPow.setY(NDArrayFactory::create(3.14159265359)); + + output += helper.runOperationSuit(&sbAdd, generator, batch, + "Scalar Addition - x.add(3.14159265359)"); + output += helper.runOperationSuit(&sbDiv, generator, batch, + "Scalar Division - x.div(3.14159265359)"); + output += helper.runOperationSuit(&sbPow, generator, batch, + "Scalar Power - x.pow(3.14159265359)"); + + return output; +} + +template +static std::string pairwiseBenchmark() { + std::string output; + output += + "pairwiseBenchmark " + DataTypeUtils::asString(DataTypeUtils::fromT()); + + BenchmarkHelper helper(WARMUP, NUM_ITER); + IntPowerParameters length( + "length", 2, 8, 20, + 4); // 2^4 to 2^20 in steps of 4 - 2^4, 2^8, 2^16, 2^20 + BoolParameters inplace("inplace"); + + ParametersBatch batch({&length, &inplace}); + + auto generator = PARAMETRIC_XYZ() { + auto arr1 = NDArrayFactory::create('c', {p.getIntParam("length")}); + auto arr2 = NDArrayFactory::create('c', {p.getIntParam("length")}); + x.push_back(arr1); + y.push_back(arr2); + if (p.getIntParam("inplace") == 1) { + z.push_back(arr1); + } else { + z.push_back(NDArrayFactory::create('c', {p.getIntParam("length")})); } - - template - static std::string reduceFullBenchmark() { - std::string output; - output += "reduceFullBenchmark " + DataTypeUtils::asString(DataTypeUtils::fromT()); - - BenchmarkHelper helper(WARMUP, NUM_ITER); - - IntPowerParameters length("length", 2, 8, 20, 4); //2^8, 2^12, 2^16, 2^20 - - ParametersBatch batch({&length}); - - auto generator = PARAMETRIC_XYZ() { - auto arr = NDArrayFactory::create_('c', {p.getIntParam("length")}); - - x.push_back(arr); - y.push_back(nullptr); - z.push_back(NDArrayFactory::create_(0.0f)); - }; - - ReductionBenchmark rbSum(reduce::SameOps::Sum, "sum"); - ReductionBenchmark rbProd(reduce::SameOps::Prod, "prod"); - ReductionBenchmark rbMax(reduce::SameOps::Max, "max"); - - output += helper.runOperationSuit(&rbSum, (const std::function)(generator), batch, "Sum - Full Array Reduction"); - output += helper.runOperationSuit(&rbProd, (const std::function)(generator), batch, "Product - Full Array Reduction"); - output += helper.runOperationSuit(&rbMax, (const std::function)(generator), batch, "Maximum - Full Array Reduction"); - - //Index reduction - sd::ops::argmax opArgmax; - DeclarableBenchmark dbArgmax(opArgmax, "Argmax"); - auto generator3 = PARAMETRIC_D(){ - auto ctx = new Context(1); - - ctx->setInputArray(0, NDArrayFactory::create_('c', {p.getIntParam("length")}), true); - ctx->setInputArray(1, NDArrayFactory::create_((Nd4jLong)0), true); - ctx->setOutputArray(0, NDArrayFactory::create_(0), true); - - return ctx; - }; - output += helper.runOperationSuit(&dbArgmax, generator3, batch, "Argmax Full Array Reduction"); - return output; + }; + + PairwiseBenchmark pb1(pairwise::Ops::Add, "Add"); + output += helper.runOperationSuit(&pb1, generator, batch, "Pairwise Add"); + + PairwiseBenchmark pb2(pairwise::Ops::Divide, "Divide"); + output += helper.runOperationSuit(&pb2, generator, batch, "Pairwise Divide"); + + return output; +} + +static std::string mismatchedOrderAssign() { + std::string output; + BenchmarkHelper helper(WARMUP, NUM_ITER); + + IntPowerParameters rows("rows", 2, 8, 20, 4); // 2^8, 2^12, 2^16, 2^20 + BoolParameters cf("cf"); + + ParametersBatch batch({&rows, &cf}); + + auto generator = PARAMETRIC_XZ() { + int numElements = 4194304; // 2^24 + int rows = p.getIntParam("rows"); + int cols = numElements / rows; + bool c = p.getIntParam("cf"); + + auto arr = NDArrayFactory::create(c ? 'c' : 'f', {rows, cols}); + auto arr2 = NDArrayFactory::create(c ? 'f' : 'c', {rows, cols}); + x.push_back(arr); + z.push_back(arr2); + }; + + TransformBenchmark tb(transform::AnyOps::Assign, "assign"); + output += helper.runOperationSuit(&tb, generator, batch, + "C->F and F->C Assign F32"); + + // Also test: NCHW to NHWC and back + BoolParameters nchw("nchw"); + int mb = 8; + int hw = 64; + int c = 3; + ParametersBatch batch2({&nchw}); + auto generator2 = PARAMETRIC_XZ() { + bool nchw = p.getIntParam("nchw"); + + if (nchw) { + auto orig = NDArrayFactory::create('c', {mb, c, hw, hw}); + orig.permutei({0, 2, 3, 1}); + x.push_back(orig); + z.push_back(NDArrayFactory::create('c', {mb, hw, hw, c})); + } else { + auto orig = NDArrayFactory::create('c', {mb, hw, hw, c}); + orig.permutei({0, 3, 1, 2}); + x.push_back(orig); + z.push_back(NDArrayFactory::create('c', {mb, c, hw, hw})); } - - template - static std::string reduceDimBenchmark(){ - std::string output; - output += "reduceDimBenchmark " + DataTypeUtils::asString(DataTypeUtils::fromT()); - - BenchmarkHelper helper(WARMUP, NUM_ITER); - - int length[] = {1024*1024}; - int pow[] = {10}; - - for( int i=0; i<1; i++ ){ - IntPowerParameters rows("rows", 2, 0, pow[i], 2); - BoolParameters dim("dim"); - - - ParametersBatch batch({&rows, &dim}); - - auto generator = PARAMETRIC_XYZ() { - int rows = p.getIntParam("rows"); - int cols = length[i] / rows; - int dim = p.getIntParam("dim"); - auto arr = NDArrayFactory::create_('c', {rows, cols}); - - - x.push_back(arr); - y.push_back(NDArrayFactory::create_(dim)); - - NDArray* result; - if(dim == 0){ - result = NDArrayFactory::create_('c', {cols}); - } else { - result = NDArrayFactory::create_('c', {rows}); - } - z.push_back(result); - }; - - ReductionBenchmark rbSum(reduce::SameOps::Sum, "sum"); - ReductionBenchmark rbMax(reduce::SameOps::Max, "max"); - - std::string s1("Sum Along Dimension - "); - s1 += std::to_string(length[i]); - std::string s3("Maximum Along Dimension - "); - s3 += std::to_string(length[i]); - - output += helper.runOperationSuit(&rbSum, (const std::function)(generator), batch, s1.c_str()); - output += helper.runOperationSuit(&rbMax, (const std::function)(generator), batch, s3.c_str()); - - - - auto generator3 = PARAMETRIC_D(){ - auto ctx = new Context(1); - int rows = p.getIntParam("rows"); - int cols = length[i] / rows; - int dim = p.getIntParam("dim"); - auto arr = NDArrayFactory::create_('c', {rows, cols}); - - auto dimArg = new Nd4jLong[1]; - dimArg[0] = dim; - ctx->setIArguments(dimArg, 1); - delete[] dimArg; - - ctx->setInputArray(0, arr, true); - - NDArray* result; - if(dim == 0){ - result = NDArrayFactory::create_('c', {cols}); - } else { - result = NDArrayFactory::create_('c', {rows}); - } - ctx->setOutputArray(0, result, true); - return ctx; - }; - - std::string s5("Argmax Along Dimension - "); - s5 += std::to_string(length[i]); - - sd::ops::argmax opArgmax; - DeclarableBenchmark dbArgmax(opArgmax, "Argmax"); - output += helper.runOperationSuit(&dbArgmax, generator3, batch, s5.c_str()); - } - return output; + }; + + TransformBenchmark tb2(transform::AnyOps::Assign, "assign_nchw"); + output += helper.runOperationSuit(&tb2, generator2, batch2, + "nchw->nhwc and nhwc->nchw Assign FP32"); + return output; +} + +template +static std::string gemmBenchmark() { + std::string output; + output += "gemm " + DataTypeUtils::asString(DataTypeUtils::fromT()); + BenchmarkHelper helper(WARMUP, NUM_ITER); + + for (int o = 0; o <= 1; o++) { + char resultOrder = (o == 0 ? 'f' : 'c'); + IntPowerParameters sz("sz", 2, 4, 10, + 2); // 2^4=16, ..., 2^10=1024 -> 4 elements + + ParametersBatch b({&sz}); + + auto generator = PARAMETRIC_XYZ() { + auto a = p.getIntParam("sz"); + auto b = p.getIntParam("sz"); + auto c = p.getIntParam("sz"); + std::vector shapeA; + std::vector shapeB; + shapeA = {a, b}; + shapeB = {b, c}; + auto A = NDArrayFactory::create('c', shapeA); + auto B = NDArrayFactory::create('c', shapeB); + auto C = NDArrayFactory::create(resultOrder, {a, c}); + + x.push_back(A); + y.push_back(B); + z.push_back(C); + }; + + std::string n; + n += "Gemm - cOrder="; + n += resultOrder; + + MatrixBenchmark mb(1.0, 0.0, false, false, n); + + output += helper.runOperationSuit(&mb, generator, b, n.c_str()); + } + + return output; +} + +template +static std::string reduceFullBenchmark() { + std::string output; + output += "reduceFullBenchmark " + + DataTypeUtils::asString(DataTypeUtils::fromT()); + + BenchmarkHelper helper(WARMUP, NUM_ITER); + + IntPowerParameters length("length", 2, 8, 20, 4); // 2^8, 2^12, 2^16, 2^20 + + ParametersBatch batch({&length}); + + auto generator = PARAMETRIC_XYZ() { + auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); + + x.push_back(arr); + y.push_back(NDArray()); + z.push_back(NDArrayFactory::create(0.0f)); + }; + + ReductionBenchmark rbSum(reduce::SameOps::Sum, "sum"); + ReductionBenchmark rbProd(reduce::SameOps::Prod, "prod"); + ReductionBenchmark rbMax(reduce::SameOps::Max, "max"); + + output += helper.runOperationSuit( + &rbSum, + (const std::function)(generator), + batch, "Sum - Full Array Reduction"); + output += helper.runOperationSuit( + &rbProd, + (const std::function)(generator), + batch, "Product - Full Array Reduction"); + output += helper.runOperationSuit( + &rbMax, + (const std::function)(generator), + batch, "Maximum - Full Array Reduction"); + + // Index reduction + sd::ops::argmax opArgmax; + DeclarableBenchmark dbArgmax(opArgmax, "Argmax"); + auto generator3 = PARAMETRIC_D() { + auto ctx = new Context(1); + + ctx->setInputArray( + 0, NDArrayFactory::create('c', {p.getIntParam("length")})); + ctx->setInputArray(1, NDArrayFactory::create((Nd4jLong)0)); + ctx->setOutputArray(0, NDArrayFactory::create(0)); + + return ctx; + }; + output += helper.runOperationSuit(&dbArgmax, generator3, batch, + "Argmax Full Array Reduction"); + return output; +} + +template +static std::string reduceDimBenchmark() { + std::string output; + output += "reduceDimBenchmark " + + DataTypeUtils::asString(DataTypeUtils::fromT()); + + BenchmarkHelper helper(WARMUP, NUM_ITER); + + int length[] = {1024 * 1024}; + int pow[] = {10}; + + for (int i = 0; i < 1; i++) { + IntPowerParameters rows("rows", 2, 0, pow[i], 2); + BoolParameters dim("dim"); + + ParametersBatch batch({&rows, &dim}); + + auto generator = PARAMETRIC_XYZ() { + int rows = p.getIntParam("rows"); + int cols = length[i] / rows; + int dim = p.getIntParam("dim"); + auto arr = NDArrayFactory::create('c', {rows, cols}); + + x.push_back(arr); + y.push_back(NDArrayFactory::create(dim)); + + NDArray result; + if (dim == 0) { + result = NDArrayFactory::create('c', {cols}); + } else { + result = NDArrayFactory::create('c', {rows}); + } + z.push_back(result); + }; + + ReductionBenchmark rbSum(reduce::SameOps::Sum, "sum"); + ReductionBenchmark rbMax(reduce::SameOps::Max, "max"); + + std::string s1("Sum Along Dimension - "); + s1 += std::to_string(length[i]); + std::string s3("Maximum Along Dimension - "); + s3 += std::to_string(length[i]); + + output += helper.runOperationSuit( + &rbSum, + (const std::function)(generator), + batch, s1.c_str()); + output += helper.runOperationSuit( + &rbMax, + (const std::function)(generator), + batch, s3.c_str()); + + auto generator3 = PARAMETRIC_D() { + auto ctx = new Context(1); + int rows = p.getIntParam("rows"); + int cols = length[i] / rows; + int dim = p.getIntParam("dim"); + auto arr = NDArrayFactory::create('c', {rows, cols}); + + auto dimArg = new Nd4jLong[1]; + dimArg[0] = dim; + ctx->setIArguments(dimArg, 1); + delete[] dimArg; + + ctx->setInputArray(0, arr); + + NDArray result; + if (dim == 0) { + result = NDArrayFactory::create('c', {cols}); + } else { + result = NDArrayFactory::create('c', {rows}); + } + ctx->setOutputArray(0, result); + return ctx; + }; + + std::string s5("Argmax Along Dimension - "); + s5 += std::to_string(length[i]); + + sd::ops::argmax opArgmax; + DeclarableBenchmark dbArgmax(opArgmax, "Argmax"); + output += helper.runOperationSuit(&dbArgmax, generator3, batch, s5.c_str()); + } + return output; +} + +template +static std::string conv2d() { + std::string output; + output += "conv2d " + DataTypeUtils::asString(DataTypeUtils::fromT()); + BenchmarkHelper helper(WARMUP, NUM_ITER); + + // Convolution2D op + BoolParameters nhwc("nhwc"); + PredefinedParameters k("k", {2, 3}); + + ParametersBatch batch({&nhwc, &k}); + sd::ops::conv2d conv2d; + DeclarableBenchmark benchmark(conv2d, "conv2d"); + + int hw = 64; + + auto generator = PARAMETRIC_D() { + auto ctx = new Context(1); + int n = p.getIntParam("nhwc"); + int khw = p.getIntParam("k"); + + if (n == 0) { + auto input = NDArrayFactory::create('c', {8, 3, hw, hw}); + auto output = NDArrayFactory::create('c', {8, 3, hw, hw}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); + } else { + auto input = NDArrayFactory::create('c', {8, hw, hw, 3}); + auto output = NDArrayFactory::create('c', {8, hw, hw, 3}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); } - template - static std::string conv2d(){ - std::string output; - output += "conv2d " + DataTypeUtils::asString(DataTypeUtils::fromT()); - BenchmarkHelper helper(WARMUP, NUM_ITER); - - //Convolution2D op - BoolParameters nhwc("nhwc"); - PredefinedParameters k("k", {2, 3}); - - ParametersBatch batch({&nhwc, &k}); - sd::ops::conv2d conv2d; - DeclarableBenchmark benchmark(conv2d, "conv2d"); - - int hw = 64; - - auto generator = PARAMETRIC_D() { - auto ctx = new Context(1); - int n = p.getIntParam("nhwc"); - int khw = p.getIntParam("k"); - - if (n == 0) { - auto input = NDArrayFactory::create_('c', {8, 3, hw, hw}); - auto output = NDArrayFactory::create_('c', {8, 3, hw, hw}); - ctx->setInputArray(0, input, true); - ctx->setOutputArray(0, output, true); - } else { - auto input = NDArrayFactory::create_('c', {8, hw, hw, 3}); - auto output = NDArrayFactory::create_('c', {8, hw, hw, 3}); - ctx->setInputArray(0, input, true); - ctx->setOutputArray(0, output, true); - } - - auto b = NDArrayFactory::create_('c', {3}); - auto w = NDArrayFactory::create_('c', {khw, khw, 3, 3}); // [kH, kW, iC, oC] always - - ctx->setInputArray(1, w, true); - ctx->setInputArray(2, b, true); - - auto args = new Nd4jLong[10]; - args[0] = args[1] = khw; //Kernel - args[2] = args[3] = 1;//Stride - args[4] = args[5] = 0; //Pad - args[6] = args[7] = 1; //Dilation - args[8] = 1; //SAME - args[9] = n;//0-nchw, 1=nhwc - ctx->setIArguments(args, 10); - delete[] args; - - return ctx; - }; - - output += helper.runOperationSuit(&benchmark, generator, batch, "Conv2d"); - return output; + auto b = NDArrayFactory::create('c', {3}); + auto w = NDArrayFactory::create( + 'c', {khw, khw, 3, 3}); // [kH, kW, iC, oC] always + + ctx->setInputArray(1, w); + ctx->setInputArray(2, b); + + auto args = new Nd4jLong[10]; + args[0] = args[1] = khw; // Kernel + args[2] = args[3] = 1; // Stride + args[4] = args[5] = 0; // Pad + args[6] = args[7] = 1; // Dilation + args[8] = 1; // SAME + args[9] = n; // 0-nchw, 1=nhwc + ctx->setIArguments(args, 10); + delete[] args; + + return ctx; + }; + + output += helper.runOperationSuit(&benchmark, generator, batch, "Conv2d"); + return output; +} + +template +static std::string pool2d() { + std::string output; + output += "pool2d " + DataTypeUtils::asString(DataTypeUtils::fromT()); + BenchmarkHelper helper(WARMUP, NUM_ITER); + + // Convolution2D op + BoolParameters nhwc("nhwc"); + PredefinedParameters k("k", {2, 3}); + + ParametersBatch batch({&nhwc, &k}); + + int c = 3; + int hw = 64; + + auto generator = PARAMETRIC_D() { + auto ctx = new Context(1); + int n = p.getIntParam("nhwc"); + int khw = p.getIntParam("k"); + + if (n == 0) { + auto input = NDArrayFactory::create('c', {8, c, hw, hw}); + auto output = NDArrayFactory::create('c', {8, c, hw, hw}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); + } else { + auto input = NDArrayFactory::create('c', {8, hw, hw, c}); + auto output = NDArrayFactory::create('c', {8, hw, hw, c}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); } - template - static std::string pool2d() { - std::string output; - output += "pool2d " + DataTypeUtils::asString(DataTypeUtils::fromT()); - BenchmarkHelper helper(WARMUP, NUM_ITER); - - //Convolution2D op - BoolParameters nhwc("nhwc"); - PredefinedParameters k("k", {2, 3}); - - ParametersBatch batch({&nhwc, &k}); - - int c = 3; - int hw = 64; - - auto generator = PARAMETRIC_D() { - auto ctx = new Context(1); - int n = p.getIntParam("nhwc"); - int khw = p.getIntParam("k"); - - if (n == 0) { - auto input = NDArrayFactory::create_('c', {8, c, hw, hw}); - auto output = NDArrayFactory::create_('c', {8, c, hw, hw}); - ctx->setInputArray(0, input, true); - ctx->setOutputArray(0, output, true); - } else { - auto input = NDArrayFactory::create_('c', {8, hw, hw, c}); - auto output = NDArrayFactory::create_('c', {8, hw, hw, c}); - ctx->setInputArray(0, input, true); - ctx->setOutputArray(0, output, true); - } - - auto args = new Nd4jLong[11]; - args[0] = args[1] = khw; //Kernel - args[2] = args[3] = 1;//Stride - args[4] = args[5] = 0; //Pad - args[6] = args[7] = 1; //Dilation - args[8] = 1; //SAME - args[9] = 0; //Divisor mode - 0 = exclude padding in divisor - args[10] = n;//0-nchw, 1=nhwc - ctx->setIArguments(args, 11); - delete[] args; - - return ctx; - }; - - sd::ops::avgpool2d avgpool2d; - DeclarableBenchmark benchmark1(avgpool2d, "avgpool"); - output += helper.runOperationSuit(&benchmark1, generator, batch, "Average Pool 2d"); - - sd::ops::maxpool2d maxpool2d; - DeclarableBenchmark benchmark2(maxpool2d, "maxpool"); - output += helper.runOperationSuit(&benchmark2, generator, batch, "Max Pool 2d"); - return output; + auto args = new Nd4jLong[11]; + args[0] = args[1] = khw; // Kernel + args[2] = args[3] = 1; // Stride + args[4] = args[5] = 0; // Pad + args[6] = args[7] = 1; // Dilation + args[8] = 1; // SAME + args[9] = 0; // Divisor mode - 0 = exclude padding in divisor + args[10] = n; // 0-nchw, 1=nhwc + ctx->setIArguments(args, 11); + delete[] args; + + return ctx; + }; + + sd::ops::avgpool2d avgpool2d; + DeclarableBenchmark benchmark1(avgpool2d, "avgpool"); + output += + helper.runOperationSuit(&benchmark1, generator, batch, "Average Pool 2d"); + + sd::ops::maxpool2d maxpool2d; + DeclarableBenchmark benchmark2(maxpool2d, "maxpool"); + output += + helper.runOperationSuit(&benchmark2, generator, batch, "Max Pool 2d"); + return output; +} + +template +static std::string lstmBenchmark() { + std::string output; + output += "lstm " + DataTypeUtils::asString(DataTypeUtils::fromT()); + BenchmarkHelper helper(WARMUP, NUM_ITER); + + BoolParameters format( + "format"); // 0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen] + PredefinedParameters mb("mb", {1, 8}); + int n = 128; + + ParametersBatch batch({&format, &mb}); + sd::ops::lstmBlock lstmBlock; + DeclarableBenchmark benchmark(lstmBlock, "lstm"); + + int seqLength = 8; + + auto generator = PARAMETRIC_D() { + auto ctx = new Context(1); + int f = p.getIntParam("format"); + int m = p.getIntParam("mb"); + + Nd4jLong l = 0; + ctx->setInputArray( + 0, NDArrayFactory::create(l)); // Max TS length (unused) + + if (f == 0) { + // TNS format + ctx->setInputArray( + 1, NDArrayFactory::create('c', {seqLength, m, n})); // x + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {seqLength, m, n})); // i + ctx->setOutputArray( + 1, NDArrayFactory::create('c', {seqLength, m, n})); // c + ctx->setOutputArray( + 2, NDArrayFactory::create('c', {seqLength, m, n})); // f + ctx->setOutputArray( + 3, NDArrayFactory::create('c', {seqLength, m, n})); // o + ctx->setOutputArray( + 4, NDArrayFactory::create('c', {seqLength, m, n})); // z + ctx->setOutputArray( + 5, NDArrayFactory::create('c', {seqLength, m, n})); // h + ctx->setOutputArray( + 6, NDArrayFactory::create('c', {seqLength, m, n})); // y + } else { + // NST format + ctx->setInputArray( + 1, NDArrayFactory::create('f', {m, n, seqLength})); // x + ctx->setOutputArray( + 0, NDArrayFactory::create('f', {m, n, seqLength})); // i + ctx->setOutputArray( + 1, NDArrayFactory::create('f', {m, n, seqLength})); // c + ctx->setOutputArray( + 2, NDArrayFactory::create('f', {m, n, seqLength})); // f + ctx->setOutputArray( + 3, NDArrayFactory::create('f', {m, n, seqLength})); // o + ctx->setOutputArray( + 4, NDArrayFactory::create('f', {m, n, seqLength})); // z + ctx->setOutputArray( + 5, NDArrayFactory::create('f', {m, n, seqLength})); // h + ctx->setOutputArray( + 6, NDArrayFactory::create('f', {m, n, seqLength})); // y } - template - static std::string lstmBenchmark() { - std::string output; - output += "lstm " + DataTypeUtils::asString(DataTypeUtils::fromT()); - BenchmarkHelper helper(WARMUP, NUM_ITER); - - BoolParameters format("format"); //0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen] - PredefinedParameters mb("mb", {1, 8}); - int n = 128; - - ParametersBatch batch({&format, &mb}); - sd::ops::lstmBlock lstmBlock; - DeclarableBenchmark benchmark(lstmBlock, "lstm"); - - int seqLength = 8; - - auto generator = PARAMETRIC_D() { - auto ctx = new Context(1); - int f = p.getIntParam("format"); - int m = p.getIntParam("mb"); - - Nd4jLong l = 0; - ctx->setInputArray(0, NDArrayFactory::create_(l), true); //Max TS length (unused) - - - if (f == 0) { - //TNS format - ctx->setInputArray(1, NDArrayFactory::create_('c', {seqLength, m, n}), true); //x - ctx->setOutputArray(0, NDArrayFactory::create_('c', {seqLength, m, n}), true); //i - ctx->setOutputArray(1, NDArrayFactory::create_('c', {seqLength, m, n}), true); //c - ctx->setOutputArray(2, NDArrayFactory::create_('c', {seqLength, m, n}), true); //f - ctx->setOutputArray(3, NDArrayFactory::create_('c', {seqLength, m, n}), true); //o - ctx->setOutputArray(4, NDArrayFactory::create_('c', {seqLength, m, n}), true); //z - ctx->setOutputArray(5, NDArrayFactory::create_('c', {seqLength, m, n}), true); //h - ctx->setOutputArray(6, NDArrayFactory::create_('c', {seqLength, m, n}), true); //y - } else { - //NST format - ctx->setInputArray(1, NDArrayFactory::create_('f', {m, n, seqLength}), true); //x - ctx->setOutputArray(0, NDArrayFactory::create_('f', {m, n, seqLength}), true); //i - ctx->setOutputArray(1, NDArrayFactory::create_('f', {m, n, seqLength}), true); //c - ctx->setOutputArray(2, NDArrayFactory::create_('f', {m, n, seqLength}), true); //f - ctx->setOutputArray(3, NDArrayFactory::create_('f', {m, n, seqLength}), true); //o - ctx->setOutputArray(4, NDArrayFactory::create_('f', {m, n, seqLength}), true); //z - ctx->setOutputArray(5, NDArrayFactory::create_('f', {m, n, seqLength}), true); //h - ctx->setOutputArray(6, NDArrayFactory::create_('f', {m, n, seqLength}), true); //y - } - - auto cLast = NDArrayFactory::create_('c', {m, n}); - auto yLast = NDArrayFactory::create_('c', {m, n}); - auto W = NDArrayFactory::create_('c', {2 * n, 4 * n}); - auto Wci = NDArrayFactory::create_('c', {n}); - auto Wcf = NDArrayFactory::create_('c', {n}); - auto Wco = NDArrayFactory::create_('c', {n}); - auto b = NDArrayFactory::create_('c', {4 * n}); - - ctx->setInputArray(2, cLast, true); - ctx->setInputArray(3, yLast, true); - ctx->setInputArray(4, W, true); - ctx->setInputArray(5, Wci, true); - ctx->setInputArray(6, Wcf, true); - ctx->setInputArray(7, Wco, true); - ctx->setInputArray(8, b, true); - - auto iargs = new Nd4jLong[2]; - iargs[0] = 0; //No peephole - iargs[1] = f; - ctx->setIArguments(iargs, 2); - delete[] iargs; - - auto targs = new double[2]; - targs[0] = 1.0; //forget bias - targs[1] = 0.0; //cell clipping value - ctx->setTArguments(targs, 2); - delete[] targs; - return ctx; - }; - - output += helper.runOperationSuit(&benchmark, generator, batch, "LSTMBlock"); - return output; + auto cLast = NDArrayFactory::create('c', {m, n}); + auto yLast = NDArrayFactory::create('c', {m, n}); + auto W = NDArrayFactory::create('c', {2 * n, 4 * n}); + auto Wci = NDArrayFactory::create('c', {n}); + auto Wcf = NDArrayFactory::create('c', {n}); + auto Wco = NDArrayFactory::create('c', {n}); + auto b = NDArrayFactory::create('c', {4 * n}); + + ctx->setInputArray(2, cLast); + ctx->setInputArray(3, yLast); + ctx->setInputArray(4, W); + ctx->setInputArray(5, Wci); + ctx->setInputArray(6, Wcf); + ctx->setInputArray(7, Wco); + ctx->setInputArray(8, b); + + auto iargs = new Nd4jLong[2]; + iargs[0] = 0; // No peephole + iargs[1] = f; + ctx->setIArguments(iargs, 2); + delete[] iargs; + + auto targs = new double[2]; + targs[0] = 1.0; // forget bias + targs[1] = 0.0; // cell clipping value + ctx->setTArguments(targs, 2); + delete[] targs; + return ctx; + }; + + output += helper.runOperationSuit(&benchmark, generator, batch, "LSTMBlock"); + return output; +} + +static std::string broadcast2d() { + std::string output; + BenchmarkHelper helper(WARMUP, NUM_ITER); + + int rows = 65536; + IntPowerParameters cols( + "cols", 2, 2, 12, + 4); // 2^2 to 2^12 in steps of 2 - 2^1=2, ..., 2^10=1024 + BoolParameters axis("axis"); + BoolParameters inplace("inplace"); + + ParametersBatch batch({&cols, &axis, &inplace}); + + auto generator = PARAMETRIC_D() { + auto a = p.getIntParam("axis"); + auto arr = + NDArrayFactory::create('c', {rows, p.getIntParam("cols")}); + + auto ctx = new Context(1); + ctx->setInputArray(0, arr); + if (a == 0) { + ctx->setInputArray(1, NDArrayFactory::create('c', {rows, 1})); + } else { + ctx->setInputArray( + 1, NDArrayFactory::create('c', {1, p.getIntParam("cols")})); } - - static std::string broadcast2d() { - std::string output; - BenchmarkHelper helper(WARMUP, NUM_ITER); - - int rows = 65536; - IntPowerParameters cols("cols", 2, 2, 12, 4); //2^2 to 2^12 in steps of 2 - 2^1=2, ..., 2^10=1024 - BoolParameters axis("axis"); - BoolParameters inplace("inplace"); - - ParametersBatch batch({&cols, &axis, &inplace}); - - auto generator = PARAMETRIC_D() { - auto a = p.getIntParam("axis"); - auto arr = NDArrayFactory::create_('c', {rows, p.getIntParam("cols")}); - - auto ctx = new Context(1); - ctx->setInputArray(0, arr, true); - if(a == 0){ - ctx->setInputArray(1, NDArrayFactory::create_('c', {rows, 1}), true); - } else { - ctx->setInputArray(1, NDArrayFactory::create_('c', {1, p.getIntParam("cols")}), true); - } - if (p.getIntParam("inplace") == 1) { - ctx->setOutputArray(0, arr); - ctx->markInplace(true); - } else { - ctx->setOutputArray(0, NDArrayFactory::create_('c', {rows, p.getIntParam("cols")}), true); - } - return ctx; - }; - - std::string s("add"); - sd::ops::add op; - DeclarableBenchmark benchmark(op, "add"); - output += helper.runOperationSuit(&benchmark, generator, batch, "Broadcast (Custom) Add - 2d"); - return output; + if (p.getIntParam("inplace") == 1) { + ctx->setOutputArray(0, arr); + ctx->markInplace(true); + } else { + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {rows, p.getIntParam("cols")})); } - - std::string LightBenchmarkSuit::runSuit() { + return ctx; + }; + + std::string s("add"); + sd::ops::add op; + DeclarableBenchmark benchmark(op, "add"); + output += helper.runOperationSuit(&benchmark, generator, batch, + "Broadcast (Custom) Add - 2d"); + return output; +} + +std::string LightBenchmarkSuit::runSuit() { #ifdef RELEASE_BUILD - std::vector dtypes({sd::DataType::FLOAT32, sd::DataType::HALF}); + std::vector dtypes({sd::DataType::FLOAT32, sd::DataType::HALF}); #else - std::vector dtypes({sd::DataType::FLOAT32}); + std::vector dtypes({sd::DataType::FLOAT32}); #endif - std::string result; - - for (auto t:dtypes) { - nd4j_printf("Running LightBenchmarkSuite.transformBenchmark [%s]\n", DataTypeUtils::asString(t).c_str()); - BUILD_SINGLE_SELECTOR(t, result += transformBenchmark, (), LIBND4J_TYPES); + std::string result; - nd4j_printf("Running LightBenchmarkSuite.scalarBenchmark [%s]\n", DataTypeUtils::asString(t).c_str()); - BUILD_SINGLE_SELECTOR(t, result += scalarBenchmark, (), LIBND4J_TYPES); + for (auto t : dtypes) { + nd4j_printf("Running LightBenchmarkSuite.transformBenchmark [%s]\n", + DataTypeUtils::asString(t).c_str()); + BUILD_SINGLE_SELECTOR(t, result += transformBenchmark, (), LIBND4J_TYPES); - nd4j_printf("Running LightBenchmarkSuite.pairwiseBenchmark [%s]\n", DataTypeUtils::asString(t).c_str()); - BUILD_SINGLE_SELECTOR(t, result += pairwiseBenchmark, (), LIBND4J_TYPES); + nd4j_printf("Running LightBenchmarkSuite.scalarBenchmark [%s]\n", + DataTypeUtils::asString(t).c_str()); + BUILD_SINGLE_SELECTOR(t, result += scalarBenchmark, (), LIBND4J_TYPES); - nd4j_printf("Running LightBenchmarkSuite.reduceFullBenchmark [%s]\n", DataTypeUtils::asString(t).c_str()); - BUILD_SINGLE_SELECTOR(t, result += reduceFullBenchmark, (), LIBND4J_TYPES); + nd4j_printf("Running LightBenchmarkSuite.pairwiseBenchmark [%s]\n", + DataTypeUtils::asString(t).c_str()); + BUILD_SINGLE_SELECTOR(t, result += pairwiseBenchmark, (), LIBND4J_TYPES); - nd4j_printf("Running LightBenchmarkSuite.reduceDimBenchmark [%s]\n", DataTypeUtils::asString(t).c_str()); - BUILD_SINGLE_SELECTOR(t, result += reduceDimBenchmark, (), LIBND4J_TYPES); + nd4j_printf("Running LightBenchmarkSuite.reduceFullBenchmark [%s]\n", + DataTypeUtils::asString(t).c_str()); + BUILD_SINGLE_SELECTOR(t, result += reduceFullBenchmark, (), LIBND4J_TYPES); - nd4j_printf("Running LightBenchmarkSuite.gemmBenchmark [%s]\n", DataTypeUtils::asString(t).c_str()); - BUILD_SINGLE_SELECTOR(t, result += gemmBenchmark, (), LIBND4J_TYPES); + nd4j_printf("Running LightBenchmarkSuite.reduceDimBenchmark [%s]\n", + DataTypeUtils::asString(t).c_str()); + BUILD_SINGLE_SELECTOR(t, result += reduceDimBenchmark, (), LIBND4J_TYPES); - nd4j_printf("Running LightBenchmarkSuite.conv2d [%s]\n", DataTypeUtils::asString(t).c_str()); - BUILD_SINGLE_SELECTOR(t, result += conv2d, (), LIBND4J_TYPES); + nd4j_printf("Running LightBenchmarkSuite.gemmBenchmark [%s]\n", + DataTypeUtils::asString(t).c_str()); + BUILD_SINGLE_SELECTOR(t, result += gemmBenchmark, (), LIBND4J_TYPES); - nd4j_printf("Running LightBenchmarkSuite.pool2d [%s]\n", DataTypeUtils::asString(t).c_str()); - BUILD_SINGLE_SELECTOR(t, result += pool2d, (), LIBND4J_TYPES); + nd4j_printf("Running LightBenchmarkSuite.conv2d [%s]\n", + DataTypeUtils::asString(t).c_str()); + BUILD_SINGLE_SELECTOR(t, result += conv2d, (), LIBND4J_TYPES); - nd4j_printf("Running LightBenchmarkSuite.lstmBenchmark [%s]\n", DataTypeUtils::asString(t).c_str()); - BUILD_SINGLE_SELECTOR(t, result += lstmBenchmark, (), LIBND4J_TYPES); + nd4j_printf("Running LightBenchmarkSuite.pool2d [%s]\n", + DataTypeUtils::asString(t).c_str()); + BUILD_SINGLE_SELECTOR(t, result += pool2d, (), LIBND4J_TYPES); - } + nd4j_printf("Running LightBenchmarkSuite.lstmBenchmark [%s]\n", + DataTypeUtils::asString(t).c_str()); + BUILD_SINGLE_SELECTOR(t, result += lstmBenchmark, (), LIBND4J_TYPES); + } - nd4j_printf("Running LightBenchmarkSuite.broadcast2d\n", ""); - result += broadcast2d(); - nd4j_printf("Running LightBenchmarkSuite.mismatchedOrderAssign\n", ""); - result += mismatchedOrderAssign(); + nd4j_printf("Running LightBenchmarkSuite.broadcast2d\n", ""); + result += broadcast2d(); + nd4j_printf("Running LightBenchmarkSuite.mismatchedOrderAssign\n", ""); + result += mismatchedOrderAssign(); - return result; - } -} \ No newline at end of file + return result; +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/samediff.h b/libnd4j/include/samediff.h index 4907c980256f..71a058f45d88 100644 --- a/libnd4j/include/samediff.h +++ b/libnd4j/include/samediff.h @@ -30,10 +30,10 @@ #include // basic Graph-related includes -#include #include +#include // ML ops includes #include -#endif //S_SAMEDIFF_H +#endif // S_SAMEDIFF_H diff --git a/libnd4j/include/system/BlasVersionHelper.h b/libnd4j/include/system/BlasVersionHelper.h index 7cc97a26cc1e..1878db24a370 100644 --- a/libnd4j/include/system/BlasVersionHelper.h +++ b/libnd4j/include/system/BlasVersionHelper.h @@ -21,20 +21,20 @@ #ifndef SAMEDIFF_BLASVERSIONHELPER_H #define SAMEDIFF_BLASVERSIONHELPER_H -#include #include #include +#include namespace sd { - class ND4J_EXPORT BlasVersionHelper { - public: - int _blasMajorVersion = 0; - int _blasMinorVersion = 0; - int _blasPatchVersion = 0; +class SD_EXPORT BlasVersionHelper { + public: + int _blasMajorVersion = 0; + int _blasMinorVersion = 0; + int _blasPatchVersion = 0; - BlasVersionHelper(); - ~BlasVersionHelper() = default; - }; -} + BlasVersionHelper(); + ~BlasVersionHelper() = default; +}; +} // namespace sd -#endif //DEV_TESTS_BLASVERSIONHELPER_H +#endif // SD_BLASVERSIONHELPER_H diff --git a/libnd4j/include/system/Environment.h b/libnd4j/include/system/Environment.h index 9b2a4b65bbd7..ede57ea906f9 100644 --- a/libnd4j/include/system/Environment.h +++ b/libnd4j/include/system/Environment.h @@ -21,129 +21,130 @@ #ifndef LIBND4J_ENVIRONMENT_H #define LIBND4J_ENVIRONMENT_H -#include -#include -#include -#include #include -#include +#include #include +#include + +#include +#include +#include -namespace sd{ - class ND4J_EXPORT Environment { - private: - std::atomic _tadThreshold; - std::atomic _elementThreshold; - std::atomic _verbose; - std::atomic _debug; - std::atomic _leaks; - std::atomic _profile; - std::atomic _dataType; - std::atomic _precBoost; - std::atomic _useMKLDNN{true}; - std::atomic _allowHelpers{true}; - - std::atomic _maxThreads; - std::atomic _maxMasterThreads; - - // these fields hold defaults - std::atomic _maxTotalPrimaryMemory{-1}; - std::atomic _maxTotalSpecialMemory{-1}; - std::atomic _maxDeviceMemory{-1}; - - bool _blasFallback = false; +namespace sd { +class SD_EXPORT Environment { + private: + std::atomic _tadThreshold; + std::atomic _elementThreshold; + std::atomic _verbose; + std::atomic _debug; + std::atomic _leaks; + std::atomic _profile; + std::atomic _dataType; + std::atomic _precBoost; + std::atomic _useMKLDNN{true}; + std::atomic _allowHelpers{true}; + + std::atomic _maxThreads; + std::atomic _maxMasterThreads; + + // these fields hold defaults + std::atomic _maxTotalPrimaryMemory{-1}; + std::atomic _maxTotalSpecialMemory{-1}; + std::atomic _maxDeviceMemory{-1}; + + bool _blasFallback = false; #ifdef __ND4J_EXPERIMENTAL__ - const bool _experimental = true; + const bool _experimental = true; #else - const bool _experimental = false; + const bool _experimental = false; #endif - // device compute capability for CUDA - std::vector _capabilities; + // device compute capability for CUDA + std::vector _capabilities; - Environment(); + Environment(); public: - ~Environment(); - /** - * These 3 fields are mostly for CUDA/cuBLAS version tracking - */ - int _blasMajorVersion = 0; - int _blasMinorVersion = 0; - int _blasPatchVersion = 0; + ~Environment(); + + /** + * These 3 fields are mostly for CUDA/cuBLAS version tracking + */ + int _blasMajorVersion = 0; + int _blasMinorVersion = 0; + int _blasPatchVersion = 0; - static Environment& getInstance(); + static Environment& getInstance(); - bool isVerbose(); - void setVerbose(bool reallyVerbose); - bool isDebug(); - bool isProfiling(); - bool isDetectingLeaks(); - bool isDebugAndVerbose(); - void setDebug(bool reallyDebug); - void setProfiling(bool reallyProfile); - void setLeaksDetector(bool reallyDetect); - bool helpersAllowed(); - void allowHelpers(bool reallyAllow); + bool isVerbose(); + void setVerbose(bool reallyVerbose); + bool isDebug(); + bool isProfiling(); + bool isDetectingLeaks(); + bool isDebugAndVerbose(); + void setDebug(bool reallyDebug); + void setProfiling(bool reallyProfile); + void setLeaksDetector(bool reallyDetect); + bool helpersAllowed(); + void allowHelpers(bool reallyAllow); - bool blasFallback(); - - int tadThreshold(); - void setTadThreshold(int threshold); + bool blasFallback(); - int elementwiseThreshold(); - void setElementwiseThreshold(int threshold); + int tadThreshold(); + void setTadThreshold(int threshold); - int maxThreads(); - void setMaxThreads(int max); + int elementwiseThreshold(); + void setElementwiseThreshold(int threshold); - int maxMasterThreads(); - void setMaxMasterThreads(int max); + int maxThreads(); + void setMaxThreads(int max); - /* - * Legacy memory limits API, still used in new API as simplified version - */ - void setMaxPrimaryMemory(uint64_t maxBytes); - void setMaxSpecialyMemory(uint64_t maxBytes); - void setMaxDeviceMemory(uint64_t maxBytes); + int maxMasterThreads(); + void setMaxMasterThreads(int max); - uint64_t maxPrimaryMemory(); - uint64_t maxSpecialMemory(); - //////////////////////// + /* + * Legacy memory limits API, still used in new API as simplified version + */ + void setMaxPrimaryMemory(uint64_t maxBytes); + void setMaxSpecialyMemory(uint64_t maxBytes); + void setMaxDeviceMemory(uint64_t maxBytes); - /* - * Methods for memory limits/counters - */ - void setGroupLimit(int group, Nd4jLong numBytes); - void setDeviceLimit(int deviceId, Nd4jLong numBytes); + uint64_t maxPrimaryMemory(); + uint64_t maxSpecialMemory(); + //////////////////////// - Nd4jLong getGroupLimit(int group); - Nd4jLong getDeviceLimit(int deviceId); + /* + * Methods for memory limits/counters + */ + void setGroupLimit(int group, Nd4jLong numBytes); + void setDeviceLimit(int deviceId, Nd4jLong numBytes); - Nd4jLong getGroupCounter(int group); - Nd4jLong getDeviceCounter(int deviceId); - //////////////////////// + Nd4jLong getGroupLimit(int group); + Nd4jLong getDeviceLimit(int deviceId); - bool isUseMKLDNN() { return _useMKLDNN.load(); } - void setUseMKLDNN(bool useMKLDNN) { _useMKLDNN.store(useMKLDNN); } + Nd4jLong getGroupCounter(int group); + Nd4jLong getDeviceCounter(int deviceId); + //////////////////////// - sd::DataType defaultFloatDataType(); - void setDefaultFloatDataType(sd::DataType dtype); + bool isUseMKLDNN() { return _useMKLDNN.load(); } + void setUseMKLDNN(bool useMKLDNN) { _useMKLDNN.store(useMKLDNN); } - bool precisionBoostAllowed(); - void allowPrecisionBoost(bool reallyAllow); + sd::DataType defaultFloatDataType(); + void setDefaultFloatDataType(sd::DataType dtype); - bool isExperimentalBuild(); + bool precisionBoostAllowed(); + void allowPrecisionBoost(bool reallyAllow); - bool isCPU(); + bool isExperimentalBuild(); - int blasMajorVersion(); - int blasMinorVersion(); - int blasPatchVersion(); + bool isCPU(); - std::vector& capabilities(); - }; -} + int blasMajorVersion(); + int blasMinorVersion(); + int blasPatchVersion(); + std::vector& capabilities(); +}; +} // namespace sd -#endif //LIBND4J_ENVIRONMENT_H +#endif // LIBND4J_ENVIRONMENT_H diff --git a/libnd4j/include/system/buffer.h b/libnd4j/include/system/buffer.h old mode 100755 new mode 100644 index 5072965ca9bd..d54ebe441343 --- a/libnd4j/include/system/buffer.h +++ b/libnd4j/include/system/buffer.h @@ -28,62 +28,56 @@ #include #include #endif -#include - #include -#include #include +#include #include - //Question: Should the indexes here really be int? Isn't size_t or Nd4jLong more appropriate? +// Question: Should the indexes here really be int? Isn't size_t or Nd4jLong +// more appropriate? namespace sd { - namespace buffer { +namespace buffer { /** * Represents both a cpu and gpu * buffer - mainly used for testing */ - template - struct Buffer { - int length = 0; - int allocatedOnGpu = 0; - T *data = nullptr; - T *gData = nullptr; - T one, two; - public: - ~Buffer() { - delete []data; - delete []gData; - } - - void assign(T *val) { - data = val; - } - - T &operator=(T x) { - one = x; - return x; - } - - class Proxy { - Buffer &a; - int idx; - public: - Proxy(Buffer &a, int idx) : - a(a), idx(idx) { - } - - T &operator=(T x) { - a.two = x; - a.data[idx] = x; - return a.data[idx]; - } - }; - - - Proxy operator[](int index) { - return Proxy(*this, index); - } - }; +template +struct Buffer { + int length = 0; + int allocatedOnGpu = 0; + T *data = nullptr; + T *gData = nullptr; + T one, two; + + public: + ~Buffer() { + delete[] data; + delete[] gData; + } + + void assign(T *val) { data = val; } + + T &operator=(T x) { + one = x; + return x; + } + + class Proxy { + Buffer &a; + int idx; + + public: + Proxy(Buffer &a, int idx) : a(a), idx(idx) {} + + T &operator=(T x) { + a.two = x; + a.data[idx] = x; + return a.data[idx]; + } + }; + + Proxy operator[](int index) { return Proxy(*this, index); } +}; /** * Returns the size of the buffer @@ -91,13 +85,14 @@ namespace sd { * @param buffer the buffer to get the size of * @return the size of the buffer in bytes */ - template +template #ifdef __CUDACC__ - __host__ __device__ +__host__ __device__ #endif - int bufferSize(Buffer *buffer); + int + bufferSize(Buffer *buffer); /** * Copies data to the gpu @@ -105,45 +100,41 @@ namespace sd { */ #ifdef __CUDACC__ - template - __host__ - void copyDataToGpu(Buffer **buffer, cudaStream_t stream); +template +__host__ void copyDataToGpu(Buffer **buffer, cudaStream_t stream); #endif - - /** * Copies data from the gpu * @param buffer the buffer to copy */ #ifdef __CUDACC__ - template - __host__ - void copyDataFromGpu(Buffer **buffer, cudaStream_t stream); +template +__host__ void copyDataFromGpu(Buffer **buffer, cudaStream_t stream); #endif - - /** * Allocate buffer of the given * length on the cpu and gpu. */ - template +template #ifdef __CUDACC__ - __host__ +__host__ #endif - void allocBuffer(Buffer **buffer, int length); + void + allocBuffer(Buffer **buffer, int length); /** * Frees the given buffer * (gpu and cpu */ - template +template #ifdef __CUDACC__ - __host__ +__host__ #endif - void freeBuffer(Buffer **buffer); + void + freeBuffer(Buffer **buffer); /** * Creates a buffer @@ -151,60 +142,65 @@ namespace sd { * and also synchronizes * the data on the gpu. */ - template +template #ifdef __CUDACC__ - __host__ +__host__ #endif - Buffer - * - createBuffer(T *data, int length); + Buffer + *createBuffer(T *data, int length); /** * Print the buffer on the host * @param buff */ - template +template #ifdef __CUDACC__ - __host__ +__host__ #endif - void printArr(Buffer *buff); + void + printArr(Buffer *buff); /** * * @param buffer * @return */ - template +template #ifdef __CUDACC__ - __host__ __device__ +__host__ __device__ #endif - int bufferSize(Buffer *buffer) { - return sizeof(T) * buffer->length; - } + int + bufferSize(Buffer *buffer) { + return sizeof(T) * buffer->length; +} #ifdef __CUDACC__ - /** +/** * * @param buffer */ -template -__host__ void copyDataToGpu(Buffer **buffer, cudaStream_t stream) { - Buffer *bufferRef = *buffer; - checkCudaErrors(cudaMemcpyAsync(bufferRef->gData, bufferRef->data, bufferSize(bufferRef), cudaMemcpyHostToDevice, stream)); - checkCudaErrors(cudaStreamSynchronize(stream)); +template +__host__ void copyDataToGpu(Buffer **buffer, cudaStream_t stream) { + Buffer *bufferRef = *buffer; + checkCudaErrors(cudaMemcpyAsync(bufferRef->gData, bufferRef->data, + bufferSize(bufferRef), cudaMemcpyHostToDevice, + stream)); + checkCudaErrors(cudaStreamSynchronize(stream)); } /** * * @param buffer */ -template -__host__ void copyDataFromGpu(Buffer **buffer, cudaStream_t stream) { - Buffer *bufferRef = *buffer; - int bufferTotalSize = bufferSize(bufferRef); - checkCudaErrors(cudaMemcpyAsync(bufferRef->data, bufferRef->gData, bufferTotalSize, cudaMemcpyDeviceToHost, stream)); - checkCudaErrors(cudaStreamSynchronize(stream)); +template +__host__ void copyDataFromGpu(Buffer **buffer, cudaStream_t stream) { + Buffer *bufferRef = *buffer; + int bufferTotalSize = bufferSize(bufferRef); + checkCudaErrors(cudaMemcpyAsync(bufferRef->data, bufferRef->gData, + bufferTotalSize, cudaMemcpyDeviceToHost, + stream)); + checkCudaErrors(cudaStreamSynchronize(stream)); } #endif @@ -212,38 +208,40 @@ __host__ void copyDataFromGpu(Buffer **buffer, cudaStream_t stream) { * Allocate buffer of the given * length on the cpu and gpu. */ - template +template #ifdef __CUDACC__ - __host__ +__host__ #endif - void allocBuffer(Buffer **buffer, int length) { - Buffer *bufferRef = *buffer; - bufferRef->length = length; - bufferRef->data = reinterpret_cast(malloc(sizeof(T) * length)); - - CHECK_ALLOC(bufferRef->data, "Failed to allocate new buffer", sizeof(T) * length); + void + allocBuffer(Buffer **buffer, int length) { + Buffer *bufferRef = *buffer; + bufferRef->length = length; + bufferRef->data = reinterpret_cast(malloc(sizeof(T) * length)); + + CHECK_ALLOC(bufferRef->data, "Failed to allocate new buffer", + sizeof(T) * length); #ifdef __CUDACC__ - checkCudaErrors(cudaMalloc(&bufferRef->gData, sizeof(T) * length)); + checkCudaErrors(cudaMalloc(&bufferRef->gData, sizeof(T) * length)); #endif - } +} /** * Frees the given buffer * (gpu and cpu */ - template +template #ifdef __CUDACC__ - __host__ +__host__ #endif - void freeBuffer(Buffer *buffer) { + void + freeBuffer(Buffer *buffer) { #ifdef __CUDACC__ - if(buffer->gData != nullptr) - checkCudaErrors(cudaFree(buffer->gData)); + if (buffer->gData != nullptr) checkCudaErrors(cudaFree(buffer->gData)); #endif - delete buffer; - } + delete buffer; +} /** * Creates a buffer @@ -251,47 +249,44 @@ __host__ void copyDataFromGpu(Buffer **buffer, cudaStream_t stream) { * and also synchronizes * the data on the gpu. */ - template +template #ifdef __CUDACC__ - __host__ +__host__ #endif - Buffer *createBuffer(T *data, int length) { - Buffer *ret = new Buffer; - T *buffData = new T[length]; - for(int i = 0; i < length; i++) - buffData[i] = data[i]; - ret->data = buffData; - ret->length = length; - return ret; - } - - + Buffer + *createBuffer(T *data, int length) { + Buffer *ret = new Buffer; + T *buffData = new T[length]; + for (int i = 0; i < length; i++) buffData[i] = data[i]; + ret->data = buffData; + ret->length = length; + return ret; +} #ifdef __CUDACC__ - template - __host__ - Buffer *createBuffer(T *data, int length, cudaStream_t stream) { - Buffer *ret = createBuffer(data, length); - - T *gData; - T **gDataRef = &(gData); - checkCudaErrors(cudaMalloc(reinterpret_cast(gDataRef), sizeof(T) * length)); - ret->gData = gData; - checkCudaErrors(cudaMemcpyAsync(ret->gData, ret->data, sizeof(T) * length, cudaMemcpyHostToDevice, stream)); - return ret; - } -#endif - } +template +__host__ Buffer *createBuffer(T *data, int length, cudaStream_t stream) { + Buffer *ret = createBuffer(data, length); + + T *gData; + T **gDataRef = &(gData); + checkCudaErrors( + cudaMalloc(reinterpret_cast(gDataRef), sizeof(T) * length)); + ret->gData = gData; + checkCudaErrors(cudaMemcpyAsync(ret->gData, ret->data, sizeof(T) * length, + cudaMemcpyHostToDevice, stream)); + return ret; } - - +#endif +} // namespace buffer +} // namespace sd #ifdef __CUDACC__ -template -__host__ void printArr(sd::buffer::Buffer *buff) { - for (int i = 0; i < buff->length; i++) { - printf("Buffer[%d] was %f\n", i, buff->data[i]); - } +template +__host__ void printArr(sd::buffer::Buffer *buff) { + for (int i = 0; i < buff->length; i++) { + printf("Buffer[%d] was %f\n", i, buff->data[i]); + } } #endif diff --git a/libnd4j/include/system/dll.h b/libnd4j/include/system/dll.h index 71098f8bf472..99dc203a0664 100644 --- a/libnd4j/include/system/dll.h +++ b/libnd4j/include/system/dll.h @@ -25,8 +25,8 @@ #ifdef _WIN32 //#include -# define ND4J_EXPORT __declspec(dllexport) +#define SD_EXPORT __declspec(dllexport) #else -# define ND4J_EXPORT +#define SD_EXPORT #endif -#endif //NATIVEOPERATIONS_DLL_H +#endif // NATIVEOPERATIONS_DLL_H diff --git a/libnd4j/include/system/enum_boilerplate.h b/libnd4j/include/system/enum_boilerplate.h index 2acb1536aa37..5b198b7005c1 100644 --- a/libnd4j/include/system/enum_boilerplate.h +++ b/libnd4j/include/system/enum_boilerplate.h @@ -23,113 +23,133 @@ #include - #define EN_1(WHAT, OP_PAIR) WHAT(OP_PAIR) -#define EN_2(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_1(WHAT, __VA_ARGS__)) -#define EN_3(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_2(WHAT, __VA_ARGS__)) -#define EN_4(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_3(WHAT, __VA_ARGS__)) -#define EN_5(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_4(WHAT, __VA_ARGS__)) -#define EN_6(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_5(WHAT, __VA_ARGS__)) -#define EN_7(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_6(WHAT, __VA_ARGS__)) -#define EN_8(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_7(WHAT, __VA_ARGS__)) -#define EN_9(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_8(WHAT, __VA_ARGS__)) -#define EN_10(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_9(WHAT, __VA_ARGS__)) -#define EN_11(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_10(WHAT, __VA_ARGS__)) -#define EN_12(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_11(WHAT, __VA_ARGS__)) -#define EN_13(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_12(WHAT, __VA_ARGS__)) -#define EN_14(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_13(WHAT, __VA_ARGS__)) -#define EN_15(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_14(WHAT, __VA_ARGS__)) -#define EN_16(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_15(WHAT, __VA_ARGS__)) -#define EN_17(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_16(WHAT, __VA_ARGS__)) -#define EN_18(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_17(WHAT, __VA_ARGS__)) -#define EN_19(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_18(WHAT, __VA_ARGS__)) -#define EN_20(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_19(WHAT, __VA_ARGS__)) -#define EN_21(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_20(WHAT, __VA_ARGS__)) -#define EN_22(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_21(WHAT, __VA_ARGS__)) -#define EN_23(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_22(WHAT, __VA_ARGS__)) -#define EN_24(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_23(WHAT, __VA_ARGS__)) -#define EN_25(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_24(WHAT, __VA_ARGS__)) -#define EN_26(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_25(WHAT, __VA_ARGS__)) -#define EN_27(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_26(WHAT, __VA_ARGS__)) -#define EN_28(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_27(WHAT, __VA_ARGS__)) -#define EN_29(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_28(WHAT, __VA_ARGS__)) -#define EN_30(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_29(WHAT, __VA_ARGS__)) -#define EN_31(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_30(WHAT, __VA_ARGS__)) -#define EN_32(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_31(WHAT, __VA_ARGS__)) -#define EN_33(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_32(WHAT, __VA_ARGS__)) -#define EN_34(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_33(WHAT, __VA_ARGS__)) -#define EN_35(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_34(WHAT, __VA_ARGS__)) -#define EN_36(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_35(WHAT, __VA_ARGS__)) -#define EN_37(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_36(WHAT, __VA_ARGS__)) -#define EN_38(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_37(WHAT, __VA_ARGS__)) -#define EN_39(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_38(WHAT, __VA_ARGS__)) -#define EN_40(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_39(WHAT, __VA_ARGS__)) -#define EN_41(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_40(WHAT, __VA_ARGS__)) -#define EN_42(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_41(WHAT, __VA_ARGS__)) -#define EN_43(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_42(WHAT, __VA_ARGS__)) -#define EN_44(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_43(WHAT, __VA_ARGS__)) -#define EN_45(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_44(WHAT, __VA_ARGS__)) -#define EN_46(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_45(WHAT, __VA_ARGS__)) -#define EN_47(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_46(WHAT, __VA_ARGS__)) -#define EN_48(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_47(WHAT, __VA_ARGS__)) -#define EN_49(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_48(WHAT, __VA_ARGS__)) -#define EN_50(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_49(WHAT, __VA_ARGS__)) -#define EN_51(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_50(WHAT, __VA_ARGS__)) -#define EN_52(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_51(WHAT, __VA_ARGS__)) -#define EN_53(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_52(WHAT, __VA_ARGS__)) -#define EN_54(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_53(WHAT, __VA_ARGS__)) -#define EN_55(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_54(WHAT, __VA_ARGS__)) -#define EN_56(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_55(WHAT, __VA_ARGS__)) -#define EN_57(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_56(WHAT, __VA_ARGS__)) -#define EN_58(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_57(WHAT, __VA_ARGS__)) -#define EN_59(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_58(WHAT, __VA_ARGS__)) -#define EN_60(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_59(WHAT, __VA_ARGS__)) -#define EN_61(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_60(WHAT, __VA_ARGS__)) -#define EN_62(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_61(WHAT, __VA_ARGS__)) -#define EN_63(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_62(WHAT, __VA_ARGS__)) -#define EN_64(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_63(WHAT, __VA_ARGS__)) -#define EN_65(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_64(WHAT, __VA_ARGS__)) -#define EN_66(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_65(WHAT, __VA_ARGS__)) -#define EN_67(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_66(WHAT, __VA_ARGS__)) -#define EN_68(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_67(WHAT, __VA_ARGS__)) -#define EN_69(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_68(WHAT, __VA_ARGS__)) -#define EN_70(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_69(WHAT, __VA_ARGS__)) -#define EN_71(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_70(WHAT, __VA_ARGS__)) -#define EN_72(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_71(WHAT, __VA_ARGS__)) -#define EN_73(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_72(WHAT, __VA_ARGS__)) -#define EN_74(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_73(WHAT, __VA_ARGS__)) -#define EN_75(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_74(WHAT, __VA_ARGS__)) -#define EN_76(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_75(WHAT, __VA_ARGS__)) -#define EN_77(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_76(WHAT, __VA_ARGS__)) -#define EN_78(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_77(WHAT, __VA_ARGS__)) -#define EN_79(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_78(WHAT, __VA_ARGS__)) -#define EN_80(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_79(WHAT, __VA_ARGS__)) -#define EN_81(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_80(WHAT, __VA_ARGS__)) -#define EN_82(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_81(WHAT, __VA_ARGS__)) -#define EN_83(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_82(WHAT, __VA_ARGS__)) -#define EN_84(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_83(WHAT, __VA_ARGS__)) -#define EN_85(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_84(WHAT, __VA_ARGS__)) -#define EN_86(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_85(WHAT, __VA_ARGS__)) -#define EN_87(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_86(WHAT, __VA_ARGS__)) -#define EN_88(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_87(WHAT, __VA_ARGS__)) -#define EN_89(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_88(WHAT, __VA_ARGS__)) -#define EN_90(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_89(WHAT, __VA_ARGS__)) -#define EN_91(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_90(WHAT, __VA_ARGS__)) -#define EN_92(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_91(WHAT, __VA_ARGS__)) -#define EN_93(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_92(WHAT, __VA_ARGS__)) -#define EN_94(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_93(WHAT, __VA_ARGS__)) -#define EN_95(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_94(WHAT, __VA_ARGS__)) -#define EN_96(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_95(WHAT, __VA_ARGS__)) -#define EN_97(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_96(WHAT, __VA_ARGS__)) -#define EN_98(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_97(WHAT, __VA_ARGS__)) -#define EN_99(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_98(WHAT, __VA_ARGS__)) +#define EN_2(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_1(WHAT, __VA_ARGS__)) +#define EN_3(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_2(WHAT, __VA_ARGS__)) +#define EN_4(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_3(WHAT, __VA_ARGS__)) +#define EN_5(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_4(WHAT, __VA_ARGS__)) +#define EN_6(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_5(WHAT, __VA_ARGS__)) +#define EN_7(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_6(WHAT, __VA_ARGS__)) +#define EN_8(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_7(WHAT, __VA_ARGS__)) +#define EN_9(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_8(WHAT, __VA_ARGS__)) +#define EN_10(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_9(WHAT, __VA_ARGS__)) +#define EN_11(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_10(WHAT, __VA_ARGS__)) +#define EN_12(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_11(WHAT, __VA_ARGS__)) +#define EN_13(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_12(WHAT, __VA_ARGS__)) +#define EN_14(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_13(WHAT, __VA_ARGS__)) +#define EN_15(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_14(WHAT, __VA_ARGS__)) +#define EN_16(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_15(WHAT, __VA_ARGS__)) +#define EN_17(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_16(WHAT, __VA_ARGS__)) +#define EN_18(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_17(WHAT, __VA_ARGS__)) +#define EN_19(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_18(WHAT, __VA_ARGS__)) +#define EN_20(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_19(WHAT, __VA_ARGS__)) +#define EN_21(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_20(WHAT, __VA_ARGS__)) +#define EN_22(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_21(WHAT, __VA_ARGS__)) +#define EN_23(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_22(WHAT, __VA_ARGS__)) +#define EN_24(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_23(WHAT, __VA_ARGS__)) +#define EN_25(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_24(WHAT, __VA_ARGS__)) +#define EN_26(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_25(WHAT, __VA_ARGS__)) +#define EN_27(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_26(WHAT, __VA_ARGS__)) +#define EN_28(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_27(WHAT, __VA_ARGS__)) +#define EN_29(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_28(WHAT, __VA_ARGS__)) +#define EN_30(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_29(WHAT, __VA_ARGS__)) +#define EN_31(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_30(WHAT, __VA_ARGS__)) +#define EN_32(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_31(WHAT, __VA_ARGS__)) +#define EN_33(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_32(WHAT, __VA_ARGS__)) +#define EN_34(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_33(WHAT, __VA_ARGS__)) +#define EN_35(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_34(WHAT, __VA_ARGS__)) +#define EN_36(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_35(WHAT, __VA_ARGS__)) +#define EN_37(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_36(WHAT, __VA_ARGS__)) +#define EN_38(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_37(WHAT, __VA_ARGS__)) +#define EN_39(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_38(WHAT, __VA_ARGS__)) +#define EN_40(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_39(WHAT, __VA_ARGS__)) +#define EN_41(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_40(WHAT, __VA_ARGS__)) +#define EN_42(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_41(WHAT, __VA_ARGS__)) +#define EN_43(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_42(WHAT, __VA_ARGS__)) +#define EN_44(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_43(WHAT, __VA_ARGS__)) +#define EN_45(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_44(WHAT, __VA_ARGS__)) +#define EN_46(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_45(WHAT, __VA_ARGS__)) +#define EN_47(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_46(WHAT, __VA_ARGS__)) +#define EN_48(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_47(WHAT, __VA_ARGS__)) +#define EN_49(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_48(WHAT, __VA_ARGS__)) +#define EN_50(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_49(WHAT, __VA_ARGS__)) +#define EN_51(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_50(WHAT, __VA_ARGS__)) +#define EN_52(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_51(WHAT, __VA_ARGS__)) +#define EN_53(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_52(WHAT, __VA_ARGS__)) +#define EN_54(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_53(WHAT, __VA_ARGS__)) +#define EN_55(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_54(WHAT, __VA_ARGS__)) +#define EN_56(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_55(WHAT, __VA_ARGS__)) +#define EN_57(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_56(WHAT, __VA_ARGS__)) +#define EN_58(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_57(WHAT, __VA_ARGS__)) +#define EN_59(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_58(WHAT, __VA_ARGS__)) +#define EN_60(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_59(WHAT, __VA_ARGS__)) +#define EN_61(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_60(WHAT, __VA_ARGS__)) +#define EN_62(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_61(WHAT, __VA_ARGS__)) +#define EN_63(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_62(WHAT, __VA_ARGS__)) +#define EN_64(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_63(WHAT, __VA_ARGS__)) +#define EN_65(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_64(WHAT, __VA_ARGS__)) +#define EN_66(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_65(WHAT, __VA_ARGS__)) +#define EN_67(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_66(WHAT, __VA_ARGS__)) +#define EN_68(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_67(WHAT, __VA_ARGS__)) +#define EN_69(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_68(WHAT, __VA_ARGS__)) +#define EN_70(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_69(WHAT, __VA_ARGS__)) +#define EN_71(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_70(WHAT, __VA_ARGS__)) +#define EN_72(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_71(WHAT, __VA_ARGS__)) +#define EN_73(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_72(WHAT, __VA_ARGS__)) +#define EN_74(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_73(WHAT, __VA_ARGS__)) +#define EN_75(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_74(WHAT, __VA_ARGS__)) +#define EN_76(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_75(WHAT, __VA_ARGS__)) +#define EN_77(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_76(WHAT, __VA_ARGS__)) +#define EN_78(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_77(WHAT, __VA_ARGS__)) +#define EN_79(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_78(WHAT, __VA_ARGS__)) +#define EN_80(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_79(WHAT, __VA_ARGS__)) +#define EN_81(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_80(WHAT, __VA_ARGS__)) +#define EN_82(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_81(WHAT, __VA_ARGS__)) +#define EN_83(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_82(WHAT, __VA_ARGS__)) +#define EN_84(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_83(WHAT, __VA_ARGS__)) +#define EN_85(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_84(WHAT, __VA_ARGS__)) +#define EN_86(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_85(WHAT, __VA_ARGS__)) +#define EN_87(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_86(WHAT, __VA_ARGS__)) +#define EN_88(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_87(WHAT, __VA_ARGS__)) +#define EN_89(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_88(WHAT, __VA_ARGS__)) +#define EN_90(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_89(WHAT, __VA_ARGS__)) +#define EN_91(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_90(WHAT, __VA_ARGS__)) +#define EN_92(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_91(WHAT, __VA_ARGS__)) +#define EN_93(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_92(WHAT, __VA_ARGS__)) +#define EN_94(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_93(WHAT, __VA_ARGS__)) +#define EN_95(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_94(WHAT, __VA_ARGS__)) +#define EN_96(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_95(WHAT, __VA_ARGS__)) +#define EN_97(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_96(WHAT, __VA_ARGS__)) +#define EN_98(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_97(WHAT, __VA_ARGS__)) +#define EN_99(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_98(WHAT, __VA_ARGS__)) #define __EXPAND_ENUM(NUM, CLASS) CLASS = NUM, -#define _EXPAND_ENUM(OP_PAIR) EVALUATING_PASTE(__EXPAND, _ENUM(UNPAREN(OP_PAIR))) +#define _EXPAND_ENUM(OP_PAIR) \ + EVALUATING_PASTE(__EXPAND, _ENUM(UNPAREN(OP_PAIR))) -#define GET_MACROS_ENUM(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, _61, _62, _63, _64, _65, _66, _67, _68, _69, _70, _71, _72, _73, _74, _75, _76, _77, _78, _79, _80, _81, _82, _83, _84, _85, _86, _87, _88, _89, _90, _91, _92, _93, _94, _95, _96, _97, _98, _99, NAME,...) NAME +#define GET_MACROS_ENUM( \ + _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, \ + _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, \ + _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, \ + _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, _61, \ + _62, _63, _64, _65, _66, _67, _68, _69, _70, _71, _72, _73, _74, _75, _76, \ + _77, _78, _79, _80, _81, _82, _83, _84, _85, _86, _87, _88, _89, _90, _91, \ + _92, _93, _94, _95, _96, _97, _98, _99, NAME, ...) \ + NAME -#define FOR_EACH_ENUM(WHAT, ...) EXPAND(GET_MACROS_ENUM(__VA_ARGS__, EN_99, EN_98, EN_97, EN_96, EN_95, EN_94, EN_93, EN_92, EN_91, EN_90, EN_89, EN_88, EN_87, EN_86, EN_85, EN_84, EN_83, EN_82, EN_81, EN_80, EN_79, EN_78, EN_77, EN_76, EN_75, EN_74, EN_73, EN_72, EN_71, EN_70, EN_69, EN_68, EN_67, EN_66, EN_65, EN_64, EN_63, EN_62, EN_61, EN_60, EN_59, EN_58, EN_57, EN_56, EN_55, EN_54, EN_53, EN_52, EN_51, EN_50, EN_49, EN_48, EN_47, EN_46, EN_45, EN_44, EN_43, EN_42, EN_41, EN_40, EN_39, EN_38, EN_37, EN_36, EN_35, EN_34, EN_33, EN_32, EN_31, EN_30, EN_29, EN_28, EN_27, EN_26, EN_25, EN_24, EN_23, EN_22, EN_21, EN_20, EN_19, EN_18, EN_17, EN_16, EN_15, EN_14, EN_13, EN_12, EN_11, EN_10, EN_9, EN_8, EN_7, EN_6, EN_5, EN_4, EN_3, EN_2, EN_1)(WHAT, __VA_ARGS__)) +#define FOR_EACH_ENUM(WHAT, ...) \ + EXPAND(GET_MACROS_ENUM( \ + __VA_ARGS__, EN_99, EN_98, EN_97, EN_96, EN_95, EN_94, EN_93, EN_92, \ + EN_91, EN_90, EN_89, EN_88, EN_87, EN_86, EN_85, EN_84, EN_83, EN_82, \ + EN_81, EN_80, EN_79, EN_78, EN_77, EN_76, EN_75, EN_74, EN_73, EN_72, \ + EN_71, EN_70, EN_69, EN_68, EN_67, EN_66, EN_65, EN_64, EN_63, EN_62, \ + EN_61, EN_60, EN_59, EN_58, EN_57, EN_56, EN_55, EN_54, EN_53, EN_52, \ + EN_51, EN_50, EN_49, EN_48, EN_47, EN_46, EN_45, EN_44, EN_43, EN_42, \ + EN_41, EN_40, EN_39, EN_38, EN_37, EN_36, EN_35, EN_34, EN_33, EN_32, \ + EN_31, EN_30, EN_29, EN_28, EN_27, EN_26, EN_25, EN_24, EN_23, EN_22, \ + EN_21, EN_20, EN_19, EN_18, EN_17, EN_16, EN_15, EN_14, EN_13, EN_12, \ + EN_11, EN_10, EN_9, EN_8, EN_7, EN_6, EN_5, EN_4, EN_3, EN_2, \ + EN_1)(WHAT, __VA_ARGS__)) #define _EXEC_ENUM(WHAT, ...) EVAL(FOR_EACH_ENUM(WHAT, __VA_ARGS__)) #define EXEC_ENUMERATOR(...) EXPAND(_EXEC_ENUM(_EXPAND_ENUM, __VA_ARGS__)) diff --git a/libnd4j/include/system/msvc.h b/libnd4j/include/system/msvc.h index c884736f3ec7..f14423a8572c 100644 --- a/libnd4j/include/system/msvc.h +++ b/libnd4j/include/system/msvc.h @@ -23,17 +23,17 @@ #if defined(_MSC_VER) -#pragma warning( disable : 4244 ) -#pragma warning( disable : 4267 ) -#pragma warning( disable : 4251 ) -#pragma warning( disable : 4101 ) -#pragma warning( disable : 4305 ) -#pragma warning( disable : 4309 ) -#pragma warning( disable : 4333 ) -#pragma warning( disable : 4146 ) -#pragma warning( disable : 4018 ) -#pragma warning( disable : 4297 ) +#pragma warning(disable : 4244) +#pragma warning(disable : 4267) +#pragma warning(disable : 4251) +#pragma warning(disable : 4101) +#pragma warning(disable : 4305) +#pragma warning(disable : 4309) +#pragma warning(disable : 4333) +#pragma warning(disable : 4146) +#pragma warning(disable : 4018) +#pragma warning(disable : 4297) #endif -#endif //DEV_TESTS_MSVC_H +#endif // SD_MSVC_H diff --git a/libnd4j/include/system/nd4jmalloc.h b/libnd4j/include/system/nd4jmalloc.h index c808aad090f7..1ee40a021302 100644 --- a/libnd4j/include/system/nd4jmalloc.h +++ b/libnd4j/include/system/nd4jmalloc.h @@ -21,4 +21,4 @@ #ifndef NATIVEOPERATIONS_ND4JMALLOC_H #define NATIVEOPERATIONS_ND4JMALLOC_H #include -#endif //NATIVEOPERATIONS_ND4JMALLOC_H +#endif // NATIVEOPERATIONS_ND4JMALLOC_H diff --git a/libnd4j/include/system/nd4jmemset.h b/libnd4j/include/system/nd4jmemset.h index 6482dcb8a2f7..7df35d5a0fef 100644 --- a/libnd4j/include/system/nd4jmemset.h +++ b/libnd4j/include/system/nd4jmemset.h @@ -21,4 +21,4 @@ #ifndef NATIVEOPERATIONS_ND4JSTRING_H #define NATIVEOPERATIONS_ND4JSTRING_H #include -#endif //NATIVEOPERATIONS_ND4JSTRING_H +#endif // NATIVEOPERATIONS_ND4JSTRING_H diff --git a/libnd4j/include/system/op_boilerplate.h b/libnd4j/include/system/op_boilerplate.h index 0c2630f22001..0860c0dd8063 100644 --- a/libnd4j/include/system/op_boilerplate.h +++ b/libnd4j/include/system/op_boilerplate.h @@ -1231,12 +1231,12 @@ #define REQUIRE_OK(A) if (sd::ops::resultHelper( (A), #A, __FILE__, __LINE__ ) != 0) return ND4J_STATUS_VALIDATION; #define REQUIRE_TRUE(COND, ...) if (!(COND)) { if (sd::ops::conditionHelper(__FILE__, __LINE__, COND, __VA_ARGS__) != 0) throw std::invalid_argument("Op validation failed");}; -#define DECLARE_ENTRY(NAME, ...) template struct ND4J_EXPORT __registratorFloat>; \ - template struct ND4J_EXPORT __registratorHalf>; \ - template struct ND4J_EXPORT __registratorDouble>; \ - template struct ND4J_EXPORT __registratorSynonymHalf>; \ - template struct ND4J_EXPORT __registratorSynonymDouble>; \ - template struct ND4J_EXPORT __registratorSynonymFloat>; +#define DECLARE_ENTRY(NAME, ...) template struct SD_EXPORT __registratorFloat>; \ + template struct SD_EXPORT __registratorHalf>; \ + template struct SD_EXPORT __registratorDouble>; \ + template struct SD_EXPORT __registratorSynonymHalf>; \ + template struct SD_EXPORT __registratorSynonymDouble>; \ + template struct SD_EXPORT __registratorSynonymFloat>; #if defined(_MSC_VER) || defined(_WIN64) || defined(_WIN32) || defined(__CLION_IDE__) || defined(__VSCODE__) @@ -1255,7 +1255,7 @@ #define REGISTER_H(NAME) template \ struct __registrator_##NAME {\ __registrator_##NAME() {\ - OpName *ptr = new OpName(); \ + auto ptr = std::make_shared(); \ OpRegistrator::getInstance().registerOperation(ptr); \ }\ };\ @@ -1268,7 +1268,7 @@ #define REGISTER_C(NAME) template \ struct __registrator_##NAME {\ __registrator_##NAME() {\ - OpName *ptr = new OpName(); \ + auto ptr = std::make_shared(); \ OpRegistrator::getInstance().registerOperation(ptr); \ }\ };\ @@ -1277,7 +1277,7 @@ #define REGISTER_C(NAME) #endif -#define DECLARE_OP(NAME, NIN, NOUT, INPLACEABLE) class ND4J_EXPORT NAME: public sd::ops::DeclarableOp { \ +#define DECLARE_OP(NAME, NIN, NOUT, INPLACEABLE) class SD_EXPORT NAME: public sd::ops::DeclarableOp { \ public:\ NAME(); \ sd::ShapeList* calculateOutputShape(sd::ShapeList* inputShape, sd::graph::Context& block); \ @@ -1287,7 +1287,7 @@ };\ REGISTER_H(NAME) -#define DECLARE_BOOLEAN_OP(NAME, NIN, SCALAR) class ND4J_EXPORT NAME: public sd::ops::BooleanOp { \ +#define DECLARE_BOOLEAN_OP(NAME, NIN, SCALAR) class SD_EXPORT NAME: public sd::ops::BooleanOp { \ public:\ NAME(); \ protected: \ @@ -1300,7 +1300,7 @@ REGISTER_C(NAME) \ Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block) -#define DECLARE_LIST_OP(NAME, NIN, NOUT, TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::DeclarableListOp { \ +#define DECLARE_LIST_OP(NAME, NIN, NOUT, TARGS, IARGS) class SD_EXPORT NAME: public sd::ops::DeclarableListOp { \ public:\ NAME(); \ protected: \ @@ -1312,7 +1312,7 @@ REGISTER_C(NAME) \ Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block) -#define DECLARE_LOGIC_OP(NAME) class ND4J_EXPORT NAME: public sd::ops::LogicOp { \ +#define DECLARE_LOGIC_OP(NAME) class SD_EXPORT NAME: public sd::ops::LogicOp { \ public:\ NAME(); \ protected: \ @@ -1343,7 +1343,7 @@ #define DECLARE_SYN(NAME, ORIGINAL) template \ struct __registratorSynonym_##NAME {\ __registratorSynonym_##NAME(const char *name, const char *oname) {\ - auto ptr = reinterpret_cast(OpRegistrator::getInstance().getOperation(oname)); \ + auto ptr = OpRegistrator::getInstance().getOperation(oname); \ if (ptr == nullptr) { \ std::string newName(name); \ std::string oldName(oname); \ @@ -1355,7 +1355,7 @@ };\ static sd::ops::__registratorSynonym_##NAME zzz_register_opd_##NAME(#NAME, #ORIGINAL) -#define DECLARE_DIVERGENT_OP(NAME, NIN, NOUT, INPLACEABLE) class ND4J_EXPORT NAME: public sd::ops::DeclarableOp { \ +#define DECLARE_DIVERGENT_OP(NAME, NIN, NOUT, INPLACEABLE) class SD_EXPORT NAME: public sd::ops::DeclarableOp { \ public:\ NAME(); \ sd::ShapeList* calculateOutputShape(sd::ShapeList* inputShape, sd::graph::Context& block); \ @@ -1378,7 +1378,7 @@ } \ Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block) -#define DECLARE_CONFIGURABLE_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::DeclarableOp { \ +#define DECLARE_CONFIGURABLE_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) class SD_EXPORT NAME: public sd::ops::DeclarableOp { \ public:\ NAME(); \ sd::ShapeList* calculateOutputShape(sd::ShapeList* inputShape, sd::graph::Context& block); \ @@ -1401,7 +1401,7 @@ } \ Nd4jStatus sd::ops::NAME::validateAndExecute(Context& block) -#define DECLARE_REDUCTION_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::DeclarableReductionOp { \ +#define DECLARE_REDUCTION_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) class SD_EXPORT NAME: public sd::ops::DeclarableReductionOp { \ public:\ NAME(); \ protected: \ @@ -1415,7 +1415,7 @@ Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block) -#define DECLARE_CUSTOM_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::DeclarableCustomOp { \ +#define DECLARE_CUSTOM_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) class SD_EXPORT NAME: public sd::ops::DeclarableCustomOp { \ protected: \ void registerTypes(); \ Nd4jStatus validateAndExecute(Context& block); \ @@ -1437,7 +1437,7 @@ #define DECLARE_TYPES(NAME) void sd::ops::NAME::registerTypes() -#define DECLARE_BROADCASTABLE_OP(NAME,TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::BroadcastableOp { \ +#define DECLARE_BROADCASTABLE_OP(NAME,TARGS, IARGS) class SD_EXPORT NAME: public sd::ops::BroadcastableOp { \ protected: \ void registerTypes(); \ Nd4jStatus validateAndExecute(Context& block); \ @@ -1446,7 +1446,7 @@ };\ REGISTER_H(NAME) -#define DECLARE_BROADCASTABLE_BOOL_OP(NAME,TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::BroadcastableBoolOp { \ +#define DECLARE_BROADCASTABLE_BOOL_OP(NAME,TARGS, IARGS) class SD_EXPORT NAME: public sd::ops::BroadcastableBoolOp { \ protected: \ void registerTypes(); \ Nd4jStatus validateAndExecute(Context& block); \ @@ -1525,20 +1525,20 @@ #define CHECK_STASH(NAME) block.getStash()->checkStash(block.getNodeId(), NAME); #define UNSTASH(NAME) block.getStash()->extractArray(block.getNodeId(), NAME); -#define INPUT_VARIABLE(INDEX) block.array(INDEX) +#define INPUT_VARIABLE(INDEX) block.arrayForOp(INDEX) #define OUTPUT_VARIABLE(INDEX) reinterpret_cast(this->getZ(block, INDEX)) #define OUTPUT_NULLIFIED(INDEX) reinterpret_cast(this->getNullifiedZ(block, INDEX)) -#define INPUT_LIST(INDEX) reinterpret_cast(block.getVariable(INDEX)->getNDArrayList()) +#define INPUT_LIST(INDEX) reinterpret_cast(block.getVariable(INDEX)->getNDArrayList().get()) -#define D_ARG(INDEX) block.getDArguments()->at(INDEX) -#define INT_ARG(INDEX) block.getIArguments()->at(INDEX) +#define D_ARG(INDEX) block.getDArguments().at(INDEX) +#define INT_ARG(INDEX) block.getIArguments().at(INDEX) #define I_ARG(INDEX) INT_ARG(INDEX) -#define T_ARG(INDEX) block.getTArguments()->at(INDEX) -#define B_ARG(INDEX) block.getBArguments()->at(INDEX) +#define T_ARG(INDEX) block.getTArguments().at(INDEX) +#define B_ARG(INDEX) block.getBArguments().at(INDEX) -#define COPY_SHAPE(SRC, TGT) TGT = ShapeBuilders::copyShapeInfo(SRC, true, block.getWorkspace()) +#define COPY_SHAPE(SRC, TGT) TGT = ShapeBuilders::copyShapeInfo(SRC, true, block.workspace()) #define COPY_SHAPE_EX(SRC, TGT, WORKSPACE) TGT = ShapeBuilders::copyShapeInfo(SRC, true, WORKSPACE) diff --git a/libnd4j/include/system/op_enums.h b/libnd4j/include/system/op_enums.h index ad16d281e56b..5a1bb3217473 100644 --- a/libnd4j/include/system/op_enums.h +++ b/libnd4j/include/system/op_enums.h @@ -18,136 +18,91 @@ // @author raver119@gmail.com // - #ifndef LIBND4J_OP_ENUMS_H #define LIBND4J_OP_ENUMS_H #include -#include #include +#include namespace sd { - namespace random { - enum Ops { - BUILD_ENUMERATION(RANDOM_OPS) - }; - } - - namespace transform { - enum FloatOps { - BUILD_ENUMERATION(TRANSFORM_FLOAT_OPS) - }; - - enum SameOps { - BUILD_ENUMERATION(TRANSFORM_SAME_OPS) - }; - - enum BoolOps { - BUILD_ENUMERATION(TRANSFORM_BOOL_OPS) - }; - - enum AnyOps { - BUILD_ENUMERATION(TRANSFORM_ANY_OPS) - }; - - enum StrictOps { - BUILD_ENUMERATION(TRANSFORM_STRICT_OPS) - }; - } - - namespace pairwise { - enum Ops { - BUILD_ENUMERATION(PAIRWISE_TRANSFORM_OPS) - }; - - enum BoolOps { - BUILD_ENUMERATION(PAIRWISE_BOOL_OPS) - }; - - enum IntOps { - BUILD_ENUMERATION(PAIRWISE_INT_OPS) - }; - } - - namespace scalar { - enum Ops { - BUILD_ENUMERATION(SCALAR_OPS) - }; - - enum BoolOps { - BUILD_ENUMERATION(SCALAR_BOOL_OPS) - }; - - enum IntOps { - BUILD_ENUMERATION(SCALAR_INT_OPS) - }; - } - - namespace reduce { - enum FloatOps { - BUILD_ENUMERATION(REDUCE_FLOAT_OPS) - }; - - enum SameOps { - BUILD_ENUMERATION(REDUCE_SAME_OPS) - }; - - enum BoolOps { - BUILD_ENUMERATION(REDUCE_BOOL_OPS) - }; - - enum LongOps { - BUILD_ENUMERATION(REDUCE_LONG_OPS) - }; - } - - namespace reduce3 { - enum Ops { - BUILD_ENUMERATION(REDUCE3_OPS) - }; - } - - namespace indexreduce { - enum Ops { - BUILD_ENUMERATION(INDEX_REDUCE_OPS) - }; - } - - namespace broadcast { - enum Ops { - BUILD_ENUMERATION(BROADCAST_OPS) - }; - - enum BoolOps { - BUILD_ENUMERATION(BROADCAST_BOOL_OPS) - }; - - enum IntOps { - BUILD_ENUMERATION(BROADCAST_INT_OPS) - }; - } - - namespace variance { - enum Ops { - BUILD_ENUMERATION(SUMMARY_STATS_OPS) - }; - } - - namespace logic { - enum Ops { - While = 0, - Scope = 10, - Conditional = 20, - Switch = 30, - Return = 40, - Expose = 50, - Merge = 60, - LoopCond = 70, - NextIteration = 80, - Exit = 90, - Enter = 100, - }; - } +namespace random { +enum Ops { BUILD_ENUMERATION(RANDOM_OPS) }; +} + +namespace transform { +enum FloatOps { BUILD_ENUMERATION(TRANSFORM_FLOAT_OPS) }; + +enum SameOps { BUILD_ENUMERATION(TRANSFORM_SAME_OPS) }; + +enum BoolOps { BUILD_ENUMERATION(TRANSFORM_BOOL_OPS) }; + +enum AnyOps { BUILD_ENUMERATION(TRANSFORM_ANY_OPS) }; + +enum StrictOps { BUILD_ENUMERATION(TRANSFORM_STRICT_OPS) }; +} // namespace transform + +namespace pairwise { +enum Ops { BUILD_ENUMERATION(PAIRWISE_TRANSFORM_OPS) }; + +enum BoolOps { BUILD_ENUMERATION(PAIRWISE_BOOL_OPS) }; + +enum IntOps { BUILD_ENUMERATION(PAIRWISE_INT_OPS) }; +} // namespace pairwise + +namespace scalar { +enum Ops { BUILD_ENUMERATION(SCALAR_OPS) }; + +enum BoolOps { BUILD_ENUMERATION(SCALAR_BOOL_OPS) }; + +enum IntOps { BUILD_ENUMERATION(SCALAR_INT_OPS) }; +} // namespace scalar + +namespace reduce { +enum FloatOps { BUILD_ENUMERATION(REDUCE_FLOAT_OPS) }; + +enum SameOps { BUILD_ENUMERATION(REDUCE_SAME_OPS) }; + +enum BoolOps { BUILD_ENUMERATION(REDUCE_BOOL_OPS) }; + +enum LongOps { BUILD_ENUMERATION(REDUCE_LONG_OPS) }; +} // namespace reduce + +namespace reduce3 { +enum Ops { BUILD_ENUMERATION(REDUCE3_OPS) }; +} + +namespace indexreduce { +enum Ops { BUILD_ENUMERATION(INDEX_REDUCE_OPS) }; +} + +namespace broadcast { +enum Ops { BUILD_ENUMERATION(BROADCAST_OPS) }; + +enum BoolOps { BUILD_ENUMERATION(BROADCAST_BOOL_OPS) }; + +enum IntOps { BUILD_ENUMERATION(BROADCAST_INT_OPS) }; +} // namespace broadcast + +namespace variance { +enum Ops { BUILD_ENUMERATION(SUMMARY_STATS_OPS) }; +} + +namespace logic { +enum Ops { + While = 0, + Scope = 10, + Conditional = 20, + Switch = 30, + Return = 40, + Expose = 50, + Merge = 60, + LoopCond = 70, + NextIteration = 80, + Exit = 90, + Enter = 100, +}; } +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/system/openmp_pragmas.h b/libnd4j/include/system/openmp_pragmas.h index 667f54521d31..14a42241e6c6 100644 --- a/libnd4j/include/system/openmp_pragmas.h +++ b/libnd4j/include/system/openmp_pragmas.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_OPENMP_PRAGMAS_H -#define DEV_TESTS_OPENMP_PRAGMAS_H +#ifndef SD_OPENMP_PRAGMAS_H +#define SD_OPENMP_PRAGMAS_H #if defined(_MSC_VER) @@ -61,9 +61,8 @@ #else - #define OMP_STRINGIFY(args) #args -#define OMP_IF(args) if(args) +#define OMP_IF(args) if (args) #define OMP_SCHEDULE(args) schedule(args) #define OMP_MAXT maxT #define OMP_SUMT sumT @@ -73,13 +72,21 @@ #define PRAGMA_OMP_CRITICAL _Pragma(OMP_STRINGIFY(omp critical)) #define PRAGMA_OMP_SIMD _Pragma(OMP_STRINGIFY(omp simd)) #define PRAGMA_OMP_SIMD_ARGS(args) _Pragma(OMP_STRINGIFY(omp simd args)) -#define PRAGMA_OMP_SIMD_SUM(args) _Pragma(OMP_STRINGIFY(omp simd reduction(sumT:args))) -#define PRAGMA_OMP_SIMD_MAX(args) _Pragma(OMP_STRINGIFY(omp simd reduction(maxTF:args))) +#define PRAGMA_OMP_SIMD_SUM(args) \ + _Pragma(OMP_STRINGIFY(omp simd reduction(sumT : args))) +#define PRAGMA_OMP_SIMD_MAX(args) \ + _Pragma(OMP_STRINGIFY(omp simd reduction(maxTF : args))) #define PRAGMA_OMP_PARALLEL _Pragma(OMP_STRINGIFY(omp parallel default(shared))) -#define PRAGMA_OMP_PARALLEL_REDUCTION(args) _Pragma(OMP_STRINGIFY(omp parallel reduction(args) default(shared))) -#define PRAGMA_OMP_PARALLEL_ARGS(args) _Pragma(OMP_STRINGIFY(omp parallel args default(shared))) -#define PRAGMA_OMP_PARALLEL_THREADS(args) _Pragma(OMP_STRINGIFY(omp parallel num_threads(args) if(args > 1) default(shared))) -#define PRAGMA_OMP_PARALLEL_THREADS_IF(threads, condition) _Pragma(OMP_STRINGIFY(omp parallel num_threads(threads) if(condition) default(shared))) +#define PRAGMA_OMP_PARALLEL_REDUCTION(args) \ + _Pragma(OMP_STRINGIFY(omp parallel reduction(args) default(shared))) +#define PRAGMA_OMP_PARALLEL_ARGS(args) \ + _Pragma(OMP_STRINGIFY(omp parallel args default(shared))) +#define PRAGMA_OMP_PARALLEL_THREADS(args) \ + _Pragma(OMP_STRINGIFY( \ + omp parallel num_threads(args) if (args > 1) default(shared))) +#define PRAGMA_OMP_PARALLEL_THREADS_IF(threads, condition) \ + _Pragma(OMP_STRINGIFY( \ + omp parallel num_threads(threads) if (condition) default(shared))) #define PRAGMA_OMP_PARALLEL_FOR _Pragma(OMP_STRINGIFY(omp parallel for default(shared))) #define PRAGMA_OMP_PARALLEL_FOR_REDUCTION(args) _Pragma(OMP_STRINGIFY(omp parallel for reduction(args) default(shared))) #define PRAGMA_OMP_PARALLEL_FOR_ARGS(args) _Pragma(OMP_STRINGIFY(omp parallel for args default(shared))) @@ -92,7 +99,8 @@ #define PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(loops) _Pragma(OMP_STRINGIFY(omp parallel for simd default(shared) collapse(loops))) #define PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(args) _Pragma(OMP_STRINGIFY(omp parallel for simd reduction(args) default(shared))) #define PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(args) _Pragma(OMP_STRINGIFY(omp parallel for simd num_threads(args) if(args > 1) default(shared))) -#define PRAGMA_OMP_PARALLEL_SECTIONS _Pragma(OMP_STRINGIFY(omp parallel sections)) +#define PRAGMA_OMP_PARALLEL_SECTIONS \ + _Pragma(OMP_STRINGIFY(omp parallel sections)) #define PRAGMA_OMP_SECTION _Pragma(OMP_STRINGIFY(omp section)) #define PRAGMA_OMP_SINGLE _Pragma(OMP_STRINGIFY(omp single)) #define PRAGMA_OMP_SINGLE_ARGS(args) _Pragma(OMP_STRINGIFY(omp single args)) @@ -113,26 +121,43 @@ // parallel_for block #define FUNC_1D std::function -#define FUNC_2D std::function -#define FUNC_3D std::function +#define FUNC_2D \ + std::function +#define FUNC_3D \ + std::function // aggregation lambda #define LAMBDA_AL [&](int64_t _old, int64_t _new) -> int64_t #define LAMBDA_AD [&](double _old, double _new) -> double -#define LAMBDA_SUML LAMBDA_AL {return _old + _new; } -#define LAMBDA_SUMD LAMBDA_AD {return _old + _new; } +#define LAMBDA_SUML \ + LAMBDA_AL { return _old + _new; } +#define LAMBDA_SUMD \ + LAMBDA_AD { return _old + _new; } // reduction lambda -#define PRAGMA_REDUCE_LONG [&] (uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) mutable -> int64_t -#define PRAGMA_REDUCE_DOUBLE [&] (uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) mutable -> double +#define PRAGMA_REDUCE_LONG \ + [&](uint64_t thread_id, int64_t start, int64_t stop, \ + int64_t increment) mutable -> int64_t +#define PRAGMA_REDUCE_DOUBLE \ + [&](uint64_t thread_id, int64_t start, int64_t stop, \ + int64_t increment) mutable -> double // paralllel block lambda -#define PRAGMA_THREADS_DO [&](uint64_t thread_id, uint64_t numThreads) -> void +#define PRAGMA_THREADS_DO [&](uint64_t thread_id, uint64_t numThreads) -> void // paralllel_for lambdas -#define PRAGMA_THREADS_FOR [&](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void -#define PRAGMA_THREADS_FOR_2D [&](uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y) -> void -#define PRAGMA_THREADS_FOR_3D [&](uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z) -> void - -#endif //DEV_TESTS_OPENMP_PRAGMAS_H +#define PRAGMA_THREADS_FOR \ + [&](uint64_t thread_id, int64_t start, int64_t stop, \ + int64_t increment) -> void +#define PRAGMA_THREADS_FOR_2D \ + [&](uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t inc_x, \ + int64_t start_y, int64_t stop_y, int64_t inc_y) -> void +#define PRAGMA_THREADS_FOR_3D \ + [&](uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t inc_x, \ + int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, \ + int64_t stop_z, int64_t inc_z) -> void + +#endif // SD_OPENMP_PRAGMAS_H diff --git a/libnd4j/include/system/optype.h b/libnd4j/include/system/optype.h index 70323f104c8b..23008db8de25 100644 --- a/libnd4j/include/system/optype.h +++ b/libnd4j/include/system/optype.h @@ -16,7 +16,7 @@ #pragma once -//TODO convert this into an enum class +// TODO convert this into an enum class // might break JNI though... typedef int OpType; @@ -25,9 +25,7 @@ typedef int OpType; #define constexpr #endif -namespace op_type -{ - constexpr OpType Variance = 0; - constexpr OpType StandardDeviation = 1; -} - +namespace op_type { +constexpr OpType Variance = 0; +constexpr OpType StandardDeviation = 1; +} // namespace op_type diff --git a/libnd4j/include/system/pairwise_util.h b/libnd4j/include/system/pairwise_util.h old mode 100755 new mode 100644 index d9e0965c89d0..b494690259a5 --- a/libnd4j/include/system/pairwise_util.h +++ b/libnd4j/include/system/pairwise_util.h @@ -33,16 +33,17 @@ #endif #include -#include -#include -#include #include #include +#include +#include + +#include #ifdef _OPENMP #include #endif -//Loops adapted from: -//https://github.com/numpy/numpy/blob/009b17a85a22707e63ac9ea1896413992bbf9ce5/numpy/core/src/private/lowlevel_strided_loops.h#L401-L401 +// Loops adapted from: +// https://github.com/numpy/numpy/blob/009b17a85a22707e63ac9ea1896413992bbf9ce5/numpy/core/src/private/lowlevel_strided_loops.h#L401-L401 /* namespace shape { @@ -60,16 +61,15 @@ namespace shape { ************************************************************/ typedef struct { - int perm, stride; + int perm, stride; } StridePermutation; - /** * Credit to: * http://alienryderflex.com/quicksort/ * - * The non recursive implementation is important for being able to run on cuda as well - * as host. + * The non recursive implementation is important for being able to run on cuda + * as well as host. * * In practice the work loads intended for * this won't be so hard @@ -81,108 +81,97 @@ typedef struct { #ifdef __CUDACC__ __host__ __device__ #endif -void quickSort(StridePermutation *arr, int elements); - + void + quickSort(StridePermutation *arr, int elements); /* Start raw iteration */ #define ND4J_RAW_ITER_START(idim, ndim, coord, shape) \ - memset((coord), 0, (ndim) * sizeof(coord[0])); \ - do { - + memset((coord), 0, (ndim) * sizeof(coord[0])); \ + do { /* Increment to the next n-dimensional coordinate for one raw array */ #define ND4J_RAW_ITER_ONE_NEXT(idim, ndim, coord, shape, data, strides) \ - for ((idim) = 0; (idim) < (ndim); (idim)++) { \ - if (++(coord)[idim] == (shape)[idim]) { \ - (coord)[idim] = 0; \ - (data) -= ((shape)[idim] - 1) * (strides)[idim]; \ - } \ - else { \ - (data) += (strides)[idim]; \ - break; \ - } \ - } \ - } while ((idim) < (ndim)) + for ((idim) = 0; (idim) < (ndim); (idim)++) { \ + if (++(coord)[idim] == (shape)[idim]) { \ + (coord)[idim] = 0; \ + (data) -= ((shape)[idim] - 1) * (strides)[idim]; \ + } else { \ + (data) += (strides)[idim]; \ + break; \ + } \ + } \ + } \ + while ((idim) < (ndim)) #define ND4J_RAW_ITER_ONE_NEXTF(idim, ndim, coord, shape, data, strides) \ - for ((idim) = ndim - 1; (idim) >= (0); (idim)--) { \ - if (++(coord)[idim] == (shape)[idim]) { \ - (coord)[idim] = 0; \ - (data) -= ((shape)[idim] - 1) * (strides)[idim]; \ - } \ - else { \ - (data) += (strides)[idim]; \ - break; \ - } \ - } \ - } while ((idim) >= (0)) - + for ((idim) = ndim - 1; (idim) >= (0); (idim)--) { \ + if (++(coord)[idim] == (shape)[idim]) { \ + (coord)[idim] = 0; \ + (data) -= ((shape)[idim] - 1) * (strides)[idim]; \ + } else { \ + (data) += (strides)[idim]; \ + break; \ + } \ + } \ + } \ + while ((idim) >= (0)) /* Increment to the next n-dimensional coordinate for two raw arrays */ -#define ND4J_RAW_ITER_TWO_NEXT(idim, ndim, coord, shape, \ - dataA, stridesA, dataB, stridesB) \ - for ((idim) = 0; (idim) < (ndim); (idim)++) { \ - if (++(coord)[idim] == (shape)[idim]) { \ - (coord)[idim] = 0; \ - (dataA) -= ((shape)[idim] - 1) * (stridesA)[idim]; \ - (dataB) -= ((shape)[idim] - 1) * (stridesB)[idim]; \ - } \ - else { \ - (dataA) += (stridesA)[idim]; \ - (dataB) += (stridesB)[idim]; \ - break; \ - } \ - } \ - } while ((idim) < (ndim)) +#define ND4J_RAW_ITER_TWO_NEXT(idim, ndim, coord, shape, dataA, stridesA, \ + dataB, stridesB) \ + for ((idim) = 0; (idim) < (ndim); (idim)++) { \ + if (++(coord)[idim] == (shape)[idim]) { \ + (coord)[idim] = 0; \ + (dataA) -= ((shape)[idim] - 1) * (stridesA)[idim]; \ + (dataB) -= ((shape)[idim] - 1) * (stridesB)[idim]; \ + } else { \ + (dataA) += (stridesA)[idim]; \ + (dataB) += (stridesB)[idim]; \ + break; \ + } \ + } \ + } \ + while ((idim) < (ndim)) /* Increment to the next n-dimensional coordinate for three raw arrays */ -#define ND4J_RAW_ITER_THREE_NEXT(idim, ndim, coord, shape, \ - dataA, stridesA, \ - dataB, stridesB, \ - dataC, stridesC) \ - for ((idim) = 0; (idim) < (ndim); (idim)++) { \ - if (++(coord)[idim] == (shape)[idim]) { \ - (coord)[idim] = 0; \ - (dataA) -= ((shape)[idim] - 1) * (stridesA)[idim]; \ - (dataB) -= ((shape)[idim] - 1) * (stridesB)[idim]; \ - (dataC) -= ((shape)[idim] - 1) * (stridesC)[idim]; \ - } \ - else { \ - (dataA) += (stridesA)[idim]; \ - (dataB) += (stridesB)[idim]; \ - (dataC) += (stridesC)[idim]; \ - break; \ - } \ - } \ - } while ((idim) < (ndim)) +#define ND4J_RAW_ITER_THREE_NEXT(idim, ndim, coord, shape, dataA, stridesA, \ + dataB, stridesB, dataC, stridesC) \ + for ((idim) = 0; (idim) < (ndim); (idim)++) { \ + if (++(coord)[idim] == (shape)[idim]) { \ + (coord)[idim] = 0; \ + (dataA) -= ((shape)[idim] - 1) * (stridesA)[idim]; \ + (dataB) -= ((shape)[idim] - 1) * (stridesB)[idim]; \ + (dataC) -= ((shape)[idim] - 1) * (stridesC)[idim]; \ + } else { \ + (dataA) += (stridesA)[idim]; \ + (dataB) += (stridesB)[idim]; \ + (dataC) += (stridesC)[idim]; \ + break; \ + } \ + } \ + } \ + while ((idim) < (ndim)) /* Increment to the next n-dimensional coordinate for four raw arrays */ -#define ND4J_RAW_ITER_FOUR_NEXT(idim, ndim, coord, shape, \ - dataA, stridesA, \ - dataB, stridesB, \ - dataC, stridesC, \ - dataD, stridesD) \ - for ((idim) = 0; (idim) < (ndim); (idim)++) { \ - if (++(coord)[idim] == (shape)[idim]) { \ - (coord)[idim] = 0; \ - (dataA) -= ((shape)[idim] - 1) * (stridesA)[idim]; \ - (dataB) -= ((shape)[idim] - 1) * (stridesB)[idim]; \ - (dataC) -= ((shape)[idim] - 1) * (stridesC)[idim]; \ - (dataD) -= ((shape)[idim] - 1) * (stridesD)[idim]; \ - } \ - else { \ - (dataA) += (stridesA)[idim]; \ - (dataB) += (stridesB)[idim]; \ - (dataC) += (stridesC)[idim]; \ - (dataD) += (stridesD)[idim]; \ - break; \ - } \ - } \ - } while ((idim) < (ndim)) - - - - - +#define ND4J_RAW_ITER_FOUR_NEXT(idim, ndim, coord, shape, dataA, stridesA, \ + dataB, stridesB, dataC, stridesC, dataD, \ + stridesD) \ + for ((idim) = 0; (idim) < (ndim); (idim)++) { \ + if (++(coord)[idim] == (shape)[idim]) { \ + (coord)[idim] = 0; \ + (dataA) -= ((shape)[idim] - 1) * (stridesA)[idim]; \ + (dataB) -= ((shape)[idim] - 1) * (stridesB)[idim]; \ + (dataC) -= ((shape)[idim] - 1) * (stridesC)[idim]; \ + (dataD) -= ((shape)[idim] - 1) * (stridesD)[idim]; \ + } else { \ + (dataA) += (stridesA)[idim]; \ + (dataB) += (stridesB)[idim]; \ + (dataC) += (stridesC)[idim]; \ + (dataD) += (stridesD)[idim]; \ + break; \ + } \ + } \ + } \ + while ((idim) < (ndim)) /*NUMPY_API * @@ -194,18 +183,18 @@ void quickSort(StridePermutation *arr, int elements); #ifdef __CUDACC__ __host__ __device__ #endif -inline void SortStrideArray(int ndim, int strides[], - StridePermutation *out_strideperm) { - - /* Set up the strideperm values */ - for (int i = 0; i < ndim; i++) { - out_strideperm[i].perm = i; - out_strideperm[i].stride = strides[i]; - } - - /* Sort them */ - quickSort(out_strideperm,ndim); - + inline void + SortStrideArray(int ndim, int strides[], + StridePermutation *out_strideperm) { + + /* Set up the strideperm values */ + for (int i = 0; i < ndim; i++) { + out_strideperm[i].perm = i; + out_strideperm[i].stride = strides[i]; + } + + /* Sort them */ + quickSort(out_strideperm, ndim); } /* @@ -224,21 +213,19 @@ inline void SortStrideArray(int ndim, int strides[], * Returns 0 on success, -1 on failure. */ template - #ifdef __CUDACC__ __host__ __device__ #endif -inline int PrepareOneRawArrayIter(int ndim, Nd4jLong shape[], - T data[], Nd4jLong strides[], - int *out_ndim, Nd4jLong outShape[], - T **out_data, Nd4jLong *outStrides) { - - for (int i = 0; i < ndim; i++) { - outShape[i] = shape[i]; - outStrides[i] = strides[i]; - } - + inline int + PrepareOneRawArrayIter(int ndim, Nd4jLong shape[], T data[], + Nd4jLong strides[], int *out_ndim, + Nd4jLong outShape[], T **out_data, + Nd4jLong *outStrides) { + for (int i = 0; i < ndim; i++) { + outShape[i] = shape[i]; + outStrides[i] = strides[i]; + } #if 0 /* DEBUG */ @@ -257,51 +244,47 @@ inline int PrepareOneRawArrayIter(int ndim, Nd4jLong shape[], } #endif - *out_data = data; - *out_ndim = ndim; - return 0; + *out_data = data; + *out_ndim = ndim; + return 0; } - class BlockInformation { -public: - Nd4jLong items; - int threads; - Nd4jLong chunks; - Nd4jLong modulo; - Nd4jLong remainder; - - BlockInformation(Nd4jLong length, int threshold) { - + public: + Nd4jLong items; + int threads; + Nd4jLong chunks; + Nd4jLong modulo; + Nd4jLong remainder; + + BlockInformation(Nd4jLong length, int threshold) { threads = length / threshold; - threads = (1 < threads)?threads:1;//sd::math::nd4j_max(1, threads); - threads = (threads < omp_get_max_threads())?threads:omp_get_max_threads();//sd::math::nd4j_min(threads, omp_get_max_threads()); + threads = + (1 < threads) ? threads : 1; // sd::math::nd4j_max(1, threads); + threads = (threads < omp_get_max_threads()) + ? threads + : omp_get_max_threads(); // sd::math::nd4j_min(threads, + // omp_get_max_threads()); items = length / threads; remainder = length % threads; - if(items < 1) - items = 1; + if (items < 1) items = 1; chunks = length / items; modulo = length % items; - //one left over chunk - if(modulo > 0) - chunks++; - } -}; - - -class CudaBlockInformation { - + // one left over chunk + if (modulo > 0) chunks++; + } }; +class CudaBlockInformation {}; /** * Credit to: * http://alienryderflex.com/quicksort/ * - * The non recursive implementation is important for being able to run on cuda as well - * as host. + * The non recursive implementation is important for being able to run on cuda + * as well as host. * * In practice the work loads intended for * this won't be so hard @@ -313,46 +296,42 @@ class CudaBlockInformation { #ifdef __CUDACC__ __host__ __device__ #endif -inline void quickSort(StridePermutation *arr, int elements) { -#define MAX_LEVELS 300 - - int beg[MAX_LEVELS], end[MAX_LEVELS], i= 0, L, R, swap ; - StridePermutation piv; - beg[0] = 0; - end[0] = elements; - while (i >= 0) { - L = beg[i]; - R= end[i] - 1; - if (L < R) { - piv = arr[L]; - while (L < R) { - while (arr[R].stride >= piv.stride && L < R) - R--; - if (L < R) - arr[L++] = arr[R]; - while (arr[L].stride <= piv.stride && L < R) - L++; - if (L end[i - 1] - beg[i - 1]) { - swap = beg[i]; - beg[i]= beg[i - 1]; - beg[i - 1] = swap; - swap = end[i]; - end[i] = end[i - 1]; - end[i - 1] = swap; - } - } - else { - i--; - } + inline void + quickSort(StridePermutation *arr, int elements) { +#define MAX_LEVELS 300 + + int beg[MAX_LEVELS], end[MAX_LEVELS], i = 0, L, R, swap; + StridePermutation piv; + beg[0] = 0; + end[0] = elements; + while (i >= 0) { + L = beg[i]; + R = end[i] - 1; + if (L < R) { + piv = arr[L]; + while (L < R) { + while (arr[R].stride >= piv.stride && L < R) R--; + if (L < R) arr[L++] = arr[R]; + while (arr[L].stride <= piv.stride && L < R) L++; + if (L < R) arr[R--] = arr[L]; + } + + arr[L] = piv; + beg[i + 1] = L + 1; + end[i + 1] = end[i]; + end[i++] = L; + if (end[i] - beg[i] > end[i - 1] - beg[i - 1]) { + swap = beg[i]; + beg[i] = beg[i - 1]; + beg[i - 1] = swap; + swap = end[i]; + end[i] = end[i - 1]; + end[i - 1] = swap; + } + } else { + i--; } + } } /** @@ -372,50 +351,48 @@ inline void quickSort(StridePermutation *arr, int elements) { * Returns 0 on success, -1 on failure. */ template -int _CUDA_HD PrepareTwoRawArrayIter(int ndim, Nd4jLong *shape, - X *dataA, Nd4jLong *stridesA, - Y *dataB, Nd4jLong *stridesB, - int *out_ndim, Nd4jLong *outShape, - X **out_dataA, Nd4jLong *outStridesA, - Y **out_dataB, Nd4jLong *outStridesB) { - int i; - -/* Sort the axes based on the destination strides */ - for (i = 0; i < ndim; ++i) { - outShape[i] = shape[i]; - outStridesA[i] = stridesA[i]; - outStridesB[i] = stridesB[i]; +int _CUDA_HD PrepareTwoRawArrayIter(int ndim, Nd4jLong *shape, X *dataA, + Nd4jLong *stridesA, Y *dataB, + Nd4jLong *stridesB, int *out_ndim, + Nd4jLong *outShape, X **out_dataA, + Nd4jLong *outStridesA, Y **out_dataB, + Nd4jLong *outStridesB) { + int i; + + /* Sort the axes based on the destination strides */ + for (i = 0; i < ndim; ++i) { + outShape[i] = shape[i]; + outStridesA[i] = stridesA[i]; + outStridesB[i] = stridesB[i]; + } + + /* Reverse any negative strides of operand A */ + for (i = 0; i < ndim; i++) { + Nd4jLong stride_entryA = outStridesA[i]; + Nd4jLong stride_entryB = outStridesB[i]; + Nd4jLong shape_entry = outShape[i]; + + if (stride_entryA < 0) { + dataA += stride_entryA * (shape_entry - 1); + dataB += stride_entryB * (shape_entry - 1); + outStridesA[i] = -stride_entryA; + outStridesB[i] = -stride_entryB; } - -/* Reverse any negative strides of operand A */ - for (i = 0; i < ndim; i++) { - Nd4jLong stride_entryA = outStridesA[i]; - Nd4jLong stride_entryB = outStridesB[i]; - Nd4jLong shape_entry = outShape[i]; - - if (stride_entryA < 0) { - dataA += stride_entryA * (shape_entry - 1); - dataB += stride_entryB * (shape_entry - 1); - outStridesA[i] = -stride_entryA; - outStridesB[i] = -stride_entryB; - } -/* Detect 0-size arrays here */ - if (shape_entry == 0) { - *out_ndim = 1; - *out_dataA = dataA; - *out_dataB = dataB; - outShape[0] = 0; - outStridesA[0] = 0; - outStridesB[0] = 0; - return 0; - } + /* Detect 0-size arrays here */ + if (shape_entry == 0) { + *out_ndim = 1; + *out_dataA = dataA; + *out_dataB = dataB; + outShape[0] = 0; + outStridesA[0] = 0; + outStridesB[0] = 0; + return 0; } + } - - *out_dataA = dataA; - *out_dataB = dataB; - *out_ndim = ndim; - + *out_dataA = dataA; + *out_dataB = dataB; + *out_ndim = ndim; #if 0 /* DEBUG */ @@ -443,7 +420,7 @@ int _CUDA_HD PrepareTwoRawArrayIter(int ndim, Nd4jLong *shape, } #endif - return 0; + return 0; } /** @@ -466,97 +443,93 @@ template #ifdef __CUDACC__ __host__ __device__ #endif -int PrepareThreeRawArrayIter(int ndim, Nd4jLong shape[], - X *dataA, Nd4jLong *stridesA, - Y *dataB, Nd4jLong *stridesB, - Z *dataC, Nd4jLong *stridesC, - int &out_ndim, Nd4jLong *outShape, - X **out_dataA, Nd4jLong outStridesA[], - Y **out_dataB, Nd4jLong outStridesB[], - Z **out_dataC, Nd4jLong outStridesC[]) -{ - - /* Special case 0 and 1 dimensions */ - if (ndim == 0) { - out_ndim = 1; - *out_dataA = dataA; - *out_dataB = dataB; - *out_dataC = dataC; - outShape[0] = 1; - outStridesA[0] = 0; - outStridesB[0] = 0; - outStridesC[0] = 0; - return 0; - } - else if (ndim == 1) { - auto stride_entryA = stridesA[0]; - auto stride_entryB = stridesB[0]; - auto stride_entryC = stridesC[0]; - auto shape_entry = shape[0]; - out_ndim = 1; - outShape[0] = shape[0]; - /* Always make a positive stride for the first operand */ - if (stride_entryA >= 0) { - *out_dataA = dataA; - *out_dataB = dataB; - *out_dataC = dataC; - outStridesA[0] = stride_entryA; - outStridesB[0] = stride_entryB; - outStridesC[0] = stride_entryC; - } - else { - *out_dataA = dataA + stride_entryA * (shape_entry - 1); - *out_dataB = dataB + stride_entryB * (shape_entry - 1); - *out_dataC = dataC + stride_entryC * (shape_entry - 1); - outStridesA[0] = -stride_entryA; - outStridesB[0] = -stride_entryB; - outStridesC[0] = -stride_entryC; - } - return 0; - } - - for (int i = 0; i < ndim; ++i) { - outShape[i] = shape[i]; - outStridesA[i] = stridesA[i]; - outStridesB[i] = stridesB[i]; - outStridesC[i] = stridesC[i]; - } - - /* Reverse any negative strides of operand A */ - for (int i = 0; i < ndim; ++i) { - auto stride_entryA = outStridesA[i]; - auto stride_entryB = outStridesB[i]; - auto stride_entryC = outStridesC[i]; - auto shape_entry = outShape[i]; - - if (stride_entryA < 0) { - dataA += stride_entryA * (shape_entry - 1); - dataB += stride_entryB * (shape_entry - 1); - dataC += stride_entryC * (shape_entry - 1); - outStridesA[i] = -stride_entryA; - outStridesB[i] = -stride_entryB; - outStridesC[i] = -stride_entryC; - } - /* Detect 0-size arrays here */ - if (shape_entry == 0) { - out_ndim = 1; - *out_dataA = dataA; - *out_dataB = dataB; - *out_dataC = dataC; - outShape[0] = 0; - outStridesA[0] = 0; - outStridesB[0] = 0; - outStridesC[0] = 0; - return 0; - } - } - - + int + PrepareThreeRawArrayIter(int ndim, Nd4jLong shape[], X *dataA, + Nd4jLong *stridesA, Y *dataB, Nd4jLong *stridesB, + Z *dataC, Nd4jLong *stridesC, int &out_ndim, + Nd4jLong *outShape, X **out_dataA, + Nd4jLong outStridesA[], Y **out_dataB, + Nd4jLong outStridesB[], Z **out_dataC, + Nd4jLong outStridesC[]) { + + /* Special case 0 and 1 dimensions */ + if (ndim == 0) { + out_ndim = 1; *out_dataA = dataA; *out_dataB = dataB; *out_dataC = dataC; - out_ndim = ndim; + outShape[0] = 1; + outStridesA[0] = 0; + outStridesB[0] = 0; + outStridesC[0] = 0; + return 0; + } else if (ndim == 1) { + auto stride_entryA = stridesA[0]; + auto stride_entryB = stridesB[0]; + auto stride_entryC = stridesC[0]; + auto shape_entry = shape[0]; + out_ndim = 1; + outShape[0] = shape[0]; + /* Always make a positive stride for the first operand */ + if (stride_entryA >= 0) { + *out_dataA = dataA; + *out_dataB = dataB; + *out_dataC = dataC; + outStridesA[0] = stride_entryA; + outStridesB[0] = stride_entryB; + outStridesC[0] = stride_entryC; + } else { + *out_dataA = dataA + stride_entryA * (shape_entry - 1); + *out_dataB = dataB + stride_entryB * (shape_entry - 1); + *out_dataC = dataC + stride_entryC * (shape_entry - 1); + outStridesA[0] = -stride_entryA; + outStridesB[0] = -stride_entryB; + outStridesC[0] = -stride_entryC; + } return 0; + } + + for (int i = 0; i < ndim; ++i) { + outShape[i] = shape[i]; + outStridesA[i] = stridesA[i]; + outStridesB[i] = stridesB[i]; + outStridesC[i] = stridesC[i]; + } + + /* Reverse any negative strides of operand A */ + for (int i = 0; i < ndim; ++i) { + auto stride_entryA = outStridesA[i]; + auto stride_entryB = outStridesB[i]; + auto stride_entryC = outStridesC[i]; + auto shape_entry = outShape[i]; + + if (stride_entryA < 0) { + dataA += stride_entryA * (shape_entry - 1); + dataB += stride_entryB * (shape_entry - 1); + dataC += stride_entryC * (shape_entry - 1); + outStridesA[i] = -stride_entryA; + outStridesB[i] = -stride_entryB; + outStridesC[i] = -stride_entryC; + } + /* Detect 0-size arrays here */ + if (shape_entry == 0) { + out_ndim = 1; + *out_dataA = dataA; + *out_dataB = dataB; + *out_dataC = dataC; + outShape[0] = 0; + outStridesA[0] = 0; + outStridesB[0] = 0; + outStridesC[0] = 0; + return 0; + } + } + + *out_dataA = dataA; + *out_dataB = dataB; + *out_dataC = dataC; + out_ndim = ndim; + return 0; } -#endif //NATIVEOPERATIONS_PAIRWISE_UTIL_H +#endif // NATIVEOPERATIONS_PAIRWISE_UTIL_H diff --git a/libnd4j/include/system/platform_boilerplate.h b/libnd4j/include/system/platform_boilerplate.h index b74a0530fe4d..2ac6d4cddc1d 100644 --- a/libnd4j/include/system/platform_boilerplate.h +++ b/libnd4j/include/system/platform_boilerplate.h @@ -23,35 +23,35 @@ #include - - -#define CONCATP(A,B) A ##_##B - - -#define DECLARE_PLATFORM_F(NAME, ENGINE, CNAME) class ND4J_EXPORT PLATFORM_##CNAME : public PlatformHelper {\ - public: \ - PLATFORM_##CNAME() : PlatformHelper(#NAME, samediff::Engine::ENGINE) { } \ - bool isUsable(graph::Context &context) override; \ - Nd4jStatus invokeHelper(graph::Context &context) override; \ - }; - -#define DECLARE_PLATFORM(NAME, ENGINE) DECLARE_PLATFORM_F(NAME, ENGINE, NAME ##_## ENGINE) - -#define PLATFORM_IMPL_F(NAME, ENGINE, CNAME) struct ND4J_EXPORT __registratorPlatformHelper_##CNAME { \ - __registratorPlatformHelper_##CNAME() { \ - auto helper = new PLATFORM_##CNAME(); \ - OpRegistrator::getInstance().registerHelper(helper); \ - } \ - }; \ - static __registratorPlatformHelper_##CNAME platformHelper_##CNAME; \ - Nd4jStatus PLATFORM_##CNAME::invokeHelper(sd::graph::Context &block) - - -#define PLATFORM_IMPL(NAME, ENGINE) PLATFORM_IMPL_F(NAME, ENGINE, NAME ##_## ENGINE) - - -#define PLATFORM_CHECK_F(NAME, ENGINE, CNAME) bool PLATFORM_##CNAME::isUsable(graph::Context &block) -#define PLATFORM_CHECK(NAME, ENGINE) PLATFORM_CHECK_F(NAME, ENGINE, NAME ##_## ENGINE) - - -#endif //SD_PLATFORM_BOILERPLATE_H +#define CONCATP(A, B) A##_##B + +#define DECLARE_PLATFORM_F(NAME, ENGINE, CNAME) \ + class SD_EXPORT PLATFORM_##CNAME : public PlatformHelper { \ + public: \ + PLATFORM_##CNAME() : PlatformHelper(#NAME, samediff::Engine::ENGINE) {} \ + bool isUsable(graph::Context &context) override; \ + Nd4jStatus invokeHelper(graph::Context &context) override; \ + }; + +#define DECLARE_PLATFORM(NAME, ENGINE) \ + DECLARE_PLATFORM_F(NAME, ENGINE, NAME##_##ENGINE) + +#define PLATFORM_IMPL_F(NAME, ENGINE, CNAME) \ + struct SD_EXPORT __registratorPlatformHelper_##CNAME { \ + __registratorPlatformHelper_##CNAME() { \ + auto helper = new PLATFORM_##CNAME(); \ + OpRegistrator::getInstance().registerHelper(helper); \ + } \ + }; \ + static __registratorPlatformHelper_##CNAME platformHelper_##CNAME; \ + Nd4jStatus PLATFORM_##CNAME::invokeHelper(sd::graph::Context &block) + +#define PLATFORM_IMPL(NAME, ENGINE) \ + PLATFORM_IMPL_F(NAME, ENGINE, NAME##_##ENGINE) + +#define PLATFORM_CHECK_F(NAME, ENGINE, CNAME) \ + bool PLATFORM_##CNAME::isUsable(graph::Context &block) +#define PLATFORM_CHECK(NAME, ENGINE) \ + PLATFORM_CHECK_F(NAME, ENGINE, NAME##_##ENGINE) + +#endif // SD_PLATFORM_BOILERPLATE_H diff --git a/libnd4j/include/system/play.h b/libnd4j/include/system/play.h index 9e121d88bd92..37a31def95ec 100644 --- a/libnd4j/include/system/play.h +++ b/libnd4j/include/system/play.h @@ -30,7 +30,7 @@ #define Y_TYPES \ (DATA_INT8, int8_t) ,\ - (DATA_INT16, int16_t) + (DATA_INT16, int16_t) #define Z_TYPES \ (DATA_UINT8, uint8_t) ,\ @@ -38,7 +38,7 @@ #define PWT_LIST \ (float, long, float),\ - (float, long, long) + (float, long, long) BUILD_SINGLE_TEMPLATE_TWICE(template class functionName, , DATA_TYPES) @@ -46,23 +46,27 @@ BUILD_SINGLE_TEMPLATE_TWICE(template class functionName, , DATA_TYPES) DECLARE_PLATFORM(conv2d, ENGINE_CPU) -//BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functionName, (signature), DATA_TYPES, Y_TYPES); +// BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functionName, (signature), +// DATA_TYPES, Y_TYPES); -//BUILD_SINGLE_UNCHAINED_TEMPLATE(functionName , (signature), Y_TYPES); +// BUILD_SINGLE_UNCHAINED_TEMPLATE(functionName , (signature), Y_TYPES); -//BUILD_TRIPLE_SELECTOR(xType, yType, zType, functionName, (signature), DATA_TYPES, Y_TYPES, Z_TYPES) +// BUILD_TRIPLE_SELECTOR(xType, yType, zType, functionName, (signature), +// DATA_TYPES, Y_TYPES, Z_TYPES) +// BUILD_TRIPLE_TEMPLATE(functionName, (signature), DATA_TYPES, Y_TYPES, +// Z_TYPES) -//BUILD_TRIPLE_TEMPLATE(functionName, (signature), DATA_TYPES, Y_TYPES, Z_TYPES) +// BUILD_ENUMERATION(DATA_TYPES) -//BUILD_ENUMERATION(DATA_TYPES) +// BUILD_SINGLE_SELECTOR(xType, functions::IndexReduce, ::op(a, b, c, d, e), +// DATA_TYPES) BUILD_DOUBLE_SELECTOR(xType, yType, functions::IndexReduce, +// ::op(a, b, c, d, e), DATA_TYPES, DATA_TYPES) -//BUILD_SINGLE_SELECTOR(xType, functions::IndexReduce, ::op(a, b, c, d, e), DATA_TYPES) -//BUILD_DOUBLE_SELECTOR(xType, yType, functions::IndexReduce, ::op(a, b, c, d, e), DATA_TYPES, DATA_TYPES) +// BUILD_SINGLE_TEMPLATE(template class Alpha, (signature), DATA_TYPES); -//BUILD_SINGLE_TEMPLATE(template class Alpha, (signature), DATA_TYPES); - -//BUILD_DOUBLE_TEMPLATE(template class Alpha, (signature) , DATA_TYPES, DATA_TYPES); +// BUILD_DOUBLE_TEMPLATE(template class Alpha, (signature) , DATA_TYPES, +// DATA_TYPES); /* #define SCALAR_OPS \ @@ -111,60 +115,94 @@ DECLARE_PLATFORM(conv2d, ENGINE_CPU) EXECUTE_NOE((x, y, extras), OPS_A(PAIRWISE_TRANSFORM_OPS)) */ +// EXECUTE_NOE((x, extras), OPS_A(SCALAR_OPS)) -//EXECUTE_NOE((x, extras), OPS_A(SCALAR_OPS)) - -//BUILD_CALL_1(template void sd::NDArray::applyTransform, float16, (NDArray* a, float16* b), TRANSFORM_OPS) +// BUILD_CALL_1(template void sd::NDArray::applyTransform, float16, +// (NDArray* a, float16* b), TRANSFORM_OPS) -//BUILD_CALL_1(template void sd::NDArray::applyPairwiseTransform, float16, (NDArray* other, float16* extraParams), PAIRWISE_TRANSFORM_OPS) -//BUILD_TRACKER(TRANSFORM, ACTIVATIONS) +// BUILD_CALL_1(template void sd::NDArray::applyPairwiseTransform, +// float16, (NDArray* other, float16* extraParams), +// PAIRWISE_TRANSFORM_OPS) BUILD_TRACKER(TRANSFORM, ACTIVATIONS) -//BUILD_CALL_1(template void sd::NDArray::applyScalar, float16, (float16 scalar, NDArray* target, float16 *extraParams) , ACTIVATIONS); +// BUILD_CALL_1(template void sd::NDArray::applyScalar, float16, +// (float16 scalar, NDArray* target, float16 *extraParams) , +// ACTIVATIONS); /* -#define DECLARE_OP(NAME, NIN, NOUT) DECLARE_OP_UNIQ(__COUNTER__, NAME, NIN, NOUT) +#define DECLARE_OP(NAME, NIN, NOUT) DECLARE_OP_UNIQ(__COUNTER__, NAME, NIN, +NOUT) #define DECLARE_OP_UNIQ(CTR, NAME, NIN, NOUT) template \ - class NAME: public sd::ops::DeclarableOp { \ + class NAME: public +sd::ops::DeclarableOp { \ public:\ - NAME() : sd::ops::DeclarableOp(NIN, NOUT, #NAME) { } \ + NAME() : +sd::ops::DeclarableOp(NIN, NOUT, #NAME) { } \ protected: \ - Nd4jStatus validateAndExecute(Block& block); \ + Nd4jStatus +validateAndExecute(Block& block); \ };\ template \ - Nd4jStatus sd::ops::NAME::validateAndExecute(Block& block) + Nd4jStatus +sd::ops::NAME::validateAndExecute(Block& block) */ -//#define END_OP(NAME) }; static sd::ops::__registrator> register_op##Name; +//#define END_OP(NAME) }; static sd::ops::__registrator> +// register_op##Name; //#DECLARE_OP(Concat, -1, 1) -//END_OP(Concat) - - -//BUILD_LAYERS_FACTORY(float, OPS_A(NATIVE_LAYERS), OPS_B(ACTIVATIONS)) +// END_OP(Concat) +// BUILD_LAYERS_FACTORY(float, OPS_A(NATIVE_LAYERS), OPS_B(ACTIVATIONS)) -//DISPATCH_SIMPLE(scalarAlongDimension_, float, PARAMS(x, xShapeInfo, extraParamx, z, zShapeInfo, scalars, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), OPS_A(SCALAR_OPS)) +// DISPATCH_SIMPLE(scalarAlongDimension_, float, PARAMS(x, xShapeInfo, +// extraParamx, z, zShapeInfo, scalars, tadShapeInfo, tadOffsets, tadShapeInfoZ, +// tadOffsetsZ), OPS_A(SCALAR_OPS)) -//_EXEC_KERNEL_F(scalarAlongDimension_, scalarAlongDimensionGeneric, float, (float inputA, float inputB), (paramA, paramB), (10, SCALAR::Add), (11, SCALAR::Subtract), (12, SCALAR::Multiply)) +//_EXEC_KERNEL_F(scalarAlongDimension_, scalarAlongDimensionGeneric, float, +//(float inputA, float inputB), (paramA, paramB), (10, SCALAR::Add), (11, +// SCALAR::Subtract), (12, SCALAR::Multiply)) -//DISPATCH_KERNEL_SIMPLE(scalarAlongDimension_, scalarAlongDimensionGeneric, float, INPUT(float inputA, float inputB), PARAMS(paramA, paramB), OPS_A(SCALAR_OPS)) +// DISPATCH_KERNEL_SIMPLE(scalarAlongDimension_, scalarAlongDimensionGeneric, +// float, INPUT(float inputA, float inputB), PARAMS(paramA, paramB), +// OPS_A(SCALAR_OPS)) // original version -// DISPATCH_METAOP(functions::pairwise_transforms::PairWiseTransform::template transformCuda, PARAMS(N, dx, dy, xStride, yStride, paramsPtr, dz, zStride, nullptr, nullptr, nullptr), InvertedMetaOp, OPS_A(SCALAR_OPS), OPS_B(PAIRWISE_TRANSFORM_OPS)) +// DISPATCH_METAOP(functions::pairwise_transforms::PairWiseTransform::template +// transformCuda, PARAMS(N, dx, dy, xStride, yStride, paramsPtr, dz, zStride, +// nullptr, nullptr, nullptr), InvertedMetaOp, OPS_A(SCALAR_OPS), +// OPS_B(PAIRWISE_TRANSFORM_OPS)) /* -DISPATCH_METAOP(invertedMetaPairwiseShaped_Pairwise_Scalar, PARAMS(opTypeA, opTypeB, N, x, xShape, y, yShape, z, zShape, extrasA, extrasB, scalarA, scalarB), float, OPS_A(PAIRWISE_TRANSFORM_OPS), OPS_B(SCALAR_OPS));*/ - -//DISPATCH_KERNEL_META(invertedMetaPairwiseShaped_Pairwise_Scalar_, invertedMetaPairwiseShapedGeneric, float, simdOps::InvertedMetaOp, INPUT(const int opTypeA, const int opTypeB, long N, float *dx, int *xShapeInfo, float *dy, int *yShapeInfo, float *dz, int *zShapeInfo, float *extraA, float *extraB, float scalarA, float scalarB), PARAMS(opTypeA, opTypeB, N, dx, xShapeInfo, dy, yShapeInfo, dz, zShapeInfo, extraA, extraB, scalarA, scalarB), OPS_A(PAIRWISE_TRANSFORM_OPS), OPS_B(SCALAR_OPS)) - -//_EXPAND_KERNEL_CALL(invertedMetaPairwiseShaped_Pairwise_Scalar_, invertedMetaPairwiseShapedGeneric, float, simdOps::InvertedMetaOp, INPUT(const int opTypeA, const int opTypeB, long N, float *dx, int *xShapeInfo, float *dy, int *yShapeInfo, float *dz, int *zShapeInfo, float *extraA, float *extraB, float scalarA, float scalarB), PARAMS(N, dx, dy, xStride, yStride, paramsPtr, dz, zStride, nullptr, nullptr, nullptr), 66, simdOps::SomeOpA, 99, simdOps::SomeOpB) +DISPATCH_METAOP(invertedMetaPairwiseShaped_Pairwise_Scalar, PARAMS(opTypeA, +opTypeB, N, x, xShape, y, yShape, z, zShape, extrasA, extrasB, scalarA, +scalarB), float, OPS_A(PAIRWISE_TRANSFORM_OPS), OPS_B(SCALAR_OPS));*/ + +// DISPATCH_KERNEL_META(invertedMetaPairwiseShaped_Pairwise_Scalar_, +// invertedMetaPairwiseShapedGeneric, float, simdOps::InvertedMetaOp, +// INPUT(const int opTypeA, const int opTypeB, long N, float *dx, int +// *xShapeInfo, float *dy, int *yShapeInfo, float *dz, int *zShapeInfo, float +// *extraA, float *extraB, float scalarA, float scalarB), PARAMS(opTypeA, +// opTypeB, N, dx, xShapeInfo, dy, yShapeInfo, dz, zShapeInfo, extraA, extraB, +// scalarA, scalarB), OPS_A(PAIRWISE_TRANSFORM_OPS), OPS_B(SCALAR_OPS)) + +//_EXPAND_KERNEL_CALL(invertedMetaPairwiseShaped_Pairwise_Scalar_, +// invertedMetaPairwiseShapedGeneric, float, simdOps::InvertedMetaOp, +// INPUT(const int opTypeA, const int opTypeB, long N, float *dx, int +// *xShapeInfo, float *dy, int *yShapeInfo, float *dz, int *zShapeInfo, float +// *extraA, float *extraB, float scalarA, float scalarB), PARAMS(N, dx, dy, +// xStride, yStride, paramsPtr, dz, zStride, nullptr, nullptr, nullptr), 66, +// simdOps::SomeOpA, 99, simdOps::SomeOpB) /* - extern "C" __global__ void invertedMetaOpKernel_Pairwise_Scalar_16_1_float(const int opTypeA, const int opTypeB, long N, float *dx, int *xShapeInfo, float *dy, int *yShapeInfo, float *dz, int *zShapeInfo, float *extraA, float *extraB, float scalarA, float scalarB) { - invertedMetaPairwiseShapedGeneric, simdOps::Multiply>>(opTypeA, opTypeB, N, dx, xShapeInfo, dy, yShapeInfo, dz, zShapeInfo, extraA, extraB, scalarA, scalarB); + extern "C" __global__ void + invertedMetaOpKernel_Pairwise_Scalar_16_1_float(const int opTypeA, const int + opTypeB, long N, float *dx, int *xShapeInfo, float *dy, int *yShapeInfo, float + *dz, int *zShapeInfo, float *extraA, float *extraB, float scalarA, float + scalarB) { invertedMetaPairwiseShapedGeneric, + simdOps::Multiply>>(opTypeA, opTypeB, N, dx, xShapeInfo, dy, yShapeInfo, + dz, zShapeInfo, extraA, extraB, scalarA, scalarB); } */ - - -#endif //LIBND4J_PLAY_H +#endif // LIBND4J_PLAY_H diff --git a/libnd4j/include/system/pointercast.h b/libnd4j/include/system/pointercast.h index 2c64d608eda2..ead4f4f70b5a 100644 --- a/libnd4j/include/system/pointercast.h +++ b/libnd4j/include/system/pointercast.h @@ -21,44 +21,41 @@ #ifndef NATIVEOPERATIONS_POINTERCAST_H #define NATIVEOPERATIONS_POINTERCAST_H -#include #include +#include typedef void* Nd4jPointer; typedef long long Nd4jLong; typedef uint64_t Nd4jULong; typedef int Nd4jStatus; -#define ND4J_STATUS_OK 0 -#define ND4J_STATUS_BAD_INPUT 1 -#define ND4J_STATUS_BAD_SHAPE 2 -#define ND4J_STATUS_BAD_RANK 3 -#define ND4J_STATUS_BAD_PARAMS 4 -#define ND4J_STATUS_BAD_OUTPUT 5 -#define ND4J_STATUS_BAD_RNG 6 -#define ND4J_STATUS_BAD_EPSILON 7 +#define ND4J_STATUS_OK 0 +#define ND4J_STATUS_BAD_INPUT 1 +#define ND4J_STATUS_BAD_SHAPE 2 +#define ND4J_STATUS_BAD_RANK 3 +#define ND4J_STATUS_BAD_PARAMS 4 +#define ND4J_STATUS_BAD_OUTPUT 5 +#define ND4J_STATUS_BAD_RNG 6 +#define ND4J_STATUS_BAD_EPSILON 7 #define ND4J_STATUS_BAD_GRADIENTS 8 -#define ND4J_STATUS_BAD_BIAS 9 - -#define ND4J_STATUS_VALIDATION 20 +#define ND4J_STATUS_BAD_BIAS 9 -#define ND4J_STATUS_BAD_GRAPH 30 -#define ND4J_STATUS_BAD_LENGTH 31 -#define ND4J_STATUS_BAD_DIMENSIONS 32 -#define ND4J_STATUS_BAD_ORDER 33 -#define ND4J_STATUS_BAD_ARGUMENTS 34 +#define ND4J_STATUS_VALIDATION 20 -#define ND4J_STATUS_DOUBLE_WRITE 40 -#define ND4J_STATUS_DOUBLE_READ 45 +#define ND4J_STATUS_BAD_GRAPH 30 +#define ND4J_STATUS_BAD_LENGTH 31 +#define ND4J_STATUS_BAD_DIMENSIONS 32 +#define ND4J_STATUS_BAD_ORDER 33 +#define ND4J_STATUS_BAD_ARGUMENTS 34 +#define ND4J_STATUS_DOUBLE_WRITE 40 +#define ND4J_STATUS_DOUBLE_READ 45 -#define ND4J_STATUS_KERNEL_FAILURE 50 - - -#define ND4J_STATUS_TRUE 100 -#define ND4J_STATUS_FALSE 101 -#define ND4J_STATUS_MAYBE 119 +#define ND4J_STATUS_KERNEL_FAILURE 50 +#define ND4J_STATUS_TRUE 100 +#define ND4J_STATUS_FALSE 101 +#define ND4J_STATUS_MAYBE 119 #ifdef _MSC_VER @@ -72,8 +69,8 @@ typedef int Nd4jStatus; #elif __GNUC__ -#include -#define MAP_IMPL std::unordered_map +#include +#define MAP_IMPL std::map #else @@ -82,6 +79,4 @@ typedef int Nd4jStatus; #endif - - -#endif //NATIVEOPERATIONS_POINTERCAST_H +#endif // NATIVEOPERATIONS_POINTERCAST_H diff --git a/libnd4j/include/system/type_boilerplate.h b/libnd4j/include/system/type_boilerplate.h index 997fcab22dee..dcc51b739d48 100644 --- a/libnd4j/include/system/type_boilerplate.h +++ b/libnd4j/include/system/type_boilerplate.h @@ -26,539 +26,1513 @@ #define EXPAND3(...) __VA_ARGS__ #define EXTRACT(...) EXTRACT __VA_ARGS__ #define NOTHING_EXTRACT -#define PASTE(x, ...) x ## __VA_ARGS__ -#define PASTE2(x, ...) x ## __VA_ARGS__ -#define PASTE3(x, ...) x ## __VA_ARGS__ +#define PASTE(x, ...) x##__VA_ARGS__ +#define PASTE2(x, ...) x##__VA_ARGS__ +#define PASTE3(x, ...) x##__VA_ARGS__ #define EVALUATING_PASTE(x, ...) PASTE(x, __VA_ARGS__) #define EVALUATING_PASTE2(x, ...) PASTE2(x, __VA_ARGS__) #define EVALUATING_PASTE3(x, ...) PASTE3(x, __VA_ARGS__) #define UNPAREN(x) EVALUATING_PASTE(NOTHING_, EXTRACT x) #define UNPAREN2(x) EVALUATING_PASTE2(NOTHING_, EXTRACT x) #define UNPAREN3(x) EVALUATING_PASTE3(NOTHING_, EXTRACT x) -#define EVAL( x ) x -#define EVALX( x ) x -#define EVAL0(...) EVAL1(EVAL1(EVAL1(__VA_ARGS__))) +#define EVAL(x) x +#define EVALX(x) x +#define EVAL0(...) EVAL1(EVAL1(EVAL1(__VA_ARGS__))) #define EVAL1(...) EVAL2(EVAL2(EVAL2(__VA_ARGS__))) #define EVAL2(...) EVAL3(EVAL3(EVAL3(__VA_ARGS__))) #define EVAL3(...) EVAL4(EVAL4(EVAL4(__VA_ARGS__))) #define EVAL4(...) EVAL5(EVAL5(EVAL5(__VA_ARGS__))) #define EVAL5(...) __VA_ARGS__ - #define SEL_T_1(WHAT, NAME, SIGNATURE, TYPE_A) WHAT(NAME, SIGNATURE, TYPE_A) -#define SEL_T_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_3(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_2(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_4(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_3(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_5(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_4(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_6(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_5(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_7(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_6(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_8(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_7(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_9(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_8(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_10(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_9(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_11(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_10(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_12(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_11(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_13(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_12(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_14(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_13(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_15(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_14(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_16(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_15(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_17(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_16(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_18(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_17(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_19(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_18(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_20(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_19(WHAT, NAME, SIGNATURE, __VA_ARGS__)) - - -#define SEL_TT1_1(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) -#define SEL_TT1_2(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_1(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_3(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_2(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_4(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_3(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_5(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_4(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_6(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_5(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_7(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_6(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_8(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_7(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_9(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_8(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_10(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_9(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_11(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_10(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_12(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_11(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_13(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_12(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_14(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_13(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_15(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_14(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_16(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_15(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_17(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_16(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_18(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_17(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_19(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_18(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_20(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_19(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) - - -#define SEL_P1_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) -#define SEL_P1_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_20(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) - -#define SEL_P2_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) -#define SEL_P2_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_20(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) - - - - -#define SEL_TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) -#define SEL_TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) - +#define SEL_T_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_3(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_2(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_4(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_3(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_5(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_4(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_6(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_5(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_7(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_6(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_8(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_7(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_9(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_8(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_10(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_9(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_11(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_10(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_12(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_11(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_13(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_12(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_14(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_13(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_15(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_14(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_16(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_15(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_17(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_16(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_18(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_17(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_19(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_18(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_20(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_19(WHAT, NAME, SIGNATURE, __VA_ARGS__)) + +#define SEL_TT1_1(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +#define SEL_TT1_2(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_1(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_3(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_2(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_4(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_3(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_5(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_4(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_6(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_5(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_7(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_6(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_8(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_7(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_9(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_8(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_10(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_9(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_11(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_10(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_12(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_11(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_13(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_12(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_14(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_13(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_15(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_14(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_16(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_15(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_17(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_16(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_18(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_17(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_19(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_18(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_20(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_19(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) + +#define SEL_P1_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +#define SEL_P1_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_20(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) + +#define SEL_P2_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +#define SEL_P2_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_20(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) + +#define SEL_TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +#define SEL_TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) #define DS_1(WHAT, NAME, SIGNATURE, TYPE_A) WHAT(NAME, SIGNATURE, TYPE_A) -#define DS_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_3(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_2(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_4(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_3(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_5(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_4(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_6(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_5(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_7(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_6(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_8(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_7(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_9(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_8(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_10(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_9(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_11(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_10(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_12(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_11(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_13(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_12(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_14(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_13(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_15(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_14(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_16(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_15(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_17(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_16(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_18(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_17(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_19(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_18(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_20(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_19(WHAT, NAME, SIGNATURE, __VA_ARGS__)) - +#define DS_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_3(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_2(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_4(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_3(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_5(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_4(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_6(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_5(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_7(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_6(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_8(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_7(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_9(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_8(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_10(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_9(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_11(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_10(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_12(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_11(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_13(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_12(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_14(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_13(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_15(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_14(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_16(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_15(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_17(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_16(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_18(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_17(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_19(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_18(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_20(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_19(WHAT, NAME, SIGNATURE, __VA_ARGS__)) #define DP_1(WHAT, NAME, SIGNATURE, TYPE_A) WHAT(NAME, SIGNATURE, TYPE_A) -#define DP_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_3(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_2(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_4(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_3(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_5(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_4(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_6(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_5(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_7(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_6(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_8(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_7(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_9(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_8(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_10(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_9(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_11(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_10(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_12(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_11(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_13(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_12(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_14(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_13(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_15(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_14(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_16(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_15(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_17(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_16(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_18(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_17(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_19(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_18(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_20(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_19(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_21(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_20(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_22(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_21(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_23(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_22(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_24(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_23(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_25(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_24(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_26(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_25(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_27(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_26(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_28(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_27(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_29(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_28(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_30(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_29(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_31(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_30(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_32(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_31(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_33(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_32(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_34(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_33(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_35(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_34(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_36(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_35(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_37(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_36(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_38(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_37(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_39(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_38(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_40(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_39(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_41(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_40(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_42(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_41(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_43(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_42(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_44(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_43(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_45(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_44(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_46(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_45(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_47(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_46(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_48(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_47(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_49(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_48(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_50(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_49(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_51(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_50(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_52(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_51(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_53(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_52(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_54(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_53(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_55(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_54(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_56(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_55(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_57(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_56(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_58(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_57(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_59(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_58(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_60(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_59(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_61(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_60(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_62(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_61(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_63(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_62(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_64(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_63(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_65(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_64(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_66(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_65(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_67(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_66(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_68(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_67(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_69(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_68(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_70(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_69(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_71(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_70(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_72(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_71(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_73(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_72(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_74(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_73(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_75(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_74(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_76(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_75(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_77(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_76(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_78(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_77(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_79(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_78(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_80(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_79(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_81(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_80(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_82(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_81(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_83(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_82(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_84(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_83(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_85(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_84(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_86(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_85(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_87(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_86(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_88(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_87(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_89(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_88(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_90(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_89(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_91(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_90(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_92(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_91(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_93(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_92(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_94(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_93(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_95(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_94(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_96(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_95(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_97(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_96(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_98(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_97(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_99(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_98(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_100(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_99(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_101(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_100(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_102(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_101(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_103(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_102(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_104(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_103(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_105(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_104(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_106(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_105(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_107(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_106(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_108(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_107(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_109(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_108(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_110(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_109(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_111(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_110(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_112(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_111(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_113(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_112(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_114(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_113(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_115(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_114(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_116(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_115(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_117(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_116(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_118(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_117(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_119(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_118(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_120(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_119(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_121(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_120(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_122(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_121(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_123(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_122(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_124(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_123(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_125(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_124(WHAT, NAME, SIGNATURE, __VA_ARGS__)) - - -#define DT_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) -#define DT_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_1(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_4(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_5(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_6(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_7(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_8(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_9(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_10(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_11(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_12(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_13(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_14(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_15(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_16(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_17(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_18(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_19(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) - -#define DT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) -#define DT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_1(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_4(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_5(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_6(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_7(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_8(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_9(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_10(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_11(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_12(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_13(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_14(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_15(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_16(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_17(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_18(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_19(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) - -#define TTT1_1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) -#define TTT1_2(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_3(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_2(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_4(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_3(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_5(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_4(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_6(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_5(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_7(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_6(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_8(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_7(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_9(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_8(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_10(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_9(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_11(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_10(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_12(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_11(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_13(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_12(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_14(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_13(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_15(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_14(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_16(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_15(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_17(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_16(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_18(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_17(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_19(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_18(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_20(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_19(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) - -#define TTT2_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) -#define TTT2_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_20(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) - - -#define TTT3_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) -#define TTT3_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_20(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) - - - -#define TT1_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) -#define TT1_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) - -#define TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) -#define TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) - -#define TT3_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) -#define TT3_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) - - -#define GET_MACRO_SEL_T(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -#define GET_MACRO_SEL_P1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -#define GET_MACRO_SEL_P2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -#define GET_MACRO_SEL_TT1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -#define GET_MACRO_SEL_TT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -#define GET_MACRO_DS(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -#define GET_MACRO_DT(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -#define GET_MACRO_DP(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, _61, _62, _63, _64, _65, _66, _67, _68, _69, _70, _71, _72, _73, _74, _75, _76, _77, _78, _79, _80, _81, _82, _83, _84, _85, _86, _87, _88, _89, _90, _91, _92, _93, _94, _95, _96, _97, _98, _99, _100, _101, _102, _103, _104, _105, _106, _107, _108, _109, _110, _111, _112, _113, _114, _115, _116, _117, _118, _119, _120, _121, _122, _123, _124, _125, NAME,...) NAME -#define GET_MACRO_DT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME - - -#define GET_MACRO_TT1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -#define GET_MACRO_TT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -#define GET_MACRO_TT3(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME - -#define GET_MACRO_TTT1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -#define GET_MACRO_TTT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -#define GET_MACRO_TTT3(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME - -#define FOR_EACH_S1(WHAT, NAME, SIGNATURE, ...) EXPAND(GET_MACRO_SEL_T(__VA_ARGS__, SEL_T_20, SEL_T_19, SEL_T_18, SEL_T_17, SEL_T_16, SEL_T_15, SEL_T_14, SEL_T_13, SEL_T_12, SEL_T_11, SEL_T_10, SEL_T_9, SEL_T_8, SEL_T_7, SEL_T_6, SEL_T_5, SEL_T_4, SEL_T_3, SEL_T_2, SEL_T_1)(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define FOR_EACH_S2(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, ...) EXPAND(GET_MACRO_SEL_TT1(__VA_ARGS__, SEL_TT1_20, SEL_TT1_19, SEL_TT1_18, SEL_TT1_17, SEL_TT1_16, SEL_TT1_15, SEL_TT1_14, SEL_TT1_13, SEL_TT1_12, SEL_TT1_11, SEL_TT1_10, SEL_TT1_9, SEL_TT1_8, SEL_TT1_7, SEL_TT1_6, SEL_TT1_5, SEL_TT1_4, SEL_TT1_3, SEL_TT1_2, SEL_TT1_1)(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) -#define FOR_EACH_P1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, ...) EXPAND(GET_MACRO_SEL_P1(__VA_ARGS__, SEL_P1_20, SEL_P1_19, SEL_P1_18, SEL_P1_17, SEL_P1_16, SEL_P1_15, SEL_P1_14, SEL_P1_13, SEL_P1_12, SEL_P1_11, SEL_P1_10, SEL_P1_9, SEL_P1_8, SEL_P1_7, SEL_P1_6, SEL_P1_5, SEL_P1_4, SEL_P1_3, SEL_P1_2, SEL_P1_1)(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) -#define FOR_EACH_P2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, ...) EXPAND2(GET_MACRO_SEL_P2(__VA_ARGS__, SEL_P2_20, SEL_P2_19, SEL_P2_18, SEL_P2_17, SEL_P2_16, SEL_P2_15, SEL_P2_14, SEL_P2_13, SEL_P2_12, SEL_P2_11, SEL_P2_10, SEL_P2_9, SEL_P2_8, SEL_P2_7, SEL_P2_6, SEL_P2_5, SEL_P2_4, SEL_P2_3, SEL_P2_2, SEL_P2_1)(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) -#define FOR_EACH_S3(WHAT, NAME, SIGNATURE, TYPE_A, ...) EXPAND(GET_MACRO_SEL_TT2(__VA_ARGS__, SEL_TT2_20, SEL_TT2_19, SEL_TT2_18, SEL_TT2_17, SEL_TT2_16, SEL_TT2_15, SEL_TT2_14, SEL_TT2_13, SEL_TT2_12, SEL_TT2_11, SEL_TT2_10, SEL_TT2_9, SEL_TT2_8, SEL_TT2_7, SEL_TT2_6, SEL_TT2_5, SEL_TT2_4, SEL_TT2_3, SEL_TT2_2, SEL_TT2_1)(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define FOR_EACH_DS(WHAT, NAME, SIGNATURE, ...) EXPAND(GET_MACRO_DS(__VA_ARGS__, DS_20, DS_19, DS_18, DS_17, DS_16, DS_15, DS_14, DS_13, DS_12, DS_11, DS_10, DS_9, DS_8, DS_7, DS_6, DS_5, DS_4, DS_3, DS_2, DS_1)(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define FOR_EACH_DT(WHAT, NAME, SIGNATURE, TYPES_A, ...) EXPAND(GET_MACRO_DT(__VA_ARGS__, DT_20, DT_19, DT_18, DT_17, DT_16, DT_15, DT_14, DT_13, DT_12, DT_11, DT_10, DT_9, DT_8, DT_7, DT_6, DT_5, DT_4, DT_3, DT_2, DT_1)(WHAT, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) -#define FOR_EACH_DT2(WHAT, NAME, SIGNATURE, TYPE_A, ...) EXPAND(GET_MACRO_DT2(__VA_ARGS__, DT2_20, DT2_19, DT2_18, DT2_17, DT2_16, DT2_15, DT2_14, DT2_13, DT2_12, DT2_11, DT2_10, DT2_9, DT2_8, DT2_7, DT2_6, DT2_5, DT2_4, DT2_3, DT2_2, DT2_1)(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define FOR_EACH_DP(WHAT, NAME, SIGNATURE, ...) EXPAND(GET_MACRO_DP(__VA_ARGS__, DP_125, DP_124, DP_123, DP_122, DP_121, DP_120, DP_119, DP_118, DP_117, DP_116, DP_115, DP_114, DP_113, DP_112, DP_111, DP_110, DP_109, DP_108, DP_107, DP_106, DP_105, DP_104, DP_103, DP_102, DP_101, DP_100, DP_99, DP_98, DP_97, DP_96, DP_95, DP_94, DP_93, DP_92, DP_91, DP_90, DP_89, DP_88, DP_87, DP_86, DP_85, DP_84, DP_83, DP_82, DP_81, DP_80, DP_79, DP_78, DP_77, DP_76, DP_75, DP_74, DP_73, DP_72, DP_71, DP_70, DP_69, DP_68, DP_67, DP_66, DP_65, DP_64, DP_63, DP_62, DP_61, DP_60, DP_59, DP_58, DP_57, DP_56, DP_55, DP_54, DP_53, DP_52, DP_51, DP_50, DP_49, DP_48, DP_47, DP_46, DP_45, DP_44, DP_43, DP_42, DP_41, DP_40, DP_39, DP_38, DP_37, DP_36, DP_35, DP_34, DP_33, DP_32, DP_31, DP_30, DP_29, DP_28, DP_27, DP_26, DP_25, DP_24, DP_23, DP_22, DP_21, DP_20, DP_19, DP_18, DP_17, DP_16, DP_15, DP_14, DP_13, DP_12, DP_11, DP_10, DP_9, DP_8, DP_7, DP_6, DP_5, DP_4, DP_3, DP_2, DP_1)(WHAT, NAME, SIGNATURE, __VA_ARGS__)) - - -#define FOR_EACH_TT1(WHAT, NAME, SIGNATURE, TYPES_X, TYPES_Y, ...) EXPAND(GET_MACRO_TT1(__VA_ARGS__, TT1_20, TT1_19, TT1_18, TT1_17, TT1_16, TT1_15, TT1_14, TT1_13, TT1_12, TT1_11, TT1_10, TT1_9, TT1_8, TT1_7, TT1_6, TT1_5, TT1_4, TT1_3, TT1_2, TT1_1)(WHAT, NAME, SIGNATURE, TYPES_X, TYPES_Y, __VA_ARGS__)) -#define FOR_EACH_TT2(WHAT, NAME, SIGNATURE, TYPE_Z, TYPES_X, ...) EXPAND(GET_MACRO_TT2(__VA_ARGS__, TT2_20, TT2_19, TT2_18, TT2_17, TT2_16, TT2_15, TT2_14, TT2_13, TT2_12, TT2_11, TT2_10, TT2_9, TT2_8, TT2_7, TT2_6, TT2_5, TT2_4, TT2_3, TT2_2, TT2_1)(WHAT, NAME, SIGNATURE, TYPE_Z, TYPES_X, __VA_ARGS__)) -#define FOR_EACH_TT3(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, ...) EXPAND(GET_MACRO_TT3(__VA_ARGS__, TT3_20, TT3_19, TT3_18, TT3_17, TT3_16, TT3_15, TT3_14, TT3_13, TT3_12, TT3_11, TT3_10, TT3_9, TT3_8, TT3_7, TT3_6, TT3_5, TT3_4, TT3_3, TT3_2, TT3_1)(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, __VA_ARGS__)) - -#define FOR_EACH_TTT1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, ...) EXPAND(GET_MACRO_TTT1(__VA_ARGS__, TTT1_20, TTT1_19, TTT1_18, TTT1_17, TTT1_16, TTT1_15, TTT1_14, TTT1_13, TTT1_12, TTT1_11, TTT1_10, TTT1_9, TTT1_8, TTT1_7, TTT1_6, TTT1_5, TTT1_4, TTT1_3, TTT1_2, TTT1_1)(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, __VA_ARGS__)) -#define FOR_EACH_TTT2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, ...) EXPAND(GET_MACRO_TTT2(__VA_ARGS__, TTT2_20, TTT2_19, TTT2_18, TTT2_17, TTT2_16, TTT2_15, TTT2_14, TTT2_13, TTT2_12, TTT2_11, TTT2_10, TTT2_9, TTT2_8, TTT2_7, TTT2_6, TTT2_5, TTT2_4, TTT2_3, TTT2_2, TTT2_1)(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, __VA_ARGS__)) -#define FOR_EACH_TTT3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, ...) EXPAND(GET_MACRO_TTT3(__VA_ARGS__, TTT3_20, TTT3_19, TTT3_18, TTT3_17, TTT3_16, TTT3_15, TTT3_14, TTT3_13, TTT3_12, TTT3_11, TTT3_10, TTT3_9, TTT3_8, TTT3_7, TTT3_6, TTT3_5, TTT3_4, TTT3_3, TTT3_2, TTT3_1)(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, __VA_ARGS__)) - -#define _EXEC_SELECTOR_T(WHAT, NAME, SIGNATURE, ...) EVAL(FOR_EACH_S1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define _EXEC_SELECTOR_P_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, ...) EVAL(FOR_EACH_P1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) -#define _EXEC_SELECTOR_P_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, ...) EVAL(FOR_EACH_P2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define _EXEC_SELECTOR_TT_1(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, ...) EVAL(FOR_EACH_S2(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) -#define _EXEC_SELECTOR_TT_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) EVAL(FOR_EACH_S3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define _EXEC_SINGLE_T(WHAT, NAME, SIGNATURE, ...) EVAL(FOR_EACH_DS(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define _EXEC_DOUBLE_T(WHAT, NAME, SIGNATURE, TYPES_A, ...) EVAL(FOR_EACH_DT(WHAT, NAME, SIGNATURE, LIST(TYPES_A), __VA_ARGS__)) -#define _EXEC_DOUBLE_T2(WHAT, NAME, SIGNATURE, TYPE_A, ...) EVAL(FOR_EACH_DT2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define _EXEC_DOUBLE_P(WHAT, NAME, SIGNATURE, ...) EVAL(FOR_EACH_DP(WHAT, NAME, SIGNATURE, __VA_ARGS__)) - -#define _EXEC_SELECTOR_TTT_1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, ...) EVAL(FOR_EACH_TTT1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, __VA_ARGS__)) -#define _EXEC_SELECTOR_TTT_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, ...) EVAL(FOR_EACH_TTT2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, __VA_ARGS__)) -#define _EXEC_SELECTOR_TTT_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, ...) EVAL(FOR_EACH_TTT3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, __VA_ARGS__)) - -#define _EXEC_TRIPLE_T3(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, ...) EVAL(FOR_EACH_TT3(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, __VA_ARGS__)) -#define _EXEC_TRIPLE_T2(WHAT, NAME, SIGNATURE, TYPE_Z, TYPES_X, ...) EVAL(FOR_EACH_TT2(WHAT, NAME, SIGNATURE, TYPE_Z, LIST(TYPES_X), __VA_ARGS__)) -#define _EXEC_TRIPLE_T1(WHAT, NAME, SIGNATURE, TYPES_X, TYPES_Y, ...) EVAL(FOR_EACH_TT1(WHAT, NAME, SIGNATURE, LIST(TYPES_X), LIST(TYPES_Y), __VA_ARGS__)) - -#define DISPATCH_PAIRWISE(NAME, SIGNATURE, TYPE, TYPES_B) EVAL(_EXEC_DOUBLE_T2(RANDOMPAIRWISE2, NAME, SIGNATURE, TYPE, TYPES_B)) -#define DISPATCH_PAIRWISE2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE, ...) EVAL(_EXEC_SELECTOR_P_2(SELECTOR_PAIRWISE_2, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE, __VA_ARGS__)) - -#define DISPATCH_DTYPES(NAME, SIGNATURE, TYPE, TYPES_B) EVAL(_EXEC_DOUBLE_T2(RANDOMDOUBLE2, NAME, SIGNATURE, TYPE, TYPES_B)) -#define DISPATCH_DTYPES2(NAME, SIGNATURE, TYPE, ...) EVAL(_EXEC_SELECTOR_TT_2(SELECTOR_DOUBLE_2, NAME, SIGNATURE, TYPE, __VA_ARGS__)) - -#define DISPATCH_TTYPES2(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, ...) EVAL(_EXEC_SELECTOR_TTT_2(SELECTOR_TRIPLE_2, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, __VA_ARGS__)) -#define DISPATCH_TTYPES3(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, ...) EVAL(_EXEC_SELECTOR_TTT_3(SELECTOR_TRIPLE_3, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, __VA_ARGS__)) - +#define DP_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_3(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_2(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_4(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_3(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_5(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_4(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_6(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_5(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_7(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_6(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_8(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_7(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_9(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_8(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_10(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_9(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_11(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_10(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_12(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_11(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_13(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_12(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_14(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_13(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_15(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_14(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_16(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_15(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_17(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_16(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_18(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_17(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_19(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_18(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_20(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_19(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_21(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_20(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_22(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_21(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_23(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_22(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_24(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_23(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_25(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_24(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_26(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_25(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_27(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_26(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_28(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_27(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_29(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_28(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_30(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_29(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_31(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_30(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_32(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_31(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_33(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_32(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_34(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_33(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_35(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_34(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_36(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_35(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_37(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_36(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_38(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_37(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_39(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_38(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_40(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_39(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_41(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_40(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_42(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_41(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_43(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_42(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_44(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_43(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_45(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_44(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_46(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_45(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_47(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_46(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_48(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_47(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_49(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_48(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_50(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_49(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_51(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_50(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_52(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_51(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_53(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_52(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_54(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_53(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_55(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_54(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_56(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_55(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_57(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_56(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_58(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_57(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_59(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_58(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_60(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_59(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_61(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_60(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_62(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_61(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_63(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_62(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_64(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_63(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_65(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_64(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_66(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_65(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_67(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_66(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_68(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_67(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_69(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_68(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_70(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_69(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_71(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_70(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_72(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_71(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_73(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_72(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_74(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_73(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_75(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_74(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_76(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_75(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_77(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_76(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_78(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_77(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_79(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_78(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_80(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_79(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_81(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_80(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_82(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_81(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_83(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_82(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_84(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_83(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_85(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_84(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_86(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_85(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_87(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_86(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_88(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_87(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_89(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_88(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_90(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_89(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_91(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_90(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_92(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_91(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_93(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_92(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_94(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_93(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_95(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_94(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_96(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_95(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_97(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_96(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_98(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_97(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_99(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_98(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_100(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_99(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_101(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_100(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_102(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_101(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_103(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_102(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_104(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_103(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_105(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_104(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_106(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_105(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_107(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_106(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_108(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_107(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_109(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_108(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_110(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_109(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_111(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_110(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_112(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_111(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_113(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_112(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_114(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_113(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_115(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_114(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_116(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_115(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_117(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_116(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_118(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_117(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_119(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_118(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_120(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_119(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_121(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_120(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_122(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_121(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_123(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_122(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_124(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_123(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_125(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_124(WHAT, NAME, SIGNATURE, __VA_ARGS__)) + +#define DT_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +#define DT_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_1(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_4(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_5(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_6(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_7(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_8(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_9(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_10(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_11(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_12(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_13(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_14(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_15(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_16(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_17(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_18(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_19(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) + +#define DT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +#define DT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_1(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_4(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_5(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_6(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_7(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_8(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_9(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_10(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_11(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_12(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_13(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_14(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_15(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_16(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_17(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_18(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_19(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) + +#define TTT1_1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +#define TTT1_2(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_3(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_2(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_4(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_3(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_5(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_4(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_6(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_5(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_7(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_6(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_8(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_7(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_9(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_8(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_10(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_9(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_11(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_10(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_12(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_11(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_13(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_12(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_14(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_13(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_15(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_14(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_16(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_15(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_17(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_16(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_18(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_17(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_19(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_18(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_20(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_19(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) + +#define TTT2_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +#define TTT2_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_20(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) + +#define TTT3_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +#define TTT3_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_20(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) + +#define TT1_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +#define TT1_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) + +#define TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +#define TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) + +#define TT3_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +#define TT3_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) + +#define GET_MACRO_SEL_T(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, \ + _13, _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME +#define GET_MACRO_SEL_P1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, \ + _13, _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME +#define GET_MACRO_SEL_P2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, \ + _13, _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME +#define GET_MACRO_SEL_TT1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, \ + _13, _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME +#define GET_MACRO_SEL_TT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, \ + _13, _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME +#define GET_MACRO_DS(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \ + _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME +#define GET_MACRO_DT(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \ + _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME +#define GET_MACRO_DP( \ + _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, \ + _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, \ + _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, \ + _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, _61, \ + _62, _63, _64, _65, _66, _67, _68, _69, _70, _71, _72, _73, _74, _75, _76, \ + _77, _78, _79, _80, _81, _82, _83, _84, _85, _86, _87, _88, _89, _90, _91, \ + _92, _93, _94, _95, _96, _97, _98, _99, _100, _101, _102, _103, _104, \ + _105, _106, _107, _108, _109, _110, _111, _112, _113, _114, _115, _116, \ + _117, _118, _119, _120, _121, _122, _123, _124, _125, NAME, ...) \ + NAME +#define GET_MACRO_DT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \ + _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME + +#define GET_MACRO_TT1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \ + _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME +#define GET_MACRO_TT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \ + _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME +#define GET_MACRO_TT3(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \ + _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME + +#define GET_MACRO_TTT1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \ + _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME +#define GET_MACRO_TTT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \ + _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME +#define GET_MACRO_TTT3(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \ + _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME + +#define FOR_EACH_S1(WHAT, NAME, SIGNATURE, ...) \ + EXPAND(GET_MACRO_SEL_T(__VA_ARGS__, SEL_T_20, SEL_T_19, SEL_T_18, SEL_T_17, \ + SEL_T_16, SEL_T_15, SEL_T_14, SEL_T_13, SEL_T_12, \ + SEL_T_11, SEL_T_10, SEL_T_9, SEL_T_8, SEL_T_7, \ + SEL_T_6, SEL_T_5, SEL_T_4, SEL_T_3, SEL_T_2, \ + SEL_T_1)(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define FOR_EACH_S2(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, ...) \ + EXPAND(GET_MACRO_SEL_TT1( \ + __VA_ARGS__, SEL_TT1_20, SEL_TT1_19, SEL_TT1_18, SEL_TT1_17, SEL_TT1_16, \ + SEL_TT1_15, SEL_TT1_14, SEL_TT1_13, SEL_TT1_12, SEL_TT1_11, SEL_TT1_10, \ + SEL_TT1_9, SEL_TT1_8, SEL_TT1_7, SEL_TT1_6, SEL_TT1_5, SEL_TT1_4, \ + SEL_TT1_3, SEL_TT1_2, \ + SEL_TT1_1)(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) +#define FOR_EACH_P1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, ...) \ + EXPAND(GET_MACRO_SEL_P1(__VA_ARGS__, SEL_P1_20, SEL_P1_19, SEL_P1_18, \ + SEL_P1_17, SEL_P1_16, SEL_P1_15, SEL_P1_14, \ + SEL_P1_13, SEL_P1_12, SEL_P1_11, SEL_P1_10, \ + SEL_P1_9, SEL_P1_8, SEL_P1_7, SEL_P1_6, SEL_P1_5, \ + SEL_P1_4, SEL_P1_3, SEL_P1_2, SEL_P1_1)( \ + WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) +#define FOR_EACH_P2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, ...) \ + EXPAND2(GET_MACRO_SEL_P2(__VA_ARGS__, SEL_P2_20, SEL_P2_19, SEL_P2_18, \ + SEL_P2_17, SEL_P2_16, SEL_P2_15, SEL_P2_14, \ + SEL_P2_13, SEL_P2_12, SEL_P2_11, SEL_P2_10, \ + SEL_P2_9, SEL_P2_8, SEL_P2_7, SEL_P2_6, SEL_P2_5, \ + SEL_P2_4, SEL_P2_3, SEL_P2_2, SEL_P2_1)( \ + WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) +#define FOR_EACH_S3(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + EXPAND(GET_MACRO_SEL_TT2( \ + __VA_ARGS__, SEL_TT2_20, SEL_TT2_19, SEL_TT2_18, SEL_TT2_17, SEL_TT2_16, \ + SEL_TT2_15, SEL_TT2_14, SEL_TT2_13, SEL_TT2_12, SEL_TT2_11, SEL_TT2_10, \ + SEL_TT2_9, SEL_TT2_8, SEL_TT2_7, SEL_TT2_6, SEL_TT2_5, SEL_TT2_4, \ + SEL_TT2_3, SEL_TT2_2, \ + SEL_TT2_1)(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define FOR_EACH_DS(WHAT, NAME, SIGNATURE, ...) \ + EXPAND(GET_MACRO_DS(__VA_ARGS__, DS_20, DS_19, DS_18, DS_17, DS_16, DS_15, \ + DS_14, DS_13, DS_12, DS_11, DS_10, DS_9, DS_8, DS_7, \ + DS_6, DS_5, DS_4, DS_3, DS_2, \ + DS_1)(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define FOR_EACH_DT(WHAT, NAME, SIGNATURE, TYPES_A, ...) \ + EXPAND(GET_MACRO_DT(__VA_ARGS__, DT_20, DT_19, DT_18, DT_17, DT_16, DT_15, \ + DT_14, DT_13, DT_12, DT_11, DT_10, DT_9, DT_8, DT_7, \ + DT_6, DT_5, DT_4, DT_3, DT_2, \ + DT_1)(WHAT, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) +#define FOR_EACH_DT2(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + EXPAND(GET_MACRO_DT2(__VA_ARGS__, DT2_20, DT2_19, DT2_18, DT2_17, DT2_16, \ + DT2_15, DT2_14, DT2_13, DT2_12, DT2_11, DT2_10, DT2_9, \ + DT2_8, DT2_7, DT2_6, DT2_5, DT2_4, DT2_3, DT2_2, \ + DT2_1)(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define FOR_EACH_DP(WHAT, NAME, SIGNATURE, ...) \ + EXPAND(GET_MACRO_DP( \ + __VA_ARGS__, DP_125, DP_124, DP_123, DP_122, DP_121, DP_120, DP_119, \ + DP_118, DP_117, DP_116, DP_115, DP_114, DP_113, DP_112, DP_111, DP_110, \ + DP_109, DP_108, DP_107, DP_106, DP_105, DP_104, DP_103, DP_102, DP_101, \ + DP_100, DP_99, DP_98, DP_97, DP_96, DP_95, DP_94, DP_93, DP_92, DP_91, \ + DP_90, DP_89, DP_88, DP_87, DP_86, DP_85, DP_84, DP_83, DP_82, DP_81, \ + DP_80, DP_79, DP_78, DP_77, DP_76, DP_75, DP_74, DP_73, DP_72, DP_71, \ + DP_70, DP_69, DP_68, DP_67, DP_66, DP_65, DP_64, DP_63, DP_62, DP_61, \ + DP_60, DP_59, DP_58, DP_57, DP_56, DP_55, DP_54, DP_53, DP_52, DP_51, \ + DP_50, DP_49, DP_48, DP_47, DP_46, DP_45, DP_44, DP_43, DP_42, DP_41, \ + DP_40, DP_39, DP_38, DP_37, DP_36, DP_35, DP_34, DP_33, DP_32, DP_31, \ + DP_30, DP_29, DP_28, DP_27, DP_26, DP_25, DP_24, DP_23, DP_22, DP_21, \ + DP_20, DP_19, DP_18, DP_17, DP_16, DP_15, DP_14, DP_13, DP_12, DP_11, \ + DP_10, DP_9, DP_8, DP_7, DP_6, DP_5, DP_4, DP_3, DP_2, \ + DP_1)(WHAT, NAME, SIGNATURE, __VA_ARGS__)) + +#define FOR_EACH_TT1(WHAT, NAME, SIGNATURE, TYPES_X, TYPES_Y, ...) \ + EXPAND(GET_MACRO_TT1(__VA_ARGS__, TT1_20, TT1_19, TT1_18, TT1_17, TT1_16, \ + TT1_15, TT1_14, TT1_13, TT1_12, TT1_11, TT1_10, TT1_9, \ + TT1_8, TT1_7, TT1_6, TT1_5, TT1_4, TT1_3, TT1_2, \ + TT1_1)(WHAT, NAME, SIGNATURE, TYPES_X, TYPES_Y, \ + __VA_ARGS__)) +#define FOR_EACH_TT2(WHAT, NAME, SIGNATURE, TYPE_Z, TYPES_X, ...) \ + EXPAND(GET_MACRO_TT2(__VA_ARGS__, TT2_20, TT2_19, TT2_18, TT2_17, TT2_16, \ + TT2_15, TT2_14, TT2_13, TT2_12, TT2_11, TT2_10, TT2_9, \ + TT2_8, TT2_7, TT2_6, TT2_5, TT2_4, TT2_3, TT2_2, \ + TT2_1)(WHAT, NAME, SIGNATURE, TYPE_Z, TYPES_X, \ + __VA_ARGS__)) +#define FOR_EACH_TT3(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, ...) \ + EXPAND(GET_MACRO_TT3(__VA_ARGS__, TT3_20, TT3_19, TT3_18, TT3_17, TT3_16, \ + TT3_15, TT3_14, TT3_13, TT3_12, TT3_11, TT3_10, TT3_9, \ + TT3_8, TT3_7, TT3_6, TT3_5, TT3_4, TT3_3, TT3_2, \ + TT3_1)(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, \ + __VA_ARGS__)) + +#define FOR_EACH_TTT1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, \ + ...) \ + EXPAND(GET_MACRO_TTT1(__VA_ARGS__, TTT1_20, TTT1_19, TTT1_18, TTT1_17, \ + TTT1_16, TTT1_15, TTT1_14, TTT1_13, TTT1_12, TTT1_11, \ + TTT1_10, TTT1_9, TTT1_8, TTT1_7, TTT1_6, TTT1_5, \ + TTT1_4, TTT1_3, TTT1_2, TTT1_1)( \ + WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, __VA_ARGS__)) +#define FOR_EACH_TTT2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, ...) \ + EXPAND(GET_MACRO_TTT2(__VA_ARGS__, TTT2_20, TTT2_19, TTT2_18, TTT2_17, \ + TTT2_16, TTT2_15, TTT2_14, TTT2_13, TTT2_12, TTT2_11, \ + TTT2_10, TTT2_9, TTT2_8, TTT2_7, TTT2_6, TTT2_5, \ + TTT2_4, TTT2_3, TTT2_2, TTT2_1)( \ + WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, __VA_ARGS__)) +#define FOR_EACH_TTT3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, ...) \ + EXPAND(GET_MACRO_TTT3(__VA_ARGS__, TTT3_20, TTT3_19, TTT3_18, TTT3_17, \ + TTT3_16, TTT3_15, TTT3_14, TTT3_13, TTT3_12, TTT3_11, \ + TTT3_10, TTT3_9, TTT3_8, TTT3_7, TTT3_6, TTT3_5, \ + TTT3_4, TTT3_3, TTT3_2, TTT3_1)( \ + WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, __VA_ARGS__)) + +#define _EXEC_SELECTOR_T(WHAT, NAME, SIGNATURE, ...) \ + EVAL(FOR_EACH_S1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define _EXEC_SELECTOR_P_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, \ + TYPES_A, ...) \ + EVAL(FOR_EACH_P1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, \ + __VA_ARGS__)) +#define _EXEC_SELECTOR_P_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + ...) \ + EVAL(FOR_EACH_P2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define _EXEC_SELECTOR_TT_1(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, ...) \ + EVAL(FOR_EACH_S2(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) +#define _EXEC_SELECTOR_TT_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + EVAL(FOR_EACH_S3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define _EXEC_SINGLE_T(WHAT, NAME, SIGNATURE, ...) \ + EVAL(FOR_EACH_DS(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define _EXEC_DOUBLE_T(WHAT, NAME, SIGNATURE, TYPES_A, ...) \ + EVAL(FOR_EACH_DT(WHAT, NAME, SIGNATURE, LIST(TYPES_A), __VA_ARGS__)) +#define _EXEC_DOUBLE_T2(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + EVAL(FOR_EACH_DT2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define _EXEC_DOUBLE_P(WHAT, NAME, SIGNATURE, ...) \ + EVAL(FOR_EACH_DP(WHAT, NAME, SIGNATURE, __VA_ARGS__)) + +#define _EXEC_SELECTOR_TTT_1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, \ + TYPES_Y, ...) \ + EVAL(FOR_EACH_TTT1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, \ + __VA_ARGS__)) +#define _EXEC_SELECTOR_TTT_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, \ + ...) \ + EVAL(FOR_EACH_TTT2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, \ + __VA_ARGS__)) +#define _EXEC_SELECTOR_TTT_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, \ + ...) \ + EVAL(FOR_EACH_TTT3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, __VA_ARGS__)) + +#define _EXEC_TRIPLE_T3(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, ...) \ + EVAL(FOR_EACH_TT3(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, __VA_ARGS__)) +#define _EXEC_TRIPLE_T2(WHAT, NAME, SIGNATURE, TYPE_Z, TYPES_X, ...) \ + EVAL(FOR_EACH_TT2(WHAT, NAME, SIGNATURE, TYPE_Z, LIST(TYPES_X), __VA_ARGS__)) +#define _EXEC_TRIPLE_T1(WHAT, NAME, SIGNATURE, TYPES_X, TYPES_Y, ...) \ + EVAL(FOR_EACH_TT1(WHAT, NAME, SIGNATURE, LIST(TYPES_X), LIST(TYPES_Y), \ + __VA_ARGS__)) + +#define DISPATCH_PAIRWISE(NAME, SIGNATURE, TYPE, TYPES_B) \ + EVAL(_EXEC_DOUBLE_T2(RANDOMPAIRWISE2, NAME, SIGNATURE, TYPE, TYPES_B)) +#define DISPATCH_PAIRWISE2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE, ...) \ + EVAL(_EXEC_SELECTOR_P_2(SELECTOR_PAIRWISE_2, XTYPE, YTYPE, ZTYPE, NAME, \ + SIGNATURE, TYPE, __VA_ARGS__)) + +#define DISPATCH_DTYPES(NAME, SIGNATURE, TYPE, TYPES_B) \ + EVAL(_EXEC_DOUBLE_T2(RANDOMDOUBLE2, NAME, SIGNATURE, TYPE, TYPES_B)) +#define DISPATCH_DTYPES2(NAME, SIGNATURE, TYPE, ...) \ + EVAL(_EXEC_SELECTOR_TT_2(SELECTOR_DOUBLE_2, NAME, SIGNATURE, TYPE, \ + __VA_ARGS__)) + +#define DISPATCH_TTYPES2(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, ...) \ + EVAL(_EXEC_SELECTOR_TTT_2(SELECTOR_TRIPLE_2, ZTYPE, NAME, SIGNATURE, TYPE_X, \ + TYPES_Z, __VA_ARGS__)) +#define DISPATCH_TTYPES3(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, ...) \ + EVAL(_EXEC_SELECTOR_TTT_3(SELECTOR_TRIPLE_3, ZTYPE, NAME, SIGNATURE, TYPE_X, \ + TYPE_Y, __VA_ARGS__)) #ifndef __CLION_IDE__ -#define BUILD_SINGLE_UNCHAINED_TEMPLATE(NAME, SIGNATURE, TYPES) EVAL(_EXEC_SINGLE_T(RANDOMSINGLEU, NAME, (SIGNATURE), TYPES)) -#define BUILD_SINGLE_TEMPLATE(NAME, SIGNATURE, TYPES) EVAL(_EXEC_SINGLE_T(RANDOMSINGLE, NAME, (SIGNATURE), TYPES)) -#define BUILD_SINGLE_TEMPLATE_TWICE(NAME, SIGNATURE, TYPES) EVAL(_EXEC_SELECTOR_T(TEMPLATE_SINGLE_TWICE, NAME, SIGNATURE, TYPES)) -#define BUILD_DOUBLE_TEMPLATE(NAME, SIGNATURE, TYPES_A, TYPES_B) EVAL(_EXEC_DOUBLE_T(RANDOMDOUBLE, NAME, (SIGNATURE), (TYPES_A), TYPES_B)) -#define BUILD_SINGLE_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} -#define BUILD_SINGLE_SELECTOR_TWICE(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE_TWICE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} -#define BUILD_SINGLE_SELECTOR_THRICE(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE_THRICE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} - - -#define BUILD_SINGLE_PARTIAL_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_PARTIAL_SINGLE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type"); }} -#define BUILD_DOUBLE_SELECTOR(XTYPE, YTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) switch(XTYPE) { EVAL(_EXEC_SELECTOR_TT_1(SELECTOR_DOUBLE, YTYPE, NAME, (SIGNATURE), (TYPES_B), TYPES_A)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} -#define BUILD_TRIPLE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPES_Z) switch(XTYPE) { EVAL(_EXEC_SELECTOR_TTT_1(SELECTOR_TRIPLE, YTYPE, ZTYPE, NAME, SIGNATURE, (TYPES_Z), (TYPES_Y), TYPES_X)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type"); } } -#define BUILD_TRIPLE_TEMPLATE(NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPES_Z) EVAL(_EXEC_TRIPLE_T1(RANDOMTRIPLE, NAME, (SIGNATURE), (TYPES_X), (TYPES_Y), TYPES_Z)) -#define BUILD_PAIRWISE_TEMPLATE(NAME, SIGNATURE, TYPES_A) EVAL(_EXEC_DOUBLE_P(RANDOMPAIRWISE, NAME, (SIGNATURE), TYPES_A)) -#define BUILD_PAIRWISE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) switch(XTYPE) { EVAL(_EXEC_SELECTOR_P_1(SELECTOR_PAIRWISE, XTYPE, YTYPE, ZTYPE, NAME, (SIGNATURE), (TYPES_B), TYPES_A)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type"); }} +#define BUILD_SINGLE_UNCHAINED_TEMPLATE(NAME, SIGNATURE, TYPES) \ + EVAL(_EXEC_SINGLE_T(RANDOMSINGLEU, NAME, (SIGNATURE), TYPES)) +#define BUILD_SINGLE_TEMPLATE(NAME, SIGNATURE, TYPES) \ + EVAL(_EXEC_SINGLE_T(RANDOMSINGLE, NAME, (SIGNATURE), TYPES)) +#define BUILD_SINGLE_TEMPLATE_TWICE(NAME, SIGNATURE, TYPES) \ + EVAL(_EXEC_SELECTOR_T(TEMPLATE_SINGLE_TWICE, NAME, SIGNATURE, TYPES)) +#define BUILD_DOUBLE_TEMPLATE(NAME, SIGNATURE, TYPES_A, TYPES_B) \ + EVAL(_EXEC_DOUBLE_T(RANDOMDOUBLE, NAME, (SIGNATURE), (TYPES_A), TYPES_B)) +#define BUILD_SINGLE_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) \ + switch (XTYPE) { \ + EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE, NAME, SIGNATURE, TYPES)); \ + default: { \ + printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); \ + fflush(stdout); \ + throw std::runtime_error("bad data type"); \ + } \ + } +#define BUILD_SINGLE_SELECTOR_TWICE(XTYPE, NAME, SIGNATURE, TYPES) \ + switch (XTYPE) { \ + EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE_TWICE, NAME, SIGNATURE, TYPES)); \ + default: { \ + printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); \ + fflush(stdout); \ + throw std::runtime_error("bad data type"); \ + } \ + } +#define BUILD_SINGLE_SELECTOR_THRICE(XTYPE, NAME, SIGNATURE, TYPES) \ + switch (XTYPE) { \ + EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE_THRICE, NAME, SIGNATURE, TYPES)); \ + default: { \ + printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); \ + fflush(stdout); \ + throw std::runtime_error("bad data type"); \ + } \ + } + +#define BUILD_SINGLE_PARTIAL_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) \ + switch (XTYPE) { \ + EVAL(_EXEC_SELECTOR_T(SELECTOR_PARTIAL_SINGLE, NAME, SIGNATURE, TYPES)); \ + default: { \ + printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); \ + fflush(stdout); \ + throw std::runtime_error("bad data type"); \ + } \ + } +#define BUILD_DOUBLE_SELECTOR(XTYPE, YTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) \ + switch (XTYPE) { \ + EVAL(_EXEC_SELECTOR_TT_1(SELECTOR_DOUBLE, YTYPE, NAME, (SIGNATURE), \ + (TYPES_B), TYPES_A)); \ + default: { \ + printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); \ + fflush(stdout); \ + throw std::runtime_error("bad data type"); \ + } \ + } +#define BUILD_TRIPLE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_X, \ + TYPES_Y, TYPES_Z) \ + switch (XTYPE) { \ + EVAL(_EXEC_SELECTOR_TTT_1(SELECTOR_TRIPLE, YTYPE, ZTYPE, NAME, SIGNATURE, \ + (TYPES_Z), (TYPES_Y), TYPES_X)); \ + default: { \ + printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); \ + fflush(stdout); \ + throw std::runtime_error("bad data type"); \ + } \ + } +#define BUILD_TRIPLE_TEMPLATE(NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPES_Z) \ + EVAL(_EXEC_TRIPLE_T1(RANDOMTRIPLE, NAME, (SIGNATURE), (TYPES_X), (TYPES_Y), \ + TYPES_Z)) +#define BUILD_PAIRWISE_TEMPLATE(NAME, SIGNATURE, TYPES_A) \ + EVAL(_EXEC_DOUBLE_P(RANDOMPAIRWISE, NAME, (SIGNATURE), TYPES_A)) +#define BUILD_PAIRWISE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, \ + TYPES_B) \ + switch (XTYPE) { \ + EVAL(_EXEC_SELECTOR_P_1(SELECTOR_PAIRWISE, XTYPE, YTYPE, ZTYPE, NAME, \ + (SIGNATURE), (TYPES_B), TYPES_A)); \ + default: { \ + printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); \ + fflush(stdout); \ + throw std::runtime_error("bad data type"); \ + } \ + } #else #define BUILD_SINGLE_UNCHAINED_TEMPLATE(NAME, SIGNATURE, TYPES) #define BUILD_SINGLE_TEMPLATE(NAME, SIGNATURE, TYPES) @@ -569,74 +1543,210 @@ #define BUILD_SINGLE_SELECTOR_THRICE(XTYPE, NAME, SIGNATURE, TYPES) #define BUILD_SINGLE_PARTIAL_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) #define BUILD_DOUBLE_SELECTOR(XTYPE, YTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) -#define BUILD_TRIPLE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPES_Z) +#define BUILD_TRIPLE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_X, \ + TYPES_Y, TYPES_Z) #define BUILD_TRIPLE_TEMPLATE(NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPES_Z) #define BUILD_PAIRWISE_TEMPLATE(NAME, SIGNATURE, TYPES_A) -#define BUILD_PAIRWISE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) +#define BUILD_PAIRWISE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, \ + TYPES_B) #endif #define LIST(...) __VA_ARGS__ -#define _SELECTOR_DOUBLE_2(NAME, SIGNATURE, TYPE_A, ENUM, TYPE_B) case ENUM: { NAME SIGNATURE; break; }; -#define SELECTOR_DOUBLE_2(NAME, SIGNATURE, TYPE_A, TYPE_B) EVALUATING_PASTE2(_SELECT, OR_DOUBLE_2(NAME, UNPAREN3(SIGNATURE), TYPE_A, UNPAREN3(TYPE_B))) - -#define _SELECTOR_DOUBLE(YTYPE, NAME, SIGNATURE, ENUM, TYPE_A, ...) case ENUM: { switch(YTYPE) { EXPAND(DISPATCH_DTYPES2(NAME, SIGNATURE, TYPE_A, __VA_ARGS__)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d\n", YTYPE, __FILE__, __LINE__); fflush(stdout);}}; break; }; -#define SELECTOR_DOUBLE(YTYPE, NAME, SIGNATURE, TYPES_B, TYPE_A) EVALUATING_PASTE(_SELECTOR, _DOUBLE(YTYPE, NAME, SIGNATURE, UNPAREN(TYPE_A), UNPAREN(TYPES_B))) - -#define _SELECTOR_PAIRWISE_2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, ENUM, TYPE_B) case ENUM: { if (ZTYPE == YTYPE) {NAME SIGNATURE;} else if (XTYPE == ZTYPE ){NAME SIGNATURE;} else {printf("[ERROR] Unknown dtypeX=%d on %s:%d\n", YTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("Unknown Z operand");}; break; }; -#define SELECTOR_PAIRWISE_2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) EVALUATING_PASTE2(_SELECT, OR_PAIRWISE_2(XTYPE, YTYPE, ZTYPE, NAME, UNPAREN3(SIGNATURE), TYPE_A, UNPAREN3(TYPE_B))) -#define _SELECTOR_PAIRWISE(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, ENUM, TYPE_A, ...) case ENUM: { switch(YTYPE) { EXPAND(DISPATCH_PAIRWISE2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d\n", YTYPE, __FILE__, __LINE__); fflush(stdout);}}; break; }; -#define SELECTOR_PAIRWISE(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_B, TYPE_A) EVALUATING_PASTE(_SELECTOR, _PAIRWISE(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, UNPAREN(TYPE_A), UNPAREN(TYPES_B))) - -#define _SELECTOR_TRIPLE_3(NAME, SIGNATURE, TYPE_X, TYPE_Y, ENUM_Z, TYPE_Z) case ENUM_Z: { NAMESIGNATURE;}; break; -#define SELECTOR_TRIPLE_3(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, TYPE_Z) EVALUATING_PASTE3(_SELECTOR, _TRIPLE_3(NAME, SIGNATURE, TYPE_X, TYPE_Y, UNPAREN3(TYPE_Z))) -#define _SELECTOR_TRIPLE_2(ZTYPE, NAME, SIGNATURE, TYPE_X, ENUM_Y, TYPE_Y, TYPES_Z) case ENUM_Y: { switch (ZTYPE) { EXPAND2(DISPATCH_TTYPES3(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, UNPAREN3(TYPES_Z))); default: {printf("[ERROR] Unknown dtypeZ=%d on %s:%d\n", ZTYPE, __FILE__, __LINE__); ; fflush(stdout);} } break; }; -#define SELECTOR_TRIPLE_2(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, TYPE_Y) EVALUATING_PASTE2(_SELECTOR, _TRIPLE_2(ZTYPE, NAME, SIGNATURE, TYPE_X, UNPAREN2(TYPE_Y), TYPES_Z)) -#define _SELECTOR_TRIPLE(YTYPE, ZTYPE, NAME, SIGNATURE, ENUM_X, TYPE_X, TYPES_Z, ...) case ENUM_X: { switch (YTYPE) { EXPAND(DISPATCH_TTYPES2(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, __VA_ARGS__ )); default: {printf("[ERROR] Unknown dtypeY=%d on %s:%d\n", YTYPE, __FILE__, __LINE__); ; fflush(stdout);} } break; }; -#define SELECTOR_TRIPLE(YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, TYPE_X) EVALUATING_PASTE(_SELECTOR, _TRIPLE(YTYPE, ZTYPE, NAME, SIGNATURE, UNPAREN(TYPE_X), TYPES_Z, UNPAREN(TYPES_Y))) - -#define _SELECTOR_SINGLE(A, B, C, D) case C: {AB; break;}; -#define SELECTOR_SINGLE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_SINGLE(A, B, UNPAREN(C))) - -#define _SELECTOR_SINGLE_THRICE(A, B, C, D) case C: {AB; break;}; -#define SELECTOR_SINGLE_THRICE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_SINGLE_THRICE(A, B, UNPAREN(C))) - -#define _SELECTOR_SINGLE_TWICE(A, B, C, D) case C: {AB; break;}; -#define SELECTOR_SINGLE_TWICE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_SINGLE_TWICE(A, B, UNPAREN(C))) - -#define _TEMPLATE_SINGLE_TWICE(A, B, C, D) AB; -#define TEMPLATE_SINGLE_TWICE(A, B, C) EVALUATING_PASTE(_TEM, PLATE_SINGLE_TWICE(A, B, UNPAREN(C))) - -#define _SELECTOR_PARTIAL_SINGLE(A, B, C, D) case C: {A D, UNPAREN2(B); break;}; -#define SELECTOR_PARTIAL_SINGLE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_PARTIAL_SINGLE(A, B, UNPAREN(C))) - -#define _RANDOMSINGLE(A, B, C, D) AB; +#define _SELECTOR_DOUBLE_2(NAME, SIGNATURE, TYPE_A, ENUM, TYPE_B) \ + case ENUM: { \ + NAME SIGNATURE; \ + break; \ + }; +#define SELECTOR_DOUBLE_2(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVALUATING_PASTE2(_SELECT, OR_DOUBLE_2(NAME, UNPAREN3(SIGNATURE), TYPE_A, \ + UNPAREN3(TYPE_B))) + +#define _SELECTOR_DOUBLE(YTYPE, NAME, SIGNATURE, ENUM, TYPE_A, ...) \ + case ENUM: { \ + switch (YTYPE) { \ + EXPAND(DISPATCH_DTYPES2(NAME, SIGNATURE, TYPE_A, __VA_ARGS__)); \ + default: { \ + printf("[ERROR] Unknown dtypeX=%d on %s:%d\n", YTYPE, __FILE__, \ + __LINE__); \ + fflush(stdout); \ + } \ + }; \ + break; \ + }; +#define SELECTOR_DOUBLE(YTYPE, NAME, SIGNATURE, TYPES_B, TYPE_A) \ + EVALUATING_PASTE(_SELECTOR, _DOUBLE(YTYPE, NAME, SIGNATURE, UNPAREN(TYPE_A), \ + UNPAREN(TYPES_B))) + +#define _SELECTOR_PAIRWISE_2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + ENUM, TYPE_B) \ + case ENUM: { \ + if (ZTYPE == YTYPE) { \ + NAME SIGNATURE; \ + } else if (XTYPE == ZTYPE) { \ + NAME SIGNATURE; \ + } else { \ + printf("[ERROR] Unknown dtypeX=%d on %s:%d\n", YTYPE, __FILE__, \ + __LINE__); \ + fflush(stdout); \ + throw std::runtime_error("Unknown Z operand"); \ + }; \ + break; \ + }; +#define SELECTOR_PAIRWISE_2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + TYPE_B) \ + EVALUATING_PASTE2( \ + _SELECT, OR_PAIRWISE_2(XTYPE, YTYPE, ZTYPE, NAME, UNPAREN3(SIGNATURE), \ + TYPE_A, UNPAREN3(TYPE_B))) +#define _SELECTOR_PAIRWISE(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, ENUM, TYPE_A, \ + ...) \ + case ENUM: { \ + switch (YTYPE) { \ + EXPAND(DISPATCH_PAIRWISE2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)); \ + default: { \ + printf("[ERROR] Unknown dtypeX=%d on %s:%d\n", YTYPE, __FILE__, \ + __LINE__); \ + fflush(stdout); \ + } \ + }; \ + break; \ + }; +#define SELECTOR_PAIRWISE(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_B, \ + TYPE_A) \ + EVALUATING_PASTE(_SELECTOR, _PAIRWISE(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, \ + UNPAREN(TYPE_A), UNPAREN(TYPES_B))) + +#define _SELECTOR_TRIPLE_3(NAME, SIGNATURE, TYPE_X, TYPE_Y, ENUM_Z, TYPE_Z) \ + case ENUM_Z: { \ + NAME SIGNATURE; \ + }; break; +#define SELECTOR_TRIPLE_3(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, TYPE_Z) \ + EVALUATING_PASTE3( \ + _SELECTOR, _TRIPLE_3(NAME, SIGNATURE, TYPE_X, TYPE_Y, UNPAREN3(TYPE_Z))) +#define _SELECTOR_TRIPLE_2(ZTYPE, NAME, SIGNATURE, TYPE_X, ENUM_Y, TYPE_Y, \ + TYPES_Z) \ + case ENUM_Y: { \ + switch (ZTYPE) { \ + EXPAND2(DISPATCH_TTYPES3(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, \ + UNPAREN3(TYPES_Z))); \ + default: { \ + printf("[ERROR] Unknown dtypeZ=%d on %s:%d\n", ZTYPE, __FILE__, \ + __LINE__); \ + ; \ + fflush(stdout); \ + } \ + } \ + break; \ + }; +#define SELECTOR_TRIPLE_2(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, TYPE_Y) \ + EVALUATING_PASTE2(_SELECTOR, _TRIPLE_2(ZTYPE, NAME, SIGNATURE, TYPE_X, \ + UNPAREN2(TYPE_Y), TYPES_Z)) +#define _SELECTOR_TRIPLE(YTYPE, ZTYPE, NAME, SIGNATURE, ENUM_X, TYPE_X, \ + TYPES_Z, ...) \ + case ENUM_X: { \ + switch (YTYPE) { \ + EXPAND(DISPATCH_TTYPES2(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, \ + __VA_ARGS__)); \ + default: { \ + printf("[ERROR] Unknown dtypeY=%d on %s:%d\n", YTYPE, __FILE__, \ + __LINE__); \ + ; \ + fflush(stdout); \ + } \ + } \ + break; \ + }; +#define SELECTOR_TRIPLE(YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, \ + TYPE_X) \ + EVALUATING_PASTE(_SELECTOR, \ + _TRIPLE(YTYPE, ZTYPE, NAME, SIGNATURE, UNPAREN(TYPE_X), \ + TYPES_Z, UNPAREN(TYPES_Y))) + +#define _SELECTOR_SINGLE(A, B, C, D) \ + case C: { \ + A B; \ + break; \ + }; +#define SELECTOR_SINGLE(A, B, C) \ + EVALUATING_PASTE(_SEL, ECTOR_SINGLE(A, B, UNPAREN(C))) + +#define _SELECTOR_SINGLE_THRICE(A, B, C, D) \ + case C: { \ + A B; \ + break; \ + }; +#define SELECTOR_SINGLE_THRICE(A, B, C) \ + EVALUATING_PASTE(_SEL, ECTOR_SINGLE_THRICE(A, B, UNPAREN(C))) + +#define _SELECTOR_SINGLE_TWICE(A, B, C, D) \ + case C: { \ + A B; \ + break; \ + }; +#define SELECTOR_SINGLE_TWICE(A, B, C) \ + EVALUATING_PASTE(_SEL, ECTOR_SINGLE_TWICE(A, B, UNPAREN(C))) + +#define _TEMPLATE_SINGLE_TWICE(A, B, C, D) A B; +#define TEMPLATE_SINGLE_TWICE(A, B, C) \ + EVALUATING_PASTE(_TEM, PLATE_SINGLE_TWICE(A, B, UNPAREN(C))) + +#define _SELECTOR_PARTIAL_SINGLE(A, B, C, D) \ + case C: { \ + A D, UNPAREN2(B); \ + break; \ + }; +#define SELECTOR_PARTIAL_SINGLE(A, B, C) \ + EVALUATING_PASTE(_SEL, ECTOR_PARTIAL_SINGLE(A, B, UNPAREN(C))) + +#define _RANDOMSINGLE(A, B, C, D) A B; #define _RANDOMSINGLEU(A, B, C, D) A D B; -#define RANDOMSINGLE(A, B, C) EVALUATING_PASTE(_RAND, OMSINGLE(A, UNPAREN(B), UNPAREN(C))) -#define RANDOMSINGLEU(A, B, C) EVALUATING_PASTE(_RAND, OMSINGLEU(A, UNPAREN(B), UNPAREN(C))) -#define RANDOMDOUBLE(A, B, C, D) EXPAND(DISPATCH_DTYPES(A, UNPAREN(B), D, UNPAREN(C))) - -#define _RANDOMDOUBLE2(A, B, C, D, E, F) AB; -#define RANDOMDOUBLE2(A, B, C, D) EVALUATING_PASTE(_RAND, OMDOUBLE2(A, B, UNPAREN(C), UNPAREN(D))) - -#define _RANDOMPAIRWISE2(A, B, C, D, E) AE; -#define RANDOMPAIRWISE(A, B, C) EVALUATING_PASTE(_RANDOM, PAIRWISE2(A, UNPAREN(C), UNPAREN(B))) - -#define _RANDOMTRIPLE3(A, B, ZN, ZT, YN, YT, XN, XT) AB; -#define RANDOMTRIPLE3(A, B, Z, Y, X) EVALUATING_PASTE(_RANDOM, TRIPLE3(A, UNPAREN(B), UNPAREN(Z), UNPAREN(Y), UNPAREN(X))) - -#define _RANDOMTRIPLE2(NAME, SIGNATURE, TYPE_Z, TYPE_Y, TYPES_X) EVALX(_EXEC_TRIPLE_T3(RANDOMTRIPLE3, NAME, SIGNATURE, TYPE_Z, TYPE_Y, UNPAREN(TYPES_X))) -#define RANDOMTRIPLE2(NAME, SIGNATURE, TYPE_Z, TYPES_X, TYPE_Y) _RANDOMTRIPLE2(NAME, SIGNATURE, TYPE_Z, TYPE_Y, TYPES_X) -#define _RANDOMTRIPLE(NAME, SIGNATURE, TYPE_Z, TYPES_X, TYPES_Y) EVAL(_EXEC_TRIPLE_T2(RANDOMTRIPLE2, NAME, SIGNATURE, TYPE_Z, TYPES_X, UNPAREN(TYPES_Y))) -#define RANDOMTRIPLE(NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPE_Z) _RANDOMTRIPLE(NAME, SIGNATURE, TYPE_Z, TYPES_X, TYPES_Y) - - -#define BROADCAST(NAME) sd::BroadcastOpsTuple::custom(sd::scalar::NAME, sd::pairwise::NAME, sd::broadcast::NAME) -#define BROADCAST_BOOL(NAME) sd::BroadcastBoolOpsTuple::custom(sd::scalar::NAME, sd::pairwise::NAME, sd::broadcast::NAME) +#define RANDOMSINGLE(A, B, C) \ + EVALUATING_PASTE(_RAND, OMSINGLE(A, UNPAREN(B), UNPAREN(C))) +#define RANDOMSINGLEU(A, B, C) \ + EVALUATING_PASTE(_RAND, OMSINGLEU(A, UNPAREN(B), UNPAREN(C))) +#define RANDOMDOUBLE(A, B, C, D) \ + EXPAND(DISPATCH_DTYPES(A, UNPAREN(B), D, UNPAREN(C))) + +#define _RANDOMDOUBLE2(A, B, C, D, E, F) A B; +#define RANDOMDOUBLE2(A, B, C, D) \ + EVALUATING_PASTE(_RAND, OMDOUBLE2(A, B, UNPAREN(C), UNPAREN(D))) + +#define _RANDOMPAIRWISE2(A, B, C, D, E) A E; +#define RANDOMPAIRWISE(A, B, C) \ + EVALUATING_PASTE(_RANDOM, PAIRWISE2(A, UNPAREN(C), UNPAREN(B))) + +#define _RANDOMTRIPLE3(A, B, ZN, ZT, YN, YT, XN, XT) A B; +#define RANDOMTRIPLE3(A, B, Z, Y, X) \ + EVALUATING_PASTE(_RANDOM, \ + TRIPLE3(A, UNPAREN(B), UNPAREN(Z), UNPAREN(Y), UNPAREN(X))) + +#define _RANDOMTRIPLE2(NAME, SIGNATURE, TYPE_Z, TYPE_Y, TYPES_X) \ + EVALX(_EXEC_TRIPLE_T3(RANDOMTRIPLE3, NAME, SIGNATURE, TYPE_Z, TYPE_Y, \ + UNPAREN(TYPES_X))) +#define RANDOMTRIPLE2(NAME, SIGNATURE, TYPE_Z, TYPES_X, TYPE_Y) \ + _RANDOMTRIPLE2(NAME, SIGNATURE, TYPE_Z, TYPE_Y, TYPES_X) +#define _RANDOMTRIPLE(NAME, SIGNATURE, TYPE_Z, TYPES_X, TYPES_Y) \ + EVAL(_EXEC_TRIPLE_T2(RANDOMTRIPLE2, NAME, SIGNATURE, TYPE_Z, TYPES_X, \ + UNPAREN(TYPES_Y))) +#define RANDOMTRIPLE(NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPE_Z) \ + _RANDOMTRIPLE(NAME, SIGNATURE, TYPE_Z, TYPES_X, TYPES_Y) + +#define BROADCAST(NAME) \ + sd::BroadcastOpsTuple::custom(sd::scalar::NAME, sd::pairwise::NAME, \ + sd::broadcast::NAME) +#define BROADCAST_BOOL(NAME) \ + sd::BroadcastBoolOpsTuple::custom(sd::scalar::NAME, sd::pairwise::NAME, \ + sd::broadcast::NAME) #define ALL_STRINGS sd::DataType::UTF8, sd::DataType::UTF16, sd::DataType::UTF32 #define ALL_INDICES sd::DataType::INT32, sd::DataType::INT64 -#define ALL_INTS sd::DataType::INT8, sd::DataType::UINT8, sd::DataType::INT16, sd::DataType::UINT16, sd::DataType::INT32, sd::DataType::UINT32, sd::DataType::INT64, sd::DataType::UINT64 -#define ALL_FLOATS sd::DataType::HALF, sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::BFLOAT16 - -#endif //TESTS_CPU_TYPE_BOILERPLATE_H +#define ALL_INTS \ + sd::DataType::INT8, sd::DataType::UINT8, sd::DataType::INT16, \ + sd::DataType::UINT16, sd::DataType::INT32, sd::DataType::UINT32, \ + sd::DataType::INT64, sd::DataType::UINT64 +#define ALL_FLOATS \ + sd::DataType::HALF, sd::DataType::FLOAT32, sd::DataType::DOUBLE, \ + sd::DataType::BFLOAT16 + +#endif // TESTS_CPU_TYPE_BOILERPLATE_H diff --git a/libnd4j/include/system/util.h b/libnd4j/include/system/util.h index aa2055606fcc..fbd3214c4c35 100644 --- a/libnd4j/include/system/util.h +++ b/libnd4j/include/system/util.h @@ -14,7 +14,7 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -/* +/* * File: util.h * Author: saudet * @@ -34,14 +34,14 @@ static inline Nd4jLong microTime() { #ifdef WIN32 - LARGE_INTEGER freq, count; - QueryPerformanceFrequency(&freq); - QueryPerformanceCounter(&count); - return (Nd4jLong)count.QuadPart/freq.QuadPart; + LARGE_INTEGER freq, count; + QueryPerformanceFrequency(&freq); + QueryPerformanceCounter(&count); + return (Nd4jLong)count.QuadPart / freq.QuadPart; #else - timeval tv; - gettimeofday(&tv, NULL); - return (Nd4jLong)tv.tv_sec*1000000 + tv.tv_usec; + timeval tv; + gettimeofday(&tv, NULL); + return (Nd4jLong)tv.tv_sec * 1000000 + tv.tv_usec; #endif } diff --git a/libnd4j/include/types/bfloat16.h b/libnd4j/include/types/bfloat16.h index a0590981605f..2be0f4e3c140 100644 --- a/libnd4j/include/types/bfloat16.h +++ b/libnd4j/include/types/bfloat16.h @@ -16,7 +16,8 @@ /* - Intel bfloat16 data type, based on https://software.intel.com/sites/default/files/managed/40/8b/bf16-hardware-numerics-definition-white-paper.pdf + Intel bfloat16 data type, based on + https://software.intel.com/sites/default/files/managed/40/8b/bf16-hardware-numerics-definition-white-paper.pdf */ @@ -32,7 +33,6 @@ #include #endif - #ifdef __CUDACC__ #define local_def inline __host__ __device__ #elif _MSC_VER @@ -43,212 +43,352 @@ #define local_def inline #endif -//namespace sd -//{ - struct bfloat16 - { - private: - template - struct isNumericType { static bool const value = std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value; }; - // struct isNumericType { static bool const value = std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::value;; }; - - public: - int16_t _data; - - local_def bfloat16() { - _data = 0; - } - - template ::value>::type> - local_def bfloat16(const T& rhs) { - *this = rhs; - } - - local_def operator float() const { - int32_t temp = this->_data << 16; //((sign << 31) | (exponent << 23) | mantissa); - return *reinterpret_cast(&temp); - } - - local_def explicit operator bool() const { - return this->_data == 0 ? false : true; - } - - template ::value>::type> - local_def explicit operator T() const { - return static_cast(static_cast(*this)); - } - - local_def bfloat16& operator=(const bool rhs) { - *this = (float)rhs ? 1.f: 0.f; - return *this; - } - - local_def bfloat16& operator=(const float& rhs) { - #ifdef __CUDACC__ - if(::isnan(rhs)) { - _data = bfloat16::nan(); - return *this; - } - #endif - auto x = *reinterpret_cast(& const_cast(rhs)); - uint32_t lsb = (x >> 16) & 1; - uint32_t rounding_bias = 0x7fff + lsb; - x += rounding_bias; - this->_data = static_cast(x >> 16); - - return *this; - } - - local_def bfloat16& operator=(const bfloat16& rhs) { - _data = rhs._data; - return *this; - } - - template ::value>::type> - local_def bfloat16& operator=(const T& rhs) { - *this = (float)rhs; - return *this; - } - - local_def friend bool operator==(const bfloat16& a, const bfloat16& b) { return (a._data == b._data); } - local_def friend bool operator!=(const bfloat16& a, const bfloat16& b) { return !(a == b); } - local_def friend bool operator<(const bfloat16& a, const bfloat16& b) { return (float)a < (float)b; } - local_def friend bool operator>(const bfloat16& a, const bfloat16& b) { return (float)a > (float)b; } - local_def friend bool operator<=(const bfloat16& a, const bfloat16& b) { return (float)a <= (float)b; } - local_def friend bool operator>=(const bfloat16& a, const bfloat16& b) { return (float)a >= (float)b; } - - local_def friend bfloat16 operator+(const bfloat16& a, const bfloat16& b) { return bfloat16((float)a + (float)b); } - local_def friend bfloat16 operator-(const bfloat16& a, const bfloat16& b) { return bfloat16((float)a - (float)b); } - local_def friend bfloat16 operator*(const bfloat16& a, const bfloat16& b) { return bfloat16((float)a * (float)b); } - local_def friend bfloat16 operator/(const bfloat16& a, const bfloat16& b) { return bfloat16((float)a / (float)b); } - - template ::value>::type> - local_def friend bfloat16 operator+(const bfloat16& a, const T& b) { return a + static_cast(b); } - template ::value>::type> - local_def friend bfloat16 operator+(const T& a, const bfloat16& b) { return static_cast(a) + b; } - - template ::value>::type> - local_def friend bfloat16 operator-(const bfloat16& a, const T& b) { return a - static_cast(b); } - template ::value>::type> - local_def friend bfloat16 operator-(const T& a, const bfloat16& b) { return static_cast(a) - b; } - - template ::value>::type> - local_def friend bfloat16 operator*(const bfloat16& a, const T& b) { return a * static_cast(b); } - template ::value>::type> - local_def friend bfloat16 operator*(const T& a, const bfloat16& b) { return static_cast(a) * b; } - - template ::value>::type> - local_def friend bfloat16 operator/(const bfloat16& a, const T& b) { return a / static_cast(b); } - template ::value>::type> - local_def friend bfloat16 operator/(const T& a, const bfloat16& b) { return static_cast(a) / b; } - - template ::value>::type> - local_def friend bool operator==(const bfloat16& a, const T& b) { return a == static_cast(b); } - template ::value>::type> - local_def friend bool operator==(const T& a, const bfloat16& b) { return static_cast(a) == b; } - - template ::value>::type> - local_def friend bool operator!=(const bfloat16& a, const T& b) { return a != static_cast(b); } - template ::value>::type> - local_def friend bool operator!=(const T& a, const bfloat16& b) { return static_cast(a) != b; } - - template ::value>::type> - local_def friend bool operator<(const bfloat16& a, const T& b) { return a < static_cast(b); } - template ::value>::type> - local_def friend bool operator<(const T& a, const bfloat16& b) { return static_cast(a) < b; } - - template ::value>::type> - local_def friend bool operator>(const bfloat16& a, const T& b) { return a > static_cast(b); } - template ::value>::type> - local_def friend bool operator>(const T& a, const bfloat16& b) { return static_cast(a) > b; } - - template ::value>::type> - local_def friend bool operator<=(const bfloat16& a, const T& b) { return a <= static_cast(b); } - template ::value>::type> - local_def friend bool operator<=(const T& a, const bfloat16& b) { return static_cast(a) <= b; } - - template ::value>::type> - local_def friend bool operator>=(const bfloat16& a, const T& b) { return a >= static_cast(b); } - template ::value>::type> - local_def friend bool operator>=(const T& a, const bfloat16& b) { return static_cast(a) >= b; } - - local_def bfloat16& operator+=(bfloat16 rhs) { *this = (float)(*this) + (float)rhs; return *this; } - - local_def bfloat16& operator-=(bfloat16 rhs) { *this = (float)(*this) - (float)rhs; return *this; } - - local_def bfloat16& operator*=(bfloat16 rhs) { *this = (float)(*this) * (float)rhs; return *this; } - - local_def bfloat16& operator/=(bfloat16 rhs) { *this = (float)(*this) / (float)rhs; return *this; } - - template ::value>::type> - local_def bfloat16& operator+=(const T& rhs) { *this = *this + rhs; return *this; } - - template ::value>::type> - local_def bfloat16& operator-=(const T& rhs) { *this = *this - rhs; return *this; } - - template ::value>::type> - local_def bfloat16& operator*=(const T& rhs) { *this = *this * rhs; return *this; } - - template ::value>::type> - local_def bfloat16& operator/=(const T& rhs) { *this = *this / rhs; return *this; } - - local_def bfloat16& operator++() { *this = (float)*this + (float)1.f; return *this; } - - local_def bfloat16& operator--() { *this = (float)*this - (float)1.f; return *this; } - - local_def bfloat16 operator++(int) { *this = (float)*this + (float)1.f; return *this; } - - local_def bfloat16 operator--(int) { *this = (float)*this - (float)1.f; return *this; } - - local_def bfloat16 operator-() const { - return 0.f - (float)*this; - } - - - - // local_def std::ostream& operator<<(std::ostream& os) { - // os << static_cast(*this); - // return os; - // } - local_def static bfloat16 min() { - bfloat16 res; - res._data = 0xFF7F; - return res; - } - local_def static bfloat16 max() { - bfloat16 res; - res._data = 0x7F7F; - return res; - - } - local_def static bfloat16 eps() { - bfloat16 res; - res._data = 0x3C00; - return res; - } +struct float16; - local_def static bfloat16 inf() { - bfloat16 res; - res._data = 0x3C00; - return res; - } - - local_def static bfloat16 nan() { - bfloat16 res; - res._data = 0x7FC0; - return res; - } +// namespace sd +//{ +struct bfloat16 { + private: + template + struct isNumericType { + static bool const value = + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value; + }; + // struct isNumericType { static bool const value = std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::value;; }; + + public: + int16_t _data; + + local_def bfloat16() { _data = 0; } + + template ::value>::type> + local_def bfloat16(const T& rhs) { + *this = rhs; + } + + local_def operator float() const { + int32_t temp = this->_data + << 16; //((sign << 31) | (exponent << 23) | mantissa); + return *reinterpret_cast(&temp); + } + + local_def explicit operator bool() const { + return this->_data == 0 ? false : true; + } + + template ::value>::type> + local_def explicit operator T() const { + return static_cast(static_cast(*this)); + } + + local_def bfloat16& operator=(const bool rhs) { + *this = (float)rhs ? 1.f : 0.f; + return *this; + } + + local_def bfloat16& operator=(const float& rhs) { +#ifdef __CUDACC__ + if (::isnan(rhs)) { + _data = bfloat16::nan(); + return *this; + } +#endif + auto x = *reinterpret_cast(&const_cast(rhs)); + uint32_t lsb = (x >> 16) & 1; + uint32_t rounding_bias = 0x7fff + lsb; + x += rounding_bias; + this->_data = static_cast(x >> 16); + + return *this; + } + + local_def bfloat16& operator=(const bfloat16& rhs) { + _data = rhs._data; + return *this; + } + + template ::value>::type> + local_def bfloat16& operator=(const T& rhs) { + *this = (float)rhs; + return *this; + } + + local_def friend bool operator==(const bfloat16& a, const bfloat16& b) { + return (a._data == b._data); + } + local_def friend bool operator!=(const bfloat16& a, const bfloat16& b) { + return !(a == b); + } + local_def friend bool operator<(const bfloat16& a, const bfloat16& b) { + return (float)a < (float)b; + } + local_def friend bool operator>(const bfloat16& a, const bfloat16& b) { + return (float)a > (float)b; + } + local_def friend bool operator<=(const bfloat16& a, const bfloat16& b) { + return (float)a <= (float)b; + } + local_def friend bool operator>=(const bfloat16& a, const bfloat16& b) { + return (float)a >= (float)b; + } + + local_def friend bfloat16 operator+(const bfloat16& a, const bfloat16& b) { + return bfloat16((float)a + (float)b); + } + local_def friend bfloat16 operator-(const bfloat16& a, const bfloat16& b) { + return bfloat16((float)a - (float)b); + } + local_def friend bfloat16 operator*(const bfloat16& a, const bfloat16& b) { + return bfloat16((float)a * (float)b); + } + local_def friend bfloat16 operator/(const bfloat16& a, const bfloat16& b) { + return bfloat16((float)a / (float)b); + } + + template ::value>::type> + local_def friend bfloat16 operator+(const bfloat16& a, const T& b) { + return a + static_cast(b); + } + template ::value>::type> + local_def friend bfloat16 operator+(const T& a, const bfloat16& b) { + return static_cast(a) + b; + } + + template ::value>::type> + local_def friend bfloat16 operator-(const bfloat16& a, const T& b) { + return a - static_cast(b); + } + template ::value>::type> + local_def friend bfloat16 operator-(const T& a, const bfloat16& b) { + return static_cast(a) - b; + } + + template ::value>::type> + local_def friend bfloat16 operator*(const bfloat16& a, const T& b) { + return a * static_cast(b); + } + template ::value>::type> + local_def friend bfloat16 operator*(const T& a, const bfloat16& b) { + return static_cast(a) * b; + } + + template ::value>::type> + local_def friend bfloat16 operator/(const bfloat16& a, const T& b) { + return a / static_cast(b); + } + template ::value>::type> + local_def friend bfloat16 operator/(const T& a, const bfloat16& b) { + return static_cast(a) / b; + } + + template ::value>::type> + local_def friend bool operator==(const bfloat16& a, const T& b) { + return a == static_cast(b); + } + template ::value>::type> + local_def friend bool operator==(const T& a, const bfloat16& b) { + return static_cast(a) == b; + } + + template ::value>::type> + local_def friend bool operator!=(const bfloat16& a, const T& b) { + return a != static_cast(b); + } + template ::value>::type> + local_def friend bool operator!=(const T& a, const bfloat16& b) { + return static_cast(a) != b; + } + + template ::value>::type> + local_def friend bool operator<(const bfloat16& a, const T& b) { + return a < static_cast(b); + } + template ::value>::type> + local_def friend bool operator<(const T& a, const bfloat16& b) { + return static_cast(a) < b; + } + + template ::value>::type> + local_def friend bool operator>(const bfloat16& a, const T& b) { + return a > static_cast(b); + } + template ::value>::type> + local_def friend bool operator>(const T& a, const bfloat16& b) { + return static_cast(a) > b; + } + + template ::value>::type> + local_def friend bool operator<=(const bfloat16& a, const T& b) { + return a <= static_cast(b); + } + template ::value>::type> + local_def friend bool operator<=(const T& a, const bfloat16& b) { + return static_cast(a) <= b; + } + + template ::value>::type> + local_def friend bool operator>=(const bfloat16& a, const T& b) { + return a >= static_cast(b); + } + template ::value>::type> + local_def friend bool operator>=(const T& a, const bfloat16& b) { + return static_cast(a) >= b; + } + + local_def bfloat16& operator+=(bfloat16 rhs) { + *this = (float)(*this) + (float)rhs; + return *this; + } + + local_def bfloat16& operator-=(bfloat16 rhs) { + *this = (float)(*this) - (float)rhs; + return *this; + } + + local_def bfloat16& operator*=(bfloat16 rhs) { + *this = (float)(*this) * (float)rhs; + return *this; + } + + local_def bfloat16& operator/=(bfloat16 rhs) { + *this = (float)(*this) / (float)rhs; + return *this; + } + + template ::value>::type> + local_def bfloat16& operator+=(const T& rhs) { + *this = *this + rhs; + return *this; + } + + template ::value>::type> + local_def bfloat16& operator-=(const T& rhs) { + *this = *this - rhs; + return *this; + } + + template ::value>::type> + local_def bfloat16& operator*=(const T& rhs) { + *this = *this * rhs; + return *this; + } + + template ::value>::type> + local_def bfloat16& operator/=(const T& rhs) { + *this = *this / rhs; + return *this; + } + + local_def bfloat16& operator++() { + *this = (float)*this + (float)1.f; + return *this; + } + + local_def bfloat16& operator--() { + *this = (float)*this - (float)1.f; + return *this; + } + + local_def bfloat16 operator++(int) { + *this = (float)*this + (float)1.f; + return *this; + } + + local_def bfloat16 operator--(int) { + *this = (float)*this - (float)1.f; + return *this; + } + + local_def bfloat16 operator-() const { return 0.f - (float)*this; } + + // local_def std::ostream& operator<<(std::ostream& os) { + // os << static_cast(*this); + // return os; + // } + local_def static bfloat16 min() { + bfloat16 res; + res._data = 0xFF7F; + return res; + } + local_def static bfloat16 max() { + bfloat16 res; + res._data = 0x7F7F; + return res; + } + local_def static bfloat16 eps() { + bfloat16 res; + res._data = 0x3C00; + return res; + } + + local_def static bfloat16 inf() { + bfloat16 res; + res._data = 0x3C00; + return res; + } + + local_def static bfloat16 nan() { + bfloat16 res; + res._data = 0x7FC0; + return res; + } }; - - // local_def std::ostream& operator<<(std::ostream &os, const bfloat16 &f) { // os << static_cast(f); // return os; // } - -// local_def bfloat16 /* constexpr */ operator+(const bfloat16& h) { return h; } +// local_def bfloat16 /* constexpr */ operator+(const bfloat16& h) { return h; +// } // local_def bfloat16 operator - (const bfloat16& h) { // auto temp = h._data; @@ -258,8 +398,8 @@ // return t; // } -// WARNING: this implementation only for avoid cyclic references between float16 and bfloat16 types. -// local_def void float16::assign(const bfloat16& rhs) { +// WARNING: this implementation only for avoid cyclic references between float16 +// and bfloat16 types. local_def void float16::assign(const bfloat16& rhs) { // assign((float)rhs); // } diff --git a/libnd4j/include/types/float16.h b/libnd4j/include/types/float16.h index 761e66f1b2b5..3eecf3f32744 100644 --- a/libnd4j/include/types/float16.h +++ b/libnd4j/include/types/float16.h @@ -17,12 +17,13 @@ #ifndef LIBND4J_FLOAT16_H #define LIBND4J_FLOAT16_H +#include + #include #include #include -#include #if defined(__INTEL_COMPILER) || defined(SD_F16C) - #include +#include #endif struct bfloat16; @@ -34,63 +35,50 @@ struct bfloat16; // CUDA_9 and above struct ihalf : public __half { - public: - __host__ __device__ ihalf() : half() { - // - } - - inline __host__ __device__ unsigned short * getXP() { - return &this->__x; - } - - inline __host__ __device__ unsigned short getX() const { - return this->__x; - } - - inline __host__ __device__ void assign(const half f) { - this->__x = ((__half_raw *) &f)->x; - } + public: + __host__ __device__ ihalf() : half() { + // + } + + inline __host__ __device__ unsigned short* getXP() { return &this->__x; } + + inline __host__ __device__ unsigned short getX() const { return this->__x; } + + inline __host__ __device__ void assign(const half f) { + this->__x = ((__half_raw*)&f)->x; + } }; #else struct ihalf : public __half { - public: - __host__ __device__ ihalf() : half() { - // - } - - inline __host__ __device__ unsigned short * getXP() { - return &this->x; - } - - inline __host__ __device__ unsigned short getX() const { - return this->x; - } - - inline __host__ __device__ void assign(const half f) { - this->x = ((__half *) &f)->x; - } + public: + __host__ __device__ ihalf() : half() { + // + } + + inline __host__ __device__ unsigned short* getXP() { return &this->x; } + + inline __host__ __device__ unsigned short getX() const { return this->x; } + + inline __host__ __device__ void assign(const half f) { + this->x = ((__half*)&f)->x; + } }; -#endif // CUDA_8 +#endif // CUDA_8 #else struct __half { -public: - unsigned short x; - inline unsigned short * getXP() { - return &this->x; - } + public: + unsigned short x; + inline unsigned short* getXP() { return &this->x; } - inline unsigned short getX() const { - return this->x; - } + inline unsigned short getX() const { return this->x; } }; typedef __half half; typedef __half ihalf; - -#endif // CUDA +#endif // CUDA #ifdef __CUDACC__ #define local_def inline __host__ __device__ @@ -102,386 +90,545 @@ typedef __half ihalf; #define local_def inline #endif - static local_def int ishnan_(unsigned short h) { - return (h & 0x7c00U) == 0x7c00U && (h & 0x03ffU) != 0; + return (h & 0x7c00U) == 0x7c00U && (h & 0x03ffU) != 0; } static local_def int ishinf_(unsigned short h) { - return (h & 0x7c00U) == 0x7c00U && (h & 0x03ffU) == 0; + return (h & 0x7c00U) == 0x7c00U && (h & 0x03ffU) == 0; } static local_def int ishequ_(unsigned short x, unsigned short y) { - return ishnan_(x) == 0 && ishnan_(y) == 0 && x == y; + return ishnan_(x) == 0 && ishnan_(y) == 0 && x == y; } static local_def unsigned short hneg(unsigned short h) { - h ^= 0x8000U; - return h; + h ^= 0x8000U; + return h; } - #if defined(__INTEL_COMPILER) || defined(SD_F16C) //_Pragma("omp declare simd") inline -local_def float cpu_ihalf2float(ihalf h) { - return _cvtsh_ss(h.getX()); -} +local_def float cpu_ihalf2float(ihalf h) { return _cvtsh_ss(h.getX()); } #else local_def float cpu_ihalf2float(ihalf h) { - unsigned sign = ((h.getX() >> 15) & 1); - unsigned exponent = ((h.getX() >> 10) & 0x1f); - unsigned mantissa = ((h.getX() & 0x3ff) << 13); - - if (exponent == 0x1f) { /* NaN or Inf */ - mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0); - exponent = 0xff; - } else if (!exponent) { /* Denorm or Zero */ - if (mantissa) { - unsigned int msb; - exponent = 0x71; - do { - msb = (mantissa & 0x400000); - mantissa <<= 1; /* normalize */ - --exponent; - } while (!msb); - mantissa &= 0x7fffff; /* 1.mantissa is implicit */ - } - } else { - exponent += 0x70; + unsigned sign = ((h.getX() >> 15) & 1); + unsigned exponent = ((h.getX() >> 10) & 0x1f); + unsigned mantissa = ((h.getX() & 0x3ff) << 13); + + if (exponent == 0x1f) { /* NaN or Inf */ + mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0); + exponent = 0xff; + } else if (!exponent) { /* Denorm or Zero */ + if (mantissa) { + unsigned int msb; + exponent = 0x71; + do { + msb = (mantissa & 0x400000); + mantissa <<= 1; /* normalize */ + --exponent; + } while (!msb); + mantissa &= 0x7fffff; /* 1.mantissa is implicit */ } + } else { + exponent += 0x70; + } - int temp = ((sign << 31) | (exponent << 23) | mantissa); + int temp = ((sign << 31) | (exponent << 23) | mantissa); - return *((float*)((void*)&temp)); + return *((float*)((void*)&temp)); } #endif #if defined(__INTEL_COMPILER) || defined(SD_F16C) //_Pragma("omp declare simd") inline local_def ihalf cpu_float2ihalf_rn(float f) { - ihalf ret; - ret.x = _cvtss_sh(f, 0); - return ret; + ihalf ret; + ret.x = _cvtss_sh(f, 0); + return ret; } #else -local_def ihalf cpu_float2ihalf_rn(float f) -{ - ihalf ret; - - unsigned x = *((int*)(void*)(&f)); - unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1; - unsigned sign, exponent, mantissa; - - // Get rid of +NaN/-NaN case first. - if (u > 0x7f800000) { - *ret.getXP() = 0x7fffU; - return ret; - } +local_def ihalf cpu_float2ihalf_rn(float f) { + ihalf ret; - sign = ((x >> 16) & 0x8000); + unsigned x = *((int*)(void*)(&f)); + unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1; + unsigned sign, exponent, mantissa; - // Get rid of +Inf/-Inf, +0/-0. - if (u > 0x477fefff) { - *ret.getXP() = sign | 0x7c00U; - return ret; - } - if (u < 0x33000001) { - *ret.getXP() = (sign | 0x0000); - return ret; - } + // Get rid of +NaN/-NaN case first. + if (u > 0x7f800000) { + *ret.getXP() = 0x7fffU; + return ret; + } - exponent = ((u >> 23) & 0xff); - mantissa = (u & 0x7fffff); + sign = ((x >> 16) & 0x8000); - if (exponent > 0x70) { - shift = 13; - exponent -= 0x70; - } else { - shift = 0x7e - exponent; - exponent = 0; - mantissa |= 0x800000; - } - lsb = (1 << shift); - lsb_s1 = (lsb >> 1); - lsb_m1 = (lsb - 1); - - // Round to nearest even. - remainder = (mantissa & lsb_m1); - mantissa >>= shift; - if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { - ++mantissa; - if (!(mantissa & 0x3ff)) { - ++exponent; - mantissa = 0; - } + // Get rid of +Inf/-Inf, +0/-0. + if (u > 0x477fefff) { + *ret.getXP() = sign | 0x7c00U; + return ret; + } + if (u < 0x33000001) { + *ret.getXP() = (sign | 0x0000); + return ret; + } + + exponent = ((u >> 23) & 0xff); + mantissa = (u & 0x7fffff); + + if (exponent > 0x70) { + shift = 13; + exponent -= 0x70; + } else { + shift = 0x7e - exponent; + exponent = 0; + mantissa |= 0x800000; + } + lsb = (1 << shift); + lsb_s1 = (lsb >> 1); + lsb_m1 = (lsb - 1); + + // Round to nearest even. + remainder = (mantissa & lsb_m1); + mantissa >>= shift; + if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { + ++mantissa; + if (!(mantissa & 0x3ff)) { + ++exponent; + mantissa = 0; } + } - *ret.getXP() = (sign | (exponent << 10) | mantissa); + *ret.getXP() = (sign | (exponent << 10) | mantissa); - return ret; + return ret; } #endif struct float16 { + private: + template + struct isNumericType { + static bool const value = + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value; + }; // || std::is_same::value; }; + // struct isNumericType { static bool const value = std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::value;; }; + + public: + ihalf data; + local_def float16() { *data.getXP() = 0; } + + template ::value || + std::is_same::value>::type> + local_def float16(const T& rhs) { + *this = rhs; + } + + local_def float16(const half& rhs) { +#ifdef __CUDACC__ + data.assign(rhs); +#endif + } + + local_def operator float() const { +#ifdef __CUDA_ARCH__ + return __half2float(data); +#else + return cpu_ihalf2float(data); +#endif + } + + local_def explicit operator bool() const { + return static_cast(*this) != 0.0f; + } + + local_def explicit operator half() const { return data; } + + template ::value>::type> + local_def explicit operator T() const { + return static_cast(static_cast(*this)); + } - private: - template - struct isNumericType { static bool const value = std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value; };// || std::is_same::value; }; - // struct isNumericType { static bool const value = std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::value;; }; - - public: - ihalf data; - local_def float16() { *data.getXP() = 0; } - - template ::value || std::is_same::value>::type> - local_def float16(const T& rhs) { - *this = rhs; - } - - local_def float16(const half& rhs) { - #ifdef __CUDACC__ - data.assign(rhs); - #endif - } - - local_def operator float() const { - #ifdef __CUDA_ARCH__ - return __half2float(data); - #else - return cpu_ihalf2float(data); - #endif - } - - local_def explicit operator bool() const { - return static_cast(*this) != 0.0f; - } - - local_def explicit operator half() const { - return data; - } - - template ::value>::type> - local_def explicit operator T() const { - return static_cast(static_cast(*this)); - } - - local_def float16& operator=(const float& rhs) { - #ifdef __CUDA_ARCH__ - auto t = __float2half_rn(rhs); - auto b = *(data.getXP()); - - #ifdef CUDA_8 - *(data.getXP()) = t; - #else - data.assign(t); - #endif - - #else - data = cpu_float2ihalf_rn(rhs); - #endif - - return *this; - } - - local_def float16& operator=(const unsigned short rhs) { - *data.getXP() = rhs; - return *this; - } - - local_def float16& operator=(const bool rhs) { - *this = (float)rhs ? 1.f: 0.f; - return *this; - } - - local_def float16& operator=(const ihalf& rhs) { - *data.getXP() = ((ihalf) rhs).getX(); - return *this; - } - - #ifdef __CUDACC__ - local_def float16& operator=(const half& rhs) { - data.assign(rhs); - return *this; - } - #endif - - local_def float16& operator=(const float16& rhs) { - data = rhs.data; - return *this; - } - - template ::value || std::is_same::value>::type> - local_def float16& operator=(const T& rhs) { - *this = (float)rhs; - return *this; - } - - #ifdef NATIVE_HALFS - local_def friend bool operator==(const float16& a, const float16& b) { return __hequ(a.data, b.data); } - #else - local_def friend bool operator==(const float16& a, const float16& b) { return ishequ_(((ihalf) a.data).getX(), ((ihalf)b.data).getX()); } - #endif - - #ifdef NATIVE_HALFS - local_def friend bool operator!=(const float16& a, const float16& b) { return !(__hequ(a.data, b.data)); } - #else - local_def friend bool operator!=(const float16& a, const float16& b) { return !(a == b); } - #endif - - #ifdef NATIVE_HALFS - local_def friend bool operator<(const float16& a, const float16& b) { return __hlt(a.data, b.data); } - #else - local_def friend bool operator<(const float16& a, const float16& b) { return (float)a < (float)b; } - #endif - - #ifdef NATIVE_HALFS - local_def friend bool operator>(const float16& a, const float16& b) { return __hgt(a.data, b.data); } - #else - local_def friend bool operator>(const float16& a, const float16& b) { return (float)a > (float)b; } - #endif - - #ifdef NATIVE_HALFS - local_def friend bool operator<=(const float16& a, const float16& b) { return __hle(a.data, b.data); } - #else - local_def friend bool operator<=(const float16& a, const float16& b) { return (float)a <= (float)b; } - #endif - - #ifdef NATIVE_HALFS - local_def friend bool operator>=(const float16& a, const float16& b) { return __hge(a.data, b.data); } - #else - local_def friend bool operator>=(const float16& a, const float16& b) { return (float)a >= (float)b; } - #endif - - #ifdef NATIVE_HALFS - local_def friend float16 operator+(const float16& a, const float16& b) { return __hadd(a.data, b.data); } - - local_def friend float16 operator-(const float16& a, const float16& b) { return __hsub(a.data, b.data); } - - local_def friend float16 operator*(const float16& a, const float16& b) { return __hmul(a.data, b.data); } - - local_def friend float16 operator/(const float16& a, const float16& b) { - #ifdef CUDA_8 - return hdiv(a.data, b.data); - #else - return __hdiv(a.data, b.data); - #endif - } - #else - local_def friend float16 operator+(const float16& a, const float16& b) { return float16((float)a + (float)b); } - local_def friend float16 operator-(const float16& a, const float16& b) { return float16((float)a - (float)b); } - local_def friend float16 operator*(const float16& a, const float16& b) { return float16((float)a * (float)b); } - local_def friend float16 operator/(const float16& a, const float16& b) { return float16((float)a / (float)b); } - #endif - - template ::value>::type> - local_def friend float16 operator+(const float16& a, const T& b) { return a + static_cast(b); } - template ::value>::type> - local_def friend float16 operator+(const T& a, const float16& b) { return static_cast(a) + b; } - - template ::value>::type> - local_def friend float16 operator-(const float16& a, const T& b) { return a - static_cast(b); } - template ::value>::type> - local_def friend float16 operator-(const T& a, const float16& b) { return static_cast(a) - b; } - - template ::value>::type> - local_def friend float16 operator*(const float16& a, const T& b) { return a * static_cast(b); } - template ::value>::type> - local_def friend float16 operator*(const T& a, const float16& b) { return static_cast(a) * b; } - - template ::value>::type> - local_def friend float16 operator/(const float16& a, const T& b) { return a / static_cast(b); } - template ::value>::type> - local_def friend float16 operator/(const T& a, const float16& b) { return static_cast(a) / b; } - - template ::value>::type> - local_def friend bool operator==(const float16& a, const T& b) { return a == static_cast(b); } - template ::value>::type> - local_def friend bool operator==(const T& a, const float16& b) { return static_cast(a) == b; } - - template ::value>::type> - local_def friend bool operator!=(const float16& a, const T& b) { return a != static_cast(b); } - template ::value>::type> - local_def friend bool operator!=(const T& a, const float16& b) { return static_cast(a) != b; } - - template ::value>::type> - local_def friend bool operator<(const float16& a, const T& b) { return a < static_cast(b); } - template ::value>::type> - local_def friend bool operator<(const T& a, const float16& b) { return static_cast(a) < b; } - - template ::value>::type> - local_def friend bool operator>(const float16& a, const T& b) { return a > static_cast(b); } - template ::value>::type> - local_def friend bool operator>(const T& a, const float16& b) { return static_cast(a) > b; } - - template ::value>::type> - local_def friend bool operator<=(const float16& a, const T& b) { return a <= static_cast(b); } - template ::value>::type> - local_def friend bool operator<=(const T& a, const float16& b) { return static_cast(a) <= b; } - - template ::value>::type> - local_def friend bool operator>=(const float16& a, const T& b) { return a >= static_cast(b); } - template ::value>::type> - local_def friend bool operator>=(const T& a, const float16& b) { return static_cast(a) >= b; } - - local_def float16& operator+=(float16 rhs) { *this = (float)*this + (float)rhs; return *this; } - - local_def float16& operator-=(float16 rhs) { *this = (float)*this - (float)rhs; return *this; } - - local_def float16& operator*=(float16 rhs) { *this = (float)*this * (float)rhs; return *this; } - - local_def float16& operator/=(float16 rhs) { *this = (float)*this / (float)rhs; return *this; } - - template ::value>::type> - local_def float16& operator+=(const T& rhs) { *this = *this + rhs; return *this; } + local_def float16& operator=(const float& rhs) { +#ifdef __CUDA_ARCH__ + auto t = __float2half_rn(rhs); + auto b = *(data.getXP()); - template ::value>::type> - local_def float16& operator-=(const T& rhs) { *this = *this - rhs; return *this; } +#ifdef CUDA_8 + *(data.getXP()) = t; +#else + data.assign(t); +#endif - template ::value>::type> - local_def float16& operator*=(const T& rhs) { *this = *this * rhs; return *this; } +#else + data = cpu_float2ihalf_rn(rhs); +#endif - template ::value>::type> - local_def float16& operator/=(const T& rhs) { *this = *this / rhs; return *this; } + return *this; + } - local_def float16& operator++() { *this = *this + (float16)1.f; return *this; } + local_def float16& operator=(const unsigned short rhs) { + *data.getXP() = rhs; + return *this; + } - local_def float16& operator--() { *this = *this - (float16)1.f; return *this; } + local_def float16& operator=(const bool rhs) { + *this = (float)rhs ? 1.f : 0.f; + return *this; + } - local_def float16 operator++(int) { *this = *this + (float16)1.f; return *this; } + local_def float16& operator=(const ihalf& rhs) { + *data.getXP() = ((ihalf)rhs).getX(); + return *this; + } - local_def float16 operator--(int) { *this = *this - (float16)1.f; return *this; } +#ifdef __CUDACC__ + local_def float16& operator=(const half& rhs) { + data.assign(rhs); + return *this; + } +#endif - local_def float16 operator-() const { - return 0.f - (float)*this; - } + local_def float16& operator=(const float16& rhs) { + data = rhs.data; + return *this; + } + + template ::value || + std::is_same::value>::type> + local_def float16& operator=(const T& rhs) { + *this = (float)rhs; + return *this; + } + +#ifdef NATIVE_HALFS + local_def friend bool operator==(const float16& a, const float16& b) { + return __hequ(a.data, b.data); + } +#else + local_def friend bool operator==(const float16& a, const float16& b) { + return ishequ_(((ihalf)a.data).getX(), ((ihalf)b.data).getX()); + } +#endif - // local_def std::ostream& operator<<(std::ostream& os) { - // os << static_cast(*this); - // return os; - // } -}; +#ifdef NATIVE_HALFS + local_def friend bool operator!=(const float16& a, const float16& b) { + return !(__hequ(a.data, b.data)); + } +#else + local_def friend bool operator!=(const float16& a, const float16& b) { + return !(a == b); + } +#endif +#ifdef NATIVE_HALFS + local_def friend bool operator<(const float16& a, const float16& b) { + return __hlt(a.data, b.data); + } +#else + local_def friend bool operator<(const float16& a, const float16& b) { + return (float)a < (float)b; + } +#endif +#ifdef NATIVE_HALFS + local_def friend bool operator>(const float16& a, const float16& b) { + return __hgt(a.data, b.data); + } +#else + local_def friend bool operator>(const float16& a, const float16& b) { + return (float)a > (float)b; + } +#endif - // local_def std::ostream& operator<<(std::ostream &os, const float16 &f) { - // os << static_cast(f); - // return os; - // } +#ifdef NATIVE_HALFS + local_def friend bool operator<=(const float16& a, const float16& b) { + return __hle(a.data, b.data); + } +#else + local_def friend bool operator<=(const float16& a, const float16& b) { + return (float)a <= (float)b; + } +#endif - // local_def float16 operator+(const float16& h) { return h; } +#ifdef NATIVE_HALFS + local_def friend bool operator>=(const float16& a, const float16& b) { + return __hge(a.data, b.data); + } +#else + local_def friend bool operator>=(const float16& a, const float16& b) { + return (float)a >= (float)b; + } +#endif + +#ifdef NATIVE_HALFS + local_def friend float16 operator+(const float16& a, const float16& b) { + return __hadd(a.data, b.data); + } + + local_def friend float16 operator-(const float16& a, const float16& b) { + return __hsub(a.data, b.data); + } + + local_def friend float16 operator*(const float16& a, const float16& b) { + return __hmul(a.data, b.data); + } + + local_def friend float16 operator/(const float16& a, const float16& b) { +#ifdef CUDA_8 + return hdiv(a.data, b.data); +#else + return __hdiv(a.data, b.data); +#endif + } +#else + local_def friend float16 operator+(const float16& a, const float16& b) { + return float16((float)a + (float)b); + } + local_def friend float16 operator-(const float16& a, const float16& b) { + return float16((float)a - (float)b); + } + local_def friend float16 operator*(const float16& a, const float16& b) { + return float16((float)a * (float)b); + } + local_def friend float16 operator/(const float16& a, const float16& b) { + return float16((float)a / (float)b); + } +#endif + + template ::value>::type> + local_def friend float16 operator+(const float16& a, const T& b) { + return a + static_cast(b); + } + template ::value>::type> + local_def friend float16 operator+(const T& a, const float16& b) { + return static_cast(a) + b; + } + + template ::value>::type> + local_def friend float16 operator-(const float16& a, const T& b) { + return a - static_cast(b); + } + template ::value>::type> + local_def friend float16 operator-(const T& a, const float16& b) { + return static_cast(a) - b; + } + + template ::value>::type> + local_def friend float16 operator*(const float16& a, const T& b) { + return a * static_cast(b); + } + template ::value>::type> + local_def friend float16 operator*(const T& a, const float16& b) { + return static_cast(a) * b; + } + + template ::value>::type> + local_def friend float16 operator/(const float16& a, const T& b) { + return a / static_cast(b); + } + template ::value>::type> + local_def friend float16 operator/(const T& a, const float16& b) { + return static_cast(a) / b; + } + + template ::value>::type> + local_def friend bool operator==(const float16& a, const T& b) { + return a == static_cast(b); + } + template ::value>::type> + local_def friend bool operator==(const T& a, const float16& b) { + return static_cast(a) == b; + } + + template ::value>::type> + local_def friend bool operator!=(const float16& a, const T& b) { + return a != static_cast(b); + } + template ::value>::type> + local_def friend bool operator!=(const T& a, const float16& b) { + return static_cast(a) != b; + } + + template ::value>::type> + local_def friend bool operator<(const float16& a, const T& b) { + return a < static_cast(b); + } + template ::value>::type> + local_def friend bool operator<(const T& a, const float16& b) { + return static_cast(a) < b; + } + + template ::value>::type> + local_def friend bool operator>(const float16& a, const T& b) { + return a > static_cast(b); + } + template ::value>::type> + local_def friend bool operator>(const T& a, const float16& b) { + return static_cast(a) > b; + } + + template ::value>::type> + local_def friend bool operator<=(const float16& a, const T& b) { + return a <= static_cast(b); + } + template ::value>::type> + local_def friend bool operator<=(const T& a, const float16& b) { + return static_cast(a) <= b; + } + + template ::value>::type> + local_def friend bool operator>=(const float16& a, const T& b) { + return a >= static_cast(b); + } + template ::value>::type> + local_def friend bool operator>=(const T& a, const float16& b) { + return static_cast(a) >= b; + } + + local_def float16& operator+=(float16 rhs) { + *this = (float)*this + (float)rhs; + return *this; + } + + local_def float16& operator-=(float16 rhs) { + *this = (float)*this - (float)rhs; + return *this; + } + + local_def float16& operator*=(float16 rhs) { + *this = (float)*this * (float)rhs; + return *this; + } + + local_def float16& operator/=(float16 rhs) { + *this = (float)*this / (float)rhs; + return *this; + } + + template ::value>::type> + local_def float16& operator+=(const T& rhs) { + *this = *this + rhs; + return *this; + } + + template ::value>::type> + local_def float16& operator-=(const T& rhs) { + *this = *this - rhs; + return *this; + } + + template ::value>::type> + local_def float16& operator*=(const T& rhs) { + *this = *this * rhs; + return *this; + } + + template ::value>::type> + local_def float16& operator/=(const T& rhs) { + *this = *this / rhs; + return *this; + } + + local_def float16& operator++() { + *this = *this + (float16)1.f; + return *this; + } + + local_def float16& operator--() { + *this = *this - (float16)1.f; + return *this; + } + + local_def float16 operator++(int) { + *this = *this + (float16)1.f; + return *this; + } + + local_def float16 operator--(int) { + *this = *this - (float16)1.f; + return *this; + } + + local_def float16 operator-() const { return 0.f - (float)*this; } + + // local_def std::ostream& operator<<(std::ostream& os) { + // os << static_cast(*this); + // return os; + // } +}; - // local_def float16 operator - (const float16& h) { - // const ihalf * tmp = &h.data; - // return float16(hneg(tmp->getX())); - // } +// local_def std::ostream& operator<<(std::ostream &os, const float16 &f) { +// os << static_cast(f); +// return os; +// } + +// local_def float16 operator+(const float16& h) { return h; } + +// local_def float16 operator - (const float16& h) { +// const ihalf * tmp = &h.data; +// return float16(hneg(tmp->getX())); +// } #ifdef __CUDACC__ - local_def int isnan(const float16& h) { return ishnan_(((ihalf)h.data).getX()); } +local_def int isnan(const float16& h) { + return ishnan_(((ihalf)h.data).getX()); +} - local_def int isinf(const float16& h) { return ishinf_(((ihalf)h.data).getX()); } +local_def int isinf(const float16& h) { + return ishinf_(((ihalf)h.data).getX()); +} #endif - // std::ostream& operator << (std::ostream& s, const float16&); +// std::ostream& operator << (std::ostream& s, const float16&); #endif diff --git a/libnd4j/include/types/float8.h b/libnd4j/include/types/float8.h index 6dc03bba45fc..4060bb80ad27 100644 --- a/libnd4j/include/types/float8.h +++ b/libnd4j/include/types/float8.h @@ -35,151 +35,137 @@ #include - namespace sd { - typedef struct { - unsigned char x; - } __quarter; - - typedef __quarter quarter; - - quarter _CUDA_HD FORCEINLINE cpu_float2quarter_rn(float f); - float _CUDA_HD FORCEINLINE cpu_quarter2float(quarter b); - - struct float8 { - quarter data; +typedef struct { + unsigned char x; +} __quarter; - _CUDA_HD FORCEINLINE float8(); +typedef __quarter quarter; - template - _CUDA_HD FORCEINLINE float8(const T& rhs); +quarter _CUDA_HD FORCEINLINE cpu_float2quarter_rn(float f); +float _CUDA_HD FORCEINLINE cpu_quarter2float(quarter b); - template - _CUDA_HD FORCEINLINE float8& operator=(const T& rhs); +struct float8 { + quarter data; - _CUDA_HD FORCEINLINE operator float() const; + _CUDA_HD FORCEINLINE float8(); - _CUDA_HD FORCEINLINE void assign(double rhs); + template + _CUDA_HD FORCEINLINE float8(const T& rhs); - _CUDA_HD FORCEINLINE void assign(float rhs); - }; + template + _CUDA_HD FORCEINLINE float8& operator=(const T& rhs); + _CUDA_HD FORCEINLINE operator float() const; - float cpu_quarter2float(quarter b) { - unsigned sign = ((b.x >> 7) & 1); - unsigned exponent = ((b.x >> 4) & 0x7); - unsigned mantissa = ((b.x & 0xf) << 19); + _CUDA_HD FORCEINLINE void assign(double rhs); - if (exponent == 0x7) { /* NaN or Inf */ - mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0); - exponent = 0xff; - } else if (!exponent) { /* Denorm or Zero */ - if (mantissa) { - unsigned int msb; - exponent = 0x7d; - do { - msb = (mantissa & 0x400000); - mantissa <<= 1; /* normalize */ - --exponent; - } while (!msb); - mantissa &= 0x7fffff; /* 1.mantissa is implicit */ - } - } else { - exponent += 0x7C; - } + _CUDA_HD FORCEINLINE void assign(float rhs); +}; - int temp = ((sign << 31) | (exponent << 23) | mantissa); +float cpu_quarter2float(quarter b) { + unsigned sign = ((b.x >> 7) & 1); + unsigned exponent = ((b.x >> 4) & 0x7); + unsigned mantissa = ((b.x & 0xf) << 19); - return *((float*)((void*)&temp)); + if (exponent == 0x7) { /* NaN or Inf */ + mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0); + exponent = 0xff; + } else if (!exponent) { /* Denorm or Zero */ + if (mantissa) { + unsigned int msb; + exponent = 0x7d; + do { + msb = (mantissa & 0x400000); + mantissa <<= 1; /* normalize */ + --exponent; + } while (!msb); + mantissa &= 0x7fffff; /* 1.mantissa is implicit */ } + } else { + exponent += 0x7C; + } + int temp = ((sign << 31) | (exponent << 23) | mantissa); + return *((float*)((void*)&temp)); +} - quarter cpu_float2quarter_rn(float f) - { - quarter ret; - - unsigned x = *((int*)(void*)(&f)); - unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1; - unsigned sign, exponent, mantissa; - - // Get rid of +NaN/-NaN case first. - if (u > 0x7f800000) { - ret.x = 0x7fU; - return ret; - } - - sign = ((x >> 24) & 0x80); - - // Get rid of +Inf/-Inf, +0/-0. - if (u > 0x477fefff) { - ret.x = sign | 0x70U; - return ret; - } - if (u < 0x33000001) { - ret.x = (sign | 0x00); - return ret; - } - - exponent = ((u >> 23) & 0xff); - mantissa = (u & 0x7fffff); - - if (exponent > 0x7C) { - shift = 19; - exponent -= 0x7C; - } else { - shift = 0x90 - exponent; - exponent = 0; - mantissa |= 0x800000; - } - lsb = (1 << shift); - lsb_s1 = (lsb >> 1); - lsb_m1 = (lsb - 1); - - // Round to nearest even. - remainder = (mantissa & lsb_m1); - mantissa >>= shift; - if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { - ++mantissa; - if (!(mantissa & 0xf)) { - ++exponent; - mantissa = 0; - } - } - - ret.x = (sign | (exponent << 4) | mantissa); - - return ret; +quarter cpu_float2quarter_rn(float f) { + quarter ret; + + unsigned x = *((int*)(void*)(&f)); + unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1; + unsigned sign, exponent, mantissa; + + // Get rid of +NaN/-NaN case first. + if (u > 0x7f800000) { + ret.x = 0x7fU; + return ret; + } + + sign = ((x >> 24) & 0x80); + + // Get rid of +Inf/-Inf, +0/-0. + if (u > 0x477fefff) { + ret.x = sign | 0x70U; + return ret; + } + if (u < 0x33000001) { + ret.x = (sign | 0x00); + return ret; + } + + exponent = ((u >> 23) & 0xff); + mantissa = (u & 0x7fffff); + + if (exponent > 0x7C) { + shift = 19; + exponent -= 0x7C; + } else { + shift = 0x90 - exponent; + exponent = 0; + mantissa |= 0x800000; + } + lsb = (1 << shift); + lsb_s1 = (lsb >> 1); + lsb_m1 = (lsb - 1); + + // Round to nearest even. + remainder = (mantissa & lsb_m1); + mantissa >>= shift; + if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { + ++mantissa; + if (!(mantissa & 0xf)) { + ++exponent; + mantissa = 0; } + } + ret.x = (sign | (exponent << 4) | mantissa); - float8::float8() { - data = cpu_float2quarter_rn(0.0f); - } + return ret; +} - template - float8::float8(const T& rhs) { - assign(rhs); - } +float8::float8() { data = cpu_float2quarter_rn(0.0f); } - template - float8& float8::operator=(const T& rhs) { - assign(rhs); return *this; - } +template +float8::float8(const T& rhs) { + assign(rhs); +} +template +float8& float8::operator=(const T& rhs) { + assign(rhs); + return *this; +} - float8::operator float() const { - return cpu_quarter2float(data); - } +float8::operator float() const { return cpu_quarter2float(data); } - void float8::assign(double rhs) { - assign((float)rhs); - } +void float8::assign(double rhs) { assign((float)rhs); } - void float8::assign(float rhs) { - data = cpu_float2quarter_rn(rhs); - } -} +void float8::assign(float rhs) { data = cpu_float2quarter_rn(rhs); } +} // namespace sd -#endif //LIBND4J_FLOAT8_H +#endif // LIBND4J_FLOAT8_H diff --git a/libnd4j/include/types/impl/int16.cpp b/libnd4j/include/types/impl/int16.cpp index 67f90e9d8724..7ab8290557e4 100644 --- a/libnd4j/include/types/impl/int16.cpp +++ b/libnd4j/include/types/impl/int16.cpp @@ -22,11 +22,11 @@ namespace sd { - /* - template int16::int16(const float& rhs); - template int16::int16(const double& rhs); +/* +template int16::int16(const float& rhs); +template int16::int16(const double& rhs); - template int16& int16::operator=(const float& rhs); - template int16& int16::operator=(const double& rhs); - */ +template int16& int16::operator=(const float& rhs); +template int16& int16::operator=(const double& rhs); + */ } \ No newline at end of file diff --git a/libnd4j/include/types/impl/pair.cpp b/libnd4j/include/types/impl/pair.cpp index 767bfa63028f..b12fd767cdfe 100644 --- a/libnd4j/include/types/impl/pair.cpp +++ b/libnd4j/include/types/impl/pair.cpp @@ -21,16 +21,12 @@ #include namespace sd { - Pair::Pair(int first, int second) { - _first = first; - _second = second; - } +Pair::Pair(int first, int second) { + _first = first; + _second = second; +} - int Pair::first() const { - return _first; - } +int Pair::first() const { return _first; } - int Pair::second() const { - return _second; - }; -} +int Pair::second() const { return _second; }; +} // namespace sd diff --git a/libnd4j/include/types/impl/triple.cpp b/libnd4j/include/types/impl/triple.cpp index 0b39d4bac15e..9e03c24cd5b5 100644 --- a/libnd4j/include/types/impl/triple.cpp +++ b/libnd4j/include/types/impl/triple.cpp @@ -21,21 +21,15 @@ #include namespace sd { - int Triple::first() const { - return _first; - } +int Triple::first() const { return _first; } - int Triple::second() const { - return _second; - } +int Triple::second() const { return _second; } - int Triple::third() const { - return _third; - } +int Triple::third() const { return _third; } - Triple::Triple(int first, int second, int third) { - _first = first; - _second = second; - _third = third; - } +Triple::Triple(int first, int second, int third) { + _first = first; + _second = second; + _third = third; } +} // namespace sd diff --git a/libnd4j/include/types/impl/uint16.cpp b/libnd4j/include/types/impl/uint16.cpp index 5b858222da3b..234fcaa4a815 100644 --- a/libnd4j/include/types/impl/uint16.cpp +++ b/libnd4j/include/types/impl/uint16.cpp @@ -23,12 +23,11 @@ namespace sd { +/* +template uint16::uint16(const float& rhs); +template uint16::uint16(const double& rhs); - /* - template uint16::uint16(const float& rhs); - template uint16::uint16(const double& rhs); - - template uint16& uint16::operator=(const double& rhs); - template uint16& uint16::operator=(const float& rhs); - */ +template uint16& uint16::operator=(const double& rhs); +template uint16& uint16::operator=(const float& rhs); + */ } \ No newline at end of file diff --git a/libnd4j/include/types/impl/uint8.cpp b/libnd4j/include/types/impl/uint8.cpp index a6d25c9d3780..0ab6861ca59a 100644 --- a/libnd4j/include/types/impl/uint8.cpp +++ b/libnd4j/include/types/impl/uint8.cpp @@ -22,11 +22,11 @@ namespace sd { - /* - template uint8::uint8(const float& rhs); - template uint8::uint8(const double& rhs); +/* +template uint8::uint8(const float& rhs); +template uint8::uint8(const double& rhs); - template uint8& uint8::operator=(const float& rhs); - template uint8& uint8::operator=(const double& rhs); - */ +template uint8& uint8::operator=(const float& rhs); +template uint8& uint8::operator=(const double& rhs); + */ } \ No newline at end of file diff --git a/libnd4j/include/types/impl/utf8string.cpp b/libnd4j/include/types/impl/utf8string.cpp index a7df7cc28ddd..7cc6dc8a20ea 100644 --- a/libnd4j/include/types/impl/utf8string.cpp +++ b/libnd4j/include/types/impl/utf8string.cpp @@ -19,57 +19,57 @@ // #include + #include namespace sd { - utf8string::~utf8string() { - if (_allocated) - delete[] _buffer; - } +utf8string::~utf8string() { + if (_allocated) delete[] _buffer; +} - utf8string::utf8string() { - _allocated = false; - _length = 0; - _buffer = nullptr; - } +utf8string::utf8string() { + _allocated = false; + _length = 0; + _buffer = nullptr; +} - utf8string::utf8string(const char *string, int length) { - _length = length; - _buffer = new char[_length]; - _allocated = true; - std::memset(_buffer, 0, _length + 1); - std::memcpy(_buffer, string, _length); - } +utf8string::utf8string(const char *string, int length) { + _length = length; + _buffer = new char[_length]; + _allocated = true; + std::memset(_buffer, 0, _length + 1); + std::memcpy(_buffer, string, _length); +} - utf8string::utf8string(const std::string &str) { - _length = str.length(); - _buffer = new char[_length + 1]; - _allocated = true; - std::memset(_buffer, 0, _length + 1); - std::memcpy(_buffer, str.data(), _length); - _buffer[_length] = 0; - } +utf8string::utf8string(const std::string &str) { + _length = str.length(); + _buffer = new char[_length + 1]; + _allocated = true; + std::memset(_buffer, 0, _length + 1); + std::memcpy(_buffer, str.data(), _length); + _buffer[_length] = 0; +} - utf8string::utf8string(const utf8string &other) { - _length = other._length; - _buffer = new char[_length+1]; - _allocated = true; - std::memset(_buffer, 0, _length + 1); - std::memcpy(_buffer, other._buffer, _length); - _buffer[_length] = 0; - } +utf8string::utf8string(const utf8string &other) { + _length = other._length; + _buffer = new char[_length + 1]; + _allocated = true; + std::memset(_buffer, 0, _length + 1); + std::memcpy(_buffer, other._buffer, _length); + _buffer[_length] = 0; +} - void utf8string::Swap(utf8string &other) { - std::swap(_length, other._length); - std::swap(_buffer, other._buffer);// = new char[_length+1]; - std::swap(_allocated, other._allocated); // = true; - } +void utf8string::Swap(utf8string &other) { + std::swap(_length, other._length); + std::swap(_buffer, other._buffer); // = new char[_length+1]; + std::swap(_allocated, other._allocated); // = true; +} - utf8string& utf8string::operator=(const utf8string &other) { - if (this != &other) { - utf8string temp(other); - Swap(temp); - } - return *this; - } +utf8string &utf8string::operator=(const utf8string &other) { + if (this != &other) { + utf8string temp(other); + Swap(temp); + } + return *this; } +} // namespace sd diff --git a/libnd4j/include/types/int16.h b/libnd4j/include/types/int16.h index 25a77138151d..ab5ac4eb75a3 100644 --- a/libnd4j/include/types/int16.h +++ b/libnd4j/include/types/int16.h @@ -24,75 +24,61 @@ #include #include - namespace sd { - float _CUDA_HD FORCEINLINE cpu_int162float(int16_t data); - int16_t _CUDA_HD FORCEINLINE cpu_float2int16(float data); - - struct int16 { - int16_t data; - - _CUDA_HD FORCEINLINE int16(); - _CUDA_HD FORCEINLINE ~int16() = default; - - template - _CUDA_HD FORCEINLINE int16(const T& rhs); - - template - _CUDA_HD FORCEINLINE int16& operator=(const T& rhs); +float _CUDA_HD FORCEINLINE cpu_int162float(int16_t data); +int16_t _CUDA_HD FORCEINLINE cpu_float2int16(float data); +struct int16 { + int16_t data; - _CUDA_HD FORCEINLINE operator float() const; + _CUDA_HD FORCEINLINE int16(); + _CUDA_HD FORCEINLINE ~int16() = default; - _CUDA_HD FORCEINLINE void assign(double rhs); + template + _CUDA_HD FORCEINLINE int16(const T& rhs); - _CUDA_HD FORCEINLINE void assign(float rhs); - }; + template + _CUDA_HD FORCEINLINE int16& operator=(const T& rhs); + _CUDA_HD FORCEINLINE operator float() const; - ////////////////////////////// + _CUDA_HD FORCEINLINE void assign(double rhs); - float cpu_int162float(int16_t data) { - return (float) ((int) data); - } + _CUDA_HD FORCEINLINE void assign(float rhs); +}; - int16_t cpu_float2int16(float data) { - auto t = static_cast(data); - if (t > 32767 ) t = 32767; - if (t < -32768) t = -32768; +////////////////////////////// - return static_cast(t); - } +float cpu_int162float(int16_t data) { return (float)((int)data); } +int16_t cpu_float2int16(float data) { + auto t = static_cast(data); + if (t > 32767) t = 32767; + if (t < -32768) t = -32768; - int16::int16() { - data = cpu_float2int16(0.0f); - } + return static_cast(t); +} - template - int16::int16(const T& rhs) { - assign(rhs); - } +int16::int16() { data = cpu_float2int16(0.0f); } - template - int16& int16::operator=(const T& rhs) { - assign(rhs); return *this; - } +template +int16::int16(const T& rhs) { + assign(rhs); +} +template +int16& int16::operator=(const T& rhs) { + assign(rhs); + return *this; +} - int16::operator float() const { - return cpu_int162float(data); - } +int16::operator float() const { return cpu_int162float(data); } - void int16::assign(double rhs) { - assign(static_cast(rhs)); - } +void int16::assign(double rhs) { assign(static_cast(rhs)); } - void int16::assign(float rhs) { - data = cpu_float2int16(rhs); - } +void int16::assign(float rhs) { data = cpu_float2int16(rhs); } -} +} // namespace sd -#endif //LIBND4J_INT16_H +#endif // LIBND4J_INT16_H diff --git a/libnd4j/include/types/int8.h b/libnd4j/include/types/int8.h index 19e1b91e16a4..1e1daaeba8d6 100644 --- a/libnd4j/include/types/int8.h +++ b/libnd4j/include/types/int8.h @@ -24,71 +24,58 @@ #include #include - namespace sd { - float _CUDA_HD FORCEINLINE cpu_int82float(int8_t data); - int8_t _CUDA_HD FORCEINLINE cpu_float2int8(float data); - - struct int8 { - int8_t data; - - _CUDA_HD FORCEINLINE int8(); - _CUDA_HD FORCEINLINE ~int8() = default; - - template - _CUDA_HD FORCEINLINE int8(const T& rhs); +float _CUDA_HD FORCEINLINE cpu_int82float(int8_t data); +int8_t _CUDA_HD FORCEINLINE cpu_float2int8(float data); - template - _CUDA_HD FORCEINLINE int8& operator=(const T& rhs); +struct int8 { + int8_t data; + _CUDA_HD FORCEINLINE int8(); + _CUDA_HD FORCEINLINE ~int8() = default; - _CUDA_HD FORCEINLINE operator float() const; + template + _CUDA_HD FORCEINLINE int8(const T& rhs); - _CUDA_HD FORCEINLINE void assign(double rhs); + template + _CUDA_HD FORCEINLINE int8& operator=(const T& rhs); - _CUDA_HD FORCEINLINE void assign(float rhs); - }; + _CUDA_HD FORCEINLINE operator float() const; + _CUDA_HD FORCEINLINE void assign(double rhs); - float cpu_int82float(int8_t data) { - return (float) ((int) data); - } + _CUDA_HD FORCEINLINE void assign(float rhs); +}; - int8_t cpu_float2int8(float data) { - int t = (int) data; - if (t > 127) t = 127; - if (t < -128) t = -128; +float cpu_int82float(int8_t data) { return (float)((int)data); } - return (int8_t) t; - } +int8_t cpu_float2int8(float data) { + int t = (int)data; + if (t > 127) t = 127; + if (t < -128) t = -128; - int8::int8() { - data = cpu_float2int8(0.0f); - } + return (int8_t)t; +} - template - int8::int8(const T& rhs) { - assign(rhs); - } +int8::int8() { data = cpu_float2int8(0.0f); } - template - int8& int8::operator=(const T& rhs) { - assign(rhs); return *this; - } +template +int8::int8(const T& rhs) { + assign(rhs); +} +template +int8& int8::operator=(const T& rhs) { + assign(rhs); + return *this; +} - int8::operator float() const { - return cpu_int82float(data); - } +int8::operator float() const { return cpu_int82float(data); } - void int8::assign(double rhs) { - assign((float)rhs); - } +void int8::assign(double rhs) { assign((float)rhs); } - void int8::assign(float rhs) { - data = cpu_float2int8(rhs); - } -} +void int8::assign(float rhs) { data = cpu_float2int8(rhs); } +} // namespace sd -#endif //LIBND4J_INT8_H +#endif // LIBND4J_INT8_H diff --git a/libnd4j/include/types/pair.h b/libnd4j/include/types/pair.h index 0471c45ed33f..2a4142042c91 100644 --- a/libnd4j/include/types/pair.h +++ b/libnd4j/include/types/pair.h @@ -24,19 +24,18 @@ #include namespace sd { - class ND4J_EXPORT Pair { - protected: - int _first = 0; - int _second = 0; +class SD_EXPORT Pair { + protected: + int _first = 0; + int _second = 0; - public: - Pair(int first = 0, int second = 0); - ~Pair() = default; + public: + Pair(int first = 0, int second = 0); + ~Pair() = default; - int first() const; - int second() const; - }; -} + int first() const; + int second() const; +}; +} // namespace sd - -#endif //LIBND4J_PAIR_H +#endif // LIBND4J_PAIR_H diff --git a/libnd4j/include/types/triple.h b/libnd4j/include/types/triple.h index 0a5310265888..e7912084f535 100644 --- a/libnd4j/include/types/triple.h +++ b/libnd4j/include/types/triple.h @@ -21,24 +21,23 @@ #ifndef LIBND4J_TRIPLE_H #define LIBND4J_TRIPLE_H - #include namespace sd { - class ND4J_EXPORT Triple { - protected: - int _first = 0; - int _second = 0; - int _third = 0; - - public: - Triple(int first = 0, int second = 0, int third = 0); - ~Triple() = default; - - int first() const; - int second() const; - int third() const; - }; -} - -#endif //LIBND4J_TRIPLE_H +class SD_EXPORT Triple { + protected: + int _first = 0; + int _second = 0; + int _third = 0; + + public: + Triple(int first = 0, int second = 0, int third = 0); + ~Triple() = default; + + int first() const; + int second() const; + int third() const; +}; +} // namespace sd + +#endif // LIBND4J_TRIPLE_H diff --git a/libnd4j/include/types/types.h b/libnd4j/include/types/types.h index 7717c801908d..a6fe4064d4bb 100644 --- a/libnd4j/include/types/types.h +++ b/libnd4j/include/types/types.h @@ -22,478 +22,289 @@ #define LIBND4J_TYPES_H #include -#include +#include +#include #include -#include +#include #include -#include +#include #include +#include #include -#include -#include - -#define LIBND4J_STRINGTYPES \ - (sd::DataType::UTF8, std::string),\ - (sd::DataType::UTF16, std::u16string), \ - (sd::DataType::UTF32, std::u32string) - -#define LIBND4J_TYPES \ - (sd::DataType::BFLOAT16, bfloat16),\ - (sd::DataType::HALF, float16), \ - (sd::DataType::FLOAT32, float), \ - (sd::DataType::DOUBLE, double), \ - (sd::DataType::BOOL, bool), \ - (sd::DataType::INT8, int8_t), \ - (sd::DataType::UINT8, uint8_t), \ - (sd::DataType::UINT16, uint16_t), \ - (sd::DataType::UINT32, uint32_t), \ - (sd::DataType::UINT64, uint64_t), \ - (sd::DataType::INT16, int16_t), \ - (sd::DataType::INT32, int32_t), \ - (sd::DataType::INT64, Nd4jLong) - -#define LIBND4J_TYPES_EXTENDED \ - (sd::DataType::HALF, float16), \ - (sd::DataType::FLOAT32, float), \ - (sd::DataType::DOUBLE, double), \ - (sd::DataType::BOOL, bool), \ - (sd::DataType::INT8, int8_t), \ - (sd::DataType::UINT8, uint8_t), \ - (sd::DataType::INT16, int16_t), \ - (sd::DataType::INT32, int32_t), \ - (sd::DataType::INT64, Nd4jLong), \ - (sd::DataType::UINT16, uint16_t), \ - (sd::DataType::UINT64, Nd4jULong), \ - (sd::DataType::UINT32, uint32_t), \ - (sd::DataType::BFLOAT16, bfloat16) - -#define BOOL_TYPES \ - (sd::DataType::BOOL, bool) +#define LIBND4J_STRINGTYPES \ + (sd::DataType::UTF8, std::string), (sd::DataType::UTF16, std::u16string), \ + (sd::DataType::UTF32, std::u32string) + +#define LIBND4J_TYPES \ + (sd::DataType::BFLOAT16, bfloat16), (sd::DataType::HALF, float16), \ + (sd::DataType::FLOAT32, float), (sd::DataType::DOUBLE, double), \ + (sd::DataType::BOOL, bool), (sd::DataType::INT8, int8_t), \ + (sd::DataType::UINT8, uint8_t), (sd::DataType::UINT16, uint16_t), \ + (sd::DataType::UINT32, uint32_t), (sd::DataType::UINT64, uint64_t), \ + (sd::DataType::INT16, int16_t), (sd::DataType::INT32, int32_t), \ + (sd::DataType::INT64, Nd4jLong) + +#define LIBND4J_TYPES_EXTENDED \ + (sd::DataType::HALF, float16), (sd::DataType::FLOAT32, float), \ + (sd::DataType::DOUBLE, double), (sd::DataType::BOOL, bool), \ + (sd::DataType::INT8, int8_t), (sd::DataType::UINT8, uint8_t), \ + (sd::DataType::INT16, int16_t), (sd::DataType::INT32, int32_t), \ + (sd::DataType::INT64, Nd4jLong), (sd::DataType::UINT16, uint16_t), \ + (sd::DataType::UINT64, Nd4jULong), (sd::DataType::UINT32, uint32_t), \ + (sd::DataType::BFLOAT16, bfloat16) + +#define BOOL_TYPES (sd::DataType::BOOL, bool) #define LONG_TYPES \ - (sd::DataType::INT64, Nd4jLong),\ - (sd::DataType::UINT64, uint64_t) + (sd::DataType::INT64, Nd4jLong), (sd::DataType::UINT64, uint64_t) -#define FLOAT_TYPES \ - (sd::DataType::BFLOAT16, bfloat16) ,\ - (sd::DataType::HALF, float16), \ - (sd::DataType::FLOAT32, float), \ - (sd::DataType::DOUBLE, double) +#define FLOAT_TYPES \ + (sd::DataType::BFLOAT16, bfloat16), (sd::DataType::HALF, float16), \ + (sd::DataType::FLOAT32, float), (sd::DataType::DOUBLE, double) #define INDEXING_TYPES \ - (sd::DataType::INT32, int32_t), \ - (sd::DataType::INT64, Nd4jLong) + (sd::DataType::INT32, int32_t), (sd::DataType::INT64, Nd4jLong) #define FLOAT_NATIVE \ - (sd::DataType::FLOAT32, float), \ - (sd::DataType::DOUBLE, double) + (sd::DataType::FLOAT32, float), (sd::DataType::DOUBLE, double) -#define FLOAT_TYPES_0 \ - (sd::DataType::HALF, float16) +#define FLOAT_TYPES_0 (sd::DataType::HALF, float16) -#define FLOAT_TYPES_1 \ - (sd::DataType::FLOAT32, float) +#define FLOAT_TYPES_1 (sd::DataType::FLOAT32, float) -#define FLOAT_TYPES_2 \ - (sd::DataType::DOUBLE, double) +#define FLOAT_TYPES_2 (sd::DataType::DOUBLE, double) -#define FLOAT_TYPES_3 \ - (sd::DataType::BFLOAT16, bfloat16) +#define FLOAT_TYPES_3 (sd::DataType::BFLOAT16, bfloat16) -#define LIBND4J_TYPES_0 \ - (sd::DataType::HALF, float16) +#define LIBND4J_TYPES_0 (sd::DataType::HALF, float16) -#define LIBND4J_TYPES_1 \ - (sd::DataType::FLOAT32, float) +#define LIBND4J_TYPES_1 (sd::DataType::FLOAT32, float) -#define LIBND4J_TYPES_2 \ - (sd::DataType::DOUBLE, double) +#define LIBND4J_TYPES_2 (sd::DataType::DOUBLE, double) -#define LIBND4J_TYPES_3 \ - (sd::DataType::BOOL, bool) +#define LIBND4J_TYPES_3 (sd::DataType::BOOL, bool) -#define LIBND4J_TYPES_4 \ - (sd::DataType::INT8, int8_t) +#define LIBND4J_TYPES_4 (sd::DataType::INT8, int8_t) -#define LIBND4J_TYPES_5 \ - (sd::DataType::UINT8, uint8_t) +#define LIBND4J_TYPES_5 (sd::DataType::UINT8, uint8_t) #define LIBND4J_TYPES_6 \ - (sd::DataType::INT16, int16_t),\ - (sd::DataType::UINT16, uint16_t) + (sd::DataType::INT16, int16_t), (sd::DataType::UINT16, uint16_t) #define LIBND4J_TYPES_7 \ - (sd::DataType::INT32, int32_t), \ - (sd::DataType::UINT32, uint32_t) + (sd::DataType::INT32, int32_t), (sd::DataType::UINT32, uint32_t) #define LIBND4J_TYPES_8 \ - (sd::DataType::INT64, Nd4jLong),\ - (sd::DataType::UINT64, uint64_t) + (sd::DataType::INT64, Nd4jLong), (sd::DataType::UINT64, uint64_t) -#define LIBND4J_TYPES_9 \ - (sd::DataType::BFLOAT16, bfloat16) +#define LIBND4J_TYPES_9 (sd::DataType::BFLOAT16, bfloat16) -#define INTEGER_TYPES \ - (sd::DataType::INT8, int8_t), \ - (sd::DataType::UINT8, uint8_t), \ - (sd::DataType::UINT16, uint16_t), \ - (sd::DataType::UINT32, uint32_t), \ - (sd::DataType::UINT64, uint64_t), \ - (sd::DataType::INT16, int16_t), \ - (sd::DataType::INT32, int32_t), \ - (sd::DataType::INT64, Nd4jLong) +#define INTEGER_TYPES \ + (sd::DataType::INT8, int8_t), (sd::DataType::UINT8, uint8_t), \ + (sd::DataType::UINT16, uint16_t), (sd::DataType::UINT32, uint32_t), \ + (sd::DataType::UINT64, uint64_t), (sd::DataType::INT16, int16_t), \ + (sd::DataType::INT32, int32_t), (sd::DataType::INT64, Nd4jLong) -#define INTEGER_TYPES_0 \ - (sd::DataType::INT8, int8_t) +#define INTEGER_TYPES_0 (sd::DataType::INT8, int8_t) -#define INTEGER_TYPES_1 \ - (sd::DataType::UINT8, uint8_t) +#define INTEGER_TYPES_1 (sd::DataType::UINT8, uint8_t) -#define INTEGER_TYPES_2 \ - (sd::DataType::UINT16, uint16_t) +#define INTEGER_TYPES_2 (sd::DataType::UINT16, uint16_t) -#define INTEGER_TYPES_3 \ - (sd::DataType::UINT32, uint32_t) +#define INTEGER_TYPES_3 (sd::DataType::UINT32, uint32_t) -#define INTEGER_TYPES_4 \ - (sd::DataType::UINT64, uint64_t) +#define INTEGER_TYPES_4 (sd::DataType::UINT64, uint64_t) -#define INTEGER_TYPES_5 \ - (sd::DataType::INT16, int16_t) +#define INTEGER_TYPES_5 (sd::DataType::INT16, int16_t) -#define INTEGER_TYPES_6 \ - (sd::DataType::INT32, int32_t) +#define INTEGER_TYPES_6 (sd::DataType::INT32, int32_t) -#define INTEGER_TYPES_7 \ - (sd::DataType::INT64, Nd4jLong) +#define INTEGER_TYPES_7 (sd::DataType::INT64, Nd4jLong) +#define NUMERIC_TYPES \ + (sd::DataType::HALF, float16), (sd::DataType::FLOAT32, float), \ + (sd::DataType::DOUBLE, double), (sd::DataType::INT8, int8_t), \ + (sd::DataType::UINT8, uint8_t), (sd::DataType::UINT16, uint16_t), \ + (sd::DataType::UINT32, uint32_t), (sd::DataType::UINT64, uint64_t), \ + (sd::DataType::INT16, int16_t), (sd::DataType::INT32, int32_t), \ + (sd::DataType::INT64, Nd4jLong), (sd::DataType::BFLOAT16, bfloat16) -#define NUMERIC_TYPES \ - (sd::DataType::HALF, float16), \ - (sd::DataType::FLOAT32, float), \ - (sd::DataType::DOUBLE, double), \ - (sd::DataType::INT8, int8_t), \ - (sd::DataType::UINT8, uint8_t), \ - (sd::DataType::UINT16, uint16_t), \ - (sd::DataType::UINT32, uint32_t), \ - (sd::DataType::UINT64, uint64_t), \ - (sd::DataType::INT16, int16_t), \ - (sd::DataType::INT32, int32_t), \ - (sd::DataType::INT64, Nd4jLong), \ - (sd::DataType::BFLOAT16, bfloat16) +#define NUMERIC_TYPES_0 (sd::DataType::HALF, float16) -#define NUMERIC_TYPES_0 \ - (sd::DataType::HALF, float16) +#define NUMERIC_TYPES_1 (sd::DataType::FLOAT32, float) -#define NUMERIC_TYPES_1 \ - (sd::DataType::FLOAT32, float) - -#define NUMERIC_TYPES_2 \ - (sd::DataType::DOUBLE, double) +#define NUMERIC_TYPES_2 (sd::DataType::DOUBLE, double) #define NUMERIC_TYPES_3 \ - (sd::DataType::INT8, int8_t), \ - (sd::DataType::BFLOAT16, bfloat16) + (sd::DataType::INT8, int8_t), (sd::DataType::BFLOAT16, bfloat16) -#define NUMERIC_TYPES_4 \ - (sd::DataType::UINT8, uint8_t) +#define NUMERIC_TYPES_4 (sd::DataType::UINT8, uint8_t) -#define NUMERIC_TYPES_5 \ - (sd::DataType::UINT16, uint16_t) +#define NUMERIC_TYPES_5 (sd::DataType::UINT16, uint16_t) -#define NUMERIC_TYPES_6 \ - (sd::DataType::UINT32, uint32_t) +#define NUMERIC_TYPES_6 (sd::DataType::UINT32, uint32_t) -#define NUMERIC_TYPES_7 \ - (sd::DataType::UINT64, uint64_t) +#define NUMERIC_TYPES_7 (sd::DataType::UINT64, uint64_t) -#define NUMERIC_TYPES_8 \ - (sd::DataType::INT16, int16_t) +#define NUMERIC_TYPES_8 (sd::DataType::INT16, int16_t) #define NUMERIC_TYPES_9 \ - (sd::DataType::INT32, int32_t), \ - (sd::DataType::INT64, Nd4jLong) - - -#define GENERIC_NUMERIC_TYPES \ - (sd::DataType::HALF, float16), \ - (sd::DataType::FLOAT32, float), \ - (sd::DataType::DOUBLE, double), \ - (sd::DataType::INT32, int32_t), \ - (sd::DataType::INT64, Nd4jLong), \ - (sd::DataType::BFLOAT16, bfloat16) + (sd::DataType::INT32, int32_t), (sd::DataType::INT64, Nd4jLong) +#define GENERIC_NUMERIC_TYPES \ + (sd::DataType::HALF, float16), (sd::DataType::FLOAT32, float), \ + (sd::DataType::DOUBLE, double), (sd::DataType::INT32, int32_t), \ + (sd::DataType::INT64, Nd4jLong), (sd::DataType::BFLOAT16, bfloat16) #ifdef __ND4J_EXPERIMENTAL__ -#define PAIRWISE_TYPES_0 \ - (double, double, double), \ - (double, uint8_t, double), \ - (double, uint8_t, uint8_t), \ - (double, float, double), \ - (double, float, float), \ - (double, bfloat16, double), \ - (double, bfloat16, bfloat16), \ - (double, Nd4jLong, double), \ - (double, Nd4jLong, Nd4jLong), \ - (double, int32_t, double), \ - (double, int32_t, int32_t) , \ - (bool, bool, bool), \ - (bool, int8_t, bool), \ - (int8_t, bool, bool), \ - (int8_t, int8_t, int8_t), \ - (int16_t, bool, int16_t), \ - (float16, int8_t, float16), \ - (bfloat16, bool, bool), \ - (double, int8_t, double) - -#define PAIRWISE_TYPES_9 \ - (double, int16_t, double), \ - (double, int16_t, int16_t), \ - (double, float16, double), \ - (double, float16, float16), \ - (double, bool, double), \ - (double, bool, bool), \ - (int8_t, double, int8_t), \ - (int8_t, double, double), \ - (int8_t, uint8_t, int8_t), \ - (int8_t, uint8_t, uint8_t), \ - (int8_t, float, int8_t), \ - (double, int8_t, int8_t) , \ - (bool, int8_t, int8_t) ,\ - (float16, int32_t, float16), \ - (float16, int32_t, int32_t), \ - (float16, int16_t, float16), \ - (float16, int16_t, int16_t), \ - (float16, float16, float16), \ - (float16, bool, float16), \ - (float16, bool, bool), \ - (float16, int8_t, int8_t) - -#define PAIRWISE_TYPES_1 \ - (uint8_t, double, uint8_t), \ - (uint8_t, double, double), \ - (uint8_t, uint8_t, uint8_t), \ - (uint8_t, float, uint8_t), \ - (uint8_t, float, float), \ - (uint8_t, bfloat16, uint8_t), \ - (uint8_t, bfloat16, bfloat16), \ - (uint8_t, Nd4jLong, uint8_t) , \ - (uint8_t, Nd4jLong, Nd4jLong), \ - (uint8_t, int32_t, uint8_t), \ - (uint8_t, int32_t, int32_t), \ - (uint8_t, int16_t, uint8_t), \ - (uint8_t, int16_t, int16_t), \ - (uint8_t, float16, uint8_t), \ - (uint8_t, float16, float16), \ - (uint8_t, bool, uint8_t), \ - (uint8_t, bool, bool), \ - (uint8_t, int8_t, uint8_t), \ - (uint8_t, int8_t, int8_t) - -#define PAIRWISE_TYPES_2 \ - (float, double, float), \ - (float, double, double), \ - (float, uint8_t, float), \ - (float, uint8_t, uint8_t), \ - (float, float, float), \ - (float, bfloat16, float), \ - (float, bfloat16, bfloat16), \ - (float, Nd4jLong, float), \ - (float, Nd4jLong, Nd4jLong) , \ - (float, int32_t, float), \ - (int8_t, int32_t, int8_t), \ - (int8_t, int32_t, int32_t), \ - (float, int32_t, int32_t), \ - (float, int16_t, float), \ - (float, int16_t, int16_t), \ - (float, float16, float), \ - (float, float16, float16), \ - (float, bool, float) - -#define PAIRWISE_TYPES_3 \ - (bfloat16, double, bfloat16), \ - (bfloat16, double, double), \ - (bfloat16, uint8_t, bfloat16), \ - (bfloat16, uint8_t, uint8_t), \ - (bfloat16, float, bfloat16), \ - (bfloat16, float, float), \ - (bfloat16, bfloat16, bfloat16), \ - (bfloat16, Nd4jLong, bfloat16), \ - (bfloat16, Nd4jLong, Nd4jLong), \ - (bfloat16, int32_t, bfloat16) , \ - (float, bool, bool), \ - (bfloat16, int32_t, int32_t), \ - (float, int8_t, float), \ - (int8_t, float, float), \ - (int8_t, bfloat16, int8_t), \ - (int8_t, bfloat16, bfloat16), \ - (int8_t, Nd4jLong, int8_t), \ - (int8_t, Nd4jLong, Nd4jLong), \ - (float, int8_t, int8_t) - -#define PAIRWISE_TYPES_4 \ - (Nd4jLong, double, Nd4jLong), \ - (Nd4jLong, double, double), \ - (Nd4jLong, uint8_t, Nd4jLong), \ - (Nd4jLong, uint8_t, uint8_t), \ - (Nd4jLong, float, Nd4jLong), \ - (Nd4jLong, float, float), \ - (Nd4jLong, bfloat16, Nd4jLong), \ - (Nd4jLong, bfloat16, bfloat16), \ - (Nd4jLong, Nd4jLong, Nd4jLong), \ - (Nd4jLong, int32_t, Nd4jLong), \ - (bfloat16, int16_t, bfloat16), \ - (bfloat16, int16_t, int16_t), \ - (bfloat16, float16, bfloat16), \ - (bfloat16, float16, float16), \ - (bfloat16, bool, bfloat16), \ - (int8_t, int16_t, int8_t), \ - (int8_t, int16_t, int16_t), \ - (bfloat16, int8_t, bfloat16), \ - (bfloat16, int8_t, int8_t) - -#define PAIRWISE_TYPES_5 \ - (int32_t, double, int32_t), \ - (int32_t, double, double), \ - (int32_t, uint8_t, int32_t), \ - (int32_t, uint8_t, uint8_t), \ - (int32_t, float, int32_t), \ - (int32_t, float, float), \ - (int32_t, bfloat16, int32_t), \ - (int32_t, bfloat16, bfloat16), \ - (int32_t, Nd4jLong, int32_t), \ - (Nd4jLong, int32_t, int32_t), \ - (Nd4jLong, int16_t, Nd4jLong), \ - (Nd4jLong, int16_t, int16_t), \ - (Nd4jLong, float16, Nd4jLong), \ - (Nd4jLong, float16, float16), \ - (Nd4jLong, bool, Nd4jLong), \ - (Nd4jLong, bool, bool), \ - (Nd4jLong, int8_t, Nd4jLong), \ - (Nd4jLong, int8_t, int8_t) - - -#define PAIRWISE_TYPES_6 \ - (int16_t, double, int16_t), \ - (int16_t, double, double), \ - (int16_t, uint8_t, int16_t), \ - (int16_t, uint8_t, uint8_t), \ - (int16_t, float, int16_t), \ - (int16_t, float, float), \ - (int16_t, bfloat16, int16_t), \ - (int16_t, bfloat16, bfloat16), \ - (int16_t, Nd4jLong, int16_t), \ - (float16, bfloat16, bfloat16), \ - (int16_t, Nd4jLong, Nd4jLong), \ - (int32_t, Nd4jLong, Nd4jLong), \ - (int32_t, int32_t, int32_t), \ - (int32_t, int16_t, int32_t), \ - (int32_t, int16_t, int16_t), \ - (int32_t, float16, int32_t), \ - (int32_t, float16, float16), \ - (int32_t, bool, int32_t), \ - (int32_t, bool, bool), \ - (int32_t, int8_t, int32_t), \ - (int32_t, int8_t, int8_t) - - -#define PAIRWISE_TYPES_7 \ - (float16, double, float16), \ - (float16, double, double), \ - (float16, uint8_t, float16), \ - (float16, uint8_t, uint8_t), \ - (float16, float, float16), \ - (float16, float, float), \ - (float16, bfloat16, float16), \ - (float16, Nd4jLong, float16), \ - (float16, Nd4jLong, Nd4jLong), \ - (int16_t, int32_t, int16_t), \ - (int16_t, int32_t, int32_t), \ - (int16_t, int16_t, int16_t), \ - (int16_t, float16, int16_t), \ - (int16_t, float16, float16), \ - (int8_t, float16, int8_t), \ - (int8_t, float16, float16), \ - (int8_t, bool, int8_t), \ - (int16_t, bool, bool), \ - (int16_t, int8_t, int16_t), \ - (int16_t, int8_t, int8_t) - - -#define PAIRWISE_TYPES_8 \ - (bool, double, bool), \ - (bool, double, double), \ - (bool, uint8_t, bool), \ - (bool, uint8_t, uint8_t), \ - (bool, float, bool), \ - (bool, float, float), \ - (bool, bfloat16, bool) ,\ - (bool, bfloat16, bfloat16), \ - (bool, Nd4jLong, bool), \ - (bool, Nd4jLong, Nd4jLong), \ - (bool, int32_t, bool), \ - (bool, int32_t, int32_t), \ - (bool, int16_t, bool), \ - (bool, int16_t, int16_t), \ - (bool, float16, bool), \ - (bool, float16, float16) - +#define PAIRWISE_TYPES_0 \ + (double, double, double), (double, uint8_t, double), \ + (double, uint8_t, uint8_t), (double, float, double), \ + (double, float, float), (double, bfloat16, double), \ + (double, bfloat16, bfloat16), (double, Nd4jLong, double), \ + (double, Nd4jLong, Nd4jLong), (double, int32_t, double), \ + (double, int32_t, int32_t), (bool, bool, bool), (bool, int8_t, bool), \ + (int8_t, bool, bool), (int8_t, int8_t, int8_t), \ + (int16_t, bool, int16_t), (float16, int8_t, float16), \ + (bfloat16, bool, bool), (double, int8_t, double) + +#define PAIRWISE_TYPES_9 \ + (double, int16_t, double), (double, int16_t, int16_t), \ + (double, float16, double), (double, float16, float16), \ + (double, bool, double), (double, bool, bool), (int8_t, double, int8_t), \ + (int8_t, double, double), (int8_t, uint8_t, int8_t), \ + (int8_t, uint8_t, uint8_t), (int8_t, float, int8_t), \ + (double, int8_t, int8_t), (bool, int8_t, int8_t), \ + (float16, int32_t, float16), (float16, int32_t, int32_t), \ + (float16, int16_t, float16), (float16, int16_t, int16_t), \ + (float16, float16, float16), (float16, bool, float16), \ + (float16, bool, bool), (float16, int8_t, int8_t) + +#define PAIRWISE_TYPES_1 \ + (uint8_t, double, uint8_t), (uint8_t, double, double), \ + (uint8_t, uint8_t, uint8_t), (uint8_t, float, uint8_t), \ + (uint8_t, float, float), (uint8_t, bfloat16, uint8_t), \ + (uint8_t, bfloat16, bfloat16), (uint8_t, Nd4jLong, uint8_t), \ + (uint8_t, Nd4jLong, Nd4jLong), (uint8_t, int32_t, uint8_t), \ + (uint8_t, int32_t, int32_t), (uint8_t, int16_t, uint8_t), \ + (uint8_t, int16_t, int16_t), (uint8_t, float16, uint8_t), \ + (uint8_t, float16, float16), (uint8_t, bool, uint8_t), \ + (uint8_t, bool, bool), (uint8_t, int8_t, uint8_t), \ + (uint8_t, int8_t, int8_t) + +#define PAIRWISE_TYPES_2 \ + (float, double, float), (float, double, double), (float, uint8_t, float), \ + (float, uint8_t, uint8_t), (float, float, float), \ + (float, bfloat16, float), (float, bfloat16, bfloat16), \ + (float, Nd4jLong, float), (float, Nd4jLong, Nd4jLong), \ + (float, int32_t, float), (int8_t, int32_t, int8_t), \ + (int8_t, int32_t, int32_t), (float, int32_t, int32_t), \ + (float, int16_t, float), (float, int16_t, int16_t), \ + (float, float16, float), (float, float16, float16), (float, bool, float) + +#define PAIRWISE_TYPES_3 \ + (bfloat16, double, bfloat16), (bfloat16, double, double), \ + (bfloat16, uint8_t, bfloat16), (bfloat16, uint8_t, uint8_t), \ + (bfloat16, float, bfloat16), (bfloat16, float, float), \ + (bfloat16, bfloat16, bfloat16), (bfloat16, Nd4jLong, bfloat16), \ + (bfloat16, Nd4jLong, Nd4jLong), (bfloat16, int32_t, bfloat16), \ + (float, bool, bool), (bfloat16, int32_t, int32_t), \ + (float, int8_t, float), (int8_t, float, float), \ + (int8_t, bfloat16, int8_t), (int8_t, bfloat16, bfloat16), \ + (int8_t, Nd4jLong, int8_t), (int8_t, Nd4jLong, Nd4jLong), \ + (float, int8_t, int8_t) + +#define PAIRWISE_TYPES_4 \ + (Nd4jLong, double, Nd4jLong), (Nd4jLong, double, double), \ + (Nd4jLong, uint8_t, Nd4jLong), (Nd4jLong, uint8_t, uint8_t), \ + (Nd4jLong, float, Nd4jLong), (Nd4jLong, float, float), \ + (Nd4jLong, bfloat16, Nd4jLong), (Nd4jLong, bfloat16, bfloat16), \ + (Nd4jLong, Nd4jLong, Nd4jLong), (Nd4jLong, int32_t, Nd4jLong), \ + (bfloat16, int16_t, bfloat16), (bfloat16, int16_t, int16_t), \ + (bfloat16, float16, bfloat16), (bfloat16, float16, float16), \ + (bfloat16, bool, bfloat16), (int8_t, int16_t, int8_t), \ + (int8_t, int16_t, int16_t), (bfloat16, int8_t, bfloat16), \ + (bfloat16, int8_t, int8_t) + +#define PAIRWISE_TYPES_5 \ + (int32_t, double, int32_t), (int32_t, double, double), \ + (int32_t, uint8_t, int32_t), (int32_t, uint8_t, uint8_t), \ + (int32_t, float, int32_t), (int32_t, float, float), \ + (int32_t, bfloat16, int32_t), (int32_t, bfloat16, bfloat16), \ + (int32_t, Nd4jLong, int32_t), (Nd4jLong, int32_t, int32_t), \ + (Nd4jLong, int16_t, Nd4jLong), (Nd4jLong, int16_t, int16_t), \ + (Nd4jLong, float16, Nd4jLong), (Nd4jLong, float16, float16), \ + (Nd4jLong, bool, Nd4jLong), (Nd4jLong, bool, bool), \ + (Nd4jLong, int8_t, Nd4jLong), (Nd4jLong, int8_t, int8_t) + +#define PAIRWISE_TYPES_6 \ + (int16_t, double, int16_t), (int16_t, double, double), \ + (int16_t, uint8_t, int16_t), (int16_t, uint8_t, uint8_t), \ + (int16_t, float, int16_t), (int16_t, float, float), \ + (int16_t, bfloat16, int16_t), (int16_t, bfloat16, bfloat16), \ + (int16_t, Nd4jLong, int16_t), (float16, bfloat16, bfloat16), \ + (int16_t, Nd4jLong, Nd4jLong), (int32_t, Nd4jLong, Nd4jLong), \ + (int32_t, int32_t, int32_t), (int32_t, int16_t, int32_t), \ + (int32_t, int16_t, int16_t), (int32_t, float16, int32_t), \ + (int32_t, float16, float16), (int32_t, bool, int32_t), \ + (int32_t, bool, bool), (int32_t, int8_t, int32_t), \ + (int32_t, int8_t, int8_t) + +#define PAIRWISE_TYPES_7 \ + (float16, double, float16), (float16, double, double), \ + (float16, uint8_t, float16), (float16, uint8_t, uint8_t), \ + (float16, float, float16), (float16, float, float), \ + (float16, bfloat16, float16), (float16, Nd4jLong, float16), \ + (float16, Nd4jLong, Nd4jLong), (int16_t, int32_t, int16_t), \ + (int16_t, int32_t, int32_t), (int16_t, int16_t, int16_t), \ + (int16_t, float16, int16_t), (int16_t, float16, float16), \ + (int8_t, float16, int8_t), (int8_t, float16, float16), \ + (int8_t, bool, int8_t), (int16_t, bool, bool), \ + (int16_t, int8_t, int16_t), (int16_t, int8_t, int8_t) + +#define PAIRWISE_TYPES_8 \ + (bool, double, bool), (bool, double, double), (bool, uint8_t, bool), \ + (bool, uint8_t, uint8_t), (bool, float, bool), (bool, float, float), \ + (bool, bfloat16, bool), (bool, bfloat16, bfloat16), \ + (bool, Nd4jLong, bool), (bool, Nd4jLong, Nd4jLong), \ + (bool, int32_t, bool), (bool, int32_t, int32_t), (bool, int16_t, bool), \ + (bool, int16_t, int16_t), (bool, float16, bool), \ + (bool, float16, float16) #else -#define PAIRWISE_TYPES_0 \ -(float16, float16, float16) , \ -(float16, bool, float16) +#define PAIRWISE_TYPES_0 (float16, float16, float16), (float16, bool, float16) -#define PAIRWISE_TYPES_1 \ -(float, float, float) , \ -(float, bool, float) +#define PAIRWISE_TYPES_1 (float, float, float), (float, bool, float) -#define PAIRWISE_TYPES_2 \ -(double, double, double) , \ -(double, bool, double) +#define PAIRWISE_TYPES_2 (double, double, double), (double, bool, double) -#define PAIRWISE_TYPES_3 \ -(int8_t, int8_t, int8_t) , \ -(int8_t, bool, int8_t) +#define PAIRWISE_TYPES_3 (int8_t, int8_t, int8_t), (int8_t, bool, int8_t) -#define PAIRWISE_TYPES_4 \ -(int16_t, int16_t, int16_t) , \ -(int16_t, bool, int16_t) +#define PAIRWISE_TYPES_4 (int16_t, int16_t, int16_t), (int16_t, bool, int16_t) -#define PAIRWISE_TYPES_5 \ -(uint8_t, uint8_t, uint8_t) , \ -(uint8_t, bool, uint8_t) +#define PAIRWISE_TYPES_5 (uint8_t, uint8_t, uint8_t), (uint8_t, bool, uint8_t) -#define PAIRWISE_TYPES_6 \ -(int, int, int) ,\ -(int, bool, int) +#define PAIRWISE_TYPES_6 (int, int, int), (int, bool, int) -#define PAIRWISE_TYPES_7 \ -(bool, bool, bool) +#define PAIRWISE_TYPES_7 (bool, bool, bool) #define PAIRWISE_TYPES_8 \ -(Nd4jLong, Nd4jLong, Nd4jLong) ,\ -(Nd4jLong, bool, Nd4jLong) + (Nd4jLong, Nd4jLong, Nd4jLong), (Nd4jLong, bool, Nd4jLong) #define PAIRWISE_TYPES_9 \ -(bfloat16, bfloat16, bfloat16) , \ -(bfloat16, bool, bfloat16) + (bfloat16, bfloat16, bfloat16), (bfloat16, bool, bfloat16) #define PAIRWISE_TYPES_10 \ -(uint64_t, uint64_t, uint64_t) ,\ -(uint64_t, bool, uint64_t) + (uint64_t, uint64_t, uint64_t), (uint64_t, bool, uint64_t) #define PAIRWISE_TYPES_11 \ -(uint32_t, uint32_t, uint32_t) ,\ -(uint32_t, bool, uint32_t) + (uint32_t, uint32_t, uint32_t), (uint32_t, bool, uint32_t) #define PAIRWISE_TYPES_12 \ -(uint16_t, uint16_t, uint16_t) ,\ -(uint16_t, bool, uint16_t) + (uint16_t, uint16_t, uint16_t), (uint16_t, bool, uint16_t) #endif -#endif //LIBND4J_TYPES_H - +#endif // LIBND4J_TYPES_H diff --git a/libnd4j/include/types/u64.h b/libnd4j/include/types/u64.h index 908a9ba1cd63..ef9537ce3eeb 100644 --- a/libnd4j/include/types/u64.h +++ b/libnd4j/include/types/u64.h @@ -20,45 +20,43 @@ #ifndef LIBND4J_U64_H #define LIBND4J_U64_H -#include #include #include +#include namespace sd { - typedef struct { - int16_t _v0; - int16_t _v1; - int16_t _v2; - int16_t _v3; - } di16; - - typedef struct { - int _v0; - int _v1; - } di32; - - typedef struct { - uint32_t _v0; - uint32_t _v1; - } du32; - - union u64 { - bool _bool; - int8_t _char; - int16_t _short; - int32_t _int; - //float16 _half = 0.0f; - float _float; - double _double; - Nd4jLong _long; - uint64_t _ulong; - di32 _di32; - du32 _du32; - u64() { - _long = 0; - } - }; -} +typedef struct { + int16_t _v0; + int16_t _v1; + int16_t _v2; + int16_t _v3; +} di16; + +typedef struct { + int _v0; + int _v1; +} di32; + +typedef struct { + uint32_t _v0; + uint32_t _v1; +} du32; + +union u64 { + bool _bool; + int8_t _char; + int16_t _short; + int32_t _int; + // float16 _half = 0.0f; + float _float; + double _double; + Nd4jLong _long; + uint64_t _ulong; + di32 _di32; + du32 _du32; + u64() { _long = 0; } +}; +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/types/uint16.h b/libnd4j/include/types/uint16.h index 5fee50e7aa30..d45454336538 100644 --- a/libnd4j/include/types/uint16.h +++ b/libnd4j/include/types/uint16.h @@ -24,75 +24,66 @@ #include #include - namespace sd { - uint16_t _CUDA_HD FORCEINLINE cpu_float2uint16(float data); - float _CUDA_HD FORCEINLINE cpu_uint162float(uint16_t data); +uint16_t _CUDA_HD FORCEINLINE cpu_float2uint16(float data); +float _CUDA_HD FORCEINLINE cpu_uint162float(uint16_t data); - struct uint16 { - uint16_t data; +struct uint16 { + uint16_t data; - _CUDA_HD FORCEINLINE uint16(); - _CUDA_HD FORCEINLINE ~uint16(); + _CUDA_HD FORCEINLINE uint16(); + _CUDA_HD FORCEINLINE ~uint16(); - template - _CUDA_HD FORCEINLINE uint16(const T& rhs); + template + _CUDA_HD FORCEINLINE uint16(const T& rhs); - template - _CUDA_HD FORCEINLINE uint16& operator=(const T& rhs); + template + _CUDA_HD FORCEINLINE uint16& operator=(const T& rhs); - _CUDA_HD FORCEINLINE operator float() const; + _CUDA_HD FORCEINLINE operator float() const; - _CUDA_HD FORCEINLINE void assign(double rhs); + _CUDA_HD FORCEINLINE void assign(double rhs); - _CUDA_HD FORCEINLINE void assign(float rhs); - }; + _CUDA_HD FORCEINLINE void assign(float rhs); +}; //////////////////// IMPLEMENTATIONS - float _CUDA_HD cpu_uint162float(uint16_t data) { - return static_cast(data); - } +float _CUDA_HD cpu_uint162float(uint16_t data) { + return static_cast(data); +} - uint16_t _CUDA_HD cpu_float2uint16(float data) { - auto t = static_cast(data); - if (t > 65536 ) t = 65536; - if (t < 0) t = 0; +uint16_t _CUDA_HD cpu_float2uint16(float data) { + auto t = static_cast(data); + if (t > 65536) t = 65536; + if (t < 0) t = 0; - return static_cast(t); - } + return static_cast(t); +} - _CUDA_HD uint16::uint16() { - data = cpu_float2uint16(0.0f); - } +_CUDA_HD uint16::uint16() { data = cpu_float2uint16(0.0f); } - _CUDA_HD uint16::~uint16() { - // - } +_CUDA_HD uint16::~uint16() { + // +} - template - _CUDA_HD uint16::uint16(const T& rhs) { - assign(rhs); - } +template +_CUDA_HD uint16::uint16(const T& rhs) { + assign(rhs); +} - template - _CUDA_HD uint16& uint16::operator=(const T& rhs) { - assign(rhs); - return *this; - } +template +_CUDA_HD uint16& uint16::operator=(const T& rhs) { + assign(rhs); + return *this; +} - _CUDA_HD uint16::operator float() const { - return cpu_uint162float(data); - } +_CUDA_HD uint16::operator float() const { return cpu_uint162float(data); } - _CUDA_HD void uint16::assign(float rhs) { - data = cpu_float2uint16(rhs); - } +_CUDA_HD void uint16::assign(float rhs) { data = cpu_float2uint16(rhs); } - _CUDA_HD void uint16::assign(double rhs) { - assign((float)rhs); - } -} +_CUDA_HD void uint16::assign(double rhs) { assign((float)rhs); } +} // namespace sd -#endif //LIBND4J_UINT16_H +#endif // LIBND4J_UINT16_H diff --git a/libnd4j/include/types/uint8.h b/libnd4j/include/types/uint8.h index a2505c9ab64c..b3c8de86f327 100644 --- a/libnd4j/include/types/uint8.h +++ b/libnd4j/include/types/uint8.h @@ -24,71 +24,62 @@ #include #include - namespace sd { - float _CUDA_HD FORCEINLINE cpu_uint82float(uint8_t data); - uint8_t _CUDA_HD FORCEINLINE cpu_float2uint8(float data); - - struct uint8 { - uint8_t data; - - _CUDA_HD FORCEINLINE uint8(); - _CUDA_HD FORCEINLINE ~uint8() = default; - - template - _CUDA_HD FORCEINLINE uint8(const T& rhs); - - template - _CUDA_HD FORCEINLINE uint8& operator=(const T& rhs); - +float _CUDA_HD FORCEINLINE cpu_uint82float(uint8_t data); +uint8_t _CUDA_HD FORCEINLINE cpu_float2uint8(float data); - _CUDA_HD FORCEINLINE operator float() const; +struct uint8 { + uint8_t data; - _CUDA_HD FORCEINLINE void assign(double rhs); + _CUDA_HD FORCEINLINE uint8(); + _CUDA_HD FORCEINLINE ~uint8() = default; - _CUDA_HD FORCEINLINE void assign(float rhs); - }; + template + _CUDA_HD FORCEINLINE uint8(const T& rhs); + template + _CUDA_HD FORCEINLINE uint8& operator=(const T& rhs); + _CUDA_HD FORCEINLINE operator float() const; - /////////////////////////// + _CUDA_HD FORCEINLINE void assign(double rhs); + _CUDA_HD FORCEINLINE void assign(float rhs); +}; - float cpu_uint82float(uint8_t data) { - return static_cast(static_cast(data)); - } +/////////////////////////// - uint8_t cpu_float2uint8(float data) { - auto t = static_cast(data); - if (t > 255) t = 255; - if (t < 0) t = 0; +float cpu_uint82float(uint8_t data) { + return static_cast(static_cast(data)); +} - return static_cast(t); - } +uint8_t cpu_float2uint8(float data) { + auto t = static_cast(data); + if (t > 255) t = 255; + if (t < 0) t = 0; - uint8::uint8() { data = cpu_float2uint8(0.0f); } + return static_cast(t); +} - template - uint8::uint8(const T& rhs) { - assign(rhs); - } +uint8::uint8() { data = cpu_float2uint8(0.0f); } - template - uint8& uint8::operator=(const T& rhs) { assign(rhs); return *this; } +template +uint8::uint8(const T& rhs) { + assign(rhs); +} +template +uint8& uint8::operator=(const T& rhs) { + assign(rhs); + return *this; +} - uint8::operator float() const { - return cpu_uint82float(data); - } +uint8::operator float() const { return cpu_uint82float(data); } - void uint8::assign(double rhs) { - assign(static_cast(rhs)); - } +void uint8::assign(double rhs) { assign(static_cast(rhs)); } - void uint8::assign(float rhs) { - data = cpu_float2uint8(rhs); - } -} +void uint8::assign(float rhs) { data = cpu_float2uint8(rhs); } +} // namespace sd -#endif //LIBND4J_UINT8_H +#endif // LIBND4J_UINT8_H diff --git a/libnd4j/include/types/utf8string.h b/libnd4j/include/types/utf8string.h index ed25c6e10735..8a580d1ee364 100644 --- a/libnd4j/include/types/utf8string.h +++ b/libnd4j/include/types/utf8string.h @@ -18,32 +18,33 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_UTF8STRING_H -#define DEV_TESTS_UTF8STRING_H +#ifndef SD_UTF8STRING_H +#define SD_UTF8STRING_H -#include #include +#include + namespace sd { - struct ND4J_EXPORT utf8string { - private: - bool _allocated = false; - public: - char *_buffer = nullptr; - unsigned int _length = 0; +struct SD_EXPORT utf8string { + private: + bool _allocated = false; - utf8string(); - ~utf8string(); + public: + char *_buffer = nullptr; + unsigned int _length = 0; - utf8string(const char *string, int length); - utf8string(const std::string &string); - utf8string(const utf8string &other); - utf8string& operator=(const utf8string &other); + utf8string(); + ~utf8string(); - protected: - void Swap(utf8string &other); - }; -} + utf8string(const char *string, int length); + utf8string(const std::string &string); + utf8string(const utf8string &other); + utf8string &operator=(const utf8string &other); + protected: + void Swap(utf8string &other); +}; +} // namespace sd -#endif //DEV_TESTS_UTF8STRING_H +#endif // SD_UTF8STRING_H diff --git a/libnd4j/minifier/graphopt.cpp b/libnd4j/minifier/graphopt.cpp index 50321aacc66c..28fa29b121b2 100644 --- a/libnd4j/minifier/graphopt.cpp +++ b/libnd4j/minifier/graphopt.cpp @@ -21,113 +21,112 @@ * */ +#include "graphopt.h" + #include #include -#include "graphopt.h" - -std::ostream& -operator<< (std::ostream& out, GraphOpt const& opts) { - if (opts._files.empty() && opts._opts.empty()) { - out << "Empty options" << std::endl; - return out; - } - out << "==================================================" << std::endl; - out << "Files:" << std::endl; - int index = 1; - for (auto file: opts._files) { - out << "File " << index++ << ": " << file << std::endl; - } - out << "Options:" << std::endl; - for (char opt: opts._opts) { - out << "Option: " << opt; - if (opts._args.find(opt) != opts._args.end()) { - out << " with arg: " << opts._args.at(opt) << std::endl; - } - else { - out << std::endl; - } - } - out << "=================================================="; +std::ostream& operator<<(std::ostream& out, GraphOpt const& opts) { + if (opts._files.empty() && opts._opts.empty()) { + out << "Empty options" << std::endl; return out; + } + out << "==================================================" << std::endl; + out << "Files:" << std::endl; + int index = 1; + for (auto file : opts._files) { + out << "File " << index++ << ": " << file << std::endl; + } + out << "Options:" << std::endl; + for (char opt : opts._opts) { + out << "Option: " << opt; + if (opts._args.find(opt) != opts._args.end()) { + out << " with arg: " << opts._args.at(opt) << std::endl; + } else { + out << std::endl; + } + } + out << "=================================================="; + return out; } //////////////////////////////////////////////////////////////////////////////// -int -GraphOpt::optionsWithArgs(int argc, char* argv[], GraphOpt& res) { - char* optArg = nullptr; - int optIndex = 1; - - char const* optionStr = "lxa:o:e"; - std::string const defaultOutputName("nd4jlib_mini"); - - for (optIndex = 1; (optIndex < argc) && (argv[optIndex][0] == '-') && - (argv[optIndex][0]); optIndex++) { - - int opt = argv[optIndex][1]; - - if (opt == '?' || opt == 'h') { - res.help(argv[0], std::cout); - res.reset(); - return 1; - } - - char const* p = strchr(optionStr, opt); - - if (p == nullptr) - { - std::cerr << "opt " << (char)opt << " not found with " << optionStr << std::endl; - res._opts.push_back('?'); - res.reset(); - return -1; - } - else { - res._opts.push_back(opt); - - if (p[1] == ':') // processing param with - { - optIndex++; - if (optIndex >= argc) - { - std::cerr << "optIndex " << optIndex << " is out of bounds " << argc << std::endl; - res.reset(); - res._opts.push_back('?'); - return -2; - } - res._args[opt] = std::string(argv[optIndex]); - } - } - } - - if ( !res.hasParam('l') && !res.hasParam('x') ) { - std::cerr << "No -l or -x params are provided. At least one of them should be used." << std::endl; - res.reset(); - res._opts.push_back('?'); - return -3; +int GraphOpt::optionsWithArgs(int argc, char* argv[], GraphOpt& res) { + char* optArg = nullptr; + int optIndex = 1; + + char const* optionStr = "lxa:o:e"; + std::string const defaultOutputName("nd4jlib_mini"); + + for (optIndex = 1; + (optIndex < argc) && (argv[optIndex][0] == '-') && (argv[optIndex][0]); + optIndex++) { + int opt = argv[optIndex][1]; + + if (opt == '?' || opt == 'h') { + res.help(argv[0], std::cout); + res.reset(); + return 1; } - if (res._args.empty()) - res._args['o'] = defaultOutputName; - - for ( ; optIndex < argc; optIndex++) { - res._files.push_back(std::string(argv[optIndex])); + char const* p = strchr(optionStr, opt); + + if (p == nullptr) { + std::cerr << "opt " << (char)opt << " not found with " << optionStr + << std::endl; + res._opts.push_back('?'); + res.reset(); + return -1; + } else { + res._opts.push_back(opt); + + if (p[1] == ':') // processing param with + { + optIndex++; + if (optIndex >= argc) { + std::cerr << "optIndex " << optIndex << " is out of bounds " << argc + << std::endl; + res.reset(); + res._opts.push_back('?'); + return -2; + } + res._args[opt] = std::string(argv[optIndex]); + } } - return 0; + } + + if (!res.hasParam('l') && !res.hasParam('x')) { + std::cerr << "No -l or -x params are provided. At least one of them should " + "be used." + << std::endl; + res.reset(); + res._opts.push_back('?'); + return -3; + } + + if (res._args.empty()) res._args['o'] = defaultOutputName; + + for (; optIndex < argc; optIndex++) { + res._files.push_back(std::string(argv[optIndex])); + } + return 0; } //////////////////////////////////////////////////////////////////////////////// -std::ostream& -GraphOpt::help(std::string app, std::ostream& out) { - out << "Usage: \n" << app << " [-lxe] [-o outname] filename1 " - "[filename2 filename3 ... filenameN]" << std::endl; - out << "Parameters:" << std::endl; - out << "\t-l\t Generate library" << std::endl; - out << "\t-x\t Generate executable" << std::endl; - out << "\t-e\t Embed the Graph(s) into executable as resource" << std::endl; - out << "\t-o Set up output name (for library, executable or both)" << std::endl; - out << "\t-a target CPU architecture" << std::endl; - out << "\t-h\t This help" << std::endl; - - return out; +std::ostream& GraphOpt::help(std::string app, std::ostream& out) { + out << "Usage: \n" + << app + << " [-lxe] [-o outname] filename1 " + "[filename2 filename3 ... filenameN]" + << std::endl; + out << "Parameters:" << std::endl; + out << "\t-l\t Generate library" << std::endl; + out << "\t-x\t Generate executable" << std::endl; + out << "\t-e\t Embed the Graph(s) into executable as resource" << std::endl; + out << "\t-o Set up output name (for library, executable or both)" + << std::endl; + out << "\t-a target CPU architecture" << std::endl; + out << "\t-h\t This help" << std::endl; + + return out; } - diff --git a/libnd4j/minifier/graphopt.h b/libnd4j/minifier/graphopt.h index 329fc22d6a7d..d236e008c5d1 100644 --- a/libnd4j/minifier/graphopt.h +++ b/libnd4j/minifier/graphopt.h @@ -17,8 +17,8 @@ /* * GraphOpt class declarations * - * GraphOpt class used for parsing command line arguments - * + * GraphOpt class used for parsing command line arguments + * * * Created by GS 3/2/2018 * @@ -27,49 +27,51 @@ #ifndef __H__GRAPH_OPTIONS__ #define __H__GRAPH_OPTIONS__ -#include +#include +#include #include +#include #include -#include -#include class GraphOpt { -public: - typedef std::list FileList; - typedef std::list OptionList; - typedef std::unordered_map ArgumentDict; -public: - GraphOpt() - {} + public: + typedef std::list FileList; + typedef std::list OptionList; + typedef std::unordered_map ArgumentDict; - static int optionsWithArgs(int argc, char* argv[], GraphOpt& options); + public: + GraphOpt() {} - FileList& files() { return _files; } - FileList const& files() const { return _files; } - OptionList const& options() const { return _opts; } - std::string outputName() const { return _args.at('o'); } - std::string arch() const { - if (_args.count('a') < 1) { - printf("No Arg!!!\n"); - fflush(stdout); - } - return _args.at('a'); - }; - std::ostream& help(std::string app, std::ostream& out); - bool hasParam(int param) const { return std::find(_opts.begin(), _opts.end(), param) != _opts.end(); } - - friend std::ostream& operator<< (std::ostream& out, GraphOpt const& opts); + static int optionsWithArgs(int argc, char* argv[], GraphOpt& options); - void reset() { - _files.clear(); - _opts.clear(); - _args.clear(); + FileList& files() { return _files; } + FileList const& files() const { return _files; } + OptionList const& options() const { return _opts; } + std::string outputName() const { return _args.at('o'); } + std::string arch() const { + if (_args.count('a') < 1) { + printf("No Arg!!!\n"); + fflush(stdout); } + return _args.at('a'); + }; + std::ostream& help(std::string app, std::ostream& out); + bool hasParam(int param) const { + return std::find(_opts.begin(), _opts.end(), param) != _opts.end(); + } + + friend std::ostream& operator<<(std::ostream& out, GraphOpt const& opts); + + void reset() { + _files.clear(); + _opts.clear(); + _args.clear(); + } -private: - FileList _files; - OptionList _opts; - ArgumentDict _args; + private: + FileList _files; + OptionList _opts; + ArgumentDict _args; }; #endif diff --git a/libnd4j/minifier/minifier.cpp b/libnd4j/minifier/minifier.cpp index 043f2b696fdf..81c2ac0073fe 100644 --- a/libnd4j/minifier/minifier.cpp +++ b/libnd4j/minifier/minifier.cpp @@ -17,140 +17,144 @@ #include #ifdef _WIN32 - #include +#include #else - #include +#include #endif + +#include +#include +#include + #include + #include "graphopt.h" -#include -#include -#include using namespace sd::ops; using namespace sd::graph; -int -main(int argc, char *argv[]) { - // this string will contain list of operations - std::string opts_arg; - - // this string will contain optional name for output binary file - std::string name_arg; - - // this string will contain binary compilation mode: shared/static/executable - std::string build_arg; - - // this string will contain target arch/optimization mode - std::string arch_arg; - - GraphOpt opt; - int err = GraphOpt::optionsWithArgs(argc, argv, opt); - - //std::cout << opt << std::endl; - if (err > 0) { - // only help message - return err; +int main(int argc, char *argv[]) { + // this string will contain list of operations + std::string opts_arg; + + // this string will contain optional name for output binary file + std::string name_arg; + + // this string will contain binary compilation mode: shared/static/executable + std::string build_arg; + + // this string will contain target arch/optimization mode + std::string arch_arg; + + GraphOpt opt; + int err = GraphOpt::optionsWithArgs(argc, argv, opt); + + // std::cout << opt << std::endl; + if (err > 0) { + // only help message + return err; + } + + if (err < 0) { + std::cerr << "Wrong parameter list" << std::endl; + opt.help(argv[0], std::cerr); + return err; + } + + for (int option : opt.options()) { + std::cout << "Option \'" << (char)option << "\': "; + switch (option) { + case 'l': + std::cout << "Build library" << std::endl; + break; + case 'x': + std::cout << "Build executable" << std::endl; + break; + case 'e': + std::cout << "Link the Graph to executable as Resource" << std::endl; + break; + case 'o': + std::cout << "Output file name is " << opt.outputName() << std::endl; + break; + case 'a': + std::cout << "Target arch: " << opt.arch() << std::endl; + break; + default: + std::cerr << "Wrong parameter " << (char)option << std::endl; } + } - if (err < 0) { - std::cerr << "Wrong parameter list" << std::endl; - opt.help(argv[0], std::cerr); - return err; - } - - for (int option: opt.options()) { - std::cout << "Option \'" << (char)option <<"\': "; - switch (option) { - case 'l': - std::cout << "Build library" << std::endl; - break; - case 'x': - std::cout << "Build executable" << std::endl; - break; - case 'e': - std::cout << "Link the Graph to executable as Resource" << std::endl; - break; - case 'o': - std::cout << "Output file name is " << opt.outputName() << std::endl; - break; - case 'a': - std::cout << "Target arch: " << opt.arch() << std::endl; - break; - default: - std::cerr << "Wrong parameter " << (char)option << std::endl; - } - } - - if (!opt.hasParam('o')) { - std::cout << "Ouput file name is " << opt.outputName() << std::endl; - } + if (!opt.hasParam('o')) { + std::cout << "Ouput file name is " << opt.outputName() << std::endl; + } + + name_arg = " --name \'" + opt.outputName() + "\' "; - name_arg = " --name \'" + opt.outputName() + "\' "; + if (opt.hasParam('a')) arch_arg = opt.arch(); - if (opt.hasParam('a')) - arch_arg = opt.arch(); - - std::vector descriptors; - nd4j_printf("Total available operations: %i\n", OpRegistrator::getInstance().numberOfOperations()); + std::vector descriptors; + nd4j_printf("Total available operations: %i\n", + OpRegistrator::getInstance().numberOfOperations()); - for (auto file: opt.files()) { - // all files will be checked for accessibility & size + for (auto file : opt.files()) { + // all files will be checked for accessibility & size #ifdef _WIN32 - if (_access(file.c_str(), 1) != -1) { + if (_access(file.c_str(), 1) != -1) { #else - if (access(file.c_str(), F_OK | R_OK) != -1) { + if (access(file.c_str(), F_OK | R_OK) != -1) { #endif #ifdef _WIN32 - struct _stat st; - _stat(file.c_str(), &st); + struct _stat st; + _stat(file.c_str(), &st); #else - struct stat st; - stat(file.c_str(), &st); -#endif - if (st.st_size != 0) { - //std::cout << "File " << file << " exists and can be read" << std::endl; - auto graph = GraphExecutioner::importFromFlatBuffers(file.c_str()); - auto ops = graph->getOperations(); - - for (auto &v:ops) { - descriptors.emplace_back(v); - } - } else { - std::cerr << "File " << file << " exists, but has zero size" << std::endl; - return 2; - } - } - else { - std::cerr << "File " << file << " does not exists " << std::endl; - return 10; + struct stat st; + stat(file.c_str(), &st); +#endif + if (st.st_size != 0) { + // std::cout << "File " << file << " exists and can be read" << + // std::endl; + auto graph = Graph::fromFlatBuffers(file.c_str()); + auto ops = graph.unmappedNodes(); + + for (auto &v : ops) { + if (v.second.hasCustomOp()) + descriptors.emplace_back(*v.second.customOp()->getOpDescriptor()); } + } else { + std::cerr << "File " << file << " exists, but has zero size" + << std::endl; + return 2; + } + } else { + std::cerr << "File " << file << " does not exists " << std::endl; + return 10; } + } - if (!descriptors.empty()) { - GraphUtils::filterOperations(descriptors); - - nd4j_printf("Operations found so far:\n",""); - for (auto &v: descriptors) { - nd4j_printf("%s\n", v.getOpName()->c_str()); - } + if (!descriptors.empty()) { + GraphUtils::filterOperations(descriptors); - // building list of operations - opts_arg = GraphUtils::makeCommandLine(descriptors); + nd4j_printf("Operations found so far:\n", ""); + for (auto &v : descriptors) { + nd4j_printf("%s\n", v.getOpName()->c_str()); } - nd4j_printf("\n",""); - std::string output(opt.outputName()); + // building list of operations + opts_arg = GraphUtils::makeCommandLine(descriptors); + } + nd4j_printf("\n", ""); - std::string input("../include/ops/declarable/CustomOperations.h"); + std::string output(opt.outputName()); - if (0 == GraphUtils::runPreprocessor(input.c_str(), output.c_str())) { - nd4j_printf("All done successfully.\n", ""); - } + std::string input("../include/ops/declarable/CustomOperations.h"); + + if (0 == GraphUtils::runPreprocessor(input.c_str(), output.c_str())) { + nd4j_printf("All done successfully.\n", ""); + } - //nd4j_printf("Command line: %s\n", cmdline.c_str()); - // FIXME: do this in cross-platform way - nd4j_printf("Building minified library...\n", ""); + // nd4j_printf("Command line: %s\n", cmdline.c_str()); + // FIXME: do this in cross-platform way + nd4j_printf("Building minified library...\n", ""); - return EXIT_SUCCESS; + return EXIT_SUCCESS; } diff --git a/libnd4j/server/GraphServer.cpp b/libnd4j/server/GraphServer.cpp index b7615dd5cd89..d354ad35c548 100644 --- a/libnd4j/server/GraphServer.cpp +++ b/libnd4j/server/GraphServer.cpp @@ -19,121 +19,132 @@ // #include "GraphServer.h" -#include + #include +#include +#include +#include +#include +#include #include #include + #include #include -#include -#include -#include -#include +namespace sd { +namespace graph { +grpc::Status GraphInferenceServerImpl::RegisterGraph( + grpc::ServerContext *context, + const flatbuffers::grpc::Message *request_msg, + flatbuffers::grpc::Message *response_msg) { + auto flat_graph = request_msg->GetRoot(); + + try { + // building our graph + auto graph = new Graph(flat_graph); + + // single data type for now + GraphHolder::getInstance().registerGraph(flat_graph->id(), graph); + + // sending out OK response + auto response_offset = CreateFlatResponse(mb_, 0); + mb_.Finish(response_offset); + *response_msg = mb_.ReleaseMessage(); + assert(response_msg->Verify()); + + return grpc::Status::OK; + } catch (std::runtime_error &e) { + grpc::string gmsg("Caught runtime_error exception"); + return grpc::Status(grpc::StatusCode::UNKNOWN, gmsg); + } +} +grpc::Status GraphInferenceServerImpl::ReplaceGraph( + grpc::ServerContext *context, + const flatbuffers::grpc::Message *request_msg, + flatbuffers::grpc::Message *response_msg) { + auto flat_graph = request_msg->GetRoot(); + + try { + // building our graph + auto graph = new Graph(flat_graph); + + // single data type for now + GraphHolder::getInstance().replaceGraph(flat_graph->id(), graph); + + // sending out OK response + auto response_offset = CreateFlatResponse(mb_, 0); + mb_.Finish(response_offset); + *response_msg = mb_.ReleaseMessage(); + assert(response_msg->Verify()); + + return grpc::Status::OK; + } catch (sd::graph::unknown_graph_exception &e) { + grpc::string gmsg(e.message()); + return grpc::Status(grpc::StatusCode::NOT_FOUND, gmsg); + } catch (std::runtime_error &e) { + grpc::string gmsg("Caught runtime_error exception"); + return grpc::Status(grpc::StatusCode::UNKNOWN, gmsg); + } +} +grpc::Status GraphInferenceServerImpl::ForgetGraph( + grpc::ServerContext *context, + const flatbuffers::grpc::Message *request_msg, + flatbuffers::grpc::Message *response_msg) { + try { + // getting drop request + auto request = request_msg->GetRoot(); + + // dropping out graph (any datatype) + GraphHolder::getInstance().dropGraphAny(request->id()); + + // sending out OK response + auto response_offset = CreateFlatResponse(mb_, 0); + mb_.Finish(response_offset); + *response_msg = mb_.ReleaseMessage(); + assert(response_msg->Verify()); + + return grpc::Status::OK; + } catch (sd::graph::unknown_graph_exception &e) { + grpc::string gmsg(e.message()); + return grpc::Status(grpc::StatusCode::NOT_FOUND, gmsg); + } +} -namespace sd { - namespace graph { - grpc::Status GraphInferenceServerImpl::RegisterGraph( grpc::ServerContext *context, const flatbuffers::grpc::Message *request_msg, flatbuffers::grpc::Message *response_msg) { - auto flat_graph = request_msg->GetRoot(); - - try { - // building our graph - auto graph = new Graph(flat_graph); - - // single data type for now - GraphHolder::getInstance().registerGraph(flat_graph->id(), graph); - - // sending out OK response - auto response_offset = CreateFlatResponse(mb_, 0); - mb_.Finish(response_offset); - *response_msg = mb_.ReleaseMessage(); - assert(response_msg->Verify()); - - return grpc::Status::OK; - } catch (std::runtime_error &e) { - grpc::string gmsg("Caught runtime_error exception"); - return grpc::Status(grpc::StatusCode::UNKNOWN, gmsg); - } - } - - grpc::Status GraphInferenceServerImpl::ReplaceGraph( grpc::ServerContext *context, const flatbuffers::grpc::Message *request_msg, flatbuffers::grpc::Message *response_msg) { - auto flat_graph = request_msg->GetRoot(); - - try { - // building our graph - auto graph = new Graph(flat_graph); - - // single data type for now - GraphHolder::getInstance().replaceGraph(flat_graph->id(), graph); - - // sending out OK response - auto response_offset = CreateFlatResponse(mb_, 0); - mb_.Finish(response_offset); - *response_msg = mb_.ReleaseMessage(); - assert(response_msg->Verify()); - - return grpc::Status::OK; - } catch (sd::graph::unknown_graph_exception &e) { - grpc::string gmsg(e.message()); - return grpc::Status(grpc::StatusCode::NOT_FOUND, gmsg); - } catch (std::runtime_error &e) { - grpc::string gmsg("Caught runtime_error exception"); - return grpc::Status(grpc::StatusCode::UNKNOWN, gmsg); - } - } - - grpc::Status GraphInferenceServerImpl::ForgetGraph( grpc::ServerContext *context, const flatbuffers::grpc::Message *request_msg, flatbuffers::grpc::Message *response_msg) { - try { - - // getting drop request - auto request = request_msg->GetRoot(); - - // dropping out graph (any datatype) - GraphHolder::getInstance().dropGraphAny(request->id()); - - // sending out OK response - auto response_offset = CreateFlatResponse(mb_, 0); - mb_.Finish(response_offset); - *response_msg = mb_.ReleaseMessage(); - assert(response_msg->Verify()); - - return grpc::Status::OK; - } catch (sd::graph::unknown_graph_exception &e) { - grpc::string gmsg(e.message()); - return grpc::Status(grpc::StatusCode::NOT_FOUND, gmsg); - } - } - - grpc::Status GraphInferenceServerImpl::InferenceRequest( grpc::ServerContext *context, const flatbuffers::grpc::Message *request_msg, flatbuffers::grpc::Message *response_msg) { - auto request = request_msg->GetRoot(); - - try { - // GraphHolder - auto response_offset = GraphHolder::getInstance().execute(request->id(), mb_, request); - - mb_.Finish(response_offset); - *response_msg = mb_.ReleaseMessage(); - assert(response_msg->Verify()); - - return grpc::Status::OK; - } catch (sd::graph::no_results_exception &e) { - grpc::string gmsg(e.message()); - return grpc::Status(grpc::StatusCode::INTERNAL, gmsg); - } catch (sd::graph::unknown_graph_exception &e) { - grpc::string gmsg(e.message()); - return grpc::Status(grpc::StatusCode::NOT_FOUND, gmsg); - } catch (sd::graph::graph_execution_exception &e) { - grpc::string gmsg(e.message()); - return grpc::Status(grpc::StatusCode::INTERNAL, gmsg); - } catch (std::runtime_error &e) { - grpc::string gmsg("Caught runtime_error exception"); - return grpc::Status(grpc::StatusCode::UNKNOWN, gmsg); - } - } - } +grpc::Status GraphInferenceServerImpl::InferenceRequest( + grpc::ServerContext *context, + const flatbuffers::grpc::Message *request_msg, + flatbuffers::grpc::Message *response_msg) { + auto request = request_msg->GetRoot(); + + try { + // GraphHolder + auto response_offset = + GraphHolder::getInstance().execute(request->id(), mb_, request); + + mb_.Finish(response_offset); + *response_msg = mb_.ReleaseMessage(); + assert(response_msg->Verify()); + + return grpc::Status::OK; + } catch (sd::graph::no_results_exception &e) { + grpc::string gmsg(e.message()); + return grpc::Status(grpc::StatusCode::INTERNAL, gmsg); + } catch (sd::graph::unknown_graph_exception &e) { + grpc::string gmsg(e.message()); + return grpc::Status(grpc::StatusCode::NOT_FOUND, gmsg); + } catch (sd::graph::graph_execution_exception &e) { + grpc::string gmsg(e.message()); + return grpc::Status(grpc::StatusCode::INTERNAL, gmsg); + } catch (std::runtime_error &e) { + grpc::string gmsg("Caught runtime_error exception"); + return grpc::Status(grpc::StatusCode::UNKNOWN, gmsg); + } } +} // namespace graph +} // namespace sd void RunServer(int port) { assert(port > 0 && port < 65535); @@ -148,43 +159,44 @@ void RunServer(int port) { builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); builder.RegisterService(&service); std::unique_ptr server(builder.BuildAndStart()); - std::cerr << "Server listening on: [" << server_address << "]; Number of operations: [" << registrator->numberOfOperations() << "]"<< std::endl; + std::cerr << "Server listening on: [" << server_address + << "]; Number of operations: [" << registrator->numberOfOperations() + << "]" << std::endl; server->Wait(); } -char* getCmdOption(char **begin, char **end, const std::string & option) { - auto itr = std::find(begin, end, option); - if (itr != end && ++itr != end) - return *itr; +char *getCmdOption(char **begin, char **end, const std::string &option) { + auto itr = std::find(begin, end, option); + if (itr != end && ++itr != end) return *itr; - return 0; + return 0; } -bool cmdOptionExists(char** begin, char** end, const std::string& option) { - return std::find(begin, end, option) != end; +bool cmdOptionExists(char **begin, char **end, const std::string &option) { + return std::find(begin, end, option) != end; } int main(int argc, char *argv[]) { - /** - * basically we only care about few things here: - * 1) port number - * 2) if we should use gprc, json, or both - * 3) if there's any graph(s) provided at startup - */ - int port = 40123; - if(cmdOptionExists(argv, argv+argc, "-p")) { - auto sPort = getCmdOption(argv, argv + argc, "-p"); - port = atoi(sPort); - } - - if(cmdOptionExists(argv, argv+argc, "-f")) { - auto file = getCmdOption(argv, argv + argc, "-f"); - auto graph = GraphExecutioner::importFromFlatBuffers(file); - sd::graph::GraphHolder::getInstance().registerGraph(0L, graph); - } - - RunServer(port); - - return 0; + /** + * basically we only care about few things here: + * 1) port number + * 2) if we should use gprc, json, or both + * 3) if there's any graph(s) provided at startup + */ + int port = 40123; + if (cmdOptionExists(argv, argv + argc, "-p")) { + auto sPort = getCmdOption(argv, argv + argc, "-p"); + port = atoi(sPort); + } + + if (cmdOptionExists(argv, argv + argc, "-f")) { + auto file = getCmdOption(argv, argv + argc, "-f"); + auto graph = GraphExecutioner::importFromFlatBuffers(file); + sd::graph::GraphHolder::getInstance().registerGraph(0L, graph); + } + + RunServer(port); + + return 0; } \ No newline at end of file diff --git a/libnd4j/server/GraphServer.h b/libnd4j/server/GraphServer.h index 0dceacf25a33..4be077443e61 100644 --- a/libnd4j/server/GraphServer.h +++ b/libnd4j/server/GraphServer.h @@ -18,27 +18,38 @@ // @author raver119@gmail.com // - -#include #include #include -#include - #include +#include +#include namespace sd { - namespace graph { - class GraphInferenceServerImpl final : public GraphInferenceServer::Service { - private: - flatbuffers::grpc::MessageBuilder mb_; - public: - virtual grpc::Status RegisterGraph( grpc::ServerContext *context, const flatbuffers::grpc::Message *request_msg, flatbuffers::grpc::Message *response_msg); - - virtual grpc::Status ForgetGraph( grpc::ServerContext *context, const flatbuffers::grpc::Message *request_msg, flatbuffers::grpc::Message *response_msg); - - virtual grpc::Status ReplaceGraph( grpc::ServerContext *context, const flatbuffers::grpc::Message *request_msg, flatbuffers::grpc::Message *response_msg); - - virtual grpc::Status InferenceRequest( grpc::ServerContext *context, const flatbuffers::grpc::Message *request_msg, flatbuffers::grpc::Message *response_msg); - }; - } -} \ No newline at end of file +namespace graph { +class GraphInferenceServerImpl final : public GraphInferenceServer::Service { + private: + flatbuffers::grpc::MessageBuilder mb_; + + public: + virtual grpc::Status RegisterGraph( + grpc::ServerContext *context, + const flatbuffers::grpc::Message *request_msg, + flatbuffers::grpc::Message *response_msg); + + virtual grpc::Status ForgetGraph( + grpc::ServerContext *context, + const flatbuffers::grpc::Message *request_msg, + flatbuffers::grpc::Message *response_msg); + + virtual grpc::Status ReplaceGraph( + grpc::ServerContext *context, + const flatbuffers::grpc::Message *request_msg, + flatbuffers::grpc::Message *response_msg); + + virtual grpc::Status InferenceRequest( + grpc::ServerContext *context, + const flatbuffers::grpc::Message *request_msg, + flatbuffers::grpc::Message *response_msg); +}; +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/AllTests.cpp b/libnd4j/tests_cpu/layers_tests/AllTests.cpp index 669a5b1d0bed..32c84a4704c9 100644 --- a/libnd4j/tests_cpu/layers_tests/AllTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/AllTests.cpp @@ -20,19 +20,19 @@ // #include "testlayers.h" /* +#include "ConvolutionTests.cpp" +#include "DeclarableOpsTests.cpp" #include "DenseLayerTests.cpp" +#include "FlatBuffersTests.cpp" +#include "GraphTests.cpp" +#include "HashUtilsTests.cpp" #include "NDArrayTests.cpp" +#include "SessionLocalTests.cpp" +#include "StashTests.cpp" +#include "TadTests.cpp" #include "VariableSpaceTests.cpp" #include "VariableTests.cpp" -#include "DeclarableOpsTests.cpp" -#include "HashUtilsTests.cpp" #include "WorkspaceTests.cpp" -#include "ConvolutionTests.cpp" -#include "TadTests.cpp" -#include "StashTests.cpp" -#include "SessionLocalTests.cpp" -#include "GraphTests.cpp" -#include "FlatBuffersTests.cpp" */ /////// @@ -40,6 +40,6 @@ // #include "ProtoBufTests.cpp" int main(int argc, char **argv) { - testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/ArrayOptionsTests.cpp b/libnd4j/tests_cpu/layers_tests/ArrayOptionsTests.cpp index 97893ca5bc8d..507d5f744ea2 100644 --- a/libnd4j/tests_cpu/layers_tests/ArrayOptionsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ArrayOptionsTests.cpp @@ -18,92 +18,89 @@ // Created by raver119 on 13.01.2018. // -#include "testlayers.h" #include #include -using namespace sd; +#include "testlayers.h" +using namespace sd; class ArrayOptionsTests : public testing::Test { -public: - Nd4jLong shape[8] = {2, 5, 5, 5, 1, 0, 1, 99}; + public: + Nd4jLong shape[8] = {2, 5, 5, 5, 1, 0, 1, 99}; }; TEST_F(ArrayOptionsTests, TestShape_Basic_0) { - shape[5] = 1; + shape[5] = 1; - ASSERT_TRUE(ArrayOptions::isNewFormat(shape)); - ASSERT_FALSE(ArrayOptions::isSparseArray(shape)); + ASSERT_TRUE(ArrayOptions::isNewFormat(shape)); + ASSERT_FALSE(ArrayOptions::isSparseArray(shape)); } - TEST_F(ArrayOptionsTests, TestShape_Basic_1) { - shape[5] = 2; - + shape[5] = 2; - ASSERT_TRUE(ArrayOptions::isNewFormat(shape)); - ASSERT_TRUE(ArrayOptions::isSparseArray(shape)); + ASSERT_TRUE(ArrayOptions::isNewFormat(shape)); + ASSERT_TRUE(ArrayOptions::isSparseArray(shape)); } TEST_F(ArrayOptionsTests, TestShape_Basic_2) { - shape[5] = 258; + shape[5] = 258; - ASSERT_TRUE(ArrayOptions::isNewFormat(shape)); + ASSERT_TRUE(ArrayOptions::isNewFormat(shape)); - ASSERT_TRUE(ArrayOptions::isSparseArray(shape)); - ASSERT_EQ(SpaceType::CONTINUOUS, ArrayOptions::spaceType(shape)); + ASSERT_TRUE(ArrayOptions::isSparseArray(shape)); + ASSERT_EQ(SpaceType::CONTINUOUS, ArrayOptions::spaceType(shape)); } TEST_F(ArrayOptionsTests, TestShape_Basic_3) { - ASSERT_EQ(0, shape::extra(shape)); + ASSERT_EQ(0, shape::extra(shape)); - ASSERT_EQ(SpaceType::CONTINUOUS, ArrayOptions::spaceType(shape)); + ASSERT_EQ(SpaceType::CONTINUOUS, ArrayOptions::spaceType(shape)); } TEST_F(ArrayOptionsTests, TestShape_Basic_4) { + ArrayOptions::setPropertyBits(shape, {ARRAY_HALF, ARRAY_QUANTIZED}); - ArrayOptions::setPropertyBits(shape, {ARRAY_HALF, ARRAY_QUANTIZED}); - - auto dtype = ArrayOptions::dataType(shape); + auto dtype = ArrayOptions::dataType(shape); - ASSERT_FALSE(ArrayOptions::isSparseArray(shape)); - ASSERT_TRUE(sd::DataType::HALF == ArrayOptions::dataType(shape)); - ASSERT_EQ(sd::ArrayType::DENSE, ArrayOptions::arrayType(shape)); - ASSERT_EQ(sd::SpaceType::QUANTIZED, ArrayOptions::spaceType(shape)); + ASSERT_FALSE(ArrayOptions::isSparseArray(shape)); + ASSERT_TRUE(sd::DataType::HALF == ArrayOptions::dataType(shape)); + ASSERT_EQ(sd::ArrayType::DENSE, ArrayOptions::arrayType(shape)); + ASSERT_EQ(sd::SpaceType::QUANTIZED, ArrayOptions::spaceType(shape)); } TEST_F(ArrayOptionsTests, TestShape_Basic_5) { - ArrayOptions::setPropertyBits(shape, {ARRAY_SPARSE, ARRAY_INT, ARRAY_CSC}); + ArrayOptions::setPropertyBits(shape, {ARRAY_SPARSE, ARRAY_INT, ARRAY_CSC}); - ASSERT_TRUE(ArrayOptions::isSparseArray(shape)); - ASSERT_TRUE(sd::DataType::INT32 == ArrayOptions::dataType(shape)); - ASSERT_EQ(sd::SparseType::CSC, ArrayOptions::sparseType(shape)); + ASSERT_TRUE(ArrayOptions::isSparseArray(shape)); + ASSERT_TRUE(sd::DataType::INT32 == ArrayOptions::dataType(shape)); + ASSERT_EQ(sd::SparseType::CSC, ArrayOptions::sparseType(shape)); } TEST_F(ArrayOptionsTests, TestShape_Basic_6) { - ArrayOptions::setPropertyBits(shape, {ARRAY_EMPTY, ARRAY_INT, ARRAY_CSC}); + ArrayOptions::setPropertyBits(shape, {ARRAY_EMPTY, ARRAY_INT, ARRAY_CSC}); - ASSERT_EQ(sd::ArrayType::EMPTY, ArrayOptions::arrayType(shape)); + ASSERT_EQ(sd::ArrayType::EMPTY, ArrayOptions::arrayType(shape)); } TEST_F(ArrayOptionsTests, TestShape_Basic_7) { - ArrayOptions::setDataType(shape, sd::DataType::FLOAT32); - ArrayOptions::setDataType(shape, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape, sd::DataType::FLOAT32); - ASSERT_EQ(sd::DataType::FLOAT32, ArrayOptions::dataType(shape)); + ASSERT_EQ(sd::DataType::FLOAT32, ArrayOptions::dataType(shape)); } TEST_F(ArrayOptionsTests, TestShape_Basic_8) { - ArrayOptions::setDataType(shape, sd::DataType::DOUBLE); - ArrayOptions::setDataType(shape, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape, sd::DataType::DOUBLE); + ArrayOptions::setDataType(shape, sd::DataType::FLOAT32); - ASSERT_EQ(sd::DataType::FLOAT32, ArrayOptions::dataType(shape)); + ASSERT_EQ(sd::DataType::FLOAT32, ArrayOptions::dataType(shape)); } TEST_F(ArrayOptionsTests, TestShape_Basic_9) { - ArrayOptions::setDataType(shape, sd::DataType::FLOAT32); - ArrayOptions::setDataType(shape, sd::DataType::DOUBLE); + ArrayOptions::setDataType(shape, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape, sd::DataType::DOUBLE); - ASSERT_EQ(sd::DataType::DOUBLE, ArrayOptions::dataType(shape)); + ASSERT_EQ(sd::DataType::DOUBLE, ArrayOptions::dataType(shape)); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/AtomicTests.cu b/libnd4j/tests_cpu/layers_tests/AtomicTests.cu index f8248e7ea52f..1030edd88678 100644 --- a/libnd4j/tests_cpu/layers_tests/AtomicTests.cu +++ b/libnd4j/tests_cpu/layers_tests/AtomicTests.cu @@ -18,224 +18,253 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include #include -#include +#include #include #include -#include +#include +#include +#include "testlayers.h" using namespace sd; - class AtomicTests : public testing::Test { -public: - AtomicTests() { - // - } + public: + AtomicTests() { + // + } }; template -static _CUDA_G void multiplyKernel(void *vbuffer, uint64_t length, void *vresult) { - auto buffer = reinterpret_cast(vbuffer); - auto result = reinterpret_cast(vresult); +static _CUDA_G void multiplyKernel(void *vbuffer, uint64_t length, + void *vresult) { + auto buffer = reinterpret_cast(vbuffer); + auto result = reinterpret_cast(vresult); - auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { - auto rem = e % 4; - auto i = (e - rem) / 4; + for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { + auto rem = e % 4; + auto i = (e - rem) / 4; - sd::math::atomics::nd4j_atomicMul(&result[i], buffer[e]); - } + sd::math::atomics::nd4j_atomicMul(&result[i], buffer[e]); + } } template static void multiplyLauncher(void *vbuffer, uint64_t length, void *vresult) { - multiplyKernel<<<256, 256, 1024, *sd::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult); - auto err = cudaStreamSynchronize(*sd::LaunchContext::defaultContext()->getCudaStream()); - if (err != 0) - throw sd::cuda_exception::build("multiply failed", err); + multiplyKernel<<<256, 256, 1024, + *sd::LaunchContext::defaultContext()->getCudaStream()>>>( + vbuffer, length, vresult); + auto err = cudaStreamSynchronize( + *sd::LaunchContext::defaultContext()->getCudaStream()); + if (err != 0) throw sd::cuda_exception::build("multiply failed", err); } template static _CUDA_G void sumKernel(void *vbuffer, uint64_t length, void *vresult) { - auto buffer = reinterpret_cast(vbuffer); - auto result = reinterpret_cast(vresult); + auto buffer = reinterpret_cast(vbuffer); + auto result = reinterpret_cast(vresult); - auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { - auto rem = e % 4; - auto i = (e - rem) / 4; + for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { + auto rem = e % 4; + auto i = (e - rem) / 4; - sd::math::atomics::nd4j_atomicAdd(&result[i], buffer[e]); - } + sd::math::atomics::nd4j_atomicAdd(&result[i], buffer[e]); + } } template static void sumLauncher(void *vbuffer, uint64_t length, void *vresult) { - sumKernel<<<256, 256, 1024, *sd::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult); - auto err = cudaStreamSynchronize(*sd::LaunchContext::defaultContext()->getCudaStream()); - if (err != 0) - throw sd::cuda_exception::build("sum failed", err); + sumKernel<<<256, 256, 1024, + *sd::LaunchContext::defaultContext()->getCudaStream()>>>( + vbuffer, length, vresult); + auto err = cudaStreamSynchronize( + *sd::LaunchContext::defaultContext()->getCudaStream()); + if (err != 0) throw sd::cuda_exception::build("sum failed", err); } template static _CUDA_G void subKernel(void *vbuffer, uint64_t length, void *vresult) { - auto buffer = reinterpret_cast(vbuffer); - auto result = reinterpret_cast(vresult); + auto buffer = reinterpret_cast(vbuffer); + auto result = reinterpret_cast(vresult); - auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { - auto rem = e % 4; - auto i = (e - rem) / 4; + for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { + auto rem = e % 4; + auto i = (e - rem) / 4; - sd::math::atomics::nd4j_atomicSub(&result[i], buffer[e]); - } + sd::math::atomics::nd4j_atomicSub(&result[i], buffer[e]); + } } template static void subLauncher(void *vbuffer, uint64_t length, void *vresult) { - subKernel<<<256, 256, 1024, *sd::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult); - auto err = cudaStreamSynchronize(*sd::LaunchContext::defaultContext()->getCudaStream()); - if (err != 0) - throw sd::cuda_exception::build("sub failed", err); + subKernel<<<256, 256, 1024, + *sd::LaunchContext::defaultContext()->getCudaStream()>>>( + vbuffer, length, vresult); + auto err = cudaStreamSynchronize( + *sd::LaunchContext::defaultContext()->getCudaStream()); + if (err != 0) throw sd::cuda_exception::build("sub failed", err); } template static _CUDA_G void divKernel(void *vbuffer, uint64_t length, void *vresult) { - auto buffer = reinterpret_cast(vbuffer); - auto result = reinterpret_cast(vresult); + auto buffer = reinterpret_cast(vbuffer); + auto result = reinterpret_cast(vresult); - auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { - auto rem = e % 4; - auto i = (e - rem) / 4; + for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { + auto rem = e % 4; + auto i = (e - rem) / 4; - sd::math::atomics::nd4j_atomicDiv(&result[i], buffer[e]); - } + sd::math::atomics::nd4j_atomicDiv(&result[i], buffer[e]); + } } template static void divLauncher(void *vbuffer, uint64_t length, void *vresult) { - divKernel<<<256, 256, 1024, *sd::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult); - auto err = cudaStreamSynchronize(*sd::LaunchContext::defaultContext()->getCudaStream()); - if (err != 0) - throw sd::cuda_exception::build("div failed", err); + divKernel<<<256, 256, 1024, + *sd::LaunchContext::defaultContext()->getCudaStream()>>>( + vbuffer, length, vresult); + auto err = cudaStreamSynchronize( + *sd::LaunchContext::defaultContext()->getCudaStream()); + if (err != 0) throw sd::cuda_exception::build("div failed", err); } static void multiplyHost(NDArray &input, NDArray &output) { - BUILD_SINGLE_SELECTOR(input.dataType(), multiplyLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), NUMERIC_TYPES); + BUILD_SINGLE_SELECTOR( + input.dataType(), multiplyLauncher, + (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), + NUMERIC_TYPES); } static void sumHost(NDArray &input, NDArray &output) { - BUILD_SINGLE_SELECTOR(input.dataType(), sumLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), NUMERIC_TYPES); + BUILD_SINGLE_SELECTOR( + input.dataType(), sumLauncher, + (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), + NUMERIC_TYPES); } static void subHost(NDArray &input, NDArray &output) { - BUILD_SINGLE_SELECTOR(input.dataType(), subLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR( + input.dataType(), subLauncher, + (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), + FLOAT_TYPES); } static void divHost(NDArray &input, NDArray &output) { - BUILD_SINGLE_SELECTOR(input.dataType(), divLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR( + input.dataType(), divLauncher, + (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), + FLOAT_TYPES); } TEST_F(AtomicTests, test_multiply) { - std::vector dtypes = {sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::INT16, sd::DataType::HALF}; - - for (auto t:dtypes) { - nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); - NDArray input('c', {4, 25}, t); - NDArray output('c', {input.lengthOf() / 4}, t); - NDArray exp = output.ulike(); - - input.assign(2); - output.assign(2); - exp.assign(32); - - multiplyHost(input, output); - ASSERT_EQ(exp, output); - } + std::vector dtypes = {sd::DataType::FLOAT32, + sd::DataType::DOUBLE, sd::DataType::INT16, + sd::DataType::HALF}; + + for (auto t : dtypes) { + nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); + NDArray input('c', {4, 25}, t); + NDArray output('c', {input.lengthOf() / 4}, t); + NDArray exp = output.ulike(); + + input.assign(2); + output.assign(2); + exp.assign(32); + + multiplyHost(input, output); + ASSERT_EQ(exp, output); + } } TEST_F(AtomicTests, test_multiply_2) { - std::vector dtypes = {sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::HALF, sd::DataType::BFLOAT16}; - - for (auto t:dtypes) { - nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); - NDArray input('c', {4, 25}, t); - NDArray output('c', {input.lengthOf() / 4}, t); - NDArray exp = output.ulike(); - - input.assign(1.5); - output.assign(2); - exp.assign(10.125); - - multiplyHost(input, output); -// output.printBuffer("multiply 2"); - ASSERT_EQ(exp, output); - } + std::vector dtypes = {sd::DataType::FLOAT32, + sd::DataType::DOUBLE, sd::DataType::HALF, + sd::DataType::BFLOAT16}; + + for (auto t : dtypes) { + nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); + NDArray input('c', {4, 25}, t); + NDArray output('c', {input.lengthOf() / 4}, t); + NDArray exp = output.ulike(); + + input.assign(1.5); + output.assign(2); + exp.assign(10.125); + + multiplyHost(input, output); + // output.printBuffer("multiply 2"); + ASSERT_EQ(exp, output); + } } TEST_F(AtomicTests, test_sum) { - std::vector dtypes = {sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::BFLOAT16, sd::DataType::HALF, sd::DataType::INT16}; - - for (auto t:dtypes) { - nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); - NDArray input('c', {4, 25}, t); - NDArray output('c', {input.lengthOf() / 4}, t); - NDArray exp = output.ulike(); - - input.assign(1); - output.assign(1); - exp.assign(5); - - sumHost(input, output); -// output.printIndexedBuffer("Sum"); - ASSERT_EQ(exp, output); - } + std::vector dtypes = { + sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::BFLOAT16, + sd::DataType::HALF, sd::DataType::INT16}; + + for (auto t : dtypes) { + nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); + NDArray input('c', {4, 25}, t); + NDArray output('c', {input.lengthOf() / 4}, t); + NDArray exp = output.ulike(); + + input.assign(1); + output.assign(1); + exp.assign(5); + + sumHost(input, output); + // output.printIndexedBuffer("Sum"); + ASSERT_EQ(exp, output); + } } TEST_F(AtomicTests, test_sub) { - std::vector dtypes = {sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::HALF}; + std::vector dtypes = {sd::DataType::FLOAT32, + sd::DataType::DOUBLE, sd::DataType::HALF}; - for (auto t:dtypes) { - nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); - NDArray input('c', {4, 25}, t); - NDArray output('c', {input.lengthOf() / 4}, t); - NDArray exp = output.ulike(); + for (auto t : dtypes) { + nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); + NDArray input('c', {4, 25}, t); + NDArray output('c', {input.lengthOf() / 4}, t); + NDArray exp = output.ulike(); - input.assign(1); - output.assign(5); - exp.assign(1); + input.assign(1); + output.assign(5); + exp.assign(1); - subHost(input, output); -// output.printBuffer("Sub"); + subHost(input, output); + // output.printBuffer("Sub"); - ASSERT_EQ(exp, output); - } + ASSERT_EQ(exp, output); + } } TEST_F(AtomicTests, test_div) { - std::vector dtypes = {sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::BFLOAT16, sd::DataType::HALF}; - - for (auto t:dtypes) { - nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); - NDArray input('c', {4, 25}, t); - NDArray output('c', {input.lengthOf() / 4}, t); - NDArray exp = output.ulike(); - - input.assign(2); - output.assign(32); - exp.assign(2); - - divHost(input, output); -// output.printBuffer("Div"); - ASSERT_EQ(exp, output); - } + std::vector dtypes = { + sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::BFLOAT16, + sd::DataType::HALF}; + + for (auto t : dtypes) { + nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); + NDArray input('c', {4, 25}, t); + NDArray output('c', {input.lengthOf() / 4}, t); + NDArray exp = output.ulike(); + + input.assign(2); + output.assign(32); + exp.assign(2); + + divHost(input, output); + // output.printBuffer("Div"); + ASSERT_EQ(exp, output); + } } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/AttentionTests.cpp b/libnd4j/tests_cpu/layers_tests/AttentionTests.cpp index ab6d50b53895..7b7066ed5e3b 100644 --- a/libnd4j/tests_cpu/layers_tests/AttentionTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/AttentionTests.cpp @@ -18,37 +18,37 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include #include -#include #include #include +#include +#include +#include "testlayers.h" using namespace sd; - class AttentionTests : public testing::Test { -public: - AttentionTests() { - printf("\n"); - fflush(stdout); - } + public: + AttentionTests() { + printf("\n"); + fflush(stdout); + } }; TEST_F(AttentionTests, basic_dot_product_attention) { - auto keys = NDArrayFactory::create('c', {10, 4, 3}); - auto values = NDArrayFactory::create('c', {10, 4, 3}); - auto queries = NDArrayFactory::create('c', {10, 4, 1}); + auto keys = NDArrayFactory::create('c', {10, 4, 3}); + auto values = NDArrayFactory::create('c', {10, 4, 3}); + auto queries = NDArrayFactory::create('c', {10, 4, 1}); - sd::ops::dot_product_attention op; - auto result = op.evaluate({&queries, &keys, &values}, {1, 0}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::dot_product_attention op; + auto result = op.evaluate({&queries, &keys, &values}, {1, 0}); + ASSERT_EQ(Status::OK(), result.status()); } /* -//Ignored: AB 2019/05/21 - Segmentation fault on on linux-ppc64le-cpu - https://github.com/deeplearning4j/deeplearning4j/issues/7657 +//Ignored: AB 2019/05/21 - Segmentation fault on on linux-ppc64le-cpu - +https://github.com/deeplearning4j/deeplearning4j/issues/7657 TEST_F(AttentionTests, basic_dot_product_attention_bp) { auto keys = NDArrayFactory::create('c', {10, 4, 3}); auto values = NDArrayFactory::create('c', {10, 4, 3}); @@ -64,25 +64,25 @@ TEST_F(AttentionTests, basic_dot_product_attention_bp) { */ TEST_F(AttentionTests, basic_dot_product_attention_with_weights) { - auto keys = NDArrayFactory::create('c', {10, 4, 3}); - auto values = NDArrayFactory::create('c', {10, 4, 3}); - auto queries = NDArrayFactory::create('c', {10, 4, 1}); + auto keys = NDArrayFactory::create('c', {10, 4, 3}); + auto values = NDArrayFactory::create('c', {10, 4, 3}); + auto queries = NDArrayFactory::create('c', {10, 4, 1}); - sd::ops::dot_product_attention op; - auto result = op.evaluate({&queries, &keys, &values}, {1, 1}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::dot_product_attention op; + auto result = op.evaluate({&queries, &keys, &values}, {1, 1}); + ASSERT_EQ(Status::OK(), result.status()); } TEST_F(AttentionTests, basic_dot_product_attention_with_mask) { - auto keys = NDArrayFactory::create('c', {10, 4, 3}); - auto values = NDArrayFactory::create('c', {10, 4, 3}); - auto queries = NDArrayFactory::create('c', {10, 4, 1}); - auto mask = NDArrayFactory::create('c', {10, 3}); - mask.assign(1.); - - sd::ops::dot_product_attention op; - auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0}); - ASSERT_EQ(Status::OK(), result.status()); + auto keys = NDArrayFactory::create('c', {10, 4, 3}); + auto values = NDArrayFactory::create('c', {10, 4, 3}); + auto queries = NDArrayFactory::create('c', {10, 4, 1}); + auto mask = NDArrayFactory::create('c', {10, 3}); + mask.assign(1.); + + sd::ops::dot_product_attention op; + auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0}); + ASSERT_EQ(Status::OK(), result.status()); } /* @@ -96,23 +96,23 @@ TEST_F(AttentionTests, basic_dot_product_attention_bp_with_mask) { mask.assign(1.); sd::ops::dot_product_attention_bp op; - auto result = op.execute({&queries, &keys, &values, &eps, &mask}, {}, {1, 0}, {}); - ASSERT_EQ(Status::OK(), result->status()); + auto result = op.execute({&queries, &keys, &values, &eps, &mask}, {}, {1, +0}, {}); ASSERT_EQ(Status::OK(), result->status()); delete result; } */ TEST_F(AttentionTests, multi_head_input_dot_product_attention_with_mask) { - auto keys = NDArrayFactory::create('c', {2, 5, 4, 3}); - auto values = NDArrayFactory::create('c', {2, 5, 4, 3}); - auto queries = NDArrayFactory::create('c', {2, 5, 4, 1}); - auto mask = NDArrayFactory::create('c', {2, 3}); - mask.assign(1.); - - sd::ops::dot_product_attention op; - auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0}); - ASSERT_EQ(Status::OK(), result.status()); + auto keys = NDArrayFactory::create('c', {2, 5, 4, 3}); + auto values = NDArrayFactory::create('c', {2, 5, 4, 3}); + auto queries = NDArrayFactory::create('c', {2, 5, 4, 1}); + auto mask = NDArrayFactory::create('c', {2, 3}); + mask.assign(1.); + + sd::ops::dot_product_attention op; + auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0}); + ASSERT_EQ(Status::OK(), result.status()); } /* @@ -126,35 +126,36 @@ TEST_F(AttentionTests, multi_head_input_dot_product_attention_bp_with_mask) { mask.assign(1.); sd::ops::dot_product_attention_bp op; - auto result = op.execute({&queries, &keys, &values, &eps, &mask}, {}, {1, 0}, {}); - ASSERT_EQ(Status::OK(), result->status()); + auto result = op.execute({&queries, &keys, &values, &eps, &mask}, {}, {1, +0}, {}); ASSERT_EQ(Status::OK(), result->status()); delete result; } */ - TEST_F(AttentionTests, basic_multi_head_dot_product_attention) { - auto keys = NDArrayFactory::create('c', {10, 4, 5}); - auto values = NDArrayFactory::create('c', {10, 4, 5}); - auto queries = NDArrayFactory::create('c', {10, 4, 2}); - - auto Wk = NDArrayFactory::create('c', {2, 3, 4}); - auto Wv = NDArrayFactory::create('c', {2, 3, 4}); - auto Wq = NDArrayFactory::create('c', {2, 3, 4}); - auto Wo = NDArrayFactory::create('c', {2* 3, 4}); - - sd::ops::multi_head_dot_product_attention op; - auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo}, {1, 0}); - ASSERT_EQ(Status::OK(), result.status()); + auto keys = NDArrayFactory::create('c', {10, 4, 5}); + auto values = NDArrayFactory::create('c', {10, 4, 5}); + auto queries = NDArrayFactory::create('c', {10, 4, 2}); + + auto Wk = NDArrayFactory::create('c', {2, 3, 4}); + auto Wv = NDArrayFactory::create('c', {2, 3, 4}); + auto Wq = NDArrayFactory::create('c', {2, 3, 4}); + auto Wo = NDArrayFactory::create('c', {2 * 3, 4}); + + sd::ops::multi_head_dot_product_attention op; + auto result = + op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo}, {1, 0}); + ASSERT_EQ(Status::OK(), result.status()); } /* -//AB 2019/05/30 - Other attention BP tests are segfaulting on ppc64le - disabling this pre-emptively - See issue #7657 -TEST_F(AttentionTests, basic_multi_head_dot_product_bp_attention) { - auto keys = NDArrayFactory::create('c', {10, 4, 5}); - auto values = NDArrayFactory::create('c', {10, 4, 5}); - auto queries = NDArrayFactory::create('c', {10, 4, 2}); +//AB 2019/05/30 - Other attention BP tests are segfaulting on ppc64le - +disabling this pre-emptively - See issue #7657 TEST_F(AttentionTests, +basic_multi_head_dot_product_bp_attention) { auto keys = +NDArrayFactory::create('c', {10, 4, 5}); auto values = +NDArrayFactory::create('c', {10, 4, 5}); auto queries = +NDArrayFactory::create('c', {10, 4, 2}); auto Wk = NDArrayFactory::create('c', {2, 3, 4}); auto Wv = NDArrayFactory::create('c', {2, 3, 4}); @@ -165,38 +166,39 @@ TEST_F(AttentionTests, basic_multi_head_dot_product_bp_attention) { sd::ops::multi_head_dot_product_attention_bp op; - auto result = op.execute({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &eps}, {}, {1, 0}, {}); - ASSERT_EQ(Status::OK(), result->status()); + auto result = op.execute({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, +&eps}, {}, {1, 0}, {}); ASSERT_EQ(Status::OK(), result->status()); delete result; } */ TEST_F(AttentionTests, basic_multi_head_dot_product_attention_with_mask) { - auto keys = NDArrayFactory::create('c', {10, 4, 5}); - auto values = NDArrayFactory::create('c', {10, 4, 5}); - auto queries = NDArrayFactory::create('c', {10, 4, 2}); - - auto Wk = NDArrayFactory::create('c', {2, 3, 4}); - auto Wv = NDArrayFactory::create('c', {2, 3, 4}); - auto Wq = NDArrayFactory::create('c', {2, 3, 4}); - auto Wo = NDArrayFactory::create('c', {2* 3, 4}); - - auto mask = NDArrayFactory::create('c', {10, 5}); - mask.assign(1.); - - - sd::ops::multi_head_dot_product_attention op; - auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &mask}, {1, 0}); - ASSERT_EQ(Status::OK(), result.status()); + auto keys = NDArrayFactory::create('c', {10, 4, 5}); + auto values = NDArrayFactory::create('c', {10, 4, 5}); + auto queries = NDArrayFactory::create('c', {10, 4, 2}); + + auto Wk = NDArrayFactory::create('c', {2, 3, 4}); + auto Wv = NDArrayFactory::create('c', {2, 3, 4}); + auto Wq = NDArrayFactory::create('c', {2, 3, 4}); + auto Wo = NDArrayFactory::create('c', {2 * 3, 4}); + + auto mask = NDArrayFactory::create('c', {10, 5}); + mask.assign(1.); + + sd::ops::multi_head_dot_product_attention op; + auto result = op.evaluate( + {&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &mask}, {1, 0}); + ASSERT_EQ(Status::OK(), result.status()); } /* -//AB 2019/05/30 - Other attention BP tests are segfaulting on ppc64le - disabling this pre-emptively - See issue #7657 -TEST_F(AttentionTests, basic_multi_head_dot_product_bp_attention_with_mask) { - auto keys = NDArrayFactory::create('c', {10, 4, 5}); - auto values = NDArrayFactory::create('c', {10, 4, 5}); - auto queries = NDArrayFactory::create('c', {10, 4, 2}); +//AB 2019/05/30 - Other attention BP tests are segfaulting on ppc64le - +disabling this pre-emptively - See issue #7657 TEST_F(AttentionTests, +basic_multi_head_dot_product_bp_attention_with_mask) { auto keys = +NDArrayFactory::create('c', {10, 4, 5}); auto values = +NDArrayFactory::create('c', {10, 4, 5}); auto queries = +NDArrayFactory::create('c', {10, 4, 2}); auto Wk = NDArrayFactory::create('c', {2, 3, 4}); auto Wv = NDArrayFactory::create('c', {2, 3, 4}); @@ -210,8 +212,8 @@ TEST_F(AttentionTests, basic_multi_head_dot_product_bp_attention_with_mask) { sd::ops::multi_head_dot_product_attention_bp op; - auto result = op.execute({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &eps, &mask}, {}, {1, 0}, {}); - ASSERT_EQ(Status::OK(), result->status()); + auto result = op.execute({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, +&eps, &mask}, {}, {1, 0}, {}); ASSERT_EQ(Status::OK(), result->status()); delete result; } diff --git a/libnd4j/tests_cpu/layers_tests/BackpropTests.cpp b/libnd4j/tests_cpu/layers_tests/BackpropTests.cpp index 5c528f072597..d2bf722a49e4 100644 --- a/libnd4j/tests_cpu/layers_tests/BackpropTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BackpropTests.cpp @@ -18,32 +18,31 @@ // Created by raver119 on 13.01.2018. // -#include "testlayers.h" #include +#include "testlayers.h" + using namespace sd; using namespace sd::ops; using namespace sd::graph; class BackpropTests : public testing::Test { -public: - + public: }; TEST_F(BackpropTests, Test_Add_1) { + NDArray x('c', {2, 3, 4}, sd::DataType::FLOAT32); + NDArray y('c', {3, 4}, sd::DataType::FLOAT32); + NDArray e('c', {2, 3, 4}, sd::DataType::FLOAT32); - NDArray x('c', {2, 3, 4}, sd::DataType::FLOAT32); - NDArray y('c', {3, 4}, sd::DataType::FLOAT32); - NDArray e('c', {2, 3, 4}, sd::DataType::FLOAT32); - - sd::ops::add_bp op; - auto result = op.evaluate({&x, &y, &e}); + sd::ops::add_bp op; + auto result = op.evaluate({&x, &y, &e}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - auto eps = result.at(0); - auto grad = result.at(1); + auto eps = result.at(0); + auto grad = result.at(1); - ASSERT_TRUE(x.isSameShape(eps)); - ASSERT_TRUE(y.isSameShape(grad)); + ASSERT_TRUE(x.isSameShape(eps)); + ASSERT_TRUE(y.isSameShape(grad)); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/BitwiseUtilsTests.cpp b/libnd4j/tests_cpu/layers_tests/BitwiseUtilsTests.cpp index 4174637e201a..c05f152b321c 100644 --- a/libnd4j/tests_cpu/layers_tests/BitwiseUtilsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BitwiseUtilsTests.cpp @@ -18,59 +18,57 @@ // Created by raver119 on 10.11.2017. // -#include "testlayers.h" -#include #include -#include #include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class BitwiseUtilsTests : public testing::Test { -public: - + public: }; // oviously, this test will fail on big-endian machines, but who cares TEST_F(BitwiseUtilsTests, Test_Runtime_Endianess_1) { - bool isBE = BitwiseUtils::isBE(); + bool isBE = BitwiseUtils::isBE(); - ASSERT_FALSE(isBE); + ASSERT_FALSE(isBE); } TEST_F(BitwiseUtilsTests, Test_ValueBit_1) { - int idx = BitwiseUtils::valueBit(1); + int idx = BitwiseUtils::valueBit(1); - ASSERT_EQ(0, idx); + ASSERT_EQ(0, idx); } TEST_F(BitwiseUtilsTests, Test_ValueBit_2) { - int idx = BitwiseUtils::valueBit(2); + int idx = BitwiseUtils::valueBit(2); - ASSERT_EQ(1, idx); + ASSERT_EQ(1, idx); } TEST_F(BitwiseUtilsTests, Test_ValueBits_1) { - std::vector expected({1, 1}); - while (expected.size() < 32) - expected.push_back(0); + std::vector expected({1, 1}); + while (expected.size() < 32) expected.push_back(0); - std::vector result = BitwiseUtils::valueBits(3); + std::vector result = BitwiseUtils::valueBits(3); - ASSERT_EQ(32, result.size()); - ASSERT_EQ(expected, result); + ASSERT_EQ(32, result.size()); + ASSERT_EQ(expected, result); } TEST_F(BitwiseUtilsTests, Test_ValueBits_2) { - int value = 48; - int flipped = BitwiseUtils::flip_bits(value); + int value = 48; + int flipped = BitwiseUtils::flip_bits(value); - ASSERT_NE(value, flipped); + ASSERT_NE(value, flipped); - auto o = BitwiseUtils::valueBits(value); - auto f = BitwiseUtils::valueBits(flipped); + auto o = BitwiseUtils::valueBits(value); + auto f = BitwiseUtils::valueBits(flipped); - for (int e = 0; e < o.size(); e++) - ASSERT_NE(o.at(e), f.at(e)); + for (int e = 0; e < o.size(); e++) ASSERT_NE(o.at(e), f.at(e)); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/BooleanOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/BooleanOpsTests.cpp index 199cb88eb0d5..3bcb90dd8ee8 100644 --- a/libnd4j/tests_cpu/layers_tests/BooleanOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BooleanOpsTests.cpp @@ -18,131 +18,123 @@ // Created by raver119 on 13.10.2017. // -#include "testlayers.h" #include +#include "testlayers.h" + using namespace sd; using namespace sd::ops; using namespace sd::graph; class BooleanOpsTests : public testing::Test { -public: - + public: }; - TEST_F(BooleanOpsTests, LtTest_1) { - auto x = NDArrayFactory::create_(1.0f); - auto y = NDArrayFactory::create_(2.0f); - - sd::ops::lt_scalar op; + auto x = NDArrayFactory::create_(1.0f); + auto y = NDArrayFactory::create_(2.0f); + sd::ops::lt_scalar op; - ASSERT_TRUE(op.verify({x, y})); + ASSERT_TRUE(op.verify({x, y})); - delete x; - delete y; + delete x; + delete y; } TEST_F(BooleanOpsTests, LtTest_2) { - auto x = NDArrayFactory::create_(2.0f); - auto y = NDArrayFactory::create_(1.0f); + auto x = NDArrayFactory::create_(2.0f); + auto y = NDArrayFactory::create_(1.0f); - sd::ops::lt_scalar op; + sd::ops::lt_scalar op; + ASSERT_FALSE(op.verify({x, y})); - ASSERT_FALSE(op.verify({x, y})); - - delete x; - delete y; + delete x; + delete y; } TEST_F(BooleanOpsTests, Is_non_decreasing_1) { - auto x = NDArrayFactory::create('c', {2 , 2}, {1, 2, 4, 4}); - - sd::ops::is_non_decreasing op; + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 4, 4}); - ASSERT_TRUE(op.verify({&x})); + sd::ops::is_non_decreasing op; + ASSERT_TRUE(op.verify({&x})); } TEST_F(BooleanOpsTests, Is_non_decreasing_2) { - auto x = NDArrayFactory::create('c', {2 , 2}, {1, 2, 4, 3}); - - sd::ops::is_non_decreasing op; + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 4, 3}); - ASSERT_FALSE(op.verify({&x})); + sd::ops::is_non_decreasing op; + ASSERT_FALSE(op.verify({&x})); } TEST_F(BooleanOpsTests, Is_strictly_increasing_1) { - auto x = NDArrayFactory::create('c', {2 , 2}, {1, 2, 4, 5}); + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 4, 5}); - sd::ops::is_strictly_increasing op; - - ASSERT_TRUE(op.verify({&x})); + sd::ops::is_strictly_increasing op; + ASSERT_TRUE(op.verify({&x})); } TEST_F(BooleanOpsTests, Is_strictly_increasing_2) { - auto x = NDArrayFactory::create('c', {2 , 2}, {1, 2, 3, 3}); - - sd::ops::is_strictly_increasing op; + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 3}); - ASSERT_FALSE(op.verify({&x})); + sd::ops::is_strictly_increasing op; + ASSERT_FALSE(op.verify({&x})); } TEST_F(BooleanOpsTests, Is_strictly_increasing_3) { - auto x = NDArrayFactory::create('c', {2 , 2}, {1, 2, 4, 3}); + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 4, 3}); - sd::ops::is_strictly_increasing op; + sd::ops::is_strictly_increasing op; - ASSERT_FALSE(op.verify({&x})); + ASSERT_FALSE(op.verify({&x})); } TEST_F(BooleanOpsTests, Is_strictly_increasing_5) { - auto x = NDArrayFactory::create('c', {64, 512}); - x.linspace(1.0); + auto x = NDArrayFactory::create('c', {64, 512}); + x.linspace(1.0); - sd::ops::is_strictly_increasing op; + sd::ops::is_strictly_increasing op; - ASSERT_TRUE(op.verify({&x})); + ASSERT_TRUE(op.verify({&x})); } TEST_F(BooleanOpsTests, Is_strictly_increasing_6) { - auto x = NDArrayFactory::create('c', {64, 512}); - x.linspace(1.0); + auto x = NDArrayFactory::create('c', {64, 512}); + x.linspace(1.0); - x.p(18, 1000323.f); + x.p(18, 1000323.f); - sd::ops::is_strictly_increasing op; + sd::ops::is_strictly_increasing op; - ASSERT_FALSE(op.verify({&x})); + ASSERT_FALSE(op.verify({&x})); } TEST_F(BooleanOpsTests, Is_numeric_tensor_1) { - auto x = NDArrayFactory::create('c', {2 , 2}, {1.f, 2.f, 4.f, 3.f}); + auto x = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 4.f, 3.f}); - sd::ops::is_numeric_tensor op; + sd::ops::is_numeric_tensor op; - ASSERT_TRUE(op.verify({&x})); + ASSERT_TRUE(op.verify({&x})); } TEST_F(BooleanOpsTests, test_where_1) { - auto x = NDArrayFactory::create('c', {6}, { 1, -3, 4, 8, -2, 5 }); - auto y = NDArrayFactory::create('c', {6}, { 2, -3, 1, 1, -2, 1 }); - auto e = NDArrayFactory::create('c', {3}, { 4, 8, 5 }); + auto x = NDArrayFactory::create('c', {6}, {1, -3, 4, 8, -2, 5}); + auto y = NDArrayFactory::create('c', {6}, {2, -3, 1, 1, -2, 1}); + auto e = NDArrayFactory::create('c', {3}, {4, 8, 5}); - sd::ops::choose op; + sd::ops::choose op; - auto result = op.evaluate({&x, &y}, {3}); - ASSERT_EQ(Status::OK(), result.status()); + auto result = op.evaluate({&x, &y}, {3}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - //z->printIndexedBuffer("z"); + // z->printIndexedBuffer("z"); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } - diff --git a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp index 5a4db9fb8163..fa72c879d2e6 100644 --- a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp @@ -18,838 +18,799 @@ // Created by raver119 on 23.11.17. // - -#include "testlayers.h" #include #include #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class BroadcastableOpsTests : public testing::Test { -public: - + public: }; TEST_F(BroadcastableOpsTests, Test_Add_1) { + NDArray x('c', {5, 5}, sd::DataType::FLOAT32); + NDArray y('c', {1, 5}, sd::DataType::FLOAT32); + NDArray exp('c', {5, 5}, sd::DataType::FLOAT32); + x.linspace(1); + y.linspace(1); + exp.linspace(1); - NDArray x('c', {5, 5}, sd::DataType::FLOAT32); - NDArray y('c', {1, 5}, sd::DataType::FLOAT32); - NDArray exp('c', {5, 5}, sd::DataType::FLOAT32); - x.linspace(1); - y.linspace(1); - exp.linspace(1); - - //exp.printIndexedBuffer("E B"); - - exp.applyBroadcast(broadcast::Add, {1}, y, exp); + // exp.printIndexedBuffer("E B"); - sd::ops::add op; - auto result = op.evaluate({&x, &y}); + exp.applyBroadcast(broadcast::Add, {1}, y, exp); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::add op; + auto result = op.evaluate({&x, &y}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - //exp.printIndexedBuffer("E A"); - //z->printIndexedBuffer("Z"); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + // exp.printIndexedBuffer("E A"); + // z->printIndexedBuffer("Z"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(BroadcastableOpsTests, Test_Multiply_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {1, 5}); - auto exp = NDArrayFactory::create('c', {5, 5}); - x.linspace(1); - y.linspace(1); - exp.linspace(1); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {1, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.linspace(1); + y.linspace(1); + exp.linspace(1); - exp.applyBroadcast(broadcast::Multiply, {1}, y, exp); + exp.applyBroadcast(broadcast::Multiply, {1}, y, exp); - sd::ops::multiply op; - auto result = op.evaluate({&x, &y}); + sd::ops::multiply op; + auto result = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(BroadcastableOpsTests, Test_SquaredSubtract_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {1, 5}); - auto exp = NDArrayFactory::create('c', {5, 5}); - x.linspace(1); - y.linspace(1); - exp.linspace(1); - - exp.applyBroadcast(broadcast::SquaredSubtract, {1}, y, exp); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {1, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.linspace(1); + y.linspace(1); + exp.linspace(1); + exp.applyBroadcast(broadcast::SquaredSubtract, {1}, y, exp); - sd::ops::squaredsubtract op; - auto result = op.evaluate({&x, &y}); + sd::ops::squaredsubtract op; + auto result = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(BroadcastableOpsTests, Test_ScalarBroadcast_1) { - auto x = NDArrayFactory::create('c', {1, 1}, {1}); - auto y = NDArrayFactory::create('c', {1, 3}, {0, 1, 2}); - auto exp = NDArrayFactory::create('c', {1,3}, {1, 0, -1}); + auto x = NDArrayFactory::create('c', {1, 1}, {1}); + auto y = NDArrayFactory::create('c', {1, 3}, {0, 1, 2}); + auto exp = NDArrayFactory::create('c', {1, 3}, {1, 0, -1}); - sd::ops::subtract op; - auto result = op.evaluate({&x, &y}); + sd::ops::subtract op; + auto result = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(BroadcastableOpsTests, Test_ScalarBroadcast_2) { - auto x = NDArrayFactory::create('c', {1, 1}, {1}); - auto y = NDArrayFactory::create('c', {1, 3}, {0, 1, 2}); - auto exp = NDArrayFactory::create('c', {1,3}, {1, 2, 3}); + auto x = NDArrayFactory::create('c', {1, 1}, {1}); + auto y = NDArrayFactory::create('c', {1, 3}, {0, 1, 2}); + auto exp = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); - sd::ops::add op; - auto result = op.evaluate({&x, &y}); + sd::ops::add op; + auto result = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(BroadcastableOpsTests, Test_Maximum_1) { - auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 1, 2, 3, 2}); - auto row = NDArrayFactory::create('c', {1, 3}, {2, 2, 2}); - auto exp = NDArrayFactory::create('c', {2, 3}, {2, 2, 2, 2, 3, 2}); + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 1, 2, 3, 2}); + auto row = NDArrayFactory::create('c', {1, 3}, {2, 2, 2}); + auto exp = NDArrayFactory::create('c', {2, 3}, {2, 2, 2, 2, 3, 2}); - sd::ops::maximum op; - auto result = op.evaluate({&x, &row}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::maximum op; + auto result = op.evaluate({&x, &row}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(BroadcastableOpsTests, Test_Minimum_1) { - auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 1, 2, 3, 2}); - auto col = NDArrayFactory::create('c', {2, 1}, {2, 1}); - auto exp = NDArrayFactory::create('c', {2, 3}, {1, 2, 1, 1, 1, 1}); + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 1, 2, 3, 2}); + auto col = NDArrayFactory::create('c', {2, 1}, {2, 1}); + auto exp = NDArrayFactory::create('c', {2, 3}, {1, 2, 1, 1, 1, 1}); - sd::ops::minimum op; - auto result = op.evaluate({&x, &col}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::minimum op; + auto result = op.evaluate({&x, &col}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(BroadcastableOpsTests, Test_Shape_1) { - sd::ops::minimum op; + sd::ops::minimum op; - Nd4jLong shapeX[] = {2, 2, 5, 5, 1, 8192, 1, 99}; - Nd4jLong shapeY[] = {2, 2, 5, 5, 1, 8192, 1, 99}; - ShapeList inputShape({shapeX, shapeY}); - VariableSpace vs; - Context ctx(1, &vs, false); + Nd4jLong shapeX[] = {2, 2, 5, 5, 1, 8192, 1, 99}; + Nd4jLong shapeY[] = {2, 2, 5, 5, 1, 8192, 1, 99}; + ShapeList inputShape({shapeX, shapeY}); + VariableSpace vs; + Context ctx(1, &vs, false); - auto shapes = op.calculateOutputShape(&inputShape, ctx); + auto shapes = op.calculateOutputShape(&inputShape, ctx); - auto shapeZ = shapes->at(0); - ASSERT_TRUE(shape::shapeEquals(shapeX, shapeZ)); + auto shapeZ = shapes->at(0); + ASSERT_TRUE(shape::shapeEquals(shapeX, shapeZ)); - delete shapes; + delete shapes; } TEST_F(BroadcastableOpsTests, Test_Shape_2) { - sd::ops::minimum op; + sd::ops::minimum op; - const Nd4jLong shapeX[] = {2, 1, 1, 1, 1, 8192, 1, 99}; - const Nd4jLong shapeY[] = {2, 2, 5, 5, 1, 8192, 1, 99}; - ShapeList inputShape({shapeX, shapeY}); - VariableSpace vs; - Context ctx(1, &vs, false); + const Nd4jLong shapeX[] = {2, 1, 1, 1, 1, 8192, 1, 99}; + const Nd4jLong shapeY[] = {2, 2, 5, 5, 1, 8192, 1, 99}; + ShapeList inputShape({shapeX, shapeY}); + VariableSpace vs; + Context ctx(1, &vs, false); - auto shapes = op.calculateOutputShape(&inputShape, ctx); + auto shapes = op.calculateOutputShape(&inputShape, ctx); - auto shapeZ = shapes->at(0); - ASSERT_TRUE(shape::shapeEquals(shapeY, shapeZ)); + auto shapeZ = shapes->at(0); + ASSERT_TRUE(shape::shapeEquals(shapeY, shapeZ)); - delete shapes; + delete shapes; } - TEST_F(BroadcastableOpsTests, Test_Shape_3) { - sd::ops::minimum op; + sd::ops::minimum op; - const Nd4jLong shapeX[] = {2, 5, 3, 1, 1, 8192, 1, 99}; - const Nd4jLong shapeY[] = {2, 1, 3, 3, 1, 8192, 1, 99}; - ShapeList inputShape({shapeX, shapeY}); - VariableSpace vs; - Context ctx(1, &vs, false); + const Nd4jLong shapeX[] = {2, 5, 3, 1, 1, 8192, 1, 99}; + const Nd4jLong shapeY[] = {2, 1, 3, 3, 1, 8192, 1, 99}; + ShapeList inputShape({shapeX, shapeY}); + VariableSpace vs; + Context ctx(1, &vs, false); - auto shapes = op.calculateOutputShape(&inputShape, ctx); + auto shapes = op.calculateOutputShape(&inputShape, ctx); - auto shapeZ = shapes->at(0); - ASSERT_TRUE(shape::shapeEquals(shapeX, shapeZ)); + auto shapeZ = shapes->at(0); + ASSERT_TRUE(shape::shapeEquals(shapeX, shapeZ)); - delete shapes; + delete shapes; } - TEST_F(BroadcastableOpsTests, Test_Shape_4) { - sd::ops::minimum op; + sd::ops::minimum op; - const Nd4jLong shapeX[] = {2, 5, 3, 1, 1, 8192, 1, 99}; - const Nd4jLong shapeY[] = {2, 5, 1, 1, 1, 8192, 1, 99}; - ShapeList inputShape({shapeX, shapeY}); - VariableSpace vs; - Context ctx(1, &vs, false); + const Nd4jLong shapeX[] = {2, 5, 3, 1, 1, 8192, 1, 99}; + const Nd4jLong shapeY[] = {2, 5, 1, 1, 1, 8192, 1, 99}; + ShapeList inputShape({shapeX, shapeY}); + VariableSpace vs; + Context ctx(1, &vs, false); - auto shapes = op.calculateOutputShape(&inputShape, ctx); + auto shapes = op.calculateOutputShape(&inputShape, ctx); - auto shapeZ = shapes->at(0); - ASSERT_TRUE(shape::shapeEquals(shapeX, shapeZ)); + auto shapeZ = shapes->at(0); + ASSERT_TRUE(shape::shapeEquals(shapeX, shapeZ)); - delete shapes; + delete shapes; } // (2,1,3) + (4,3) = (2,4,3) TEST_F(BroadcastableOpsTests, Test_Shape_5) { - sd::ops::minimum op; + sd::ops::minimum op; - const Nd4jLong shapeX[] = {3, 2, 1, 3, 3, 3, 1, 8192, 1, 99}; - const Nd4jLong shapeY[] = {2, 4, 3, 3, 1, 8192, 1, 99}; - const Nd4jLong shapeE[] = {3, 2, 4, 3, 12, 3, 1, 8192, 1, 99}; - ShapeList inputShape({shapeX, shapeY}); - VariableSpace vs; - Context ctx(1, &vs, false); + const Nd4jLong shapeX[] = {3, 2, 1, 3, 3, 3, 1, 8192, 1, 99}; + const Nd4jLong shapeY[] = {2, 4, 3, 3, 1, 8192, 1, 99}; + const Nd4jLong shapeE[] = {3, 2, 4, 3, 12, 3, 1, 8192, 1, 99}; + ShapeList inputShape({shapeX, shapeY}); + VariableSpace vs; + Context ctx(1, &vs, false); - auto shapes = op.calculateOutputShape(&inputShape, ctx); + auto shapes = op.calculateOutputShape(&inputShape, ctx); - auto shapeZ = shapes->at(0); - ASSERT_TRUE(shape::shapeEquals(shapeE, shapeZ)); + auto shapeZ = shapes->at(0); + ASSERT_TRUE(shape::shapeEquals(shapeE, shapeZ)); - delete shapes; + delete shapes; } TEST_F(BroadcastableOpsTests, Test_Scalar_Add_1) { - auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto y = NDArrayFactory::create(2.0f); - auto exp = NDArrayFactory::create('c', {2, 2}, {3, 4, 5, 6}); - - sd::ops::add op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create('c', {2, 2}, {3, 4, 5, 6}); - auto z = result.at(0); + sd::ops::add op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(BroadcastableOpsTests, Test_Inplace_Output_1) { - auto x = NDArrayFactory::create('c', {2, 1, 3}); - auto y = NDArrayFactory::create('c', {4, 3}); - auto o = NDArrayFactory::create('c', {2, 4, 3}); - auto e = NDArrayFactory::create('c', {2, 4, 3}); - auto buffO1 = reinterpret_cast(o.buffer()); - y.assign(1.0f); - e.assign(1.0f); + auto x = NDArrayFactory::create('c', {2, 1, 3}); + auto y = NDArrayFactory::create('c', {4, 3}); + auto o = NDArrayFactory::create('c', {2, 4, 3}); + auto e = NDArrayFactory::create('c', {2, 4, 3}); + auto buffO1 = reinterpret_cast(o.buffer()); + y.assign(1.0f); + e.assign(1.0f); - sd::ops::add op; - auto result = op.execute({&x, &y}, {&o}, {}, {}, {}); - ASSERT_EQ(Status::OK(), result); + sd::ops::add op; + auto result = op.execute({&x, &y}, {&o}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); - auto buffO2 = reinterpret_cast(o.buffer()); + auto buffO2 = reinterpret_cast(o.buffer()); - ASSERT_TRUE(e.isSameShape(o)); - ASSERT_TRUE(e.equalsTo(o)); + ASSERT_TRUE(e.isSameShape(o)); + ASSERT_TRUE(e.equalsTo(o)); - ASSERT_TRUE(buffO1 == buffO2); + ASSERT_TRUE(buffO1 == buffO2); } TEST_F(BroadcastableOpsTests, Test_Subtract_1) { + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); + auto e = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); - auto e = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); - - auto z = x - y; + auto z = x - y; - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(BroadcastableOpsTests, Test_Subtract_2) { - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); - auto e = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); + auto e = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); - sd::ops::subtract op; - auto result = op.evaluate({&x, &y}); - auto z = result.at(0); + sd::ops::subtract op; + auto result = op.evaluate({&x, &y}); + auto z = result.at(0); - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(BroadcastableOpsTests, Test_Subtract_3) { - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); - auto z = NDArrayFactory::create('c', {2}, {0.0f, 0.0f}); - auto e = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); + auto z = NDArrayFactory::create('c', {2}, {0.0f, 0.0f}); + auto e = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); - sd::ops::subtract op; - auto result = op.execute({&x, &y}, {&z}, {}, {}, {}); + sd::ops::subtract op; + auto result = op.execute({&x, &y}, {&z}, {}, {}, {}); - ASSERT_EQ(Status::OK(), result); - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_EQ(Status::OK(), result); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(BroadcastableOpsTests, Test_Subtract_4) { - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); - auto e = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); + auto e = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); - auto z = x.applyTrueBroadcast(BroadcastOpsTuple::Subtract(), y); + auto z = x.applyTrueBroadcast(BroadcastOpsTuple::Subtract(), y); - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, Test_Subtract_5) { - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); - auto e = NDArrayFactory::create('c', {2}, {-1., 0.}); + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); + auto e = NDArrayFactory::create('c', {2}, {-1., 0.}); - auto z = y - x; + auto z = y - x; - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, Test_Subtract_6) { - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create(4.f); - auto e = NDArrayFactory::create(3.f); + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create(4.f); + auto e = NDArrayFactory::create(3.f); - auto z = y - x; + auto z = y - x; - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, Test_Subtract_7) { - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create(4.f); - auto e = NDArrayFactory::create(-3.f); + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create(4.f); + auto e = NDArrayFactory::create(-3.f); - auto z = x - y; + auto z = x - y; - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, Test_Add_2) { + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); + auto e = NDArrayFactory::create('c', {2}, {1.f, 2.f}); - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); - auto e = NDArrayFactory::create('c', {2}, {1.f, 2.f}); - - auto z = x + y; + auto z = x + y; - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, Test_Add_3) { + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); + auto e = NDArrayFactory::create('c', {2}, {1.f, 2.f}); - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); - auto e = NDArrayFactory::create('c', {2}, {1.f, 2.f}); + auto z = y + x; - auto z = y + x; - - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, Test_Add_4) { + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create(4.f); + auto e = NDArrayFactory::create(5.f); - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create(4.f); - auto e = NDArrayFactory::create(5.f); - - auto z = x + y; + auto z = x + y; - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, Test_Add_5) { + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create(4.f); + auto e = NDArrayFactory::create(5.f); - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create(4.f); - auto e = NDArrayFactory::create(5.f); - - auto z = y + x; + auto z = y + x; - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, Test_Multiply_2) { + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create('c', {2}, {3.f, 4.f}); + auto e = NDArrayFactory::create('c', {2}, {6.f, 8.f}); - auto x = NDArrayFactory::create(2.0f); - auto y = NDArrayFactory::create('c', {2}, {3.f, 4.f}); - auto e = NDArrayFactory::create('c', {2}, {6.f, 8.f}); + auto z = y * x; - auto z = y * x; - - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, Test_Multiply_3) { + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create('c', {2}, {3.f, 4.f}); + auto e = NDArrayFactory::create('c', {2}, {6.f, 8.f}); - auto x = NDArrayFactory::create(2.0f); - auto y = NDArrayFactory::create('c', {2}, {3.f, 4.f}); - auto e = NDArrayFactory::create('c', {2}, {6.f, 8.f}); + auto z = x * y; - auto z = x * y; - - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, Test_Multiply_4) { + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create(4.f); + auto e = NDArrayFactory::create(8.f); - auto x = NDArrayFactory::create(2.0f); - auto y = NDArrayFactory::create(4.f); - auto e = NDArrayFactory::create(8.f); - - auto z = y * x; + auto z = y * x; - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, Test_Multiply_5) { + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create(4.f); + auto e = NDArrayFactory::create(8.f); - auto x = NDArrayFactory::create(2.0f); - auto y = NDArrayFactory::create(4.f); - auto e = NDArrayFactory::create(8.f); + auto z = x * y; - auto z = x * y; - - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(BroadcastableOpsTests, Test_Multiply_6) { - auto x = NDArrayFactory::create(2.0f); - auto y = NDArrayFactory::create('c', {1}, {4.f}); - auto e = NDArrayFactory::create('c', {1}, {8.f}); + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create('c', {1}, {4.f}); + auto e = NDArrayFactory::create('c', {1}, {8.f}); - auto z = x * y; + auto z = x * y; - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(BroadcastableOpsTests, Test_Multiply_7) { - auto x = NDArrayFactory::create(2.0f); - auto y = NDArrayFactory::create('c', {1}, {4.f}); - auto e = NDArrayFactory::create('c', {1}, {8.f}); - - sd::ops::multiply op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create('c', {1}, {4.f}); + auto e = NDArrayFactory::create('c', {1}, {8.f}); - auto z = result.at(0); + sd::ops::multiply op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(e.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(BroadcastableOpsTests, Test_Multiply_8) { - auto x = NDArrayFactory::create(2.0f); - auto y = NDArrayFactory::create('c', {1, 1}, {4.f}); - auto e = NDArrayFactory::create('c', {1, 1}, {8.f}); + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create('c', {1, 1}, {4.f}); + auto e = NDArrayFactory::create('c', {1, 1}, {8.f}); - sd::ops::multiply op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::multiply op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, broadcast_add_1) { + NDArray x('c', {4}, {1, 1, 1, 1}); + NDArray y('c', {1, 4}, {1, 2, 3, 4}); + NDArray z('c', {1, 4}, sd::DataType::DOUBLE); + NDArray exp('c', {1, 4}, {2, 3, 4, 5}, sd::DataType::DOUBLE); - NDArray x('c', {4}, {1,1,1,1}); - NDArray y('c', {1,4}, {1,2,3,4}); - NDArray z('c', {1,4}, sd::DataType::DOUBLE); - NDArray exp('c', {1,4}, {2,3,4,5}, sd::DataType::DOUBLE); - - sd::ops::add op; - auto status = op.execute({&x, &y}, {&z}); + sd::ops::add op; + auto status = op.execute({&x, &y}, {&z}); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(z.equalsTo(exp)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(z.equalsTo(exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, broadcast_equals_1) { + NDArray x('c', {1, 4}, {1, 2, 3, 4}); + NDArray y('c', {3, 4}, {0, 0, 0, 0, 1, 2, 3, 4, 1, 2, 3, 4}); + NDArray z('c', {3, 4}, sd::DataType::BOOL); + NDArray exp('c', {3, 4}, {0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1}, + sd::DataType::BOOL); - NDArray x('c', {1,4}, {1,2,3,4}); - NDArray y('c', {3,4}, {0,0,0,0, 1,2,3,4, 1,2,3,4}); - NDArray z('c', {3,4}, sd::DataType::BOOL); - NDArray exp('c', {3,4}, {0,0,0,0, 1,1,1,1, 1,1,1,1}, sd::DataType::BOOL); + sd::ops::equals op; + auto status = op.execute({&x, &y}, {&z}); + // z.printIndexedBuffer(); - sd::ops::equals op; - auto status = op.execute({&x, &y}, {&z}); - // z.printIndexedBuffer(); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(z.equalsTo(exp)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(z.equalsTo(exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, broadcast_empty_1) { + NDArray y('c', {3, 4}, {0, 0, 0, 0, 1, 2, 3, 4, 1, 2, 3, 4}); + NDArray x(sd::DataType::DOUBLE, y.getContext(), false); + NDArray z(sd::DataType::DOUBLE, y.getContext(), false); + NDArray zExp(sd::DataType::DOUBLE, y.getContext(), false); - NDArray y('c', {3,4}, {0,0,0,0, 1,2,3,4, 1,2,3,4}); - NDArray x(sd::DataType::DOUBLE, y.getContext(), false); - NDArray z(sd::DataType::DOUBLE, y.getContext(), false); - NDArray zExp(sd::DataType::DOUBLE, y.getContext(), false); - - sd::ops::multiply op; - auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + sd::ops::multiply op; + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(z.isSameShape(zExp)); - ASSERT_TRUE(z.equalsTo(zExp)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(z.isSameShape(zExp)); + ASSERT_TRUE(z.equalsTo(zExp)); } TEST_F(BroadcastableOpsTests, broadcast_empty_2) { + NDArray y('c', {1, 4}, {1, 2, 3, 4}); + NDArray x = NDArrayFactory::create('c', {0, 4}); + NDArray e = NDArrayFactory::create('c', {0, 4}); + ; - NDArray y('c', {1,4}, {1,2,3,4}); - NDArray x = NDArrayFactory::create('c', {0, 4}); - NDArray e = NDArrayFactory::create('c', {0, 4});; - - sd::ops::multiply op; - auto status = op.execute({&x, &y}, {&x}, {}, {}, {}); + sd::ops::multiply op; + auto status = op.execute({&x, &y}, {&x}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(e.isSameShape(x)); - ASSERT_TRUE(e.equalsTo(x)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(e.isSameShape(x)); + ASSERT_TRUE(e.equalsTo(x)); } TEST_F(BroadcastableOpsTests, broadcast_empty_3) { + NDArray x = NDArrayFactory::create('c', {1, 0, 2}); + NDArray y('c', {}, std::vector{0.1}, sd::DataType::FLOAT32); + NDArray e = NDArrayFactory::create('c', {1, 0, 2}); + ; - NDArray x = NDArrayFactory::create('c', {1, 0, 2}); - NDArray y('c', {}, std::vector{0.1}, sd::DataType::FLOAT32); - NDArray e = NDArrayFactory::create('c', {1, 0, 2});; + sd::ops::maximum op; + auto result = op.evaluate({&x, &y}); - sd::ops::maximum op; - auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); - auto z = result.at(0); - - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(*z)); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(BroadcastableOpsTests, broadcast_empty_4) { + NDArray x = NDArrayFactory::create('c', {1, 0, 1}); + NDArray y = NDArrayFactory::create('c', {1, 0, 2}); + NDArray e = NDArrayFactory::create('c', {1, 0, 2}); + ; - NDArray x = NDArrayFactory::create('c', {1, 0, 1}); - NDArray y = NDArrayFactory::create('c', {1, 0, 2}); - NDArray e = NDArrayFactory::create('c', {1, 0, 2});; - - sd::ops::maximum op; - auto result = op.evaluate({&x, &y}); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::maximum op; + auto result = op.evaluate({&x, &y}); - auto z = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(*z)); + auto z = result.at(0); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(BroadcastableOpsTests, broadcast_empty_5) { + NDArray x = NDArrayFactory::create('c', {1, 0, 1}); + NDArray y = NDArrayFactory::create('c', {1, 0, 2}); + NDArray e = NDArrayFactory::create('c', {1, 0, 2}); + ; - NDArray x = NDArrayFactory::create('c', {1, 0, 1}); - NDArray y = NDArrayFactory::create('c', {1, 0, 2}); - NDArray e = NDArrayFactory::create('c', {1, 0, 2});; + sd::ops::realdiv op; + auto result = op.evaluate({&x, &y}); - sd::ops::realdiv op; - auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(*z)); + auto z = result.at(0); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(BroadcastableOpsTests, broadcast_empty_6) { + NDArray x = NDArrayFactory::create('c', {1, 0, 1}); + NDArray y = NDArrayFactory::create('c', {1, 2}, {2, 2}); + NDArray e = NDArrayFactory::create('c', {1, 0, 2}); + ; - NDArray x = NDArrayFactory::create('c', {1, 0, 1}); - NDArray y = NDArrayFactory::create('c', {1, 2}, {2, 2}); - NDArray e = NDArrayFactory::create('c', {1, 0, 2});; - - sd::ops::realdiv op; - auto result = op.evaluate({&x, &y}); + sd::ops::realdiv op; + auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(*z)); + auto z = result.at(0); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(BroadcastableOpsTests, broadcast_empty_7) { + NDArray x = NDArrayFactory::create('c', {1, 0, 2, 1}); + NDArray y = NDArrayFactory::create('c', {1, 2, 0}); + NDArray e = NDArrayFactory::create('c', {1, 0, 2, 0}); + ; - NDArray x = NDArrayFactory::create('c', {1, 0, 2, 1}); - NDArray y = NDArrayFactory::create('c', {1, 2, 0}); - NDArray e = NDArrayFactory::create('c', {1, 0, 2, 0});; - - sd::ops::realdiv op; - auto result = op.evaluate({&x, &y}); + sd::ops::realdiv op; + auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(*z)); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } - TEST_F(BroadcastableOpsTests, broadcast_bool_empty_1) { + NDArray y('c', {3, 4}, {0, 0, 0, 0, 1, 2, 3, 4, 1, 2, 3, 4}); + NDArray x(sd::DataType::DOUBLE, y.getContext(), false); + NDArray z(sd::DataType::BOOL, y.getContext(), false); + NDArray zExp(sd::DataType::BOOL, y.getContext(), false); - NDArray y('c', {3,4}, {0,0,0,0, 1,2,3,4, 1,2,3,4}); - NDArray x(sd::DataType::DOUBLE, y.getContext(), false); - NDArray z(sd::DataType::BOOL, y.getContext(), false); - NDArray zExp(sd::DataType::BOOL, y.getContext(), false); + sd::ops::greater op; + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); - sd::ops::greater op; - auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(z.isSameShape(zExp)); - ASSERT_TRUE(z.equalsTo(zExp)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(z.isSameShape(zExp)); + ASSERT_TRUE(z.equalsTo(zExp)); } TEST_F(BroadcastableOpsTests, broadcast_bool_empty_2) { + NDArray y('c', {1, 4}, {1, 2, 3, 4}); + NDArray x = NDArrayFactory::create('c', {0, 4}); + NDArray e = NDArrayFactory::create('c', {0, 4}); + ; - NDArray y('c', {1,4}, {1,2,3,4}); - NDArray x = NDArrayFactory::create('c', {0, 4}); - NDArray e = NDArrayFactory::create('c', {0, 4});; - - - sd::ops::greater op; - auto result = op.evaluate({&x, &y}); + sd::ops::greater op; + auto result = op.evaluate({&x, &y}); - auto z = result.at(0); + auto z = result.at(0); - // z->printShapeInfo("z"); + // z->printShapeInfo("z"); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(*z)); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(BroadcastableOpsTests, broadcast_bool_1) { + NDArray x('c', {3, 1, 2}, sd::DataType::FLOAT32); + NDArray y('c', {2, 2}, sd::DataType::FLOAT32); + NDArray z('c', {3, 2, 2}, sd::DataType::BOOL); + NDArray e('c', {3, 2, 2}, sd::DataType::BOOL); - NDArray x('c', {3, 1, 2}, sd::DataType::FLOAT32); - NDArray y('c', {2, 2}, sd::DataType::FLOAT32); - NDArray z('c', {3, 2, 2}, sd::DataType::BOOL); - NDArray e('c', {3, 2, 2}, sd::DataType::BOOL); + x.assign(4.f); + y.assign(2.f); + e.assign(true); - x.assign(4.f); - y.assign(2.f); - e.assign(true); + sd::ops::greater op; - sd::ops::greater op; + auto status = op.execute({&x, &y}, {&z}); - auto status = op.execute({&x, &y}, {&z}); + ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_EQ(ND4J_STATUS_OK, status); + // z.printIndexedBuffer("Z"); - // z.printIndexedBuffer("Z"); - - ASSERT_TRUE(z.isSameShape(e)); - ASSERT_TRUE(z.equalsTo(e)); + ASSERT_TRUE(z.isSameShape(e)); + ASSERT_TRUE(z.equalsTo(e)); } TEST_F(BroadcastableOpsTests, broadcast_bool_2) { + NDArray x('c', {3, 1, 2}, sd::DataType::FLOAT32); + NDArray y('c', {2, 2}, sd::DataType::FLOAT32); + NDArray z('c', {3, 2, 2}, sd::DataType::BOOL); + NDArray e('c', {3, 2, 2}, sd::DataType::BOOL); - NDArray x('c', {3, 1, 2}, sd::DataType::FLOAT32); - NDArray y('c', {2, 2}, sd::DataType::FLOAT32); - NDArray z('c', {3, 2, 2}, sd::DataType::BOOL); - NDArray e('c', {3, 2, 2}, sd::DataType::BOOL); - - x.assign(1.f); - y.assign(2.f); - e.assign(false); + x.assign(1.f); + y.assign(2.f); + e.assign(false); - sd::ops::equals op; + sd::ops::equals op; - auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_EQ(ND4J_STATUS_OK, status); - // z.printIndexedBuffer("Z"); + // z.printIndexedBuffer("Z"); - ASSERT_TRUE(z.isSameShape(e)); - ASSERT_TRUE(z.equalsTo(e)); + ASSERT_TRUE(z.isSameShape(e)); + ASSERT_TRUE(z.equalsTo(e)); } TEST_F(BroadcastableOpsTests, broadcast_bool_3) { + auto x = NDArrayFactory::create(0); + auto y = NDArrayFactory::create('c', {3}, {2, 1, 2}); + NDArray z('c', {3}, sd::DataType::BOOL); + NDArray e('c', {3}, sd::DataType::BOOL); - auto x = NDArrayFactory::create(0); - auto y = NDArrayFactory::create('c', {3}, {2, 1, 2}); - NDArray z('c', {3}, sd::DataType::BOOL); - NDArray e('c', {3}, sd::DataType::BOOL); - - e.assign(true); + e.assign(true); - sd::ops::less op; - auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + sd::ops::less op; + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_EQ(ND4J_STATUS_OK, status); - // z.printIndexedBuffer("Z"); + // z.printIndexedBuffer("Z"); - ASSERT_TRUE(z.isSameShape(e)); - ASSERT_TRUE(z.equalsTo(e)); + ASSERT_TRUE(z.isSameShape(e)); + ASSERT_TRUE(z.equalsTo(e)); } TEST_F(BroadcastableOpsTests, broadcast_2) { - NDArray x('c', {3, 1, 2}, sd::DataType::FLOAT32); - NDArray y('c', {2, 2}, sd::DataType::FLOAT32); - NDArray z('c', {3, 2, 2}, sd::DataType::FLOAT32); - NDArray e('c', {3, 2, 2}, sd::DataType::FLOAT32); + NDArray x('c', {3, 1, 2}, sd::DataType::FLOAT32); + NDArray y('c', {2, 2}, sd::DataType::FLOAT32); + NDArray z('c', {3, 2, 2}, sd::DataType::FLOAT32); + NDArray e('c', {3, 2, 2}, sd::DataType::FLOAT32); - x = 4.f; - y = 2.f; - e = -2.f; + x = 4.f; + y = 2.f; + e = -2.f; - sd::ops::reversesubtract op; // z = y - x; + sd::ops::reversesubtract op; // z = y - x; - auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_EQ(ND4J_STATUS_OK, status); - // z.printIndexedBuffer("Z"); + // z.printIndexedBuffer("Z"); - ASSERT_TRUE(z.isSameShape(e)); - ASSERT_TRUE(z.equalsTo(e)); + ASSERT_TRUE(z.isSameShape(e)); + ASSERT_TRUE(z.equalsTo(e)); } TEST_F(BroadcastableOpsTests, broadcast_3) { - auto x = NDArrayFactory::create(0); - auto y = NDArrayFactory::create('c', {3}, {2, 1, 2}); - NDArray z('c', {3}, sd::DataType::INT32); - auto e = NDArrayFactory::create('c', {3}, {2, 1, 2}); + auto x = NDArrayFactory::create(0); + auto y = NDArrayFactory::create('c', {3}, {2, 1, 2}); + NDArray z('c', {3}, sd::DataType::INT32); + auto e = NDArrayFactory::create('c', {3}, {2, 1, 2}); - sd::ops::add op; - auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + sd::ops::add op; + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_EQ(ND4J_STATUS_OK, status); - // z.printIndexedBuffer("Z"); + // z.printIndexedBuffer("Z"); - ASSERT_TRUE(z.isSameShape(e)); - ASSERT_TRUE(z.equalsTo(e)); + ASSERT_TRUE(z.isSameShape(e)); + ASSERT_TRUE(z.equalsTo(e)); } TEST_F(BroadcastableOpsTests, test_bert_multiply_1) { - auto x = NDArrayFactory::create('c', {4, 128, 1}); - auto y = NDArrayFactory::create('c', {4, 1, 128}); - auto z = NDArrayFactory::create('c', {4, 128, 128}); - auto e = NDArrayFactory::create('c', {4, 128, 128}); + auto x = NDArrayFactory::create('c', {4, 128, 1}); + auto y = NDArrayFactory::create('c', {4, 1, 128}); + auto z = NDArrayFactory::create('c', {4, 128, 128}); + auto e = NDArrayFactory::create('c', {4, 128, 128}); - x.assign(0.f); - y.assign(1.f); - z.assign(119.f); - e.assign(0.f); -/* - Context ctx(1); - ctx.setInputArray(0, &x); - ctx.setInputArray(1, &y); - ctx.setOutputArray(0, &z); + x.assign(0.f); + y.assign(1.f); + z.assign(119.f); + e.assign(0.f); + /* + Context ctx(1); + ctx.setInputArray(0, &x); + ctx.setInputArray(1, &y); + ctx.setOutputArray(0, &z); - sd::ops::multiply op; - auto status = op.execute(&ctx); - ASSERT_EQ(Status::OK(), status); + sd::ops::multiply op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); - z.printIndexedBuffer(); -*/ + z.printIndexedBuffer(); + */ - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); - //z.printIndexedBuffer(); + // z.printIndexedBuffer(); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(BroadcastableOpsTests, test_bert_multiply_2) { - auto x = NDArrayFactory::create('c', {4, 128, 1}); - auto y = NDArrayFactory::create('c', {768}); - auto z = NDArrayFactory::create('c', {4, 128, 768}); - auto e = NDArrayFactory::create('c', {4, 128, 768}); + auto x = NDArrayFactory::create('c', {4, 128, 1}); + auto y = NDArrayFactory::create('c', {768}); + auto z = NDArrayFactory::create('c', {4, 128, 768}); + auto e = NDArrayFactory::create('c', {4, 128, 768}); - x.assign(1.f); - y.assign(2.f); - z.assign(119.f); - e.assign(2.f); + x.assign(1.f); + y.assign(2.f); + z.assign(119.f); + e.assign(2.f); - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } diff --git a/libnd4j/tests_cpu/layers_tests/BrodcastTests.cpp b/libnd4j/tests_cpu/layers_tests/BrodcastTests.cpp index 34d0132bb221..135fb0c4ba1a 100644 --- a/libnd4j/tests_cpu/layers_tests/BrodcastTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BrodcastTests.cpp @@ -18,49 +18,56 @@ // Created by agibsonccc on 1/19/17. // -#include "testinclude.h" #include +#include "testinclude.h" + class BroadcastMultiDimTest : public testing::Test { -public: - int dimensions[2] = {0,2}; - Nd4jLong inputShapeBuffer[10] = {3,2,3,5,15,5,1,8192,1,99}; - float inputData[30] = {1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0,16.0,17.0,18.0,19.0,20.0,21.0,22.0,23.0,24.0,25.0,26.0,27.0,28.0,29.0,30.0}; - float dataAssertion[30] = {1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0,16.0,17.0,18.0,0.0,0.0,21.0,22.0,23.0,0.0,0.0,26.0,27.0,28.0,0.0,0.0}; - float result[30] = {0.0}; - float broadcastData[10] = {1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0}; - Nd4jLong broadcastShapeInfo[8] = {2,2,5,5,1,8192,1,99}; - int opNum = 2; - int dimensionLength = 2; + public: + int dimensions[2] = {0, 2}; + Nd4jLong inputShapeBuffer[10] = {3, 2, 3, 5, 15, 5, 1, 8192, 1, 99}; + float inputData[30] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, + 25.0, 26.0, 27.0, 28.0, 29.0, 30.0}; + float dataAssertion[30] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 0.0, 0.0, 21.0, 22.0, 23.0, 0.0, + 0.0, 26.0, 27.0, 28.0, 0.0, 0.0}; + float result[30] = {0.0}; + float broadcastData[10] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0}; + Nd4jLong broadcastShapeInfo[8] = {2, 2, 5, 5, 1, 8192, 1, 99}; + int opNum = 2; + int dimensionLength = 2; }; #ifndef __CUDABLAS__ -TEST_F(BroadcastMultiDimTest,MultimDimTest) { - auto tad = new shape::TAD(); - tad->init(inputShapeBuffer,dimensions,dimensionLength); - tad->createTadOnlyShapeInfo(); - tad-> createOffsets(); - functions::broadcast::Broadcast::exec( - opNum, - inputData, //x - inputShapeBuffer, //xShapeInfo - broadcastData, //y - broadcastShapeInfo, //yShapeInfo - result, //result - inputShapeBuffer, //resultShapeInfo - dimensions, //dimension - dimensionLength, //dimensionLength - tad->tadOnlyShapeInfo, //tadShapeInfo - tad->tadOffsets, //tadOffset - tad->tadOnlyShapeInfo, //tadShapeInfoZ - tad->tadOffsets, sd::LoopKind::COMMON, 0, tad->numTads); //tadOffsetZ +TEST_F(BroadcastMultiDimTest, MultimDimTest) { + auto tad = new shape::TAD(); + tad->init(inputShapeBuffer, dimensions, dimensionLength); + tad->createTadOnlyShapeInfo(); + tad->createOffsets(); + functions::broadcast::Broadcast::exec( + opNum, + inputData, // x + inputShapeBuffer, // xShapeInfo + broadcastData, // y + broadcastShapeInfo, // yShapeInfo + result, // result + inputShapeBuffer, // resultShapeInfo + dimensions, // dimension + dimensionLength, // dimensionLength + tad->tadOnlyShapeInfo, // tadShapeInfo + tad->tadOffsets, // tadOffset + tad->tadOnlyShapeInfo, // tadShapeInfoZ + tad->tadOffsets, sd::LoopKind::COMMON, 0, tad->numTads); // tadOffsetZ - for(int i = 0; i < 30; i++) { - ASSERT_EQ(dataAssertion[i],result[i]); - } + for (int i = 0; i < 30; i++) { + ASSERT_EQ(dataAssertion[i], result[i]); + } - delete tad; + delete tad; } #endif \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/CMakeLists.txt b/libnd4j/tests_cpu/layers_tests/CMakeLists.txt index 563bf58f6c0b..ffe508987397 100644 --- a/libnd4j/tests_cpu/layers_tests/CMakeLists.txt +++ b/libnd4j/tests_cpu/layers_tests/CMakeLists.txt @@ -60,7 +60,8 @@ else() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -mtune=native") endif() - if (SD_CPU AND SD_SANITIZE) + # Use ASAN for GCC only + if (SD_CPU AND SD_SANITIZE AND "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" ) set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address") else() # CUDA? diff --git a/libnd4j/tests_cpu/layers_tests/CnpyTests.cpp b/libnd4j/tests_cpu/layers_tests/CnpyTests.cpp index ea8025592c82..256be6b110b8 100644 --- a/libnd4j/tests_cpu/layers_tests/CnpyTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/CnpyTests.cpp @@ -18,53 +18,50 @@ // Created by agibsonccc on 3/30/17. // -#include "testinclude.h" -#include #include -class FileTest : public testing::Test { - -}; +#include -class LoadFromStringTest : public testing::Test { +#include "testinclude.h" -}; +class FileTest : public testing::Test {}; -class HeaderTest : public testing::Test { +class LoadFromStringTest : public testing::Test {}; -}; +class HeaderTest : public testing::Test {}; TEST_F(HeaderTest, test_dataTypes_1) { - std::string header("0NUMPY6789{'descr': '>f4"); + std::string header("0NUMPY6789{'descr': '>f4"); - - ASSERT_EQ(sd::DataType::FLOAT32, dataTypeFromNpyHeader(const_cast(header.data()))); + ASSERT_EQ(sd::DataType::FLOAT32, + dataTypeFromNpyHeader(const_cast(header.data()))); } TEST_F(HeaderTest, test_dataTypes_2) { - std::string header("0NUMPY6789{'descr': '>f8"); - + std::string header("0NUMPY6789{'descr': '>f8"); - ASSERT_EQ(sd::DataType::DOUBLE, dataTypeFromNpyHeader(const_cast(header.data()))); + ASSERT_EQ(sd::DataType::DOUBLE, + dataTypeFromNpyHeader(const_cast(header.data()))); } TEST_F(HeaderTest, test_dataTypes_3) { - std::string header("0NUMPY6789{'descr': '(header.data()))); + ASSERT_EQ(sd::DataType::INT32, + dataTypeFromNpyHeader(const_cast(header.data()))); } TEST_F(HeaderTest, test_dataTypes_4) { - std::string header("0NUMPY6789{'descr': '>u2"); - + std::string header("0NUMPY6789{'descr': '>u2"); - ASSERT_EQ(sd::DataType::UINT16, dataTypeFromNpyHeader(const_cast(header.data()))); + ASSERT_EQ(sd::DataType::UINT16, + dataTypeFromNpyHeader(const_cast(header.data()))); } /* TEST_F(FileTest,T) { - cnpy::NpyArray npy = cnpy::npyLoad(std::string("/home/agibsonccc/code/libnd4j/test.npy")); + cnpy::NpyArray npy = +cnpy::npyLoad(std::string("/home/agibsonccc/code/libnd4j/test.npy")); ASSERT_FALSE(npy.fortranOrder); ASSERT_EQ(2,npy.shape[0]); diff --git a/libnd4j/tests_cpu/layers_tests/ConditionalTests.cpp b/libnd4j/tests_cpu/layers_tests/ConditionalTests.cpp deleted file mode 100644 index 5167abcd1c01..000000000000 --- a/libnd4j/tests_cpu/layers_tests/ConditionalTests.cpp +++ /dev/null @@ -1,332 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 16.10.2017. -// - -#include "testlayers.h" -#include -#include -#include -#include - -using namespace sd; -using namespace sd::graph; - -class ConditionalTests : public testing::Test { -public: - ConditionalTests(){ - //Environment::getInstance().setVerbose(true); - //Environment::getInstance().setDebug(true); - } - - ~ConditionalTests(){ - //Environment::getInstance().setVerbose(false); - //Environment::getInstance().setDebug(false); - } -}; - - -TEST_F(ConditionalTests, BasicTests_1) { - Graph graph; - - auto x = NDArrayFactory::valueOf({2, 2}, 1.0f); - auto y0 = NDArrayFactory::valueOf({2, 2}, 5.0f); - auto y1 = NDArrayFactory::valueOf({2, 2}, -5.0f); - auto scalar = NDArrayFactory::create_(1.0f); - - auto variableSpace = graph.getVariableSpace(); - - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y0); - variableSpace->putVariable(-3, y1); - variableSpace->putVariable(-4, scalar); - - - auto scopeCondition = new Node(OpType_LOGIC, logic::Scope, 1); - scopeCondition->setName("scopeCondition"); - - auto scopeFalse = new Node(OpType_LOGIC, logic::Scope, 2); - scopeFalse->setName("scopeFalse"); - - auto scopeTrue = new Node(OpType_LOGIC, logic::Scope, 3); - scopeTrue->setName("scopeTrue"); - - auto nodeF = new Node(OpType_PAIRWISE, pairwise::Add, 5, {-1, -2}); - nodeF->setScopeInfo(2, "scopeFalse"); - - auto nodeT = new Node(OpType_PAIRWISE, pairwise::Subtract, 6, {-1, -2}); - nodeT->setScopeInfo(3, "scopeTrue"); - - auto nodeC0 = new Node(OpType_REDUCE_SAME, reduce::Sum, 7, {-1}); - nodeC0->setScopeInfo(1, "scopeCondition"); - - sd::ops::eq_scalar op; - auto nodeC1 = new Node(&op, 8, {7, -4}); - nodeC1->setScopeInfo(1, "scopeCondition"); - - graph.addNode(scopeCondition); - graph.addNode(scopeFalse); - graph.addNode(scopeTrue); - graph.addNode(nodeF); - graph.addNode(nodeT); - graph.addNode(nodeC0); - graph.addNode(nodeC1); - - // at this point graph should ounly have Nodes referring to the Scopes: condition scope, true scope and false scope - ASSERT_EQ(3, graph.totalNodes()); - - // now we're adding Condition op, that'll take all of those in - auto nodeCondition = new Node(OpType_LOGIC, logic::Conditional, 10, {1, 2, 3}); - graph.addNode(nodeCondition); - - ASSERT_EQ(4, graph.totalNodes()); - - Nd4jStatus status = GraphExecutioner::execute(&graph); - ASSERT_EQ(ND4J_STATUS_OK, status); - - ASSERT_TRUE(variableSpace->hasVariable(10, 0)); - auto conditionalResult = variableSpace->getVariable(10, 0)->getNDArray(); - ASSERT_NE(nullptr, conditionalResult); - - ASSERT_NEAR(6.0, conditionalResult->meanNumber().e(0), 1e-5); -} -#ifdef GRAPH_FILES_OK -/** - * Condition is False - */ -TEST_F(ConditionalTests, Flat_Test_1) { - sd::ops::identity op0; - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simpleif_0_1.fb"); - auto varSpace = graph->getVariableSpace(); - //varSpace->getVariable(1)->getNDArray()->assign(2.0); - //varSpace->getVariable(2)->getNDArray()->assign(0.0); - - //graph->printOut(); - - auto status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(varSpace->hasVariable(15)); - - auto z = varSpace->getVariable(15)->getNDArray(); - - ASSERT_NE(nullptr, z); - - auto exp = NDArrayFactory::create('c', {2, 2}, {-2, -2, -2, -2}); - - ASSERT_TRUE(exp.equalsTo(z)); - - delete graph; -} - -/** - * Condition is True - */ -TEST_F(ConditionalTests, Flat_Test_2) { - Environment::getInstance().setDebug(true); - Environment::getInstance().setVerbose(true); - sd::ops::identity op0; - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simpleif_0.fb"); - auto varSpace = graph->getVariableSpace(); - varSpace->getVariable(1)->getNDArray()->assign(-1.0); - - graph->printOut(); - - auto status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(varSpace->hasVariable(15)); - - auto z = varSpace->getVariable(15)->getNDArray(); - - ASSERT_NE(nullptr, z); - - auto exp = NDArrayFactory::create('c', {2, 2}, {1, 1, 1, 1}); - - ASSERT_TRUE(exp.equalsTo(z)); - delete graph; -} - - -/** - * Condition is false here, so there loop will be skipped - */ -TEST_F(ConditionalTests, Flat_Test_3) { - sd::ops::identity op0; - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simplewhile_0_3.fb"); - auto varSpace = graph->getVariableSpace(); - varSpace->getVariable(1)->getNDArray()->assign(1.0); - - //graph->printOut(); - - auto status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(varSpace->hasVariable(17)); - - auto z = varSpace->getVariable(17)->getNDArray(); - - ASSERT_NE(nullptr, z); - - auto exp = NDArrayFactory::create('c', {2, 2}, {1, 1, 1, 1}); - ASSERT_TRUE(exp.equalsTo(z)); - - delete graph; -} - -/** - * just one cycle in body - */ -TEST_F(ConditionalTests, Flat_Test_4) { - sd::ops::identity op0; - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simplewhile_0_4.fb"); - auto varSpace = graph->getVariableSpace(); - varSpace->getVariable(2)->getNDArray()->assign(4.0); - - //graph->printOut(); - - auto status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(varSpace->hasVariable(17)); - - auto z = varSpace->getVariable(17)->getNDArray(); - - ASSERT_NE(nullptr, z); - - // 0.0 + 2.0 = 2.0 in each element - auto exp = NDArrayFactory::create('c', {2, 2}, {2, 2, 2, 2}); - ASSERT_TRUE(exp.equalsTo(z)); - - delete graph; -} - - -/** - * just two cycles in body - */ -TEST_F(ConditionalTests, Flat_Test_5) { - sd::ops::identity op0; - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simplewhile_0_4.fb"); - auto varSpace = graph->getVariableSpace(); - varSpace->getVariable(2)->getNDArray()->assign(9.0); - - //graph->printOut(); - - auto status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(varSpace->hasVariable(17)); - - auto z = varSpace->getVariable(17)->getNDArray(); - - ASSERT_NE(nullptr, z); - - // 0.0 + 2.0 + 2.0 = 4.0 in each element - auto exp = NDArrayFactory::create('c', {2, 2}, {4, 4, 4, 4}); - ASSERT_TRUE(exp.equalsTo(z)); - - delete graph; -} - -/** - * While loop with multiple variables - */ -TEST_F(ConditionalTests, Flat_Test_6) { - sd::ops::identity op0; - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simplewhile_1.fb"); - auto varSpace = graph->getVariableSpace(); - varSpace->getVariable(1)->getNDArray()->assign(-4.0f); - varSpace->getVariable(2)->getNDArray()->assign(1.0f); - - //graph->printOut(); - - auto status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(varSpace->hasVariable(25)); - - auto z = varSpace->getVariable(25)->getNDArray(); - - ASSERT_NE(nullptr, z); - - //z->printIndexedBuffer(); - - auto exp = NDArrayFactory::create('c', {2, 2}, {-1, -1, -1, -1}); - ASSERT_TRUE(exp.equalsTo(z)); - - delete graph; -} - -TEST_F(ConditionalTests, Flat_Test_7) { - sd::ops::identity op0; - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simplewhile_1.fb"); - auto varSpace = graph->getVariableSpace(); - varSpace->getVariable(1)->getNDArray()->assign(-9.0f); - varSpace->getVariable(2)->getNDArray()->assign(1.0f); - - //graph->printOut(); - - auto status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(varSpace->hasVariable(25)); - - auto z = varSpace->getVariable(25)->getNDArray(); - - ASSERT_NE(nullptr, z); - - //z->printIndexedBuffer(); - - auto exp = NDArrayFactory::create('c', {2, 2}, {-3, -3, -3, -3}); - ASSERT_TRUE(exp.equalsTo(z)); - - delete graph; -} - -/** - * This test checks nested while execution - */ -TEST_F(ConditionalTests, Flat_Test_8) { - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simplewhile_nested.fb"); - auto varSpace = graph->getVariableSpace(); - //graph->printOut(); - - auto status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(varSpace->hasVariable(52)); - - auto z = varSpace->getVariable(52)->getNDArray(); - - ASSERT_NE(nullptr, z); - - //val exp = Nd4j.create(2, 2).assign(15.0); - auto exp = NDArrayFactory::create('c', {2, 2}, {15, 15, 15, 15}); - ASSERT_TRUE(exp.equalsTo(z)); - - delete graph; -} -#endif diff --git a/libnd4j/tests_cpu/layers_tests/ConstantShapeHelperTests.cpp b/libnd4j/tests_cpu/layers_tests/ConstantShapeHelperTests.cpp index a9a42ac8818f..cc4a35af1916 100644 --- a/libnd4j/tests_cpu/layers_tests/ConstantShapeHelperTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConstantShapeHelperTests.cpp @@ -18,226 +18,237 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include -#include -#include #include +#include +#include #include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::ops; using namespace sd::graph; class ConstantShapeHelperTests : public testing::Test { -public: - + public: }; class ConstantHelperTests : public testing::Test { -public: - + public: }; class ConstantTadHelperTests : public testing::Test { -public: - + public: }; TEST_F(ConstantShapeHelperTests, test_cachedAmount_1) { - auto ttlBefore = ConstantShapeHelper::getInstance().totalCachedEntries(); + auto ttlBefore = ConstantShapeHelper::getInstance().totalCachedEntries(); - auto arrayA = NDArrayFactory::create('c', {7, 11, 17, 23, 31, 43}); + auto arrayA = NDArrayFactory::create('c', {7, 11, 17, 23, 31, 43}); - auto ttlMiddle = ConstantShapeHelper::getInstance().totalCachedEntries(); + auto ttlMiddle = ConstantShapeHelper::getInstance().totalCachedEntries(); - auto arrayB = NDArrayFactory::create('c', {7, 11, 17, 23, 31, 43}); + auto arrayB = NDArrayFactory::create('c', {7, 11, 17, 23, 31, 43}); - auto ttlAfter = ConstantShapeHelper::getInstance().totalCachedEntries(); + auto ttlAfter = ConstantShapeHelper::getInstance().totalCachedEntries(); - ASSERT_TRUE(ttlBefore <= ttlMiddle); - ASSERT_EQ(ttlMiddle, ttlAfter); + ASSERT_TRUE(ttlBefore <= ttlMiddle); + ASSERT_EQ(ttlMiddle, ttlAfter); } TEST_F(ConstantTadHelperTests, test_cachedAmount_1) { - auto arrayA = NDArrayFactory::create('c', {7, 11, 17, 23, 31, 43}); - auto ttlBefore = ConstantTadHelper::getInstance().totalCachedEntries(); + auto arrayA = NDArrayFactory::create('c', {7, 11, 17, 23, 31, 43}); + auto ttlBefore = ConstantTadHelper::getInstance().totalCachedEntries(); - auto packAA = ConstantTadHelper::getInstance().tadForDimensions(arrayA.shapeInfo(), {3, 4}); + auto packAA = ConstantTadHelper::getInstance().tadForDimensions( + arrayA.shapeInfo(), {3, 4}); - auto ttlMiddle = ConstantTadHelper::getInstance().totalCachedEntries(); + auto ttlMiddle = ConstantTadHelper::getInstance().totalCachedEntries(); - auto packAB = ConstantTadHelper::getInstance().tadForDimensions(arrayA.shapeInfo(), {3, 4}); + auto packAB = ConstantTadHelper::getInstance().tadForDimensions( + arrayA.shapeInfo(), {3, 4}); - auto ttlAfter = ConstantTadHelper::getInstance().totalCachedEntries(); + auto ttlAfter = ConstantTadHelper::getInstance().totalCachedEntries(); - ASSERT_TRUE(ttlBefore <= ttlMiddle); - ASSERT_EQ(ttlMiddle, ttlAfter); + ASSERT_TRUE(ttlBefore <= ttlMiddle); + ASSERT_EQ(ttlMiddle, ttlAfter); } TEST_F(ConstantShapeHelperTests, basic_test_1) { - auto ptr = ShapeBuilders::createShapeInfo(sd::DataType::BFLOAT16, 'f', {5, 10, 15}); - ShapeDescriptor descriptor(ptr); - ShapeDescriptor descriptor2(ptr); - - ASSERT_EQ(descriptor, descriptor2); + auto ptr = + ShapeBuilders::createShapeInfo(sd::DataType::BFLOAT16, 'f', {5, 10, 15}); + ShapeDescriptor descriptor(ptr); + ShapeDescriptor descriptor2(ptr); - ASSERT_EQ(1, descriptor.ews()); - ASSERT_EQ(3, descriptor.rank()); - ASSERT_EQ('f', descriptor.order()); - ASSERT_EQ(sd::DataType::BFLOAT16, descriptor.dataType()); - ASSERT_FALSE(descriptor.isEmpty()); + ASSERT_EQ(descriptor, descriptor2); - ASSERT_FALSE(ConstantShapeHelper::getInstance().checkBufferExistenceForShapeInfo(descriptor)); + ASSERT_EQ(1, descriptor.ews()); + ASSERT_EQ(3, descriptor.rank()); + ASSERT_EQ('f', descriptor.order()); + ASSERT_EQ(sd::DataType::BFLOAT16, descriptor.dataType()); + ASSERT_FALSE(descriptor.isEmpty()); - auto buffer = ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor); + ASSERT_FALSE( + ConstantShapeHelper::getInstance().checkBufferExistenceForShapeInfo( + descriptor)); - ASSERT_TRUE(ConstantShapeHelper::getInstance().checkBufferExistenceForShapeInfo(descriptor)); + auto buffer = + ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor); - auto buffer2 = ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor2); + ASSERT_TRUE( + ConstantShapeHelper::getInstance().checkBufferExistenceForShapeInfo( + descriptor)); + auto buffer2 = + ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor2); - ASSERT_TRUE(buffer.primary() != nullptr); - ASSERT_TRUE(buffer.primary() == buffer2.primary()); - ASSERT_TRUE(buffer.special() == buffer2.special()); + ASSERT_TRUE(buffer.primary() != nullptr); + ASSERT_TRUE(buffer.primary() == buffer2.primary()); + ASSERT_TRUE(buffer.special() == buffer2.special()); - delete []ptr; + delete[] ptr; } TEST_F(ConstantShapeHelperTests, stress_test_1) { - - for (auto x = 0; x < 1000; x++) { - auto ptr = ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', {5, x + 10, x + 1}); - ShapeDescriptor descriptor(ptr); - ConstantShapeHelper::getInstance().createShapeInfo(descriptor); - delete [] ptr; - } - ShapeDescriptor aShape(sd::DataType::FLOAT32, 'c', {(Nd4jLong)5, (Nd4jLong)382, (Nd4jLong)373}); -// nd4j_printf("%d\n", ConstantShapeHelper::getInstance().cachedEntriesForDevice(0)); - - auto timeStart = std::chrono::system_clock::now(); - ASSERT_TRUE(ConstantShapeHelper::getInstance().checkBufferExistenceForShapeInfo(aShape)); - auto timeEnd = std::chrono::system_clock::now(); - - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - nd4j_printf("Total time (us) %lld\n", outerTime); + for (auto x = 0; x < 1000; x++) { + auto ptr = ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', + {5, x + 10, x + 1}); + ShapeDescriptor descriptor(ptr); + ConstantShapeHelper::getInstance().createShapeInfo(descriptor); + delete[] ptr; + } + ShapeDescriptor aShape(sd::DataType::FLOAT32, 'c', + {(Nd4jLong)5, (Nd4jLong)382, (Nd4jLong)373}); + // nd4j_printf("%d\n", + // ConstantShapeHelper::getInstance().cachedEntriesForDevice(0)); + + auto timeStart = std::chrono::system_clock::now(); + ASSERT_TRUE( + ConstantShapeHelper::getInstance().checkBufferExistenceForShapeInfo( + aShape)); + auto timeEnd = std::chrono::system_clock::now(); + + auto outerTime = + std::chrono::duration_cast(timeEnd - timeStart) + .count(); + nd4j_printf("Total time (us) %lld\n", outerTime); } TEST_F(ConstantShapeHelperTests, basic_test_3) { - auto array = NDArrayFactory::create_('c', {128}); + auto array = NDArrayFactory::create_('c', {128}); - ASSERT_TRUE(array->shapeInfo() != nullptr); + ASSERT_TRUE(array->shapeInfo() != nullptr); #ifdef __CUDABLAS__ - ASSERT_TRUE(array->specialShapeInfo() != nullptr); + ASSERT_TRUE(array->specialShapeInfo() != nullptr); #endif - delete array; + delete array; } - TEST_F(ConstantShapeHelperTests, basic_test_4) { - auto array = NDArrayFactory::create_('c', {128, 256}); + auto array = NDArrayFactory::create_('c', {128, 256}); - auto dup = new NDArray(array->dup('f')); + auto dup = new NDArray(array->dup('f')); - ASSERT_TRUE(dup->shapeInfo() != nullptr); + ASSERT_TRUE(dup->shapeInfo() != nullptr); #ifdef __CUDABLAS__ - ASSERT_TRUE(dup->specialShapeInfo() != nullptr); - PointersManager manager(sd::LaunchContext ::defaultContext(), "test"); - // manager.printDevContentOnDev(dup->special(), shape::shapeInfoLength(2), 0); + ASSERT_TRUE(dup->specialShapeInfo() != nullptr); + PointersManager manager(sd::LaunchContext ::defaultContext(), "test"); + // manager.printDevContentOnDev(dup->special(), + // shape::shapeInfoLength(2), 0); #endif - delete array; - delete dup; + delete array; + delete dup; } - TEST_F(ConstantShapeHelperTests, basic_test_5) { + auto arrayA = NDArrayFactory::create(1); + auto arrayB = NDArrayFactory::create_('c', {128, 256}); - auto arrayA = NDArrayFactory::create(1); - auto arrayB = NDArrayFactory::create_('c', {128, 256}); - - //arrayA.printShapeInfo("A"); - //arrayB->printShapeInfo("B"); - ASSERT_EQ(0, arrayA.rankOf()); - ASSERT_EQ(2, arrayB->rankOf()); - ASSERT_NE(arrayA.dataType(), arrayB->dataType()); + // arrayA.printShapeInfo("A"); + // arrayB->printShapeInfo("B"); + ASSERT_EQ(0, arrayA.rankOf()); + ASSERT_EQ(2, arrayB->rankOf()); + ASSERT_NE(arrayA.dataType(), arrayB->dataType()); - delete arrayB; + delete arrayB; } TEST_F(ConstantShapeHelperTests, basic_test_6) { - ShapeDescriptor descriptorA(sd::DataType::INT32, 'c', {}); - ShapeDescriptor descriptorB(sd::DataType::FLOAT32, 'c', {10, 10}); + ShapeDescriptor descriptorA(sd::DataType::INT32, 'c', {}); + ShapeDescriptor descriptorB(sd::DataType::FLOAT32, 'c', {10, 10}); - // ASSERT_FALSE(descriptorA < descriptorB); - // ASSERT_TRUE(descriptorB < descriptorA); + // ASSERT_FALSE(descriptorA < descriptorB); + // ASSERT_TRUE(descriptorB < descriptorA); - ASSERT_TRUE(descriptorA < descriptorB); - ASSERT_FALSE(descriptorB < descriptorA); + ASSERT_TRUE(descriptorA < descriptorB); + ASSERT_FALSE(descriptorB < descriptorA); } TEST_F(ConstantShapeHelperTests, basic_test_7) { - auto array = NDArrayFactory::create_('c', {32, 256}); + auto array = NDArrayFactory::create_('c', {32, 256}); - IndicesList indices({NDIndex::all(), NDIndex::interval(0,1)}); - auto strided = array->subarray(indices); - strided.assign(1.0f); + IndicesList indices({NDIndex::all(), NDIndex::interval(0, 1)}); + auto strided = array->subarray(indices); + strided.assign(1.0f); - //strided->printIndexedBuffer("column"); + // strided->printIndexedBuffer("column"); - delete array; + delete array; } TEST_F(ConstantHelperTests, basic_test_1) { + ConstantDescriptor descriptor({1, 2, 3}); - ConstantDescriptor descriptor({1, 2, 3}); - - ConstantDataBuffer* fBuffer = ConstantHelper::getInstance().constantBuffer(descriptor, sd::DataType::FLOAT32); - auto fPtr = fBuffer->primaryAsT(); + ConstantDataBuffer* fBuffer = ConstantHelper::getInstance().constantBuffer( + descriptor, sd::DataType::FLOAT32); + auto fPtr = fBuffer->primaryAsT(); - ASSERT_NEAR(1.f, fPtr[0], 1e-5); - ASSERT_NEAR(2.f, fPtr[1], 1e-5); - ASSERT_NEAR(3.f, fPtr[2], 1e-5); + ASSERT_NEAR(1.f, fPtr[0], 1e-5); + ASSERT_NEAR(2.f, fPtr[1], 1e-5); + ASSERT_NEAR(3.f, fPtr[2], 1e-5); - auto iBuffer = ConstantHelper::getInstance().constantBuffer(descriptor, sd::DataType::INT32); - auto iPtr = iBuffer->primaryAsT(); + auto iBuffer = ConstantHelper::getInstance().constantBuffer( + descriptor, sd::DataType::INT32); + auto iPtr = iBuffer->primaryAsT(); - ASSERT_EQ(1, iPtr[0]); - ASSERT_EQ(2, iPtr[1]); - ASSERT_EQ(3, iPtr[2]); + ASSERT_EQ(1, iPtr[0]); + ASSERT_EQ(2, iPtr[1]); + ASSERT_EQ(3, iPtr[2]); } TEST_F(ConstantHelperTests, basic_test_2) { + double array[] = {1., 2., 3.}; + ConstantDescriptor descriptor(array, 3); - double array[] = {1., 2., 3.}; - ConstantDescriptor descriptor(array, 3); + ConstantDataBuffer* fBuffer = ConstantHelper::getInstance().constantBuffer( + descriptor, sd::DataType::FLOAT32); + auto fPtr = fBuffer->primaryAsT(); - ConstantDataBuffer* fBuffer = ConstantHelper::getInstance().constantBuffer(descriptor, sd::DataType::FLOAT32); - auto fPtr = fBuffer->primaryAsT(); + ASSERT_NEAR(1.f, fPtr[0], 1e-5); + ASSERT_NEAR(2.f, fPtr[1], 1e-5); + ASSERT_NEAR(3.f, fPtr[2], 1e-5); - ASSERT_NEAR(1.f, fPtr[0], 1e-5); - ASSERT_NEAR(2.f, fPtr[1], 1e-5); - ASSERT_NEAR(3.f, fPtr[2], 1e-5); + auto iBuffer = ConstantHelper::getInstance().constantBuffer( + descriptor, sd::DataType::INT32); + auto iPtr = iBuffer->primaryAsT(); - auto iBuffer = ConstantHelper::getInstance().constantBuffer(descriptor, sd::DataType::INT32); - auto iPtr = iBuffer->primaryAsT(); - - ASSERT_EQ(1, iPtr[0]); - ASSERT_EQ(2, iPtr[1]); - ASSERT_EQ(3, iPtr[2]); + ASSERT_EQ(1, iPtr[0]); + ASSERT_EQ(2, iPtr[1]); + ASSERT_EQ(3, iPtr[2]); } ////////////////////////////////////////////////////////////////////// TEST_F(ConstantShapeHelperTests, ShapeDescriptor_1) { + Nd4jLong shapeInfo1[] = {4, 2, 5, 5, 2, 25, 5, 1, 50, 8192, 0, 99}; + Nd4jLong shapeInfo2[] = {4, 2, 5, 5, 2, 50, 10, 2, 1, 8192, 1, 99}; - Nd4jLong shapeInfo1[] = {4, 2, 5, 5, 2, 25, 5, 1, 50, 8192, 0, 99}; - Nd4jLong shapeInfo2[] = {4, 2, 5, 5, 2, 50, 10, 2, 1, 8192, 1, 99}; - - ShapeDescriptor descr1(shapeInfo1); - ShapeDescriptor descr2(shapeInfo2); + ShapeDescriptor descr1(shapeInfo1); + ShapeDescriptor descr2(shapeInfo2); - ASSERT_FALSE(descr1 == descr2); + ASSERT_FALSE(descr1 == descr2); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/ContextTests.cpp b/libnd4j/tests_cpu/layers_tests/ContextTests.cpp index 57d9ce88d94d..460de23abf23 100644 --- a/libnd4j/tests_cpu/layers_tests/ContextTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ContextTests.cpp @@ -18,339 +18,352 @@ // Created by raver119 on 30.10.2017. // -#include "testlayers.h" #include +#include "testlayers.h" + using namespace sd; using namespace sd::ops; using namespace sd::graph; class ContextTests : public testing::Test { -public: - + public: }; - TEST_F(ContextTests, Basic_Test_1) { - VariableSpace variableSpace; + VariableSpace variableSpace; - auto _20 = NDArrayFactory::create_('c', {2, 2}); - auto _21 = NDArrayFactory::create_('c', {2, 2}); + auto _20 = NDArrayFactory::create('c', {2, 2}); + auto _21 = NDArrayFactory::create('c', {2, 2}); - _20->assign(1.0f); - _21->assign(2.0f); + _20.assign(1.0f); + _21.assign(2.0f); - variableSpace.putVariable(2, 0, _20); - variableSpace.putVariable(2, 1, _21); + variableSpace.putVariable(2, 0, _20); + variableSpace.putVariable(2, 1, _21); - Context block(1, &variableSpace); + Context block(1, &variableSpace); - block.pickInput(2, 0); - block.pickInput(2, 1); + block.pickInput(2, 0); + block.pickInput(2, 1); - ASSERT_EQ(2, block.inputs()->size()); - ASSERT_EQ(2, block.width()); + ASSERT_EQ(2, block.inputs().size()); + ASSERT_EQ(2, block.width()); - ASSERT_TRUE(variableSpace.hasVariable(2, 0)); - ASSERT_TRUE(variableSpace.hasVariable(2, 1)); + ASSERT_TRUE(variableSpace.hasVariable(2, 0)); + ASSERT_TRUE(variableSpace.hasVariable(2, 1)); - ASSERT_NEAR(1.0f, block.variable(0)->getNDArray()->meanNumber().e(0), 1e-5); - ASSERT_NEAR(2.0f, block.variable(1)->getNDArray()->meanNumber().e(0), 1e-5); + ASSERT_NEAR(1.0f, block.variable(0)->getNDArray()->meanNumber().e(0), + 1e-5); + ASSERT_NEAR(2.0f, block.variable(1)->getNDArray()->meanNumber().e(0), + 1e-5); } - TEST_F(ContextTests, Basic_Test_2) { - VariableSpace variableSpace; + VariableSpace variableSpace; - auto _20 = NDArrayFactory::create_('c', {2, 2}); - auto _21 = NDArrayFactory::create_('c', {2, 2}); + auto _20 = NDArrayFactory::create('c', {2, 2}); + auto _21 = NDArrayFactory::create('c', {2, 2}); - _20->assign(1.0f); - _21->assign(2.0f); + _20.assign(1.0f); + _21.assign(2.0f); - variableSpace.putVariable(-1, _20); - variableSpace.putVariable(-2, _21); + variableSpace.putVariable(-1, _20); + variableSpace.putVariable(-2, _21); - Context block(1, &variableSpace); + Context block(1, &variableSpace); - block.pickInput(-1); - block.pickInput(-2); + block.pickInput(-1); + block.pickInput(-2); - ASSERT_EQ(2, block.inputs()->size()); - ASSERT_EQ(2, block.width()); + ASSERT_EQ(2, block.inputs().size()); + ASSERT_EQ(2, block.width()); - ASSERT_TRUE(variableSpace.hasVariable(-1)); - ASSERT_TRUE(variableSpace.hasVariable(-2)); + ASSERT_TRUE(variableSpace.hasVariable(-1)); + ASSERT_TRUE(variableSpace.hasVariable(-2)); - ASSERT_NEAR(1.0f, block.variable(0)->getNDArray()->meanNumber().e(0), 1e-5); - ASSERT_NEAR(2.0f, block.variable(1)->getNDArray()->meanNumber().e(0), 1e-5); + ASSERT_NEAR(1.0f, block.variable(0)->getNDArray()->meanNumber().e(0), + 1e-5); + ASSERT_NEAR(2.0f, block.variable(1)->getNDArray()->meanNumber().e(0), + 1e-5); } - TEST_F(ContextTests, Basic_Test_3) { - VariableSpace variableSpace; + VariableSpace variableSpace; - Context ctx(1, &variableSpace); + Context ctx(1, &variableSpace); - auto _20 = NDArrayFactory::create_('c', {2, 2}); + auto _20 = NDArrayFactory::create('c', {2, 2}); - ctx.pushNDArrayToVariableSpace(1, 1, _20); + ctx.pushNDArrayToVariableSpace(1, 1, _20); - ASSERT_TRUE(variableSpace.hasVariable(1, 1)); + ASSERT_TRUE(variableSpace.hasVariable(1, 1)); } - TEST_F(ContextTests, Basic_Test_4) { - VariableSpace variableSpace; + VariableSpace variableSpace; - Context ctx(1, &variableSpace); + Context ctx(1, &variableSpace); - auto _20 = NDArrayFactory::create_('c', {2, 2}); - _20->linspace(1); + auto _20 = NDArrayFactory::create('c', {2, 2}); + _20.linspace(1); - auto _21 = NDArrayFactory::create_('c', {2, 2}); - _21->linspace(10); + auto _21 = NDArrayFactory::create('c', {2, 2}); + _21.linspace(10); - ctx.pushNDArrayToVariableSpace(1, 1, _20); + ctx.pushNDArrayToVariableSpace(1, 1, _20); - ASSERT_TRUE(variableSpace.hasVariable(1, 1)); + ASSERT_TRUE(variableSpace.hasVariable(1, 1)); - ctx.pushNDArrayToVariableSpace(1, 1, _21); + ctx.pushNDArrayToVariableSpace(1, 1, _21); - auto vA = ctx.variable(1, 1); + auto vA = ctx.variable(1, 1); - ASSERT_TRUE(vA->getNDArray()->equalsTo(_21)); + ASSERT_TRUE(vA->getNDArray()->equalsTo(_21)); } TEST_F(ContextTests, Basic_Test_5) { - VariableSpace variableSpace; - - Context ctx(1, &variableSpace); - - auto _20 = NDArrayFactory::create_('c', {2, 2}); - _20->linspace(1); + VariableSpace variableSpace; - auto exp = new NDArray(_20->dup()); + Context ctx(1, &variableSpace); - ctx.pushNDArrayToVariableSpace(1, 1, _20); + auto _20 = NDArrayFactory::create('c', {2, 2}); + _20.linspace(1); - ASSERT_TRUE(variableSpace.hasVariable(1, 1)); + auto exp = _20.dup(); - ctx.pushNDArrayToVariableSpace(1, 1, _20); + ctx.pushNDArrayToVariableSpace(1, 1, _20); - auto vA = ctx.variable(1, 1); + ASSERT_TRUE(variableSpace.hasVariable(1, 1)); - ASSERT_TRUE(vA->getNDArray() == _20); + ctx.pushNDArrayToVariableSpace(1, 1, _20); - ASSERT_TRUE(vA->getNDArray()->equalsTo(exp)); + auto vA = ctx.variable(1, 1); - delete exp; + ASSERT_TRUE(vA->getNDArray()->equalsTo(exp)); } - TEST_F(ContextTests, Basic_Test_6) { - VariableSpace variableSpace; + VariableSpace variableSpace; - Context ctx(1, &variableSpace); + Context ctx(1, &variableSpace); - auto v0 = ctx.ensureVariable(); - auto v1 = ctx.ensureVariable(1); + auto v0 = ctx.ensureVariable("", 1, 0); + auto v1 = ctx.ensureVariable("", 1, 1); - ASSERT_TRUE(variableSpace.hasVariable(1, 0)); - ASSERT_TRUE(variableSpace.hasVariable(1, 1)); + ASSERT_TRUE(variableSpace.hasVariable(1, 0)); + ASSERT_TRUE(variableSpace.hasVariable(1, 1)); - auto var0 = variableSpace.getVariable(1, 0); - auto var1 = variableSpace.getVariable(1, 1); + auto var0 = variableSpace.getVariable(1, 0); + auto var1 = variableSpace.getVariable(1, 1); - ASSERT_TRUE(v0 == var0); - ASSERT_TRUE(v1 == var1); + ASSERT_TRUE(v0 == var0); + ASSERT_TRUE(v1 == var1); } - TEST_F(ContextTests, Basic_Test_7) { - VariableSpace variableSpace; - - Context ctx(1, &variableSpace); + VariableSpace variableSpace; - auto v0 = ctx.ensureVariable(); - auto v1 = ctx.ensureVariable(1); + Context ctx(1, &variableSpace); - ASSERT_TRUE(variableSpace.hasVariable(1, 0)); - ASSERT_TRUE(variableSpace.hasVariable(1, 1)); + auto v0 = ctx.ensureVariable("", 1, 0); + auto v1 = ctx.ensureVariable("", 1, 1); - auto var0 = variableSpace.getVariable(1, 0); - auto var1 = variableSpace.getVariable(1, 1); + ASSERT_TRUE(variableSpace.hasVariable(1, 0)); + ASSERT_TRUE(variableSpace.hasVariable(1, 1)); - ASSERT_TRUE(v0 == var0); - ASSERT_TRUE(v1 == var1); + auto var0 = variableSpace.getVariable(1, 0); + auto var1 = variableSpace.getVariable(1, 1); + ASSERT_TRUE(v0 == var0); + ASSERT_TRUE(v1 == var1); - auto _10 = NDArrayFactory::create_('c', {2, 2}); - _10->linspace(1); + auto _10 = NDArrayFactory::create('c', {2, 2}); + _10.linspace(1); - auto _11 = NDArrayFactory::create_('c', {2, 2}); - _11->linspace(10); + auto _11 = NDArrayFactory::create('c', {2, 2}); + _11.linspace(10); - ctx.pushNDArrayToVariableSpace(1, 0, _10); - ctx.pushNDArrayToVariableSpace(1, 1, _11); + ctx.pushNDArrayToVariableSpace(1, 0, _10); + ctx.pushNDArrayToVariableSpace(1, 1, _11); - auto z0 = variableSpace.getVariable(1, 0); - auto z1 = variableSpace.getVariable(1, 1); + auto z0 = variableSpace.getVariable(1, 0); + auto z1 = variableSpace.getVariable(1, 1); - ASSERT_TRUE(v0 == z0); - ASSERT_TRUE(v1 == z1); + ASSERT_TRUE(v0 == z0); + ASSERT_TRUE(v1 == z1); } TEST_F(ContextTests, Basic_Test_8) { - VariableSpace variableSpace; + VariableSpace variableSpace; - Context ctx(1, &variableSpace); + Context ctx(1, &variableSpace); - auto _10 = NDArrayFactory::create_('c', {2, 2}); - _10->linspace(1); + auto _10 = NDArrayFactory::create('c', {2, 2}); + _10.linspace(1); - auto _11 = NDArrayFactory::create_('c', {2, 2}); - _11->linspace(10); + auto _11 = NDArrayFactory::create('c', {2, 2}); + _11.linspace(10); - ctx.pushNDArrayToVariableSpace(1, 0, _10); - ctx.pushNDArrayToVariableSpace(1, 1, _11); + ctx.pushNDArrayToVariableSpace(1, 0, _10); + ctx.pushNDArrayToVariableSpace(1, 1, _11); - auto z0 = variableSpace.getVariable(1, 0); - auto z1 = variableSpace.getVariable(1, 1); + auto z0 = variableSpace.getVariable(1, 0); + auto z1 = variableSpace.getVariable(1, 1); - auto v0 = ctx.ensureVariable(); - auto v1 = ctx.ensureVariable(1); + auto v0 = ctx.ensureVariable("", 1, 0); + auto v1 = ctx.ensureVariable("", 1, 1); - ASSERT_TRUE(v0 == z0); - ASSERT_TRUE(v1 == z1); + ASSERT_TRUE(v0 == z0); + ASSERT_TRUE(v1 == z1); } - TEST_F(ContextTests, Basic_Test_9) { - VariableSpace variableSpace; + VariableSpace variableSpace; - auto in = NDArrayFactory::create('c', {5, 5}); + auto in = NDArrayFactory::create('c', {5, 5}); - Context ctx(1, &variableSpace, true); - ctx.pushNDArrayToVariableSpace(1, 1, &in, false); + Context ctx(1, &variableSpace, true); + ctx.pushNDArrayToVariableSpace(1, 1, in); } TEST_F(ContextTests, Basic_Test_10) { - VariableSpace variableSpace; + VariableSpace variableSpace; - Context ctx(119, &variableSpace); + Context ctx(119, &variableSpace); } - TEST_F(ContextTests, Prototype_Test_1) { - ContextPrototype prototype(nullptr, 119, true); - prototype.pickInput(12, 3); - prototype.pickInput(12, 4); + ContextPrototype prototype(nullptr, 119, true); + prototype.pickInput(12, 3); + prototype.pickInput(12, 4); - prototype.getTArguments()->push_back(2.0); - prototype.getTArguments()->push_back(-2.0); + prototype.appendT(2.0); + prototype.appendT(-2.0); - prototype.getIArguments()->push_back(17); - prototype.getIArguments()->push_back(119); + prototype.appendI(17); + prototype.appendI(119); - Context ctx(&prototype, nullptr); + Context ctx(prototype, nullptr); - ASSERT_EQ(ctx.nodeId(), prototype.nodeId()); - ASSERT_EQ(ctx.isInplace(), prototype.isInplace()); + ASSERT_EQ(ctx.nodeId(), prototype.nodeId()); + ASSERT_EQ(ctx.isInplace(), prototype.isInplace()); - ASSERT_EQ(2, ctx.inputs()->size()); - ASSERT_EQ(2, ctx.getTArguments()->size()); - ASSERT_EQ(2, ctx.getIArguments()->size()); + ASSERT_EQ(2, ctx.inputs().size()); + ASSERT_EQ(2, ctx.getTArguments().size()); + ASSERT_EQ(2, ctx.getIArguments().size()); - ASSERT_EQ(2.0, ctx.getTArguments()->at(0)); - ASSERT_EQ(-2.0, ctx.getTArguments()->at(1)); + ASSERT_EQ(2.0, ctx.getTArguments().at(0)); + ASSERT_EQ(-2.0, ctx.getTArguments().at(1)); - ASSERT_EQ(17, ctx.getIArguments()->at(0)); - ASSERT_EQ(119, ctx.getIArguments()->at(1)); + ASSERT_EQ(17, ctx.getIArguments().at(0)); + ASSERT_EQ(119, ctx.getIArguments().at(1)); } - TEST_F(ContextTests, Prototype_Test_2) { - ContextPrototype prototype(nullptr, 119, false); - prototype.setOpNum(179); + ContextPrototype prototype(nullptr, 119, false); + prototype.setOpNum(179); - Context ctx(&prototype, nullptr); + Context ctx(prototype, nullptr); - ASSERT_EQ(ctx.isInplace(), prototype.isInplace()); - ASSERT_EQ(ctx.opNum(), prototype.opNum()); + ASSERT_EQ(ctx.isInplace(), prototype.isInplace()); + ASSERT_EQ(ctx.opNum(), prototype.opNum()); - ASSERT_EQ(0, ctx.inputs()->size()); - ASSERT_EQ(0, ctx.getTArguments()->size()); - ASSERT_EQ(0, ctx.getIArguments()->size()); + ASSERT_EQ(0, ctx.inputs().size()); + ASSERT_EQ(0, ctx.getTArguments().size()); + ASSERT_EQ(0, ctx.getIArguments().size()); } TEST_F(ContextTests, test_short_context_1) { - auto array0 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto array1 = NDArrayFactory::create('c', {3, 2}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f}); - Context ctx(1); + auto array0 = NDArrayFactory::create('c', {3, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto array1 = NDArrayFactory::create( + 'c', {3, 2}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f}); + Context ctx(1); - ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), array0.specialBuffer(), array0.specialShapeInfo()); - ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), array1.specialBuffer(), array1.specialShapeInfo()); + ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), + array0.specialBuffer(), array0.specialShapeInfo()); + ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), + array1.specialBuffer(), array1.specialShapeInfo()); - ASSERT_EQ(2, ctx.width()); + ASSERT_EQ(2, ctx.width()); - auto input0 = ctx.array(0); - ASSERT_TRUE(input0 != nullptr); + auto input0 = ctx.array(0); + ASSERT_TRUE(input0 != nullptr); - auto input1 = ctx.array(1); - ASSERT_TRUE(input1 != nullptr); + auto input1 = ctx.array(1); + ASSERT_TRUE(input1 != nullptr); - ASSERT_TRUE(input0->buffer() == array0.buffer()); - ASSERT_TRUE(input0->shapeInfo() == array0.shapeInfo()); + ASSERT_TRUE(input0->buffer() == array0.buffer()); + ASSERT_TRUE(input0->shapeInfo() == array0.shapeInfo()); - ASSERT_TRUE(input0->specialBuffer() == array0.specialBuffer()); - ASSERT_TRUE(input0->specialShapeInfo() == array0.specialShapeInfo()); + ASSERT_TRUE(input0->specialBuffer() == array0.specialBuffer()); + ASSERT_TRUE(input0->specialShapeInfo() == array0.specialShapeInfo()); - ASSERT_TRUE(input1->buffer() == array1.buffer()); - ASSERT_TRUE(input1->shapeInfo() == array1.shapeInfo()); + ASSERT_TRUE(input1->buffer() == array1.buffer()); + ASSERT_TRUE(input1->shapeInfo() == array1.shapeInfo()); - ASSERT_TRUE(input1->specialBuffer() == array1.specialBuffer()); - ASSERT_TRUE(input1->specialShapeInfo() == array1.specialShapeInfo()); + ASSERT_TRUE(input1->specialBuffer() == array1.specialBuffer()); + ASSERT_TRUE(input1->specialShapeInfo() == array1.specialShapeInfo()); } TEST_F(ContextTests, test_short_context_2) { - auto array0 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto array1 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto z = NDArrayFactory::create('c', {3, 2}); + auto array0 = NDArrayFactory::create('c', {3, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto array1 = NDArrayFactory::create('c', {3, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto z = NDArrayFactory::create('c', {3, 2}); - auto exp = NDArrayFactory::create('c', {3, 2}, {2.f, 4.f, 6.f, 8.f, 10.f, 12.f}); - Context ctx(1); + auto exp = NDArrayFactory::create('c', {3, 2}, + {2.f, 4.f, 6.f, 8.f, 10.f, 12.f}); + Context ctx(1); - ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), array0.specialBuffer(), array0.specialShapeInfo()); - ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), array1.specialBuffer(), array1.specialShapeInfo()); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), + array0.specialBuffer(), array0.specialShapeInfo()); + ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), + array1.specialBuffer(), array1.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); - ASSERT_EQ(2, ctx.width()); + ASSERT_EQ(2, ctx.width()); - sd::ops::add op; - op.execute(&ctx); + sd::ops::add op; + op.execute(&ctx); - ASSERT_EQ(exp, z); + ASSERT_EQ(exp, z); } TEST_F(ContextTests, test_short_context_3) { - auto array0 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto array1 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto array0 = NDArrayFactory::create('c', {3, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto array1 = NDArrayFactory::create('c', {3, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + + auto exp = NDArrayFactory::create('c', {3, 2}, + {2.f, 4.f, 6.f, 8.f, 10.f, 12.f}); + Context ctx(1); + + ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), + array0.specialBuffer(), array0.specialShapeInfo()); + ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), + array1.specialBuffer(), array1.specialShapeInfo()); - auto exp = NDArrayFactory::create('c', {3, 2}, {2.f, 4.f, 6.f, 8.f, 10.f, 12.f}); - Context ctx(1); + ASSERT_EQ(2, ctx.width()); - ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), array0.specialBuffer(), array0.specialShapeInfo()); - ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), array1.specialBuffer(), array1.specialShapeInfo()); + sd::ops::add op; + op.execute(&ctx); - ASSERT_EQ(2, ctx.width()); + ASSERT_EQ(1, ctx.fastpath_out().size()); - sd::ops::add op; - op.execute(&ctx); + auto z = ctx.fastpath_out()[0]; + + ASSERT_EQ(exp, *z); +} - ASSERT_EQ(1, ctx.fastpath_out().size()); +TEST_F(ContextTests, test_copy_1) { + ContextPrototype prototype(nullptr, 12); - auto z = ctx.fastpath_out()[0]; + auto copy = prototype; - ASSERT_EQ(exp, *z); + ASSERT_EQ(prototype.nodeId(), copy.nodeId()); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp index b879854583e0..087c12bfebc3 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp @@ -21,17 +21,18 @@ #ifndef LIBND4J_CONVOLUTIONTESTS1_H #define LIBND4J_CONVOLUTIONTESTS1_H -#include "testlayers.h" #include #include #include #include #include +#include +#include #include -#include #include -#include -#include +#include + +#include "testlayers.h" #ifdef HAVE_MKLDNN #include @@ -41,2817 +42,4853 @@ using namespace sd; using namespace sd::graph; class ConvolutionTests1 : public testing::Test { -public: - + public: }; template class TypedConvolutionTests1 : public testing::Test { -public: - + public: }; typedef ::testing::Types TestingTypes; -TYPED_TEST_CASE(TypedConvolutionTests1, TestingTypes); +TYPED_TEST_SUITE(TypedConvolutionTests1, TestingTypes); ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv2d_1) { - - int bS=1, iH=5,iW=4, iC=2,oC=3, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - TypeParam _expB[]{664.0, 700.0, 736.0, 344.0, 808.0, 844.0, 880.0, 408.0, 952.0, 988.0, 1024.0, 472.0, 1096.0, 1132.0, 1168.0, 536.0, 466.0, 480.0, 494.0, 220.0, 1528.0, 1628.0, 1728.0, 856.0, 1928.0, 2028.0, 2128.0, 1048.0, 2328.0, 2428.0, 2528.0, 1240.0, 2728.0, 2828.0, 2928.0, 1432.0, 1346.0, 1392.0, 1438.0, 700.0, 2392.0, 2556.0, 2720.0, 1368.0, 3048.0, 3212.0, 3376.0, 1688.0, 3704.0, 3868.0, 4032.0, 2008.0, 4360.0, 4524.0, 4688.0, 2328.0, 2226.0, 2304.0, 2382.0, 1180.0}; - Nd4jLong _expS[]{4, 1, 3, 5, 4, 60, 20, 4, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; - auto input = NDArrayFactory::create_('c', {bS, iC, iH, iW}); - auto weights = NDArrayFactory::create_('c', {oC, iC, kH, kW}); - for (int e = 0; e < input->lengthOf(); e++) - input->p(e, e + 1); - - for (int e = 0; e < weights->lengthOf(); e++) - weights->p(e, e + 1); - weights->permutei({2,3,1,0}); - - // weights->printShapeInfo("weights"); - - ArrayOptions::setDataType(_expS, input->dataType()); - auto exp = new NDArray(_expB, _expS); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, input); - variableSpace->putVariable(-2, weights); - - auto block = new Context(1, variableSpace, false); // not-in-place - block->fillInputs({-1, -2}); - // 5,5 kernel - block->getIArguments()->push_back(kH); - block->getIArguments()->push_back(kW); - - // 1,1 stride - block->getIArguments()->push_back(sH); - block->getIArguments()->push_back(sW); - - // 0,0 padding - block->getIArguments()->push_back(pH); - block->getIArguments()->push_back(pW); - - // 1,1 dilation - block->getIArguments()->push_back(dH); - block->getIArguments()->push_back(dW); - - // same mode - block->getIArguments()->push_back(1); - - // is NHWC - block->getIArguments()->push_back(0); - - sd::ops::conv2d op; - - Nd4jStatus status = op.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto res = variableSpace->getVariable(1)->getNDArray(); - - - // checking output shape - ASSERT_EQ(1, res->sizeAt(0)); - ASSERT_EQ(3, res->sizeAt(1)); - ASSERT_EQ(5, res->sizeAt(2)); - ASSERT_EQ(4, res->sizeAt(3)); - - // basically the same as above - ASSERT_TRUE(res->isSameShape(exp)); - // just for visual validation - // exp->printIndexedBuffer("Expected"); - // res->printIndexedBuffer("Actual "); - // res->printShapeInfo("Result shape"); - // final check - ASSERT_TRUE(res->equalsTo(exp)); - - delete block; - delete variableSpace; - delete exp; + int bS = 1, iH = 5, iW = 4, iC = 2, oC = 3, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + TypeParam _expB[]{ + 664.0, 700.0, 736.0, 344.0, 808.0, 844.0, 880.0, 408.0, 952.0, + 988.0, 1024.0, 472.0, 1096.0, 1132.0, 1168.0, 536.0, 466.0, 480.0, + 494.0, 220.0, 1528.0, 1628.0, 1728.0, 856.0, 1928.0, 2028.0, 2128.0, + 1048.0, 2328.0, 2428.0, 2528.0, 1240.0, 2728.0, 2828.0, 2928.0, 1432.0, + 1346.0, 1392.0, 1438.0, 700.0, 2392.0, 2556.0, 2720.0, 1368.0, 3048.0, + 3212.0, 3376.0, 1688.0, 3704.0, 3868.0, 4032.0, 2008.0, 4360.0, 4524.0, + 4688.0, 2328.0, 2226.0, 2304.0, 2382.0, 1180.0}; + Nd4jLong _expS[]{ + 4, 1, 3, 5, 4, + 60, 20, 4, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, + 1, 99}; + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kH, kW}); + for (int e = 0; e < input.lengthOf(); e++) input.p(e, e + 1); + + for (int e = 0; e < weights.lengthOf(); e++) weights.p(e, e + 1); + weights.permutei({2, 3, 1, 0}); + + // weights->printShapeInfo("weights"); + + ArrayOptions::setDataType(_expS, input.dataType()); + auto exp = new NDArray(_expB, _expS); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); + variableSpace->putVariable(-2, weights); + + auto block = new Context(1, variableSpace, false); // not-in-place + block->fillInputs({-1, -2}); + // 5,5 kernel + block->appendI(kH); + block->appendI(kW); + + // 1,1 stride + block->appendI(sH); + block->appendI(sW); + + // 0,0 padding + block->appendI(pH); + block->appendI(pW); + + // 1,1 dilation + block->appendI(dH); + block->appendI(dW); + + // same mode + block->appendI(1); + + // is NHWC + block->appendI(0); + + sd::ops::conv2d op; + + Nd4jStatus status = op.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto res = variableSpace->getVariable(1)->getNDArray(); + + // checking output shape + ASSERT_EQ(1, res->sizeAt(0)); + ASSERT_EQ(3, res->sizeAt(1)); + ASSERT_EQ(5, res->sizeAt(2)); + ASSERT_EQ(4, res->sizeAt(3)); + + // basically the same as above + ASSERT_TRUE(res->isSameShape(exp)); + // just for visual validation + // exp->printIndexedBuffer("Expected"); + // res->printIndexedBuffer("Actual "); + // res->printShapeInfo("Result shape"); + // final check + ASSERT_TRUE(res->equalsTo(exp)); + + delete block; + delete variableSpace; + delete exp; } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv2d_2) { - auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto exp = NDArrayFactory::create('c', {1, 4, 1, 4}, {2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f}); + auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto exp = NDArrayFactory::create( + 'c', {1, 4, 1, 4}, + {2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, + 6.f, 8.f}); - weights.assign(2.0); - input.linspace(1); + weights.assign(2.0); + input.linspace(1); - sd::ops::conv2d op; - auto result = op.evaluate({&input, &weights}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::conv2d op; + auto result = + op.evaluate({&input, &weights}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv2d_3) { - - int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=4,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); - auto bias = NDArrayFactory::create('c', {oC}, {1.f, 2.f, 3.f}); - - - auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{ 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, - 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f, - 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, - 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f}); - input = 2.; - weights.linspace(0.1, 0.1); - - sd::ops::conv2d op; - auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - - + int bS = 2, iH = 4, iW = 3, iC = 4, oC = 3, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}, {1.f, 2.f, 3.f}); + + auto expOutput = NDArrayFactory::create( + 'c', {bS, oH, oW, oC}, + {152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, + 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, + 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, + 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f, + 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, + 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, + 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, + 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f}); + input = 2.; + weights.linspace(0.1, 0.1); + + sd::ops::conv2d op; + auto results = + op.evaluate({&input, &weights}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv2d_4) { - - int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); - - auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{ 170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f}); - - input = 2.; - weights.linspace(0.1, 0.1); - - sd::ops::conv2d op; - auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 2, iH = 4, iW = 3, iC = 4, oC = 3, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}, {1, 2, 3}); + + auto expOutput = NDArrayFactory::create( + 'c', {bS, oH, oW, oC}, + {170.4f, 175.20001f, 180.f, 170.4f, 175.20001f, 180.f, + 170.4f, 175.20001f, 180.f, 170.4f, 175.20001f, 180.f, + 170.4f, 175.20001f, 180.f, 170.4f, 175.20001f, 180.f, + 170.4f, 175.20001f, 180.f, 170.4f, 175.20001f, 180.f}); + + input = 2.; + weights.linspace(0.1, 0.1); + + sd::ops::conv2d op; + auto results = + op.evaluate({&input, &weights}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv2d_5) { - - int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto weights = NDArrayFactory::create('c', {oC, iC, kH, kW}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); - - auto expOutput = NDArrayFactory::create('c', {bS, oC, oH, oW}, {61.f, 61.f, 61.f, 61.f, 177.2f, 177.2f, 177.2f, 177.2f, 293.4f, 293.4f, 293.4f, 293.4f, 61.f, 61.f, 61.f, 61.f, 177.2f, 177.2f, 177.2f, 177.2f, 293.4f, 293.4f, 293.4f, 293.4f}); - - input = 2.; - weights.linspace(0.1, 0.1); - weights.permutei({2,3,1,0}); - - sd::ops::conv2d op; - auto results = op.evaluate({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - // output->printIndexedBuffer(); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 2, iH = 4, iW = 3, iC = 4, oC = 3, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kH, kW}); + auto bias = NDArrayFactory::create('c', {oC}, {1, 2, 3}); + + auto expOutput = NDArrayFactory::create( + 'c', {bS, oC, oH, oW}, + {61.f, 61.f, 61.f, 61.f, 177.2f, 177.2f, 177.2f, 177.2f, + 293.4f, 293.4f, 293.4f, 293.4f, 61.f, 61.f, 61.f, 61.f, + 177.2f, 177.2f, 177.2f, 177.2f, 293.4f, 293.4f, 293.4f, 293.4f}); + + input = 2.; + weights.linspace(0.1, 0.1); + weights.permutei({2, 3, 1, 0}); + + sd::ops::conv2d op; + auto results = + op.evaluate({&input, &weights, &bias}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + // output->printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv2d_6) { - auto input = NDArrayFactory::create('c', {54, 1, 12, 12}); - auto weights = NDArrayFactory::create('c', {1, 2, 12, 2}); + auto input = NDArrayFactory::create('c', {54, 1, 12, 12}); + auto weights = NDArrayFactory::create('c', {1, 2, 12, 2}); - sd::ops::conv2d op; - auto result = op.evaluate({&input, &weights}, {}, {-1,-1, 1,1, 0,0, 1,1, 1,1}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::conv2d op; + auto result = + op.evaluate({&input, &weights}, {}, {-1, -1, 1, 1, 0, 0, 1, 1, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv2d_7) { + int bS = 1, iH = 256, iW = 256, iC = 1, oC = 1, kH = 4, kW = 3, sH = 1, + sW = 1, pH = 0, pW = 0, dH = 1, dW = 1; + // int oH=256,oW=256; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW - int bS=1, iH=256,iW=256, iC=1,oC=1, kH=4,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - // int oH=256,oW=256; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); - - input = 5.; - weights = 3.; + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); - sd::ops::conv2d op; - auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); + input = 5.; + weights = 3.; + sd::ops::conv2d op; + auto results = + op.evaluate({&input, &weights}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, conv2d_8) { - - int bS=1, iH=6,iW=8, iC=2,oC=2, kH=2,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=6,oW=8; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iC, iH, iW}, {0.679350, 0.355087, 0.842789, 0.200313, 0.701499, 0.310693, 0.447940, 0.938010, 0.326674, 0.151873, 0.383318, 0.782123, 0.198807, - 0.798564, 0.163263, 0.146968, 0.260897, 0.135058, 0.756209, 0.275454, 0.369088, 0.092826, 0.836492, 0.268413, 0.095062, 0.312795, 0.135918, 0.517544, 0.328703, - 0.061736, 0.396431, 0.248016, 0.548959, 0.115046, 0.814362, 0.721564, 0.404494, 0.299089, 0.403884, 0.988311, 0.022296, 0.927782, 0.318416, 0.068546, 0.284533, - 0.232720, 0.352142, 0.058909, 0.711221, 0.674457, 0.196946, 0.699497, 0.074322, 0.420425, 0.584263, 0.149574, 0.446406, 0.723072, 0.064481, 0.483078, 0.875996, - 0.569819, 0.445863, 0.527755, 0.016646, 0.753678, 0.140636, 0.754129, 0.161932, 0.775037, 0.332645, 0.117394, 0.017711, 0.608476, 0.525152, 0.917194, 0.849891, - 0.589423, 0.852278, 0.390636, 0.889683, 0.669445, 0.698873, 0.961480, 0.157401, 0.157364, 0.493520, 0.569937, 0.126832, 0.115728, 0.786368, 0.737939, 0.490079, 0.608414, 0.956500, 0.390098}); - - NDArray weights('c', {kH, kW, iC, oC}, {0.07581716775894165, 0.8706002235412598, 0.29345420002937317, 0.5281786322593689, 0.10540834069252014, 0.3663792014122009, 0.17209206521511078, 0.6257694959640503}); - NDArray bias('c', {1, oC}, {0.7414038777351379, 0.8980839848518372}); - - NDArray expOutput('c', {bS, oC, oH, oW}, {1.112878, 1.106691, 0.914598, 1.127438, 0.988108, 1.070572, 1.040759, 0.962728, 0.927537, 1.109045, 0.893301, 1.101278, 1.080314, - 1.112327, 1.030041, 0.955914, 0.779137, 1.110499, 0.944709, 1.195986, 0.997814, 1.083822, 1.090898, 0.889572, 0.964781, 1.071012, 1.111928, 1.291319, 1.085454, 0.977661, - 1.149068, 1.077099, 1.068283, 1.064290, 1.177125, 1.212480, 0.932593, 0.939493, 1.118576, 1.056927, 0.780314, 0.845707, 0.996308, 0.963152, 0.906792, 0.937590, 1.048791, - 0.860346, 2.264212, 2.071576, 1.916629, 2.030785, 2.169075, 2.039786, 1.935480, 2.177816, 1.524273, 1.933327, 1.630923, 2.406983, 1.770406, 2.413284, 1.790349, 1.476586, - 1.179925, 1.909109, 2.009143, 2.299778, 1.957207, 1.779718, 2.480604, 1.529086, 1.748063, 1.952856, 2.029487, 2.699131, 1.879842, 1.471205, 2.150177, 2.039078, 1.933456, - 1.764169, 2.584944, 2.521004, 1.744296, 1.707578, 2.237938, 2.325231, 0.984485, 1.766936, 1.590640, 1.347524, 1.404648, 1.422042, 1.709862, 1.155412}); - - sd::ops::conv2d op; - auto results = op.evaluate({&input, &weights, &bias}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - // output->printBuffer(); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - + int bS = 1, iH = 6, iW = 8, iC = 2, oC = 2, kH = 2, kW = 1, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 6, oW = 8; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input( + 'c', {bS, iC, iH, iW}, + {0.679350, 0.355087, 0.842789, 0.200313, 0.701499, 0.310693, 0.447940, + 0.938010, 0.326674, 0.151873, 0.383318, 0.782123, 0.198807, 0.798564, + 0.163263, 0.146968, 0.260897, 0.135058, 0.756209, 0.275454, 0.369088, + 0.092826, 0.836492, 0.268413, 0.095062, 0.312795, 0.135918, 0.517544, + 0.328703, 0.061736, 0.396431, 0.248016, 0.548959, 0.115046, 0.814362, + 0.721564, 0.404494, 0.299089, 0.403884, 0.988311, 0.022296, 0.927782, + 0.318416, 0.068546, 0.284533, 0.232720, 0.352142, 0.058909, 0.711221, + 0.674457, 0.196946, 0.699497, 0.074322, 0.420425, 0.584263, 0.149574, + 0.446406, 0.723072, 0.064481, 0.483078, 0.875996, 0.569819, 0.445863, + 0.527755, 0.016646, 0.753678, 0.140636, 0.754129, 0.161932, 0.775037, + 0.332645, 0.117394, 0.017711, 0.608476, 0.525152, 0.917194, 0.849891, + 0.589423, 0.852278, 0.390636, 0.889683, 0.669445, 0.698873, 0.961480, + 0.157401, 0.157364, 0.493520, 0.569937, 0.126832, 0.115728, 0.786368, + 0.737939, 0.490079, 0.608414, 0.956500, 0.390098}); + + NDArray weights('c', {kH, kW, iC, oC}, + {0.07581716775894165, 0.8706002235412598, 0.29345420002937317, + 0.5281786322593689, 0.10540834069252014, 0.3663792014122009, + 0.17209206521511078, 0.6257694959640503}); + NDArray bias('c', {1, oC}, {0.7414038777351379, 0.8980839848518372}); + + NDArray expOutput( + 'c', {bS, oC, oH, oW}, + {1.112878, 1.106691, 0.914598, 1.127438, 0.988108, 1.070572, 1.040759, + 0.962728, 0.927537, 1.109045, 0.893301, 1.101278, 1.080314, 1.112327, + 1.030041, 0.955914, 0.779137, 1.110499, 0.944709, 1.195986, 0.997814, + 1.083822, 1.090898, 0.889572, 0.964781, 1.071012, 1.111928, 1.291319, + 1.085454, 0.977661, 1.149068, 1.077099, 1.068283, 1.064290, 1.177125, + 1.212480, 0.932593, 0.939493, 1.118576, 1.056927, 0.780314, 0.845707, + 0.996308, 0.963152, 0.906792, 0.937590, 1.048791, 0.860346, 2.264212, + 2.071576, 1.916629, 2.030785, 2.169075, 2.039786, 1.935480, 2.177816, + 1.524273, 1.933327, 1.630923, 2.406983, 1.770406, 2.413284, 1.790349, + 1.476586, 1.179925, 1.909109, 2.009143, 2.299778, 1.957207, 1.779718, + 2.480604, 1.529086, 1.748063, 1.952856, 2.029487, 2.699131, 1.879842, + 1.471205, 2.150177, 2.039078, 1.933456, 1.764169, 2.584944, 2.521004, + 1.744296, 1.707578, 2.237938, 2.325231, 0.984485, 1.766936, 1.590640, + 1.347524, 1.404648, 1.422042, 1.709862, 1.155412}); + + sd::ops::conv2d op; + auto results = + op.evaluate({&input, &weights, &bias}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + // output->printBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, conv2d_9) { - - int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - int wFormat = 1; // 0-[kH, kW, iC, oC], 1-[oC, iC, kH, kW], 2-[oC, kH, kW, iC] - - NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); - NDArray weights('c', {oC, iC, kH, kW}, {-3., -1.8, -0.6, 0.6, 1.8, 3., -2.7, -1.5, -0.3, 0.9, 2.1, 3.3, -2.4, -1.2, 0., 1.2, 2.4, 3.6, -2.1, -0.9, 0.3, 1.5, - 2.7, 3.9, -2.9, -1.7, -0.5, 0.7, 1.9, 3.1, -2.6, -1.4, -0.2, 1., 2.2, 3.4, -2.3, -1.1, 0.1, 1.3, 2.5, 3.7, -2., -0.8, 0.4, 1.6, - 2.8, 4., -2.8, -1.6, -0.4, 0.8, 2., 3.2, -2.5, -1.3, -0.1, 1.1, 2.3, 3.5, -2.2, -1., 0.2, 1.4, 2.6, 3.8, -1.9, -0.7, 0.5, 1.7, 2.9, 4.1}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {-1,2,0.5}, sd::DataType::FLOAT32); - - NDArray expOutput('c', {bS, oC, oH, oW}, {37.699997, 32.300041, 21.499989, 16.100004, 74.900024, 68.300003, 55.100006, 48.499969, 107.599983, 99.799988, - 84.200005, 76.400009, -221.5, -226.899994, -237.699997, -243.099991, -241.899994, -248.5, -261.700012, -268.299988, - -266.799988, -274.600006, -290.200012, -298.}, sd::DataType::FLOAT32); - - input.linspace(25,-0.5); - - sd::ops::conv2d op; - auto results = op.evaluate({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 2, iH = 4, iW = 3, iC = 4, oC = 3, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = + 1; // 0-[kH, kW, iC, oC], 1-[oC, iC, kH, kW], 2-[oC, kH, kW, iC] + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {oC, iC, kH, kW}, + {-3., -1.8, -0.6, 0.6, 1.8, 3., -2.7, -1.5, -0.3, 0.9, 2.1, 3.3, + -2.4, -1.2, 0., 1.2, 2.4, 3.6, -2.1, -0.9, 0.3, 1.5, 2.7, 3.9, + -2.9, -1.7, -0.5, 0.7, 1.9, 3.1, -2.6, -1.4, -0.2, 1., 2.2, 3.4, + -2.3, -1.1, 0.1, 1.3, 2.5, 3.7, -2., -0.8, 0.4, 1.6, 2.8, 4., + -2.8, -1.6, -0.4, 0.8, 2., 3.2, -2.5, -1.3, -0.1, 1.1, 2.3, 3.5, + -2.2, -1., 0.2, 1.4, 2.6, 3.8, -1.9, -0.7, 0.5, 1.7, 2.9, 4.1}, + sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-1, 2, 0.5}, sd::DataType::FLOAT32); + + NDArray expOutput( + 'c', {bS, oC, oH, oW}, + {37.699997, 32.300041, 21.499989, 16.100004, 74.900024, + 68.300003, 55.100006, 48.499969, 107.599983, 99.799988, + 84.200005, 76.400009, -221.5, -226.899994, -237.699997, + -243.099991, -241.899994, -248.5, -261.700012, -268.299988, + -266.799988, -274.600006, -290.200012, -298.}, + sd::DataType::FLOAT32); + + input.linspace(25, -0.5); + + sd::ops::conv2d op; + auto results = op.evaluate( + {&input, &weights, &bias}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat, wFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, conv2d_10) { - - int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=4,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - int wFormat = 2; // 0-[kH, kW, iC, oC], 1-[oC, iC, kH, kW], 2-[oC, kH, kW, iC] - - NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {oC, kH, kW, iC}, {-3., -2.7, -2.4, -2.1, -1.8, -1.5, -1.2, -0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9, 1.2, 1.5, 1.8, 2.1, 2.4, 2.7, 3., 3.3, - 3.6, 3.9, -2.9, -2.6, -2.3, -2., -1.7, -1.4, -1.1, -0.8, -0.5, -0.2, 0.1, 0.4, 0.7, 1., 1.3, 1.6, 1.9, 2.2, 2.5, 2.8, - 3.1, 3.4, 3.7, 4., -2.8, -2.5, -2.2, -1.9, -1.6, -1.3, -1., -0.7, -0.4, -0.1, 0.2, 0.5, 0.8, 1.1, 1.4, 1.7, 2., 2.3, 2.6, - 2.9, 3.2, 3.5, 3.8, 4.1}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {-1,2,0.5}, sd::DataType::FLOAT32); - - NDArray expOutput('c', {bS, oH, oW, oC}, {463.400055, 498.800018, 529.700012, 410.600006, 442.799988, 470.500031, 113.600006, 130.400009, 142.699982, - -63.999958, -19.600082, 20.300007, -85.600052, -45.999939, -10.899940, -144.100021, -124., -108.399994, -128.799988, -98.799973, -73.300011, - -150.400009, -125.200012, -104.500008, -133.300003, -120.399994, -112.000008, -170.199997, -154., -142.299988, -146.200012, -133.199997, -124.699997, - -88.000008, -80.800003, -78.099991, -170.200012, -173.199997, -180.699982, -223., -229.199997, -239.900009, -88., -90.400002, -97.300003, -323.200012, - -336.399994, -354.100037, -344.800018, -362.799988, -385.299957, -100.900002, -109.600006, -122.800003, -388.000031, -415.599976, -447.700012, -409.599976, - -442., -478.900024, -90.099991, -105.999992, -126.399994, 117.800003, 95.599991, 68.899994, 141.799988, 116.399994, 86.5, 171.200012, 159.200012, 142.699997}, sd::DataType::FLOAT32); - - input.linspace(25,-0.5); - - sd::ops::conv2d op; - auto results = op.evaluate({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 2, iH = 4, iW = 3, iC = 4, oC = 3, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = + 2; // 0-[kH, kW, iC, oC], 1-[oC, iC, kH, kW], 2-[oC, kH, kW, iC] + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {oC, kH, kW, iC}, + {-3., -2.7, -2.4, -2.1, -1.8, -1.5, -1.2, -0.9, -0.6, -0.3, 0., 0.3, + 0.6, 0.9, 1.2, 1.5, 1.8, 2.1, 2.4, 2.7, 3., 3.3, 3.6, 3.9, + -2.9, -2.6, -2.3, -2., -1.7, -1.4, -1.1, -0.8, -0.5, -0.2, 0.1, 0.4, + 0.7, 1., 1.3, 1.6, 1.9, 2.2, 2.5, 2.8, 3.1, 3.4, 3.7, 4., + -2.8, -2.5, -2.2, -1.9, -1.6, -1.3, -1., -0.7, -0.4, -0.1, 0.2, 0.5, + 0.8, 1.1, 1.4, 1.7, 2., 2.3, 2.6, 2.9, 3.2, 3.5, 3.8, 4.1}, + sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-1, 2, 0.5}, sd::DataType::FLOAT32); + + NDArray expOutput( + 'c', {bS, oH, oW, oC}, + {463.400055, 498.800018, 529.700012, 410.600006, 442.799988, + 470.500031, 113.600006, 130.400009, 142.699982, -63.999958, + -19.600082, 20.300007, -85.600052, -45.999939, -10.899940, + -144.100021, -124., -108.399994, -128.799988, -98.799973, + -73.300011, -150.400009, -125.200012, -104.500008, -133.300003, + -120.399994, -112.000008, -170.199997, -154., -142.299988, + -146.200012, -133.199997, -124.699997, -88.000008, -80.800003, + -78.099991, -170.200012, -173.199997, -180.699982, -223., + -229.199997, -239.900009, -88., -90.400002, -97.300003, + -323.200012, -336.399994, -354.100037, -344.800018, -362.799988, + -385.299957, -100.900002, -109.600006, -122.800003, -388.000031, + -415.599976, -447.700012, -409.599976, -442., -478.900024, + -90.099991, -105.999992, -126.399994, 117.800003, 95.599991, + 68.899994, 141.799988, 116.399994, 86.5, 171.200012, + 159.200012, 142.699997}, + sd::DataType::FLOAT32); + + input.linspace(25, -0.5); + + sd::ops::conv2d op; + auto results = op.evaluate( + {&input, &weights, &bias}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat, wFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, sconv2d_1) { - float _expB[] = {10025.0f, 10350.0f, 10675.0f, 11000.0f, 11325.0f, 11650.0f, 13275.0f, 13600.0f, 13925.0f, 14250.0f, 14575.0f, 14900.0f, 16525.0f, 16850.0f, 17175.0f, 17500.0f, 17825.0f, 18150.0f, 19775.0f, 20100.0f, 20425.0f, 20750.0f, 21075.0f, 21400.0f, 23025.0f, 23350.0f, 23675.0f, 24000.0f, 24325.0f, 24650.0f, 26275.0f, 26600.0f, 26925.0f, 27250.0f, 27575.0f, 27900.0f, 38775.0f, 40350.0f, 41925.0f, 43500.0f, 45075.0f, 46650.0f, 54525.0f, 56100.0f, 57675.0f, 59250.0f, 60825.0f, 62400.0f, 70275.0f, 71850.0f, 73425.0f, 75000.0f, 76575.0f, 78150.0f, 86025.0f, 87600.0f, 89175.0f, 90750.0f, 92325.0f, 93900.0f, 101775.0f, 103350.0f, 104925.0f, 106500.0f, 108075.0f, 109650.0f, 117525.0f, 119100.0f, 120675.0f, 122250.0f, 123825.0f, 125400.0f, 67525.0f, 70350.0f, 73175.0f, 76000.0f, 78825.0f, 81650.0f, 95775.0f, 98600.0f, 101425.0f, 104250.0f, 107075.0f, 109900.0f, 124025.0f, 126850.0f, 129675.0f, 132500.0f, 135325.0f, 138150.0f, 152275.0f, 155100.0f, 157925.0f, 160750.0f, 163575.0f, 166400.0f, 180525.0f, 183350.0f, 186175.0f, 189000.0f, 191825.0f, 194650.0f, 208775.0f, 211600.0f, 214425.0f, 217250.0f, 220075.0f, 222900.0f, 119400.0f, 120350.0f, 121300.0f, 122250.0f, 123200.0f, 124150.0f, 128900.0f, 129850.0f, 130800.0f, 131750.0f, 132700.0f, 133650.0f, 138400.0f, 139350.0f, 140300.0f, 141250.0f, 142200.0f, 143150.0f, 147900.0f, 148850.0f, 149800.0f, 150750.0f, 151700.0f, 152650.0f, 157400.0f, 158350.0f, 159300.0f, 160250.0f, 161200.0f, 162150.0f, 166900.0f, 167850.0f, 168800.0f, 169750.0f, 170700.0f, 171650.0f, 273150.0f, 275350.0f, 277550.0f, 279750.0f, 281950.0f, 284150.0f, 295150.0f, 297350.0f, 299550.0f, 301750.0f, 303950.0f, 306150.0f, 317150.0f, 319350.0f, 321550.0f, 323750.0f, 325950.0f, 328150.0f, 339150.0f, 341350.0f, 343550.0f, 345750.0f, 347950.0f, 350150.0f, 361150.0f, 363350.0f, 365550.0f, 367750.0f, 369950.0f, 372150.0f, 383150.0f, 385350.0f, 387550.0f, 389750.0f, 391950.0f, 394150.0f, 426900.0f, 430350.0f, 433800.0f, 437250.0f, 440700.0f, 444150.0f, 461400.0f, 464850.0f, 468300.0f, 471750.0f, 475200.0f, 478650.0f, 495900.0f, 499350.0f, 502800.0f, 506250.0f, 509700.0f, 513150.0f, 530400.0f, 533850.0f, 537300.0f, 540750.0f, 544200.0f, 547650.0f, 564900.0f, 568350.0f, 571800.0f, 575250.0f, 578700.0f, 582150.0f, 599400.0f, 602850.0f, 606300.0f, 609750.0f, 613200.0f, 616650.0f, 75025.0f, 75350.0f, 75675.0f, 76000.0f, 76325.0f, 76650.0f, 78275.0f, 78600.0f, 78925.0f, 79250.0f, 79575.0f, 79900.0f, 81525.0f, 81850.0f, 82175.0f, 82500.0f, 82825.0f, 83150.0f, 84775.0f, 85100.0f, 85425.0f, 85750.0f, 86075.0f, 86400.0f, 88025.0f, 88350.0f, 88675.0f, 89000.0f, 89325.0f, 89650.0f, 91275.0f, 91600.0f, 91925.0f, 92250.0f, 92575.0f, 92900.0f, 353775.0f, 355350.0f, 356925.0f, 358500.0f, 360075.0f, 361650.0f, 369525.0f, 371100.0f, 372675.0f, 374250.0f, 375825.0f, 377400.0f, 385275.0f, 386850.0f, 388425.0f, 390000.0f, 391575.0f, 393150.0f, 401025.0f, 402600.0f, 404175.0f, 405750.0f, 407325.0f, 408900.0f, 416775.0f, 418350.0f, 419925.0f, 421500.0f, 423075.0f, 424650.0f, 432525.0f, 434100.0f, 435675.0f, 437250.0f, 438825.0f, 440400.0f, 632525.0f, 635350.0f, 638175.0f, 641000.0f, 643825.0f, 646650.0f, 660775.0f, 663600.0f, 666425.0f, 669250.0f, 672075.0f, 674900.0f, 689025.0f, 691850.0f, 694675.0f, 697500.0f, 700325.0f, 703150.0f, 717275.0f, 720100.0f, 722925.0f, 725750.0f, 728575.0f, 731400.0f, 745525.0f, 748350.0f, 751175.0f, 754000.0f, 756825.0f, 759650.0f, 773775.0f, 776600.0f, 779425.0f, 782250.0f, 785075.0f, 787900.0f, 309400.0f, 310350.0f, 311300.0f, 312250.0f, 313200.0f, 314150.0f, 318900.0f, 319850.0f, 320800.0f, 321750.0f, 322700.0f, 323650.0f, 328400.0f, 329350.0f, 330300.0f, 331250.0f, 332200.0f, 333150.0f, 337900.0f, 338850.0f, 339800.0f, 340750.0f, 341700.0f, 342650.0f, 347400.0f, 348350.0f, 349300.0f, 350250.0f, 351200.0f, 352150.0f, 356900.0f, 357850.0f, 358800.0f, 359750.0f, 360700.0f, 361650.0f, 713150.0f, 715350.0f, 717550.0f, 719750.0f, 721950.0f, 724150.0f, 735150.0f, 737350.0f, 739550.0f, 741750.0f, 743950.0f, 746150.0f, 757150.0f, 759350.0f, 761550.0f, 763750.0f, 765950.0f, 768150.0f, 779150.0f, 781350.0f, 783550.0f, 785750.0f, 787950.0f, 790150.0f, 801150.0f, 803350.0f, 805550.0f, 807750.0f, 809950.0f, 812150.0f, 823150.0f, 825350.0f, 827550.0f, 829750.0f, 831950.0f, 834150.0f, 1116900.0f, 1120350.0f, 1123800.0f, 1127250.0f, 1130700.0f, 1134150.0f, 1151400.0f, 1154850.0f, 1158300.0f, 1161750.0f, 1165200.0f, 1168650.0f, 1185900.0f, 1189350.0f, 1192800.0f, 1196250.0f, 1199700.0f, 1203150.0f, 1220400.0f, 1223850.0f, 1227300.0f, 1230750.0f, 1234200.0f, 1237650.0f, 1254900.0f, 1258350.0f, 1261800.0f, 1265250.0f, 1268700.0f, 1272150.0f, 1289400.0f, 1292850.0f, 1296300.0f, 1299750.0f, 1303200.0f, 1306650.0f,}; - Nd4jLong _expS[] = {4, 2, 6, 6, 6, 144, 36, 6, 1, 8192, 1, 99}; - NDArray exp(_expB, _expS); - - int sY = 1; - int sX = 1; - int pY = 0; - int pX = 0; - int iC = 2; - int oC = 3; - int kY = 5; - int kX = 5; - int iY = 10; - int iX = 10; - int B = 2; - - auto input = NDArrayFactory::create_('c', {B, iC, iY, iX}); - for (int e = 0; e < input->lengthOf(); e++) - input->p(e, e+1); - - auto weights = NDArrayFactory::create_('c', {oC, iC, kY, kX}); - for (int e = 0; e < weights->lengthOf(); e++) - weights->p(e, e+1); - weights->permutei({2,3,1,0}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, input); - variableSpace->putVariable(-2, weights); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1, -2}); - - block->getIArguments()->push_back(kY); - block->getIArguments()->push_back(kX); - - block->getIArguments()->push_back(sY); - block->getIArguments()->push_back(sX); - - block->getIArguments()->push_back(pY); - block->getIArguments()->push_back(pX); - - // dilation - block->getIArguments()->push_back(1); - block->getIArguments()->push_back(1); - - // NOT same mode - block->getIArguments()->push_back(0); - - sd::ops::sconv2d op; - - Nd4jStatus status = op.execute(block); - - ASSERT_EQ(ND4J_STATUS_OK, status); - auto output = variableSpace->getVariable(1)->getNDArray(); - - //exp.printShapeInfo("Expected shape"); - //output->printShapeInfo("Result shape"); - ASSERT_TRUE(exp.isSameShape(output)); - - //exp.printBuffer("Expctd buffer"); - //output->printBuffer("Result buffer"); - ASSERT_TRUE(exp.equalsTo(output)); - - delete block; - delete variableSpace; + float _expB[] = { + 10025.0f, 10350.0f, 10675.0f, 11000.0f, 11325.0f, 11650.0f, + 13275.0f, 13600.0f, 13925.0f, 14250.0f, 14575.0f, 14900.0f, + 16525.0f, 16850.0f, 17175.0f, 17500.0f, 17825.0f, 18150.0f, + 19775.0f, 20100.0f, 20425.0f, 20750.0f, 21075.0f, 21400.0f, + 23025.0f, 23350.0f, 23675.0f, 24000.0f, 24325.0f, 24650.0f, + 26275.0f, 26600.0f, 26925.0f, 27250.0f, 27575.0f, 27900.0f, + 38775.0f, 40350.0f, 41925.0f, 43500.0f, 45075.0f, 46650.0f, + 54525.0f, 56100.0f, 57675.0f, 59250.0f, 60825.0f, 62400.0f, + 70275.0f, 71850.0f, 73425.0f, 75000.0f, 76575.0f, 78150.0f, + 86025.0f, 87600.0f, 89175.0f, 90750.0f, 92325.0f, 93900.0f, + 101775.0f, 103350.0f, 104925.0f, 106500.0f, 108075.0f, 109650.0f, + 117525.0f, 119100.0f, 120675.0f, 122250.0f, 123825.0f, 125400.0f, + 67525.0f, 70350.0f, 73175.0f, 76000.0f, 78825.0f, 81650.0f, + 95775.0f, 98600.0f, 101425.0f, 104250.0f, 107075.0f, 109900.0f, + 124025.0f, 126850.0f, 129675.0f, 132500.0f, 135325.0f, 138150.0f, + 152275.0f, 155100.0f, 157925.0f, 160750.0f, 163575.0f, 166400.0f, + 180525.0f, 183350.0f, 186175.0f, 189000.0f, 191825.0f, 194650.0f, + 208775.0f, 211600.0f, 214425.0f, 217250.0f, 220075.0f, 222900.0f, + 119400.0f, 120350.0f, 121300.0f, 122250.0f, 123200.0f, 124150.0f, + 128900.0f, 129850.0f, 130800.0f, 131750.0f, 132700.0f, 133650.0f, + 138400.0f, 139350.0f, 140300.0f, 141250.0f, 142200.0f, 143150.0f, + 147900.0f, 148850.0f, 149800.0f, 150750.0f, 151700.0f, 152650.0f, + 157400.0f, 158350.0f, 159300.0f, 160250.0f, 161200.0f, 162150.0f, + 166900.0f, 167850.0f, 168800.0f, 169750.0f, 170700.0f, 171650.0f, + 273150.0f, 275350.0f, 277550.0f, 279750.0f, 281950.0f, 284150.0f, + 295150.0f, 297350.0f, 299550.0f, 301750.0f, 303950.0f, 306150.0f, + 317150.0f, 319350.0f, 321550.0f, 323750.0f, 325950.0f, 328150.0f, + 339150.0f, 341350.0f, 343550.0f, 345750.0f, 347950.0f, 350150.0f, + 361150.0f, 363350.0f, 365550.0f, 367750.0f, 369950.0f, 372150.0f, + 383150.0f, 385350.0f, 387550.0f, 389750.0f, 391950.0f, 394150.0f, + 426900.0f, 430350.0f, 433800.0f, 437250.0f, 440700.0f, 444150.0f, + 461400.0f, 464850.0f, 468300.0f, 471750.0f, 475200.0f, 478650.0f, + 495900.0f, 499350.0f, 502800.0f, 506250.0f, 509700.0f, 513150.0f, + 530400.0f, 533850.0f, 537300.0f, 540750.0f, 544200.0f, 547650.0f, + 564900.0f, 568350.0f, 571800.0f, 575250.0f, 578700.0f, 582150.0f, + 599400.0f, 602850.0f, 606300.0f, 609750.0f, 613200.0f, 616650.0f, + 75025.0f, 75350.0f, 75675.0f, 76000.0f, 76325.0f, 76650.0f, + 78275.0f, 78600.0f, 78925.0f, 79250.0f, 79575.0f, 79900.0f, + 81525.0f, 81850.0f, 82175.0f, 82500.0f, 82825.0f, 83150.0f, + 84775.0f, 85100.0f, 85425.0f, 85750.0f, 86075.0f, 86400.0f, + 88025.0f, 88350.0f, 88675.0f, 89000.0f, 89325.0f, 89650.0f, + 91275.0f, 91600.0f, 91925.0f, 92250.0f, 92575.0f, 92900.0f, + 353775.0f, 355350.0f, 356925.0f, 358500.0f, 360075.0f, 361650.0f, + 369525.0f, 371100.0f, 372675.0f, 374250.0f, 375825.0f, 377400.0f, + 385275.0f, 386850.0f, 388425.0f, 390000.0f, 391575.0f, 393150.0f, + 401025.0f, 402600.0f, 404175.0f, 405750.0f, 407325.0f, 408900.0f, + 416775.0f, 418350.0f, 419925.0f, 421500.0f, 423075.0f, 424650.0f, + 432525.0f, 434100.0f, 435675.0f, 437250.0f, 438825.0f, 440400.0f, + 632525.0f, 635350.0f, 638175.0f, 641000.0f, 643825.0f, 646650.0f, + 660775.0f, 663600.0f, 666425.0f, 669250.0f, 672075.0f, 674900.0f, + 689025.0f, 691850.0f, 694675.0f, 697500.0f, 700325.0f, 703150.0f, + 717275.0f, 720100.0f, 722925.0f, 725750.0f, 728575.0f, 731400.0f, + 745525.0f, 748350.0f, 751175.0f, 754000.0f, 756825.0f, 759650.0f, + 773775.0f, 776600.0f, 779425.0f, 782250.0f, 785075.0f, 787900.0f, + 309400.0f, 310350.0f, 311300.0f, 312250.0f, 313200.0f, 314150.0f, + 318900.0f, 319850.0f, 320800.0f, 321750.0f, 322700.0f, 323650.0f, + 328400.0f, 329350.0f, 330300.0f, 331250.0f, 332200.0f, 333150.0f, + 337900.0f, 338850.0f, 339800.0f, 340750.0f, 341700.0f, 342650.0f, + 347400.0f, 348350.0f, 349300.0f, 350250.0f, 351200.0f, 352150.0f, + 356900.0f, 357850.0f, 358800.0f, 359750.0f, 360700.0f, 361650.0f, + 713150.0f, 715350.0f, 717550.0f, 719750.0f, 721950.0f, 724150.0f, + 735150.0f, 737350.0f, 739550.0f, 741750.0f, 743950.0f, 746150.0f, + 757150.0f, 759350.0f, 761550.0f, 763750.0f, 765950.0f, 768150.0f, + 779150.0f, 781350.0f, 783550.0f, 785750.0f, 787950.0f, 790150.0f, + 801150.0f, 803350.0f, 805550.0f, 807750.0f, 809950.0f, 812150.0f, + 823150.0f, 825350.0f, 827550.0f, 829750.0f, 831950.0f, 834150.0f, + 1116900.0f, 1120350.0f, 1123800.0f, 1127250.0f, 1130700.0f, 1134150.0f, + 1151400.0f, 1154850.0f, 1158300.0f, 1161750.0f, 1165200.0f, 1168650.0f, + 1185900.0f, 1189350.0f, 1192800.0f, 1196250.0f, 1199700.0f, 1203150.0f, + 1220400.0f, 1223850.0f, 1227300.0f, 1230750.0f, 1234200.0f, 1237650.0f, + 1254900.0f, 1258350.0f, 1261800.0f, 1265250.0f, 1268700.0f, 1272150.0f, + 1289400.0f, 1292850.0f, 1296300.0f, 1299750.0f, 1303200.0f, 1306650.0f, + }; + Nd4jLong _expS[] = {4, 2, 6, 6, 6, 144, 36, 6, 1, 8192, 1, 99}; + NDArray exp(_expB, _expS); + + int sY = 1; + int sX = 1; + int pY = 0; + int pX = 0; + int iC = 2; + int oC = 3; + int kY = 5; + int kX = 5; + int iY = 10; + int iX = 10; + int B = 2; + + auto input = NDArrayFactory::create('c', {B, iC, iY, iX}); + for (int e = 0; e < input.lengthOf(); e++) input.p(e, e + 1); + + auto weights = NDArrayFactory::create('c', {oC, iC, kY, kX}); + for (int e = 0; e < weights.lengthOf(); e++) weights.p(e, e + 1); + weights.permutei({2, 3, 1, 0}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); + variableSpace->putVariable(-2, weights); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); + + block->appendI(kY); + block->appendI(kX); + + block->appendI(sY); + block->appendI(sX); + + block->appendI(pY); + block->appendI(pX); + + // dilation + block->appendI(1); + block->appendI(1); + + // NOT same mode + block->appendI(0); + + sd::ops::sconv2d op; + + Nd4jStatus status = op.execute(block); + + ASSERT_EQ(ND4J_STATUS_OK, status); + auto output = variableSpace->getVariable(1)->getNDArray(); + + // exp.printShapeInfo("Expected shape"); + // output->printShapeInfo("Result shape"); + ASSERT_TRUE(exp.isSameShape(*output)); + + // exp.printBuffer("Expctd buffer"); + // output->printBuffer("Result buffer"); + ASSERT_TRUE(exp.equalsTo(*output)); + + delete block; + delete variableSpace; } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, sconv2d_2) { - TypeParam _expBFF[] = {108.9405008f, 109.5920008f, 110.2435008f, 110.8950008f, 111.5465008f, 112.1980008f, 115.4555008f, 116.1070008f, 116.7585008f, 117.410000f, 118.061500f, 118.7130009f, 121.9705009f, 122.6220009f, 123.2735009f, 123.9250009f, 124.5765009f, 125.2280009f, 128.4855009f, 129.1370009f, 129.7885009f, 130.4400009f, 131.09150f, 131.74300f, 135.0005010f, 135.6520010f, 136.3035010f, 136.9550010f, 137.6065010f, 138.2580010f, 141.5155010f, 142.1670010f, 142.8185010f, 143.4700010f, 144.1215010f, 144.7730010f, 248.9617514f, 250.670751f, 252.3797515f, 254.0887515f, 255.7977515f, 257.5067515f, 266.0517515f, 267.7607515f, 269.469751f, 271.1787516f, 272.8877516f, 274.5967516f, 283.1417516f, 284.8507516f, - 286.5597516f, 288.268751f, 289.9777517f, 291.6867517f, 300.2317517f, 301.9407517f, 303.6497517f, 305.3587517f, 307.067751f, 308.7767518f, 317.3217518f, 319.0307518f, 320.7397518f, 322.4487518f, 324.157751f, 325.866751f, 334.4117519f, 336.1207519f, 337.8297519f, 339.5387519f, 341.2477519f, 342.95675f, 388.9829964f, 391.7494964f, 394.5159964f, 397.2824964f, 400.048996f, 402.8154963f, 416.647996f, 419.4144962f, 422.1809962f, 424.9474962f, 427.7139962f, 430.4804962f, 444.3129961f, 447.0794961f, 449.8459961f, 452.6124960f, 455.3789960f, 458.1454960f, 471.9779959f, 474.7444959f, 477.5109959f, 480.2774959f, 483.0439959f, 485.8104958f, 499.6429958f, 502.4094957f, 505.1759957f, 507.9424957f, - 510.7089957f, 513.4754957f, 527.3079956f, 530.0744956f, 532.8409956f, 535.607495f, 538.3739955f, 541.1404955f, 529.0042487f, 532.8282487f, 536.6522487f, 540.4762487f, 544.3002487f, 548.1242487f, 567.2442487f, 571.068248f, 574.892248f, 578.716248f, 582.540248f, 586.3642486f, 605.4842486f, 609.3082486f, 613.1322486f, 616.9562486f, 620.7802486f, 624.6042486f, 643.7242486f, 647.5482486f, 651.3722486f, 655.1962486f, 659.0202486f, 662.8442486f, 681.9642486f, 685.7882486f, 689.6122486f, 693.4362486f, 697.2602486f, 701.0842486f, 720.2042486f, 724.0282486f, 727.852248f, 731.676248f, 735.500248f, 739.324248f, 669.0255044f, 673.9070044f, 678.7885044f, 683.6700044f, 688.5515044f, 693.4330044f, - 717.8405044f, 722.7220044f, 727.6035044f, 732.4850044f, 737.3665044f, 742.2480044f, 766.6555043f, 771.5370043f, 776.4185043f, 781.3000043f, 786.1815043f, 791.0630043f, 815.4705043f, 820.3520043f, 825.2335043f, 830.1150043f, 834.9965043f, 839.8780043f, 864.2855042f, 869.1670042f, 874.0485042f, 878.9300042f, 883.8115042f, 888.6930042f, 913.1005042f, 917.9820042f, 922.8635042f, 927.7450042f, 932.6265042f, 937.5080042f, 809.0467424f, 814.9857424f, 820.9247424f, 826.8637423f, 832.8027423f, 838.7417423f, 868.4367421f, 874.3757421f, 880.3147420f, 886.2537420f, 892.1927420f, 898.13174f, 927.8267418f, 933.7657418f, 939.7047417f, 945.6437417f, 951.5827417f, 957.5217416f, 987.2167415f, 993.155741f, - 999.0947414f, 1005.0337414f, 1010.972741f, 1016.9117413f, 1046.6067412f, 1052.5457411f, 1058.4847411f, 1064.4237411f, 1070.3627410f, 1076.3017410f, 1105.996740f, 1111.9357408f, 1117.8747408f, 1123.8137408f, 1129.7527407f, 1135.6917407f, 949.0679815f, 956.0644814f, 963.060981f, 970.0574813f, 977.0539812f, 984.0504811f, 1019.0329807f, 1026.0294807f, 1033.0259806f, 1040.0224805f, 1047.0189804f, 1054.0154804f, 1088.9979800f, 1095.9944799f, 1102.9909798f, 1109.987479f, 1116.9839797f, 1123.9804796f, 1158.9629792f, 1165.9594791f, 1172.9559791f, 1179.9524790f, 1186.9489789f, 1193.9454788f, 1228.9279785f, 1235.9244784f, 1242.9209783f, 1249.9174782f, 1256.913978f, 1263.9104781f, 1298.8929777f, 1305.8894776f, 1312.8859775f, 1319.8824775f, 1326.8789774f, 1333.8754773f, 1089.0892560f, 1097.1432561f, 1105.1972562f, 1113.251256f, 1121.3052563f, 1129.3592564f, 1169.6292568f, 1177.6832568f, 1185.7372569f, 1193.7912570f, 1201.845257f, 1209.8992571f, 1250.1692575f, 1258.2232576f, 1266.2772576f, 1274.3312577f, 1282.3852578f, 1290.4392579f, 1330.7092582f, 1338.7632583f, 1346.8172584f, 1354.8712584f, 1362.9252585f, 1370.9792586f, 1411.24925f, 1419.3032590f, 1427.3572591f, 1435.4112592f, 1443.465259f, 1451.5192593f, 1491.7892597f, 1499.8432598f, 1507.8972598f, 1515.9512599f, 1524.0052600f, 1532.059260f, 1229.1105073f, 1238.2220073f, 1247.3335073f, 1256.4450073f, 1265.5565073f, 1274.668007f, 1320.2255074f, 1329.3370074f, 1338.4485074f, 1347.5600075f, 1356.6715075f, 1365.7830075f, 1411.340507f, 1420.4520076f, 1429.5635076f, 1438.6750076f, 1447.7865076f, 1456.8980076f, 1502.4555077f, 1511.5670077f, 1520.6785077f, 1529.7900077f, 1538.9015077f, 1548.013007f, 1593.5705078f, 1602.6820078f, 1611.793507f, 1620.9050079f, 1630.0165079f, 1639.1280079f, 1684.6855080f, 1693.7970080f, 1702.9085080f, 1712.0200080f, 1721.1315080f, 1730.2430080f, 1369.1317613f, 1379.3007614f, 1389.4697614f, 1399.6387615f, 1409.8077615f, 1419.976761f, 1470.8217618f, 1480.9907618f, 1491.159761f, 1501.3287619f, 1511.4977619f, 1521.6667620f, 1572.5117622f, 1582.6807622f, 1592.8497623f, 1603.0187623f, 1613.1877624f, 1623.3567624f, 1674.2017626f, 1684.3707627f, 1694.5397627f, 1704.7087628f, 1714.8777628f, 1725.046762f, 1775.8917631f, 1786.0607631f, 1796.229763f, 1806.3987632f, 1816.5677632f, 1826.7367633f, 1877.5817635f, 1887.7507635f, 1897.9197636f, 1908.0887636f, 1918.2577637f, 1928.4267637f, 304.3905022f, 305.0420022f, 305.6935022f, 306.3450022f, 306.9965022f, 307.6480022f, 310.9055022f, 311.5570022f, 312.208502f, 312.860002f, 313.5115023f, 314.1630023f, 317.4205023f, 318.0720023f, 318.7235023f, 319.3750023f, 320.0265023f, 320.6780023f, 323.9355023f, 324.5870023f, 325.2385023f, 325.8900023f, 326.541502f, 327.193002f, 330.4505024f, 331.1020024f, 331.7535024f, 332.4050024f, 333.0565024f, 333.7080024f, 336.9655024f, 337.6170024f, 338.2685024f, 338.9200024f, 339.5715024f, 340.223002f, 761.6617542f, 763.3707542f, 765.0797542f, 766.7887542f, 768.4977542f, 770.206754f, 778.7517543f, 780.4607543f, 782.1697543f, 783.8787543f, 785.5877543f, 787.2967543f, 795.8417544f, 797.5507544f, 799.2597544f, 800.9687544f, 802.6777544f, 804.3867544f, 812.9317545f, 814.6407545f, 816.3497545f, 818.0587545f, 819.7677545f, 821.4767545f, 830.0217546f, 831.7307546f, 833.4397546f, 835.1487546f, 836.8577546f, 838.5667546f, 847.1117547f, 848.8207547f, 850.5297547f, 852.2387547f, 853.9477547f, 855.6567547f, 1218.9329915f, 1221.6994915f, 1224.4659915f, 1227.232491f, 1229.9989914f, 1232.7654914f, 1246.5979913f, 1249.3644913f, 1252.1309913f, 1254.8974913f, 1257.6639913f, 1260.430491f, 1274.2629912f, 1277.029491f, 1279.7959911f, 1282.5624911f, 1285.3289911f, 1288.0954911f, 1301.9279910f, 1304.6944910f, 1307.4609910f, 1310.22749f, 1312.9939909f, 1315.7604909f, 1329.5929908f, 1332.3594908f, 1335.1259908f, 1337.8924908f, 1340.6589908f, 1343.4254908f, 1357.2579907f, - 1360.0244907f, 1362.7909906f, 1365.5574906f, 1368.3239906f, 1371.0904906f, 1676.2042479f, 1680.0282479f, 1683.8522479f, 1687.6762479f, 1691.5002479f, 1695.3242479f, 1714.4442479f, 1718.2682479f, 1722.0922479f, 1725.9162479f, 1729.7402479f, 1733.5642479f, 1752.6842479f, 1756.5082479f, 1760.3322479f, 1764.1562479f, 1767.9802479f, 1771.8042479f, 1790.9242479f, 1794.7482479f, 1798.5722479f, 1802.3962479f, 1806.2202479f, 1810.044247f, 1829.1642478f, 1832.9882478f, 1836.8122478f, 1840.6362478f, 1844.4602478f, 1848.2842478f, 1867.4042478f, 1871.2282478f, 1875.0522478f, 1878.8762478f, 1882.7002478f, 1886.5242478f, 2133.4755029f, 2138.3570029f, 2143.2385029f, 2148.1200029f, 2153.0015029f, 2157.8830029f, 2182.2905028f, 2187.1720028f, 2192.0535028f, 2196.9350028f, 2201.8165028f, 2206.6980028f, 2231.1055028f, 2235.9870028f, 2240.8685028f, 2245.7500028f, 2250.6315028f, 2255.5130028f, 2279.9205027f, 2284.8020027f, 2289.6835027f, 2294.5650027f, 2299.4465027f, 2304.3280027f, 2328.7355027f, 2333.6170027f, 2338.4985027f, 2343.3800027f, 2348.2615027f, 2353.1430027f, 2377.5505026f, 2382.4320026f, 2387.3135026f, 2392.1950026f, 2397.0765026f, 2401.9580026f, 2590.7467330f, 2596.6857330f, 2602.6247329f, 2608.5637329f, 2614.5027329f, 2620.441732f, 2650.1367327f, 2656.0757327f, 2662.0147326f, 2667.9537326f, 2673.8927326f, 2679.8317325f, 2709.5267324f, 2715.465732f, 2721.4047323f, 2727.3437323f, 2733.282732f, 2739.2217322f, 2768.9167321f, 2774.8557320f, 2780.7947320f, 2786.7337320f, 2792.6727319f, 2798.6117319f, 2828.306731f, 2834.2457317f, 2840.1847317f, 2846.1237317f, 2852.0627316f, 2858.0017316f, 2887.6967314f, 2893.6357314f, 2899.5747314f, 2905.5137313f, 2911.4527313f, 2917.3917313f, 3048.0179587f, 3055.0144586f, 3062.0109585f, 3069.0074584f, 3076.0039584f, 3083.0004583f, 3117.9829579f, 3124.9794578f, 3131.9759578f, 3138.9724577f, 3145.9689576f, 3152.9654575f, 3187.947957f, 3194.9444571f, 3201.9409570f, 3208.9374569f, 3215.933956f, 3222.9304568f, 3257.9129564f, 3264.9094563f, 3271.9059562f, 3278.9024562f, 3285.8989561f, - 3292.8954560f, 3327.8779556f, 3334.874455f, 3341.8709555f, 3348.8674554f, 3355.8639553f, 3362.860455f, 3397.8429549f, 3404.8394548f, 3411.8359547f, 3418.8324546f, 3425.8289546f, 3432.8254545f, 3505.28927f, 3513.3432780f, 3521.3972781f, 3529.4512782f, 3537.5052782f, 3545.5592783f, 3585.8292787f, 3593.8832788f, 3601.9372788f, 3609.9912789f, 3618.0452790f, 3626.099279f, - 3666.3692794f, 3674.4232795f, 3682.4772796f, 3690.5312796f, 3698.5852797f, 3706.6392798f, 3746.9092801f, 3754.9632802f, 3763.0172803f, 3771.0712804f, 3779.1252804f, 3787.1792805f, 3827.4492809f, 3835.50328f, 3843.5572810f, 3851.6112811f, 3859.6652812f, 3867.7192812f, 3907.9892816f, 3916.0432817f, 3924.097281f, - 3932.1512818f, 3940.2052819f, 3948.2592820f, 3962.5605113f, 3971.6720113f, 3980.783511f, 3989.8950114f, 3999.0065114f, 4008.1180114f, 4053.6755115f, 4062.7870115f, 4071.8985115f, 4081.0100115f, 4090.1215115f, 4099.2330115f, 4144.7905116f, 4153.9020116f, 4163.0135116f, 4172.1250116f, - 4181.236511f, 4190.3480117f, 4235.9055117f, 4245.0170117f, 4254.128511f, 4263.2400118f, 4272.3515118f, 4281.4630118f, 4327.0205119f, 4336.1320119f, 4345.2435119f, 4354.3550119f, 4363.4665119f, 4372.5780119f, 4418.1355120f, 4427.2470120f, 4436.3585120f, 4445.4700120f, 4454.581512f, 4463.6930121f, 4419.8317743f, 4430.0007744f, 4440.1697744f, 4450.338774f, 4460.5077745f, 4470.6767745f, 4521.521774f, 4531.6907748f, - 4541.8597748f, 4552.0287749f, 4562.1977749f, 4572.3667750f, 4623.2117752f, 4633.3807752f, 4643.5497753f, 4653.7187753f, 4663.8877754f, 4674.0567754f, 4724.9017756f, 4735.0707757f, 4745.2397757f, 4755.4087757f, 4765.5777758f, 4775.7467758f, 4826.591776f, 4836.7607761f, 4846.9297761f, 4857.0987762f, 4867.2677762f, 4877.4367763f, 4928.2817765f, 4938.4507765f, 4948.6197766f, 4958.7887766f, 4968.957776f, 4979.12677675f}; - Nd4jLong _expSFF[] = {4, 2, 10, 6, 6, 360, 36, 6, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99,}; - NDArray expFF(_expBFF, _expSFF); - - auto input = NDArrayFactory::create('c', {2, 3, 10, 10}); - auto weightsD = NDArrayFactory::create('c', {5, 3, 5, 5}); - auto weightsP = NDArrayFactory::create('c', {10, 15, 1, 1}); - - input.linspace(1); - weightsD.linspace(1); - weightsP.linspace(1); - weightsD.permutei({2,3,1,0}); - weightsP.permutei({2,3,1,0}); - - input.applyScalar(scalar::Divide, 100.0, input); - weightsD.applyScalar(scalar::Divide, 100.0, weightsD); - weightsP.applyScalar(scalar::Divide, 100.0, weightsP); - - sd::ops::sconv2d op; - - auto resultFF = op.evaluate({&input, &weightsD, &weightsP}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}); - - auto z = resultFF.at(0); - //z->printShapeInfo("FF shape"); - - - ASSERT_TRUE(z->isSameShape(&expFF)); - - //expFF.printBuffer("e"); - //z->printBuffer("z"); - ASSERT_TRUE(z->equalsTo(&expFF, 1e-3)); + TypeParam _expBFF[] = { + 108.9405008f, 109.5920008f, 110.2435008f, 110.8950008f, + 111.5465008f, 112.1980008f, 115.4555008f, 116.1070008f, + 116.7585008f, 117.410000f, 118.061500f, 118.7130009f, + 121.9705009f, 122.6220009f, 123.2735009f, 123.9250009f, + 124.5765009f, 125.2280009f, 128.4855009f, 129.1370009f, + 129.7885009f, 130.4400009f, 131.09150f, 131.74300f, + 135.0005010f, 135.6520010f, 136.3035010f, 136.9550010f, + 137.6065010f, 138.2580010f, 141.5155010f, 142.1670010f, + 142.8185010f, 143.4700010f, 144.1215010f, 144.7730010f, + 248.9617514f, 250.670751f, 252.3797515f, 254.0887515f, + 255.7977515f, 257.5067515f, 266.0517515f, 267.7607515f, + 269.469751f, 271.1787516f, 272.8877516f, 274.5967516f, + 283.1417516f, 284.8507516f, 286.5597516f, 288.268751f, + 289.9777517f, 291.6867517f, 300.2317517f, 301.9407517f, + 303.6497517f, 305.3587517f, 307.067751f, 308.7767518f, + 317.3217518f, 319.0307518f, 320.7397518f, 322.4487518f, + 324.157751f, 325.866751f, 334.4117519f, 336.1207519f, + 337.8297519f, 339.5387519f, 341.2477519f, 342.95675f, + 388.9829964f, 391.7494964f, 394.5159964f, 397.2824964f, + 400.048996f, 402.8154963f, 416.647996f, 419.4144962f, + 422.1809962f, 424.9474962f, 427.7139962f, 430.4804962f, + 444.3129961f, 447.0794961f, 449.8459961f, 452.6124960f, + 455.3789960f, 458.1454960f, 471.9779959f, 474.7444959f, + 477.5109959f, 480.2774959f, 483.0439959f, 485.8104958f, + 499.6429958f, 502.4094957f, 505.1759957f, 507.9424957f, + 510.7089957f, 513.4754957f, 527.3079956f, 530.0744956f, + 532.8409956f, 535.607495f, 538.3739955f, 541.1404955f, + 529.0042487f, 532.8282487f, 536.6522487f, 540.4762487f, + 544.3002487f, 548.1242487f, 567.2442487f, 571.068248f, + 574.892248f, 578.716248f, 582.540248f, 586.3642486f, + 605.4842486f, 609.3082486f, 613.1322486f, 616.9562486f, + 620.7802486f, 624.6042486f, 643.7242486f, 647.5482486f, + 651.3722486f, 655.1962486f, 659.0202486f, 662.8442486f, + 681.9642486f, 685.7882486f, 689.6122486f, 693.4362486f, + 697.2602486f, 701.0842486f, 720.2042486f, 724.0282486f, + 727.852248f, 731.676248f, 735.500248f, 739.324248f, + 669.0255044f, 673.9070044f, 678.7885044f, 683.6700044f, + 688.5515044f, 693.4330044f, 717.8405044f, 722.7220044f, + 727.6035044f, 732.4850044f, 737.3665044f, 742.2480044f, + 766.6555043f, 771.5370043f, 776.4185043f, 781.3000043f, + 786.1815043f, 791.0630043f, 815.4705043f, 820.3520043f, + 825.2335043f, 830.1150043f, 834.9965043f, 839.8780043f, + 864.2855042f, 869.1670042f, 874.0485042f, 878.9300042f, + 883.8115042f, 888.6930042f, 913.1005042f, 917.9820042f, + 922.8635042f, 927.7450042f, 932.6265042f, 937.5080042f, + 809.0467424f, 814.9857424f, 820.9247424f, 826.8637423f, + 832.8027423f, 838.7417423f, 868.4367421f, 874.3757421f, + 880.3147420f, 886.2537420f, 892.1927420f, 898.13174f, + 927.8267418f, 933.7657418f, 939.7047417f, 945.6437417f, + 951.5827417f, 957.5217416f, 987.2167415f, 993.155741f, + 999.0947414f, 1005.0337414f, 1010.972741f, 1016.9117413f, + 1046.6067412f, 1052.5457411f, 1058.4847411f, 1064.4237411f, + 1070.3627410f, 1076.3017410f, 1105.996740f, 1111.9357408f, + 1117.8747408f, 1123.8137408f, 1129.7527407f, 1135.6917407f, + 949.0679815f, 956.0644814f, 963.060981f, 970.0574813f, + 977.0539812f, 984.0504811f, 1019.0329807f, 1026.0294807f, + 1033.0259806f, 1040.0224805f, 1047.0189804f, 1054.0154804f, + 1088.9979800f, 1095.9944799f, 1102.9909798f, 1109.987479f, + 1116.9839797f, 1123.9804796f, 1158.9629792f, 1165.9594791f, + 1172.9559791f, 1179.9524790f, 1186.9489789f, 1193.9454788f, + 1228.9279785f, 1235.9244784f, 1242.9209783f, 1249.9174782f, + 1256.913978f, 1263.9104781f, 1298.8929777f, 1305.8894776f, + 1312.8859775f, 1319.8824775f, 1326.8789774f, 1333.8754773f, + 1089.0892560f, 1097.1432561f, 1105.1972562f, 1113.251256f, + 1121.3052563f, 1129.3592564f, 1169.6292568f, 1177.6832568f, + 1185.7372569f, 1193.7912570f, 1201.845257f, 1209.8992571f, + 1250.1692575f, 1258.2232576f, 1266.2772576f, 1274.3312577f, + 1282.3852578f, 1290.4392579f, 1330.7092582f, 1338.7632583f, + 1346.8172584f, 1354.8712584f, 1362.9252585f, 1370.9792586f, + 1411.24925f, 1419.3032590f, 1427.3572591f, 1435.4112592f, + 1443.465259f, 1451.5192593f, 1491.7892597f, 1499.8432598f, + 1507.8972598f, 1515.9512599f, 1524.0052600f, 1532.059260f, + 1229.1105073f, 1238.2220073f, 1247.3335073f, 1256.4450073f, + 1265.5565073f, 1274.668007f, 1320.2255074f, 1329.3370074f, + 1338.4485074f, 1347.5600075f, 1356.6715075f, 1365.7830075f, + 1411.340507f, 1420.4520076f, 1429.5635076f, 1438.6750076f, + 1447.7865076f, 1456.8980076f, 1502.4555077f, 1511.5670077f, + 1520.6785077f, 1529.7900077f, 1538.9015077f, 1548.013007f, + 1593.5705078f, 1602.6820078f, 1611.793507f, 1620.9050079f, + 1630.0165079f, 1639.1280079f, 1684.6855080f, 1693.7970080f, + 1702.9085080f, 1712.0200080f, 1721.1315080f, 1730.2430080f, + 1369.1317613f, 1379.3007614f, 1389.4697614f, 1399.6387615f, + 1409.8077615f, 1419.976761f, 1470.8217618f, 1480.9907618f, + 1491.159761f, 1501.3287619f, 1511.4977619f, 1521.6667620f, + 1572.5117622f, 1582.6807622f, 1592.8497623f, 1603.0187623f, + 1613.1877624f, 1623.3567624f, 1674.2017626f, 1684.3707627f, + 1694.5397627f, 1704.7087628f, 1714.8777628f, 1725.046762f, + 1775.8917631f, 1786.0607631f, 1796.229763f, 1806.3987632f, + 1816.5677632f, 1826.7367633f, 1877.5817635f, 1887.7507635f, + 1897.9197636f, 1908.0887636f, 1918.2577637f, 1928.4267637f, + 304.3905022f, 305.0420022f, 305.6935022f, 306.3450022f, + 306.9965022f, 307.6480022f, 310.9055022f, 311.5570022f, + 312.208502f, 312.860002f, 313.5115023f, 314.1630023f, + 317.4205023f, 318.0720023f, 318.7235023f, 319.3750023f, + 320.0265023f, 320.6780023f, 323.9355023f, 324.5870023f, + 325.2385023f, 325.8900023f, 326.541502f, 327.193002f, + 330.4505024f, 331.1020024f, 331.7535024f, 332.4050024f, + 333.0565024f, 333.7080024f, 336.9655024f, 337.6170024f, + 338.2685024f, 338.9200024f, 339.5715024f, 340.223002f, + 761.6617542f, 763.3707542f, 765.0797542f, 766.7887542f, + 768.4977542f, 770.206754f, 778.7517543f, 780.4607543f, + 782.1697543f, 783.8787543f, 785.5877543f, 787.2967543f, + 795.8417544f, 797.5507544f, 799.2597544f, 800.9687544f, + 802.6777544f, 804.3867544f, 812.9317545f, 814.6407545f, + 816.3497545f, 818.0587545f, 819.7677545f, 821.4767545f, + 830.0217546f, 831.7307546f, 833.4397546f, 835.1487546f, + 836.8577546f, 838.5667546f, 847.1117547f, 848.8207547f, + 850.5297547f, 852.2387547f, 853.9477547f, 855.6567547f, + 1218.9329915f, 1221.6994915f, 1224.4659915f, 1227.232491f, + 1229.9989914f, 1232.7654914f, 1246.5979913f, 1249.3644913f, + 1252.1309913f, 1254.8974913f, 1257.6639913f, 1260.430491f, + 1274.2629912f, 1277.029491f, 1279.7959911f, 1282.5624911f, + 1285.3289911f, 1288.0954911f, 1301.9279910f, 1304.6944910f, + 1307.4609910f, 1310.22749f, 1312.9939909f, 1315.7604909f, + 1329.5929908f, 1332.3594908f, 1335.1259908f, 1337.8924908f, + 1340.6589908f, 1343.4254908f, 1357.2579907f, 1360.0244907f, + 1362.7909906f, 1365.5574906f, 1368.3239906f, 1371.0904906f, + 1676.2042479f, 1680.0282479f, 1683.8522479f, 1687.6762479f, + 1691.5002479f, 1695.3242479f, 1714.4442479f, 1718.2682479f, + 1722.0922479f, 1725.9162479f, 1729.7402479f, 1733.5642479f, + 1752.6842479f, 1756.5082479f, 1760.3322479f, 1764.1562479f, + 1767.9802479f, 1771.8042479f, 1790.9242479f, 1794.7482479f, + 1798.5722479f, 1802.3962479f, 1806.2202479f, 1810.044247f, + 1829.1642478f, 1832.9882478f, 1836.8122478f, 1840.6362478f, + 1844.4602478f, 1848.2842478f, 1867.4042478f, 1871.2282478f, + 1875.0522478f, 1878.8762478f, 1882.7002478f, 1886.5242478f, + 2133.4755029f, 2138.3570029f, 2143.2385029f, 2148.1200029f, + 2153.0015029f, 2157.8830029f, 2182.2905028f, 2187.1720028f, + 2192.0535028f, 2196.9350028f, 2201.8165028f, 2206.6980028f, + 2231.1055028f, 2235.9870028f, 2240.8685028f, 2245.7500028f, + 2250.6315028f, 2255.5130028f, 2279.9205027f, 2284.8020027f, + 2289.6835027f, 2294.5650027f, 2299.4465027f, 2304.3280027f, + 2328.7355027f, 2333.6170027f, 2338.4985027f, 2343.3800027f, + 2348.2615027f, 2353.1430027f, 2377.5505026f, 2382.4320026f, + 2387.3135026f, 2392.1950026f, 2397.0765026f, 2401.9580026f, + 2590.7467330f, 2596.6857330f, 2602.6247329f, 2608.5637329f, + 2614.5027329f, 2620.441732f, 2650.1367327f, 2656.0757327f, + 2662.0147326f, 2667.9537326f, 2673.8927326f, 2679.8317325f, + 2709.5267324f, 2715.465732f, 2721.4047323f, 2727.3437323f, + 2733.282732f, 2739.2217322f, 2768.9167321f, 2774.8557320f, + 2780.7947320f, 2786.7337320f, 2792.6727319f, 2798.6117319f, + 2828.306731f, 2834.2457317f, 2840.1847317f, 2846.1237317f, + 2852.0627316f, 2858.0017316f, 2887.6967314f, 2893.6357314f, + 2899.5747314f, 2905.5137313f, 2911.4527313f, 2917.3917313f, + 3048.0179587f, 3055.0144586f, 3062.0109585f, 3069.0074584f, + 3076.0039584f, 3083.0004583f, 3117.9829579f, 3124.9794578f, + 3131.9759578f, 3138.9724577f, 3145.9689576f, 3152.9654575f, + 3187.947957f, 3194.9444571f, 3201.9409570f, 3208.9374569f, + 3215.933956f, 3222.9304568f, 3257.9129564f, 3264.9094563f, + 3271.9059562f, 3278.9024562f, 3285.8989561f, 3292.8954560f, + 3327.8779556f, 3334.874455f, 3341.8709555f, 3348.8674554f, + 3355.8639553f, 3362.860455f, 3397.8429549f, 3404.8394548f, + 3411.8359547f, 3418.8324546f, 3425.8289546f, 3432.8254545f, + 3505.28927f, 3513.3432780f, 3521.3972781f, 3529.4512782f, + 3537.5052782f, 3545.5592783f, 3585.8292787f, 3593.8832788f, + 3601.9372788f, 3609.9912789f, 3618.0452790f, 3626.099279f, + 3666.3692794f, 3674.4232795f, 3682.4772796f, 3690.5312796f, + 3698.5852797f, 3706.6392798f, 3746.9092801f, 3754.9632802f, + 3763.0172803f, 3771.0712804f, 3779.1252804f, 3787.1792805f, + 3827.4492809f, 3835.50328f, 3843.5572810f, 3851.6112811f, + 3859.6652812f, 3867.7192812f, 3907.9892816f, 3916.0432817f, + 3924.097281f, 3932.1512818f, 3940.2052819f, 3948.2592820f, + 3962.5605113f, 3971.6720113f, 3980.783511f, 3989.8950114f, + 3999.0065114f, 4008.1180114f, 4053.6755115f, 4062.7870115f, + 4071.8985115f, 4081.0100115f, 4090.1215115f, 4099.2330115f, + 4144.7905116f, 4153.9020116f, 4163.0135116f, 4172.1250116f, + 4181.236511f, 4190.3480117f, 4235.9055117f, 4245.0170117f, + 4254.128511f, 4263.2400118f, 4272.3515118f, 4281.4630118f, + 4327.0205119f, 4336.1320119f, 4345.2435119f, 4354.3550119f, + 4363.4665119f, 4372.5780119f, 4418.1355120f, 4427.2470120f, + 4436.3585120f, 4445.4700120f, 4454.581512f, 4463.6930121f, + 4419.8317743f, 4430.0007744f, 4440.1697744f, 4450.338774f, + 4460.5077745f, 4470.6767745f, 4521.521774f, 4531.6907748f, + 4541.8597748f, 4552.0287749f, 4562.1977749f, 4572.3667750f, + 4623.2117752f, 4633.3807752f, 4643.5497753f, 4653.7187753f, + 4663.8877754f, 4674.0567754f, 4724.9017756f, 4735.0707757f, + 4745.2397757f, 4755.4087757f, 4765.5777758f, 4775.7467758f, + 4826.591776f, 4836.7607761f, 4846.9297761f, 4857.0987762f, + 4867.2677762f, 4877.4367763f, 4928.2817765f, 4938.4507765f, + 4948.6197766f, 4958.7887766f, 4968.957776f, 4979.12677675f}; + Nd4jLong _expSFF[] = { + 4, 2, 10, 6, 6, + 360, 36, 6, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, + 1, 99, + }; + NDArray expFF(_expBFF, _expSFF); + + auto input = NDArrayFactory::create('c', {2, 3, 10, 10}); + auto weightsD = NDArrayFactory::create('c', {5, 3, 5, 5}); + auto weightsP = NDArrayFactory::create('c', {10, 15, 1, 1}); + + input.linspace(1); + weightsD.linspace(1); + weightsP.linspace(1); + weightsD.permutei({2, 3, 1, 0}); + weightsP.permutei({2, 3, 1, 0}); + + input.applyScalar(scalar::Divide, 100.0, input); + weightsD.applyScalar(scalar::Divide, 100.0, weightsD); + weightsP.applyScalar(scalar::Divide, 100.0, weightsP); + + sd::ops::sconv2d op; + + auto resultFF = op.evaluate({&input, &weightsD, &weightsP}, + {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}); + + auto z = resultFF.at(0); + // z->printShapeInfo("FF shape"); + + ASSERT_TRUE(z.isSameShape(&expFF)); + + // expFF.printBuffer("e"); + // z->printBuffer("z"); + ASSERT_TRUE(z.equalsTo(&expFF, 1e-3)); } TYPED_TEST(TypedConvolutionTests1, sconv2d_3) { - auto input = NDArrayFactory::create('c', {3, 3, 8, 8}); - auto weightsD = NDArrayFactory::create('c', {1, 3, 1, 1}); - auto weightsP = NDArrayFactory::create('c', {2, 3, 1, 1}); - auto bias = NDArrayFactory::create('c', {2}); - auto output = NDArrayFactory::create('c', {3, 2, 8, 8}); - output.assign(0.0); - - input.linspace(1); - weightsD.linspace(1); - weightsP.linspace(1); - bias.linspace(1); - weightsD.permutei({2,3,1,0}); - weightsP.permutei({2,3,1,0}); - - auto expOutput = NDArrayFactory::create('c', {3, 2, 8, 8}); - - sd::ops::sconv2d op; - Nd4jStatus status = op.execute({&input, &weightsD, &weightsP, &bias}, {&output}, {1, 1, 1, 1, 0, 0, 1, 1, 0}); - auto result = op.evaluate({&input, &weightsD, &weightsP, &bias}, {1, 1, 1, 1, 0, 0, 1, 1, 0}); - - auto z = result.at(0); - - //printf("\n"); - //output.printBuffer("output"); - //z->printBuffer("z"); - - - //ASSERT_TRUE(expOutput.isSameShape(z)); + auto input = NDArrayFactory::create('c', {3, 3, 8, 8}); + auto weightsD = NDArrayFactory::create('c', {1, 3, 1, 1}); + auto weightsP = NDArrayFactory::create('c', {2, 3, 1, 1}); + auto bias = NDArrayFactory::create('c', {2}); + auto output = NDArrayFactory::create('c', {3, 2, 8, 8}); + output.assign(0.0); + + input.linspace(1); + weightsD.linspace(1); + weightsP.linspace(1); + bias.linspace(1); + weightsD.permutei({2, 3, 1, 0}); + weightsP.permutei({2, 3, 1, 0}); + + auto expOutput = NDArrayFactory::create('c', {3, 2, 8, 8}); + + sd::ops::sconv2d op; + Nd4jStatus status = op.execute({&input, &weightsD, &weightsP, &bias}, + {&output}, {1, 1, 1, 1, 0, 0, 1, 1, 0}); + auto result = op.evaluate({&input, &weightsD, &weightsP, &bias}, + {1, 1, 1, 1, 0, 0, 1, 1, 0}); + + auto z = result.at(0); + + // printf("\n"); + // output.printBuffer("output"); + // z->printBuffer("z"); + + // ASSERT_TRUE(expOutput.isSameShape(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, sconv2d_4) { - - int bS=1, iH=6,iW=6, iC=3,oC=2,mC=3, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=6,oW=6; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iC, iH, iW}, {0.679350, 0.355087, 0.842789, 0.200313, 0.701499, 0.310693, 0.447940, 0.938010, 0.326674, 0.151873, 0.383318, 0.782123, 0.198807, - 0.798564, 0.163263, 0.146968, 0.260897, 0.135058, 0.756209, 0.275454, 0.369088, 0.092826, 0.836492, 0.268413, 0.095062, 0.312795, 0.135918, 0.517544, 0.328703, - 0.061736, 0.396431, 0.248016, 0.548959, 0.115046, 0.814362, 0.721564, 0.404494, 0.299089, 0.403884, 0.988311, 0.022296, 0.927782, 0.318416, 0.068546, 0.284533, - 0.232720, 0.352142, 0.058909, 0.711221, 0.674457, 0.196946, 0.699497, 0.074322, 0.420425, 0.584263, 0.149574, 0.446406, 0.723072, 0.064481, 0.483078, 0.875996, - 0.569819, 0.445863, 0.527755, 0.016646, 0.753678, 0.140636, 0.754129, 0.161932, 0.775037, 0.332645, 0.117394, 0.017711, 0.608476, 0.525152, 0.917194, 0.849891, - 0.589423, 0.852278, 0.390636, 0.889683, 0.669445, 0.698873, 0.961480, 0.157401, 0.157364, 0.493520, 0.569937, 0.126832, 0.115728, 0.786368, 0.737939, 0.490079, - 0.608414, 0.956500, 0.390098, 0.147305, 0.850645, 0.497650, 0.071866, 0.082150, 0.035314, 0.732041, 0.369934, 0.840666, 0.273894, 0.431796, 0.133231}); - NDArray weightsD('c', {kH, kW, iC, mC}, {0.5340641736984253, 0.8257383108139038, 0.3279532492160797, 0.27217748761177063, 0.05432872101664543, 0.31322699785232544, 0.6599581837654114, 0.35526034235954285, 0.5765137672424316}); - NDArray weightsP('c', {1, 1, iC*mC, oC}, {0.4442146420478821, 0.3362849950790405, 0.5215804576873779, 0.5305071473121643, 0.7323054075241089, 0.5168435573577881, 0.8601323962211609, 0.2587810158729553, 0.9473239779472351, 0.39540114998817444, 0.04835261031985283, 0.8724213242530823, 0.8607604503631592, 0.8382210731506348, 0.8573186993598938, 0.6496091485023499, 0.8864102959632874, 0.14267340302467346}); - NDArray biases('c', {1,oC}, {0.8807470202445984, 0.6262521147727966}); - - NDArray expOutput('c', {bS, oC, oH, oW}, {1.643804, 2.135067, 2.494167, 2.628944, 2.700440, 2.257452, 2.562539, 2.293667, 2.493985, 2.014933, 2.301736, 2.939066, 1.492952, - 2.026476, 1.771098, 2.013162, 1.315507, 1.289951, 2.831223, 2.196924, 2.028261, 2.024326, 2.983223, 1.809527, 1.434322, 2.513157, 1.826834, 1.608869, 1.297912, 1.212318, - 2.295934, 1.844615, 2.591148, 1.597267, 2.317755, 1.755642, 1.324064, 1.542060, 1.892052, 1.939339, 1.922781, 1.720199, 1.833396, 1.728024, 1.757968, 1.410675, 1.661960, - 2.096277, 1.178815, 1.637460, 1.254187, 1.491076, 0.968625, 0.986342, 2.116042, 1.536920, 1.504321, 1.490398, 2.136795, 1.351860, 1.148578, 1.817408, 1.327139, 1.288620, - 0.962232, 0.980667, 1.623775, 1.417320, 1.845710, 1.237095, 1.762792, 1.352515}); - - sd::ops::sconv2d op; - auto results = op.evaluate({&input, &weightsD, &weightsP, &biases}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - + int bS = 1, iH = 6, iW = 6, iC = 3, oC = 2, mC = 3, kH = 1, kW = 1, sH = 1, + sW = 1, pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 6, oW = 6; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input( + 'c', {bS, iC, iH, iW}, + {0.679350, 0.355087, 0.842789, 0.200313, 0.701499, 0.310693, 0.447940, + 0.938010, 0.326674, 0.151873, 0.383318, 0.782123, 0.198807, 0.798564, + 0.163263, 0.146968, 0.260897, 0.135058, 0.756209, 0.275454, 0.369088, + 0.092826, 0.836492, 0.268413, 0.095062, 0.312795, 0.135918, 0.517544, + 0.328703, 0.061736, 0.396431, 0.248016, 0.548959, 0.115046, 0.814362, + 0.721564, 0.404494, 0.299089, 0.403884, 0.988311, 0.022296, 0.927782, + 0.318416, 0.068546, 0.284533, 0.232720, 0.352142, 0.058909, 0.711221, + 0.674457, 0.196946, 0.699497, 0.074322, 0.420425, 0.584263, 0.149574, + 0.446406, 0.723072, 0.064481, 0.483078, 0.875996, 0.569819, 0.445863, + 0.527755, 0.016646, 0.753678, 0.140636, 0.754129, 0.161932, 0.775037, + 0.332645, 0.117394, 0.017711, 0.608476, 0.525152, 0.917194, 0.849891, + 0.589423, 0.852278, 0.390636, 0.889683, 0.669445, 0.698873, 0.961480, + 0.157401, 0.157364, 0.493520, 0.569937, 0.126832, 0.115728, 0.786368, + 0.737939, 0.490079, 0.608414, 0.956500, 0.390098, 0.147305, 0.850645, + 0.497650, 0.071866, 0.082150, 0.035314, 0.732041, 0.369934, 0.840666, + 0.273894, 0.431796, 0.133231}); + NDArray weightsD( + 'c', {kH, kW, iC, mC}, + {0.5340641736984253, 0.8257383108139038, 0.3279532492160797, + 0.27217748761177063, 0.05432872101664543, 0.31322699785232544, + 0.6599581837654114, 0.35526034235954285, 0.5765137672424316}); + NDArray weightsP( + 'c', {1, 1, iC * mC, oC}, + {0.4442146420478821, 0.3362849950790405, 0.5215804576873779, + 0.5305071473121643, 0.7323054075241089, 0.5168435573577881, + 0.8601323962211609, 0.2587810158729553, 0.9473239779472351, + 0.39540114998817444, 0.04835261031985283, 0.8724213242530823, + 0.8607604503631592, 0.8382210731506348, 0.8573186993598938, + 0.6496091485023499, 0.8864102959632874, 0.14267340302467346}); + NDArray biases('c', {1, oC}, {0.8807470202445984, 0.6262521147727966}); + + NDArray expOutput( + 'c', {bS, oC, oH, oW}, + {1.643804, 2.135067, 2.494167, 2.628944, 2.700440, 2.257452, 2.562539, + 2.293667, 2.493985, 2.014933, 2.301736, 2.939066, 1.492952, 2.026476, + 1.771098, 2.013162, 1.315507, 1.289951, 2.831223, 2.196924, 2.028261, + 2.024326, 2.983223, 1.809527, 1.434322, 2.513157, 1.826834, 1.608869, + 1.297912, 1.212318, 2.295934, 1.844615, 2.591148, 1.597267, 2.317755, + 1.755642, 1.324064, 1.542060, 1.892052, 1.939339, 1.922781, 1.720199, + 1.833396, 1.728024, 1.757968, 1.410675, 1.661960, 2.096277, 1.178815, + 1.637460, 1.254187, 1.491076, 0.968625, 0.986342, 2.116042, 1.536920, + 1.504321, 1.490398, 2.136795, 1.351860, 1.148578, 1.817408, 1.327139, + 1.288620, 0.962232, 0.980667, 1.623775, 1.417320, 1.845710, 1.237095, + 1.762792, 1.352515}); + + sd::ops::sconv2d op; + auto results = + op.evaluate({&input, &weightsD, &weightsP, &biases}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } TYPED_TEST(TypedConvolutionTests1, conv2D_BP_Bias_1) { - TypeParam _expWGradB[] = {9312.0, 12580.0, 9528.0, 13168.0, 17712.0, 13360.0, 9960.0, 13348.0, 10032.0, 13344.0, 18148.0, 13848.0, 19312.0, 26160.0, 19888.0, 15144.0, 20452.0, 15504.0}; - Nd4jLong _expWGradS[] = {4, 2, 1, 3, 3, 9, 9, 3, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; - NDArray expWGrad(_expWGradB, _expWGradS); - expWGrad.permutei({2,3,1,0}); - - TypeParam _expBGradB[] = {784.0, 1296.0}; - Nd4jLong _expBGradS[] = {2, 2, 1, 1, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; - - NDArray expBGrad(_expBGradB, _expBGradS); - - auto input = NDArrayFactory::create('c', {2, 1, 4, 4}); - auto weights = NDArrayFactory::create('c', {2, 1, 3, 3}); - auto bias = NDArrayFactory::create('c', {2, 1}); - auto epsilonNext = NDArrayFactory::create('c', {2, 2, 4, 4}); - - - TypeParam _expEpsB[] = {952.0, 1540.0, 1636.0, 1180.0, 1791.0, 2886.0, 3057.0, 2193.0, 2223.0, 3570.0, 3741.0, 2673.0, 1900.0, 3028.0, 3160.0, 2240.0, 2872.0, 4612.0, 4708.0, 3356.0, 5247.0, 8358.0, 8529.0, 6033.0, 5679.0, 9042.0, 9213.0, 6513.0, 4588.0, 7252.0, 7384.0, 5184.0}; - NDArray expEps(_expEpsB, input.shapeInfo()); - - input.linspace(1); - weights.linspace(1); - epsilonNext.linspace(1); - weights.permutei({2,3,1,0}); - - sd::ops::conv2d_bp op; - - auto results = op.evaluate({&input, &weights, &bias, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}, {}); - - ASSERT_TRUE(results.size() == 3); - - auto epsilon = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); - - ASSERT_TRUE(expWGrad.isSameShape(gradW)); - - //expWGrad.printBuffer("Expctd buffer"); - // gradW->printBuffer("Result buffer"); - ASSERT_TRUE(expWGrad.equalsTo(gradW)); - - - ASSERT_TRUE(input.isSameShape(epsilon)); - - // expEps.printBuffer("Expctd buffer"); - //epsilon->printBuffer("Result buffer"); - ASSERT_TRUE(expEps.equalsTo(epsilon)); - - ASSERT_TRUE(expBGrad.isSameShape(gradB)); - - ASSERT_TRUE(expBGrad.equalsTo(gradB)); - - -} - - -TYPED_TEST(TypedConvolutionTests1, conv2D_BP_NoBias_1) { - TypeParam _expWGradB[] = {9312.0, 12580.0, 9528.0, 13168.0, 17712.0, 13360.0, 9960.0, 13348.0, 10032.0, 13344.0, 18148.0, 13848.0, 19312.0, 26160.0, 19888.0, 15144.0, 20452.0, 15504.0}; - Nd4jLong _expWGradS[] = {4, 2, 1, 3, 3, 9, 9, 3, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; - NDArray expWGrad(_expWGradB, _expWGradS); - expWGrad.permutei({2,3,1,0}); - - auto input = NDArrayFactory::create('c', {2, 1, 4, 4}); - auto weights = NDArrayFactory::create('c', {2, 1, 3, 3}); - auto epsilonNext = NDArrayFactory::create('c', {2, 2, 4, 4}); - - - TypeParam _expEpsB[] = {952.0, 1540.0, 1636.0, 1180.0, 1791.0, 2886.0, 3057.0, 2193.0, 2223.0, 3570.0, 3741.0, 2673.0, 1900.0, 3028.0, 3160.0, 2240.0, 2872.0, 4612.0, 4708.0, 3356.0, 5247.0, 8358.0, 8529.0, 6033.0, 5679.0, 9042.0, 9213.0, 6513.0, 4588.0, 7252.0, 7384.0, 5184.0}; - NDArray expEps(_expEpsB, input.shapeInfo()); - - input.linspace(1); - weights.linspace(1); - epsilonNext.linspace(1); - weights.permutei({2,3,1,0}); - - sd::ops::conv2d_bp op; - - auto results = op.evaluate({&input, &weights, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}, {}); - - ASSERT_TRUE(results.size() == 2); - - auto epsilon = results.at(0); - auto gradW = results.at(1); - - ASSERT_TRUE(expWGrad.isSameShape(gradW)); - - //expWGrad.printBuffer("Expctd buffer"); - // gradW->printBuffer("Result buffer"); - ASSERT_TRUE(expWGrad.equalsTo(gradW)); - - - ASSERT_TRUE(input.isSameShape(epsilon)); - - // expEps.printBuffer("Expctd buffer"); - //epsilon->printBuffer("Result buffer"); - ASSERT_TRUE(expEps.equalsTo(epsilon)); - - -} - -TYPED_TEST(TypedConvolutionTests1, sconv2d_conv2d_1) { - - auto input = NDArrayFactory::create('c', {2, 3, 10, 10}); - auto weightsD = NDArrayFactory::create('c', {5, 5, 3, 2}, {1.f, 76.f, 26.f, 101.f, 51.f, 126.f, 2.f, 77.f, 27.f, 102.f, 52.f, 127.f, 3.f, 78.f, 28.f, 103.f, 53.f, 128.f, 4.f, 79.f, 29.f, 104.f, 54.f, 129.f, 5.f, 80.f, 30.f, 105.f, 55.f, 130.f, - 6.f, 81.f, 31.f, 106.f, 56.f, 131.f, 7.f, 82.f, 32.f, 107.f, 57.f, 132.f, 8.f, 83.f, 33.f, 108.f, 58.f, 133.f, 9.f, 84.f, 34.f, 109.f, 59.f, 134.f, 10.f, 85.f, 35.f, 110.f, 60.f, 135.f, - 11.f, 86.f, 36.f, 111.f, 61.f, 136.f, 12.f, 87.f, 37.f, 112.f, 62.f, 137.f, 13.f, 88.f, 38.f, 113.f, 63.f, 138.f, 14.f, 89.f, 39.f, 114.f, 64.f, 139.f, 15.f, 90.f, 40.f, 115.f, 65.f, 140.f, - 16.f, 91.f, 41.f, 116.f, 66.f, 141.f, 17.f, 92.f, 42.f, 117.f, 67.f, 142.f, 18.f, 93.f, 43.f, 118.f, 68.f, 143.f, 19.f, 94.f, 44.f, 119.f, 69.f, 144.f, 20.f, 95.f, 45.f, 120.f, 70.f, 145.f, - 21.f, 96.f, 46.f, 121.f, 71.f, 146.f, 22.f, 97.f, 47.f, 122.f, 72.f, 147.f, 23.f, 98.f, 48.f, 123.f, 73.f, 148.f, 24.f, 99.f, 49.f, 124.f, 74.f, 149.f, 25.f, 100.f, 50.f, 125.f, 75.f, 150.f}); - auto weightsP = NDArrayFactory::create('c', {1, 1, 6, 10}, {0.0001f, 0.0007f, 0.0013f, 0.0019f, 0.0025f, 0.0031f, 0.0037f, 0.0043f, 0.0049f, 0.0055f,0.0002f, 0.0008f, 0.0014f, 0.0020f, 0.0026f, 0.0032f, 0.0038f, 0.0044f, 0.0050f, 0.0056f, - 0.0003f, 0.0009f, 0.0015f, 0.0021f, 0.0027f, 0.0033f, 0.0039f, 0.0045f, 0.0051f, 0.0057f,0.0004f, 0.0010f, 0.0016f, 0.0022f, 0.0028f, 0.0034f, 0.0040f, 0.0046f, 0.0052f, 0.0058f, - 0.0005f, 0.0011f, 0.0017f, 0.0023f, 0.0029f, 0.0035f, 0.0041f, 0.0047f, 0.0053f, 0.0059f,0.0006f, 0.0012f, 0.0018f, 0.0024f, 0.0030f, 0.0036f, 0.0042f, 0.0048f, 0.0054f, 0.0060f}); - - auto expFF = NDArrayFactory::create('c', {2, 6, 6, 6}, {10025.0f,10350.0f,10675.0f,11000.0f,11325.0f,11650.0f,13275.0f,13600.0f,13925.0f,14250.0f,14575.0f,14900.0f,16525.0f,16850.0f, - 17175.0f,17500.0f,17825.0f,18150.0f,19775.0f,20100.0f,20425.0f,20750.0f,21075.0f,21400.0f,23025.0f,23350.0f,23675.0f,24000.0f, - 24325.0f,24650.0f,26275.0f,26600.0f,26925.0f,27250.0f,27575.0f,27900.0f,53150.0f,55350.0f,57550.0f,59750.0f,61950.0f,64150.0f, - 75150.0f,77350.0f,79550.0f,81750.0f,83950.0f,86150.0f,97150.0f,99350.0f,101550.0f,103750.0f,105950.0f,108150.0f,119150.0f, - 121350.0f,123550.0f,125750.0f,127950.0f,130150.0f,141150.0f,143350.0f,145550.0f,147750.0f,149950.0f,152150.0f,163150.0f, - 165350.0f,167550.0f,169750.0f,171950.0f,174150.0f,119400.0f,120350.0f,121300.0f,122250.0f,123200.0f,124150.0f,128900.0f, - 129850.0f,130800.0f,131750.0f,132700.0f,133650.0f,138400.0f,139350.0f,140300.0f,141250.0f,142200.0f,143150.0f,147900.0f, - 148850.0f,149800.0f,150750.0f,151700.0f,152650.0f,157400.0f,158350.0f,159300.0f,160250.0f,161200.0f,162150.0f,166900.0f, - 167850.0f,168800.0f,169750.0f,170700.0f,171650.0f,350025.0f,352850.0f,355675.0f,358500.0f,361325.0f,364150.0f,378275.0f, - 381100.0f,383925.0f,386750.0f,389575.0f,392400.0f,406525.0f,409350.0f,412175.0f,415000.0f,417825.0f,420650.0f,434775.0f, - 437600.0f,440425.0f,443250.0f,446075.0f,448900.0f,463025.0f,465850.0f,468675.0f,471500.0f,474325.0f,477150.0f,491275.0f, - 494100.0f,496925.0f,499750.0f,502575.0f,505400.0f,353775.0f,355350.0f,356925.0f,358500.0f,360075.0f,361650.0f,369525.0f, - 371100.0f,372675.0f,374250.0f,375825.0f,377400.0f,385275.0f,386850.0f,388425.0f,390000.0f,391575.0f,393150.0f,401025.0f, - 402600.0f,404175.0f,405750.0f,407325.0f,408900.0f,416775.0f,418350.0f,419925.0f,421500.0f,423075.0f,424650.0f,432525.0f, - 434100.0f,435675.0f,437250.0f,438825.0f,440400.0f,771900.0f,775350.0f,778800.0f,782250.0f,785700.0f,789150.0f,806400.0f, - 809850.0f,813300.0f,816750.0f,820200.0f,823650.0f,840900.0f,844350.0f,847800.0f,851250.0f,854700.0f,858150.0f,875400.0f, - 878850.0f,882300.0f,885750.0f,889200.0f,892650.0f,909900.0f,913350.0f,916800.0f,920250.0f,923700.0f,927150.0f,944400.0f, - 947850.0f,951300.0f,954750.0f,958200.0f,961650.0f,107525.0f,107850.0f,108175.0f,108500.0f,108825.0f,109150.0f,110775.0f, - 111100.0f,111425.0f,111750.0f,112075.0f,112400.0f,114025.0f,114350.0f,114675.0f,115000.0f,115325.0f,115650.0f,117275.0f, - 117600.0f,117925.0f,118250.0f,118575.0f,118900.0f,120525.0f,120850.0f,121175.0f,121500.0f,121825.0f,122150.0f,123775.0f, - 124100.0f,124425.0f,124750.0f,125075.0f,125400.0f,713150.0f,715350.0f,717550.0f,719750.0f,721950.0f,724150.0f,735150.0f, - 737350.0f,739550.0f,741750.0f,743950.0f,746150.0f,757150.0f,759350.0f,761550.0f,763750.0f,765950.0f,768150.0f,779150.0f, - 781350.0f,783550.0f,785750.0f,787950.0f,790150.0f,801150.0f,803350.0f,805550.0f,807750.0f,809950.0f,812150.0f,823150.0f, - 825350.0f,827550.0f,829750.0f,831950.0f,834150.0f,404400.0f,405350.0f,406300.0f,407250.0f,408200.0f,409150.0f,413900.0f, - 414850.0f,415800.0f,416750.0f,417700.0f,418650.0f,423400.0f,424350.0f,425300.0f,426250.0f,427200.0f,428150.0f,432900.0f,433850.0f,434800.0f,435750.0f,436700.0f,437650.0f,442400.0f,443350.0f,444300.0f,445250.0f,446200.0f,447150.0f,451900.0f,452850.0f,453800.0f,454750.0f,455700.0f,456650.0f,1197525.0f,1200350.0f,1203175.0f,1206000.0f,1208825.0f,1211650.0f,1225775.0f,1228600.0f,1231425.0f,1234250.0f,1237075.0f,1239900.0f,1254025.0f,1256850.0f,1259675.0f,1262500.0f,1265325.0f,1268150.0f,1282275.0f,1285100.0f,1287925.0f,1290750.0f,1293575.0f,1296400.0f,1310525.0f,1313350.0f,1316175.0f,1319000.0f,1321825.0f,1324650.0f,1338775.0f,1341600.0f,1344425.0f,1347250.0f,1350075.0f,1352900.0f,826275.0f,827850.0f,829425.0f,831000.0f,832575.0f,834150.0f,842025.0f,843600.0f,845175.0f,846750.0f,848325.0f,849900.0f,857775.0f,859350.0f,860925.0f,862500.0f,864075.0f,865650.0f,873525.0f,875100.0f,876675.0f,878250.0f,879825.0f,881400.0f,889275.0f,890850.0f,892425.0f,894000.0f,895575.0f,897150.0f,905025.0f,906600.0f,908175.0f,909750.0f,911325.0f,912900.0f,1806900.0f,1810350.0f,1813800.0f,1817250.0f,1820700.0f,1824150.0f,1841400.0f,1844850.0f,1848300.0f,1851750.0f,1855200.0f,1858650.0f,1875900.0f,1879350.0f,1882800.0f,1886250.0f,1889700.0f,1893150.0f,1910400.0f,1913850.0f,1917300.0f,1920750.0f,1924200.0f,1927650.0f,1944900.0f,1948350.0f,1951800.0f,1955250.0f,1958700.0f,1962150.0f,1979400.0f,1982850.0f,1986300.0f,1989750.0f,1993200.0f,1996650.f}); - auto exp2FF = NDArrayFactory::create('c', {2, 10, 6, 6}, {827.4900282f,832.2350283f,836.9800284f,841.725028f,846.4700287f,851.2150288f,874.9400293f,879.6850294f,884.4300295f,889.1750296f,893.9200297f,898.665029f, - 922.3900304f,927.1350305f,931.8800306f,936.6250307f,941.3700308f,946.1150309f,969.8400315f,974.5850316f,979.3300317f,984.0750318f,988.8200319f,993.5650320f, - 1017.2900326f,1022.0350327f,1026.7800328f,1031.5250329f,1036.2700330f,1041.0150331f,1064.7400337f,1069.4850338f,1074.2300339f,1078.9750340f,1083.7200341f, - 1088.4650342f,1822.4550553f,1833.995055f,1845.5350558f,1857.075056f,1868.6150563f,1880.1550566f,1937.8550578f,1949.3950581f,1960.9350583f,1972.4750586f, - 1984.015058f,1995.5550591f,2053.2550604f,2064.7950606f,2076.3350609f,2087.8750611f,2099.4150614f,2110.955061f,2168.6550629f,2180.1950632f,2191.7350634f, - 2203.2750637f,2214.8150639f,2226.3550642f,2284.0550655f,2295.5950657f,2307.1350660f,2318.6750662f,2330.2150665f,2341.7550667f,2399.4550680f,2410.9950683f, - 2422.5350685f,2434.0750688f,2445.6150690f,2457.1550693f,2817.419968f,2835.7549686f,2854.0899683f,2872.4249680f,2890.7599677f,2909.0949674f,3000.7699660f, - 3019.104965f,3037.4399655f,3055.7749652f,3074.1099649f,3092.4449646f,3184.1199632f,3202.4549629f,3220.789962f,3239.1249624f,3257.4599621f,3275.7949618f, - 3367.4699604f,3385.8049601f,3404.1399598f,3422.474959f,3440.8099593f,3459.1449590f,3550.8199576f,3569.1549573f,3587.4899570f,3605.8249567f,3624.1599565f, - 3642.4949562f,3734.1699548f,3752.5049545f,3770.8399542f,3789.1749539f,3807.5099536f,3825.8449534f,3812.385098f,3837.5150988f,3862.6450994f,3887.7751000f, - 3912.9051006f,3938.0351012f,4063.6851041f,4088.8151047f,4113.9451053f,4139.0751059f,4164.2051065f,4189.3351071f,4314.9851100f,4340.1151106f,4365.2451112f, - 4390.3751118f,4415.5051124f,4440.6351130f,4566.2851159f,4591.4151165f,4616.5451171f,4641.6751177f,4666.805118f,4691.9351188f,4817.5851218f,4842.7151224f, - 4867.8451230f,4892.975123f,4918.1051241f,4943.2351247f,5068.8851277f,5094.0151283f,5119.1451288f,5144.2751294f,5169.4051300f,5194.5351306f,4807.3499803f, - 4839.2749801f,4871.1999799f,4903.1249797f,4935.0499795f,4966.9749793f,5126.5999784f,5158.5249782f,5190.4499780f,5222.3749778f,5254.2999777f,5286.2249775f, - 5445.8499765f,5477.774976f,5509.6999762f,5541.6249760f,5573.5499758f,5605.4749756f,5765.0999747f,5797.0249745f,5828.9499743f,5860.8749741f,5892.7999739f, - 5924.724973f,6084.3499728f,6116.2749726f,6148.1999724f,6180.1249723f,6212.0499721f,6243.9749719f,6403.59997f,6435.5249708f,6467.4499706f,6499.3749704f, - 6531.2999702f,6563.2249700f,5802.3150007f,5841.0350006f,5879.7550005f,5918.4750004f,5957.195000f,5995.9150003f,6189.5149999f,6228.2349998f,6266.9549997f, - 6305.6749996f,6344.3949995f,6383.114999f,6576.7149990f,6615.4349990f,6654.1549989f,6692.8749988f,6731.5949987f,6770.3149986f,6963.9149982f,7002.6349981f, - 7041.3549981f,7080.0749980f,7118.7949979f,7157.5149978f,7351.1149974f,7389.8349973f,7428.5549972f,7467.2749972f,7505.9949971f,7544.7149970f,7738.3149966f,7777.0349965f,7815.7549964f,7854.4749963f,7893.1949963f,7931.9149962f,6797.2799488f,6842.794948f,6888.3099489f,6933.8249490f,6979.3399491f,7024.8549492f,7252.4299497f,7297.9449498f,7343.4599499f,7388.9749500f,7434.489950f,7480.0049501f,7707.5799506f,7753.0949507f,7798.6099508f,7844.1249509f,7889.6399510f,7935.1549511f,8162.7299515f,8208.2449516f,8253.7599517f,8299.2749518f,8344.7899519f,8390.3049520f,8617.8799525f,8663.394952f,8708.9099526f,8754.4249527f,8799.9399528f,8845.4549529f,9073.0299534f,9118.5449535f,9164.0599536f,9209.5749537f,9255.089953f,9300.604953f,7792.2451647f,7844.5551655f,7896.8651663f,7949.1751671f,8001.4851679f,8053.7951686f,8315.3451725f,8367.6551733f,8419.9651741f,8472.2751749f,8524.585175f,8576.8951764f,8838.4451803f,8890.7551811f,8943.0651819f,8995.3751827f,9047.6851834f,9099.9951842f,9361.5451881f,9413.8551889f,9466.1651897f,9518.475190f,9570.7851912f,9623.0951920f,9884.6451959f,9936.9551967f,9989.2651975f,10041.5751982f,10093.8851990f,10146.1951998f,10407.7452037f,10460.0552045f,10512.3652053f,10564.6752060f,10616.9852068f,10669.2952076f,8787.210074f,8846.3150748f,8905.4200750f,8964.5250752f,9023.6300755f,9082.7350757f,9378.2600768f,9437.3650770f,9496.4700773f,9555.5750775f,9614.6800777f,9673.7850779f,9969.3100791f,10028.4150793f,10087.5200795f,10146.625079f,10205.7300800f,10264.8350802f,10560.3600813f,10619.465081f,10678.5700818f,10737.6750820f,10796.7800822f,10855.8850825f,11151.4100836f,11210.5150838f,11269.6200840f,11328.7250843f,11387.8300845f,11446.9350847f,11742.4600858f,11801.5650861f,11860.6700863f,11919.7750865f,11978.880086f,12037.9850870f,9782.1750935f,9848.0750935f,9913.9750934f,9979.8750934f,10045.7750934f,10111.6750933f,10441.1750931f,10507.0750931f,10572.9750931f,10638.8750930f,10704.7750930f,10770.6750930f,11100.1750928f,11166.0750927f,11231.9750927f,11297.8750927f,11363.7750926f,11429.6750926f,11759.1750924f,11825.0750924f,11890.9750923f,11956.8750923f,12022.7750923f,12088.6750922f,12418.175092f,12484.0750920f,12549.9750920f,12615.8750919f,12681.7750919f,12747.6750919f,13077.1750917f,13143.0750916f,13208.9750916f,13274.8750916f,13340.7750915f,13406.6750915f,2250.990060f,2255.7350610f,2260.4800611f,2265.2250612f,2269.9700613f,2274.7150614f,2298.4400619f,2303.185062f,2307.9300622f,2312.6750623f,2317.4200624f,2322.1650625f,2345.8900630f,2350.6350631f,2355.380063f,2360.1250634f,2364.8700635f,2369.6150636f,2393.3400641f,2398.0850642f,2402.8300643f,2407.5750644f,2412.320064f,2417.0650647f,2440.7900652f,2445.5350653f,2450.2800654f,2455.0250655f,2459.7700656f,2464.515065f,2488.2400663f,2492.9850664f,2497.7300665f,2502.4750666f,2507.2200667f,2511.9650668f,5284.4551315f,5295.9951318f,5307.535132f,5319.0751323f,5330.6151326f,5342.1551328f,5399.8551341f,5411.3951343f,5422.9351346f,5434.475134f,5446.0151351f,5457.5551354f,5515.2551366f,5526.7951369f,5538.3351371f,5549.8751374f,5561.4151376f,5572.9551379f,5630.6551392f,5642.1951394f,5653.7351397f,5665.2751399f,5676.8151402f,5688.3551404f,5746.0551417f,5757.5951420f,5769.1351422f,5780.6751425f,5792.2151427f,5803.7551430f,5861.455144f,5872.9951445f,5884.5351448f,5896.0751450f,5907.6151453f,5919.1551455f,8317.919884f,8336.2548841f,8354.5898838f,8372.9248835f,8391.2598832f,8409.59488f,8501.2698815f,8519.6048813f,8537.9398810f,8556.2748807f,8574.6098804f,8592.9448801f,8684.6198787f,8702.9548784f,8721.2898782f,8739.6248779f,8757.9598776f,8776.2948773f,8867.9698759f,8886.3048756f,8904.6398753f,8922.9748751f,8941.3098748f,8959.6448745f,9051.3198731f,9069.6548728f,9087.9898725f,9106.3248722f,9124.6598720f,9142.9948717f,9234.6698703f,9253.0048700f,9271.3398697f,9289.6748694f,9308.0098691f,9326.3448689f,11351.3852747f,11376.5152753f,11401.6452759f,11426.7752765f,11451.9052771f,11477.0352777f,11602.6852806f,11627.8152812f,11652.9452818f,11678.0752824f,11703.2052830f,11728.335283f,11853.9852865f,11879.1152871f,11904.2452877f,11929.3752883f,11954.505288f,11979.6352894f,12105.2852924f,12130.4152930f,12155.545293f,12180.6752941f,12205.8052947f,12230.9352953f,12356.5852983f,12381.715298f,12406.8452994f,12431.9753000f,12457.1053006f,12482.2353012f,12607.8853041f,12633.0153047f,12658.1453053f,12683.2753059f,12708.4053065f,12733.5353071f,14384.8499244f,14416.7749242f,14448.6999240f,14480.6249238f,14512.549923f,14544.4749235f,14704.0999225f,14736.024922f,14767.9499222f,14799.8749220f,14831.7999218f,14863.7249216f,15023.3499207f,15055.2749205f,15087.1999203f,15119.1249201f,15151.0499199f,15182.9749197f,15342.5999188f,15374.5249186f,15406.4499184f,15438.374918f,15470.2999181f,15502.2249179f,15661.84991f,15693.7749168f,15725.6999166f,15757.6249164f,15789.5499162f,15821.4749160f,15981.0999151f,16013.0249149f,16044.9499147f,16076.8749145f,16108.7999143f,16140.7249142f,17418.314976f,17457.0349761f,17495.7549760f,17534.4749759f,17573.1949758f,17611.9149757f,17805.5149753f,17844.234975f,17882.9549752f,17921.6749751f,17960.3949750f,17999.1149749f,18192.7149745f,18231.4349744f,18270.154974f,18308.8749743f,18347.5949742f,18386.3149741f,18579.9149737f,18618.6349736f,18657.3549735f,18696.074973f,18734.7949734f,18773.5149733f,18967.1149729f,19005.8349728f,19044.5549727f,19083.2749726f,19121.994972f,19160.7149725f,19354.3149721f,19393.0349720f,19431.7549719f,19470.4749718f,19509.1949717f,19547.914971f,20451.7799765f,20497.2949766f,20542.8099767f,20588.3249768f,20633.8399769f,20679.3549770f,20906.929977f,20952.4449775f,20997.9599776f,21043.4749777f,21088.9899778f,21134.5049779f,21362.0799784f,21407.5949785f,21453.1099786f,21498.624978f,21544.139978f,21589.6549788f,21817.2299793f,21862.7449794f,21908.2599795f,21953.7749796f,21999.2899797f,22044.8049798f,22272.3799802f,22317.8949803f,22363.4099804f,22408.9249805f,22454.4399806f,22499.9549807f,22727.529981f,22773.044981f,22818.5599813f,22864.0749814f,22909.5899815f,22955.1049816f,23485.2453985f,23537.555399f,23589.8654000f,23642.1754008f,23694.4854016f,23746.7954024f,24008.3454063f,24060.655407f,24112.9654078f,24165.2754086f,24217.5854094f,24269.8954102f,24531.4454141f,24583.7554148f,24636.0654156f,24688.3754164f,24740.6854172f,24792.99541f,25054.545421f,25106.8554226f,25159.1654234f,25211.4754242f,25263.7854250f,25316.0954257f,25577.6454296f,25629.9554304f,25682.2654312f,25734.5754320f,25786.8854328f,25839.1954335f,26100.7454374f,26153.0554382f,26205.3654390f,26257.6754398f,26309.985440f,26362.2954413f,26518.7101423f,26577.8151425f,26636.920142f,26696.0251430f,26755.1301432f,26814.2351434f,27109.7601446f,27168.8651448f,27227.9701450f,27287.0751452f,27346.1801455f,27405.2851457f,27700.8101468f,27759.9151470f,27819.0201473f,27878.1251475f,27937.2301477f,27996.33514f,28291.8601491f,28350.9651493f,28410.0701495f,28469.175149f,28528.2801500f,28587.3851502f,28882.9101513f,28942.0151516f,29001.1201518f,29060.2251520f,29119.3301522f,29178.4351525f,29473.9601536f,29533.0651538f,29592.1701540f,29651.2751543f,29710.3801545f,29769.4851547f,29552.1750826f,29618.0750825f,29683.9750825f,29749.8750825f,29815.7750824f,29881.6750824f,30211.1750822f,30277.0750822f,30342.9750821f,30408.8750821f,30474.7750821f,30540.6750820f,30870.175081f,30936.0750818f,31001.9750818f,31067.8750817f,31133.7750817f,31199.6750817f,31529.1750815f,31595.075081f,31660.9750814f,31726.8750814f,31792.7750813f,31858.6750813f,32188.1750811f,32254.0750811f,32319.975081f,32385.8750810f,32451.7750810f,32517.6750809f,32847.1750808f,32913.0750807f,32978.9750807f,33044.875080f,33110.7750806f,33176.67508062f}); - - input.linspace(1); - - sd::ops::sconv2d op; - auto resultFF = op.evaluate({&input, &weightsD}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0}, {}); - - auto z = resultFF.at(0); - - ASSERT_TRUE(z->isSameShape(&expFF)); - ASSERT_TRUE(z->equalsTo(&expFF, 1)); - - - sd::ops::conv2d op2d; - // weightsP.printShapeInfo(); - auto result2D = op2d.evaluate({z, &weightsP}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {}); - - auto z2d = result2D.at(0); - // z2d->printBuffer(); - - ASSERT_TRUE(z2d->isSameShape(&exp2FF)); - ASSERT_TRUE(z2d->equalsTo(&exp2FF)); -} - -TEST_F(ConvolutionTests1, deconv2d_bp_1) { - - int bS=3, iH=4,iW=4, iC=3,oC=2, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=4,oW=4; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, sd::DataType::FLOAT32); - NDArray weights('c',{kH,kW,oC,iC}, {1,3,5,2,4,6}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oC, oH, oW},sd::DataType::FLOAT32); - - NDArray expGradI('c', {bS, iC, iH, iW}, {35.f, 38.f, 41.f, 44.f, 47.f, 50.f, 53.f, 56.f, 59.f, 62.f, 65.f, 68.f, 71.f, 74.f, - 77.f, 80.f, 71.f, 78.f, 85.f, 92.f, 99.f, 106.f, 113.f, 120.f, 127.f, 134.f, 141.f, 148.f, 155.f, 162.f, 169.f, - 176.f, 107.f, 118.f, 129.f, 140.f, 151.f, 162.f, 173.f, 184.f, 195.f, 206.f, 217.f, 228.f, 239.f, 250.f, 261.f, 272.f, - 131.f, 134.f, 137.f, 140.f, 143.f, 146.f, 149.f, 152.f, 155.f, 158.f, 161.f, 164.f, 167.f, 170.f, 173.f, 176.f, 295.f, - 302.f, 309.f, 316.f, 323.f, 330.f, 337.f, 344.f, 351.f, 358.f, 365.f, 372.f, 379.f, 386.f, 393.f, 400.f, 459.f, 470.f, - 481.f, 492.f, 503.f, 514.f, 525.f, 536.f, 547.f, 558.f, 569.f, 580.f, 591.f, 602.f, 613.f, 624.f, 227.f, 230.f, 233.f, - 236.f, 239.f, 242.f, 245.f, 248.f, 251.f, 254.f, 257.f, 260.f, 263.f, 266.f, 269.f, 272.f, 519.f, 526.f, 533.f, 540.f, - 547.f, 554.f, 561.f, 568.f, 575.f, 582.f, 589.f, 596.f, 603.f, 610.f, 617.f, 624.f, 811.f, 822.f, 833.f, 844.f, 855.f, - 866.f, 877.f, 888.f, 899.f, 910.f, 921.f, 932.f, 943.f, 954.f, 965.f, 976.f}, sd::DataType::FLOAT32); - NDArray expGradW('c', {kH, kW, oC, iC}, {160008., 191112., 222216., 203400., 246792., 290184.f}, sd::DataType::FLOAT32); - NDArray expGradB('c', {oC}, {1944.f, 2712.f}, sd::DataType::FLOAT32); - - input.linspace(1); - bias.linspace(1); - gradO.linspace(1); - - - sd::ops::deconv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto gradI = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); - -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, deconv2d_bp_2) { - - int bS=3, iH=4,iW=4, iC=3,oC=2, kH=2,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=4,oW=4; // 5,4 - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - int wFormat = 1; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] - - NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32); - NDArray weights('c',{iC, oC, kH, kW}, {1., 7., 2., 10., 3., 8., 4., 11., 5., 9., 6., 12.}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oC, oH, oW},sd::DataType::FLOAT32); - - NDArray expGradI('c', {bS, iC, iH, iW}, {-77.400002, -77.199997, -77., -76.800003, -76.599998, -76.400002, -76.200005, -76., -75.800003, -75.599998, -75.399994, - -75.199997, -11.32, -11.29, -11.26, -11.23, -100.839996, -100.580002, -100.32, -100.059998, -99.800003, -99.540001, -99.279999, -99.019997, -98.760002, -98.50, - -98.240005, -97.979996, -26.52, -26.450001, -26.380001, -26.309999, -124.279999, -123.959991, -123.639999, -123.32, -123., -122.68, -122.360001, -122.040001, - -121.720001, -121.400009, -121.080002, -120.759995, -41.720001, -41.610001, -41.50, -41.389999, -71., -70.800003, -70.599998, -70.399994, -70.199997, -70., -69.800003, -69.600006, -69.400002, -69.199997, -69., -68.799995, -10.360001, -10.33, -10.30, -10.27, -92.519997, -92.260002, -92., -91.740005, -91.479996, -91.220001, -90.960007, -90.700005, -90.440002, -90.18, -89.919998, -89.660004, -24.280001, -24.209999, -24.139999, -24.07, -114.040001, -113.720001, -113.400009, -113.080002, -112.759995, -112.440002, -112.120003, -111.800003, -111.480003, -111.159996, -110.839996, -110.520004, -38.200001, -38.09, -37.980003, -37.869999, -64.599998, -64.400002, -64.199997, -64., -63.799995, -63.599998, -63.400002, -63.199997, -63., -62.799995, -62.599998, -62.400002, -9.40, -9.37, -9.34, -9.309999, -84.200005, -83.940002, -83.68, -83.419998, -83.160004, -82.900002, -82.639999, -82.379997, -82.119995, -81.860001, -81.600006, -81.339996, -22.040001, -21.970001, -21.90, -21.83, -103.800003, -103.480003, -103.159996, -102.839996, -102.520004, -102.200005, -101.879997, -101.559998, -101.239998, -100.919998, -100.599998, -100.279999, -34.68, -34.57, -34.459999, -34.349998}, sd::DataType::FLOAT32); - - NDArray expGradW('c', {iC, oC, kH, kW}, {-3010.799805, -2502.420410, -2899.439209, -2407.380615, -242.159332, -437.460510, -253.680466, -434.580048, 2526.479980, 1627.500000, 2392.079834, 1538.220093}, sd::DataType::FLOAT32); - NDArray expGradB('c', {oC}, {-173.040009, -165.360016}, sd::DataType::FLOAT32); - - input.linspace(70., -1); - gradO.linspace(-4, 0.01); - - sd::ops::deconv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto gradI = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, deconv2d_bp_3) { - - int bS=3, iH=4,iW=4, iC=3,oC=2, kH=2,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=5,oW=4; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - int wFormat = 2; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] - - NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32); - NDArray weights('c',{iC, kH, kW, oC}, {1., 4., 7., 10., 2., 5., 8., 11., 3., 6., 9., 12.}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); - - NDArray expGradI('c', {bS, iH, iW, iC}, {-86.5, -102.320007, -118.139999, -86.060005, -101.800003, -117.540001, -85.619995, -101.279999, -116.940002, -85.18, - -100.759995, -116.339996, -84.740005, -100.239998, -115.739998, -84.300003, -99.720001, -115.139999, -83.860001, -99.199997, -114.539993, -83.419998, -98.68, - -113.939995, -82.979996, -98.160004, -113.339996, -82.539993, -97.639999, -112.739998, -82.099998, -97.120003, -112.139999, -81.660004, -96.600006, -111.539993, - -81.220001, -96.080002, -110.939995, -80.779999, -95.559998, -110.340012, -80.340004, -95.040001, -109.740005, -79.900002, -94.519997, -109.139992, -77.699997, - -91.919998, -106.139999, -77.260002, -91.400002, -105.540001, -76.820007, -90.880005, -104.940002, -76.380005, -90.360001, -104.339996, -75.940002, -89.839996, -103.740005, -75.5, -89.320007, -103.139999, -75.060005, -88.800003, -102.540001, -74.619995, -88.279999, -101.940002, -74.18, -87.759995, -101.339996, -73.740005, -87.239998, -100.739998, -73.300003, -86.720001, -100.139999, -72.860001, -86.199997, -99.539993, -72.419998, -85.68, -98.939995, -71.979996, -85.160004, -98.339996, -71.539993, -84.639999, -97.740005, -71.099998, -84.120003, -97.139999, -68.899994, -81.519997, -94.139999, -68.459999, -81.00, -93.539993, -68.019997, -80.479996, -92.940002, -67.580002, -79.959999, -92.339996, -67.139999, -79.440002, -91.740005, -66.699997, -78.919998, -91.139999, -66.260002, -78.399994, -90.540001, -65.820007, -77.880005, -89.940002, -65.380005, -77.360001, -89.339996, -64.940002, -76.839996, -88.740005, -64.5, -76.320007, -88.139999, -64.060005, -75.800003, -87.540001, -63.619995, -75.279999, -86.940002, -63.18, -74.759995, -86.339996, -62.739998, -74.239998, -85.739998, -62.299999, -73.720001, -85.139999}, sd::DataType::FLOAT32); - - NDArray expGradW('c', {iC, kH, kW, oC}, {-592.800110, -593.039917, -594.719116, -594.960266, -427.199890, -427.919617, -432.959900, -433.679993, -261.600281, -262.799591, -271.200317, -272.399536}, sd::DataType::FLOAT32); - NDArray expGradB('c', {oC}, {-204.600006, -204.}, sd::DataType::FLOAT32); - - input.linspace(70., -1); - gradO.linspace(-4, 0.01); - - sd::ops::deconv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto gradI = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); -} - -TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) { - auto input = NDArrayFactory::create('c', {2, 2, 6}); - auto weights = NDArrayFactory::create('c', {2, 2, 3}, {1,5,9,3,7,11,2,6,10,4,8,12}); - auto bias = NDArrayFactory::create('c', {3}); - auto expFF = NDArrayFactory::create('c', {2, 3, 5}, {59.0f, 69.0f, 79.0f, 89.0f, 99.0f, 132.0f, 158.0f, 184.0f, 210.0f, 236.0f, 205.0f, 247.0f, 289.0f, 331.0f, 373.0f, 179.0f, 189.0f, 199.0f, 209.0f, 219.0f, 444.0f, 470.0f, 496.0f, 522.0f, 548.0f, 709.0f, 751.0f, 793.0f, 835.0f, 877.0f}); - auto expEps = NDArrayFactory::create('c', {2, 2, 6}, {130.0f, 293.0f, 326.0f, 359.0f, 392.0f, 220.0f, 166.0f, 371.0f, 416.0f, 461.0f, 506.0f, 280.0f, 355.0f, 788.0f, 821.0f, 854.0f, 887.0f, 490.0f, 481.0f, 1046.0f, 1091.0f, 1136.0f, 1181.0f, 640.0f}); - auto expGW = NDArrayFactory::create('c', {3, 2, 2}, {1415.0f, 1520.0f, 2045.0f, 2150.0f, 1865.0f, 2020.0f, 2795.0f, 2950.0f, 2315.0f, 2520.0f, 3545.0f, 3750.0f}); - auto expGB = NDArrayFactory::create('c', {3}, {105.0f, 155.0f, 205.0f}); - - expGW.permutei({2,1,0}); - input.linspace(1); - bias.linspace(1); - - sd::ops::conv1d op; - auto result_FF = op.evaluate({&input, &weights, &bias}, {}, {2, 1, 0, 1, 0, 0}); - - ASSERT_EQ(ND4J_STATUS_OK, result_FF.status()); - - auto z = result_FF.at(0); - - ASSERT_TRUE(expFF.isSameShape(z)); - ASSERT_TRUE(expFF.equalsTo(z)); - - sd::ops::conv1d_bp op_bp; - - auto epsilonNxt = new NDArray(z->dup()); - epsilonNxt->linspace(1); - - auto result_BP = op_bp.evaluate({&input, &weights, &bias, epsilonNxt}, {}, {2, 1, 0, 1, 0, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result_BP.status()); - - auto eps = result_BP.at(0); - auto gradW = result_BP.at(1); - auto gradB = result_BP.at(2); - - ASSERT_TRUE(expEps.isSameShape(eps)); - ASSERT_TRUE(expGW.isSameShape(gradW)); - ASSERT_TRUE(expGB.isSameShape(gradB)); - - ASSERT_TRUE(expEps.equalsTo(eps)); - ASSERT_TRUE(expGW.equalsTo(gradW)); - ASSERT_TRUE(expGB.equalsTo(gradB)); - - delete epsilonNxt; -} - - -TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_2) { - auto input = NDArrayFactory::create('c', {2, 2, 6}); - auto weights = NDArrayFactory::create('c', {2, 2, 3}, {1.f, 5.f, 9.f, 3.f, 7.f, 11.f, 2.f, 6.f, 10.f, 4.f, 8.f, 12.f}); - - input.linspace(1); - - sd::ops::conv1d op; - auto result = op.evaluate({&input, &weights}, {}, {2, 1, 0, 1, 1,0}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv1d_causal_1) { - - int bS=2, iW=3, iC=4,oC=3, kW=2, sW=1, pW=0, dW=1; - int oW = (iW-1)/sW + 1; - int paddingMode = 2; // CAUSAL - int dataFormat = 1; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iW, iC}); - NDArray weights('c', {kW, iC, oC}); - NDArray bias('c', {oC}, {-1,-2,-3}); - - NDArray expOutput('c', {bS, oW, oC}, {18. , 18. , 18. , 53. , 55.6, 58.2, 89.8, 95.6, 101.4, 102. , 106.8, 111.6, 163.4, 175.6, 187.8, 200.2, 215.6, 231.}); - - input.linspace(1., 1.); - weights.linspace(0.1, 0.1); - - sd::ops::conv1d op; - auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv1d_causal_2) { - - int bS=2, iW=16, iC=3,oC=4, kW=2, sW=2, pW=0, dW=1; - int oW = (iW-1)/sW + 1; - int paddingMode = 2; // CAUSAL - int dataFormat = 1; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iW, iC}); - NDArray weights('c', {kW, iC, oC}); - NDArray bias('c', {oC}, {-1,-2,-3,-4}); - - NDArray expOutput('c', {bS, oW, oC}, { 10. , 9.6, 9.2, 8.8, 48.9, 51.8, 54.7, 57.6, 88.5, 95. , 101.5, 108. , 128.1, 138.2, 148.3, 158.4, - 167.7, 181.4, 195.1, 208.8, 207.3, 224.6, 241.9, 259.2, 246.9, 267.8, 288.7, 309.6, 286.5, 311. , 335.5, 360. , - 254.8, 268.8, 282.8, 296.8, 365.7, 397.4, 429.1, 460.8, 405.3, 440.6, 475.9, 511.2, 444.9, 483.8, 522.7, 561.6, - 484.5, 527. , 569.5, 612. , 524.1, 570.2, 616.3, 662.4, 563.7, 613.4, 663.1, 712.8, 603.3, 656.6, 709.9, 763.2}); - - input.linspace(1., 1.); - weights.linspace(0.1, 0.1); - - sd::ops::conv1d op; - auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv1d_causal_3) { - - int bS=2, iW=16, iC=3,oC=4, kW=3, sW=3, pW=0, dW=1; - int oW = (iW-1)/sW + 1; - int paddingMode = 2; // CAUSAL - int dataFormat = 1; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iW, iC}); - NDArray weights('c', {kW, iC, oC}); - NDArray bias('c', {oC}, {-1,-2,-3,-4}); - - NDArray expOutput('c', {bS, oW, oC}, {17.2, 16.8, 16.4, 16.,145.4, 151.6, 157.8, 164.,283.1, 297.4, 311.7, 326., 420.8, 443.2, 465.6, 488., - 558.5, 589., 619.5, 650.,696.2001, 734.8, 773.4, 812., 434.8, 448.8, 462.8, 476.8, 879.8, 929.2, 978.6, 1028., - 1017.5, 1075., 1132.5, 1190.,1155.2001, 1220.8, 1286.4, 1352.,1292.8999, 1366.6, 1440.3, 1514., 1430.6001, 1512.4, 1594.2, 1676.}); - - input.linspace(1., 1.); - weights.linspace(0.1, 0.1); - - sd::ops::conv1d op; - auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv1d_causal_4) { - - int bS=2, iW=8, iC=3,oC=4, kW=3, sW=1, pW=0, dW=3; - int oW = (iW-1)/sW + 1; - int paddingMode = 2; // CAUSAL - int dataFormat = 1; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iW, iC}); - NDArray weights('c', {kW, iC, oC}); - NDArray bias('c', {oC}, {-1,-2,-3,-4}); - - NDArray expOutput('c', {bS, oW, oC}, {17.2, 16.8, 16.4, 16. ,43.3, 43.8, 44.3, 44.8,69.4, 70.8, 72.2, 73.6,106.5, 109.4, 112.3, 115.2,147.9, 152.6, 157.3, 162. ,189.3, 195.8, 202.3, - 208.8,234.5, 243.4, 252.3, 261.2,280.4, 292. , 303.6, 315.2, 226. , 232.8, 239.6, 246.4, 252.1, 259.8, 267.5, 275.2,278.2, 286.8, 295.4, 304. ,437.7, - 455. , 472.3, 489.6,479.1, 498.2, 517.3, 536.4,520.5, 541.4, 562.3, 583.2, 601.7, 632.2, 662.7, 693.2, 647.6, 680.8, 714. , 747.2}); - - input.linspace(1., 1.); - weights.linspace(0.1, 0.1); - - sd::ops::conv1d op; - auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv1d_causal_5) { - - int bS=2, iW=8, iC=3,oC=4, kW=3, sW=1, pW=0, dW=3; - int oW = (iW-1)/sW + 1; - int paddingMode = 2; // CAUSAL - int dataFormat = 0; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iC, iW}); - NDArray weights('c', {kW, iC, oC}); - NDArray bias('c', {oC}, {-1,-2,-3,-4}); - - NDArray expOutput('c', {bS, oC, oW}, { 83.7, 92.4, 101.1, 162.1, 175.9, 189.7, 223.4, 238.7,85.4, 94.4, 103.4, 167.4, 181.8, 196.2, 233.2, 249.4,87.1, 96.4, 105.7, 172.7, 187.7, 202.7, 243. , 260.1, - 88.8, 98.4, 108. , 178. , 193.6, 209.2, 252.8, 270.8, 292.5, 301.2, 309.9, 493.3, 507.1, 520.9, 590.6, 605.9, 301.4, 310.4, 319.4, 513. , 527.4, 541.8, 622. , 638.2, - 310.3, 319.6, 328.9, 532.7, 547.7, 562.7, 653.4, 670.5, 319.2, 328.8, 338.4, 552.4, 568. , 583.6, 684.8, 702.8}); - - input.linspace(1., 1.); - weights.linspace(0.1, 0.1); - - sd::ops::conv1d op; - auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv1d_causal_6) { - - int bS=2, iW=16, iC=3,oC=4, kW=3, sW=3, pW=0, dW=1; - int oW = (iW-1)/sW + 1; - int paddingMode = 2; // CAUSAL - int dataFormat = 0; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iC, iW}); - NDArray weights('c', {kW, iC, oC}); - NDArray bias('c', {oC}, {-1,-2,-3,-4}); - - NDArray expOutput('c', {bS, oC, oW}, {159.7,335.3,381.2,427.1,473. ,518.9,163.8,351.4,400. ,448.6,497.2,545.8,167.9,367.5,418.8,470.1,521.4,572.7,172. ,383.6,437.6,491.6,545.6,599.6, - 577.3, 1069.7, 1115.6, 1161.5, 1207.4, 1253.3,595.8, 1129. , 1177.6, 1226.2, 1274.8, 1323.4,614.3, 1188.3, 1239.6, 1290.9, 1342.2, 1393.5, - 632.8, 1247.6, 1301.6, 1355.6, 1409.6, 1463.6}); - - input.linspace(1., 1.); - weights.linspace(0.1, 0.1); - - sd::ops::conv1d op; - auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv1d_causal_7) { - - int bS=2, iW=8, iC=3,oC=4, kW=2, sW=1, pW=0, dW=1; - int oW = (iW-1)/sW + 1; - int paddingMode = 2; // CAUSAL - int dataFormat = 1; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {kW, iC, oC}, sd::DataType::FLOAT32); - - NDArray expOutput('c', {bS, oW, oC}, {11.000000, 11.600000, 12.200000, 12.800000, 30.099998, 32.200001, 34.299999, 36.400002, 49.899998, 53.800003, 57.699997, - 61.599998, 69.699997, 75.400002, 81.099998, 86.800003, 89.500000, 97.000000, 104.500000, 112.000000, 109.300003, 118.600006, 127.899994, 137.199997, 129.100006, - 140.199997, 151.300003, 162.399994, 148.899994, 161.800003, 174.699997, 187.600006, 133.399994, 141.200012, 149.000000, 156.800003, 188.500000, 205.000000, - 221.500000, 238.000000, 208.299988, 226.600006, 244.899994, 263.200012, 228.100006, 248.200012, 268.299988, 288.399994, 247.899994, 269.799988, 291.700012, - 313.600006, 267.700012, 291.399994, 315.100006, 338.799988, 287.500000, 313.000000, 338.500000, 364.000000, 307.299988, 334.600006, 361.899994, 389.200012}, sd::DataType::FLOAT32); - - input.linspace(1., 1.); - weights.linspace(0.1, 0.1); - - sd::ops::conv1d op; - auto results = op.evaluate({&input, &weights}, {kW, sW, pW, dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv1d_causal_8) { - - int bS=2, iW=8, iC=3,oC=4, kW=2, sW=1, pW=0, dW=2; - int oW = (iW-1)/sW + 1; - int paddingMode = 2; // CAUSAL - int dataFormat = 1; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {kW, iC, oC}, sd::DataType::FLOAT32); - - NDArray expOutput('c', {bS, oW, oC}, {11.000000, 11.600000, 12.200000, 12.800000, 26.299999, 27.799999, 29.299999, 30.799999, 45.399998, 48.399998, - 51.400002, 54.400005, 65.199997, 70.000000, 74.800003, 79.600006, 85.000000, 91.600006, 98.199997, 104.800003, 104.799995, 113.199997, 121.600006, - 130.000000, 124.599998, 134.800003, 145.000000, 155.200012, 144.399994, 156.399994, 168.399994, 180.400009, 133.400009, 141.199997, 149.000000, - 156.800003, 148.699997, 157.400009, 166.099991, 174.800003, 203.800003, 221.200012, 238.599991, 256.000000, 223.599991, 242.799988, 262.000000, - 281.200012, 243.399994, 264.399994, 285.399994, 306.399994, 263.199982, 286.000000, 308.799988, 331.600006, 283.000000, 307.600006, 332.200012, - 356.800018, 302.799988, 329.199982, 355.600006, 382.000000}, sd::DataType::FLOAT32); - - input.linspace(1., 1.); - weights.linspace(0.1, 0.1); - - sd::ops::conv1d op; - auto results = op.evaluate({&input, &weights}, {kW, sW, pW, dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv1d_causal_bp_1) { - - int bS=2, iW=3, iC=4,oC=3, kW=2, sW=1, pW=0, dW=1; - int oW = (iW-1)/sW + 1; - int paddingMode = 2; // CAUSAL - int dataFormat = 1; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iW, iC}); - NDArray weights('c', {kW, iC, oC}); - NDArray bias('c', {oC}, {-1,-2,-3}); - NDArray gradO('c', {bS, oW, oC}); - - input.linspace(1., 1.); - weights.linspace(0.1, 0.1); - gradO.linspace(-1.5, 0.1); - - const OpArgsHolder argsHolderFF({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); - const OpArgsHolder argsHolderBP({&input, &weights, &bias, &gradO}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); - - sd::ops::conv1d opFF; - sd::ops::conv1d_bp opBP; - - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - - ASSERT_TRUE(isGradCorrect); -} - -TEST_F(ConvolutionTests1, Test_Dilation2D_1) { - auto input = NDArrayFactory::create('c', {2, 6, 6, 3}); - auto weights = NDArrayFactory::create('c', {3, 2, 3}); - auto exp = NDArrayFactory::create('c', {2, 3, 3, 3}, {77, 79, 81, 83, 85, 87, 80, 82, 84, 113, 115, 117, 119, 121, 123, 116, 118, 120, 107, 109, 111, 113, 115, 117, 110, 112, 114, 185, 187, 189, 191, 193, 195, 188, 190, 192, 221, 223, 225, 227, 229, 231, 224, 226, 228, 215, 217, 219, 221, 223, 225, 218, 220, 222,}); - - input.linspace(1); - weights.linspace(1); - - sd::ops::dilation2d op; - auto result = op.evaluate({&input, &weights}, {1, 1,2,2,1, 1,2,2,1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); -} - -TEST_F(ConvolutionTests1, Test_Dilation2D_2) { - auto input = NDArrayFactory::create('c', {2, 6, 6, 3}); - auto weights = NDArrayFactory::create('c', {3, 2, 3}); - auto exp = NDArrayFactory::create('c', {2, 1, 2, 3}, {95, 97, 99, 101, 103, 105, 203, 205, 207, 209, 211, 213}); - - input.linspace(1); - weights.linspace(1); - - sd::ops::dilation2d op; - auto result = op.evaluate({&input, &weights}, {0, 1,2,2,1, 1,2,2,1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - -} - -////////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test1) { - - int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=4,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); - auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); - - auto expGradI = NDArrayFactory::create('c', {bS, iH, iW, iC},{ 0.226f, 0.343f, 0.46f, 0.577f, 1.172f, 1.46f, 1.748f, 2.036f, 1.892f, 2.288f, 2.684f, 3.08f, 1.284f, 1.581f, 1.878f, 2.175f, 4.458f, 5.133f, 5.808f, 6.483f, 6.186f, 7.023f, 7.86f, 8.697f, - 3.39f, 3.93f, 4.47f, 5.01f, 9.642f, 10.803f, 11.964f, 13.125f,11.37f, 12.693f, 14.016f, 15.339f, 5.266f, 5.707f, 6.148f, 6.589f,12.98f, 13.916f, 14.852f, 15.788f,14.564f, 15.608f, 16.652f, 17.696f, - 3.25f, 4.015f, 4.78f, 5.545f, 9.812f, 11.396f, 12.98f, 14.564f,10.532f, 12.224f, 13.916f, 15.608f, 9.708f, 10.977f, 12.246f, 13.515f,25.194f, 27.813f, 30.432f, 33.051f,26.922f, 29.703f, 32.484f, 35.265f, - 11.814f, 13.326f, 14.838f, 16.35f,30.378f, 33.483f, 36.588f, 39.693f,32.106f, 35.373f, 38.64f, 41.907f,13.474f, 14.563f, 15.652f, 16.741f,31.988f, 34.22f, 36.452f, 38.684f,33.572f, 35.912f, 38.252f, 40.592f}); - - auto expGradW = NDArrayFactory::create('c', {kH, kW, iC, oC},{14.4f, 14.76f, 15.12f,14.4f, 14.76f, 15.12f,14.4f, 14.76f, 15.12f,14.4f, 14.76f, 15.12f, 9.24f, 9.48f, 9.72f, 9.24f, 9.48f, 9.72f, 9.24f, 9.48f, 9.72f, 9.24f, 9.48f, 9.72f, - 17.04f, 17.52f, 18.f,17.04f, 17.52f, 18.f, 17.04f, 17.52f, 18.f, 17.04f, 17.52f, 18.f,10.88f, 11.2f, 11.52f,10.88f, 11.2f, 11.52f,10.88f, 11.2f, 11.52f,10.88f, 11.2f, 11.52f, - 11.16f, 11.52f, 11.88f,11.16f, 11.52f, 11.88f,11.16f, 11.52f, 11.88f,11.16f, 11.52f, 11.88f, 7.08f, 7.32f, 7.56f, 7.08f, 7.32f, 7.56f, 7.08f, 7.32f, 7.56f, 7.08f, 7.32f, 7.56f}); - // auto expGradB('c', {oC},{}); - - input = 2.; - weights.linspace(0.1, 0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto gradI = results.at(0); - auto gradW = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - -} - -////////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test2) { - - int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); - auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); - - auto expGradI = NDArrayFactory::create('c', {bS, iH, iW, iC},{ 0.014f, 0.032f, 0.05f, 0.068f,0.118f,0.181f, 0.244f, 0.307f,0.212f,0.257f, 0.302f, 0.347f,0.208f,0.298f, 0.388f, 0.478f,1.028f,1.262f, 1.496f, 1.73f,1.036f,1.18f, 1.324f, 1.468f, - 0.928f,1.018f, 1.108f, 1.198f,2.9f,3.134f, 3.368f, 3.602f,2.188f,2.332f, 2.476f, 2.62f, 1.202f,1.274f, 1.346f, 1.418f,3.142f,3.313f, 3.484f, 3.655f,2.048f,2.147f, 2.246f, 2.345f, - 0.086f,0.212f, 0.338f, 0.464f,0.694f,0.973f, 1.252f, 1.531f,0.716f,0.869f, 1.022f, 1.175f,1.216f,1.522f, 1.828f, 2.134f,3.908f,4.574f, 5.24f, 5.906f,2.908f,3.268f, 3.628f, 3.988f, - 3.664f,3.97f, 4.276f, 4.582f,9.236f,9.902f,10.568f,11.234f,5.788f,6.148f, 6.508f, 6.868f,3.002f,3.182f, 3.362f, 3.542f,7.174f,7.561f, 7.948f, 8.335f,4.28f,4.487f, 4.694f, 4.901f}); - - auto expGradW = NDArrayFactory::create('c', {kH, kW, iC, oC},{1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f, - 1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f, - 1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f}); - // auto expGradB('c', {oC},{}); - - input = 2.; - weights.linspace(0.1, 0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto gradI = results.at(0); - auto gradW = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); -} - -////////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test3) { - int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto weights = NDArrayFactory::create('c', {oC, iC, kH, kW}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); - auto gradO = NDArrayFactory::create('c', {bS, oC, oH, oW}); - - auto expGradI = NDArrayFactory::create('c', {bS, iC, iH, iW},{ 0.567f, 1.224f, 0.66f, 1.314f, 2.82f, 1.512f, 1.386f, 2.976f, 1.596f, 0.801f, 1.71f, 0.912f, 0.657f, 1.422f, 0.768f, 1.53f, 3.288f, 1.764f, 1.602f, 3.444f, 1.848f, 0.927f, 1.98f, 1.056f, - 0.747f, 1.62f, 0.876f, 1.746f, 3.756f, 2.016f, 1.818f, 3.912f, 2.1f, 1.053f, 2.25f, 1.2f, 0.837f, 1.818f, 0.984f, 1.962f, 4.224f, 2.268f, 2.034f, 4.38f, 2.352f, 1.179f, 2.52f, 1.344f, - 1.467f, 3.06f, 1.596f, 3.186f, 6.636f, 3.456f, 3.402f, 7.08f, 3.684f, 1.845f, 3.834f, 1.992f, 1.773f, 3.69f, 1.92f, 3.834f, 7.968f, 4.14f, 4.05f, 8.412f, 4.368f, 2.187f, 4.536f, 2.352f, - 2.079f, 4.32f, 2.244f, 4.482f, 9.3f, 4.824f, 4.698f, 9.744f, 5.052f, 2.529f, 5.238f, 2.712f, 2.385f, 4.95f, 2.568f, 5.13f, 10.632f, 5.508f, 5.346f, 11.076f, 5.736f, 2.871f, 5.94f, 3.072f}); - - auto expGradW = NDArrayFactory::create('c', {oC, iC, kH, kW},{1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, - 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, - 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, - 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f}); - auto expGradB = NDArrayFactory::create('c', {oC},{0.68f, 1.f, 1.32f}); - - input = 2.; - weights.linspace(0.1, 0.1); - gradO.linspace(0.01, 0.01); - weights.permutei({2,3,1,0}); - expGradW.permutei({2,3,1,0}); - - sd::ops::conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto gradI = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv2d_bp_4) { - - int bS=1, iH=7,iW=1, iC=2,oC=3, kH=2,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=7,oW=1; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); - NDArray weights('c', {kH, kW, iC, oC}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {1,2,3}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); - - NDArray gradI('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); - NDArray gradW('c', {kH, kW, iC, oC}, sd::DataType::FLOAT32); - NDArray gradB('c', {oC}, sd::DataType::FLOAT32); - - - input = 2.; - weights.linspace(0.1, 0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::conv2d_bp op; - auto status = op.execute({&input, &weights, &bias, &gradO}, {&gradI, &gradW, &gradB}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {}); - - ASSERT_EQ(Status::OK(), status); -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv2d_bp_5) { - - int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - int wFormat = 1; // 0-[kH, kW, iC, oC], 1-[oC, iC, kH, kW], 2-[oC, kH, kW, iC] - - NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); - NDArray weights('c', {oC, iC, kH, kW}, {3.6, 2.4, 1.2, 0.0, -1.2, -2.4, 3.3, 2.1, 0.9, -0.3, -1.5, -2.7, 3.0, 1.8, 0.6, -0.6, -1.8, -3.0, 2.7, 1.5, 0.3, -0.9, -2.1, -3.3, 3.5, 2.3, 1.1, -0.1, -1.3, -2.5, 3.2, 2.0, 0.8, -0.4, -1.6, -2.8, 2.9, 1.7, 0.5, -0.7, -1.9, -3.1, 2.6, 1.4, 0.2, -1.0, -2.2, -3.4, 3.4, 2.2, 1.0, -0.2, -1.4, -2.6, 3.1, 1.9, 0.7, -0.5, -1.7, -2.9, 2.8, 1.6, 0.4, -0.8, -2.0, -3.2, 2.5, 1.3, 0.1, -1.1, -2.3, -3.5}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {1,-0.5, 0.1}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); - - NDArray expGradI('c', {bS, iC, iH, iW},{0.517, 0.959, 0.406, 0.884, 1.474, 0.518, 0.020, -0.398, -0.490, -0.281, -0.853, -0.608, 0.472, 0.860, 0.352, 0.776, 1.240, - 0.392, -0.088, -0.632, -0.616, -0.344, -0.988, -0.680, 0.427, 0.761, 0.298, 0.668, 1.006, 0.266, -0.196, -0.866, -0.742, -0.407, -1.123, -0.752, 0.382, 0.662, - 0.244, 0.560, 0.772, 0.140, -0.304, -1.100, -0.868, -0.470, -1.258, -0.824, 1.777, 3.047, 1.234, 2.540, 3.922, 1.310, -0.052, -1.406, -1.426, -0.749, -2.221, - -1.508, 1.624, 2.732, 1.072, 2.216, 3.256, 0.968, -0.376, -2.072, -1.768, -0.920, -2.572, -1.688, 1.471, 2.417, 0.910, 1.892, 2.590, 0.626, -0.700, -2.738, -2.110, - -1.091, -2.923, -1.868, 1.318, 2.102, 0.748, 1.568, 1.924, 0.284, -1.024, -3.404, -2.452, -1.262, -3.274, -2.048}, sd::DataType::FLOAT32); - - NDArray expGradW('c', {oC, iC, kH, kW},{-3.3, -2.62, -1.26, -0.58, 0.78, 1.46, 4.86, 5.54, 6.9, 7.58, 8.940001, 9.619999, 13.02, 13.700001, 15.06, 15.74, 17.1, - 17.780001, 21.18, 21.860001, 23.219999, 23.900002, 25.259998, 25.940001, -10.340001, -9.34, -7.339999, -6.34, -4.339999, -3.339999, 1.66, 2.66, 4.660001, - 5.660001, 7.66, 8.66, 13.66, 14.660001, 16.66, 17.66, 19.66, 20.66, 25.66, 26.66, 28.66, 29.66, 31.66, 32.66, -17.380001, -16.059999, -13.420003, -12.099999, - -9.46, -8.139999, -1.540001, -0.219999, 2.419999, 3.739999, 6.379999, 7.7, 14.299999, 15.62, 18.26, 19.58, 22.219999, 23.539999, 30.139999, 31.459999, 34.099998, - 35.419998, 38.060001, 39.380001}, sd::DataType::FLOAT32); - - NDArray expGradB('c', {oC}, {0.68, 1., 1.32}, sd::DataType::FLOAT32); - - input.linspace(-48, 1); - // weights.linspace(3.6, -0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); - auto gradI = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv2d_bp_6) { - - int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=4,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - int wFormat = 2; // 0-[kH, kW, iC, oC], 1-[oC, iC, kH, kW], 2-[oC, kH, kW, iC] - - NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {oC, kH, kW, iC}, {3.6, 0.0, 3.3, -0.3, 3.0, -0.6, 2.7, -0.9, 3.5, -0.1, 3.2, -0.4, 2.9, -0.7, 2.6, -1.0, 3.4, -0.2, 3.1, -0.5, 2.8, -0.8, 2.5, -1.1, 2.4, -1.2, 2.1, -1.5, 1.8, -1.8, 1.5, -2.1, 2.3, -1.3, 2.0, -1.6, 1.7, -1.9, 1.4, -2.2, 2.2, -1.4, 1.9, -1.7, 1.6, -2.0, 1.3, -2.3, 1.2, -2.4, 0.9, -2.7, 0.6, -3.0, 0.3, -3.3, 1.1, -2.5, 0.8, -2.8, 0.5, -3.1, 0.2, -3.4, 1.0, -2.6, 0.7, -2.9, 0.4, -3.2, 0.1, -3.5}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {1,-0.5, 0.1}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); - - NDArray expGradI('c', {bS, iH, iW, iC}, {0.882, -0.522, 0.765, -0.639, 1.953, -1.503, 1.665, -1.791, 2.691, -2.061, 2.295, -2.457, 2.259, -1.305, 1.962, -1.602, 4.545, - -3.555, 3.870, -4.230, 5.625, -4.419, 4.788, -5.256001, 4.122, -2.358, 3.582, -2.898, 7.785, -6.147, 6.624, -7.308, 8.865, -7.011, 7.541999, -8.334, 3.273, -2.019, - 2.832, -2.460, 6.069, -5.163, 5.133, -6.099, 6.771, -5.757, 5.727, -6.801, 5.958, -3.222, 5.193, -3.987, 10.809, -8.198999, 9.225, -9.783, 11.547, -8.757, 9.855, - -10.448999, 9.711, -5.517, 8.441999, -6.786, 17.505001, -13.922999, 14.886, -16.542, 18.585001, -14.787001, 15.804001, -17.568001, 11.574, -6.570, 10.062, -8.082, - 20.745001, -16.514999, 17.639999, -19.619999, 21.825001, -17.379002, 18.558001, -20.646, 8.133, -4.935, 7.044, -6.024, 14.492998, -12.291, 12.261, -14.523001, 15.195001, -12.885, 12.855, -15.225}, sd::DataType::FLOAT32); - - NDArray expGradW('c', {oC, kH, kW, iC},{34.559998, 41.760010, 48.959999, 56.160004, 33.119999, 37.739998, 42.360001, 46.979996, 120.960007, 129.480011, 138.0, 146.519989, - 91.200005, 96.639999, 102.079994, 107.520004, 114.479996, 120.059998, 125.639999, 131.220001, 82.080002, 85.620003, 89.160004, 92.699997, 33.120003, 40.499996, - 47.879993, 55.260002, 32.399998, 37.139996, 41.880001, 46.620003, 120.479988, 129.240005, 137.999985, 146.759995, 91.199997, 96.799995, 102.399994, 108.0, 115.199989, - 120.959999, 126.720001, 132.479996, 82.799995, 86.460007, 90.119995, 93.779999, 31.679998, 39.239994, 46.800003, 54.359997, 31.680000, 36.540001, 41.400002, 46.260002, - 120.0, 129.0, 138.0, 147.0, 91.200005, 96.960007, 102.720001, 108.480003, 115.919998, 121.860001, 127.799988, 133.740005, 83.520004, 87.300003, 91.080002, 94.860001}, sd::DataType::FLOAT32); - - NDArray expGradB('c', {oC}, {8.520, 8.760, 9.}, sd::DataType::FLOAT32); - - input.linspace(-48, 1); - gradO.linspace(0.01, 0.01); - - sd::ops::conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); - auto gradI = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); -} - -//////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test1) { - - int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=3,oH=4,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); - auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); - - auto expGradI = NDArrayFactory::create('c', {bS, iD, iH, iW, iC},{0.226f, 0.343f, 0.46f, 0.577f, 1.172f, 1.46f, 1.748f, 2.036f, 1.892f, 2.288f, 2.684f, 3.08f, 1.284f, 1.581f, 1.878f, 2.175f, 4.458f, 5.133f, 5.808f, 6.483f, 6.186f, 7.023f, 7.86f, 8.697f, 3.39f, 3.93f, 4.47f, 5.01f, 9.642f, 10.803f, 11.964f, 13.125f, 11.37f, 12.693f, 14.016f, 15.339f, - 5.266f, 5.707f, 6.148f, 6.589f, 12.98f, 13.916f, 14.852f, 15.788f, 14.564f, 15.608f, 16.652f, 17.696f, 6.284f, 7.166f, 8.048f, 8.93f, 17.896f, 19.768f, 21.64f, 23.512f, 21.928f, 24.016f, 26.104f, 28.192f, 18.12f, 19.686f, 21.252f, 22.818f, 45.852f, 49.146f, 52.44f, 55.734f, 53.196f, 56.814f, 60.432f, 64.05f, - 28.164f, 30.216f, 32.268f, 34.32f, 67.884f, 72.15f, 76.416f, 80.682f, 75.228f, 79.818f, 84.408f, 88.998f, 29.324f, 30.854f, 32.384f, 33.914f, 67.432f, 70.6f, 73.768f, 76.936f, 73.192f, 76.576f, 79.96f, 83.344f, 27.884f, 30.062f, 32.24f, 34.418f, 66.28f, 70.744f, 75.208f, 79.672f, 70.312f, 74.992f, 79.672f, 84.352f, - 58.296f, 61.806f, 65.316f, 68.826f, 133.98f, 141.162f, 148.344f, 155.526f, 141.324f, 148.83f, 156.336f, 163.842f, 68.34f, 72.336f, 76.332f, 80.328f, 156.012f, 164.166f, 172.32f, 180.474f, 163.356f, 171.834f, 180.312f, 188.79f, 61.292f, 64.118f, 66.944f, 69.77f, 136.552f, 142.312f, 148.072f, 153.832f, 142.312f, 148.288f, 154.264f, 160.24f, - 9.298f, 11.359f, 13.42f, 15.481f, 27.092f, 31.268f, 35.444f, 39.62f, 27.812f, 32.096f, 36.38f, 40.664f, 26.556f, 29.769f, 32.982f, 36.195f, 66.666f, 73.173f, 79.68f, 86.187f, 68.394f, 75.063f, 81.732f, 88.401f, 28.662f, 32.118f, 35.574f, 39.03f, 71.85f, 78.843f, 85.836f, 92.829f, 73.578f, 80.733f, 87.888f, 95.043f, - 29.89f, 32.275f, 34.66f, 37.045f, 70.004f, 74.828f, 79.652f, 84.476f, 71.588f, 76.52f, 81.452f, 86.384f, 71.084f, 75.854f, 80.624f, 85.394f, 163.048f, 172.696f, 182.344f, 191.992f, 167.08f, 176.944f, 186.808f, 196.672f, 138.648f, 146.046f, 153.444f, 160.842f, 310.236f, 325.194f, 340.152f, 355.11f, 317.58f, 332.862f, 348.144f, 363.426f, - 148.692f, 156.576f, 164.46f, 172.344f, 332.268f, 348.198f, 364.128f, 380.058f, 339.612f, 355.866f, 372.12f, 388.374f, 125.228f, 130.646f, 136.064f, 141.482f, 274.792f, 285.736f, 296.68f, 307.624f, 280.552f, 291.712f, 302.872f, 314.032f, 92.684f, 98.75f, 104.816f, 110.882f, 211.432f, 223.672f, 235.912f, 248.152f, 215.464f, 227.92f, 240.376f, 252.832f, - 178.824f, 188.166f, 197.508f, 206.85f, 398.364f, 417.21f, 436.056f, 454.902f, 405.708f, 424.878f, 444.048f, 463.218f, 188.868f, 198.696f, 208.524f, 218.352f, 420.396f, 440.214f, 460.032f, 479.85f, 427.74f, 447.882f, 468.024f, 488.166f, 157.196f, 163.91f, 170.624f, 177.338f, 343.912f, 357.448f, 370.984f, 384.52f, 349.672f, 363.424f, 377.176f, 390.928f}); - - auto expGradW = NDArrayFactory::create('c', {kD, kH, kW, iC, oC},{120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, - 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, - 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, - 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, - 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, - 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f}); - // auto expGradB('c', {oC},{}); - - input = 2.; - weights.linspace(0.1, 0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::conv3dnew_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto gradI = results.at(0); - auto gradW = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - - -} - - -//////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test2) { - - int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); - auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); - - auto expGradI = NDArrayFactory::create('c', {bS, iD, iH, iW, iC},{ 0.014f, 0.032f, 0.05f, 0.068f, 0.118f, 0.181f, 0.244f, 0.307f, 0.212f, 0.257f, 0.302f, 0.347f, 0.208f, 0.298f, 0.388f, 0.478f, 1.028f, 1.262f, 1.496f, 1.73f, 1.036f, 1.18f, 1.324f, 1.468f, 0.928f, 1.018f, 1.108f, 1.198f, 2.9f, 3.134f, 3.368f, 3.602f, 2.188f, 2.332f, 2.476f, 2.62f, - 1.202f, 1.274f, 1.346f, 1.418f, 3.142f, 3.313f, 3.484f, 3.655f, 2.048f, 2.147f, 2.246f, 2.345f, 0.532f, 0.676f, 0.82f, 0.964f, 2.324f, 2.666f, 3.008f, 3.35f, 2.008f, 2.206f, 2.404f, 2.602f, 3.584f, 3.98f, 4.376f, 4.772f, 10.552f, 11.452f, 12.352f, 13.252f, 7.4f, 7.904f, 8.408f, 8.912f, - 6.752f, 7.148f, 7.544f, 7.94f, 17.752f, 18.652f, 19.552f, 20.452f, 11.432f, 11.936f, 12.44f, 12.944f, 5.932f, 6.184f, 6.436f, 6.688f, 14.42f, 14.978f, 15.536f, 16.094f, 8.704f, 9.01f, 9.316f, 9.622f, 3.11f, 3.236f, 3.362f, 3.488f, 7.39f, 7.669f, 7.948f, 8.227f, 4.388f, 4.541f, 4.694f, 4.847f, - 8.56f, 8.866f, 9.172f, 9.478f, 19.892f, 20.558f, 21.224f, 21.89f, 11.548f, 11.908f, 12.268f, 12.628f, 11.008f, 11.314f, 11.62f, 11.926f, 25.22f, 25.886f, 26.552f, 27.218f, 14.428f, 14.788f, 15.148f, 15.508f, 7.322f, 7.502f, 7.682f, 7.862f, 16.462f, 16.849f, 17.236f, 17.623f, 9.248f, 9.455f, 9.662f, 9.869f, - 0.158f, 0.392f, 0.626f, 0.86f, 1.27f, 1.765f, 2.26f, 2.755f, 1.22f, 1.481f, 1.742f, 2.003f, 2.224f, 2.746f, 3.268f, 3.79f, 6.788f, 7.886f, 8.984f, 10.082f, 4.78f, 5.356f, 5.932f, 6.508f, 6.4f, 6.922f, 7.444f, 7.966f, 15.572f, 16.67f, 17.768f, 18.866f, 9.388f, 9.964f, 10.54f, 11.116f, - 4.802f, 5.09f, 5.378f, 5.666f, 11.206f, 11.809f, 12.412f, 13.015f, 6.512f, 6.827f, 7.142f, 7.457f, 6.004f, 6.58f, 7.156f, 7.732f, 14.996f, 16.202f, 17.408f, 18.614f, 9.208f, 9.838f, 10.468f, 11.098f, 17.984f, 19.244f, 20.504f, 21.764f, 42.808f, 45.436f, 48.064f, 50.692f, 25.256f, 26.624f, 27.992f, 29.36f, - 28.064f, 29.324f, 30.584f, 31.844f, 63.832f, 66.46f, 69.088f, 71.716f, 36.2f, 37.568f, 38.936f, 40.304f, 18.316f, 19.f, 19.684f, 20.368f, 40.916f, 42.338f, 43.76f, 45.182f, 22.816f, 23.554f, 24.292f, 25.03f, 8.438f, 8.78f, 9.122f, 9.464f, 18.91f, 19.621f, 20.332f, 21.043f, 10.58f, 10.949f, 11.318f, 11.687f, - 20.944f, 21.682f, 22.42f, 23.158f, 46.388f, 47.918f, 49.448f, 50.978f, 25.66f, 26.452f, 27.244f, 28.036f, 26.848f, 27.586f, 28.324f, 29.062f, 58.628f, 60.158f, 61.688f, 63.218f, 31.996f, 32.788f, 33.58f, 34.372f, 16.106f, 16.502f, 16.898f, 17.294f, 34.894f, 35.713f, 36.532f, 37.351f, 18.896f, 19.319f, 19.742f, 20.165f}); - - auto expGradW = NDArrayFactory::create('c', {kD, kH, kW, iC, oC},{7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, - 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, - 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, - 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f}); - // auto expGradB('c', {oC},{}); - - input = 2.; - weights.linspace(0.1, 0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::conv3dnew_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto gradI = results.at(0); - auto gradW = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - + TypeParam _expWGradB[] = {9312.0, 12580.0, 9528.0, 13168.0, 17712.0, + 13360.0, 9960.0, 13348.0, 10032.0, 13344.0, + 18148.0, 13848.0, 19312.0, 26160.0, 19888.0, + 15144.0, 20452.0, 15504.0}; + Nd4jLong _expWGradS[] = { + 4, 2, 1, 3, 3, + 9, 9, 3, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, + 1, 99}; + NDArray expWGrad(_expWGradB, _expWGradS); + expWGrad.permutei({2, 3, 1, 0}); -} + TypeParam _expBGradB[] = {784.0, 1296.0}; + Nd4jLong _expBGradS[] = { + 2, 2, 1, 1, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; -//////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test3) { + NDArray expBGrad(_expBGradB, _expBGradS); - int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW + auto input = NDArrayFactory::create('c', {2, 1, 4, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 3, 3}); + auto bias = NDArrayFactory::create('c', {2, 1}); + auto epsilonNext = NDArrayFactory::create('c', {2, 2, 4, 4}); - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); - auto gradO = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); + TypeParam _expEpsB[] = { + 952.0, 1540.0, 1636.0, 1180.0, 1791.0, 2886.0, 3057.0, 2193.0, + 2223.0, 3570.0, 3741.0, 2673.0, 1900.0, 3028.0, 3160.0, 2240.0, + 2872.0, 4612.0, 4708.0, 3356.0, 5247.0, 8358.0, 8529.0, 6033.0, + 5679.0, 9042.0, 9213.0, 6513.0, 4588.0, 7252.0, 7384.0, 5184.0}; + NDArray expEps(_expEpsB, input.shapeInfo()); - auto expGradI = NDArrayFactory::create('c', {bS, iC, iD, iH, iW},{2.091f, 4.356f, 2.268f, 4.53f, 9.42f, 4.896f, 4.65f, 9.672f, 5.028f, 2.517f, 5.226f, 2.712f, 4.932f, 10.242f, 5.316f, 10.62f, 22.02f, 11.412f, 10.908f, 22.62f, 11.724f, 5.868f, 12.15f, 6.288f, 2.913f, 6.03f, 3.12f, 6.234f, 12.888f, 6.66f, 6.402f, 13.236f, 6.84f, 3.423f, 7.068f, 3.648f, - 2.415f, 5.04f, 2.628f, 5.25f, 10.932f, 5.688f, 5.37f, 11.184f, 5.82f, 2.913f, 6.054f, 3.144f, 5.724f, 11.898f, 6.18f, 12.348f, 25.62f, 13.284f, 12.636f, 26.22f, 13.596f, 6.804f, 14.094f, 7.296f, 3.381f, 7.002f, 3.624f, 7.242f, 14.976f, 7.74f, 7.41f, 15.324f, 7.92f, 3.963f, 8.184f, 4.224f, - 2.739f, 5.724f, 2.988f, 5.97f, 12.444f, 6.48f, 6.09f, 12.696f, 6.612f, 3.309f, 6.882f, 3.576f, 6.516f, 13.554f, 7.044f, 14.076f, 29.22f, 15.156f, 14.364f, 29.82f, 15.468f, 7.74f, 16.038f, 8.304f, 3.849f, 7.974f, 4.128f, 8.25f, 17.064f, 8.82f, 8.418f, 17.412f, 9.f, 4.503f, 9.3f, 4.8f, - 3.063f, 6.408f, 3.348f, 6.69f, 13.956f, 7.272f, 6.81f, 14.208f, 7.404f, 3.705f, 7.71f, 4.008f, 7.308f, 15.21f, 7.908f, 15.804f, 32.82f, 17.028f, 16.092f, 33.42f, 17.34f, 8.676f, 17.982f, 9.312f, 4.317f, 8.946f, 4.632f, 9.258f, 19.152f, 9.9f, 9.426f, 19.5f, 10.08f, 5.043f, 10.416f, 5.376f, - 5.619f, 11.484f, 5.868f, 11.73f, 23.964f, 12.24f, 12.138f, 24.792f, 12.66f, 6.333f, 12.93f, 6.6f, 12.42f, 25.362f, 12.948f, 25.884f, 52.836f, 26.964f, 26.748f, 54.588f, 27.852f, 13.932f, 28.422f, 14.496f, 6.873f, 14.022f, 7.152f, 14.298f, 29.16f, 14.868f, 14.754f, 30.084f, 15.336f, 7.671f, 15.636f, 7.968f, - 6.807f, 13.896f, 7.092f, 14.178f, 28.932f, 14.76f, 14.586f, 29.76f, 15.18f, 7.593f, 15.486f, 7.896f, 14.94f, 30.474f, 15.54f, 31.068f, 63.348f, 32.292f, 31.932f, 65.1f, 33.18f, 16.596f, 33.822f, 17.232f, 8.205f, 16.722f, 8.52f, 17.034f, 34.704f, 17.676f, 17.49f, 35.628f, 18.144f, 9.075f, 18.48f, 9.408f, - 7.995f, 16.308f, 8.316f, 16.626f, 33.9f, 17.28f, 17.034f, 34.728f, 17.7f, 8.853f, 18.042f, 9.192f, 17.46f, 35.586f, 18.132f, 36.252f, 73.86f, 37.62f, 37.116f, 75.612f, 38.508f, 19.26f, 39.222f, 19.968f, 9.537f, 19.422f, 9.888f, 19.77f, 40.248f, 20.484f, 20.226f, 41.172f, 20.952f, 10.479f, 21.324f, 10.848f, - 9.183f, 18.72f, 9.54f, 19.074f, 38.868f, 19.8f, 19.482f, 39.696f, 20.22f, 10.113f, 20.598f, 10.488f, 19.98f, 40.698f, 20.724f, 41.436f, 84.372f, 42.948f, 42.3f, 86.124f, 43.836f, 21.924f, 44.622f, 22.704f, 10.869f, 22.122f, 11.256f, 22.506f, 45.792f, 23.292f, 22.962f, 46.716f, 23.76f, 11.883f, 24.168f, 12.288f}); + input.linspace(1); + weights.linspace(1); + epsilonNext.linspace(1); + weights.permutei({2, 3, 1, 0}); - auto expGradW = NDArrayFactory::create('c', {oC, iC, kD, kH, kW},{5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, - 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, - 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, - 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, - 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, - 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f}); + sd::ops::conv2d_bp op; - auto expGradB = NDArrayFactory::create('c', {oC},{2.64f, 3.92f, 5.2f}); + auto results = op.evaluate({&input, &weights, &bias, &epsilonNext}, {}, + {3, 3, 1, 1, 0, 0, 1, 1, 1}, {}); - input = 2.; - weights.linspace(0.1, 0.1); - gradO.linspace(0.01, 0.01); - weights.permutei({2, 3, 4, 1, 0}); - expGradW.permutei({2, 3, 4, 1, 0}); + ASSERT_TRUE(results.size() == 3); - sd::ops::conv3dnew_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto* gradI = results.at(0); - auto* gradW = results.at(1); - auto* gradB = results.at(2); + auto epsilon = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); - ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expWGrad.isSameShape(gradW)); - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); + // expWGrad.printBuffer("Expctd buffer"); + // gradW->printBuffer("Result buffer"); + ASSERT_TRUE(expWGrad.equalsTo(gradW)); - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); + ASSERT_TRUE(input.isSameShape(epsilon)); - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); + // expEps.printBuffer("Expctd buffer"); + // epsilon->printBuffer("Result buffer"); + ASSERT_TRUE(expEps.equalsTo(epsilon)); + ASSERT_TRUE(expBGrad.isSameShape(gradB)); + ASSERT_TRUE(expBGrad.equalsTo(gradB)); } -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv3d_bp_test4) { - - int bS=2, iD=4,iH=3,iW=3, iC=4,oC=3, kD=3,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - int wFormat = 1; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] +TYPED_TEST(TypedConvolutionTests1, conv2D_BP_NoBias_1) { + TypeParam _expWGradB[] = {9312.0, 12580.0, 9528.0, 13168.0, 17712.0, + 13360.0, 9960.0, 13348.0, 10032.0, 13344.0, + 18148.0, 13848.0, 19312.0, 26160.0, 19888.0, + 15144.0, 20452.0, 15504.0}; + Nd4jLong _expWGradS[] = { + 4, 2, 1, 3, 3, + 9, 9, 3, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, + 1, 99}; + NDArray expWGrad(_expWGradB, _expWGradS); + expWGrad.permutei({2, 3, 1, 0}); + + auto input = NDArrayFactory::create('c', {2, 1, 4, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 3, 3}); + auto epsilonNext = NDArrayFactory::create('c', {2, 2, 4, 4}); + + TypeParam _expEpsB[] = { + 952.0, 1540.0, 1636.0, 1180.0, 1791.0, 2886.0, 3057.0, 2193.0, + 2223.0, 3570.0, 3741.0, 2673.0, 1900.0, 3028.0, 3160.0, 2240.0, + 2872.0, 4612.0, 4708.0, 3356.0, 5247.0, 8358.0, 8529.0, 6033.0, + 5679.0, 9042.0, 9213.0, 6513.0, 4588.0, 7252.0, 7384.0, 5184.0}; + NDArray expEps(_expEpsB, input.shapeInfo()); + + input.linspace(1); + weights.linspace(1); + epsilonNext.linspace(1); + weights.permutei({2, 3, 1, 0}); + + sd::ops::conv2d_bp op; + + auto results = op.evaluate({&input, &weights, &epsilonNext}, {}, + {3, 3, 1, 1, 0, 0, 1, 1, 1}, {}); + + ASSERT_TRUE(results.size() == 2); + + auto epsilon = results.at(0); + auto gradW = results.at(1); + + ASSERT_TRUE(expWGrad.isSameShape(gradW)); + + // expWGrad.printBuffer("Expctd buffer"); + // gradW->printBuffer("Result buffer"); + ASSERT_TRUE(expWGrad.equalsTo(gradW)); + + ASSERT_TRUE(input.isSameShape(epsilon)); + + // expEps.printBuffer("Expctd buffer"); + // epsilon->printBuffer("Result buffer"); + ASSERT_TRUE(expEps.equalsTo(epsilon)); +} - NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); - NDArray weights('c', {oC, iC, kD, kH, kW}, {7., 5.8, 4.6, 3.4, 2.2, 1., -0.2, -1.4, -2.6, -3.8, -5., -6.2, 6.7, 5.5, 4.3, 3.1, 1.9, 0.7, -0.5, -1.7, -2.9, -4.1, - -5.3, -6.5, 6.4, 5.2, 4., 2.8, 1.6, 0.4, -0.8, -2., -3.2, -4.4, -5.6, -6.8, 6.1, 4.9, 3.7, 2.5, 1.3, 0.1, -1.1, -2.3, -3.5, -4.7, -5.9, -7.1, 6.9, 5.7, 4.5, - 3.3, 2.1, 0.9, -0.3, -1.5, -2.7, -3.9, -5.1, -6.3, 6.6, 5.4, 4.2, 3., 1.8, 0.6, -0.6, -1.8, -3., -4.2, -5.4, -6.6, 6.3, 5.1, 3.9, 2.7, 1.5, 0.3, -0.9, -2.1, - -3.3, -4.5, -5.7, -6.9, 6., 4.8, 3.6, 2.4, 1.2, 0., -1.2, -2.4, -3.6, -4.8, -6., -7.2, 6.8, 5.6, 4.4, 3.2, 2., 0.8, -0.4, -1.6, -2.8, -4., -5.2, -6.4, 6.5, 5.3, 4.1, 2.9, 1.7, 0.5, -0.7, -1.9, -3.1, -4.3, -5.5, -6.7, 6.2, 5., 3.8, 2.6, 1.4, 0.2, -1., -2.2, -3.4, -4.6, -5.8, -7., 5.9, 4.7, 3.5, 2.3, 1.1, -0.1, -1.3, -2.5, -3.7, -4.9, -6.1, -7.3}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {1,-0.5, 0.1}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oC, oD, oH, oW}, sd::DataType::FLOAT32); +TYPED_TEST(TypedConvolutionTests1, sconv2d_conv2d_1) { + auto input = NDArrayFactory::create('c', {2, 3, 10, 10}); + auto weightsD = NDArrayFactory::create( + 'c', {5, 5, 3, 2}, + {1.f, 76.f, 26.f, 101.f, 51.f, 126.f, 2.f, 77.f, 27.f, 102.f, + 52.f, 127.f, 3.f, 78.f, 28.f, 103.f, 53.f, 128.f, 4.f, 79.f, + 29.f, 104.f, 54.f, 129.f, 5.f, 80.f, 30.f, 105.f, 55.f, 130.f, + 6.f, 81.f, 31.f, 106.f, 56.f, 131.f, 7.f, 82.f, 32.f, 107.f, + 57.f, 132.f, 8.f, 83.f, 33.f, 108.f, 58.f, 133.f, 9.f, 84.f, + 34.f, 109.f, 59.f, 134.f, 10.f, 85.f, 35.f, 110.f, 60.f, 135.f, + 11.f, 86.f, 36.f, 111.f, 61.f, 136.f, 12.f, 87.f, 37.f, 112.f, + 62.f, 137.f, 13.f, 88.f, 38.f, 113.f, 63.f, 138.f, 14.f, 89.f, + 39.f, 114.f, 64.f, 139.f, 15.f, 90.f, 40.f, 115.f, 65.f, 140.f, + 16.f, 91.f, 41.f, 116.f, 66.f, 141.f, 17.f, 92.f, 42.f, 117.f, + 67.f, 142.f, 18.f, 93.f, 43.f, 118.f, 68.f, 143.f, 19.f, 94.f, + 44.f, 119.f, 69.f, 144.f, 20.f, 95.f, 45.f, 120.f, 70.f, 145.f, + 21.f, 96.f, 46.f, 121.f, 71.f, 146.f, 22.f, 97.f, 47.f, 122.f, + 72.f, 147.f, 23.f, 98.f, 48.f, 123.f, 73.f, 148.f, 24.f, 99.f, + 49.f, 124.f, 74.f, 149.f, 25.f, 100.f, 50.f, 125.f, 75.f, 150.f}); + auto weightsP = NDArrayFactory::create( + 'c', {1, 1, 6, 10}, + {0.0001f, 0.0007f, 0.0013f, 0.0019f, 0.0025f, 0.0031f, 0.0037f, 0.0043f, + 0.0049f, 0.0055f, 0.0002f, 0.0008f, 0.0014f, 0.0020f, 0.0026f, 0.0032f, + 0.0038f, 0.0044f, 0.0050f, 0.0056f, 0.0003f, 0.0009f, 0.0015f, 0.0021f, + 0.0027f, 0.0033f, 0.0039f, 0.0045f, 0.0051f, 0.0057f, 0.0004f, 0.0010f, + 0.0016f, 0.0022f, 0.0028f, 0.0034f, 0.0040f, 0.0046f, 0.0052f, 0.0058f, + 0.0005f, 0.0011f, 0.0017f, 0.0023f, 0.0029f, 0.0035f, 0.0041f, 0.0047f, + 0.0053f, 0.0059f, 0.0006f, 0.0012f, 0.0018f, 0.0024f, 0.0030f, 0.0036f, + 0.0042f, 0.0048f, 0.0054f, 0.0060f}); + + auto expFF = NDArrayFactory::create( + 'c', {2, 6, 6, 6}, + {10025.0f, 10350.0f, 10675.0f, 11000.0f, 11325.0f, 11650.0f, + 13275.0f, 13600.0f, 13925.0f, 14250.0f, 14575.0f, 14900.0f, + 16525.0f, 16850.0f, 17175.0f, 17500.0f, 17825.0f, 18150.0f, + 19775.0f, 20100.0f, 20425.0f, 20750.0f, 21075.0f, 21400.0f, + 23025.0f, 23350.0f, 23675.0f, 24000.0f, 24325.0f, 24650.0f, + 26275.0f, 26600.0f, 26925.0f, 27250.0f, 27575.0f, 27900.0f, + 53150.0f, 55350.0f, 57550.0f, 59750.0f, 61950.0f, 64150.0f, + 75150.0f, 77350.0f, 79550.0f, 81750.0f, 83950.0f, 86150.0f, + 97150.0f, 99350.0f, 101550.0f, 103750.0f, 105950.0f, 108150.0f, + 119150.0f, 121350.0f, 123550.0f, 125750.0f, 127950.0f, 130150.0f, + 141150.0f, 143350.0f, 145550.0f, 147750.0f, 149950.0f, 152150.0f, + 163150.0f, 165350.0f, 167550.0f, 169750.0f, 171950.0f, 174150.0f, + 119400.0f, 120350.0f, 121300.0f, 122250.0f, 123200.0f, 124150.0f, + 128900.0f, 129850.0f, 130800.0f, 131750.0f, 132700.0f, 133650.0f, + 138400.0f, 139350.0f, 140300.0f, 141250.0f, 142200.0f, 143150.0f, + 147900.0f, 148850.0f, 149800.0f, 150750.0f, 151700.0f, 152650.0f, + 157400.0f, 158350.0f, 159300.0f, 160250.0f, 161200.0f, 162150.0f, + 166900.0f, 167850.0f, 168800.0f, 169750.0f, 170700.0f, 171650.0f, + 350025.0f, 352850.0f, 355675.0f, 358500.0f, 361325.0f, 364150.0f, + 378275.0f, 381100.0f, 383925.0f, 386750.0f, 389575.0f, 392400.0f, + 406525.0f, 409350.0f, 412175.0f, 415000.0f, 417825.0f, 420650.0f, + 434775.0f, 437600.0f, 440425.0f, 443250.0f, 446075.0f, 448900.0f, + 463025.0f, 465850.0f, 468675.0f, 471500.0f, 474325.0f, 477150.0f, + 491275.0f, 494100.0f, 496925.0f, 499750.0f, 502575.0f, 505400.0f, + 353775.0f, 355350.0f, 356925.0f, 358500.0f, 360075.0f, 361650.0f, + 369525.0f, 371100.0f, 372675.0f, 374250.0f, 375825.0f, 377400.0f, + 385275.0f, 386850.0f, 388425.0f, 390000.0f, 391575.0f, 393150.0f, + 401025.0f, 402600.0f, 404175.0f, 405750.0f, 407325.0f, 408900.0f, + 416775.0f, 418350.0f, 419925.0f, 421500.0f, 423075.0f, 424650.0f, + 432525.0f, 434100.0f, 435675.0f, 437250.0f, 438825.0f, 440400.0f, + 771900.0f, 775350.0f, 778800.0f, 782250.0f, 785700.0f, 789150.0f, + 806400.0f, 809850.0f, 813300.0f, 816750.0f, 820200.0f, 823650.0f, + 840900.0f, 844350.0f, 847800.0f, 851250.0f, 854700.0f, 858150.0f, + 875400.0f, 878850.0f, 882300.0f, 885750.0f, 889200.0f, 892650.0f, + 909900.0f, 913350.0f, 916800.0f, 920250.0f, 923700.0f, 927150.0f, + 944400.0f, 947850.0f, 951300.0f, 954750.0f, 958200.0f, 961650.0f, + 107525.0f, 107850.0f, 108175.0f, 108500.0f, 108825.0f, 109150.0f, + 110775.0f, 111100.0f, 111425.0f, 111750.0f, 112075.0f, 112400.0f, + 114025.0f, 114350.0f, 114675.0f, 115000.0f, 115325.0f, 115650.0f, + 117275.0f, 117600.0f, 117925.0f, 118250.0f, 118575.0f, 118900.0f, + 120525.0f, 120850.0f, 121175.0f, 121500.0f, 121825.0f, 122150.0f, + 123775.0f, 124100.0f, 124425.0f, 124750.0f, 125075.0f, 125400.0f, + 713150.0f, 715350.0f, 717550.0f, 719750.0f, 721950.0f, 724150.0f, + 735150.0f, 737350.0f, 739550.0f, 741750.0f, 743950.0f, 746150.0f, + 757150.0f, 759350.0f, 761550.0f, 763750.0f, 765950.0f, 768150.0f, + 779150.0f, 781350.0f, 783550.0f, 785750.0f, 787950.0f, 790150.0f, + 801150.0f, 803350.0f, 805550.0f, 807750.0f, 809950.0f, 812150.0f, + 823150.0f, 825350.0f, 827550.0f, 829750.0f, 831950.0f, 834150.0f, + 404400.0f, 405350.0f, 406300.0f, 407250.0f, 408200.0f, 409150.0f, + 413900.0f, 414850.0f, 415800.0f, 416750.0f, 417700.0f, 418650.0f, + 423400.0f, 424350.0f, 425300.0f, 426250.0f, 427200.0f, 428150.0f, + 432900.0f, 433850.0f, 434800.0f, 435750.0f, 436700.0f, 437650.0f, + 442400.0f, 443350.0f, 444300.0f, 445250.0f, 446200.0f, 447150.0f, + 451900.0f, 452850.0f, 453800.0f, 454750.0f, 455700.0f, 456650.0f, + 1197525.0f, 1200350.0f, 1203175.0f, 1206000.0f, 1208825.0f, 1211650.0f, + 1225775.0f, 1228600.0f, 1231425.0f, 1234250.0f, 1237075.0f, 1239900.0f, + 1254025.0f, 1256850.0f, 1259675.0f, 1262500.0f, 1265325.0f, 1268150.0f, + 1282275.0f, 1285100.0f, 1287925.0f, 1290750.0f, 1293575.0f, 1296400.0f, + 1310525.0f, 1313350.0f, 1316175.0f, 1319000.0f, 1321825.0f, 1324650.0f, + 1338775.0f, 1341600.0f, 1344425.0f, 1347250.0f, 1350075.0f, 1352900.0f, + 826275.0f, 827850.0f, 829425.0f, 831000.0f, 832575.0f, 834150.0f, + 842025.0f, 843600.0f, 845175.0f, 846750.0f, 848325.0f, 849900.0f, + 857775.0f, 859350.0f, 860925.0f, 862500.0f, 864075.0f, 865650.0f, + 873525.0f, 875100.0f, 876675.0f, 878250.0f, 879825.0f, 881400.0f, + 889275.0f, 890850.0f, 892425.0f, 894000.0f, 895575.0f, 897150.0f, + 905025.0f, 906600.0f, 908175.0f, 909750.0f, 911325.0f, 912900.0f, + 1806900.0f, 1810350.0f, 1813800.0f, 1817250.0f, 1820700.0f, 1824150.0f, + 1841400.0f, 1844850.0f, 1848300.0f, 1851750.0f, 1855200.0f, 1858650.0f, + 1875900.0f, 1879350.0f, 1882800.0f, 1886250.0f, 1889700.0f, 1893150.0f, + 1910400.0f, 1913850.0f, 1917300.0f, 1920750.0f, 1924200.0f, 1927650.0f, + 1944900.0f, 1948350.0f, 1951800.0f, 1955250.0f, 1958700.0f, 1962150.0f, + 1979400.0f, 1982850.0f, 1986300.0f, 1989750.0f, 1993200.0f, 1996650.f}); + auto exp2FF = NDArrayFactory::create( + 'c', {2, 10, 6, 6}, + {827.4900282f, 832.2350283f, 836.9800284f, 841.725028f, + 846.4700287f, 851.2150288f, 874.9400293f, 879.6850294f, + 884.4300295f, 889.1750296f, 893.9200297f, 898.665029f, + 922.3900304f, 927.1350305f, 931.8800306f, 936.6250307f, + 941.3700308f, 946.1150309f, 969.8400315f, 974.5850316f, + 979.3300317f, 984.0750318f, 988.8200319f, 993.5650320f, + 1017.2900326f, 1022.0350327f, 1026.7800328f, 1031.5250329f, + 1036.2700330f, 1041.0150331f, 1064.7400337f, 1069.4850338f, + 1074.2300339f, 1078.9750340f, 1083.7200341f, 1088.4650342f, + 1822.4550553f, 1833.995055f, 1845.5350558f, 1857.075056f, + 1868.6150563f, 1880.1550566f, 1937.8550578f, 1949.3950581f, + 1960.9350583f, 1972.4750586f, 1984.015058f, 1995.5550591f, + 2053.2550604f, 2064.7950606f, 2076.3350609f, 2087.8750611f, + 2099.4150614f, 2110.955061f, 2168.6550629f, 2180.1950632f, + 2191.7350634f, 2203.2750637f, 2214.8150639f, 2226.3550642f, + 2284.0550655f, 2295.5950657f, 2307.1350660f, 2318.6750662f, + 2330.2150665f, 2341.7550667f, 2399.4550680f, 2410.9950683f, + 2422.5350685f, 2434.0750688f, 2445.6150690f, 2457.1550693f, + 2817.419968f, 2835.7549686f, 2854.0899683f, 2872.4249680f, + 2890.7599677f, 2909.0949674f, 3000.7699660f, 3019.104965f, + 3037.4399655f, 3055.7749652f, 3074.1099649f, 3092.4449646f, + 3184.1199632f, 3202.4549629f, 3220.789962f, 3239.1249624f, + 3257.4599621f, 3275.7949618f, 3367.4699604f, 3385.8049601f, + 3404.1399598f, 3422.474959f, 3440.8099593f, 3459.1449590f, + 3550.8199576f, 3569.1549573f, 3587.4899570f, 3605.8249567f, + 3624.1599565f, 3642.4949562f, 3734.1699548f, 3752.5049545f, + 3770.8399542f, 3789.1749539f, 3807.5099536f, 3825.8449534f, + 3812.385098f, 3837.5150988f, 3862.6450994f, 3887.7751000f, + 3912.9051006f, 3938.0351012f, 4063.6851041f, 4088.8151047f, + 4113.9451053f, 4139.0751059f, 4164.2051065f, 4189.3351071f, + 4314.9851100f, 4340.1151106f, 4365.2451112f, 4390.3751118f, + 4415.5051124f, 4440.6351130f, 4566.2851159f, 4591.4151165f, + 4616.5451171f, 4641.6751177f, 4666.805118f, 4691.9351188f, + 4817.5851218f, 4842.7151224f, 4867.8451230f, 4892.975123f, + 4918.1051241f, 4943.2351247f, 5068.8851277f, 5094.0151283f, + 5119.1451288f, 5144.2751294f, 5169.4051300f, 5194.5351306f, + 4807.3499803f, 4839.2749801f, 4871.1999799f, 4903.1249797f, + 4935.0499795f, 4966.9749793f, 5126.5999784f, 5158.5249782f, + 5190.4499780f, 5222.3749778f, 5254.2999777f, 5286.2249775f, + 5445.8499765f, 5477.774976f, 5509.6999762f, 5541.6249760f, + 5573.5499758f, 5605.4749756f, 5765.0999747f, 5797.0249745f, + 5828.9499743f, 5860.8749741f, 5892.7999739f, 5924.724973f, + 6084.3499728f, 6116.2749726f, 6148.1999724f, 6180.1249723f, + 6212.0499721f, 6243.9749719f, 6403.59997f, 6435.5249708f, + 6467.4499706f, 6499.3749704f, 6531.2999702f, 6563.2249700f, + 5802.3150007f, 5841.0350006f, 5879.7550005f, 5918.4750004f, + 5957.195000f, 5995.9150003f, 6189.5149999f, 6228.2349998f, + 6266.9549997f, 6305.6749996f, 6344.3949995f, 6383.114999f, + 6576.7149990f, 6615.4349990f, 6654.1549989f, 6692.8749988f, + 6731.5949987f, 6770.3149986f, 6963.9149982f, 7002.6349981f, + 7041.3549981f, 7080.0749980f, 7118.7949979f, 7157.5149978f, + 7351.1149974f, 7389.8349973f, 7428.5549972f, 7467.2749972f, + 7505.9949971f, 7544.7149970f, 7738.3149966f, 7777.0349965f, + 7815.7549964f, 7854.4749963f, 7893.1949963f, 7931.9149962f, + 6797.2799488f, 6842.794948f, 6888.3099489f, 6933.8249490f, + 6979.3399491f, 7024.8549492f, 7252.4299497f, 7297.9449498f, + 7343.4599499f, 7388.9749500f, 7434.489950f, 7480.0049501f, + 7707.5799506f, 7753.0949507f, 7798.6099508f, 7844.1249509f, + 7889.6399510f, 7935.1549511f, 8162.7299515f, 8208.2449516f, + 8253.7599517f, 8299.2749518f, 8344.7899519f, 8390.3049520f, + 8617.8799525f, 8663.394952f, 8708.9099526f, 8754.4249527f, + 8799.9399528f, 8845.4549529f, 9073.0299534f, 9118.5449535f, + 9164.0599536f, 9209.5749537f, 9255.089953f, 9300.604953f, + 7792.2451647f, 7844.5551655f, 7896.8651663f, 7949.1751671f, + 8001.4851679f, 8053.7951686f, 8315.3451725f, 8367.6551733f, + 8419.9651741f, 8472.2751749f, 8524.585175f, 8576.8951764f, + 8838.4451803f, 8890.7551811f, 8943.0651819f, 8995.3751827f, + 9047.6851834f, 9099.9951842f, 9361.5451881f, 9413.8551889f, + 9466.1651897f, 9518.475190f, 9570.7851912f, 9623.0951920f, + 9884.6451959f, 9936.9551967f, 9989.2651975f, 10041.5751982f, + 10093.8851990f, 10146.1951998f, 10407.7452037f, 10460.0552045f, + 10512.3652053f, 10564.6752060f, 10616.9852068f, 10669.2952076f, + 8787.210074f, 8846.3150748f, 8905.4200750f, 8964.5250752f, + 9023.6300755f, 9082.7350757f, 9378.2600768f, 9437.3650770f, + 9496.4700773f, 9555.5750775f, 9614.6800777f, 9673.7850779f, + 9969.3100791f, 10028.4150793f, 10087.5200795f, 10146.625079f, + 10205.7300800f, 10264.8350802f, 10560.3600813f, 10619.465081f, + 10678.5700818f, 10737.6750820f, 10796.7800822f, 10855.8850825f, + 11151.4100836f, 11210.5150838f, 11269.6200840f, 11328.7250843f, + 11387.8300845f, 11446.9350847f, 11742.4600858f, 11801.5650861f, + 11860.6700863f, 11919.7750865f, 11978.880086f, 12037.9850870f, + 9782.1750935f, 9848.0750935f, 9913.9750934f, 9979.8750934f, + 10045.7750934f, 10111.6750933f, 10441.1750931f, 10507.0750931f, + 10572.9750931f, 10638.8750930f, 10704.7750930f, 10770.6750930f, + 11100.1750928f, 11166.0750927f, 11231.9750927f, 11297.8750927f, + 11363.7750926f, 11429.6750926f, 11759.1750924f, 11825.0750924f, + 11890.9750923f, 11956.8750923f, 12022.7750923f, 12088.6750922f, + 12418.175092f, 12484.0750920f, 12549.9750920f, 12615.8750919f, + 12681.7750919f, 12747.6750919f, 13077.1750917f, 13143.0750916f, + 13208.9750916f, 13274.8750916f, 13340.7750915f, 13406.6750915f, + 2250.990060f, 2255.7350610f, 2260.4800611f, 2265.2250612f, + 2269.9700613f, 2274.7150614f, 2298.4400619f, 2303.185062f, + 2307.9300622f, 2312.6750623f, 2317.4200624f, 2322.1650625f, + 2345.8900630f, 2350.6350631f, 2355.380063f, 2360.1250634f, + 2364.8700635f, 2369.6150636f, 2393.3400641f, 2398.0850642f, + 2402.8300643f, 2407.5750644f, 2412.320064f, 2417.0650647f, + 2440.7900652f, 2445.5350653f, 2450.2800654f, 2455.0250655f, + 2459.7700656f, 2464.515065f, 2488.2400663f, 2492.9850664f, + 2497.7300665f, 2502.4750666f, 2507.2200667f, 2511.9650668f, + 5284.4551315f, 5295.9951318f, 5307.535132f, 5319.0751323f, + 5330.6151326f, 5342.1551328f, 5399.8551341f, 5411.3951343f, + 5422.9351346f, 5434.475134f, 5446.0151351f, 5457.5551354f, + 5515.2551366f, 5526.7951369f, 5538.3351371f, 5549.8751374f, + 5561.4151376f, 5572.9551379f, 5630.6551392f, 5642.1951394f, + 5653.7351397f, 5665.2751399f, 5676.8151402f, 5688.3551404f, + 5746.0551417f, 5757.5951420f, 5769.1351422f, 5780.6751425f, + 5792.2151427f, 5803.7551430f, 5861.455144f, 5872.9951445f, + 5884.5351448f, 5896.0751450f, 5907.6151453f, 5919.1551455f, + 8317.919884f, 8336.2548841f, 8354.5898838f, 8372.9248835f, + 8391.2598832f, 8409.59488f, 8501.2698815f, 8519.6048813f, + 8537.9398810f, 8556.2748807f, 8574.6098804f, 8592.9448801f, + 8684.6198787f, 8702.9548784f, 8721.2898782f, 8739.6248779f, + 8757.9598776f, 8776.2948773f, 8867.9698759f, 8886.3048756f, + 8904.6398753f, 8922.9748751f, 8941.3098748f, 8959.6448745f, + 9051.3198731f, 9069.6548728f, 9087.9898725f, 9106.3248722f, + 9124.6598720f, 9142.9948717f, 9234.6698703f, 9253.0048700f, + 9271.3398697f, 9289.6748694f, 9308.0098691f, 9326.3448689f, + 11351.3852747f, 11376.5152753f, 11401.6452759f, 11426.7752765f, + 11451.9052771f, 11477.0352777f, 11602.6852806f, 11627.8152812f, + 11652.9452818f, 11678.0752824f, 11703.2052830f, 11728.335283f, + 11853.9852865f, 11879.1152871f, 11904.2452877f, 11929.3752883f, + 11954.505288f, 11979.6352894f, 12105.2852924f, 12130.4152930f, + 12155.545293f, 12180.6752941f, 12205.8052947f, 12230.9352953f, + 12356.5852983f, 12381.715298f, 12406.8452994f, 12431.9753000f, + 12457.1053006f, 12482.2353012f, 12607.8853041f, 12633.0153047f, + 12658.1453053f, 12683.2753059f, 12708.4053065f, 12733.5353071f, + 14384.8499244f, 14416.7749242f, 14448.6999240f, 14480.6249238f, + 14512.549923f, 14544.4749235f, 14704.0999225f, 14736.024922f, + 14767.9499222f, 14799.8749220f, 14831.7999218f, 14863.7249216f, + 15023.3499207f, 15055.2749205f, 15087.1999203f, 15119.1249201f, + 15151.0499199f, 15182.9749197f, 15342.5999188f, 15374.5249186f, + 15406.4499184f, 15438.374918f, 15470.2999181f, 15502.2249179f, + 15661.84991f, 15693.7749168f, 15725.6999166f, 15757.6249164f, + 15789.5499162f, 15821.4749160f, 15981.0999151f, 16013.0249149f, + 16044.9499147f, 16076.8749145f, 16108.7999143f, 16140.7249142f, + 17418.314976f, 17457.0349761f, 17495.7549760f, 17534.4749759f, + 17573.1949758f, 17611.9149757f, 17805.5149753f, 17844.234975f, + 17882.9549752f, 17921.6749751f, 17960.3949750f, 17999.1149749f, + 18192.7149745f, 18231.4349744f, 18270.154974f, 18308.8749743f, + 18347.5949742f, 18386.3149741f, 18579.9149737f, 18618.6349736f, + 18657.3549735f, 18696.074973f, 18734.7949734f, 18773.5149733f, + 18967.1149729f, 19005.8349728f, 19044.5549727f, 19083.2749726f, + 19121.994972f, 19160.7149725f, 19354.3149721f, 19393.0349720f, + 19431.7549719f, 19470.4749718f, 19509.1949717f, 19547.914971f, + 20451.7799765f, 20497.2949766f, 20542.8099767f, 20588.3249768f, + 20633.8399769f, 20679.3549770f, 20906.929977f, 20952.4449775f, + 20997.9599776f, 21043.4749777f, 21088.9899778f, 21134.5049779f, + 21362.0799784f, 21407.5949785f, 21453.1099786f, 21498.624978f, + 21544.139978f, 21589.6549788f, 21817.2299793f, 21862.7449794f, + 21908.2599795f, 21953.7749796f, 21999.2899797f, 22044.8049798f, + 22272.3799802f, 22317.8949803f, 22363.4099804f, 22408.9249805f, + 22454.4399806f, 22499.9549807f, 22727.529981f, 22773.044981f, + 22818.5599813f, 22864.0749814f, 22909.5899815f, 22955.1049816f, + 23485.2453985f, 23537.555399f, 23589.8654000f, 23642.1754008f, + 23694.4854016f, 23746.7954024f, 24008.3454063f, 24060.655407f, + 24112.9654078f, 24165.2754086f, 24217.5854094f, 24269.8954102f, + 24531.4454141f, 24583.7554148f, 24636.0654156f, 24688.3754164f, + 24740.6854172f, 24792.99541f, 25054.545421f, 25106.8554226f, + 25159.1654234f, 25211.4754242f, 25263.7854250f, 25316.0954257f, + 25577.6454296f, 25629.9554304f, 25682.2654312f, 25734.5754320f, + 25786.8854328f, 25839.1954335f, 26100.7454374f, 26153.0554382f, + 26205.3654390f, 26257.6754398f, 26309.985440f, 26362.2954413f, + 26518.7101423f, 26577.8151425f, 26636.920142f, 26696.0251430f, + 26755.1301432f, 26814.2351434f, 27109.7601446f, 27168.8651448f, + 27227.9701450f, 27287.0751452f, 27346.1801455f, 27405.2851457f, + 27700.8101468f, 27759.9151470f, 27819.0201473f, 27878.1251475f, + 27937.2301477f, 27996.33514f, 28291.8601491f, 28350.9651493f, + 28410.0701495f, 28469.175149f, 28528.2801500f, 28587.3851502f, + 28882.9101513f, 28942.0151516f, 29001.1201518f, 29060.2251520f, + 29119.3301522f, 29178.4351525f, 29473.9601536f, 29533.0651538f, + 29592.1701540f, 29651.2751543f, 29710.3801545f, 29769.4851547f, + 29552.1750826f, 29618.0750825f, 29683.9750825f, 29749.8750825f, + 29815.7750824f, 29881.6750824f, 30211.1750822f, 30277.0750822f, + 30342.9750821f, 30408.8750821f, 30474.7750821f, 30540.6750820f, + 30870.175081f, 30936.0750818f, 31001.9750818f, 31067.8750817f, + 31133.7750817f, 31199.6750817f, 31529.1750815f, 31595.075081f, + 31660.9750814f, 31726.8750814f, 31792.7750813f, 31858.6750813f, + 32188.1750811f, 32254.0750811f, 32319.975081f, 32385.8750810f, + 32451.7750810f, 32517.6750809f, 32847.1750808f, 32913.0750807f, + 32978.9750807f, 33044.875080f, 33110.7750806f, 33176.67508062f}); + + input.linspace(1); + + sd::ops::sconv2d op; + auto resultFF = + op.evaluate({&input, &weightsD}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0}, {}); + + auto z = resultFF.at(0); + + ASSERT_TRUE(z.isSameShape(&expFF)); + ASSERT_TRUE(z.equalsTo(&expFF, 1)); + + sd::ops::conv2d op2d; + // weightsP.printShapeInfo(); + auto result2D = + op2d.evaluate({&z, &weightsP}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {}); + + auto z2d = result2D.at(0); + // z2d->printBuffer(); + + ASSERT_TRUE(z2d.isSameShape(&exp2FF)); + ASSERT_TRUE(z2d.equalsTo(&exp2FF)); +} - NDArray expGradI('c', {bS, iC, iD, iH, iW},{1.847, 3.577, 1.694, 3.460, 6.542, 3.010, 1.469, 2.677, 1.172, 3.226, 5.929999, 2.632, 5.408, 9.483999, 3.932, 1.894, - 2.978, 1.012, 0.058, -0.694, -0.824, -1.504, -4.916, -3.556, -1.850, -4.798, -3.020, -1.069, -2.687, -1.654, -3.236, -7.714, -4.550, -2.311, -5.315, -3.040, - 1.766, 3.406, 1.604, 3.280, 6.164, 2.812, 1.370, 2.470, 1.064, 3.028, 5.516, 2.416, 4.976, 8.584001, 3.464, 1.660, 2.492, 0.760, -0.140, -1.108, -1.040, -1.936, - -5.816, -4.024, -2.084, -5.284, -3.272, -1.186, -2.930, -1.780, -3.488, -8.236, -4.820, -2.446, -5.594, -3.184, 1.685, 3.235, 1.514, 3.100, 5.786, 2.614, 1.271, - 2.263, 0.956, 2.830, 5.102, 2.200, 4.544001, 7.683999, 2.996, 1.426, 2.006, 0.508, -0.338, -1.522, -1.256, -2.368, -6.716, -4.492, -2.318, -5.770, -3.524, -1.303, - -3.173, -1.906, -3.740, -8.757999, -5.090, -2.581, -5.873, -3.328, 1.604, 3.064, 1.424, 2.920, 5.408, 2.416, 1.172, 2.056, 0.848, 2.632, 4.688, 1.984, 4.112, 6.784, 2.528, 1.192, 1.520, 0.256, -0.536, -1.936, -1.472, -2.800, -7.616, -4.960, -2.552, -6.256, -3.776, -1.420, -3.416, -2.032, -3.992, -9.280001, -5.360, -2.716, -6.152, -3.472, 6.815001, 12.649, 5.798, 11.668, 21.230, 9.490, 4.709, 8.292999, 3.548, 9.706, 17.162001, 7.384, 14.912001, 25.036001, 9.980001, 4.918, 7.298, 2.308, -0.374, -3.286, -2.984, -5.824, -17.012001, -11.332001, -5.738, -14.302, -8.636, -3.013, -7.439, -4.462, -8.852, -20.674, -11.894, -5.983, -13.523, -7.576, 6.518, 12.046, 5.492, 11.056, 19.988001, 8.860001, 4.394, 7.654, 3.224, 9.075999, 15.883999, 6.736001, 13.616, 22.407999, 8.648, 4.252, 5.947999, 1.624, -1.004, -4.564, -3.632, -7.120, -19.639999, -12.664001, -6.404, -15.652, -9.320, -3.346, -8.114, -4.804, -9.536, -22.059999, -12.596, -6.334, -14.233999, -7.936, 6.221, 11.443, 5.186, 10.444, 18.746, 8.230, 4.079, 7.015, 2.900, 8.446, 14.606001, 6.088, 12.320, 19.779999, 7.316, 3.586, 4.598001, 0.940, -1.634, -5.842, -4.280, -8.416, -22.268002, -13.996, -7.070001, -17.001999, -10.004001, -3.679, -8.789, -5.146, -10.220, -23.445999, -13.298, -6.684999, -14.945, -8.296, 5.924, 10.840, 4.880, 9.832001, 17.504, 7.600, 3.764, 6.376, 2.576, 7.816, 13.328, 5.440001, 11.024, 17.152, 5.983999, 2.920, 3.247999, 0.256, -2.264, -7.120, -4.928, -9.712, -24.896, -15.328, -7.736, -18.352001, -10.688, -4.012, -9.464, -5.488, -10.903999, -24.832001, -14.000, -7.035999, -15.656, -8.655999}, sd::DataType::FLOAT32); +TEST_F(ConvolutionTests1, deconv2d_bp_1) { + int bS = 3, iH = 4, iW = 4, iC = 3, oC = 2, kH = 1, kW = 1, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 4; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, oC, iC}, {1, 3, 5, 2, 4, 6}, + sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); + + NDArray expGradI( + 'c', {bS, iC, iH, iW}, + {35.f, 38.f, 41.f, 44.f, 47.f, 50.f, 53.f, 56.f, 59.f, 62.f, + 65.f, 68.f, 71.f, 74.f, 77.f, 80.f, 71.f, 78.f, 85.f, 92.f, + 99.f, 106.f, 113.f, 120.f, 127.f, 134.f, 141.f, 148.f, 155.f, 162.f, + 169.f, 176.f, 107.f, 118.f, 129.f, 140.f, 151.f, 162.f, 173.f, 184.f, + 195.f, 206.f, 217.f, 228.f, 239.f, 250.f, 261.f, 272.f, 131.f, 134.f, + 137.f, 140.f, 143.f, 146.f, 149.f, 152.f, 155.f, 158.f, 161.f, 164.f, + 167.f, 170.f, 173.f, 176.f, 295.f, 302.f, 309.f, 316.f, 323.f, 330.f, + 337.f, 344.f, 351.f, 358.f, 365.f, 372.f, 379.f, 386.f, 393.f, 400.f, + 459.f, 470.f, 481.f, 492.f, 503.f, 514.f, 525.f, 536.f, 547.f, 558.f, + 569.f, 580.f, 591.f, 602.f, 613.f, 624.f, 227.f, 230.f, 233.f, 236.f, + 239.f, 242.f, 245.f, 248.f, 251.f, 254.f, 257.f, 260.f, 263.f, 266.f, + 269.f, 272.f, 519.f, 526.f, 533.f, 540.f, 547.f, 554.f, 561.f, 568.f, + 575.f, 582.f, 589.f, 596.f, 603.f, 610.f, 617.f, 624.f, 811.f, 822.f, + 833.f, 844.f, 855.f, 866.f, 877.f, 888.f, 899.f, 910.f, 921.f, 932.f, + 943.f, 954.f, 965.f, 976.f}, + sd::DataType::FLOAT32); + NDArray expGradW('c', {kH, kW, oC, iC}, + {160008., 191112., 222216., 203400., 246792., 290184.f}, + sd::DataType::FLOAT32); + NDArray expGradB('c', {oC}, {1944.f, 2712.f}, sd::DataType::FLOAT32); + + input.linspace(1); + bias.linspace(1); + gradO.linspace(1); + + sd::ops::deconv2d_bp op; + auto results = + op.evaluate({&input, &weights, &bias, &gradO}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} - NDArray expGradW('c', {oC, iC, kD, kH, kW},{-24.399998, -23.080000, -20.440001, -19.119999, -12.519999, -11.199998, -8.560001, -7.240002, -0.639999, 0.679999, - 3.320001, 4.640001, 23.119999, 24.439999, 27.080002, 28.400002, 35.000000, 36.320000, 38.959999, 40.279999, 46.879997, 48.200005, 50.839996, 52.160004, - 70.639999, 71.959999, 74.599998, 75.919998, 82.520004, 83.840004, 86.479996, 87.800003, 94.399994, 95.719994, 98.360001, 99.680008, 118.160004, 119.479996, - 122.120003, 123.440010, 130.040009, 131.360001, 134.000000, 135.319992, 141.919998, 143.239990, 145.879990, 147.200012, -70.159996, -68.200005, -64.279999, - -62.319996, -52.519993, -50.559994, -46.640003, -44.680000, -34.880001, -32.919998, -29.000002, -27.040005, 0.400004, 2.359996, 6.279998, 8.240004, 18.040001, - 20.000000, 23.920002, 25.879999, 35.680000, 37.639996, 41.560001, 43.520000, 70.959999, 72.919998, 76.840004, 78.799995, 88.599998, 90.560005, 94.479996, 96.440002, 106.240005, 108.199997, 112.120003, 114.080002, 141.519989, 143.479996, 147.400009, 149.360001, 159.159988, 161.119995, 165.040009, 167.000000, 176.800003, 178.760010, 182.679993, 184.639999, -115.920006, -113.320000, -108.120003, -105.520012, -92.520004, -89.919991, -84.720001, -82.119995, -69.120010, -66.520004, -61.320000, -58.719994, -22.320000, -19.719999, -14.520001, -11.920001, 1.079997, 3.679997, 8.879997, 11.480003, 24.480001, 27.079998, 32.280003, 34.880001, 71.279999, 73.880005, 79.080002, 81.680000, 94.679993, 97.280006, 102.479996, 105.080002, 118.080002, 120.679993, 125.879997, 128.479996, 164.880005, 167.479996, 172.679993, 175.279999, 188.279984, 190.880005, 196.080002, 198.679993, 211.680008, 214.280014, 219.479996, 222.079987}, sd::DataType::FLOAT32); +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, deconv2d_bp_2) { + int bS = 3, iH = 4, iW = 4, iC = 3, oC = 2, kH = 2, kW = 1, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 4; // 5,4 + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = + 1; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32); + NDArray weights('c', {iC, oC, kH, kW}, + {1., 7., 2., 10., 3., 8., 4., 11., 5., 9., 6., 12.}, + sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); + + NDArray expGradI( + 'c', {bS, iC, iH, iW}, + {-77.400002, -77.199997, -77., -76.800003, -76.599998, + -76.400002, -76.200005, -76., -75.800003, -75.599998, + -75.399994, -75.199997, -11.32, -11.29, -11.26, + -11.23, -100.839996, -100.580002, -100.32, -100.059998, + -99.800003, -99.540001, -99.279999, -99.019997, -98.760002, + -98.50, -98.240005, -97.979996, -26.52, -26.450001, + -26.380001, -26.309999, -124.279999, -123.959991, -123.639999, + -123.32, -123., -122.68, -122.360001, -122.040001, + -121.720001, -121.400009, -121.080002, -120.759995, -41.720001, + -41.610001, -41.50, -41.389999, -71., -70.800003, + -70.599998, -70.399994, -70.199997, -70., -69.800003, + -69.600006, -69.400002, -69.199997, -69., -68.799995, + -10.360001, -10.33, -10.30, -10.27, -92.519997, + -92.260002, -92., -91.740005, -91.479996, -91.220001, + -90.960007, -90.700005, -90.440002, -90.18, -89.919998, + -89.660004, -24.280001, -24.209999, -24.139999, -24.07, + -114.040001, -113.720001, -113.400009, -113.080002, -112.759995, + -112.440002, -112.120003, -111.800003, -111.480003, -111.159996, + -110.839996, -110.520004, -38.200001, -38.09, -37.980003, + -37.869999, -64.599998, -64.400002, -64.199997, -64., + -63.799995, -63.599998, -63.400002, -63.199997, -63., + -62.799995, -62.599998, -62.400002, -9.40, -9.37, + -9.34, -9.309999, -84.200005, -83.940002, -83.68, + -83.419998, -83.160004, -82.900002, -82.639999, -82.379997, + -82.119995, -81.860001, -81.600006, -81.339996, -22.040001, + -21.970001, -21.90, -21.83, -103.800003, -103.480003, + -103.159996, -102.839996, -102.520004, -102.200005, -101.879997, + -101.559998, -101.239998, -100.919998, -100.599998, -100.279999, + -34.68, -34.57, -34.459999, -34.349998}, + sd::DataType::FLOAT32); + + NDArray expGradW('c', {iC, oC, kH, kW}, + {-3010.799805, -2502.420410, -2899.439209, -2407.380615, + -242.159332, -437.460510, -253.680466, -434.580048, + 2526.479980, 1627.500000, 2392.079834, 1538.220093}, + sd::DataType::FLOAT32); + NDArray expGradB('c', {oC}, {-173.040009, -165.360016}, + sd::DataType::FLOAT32); + + input.linspace(70., -1); + gradO.linspace(-4, 0.01); + + sd::ops::deconv2d_bp op; + auto results = op.evaluate( + {&input, &weights, &bias, &gradO}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat, wFormat}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} - NDArray expGradB('c', {oC}, {2.64, 3.92, 5.2}, sd::DataType::FLOAT32); +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, deconv2d_bp_3) { + int bS = 3, iH = 4, iW = 4, iC = 3, oC = 2, kH = 2, kW = 1, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 5, oW = 4; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = + 2; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32); + NDArray weights('c', {iC, kH, kW, oC}, + {1., 4., 7., 10., 2., 5., 8., 11., 3., 6., 9., 12.}, + sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); + + NDArray expGradI( + 'c', {bS, iH, iW, iC}, + {-86.5, -102.320007, -118.139999, -86.060005, -101.800003, + -117.540001, -85.619995, -101.279999, -116.940002, -85.18, + -100.759995, -116.339996, -84.740005, -100.239998, -115.739998, + -84.300003, -99.720001, -115.139999, -83.860001, -99.199997, + -114.539993, -83.419998, -98.68, -113.939995, -82.979996, + -98.160004, -113.339996, -82.539993, -97.639999, -112.739998, + -82.099998, -97.120003, -112.139999, -81.660004, -96.600006, + -111.539993, -81.220001, -96.080002, -110.939995, -80.779999, + -95.559998, -110.340012, -80.340004, -95.040001, -109.740005, + -79.900002, -94.519997, -109.139992, -77.699997, -91.919998, + -106.139999, -77.260002, -91.400002, -105.540001, -76.820007, + -90.880005, -104.940002, -76.380005, -90.360001, -104.339996, + -75.940002, -89.839996, -103.740005, -75.5, -89.320007, + -103.139999, -75.060005, -88.800003, -102.540001, -74.619995, + -88.279999, -101.940002, -74.18, -87.759995, -101.339996, + -73.740005, -87.239998, -100.739998, -73.300003, -86.720001, + -100.139999, -72.860001, -86.199997, -99.539993, -72.419998, + -85.68, -98.939995, -71.979996, -85.160004, -98.339996, + -71.539993, -84.639999, -97.740005, -71.099998, -84.120003, + -97.139999, -68.899994, -81.519997, -94.139999, -68.459999, + -81.00, -93.539993, -68.019997, -80.479996, -92.940002, + -67.580002, -79.959999, -92.339996, -67.139999, -79.440002, + -91.740005, -66.699997, -78.919998, -91.139999, -66.260002, + -78.399994, -90.540001, -65.820007, -77.880005, -89.940002, + -65.380005, -77.360001, -89.339996, -64.940002, -76.839996, + -88.740005, -64.5, -76.320007, -88.139999, -64.060005, + -75.800003, -87.540001, -63.619995, -75.279999, -86.940002, + -63.18, -74.759995, -86.339996, -62.739998, -74.239998, + -85.739998, -62.299999, -73.720001, -85.139999}, + sd::DataType::FLOAT32); + + NDArray expGradW('c', {iC, kH, kW, oC}, + {-592.800110, -593.039917, -594.719116, -594.960266, + -427.199890, -427.919617, -432.959900, -433.679993, + -261.600281, -262.799591, -271.200317, -272.399536}, + sd::DataType::FLOAT32); + NDArray expGradB('c', {oC}, {-204.600006, -204.}, sd::DataType::FLOAT32); + + input.linspace(70., -1); + gradO.linspace(-4, 0.01); + + sd::ops::deconv2d_bp op; + auto results = op.evaluate( + {&input, &weights, &bias, &gradO}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat, wFormat}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} - input.linspace(-75, 0.5); - gradO.linspace(0.01, 0.01); +TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) { + auto input = NDArrayFactory::create('c', {2, 2, 6}); + auto weights = NDArrayFactory::create( + 'c', {2, 2, 3}, {1, 5, 9, 3, 7, 11, 2, 6, 10, 4, 8, 12}); + auto bias = NDArrayFactory::create('c', {3}); + auto expFF = NDArrayFactory::create( + 'c', {2, 3, 5}, + {59.0f, 69.0f, 79.0f, 89.0f, 99.0f, 132.0f, 158.0f, 184.0f, + 210.0f, 236.0f, 205.0f, 247.0f, 289.0f, 331.0f, 373.0f, 179.0f, + 189.0f, 199.0f, 209.0f, 219.0f, 444.0f, 470.0f, 496.0f, 522.0f, + 548.0f, 709.0f, 751.0f, 793.0f, 835.0f, 877.0f}); + auto expEps = NDArrayFactory::create( + 'c', {2, 2, 6}, + {130.0f, 293.0f, 326.0f, 359.0f, 392.0f, 220.0f, 166.0f, 371.0f, + 416.0f, 461.0f, 506.0f, 280.0f, 355.0f, 788.0f, 821.0f, 854.0f, + 887.0f, 490.0f, 481.0f, 1046.0f, 1091.0f, 1136.0f, 1181.0f, 640.0f}); + auto expGW = NDArrayFactory::create( + 'c', {3, 2, 2}, + {1415.0f, 1520.0f, 2045.0f, 2150.0f, 1865.0f, 2020.0f, 2795.0f, 2950.0f, + 2315.0f, 2520.0f, 3545.0f, 3750.0f}); + auto expGB = + NDArrayFactory::create('c', {3}, {105.0f, 155.0f, 205.0f}); + + expGW.permutei({2, 1, 0}); + input.linspace(1); + bias.linspace(1); + + sd::ops::conv1d op; + auto result_FF = + op.evaluate({&input, &weights, &bias}, {}, {2, 1, 0, 1, 0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, result_FF.status()); + + auto z = result_FF.at(0); + + ASSERT_TRUE(expFF.isSameShape(z)); + ASSERT_TRUE(expFF.equalsTo(z)); + + sd::ops::conv1d_bp op_bp; + + auto epsilonNxt = new NDArray(z.dup()); + epsilonNxt->linspace(1); + + auto result_BP = op_bp.evaluate({&input, &weights, &bias, epsilonNxt}, {}, + {2, 1, 0, 1, 0, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result_BP.status()); + + auto eps = result_BP.at(0); + auto gradW = result_BP.at(1); + auto gradB = result_BP.at(2); + + ASSERT_TRUE(expEps.isSameShape(eps)); + ASSERT_TRUE(expGW.isSameShape(gradW)); + ASSERT_TRUE(expGB.isSameShape(gradB)); + + ASSERT_TRUE(expEps.equalsTo(eps)); + ASSERT_TRUE(expGW.equalsTo(gradW)); + ASSERT_TRUE(expGB.equalsTo(gradB)); + + delete epsilonNxt; +} - sd::ops::conv3dnew_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); - auto gradI = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); +TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_2) { + auto input = NDArrayFactory::create('c', {2, 2, 6}); + auto weights = NDArrayFactory::create( + 'c', {2, 2, 3}, + {1.f, 5.f, 9.f, 3.f, 7.f, 11.f, 2.f, 6.f, 10.f, 4.f, 8.f, 12.f}); - ASSERT_EQ(Status::OK(), results.status()); + input.linspace(1); - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); + sd::ops::conv1d op; + auto result = op.evaluate({&input, &weights}, {}, {2, 1, 0, 1, 1, 0}); - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); + auto z = result.at(0); } ////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv3d_bp_test5) { +TEST_F(ConvolutionTests1, conv1d_causal_1) { + int bS = 2, iW = 3, iC = 4, oC = 3, kW = 2, sW = 1, pW = 0, dW = 1; + int oW = (iW - 1) / sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1, -2, -3}); - int bS=2, iD=4,iH=3,iW=3, iC=4,oC=3, kD=3,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=4,oH=3,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - int wFormat = 2; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] + NDArray expOutput('c', {bS, oW, oC}, + {18., 18., 18., 53., 55.6, 58.2, 89.8, 95.6, 101.4, 102., + 106.8, 111.6, 163.4, 175.6, 187.8, 200.2, 215.6, 231.}); - NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {oC, kD, kH, kW, iC}, {15., 14.7, 14.4, 14.1, 13.8, 13.5, 13.2, 12.9, 12.6, 12.3, 12., 11.7, 11.4, 11.1, 10.8, 10.5, 10.2, 9.9, 9.6, 9.3, 9., - 8.7, 8.4, 8.1, 7.8, 7.5, 7.2, 6.9, 6.6, 6.3, 6., 5.7, 5.4, 5.1, 4.8, 4.5, 4.2, 3.9, 3.6, 3.3, 3., 2.7, 2.4, 2.1, 1.8, 1.5, 1.2, 0.9, 14.9, 14.6, 14.3, 14., - 13.7, 13.4, 13.1, 12.8, 12.5, 12.2, 11.9, 11.6, 11.3, 11., 10.7, 10.4, 10.1, 9.8, 9.5, 9.2, 8.9, 8.6, 8.3, 8., 7.7, 7.4, 7.1, 6.8, 6.5, 6.2, 5.9, 5.6, 5.3, 5., - 4.7, 4.4, 4.1, 3.8, 3.5, 3.2, 2.9, 2.6, 2.3, 2., 1.7, 1.4, 1.1, 0.8, 14.8, 14.5, 14.2, 13.9, 13.6, 13.3, 13., 12.7, 12.4, 12.1, 11.8, 11.5, 11.2, 10.9, 10.6, - 10.3, 10., 9.7, 9.4, 9.1, 8.8, 8.5, 8.2, 7.9, 7.6, 7.3, 7., 6.7, 6.4, 6.1, 5.8, 5.5, 5.2, 4.9, 4.6, 4.3, 4., 3.7, 3.4, 3.1, 2.8, 2.5, 2.2, 1.9, 1.6, 1.3, 1., 0.7}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {1,-0.5, 0.1}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oD, oH, oW, oC}, sd::DataType::FLOAT32); + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); - NDArray expGradI('c', {bS, iD, iH, iW, iC}, {13.565001, 13.286001, 13.007000, 12.728001, 28.264000, 27.652000, 27.040001, 26.427999, 32.547997, 31.827999, 31.108002, - 30.388000, 31.647999, 30.927998, 30.208000, 29.487999, 64.484001, 62.935997, 61.387997, 59.839996, 72.188004, 70.424004, 68.660004, 66.896004, 43.852001, 42.807999, - 41.764000, 40.719997, 87.596001, 85.400002, 83.204002, 81.007996, 95.299988, 92.887993, 90.475998, 88.063995, 34.130997, 33.348000, 32.564999, 31.782001, 67.856995, - 66.210007, 64.563004, 62.916000, 72.987000, 71.178001, 69.369003, 67.559998, 70.179001, 68.369995, 66.561005, 64.751999, 137.927994, 134.147995, 130.367996, 126.587997, - 146.891998, 142.787994, 138.683990, 134.580017, 84.597000, 82.302002, 80.007004, 77.711998, 164.820007, 160.067993, 155.316010, 150.563995, 173.783997, 168.707993, - 163.631989, 158.556000, 58.674000, 57.162003, 55.649994, 54.138000, 114.027008, 110.921997, 107.816994, 104.711990, 119.156998, 115.889999, 112.623001, 109.355995, 113.433006, 110.166000, 106.899002, 103.632004, 218.603989, 211.908020, 205.211975, 198.515991, 227.568008, 220.547974, 213.528015, 206.507996, 127.850998, 124.098000, 120.345001, 116.591995, 245.496002, 237.828018, 230.159988, 222.492004, 254.459991, 246.468002, 238.475998, 230.483994, 34.049000, 32.797997, 31.547001, 30.295998, 64.479996, 61.924000, 59.368004, 56.812000, 67.035995, 64.372002, 61.707996, 59.044003, 62.248001, 59.584003, 56.919998, 54.256001, 116.180000, 110.744003, 105.307999, 99.872002, 120.428001, 114.776001, 109.124001, 103.472000, 69.268005, 66.279999, 63.292000, 60.304001, 128.923996, 122.839996, 116.755997, 110.671997, 133.171997, 126.872002, 120.571991, 114.271996, 94.565002, 92.342010, 90.118996, 87.896004, 182.488007, 177.988007, 173.488007, 168.988007, 186.772003, 182.164001, 177.556000, 172.947998, 178.095993, 173.488007, 168.880005, 164.272003, 341.828003, 332.504028, 323.180023, 313.856018, 349.532013, 339.992004, 330.451996, 320.911987, 190.299988, 185.368011, 180.436005, 175.503998, 364.940002, 354.967987, 344.996002, 335.024017, 372.644012, 362.455994, 352.268005, 342.080017, 132.303009, 128.604004, 124.904999, 121.206001, 252.536987, 245.057999, 237.578979, 230.100006, 257.666992, 250.026001, 242.385010, 234.744019, 243.195007, 235.554001, 227.912994, 220.272003, 460.631958, 445.188019, 429.744019, 414.299988, 469.595947, 453.827972, 438.059998, 422.291992, 257.613007, 249.486008, 241.358994, 233.232010, 487.523987, 471.108032, 454.691986, 438.276001, 496.488037, 479.748016, 463.007996, 446.268005, 156.846008, 152.417999, 147.989990, 143.561996, 298.707001, 289.769989, 280.833008, 271.895996, 303.837006, 294.737976, 285.638977, 276.540009, 286.449005, 277.350006, 268.250977, 259.151978, 541.307983, 522.947998, 504.587982, 486.227997, 550.271973, 531.588013, 512.903992, 494.220032, 300.867004, 291.281982, 281.696991, 272.112000, 568.200012, 548.868042, 529.535950, 510.204010, 577.164062, 557.507935, 537.851990, 518.196045, 83.944992, 80.750000, 77.555000, 74.360001, 156.496002, 150.052002, 143.608002, 137.164001, 159.052002, 152.500000, 145.947998, 139.395996, 146.488007, 139.936005, 133.384003, 126.832001, 269.107971, 255.895996, 242.684006, 229.471985, 273.356018, 259.927979, 246.500000, 233.071991, 153.507996, 146.632004, 139.755997, 132.880005, 281.851990, 267.992004, 254.132004, 240.272003, 286.100006, 272.023987, 257.947998, 243.872009}, sd::DataType::FLOAT32); + sd::ops::conv1d op; + auto results = op.evaluate({&input, &weights, &bias}, + {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results.at(0); - NDArray expGradW('c', {oC, kD, kH, kW, iC}, {396.899872, 429.570007, 462.240234, 494.910156, 313.739960, 335.250000, 356.760071, 378.270020, 403.379944, 424.350006, - 445.320007, 466.289978, 299.520020, 313.319977, 327.119995, 340.920013, 1556.280029, 1594.979980, 1633.679932, 1672.379883, 1090.080078, 1115.520020, 1140.959961, - 1166.400024, 1183.679932, 1208.400024, 1233.119995, 1257.840088, 821.279907, 837.519897, 853.760010, 870.000000, 1500.119873, 1525.500122, 1550.880005, 1576.260010, - 1029.780029, 1046.429932, 1063.080078, 1079.729980, 1080.539917, 1096.650024, 1112.760010, 1128.869995, 738.000000, 748.560059, 759.119995, 769.679993, 389.880005, - 422.819946, 455.759979, 488.699951, 309.420013, 331.109985, 352.799988, 374.490051, 399.780029, 420.930023, 442.080017, 463.230011, 297.359985, 311.280029, 325.200012, 339.120056, 1553.400146, 1592.459961, 1631.520020, 1670.579956, 1088.640015, 1114.320068, 1140.000000, 1165.679932, 1183.199951, 1208.160034, 1233.119995, 1258.079956, 821.280029, 837.680054, 854.079956, 870.479980, 1502.819946, 1528.469971, 1554.119995, 1579.770020, 1031.939941, 1048.770020, 1065.599976, 1082.429932, 1083.420044, 1099.709961, 1116.000000, 1132.290039, 740.159973, 750.840027, 761.519958, 772.199951, 382.859924, 416.070099, 449.279968, 482.489990, 305.099976, 326.970062, 348.840027, 370.709991, 396.179962, 417.510010, 438.839966, 460.169952, 295.200012, 309.239990, 323.279968, 337.320007, 1550.519775, 1589.939941, 1629.359985, 1668.779907, 1087.200073, 1113.119995, 1139.039917, 1164.959961, 1182.719971, 1207.920044, 1233.119995, 1258.320190, 821.279968, 837.840027, 854.400024, 870.959961, 1505.520142, 1531.439819, 1557.359985, 1583.279907, 1034.100098, 1051.110107, 1068.120117, 1085.130005, 1086.299927, 1102.770020, 1119.239990, 1135.710083, 742.319946, 753.119995, 763.919983, 774.720032}, sd::DataType::FLOAT32); + ASSERT_EQ(Status::OK(), results.status()); - NDArray expGradB('c', {oC}, {77.400002, 78.119995, 78.840004}, sd::DataType::FLOAT32); + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_2) { + int bS = 2, iW = 16, iC = 3, oC = 4, kW = 2, sW = 2, pW = 0, dW = 1; + int oW = (iW - 1) / sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1, -2, -3, -4}); + + NDArray expOutput( + 'c', {bS, oW, oC}, + {10., 9.6, 9.2, 8.8, 48.9, 51.8, 54.7, 57.6, 88.5, 95., + 101.5, 108., 128.1, 138.2, 148.3, 158.4, 167.7, 181.4, 195.1, 208.8, + 207.3, 224.6, 241.9, 259.2, 246.9, 267.8, 288.7, 309.6, 286.5, 311., + 335.5, 360., 254.8, 268.8, 282.8, 296.8, 365.7, 397.4, 429.1, 460.8, + 405.3, 440.6, 475.9, 511.2, 444.9, 483.8, 522.7, 561.6, 484.5, 527., + 569.5, 612., 524.1, 570.2, 616.3, 662.4, 563.7, 613.4, 663.1, 712.8, + 603.3, 656.6, 709.9, 763.2}); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + sd::ops::conv1d op; + auto results = op.evaluate({&input, &weights, &bias}, + {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} - input.linspace(-75, 0.5); - gradO.linspace(0.01, 0.01); +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_3) { + int bS = 2, iW = 16, iC = 3, oC = 4, kW = 3, sW = 3, pW = 0, dW = 1; + int oW = (iW - 1) / sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1, -2, -3, -4}); + + NDArray expOutput( + 'c', {bS, oW, oC}, + {17.2, 16.8, 16.4, 16., 145.4, 151.6, 157.8, 164., + 283.1, 297.4, 311.7, 326., 420.8, 443.2, 465.6, 488., + 558.5, 589., 619.5, 650., 696.2001, 734.8, 773.4, 812., + 434.8, 448.8, 462.8, 476.8, 879.8, 929.2, 978.6, 1028., + 1017.5, 1075., 1132.5, 1190., 1155.2001, 1220.8, 1286.4, 1352., + 1292.8999, 1366.6, 1440.3, 1514., 1430.6001, 1512.4, 1594.2, 1676.}); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + sd::ops::conv1d op; + auto results = op.evaluate({&input, &weights, &bias}, + {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} - sd::ops::conv3dnew_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); - auto gradI = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_4) { + int bS = 2, iW = 8, iC = 3, oC = 4, kW = 3, sW = 1, pW = 0, dW = 3; + int oW = (iW - 1) / sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1, -2, -3, -4}); + + NDArray expOutput( + 'c', {bS, oW, oC}, + {17.2, 16.8, 16.4, 16., 43.3, 43.8, 44.3, 44.8, 69.4, 70.8, + 72.2, 73.6, 106.5, 109.4, 112.3, 115.2, 147.9, 152.6, 157.3, 162., + 189.3, 195.8, 202.3, 208.8, 234.5, 243.4, 252.3, 261.2, 280.4, 292., + 303.6, 315.2, 226., 232.8, 239.6, 246.4, 252.1, 259.8, 267.5, 275.2, + 278.2, 286.8, 295.4, 304., 437.7, 455., 472.3, 489.6, 479.1, 498.2, + 517.3, 536.4, 520.5, 541.4, 562.3, 583.2, 601.7, 632.2, 662.7, 693.2, + 647.6, 680.8, 714., 747.2}); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + sd::ops::conv1d op; + auto results = op.evaluate({&input, &weights, &bias}, + {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} - ASSERT_EQ(Status::OK(), results.status()); +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_5) { + int bS = 2, iW = 8, iC = 3, oC = 4, kW = 3, sW = 1, pW = 0, dW = 3; + int oW = (iW - 1) / sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iW}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1, -2, -3, -4}); + + NDArray expOutput( + 'c', {bS, oC, oW}, + {83.7, 92.4, 101.1, 162.1, 175.9, 189.7, 223.4, 238.7, 85.4, 94.4, + 103.4, 167.4, 181.8, 196.2, 233.2, 249.4, 87.1, 96.4, 105.7, 172.7, + 187.7, 202.7, 243., 260.1, 88.8, 98.4, 108., 178., 193.6, 209.2, + 252.8, 270.8, 292.5, 301.2, 309.9, 493.3, 507.1, 520.9, 590.6, 605.9, + 301.4, 310.4, 319.4, 513., 527.4, 541.8, 622., 638.2, 310.3, 319.6, + 328.9, 532.7, 547.7, 562.7, 653.4, 670.5, 319.2, 328.8, 338.4, 552.4, + 568., 583.6, 684.8, 702.8}); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + sd::ops::conv1d op; + auto results = op.evaluate({&input, &weights, &bias}, + {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_6) { + int bS = 2, iW = 16, iC = 3, oC = 4, kW = 3, sW = 3, pW = 0, dW = 1; + int oW = (iW - 1) / sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iW}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1, -2, -3, -4}); + + NDArray expOutput( + 'c', {bS, oC, oW}, + {159.7, 335.3, 381.2, 427.1, 473., 518.9, 163.8, 351.4, + 400., 448.6, 497.2, 545.8, 167.9, 367.5, 418.8, 470.1, + 521.4, 572.7, 172., 383.6, 437.6, 491.6, 545.6, 599.6, + 577.3, 1069.7, 1115.6, 1161.5, 1207.4, 1253.3, 595.8, 1129., + 1177.6, 1226.2, 1274.8, 1323.4, 614.3, 1188.3, 1239.6, 1290.9, + 1342.2, 1393.5, 632.8, 1247.6, 1301.6, 1355.6, 1409.6, 1463.6}); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + sd::ops::conv1d op; + auto results = op.evaluate({&input, &weights, &bias}, + {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_7) { + int bS = 2, iW = 8, iC = 3, oC = 4, kW = 2, sW = 1, pW = 0, dW = 1; + int oW = (iW - 1) / sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {kW, iC, oC}, sd::DataType::FLOAT32); + + NDArray expOutput( + 'c', {bS, oW, oC}, + {11.000000, 11.600000, 12.200000, 12.800000, 30.099998, 32.200001, + 34.299999, 36.400002, 49.899998, 53.800003, 57.699997, 61.599998, + 69.699997, 75.400002, 81.099998, 86.800003, 89.500000, 97.000000, + 104.500000, 112.000000, 109.300003, 118.600006, 127.899994, 137.199997, + 129.100006, 140.199997, 151.300003, 162.399994, 148.899994, 161.800003, + 174.699997, 187.600006, 133.399994, 141.200012, 149.000000, 156.800003, + 188.500000, 205.000000, 221.500000, 238.000000, 208.299988, 226.600006, + 244.899994, 263.200012, 228.100006, 248.200012, 268.299988, 288.399994, + 247.899994, 269.799988, 291.700012, 313.600006, 267.700012, 291.399994, + 315.100006, 338.799988, 287.500000, 313.000000, 338.500000, 364.000000, + 307.299988, 334.600006, 361.899994, 389.200012}, + sd::DataType::FLOAT32); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + sd::ops::conv1d op; + auto results = op.evaluate({&input, &weights}, + {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_8) { + int bS = 2, iW = 8, iC = 3, oC = 4, kW = 2, sW = 1, pW = 0, dW = 2; + int oW = (iW - 1) / sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {kW, iC, oC}, sd::DataType::FLOAT32); + + NDArray expOutput( + 'c', {bS, oW, oC}, + {11.000000, 11.600000, 12.200000, 12.800000, 26.299999, 27.799999, + 29.299999, 30.799999, 45.399998, 48.399998, 51.400002, 54.400005, + 65.199997, 70.000000, 74.800003, 79.600006, 85.000000, 91.600006, + 98.199997, 104.800003, 104.799995, 113.199997, 121.600006, 130.000000, + 124.599998, 134.800003, 145.000000, 155.200012, 144.399994, 156.399994, + 168.399994, 180.400009, 133.400009, 141.199997, 149.000000, 156.800003, + 148.699997, 157.400009, 166.099991, 174.800003, 203.800003, 221.200012, + 238.599991, 256.000000, 223.599991, 242.799988, 262.000000, 281.200012, + 243.399994, 264.399994, 285.399994, 306.399994, 263.199982, 286.000000, + 308.799988, 331.600006, 283.000000, 307.600006, 332.200012, 356.800018, + 302.799988, 329.199982, 355.600006, 382.000000}, + sd::DataType::FLOAT32); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + sd::ops::conv1d op; + auto results = op.evaluate({&input, &weights}, + {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv3d_test1) { +TEST_F(ConvolutionTests1, conv1d_causal_bp_1) { + int bS = 2, iW = 3, iC = 4, oC = 3, kW = 2, sW = 1, pW = 0, dW = 1; + int oW = (iW - 1) / sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW - int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NDHWC, 0-NCDHW + NDArray input('c', {bS, iW, iC}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1, -2, -3}); + NDArray gradO('c', {bS, oW, oC}); - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto expected = NDArrayFactory::create('c', {2, 3, 4, 3, 3}, {534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, - 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, - 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, - 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f, - 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, - 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, - 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, - 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f}); - input = 2.; - weights.linspace(0.1, 0.1); + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + gradO.linspace(-1.5, 0.1); - sd::ops::conv3dnew op; - auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); + const OpArgsHolder argsHolderFF({&input, &weights, &bias}, {}, + {kW, sW, pW, dW, paddingMode, dataFormat}); + const OpArgsHolder argsHolderBP({&input, &weights, &bias, &gradO}, {}, + {kW, sW, pW, dW, paddingMode, dataFormat}); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::conv1d opFF; + sd::ops::conv1d_bp opBP; + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); } -////////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv3d_test2) { +TEST_F(ConvolutionTests1, Test_Dilation2D_1) { + auto input = NDArrayFactory::create('c', {2, 6, 6, 3}); + auto weights = NDArrayFactory::create('c', {3, 2, 3}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 3, 3}, + { + 77, 79, 81, 83, 85, 87, 80, 82, 84, 113, 115, 117, 119, 121, + 123, 116, 118, 120, 107, 109, 111, 113, 115, 117, 110, 112, 114, 185, + 187, 189, 191, 193, 195, 188, 190, 192, 221, 223, 225, 227, 229, 231, + 224, 226, 228, 215, 217, 219, 221, 223, 225, 218, 220, 222, + }); + + input.linspace(1); + weights.linspace(1); + + sd::ops::dilation2d op; + auto result = op.evaluate({&input, &weights}, {1, 1, 2, 2, 1, 1, 2, 2, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} - int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NDHWC, 0-NCDHW +TEST_F(ConvolutionTests1, Test_Dilation2D_2) { + auto input = NDArrayFactory::create('c', {2, 6, 6, 3}); + auto weights = NDArrayFactory::create('c', {3, 2, 3}); + auto exp = NDArrayFactory::create( + 'c', {2, 1, 2, 3}, + {95, 97, 99, 101, 103, 105, 203, 205, 207, 209, 211, 213}); - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto expected = NDArrayFactory::create('c', {2, 2, 2, 2, 3}, {686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, - 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, - 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, - 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f}); - input = 2.; - weights.linspace(0.1, 0.1); + input.linspace(1); + weights.linspace(1); - sd::ops::conv3dnew op; - auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); + sd::ops::dilation2d op; + auto result = op.evaluate({&input, &weights}, {0, 1, 2, 2, 1, 1, 2, 2, 1}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test1) { + int bS = 2, iH = 4, iW = 3, iC = 4, oC = 3, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}, {1, 2, 3}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); + + auto expGradI = NDArrayFactory::create( + 'c', {bS, iH, iW, iC}, + {0.226f, 0.343f, 0.46f, 0.577f, 1.172f, 1.46f, 1.748f, 2.036f, + 1.892f, 2.288f, 2.684f, 3.08f, 1.284f, 1.581f, 1.878f, 2.175f, + 4.458f, 5.133f, 5.808f, 6.483f, 6.186f, 7.023f, 7.86f, 8.697f, + 3.39f, 3.93f, 4.47f, 5.01f, 9.642f, 10.803f, 11.964f, 13.125f, + 11.37f, 12.693f, 14.016f, 15.339f, 5.266f, 5.707f, 6.148f, 6.589f, + 12.98f, 13.916f, 14.852f, 15.788f, 14.564f, 15.608f, 16.652f, 17.696f, + 3.25f, 4.015f, 4.78f, 5.545f, 9.812f, 11.396f, 12.98f, 14.564f, + 10.532f, 12.224f, 13.916f, 15.608f, 9.708f, 10.977f, 12.246f, 13.515f, + 25.194f, 27.813f, 30.432f, 33.051f, 26.922f, 29.703f, 32.484f, 35.265f, + 11.814f, 13.326f, 14.838f, 16.35f, 30.378f, 33.483f, 36.588f, 39.693f, + 32.106f, 35.373f, 38.64f, 41.907f, 13.474f, 14.563f, 15.652f, 16.741f, + 31.988f, 34.22f, 36.452f, 38.684f, 33.572f, 35.912f, 38.252f, 40.592f}); + + auto expGradW = NDArrayFactory::create( + 'c', {kH, kW, iC, oC}, + {14.4f, 14.76f, 15.12f, 14.4f, 14.76f, 15.12f, 14.4f, 14.76f, 15.12f, + 14.4f, 14.76f, 15.12f, 9.24f, 9.48f, 9.72f, 9.24f, 9.48f, 9.72f, + 9.24f, 9.48f, 9.72f, 9.24f, 9.48f, 9.72f, 17.04f, 17.52f, 18.f, + 17.04f, 17.52f, 18.f, 17.04f, 17.52f, 18.f, 17.04f, 17.52f, 18.f, + 10.88f, 11.2f, 11.52f, 10.88f, 11.2f, 11.52f, 10.88f, 11.2f, 11.52f, + 10.88f, 11.2f, 11.52f, 11.16f, 11.52f, 11.88f, 11.16f, 11.52f, 11.88f, + 11.16f, 11.52f, 11.88f, 11.16f, 11.52f, 11.88f, 7.08f, 7.32f, 7.56f, + 7.08f, 7.32f, 7.56f, 7.08f, 7.32f, 7.56f, 7.08f, 7.32f, 7.56f}); + // auto expGradB('c', {oC},{}); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::conv2d_bp op; + auto results = + op.evaluate({&input, &weights, &bias, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); +} ////////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv3d_test3) { +TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test2) { + int bS = 2, iH = 4, iW = 3, iC = 4, oC = 3, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}, {1, 2, 3}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); + + auto expGradI = NDArrayFactory::create( + 'c', {bS, iH, iW, iC}, + {0.014f, 0.032f, 0.05f, 0.068f, 0.118f, 0.181f, 0.244f, 0.307f, 0.212f, + 0.257f, 0.302f, 0.347f, 0.208f, 0.298f, 0.388f, 0.478f, 1.028f, 1.262f, + 1.496f, 1.73f, 1.036f, 1.18f, 1.324f, 1.468f, 0.928f, 1.018f, 1.108f, + 1.198f, 2.9f, 3.134f, 3.368f, 3.602f, 2.188f, 2.332f, 2.476f, 2.62f, + 1.202f, 1.274f, 1.346f, 1.418f, 3.142f, 3.313f, 3.484f, 3.655f, 2.048f, + 2.147f, 2.246f, 2.345f, 0.086f, 0.212f, 0.338f, 0.464f, 0.694f, 0.973f, + 1.252f, 1.531f, 0.716f, 0.869f, 1.022f, 1.175f, 1.216f, 1.522f, 1.828f, + 2.134f, 3.908f, 4.574f, 5.24f, 5.906f, 2.908f, 3.268f, 3.628f, 3.988f, + 3.664f, 3.97f, 4.276f, 4.582f, 9.236f, 9.902f, 10.568f, 11.234f, 5.788f, + 6.148f, 6.508f, 6.868f, 3.002f, 3.182f, 3.362f, 3.542f, 7.174f, 7.561f, + 7.948f, 8.335f, 4.28f, 4.487f, 4.694f, 4.901f}); + + auto expGradW = NDArrayFactory::create( + 'c', {kH, kW, iC, oC}, + {1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, + 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, + 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, + 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, + 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, + 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, + 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, + 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f}); + // auto expGradB('c', {oC},{}); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::conv2d_bp op; + auto results = + op.evaluate({&input, &weights, &bias, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); +} - int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test3) { + int bS = 2, iH = 4, iW = 3, iC = 4, oC = 3, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kH, kW}); + auto bias = NDArrayFactory::create('c', {oC}, {1, 2, 3}); + auto gradO = NDArrayFactory::create('c', {bS, oC, oH, oW}); + + auto expGradI = NDArrayFactory::create( + 'c', {bS, iC, iH, iW}, + {0.567f, 1.224f, 0.66f, 1.314f, 2.82f, 1.512f, 1.386f, 2.976f, 1.596f, + 0.801f, 1.71f, 0.912f, 0.657f, 1.422f, 0.768f, 1.53f, 3.288f, 1.764f, + 1.602f, 3.444f, 1.848f, 0.927f, 1.98f, 1.056f, 0.747f, 1.62f, 0.876f, + 1.746f, 3.756f, 2.016f, 1.818f, 3.912f, 2.1f, 1.053f, 2.25f, 1.2f, + 0.837f, 1.818f, 0.984f, 1.962f, 4.224f, 2.268f, 2.034f, 4.38f, 2.352f, + 1.179f, 2.52f, 1.344f, 1.467f, 3.06f, 1.596f, 3.186f, 6.636f, 3.456f, + 3.402f, 7.08f, 3.684f, 1.845f, 3.834f, 1.992f, 1.773f, 3.69f, 1.92f, + 3.834f, 7.968f, 4.14f, 4.05f, 8.412f, 4.368f, 2.187f, 4.536f, 2.352f, + 2.079f, 4.32f, 2.244f, 4.482f, 9.3f, 4.824f, 4.698f, 9.744f, 5.052f, + 2.529f, 5.238f, 2.712f, 2.385f, 4.95f, 2.568f, 5.13f, 10.632f, 5.508f, + 5.346f, 11.076f, 5.736f, 2.871f, 5.94f, 3.072f}); + + auto expGradW = NDArrayFactory::create( + 'c', {oC, iC, kH, kW}, + {1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, + 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, + 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, + 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, + 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 2.0000e+00f, + 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, + 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, + 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, + 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, + 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.6400e+00f, 2.6400e+00f, + 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, + 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, + 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, + 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, + 2.6400e+00f, 2.6400e+00f}); + auto expGradB = + NDArrayFactory::create('c', {oC}, {0.68f, 1.f, 1.32f}); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + weights.permutei({2, 3, 1, 0}); + expGradW.permutei({2, 3, 1, 0}); + + sd::ops::conv2d_bp op; + auto results = + op.evaluate({&input, &weights, &bias, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2}); - input = 2.; - weights = 0.5; - expected = 48.; +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv2d_bp_4) { + int bS = 1, iH = 7, iW = 1, iC = 2, oC = 3, kH = 2, kW = 1, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 7, oW = 1; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, oC}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {1, 2, 3}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); + + NDArray gradI('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray gradW('c', {kH, kW, iC, oC}, sd::DataType::FLOAT32); + NDArray gradB('c', {oC}, sd::DataType::FLOAT32); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::conv2d_bp op; + auto status = op.execute( + {&input, &weights, &bias, &gradO}, {&gradI, &gradW, &gradB}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}, {}); + + ASSERT_EQ(Status::OK(), status); +} - sd::ops::conv3dnew op; - auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv2d_bp_5) { + int bS = 2, iH = 4, iW = 3, iC = 4, oC = 3, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = + 1; // 0-[kH, kW, iC, oC], 1-[oC, iC, kH, kW], 2-[oC, kH, kW, iC] + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {oC, iC, kH, kW}, + {3.6, 2.4, 1.2, 0.0, -1.2, -2.4, 3.3, 2.1, 0.9, -0.3, -1.5, -2.7, + 3.0, 1.8, 0.6, -0.6, -1.8, -3.0, 2.7, 1.5, 0.3, -0.9, -2.1, -3.3, + 3.5, 2.3, 1.1, -0.1, -1.3, -2.5, 3.2, 2.0, 0.8, -0.4, -1.6, -2.8, + 2.9, 1.7, 0.5, -0.7, -1.9, -3.1, 2.6, 1.4, 0.2, -1.0, -2.2, -3.4, + 3.4, 2.2, 1.0, -0.2, -1.4, -2.6, 3.1, 1.9, 0.7, -0.5, -1.7, -2.9, + 2.8, 1.6, 0.4, -0.8, -2.0, -3.2, 2.5, 1.3, 0.1, -1.1, -2.3, -3.5}, + sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {1, -0.5, 0.1}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); + + NDArray expGradI( + 'c', {bS, iC, iH, iW}, + {0.517, 0.959, 0.406, 0.884, 1.474, 0.518, 0.020, -0.398, -0.490, + -0.281, -0.853, -0.608, 0.472, 0.860, 0.352, 0.776, 1.240, 0.392, + -0.088, -0.632, -0.616, -0.344, -0.988, -0.680, 0.427, 0.761, 0.298, + 0.668, 1.006, 0.266, -0.196, -0.866, -0.742, -0.407, -1.123, -0.752, + 0.382, 0.662, 0.244, 0.560, 0.772, 0.140, -0.304, -1.100, -0.868, + -0.470, -1.258, -0.824, 1.777, 3.047, 1.234, 2.540, 3.922, 1.310, + -0.052, -1.406, -1.426, -0.749, -2.221, -1.508, 1.624, 2.732, 1.072, + 2.216, 3.256, 0.968, -0.376, -2.072, -1.768, -0.920, -2.572, -1.688, + 1.471, 2.417, 0.910, 1.892, 2.590, 0.626, -0.700, -2.738, -2.110, + -1.091, -2.923, -1.868, 1.318, 2.102, 0.748, 1.568, 1.924, 0.284, + -1.024, -3.404, -2.452, -1.262, -3.274, -2.048}, + sd::DataType::FLOAT32); + + NDArray expGradW( + 'c', {oC, iC, kH, kW}, + {-3.3, -2.62, -1.26, -0.58, 0.78, 1.46, + 4.86, 5.54, 6.9, 7.58, 8.940001, 9.619999, + 13.02, 13.700001, 15.06, 15.74, 17.1, 17.780001, + 21.18, 21.860001, 23.219999, 23.900002, 25.259998, 25.940001, + -10.340001, -9.34, -7.339999, -6.34, -4.339999, -3.339999, + 1.66, 2.66, 4.660001, 5.660001, 7.66, 8.66, + 13.66, 14.660001, 16.66, 17.66, 19.66, 20.66, + 25.66, 26.66, 28.66, 29.66, 31.66, 32.66, + -17.380001, -16.059999, -13.420003, -12.099999, -9.46, -8.139999, + -1.540001, -0.219999, 2.419999, 3.739999, 6.379999, 7.7, + 14.299999, 15.62, 18.26, 19.58, 22.219999, 23.539999, + 30.139999, 31.459999, 34.099998, 35.419998, 38.060001, 39.380001}, + sd::DataType::FLOAT32); + + NDArray expGradB('c', {oC}, {0.68, 1., 1.32}, sd::DataType::FLOAT32); + + input.linspace(-48, 1); + // weights.linspace(3.6, -0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::conv2d_bp op; + auto results = op.evaluate( + {&input, &weights, &bias, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat, wFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv2d_bp_6) { + int bS = 2, iH = 4, iW = 3, iC = 4, oC = 3, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = + 2; // 0-[kH, kW, iC, oC], 1-[oC, iC, kH, kW], 2-[oC, kH, kW, iC] + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {oC, kH, kW, iC}, + {3.6, 0.0, 3.3, -0.3, 3.0, -0.6, 2.7, -0.9, 3.5, -0.1, 3.2, -0.4, + 2.9, -0.7, 2.6, -1.0, 3.4, -0.2, 3.1, -0.5, 2.8, -0.8, 2.5, -1.1, + 2.4, -1.2, 2.1, -1.5, 1.8, -1.8, 1.5, -2.1, 2.3, -1.3, 2.0, -1.6, + 1.7, -1.9, 1.4, -2.2, 2.2, -1.4, 1.9, -1.7, 1.6, -2.0, 1.3, -2.3, + 1.2, -2.4, 0.9, -2.7, 0.6, -3.0, 0.3, -3.3, 1.1, -2.5, 0.8, -2.8, + 0.5, -3.1, 0.2, -3.4, 1.0, -2.6, 0.7, -2.9, 0.4, -3.2, 0.1, -3.5}, + sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {1, -0.5, 0.1}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); + + NDArray expGradI( + 'c', {bS, iH, iW, iC}, + {0.882, -0.522, 0.765, -0.639, 1.953, -1.503, + 1.665, -1.791, 2.691, -2.061, 2.295, -2.457, + 2.259, -1.305, 1.962, -1.602, 4.545, -3.555, + 3.870, -4.230, 5.625, -4.419, 4.788, -5.256001, + 4.122, -2.358, 3.582, -2.898, 7.785, -6.147, + 6.624, -7.308, 8.865, -7.011, 7.541999, -8.334, + 3.273, -2.019, 2.832, -2.460, 6.069, -5.163, + 5.133, -6.099, 6.771, -5.757, 5.727, -6.801, + 5.958, -3.222, 5.193, -3.987, 10.809, -8.198999, + 9.225, -9.783, 11.547, -8.757, 9.855, -10.448999, + 9.711, -5.517, 8.441999, -6.786, 17.505001, -13.922999, + 14.886, -16.542, 18.585001, -14.787001, 15.804001, -17.568001, + 11.574, -6.570, 10.062, -8.082, 20.745001, -16.514999, + 17.639999, -19.619999, 21.825001, -17.379002, 18.558001, -20.646, + 8.133, -4.935, 7.044, -6.024, 14.492998, -12.291, + 12.261, -14.523001, 15.195001, -12.885, 12.855, -15.225}, + sd::DataType::FLOAT32); + + NDArray expGradW( + 'c', {oC, kH, kW, iC}, + {34.559998, 41.760010, 48.959999, 56.160004, 33.119999, 37.739998, + 42.360001, 46.979996, 120.960007, 129.480011, 138.0, 146.519989, + 91.200005, 96.639999, 102.079994, 107.520004, 114.479996, 120.059998, + 125.639999, 131.220001, 82.080002, 85.620003, 89.160004, 92.699997, + 33.120003, 40.499996, 47.879993, 55.260002, 32.399998, 37.139996, + 41.880001, 46.620003, 120.479988, 129.240005, 137.999985, 146.759995, + 91.199997, 96.799995, 102.399994, 108.0, 115.199989, 120.959999, + 126.720001, 132.479996, 82.799995, 86.460007, 90.119995, 93.779999, + 31.679998, 39.239994, 46.800003, 54.359997, 31.680000, 36.540001, + 41.400002, 46.260002, 120.0, 129.0, 138.0, 147.0, + 91.200005, 96.960007, 102.720001, 108.480003, 115.919998, 121.860001, + 127.799988, 133.740005, 83.520004, 87.300003, 91.080002, 94.860001}, + sd::DataType::FLOAT32); + + NDArray expGradB('c', {oC}, {8.520, 8.760, 9.}, sd::DataType::FLOAT32); + + input.linspace(-48, 1); + gradO.linspace(0.01, 0.01); + + sd::ops::conv2d_bp op; + auto results = op.evaluate( + {&input, &weights, &bias, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat, wFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} +//////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test1) { + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 4, oC = 3, kD = 2, kH = 3, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 3, oH = 4, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}, {1, 2, 3}); + auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); + + auto expGradI = NDArrayFactory::create( + 'c', {bS, iD, iH, iW, iC}, + {0.226f, 0.343f, 0.46f, 0.577f, 1.172f, 1.46f, 1.748f, + 2.036f, 1.892f, 2.288f, 2.684f, 3.08f, 1.284f, 1.581f, + 1.878f, 2.175f, 4.458f, 5.133f, 5.808f, 6.483f, 6.186f, + 7.023f, 7.86f, 8.697f, 3.39f, 3.93f, 4.47f, 5.01f, + 9.642f, 10.803f, 11.964f, 13.125f, 11.37f, 12.693f, 14.016f, + 15.339f, 5.266f, 5.707f, 6.148f, 6.589f, 12.98f, 13.916f, + 14.852f, 15.788f, 14.564f, 15.608f, 16.652f, 17.696f, 6.284f, + 7.166f, 8.048f, 8.93f, 17.896f, 19.768f, 21.64f, 23.512f, + 21.928f, 24.016f, 26.104f, 28.192f, 18.12f, 19.686f, 21.252f, + 22.818f, 45.852f, 49.146f, 52.44f, 55.734f, 53.196f, 56.814f, + 60.432f, 64.05f, 28.164f, 30.216f, 32.268f, 34.32f, 67.884f, + 72.15f, 76.416f, 80.682f, 75.228f, 79.818f, 84.408f, 88.998f, + 29.324f, 30.854f, 32.384f, 33.914f, 67.432f, 70.6f, 73.768f, + 76.936f, 73.192f, 76.576f, 79.96f, 83.344f, 27.884f, 30.062f, + 32.24f, 34.418f, 66.28f, 70.744f, 75.208f, 79.672f, 70.312f, + 74.992f, 79.672f, 84.352f, 58.296f, 61.806f, 65.316f, 68.826f, + 133.98f, 141.162f, 148.344f, 155.526f, 141.324f, 148.83f, 156.336f, + 163.842f, 68.34f, 72.336f, 76.332f, 80.328f, 156.012f, 164.166f, + 172.32f, 180.474f, 163.356f, 171.834f, 180.312f, 188.79f, 61.292f, + 64.118f, 66.944f, 69.77f, 136.552f, 142.312f, 148.072f, 153.832f, + 142.312f, 148.288f, 154.264f, 160.24f, 9.298f, 11.359f, 13.42f, + 15.481f, 27.092f, 31.268f, 35.444f, 39.62f, 27.812f, 32.096f, + 36.38f, 40.664f, 26.556f, 29.769f, 32.982f, 36.195f, 66.666f, + 73.173f, 79.68f, 86.187f, 68.394f, 75.063f, 81.732f, 88.401f, + 28.662f, 32.118f, 35.574f, 39.03f, 71.85f, 78.843f, 85.836f, + 92.829f, 73.578f, 80.733f, 87.888f, 95.043f, 29.89f, 32.275f, + 34.66f, 37.045f, 70.004f, 74.828f, 79.652f, 84.476f, 71.588f, + 76.52f, 81.452f, 86.384f, 71.084f, 75.854f, 80.624f, 85.394f, + 163.048f, 172.696f, 182.344f, 191.992f, 167.08f, 176.944f, 186.808f, + 196.672f, 138.648f, 146.046f, 153.444f, 160.842f, 310.236f, 325.194f, + 340.152f, 355.11f, 317.58f, 332.862f, 348.144f, 363.426f, 148.692f, + 156.576f, 164.46f, 172.344f, 332.268f, 348.198f, 364.128f, 380.058f, + 339.612f, 355.866f, 372.12f, 388.374f, 125.228f, 130.646f, 136.064f, + 141.482f, 274.792f, 285.736f, 296.68f, 307.624f, 280.552f, 291.712f, + 302.872f, 314.032f, 92.684f, 98.75f, 104.816f, 110.882f, 211.432f, + 223.672f, 235.912f, 248.152f, 215.464f, 227.92f, 240.376f, 252.832f, + 178.824f, 188.166f, 197.508f, 206.85f, 398.364f, 417.21f, 436.056f, + 454.902f, 405.708f, 424.878f, 444.048f, 463.218f, 188.868f, 198.696f, + 208.524f, 218.352f, 420.396f, 440.214f, 460.032f, 479.85f, 427.74f, + 447.882f, 468.024f, 488.166f, 157.196f, 163.91f, 170.624f, 177.338f, + 343.912f, 357.448f, 370.984f, 384.52f, 349.672f, 363.424f, 377.176f, + 390.928f}); + + auto expGradW = NDArrayFactory::create( + 'c', {kD, kH, kW, iC, oC}, + {120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, + 123.12f, 120.96f, 122.04f, 123.12f, 79.56f, 80.28f, 81.f, 79.56f, + 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, + 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, + 157.68f, 154.8f, 156.24f, 157.68f, 101.76f, 102.72f, 103.68f, 101.76f, + 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, + 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, + 113.4f, 111.24f, 112.32f, 113.4f, 73.08f, 73.8f, 74.52f, 73.08f, + 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, + 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, + 69.12f, 67.68f, 68.4f, 69.12f, 44.4f, 44.88f, 45.36f, 44.4f, + 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, + 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, + 87.84f, 85.92f, 86.88f, 87.84f, 56.32f, 56.96f, 57.6f, 56.32f, + 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, + 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, + 62.64f, 61.2f, 61.92f, 62.64f, 40.08f, 40.56f, 41.04f, 40.08f, + 40.56f, 41.04f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f}); + // auto expGradB('c', {oC},{}); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::conv3dnew_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); } //////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv3d_test4) { +TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test2) { + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 4, oC = 3, kD = 2, kH = 3, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}, {1, 2, 3}); + auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); + + auto expGradI = NDArrayFactory::create( + 'c', {bS, iD, iH, iW, iC}, + {0.014f, 0.032f, 0.05f, 0.068f, 0.118f, 0.181f, 0.244f, 0.307f, + 0.212f, 0.257f, 0.302f, 0.347f, 0.208f, 0.298f, 0.388f, 0.478f, + 1.028f, 1.262f, 1.496f, 1.73f, 1.036f, 1.18f, 1.324f, 1.468f, + 0.928f, 1.018f, 1.108f, 1.198f, 2.9f, 3.134f, 3.368f, 3.602f, + 2.188f, 2.332f, 2.476f, 2.62f, 1.202f, 1.274f, 1.346f, 1.418f, + 3.142f, 3.313f, 3.484f, 3.655f, 2.048f, 2.147f, 2.246f, 2.345f, + 0.532f, 0.676f, 0.82f, 0.964f, 2.324f, 2.666f, 3.008f, 3.35f, + 2.008f, 2.206f, 2.404f, 2.602f, 3.584f, 3.98f, 4.376f, 4.772f, + 10.552f, 11.452f, 12.352f, 13.252f, 7.4f, 7.904f, 8.408f, 8.912f, + 6.752f, 7.148f, 7.544f, 7.94f, 17.752f, 18.652f, 19.552f, 20.452f, + 11.432f, 11.936f, 12.44f, 12.944f, 5.932f, 6.184f, 6.436f, 6.688f, + 14.42f, 14.978f, 15.536f, 16.094f, 8.704f, 9.01f, 9.316f, 9.622f, + 3.11f, 3.236f, 3.362f, 3.488f, 7.39f, 7.669f, 7.948f, 8.227f, + 4.388f, 4.541f, 4.694f, 4.847f, 8.56f, 8.866f, 9.172f, 9.478f, + 19.892f, 20.558f, 21.224f, 21.89f, 11.548f, 11.908f, 12.268f, 12.628f, + 11.008f, 11.314f, 11.62f, 11.926f, 25.22f, 25.886f, 26.552f, 27.218f, + 14.428f, 14.788f, 15.148f, 15.508f, 7.322f, 7.502f, 7.682f, 7.862f, + 16.462f, 16.849f, 17.236f, 17.623f, 9.248f, 9.455f, 9.662f, 9.869f, + 0.158f, 0.392f, 0.626f, 0.86f, 1.27f, 1.765f, 2.26f, 2.755f, + 1.22f, 1.481f, 1.742f, 2.003f, 2.224f, 2.746f, 3.268f, 3.79f, + 6.788f, 7.886f, 8.984f, 10.082f, 4.78f, 5.356f, 5.932f, 6.508f, + 6.4f, 6.922f, 7.444f, 7.966f, 15.572f, 16.67f, 17.768f, 18.866f, + 9.388f, 9.964f, 10.54f, 11.116f, 4.802f, 5.09f, 5.378f, 5.666f, + 11.206f, 11.809f, 12.412f, 13.015f, 6.512f, 6.827f, 7.142f, 7.457f, + 6.004f, 6.58f, 7.156f, 7.732f, 14.996f, 16.202f, 17.408f, 18.614f, + 9.208f, 9.838f, 10.468f, 11.098f, 17.984f, 19.244f, 20.504f, 21.764f, + 42.808f, 45.436f, 48.064f, 50.692f, 25.256f, 26.624f, 27.992f, 29.36f, + 28.064f, 29.324f, 30.584f, 31.844f, 63.832f, 66.46f, 69.088f, 71.716f, + 36.2f, 37.568f, 38.936f, 40.304f, 18.316f, 19.f, 19.684f, 20.368f, + 40.916f, 42.338f, 43.76f, 45.182f, 22.816f, 23.554f, 24.292f, 25.03f, + 8.438f, 8.78f, 9.122f, 9.464f, 18.91f, 19.621f, 20.332f, 21.043f, + 10.58f, 10.949f, 11.318f, 11.687f, 20.944f, 21.682f, 22.42f, 23.158f, + 46.388f, 47.918f, 49.448f, 50.978f, 25.66f, 26.452f, 27.244f, 28.036f, + 26.848f, 27.586f, 28.324f, 29.062f, 58.628f, 60.158f, 61.688f, 63.218f, + 31.996f, 32.788f, 33.58f, 34.372f, 16.106f, 16.502f, 16.898f, 17.294f, + 34.894f, 35.713f, 36.532f, 37.351f, 18.896f, 19.319f, 19.742f, 20.165f}); + + auto expGradW = NDArrayFactory::create( + 'c', {kD, kH, kW, iC, oC}, + {7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, + 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, + 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, + 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, + 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, + 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, + 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, + 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, + 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, + 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, + 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, + 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, + 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, + 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, + 8.16f, 7.52f, 7.84f, 8.16f}); + // auto expGradB('c', {oC},{}); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::conv3dnew_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); +} - int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW +//////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test3) { + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 4, oC = 3, kD = 2, kH = 3, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); + auto bias = NDArrayFactory::create('c', {oC}, {1, 2, 3}); + auto gradO = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); + + auto expGradI = NDArrayFactory::create( + 'c', {bS, iC, iD, iH, iW}, + {2.091f, 4.356f, 2.268f, 4.53f, 9.42f, 4.896f, 4.65f, 9.672f, + 5.028f, 2.517f, 5.226f, 2.712f, 4.932f, 10.242f, 5.316f, 10.62f, + 22.02f, 11.412f, 10.908f, 22.62f, 11.724f, 5.868f, 12.15f, 6.288f, + 2.913f, 6.03f, 3.12f, 6.234f, 12.888f, 6.66f, 6.402f, 13.236f, + 6.84f, 3.423f, 7.068f, 3.648f, 2.415f, 5.04f, 2.628f, 5.25f, + 10.932f, 5.688f, 5.37f, 11.184f, 5.82f, 2.913f, 6.054f, 3.144f, + 5.724f, 11.898f, 6.18f, 12.348f, 25.62f, 13.284f, 12.636f, 26.22f, + 13.596f, 6.804f, 14.094f, 7.296f, 3.381f, 7.002f, 3.624f, 7.242f, + 14.976f, 7.74f, 7.41f, 15.324f, 7.92f, 3.963f, 8.184f, 4.224f, + 2.739f, 5.724f, 2.988f, 5.97f, 12.444f, 6.48f, 6.09f, 12.696f, + 6.612f, 3.309f, 6.882f, 3.576f, 6.516f, 13.554f, 7.044f, 14.076f, + 29.22f, 15.156f, 14.364f, 29.82f, 15.468f, 7.74f, 16.038f, 8.304f, + 3.849f, 7.974f, 4.128f, 8.25f, 17.064f, 8.82f, 8.418f, 17.412f, + 9.f, 4.503f, 9.3f, 4.8f, 3.063f, 6.408f, 3.348f, 6.69f, + 13.956f, 7.272f, 6.81f, 14.208f, 7.404f, 3.705f, 7.71f, 4.008f, + 7.308f, 15.21f, 7.908f, 15.804f, 32.82f, 17.028f, 16.092f, 33.42f, + 17.34f, 8.676f, 17.982f, 9.312f, 4.317f, 8.946f, 4.632f, 9.258f, + 19.152f, 9.9f, 9.426f, 19.5f, 10.08f, 5.043f, 10.416f, 5.376f, + 5.619f, 11.484f, 5.868f, 11.73f, 23.964f, 12.24f, 12.138f, 24.792f, + 12.66f, 6.333f, 12.93f, 6.6f, 12.42f, 25.362f, 12.948f, 25.884f, + 52.836f, 26.964f, 26.748f, 54.588f, 27.852f, 13.932f, 28.422f, 14.496f, + 6.873f, 14.022f, 7.152f, 14.298f, 29.16f, 14.868f, 14.754f, 30.084f, + 15.336f, 7.671f, 15.636f, 7.968f, 6.807f, 13.896f, 7.092f, 14.178f, + 28.932f, 14.76f, 14.586f, 29.76f, 15.18f, 7.593f, 15.486f, 7.896f, + 14.94f, 30.474f, 15.54f, 31.068f, 63.348f, 32.292f, 31.932f, 65.1f, + 33.18f, 16.596f, 33.822f, 17.232f, 8.205f, 16.722f, 8.52f, 17.034f, + 34.704f, 17.676f, 17.49f, 35.628f, 18.144f, 9.075f, 18.48f, 9.408f, + 7.995f, 16.308f, 8.316f, 16.626f, 33.9f, 17.28f, 17.034f, 34.728f, + 17.7f, 8.853f, 18.042f, 9.192f, 17.46f, 35.586f, 18.132f, 36.252f, + 73.86f, 37.62f, 37.116f, 75.612f, 38.508f, 19.26f, 39.222f, 19.968f, + 9.537f, 19.422f, 9.888f, 19.77f, 40.248f, 20.484f, 20.226f, 41.172f, + 20.952f, 10.479f, 21.324f, 10.848f, 9.183f, 18.72f, 9.54f, 19.074f, + 38.868f, 19.8f, 19.482f, 39.696f, 20.22f, 10.113f, 20.598f, 10.488f, + 19.98f, 40.698f, 20.724f, 41.436f, 84.372f, 42.948f, 42.3f, 86.124f, + 43.836f, 21.924f, 44.622f, 22.704f, 10.869f, 22.122f, 11.256f, 22.506f, + 45.792f, 23.292f, 22.962f, 46.716f, 23.76f, 11.883f, 24.168f, 12.288f}); + + auto expGradW = NDArrayFactory::create( + 'c', {oC, iC, kD, kH, kW}, + {5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, + 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, + 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, + 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, + 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 7.84f, 7.84f, + 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, + 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, + 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, + 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, + 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 10.4f, 10.4f, 10.4f, 10.4f, + 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, + 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, + 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, + 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, + 10.4f, 10.4f, 10.4f, 10.4f}); + + auto expGradB = + NDArrayFactory::create('c', {oC}, {2.64f, 3.92f, 5.2f}); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + weights.permutei({2, 3, 4, 1, 0}); + expGradW.permutei({2, 3, 4, 1, 0}); + + sd::ops::conv3dnew_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto bias = NDArrayFactory::create('c', {oC}); - auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2}); +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv3d_bp_test4) { + int bS = 2, iD = 4, iH = 3, iW = 3, iC = 4, oC = 3, kD = 3, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, + // kD, kH, kW, iC] + + NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {oC, iC, kD, kH, kW}, + {7., 5.8, 4.6, 3.4, 2.2, 1., -0.2, -1.4, -2.6, -3.8, -5., -6.2, + 6.7, 5.5, 4.3, 3.1, 1.9, 0.7, -0.5, -1.7, -2.9, -4.1, -5.3, -6.5, + 6.4, 5.2, 4., 2.8, 1.6, 0.4, -0.8, -2., -3.2, -4.4, -5.6, -6.8, + 6.1, 4.9, 3.7, 2.5, 1.3, 0.1, -1.1, -2.3, -3.5, -4.7, -5.9, -7.1, + 6.9, 5.7, 4.5, 3.3, 2.1, 0.9, -0.3, -1.5, -2.7, -3.9, -5.1, -6.3, + 6.6, 5.4, 4.2, 3., 1.8, 0.6, -0.6, -1.8, -3., -4.2, -5.4, -6.6, + 6.3, 5.1, 3.9, 2.7, 1.5, 0.3, -0.9, -2.1, -3.3, -4.5, -5.7, -6.9, + 6., 4.8, 3.6, 2.4, 1.2, 0., -1.2, -2.4, -3.6, -4.8, -6., -7.2, + 6.8, 5.6, 4.4, 3.2, 2., 0.8, -0.4, -1.6, -2.8, -4., -5.2, -6.4, + 6.5, 5.3, 4.1, 2.9, 1.7, 0.5, -0.7, -1.9, -3.1, -4.3, -5.5, -6.7, + 6.2, 5., 3.8, 2.6, 1.4, 0.2, -1., -2.2, -3.4, -4.6, -5.8, -7., + 5.9, 4.7, 3.5, 2.3, 1.1, -0.1, -1.3, -2.5, -3.7, -4.9, -6.1, -7.3}, + sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {1, -0.5, 0.1}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oD, oH, oW}, sd::DataType::FLOAT32); + + NDArray expGradI( + 'c', {bS, iC, iD, iH, iW}, + {1.847, 3.577, 1.694, 3.460, 6.542, 3.010, + 1.469, 2.677, 1.172, 3.226, 5.929999, 2.632, + 5.408, 9.483999, 3.932, 1.894, 2.978, 1.012, + 0.058, -0.694, -0.824, -1.504, -4.916, -3.556, + -1.850, -4.798, -3.020, -1.069, -2.687, -1.654, + -3.236, -7.714, -4.550, -2.311, -5.315, -3.040, + 1.766, 3.406, 1.604, 3.280, 6.164, 2.812, + 1.370, 2.470, 1.064, 3.028, 5.516, 2.416, + 4.976, 8.584001, 3.464, 1.660, 2.492, 0.760, + -0.140, -1.108, -1.040, -1.936, -5.816, -4.024, + -2.084, -5.284, -3.272, -1.186, -2.930, -1.780, + -3.488, -8.236, -4.820, -2.446, -5.594, -3.184, + 1.685, 3.235, 1.514, 3.100, 5.786, 2.614, + 1.271, 2.263, 0.956, 2.830, 5.102, 2.200, + 4.544001, 7.683999, 2.996, 1.426, 2.006, 0.508, + -0.338, -1.522, -1.256, -2.368, -6.716, -4.492, + -2.318, -5.770, -3.524, -1.303, -3.173, -1.906, + -3.740, -8.757999, -5.090, -2.581, -5.873, -3.328, + 1.604, 3.064, 1.424, 2.920, 5.408, 2.416, + 1.172, 2.056, 0.848, 2.632, 4.688, 1.984, + 4.112, 6.784, 2.528, 1.192, 1.520, 0.256, + -0.536, -1.936, -1.472, -2.800, -7.616, -4.960, + -2.552, -6.256, -3.776, -1.420, -3.416, -2.032, + -3.992, -9.280001, -5.360, -2.716, -6.152, -3.472, + 6.815001, 12.649, 5.798, 11.668, 21.230, 9.490, + 4.709, 8.292999, 3.548, 9.706, 17.162001, 7.384, + 14.912001, 25.036001, 9.980001, 4.918, 7.298, 2.308, + -0.374, -3.286, -2.984, -5.824, -17.012001, -11.332001, + -5.738, -14.302, -8.636, -3.013, -7.439, -4.462, + -8.852, -20.674, -11.894, -5.983, -13.523, -7.576, + 6.518, 12.046, 5.492, 11.056, 19.988001, 8.860001, + 4.394, 7.654, 3.224, 9.075999, 15.883999, 6.736001, + 13.616, 22.407999, 8.648, 4.252, 5.947999, 1.624, + -1.004, -4.564, -3.632, -7.120, -19.639999, -12.664001, + -6.404, -15.652, -9.320, -3.346, -8.114, -4.804, + -9.536, -22.059999, -12.596, -6.334, -14.233999, -7.936, + 6.221, 11.443, 5.186, 10.444, 18.746, 8.230, + 4.079, 7.015, 2.900, 8.446, 14.606001, 6.088, + 12.320, 19.779999, 7.316, 3.586, 4.598001, 0.940, + -1.634, -5.842, -4.280, -8.416, -22.268002, -13.996, + -7.070001, -17.001999, -10.004001, -3.679, -8.789, -5.146, + -10.220, -23.445999, -13.298, -6.684999, -14.945, -8.296, + 5.924, 10.840, 4.880, 9.832001, 17.504, 7.600, + 3.764, 6.376, 2.576, 7.816, 13.328, 5.440001, + 11.024, 17.152, 5.983999, 2.920, 3.247999, 0.256, + -2.264, -7.120, -4.928, -9.712, -24.896, -15.328, + -7.736, -18.352001, -10.688, -4.012, -9.464, -5.488, + -10.903999, -24.832001, -14.000, -7.035999, -15.656, -8.655999}, + sd::DataType::FLOAT32); + + NDArray expGradW( + 'c', {oC, iC, kD, kH, kW}, + {-24.399998, -23.080000, -20.440001, -19.119999, -12.519999, + -11.199998, -8.560001, -7.240002, -0.639999, 0.679999, + 3.320001, 4.640001, 23.119999, 24.439999, 27.080002, + 28.400002, 35.000000, 36.320000, 38.959999, 40.279999, + 46.879997, 48.200005, 50.839996, 52.160004, 70.639999, + 71.959999, 74.599998, 75.919998, 82.520004, 83.840004, + 86.479996, 87.800003, 94.399994, 95.719994, 98.360001, + 99.680008, 118.160004, 119.479996, 122.120003, 123.440010, + 130.040009, 131.360001, 134.000000, 135.319992, 141.919998, + 143.239990, 145.879990, 147.200012, -70.159996, -68.200005, + -64.279999, -62.319996, -52.519993, -50.559994, -46.640003, + -44.680000, -34.880001, -32.919998, -29.000002, -27.040005, + 0.400004, 2.359996, 6.279998, 8.240004, 18.040001, + 20.000000, 23.920002, 25.879999, 35.680000, 37.639996, + 41.560001, 43.520000, 70.959999, 72.919998, 76.840004, + 78.799995, 88.599998, 90.560005, 94.479996, 96.440002, + 106.240005, 108.199997, 112.120003, 114.080002, 141.519989, + 143.479996, 147.400009, 149.360001, 159.159988, 161.119995, + 165.040009, 167.000000, 176.800003, 178.760010, 182.679993, + 184.639999, -115.920006, -113.320000, -108.120003, -105.520012, + -92.520004, -89.919991, -84.720001, -82.119995, -69.120010, + -66.520004, -61.320000, -58.719994, -22.320000, -19.719999, + -14.520001, -11.920001, 1.079997, 3.679997, 8.879997, + 11.480003, 24.480001, 27.079998, 32.280003, 34.880001, + 71.279999, 73.880005, 79.080002, 81.680000, 94.679993, + 97.280006, 102.479996, 105.080002, 118.080002, 120.679993, + 125.879997, 128.479996, 164.880005, 167.479996, 172.679993, + 175.279999, 188.279984, 190.880005, 196.080002, 198.679993, + 211.680008, 214.280014, 219.479996, 222.079987}, + sd::DataType::FLOAT32); + + NDArray expGradB('c', {oC}, {2.64, 3.92, 5.2}, sd::DataType::FLOAT32); + + input.linspace(-75, 0.5); + gradO.linspace(0.01, 0.01); + + sd::ops::conv3dnew_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat, wFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} - input = 2.; - weights = 0.5; - expected = 49.; - bias = 1.; +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv3d_bp_test5) { + int bS = 2, iD = 4, iH = 3, iW = 3, iC = 4, oC = 3, kD = 3, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 4, oH = 3, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, + // kD, kH, kW, iC] + + NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {oC, kD, kH, kW, iC}, + {15., 14.7, 14.4, 14.1, 13.8, 13.5, 13.2, 12.9, 12.6, 12.3, 12., 11.7, + 11.4, 11.1, 10.8, 10.5, 10.2, 9.9, 9.6, 9.3, 9., 8.7, 8.4, 8.1, + 7.8, 7.5, 7.2, 6.9, 6.6, 6.3, 6., 5.7, 5.4, 5.1, 4.8, 4.5, + 4.2, 3.9, 3.6, 3.3, 3., 2.7, 2.4, 2.1, 1.8, 1.5, 1.2, 0.9, + 14.9, 14.6, 14.3, 14., 13.7, 13.4, 13.1, 12.8, 12.5, 12.2, 11.9, 11.6, + 11.3, 11., 10.7, 10.4, 10.1, 9.8, 9.5, 9.2, 8.9, 8.6, 8.3, 8., + 7.7, 7.4, 7.1, 6.8, 6.5, 6.2, 5.9, 5.6, 5.3, 5., 4.7, 4.4, + 4.1, 3.8, 3.5, 3.2, 2.9, 2.6, 2.3, 2., 1.7, 1.4, 1.1, 0.8, + 14.8, 14.5, 14.2, 13.9, 13.6, 13.3, 13., 12.7, 12.4, 12.1, 11.8, 11.5, + 11.2, 10.9, 10.6, 10.3, 10., 9.7, 9.4, 9.1, 8.8, 8.5, 8.2, 7.9, + 7.6, 7.3, 7., 6.7, 6.4, 6.1, 5.8, 5.5, 5.2, 4.9, 4.6, 4.3, + 4., 3.7, 3.4, 3.1, 2.8, 2.5, 2.2, 1.9, 1.6, 1.3, 1., 0.7}, + sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {1, -0.5, 0.1}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oD, oH, oW, oC}, sd::DataType::FLOAT32); + + NDArray expGradI( + 'c', {bS, iD, iH, iW, iC}, + {13.565001, 13.286001, 13.007000, 12.728001, 28.264000, 27.652000, + 27.040001, 26.427999, 32.547997, 31.827999, 31.108002, 30.388000, + 31.647999, 30.927998, 30.208000, 29.487999, 64.484001, 62.935997, + 61.387997, 59.839996, 72.188004, 70.424004, 68.660004, 66.896004, + 43.852001, 42.807999, 41.764000, 40.719997, 87.596001, 85.400002, + 83.204002, 81.007996, 95.299988, 92.887993, 90.475998, 88.063995, + 34.130997, 33.348000, 32.564999, 31.782001, 67.856995, 66.210007, + 64.563004, 62.916000, 72.987000, 71.178001, 69.369003, 67.559998, + 70.179001, 68.369995, 66.561005, 64.751999, 137.927994, 134.147995, + 130.367996, 126.587997, 146.891998, 142.787994, 138.683990, 134.580017, + 84.597000, 82.302002, 80.007004, 77.711998, 164.820007, 160.067993, + 155.316010, 150.563995, 173.783997, 168.707993, 163.631989, 158.556000, + 58.674000, 57.162003, 55.649994, 54.138000, 114.027008, 110.921997, + 107.816994, 104.711990, 119.156998, 115.889999, 112.623001, 109.355995, + 113.433006, 110.166000, 106.899002, 103.632004, 218.603989, 211.908020, + 205.211975, 198.515991, 227.568008, 220.547974, 213.528015, 206.507996, + 127.850998, 124.098000, 120.345001, 116.591995, 245.496002, 237.828018, + 230.159988, 222.492004, 254.459991, 246.468002, 238.475998, 230.483994, + 34.049000, 32.797997, 31.547001, 30.295998, 64.479996, 61.924000, + 59.368004, 56.812000, 67.035995, 64.372002, 61.707996, 59.044003, + 62.248001, 59.584003, 56.919998, 54.256001, 116.180000, 110.744003, + 105.307999, 99.872002, 120.428001, 114.776001, 109.124001, 103.472000, + 69.268005, 66.279999, 63.292000, 60.304001, 128.923996, 122.839996, + 116.755997, 110.671997, 133.171997, 126.872002, 120.571991, 114.271996, + 94.565002, 92.342010, 90.118996, 87.896004, 182.488007, 177.988007, + 173.488007, 168.988007, 186.772003, 182.164001, 177.556000, 172.947998, + 178.095993, 173.488007, 168.880005, 164.272003, 341.828003, 332.504028, + 323.180023, 313.856018, 349.532013, 339.992004, 330.451996, 320.911987, + 190.299988, 185.368011, 180.436005, 175.503998, 364.940002, 354.967987, + 344.996002, 335.024017, 372.644012, 362.455994, 352.268005, 342.080017, + 132.303009, 128.604004, 124.904999, 121.206001, 252.536987, 245.057999, + 237.578979, 230.100006, 257.666992, 250.026001, 242.385010, 234.744019, + 243.195007, 235.554001, 227.912994, 220.272003, 460.631958, 445.188019, + 429.744019, 414.299988, 469.595947, 453.827972, 438.059998, 422.291992, + 257.613007, 249.486008, 241.358994, 233.232010, 487.523987, 471.108032, + 454.691986, 438.276001, 496.488037, 479.748016, 463.007996, 446.268005, + 156.846008, 152.417999, 147.989990, 143.561996, 298.707001, 289.769989, + 280.833008, 271.895996, 303.837006, 294.737976, 285.638977, 276.540009, + 286.449005, 277.350006, 268.250977, 259.151978, 541.307983, 522.947998, + 504.587982, 486.227997, 550.271973, 531.588013, 512.903992, 494.220032, + 300.867004, 291.281982, 281.696991, 272.112000, 568.200012, 548.868042, + 529.535950, 510.204010, 577.164062, 557.507935, 537.851990, 518.196045, + 83.944992, 80.750000, 77.555000, 74.360001, 156.496002, 150.052002, + 143.608002, 137.164001, 159.052002, 152.500000, 145.947998, 139.395996, + 146.488007, 139.936005, 133.384003, 126.832001, 269.107971, 255.895996, + 242.684006, 229.471985, 273.356018, 259.927979, 246.500000, 233.071991, + 153.507996, 146.632004, 139.755997, 132.880005, 281.851990, 267.992004, + 254.132004, 240.272003, 286.100006, 272.023987, 257.947998, 243.872009}, + sd::DataType::FLOAT32); + + NDArray expGradW( + 'c', {oC, kD, kH, kW, iC}, + {396.899872, 429.570007, 462.240234, 494.910156, 313.739960, + 335.250000, 356.760071, 378.270020, 403.379944, 424.350006, + 445.320007, 466.289978, 299.520020, 313.319977, 327.119995, + 340.920013, 1556.280029, 1594.979980, 1633.679932, 1672.379883, + 1090.080078, 1115.520020, 1140.959961, 1166.400024, 1183.679932, + 1208.400024, 1233.119995, 1257.840088, 821.279907, 837.519897, + 853.760010, 870.000000, 1500.119873, 1525.500122, 1550.880005, + 1576.260010, 1029.780029, 1046.429932, 1063.080078, 1079.729980, + 1080.539917, 1096.650024, 1112.760010, 1128.869995, 738.000000, + 748.560059, 759.119995, 769.679993, 389.880005, 422.819946, + 455.759979, 488.699951, 309.420013, 331.109985, 352.799988, + 374.490051, 399.780029, 420.930023, 442.080017, 463.230011, + 297.359985, 311.280029, 325.200012, 339.120056, 1553.400146, + 1592.459961, 1631.520020, 1670.579956, 1088.640015, 1114.320068, + 1140.000000, 1165.679932, 1183.199951, 1208.160034, 1233.119995, + 1258.079956, 821.280029, 837.680054, 854.079956, 870.479980, + 1502.819946, 1528.469971, 1554.119995, 1579.770020, 1031.939941, + 1048.770020, 1065.599976, 1082.429932, 1083.420044, 1099.709961, + 1116.000000, 1132.290039, 740.159973, 750.840027, 761.519958, + 772.199951, 382.859924, 416.070099, 449.279968, 482.489990, + 305.099976, 326.970062, 348.840027, 370.709991, 396.179962, + 417.510010, 438.839966, 460.169952, 295.200012, 309.239990, + 323.279968, 337.320007, 1550.519775, 1589.939941, 1629.359985, + 1668.779907, 1087.200073, 1113.119995, 1139.039917, 1164.959961, + 1182.719971, 1207.920044, 1233.119995, 1258.320190, 821.279968, + 837.840027, 854.400024, 870.959961, 1505.520142, 1531.439819, + 1557.359985, 1583.279907, 1034.100098, 1051.110107, 1068.120117, + 1085.130005, 1086.299927, 1102.770020, 1119.239990, 1135.710083, + 742.319946, 753.119995, 763.919983, 774.720032}, + sd::DataType::FLOAT32); + + NDArray expGradB('c', {oC}, {77.400002, 78.119995, 78.840004}, + sd::DataType::FLOAT32); + + input.linspace(-75, 0.5); + gradO.linspace(0.01, 0.01); + + sd::ops::conv3dnew_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat, wFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} - sd::ops::conv3dnew op; - auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_test1) { + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 4, oC = 3, kD = 2, kH = 3, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4, 3, 3}, + {534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, + 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, + 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, + 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, + 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, + 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f, + 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, + 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, + 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, + 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, + 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, + 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f}); + input = 2.; + weights.linspace(0.1, 0.1); + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} - // output->printIndexedBuffer(); +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_test2) { + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 4, oC = 3, kD = 2, kH = 3, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto expected = NDArrayFactory::create( + 'c', {2, 2, 2, 2, 3}, + {686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, + 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, + 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, + 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, + 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f}); + input = 2.; + weights.linspace(0.1, 0.1); + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_test3) { + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 4, oC = 3, kD = 2, kH = 3, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2}); + input = 2.; + weights = 0.5; + expected = 48.; + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} +//////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_test4) { + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 4, oC = 3, kD = 2, kH = 3, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}); + auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2}); + + input = 2.; + weights = 0.5; + expected = 49.; + bias = 1.; + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights, &bias}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + auto output = results.at(0); + + // output->printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv3d_test5) { - - int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto bias = NDArrayFactory::create('c', {oC},{1.f, 2.f, 3.f}); - auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, - 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, - 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f}); - input = 2.; - weights = 0.5; - - sd::ops::conv3dnew op; - auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); - - // output->printIndexedBuffer(); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 4, oC = 3, kD = 2, kH = 3, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}, {1.f, 2.f, 3.f}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 2, 2, 2}, + {49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 50.f, 50.f, 50.f, 50.f, + 50.f, 50.f, 50.f, 50.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, + 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 50.f, 50.f, 50.f, 50.f, + 50.f, 50.f, 50.f, 50.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f}); + input = 2.; + weights = 0.5; + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights, &bias}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + auto output = results.at(0); + + // output->printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv3d_test6) { - - int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); - auto bias = NDArrayFactory::create('c', {oC},{1.f, 2.f, 3.f}); - auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f, - 698.f, 698.f, 698.f, 698.f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, - 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f, - 698.f, 698.f, 698.f, 698.f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f}); - input = 2.; - weights.linspace(0.1, 0.1); - weights.permutei({2, 3, 4, 1, 0}); - - sd::ops::conv3dnew op; - auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); - - // output->printIndexedBuffer(); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 4, oC = 3, kD = 2, kH = 3, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); + auto bias = NDArrayFactory::create('c', {oC}, {1.f, 2.f, 3.f}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 2, 2, 2}, + {236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, + 698.f, 698.f, 698.f, 698.f, 698.f, 698.f, 698.f, 698.f, + 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, + 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, + 698.f, 698.f, 698.f, 698.f, 698.f, 698.f, 698.f, 698.f, + 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f}); + input = 2.; + weights.linspace(0.1, 0.1); + weights.permutei({2, 3, 4, 1, 0}); + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights, &bias}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + auto output = results.at(0); + + // output->printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv3d_test7) { - - int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); - auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, - 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, - 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f}); - input = 2.; - weights.linspace(0.1, 0.1); - weights.permutei({2, 3, 4, 1, 0}); - - sd::ops::conv3dnew op; - auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 4, oC = 3, kD = 2, kH = 3, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 2, 2, 2}, + {235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, + 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, + 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, + 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, + 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, + 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f}); + input = 2.; + weights.linspace(0.1, 0.1); + weights.permutei({2, 3, 4, 1, 0}); + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv3d_test8) { - auto x = NDArrayFactory::create('c', {4, 2, 28, 28, 3}); - auto y = NDArrayFactory::create('c', {2, 5, 5, 3, 4}); - auto e = NDArrayFactory::create('c', {4, 1, 7, 10, 4}); + auto x = NDArrayFactory::create('c', {4, 2, 28, 28, 3}); + auto y = NDArrayFactory::create('c', {2, 5, 5, 3, 4}); + auto e = NDArrayFactory::create('c', {4, 1, 7, 10, 4}); - sd::ops::conv3dnew op; - auto result = op.evaluate({&x, &y}, {}, {2,5,5, 5,4,3, 0,0,0, 1,1,1, 1,1}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::conv3dnew op; + auto result = + op.evaluate({&x, &y}, {}, {2, 5, 5, 5, 4, 3, 0, 0, 0, 1, 1, 1, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.isSameShape(z)); } TYPED_TEST(TypedConvolutionTests1, conv3d_test9) { - auto x = NDArrayFactory::create('c', {4, 2, 28, 28, 3}); - auto w = NDArrayFactory::create('c', {2, 5, 5, 3, 4}); - auto exp = NDArrayFactory::create('c', {4, 1, 7, 10, 4}); + auto x = NDArrayFactory::create('c', {4, 2, 28, 28, 3}); + auto w = NDArrayFactory::create('c', {2, 5, 5, 3, 4}); + auto exp = NDArrayFactory::create('c', {4, 1, 7, 10, 4}); - sd::ops::conv3dnew op; - auto result = op.evaluate({&x, &w}, {}, {2,5,5, 5,4,3, 0,0,0, 1,1,1, 1,1}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::conv3dnew op; + auto result = + op.evaluate({&x, &w}, {}, {2, 5, 5, 5, 4, 3, 0, 0, 0, 1, 1, 1, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); - ShapeList shapeList({x.shapeInfo(), w.shapeInfo()}); - ContextPrototype proto; - Context ctx(1); - ctx.getIArguments()->push_back(2); - ctx.getIArguments()->push_back(5); - ctx.getIArguments()->push_back(5); + ShapeList shapeList({x.shapeInfo(), w.shapeInfo()}); + ContextPrototype proto; + Context ctx(1); + ctx.appendI(2); + ctx.appendI(5); + ctx.appendI(5); - ctx.getIArguments()->push_back(5); - ctx.getIArguments()->push_back(4); - ctx.getIArguments()->push_back(3); + ctx.appendI(5); + ctx.appendI(4); + ctx.appendI(3); - ctx.getIArguments()->push_back(0); - ctx.getIArguments()->push_back(0); - ctx.getIArguments()->push_back(0); + ctx.appendI(0); + ctx.appendI(0); + ctx.appendI(0); - ctx.getIArguments()->push_back(1); - ctx.getIArguments()->push_back(1); - ctx.getIArguments()->push_back(1); + ctx.appendI(1); + ctx.appendI(1); + ctx.appendI(1); - ctx.getIArguments()->push_back(0); - ctx.getIArguments()->push_back(1); // previous variant was "ctx.getIArguments()->push_back(0)" and this caused fail + ctx.appendI(0); + ctx.appendI(1); // previous variant was "ctx.appendI(0)" and this caused fail - auto shapes = op.calculateOutputShape(&shapeList, ctx); - ASSERT_EQ(1, shapes->size()); + auto shapes = op.calculateOutputShape(&shapeList, ctx); + ASSERT_EQ(1, shapes->size()); - auto s = shapes->at(0); + auto s = shapes->at(0); - auto z = result.at(0); - // z->printShapeInfo("z shape"); + auto z = result.at(0); + // z->printShapeInfo("z shape"); - ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.isSameShape(z)); - delete shapes; + + delete shapes; } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv3d_test10) { + int bS = 1, iD = 2, iH = 2, iW = 2, iC = 1, oC = 1, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW - int bS=1, iD=2,iH=2,iW=2, iC=1,oC=1, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - input = 2.; - weights = 1.; + input = 2.; + weights = 1.; - sd::ops::conv3dnew op; - auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv3d_test11) { - - int bS=5, iD=4,iH=14,iW=14, iC=1,oC=1, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=3,oH=13,oW=13; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto expected = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); - - input = 2.; - weights = 1.; - - sd::ops::conv3dnew op; - auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(output->isSameShape(&expected)); - + int bS = 5, iD = 4, iH = 14, iW = 14, iC = 1, oC = 1, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 3, oH = 13, oW = 13; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto expected = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); + + input = 2.; + weights = 1.; + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(output.isSameShape(&expected)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, conv3d_test12) { - - int bS=2, iD=4,iH=3,iW=3, iC=4,oC=3, kD=3,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - int wFormat = 1; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] - - NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); - NDArray weights('c', {oC, iC, kD, kH, kW}, {-14.4, -13.2, -12.0, -10.8, -9.6, -8.4, -7.2, -6.0, -4.8, -3.6, -2.4, -1.2, -14.1, -12.9, -11.7, -10.5, -9.3, -8.1, - -6.9, -5.7, -4.5, -3.3, -2.1, -0.9, -13.8, -12.6, -11.4, -10.2, -9.0, -7.8, -6.6, -5.4, -4.2, -3.0, -1.8, -0.6, -13.5, -12.3, -11.1, -9.9, -8.7, -7.5, -6.3, - -5.1, -3.9, -2.7, -1.5, -0.3, -14.3, -13.1, -11.9, -10.7, -9.5, -8.3, -7.1, -5.9, -4.7, -3.5, -2.3, -1.1, -14.0, -12.8, -11.6, -10.4, -9.2, -8.0, -6.8, -5.6, - -4.4, -3.2, -2.0, -0.8, -13.7, -12.5, -11.3, -10.1, -8.9, -7.7, -6.5, -5.3, -4.1, -2.9, -1.7, -0.5, -13.4, -12.2, -11.0, -9.8, -8.6, -7.4, -6.2, -5.0, -3.8, -2.6, -1.4, -0.2, -14.2, -13.0, -11.8, -10.6, -9.4, -8.2, -7.0, -5.8, -4.6, -3.4, -2.2, -1.0, -13.9, -12.7, -11.5, -10.3, -9.1, -7.9, -6.7, -5.5, -4.3, -3.1, -1.9, -0.7, -13.6, -12.4, -11.2, -10.0, -8.8, -7.6, -6.4, -5.2, -4.0, -2.8, -1.6, -0.4, -13.3, -12.1, -10.9, -9.7, -8.5, -7.3, -6.1, -4.9, -3.7, -2.5, -1.3, -0.1}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {-1,2,0.5}, sd::DataType::FLOAT32); - - NDArray expOutput('c', {bS, oC, oD, oH, oW}, {-42520.597656, -42344.199219, -41991.402344, -41814.996094, -40932.992188, -40756.597656, -40403.800781, -40227.406250, - -41953.601562, -41779.601562, -41431.597656, -41257.601562, -40387.601562, -40213.597656, -39865.601562, -39691.597656, -41391.105469, -41219.492188, - -40876.300781, -40704.699219, -39846.707031, -39675.097656, -39331.898438, -39160.300781, -17119.001953, -16942.599609, -16589.798828, -16413.400391, - -15531.399414, -15355.000000, -15002.199219, -14825.800781, -16897.597656, -16723.597656, -16375.599609, -16201.599609, -15331.599609, -15157.600586, - -14809.601562, -14635.598633, -16680.703125, -16509.099609, -16165.900391, -15994.300781, -15136.300781, -14964.700195, -14621.500000, -14449.900391}, sd::DataType::FLOAT32); - - input.linspace(150,-0.5); - - sd::ops::conv3dnew op; - auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 2, iD = 4, iH = 3, iW = 3, iC = 4, oC = 3, kD = 3, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, + // kD, kH, kW, iC] + + NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {oC, iC, kD, kH, kW}, + {-14.4, -13.2, -12.0, -10.8, -9.6, -8.4, -7.2, -6.0, -4.8, -3.6, + -2.4, -1.2, -14.1, -12.9, -11.7, -10.5, -9.3, -8.1, -6.9, -5.7, + -4.5, -3.3, -2.1, -0.9, -13.8, -12.6, -11.4, -10.2, -9.0, -7.8, + -6.6, -5.4, -4.2, -3.0, -1.8, -0.6, -13.5, -12.3, -11.1, -9.9, + -8.7, -7.5, -6.3, -5.1, -3.9, -2.7, -1.5, -0.3, -14.3, -13.1, + -11.9, -10.7, -9.5, -8.3, -7.1, -5.9, -4.7, -3.5, -2.3, -1.1, + -14.0, -12.8, -11.6, -10.4, -9.2, -8.0, -6.8, -5.6, -4.4, -3.2, + -2.0, -0.8, -13.7, -12.5, -11.3, -10.1, -8.9, -7.7, -6.5, -5.3, + -4.1, -2.9, -1.7, -0.5, -13.4, -12.2, -11.0, -9.8, -8.6, -7.4, + -6.2, -5.0, -3.8, -2.6, -1.4, -0.2, -14.2, -13.0, -11.8, -10.6, + -9.4, -8.2, -7.0, -5.8, -4.6, -3.4, -2.2, -1.0, -13.9, -12.7, + -11.5, -10.3, -9.1, -7.9, -6.7, -5.5, -4.3, -3.1, -1.9, -0.7, + -13.6, -12.4, -11.2, -10.0, -8.8, -7.6, -6.4, -5.2, -4.0, -2.8, + -1.6, -0.4, -13.3, -12.1, -10.9, -9.7, -8.5, -7.3, -6.1, -4.9, + -3.7, -2.5, -1.3, -0.1}, + sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-1, 2, 0.5}, sd::DataType::FLOAT32); + + NDArray expOutput( + 'c', {bS, oC, oD, oH, oW}, + {-42520.597656, -42344.199219, -41991.402344, -41814.996094, + -40932.992188, -40756.597656, -40403.800781, -40227.406250, + -41953.601562, -41779.601562, -41431.597656, -41257.601562, + -40387.601562, -40213.597656, -39865.601562, -39691.597656, + -41391.105469, -41219.492188, -40876.300781, -40704.699219, + -39846.707031, -39675.097656, -39331.898438, -39160.300781, + -17119.001953, -16942.599609, -16589.798828, -16413.400391, + -15531.399414, -15355.000000, -15002.199219, -14825.800781, + -16897.597656, -16723.597656, -16375.599609, -16201.599609, + -15331.599609, -15157.600586, -14809.601562, -14635.598633, + -16680.703125, -16509.099609, -16165.900391, -15994.300781, + -15136.300781, -14964.700195, -14621.500000, -14449.900391}, + sd::DataType::FLOAT32); + + input.linspace(150, -0.5); + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights, &bias}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat, wFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, conv3d_test13) { - - int bS=2, iD=4,iH=3,iW=3, iC=4,oC=3, kD=3,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=4,oH=3,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - int wFormat = 2; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] - - NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {oC, kD, kH, kW, iC}, {-7., -6.7, -6.4, -6.1, -5.8, -5.5, -5.2, -4.9, -4.6, -4.3, -4., -3.7, -3.4, -3.1, -2.8, -2.5, -2.2, -1.9, -1.6, -1.3, - -1., -0.7, -0.4, -0.1, 0.2, 0.5, 0.8, 1.1, 1.4, 1.7, 2., 2.3, 2.6, 2.9, 3.2, 3.5, 3.8, 4.1, 4.4, 4.7, 5., 5.3, 5.6, 5.9, 6.2, 6.5, 6.8, 7.1, -6.9, -6.6, -6.3, - -6., -5.7, -5.4, -5.1, -4.8, -4.5, -4.2, -3.9, -3.6, -3.3, -3., -2.7, -2.4, -2.1, -1.8, -1.5, -1.2, -0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9, 1.2, 1.5, 1.8, 2.1, - 2.4, 2.7, 3., 3.3, 3.6, 3.9, 4.2, 4.5, 4.8, 5.1, 5.4, 5.7, 6., 6.3, 6.6, 6.9, 7.2, -6.8, -6.5, -6.2, -5.9, -5.6, -5.3, -5., -4.7, -4.4, -4.1, -3.8, -3.5, -3.2, - -2.9, -2.6, -2.3, -2., -1.7, -1.4, -1.1, -0.8, -0.5, -0.2, 0.1, 0.4, 0.7, 1., 1.3, 1.6, 1.9, 2.2, 2.5, 2.8, 3.1, 3.4, 3.7, 4., 4.3, 4.6, 4.9, 5.2, 5.5, 5.8, 6.1, 6.4, 6.7, 7., 7.3}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {-1,2,0.5}, sd::DataType::FLOAT32); - - NDArray expOutput('c', {bS, oD, oH, oW, oC}, {3969.399658, 4168.399902, 4362.899414, 3812.600586, 4005.200195, 4193.299805, 1317.000000, 1413.199829, 1504.899902, - 3498.999756, 3678.800049, 3854.100098, 3342.200195, 3515.599854, 3684.500244, 1139.400024, 1226.000000, 1308.099976, 685.799927, 772.400024, 854.500000, - 645.800049, 729.200073, 808.099976, 80.799995, 123.200012, 161.100006, -2851.000732, -2597.199707, -2347.899414, -2855.799805, -2611.600098, -2371.900879, - -2124.399414, -2003.199951, -1886.500244, -2865.399902, -2640.400146, -2419.899902, -2870.199951, -2654.800049, -2443.899902, -2045.200073, -1938.399902, - -1836.100220, -2596.000244, -2489.199707, -2386.900146, -2540.799561, -2438.800049, -2341.300049, -1539.699951, -1488.400024, -1441.599854, -2894.200195, - -2726.800049, -2563.899902, -2899.000488, -2741.199707, -2587.899658, -1886.800171, -1808.800049, -1735.300171, -2908.599121, -2770.000488, -2635.900146, -2913.400146, -2784.399658, -2659.899902, -1807.599976, -1743.999878, -1684.900146, -2099.199951, -2035.599976, -1976.500366, -2044.000244, -1985.199707, -1930.900024, -1161.699951, -1132.000122, -1106.800171, -2731.399902, -2647.599609, -2568.300293, -2580.999756, -2503.600098, -2430.699951, -1457.400024, -1418.800049, -1384.700073, -2280.200195, -2215.600098, -2155.500732, -2129.799561, -2071.600098, -2017.899780, -1174.200073, -1145.200195, -1120.699829, -1282.200073, -1253.199951, -1228.699951, -1168.599976, -1142.799927, -1121.500122, -615.199951, -601.600037, -592.500000, -1675.399658, -1706.800049, -1742.700073, -1832.200073, -1870.000000, -1912.299561, -814.199951, -833.200012, -856.699951, -2145.800049, -2196.399902, -2251.500244, -2302.600342, -2359.599854, -2421.100098, -991.800049, -1020.400024, -1053.500000, -754.199951, -782.800049, -815.900085, -794.199951, -825.999939, -862.299988, -293.600006, -308.800018, -328.500000, -3023.800293, -3115.600098, -3211.900391, -3028.599121, -3130.000244, -3235.899902, -1173.999878, -1225.600098, -1281.699951, -3038.200195, -3158.799805, -3283.899902, -3043.000000, -3173.199707, -3307.900391, -1094.800049, -1160.800049, -1231.300049, -608.799988, -674.799988, -745.300049, -553.599976, -624.400024, -699.700012, -27.700012, -62.799988, -102.400009, -3066.999512, -3245.199707, -3427.900391, -3071.800293, -3259.599854, -3451.900146, -936.400085, -1031.199951, -1130.500000, -3081.400146, -3288.400635, -3499.899414, -3086.200439, -3302.799805, -3523.899902, -857.199951, -966.400024, -1080.099976, -111.999969, -221.199936, -334.900024, -56.800079, -170.799988, -289.299927, 350.299927, 293.600037, 232.399979, 2683.000244, 2536.400146, 2385.300049, 2833.399658, 2680.400391, 2522.900391, 1940.999878, 1864.399902, 1783.300049, 3134.200195, 2968.399414, 2798.100098, 3284.600098, 3112.400391, 2935.699707, 2224.199707, 2138.000244, 2047.300049, 2807.399658, 2721.200195, 2630.500000, 2921.000000, 2831.599854, 2737.699707, 1775.200195, 1731.199951, 1682.699829}, sd::DataType::FLOAT32); - - input.linspace(75,-0.5); - - sd::ops::conv3dnew op; - auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 2, iD = 4, iH = 3, iW = 3, iC = 4, oC = 3, kD = 3, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 4, oH = 3, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, + // kD, kH, kW, iC] + + NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {oC, kD, kH, kW, iC}, + {-7., -6.7, -6.4, -6.1, -5.8, -5.5, -5.2, -4.9, -4.6, -4.3, -4., -3.7, + -3.4, -3.1, -2.8, -2.5, -2.2, -1.9, -1.6, -1.3, -1., -0.7, -0.4, -0.1, + 0.2, 0.5, 0.8, 1.1, 1.4, 1.7, 2., 2.3, 2.6, 2.9, 3.2, 3.5, + 3.8, 4.1, 4.4, 4.7, 5., 5.3, 5.6, 5.9, 6.2, 6.5, 6.8, 7.1, + -6.9, -6.6, -6.3, -6., -5.7, -5.4, -5.1, -4.8, -4.5, -4.2, -3.9, -3.6, + -3.3, -3., -2.7, -2.4, -2.1, -1.8, -1.5, -1.2, -0.9, -0.6, -0.3, 0., + 0.3, 0.6, 0.9, 1.2, 1.5, 1.8, 2.1, 2.4, 2.7, 3., 3.3, 3.6, + 3.9, 4.2, 4.5, 4.8, 5.1, 5.4, 5.7, 6., 6.3, 6.6, 6.9, 7.2, + -6.8, -6.5, -6.2, -5.9, -5.6, -5.3, -5., -4.7, -4.4, -4.1, -3.8, -3.5, + -3.2, -2.9, -2.6, -2.3, -2., -1.7, -1.4, -1.1, -0.8, -0.5, -0.2, 0.1, + 0.4, 0.7, 1., 1.3, 1.6, 1.9, 2.2, 2.5, 2.8, 3.1, 3.4, 3.7, + 4., 4.3, 4.6, 4.9, 5.2, 5.5, 5.8, 6.1, 6.4, 6.7, 7., 7.3}, + sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-1, 2, 0.5}, sd::DataType::FLOAT32); + + NDArray expOutput( + 'c', {bS, oD, oH, oW, oC}, + {3969.399658, 4168.399902, 4362.899414, 3812.600586, 4005.200195, + 4193.299805, 1317.000000, 1413.199829, 1504.899902, 3498.999756, + 3678.800049, 3854.100098, 3342.200195, 3515.599854, 3684.500244, + 1139.400024, 1226.000000, 1308.099976, 685.799927, 772.400024, + 854.500000, 645.800049, 729.200073, 808.099976, 80.799995, + 123.200012, 161.100006, -2851.000732, -2597.199707, -2347.899414, + -2855.799805, -2611.600098, -2371.900879, -2124.399414, -2003.199951, + -1886.500244, -2865.399902, -2640.400146, -2419.899902, -2870.199951, + -2654.800049, -2443.899902, -2045.200073, -1938.399902, -1836.100220, + -2596.000244, -2489.199707, -2386.900146, -2540.799561, -2438.800049, + -2341.300049, -1539.699951, -1488.400024, -1441.599854, -2894.200195, + -2726.800049, -2563.899902, -2899.000488, -2741.199707, -2587.899658, + -1886.800171, -1808.800049, -1735.300171, -2908.599121, -2770.000488, + -2635.900146, -2913.400146, -2784.399658, -2659.899902, -1807.599976, + -1743.999878, -1684.900146, -2099.199951, -2035.599976, -1976.500366, + -2044.000244, -1985.199707, -1930.900024, -1161.699951, -1132.000122, + -1106.800171, -2731.399902, -2647.599609, -2568.300293, -2580.999756, + -2503.600098, -2430.699951, -1457.400024, -1418.800049, -1384.700073, + -2280.200195, -2215.600098, -2155.500732, -2129.799561, -2071.600098, + -2017.899780, -1174.200073, -1145.200195, -1120.699829, -1282.200073, + -1253.199951, -1228.699951, -1168.599976, -1142.799927, -1121.500122, + -615.199951, -601.600037, -592.500000, -1675.399658, -1706.800049, + -1742.700073, -1832.200073, -1870.000000, -1912.299561, -814.199951, + -833.200012, -856.699951, -2145.800049, -2196.399902, -2251.500244, + -2302.600342, -2359.599854, -2421.100098, -991.800049, -1020.400024, + -1053.500000, -754.199951, -782.800049, -815.900085, -794.199951, + -825.999939, -862.299988, -293.600006, -308.800018, -328.500000, + -3023.800293, -3115.600098, -3211.900391, -3028.599121, -3130.000244, + -3235.899902, -1173.999878, -1225.600098, -1281.699951, -3038.200195, + -3158.799805, -3283.899902, -3043.000000, -3173.199707, -3307.900391, + -1094.800049, -1160.800049, -1231.300049, -608.799988, -674.799988, + -745.300049, -553.599976, -624.400024, -699.700012, -27.700012, + -62.799988, -102.400009, -3066.999512, -3245.199707, -3427.900391, + -3071.800293, -3259.599854, -3451.900146, -936.400085, -1031.199951, + -1130.500000, -3081.400146, -3288.400635, -3499.899414, -3086.200439, + -3302.799805, -3523.899902, -857.199951, -966.400024, -1080.099976, + -111.999969, -221.199936, -334.900024, -56.800079, -170.799988, + -289.299927, 350.299927, 293.600037, 232.399979, 2683.000244, + 2536.400146, 2385.300049, 2833.399658, 2680.400391, 2522.900391, + 1940.999878, 1864.399902, 1783.300049, 3134.200195, 2968.399414, + 2798.100098, 3284.600098, 3112.400391, 2935.699707, 2224.199707, + 2138.000244, 2047.300049, 2807.399658, 2721.200195, 2630.500000, + 2921.000000, 2831.599854, 2737.699707, 1775.200195, 1731.199951, + 1682.699829}, + sd::DataType::FLOAT32); + + input.linspace(75, -0.5); + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights, &bias}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat, wFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, pointwise_conv2d_test1) { - - int bS=2, iH=4,iW=3, iC=4,oC=3; - - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {1, 1, iC, oC}); - auto bias = NDArrayFactory::create('c', {oC}); - - - auto expOutput = NDArrayFactory::create('c', {bS, iH, iW, oC},{ 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, - 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, - 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, - 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f}); - input = 2.; - weights.linspace(0.1, 0.1); - bias = 1.; - - sd::ops::pointwise_conv2d op; - auto results = op.evaluate({&input, &weights, &bias}, {}, {dataFormat}); - auto* output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - + int bS = 2, iH = 4, iW = 3, iC = 4, oC = 3; + + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {1, 1, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}); + + auto expOutput = NDArrayFactory::create( + 'c', {bS, iH, iW, oC}, + {5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, + 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, + 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, + 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, + 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, + 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f}); + input = 2.; + weights.linspace(0.1, 0.1); + bias = 1.; + + sd::ops::pointwise_conv2d op; + auto results = op.evaluate({&input, &weights, &bias}, {}, {dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, vol2col_test1) { - - int bS=2, iD=2,iH=3,iW=2, iC=3,oC=2, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=3,oW=2; - - NDArray volume('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); - NDArray columns('c', {bS, iC, kD, kH, kW, oD, oH, oW}, sd::DataType::FLOAT32); - - columns = -1.; - volume.linspace(1); - - NDArray columnsExpected('c', {bS, iC, kD, kH, kW, oD, oH, oW}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 2., 0., 4., 0., 6.,0., 8., 0., 10., 0., 12., 0., 3., 4., 5., 6., 0., 0., 9., 10., 11., 12., 0., 0., 4., 0., 6., 0., 0., 0., 10., 0., 12., 0., 0., 0., 5., 6., -0., 0., 0., 0., 11., 12., 0., 0., 0., 0., 6., 0., 0., 0., 0., 0., 12., 0., 0., 0., 0., 0., 7., 8., 9., 10., 11., 12., 0., 0., 0., 0., 0., 0., 8., 0., 10., 0., 12., 0., 0., 0., 0., 0., 0., 0., 9., 10., 11., 12., 0., 0., 0., 0., 0., 0., 0., 0., 10., 0., 12., 0., 0., 0., 0., 0., -0., 0., 0., 0., 11., 12., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 12., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 13., 14., 15., 16., 17.,18., 19., 20., 21., 22., 23., 24., 14., 0., 16., 0., 18., 0., 20., 0., 22., 0., 24., 0., 15., 16., 17., 18., 0., 0., 21., 22., 23., 24., 0., -0., 16., 0., 18., 0., 0., 0., 22., 0., 24., 0., 0., 0., 17., 18., 0., 0., 0., 0., 23., 24., 0., 0., 0., 0., 18., 0., 0., 0., 0., 0., 24., 0., 0., 0., 0., 0., 19., 20., 21., 22., 23., 24., 0., 0., 0., 0., 0., 0., 20., 0., 22., 0., 24., 0., 0., 0., 0., 0., 0., 0., 21., 22., 23., -24., 0., 0., 0., 0., 0., 0., 0., 0., 22., 0., 24., 0., 0., 0., 0., 0., 0., 0., 0., 0., 23., 24., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 24.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 26., 0., 28., 0., 30., 0., 32., 0., -34., 0., 36., 0., 27., 28., 29., 30., 0., 0., 33., 34., 35., 36., 0., 0., 28., 0., 30., 0., 0., 0., 34., 0., 36., 0., 0., 0., 29., 30., 0., 0., 0., 0., 35., 36., 0., 0., 0., 0., 30., 0., 0., 0., 0., 0., 36., 0., 0., 0., 0., 0., 31., 32., 33., 34., 35., 36., 0., 0., 0., 0., 0., -0., 32., 0., 34., 0., 36., 0., 0., 0., 0., 0., 0., 0., 33., 34., 35., 36., 0., 0., 0., 0., 0., 0., 0., 0., 34., 0., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 35., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 37., 38., 39., 40., -41., 42., 43., 44., 45., 46., 47., 48., 38., 0., 40., 0., 42., 0., 44., 0., 46., 0., 48., 0., 39., 40., 41., 42., 0., 0., 45., 46., 47., 48., 0., 0., 40., 0., 42., 0., 0., 0., 46., 0., 48., 0., 0., 0., 41., 42., 0., 0., 0., 0., 47., 48., 0., 0., 0., 0., 42., 0., 0., 0., 0., -0., 48., 0., 0., 0., 0., 0., 43., 44., 45., 46., 47., 48., 0., 0., 0., 0., 0., 0., 44., 0., 46., 0., 48., 0., 0., 0., 0., 0., 0., 0., 45., 46., 47., 48., 0., 0., 0., 0., 0., 0., 0., 0., 46., 0., 48., 0., 0., 0., 0., 0., 0., 0., 0., 0., 47., 48., 0., 0., 0., 0., 0., 0., 0., 0., -0., 0., 48., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 50., 0., 52., 0., 54.,0., 56., 0., 58., 0., 60., 0., 51., 52., 53., 54., 0., 0., 57., 58., 59., 60., 0., 0., 52., 0., 54., 0., 0., 0., 58., 0., 60., 0., 0., 0., -53., 54., 0., 0., 0., 0., 59., 60., 0., 0., 0., 0., 54., 0., 0., 0., 0., 0., 60., 0., 0., 0., 0., 0., 55., 56., 57., 58., 59., 60., 0., 0.,0., 0., 0., 0., 56., 0., 58., 0., 60., 0., 0., 0., 0., 0., 0., 0., 57., 58., 59., 60., 0., 0., 0., 0., 0., 0., 0., 0., 58., 0., 60., 0., -0., 0., 0., 0., 0., 0., 0., 0., 59., 60., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 60., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 62., 0., 64., 0., 66., 0., 68., 0., 70., 0., 72., 0., 63., 64., 65., 66., 0., 0., 69., -70., 71., 72., 0., 0., 64., 0., 66., 0., 0., 0., 70., 0., 72., 0., 0., 0., 65., 66., 0., 0., 0., 0., 71., 72., 0., 0., 0., 0., 66., 0., 0., 0., 0., 0., 72., 0., 0., 0., 0., 0., 67., 68., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., 68., 0., 70., 0., 72., 0., 0., 0., 0., 0., 0., -0., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., 0., 0., 70., 0., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 71., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, sd::DataType::FLOAT32); - - graph::Context context(1); - sd::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); - // columns.printBuffer(); - - ASSERT_TRUE(columns.equalsTo(columnsExpected)); + int bS = 2, iD = 2, iH = 3, iW = 2, iC = 3, oC = 2, kD = 2, kH = 3, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 3, oW = 2; + + NDArray volume('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); + NDArray columns('c', {bS, iC, kD, kH, kW, oD, oH, oW}, sd::DataType::FLOAT32); + + columns = -1.; + volume.linspace(1); + + NDArray columnsExpected( + 'c', {bS, iC, kD, kH, kW, oD, oH, oW}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 2., 0., + 4., 0., 6., 0., 8., 0., 10., 0., 12., 0., 3., 4., 5., 6., + 0., 0., 9., 10., 11., 12., 0., 0., 4., 0., 6., 0., 0., 0., + 10., 0., 12., 0., 0., 0., 5., 6., 0., 0., 0., 0., 11., 12., + 0., 0., 0., 0., 6., 0., 0., 0., 0., 0., 12., 0., 0., 0., + 0., 0., 7., 8., 9., 10., 11., 12., 0., 0., 0., 0., 0., 0., + 8., 0., 10., 0., 12., 0., 0., 0., 0., 0., 0., 0., 9., 10., + 11., 12., 0., 0., 0., 0., 0., 0., 0., 0., 10., 0., 12., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 11., 12., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 12., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., + 23., 24., 14., 0., 16., 0., 18., 0., 20., 0., 22., 0., 24., 0., + 15., 16., 17., 18., 0., 0., 21., 22., 23., 24., 0., 0., 16., 0., + 18., 0., 0., 0., 22., 0., 24., 0., 0., 0., 17., 18., 0., 0., + 0., 0., 23., 24., 0., 0., 0., 0., 18., 0., 0., 0., 0., 0., + 24., 0., 0., 0., 0., 0., 19., 20., 21., 22., 23., 24., 0., 0., + 0., 0., 0., 0., 20., 0., 22., 0., 24., 0., 0., 0., 0., 0., + 0., 0., 21., 22., 23., 24., 0., 0., 0., 0., 0., 0., 0., 0., + 22., 0., 24., 0., 0., 0., 0., 0., 0., 0., 0., 0., 23., 24., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 24., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 25., 26., 27., 28., 29., 30., + 31., 32., 33., 34., 35., 36., 26., 0., 28., 0., 30., 0., 32., 0., + 34., 0., 36., 0., 27., 28., 29., 30., 0., 0., 33., 34., 35., 36., + 0., 0., 28., 0., 30., 0., 0., 0., 34., 0., 36., 0., 0., 0., + 29., 30., 0., 0., 0., 0., 35., 36., 0., 0., 0., 0., 30., 0., + 0., 0., 0., 0., 36., 0., 0., 0., 0., 0., 31., 32., 33., 34., + 35., 36., 0., 0., 0., 0., 0., 0., 32., 0., 34., 0., 36., 0., + 0., 0., 0., 0., 0., 0., 33., 34., 35., 36., 0., 0., 0., 0., + 0., 0., 0., 0., 34., 0., 36., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 35., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 37., 38., + 39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 38., 0., 40., 0., + 42., 0., 44., 0., 46., 0., 48., 0., 39., 40., 41., 42., 0., 0., + 45., 46., 47., 48., 0., 0., 40., 0., 42., 0., 0., 0., 46., 0., + 48., 0., 0., 0., 41., 42., 0., 0., 0., 0., 47., 48., 0., 0., + 0., 0., 42., 0., 0., 0., 0., 0., 48., 0., 0., 0., 0., 0., + 43., 44., 45., 46., 47., 48., 0., 0., 0., 0., 0., 0., 44., 0., + 46., 0., 48., 0., 0., 0., 0., 0., 0., 0., 45., 46., 47., 48., + 0., 0., 0., 0., 0., 0., 0., 0., 46., 0., 48., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 47., 48., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 48., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., + 50., 0., 52., 0., 54., 0., 56., 0., 58., 0., 60., 0., 51., 52., + 53., 54., 0., 0., 57., 58., 59., 60., 0., 0., 52., 0., 54., 0., + 0., 0., 58., 0., 60., 0., 0., 0., 53., 54., 0., 0., 0., 0., + 59., 60., 0., 0., 0., 0., 54., 0., 0., 0., 0., 0., 60., 0., + 0., 0., 0., 0., 55., 56., 57., 58., 59., 60., 0., 0., 0., 0., + 0., 0., 56., 0., 58., 0., 60., 0., 0., 0., 0., 0., 0., 0., + 57., 58., 59., 60., 0., 0., 0., 0., 0., 0., 0., 0., 58., 0., + 60., 0., 0., 0., 0., 0., 0., 0., 0., 0., 59., 60., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 60., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 61., 62., 63., 64., 65., 66., 67., 68., + 69., 70., 71., 72., 62., 0., 64., 0., 66., 0., 68., 0., 70., 0., + 72., 0., 63., 64., 65., 66., 0., 0., 69., 70., 71., 72., 0., 0., + 64., 0., 66., 0., 0., 0., 70., 0., 72., 0., 0., 0., 65., 66., + 0., 0., 0., 0., 71., 72., 0., 0., 0., 0., 66., 0., 0., 0., + 0., 0., 72., 0., 0., 0., 0., 0., 67., 68., 69., 70., 71., 72., + 0., 0., 0., 0., 0., 0., 68., 0., 70., 0., 72., 0., 0., 0., + 0., 0., 0., 0., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., + 0., 0., 70., 0., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 71., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 72., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, + sd::DataType::FLOAT32); + + graph::Context context(1); + sd::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, + pH, pW, dD, dH, dW); + // columns.printBuffer(); + + ASSERT_TRUE(columns.equalsTo(columnsExpected)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, vol2col_test2) { - - int bS=2, iD=2,iH=3,iW=2, iC=3,oC=2, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=3,oW=2; - - auto volume = NDArrayFactory::create('c', {iD, bS, iH, iC, iW}); - volume.permutei({1, 3, 0, 2, 4}); - volume.linspace(1); - - auto columns = NDArrayFactory::create('c', {kD, iC, kH, oW, kW, bS, oD, oH}); - columns.permutei({5, 1, 0, 2, 4, 6, 7, 3}); - columns = -1.; - auto columnsExpected = NDArrayFactory::create('c', {bS, iC, kD, kH, kW, oD, oH, oW}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, -10.f, 11.f, 12.f, 2.f, 0.f, 4.f, 0.f, 6.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 3.f, 4.f, 5.f, 6.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 4.f, 0.f, 6.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 6.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 7.f, 8.f, -9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 14.f, 0.f, 16.f, 0.f, 18.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 15.f, 16.f, 17.f, 18.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 16.f, 0.f, 18.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 17.f, 18.f, 0.f, 0.f, 0.f, 0.f, -23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 26.f, 0.f, 28.f, 0.f, 30.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 27.f, 28.f, 29.f, 30.f, 0.f, 0.f, 33.f, 34.f, 35.f, 36.f, -0.f, 0.f, 28.f, 0.f, 30.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 29.f, 30.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 30.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 33.f, -34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 38.f, 0.f, 40.f, -0.f, 42.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 39.f, 40.f, 41.f, 42.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 40.f, 0.f, 42.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 41.f, 42.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 42.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 43.f, 44.f, 45.f, 46.f, 47.f, -48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 50.f, 0.f, 52.f, 0.f, 54.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 51.f, 52.f, 53.f, 54.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 52.f, 0.f, 54.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 53.f, 54.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f, 0.f, 0.f, -0.f, 0.f, 54.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 62.f, 0.f, 64.f, 0.f, 66.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 63.f, 64.f, 65.f, 66.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 64.f, 0.f, 66.f, -0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 65.f, 66.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 66.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); - - graph::Context context(1); - sd::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); - // columns.printBuffer(); - - ASSERT_TRUE(columns.equalsTo(columnsExpected)); + int bS = 2, iD = 2, iH = 3, iW = 2, iC = 3, oC = 2, kD = 2, kH = 3, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 3, oW = 2; + + auto volume = NDArrayFactory::create('c', {iD, bS, iH, iC, iW}); + volume.permutei({1, 3, 0, 2, 4}); + volume.linspace(1); + + auto columns = + NDArrayFactory::create('c', {kD, iC, kH, oW, kW, bS, oD, oH}); + columns.permutei({5, 1, 0, 2, 4, 6, 7, 3}); + columns = -1.; + auto columnsExpected = NDArrayFactory::create( + 'c', {bS, iC, kD, kH, kW, oD, oH, oW}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 2.f, 0.f, 4.f, 0.f, 6.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, + 3.f, 4.f, 5.f, 6.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, + 4.f, 0.f, 6.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, + 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, + 6.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 14.f, 0.f, 16.f, 0.f, 18.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, + 15.f, 16.f, 17.f, 18.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, + 16.f, 0.f, 18.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, + 17.f, 18.f, 0.f, 0.f, 0.f, 0.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, + 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, + 26.f, 0.f, 28.f, 0.f, 30.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, + 27.f, 28.f, 29.f, 30.f, 0.f, 0.f, 33.f, 34.f, 35.f, 36.f, 0.f, 0.f, + 28.f, 0.f, 30.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, + 29.f, 30.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, + 30.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 33.f, 34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, + 38.f, 0.f, 40.f, 0.f, 42.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, + 39.f, 40.f, 41.f, 42.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, + 40.f, 0.f, 42.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, + 41.f, 42.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, + 42.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, + 50.f, 0.f, 52.f, 0.f, 54.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, + 51.f, 52.f, 53.f, 54.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, + 52.f, 0.f, 54.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, + 53.f, 54.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, + 54.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, + 62.f, 0.f, 64.f, 0.f, 66.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, + 63.f, 64.f, 65.f, 66.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, + 64.f, 0.f, 66.f, 0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, + 65.f, 66.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, + 66.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + + graph::Context context(1); + sd::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, + pH, pW, dD, dH, dW); + // columns.printBuffer(); + + ASSERT_TRUE(columns.equalsTo(columnsExpected)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, col2im_test1) { + int bS = 2, iH = 2, iW = 2, iC = 2, kH = 2, kW = 2, sD = 1, sH = 1, sW = 1, + pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oH = 2, oW = 2; - int bS=2, iH=2,iW=2, iC=2, kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oH=2,oW=2; - - auto image = NDArrayFactory::create('c', {bS, iC, iH, iW}); - image = -2.; - - auto columns = NDArrayFactory::create('c', {bS, iC, kH, kW, oH, oW}); - columns.linspace(1); + auto image = NDArrayFactory::create('c', {bS, iC, iH, iW}); + image = -2.; - auto imageExpected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {1.f, 7.f, 12.f, 34.f, 17.f, 39.f, 44.f, 98.f, 33.f, 71.f, 76.f, 162.f, 49.f, 103.f, 108.f, 226.f}); + auto columns = NDArrayFactory::create('c', {bS, iC, kH, kW, oH, oW}); + columns.linspace(1); + auto imageExpected = NDArrayFactory::create( + 'c', {bS, iC, iH, iW}, + {1.f, 7.f, 12.f, 34.f, 17.f, 39.f, 44.f, 98.f, 33.f, 71.f, 76.f, 162.f, + 49.f, 103.f, 108.f, 226.f}); - sd::ops::col2im op; - auto status = op.execute({&columns}, {&image}, {sH, sW, pH, pW, iH, iW, dH, dW, 0}); - ASSERT_EQ(Status::OK(), status); + sd::ops::col2im op; + auto status = + op.execute({&columns}, {&image}, {sH, sW, pH, pW, iH, iW, dH, dW, 0}); + ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(image.equalsTo(imageExpected)); + ASSERT_TRUE(image.equalsTo(imageExpected)); } - ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, upsampling2d_test1) { - - const int bS=3, iH=2,iW=2, iC=3; - const int factorH=2, factorW=3; - const int isNCHW = 0; // data format, default is NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - input.linspace(1); - - auto expOutput = NDArrayFactory::create('c', {bS, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, - 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, - 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, - 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, - 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f}); - - sd::ops::upsampling2d op; - auto results = op.evaluate({&input}, {factorH, factorW, isNCHW}); - auto* output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - + const int bS = 3, iH = 2, iW = 2, iC = 3; + const int factorH = 2, factorW = 3; + const int isNCHW = 0; // data format, default is NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + input.linspace(1); + + auto expOutput = NDArrayFactory::create( + 'c', {bS, iH * factorH, iW * factorW, iC}, + {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, + 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, + 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, + 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, + 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, + 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f}); + + sd::ops::upsampling2d op; + auto results = op.evaluate({&input}, {factorH, factorW, isNCHW}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, upsampling2d_test2) { - - const int bS=3, iH=2,iW=2, iC=3; - const int factorH=2, factorW=3; - const int isNCHW = 1; // data format, default is NCHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - input.linspace(1); - - auto expOutput = NDArrayFactory::create('c', {bS, iC, iH*factorH, iW*factorW}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, - 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, - 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, - 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, - 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, - 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f}); - - sd::ops::upsampling2d op; - auto results = op.evaluate({&input}, {factorH, factorW, isNCHW}); - auto* output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - + const int bS = 3, iH = 2, iW = 2, iC = 3; + const int factorH = 2, factorW = 3; + const int isNCHW = 1; // data format, default is NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + input.linspace(1); + + auto expOutput = NDArrayFactory::create( + 'c', {bS, iC, iH * factorH, iW * factorW}, + {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, + 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, + 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, + 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, + 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, + 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, + 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, + 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, + 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, + 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, + 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, + 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, + 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, + 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, + 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, + 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, + 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, + 35.f, 35.f, 35.f, 36.f, 36.f, 36.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f}); + + sd::ops::upsampling2d op; + auto results = op.evaluate({&input}, {factorH, factorW, isNCHW}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, upsampling3d_test1) { - - const int bS=3, iD=2,iH=2,iW=2, iC=3; - const int factorD=2,factorH=3,factorW=2; - const int isNCDHW = 0; // data format, default is NCHW - - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - input.linspace(1); - - auto expOutput = NDArrayFactory::create('c', {bS, iD*factorD, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, - 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, - 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, - 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, - 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, - 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, - 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, - 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, - 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, - 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, - 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, - 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, - 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f}); - - sd::ops::upsampling3d op; - auto results = op.evaluate({&input}, {factorD, factorH, factorW, isNCDHW}); - auto* output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - + const int bS = 3, iD = 2, iH = 2, iW = 2, iC = 3; + const int factorD = 2, factorH = 3, factorW = 2; + const int isNCDHW = 0; // data format, default is NCHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + input.linspace(1); + + auto expOutput = NDArrayFactory::create( + 'c', {bS, iD * factorD, iH * factorH, iW * factorW, iC}, + {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, + 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, + 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, + 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, + 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, + 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, + 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, + 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, + 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, + 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, + 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, + 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, + 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, + 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, + 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, + 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, + 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, + 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, + 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, + 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, + 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, + 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, + 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, + 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, + 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, + 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, + 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, + 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, + 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, + 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, + 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, + 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, + 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, + 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, + 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, + 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, + 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, + 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, + 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, + 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, + 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f}); + + sd::ops::upsampling3d op; + auto results = op.evaluate({&input}, {factorD, factorH, factorW, isNCDHW}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, upsampling3d_test2) { - - const int bS=3, iD=2,iH=2,iW=2, iC=3; - const int factorD=2,factorH=3,factorW=2; - const int isNCDHW = 1; // data format, default is NCHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - input.linspace(1); - - auto expOutput = NDArrayFactory::create('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, { 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, - 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, - 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, - 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, - 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, - 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, - 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, - 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, - 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, - 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, - 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, - 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f}); - - sd::ops::upsampling3d op; - auto results = op.evaluate({&input}, {factorD, factorH, factorW, isNCDHW}); - auto* output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - + const int bS = 3, iD = 2, iH = 2, iW = 2, iC = 3; + const int factorD = 2, factorH = 3, factorW = 2; + const int isNCDHW = 1; // data format, default is NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + input.linspace(1); + + auto expOutput = NDArrayFactory::create( + 'c', {bS, iC, iD * factorD, iH * factorH, iW * factorW}, + {1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, + 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, + 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, + 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, + 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, + 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, + 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, + 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, + 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, + 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, + 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, + 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, + 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, + 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, + 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, + 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, + 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, + 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, + 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, + 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, + 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, + 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, + 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, + 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, + 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, + 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, + 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, + 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, + 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, + 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, + 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, + 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, + 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, + 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, + 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, + 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, + 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, + 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, + 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, + 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, + 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, + 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, + 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, + 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, + 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, + 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, + 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, + 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, + 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, + 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, + 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, + 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, + 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, + 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, + 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, + 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, + 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, + 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, + 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, + 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, + 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, + 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, + 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, + 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, + 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, + 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, + 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, + 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, + 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, + 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, + 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, + 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f}); + + sd::ops::upsampling3d op; + auto results = op.evaluate({&input}, {factorD, factorH, factorW, isNCDHW}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } - ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, upsampling3d_bp_test1) { + const int bS = 1, iD = 2, iH = 2, iW = 2, iC = 1; + const int factorD = 2, factorH = 2, factorW = 2; + const int isNCDHW = 1; // data format, default is NCHW - const int bS=1, iD=2,iH=2,iW=2, iC=1; - const int factorD=2, factorH=2, factorW=2; - const int isNCDHW = 1; // data format, default is NCHW + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto gradO = NDArrayFactory::create( + 'c', {bS, iC, iD * factorD, iH * factorH, iW * factorW}); + gradO = 1.; - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}); - gradO = 1.; + auto expGradI = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + expGradI = 8.; - auto expGradI = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - expGradI = 8.; + sd::ops::upsampling3d_bp op; + auto results = op.evaluate({&input, &gradO}, {isNCDHW}); + auto gradI = results.at(0); - sd::ops::upsampling3d_bp op; - auto results = op.evaluate({&input, &gradO}, {isNCDHW}); - auto* gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); } TYPED_TEST(TypedConvolutionTests1, conv2D_input_BP_test1) { + auto inputShape = NDArrayFactory::create('c', {4}, {2, 1, 4, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 3, 3}); + auto epsilonNext = NDArrayFactory::create('c', {2, 2, 4, 4}); + auto shapeArr = NDArrayFactory::create('c', {2, 1, 4, 4}); - auto inputShape = NDArrayFactory::create('c', {4}, {2, 1, 4, 4}); - auto weights = NDArrayFactory::create('c', {2, 1, 3, 3}); - auto epsilonNext = NDArrayFactory::create('c', {2, 2, 4, 4}); - auto shapeArr = NDArrayFactory::create('c', {2, 1, 4, 4}); - - - TypeParam _expEpsB[] = {952.0, 1540.0, 1636.0, 1180.0, 1791.0, 2886.0, 3057.0, 2193.0, 2223.0, 3570.0, 3741.0, 2673.0, 1900.0, 3028.0, 3160.0, 2240.0, 2872.0, 4612.0, 4708.0, 3356.0, 5247.0, 8358.0, 8529.0, 6033.0, 5679.0, 9042.0, 9213.0, 6513.0, 4588.0, 7252.0, 7384.0, 5184.0}; - NDArray expEps(_expEpsB, shapeArr.shapeInfo()); - - weights.linspace(1); - epsilonNext.linspace(1); - weights.permutei({2,3,1,0}); + TypeParam _expEpsB[] = { + 952.0, 1540.0, 1636.0, 1180.0, 1791.0, 2886.0, 3057.0, 2193.0, + 2223.0, 3570.0, 3741.0, 2673.0, 1900.0, 3028.0, 3160.0, 2240.0, + 2872.0, 4612.0, 4708.0, 3356.0, 5247.0, 8358.0, 8529.0, 6033.0, + 5679.0, 9042.0, 9213.0, 6513.0, 4588.0, 7252.0, 7384.0, 5184.0}; + NDArray expEps(_expEpsB, shapeArr.shapeInfo()); - sd::ops::conv2d_input_bp op; + weights.linspace(1); + epsilonNext.linspace(1); + weights.permutei({2, 3, 1, 0}); - auto results = op.evaluate({&inputShape, &weights, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}); + sd::ops::conv2d_input_bp op; - ASSERT_TRUE(results.size() == 1); + auto results = op.evaluate({&inputShape, &weights, &epsilonNext}, {}, + {3, 3, 1, 1, 0, 0, 1, 1, 1}); - auto epsilon = results.at(0); + ASSERT_TRUE(results.size() == 1); - ASSERT_TRUE(shapeArr.isSameShape(epsilon)); - ASSERT_TRUE(expEps.equalsTo(epsilon)); + auto epsilon = results.at(0); + ASSERT_TRUE(shapeArr.isSameShape(epsilon)); + ASSERT_TRUE(expEps.equalsTo(epsilon)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, upsampling3d_bp_test3) { - - const int bS=1, iD=3,iH=3,iW=3, iC=2; - const int factorD=2, factorH=2, factorW=2; - const int isNCDHW = 1; // data format, default is NCHW - - NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, {0.6793504, 0.35508695, 0.84278935, 0.20031333, 0.7014987, 0.31069338, - 0.44793984, 0.93800974, 0.32667395, 0.15187258, 0.38331753, 0.78212297, 0.1988072, 0.7985636, 0.1632634, 0.14696825, 0.26089668, - 0.13505761, 0.7562093, 0.27545404, 0.36908787, 0.09282647, 0.83649176, 0.26841334, 0.09506222, 0.31279507, 0.13591796, 0.5175439, - 0.32870287, 0.061735712, 0.39643127, 0.248016, 0.5489592, 0.115046196, 0.8143622, 0.7215636, 0.40449402, 0.29908907, 0.4038839, - 0.9883108, 0.022296403, 0.927782, 0.3184157, 0.0685462, 0.28453344, 0.23272, 0.35214192, 0.058909304, 0.7112212, 0.6744568, 0.19694561, - 0.6994972, 0.0743224, 0.42042503, 0.5842631, 0.14957358, 0.44640633, 0.72307247, 0.06448108, 0.48307765, 0.8759956, 0.5698191, 0.4458631, - 0.5277549, 0.016646361, 0.753678, 0.14063567, 0.7541292, 0.16193217, 0.7750374, 0.3326449, 0.11739397, 0.017710684, 0.60847557, 0.52515227, - 0.9171938, 0.84989065, 0.5894228, 0.85227835, 0.39063585, 0.88968325, 0.6694452, 0.698873, 0.96147966, 0.15740126, 0.15736352, 0.49352047, - 0.5699365, 0.12683152, 0.11572781, 0.7863682, 0.737939, 0.49007934, 0.6084143, 0.9564999, 0.3900982, 0.14730452, 0.8506447, 0.49765033, - 0.07186628, 0.08214969, 0.035314173, 0.7320408, 0.36993408, 0.8406658, 0.27389422, 0.43179566, 0.13323106, 0.19297548, 0.24689731, 0.38641843, - 0.51154125, 0.19903564, 0.1416313, 0.69769853, 0.25363067, 0.78221816, 0.9300991, 0.3355119, 0.5588076, 0.6643576, 0.018850708, 0.63755876, - 0.2904297, 0.43490165, 0.84251267, 0.46609768, 0.38139546, 0.52318525, 0.9901826, 0.9257676, 0.6434591, 0.016828254, 0.9187561, 0.22897908, - 0.0063138064, 0.66597503, 0.19036093, 0.59552056, 0.69888055, 0.22146936, 0.9124342, 0.8708221, 0.7273687, 0.52397245, 0.66288394, 0.2188415, - 0.3354802, 0.03566524, 0.5101009, 0.5017283, 0.75122046, 0.1884508, 0.7407126, 0.6253045, 0.47145858, 0.5369367, 0.19884548, 0.99008304, - 0.08256686, 0.91884845, 0.02360027, 0.98895234, 0.3751719, 0.91783875, 0.4338776, 0.6783008, 0.6667967, 0.46720362, 0.7508773, 0.52304846, - 0.76631916, 0.4187526, 0.7653719, 0.5159193, 0.42730415, 0.49462363, 0.2731735, 0.8862948, 0.043214794, 0.3197591, 0.040378205, 0.5427239, - 0.9228089, 0.045940384, 0.70047987, 0.8419288, 0.53966296, 0.009444186, 0.038044546, 0.03158029, 0.43485752, 0.9204235, 0.5478789, 0.8290083, - 0.11868837, 0.0229866, 0.6639305, 0.8757367, 0.8279557, 0.76270294, 0.43242732, 0.4713431, 0.2569212, 0.30575937, 0.44395888, 0.99384075, - 0.6127142, 0.44844577, 0.6347944, 0.098358564, 0.34233716, 0.9329664, 0.65776783, 0.108565055, 0.2052629, 0.46441218, 0.041791342, 0.89369565, - 0.7000381, 0.2106213, 0.51152664, 0.44200692, 0.8293282, 0.20901772, 0.6387249, 0.8016979, 0.11178707, 0.109545894, 0.19654618, 0.060582615, - 0.08239174, 0.64630795, 0.32862368, 0.60225064, 0.8328141, 0.5484566, 0.8120276, 0.38822946, 0.6742381, 0.34913155, 0.42887798, 0.45344824, - 0.73956585, 0.9714739, 0.42937812, 0.45185348, 0.84535813, 0.046436775, 0.8802151, 0.8676222, 0.42625394, 0.4985318, 0.42399272, 0.122144565, - 0.0060101906, 0.47253844, 0.18123977, 0.86316174, 0.5863874, 0.3852012, 0.9785553, 0.0054711984, 0.88500834, 0.020897374, 0.27467912, 0.3852802, - 0.0766939, 0.94622654, 0.38687763, 0.3308602, 0.7770494, 0.9052543, 0.22258204, 0.42207044, 0.18050623, 0.21057767, 0.012561422, 0.7977821, - 0.61251044, 0.7203693, 0.6028265, 0.6036933, 0.1446382, 0.6712341, 0.76634467, 0.4854034, 0.26634562, 0.76523924, 0.16348523, 0.2663676, - 0.96846986, 0.8273284, 0.10700377, 0.7600526, 0.6771002, 0.47963092, 0.21264452, 0.56934077, 0.5514792, 0.85725874, 0.99090636, 0.54562527, - 0.93597686, 0.21142527, 0.4628326, 0.35011524, 0.31464386, 0.31164807, 0.65928996, 0.94418925, 0.39666295, 0.9496393, 0.103756346, 0.482158, - 0.49171793, 0.4108867, 0.22594318, 0.97093135, 0.5974685, 0.34632966, 0.54835194, 0.10499302, 0.9767778, 0.55008715, 0.54379046, 0.3583731, - 0.33369112, 0.04279039, 0.24939054, 0.23943715, 0.06775989, 0.7750291, 0.24329625, 0.4327169, 0.86916673, 0.80322117, 0.049972698, 0.47177452, - 0.37419558, 0.15303156, 0.121425234, 0.75884604, 0.8191354, 0.48554084, 0.053899214, 0.7858246, 0.39219773, 0.77579063, 0.34507045, 0.46070176, - 0.14496958, 0.47706795, 0.50678796, 0.64902323, 0.3277943, 0.0017530271, 0.6536156, 0.8582253, 0.95703506, 0.9963951, 0.8239163, 0.305142, - 0.012419582, 0.9498972, 0.1595827, 0.47947606, 0.5071124, 0.78227425, 0.2066719, 0.5217094, 0.7841406, 0.5260441, 0.49798164, 0.10975622, - 0.8633349, 0.76298475, 0.14295428, 0.6131504, 0.43794408, 0.50339264, 0.4504877, 0.19235311, 0.6678411, 0.80769485, 0.67495126, 0.96461457, - 0.10535406, 0.66438645, 0.4372345, 0.93851465, 0.8635335, 0.3405871, 0.45652762, 0.3636232, 0.52931345, 0.20154329, 0.07698499, 0.6125804, - 0.3583082, 0.3894796, 0.32601944, 0.5237369, 0.66683626, 0.08541841, 0.4815708, 0.11897489, 0.97555137, 0.3602705, 0.9620871, 0.6361821, - 0.71167386, 0.5134439, 0.57761437, 0.58598644, 0.39387667, 0.6966405, 0.46841687, 0.85788506, 0.9957087, 0.051309288, 0.24846801, 0.55938333, - 0.10230542, 0.9370694, 0.57527155, 0.54656035, 0.28896323, 0.51303476, 0.8865, 0.38641605, 0.9836358}, sd::DataType::FLOAT32); - - NDArray expGradI('c', {bS, iC, iD, iH, iW}, {3.510932, 3.4310975, 3.538762, 4.148549, 2.8380678, 2.5431657, 3.3928843, 3.228055, 3.1467278, - 3.2603023, 5.611751, 4.334653, 3.3697734, 4.603307, 4.4357986, 4.32991, 3.0532732, 3.1370173, 4.181534, 2.9965065, 2.8553872, 5.2719016, - 4.5671935, 3.7027276, 3.3517184, 5.2544537, 3.5107024, 4.1496124, 3.9333878, 3.1798909, 3.1446428, 3.0932689, 3.9730802, 3.0466917, - 4.9675374, 4.769673, 3.766952, 3.6375027, 3.6492167, 4.9440994, 3.8379507, 3.467589, 4.719474, 3.1295977, 4.5177174, 4.2760015, 2.8443856, - 4.225355, 4.377341, 4.4398847, 4.710785, 4.4199953, 3.928307, 4.8769503}, sd::DataType::FLOAT32); - - sd::ops::upsampling3d_bp op; - auto results = op.evaluate({&input, &gradO}, {isNCDHW}); - auto* gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - + const int bS = 1, iD = 3, iH = 3, iW = 3, iC = 2; + const int factorD = 2, factorH = 2, factorW = 2; + const int isNCDHW = 1; // data format, default is NCHW + + NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); + NDArray gradO( + 'c', {bS, iC, iD * factorD, iH * factorH, iW * factorW}, + {0.6793504, 0.35508695, 0.84278935, 0.20031333, 0.7014987, + 0.31069338, 0.44793984, 0.93800974, 0.32667395, 0.15187258, + 0.38331753, 0.78212297, 0.1988072, 0.7985636, 0.1632634, + 0.14696825, 0.26089668, 0.13505761, 0.7562093, 0.27545404, + 0.36908787, 0.09282647, 0.83649176, 0.26841334, 0.09506222, + 0.31279507, 0.13591796, 0.5175439, 0.32870287, 0.061735712, + 0.39643127, 0.248016, 0.5489592, 0.115046196, 0.8143622, + 0.7215636, 0.40449402, 0.29908907, 0.4038839, 0.9883108, + 0.022296403, 0.927782, 0.3184157, 0.0685462, 0.28453344, + 0.23272, 0.35214192, 0.058909304, 0.7112212, 0.6744568, + 0.19694561, 0.6994972, 0.0743224, 0.42042503, 0.5842631, + 0.14957358, 0.44640633, 0.72307247, 0.06448108, 0.48307765, + 0.8759956, 0.5698191, 0.4458631, 0.5277549, 0.016646361, + 0.753678, 0.14063567, 0.7541292, 0.16193217, 0.7750374, + 0.3326449, 0.11739397, 0.017710684, 0.60847557, 0.52515227, + 0.9171938, 0.84989065, 0.5894228, 0.85227835, 0.39063585, + 0.88968325, 0.6694452, 0.698873, 0.96147966, 0.15740126, + 0.15736352, 0.49352047, 0.5699365, 0.12683152, 0.11572781, + 0.7863682, 0.737939, 0.49007934, 0.6084143, 0.9564999, + 0.3900982, 0.14730452, 0.8506447, 0.49765033, 0.07186628, + 0.08214969, 0.035314173, 0.7320408, 0.36993408, 0.8406658, + 0.27389422, 0.43179566, 0.13323106, 0.19297548, 0.24689731, + 0.38641843, 0.51154125, 0.19903564, 0.1416313, 0.69769853, + 0.25363067, 0.78221816, 0.9300991, 0.3355119, 0.5588076, + 0.6643576, 0.018850708, 0.63755876, 0.2904297, 0.43490165, + 0.84251267, 0.46609768, 0.38139546, 0.52318525, 0.9901826, + 0.9257676, 0.6434591, 0.016828254, 0.9187561, 0.22897908, + 0.0063138064, 0.66597503, 0.19036093, 0.59552056, 0.69888055, + 0.22146936, 0.9124342, 0.8708221, 0.7273687, 0.52397245, + 0.66288394, 0.2188415, 0.3354802, 0.03566524, 0.5101009, + 0.5017283, 0.75122046, 0.1884508, 0.7407126, 0.6253045, + 0.47145858, 0.5369367, 0.19884548, 0.99008304, 0.08256686, + 0.91884845, 0.02360027, 0.98895234, 0.3751719, 0.91783875, + 0.4338776, 0.6783008, 0.6667967, 0.46720362, 0.7508773, + 0.52304846, 0.76631916, 0.4187526, 0.7653719, 0.5159193, + 0.42730415, 0.49462363, 0.2731735, 0.8862948, 0.043214794, + 0.3197591, 0.040378205, 0.5427239, 0.9228089, 0.045940384, + 0.70047987, 0.8419288, 0.53966296, 0.009444186, 0.038044546, + 0.03158029, 0.43485752, 0.9204235, 0.5478789, 0.8290083, + 0.11868837, 0.0229866, 0.6639305, 0.8757367, 0.8279557, + 0.76270294, 0.43242732, 0.4713431, 0.2569212, 0.30575937, + 0.44395888, 0.99384075, 0.6127142, 0.44844577, 0.6347944, + 0.098358564, 0.34233716, 0.9329664, 0.65776783, 0.108565055, + 0.2052629, 0.46441218, 0.041791342, 0.89369565, 0.7000381, + 0.2106213, 0.51152664, 0.44200692, 0.8293282, 0.20901772, + 0.6387249, 0.8016979, 0.11178707, 0.109545894, 0.19654618, + 0.060582615, 0.08239174, 0.64630795, 0.32862368, 0.60225064, + 0.8328141, 0.5484566, 0.8120276, 0.38822946, 0.6742381, + 0.34913155, 0.42887798, 0.45344824, 0.73956585, 0.9714739, + 0.42937812, 0.45185348, 0.84535813, 0.046436775, 0.8802151, + 0.8676222, 0.42625394, 0.4985318, 0.42399272, 0.122144565, + 0.0060101906, 0.47253844, 0.18123977, 0.86316174, 0.5863874, + 0.3852012, 0.9785553, 0.0054711984, 0.88500834, 0.020897374, + 0.27467912, 0.3852802, 0.0766939, 0.94622654, 0.38687763, + 0.3308602, 0.7770494, 0.9052543, 0.22258204, 0.42207044, + 0.18050623, 0.21057767, 0.012561422, 0.7977821, 0.61251044, + 0.7203693, 0.6028265, 0.6036933, 0.1446382, 0.6712341, + 0.76634467, 0.4854034, 0.26634562, 0.76523924, 0.16348523, + 0.2663676, 0.96846986, 0.8273284, 0.10700377, 0.7600526, + 0.6771002, 0.47963092, 0.21264452, 0.56934077, 0.5514792, + 0.85725874, 0.99090636, 0.54562527, 0.93597686, 0.21142527, + 0.4628326, 0.35011524, 0.31464386, 0.31164807, 0.65928996, + 0.94418925, 0.39666295, 0.9496393, 0.103756346, 0.482158, + 0.49171793, 0.4108867, 0.22594318, 0.97093135, 0.5974685, + 0.34632966, 0.54835194, 0.10499302, 0.9767778, 0.55008715, + 0.54379046, 0.3583731, 0.33369112, 0.04279039, 0.24939054, + 0.23943715, 0.06775989, 0.7750291, 0.24329625, 0.4327169, + 0.86916673, 0.80322117, 0.049972698, 0.47177452, 0.37419558, + 0.15303156, 0.121425234, 0.75884604, 0.8191354, 0.48554084, + 0.053899214, 0.7858246, 0.39219773, 0.77579063, 0.34507045, + 0.46070176, 0.14496958, 0.47706795, 0.50678796, 0.64902323, + 0.3277943, 0.0017530271, 0.6536156, 0.8582253, 0.95703506, + 0.9963951, 0.8239163, 0.305142, 0.012419582, 0.9498972, + 0.1595827, 0.47947606, 0.5071124, 0.78227425, 0.2066719, + 0.5217094, 0.7841406, 0.5260441, 0.49798164, 0.10975622, + 0.8633349, 0.76298475, 0.14295428, 0.6131504, 0.43794408, + 0.50339264, 0.4504877, 0.19235311, 0.6678411, 0.80769485, + 0.67495126, 0.96461457, 0.10535406, 0.66438645, 0.4372345, + 0.93851465, 0.8635335, 0.3405871, 0.45652762, 0.3636232, + 0.52931345, 0.20154329, 0.07698499, 0.6125804, 0.3583082, + 0.3894796, 0.32601944, 0.5237369, 0.66683626, 0.08541841, + 0.4815708, 0.11897489, 0.97555137, 0.3602705, 0.9620871, + 0.6361821, 0.71167386, 0.5134439, 0.57761437, 0.58598644, + 0.39387667, 0.6966405, 0.46841687, 0.85788506, 0.9957087, + 0.051309288, 0.24846801, 0.55938333, 0.10230542, 0.9370694, + 0.57527155, 0.54656035, 0.28896323, 0.51303476, 0.8865, + 0.38641605, 0.9836358}, + sd::DataType::FLOAT32); + + NDArray expGradI( + 'c', {bS, iC, iD, iH, iW}, + {3.510932, 3.4310975, 3.538762, 4.148549, 2.8380678, 2.5431657, + 3.3928843, 3.228055, 3.1467278, 3.2603023, 5.611751, 4.334653, + 3.3697734, 4.603307, 4.4357986, 4.32991, 3.0532732, 3.1370173, + 4.181534, 2.9965065, 2.8553872, 5.2719016, 4.5671935, 3.7027276, + 3.3517184, 5.2544537, 3.5107024, 4.1496124, 3.9333878, 3.1798909, + 3.1446428, 3.0932689, 3.9730802, 3.0466917, 4.9675374, 4.769673, + 3.766952, 3.6375027, 3.6492167, 4.9440994, 3.8379507, 3.467589, + 4.719474, 3.1295977, 4.5177174, 4.2760015, 2.8443856, 4.225355, + 4.377341, 4.4398847, 4.710785, 4.4199953, 3.928307, 4.8769503}, + sd::DataType::FLOAT32); + + sd::ops::upsampling3d_bp op; + auto results = op.evaluate({&input, &gradO}, {isNCDHW}); + auto gradI = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); } - ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, deconv2d_test1) { - - int bS=2, oH=4,oW=4, oC=5,iC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int iH=3,iW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, oC, iC}); - auto exp = NDArrayFactory::create('c', {bS, oH, oW, oC}, { 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f, - 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f}); - input = 0.5; - weights.linspace(0.1, 0.1); - - sd::ops::deconv2d op; - auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - ASSERT_EQ(Status::OK(), results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + int bS = 2, oH = 4, oW = 4, oC = 5, iC = 10, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int iH = 3, iW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, oC, iC}); + auto exp = NDArrayFactory::create( + 'c', {bS, oH, oW, oC}, + {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, + 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, + 42.75f, 47.75f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, + 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, + 115.5f, 125.5f, 135.5f, 145.5f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, + 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, + 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, 52.75f, 57.75f, 62.75f, + 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, + 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f, 2.75f, + 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, + 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, + 47.75f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, + 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, + 125.5f, 135.5f, 145.5f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, + 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, 52.75f, 57.75f, 62.75f, 67.75f, + 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, + 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f}); + input = 0.5; + weights.linspace(0.1, 0.1); + + sd::ops::deconv2d op; + auto results = op.evaluate({&input, &weights}, {kH, kW, sH, sW, pH, pW, dH, + dW, paddingMode, dataFormat}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, deconv2d_test2) { - - int bS=2, iH=4,iW=4, iC=5,oC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=4,oW=4; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, oH, oW, oC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); - auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f }); - input = 0.5; - weights.linspace(0.1, 0.1); - - sd::ops::deconv2d op; - auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + int bS = 2, iH = 4, iW = 4, iC = 5, oC = 10, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 4; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, oH, oW, oC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); + auto exp = NDArrayFactory::create( + 'c', {bS, iH, iW, iC}, + {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, + 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, + 60.5f, 70.5f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, + 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, + 181.f, 201.f, 221.f, 241.f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, + 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, + 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 55.5f, 65.5f, 75.5f, + 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, + 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 2.75f, + 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, + 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, + 70.5f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, + 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, + 201.f, 221.f, 241.f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, + 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 161.f, 181.f, 201.f, 221.f, 241.f, 55.5f, 65.5f, 75.5f, 85.5f, + 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, + 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f}); + input = 0.5; + weights.linspace(0.1, 0.1); + + sd::ops::deconv2d op; + auto results = op.evaluate({&input, &weights}, {kH, kW, sH, sW, pH, pW, dH, + dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, deconv2d_test3) { - - int bS=1, oH=5,oW=5, oC=3,iC=2, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=2,dW=2; - int iH=3,iW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, oC, iC}); - auto bias = NDArrayFactory::create('c', {oC}); - - auto exp = NDArrayFactory::create('c', {bS, oH, oW, oC}, {-2.9f, -6.8f, -10.7f, -2.6f, -6.1f, -9.6f, -16.9f, -23.9f, -30.9f, -13.1f, -16.6f, -20.1f, -11.6f, -14.7f, -17.8f, -2.0f, -4.7f, -7.4f, -1.7f, -4.0f, -6.3f, -11.5f, -16.1f, - -20.7f, -8.6f, -10.9f, -13.2f, -7.1f, -9.0f, -10.9f, -27.4f, -32.8f, -38.2f, -24.4f, -29.0f, -33.6f, -65.0f, -74.2f, -83.4f, -38.2f, -42.8f, -47.4f, - -32.8f, -36.6f, -40.4f, -18.2f, -20.9f, -23.6f, -15.5f, -17.8f, -20.1f, -39.1f, -43.7f, -48.3f, -22.4f, -24.7f, -27.0f, -18.5f, -20.4f, -22.3f, -10.1f, -11.6f, -13.1f, - -7.4f, -8.5f, -9.6f, -19.3f, -21.5f, -23.7f, -10.7f, -11.8f, -12.9f, -6.8f, -7.5f, -8.2f}); - - input.linspace(-10, 0.5); - weights.linspace(0.1, 0.1); - bias = 0.2; - - sd::ops::deconv2d op; - auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - ASSERT_EQ(Status::OK(), results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + int bS = 1, oH = 5, oW = 5, oC = 3, iC = 2, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 2, dW = 2; + int iH = 3, iW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, oC, iC}); + auto bias = NDArrayFactory::create('c', {oC}); + + auto exp = NDArrayFactory::create( + 'c', {bS, oH, oW, oC}, + {-2.9f, -6.8f, -10.7f, -2.6f, -6.1f, -9.6f, -16.9f, -23.9f, -30.9f, + -13.1f, -16.6f, -20.1f, -11.6f, -14.7f, -17.8f, -2.0f, -4.7f, -7.4f, + -1.7f, -4.0f, -6.3f, -11.5f, -16.1f, -20.7f, -8.6f, -10.9f, -13.2f, + -7.1f, -9.0f, -10.9f, -27.4f, -32.8f, -38.2f, -24.4f, -29.0f, -33.6f, + -65.0f, -74.2f, -83.4f, -38.2f, -42.8f, -47.4f, -32.8f, -36.6f, -40.4f, + -18.2f, -20.9f, -23.6f, -15.5f, -17.8f, -20.1f, -39.1f, -43.7f, -48.3f, + -22.4f, -24.7f, -27.0f, -18.5f, -20.4f, -22.3f, -10.1f, -11.6f, -13.1f, + -7.4f, -8.5f, -9.6f, -19.3f, -21.5f, -23.7f, -10.7f, -11.8f, -12.9f, + -6.8f, -7.5f, -8.2f}); + + input.linspace(-10, 0.5); + weights.linspace(0.1, 0.1); + bias = 0.2; + + sd::ops::deconv2d op; + auto results = op.evaluate({&input, &weights}, {kH, kW, sH, sW, pH, pW, dH, + dW, paddingMode, dataFormat}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, deconv2d_test4) { - - NDArray input('c', {2, 3, 4, 4}, sd::DataType::FLOAT32); - NDArray weights('c', {3, 3, 5, 5}, sd::DataType::FLOAT32); - NDArray exp('c', {2,3,8,8}, {6276.0,12831.0,19668.0,26790.0,27012.0,20703.0,14100.0,7200.0,13719.0,28023.0,42918.0,58410.0,58902.0,45105.0,30693.0,15660.0,22389.0,45696.0,69930.0,95100.0,95910.0,73386.0,49899.0,25440.0,32346.0,65970.0, - 100884.0,137100.0,138276.0,105726.0,71838.0,36600.0,33726.0,68790.0,105204.0,142980.0,144156.0,110226.0,74898.0,38160.0,27555.0,56154.0,85806.0,116520.0,117474.0,89748.0,60933.0,31020.0,19917.0,40557.0,61926.0, - 84030.0,84714.0,64671.0,43875.0,22320.0,10752.0,21879.0,33384.0,45270.0,45636.0,34815.0,23604.0,12000.0,7551.0,15456.0,23718.0,32340.0,32562.0,24978.0,17025.0,8700.0,16569.0,33873.0,51918.0,70710.0,71202.0, - 54555.0,37143.0,18960.0,27114.0,55371.0,84780.0,115350.0,116160.0,88911.0,60474.0,30840.0,39246.0,80070.0,122484.0,166500.0,167676.0,128226.0,87138.0,44400.0,40626.0,82890.0,126804.0,172380.0,173556.0,132726.0, - 90198.0,45960.0,33180.0,67629.0,103356.0,140370.0,141324.0,107973.0,73308.0,37320.0,23967.0,48807.0,74526.0,101130.0,101814.0,77721.0,52725.0,26820.0,12927.0,26304.0,40134.0,54420.0,54786.0,41790.0,28329.0,14400.0, - 8826.0,18081.0,27768.0,37890.0,38112.0,29253.0,19950.0,10200.0,19419.0,39723.0,60918.0,83010.0,83502.0,64005.0,43593.0,22260.0,31839.0,65046.0,99630.0,135600.0,136410.0,104436.0,71049.0,36240.0,46146.0,94170.0, - 144084.0,195900.0,197076.0,150726.0,102438.0,52200.0,47526.0,96990.0,148404.0,201780.0,202956.0,155226.0,105498.0,53760.0,38805.0,79104.0,120906.0,164220.0,165174.0,126198.0,85683.0,43620.0,28017.0,57057.0,87126.0, - 118230.0,118914.0,90771.0,61575.0,31320.0,15102.0,30729.0,46884.0,63570.0,63936.0,48765.0,33054.0,16800.0,17220.0,34863.0,52932.0,71430.0,72228.0,54831.0,36996.0,18720.0,36327.0,73527.0,111606.0,150570.0,152214.0, - 115521.0,77925.0,39420.0,57381.0,116112.0,176202.0,237660.0,240198.0,182250.0,122907.0,62160.0,80442.0,162738.0,246900.0,332940.0,336420.0,255198.0,172062.0,87000.0,84702.0,171318.0,259860.0,350340.0,353820.0, - 268338.0,180882.0,91440.0,66867.0,135210.0,205038.0,276360.0,279042.0,211572.0,142581.0,72060.0,46845.0,94701.0,143574.0,193470.0,195306.0,148047.0,99747.0,50400.0,24576.0,49671.0,75288.0,101430.0,102372.0,77583.0, - 52260.0,26400.0,22095.0,44688.0,67782.0,91380.0,92178.0,69906.0,47121.0,23820.0,46377.0,93777.0,142206.0,191670.0,193314.0,146571.0,98775.0,49920.0,72906.0,147387.0,223452.0,301110.0,303648.0,230175.0,155082.0, - 78360.0,101742.0,205638.0,311700.0,419940.0,423420.0,320898.0,216162.0,109200.0,106002.0,214218.0,324660.0,437340.0,440820.0,334038.0,224982.0,113640.0,83292.0,168285.0,254988.0,343410.0,346092.0,262197.0,176556.0, - 89160.0,58095.0,117351.0,177774.0,239370.0,241206.0,182697.0,122997.0,62100.0,30351.0,61296.0,92838.0,124980.0,125922.0,95358.0,64185.0,32400.0,26970.0,54513.0,82632.0,111330.0,112128.0,84981.0,57246.0,28920.0,56427.0,114027.0,172806.0,232770.0,234414.0,177621.0,119625.0,60420.0,88431.0,178662.0,270702.0,364560.0,367098.0,278100.0,187257.0,94560.0,123042.0,248538.0,376500.0,506940.0,510420.0,386598.0,260262.0,131400.0,127302.0,257118.0,389460.0,524340.0,527820.0,399738.0,269082.0,135840.0,99717.0,201360.0,304938.0,410460.0,413142.0,312822.0,210531.0,106260.0,69345.0,140001.0,211974.0,285270.0,287106.0,217347.0,146247.0,73800.0,36126.0,72921.0,110388.0,148530.0,149472.0,113133.0,76110.0,38400.0}, sd::DataType::FLOAT32); - - input.linspace(1); - weights.linspace(1); - weights.permutei({2,3,1,0}); - - sd::ops::deconv2d op; - auto result = op.evaluate({&input, &weights}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}); - - auto z = result.at(0); - // z->printShapeInfo(); - // z->printBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + NDArray input('c', {2, 3, 4, 4}, sd::DataType::FLOAT32); + NDArray weights('c', {3, 3, 5, 5}, sd::DataType::FLOAT32); + NDArray exp( + 'c', {2, 3, 8, 8}, + {6276.0, 12831.0, 19668.0, 26790.0, 27012.0, 20703.0, 14100.0, + 7200.0, 13719.0, 28023.0, 42918.0, 58410.0, 58902.0, 45105.0, + 30693.0, 15660.0, 22389.0, 45696.0, 69930.0, 95100.0, 95910.0, + 73386.0, 49899.0, 25440.0, 32346.0, 65970.0, 100884.0, 137100.0, + 138276.0, 105726.0, 71838.0, 36600.0, 33726.0, 68790.0, 105204.0, + 142980.0, 144156.0, 110226.0, 74898.0, 38160.0, 27555.0, 56154.0, + 85806.0, 116520.0, 117474.0, 89748.0, 60933.0, 31020.0, 19917.0, + 40557.0, 61926.0, 84030.0, 84714.0, 64671.0, 43875.0, 22320.0, + 10752.0, 21879.0, 33384.0, 45270.0, 45636.0, 34815.0, 23604.0, + 12000.0, 7551.0, 15456.0, 23718.0, 32340.0, 32562.0, 24978.0, + 17025.0, 8700.0, 16569.0, 33873.0, 51918.0, 70710.0, 71202.0, + 54555.0, 37143.0, 18960.0, 27114.0, 55371.0, 84780.0, 115350.0, + 116160.0, 88911.0, 60474.0, 30840.0, 39246.0, 80070.0, 122484.0, + 166500.0, 167676.0, 128226.0, 87138.0, 44400.0, 40626.0, 82890.0, + 126804.0, 172380.0, 173556.0, 132726.0, 90198.0, 45960.0, 33180.0, + 67629.0, 103356.0, 140370.0, 141324.0, 107973.0, 73308.0, 37320.0, + 23967.0, 48807.0, 74526.0, 101130.0, 101814.0, 77721.0, 52725.0, + 26820.0, 12927.0, 26304.0, 40134.0, 54420.0, 54786.0, 41790.0, + 28329.0, 14400.0, 8826.0, 18081.0, 27768.0, 37890.0, 38112.0, + 29253.0, 19950.0, 10200.0, 19419.0, 39723.0, 60918.0, 83010.0, + 83502.0, 64005.0, 43593.0, 22260.0, 31839.0, 65046.0, 99630.0, + 135600.0, 136410.0, 104436.0, 71049.0, 36240.0, 46146.0, 94170.0, + 144084.0, 195900.0, 197076.0, 150726.0, 102438.0, 52200.0, 47526.0, + 96990.0, 148404.0, 201780.0, 202956.0, 155226.0, 105498.0, 53760.0, + 38805.0, 79104.0, 120906.0, 164220.0, 165174.0, 126198.0, 85683.0, + 43620.0, 28017.0, 57057.0, 87126.0, 118230.0, 118914.0, 90771.0, + 61575.0, 31320.0, 15102.0, 30729.0, 46884.0, 63570.0, 63936.0, + 48765.0, 33054.0, 16800.0, 17220.0, 34863.0, 52932.0, 71430.0, + 72228.0, 54831.0, 36996.0, 18720.0, 36327.0, 73527.0, 111606.0, + 150570.0, 152214.0, 115521.0, 77925.0, 39420.0, 57381.0, 116112.0, + 176202.0, 237660.0, 240198.0, 182250.0, 122907.0, 62160.0, 80442.0, + 162738.0, 246900.0, 332940.0, 336420.0, 255198.0, 172062.0, 87000.0, + 84702.0, 171318.0, 259860.0, 350340.0, 353820.0, 268338.0, 180882.0, + 91440.0, 66867.0, 135210.0, 205038.0, 276360.0, 279042.0, 211572.0, + 142581.0, 72060.0, 46845.0, 94701.0, 143574.0, 193470.0, 195306.0, + 148047.0, 99747.0, 50400.0, 24576.0, 49671.0, 75288.0, 101430.0, + 102372.0, 77583.0, 52260.0, 26400.0, 22095.0, 44688.0, 67782.0, + 91380.0, 92178.0, 69906.0, 47121.0, 23820.0, 46377.0, 93777.0, + 142206.0, 191670.0, 193314.0, 146571.0, 98775.0, 49920.0, 72906.0, + 147387.0, 223452.0, 301110.0, 303648.0, 230175.0, 155082.0, 78360.0, + 101742.0, 205638.0, 311700.0, 419940.0, 423420.0, 320898.0, 216162.0, + 109200.0, 106002.0, 214218.0, 324660.0, 437340.0, 440820.0, 334038.0, + 224982.0, 113640.0, 83292.0, 168285.0, 254988.0, 343410.0, 346092.0, + 262197.0, 176556.0, 89160.0, 58095.0, 117351.0, 177774.0, 239370.0, + 241206.0, 182697.0, 122997.0, 62100.0, 30351.0, 61296.0, 92838.0, + 124980.0, 125922.0, 95358.0, 64185.0, 32400.0, 26970.0, 54513.0, + 82632.0, 111330.0, 112128.0, 84981.0, 57246.0, 28920.0, 56427.0, + 114027.0, 172806.0, 232770.0, 234414.0, 177621.0, 119625.0, 60420.0, + 88431.0, 178662.0, 270702.0, 364560.0, 367098.0, 278100.0, 187257.0, + 94560.0, 123042.0, 248538.0, 376500.0, 506940.0, 510420.0, 386598.0, + 260262.0, 131400.0, 127302.0, 257118.0, 389460.0, 524340.0, 527820.0, + 399738.0, 269082.0, 135840.0, 99717.0, 201360.0, 304938.0, 410460.0, + 413142.0, 312822.0, 210531.0, 106260.0, 69345.0, 140001.0, 211974.0, + 285270.0, 287106.0, 217347.0, 146247.0, 73800.0, 36126.0, 72921.0, + 110388.0, 148530.0, 149472.0, 113133.0, 76110.0, 38400.0}, + sd::DataType::FLOAT32); + + input.linspace(1); + weights.linspace(1); + weights.permutei({2, 3, 1, 0}); + + sd::ops::deconv2d op; + auto result = op.evaluate({&input, &weights}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}); + + auto z = result.at(0); + // z->printShapeInfo(); + // z->printBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, deconv2d_test5) { - Nd4jLong _expS[] = {4, 2, 3, 8, 8, 192, 64, 8, 1, 16384, 1, 99}; - double _expB[] = {6276.0,12831.0,19668.0,26790.0,27012.0,20703.0,14100.0,7200.0,13719.0,28023.0,42918.0,58410.0,58902.0,45105.0,30693.0,15660.0,22389.0,45696.0,69930.0,95100.0,95910.0,73386.0,49899.0,25440.0,32346.0,65970.0,100884.0,137100.0,138276.0,105726.0,71838.0,36600.0,33726.0,68790.0,105204.0,142980.0,144156.0,110226.0,74898.0,38160.0,27555.0,56154.0,85806.0,116520.0,117474.0,89748.0,60933.0,31020.0,19917.0,40557.0,61926.0,84030.0,84714.0,64671.0,43875.0,22320.0,10752.0,21879.0,33384.0,45270.0,45636.0,34815.0,23604.0,12000.0,7551.0,15456.0,23718.0,32340.0,32562.0,24978.0,17025.0,8700.0,16569.0,33873.0,51918.0,70710.0,71202.0,54555.0,37143.0,18960.0,27114.0,55371.0,84780.0,115350.0,116160.0,88911.0,60474.0,30840.0,39246.0,80070.0,122484.0,166500.0,167676.0,128226.0,87138.0,44400.0,40626.0,82890.0,126804.0,172380.0,173556.0,132726.0,90198.0,45960.0,33180.0,67629.0,103356.0,140370.0,141324.0,107973.0,73308.0,37320.0,23967.0,48807.0,74526.0,101130.0,101814.0,77721.0,52725.0,26820.0,12927.0,26304.0,40134.0,54420.0,54786.0,41790.0,28329.0,14400.0,8826.0,18081.0,27768.0,37890.0,38112.0,29253.0,19950.0,10200.0,19419.0,39723.0,60918.0,83010.0,83502.0,64005.0,43593.0,22260.0,31839.0,65046.0,99630.0,135600.0,136410.0,104436.0,71049.0,36240.0,46146.0,94170.0,144084.0,195900.0,197076.0,150726.0,102438.0,52200.0,47526.0,96990.0,148404.0,201780.0,202956.0,155226.0,105498.0,53760.0,38805.0,79104.0,120906.0,164220.0,165174.0,126198.0,85683.0,43620.0,28017.0,57057.0,87126.0,118230.0,118914.0,90771.0,61575.0,31320.0,15102.0,30729.0,46884.0,63570.0,63936.0,48765.0,33054.0,16800.0,17220.0,34863.0,52932.0,71430.0,72228.0,54831.0,36996.0,18720.0,36327.0,73527.0,111606.0,150570.0,152214.0,115521.0,77925.0,39420.0,57381.0,116112.0,176202.0,237660.0,240198.0,182250.0,122907.0,62160.0,80442.0,162738.0,246900.0,332940.0,336420.0,255198.0,172062.0,87000.0,84702.0,171318.0,259860.0,350340.0,353820.0,268338.0,180882.0,91440.0,66867.0,135210.0,205038.0,276360.0,279042.0,211572.0,142581.0,72060.0,46845.0,94701.0,143574.0,193470.0,195306.0,148047.0,99747.0,50400.0,24576.0,49671.0,75288.0,101430.0,102372.0,77583.0,52260.0,26400.0,22095.0,44688.0,67782.0,91380.0,92178.0,69906.0,47121.0,23820.0,46377.0,93777.0,142206.0,191670.0,193314.0,146571.0,98775.0,49920.0,72906.0,147387.0,223452.0,301110.0,303648.0,230175.0,155082.0,78360.0,101742.0,205638.0,311700.0,419940.0,423420.0,320898.0,216162.0,109200.0,106002.0,214218.0,324660.0,437340.0,440820.0,334038.0,224982.0,113640.0,83292.0,168285.0,254988.0,343410.0,346092.0,262197.0,176556.0,89160.0,58095.0,117351.0,177774.0,239370.0,241206.0,182697.0,122997.0,62100.0,30351.0,61296.0,92838.0,124980.0,125922.0,95358.0,64185.0,32400.0,26970.0,54513.0,82632.0,111330.0,112128.0,84981.0,57246.0,28920.0,56427.0,114027.0,172806.0,232770.0,234414.0,177621.0,119625.0,60420.0,88431.0,178662.0,270702.0,364560.0,367098.0,278100.0,187257.0,94560.0,123042.0,248538.0,376500.0,506940.0,510420.0,386598.0,260262.0,131400.0,127302.0,257118.0,389460.0,524340.0,527820.0,399738.0,269082.0,135840.0,99717.0,201360.0,304938.0,410460.0,413142.0,312822.0,210531.0,106260.0,69345.0,140001.0,211974.0,285270.0,287106.0,217347.0,146247.0,73800.0,36126.0,72921.0,110388.0,148530.0,149472.0,113133.0,76110.0,38400.0,}; - NDArray exp(_expB, _expS); - - auto input = NDArrayFactory::create('c', {2, 3, 4, 4}); - auto weights = NDArrayFactory::create('c', {3, 3, 5, 5}); - auto z = NDArrayFactory::create('c', {2, 3, 8, 8}); - - input.linspace(1); - weights.linspace(1); - weights.permutei({2,3,1,0}); - - sd::ops::deconv2d op; - auto result = op.execute({&input, &weights}, {&z}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}); - - ASSERT_EQ(ND4J_STATUS_OK, result); - - ASSERT_TRUE(exp.isSameShape(&z)); - ASSERT_TRUE(exp.equalsTo(&z)); + Nd4jLong _expS[] = {4, 2, 3, 8, 8, 192, 64, 8, 1, 16384, 1, 99}; + double _expB[] = { + 6276.0, 12831.0, 19668.0, 26790.0, 27012.0, 20703.0, 14100.0, + 7200.0, 13719.0, 28023.0, 42918.0, 58410.0, 58902.0, 45105.0, + 30693.0, 15660.0, 22389.0, 45696.0, 69930.0, 95100.0, 95910.0, + 73386.0, 49899.0, 25440.0, 32346.0, 65970.0, 100884.0, 137100.0, + 138276.0, 105726.0, 71838.0, 36600.0, 33726.0, 68790.0, 105204.0, + 142980.0, 144156.0, 110226.0, 74898.0, 38160.0, 27555.0, 56154.0, + 85806.0, 116520.0, 117474.0, 89748.0, 60933.0, 31020.0, 19917.0, + 40557.0, 61926.0, 84030.0, 84714.0, 64671.0, 43875.0, 22320.0, + 10752.0, 21879.0, 33384.0, 45270.0, 45636.0, 34815.0, 23604.0, + 12000.0, 7551.0, 15456.0, 23718.0, 32340.0, 32562.0, 24978.0, + 17025.0, 8700.0, 16569.0, 33873.0, 51918.0, 70710.0, 71202.0, + 54555.0, 37143.0, 18960.0, 27114.0, 55371.0, 84780.0, 115350.0, + 116160.0, 88911.0, 60474.0, 30840.0, 39246.0, 80070.0, 122484.0, + 166500.0, 167676.0, 128226.0, 87138.0, 44400.0, 40626.0, 82890.0, + 126804.0, 172380.0, 173556.0, 132726.0, 90198.0, 45960.0, 33180.0, + 67629.0, 103356.0, 140370.0, 141324.0, 107973.0, 73308.0, 37320.0, + 23967.0, 48807.0, 74526.0, 101130.0, 101814.0, 77721.0, 52725.0, + 26820.0, 12927.0, 26304.0, 40134.0, 54420.0, 54786.0, 41790.0, + 28329.0, 14400.0, 8826.0, 18081.0, 27768.0, 37890.0, 38112.0, + 29253.0, 19950.0, 10200.0, 19419.0, 39723.0, 60918.0, 83010.0, + 83502.0, 64005.0, 43593.0, 22260.0, 31839.0, 65046.0, 99630.0, + 135600.0, 136410.0, 104436.0, 71049.0, 36240.0, 46146.0, 94170.0, + 144084.0, 195900.0, 197076.0, 150726.0, 102438.0, 52200.0, 47526.0, + 96990.0, 148404.0, 201780.0, 202956.0, 155226.0, 105498.0, 53760.0, + 38805.0, 79104.0, 120906.0, 164220.0, 165174.0, 126198.0, 85683.0, + 43620.0, 28017.0, 57057.0, 87126.0, 118230.0, 118914.0, 90771.0, + 61575.0, 31320.0, 15102.0, 30729.0, 46884.0, 63570.0, 63936.0, + 48765.0, 33054.0, 16800.0, 17220.0, 34863.0, 52932.0, 71430.0, + 72228.0, 54831.0, 36996.0, 18720.0, 36327.0, 73527.0, 111606.0, + 150570.0, 152214.0, 115521.0, 77925.0, 39420.0, 57381.0, 116112.0, + 176202.0, 237660.0, 240198.0, 182250.0, 122907.0, 62160.0, 80442.0, + 162738.0, 246900.0, 332940.0, 336420.0, 255198.0, 172062.0, 87000.0, + 84702.0, 171318.0, 259860.0, 350340.0, 353820.0, 268338.0, 180882.0, + 91440.0, 66867.0, 135210.0, 205038.0, 276360.0, 279042.0, 211572.0, + 142581.0, 72060.0, 46845.0, 94701.0, 143574.0, 193470.0, 195306.0, + 148047.0, 99747.0, 50400.0, 24576.0, 49671.0, 75288.0, 101430.0, + 102372.0, 77583.0, 52260.0, 26400.0, 22095.0, 44688.0, 67782.0, + 91380.0, 92178.0, 69906.0, 47121.0, 23820.0, 46377.0, 93777.0, + 142206.0, 191670.0, 193314.0, 146571.0, 98775.0, 49920.0, 72906.0, + 147387.0, 223452.0, 301110.0, 303648.0, 230175.0, 155082.0, 78360.0, + 101742.0, 205638.0, 311700.0, 419940.0, 423420.0, 320898.0, 216162.0, + 109200.0, 106002.0, 214218.0, 324660.0, 437340.0, 440820.0, 334038.0, + 224982.0, 113640.0, 83292.0, 168285.0, 254988.0, 343410.0, 346092.0, + 262197.0, 176556.0, 89160.0, 58095.0, 117351.0, 177774.0, 239370.0, + 241206.0, 182697.0, 122997.0, 62100.0, 30351.0, 61296.0, 92838.0, + 124980.0, 125922.0, 95358.0, 64185.0, 32400.0, 26970.0, 54513.0, + 82632.0, 111330.0, 112128.0, 84981.0, 57246.0, 28920.0, 56427.0, + 114027.0, 172806.0, 232770.0, 234414.0, 177621.0, 119625.0, 60420.0, + 88431.0, 178662.0, 270702.0, 364560.0, 367098.0, 278100.0, 187257.0, + 94560.0, 123042.0, 248538.0, 376500.0, 506940.0, 510420.0, 386598.0, + 260262.0, 131400.0, 127302.0, 257118.0, 389460.0, 524340.0, 527820.0, + 399738.0, 269082.0, 135840.0, 99717.0, 201360.0, 304938.0, 410460.0, + 413142.0, 312822.0, 210531.0, 106260.0, 69345.0, 140001.0, 211974.0, + 285270.0, 287106.0, 217347.0, 146247.0, 73800.0, 36126.0, 72921.0, + 110388.0, 148530.0, 149472.0, 113133.0, 76110.0, 38400.0, + }; + NDArray exp(_expB, _expS); + + auto input = NDArrayFactory::create('c', {2, 3, 4, 4}); + auto weights = NDArrayFactory::create('c', {3, 3, 5, 5}); + auto z = NDArrayFactory::create('c', {2, 3, 8, 8}); + + input.linspace(1); + weights.linspace(1); + weights.permutei({2, 3, 1, 0}); + + sd::ops::deconv2d op; + auto result = + op.execute({&input, &weights}, {&z}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, result); + + ASSERT_TRUE(exp.isSameShape(&z)); + ASSERT_TRUE(exp.equalsTo(&z)); } TYPED_TEST(TypedConvolutionTests1, deconv2d_test6) { - - int bS=2, iH=4,iW=4, iC=3,oC=3, kH=5,kW=5, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=8,oW=8; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto weights = NDArrayFactory::create('c', {kH, kW, oC, iC}, {1.f, 76.f, 151.f, 26.f, 101.f, 176.f, 51.f, 126.f, 201.f, 2.f, 77.f, 152.f, 27.f, 102.f, 177.f, 52.f, 127.f, 202.f, 3.f, 78.f, 153.f, 28.f, 103.f, 178.f, 53.f, 128.f, 203.f, - 4.f, 79.f, 154.f, 29.f, 104.f, 179.f, 54.f, 129.f, 204.f, 5.f, 80.f, 155.f, 30.f, 105.f, 180.f, 55.f, 130.f, 205.f, 6.f, 81.f, 156.f, 31.f, 106.f, 181.f, 56.f, 131.f, 206.f, - 7.f, 82.f, 157.f, 32.f, 107.f, 182.f, 57.f, 132.f, 207.f, 8.f, 83.f, 158.f, 33.f, 108.f, 183.f, 58.f, 133.f, 208.f, 9.f, 84.f, 159.f, 34.f, 109.f, 184.f, 59.f, 134.f, 209.f, - 10.f, 85.f, 160.f, 35.f, 110.f, 185.f, 60.f, 135.f, 210.f, 11.f, 86.f, 161.f, 36.f, 111.f, 186.f, 61.f, 136.f, 211.f, 12.f, 87.f, 162.f, 37.f, 112.f, 187.f, 62.f, 137.f, 212.f, - 13.f, 88.f, 163.f, 38.f, 113.f, 188.f, 63.f, 138.f, 213.f, 14.f, 89.f, 164.f, 39.f, 114.f, 189.f, 64.f, 139.f, 214.f, 15.f, 90.f, 165.f, 40.f, 115.f, 190.f, 65.f, 140.f, 215.f, - 16.f, 91.f, 166.f, 41.f, 116.f, 191.f, 66.f, 141.f, 216.f, 17.f, 92.f, 167.f, 42.f, 117.f, 192.f, 67.f, 142.f, 217.f, 18.f, 93.f, 168.f, 43.f, 118.f, 193.f, 68.f, 143.f, 218.f, - 19.f, 94.f, 169.f, 44.f, 119.f, 194.f, 69.f, 144.f, 219.f, 20.f, 95.f, 170.f, 45.f, 120.f, 195.f, 70.f, 145.f, 220.f, 21.f, 96.f, 171.f, 46.f, 121.f, 196.f, 71.f, 146.f, 221.f, - 22.f, 97.f, 172.f, 47.f, 122.f, 197.f, 72.f, 147.f, 222.f, 23.f, 98.f, 173.f, 48.f, 123.f, 198.f, 73.f, 148.f, 223.f, 24.f, 99.f, 174.f, 49.f, 124.f, 199.f, 74.f, 149.f, 224.f, - 25.f, 100.f, 175.f,50.f, 125.f, 200.f,75.f, 150.f, 225.f}); - - auto exp = NDArrayFactory::create('c', {bS, oC, oH, oW}, {6276.0f, 12831.0f, 19668.0f, 26790.0f, 27012.0f, 20703.0f, 14100.0f, 7200.0f, 13719.0f, 28023.0f, 42918.0f, 58410.0f, 58902.0f, 45105.0f, 30693.0f, 15660.0f, 22389.0f, 45696.0f, 69930.0f, 95100.0f, 95910.0f, 73386.0f, 49899.0f, 25440.0f, 32346.0f, 65970.0f, 100884.0f, 137100.0f, 138276.0f, 105726.0f, 71838.0f, 36600.0f, 33726.0f, 68790.0f, 105204.0f, 142980.0f, 144156.0f, 110226.0f, 74898.0f, 38160.0f, 27555.0f, 56154.0f, 85806.0f, 116520.0f, 117474.0f, 89748.0f, 60933.0f, 31020.0f, 19917.0f, 40557.0f, 61926.0f, 84030.0f, 84714.0f, 64671.0f, 43875.0f, 22320.0f, 10752.0f, 21879.0f, 33384.0f, 45270.0f, 45636.0f, 34815.0f, 23604.0f, 12000.0f, 7551.0f, 15456.0f, 23718.0f, 32340.0f, 32562.0f, 24978.0f, 17025.0f, 8700.0f, 16569.0f, 33873.0f, 51918.0f, 70710.0f, 71202.0f, 54555.0f, 37143.0f, 18960.0f, 27114.0f, 55371.0f, 84780.0f, 115350.0f, 116160.0f, 88911.0f, 60474.0f, 30840.0f, 39246.0f, 80070.0f, 122484.0f, 166500.0f, 167676.0f, 128226.0f, 87138.0f, 44400.0f, 40626.0f, 82890.0f, 126804.0f, 172380.0f, 173556.0f, 132726.0f, 90198.0f, 45960.0f, 33180.0f, 67629.0f, 103356.0f, 140370.0f, 141324.0f, 107973.0f, 73308.0f, 37320.0f, 23967.0f, 48807.0f, 74526.0f, 101130.0f, 101814.0f, 77721.0f, 52725.0f, 26820.0f, 12927.0f, 26304.0f, 40134.0f, 54420.0f, 54786.0f, 41790.0f, 28329.0f, 14400.0f, 8826.0f, 18081.0f, 27768.0f, 37890.0f, 38112.0f, 29253.0f, 19950.0f, 10200.0f, 19419.0f, 39723.0f, 60918.0f, 83010.0f, 83502.0f, 64005.0f, 43593.0f, 22260.0f, 31839.0f, 65046.0f, 99630.0f, 135600.0f, 136410.0f, 104436.0f, 71049.0f, 36240.0f, 46146.0f, 94170.0f, 144084.0f, 195900.0f, 197076.0f, 150726.0f, 102438.0f, 52200.0f, 47526.0f, 96990.0f, 148404.0f, 201780.0f, 202956.0f, 155226.0f, 105498.0f, 53760.0f, 38805.0f, 79104.0f, 120906.0f, 164220.0f, 165174.0f, 126198.0f, 85683.0f, 43620.0f, 28017.0f, 57057.0f, 87126.0f, 118230.0f, 118914.0f, 90771.0f, 61575.0f, 31320.0f, 15102.0f, 30729.0f, 46884.0f, 63570.0f, 63936.0f, 48765.0f, 33054.0f, 16800.0f, 17220.0f, 34863.0f, 52932.0f, 71430.0f, 72228.0f, 54831.0f, 36996.0f, 18720.0f, 36327.0f, 73527.0f, 111606.0f, 150570.0f, 152214.0f, 115521.0f, 77925.0f, 39420.0f, 57381.0f, 116112.0f, 176202.0f, 237660.0f, 240198.0f, 182250.0f, 122907.0f, 62160.0f, 80442.0f, 162738.0f, 246900.0f, 332940.0f, 336420.0f, 255198.0f, 172062.0f, 87000.0f, 84702.0f, 171318.0f, 259860.0f, 350340.0f, 353820.0f, 268338.0f, 180882.0f, 91440.0f, 66867.0f, 135210.0f, 205038.0f, 276360.0f, 279042.0f, 211572.0f, 142581.0f, 72060.0f, 46845.0f, 94701.0f, 143574.0f, 193470.0f, 195306.0f, 148047.0f, 99747.0f, 50400.0f, 24576.0f, 49671.0f, 75288.0f, 101430.0f, 102372.0f, 77583.0f, 52260.0f, 26400.0f, 22095.0f, 44688.0f, 67782.0f, 91380.0f, 92178.0f, 69906.0f, 47121.0f, 23820.0f, 46377.0f, 93777.0f, 142206.0f, 191670.0f, 193314.0f, 146571.0f, 98775.0f, 49920.0f, 72906.0f, 147387.0f, 223452.0f, 301110.0f, 303648.0f, 230175.0f, 155082.0f, 78360.0f, 101742.0f, 205638.0f, 311700.0f, 419940.0f, 423420.0f, 320898.0f, 216162.0f, 109200.0f, 106002.0f, 214218.0f, 324660.0f, 437340.0f, 440820.0f, 334038.0f, 224982.0f, 113640.0f, 83292.0f, 168285.0f, 254988.0f, 343410.0f, 346092.0f, 262197.0f, 176556.0f, 89160.0f, 58095.0f, 117351.0f, 177774.0f, 239370.0f, 241206.0f, 182697.0f, 122997.0f, 62100.0f, 30351.0f, 61296.0f, 92838.0f, 124980.0f, 125922.0f, 95358.0f, 64185.0f, 32400.0f, 26970.0f, 54513.0f, 82632.0f, 111330.0f, 112128.0f, 84981.0f, 57246.0f, 28920.0f, 56427.0f, 114027.0f, 172806.0f, 232770.0f, 234414.0f, 177621.0f, 119625.0f, 60420.0f, 88431.0f, 178662.0f, 270702.0f, 364560.0f, 367098.0f, 278100.0f, 187257.0f, 94560.0f, 123042.0f, 248538.0f, 376500.0f, 506940.0f, 510420.0f, 386598.0f, 260262.0f, 131400.0f, 127302.0f, 257118.0f, 389460.0f, 524340.0f, 527820.0f, 399738.0f, 269082.0f, 135840.0f, 99717.0f, 201360.0f, 304938.0f, 410460.0f, 413142.0f, 312822.0f, 210531.0f, 106260.0f, 69345.0f, 140001.0f, 211974.0f, 285270.0f, 287106.0f, 217347.0f, 146247.0f, 73800.0f, 36126.0f, 72921.0f, 110388.0f, 148530.0f, 149472.0f, 113133.0f, 76110.0f, 38400.0f}); - - input.linspace(1); - - sd::ops::deconv2d op; - auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - - ASSERT_EQ(Status::OK(), results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + int bS = 2, iH = 4, iW = 4, iC = 3, oC = 3, kH = 5, kW = 5, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 8, oW = 8; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create( + 'c', {kH, kW, oC, iC}, + {1.f, 76.f, 151.f, 26.f, 101.f, 176.f, 51.f, 126.f, 201.f, 2.f, + 77.f, 152.f, 27.f, 102.f, 177.f, 52.f, 127.f, 202.f, 3.f, 78.f, + 153.f, 28.f, 103.f, 178.f, 53.f, 128.f, 203.f, 4.f, 79.f, 154.f, + 29.f, 104.f, 179.f, 54.f, 129.f, 204.f, 5.f, 80.f, 155.f, 30.f, + 105.f, 180.f, 55.f, 130.f, 205.f, 6.f, 81.f, 156.f, 31.f, 106.f, + 181.f, 56.f, 131.f, 206.f, 7.f, 82.f, 157.f, 32.f, 107.f, 182.f, + 57.f, 132.f, 207.f, 8.f, 83.f, 158.f, 33.f, 108.f, 183.f, 58.f, + 133.f, 208.f, 9.f, 84.f, 159.f, 34.f, 109.f, 184.f, 59.f, 134.f, + 209.f, 10.f, 85.f, 160.f, 35.f, 110.f, 185.f, 60.f, 135.f, 210.f, + 11.f, 86.f, 161.f, 36.f, 111.f, 186.f, 61.f, 136.f, 211.f, 12.f, + 87.f, 162.f, 37.f, 112.f, 187.f, 62.f, 137.f, 212.f, 13.f, 88.f, + 163.f, 38.f, 113.f, 188.f, 63.f, 138.f, 213.f, 14.f, 89.f, 164.f, + 39.f, 114.f, 189.f, 64.f, 139.f, 214.f, 15.f, 90.f, 165.f, 40.f, + 115.f, 190.f, 65.f, 140.f, 215.f, 16.f, 91.f, 166.f, 41.f, 116.f, + 191.f, 66.f, 141.f, 216.f, 17.f, 92.f, 167.f, 42.f, 117.f, 192.f, + 67.f, 142.f, 217.f, 18.f, 93.f, 168.f, 43.f, 118.f, 193.f, 68.f, + 143.f, 218.f, 19.f, 94.f, 169.f, 44.f, 119.f, 194.f, 69.f, 144.f, + 219.f, 20.f, 95.f, 170.f, 45.f, 120.f, 195.f, 70.f, 145.f, 220.f, + 21.f, 96.f, 171.f, 46.f, 121.f, 196.f, 71.f, 146.f, 221.f, 22.f, + 97.f, 172.f, 47.f, 122.f, 197.f, 72.f, 147.f, 222.f, 23.f, 98.f, + 173.f, 48.f, 123.f, 198.f, 73.f, 148.f, 223.f, 24.f, 99.f, 174.f, + 49.f, 124.f, 199.f, 74.f, 149.f, 224.f, 25.f, 100.f, 175.f, 50.f, + 125.f, 200.f, 75.f, 150.f, 225.f}); + + auto exp = NDArrayFactory::create( + 'c', {bS, oC, oH, oW}, + {6276.0f, 12831.0f, 19668.0f, 26790.0f, 27012.0f, 20703.0f, + 14100.0f, 7200.0f, 13719.0f, 28023.0f, 42918.0f, 58410.0f, + 58902.0f, 45105.0f, 30693.0f, 15660.0f, 22389.0f, 45696.0f, + 69930.0f, 95100.0f, 95910.0f, 73386.0f, 49899.0f, 25440.0f, + 32346.0f, 65970.0f, 100884.0f, 137100.0f, 138276.0f, 105726.0f, + 71838.0f, 36600.0f, 33726.0f, 68790.0f, 105204.0f, 142980.0f, + 144156.0f, 110226.0f, 74898.0f, 38160.0f, 27555.0f, 56154.0f, + 85806.0f, 116520.0f, 117474.0f, 89748.0f, 60933.0f, 31020.0f, + 19917.0f, 40557.0f, 61926.0f, 84030.0f, 84714.0f, 64671.0f, + 43875.0f, 22320.0f, 10752.0f, 21879.0f, 33384.0f, 45270.0f, + 45636.0f, 34815.0f, 23604.0f, 12000.0f, 7551.0f, 15456.0f, + 23718.0f, 32340.0f, 32562.0f, 24978.0f, 17025.0f, 8700.0f, + 16569.0f, 33873.0f, 51918.0f, 70710.0f, 71202.0f, 54555.0f, + 37143.0f, 18960.0f, 27114.0f, 55371.0f, 84780.0f, 115350.0f, + 116160.0f, 88911.0f, 60474.0f, 30840.0f, 39246.0f, 80070.0f, + 122484.0f, 166500.0f, 167676.0f, 128226.0f, 87138.0f, 44400.0f, + 40626.0f, 82890.0f, 126804.0f, 172380.0f, 173556.0f, 132726.0f, + 90198.0f, 45960.0f, 33180.0f, 67629.0f, 103356.0f, 140370.0f, + 141324.0f, 107973.0f, 73308.0f, 37320.0f, 23967.0f, 48807.0f, + 74526.0f, 101130.0f, 101814.0f, 77721.0f, 52725.0f, 26820.0f, + 12927.0f, 26304.0f, 40134.0f, 54420.0f, 54786.0f, 41790.0f, + 28329.0f, 14400.0f, 8826.0f, 18081.0f, 27768.0f, 37890.0f, + 38112.0f, 29253.0f, 19950.0f, 10200.0f, 19419.0f, 39723.0f, + 60918.0f, 83010.0f, 83502.0f, 64005.0f, 43593.0f, 22260.0f, + 31839.0f, 65046.0f, 99630.0f, 135600.0f, 136410.0f, 104436.0f, + 71049.0f, 36240.0f, 46146.0f, 94170.0f, 144084.0f, 195900.0f, + 197076.0f, 150726.0f, 102438.0f, 52200.0f, 47526.0f, 96990.0f, + 148404.0f, 201780.0f, 202956.0f, 155226.0f, 105498.0f, 53760.0f, + 38805.0f, 79104.0f, 120906.0f, 164220.0f, 165174.0f, 126198.0f, + 85683.0f, 43620.0f, 28017.0f, 57057.0f, 87126.0f, 118230.0f, + 118914.0f, 90771.0f, 61575.0f, 31320.0f, 15102.0f, 30729.0f, + 46884.0f, 63570.0f, 63936.0f, 48765.0f, 33054.0f, 16800.0f, + 17220.0f, 34863.0f, 52932.0f, 71430.0f, 72228.0f, 54831.0f, + 36996.0f, 18720.0f, 36327.0f, 73527.0f, 111606.0f, 150570.0f, + 152214.0f, 115521.0f, 77925.0f, 39420.0f, 57381.0f, 116112.0f, + 176202.0f, 237660.0f, 240198.0f, 182250.0f, 122907.0f, 62160.0f, + 80442.0f, 162738.0f, 246900.0f, 332940.0f, 336420.0f, 255198.0f, + 172062.0f, 87000.0f, 84702.0f, 171318.0f, 259860.0f, 350340.0f, + 353820.0f, 268338.0f, 180882.0f, 91440.0f, 66867.0f, 135210.0f, + 205038.0f, 276360.0f, 279042.0f, 211572.0f, 142581.0f, 72060.0f, + 46845.0f, 94701.0f, 143574.0f, 193470.0f, 195306.0f, 148047.0f, + 99747.0f, 50400.0f, 24576.0f, 49671.0f, 75288.0f, 101430.0f, + 102372.0f, 77583.0f, 52260.0f, 26400.0f, 22095.0f, 44688.0f, + 67782.0f, 91380.0f, 92178.0f, 69906.0f, 47121.0f, 23820.0f, + 46377.0f, 93777.0f, 142206.0f, 191670.0f, 193314.0f, 146571.0f, + 98775.0f, 49920.0f, 72906.0f, 147387.0f, 223452.0f, 301110.0f, + 303648.0f, 230175.0f, 155082.0f, 78360.0f, 101742.0f, 205638.0f, + 311700.0f, 419940.0f, 423420.0f, 320898.0f, 216162.0f, 109200.0f, + 106002.0f, 214218.0f, 324660.0f, 437340.0f, 440820.0f, 334038.0f, + 224982.0f, 113640.0f, 83292.0f, 168285.0f, 254988.0f, 343410.0f, + 346092.0f, 262197.0f, 176556.0f, 89160.0f, 58095.0f, 117351.0f, + 177774.0f, 239370.0f, 241206.0f, 182697.0f, 122997.0f, 62100.0f, + 30351.0f, 61296.0f, 92838.0f, 124980.0f, 125922.0f, 95358.0f, + 64185.0f, 32400.0f, 26970.0f, 54513.0f, 82632.0f, 111330.0f, + 112128.0f, 84981.0f, 57246.0f, 28920.0f, 56427.0f, 114027.0f, + 172806.0f, 232770.0f, 234414.0f, 177621.0f, 119625.0f, 60420.0f, + 88431.0f, 178662.0f, 270702.0f, 364560.0f, 367098.0f, 278100.0f, + 187257.0f, 94560.0f, 123042.0f, 248538.0f, 376500.0f, 506940.0f, + 510420.0f, 386598.0f, 260262.0f, 131400.0f, 127302.0f, 257118.0f, + 389460.0f, 524340.0f, 527820.0f, 399738.0f, 269082.0f, 135840.0f, + 99717.0f, 201360.0f, 304938.0f, 410460.0f, 413142.0f, 312822.0f, + 210531.0f, 106260.0f, 69345.0f, 140001.0f, 211974.0f, 285270.0f, + 287106.0f, 217347.0f, 146247.0f, 73800.0f, 36126.0f, 72921.0f, + 110388.0f, 148530.0f, 149472.0f, 113133.0f, 76110.0f, 38400.0f}); + + input.linspace(1); + + sd::ops::deconv2d op; + auto results = op.evaluate({&input, &weights}, {kH, kW, sH, sW, pH, pW, dH, + dW, paddingMode, dataFormat}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } TEST_F(ConvolutionTests1, deconv2d_test7) { - - NDArray exp('c', {3, 2, 4, 4}, {218., 227., 236., 245., 254., 263., 272., 281., 290., 299., 308., 317., 326., 335., 344., 353., 270., 282., 294., 306., 318., 330., 342., 354., 366., 378., 390., 402., 414., 426., 438., 450., 650., 659., 668., 677., 686., 695., 704., 713., 722., 731., 740., 749., 758., 767., 776., 785., 846., 858., 870., 882., 894., 906., 918., 930., 942., 954., 966., 978., 990., 1002., 1014., 1026., 1082., 1091., 1100., 1109., 1118., 1127., 1136., 1145., 1154., 1163., 1172., 1181., 1190., 1199., 1208., 1217., 1422., 1434., 1446., 1458., 1470., 1482., 1494., 1506., 1518., 1530., 1542., 1554., 1566., 1578., 1590., 1602.}); - - auto input = NDArrayFactory::create('c', {3, 3, 4, 4}); - auto weights = NDArrayFactory::create('c',{1, 1, 2, 3}, {1,3,5,2,4,6}); - auto bias = NDArrayFactory::create('c', {2}); - - input.linspace(1); - bias.linspace(1); - - sd::ops::deconv2d op; - - auto result = op.evaluate({&input, &weights, &bias}, {1, 1, 1, 1, 0, 0, 1, 1, 1, 0}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + NDArray exp( + 'c', {3, 2, 4, 4}, + {218., 227., 236., 245., 254., 263., 272., 281., 290., 299., + 308., 317., 326., 335., 344., 353., 270., 282., 294., 306., + 318., 330., 342., 354., 366., 378., 390., 402., 414., 426., + 438., 450., 650., 659., 668., 677., 686., 695., 704., 713., + 722., 731., 740., 749., 758., 767., 776., 785., 846., 858., + 870., 882., 894., 906., 918., 930., 942., 954., 966., 978., + 990., 1002., 1014., 1026., 1082., 1091., 1100., 1109., 1118., 1127., + 1136., 1145., 1154., 1163., 1172., 1181., 1190., 1199., 1208., 1217., + 1422., 1434., 1446., 1458., 1470., 1482., 1494., 1506., 1518., 1530., + 1542., 1554., 1566., 1578., 1590., 1602.}); + + auto input = NDArrayFactory::create('c', {3, 3, 4, 4}); + auto weights = + NDArrayFactory::create('c', {1, 1, 2, 3}, {1, 3, 5, 2, 4, 6}); + auto bias = NDArrayFactory::create('c', {2}); + + input.linspace(1); + bias.linspace(1); + + sd::ops::deconv2d op; + + auto result = + op.evaluate({&input, &weights, &bias}, {1, 1, 1, 1, 0, 0, 1, 1, 1, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, deconv2d_test8) { - - int bS=1, iH=7,iW=7, iC=3,oC=2, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=7,oW=7; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iC, iH, iW}, {0.679350, 0.355087, 0.842789, 0.200313, 0.701499, 0.310693, 0.447940, 0.938010, 0.326674, 0.151873, 0.383318, 0.782123, 0.198807, - 0.798564, 0.163263, 0.146968, 0.260897, 0.135058, 0.756209, 0.275454, 0.369088, 0.092826, 0.836492, 0.268413, 0.095062, 0.312795, 0.135918, 0.517544, 0.328703, - 0.061736, 0.396431, 0.248016, 0.548959, 0.115046, 0.814362, 0.721564, 0.404494, 0.299089, 0.403884, 0.988311, 0.022296, 0.927782, 0.318416, 0.068546, 0.284533, - 0.232720, 0.352142, 0.058909, 0.711221, 0.674457, 0.196946, 0.699497, 0.074322, 0.420425, 0.584263, 0.149574, 0.446406, 0.723072, 0.064481, 0.483078, 0.875996, - 0.569819, 0.445863, 0.527755, 0.016646, 0.753678, 0.140636, 0.754129, 0.161932, 0.775037, 0.332645, 0.117394, 0.017711, 0.608476, 0.525152, 0.917194, 0.849891, - 0.589423, 0.852278, 0.390636, 0.889683, 0.669445, 0.698873, 0.961480, 0.157401, 0.157364, 0.493520, 0.569937, 0.126832, 0.115728, 0.786368, 0.737939, 0.490079, - 0.608414, 0.956500, 0.390098, 0.147305, 0.850645, 0.497650, 0.071866, 0.082150, 0.035314, 0.732041, 0.369934, 0.840666, 0.273894, 0.431796, 0.133231, 0.192975, - 0.246897, 0.386418, 0.511541, 0.199036, 0.141631, 0.697699, 0.253631, 0.782218, 0.930099, 0.335512, 0.558808, 0.664358, 0.018851, 0.637559, 0.290430, 0.434902, - 0.842513, 0.466098, 0.381395, 0.523185, 0.990183, 0.925768, 0.643459, 0.016828, 0.918756, 0.228979, 0.006314, 0.665975, 0.190361, 0.595521, 0.698881, 0.221469, - 0.912434, 0.870822, 0.727369, 0.523972, 0.662884, 0.218841}); - - NDArray weights('c', {kH, kW, oC, iC}, {0.4195024073123932, 0.22738978266716003, 0.10093523561954498, 0.25008103251457214, 0.3183899223804474, 0.5976081490516663}); - NDArray bias('c', {1, oC}, {0.3596062958240509, 0.6866418123245239}); - - NDArray exp('c', {bS, oC, oH, oW}, {0.848190, 0.560603, 0.880509, 0.464103, 0.823376, 0.660138, 0.666382, 0.882257, 0.704650, 0.451427, 0.649734, 0.911822, 0.611581, - 0.847623, 0.568191, 0.439341, 0.710854, 0.473843, 0.927273, 0.605861, 0.724540, 0.530591, 0.804268, 0.478136, 0.602198, 0.639553, 0.669082, 0.855013, 0.678572, - 0.617800, 0.667545, 0.765899, 0.835564, 0.631733, 0.921562, 0.790830, 0.588187, 0.597934, 0.725855, 0.822259, 0.455384, 0.998167, 0.683336, 0.591897, 0.705213, - 0.748148, 0.648922, 0.484723, 0.873482, 1.368675, 0.881096, 1.169214, 0.781504, 1.433406, 1.171439, 1.348675, 1.227033, 1.256600, 0.824772, 1.051633, 1.308692, - 1.148711, 1.334007, 1.014448, 0.813336, 1.408801, 0.916766, 1.583323, 1.362920, 1.226212, 1.149715, 1.330235, 0.770671, 1.285158, 1.105632, 1.272558, 1.590159, - 1.235054, 1.201363, 1.222816, 1.623673, 1.590317, 1.322463, 1.206481, 1.466262, 0.974741, 0.922343, 1.367100, 1.087943, 1.084952, 1.586691, 1.133576, 1.405098, - 1.471922, 1.484062, 1.212039, 1.144419, 1.266123}); - - sd::ops::deconv2d op; - auto results = op.evaluate({&input, &weights, &bias}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - - ASSERT_EQ(Status::OK(), results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + int bS = 1, iH = 7, iW = 7, iC = 3, oC = 2, kH = 1, kW = 1, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 7, oW = 7; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input( + 'c', {bS, iC, iH, iW}, + {0.679350, 0.355087, 0.842789, 0.200313, 0.701499, 0.310693, 0.447940, + 0.938010, 0.326674, 0.151873, 0.383318, 0.782123, 0.198807, 0.798564, + 0.163263, 0.146968, 0.260897, 0.135058, 0.756209, 0.275454, 0.369088, + 0.092826, 0.836492, 0.268413, 0.095062, 0.312795, 0.135918, 0.517544, + 0.328703, 0.061736, 0.396431, 0.248016, 0.548959, 0.115046, 0.814362, + 0.721564, 0.404494, 0.299089, 0.403884, 0.988311, 0.022296, 0.927782, + 0.318416, 0.068546, 0.284533, 0.232720, 0.352142, 0.058909, 0.711221, + 0.674457, 0.196946, 0.699497, 0.074322, 0.420425, 0.584263, 0.149574, + 0.446406, 0.723072, 0.064481, 0.483078, 0.875996, 0.569819, 0.445863, + 0.527755, 0.016646, 0.753678, 0.140636, 0.754129, 0.161932, 0.775037, + 0.332645, 0.117394, 0.017711, 0.608476, 0.525152, 0.917194, 0.849891, + 0.589423, 0.852278, 0.390636, 0.889683, 0.669445, 0.698873, 0.961480, + 0.157401, 0.157364, 0.493520, 0.569937, 0.126832, 0.115728, 0.786368, + 0.737939, 0.490079, 0.608414, 0.956500, 0.390098, 0.147305, 0.850645, + 0.497650, 0.071866, 0.082150, 0.035314, 0.732041, 0.369934, 0.840666, + 0.273894, 0.431796, 0.133231, 0.192975, 0.246897, 0.386418, 0.511541, + 0.199036, 0.141631, 0.697699, 0.253631, 0.782218, 0.930099, 0.335512, + 0.558808, 0.664358, 0.018851, 0.637559, 0.290430, 0.434902, 0.842513, + 0.466098, 0.381395, 0.523185, 0.990183, 0.925768, 0.643459, 0.016828, + 0.918756, 0.228979, 0.006314, 0.665975, 0.190361, 0.595521, 0.698881, + 0.221469, 0.912434, 0.870822, 0.727369, 0.523972, 0.662884, 0.218841}); + + NDArray weights( + 'c', {kH, kW, oC, iC}, + {0.4195024073123932, 0.22738978266716003, 0.10093523561954498, + 0.25008103251457214, 0.3183899223804474, 0.5976081490516663}); + NDArray bias('c', {1, oC}, {0.3596062958240509, 0.6866418123245239}); + + NDArray exp( + 'c', {bS, oC, oH, oW}, + {0.848190, 0.560603, 0.880509, 0.464103, 0.823376, 0.660138, 0.666382, + 0.882257, 0.704650, 0.451427, 0.649734, 0.911822, 0.611581, 0.847623, + 0.568191, 0.439341, 0.710854, 0.473843, 0.927273, 0.605861, 0.724540, + 0.530591, 0.804268, 0.478136, 0.602198, 0.639553, 0.669082, 0.855013, + 0.678572, 0.617800, 0.667545, 0.765899, 0.835564, 0.631733, 0.921562, + 0.790830, 0.588187, 0.597934, 0.725855, 0.822259, 0.455384, 0.998167, + 0.683336, 0.591897, 0.705213, 0.748148, 0.648922, 0.484723, 0.873482, + 1.368675, 0.881096, 1.169214, 0.781504, 1.433406, 1.171439, 1.348675, + 1.227033, 1.256600, 0.824772, 1.051633, 1.308692, 1.148711, 1.334007, + 1.014448, 0.813336, 1.408801, 0.916766, 1.583323, 1.362920, 1.226212, + 1.149715, 1.330235, 0.770671, 1.285158, 1.105632, 1.272558, 1.590159, + 1.235054, 1.201363, 1.222816, 1.623673, 1.590317, 1.322463, 1.206481, + 1.466262, 0.974741, 0.922343, 1.367100, 1.087943, 1.084952, 1.586691, + 1.133576, 1.405098, 1.471922, 1.484062, 1.212039, 1.144419, 1.266123}); + + sd::ops::deconv2d op; + auto results = + op.evaluate({&input, &weights, &bias}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, deconv2d_test9) { - - int bS=2, oH=4,oW=4, oC=5,iC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int iH=3,iW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - int wFormat = 1; // 0-[kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] - - NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {iC, oC, kH, kW}, {100.000000, 75.000000, 50.000000, 25.000000, 95.000000, 70.000000, 45.000000, 20.000000, 90.000000, 65.000000, 40.000000, - 15.000000, 85.000000, 60.000000, 35.000000, 10.000000, 80.000000, 55.000000, 30.000000, 5.000000, 99.500000, 74.500000, 49.500000, 24.500000, 94.500000, 69.500000, - 44.500000, 19.500000, 89.500000, 64.500000, 39.500000, 14.500000, 84.500000, 59.500000, 34.500000, 9.500000, 79.500000, 54.500000, 29.500000, 4.500000, 99.000000, - 74.000000, 49.000000, 24.000000, 94.000000, 69.000000, 44.000000, 19.000000, 89.000000, 64.000000, 39.000000, 14.000000, 84.000000, 59.000000, 34.000000, 9.000000, - 79.000000, 54.000000, 29.000000, 4.000000, 98.500000, 73.500000, 48.500000, 23.500000, 93.500000, 68.500000, 43.500000, 18.500000, 88.500000, 63.500000, 38.500000, - 13.500000, 83.500000, 58.500000, 33.500000, 8.500000, 78.500000, 53.500000, 28.500000, 3.500000, 98.000000, 73.000000, 48.000000, 23.000000, 93.000000, 68.000000, - 43.000000, 18.000000, 88.000000, 63.000000, 38.000000, 13.000000, 83.000000, 58.000000, 33.000000, 8.000000, 78.000000, 53.000000, 28.000000, 3.000000, 97.500000, 72.500000, 47.500000, 22.500000, 92.500000, 67.500000, 42.500000, 17.500000, 87.500000, 62.500000, 37.500000, 12.500000, 82.500000, 57.500000, 32.500000, 7.500000, 77.500000, 52.500000, 27.500000, 2.500000, 97.000000, 72.000000, 47.000000, 22.000000, 92.000000, 67.000000, 42.000000, 17.000000, 87.000000, 62.000000, 37.000000, 12.000000, 82.000000, 57.000000, 32.000000, 7.000000, 77.000000, 52.000000, 27.000000, 2.000000, 96.500000, 71.500000, 46.500000, 21.500000, 91.500000, 66.500000, 41.500000, 16.500000, 86.500000, 61.500000, 36.500000, 11.500000, 81.500000, 56.500000, 31.500000, 6.500000, 76.500000, 51.500000, 26.500000, 1.500000, 96.000000, 71.000000, 46.000000, 21.000000, 91.000000, 66.000000, 41.000000, 16.000000, 86.000000, 61.000000, 36.000000, 11.000000, 81.000000, 56.000000, 31.000000, 6.000000, 76.000000, 51.000000, 26.000000, 1.000000, 95.500000, 70.500000, 45.500000, 20.500000, 90.500000, 65.500000, 40.500000, 15.500000, 85.500000, 60.500000, 35.500000, 10.500000, 80.500000, 55.500000, 30.500000, 5.500000, 75.500000, 50.500000, 25.500000, 0.500000}, sd::DataType::FLOAT32); - NDArray expOutput('c', {bS, oH, oW, oC}, {-30844.250000, -29266.750000, -27689.250000, -26111.750000, -24534.250000, -52823.500000, -49718.500000, -46613.500000, -43508.500000, -40403.500000, -51118.500000, - -48113.500000, -45108.500000, -42103.500000, -39098.500000, -21501.750000, -20024.250000, -18546.750000, -17069.250000, -15591.750000, -42981.000000, -39976.000000, -36971.000000, -33966.000000, -30961.000000, - -69482.000000, -63572.000000, -57662.000000, -51752.000000, -45842.000000, -67072.000000, -61362.000000, -55652.000000, -49942.000000, -44232.000000, -26046.000000, -23241.000000, -20436.000000, -17631.000000, - -14826.000000, -38616.000000, -35911.000000, -33206.000000, -30501.000000, -27796.000000, -62252.000000, -56942.000000, -51632.000000, -46322.000000, -41012.000000, -59842.000000, -54732.000000, -49622.000000, - -44512.000000, -39402.000000, -23181.000000, -20676.000000, -18171.000000, -15666.000000, -13161.000000, -12204.250000, -10926.750000, -9649.250000, -8371.750000, -7094.250000, -17543.500000, -15038.500000, - -12533.500000, -10028.500000, -7523.500000, -16838.500000, -14433.499023, -12028.500000, -9623.500000, -7218.500000, -5361.750000, -4184.250000, -3006.750000, -1829.250000, -651.750000, -22046.750000, -20919.250000, - -19791.750000, -18664.250000, -17536.750000, -37478.500000, -35273.500000, -33068.500000, -30863.500000, -28658.500000, -35773.500000, -33668.500000, -31563.500000, -29458.500000, -27353.500000, -14954.250000, - -13926.750000, -12899.250000, -11871.750000, -10844.250000, -29886.000000, -27781.000000, -25676.000000, -23571.000000, -21466.000000, -47792.000000, -43682.000000, -39572.000000, -35462.000000, -31352.000000, - -45382.000000, -41472.000000, -37562.000000, -33652.000000, -29742.000000, -17451.000000, -15546.000000, -13641.000000, -11736.000000, -9831.000000, -25521.000000, -23716.000000, -21911.000000, -20106.000000, -18301.000000, -40562.000000, -37052.000000, -33542.000000, -30032.000000, -26522.000000, -38152.000000, -34842.000000, -31532.000000, -28222.000000, -24912.000000, -14586.000000, -12981.000000, -11376.000000, -9771.000000, -8166.000000, -7906.750000, -7079.250000, -6251.750000, -5424.250000, -4596.750000, -11198.500000, -9593.500000, -7988.500000, -6383.500000, -4778.500000, -10493.500000, -8988.500000, -7483.500000, -5978.500000, -4473.500000, -3314.250000, -2586.750000, -1859.250000, -1131.750000, -404.250000}, sd::DataType::FLOAT32); - - input.linspace(-32, 0.1); - - sd::ops::deconv2d op; - auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); - ASSERT_EQ(Status::OK(), results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 2, oH = 4, oW = 4, oC = 5, iC = 10, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int iH = 3, iW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0-[kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {iC, oC, kH, kW}, + {100.000000, 75.000000, 50.000000, 25.000000, 95.000000, 70.000000, + 45.000000, 20.000000, 90.000000, 65.000000, 40.000000, 15.000000, + 85.000000, 60.000000, 35.000000, 10.000000, 80.000000, 55.000000, + 30.000000, 5.000000, 99.500000, 74.500000, 49.500000, 24.500000, + 94.500000, 69.500000, 44.500000, 19.500000, 89.500000, 64.500000, + 39.500000, 14.500000, 84.500000, 59.500000, 34.500000, 9.500000, + 79.500000, 54.500000, 29.500000, 4.500000, 99.000000, 74.000000, + 49.000000, 24.000000, 94.000000, 69.000000, 44.000000, 19.000000, + 89.000000, 64.000000, 39.000000, 14.000000, 84.000000, 59.000000, + 34.000000, 9.000000, 79.000000, 54.000000, 29.000000, 4.000000, + 98.500000, 73.500000, 48.500000, 23.500000, 93.500000, 68.500000, + 43.500000, 18.500000, 88.500000, 63.500000, 38.500000, 13.500000, + 83.500000, 58.500000, 33.500000, 8.500000, 78.500000, 53.500000, + 28.500000, 3.500000, 98.000000, 73.000000, 48.000000, 23.000000, + 93.000000, 68.000000, 43.000000, 18.000000, 88.000000, 63.000000, + 38.000000, 13.000000, 83.000000, 58.000000, 33.000000, 8.000000, + 78.000000, 53.000000, 28.000000, 3.000000, 97.500000, 72.500000, + 47.500000, 22.500000, 92.500000, 67.500000, 42.500000, 17.500000, + 87.500000, 62.500000, 37.500000, 12.500000, 82.500000, 57.500000, + 32.500000, 7.500000, 77.500000, 52.500000, 27.500000, 2.500000, + 97.000000, 72.000000, 47.000000, 22.000000, 92.000000, 67.000000, + 42.000000, 17.000000, 87.000000, 62.000000, 37.000000, 12.000000, + 82.000000, 57.000000, 32.000000, 7.000000, 77.000000, 52.000000, + 27.000000, 2.000000, 96.500000, 71.500000, 46.500000, 21.500000, + 91.500000, 66.500000, 41.500000, 16.500000, 86.500000, 61.500000, + 36.500000, 11.500000, 81.500000, 56.500000, 31.500000, 6.500000, + 76.500000, 51.500000, 26.500000, 1.500000, 96.000000, 71.000000, + 46.000000, 21.000000, 91.000000, 66.000000, 41.000000, 16.000000, + 86.000000, 61.000000, 36.000000, 11.000000, 81.000000, 56.000000, + 31.000000, 6.000000, 76.000000, 51.000000, 26.000000, 1.000000, + 95.500000, 70.500000, 45.500000, 20.500000, 90.500000, 65.500000, + 40.500000, 15.500000, 85.500000, 60.500000, 35.500000, 10.500000, + 80.500000, 55.500000, 30.500000, 5.500000, 75.500000, 50.500000, + 25.500000, 0.500000}, + sd::DataType::FLOAT32); + NDArray expOutput('c', {bS, oH, oW, oC}, + {-30844.250000, -29266.750000, -27689.250000, -26111.750000, + -24534.250000, -52823.500000, -49718.500000, -46613.500000, + -43508.500000, -40403.500000, -51118.500000, -48113.500000, + -45108.500000, -42103.500000, -39098.500000, -21501.750000, + -20024.250000, -18546.750000, -17069.250000, -15591.750000, + -42981.000000, -39976.000000, -36971.000000, -33966.000000, + -30961.000000, -69482.000000, -63572.000000, -57662.000000, + -51752.000000, -45842.000000, -67072.000000, -61362.000000, + -55652.000000, -49942.000000, -44232.000000, -26046.000000, + -23241.000000, -20436.000000, -17631.000000, -14826.000000, + -38616.000000, -35911.000000, -33206.000000, -30501.000000, + -27796.000000, -62252.000000, -56942.000000, -51632.000000, + -46322.000000, -41012.000000, -59842.000000, -54732.000000, + -49622.000000, -44512.000000, -39402.000000, -23181.000000, + -20676.000000, -18171.000000, -15666.000000, -13161.000000, + -12204.250000, -10926.750000, -9649.250000, -8371.750000, + -7094.250000, -17543.500000, -15038.500000, -12533.500000, + -10028.500000, -7523.500000, -16838.500000, -14433.499023, + -12028.500000, -9623.500000, -7218.500000, -5361.750000, + -4184.250000, -3006.750000, -1829.250000, -651.750000, + -22046.750000, -20919.250000, -19791.750000, -18664.250000, + -17536.750000, -37478.500000, -35273.500000, -33068.500000, + -30863.500000, -28658.500000, -35773.500000, -33668.500000, + -31563.500000, -29458.500000, -27353.500000, -14954.250000, + -13926.750000, -12899.250000, -11871.750000, -10844.250000, + -29886.000000, -27781.000000, -25676.000000, -23571.000000, + -21466.000000, -47792.000000, -43682.000000, -39572.000000, + -35462.000000, -31352.000000, -45382.000000, -41472.000000, + -37562.000000, -33652.000000, -29742.000000, -17451.000000, + -15546.000000, -13641.000000, -11736.000000, -9831.000000, + -25521.000000, -23716.000000, -21911.000000, -20106.000000, + -18301.000000, -40562.000000, -37052.000000, -33542.000000, + -30032.000000, -26522.000000, -38152.000000, -34842.000000, + -31532.000000, -28222.000000, -24912.000000, -14586.000000, + -12981.000000, -11376.000000, -9771.000000, -8166.000000, + -7906.750000, -7079.250000, -6251.750000, -5424.250000, + -4596.750000, -11198.500000, -9593.500000, -7988.500000, + -6383.500000, -4778.500000, -10493.500000, -8988.500000, + -7483.500000, -5978.500000, -4473.500000, -3314.250000, + -2586.750000, -1859.250000, -1131.750000, -404.250000}, + sd::DataType::FLOAT32); + + input.linspace(-32, 0.1); + + sd::ops::deconv2d op; + auto results = op.evaluate( + {&input, &weights}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat, wFormat}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, deconv2d_test10) { - - int bS=2, oH=4,oW=4, iC=5,oC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int iH=4,iW=4; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - int wFormat = 2; // 0-[kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] - - NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); - NDArray weights('c', {iC, kH, kW, oC}, {100., 95., 90., 85., 80., 75., 70., 65., 60., 55., 50., 45., 40., 35., 30., 25., 20., 15., 10., 5., 0., -5., -10., -15., - -20., -25., -30., -35., -40., -45., -50., -55., -60., -65., -70., -75., -80., -85., -90., -95., 99., 94., 89., 84., 79., 74., 69., 64., 59., 54., 49., 44., - 39., 34., 29., 24., 19., 14., 9., 4., -1., -6., -11., -16., -21., -26., -31., -36., -41., -46., -51., -56., -61., -66., -71., -76., -81., -86., -91., -96., - 98., 93., 88., 83., 78., 73., 68., 63., 58., 53., 48., 43., 38., 33., 28., 23., 18., 13., 8., 3., -2., -7., -12., -17., -22., -27., -32., -37., -42., -47., - -52., -57., -62., -67., -72., -77., -82., -87., -92., -97., 97., 92., 87., 82., 77., 72., 67., 62., 57., 52., 47., 42., 37., 32., 27., 22., 17., 12., 7., 2., - -3., -8., -13., -18., -23., -28., -33., -38., -43., -48., -53., -58., -63., -68., -73., -78., -83., -88., -93., -98., 96., 91., 86., 81., 76., 71., 66., 61., - 56., 51., 46., 41., 36., 31., 26., 21., 16., 11., 6., 1., -4., -9., -14., -19., -24., -29., -34., -39., -44., -49., -54., -59., -64., -69., -74., -79., -84., -89., -94., -99.}, sd::DataType::FLOAT32); - NDArray expOutput('c', {bS, oC, oH, oW}, {-14128., -21007., -20934., -20861., -13660., -12972., -12926.000977, -12880., -13468., -12788., -12742., -12696.000977, - -13276., -12604., -12558., -12512., -13408., -19569.5, -19501.5, -19433.5, -12230., -10117., -10081.000977, -10045., -12058., -9973., -9937., -9901.000977, - -11886., -9829., -9793., -9757., -12688., -18132., -18069., -18006., -10800., -7262., -7236., -7210., -10648., -7157.999512, -7132., -7106., -10496., -7054., - -7027.999512, -7002., -11968., -16694.5, -16636.5, -16578.5, -9370., -4406.999023, -4391., -4375., -9238., -4343., -4326.999023, -4311., -9106., -4279., -4263., - -4246.999023, -11247.999023, -15257., -15204., -15151., -7940., -1551.999023, -1546., -1540., -7828., -1528.000977, -1521.999023, -1516., -7716., -1504., - -1498.000977, -1491.999023, -10527.999023, -13819.5, -13771.5, -13723.5, -6510., 1303.000977, 1299., 1295., -6418., 1286.999023, 1283.000977, 1279., -6326., - 1271., 1266.999023, 1263.000977, -9807.999023, -12382., -12339., -12296., -5080., 4158.000977, 4144., 4130., -5008., 4101.999023, 4088., 4074., -4936., 4046., 4031.999023, 4018., -9088., -10944.5, -10906.5, -10868.5, -3650., 7013., 6989., 6965., -3598., 6917., 6893., 6869., -3546., 6821., 6797., 6773., -8368., -9507., -9474., -9441., -2220., 9868., 9834., 9800., -2187.999512, 9732., 9698., 9664., -2156., 9596., 9562., 9528., -7648., -8069.5, -8041.5, -8013.499512, -790.000488, 12723., 12679., 12635., -777.999512, 12547., 12503., 12459., -766., 12371., 12327., 12283., -10208., -15167., -15094., -15021., -9820., -9292., -9246., -9200., -9628., -9108., -9062., -9016., -9436., -8924., -8878., -8832., -9687.999023, -14129.5, -14061.5, -13993.5, -8790., -7236.999023, -7201., -7164.999512, -8618., -7093., -7057., -7021., -8446., -6949., -6913., -6877., -9168., -13092., -13029., -12966., -7760., -5182., -5156., -5129.999512, -7608., -5078., -5052., -5026., -7456., -4974., -4948., -4922., -8648., -12054.5, -11996.5, -11938.5, -6730., -3127., -3111., -3095., -6598., -3063., -3047., -3031., -6465.999512, -2999., -2983.000488, -2967., -8128., -11017., -10964., -10911., -5700.000488, -1072., -1066., -1060., -5587.999512, -1048.000488, -1042., -1036., -5476., -1023.999512, -1018.000488, -1012., -7608., -9979.5, -9931.5, -9883.5, -4670.000488, 983., 979., 975., -4577.999512, 966.999512, 963., 959., -4486., 951.000488, 946.999512, 943., -7088., -8942., -8899., -8856., -3640.000488, 3038., 3024., 3010., -3567.999512, 2981.999512, 2968., 2954., -3496., 2926.000488, 2911.999512, 2898., -6568., -7904.5, -7866.5, -7828.499512, -2610.000488, 5093., 5069., 5045., -2557.999512, 4996.999512, 4973., 4949., -2506., 4901.000488, 4877., 4853., -6048., -6867., -6834., -6800.999512, -1580., 7148., 7114., 7080., -1547.999512, 7012., 6978., 6944., -1516., 6876.000488, 6842., 6808., -5528., -5829.5, -5801.5, -5773.499512, -550., 9203., 9159., 9115., -537.999512, 9027., 8983., 8939., -526., 8851., 8807., 8763.}, sd::DataType::FLOAT32); - - input.linspace(-32, 0.1); - - sd::ops::deconv2d op; - auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); - ASSERT_EQ(Status::OK(), results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 2, oH = 4, oW = 4, iC = 5, oC = 10, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int iH = 4, iW = 4; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0-[kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {iC, kH, kW, oC}, + {100., 95., 90., 85., 80., 75., 70., 65., 60., 55., 50., 45., + 40., 35., 30., 25., 20., 15., 10., 5., 0., -5., -10., -15., + -20., -25., -30., -35., -40., -45., -50., -55., -60., -65., -70., -75., + -80., -85., -90., -95., 99., 94., 89., 84., 79., 74., 69., 64., + 59., 54., 49., 44., 39., 34., 29., 24., 19., 14., 9., 4., + -1., -6., -11., -16., -21., -26., -31., -36., -41., -46., -51., -56., + -61., -66., -71., -76., -81., -86., -91., -96., 98., 93., 88., 83., + 78., 73., 68., 63., 58., 53., 48., 43., 38., 33., 28., 23., + 18., 13., 8., 3., -2., -7., -12., -17., -22., -27., -32., -37., + -42., -47., -52., -57., -62., -67., -72., -77., -82., -87., -92., -97., + 97., 92., 87., 82., 77., 72., 67., 62., 57., 52., 47., 42., + 37., 32., 27., 22., 17., 12., 7., 2., -3., -8., -13., -18., + -23., -28., -33., -38., -43., -48., -53., -58., -63., -68., -73., -78., + -83., -88., -93., -98., 96., 91., 86., 81., 76., 71., 66., 61., + 56., 51., 46., 41., 36., 31., 26., 21., 16., 11., 6., 1., + -4., -9., -14., -19., -24., -29., -34., -39., -44., -49., -54., -59., + -64., -69., -74., -79., -84., -89., -94., -99.}, + sd::DataType::FLOAT32); + NDArray expOutput( + 'c', {bS, oC, oH, oW}, + {-14128., -21007., -20934., -20861., -13660., + -12972., -12926.000977, -12880., -13468., -12788., + -12742., -12696.000977, -13276., -12604., -12558., + -12512., -13408., -19569.5, -19501.5, -19433.5, + -12230., -10117., -10081.000977, -10045., -12058., + -9973., -9937., -9901.000977, -11886., -9829., + -9793., -9757., -12688., -18132., -18069., + -18006., -10800., -7262., -7236., -7210., + -10648., -7157.999512, -7132., -7106., -10496., + -7054., -7027.999512, -7002., -11968., -16694.5, + -16636.5, -16578.5, -9370., -4406.999023, -4391., + -4375., -9238., -4343., -4326.999023, -4311., + -9106., -4279., -4263., -4246.999023, -11247.999023, + -15257., -15204., -15151., -7940., -1551.999023, + -1546., -1540., -7828., -1528.000977, -1521.999023, + -1516., -7716., -1504., -1498.000977, -1491.999023, + -10527.999023, -13819.5, -13771.5, -13723.5, -6510., + 1303.000977, 1299., 1295., -6418., 1286.999023, + 1283.000977, 1279., -6326., 1271., 1266.999023, + 1263.000977, -9807.999023, -12382., -12339., -12296., + -5080., 4158.000977, 4144., 4130., -5008., + 4101.999023, 4088., 4074., -4936., 4046., + 4031.999023, 4018., -9088., -10944.5, -10906.5, + -10868.5, -3650., 7013., 6989., 6965., + -3598., 6917., 6893., 6869., -3546., + 6821., 6797., 6773., -8368., -9507., + -9474., -9441., -2220., 9868., 9834., + 9800., -2187.999512, 9732., 9698., 9664., + -2156., 9596., 9562., 9528., -7648., + -8069.5, -8041.5, -8013.499512, -790.000488, 12723., + 12679., 12635., -777.999512, 12547., 12503., + 12459., -766., 12371., 12327., 12283., + -10208., -15167., -15094., -15021., -9820., + -9292., -9246., -9200., -9628., -9108., + -9062., -9016., -9436., -8924., -8878., + -8832., -9687.999023, -14129.5, -14061.5, -13993.5, + -8790., -7236.999023, -7201., -7164.999512, -8618., + -7093., -7057., -7021., -8446., -6949., + -6913., -6877., -9168., -13092., -13029., + -12966., -7760., -5182., -5156., -5129.999512, + -7608., -5078., -5052., -5026., -7456., + -4974., -4948., -4922., -8648., -12054.5, + -11996.5, -11938.5, -6730., -3127., -3111., + -3095., -6598., -3063., -3047., -3031., + -6465.999512, -2999., -2983.000488, -2967., -8128., + -11017., -10964., -10911., -5700.000488, -1072., + -1066., -1060., -5587.999512, -1048.000488, -1042., + -1036., -5476., -1023.999512, -1018.000488, -1012., + -7608., -9979.5, -9931.5, -9883.5, -4670.000488, + 983., 979., 975., -4577.999512, 966.999512, + 963., 959., -4486., 951.000488, 946.999512, + 943., -7088., -8942., -8899., -8856., + -3640.000488, 3038., 3024., 3010., -3567.999512, + 2981.999512, 2968., 2954., -3496., 2926.000488, + 2911.999512, 2898., -6568., -7904.5, -7866.5, + -7828.499512, -2610.000488, 5093., 5069., 5045., + -2557.999512, 4996.999512, 4973., 4949., -2506., + 4901.000488, 4877., 4853., -6048., -6867., + -6834., -6800.999512, -1580., 7148., 7114., + 7080., -1547.999512, 7012., 6978., 6944., + -1516., 6876.000488, 6842., 6808., -5528., + -5829.5, -5801.5, -5773.499512, -550., 9203., + 9159., 9115., -537.999512, 9027., 8983., + 8939., -526., 8851., 8807., 8763.}, + sd::DataType::FLOAT32); + + input.linspace(-32, 0.1); + + sd::ops::deconv2d op; + auto results = op.evaluate( + {&input, &weights}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat, wFormat}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, deconv2d_tf_test1) { - - int bS=2, iH=4,iW=4, iC=5,oC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=3,oW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, oH, oW, oC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); - auto outShape = NDArrayFactory::create('c', {4}, {static_cast(bS), static_cast(iH), static_cast(iW), static_cast(iC)}); - auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, { 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f, - 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f}); - input = 0.5; - weights.linspace(0.1, 0.1); - - sd::ops::deconv2d_tf op; - auto results = op.evaluate({&outShape, &weights, &input}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + int bS = 2, iH = 4, iW = 4, iC = 5, oC = 10, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 3, oW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, oH, oW, oC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); + auto outShape = NDArrayFactory::create( + 'c', {4}, + {static_cast(bS), static_cast(iH), + static_cast(iW), static_cast(iC)}); + auto exp = NDArrayFactory::create( + 'c', {bS, iH, iW, iC}, + {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, + 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, + 42.75f, 47.75f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, + 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, + 115.5f, 125.5f, 135.5f, 145.5f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, + 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, + 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, 52.75f, 57.75f, 62.75f, + 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, + 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f, 2.75f, + 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, + 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, + 47.75f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, + 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, + 125.5f, 135.5f, 145.5f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, + 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, 52.75f, 57.75f, 62.75f, 67.75f, + 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, + 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f}); + input = 0.5; + weights.linspace(0.1, 0.1); + + sd::ops::deconv2d_tf op; + auto results = + op.evaluate({&outShape, &weights, &input}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } - -#endif //LIBND4J_CONVOLUTIONTESTS1_H - +#endif // LIBND4J_CONVOLUTIONTESTS1_H diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp index 169c51124ff2..f9f39eb086b2 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp @@ -22,2824 +22,6652 @@ #ifndef LIBND4J_CONVOLUTIONTESTS2_H #define LIBND4J_CONVOLUTIONTESTS2_H -#include "testlayers.h" #include #include #include #include #include +#include +#include #include -#include #include -#include -#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class ConvolutionTests2 : public testing::Test { -public: - - const int bS = 2; // batch size - const int iD = 1; // input depth (number of picture channels, for example rgb=3) - const int iH = 28; // picture height in pixels - const int iW = 28; // picture width in pixels - const int oD = 3; // output depth (= N for dense layer) - const int kH = 5; // kernel height in pixels - const int kW = 5; // kernel width in pixels - const int sH = 1; // stride step in horizontal direction - const int sW = 1; // stride step in vertical direction - const int pH = 0; // padding height - const int pW = 0; // padding width - const int dH = 2; // dilation height - const int dW = 2; // dilation width - const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height - const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width - + public: + const int bS = 2; // batch size + const int iD = + 1; // input depth (number of picture channels, for example rgb=3) + const int iH = 28; // picture height in pixels + const int iW = 28; // picture width in pixels + const int oD = 3; // output depth (= N for dense layer) + const int kH = 5; // kernel height in pixels + const int kW = 5; // kernel width in pixels + const int sH = 1; // stride step in horizontal direction + const int sW = 1; // stride step in vertical direction + const int pH = 0; // padding height + const int pW = 0; // padding width + const int dH = 2; // dilation height + const int dW = 2; // dilation width + const int oH = + (iH - kH - (kH - 1) * (dH - 1) + 2 * pH) / sH + 1; // output height + const int oW = + (iW - kW - (kW - 1) * (dW - 1) + 2 * pW) / sW + 1; // output width }; ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, im2col_1) { - - int bS=2, iH=4,iW=3, iC=4, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; // VALID - int oW = (iW - (kW + (kW-1)*(dW-1)) + 2*pW)/sW + 1; // VALID - - int paddingMode = 0; // 1-SAME, 0-VALID; - - NDArray image('c', {bS, iC, iH, iW}, sd::DataType::DOUBLE); - NDArray expected('c', {bS, iC, kH, kW, oH, oW}, {1, 2, 4, 5, 2, 3, 5, 6, 4, 5, 7, 8, 5, 6, 8, 9, 7, 8, 10, 11, 8, 9, 11, 12, 13, 14, 16, 17, 14, - 15, 17, 18, 16, 17, 19, 20, 17, 18, 20, 21, 19, 20, 22, 23, 20, 21, 23, 24, 25, 26, 28, 29, 26, 27, 29, 30, - 28, 29, 31, 32, 29, 30, 32, 33, 31, 32, 34, 35, 32, 33, 35, 36, 37, 38, 40, 41, 38, 39, 41, 42, 40, 41, 43, - 44, 41, 42, 44, 45, 43, 44, 46, 47, 44, 45, 47, 48, 49, 50, 52, 53, 50, 51, 53, 54, 52, 53, 55, 56, 53, 54, - 56, 57, 55, 56, 58, 59, 56, 57, 59, 60, 61, 62, 64, 65, 62, 63, 65, 66, 64, 65, 67, 68, 65, 66, 68, 69, 67, - 68, 70, 71, 68, 69, 71, 72, 73, 74, 76, 77, 74, 75, 77, 78, 76, 77, 79, 80, 77, 78, 80, 81, 79, 80, 82, 83, - 80, 81, 83, 84, 85, 86, 88, 89, 86, 87, 89, 90, 88, 89, 91, 92, 89, 90, 92, 93, 91, 92, 94, 95, 92, 93, 95, 96}); - - image.linspace(1, 1); - - sd::ops::im2col op; - auto results = op.evaluate({&image}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode}); - auto column = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(column)); - ASSERT_TRUE(expected.equalsTo(column)); - + int bS = 2, iH = 4, iW = 3, iC = 4, kH = 3, kW = 2, sH = 1, sW = 1, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = (iH - (kH + (kH - 1) * (dH - 1)) + 2 * pH) / sH + 1; // VALID + int oW = (iW - (kW + (kW - 1) * (dW - 1)) + 2 * pW) / sW + 1; // VALID + + int paddingMode = 0; // 1-SAME, 0-VALID; + + NDArray image('c', {bS, iC, iH, iW}, sd::DataType::DOUBLE); + NDArray expected( + 'c', {bS, iC, kH, kW, oH, oW}, + {1, 2, 4, 5, 2, 3, 5, 6, 4, 5, 7, 8, 5, 6, 8, 9, 7, 8, + 10, 11, 8, 9, 11, 12, 13, 14, 16, 17, 14, 15, 17, 18, 16, 17, 19, 20, + 17, 18, 20, 21, 19, 20, 22, 23, 20, 21, 23, 24, 25, 26, 28, 29, 26, 27, + 29, 30, 28, 29, 31, 32, 29, 30, 32, 33, 31, 32, 34, 35, 32, 33, 35, 36, + 37, 38, 40, 41, 38, 39, 41, 42, 40, 41, 43, 44, 41, 42, 44, 45, 43, 44, + 46, 47, 44, 45, 47, 48, 49, 50, 52, 53, 50, 51, 53, 54, 52, 53, 55, 56, + 53, 54, 56, 57, 55, 56, 58, 59, 56, 57, 59, 60, 61, 62, 64, 65, 62, 63, + 65, 66, 64, 65, 67, 68, 65, 66, 68, 69, 67, 68, 70, 71, 68, 69, 71, 72, + 73, 74, 76, 77, 74, 75, 77, 78, 76, 77, 79, 80, 77, 78, 80, 81, 79, 80, + 82, 83, 80, 81, 83, 84, 85, 86, 88, 89, 86, 87, 89, 90, 88, 89, 91, 92, + 89, 90, 92, 93, 91, 92, 94, 95, 92, 93, 95, 96}); + + image.linspace(1, 1); + + sd::ops::im2col op; + auto results = + op.evaluate({&image}, {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode}); + auto column = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expected.isSameShape(column)); + ASSERT_TRUE(expected.equalsTo(column)); } template class TypedConvolutionTests2 : public testing::Test { -public: - + public: }; typedef ::testing::Types TestingTypes; -TYPED_TEST_CASE(TypedConvolutionTests2, TestingTypes); +TYPED_TEST_SUITE(TypedConvolutionTests2, TestingTypes); ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, deconv2d_tf_test2) { - - int bS=2, iH=4,iW=4, iC=5,oC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=4,oW=4; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, oH, oW, oC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); - auto outShape = NDArrayFactory::create('c', {4}, {static_cast(bS), static_cast(iH), static_cast(iW), static_cast(iC)}); - auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f}); - input = 0.5; - weights.linspace(0.1, 0.1); - - sd::ops::deconv2d_tf op; - auto results = op.evaluate({&outShape, &weights, &input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + int bS = 2, iH = 4, iW = 4, iC = 5, oC = 10, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 4; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, oH, oW, oC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); + auto outShape = NDArrayFactory::create( + 'c', {4}, + {static_cast(bS), static_cast(iH), + static_cast(iW), static_cast(iC)}); + auto exp = NDArrayFactory::create( + 'c', {bS, iH, iW, iC}, + {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, + 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, + 60.5f, 70.5f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, + 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, + 181.f, 201.f, 221.f, 241.f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, + 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, + 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 55.5f, 65.5f, 75.5f, + 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, + 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 2.75f, + 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, + 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, + 70.5f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, + 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, + 201.f, 221.f, 241.f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, + 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 161.f, 181.f, 201.f, 221.f, 241.f, 55.5f, 65.5f, 75.5f, 85.5f, + 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, + 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f}); + input = 0.5; + weights.linspace(0.1, 0.1); + + sd::ops::deconv2d_tf op; + auto results = + op.evaluate({&outShape, &weights, &input}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, Test_DeConv2D_TF_1) { - auto input0 = NDArrayFactory::create('c', {4}, {12.f, 5.f, 5.f, 32.f}); - auto input1 = NDArrayFactory::create('c', {2, 2, 32, 16}); - auto input2 = NDArrayFactory::create('c', {12, 4, 4, 16}); - auto exp = NDArrayFactory::create('c', {12, 5, 5, 32}); - - sd::ops::deconv2d_tf op; - auto result = op.evaluate({&input0, &input1, &input2}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_EQ(exp, *result.at(0)); - + auto input0 = + NDArrayFactory::create('c', {4}, {12.f, 5.f, 5.f, 32.f}); + auto input1 = NDArrayFactory::create('c', {2, 2, 32, 16}); + auto input2 = NDArrayFactory::create('c', {12, 4, 4, 16}); + auto exp = NDArrayFactory::create('c', {12, 5, 5, 32}); + + sd::ops::deconv2d_tf op; + auto result = op.evaluate({&input0, &input1, &input2}, {}, + {2, 2, 1, 1, 0, 0, 1, 1, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_EQ(exp, result.at(0)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, Test_DeConv2D_TF_2) { - auto input0 = NDArrayFactory::create('c', {4}, {3.f, 8.f, 8.f, 16.f}); - - auto input1 = NDArrayFactory::create('c', {7, 7, 16, 5}, {1.05293429f, -0.89349967f, 0.31027254f, 1.22991478f, -0.62926656f, 0.56918693f, --1.60992694f, 1.10167944f, -0.80843484f, 0.07521993f, -1.15994942f, 0.76016301f, -0.40056285f, -1.16872537f, -0.91384381f, -0.36700436f, 1.82389200f, -1.18200207f, 0.51612782f, -0.92479187f, -0.09307563f, -0.55122334f, 1.23532486f, -1.11124146f, -0.05812126f, 0.68159896f, 0.69125599f, -0.77127314f, -0.10874277f, 0.86469102f, --1.31614351f, 0.33354419f, -1.71750402f, 0.17197680f, -1.03965557f, 1.10570908f, -1.19115615f, 1.05115080f, 0.18277600f, 1.08820546f, -0.72191417f, -0.10999311f, 1.56521320f, -0.35433730f, -1.11799145f, 0.34499285f, 0.64998639f, -1.64371550f, 0.92592359f, -0.47659501f, 0.49101439f, -0.15613313f, 1.47486567f, 0.43576995f, -2.19538260f, -0.83567709f, -1.21846950f, 0.80400819f, 1.14637423f, -1.01503456f, -0.61992753f, -0.47378838f, 0.86503726f, 0.27147385f, 0.37073180f, -0.19951358f, 0.79167330f, -0.33982825f, 0.18631981f, -1.54715073f, 0.39967480f, 0.95067030f, 1.12508667f, -0.86676019f, -1.10341156f, 2.33141375f, 1.10972047f, 0.71407092f, -1.70640314f, 1.80666339f, 0.59465605f, -0.39653218f, -2.61163163f, -1.15013492f, -1.19908321f, 0.41783467f, -0.22730024f, 0.31425011f, -0.58562893f, -0.10131568f, -0.85047537f, -2.59974790f, 1.22072542f, -2.08812046f, -0.19363593f, -1.27664304f, -0.02703438f, 1.08477545f, -0.65506506f, 0.46040919f, -0.13715318f, --0.74945593f, -0.69006950f, -1.29617655f, -0.15865716f, 1.38956285f, 0.90216327f, -1.31185400f, -0.15067385f, -0.63093358f, -0.05895613f, 0.26545224f, 0.29332840f, 0.42852548f, 0.72409540f, 0.12879130f, 1.43038857f, 0.68647617f, 2.19654775f, 0.51878077f, -0.03769343f, 0.52877223f, -0.21733910f, 1.13710785f, -0.59003806f, -1.54624867f, -0.64997369f, -1.03239334f, 0.19708300f, 0.68658423f, 0.71048903f, -1.55250466f, -1.38636279f, 0.32385820f, 0.81226677f, 0.19209047f, -0.23002781f, -0.63631231f, 1.02101684f, 0.65428704f, -0.17206922f, 1.09488952f, 1.03022420f, -0.95567745f, -0.07595373f, -1.48606372f, 2.57174873f, -1.75366247f, 1.12913883f, -0.97053039f, -0.28552356f, 0.56511772f, -0.79568213f, 0.07561764f, -1.02085686f, 1.05770981f, -1.25715709f, 0.42046708f, -2.57390857f, 0.96947151f, 1.05215812f, 0.65624017f, -1.29019403f, 0.64157075f, -0.40509227f, -0.65354455f, 0.42348680f, -1.34107757f, 0.05931387f, -0.54337227f, 0.95460182f, 1.59319806f, -0.44433126f, --0.33717924f, 0.79566282f, 0.50112695f, -0.22244534f, 1.76904583f, -0.89817202f, 1.82985342f, 0.17671813f, 0.80720717f, 1.32469308f, 0.39417782f, -0.23720963f, 0.96796370f, -1.02348757f, -0.86615551f, -1.58120525f, -0.37634999f, 0.00905940f, 0.01880967f, 1.75771821f, -0.64372772f, 0.36687651f, 0.15854552f, -0.67599791f, -0.53726906f, -1.20158446f, -1.78549063f, 0.96476388f, -0.66158366f, -0.41681561f, -0.97541636f, 2.35928202f, 0.32130197f, 1.06886065f, 1.38736427f, -0.73718959f, 0.11215294f, 2.12865782f, -0.37927702f, 0.55621815f, -1.10108411f, -0.02032263f, 0.29595461f, 1.58737493f, 1.24001300f, -0.66748160f, 0.80729002f, -0.10575818f, --1.03175950f, 1.80755460f, 0.10825710f, 2.20666361f, 1.33633149f, 1.39290452f, 0.45211342f, -0.07837920f, 2.08304930f, -0.28387162f, -0.70775616f, 0.43626297f, 0.53556961f, 0.06201901f, -0.59255266f, -0.11854446f, 2.10024118f, 0.37638292f, -0.56178707f, -0.25220188f, -1.23731256f, -1.30002999f, 0.34283713f, 0.30502397f, --1.09233856f, 1.12430644f, 0.52273953f, -0.68507338f, -0.69913578f, 0.88440478f, -0.76959240f, 1.07093310f, -0.34802195f, 0.35683727f, -0.76079178f, -1.92807376f, 0.84499562f, 1.39131641f, 0.44825050f, 0.34567752f, 0.44607711f, -1.00986362f, -0.50038189f, -0.09060892f, -2.55645394f, 0.56416476f, -0.83058155f, -0.65931624f, --0.73649710f, 0.59814465f, -0.86736494f, -0.32200798f, -1.28087902f, -0.76818323f, 0.86848933f, -0.98678392f, -1.30813944f, -0.20255326f, 0.26557815f, -0.31090519f, -1.46331608f, -0.62782109f, 0.59034890f, 1.63147473f, -0.17727259f, -0.37636510f, 1.27368402f, 0.19096918f, -0.29936951f, -1.99038267f, 0.54831523f, 0.48849005f, -2.55680346f, -0.63126534f, 1.21715927f, 1.22841084f, -0.67416084f, 0.02927168f, -0.36693662f, 0.63204330f, 0.13721083f, 0.28742912f, 0.19470036f, 0.74873924f, -1.47602463f, 0.86264688f, -0.23730527f, -0.99978864f, -1.17048764f, -0.34996086f, 1.43019187f, 0.26224539f, 0.60689932f, -0.75002515f, -0.79823422f, -1.37300086f, -0.19951135f, -0.12150808f, -0.75272322f, 0.23755015f, 0.31270382f, 1.66539109f, -1.04104745f, 0.79540199f, -0.54042423f, -0.54150617f, 0.43871084f, 0.24163951f, -0.24517761f, -0.66178995f, -1.13064528f, -0.84426326f, 0.56437236f, 0.09088907f, -0.82823074f, 0.81753862f, -1.74096012f, -1.80599844f, -0.60943592f, 1.36094582f, -1.47762752f, 0.15931177f, 1.05569172f, 0.36751524f, 0.06497604f, 0.13536447f, -1.57156146f, 0.22783801f, -0.96910107f, -1.24294984f, -1.47147155f, -1.04790676f, 0.64629447f, -0.32266054f, -0.55675793f, -0.95612079f, -0.23005411f, -0.75229394f, 0.03050950f, -1.72484553f, -2.06055546f, 0.19892083f, -0.13597751f, 0.65180075f, 0.27096850f, 0.08977254f, 0.57564765f, -0.43227410f, 0.09541437f, -0.00358280f, 0.65680492f, 0.04006556f, 0.57160908f, 0.43821687f, 1.96118212f, 0.42602235f, -0.36731303f, 0.67200917f, -0.56667900f, 0.44014785f, 0.06970236f, -1.34415269f, -1.13301528f, -0.08848868f, 0.35615012f, -0.06426942f, -0.81406075f, 0.94097465f, -0.54560357f, -0.65877116f, -1.29646838f, -1.13109028f, -1.64186084f, -2.12723470f, 1.86027610f, 1.22621441f, 0.26098135f, -0.05608099f, 0.21143445f, -0.87244326f, 0.79408187f, 1.24279130f, 0.14458629f, 0.25532281f, -1.24023473f, 2.42278886f, 0.00405578f, -1.00119174f, 1.19856644f, -1.37395728f, -0.16656208f, 0.46858498f, -0.00678801f, -0.34960639f, 0.16614936f, 2.41560221f, -0.53880709f, 0.91618651f, -1.77009308f, 0.32911557f, 0.30216452f, 0.02881077f, 0.77705866f, 0.27061903f, -0.07440855f, -1.14010465f, 1.25383139f, -1.58615100f, 1.04185510f, 0.15140508f, -0.88059032f, -0.33872122f, -0.42526904f, 2.17365575f, 0.29308075f, -2.24234557f, -1.03164542f, -0.09263755f, 0.08050421f, -0.74946511f, -0.64589006f, -1.13416314f, -0.64989561f, 0.16502371f, -0.33831969f, 0.22832428f, -0.08389475f, -0.28009200f, 1.34536922f, -0.19075738f, 0.36238208f, 0.83690089f, 0.26144615f, 0.04457319f, -2.55585861f, -0.01807522f, 1.68334866f, -0.05795629f, -0.21315987f, -1.84039557f, 0.06512877f, -1.77318645f, -0.27637982f, 0.20439345f, 0.67558700f, -0.77179354f, -0.17902173f, 0.70381826f, -0.40395790f, -0.96492916f, 0.84138173f, 2.43879008f, -0.32297835f, -1.74370265f, -0.10330839f, -1.07465363f, 1.85030377f, -0.59153467f, 0.99667048f, -0.56753993f, 0.57383025f, -1.90630126f, 1.24299097f, 0.22797665f, 0.30468231f, -0.07360230f, 1.64654350f, 0.57195550f, 0.03227921f, 1.11005175f, 0.00088721f, 1.19266295f, 0.61323351f, 0.13754399f, 0.59900171f, -0.75831634f, 1.11500823f, 0.99747783f, -1.36923385f, 1.26563418f, 0.01253266f, 0.35483193f, 1.95143735f, -2.02703261f, -1.38265920f, -0.02404256f, 2.02788448f, -0.75144875f, -0.58445263f, 0.26129767f, 0.60691077f, -1.84661067f, 0.65872228f, -0.58298993f, 0.33067298f, -0.09431327f, 0.43333948f, -1.52616286f, -0.25961858f, -1.65459549f, -0.72950101f, -0.89906919f, -0.80081612f, -1.32189929f, -1.36574399f, -0.35809481f, 0.36385000f, 0.31480747f, -0.35797358f, -1.04066050f, 0.07971872f, -0.21176252f, -0.76559299f, -0.10352154f, 0.29248312f, -1.75030553f, 0.68219930f, 0.56189102f, -1.11212170f, 0.06501702f, -0.07131009f, 1.23410738f, 0.29311740f, -1.02052307f, 1.40220940f, -1.00995779f, 0.57955760f, 0.22640309f, 0.74853230f, -0.02586563f, -0.33427954f, 1.70311153f, -0.53405988f, 0.90975094f, -0.46450076f, 0.19904344f, 0.28559047f, 0.23167793f, -0.69065529f, -0.17176504f, -0.29301846f, -0.85477978f, -0.00267053f, -0.28529504f, -0.64201307f, 1.03479636f, 1.03805065f, 0.83270210f, -0.09405448f, 2.50615931f, 0.62019676f, 0.31354564f, -1.51599669f, 0.42848015f, 0.66263914f, 0.74651009f, -1.13042867f, -0.58933645f, -0.35146511f, 0.06223279f, 0.28065836f, 0.66506970f, 0.16942430f, -0.23316263f, -0.87481076f, 1.21992230f, 1.48536301f, -0.79667616f, -0.75519305f, 1.40999961f, -0.42802793f, -0.20252463f, 0.30573779f, -0.23319976f, 1.77525878f, -1.80704832f, 2.71519923f, -0.67500192f, 0.12268137f, -0.13014549f, -0.07479453f, -1.51065743f, 1.04198146f, 0.96205556f, -2.00525570f, -0.37911776f, 0.89329720f, -0.39495832f, -0.03683375f, -0.90928614f, -1.56263304f, 0.45038295f, -2.62184358f, -0.45686841f, -0.52536523f, 1.05351484f, 0.89982438f, -0.63724512f, 3.21004057f, -0.08608918f, 1.55209303f, 0.62688643f, -0.59702635f, 1.85774517f, 0.38172096f, -1.25640929f, -2.59278178f, 0.85050315f, -1.10080361f, -1.26422560f, -1.80045366f, -0.34494889f, 0.68448657f, 1.25671864f, -1.26594126f, 0.32244179f, -0.51956522f, -0.56212711f, -0.95574015f, 0.71973872f, 0.46736258f, -0.11772985f, -1.52736545f, 0.19571695f, 0.73147154f, 0.87724912f, -0.26265728f, -2.60267401f, 0.19263546f, 0.18320183f, 0.11485019f, -0.82999659f, 0.13582672f, -0.08040185f, 0.28152901f, -0.51421624f, -2.32467175f, 0.19923948f, 0.64616692f, 0.29718629f, 0.32785949f, -0.62266952f, -0.98174316f, 1.23276305f, 0.58563638f, 1.28528512f, -2.13718534f, 0.28842899f, 0.12676710f, -1.72105229f, 0.15053287f, 2.19496536f, 1.28683448f, -0.96318281f, 0.17043279f, -0.05245409f, -0.38710704f, -0.30441490f, -0.08249986f, 0.28423953f, 0.72963721f, -1.49658203f, 0.99077344f, -0.78913772f, -1.12661564f, -1.26294816f, 0.16517465f, 0.10124251f, -0.77198768f, -0.16342169f, 0.08615876f, 0.49711797f, -0.66083062f, 0.76648003f, 1.04756033f, 1.46122825f, -0.42798752f, -2.29203916f, 0.30444992f, 0.58697921f, 1.22166932f, 0.09022947f, -0.03920181f, 0.10444995f, 0.10361757f, 1.18224072f, -0.76641631f, 0.90802073f, 1.41639423f, 1.55682337f, 1.28101575f, -0.35396016f, 1.11443567f, 1.18218529f, -0.06048089f, 0.85024464f, -1.01789165f, -0.69154263f, 0.06663221f, 0.68429029f, 0.12560424f, 0.37915874f, -0.66829866f, -0.64524972f, -0.05568011f, 0.12230454f, -0.35041061f, 0.62027830f, -0.16739209f, -0.72145337f, 0.46263054f, -1.67837834f, 0.69413221f, -0.57243419f, 0.37638462f, -0.21446526f, -0.89821470f, 0.60078722f, -1.06706369f, -1.26132309f, 0.35714921f, 2.39221811f, -0.09376130f, 0.30760849f, 0.59180892f, 0.55815399f, -0.32628775f, 1.28890121f, -2.53237987f, -0.98241091f, 1.10520673f, -1.74751687f, -0.90837651f, -0.25220659f, -0.56625104f, -0.30691949f, 0.16058689f, 0.44309673f, -1.09874964f, -0.76747823f, -0.33679363f, -0.02535496f, 0.00990100f, 1.35318136f, -0.70140815f, 0.50937581f, 0.55386209f, -1.21721983f, 0.71376961f, -0.18079315f, -0.11077732f, 0.09292522f, -0.57235324f, 0.62748206f, 0.42587611f, 0.64860481f, -1.10635614f, 1.66414368f, 0.47505483f, 1.48602211f, -0.59611166f, -0.41932896f, -0.96542233f, -0.41756630f, -1.02963889f, -0.70070386f, 1.65803933f, 0.20138647f, 0.05895034f, -1.46152759f, -0.37278318f, 1.05535650f, 0.34437978f, -1.13257408f, 0.17635690f, 0.09386671f, 0.37079874f, 1.47695887f, -1.58420062f, -0.26100200f, 0.44847637f, 0.88847303f, -0.13877590f, -0.64620668f, -0.38019657f, 1.01608157f, 0.13357787f, 0.05137976f, 0.93498152f, -0.62226880f, 0.80461699f, -0.71682596f, -0.88756353f, 0.40933055f, -1.52167451f, 0.79756850f, -0.17307425f, 0.62368619f, -0.22466940f, -1.72802913f, 0.59047443f, -0.58020931f, 0.09096476f, -0.07317388f, 0.44522321f, -0.64880705f, 0.15684015f, 0.08708375f, -0.41556796f, 1.11579072f, -0.81733495f, 0.11643656f, -0.73995101f, 0.93685871f, 1.57971406f, 0.67606360f, 0.70509088f, -0.25283816f, -0.00010609f, -0.61884147f, -0.86409342f, 0.95383751f, -0.05895388f, -1.45261180f, 0.45166013f, -1.01434863f, 0.18496066f, 1.06517637f, 1.81127059f, 0.89470667f, -0.13232610f, 0.46958798f, 0.13884509f, 0.57117194f, 0.29575035f, -0.97884250f, 0.83291447f, -0.59255791f, -0.04354135f, -0.19431923f, 0.30071029f, -0.95421529f, 0.76359886f, -0.47799742f, 0.68254346f, 1.19368529f, -0.48935115f, 0.30357337f, -0.50225669f, -0.23370270f, 1.96702433f, 1.46558523f, 2.68482018f, 0.41622332f, 0.73697484f, 1.43430734f, 0.15387188f, 0.20875402f, -2.49335337f, -1.39674246f, -0.22125854f, -0.00424605f, 0.91416460f, 0.33384630f, 0.44703746f, 0.25610185f, 0.38966551f, -0.01784045f, 1.66148460f, 0.36005461f, 0.95716912f, -0.18246566f, -0.15480693f, 0.38775176f, -0.56969136f, -0.29644895f, -1.04565966f, -1.00455630f, 0.30897698f, -1.46885884f, 0.03657720f, -0.49302089f, 1.34134722f, 0.01673754f, 1.22725964f, 0.55256772f, 0.63803208f, -0.29041430f, 1.11455286f, 0.76329172f, 0.27073982f, 0.77173829f, -1.79884446f, -0.11889492f, -1.92040312f, -0.46382675f, 0.20078070f, -0.98889589f, 1.46711135f, -1.68280172f, -0.52852470f, 0.66245162f, 0.29575166f, 1.34826505f, -0.22362417f, -0.14345661f, -2.34815073f, 1.26572001f, 0.66505629f, 1.01141500f, 1.08030057f, 0.17036134f, 0.00168786f, -0.37282917f, 0.69206375f, 1.07367527f, -0.49708191f, 1.49504781f, 0.58224988f, 0.96593714f, -1.07661915f, 0.25202179f, 0.25531644f, 0.42357162f, -0.31236249f, 0.48383278f, -0.06361829f, 0.24131298f, -0.95695931f, -0.12589653f, 0.36134180f, 3.20266032f, -0.40879184f, -0.66985190f, 1.51674330f, 0.34072638f, 1.15076303f, -0.40199137f, 0.46223637f, -0.48608047f, 0.99119538f, -0.22506073f, 0.30968750f, 0.64210880f, 0.54640514f, 0.18607031f, 1.26293361f, -0.77960914f, 0.79572529f, 1.01936150f, 2.27160740f, -1.48034489f, 0.74466604f, 0.14863680f, 0.31102443f, -1.15673816f, -0.38609681f, -2.65026069f, -0.45524642f, -0.74022961f, 2.74991131f, 0.00103815f, -3.03303242f, -0.41556966f, -0.87103498f, 0.78306234f, -0.88195556f, -0.77297026f, 1.21203196f, -1.09754920f, -0.03556008f, -0.31546223f, 0.72954375f, 0.25251788f, 0.11378583f, 0.50921023f, 0.30301905f, -1.60631680f, 0.27152416f, 1.17342317f, -0.70891970f, -0.08392961f, 0.92137378f, -0.10568139f, -0.31653777f, -0.28878728f, 1.22166574f, 1.12693942f, -0.21325994f, 0.94010323f, 1.21796405f, -0.68866694f, 2.30724216f, 0.28141466f, 0.83481526f, -0.04885862f, 0.01675143f, 1.04355800f, -0.81050140f, 1.51300573f, 0.53429186f, -0.56439877f, 0.38572624f, -0.05620475f, 0.67644542f, 0.72528905f, 0.05937041f, -1.06315899f, -0.51393986f, 0.46937627f, -0.34699562f, -0.64765716f, -1.45512629f, 0.47739139f, -0.88228017f, -2.00791359f, 1.29929042f, 0.05482405f, -0.66725296f, -0.54735124f, 0.09972951f, 0.76675093f, 0.98748523f, 0.08900899f, -0.78854066f, 1.47970486f, -0.61667502f, 0.45625573f, -0.21766303f, -0.46250847f, -0.07130960f, 0.64414692f, 0.12784545f, 0.26393634f, 1.07720757f, -1.23938286f, 0.62483376f, -0.55001754f, -0.05358591f, 0.07322436f, 1.12003291f, -1.00830650f, -0.20486419f, 0.76664752f, 0.28850746f, -0.04464776f, -0.40146068f, 0.73262817f, -1.12827921f, -0.19989438f, -1.15999687f, 1.37973154f, 0.78881019f, -0.34762639f, 1.22088552f, -1.64088547f, 0.63218033f, 0.45736769f, 0.05502866f, 2.22683382f, -1.78935897f, -1.49635041f, 0.83450896f, 1.67770112f, 1.33909333f, 1.51158953f, 0.28595078f, -0.08593627f, 0.45812801f, -0.15193029f, 1.14770603f, -0.88920450f, -1.96352005f, -1.49894583f, 0.49629962f, 1.59872091f, 0.00903497f, 2.15563583f, 2.25149560f, -2.01200557f, 2.56229877f, -1.38850498f, 0.73552012f, -0.39378855f, 0.52616280f, -0.03685786f, 0.87403935f, 0.12163408f, 0.74297994f, -0.30697080f, 0.38139752f, 0.49113834f, -0.95485127f, -0.99908817f, 0.71716321f, 0.04000283f, -2.09645271f, 1.38789880f, 1.37198520f, 0.82493287f, 0.17114936f, 0.53696346f, -0.19516060f, -0.50377476f, -0.91730285f, -0.70113552f, -0.02406530f, 0.84943396f, -0.17428185f, -1.09140801f, -0.68156958f, 1.70756388f, -1.00399911f, 0.03023832f, -0.39023280f, -1.89737976f, 1.14469039f, -0.58337289f, -0.60037899f, -1.17490256f, -1.56342828f, 0.48714057f, 0.62266618f, -0.15967095f, 1.32789338f, -1.25700688f, -0.55633998f, -0.83128709f, -0.49346271f, 1.59561753f, -0.24675299f, 0.38012561f, 0.91796309f, -0.38522810f, -0.65509188f, 0.94100451f, -0.57324487f, 2.19070768f, 1.24058700f, -0.75978851f, -0.40460554f, 0.79189235f, 0.70192885f, 1.93569362f, -0.03070199f, 0.77010989f, 0.58794290f, 0.51087004f, 0.22892070f, 0.35007235f, 1.56023848f, -0.67453802f, -0.18485607f, 0.64349502f, -0.31489357f, -1.95834625f, 0.06560058f, 2.30394220f, 1.18194163f, -0.88034087f, -1.05000436f, -1.05471325f, -0.98481798f, 0.49904808f, 0.16438948f, -1.10297823f, -1.39736509f, 0.01306054f, -1.85160267f, -0.87292641f, -0.15418227f, 0.43412164f, 1.16518164f, 0.06273691f, 0.24659210f, -0.08267246f, 1.28885782f, 0.73575675f, -0.01019809f, -0.08753663f, -0.61827368f, -0.40863234f, 2.12599611f, -0.53620332f, 0.53789747f, -0.66386080f, -1.70461988f, 0.86608189f, -1.11151052f, 0.14120635f, 1.18858743f, -0.31760478f, -0.73533046f, 0.20978074f, -0.84074509f, 0.16523147f, -1.03362834f, 0.59721231f, 0.21318658f, 0.23671274f, 1.75115061f, 0.25363782f, -1.32541454f, 1.13056135f, 0.24652456f, 0.60381413f, 0.21478581f, 0.75044096f, -0.63125616f, -1.69889998f, -0.02116571f, 1.46165359f, 1.03068244f, 0.63693464f, 0.67795700f, 1.20033514f, -1.39205134f, -0.61743122f, 0.56549704f, 0.65182322f, -0.74250507f, -1.61939359f, 1.14054918f, -0.45725963f, 1.74519682f, -0.66251940f, -0.94811529f, -1.60865819f, -0.59968346f, 0.86309159f, -1.91936195f, -1.02646923f, -1.50352538f, 0.58292735f, 0.05320299f, 1.53582895f, 0.01069612f, 0.15226212f, -0.71840125f, -1.36896348f, 2.14600968f, 0.96626586f, -0.52014917f, 0.41001406f, 0.59478027f, 0.15282436f, 0.27790198f, 0.76614654f, -0.38971323f, -0.01839927f, -1.57882118f, 0.61391610f, -0.62133092f, -0.03968323f, -0.88467252f, -1.24041140f, 2.07306671f, -0.41776338f, 0.14537935f, -0.91069067f, 1.67362070f, 4.72630215f, -0.07395106f, 0.46280116f, -0.40843824f, 0.70683080f, -0.27510864f, -0.63465804f, -0.83630908f, -0.44419941f, 0.60405648f, -0.65039170f, -1.02413189f, 1.05983019f, 1.73366308f, 0.73343736f, -0.00895882f, -1.00826013f, 0.17323074f, 0.73995626f, 0.24128854f, 0.94510227f, 0.25557515f, 0.02244723f, -0.95197725f, -0.16297856f, -0.38497585f, 1.17993331f, 1.20282137f, -1.31491220f, 0.44229278f, -0.24349044f, -0.01230415f, 1.37944865f, 0.48554277f, -0.54510897f, -0.10793537f, 0.41121426f, -0.12889031f, 0.26434359f, 1.27966082f, 0.64518744f, -0.15577169f, -0.99864733f, -0.61746484f, 2.01614976f, 1.56254935f, 1.86473298f, -0.54662132f, -0.22047071f, -0.06118120f, 0.84799510f, 0.17009684f, -1.30523121f, 0.64000309f, 0.36299205f, -0.59620583f, 1.36372304f, -0.05389515f, -0.93849313f, 0.98043185f, -0.39373067f, -0.84898937f, 1.32077873f, 1.05988657f, -1.35339200f, 0.23259017f, 0.63816410f, -0.80297333f, 0.60017115f, 1.25715804f, 1.18894124f, -0.62473553f, 1.05611980f, 0.02335166f, 1.07509828f, 0.25873449f, -1.68341100f, 0.54547334f, 0.79288185f, -0.93678916f, 0.19202201f, -1.48575914f, 1.08649087f, 0.50851744f, -0.45758674f, -0.39734635f, 0.35637981f, -1.63079453f, -0.75910008f, 0.92640859f, -0.55599529f, -0.40276715f, 0.31307653f, 0.39907026f, -1.18830419f, 0.71051043f, 0.14157933f, -0.39581308f, -1.64361024f, -0.06161860f, -0.25312796f, 1.10018682f, 0.56500763f, 0.80385065f, 0.35395023f, 0.81813669f, 0.27644628f, 0.65563256f, 1.73197234f, 0.68178749f, 0.76769936f, 0.44597456f, 0.67761195f, 0.67635447f, -0.32315412f, 0.19330767f, -0.25557944f, 1.91693723f, 0.38335562f, 0.07107610f, -0.57384586f, 0.79184365f, 1.87835479f, 0.60902315f, -0.94220877f, 0.79479855f, -0.25656971f, 0.08739131f, 0.53384244f, 1.22159266f, -0.39152125f, -1.46373534f, -0.02458516f, 1.62825716f, -1.26112676f, 0.19967082f, -0.71114451f, 0.27929229f, 0.65001321f, -0.11868202f, -0.55587751f, 0.78069001f, 0.57969242f, -0.60274386f, 0.31650013f, 0.90339553f, 0.09453616f, -0.37119162f, -1.00320566f, 0.33299938f, -0.48636708f, 0.26342997f, -0.91914523f, 0.28682709f, -1.24780893f, -1.59254742f, 0.97176319f, 0.14744301f, -0.53056234f, -1.73221612f, -0.67645556f, 0.98705006f, 0.79895812f, -2.04333115f, -0.60132772f, -0.91653955f, -0.28094748f, 0.47943443f, 0.38157779f, -0.67648011f, 1.09093642f, 1.66012859f, -0.29358891f, -1.26773024f, 0.36747769f, -1.10141146f, 0.82383633f, -0.89772314f, -0.47145563f, 0.63939518f, -0.64430422f, -0.48889321f, -0.37680882f, -1.06962025f, -1.28689516f, 1.28365147f, 0.61859220f, -0.84676331f, 1.38404000f, 1.21053445f, -0.14871351f, 1.06349385f, 1.45878971f, -0.47362664f, 1.40707004f, 1.25224137f, 0.87364739f, 0.92858213f, 0.00157326f, 1.45661485f, -0.27318576f, 0.15482858f, -1.07058907f, -0.06903186f, -0.74147576f, -1.64111829f, -0.67226541f, -1.13458407f, 1.28511488f, -0.41041154f, 2.09085560f, 0.45243183f, -0.67437285f, 0.84960121f, -1.49300814f, -0.42961186f, -2.35021853f, 0.57255560f, -0.73903763f, 1.37607956f, -2.44575167f, 1.25105727f, 1.38575912f, -1.16299784f, -0.13719854f, -1.11507034f, 0.35796806f, -0.64511567f, -0.87903833f, 0.32833642f, -0.87696886f, 0.02714214f, 0.30224666f, -0.69118696f, -1.23500824f, 0.76678628f, -3.20508122f, -0.24704689f, 0.49019828f, -1.20862615f, -0.03778638f, -0.07273687f, -0.11517122f, -1.75857520f, -1.64188445f, 1.21574795f, 0.57325113f, 1.14370298f, -1.07824504f, 1.70653832f, -0.03700557f, -0.47645858f, 0.11065386f, -1.03143036f, -2.18094873f, -0.94403434f, -0.09335683f, -0.44817665f, 1.39707148f, -1.21947956f, 0.56575936f, -0.69612634f, -1.12361753f, -0.17105591f, 1.15422392f, 0.02840637f, 0.09469353f, -0.52859986f, -2.08487725f, 1.28789508f, -0.03740775f, 0.61196613f, 1.23405397f, 1.56595814f, -0.65800631f, 2.02985072f, -0.69446486f, -0.88443804f, -0.23448054f, -0.43628734f, -0.45888957f, -0.21943338f, 1.78258693f, 1.75214970f, 0.71804136f, 0.49782532f, 0.37886053f, -1.59176385f, -1.74758542f, -0.02820176f, 0.75398153f, 1.00119829f, 0.80881971f, -0.53365272f, -0.22720885f, 0.37476870f, 0.01005529f, -1.23421800f, -0.13431595f, -1.01843679f, 1.87386346f, -1.68539488f, -1.04942071f, -0.77322137f, 0.53964764f, 0.29278332f, -0.58299130f, -1.56022692f, -0.79441273f, 0.49289709f, 0.44112054f, 1.07305002f, 0.54899335f, 1.13781393f, 0.77809113f, 0.81795985f, 0.16576190f, 0.32552773f, -0.20250474f, 1.46543837f, 0.12731771f, 0.21013761f, -1.34241438f, 0.44267517f, 0.93246883f, 0.08808212f, 0.92653406f, -1.21083558f, 0.17247954f, -0.70557106f, 0.04630012f, 0.48834828f, 0.89634645f, 0.46683592f, -0.29553145f, 0.46363977f, -0.48971879f, -0.88603491f, -0.12333342f, 0.37073737f, 0.92061806f, 0.54675460f, -0.14716248f, 0.75578392f, -0.98173791f, -1.15983224f, -0.58713156f, 0.07950903f, -0.59016788f, 0.41622928f, -0.32474482f, 0.42086437f, 0.23061797f, 0.62596649f, -0.22615278f, -2.14721417f, 1.01685894f, -0.25976995f, 0.00739352f, -1.31597066f, 0.39005190f, -1.09549701f, 1.68375242f, 0.43331525f, -0.37124026f, 0.22255214f, 0.59654880f, -0.73840386f, -1.20048976f, 0.12226126f, 0.12997478f, 1.04826224f, 0.03894836f, -0.36289826f, 1.14466560f, -1.18198848f, -0.03713558f, 0.67677927f, -0.42329931f, -0.89409167f, -0.77874780f, 0.58438253f, -0.35176343f, -1.53329861f, -0.02995299f, -0.40145162f, -1.51052392f, 0.09194464f, -1.13275242f, -0.61983156f, -0.40004560f, -0.19893464f, 0.22134103f, -0.03903082f, 1.14894116f, -0.03476744f, 0.22520730f, -0.55851930f, 0.76650429f, -0.57863152f, -1.34161711f, -0.31498179f, -1.19411755f, 1.70044947f, -0.17428267f, -0.35983825f, -0.42613637f, 0.58165723f, -0.77866900f, -1.59727287f, -0.61723864f, 1.51078022f, 0.32971445f, -0.86441469f, 0.60552609f, 0.00208178f, -0.47096625f, -1.10479307f, -1.21652532f, -0.08211990f, -1.43739200f, -1.31684434f, 0.43312529f, -0.76822090f, 1.88128507f, -0.02179282f, 1.04971325f, -1.55004108f, 1.25337446f, 0.11203052f, -1.16048300f, 1.59467411f, -1.29469275f, 1.14019871f, 1.20021439f, 1.84098923f, 0.05004879f, 0.73529941f, 2.05272865f, -0.13080600f, -0.08436690f, -1.17919350f, -0.66256678f, -0.36727047f, 0.73840511f, 1.22293818f, -0.00206342f, -0.29839504f, -0.00618613f, 1.04213119f, 1.21176076f, -0.62886089f, -0.02589060f, 0.96009409f, -0.64478731f, -1.16516542f, 0.57528079f, 1.04294407f, -0.09774588f, 0.45935291f, 1.03263175f, 1.00633478f, -1.82209253f, -0.18035053f, -0.28302726f, -0.83813244f, 0.57593471f, -0.03807700f, 1.60498738f, 0.16530658f, -1.43083501f, 2.10824299f, 0.30279446f, -0.03961089f, -0.38900724f, 1.31272805f, -0.56575215f, 0.57970244f, -0.48305038f, 1.34114623f, 0.21859215f, 0.66399640f, -1.52087069f, -1.30717897f, 0.14394683f, 0.97648209f, -0.71372712f, -1.22574198f, -0.27702177f, 0.04041927f, 0.02442212f, 2.19617033f, -0.48566443f, 0.81463927f, 0.20383844f, 1.17562282f, -0.33829874f, -0.42141283f, -0.96415234f, -2.39141965f, -1.04285860f, -0.23004992f, 0.41186509f, 0.03811268f, 0.36818987f, -0.71099734f, -0.56749570f, 0.18486284f, -0.44530040f, 2.14008284f, -0.27467576f, 1.70690107f, -1.40462613f, 0.24697532f, -1.31629777f, -2.20674944f, -0.67868507f, -1.15767133f, -0.64391804f, -1.79037917f, 0.58749497f, -1.58303332f, -0.69021022f, 1.64376318f, -0.95393223f, 1.98415601f, -0.10991055f, 0.02474386f, 0.23683345f, -0.63420391f, -0.57991928f, 0.83028817f, -0.40033704f, 0.19212338f, 0.74640590f, 1.10264432f, -1.65286255f, 0.92683482f, -1.42252541f, -0.74605089f, 2.14535880f, 0.12971123f, -0.47971717f, 1.67546797f, 0.42268261f, 0.22648531f, -0.42369929f, 0.77403021f, -1.31818616f, -0.67143595f, -0.04311426f, 1.64128351f, 0.34776631f, -0.39353722f, -0.42765084f, 0.16170517f, -0.54488391f, -0.38428506f, 0.42097485f, -0.55982012f, -1.74543798f, 1.53704774f, 0.43562424f, -0.30395737f, 0.31846946f, 0.39205357f, 0.57386035f, -1.11912560f, -1.39164317f, -1.04337609f, 0.31629622f, 1.51927638f, 0.88745505f, -0.40445471f, 0.25783861f, 1.88646257f, 0.36509129f, -1.13266826f, -0.45394278f, -0.48400903f, -1.22332740f, 0.38626808f, -1.10049105f, 0.84138852f, 1.27863181f, 0.53942156f, -0.67743856f, -0.03896645f, 1.70393491f, 0.60997570f, 0.43368068f, -0.13338457f, -0.18920666f, -0.29583672f, -1.40738738f, 1.03876019f, 1.71253765f, 2.12821221f, -0.96092403f, 0.93841934f, -0.79030478f, 1.36427641f, -1.39196694f, 0.08514920f, 0.16223004f, 0.71259701f, 0.20150672f, 0.25068361f, -0.99952722f, 1.80129099f, -1.28586197f, -0.64957166f, -0.94813949f, -0.40161121f, 0.31977695f, 0.54932386f, -0.67757767f, 1.88086259f, 0.92337233f, -1.64887333f, 0.44333732f, -0.19468001f, 0.12977587f, 0.21171951f, 0.27679422f, 0.49134475f, -1.44429457f, 1.25617445f, 0.39978400f, 0.99869555f, -1.61617446f, 1.61177349f, 0.70243025f, -0.95748568f, -0.61795151f, -0.77302909f, 0.72967088f, 0.81964350f, -0.71813750f, 0.90140164f, -1.45950246f, -0.79972702f, 0.40875742f, 0.00152073f, -1.74491429f, 1.53776145f, 0.75769204f, -0.22075878f, -0.58385569f, 2.18884754f, 0.33597681f, -1.66265559f, 1.03805876f, -1.55245185f, -0.03582226f, -1.94542754f, -0.76081425f, -0.50471377f, 1.35763168f, -0.39631784f, -0.17134467f, -0.82220149f, -0.41021580f, -0.00940776f, -0.80176353f, -0.19816744f, 1.22061026f, -0.14486519f, -0.71727395f, -0.65721530f, 0.47020102f, -0.70403302f, -0.94795334f, 1.79884899f, 0.07779162f, -1.50615680f, 0.04140327f, -0.22001404f, 0.63735324f, 0.79237640f, -2.25412822f, -0.52519119f, -0.87280381f, -0.07100742f, -0.94734806f, -0.12286110f, -0.13623615f, -0.42595413f, 0.17547913f, -0.81707209f, 0.36855817f, -1.68186557f, 0.19312963f, -0.66249490f, -0.98283452f, -0.33314428f, 0.40918943f, 0.88268638f, -0.05390308f, -0.22440539f, -0.15879378f, -0.34859571f, -0.01013108f, -0.30005428f, -1.19408464f, 0.21789688f, -1.07769871f, 0.81475031f, -0.69555300f, 2.35201311f, -0.40362412f, 0.93497628f, 1.13343573f, 0.92343372f, 0.26987928f, 0.46123627f, 0.22577702f, 1.26289701f, -0.45956740f, 0.55994868f, -0.58410591f, 0.13304594f, -0.25806463f, 0.49044946f, -0.82065403f, -3.06672239f, -0.27774641f, 0.68504512f, -0.21386372f, 1.11427057f, -0.73201770f, 0.51655543f, 1.77261138f, 0.72081727f, 0.11116749f, 0.16637769f, -0.74987584f, 0.66579849f, -0.75808716f, 0.20678560f, -0.67698354f, -0.82141948f, 0.61008269f, 0.66520184f, 0.44894725f, 0.73015076f, -1.52517414f, 0.11714164f, 1.90452611f, -1.30355322f, 0.12144456f, 1.18547559f, -0.07349755f, -2.28061509f, 0.83522540f, 0.78438890f, 2.19334102f, 0.90305614f, -0.59345531f, 0.77925014f, 1.32338643f, 0.14068902f, 1.19032264f, 0.20666829f, -0.76595837f, 0.74967057f, 2.86965609f, 0.55690205f, -1.72530472f, -0.83317834f, -0.85842621f, -0.29678273f, 1.80955839f, -0.70496303f, 1.19106734f, -0.92985237f, -1.00617313f, -0.56049556f, -0.29382578f, -2.04022193f, -1.95356870f, -0.42553005f, -0.33369407f, 1.02115977f, -1.45769477f, -0.67720300f, 0.53819913f, 1.57643425f, -0.47015440f, -1.47861958f, -0.00545934f, -0.97836047f, 0.42680529f, 1.56110144f, -1.49487829f, -0.65198445f, 0.22720462f, 1.83036661f, -0.47099793f, -0.09915133f, 0.14923312f, -1.16313052f, 0.67798084f, -1.63665557f, -0.38220280f, 0.01719763f, 0.30041245f, 0.43148938f, -0.44021657f, -1.25734651f, 0.02465564f, -1.00845659f, -0.28574651f, 0.01367745f, 0.77253437f, -0.99399441f, 0.61445391f, 0.18343423f, -0.50997210f, 0.41359940f, 0.77279282f, 0.83511519f, 0.27929801f, 0.70800692f, -0.20278299f, 1.57884383f, 0.22650529f, 0.43347472f, 0.74003208f, -0.71401161f, -0.69829476f, -1.56766701f, -0.99254119f, 1.27301061f, 2.73726511f, 0.66089469f, -1.95778012f, -1.24642098f, -0.63579029f, -1.63168180f, -0.66980726f, 0.81933254f, 0.61866677f, 1.40594471f, 0.05158535f, 0.00196500f, -0.24592508f, -0.50780547f, -0.83905292f, -0.10748957f, 0.04490763f, 0.27769178f, -0.23227681f, 0.82108080f, 0.03562285f, 0.95483875f, -1.49897683f, 0.67809856f, 0.35497451f, -0.44021592f, -1.67361462f, -0.88895375f, 1.44293678f, -0.85046643f, -0.46437624f, -1.87252641f, 0.26775804f, -0.24535774f, 0.73365933f, 0.52253938f, 0.27947086f, -0.58796054f, 0.59045380f, 1.93476331f, -0.46775359f, 0.25238225f, -1.26601815f, -0.13324316f, -0.71454948f, -0.21610366f, -1.49586582f, 1.04903507f, 0.22208478f, 0.25512528f, -0.46157327f, -0.41319233f, -0.63846964f, -0.25100923f, 0.81277549f, -0.26959971f, 0.88737756f, 1.24578953f, -0.91121447f, -1.05756927f, 0.44390878f, 0.16672316f, -1.22941923f, 0.89547867f, -1.50212002f, -1.69620168f, 0.53339505f, -0.23656729f, -1.69879091f, 0.01510374f, 0.08315694f, -0.73196459f, -1.60263407f, -1.07601058f, -0.76389569f, -1.65307498f, -0.61484390f, -0.43546933f, 0.71318507f, -0.16273083f, 0.64122051f, -0.15406294f, 1.17673671f, -0.91240519f, 0.71091145f, 2.40497613f, 1.26343656f, 0.71469337f, 0.20705548f, 0.81776261f, 0.36253929f, -1.92106628f, -0.09300470f, -0.36648872f, 1.27732766f, -0.39180157f, -0.61186749f, -1.03455031f, -0.25079829f, -0.61479062f, -1.07094336f, 0.82218504f, 0.89934880f, 0.41308978f, -0.59968555f, 0.37682834f, -1.77388155f, 0.00294951f, -0.66145372f, -0.50789726f, -0.85123241f, -0.89909405f, -1.89454281f, -0.56692821f, 1.52272677f, -0.11961794f, 0.27843913f, -0.60582250f, 1.01871169f, -0.36098275f, -0.12242325f, -0.67375034f, -0.11204147f, -2.62773919f, -0.95901299f, 0.14040214f, 1.32364666f, -1.35099924f, -0.11077739f, -0.79319423f, 0.75949597f, -0.25485823f, -0.90959758f, -0.42373934f, -1.29850340f, 0.85699379f, -1.11882365f, 0.63470817f, 0.49696380f, -0.07983235f, -0.23903450f, -0.22618714f, -0.12117998f, -0.09442677f, 1.55589819f, -0.11996678f, -1.72700179f, 0.54683149f, -0.40804827f, -0.50099218f, 0.34596699f, -1.81841791f, 0.06385052f, 0.84428120f, 0.69901514f, 1.94559097f, 0.43251973f, 0.16794942f, 1.82829034f, 1.70959795f, 0.36130908f, -0.94608402f, -0.53498030f, 0.47781768f, -0.24203247f, 1.25065851f, 0.51788396f, -2.09381890f, 0.72973937f, 0.03281829f, 0.58632666f, 1.85737121f, -0.49569523f, 0.45921183f, 1.87173629f, 0.22803484f, 1.66433418f, -1.05872321f, -1.13663685f, 0.12397861f, -0.65112090f, 0.98152941f, 0.83739656f, -0.18783289f, 1.84249437f, -0.90706986f, -0.80824369f, -1.23854923f, -0.86488134f, -1.02627063f, 0.10976455f, -0.61403006f, 1.27554715f, 0.14653525f, -0.03953953f, -0.08512071f, -1.30043304f, -0.02566035f, 0.12054887f, 0.00282162f, 0.48921332f, -1.74398839f, 1.44554436f, -1.35854721f, 0.69256759f, 0.34101671f, 2.50045252f, 0.49121150f, -0.27115449f, 0.93974596f, 0.26258010f, 0.27151433f, -0.87214381f, -0.92580765f, -1.03269923f, 0.20615758f, -0.37822601f, 0.58983004f, 0.16426525f, 0.68218285f, 1.98158526f, 0.47492698f, 0.54224718f, 1.28722692f, -1.76915324f, -1.11240053f, 0.77428484f, 0.27184650f, 2.22473478f, -0.05574624f, 0.39976570f, -0.43911108f, 0.52805597f, 0.17340177f, 1.36057591f, -0.35004014f, 1.72787797f, 0.68357420f, 1.25532615f, -0.56752264f, 0.51840127f, -0.21237844f, -0.58821255f, -0.85278064f, 1.90179110f, -0.67447448f, -0.36831430f, -0.22930753f, 0.98231596f, -0.07011599f, -0.08560387f, 0.05998110f, -0.02481356f, -0.57335132f, -0.44288307f, -0.24468307f, 0.53321087f, 1.19609559f, 0.10664973f, 0.24379487f, 0.93687552f, 0.93615580f, 1.74319768f, -0.68310338f, 1.32163060f, 0.61918712f, -0.76501870f, -0.54549301f, 1.74077415f, -0.69977754f, -0.66880983f, -1.15981388f, 0.81571609f, 0.53788543f, 0.47898352f, -0.02484704f, -1.64646924f, -0.69822907f, 0.27020717f, 0.05027051f, 1.75149667f, 0.01548872f, 0.32615909f, 2.55151844f, -1.29172051f, -0.36133784f, 0.98637396f, 0.14009331f, -0.50038946f, -0.92230296f, 0.17307127f, 1.05361068f, -1.46784890f, 2.38960409f, 1.19413340f, -1.33349669f, 1.59141159f, -0.71811068f, 1.22429430f, 1.26947939f, 1.08177102f, -1.18138707f, -0.72775704f, 0.17282635f, -0.40554270f, -0.40341887f, 0.46564049f, -1.02069795f, -0.07653128f, -0.13979210f, -0.31195050f, -1.72042310f, 1.37131393f, 0.63849634f, 0.75561279f, 1.81152904f, 0.26686314f, 1.32796574f, 0.56100166f, 0.70058894f, -0.88962644f, -0.04360984f, -0.88249093f, 0.24311203f, 0.50410056f, -2.22567797f, 0.94520348f, -2.12467694f, 0.47282359f, -0.71379906f, -0.09857135f, 0.62374717f, 1.37182784f, 0.73380554f, 0.59745449f, 2.80427694f, 0.67253572f, 1.65335357f, 1.69891667f, 1.34585941f, -0.79989213f, 1.44980943f, -0.52013642f, -0.46971673f, -1.50070012f, -0.25687039f, -0.56916732f, 0.71065760f, -1.31996286f, 0.96031237f, 0.13929774f, 1.49679291f, -0.05966444f, -0.58674580f, -0.08278833f, -0.93390942f, 0.42415768f, -1.77889526f, 0.75336021f, -0.72699982f, -0.82880586f, 0.63955617f, 0.42771208f, -0.42366457f, -0.91581815f, 0.94750947f, 0.43123913f, -0.99053741f, 0.70470595f, -1.16662264f, 1.14847183f, -0.83885664f, 0.46714026f, -2.27748466f, -1.23656678f, 0.14695056f, -0.33159894f, -0.52553117f, -0.04391259f, -0.29630372f, 0.25949728f, 0.96991086f, -0.37714824f, -0.28251833f, 0.16106486f, 1.38844633f, -0.18713553f, -1.30708838f, 0.48490265f, 0.29553881f, -0.45505449f, 0.83341682f, 0.87346369f, -0.63516861f, 0.66063565f, 0.93892503f, -2.73996735f, -0.81515318f, -0.91458052f, 0.00978268f, 0.43472794f, -0.08090764f, 1.37249672f, 0.76722521f, -1.19154143f, 0.22046764f, 0.34916410f, 0.51383299f, -0.56379753f, -2.49949312f, -0.74207872f, -0.68400806f, -0.09663232f, -0.07199454f, -1.05562651f, -0.75028551f, -0.87253797f, 0.69039482f, 0.45923674f, -1.27515161f, -0.04555376f, -1.41501272f, -0.83773375f, -0.74807298f, 1.36646152f, 0.06317432f, -1.32559633f, 1.89092779f, 1.24883330f, -1.03608561f, 1.08677161f, -0.99629849f, -0.69947034f, -0.85716367f, -0.07947286f, -0.25485426f, -0.19732477f, 1.64581251f, 1.04618108f, 1.87186897f, -0.18198362f, -0.83807969f, 0.70462501f, -3.18930101f, 0.74610996f, -0.60935193f, -0.49383929f, -2.88986492f, 0.51707613f, 1.04620326f, 1.09837818f, -1.19840038f, -0.10391295f, -0.20789115f, -1.51052022f, -0.31087330f, 0.22411564f, -1.30506921f, -1.52000105f, -1.51593041f, 1.04321992f, 0.97611690f, 0.90424490f, 1.83324766f, -0.08682299f, 0.47035542f, 1.70865905f, -0.31108001f, 0.04115159f, -1.36352801f, -0.90797836f, 0.32128647f, 0.66191489f, 0.08681208f, 0.14993365f, 0.47110486f, -0.31522670f, -0.38906571f, -0.08876022f, -0.13106902f, 2.25685239f, -0.62211353f, -1.68553007f, -0.23707703f, 0.69236159f, -0.46686995f, -0.27520603f, 0.26619941f, 1.48525345f, 1.61278927f, 0.49452963f, 1.20846486f, -1.11853909f, -0.30010033f, -0.75471467f, -1.69959772f, -0.52042168f, -0.43881389f, -1.45240712f, 1.02122891f, 1.73639011f, -0.03813924f, -0.22239220f, 0.15797073f, -0.64418089f, -0.60228932f, -0.83248150f, -0.02042520f, 0.38137484f, 0.86056453f, 0.06410559f, -0.62785137f, -0.49916875f, -2.53796315f, -0.79168582f, -0.69197005f, -0.77175534f, -0.28669405f, -0.79764080f, 0.97218460f, -0.10351621f, -0.52759898f, 1.02840185f, 1.16363287f, 0.08351815f, -0.61088538f, 0.59944046f, 1.54409397f, -1.39842033f, 0.27917057f, -0.27146137f, 1.46310735f, 0.03626106f, 0.15038440f, -0.07894899f, -1.42527366f, 1.69641745f, 1.48384345f, -0.43328866f, -0.54252565f, -0.94416499f, 1.54436302f, -0.81367069f, -1.67925239f, -0.17525831f, 0.27891046f, -0.69066733f, 0.89911050f, 0.11606655f, 0.67450327f, 0.41538724f, 0.90886223f, 1.19786549f, 0.85810721f, 1.32862210f, -0.83469814f, -1.09682298f, 0.88092703f, -0.97478902f, -0.11664717f, -0.07929394f, -0.69581884f, -0.16928329f, -0.70731819f, -0.40485084f, -0.28954300f, 0.52882415f, 0.38769314f, -1.38704026f, 1.15099049f, -0.43566978f, 0.34459323f, 0.49520254f, 1.11130333f, 0.28783718f, -0.53783375f, -1.63577271f, 1.02222812f, 0.86302060f, 0.48346213f, 0.46627176f, -1.30133855f, -1.48477137f, 0.31219670f, -1.21498191f, 0.89838904f, 0.87186617f, -0.39968935f, 0.34930915f, -0.32909471f, -1.39364409f, 2.13006306f, 0.33270469f, 0.00215986f, 0.97776711f, 0.24908836f, 1.56164885f, 0.45157790f, -1.55970144f, 0.27677536f, 0.07662498f, -0.08262251f, -0.17658773f, 0.65820259f, 2.01052690f, -1.71946216f, 0.84686053f, -1.23594892f, 1.40792072f, -1.47772563f, -0.36132276f, -0.50405115f, 0.09009213f, 0.81659186f, 1.85574234f, -0.64974433f, 0.63352364f, 1.01766217f, -1.54804432f, -0.42570522f, -0.24763709f, 0.72822112f, -0.93733686f, 0.68087620f, -1.40644944f, 0.48672482f, 0.09725539f, -0.64416331f, -0.95747960f, 0.36771363f, 0.39155054f, -0.71790671f, -2.17222738f, -0.08655047f, -0.97842115f, -0.22991380f, 0.52029115f, -1.42072022f, 0.29576331f, 0.32391560f, -1.00823236f, 1.67909145f, 1.16841447f, -0.32307062f, 0.15756166f, -0.97590631f, -0.39429301f, -0.03583352f, 0.17554663f, 0.57961231f, -0.46873134f, -0.23343173f, -0.85060924f, 1.71745574f, -0.04658702f, 0.63088381f, -0.67581934f, -1.53171062f, -1.58800113f, -1.17987096f, -1.16737640f, -0.87544650f, -1.17138922f, 0.38979119f, -2.39369726f, -1.34747124f, 0.58450359f, 0.87791806f, -0.04459394f, 0.97995293f, -0.10354915f, 0.65324986f, -0.17833626f, -0.85849386f, -0.42063358f, 0.19708554f, 0.10255250f, -0.59539181f, 0.86194044f, 1.68610668f, 0.55275291f, -0.43127069f, -0.04218780f, -0.08466262f, 0.31236625f, -0.92824298f, -0.09879152f, 0.32358822f, 1.04045570f, 0.35617545f, 0.09059231f, 1.19069445f, 1.96978688f, 0.63561743f, 0.15030998f, -0.29879019f, 0.22774190f, -1.01608860f, 1.03605175f, 0.47804731f, -0.30450734f, -0.61382371f, 0.45390254f, -1.93547988f, 2.01267338f, 0.52447683f, 0.18379784f, 1.11913633f, -1.24273467f, 0.15803322f, 1.72184098f, -0.79349059f, 0.10258614f, -1.53445125f, 0.02630571f, 0.81649125f, 0.91089755f, -1.12968338f, 1.04016411f, 0.28999722f, 0.74863863f, -0.61388236f, 0.01665530f, 1.43592548f, 0.68138391f, 0.11963340f, -1.26123953f, 1.36340797f, 0.25696915f, -0.58877039f, 1.42209792f, 0.55563360f, -1.33329606f, 1.84695840f, 0.88433737f, 1.04359078f, 0.18906727f, -0.03448994f, 1.17944050f, 0.86783957f, 0.44934425f, -0.77892244f, -1.76232874f, -1.01689589f, 0.78943914f, 0.92141974f, -1.00187087f, -0.13809921f, -0.90222073f, 1.10094714f, -0.13657950f, -0.44349849f, -1.61441302f, 1.05724919f, 1.50337231f, -0.05785890f, -0.76958144f, -0.51498759f, 0.69227600f, -0.37975949f, 1.31949317f, 0.82049531f, 0.32868597f, -0.31557772f, -0.75534385f, 1.27303052f, 0.43453619f, 0.11296938f, 1.18182182f, 2.23387384f, -0.86412978f, -0.01599468f, -0.70869064f, -0.09221385f, -1.23729551f, 0.79490280f, 0.03522846f, -0.95069039f, -1.73461652f, 0.72329187f, 1.40385795f, -0.11585230f, -0.78033113f, 0.07491048f, -1.12873089f, 0.18476245f, 0.57568848f, -0.28792691f, 1.35411644f, -0.76956165f, 0.29571572f, 1.03178787f, -0.38780826f, 0.31680650f, 0.69368076f, -1.23856580f, -0.49848995f, 0.14766994f, 1.02625990f, 3.03858209f, -0.51030380f, 0.96796870f, 1.35078156f, -1.07729447f, 0.84322494f, 0.54886484f, 1.31453705f, -0.45792100f, 0.31196272f, -0.15701357f, 0.83586836f, -0.74952888f, -1.17432022f, -0.31002575f, -1.02149463f, -0.36117774f, -1.22079086f, 0.03532525f, 0.00555908f, -0.45891216f, 0.29636297f, -0.68272704f, 0.41257843f, 0.37988129f, 0.01747893f, 0.82739186f, 1.52292180f, -0.79456621f, 2.20275712f, 2.13212132f, -0.81393015f, -1.15712392f, 0.22488308f, 0.62776327f, -0.85444915f, 0.44017896f, 0.05863331f, -0.83198178f, 0.93063420f, -0.16121253f, 0.12382501f, -0.37826315f, 0.93118382f, 0.19507533f, -0.58595538f, 1.46994352f, 0.13170272f, -0.70031989f, -0.12820166f, 0.30487457f, 0.84148771f, -0.68807501f, 0.21187615f, -0.67030680f, -1.79136002f, 0.70810199f, -1.20959783f, -0.08468831f, -0.06317700f, 1.35527098f, -0.47018668f, -0.91693246f, 0.14818805f, -0.05405350f, 1.16875637f, -0.17363262f, -1.61833882f, -0.32934523f, -0.38346377f, -0.62702698f, 0.34135151f, 0.48015586f, -0.65263331f, -0.04689486f, 0.01156854f, 0.37580970f, -0.16174591f, 0.59627324f, 0.24351901f, -0.87983090f, 1.57049024f, 1.25836349f, -0.41464049f, -0.62279183f, 0.09693756f, -0.23850618f, -0.49007827f, 0.22298151f, 0.10914832f, -0.35192192f, -1.27221346f, 1.10203624f, -0.86399704f, -0.47319838f, -0.77105570f, -1.68624854f, 0.81198281f, 0.82534081f, 0.75654501f, 1.47631240f, -0.61000234f, -0.58933264f, 0.54822850f, -1.22829592f, 0.11107657f, 0.56449169f, 1.50693524f, -0.59280968f, -0.64286685f, -0.20120731f, 0.27184448f, 1.55500400f, -0.48919386f, 1.04044867f, -0.87048137f, -0.40569979f, 0.21908638f, -0.51829034f, -1.48748124f, 0.02990401f, 1.83462536f, 0.29885170f, 1.32370698f, -1.30129600f, 2.43271399f, 0.22967771f, -1.13014007f, 0.95529765f, -0.83325785f, 0.43633386f, 0.85774118f, 0.78160155f, 0.58583075f, 1.18906367f, -1.54354560f, -0.68320692f, 0.01900371f, -0.79777133f, 0.12851712f, 1.10176420f, 0.79418170f, -1.41154039f, 0.36929929f, 1.12176800f, 1.23849642f, -0.89377707f, 1.01390159f, -0.50889206f, -1.12554002f, 0.17932732f, 0.48949540f, -0.54235244f, -0.28146735f, -1.39125514f, 0.13309635f, -1.12864995f, -1.29901242f, -0.04266220f, -1.98028529f, -1.34869373f, 0.00038156f, -0.92473024f, 1.48010647f, -0.02754467f, -0.26030368f, 0.93083733f, 0.27946711f, 0.64052200f, -0.04220961f, 1.25002527f, -1.07923257f, 0.19048618f, 0.08900311f, -0.40813437f, -0.73068553f, 0.52122378f, 0.68990833f, -0.38749605f, -1.09269309f, -1.63480806f, 1.01789618f, -0.61596102f, 0.81049860f, 1.30838764f, -1.49213874f, -0.77916288f, -0.72660202f, -0.92013240f, -1.61726642f, -0.11527207f, 0.35143322f, -1.11646879f, -1.45525432f, -0.82892823f, 0.15512508f, 1.01891017f, 1.40162635f, 1.02494884f, 0.33882582f, -0.78747398f, -0.26009330f, -0.38519114f, 0.79247451f, 0.02065756f, -0.48030257f, 1.01167107f, -1.74057114f, -0.84549171f, -0.15337363f, -1.92544484f, 1.01270044f, 0.00762185f, -0.16405612f, 1.61778915f, 0.93316060f, -0.68960994f, -1.13214970f, -0.94695878f, -0.28418848f, 0.17102109f, -0.08787476f, -1.83799696f, -0.13761258f, -0.18652774f, 1.46456254f, 0.34169790f, -0.40697145f, 1.49663997f, -0.99555492f, -0.67775637f, -0.51951116f, 1.35157657f, -0.27099034f, -0.46987835f, 2.28101230f, 0.59104478f, 0.75010139f, 1.01472175f, 0.25741309f, -0.56074983f, 1.12267506f, 0.35336846f, 0.61733276f, -1.63976014f, -0.17700450f, -0.25093642f, -0.75599891f, 2.10956192f, 0.95155340f, 0.72049862f, 0.50492924f, 0.62067389f, 2.08688402f, -0.73604703f, 0.63383341f, -0.53528428f, -2.11538506f, -0.98173052f, 0.59560484f, -0.26205051f, -0.91948050f, 0.00593397f, -0.11734286f, -1.41261208f, -0.83611172f, -0.27682739f, -0.20619918f, -0.36557615f, 0.77194935f, 1.67695415f, -1.39265156f, 0.04892010f, -0.37773246f, 0.16124558f, -0.18348448f, -1.38248885f, 0.58459854f, 0.65064198f, 1.11349559f, 0.36708066f, -0.15471332f, 0.14208725f, -2.06860566f, 0.29629150f, 0.93084633f, -0.47215626f, 0.60208917f, 0.95415461f, 1.03390312f, -0.03639749f, -0.23988228f, 1.27037442f, 0.95133096f, 0.33187470f, -0.34527761f, 0.22134073f, 1.01799667f, -0.81475645f, -1.18869019f, 0.23314142f, 0.25180560f, -1.23762786f, 1.25283313f, 0.16980635f, 0.40740708f, 0.59256923f, 0.16274920f, -0.69713289f, -0.16444311f, -2.41602516f, 0.37952334f, -0.05604568f, -0.23772651f, 0.20581599f, -0.54303211f, 1.71877348f, 0.83602583f, -0.32586128f, 0.73609394f, -1.73640239f, 0.07249248f, 0.31248692f, 1.77627432f, 0.97660398f, -0.42095289f, -0.18750280f, -0.84246057f, 0.29762223f, 1.87054563f, -1.46980762f, -0.45306337f, 1.52366042f, 1.39061129f, -0.04980387f, -0.55382830f, -0.96987218f, -0.06910808f, -0.41276473f, -0.83891344f, -0.92597574f, 0.60252470f, 0.21938549f, -0.04451685f, -1.00330937f, -0.36955237f, -1.52876902f, 0.27296364f, -1.96721256f, 0.05291027f, -0.91540521f, 0.48990685f, -1.99560380f, -0.68551093f, -0.14532298f, -1.56881595f, -0.08319287f, 0.31003201f, -1.42829597f, -0.61810297f, -0.03581250f, 0.77747720f, 1.25297558f, -1.36239243f, -1.13274276f, -0.35045877f, -2.34157228f, 0.04515179f, -0.83044821f, 1.81353962f, -1.36855912f, 0.39704823f, 0.16665934f, -0.16654585f, 1.17806077f, 1.00086153f, -1.25474250f, -1.46876431f, 1.18021631f, -0.32257929f, 2.12062597f, 0.86819613f, -1.18048275f, -1.69747460f, -0.74092305f, 0.05086798f, 1.15339577f, 1.32972670f, 0.27247882f, 0.98499072f, 2.35597157f, 0.30179837f, -0.66633248f, 0.13794266f, -0.22753908f, -0.22868259f, -1.81792033f, 0.50151759f, -0.79408127f, -1.05343878f, 0.45727381f, 0.84800923f, -1.73605800f, -0.02032863f, 1.82778001f, 1.41025102f, -0.81715560f, 0.25888795f, -0.25075480f, 0.66256499f, 0.11993053f, 1.81336939f, -0.06345166f, -1.49658346f, 0.07531686f, 0.96972889f, 0.87405980f, 0.75830793f, -0.13497087f, -2.45855975f, -0.65984958f, 0.93919373f, -0.97305542f, 0.73477978f, 1.04337513f, -1.22712576f, -0.46385625f, -1.20876372f, -0.82760453f, 0.01455977f, -1.05089867f, -0.02801843f, 0.60899758f, -0.82052249f, -1.48932517f, -0.98073828f, -0.19311285f, -0.25602359f, 0.50351876f, -1.24557400f, -0.82138073f, -1.45966852f, 0.44991320f, -0.75550151f, -0.98550314f, -1.21418869f, -1.15771639f, -1.72192061f, -0.39616469f, -0.55566746f, -1.31880891f, -0.08843257f, 1.00422776f, 0.35846478f, 0.46060917f, 0.77326930f, 1.60129988f, -1.85124147f, -0.30582917f, 1.30227256f, 1.81890345f, -0.44084981f, 0.25315762f, 0.70259613f, -0.94882858f, 1.97040296f, 0.71473581f, -0.68193883f, -0.36290962f, 1.16348684f, 0.15418798f, 1.07806778f, 0.40554729f, 0.10280909f, -1.06474805f, 0.64398485f, -0.63568884f, -0.06108581f, -1.03290677f, 1.02834034f, 1.15284693f, 0.14046004f, 1.86630619f, 0.46804786f, -0.68397558f, 1.60733378f, -1.64890087f, -1.03819239f, -1.19212389f, -0.78382361f, 0.03925850f, 1.52259934f, 0.09540676f, -0.21220762f, 0.55955195f, -0.39845437f, -2.14541650f, 0.49337825f, -0.68574250f, 0.74040270f, 0.50783634f, -1.60461199f, -1.26806450f, -0.12652303f, -0.83992827f, -0.15524681f, 0.40098447f, 0.23392735f, -0.23262636f, 0.06525709f, -0.35994548f, -1.08432877f, -0.21395946f, -0.78357452f, -0.57157278f, 0.71407390f, 0.86596155f, -1.13723528f, 0.13460183f, -1.20881450f, 0.71018457f, 0.68943661f, -0.70428050f, 0.64600736f, 0.01990297f, -0.10575775f, -0.80263519f, 0.10618331f, 0.08865548f, 1.51651669f, 0.60851854f, 1.15161908f, 1.04919207f, 1.18359745f, -0.04352076f, -0.83643389f, -0.07922365f, 0.10597949f, -1.34984851f, -1.91319740f, 0.71585363f, -2.10845160f, 0.64385056f, -0.54551518f, -1.02039802f, -1.62510490f, 1.65401149f, -0.42711899f, 0.07970079f, -0.21404363f, 0.30498922f, 1.07942021f, 0.63995659f, -1.82114816f, 0.56396323f, 1.07084870f, -2.00350380f, 0.53339815f, 0.18500003f, 1.15034151f, -0.21436051f, -0.99986565f, -0.58812016f, -0.07247020f, 0.78910017f, 0.48839527f, 0.98795873f, 0.10357288f, -0.05604928f, 0.38977858f, 0.73745090f, 1.40838420f, 0.25967824f, 0.23588051f, -0.03451392f, 1.04897523f, -1.77121758f, 2.35625434f, -0.67086869f, -0.84005541f, -0.85940343f, -1.04449213f, -0.65917015f, -0.78713167f, -0.95910054f, 0.38597879f, -0.31879017f, -0.86260867f, -1.08593106f, 0.02802678f, 0.99484950f, -0.55113328f, 2.60936737f, -0.03388772f, -0.47583574f, -0.14021793f, 0.99019170f, -1.22431207f, 0.78734446f, -1.77037835f, 0.15018673f, 0.36423206f, 1.36447549f, -1.61007094f, 0.51875496f, -1.60788095f, -1.73557448f, -0.41414359f, -0.93710536f, 0.38715765f, 0.04243837f, -1.59682858f, -1.10728157f, 1.88292623f, -1.01428258f, 0.01074958f, -1.88169158f, -0.31616244f, 0.45334938f, 1.12449574f, -1.16699445f, -1.59505820f, 0.04126552f, -0.89016622f, 0.45838884f, 0.71463561f, 0.14563711f, 0.30694655f, 0.67193079f, 0.61429602f, 1.00201404f, -0.49295208f, 0.05997690f, 0.99491668f, -0.73801446f, -1.17185295f, 0.94778723f, 0.36106884f, -0.43561545f, 0.04102699f, 0.52626407f, 0.08442099f, -1.57626402f, 1.56855237f, -1.65396678f, 1.74014664f, -0.38219589f, 0.39305371f, -0.31705827f, -1.15742850f, 0.11669596f, 0.54043210f, -0.52270615f, -0.13375773f, 0.68094701f, -1.84134769f, -1.49383473f, 0.14632171f, -0.54607725f, -1.20867658f, -1.28439069f, -1.81734920f, 1.54257309f, 0.78347659f, -0.24049839f, 1.69973648f, 0.99825776f, 0.99971974f, -0.26055810f, 0.34143049f, -0.44862366f, 0.11253342f, -0.60932243f, 0.70383030f, -1.87318194f, 0.21953633f, 0.82791799f, 1.64545465f, -0.42693698f, -0.64897031f, -0.97996652f, -1.06616282f, 0.52939081f, -0.12541170f, -0.57480675f, 0.73600835f, 0.35711968f, -0.03528263f, 0.79997194f, 0.55742902f, -0.28909785f, 0.64331138f, -1.79893720f, 1.01572442f, 0.27111965f, -0.51778597f, 0.12906317f, 0.76148927f, 1.51315522f, 0.41101140f, 0.38008851f, 0.66759896f, -0.13804778f, 0.64854795f, 1.73474562f, 0.75999504f, -0.73411214f, -0.05406699f, 1.35664344f, -0.25298578f, -0.12696666f, -0.42628938f, 0.61129904f, 1.55259824f, -0.05820796f, -0.38598019f, -0.87325627f, -0.55066222f, -1.24557889f, -0.26509118f, -0.32103062f, 1.14031804f, -0.75985742f, 0.70659167f, -1.15016067f, 1.24906838f, 0.90396994f, -0.16241251f, 0.43682271f, -1.42695689f, 0.47134697f, -1.66143429f, 0.08698819f, -1.00775325f, -2.24129725f, -1.04226267f, -0.98537570f, -0.89938259f, -1.80710697f, -1.22866321f, 0.78125423f, 1.55150509f, 0.46235040f, 0.18444096f, 0.19313288f, -2.20686269f, -0.40341458f, 0.50321484f, 0.47339424f, -0.81383848f, -0.21972439f, 0.66612029f, 0.60239881f, 1.20443010f, 0.70015103f, 0.30632916f, 0.01489905f, 0.68129027f, -0.89645082f, -2.68969011f, -0.96684915f, 1.66421318f, 0.74333072f, -0.78321886f, 1.60063362f, -1.27524030f, -1.95856726f, 0.47504124f, 0.15398432f, -0.20796098f, -0.13449343f, 0.93458968f, 1.60390890f, 0.21798505f, -0.27035928f, -1.23248971f, -1.25361061f, 1.34666133f, 1.07233441f, 0.88799530f, -1.23687923f, -0.40781614f, -0.11916534f, -0.88050151f, -0.66422415f, -2.61471510f, 0.78276747f, 2.42323995f, -1.70715427f, 0.71550035f, -0.60298312f, 0.70491880f, 0.46175584f, 0.80827898f, -0.45108104f, -0.98219043f, -1.72823501f, 1.73190725f, 0.53906441f, -1.50445580f, -0.59250867f, -0.07239901f, 0.44743437f, -0.13740127f, 1.69935930f, -1.00480616f, -0.58191377f, 0.39853972f, -0.60960841f, -0.45473522f, -0.76396072f, -0.31872150f, 1.74509728f, -0.59950751f, 0.89810580f, -0.81400329f, 1.14280319f, 1.11165059f, -1.31295311f, -1.60784578f, -0.87506992f, -1.13461006f, -2.09486437f, -0.16449419f, -0.37728927f, 0.47595578f, -0.55342919f, -0.17574213f, 2.21499181f, 1.14331865f, -0.14938518f, 0.18935619f, -0.33802557f, 0.52538890f, 0.82673949f, 1.16562462f, 1.24713838f, 0.98890215f, -0.64991701f, 1.49886703f, 1.97769642f, 0.08059916f, -1.60925281f, -1.23822486f, -1.40829837f, 0.51331180f, -0.29928651f, -1.04348791f, -0.39911583f, 0.69380492f, 1.54516888f, 1.22791195f, 2.25008130f, 1.33348894f, -0.21775827f, -0.71937007f, 0.54982573f, 1.70691478f, 0.32459491f, -0.57187974f, -0.21614684f, 1.08274269f, 0.41384646f, 0.24497485f, -1.43703413f, 0.89616930f, 0.82032162f, -0.24598582f, 0.84271127f, -0.81894702f, -0.01828136f, 1.70397091f, 0.39505738f, -0.51221430f, -0.87979966f, 0.10795479f, 0.45194778f, -0.76008922f, 1.23394477f, -0.56798172f, 1.06459570f, -0.44333413f, -2.40399075f, -0.37267187f, 1.42946172f, 0.95734519f, 1.86127949f, -0.15217264f, 1.68742633f, 1.97638428f, -0.44211119f, -0.98393327f, -0.54173928f, -1.72017395f, 0.74697793f, -1.77827263f, -1.92299354f, -0.17189410f, -0.48633271f, -2.21230388f, -0.45906609f, -0.53493047f, 0.37253976f, -0.56951141f, 0.07728028f, 0.03530006f, -1.18123293f, 1.94158125f, -1.55930352f, 0.69334733f, -1.95163214f, -0.95800400f, -0.01804711f, -0.56747472f, -0.99099451f, -1.52853060f, -0.98279524f, -1.67307866f, 0.96121490f, 0.35654056f, 1.74034202f, -1.44633865f, -0.27781928f, 1.79457986f, -0.41029963f, -0.76871634f, 0.36555341f, -0.77664107f, 0.19535238f, -0.76185411f, -0.19828433f, -0.88820636f, 0.63885397f, 0.11346363f, -2.50265074f, 0.16319332f, -1.01288569f, 1.86605489f, 0.89761645f, 1.11795115f, -0.00714116f, -0.89034635f, -0.76447034f, -0.18822117f, -0.48340848f, -0.99788517f, 1.02172959f, -0.39395007f, 0.72566581f, -0.81438208f, -0.71715081f, 0.96243578f, -1.36424279f, -1.13870537f, 1.17602491f, 0.16320205f, 0.71959788f, 1.66669416f, 0.55690295f, -0.28912008f, -1.19219172f, 0.23308393f, -0.37963116f, 0.45347008f, -0.42606446f, 1.30938649f, 1.25128853f, 0.57649273f, 0.34440875f, -0.23893952f, -1.06604803f, 0.31336102f, 0.75727910f, 0.46772480f, -0.37650385f, -0.06036821f, 1.03686309f, 0.46158856f, -1.81028461f, 1.43393028f, 0.85494965f, -2.34685564f, -0.17571987f, -0.45592231f, -1.31190526f, 1.73194158f, -0.11856517f, 0.07041293f, 0.25689471f, -0.56000596f, 2.06649089f, 0.38954756f, 1.36627376f, 0.13905638f, 0.77370811f, 0.43944249f, -0.08798827f, 0.07245751f, -1.30234015f, 0.29710820f, 0.74389762f, 0.11971968f, -0.07381748f, 1.32652700f, 1.34079397f}); - - auto input2 = NDArrayFactory::create('c', {3, 4, 4, 5}, {0.98114507f, 0.96400015f, 0.58669623f, 0.60073098f, 0.75425418f, 0.44258752f, 0.76373084f, 0.96593234f, 0.34067846f, 0.57962620f, 0.77517051f, 0.97472977f, 0.79237527f, 0.68690428f, 0.21719366f, 0.79959206f, 0.84814187f, 0.22496814f, 0.08646965f, 0.31110474f, 0.79813162f, 0.19661444f, 0.57760099f, 0.72138960f, 0.15244268f, 0.87687051f, 0.11130344f, 0.01087698f, 0.34817841f, 0.54992017f, 0.23443850f, 0.31725614f, 0.59755220f, 0.20364695f, 0.00531392f, 0.23403114f, 0.07442912f, 0.83707647f, 0.89291743f, 0.09044587f, 0.69041462f, 0.29904183f, 0.61904680f, 0.85306847f, 0.34467042f, 0.95839152f, 0.54517124f, 0.29640937f, 0.94855959f, 0.95970016f, 0.94045145f, 0.95510301f, 0.34666505f, 0.34717010f, 0.69245678f, 0.71669175f, 0.59043738f, 0.64924132f, 0.06033522f, 0.60185199f, 0.04690073f, 0.59241154f, 0.40229547f, 0.23002481f, 0.45161195f, 0.73743778f, 0.93209113f, 0.37294358f, 0.50177744f, 0.15072501f, 0.26146917f, 0.05252146f, 0.04758931f, 0.76448288f, 0.85149045f, 0.08840467f, 0.07692576f, 0.33180160f, 0.27241259f, 0.74834620f, 0.56453640f, 0.23057286f, 0.68429752f, 0.11961551f, 0.39045977f, 0.44356094f, 0.77018807f, 0.07984410f, 0.47926806f, 0.26165759f, 0.18606064f, 0.89972877f, 0.17962874f, 0.47273120f, 0.64641705f, 0.61890443f, 0.58730015f, 0.25937832f, 0.35231561f, 0.10243882f, 0.17459193f, 0.95906995f, 0.09227025f, 0.30003223f, 0.41601210f, 0.38269713f, 0.84799751f, 0.59295173f, 0.76277990f, 0.68910424f, 0.37672606f, 0.40675461f, 0.94346058f, 0.91438505f, 0.84728183f, 0.64367667f, 0.74899979f, 0.60570691f, 0.16417363f, 0.68852426f, 0.85486889f, 0.22585792f, 0.86953176f, 0.07465519f, 0.93096301f, 0.38008822f, 0.38752587f, 0.44004038f, 0.13170612f, 0.94541045f, 0.89349973f, 0.69245307f, 0.94978877f, 0.98776658f, 0.79445884f, 0.30607409f, 0.58264961f, 0.37980538f, 0.41810784f, 0.48903038f, 0.51615888f, 0.57682794f, 0.82481897f, 0.78341080f, 0.48446465f, 0.17447931f, 0.71125424f, 0.30263851f, 0.70675352f, 0.03215584f, 0.92381065f, 0.22343694f, 0.08851149f, 0.91402490f, 0.70074717f, 0.30912192f, 0.37723206f, 0.97579397f, 0.23554587f, 0.95939133f, 0.41565709f, 0.01741416f, 0.58362787f, 0.22106662f, 0.89065537f, 0.31900249f, 0.41280911f, 0.67947610f, 0.04545590f, 0.15352812f, 0.85412524f, 0.84933222f, 0.80000225f, 0.93147073f, 0.70094105f, 0.69269875f, 0.95282194f, 0.65913582f, 0.79186874f, 0.59855248f, 0.39707430f, 0.95126239f, 0.15618217f, 0.33446689f, 0.98123758f, 0.84770758f, 0.98081012f, 0.54427413f, 0.18728519f, 0.89792955f, 0.53360126f, 0.72812986f, 0.13307744f, 0.51217443f, 0.66708084f, 0.29416915f, 0.31298995f, 0.39155037f, 0.29288291f, 0.87063305f, 0.61759154f, 0.73723332f, 0.37167635f, 0.82122716f, 0.22937430f, 0.76570536f, 0.47911792f, 0.02826214f, 0.94277323f, 0.59945469f, 0.19042060f, 0.68173155f, 0.82771295f, 0.95649538f, 0.40833101f, 0.90838542f, 0.55245881f, 0.49011012f, 0.36773444f, 0.34513527f, 0.42050683f, 0.16113964f, 0.30969388f, 0.27174174f, 0.12117655f, 0.35270175f, 0.81967867f, 0.63723136f, 0.84309389f, 0.71822576f, 0.84883484f, 0.32306117f, 0.08176457f, 0.56175486f, 0.34892198f, 0.09306929f, 0.85437582f, 0.13925577f, 0.48629188f, 0.29923539f}); - auto exp = NDArrayFactory::create('c', {3, 8, 8, 16}, {5.98743296f, -2.83037376f, -0.87943113f, 1.41339970f, 1.32433391f, -1.20299149f, -0.02893090f, 2.05326009f, 1.19417048f, 5.58212376f, 3.28139353f, 1.19237995f, -1.09431255f, -2.55264497f, 3.11014652f, 6.81296825f, -2.09029293f, -4.32068443f, -0.52808392f, -1.97968531f, -0.18673831f, 0.84605980f, 4.55825520f, 2.71503139f, 0.15210046f, 0.85310984f, -3.82062817f, 2.76470995f, 3.69004202f, -1.45017099f, -2.59361267f, -1.35094655f, 7.24145126f, -5.25432396f, 0.19920218f, -4.30596399f, 1.35318923f, -3.88142037f, 3.67493343f, 2.25931478f, 2.87630725f, 1.66349852f, 6.21347952f, 0.94105923f, -1.61742055f, -2.35699606f, 0.12850338f, 1.79141688f, -2.09535933f, -6.35418081f, -0.06303531f, -4.38615131f, 0.48237842f, 0.26528549f, 3.38231516f, 3.76315165f, -0.40254810f, -0.23716694f, -6.13381910f, -0.41950428f, -0.89680839f, -1.46491277f, -1.98541689f, -0.99357355f, 5.58237648f, -2.38937521f, -0.00872564f, -2.37138414f, 4.91117287f, -4.51916361f, 0.97943687f, 2.91052818f, -2.50362611f, 1.70252812f, 5.04137802f, 3.57108784f, -1.87532270f, -3.66677809f, -2.38861251f, 5.55765152f, -7.27571774f, -1.68887305f, -0.72266489f, -4.42809057f, -0.92118186f, 1.02381468f, 4.44284725f, 5.17150497f, -0.42438728f, 2.02693963f, -1.36484981f, -1.47912180f, 0.26649538f, -0.02091765f, -2.86906910f, -3.03046989f, 1.35122132f, -3.21707630f, 2.21112418f, 0.24121630f, 3.96940088f, -7.66105747f, 2.76352382f, -0.99061489f, -2.16720009f, -1.63170409f, 1.12701774f, -1.02415371f, -0.90435314f, -1.51372027f, -0.76884907f, 0.39066136f, -0.89562428f, -2.03204703f, 1.28074932f, -2.14551091f, -2.36843777f, 0.46580017f, 0.75451565f, -0.00336730f, -1.06597757f, 3.27195978f, -0.41307712f, -0.10376054f, -1.34102952f, -2.22901654f, 2.31929803f, 1.40851438f, -2.23774385f, 0.20417206f, -1.12153268f, -0.13188094f, -3.96649432f, 2.10269976f, 0.49845099f, 6.18937683f, -0.51783508f, -0.48048639f, -1.92970264f, 3.16670656f, 1.13355756f, -0.07890664f, 1.31536257f, -0.43924797f, -0.04562932f, -0.87974954f, 0.75411212f, -2.39745235f, -3.97132111f, 0.37202546f, -2.40399146f, -1.50796390f, -3.08302689f, 0.23075986f, -0.94316757f, 1.34948587f, 0.58591264f, 2.18529797f, 7.97652435f, 2.32798409f, -4.09404373f, 0.89634895f, 0.77697754f, -0.65091681f, -7.05506849f, 5.86194515f, 2.51394033f, 4.69959354f, 0.20835471f, 3.18049693f, -1.29682434f, 3.70832396f, -0.48123091f, -1.67904007f, -1.35418940f, 1.58435583f, -1.13851106f, -1.19225955f, 0.59713769f, -5.80462933f, -7.45143986f, -1.08658695f, 1.03244078f, -1.75307107f, -7.07100582f, 3.85825157f, 1.62127817f, 2.32572675f, 0.56171900f, -0.80591971f, 3.98835945f, 0.15742642f, -2.97832179f, 0.13821673f, -0.72556758f, -0.84936106f, -7.28444147f, 3.94134307f, 0.80779338f, 7.47784615f, 8.23335075f, 4.80595016f, -4.89574575f, 4.03362942f, -6.67522192f, -4.55204487f, 2.12511182f, -2.70781207f, -1.57226098f, -3.08408356f, -0.30812448f, -5.32870674f, -5.13238287f, 0.49605465f, -0.55042171f, 0.46324944f, -3.83545256f, -0.12562510f, -0.20978995f, -0.13068712f, -1.92144060f, -1.68787408f, 5.45581436f, -0.79583496f, -2.38866687f, -3.90546346f, -0.47028148f, -0.14319679f, -3.37016582f, 2.00905991f, -1.21345615f, 1.81376505f, 7.73004007f, 0.74310112f, -4.64536428f, 3.78111577f, -9.05182457f, -0.10674095f, 1.53476238f, 0.63345337f, -0.40907967f, -1.44729769f, -1.87145400f, -2.46623540f, 1.07472968f, 0.77390999f, -3.93438888f, 4.49174690f, -0.96686655f, 1.92278123f, 0.30049133f, -0.02388665f, -1.99777114f, -3.23885751f, 5.87784004f, 2.13776040f, 3.56758308f, -3.37774134f, -3.67526293f, 1.63700044f, -1.69959962f, -0.99112594f, 6.03103638f, 1.67399430f, -1.28699589f, 7.16759014f, 12.63490295f, 3.62937450f, -4.75982571f, 2.17861104f, -2.03065681f, 4.30207729f, -0.46797156f, -2.96022511f, -6.02702332f, 3.09229851f, -1.39771092f, -0.03471333f, 3.22175527f, 5.63565636f, 1.78195477f, -0.63545251f, -3.99497652f, 1.46043062f, 4.60050488f, -2.96651959f, -2.03159475f, -1.52386189f, -0.15129802f, -3.90390921f, -0.63852370f, 0.79210538f, 2.35288715f, -5.55609035f, 5.36427498f, -0.60248077f, -0.26181316f, 5.04884720f, 8.53192806f, 5.05080223f, -6.56371737f, 1.52260923f, -7.13623667f, 6.49414349f, 2.33445597f, -4.11490965f, -6.44347477f, -0.47079402f, -0.63467920f, 2.60399365f, 1.05958164f, 3.66901422f, -1.05657935f, 1.88611507f, -6.37475634f, 2.01480770f, 3.36020517f, -5.11001921f, -0.46132171f, 2.16525555f, 4.21938848f, -2.08346295f, 2.86168146f, 1.26987600f, 6.76066971f, -7.84916353f, 4.11700916f, 0.47985530f, -4.60113716f, 7.42062473f, 6.37472820f, 4.37820530f, -7.12197018f, 0.01357239f, -7.90392113f, 8.32131577f, -0.87593079f, -0.16994858f, -5.86345863f, -0.20697471f, -1.37845206f, 1.63819647f, 1.59720242f, -0.74357712f, -1.88725603f, -1.98357940f, -8.57950306f, -4.10104513f, 3.57231879f, -2.89855957f, -0.11263305f, 2.78033924f, 1.53078973f, -2.93089223f, 0.73189604f, 3.20563078f, 3.92601013f, -5.21916151f, 0.89163935f, -0.42978728f, -6.70888853f, 4.56477976f, 1.20105875f, 3.83393812f, -6.27205181f, 4.05993128f, -7.35513067f, 1.60660768f, -1.21052051f, 1.58191252f, -1.37899971f, -1.20117283f, 2.93301678f, 1.06302834f, 1.38993621f, -1.66884089f, -3.34452581f, 1.04498529f, -4.10412455f, -4.03310585f, 1.61513603f, -1.09388447f, 2.11451387f, -0.94192362f, -0.23287666f, 5.88265705f, -0.83010495f, -2.15317154f, -0.60276151f, -1.49265075f, 3.93397975f, 5.45194483f, 1.45161700f, -2.57401872f, -5.59288931f, 4.29170895f, 1.87151814f, 0.08362055f, -0.28767288f, 1.17675185f, 0.85266006f, 1.30549634f, -5.60830832f, 0.19398519f, -0.83982587f, 1.75940764f, -5.46077394f, 1.64495635f, 0.17102760f, -0.54459631f, -2.21975255f, -0.37443402f, -2.08474159f, 1.85959935f, 11.19680309f, -0.18611598f, -2.59765387f, 3.06330776f, -1.52183700f, -4.88415241f, -0.75097847f, 2.58201051f, 7.40885210f, 3.58994508f, 1.62457407f, 3.12514591f, -4.36833286f, 1.39830995f, 3.61003447f, -0.63837433f, -3.62661815f, 3.78898096f, 2.92802262f, 5.87374496f, -4.38554621f, -2.53411579f, -2.87311554f, -1.31391978f, -4.26736879f, 3.45099425f, 1.58769250f, 1.73341393f, -1.08842182f, 2.27120280f, -1.78938174f, -2.29940319f, 7.07046986f, 0.51426595f, -6.22928905f, 5.28968811f, 2.31827855f, -4.20915890f, -1.27249205f, 5.92120600f, 3.19458675f, 7.09252501f, 3.96577907f, 6.41484213f, -4.66009521f, 10.00181389f, 0.51108456f, -4.62243366f, -5.18351841f, 2.12961674f, 5.10694027f, 7.29412317f, 0.15912467f, -3.38902974f, -4.01918602f, -2.17383957f, 0.13118666f, 0.27872476f, -0.92317247f, 3.51440644f, 1.84171486f, 1.03378081f, 1.30569839f, -2.09583759f, 9.03952980f, -0.55187917f, -2.04549074f, 1.08294606f, -2.65263700f, -2.93977118f, 1.88909876f, 0.96043622f, 1.76579499f, 3.14314699f, 5.86394691f, 7.36944389f, -7.04524136f, 6.68673229f, -5.52591467f, -2.19745898f, -4.32036924f, 0.52971321f, 2.26268244f, 6.91575766f, -0.94590527f, -3.98923349f, -0.12266219f, 0.24294075f, -1.07783222f, 1.87989080f, -3.57109427f, 1.61553633f, 0.42486978f, 0.75852054f, -6.19481468f, -3.80570698f, 2.39946675f, -1.93851781f, -5.42234039f, -6.34092760f, -2.52374983f, -1.85044456f, 3.92693520f, 0.40042299f, 4.69742584f, 5.40483189f, -1.02398944f, 8.89605045f, 0.64680403f, 0.89943957f, 0.76993859f, -1.88244629f, 1.90714884f, 3.10836840f, -0.17064989f, 0.84892416f, -6.94988108f, 1.92141032f, -1.36458397f, 6.39284658f, 0.45201308f, 2.58823442f, 6.33375788f, -4.76916075f, -8.45738983f, -0.48962492f, 2.40652561f, 4.56602001f, -3.34420681f, 1.86862195f, -7.01420689f, -6.94657421f, -2.47419310f, -4.61693668f, -0.18822384f, -0.36949772f, 2.01374269f, 4.11018658f, -5.11564064f, 8.04294395f, 2.88567662f, -2.87645102f, -1.23238611f, -5.91409397f, -0.62205851f, 1.38689423f, -0.01120412f, 5.25955677f, -1.98474956f, -3.72012186f, 3.00445986f, 4.99141550f, 2.97457719f, 2.70827627f, 6.04544449f, -0.20756161f, -10.87035751f, 0.80454814f, 0.33568168f, -2.48132324f, -2.84452009f, 2.63126230f, -3.99351716f, -7.39294338f, 3.62798953f, -8.65815926f, 2.65992808f, -6.98126554f, 3.09881067f, 0.67735767f, -1.15946686f, 5.63180256f, -0.17694545f, -8.59651184f, 3.75297594f, -2.35913754f, -0.20330384f, 5.49958467f, 1.00861740f, 1.42849684f, 0.00062013f, -0.11073381f, 2.15207863f, 4.07368469f, 1.14344299f, -1.27953362f, 6.64699316f, -0.73672432f, -8.55606937f, -0.19439441f, -4.14319754f, -4.69964647f, -5.86446047f, 2.87106085f, -3.42714882f, -5.00668287f, 6.22464132f, -7.72335291f, 4.05667686f, -5.72637177f, 6.35073948f, -1.29593158f, 0.00813985f, 3.63368607f, -1.05764008f, -7.88486052f, 3.73919106f, 1.41835213f, -1.04935634f, 0.65119827f, 0.03547254f, 1.88996327f, 1.58701086f, -0.56215239f, -0.80187100f, 4.55604362f, -0.67249978f, 1.41084409f, 7.86281586f, -2.38301182f, -8.50535774f, -3.82098866f, -2.40856767f, -5.33439016f, -3.34747362f, 2.69389009f, -1.64118791f, 4.52447939f, 0.04468334f, -1.48768258f, -0.69848812f, -0.71123981f, 3.66259432f, 6.10314512f, 1.37305343f, -0.62758982f, -2.99383426f, 4.20510864f, 1.48497128f, -0.08954811f, 2.43872309f, -0.59880185f, 0.37431365f, 2.45458341f, -3.28401661f, -1.94629693f, -1.93975246f, -0.26385683f, -0.45814323f, -0.18108580f, -3.74811840f, -0.29739976f, -2.24116230f, -0.28150487f, -2.24421668f, 3.46930790f, 8.35415077f, 0.05562943f, -2.81079793f, 1.10388446f, -2.82245207f, -2.98102283f, -1.08132946f, 1.19089699f, 8.00183105f, 6.35385323f, 3.72591257f, 4.59467506f, -5.74890900f, 4.42238331f, -3.36533451f, 0.18350232f, 3.05606651f, 1.18788099f, 2.87450886f, 0.27472210f, -2.80111074f, -0.66314960f, -1.96376896f, 0.75167024f, -4.72056293f, 1.10629988f, -5.00775242f, 1.48246133f, -3.91681528f, -1.86573625f, -6.17714882f, -0.67820001f, 5.69730282f, 1.04399037f, -4.93794823f, 3.09619617f, 2.18692017f, -5.54232264f, -3.10046840f, -0.68972743f, 2.81824327f, 3.04334164f, 6.13203907f, 4.14081764f, 1.02573645f, 5.71970081f, -6.01574707f, -2.07346702f, 0.99554527f, 1.69641590f, 0.66776669f, -0.80132431f, -2.03513098f, -3.42513680f, -0.06704485f, -1.87195873f, -5.42428589f, -0.20748445f, -1.52408111f, 0.97084987f, -0.48799962f, -0.45379883f, -0.26652339f, -1.20720732f, 3.94169855f, -3.18480229f, -1.87440264f, -1.18028760f, 0.52011997f, -2.13437462f, -4.52583313f, 1.69722807f, -0.89371562f, 3.37972403f, 6.38838720f, 6.98663378f, -4.05421400f, 6.89512825f, -5.09085655f, -2.16257906f, -3.33272719f, -3.01246452f, 0.37613097f, 1.80455804f, -0.36456174f, -5.32273912f, -1.29978943f, -0.53685790f, -2.12896323f, 2.55506587f, -2.57999182f, 3.40891910f, 1.36033249f, 0.83864629f, -2.88629293f, -7.36048365f, 5.61314154f, 1.32668555f, -2.58041072f, -3.71943092f, 1.60647738f, -2.74816346f, 2.47269106f, 0.85507953f, 8.39183426f, 3.42624784f, -0.01519036f, 5.68412066f, 2.51771593f, 1.03045523f, -2.08733034f, -2.44337177f, 0.81668580f, 1.30275154f, 2.99679208f, -2.91957355f, -1.71337795f, 3.34979844f, 1.51825011f, 5.20375061f, 2.27888370f, 1.38787699f, 4.23474550f, -4.05878592f, -4.85074377f, -0.22794735f, 4.64402294f, 1.24391258f, -2.04935098f, 1.26285601f, -7.51862240f, 0.62138438f, -1.95792389f, -0.96587181f, 0.85141110f, 0.79354531f, 7.93766356f, 6.07677746f, 2.05947518f, 6.55480623f, 1.44032848f, -0.70615625f, -0.07896036f, -5.08359432f, -0.01047915f, -1.89632201f, 2.57555676f, 3.83779287f, 0.42850614f, 1.80754125f, -0.06942326f, 6.35997963f, 6.06101418f, -0.97032297f, 5.71477222f, -6.06671238f, -3.46607208f, -4.98306370f, 2.84659123f, -2.11025190f, -0.04609144f, 5.26831341f, -9.56940651f, -3.67193556f, -1.71143103f, -1.35221267f, -4.26226807f, -6.89146233f, 8.21761799f, 5.69823503f, 2.28137946f, 1.88911343f, -1.44562483f, -1.60295713f, -0.52568185f, -3.31892347f, -2.81997776f, 0.35287106f, 2.98202395f, -1.39432132f, -2.70001364f, -4.14169264f, 3.50194883f, 4.12610435f, 5.52755260f, 2.65859175f, 3.61353087f, -0.83027136f, -5.10652542f, -4.48625374f, 2.06585884f, -2.76383352f, -0.64300913f, 8.19686604f, 0.96106279f, 2.45952058f, 2.47275925f, -1.03288829f, -0.64897656f, -3.77937531f, 4.27940083f, 2.58320260f, -0.57665241f, 1.87247813f, -3.81604433f, -0.24543774f, -1.62118483f, -0.73075479f, -0.48533297f, 2.05016756f, 0.45561486f, 0.03316188f, 0.77791005f, -1.56283605f, 2.36616826f, 5.58082104f, -1.30925488f, -1.06329608f, 2.17189479f, -3.43008828f, -4.71520567f, -2.56184673f, 0.17508316f, -3.25817418f, -0.41749167f, 0.18119079f, -0.73181152f, 3.99792433f, -3.08002281f, -0.99143314f, -1.83520067f, 1.18565679f, 2.98040128f, 5.67814350f, 2.35128760f, 1.41600966f, 4.02718067f, -0.08193968f, 0.64636409f, 1.35931289f, 2.37125754f, 1.75978124f, 3.90977740f, 1.50662971f, -2.84089065f, 1.29824126f, -3.38730979f, -1.61005294f, 0.58292413f, -0.03019404f, -1.57986510f, -0.56102908f, -3.03128719f, 0.51644313f, -2.01147819f, 0.98400700f, 3.00028515f, 0.74579155f, -3.37098312f, 0.93339360f, -1.29018497f, -2.14695001f, 1.30411184f, 0.71501279f, 7.47793055f, 4.06516457f, 3.50772929f, 3.52762985f, 0.55643129f, 0.32272506f, -4.30955982f, 2.49414706f, 2.07820845f, -0.34377906f, 4.39805031f, 2.77561307f, -3.91292810f, 2.43981409f, 0.18861845f, -2.76658440f, -4.97148752f, 3.25273705f, -0.08929539f, 0.19818619f, -5.83767605f, -0.97381884f, -5.68745661f, -5.42433214f, 3.98769903f, -0.40394354f, -1.83387578f, -0.80109525f, 1.47454357f, -3.14899540f, 0.80130816f, -2.26348829f, 4.06121159f, 6.13077354f, 5.31226397f, 2.94966197f, -3.65217376f, -1.08136678f, -7.14119816f, -0.85269439f, -0.70365787f, -0.81598872f, 3.62807679f, 3.08123684f, -7.82739496f, 4.07951784f, -0.14204243f, -0.66969109f, -5.07225513f, 2.88492823f, 0.47202343f, 0.72683257f, -6.84280777f, 0.41807127f, -5.09785986f, -3.74514675f, 2.03936672f, -1.06096244f, -1.52409148f, -0.97046643f, 2.27491093f, -1.55597985f, -1.29215479f, -0.79737484f, -0.01979581f, 7.65407991f, 5.54527044f, 4.04147148f, -2.64274883f, -1.89246953f, -3.89547634f, -1.06029689f, -2.85982800f, -1.41247237f, 1.55836034f, 3.38194537f, -2.97655582f, 0.87510300f, 1.26282072f, -1.77029657f, -3.57144690f, -4.19456863f, 0.53179169f, -1.42221975f, -3.09144497f, -0.84294832f, -5.02758694f, -2.68011904f, 0.89156240f, -0.34783912f, 4.64484835f, -2.34453487f, -1.28573155f, 0.09990287f, 0.01828218f, -1.79960847f, -1.06579173f, 1.08763921f, 0.43687880f, 3.24747229f, 3.83097172f, 1.07253766f, -1.33810723f, 0.76530832f, 1.58660865f, 5.60743904f, -3.54124737f, -0.89264417f, -3.83942485f, -1.03707337f, -1.61659896f, 1.65349591f, 1.72698796f, 4.96013832f, 0.78927267f, -0.35563886f, -3.48121166f, 3.79677629f, 2.59023166f, 2.74940348f, -2.17589283f, -5.91757107f, 2.43766379f, -4.15906048f, -1.74731481f, -2.49113035f, -0.57349741f, -4.04455185f, -1.46939647f, 2.21418452f, 0.09153593f, 2.23016739f, 7.91880608f, 4.04464149f, 0.07706618f, -2.41892862f, -2.19280314f, 7.61760712f, -5.89153862f, 0.33551922f, -1.70855618f, -0.30561331f, -0.14341974f, -2.48878574f, 1.31269515f, 3.45388412f, -0.02453184f, -0.12132037f, -4.27916241f, 1.25179088f, 4.09455204f, -1.83801770f, -1.86743176f, -4.02864933f, 3.44515228f, -4.39244986f, -0.56988084f, -1.69426417f, 2.18254852f, -4.78135824f, 1.73193693f, -2.27968478f, -1.49523509f, 2.51696730f, 4.03677559f, -2.03679037f, 1.32167840f, -2.22570705f, -2.74843621f, 6.29655170f, -3.67230225f, -1.86765468f, -0.14842367f, -1.21552539f, -0.92038238f, -0.51692355f, 1.08433771f, -0.01929832f, 0.15660909f, 2.31432915f, -3.86507082f, -0.69797570f, 0.13505173f, -1.50951028f, -0.69980979f, -1.51297045f, 3.63725281f, 0.13388813f, 2.73131752f, -0.96528149f, 4.92000961f, -5.92699385f, 1.69444644f, -1.17121375f, -2.33710480f, 1.35302818f, 1.39608085f, 1.68293881f, 0.94960749f, 1.89011908f, -4.08865070f, 0.13722643f, -1.62849212f, -0.19044125f, 1.37906075f, -3.92504406f, -1.45033538f, -0.42085981f, 3.38237071f, -3.06508875f, -1.39420545f, 1.13067436f, 0.92206454f, 0.49917889f, -2.74508023f, -2.19221997f, 1.77914095f, 0.10854459f, -2.62178278f, 2.35042715f, -0.15322030f, -0.67014873f, -1.75627899f, 2.64074945f, 2.76339936f, 2.67275214f, -0.62736398f, 0.58251178f, -4.64895678f, 5.50419283f, 2.53566456f, -2.44196153f, -0.07845879f, -2.80389643f, -0.64810950f, -0.05813205f, 1.67155504f, -2.69673729f, -1.72486305f, -0.53888649f, 1.86805439f, -1.37128329f, -5.37923479f, -2.08133769f, 0.58187997f, -1.39498150f, 0.21874082f, 4.33726025f, 6.29673958f, 0.72312093f, -3.32683516f, 1.73482585f, -0.00766110f, -2.63785434f, -0.13511759f, 4.07195950f, 0.94139838f, 3.15717316f, 1.53720927f, 1.87664819f, -2.33655119f, 6.18176556f, -2.73912525f, -2.45279956f, 2.20392370f, -0.56854641f, 0.98915887f, -2.64472580f, 2.40633702f, -4.93327999f, -1.28942823f, 0.98247659f, 1.31774998f, 0.07669818f, -5.91169453f, -0.43135011f, 1.27404964f, -0.59787154f, -0.22716975f, 0.74409103f, 10.27316475f, -2.29192710f, -2.19403267f, 3.78925133f, 3.19553399f, -4.42490482f, -0.80781460f, 2.16568565f, -2.54165983f, 2.54885101f, 4.18779039f, 1.73079813f, -1.48891807f, 11.60153770f, -0.98686743f, -2.88813901f, 2.32898521f, -0.36101711f, 2.34522438f, 0.29057693f, 1.39800644f, -4.31848240f, -3.21217132f, 0.11740226f, -1.21613467f, 0.57248503f, -4.44853830f, 1.54665899f, 3.14459944f, 1.76809108f, 0.26693153f, 0.86913753f, 9.47121620f, -2.07677889f, 2.08578467f, 1.30181742f, 1.58683562f, -3.52757788f, -1.32763624f, 0.79821301f, -2.19358301f, 1.17707348f, 6.01983643f, 4.11209440f, -2.04209709f, 7.00413418f, -1.84904683f, -1.32542288f, -0.01298118f, 0.70377320f, 0.27815005f, 2.07879829f, -0.71606725f, -4.94399881f, -2.11898828f, -0.39051518f, -2.21034360f, 3.05337906f, -1.56889665f, 1.97065282f, 2.61320901f, -0.34063196f, -0.57001418f, -2.13183641f, 3.48879004f, -0.12067288f, 0.48568326f, -1.81424558f, 2.28868723f, 1.44802380f, 1.25918829f, -1.76415455f, 5.35742331f, 3.50682044f, 4.71371317f, 5.89110756f, 8.51241302f, 4.07391453f, -0.05887252f, -0.18202400f, 2.27119660f, 6.78274727f, -2.87470293f, -5.14336634f, 0.76443815f, 2.04625130f, -0.43199503f, -1.01353514f, 2.42951298f, 2.35641170f, 0.32345510f, -4.04195738f, -4.77967072f, 0.26564783f, 6.11455107f, -2.53868008f, -3.11839914f, -1.04203856f, 5.17195654f, -4.15338612f, -3.84149241f, 0.48130888f, 3.09706950f, -4.18423653f, 5.26233864f, 3.55831861f, 3.75122595f, 8.14969349f, 6.80038738f, 4.68907356f, -1.40135396f, -3.19287133f, -3.15895939f, 8.77363205f, -4.48793411f, -3.80537176f, -2.40145254f, -2.74341679f, -2.02862644f, 5.33402443f, 9.25365734f, 2.50246119f, 0.32847846f, -1.50564361f, -4.26163197f, -1.40994716f, 2.50708485f, 0.44500345f, -0.62516934f, 4.09846306f, 5.29355669f, -4.02224922f, 0.73442125f, 0.46648952f, 0.67028689f, -6.30715466f, 6.56297970f, 3.80854273f, -5.19078207f, 4.98839283f, 7.59161472f, 0.46010983f, -2.10227895f, 0.29324162f, -2.67019558f, 4.57838106f, -3.02338457f, -3.08647728f, -2.00112700f, -3.81710315f, -0.08346784f, 1.69288683f, 5.68807268f, 3.29351830f, 0.54618967f, 1.83540761f, -5.38810253f, 0.51326782f, 4.40081882f, -4.03805828f, 0.49482727f, -1.36024392f, 2.91845679f, -2.00959015f, 2.47489738f, -1.43354976f, 1.92024410f, -6.55897284f, 1.79488957f, -0.89570928f, -6.13094234f, -0.45504010f, 2.35239482f, 1.29039919f, -4.78849840f, -1.52545333f, -6.50420475f, 2.99257326f, -0.55620033f, 0.26807702f, -2.52090979f, -4.59419632f, 0.57965040f, 2.19423151f, 2.04760551f, -0.57048106f, -2.20812702f, -0.04777686f, 1.38053393f, -2.71448946f, -1.06219673f, -3.62008905f, 1.85719645f, 1.28355026f, -2.76315832f, 1.65295160f, -4.01645803f, -3.10454416f, -0.65713316f, 1.22384977f, -0.70416176f, 4.45064926f, 1.31602776f, 2.06907344f, 2.48872757f, 4.25775290f, 3.50504255f, -0.68262041f, 1.29799378f, -1.01969171f, 2.98593879f, 0.12607655f, 0.37219539f, -0.84196299f, -3.80019331f, -1.82315290f, -0.38489276f, -1.45200360f, -4.00882292f, 0.61042011f, -0.16738498f, 1.33787775f, -2.26938057f, 1.03656030f, 8.89089870f, -1.60370600f, -5.38691807f, 5.72182989f, 2.72854710f, -6.18535757f, -3.13408709f, 2.79175353f, 5.18425512f, 9.46434212f, 2.40110517f, 1.11330092f, -3.57366538f, 4.80967665f, 0.40691876f, -3.65484858f, 0.92398167f, 2.53852940f, 3.17747331f, 2.14199781f, -1.69107199f, -1.91864693f, -3.18452644f, -2.42408276f, -2.14332366f, -1.35526609f, -4.50732136f, 0.58234072f, -1.81547785f, 0.57311213f, 1.10584176f, -0.97226644f, 11.73174381f, -2.00559855f, -1.81175601f, 2.33131361f, 0.49264961f, -0.42245382f, -1.37528467f, 1.55768061f, 0.21152198f, 13.08896351f, 10.33674145f, 5.77929306f, -6.19886398f, 5.67007637f, -6.61288071f, -2.58029866f, -4.05192375f, 1.77221894f, 0.29821560f, 5.23508501f, -5.09560966f, -0.97536200f, -5.17957878f, 1.02876794f, -4.52072096f, 2.22126532f, -4.81708670f, 0.44538212f, -2.30738068f, 3.15900373f, -4.99227905f, 0.82632786f, 9.65415478f, -0.63819492f, -3.25479436f, -0.13276935f, 0.21337092f, -2.22116399f, -3.04922724f, 0.65568435f, -0.10706246f, 4.58047390f, 7.80782652f, 5.49080181f, -3.97114491f, 6.43327618f, -6.54772758f, -2.10962629f, -0.79831678f, -0.08316499f, 2.48658133f, 4.14070511f, -0.59806836f, -4.58636141f, -0.31166920f, 0.31757897f, -3.92562199f, 0.65357721f, 0.55871534f, 1.71843934f, 1.62395024f, 0.00695819f, -4.56716251f, -3.76420808f, 4.24979544f, -0.86128616f, 0.23126510f, -6.32968998f, 1.83346081f, 3.81335950f, 2.98407745f, -1.80454743f, 6.61764765f, -1.39372075f, -0.86780751f, 7.24317265f, 2.24205112f, 1.05702817f, 0.55431479f, -1.54557061f, 3.36389136f, 4.70898724f, 1.11327887f, -3.78462076f, -3.63381767f, 2.86510396f, 0.74203897f, 0.81488025f, 3.54250598f, 3.24824381f, 3.19000244f, -0.58995843f, -7.05670738f, 3.18306041f, 3.95191574f, 0.81820154f, -1.91068232f, -2.05426741f, -1.05589008f, -3.18377590f, -1.86278260f, -8.80374908f, 0.93416154f, -4.60517359f, 8.38999462f, 5.26356745f, -8.89992714f, 8.95298958f, 4.22590351f, 1.00351548f, -6.90151119f, -8.07641125f, -4.82450199f, 8.02293015f, 4.11661243f, 0.95457208f, -7.07843113f, -4.30524826f, 5.02697992f, 5.21011686f, 0.80132771f, 3.23420191f, 3.82452774f, -2.13171721f, -7.88879967f, 1.31062031f, 1.90848613f, -3.51572514f, -3.75684500f, 3.62577081f, -5.76075602f, -2.79389215f, 0.32598805f, -4.28981733f, 4.21048594f, -3.84532523f, 3.19815183f, -0.40756655f, -2.19974327f, 6.25655174f, 3.42396951f, -1.88986623f, -1.92803884f, -2.97344875f, -0.09756154f, 5.24342251f, -0.72513700f, 1.06113195f, -1.30720282f, 4.69107103f, 0.58984971f, 2.33985567f, 1.46385121f, 3.16576266f, 6.77769995f, -5.92685127f, -12.61141014f, -2.83663774f, 4.90253258f, -6.32688522f, -3.00096869f, 2.38634992f, -7.21459866f, -5.89208746f, 2.84085894f, -1.21792030f, 6.70161343f, -4.00450230f, 5.29881001f, -1.45574808f, 0.77542424f, 1.38336325f, -0.21572059f, -3.38088870f, 2.33249640f, 0.68824625f, -3.68440270f, 0.33481622f, -0.39239681f, 0.14560902f, 1.61039007f, -3.11967754f, 2.49372435f, 2.68783092f, -1.17559779f, 0.95257235f, 4.35451412f, -0.56818569f, -7.32110357f, -7.58534050f, -2.10573673f, -3.34446383f, -0.32183546f, -0.78525496f, -1.76974547f, 5.19060802f, -2.11319876f, -3.41755080f, -0.36864156f, 1.32680905f, 0.45004874f, 6.17223930f, -1.60707474f, 0.46096295f, -3.88852644f, 1.84729624f, -0.03412050f, 0.99224162f, -2.05553341f, 3.47793245f, -0.06305170f, 0.51314175f, -2.91650558f, -1.78121483f, -2.85465693f, 0.24649808f, -2.70376635f, 0.42334458f, -1.13862336f, -0.98409218f, -0.96593523f, 2.22128963f, 0.53402066f, 3.33979344f, 8.57430458f, 2.34217858f, -2.40062976f, 5.81624222f, 1.13290989f, -5.06850052f, -4.72865725f, 1.82859278f, 6.78569555f, 8.56885242f, 2.76462936f, 0.33891773f, -2.81092787f, 0.79498398f, -2.27208567f, 1.55182552f, 2.17166376f, 6.12517643f, 3.56859684f, 0.27685475f, -1.38408327f, -1.03533340f, -3.46618199f, 0.79240030f, -3.89390516f, -0.55852515f, -1.16367757f, -0.07008934f, -2.20105195f, 3.81210446f, -0.66834474f, 0.43603873f, 10.92334938f, 2.48571420f, -6.34997845f, 4.23135757f, 0.45045292f, -4.13489866f, -3.92324209f, 1.88537407f, 2.57159734f, 9.90973091f, 4.37453461f, 7.34546280f, -2.51120615f, 11.12575245f, -3.23452854f, -2.49947500f, 1.39819741f, -3.78950691f, 2.40617585f, 5.10036278f, -3.55743456f, -6.42888737f, -2.51929998f, -1.90880990f, -1.81618094f, 1.60946512f, -4.09737110f, 1.96408439f, -1.90115595f, 2.44444203f, -2.31254292f, -4.01332951f, 8.65541840f, -0.58626485f, -4.02226830f, 0.43893200f, -3.78272748f, -5.46277428f, 0.01306701f, 0.61185312f, 0.24469066f, 1.30214953f, 5.87789631f, 8.75197792f, -5.31634712f, 3.43556309f, -5.90755081f, 0.54375106f, -2.48162293f, -3.51843548f, 2.55853295f, 5.06387186f, -2.09662485f, -3.00377345f, -3.21781397f, -0.14537808f, -4.65453672f, 1.92747557f, 0.41553855f, 4.09379959f, 0.83387995f, 1.50868511f, -6.54959488f, -8.38881016f, 5.50689125f, -2.88616610f, -1.21597648f, -0.23817590f, 1.50816703f, -2.26873541f, 2.29862142f, -1.61143053f, 5.97371244f, 4.71440220f, -0.20635787f, 8.85926723f, 0.56064367f, -1.04103339f, -4.47060108f, -2.63824081f, 3.06782055f, -2.07702565f, 3.38269401f, -1.59988797f, -3.80122590f, 2.35341501f, 2.69095278f, 3.87612104f, 1.89984226f, 0.95496917f, 3.14841127f, -5.84543085f, -7.24945450f, -2.65708590f, 2.87417006f, 0.97556210f, -3.75203967f, 1.55287778f, -7.43401051f, -1.29005826f, -3.40252638f, -4.01049423f, 2.82721639f, -1.21479535f, 8.54563904f, 7.39749908f, -0.61361837f, 7.60177565f, 1.65812778f, -0.83008504f, -3.60961151f, -7.69062138f, -1.26275063f, -4.17071676f, 5.28448200f, 4.04685593f, -1.18231702f, 1.15276611f, 1.58620787f, 6.75060844f, 3.29332161f, -0.67640316f, 5.78984785f, -3.14913464f, -6.41867924f, -2.58316016f, -2.04366302f, 2.01089478f, -3.81723452f, 3.63843751f, -5.13238430f, -3.79432917f, 4.86581373f, -1.06922054f, 3.95978498f, -0.78166616f, 8.35650539f, 5.35834265f, 0.35594034f, 9.41657066f, -0.84108615f, -6.54425859f, -3.44328952f, -6.55536795f, -0.08963367f, -1.53906262f, 0.17658240f, -0.13108420f, -0.44371247f, -0.78411150f, 2.64754868f, 9.66306782f, 1.70506203f, -0.31588936f, 4.31715870f, -6.16665173f, -10.43371868f, -3.72962189f, 4.35245228f, -1.75867891f, -4.20046234f, 8.62637043f, 1.45946813f, -3.30153608f, 0.85179043f, -2.66643381f, 3.01863337f, -2.52916121f, 8.35405540f, -0.37298933f, -0.89473486f, 6.88681793f, -4.46370125f, -7.50776386f, 3.80255938f, -3.55003357f, 1.43528831f, -2.20383263f, 2.34999895f, 2.03803205f, 1.94830751f, -1.85976326f, 0.97718471f, 5.53710842f, -0.80560827f, 0.23925614f, 5.98795223f, -2.03578377f, -7.77835321f, -2.79955530f, -1.88185954f, -2.49112058f, -0.76095992f, 2.71161270f, -0.55918610f, 0.83789903f, -1.42063200f, -0.61528748f, -4.18273115f, 1.76384258f, 4.21265936f, 5.50964785f, -0.93324339f, 3.83215356f, 1.52210593f, -0.91594946f, 1.31148386f, 3.20160103f, 1.24493563f, -0.72693497f, 1.84716725f, 3.09897518f, -1.34605026f, -1.17511916f, -1.05526352f, -1.08590937f, -1.41319299f, -3.75052118f, -2.67095542f, -0.76179552f, -3.32081509f, -1.04692316f, -1.30194843f, -1.98795474f, 5.01223469f, 0.21895903f, -1.85535169f, 3.12362719f, 0.16198632f, -3.86784005f, -2.03062248f, -0.15415624f, 8.22020721f, 4.83055592f, 4.50315666f, 4.19443417f, 0.42727345f, -4.67786789f, -5.18739986f, 2.53988838f, 3.19683266f, 1.80313504f, 1.94664574f, 0.59795094f, -4.21626759f, 0.50492239f, -0.41232634f, -0.99224532f, -3.94929314f, 1.74060190f, -0.92474866f, -1.00664830f, -6.17397356f, -1.33146775f, -3.78111315f, -4.91876888f, 2.50303864f, -0.34890354f, -1.25013232f, 0.38168997f, -1.84135628f, -4.46107960f, -4.05920792f, -2.61709857f, 0.71046209f, 9.80566883f, 6.34086990f, 2.73394704f, -2.03342366f, -2.21424174f, -5.56514263f, -4.74755144f, -2.20672894f, 0.09010231f, 1.70423889f, 3.19200158f, -6.99027634f, 1.14216340f, 0.05824995f, -0.76996505f, -6.51575899f, -0.41109252f, 0.78229940f, 1.36170781f, -5.65170193f, 1.12221193f, -4.60430050f, -4.40174437f, 4.01805925f, 0.10774946f, -2.77991009f, -0.18023163f, 0.02151692f, -1.77023101f, -1.86639869f, -0.69443607f, 4.92290831f, 6.83520412f, 4.27372265f, 6.54272366f, -7.59249687f, -1.40776849f, -3.52368808f, 1.01398587f, -3.58802676f, -0.35658866f, 1.14716864f, 3.75847244f, -2.30159235f, -0.72130895f, -0.24564353f, -1.77531350f, -3.08677864f, -0.73486501f, -1.20357263f, 0.60789430f, -3.46990204f, -0.20668676f, -5.46096087f, -5.22016764f, 0.98259866f, 1.81012678f, 3.92534304f, -2.94997001f, 1.65154219f, 2.27040243f, 0.99095678f, 0.09144652f, -0.99103236f, -1.11210847f, 0.78181303f, 2.38706732f, 2.96695375f, -0.17279971f, 0.31143007f, 1.35465562f, 2.03586054f, 6.19515753f, -3.14652419f, -2.89027119f, -3.26665854f, -1.93043876f, -0.46601450f, 1.07655203f, 1.74946189f, 4.02148342f, 0.69275337f, 0.50094581f, -4.07613230f, 2.98369169f, 4.24537849f, 0.49480581f, -2.02408123f, -2.02068973f, 6.54505825f, -5.19377470f, -0.12596917f, -0.70204186f, -0.98308045f, -3.19708824f, 1.63609934f, 1.35475993f, 0.16313422f, 4.13918924f, 7.69187021f, 3.72601676f, -1.97790039f, -1.16739464f, -3.31835508f, 8.14553452f, -1.78718984f, 1.21505618f, -3.84255409f, -3.21992350f, 0.07376552f, -0.81223297f, 3.57002878f, 1.48521733f, -0.45995998f, 0.30551746f, -3.33944130f, 1.39538884f, 1.84758544f, -0.21494150f, -2.27316713f, -4.37771225f, 6.48841667f, -5.00251961f, -0.45162797f, -5.01056004f, 0.70199943f, -4.60057783f, -2.22394514f, 0.07777429f, -1.49820781f, 3.47308421f, 6.13231564f, 1.18605387f, -4.78924608f, -3.49548388f, -2.73382568f, 6.24617863f, -2.74291611f, -1.03833354f, -2.20752788f, -2.33219409f, 1.48633552f, 1.65796840f, 4.95045471f, 2.58479190f, -0.90922785f, 0.71312457f, -4.44465590f, 1.37020862f, 2.37683725f, 0.18805164f, -3.28422308f, -1.64939332f, 3.64181972f, -3.75277281f, 3.67203593f, -0.11204052f, 2.24140930f, -3.90657187f, 2.56883717f, -1.44016707f, -2.83842611f, -0.29104578f, 2.17757058f, -0.71431804f, 1.36911654f, 0.85083604f, -1.60110259f, -1.97247636f, -1.61163378f, -0.81236130f, -0.38993555f, -3.03631902f, -0.38213277f, 0.06394482f, 3.19348621f, 0.36771113f, 1.36763072f, 2.49159527f, -0.39599860f, -2.69996762f, -0.97561121f, -2.97563028f, -0.49662948f, -0.17564940f, -2.79042959f, 0.72395414f, 2.07260203f, -0.99439794f, -2.20248008f, -0.07389921f, 0.65536159f, 4.73054695f, -0.63917702f, 0.58788192f, -3.60156059f, 6.59609890f, 3.88419437f, -3.38469863f, -3.56237841f, -2.03295064f, 0.07279694f, 3.71804547f, 0.79928309f, -2.13411403f, -1.13909864f, -0.34193408f, -1.00338125f, -1.44231665f, -5.39835978f, -0.45086145f, 1.16064668f, 2.58335257f, 2.10072684f, 4.64244223f, 7.10090065f, 1.01974952f, -4.44687223f, 2.99792576f, 1.10303724f, -1.22736573f, -3.91514421f, 3.07458854f, 2.18765211f, 3.34481716f, 2.46166849f, 2.99648619f, -0.94046807f, 5.55028200f, 0.92199719f, -0.83934361f, -0.72042274f, 0.84869325f, 1.46914721f, 0.85937387f, 4.77306223f, -4.06436539f, -2.59847593f, 2.44828081f, 0.50484699f, -2.71092367f, -6.39010477f, 0.91778028f, 3.25469685f, 1.30310678f, 1.35258150f, 3.56171441f, 7.82435083f, -2.51527429f, -4.24328852f, 2.36876059f, 1.94595242f, -2.59290171f, -6.62389565f, 3.32567835f, 2.13659120f, 4.09299326f, 3.48293996f, 2.64965177f, -3.19157362f, 13.37204266f, -0.50297594f, -4.57448196f, 3.95582604f, -0.69038916f, 0.10098404f, 1.18737555f, 3.65761185f, -5.69623756f, -2.03357077f, 1.02868807f, -1.38448596f, -0.05690211f, -8.48874187f, 0.56755424f, 1.45485961f, 0.66273880f, 0.06495565f, 1.79539490f, 8.46864319f, -1.22696662f, -1.87585378f, -0.99768794f, 2.72801924f, -0.66980243f, -2.31924677f, 0.33271110f, 0.11666083f, 1.86980045f, 5.95332909f, 7.38583708f, -2.80956483f, 6.79227638f, -6.78070831f, 1.21884382f, -1.40695429f, 0.90236962f, -1.13695288f, 0.50760663f, 1.00955284f, -5.39029121f, 0.24987072f, 2.24283314f, -4.02145576f, 2.18057394f, -3.35627747f, 1.26061773f, 1.30342579f, 0.11311233f, -1.11199212f, -4.06509686f, 5.82649660f, -1.24059582f, 5.51652861f, -1.90937877f, 1.10658336f, -0.47065550f, -2.39167786f, -1.95931304f, 4.12717247f, 1.15396059f, 1.26015663f, 7.97836876f, 7.33633423f, 2.27785325f, -2.83802366f, -2.74850106f, 0.86126029f, 6.18781090f, -1.43707538f, -6.97134876f, -3.25486469f, -1.95214593f, 0.91066706f, 0.89637989f, 1.06481194f, 6.25791073f, 0.81779671f, -1.08384395f, -3.21191931f, 2.04216075f, 4.76030350f, -2.37217665f, -1.42571259f, -6.35876131f, 4.62536526f, -5.40060568f, -3.14868999f, -1.00587153f, 1.80662942f, -7.03201485f, 6.08373499f, 0.99862772f, 2.21717811f, 4.06814623f, 6.02428913f, 5.33422756f, -0.87013257f, -2.22477579f, -2.51505303f, 5.82925224f, -0.82854009f, -4.30698347f, -1.75007713f, 2.08352375f, -2.25235629f, 1.17517352f, 5.77717733f, 2.27472878f, 2.72778273f, -1.95411634f, -4.52602863f, 1.13983536f, 1.16340065f, -2.02740526f, -3.11290503f, -1.94906235f, 1.54855204f, -4.52984142f, 1.97465122f, -1.79415476f, 4.03510094f, -8.45349979f, 10.87430096f, 2.19863629f, -5.39083815f, 5.86213875f, 6.25744534f, 6.52600002f, -4.72149038f, -1.75254321f, -5.51459169f, 7.03155518f, -2.01889277f, -4.58441257f, -3.61226106f, 0.42395937f, -0.93263882f, 2.28703761f, 2.80611467f, 2.59498215f, 0.65989012f, -1.51268566f, -4.49465561f, -4.70453882f, 5.44696808f, -4.37603617f, 0.46670085f, 2.82488608f, 2.18854523f, -2.04817152f, 1.19557285f, 1.53618634f, 4.44758606f, -7.31593513f, 7.43966007f, -3.55480957f, -5.29834652f, 2.14622784f, 1.65194583f, 2.71262598f, -4.86145496f, 0.79726243f, -8.88541985f, 1.19627261f, 0.79660845f, -1.98016644f, 1.03741014f, -3.93128228f, 1.05535269f, 2.01378822f, -0.46086323f, -0.77754641f, -1.43942690f, 0.49809402f, -2.27861357f, -3.29815221f, 0.38201320f, -3.98481083f, 4.88261318f, -0.44555628f, -2.57224536f, 2.35001850f, -2.65835261f, -2.43422794f, -2.97889376f, 1.07349825f, 1.88157082f, 4.74075413f, 0.60376728f, -0.48894715f, -1.15800071f, 4.68110943f, -0.86976886f, 1.49192941f, 0.62665290f, 0.20652676f, 0.53916287f, -1.45706177f, 0.66133004f, 1.34405875f, -4.27689552f, -0.20838106f, -5.14266443f, -1.29718637f, -1.74506426f, -0.86022055f, -3.57553625f, 0.46880072f, -1.25287139f, 3.28596354f, 11.33191013f, 1.23942876f, -3.87616491f, 7.57880497f, -0.22940339f, -5.68512678f, -1.94969654f, 5.85449600f, 3.75705457f, 4.24395847f, 1.60086083f, 2.62553668f, -0.93964291f, 5.84753895f, -0.79931092f, 0.48274064f, 2.07170033f, 3.02243996f, 2.63509989f, -0.76043403f, -1.64048159f, -6.17683458f, -3.09974527f, -2.12773156f, -0.89379883f, 2.82242465f, -1.99981332f, -0.08763933f, 0.01921120f, -1.94142103f, 2.48067307f, 0.41083777f, 8.24922180f, -1.84516132f, -1.39224625f, 5.03956223f, 0.49562740f, -5.28296328f, -0.20005548f, 3.13672113f, 0.51187158f, 7.11563921f, 6.43059587f, 3.48430967f, -5.37095928f, 8.03863049f, -5.53923941f, -2.16421175f, -3.77641368f, 3.29633045f, 5.04030085f, 2.25945377f, -3.04169011f, -2.16198015f, -2.49559617f, -0.26252726f, -6.99201345f, 2.87374353f, -0.12568980f, 0.23314142f, -1.32087135f, 4.39030552f, -0.24638844f, -4.37242651f, 14.09276772f, 1.23987353f, -1.72249663f, 0.31124914f, -2.13725138f, -3.74915648f, -1.87147236f, 0.47318631f, 1.13337576f, 3.00416899f, 8.82548523f, 4.80538750f, -5.28486395f, 5.51870108f, -5.15801477f, 0.95712411f, -1.50416136f, 2.34657240f, 4.20726633f, 5.56757259f, -3.30645251f, -3.39945269f, -2.68488026f, -2.53525281f, -3.15145874f, 2.74529529f, -0.96283442f, 2.87778258f, 0.22186530f, 1.24905694f, -7.07941198f, -5.45916176f, 3.46988297f, 0.92430985f, -0.98330998f, -2.23672342f, -3.03262734f, 0.73941302f, 0.98004431f, 0.83219361f, 7.17411804f, 4.27849865f, 0.14765590f, 8.61269569f, 9.04497051f, 1.53991723f, -2.08305025f, -4.34939337f, 0.63786775f, 2.60098696f, 0.02432060f, -1.48516297f, -4.06825686f, 5.12420368f, -0.75312757f, 1.96927559f, 4.91575956f, 3.41533065f, 3.62557888f, -4.35002136f, -5.91343403f, 0.45026422f, 4.93286371f, 3.45830250f, -4.39032364f, -0.51697755f, -7.41543341f, -3.06703568f, 1.01196158f, 2.47106576f, 5.54014874f, -4.65312243f, 8.61000633f, 8.25905323f, -1.41497111f, 8.69221878f, 0.40090930f, 1.11325574f, -1.67089832f, -4.01080132f, 1.07925677f, 2.68086481f, -0.73093414f, -1.35081220f, -7.85765076f, -5.98989439f, -0.04651213f, 4.63693142f, 2.07757711f, -0.22652936f, 3.45525455f, -0.69198442f, -10.39761639f, -2.02106953f, 4.77755499f, -2.67665577f, -1.72481167f, 4.49634743f, -2.55717134f, -4.55044937f, 0.46377492f, -3.08933020f, 3.86891365f, -2.79104614f, 8.36974335f, 0.86471701f, -5.39342690f, 12.54906940f, -0.41536295f, -5.29502535f, -3.94430566f, -5.67391300f, -4.65079165f, 2.22505951f, -0.30000746f, 2.27855444f, -4.81604433f, -1.73440599f, 4.68784523f, 5.00208044f, 0.18863934f, -1.74989462f, 3.17923450f, -1.59773099f, -12.59962940f, -1.54495025f, -0.00576371f, 1.79913878f, -2.43449807f, 1.49516344f, -3.90507102f, 1.68647158f, 4.50177765f, -5.32286358f, 3.47539330f, -2.90529680f, 1.61576962f, 0.83679676f, -5.55615807f, 3.78939056f, -4.46644831f, -5.95550919f, 0.37808037f, 0.51334500f, 1.74658906f, -0.82085419f, -0.65387219f, 3.67790437f, 0.03758264f, -2.42622781f, 1.83335185f, 4.73835945f, -0.83536482f, -0.03993917f, 3.78230667f, -4.81265640f, -8.26869011f, -1.30363441f, -2.09106350f, -3.96769738f, -1.89037073f, 0.38682747f, 0.05434489f, 5.72213697f, 0.55685395f, -3.47729349f, -1.11535001f, 2.09416127f, 5.08877802f, 5.72183466f, 1.29632664f, 0.16822398f, -2.43180108f, 3.49967623f, 2.15753818f, -0.26548505f, 3.24446392f, -0.00599277f, 1.08215356f, -0.23225522f, -2.40723038f, 0.18496060f, -3.70608735f, -0.19918591f, -1.64028871f, 0.80792952f, -0.85334057f, -2.52314138f, -3.12099195f, 0.17949918f, -0.82650864f, 2.32224989f, 9.56476116f, -0.20134282f, -0.48428559f, 2.86784410f, 0.07289505f, -3.92880869f, -2.11887884f, 0.59164631f, 6.31267452f, 7.49149418f, 2.88749456f, 2.40504885f, -3.57608175f, -1.48019314f, -0.69410253f, 0.90275228f, -0.34111357f, 2.19190216f, 3.39090061f, 3.39631820f, -5.19105434f, 2.67546582f, -2.56549048f, -0.59797800f, -4.21802664f, 0.63918972f, -0.69969130f, 0.47496963f, -4.30976725f, 0.16531238f, -3.59595251f, -0.76877379f, 11.79971790f, -0.93276632f, -1.48630571f, 8.04754066f, 2.09168458f, -3.77018499f, -4.19337654f, 0.26171905f, 1.99359691f, 8.96759701f, 8.39609814f, 6.19231987f, -5.36037970f, 4.69818354f, -4.22453928f, -4.61665344f, -2.52073431f, 1.34026706f, 2.80182385f, 2.56681514f, -4.04676390f, -3.01466990f, -4.10480118f, 0.38737059f, -0.37146521f, -2.26529670f, -1.72867084f, 0.93472683f, -2.47562981f, 0.89871657f, -1.67618203f, -0.28950238f, 5.30124855f, -0.14731219f, -0.81319761f, -1.11265934f, 0.11356127f, -2.52802444f, -1.93826056f, 1.06187987f, 1.48062325f, 4.28070498f, 5.69893932f, 9.26904392f, -4.23773003f, 5.78582096f, -6.18445301f, -2.85200453f, -5.30461454f, -4.16009140f, -0.07239690f, 4.11531162f, -1.12266588f, -1.50265646f, 0.47661865f, -1.90043914f, -6.48978710f, 1.71005368f, 0.18256521f, -0.88272136f, -0.51324779f, -0.78045660f, -5.21036625f, -4.11805344f, 3.99454761f, -1.04999924f, -6.99629354f, -5.02737141f, 0.94748145f, -2.35882139f, 4.13982439f, -1.41835535f, 7.56763077f, 3.97024012f, -4.08156776f, 6.90305424f, 0.53571963f, -2.22625160f, -2.09144926f, -4.98530245f, -0.15102190f, 0.59995949f, 3.28562784f, 0.77991986f, -3.08389306f, 3.34046674f, 0.41394949f, 5.10031366f, 2.99692893f, 0.17706826f, 2.85998058f, -6.68330860f, -6.72653008f, -0.04071128f, 3.71085787f, 3.17834806f, -4.88019037f, 6.74075413f, -7.41782188f, -5.22026348f, -1.94595623f, -3.61318684f, 1.85610664f, 1.08613706f, 6.41580677f, 1.46376514f, -4.11524010f, 9.59146214f, -2.92772651f, -1.70753336f, -1.51594138f, -4.88185692f, 1.47331417f, -2.23893595f, 4.98459148f, 1.29359996f, -2.29221845f, -0.99594390f, 3.05759239f, 6.86030054f, 2.40487719f, 3.28339863f, 7.72739315f, -3.60563445f, -9.73502827f, -1.51672328f, -0.08473521f, -2.43673515f, -3.26616001f, 3.63767886f, -11.25394535f, -5.17597103f, -1.27523947f, -7.82669783f, 0.67929745f, -4.50530529f, 5.49323797f, 6.78993320f, -2.28033876f, 4.61412525f, 2.55109429f, -12.38607693f, -0.63024014f, -3.45992327f, -0.84092742f, -0.03252453f, 4.58635283f, 5.28213978f, -1.28417206f, -1.71185923f, -0.26850975f, 8.28257561f, 4.47432184f, 2.72818279f, 8.42217731f, -4.22216320f, -8.95128918f, -1.57179546f, 1.34253705f, -5.47035217f, -5.50866985f, 4.64156532f, -6.11207914f, -5.46734476f, 3.54298997f, -2.79237103f, -0.70766860f, -3.62739944f, 3.22660995f, -2.02262759f, 0.11224222f, 2.63832402f, -0.91955596f, -4.65958309f, -0.29729855f, -1.78957534f, -0.40749407f, 0.51688713f, 0.83725226f, 0.30945438f, 1.20769620f, -1.75219965f, 2.59689760f, 5.01501608f, -1.59034789f, 0.58155286f, 3.75831509f, -5.26110506f, -8.65382767f, -6.19066620f, -0.61932850f, -2.71863723f, -0.87443137f, 3.40582991f, -1.27868056f, 3.51236677f, -2.07806540f, -0.85076392f, -1.14599180f, 1.16361260f, 1.86411846f, 5.86179352f, 0.69029891f, -0.06060839f, 1.54649436f, -0.60351688f, 1.51970077f, 0.04187265f, 1.64540339f, 2.75502157f, 2.46308279f, 1.69071770f, -3.23827076f, 0.92096543f, -3.09458661f, -1.23823690f, 0.24035048f, -0.74456501f, -1.85476089f, -0.32914662f, -2.10325241f, 1.19795251f, -2.05372071f, 1.02114081f, 2.56286955f, 0.42165697f, -1.65826249f, 4.00724554f, -2.18727994f, -1.05848944f, -0.52338278f, -0.28714985f, 8.08780861f, 5.04444599f, 3.51866961f, 3.37445784f, -1.96067202f, -1.21509445f, -3.96595931f, -0.80801201f, 0.76944816f, 1.80147493f, 4.14419460f, -0.12201095f, -2.77788162f, 1.13284469f, -2.05441403f, -0.61129224f, -2.69690657f, 1.91634214f, -2.17146754f, -0.22308528f, -6.02561045f, 0.49161875f, -6.74280357f, -4.62689781f, 2.47910833f, 1.86534905f, -3.24152899f, -1.39898300f, 0.29427958f, -2.16338181f, 0.90073711f, 1.75551236f, 4.42651892f, 8.34437466f, 5.50070190f, 5.68162251f, 1.65345454f, -2.72315669f, -5.43411493f, -0.29380533f, 1.07508349f, -1.73533511f, 2.56912184f, 3.62010550f, -6.30422783f, 1.74158525f, -1.22070909f, -0.80982518f, -4.14757967f, 4.29217434f, 0.70600843f, -2.09282112f, -5.09018898f, -0.11623126f, -5.99775553f, -4.66743088f, 1.61512172f, -1.30276895f, -3.17103505f, -0.26310229f, -1.00843918f, -0.77664804f, -2.05240250f, 0.04728425f, 1.15720487f, 4.01001406f, 7.24615860f, 2.55452180f, -5.76347876f, 0.34683830f, -6.05540276f, -4.70677900f, -0.93182588f, -4.37759733f, 2.93209839f, 1.63947964f, -2.43563962f, 1.35213876f, 0.00670356f, -0.02742785f, -2.16460943f, 1.39449501f, 0.23929763f, 2.37476778f, -4.17733765f, -0.81475425f, -6.15027046f, -5.74441719f, 3.53978682f, 0.66798484f}); - - sd::ops::deconv2d_tf op; - auto result = op.evaluate({&input0, &input1, &input2}, {}, {7,7, 2,2, 0,0, 1,1, 1,1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + auto input0 = + NDArrayFactory::create('c', {4}, {3.f, 8.f, 8.f, 16.f}); + + auto input1 = NDArrayFactory::create( + 'c', {7, 7, 16, 5}, + {1.05293429f, -0.89349967f, 0.31027254f, 1.22991478f, -0.62926656f, + 0.56918693f, -1.60992694f, 1.10167944f, -0.80843484f, 0.07521993f, + -1.15994942f, 0.76016301f, -0.40056285f, -1.16872537f, -0.91384381f, + -0.36700436f, 1.82389200f, -1.18200207f, 0.51612782f, -0.92479187f, + -0.09307563f, -0.55122334f, 1.23532486f, -1.11124146f, -0.05812126f, + 0.68159896f, 0.69125599f, -0.77127314f, -0.10874277f, 0.86469102f, + -1.31614351f, 0.33354419f, -1.71750402f, 0.17197680f, -1.03965557f, + 1.10570908f, -1.19115615f, 1.05115080f, 0.18277600f, 1.08820546f, + -0.72191417f, -0.10999311f, 1.56521320f, -0.35433730f, -1.11799145f, + 0.34499285f, 0.64998639f, -1.64371550f, 0.92592359f, -0.47659501f, + 0.49101439f, -0.15613313f, 1.47486567f, 0.43576995f, 2.19538260f, + -0.83567709f, -1.21846950f, 0.80400819f, 1.14637423f, -1.01503456f, + -0.61992753f, -0.47378838f, 0.86503726f, 0.27147385f, 0.37073180f, + -0.19951358f, 0.79167330f, -0.33982825f, 0.18631981f, -1.54715073f, + 0.39967480f, 0.95067030f, 1.12508667f, -0.86676019f, -1.10341156f, + 2.33141375f, 1.10972047f, 0.71407092f, 1.70640314f, 1.80666339f, + 0.59465605f, -0.39653218f, -2.61163163f, -1.15013492f, -1.19908321f, + 0.41783467f, -0.22730024f, 0.31425011f, -0.58562893f, -0.10131568f, + -0.85047537f, -2.59974790f, 1.22072542f, -2.08812046f, -0.19363593f, + -1.27664304f, -0.02703438f, 1.08477545f, -0.65506506f, 0.46040919f, + -0.13715318f, -0.74945593f, -0.69006950f, -1.29617655f, -0.15865716f, + 1.38956285f, 0.90216327f, -1.31185400f, -0.15067385f, -0.63093358f, + -0.05895613f, 0.26545224f, 0.29332840f, 0.42852548f, 0.72409540f, + 0.12879130f, 1.43038857f, 0.68647617f, 2.19654775f, 0.51878077f, + -0.03769343f, 0.52877223f, -0.21733910f, 1.13710785f, -0.59003806f, + 1.54624867f, -0.64997369f, -1.03239334f, 0.19708300f, 0.68658423f, + 0.71048903f, -1.55250466f, -1.38636279f, 0.32385820f, 0.81226677f, + 0.19209047f, -0.23002781f, -0.63631231f, 1.02101684f, 0.65428704f, + -0.17206922f, 1.09488952f, 1.03022420f, -0.95567745f, -0.07595373f, + -1.48606372f, 2.57174873f, -1.75366247f, 1.12913883f, 0.97053039f, + -0.28552356f, 0.56511772f, -0.79568213f, 0.07561764f, -1.02085686f, + 1.05770981f, -1.25715709f, 0.42046708f, -2.57390857f, 0.96947151f, + 1.05215812f, 0.65624017f, -1.29019403f, 0.64157075f, -0.40509227f, + -0.65354455f, 0.42348680f, -1.34107757f, 0.05931387f, -0.54337227f, + 0.95460182f, 1.59319806f, -0.44433126f, -0.33717924f, 0.79566282f, + 0.50112695f, -0.22244534f, 1.76904583f, -0.89817202f, 1.82985342f, + 0.17671813f, 0.80720717f, 1.32469308f, 0.39417782f, -0.23720963f, + 0.96796370f, -1.02348757f, -0.86615551f, -1.58120525f, -0.37634999f, + 0.00905940f, 0.01880967f, 1.75771821f, -0.64372772f, 0.36687651f, + 0.15854552f, -0.67599791f, 0.53726906f, -1.20158446f, -1.78549063f, + 0.96476388f, -0.66158366f, -0.41681561f, -0.97541636f, 2.35928202f, + 0.32130197f, 1.06886065f, 1.38736427f, -0.73718959f, 0.11215294f, + 2.12865782f, -0.37927702f, 0.55621815f, -1.10108411f, -0.02032263f, + 0.29595461f, 1.58737493f, 1.24001300f, -0.66748160f, 0.80729002f, + -0.10575818f, -1.03175950f, 1.80755460f, 0.10825710f, 2.20666361f, + 1.33633149f, 1.39290452f, 0.45211342f, -0.07837920f, 2.08304930f, + -0.28387162f, -0.70775616f, 0.43626297f, 0.53556961f, 0.06201901f, + -0.59255266f, -0.11854446f, 2.10024118f, 0.37638292f, -0.56178707f, + -0.25220188f, -1.23731256f, -1.30002999f, 0.34283713f, 0.30502397f, + -1.09233856f, 1.12430644f, 0.52273953f, -0.68507338f, -0.69913578f, + 0.88440478f, -0.76959240f, 1.07093310f, -0.34802195f, 0.35683727f, + -0.76079178f, -1.92807376f, 0.84499562f, 1.39131641f, 0.44825050f, + 0.34567752f, 0.44607711f, -1.00986362f, -0.50038189f, -0.09060892f, + -2.55645394f, 0.56416476f, -0.83058155f, -0.65931624f, -0.73649710f, + 0.59814465f, -0.86736494f, -0.32200798f, -1.28087902f, -0.76818323f, + 0.86848933f, -0.98678392f, -1.30813944f, -0.20255326f, 0.26557815f, + -0.31090519f, -1.46331608f, -0.62782109f, 0.59034890f, 1.63147473f, + -0.17727259f, -0.37636510f, 1.27368402f, 0.19096918f, -0.29936951f, + -1.99038267f, 0.54831523f, 0.48849005f, -2.55680346f, -0.63126534f, + 1.21715927f, 1.22841084f, -0.67416084f, 0.02927168f, -0.36693662f, + 0.63204330f, 0.13721083f, 0.28742912f, 0.19470036f, 0.74873924f, + -1.47602463f, 0.86264688f, -0.23730527f, -0.99978864f, -1.17048764f, + -0.34996086f, 1.43019187f, 0.26224539f, 0.60689932f, -0.75002515f, + -0.79823422f, -1.37300086f, -0.19951135f, -0.12150808f, -0.75272322f, + 0.23755015f, 0.31270382f, 1.66539109f, -1.04104745f, 0.79540199f, + -0.54042423f, -0.54150617f, 0.43871084f, 0.24163951f, -0.24517761f, + -0.66178995f, -1.13064528f, -0.84426326f, 0.56437236f, 0.09088907f, + -0.82823074f, 0.81753862f, -1.74096012f, -1.80599844f, -0.60943592f, + 1.36094582f, -1.47762752f, 0.15931177f, 1.05569172f, 0.36751524f, + 0.06497604f, 0.13536447f, -1.57156146f, 0.22783801f, -0.96910107f, + -1.24294984f, -1.47147155f, -1.04790676f, 0.64629447f, -0.32266054f, + -0.55675793f, -0.95612079f, -0.23005411f, -0.75229394f, 0.03050950f, + -1.72484553f, -2.06055546f, 0.19892083f, -0.13597751f, 0.65180075f, + 0.27096850f, 0.08977254f, 0.57564765f, -0.43227410f, 0.09541437f, + -0.00358280f, 0.65680492f, 0.04006556f, 0.57160908f, 0.43821687f, + 1.96118212f, 0.42602235f, -0.36731303f, 0.67200917f, -0.56667900f, + 0.44014785f, 0.06970236f, -1.34415269f, -1.13301528f, -0.08848868f, + 0.35615012f, -0.06426942f, -0.81406075f, 0.94097465f, -0.54560357f, + -0.65877116f, -1.29646838f, -1.13109028f, -1.64186084f, -2.12723470f, + 1.86027610f, 1.22621441f, 0.26098135f, -0.05608099f, 0.21143445f, + -0.87244326f, 0.79408187f, 1.24279130f, 0.14458629f, 0.25532281f, + -1.24023473f, 2.42278886f, 0.00405578f, -1.00119174f, 1.19856644f, + -1.37395728f, -0.16656208f, 0.46858498f, -0.00678801f, -0.34960639f, + 0.16614936f, 2.41560221f, -0.53880709f, 0.91618651f, -1.77009308f, + 0.32911557f, 0.30216452f, 0.02881077f, 0.77705866f, 0.27061903f, + -0.07440855f, -1.14010465f, 1.25383139f, -1.58615100f, 1.04185510f, + 0.15140508f, -0.88059032f, -0.33872122f, -0.42526904f, 2.17365575f, + 0.29308075f, -2.24234557f, -1.03164542f, -0.09263755f, 0.08050421f, + -0.74946511f, -0.64589006f, -1.13416314f, -0.64989561f, 0.16502371f, + -0.33831969f, 0.22832428f, -0.08389475f, -0.28009200f, 1.34536922f, + -0.19075738f, 0.36238208f, 0.83690089f, 0.26144615f, 0.04457319f, + -2.55585861f, -0.01807522f, 1.68334866f, -0.05795629f, -0.21315987f, + -1.84039557f, 0.06512877f, -1.77318645f, -0.27637982f, 0.20439345f, + 0.67558700f, -0.77179354f, -0.17902173f, 0.70381826f, -0.40395790f, + -0.96492916f, 0.84138173f, 2.43879008f, -0.32297835f, -1.74370265f, + -0.10330839f, -1.07465363f, 1.85030377f, -0.59153467f, 0.99667048f, + -0.56753993f, 0.57383025f, -1.90630126f, 1.24299097f, 0.22797665f, + 0.30468231f, -0.07360230f, 1.64654350f, 0.57195550f, 0.03227921f, + 1.11005175f, 0.00088721f, 1.19266295f, 0.61323351f, 0.13754399f, + 0.59900171f, -0.75831634f, 1.11500823f, 0.99747783f, -1.36923385f, + 1.26563418f, 0.01253266f, 0.35483193f, 1.95143735f, -2.02703261f, + -1.38265920f, -0.02404256f, 2.02788448f, -0.75144875f, -0.58445263f, + 0.26129767f, 0.60691077f, -1.84661067f, 0.65872228f, -0.58298993f, + 0.33067298f, -0.09431327f, 0.43333948f, -1.52616286f, -0.25961858f, + -1.65459549f, -0.72950101f, -0.89906919f, -0.80081612f, -1.32189929f, + -1.36574399f, -0.35809481f, 0.36385000f, 0.31480747f, -0.35797358f, + -1.04066050f, 0.07971872f, -0.21176252f, -0.76559299f, -0.10352154f, + 0.29248312f, -1.75030553f, 0.68219930f, 0.56189102f, -1.11212170f, + 0.06501702f, -0.07131009f, 1.23410738f, 0.29311740f, -1.02052307f, + 1.40220940f, -1.00995779f, 0.57955760f, 0.22640309f, 0.74853230f, + -0.02586563f, -0.33427954f, 1.70311153f, -0.53405988f, 0.90975094f, + -0.46450076f, 0.19904344f, 0.28559047f, 0.23167793f, -0.69065529f, + -0.17176504f, -0.29301846f, -0.85477978f, -0.00267053f, -0.28529504f, + -0.64201307f, 1.03479636f, 1.03805065f, 0.83270210f, -0.09405448f, + 2.50615931f, 0.62019676f, 0.31354564f, -1.51599669f, 0.42848015f, + 0.66263914f, 0.74651009f, -1.13042867f, -0.58933645f, -0.35146511f, + 0.06223279f, 0.28065836f, 0.66506970f, 0.16942430f, -0.23316263f, + -0.87481076f, 1.21992230f, 1.48536301f, -0.79667616f, -0.75519305f, + 1.40999961f, -0.42802793f, -0.20252463f, 0.30573779f, -0.23319976f, + 1.77525878f, -1.80704832f, 2.71519923f, -0.67500192f, 0.12268137f, + -0.13014549f, -0.07479453f, -1.51065743f, 1.04198146f, 0.96205556f, + -2.00525570f, -0.37911776f, 0.89329720f, -0.39495832f, -0.03683375f, + -0.90928614f, -1.56263304f, 0.45038295f, -2.62184358f, -0.45686841f, + -0.52536523f, 1.05351484f, 0.89982438f, -0.63724512f, 3.21004057f, + -0.08608918f, 1.55209303f, 0.62688643f, -0.59702635f, 1.85774517f, + 0.38172096f, -1.25640929f, -2.59278178f, 0.85050315f, -1.10080361f, + -1.26422560f, -1.80045366f, -0.34494889f, 0.68448657f, 1.25671864f, + -1.26594126f, 0.32244179f, -0.51956522f, -0.56212711f, -0.95574015f, + 0.71973872f, 0.46736258f, -0.11772985f, -1.52736545f, 0.19571695f, + 0.73147154f, 0.87724912f, -0.26265728f, -2.60267401f, 0.19263546f, + 0.18320183f, 0.11485019f, -0.82999659f, 0.13582672f, -0.08040185f, + 0.28152901f, -0.51421624f, -2.32467175f, 0.19923948f, 0.64616692f, + 0.29718629f, 0.32785949f, -0.62266952f, -0.98174316f, 1.23276305f, + 0.58563638f, 1.28528512f, -2.13718534f, 0.28842899f, 0.12676710f, + -1.72105229f, 0.15053287f, 2.19496536f, 1.28683448f, -0.96318281f, + 0.17043279f, -0.05245409f, -0.38710704f, -0.30441490f, -0.08249986f, + 0.28423953f, 0.72963721f, -1.49658203f, 0.99077344f, -0.78913772f, + -1.12661564f, -1.26294816f, 0.16517465f, 0.10124251f, -0.77198768f, + -0.16342169f, 0.08615876f, 0.49711797f, -0.66083062f, 0.76648003f, + 1.04756033f, 1.46122825f, -0.42798752f, -2.29203916f, 0.30444992f, + 0.58697921f, 1.22166932f, 0.09022947f, -0.03920181f, 0.10444995f, + 0.10361757f, 1.18224072f, -0.76641631f, 0.90802073f, 1.41639423f, + 1.55682337f, 1.28101575f, -0.35396016f, 1.11443567f, 1.18218529f, + -0.06048089f, 0.85024464f, -1.01789165f, -0.69154263f, 0.06663221f, + 0.68429029f, 0.12560424f, 0.37915874f, -0.66829866f, -0.64524972f, + -0.05568011f, 0.12230454f, -0.35041061f, 0.62027830f, -0.16739209f, + -0.72145337f, 0.46263054f, -1.67837834f, 0.69413221f, -0.57243419f, + 0.37638462f, -0.21446526f, -0.89821470f, 0.60078722f, -1.06706369f, + -1.26132309f, 0.35714921f, 2.39221811f, -0.09376130f, 0.30760849f, + 0.59180892f, 0.55815399f, -0.32628775f, 1.28890121f, -2.53237987f, + -0.98241091f, 1.10520673f, -1.74751687f, -0.90837651f, -0.25220659f, + -0.56625104f, -0.30691949f, 0.16058689f, 0.44309673f, -1.09874964f, + -0.76747823f, -0.33679363f, -0.02535496f, 0.00990100f, 1.35318136f, + -0.70140815f, 0.50937581f, 0.55386209f, -1.21721983f, 0.71376961f, + -0.18079315f, -0.11077732f, 0.09292522f, -0.57235324f, 0.62748206f, + 0.42587611f, 0.64860481f, -1.10635614f, 1.66414368f, 0.47505483f, + 1.48602211f, -0.59611166f, -0.41932896f, -0.96542233f, -0.41756630f, + -1.02963889f, -0.70070386f, 1.65803933f, 0.20138647f, 0.05895034f, + -1.46152759f, -0.37278318f, 1.05535650f, 0.34437978f, -1.13257408f, + 0.17635690f, 0.09386671f, 0.37079874f, 1.47695887f, -1.58420062f, + -0.26100200f, 0.44847637f, 0.88847303f, -0.13877590f, -0.64620668f, + -0.38019657f, 1.01608157f, 0.13357787f, 0.05137976f, 0.93498152f, + -0.62226880f, 0.80461699f, -0.71682596f, -0.88756353f, 0.40933055f, + -1.52167451f, 0.79756850f, -0.17307425f, 0.62368619f, -0.22466940f, + -1.72802913f, 0.59047443f, -0.58020931f, 0.09096476f, -0.07317388f, + 0.44522321f, -0.64880705f, 0.15684015f, 0.08708375f, -0.41556796f, + 1.11579072f, -0.81733495f, 0.11643656f, -0.73995101f, 0.93685871f, + 1.57971406f, 0.67606360f, 0.70509088f, -0.25283816f, -0.00010609f, + -0.61884147f, -0.86409342f, 0.95383751f, -0.05895388f, -1.45261180f, + 0.45166013f, -1.01434863f, 0.18496066f, 1.06517637f, 1.81127059f, + 0.89470667f, -0.13232610f, 0.46958798f, 0.13884509f, 0.57117194f, + 0.29575035f, -0.97884250f, 0.83291447f, -0.59255791f, -0.04354135f, + -0.19431923f, 0.30071029f, -0.95421529f, 0.76359886f, -0.47799742f, + 0.68254346f, 1.19368529f, -0.48935115f, 0.30357337f, -0.50225669f, + -0.23370270f, 1.96702433f, 1.46558523f, 2.68482018f, 0.41622332f, + 0.73697484f, 1.43430734f, 0.15387188f, 0.20875402f, -2.49335337f, + -1.39674246f, -0.22125854f, -0.00424605f, 0.91416460f, 0.33384630f, + 0.44703746f, 0.25610185f, 0.38966551f, -0.01784045f, 1.66148460f, + 0.36005461f, 0.95716912f, -0.18246566f, -0.15480693f, 0.38775176f, + -0.56969136f, -0.29644895f, -1.04565966f, -1.00455630f, 0.30897698f, + -1.46885884f, 0.03657720f, -0.49302089f, 1.34134722f, 0.01673754f, + 1.22725964f, 0.55256772f, 0.63803208f, -0.29041430f, 1.11455286f, + 0.76329172f, 0.27073982f, 0.77173829f, -1.79884446f, -0.11889492f, + -1.92040312f, -0.46382675f, 0.20078070f, -0.98889589f, 1.46711135f, + -1.68280172f, -0.52852470f, 0.66245162f, 0.29575166f, 1.34826505f, + -0.22362417f, -0.14345661f, -2.34815073f, 1.26572001f, 0.66505629f, + 1.01141500f, 1.08030057f, 0.17036134f, 0.00168786f, -0.37282917f, + 0.69206375f, 1.07367527f, -0.49708191f, 1.49504781f, 0.58224988f, + 0.96593714f, -1.07661915f, 0.25202179f, 0.25531644f, 0.42357162f, + -0.31236249f, 0.48383278f, -0.06361829f, 0.24131298f, -0.95695931f, + -0.12589653f, 0.36134180f, 3.20266032f, -0.40879184f, -0.66985190f, + 1.51674330f, 0.34072638f, 1.15076303f, -0.40199137f, 0.46223637f, + -0.48608047f, 0.99119538f, -0.22506073f, 0.30968750f, 0.64210880f, + 0.54640514f, 0.18607031f, 1.26293361f, -0.77960914f, 0.79572529f, + 1.01936150f, 2.27160740f, -1.48034489f, 0.74466604f, 0.14863680f, + 0.31102443f, -1.15673816f, -0.38609681f, -2.65026069f, -0.45524642f, + -0.74022961f, 2.74991131f, 0.00103815f, -3.03303242f, -0.41556966f, + -0.87103498f, 0.78306234f, -0.88195556f, -0.77297026f, 1.21203196f, + -1.09754920f, -0.03556008f, -0.31546223f, 0.72954375f, 0.25251788f, + 0.11378583f, 0.50921023f, 0.30301905f, -1.60631680f, 0.27152416f, + 1.17342317f, -0.70891970f, -0.08392961f, 0.92137378f, -0.10568139f, + -0.31653777f, -0.28878728f, 1.22166574f, 1.12693942f, -0.21325994f, + 0.94010323f, 1.21796405f, -0.68866694f, 2.30724216f, 0.28141466f, + 0.83481526f, -0.04885862f, 0.01675143f, 1.04355800f, -0.81050140f, + 1.51300573f, 0.53429186f, -0.56439877f, 0.38572624f, -0.05620475f, + 0.67644542f, 0.72528905f, 0.05937041f, -1.06315899f, -0.51393986f, + 0.46937627f, -0.34699562f, -0.64765716f, -1.45512629f, 0.47739139f, + -0.88228017f, -2.00791359f, 1.29929042f, 0.05482405f, -0.66725296f, + -0.54735124f, 0.09972951f, 0.76675093f, 0.98748523f, 0.08900899f, + -0.78854066f, 1.47970486f, -0.61667502f, 0.45625573f, -0.21766303f, + -0.46250847f, -0.07130960f, 0.64414692f, 0.12784545f, 0.26393634f, + 1.07720757f, -1.23938286f, 0.62483376f, -0.55001754f, -0.05358591f, + 0.07322436f, 1.12003291f, -1.00830650f, -0.20486419f, 0.76664752f, + 0.28850746f, -0.04464776f, -0.40146068f, 0.73262817f, -1.12827921f, + -0.19989438f, -1.15999687f, 1.37973154f, 0.78881019f, -0.34762639f, + 1.22088552f, -1.64088547f, 0.63218033f, 0.45736769f, 0.05502866f, + 2.22683382f, -1.78935897f, -1.49635041f, 0.83450896f, 1.67770112f, + 1.33909333f, 1.51158953f, 0.28595078f, -0.08593627f, 0.45812801f, + -0.15193029f, 1.14770603f, -0.88920450f, -1.96352005f, -1.49894583f, + 0.49629962f, 1.59872091f, 0.00903497f, 2.15563583f, 2.25149560f, + -2.01200557f, 2.56229877f, -1.38850498f, 0.73552012f, -0.39378855f, + 0.52616280f, -0.03685786f, 0.87403935f, 0.12163408f, 0.74297994f, + -0.30697080f, 0.38139752f, 0.49113834f, -0.95485127f, -0.99908817f, + 0.71716321f, 0.04000283f, -2.09645271f, 1.38789880f, 1.37198520f, + 0.82493287f, 0.17114936f, 0.53696346f, -0.19516060f, -0.50377476f, + -0.91730285f, -0.70113552f, -0.02406530f, 0.84943396f, -0.17428185f, + -1.09140801f, -0.68156958f, 1.70756388f, -1.00399911f, 0.03023832f, + -0.39023280f, -1.89737976f, 1.14469039f, -0.58337289f, -0.60037899f, + -1.17490256f, -1.56342828f, 0.48714057f, 0.62266618f, -0.15967095f, + 1.32789338f, -1.25700688f, -0.55633998f, -0.83128709f, -0.49346271f, + 1.59561753f, -0.24675299f, 0.38012561f, 0.91796309f, -0.38522810f, + -0.65509188f, 0.94100451f, -0.57324487f, 2.19070768f, 1.24058700f, + -0.75978851f, -0.40460554f, 0.79189235f, 0.70192885f, 1.93569362f, + -0.03070199f, 0.77010989f, 0.58794290f, 0.51087004f, 0.22892070f, + 0.35007235f, 1.56023848f, -0.67453802f, -0.18485607f, 0.64349502f, + -0.31489357f, -1.95834625f, 0.06560058f, 2.30394220f, 1.18194163f, + -0.88034087f, -1.05000436f, -1.05471325f, -0.98481798f, 0.49904808f, + 0.16438948f, -1.10297823f, -1.39736509f, 0.01306054f, -1.85160267f, + -0.87292641f, -0.15418227f, 0.43412164f, 1.16518164f, 0.06273691f, + 0.24659210f, -0.08267246f, 1.28885782f, 0.73575675f, -0.01019809f, + -0.08753663f, -0.61827368f, -0.40863234f, 2.12599611f, -0.53620332f, + 0.53789747f, -0.66386080f, -1.70461988f, 0.86608189f, -1.11151052f, + 0.14120635f, 1.18858743f, -0.31760478f, -0.73533046f, 0.20978074f, + -0.84074509f, 0.16523147f, -1.03362834f, 0.59721231f, 0.21318658f, + 0.23671274f, 1.75115061f, 0.25363782f, -1.32541454f, 1.13056135f, + 0.24652456f, 0.60381413f, 0.21478581f, 0.75044096f, -0.63125616f, + -1.69889998f, -0.02116571f, 1.46165359f, 1.03068244f, 0.63693464f, + 0.67795700f, 1.20033514f, -1.39205134f, -0.61743122f, 0.56549704f, + 0.65182322f, -0.74250507f, -1.61939359f, 1.14054918f, -0.45725963f, + 1.74519682f, -0.66251940f, -0.94811529f, -1.60865819f, -0.59968346f, + 0.86309159f, -1.91936195f, -1.02646923f, -1.50352538f, 0.58292735f, + 0.05320299f, 1.53582895f, 0.01069612f, 0.15226212f, -0.71840125f, + -1.36896348f, 2.14600968f, 0.96626586f, -0.52014917f, 0.41001406f, + 0.59478027f, 0.15282436f, 0.27790198f, 0.76614654f, -0.38971323f, + -0.01839927f, -1.57882118f, 0.61391610f, -0.62133092f, -0.03968323f, + -0.88467252f, -1.24041140f, 2.07306671f, -0.41776338f, 0.14537935f, + -0.91069067f, 1.67362070f, 4.72630215f, -0.07395106f, 0.46280116f, + -0.40843824f, 0.70683080f, -0.27510864f, -0.63465804f, -0.83630908f, + -0.44419941f, 0.60405648f, -0.65039170f, -1.02413189f, 1.05983019f, + 1.73366308f, 0.73343736f, -0.00895882f, -1.00826013f, 0.17323074f, + 0.73995626f, 0.24128854f, 0.94510227f, 0.25557515f, 0.02244723f, + -0.95197725f, -0.16297856f, -0.38497585f, 1.17993331f, 1.20282137f, + -1.31491220f, 0.44229278f, -0.24349044f, -0.01230415f, 1.37944865f, + 0.48554277f, -0.54510897f, -0.10793537f, 0.41121426f, -0.12889031f, + 0.26434359f, 1.27966082f, 0.64518744f, -0.15577169f, -0.99864733f, + -0.61746484f, 2.01614976f, 1.56254935f, 1.86473298f, -0.54662132f, + -0.22047071f, -0.06118120f, 0.84799510f, 0.17009684f, -1.30523121f, + 0.64000309f, 0.36299205f, -0.59620583f, 1.36372304f, -0.05389515f, + -0.93849313f, 0.98043185f, -0.39373067f, -0.84898937f, 1.32077873f, + 1.05988657f, -1.35339200f, 0.23259017f, 0.63816410f, -0.80297333f, + 0.60017115f, 1.25715804f, 1.18894124f, -0.62473553f, 1.05611980f, + 0.02335166f, 1.07509828f, 0.25873449f, -1.68341100f, 0.54547334f, + 0.79288185f, -0.93678916f, 0.19202201f, -1.48575914f, 1.08649087f, + 0.50851744f, -0.45758674f, -0.39734635f, 0.35637981f, -1.63079453f, + -0.75910008f, 0.92640859f, -0.55599529f, -0.40276715f, 0.31307653f, + 0.39907026f, -1.18830419f, 0.71051043f, 0.14157933f, -0.39581308f, + -1.64361024f, -0.06161860f, -0.25312796f, 1.10018682f, 0.56500763f, + 0.80385065f, 0.35395023f, 0.81813669f, 0.27644628f, 0.65563256f, + 1.73197234f, 0.68178749f, 0.76769936f, 0.44597456f, 0.67761195f, + 0.67635447f, -0.32315412f, 0.19330767f, -0.25557944f, 1.91693723f, + 0.38335562f, 0.07107610f, -0.57384586f, 0.79184365f, 1.87835479f, + 0.60902315f, -0.94220877f, 0.79479855f, -0.25656971f, 0.08739131f, + 0.53384244f, 1.22159266f, -0.39152125f, -1.46373534f, -0.02458516f, + 1.62825716f, -1.26112676f, 0.19967082f, -0.71114451f, 0.27929229f, + 0.65001321f, -0.11868202f, -0.55587751f, 0.78069001f, 0.57969242f, + -0.60274386f, 0.31650013f, 0.90339553f, 0.09453616f, -0.37119162f, + -1.00320566f, 0.33299938f, -0.48636708f, 0.26342997f, -0.91914523f, + 0.28682709f, -1.24780893f, -1.59254742f, 0.97176319f, 0.14744301f, + -0.53056234f, -1.73221612f, -0.67645556f, 0.98705006f, 0.79895812f, + -2.04333115f, -0.60132772f, -0.91653955f, -0.28094748f, 0.47943443f, + 0.38157779f, -0.67648011f, 1.09093642f, 1.66012859f, -0.29358891f, + -1.26773024f, 0.36747769f, -1.10141146f, 0.82383633f, -0.89772314f, + -0.47145563f, 0.63939518f, -0.64430422f, -0.48889321f, -0.37680882f, + -1.06962025f, -1.28689516f, 1.28365147f, 0.61859220f, -0.84676331f, + 1.38404000f, 1.21053445f, -0.14871351f, 1.06349385f, 1.45878971f, + -0.47362664f, 1.40707004f, 1.25224137f, 0.87364739f, 0.92858213f, + 0.00157326f, 1.45661485f, -0.27318576f, 0.15482858f, -1.07058907f, + -0.06903186f, -0.74147576f, -1.64111829f, -0.67226541f, -1.13458407f, + 1.28511488f, -0.41041154f, 2.09085560f, 0.45243183f, -0.67437285f, + 0.84960121f, -1.49300814f, -0.42961186f, -2.35021853f, 0.57255560f, + -0.73903763f, 1.37607956f, -2.44575167f, 1.25105727f, 1.38575912f, + -1.16299784f, -0.13719854f, -1.11507034f, 0.35796806f, -0.64511567f, + -0.87903833f, 0.32833642f, -0.87696886f, 0.02714214f, 0.30224666f, + -0.69118696f, -1.23500824f, 0.76678628f, -3.20508122f, -0.24704689f, + 0.49019828f, -1.20862615f, -0.03778638f, -0.07273687f, -0.11517122f, + -1.75857520f, -1.64188445f, 1.21574795f, 0.57325113f, 1.14370298f, + -1.07824504f, 1.70653832f, -0.03700557f, -0.47645858f, 0.11065386f, + -1.03143036f, -2.18094873f, -0.94403434f, -0.09335683f, -0.44817665f, + 1.39707148f, -1.21947956f, 0.56575936f, -0.69612634f, -1.12361753f, + -0.17105591f, 1.15422392f, 0.02840637f, 0.09469353f, -0.52859986f, + -2.08487725f, 1.28789508f, -0.03740775f, 0.61196613f, 1.23405397f, + 1.56595814f, -0.65800631f, 2.02985072f, -0.69446486f, -0.88443804f, + -0.23448054f, -0.43628734f, -0.45888957f, -0.21943338f, 1.78258693f, + 1.75214970f, 0.71804136f, 0.49782532f, 0.37886053f, -1.59176385f, + -1.74758542f, -0.02820176f, 0.75398153f, 1.00119829f, 0.80881971f, + -0.53365272f, -0.22720885f, 0.37476870f, 0.01005529f, -1.23421800f, + -0.13431595f, -1.01843679f, 1.87386346f, -1.68539488f, -1.04942071f, + -0.77322137f, 0.53964764f, 0.29278332f, -0.58299130f, -1.56022692f, + -0.79441273f, 0.49289709f, 0.44112054f, 1.07305002f, 0.54899335f, + 1.13781393f, 0.77809113f, 0.81795985f, 0.16576190f, 0.32552773f, + -0.20250474f, 1.46543837f, 0.12731771f, 0.21013761f, -1.34241438f, + 0.44267517f, 0.93246883f, 0.08808212f, 0.92653406f, -1.21083558f, + 0.17247954f, -0.70557106f, 0.04630012f, 0.48834828f, 0.89634645f, + 0.46683592f, -0.29553145f, 0.46363977f, -0.48971879f, -0.88603491f, + -0.12333342f, 0.37073737f, 0.92061806f, 0.54675460f, -0.14716248f, + 0.75578392f, -0.98173791f, -1.15983224f, -0.58713156f, 0.07950903f, + -0.59016788f, 0.41622928f, -0.32474482f, 0.42086437f, 0.23061797f, + 0.62596649f, -0.22615278f, -2.14721417f, 1.01685894f, -0.25976995f, + 0.00739352f, -1.31597066f, 0.39005190f, -1.09549701f, 1.68375242f, + 0.43331525f, -0.37124026f, 0.22255214f, 0.59654880f, -0.73840386f, + -1.20048976f, 0.12226126f, 0.12997478f, 1.04826224f, 0.03894836f, + -0.36289826f, 1.14466560f, -1.18198848f, -0.03713558f, 0.67677927f, + -0.42329931f, -0.89409167f, -0.77874780f, 0.58438253f, -0.35176343f, + -1.53329861f, -0.02995299f, -0.40145162f, -1.51052392f, 0.09194464f, + -1.13275242f, -0.61983156f, -0.40004560f, -0.19893464f, 0.22134103f, + -0.03903082f, 1.14894116f, -0.03476744f, 0.22520730f, -0.55851930f, + 0.76650429f, -0.57863152f, -1.34161711f, -0.31498179f, -1.19411755f, + 1.70044947f, -0.17428267f, -0.35983825f, -0.42613637f, 0.58165723f, + -0.77866900f, -1.59727287f, -0.61723864f, 1.51078022f, 0.32971445f, + -0.86441469f, 0.60552609f, 0.00208178f, -0.47096625f, -1.10479307f, + -1.21652532f, -0.08211990f, -1.43739200f, -1.31684434f, 0.43312529f, + -0.76822090f, 1.88128507f, -0.02179282f, 1.04971325f, -1.55004108f, + 1.25337446f, 0.11203052f, -1.16048300f, 1.59467411f, -1.29469275f, + 1.14019871f, 1.20021439f, 1.84098923f, 0.05004879f, 0.73529941f, + 2.05272865f, -0.13080600f, -0.08436690f, -1.17919350f, -0.66256678f, + -0.36727047f, 0.73840511f, 1.22293818f, -0.00206342f, -0.29839504f, + -0.00618613f, 1.04213119f, 1.21176076f, -0.62886089f, -0.02589060f, + 0.96009409f, -0.64478731f, -1.16516542f, 0.57528079f, 1.04294407f, + -0.09774588f, 0.45935291f, 1.03263175f, 1.00633478f, -1.82209253f, + -0.18035053f, -0.28302726f, -0.83813244f, 0.57593471f, -0.03807700f, + 1.60498738f, 0.16530658f, -1.43083501f, 2.10824299f, 0.30279446f, + -0.03961089f, -0.38900724f, 1.31272805f, -0.56575215f, 0.57970244f, + -0.48305038f, 1.34114623f, 0.21859215f, 0.66399640f, -1.52087069f, + -1.30717897f, 0.14394683f, 0.97648209f, -0.71372712f, -1.22574198f, + -0.27702177f, 0.04041927f, 0.02442212f, 2.19617033f, -0.48566443f, + 0.81463927f, 0.20383844f, 1.17562282f, -0.33829874f, -0.42141283f, + -0.96415234f, -2.39141965f, -1.04285860f, -0.23004992f, 0.41186509f, + 0.03811268f, 0.36818987f, -0.71099734f, -0.56749570f, 0.18486284f, + -0.44530040f, 2.14008284f, -0.27467576f, 1.70690107f, -1.40462613f, + 0.24697532f, -1.31629777f, -2.20674944f, -0.67868507f, -1.15767133f, + -0.64391804f, -1.79037917f, 0.58749497f, -1.58303332f, -0.69021022f, + 1.64376318f, -0.95393223f, 1.98415601f, -0.10991055f, 0.02474386f, + 0.23683345f, -0.63420391f, -0.57991928f, 0.83028817f, -0.40033704f, + 0.19212338f, 0.74640590f, 1.10264432f, -1.65286255f, 0.92683482f, + -1.42252541f, -0.74605089f, 2.14535880f, 0.12971123f, -0.47971717f, + 1.67546797f, 0.42268261f, 0.22648531f, -0.42369929f, 0.77403021f, + -1.31818616f, -0.67143595f, -0.04311426f, 1.64128351f, 0.34776631f, + -0.39353722f, -0.42765084f, 0.16170517f, -0.54488391f, -0.38428506f, + 0.42097485f, -0.55982012f, -1.74543798f, 1.53704774f, 0.43562424f, + -0.30395737f, 0.31846946f, 0.39205357f, 0.57386035f, -1.11912560f, + -1.39164317f, -1.04337609f, 0.31629622f, 1.51927638f, 0.88745505f, + -0.40445471f, 0.25783861f, 1.88646257f, 0.36509129f, -1.13266826f, + -0.45394278f, -0.48400903f, -1.22332740f, 0.38626808f, -1.10049105f, + 0.84138852f, 1.27863181f, 0.53942156f, -0.67743856f, -0.03896645f, + 1.70393491f, 0.60997570f, 0.43368068f, -0.13338457f, -0.18920666f, + -0.29583672f, -1.40738738f, 1.03876019f, 1.71253765f, 2.12821221f, + -0.96092403f, 0.93841934f, -0.79030478f, 1.36427641f, -1.39196694f, + 0.08514920f, 0.16223004f, 0.71259701f, 0.20150672f, 0.25068361f, + -0.99952722f, 1.80129099f, -1.28586197f, -0.64957166f, -0.94813949f, + -0.40161121f, 0.31977695f, 0.54932386f, -0.67757767f, 1.88086259f, + 0.92337233f, -1.64887333f, 0.44333732f, -0.19468001f, 0.12977587f, + 0.21171951f, 0.27679422f, 0.49134475f, -1.44429457f, 1.25617445f, + 0.39978400f, 0.99869555f, -1.61617446f, 1.61177349f, 0.70243025f, + -0.95748568f, -0.61795151f, -0.77302909f, 0.72967088f, 0.81964350f, + -0.71813750f, 0.90140164f, -1.45950246f, -0.79972702f, 0.40875742f, + 0.00152073f, -1.74491429f, 1.53776145f, 0.75769204f, -0.22075878f, + -0.58385569f, 2.18884754f, 0.33597681f, -1.66265559f, 1.03805876f, + -1.55245185f, -0.03582226f, -1.94542754f, -0.76081425f, -0.50471377f, + 1.35763168f, -0.39631784f, -0.17134467f, -0.82220149f, -0.41021580f, + -0.00940776f, -0.80176353f, -0.19816744f, 1.22061026f, -0.14486519f, + -0.71727395f, -0.65721530f, 0.47020102f, -0.70403302f, -0.94795334f, + 1.79884899f, 0.07779162f, -1.50615680f, 0.04140327f, -0.22001404f, + 0.63735324f, 0.79237640f, -2.25412822f, -0.52519119f, -0.87280381f, + -0.07100742f, -0.94734806f, -0.12286110f, -0.13623615f, -0.42595413f, + 0.17547913f, -0.81707209f, 0.36855817f, -1.68186557f, 0.19312963f, + -0.66249490f, -0.98283452f, -0.33314428f, 0.40918943f, 0.88268638f, + -0.05390308f, -0.22440539f, -0.15879378f, -0.34859571f, -0.01013108f, + -0.30005428f, -1.19408464f, 0.21789688f, -1.07769871f, 0.81475031f, + -0.69555300f, 2.35201311f, -0.40362412f, 0.93497628f, 1.13343573f, + 0.92343372f, 0.26987928f, 0.46123627f, 0.22577702f, 1.26289701f, + -0.45956740f, 0.55994868f, -0.58410591f, 0.13304594f, -0.25806463f, + 0.49044946f, -0.82065403f, -3.06672239f, -0.27774641f, 0.68504512f, + -0.21386372f, 1.11427057f, -0.73201770f, 0.51655543f, 1.77261138f, + 0.72081727f, 0.11116749f, 0.16637769f, -0.74987584f, 0.66579849f, + -0.75808716f, 0.20678560f, -0.67698354f, -0.82141948f, 0.61008269f, + 0.66520184f, 0.44894725f, 0.73015076f, -1.52517414f, 0.11714164f, + 1.90452611f, -1.30355322f, 0.12144456f, 1.18547559f, -0.07349755f, + -2.28061509f, 0.83522540f, 0.78438890f, 2.19334102f, 0.90305614f, + -0.59345531f, 0.77925014f, 1.32338643f, 0.14068902f, 1.19032264f, + 0.20666829f, -0.76595837f, 0.74967057f, 2.86965609f, 0.55690205f, + -1.72530472f, -0.83317834f, -0.85842621f, -0.29678273f, 1.80955839f, + -0.70496303f, 1.19106734f, -0.92985237f, -1.00617313f, -0.56049556f, + -0.29382578f, -2.04022193f, -1.95356870f, -0.42553005f, -0.33369407f, + 1.02115977f, -1.45769477f, -0.67720300f, 0.53819913f, 1.57643425f, + -0.47015440f, -1.47861958f, -0.00545934f, -0.97836047f, 0.42680529f, + 1.56110144f, -1.49487829f, -0.65198445f, 0.22720462f, 1.83036661f, + -0.47099793f, -0.09915133f, 0.14923312f, -1.16313052f, 0.67798084f, + -1.63665557f, -0.38220280f, 0.01719763f, 0.30041245f, 0.43148938f, + -0.44021657f, -1.25734651f, 0.02465564f, -1.00845659f, -0.28574651f, + 0.01367745f, 0.77253437f, -0.99399441f, 0.61445391f, 0.18343423f, + -0.50997210f, 0.41359940f, 0.77279282f, 0.83511519f, 0.27929801f, + 0.70800692f, -0.20278299f, 1.57884383f, 0.22650529f, 0.43347472f, + 0.74003208f, -0.71401161f, -0.69829476f, -1.56766701f, -0.99254119f, + 1.27301061f, 2.73726511f, 0.66089469f, -1.95778012f, -1.24642098f, + -0.63579029f, -1.63168180f, -0.66980726f, 0.81933254f, 0.61866677f, + 1.40594471f, 0.05158535f, 0.00196500f, -0.24592508f, -0.50780547f, + -0.83905292f, -0.10748957f, 0.04490763f, 0.27769178f, -0.23227681f, + 0.82108080f, 0.03562285f, 0.95483875f, -1.49897683f, 0.67809856f, + 0.35497451f, -0.44021592f, -1.67361462f, -0.88895375f, 1.44293678f, + -0.85046643f, -0.46437624f, -1.87252641f, 0.26775804f, -0.24535774f, + 0.73365933f, 0.52253938f, 0.27947086f, -0.58796054f, 0.59045380f, + 1.93476331f, -0.46775359f, 0.25238225f, -1.26601815f, -0.13324316f, + -0.71454948f, -0.21610366f, -1.49586582f, 1.04903507f, 0.22208478f, + 0.25512528f, -0.46157327f, -0.41319233f, -0.63846964f, -0.25100923f, + 0.81277549f, -0.26959971f, 0.88737756f, 1.24578953f, -0.91121447f, + -1.05756927f, 0.44390878f, 0.16672316f, -1.22941923f, 0.89547867f, + -1.50212002f, -1.69620168f, 0.53339505f, -0.23656729f, -1.69879091f, + 0.01510374f, 0.08315694f, -0.73196459f, -1.60263407f, -1.07601058f, + -0.76389569f, -1.65307498f, -0.61484390f, -0.43546933f, 0.71318507f, + -0.16273083f, 0.64122051f, -0.15406294f, 1.17673671f, -0.91240519f, + 0.71091145f, 2.40497613f, 1.26343656f, 0.71469337f, 0.20705548f, + 0.81776261f, 0.36253929f, -1.92106628f, -0.09300470f, -0.36648872f, + 1.27732766f, -0.39180157f, -0.61186749f, -1.03455031f, -0.25079829f, + -0.61479062f, -1.07094336f, 0.82218504f, 0.89934880f, 0.41308978f, + -0.59968555f, 0.37682834f, -1.77388155f, 0.00294951f, -0.66145372f, + -0.50789726f, -0.85123241f, -0.89909405f, -1.89454281f, -0.56692821f, + 1.52272677f, -0.11961794f, 0.27843913f, -0.60582250f, 1.01871169f, + -0.36098275f, -0.12242325f, -0.67375034f, -0.11204147f, -2.62773919f, + -0.95901299f, 0.14040214f, 1.32364666f, -1.35099924f, -0.11077739f, + -0.79319423f, 0.75949597f, -0.25485823f, -0.90959758f, -0.42373934f, + -1.29850340f, 0.85699379f, -1.11882365f, 0.63470817f, 0.49696380f, + -0.07983235f, -0.23903450f, -0.22618714f, -0.12117998f, -0.09442677f, + 1.55589819f, -0.11996678f, -1.72700179f, 0.54683149f, -0.40804827f, + -0.50099218f, 0.34596699f, -1.81841791f, 0.06385052f, 0.84428120f, + 0.69901514f, 1.94559097f, 0.43251973f, 0.16794942f, 1.82829034f, + 1.70959795f, 0.36130908f, -0.94608402f, -0.53498030f, 0.47781768f, + -0.24203247f, 1.25065851f, 0.51788396f, -2.09381890f, 0.72973937f, + 0.03281829f, 0.58632666f, 1.85737121f, -0.49569523f, 0.45921183f, + 1.87173629f, 0.22803484f, 1.66433418f, -1.05872321f, -1.13663685f, + 0.12397861f, -0.65112090f, 0.98152941f, 0.83739656f, -0.18783289f, + 1.84249437f, -0.90706986f, -0.80824369f, -1.23854923f, -0.86488134f, + -1.02627063f, 0.10976455f, -0.61403006f, 1.27554715f, 0.14653525f, + -0.03953953f, -0.08512071f, -1.30043304f, -0.02566035f, 0.12054887f, + 0.00282162f, 0.48921332f, -1.74398839f, 1.44554436f, -1.35854721f, + 0.69256759f, 0.34101671f, 2.50045252f, 0.49121150f, -0.27115449f, + 0.93974596f, 0.26258010f, 0.27151433f, -0.87214381f, -0.92580765f, + -1.03269923f, 0.20615758f, -0.37822601f, 0.58983004f, 0.16426525f, + 0.68218285f, 1.98158526f, 0.47492698f, 0.54224718f, 1.28722692f, + -1.76915324f, -1.11240053f, 0.77428484f, 0.27184650f, 2.22473478f, + -0.05574624f, 0.39976570f, -0.43911108f, 0.52805597f, 0.17340177f, + 1.36057591f, -0.35004014f, 1.72787797f, 0.68357420f, 1.25532615f, + -0.56752264f, 0.51840127f, -0.21237844f, -0.58821255f, -0.85278064f, + 1.90179110f, -0.67447448f, -0.36831430f, -0.22930753f, 0.98231596f, + -0.07011599f, -0.08560387f, 0.05998110f, -0.02481356f, -0.57335132f, + -0.44288307f, -0.24468307f, 0.53321087f, 1.19609559f, 0.10664973f, + 0.24379487f, 0.93687552f, 0.93615580f, 1.74319768f, -0.68310338f, + 1.32163060f, 0.61918712f, -0.76501870f, -0.54549301f, 1.74077415f, + -0.69977754f, -0.66880983f, -1.15981388f, 0.81571609f, 0.53788543f, + 0.47898352f, -0.02484704f, -1.64646924f, -0.69822907f, 0.27020717f, + 0.05027051f, 1.75149667f, 0.01548872f, 0.32615909f, 2.55151844f, + -1.29172051f, -0.36133784f, 0.98637396f, 0.14009331f, -0.50038946f, + -0.92230296f, 0.17307127f, 1.05361068f, -1.46784890f, 2.38960409f, + 1.19413340f, -1.33349669f, 1.59141159f, -0.71811068f, 1.22429430f, + 1.26947939f, 1.08177102f, -1.18138707f, -0.72775704f, 0.17282635f, + -0.40554270f, -0.40341887f, 0.46564049f, -1.02069795f, -0.07653128f, + -0.13979210f, -0.31195050f, -1.72042310f, 1.37131393f, 0.63849634f, + 0.75561279f, 1.81152904f, 0.26686314f, 1.32796574f, 0.56100166f, + 0.70058894f, -0.88962644f, -0.04360984f, -0.88249093f, 0.24311203f, + 0.50410056f, -2.22567797f, 0.94520348f, -2.12467694f, 0.47282359f, + -0.71379906f, -0.09857135f, 0.62374717f, 1.37182784f, 0.73380554f, + 0.59745449f, 2.80427694f, 0.67253572f, 1.65335357f, 1.69891667f, + 1.34585941f, -0.79989213f, 1.44980943f, -0.52013642f, -0.46971673f, + -1.50070012f, -0.25687039f, -0.56916732f, 0.71065760f, -1.31996286f, + 0.96031237f, 0.13929774f, 1.49679291f, -0.05966444f, -0.58674580f, + -0.08278833f, -0.93390942f, 0.42415768f, -1.77889526f, 0.75336021f, + -0.72699982f, -0.82880586f, 0.63955617f, 0.42771208f, -0.42366457f, + -0.91581815f, 0.94750947f, 0.43123913f, -0.99053741f, 0.70470595f, + -1.16662264f, 1.14847183f, -0.83885664f, 0.46714026f, -2.27748466f, + -1.23656678f, 0.14695056f, -0.33159894f, -0.52553117f, -0.04391259f, + -0.29630372f, 0.25949728f, 0.96991086f, -0.37714824f, -0.28251833f, + 0.16106486f, 1.38844633f, -0.18713553f, -1.30708838f, 0.48490265f, + 0.29553881f, -0.45505449f, 0.83341682f, 0.87346369f, -0.63516861f, + 0.66063565f, 0.93892503f, -2.73996735f, -0.81515318f, -0.91458052f, + 0.00978268f, 0.43472794f, -0.08090764f, 1.37249672f, 0.76722521f, + -1.19154143f, 0.22046764f, 0.34916410f, 0.51383299f, -0.56379753f, + -2.49949312f, -0.74207872f, -0.68400806f, -0.09663232f, -0.07199454f, + -1.05562651f, -0.75028551f, -0.87253797f, 0.69039482f, 0.45923674f, + -1.27515161f, -0.04555376f, -1.41501272f, -0.83773375f, -0.74807298f, + 1.36646152f, 0.06317432f, -1.32559633f, 1.89092779f, 1.24883330f, + -1.03608561f, 1.08677161f, -0.99629849f, -0.69947034f, -0.85716367f, + -0.07947286f, -0.25485426f, -0.19732477f, 1.64581251f, 1.04618108f, + 1.87186897f, -0.18198362f, -0.83807969f, 0.70462501f, -3.18930101f, + 0.74610996f, -0.60935193f, -0.49383929f, -2.88986492f, 0.51707613f, + 1.04620326f, 1.09837818f, -1.19840038f, -0.10391295f, -0.20789115f, + -1.51052022f, -0.31087330f, 0.22411564f, -1.30506921f, -1.52000105f, + -1.51593041f, 1.04321992f, 0.97611690f, 0.90424490f, 1.83324766f, + -0.08682299f, 0.47035542f, 1.70865905f, -0.31108001f, 0.04115159f, + -1.36352801f, -0.90797836f, 0.32128647f, 0.66191489f, 0.08681208f, + 0.14993365f, 0.47110486f, -0.31522670f, -0.38906571f, -0.08876022f, + -0.13106902f, 2.25685239f, -0.62211353f, -1.68553007f, -0.23707703f, + 0.69236159f, -0.46686995f, -0.27520603f, 0.26619941f, 1.48525345f, + 1.61278927f, 0.49452963f, 1.20846486f, -1.11853909f, -0.30010033f, + -0.75471467f, -1.69959772f, -0.52042168f, -0.43881389f, -1.45240712f, + 1.02122891f, 1.73639011f, -0.03813924f, -0.22239220f, 0.15797073f, + -0.64418089f, -0.60228932f, -0.83248150f, -0.02042520f, 0.38137484f, + 0.86056453f, 0.06410559f, -0.62785137f, -0.49916875f, -2.53796315f, + -0.79168582f, -0.69197005f, -0.77175534f, -0.28669405f, -0.79764080f, + 0.97218460f, -0.10351621f, -0.52759898f, 1.02840185f, 1.16363287f, + 0.08351815f, -0.61088538f, 0.59944046f, 1.54409397f, -1.39842033f, + 0.27917057f, -0.27146137f, 1.46310735f, 0.03626106f, 0.15038440f, + -0.07894899f, -1.42527366f, 1.69641745f, 1.48384345f, -0.43328866f, + -0.54252565f, -0.94416499f, 1.54436302f, -0.81367069f, -1.67925239f, + -0.17525831f, 0.27891046f, -0.69066733f, 0.89911050f, 0.11606655f, + 0.67450327f, 0.41538724f, 0.90886223f, 1.19786549f, 0.85810721f, + 1.32862210f, -0.83469814f, -1.09682298f, 0.88092703f, -0.97478902f, + -0.11664717f, -0.07929394f, -0.69581884f, -0.16928329f, -0.70731819f, + -0.40485084f, -0.28954300f, 0.52882415f, 0.38769314f, -1.38704026f, + 1.15099049f, -0.43566978f, 0.34459323f, 0.49520254f, 1.11130333f, + 0.28783718f, -0.53783375f, -1.63577271f, 1.02222812f, 0.86302060f, + 0.48346213f, 0.46627176f, -1.30133855f, -1.48477137f, 0.31219670f, + -1.21498191f, 0.89838904f, 0.87186617f, -0.39968935f, 0.34930915f, + -0.32909471f, -1.39364409f, 2.13006306f, 0.33270469f, 0.00215986f, + 0.97776711f, 0.24908836f, 1.56164885f, 0.45157790f, -1.55970144f, + 0.27677536f, 0.07662498f, -0.08262251f, -0.17658773f, 0.65820259f, + 2.01052690f, -1.71946216f, 0.84686053f, -1.23594892f, 1.40792072f, + -1.47772563f, -0.36132276f, -0.50405115f, 0.09009213f, 0.81659186f, + 1.85574234f, -0.64974433f, 0.63352364f, 1.01766217f, -1.54804432f, + -0.42570522f, -0.24763709f, 0.72822112f, -0.93733686f, 0.68087620f, + -1.40644944f, 0.48672482f, 0.09725539f, -0.64416331f, -0.95747960f, + 0.36771363f, 0.39155054f, -0.71790671f, -2.17222738f, -0.08655047f, + -0.97842115f, -0.22991380f, 0.52029115f, -1.42072022f, 0.29576331f, + 0.32391560f, -1.00823236f, 1.67909145f, 1.16841447f, -0.32307062f, + 0.15756166f, -0.97590631f, -0.39429301f, -0.03583352f, 0.17554663f, + 0.57961231f, -0.46873134f, -0.23343173f, -0.85060924f, 1.71745574f, + -0.04658702f, 0.63088381f, -0.67581934f, -1.53171062f, -1.58800113f, + -1.17987096f, -1.16737640f, -0.87544650f, -1.17138922f, 0.38979119f, + -2.39369726f, -1.34747124f, 0.58450359f, 0.87791806f, -0.04459394f, + 0.97995293f, -0.10354915f, 0.65324986f, -0.17833626f, -0.85849386f, + -0.42063358f, 0.19708554f, 0.10255250f, -0.59539181f, 0.86194044f, + 1.68610668f, 0.55275291f, -0.43127069f, -0.04218780f, -0.08466262f, + 0.31236625f, -0.92824298f, -0.09879152f, 0.32358822f, 1.04045570f, + 0.35617545f, 0.09059231f, 1.19069445f, 1.96978688f, 0.63561743f, + 0.15030998f, -0.29879019f, 0.22774190f, -1.01608860f, 1.03605175f, + 0.47804731f, -0.30450734f, -0.61382371f, 0.45390254f, -1.93547988f, + 2.01267338f, 0.52447683f, 0.18379784f, 1.11913633f, -1.24273467f, + 0.15803322f, 1.72184098f, -0.79349059f, 0.10258614f, -1.53445125f, + 0.02630571f, 0.81649125f, 0.91089755f, -1.12968338f, 1.04016411f, + 0.28999722f, 0.74863863f, -0.61388236f, 0.01665530f, 1.43592548f, + 0.68138391f, 0.11963340f, -1.26123953f, 1.36340797f, 0.25696915f, + -0.58877039f, 1.42209792f, 0.55563360f, -1.33329606f, 1.84695840f, + 0.88433737f, 1.04359078f, 0.18906727f, -0.03448994f, 1.17944050f, + 0.86783957f, 0.44934425f, -0.77892244f, -1.76232874f, -1.01689589f, + 0.78943914f, 0.92141974f, -1.00187087f, -0.13809921f, -0.90222073f, + 1.10094714f, -0.13657950f, -0.44349849f, -1.61441302f, 1.05724919f, + 1.50337231f, -0.05785890f, -0.76958144f, -0.51498759f, 0.69227600f, + -0.37975949f, 1.31949317f, 0.82049531f, 0.32868597f, -0.31557772f, + -0.75534385f, 1.27303052f, 0.43453619f, 0.11296938f, 1.18182182f, + 2.23387384f, -0.86412978f, -0.01599468f, -0.70869064f, -0.09221385f, + -1.23729551f, 0.79490280f, 0.03522846f, -0.95069039f, -1.73461652f, + 0.72329187f, 1.40385795f, -0.11585230f, -0.78033113f, 0.07491048f, + -1.12873089f, 0.18476245f, 0.57568848f, -0.28792691f, 1.35411644f, + -0.76956165f, 0.29571572f, 1.03178787f, -0.38780826f, 0.31680650f, + 0.69368076f, -1.23856580f, -0.49848995f, 0.14766994f, 1.02625990f, + 3.03858209f, -0.51030380f, 0.96796870f, 1.35078156f, -1.07729447f, + 0.84322494f, 0.54886484f, 1.31453705f, -0.45792100f, 0.31196272f, + -0.15701357f, 0.83586836f, -0.74952888f, -1.17432022f, -0.31002575f, + -1.02149463f, -0.36117774f, -1.22079086f, 0.03532525f, 0.00555908f, + -0.45891216f, 0.29636297f, -0.68272704f, 0.41257843f, 0.37988129f, + 0.01747893f, 0.82739186f, 1.52292180f, -0.79456621f, 2.20275712f, + 2.13212132f, -0.81393015f, -1.15712392f, 0.22488308f, 0.62776327f, + -0.85444915f, 0.44017896f, 0.05863331f, -0.83198178f, 0.93063420f, + -0.16121253f, 0.12382501f, -0.37826315f, 0.93118382f, 0.19507533f, + -0.58595538f, 1.46994352f, 0.13170272f, -0.70031989f, -0.12820166f, + 0.30487457f, 0.84148771f, -0.68807501f, 0.21187615f, -0.67030680f, + -1.79136002f, 0.70810199f, -1.20959783f, -0.08468831f, -0.06317700f, + 1.35527098f, -0.47018668f, -0.91693246f, 0.14818805f, -0.05405350f, + 1.16875637f, -0.17363262f, -1.61833882f, -0.32934523f, -0.38346377f, + -0.62702698f, 0.34135151f, 0.48015586f, -0.65263331f, -0.04689486f, + 0.01156854f, 0.37580970f, -0.16174591f, 0.59627324f, 0.24351901f, + -0.87983090f, 1.57049024f, 1.25836349f, -0.41464049f, -0.62279183f, + 0.09693756f, -0.23850618f, -0.49007827f, 0.22298151f, 0.10914832f, + -0.35192192f, -1.27221346f, 1.10203624f, -0.86399704f, -0.47319838f, + -0.77105570f, -1.68624854f, 0.81198281f, 0.82534081f, 0.75654501f, + 1.47631240f, -0.61000234f, -0.58933264f, 0.54822850f, -1.22829592f, + 0.11107657f, 0.56449169f, 1.50693524f, -0.59280968f, -0.64286685f, + -0.20120731f, 0.27184448f, 1.55500400f, -0.48919386f, 1.04044867f, + -0.87048137f, -0.40569979f, 0.21908638f, -0.51829034f, -1.48748124f, + 0.02990401f, 1.83462536f, 0.29885170f, 1.32370698f, -1.30129600f, + 2.43271399f, 0.22967771f, -1.13014007f, 0.95529765f, -0.83325785f, + 0.43633386f, 0.85774118f, 0.78160155f, 0.58583075f, 1.18906367f, + -1.54354560f, -0.68320692f, 0.01900371f, -0.79777133f, 0.12851712f, + 1.10176420f, 0.79418170f, -1.41154039f, 0.36929929f, 1.12176800f, + 1.23849642f, -0.89377707f, 1.01390159f, -0.50889206f, -1.12554002f, + 0.17932732f, 0.48949540f, -0.54235244f, -0.28146735f, -1.39125514f, + 0.13309635f, -1.12864995f, -1.29901242f, -0.04266220f, -1.98028529f, + -1.34869373f, 0.00038156f, -0.92473024f, 1.48010647f, -0.02754467f, + -0.26030368f, 0.93083733f, 0.27946711f, 0.64052200f, -0.04220961f, + 1.25002527f, -1.07923257f, 0.19048618f, 0.08900311f, -0.40813437f, + -0.73068553f, 0.52122378f, 0.68990833f, -0.38749605f, -1.09269309f, + -1.63480806f, 1.01789618f, -0.61596102f, 0.81049860f, 1.30838764f, + -1.49213874f, -0.77916288f, -0.72660202f, -0.92013240f, -1.61726642f, + -0.11527207f, 0.35143322f, -1.11646879f, -1.45525432f, -0.82892823f, + 0.15512508f, 1.01891017f, 1.40162635f, 1.02494884f, 0.33882582f, + -0.78747398f, -0.26009330f, -0.38519114f, 0.79247451f, 0.02065756f, + -0.48030257f, 1.01167107f, -1.74057114f, -0.84549171f, -0.15337363f, + -1.92544484f, 1.01270044f, 0.00762185f, -0.16405612f, 1.61778915f, + 0.93316060f, -0.68960994f, -1.13214970f, -0.94695878f, -0.28418848f, + 0.17102109f, -0.08787476f, -1.83799696f, -0.13761258f, -0.18652774f, + 1.46456254f, 0.34169790f, -0.40697145f, 1.49663997f, -0.99555492f, + -0.67775637f, -0.51951116f, 1.35157657f, -0.27099034f, -0.46987835f, + 2.28101230f, 0.59104478f, 0.75010139f, 1.01472175f, 0.25741309f, + -0.56074983f, 1.12267506f, 0.35336846f, 0.61733276f, -1.63976014f, + -0.17700450f, -0.25093642f, -0.75599891f, 2.10956192f, 0.95155340f, + 0.72049862f, 0.50492924f, 0.62067389f, 2.08688402f, -0.73604703f, + 0.63383341f, -0.53528428f, -2.11538506f, -0.98173052f, 0.59560484f, + -0.26205051f, -0.91948050f, 0.00593397f, -0.11734286f, -1.41261208f, + -0.83611172f, -0.27682739f, -0.20619918f, -0.36557615f, 0.77194935f, + 1.67695415f, -1.39265156f, 0.04892010f, -0.37773246f, 0.16124558f, + -0.18348448f, -1.38248885f, 0.58459854f, 0.65064198f, 1.11349559f, + 0.36708066f, -0.15471332f, 0.14208725f, -2.06860566f, 0.29629150f, + 0.93084633f, -0.47215626f, 0.60208917f, 0.95415461f, 1.03390312f, + -0.03639749f, -0.23988228f, 1.27037442f, 0.95133096f, 0.33187470f, + -0.34527761f, 0.22134073f, 1.01799667f, -0.81475645f, -1.18869019f, + 0.23314142f, 0.25180560f, -1.23762786f, 1.25283313f, 0.16980635f, + 0.40740708f, 0.59256923f, 0.16274920f, -0.69713289f, -0.16444311f, + -2.41602516f, 0.37952334f, -0.05604568f, -0.23772651f, 0.20581599f, + -0.54303211f, 1.71877348f, 0.83602583f, -0.32586128f, 0.73609394f, + -1.73640239f, 0.07249248f, 0.31248692f, 1.77627432f, 0.97660398f, + -0.42095289f, -0.18750280f, -0.84246057f, 0.29762223f, 1.87054563f, + -1.46980762f, -0.45306337f, 1.52366042f, 1.39061129f, -0.04980387f, + -0.55382830f, -0.96987218f, -0.06910808f, -0.41276473f, -0.83891344f, + -0.92597574f, 0.60252470f, 0.21938549f, -0.04451685f, -1.00330937f, + -0.36955237f, -1.52876902f, 0.27296364f, -1.96721256f, 0.05291027f, + -0.91540521f, 0.48990685f, -1.99560380f, -0.68551093f, -0.14532298f, + -1.56881595f, -0.08319287f, 0.31003201f, -1.42829597f, -0.61810297f, + -0.03581250f, 0.77747720f, 1.25297558f, -1.36239243f, -1.13274276f, + -0.35045877f, -2.34157228f, 0.04515179f, -0.83044821f, 1.81353962f, + -1.36855912f, 0.39704823f, 0.16665934f, -0.16654585f, 1.17806077f, + 1.00086153f, -1.25474250f, -1.46876431f, 1.18021631f, -0.32257929f, + 2.12062597f, 0.86819613f, -1.18048275f, -1.69747460f, -0.74092305f, + 0.05086798f, 1.15339577f, 1.32972670f, 0.27247882f, 0.98499072f, + 2.35597157f, 0.30179837f, -0.66633248f, 0.13794266f, -0.22753908f, + -0.22868259f, -1.81792033f, 0.50151759f, -0.79408127f, -1.05343878f, + 0.45727381f, 0.84800923f, -1.73605800f, -0.02032863f, 1.82778001f, + 1.41025102f, -0.81715560f, 0.25888795f, -0.25075480f, 0.66256499f, + 0.11993053f, 1.81336939f, -0.06345166f, -1.49658346f, 0.07531686f, + 0.96972889f, 0.87405980f, 0.75830793f, -0.13497087f, -2.45855975f, + -0.65984958f, 0.93919373f, -0.97305542f, 0.73477978f, 1.04337513f, + -1.22712576f, -0.46385625f, -1.20876372f, -0.82760453f, 0.01455977f, + -1.05089867f, -0.02801843f, 0.60899758f, -0.82052249f, -1.48932517f, + -0.98073828f, -0.19311285f, -0.25602359f, 0.50351876f, -1.24557400f, + -0.82138073f, -1.45966852f, 0.44991320f, -0.75550151f, -0.98550314f, + -1.21418869f, -1.15771639f, -1.72192061f, -0.39616469f, -0.55566746f, + -1.31880891f, -0.08843257f, 1.00422776f, 0.35846478f, 0.46060917f, + 0.77326930f, 1.60129988f, -1.85124147f, -0.30582917f, 1.30227256f, + 1.81890345f, -0.44084981f, 0.25315762f, 0.70259613f, -0.94882858f, + 1.97040296f, 0.71473581f, -0.68193883f, -0.36290962f, 1.16348684f, + 0.15418798f, 1.07806778f, 0.40554729f, 0.10280909f, -1.06474805f, + 0.64398485f, -0.63568884f, -0.06108581f, -1.03290677f, 1.02834034f, + 1.15284693f, 0.14046004f, 1.86630619f, 0.46804786f, -0.68397558f, + 1.60733378f, -1.64890087f, -1.03819239f, -1.19212389f, -0.78382361f, + 0.03925850f, 1.52259934f, 0.09540676f, -0.21220762f, 0.55955195f, + -0.39845437f, -2.14541650f, 0.49337825f, -0.68574250f, 0.74040270f, + 0.50783634f, -1.60461199f, -1.26806450f, -0.12652303f, -0.83992827f, + -0.15524681f, 0.40098447f, 0.23392735f, -0.23262636f, 0.06525709f, + -0.35994548f, -1.08432877f, -0.21395946f, -0.78357452f, -0.57157278f, + 0.71407390f, 0.86596155f, -1.13723528f, 0.13460183f, -1.20881450f, + 0.71018457f, 0.68943661f, -0.70428050f, 0.64600736f, 0.01990297f, + -0.10575775f, -0.80263519f, 0.10618331f, 0.08865548f, 1.51651669f, + 0.60851854f, 1.15161908f, 1.04919207f, 1.18359745f, -0.04352076f, + -0.83643389f, -0.07922365f, 0.10597949f, -1.34984851f, -1.91319740f, + 0.71585363f, -2.10845160f, 0.64385056f, -0.54551518f, -1.02039802f, + -1.62510490f, 1.65401149f, -0.42711899f, 0.07970079f, -0.21404363f, + 0.30498922f, 1.07942021f, 0.63995659f, -1.82114816f, 0.56396323f, + 1.07084870f, -2.00350380f, 0.53339815f, 0.18500003f, 1.15034151f, + -0.21436051f, -0.99986565f, -0.58812016f, -0.07247020f, 0.78910017f, + 0.48839527f, 0.98795873f, 0.10357288f, -0.05604928f, 0.38977858f, + 0.73745090f, 1.40838420f, 0.25967824f, 0.23588051f, -0.03451392f, + 1.04897523f, -1.77121758f, 2.35625434f, -0.67086869f, -0.84005541f, + -0.85940343f, -1.04449213f, -0.65917015f, -0.78713167f, -0.95910054f, + 0.38597879f, -0.31879017f, -0.86260867f, -1.08593106f, 0.02802678f, + 0.99484950f, -0.55113328f, 2.60936737f, -0.03388772f, -0.47583574f, + -0.14021793f, 0.99019170f, -1.22431207f, 0.78734446f, -1.77037835f, + 0.15018673f, 0.36423206f, 1.36447549f, -1.61007094f, 0.51875496f, + -1.60788095f, -1.73557448f, -0.41414359f, -0.93710536f, 0.38715765f, + 0.04243837f, -1.59682858f, -1.10728157f, 1.88292623f, -1.01428258f, + 0.01074958f, -1.88169158f, -0.31616244f, 0.45334938f, 1.12449574f, + -1.16699445f, -1.59505820f, 0.04126552f, -0.89016622f, 0.45838884f, + 0.71463561f, 0.14563711f, 0.30694655f, 0.67193079f, 0.61429602f, + 1.00201404f, -0.49295208f, 0.05997690f, 0.99491668f, -0.73801446f, + -1.17185295f, 0.94778723f, 0.36106884f, -0.43561545f, 0.04102699f, + 0.52626407f, 0.08442099f, -1.57626402f, 1.56855237f, -1.65396678f, + 1.74014664f, -0.38219589f, 0.39305371f, -0.31705827f, -1.15742850f, + 0.11669596f, 0.54043210f, -0.52270615f, -0.13375773f, 0.68094701f, + -1.84134769f, -1.49383473f, 0.14632171f, -0.54607725f, -1.20867658f, + -1.28439069f, -1.81734920f, 1.54257309f, 0.78347659f, -0.24049839f, + 1.69973648f, 0.99825776f, 0.99971974f, -0.26055810f, 0.34143049f, + -0.44862366f, 0.11253342f, -0.60932243f, 0.70383030f, -1.87318194f, + 0.21953633f, 0.82791799f, 1.64545465f, -0.42693698f, -0.64897031f, + -0.97996652f, -1.06616282f, 0.52939081f, -0.12541170f, -0.57480675f, + 0.73600835f, 0.35711968f, -0.03528263f, 0.79997194f, 0.55742902f, + -0.28909785f, 0.64331138f, -1.79893720f, 1.01572442f, 0.27111965f, + -0.51778597f, 0.12906317f, 0.76148927f, 1.51315522f, 0.41101140f, + 0.38008851f, 0.66759896f, -0.13804778f, 0.64854795f, 1.73474562f, + 0.75999504f, -0.73411214f, -0.05406699f, 1.35664344f, -0.25298578f, + -0.12696666f, -0.42628938f, 0.61129904f, 1.55259824f, -0.05820796f, + -0.38598019f, -0.87325627f, -0.55066222f, -1.24557889f, -0.26509118f, + -0.32103062f, 1.14031804f, -0.75985742f, 0.70659167f, -1.15016067f, + 1.24906838f, 0.90396994f, -0.16241251f, 0.43682271f, -1.42695689f, + 0.47134697f, -1.66143429f, 0.08698819f, -1.00775325f, -2.24129725f, + -1.04226267f, -0.98537570f, -0.89938259f, -1.80710697f, -1.22866321f, + 0.78125423f, 1.55150509f, 0.46235040f, 0.18444096f, 0.19313288f, + -2.20686269f, -0.40341458f, 0.50321484f, 0.47339424f, -0.81383848f, + -0.21972439f, 0.66612029f, 0.60239881f, 1.20443010f, 0.70015103f, + 0.30632916f, 0.01489905f, 0.68129027f, -0.89645082f, -2.68969011f, + -0.96684915f, 1.66421318f, 0.74333072f, -0.78321886f, 1.60063362f, + -1.27524030f, -1.95856726f, 0.47504124f, 0.15398432f, -0.20796098f, + -0.13449343f, 0.93458968f, 1.60390890f, 0.21798505f, -0.27035928f, + -1.23248971f, -1.25361061f, 1.34666133f, 1.07233441f, 0.88799530f, + -1.23687923f, -0.40781614f, -0.11916534f, -0.88050151f, -0.66422415f, + -2.61471510f, 0.78276747f, 2.42323995f, -1.70715427f, 0.71550035f, + -0.60298312f, 0.70491880f, 0.46175584f, 0.80827898f, -0.45108104f, + -0.98219043f, -1.72823501f, 1.73190725f, 0.53906441f, -1.50445580f, + -0.59250867f, -0.07239901f, 0.44743437f, -0.13740127f, 1.69935930f, + -1.00480616f, -0.58191377f, 0.39853972f, -0.60960841f, -0.45473522f, + -0.76396072f, -0.31872150f, 1.74509728f, -0.59950751f, 0.89810580f, + -0.81400329f, 1.14280319f, 1.11165059f, -1.31295311f, -1.60784578f, + -0.87506992f, -1.13461006f, -2.09486437f, -0.16449419f, -0.37728927f, + 0.47595578f, -0.55342919f, -0.17574213f, 2.21499181f, 1.14331865f, + -0.14938518f, 0.18935619f, -0.33802557f, 0.52538890f, 0.82673949f, + 1.16562462f, 1.24713838f, 0.98890215f, -0.64991701f, 1.49886703f, + 1.97769642f, 0.08059916f, -1.60925281f, -1.23822486f, -1.40829837f, + 0.51331180f, -0.29928651f, -1.04348791f, -0.39911583f, 0.69380492f, + 1.54516888f, 1.22791195f, 2.25008130f, 1.33348894f, -0.21775827f, + -0.71937007f, 0.54982573f, 1.70691478f, 0.32459491f, -0.57187974f, + -0.21614684f, 1.08274269f, 0.41384646f, 0.24497485f, -1.43703413f, + 0.89616930f, 0.82032162f, -0.24598582f, 0.84271127f, -0.81894702f, + -0.01828136f, 1.70397091f, 0.39505738f, -0.51221430f, -0.87979966f, + 0.10795479f, 0.45194778f, -0.76008922f, 1.23394477f, -0.56798172f, + 1.06459570f, -0.44333413f, -2.40399075f, -0.37267187f, 1.42946172f, + 0.95734519f, 1.86127949f, -0.15217264f, 1.68742633f, 1.97638428f, + -0.44211119f, -0.98393327f, -0.54173928f, -1.72017395f, 0.74697793f, + -1.77827263f, -1.92299354f, -0.17189410f, -0.48633271f, -2.21230388f, + -0.45906609f, -0.53493047f, 0.37253976f, -0.56951141f, 0.07728028f, + 0.03530006f, -1.18123293f, 1.94158125f, -1.55930352f, 0.69334733f, + -1.95163214f, -0.95800400f, -0.01804711f, -0.56747472f, -0.99099451f, + -1.52853060f, -0.98279524f, -1.67307866f, 0.96121490f, 0.35654056f, + 1.74034202f, -1.44633865f, -0.27781928f, 1.79457986f, -0.41029963f, + -0.76871634f, 0.36555341f, -0.77664107f, 0.19535238f, -0.76185411f, + -0.19828433f, -0.88820636f, 0.63885397f, 0.11346363f, -2.50265074f, + 0.16319332f, -1.01288569f, 1.86605489f, 0.89761645f, 1.11795115f, + -0.00714116f, -0.89034635f, -0.76447034f, -0.18822117f, -0.48340848f, + -0.99788517f, 1.02172959f, -0.39395007f, 0.72566581f, -0.81438208f, + -0.71715081f, 0.96243578f, -1.36424279f, -1.13870537f, 1.17602491f, + 0.16320205f, 0.71959788f, 1.66669416f, 0.55690295f, -0.28912008f, + -1.19219172f, 0.23308393f, -0.37963116f, 0.45347008f, -0.42606446f, + 1.30938649f, 1.25128853f, 0.57649273f, 0.34440875f, -0.23893952f, + -1.06604803f, 0.31336102f, 0.75727910f, 0.46772480f, -0.37650385f, + -0.06036821f, 1.03686309f, 0.46158856f, -1.81028461f, 1.43393028f, + 0.85494965f, -2.34685564f, -0.17571987f, -0.45592231f, -1.31190526f, + 1.73194158f, -0.11856517f, 0.07041293f, 0.25689471f, -0.56000596f, + 2.06649089f, 0.38954756f, 1.36627376f, 0.13905638f, 0.77370811f, + 0.43944249f, -0.08798827f, 0.07245751f, -1.30234015f, 0.29710820f, + 0.74389762f, 0.11971968f, -0.07381748f, 1.32652700f, 1.34079397f}); + + auto input2 = NDArrayFactory::create( + 'c', {3, 4, 4, 5}, + {0.98114507f, 0.96400015f, 0.58669623f, 0.60073098f, 0.75425418f, + 0.44258752f, 0.76373084f, 0.96593234f, 0.34067846f, 0.57962620f, + 0.77517051f, 0.97472977f, 0.79237527f, 0.68690428f, 0.21719366f, + 0.79959206f, 0.84814187f, 0.22496814f, 0.08646965f, 0.31110474f, + 0.79813162f, 0.19661444f, 0.57760099f, 0.72138960f, 0.15244268f, + 0.87687051f, 0.11130344f, 0.01087698f, 0.34817841f, 0.54992017f, + 0.23443850f, 0.31725614f, 0.59755220f, 0.20364695f, 0.00531392f, + 0.23403114f, 0.07442912f, 0.83707647f, 0.89291743f, 0.09044587f, + 0.69041462f, 0.29904183f, 0.61904680f, 0.85306847f, 0.34467042f, + 0.95839152f, 0.54517124f, 0.29640937f, 0.94855959f, 0.95970016f, + 0.94045145f, 0.95510301f, 0.34666505f, 0.34717010f, 0.69245678f, + 0.71669175f, 0.59043738f, 0.64924132f, 0.06033522f, 0.60185199f, + 0.04690073f, 0.59241154f, 0.40229547f, 0.23002481f, 0.45161195f, + 0.73743778f, 0.93209113f, 0.37294358f, 0.50177744f, 0.15072501f, + 0.26146917f, 0.05252146f, 0.04758931f, 0.76448288f, 0.85149045f, + 0.08840467f, 0.07692576f, 0.33180160f, 0.27241259f, 0.74834620f, + 0.56453640f, 0.23057286f, 0.68429752f, 0.11961551f, 0.39045977f, + 0.44356094f, 0.77018807f, 0.07984410f, 0.47926806f, 0.26165759f, + 0.18606064f, 0.89972877f, 0.17962874f, 0.47273120f, 0.64641705f, + 0.61890443f, 0.58730015f, 0.25937832f, 0.35231561f, 0.10243882f, + 0.17459193f, 0.95906995f, 0.09227025f, 0.30003223f, 0.41601210f, + 0.38269713f, 0.84799751f, 0.59295173f, 0.76277990f, 0.68910424f, + 0.37672606f, 0.40675461f, 0.94346058f, 0.91438505f, 0.84728183f, + 0.64367667f, 0.74899979f, 0.60570691f, 0.16417363f, 0.68852426f, + 0.85486889f, 0.22585792f, 0.86953176f, 0.07465519f, 0.93096301f, + 0.38008822f, 0.38752587f, 0.44004038f, 0.13170612f, 0.94541045f, + 0.89349973f, 0.69245307f, 0.94978877f, 0.98776658f, 0.79445884f, + 0.30607409f, 0.58264961f, 0.37980538f, 0.41810784f, 0.48903038f, + 0.51615888f, 0.57682794f, 0.82481897f, 0.78341080f, 0.48446465f, + 0.17447931f, 0.71125424f, 0.30263851f, 0.70675352f, 0.03215584f, + 0.92381065f, 0.22343694f, 0.08851149f, 0.91402490f, 0.70074717f, + 0.30912192f, 0.37723206f, 0.97579397f, 0.23554587f, 0.95939133f, + 0.41565709f, 0.01741416f, 0.58362787f, 0.22106662f, 0.89065537f, + 0.31900249f, 0.41280911f, 0.67947610f, 0.04545590f, 0.15352812f, + 0.85412524f, 0.84933222f, 0.80000225f, 0.93147073f, 0.70094105f, + 0.69269875f, 0.95282194f, 0.65913582f, 0.79186874f, 0.59855248f, + 0.39707430f, 0.95126239f, 0.15618217f, 0.33446689f, 0.98123758f, + 0.84770758f, 0.98081012f, 0.54427413f, 0.18728519f, 0.89792955f, + 0.53360126f, 0.72812986f, 0.13307744f, 0.51217443f, 0.66708084f, + 0.29416915f, 0.31298995f, 0.39155037f, 0.29288291f, 0.87063305f, + 0.61759154f, 0.73723332f, 0.37167635f, 0.82122716f, 0.22937430f, + 0.76570536f, 0.47911792f, 0.02826214f, 0.94277323f, 0.59945469f, + 0.19042060f, 0.68173155f, 0.82771295f, 0.95649538f, 0.40833101f, + 0.90838542f, 0.55245881f, 0.49011012f, 0.36773444f, 0.34513527f, + 0.42050683f, 0.16113964f, 0.30969388f, 0.27174174f, 0.12117655f, + 0.35270175f, 0.81967867f, 0.63723136f, 0.84309389f, 0.71822576f, + 0.84883484f, 0.32306117f, 0.08176457f, 0.56175486f, 0.34892198f, + 0.09306929f, 0.85437582f, 0.13925577f, 0.48629188f, 0.29923539f}); + auto exp = NDArrayFactory::create( + 'c', {3, 8, 8, 16}, + {5.98743296f, -2.83037376f, -0.87943113f, 1.41339970f, 1.32433391f, + -1.20299149f, -0.02893090f, 2.05326009f, 1.19417048f, 5.58212376f, + 3.28139353f, 1.19237995f, -1.09431255f, -2.55264497f, 3.11014652f, + 6.81296825f, -2.09029293f, -4.32068443f, -0.52808392f, -1.97968531f, + -0.18673831f, 0.84605980f, 4.55825520f, 2.71503139f, 0.15210046f, + 0.85310984f, -3.82062817f, 2.76470995f, 3.69004202f, -1.45017099f, + -2.59361267f, -1.35094655f, 7.24145126f, -5.25432396f, 0.19920218f, + -4.30596399f, 1.35318923f, -3.88142037f, 3.67493343f, 2.25931478f, + 2.87630725f, 1.66349852f, 6.21347952f, 0.94105923f, -1.61742055f, + -2.35699606f, 0.12850338f, 1.79141688f, -2.09535933f, -6.35418081f, + -0.06303531f, -4.38615131f, 0.48237842f, 0.26528549f, 3.38231516f, + 3.76315165f, -0.40254810f, -0.23716694f, -6.13381910f, -0.41950428f, + -0.89680839f, -1.46491277f, -1.98541689f, -0.99357355f, 5.58237648f, + -2.38937521f, -0.00872564f, -2.37138414f, 4.91117287f, -4.51916361f, + 0.97943687f, 2.91052818f, -2.50362611f, 1.70252812f, 5.04137802f, + 3.57108784f, -1.87532270f, -3.66677809f, -2.38861251f, 5.55765152f, + -7.27571774f, -1.68887305f, -0.72266489f, -4.42809057f, -0.92118186f, + 1.02381468f, 4.44284725f, 5.17150497f, -0.42438728f, 2.02693963f, + -1.36484981f, -1.47912180f, 0.26649538f, -0.02091765f, -2.86906910f, + -3.03046989f, 1.35122132f, -3.21707630f, 2.21112418f, 0.24121630f, + 3.96940088f, -7.66105747f, 2.76352382f, -0.99061489f, -2.16720009f, + -1.63170409f, 1.12701774f, -1.02415371f, -0.90435314f, -1.51372027f, + -0.76884907f, 0.39066136f, -0.89562428f, -2.03204703f, 1.28074932f, + -2.14551091f, -2.36843777f, 0.46580017f, 0.75451565f, -0.00336730f, + -1.06597757f, 3.27195978f, -0.41307712f, -0.10376054f, -1.34102952f, + -2.22901654f, 2.31929803f, 1.40851438f, -2.23774385f, 0.20417206f, + -1.12153268f, -0.13188094f, -3.96649432f, 2.10269976f, 0.49845099f, + 6.18937683f, -0.51783508f, -0.48048639f, -1.92970264f, 3.16670656f, + 1.13355756f, -0.07890664f, 1.31536257f, -0.43924797f, -0.04562932f, + -0.87974954f, 0.75411212f, -2.39745235f, -3.97132111f, 0.37202546f, + -2.40399146f, -1.50796390f, -3.08302689f, 0.23075986f, -0.94316757f, + 1.34948587f, 0.58591264f, 2.18529797f, 7.97652435f, 2.32798409f, + -4.09404373f, 0.89634895f, 0.77697754f, -0.65091681f, -7.05506849f, + 5.86194515f, 2.51394033f, 4.69959354f, 0.20835471f, 3.18049693f, + -1.29682434f, 3.70832396f, -0.48123091f, -1.67904007f, -1.35418940f, + 1.58435583f, -1.13851106f, -1.19225955f, 0.59713769f, -5.80462933f, + -7.45143986f, -1.08658695f, 1.03244078f, -1.75307107f, -7.07100582f, + 3.85825157f, 1.62127817f, 2.32572675f, 0.56171900f, -0.80591971f, + 3.98835945f, 0.15742642f, -2.97832179f, 0.13821673f, -0.72556758f, + -0.84936106f, -7.28444147f, 3.94134307f, 0.80779338f, 7.47784615f, + 8.23335075f, 4.80595016f, -4.89574575f, 4.03362942f, -6.67522192f, + -4.55204487f, 2.12511182f, -2.70781207f, -1.57226098f, -3.08408356f, + -0.30812448f, -5.32870674f, -5.13238287f, 0.49605465f, -0.55042171f, + 0.46324944f, -3.83545256f, -0.12562510f, -0.20978995f, -0.13068712f, + -1.92144060f, -1.68787408f, 5.45581436f, -0.79583496f, -2.38866687f, + -3.90546346f, -0.47028148f, -0.14319679f, -3.37016582f, 2.00905991f, + -1.21345615f, 1.81376505f, 7.73004007f, 0.74310112f, -4.64536428f, + 3.78111577f, -9.05182457f, -0.10674095f, 1.53476238f, 0.63345337f, + -0.40907967f, -1.44729769f, -1.87145400f, -2.46623540f, 1.07472968f, + 0.77390999f, -3.93438888f, 4.49174690f, -0.96686655f, 1.92278123f, + 0.30049133f, -0.02388665f, -1.99777114f, -3.23885751f, 5.87784004f, + 2.13776040f, 3.56758308f, -3.37774134f, -3.67526293f, 1.63700044f, + -1.69959962f, -0.99112594f, 6.03103638f, 1.67399430f, -1.28699589f, + 7.16759014f, 12.63490295f, 3.62937450f, -4.75982571f, 2.17861104f, + -2.03065681f, 4.30207729f, -0.46797156f, -2.96022511f, -6.02702332f, + 3.09229851f, -1.39771092f, -0.03471333f, 3.22175527f, 5.63565636f, + 1.78195477f, -0.63545251f, -3.99497652f, 1.46043062f, 4.60050488f, + -2.96651959f, -2.03159475f, -1.52386189f, -0.15129802f, -3.90390921f, + -0.63852370f, 0.79210538f, 2.35288715f, -5.55609035f, 5.36427498f, + -0.60248077f, -0.26181316f, 5.04884720f, 8.53192806f, 5.05080223f, + -6.56371737f, 1.52260923f, -7.13623667f, 6.49414349f, 2.33445597f, + -4.11490965f, -6.44347477f, -0.47079402f, -0.63467920f, 2.60399365f, + 1.05958164f, 3.66901422f, -1.05657935f, 1.88611507f, -6.37475634f, + 2.01480770f, 3.36020517f, -5.11001921f, -0.46132171f, 2.16525555f, + 4.21938848f, -2.08346295f, 2.86168146f, 1.26987600f, 6.76066971f, + -7.84916353f, 4.11700916f, 0.47985530f, -4.60113716f, 7.42062473f, + 6.37472820f, 4.37820530f, -7.12197018f, 0.01357239f, -7.90392113f, + 8.32131577f, -0.87593079f, -0.16994858f, -5.86345863f, -0.20697471f, + -1.37845206f, 1.63819647f, 1.59720242f, -0.74357712f, -1.88725603f, + -1.98357940f, -8.57950306f, -4.10104513f, 3.57231879f, -2.89855957f, + -0.11263305f, 2.78033924f, 1.53078973f, -2.93089223f, 0.73189604f, + 3.20563078f, 3.92601013f, -5.21916151f, 0.89163935f, -0.42978728f, + -6.70888853f, 4.56477976f, 1.20105875f, 3.83393812f, -6.27205181f, + 4.05993128f, -7.35513067f, 1.60660768f, -1.21052051f, 1.58191252f, + -1.37899971f, -1.20117283f, 2.93301678f, 1.06302834f, 1.38993621f, + -1.66884089f, -3.34452581f, 1.04498529f, -4.10412455f, -4.03310585f, + 1.61513603f, -1.09388447f, 2.11451387f, -0.94192362f, -0.23287666f, + 5.88265705f, -0.83010495f, -2.15317154f, -0.60276151f, -1.49265075f, + 3.93397975f, 5.45194483f, 1.45161700f, -2.57401872f, -5.59288931f, + 4.29170895f, 1.87151814f, 0.08362055f, -0.28767288f, 1.17675185f, + 0.85266006f, 1.30549634f, -5.60830832f, 0.19398519f, -0.83982587f, + 1.75940764f, -5.46077394f, 1.64495635f, 0.17102760f, -0.54459631f, + -2.21975255f, -0.37443402f, -2.08474159f, 1.85959935f, 11.19680309f, + -0.18611598f, -2.59765387f, 3.06330776f, -1.52183700f, -4.88415241f, + -0.75097847f, 2.58201051f, 7.40885210f, 3.58994508f, 1.62457407f, + 3.12514591f, -4.36833286f, 1.39830995f, 3.61003447f, -0.63837433f, + -3.62661815f, 3.78898096f, 2.92802262f, 5.87374496f, -4.38554621f, + -2.53411579f, -2.87311554f, -1.31391978f, -4.26736879f, 3.45099425f, + 1.58769250f, 1.73341393f, -1.08842182f, 2.27120280f, -1.78938174f, + -2.29940319f, 7.07046986f, 0.51426595f, -6.22928905f, 5.28968811f, + 2.31827855f, -4.20915890f, -1.27249205f, 5.92120600f, 3.19458675f, + 7.09252501f, 3.96577907f, 6.41484213f, -4.66009521f, 10.00181389f, + 0.51108456f, -4.62243366f, -5.18351841f, 2.12961674f, 5.10694027f, + 7.29412317f, 0.15912467f, -3.38902974f, -4.01918602f, -2.17383957f, + 0.13118666f, 0.27872476f, -0.92317247f, 3.51440644f, 1.84171486f, + 1.03378081f, 1.30569839f, -2.09583759f, 9.03952980f, -0.55187917f, + -2.04549074f, 1.08294606f, -2.65263700f, -2.93977118f, 1.88909876f, + 0.96043622f, 1.76579499f, 3.14314699f, 5.86394691f, 7.36944389f, + -7.04524136f, 6.68673229f, -5.52591467f, -2.19745898f, -4.32036924f, + 0.52971321f, 2.26268244f, 6.91575766f, -0.94590527f, -3.98923349f, + -0.12266219f, 0.24294075f, -1.07783222f, 1.87989080f, -3.57109427f, + 1.61553633f, 0.42486978f, 0.75852054f, -6.19481468f, -3.80570698f, + 2.39946675f, -1.93851781f, -5.42234039f, -6.34092760f, -2.52374983f, + -1.85044456f, 3.92693520f, 0.40042299f, 4.69742584f, 5.40483189f, + -1.02398944f, 8.89605045f, 0.64680403f, 0.89943957f, 0.76993859f, + -1.88244629f, 1.90714884f, 3.10836840f, -0.17064989f, 0.84892416f, + -6.94988108f, 1.92141032f, -1.36458397f, 6.39284658f, 0.45201308f, + 2.58823442f, 6.33375788f, -4.76916075f, -8.45738983f, -0.48962492f, + 2.40652561f, 4.56602001f, -3.34420681f, 1.86862195f, -7.01420689f, + -6.94657421f, -2.47419310f, -4.61693668f, -0.18822384f, -0.36949772f, + 2.01374269f, 4.11018658f, -5.11564064f, 8.04294395f, 2.88567662f, + -2.87645102f, -1.23238611f, -5.91409397f, -0.62205851f, 1.38689423f, + -0.01120412f, 5.25955677f, -1.98474956f, -3.72012186f, 3.00445986f, + 4.99141550f, 2.97457719f, 2.70827627f, 6.04544449f, -0.20756161f, + -10.87035751f, 0.80454814f, 0.33568168f, -2.48132324f, -2.84452009f, + 2.63126230f, -3.99351716f, -7.39294338f, 3.62798953f, -8.65815926f, + 2.65992808f, -6.98126554f, 3.09881067f, 0.67735767f, -1.15946686f, + 5.63180256f, -0.17694545f, -8.59651184f, 3.75297594f, -2.35913754f, + -0.20330384f, 5.49958467f, 1.00861740f, 1.42849684f, 0.00062013f, + -0.11073381f, 2.15207863f, 4.07368469f, 1.14344299f, -1.27953362f, + 6.64699316f, -0.73672432f, -8.55606937f, -0.19439441f, -4.14319754f, + -4.69964647f, -5.86446047f, 2.87106085f, -3.42714882f, -5.00668287f, + 6.22464132f, -7.72335291f, 4.05667686f, -5.72637177f, 6.35073948f, + -1.29593158f, 0.00813985f, 3.63368607f, -1.05764008f, -7.88486052f, + 3.73919106f, 1.41835213f, -1.04935634f, 0.65119827f, 0.03547254f, + 1.88996327f, 1.58701086f, -0.56215239f, -0.80187100f, 4.55604362f, + -0.67249978f, 1.41084409f, 7.86281586f, -2.38301182f, -8.50535774f, + -3.82098866f, -2.40856767f, -5.33439016f, -3.34747362f, 2.69389009f, + -1.64118791f, 4.52447939f, 0.04468334f, -1.48768258f, -0.69848812f, + -0.71123981f, 3.66259432f, 6.10314512f, 1.37305343f, -0.62758982f, + -2.99383426f, 4.20510864f, 1.48497128f, -0.08954811f, 2.43872309f, + -0.59880185f, 0.37431365f, 2.45458341f, -3.28401661f, -1.94629693f, + -1.93975246f, -0.26385683f, -0.45814323f, -0.18108580f, -3.74811840f, + -0.29739976f, -2.24116230f, -0.28150487f, -2.24421668f, 3.46930790f, + 8.35415077f, 0.05562943f, -2.81079793f, 1.10388446f, -2.82245207f, + -2.98102283f, -1.08132946f, 1.19089699f, 8.00183105f, 6.35385323f, + 3.72591257f, 4.59467506f, -5.74890900f, 4.42238331f, -3.36533451f, + 0.18350232f, 3.05606651f, 1.18788099f, 2.87450886f, 0.27472210f, + -2.80111074f, -0.66314960f, -1.96376896f, 0.75167024f, -4.72056293f, + 1.10629988f, -5.00775242f, 1.48246133f, -3.91681528f, -1.86573625f, + -6.17714882f, -0.67820001f, 5.69730282f, 1.04399037f, -4.93794823f, + 3.09619617f, 2.18692017f, -5.54232264f, -3.10046840f, -0.68972743f, + 2.81824327f, 3.04334164f, 6.13203907f, 4.14081764f, 1.02573645f, + 5.71970081f, -6.01574707f, -2.07346702f, 0.99554527f, 1.69641590f, + 0.66776669f, -0.80132431f, -2.03513098f, -3.42513680f, -0.06704485f, + -1.87195873f, -5.42428589f, -0.20748445f, -1.52408111f, 0.97084987f, + -0.48799962f, -0.45379883f, -0.26652339f, -1.20720732f, 3.94169855f, + -3.18480229f, -1.87440264f, -1.18028760f, 0.52011997f, -2.13437462f, + -4.52583313f, 1.69722807f, -0.89371562f, 3.37972403f, 6.38838720f, + 6.98663378f, -4.05421400f, 6.89512825f, -5.09085655f, -2.16257906f, + -3.33272719f, -3.01246452f, 0.37613097f, 1.80455804f, -0.36456174f, + -5.32273912f, -1.29978943f, -0.53685790f, -2.12896323f, 2.55506587f, + -2.57999182f, 3.40891910f, 1.36033249f, 0.83864629f, -2.88629293f, + -7.36048365f, 5.61314154f, 1.32668555f, -2.58041072f, -3.71943092f, + 1.60647738f, -2.74816346f, 2.47269106f, 0.85507953f, 8.39183426f, + 3.42624784f, -0.01519036f, 5.68412066f, 2.51771593f, 1.03045523f, + -2.08733034f, -2.44337177f, 0.81668580f, 1.30275154f, 2.99679208f, + -2.91957355f, -1.71337795f, 3.34979844f, 1.51825011f, 5.20375061f, + 2.27888370f, 1.38787699f, 4.23474550f, -4.05878592f, -4.85074377f, + -0.22794735f, 4.64402294f, 1.24391258f, -2.04935098f, 1.26285601f, + -7.51862240f, 0.62138438f, -1.95792389f, -0.96587181f, 0.85141110f, + 0.79354531f, 7.93766356f, 6.07677746f, 2.05947518f, 6.55480623f, + 1.44032848f, -0.70615625f, -0.07896036f, -5.08359432f, -0.01047915f, + -1.89632201f, 2.57555676f, 3.83779287f, 0.42850614f, 1.80754125f, + -0.06942326f, 6.35997963f, 6.06101418f, -0.97032297f, 5.71477222f, + -6.06671238f, -3.46607208f, -4.98306370f, 2.84659123f, -2.11025190f, + -0.04609144f, 5.26831341f, -9.56940651f, -3.67193556f, -1.71143103f, + -1.35221267f, -4.26226807f, -6.89146233f, 8.21761799f, 5.69823503f, + 2.28137946f, 1.88911343f, -1.44562483f, -1.60295713f, -0.52568185f, + -3.31892347f, -2.81997776f, 0.35287106f, 2.98202395f, -1.39432132f, + -2.70001364f, -4.14169264f, 3.50194883f, 4.12610435f, 5.52755260f, + 2.65859175f, 3.61353087f, -0.83027136f, -5.10652542f, -4.48625374f, + 2.06585884f, -2.76383352f, -0.64300913f, 8.19686604f, 0.96106279f, + 2.45952058f, 2.47275925f, -1.03288829f, -0.64897656f, -3.77937531f, + 4.27940083f, 2.58320260f, -0.57665241f, 1.87247813f, -3.81604433f, + -0.24543774f, -1.62118483f, -0.73075479f, -0.48533297f, 2.05016756f, + 0.45561486f, 0.03316188f, 0.77791005f, -1.56283605f, 2.36616826f, + 5.58082104f, -1.30925488f, -1.06329608f, 2.17189479f, -3.43008828f, + -4.71520567f, -2.56184673f, 0.17508316f, -3.25817418f, -0.41749167f, + 0.18119079f, -0.73181152f, 3.99792433f, -3.08002281f, -0.99143314f, + -1.83520067f, 1.18565679f, 2.98040128f, 5.67814350f, 2.35128760f, + 1.41600966f, 4.02718067f, -0.08193968f, 0.64636409f, 1.35931289f, + 2.37125754f, 1.75978124f, 3.90977740f, 1.50662971f, -2.84089065f, + 1.29824126f, -3.38730979f, -1.61005294f, 0.58292413f, -0.03019404f, + -1.57986510f, -0.56102908f, -3.03128719f, 0.51644313f, -2.01147819f, + 0.98400700f, 3.00028515f, 0.74579155f, -3.37098312f, 0.93339360f, + -1.29018497f, -2.14695001f, 1.30411184f, 0.71501279f, 7.47793055f, + 4.06516457f, 3.50772929f, 3.52762985f, 0.55643129f, 0.32272506f, + -4.30955982f, 2.49414706f, 2.07820845f, -0.34377906f, 4.39805031f, + 2.77561307f, -3.91292810f, 2.43981409f, 0.18861845f, -2.76658440f, + -4.97148752f, 3.25273705f, -0.08929539f, 0.19818619f, -5.83767605f, + -0.97381884f, -5.68745661f, -5.42433214f, 3.98769903f, -0.40394354f, + -1.83387578f, -0.80109525f, 1.47454357f, -3.14899540f, 0.80130816f, + -2.26348829f, 4.06121159f, 6.13077354f, 5.31226397f, 2.94966197f, + -3.65217376f, -1.08136678f, -7.14119816f, -0.85269439f, -0.70365787f, + -0.81598872f, 3.62807679f, 3.08123684f, -7.82739496f, 4.07951784f, + -0.14204243f, -0.66969109f, -5.07225513f, 2.88492823f, 0.47202343f, + 0.72683257f, -6.84280777f, 0.41807127f, -5.09785986f, -3.74514675f, + 2.03936672f, -1.06096244f, -1.52409148f, -0.97046643f, 2.27491093f, + -1.55597985f, -1.29215479f, -0.79737484f, -0.01979581f, 7.65407991f, + 5.54527044f, 4.04147148f, -2.64274883f, -1.89246953f, -3.89547634f, + -1.06029689f, -2.85982800f, -1.41247237f, 1.55836034f, 3.38194537f, + -2.97655582f, 0.87510300f, 1.26282072f, -1.77029657f, -3.57144690f, + -4.19456863f, 0.53179169f, -1.42221975f, -3.09144497f, -0.84294832f, + -5.02758694f, -2.68011904f, 0.89156240f, -0.34783912f, 4.64484835f, + -2.34453487f, -1.28573155f, 0.09990287f, 0.01828218f, -1.79960847f, + -1.06579173f, 1.08763921f, 0.43687880f, 3.24747229f, 3.83097172f, + 1.07253766f, -1.33810723f, 0.76530832f, 1.58660865f, 5.60743904f, + -3.54124737f, -0.89264417f, -3.83942485f, -1.03707337f, -1.61659896f, + 1.65349591f, 1.72698796f, 4.96013832f, 0.78927267f, -0.35563886f, + -3.48121166f, 3.79677629f, 2.59023166f, 2.74940348f, -2.17589283f, + -5.91757107f, 2.43766379f, -4.15906048f, -1.74731481f, -2.49113035f, + -0.57349741f, -4.04455185f, -1.46939647f, 2.21418452f, 0.09153593f, + 2.23016739f, 7.91880608f, 4.04464149f, 0.07706618f, -2.41892862f, + -2.19280314f, 7.61760712f, -5.89153862f, 0.33551922f, -1.70855618f, + -0.30561331f, -0.14341974f, -2.48878574f, 1.31269515f, 3.45388412f, + -0.02453184f, -0.12132037f, -4.27916241f, 1.25179088f, 4.09455204f, + -1.83801770f, -1.86743176f, -4.02864933f, 3.44515228f, -4.39244986f, + -0.56988084f, -1.69426417f, 2.18254852f, -4.78135824f, 1.73193693f, + -2.27968478f, -1.49523509f, 2.51696730f, 4.03677559f, -2.03679037f, + 1.32167840f, -2.22570705f, -2.74843621f, 6.29655170f, -3.67230225f, + -1.86765468f, -0.14842367f, -1.21552539f, -0.92038238f, -0.51692355f, + 1.08433771f, -0.01929832f, 0.15660909f, 2.31432915f, -3.86507082f, + -0.69797570f, 0.13505173f, -1.50951028f, -0.69980979f, -1.51297045f, + 3.63725281f, 0.13388813f, 2.73131752f, -0.96528149f, 4.92000961f, + -5.92699385f, 1.69444644f, -1.17121375f, -2.33710480f, 1.35302818f, + 1.39608085f, 1.68293881f, 0.94960749f, 1.89011908f, -4.08865070f, + 0.13722643f, -1.62849212f, -0.19044125f, 1.37906075f, -3.92504406f, + -1.45033538f, -0.42085981f, 3.38237071f, -3.06508875f, -1.39420545f, + 1.13067436f, 0.92206454f, 0.49917889f, -2.74508023f, -2.19221997f, + 1.77914095f, 0.10854459f, -2.62178278f, 2.35042715f, -0.15322030f, + -0.67014873f, -1.75627899f, 2.64074945f, 2.76339936f, 2.67275214f, + -0.62736398f, 0.58251178f, -4.64895678f, 5.50419283f, 2.53566456f, + -2.44196153f, -0.07845879f, -2.80389643f, -0.64810950f, -0.05813205f, + 1.67155504f, -2.69673729f, -1.72486305f, -0.53888649f, 1.86805439f, + -1.37128329f, -5.37923479f, -2.08133769f, 0.58187997f, -1.39498150f, + 0.21874082f, 4.33726025f, 6.29673958f, 0.72312093f, -3.32683516f, + 1.73482585f, -0.00766110f, -2.63785434f, -0.13511759f, 4.07195950f, + 0.94139838f, 3.15717316f, 1.53720927f, 1.87664819f, -2.33655119f, + 6.18176556f, -2.73912525f, -2.45279956f, 2.20392370f, -0.56854641f, + 0.98915887f, -2.64472580f, 2.40633702f, -4.93327999f, -1.28942823f, + 0.98247659f, 1.31774998f, 0.07669818f, -5.91169453f, -0.43135011f, + 1.27404964f, -0.59787154f, -0.22716975f, 0.74409103f, 10.27316475f, + -2.29192710f, -2.19403267f, 3.78925133f, 3.19553399f, -4.42490482f, + -0.80781460f, 2.16568565f, -2.54165983f, 2.54885101f, 4.18779039f, + 1.73079813f, -1.48891807f, 11.60153770f, -0.98686743f, -2.88813901f, + 2.32898521f, -0.36101711f, 2.34522438f, 0.29057693f, 1.39800644f, + -4.31848240f, -3.21217132f, 0.11740226f, -1.21613467f, 0.57248503f, + -4.44853830f, 1.54665899f, 3.14459944f, 1.76809108f, 0.26693153f, + 0.86913753f, 9.47121620f, -2.07677889f, 2.08578467f, 1.30181742f, + 1.58683562f, -3.52757788f, -1.32763624f, 0.79821301f, -2.19358301f, + 1.17707348f, 6.01983643f, 4.11209440f, -2.04209709f, 7.00413418f, + -1.84904683f, -1.32542288f, -0.01298118f, 0.70377320f, 0.27815005f, + 2.07879829f, -0.71606725f, -4.94399881f, -2.11898828f, -0.39051518f, + -2.21034360f, 3.05337906f, -1.56889665f, 1.97065282f, 2.61320901f, + -0.34063196f, -0.57001418f, -2.13183641f, 3.48879004f, -0.12067288f, + 0.48568326f, -1.81424558f, 2.28868723f, 1.44802380f, 1.25918829f, + -1.76415455f, 5.35742331f, 3.50682044f, 4.71371317f, 5.89110756f, + 8.51241302f, 4.07391453f, -0.05887252f, -0.18202400f, 2.27119660f, + 6.78274727f, -2.87470293f, -5.14336634f, 0.76443815f, 2.04625130f, + -0.43199503f, -1.01353514f, 2.42951298f, 2.35641170f, 0.32345510f, + -4.04195738f, -4.77967072f, 0.26564783f, 6.11455107f, -2.53868008f, + -3.11839914f, -1.04203856f, 5.17195654f, -4.15338612f, -3.84149241f, + 0.48130888f, 3.09706950f, -4.18423653f, 5.26233864f, 3.55831861f, + 3.75122595f, 8.14969349f, 6.80038738f, 4.68907356f, -1.40135396f, + -3.19287133f, -3.15895939f, 8.77363205f, -4.48793411f, -3.80537176f, + -2.40145254f, -2.74341679f, -2.02862644f, 5.33402443f, 9.25365734f, + 2.50246119f, 0.32847846f, -1.50564361f, -4.26163197f, -1.40994716f, + 2.50708485f, 0.44500345f, -0.62516934f, 4.09846306f, 5.29355669f, + -4.02224922f, 0.73442125f, 0.46648952f, 0.67028689f, -6.30715466f, + 6.56297970f, 3.80854273f, -5.19078207f, 4.98839283f, 7.59161472f, + 0.46010983f, -2.10227895f, 0.29324162f, -2.67019558f, 4.57838106f, + -3.02338457f, -3.08647728f, -2.00112700f, -3.81710315f, -0.08346784f, + 1.69288683f, 5.68807268f, 3.29351830f, 0.54618967f, 1.83540761f, + -5.38810253f, 0.51326782f, 4.40081882f, -4.03805828f, 0.49482727f, + -1.36024392f, 2.91845679f, -2.00959015f, 2.47489738f, -1.43354976f, + 1.92024410f, -6.55897284f, 1.79488957f, -0.89570928f, -6.13094234f, + -0.45504010f, 2.35239482f, 1.29039919f, -4.78849840f, -1.52545333f, + -6.50420475f, 2.99257326f, -0.55620033f, 0.26807702f, -2.52090979f, + -4.59419632f, 0.57965040f, 2.19423151f, 2.04760551f, -0.57048106f, + -2.20812702f, -0.04777686f, 1.38053393f, -2.71448946f, -1.06219673f, + -3.62008905f, 1.85719645f, 1.28355026f, -2.76315832f, 1.65295160f, + -4.01645803f, -3.10454416f, -0.65713316f, 1.22384977f, -0.70416176f, + 4.45064926f, 1.31602776f, 2.06907344f, 2.48872757f, 4.25775290f, + 3.50504255f, -0.68262041f, 1.29799378f, -1.01969171f, 2.98593879f, + 0.12607655f, 0.37219539f, -0.84196299f, -3.80019331f, -1.82315290f, + -0.38489276f, -1.45200360f, -4.00882292f, 0.61042011f, -0.16738498f, + 1.33787775f, -2.26938057f, 1.03656030f, 8.89089870f, -1.60370600f, + -5.38691807f, 5.72182989f, 2.72854710f, -6.18535757f, -3.13408709f, + 2.79175353f, 5.18425512f, 9.46434212f, 2.40110517f, 1.11330092f, + -3.57366538f, 4.80967665f, 0.40691876f, -3.65484858f, 0.92398167f, + 2.53852940f, 3.17747331f, 2.14199781f, -1.69107199f, -1.91864693f, + -3.18452644f, -2.42408276f, -2.14332366f, -1.35526609f, -4.50732136f, + 0.58234072f, -1.81547785f, 0.57311213f, 1.10584176f, -0.97226644f, + 11.73174381f, -2.00559855f, -1.81175601f, 2.33131361f, 0.49264961f, + -0.42245382f, -1.37528467f, 1.55768061f, 0.21152198f, 13.08896351f, + 10.33674145f, 5.77929306f, -6.19886398f, 5.67007637f, -6.61288071f, + -2.58029866f, -4.05192375f, 1.77221894f, 0.29821560f, 5.23508501f, + -5.09560966f, -0.97536200f, -5.17957878f, 1.02876794f, -4.52072096f, + 2.22126532f, -4.81708670f, 0.44538212f, -2.30738068f, 3.15900373f, + -4.99227905f, 0.82632786f, 9.65415478f, -0.63819492f, -3.25479436f, + -0.13276935f, 0.21337092f, -2.22116399f, -3.04922724f, 0.65568435f, + -0.10706246f, 4.58047390f, 7.80782652f, 5.49080181f, -3.97114491f, + 6.43327618f, -6.54772758f, -2.10962629f, -0.79831678f, -0.08316499f, + 2.48658133f, 4.14070511f, -0.59806836f, -4.58636141f, -0.31166920f, + 0.31757897f, -3.92562199f, 0.65357721f, 0.55871534f, 1.71843934f, + 1.62395024f, 0.00695819f, -4.56716251f, -3.76420808f, 4.24979544f, + -0.86128616f, 0.23126510f, -6.32968998f, 1.83346081f, 3.81335950f, + 2.98407745f, -1.80454743f, 6.61764765f, -1.39372075f, -0.86780751f, + 7.24317265f, 2.24205112f, 1.05702817f, 0.55431479f, -1.54557061f, + 3.36389136f, 4.70898724f, 1.11327887f, -3.78462076f, -3.63381767f, + 2.86510396f, 0.74203897f, 0.81488025f, 3.54250598f, 3.24824381f, + 3.19000244f, -0.58995843f, -7.05670738f, 3.18306041f, 3.95191574f, + 0.81820154f, -1.91068232f, -2.05426741f, -1.05589008f, -3.18377590f, + -1.86278260f, -8.80374908f, 0.93416154f, -4.60517359f, 8.38999462f, + 5.26356745f, -8.89992714f, 8.95298958f, 4.22590351f, 1.00351548f, + -6.90151119f, -8.07641125f, -4.82450199f, 8.02293015f, 4.11661243f, + 0.95457208f, -7.07843113f, -4.30524826f, 5.02697992f, 5.21011686f, + 0.80132771f, 3.23420191f, 3.82452774f, -2.13171721f, -7.88879967f, + 1.31062031f, 1.90848613f, -3.51572514f, -3.75684500f, 3.62577081f, + -5.76075602f, -2.79389215f, 0.32598805f, -4.28981733f, 4.21048594f, + -3.84532523f, 3.19815183f, -0.40756655f, -2.19974327f, 6.25655174f, + 3.42396951f, -1.88986623f, -1.92803884f, -2.97344875f, -0.09756154f, + 5.24342251f, -0.72513700f, 1.06113195f, -1.30720282f, 4.69107103f, + 0.58984971f, 2.33985567f, 1.46385121f, 3.16576266f, 6.77769995f, + -5.92685127f, -12.61141014f, -2.83663774f, 4.90253258f, -6.32688522f, + -3.00096869f, 2.38634992f, -7.21459866f, -5.89208746f, 2.84085894f, + -1.21792030f, 6.70161343f, -4.00450230f, 5.29881001f, -1.45574808f, + 0.77542424f, 1.38336325f, -0.21572059f, -3.38088870f, 2.33249640f, + 0.68824625f, -3.68440270f, 0.33481622f, -0.39239681f, 0.14560902f, + 1.61039007f, -3.11967754f, 2.49372435f, 2.68783092f, -1.17559779f, + 0.95257235f, 4.35451412f, -0.56818569f, -7.32110357f, -7.58534050f, + -2.10573673f, -3.34446383f, -0.32183546f, -0.78525496f, -1.76974547f, + 5.19060802f, -2.11319876f, -3.41755080f, -0.36864156f, 1.32680905f, + 0.45004874f, 6.17223930f, -1.60707474f, 0.46096295f, -3.88852644f, + 1.84729624f, -0.03412050f, 0.99224162f, -2.05553341f, 3.47793245f, + -0.06305170f, 0.51314175f, -2.91650558f, -1.78121483f, -2.85465693f, + 0.24649808f, -2.70376635f, 0.42334458f, -1.13862336f, -0.98409218f, + -0.96593523f, 2.22128963f, 0.53402066f, 3.33979344f, 8.57430458f, + 2.34217858f, -2.40062976f, 5.81624222f, 1.13290989f, -5.06850052f, + -4.72865725f, 1.82859278f, 6.78569555f, 8.56885242f, 2.76462936f, + 0.33891773f, -2.81092787f, 0.79498398f, -2.27208567f, 1.55182552f, + 2.17166376f, 6.12517643f, 3.56859684f, 0.27685475f, -1.38408327f, + -1.03533340f, -3.46618199f, 0.79240030f, -3.89390516f, -0.55852515f, + -1.16367757f, -0.07008934f, -2.20105195f, 3.81210446f, -0.66834474f, + 0.43603873f, 10.92334938f, 2.48571420f, -6.34997845f, 4.23135757f, + 0.45045292f, -4.13489866f, -3.92324209f, 1.88537407f, 2.57159734f, + 9.90973091f, 4.37453461f, 7.34546280f, -2.51120615f, 11.12575245f, + -3.23452854f, -2.49947500f, 1.39819741f, -3.78950691f, 2.40617585f, + 5.10036278f, -3.55743456f, -6.42888737f, -2.51929998f, -1.90880990f, + -1.81618094f, 1.60946512f, -4.09737110f, 1.96408439f, -1.90115595f, + 2.44444203f, -2.31254292f, -4.01332951f, 8.65541840f, -0.58626485f, + -4.02226830f, 0.43893200f, -3.78272748f, -5.46277428f, 0.01306701f, + 0.61185312f, 0.24469066f, 1.30214953f, 5.87789631f, 8.75197792f, + -5.31634712f, 3.43556309f, -5.90755081f, 0.54375106f, -2.48162293f, + -3.51843548f, 2.55853295f, 5.06387186f, -2.09662485f, -3.00377345f, + -3.21781397f, -0.14537808f, -4.65453672f, 1.92747557f, 0.41553855f, + 4.09379959f, 0.83387995f, 1.50868511f, -6.54959488f, -8.38881016f, + 5.50689125f, -2.88616610f, -1.21597648f, -0.23817590f, 1.50816703f, + -2.26873541f, 2.29862142f, -1.61143053f, 5.97371244f, 4.71440220f, + -0.20635787f, 8.85926723f, 0.56064367f, -1.04103339f, -4.47060108f, + -2.63824081f, 3.06782055f, -2.07702565f, 3.38269401f, -1.59988797f, + -3.80122590f, 2.35341501f, 2.69095278f, 3.87612104f, 1.89984226f, + 0.95496917f, 3.14841127f, -5.84543085f, -7.24945450f, -2.65708590f, + 2.87417006f, 0.97556210f, -3.75203967f, 1.55287778f, -7.43401051f, + -1.29005826f, -3.40252638f, -4.01049423f, 2.82721639f, -1.21479535f, + 8.54563904f, 7.39749908f, -0.61361837f, 7.60177565f, 1.65812778f, + -0.83008504f, -3.60961151f, -7.69062138f, -1.26275063f, -4.17071676f, + 5.28448200f, 4.04685593f, -1.18231702f, 1.15276611f, 1.58620787f, + 6.75060844f, 3.29332161f, -0.67640316f, 5.78984785f, -3.14913464f, + -6.41867924f, -2.58316016f, -2.04366302f, 2.01089478f, -3.81723452f, + 3.63843751f, -5.13238430f, -3.79432917f, 4.86581373f, -1.06922054f, + 3.95978498f, -0.78166616f, 8.35650539f, 5.35834265f, 0.35594034f, + 9.41657066f, -0.84108615f, -6.54425859f, -3.44328952f, -6.55536795f, + -0.08963367f, -1.53906262f, 0.17658240f, -0.13108420f, -0.44371247f, + -0.78411150f, 2.64754868f, 9.66306782f, 1.70506203f, -0.31588936f, + 4.31715870f, -6.16665173f, -10.43371868f, -3.72962189f, 4.35245228f, + -1.75867891f, -4.20046234f, 8.62637043f, 1.45946813f, -3.30153608f, + 0.85179043f, -2.66643381f, 3.01863337f, -2.52916121f, 8.35405540f, + -0.37298933f, -0.89473486f, 6.88681793f, -4.46370125f, -7.50776386f, + 3.80255938f, -3.55003357f, 1.43528831f, -2.20383263f, 2.34999895f, + 2.03803205f, 1.94830751f, -1.85976326f, 0.97718471f, 5.53710842f, + -0.80560827f, 0.23925614f, 5.98795223f, -2.03578377f, -7.77835321f, + -2.79955530f, -1.88185954f, -2.49112058f, -0.76095992f, 2.71161270f, + -0.55918610f, 0.83789903f, -1.42063200f, -0.61528748f, -4.18273115f, + 1.76384258f, 4.21265936f, 5.50964785f, -0.93324339f, 3.83215356f, + 1.52210593f, -0.91594946f, 1.31148386f, 3.20160103f, 1.24493563f, + -0.72693497f, 1.84716725f, 3.09897518f, -1.34605026f, -1.17511916f, + -1.05526352f, -1.08590937f, -1.41319299f, -3.75052118f, -2.67095542f, + -0.76179552f, -3.32081509f, -1.04692316f, -1.30194843f, -1.98795474f, + 5.01223469f, 0.21895903f, -1.85535169f, 3.12362719f, 0.16198632f, + -3.86784005f, -2.03062248f, -0.15415624f, 8.22020721f, 4.83055592f, + 4.50315666f, 4.19443417f, 0.42727345f, -4.67786789f, -5.18739986f, + 2.53988838f, 3.19683266f, 1.80313504f, 1.94664574f, 0.59795094f, + -4.21626759f, 0.50492239f, -0.41232634f, -0.99224532f, -3.94929314f, + 1.74060190f, -0.92474866f, -1.00664830f, -6.17397356f, -1.33146775f, + -3.78111315f, -4.91876888f, 2.50303864f, -0.34890354f, -1.25013232f, + 0.38168997f, -1.84135628f, -4.46107960f, -4.05920792f, -2.61709857f, + 0.71046209f, 9.80566883f, 6.34086990f, 2.73394704f, -2.03342366f, + -2.21424174f, -5.56514263f, -4.74755144f, -2.20672894f, 0.09010231f, + 1.70423889f, 3.19200158f, -6.99027634f, 1.14216340f, 0.05824995f, + -0.76996505f, -6.51575899f, -0.41109252f, 0.78229940f, 1.36170781f, + -5.65170193f, 1.12221193f, -4.60430050f, -4.40174437f, 4.01805925f, + 0.10774946f, -2.77991009f, -0.18023163f, 0.02151692f, -1.77023101f, + -1.86639869f, -0.69443607f, 4.92290831f, 6.83520412f, 4.27372265f, + 6.54272366f, -7.59249687f, -1.40776849f, -3.52368808f, 1.01398587f, + -3.58802676f, -0.35658866f, 1.14716864f, 3.75847244f, -2.30159235f, + -0.72130895f, -0.24564353f, -1.77531350f, -3.08677864f, -0.73486501f, + -1.20357263f, 0.60789430f, -3.46990204f, -0.20668676f, -5.46096087f, + -5.22016764f, 0.98259866f, 1.81012678f, 3.92534304f, -2.94997001f, + 1.65154219f, 2.27040243f, 0.99095678f, 0.09144652f, -0.99103236f, + -1.11210847f, 0.78181303f, 2.38706732f, 2.96695375f, -0.17279971f, + 0.31143007f, 1.35465562f, 2.03586054f, 6.19515753f, -3.14652419f, + -2.89027119f, -3.26665854f, -1.93043876f, -0.46601450f, 1.07655203f, + 1.74946189f, 4.02148342f, 0.69275337f, 0.50094581f, -4.07613230f, + 2.98369169f, 4.24537849f, 0.49480581f, -2.02408123f, -2.02068973f, + 6.54505825f, -5.19377470f, -0.12596917f, -0.70204186f, -0.98308045f, + -3.19708824f, 1.63609934f, 1.35475993f, 0.16313422f, 4.13918924f, + 7.69187021f, 3.72601676f, -1.97790039f, -1.16739464f, -3.31835508f, + 8.14553452f, -1.78718984f, 1.21505618f, -3.84255409f, -3.21992350f, + 0.07376552f, -0.81223297f, 3.57002878f, 1.48521733f, -0.45995998f, + 0.30551746f, -3.33944130f, 1.39538884f, 1.84758544f, -0.21494150f, + -2.27316713f, -4.37771225f, 6.48841667f, -5.00251961f, -0.45162797f, + -5.01056004f, 0.70199943f, -4.60057783f, -2.22394514f, 0.07777429f, + -1.49820781f, 3.47308421f, 6.13231564f, 1.18605387f, -4.78924608f, + -3.49548388f, -2.73382568f, 6.24617863f, -2.74291611f, -1.03833354f, + -2.20752788f, -2.33219409f, 1.48633552f, 1.65796840f, 4.95045471f, + 2.58479190f, -0.90922785f, 0.71312457f, -4.44465590f, 1.37020862f, + 2.37683725f, 0.18805164f, -3.28422308f, -1.64939332f, 3.64181972f, + -3.75277281f, 3.67203593f, -0.11204052f, 2.24140930f, -3.90657187f, + 2.56883717f, -1.44016707f, -2.83842611f, -0.29104578f, 2.17757058f, + -0.71431804f, 1.36911654f, 0.85083604f, -1.60110259f, -1.97247636f, + -1.61163378f, -0.81236130f, -0.38993555f, -3.03631902f, -0.38213277f, + 0.06394482f, 3.19348621f, 0.36771113f, 1.36763072f, 2.49159527f, + -0.39599860f, -2.69996762f, -0.97561121f, -2.97563028f, -0.49662948f, + -0.17564940f, -2.79042959f, 0.72395414f, 2.07260203f, -0.99439794f, + -2.20248008f, -0.07389921f, 0.65536159f, 4.73054695f, -0.63917702f, + 0.58788192f, -3.60156059f, 6.59609890f, 3.88419437f, -3.38469863f, + -3.56237841f, -2.03295064f, 0.07279694f, 3.71804547f, 0.79928309f, + -2.13411403f, -1.13909864f, -0.34193408f, -1.00338125f, -1.44231665f, + -5.39835978f, -0.45086145f, 1.16064668f, 2.58335257f, 2.10072684f, + 4.64244223f, 7.10090065f, 1.01974952f, -4.44687223f, 2.99792576f, + 1.10303724f, -1.22736573f, -3.91514421f, 3.07458854f, 2.18765211f, + 3.34481716f, 2.46166849f, 2.99648619f, -0.94046807f, 5.55028200f, + 0.92199719f, -0.83934361f, -0.72042274f, 0.84869325f, 1.46914721f, + 0.85937387f, 4.77306223f, -4.06436539f, -2.59847593f, 2.44828081f, + 0.50484699f, -2.71092367f, -6.39010477f, 0.91778028f, 3.25469685f, + 1.30310678f, 1.35258150f, 3.56171441f, 7.82435083f, -2.51527429f, + -4.24328852f, 2.36876059f, 1.94595242f, -2.59290171f, -6.62389565f, + 3.32567835f, 2.13659120f, 4.09299326f, 3.48293996f, 2.64965177f, + -3.19157362f, 13.37204266f, -0.50297594f, -4.57448196f, 3.95582604f, + -0.69038916f, 0.10098404f, 1.18737555f, 3.65761185f, -5.69623756f, + -2.03357077f, 1.02868807f, -1.38448596f, -0.05690211f, -8.48874187f, + 0.56755424f, 1.45485961f, 0.66273880f, 0.06495565f, 1.79539490f, + 8.46864319f, -1.22696662f, -1.87585378f, -0.99768794f, 2.72801924f, + -0.66980243f, -2.31924677f, 0.33271110f, 0.11666083f, 1.86980045f, + 5.95332909f, 7.38583708f, -2.80956483f, 6.79227638f, -6.78070831f, + 1.21884382f, -1.40695429f, 0.90236962f, -1.13695288f, 0.50760663f, + 1.00955284f, -5.39029121f, 0.24987072f, 2.24283314f, -4.02145576f, + 2.18057394f, -3.35627747f, 1.26061773f, 1.30342579f, 0.11311233f, + -1.11199212f, -4.06509686f, 5.82649660f, -1.24059582f, 5.51652861f, + -1.90937877f, 1.10658336f, -0.47065550f, -2.39167786f, -1.95931304f, + 4.12717247f, 1.15396059f, 1.26015663f, 7.97836876f, 7.33633423f, + 2.27785325f, -2.83802366f, -2.74850106f, 0.86126029f, 6.18781090f, + -1.43707538f, -6.97134876f, -3.25486469f, -1.95214593f, 0.91066706f, + 0.89637989f, 1.06481194f, 6.25791073f, 0.81779671f, -1.08384395f, + -3.21191931f, 2.04216075f, 4.76030350f, -2.37217665f, -1.42571259f, + -6.35876131f, 4.62536526f, -5.40060568f, -3.14868999f, -1.00587153f, + 1.80662942f, -7.03201485f, 6.08373499f, 0.99862772f, 2.21717811f, + 4.06814623f, 6.02428913f, 5.33422756f, -0.87013257f, -2.22477579f, + -2.51505303f, 5.82925224f, -0.82854009f, -4.30698347f, -1.75007713f, + 2.08352375f, -2.25235629f, 1.17517352f, 5.77717733f, 2.27472878f, + 2.72778273f, -1.95411634f, -4.52602863f, 1.13983536f, 1.16340065f, + -2.02740526f, -3.11290503f, -1.94906235f, 1.54855204f, -4.52984142f, + 1.97465122f, -1.79415476f, 4.03510094f, -8.45349979f, 10.87430096f, + 2.19863629f, -5.39083815f, 5.86213875f, 6.25744534f, 6.52600002f, + -4.72149038f, -1.75254321f, -5.51459169f, 7.03155518f, -2.01889277f, + -4.58441257f, -3.61226106f, 0.42395937f, -0.93263882f, 2.28703761f, + 2.80611467f, 2.59498215f, 0.65989012f, -1.51268566f, -4.49465561f, + -4.70453882f, 5.44696808f, -4.37603617f, 0.46670085f, 2.82488608f, + 2.18854523f, -2.04817152f, 1.19557285f, 1.53618634f, 4.44758606f, + -7.31593513f, 7.43966007f, -3.55480957f, -5.29834652f, 2.14622784f, + 1.65194583f, 2.71262598f, -4.86145496f, 0.79726243f, -8.88541985f, + 1.19627261f, 0.79660845f, -1.98016644f, 1.03741014f, -3.93128228f, + 1.05535269f, 2.01378822f, -0.46086323f, -0.77754641f, -1.43942690f, + 0.49809402f, -2.27861357f, -3.29815221f, 0.38201320f, -3.98481083f, + 4.88261318f, -0.44555628f, -2.57224536f, 2.35001850f, -2.65835261f, + -2.43422794f, -2.97889376f, 1.07349825f, 1.88157082f, 4.74075413f, + 0.60376728f, -0.48894715f, -1.15800071f, 4.68110943f, -0.86976886f, + 1.49192941f, 0.62665290f, 0.20652676f, 0.53916287f, -1.45706177f, + 0.66133004f, 1.34405875f, -4.27689552f, -0.20838106f, -5.14266443f, + -1.29718637f, -1.74506426f, -0.86022055f, -3.57553625f, 0.46880072f, + -1.25287139f, 3.28596354f, 11.33191013f, 1.23942876f, -3.87616491f, + 7.57880497f, -0.22940339f, -5.68512678f, -1.94969654f, 5.85449600f, + 3.75705457f, 4.24395847f, 1.60086083f, 2.62553668f, -0.93964291f, + 5.84753895f, -0.79931092f, 0.48274064f, 2.07170033f, 3.02243996f, + 2.63509989f, -0.76043403f, -1.64048159f, -6.17683458f, -3.09974527f, + -2.12773156f, -0.89379883f, 2.82242465f, -1.99981332f, -0.08763933f, + 0.01921120f, -1.94142103f, 2.48067307f, 0.41083777f, 8.24922180f, + -1.84516132f, -1.39224625f, 5.03956223f, 0.49562740f, -5.28296328f, + -0.20005548f, 3.13672113f, 0.51187158f, 7.11563921f, 6.43059587f, + 3.48430967f, -5.37095928f, 8.03863049f, -5.53923941f, -2.16421175f, + -3.77641368f, 3.29633045f, 5.04030085f, 2.25945377f, -3.04169011f, + -2.16198015f, -2.49559617f, -0.26252726f, -6.99201345f, 2.87374353f, + -0.12568980f, 0.23314142f, -1.32087135f, 4.39030552f, -0.24638844f, + -4.37242651f, 14.09276772f, 1.23987353f, -1.72249663f, 0.31124914f, + -2.13725138f, -3.74915648f, -1.87147236f, 0.47318631f, 1.13337576f, + 3.00416899f, 8.82548523f, 4.80538750f, -5.28486395f, 5.51870108f, + -5.15801477f, 0.95712411f, -1.50416136f, 2.34657240f, 4.20726633f, + 5.56757259f, -3.30645251f, -3.39945269f, -2.68488026f, -2.53525281f, + -3.15145874f, 2.74529529f, -0.96283442f, 2.87778258f, 0.22186530f, + 1.24905694f, -7.07941198f, -5.45916176f, 3.46988297f, 0.92430985f, + -0.98330998f, -2.23672342f, -3.03262734f, 0.73941302f, 0.98004431f, + 0.83219361f, 7.17411804f, 4.27849865f, 0.14765590f, 8.61269569f, + 9.04497051f, 1.53991723f, -2.08305025f, -4.34939337f, 0.63786775f, + 2.60098696f, 0.02432060f, -1.48516297f, -4.06825686f, 5.12420368f, + -0.75312757f, 1.96927559f, 4.91575956f, 3.41533065f, 3.62557888f, + -4.35002136f, -5.91343403f, 0.45026422f, 4.93286371f, 3.45830250f, + -4.39032364f, -0.51697755f, -7.41543341f, -3.06703568f, 1.01196158f, + 2.47106576f, 5.54014874f, -4.65312243f, 8.61000633f, 8.25905323f, + -1.41497111f, 8.69221878f, 0.40090930f, 1.11325574f, -1.67089832f, + -4.01080132f, 1.07925677f, 2.68086481f, -0.73093414f, -1.35081220f, + -7.85765076f, -5.98989439f, -0.04651213f, 4.63693142f, 2.07757711f, + -0.22652936f, 3.45525455f, -0.69198442f, -10.39761639f, -2.02106953f, + 4.77755499f, -2.67665577f, -1.72481167f, 4.49634743f, -2.55717134f, + -4.55044937f, 0.46377492f, -3.08933020f, 3.86891365f, -2.79104614f, + 8.36974335f, 0.86471701f, -5.39342690f, 12.54906940f, -0.41536295f, + -5.29502535f, -3.94430566f, -5.67391300f, -4.65079165f, 2.22505951f, + -0.30000746f, 2.27855444f, -4.81604433f, -1.73440599f, 4.68784523f, + 5.00208044f, 0.18863934f, -1.74989462f, 3.17923450f, -1.59773099f, + -12.59962940f, -1.54495025f, -0.00576371f, 1.79913878f, -2.43449807f, + 1.49516344f, -3.90507102f, 1.68647158f, 4.50177765f, -5.32286358f, + 3.47539330f, -2.90529680f, 1.61576962f, 0.83679676f, -5.55615807f, + 3.78939056f, -4.46644831f, -5.95550919f, 0.37808037f, 0.51334500f, + 1.74658906f, -0.82085419f, -0.65387219f, 3.67790437f, 0.03758264f, + -2.42622781f, 1.83335185f, 4.73835945f, -0.83536482f, -0.03993917f, + 3.78230667f, -4.81265640f, -8.26869011f, -1.30363441f, -2.09106350f, + -3.96769738f, -1.89037073f, 0.38682747f, 0.05434489f, 5.72213697f, + 0.55685395f, -3.47729349f, -1.11535001f, 2.09416127f, 5.08877802f, + 5.72183466f, 1.29632664f, 0.16822398f, -2.43180108f, 3.49967623f, + 2.15753818f, -0.26548505f, 3.24446392f, -0.00599277f, 1.08215356f, + -0.23225522f, -2.40723038f, 0.18496060f, -3.70608735f, -0.19918591f, + -1.64028871f, 0.80792952f, -0.85334057f, -2.52314138f, -3.12099195f, + 0.17949918f, -0.82650864f, 2.32224989f, 9.56476116f, -0.20134282f, + -0.48428559f, 2.86784410f, 0.07289505f, -3.92880869f, -2.11887884f, + 0.59164631f, 6.31267452f, 7.49149418f, 2.88749456f, 2.40504885f, + -3.57608175f, -1.48019314f, -0.69410253f, 0.90275228f, -0.34111357f, + 2.19190216f, 3.39090061f, 3.39631820f, -5.19105434f, 2.67546582f, + -2.56549048f, -0.59797800f, -4.21802664f, 0.63918972f, -0.69969130f, + 0.47496963f, -4.30976725f, 0.16531238f, -3.59595251f, -0.76877379f, + 11.79971790f, -0.93276632f, -1.48630571f, 8.04754066f, 2.09168458f, + -3.77018499f, -4.19337654f, 0.26171905f, 1.99359691f, 8.96759701f, + 8.39609814f, 6.19231987f, -5.36037970f, 4.69818354f, -4.22453928f, + -4.61665344f, -2.52073431f, 1.34026706f, 2.80182385f, 2.56681514f, + -4.04676390f, -3.01466990f, -4.10480118f, 0.38737059f, -0.37146521f, + -2.26529670f, -1.72867084f, 0.93472683f, -2.47562981f, 0.89871657f, + -1.67618203f, -0.28950238f, 5.30124855f, -0.14731219f, -0.81319761f, + -1.11265934f, 0.11356127f, -2.52802444f, -1.93826056f, 1.06187987f, + 1.48062325f, 4.28070498f, 5.69893932f, 9.26904392f, -4.23773003f, + 5.78582096f, -6.18445301f, -2.85200453f, -5.30461454f, -4.16009140f, + -0.07239690f, 4.11531162f, -1.12266588f, -1.50265646f, 0.47661865f, + -1.90043914f, -6.48978710f, 1.71005368f, 0.18256521f, -0.88272136f, + -0.51324779f, -0.78045660f, -5.21036625f, -4.11805344f, 3.99454761f, + -1.04999924f, -6.99629354f, -5.02737141f, 0.94748145f, -2.35882139f, + 4.13982439f, -1.41835535f, 7.56763077f, 3.97024012f, -4.08156776f, + 6.90305424f, 0.53571963f, -2.22625160f, -2.09144926f, -4.98530245f, + -0.15102190f, 0.59995949f, 3.28562784f, 0.77991986f, -3.08389306f, + 3.34046674f, 0.41394949f, 5.10031366f, 2.99692893f, 0.17706826f, + 2.85998058f, -6.68330860f, -6.72653008f, -0.04071128f, 3.71085787f, + 3.17834806f, -4.88019037f, 6.74075413f, -7.41782188f, -5.22026348f, + -1.94595623f, -3.61318684f, 1.85610664f, 1.08613706f, 6.41580677f, + 1.46376514f, -4.11524010f, 9.59146214f, -2.92772651f, -1.70753336f, + -1.51594138f, -4.88185692f, 1.47331417f, -2.23893595f, 4.98459148f, + 1.29359996f, -2.29221845f, -0.99594390f, 3.05759239f, 6.86030054f, + 2.40487719f, 3.28339863f, 7.72739315f, -3.60563445f, -9.73502827f, + -1.51672328f, -0.08473521f, -2.43673515f, -3.26616001f, 3.63767886f, + -11.25394535f, -5.17597103f, -1.27523947f, -7.82669783f, 0.67929745f, + -4.50530529f, 5.49323797f, 6.78993320f, -2.28033876f, 4.61412525f, + 2.55109429f, -12.38607693f, -0.63024014f, -3.45992327f, -0.84092742f, + -0.03252453f, 4.58635283f, 5.28213978f, -1.28417206f, -1.71185923f, + -0.26850975f, 8.28257561f, 4.47432184f, 2.72818279f, 8.42217731f, + -4.22216320f, -8.95128918f, -1.57179546f, 1.34253705f, -5.47035217f, + -5.50866985f, 4.64156532f, -6.11207914f, -5.46734476f, 3.54298997f, + -2.79237103f, -0.70766860f, -3.62739944f, 3.22660995f, -2.02262759f, + 0.11224222f, 2.63832402f, -0.91955596f, -4.65958309f, -0.29729855f, + -1.78957534f, -0.40749407f, 0.51688713f, 0.83725226f, 0.30945438f, + 1.20769620f, -1.75219965f, 2.59689760f, 5.01501608f, -1.59034789f, + 0.58155286f, 3.75831509f, -5.26110506f, -8.65382767f, -6.19066620f, + -0.61932850f, -2.71863723f, -0.87443137f, 3.40582991f, -1.27868056f, + 3.51236677f, -2.07806540f, -0.85076392f, -1.14599180f, 1.16361260f, + 1.86411846f, 5.86179352f, 0.69029891f, -0.06060839f, 1.54649436f, + -0.60351688f, 1.51970077f, 0.04187265f, 1.64540339f, 2.75502157f, + 2.46308279f, 1.69071770f, -3.23827076f, 0.92096543f, -3.09458661f, + -1.23823690f, 0.24035048f, -0.74456501f, -1.85476089f, -0.32914662f, + -2.10325241f, 1.19795251f, -2.05372071f, 1.02114081f, 2.56286955f, + 0.42165697f, -1.65826249f, 4.00724554f, -2.18727994f, -1.05848944f, + -0.52338278f, -0.28714985f, 8.08780861f, 5.04444599f, 3.51866961f, + 3.37445784f, -1.96067202f, -1.21509445f, -3.96595931f, -0.80801201f, + 0.76944816f, 1.80147493f, 4.14419460f, -0.12201095f, -2.77788162f, + 1.13284469f, -2.05441403f, -0.61129224f, -2.69690657f, 1.91634214f, + -2.17146754f, -0.22308528f, -6.02561045f, 0.49161875f, -6.74280357f, + -4.62689781f, 2.47910833f, 1.86534905f, -3.24152899f, -1.39898300f, + 0.29427958f, -2.16338181f, 0.90073711f, 1.75551236f, 4.42651892f, + 8.34437466f, 5.50070190f, 5.68162251f, 1.65345454f, -2.72315669f, + -5.43411493f, -0.29380533f, 1.07508349f, -1.73533511f, 2.56912184f, + 3.62010550f, -6.30422783f, 1.74158525f, -1.22070909f, -0.80982518f, + -4.14757967f, 4.29217434f, 0.70600843f, -2.09282112f, -5.09018898f, + -0.11623126f, -5.99775553f, -4.66743088f, 1.61512172f, -1.30276895f, + -3.17103505f, -0.26310229f, -1.00843918f, -0.77664804f, -2.05240250f, + 0.04728425f, 1.15720487f, 4.01001406f, 7.24615860f, 2.55452180f, + -5.76347876f, 0.34683830f, -6.05540276f, -4.70677900f, -0.93182588f, + -4.37759733f, 2.93209839f, 1.63947964f, -2.43563962f, 1.35213876f, + 0.00670356f, -0.02742785f, -2.16460943f, 1.39449501f, 0.23929763f, + 2.37476778f, -4.17733765f, -0.81475425f, -6.15027046f, -5.74441719f, + 3.53978682f, 0.66798484f}); + + sd::ops::deconv2d_tf op; + auto result = op.evaluate({&input0, &input1, &input2}, {}, + {7, 7, 2, 2, 0, 0, 1, 1, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, Test_Dilation2D_Again_1) { - auto x = NDArrayFactory::create('c', {4, 128, 128, 4}); - auto w = NDArrayFactory::create('c', {4, 5, 4}); - auto exp = NDArrayFactory::create('c', {4, 64, 43, 4}); - + auto x = NDArrayFactory::create('c', {4, 128, 128, 4}); + auto w = NDArrayFactory::create('c', {4, 5, 4}); + auto exp = NDArrayFactory::create('c', {4, 64, 43, 4}); - sd::ops::dilation2d op; - auto result = op.evaluate({&x, &w}, {}, {1, 1,5,7,1, 1,2,3,1}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::dilation2d op; + auto result = op.evaluate({&x, &w}, {}, {1, 1, 5, 7, 1, 1, 2, 3, 1}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.isSameShape(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, Test_Dilation2D_Again_2) { - auto x = NDArrayFactory::create('c', {4, 26, 19, 4}); - auto w = NDArrayFactory::create('c', {11, 7, 4}); - - sd::ops::dilation2d op; - auto result = op.evaluate({&x, &w}, {}, {0, 1,2,3,1, 1,3,2,1}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {4, 26, 19, 4}); + auto w = NDArrayFactory::create('c', {11, 7, 4}); + sd::ops::dilation2d op; + auto result = op.evaluate({&x, &w}, {}, {0, 1, 2, 3, 1, 1, 3, 2, 1}); + ASSERT_EQ(Status::OK(), result.status()); } TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_1) { - TypeParam _expGradWpB[] = {1603.7102981f, 10645.6278024f, 5975.4227995f, 17697.0903052f, 12133.6353024f, 26535.0528052f, 1779.221097f, 11795.5686029f, 6721.9835994f, 19904.0811062f, 13775.2461029f, 30123.0936062f, 1954.7318976f, 12945.5094033f, 7468.5443993f, 22111.071907f, 15416.8569033f, 33711.134407f, 2130.2426974f, 14095.4502038f, 8215.1051992f, 24318.0627081f, 17058.4677038f, 37299.1752081f, 2305.7534972f, 15245.3910042f, 8961.6659991f, 26525.0535091f, 18700.0785042f, 40887.2160091f, 2481.2642970f, 16395.3318047f, 9708.2267991f, 28732.0443100f, 20341.6893047f, 44475.2568100f, 2656.7750968f, 17545.2726051f, 10454.7875990f, 30939.0351110f, 21983.3001051f, 48063.2976110f, 2832.2858966f, 18695.2134056f, 11201.3483989f, 33146.0259119f, 23624.9109056f, 51651.3384119f, 3007.7966964f, 19845.1542060f, 11947.9091988f, 35353.0167129f, 25266.5217060f, 55239.3792129f, 3183.3074962f, 20995.095006f, 12694.4699987f, 37560.007513f, 26908.132506f, 58827.4200139f}; - Nd4jLong _expGradWpS[] {4, 10, 6, 1, 1, 6, 1, 1, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; - NDArray expGWP(_expGradWpB, _expGradWpS); - expGWP.permutei({2,3,1,0}); - - TypeParam _expGradWdB[] = {2074.21032f, 2082.76104f, 2091.31176f, 2099.86248f, 2108.4132f, 2159.71752f, 2168.26824f, 2176.81896f, 2185.36968f, 2193.9204f, 2245.22472f, 2253.77544f, 2262.32616f, 2270.87688f, 2279.4276f, 2330.73192f, 2339.28264f, 2347.83336f, 2356.38408f, 2364.9348f, 2416.23912f, 2424.78984f, 2433.34056f, 2441.89128f, 2450.442f, 3112.99344f, 3122.06328f, 3131.13312f, 3140.20296f, 3149.2728f, 3203.69184f, 3212.76168f, 3221.83152f, 3230.90136f, 3239.9712f, 3294.39024f, 3303.46008f, 3312.52992f, 3321.59976f, 3330.6696f, 3385.08864f, 3394.15848f, 3403.22832f, 3412.29816f, 3421.368f, 3475.78704f, 3484.85688f, 3493.92672f, 3502.99656f, 3512.0664f, 4255.60056f, 4265.18952f, 4274.77848f, 4284.36744f, 4293.9564f, 4351.49016f, 4361.07912f, 4370.66808f, 4380.25704f, 4389.846f, 4447.37976f, 4456.96872f, 4466.55768f, 4476.14664f, 4485.7356f, 4543.26936f, 4552.85832f, 4562.44728f, 4572.03624f, 4581.6252f, 4639.15896f, 4648.74792f, 4658.33688f, 4667.92584f, 4677.5148f, 2140.10988f, 2148.92016f, 2157.73044f, 2166.54072f, 2175.351f, 2228.21268f, 2237.02296f, 2245.83324f, 2254.64352f, 2263.4538f, 2316.31548f, 2325.12576f, 2333.93604f, 2342.74632f, 2351.5566f, 2404.41828f, 2413.22856f, 2422.03884f, 2430.84912f, 2439.6594f, 2492.52108f, 2501.33136f, 2510.14164f, 2518.95192f, 2527.7622f, 3204.849f, 3214.1784f, 3223.5078f, 3232.8372f, 3242.1666f, 3298.143f, 3307.4724f, 3316.8018f, 3326.1312f, 3335.4606f, 3391.437f, 3400.7664f, 3410.0958f, 3419.4252f, 3428.7546f, 3484.731f, 3494.0604f, 3503.3898f, 3512.7192f, 3522.0486f, 3578.025f, 3587.3544f, 3596.6838f, 3606.0132f, 3615.3426f, 4373.41212f, 4383.26064f, 4393.10916f, 4402.95768f, 4412.8062f, 4471.89732f, 4481.74584f, 4491.59436f, 4501.44288f, 4511.2914f, 4570.38252f, 4580.23104f, 4590.07956f, 4599.92808f, 4609.7766f, 4668.86772f, 4678.71624f, 4688.56476f, 4698.41328f, 4708.2618f, 4767.35292f, 4777.20144f, 4787.04996f, 4796.89848f, 4806.747f}; - Nd4jLong _expGradWdS[] = {4, 2, 3, 5, 5, 75, 25, 5, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; - NDArray expGWD(_expGradWdB, _expGradWdS); - expGWD.permutei({2,3,1,0}); - - TypeParam _expEB[] = {5.0103f, 10.17147f, 15.48408f, 20.9487f, 26.5659f, 26.6832f, 21.65628f, 16.47507f, 11.139f, 5.6475f, 10.79727f, 21.90255f, 33.31698f, 45.0417f, 57.07785f, 57.3267f, 46.49334f, 35.34513f, 23.88093f, 12.0996f, 17.37801f, 35.22744f, 53.55f, 72.3474f, 91.62135f, 92.016f, 74.57958f, 56.66148f, 38.25999f, 19.3734f, 24.76962f, 50.18034f, 76.23444f, 102.9342f, 130.2819f, 130.8366f, 105.9834f, 80.47542f, 54.31038f, 27.486f, 32.9892f, 66.79545f, 101.4216f, 136.8705f, 173.145f, 173.874f, 140.7732f, 106.83825f, 72.0663f, 36.4545f, 33.8298f, 68.49375f, 103.9947f, 140.3355f, 177.519f, 178.248f, 144.3066f, 109.51395f, 73.8672f, 37.3635f, 28.85658f, 58.39302f, 88.6116f, 119.5146f, 151.1043f, 151.716f, 122.76444f, 93.11934f, 62.77842f, 31.7394f, 23.00409f, 46.52748f, 70.57188f, 95.139f, 120.23055f, 120.7107f, 97.6311f, 74.02194f, 49.88151f, 25.2081f, 16.25523f, 32.86293f, 49.82424f, 67.1403f, 84.81225f, 85.1466f, 68.83818f, 52.17045f, 35.14227f, 17.7525f, 8.5929f, 17.36517f, 26.31738f, 35.4501f, 44.7639f, 44.9382f, 36.31728f, 27.51357f, 18.5265f, 9.3555f, 8.63807f, 17.45032f, 26.43736f, 35.5998f, 44.93825f, 45.1399f, 36.46882f, 27.6199f, 18.59253f, 9.3861f, 18.18615f, 36.72737f, 55.62488f, 74.8799f, 94.49365f, 94.9122f, 76.65698f, 58.03937f, 39.05815f, 19.7121f, 28.66254f, 57.86775f, 87.61746f, 117.9135f, 148.7577f, 149.4084f, 120.63768f, 91.31331f, 61.43346f, 30.9963f, 40.08554f, 80.90806f, 122.47f, 164.7738f, 207.8219f, 208.72f, 168.48412f, 127.49662f, 85.75506f, 43.257f, 52.47345f, 105.8849f, 160.2374f, 215.534f, 271.77775f, 272.9385f, 220.2695f, 166.6442f, 112.05955f, 56.5125f, 53.82975f, 108.6158f, 164.3612f, 221.069f, 278.74225f, 279.903f, 225.8777f, 170.8778f, 114.90025f, 57.942f, 45.14002f, 91.0585f, 137.75788f, 185.2406f, 233.5091f, 234.4682f, 189.16564f, 143.06998f, 96.17878f, 48.4896f, 35.43048f, 71.45487f, 108.075f, 145.2927f, 183.1098f, 183.852f, 148.29504f, 112.13319f, 75.36462f, 37.9875f, 24.68283f, 49.76831f, 75.25766f, 101.1521f, 127.45285f, 127.9629f, 103.1927f, 78.01253f, 52.42117f, 26.4174f, 12.87877f, 25.96222f, 39.25096f, 52.7456f, 66.44675f, 66.7094f, 53.78542f, 40.6531f, 27.31183f, 13.761f, 12.59184f, 25.38317f, 38.37464f, 51.5669f, 64.9606f, 65.2566f, 52.61336f, 39.76673f, 26.71606f, 13.4607f, 26.23903f, 52.88419f, 79.93678f, 107.3981f, 135.26945f, 135.8777f, 109.53262f, 82.77361f, 55.59937f, 28.0086f, 40.96107f, 82.54206f, 124.74492f, 167.5716f, 211.02405f, 211.9608f, 170.83578f, 129.07914f, 86.68893f, 43.6632f, 56.77746f, 114.39578f, 172.85756f, 232.1654f, 292.3219f, 293.6034f, 236.60084f, 178.74182f, 120.02374f, 60.444f, 73.7077f, 148.48435f, 224.3332f, 301.2575f, 379.2605f, 380.903f, 306.9058f, 231.82015f, 155.6428f, 78.3705f, 75.6397f, 152.36785f, 230.1877f, 309.1025f, 389.1155f, 390.758f, 314.8288f, 237.79165f, 159.6433f, 80.3805f, 62.89546f, 126.67598f, 191.34416f, 256.9026f, 323.3539f, 324.7004f, 261.56684f, 197.53262f, 132.59514f, 66.7518f, 48.97887f, 98.63226f, 148.96212f, 199.9704f, 251.65905f, 252.6933f, 203.53098f, 153.68244f, 103.14573f, 51.9189f, 33.87043f, 68.19769f, 102.98308f, 138.2279f, 173.93345f, 174.6392f, 140.64322f, 106.18261f, 71.25607f, 35.8623f, 17.55064f, 35.33327f, 53.34854f, 71.5971f, 90.0796f, 90.4406f, 72.82556f, 54.97463f, 36.88716f, 18.5625f, 13.0455f, 26.44707f, 40.20528f, 54.3207f, 68.7939f, 68.9112f, 55.84908f, 42.42747f, 28.6458f, 14.5035f, 27.89367f, 56.50575f, 85.83738f, 115.8897f, 146.66385f, 146.9127f, 118.98294f, 90.32793f, 60.94653f, 30.8376f, 44.56161f, 90.21024f, 136.9476f, 184.7754f, 233.69535f, 234.09f, 189.46998f, 143.75268f, 96.93639f, 49.0194f, 63.06642f, 127.59474f, 193.58724f, 261.0462f, 329.9739f, 330.5286f, 267.3786f, 202.75302f, 136.64958f, 69.066f, 83.4252f, 168.69345f, 255.8076f, 344.7705f, 435.585f, 436.314f, 352.7772f, 267.38025f, 180.1203f, 90.9945f, 84.2658f, 170.39175f, 258.3807f, 348.2355f, 439.959f, 440.688f, 356.3106f, 270.05595f, 181.9212f, 91.9035f, 71.25738f, 144.01542f, 218.2764f, 294.0426f, 371.3163f, 371.928f, 300.57564f, 227.70894f, 153.32562f, 77.4234f, 56.34369f, 113.82228f, 172.43748f, 232.191f, 293.08455f, 293.5647f, 237.1455f, 179.58114f, 120.86991f, 61.0101f, 39.50763f, 79.77813f, 120.81264f, 162.6123f, 205.17825f, 205.5126f, 165.95178f, 125.62125f, 84.51987f, 42.6465f, 20.7321f, 41.84877f, 63.35058f, 85.2381f, 107.5119f, 107.6862f, 86.92608f, 65.77797f, 44.2413f, 22.3155f, 22.71767f, 45.82912f, 69.33496f, 93.2358f, 117.53225f, 117.7339f, 94.98322f, 71.8351f, 48.28893f, 24.3441f, 47.44335f, 95.68097f, 144.71408f, 194.5439f, 245.17165f, 245.5902f, 198.07778f, 149.76377f, 100.64695f, 50.7261f, 74.19534f, 149.59215f, 226.19226f, 303.9975f, 383.0097f, 383.6604f, 309.35688f, 233.84091f, 157.11066f, 79.1643f, 102.99194f, 207.59926f, 313.8244f, 421.6698f, 531.1379f, 532.036f, 428.89372f, 324.12142f, 217.71666f, 109.677f, 133.85145f, 269.7389f, 407.6654f, 547.634f, 689.64775f, 690.8085f, 556.7615f, 420.6602f, 282.50155f, 142.2825f, 135.20775f, 272.4698f, 411.7892f, 553.169f, 696.61225f, 697.773f, 562.3697f, 424.8938f, 285.34225f, 143.712f, 112.43842f, 226.5337f, 342.28828f, 459.7046f, 578.7851f, 579.7442f, 467.14324f, 352.87078f, 236.92438f, 119.3016f, 87.55128f, 176.35527f, 266.4138f, 357.7287f, 450.3018f, 451.044f, 363.36624f, 274.42479f, 184.21782f, 92.7435f, 60.52803f, 121.89791f, 184.11086f, 247.1681f, 311.07085f, 311.5809f, 250.9655f, 189.50093f, 127.18597f, 64.0194f, 31.35037f, 63.12502f, 95.32456f, 127.9496f, 161.00075f, 161.2634f, 129.86782f, 98.0443f, 65.79223f, 33.111f, 33.43584f, 67.30517f, 101.60864f, 136.3469f, 171.5206f, 171.8166f, 138.32936f, 104.40473f, 70.04206f, 35.2407f, 69.09703f, 139.06819f, 209.91478f, 281.6381f, 354.23945f, 354.8477f, 285.64462f, 215.55961f, 144.59137f, 72.7386f, 107.00307f, 215.32806f, 324.97692f, 435.9516f, 548.25405f, 549.1908f, 442.02378f, 333.52314f, 223.68693f, 112.5132f, 147.17346f, 296.12378f, 446.85356f, 599.3654f, 753.6619f, 754.9434f, 607.54484f, 458.35382f, 307.36774f, 154.584f, 189.6277f, 381.49435f, 575.6032f, 771.9575f, 970.5605f, 972.203f, 782.2858f, 590.11015f, 395.6728f, 198.9705f, 191.5597f, 385.37785f, 581.4577f, 779.8025f, 980.4155f, 982.058f, 790.2088f, 596.08165f, 399.6733f, 200.9805f, 157.97146f, 317.76398f, 479.38016f, 642.8226f, 808.0939f, 809.4404f, 651.23084f, 491.18462f, 329.29914f, 165.5718f, 122.04087f, 245.45826f, 370.25412f, 496.4304f, 623.98905f, 625.0233f, 502.79898f, 379.18644f, 254.18373f, 127.7889f, 83.74843f, 168.42169f, 254.02108f, 340.5479f, 428.00345f, 428.7092f, 344.83522f, 260.02861f, 174.28807f, 87.6123f, 43.07464f, 86.61527f, 130.62254f, 175.0971f, 220.0396f, 220.4006f, 177.26156f, 133.65263f, 89.57316f, 45.0225f }; - Nd4jLong _expES[] = {4, 2, 3, 10, 10, 300, 100, 10, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; - NDArray expE(_expEB, _expES); - - auto input = NDArrayFactory::create('c', {2, 3, 10, 10}); - auto weightsD = NDArrayFactory::create('c', {2, 3, 5, 5}); - auto weightsP = NDArrayFactory::create('c', {10, 6, 1, 1}); - - auto epsilon = NDArrayFactory::create('c', {2, 3, 10, 10}); - auto epsilonNext = NDArrayFactory::create('c', {2, 10, 6, 6}); - - input.linspace(1); - weightsD.linspace(1); - weightsP.linspace(1); - epsilonNext.linspace(1); - weightsD.permutei({2,3,1,0}); - weightsP.permutei({2,3,1,0}); - - input.applyScalar(scalar::Divide, 100.0, input); - weightsD.applyScalar(scalar::Divide, 100.0, weightsD); - weightsP.applyScalar(scalar::Divide, 100.0, weightsP); - epsilonNext.applyScalar(scalar::Divide, 100.0, epsilonNext); - - sd::ops::sconv2d_bp op; - auto resultBP = op.evaluate({&input, &epsilonNext, &weightsD, &weightsP },{}, {5, 5, 1, 1, 0, 0, 1, 1, 0}, {}); - - ASSERT_EQ(3, resultBP.size()); - - auto _epsilon = resultBP.at(0); - auto _gradWD = resultBP.at(1); - auto _gradWP = resultBP.at(2); - - //_gradWP->printBuffer("gradWP"); - - ASSERT_TRUE(_gradWP->isSameShape(&expGWP)); - ASSERT_TRUE(_gradWP->isSameShape(&weightsP)); - - ASSERT_TRUE(_gradWP->equalsTo(&expGWP)); - - //_gradWD->printShapeInfo("gradWD shape"); - - ASSERT_TRUE(_gradWD->isSameShape(&expGWD)); - ASSERT_TRUE(_gradWD->isSameShape(&weightsD)); -// _gradWD->printIndexedBuffer(); - ASSERT_TRUE(_gradWD->equalsTo(&expGWD)); - - ASSERT_TRUE(_epsilon->isSameShape(&input)); - ASSERT_TRUE(_epsilon->isSameShape(&expE)); - - ASSERT_TRUE(_epsilon->equalsTo(&expE)); - + TypeParam _expGradWpB[] = { + 1603.7102981f, 10645.6278024f, 5975.4227995f, 17697.0903052f, + 12133.6353024f, 26535.0528052f, 1779.221097f, 11795.5686029f, + 6721.9835994f, 19904.0811062f, 13775.2461029f, 30123.0936062f, + 1954.7318976f, 12945.5094033f, 7468.5443993f, 22111.071907f, + 15416.8569033f, 33711.134407f, 2130.2426974f, 14095.4502038f, + 8215.1051992f, 24318.0627081f, 17058.4677038f, 37299.1752081f, + 2305.7534972f, 15245.3910042f, 8961.6659991f, 26525.0535091f, + 18700.0785042f, 40887.2160091f, 2481.2642970f, 16395.3318047f, + 9708.2267991f, 28732.0443100f, 20341.6893047f, 44475.2568100f, + 2656.7750968f, 17545.2726051f, 10454.7875990f, 30939.0351110f, + 21983.3001051f, 48063.2976110f, 2832.2858966f, 18695.2134056f, + 11201.3483989f, 33146.0259119f, 23624.9109056f, 51651.3384119f, + 3007.7966964f, 19845.1542060f, 11947.9091988f, 35353.0167129f, + 25266.5217060f, 55239.3792129f, 3183.3074962f, 20995.095006f, + 12694.4699987f, 37560.007513f, 26908.132506f, 58827.4200139f}; + Nd4jLong _expGradWpS[]{ + 4, 10, 6, 1, 1, + 6, 1, 1, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, + 1, 99}; + NDArray expGWP(_expGradWpB, _expGradWpS); + expGWP.permutei({2, 3, 1, 0}); + + TypeParam _expGradWdB[] = { + 2074.21032f, 2082.76104f, 2091.31176f, 2099.86248f, 2108.4132f, + 2159.71752f, 2168.26824f, 2176.81896f, 2185.36968f, 2193.9204f, + 2245.22472f, 2253.77544f, 2262.32616f, 2270.87688f, 2279.4276f, + 2330.73192f, 2339.28264f, 2347.83336f, 2356.38408f, 2364.9348f, + 2416.23912f, 2424.78984f, 2433.34056f, 2441.89128f, 2450.442f, + 3112.99344f, 3122.06328f, 3131.13312f, 3140.20296f, 3149.2728f, + 3203.69184f, 3212.76168f, 3221.83152f, 3230.90136f, 3239.9712f, + 3294.39024f, 3303.46008f, 3312.52992f, 3321.59976f, 3330.6696f, + 3385.08864f, 3394.15848f, 3403.22832f, 3412.29816f, 3421.368f, + 3475.78704f, 3484.85688f, 3493.92672f, 3502.99656f, 3512.0664f, + 4255.60056f, 4265.18952f, 4274.77848f, 4284.36744f, 4293.9564f, + 4351.49016f, 4361.07912f, 4370.66808f, 4380.25704f, 4389.846f, + 4447.37976f, 4456.96872f, 4466.55768f, 4476.14664f, 4485.7356f, + 4543.26936f, 4552.85832f, 4562.44728f, 4572.03624f, 4581.6252f, + 4639.15896f, 4648.74792f, 4658.33688f, 4667.92584f, 4677.5148f, + 2140.10988f, 2148.92016f, 2157.73044f, 2166.54072f, 2175.351f, + 2228.21268f, 2237.02296f, 2245.83324f, 2254.64352f, 2263.4538f, + 2316.31548f, 2325.12576f, 2333.93604f, 2342.74632f, 2351.5566f, + 2404.41828f, 2413.22856f, 2422.03884f, 2430.84912f, 2439.6594f, + 2492.52108f, 2501.33136f, 2510.14164f, 2518.95192f, 2527.7622f, + 3204.849f, 3214.1784f, 3223.5078f, 3232.8372f, 3242.1666f, + 3298.143f, 3307.4724f, 3316.8018f, 3326.1312f, 3335.4606f, + 3391.437f, 3400.7664f, 3410.0958f, 3419.4252f, 3428.7546f, + 3484.731f, 3494.0604f, 3503.3898f, 3512.7192f, 3522.0486f, + 3578.025f, 3587.3544f, 3596.6838f, 3606.0132f, 3615.3426f, + 4373.41212f, 4383.26064f, 4393.10916f, 4402.95768f, 4412.8062f, + 4471.89732f, 4481.74584f, 4491.59436f, 4501.44288f, 4511.2914f, + 4570.38252f, 4580.23104f, 4590.07956f, 4599.92808f, 4609.7766f, + 4668.86772f, 4678.71624f, 4688.56476f, 4698.41328f, 4708.2618f, + 4767.35292f, 4777.20144f, 4787.04996f, 4796.89848f, 4806.747f}; + Nd4jLong _expGradWdS[] = { + 4, 2, 3, 5, 5, + 75, 25, 5, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, + 1, 99}; + NDArray expGWD(_expGradWdB, _expGradWdS); + expGWD.permutei({2, 3, 1, 0}); + + TypeParam _expEB[] = { + 5.0103f, 10.17147f, 15.48408f, 20.9487f, 26.5659f, 26.6832f, + 21.65628f, 16.47507f, 11.139f, 5.6475f, 10.79727f, 21.90255f, + 33.31698f, 45.0417f, 57.07785f, 57.3267f, 46.49334f, 35.34513f, + 23.88093f, 12.0996f, 17.37801f, 35.22744f, 53.55f, 72.3474f, + 91.62135f, 92.016f, 74.57958f, 56.66148f, 38.25999f, 19.3734f, + 24.76962f, 50.18034f, 76.23444f, 102.9342f, 130.2819f, 130.8366f, + 105.9834f, 80.47542f, 54.31038f, 27.486f, 32.9892f, 66.79545f, + 101.4216f, 136.8705f, 173.145f, 173.874f, 140.7732f, 106.83825f, + 72.0663f, 36.4545f, 33.8298f, 68.49375f, 103.9947f, 140.3355f, + 177.519f, 178.248f, 144.3066f, 109.51395f, 73.8672f, 37.3635f, + 28.85658f, 58.39302f, 88.6116f, 119.5146f, 151.1043f, 151.716f, + 122.76444f, 93.11934f, 62.77842f, 31.7394f, 23.00409f, 46.52748f, + 70.57188f, 95.139f, 120.23055f, 120.7107f, 97.6311f, 74.02194f, + 49.88151f, 25.2081f, 16.25523f, 32.86293f, 49.82424f, 67.1403f, + 84.81225f, 85.1466f, 68.83818f, 52.17045f, 35.14227f, 17.7525f, + 8.5929f, 17.36517f, 26.31738f, 35.4501f, 44.7639f, 44.9382f, + 36.31728f, 27.51357f, 18.5265f, 9.3555f, 8.63807f, 17.45032f, + 26.43736f, 35.5998f, 44.93825f, 45.1399f, 36.46882f, 27.6199f, + 18.59253f, 9.3861f, 18.18615f, 36.72737f, 55.62488f, 74.8799f, + 94.49365f, 94.9122f, 76.65698f, 58.03937f, 39.05815f, 19.7121f, + 28.66254f, 57.86775f, 87.61746f, 117.9135f, 148.7577f, 149.4084f, + 120.63768f, 91.31331f, 61.43346f, 30.9963f, 40.08554f, 80.90806f, + 122.47f, 164.7738f, 207.8219f, 208.72f, 168.48412f, 127.49662f, + 85.75506f, 43.257f, 52.47345f, 105.8849f, 160.2374f, 215.534f, + 271.77775f, 272.9385f, 220.2695f, 166.6442f, 112.05955f, 56.5125f, + 53.82975f, 108.6158f, 164.3612f, 221.069f, 278.74225f, 279.903f, + 225.8777f, 170.8778f, 114.90025f, 57.942f, 45.14002f, 91.0585f, + 137.75788f, 185.2406f, 233.5091f, 234.4682f, 189.16564f, 143.06998f, + 96.17878f, 48.4896f, 35.43048f, 71.45487f, 108.075f, 145.2927f, + 183.1098f, 183.852f, 148.29504f, 112.13319f, 75.36462f, 37.9875f, + 24.68283f, 49.76831f, 75.25766f, 101.1521f, 127.45285f, 127.9629f, + 103.1927f, 78.01253f, 52.42117f, 26.4174f, 12.87877f, 25.96222f, + 39.25096f, 52.7456f, 66.44675f, 66.7094f, 53.78542f, 40.6531f, + 27.31183f, 13.761f, 12.59184f, 25.38317f, 38.37464f, 51.5669f, + 64.9606f, 65.2566f, 52.61336f, 39.76673f, 26.71606f, 13.4607f, + 26.23903f, 52.88419f, 79.93678f, 107.3981f, 135.26945f, 135.8777f, + 109.53262f, 82.77361f, 55.59937f, 28.0086f, 40.96107f, 82.54206f, + 124.74492f, 167.5716f, 211.02405f, 211.9608f, 170.83578f, 129.07914f, + 86.68893f, 43.6632f, 56.77746f, 114.39578f, 172.85756f, 232.1654f, + 292.3219f, 293.6034f, 236.60084f, 178.74182f, 120.02374f, 60.444f, + 73.7077f, 148.48435f, 224.3332f, 301.2575f, 379.2605f, 380.903f, + 306.9058f, 231.82015f, 155.6428f, 78.3705f, 75.6397f, 152.36785f, + 230.1877f, 309.1025f, 389.1155f, 390.758f, 314.8288f, 237.79165f, + 159.6433f, 80.3805f, 62.89546f, 126.67598f, 191.34416f, 256.9026f, + 323.3539f, 324.7004f, 261.56684f, 197.53262f, 132.59514f, 66.7518f, + 48.97887f, 98.63226f, 148.96212f, 199.9704f, 251.65905f, 252.6933f, + 203.53098f, 153.68244f, 103.14573f, 51.9189f, 33.87043f, 68.19769f, + 102.98308f, 138.2279f, 173.93345f, 174.6392f, 140.64322f, 106.18261f, + 71.25607f, 35.8623f, 17.55064f, 35.33327f, 53.34854f, 71.5971f, + 90.0796f, 90.4406f, 72.82556f, 54.97463f, 36.88716f, 18.5625f, + 13.0455f, 26.44707f, 40.20528f, 54.3207f, 68.7939f, 68.9112f, + 55.84908f, 42.42747f, 28.6458f, 14.5035f, 27.89367f, 56.50575f, + 85.83738f, 115.8897f, 146.66385f, 146.9127f, 118.98294f, 90.32793f, + 60.94653f, 30.8376f, 44.56161f, 90.21024f, 136.9476f, 184.7754f, + 233.69535f, 234.09f, 189.46998f, 143.75268f, 96.93639f, 49.0194f, + 63.06642f, 127.59474f, 193.58724f, 261.0462f, 329.9739f, 330.5286f, + 267.3786f, 202.75302f, 136.64958f, 69.066f, 83.4252f, 168.69345f, + 255.8076f, 344.7705f, 435.585f, 436.314f, 352.7772f, 267.38025f, + 180.1203f, 90.9945f, 84.2658f, 170.39175f, 258.3807f, 348.2355f, + 439.959f, 440.688f, 356.3106f, 270.05595f, 181.9212f, 91.9035f, + 71.25738f, 144.01542f, 218.2764f, 294.0426f, 371.3163f, 371.928f, + 300.57564f, 227.70894f, 153.32562f, 77.4234f, 56.34369f, 113.82228f, + 172.43748f, 232.191f, 293.08455f, 293.5647f, 237.1455f, 179.58114f, + 120.86991f, 61.0101f, 39.50763f, 79.77813f, 120.81264f, 162.6123f, + 205.17825f, 205.5126f, 165.95178f, 125.62125f, 84.51987f, 42.6465f, + 20.7321f, 41.84877f, 63.35058f, 85.2381f, 107.5119f, 107.6862f, + 86.92608f, 65.77797f, 44.2413f, 22.3155f, 22.71767f, 45.82912f, + 69.33496f, 93.2358f, 117.53225f, 117.7339f, 94.98322f, 71.8351f, + 48.28893f, 24.3441f, 47.44335f, 95.68097f, 144.71408f, 194.5439f, + 245.17165f, 245.5902f, 198.07778f, 149.76377f, 100.64695f, 50.7261f, + 74.19534f, 149.59215f, 226.19226f, 303.9975f, 383.0097f, 383.6604f, + 309.35688f, 233.84091f, 157.11066f, 79.1643f, 102.99194f, 207.59926f, + 313.8244f, 421.6698f, 531.1379f, 532.036f, 428.89372f, 324.12142f, + 217.71666f, 109.677f, 133.85145f, 269.7389f, 407.6654f, 547.634f, + 689.64775f, 690.8085f, 556.7615f, 420.6602f, 282.50155f, 142.2825f, + 135.20775f, 272.4698f, 411.7892f, 553.169f, 696.61225f, 697.773f, + 562.3697f, 424.8938f, 285.34225f, 143.712f, 112.43842f, 226.5337f, + 342.28828f, 459.7046f, 578.7851f, 579.7442f, 467.14324f, 352.87078f, + 236.92438f, 119.3016f, 87.55128f, 176.35527f, 266.4138f, 357.7287f, + 450.3018f, 451.044f, 363.36624f, 274.42479f, 184.21782f, 92.7435f, + 60.52803f, 121.89791f, 184.11086f, 247.1681f, 311.07085f, 311.5809f, + 250.9655f, 189.50093f, 127.18597f, 64.0194f, 31.35037f, 63.12502f, + 95.32456f, 127.9496f, 161.00075f, 161.2634f, 129.86782f, 98.0443f, + 65.79223f, 33.111f, 33.43584f, 67.30517f, 101.60864f, 136.3469f, + 171.5206f, 171.8166f, 138.32936f, 104.40473f, 70.04206f, 35.2407f, + 69.09703f, 139.06819f, 209.91478f, 281.6381f, 354.23945f, 354.8477f, + 285.64462f, 215.55961f, 144.59137f, 72.7386f, 107.00307f, 215.32806f, + 324.97692f, 435.9516f, 548.25405f, 549.1908f, 442.02378f, 333.52314f, + 223.68693f, 112.5132f, 147.17346f, 296.12378f, 446.85356f, 599.3654f, + 753.6619f, 754.9434f, 607.54484f, 458.35382f, 307.36774f, 154.584f, + 189.6277f, 381.49435f, 575.6032f, 771.9575f, 970.5605f, 972.203f, + 782.2858f, 590.11015f, 395.6728f, 198.9705f, 191.5597f, 385.37785f, + 581.4577f, 779.8025f, 980.4155f, 982.058f, 790.2088f, 596.08165f, + 399.6733f, 200.9805f, 157.97146f, 317.76398f, 479.38016f, 642.8226f, + 808.0939f, 809.4404f, 651.23084f, 491.18462f, 329.29914f, 165.5718f, + 122.04087f, 245.45826f, 370.25412f, 496.4304f, 623.98905f, 625.0233f, + 502.79898f, 379.18644f, 254.18373f, 127.7889f, 83.74843f, 168.42169f, + 254.02108f, 340.5479f, 428.00345f, 428.7092f, 344.83522f, 260.02861f, + 174.28807f, 87.6123f, 43.07464f, 86.61527f, 130.62254f, 175.0971f, + 220.0396f, 220.4006f, 177.26156f, 133.65263f, 89.57316f, 45.0225f}; + Nd4jLong _expES[] = { + 4, 2, 3, 10, 10, + 300, 100, 10, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, + 1, 99}; + NDArray expE(_expEB, _expES); + + auto input = NDArrayFactory::create('c', {2, 3, 10, 10}); + auto weightsD = NDArrayFactory::create('c', {2, 3, 5, 5}); + auto weightsP = NDArrayFactory::create('c', {10, 6, 1, 1}); + + auto epsilon = NDArrayFactory::create('c', {2, 3, 10, 10}); + auto epsilonNext = NDArrayFactory::create('c', {2, 10, 6, 6}); + + input.linspace(1); + weightsD.linspace(1); + weightsP.linspace(1); + epsilonNext.linspace(1); + weightsD.permutei({2, 3, 1, 0}); + weightsP.permutei({2, 3, 1, 0}); + + input.applyScalar(scalar::Divide, 100.0, input); + weightsD.applyScalar(scalar::Divide, 100.0, weightsD); + weightsP.applyScalar(scalar::Divide, 100.0, weightsP); + epsilonNext.applyScalar(scalar::Divide, 100.0, epsilonNext); + + sd::ops::sconv2d_bp op; + auto resultBP = op.evaluate({&input, &epsilonNext, &weightsD, &weightsP}, {}, + {5, 5, 1, 1, 0, 0, 1, 1, 0}, {}); + + ASSERT_EQ(3, resultBP.size()); + + auto _epsilon = resultBP.at(0); + auto _gradWD = resultBP.at(1); + auto _gradWP = resultBP.at(2); + + //_gradWP->printBuffer("gradWP"); + + ASSERT_TRUE(_gradWP.isSameShape(&expGWP)); + ASSERT_TRUE(_gradWP.isSameShape(&weightsP)); + + ASSERT_TRUE(_gradWP.equalsTo(&expGWP)); + + //_gradWD->printShapeInfo("gradWD shape"); + + ASSERT_TRUE(_gradWD.isSameShape(&expGWD)); + ASSERT_TRUE(_gradWD.isSameShape(&weightsD)); + // _gradWD->printIndexedBuffer(); + ASSERT_TRUE(_gradWD.equalsTo(&expGWD)); + + ASSERT_TRUE(_epsilon.isSameShape(&input)); + ASSERT_TRUE(_epsilon.isSameShape(&expE)); + + ASSERT_TRUE(_epsilon.equalsTo(&expE)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_2) { - - int bS=3, iH=16,iW=16, iC=3,mC=3, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=2,dW=2; - int oH=16,oW=16; - int oC=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iC, iH, iW}, typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 : sd::DataType::DOUBLE); - NDArray gradO('c', {bS, oC, oH, oW}, typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 : sd::DataType::DOUBLE); - NDArray weightsDepth('c', {kH, kW, iC, mC}, typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 : sd::DataType::DOUBLE); - NDArray weightsPoint('f', {1, 1, iC*mC, oC}, typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 : sd::DataType::DOUBLE); - NDArray bias('c', {1,oC}, {0.5, 0.5}, typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 : sd::DataType::DOUBLE); - - NDArray gradI(&input); - NDArray gradWD(&weightsDepth); - NDArray gradWP(&weightsPoint); - NDArray gradB(&bias); - - input = 2.; - weightsDepth.linspace(0.1, 0.1); - weightsPoint.linspace(0.15, 0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::sconv2d_bp op; - Nd4jStatus status = op.execute({&input, &gradO, &weightsDepth, & weightsPoint, &bias}, - {&gradI, &gradWD, &gradWP, &gradB}, - {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {}); - + int bS = 3, iH = 16, iW = 16, iC = 3, mC = 3, kH = 1, kW = 1, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 2, dW = 2; + int oH = 16, oW = 16; + int oC = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, + typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 + : sd::DataType::DOUBLE); + NDArray gradO('c', {bS, oC, oH, oW}, + typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 + : sd::DataType::DOUBLE); + NDArray weightsDepth('c', {kH, kW, iC, mC}, + typeid(TypeParam) == typeid(float) + ? sd::DataType::FLOAT32 + : sd::DataType::DOUBLE); + NDArray weightsPoint('f', {1, 1, iC * mC, oC}, + typeid(TypeParam) == typeid(float) + ? sd::DataType::FLOAT32 + : sd::DataType::DOUBLE); + NDArray bias('c', {1, oC}, {0.5, 0.5}, + typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 + : sd::DataType::DOUBLE); + + NDArray gradI(&input); + NDArray gradWD(&weightsDepth); + NDArray gradWP(&weightsPoint); + NDArray gradB(&bias); + + input = 2.; + weightsDepth.linspace(0.1, 0.1); + weightsPoint.linspace(0.15, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::sconv2d_bp op; + Nd4jStatus status = + op.execute({&input, &gradO, &weightsDepth, &weightsPoint, &bias}, + {&gradI, &gradWD, &gradWP, &gradB}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}, {}); + + ASSERT_EQ(Status::OK(), status); + + NDArray expGradI = gradI; + NDArray expGradWD = gradWD; + NDArray expGradWP = gradWP; + NDArray expGradB = gradB; + + for (int i = 0; i < 10; i++) { + Nd4jStatus status = op.execute( + {&input, &gradO, &weightsDepth, &weightsPoint, &bias}, + {&gradI, &gradWD, &gradWP, &gradB}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}, {}); ASSERT_EQ(Status::OK(), status); - NDArray expGradI = gradI; - NDArray expGradWD = gradWD; - NDArray expGradWP = gradWP; - NDArray expGradB = gradB; - - for( int i=0; i<10; i++ ) { - Nd4jStatus status = op.execute({&input, &gradO, &weightsDepth, & weightsPoint, &bias}, - {&gradI, &gradWD, &gradWP, &gradB}, - {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {}); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(expGradI.equalsTo(gradI)); - ASSERT_TRUE(expGradWD.equalsTo(gradWD)); - ASSERT_TRUE(expGradWP.equalsTo(gradWP)); - ASSERT_TRUE(expGradB.equalsTo(expGradB)); - } + ASSERT_TRUE(expGradI.equalsTo(gradI)); + ASSERT_TRUE(expGradWD.equalsTo(gradWD)); + ASSERT_TRUE(expGradWP.equalsTo(gradWP)); + ASSERT_TRUE(expGradB.equalsTo(expGradB)); + } } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_3) { + auto input = NDArrayFactory::create('c', {3, 3, 16, 16}); + auto weightsD = NDArrayFactory::create('c', {1, 3, 2, 2}); + auto weightsP = NDArrayFactory::create('c', {2, 3, 1, 1}); + auto bias = NDArrayFactory::create('c', {1, 2}); - auto input = NDArrayFactory::create('c', {3, 3, 16, 16}); - auto weightsD = NDArrayFactory::create('c', {1, 3, 2, 2}); - auto weightsP = NDArrayFactory::create('c', {2, 3, 1, 1}); - auto bias = NDArrayFactory::create('c', {1, 2}); - - weightsD.permutei({2,3,1,0}); - weightsP.permutei({2,3,1,0}); + weightsD.permutei({2, 3, 1, 0}); + weightsP.permutei({2, 3, 1, 0}); - auto epsilonNext = NDArrayFactory::create('c', {3, 2, 14, 14}); + auto epsilonNext = NDArrayFactory::create('c', {3, 2, 14, 14}); - auto epsilon = NDArrayFactory::create('c', {3, 3, 16, 16}); + auto epsilon = NDArrayFactory::create('c', {3, 3, 16, 16}); - sd::ops::sconv2d_bp op; - auto result = op.evaluate({&input, &epsilonNext, &weightsD, &weightsP}, {}, {2, 2, 1, 1, 0, 0, 2, 2, 0}); + sd::ops::sconv2d_bp op; + auto result = op.evaluate({&input, &epsilonNext, &weightsD, &weightsP}, {}, + {2, 2, 1, 1, 0, 0, 2, 2, 0}); - auto eps = result.at(0); - auto gWD = result.at(1); - auto gWP = result.at(2); + auto eps = result.at(0); + auto gWD = result.at(1); + auto gWP = result.at(2); - - ASSERT_TRUE(epsilon.isSameShape(eps)); + ASSERT_TRUE(epsilon.isSameShape(eps)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_4) { - - int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=4,oW=3; - int oC=iC*mC; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weightsDepth = NDArrayFactory::create('c', {kH, kW, iC, mC}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3,4}); - auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); - - auto expGradI = NDArrayFactory::create('c', {bS, iH, iW, iC},{0.07f, 0.19f, 0.348f, 0.652f, 0.588f, 0.956f, 0.387f, 0.687f, 1.326f, 2.022f, 1.878f, 2.67f, 1.071f, 1.515f, 2.982f, 3.966f, 3.534f, 4.614f, 1.606f, 1.982f, 3.932f, 4.748f, 4.428f, 5.308f, - 1.126f, 1.63f, 3.228f, 4.3f, 3.468f, 4.604f, 3.123f, 3.999f, 7.95f, 9.798f, 8.502f, 10.446f, 3.807f, 4.827f, 9.606f, 11.742f,10.158f, 12.39f, 4.198f, 4.958f, 9.884f, 11.468f,10.38f, 12.028f}); - - auto expGradW = NDArrayFactory::create('c', {kH, kW, iC, mC},{19.08f, 19.44f, 19.8f, 20.16f, 12.24f, 12.48f, 12.72f, 12.96f, 22.56f, 23.04f, 23.52f, 24.f, 14.4f, 14.72f, 15.04f, 15.36f, 14.76f, 15.12f, 15.48f, 15.84f, 9.36f, 9.6f, 9.84f, 10.08f}); - - input = 2.; - weightsDepth.linspace(0.1, 0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::sconv2d_bp op; - auto results = op.evaluate({&input, &gradO, &weightsDepth, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* gradI = results.at(0); - auto* gradWD = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradWD)); - ASSERT_TRUE(expGradW.equalsTo(gradWD)); - + int bS = 2, iH = 4, iW = 3, iC = 2, mC = 2, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 3; + int oC = iC * mC; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weightsDepth = NDArrayFactory::create('c', {kH, kW, iC, mC}); + auto bias = NDArrayFactory::create('c', {oC}, {1, 2, 3, 4}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); + + auto expGradI = NDArrayFactory::create( + 'c', {bS, iH, iW, iC}, + {0.07f, 0.19f, 0.348f, 0.652f, 0.588f, 0.956f, 0.387f, 0.687f, + 1.326f, 2.022f, 1.878f, 2.67f, 1.071f, 1.515f, 2.982f, 3.966f, + 3.534f, 4.614f, 1.606f, 1.982f, 3.932f, 4.748f, 4.428f, 5.308f, + 1.126f, 1.63f, 3.228f, 4.3f, 3.468f, 4.604f, 3.123f, 3.999f, + 7.95f, 9.798f, 8.502f, 10.446f, 3.807f, 4.827f, 9.606f, 11.742f, + 10.158f, 12.39f, 4.198f, 4.958f, 9.884f, 11.468f, 10.38f, 12.028f}); + + auto expGradW = NDArrayFactory::create( + 'c', {kH, kW, iC, mC}, + {19.08f, 19.44f, 19.8f, 20.16f, 12.24f, 12.48f, 12.72f, 12.96f, + 22.56f, 23.04f, 23.52f, 24.f, 14.4f, 14.72f, 15.04f, 15.36f, + 14.76f, 15.12f, 15.48f, 15.84f, 9.36f, 9.6f, 9.84f, 10.08f}); + + input = 2.; + weightsDepth.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::sconv2d_bp op; + auto results = + op.evaluate({&input, &gradO, &weightsDepth, &bias}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradWD = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradWD)); + ASSERT_TRUE(expGradW.equalsTo(gradWD)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, sconv2d_bp_5) { - - int bS=1, iH=8,iW=8, iC=3,mC=3, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=8,oW=8; - int oC=2; // iC*mC if weightsPoint = nullptr - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, oC, oH, oW}); - auto weightsDepth = NDArrayFactory::create('c', {kH, kW, iC, mC}); - auto weightsPoint = NDArrayFactory::create('c', {1, 1, iC*mC, oC}); - auto bias = NDArrayFactory::create('c', {1,oC}, {1,2}); - - auto gradI = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto gradWD = NDArrayFactory::create('f', {kH, kW, iC, mC}); - auto gradWP = NDArrayFactory::create('c', {1, 1, iC*mC, oC}); - auto gradB = NDArrayFactory::create('c', {1,oC}, {1,2}); - - input = 2.; - weightsDepth.linspace(0.1, 0.1); - weightsDepth.linspace(-0.5, 0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::sconv2d_bp op; - auto status = op.execute({&input, &gradO, &weightsDepth, &weightsPoint, &bias}, {&gradI, &gradWD, &gradWP, &gradB}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {}); - ASSERT_EQ(Status::OK(), status); + int bS = 1, iH = 8, iW = 8, iC = 3, mC = 3, kH = 1, kW = 1, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 8, oW = 8; + int oC = 2; // iC*mC if weightsPoint = nullptr + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, oC, oH, oW}); + auto weightsDepth = NDArrayFactory::create('c', {kH, kW, iC, mC}); + auto weightsPoint = NDArrayFactory::create('c', {1, 1, iC * mC, oC}); + auto bias = NDArrayFactory::create('c', {1, oC}, {1, 2}); + + auto gradI = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradWD = NDArrayFactory::create('f', {kH, kW, iC, mC}); + auto gradWP = NDArrayFactory::create('c', {1, 1, iC * mC, oC}); + auto gradB = NDArrayFactory::create('c', {1, oC}, {1, 2}); + + input = 2.; + weightsDepth.linspace(0.1, 0.1); + weightsDepth.linspace(-0.5, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::sconv2d_bp op; + auto status = + op.execute({&input, &gradO, &weightsDepth, &weightsPoint, &bias}, + {&gradI, &gradWD, &gradWP, &gradB}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}, {}); + ASSERT_EQ(Status::OK(), status); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, im2col_bp_1) { + int bS = 3, iH = 12, iW = 12, iC = 6, oC = 3, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 12, oW = 12; - int bS=3, iH=12,iW=12, iC=6,oC=3, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=12,oW=12; + // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::DOUBLE); + NDArray gradO('c', {bS, iC, kH, kW, oH, oW}, sd::DataType::DOUBLE); + NDArray gradI('c', {bS, iC, iH, iW}, sd::DataType::DOUBLE); // output - // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] - NDArray input('c', {bS, iC, iH, iW}, sd::DataType::DOUBLE); - NDArray gradO('c', {bS, iC, kH, kW, oH, oW}, sd::DataType::DOUBLE); - NDArray gradI('c', {bS, iC, iH, iW}, sd::DataType::DOUBLE); // output - - sd::ops::im2col_bp op; - Nd4jStatus status = op.execute({&input, &gradO}, {&gradI}, {}, {kH, kW, sH, sW, pH, pW, dH, dW, 1}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); + sd::ops::im2col_bp op; + Nd4jStatus status = op.execute({&input, &gradO}, {&gradI}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, 1}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_test1) { - - int bS=2, iD=4,iH=4,iW=4, iC=2,oC=3, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=3,oH=3,oW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto exp = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.3 , 0.75, 1.5 , 2.4 , 1.5 , 2.4 , 1.2 , 1.65, 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 4.2 , 5.1 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 4.2 , 5.1 , 2.1 , 2.55, 5.1 , 6. , 5.1 , 6. , 3. , 3.45, - 4.2 , 5.1 ,10.2 ,12. ,10.2 ,12. , 6. , 6.9 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 , 7.8 , 8.7 ,17.4 ,19.2 ,17.4 ,19.2 , 9.6 ,10.5 , - 4.2 , 5.1 ,10.2 ,12. ,10.2 ,12. , 6. , 6.9 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 , 7.8 , 8.7 ,17.4 ,19.2 ,17.4 ,19.2 , 9.6 ,10.5 , - 3.9 , 4.35, 8.7 , 9.6 , 8.7 , 9.6 , 4.8 , 5.25, 9.6 ,10.5 ,21. ,22.8 ,21. ,22.8 ,11.4 ,12.3 , 9.6 ,10.5 ,21. ,22.8 ,21. ,22.8 ,11.4 ,12.3 , 5.7 , 6.15,12.3 ,13.2 ,12.3 ,13.2 , 6.6 , 7.05, - 0.3 , 0.75, 1.5 , 2.4 , 1.5 , 2.4 , 1.2 , 1.65, 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 4.2 , 5.1 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 4.2 , 5.1 , 2.1 , 2.55, 5.1 , 6. , 5.1 , 6. , 3. , 3.45, - 4.2 , 5.1 ,10.2 ,12. ,10.2 ,12. , 6. , 6.9 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 , 7.8 , 8.7 ,17.4 ,19.2 ,17.4 ,19.2 , 9.6 ,10.5 , - 4.2 , 5.1 ,10.2 ,12. ,10.2 ,12. , 6. , 6.9 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 , 7.8 , 8.7 ,17.4 ,19.2 ,17.4 ,19.2 , 9.6 ,10.5 , - 3.9 , 4.35, 8.7 , 9.6 , 8.7 , 9.6 , 4.8 , 5.25, 9.6 ,10.5 ,21. ,22.8 ,21. ,22.8 ,11.4 ,12.3 , 9.6 ,10.5 ,21. ,22.8 ,21. ,22.8 ,11.4 ,12.3 , 5.7 , 6.15,12.3 ,13.2 ,12.3 ,13.2 , 6.6 , 7.05}); - input = 0.5; - weights.linspace(0.1, 0.1); - - sd::ops::deconv3d op; - auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); - auto output = results.at(0); - - // output->printBuffer(); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + int bS = 2, iD = 4, iH = 4, iW = 4, iC = 2, oC = 3, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 3, oH = 3, oW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto exp = NDArrayFactory::create( + 'c', {bS, iD, iH, iW, iC}, + {0.3, 0.75, 1.5, 2.4, 1.5, 2.4, 1.2, 1.65, 2.4, 3.3, 6.6, 8.4, + 6.6, 8.4, 4.2, 5.1, 2.4, 3.3, 6.6, 8.4, 6.6, 8.4, 4.2, 5.1, + 2.1, 2.55, 5.1, 6., 5.1, 6., 3., 3.45, 4.2, 5.1, 10.2, 12., + 10.2, 12., 6., 6.9, 12., 13.8, 27.6, 31.2, 27.6, 31.2, 15.6, 17.4, + 12., 13.8, 27.6, 31.2, 27.6, 31.2, 15.6, 17.4, 7.8, 8.7, 17.4, 19.2, + 17.4, 19.2, 9.6, 10.5, 4.2, 5.1, 10.2, 12., 10.2, 12., 6., 6.9, + 12., 13.8, 27.6, 31.2, 27.6, 31.2, 15.6, 17.4, 12., 13.8, 27.6, 31.2, + 27.6, 31.2, 15.6, 17.4, 7.8, 8.7, 17.4, 19.2, 17.4, 19.2, 9.6, 10.5, + 3.9, 4.35, 8.7, 9.6, 8.7, 9.6, 4.8, 5.25, 9.6, 10.5, 21., 22.8, + 21., 22.8, 11.4, 12.3, 9.6, 10.5, 21., 22.8, 21., 22.8, 11.4, 12.3, + 5.7, 6.15, 12.3, 13.2, 12.3, 13.2, 6.6, 7.05, 0.3, 0.75, 1.5, 2.4, + 1.5, 2.4, 1.2, 1.65, 2.4, 3.3, 6.6, 8.4, 6.6, 8.4, 4.2, 5.1, + 2.4, 3.3, 6.6, 8.4, 6.6, 8.4, 4.2, 5.1, 2.1, 2.55, 5.1, 6., + 5.1, 6., 3., 3.45, 4.2, 5.1, 10.2, 12., 10.2, 12., 6., 6.9, + 12., 13.8, 27.6, 31.2, 27.6, 31.2, 15.6, 17.4, 12., 13.8, 27.6, 31.2, + 27.6, 31.2, 15.6, 17.4, 7.8, 8.7, 17.4, 19.2, 17.4, 19.2, 9.6, 10.5, + 4.2, 5.1, 10.2, 12., 10.2, 12., 6., 6.9, 12., 13.8, 27.6, 31.2, + 27.6, 31.2, 15.6, 17.4, 12., 13.8, 27.6, 31.2, 27.6, 31.2, 15.6, 17.4, + 7.8, 8.7, 17.4, 19.2, 17.4, 19.2, 9.6, 10.5, 3.9, 4.35, 8.7, 9.6, + 8.7, 9.6, 4.8, 5.25, 9.6, 10.5, 21., 22.8, 21., 22.8, 11.4, 12.3, + 9.6, 10.5, 21., 22.8, 21., 22.8, 11.4, 12.3, 5.7, 6.15, 12.3, 13.2, + 12.3, 13.2, 6.6, 7.05}); + input = 0.5; + weights.linspace(0.1, 0.1); + + sd::ops::deconv3d op; + auto results = op.evaluate( + {&input, &weights}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, dataFormat}, + {}); + auto output = results.at(0); + + // output->printBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_test2) { - - int bS=2, iD=4,iH=4,iW=4, iC=2,oC=3, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=4,oH=4,oW=4; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto exp = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.3 , 0.75, 1.5 , 2.4 , 1.5 , 2.4 , 1.5 , 2.4 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 6.6 , 8.4 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 6.6 , 8.4 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 6.6 , 8.4 , - 4.2 , 5.1 ,10.2 , 12. ,10.2 , 12. ,10.2 , 12. ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 , - 4.2 , 5.1 ,10.2 , 12. ,10.2 , 12. ,10.2 , 12. ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 , - 4.2 , 5.1 ,10.2 , 12. ,10.2 , 12. ,10.2 , 12. ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 , - 0.3 , 0.75, 1.5 , 2.4 , 1.5 , 2.4 , 1.5 , 2.4 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 6.6 , 8.4 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 6.6 , 8.4 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 6.6 , 8.4 , - 4.2 , 5.1 ,10.2 , 12. ,10.2 , 12. ,10.2 , 12. ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 , - 4.2 , 5.1 ,10.2 , 12. ,10.2 , 12. ,10.2 , 12. ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 , - 4.2 , 5.1 ,10.2 , 12. ,10.2 , 12. ,10.2 , 12. ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 }); - input = 0.5; - weights.linspace(0.1, 0.1); - - sd::ops::deconv3d op; - auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + int bS = 2, iD = 4, iH = 4, iW = 4, iC = 2, oC = 3, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 4, oH = 4, oW = 4; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto exp = NDArrayFactory::create( + 'c', {bS, iD, iH, iW, iC}, + {0.3, 0.75, 1.5, 2.4, 1.5, 2.4, 1.5, 2.4, 2.4, 3.3, 6.6, 8.4, + 6.6, 8.4, 6.6, 8.4, 2.4, 3.3, 6.6, 8.4, 6.6, 8.4, 6.6, 8.4, + 2.4, 3.3, 6.6, 8.4, 6.6, 8.4, 6.6, 8.4, 4.2, 5.1, 10.2, 12., + 10.2, 12., 10.2, 12., 12., 13.8, 27.6, 31.2, 27.6, 31.2, 27.6, 31.2, + 12., 13.8, 27.6, 31.2, 27.6, 31.2, 27.6, 31.2, 12., 13.8, 27.6, 31.2, + 27.6, 31.2, 27.6, 31.2, 4.2, 5.1, 10.2, 12., 10.2, 12., 10.2, 12., + 12., 13.8, 27.6, 31.2, 27.6, 31.2, 27.6, 31.2, 12., 13.8, 27.6, 31.2, + 27.6, 31.2, 27.6, 31.2, 12., 13.8, 27.6, 31.2, 27.6, 31.2, 27.6, 31.2, + 4.2, 5.1, 10.2, 12., 10.2, 12., 10.2, 12., 12., 13.8, 27.6, 31.2, + 27.6, 31.2, 27.6, 31.2, 12., 13.8, 27.6, 31.2, 27.6, 31.2, 27.6, 31.2, + 12., 13.8, 27.6, 31.2, 27.6, 31.2, 27.6, 31.2, 0.3, 0.75, 1.5, 2.4, + 1.5, 2.4, 1.5, 2.4, 2.4, 3.3, 6.6, 8.4, 6.6, 8.4, 6.6, 8.4, + 2.4, 3.3, 6.6, 8.4, 6.6, 8.4, 6.6, 8.4, 2.4, 3.3, 6.6, 8.4, + 6.6, 8.4, 6.6, 8.4, 4.2, 5.1, 10.2, 12., 10.2, 12., 10.2, 12., + 12., 13.8, 27.6, 31.2, 27.6, 31.2, 27.6, 31.2, 12., 13.8, 27.6, 31.2, + 27.6, 31.2, 27.6, 31.2, 12., 13.8, 27.6, 31.2, 27.6, 31.2, 27.6, 31.2, + 4.2, 5.1, 10.2, 12., 10.2, 12., 10.2, 12., 12., 13.8, 27.6, 31.2, + 27.6, 31.2, 27.6, 31.2, 12., 13.8, 27.6, 31.2, 27.6, 31.2, 27.6, 31.2, + 12., 13.8, 27.6, 31.2, 27.6, 31.2, 27.6, 31.2, 4.2, 5.1, 10.2, 12., + 10.2, 12., 10.2, 12., 12., 13.8, 27.6, 31.2, 27.6, 31.2, 27.6, 31.2, + 12., 13.8, 27.6, 31.2, 27.6, 31.2, 27.6, 31.2, 12., 13.8, 27.6, 31.2, + 27.6, 31.2, 27.6, 31.2}); + input = 0.5; + weights.linspace(0.1, 0.1); + + sd::ops::deconv3d op; + auto results = op.evaluate( + {&input, &weights}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, dataFormat}, + {}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_test3) { - - int bS=2, iD=4,iH=4,iW=4, iC=2,oC=3, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=3,oH=3,oW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); - auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); - auto exp = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {2.55, 5.25, 5.25, 2.7, 5.4 , 11.1 , 11.1 , 5.7, 5.4 , 11.1 , 11.1 , 5.7, 2.85, 5.85, 5.85, 3. , 5.7 , 11.7 , 11.7 , 6. ,12. , 24.6 , 24.6 , 12.6,12. , 24.6 , 24.6 , 12.6, 6.3 , 12.9 , 12.9 , 6.6, - 5.7 , 11.7 , 11.7 , 6. ,12. , 24.6 , 24.6 , 12.6,12. , 24.6 , 24.6 , 12.6, 6.3 , 12.9 , 12.9 , 6.6, 3.15, 6.45, 6.45, 3.3, 6.6 , 13.5 , 13.5 , 6.9, 6.6 , 13.5 , 13.5 , 6.9, 3.45, 7.05, 7.05, 3.6, - 3.75, 7.65, 7.65, 3.9, 7.8 , 15.9 , 15.9 , 8.1, 7.8 , 15.9 , 15.9 , 8.1, 4.05, 8.25, 8.25, 4.2, 8.1 , 16.5 , 16.5 , 8.4,16.8 , 34.2 , 34.2 , 17.4,16.8 , 34.2 , 34.2 , 17.4, 8.7 , 17.7 , 17.7 , 9. , - 8.1 , 16.5 , 16.5 , 8.4,16.8 , 34.2 , 34.2 , 17.4,16.8 , 34.2 , 34.2 , 17.4, 8.7 , 17.7 , 17.7 , 9. , 4.35, 8.85, 8.85, 4.5, 9. , 18.3 , 18.3 , 9.3, 9. , 18.3 , 18.3 , 9.3, 4.65, 9.45, 9.45, 4.8, - 2.55, 5.25, 5.25, 2.7, 5.4 , 11.1 , 11.1 , 5.7, 5.4 , 11.1 , 11.1 , 5.7, 2.85, 5.85, 5.85, 3. , 5.7 , 11.7 , 11.7 , 6. ,12. , 24.6 , 24.6 , 12.6,12. , 24.6 , 24.6 , 12.6, 6.3 , 12.9 , 12.9 , 6.6, - 5.7 , 11.7 , 11.7 , 6. ,12. , 24.6 , 24.6 , 12.6,12. , 24.6 , 24.6 , 12.6, 6.3 , 12.9 , 12.9 , 6.6, 3.15, 6.45, 6.45, 3.3, 6.6 , 13.5 , 13.5 , 6.9, 6.6 , 13.5 , 13.5 , 6.9, 3.45, 7.05, 7.05, 3.6, - 3.75, 7.65, 7.65, 3.9, 7.8 , 15.9 , 15.9 , 8.1, 7.8 , 15.9 , 15.9 , 8.1, 4.05, 8.25, 8.25, 4.2, 8.1 , 16.5 , 16.5 , 8.4,16.8 , 34.2 , 34.2 , 17.4,16.8 , 34.2 , 34.2 , 17.4, 8.7 , 17.7 , 17.7 , 9. , - 8.1 , 16.5 , 16.5 , 8.4,16.8 , 34.2 , 34.2 , 17.4,16.8 , 34.2 , 34.2 , 17.4, 8.7 , 17.7 , 17.7 , 9. , 4.35, 8.85, 8.85, 4.5, 9. , 18.3 , 18.3 , 9.3, 9. , 18.3 , 18.3 , 9.3, 4.65, 9.45, 9.45, 4.8}); - input = 0.5; - weights.linspace(0.1, 0.1); - weights.permutei({2, 3, 4, 1, 0}); - - sd::ops::deconv3d op; - auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + int bS = 2, iD = 4, iH = 4, iW = 4, iC = 2, oC = 3, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 3, oH = 3, oW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); + auto exp = NDArrayFactory::create( + 'c', {bS, iC, iD, iH, iW}, + {2.55, 5.25, 5.25, 2.7, 5.4, 11.1, 11.1, 5.7, 5.4, 11.1, 11.1, 5.7, + 2.85, 5.85, 5.85, 3., 5.7, 11.7, 11.7, 6., 12., 24.6, 24.6, 12.6, + 12., 24.6, 24.6, 12.6, 6.3, 12.9, 12.9, 6.6, 5.7, 11.7, 11.7, 6., + 12., 24.6, 24.6, 12.6, 12., 24.6, 24.6, 12.6, 6.3, 12.9, 12.9, 6.6, + 3.15, 6.45, 6.45, 3.3, 6.6, 13.5, 13.5, 6.9, 6.6, 13.5, 13.5, 6.9, + 3.45, 7.05, 7.05, 3.6, 3.75, 7.65, 7.65, 3.9, 7.8, 15.9, 15.9, 8.1, + 7.8, 15.9, 15.9, 8.1, 4.05, 8.25, 8.25, 4.2, 8.1, 16.5, 16.5, 8.4, + 16.8, 34.2, 34.2, 17.4, 16.8, 34.2, 34.2, 17.4, 8.7, 17.7, 17.7, 9., + 8.1, 16.5, 16.5, 8.4, 16.8, 34.2, 34.2, 17.4, 16.8, 34.2, 34.2, 17.4, + 8.7, 17.7, 17.7, 9., 4.35, 8.85, 8.85, 4.5, 9., 18.3, 18.3, 9.3, + 9., 18.3, 18.3, 9.3, 4.65, 9.45, 9.45, 4.8, 2.55, 5.25, 5.25, 2.7, + 5.4, 11.1, 11.1, 5.7, 5.4, 11.1, 11.1, 5.7, 2.85, 5.85, 5.85, 3., + 5.7, 11.7, 11.7, 6., 12., 24.6, 24.6, 12.6, 12., 24.6, 24.6, 12.6, + 6.3, 12.9, 12.9, 6.6, 5.7, 11.7, 11.7, 6., 12., 24.6, 24.6, 12.6, + 12., 24.6, 24.6, 12.6, 6.3, 12.9, 12.9, 6.6, 3.15, 6.45, 6.45, 3.3, + 6.6, 13.5, 13.5, 6.9, 6.6, 13.5, 13.5, 6.9, 3.45, 7.05, 7.05, 3.6, + 3.75, 7.65, 7.65, 3.9, 7.8, 15.9, 15.9, 8.1, 7.8, 15.9, 15.9, 8.1, + 4.05, 8.25, 8.25, 4.2, 8.1, 16.5, 16.5, 8.4, 16.8, 34.2, 34.2, 17.4, + 16.8, 34.2, 34.2, 17.4, 8.7, 17.7, 17.7, 9., 8.1, 16.5, 16.5, 8.4, + 16.8, 34.2, 34.2, 17.4, 16.8, 34.2, 34.2, 17.4, 8.7, 17.7, 17.7, 9., + 4.35, 8.85, 8.85, 4.5, 9., 18.3, 18.3, 9.3, 9., 18.3, 18.3, 9.3, + 4.65, 9.45, 9.45, 4.8}); + input = 0.5; + weights.linspace(0.1, 0.1); + weights.permutei({2, 3, 4, 1, 0}); + + sd::ops::deconv3d op; + auto results = op.evaluate( + {&input, &weights}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, dataFormat}, + {}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_test4) { - - int bS=2, iD=2,iH=2,iW=2, iC=2,oC=3, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1; - int oD=3,oH=3,oW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); - auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); - auto exp = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {24.6, 24.6,24.6, 24.6,24.6, 24.6,24.6, 24.6,34.2, 34.2,34.2, 34.2,34.2, 34.2,34.2, 34.2,24.6, 24.6,24.6, 24.6, - 24.6, 24.6,24.6, 24.6,34.2, 34.2,34.2, 34.2,34.2, 34.2,34.2, 34.2}); - input = 0.5; - weights.linspace(0.1, 0.1); - weights.permutei({2, 3, 4, 1, 0}); - - sd::ops::deconv3d op; - auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + int bS = 2, iD = 2, iH = 2, iW = 2, iC = 2, oC = 3, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 1, pH = 1, pW = 1, dD = 1, dH = 1, dW = 1; + int oD = 3, oH = 3, oW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); + auto exp = NDArrayFactory::create( + 'c', {bS, iC, iD, iH, iW}, + {24.6, 24.6, 24.6, 24.6, 24.6, 24.6, 24.6, 24.6, 34.2, 34.2, 34.2, + 34.2, 34.2, 34.2, 34.2, 34.2, 24.6, 24.6, 24.6, 24.6, 24.6, 24.6, + 24.6, 24.6, 34.2, 34.2, 34.2, 34.2, 34.2, 34.2, 34.2, 34.2}); + input = 0.5; + weights.linspace(0.1, 0.1); + weights.permutei({2, 3, 4, 1, 0}); + + sd::ops::deconv3d op; + auto results = op.evaluate( + {&input, &weights}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, dataFormat}, + {}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_test5) { - int bS=1, oD=5,oH=5,oW=5, oC=3,iC=2, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=2,dH=2,dW=2; - int iD=3,iH=3,iW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, oC, iC}); - auto bias = NDArrayFactory::create('c', {oC}); - - auto exp = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}, {-2.9f, -6.8f, -10.7f, -2.6f, -6.1f, -9.6f, -16.9f, -23.9f, -30.9f, -13.1f, -16.6f, -20.1f, -11.6f, -14.7f, -17.8f, -2.0f, -4.7f, -7.4f, -1.7f, -4.0f, -6.3f, -11.5f, - -16.1f, -20.7f, -8.6f, -10.9f, -13.2f, -7.1f, -9.0f, -10.9f, -27.4f, -32.8f, -38.2f, -24.4f, -29.0f, -33.6f, -65.0f, -74.2f, -83.4f, -38.2f, -42.8f, -47.4f, -32.8f, - -36.6f, -40.4f, -18.2f, -20.9f, -23.6f, -15.5f, -17.8f, -20.1f, -39.1f, -43.7f, -48.3f, -22.4f, -24.7f, -27.0f, -18.5f, -20.4f, -22.3f, -10.1f, -11.6f, -13.1f, -7.4f, - -8.5f, -9.6f, -19.3f, -21.5f, -23.7f, -10.7f, -11.8f, -12.9f, -6.8f, -7.5f, -8.2f, -0.2f, -0.5f, -0.8f, 0.1f, 0.2f, 0.3f, -0.7f, -0.5f, -0.3f, 0.4f, 0.5f, 0.6f, 1.9f, 2.4f, - 2.9f, 0.7f, 1.6f, 2.5f, 1.0f, 2.3f, 3.6f, 4.7f, 7.3f, 9.9f, 4.9f, 6.2f, 7.5f, 6.4f, 8.1f, 9.8f, -0.4f, 1.4f, 3.2f, 2.6f, 5.2f, 7.8f, 10.6f, 15.8f, 21.0f, 10.4f, 13.0f, 15.6f, - 15.8f, 19.2f, 22.6f, 6.1f, 7.0f, 7.9f, 8.8f, 10.1f, 11.4f, 20.3f, 22.9f, 25.5f, 12.7f, 14.0f, 15.3f, 16.6f, 18.3f, 20.0f, 14.2f, 16.3f, 18.4f, 16.9f, 19.4f, 21.9f, 40.1f, - 45.1f, 50.1f, 24.4f, 26.9f, 29.4f, 28.3f, 31.2f, 34.1f, -47.2f, -47.8f, -48.4f, -41.8f, -41.6f, -41.4f, -85.4f, -85.f, -84.6f, -41.2f, -41.0f, -40.8f, -33.4f, -32.4f, -31.4f, - -31.f, -29.2f, -27.4f, -25.6f, -23.0f, -20.4f, -45.8f, -40.6f, -35.4f, -17.8f, -15.2f, -12.6f, -10.0f, -6.6f, -3.2f, -65.6f, -62.0f, -58.4f, -50.0f, -44.8f, -39.6f, -89.2f, - -78.8f, -68.4f, -34.4f, -29.2f, -24.f, -14.0f, -7.2f, -0.4f, -20.2f, -18.4f, -16.6f, -10.f, -7.4f, -4.8f, -14.6f, -9.4f, -4.2f, -2.2f, 0.4f, 3.0f, 10.4f, 13.8f, 17.2f, 10.4f, - 14.6f, 18.8f, 20.6f, 25.6f, 30.6f, 53.8f, 63.8f, 73.8f, 35.6f, 40.6f, 45.6f, 48.2f, 54.0f, 59.8f, -3.8f, -4.1f, -4.4f, 1.3f, 1.4f, 1.5f, 1.7f, 1.9f, 2.1f, 1.6f, 1.7f, 1.8f, 7.9f, - 8.4f, 8.9f, 11.5f, 12.4f, 13.3f, 16.6f, 17.9f, 19.2f, 35.9f, 38.5f, 41.1f, 20.5f, 21.8f, 23.1f, 26.8f, 28.5f, 30.2f, 21.2f, 23.0f, 24.8f, 33.8f, 36.4f, 39.0f, 73.0f, 78.2f, - 83.4f, 41.6f, 44.2f, 46.8f, 56.6f, 60.0f, 63.4f, 16.9f, 17.8f, 18.7f, 24.4f, 25.7f, 27.f, 51.5f, 54.1f, 56.7f, 28.3f, 29.6f, 30.9f, 37.0f, 38.7f, 40.4f, 39.4f, 41.5f, - 43.6f, 46.9f, 49.4f, 51.9f, 100.1f, 105.1f, 110.1f, 54.4f, 56.9f, 59.4f, 63.1f, 66.0f, 68.9f, 42.1f, 45.4f, 48.7f, 47.2f, 50.9f, 54.6f, 104.3f, 111.7f, - 119.1f, 58.3f, 62.0f, 65.7f, 64.6f, 68.7f, 72.8f, 57.4f, 61.9f, 66.4f, 62.5f, 67.4f, 72.3f, 138.5f, 148.3f, 158.1f, 77.2f, 82.1f, 87.0f, 83.5f, 88.8f, 94.1f, - 134.6f, 143.6f, 152.6f, 147.2f, 157.0f, 166.8f, 321.4f, 341.0f, 360.6f, 176.6f, 186.4f, 196.2f, 191.6f, 202.2f, 212.8f, 84.4f, 88.9f, - 93.4f, 91.9f, 96.8f, 101.7f, 197.3f, 207.1f, 216.9f, 106.6f, 111.5f, 116.4f, 115.3f, 120.6f, 125.9f, 106.9f, 112.6f, 118.3f, 114.4f, 120.5f, 126.6f, 245.9f, 258.1f, 270.3f, 132.7f, 138.8f, 144.9f, 141.4f, 147.9f, 154.4f}); - - input.linspace(-10, 0.5); - weights.linspace(0.1, 0.1); - bias = 0.2; - - sd::ops::deconv3d op; - auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - ASSERT_EQ(Status::OK(), results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + int bS = 1, oD = 5, oH = 5, oW = 5, oC = 3, iC = 2, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 2, dH = 2, dW = 2; + int iD = 3, iH = 3, iW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, oC, iC}); + auto bias = NDArrayFactory::create('c', {oC}); + + auto exp = NDArrayFactory::create( + 'c', {bS, oD, oH, oW, oC}, + {-2.9f, -6.8f, -10.7f, -2.6f, -6.1f, -9.6f, -16.9f, -23.9f, -30.9f, + -13.1f, -16.6f, -20.1f, -11.6f, -14.7f, -17.8f, -2.0f, -4.7f, -7.4f, + -1.7f, -4.0f, -6.3f, -11.5f, -16.1f, -20.7f, -8.6f, -10.9f, -13.2f, + -7.1f, -9.0f, -10.9f, -27.4f, -32.8f, -38.2f, -24.4f, -29.0f, -33.6f, + -65.0f, -74.2f, -83.4f, -38.2f, -42.8f, -47.4f, -32.8f, -36.6f, -40.4f, + -18.2f, -20.9f, -23.6f, -15.5f, -17.8f, -20.1f, -39.1f, -43.7f, -48.3f, + -22.4f, -24.7f, -27.0f, -18.5f, -20.4f, -22.3f, -10.1f, -11.6f, -13.1f, + -7.4f, -8.5f, -9.6f, -19.3f, -21.5f, -23.7f, -10.7f, -11.8f, -12.9f, + -6.8f, -7.5f, -8.2f, -0.2f, -0.5f, -0.8f, 0.1f, 0.2f, 0.3f, + -0.7f, -0.5f, -0.3f, 0.4f, 0.5f, 0.6f, 1.9f, 2.4f, 2.9f, + 0.7f, 1.6f, 2.5f, 1.0f, 2.3f, 3.6f, 4.7f, 7.3f, 9.9f, + 4.9f, 6.2f, 7.5f, 6.4f, 8.1f, 9.8f, -0.4f, 1.4f, 3.2f, + 2.6f, 5.2f, 7.8f, 10.6f, 15.8f, 21.0f, 10.4f, 13.0f, 15.6f, + 15.8f, 19.2f, 22.6f, 6.1f, 7.0f, 7.9f, 8.8f, 10.1f, 11.4f, + 20.3f, 22.9f, 25.5f, 12.7f, 14.0f, 15.3f, 16.6f, 18.3f, 20.0f, + 14.2f, 16.3f, 18.4f, 16.9f, 19.4f, 21.9f, 40.1f, 45.1f, 50.1f, + 24.4f, 26.9f, 29.4f, 28.3f, 31.2f, 34.1f, -47.2f, -47.8f, -48.4f, + -41.8f, -41.6f, -41.4f, -85.4f, -85.f, -84.6f, -41.2f, -41.0f, -40.8f, + -33.4f, -32.4f, -31.4f, -31.f, -29.2f, -27.4f, -25.6f, -23.0f, -20.4f, + -45.8f, -40.6f, -35.4f, -17.8f, -15.2f, -12.6f, -10.0f, -6.6f, -3.2f, + -65.6f, -62.0f, -58.4f, -50.0f, -44.8f, -39.6f, -89.2f, -78.8f, -68.4f, + -34.4f, -29.2f, -24.f, -14.0f, -7.2f, -0.4f, -20.2f, -18.4f, -16.6f, + -10.f, -7.4f, -4.8f, -14.6f, -9.4f, -4.2f, -2.2f, 0.4f, 3.0f, + 10.4f, 13.8f, 17.2f, 10.4f, 14.6f, 18.8f, 20.6f, 25.6f, 30.6f, + 53.8f, 63.8f, 73.8f, 35.6f, 40.6f, 45.6f, 48.2f, 54.0f, 59.8f, + -3.8f, -4.1f, -4.4f, 1.3f, 1.4f, 1.5f, 1.7f, 1.9f, 2.1f, + 1.6f, 1.7f, 1.8f, 7.9f, 8.4f, 8.9f, 11.5f, 12.4f, 13.3f, + 16.6f, 17.9f, 19.2f, 35.9f, 38.5f, 41.1f, 20.5f, 21.8f, 23.1f, + 26.8f, 28.5f, 30.2f, 21.2f, 23.0f, 24.8f, 33.8f, 36.4f, 39.0f, + 73.0f, 78.2f, 83.4f, 41.6f, 44.2f, 46.8f, 56.6f, 60.0f, 63.4f, + 16.9f, 17.8f, 18.7f, 24.4f, 25.7f, 27.f, 51.5f, 54.1f, 56.7f, + 28.3f, 29.6f, 30.9f, 37.0f, 38.7f, 40.4f, 39.4f, 41.5f, 43.6f, + 46.9f, 49.4f, 51.9f, 100.1f, 105.1f, 110.1f, 54.4f, 56.9f, 59.4f, + 63.1f, 66.0f, 68.9f, 42.1f, 45.4f, 48.7f, 47.2f, 50.9f, 54.6f, + 104.3f, 111.7f, 119.1f, 58.3f, 62.0f, 65.7f, 64.6f, 68.7f, 72.8f, + 57.4f, 61.9f, 66.4f, 62.5f, 67.4f, 72.3f, 138.5f, 148.3f, 158.1f, + 77.2f, 82.1f, 87.0f, 83.5f, 88.8f, 94.1f, 134.6f, 143.6f, 152.6f, + 147.2f, 157.0f, 166.8f, 321.4f, 341.0f, 360.6f, 176.6f, 186.4f, 196.2f, + 191.6f, 202.2f, 212.8f, 84.4f, 88.9f, 93.4f, 91.9f, 96.8f, 101.7f, + 197.3f, 207.1f, 216.9f, 106.6f, 111.5f, 116.4f, 115.3f, 120.6f, 125.9f, + 106.9f, 112.6f, 118.3f, 114.4f, 120.5f, 126.6f, 245.9f, 258.1f, 270.3f, + 132.7f, 138.8f, 144.9f, 141.4f, 147.9f, 154.4f}); + + input.linspace(-10, 0.5); + weights.linspace(0.1, 0.1); + bias = 0.2; + + sd::ops::deconv3d op; + auto results = op.evaluate({&input, &weights}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_test6) { - - int bS=2, oD=4,oH=4,oW=4, oC=5,iC=10, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int iD=3,iH=3,iW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - int wFormat = 1; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] - - NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {iC, oC, kD, kH, kW}, {20., 15., 10., 5., 0., -5., -10., -15., 19., 14., 9., 4., -1., -6., -11., -16., 18., 13., 8., 3., -2., -7., -12., -17., - 17., 12., 7., 2., -3., -8., -13., -18., 16., 11., 6., 1., -4., -9., -14., -19., 19.9, 14.9, 9.9, 4.9, -0.1, -5.1, -10.1, -15.1, 18.9, 13.9, 8.9, 3.9, -1.1, -6.1, - -11.1, -16.1, 17.9, 12.9, 7.9, 2.9, -2.1, -7.1, -12.1, -17.1, 16.9, 11.9, 6.9, 1.9, -3.1, -8.1, -13.1, -18.1, 15.9, 10.9, 5.9, 0.9, -4.1, -9.1, -14.1, -19.1, - 19.799999, 14.8, 9.8, 4.8, -0.2, -5.2, -10.2, -15.2, 18.799999, 13.8, 8.8, 3.8, -1.2, -6.2, -11.2, -16.200001, 17.799999, 12.8, 7.8, 2.8, -2.2, -7.2, -12.2, - -17.200001, 16.799999, 11.8, 6.8, 1.8, -3.2, -8.2, -13.2, -18.200001, 15.8, 10.8, 5.8, 0.8, -4.2, -9.2, -14.2, -19.200001, 19.700001, 14.7, 9.7, 4.7, -0.3, -5.3, -10.3, -15.3, 18.700001, 13.7, 8.7, 3.7, -1.3, -6.3, -11.3, -16.299999, 17.700001, 12.7, 7.7, 2.7, -2.3, -7.3, -12.3, -17.299999, 16.700001, 11.7, 6.7, 1.7, -3.3, -8.3, -13.3, -18.299999, 15.7, 10.7, 5.7, 0.7, -4.3, -9.3, -14.3, -19.299999, 19.6, 14.6, 9.6, 4.6, -0.4, -5.4, -10.4, -15.4, 18.6, 13.6, 8.6, 3.6, -1.4, -6.4, -11.4, -16.4, 17.6, 12.6, 7.6, 2.6, -2.4, -7.4, -12.4, -17.4, 16.6, 11.6, 6.6, 1.6, -3.4, -8.4, -13.4, -18.4, 15.6, 10.6, 5.6, 0.6, -4.4, -9.4, -14.4, -19.4, 19.5, 14.5, 9.5, 4.5, -0.5, -5.5, -10.5, -15.5, 18.5, 13.5, 8.5, 3.5, -1.5, -6.5, -11.5, -16.5, 17.5, 12.5, 7.5, 2.5, -2.5, -7.5, -12.5, -17.5, 16.5, 11.5, 6.5, 1.5, -3.5, -8.5, -13.5, -18.5, 15.5, 10.5, 5.5, 0.5, -4.5, -9.5, -14.5, -19.5, 19.4, 14.4, 9.4, 4.4, -0.6, -5.6, -10.6, -15.6, 18.4, 13.4, 8.4, 3.4, -1.6, -6.6, -11.6, -16.6, 17.4, 12.4, 7.4, 2.4, -2.6, -7.6, -12.6, -17.6, 16.4, 11.4, 6.4, 1.4, -3.6, -8.6, -13.6, -18.6, 15.4, 10.4, 5.4, 0.4, -4.6, -9.6, -14.6, -19.6, 19.299999, 14.3, 9.3, 4.3, -0.7, -5.7, -10.7, -15.7, 18.299999, 13.3, 8.3, 3.3, -1.7, -6.7, -11.7, -16.700001, 17.299999, 12.3, 7.3, 2.3, -2.7, -7.7, -12.7, -17.700001, 16.299999, 11.3, 6.3, 1.3, -3.7, -8.7, -13.7, -18.700001, 15.3, 10.3, 5.3, 0.3, -4.7, -9.7, -14.7, -19.700001, 19.200001, 14.2, 9.2, 4.2, -0.8, -5.8, -10.8, -15.8, 18.200001, 13.2, 8.2, 3.2, -1.8, -6.8, -11.8, -16.799999, 17.200001, 12.2, 7.2, 2.2, -2.8, -7.8, -12.8, -17.799999, 16.200001, 11.2, 6.2, 1.2, -3.8, -8.8, -13.8, -18.799999, 15.2, 10.2, 5.2, 0.2, -4.8, -9.8, -14.8, -19.799999, 19.1, 14.1, 9.1, 4.1, -0.9, -5.9, -10.9, -15.9, 18.1, 13.1, 8.1, 3.1, -1.9, -6.9, -11.9, -16.9, 17.1, 12.1, 7.1, 2.1, -2.9, -7.9, -12.9, -17.9, 16.1, 11.1, 6.1, 1.1, -3.9, -8.9, -13.9, -18.9, 15.1, 10.1, 5.1, 0.1, -4.9, -9.9, -14.9, -19.9}, sd::DataType::FLOAT32); - NDArray expOutput('c', {bS, oD, oH, oW, oC}, {-5191.349609, -4925.850098, -4660.350098, -4394.850098, -4129.349609, -8859.700195, -8338.700195, -7817.700195, - -7296.700195, -6775.700195, -8518.700195, -8017.700195, -7516.700195, -7015.700195, -6514.700195, -3572.850098, -3327.349854, -3081.850098, -2836.350098, - -2590.850098, -7141.200195, -6640.200195, -6139.199707, -5638.200195, -5137.200195, -11486.400391, -10504.400391, -9522.400391, -8540.400391, -7558.399902, - -11004.400391, -10062.400391, -9120.400391, -8178.399414, -7236.399414, -4254.200195, -3793.200195, -3332.200195, -2871.199951, -2410.200195, -6268.200195, - -5827.200195, -5386.200195, -4945.200195, -4504.200195, -10040.400391, -9178.400391, -8316.400391, -7454.400391, -6592.399902, -9558.400391, -8736.400391, - -7914.400391, -7092.399902, -6270.400391, -3681.199707, -3280.200195, -2879.200195, -2478.200195, -2077.200195, -1963.350098, -1757.850098, -1552.349854, -1346.849976, -1141.349976, -2803.700195, -2402.699951, -2001.699951, -1600.699951, -1199.699951, -2662.699951, -2281.699951, -1900.699951, -1519.699951, -1138.700073, -844.850037, -659.349976, -473.850006, -288.350006, -102.849998, -3313.200195, -2872.199951, -2431.200195, -1990.200195, -1549.199829, -4230.399902, -3368.400391, -2506.400391, -1644.400146, -782.400146, -3948.400146, -3126.400391, -2304.399902, -1482.400146, -660.400269, -926.200195, -525.199951, -124.199951, 276.799927, 677.799805, -1643.400269, -821.400146, 0.599609, 822.600098, 1644.599609, 1005.199951, 2609.199707, 4213.200195, 5817.200195, 7421.200684, 1169.199463, 2693.200195, 4217.199707, 5741.201172, 7265.203125, 2430.599609, 3172.600098, 3914.600098, 4656.599609, 5398.599609, -1097.400391, -395.400269, 306.599609, 1008.599854, 1710.599731, 1497.199219, 2861.199219, 4225.201172, 5589.200684, 6953.200684, 1661.199219, 2945.199463, 4229.199707, 5513.201172, 6797.200684, 2376.599609, 2998.599854, 3620.599609, 4242.600098, 4864.600098, 1042.799927, 1363.799927, 1684.800171, 2005.799805, 2326.799805, 3681.599609, 4303.599609, 4925.599609, 5547.600098, 6169.599609, 3563.599609, 4145.599609, 4727.600098, 5309.600098, 5891.599609, 2429.800293, 2710.800293, 2991.799805, 3272.799805, 3553.799805, -1594.199829, -1333.199951, -1072.200073, -811.200012, -550.200134, -1692.400024, -1190.399902, -688.400024, -186.400269, 315.600098, -1410.399902, -948.399902, -486.399902, -24.399780, 437.599731, -107.199890, 113.799988, 334.799988, 555.799988, 776.800049, -5.400024, 456.599731, 918.600281, 1380.599731, 1842.599976, 2481.199219, 3365.199219, 4249.199219, 5133.199219, 6017.199219, 2645.199219, 3449.199219, 4253.199707, 5057.199219, 5861.199707, 2268.600098, 2650.599609, 3032.600098, 3414.600098, 3796.599609, 540.599976, 882.600220, 1224.599854, 1566.599854, 1908.600220, 2973.200195, 3617.199707, 4261.199219, 4905.199219, 5549.199219, 3137.199707, 3701.199219, 4265.199707, 4829.199219, 5393.199219, 2214.599609, 2476.600098, 2738.599609, 3000.599854, 3262.599854, 961.800049, 1102.800049, 1243.799927, 1384.800171, 1525.799927, 2619.599609, 2881.599854, 3143.599854, 3405.599609, 3667.599609, 2501.599854, 2723.599609, 2945.599854, 3167.599609, 3389.600098, 1448.799927, 1549.800049, 1650.799927, 1751.800049, 1852.799927, 37.650002, 123.150009, 208.650009, 294.149994, 379.650024, 498.300018, 659.300049, 820.300049, 981.299927, 1142.299927, 439.300018, 580.299988, 721.299927, 862.300049, 1003.300049, 356.149963, 421.649994, 487.150024, 552.649963, 618.150024, 916.799988, 1057.800049, 1198.800171, 1339.800049, 1480.800171, 2429.600098, 2691.600098, 2953.599609, 3215.599609, 3477.599609, 2111.599854, 2333.599854, 2555.600098, 2777.599609, 2999.600098, 1203.800049, 1304.800049, 1405.799927, 1506.800049, 1607.800049, 589.799927, 670.800049, 751.800049, 832.800049, 913.800049, 1475.599976, 1617.600098, 1759.600098, 1901.600098, 2043.600098, 1157.600098, 1259.600098, 1361.600098, 1463.600098, 1565.599976, 576.799988, 617.800049, 658.799988, 699.799927, 740.800049, 265.649994, 291.149994, 316.650024, 342.150024, 367.649994, 554.300049, 595.299988, 636.299927, 677.299988, 718.299988, 295.300018, 316.300018, 337.299988, 358.299988, 379.300018, 84.149994, 89.650002, 95.150002, 100.650009, 106.150009, 87.150002, 82.650002, 78.150002, 73.650002, 69.150002, 347.299988, 328.300018, 309.300018, 290.299988, 271.299988, 688.300049, 649.299927, 610.299988, 571.300049, 532.300049, 355.650024, 331.149963, 306.649994, 282.149994, 257.649994, 715.800049, 676.800049, 637.799988, 598.800049, 559.800049, 1527.600098, 1429.599976, 1331.599976, 1233.600098, 1135.600098, 2009.600098, 1871.600098, 1733.599976, 1595.600098, 1457.600098, 902.799988, 823.799927, 744.800049, 665.800049, 586.800049, 1588.800049, 1489.800049, 1390.800049, 1291.800049, 1192.799927, 2973.600098, 2755.600098, 2537.600098, 2319.600098, 2101.600098, 3455.600098, 3197.600098, 2939.600098, 2681.600098, 2423.600098, 1475.800049, 1336.800049, 1197.800049, 1058.799927, 919.800049, 615.150024, 550.650024, 486.149994, 421.649994, 357.150024, 1003.300049, 864.300049, 725.299988, 586.300049, 447.300018, 1144.300049, 985.299988, 826.300049, 667.299988, 508.299988, 383.649994, 299.149994, 214.649994, 130.149994, 45.649998, 1843.799927, 1744.799927, 1645.800049, 1546.799927, 1447.800049, 3383.600098, 3165.600098, 2947.600098, 2729.599854, 2511.600098, 3665.599854, 3407.600098, 3149.599854, 2891.599854, 2633.599854, 1530.800171, 1391.800049, 1252.800049, 1113.800049, 974.800171, 3270.599609, 3012.599854, 2754.600098, 2496.599854, 2238.600098, 5433.199707, 4877.200195, 4321.200195, 3765.199707, 3209.199951, 5597.200195, 4961.199707, 4325.200195, 3689.199707, 3053.199951, 1944.600098, 1606.599854, 1268.600098, 930.599976, 592.600098, 3816.599854, 3438.600342, 3060.599854, 2682.600098, 2304.600098, 5925.200195, 5129.200684, 4333.200195, 3537.199951, 2741.199707, 6089.200684, 5213.200195, 4337.200195, 3461.199707, 2585.200195, 1890.599609, 1432.600220, 974.599976, 516.599976, 58.599976, 799.799927, 580.800171, 361.800110, 142.800110, -76.200073, 495.599976, 37.599976, -420.399902, -878.399902, -1336.400024, 377.599854, -120.399902, -618.399902, -1116.400391, -1614.399902, -513.199951, -772.200012, -1031.199951, -1290.199829, -1549.200073, 3562.800049, 3283.799805, 3004.799805, 2725.800293, 2446.800293, 5921.599609, 5343.599609, 4765.600098, 4187.599609, 3609.599854, 6203.599609, 5585.600098, 4967.600098, 4349.599609, 3731.600098, 2349.799805, 2030.800171, 1711.800293, 1392.800171, 1073.799927, 4908.600098, 4290.599609, 3672.600098, 3054.600098, 2436.600098, 6909.199219, 5633.200684, 4357.200195, 3081.199219, 1805.199463, 7073.200684, 5717.199707, 4361.199219, 3005.199463, 1649.199951, 1782.600464, 1084.599609, 386.599609, -311.400146, -1009.400635, 5454.600098, 4716.599609, 3978.599854, 3240.600098, 2502.600098, 7401.199219, 5885.199219, 4369.200195, 2853.200195, 1337.199219, 7565.199219, 5969.200195, 4373.200195, 2777.199219, 1181.199219, 1728.599854, 910.600098, 92.600098, -725.400391, -1543.400391, 718.799927, 319.800049, -79.200073, -478.200073, -877.200073, -566.400391, -1384.400391, -2202.400391, -3020.400391, -3838.400391, -684.400146, -1542.400391, -2400.400391, -3258.400391, -4116.400391, -1494.200073, -1933.200073, -2372.199707, -2811.200195, -3250.199951, -83.850006, -268.350006, -452.849945, -637.350037, -821.849976, -1094.699951, -1473.699951, -1852.700073, -2231.699707, -2610.699951, -1153.700073, -1552.699829, -1951.699829, -2350.700195, -2749.700195, -1115.350098, -1319.849854, -1524.350098, -1728.849976, -1933.350098, -2026.200073, -2425.200195, -2824.200195, -3223.199707, -3622.200195, -6156.400391, -6974.400391, -7792.400391, -8610.400391, -9428.399414, -6474.400391, -7332.400391, -8190.400391, -9048.399414, -9906.399414, -4439.200195, -4878.199707, -5317.200195, -5756.200195, -6195.200195, -2353.199951, -2812.200195, -3271.200195, -3730.200195, -4189.200195, -7110.400391, -8048.400391, -8986.399414, -9924.400391, -10862.400391, -7428.400391, -8406.399414, -9384.399414, -10362.400391, -11340.400391, -5066.200195, -5565.200195, -6064.200195, -6563.200195, -7062.200195, -2555.849854, -2800.349854, -3044.849854, -3289.350098, -3533.850098, -6438.700195, -6937.700195, -7436.700195, -7935.700195, -8434.699219, -6697.700195, -7216.700195, -7735.700195, -8254.699219, -8773.700195, -4087.349854, -4351.850098, -4616.349609, -4880.850098, -5145.350098}, sd::DataType::FLOAT32); - - input.linspace(-27, 0.1); - - sd::ops::deconv3d op; - auto results = op.evaluate({&input, &weights}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); - ASSERT_EQ(Status::OK(), results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output, 1e-3)); + int bS = 2, oD = 4, oH = 4, oW = 4, oC = 5, iC = 10, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int iD = 3, iH = 3, iW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - + // [iC, kD, kH, kW, oC] + + NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {iC, oC, kD, kH, kW}, + {20., 15., 10., 5., 0., -5., -10., -15., + 19., 14., 9., 4., -1., -6., -11., -16., + 18., 13., 8., 3., -2., -7., -12., -17., + 17., 12., 7., 2., -3., -8., -13., -18., + 16., 11., 6., 1., -4., -9., -14., -19., + 19.9, 14.9, 9.9, 4.9, -0.1, -5.1, -10.1, -15.1, + 18.9, 13.9, 8.9, 3.9, -1.1, -6.1, -11.1, -16.1, + 17.9, 12.9, 7.9, 2.9, -2.1, -7.1, -12.1, -17.1, + 16.9, 11.9, 6.9, 1.9, -3.1, -8.1, -13.1, -18.1, + 15.9, 10.9, 5.9, 0.9, -4.1, -9.1, -14.1, -19.1, + 19.799999, 14.8, 9.8, 4.8, -0.2, -5.2, -10.2, -15.2, + 18.799999, 13.8, 8.8, 3.8, -1.2, -6.2, -11.2, -16.200001, + 17.799999, 12.8, 7.8, 2.8, -2.2, -7.2, -12.2, -17.200001, + 16.799999, 11.8, 6.8, 1.8, -3.2, -8.2, -13.2, -18.200001, + 15.8, 10.8, 5.8, 0.8, -4.2, -9.2, -14.2, -19.200001, + 19.700001, 14.7, 9.7, 4.7, -0.3, -5.3, -10.3, -15.3, + 18.700001, 13.7, 8.7, 3.7, -1.3, -6.3, -11.3, -16.299999, + 17.700001, 12.7, 7.7, 2.7, -2.3, -7.3, -12.3, -17.299999, + 16.700001, 11.7, 6.7, 1.7, -3.3, -8.3, -13.3, -18.299999, + 15.7, 10.7, 5.7, 0.7, -4.3, -9.3, -14.3, -19.299999, + 19.6, 14.6, 9.6, 4.6, -0.4, -5.4, -10.4, -15.4, + 18.6, 13.6, 8.6, 3.6, -1.4, -6.4, -11.4, -16.4, + 17.6, 12.6, 7.6, 2.6, -2.4, -7.4, -12.4, -17.4, + 16.6, 11.6, 6.6, 1.6, -3.4, -8.4, -13.4, -18.4, + 15.6, 10.6, 5.6, 0.6, -4.4, -9.4, -14.4, -19.4, + 19.5, 14.5, 9.5, 4.5, -0.5, -5.5, -10.5, -15.5, + 18.5, 13.5, 8.5, 3.5, -1.5, -6.5, -11.5, -16.5, + 17.5, 12.5, 7.5, 2.5, -2.5, -7.5, -12.5, -17.5, + 16.5, 11.5, 6.5, 1.5, -3.5, -8.5, -13.5, -18.5, + 15.5, 10.5, 5.5, 0.5, -4.5, -9.5, -14.5, -19.5, + 19.4, 14.4, 9.4, 4.4, -0.6, -5.6, -10.6, -15.6, + 18.4, 13.4, 8.4, 3.4, -1.6, -6.6, -11.6, -16.6, + 17.4, 12.4, 7.4, 2.4, -2.6, -7.6, -12.6, -17.6, + 16.4, 11.4, 6.4, 1.4, -3.6, -8.6, -13.6, -18.6, + 15.4, 10.4, 5.4, 0.4, -4.6, -9.6, -14.6, -19.6, + 19.299999, 14.3, 9.3, 4.3, -0.7, -5.7, -10.7, -15.7, + 18.299999, 13.3, 8.3, 3.3, -1.7, -6.7, -11.7, -16.700001, + 17.299999, 12.3, 7.3, 2.3, -2.7, -7.7, -12.7, -17.700001, + 16.299999, 11.3, 6.3, 1.3, -3.7, -8.7, -13.7, -18.700001, + 15.3, 10.3, 5.3, 0.3, -4.7, -9.7, -14.7, -19.700001, + 19.200001, 14.2, 9.2, 4.2, -0.8, -5.8, -10.8, -15.8, + 18.200001, 13.2, 8.2, 3.2, -1.8, -6.8, -11.8, -16.799999, + 17.200001, 12.2, 7.2, 2.2, -2.8, -7.8, -12.8, -17.799999, + 16.200001, 11.2, 6.2, 1.2, -3.8, -8.8, -13.8, -18.799999, + 15.2, 10.2, 5.2, 0.2, -4.8, -9.8, -14.8, -19.799999, + 19.1, 14.1, 9.1, 4.1, -0.9, -5.9, -10.9, -15.9, + 18.1, 13.1, 8.1, 3.1, -1.9, -6.9, -11.9, -16.9, + 17.1, 12.1, 7.1, 2.1, -2.9, -7.9, -12.9, -17.9, + 16.1, 11.1, 6.1, 1.1, -3.9, -8.9, -13.9, -18.9, + 15.1, 10.1, 5.1, 0.1, -4.9, -9.9, -14.9, -19.9}, + sd::DataType::FLOAT32); + NDArray expOutput( + 'c', {bS, oD, oH, oW, oC}, + {-5191.349609, -4925.850098, -4660.350098, -4394.850098, -4129.349609, + -8859.700195, -8338.700195, -7817.700195, -7296.700195, -6775.700195, + -8518.700195, -8017.700195, -7516.700195, -7015.700195, -6514.700195, + -3572.850098, -3327.349854, -3081.850098, -2836.350098, -2590.850098, + -7141.200195, -6640.200195, -6139.199707, -5638.200195, -5137.200195, + -11486.400391, -10504.400391, -9522.400391, -8540.400391, -7558.399902, + -11004.400391, -10062.400391, -9120.400391, -8178.399414, -7236.399414, + -4254.200195, -3793.200195, -3332.200195, -2871.199951, -2410.200195, + -6268.200195, -5827.200195, -5386.200195, -4945.200195, -4504.200195, + -10040.400391, -9178.400391, -8316.400391, -7454.400391, -6592.399902, + -9558.400391, -8736.400391, -7914.400391, -7092.399902, -6270.400391, + -3681.199707, -3280.200195, -2879.200195, -2478.200195, -2077.200195, + -1963.350098, -1757.850098, -1552.349854, -1346.849976, -1141.349976, + -2803.700195, -2402.699951, -2001.699951, -1600.699951, -1199.699951, + -2662.699951, -2281.699951, -1900.699951, -1519.699951, -1138.700073, + -844.850037, -659.349976, -473.850006, -288.350006, -102.849998, + -3313.200195, -2872.199951, -2431.200195, -1990.200195, -1549.199829, + -4230.399902, -3368.400391, -2506.400391, -1644.400146, -782.400146, + -3948.400146, -3126.400391, -2304.399902, -1482.400146, -660.400269, + -926.200195, -525.199951, -124.199951, 276.799927, 677.799805, + -1643.400269, -821.400146, 0.599609, 822.600098, 1644.599609, + 1005.199951, 2609.199707, 4213.200195, 5817.200195, 7421.200684, + 1169.199463, 2693.200195, 4217.199707, 5741.201172, 7265.203125, + 2430.599609, 3172.600098, 3914.600098, 4656.599609, 5398.599609, + -1097.400391, -395.400269, 306.599609, 1008.599854, 1710.599731, + 1497.199219, 2861.199219, 4225.201172, 5589.200684, 6953.200684, + 1661.199219, 2945.199463, 4229.199707, 5513.201172, 6797.200684, + 2376.599609, 2998.599854, 3620.599609, 4242.600098, 4864.600098, + 1042.799927, 1363.799927, 1684.800171, 2005.799805, 2326.799805, + 3681.599609, 4303.599609, 4925.599609, 5547.600098, 6169.599609, + 3563.599609, 4145.599609, 4727.600098, 5309.600098, 5891.599609, + 2429.800293, 2710.800293, 2991.799805, 3272.799805, 3553.799805, + -1594.199829, -1333.199951, -1072.200073, -811.200012, -550.200134, + -1692.400024, -1190.399902, -688.400024, -186.400269, 315.600098, + -1410.399902, -948.399902, -486.399902, -24.399780, 437.599731, + -107.199890, 113.799988, 334.799988, 555.799988, 776.800049, + -5.400024, 456.599731, 918.600281, 1380.599731, 1842.599976, + 2481.199219, 3365.199219, 4249.199219, 5133.199219, 6017.199219, + 2645.199219, 3449.199219, 4253.199707, 5057.199219, 5861.199707, + 2268.600098, 2650.599609, 3032.600098, 3414.600098, 3796.599609, + 540.599976, 882.600220, 1224.599854, 1566.599854, 1908.600220, + 2973.200195, 3617.199707, 4261.199219, 4905.199219, 5549.199219, + 3137.199707, 3701.199219, 4265.199707, 4829.199219, 5393.199219, + 2214.599609, 2476.600098, 2738.599609, 3000.599854, 3262.599854, + 961.800049, 1102.800049, 1243.799927, 1384.800171, 1525.799927, + 2619.599609, 2881.599854, 3143.599854, 3405.599609, 3667.599609, + 2501.599854, 2723.599609, 2945.599854, 3167.599609, 3389.600098, + 1448.799927, 1549.800049, 1650.799927, 1751.800049, 1852.799927, + 37.650002, 123.150009, 208.650009, 294.149994, 379.650024, + 498.300018, 659.300049, 820.300049, 981.299927, 1142.299927, + 439.300018, 580.299988, 721.299927, 862.300049, 1003.300049, + 356.149963, 421.649994, 487.150024, 552.649963, 618.150024, + 916.799988, 1057.800049, 1198.800171, 1339.800049, 1480.800171, + 2429.600098, 2691.600098, 2953.599609, 3215.599609, 3477.599609, + 2111.599854, 2333.599854, 2555.600098, 2777.599609, 2999.600098, + 1203.800049, 1304.800049, 1405.799927, 1506.800049, 1607.800049, + 589.799927, 670.800049, 751.800049, 832.800049, 913.800049, + 1475.599976, 1617.600098, 1759.600098, 1901.600098, 2043.600098, + 1157.600098, 1259.600098, 1361.600098, 1463.600098, 1565.599976, + 576.799988, 617.800049, 658.799988, 699.799927, 740.800049, + 265.649994, 291.149994, 316.650024, 342.150024, 367.649994, + 554.300049, 595.299988, 636.299927, 677.299988, 718.299988, + 295.300018, 316.300018, 337.299988, 358.299988, 379.300018, + 84.149994, 89.650002, 95.150002, 100.650009, 106.150009, + 87.150002, 82.650002, 78.150002, 73.650002, 69.150002, + 347.299988, 328.300018, 309.300018, 290.299988, 271.299988, + 688.300049, 649.299927, 610.299988, 571.300049, 532.300049, + 355.650024, 331.149963, 306.649994, 282.149994, 257.649994, + 715.800049, 676.800049, 637.799988, 598.800049, 559.800049, + 1527.600098, 1429.599976, 1331.599976, 1233.600098, 1135.600098, + 2009.600098, 1871.600098, 1733.599976, 1595.600098, 1457.600098, + 902.799988, 823.799927, 744.800049, 665.800049, 586.800049, + 1588.800049, 1489.800049, 1390.800049, 1291.800049, 1192.799927, + 2973.600098, 2755.600098, 2537.600098, 2319.600098, 2101.600098, + 3455.600098, 3197.600098, 2939.600098, 2681.600098, 2423.600098, + 1475.800049, 1336.800049, 1197.800049, 1058.799927, 919.800049, + 615.150024, 550.650024, 486.149994, 421.649994, 357.150024, + 1003.300049, 864.300049, 725.299988, 586.300049, 447.300018, + 1144.300049, 985.299988, 826.300049, 667.299988, 508.299988, + 383.649994, 299.149994, 214.649994, 130.149994, 45.649998, + 1843.799927, 1744.799927, 1645.800049, 1546.799927, 1447.800049, + 3383.600098, 3165.600098, 2947.600098, 2729.599854, 2511.600098, + 3665.599854, 3407.600098, 3149.599854, 2891.599854, 2633.599854, + 1530.800171, 1391.800049, 1252.800049, 1113.800049, 974.800171, + 3270.599609, 3012.599854, 2754.600098, 2496.599854, 2238.600098, + 5433.199707, 4877.200195, 4321.200195, 3765.199707, 3209.199951, + 5597.200195, 4961.199707, 4325.200195, 3689.199707, 3053.199951, + 1944.600098, 1606.599854, 1268.600098, 930.599976, 592.600098, + 3816.599854, 3438.600342, 3060.599854, 2682.600098, 2304.600098, + 5925.200195, 5129.200684, 4333.200195, 3537.199951, 2741.199707, + 6089.200684, 5213.200195, 4337.200195, 3461.199707, 2585.200195, + 1890.599609, 1432.600220, 974.599976, 516.599976, 58.599976, + 799.799927, 580.800171, 361.800110, 142.800110, -76.200073, + 495.599976, 37.599976, -420.399902, -878.399902, -1336.400024, + 377.599854, -120.399902, -618.399902, -1116.400391, -1614.399902, + -513.199951, -772.200012, -1031.199951, -1290.199829, -1549.200073, + 3562.800049, 3283.799805, 3004.799805, 2725.800293, 2446.800293, + 5921.599609, 5343.599609, 4765.600098, 4187.599609, 3609.599854, + 6203.599609, 5585.600098, 4967.600098, 4349.599609, 3731.600098, + 2349.799805, 2030.800171, 1711.800293, 1392.800171, 1073.799927, + 4908.600098, 4290.599609, 3672.600098, 3054.600098, 2436.600098, + 6909.199219, 5633.200684, 4357.200195, 3081.199219, 1805.199463, + 7073.200684, 5717.199707, 4361.199219, 3005.199463, 1649.199951, + 1782.600464, 1084.599609, 386.599609, -311.400146, -1009.400635, + 5454.600098, 4716.599609, 3978.599854, 3240.600098, 2502.600098, + 7401.199219, 5885.199219, 4369.200195, 2853.200195, 1337.199219, + 7565.199219, 5969.200195, 4373.200195, 2777.199219, 1181.199219, + 1728.599854, 910.600098, 92.600098, -725.400391, -1543.400391, + 718.799927, 319.800049, -79.200073, -478.200073, -877.200073, + -566.400391, -1384.400391, -2202.400391, -3020.400391, -3838.400391, + -684.400146, -1542.400391, -2400.400391, -3258.400391, -4116.400391, + -1494.200073, -1933.200073, -2372.199707, -2811.200195, -3250.199951, + -83.850006, -268.350006, -452.849945, -637.350037, -821.849976, + -1094.699951, -1473.699951, -1852.700073, -2231.699707, -2610.699951, + -1153.700073, -1552.699829, -1951.699829, -2350.700195, -2749.700195, + -1115.350098, -1319.849854, -1524.350098, -1728.849976, -1933.350098, + -2026.200073, -2425.200195, -2824.200195, -3223.199707, -3622.200195, + -6156.400391, -6974.400391, -7792.400391, -8610.400391, -9428.399414, + -6474.400391, -7332.400391, -8190.400391, -9048.399414, -9906.399414, + -4439.200195, -4878.199707, -5317.200195, -5756.200195, -6195.200195, + -2353.199951, -2812.200195, -3271.200195, -3730.200195, -4189.200195, + -7110.400391, -8048.400391, -8986.399414, -9924.400391, -10862.400391, + -7428.400391, -8406.399414, -9384.399414, -10362.400391, -11340.400391, + -5066.200195, -5565.200195, -6064.200195, -6563.200195, -7062.200195, + -2555.849854, -2800.349854, -3044.849854, -3289.350098, -3533.850098, + -6438.700195, -6937.700195, -7436.700195, -7935.700195, -8434.699219, + -6697.700195, -7216.700195, -7735.700195, -8254.699219, -8773.700195, + -4087.349854, -4351.850098, -4616.349609, -4880.850098, -5145.350098}, + sd::DataType::FLOAT32); + + input.linspace(-27, 0.1); + + sd::ops::deconv3d op; + auto results = op.evaluate({&input, &weights}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat, wFormat}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output, 1e-3)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_test7) { - - int bS=2, oD=4,oH=4,oW=4, iC=5,oC=10, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=1,pH=0,pW=0, dD=1,dH=1,dW=1; - int iD=4,iH=4,iW=4; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - int wFormat = 2; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] - - NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); - NDArray weights('c', {iC, kD, kH, kW, oC}, {20., 19.5, 19., 18.5, 18., 17.5, 17., 16.5, 16., 15.5, 15., 14.5, 14., 13.5, 13., 12.5, 12., 11.5, 11., 10.5, 10., - 9.5, 9., 8.5, 8., 7.5, 7., 6.5, 6., 5.5, 5., 4.5, 4., 3.5, 3., 2.5, 2., 1.5, 1., 0.5, 0., -0.5, -1., -1.5, -2., -2.5, -3., -3.5, -4., -4.5, -5., -5.5, -6., - -6.5, -7., -7.5, -8., -8.5, -9., -9.5, -10., -10.5, -11., -11.5, -12., -12.5, -13., -13.5, -14., -14.5, -15., -15.5, -16., -16.5, -17., -17.5, -18., -18.5, - -19., -19.5, 19.9, 19.4, 18.9, 18.4, 17.9, 17.4, 16.9, 16.4, 15.9, 15.4, 14.9, 14.4, 13.9, 13.4, 12.9, 12.4, 11.9, 11.4, 10.9, 10.4, 9.9, 9.4, 8.9, 8.4, 7.9, - 7.4, 6.9, 6.4, 5.9, 5.4, 4.9, 4.4, 3.9, 3.4, 2.9, 2.4, 1.9, 1.4, 0.9, 0.4, -0.1, -0.6, -1.1, -1.6, -2.1, -2.6, -3.1, -3.6, -4.1, -4.6, -5.1, -5.6, -6.1, -6.6, -7.1, -7.6, -8.1, -8.6, -9.1, -9.6, -10.1, -10.6, -11.1, -11.6, -12.1, -12.6, -13.1, -13.6, -14.1, -14.6, -15.1, -15.6, -16.1, -16.6, -17.1, -17.6, -18.1, -18.6, -19.1, -19.6, 19.799999, 19.299999, 18.799999, 18.299999, 17.799999, 17.299999, 16.799999, 16.299999, 15.8, 15.3, 14.8, 14.3, 13.8, 13.3, 12.8, 12.3, 11.8, 11.3, 10.8, 10.3, 9.8, 9.3, 8.8, 8.3, 7.8, 7.3, 6.8, 6.3, 5.8, 5.3, 4.8, 4.3, 3.8, 3.3, 2.8, 2.3, 1.8, 1.3, 0.8, 0.3, -0.2, -0.7, -1.2, -1.7, -2.2, -2.7, -3.2, -3.7, -4.2, -4.7, -5.2, -5.7, -6.2, -6.7, -7.2, -7.7, -8.2, -8.7, -9.2, -9.7, -10.2, -10.7, -11.2, -11.7, -12.2, -12.7, -13.2, -13.7, -14.2, -14.7, -15.2, -15.7, -16.200001, -16.700001, -17.200001, -17.700001, -18.200001, -18.700001, -19.200001, -19.700001, 19.700001, 19.200001, 18.700001, 18.200001, 17.700001, 17.200001, 16.700001, 16.200001, 15.7, 15.2, 14.7, 14.2, 13.7, 13.2, 12.7, 12.2, 11.7, 11.2, 10.7, 10.2, 9.7, 9.2, 8.7, 8.2, 7.7, 7.2, 6.7, 6.2, 5.7, 5.2, 4.7, 4.2, 3.7, 3.2, 2.7, 2.2, 1.7, 1.2, 0.7, 0.2, -0.3, -0.8, -1.3, -1.8, -2.3, -2.8, -3.3, -3.8, -4.3, -4.8, -5.3, -5.8, -6.3, -6.8, -7.3, -7.8, -8.3, -8.8, -9.3, -9.8, -10.3, -10.8, -11.3, -11.8, -12.3, -12.8, -13.3, -13.8, -14.3, -14.8, -15.3, -15.8, -16.299999, -16.799999, -17.299999, -17.799999, -18.299999, -18.799999, -19.299999, -19.799999, 19.6, 19.1, 18.6, 18.1, 17.6, 17.1, 16.6, 16.1, 15.6, 15.1, 14.6, 14.1, 13.6, 13.1, 12.6, 12.1, 11.6, 11.1, 10.6, 10.1, 9.6, 9.1, 8.6, 8.1, 7.6, 7.1, 6.6, 6.1, 5.6, 5.1, 4.6, 4.1, 3.6, 3.1, 2.6, 2.1, 1.6, 1.1, 0.6, 0.1, -0.4, -0.9, -1.4, -1.9, -2.4, -2.9, -3.4, -3.9, -4.4, -4.9, -5.4, -5.9, -6.4, -6.9, -7.4, -7.9, -8.4, -8.9, -9.4, -9.9, -10.4, -10.9, -11.4, -11.9, -12.4, -12.9, -13.4, -13.9, -14.4, -14.9, -15.4, -15.9, -16.4, -16.9, -17.4, -17.9, -18.4, -18.9, -19.4, -19.9}, sd::DataType::FLOAT32); - NDArray expOutput('c', {bS, oC, oD, oH, oW}, {-1907.199951, -3324.499756, -3307.199707, -3289.899902, -2814.799805, -4664.800293, -4640.199707, -4615.600098, - -2755.599854, -4566.400391, -4541.800293, -4517.199707, -2696.400146, -4468., -4443.400391, -4418.799805, -1735.999878, -2542.199951, -2527.600098, -2513., - -1592.800049, -1355.999756, -1346.799805, -1337.599854, -1554.400024, -1319.199829, -1310.000122, -1300.800049, -1516., -1282.400024, -1273.200195, -1263.999878, - -1579.200073, -2308.599854, -2294., -2279.400146, -1439.199951, -1208.799683, -1199.599976, -1190.399902, -1400.800049, -1172., -1162.800049, -1153.600098, - -1362.399902, -1135.199951, -1126., -1116.799805, -1422.400024, -2075., -2060.399902, -2045.799683, -1285.599976, -1061.599854, -1052.399902, -1043.200195, - -1247.199951, -1024.800049, -1015.599976, -1006.400146, -1208.799927, -988.000122, -978.799683, -969.599976, -1859.199951, -3228.75, -3211.949951, -3195.150146, -2719.800049, -4475.299805, -4451.699707, -4428.100098, -2662.600098, -4380.899902, -4357.300293, -4333.699707, -2605.399902, -4286.5, -4262.899902, -4239.300293, -1643.999878, -2358.700195, -2345.099854, -2331.5, -1410.800049, -992.999756, -985.799438, -978.600098, -1376.400024, -964.199707, -957., -949.800049, -1342., -935.399902, -928.199951, -921.000122, -1495.200073, -2141.099854, -2127.5, -2113.900391, -1273.199951, -877.799683, -870.599976, -863.39978, -1238.800049, -849., -841.800171, -834.599976, -1204.400024, -820.199707, -813., -805.799438, -1346.400146, -1923.500122, -1909.899902, -1896.299927, -1135.599976, -762.599976, -755.399658, -748.200195, -1101.199951, -733.800049, -726.599854, -719.400024, -1066.800049, -705., -697.800171, -690.599976, -1811.199951, -3133., -3116.699951, -3100.399902, -2624.799805, -4285.799805, -4263.199707, -4240.600098, -2569.600098, -4195.399902, -4172.800293, -4150.199707, -2514.399902, -4105., -4082.400146, -4059.800293, -1552., -2175.200195, -2162.599854, -2150., -1228.800049, -630., -624.799561, -619.599854, -1198.400024, -609.199463, -603.999756, -598.800049, -1167.999878, -588.400391, -583.199951, -578., -1411.200073, -1973.599854, -1961.000122, -1948.400146, -1107.199829, -546.800171, -541.599976, -536.400269, -1076.800049, -525.999756, -520.800049, -515.599976, -1046.400146, -505.199829, -500., -494.799683, -1270.399902, -1772., -1759.400146, -1746.799927, -985.599976, -463.600098, -458.399902, -453.199951, -955.199951, -442.799927, -437.599976, -432.400269, -924.799988, -422.000122, -416.800171, -411.599976, -1763.199951, -3037.25, -3021.449951, -3005.649902, -2529.800293, -4096.299805, -4074.699951, -4053.100098, -2476.600098, -4009.900146, -3988.300049, -3966.699951, -2423.399902, -3923.5, -3901.899902, -3880.299805, -1459.999878, -1991.699951, -1980.099854, -1968.500122, -1046.800049, -266.999878, -263.799805, -260.599854, -1020.400146, -254.199829, -251., -247.799927, -994., -241.400269, -238.200073, -234.999878, -1327.200073, -1806.099854, -1794.500122, -1782.900146, -941.199951, -215.799927, -212.600098, -209.399902, -914.799988, -203.000122, -199.799683, -196.599976, -888.400024, -190.200317, -186.999878, -183.799805, -1194.399902, -1620.500122, -1608.899902, -1597.299927, -835.599915, -164.599976, -161.400269, -158.200195, -809.200073, -151.799927, -148.599976, -145.400024, -782.799927, -139., -135.799805, -132.599976, -1715.200073, -2941.5, -2926.199951, -2910.899902, -2434.800049, -3906.799805, -3886.199951, -3865.599609, -2383.600098, -3824.400391, -3803.800049, -3783.199951, -2332.400146, -3742., -3721.400146, -3700.799805, -1367.999878, -1808.199707, -1797.599854, -1786.999878, -864.800049, 95.999878, 97.200073, 98.400024, -842.39978, 100.799927, 102.000244, 103.200439, -820., 105.599609, 106.800171, 108., -1243.199951, -1638.599854, -1628.000122, -1617.400146, -775.199829, 115.200195, 116.400146, 117.60022, -752.799805, 120., 121.200073, 122.400024, -730.399841, 124.799927, 125.999878, 127.199951, -1118.400024, -1468.999878, -1458.400146, -1447.799927, -685.599915, 134.400146, 135.60022, 136.800171, -663.199951, 139.200073, 140.399902, 141.599731, -640.799988, 144., 145.200195, 146.400146, -1667.199951, -2845.749756, -2830.949707, -2816.149902, -2339.799805, -3717.300049, -3697.699951, -3678.100098, -2290.600098, -3638.900146, -3619.300049, -3599.699951, -2241.399902, -3560.5, -3540.899902, -3521.299805, -1276., -1624.699951, -1615.100098, -1605.499878, -682.799927, 459.000122, 458.199951, 457.400146, -664.400024, 455.800049, 454.999878, 454.200439, -646.000122, 452.599976, 451.799805, 451.000122, -1159.200073, -1471.099854, -1461.5, -1451.900146, -609.199829, 446.200195, 445.400024, 444.600098, -590.799927, 443., 442.200073, 441.399658, -572.39978, 439.799927, 439.000122, 438.200073, -1042.399902, -1317.499756, -1307.900146, -1298.299683, -535.599976, 433.399963, 432.600098, 431.799744, -517.200012, 430.200195, 429.400024, 428.599976, -498.799927, 427.000061, 426.200256, 425.400024, -1619.199951, -2750., -2735.699951, -2721.399902, -2244.799805, -3527.799805, -3509.199951, -3490.600098, -2197.600098, -3453.400146, -3434.800049, -3416.199951, -2150.399902, -3379., -3360.400146, -3341.800049, -1184., -1441.199951, -1432.599854, -1424., -500.799927, 822.000122, 819.200195, 816.400146, -486.400024, 810.799927, 808.000244, 805.200073, -472., 799.60022, 796.799683, 794.000122, -1075.199951, -1303.599854, -1295.000122, -1286.400024, -443.199951, 777.200073, 774.400024, 771.599854, -428.799927, 766., 763.200317, 760.400024, -414.400146, 754.800049, 752.000244, 749.200195, -966.400146, -1166.000122, -1157.400146, -1148.799927, -385.600098, 732.400024, 729.599976, 726.799927, -371.200134, 721.200012, 718.400146, 715.599792, -356.799988, 710.000183, 707.199951, 704.400024, -1571.199951, -2654.25, -2640.449951, -2626.649902, -2149.800049, -3338.299805, -3320.699951, -3303.100098, -2104.600098, -3267.900146, -3250.299805, -3232.699951, -2059.399902, -3197.5, -3179.900146, -3162.300049, -1092., -1257.699951, -1250.099854, -1242.499878, -318.799927, 1185.000122, 1180.200439, 1175.400146, -308.399902, 1165.800293, 1161.000122, 1156.200073, -298., 1146.599731, 1141.800049, 1137.000122, -991.199951, -1136.099976, -1128.500122, -1120.899902, -277.199951, 1108.199829, 1103.400146, 1098.599976, -266.799927, 1089.000366, 1084.199951, 1079.400024, -256.399902, 1069.799927, 1065.000122, 1060.200317, -890.400024, -1014.5, -1006.900024, -999.299988, -235.599976, 1031.399902, 1026.599854, 1021.800049, -225.199951, 1012.200195, 1007.400024, 1002.599854, -214.799805, 992.999878, 988.199707, 983.400146, -1523.199951, -2558.5, -2545.199951, -2531.899902, -2054.800049, -3148.800049, -3132.199951, -3115.599854, -2011.599976, -3082.400146, -3065.800049, -3049.199951, -1968.400024, -3016., -2999.400146, -2982.799805, -1000.000061, -1074.199951, -1067.599976, -1061.000244, -136.799805, 1548.000244, 1541.200195, 1534.400269, -130.400146, 1520.800171, 1514.000122, 1507.200073, -124., 1493.600098, 1486.799805, 1480.000244, -907.200073, -968.599976, -962.000122, -955.400085, -111.199951, 1439.200073, 1432.399902, 1425.599854, -104.800049, 1412.000122, 1405.200195, 1398.400024, -98.400024, 1384.799927, 1378.000366, 1371.200195, -814.400024, -862.999939, -856.399902, -849.799927, -85.599976, 1330.400024, 1323.599854, 1316.799927, -79.200073, 1303.200073, 1296.399902, 1289.599731, -72.799927, 1276., 1269.200195, 1262.400024, -1475.200073, -2462.75, -2449.949951, -2437.149902, -1959.800049, -2959.299805, -2943.699951, -2928.099854, -1918.599976, -2896.900146, -2881.300049, -2865.699951, -1877.399902, -2834.5, -2818.900146, -2803.300049, -907.999939, -890.700012, -885.099915, -879.499878, 45.199829, 1911., 1902.200073, 1893.400024, 47.599976, 1875.800293, 1867.000244, 1858.200073, 49.999878, 1840.599976, 1831.800171, 1823.000244, -823.200073, -801.100098, -795.500061, -789.900024, 54.799927, 1770.199951, 1761.400269, 1752.599976, 57.200073, 1735., 1726.200073, 1717.400269, 59.599976, 1699.799805, 1691., 1682.200073, -738.400024, -711.499817, -705.900085, -700.299927, 64.400146, 1629.399902, 1620.599976, 1611.800171, 66.800049, 1594.200195, 1585.39978, 1576.599976, 69.200073, 1559.000122, 1550.199829, 1541.400146, 1260.800049, 2211.5, 2228.800049, 2246.100098, 1921.200073, 3207.200195, 3231.800049, 3256.399902, 1980.400024, 3305.599854, 3330.200195, 3354.800049, 2039.599854, 3404., 3428.599854, 3453.200195, 1400., 2129.800049, 2144.400146, 2159., 1479.199951, 1588.000244, 1597.200073, 1606.400024, 1517.599976, 1624.800171, 1634., 1643.199951, 1556., 1661.600098, 1670.800171, 1679.999878, 1556.799927, 2363.400146, 2378., 2392.600098, 1632.799805, 1735.199951, 1744.400146, 1753.600098, 1671.199829, 1771.999878, 1781.200073, 1790.400024, 1709.60022, 1808.800171, 1818.000244, 1827.200073, 1713.599976, 2597., 2611.599854, 2626.199951, 1786.400024, 1882.400024, 1891.600098, 1900.800171, 1824.799805, 1919.200195, 1928.400146, 1937.600098, 1863.199951, 1956., 1965.199951, 1974.400391, 1228.800049, 2147.25, 2164.049805, 2180.850098, 1856.199951, 3076.700195, 3100.300049, 3123.899902, 1913.400024, 3171.099854, 3194.700195, 3218.300049, 1970.599976, 3265.5, 3289.099854, 3312.699951, 1332., 1993.300049, 2006.900146, 2020.499878, 1341.199951, 1310.999878, 1318.199951, 1325.400146, 1375.60022, 1339.800171, 1347., 1354.199951, 1410., 1368.600098, 1375.800171, 1383., 1480.800049, 2210.900146, 2224.5, 2238.100098, 1478.799805, 1426.200073, 1433.400146, 1440.599609, 1513.199951, 1455., 1462.199951, 1469.400024, 1547.60022, 1483.799927, 1490.999878, 1498.199951, 1629.599976, 2428.500244, 2442.100098, 2455.699951, 1616.399902, 1541.400146, 1548.600098, 1555.799683, 1650.800049, 1570.200073, 1577.400024, 1584.600098, 1685.199951, 1598.99939, 1606.200317, 1613.400024, 1196.800049, 2083., 2099.300049, 2115.600098, 1791.200073, 2946.200195, 2968.800049, 2991.400146, 1846.400024, 3036.599854, 3059.200195, 3081.800049, 1901.599976, 3127., 3149.599854, 3172.200195, 1264., 1856.800049, 1869.400146, 1881.999878, 1203.200073, 1034., 1039.200073, 1044.400146, 1233.599976, 1054.799927, 1059.999878, 1065.199951, 1263.999878, 1075.599609, 1080.800171, 1086., 1404.799927, 2058.400146, 2071., 2083.599854, 1324.799927, 1117.199951, 1122.400146, 1127.599609, 1355.199951, 1138., 1143.200439, 1148.400146, 1385.599976, 1158.800171, 1164.000244, 1169.200073, 1545.599976, 2260., 2272.600098, 2285.199951, 1446.400024, 1200.400146, 1205.600098, 1210.800171, 1476.799805, 1221.199951, 1226.400024, 1231.600098, 1507.199951, 1242.000244, 1247.200073, 1252.400146, 1164.800049, 2018.75, 2034.549927, 2050.350098, 1726.200073, 2815.700195, 2837.300049, 2858.900146, 1779.400024, 2902.099854, 2923.700195, 2945.300049, 1832.599976, 2988.5, 3010.099854, 3031.700195, 1196.000122, 1720.300049, 1731.900146, 1743.499878, 1065.200073, 757.000122, 760.200073, 763.400024, 1091.599976, 769.800171, 773., 776.199951, 1118., 782.599976, 785.800049, 789., 1328.800049, 1905.900146, 1917.499878, 1929.100098, 1170.799805, 808.200073, 811.400024, 814.60022, 1197.199951, 821., 824.199951, 827.400024, 1223.599976, 833.799927, 837.000244, 840.199951, 1461.599976, 2091.5, 2103.100098, 2114.700195, 1276.400146, 859.400024, 862.600098, 865.800293, 1302.799927, 872.200073, 875.400146, 878.599854, 1329.199951, 885., 888.199951, 891.400024, 1132.800049, 1954.500122, 1969.799927, 1985.099976, 1661.199951, 2685.200195, 2705.800049, 2726.399902, 1712.399902, 2767.599854, 2788.200195, 2808.800049, 1763.599976, 2850., 2870.599854, 2891.199951, 1128., 1583.800049, 1594.400146, 1605., 927.200012, 480., 481.199951, 482.400146, 949.599976, 484.800171, 486., 487.200073, 971.999878, 489.599731, 490.800171, 492.000122, 1252.799927, 1753.400146, 1763.999878, 1774.600098, 1016.799805, 499.200195, 500.400024, 501.60022, 1039.199951, 504., 505.199951, 506.400146, 1061.599976, 508.799927, 510., 511.200195, 1377.599976, 1923.000122, 1933.600098, 1944.200073, 1106.400024, 518.400024, 519.60022, 520.800171, 1128.799927, 523.199829, 524.400024, 525.600098, 1151.199829, 528., 529.199829, 530.400146, 1100.800049, 1890.25, 1905.050049, 1919.849976, 1596.199951, 2554.700195, 2574.300049, 2593.900146, 1645.399902, 2633.099854, 2652.700195, 2672.300049, 1694.599976, 2711.5, 2731.099854, 2750.700195, 1060., 1447.299805, 1456.900146, 1466.499878, 789.200012, 203.000122, 202.200195, 201.400146, 807.600098, 199.800171, 199., 198.200195, 826., 196.599731, 195.800049, 195., 1176.799927, 1600.900146, 1610.500244, 1620.099854, 862.80011, 190.200317, 189.400146, 188.60022, 881.199951, 187., 186.199829, 185.400024, 899.60022, 183.800171, 183., 182.200073, 1293.599976, 1754.499878, 1764.099854, 1773.700073, 936.400024, 177.400146, 176.60022, 175.800049, 954.799805, 174.199951, 173.400024, 172.599854, 973.200073, 171., 170.200073, 169.400146, 1068.800049, 1826., 1840.299927, 1854.599976, 1531.199951, 2424.200195, 2442.800049, 2461.399902, 1578.399902, 2498.599854, 2517.199951, 2535.800049, 1625.599976, 2573., 2591.599854, 2610.200195, 991.999939, 1310.800049, 1319.400146, 1328., 651.199951, -74., -76.799805, -79.599854, 665.600098, -85.199829, -87.999756, -90.799805, 680., -96.400024, -99.199829, -102., 1100.800049, 1448.400146, 1456.999878, 1465.600098, 708.800049, -118.799805, -121.599976, -124.400269, 723.199829, -130., -132.800171, -135.599976, 737.599976, -141.200073, -144., -146.799805, 1209.599976, 1586., 1594.600098, 1603.200073, 766.400146, -163.599976, -166.39978, -169.200073, 780.800049, -174.799927, -177.599976, -180.400146, 795.199951, -185.999878, -188.800171, -191.599854, 1036.800049, 1761.75, 1775.550049, 1789.349976, 1466.200073, 2293.700195, 2311.300049, 2328.900146, 1511.399902, 2364.099854, 2381.700195, 2399.300049, 1556.599976, 2434.5, 2452.099854, 2469.700195, 923.999939, 1174.300049, 1181.899902, 1189.5, 513.200073, -350.999756, -355.799805, -360.599854, 523.599976, -370.199951, -374.999939, -379.799805, 534., -389.400146, -394.19989, -398.999817, 1024.800049, 1295.900146, 1303.5, 1311.10022, 554.799927, -427.800171, -432.599854, -437.400146, 565.199951, -446.999878, -451.799805, -456.599854, 575.599976, -466.200317, -470.999756, -475.799805, 1125.599976, 1417.499878, 1425.100098, 1432.700073, 596.400024, -504.599854, -509.400269, -514.199951, 606.800049, -523.800171, -528.599609, -533.400146, 617.200073, -542.999878, -547.800171, -552.599854, 1004.800049, 1697.5, 1710.799927, 1724.099976, 1401.199951, 2163.200195, 2179.800049, 2196.400146, 1444.400024, 2229.599854, 2246.200195, 2262.800049, 1487.599976, 2296., 2312.599854, 2329.200195, 855.999939, 1037.800049, 1044.400146, 1051., 375.199951, -627.999756, -634.800171, -641.599976, 381.599976, -655.199829, -661.999878, -668.80011, 388.000061, -682.400146, -689.199951, -695.999756, 948.799988, 1143.400146, 1149.999878, 1156.60022, 400.799805, -736.799927, -743.599976, -750.399902, 407.200073, -763.999878, -770.799805, -777.599731, 413.599976, -791.200073, -797.999756, -804.800171, 1041.599976, 1248.999878, 1255.60022, 1262.200073, 426.399902, -845.599854, -852.400146, -859.200073, 432.799927, -872.799805, -879.599854, -886.400024, 439.200073, -899.999878, -906.799927, -913.599976, 972.800049, 1633.25, 1646.049927, 1658.850098, 1336.200073, 2032.700195, 2048.300049, 2063.900146, 1377.400024, 2095.099854, 2110.700195, 2126.300049, 1418.599976, 2157.5, 2173.099854, 2188.700195, 787.999939, 901.299988, 906.899963, 912.500061, 237.200012, -904.999817, -913.799866, -922.599792, 239.599976, -940.199707, -948.999817, -957.800171, 242., -975.400146, -984.199829, -992.999756, 872.799988, 990.899963, 996.499878, 1002.10022, 246.800049, -1045.799927, -1054.599854, -1063.400024, 249.200073, -1080.999878, -1089.799805, -1098.599854, 251.600098, -1116.199951, -1124.999878, -1133.799683, 957.599976, 1080.499878, 1086.10022, 1091.700073, 256.400024, -1186.599854, -1195.400146, -1204.199829, 258.799927, -1221.800171, -1230.599976, -1239.400269, 261.199951, -1257., -1265.799927, -1274.600098}, sd::DataType::FLOAT32); - - input.linspace(-32, 0.1); - - sd::ops::deconv3d op; - auto results = op.evaluate({&input, &weights}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); - ASSERT_EQ(Status::OK(), results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 2, oD = 4, oH = 4, oW = 4, iC = 5, oC = 10, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 1, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int iD = 4, iH = 4, iW = 4; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - + // [iC, kD, kH, kW, oC] + + NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {iC, kD, kH, kW, oC}, + {20., 19.5, 19., 18.5, 18., 17.5, + 17., 16.5, 16., 15.5, 15., 14.5, + 14., 13.5, 13., 12.5, 12., 11.5, + 11., 10.5, 10., 9.5, 9., 8.5, + 8., 7.5, 7., 6.5, 6., 5.5, + 5., 4.5, 4., 3.5, 3., 2.5, + 2., 1.5, 1., 0.5, 0., -0.5, + -1., -1.5, -2., -2.5, -3., -3.5, + -4., -4.5, -5., -5.5, -6., -6.5, + -7., -7.5, -8., -8.5, -9., -9.5, + -10., -10.5, -11., -11.5, -12., -12.5, + -13., -13.5, -14., -14.5, -15., -15.5, + -16., -16.5, -17., -17.5, -18., -18.5, + -19., -19.5, 19.9, 19.4, 18.9, 18.4, + 17.9, 17.4, 16.9, 16.4, 15.9, 15.4, + 14.9, 14.4, 13.9, 13.4, 12.9, 12.4, + 11.9, 11.4, 10.9, 10.4, 9.9, 9.4, + 8.9, 8.4, 7.9, 7.4, 6.9, 6.4, + 5.9, 5.4, 4.9, 4.4, 3.9, 3.4, + 2.9, 2.4, 1.9, 1.4, 0.9, 0.4, + -0.1, -0.6, -1.1, -1.6, -2.1, -2.6, + -3.1, -3.6, -4.1, -4.6, -5.1, -5.6, + -6.1, -6.6, -7.1, -7.6, -8.1, -8.6, + -9.1, -9.6, -10.1, -10.6, -11.1, -11.6, + -12.1, -12.6, -13.1, -13.6, -14.1, -14.6, + -15.1, -15.6, -16.1, -16.6, -17.1, -17.6, + -18.1, -18.6, -19.1, -19.6, 19.799999, 19.299999, + 18.799999, 18.299999, 17.799999, 17.299999, 16.799999, 16.299999, + 15.8, 15.3, 14.8, 14.3, 13.8, 13.3, + 12.8, 12.3, 11.8, 11.3, 10.8, 10.3, + 9.8, 9.3, 8.8, 8.3, 7.8, 7.3, + 6.8, 6.3, 5.8, 5.3, 4.8, 4.3, + 3.8, 3.3, 2.8, 2.3, 1.8, 1.3, + 0.8, 0.3, -0.2, -0.7, -1.2, -1.7, + -2.2, -2.7, -3.2, -3.7, -4.2, -4.7, + -5.2, -5.7, -6.2, -6.7, -7.2, -7.7, + -8.2, -8.7, -9.2, -9.7, -10.2, -10.7, + -11.2, -11.7, -12.2, -12.7, -13.2, -13.7, + -14.2, -14.7, -15.2, -15.7, -16.200001, -16.700001, + -17.200001, -17.700001, -18.200001, -18.700001, -19.200001, -19.700001, + 19.700001, 19.200001, 18.700001, 18.200001, 17.700001, 17.200001, + 16.700001, 16.200001, 15.7, 15.2, 14.7, 14.2, + 13.7, 13.2, 12.7, 12.2, 11.7, 11.2, + 10.7, 10.2, 9.7, 9.2, 8.7, 8.2, + 7.7, 7.2, 6.7, 6.2, 5.7, 5.2, + 4.7, 4.2, 3.7, 3.2, 2.7, 2.2, + 1.7, 1.2, 0.7, 0.2, -0.3, -0.8, + -1.3, -1.8, -2.3, -2.8, -3.3, -3.8, + -4.3, -4.8, -5.3, -5.8, -6.3, -6.8, + -7.3, -7.8, -8.3, -8.8, -9.3, -9.8, + -10.3, -10.8, -11.3, -11.8, -12.3, -12.8, + -13.3, -13.8, -14.3, -14.8, -15.3, -15.8, + -16.299999, -16.799999, -17.299999, -17.799999, -18.299999, -18.799999, + -19.299999, -19.799999, 19.6, 19.1, 18.6, 18.1, + 17.6, 17.1, 16.6, 16.1, 15.6, 15.1, + 14.6, 14.1, 13.6, 13.1, 12.6, 12.1, + 11.6, 11.1, 10.6, 10.1, 9.6, 9.1, + 8.6, 8.1, 7.6, 7.1, 6.6, 6.1, + 5.6, 5.1, 4.6, 4.1, 3.6, 3.1, + 2.6, 2.1, 1.6, 1.1, 0.6, 0.1, + -0.4, -0.9, -1.4, -1.9, -2.4, -2.9, + -3.4, -3.9, -4.4, -4.9, -5.4, -5.9, + -6.4, -6.9, -7.4, -7.9, -8.4, -8.9, + -9.4, -9.9, -10.4, -10.9, -11.4, -11.9, + -12.4, -12.9, -13.4, -13.9, -14.4, -14.9, + -15.4, -15.9, -16.4, -16.9, -17.4, -17.9, + -18.4, -18.9, -19.4, -19.9}, + sd::DataType::FLOAT32); + NDArray expOutput( + 'c', {bS, oC, oD, oH, oW}, + {-1907.199951, -3324.499756, -3307.199707, -3289.899902, -2814.799805, + -4664.800293, -4640.199707, -4615.600098, -2755.599854, -4566.400391, + -4541.800293, -4517.199707, -2696.400146, -4468., -4443.400391, + -4418.799805, -1735.999878, -2542.199951, -2527.600098, -2513., + -1592.800049, -1355.999756, -1346.799805, -1337.599854, -1554.400024, + -1319.199829, -1310.000122, -1300.800049, -1516., -1282.400024, + -1273.200195, -1263.999878, -1579.200073, -2308.599854, -2294., + -2279.400146, -1439.199951, -1208.799683, -1199.599976, -1190.399902, + -1400.800049, -1172., -1162.800049, -1153.600098, -1362.399902, + -1135.199951, -1126., -1116.799805, -1422.400024, -2075., + -2060.399902, -2045.799683, -1285.599976, -1061.599854, -1052.399902, + -1043.200195, -1247.199951, -1024.800049, -1015.599976, -1006.400146, + -1208.799927, -988.000122, -978.799683, -969.599976, -1859.199951, + -3228.75, -3211.949951, -3195.150146, -2719.800049, -4475.299805, + -4451.699707, -4428.100098, -2662.600098, -4380.899902, -4357.300293, + -4333.699707, -2605.399902, -4286.5, -4262.899902, -4239.300293, + -1643.999878, -2358.700195, -2345.099854, -2331.5, -1410.800049, + -992.999756, -985.799438, -978.600098, -1376.400024, -964.199707, + -957., -949.800049, -1342., -935.399902, -928.199951, + -921.000122, -1495.200073, -2141.099854, -2127.5, -2113.900391, + -1273.199951, -877.799683, -870.599976, -863.39978, -1238.800049, + -849., -841.800171, -834.599976, -1204.400024, -820.199707, + -813., -805.799438, -1346.400146, -1923.500122, -1909.899902, + -1896.299927, -1135.599976, -762.599976, -755.399658, -748.200195, + -1101.199951, -733.800049, -726.599854, -719.400024, -1066.800049, + -705., -697.800171, -690.599976, -1811.199951, -3133., + -3116.699951, -3100.399902, -2624.799805, -4285.799805, -4263.199707, + -4240.600098, -2569.600098, -4195.399902, -4172.800293, -4150.199707, + -2514.399902, -4105., -4082.400146, -4059.800293, -1552., + -2175.200195, -2162.599854, -2150., -1228.800049, -630., + -624.799561, -619.599854, -1198.400024, -609.199463, -603.999756, + -598.800049, -1167.999878, -588.400391, -583.199951, -578., + -1411.200073, -1973.599854, -1961.000122, -1948.400146, -1107.199829, + -546.800171, -541.599976, -536.400269, -1076.800049, -525.999756, + -520.800049, -515.599976, -1046.400146, -505.199829, -500., + -494.799683, -1270.399902, -1772., -1759.400146, -1746.799927, + -985.599976, -463.600098, -458.399902, -453.199951, -955.199951, + -442.799927, -437.599976, -432.400269, -924.799988, -422.000122, + -416.800171, -411.599976, -1763.199951, -3037.25, -3021.449951, + -3005.649902, -2529.800293, -4096.299805, -4074.699951, -4053.100098, + -2476.600098, -4009.900146, -3988.300049, -3966.699951, -2423.399902, + -3923.5, -3901.899902, -3880.299805, -1459.999878, -1991.699951, + -1980.099854, -1968.500122, -1046.800049, -266.999878, -263.799805, + -260.599854, -1020.400146, -254.199829, -251., -247.799927, + -994., -241.400269, -238.200073, -234.999878, -1327.200073, + -1806.099854, -1794.500122, -1782.900146, -941.199951, -215.799927, + -212.600098, -209.399902, -914.799988, -203.000122, -199.799683, + -196.599976, -888.400024, -190.200317, -186.999878, -183.799805, + -1194.399902, -1620.500122, -1608.899902, -1597.299927, -835.599915, + -164.599976, -161.400269, -158.200195, -809.200073, -151.799927, + -148.599976, -145.400024, -782.799927, -139., -135.799805, + -132.599976, -1715.200073, -2941.5, -2926.199951, -2910.899902, + -2434.800049, -3906.799805, -3886.199951, -3865.599609, -2383.600098, + -3824.400391, -3803.800049, -3783.199951, -2332.400146, -3742., + -3721.400146, -3700.799805, -1367.999878, -1808.199707, -1797.599854, + -1786.999878, -864.800049, 95.999878, 97.200073, 98.400024, + -842.39978, 100.799927, 102.000244, 103.200439, -820., + 105.599609, 106.800171, 108., -1243.199951, -1638.599854, + -1628.000122, -1617.400146, -775.199829, 115.200195, 116.400146, + 117.60022, -752.799805, 120., 121.200073, 122.400024, + -730.399841, 124.799927, 125.999878, 127.199951, -1118.400024, + -1468.999878, -1458.400146, -1447.799927, -685.599915, 134.400146, + 135.60022, 136.800171, -663.199951, 139.200073, 140.399902, + 141.599731, -640.799988, 144., 145.200195, 146.400146, + -1667.199951, -2845.749756, -2830.949707, -2816.149902, -2339.799805, + -3717.300049, -3697.699951, -3678.100098, -2290.600098, -3638.900146, + -3619.300049, -3599.699951, -2241.399902, -3560.5, -3540.899902, + -3521.299805, -1276., -1624.699951, -1615.100098, -1605.499878, + -682.799927, 459.000122, 458.199951, 457.400146, -664.400024, + 455.800049, 454.999878, 454.200439, -646.000122, 452.599976, + 451.799805, 451.000122, -1159.200073, -1471.099854, -1461.5, + -1451.900146, -609.199829, 446.200195, 445.400024, 444.600098, + -590.799927, 443., 442.200073, 441.399658, -572.39978, + 439.799927, 439.000122, 438.200073, -1042.399902, -1317.499756, + -1307.900146, -1298.299683, -535.599976, 433.399963, 432.600098, + 431.799744, -517.200012, 430.200195, 429.400024, 428.599976, + -498.799927, 427.000061, 426.200256, 425.400024, -1619.199951, + -2750., -2735.699951, -2721.399902, -2244.799805, -3527.799805, + -3509.199951, -3490.600098, -2197.600098, -3453.400146, -3434.800049, + -3416.199951, -2150.399902, -3379., -3360.400146, -3341.800049, + -1184., -1441.199951, -1432.599854, -1424., -500.799927, + 822.000122, 819.200195, 816.400146, -486.400024, 810.799927, + 808.000244, 805.200073, -472., 799.60022, 796.799683, + 794.000122, -1075.199951, -1303.599854, -1295.000122, -1286.400024, + -443.199951, 777.200073, 774.400024, 771.599854, -428.799927, + 766., 763.200317, 760.400024, -414.400146, 754.800049, + 752.000244, 749.200195, -966.400146, -1166.000122, -1157.400146, + -1148.799927, -385.600098, 732.400024, 729.599976, 726.799927, + -371.200134, 721.200012, 718.400146, 715.599792, -356.799988, + 710.000183, 707.199951, 704.400024, -1571.199951, -2654.25, + -2640.449951, -2626.649902, -2149.800049, -3338.299805, -3320.699951, + -3303.100098, -2104.600098, -3267.900146, -3250.299805, -3232.699951, + -2059.399902, -3197.5, -3179.900146, -3162.300049, -1092., + -1257.699951, -1250.099854, -1242.499878, -318.799927, 1185.000122, + 1180.200439, 1175.400146, -308.399902, 1165.800293, 1161.000122, + 1156.200073, -298., 1146.599731, 1141.800049, 1137.000122, + -991.199951, -1136.099976, -1128.500122, -1120.899902, -277.199951, + 1108.199829, 1103.400146, 1098.599976, -266.799927, 1089.000366, + 1084.199951, 1079.400024, -256.399902, 1069.799927, 1065.000122, + 1060.200317, -890.400024, -1014.5, -1006.900024, -999.299988, + -235.599976, 1031.399902, 1026.599854, 1021.800049, -225.199951, + 1012.200195, 1007.400024, 1002.599854, -214.799805, 992.999878, + 988.199707, 983.400146, -1523.199951, -2558.5, -2545.199951, + -2531.899902, -2054.800049, -3148.800049, -3132.199951, -3115.599854, + -2011.599976, -3082.400146, -3065.800049, -3049.199951, -1968.400024, + -3016., -2999.400146, -2982.799805, -1000.000061, -1074.199951, + -1067.599976, -1061.000244, -136.799805, 1548.000244, 1541.200195, + 1534.400269, -130.400146, 1520.800171, 1514.000122, 1507.200073, + -124., 1493.600098, 1486.799805, 1480.000244, -907.200073, + -968.599976, -962.000122, -955.400085, -111.199951, 1439.200073, + 1432.399902, 1425.599854, -104.800049, 1412.000122, 1405.200195, + 1398.400024, -98.400024, 1384.799927, 1378.000366, 1371.200195, + -814.400024, -862.999939, -856.399902, -849.799927, -85.599976, + 1330.400024, 1323.599854, 1316.799927, -79.200073, 1303.200073, + 1296.399902, 1289.599731, -72.799927, 1276., 1269.200195, + 1262.400024, -1475.200073, -2462.75, -2449.949951, -2437.149902, + -1959.800049, -2959.299805, -2943.699951, -2928.099854, -1918.599976, + -2896.900146, -2881.300049, -2865.699951, -1877.399902, -2834.5, + -2818.900146, -2803.300049, -907.999939, -890.700012, -885.099915, + -879.499878, 45.199829, 1911., 1902.200073, 1893.400024, + 47.599976, 1875.800293, 1867.000244, 1858.200073, 49.999878, + 1840.599976, 1831.800171, 1823.000244, -823.200073, -801.100098, + -795.500061, -789.900024, 54.799927, 1770.199951, 1761.400269, + 1752.599976, 57.200073, 1735., 1726.200073, 1717.400269, + 59.599976, 1699.799805, 1691., 1682.200073, -738.400024, + -711.499817, -705.900085, -700.299927, 64.400146, 1629.399902, + 1620.599976, 1611.800171, 66.800049, 1594.200195, 1585.39978, + 1576.599976, 69.200073, 1559.000122, 1550.199829, 1541.400146, + 1260.800049, 2211.5, 2228.800049, 2246.100098, 1921.200073, + 3207.200195, 3231.800049, 3256.399902, 1980.400024, 3305.599854, + 3330.200195, 3354.800049, 2039.599854, 3404., 3428.599854, + 3453.200195, 1400., 2129.800049, 2144.400146, 2159., + 1479.199951, 1588.000244, 1597.200073, 1606.400024, 1517.599976, + 1624.800171, 1634., 1643.199951, 1556., 1661.600098, + 1670.800171, 1679.999878, 1556.799927, 2363.400146, 2378., + 2392.600098, 1632.799805, 1735.199951, 1744.400146, 1753.600098, + 1671.199829, 1771.999878, 1781.200073, 1790.400024, 1709.60022, + 1808.800171, 1818.000244, 1827.200073, 1713.599976, 2597., + 2611.599854, 2626.199951, 1786.400024, 1882.400024, 1891.600098, + 1900.800171, 1824.799805, 1919.200195, 1928.400146, 1937.600098, + 1863.199951, 1956., 1965.199951, 1974.400391, 1228.800049, + 2147.25, 2164.049805, 2180.850098, 1856.199951, 3076.700195, + 3100.300049, 3123.899902, 1913.400024, 3171.099854, 3194.700195, + 3218.300049, 1970.599976, 3265.5, 3289.099854, 3312.699951, + 1332., 1993.300049, 2006.900146, 2020.499878, 1341.199951, + 1310.999878, 1318.199951, 1325.400146, 1375.60022, 1339.800171, + 1347., 1354.199951, 1410., 1368.600098, 1375.800171, + 1383., 1480.800049, 2210.900146, 2224.5, 2238.100098, + 1478.799805, 1426.200073, 1433.400146, 1440.599609, 1513.199951, + 1455., 1462.199951, 1469.400024, 1547.60022, 1483.799927, + 1490.999878, 1498.199951, 1629.599976, 2428.500244, 2442.100098, + 2455.699951, 1616.399902, 1541.400146, 1548.600098, 1555.799683, + 1650.800049, 1570.200073, 1577.400024, 1584.600098, 1685.199951, + 1598.99939, 1606.200317, 1613.400024, 1196.800049, 2083., + 2099.300049, 2115.600098, 1791.200073, 2946.200195, 2968.800049, + 2991.400146, 1846.400024, 3036.599854, 3059.200195, 3081.800049, + 1901.599976, 3127., 3149.599854, 3172.200195, 1264., + 1856.800049, 1869.400146, 1881.999878, 1203.200073, 1034., + 1039.200073, 1044.400146, 1233.599976, 1054.799927, 1059.999878, + 1065.199951, 1263.999878, 1075.599609, 1080.800171, 1086., + 1404.799927, 2058.400146, 2071., 2083.599854, 1324.799927, + 1117.199951, 1122.400146, 1127.599609, 1355.199951, 1138., + 1143.200439, 1148.400146, 1385.599976, 1158.800171, 1164.000244, + 1169.200073, 1545.599976, 2260., 2272.600098, 2285.199951, + 1446.400024, 1200.400146, 1205.600098, 1210.800171, 1476.799805, + 1221.199951, 1226.400024, 1231.600098, 1507.199951, 1242.000244, + 1247.200073, 1252.400146, 1164.800049, 2018.75, 2034.549927, + 2050.350098, 1726.200073, 2815.700195, 2837.300049, 2858.900146, + 1779.400024, 2902.099854, 2923.700195, 2945.300049, 1832.599976, + 2988.5, 3010.099854, 3031.700195, 1196.000122, 1720.300049, + 1731.900146, 1743.499878, 1065.200073, 757.000122, 760.200073, + 763.400024, 1091.599976, 769.800171, 773., 776.199951, + 1118., 782.599976, 785.800049, 789., 1328.800049, + 1905.900146, 1917.499878, 1929.100098, 1170.799805, 808.200073, + 811.400024, 814.60022, 1197.199951, 821., 824.199951, + 827.400024, 1223.599976, 833.799927, 837.000244, 840.199951, + 1461.599976, 2091.5, 2103.100098, 2114.700195, 1276.400146, + 859.400024, 862.600098, 865.800293, 1302.799927, 872.200073, + 875.400146, 878.599854, 1329.199951, 885., 888.199951, + 891.400024, 1132.800049, 1954.500122, 1969.799927, 1985.099976, + 1661.199951, 2685.200195, 2705.800049, 2726.399902, 1712.399902, + 2767.599854, 2788.200195, 2808.800049, 1763.599976, 2850., + 2870.599854, 2891.199951, 1128., 1583.800049, 1594.400146, + 1605., 927.200012, 480., 481.199951, 482.400146, + 949.599976, 484.800171, 486., 487.200073, 971.999878, + 489.599731, 490.800171, 492.000122, 1252.799927, 1753.400146, + 1763.999878, 1774.600098, 1016.799805, 499.200195, 500.400024, + 501.60022, 1039.199951, 504., 505.199951, 506.400146, + 1061.599976, 508.799927, 510., 511.200195, 1377.599976, + 1923.000122, 1933.600098, 1944.200073, 1106.400024, 518.400024, + 519.60022, 520.800171, 1128.799927, 523.199829, 524.400024, + 525.600098, 1151.199829, 528., 529.199829, 530.400146, + 1100.800049, 1890.25, 1905.050049, 1919.849976, 1596.199951, + 2554.700195, 2574.300049, 2593.900146, 1645.399902, 2633.099854, + 2652.700195, 2672.300049, 1694.599976, 2711.5, 2731.099854, + 2750.700195, 1060., 1447.299805, 1456.900146, 1466.499878, + 789.200012, 203.000122, 202.200195, 201.400146, 807.600098, + 199.800171, 199., 198.200195, 826., 196.599731, + 195.800049, 195., 1176.799927, 1600.900146, 1610.500244, + 1620.099854, 862.80011, 190.200317, 189.400146, 188.60022, + 881.199951, 187., 186.199829, 185.400024, 899.60022, + 183.800171, 183., 182.200073, 1293.599976, 1754.499878, + 1764.099854, 1773.700073, 936.400024, 177.400146, 176.60022, + 175.800049, 954.799805, 174.199951, 173.400024, 172.599854, + 973.200073, 171., 170.200073, 169.400146, 1068.800049, + 1826., 1840.299927, 1854.599976, 1531.199951, 2424.200195, + 2442.800049, 2461.399902, 1578.399902, 2498.599854, 2517.199951, + 2535.800049, 1625.599976, 2573., 2591.599854, 2610.200195, + 991.999939, 1310.800049, 1319.400146, 1328., 651.199951, + -74., -76.799805, -79.599854, 665.600098, -85.199829, + -87.999756, -90.799805, 680., -96.400024, -99.199829, + -102., 1100.800049, 1448.400146, 1456.999878, 1465.600098, + 708.800049, -118.799805, -121.599976, -124.400269, 723.199829, + -130., -132.800171, -135.599976, 737.599976, -141.200073, + -144., -146.799805, 1209.599976, 1586., 1594.600098, + 1603.200073, 766.400146, -163.599976, -166.39978, -169.200073, + 780.800049, -174.799927, -177.599976, -180.400146, 795.199951, + -185.999878, -188.800171, -191.599854, 1036.800049, 1761.75, + 1775.550049, 1789.349976, 1466.200073, 2293.700195, 2311.300049, + 2328.900146, 1511.399902, 2364.099854, 2381.700195, 2399.300049, + 1556.599976, 2434.5, 2452.099854, 2469.700195, 923.999939, + 1174.300049, 1181.899902, 1189.5, 513.200073, -350.999756, + -355.799805, -360.599854, 523.599976, -370.199951, -374.999939, + -379.799805, 534., -389.400146, -394.19989, -398.999817, + 1024.800049, 1295.900146, 1303.5, 1311.10022, 554.799927, + -427.800171, -432.599854, -437.400146, 565.199951, -446.999878, + -451.799805, -456.599854, 575.599976, -466.200317, -470.999756, + -475.799805, 1125.599976, 1417.499878, 1425.100098, 1432.700073, + 596.400024, -504.599854, -509.400269, -514.199951, 606.800049, + -523.800171, -528.599609, -533.400146, 617.200073, -542.999878, + -547.800171, -552.599854, 1004.800049, 1697.5, 1710.799927, + 1724.099976, 1401.199951, 2163.200195, 2179.800049, 2196.400146, + 1444.400024, 2229.599854, 2246.200195, 2262.800049, 1487.599976, + 2296., 2312.599854, 2329.200195, 855.999939, 1037.800049, + 1044.400146, 1051., 375.199951, -627.999756, -634.800171, + -641.599976, 381.599976, -655.199829, -661.999878, -668.80011, + 388.000061, -682.400146, -689.199951, -695.999756, 948.799988, + 1143.400146, 1149.999878, 1156.60022, 400.799805, -736.799927, + -743.599976, -750.399902, 407.200073, -763.999878, -770.799805, + -777.599731, 413.599976, -791.200073, -797.999756, -804.800171, + 1041.599976, 1248.999878, 1255.60022, 1262.200073, 426.399902, + -845.599854, -852.400146, -859.200073, 432.799927, -872.799805, + -879.599854, -886.400024, 439.200073, -899.999878, -906.799927, + -913.599976, 972.800049, 1633.25, 1646.049927, 1658.850098, + 1336.200073, 2032.700195, 2048.300049, 2063.900146, 1377.400024, + 2095.099854, 2110.700195, 2126.300049, 1418.599976, 2157.5, + 2173.099854, 2188.700195, 787.999939, 901.299988, 906.899963, + 912.500061, 237.200012, -904.999817, -913.799866, -922.599792, + 239.599976, -940.199707, -948.999817, -957.800171, 242., + -975.400146, -984.199829, -992.999756, 872.799988, 990.899963, + 996.499878, 1002.10022, 246.800049, -1045.799927, -1054.599854, + -1063.400024, 249.200073, -1080.999878, -1089.799805, -1098.599854, + 251.600098, -1116.199951, -1124.999878, -1133.799683, 957.599976, + 1080.499878, 1086.10022, 1091.700073, 256.400024, -1186.599854, + -1195.400146, -1204.199829, 258.799927, -1221.800171, -1230.599976, + -1239.400269, 261.199951, -1257., -1265.799927, -1274.600098}, + sd::DataType::FLOAT32); + + input.linspace(-32, 0.1); + + sd::ops::deconv3d op; + auto results = op.evaluate({&input, &weights}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat, wFormat}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_bp_test1) { - - int bS=1, iD=3,iH=3,iW=3, iC=1,oC=2, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto bias = NDArrayFactory::create('c', {iC}); - auto gradO = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - - NDArray expGradI('c', {bS, oD, oH, oW, oC}, {62., 67.6, 68.4, 74.8, 81.2, 89.2, 87.6, 96.4, 119.6, 132.4, 126., 139.6, 138.8, 154., 145.2, 161.2}, sd::DataType::FLOAT32); - NDArray expGradW('c', {kD, kH, kW, iC, oC}, {28., 28., 32., 32., 40., 40., 44., 44., 64, 64., 68., 68., 76., 76., 80., 80.}, sd::DataType::FLOAT32); - NDArray expGradB('c', {iC}, std::vector{364.5}, sd::DataType::FLOAT32); - - input = 0.5; - weights.linspace(0.1, 0.1); - gradO.linspace(0.5); - - sd::ops::deconv3d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); - - auto gradI = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); - - } + int bS = 1, iD = 3, iH = 3, iW = 3, iC = 1, oC = 2, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {iC}); + auto gradO = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + + NDArray expGradI('c', {bS, oD, oH, oW, oC}, + {62., 67.6, 68.4, 74.8, 81.2, 89.2, 87.6, 96.4, 119.6, 132.4, + 126., 139.6, 138.8, 154., 145.2, 161.2}, + sd::DataType::FLOAT32); + NDArray expGradW('c', {kD, kH, kW, iC, oC}, + {28., 28., 32., 32., 40., 40., 44., 44., 64, 64., 68., 68., + 76., 76., 80., 80.}, + sd::DataType::FLOAT32); + NDArray expGradB('c', {iC}, std::vector{364.5}, + sd::DataType::FLOAT32); + + input = 0.5; + weights.linspace(0.1, 0.1); + gradO.linspace(0.5); + + sd::ops::deconv3d_bp op; + auto results = op.evaluate( + {&input, &weights, &bias, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, dataFormat}, + {}); + + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_bp_test2) { - - int bS=1, iD=2,iH=2,iW=2, iC=1,oC=2, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto gradO = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - - NDArray expGradI('c', {bS, oD, oH, oW, oC}, {34, 37.2, 16.6, 18.4, 15.4, 17.4, 7.1, 8.2, 10.6, 13., 4.3, 5.6, 2.9, 4.3, 0.75, 1.5}, sd::DataType::FLOAT32); - NDArray expGradW('c', {kD, kH, kW, iC, oC}, {16, 16, 9, 9, 10, 10, 5.5, 5.5, 12, 12, 6.5, 6.5, 7, 7, 3.75, 3.75}, sd::DataType::FLOAT32); - - input = 0.5; - weights.linspace(0.1, 0.1); - gradO.linspace(0.5); - - sd::ops::deconv3d_bp op; - auto results = op.evaluate({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); - - auto gradI = results.at(0); - auto gradW = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - + int bS = 1, iD = 2, iH = 2, iW = 2, iC = 1, oC = 2, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto gradO = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + + NDArray expGradI('c', {bS, oD, oH, oW, oC}, + {34, 37.2, 16.6, 18.4, 15.4, 17.4, 7.1, 8.2, 10.6, 13., 4.3, + 5.6, 2.9, 4.3, 0.75, 1.5}, + sd::DataType::FLOAT32); + NDArray expGradW( + 'c', {kD, kH, kW, iC, oC}, + {16, 16, 9, 9, 10, 10, 5.5, 5.5, 12, 12, 6.5, 6.5, 7, 7, 3.75, 3.75}, + sd::DataType::FLOAT32); + + input = 0.5; + weights.linspace(0.1, 0.1); + gradO.linspace(0.5); + + sd::ops::deconv3d_bp op; + auto results = op.evaluate( + {&input, &weights, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, dataFormat}, + {}); + + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_bp_test3) { - - int bS=1, iD=3,iH=3,iW=3, iC=1,oC=2, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}, {0.1f, 0.9f, 0.2f, 0.1f, 0.3f, 1.1f, 0.4f, 1.2f, 0.5f, 1.3f, 0.6f, 1.4f, 0.7f, 1.5f, 0.8f, 1.6f}); - auto gradO = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - - NDArray expGradI('c', {bS, oD, oH, oW, oC}, {33.8, 37.4, 44.6, 48.2, 66.2, 69.8, 77., 80.6, 77.25, 86.35, 104.55, 113.65, 159.15, 168.25, 186.45, 195.55}, sd::DataType::FLOAT32); - NDArray expGradW('c', {kD, kH, kW, iC, oC}, {28., 28, 32, 32, 40, 40, 44, 44, 64, 64, 68, 68, 76, 76, 80, 80.}, sd::DataType::FLOAT32); - - input = 0.5; - gradO.linspace(0.5); - - sd::ops::deconv3d_bp op; - auto results = op.evaluate({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); - - auto gradI = results.at(0); - auto gradW = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - + int bS = 1, iD = 3, iH = 3, iW = 3, iC = 1, oC = 2, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); + auto weights = NDArrayFactory::create( + 'c', {kD, kH, kW, iC, oC}, + {0.1f, 0.9f, 0.2f, 0.1f, 0.3f, 1.1f, 0.4f, 1.2f, 0.5f, 1.3f, 0.6f, 1.4f, + 0.7f, 1.5f, 0.8f, 1.6f}); + auto gradO = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + + NDArray expGradI('c', {bS, oD, oH, oW, oC}, + {33.8, 37.4, 44.6, 48.2, 66.2, 69.8, 77., 80.6, 77.25, 86.35, + 104.55, 113.65, 159.15, 168.25, 186.45, 195.55}, + sd::DataType::FLOAT32); + NDArray expGradW( + 'c', {kD, kH, kW, iC, oC}, + {28., 28, 32, 32, 40, 40, 44, 44, 64, 64, 68, 68, 76, 76, 80, 80.}, + sd::DataType::FLOAT32); + + input = 0.5; + gradO.linspace(0.5); + + sd::ops::deconv3d_bp op; + auto results = op.evaluate( + {&input, &weights, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, dataFormat}, + {}); + + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_bp_test4) { - - int bS=1, iD=2,iH=2,iW=2, iC=1,oC=2, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1; - int oD=3,oH=3,oW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}, {0.1f, 0.9f, 0.2f, 0.1f, 0.3f, 1.1f, 0.4f, 1.2f, 0.5f, 1.3f, 0.6f, 1.4f, 0.7f, 1.5f, 0.8f, 1.6f}); - auto gradO = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - - NDArray expGradI('c', {bS, oC, oD, oH, oW}, {0.4, 1.55, 1.05, 2.3, 5.7, 3.2, 1.5, 3.35, 1.75, 3.8, 8.3, 4.3, 9.0, 18.6, 9.2, 4.4, 8.7, 4.1, 1.8, 3.55, 1.65, 3.5, 6.5, 2.8, 1.3, 2.15, 0.75, 0.8, 3.15, 2.25, 4.7, 12.1, 7.2, 3.5, 8.15, 4.55, 7.8, 17.9, 9.9, 19.75, 42.85, 23.6, 9.35, 21.55, 12.9, 5.4, 11.55, 6.05, 8.25, 20.75, 13.2, 0.65, 6.6, 6.75}, sd::DataType::FLOAT32); - NDArray expGradW('c', {kD, kH, kW, iC, oC}, {16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.}, sd::DataType::FLOAT32); - - input = 0.5; - gradO.linspace(0.5); - - sd::ops::deconv3d_bp op; - auto results = op.evaluate({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); - - auto gradI = results.at(0); - auto gradW = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); + int bS = 1, iD = 2, iH = 2, iW = 2, iC = 1, oC = 2, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 1, pH = 1, pW = 1, dD = 1, dH = 1, dW = 1; + int oD = 3, oH = 3, oW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); + auto weights = NDArrayFactory::create( + 'c', {kD, kH, kW, iC, oC}, + {0.1f, 0.9f, 0.2f, 0.1f, 0.3f, 1.1f, 0.4f, 1.2f, 0.5f, 1.3f, 0.6f, 1.4f, + 0.7f, 1.5f, 0.8f, 1.6f}); + auto gradO = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + + NDArray expGradI( + 'c', {bS, oC, oD, oH, oW}, + {0.4, 1.55, 1.05, 2.3, 5.7, 3.2, 1.5, 3.35, 1.75, 3.8, 8.3, + 4.3, 9.0, 18.6, 9.2, 4.4, 8.7, 4.1, 1.8, 3.55, 1.65, 3.5, + 6.5, 2.8, 1.3, 2.15, 0.75, 0.8, 3.15, 2.25, 4.7, 12.1, 7.2, + 3.5, 8.15, 4.55, 7.8, 17.9, 9.9, 19.75, 42.85, 23.6, 9.35, 21.55, + 12.9, 5.4, 11.55, 6.05, 8.25, 20.75, 13.2, 0.65, 6.6, 6.75}, + sd::DataType::FLOAT32); + NDArray expGradW('c', {kD, kH, kW, iC, oC}, + {16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, + 16.0, 16.0, 16.0, 16.0, 16.0, 16.}, + sd::DataType::FLOAT32); + + input = 0.5; + gradO.linspace(0.5); + + sd::ops::deconv3d_bp op; + auto results = op.evaluate( + {&input, &weights, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, dataFormat}, + {}); + + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_bp_test5) { - - int bS=2, iD=4,iH=4,iW=4, iC=3,oC=2, kD=2,kH=1,kW=1, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=4,oH=4,oW=4; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - int wFormat = 1; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] - - NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32); - NDArray weights('c',{iC, oC, kD, kH, kW}, {-0.6, 0., -0.3, 0.3, -0.5, 0.1, -0.2, 0.4, -0.4, 0.2, -0.1, 0.5}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oC, oD, oH, oW},sd::DataType::FLOAT32); - - NDArray expGradI('c', {bS, iC, iD, iH, iW}, {9.696001, 9.684001, 9.672001, 9.66, 9.648001, 9.636, 9.624001, 9.612, 9.600001, 9.587999, 9.576, 9.564001, 9.552, - 9.540001, 9.528, 9.516, 9.504001, 9.492, 9.480001, 9.468, 9.455999, 9.444, 9.432001, 9.420001, 9.408001, 9.396, 9.384001, 9.372001, 9.36, 9.348001, 9.335999, - 9.324001, 9.312, 9.300001, 9.288001, 9.276001, 9.264, 9.252001, 9.24, 9.228001, 9.216, 9.204, 9.191999, 9.18, 9.168001, 9.156, 9.144001, 9.132, 13.152, 13.134001, - 13.116, 13.098, 13.080001, 13.062, 13.044001, 13.026001, 13.008001, 12.990001, 12.972, 12.954, 12.936001, 12.918, 12.900002, 12.882, 3.616001, 3.612, 3.608, 3.604, - 3.6, 3.596, 3.592, 3.588, 3.584001, 3.579999, 3.576001, 3.571999, 3.568, 3.564, 3.56, 3.556, 3.552, 3.548, 3.544, 3.539999, 3.536001, 3.532001, 3.527999, 3.524001, 3.52, 3.516, 3.512, 3.508, 3.504, 3.5, 3.496, 3.492, 3.487999, 3.484001, 3.48, 3.476, 3.472, 3.468, 3.464, 3.46, 3.456, 3.452, 3.447999, 3.444001, 3.439999, 3.436, 3.432001, 3.428, 10.272, 10.258, 10.244, 10.23, 10.216, 10.202, 10.188, 10.174, 10.16, 10.146, 10.132, 10.118, 10.104, 10.09, 10.076, 10.062, -2.464, -2.460001, -2.455999, -2.452, -2.448, -2.444, -2.44, -2.436, -2.432, -2.428, -2.424, -2.42, -2.415999, -2.412, -2.408, -2.404, -2.4, -2.396, -2.392, -2.388, -2.384, -2.38, -2.376, -2.372, -2.368, -2.363999, -2.36, -2.356, -2.352, -2.348, -2.344, -2.34, -2.336, -2.332, -2.328001, -2.323999, -2.32, -2.316, -2.312, -2.308, -2.304, -2.3, -2.296, -2.292, -2.288, -2.283999, -2.28, -2.276, 7.392, 7.382, 7.372, 7.362, 7.352, 7.342, 7.332, 7.322, 7.312, 7.302, 7.292, 7.282, 7.272, 7.262, 7.252, 7.242, 8.16, 8.148001, 8.136001, 8.124001, 8.112, 8.1, 8.087999, 8.076, 8.063999, 8.052, 8.04, 8.028001, 8.016, 8.004001, 7.992001, 7.98, 7.968, 7.956, 7.944, 7.932001, 7.92, 7.908, 7.896, 7.884, 7.872001, 7.86, 7.848001, 7.835999, 7.824, 7.812, 7.800001, 7.788, 7.776, 7.764, 7.752, 7.740001, 7.728, 7.716001, 7.704, 7.692, 7.68, 7.668, 7.656, 7.644001, 7.632001, 7.62, 7.608001, 7.596001, 10.848, 10.830001, 10.812, 10.794001, 10.776, 10.758, 10.74, 10.722, 10.704, 10.686001, 10.668, 10.650001, 10.632, 10.614, 10.596001, 10.578001, 3.104, 3.1, 3.096, 3.092, 3.088, 3.084, 3.079999, 3.076001, 3.072, 3.068, 3.064, 3.06, 3.056, 3.052, 3.048, 3.044, 3.039999, 3.036001, 3.032, 3.028, 3.024001, 3.02, 3.016, 3.012, 3.008, 3.004, 3., 2.996, 2.992, 2.987999, 2.984001, 2.98, 2.976, 2.972, 2.968, 2.964, 2.96, 2.956, 2.952, 2.947999, 2.944001, 2.94, 2.936, 2.932001, 2.928, 2.924, 2.92, 2.916, 8.48, 8.466, 8.452, 8.438, 8.424, 8.41, 8.396, 8.382, 8.368, 8.354, 8.34, 8.326, 8.312, 8.298, 8.284, 8.27, -1.952, -1.948, -1.944, -1.94, -1.936, -1.932, -1.928, -1.924, -1.92, -1.916, -1.912, -1.908, -1.904, -1.9, -1.896, -1.892, -1.888, -1.884, -1.88, -1.876, -1.872, -1.868, -1.863999, -1.86, -1.856, -1.852, -1.848, -1.844, -1.84, -1.836, -1.832, -1.828, -1.823999, -1.82, -1.816, -1.812, -1.808, -1.804, -1.8, -1.796, -1.792, -1.788, -1.784, -1.78, -1.776, -1.771999, -1.768, -1.764, 6.112, 6.102, 6.092, 6.082, 6.072, 6.062, 6.052, 6.042, 6.032, 6.022, 6.012, 6.002, 5.992, 5.982, 5.972, 5.962}, sd::DataType::FLOAT32); - - NDArray expGradW('c', {iC, oC, kD, kH, kW}, {-73678.695312, -59907.972656, -67739.515625, -54962.082031, -15966.075195, -17115.042969, -15269.777344, -16101.275391, 41746.566406, 25677.917969, 37200.003906, 22759.517578}, sd::DataType::FLOAT32); - NDArray expGradB('c', {oC}, {-1803.520020, -1639.679932}, sd::DataType::FLOAT32); - - input.linspace(100., -0.5); - gradO.linspace(-16, 0.02); - - sd::ops::deconv3d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto gradI = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); + int bS = 2, iD = 4, iH = 4, iW = 4, iC = 3, oC = 2, kD = 2, kH = 1, kW = 1, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 4, oH = 4, oW = 4; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - + // [iC, kD, kH, kW, oC] + + NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {iC, oC, kD, kH, kW}, + {-0.6, 0., -0.3, 0.3, -0.5, 0.1, -0.2, 0.4, -0.4, 0.2, -0.1, 0.5}, + sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oD, oH, oW}, sd::DataType::FLOAT32); + + NDArray expGradI( + 'c', {bS, iC, iD, iH, iW}, + {9.696001, 9.684001, 9.672001, 9.66, 9.648001, 9.636, + 9.624001, 9.612, 9.600001, 9.587999, 9.576, 9.564001, + 9.552, 9.540001, 9.528, 9.516, 9.504001, 9.492, + 9.480001, 9.468, 9.455999, 9.444, 9.432001, 9.420001, + 9.408001, 9.396, 9.384001, 9.372001, 9.36, 9.348001, + 9.335999, 9.324001, 9.312, 9.300001, 9.288001, 9.276001, + 9.264, 9.252001, 9.24, 9.228001, 9.216, 9.204, + 9.191999, 9.18, 9.168001, 9.156, 9.144001, 9.132, + 13.152, 13.134001, 13.116, 13.098, 13.080001, 13.062, + 13.044001, 13.026001, 13.008001, 12.990001, 12.972, 12.954, + 12.936001, 12.918, 12.900002, 12.882, 3.616001, 3.612, + 3.608, 3.604, 3.6, 3.596, 3.592, 3.588, + 3.584001, 3.579999, 3.576001, 3.571999, 3.568, 3.564, + 3.56, 3.556, 3.552, 3.548, 3.544, 3.539999, + 3.536001, 3.532001, 3.527999, 3.524001, 3.52, 3.516, + 3.512, 3.508, 3.504, 3.5, 3.496, 3.492, + 3.487999, 3.484001, 3.48, 3.476, 3.472, 3.468, + 3.464, 3.46, 3.456, 3.452, 3.447999, 3.444001, + 3.439999, 3.436, 3.432001, 3.428, 10.272, 10.258, + 10.244, 10.23, 10.216, 10.202, 10.188, 10.174, + 10.16, 10.146, 10.132, 10.118, 10.104, 10.09, + 10.076, 10.062, -2.464, -2.460001, -2.455999, -2.452, + -2.448, -2.444, -2.44, -2.436, -2.432, -2.428, + -2.424, -2.42, -2.415999, -2.412, -2.408, -2.404, + -2.4, -2.396, -2.392, -2.388, -2.384, -2.38, + -2.376, -2.372, -2.368, -2.363999, -2.36, -2.356, + -2.352, -2.348, -2.344, -2.34, -2.336, -2.332, + -2.328001, -2.323999, -2.32, -2.316, -2.312, -2.308, + -2.304, -2.3, -2.296, -2.292, -2.288, -2.283999, + -2.28, -2.276, 7.392, 7.382, 7.372, 7.362, + 7.352, 7.342, 7.332, 7.322, 7.312, 7.302, + 7.292, 7.282, 7.272, 7.262, 7.252, 7.242, + 8.16, 8.148001, 8.136001, 8.124001, 8.112, 8.1, + 8.087999, 8.076, 8.063999, 8.052, 8.04, 8.028001, + 8.016, 8.004001, 7.992001, 7.98, 7.968, 7.956, + 7.944, 7.932001, 7.92, 7.908, 7.896, 7.884, + 7.872001, 7.86, 7.848001, 7.835999, 7.824, 7.812, + 7.800001, 7.788, 7.776, 7.764, 7.752, 7.740001, + 7.728, 7.716001, 7.704, 7.692, 7.68, 7.668, + 7.656, 7.644001, 7.632001, 7.62, 7.608001, 7.596001, + 10.848, 10.830001, 10.812, 10.794001, 10.776, 10.758, + 10.74, 10.722, 10.704, 10.686001, 10.668, 10.650001, + 10.632, 10.614, 10.596001, 10.578001, 3.104, 3.1, + 3.096, 3.092, 3.088, 3.084, 3.079999, 3.076001, + 3.072, 3.068, 3.064, 3.06, 3.056, 3.052, + 3.048, 3.044, 3.039999, 3.036001, 3.032, 3.028, + 3.024001, 3.02, 3.016, 3.012, 3.008, 3.004, + 3., 2.996, 2.992, 2.987999, 2.984001, 2.98, + 2.976, 2.972, 2.968, 2.964, 2.96, 2.956, + 2.952, 2.947999, 2.944001, 2.94, 2.936, 2.932001, + 2.928, 2.924, 2.92, 2.916, 8.48, 8.466, + 8.452, 8.438, 8.424, 8.41, 8.396, 8.382, + 8.368, 8.354, 8.34, 8.326, 8.312, 8.298, + 8.284, 8.27, -1.952, -1.948, -1.944, -1.94, + -1.936, -1.932, -1.928, -1.924, -1.92, -1.916, + -1.912, -1.908, -1.904, -1.9, -1.896, -1.892, + -1.888, -1.884, -1.88, -1.876, -1.872, -1.868, + -1.863999, -1.86, -1.856, -1.852, -1.848, -1.844, + -1.84, -1.836, -1.832, -1.828, -1.823999, -1.82, + -1.816, -1.812, -1.808, -1.804, -1.8, -1.796, + -1.792, -1.788, -1.784, -1.78, -1.776, -1.771999, + -1.768, -1.764, 6.112, 6.102, 6.092, 6.082, + 6.072, 6.062, 6.052, 6.042, 6.032, 6.022, + 6.012, 6.002, 5.992, 5.982, 5.972, 5.962}, + sd::DataType::FLOAT32); + + NDArray expGradW('c', {iC, oC, kD, kH, kW}, + {-73678.695312, -59907.972656, -67739.515625, -54962.082031, + -15966.075195, -17115.042969, -15269.777344, -16101.275391, + 41746.566406, 25677.917969, 37200.003906, 22759.517578}, + sd::DataType::FLOAT32); + NDArray expGradB('c', {oC}, {-1803.520020, -1639.679932}, + sd::DataType::FLOAT32); + + input.linspace(100., -0.5); + gradO.linspace(-16, 0.02); + + sd::ops::deconv3d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat, wFormat}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_bp_test6) { - - int bS=2, iD=4,iH=4,iW=4, iC=3,oC=2, kD=2,kH=1,kW=1, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=5,oH=4,oW=4; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - int wFormat = 2; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] - - NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32); - NDArray weights('c',{iC, kD, kH, kW, oC}, {-0.6, -0.3, 0., 0.3, -0.5, -0.2, 0.1, 0.4, -0.4, -0.1, 0.2, 0.5}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oD, oH, oW, oC}, sd::DataType::FLOAT32); - - NDArray expGradI('c', {bS, iD, iH, iW, iC}, {1.056, 0.482, -0.092, 1.044, 0.478, -0.088, 1.032, 0.474, -0.084, 1.02, 0.47, -0.08, 1.008, 0.466, -0.076, 0.996, - 0.462, -0.072, 0.984, 0.458, -0.068, 0.972, 0.454, -0.064, 0.96, 0.45, -0.06, 0.948, 0.446, -0.056, 0.936, 0.442, -0.052, 0.924, 0.438, -0.048, 0.912, 0.434, - -0.044, 0.9, 0.43, -0.04, 0.888, 0.426, -0.036, 0.876, 0.422, -0.032, 0.864, 0.418, -0.028, 0.852, 0.414, -0.024, 0.84, 0.41, -0.02, 0.828, 0.406, -0.016, - 0.816, 0.402, -0.012, 0.804, 0.398, -0.008, 0.792, 0.394, -0.004, 0.78, 0.39, 0., 0.768, 0.386, 0.004, 0.756, 0.382, 0.008, 0.744, 0.378, 0.012, 0.732, 0.374, - 0.016, 0.72, 0.37, 0.02, 0.708, 0.366, 0.024, 0.696, 0.362, 0.028, 0.684, 0.358, 0.032, 0.672, 0.354, 0.036, 0.66, 0.35, 0.04, 0.648, 0.346, 0.044, 0.636, 0.342, 0.048, 0.624, 0.338, 0.052, 0.612, 0.334, 0.056, 0.6, 0.33, 0.06, 0.588, 0.326, 0.064, 0.576, 0.322, 0.068, 0.564, 0.318, 0.072, 0.552, 0.314, 0.076, 0.54, 0.31, 0.08, 0.528, 0.306, 0.084, 0.516, 0.302, 0.088, 0.504, 0.298, 0.092, 0.492, 0.294, 0.096, 0.48, 0.29, 0.1, 0.468, 0.286, 0.104, 0.456, 0.282, 0.108, 0.444, 0.278, 0.112, 0.432, 0.274, 0.116, 0.42, 0.27, 0.12, 0.408, 0.266, 0.124, 0.396, 0.262, 0.128, 0.384, 0.258, 0.132, 0.372, 0.254, 0.136, 0.36, 0.25, 0.14, 0.348, 0.246, 0.144, 0.336, 0.242, 0.148, 0.324, 0.238, 0.152, 0.312, 0.234, 0.156, 0.3, 0.23, 0.16, 0.096, 0.162, 0.228, 0.084, 0.158, 0.232, 0.072, 0.154, 0.236, 0.06, 0.15, 0.24, 0.048, 0.146, 0.244, 0.036, 0.142, 0.248, 0.024, 0.138, 0.252, 0.012, 0.134, 0.256, 0., 0.13, 0.26, -0.012, 0.126, 0.264, -0.024, 0.122, 0.268, -0.036, 0.118, 0.272, -0.048, 0.114, 0.276, -0.06, 0.11, 0.28, -0.072, 0.106, 0.284, -0.084, 0.102, 0.288, -0.096, 0.098, 0.292, -0.108, 0.094, 0.296, -0.12, 0.09, 0.3, -0.132, 0.086, 0.304, -0.144, 0.082, 0.308, -0.156, 0.078, 0.312, -0.168, 0.074, 0.316, -0.18, 0.07, 0.32, -0.192, 0.066, 0.324, -0.204, 0.062, 0.328, -0.216, 0.058, 0.332, -0.228, 0.054, 0.336, -0.24, 0.05, 0.34, -0.252, 0.046, 0.344, -0.264, 0.042, 0.348, -0.276, 0.038, 0.352, -0.288, 0.034, 0.356, -0.3, 0.03, 0.36, -0.312, 0.026, 0.364, -0.324, 0.022, 0.368, -0.336, 0.018, 0.372, -0.348, 0.014, 0.376, -0.36, 0.01, 0.38, -0.372, 0.006, 0.384, -0.384, 0.002, 0.388, -0.396, -0.002, 0.392, -0.408, -0.006, 0.396, -0.42, -0.01, 0.4, -0.432, -0.014, 0.404, -0.444, -0.018, 0.408, -0.456, -0.022, 0.412, -0.468, -0.026, 0.416, -0.48, -0.03, 0.42, -0.492, -0.034, 0.424, -0.504, -0.038, 0.428, -0.516, -0.042, 0.432, -0.528, -0.046, 0.436, -0.54, -0.05, 0.44, -0.552, -0.054, 0.444, -0.564, -0.058, 0.448, -0.576, -0.062, 0.452, -0.588, -0.066, 0.456, -0.6, -0.07, 0.46, -0.612, -0.074, 0.464, -0.624, -0.078, 0.468, -0.636, -0.082, 0.472, -0.648, -0.086, 0.476, -0.66, -0.09, 0.48}, sd::DataType::FLOAT32); - - NDArray expGradW('c', {iC, kD, kH, kW, oC}, {-6328.958984, -6322.880371, -6134.400879, -6128.319824, -6318.079590, -6312.640137, -6144.000000, -6138.560547, -6307.202637, -6302.399414, -6153.599609, -6148.799316}, sd::DataType::FLOAT32); - NDArray expGradB('c', {oC}, {-1.599994, 0.000001}, sd::DataType::FLOAT32); - - input.linspace(100., -0.5); - gradO.linspace(-1.6, 0.01); - - sd::ops::deconv3d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto gradI = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); + int bS = 2, iD = 4, iH = 4, iW = 4, iC = 3, oC = 2, kD = 2, kH = 1, kW = 1, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 5, oH = 4, oW = 4; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - + // [iC, kD, kH, kW, oC] + + NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {iC, kD, kH, kW, oC}, + {-0.6, -0.3, 0., 0.3, -0.5, -0.2, 0.1, 0.4, -0.4, -0.1, 0.2, 0.5}, + sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oD, oH, oW, oC}, sd::DataType::FLOAT32); + + NDArray expGradI( + 'c', {bS, iD, iH, iW, iC}, + {1.056, 0.482, -0.092, 1.044, 0.478, -0.088, 1.032, 0.474, -0.084, + 1.02, 0.47, -0.08, 1.008, 0.466, -0.076, 0.996, 0.462, -0.072, + 0.984, 0.458, -0.068, 0.972, 0.454, -0.064, 0.96, 0.45, -0.06, + 0.948, 0.446, -0.056, 0.936, 0.442, -0.052, 0.924, 0.438, -0.048, + 0.912, 0.434, -0.044, 0.9, 0.43, -0.04, 0.888, 0.426, -0.036, + 0.876, 0.422, -0.032, 0.864, 0.418, -0.028, 0.852, 0.414, -0.024, + 0.84, 0.41, -0.02, 0.828, 0.406, -0.016, 0.816, 0.402, -0.012, + 0.804, 0.398, -0.008, 0.792, 0.394, -0.004, 0.78, 0.39, 0., + 0.768, 0.386, 0.004, 0.756, 0.382, 0.008, 0.744, 0.378, 0.012, + 0.732, 0.374, 0.016, 0.72, 0.37, 0.02, 0.708, 0.366, 0.024, + 0.696, 0.362, 0.028, 0.684, 0.358, 0.032, 0.672, 0.354, 0.036, + 0.66, 0.35, 0.04, 0.648, 0.346, 0.044, 0.636, 0.342, 0.048, + 0.624, 0.338, 0.052, 0.612, 0.334, 0.056, 0.6, 0.33, 0.06, + 0.588, 0.326, 0.064, 0.576, 0.322, 0.068, 0.564, 0.318, 0.072, + 0.552, 0.314, 0.076, 0.54, 0.31, 0.08, 0.528, 0.306, 0.084, + 0.516, 0.302, 0.088, 0.504, 0.298, 0.092, 0.492, 0.294, 0.096, + 0.48, 0.29, 0.1, 0.468, 0.286, 0.104, 0.456, 0.282, 0.108, + 0.444, 0.278, 0.112, 0.432, 0.274, 0.116, 0.42, 0.27, 0.12, + 0.408, 0.266, 0.124, 0.396, 0.262, 0.128, 0.384, 0.258, 0.132, + 0.372, 0.254, 0.136, 0.36, 0.25, 0.14, 0.348, 0.246, 0.144, + 0.336, 0.242, 0.148, 0.324, 0.238, 0.152, 0.312, 0.234, 0.156, + 0.3, 0.23, 0.16, 0.096, 0.162, 0.228, 0.084, 0.158, 0.232, + 0.072, 0.154, 0.236, 0.06, 0.15, 0.24, 0.048, 0.146, 0.244, + 0.036, 0.142, 0.248, 0.024, 0.138, 0.252, 0.012, 0.134, 0.256, + 0., 0.13, 0.26, -0.012, 0.126, 0.264, -0.024, 0.122, 0.268, + -0.036, 0.118, 0.272, -0.048, 0.114, 0.276, -0.06, 0.11, 0.28, + -0.072, 0.106, 0.284, -0.084, 0.102, 0.288, -0.096, 0.098, 0.292, + -0.108, 0.094, 0.296, -0.12, 0.09, 0.3, -0.132, 0.086, 0.304, + -0.144, 0.082, 0.308, -0.156, 0.078, 0.312, -0.168, 0.074, 0.316, + -0.18, 0.07, 0.32, -0.192, 0.066, 0.324, -0.204, 0.062, 0.328, + -0.216, 0.058, 0.332, -0.228, 0.054, 0.336, -0.24, 0.05, 0.34, + -0.252, 0.046, 0.344, -0.264, 0.042, 0.348, -0.276, 0.038, 0.352, + -0.288, 0.034, 0.356, -0.3, 0.03, 0.36, -0.312, 0.026, 0.364, + -0.324, 0.022, 0.368, -0.336, 0.018, 0.372, -0.348, 0.014, 0.376, + -0.36, 0.01, 0.38, -0.372, 0.006, 0.384, -0.384, 0.002, 0.388, + -0.396, -0.002, 0.392, -0.408, -0.006, 0.396, -0.42, -0.01, 0.4, + -0.432, -0.014, 0.404, -0.444, -0.018, 0.408, -0.456, -0.022, 0.412, + -0.468, -0.026, 0.416, -0.48, -0.03, 0.42, -0.492, -0.034, 0.424, + -0.504, -0.038, 0.428, -0.516, -0.042, 0.432, -0.528, -0.046, 0.436, + -0.54, -0.05, 0.44, -0.552, -0.054, 0.444, -0.564, -0.058, 0.448, + -0.576, -0.062, 0.452, -0.588, -0.066, 0.456, -0.6, -0.07, 0.46, + -0.612, -0.074, 0.464, -0.624, -0.078, 0.468, -0.636, -0.082, 0.472, + -0.648, -0.086, 0.476, -0.66, -0.09, 0.48}, + sd::DataType::FLOAT32); + + NDArray expGradW('c', {iC, kD, kH, kW, oC}, + {-6328.958984, -6322.880371, -6134.400879, -6128.319824, + -6318.079590, -6312.640137, -6144.000000, -6138.560547, + -6307.202637, -6302.399414, -6153.599609, -6148.799316}, + sd::DataType::FLOAT32); + NDArray expGradB('c', {oC}, {-1.599994, 0.000001}, sd::DataType::FLOAT32); + + input.linspace(100., -0.5); + gradO.linspace(-1.6, 0.01); + + sd::ops::deconv3d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat, wFormat}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, maxpool2d_1) { - - auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); - auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); - // auto z('c',{bS,iD,oH,oW}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - // variableSpace->putVariable(1, &z); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1}); - std::vector* argI = block->getIArguments(); - *argI = {kH,kW, sH,sW, pH,pW, dH,dW, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - sd::ops::maxpool2d pooling; - Nd4jStatus status = pooling.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - // result.printShapeInfo(); - ASSERT_TRUE(exp.isSameShape(result)); - - delete variableSpace; - delete block; + auto x = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto exp = NDArrayFactory::create('c', {bS, iD, oH, oW}); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->appendI( + {kH, kW, sH, sW, pH, pW, dH, dW, + 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::maxpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + // result.printShapeInfo(); + ASSERT_TRUE(exp.isSameShape(*result)); + + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, maxpool2d_2) { - - const int bS = 2; - const int iD = 1; - const int iH = 28; - const int iW = 28; - const int kH = 5; - const int kW = 5; - const int sH = 1; - const int sW = 1; - const int pH = 0; - const int pW = 0; - const int dH = 1; - const int dW = 1; - const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height - const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width - - - auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); - auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); - // auto z('c',{bS,iD,oH,oW}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - // variableSpace->putVariable(1, &z); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1}); - std::vector* argI = block->getIArguments(); - *argI = {kH,kW, sH,sW, pH,pW, dH,dW, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - sd::ops::maxpool2d pooling; - Nd4jStatus status = pooling.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - // result.printShapeInfo(); - ASSERT_TRUE(exp.isSameShape(result)); - - delete variableSpace; - delete block; + const int bS = 2; + const int iD = 1; + const int iH = 28; + const int iW = 28; + const int kH = 5; + const int kW = 5; + const int sH = 1; + const int sW = 1; + const int pH = 0; + const int pW = 0; + const int dH = 1; + const int dW = 1; + const int oH = + (iH - kH - (kH - 1) * (dH - 1) + 2 * pH) / sH + 1; // output height + const int oW = + (iW - kW - (kW - 1) * (dW - 1) + 2 * pW) / sW + 1; // output width + + auto x = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto exp = NDArrayFactory::create('c', {bS, iD, oH, oW}); + // auto z('c',{bS,iD,oH,oW}); + + VariableSpace variableSpace; + variableSpace.putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + Context block(1, &variableSpace, false); + block.setName("alpha"); + block.fillInputs({-1}); + block.appendI( + {kH, kW, sH, sW, pH, pW, dH, dW, + 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::maxpool2d pooling; + Nd4jStatus status = pooling.execute(&block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + ASSERT_TRUE(variableSpace.hasVariable(block.nodeId(), 0)); + ASSERT_TRUE(variableSpace.hasVariable("alpha")); + ASSERT_TRUE(variableSpace.hasVariable(block.nodeId())); + auto result = variableSpace.getVariable(block.nodeId())->getNDArray(); + // result.printShapeInfo(); + ASSERT_TRUE(exp.isSameShape(*result)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, maxpool2d_3) { - - const int bS = 2; - const int iD = 1; - const int iH = 28; - const int iW = 28; - const int kH = 5; - const int kW = 5; - const int sH = 1; - const int sW = 1; - const int pH = 0; - const int pW = 0; - const int dH = 1; - const int dW = 1; - const int oH = (int) sd::math::nd4j_ceil(iH * 1.f / sH); - const int oW = (int) sd::math::nd4j_ceil(iW * 1.f / sW); - - - auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); - auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); - // auto z('c',{bS,iD,oH,oW}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - // variableSpace->putVariable(1, &z); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1}); - std::vector* argI = block->getIArguments(); - *argI = {kH,kW, sH,sW, pH,pW, dH,dW, 1}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - sd::ops::maxpool2d pooling; - Nd4jStatus status = pooling.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - // result.printShapeInfo(); - ASSERT_TRUE(exp.isSameShape(result)); - - delete variableSpace; - delete block; + const int bS = 2; + const int iD = 1; + const int iH = 28; + const int iW = 28; + const int kH = 5; + const int kW = 5; + const int sH = 1; + const int sW = 1; + const int pH = 0; + const int pW = 0; + const int dH = 1; + const int dW = 1; + const int oH = (int)sd::math::nd4j_ceil(iH * 1.f / sH); + const int oW = (int)sd::math::nd4j_ceil(iW * 1.f / sW); + + auto x = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto exp = NDArrayFactory::create('c', {bS, iD, oH, oW}); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->appendI( + {kH, kW, sH, sW, pH, pW, dH, dW, + 1}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::maxpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + // result.printShapeInfo(); + ASSERT_TRUE(exp.isSameShape(*result)); + + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, maxpool2d_4) { - - const int bS = 2; - const int iD = 1; - const int iH = 24; - const int iW = 24; - const int kH = 3; - const int kW = 3; - const int sH = 1; - const int sW = 1; - const int pH = 0; - const int pW = 0; - const int dH = 1; - const int dW = 1; - const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height - const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width - - - auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); - auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); - // auto z('c',{bS,iD,oH,oW}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - // variableSpace->putVariable(1, &z); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1}); - std::vector* argI = block->getIArguments(); - *argI = {kH,kW, sH,sW, pH,pW, dH,dW, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - sd::ops::maxpool2d pooling; - Nd4jStatus status = pooling.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - // result.printShapeInfo(); - ASSERT_TRUE(exp.isSameShape(result)); - - delete variableSpace; - delete block; + const int bS = 2; + const int iD = 1; + const int iH = 24; + const int iW = 24; + const int kH = 3; + const int kW = 3; + const int sH = 1; + const int sW = 1; + const int pH = 0; + const int pW = 0; + const int dH = 1; + const int dW = 1; + const int oH = + (iH - kH - (kH - 1) * (dH - 1) + 2 * pH) / sH + 1; // output height + const int oW = + (iW - kW - (kW - 1) * (dW - 1) + 2 * pW) / sW + 1; // output width + + auto x = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto exp = NDArrayFactory::create('c', {bS, iD, oH, oW}); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->appendI( + {kH, kW, sH, sW, pH, pW, dH, dW, + 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::maxpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + // result.printShapeInfo(); + ASSERT_TRUE(exp.isSameShape(*result)); + + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, maxpool2d_5) { - - const int bS = 2; - const int iD = 1; - const int iH = 24; - const int iW = 24; - const int kH = 3; - const int kW = 3; - const int sH = 1; - const int sW = 1; - const int pH = 0; - const int pW = 0; - const int dH = 1; - const int dW = 1; - const int oH = (int) sd::math::nd4j_ceil(iH * 1.f / sH); - const int oW = (int) sd::math::nd4j_ceil(iW * 1.f / sW); - - - auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); - auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); - // auto z('c',{bS,iD,oH,oW}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - // variableSpace->putVariable(1, &z); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1}); - std::vector* argI = block->getIArguments(); - *argI = {kH,kW, sH,sW, pH,pW, dH,dW, 1}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - sd::ops::maxpool2d pooling; - Nd4jStatus status = pooling.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - // result.printShapeInfo(); - ASSERT_TRUE(exp.isSameShape(result)); - - delete variableSpace; - delete block; + const int bS = 2; + const int iD = 1; + const int iH = 24; + const int iW = 24; + const int kH = 3; + const int kW = 3; + const int sH = 1; + const int sW = 1; + const int pH = 0; + const int pW = 0; + const int dH = 1; + const int dW = 1; + const int oH = (int)sd::math::nd4j_ceil(iH * 1.f / sH); + const int oW = (int)sd::math::nd4j_ceil(iW * 1.f / sW); + + auto x = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto exp = NDArrayFactory::create('c', {bS, iD, oH, oW}); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->appendI( + {kH, kW, sH, sW, pH, pW, dH, dW, + 1}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::maxpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + // result.printShapeInfo(); + ASSERT_TRUE(exp.isSameShape(*result)); + + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool2d_6) { - auto x = NDArrayFactory::create('c', {2, 4, 4, 2}); - auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, {11.f, 12.f, 15.f, 16.f, 27.f, 28.f, 31.f, 32.f, 43.f, 44.f, 47.f, 48.f, 59.f, 60.f, 63.f, 64.f}); - - x.linspace(1); + auto x = NDArrayFactory::create('c', {2, 4, 4, 2}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, + {11.f, 12.f, 15.f, 16.f, 27.f, 28.f, 31.f, 32.f, 43.f, 44.f, 47.f, 48.f, + 59.f, 60.f, 63.f, 64.f}); - sd::ops::maxpool2d op; - auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 1, 1}); + x.linspace(1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::maxpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 1, 1}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool2d_7) { - auto x = NDArrayFactory::create('c', {2, 4, 4, 2}); - auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, {11.f, 12.f, 15.f, 16.f, 27.f, 28.f, 31.f, 32.f, 43.f, 44.f, 47.f, 48.f, 59.f, 60.f, 63.f, 64.f}); + auto x = NDArrayFactory::create('c', {2, 4, 4, 2}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, + {11.f, 12.f, 15.f, 16.f, 27.f, 28.f, 31.f, 32.f, 43.f, 44.f, 47.f, 48.f, + 59.f, 60.f, 63.f, 64.f}); - x.linspace(1); + x.linspace(1); - sd::ops::maxpool2d op; - auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1}); + sd::ops::maxpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool2d_8) { - auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); - auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, {7.f, 9.f, 17.f, 19.f, 32.f, 34.f, 42.f, 44.f, 57.f, 59.f, 67.f, 69.f, 82.f, 84.f, 92.f, 94.f}); + auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, + {7.f, 9.f, 17.f, 19.f, 32.f, 34.f, 42.f, 44.f, 57.f, 59.f, 67.f, 69.f, + 82.f, 84.f, 92.f, 94.f}); - x.linspace(1); + x.linspace(1); - sd::ops::maxpool2d op; - auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 0}); + sd::ops::maxpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool2d_9) { + int bS = 3; // batch size (number of samples) + int iC = 3; // input channels + int iH = 28, iW = 28; // input height/width + int kH = 2, kW = 2; // kernel (filter) height/width + int sH = 1, sW = 1; // stride height/width + int pH = 0, pW = 0; // padding height/width + int dH = 1, dW = 1; // dilation height/width - int bS = 3; // batch size (number of samples) - int iC = 3; // input channels - int iH = 28, iW = 28; // input height/width - int kH = 2, kW = 2; // kernel (filter) height/width - int sH = 1, sW = 1; // stride height/width - int pH = 0, pW = 0; // padding height/width - int dH = 1, dW = 1; // dilation height/width - - int oH = 27, oW = 27; // output height/width + int oH = 27, oW = 27; // output height/width - int isSameMode = 0; // 1-SAME, 0-VALID + int isSameMode = 0; // 1-SAME, 0-VALID - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - sd::ops::maxpool2d op; - auto results = op.evaluate({&input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, 1, 0}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(output->isSameShape({bS, iC, oH, oW})); + sd::ops::maxpool2d op; + auto results = op.evaluate( + {&input}, {}, {kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, 1, 0}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(output.isSameShape({bS, iC, oH, oW})); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool2d_10) { - - int bS=1, iH=4,iW=4, iC=3, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=3,oW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.27620894f, 0.21801452f, 0.062078513f, 7.348895E-4f, 0.24149609f, 0.4948205f, 0.93483436f, 0.52035654f, 0.30292067f, 0.3289706f, 0.7977864f, - 0.03180518f, 0.1455722f, 0.90352905f, 0.9405744f, 0.0048329555f, 0.44062102f, 0.111197524f, 0.31742015f, 0.1933705f, 0.23825112f, 0.35076278f, 0.7135856f, 0.28229436f, 0.18310733f, - 0.9613717f, 0.56823575f, 0.78289545f, 0.62195826f, 0.5244586f, 0.5040889f, 0.025349546f, 0.41400263f, 0.28420195f, 0.8536445f, 0.3044107f, 0.7997134f, 0.45762005f, 0.7653578f, - 0.07198584f, 0.5304998f, 0.7334402f, 0.85019743f, 0.031957153f, 0.37088063f, 0.85722464f, 0.06376881f, 0.39791203f}); - - auto expOutput = NDArrayFactory::create('c', {bS, iC, oH, oW}, {0.4948205f, 0.93483436f, 0.93483436f, 0.4948205f, 0.93483436f, 0.93483436f, 0.90352905f, 0.9405744f, 0.9405744f, 0.44062102f, 0.7135856f, - 0.7135856f, 0.9613717f, 0.9613717f, 0.78289545f, 0.9613717f, 0.9613717f, 0.78289545f, 0.7997134f, 0.8536445f, 0.8536445f, 0.7997134f, 0.85019743f, 0.85019743f, - 0.85722464f, 0.85722464f, 0.85019743f}); - - sd::ops::maxpool2d op; - auto results = op.evaluate({&input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode}); - auto* output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 1, iH = 4, iW = 4, iC = 3, kH = 2, kW = 2, sH = 1, sW = 1, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = 3, oW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create( + 'c', {bS, iC, iH, iW}, + {0.27620894f, 0.21801452f, 0.062078513f, 7.348895E-4f, 0.24149609f, + 0.4948205f, 0.93483436f, 0.52035654f, 0.30292067f, 0.3289706f, + 0.7977864f, 0.03180518f, 0.1455722f, 0.90352905f, 0.9405744f, + 0.0048329555f, 0.44062102f, 0.111197524f, 0.31742015f, 0.1933705f, + 0.23825112f, 0.35076278f, 0.7135856f, 0.28229436f, 0.18310733f, + 0.9613717f, 0.56823575f, 0.78289545f, 0.62195826f, 0.5244586f, + 0.5040889f, 0.025349546f, 0.41400263f, 0.28420195f, 0.8536445f, + 0.3044107f, 0.7997134f, 0.45762005f, 0.7653578f, 0.07198584f, + 0.5304998f, 0.7334402f, 0.85019743f, 0.031957153f, 0.37088063f, + 0.85722464f, 0.06376881f, 0.39791203f}); + + auto expOutput = NDArrayFactory::create( + 'c', {bS, iC, oH, oW}, + {0.4948205f, 0.93483436f, 0.93483436f, 0.4948205f, 0.93483436f, + 0.93483436f, 0.90352905f, 0.9405744f, 0.9405744f, 0.44062102f, + 0.7135856f, 0.7135856f, 0.9613717f, 0.9613717f, 0.78289545f, + 0.9613717f, 0.9613717f, 0.78289545f, 0.7997134f, 0.8536445f, + 0.8536445f, 0.7997134f, 0.85019743f, 0.85019743f, 0.85722464f, + 0.85722464f, 0.85019743f}); + + sd::ops::maxpool2d op; + auto results = + op.evaluate({&input}, {}, {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, maxpool2d_11) { + NDArray input('c', {1, 1, 4, 5}, sd::DataType::FLOAT32); + NDArray z('c', {1, 1, 4, 5}, sd::DataType::FLOAT32); - NDArray input('c', {1,1,4,5}, sd::DataType::FLOAT32); - NDArray z('c', {1,1,4,5}, sd::DataType::FLOAT32); - - input.linspace(1.); + input.linspace(1.); - sd::ops::maxpool2d op; - auto results = op.evaluate({&input}, {}, {2,2, 1,1, 1,1, 2,2, 1,0,0}); - - ASSERT_EQ(Status::OK(), results.status()); + sd::ops::maxpool2d op; + auto results = op.evaluate({&input}, {}, {2, 2, 1, 1, 1, 1, 2, 2, 1, 0, 0}); + ASSERT_EQ(Status::OK(), results.status()); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool3d_test1) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}, {10.5f, 11.5f, 13.5f, 14.5f, 22.5f, 23.5f, 25.5f, 26.5f, 46.5f, 47.5f, 49.5f, 50.5f, 58.5f, 59.5f, 61.5f, 62.5f, - 82.5f, 83.5f, 85.5f, 86.5f, 94.5f, 95.5f, 97.5f, 98.5f,118.5f,119.5f,121.5f,122.5f,130.5f,131.5f,133.5f,134.5f, - 154.5f,155.5f,157.5f,158.5f,166.5f,167.5f,169.5f,170.5f,190.5f,191.5f,193.5f,194.5f,202.5f,203.5f,205.5f,206.5f}); - input.linspace(1.); - - sd::ops::avgpool3dnew op; - auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, oD, oH, oW}, + {10.5f, 11.5f, 13.5f, 14.5f, 22.5f, 23.5f, 25.5f, 26.5f, + 46.5f, 47.5f, 49.5f, 50.5f, 58.5f, 59.5f, 61.5f, 62.5f, + 82.5f, 83.5f, 85.5f, 86.5f, 94.5f, 95.5f, 97.5f, 98.5f, + 118.5f, 119.5f, 121.5f, 122.5f, 130.5f, 131.5f, 133.5f, 134.5f, + 154.5f, 155.5f, 157.5f, 158.5f, 166.5f, 167.5f, 169.5f, 170.5f, + 190.5f, 191.5f, 193.5f, 194.5f, 202.5f, 203.5f, 205.5f, 206.5f}); + input.linspace(1.); + + sd::ops::avgpool3dnew op; + auto results = op.evaluate({&input}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool3d_test2) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=3,oH=4,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 29.5f, 30.5f, 31.5f, 29.5f, 30.5f, 31.5f, 32.5f, 33.5f, 34.5f, 34.f, 35.f, 36.f, 38.5f, 39.5f, 40.5f, 41.5f, 42.5f, 43.5f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 47.5f, 48.5f, 49.5f, - 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 65.5f, 66.5f, 67.5f, 65.5f, 66.5f, 67.5f, 68.5f, 69.5f, 70.5f, 70.f, 71.f, 72.f, 74.5f, 75.5f, 76.5f, 77.5f, 78.5f, 79.5f, 79.f, 80.f, 81.f, 79.f, 80.f, 81.f, 82.f, 83.f, 84.f, 83.5f, 84.5f, 85.5f, - 79.f, 80.f, 81.f, 82.f, 83.f, 84.f, 83.5f, 84.5f, 85.5f, 83.5f, 84.5f, 85.5f, 86.5f, 87.5f, 88.5f, 88.f, 89.f, 90.f, 92.5f, 93.5f, 94.5f, 95.5f, 96.5f, 97.5f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 100.f, 101.f, 102.f, 101.5f, 102.5f, 103.5f, - 133.f, 134.f, 135.f, 136.f, 137.f, 138.f, 137.5f, 138.5f, 139.5f, 137.5f, 138.5f, 139.5f, 140.5f, 141.5f, 142.5f, 142.f, 143.f, 144.f, 146.5f, 147.5f, 148.5f, 149.5f, 150.5f, 151.5f, 151.f, 152.f, 153.f, 151.f, 152.f, 153.f, 154.f, 155.f, 156.f, 155.5f, 156.5f, 157.5f, - 169.f, 170.f, 171.f, 172.f, 173.f, 174.f, 173.5f, 174.5f, 175.5f, 173.5f, 174.5f, 175.5f, 176.5f, 177.5f, 178.5f, 178.f, 179.f, 180.f, 182.5f, 183.5f, 184.5f, 185.5f, 186.5f, 187.5f, 187.f, 188.f, 189.f, 187.f, 188.f, 189.f, 190.f, 191.f, 192.f, 191.5f, 192.5f, 193.5f, - 187.f, 188.f, 189.f, 190.f, 191.f, 192.f, 191.5f, 192.5f, 193.5f, 191.5f, 192.5f, 193.5f, 194.5f, 195.5f, 196.5f, 196.f, 197.f, 198.f, 200.5f, 201.5f, 202.5f, 203.5f, 204.5f, 205.5f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 208.f, 209.f, 210.f, 209.5f, 210.5f, 211.5f}); - input.linspace(1.); - - sd::ops::avgpool3dnew op; - auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 0, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 3, oH = 4, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, oD, oH, oW, iC}, + {25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 29.5f, 30.5f, 31.5f, + 29.5f, 30.5f, 31.5f, 32.5f, 33.5f, 34.5f, 34.f, 35.f, 36.f, + 38.5f, 39.5f, 40.5f, 41.5f, 42.5f, 43.5f, 43.f, 44.f, 45.f, + 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 47.5f, 48.5f, 49.5f, + 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 65.5f, 66.5f, 67.5f, + 65.5f, 66.5f, 67.5f, 68.5f, 69.5f, 70.5f, 70.f, 71.f, 72.f, + 74.5f, 75.5f, 76.5f, 77.5f, 78.5f, 79.5f, 79.f, 80.f, 81.f, + 79.f, 80.f, 81.f, 82.f, 83.f, 84.f, 83.5f, 84.5f, 85.5f, + 79.f, 80.f, 81.f, 82.f, 83.f, 84.f, 83.5f, 84.5f, 85.5f, + 83.5f, 84.5f, 85.5f, 86.5f, 87.5f, 88.5f, 88.f, 89.f, 90.f, + 92.5f, 93.5f, 94.5f, 95.5f, 96.5f, 97.5f, 97.f, 98.f, 99.f, + 97.f, 98.f, 99.f, 100.f, 101.f, 102.f, 101.5f, 102.5f, 103.5f, + 133.f, 134.f, 135.f, 136.f, 137.f, 138.f, 137.5f, 138.5f, 139.5f, + 137.5f, 138.5f, 139.5f, 140.5f, 141.5f, 142.5f, 142.f, 143.f, 144.f, + 146.5f, 147.5f, 148.5f, 149.5f, 150.5f, 151.5f, 151.f, 152.f, 153.f, + 151.f, 152.f, 153.f, 154.f, 155.f, 156.f, 155.5f, 156.5f, 157.5f, + 169.f, 170.f, 171.f, 172.f, 173.f, 174.f, 173.5f, 174.5f, 175.5f, + 173.5f, 174.5f, 175.5f, 176.5f, 177.5f, 178.5f, 178.f, 179.f, 180.f, + 182.5f, 183.5f, 184.5f, 185.5f, 186.5f, 187.5f, 187.f, 188.f, 189.f, + 187.f, 188.f, 189.f, 190.f, 191.f, 192.f, 191.5f, 192.5f, 193.5f, + 187.f, 188.f, 189.f, 190.f, 191.f, 192.f, 191.5f, 192.5f, 193.5f, + 191.5f, 192.5f, 193.5f, 194.5f, 195.5f, 196.5f, 196.f, 197.f, 198.f, + 200.5f, 201.5f, 202.5f, 203.5f, 204.5f, 205.5f, 205.f, 206.f, 207.f, + 205.f, 206.f, 207.f, 208.f, 209.f, 210.f, 209.5f, 210.5f, 211.5f}); + input.linspace(1.); + + sd::ops::avgpool3dnew op; + auto results = op.evaluate({&input}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 0, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool3d_test3) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 29.5f, 30.5f, 31.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, 40.5f, 41.5f, 42.5f, 43.5f, 65.5f, 66.5f, 67.5f, 68.5f, 69.5f, 70.5f, - 74.5f, 75.5f, 76.5f, 77.5f, 78.5f, 79.5f, 137.5f, 138.5f, 139.5f, 140.5f, 141.5f, 142.5f, 146.5f, 147.5f, 148.5f, 149.5f, 150.5f, 151.5f, - 173.5f, 174.5f, 175.5f, 176.5f, 177.5f, 178.5f, 182.5f, 183.5f, 184.5f, 185.5f, 186.5f, 187.5f}); - input.linspace(1.); - - sd::ops::avgpool3dnew op; - auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, oD, oH, oW, iC}, + {29.5f, 30.5f, 31.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, + 40.5f, 41.5f, 42.5f, 43.5f, 65.5f, 66.5f, 67.5f, 68.5f, + 69.5f, 70.5f, 74.5f, 75.5f, 76.5f, 77.5f, 78.5f, 79.5f, + 137.5f, 138.5f, 139.5f, 140.5f, 141.5f, 142.5f, 146.5f, 147.5f, + 148.5f, 149.5f, 150.5f, 151.5f, 173.5f, 174.5f, 175.5f, 176.5f, + 177.5f, 178.5f, 182.5f, 183.5f, 184.5f, 185.5f, 186.5f, 187.5f}); + input.linspace(1.); + + sd::ops::avgpool3dnew op; + auto results = op.evaluate({&input}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool3d_test4) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1; - int oD=4,oH=4,oW=4; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW},{0.416667f, 1.00f, 1.333333f, 0.75f, 1.00f, 2.25f, 2.75f, 1.50f, 1.75f, 3.75f, 4.25f, 2.25f, 1.416667f, 3.00f, 3.333333f, 1.75f, 2.833333f, 6.00f, 6.666667f, 3.50f, 5.00f, 10.50f, 11.50f, 6.00f, 6.50f, - 13.50f, 14.50f, 7.50f, 4.833333f, 10.00f, 10.666667f, 5.50f, 6.833333f, 14.00f, 14.666667f, 7.50f, 11.00f, 22.50f, 23.50f, 12.00f, 12.50f, 25.50f, 26.50f, 13.50f, 8.833333f, 18.00f, 18.666666f, 9.50f, - 4.416667f, 9.00f, 9.333333f, 4.75f, 7.00f, 14.25f, 14.75f, 7.50f, 7.75f, 15.75f, 16.25f, 8.25f, 5.416667f, 11.00f, 11.333333f, 5.75f, 6.416667f, 13.00f, 13.333333f, 6.75f, 10.00f, 20.25f, 20.75f, - 10.50f, 10.75f, 21.75f, 22.25f, 11.25f, 7.416667f, 15.00f, 15.333333f, 7.75f, 14.833333f, 30.00f, 30.666666f, 15.50f, 23.00f, 46.50f, 47.50f, 24.00f, 24.50f, 49.50f, 50.50f, 25.50f, 16.833334f, - 34.00f, 34.666668f, 17.50f, 18.833334f, 38.00f, 38.666668f, 19.50f, 29.00f, 58.50f, 59.50f, 30.00f, 30.50f, 61.50f, 62.50f, 31.50f, 20.833334f, 42.00f, 42.666668f, 21.50f, 10.416667f, 21.00f, - 21.333334f, 10.75f, 16.00f, 32.25f, 32.75f, 16.50f, 16.75f, 33.75f, 34.25f, 17.25f, 11.416667f, 23.00f, 23.333334f, 11.75f, 12.416667f, 25.00f, 25.333334f, 12.75f, 19.00f, 38.25f, 38.75f, 19.50f, - 19.75f, 39.75f, 40.25f, 20.25f, 13.416667f, 27.00f, 27.333334f, 13.75f, 26.833334f, 54.00f, 54.666668f, 27.50f, 41.00f, 82.50f, 83.50f, 42.00f, 42.50f, 85.50f, 86.50f, 43.50f, 28.833334f, 58.00f, - 58.666668f, 29.50f, 30.833334f, 62.00f, 62.666668f, 31.50f, 47.00f, 94.50f, 95.50f, 48.00f, 48.50f, 97.50f, 98.50f, 49.50f, 32.833332f, 66.00f, 66.666664f, 33.50f, 16.416666f, 33.00f, 33.333332f, - 16.75f, 25.00f, 50.25f, 50.75f, 25.50f, 25.75f, 51.75f, 52.25f, 26.25f, 17.416666f, 35.00f, 35.333332f, 17.75f, 18.416666f, 37.00f, 37.333332f, 18.75f, 28.00f, 56.25f, 56.75f, 28.50f, 28.75f, - 57.75f, 58.25f, 29.25f, 19.416666f, 39.00f, 39.333332f, 19.75f, 38.833332f, 78.00f, 78.666664f, 39.50f, 59.00f, 118.50f, 119.50f, 60.00f, 60.50f, 121.50f, 122.50f, 61.50f, 40.833332f, 82.00f, - 82.666664f, 41.50f, 42.833332f, 86.00f, 86.666664f, 43.50f, 65.00f, 130.50f, 131.50f, 66.00f, 66.50f, 133.50f, 134.50f, 67.50f, 44.833332f, 90.00f, 90.666664f, 45.50f, 22.416666f, 45.00f, - 45.333332f, 22.75f, 34.00f, 68.25f, 68.75f, 34.50f, 34.75f, 69.75f, 70.25f, 35.25f, 23.416666f, 47.00f, 47.333332f, 23.75f, 24.416666f, 49.00f, 49.333332f, 24.75f, 37.00f, 74.25f, 74.75f, - 37.50f, 37.75f, 75.75f, 76.25f, 38.25f, 25.416666f, 51.00f, 51.333332f, 25.75f, 50.833332f, 102.00f, 102.666664f, 51.50f, 77.00f, 154.50f, 155.50f, 78.00f, 78.50f, 157.50f, 158.50f, 79.50f, - 52.833332f, 106.00f, 106.666664f, 53.50f, 54.833332f, 110.00f, 110.666664f, 55.50f, 83.00f, 166.50f, 167.50f, 84.00f, 84.50f, 169.50f, 170.50f, 85.50f, 56.833332f, 114.00f, 114.666664f, - 57.50f, 28.416666f, 57.00f, 57.333332f, 28.75f, 43.00f, 86.25f, 86.75f, 43.50f, 43.75f, 87.75f, 88.25f, 44.25f, 29.416666f, 59.00f, 59.333332f, 29.75f, 30.416666f, 61.00f, 61.333332f, 30.75f, - 46.00f, 92.25f, 92.75f, 46.50f, 46.75f, 93.75f, 94.25f, 47.25f, 31.416666f, 63.00f, 63.333332f, 31.75f, 62.833332f, 126.00f, 126.666664f, 63.50f, 95.00f, 190.50f, 191.50f, 96.00f, 96.50f, - 193.50f, 194.50f, 97.50f, 64.833336f, 130.00f, 130.666672f, 65.50f, 66.833336f, 134.00f, 134.666672f, 67.50f, 101.00f, 202.50f, 203.50f, 102.00f, 102.50f, 205.50f, 206.50f, 103.50f, - 68.833336f, 138.00f, 138.666672f, 69.50f, 34.416668f, 69.00f, 69.333336f, 34.75f, 52.00f, 104.25f, 104.75f, 52.50f, 52.75f, 105.75f, 106.25f, 53.25f, 35.416668f, 71.00f, 71.333336f, 35.75f}); - input.linspace(1.); - - sd::ops::avgpool3dnew op; - auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 1, pH = 1, pW = 1, dD = 1, dH = 1, dW = 1; + int oD = 4, oH = 4, oW = 4; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, oD, oH, oW}, + {0.416667f, 1.00f, 1.333333f, 0.75f, 1.00f, 2.25f, + 2.75f, 1.50f, 1.75f, 3.75f, 4.25f, 2.25f, + 1.416667f, 3.00f, 3.333333f, 1.75f, 2.833333f, 6.00f, + 6.666667f, 3.50f, 5.00f, 10.50f, 11.50f, 6.00f, + 6.50f, 13.50f, 14.50f, 7.50f, 4.833333f, 10.00f, + 10.666667f, 5.50f, 6.833333f, 14.00f, 14.666667f, 7.50f, + 11.00f, 22.50f, 23.50f, 12.00f, 12.50f, 25.50f, + 26.50f, 13.50f, 8.833333f, 18.00f, 18.666666f, 9.50f, + 4.416667f, 9.00f, 9.333333f, 4.75f, 7.00f, 14.25f, + 14.75f, 7.50f, 7.75f, 15.75f, 16.25f, 8.25f, + 5.416667f, 11.00f, 11.333333f, 5.75f, 6.416667f, 13.00f, + 13.333333f, 6.75f, 10.00f, 20.25f, 20.75f, 10.50f, + 10.75f, 21.75f, 22.25f, 11.25f, 7.416667f, 15.00f, + 15.333333f, 7.75f, 14.833333f, 30.00f, 30.666666f, 15.50f, + 23.00f, 46.50f, 47.50f, 24.00f, 24.50f, 49.50f, + 50.50f, 25.50f, 16.833334f, 34.00f, 34.666668f, 17.50f, + 18.833334f, 38.00f, 38.666668f, 19.50f, 29.00f, 58.50f, + 59.50f, 30.00f, 30.50f, 61.50f, 62.50f, 31.50f, + 20.833334f, 42.00f, 42.666668f, 21.50f, 10.416667f, 21.00f, + 21.333334f, 10.75f, 16.00f, 32.25f, 32.75f, 16.50f, + 16.75f, 33.75f, 34.25f, 17.25f, 11.416667f, 23.00f, + 23.333334f, 11.75f, 12.416667f, 25.00f, 25.333334f, 12.75f, + 19.00f, 38.25f, 38.75f, 19.50f, 19.75f, 39.75f, + 40.25f, 20.25f, 13.416667f, 27.00f, 27.333334f, 13.75f, + 26.833334f, 54.00f, 54.666668f, 27.50f, 41.00f, 82.50f, + 83.50f, 42.00f, 42.50f, 85.50f, 86.50f, 43.50f, + 28.833334f, 58.00f, 58.666668f, 29.50f, 30.833334f, 62.00f, + 62.666668f, 31.50f, 47.00f, 94.50f, 95.50f, 48.00f, + 48.50f, 97.50f, 98.50f, 49.50f, 32.833332f, 66.00f, + 66.666664f, 33.50f, 16.416666f, 33.00f, 33.333332f, 16.75f, + 25.00f, 50.25f, 50.75f, 25.50f, 25.75f, 51.75f, + 52.25f, 26.25f, 17.416666f, 35.00f, 35.333332f, 17.75f, + 18.416666f, 37.00f, 37.333332f, 18.75f, 28.00f, 56.25f, + 56.75f, 28.50f, 28.75f, 57.75f, 58.25f, 29.25f, + 19.416666f, 39.00f, 39.333332f, 19.75f, 38.833332f, 78.00f, + 78.666664f, 39.50f, 59.00f, 118.50f, 119.50f, 60.00f, + 60.50f, 121.50f, 122.50f, 61.50f, 40.833332f, 82.00f, + 82.666664f, 41.50f, 42.833332f, 86.00f, 86.666664f, 43.50f, + 65.00f, 130.50f, 131.50f, 66.00f, 66.50f, 133.50f, + 134.50f, 67.50f, 44.833332f, 90.00f, 90.666664f, 45.50f, + 22.416666f, 45.00f, 45.333332f, 22.75f, 34.00f, 68.25f, + 68.75f, 34.50f, 34.75f, 69.75f, 70.25f, 35.25f, + 23.416666f, 47.00f, 47.333332f, 23.75f, 24.416666f, 49.00f, + 49.333332f, 24.75f, 37.00f, 74.25f, 74.75f, 37.50f, + 37.75f, 75.75f, 76.25f, 38.25f, 25.416666f, 51.00f, + 51.333332f, 25.75f, 50.833332f, 102.00f, 102.666664f, 51.50f, + 77.00f, 154.50f, 155.50f, 78.00f, 78.50f, 157.50f, + 158.50f, 79.50f, 52.833332f, 106.00f, 106.666664f, 53.50f, + 54.833332f, 110.00f, 110.666664f, 55.50f, 83.00f, 166.50f, + 167.50f, 84.00f, 84.50f, 169.50f, 170.50f, 85.50f, + 56.833332f, 114.00f, 114.666664f, 57.50f, 28.416666f, 57.00f, + 57.333332f, 28.75f, 43.00f, 86.25f, 86.75f, 43.50f, + 43.75f, 87.75f, 88.25f, 44.25f, 29.416666f, 59.00f, + 59.333332f, 29.75f, 30.416666f, 61.00f, 61.333332f, 30.75f, + 46.00f, 92.25f, 92.75f, 46.50f, 46.75f, 93.75f, + 94.25f, 47.25f, 31.416666f, 63.00f, 63.333332f, 31.75f, + 62.833332f, 126.00f, 126.666664f, 63.50f, 95.00f, 190.50f, + 191.50f, 96.00f, 96.50f, 193.50f, 194.50f, 97.50f, + 64.833336f, 130.00f, 130.666672f, 65.50f, 66.833336f, 134.00f, + 134.666672f, 67.50f, 101.00f, 202.50f, 203.50f, 102.00f, + 102.50f, 205.50f, 206.50f, 103.50f, 68.833336f, 138.00f, + 138.666672f, 69.50f, 34.416668f, 69.00f, 69.333336f, 34.75f, + 52.00f, 104.25f, 104.75f, 52.50f, 52.75f, 105.75f, + 106.25f, 53.25f, 35.416668f, 71.00f, 71.333336f, 35.75f}); + input.linspace(1.); + + sd::ops::avgpool3dnew op; + auto results = op.evaluate({&input}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool3d_test1) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}, {20.f, 21.f, 23.f, 24.f, 32.f, 33.f, 35.f, 36.f, 56.f, 57.f, 59.f, 60.f, 68.f, 69.f, 71.f, 72.f, 92.f, 93.f, 95.f, 96.f, 104.f, 105.f, 107.f, 108.f, - 128.f, 129.f, 131.f, 132.f, 140.f, 141.f, 143.f, 144.f, 164.f, 165.f, 167.f, 168.f, 176.f, 177.f, 179.f, 180.f, 200.f, 201.f, 203.f, 204.f, 212.f, 213.f, 215.f, 216.f}); - input.linspace(1.); - - sd::ops::maxpool3dnew op; - auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, oD, oH, oW}, + {20.f, 21.f, 23.f, 24.f, 32.f, 33.f, 35.f, 36.f, 56.f, 57.f, + 59.f, 60.f, 68.f, 69.f, 71.f, 72.f, 92.f, 93.f, 95.f, 96.f, + 104.f, 105.f, 107.f, 108.f, 128.f, 129.f, 131.f, 132.f, 140.f, 141.f, + 143.f, 144.f, 164.f, 165.f, 167.f, 168.f, 176.f, 177.f, 179.f, 180.f, + 200.f, 201.f, 203.f, 204.f, 212.f, 213.f, 215.f, 216.f}); + input.linspace(1.); + + sd::ops::maxpool3dnew op; + auto results = op.evaluate({&input}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool3d_test2) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=3,oH=4,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, - 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, 88.f, 89.f, 90.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, - 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, 88.f, 89.f, 90.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, - 157.f, 158.f, 159.f, 160.f, 161.f, 162.f, 160.f, 161.f, 162.f, 166.f, 167.f, 168.f, 169.f, 170.f, 171.f, 169.f, 170.f, 171.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 178.f, 179.f, 180.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 178.f, 179.f, 180.f, - 193.f, 194.f, 195.f, 196.f, 197.f, 198.f, 196.f, 197.f, 198.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, - 193.f, 194.f, 195.f, 196.f, 197.f, 198.f, 196.f, 197.f, 198.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f}); - input.linspace(1.); - - sd::ops::maxpool3dnew op; - auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); -} + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 3, oH = 4, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, oD, oH, oW, iC}, + {49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 58.f, + 59.f, 60.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 67.f, 68.f, + 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, + 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 85.f, 86.f, 87.f, 88.f, + 89.f, 90.f, 88.f, 89.f, 90.f, 94.f, 95.f, 96.f, 97.f, 98.f, + 99.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, + 106.f, 107.f, 108.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, + 107.f, 108.f, 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, 88.f, 89.f, + 90.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, + 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, 103.f, + 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, 157.f, 158.f, + 159.f, 160.f, 161.f, 162.f, 160.f, 161.f, 162.f, 166.f, 167.f, 168.f, + 169.f, 170.f, 171.f, 169.f, 170.f, 171.f, 175.f, 176.f, 177.f, 178.f, + 179.f, 180.f, 178.f, 179.f, 180.f, 175.f, 176.f, 177.f, 178.f, 179.f, + 180.f, 178.f, 179.f, 180.f, 193.f, 194.f, 195.f, 196.f, 197.f, 198.f, + 196.f, 197.f, 198.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 205.f, + 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, + 216.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, + 193.f, 194.f, 195.f, 196.f, 197.f, 198.f, 196.f, 197.f, 198.f, 202.f, + 203.f, 204.f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 211.f, 212.f, + 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, 211.f, 212.f, 213.f, + 214.f, 215.f, 216.f, 214.f, 215.f, 216.f}); + input.linspace(1.); + + sd::ops::maxpool3dnew op; + auto results = op.evaluate({&input}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool3d_test3) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, {58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, - 166.f, 167.f, 168.f, 169.f, 170.f, 171.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f}); - input.linspace(1.); - - sd::ops::maxpool3dnew op; - auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, oD, oH, oW, iC}, + {58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 67.f, 68.f, 69.f, 70.f, + 71.f, 72.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 103.f, 104.f, + 105.f, 106.f, 107.f, 108.f, 166.f, 167.f, 168.f, 169.f, 170.f, 171.f, + 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 202.f, 203.f, 204.f, 205.f, + 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f}); + input.linspace(1.); + + sd::ops::maxpool3dnew op; + auto results = op.evaluate({&input}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool3d_test4) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1; - int oD=4,oH=4,oW=4; - int paddingMode = 0; // -SAME, 0-VALID - int dataFormat = 0; // -NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW},{ 4.f, 5.f, 6.f, 6.f, 7.f, 8.f, 9.f, 9.f, 10.f, 11.f, 12.f, 12.f, 10.f, 11.f, 12.f, 12.f, 16.f, 17.f, 18.f, 18.f, 19.f, 20.f, 21.f, 21.f, 22.f, 23.f, 24.f, 24.f, 22.f, 23.f, 24.f, 24.f, 28.f, 29.f, 30.f, 30.f, 31.f, 32.f, 33.f, 33.f, 34.f, 35.f, 36.f, 36.f, 34.f, 35.f, 36.f, 36.f, - 28.f, 29.f, 30.f, 30.f, 31.f, 32.f, 33.f, 33.f, 34.f, 35.f, 36.f, 36.f, 34.f, 35.f, 36.f, 36.f, 40.f, 41.f, 42.f, 42.f, 43.f, 44.f, 45.f, 45.f, 46.f, 47.f, 48.f, 48.f, 46.f, 47.f, 48.f, 48.f, 52.f, 53.f, 54.f, 54.f, 55.f, 56.f, 57.f, 57.f, 58.f, 59.f, 60.f, 60.f, 58.f, 59.f, 60.f, 60.f, - 64.f, 65.f, 66.f, 66.f, 67.f, 68.f, 69.f, 69.f, 70.f, 71.f, 72.f, 72.f, 70.f, 71.f, 72.f, 72.f, 64.f, 65.f, 66.f, 66.f, 67.f, 68.f, 69.f, 69.f, 70.f, 71.f, 72.f, 72.f, 70.f, 71.f, 72.f, 72.f, 76.f, 77.f, 78.f, 78.f, 79.f, 80.f, 81.f, 81.f, 82.f, 83.f, 84.f, 84.f, 82.f, 83.f, 84.f, 84.f, - 88.f, 89.f, 90.f, 90.f, 91.f, 92.f, 93.f, 93.f, 94.f, 95.f, 96.f, 96.f, 94.f, 95.f, 96.f, 96.f, 100.f, 101.f, 102.f, 102.f, 103.f, 104.f, 105.f, 105.f, 106.f, 107.f, 108.f, 108.f, 106.f, 107.f, 108.f, 108.f, 100.f, 101.f, 102.f, 102.f, 103.f, 104.f, 105.f, 105.f, 106.f, 107.f, 108.f, 108.f, 106.f, 107.f, 108.f, 108.f, - 112.f, 113.f, 114.f, 114.f, 115.f, 116.f, 117.f, 117.f, 118.f, 119.f, 120.f, 120.f, 118.f, 119.f, 120.f, 120.f, 124.f, 125.f, 126.f, 126.f, 127.f, 128.f, 129.f, 129.f, 130.f, 131.f, 132.f, 132.f, 130.f, 131.f, 132.f, 132.f, 136.f, 137.f, 138.f, 138.f, 139.f, 140.f, 141.f, 141.f, 142.f, 143.f, 144.f, 144.f, 142.f, 143.f, 144.f, 144.f, - 136.f, 137.f, 138.f, 138.f, 139.f, 140.f, 141.f, 141.f, 142.f, 143.f, 144.f, 144.f, 142.f, 143.f, 144.f, 144.f, 148.f, 149.f, 150.f, 150.f, 151.f, 152.f, 153.f, 153.f, 154.f, 155.f, 156.f, 156.f, 154.f, 155.f, 156.f, 156.f, 160.f, 161.f, 162.f, 162.f, 163.f, 164.f, 165.f, 165.f, 166.f, 167.f, 168.f, 168.f, 166.f, 167.f, 168.f, 168.f, - 172.f, 173.f, 174.f, 174.f, 175.f, 176.f, 177.f, 177.f, 178.f, 179.f, 180.f, 180.f, 178.f, 179.f, 180.f, 180.f, 172.f, 173.f, 174.f, 174.f, 175.f, 176.f, 177.f, 177.f, 178.f, 179.f, 180.f, 180.f, 178.f, 179.f, 180.f, 180.f, 184.f, 185.f, 186.f, 186.f, 187.f, 188.f, 189.f, 189.f, 190.f, 191.f, 192.f, 192.f, 190.f, 191.f, 192.f, 192.f, - 196.f, 197.f, 198.f, 198.f, 199.f, 200.f, 201.f, 201.f, 202.f, 203.f, 204.f, 204.f, 202.f, 203.f, 204.f, 204.f, 208.f, 209.f, 210.f, 210.f, 211.f, 212.f, 213.f, 213.f, 214.f, 215.f, 216.f, 216.f, 214.f, 215.f, 216.f, 216.f, 208.f, 209.f, 210.f, 210.f, 211.f, 212.f, 213.f, 213.f, 214.f, 215.f, 216.f, 216.f, 214.f, 215.f, 216.f, 216.f}); - input.linspace(1.); - - sd::ops::maxpool3dnew op; - auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 1, pH = 1, pW = 1, dD = 1, dH = 1, dW = 1; + int oD = 4, oH = 4, oW = 4; + int paddingMode = 0; // -SAME, 0-VALID + int dataFormat = 0; // -NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, oD, oH, oW}, + {4.f, 5.f, 6.f, 6.f, 7.f, 8.f, 9.f, 9.f, 10.f, 11.f, + 12.f, 12.f, 10.f, 11.f, 12.f, 12.f, 16.f, 17.f, 18.f, 18.f, + 19.f, 20.f, 21.f, 21.f, 22.f, 23.f, 24.f, 24.f, 22.f, 23.f, + 24.f, 24.f, 28.f, 29.f, 30.f, 30.f, 31.f, 32.f, 33.f, 33.f, + 34.f, 35.f, 36.f, 36.f, 34.f, 35.f, 36.f, 36.f, 28.f, 29.f, + 30.f, 30.f, 31.f, 32.f, 33.f, 33.f, 34.f, 35.f, 36.f, 36.f, + 34.f, 35.f, 36.f, 36.f, 40.f, 41.f, 42.f, 42.f, 43.f, 44.f, + 45.f, 45.f, 46.f, 47.f, 48.f, 48.f, 46.f, 47.f, 48.f, 48.f, + 52.f, 53.f, 54.f, 54.f, 55.f, 56.f, 57.f, 57.f, 58.f, 59.f, + 60.f, 60.f, 58.f, 59.f, 60.f, 60.f, 64.f, 65.f, 66.f, 66.f, + 67.f, 68.f, 69.f, 69.f, 70.f, 71.f, 72.f, 72.f, 70.f, 71.f, + 72.f, 72.f, 64.f, 65.f, 66.f, 66.f, 67.f, 68.f, 69.f, 69.f, + 70.f, 71.f, 72.f, 72.f, 70.f, 71.f, 72.f, 72.f, 76.f, 77.f, + 78.f, 78.f, 79.f, 80.f, 81.f, 81.f, 82.f, 83.f, 84.f, 84.f, + 82.f, 83.f, 84.f, 84.f, 88.f, 89.f, 90.f, 90.f, 91.f, 92.f, + 93.f, 93.f, 94.f, 95.f, 96.f, 96.f, 94.f, 95.f, 96.f, 96.f, + 100.f, 101.f, 102.f, 102.f, 103.f, 104.f, 105.f, 105.f, 106.f, 107.f, + 108.f, 108.f, 106.f, 107.f, 108.f, 108.f, 100.f, 101.f, 102.f, 102.f, + 103.f, 104.f, 105.f, 105.f, 106.f, 107.f, 108.f, 108.f, 106.f, 107.f, + 108.f, 108.f, 112.f, 113.f, 114.f, 114.f, 115.f, 116.f, 117.f, 117.f, + 118.f, 119.f, 120.f, 120.f, 118.f, 119.f, 120.f, 120.f, 124.f, 125.f, + 126.f, 126.f, 127.f, 128.f, 129.f, 129.f, 130.f, 131.f, 132.f, 132.f, + 130.f, 131.f, 132.f, 132.f, 136.f, 137.f, 138.f, 138.f, 139.f, 140.f, + 141.f, 141.f, 142.f, 143.f, 144.f, 144.f, 142.f, 143.f, 144.f, 144.f, + 136.f, 137.f, 138.f, 138.f, 139.f, 140.f, 141.f, 141.f, 142.f, 143.f, + 144.f, 144.f, 142.f, 143.f, 144.f, 144.f, 148.f, 149.f, 150.f, 150.f, + 151.f, 152.f, 153.f, 153.f, 154.f, 155.f, 156.f, 156.f, 154.f, 155.f, + 156.f, 156.f, 160.f, 161.f, 162.f, 162.f, 163.f, 164.f, 165.f, 165.f, + 166.f, 167.f, 168.f, 168.f, 166.f, 167.f, 168.f, 168.f, 172.f, 173.f, + 174.f, 174.f, 175.f, 176.f, 177.f, 177.f, 178.f, 179.f, 180.f, 180.f, + 178.f, 179.f, 180.f, 180.f, 172.f, 173.f, 174.f, 174.f, 175.f, 176.f, + 177.f, 177.f, 178.f, 179.f, 180.f, 180.f, 178.f, 179.f, 180.f, 180.f, + 184.f, 185.f, 186.f, 186.f, 187.f, 188.f, 189.f, 189.f, 190.f, 191.f, + 192.f, 192.f, 190.f, 191.f, 192.f, 192.f, 196.f, 197.f, 198.f, 198.f, + 199.f, 200.f, 201.f, 201.f, 202.f, 203.f, 204.f, 204.f, 202.f, 203.f, + 204.f, 204.f, 208.f, 209.f, 210.f, 210.f, 211.f, 212.f, 213.f, 213.f, + 214.f, 215.f, 216.f, 216.f, 214.f, 215.f, 216.f, 216.f, 208.f, 209.f, + 210.f, 210.f, 211.f, 212.f, 213.f, 213.f, 214.f, 215.f, 216.f, 216.f, + 214.f, 215.f, 216.f, 216.f}); + input.linspace(1.); + + sd::ops::maxpool3dnew op; + auto results = op.evaluate({&input}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test1) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, - 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, - 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, - 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, - 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, - 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, - 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, - 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, - 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f}); - input.linspace(1.); - gradO = 2.; - - sd::ops::avgpool3dnew_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, iD, iH, iW}, + {0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, + 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, + 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, + 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, + 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, + 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, + 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, + 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, + 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, + 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, + 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, + 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, + 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, + 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, + 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, + 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, + 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, + 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, + 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f}); + input.linspace(1.); + gradO = 2.; + + sd::ops::avgpool3dnew_bp op; + auto results = op.evaluate({&input, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test2) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1; - int oD=4,oH=4,oW=4; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f}); - input.linspace(1.); - gradO = 2.; - - sd::ops::avgpool3dnew_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - // output->printBuffer(); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 1, pH = 1, pW = 1, dD = 1, dH = 1, dW = 1; + int oD = 4, oH = 4, oW = 4; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, iD, iH, iW}, + {1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f}); + input.linspace(1.); + gradO = 2.; + + sd::ops::avgpool3dnew_bp op; + auto results = op.evaluate({&input, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + // output->printBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test3) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=3,oH=4,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, - 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, - 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, - 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, - 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, - 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, - 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, - 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f}); - input.linspace(1.); - gradO = 2.; - - sd::ops::avgpool3dnew_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 0, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 3, oH = 4, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, iD, iH, iW, iC}, + {0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, + 1.25f, 1.25f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, + 1.16667f, 1.75f, 1.75f, 1.75f, 0.58333f, 0.58333f, 0.58333f, + 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, 0.41667f, + 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, + 1.25f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, + 2.5f, 2.5f, 2.5f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, + 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 1.16667f, 1.16667f, + 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, + 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, + 2.5f, 2.5f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, + 2.5f, 3.75f, 3.75f, 3.75f, 1.75f, 1.75f, 1.75f, + 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.75f, + 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, + 5.25f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, + 3.75f, 3.75f, 3.75f, 0.41667f, 0.41667f, 0.41667f, 0.83333f, + 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.58333f, 0.58333f, + 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, + 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, + 1.75f, 1.75f, 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, + 0.83333f, 1.25f, 1.25f, 1.25f, 0.83333f, 0.83333f, 0.83333f, + 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.16667f, + 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, + 3.5f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, + 3.5f, 3.5f, 3.5f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, + 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.25f, 1.25f, + 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, + 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, + 5.25f, 5.25f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, + 3.5f, 5.25f, 5.25f, 5.25f, 1.25f, 1.25f, 1.25f, + 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f}); + input.linspace(1.); + gradO = 2.; + + sd::ops::avgpool3dnew_bp op; + auto results = op.evaluate({&input, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 0, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test4) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=3,oH=4,oW=3; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.16667f, 0.16667f, 0.16667f, 0.33333f, 0.33333f, 0.33333f, 0.5f, 0.5f, 0.5f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, - 0.91667f, 0.91667f, 0.91667f, 1.83333f, 1.83333f, 1.83333f, 2.75f, 2.75f, 2.75f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.66667f, 0.66667f, 0.66667f, 1.33333f, 1.33333f, 1.33333f, 2.f, 2.f, 2.f, - 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 1.83333f, 1.83333f, 1.83333f, 3.66667f, 3.66667f, 3.66667f, 5.5f, 5.5f, 5.5f, 0.5f, 0.5f, 0.5f, 1.f, 1.f, 1.f, 1.5f, 1.5f, 1.5f, - 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 2.75f, 2.75f, 2.75f, 5.5f, 5.5f, 5.5f, 8.25f, 8.25f, 8.25f, - 0.16667f, 0.16667f, 0.16667f, 0.33333f, 0.33333f, 0.33333f, 0.5f, 0.5f, 0.5f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, - 0.91667f, 0.91667f, 0.91667f, 1.83333f, 1.83333f, 1.83333f, 2.75f, 2.75f, 2.75f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.66667f, 0.66667f, 0.66667f, 1.33333f, 1.33333f, 1.33333f, 2.f, 2.f, 2.f, - 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 1.83333f, 1.83333f, 1.83333f, 3.66667f, 3.66667f, 3.66667f, 5.5f, 5.5f, 5.5f, 0.5f, 0.5f, 0.5f, 1.f, 1.f, 1.f, 1.5f, 1.5f, 1.5f, - 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 2.75f, 2.75f, 2.75f, 5.5f, 5.5f, 5.5f, 8.25f, 8.25f, 8.25f}); - input.linspace(1.); - gradO = 2.; - - sd::ops::avgpool3dnew_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 0, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 3, oH = 4, oW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, iD, iH, iW, iC}, + {0.16667f, 0.16667f, 0.16667f, 0.33333f, 0.33333f, 0.33333f, 0.5f, + 0.5f, 0.5f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, + 0.66667f, 1.f, 1.f, 1.f, 0.58333f, 0.58333f, 0.58333f, + 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, 0.91667f, + 0.91667f, 0.91667f, 1.83333f, 1.83333f, 1.83333f, 2.75f, 2.75f, + 2.75f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, + 1.f, 1.f, 1.f, 0.66667f, 0.66667f, 0.66667f, 1.33333f, + 1.33333f, 1.33333f, 2.f, 2.f, 2.f, 1.16667f, 1.16667f, + 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, + 1.83333f, 1.83333f, 1.83333f, 3.66667f, 3.66667f, 3.66667f, 5.5f, + 5.5f, 5.5f, 0.5f, 0.5f, 0.5f, 1.f, 1.f, + 1.f, 1.5f, 1.5f, 1.5f, 1.f, 1.f, 1.f, + 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 1.75f, + 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, + 5.25f, 2.75f, 2.75f, 2.75f, 5.5f, 5.5f, 5.5f, + 8.25f, 8.25f, 8.25f, 0.16667f, 0.16667f, 0.16667f, 0.33333f, + 0.33333f, 0.33333f, 0.5f, 0.5f, 0.5f, 0.33333f, 0.33333f, + 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, + 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, + 1.75f, 1.75f, 0.91667f, 0.91667f, 0.91667f, 1.83333f, 1.83333f, + 1.83333f, 2.75f, 2.75f, 2.75f, 0.33333f, 0.33333f, 0.33333f, + 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.66667f, + 0.66667f, 0.66667f, 1.33333f, 1.33333f, 1.33333f, 2.f, 2.f, + 2.f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, + 3.5f, 3.5f, 3.5f, 1.83333f, 1.83333f, 1.83333f, 3.66667f, + 3.66667f, 3.66667f, 5.5f, 5.5f, 5.5f, 0.5f, 0.5f, + 0.5f, 1.f, 1.f, 1.f, 1.5f, 1.5f, 1.5f, + 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, + 3.f, 3.f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, + 3.5f, 5.25f, 5.25f, 5.25f, 2.75f, 2.75f, 2.75f, + 5.5f, 5.5f, 5.5f, 8.25f, 8.25f, 8.25f}); + input.linspace(1.); + gradO = 2.; + + sd::ops::avgpool3dnew_bp op; + auto results = op.evaluate({&input, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 0, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test1) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.6f, 0.f, 2.7f, 2.8f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.9f, 3.f, 0.f, 3.1f, 3.2f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.3f, 3.4f, 0.f, 3.5f, 3.6f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 0.f, 3.9f, 4.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.1f, 4.2f, 0.f, 4.3f, 4.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.5f, 4.6f, 0.f, 4.7f, 4.8f}); - - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::maxpool3dnew_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, iD, iH, iW}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.6f, 0.f, 2.7f, 2.8f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.9f, 3.f, 0.f, 3.1f, 3.2f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.3f, 3.4f, 0.f, 3.5f, 3.6f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 0.f, 3.9f, 4.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.1f, 4.2f, 0.f, 4.3f, 4.4f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.5f, 4.6f, 0.f, 4.7f, 4.8f}); + + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool3dnew_bp op; + auto results = op.evaluate({&input, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test2) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1; - int oD=4,oH=4,oW=4; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.000e+00f, 0.000e+00f, 0.000e+00f, 1.000e-01f, 2.000e-01f, 7.000e-01f, 5.000e-01f, 6.000e-01f, 1.500e+00f, 2.200e+00f, 2.400e+00f, 5.400e+00f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.700e+00f, 1.800e+00f, 3.900e+00f, 2.100e+00f, 2.200e+00f, 4.700e+00f, 5.400e+00f, 5.600e+00f, 1.180e+01f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 8.200e+00f, 8.400e+00f, 1.740e+01f, 9.000e+00f, 9.200e+00f, 1.900e+01f, 2.040e+01f, 2.080e+01f, 4.280e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 6.500e+00f, 6.600e+00f, 1.350e+01f, 6.900e+00f, 7.000e+00f, 1.430e+01f, 1.500e+01f, 1.520e+01f, 3.100e+01f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 8.100e+00f, 8.200e+00f, 1.670e+01f, 8.500e+00f, 8.600e+00f, 1.750e+01f, 1.820e+01f, 1.840e+01f, 3.740e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.100e+01f, 2.120e+01f, 4.300e+01f, 2.180e+01f, 2.200e+01f, 4.460e+01f, 4.600e+01f, 4.640e+01f, 9.400e+01f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.290e+01f, 1.300e+01f, 2.630e+01f, 1.330e+01f, 1.340e+01f, 2.710e+01f, 2.780e+01f, 2.800e+01f, 5.660e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.450e+01f, 1.460e+01f, 2.950e+01f, 1.490e+01f, 1.500e+01f, 3.030e+01f, 3.100e+01f, 3.120e+01f, 6.300e+01f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.380e+01f, 3.400e+01f, 6.860e+01f, 3.460e+01f, 3.480e+01f, 7.020e+01f, 7.160e+01f, 7.200e+01f, 1.452e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.930e+01f, 1.940e+01f, 3.910e+01f, 1.970e+01f, 1.980e+01f, 3.990e+01f, 4.060e+01f, 4.080e+01f, 8.220e+01f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.090e+01f, 2.100e+01f, 4.230e+01f, 2.130e+01f, 2.140e+01f, 4.310e+01f, 4.380e+01f, 4.400e+01f, 8.860e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 4.660e+01f, 4.680e+01f, 9.420e+01f, 4.740e+01f, 4.760e+01f, 9.580e+01f, 9.720e+01f, 9.760e+01f, 1.964e+02f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.570e+01f, 2.580e+01f, 5.190e+01f, 2.610e+01f, 2.620e+01f, 5.270e+01f, 5.340e+01f, 5.360e+01f, 1.078e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.730e+01f, 2.740e+01f, 5.510e+01f, 2.770e+01f, 2.780e+01f, 5.590e+01f, 5.660e+01f, 5.680e+01f, 1.142e+02f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 5.940e+01f, 5.960e+01f, 1.198e+02f, 6.020e+01f, 6.040e+01f, 1.214e+02f, 1.228e+02f, 1.232e+02f, 2.476e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.210e+01f, 3.220e+01f, 6.470e+01f, 3.250e+01f, 3.260e+01f, 6.550e+01f, 6.620e+01f, 6.640e+01f, 1.334e+02f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.370e+01f, 3.380e+01f, 6.790e+01f, 3.410e+01f, 3.420e+01f, 6.870e+01f, 6.940e+01f, 6.960e+01f, 1.398e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 7.220e+01f, 7.240e+01f, 1.454e+02f, 7.300e+01f, 7.320e+01f, 1.470e+02f, 1.484e+02f, 1.488e+02f, 2.988e+02f}); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::maxpool3dnew_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 1, pH = 1, pW = 1, dD = 1, dH = 1, dW = 1; + int oD = 4, oH = 4, oW = 4; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, iD, iH, iW}, + {0.000e+00f, 0.000e+00f, 0.000e+00f, 1.000e-01f, 2.000e-01f, 7.000e-01f, + 5.000e-01f, 6.000e-01f, 1.500e+00f, 2.200e+00f, 2.400e+00f, 5.400e+00f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.700e+00f, 1.800e+00f, 3.900e+00f, + 2.100e+00f, 2.200e+00f, 4.700e+00f, 5.400e+00f, 5.600e+00f, 1.180e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 8.200e+00f, 8.400e+00f, 1.740e+01f, + 9.000e+00f, 9.200e+00f, 1.900e+01f, 2.040e+01f, 2.080e+01f, 4.280e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 6.500e+00f, 6.600e+00f, 1.350e+01f, + 6.900e+00f, 7.000e+00f, 1.430e+01f, 1.500e+01f, 1.520e+01f, 3.100e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 8.100e+00f, 8.200e+00f, 1.670e+01f, + 8.500e+00f, 8.600e+00f, 1.750e+01f, 1.820e+01f, 1.840e+01f, 3.740e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.100e+01f, 2.120e+01f, 4.300e+01f, + 2.180e+01f, 2.200e+01f, 4.460e+01f, 4.600e+01f, 4.640e+01f, 9.400e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.290e+01f, 1.300e+01f, 2.630e+01f, + 1.330e+01f, 1.340e+01f, 2.710e+01f, 2.780e+01f, 2.800e+01f, 5.660e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.450e+01f, 1.460e+01f, 2.950e+01f, + 1.490e+01f, 1.500e+01f, 3.030e+01f, 3.100e+01f, 3.120e+01f, 6.300e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.380e+01f, 3.400e+01f, 6.860e+01f, + 3.460e+01f, 3.480e+01f, 7.020e+01f, 7.160e+01f, 7.200e+01f, 1.452e+02f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.930e+01f, 1.940e+01f, 3.910e+01f, + 1.970e+01f, 1.980e+01f, 3.990e+01f, 4.060e+01f, 4.080e+01f, 8.220e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.090e+01f, 2.100e+01f, 4.230e+01f, + 2.130e+01f, 2.140e+01f, 4.310e+01f, 4.380e+01f, 4.400e+01f, 8.860e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 4.660e+01f, 4.680e+01f, 9.420e+01f, + 4.740e+01f, 4.760e+01f, 9.580e+01f, 9.720e+01f, 9.760e+01f, 1.964e+02f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.570e+01f, 2.580e+01f, 5.190e+01f, + 2.610e+01f, 2.620e+01f, 5.270e+01f, 5.340e+01f, 5.360e+01f, 1.078e+02f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.730e+01f, 2.740e+01f, 5.510e+01f, + 2.770e+01f, 2.780e+01f, 5.590e+01f, 5.660e+01f, 5.680e+01f, 1.142e+02f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 5.940e+01f, 5.960e+01f, 1.198e+02f, + 6.020e+01f, 6.040e+01f, 1.214e+02f, 1.228e+02f, 1.232e+02f, 2.476e+02f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.210e+01f, 3.220e+01f, 6.470e+01f, + 3.250e+01f, 3.260e+01f, 6.550e+01f, 6.620e+01f, 6.640e+01f, 1.334e+02f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.370e+01f, 3.380e+01f, 6.790e+01f, + 3.410e+01f, 3.420e+01f, 6.870e+01f, 6.940e+01f, 6.960e+01f, 1.398e+02f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 7.220e+01f, 7.240e+01f, 1.454e+02f, + 7.300e+01f, 7.320e+01f, 1.470e+02f, 1.484e+02f, 1.488e+02f, 2.988e+02f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool3dnew_bp op; + auto results = op.evaluate({&input, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test3) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=3,oH=4,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, { 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, - 0.f, 0.f, 0.f, 1.f, 1.1f, 1.2f, 2.9f, 3.1f, 3.3f, 0.f, 0.f, 0.f, 4.7f, 4.9f, 5.1f, 11.2f, 11.6f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 11.f, 11.2f, 11.4f, 23.8f, 24.2f, 24.6f, 0.f, 0.f, 0.f, 12.8f, 13.f, 13.2f, 27.4f, 27.8f, 28.2f, 0.f, 0.f, 0.f, 31.f, 31.4f, 31.8f, 65.6f, 66.39999f, 67.2f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.9f, 11.f, 11.1f, 22.7f, 22.9f, 23.1f, - 0.f, 0.f, 0.f, 11.8f, 11.9f, 12.f, 24.5f, 24.7f, 24.9f, 0.f, 0.f, 0.f, 26.3f, 26.5f, 26.7f, 54.4f, 54.8f, 55.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 32.6f, 32.8f, 33.f, 67.f, 67.4f, 67.8f, 0.f, 0.f, 0.f, 34.4f, 34.6f, 34.8f, 70.6f, 71.f, 71.4f, 0.f, 0.f, 0.f, 74.2f, 74.6f, 75.f, 152.f, 152.8f, 153.6f}); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::maxpool3dnew_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 3, oH = 4, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, iD, iH, iW, iC}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, + 0.f, 0.f, 0.f, 1.f, 1.1f, 1.2f, 2.9f, 3.1f, 3.3f, + 0.f, 0.f, 0.f, 4.7f, 4.9f, 5.1f, 11.2f, 11.6f, 12.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 11.f, 11.2f, 11.4f, 23.8f, 24.2f, 24.6f, + 0.f, 0.f, 0.f, 12.8f, 13.f, 13.2f, 27.4f, 27.8f, 28.2f, + 0.f, 0.f, 0.f, 31.f, 31.4f, 31.8f, 65.6f, 66.39999f, 67.2f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 10.9f, 11.f, 11.1f, 22.7f, 22.9f, 23.1f, + 0.f, 0.f, 0.f, 11.8f, 11.9f, 12.f, 24.5f, 24.7f, 24.9f, + 0.f, 0.f, 0.f, 26.3f, 26.5f, 26.7f, 54.4f, 54.8f, 55.2f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 32.6f, 32.8f, 33.f, 67.f, 67.4f, 67.8f, + 0.f, 0.f, 0.f, 34.4f, 34.6f, 34.8f, 70.6f, 71.f, 71.4f, + 0.f, 0.f, 0.f, 74.2f, 74.6f, 75.f, 152.f, 152.8f, 153.6f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool3dnew_bp op; + auto results = op.evaluate({&input, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test4) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=3,oH=4,oW=3; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, 0.f, 0.f, 0.f, 5.7f, 6.f, 6.3f, - 14.1f, 14.7f, 15.3f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 11.f, 11.2f, 11.4f, 23.8f, 24.2f, - 24.6f, 0.f, 0.f, 0.f, 43.8f, 44.4f, 45.f, 93.f, 94.2f, 95.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 10.9f, 11.f, 11.1f, 22.7f, 22.9f, 23.1f, 0.f, 0.f, 0.f, 38.1f, 38.4f, 38.7f, 78.9f, 79.5f, 80.1f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 32.6f, 32.8f, 33.f, 67.f, 67.4f, 67.8f, 0.f, 0.f, 0.f, 108.6f, 109.2f, 109.8f, 222.6f, 223.8f, 225.f}); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::maxpool3dnew_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 3, oH = 4, oW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, iD, iH, iW, iC}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, + 0.f, 0.f, 0.f, 5.7f, 6.f, 6.3f, 14.1f, 14.7f, 15.3f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 11.f, 11.2f, 11.4f, 23.8f, 24.2f, 24.6f, + 0.f, 0.f, 0.f, 43.8f, 44.4f, 45.f, 93.f, 94.2f, 95.4f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 10.9f, 11.f, 11.1f, 22.7f, 22.9f, 23.1f, + 0.f, 0.f, 0.f, 38.1f, 38.4f, 38.7f, 78.9f, 79.5f, 80.1f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 32.6f, 32.8f, 33.f, 67.f, 67.4f, 67.8f, + 0.f, 0.f, 0.f, 108.6f, 109.2f, 109.8f, 222.6f, 223.8f, 225.f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool3dnew_bp op; + auto results = op.evaluate({&input, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, maxpool2d_bp_1) { - - auto input = NDArrayFactory::create_('c', {bS,iD,iH,iW}); - auto epsilon = NDArrayFactory::create_('c', {bS,iD,oH,oW}); - auto exp = NDArrayFactory::create('c', {bS,iD,iH,iW}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, input); - variableSpace->putVariable(-2, epsilon); - // variableSpace->putVariable(1, &z); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1}); - block->fillInputs({-2}); - std::vector* argI = block->getIArguments(); - *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - sd::ops::maxpool2d_bp bp; - Nd4jStatus status = bp.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - ASSERT_TRUE(exp.isSameShape(result)); - - delete variableSpace; - delete block; + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto epsilon = NDArrayFactory::create('c', {bS, iD, oH, oW}); + auto exp = NDArrayFactory::create('c', {bS, iD, iH, iW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); + variableSpace->putVariable(-2, epsilon); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->fillInputs({-2}); + block->appendI( + {kH, kW, sH, sW, pH, pW, dW, dH, 0, 0, + 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::maxpool2d_bp bp; + Nd4jStatus status = bp.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + ASSERT_TRUE(exp.isSameShape(*result)); + + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, maxpool2d_bp_2) { - - int bS=2, iD=1, iH=4,iW=4, oD=3, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; - int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; - - // TypeParam epsilonBuff[] = {6., 7., 8., 10., 11., 12., 14., 15., 16., 22., 23., 24., 26., 27., 28., 30., 31., 32.}; - // TypeParam expectedBuff[] = {0., 0., 0., 0.,0., 6., 7., 8.,0.,10.,11.,12.,0.,14.,15.,16.,0., 0., 0., 0.,0.,22.,23.,24.,0.,26.,27.,28.,0.,30.,31.,32.}; - - NDArray input('c', {bS,iD,iH,iW}); - NDArray epsilon('c', {bS,iD,oH,oW}, {6., 7., 8., 10., 11., 12., 14., 15., 16., 22., 23., 24., 26., 27., 28., 30., 31., 32.}); - NDArray expected('c', {bS,iD,iH,iW}, {0., 0., 0., 0.,0., 6., 7., 8.,0.,10.,11.,12.,0.,14.,15.,16.,0., 0., 0., 0.,0.,22.,23.,24.,0.,26.,27.,28.,0.,30.,31.,32.}); - - - input.linspace(1.); - - std::initializer_list argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - sd::ops::maxpool2d_bp op; - auto results = op.evaluate({&input, &epsilon}, {}, argI); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 1, iH = 4, iW = 4, oD = 3, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = (iH - kH - (kH - 1) * (dH - 1) + 2 * pH) / sH + 1; + int oW = (iW - kW - (kW - 1) * (dW - 1) + 2 * pW) / sW + 1; + + // TypeParam epsilonBuff[] = + // {6., 7., 8., 10., 11., 12., 14., 15., 16., 22., 23., 24., 26., 27., 28., 30., + // 31., 32.}; TypeParam expectedBuff[] = {0., 0., 0., + // 0.,0., 6., 7., 8.,0.,10.,11.,12.,0.,14.,15.,16.,0., 0., 0., + // 0.,0.,22.,23.,24.,0.,26.,27.,28.,0.,30.,31.,32.}; + + NDArray input('c', {bS, iD, iH, iW}); + NDArray epsilon('c', {bS, iD, oH, oW}, + {6., 7., 8., 10., 11., 12., 14., 15., 16., 22., 23., 24., 26., + 27., 28., 30., 31., 32.}); + NDArray expected('c', {bS, iD, iH, iW}, + {0., 0., 0., 0., 0., 6., 7., 8., 0., 10., 11., + 12., 0., 14., 15., 16., 0., 0., 0., 0., 0., 22., + 23., 24., 0., 26., 27., 28., 0., 30., 31., 32.}); + + input.linspace(1.); + + std::initializer_list argI = { + kH, kW, sH, sW, pH, pW, + dW, dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride + // Height/Width; 4,5 - pad Height/Width; 6,7 - + // dilation Height/Width; 8 - same mode; + + sd::ops::maxpool2d_bp op; + auto results = op.evaluate({&input, &epsilon}, {}, argI); + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_3) { - - int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f}); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::maxpool2d_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + int bS = 2, iH = 4, iW = 3, iC = 3, kH = 3, kW = 2, sH = 1, sW = 1, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, iH, iW}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool2d_bp op; + auto results = + op.evaluate({&input, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_4) { - - int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=1,pW=1, dH=1,dW=1; - int oH=4,oW=4; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.f, 0.f, 0.f, 0.1f, 0.2f, 0.7f, 0.5f, 0.6f, 1.5f, 2.2f, 2.4f, 5.4f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 3.9f, 2.1f, 2.2f, 4.7f, 5.4f, 5.6f, 11.8f, - 0.f, 0.f, 0.f, 3.3f, 3.4f, 7.1f, 3.7f, 3.8f, 7.9f, 8.6f, 8.8f, 18.2f, 0.f, 0.f, 0.f, 4.9f, 5.f, 10.3f, 5.3f, 5.4f, 11.1f, 11.8f, 12.f, 24.6f, - 0.f, 0.f, 0.f, 6.5f, 6.6f, 13.5f, 6.9f, 7.f, 14.3f, 15.f, 15.2f, 31.f, 0.f, 0.f, 0.f, 8.1f, 8.2f, 16.7f, 8.5f, 8.6f, 17.5f, 18.2f, 18.4f, 37.4f}); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::maxpool2d_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iH = 4, iW = 3, iC = 3, kH = 3, kW = 2, sH = 1, sW = 1, pH = 1, + pW = 1, dH = 1, dW = 1; + int oH = 4, oW = 4; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, iH, iW}, + {0.f, 0.f, 0.f, 0.1f, 0.2f, 0.7f, 0.5f, 0.6f, 1.5f, 2.2f, 2.4f, + 5.4f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 3.9f, 2.1f, 2.2f, 4.7f, 5.4f, + 5.6f, 11.8f, 0.f, 0.f, 0.f, 3.3f, 3.4f, 7.1f, 3.7f, 3.8f, 7.9f, + 8.6f, 8.8f, 18.2f, 0.f, 0.f, 0.f, 4.9f, 5.f, 10.3f, 5.3f, 5.4f, + 11.1f, 11.8f, 12.f, 24.6f, 0.f, 0.f, 0.f, 6.5f, 6.6f, 13.5f, 6.9f, + 7.f, 14.3f, 15.f, 15.2f, 31.f, 0.f, 0.f, 0.f, 8.1f, 8.2f, 16.7f, + 8.5f, 8.6f, 17.5f, 18.2f, 18.4f, 37.4f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool2d_bp op; + auto results = + op.evaluate({&input, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_5) { - - int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=4,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, 0.f, 0.f, 0.f, 1.f, 1.1f, 1.2f, 2.9f, 3.1f, 3.3f, - 0.f, 0.f, 0.f, 4.7f, 4.9f, 5.1f, 11.2f, 11.6f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 3.9f, 8.3f, 8.5f, 8.7f, - 0.f, 0.f, 0.f, 4.6f, 4.7f, 4.8f, 10.1f, 10.3f, 10.5f, 0.f, 0.f, 0.f, 11.9f, 12.1f, 12.3f, 25.6f, 26.f, 26.4f}); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::maxpool2d_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iH = 4, iW = 3, iC = 3, kH = 3, kW = 2, sH = 1, sW = 1, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, iH, iW, iC}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, 0.f, 0.f, 0.f, 1.f, + 1.1f, 1.2f, 2.9f, 3.1f, 3.3f, 0.f, 0.f, 0.f, 4.7f, 4.9f, 5.1f, + 11.2f, 11.6f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 3.9f, 8.3f, 8.5f, 8.7f, 0.f, + 0.f, 0.f, 4.6f, 4.7f, 4.8f, 10.1f, 10.3f, 10.5f, 0.f, 0.f, 0.f, + 11.9f, 12.1f, 12.3f, 25.6f, 26.f, 26.4f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool2d_bp op; + auto results = + op.evaluate({&input, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_6) { - - int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, - 0.f, 0.f, 0.f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 0.f, 0.f, 0.f, 1.9f, 2.f, 2.1f, 2.2f, 2.3f, 2.4f}); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::maxpool2d_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iH = 4, iW = 3, iC = 3, kH = 3, kW = 2, sH = 1, sW = 1, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, iH, iW, iC}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, + 0.4f, 0.5f, 0.6f, 0.f, 0.f, 0.f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 1.5f, + 1.6f, 1.7f, 1.8f, 0.f, 0.f, 0.f, 1.9f, 2.f, 2.1f, 2.2f, 2.3f, 2.4f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool2d_bp op; + auto results = + op.evaluate({&input, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, maxpool2d_bp_7) { - - int bS=2, iH=56,iW=56, iC=3, kH=2,kW=2, sH=2,sW=2, pH=0,pW=0, dH=1,dW=1; - int oH=28,oW=28; - int paddingMode = 1; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::maxpool2d_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); - // auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - // ASSERT_TRUE(expected.isSameShape(output)); - // ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iH = 56, iW = 56, iC = 3, kH = 2, kW = 2, sH = 2, sW = 2, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = 28, oW = 28; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); + + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool2d_bp op; + auto results = + op.evaluate({&input, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, 1, dataFormat}); + // auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + // ASSERT_TRUE(expected.isSameShape(output)); + // ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, avgpool2d_bp_1) { - - auto input = NDArrayFactory::create_('c', {bS,iD,iH,iW}); - auto epsilon = NDArrayFactory::create_('c', {bS,iD,oH,oW}); - auto exp = NDArrayFactory::create('c', {bS,iD,iH,iW}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, input); - variableSpace->putVariable(-2, epsilon); - // variableSpace->putVariable(1, &z); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1}); - block->fillInputs({-2}); - std::vector* argI = block->getIArguments(); - *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 1, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode, 9 - extraParam0 (unnecessary for avg mode), 10 - data format - - sd::ops::avgpool2d_bp bp; - Nd4jStatus status = bp.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - ASSERT_TRUE(exp.isSameShape(result)); - - delete variableSpace; - delete block; + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto epsilon = NDArrayFactory::create('c', {bS, iD, oH, oW}); + auto exp = NDArrayFactory::create('c', {bS, iD, iH, iW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); + variableSpace->putVariable(-2, epsilon); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->fillInputs({-2}); + block->appendI( + {kH, kW, sH, sW, pH, pW, dW, dH, 0, 1, + 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode, 9 - + // extraParam0 (unnecessary for avg mode), 10 - data format + + sd::ops::avgpool2d_bp bp; + Nd4jStatus status = bp.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + ASSERT_TRUE(exp.isSameShape(*result)); + + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_2) { - - int bS=2, iD=1, iH=4,iW=4, oD=3, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; - int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; - - // TypeParam epsilonBuff[] = {3.5 , 4.5 , 5.5, 7.5 , 8.5 , 9.5, 11.5, 12.5, 13.5, 19.5, 20.5, 21.5, 23.5, 24.5, 25.5, 27.5, 28.5, 29.5}; - // TypeParam expectedBuff[] = {0.875, 2., 2.5,1.375, 2.75 , 6., 7., 3.75, 4.75 ,10., 11., 5.75, 2.875, 6., 6.5, 3.375, 4.875, 10.,10.5, 5.375, 10.75, 22.,23., 11.75, 12.75, 26.,27., 13.75, 6.875, 14.,14.5, 7.375}; - - auto input = NDArrayFactory::create('c', {bS,iD,iH,iW}); - auto epsilon = NDArrayFactory::create('c', {bS,iD,oH,oW}, {3.5f, 4.5f, 5.5f, 7.5f, 8.5f, 9.5f, 11.5f, 12.5f, 13.5f, 19.5f, 20.5f, 21.5f, 23.5f, 24.5f, 25.5f, 27.5f, 28.5f, 29.5f}); - auto expected = NDArrayFactory::create('c', {bS,iD,iH,iW}, {0.875f, 2.f, 2.5f, 1.375f, 2.75f, 6.f, 7.f, 3.75f, 4.75f, 10.f, 11.f, 5.75f, 2.875f, 6.f, 6.5f, 3.375f, 4.875f, 10.f, 10.5f, 5.375f, 10.75f, 22.f, 23.f, 11.75f, 12.75f, 26.f, 27.f, 13.75f, 6.875f, 14.f, 14.5f, 7.375f}); - - input.linspace(1.); - - std::initializer_list argI = {kH,kW, sH,sW, pH,pW, dW,dH, 1, 1, 0}; - - sd::ops::avgpool2d_bp op; - auto results = op.evaluate({&input, &epsilon}, {}, argI); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 1, iH = 4, iW = 4, oD = 3, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = (iH - kH - (kH - 1) * (dH - 1) + 2 * pH) / sH + 1; + int oW = (iW - kW - (kW - 1) * (dW - 1) + 2 * pW) / sW + 1; + + // TypeParam epsilonBuff[] = {3.5 , 4.5 , 5.5, 7.5 , 8.5 + // , 9.5, 11.5, 12.5, 13.5, 19.5, 20.5, 21.5, 23.5, 24.5, 25.5, 27.5, 28.5, 29.5}; + // TypeParam expectedBuff[] = {0.875, 2., 2.5,1.375, 2.75 + // , 6., 7., 3.75, 4.75 + // ,10., 11., 5.75, 2.875, 6., 6.5, 3.375, 4.875, 10.,10.5, 5.375, 10.75, 22.,23., + // 11.75, 12.75, 26.,27., 13.75, 6.875, 14.,14.5, 7.375}; + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto epsilon = NDArrayFactory::create( + 'c', {bS, iD, oH, oW}, + {3.5f, 4.5f, 5.5f, 7.5f, 8.5f, 9.5f, 11.5f, 12.5f, 13.5f, 19.5f, 20.5f, + 21.5f, 23.5f, 24.5f, 25.5f, 27.5f, 28.5f, 29.5f}); + auto expected = NDArrayFactory::create( + 'c', {bS, iD, iH, iW}, + {0.875f, 2.f, 2.5f, 1.375f, 2.75f, 6.f, 7.f, 3.75f, + 4.75f, 10.f, 11.f, 5.75f, 2.875f, 6.f, 6.5f, 3.375f, + 4.875f, 10.f, 10.5f, 5.375f, 10.75f, 22.f, 23.f, 11.75f, + 12.75f, 26.f, 27.f, 13.75f, 6.875f, 14.f, 14.5f, 7.375f}); + + input.linspace(1.); + + std::initializer_list argI = {kH, kW, sH, sW, pH, pW, + dW, dH, 1, 1, 0}; + + sd::ops::avgpool2d_bp op; + auto results = op.evaluate({&input, &epsilon}, {}, argI); + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_3) { - - int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.016667f, 0.05f, 0.033333f, 0.066667f, 0.166667f, 0.1f, 0.066667f, 0.166667f, 0.1f, 0.05f, 0.116667f, 0.066667f, - 0.083333f, 0.183333f, 0.1f, 0.2f, 0.433333f, 0.233333f, 0.2f, 0.433333f, 0.233333f, 0.116667f, 0.25f, 0.133333f, - 0.15f, 0.316667f, 0.166667f, 0.333333f, 0.7f, 0.366667f, 0.333333f, 0.7f, 0.366667f, 0.183333f, 0.383333f, 0.2f, - 0.216667f, 0.45f, 0.233333f, 0.466667f, 0.966667f, 0.5f, 0.466667f, 0.966667f, 0.5f, 0.25f, 0.516667f, 0.266667f, - 0.283333f, 0.583333f, 0.3f, 0.6f, 1.233333f, 0.633333f, 0.6f, 1.233333f, 0.633333f, 0.316667f, 0.65f, 0.333333f, - 0.35f, 0.716667f, 0.366667f, 0.733333f, 1.5f, 0.766667f, 0.733333f, 1.5f, 0.766667f, 0.383333f, 0.783333f, 0.4f }); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::avgpool2d_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iH = 4, iW = 3, iC = 3, kH = 3, kW = 2, sH = 1, sW = 1, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, iH, iW}, + {0.016667f, 0.05f, 0.033333f, 0.066667f, 0.166667f, 0.1f, + 0.066667f, 0.166667f, 0.1f, 0.05f, 0.116667f, 0.066667f, + 0.083333f, 0.183333f, 0.1f, 0.2f, 0.433333f, 0.233333f, + 0.2f, 0.433333f, 0.233333f, 0.116667f, 0.25f, 0.133333f, + 0.15f, 0.316667f, 0.166667f, 0.333333f, 0.7f, 0.366667f, + 0.333333f, 0.7f, 0.366667f, 0.183333f, 0.383333f, 0.2f, + 0.216667f, 0.45f, 0.233333f, 0.466667f, 0.966667f, 0.5f, + 0.466667f, 0.966667f, 0.5f, 0.25f, 0.516667f, 0.266667f, + 0.283333f, 0.583333f, 0.3f, 0.6f, 1.233333f, 0.633333f, + 0.6f, 1.233333f, 0.633333f, 0.316667f, 0.65f, 0.333333f, + 0.35f, 0.716667f, 0.366667f, 0.733333f, 1.5f, 0.766667f, + 0.733333f, 1.5f, 0.766667f, 0.383333f, 0.783333f, 0.4f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::avgpool2d_bp op; + auto results = + op.evaluate({&input, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_4) { - - int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=1,pW=1, dH=1,dW=1; - int oH=4,oW=4; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.233333f, 0.3f, 0.366667f, 0.55f, 0.65f, 0.75f, 0.95f, 1.05f, 1.15f, 0.766667f, 0.833333f, 0.9f, - 1.3f, 1.366667f, 1.433333f, 2.15f, 2.25f, 2.35f, 2.55f, 2.65f, 2.75f, 1.833333f, 1.9f, 1.966667f, - 2.366667f, 2.433333f, 2.5f, 3.75f, 3.85f, 3.95f, 4.15f, 4.25f, 4.35f, 2.9f, 2.966667f, 3.033333f, - 3.433333f, 3.5f, 3.566667f, 5.35f, 5.45f, 5.55f, 5.75f, 5.85f, 5.95f, 3.966667f, 4.033333f, 4.1f, - 4.5f, 4.566667f, 4.633333f, 6.95f, 7.05f, 7.15f, 7.35f, 7.45f, 7.55f, 5.033333f, 5.1f, 5.166667f, - 5.566667f, 5.633333f, 5.7f, 8.549999f, 8.65f, 8.75f, 8.95f, 9.05f, 9.150001f, 6.1f, 6.166667f, 6.233334f}); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::avgpool2d_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iH = 4, iW = 3, iC = 3, kH = 3, kW = 2, sH = 1, sW = 1, pH = 1, + pW = 1, dH = 1, dW = 1; + int oH = 4, oW = 4; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, iH, iW}, + {0.233333f, 0.3f, 0.366667f, 0.55f, 0.65f, 0.75f, + 0.95f, 1.05f, 1.15f, 0.766667f, 0.833333f, 0.9f, + 1.3f, 1.366667f, 1.433333f, 2.15f, 2.25f, 2.35f, + 2.55f, 2.65f, 2.75f, 1.833333f, 1.9f, 1.966667f, + 2.366667f, 2.433333f, 2.5f, 3.75f, 3.85f, 3.95f, + 4.15f, 4.25f, 4.35f, 2.9f, 2.966667f, 3.033333f, + 3.433333f, 3.5f, 3.566667f, 5.35f, 5.45f, 5.55f, + 5.75f, 5.85f, 5.95f, 3.966667f, 4.033333f, 4.1f, + 4.5f, 4.566667f, 4.633333f, 6.95f, 7.05f, 7.15f, + 7.35f, 7.45f, 7.55f, 5.033333f, 5.1f, 5.166667f, + 5.566667f, 5.633333f, 5.7f, 8.549999f, 8.65f, 8.75f, + 8.95f, 9.05f, 9.150001f, 6.1f, 6.166667f, 6.233334f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::avgpool2d_bp op; + auto results = + op.evaluate({&input, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - //////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_5) { - - int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=4,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.19167f, 0.23333f, 0.275f, 0.50833f, 0.59167f, 0.675f, 1.2f, 1.325f, 1.45f, 0.50833f, 0.56667f, 0.625f, 1.19167f, 1.30833f, 1.425f, 2.4f, 2.575f, 2.75f, - 1.18333f, 1.24167f, 1.3f, 2.54167f, 2.65833f, 2.775f, 4.425f, 4.6f, 4.775f, 1.01667f, 1.05833f, 1.1f, 2.15833f, 2.24167f, 2.325f, 3.675f, 3.8f, 3.925f, - 1.69167f, 1.73333f, 1.775f, 3.50833f, 3.59167f, 3.675f, 5.7f, 5.825f, 5.95f, 2.60833f, 2.66667f, 2.725f, 5.39167f, 5.50833f, 5.625f, 8.7f, 8.875f, 9.05f, - 3.28333f, 3.34167f, 3.4f, 6.74167f, 6.85833f, 6.975f, 10.725f, 10.9f, 11.075f, 2.51667f, 2.55833f, 2.6f, 5.15833f, 5.24167f, 5.325f, 8.175f, 8.3f, 8.425f}); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::avgpool2d_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iH = 4, iW = 3, iC = 3, kH = 3, kW = 2, sH = 1, sW = 1, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, iH, iW, iC}, + {0.19167f, 0.23333f, 0.275f, 0.50833f, 0.59167f, 0.675f, 1.2f, + 1.325f, 1.45f, 0.50833f, 0.56667f, 0.625f, 1.19167f, 1.30833f, + 1.425f, 2.4f, 2.575f, 2.75f, 1.18333f, 1.24167f, 1.3f, + 2.54167f, 2.65833f, 2.775f, 4.425f, 4.6f, 4.775f, 1.01667f, + 1.05833f, 1.1f, 2.15833f, 2.24167f, 2.325f, 3.675f, 3.8f, + 3.925f, 1.69167f, 1.73333f, 1.775f, 3.50833f, 3.59167f, 3.675f, + 5.7f, 5.825f, 5.95f, 2.60833f, 2.66667f, 2.725f, 5.39167f, + 5.50833f, 5.625f, 8.7f, 8.875f, 9.05f, 3.28333f, 3.34167f, + 3.4f, 6.74167f, 6.85833f, 6.975f, 10.725f, 10.9f, 11.075f, + 2.51667f, 2.55833f, 2.6f, 5.15833f, 5.24167f, 5.325f, 8.175f, + 8.3f, 8.425f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::avgpool2d_bp op; + auto results = + op.evaluate({&input, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, 0, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_6) { - - int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.01667f, 0.03333f, 0.05f, 0.08333f, 0.11667f, 0.15f, 0.06667f, 0.08333f, 0.1f, 0.13333f, 0.16667f, 0.2f, 0.36667f, 0.43333f, 0.5f, 0.23333f, 0.26667f, 0.3f, - 0.13333f, 0.16667f, 0.2f, 0.36667f, 0.43333f, 0.5f, 0.23333f, 0.26667f, 0.3f, 0.11667f, 0.13333f, 0.15f, 0.28333f, 0.31667f, 0.35f, 0.16667f, 0.18333f, 0.2f, - 0.21667f, 0.23333f, 0.25f, 0.48333f, 0.51667f, 0.55f, 0.26667f, 0.28333f, 0.3f, 0.53333f, 0.56667f, 0.6f, 1.16667f, 1.23333f, 1.3f, 0.63333f, 0.66667f, 0.7f, - 0.53333f, 0.56667f, 0.6f, 1.16667f, 1.23333f, 1.3f, 0.63333f, 0.66667f, 0.7f, 0.31667f, 0.33333f, 0.35f, 0.68333f, 0.71667f, 0.75f, 0.36667f, 0.38333f, 0.4f}); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::avgpool2d_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iH = 4, iW = 3, iC = 3, kH = 3, kW = 2, sH = 1, sW = 1, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, iH, iW, iC}, + {0.01667f, 0.03333f, 0.05f, 0.08333f, 0.11667f, 0.15f, 0.06667f, + 0.08333f, 0.1f, 0.13333f, 0.16667f, 0.2f, 0.36667f, 0.43333f, + 0.5f, 0.23333f, 0.26667f, 0.3f, 0.13333f, 0.16667f, 0.2f, + 0.36667f, 0.43333f, 0.5f, 0.23333f, 0.26667f, 0.3f, 0.11667f, + 0.13333f, 0.15f, 0.28333f, 0.31667f, 0.35f, 0.16667f, 0.18333f, + 0.2f, 0.21667f, 0.23333f, 0.25f, 0.48333f, 0.51667f, 0.55f, + 0.26667f, 0.28333f, 0.3f, 0.53333f, 0.56667f, 0.6f, 1.16667f, + 1.23333f, 1.3f, 0.63333f, 0.66667f, 0.7f, 0.53333f, 0.56667f, + 0.6f, 1.16667f, 1.23333f, 1.3f, 0.63333f, 0.66667f, 0.7f, + 0.31667f, 0.33333f, 0.35f, 0.68333f, 0.71667f, 0.75f, 0.36667f, + 0.38333f, 0.4f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::avgpool2d_bp op; + auto results = + op.evaluate({&input, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, pnormpool2d_bp_1) { - - auto input = NDArrayFactory::create_('c', {bS,iD,iH,iW}); - auto epsilon = NDArrayFactory::create_('c', {bS,iD,oH,oW}); - auto exp = NDArrayFactory::create('c', {bS,iD,iH,iW}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, input); - variableSpace->putVariable(-2, epsilon); - // variableSpace->putVariable(1, &z); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1}); - block->fillInputs({-2}); - auto argI = block->getIArguments(); - *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 3}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; 9 - divisor - std::vector* argT = block->getTArguments(); - *argT = {0.000001}; - - sd::ops::pnormpool2d_bp bp; - Nd4jStatus status = bp.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - ASSERT_TRUE(exp.isSameShape(result)); - - delete variableSpace; - delete block; + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto epsilon = NDArrayFactory::create('c', {bS, iD, oH, oW}); + auto exp = NDArrayFactory::create('c', {bS, iD, iH, iW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); + variableSpace->putVariable(-2, epsilon); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->fillInputs({-2}); + block->appendI({kH, kW, sH, sW, pH, pW, dW, dH, 0, + 3}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; + // 4,5 - pad Height/Width; 6,7 - dilation Height/Width; + // 8 - same mode; 9 - divisor + block->appendT(0.000001); + + sd::ops::pnormpool2d_bp bp; + Nd4jStatus status = bp.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + ASSERT_TRUE(exp.isSameShape(*result)); + + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, pnormpool2d_bp_2) { - - int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int pnorm = 3; - double eps = 0.; - - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {9.661570e-04f, 9.671602e-03f, 1.306569e-02f, 3.679184e-02f, 1.297220e-01f, 1.040181e-01f, 1.126750e-01f, 3.320884e-01f, 2.340406e-01f, 1.333333e-01f, 3.352886e-01f, 2.070211e-01f, - 8.991618e-02f, 2.160601e-01f, 1.283173e-01f, 2.744226e-01f, 6.364498e-01f, 3.662123e-01f, 3.869788e-01f, 8.808994e-01f, 4.984556e-01f, 2.613189e-01f, 5.818475e-01f, 3.225517e-01f, - 2.065654e-01f, 4.553546e-01f, 2.501175e-01f, 5.190718e-01f, 1.131343e+00f, 6.148388e-01f, 6.362602e-01f, 1.377521e+00f, 7.439550e-01f, 3.833026e-01f, 8.227519e-01f, 4.407146e-01f, - 3.261206e-01f, 6.969233e-01f, 3.717564e-01f, 7.627507e-01f, 1.620991e+00f, 8.600952e-01f, 8.814538e-01f, 1.866888e+00f, 9.873542e-01f, 5.046682e-01f, 1.064004e+00f, 5.602558e-01f, - 4.464697e-01f, 9.389536e-01f, 4.932274e-01f, 1.005908e+00f, 2.108550e+00f, 1.104095e+00f, 1.125322e+00f, 2.354009e+00f, 1.230180e+00f, 6.258913e-01f, 1.305581e+00f, 6.804127e-01f, - 5.671396e-01f, 1.181128e+00f, 6.145977e-01f, 1.248783e+00f, 2.595083e+00f, 1.347494e+00f, 1.368600e+00f, 2.840157e+00f, 1.472778e+00f, 7.470673e-01f, 1.547362e+00f, 8.008900e-01f}); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::pnormpool2d_bp op; - auto results = op.evaluate({&input, &gradO}, {eps}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, pnorm, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iH = 4, iW = 3, iC = 3, kH = 3, kW = 2, sH = 1, sW = 1, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int pnorm = 3; + double eps = 0.; + + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, iH, iW}, + {9.661570e-04f, 9.671602e-03f, 1.306569e-02f, 3.679184e-02f, + 1.297220e-01f, 1.040181e-01f, 1.126750e-01f, 3.320884e-01f, + 2.340406e-01f, 1.333333e-01f, 3.352886e-01f, 2.070211e-01f, + 8.991618e-02f, 2.160601e-01f, 1.283173e-01f, 2.744226e-01f, + 6.364498e-01f, 3.662123e-01f, 3.869788e-01f, 8.808994e-01f, + 4.984556e-01f, 2.613189e-01f, 5.818475e-01f, 3.225517e-01f, + 2.065654e-01f, 4.553546e-01f, 2.501175e-01f, 5.190718e-01f, + 1.131343e+00f, 6.148388e-01f, 6.362602e-01f, 1.377521e+00f, + 7.439550e-01f, 3.833026e-01f, 8.227519e-01f, 4.407146e-01f, + 3.261206e-01f, 6.969233e-01f, 3.717564e-01f, 7.627507e-01f, + 1.620991e+00f, 8.600952e-01f, 8.814538e-01f, 1.866888e+00f, + 9.873542e-01f, 5.046682e-01f, 1.064004e+00f, 5.602558e-01f, + 4.464697e-01f, 9.389536e-01f, 4.932274e-01f, 1.005908e+00f, + 2.108550e+00f, 1.104095e+00f, 1.125322e+00f, 2.354009e+00f, + 1.230180e+00f, 6.258913e-01f, 1.305581e+00f, 6.804127e-01f, + 5.671396e-01f, 1.181128e+00f, 6.145977e-01f, 1.248783e+00f, + 2.595083e+00f, 1.347494e+00f, 1.368600e+00f, 2.840157e+00f, + 1.472778e+00f, 7.470673e-01f, 1.547362e+00f, 8.008900e-01f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::pnormpool2d_bp op; + auto results = op.evaluate( + {&input, &gradO}, {eps}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, pnorm, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, pnormpool2d_bp_3) { - - int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int pnorm = 2; - double eps = 0.; - - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.007931f, 0.042891f, 0.040544f, 0.09369f, 0.276841f, 0.191675f, 0.163957f, 0.442946f, 0.287512f, 0.154919f, 0.373153f, 0.221172f, - 0.15901f, 0.365232f, 0.207846f, 0.428282f, 0.959455f, 0.534076f, 0.508585f, 1.128771f, 0.623089f, 0.319794f, 0.698063f, 0.379547f, - 0.321068f, 0.692438f, 0.372316f, 0.757521f, 1.620323f, 0.864566f, 0.838684f, 1.787943f, 0.951023f, 0.483194f, 1.023434f, 0.541058f, - 0.483937f, 1.019414f, 0.536145f, 1.085348f, 2.276996f, 1.192917f, 1.166749f, 2.443606f, 1.278126f, 0.646499f, 1.349361f, 0.703463f, - 0.647021f, 1.346249f, 0.699745f, 1.412654f, 2.932174f, 1.520512f, 1.494153f, 3.098146f, 1.604985f, 0.809791f, 1.675544f, 0.866229f, - 0.810192f, 1.673009f, 0.863237f, 1.739711f, 3.58665f, 1.847753f, 1.82126f, 3.752188f, 1.931741f, 0.973081f, 2.001861f, 1.029173f}); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::pnormpool2d_bp op; - auto results = op.evaluate({&input, &gradO}, {eps}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, pnorm, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iH = 4, iW = 3, iC = 3, kH = 3, kW = 2, sH = 1, sW = 1, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int pnorm = 2; + double eps = 0.; + + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, iH, iW}, + {0.007931f, 0.042891f, 0.040544f, 0.09369f, 0.276841f, 0.191675f, + 0.163957f, 0.442946f, 0.287512f, 0.154919f, 0.373153f, 0.221172f, + 0.15901f, 0.365232f, 0.207846f, 0.428282f, 0.959455f, 0.534076f, + 0.508585f, 1.128771f, 0.623089f, 0.319794f, 0.698063f, 0.379547f, + 0.321068f, 0.692438f, 0.372316f, 0.757521f, 1.620323f, 0.864566f, + 0.838684f, 1.787943f, 0.951023f, 0.483194f, 1.023434f, 0.541058f, + 0.483937f, 1.019414f, 0.536145f, 1.085348f, 2.276996f, 1.192917f, + 1.166749f, 2.443606f, 1.278126f, 0.646499f, 1.349361f, 0.703463f, + 0.647021f, 1.346249f, 0.699745f, 1.412654f, 2.932174f, 1.520512f, + 1.494153f, 3.098146f, 1.604985f, 0.809791f, 1.675544f, 0.866229f, + 0.810192f, 1.673009f, 0.863237f, 1.739711f, 3.58665f, 1.847753f, + 1.82126f, 3.752188f, 1.931741f, 0.973081f, 2.001861f, 1.029173f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::pnormpool2d_bp op; + auto results = op.evaluate( + {&input, &gradO}, {eps}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, pnorm, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, upsampling2d_bp_1) { + const int bS = 1, iH = 2, iW = 2, iC = 1; + const int factorH = 2, factorW = 2; + const int isNCHW = 1; // data format, default is NCHW - const int bS=1, iH=2,iW=2, iC=1; - const int factorH=2, factorW=2; - const int isNCHW = 1; // data format, default is NCHW + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = + NDArrayFactory::create('c', {bS, iC, iH * factorH, iW * factorW}); + gradO = 1.; - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, iH*factorH, iW*factorW}); - gradO = 1.; + auto expGradI = NDArrayFactory::create('c', {bS, iC, iH, iW}); + expGradI = 4.; - auto expGradI = NDArrayFactory::create('c', {bS, iC, iH, iW}); - expGradI = 4.; - - sd::ops::upsampling2d_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {isNCHW}); - auto* gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); + sd::ops::upsampling2d_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {isNCHW}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, upsampling2d_bp_2) { + const int bS = 1, iH = 2, iW = 2, iC = 1; + const int factorH = 2, factorW = 2; + const int isNCHW = 0; // data format, default is NCHW - const int bS=1, iH=2,iW=2, iC=1; - const int factorH=2, factorW=2; - const int isNCHW = 0; // data format, default is NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto gradO = NDArrayFactory::create('c', {bS, iH*factorH, iW*factorW, iC}); - gradO = 1.; + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto gradO = + NDArrayFactory::create('c', {bS, iH * factorH, iW * factorW, iC}); + gradO = 1.; - auto expGradI = NDArrayFactory::create('c', {bS, iH, iW, iC}); - expGradI = 4.; + auto expGradI = NDArrayFactory::create('c', {bS, iH, iW, iC}); + expGradI = 4.; - sd::ops::upsampling2d_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {isNCHW}); - auto* gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); + sd::ops::upsampling2d_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {isNCHW}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, upsampling2d_bp_3) { - - const int bS=1, iH=3,iW=3, iC=2; - const int factorH=2, factorW=2; - const int isNCHW = 1; // data format, default is NCHW - - NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); - - NDArray gradO('c', {bS, iC, iH*factorH, iW*factorW}, {0.6793504, 0.35508695, 0.84278935, 0.20031333, 0.7014987, 0.31069338, 0.44793984, - 0.93800974, 0.32667395, 0.15187258, 0.38331753, 0.78212297, 0.1988072, 0.7985636, 0.1632634, 0.14696825, 0.26089668, 0.13505761, - 0.7562093, 0.27545404, 0.36908787, 0.09282647, 0.83649176, 0.26841334, 0.09506222, 0.31279507, 0.13591796, 0.5175439, 0.32870287, - 0.061735712, 0.39643127, 0.248016, 0.5489592, 0.115046196, 0.8143622, 0.7215636, 0.40449402, 0.29908907, 0.4038839, 0.9883108, - 0.022296403, 0.927782, 0.3184157, 0.0685462, 0.28453344, 0.23272, 0.35214192, 0.058909304, 0.7112212, 0.6744568, 0.19694561, 0.6994972, - 0.0743224, 0.42042503, 0.5842631, 0.14957358, 0.44640633, 0.72307247, 0.06448108, 0.48307765, 0.8759956, 0.5698191, 0.4458631, 0.5277549, - 0.016646361, 0.753678, 0.14063567, 0.7541292, 0.16193217, 0.7750374, 0.3326449, 0.11739397}, sd::DataType::FLOAT32); - - NDArray expGradI('c', {bS, iC, iH, iW}, {2.4203868, 1.5216494, 2.1776323, 2.0290341, 0.772146, 1.5008594, 1.0523045, 1.3174672, 1.9263644, - 1.090545, 1.9094483, 1.3611296, 2.1195147, 2.0659215, 1.0423062, 2.3405795, 1.9105877, 1.2203633}, sd::DataType::FLOAT32); - - sd::ops::upsampling2d_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {isNCHW}); - auto* gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - + const int bS = 1, iH = 3, iW = 3, iC = 2; + const int factorH = 2, factorW = 2; + const int isNCHW = 1; // data format, default is NCHW + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + + NDArray gradO('c', {bS, iC, iH * factorH, iW * factorW}, + {0.6793504, 0.35508695, 0.84278935, 0.20031333, 0.7014987, + 0.31069338, 0.44793984, 0.93800974, 0.32667395, 0.15187258, + 0.38331753, 0.78212297, 0.1988072, 0.7985636, 0.1632634, + 0.14696825, 0.26089668, 0.13505761, 0.7562093, 0.27545404, + 0.36908787, 0.09282647, 0.83649176, 0.26841334, 0.09506222, + 0.31279507, 0.13591796, 0.5175439, 0.32870287, 0.061735712, + 0.39643127, 0.248016, 0.5489592, 0.115046196, 0.8143622, + 0.7215636, 0.40449402, 0.29908907, 0.4038839, 0.9883108, + 0.022296403, 0.927782, 0.3184157, 0.0685462, 0.28453344, + 0.23272, 0.35214192, 0.058909304, 0.7112212, 0.6744568, + 0.19694561, 0.6994972, 0.0743224, 0.42042503, 0.5842631, + 0.14957358, 0.44640633, 0.72307247, 0.06448108, 0.48307765, + 0.8759956, 0.5698191, 0.4458631, 0.5277549, 0.016646361, + 0.753678, 0.14063567, 0.7541292, 0.16193217, 0.7750374, + 0.3326449, 0.11739397}, + sd::DataType::FLOAT32); + + NDArray expGradI( + 'c', {bS, iC, iH, iW}, + {2.4203868, 1.5216494, 2.1776323, 2.0290341, 0.772146, 1.5008594, + 1.0523045, 1.3174672, 1.9263644, 1.090545, 1.9094483, 1.3611296, + 2.1195147, 2.0659215, 1.0423062, 2.3405795, 1.9105877, 1.2203633}, + sd::DataType::FLOAT32); + + sd::ops::upsampling2d_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {isNCHW}); + auto gradI = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, depthwise_conv2d_1) { - - int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=4,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - - - auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, - 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f, - 12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, - 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f}); - input = 2.; - weights.linspace(0.1, 0.1); - - sd::ops::depthwise_conv2d op; - auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 2, iH = 4, iW = 3, iC = 2, mC = 2, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 4, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + + auto expOutput = NDArrayFactory::create( + 'c', {bS, oH, oW, oC}, + {12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, + 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, + 5.4f, 6.f, 6.6f, 7.2f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, + 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, + 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f, 12.f, 12.8f, + 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, + 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, + 6.6f, 7.2f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, + 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, + 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f}); + input = 2.; + weights.linspace(0.1, 0.1); + + sd::ops::depthwise_conv2d op; + auto results = + op.evaluate({&input, &weights}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_2) { - - int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - - - auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, - 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f}); - input = 2.; - weights.linspace(0.1, 0.1); - - sd::ops::depthwise_conv2d op; - auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - + int bS = 2, iH = 4, iW = 3, iC = 2, mC = 2, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + + auto expOutput = NDArrayFactory::create( + 'c', {bS, oH, oW, oC}, + {13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, + 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, + 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, + 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f}); + input = 2.; + weights.linspace(0.1, 0.1); + + sd::ops::depthwise_conv2d op; + auto results = + op.evaluate({&input, &weights}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } - ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_3) { - - int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto weights = NDArrayFactory::create('c', {mC, iC, kH, kW}); - auto biases = NDArrayFactory::create('c', {iC*mC}, {1.f,2.f,3.f,4.f}); - - NDArray expOutput('c', {bS, oC, oH, oW},{5.2, 5.2, 5.2, 5.2,20.6,20.6,20.6,20.6,14.4,14.4,14.4,14.4,29.8,29.8,29.8,29.8, 5.2, 5.2, 5.2, 5.2,20.6,20.6,20.6,20.6,14.4,14.4,14.4,14.4,29.8,29.8,29.8,29.8}, sd::DataType::FLOAT32); - - input = 2.; - weights.linspace(0.1, 0.1); - weights.permutei({2,3,1,0}); - - sd::ops::depthwise_conv2d op; - auto results = op.evaluate({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - + int bS = 2, iH = 4, iW = 3, iC = 2, mC = 2, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {mC, iC, kH, kW}); + auto biases = + NDArrayFactory::create('c', {iC * mC}, {1.f, 2.f, 3.f, 4.f}); + + NDArray expOutput( + 'c', {bS, oC, oH, oW}, + {5.2, 5.2, 5.2, 5.2, 20.6, 20.6, 20.6, 20.6, 14.4, 14.4, 14.4, + 14.4, 29.8, 29.8, 29.8, 29.8, 5.2, 5.2, 5.2, 5.2, 20.6, 20.6, + 20.6, 20.6, 14.4, 14.4, 14.4, 14.4, 29.8, 29.8, 29.8, 29.8}, + sd::DataType::FLOAT32); + + input = 2.; + weights.linspace(0.1, 0.1); + weights.permutei({2, 3, 1, 0}); + + sd::ops::depthwise_conv2d op; + auto results = + op.evaluate({&input, &weights, &biases}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_4) { + int bS = 1, iH = 111, iW = 111, iC = 32, mC = 1, kH = 7, kW = 7, sH = 2, + sW = 2, pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 56, oW = 56; - int bS=1, iH=111,iW=111, iC=32,mC=1, kH=7,kW=7, sH=2,sW=2, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=56,oW=56; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW + const float unique = -1000000; - const float unique = -1000000; + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); + NDArray output('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); + input.linspace(0.1, 0.0001); + weights = 0.5; + output = unique; - NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); - NDArray output('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); - input.linspace(0.1, 0.0001); - weights = 0.5; - output = unique; + sd::ops::depthwise_conv2d op; + Nd4jStatus status = + op.execute({&input, &weights}, {&output}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}, {}); - sd::ops::depthwise_conv2d op; - Nd4jStatus status = op.execute({&input, &weights}, {&output} , {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(Status::OK(), status); - - for(Nd4jLong i=output.lengthOf()/1.5; i < output.lengthOf(); ++i) - ASSERT_EQ(output.e(i) != unique, true); + for (Nd4jLong i = output.lengthOf() / 1.5; i < output.lengthOf(); ++i) + ASSERT_EQ(output.e(i) != unique, true); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_5) { - - int bS=1, iH=3,iW=3, iC=2,mC=1, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=3,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - - NDArray expOutput('c', {bS, oH, oW, oC}, {10., 12., 14., 16., 8., 9., 22., 24., 26., 28., 14., 15., 14., 15., 16., 17., 8.5, 9.}, sd::DataType::FLOAT32); - - input.linspace(1.); - weights = 0.5; - - sd::ops::depthwise_conv2d op; - auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - + int bS = 1, iH = 3, iW = 3, iC = 2, mC = 1, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 3, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + + NDArray expOutput('c', {bS, oH, oW, oC}, + {10., 12., 14., 16., 8., 9., 22., 24., 26., 28., 14., 15., + 14., 15., 16., 17., 8.5, 9.}, + sd::DataType::FLOAT32); + + input.linspace(1.); + weights = 0.5; + + sd::ops::depthwise_conv2d op; + auto results = + op.evaluate({&input, &weights}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_6) { - - int bS=1, iH=3,iW=3, iC=2,mC=1, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=3,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); - - NDArray expOutput('c', {bS, oH, oW, oC}, {20., 24.,28., 32.,16., 18.,44., 48.,52., 56.,28., 30.,28., 30.,32., 34.,17., 18.}, sd::DataType::FLOAT32); - input.linspace(1.); - weights = 1.; - - sd::ops::depthwise_conv2d op; - auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - NDArray* output = results.at(0); - // output.printIndexedBuffer(); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 1, iH = 3, iW = 3, iC = 2, mC = 1, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 3, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oH, oW, oC}, + {20., 24., 28., 32., 16., 18., 44., 48., 52., 56., 28., 30., + 28., 30., 32., 34., 17., 18.}, + sd::DataType::FLOAT32); + input.linspace(1.); + weights = 1.; + + sd::ops::depthwise_conv2d op; + auto results = + op.evaluate({&input, &weights}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + // output.printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_7) { - - int bS=1, iH=3,iW=3, iC=2,mC=2, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=3,oW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iC, iH, iW}, {0.6793503761291504, 0.35508695244789124, 0.842789351940155, 0.20031332969665527, 0.7014986872673035, 0.3106933832168579, - 0.44793984293937683, 0.9380097389221191, 0.3266739547252655, 0.15187257528305054, 0.3833175301551819, 0.7821229696273804, - 0.19880719482898712, 0.7985635995864868, 0.16326339542865753, 0.14696824550628662, 0.2608966827392578, 0.13505761325359344}, sd::DataType::FLOAT32); - NDArray weights('c', {kH, kW, iC, mC}, {0.1308445781469345, 0.6442840099334717, 0.5698848366737366, 0.19896849989891052}, sd::DataType::FLOAT32); - NDArray biases('c', {1,iC*mC}, {0.6123566627502441, 0.37637925148010254, 0.17464971542358398, 0.4270855486392975}, sd::DataType::FLOAT32); - - NDArray expOutput('c', {bS, oC, oH, oW}, {0.7012459761288241, 0.6588178652487691, 0.722631079971582, 0.6385665758716108, 0.7041439625563628, 0.6530092074102978, - 0.670967162534851, 0.735090151337225, 0.6551001785478623, 0.8140738359624038, 0.6051560970782859, 0.9193749546773375, 0.5054379267801892, 0.8283436386757472, - 0.5765540302788565, 0.6649797296980537, 0.9807239274294943, 0.586850056971322, 0.261199593183985, 0.3930965634902499, 0.6203697362284615, 0.28794692117826504, - 0.6297390019475202, 0.26769104886224415, 0.25840469001015975, 0.3233307788551656, 0.25161700129415276, 0.4573034071191504, 0.5033536625992294, 0.5827033826425385, - 0.4666419179635315, 0.585974550122895, 0.4595698215161401, 0.45632759998045813, 0.4789957702325296, 0.4539577593482922}, sd::DataType::FLOAT32); - - - sd::ops::depthwise_conv2d op; - auto results = op.evaluate({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 1, iH = 3, iW = 3, iC = 2, mC = 2, kH = 1, kW = 1, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 3, oW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, + {0.6793503761291504, 0.35508695244789124, 0.842789351940155, + 0.20031332969665527, 0.7014986872673035, 0.3106933832168579, + 0.44793984293937683, 0.9380097389221191, 0.3266739547252655, + 0.15187257528305054, 0.3833175301551819, 0.7821229696273804, + 0.19880719482898712, 0.7985635995864868, 0.16326339542865753, + 0.14696824550628662, 0.2608966827392578, 0.13505761325359344}, + sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, + {0.1308445781469345, 0.6442840099334717, 0.5698848366737366, + 0.19896849989891052}, + sd::DataType::FLOAT32); + NDArray biases('c', {1, iC * mC}, + {0.6123566627502441, 0.37637925148010254, 0.17464971542358398, + 0.4270855486392975}, + sd::DataType::FLOAT32); + + NDArray expOutput( + 'c', {bS, oC, oH, oW}, + {0.7012459761288241, 0.6588178652487691, 0.722631079971582, + 0.6385665758716108, 0.7041439625563628, 0.6530092074102978, + 0.670967162534851, 0.735090151337225, 0.6551001785478623, + 0.8140738359624038, 0.6051560970782859, 0.9193749546773375, + 0.5054379267801892, 0.8283436386757472, 0.5765540302788565, + 0.6649797296980537, 0.9807239274294943, 0.586850056971322, + 0.261199593183985, 0.3930965634902499, 0.6203697362284615, + 0.28794692117826504, 0.6297390019475202, 0.26769104886224415, + 0.25840469001015975, 0.3233307788551656, 0.25161700129415276, + 0.4573034071191504, 0.5033536625992294, 0.5827033826425385, + 0.4666419179635315, 0.585974550122895, 0.4595698215161401, + 0.45632759998045813, 0.4789957702325296, 0.4539577593482922}, + sd::DataType::FLOAT32); + + sd::ops::depthwise_conv2d op; + auto results = + op.evaluate({&input, &weights, &biases}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_8) { - - int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=10,oW=10; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); - - NDArray expOutput('c', {bS, oH, oW, oC}, {-42.879997, -43.959999, -44.959999, -45.879997, -46.720005, -47.480003, -48.160000, -48.760002, -43.519997, -45.139999, -46.639996, -48.020000, -49.280003, -50.419998, -51.440006, -52.340000, -31.999998, -33.139999, -34.160000, -35.060001, -35.840004, -36.500004, -37.039997, -37.459999, -20.480000, - -21.139997, -21.680000, -22.100000, -22.399998, -22.579998, -22.639996, -22.580002, -8.960000, -9.139998, -9.200002, -9.140001, -8.960001, -8.660000, -8.240002, -7.700001, 2.560000, 2.860002, 3.279998, 3.820000, 4.480001, 5.260000, 6.160001, 7.180000, 14.080000, 14.860000, 15.759998, 16.779999, 17.920002, 19.180000, 20.560001, 22.059998, - 25.600000, 26.860001, 28.239998, 29.739998, 31.360001, 33.099998, 34.959999, 36.939999, 37.119999, 38.860001, 40.720001, 42.699997, 44.800003, 47.020000, 49.360001, 51.820000, 26.239998, 27.400002, 28.639999, 29.959999, 31.360001, 32.840000, 34.400002, 36.040001, 62.400002, 62.459999, 62.639999, 62.940002, 63.360001, 63.900002, 64.559998, - 65.340004, 106.080002, 106.169998, 106.440002, 106.889999, 107.519997, 108.330002, 109.320000, 110.490005, 114.720001, 115.529999, 116.520004, 117.690002, 119.040009, 120.570000, 122.279999, 124.169998, 123.359985, 124.889999, 126.599998, 128.490005, 130.559998, 132.809998, 135.240005, 137.850006, 132.000000, 134.250000, 136.679993, - 139.290009, 142.080002, 145.049988, 148.199997, 151.529999, 140.639999, 143.610001, 146.760010, 150.089996, 153.600006, 157.290009, 161.160004, 165.209991, 149.279999, 152.970001, 156.839996, 160.889999, 165.120010, 169.529999, 174.119995, 178.889999, 157.919998, 162.330002, 166.919983, 171.690002, 176.639999, 181.769989, 187.079987, - 192.570007, 166.559998, 171.690002, 177.000000, 182.489990, 188.160004, 194.010010, 200.040009, 206.250000, 100.799995, 104.220001, 107.760002, 111.419998, 115.200005, 119.099998, 123.120003, 127.260010, 139.200012, 144.059998, 149.040009, 154.139999, 159.360001, 164.699997, 170.160004, 175.739990, 192.479996, 199.770020, 207.239990, - 214.889999, 222.720001, 230.730011, 238.919998, 247.290009, 201.119995, 209.129990, 217.319992, 225.690002, 234.240005, 242.970001, 251.880005, 260.970001, 209.760010, 218.489990, 227.399994, 236.490005, 245.760010, 255.209991, 264.839996, 274.649994, 218.399994, 227.850006, 237.479996, 247.289993, 257.279999, 267.449982, 277.799988, - 288.330017, 227.040009, 237.209991, 247.559998, 258.089996, 268.800018, 279.690002, 290.760010, 302.010010, 235.679993, 246.570007, 257.639984, 268.889984, 280.320007, 291.929993, 303.720001, 315.690002, 244.320007, 255.929993, 267.720001, 279.690002, 291.839996, 304.169983, 316.679993, 329.369995, 252.959991, 265.290009, 277.799988, 290.489990, 303.359985, 316.410004, 329.640015, 343.050018, 139.199997, 147.419998, 155.760010, 164.220001, 172.799988, 181.500000, 190.319992, 199.260010, 216.000000, 225.660004, 235.440002, 245.339996, 255.360016, 265.500000, 275.760010, 286.140015, 278.880005, 293.369995, 308.040009, 322.889984, 337.920013, 353.129974, 368.519989, - 384.090027, 287.520020, 302.730011, 318.119995, 333.690002, 349.440002, 365.369995, 381.479980, 397.770020, 296.160004, 312.089996, 328.199982, 344.489990, 360.960022, 377.609985, 394.440002, 411.449982, 304.799988, 321.450012, 338.280029, 355.289978, 372.480011, 389.850006, 407.399994, 425.130005, 313.440002, 330.809998, 348.359985, 366.089996, 384.000000, 402.090027, 420.359985, 438.809998, 322.079987, 340.169983, 358.440002, 376.889984, 395.520020, 414.329987, 433.320007, 452.489990, 330.720001, 349.530029, 368.520020, 387.690002, 407.039978, 426.570007, 446.279999, 466.170013, 339.360016, 358.890015, 378.599976, 398.490021, 418.559998, 438.809998, 459.239990, 479.849976, 177.600006, 190.619995, 203.759995, 217.020004, 230.399994, 243.899994, 257.519989, 271.260010, 292.799988, 307.260010, 321.839996, 336.539978, 351.360016, 366.299988, 381.359985, 396.540009, 365.279999, 386.970001, 408.839996, 430.889984, 453.120026, 475.529968, 498.119995, 520.890015, 373.920013, 396.329987, 418.919983, 441.690002, 464.640015, 487.769958, 511.079987, 534.570007, 382.559998, 405.690002, 429.000000, 452.489990, 476.160004, 500.010010, 524.039978, 548.250000, 391.200012, 415.049988, 439.080017, 463.290009, 487.679993, 512.250000, 537.000000, 561.930054, 399.839996, 424.409973, 449.160034, 474.089966, 499.200012, 524.489990, 549.959961, 575.609985, 408.479980, 433.770020, 459.239990, 484.889954, 510.720032, 536.729980, 562.919983, 589.290039, 417.119995, 443.130005, 469.319977, 495.690002, 522.239990, 548.969971, 575.880005, 602.969971, 425.760010, 452.489990, 479.399994, 506.489990, 533.760010, 561.209961, 588.839966, 616.650024, 216.000000, 233.819992, 251.760010, 269.820007, 288.000000, 306.299988, 324.719971, 343.260010, 369.600006, 388.859985, 408.239990, 427.739990, 447.360016, 467.100006, 486.959961, 506.940002, 451.679993, 480.570007, 509.639984, 538.890015, 568.320007, 597.929993, 627.719971, 657.690002, 460.320007, 489.929993, 519.719971, 549.690002, 579.840027, 610.170044, 640.680054, 671.369995, 468.960022, 499.289978, 529.799988, 560.489990, 591.359985, 622.409973, 653.640015, 685.049988, 477.599976, 508.650024, 539.880005, 571.289978, 602.880005, 634.650024, 666.599976, 698.729980, 486.239990, 518.010010, 549.960022, 582.089966, 614.400024, 646.890015, 679.559937, 712.410034, 494.879974, 527.369995, 560.039978, 592.890015, 625.920044, 659.130005, 692.520020, 726.089966, 503.519989, 536.729980, 570.119995, 603.689941, 637.440063, 671.369995, 705.480042, 739.770020, 512.160034, 546.089966, 580.199951, 614.489990, 648.960022, 683.609985, 718.440002, 753.449951, 254.400009, 277.020020, 299.760010, 322.619995, 345.600006, 368.700012, 391.919983, 415.260010, 446.399994, 470.459961, 494.640015, 518.940002, 543.360046, 567.900024, 592.559998, 617.340027, 538.080017, 574.170044, 610.440002, 646.890015, 683.520020, 720.329956, 757.320007, 794.489990, 546.719971, 583.530029, 620.520020, 657.690002, 695.040039, 732.570007, 770.279968, 808.169983, 555.359985, 592.889954, 630.599976, 668.489990, 706.559998, 744.809998, 783.239990, 821.849976, 564.000000, 602.250000, 640.679993, 679.289978, 718.080017, 757.050049, 796.199951, 835.530029, 572.640015, 611.609985, 650.760010, 690.089966, 729.600037, 769.289978, 809.160034, 849.210083, 581.279968, 620.970032, 660.839966, 700.889954, 741.119995, 781.529968, 822.119995, 862.890015, 589.919983, 630.330017, 670.919983, 711.690002, 752.640015, 793.770020, 835.079956, 876.570007, 598.559998, 639.690002, 681.000000, 722.490051, 764.160034, 806.010010, 848.039978, 890.250061, 292.799988, 320.220001, 347.760010, 375.419983, 403.200012, 431.100006, 459.119995, 487.260010, 523.199951, 552.059998, 581.040039, 610.139954, 639.360046, 668.699951, 698.159973, 727.739990, 624.479980, 667.770020, 711.239990, 754.890015, 798.719971, 842.729980, 886.919983, 931.290039, 633.119995, 677.130005, 721.319946, 765.690002, 810.239990, 854.969971, 899.880005, 944.969971, 641.760010, 686.489990, 731.400024, 776.489990, 821.760010, 867.209961, 912.839966, 958.650024, 650.400024, 695.849976, 741.479980, 787.290039, 833.279968, 879.449951, 925.799927, 972.330017, 659.040039, 705.210022, 751.559998, 798.089966, 844.800049, 891.690002, 938.760010, 986.010010, 667.679993, 714.569946, 761.640015, 808.890015, 856.320007, 903.929993, 951.719971, 999.690063, 676.320007, 723.929993, 771.719971, 819.690002, 867.839966, 916.169922, 964.679932, 1013.369995, 684.959961, 733.290039, 781.800049, 830.489990, 879.359985, 928.410034, 977.640015, 1027.050049, 331.199982, 363.419983, 395.760010, 428.220001, 460.799988, 493.500000, 526.320007, 559.260010, 600.000000, 633.660034, 667.440002, 701.339966, 735.359985, 769.500000, 803.759949, 838.140015, 710.880005, 761.369995, 812.039978, 862.889893, 913.919983, 965.130005, 1016.520020, 1068.090088, 719.520020, 770.729980, 822.119934, 873.689941, 925.440063, 977.369995, 1029.479980, 1081.770020, 728.160034, 780.090088, 832.199951, 884.489990, 936.960022, 989.610046, 1042.439941, 1095.449951, 736.799927, 789.449951, 842.280029, 895.290039, 948.480042, 1001.849976, 1055.399902, 1109.129883, 745.439941, 798.810059, 852.359985, 906.089966, 960.000000, 1014.089966, 1068.359985, 1122.810059, 754.080017, 808.170044, 862.440002, 916.890015, 971.520020, 1026.330078, 1081.319946, 1136.489990, 762.720032, 817.530029, 872.520020, 927.689941, 983.040039, 1038.569946, 1094.280029, 1150.169922, 771.359985, 826.890015, 882.599976, 938.489990, 994.559998, 1050.810059, 1107.239990, 1163.849976, 369.599976, 406.619995, 443.760010, 481.020020, 518.400024, 555.900024, 593.520020, 631.260010, 113.279999, 136.839996, 160.480011, 184.199982, 208.000015, 231.880005, 255.839996, 279.880005, 31.359985, 66.699989, 102.160004, 137.740005, 173.440002, 209.260010, 245.199982, 281.260010, 31.359993, 67.179993, 103.120003, 139.179993, 175.360016, 211.660004, 248.079987, 284.619995, 31.359993, 67.659996, 104.080009, 140.619995, 177.280014, 214.060013, 250.959991, 287.980011, 31.359993, 68.139999, 105.039993, 142.059982, 179.200027, 216.459991, 253.839996, 291.339996, 31.360008, 68.619995, 106.000000, 143.499985, 181.119995, 218.860001, 256.719971, 294.700012, 31.360001, 69.099991, 106.959984, 144.939987, 183.040009, 221.260010, 259.600006, 298.059998, 31.360008, 69.579971, 107.920006, 146.379990, 184.960007, 223.660004, 262.479980, 301.419983, 31.360001, 70.059975, 108.880020, 147.819977, 186.880020, 226.059998, 265.359985, 304.779999, -83.840004, -58.040001, -32.159988, -6.200012, 19.840012, 45.959984, 72.159996, 98.440010}, sd::DataType::FLOAT32); - - input.linspace(-10, 0.1); - weights.linspace(-2, 0.1); - - sd::ops::depthwise_conv2d op; - auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - // output->printBuffer(); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 1, iH = 10, iW = 10, iC = 8, mC = 1, kH = 3, kW = 3, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 10, oW = 10; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); + + NDArray expOutput( + 'c', {bS, oH, oW, oC}, + {-42.879997, -43.959999, -44.959999, -45.879997, -46.720005, + -47.480003, -48.160000, -48.760002, -43.519997, -45.139999, + -46.639996, -48.020000, -49.280003, -50.419998, -51.440006, + -52.340000, -31.999998, -33.139999, -34.160000, -35.060001, + -35.840004, -36.500004, -37.039997, -37.459999, -20.480000, + -21.139997, -21.680000, -22.100000, -22.399998, -22.579998, + -22.639996, -22.580002, -8.960000, -9.139998, -9.200002, + -9.140001, -8.960001, -8.660000, -8.240002, -7.700001, + 2.560000, 2.860002, 3.279998, 3.820000, 4.480001, + 5.260000, 6.160001, 7.180000, 14.080000, 14.860000, + 15.759998, 16.779999, 17.920002, 19.180000, 20.560001, + 22.059998, 25.600000, 26.860001, 28.239998, 29.739998, + 31.360001, 33.099998, 34.959999, 36.939999, 37.119999, + 38.860001, 40.720001, 42.699997, 44.800003, 47.020000, + 49.360001, 51.820000, 26.239998, 27.400002, 28.639999, + 29.959999, 31.360001, 32.840000, 34.400002, 36.040001, + 62.400002, 62.459999, 62.639999, 62.940002, 63.360001, + 63.900002, 64.559998, 65.340004, 106.080002, 106.169998, + 106.440002, 106.889999, 107.519997, 108.330002, 109.320000, + 110.490005, 114.720001, 115.529999, 116.520004, 117.690002, + 119.040009, 120.570000, 122.279999, 124.169998, 123.359985, + 124.889999, 126.599998, 128.490005, 130.559998, 132.809998, + 135.240005, 137.850006, 132.000000, 134.250000, 136.679993, + 139.290009, 142.080002, 145.049988, 148.199997, 151.529999, + 140.639999, 143.610001, 146.760010, 150.089996, 153.600006, + 157.290009, 161.160004, 165.209991, 149.279999, 152.970001, + 156.839996, 160.889999, 165.120010, 169.529999, 174.119995, + 178.889999, 157.919998, 162.330002, 166.919983, 171.690002, + 176.639999, 181.769989, 187.079987, 192.570007, 166.559998, + 171.690002, 177.000000, 182.489990, 188.160004, 194.010010, + 200.040009, 206.250000, 100.799995, 104.220001, 107.760002, + 111.419998, 115.200005, 119.099998, 123.120003, 127.260010, + 139.200012, 144.059998, 149.040009, 154.139999, 159.360001, + 164.699997, 170.160004, 175.739990, 192.479996, 199.770020, + 207.239990, 214.889999, 222.720001, 230.730011, 238.919998, + 247.290009, 201.119995, 209.129990, 217.319992, 225.690002, + 234.240005, 242.970001, 251.880005, 260.970001, 209.760010, + 218.489990, 227.399994, 236.490005, 245.760010, 255.209991, + 264.839996, 274.649994, 218.399994, 227.850006, 237.479996, + 247.289993, 257.279999, 267.449982, 277.799988, 288.330017, + 227.040009, 237.209991, 247.559998, 258.089996, 268.800018, + 279.690002, 290.760010, 302.010010, 235.679993, 246.570007, + 257.639984, 268.889984, 280.320007, 291.929993, 303.720001, + 315.690002, 244.320007, 255.929993, 267.720001, 279.690002, + 291.839996, 304.169983, 316.679993, 329.369995, 252.959991, + 265.290009, 277.799988, 290.489990, 303.359985, 316.410004, + 329.640015, 343.050018, 139.199997, 147.419998, 155.760010, + 164.220001, 172.799988, 181.500000, 190.319992, 199.260010, + 216.000000, 225.660004, 235.440002, 245.339996, 255.360016, + 265.500000, 275.760010, 286.140015, 278.880005, 293.369995, + 308.040009, 322.889984, 337.920013, 353.129974, 368.519989, + 384.090027, 287.520020, 302.730011, 318.119995, 333.690002, + 349.440002, 365.369995, 381.479980, 397.770020, 296.160004, + 312.089996, 328.199982, 344.489990, 360.960022, 377.609985, + 394.440002, 411.449982, 304.799988, 321.450012, 338.280029, + 355.289978, 372.480011, 389.850006, 407.399994, 425.130005, + 313.440002, 330.809998, 348.359985, 366.089996, 384.000000, + 402.090027, 420.359985, 438.809998, 322.079987, 340.169983, + 358.440002, 376.889984, 395.520020, 414.329987, 433.320007, + 452.489990, 330.720001, 349.530029, 368.520020, 387.690002, + 407.039978, 426.570007, 446.279999, 466.170013, 339.360016, + 358.890015, 378.599976, 398.490021, 418.559998, 438.809998, + 459.239990, 479.849976, 177.600006, 190.619995, 203.759995, + 217.020004, 230.399994, 243.899994, 257.519989, 271.260010, + 292.799988, 307.260010, 321.839996, 336.539978, 351.360016, + 366.299988, 381.359985, 396.540009, 365.279999, 386.970001, + 408.839996, 430.889984, 453.120026, 475.529968, 498.119995, + 520.890015, 373.920013, 396.329987, 418.919983, 441.690002, + 464.640015, 487.769958, 511.079987, 534.570007, 382.559998, + 405.690002, 429.000000, 452.489990, 476.160004, 500.010010, + 524.039978, 548.250000, 391.200012, 415.049988, 439.080017, + 463.290009, 487.679993, 512.250000, 537.000000, 561.930054, + 399.839996, 424.409973, 449.160034, 474.089966, 499.200012, + 524.489990, 549.959961, 575.609985, 408.479980, 433.770020, + 459.239990, 484.889954, 510.720032, 536.729980, 562.919983, + 589.290039, 417.119995, 443.130005, 469.319977, 495.690002, + 522.239990, 548.969971, 575.880005, 602.969971, 425.760010, + 452.489990, 479.399994, 506.489990, 533.760010, 561.209961, + 588.839966, 616.650024, 216.000000, 233.819992, 251.760010, + 269.820007, 288.000000, 306.299988, 324.719971, 343.260010, + 369.600006, 388.859985, 408.239990, 427.739990, 447.360016, + 467.100006, 486.959961, 506.940002, 451.679993, 480.570007, + 509.639984, 538.890015, 568.320007, 597.929993, 627.719971, + 657.690002, 460.320007, 489.929993, 519.719971, 549.690002, + 579.840027, 610.170044, 640.680054, 671.369995, 468.960022, + 499.289978, 529.799988, 560.489990, 591.359985, 622.409973, + 653.640015, 685.049988, 477.599976, 508.650024, 539.880005, + 571.289978, 602.880005, 634.650024, 666.599976, 698.729980, + 486.239990, 518.010010, 549.960022, 582.089966, 614.400024, + 646.890015, 679.559937, 712.410034, 494.879974, 527.369995, + 560.039978, 592.890015, 625.920044, 659.130005, 692.520020, + 726.089966, 503.519989, 536.729980, 570.119995, 603.689941, + 637.440063, 671.369995, 705.480042, 739.770020, 512.160034, + 546.089966, 580.199951, 614.489990, 648.960022, 683.609985, + 718.440002, 753.449951, 254.400009, 277.020020, 299.760010, + 322.619995, 345.600006, 368.700012, 391.919983, 415.260010, + 446.399994, 470.459961, 494.640015, 518.940002, 543.360046, + 567.900024, 592.559998, 617.340027, 538.080017, 574.170044, + 610.440002, 646.890015, 683.520020, 720.329956, 757.320007, + 794.489990, 546.719971, 583.530029, 620.520020, 657.690002, + 695.040039, 732.570007, 770.279968, 808.169983, 555.359985, + 592.889954, 630.599976, 668.489990, 706.559998, 744.809998, + 783.239990, 821.849976, 564.000000, 602.250000, 640.679993, + 679.289978, 718.080017, 757.050049, 796.199951, 835.530029, + 572.640015, 611.609985, 650.760010, 690.089966, 729.600037, + 769.289978, 809.160034, 849.210083, 581.279968, 620.970032, + 660.839966, 700.889954, 741.119995, 781.529968, 822.119995, + 862.890015, 589.919983, 630.330017, 670.919983, 711.690002, + 752.640015, 793.770020, 835.079956, 876.570007, 598.559998, + 639.690002, 681.000000, 722.490051, 764.160034, 806.010010, + 848.039978, 890.250061, 292.799988, 320.220001, 347.760010, + 375.419983, 403.200012, 431.100006, 459.119995, 487.260010, + 523.199951, 552.059998, 581.040039, 610.139954, 639.360046, + 668.699951, 698.159973, 727.739990, 624.479980, 667.770020, + 711.239990, 754.890015, 798.719971, 842.729980, 886.919983, + 931.290039, 633.119995, 677.130005, 721.319946, 765.690002, + 810.239990, 854.969971, 899.880005, 944.969971, 641.760010, + 686.489990, 731.400024, 776.489990, 821.760010, 867.209961, + 912.839966, 958.650024, 650.400024, 695.849976, 741.479980, + 787.290039, 833.279968, 879.449951, 925.799927, 972.330017, + 659.040039, 705.210022, 751.559998, 798.089966, 844.800049, + 891.690002, 938.760010, 986.010010, 667.679993, 714.569946, + 761.640015, 808.890015, 856.320007, 903.929993, 951.719971, + 999.690063, 676.320007, 723.929993, 771.719971, 819.690002, + 867.839966, 916.169922, 964.679932, 1013.369995, 684.959961, + 733.290039, 781.800049, 830.489990, 879.359985, 928.410034, + 977.640015, 1027.050049, 331.199982, 363.419983, 395.760010, + 428.220001, 460.799988, 493.500000, 526.320007, 559.260010, + 600.000000, 633.660034, 667.440002, 701.339966, 735.359985, + 769.500000, 803.759949, 838.140015, 710.880005, 761.369995, + 812.039978, 862.889893, 913.919983, 965.130005, 1016.520020, + 1068.090088, 719.520020, 770.729980, 822.119934, 873.689941, + 925.440063, 977.369995, 1029.479980, 1081.770020, 728.160034, + 780.090088, 832.199951, 884.489990, 936.960022, 989.610046, + 1042.439941, 1095.449951, 736.799927, 789.449951, 842.280029, + 895.290039, 948.480042, 1001.849976, 1055.399902, 1109.129883, + 745.439941, 798.810059, 852.359985, 906.089966, 960.000000, + 1014.089966, 1068.359985, 1122.810059, 754.080017, 808.170044, + 862.440002, 916.890015, 971.520020, 1026.330078, 1081.319946, + 1136.489990, 762.720032, 817.530029, 872.520020, 927.689941, + 983.040039, 1038.569946, 1094.280029, 1150.169922, 771.359985, + 826.890015, 882.599976, 938.489990, 994.559998, 1050.810059, + 1107.239990, 1163.849976, 369.599976, 406.619995, 443.760010, + 481.020020, 518.400024, 555.900024, 593.520020, 631.260010, + 113.279999, 136.839996, 160.480011, 184.199982, 208.000015, + 231.880005, 255.839996, 279.880005, 31.359985, 66.699989, + 102.160004, 137.740005, 173.440002, 209.260010, 245.199982, + 281.260010, 31.359993, 67.179993, 103.120003, 139.179993, + 175.360016, 211.660004, 248.079987, 284.619995, 31.359993, + 67.659996, 104.080009, 140.619995, 177.280014, 214.060013, + 250.959991, 287.980011, 31.359993, 68.139999, 105.039993, + 142.059982, 179.200027, 216.459991, 253.839996, 291.339996, + 31.360008, 68.619995, 106.000000, 143.499985, 181.119995, + 218.860001, 256.719971, 294.700012, 31.360001, 69.099991, + 106.959984, 144.939987, 183.040009, 221.260010, 259.600006, + 298.059998, 31.360008, 69.579971, 107.920006, 146.379990, + 184.960007, 223.660004, 262.479980, 301.419983, 31.360001, + 70.059975, 108.880020, 147.819977, 186.880020, 226.059998, + 265.359985, 304.779999, -83.840004, -58.040001, -32.159988, + -6.200012, 19.840012, 45.959984, 72.159996, 98.440010}, + sd::DataType::FLOAT32); + + input.linspace(-10, 0.1); + weights.linspace(-2, 0.1); + + sd::ops::depthwise_conv2d op; + auto results = + op.evaluate({&input, &weights}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + // output->printBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_9) { - - int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=10,oW=10; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); - NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); - - NDArray expOutput('c', {bS, oC, oH, oW}, {-103.360001, -131.440002, -130.000000, -128.559998, -127.120003, -125.680000, -124.240005, -122.799995, -121.360001, -66.720001,-76.199997, -81.239998, -80.160004, -79.080002, -78.000000, -76.919998, -75.840004, -74.760002, -73.680000, -29.400002, -66.599998, -70.440002, -69.360001, -68.279999, - -67.199997, -66.120003, -65.040001, -63.959999, -62.879997, -24.599997, -57.000000, -59.639999, -58.560005, -57.479996, -56.399998, -55.320000, -54.240002, -53.159996, -52.080002, -19.799997, -47.400002, -48.840000, -47.760002, -46.680000, -45.599998, -44.520000, -43.440002, -42.360001, -41.279999, -15.000000, -37.799999, -38.040001, - -36.959999, -35.879997, -34.799999, -33.720001, -32.639999, -31.560001, -30.479996, -10.199999, -28.200001, -27.240002, -26.160000, -25.080002, -24.000000, -22.919998,-21.840000, -20.759998, -19.679998, -5.400000, -18.599998, -16.439999, -15.360001, -14.280001, -13.200001, -12.120001, -11.040000, -9.960001, -8.880000, -0.600000, - -9.000000, -5.639999, -4.560000, -3.480000, -2.400000, -1.320001, -0.240000, 0.840001, 1.920000, 4.200000, 0.160000, 3.920000, 3.920000, 3.920000, 3.920000, 3.920000,3.920001, 3.920000, 3.920000, 3.520000, 8.860001, 12.920000, 14.420000, 15.920000, 17.420000, 18.920000, 20.420000, 21.920000, 23.420000, 13.820000, 20.430000, 27.750000, - 28.919998, 30.090000, 31.260000, 32.430000, 33.600002, 34.770000, 35.939999, 19.709999, 30.630001, 39.450001, 40.619999, 41.790001, 42.960003, 44.129997, 45.299999, 46.470001, 47.639999, 25.110001, 40.829998, 51.150002, 52.320000, 53.489998, 54.660004, 55.829994, 57.000000, 58.169998, 59.340004, 30.510002, 51.029999, 62.849998, - 64.019997, 65.190002, 66.360001, 67.529999, 68.699997, 69.870003, 71.040001, 35.910000, 61.229996, 74.550003, 75.720001, 76.889999, 78.059998, 79.229996, 80.400002, 81.570000, 82.740005, 41.310001, 71.430000, 86.250000, 87.419998, 88.589996, 89.760002, 90.929993, 92.099991, 93.270004, 94.440002, 46.709999, 81.630005, 97.949997, - 99.120003, 100.290009, 101.459999, 102.630005, 103.800003, 104.970001, 106.139999, 52.110001, 91.830002, 109.649994, 110.820007, 111.990005, 113.159996, 114.330002, 115.500000, 116.669998, 117.839996, 57.509995, 19.580000, 9.079998, 9.139999, 9.199999, 9.259996, 9.320001, 9.379998, 9.440000, 9.500000, -8.740000, 129.080002, 169.279999, - 170.839996, 172.399994, 173.960007, 175.520004, 177.080002, 178.639999, 180.199982, 102.360001, 129.059998, 154.739990, 156.000000, 157.259995, 158.520004, 159.779999, 161.039993, 162.300003, 163.559998, 80.820000, 139.860001, 167.340012, 168.600006, 169.860001, 171.119995, 172.380005, 173.639999, 174.899994, 176.160004, 86.820000, - 150.660004, 179.940002, 181.200012, 182.459991, 183.720001, 184.980011, 186.239990, 187.500000, 188.759995, 92.820007, 161.459991, 192.540009, 193.799988, 195.059998, 196.319992, 197.579987, 198.839996, 200.100006, 201.360001, 98.820000, 172.259995, 205.139999, 206.399994, 207.660004, 208.919983, 210.179993, 211.440002, 212.700012, - 213.959991, 104.819992, 183.059998, 217.739990, 219.000000, 220.259995, 221.519989, 222.779999, 224.039993, 225.300018, 226.559998, 110.819992, 193.860016, 230.339996, 231.600006, 232.860001, 234.119995, 235.380005, 236.639999, 237.900009, 239.160004, 116.820000, 204.660004, 242.940002, 244.199982, 245.459991, 246.720001, 247.980011, - 249.239990, 250.500000, 251.759995, 122.819992, 47.000000, 26.240004, 26.360004, 26.479998, 26.600002, 26.720001, 26.840002, 26.959997, 27.080000, -12.999998, 257.299988, 337.640015, 339.260010, 340.879974, 342.499969, 344.119995, 345.740021, 347.359985, 348.979980, 198.899994, 249.690002, 299.729980, 301.079987, 302.429993, 303.779999, 305.130005, 306.480011, 307.829987, 309.179993, 153.929993, 261.089996, 313.230011, 314.580017, 315.929993, 317.279968, 318.630005, 319.979980, 321.329987, 322.679993, 160.529999, 272.489990, 326.729980, 328.079987, 329.429993, 330.779968, 332.130005, 333.479980, 334.829987, 336.179993, 167.130005, 283.889984, 340.230011, 341.580017, 342.929993, 344.279999, 345.630005, 346.980011, 348.330017, 349.679993, 173.729996, 295.289978, 353.729980, 355.079987, 356.429993, 357.779968, 359.130005, 360.480011, 361.829987, 363.179993, 180.329987, 306.690002, 367.230011, 368.580017, 369.929993, 371.279999, 372.630005, 373.980011, 375.330017, 376.679993, 186.929993, 318.089996, 380.729980, 382.080017, 383.429993, 384.779968, 386.130005, 387.479980, 388.829987, 390.179993, 193.529984, 329.489990, 394.229980, 395.579987, 396.929993, 398.279999, 399.630005, 400.980011, 402.330017, 403.679993, 200.130005, 82.419998, 55.400005, 55.580002, 55.759995, 55.939999, 56.120003, 56.299995, 56.479996, 56.659996, -9.260002, 393.520020, 518.000000, 519.679993, 521.359985, 523.040039, 524.720032, 526.400024, 528.080017, 529.760010, 303.440002, 382.320007, 462.720032, 464.160004, 465.600037, 467.040009, 468.479980, 469.919983, 471.359985, 472.800018, 239.040009, 394.320007, 477.119995, 478.559998, 480.000000, 481.440002, 482.880005, 484.320007, 485.760010, 487.200012, 246.240005, 406.320007, 491.520020, 492.960022, 494.400024, 495.839996, 497.280029, 498.720032, 500.160004, 501.600037, 253.440002, 418.320007, 505.919983, 507.359985, 508.800018, 510.240051, 511.680023, 513.119995, 514.559998, 516.000000, 260.640015, 430.319977, 520.320007, 521.760010, 523.200012, 524.640015, 526.079956, 527.520020, 528.960022, 530.400024, 267.839996, 442.320007, 534.720032, 536.160034, 537.600037, 539.040039, 540.479980, 541.919983, 543.359985, 544.800049, 275.040009, 454.320007, 549.119995, 550.559998, 552.000000, 553.440002, 554.880005, 556.320007, 557.760010, 559.200012, 282.239990, 466.320007, 563.520020, 564.960022, 566.400024, 567.839966, 569.280029, 570.720032, 572.160034, 573.600037, 289.440002, 125.839996, 96.559998, 96.799995, 97.040009, 97.280014, 97.520004, 97.759995, 98.000000, 98.240013, 2.480007, 537.739990, 710.359985, 712.099976, 713.840027, 715.579956, 717.319946, 719.059998, 720.799988, 722.539978, 415.980011, 526.950012, 643.710022, 645.240051, 646.770020, 648.300049, 649.829956, 651.359985, 652.890015, 654.419983, 336.149994, 539.549988, 659.010010, 660.539978, 662.070007, 663.600037, 665.130005, 666.660034, 668.190002, 669.720032, 343.950012, 552.150024, 674.309998, 675.839966, 677.369995, 678.900024, 680.429993, 681.960022, 683.490051, 685.020020, 351.750000, 564.750000, 689.609985, 691.140015, 692.669983, 694.200012, 695.729980, 697.260010, 698.789978, 700.320007, 359.549988, 577.349976, 704.910034, 706.440002, 707.970032, 709.500000, 711.029968, 712.559998, 714.089966, 715.619995, 367.350037, 589.950012, 720.210022, 721.740051, 723.270020, 724.800049, 726.329956, 727.859985, 729.390015, 730.919983, 375.149994, 602.549988, 735.510010, 737.039978, 738.570007, 740.100037, 741.630005, 743.160034, 744.690002, 746.220032, 382.950012, 615.150024, 750.809998, 752.339966, 753.869995, 755.399963, 756.929993, 758.460022, 759.990051, 761.520020, 390.750000, 177.260010, 149.720001, 150.020004, 150.319992, 150.619995, 150.919998, 151.220001, 151.520004, 151.819992, 22.220009, 689.959961, 914.720032, 916.519958, 918.319946, 920.119995, 921.919983, 923.719971, 925.520020, 927.320007, 536.519958, 683.579956, 842.699951, 844.319946, 845.940002, 847.559998, 849.179993, 850.799988, 852.419983, 854.039978, 445.260010, 696.779968, 858.900024, 860.520020, 862.140015, 863.760010, 865.380005, 867.000000, 868.619995, 870.239990, 453.659973, 709.979980, 875.099976, 876.719971, 878.339966, 879.959961, 881.579956, 883.199951, 884.819946, 886.440002, 462.059998, 723.179993, 891.299988, 892.919983, 894.539978, 896.159973, 897.779968, 899.400024, 901.020020, 902.640015, 470.459991, 736.380005, 907.500000, 909.119995, 910.739990, 912.359985, 913.979980, 915.599976, 917.219971, 918.839966, 478.859985, 749.579956, 923.699951, 925.319946, 926.940002, 928.559998, 930.179993, 931.799988, 933.419983, 935.039978, 487.260010, 762.779968, 939.900024, 941.520020, 943.140015, 944.760010, 946.380005, 948.000000, 949.619995, 951.239990, 495.659973, 775.979980, 956.099976, 957.719971, 959.339966, 960.959961, 962.579956, 964.199951, 965.819946, 967.440002, 504.059998, 236.679977, 214.880005, 215.239990, 215.599991, 215.959991, 216.319992, 216.679993, 217.040009, 217.399994, 49.959995, 850.180054, 1131.079956, 1132.939941, 1134.800049, 1136.660034, 1138.520020, 1140.380005, 1142.239990, 1144.100098, 665.060059, 852.209961, 1059.689941, 1061.399902, 1063.110107, 1064.820068, 1066.530029, 1068.239990, 1069.950073, 1071.660034, 566.370056, 866.010010, 1076.790039, 1078.500000, 1080.209961, 1081.920044, 1083.630005, 1085.339966, 1087.050049, 1088.760010, 575.369995, 879.809998, 1093.890015, 1095.599976, 1097.310059, 1099.020020, 1100.729980, 1102.439941, 1104.149902, 1105.859985, 584.369995, 893.609985, 1110.989990, 1112.699951, 1114.410034, 1116.120117, 1117.830078, 1119.540039, 1121.250000, 1122.959961, 593.370056, 907.410034, 1128.089966, 1129.800049, 1131.510010, 1133.220093, 1134.929932, 1136.639893, 1138.349976, 1140.060059, 602.369995, 921.209961, 1145.189941, 1146.900024, 1148.609985, 1150.320068, 1152.030029, 1153.739990, 1155.449951, 1157.160034, 611.370056, 935.010010, 1162.290039, 1164.000000, 1165.709961, 1167.420044, 1169.130005, 1170.839966, 1172.550049, 1174.260010, 620.369995, 948.809998, 1179.390015, 1181.099976, 1182.810059, 1184.520020, 1186.229980, 1187.939941, 1189.650024, 1191.359985, 629.370056, 304.099976, 292.039978, 292.460022, 292.880005, 293.300018, 293.720001, 294.140015, 294.559998, 294.980042, 85.700005}, sd::DataType::FLOAT32); - - input.linspace(-10, 0.1); - weights.linspace(-2, 0.1); - - sd::ops::depthwise_conv2d op; - auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output, 1e-4)); - + int bS = 1, iH = 10, iW = 10, iC = 8, mC = 1, kH = 3, kW = 3, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 10, oW = 10; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); + + NDArray expOutput( + 'c', {bS, oC, oH, oW}, + {-103.360001, -131.440002, -130.000000, -128.559998, -127.120003, + -125.680000, -124.240005, -122.799995, -121.360001, -66.720001, + -76.199997, -81.239998, -80.160004, -79.080002, -78.000000, + -76.919998, -75.840004, -74.760002, -73.680000, -29.400002, + -66.599998, -70.440002, -69.360001, -68.279999, -67.199997, + -66.120003, -65.040001, -63.959999, -62.879997, -24.599997, + -57.000000, -59.639999, -58.560005, -57.479996, -56.399998, + -55.320000, -54.240002, -53.159996, -52.080002, -19.799997, + -47.400002, -48.840000, -47.760002, -46.680000, -45.599998, + -44.520000, -43.440002, -42.360001, -41.279999, -15.000000, + -37.799999, -38.040001, -36.959999, -35.879997, -34.799999, + -33.720001, -32.639999, -31.560001, -30.479996, -10.199999, + -28.200001, -27.240002, -26.160000, -25.080002, -24.000000, + -22.919998, -21.840000, -20.759998, -19.679998, -5.400000, + -18.599998, -16.439999, -15.360001, -14.280001, -13.200001, + -12.120001, -11.040000, -9.960001, -8.880000, -0.600000, + -9.000000, -5.639999, -4.560000, -3.480000, -2.400000, + -1.320001, -0.240000, 0.840001, 1.920000, 4.200000, + 0.160000, 3.920000, 3.920000, 3.920000, 3.920000, + 3.920000, 3.920001, 3.920000, 3.920000, 3.520000, + 8.860001, 12.920000, 14.420000, 15.920000, 17.420000, + 18.920000, 20.420000, 21.920000, 23.420000, 13.820000, + 20.430000, 27.750000, 28.919998, 30.090000, 31.260000, + 32.430000, 33.600002, 34.770000, 35.939999, 19.709999, + 30.630001, 39.450001, 40.619999, 41.790001, 42.960003, + 44.129997, 45.299999, 46.470001, 47.639999, 25.110001, + 40.829998, 51.150002, 52.320000, 53.489998, 54.660004, + 55.829994, 57.000000, 58.169998, 59.340004, 30.510002, + 51.029999, 62.849998, 64.019997, 65.190002, 66.360001, + 67.529999, 68.699997, 69.870003, 71.040001, 35.910000, + 61.229996, 74.550003, 75.720001, 76.889999, 78.059998, + 79.229996, 80.400002, 81.570000, 82.740005, 41.310001, + 71.430000, 86.250000, 87.419998, 88.589996, 89.760002, + 90.929993, 92.099991, 93.270004, 94.440002, 46.709999, + 81.630005, 97.949997, 99.120003, 100.290009, 101.459999, + 102.630005, 103.800003, 104.970001, 106.139999, 52.110001, + 91.830002, 109.649994, 110.820007, 111.990005, 113.159996, + 114.330002, 115.500000, 116.669998, 117.839996, 57.509995, + 19.580000, 9.079998, 9.139999, 9.199999, 9.259996, + 9.320001, 9.379998, 9.440000, 9.500000, -8.740000, + 129.080002, 169.279999, 170.839996, 172.399994, 173.960007, + 175.520004, 177.080002, 178.639999, 180.199982, 102.360001, + 129.059998, 154.739990, 156.000000, 157.259995, 158.520004, + 159.779999, 161.039993, 162.300003, 163.559998, 80.820000, + 139.860001, 167.340012, 168.600006, 169.860001, 171.119995, + 172.380005, 173.639999, 174.899994, 176.160004, 86.820000, + 150.660004, 179.940002, 181.200012, 182.459991, 183.720001, + 184.980011, 186.239990, 187.500000, 188.759995, 92.820007, + 161.459991, 192.540009, 193.799988, 195.059998, 196.319992, + 197.579987, 198.839996, 200.100006, 201.360001, 98.820000, + 172.259995, 205.139999, 206.399994, 207.660004, 208.919983, + 210.179993, 211.440002, 212.700012, 213.959991, 104.819992, + 183.059998, 217.739990, 219.000000, 220.259995, 221.519989, + 222.779999, 224.039993, 225.300018, 226.559998, 110.819992, + 193.860016, 230.339996, 231.600006, 232.860001, 234.119995, + 235.380005, 236.639999, 237.900009, 239.160004, 116.820000, + 204.660004, 242.940002, 244.199982, 245.459991, 246.720001, + 247.980011, 249.239990, 250.500000, 251.759995, 122.819992, + 47.000000, 26.240004, 26.360004, 26.479998, 26.600002, + 26.720001, 26.840002, 26.959997, 27.080000, -12.999998, + 257.299988, 337.640015, 339.260010, 340.879974, 342.499969, + 344.119995, 345.740021, 347.359985, 348.979980, 198.899994, + 249.690002, 299.729980, 301.079987, 302.429993, 303.779999, + 305.130005, 306.480011, 307.829987, 309.179993, 153.929993, + 261.089996, 313.230011, 314.580017, 315.929993, 317.279968, + 318.630005, 319.979980, 321.329987, 322.679993, 160.529999, + 272.489990, 326.729980, 328.079987, 329.429993, 330.779968, + 332.130005, 333.479980, 334.829987, 336.179993, 167.130005, + 283.889984, 340.230011, 341.580017, 342.929993, 344.279999, + 345.630005, 346.980011, 348.330017, 349.679993, 173.729996, + 295.289978, 353.729980, 355.079987, 356.429993, 357.779968, + 359.130005, 360.480011, 361.829987, 363.179993, 180.329987, + 306.690002, 367.230011, 368.580017, 369.929993, 371.279999, + 372.630005, 373.980011, 375.330017, 376.679993, 186.929993, + 318.089996, 380.729980, 382.080017, 383.429993, 384.779968, + 386.130005, 387.479980, 388.829987, 390.179993, 193.529984, + 329.489990, 394.229980, 395.579987, 396.929993, 398.279999, + 399.630005, 400.980011, 402.330017, 403.679993, 200.130005, + 82.419998, 55.400005, 55.580002, 55.759995, 55.939999, + 56.120003, 56.299995, 56.479996, 56.659996, -9.260002, + 393.520020, 518.000000, 519.679993, 521.359985, 523.040039, + 524.720032, 526.400024, 528.080017, 529.760010, 303.440002, + 382.320007, 462.720032, 464.160004, 465.600037, 467.040009, + 468.479980, 469.919983, 471.359985, 472.800018, 239.040009, + 394.320007, 477.119995, 478.559998, 480.000000, 481.440002, + 482.880005, 484.320007, 485.760010, 487.200012, 246.240005, + 406.320007, 491.520020, 492.960022, 494.400024, 495.839996, + 497.280029, 498.720032, 500.160004, 501.600037, 253.440002, + 418.320007, 505.919983, 507.359985, 508.800018, 510.240051, + 511.680023, 513.119995, 514.559998, 516.000000, 260.640015, + 430.319977, 520.320007, 521.760010, 523.200012, 524.640015, + 526.079956, 527.520020, 528.960022, 530.400024, 267.839996, + 442.320007, 534.720032, 536.160034, 537.600037, 539.040039, + 540.479980, 541.919983, 543.359985, 544.800049, 275.040009, + 454.320007, 549.119995, 550.559998, 552.000000, 553.440002, + 554.880005, 556.320007, 557.760010, 559.200012, 282.239990, + 466.320007, 563.520020, 564.960022, 566.400024, 567.839966, + 569.280029, 570.720032, 572.160034, 573.600037, 289.440002, + 125.839996, 96.559998, 96.799995, 97.040009, 97.280014, + 97.520004, 97.759995, 98.000000, 98.240013, 2.480007, + 537.739990, 710.359985, 712.099976, 713.840027, 715.579956, + 717.319946, 719.059998, 720.799988, 722.539978, 415.980011, + 526.950012, 643.710022, 645.240051, 646.770020, 648.300049, + 649.829956, 651.359985, 652.890015, 654.419983, 336.149994, + 539.549988, 659.010010, 660.539978, 662.070007, 663.600037, + 665.130005, 666.660034, 668.190002, 669.720032, 343.950012, + 552.150024, 674.309998, 675.839966, 677.369995, 678.900024, + 680.429993, 681.960022, 683.490051, 685.020020, 351.750000, + 564.750000, 689.609985, 691.140015, 692.669983, 694.200012, + 695.729980, 697.260010, 698.789978, 700.320007, 359.549988, + 577.349976, 704.910034, 706.440002, 707.970032, 709.500000, + 711.029968, 712.559998, 714.089966, 715.619995, 367.350037, + 589.950012, 720.210022, 721.740051, 723.270020, 724.800049, + 726.329956, 727.859985, 729.390015, 730.919983, 375.149994, + 602.549988, 735.510010, 737.039978, 738.570007, 740.100037, + 741.630005, 743.160034, 744.690002, 746.220032, 382.950012, + 615.150024, 750.809998, 752.339966, 753.869995, 755.399963, + 756.929993, 758.460022, 759.990051, 761.520020, 390.750000, + 177.260010, 149.720001, 150.020004, 150.319992, 150.619995, + 150.919998, 151.220001, 151.520004, 151.819992, 22.220009, + 689.959961, 914.720032, 916.519958, 918.319946, 920.119995, + 921.919983, 923.719971, 925.520020, 927.320007, 536.519958, + 683.579956, 842.699951, 844.319946, 845.940002, 847.559998, + 849.179993, 850.799988, 852.419983, 854.039978, 445.260010, + 696.779968, 858.900024, 860.520020, 862.140015, 863.760010, + 865.380005, 867.000000, 868.619995, 870.239990, 453.659973, + 709.979980, 875.099976, 876.719971, 878.339966, 879.959961, + 881.579956, 883.199951, 884.819946, 886.440002, 462.059998, + 723.179993, 891.299988, 892.919983, 894.539978, 896.159973, + 897.779968, 899.400024, 901.020020, 902.640015, 470.459991, + 736.380005, 907.500000, 909.119995, 910.739990, 912.359985, + 913.979980, 915.599976, 917.219971, 918.839966, 478.859985, + 749.579956, 923.699951, 925.319946, 926.940002, 928.559998, + 930.179993, 931.799988, 933.419983, 935.039978, 487.260010, + 762.779968, 939.900024, 941.520020, 943.140015, 944.760010, + 946.380005, 948.000000, 949.619995, 951.239990, 495.659973, + 775.979980, 956.099976, 957.719971, 959.339966, 960.959961, + 962.579956, 964.199951, 965.819946, 967.440002, 504.059998, + 236.679977, 214.880005, 215.239990, 215.599991, 215.959991, + 216.319992, 216.679993, 217.040009, 217.399994, 49.959995, + 850.180054, 1131.079956, 1132.939941, 1134.800049, 1136.660034, + 1138.520020, 1140.380005, 1142.239990, 1144.100098, 665.060059, + 852.209961, 1059.689941, 1061.399902, 1063.110107, 1064.820068, + 1066.530029, 1068.239990, 1069.950073, 1071.660034, 566.370056, + 866.010010, 1076.790039, 1078.500000, 1080.209961, 1081.920044, + 1083.630005, 1085.339966, 1087.050049, 1088.760010, 575.369995, + 879.809998, 1093.890015, 1095.599976, 1097.310059, 1099.020020, + 1100.729980, 1102.439941, 1104.149902, 1105.859985, 584.369995, + 893.609985, 1110.989990, 1112.699951, 1114.410034, 1116.120117, + 1117.830078, 1119.540039, 1121.250000, 1122.959961, 593.370056, + 907.410034, 1128.089966, 1129.800049, 1131.510010, 1133.220093, + 1134.929932, 1136.639893, 1138.349976, 1140.060059, 602.369995, + 921.209961, 1145.189941, 1146.900024, 1148.609985, 1150.320068, + 1152.030029, 1153.739990, 1155.449951, 1157.160034, 611.370056, + 935.010010, 1162.290039, 1164.000000, 1165.709961, 1167.420044, + 1169.130005, 1170.839966, 1172.550049, 1174.260010, 620.369995, + 948.809998, 1179.390015, 1181.099976, 1182.810059, 1184.520020, + 1186.229980, 1187.939941, 1189.650024, 1191.359985, 629.370056, + 304.099976, 292.039978, 292.460022, 292.880005, 293.300018, + 293.720001, 294.140015, 294.559998, 294.980042, 85.700005}, + sd::DataType::FLOAT32); + + input.linspace(-10, 0.1); + weights.linspace(-2, 0.1); + + sd::ops::depthwise_conv2d op; + auto results = + op.evaluate({&input, &weights}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output, 1e-4)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_10) { - - int bS=1, iH=3,iW=3, iC=2,mC=2, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=3,oW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - int wFormat = 1; // 0-[kH, kW, iC, mC], 1-[mC, iC, kH, kW], 2-[mC, kH, kW, iC] - - NDArray input('c', {bS, iC, iH, iW}, {0.6793503761291504, 0.35508695244789124, 0.842789351940155, 0.20031332969665527, 0.7014986872673035, 0.3106933832168579, - 0.44793984293937683, 0.9380097389221191, 0.3266739547252655, 0.15187257528305054, 0.3833175301551819, 0.7821229696273804, - 0.19880719482898712, 0.7985635995864868, 0.16326339542865753, 0.14696824550628662, 0.2608966827392578, 0.13505761325359344}, sd::DataType::FLOAT32); - NDArray weights('c', {mC, iC, kH, kW}, {0.130845, 0.569885, 0.644284, 0.198968}, sd::DataType::FLOAT32); - NDArray biases('c', {iC*mC}, {0.6123566627502441, 0.37637925148010254, 0.17464971542358398, 0.4270855486392975}, sd::DataType::FLOAT32); - - NDArray expOutput('c', {bS, oC, oH, oW}, {0.7012459761288241, 0.6588178652487691, 0.722631079971582, 0.6385665758716108, 0.7041439625563628, 0.6530092074102978, - 0.670967162534851, 0.735090151337225, 0.6551001785478623, 0.8140738359624038, 0.6051560970782859, 0.9193749546773375, 0.5054379267801892, 0.8283436386757472, - 0.5765540302788565, 0.6649797296980537, 0.9807239274294943, 0.586850056971322, 0.261199593183985, 0.3930965634902499, 0.6203697362284615, 0.28794692117826504, - 0.6297390019475202, 0.26769104886224415, 0.25840469001015975, 0.3233307788551656, 0.25161700129415276, 0.4573034071191504, 0.5033536625992294, 0.5827033826425385, - 0.4666419179635315, 0.585974550122895, 0.4595698215161401, 0.45632759998045813, 0.4789957702325296, 0.4539577593482922}, sd::DataType::FLOAT32); - - sd::ops::depthwise_conv2d op; - auto results = op.evaluate({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); - auto* output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 1, iH = 3, iW = 3, iC = 2, mC = 2, kH = 1, kW = 1, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 3, oW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = + 1; // 0-[kH, kW, iC, mC], 1-[mC, iC, kH, kW], 2-[mC, kH, kW, iC] + + NDArray input('c', {bS, iC, iH, iW}, + {0.6793503761291504, 0.35508695244789124, 0.842789351940155, + 0.20031332969665527, 0.7014986872673035, 0.3106933832168579, + 0.44793984293937683, 0.9380097389221191, 0.3266739547252655, + 0.15187257528305054, 0.3833175301551819, 0.7821229696273804, + 0.19880719482898712, 0.7985635995864868, 0.16326339542865753, + 0.14696824550628662, 0.2608966827392578, 0.13505761325359344}, + sd::DataType::FLOAT32); + NDArray weights('c', {mC, iC, kH, kW}, + {0.130845, 0.569885, 0.644284, 0.198968}, + sd::DataType::FLOAT32); + NDArray biases('c', {iC * mC}, + {0.6123566627502441, 0.37637925148010254, 0.17464971542358398, + 0.4270855486392975}, + sd::DataType::FLOAT32); + + NDArray expOutput( + 'c', {bS, oC, oH, oW}, + {0.7012459761288241, 0.6588178652487691, 0.722631079971582, + 0.6385665758716108, 0.7041439625563628, 0.6530092074102978, + 0.670967162534851, 0.735090151337225, 0.6551001785478623, + 0.8140738359624038, 0.6051560970782859, 0.9193749546773375, + 0.5054379267801892, 0.8283436386757472, 0.5765540302788565, + 0.6649797296980537, 0.9807239274294943, 0.586850056971322, + 0.261199593183985, 0.3930965634902499, 0.6203697362284615, + 0.28794692117826504, 0.6297390019475202, 0.26769104886224415, + 0.25840469001015975, 0.3233307788551656, 0.25161700129415276, + 0.4573034071191504, 0.5033536625992294, 0.5827033826425385, + 0.4666419179635315, 0.585974550122895, 0.4595698215161401, + 0.45632759998045813, 0.4789957702325296, 0.4539577593482922}, + sd::DataType::FLOAT32); + + sd::ops::depthwise_conv2d op; + auto results = op.evaluate( + {&input, &weights, &biases}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat, wFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_11) { - - int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=10,oW=10; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - int wFormat = 2; // 0-[kH, kW, iC, mC], 1-[mC, iC, kH, kW], 2-[mC, kH, kW, iC] - - NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {mC, kH, kW, iC}, {-2., -1.9, -1.8, -1.7, -1.6, -1.5, -1.4, -1.3, -1.2, -1.1, -1., -0.9, -0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, - 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., - 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 5., 5.1}, sd::DataType::FLOAT32); - - NDArray expOutput('c', {bS, oH, oW, oC}, {-42.879997, -43.959999, -44.959999, -45.879997, -46.720005, -47.480003, -48.160000, -48.760002, -43.519997, -45.139999, -46.639996, -48.020000, -49.280003, -50.419998, -51.440006, -52.340000, -31.999998, -33.139999, -34.160000, -35.060001, -35.840004, -36.500004, -37.039997, -37.459999, -20.480000, - -21.139997, -21.680000, -22.100000, -22.399998, -22.579998, -22.639996, -22.580002, -8.960000, -9.139998, -9.200002, -9.140001, -8.960001, -8.660000, -8.240002, -7.700001, 2.560000, 2.860002, 3.279998, 3.820000, 4.480001, 5.260000, 6.160001, 7.180000, 14.080000, 14.860000, 15.759998, 16.779999, 17.920002, 19.180000, 20.560001, 22.059998, - 25.600000, 26.860001, 28.239998, 29.739998, 31.360001, 33.099998, 34.959999, 36.939999, 37.119999, 38.860001, 40.720001, 42.699997, 44.800003, 47.020000, 49.360001, 51.820000, 26.239998, 27.400002, 28.639999, 29.959999, 31.360001, 32.840000, 34.400002, 36.040001, 62.400002, 62.459999, 62.639999, 62.940002, 63.360001, 63.900002, 64.559998, - 65.340004, 106.080002, 106.169998, 106.440002, 106.889999, 107.519997, 108.330002, 109.320000, 110.490005, 114.720001, 115.529999, 116.520004, 117.690002, 119.040009, 120.570000, 122.279999, 124.169998, 123.359985, 124.889999, 126.599998, 128.490005, 130.559998, 132.809998, 135.240005, 137.850006, 132.000000, 134.250000, 136.679993, - 139.290009, 142.080002, 145.049988, 148.199997, 151.529999, 140.639999, 143.610001, 146.760010, 150.089996, 153.600006, 157.290009, 161.160004, 165.209991, 149.279999, 152.970001, 156.839996, 160.889999, 165.120010, 169.529999, 174.119995, 178.889999, 157.919998, 162.330002, 166.919983, 171.690002, 176.639999, 181.769989, 187.079987, - 192.570007, 166.559998, 171.690002, 177.000000, 182.489990, 188.160004, 194.010010, 200.040009, 206.250000, 100.799995, 104.220001, 107.760002, 111.419998, 115.200005, 119.099998, 123.120003, 127.260010, 139.200012, 144.059998, 149.040009, 154.139999, 159.360001, 164.699997, 170.160004, 175.739990, 192.479996, 199.770020, 207.239990, - 214.889999, 222.720001, 230.730011, 238.919998, 247.290009, 201.119995, 209.129990, 217.319992, 225.690002, 234.240005, 242.970001, 251.880005, 260.970001, 209.760010, 218.489990, 227.399994, 236.490005, 245.760010, 255.209991, 264.839996, 274.649994, 218.399994, 227.850006, 237.479996, 247.289993, 257.279999, 267.449982, 277.799988, - 288.330017, 227.040009, 237.209991, 247.559998, 258.089996, 268.800018, 279.690002, 290.760010, 302.010010, 235.679993, 246.570007, 257.639984, 268.889984, 280.320007, 291.929993, 303.720001, 315.690002, 244.320007, 255.929993, 267.720001, 279.690002, 291.839996, 304.169983, 316.679993, 329.369995, 252.959991, 265.290009, 277.799988, - 290.489990, 303.359985, 316.410004, 329.640015, 343.050018, 139.199997, 147.419998, 155.760010, 164.220001, 172.799988, 181.500000, 190.319992, 199.260010, 216.000000, 225.660004, 235.440002, 245.339996, 255.360016, 265.500000, 275.760010, 286.140015, 278.880005, 293.369995, 308.040009, 322.889984, 337.920013, 353.129974, 368.519989, - 384.090027, 287.520020, 302.730011, 318.119995, 333.690002, 349.440002, 365.369995, 381.479980, 397.770020, 296.160004, 312.089996, 328.199982, 344.489990, 360.960022, 377.609985, 394.440002, 411.449982, 304.799988, 321.450012, 338.280029, 355.289978, 372.480011, 389.850006, 407.399994, 425.130005, 313.440002, 330.809998, 348.359985, 366.089996, 384.000000, 402.090027, 420.359985, 438.809998, 322.079987, 340.169983, 358.440002, 376.889984, 395.520020, 414.329987, 433.320007, 452.489990, 330.720001, 349.530029, 368.520020, 387.690002, 407.039978, 426.570007, 446.279999, 466.170013, 339.360016, 358.890015, 378.599976, 398.490021, 418.559998, 438.809998, 459.239990, 479.849976, 177.600006, 190.619995, 203.759995, 217.020004, 230.399994, 243.899994, 257.519989, 271.260010, 292.799988, 307.260010, 321.839996, 336.539978, 351.360016, 366.299988, 381.359985, 396.540009, 365.279999, 386.970001, 408.839996, 430.889984, 453.120026, 475.529968, 498.119995, 520.890015, 373.920013, 396.329987, 418.919983, 441.690002, 464.640015, 487.769958, 511.079987, 534.570007, 382.559998, 405.690002, 429.000000, 452.489990, 476.160004, 500.010010, 524.039978, 548.250000, 391.200012, 415.049988, 439.080017, 463.290009, 487.679993, 512.250000, 537.000000, 561.930054, 399.839996, 424.409973, 449.160034, 474.089966, 499.200012, 524.489990, 549.959961, 575.609985, 408.479980, 433.770020, 459.239990, 484.889954, 510.720032, 536.729980, 562.919983, 589.290039, 417.119995, 443.130005, 469.319977, 495.690002, 522.239990, 548.969971, 575.880005, 602.969971, 425.760010, 452.489990, 479.399994, 506.489990, 533.760010, 561.209961, 588.839966, 616.650024, 216.000000, 233.819992, 251.760010, 269.820007, 288.000000, 306.299988, 324.719971, 343.260010, 369.600006, 388.859985, 408.239990, 427.739990, 447.360016, 467.100006, 486.959961, 506.940002, 451.679993, 480.570007, 509.639984, 538.890015, 568.320007, 597.929993, 627.719971, 657.690002, 460.320007, 489.929993, 519.719971, 549.690002, 579.840027, 610.170044, 640.680054, 671.369995, 468.960022, 499.289978, 529.799988, 560.489990, 591.359985, 622.409973, 653.640015, 685.049988, 477.599976, 508.650024, 539.880005, 571.289978, 602.880005, 634.650024, 666.599976, 698.729980, 486.239990, 518.010010, 549.960022, 582.089966, 614.400024, 646.890015, 679.559937, 712.410034, 494.879974, 527.369995, 560.039978, 592.890015, 625.920044, 659.130005, 692.520020, 726.089966, 503.519989, 536.729980, 570.119995, 603.689941, 637.440063, 671.369995, 705.480042, 739.770020, 512.160034, 546.089966, 580.199951, 614.489990, 648.960022, 683.609985, 718.440002, 753.449951, 254.400009, 277.020020, 299.760010, 322.619995, 345.600006, 368.700012, 391.919983, 415.260010, 446.399994, 470.459961, 494.640015, 518.940002, 543.360046, 567.900024, 592.559998, 617.340027, 538.080017, 574.170044, 610.440002, 646.890015, 683.520020, 720.329956, 757.320007, 794.489990, 546.719971, 583.530029, 620.520020, 657.690002, 695.040039, 732.570007, 770.279968, 808.169983, 555.359985, 592.889954, 630.599976, 668.489990, 706.559998, 744.809998, 783.239990, 821.849976, 564.000000, 602.250000, 640.679993, 679.289978, 718.080017, 757.050049, 796.199951, 835.530029, 572.640015, 611.609985, 650.760010, 690.089966, 729.600037, 769.289978, 809.160034, 849.210083, 581.279968, 620.970032, 660.839966, 700.889954, 741.119995, 781.529968, 822.119995, 862.890015, 589.919983, 630.330017, 670.919983, 711.690002, 752.640015, 793.770020, 835.079956, 876.570007, 598.559998, 639.690002, 681.000000, 722.490051, 764.160034, 806.010010, 848.039978, 890.250061, 292.799988, 320.220001, 347.760010, 375.419983, 403.200012, 431.100006, 459.119995, 487.260010, 523.199951, 552.059998, 581.040039, 610.139954, 639.360046, 668.699951, 698.159973, 727.739990, 624.479980, 667.770020, 711.239990, 754.890015, 798.719971, 842.729980, 886.919983, 931.290039, 633.119995, 677.130005, 721.319946, 765.690002, 810.239990, 854.969971, 899.880005, 944.969971, 641.760010, 686.489990, 731.400024, 776.489990, 821.760010, 867.209961, 912.839966, 958.650024, 650.400024, 695.849976, 741.479980, 787.290039, 833.279968, 879.449951, 925.799927, 972.330017, 659.040039, 705.210022, 751.559998, 798.089966, 844.800049, 891.690002, 938.760010, 986.010010, 667.679993, 714.569946, 761.640015, 808.890015, 856.320007, 903.929993, 951.719971, 999.690063, 676.320007, 723.929993, 771.719971, 819.690002, 867.839966, 916.169922, 964.679932, 1013.369995, 684.959961, 733.290039, 781.800049, 830.489990, 879.359985, 928.410034, 977.640015, 1027.050049, 331.199982, 363.419983, 395.760010, 428.220001, 460.799988, 493.500000, 526.320007, 559.260010, 600.000000, 633.660034, 667.440002, 701.339966, 735.359985, 769.500000, 803.759949, 838.140015, 710.880005, 761.369995, 812.039978, 862.889893, 913.919983, 965.130005, 1016.520020, 1068.090088, 719.520020, 770.729980, 822.119934, 873.689941, 925.440063, 977.369995, 1029.479980, 1081.770020, 728.160034, 780.090088, 832.199951, 884.489990, 936.960022, 989.610046, 1042.439941, 1095.449951, 736.799927, 789.449951, 842.280029, 895.290039, 948.480042, 1001.849976, 1055.399902, 1109.129883, 745.439941, 798.810059, 852.359985, 906.089966, 960.000000, 1014.089966, 1068.359985, 1122.810059, 754.080017, 808.170044, 862.440002, 916.890015, 971.520020, 1026.330078, 1081.319946, 1136.489990, 762.720032, 817.530029, 872.520020, 927.689941, 983.040039, 1038.569946, 1094.280029, 1150.169922, 771.359985, 826.890015, 882.599976, 938.489990, 994.559998, 1050.810059, 1107.239990, 1163.849976, 369.599976, 406.619995, 443.760010, 481.020020, 518.400024, 555.900024, 593.520020, 631.260010, 113.279999, 136.839996, 160.480011, 184.199982, 208.000015, 231.880005, 255.839996, 279.880005, 31.359985, 66.699989, 102.160004, 137.740005, 173.440002, 209.260010, 245.199982, 281.260010, 31.359993, 67.179993, 103.120003, 139.179993, 175.360016, 211.660004, 248.079987, 284.619995, 31.359993, 67.659996, 104.080009, 140.619995, 177.280014, 214.060013, 250.959991, 287.980011, 31.359993, 68.139999, 105.039993, 142.059982, 179.200027, 216.459991, 253.839996, 291.339996, 31.360008, 68.619995, 106.000000, 143.499985, 181.119995, 218.860001, 256.719971, 294.700012, 31.360001, 69.099991, 106.959984, 144.939987, 183.040009, 221.260010, 259.600006, 298.059998, 31.360008, 69.579971, 107.920006, 146.379990, 184.960007, 223.660004, 262.479980, 301.419983, 31.360001, 70.059975, 108.880020, 147.819977, 186.880020, 226.059998, 265.359985, 304.779999, -83.840004, -58.040001, -32.159988, -6.200012, 19.840012, 45.959984, 72.159996, 98.440010}, sd::DataType::FLOAT32); - - input.linspace(-10, 0.1); - weights.linspace(-2, 0.1); - - sd::ops::depthwise_conv2d op; - auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 1, iH = 10, iW = 10, iC = 8, mC = 1, kH = 3, kW = 3, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 10, oW = 10; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = + 2; // 0-[kH, kW, iC, mC], 1-[mC, iC, kH, kW], 2-[mC, kH, kW, iC] + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {mC, kH, kW, iC}, + {-2., -1.9, -1.8, -1.7, -1.6, -1.5, -1.4, -1.3, -1.2, -1.1, -1., -0.9, + -0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0., 0.1, 0.2, 0.3, + 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2, 1.3, 1.4, 1.5, + 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, + 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, + 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 5., 5.1}, + sd::DataType::FLOAT32); + + NDArray expOutput( + 'c', {bS, oH, oW, oC}, + {-42.879997, -43.959999, -44.959999, -45.879997, -46.720005, + -47.480003, -48.160000, -48.760002, -43.519997, -45.139999, + -46.639996, -48.020000, -49.280003, -50.419998, -51.440006, + -52.340000, -31.999998, -33.139999, -34.160000, -35.060001, + -35.840004, -36.500004, -37.039997, -37.459999, -20.480000, + -21.139997, -21.680000, -22.100000, -22.399998, -22.579998, + -22.639996, -22.580002, -8.960000, -9.139998, -9.200002, + -9.140001, -8.960001, -8.660000, -8.240002, -7.700001, + 2.560000, 2.860002, 3.279998, 3.820000, 4.480001, + 5.260000, 6.160001, 7.180000, 14.080000, 14.860000, + 15.759998, 16.779999, 17.920002, 19.180000, 20.560001, + 22.059998, 25.600000, 26.860001, 28.239998, 29.739998, + 31.360001, 33.099998, 34.959999, 36.939999, 37.119999, + 38.860001, 40.720001, 42.699997, 44.800003, 47.020000, + 49.360001, 51.820000, 26.239998, 27.400002, 28.639999, + 29.959999, 31.360001, 32.840000, 34.400002, 36.040001, + 62.400002, 62.459999, 62.639999, 62.940002, 63.360001, + 63.900002, 64.559998, 65.340004, 106.080002, 106.169998, + 106.440002, 106.889999, 107.519997, 108.330002, 109.320000, + 110.490005, 114.720001, 115.529999, 116.520004, 117.690002, + 119.040009, 120.570000, 122.279999, 124.169998, 123.359985, + 124.889999, 126.599998, 128.490005, 130.559998, 132.809998, + 135.240005, 137.850006, 132.000000, 134.250000, 136.679993, + 139.290009, 142.080002, 145.049988, 148.199997, 151.529999, + 140.639999, 143.610001, 146.760010, 150.089996, 153.600006, + 157.290009, 161.160004, 165.209991, 149.279999, 152.970001, + 156.839996, 160.889999, 165.120010, 169.529999, 174.119995, + 178.889999, 157.919998, 162.330002, 166.919983, 171.690002, + 176.639999, 181.769989, 187.079987, 192.570007, 166.559998, + 171.690002, 177.000000, 182.489990, 188.160004, 194.010010, + 200.040009, 206.250000, 100.799995, 104.220001, 107.760002, + 111.419998, 115.200005, 119.099998, 123.120003, 127.260010, + 139.200012, 144.059998, 149.040009, 154.139999, 159.360001, + 164.699997, 170.160004, 175.739990, 192.479996, 199.770020, + 207.239990, 214.889999, 222.720001, 230.730011, 238.919998, + 247.290009, 201.119995, 209.129990, 217.319992, 225.690002, + 234.240005, 242.970001, 251.880005, 260.970001, 209.760010, + 218.489990, 227.399994, 236.490005, 245.760010, 255.209991, + 264.839996, 274.649994, 218.399994, 227.850006, 237.479996, + 247.289993, 257.279999, 267.449982, 277.799988, 288.330017, + 227.040009, 237.209991, 247.559998, 258.089996, 268.800018, + 279.690002, 290.760010, 302.010010, 235.679993, 246.570007, + 257.639984, 268.889984, 280.320007, 291.929993, 303.720001, + 315.690002, 244.320007, 255.929993, 267.720001, 279.690002, + 291.839996, 304.169983, 316.679993, 329.369995, 252.959991, + 265.290009, 277.799988, 290.489990, 303.359985, 316.410004, + 329.640015, 343.050018, 139.199997, 147.419998, 155.760010, + 164.220001, 172.799988, 181.500000, 190.319992, 199.260010, + 216.000000, 225.660004, 235.440002, 245.339996, 255.360016, + 265.500000, 275.760010, 286.140015, 278.880005, 293.369995, + 308.040009, 322.889984, 337.920013, 353.129974, 368.519989, + 384.090027, 287.520020, 302.730011, 318.119995, 333.690002, + 349.440002, 365.369995, 381.479980, 397.770020, 296.160004, + 312.089996, 328.199982, 344.489990, 360.960022, 377.609985, + 394.440002, 411.449982, 304.799988, 321.450012, 338.280029, + 355.289978, 372.480011, 389.850006, 407.399994, 425.130005, + 313.440002, 330.809998, 348.359985, 366.089996, 384.000000, + 402.090027, 420.359985, 438.809998, 322.079987, 340.169983, + 358.440002, 376.889984, 395.520020, 414.329987, 433.320007, + 452.489990, 330.720001, 349.530029, 368.520020, 387.690002, + 407.039978, 426.570007, 446.279999, 466.170013, 339.360016, + 358.890015, 378.599976, 398.490021, 418.559998, 438.809998, + 459.239990, 479.849976, 177.600006, 190.619995, 203.759995, + 217.020004, 230.399994, 243.899994, 257.519989, 271.260010, + 292.799988, 307.260010, 321.839996, 336.539978, 351.360016, + 366.299988, 381.359985, 396.540009, 365.279999, 386.970001, + 408.839996, 430.889984, 453.120026, 475.529968, 498.119995, + 520.890015, 373.920013, 396.329987, 418.919983, 441.690002, + 464.640015, 487.769958, 511.079987, 534.570007, 382.559998, + 405.690002, 429.000000, 452.489990, 476.160004, 500.010010, + 524.039978, 548.250000, 391.200012, 415.049988, 439.080017, + 463.290009, 487.679993, 512.250000, 537.000000, 561.930054, + 399.839996, 424.409973, 449.160034, 474.089966, 499.200012, + 524.489990, 549.959961, 575.609985, 408.479980, 433.770020, + 459.239990, 484.889954, 510.720032, 536.729980, 562.919983, + 589.290039, 417.119995, 443.130005, 469.319977, 495.690002, + 522.239990, 548.969971, 575.880005, 602.969971, 425.760010, + 452.489990, 479.399994, 506.489990, 533.760010, 561.209961, + 588.839966, 616.650024, 216.000000, 233.819992, 251.760010, + 269.820007, 288.000000, 306.299988, 324.719971, 343.260010, + 369.600006, 388.859985, 408.239990, 427.739990, 447.360016, + 467.100006, 486.959961, 506.940002, 451.679993, 480.570007, + 509.639984, 538.890015, 568.320007, 597.929993, 627.719971, + 657.690002, 460.320007, 489.929993, 519.719971, 549.690002, + 579.840027, 610.170044, 640.680054, 671.369995, 468.960022, + 499.289978, 529.799988, 560.489990, 591.359985, 622.409973, + 653.640015, 685.049988, 477.599976, 508.650024, 539.880005, + 571.289978, 602.880005, 634.650024, 666.599976, 698.729980, + 486.239990, 518.010010, 549.960022, 582.089966, 614.400024, + 646.890015, 679.559937, 712.410034, 494.879974, 527.369995, + 560.039978, 592.890015, 625.920044, 659.130005, 692.520020, + 726.089966, 503.519989, 536.729980, 570.119995, 603.689941, + 637.440063, 671.369995, 705.480042, 739.770020, 512.160034, + 546.089966, 580.199951, 614.489990, 648.960022, 683.609985, + 718.440002, 753.449951, 254.400009, 277.020020, 299.760010, + 322.619995, 345.600006, 368.700012, 391.919983, 415.260010, + 446.399994, 470.459961, 494.640015, 518.940002, 543.360046, + 567.900024, 592.559998, 617.340027, 538.080017, 574.170044, + 610.440002, 646.890015, 683.520020, 720.329956, 757.320007, + 794.489990, 546.719971, 583.530029, 620.520020, 657.690002, + 695.040039, 732.570007, 770.279968, 808.169983, 555.359985, + 592.889954, 630.599976, 668.489990, 706.559998, 744.809998, + 783.239990, 821.849976, 564.000000, 602.250000, 640.679993, + 679.289978, 718.080017, 757.050049, 796.199951, 835.530029, + 572.640015, 611.609985, 650.760010, 690.089966, 729.600037, + 769.289978, 809.160034, 849.210083, 581.279968, 620.970032, + 660.839966, 700.889954, 741.119995, 781.529968, 822.119995, + 862.890015, 589.919983, 630.330017, 670.919983, 711.690002, + 752.640015, 793.770020, 835.079956, 876.570007, 598.559998, + 639.690002, 681.000000, 722.490051, 764.160034, 806.010010, + 848.039978, 890.250061, 292.799988, 320.220001, 347.760010, + 375.419983, 403.200012, 431.100006, 459.119995, 487.260010, + 523.199951, 552.059998, 581.040039, 610.139954, 639.360046, + 668.699951, 698.159973, 727.739990, 624.479980, 667.770020, + 711.239990, 754.890015, 798.719971, 842.729980, 886.919983, + 931.290039, 633.119995, 677.130005, 721.319946, 765.690002, + 810.239990, 854.969971, 899.880005, 944.969971, 641.760010, + 686.489990, 731.400024, 776.489990, 821.760010, 867.209961, + 912.839966, 958.650024, 650.400024, 695.849976, 741.479980, + 787.290039, 833.279968, 879.449951, 925.799927, 972.330017, + 659.040039, 705.210022, 751.559998, 798.089966, 844.800049, + 891.690002, 938.760010, 986.010010, 667.679993, 714.569946, + 761.640015, 808.890015, 856.320007, 903.929993, 951.719971, + 999.690063, 676.320007, 723.929993, 771.719971, 819.690002, + 867.839966, 916.169922, 964.679932, 1013.369995, 684.959961, + 733.290039, 781.800049, 830.489990, 879.359985, 928.410034, + 977.640015, 1027.050049, 331.199982, 363.419983, 395.760010, + 428.220001, 460.799988, 493.500000, 526.320007, 559.260010, + 600.000000, 633.660034, 667.440002, 701.339966, 735.359985, + 769.500000, 803.759949, 838.140015, 710.880005, 761.369995, + 812.039978, 862.889893, 913.919983, 965.130005, 1016.520020, + 1068.090088, 719.520020, 770.729980, 822.119934, 873.689941, + 925.440063, 977.369995, 1029.479980, 1081.770020, 728.160034, + 780.090088, 832.199951, 884.489990, 936.960022, 989.610046, + 1042.439941, 1095.449951, 736.799927, 789.449951, 842.280029, + 895.290039, 948.480042, 1001.849976, 1055.399902, 1109.129883, + 745.439941, 798.810059, 852.359985, 906.089966, 960.000000, + 1014.089966, 1068.359985, 1122.810059, 754.080017, 808.170044, + 862.440002, 916.890015, 971.520020, 1026.330078, 1081.319946, + 1136.489990, 762.720032, 817.530029, 872.520020, 927.689941, + 983.040039, 1038.569946, 1094.280029, 1150.169922, 771.359985, + 826.890015, 882.599976, 938.489990, 994.559998, 1050.810059, + 1107.239990, 1163.849976, 369.599976, 406.619995, 443.760010, + 481.020020, 518.400024, 555.900024, 593.520020, 631.260010, + 113.279999, 136.839996, 160.480011, 184.199982, 208.000015, + 231.880005, 255.839996, 279.880005, 31.359985, 66.699989, + 102.160004, 137.740005, 173.440002, 209.260010, 245.199982, + 281.260010, 31.359993, 67.179993, 103.120003, 139.179993, + 175.360016, 211.660004, 248.079987, 284.619995, 31.359993, + 67.659996, 104.080009, 140.619995, 177.280014, 214.060013, + 250.959991, 287.980011, 31.359993, 68.139999, 105.039993, + 142.059982, 179.200027, 216.459991, 253.839996, 291.339996, + 31.360008, 68.619995, 106.000000, 143.499985, 181.119995, + 218.860001, 256.719971, 294.700012, 31.360001, 69.099991, + 106.959984, 144.939987, 183.040009, 221.260010, 259.600006, + 298.059998, 31.360008, 69.579971, 107.920006, 146.379990, + 184.960007, 223.660004, 262.479980, 301.419983, 31.360001, + 70.059975, 108.880020, 147.819977, 186.880020, 226.059998, + 265.359985, 304.779999, -83.840004, -58.040001, -32.159988, + -6.200012, 19.840012, 45.959984, 72.159996, 98.440010}, + sd::DataType::FLOAT32); + + input.linspace(-10, 0.1); + weights.linspace(-2, 0.1); + + sd::ops::depthwise_conv2d op; + auto results = op.evaluate( + {&input, &weights}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat, wFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test1) { - - int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=4,oW=3; - int oC=iC*mC; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3,4}); - auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); - - NDArray expGradI('c', {bS, iH, iW, iC},{0.07 , 0.19 , 0.348, 0.652, 0.588, 0.956, 0.387, 0.687, 1.326, 2.022, 1.878, 2.67 , 1.071, 1.515, 2.982, 3.966, 3.534, 4.614, 1.606, 1.982, 3.932, 4.748, 4.428, 5.308, - 1.126, 1.63 , 3.228, 4.3 , 3.468, 4.604, 3.123, 3.999, 7.95 , 9.798, 8.502, 10.446, 3.807, 4.827, 9.606, 11.742,10.158, 12.39 , 4.198, 4.958, 9.884, 11.468,10.38 , 12.028}, sd::DataType::FLOAT32); - - NDArray expGradW('c', {kH, kW, iC, mC},{19.08, 19.44,19.8 , 20.16,12.24, 12.48,12.72, 12.96,22.56, 23.04,23.52, 24. ,14.4 , 14.72,15.04, 15.36,14.76, 15.12,15.48, 15.84, 9.36, 9.6 , 9.84, 10.08}, sd::DataType::FLOAT32); - - input = 2.; - weights.linspace(0.1, 0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::depthwise_conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* gradI = results.at(0); - auto* gradW = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - + int bS = 2, iH = 4, iW = 3, iC = 2, mC = 2, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 3; + int oC = iC * mC; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + auto bias = NDArrayFactory::create('c', {oC}, {1, 2, 3, 4}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); + + NDArray expGradI( + 'c', {bS, iH, iW, iC}, + {0.07, 0.19, 0.348, 0.652, 0.588, 0.956, 0.387, 0.687, 1.326, 2.022, + 1.878, 2.67, 1.071, 1.515, 2.982, 3.966, 3.534, 4.614, 1.606, 1.982, + 3.932, 4.748, 4.428, 5.308, 1.126, 1.63, 3.228, 4.3, 3.468, 4.604, + 3.123, 3.999, 7.95, 9.798, 8.502, 10.446, 3.807, 4.827, 9.606, 11.742, + 10.158, 12.39, 4.198, 4.958, 9.884, 11.468, 10.38, 12.028}, + sd::DataType::FLOAT32); + + NDArray expGradW('c', {kH, kW, iC, mC}, + {19.08, 19.44, 19.8, 20.16, 12.24, 12.48, 12.72, 12.96, + 22.56, 23.04, 23.52, 24., 14.4, 14.72, 15.04, 15.36, + 14.76, 15.12, 15.48, 15.84, 9.36, 9.6, 9.84, 10.08}, + sd::DataType::FLOAT32); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::depthwise_conv2d_bp op; + auto results = + op.evaluate({&input, &weights, &bias, &gradO}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test2) { - - int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int oC=iC*mC; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3,4}); - auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); - - NDArray expGradI('c', {bS, iH, iW, iC},{0.005, 0.025,0.034, 0.106,0.061, 0.113,0.058, 0.162,0.292, 0.564,0.298, 0.466,0.234, 0.402,0.772, 1.172,0.602, 0.834,0.333, 0.449,0.882, 1.146,0.581, 0.729, - 0.053, 0.137,0.258, 0.458,0.237, 0.353,0.41 , 0.642,1.252, 1.78 ,0.906, 1.202,1.098, 1.394,2.756, 3.412,1.722, 2.082,0.893, 1.073,2.13 , 2.522,1.269, 1.481}, sd::DataType::FLOAT32); - NDArray expGradW('c', {kH, kW, iC, mC},{2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88}, sd::DataType::FLOAT32); - - input = 2.; - weights.linspace(0.1, 0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::depthwise_conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* gradI = results.at(0); - auto* gradW = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - + int bS = 2, iH = 4, iW = 3, iC = 2, mC = 2, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int oC = iC * mC; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + auto bias = NDArrayFactory::create('c', {oC}, {1, 2, 3, 4}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); + + NDArray expGradI( + 'c', {bS, iH, iW, iC}, + {0.005, 0.025, 0.034, 0.106, 0.061, 0.113, 0.058, 0.162, 0.292, 0.564, + 0.298, 0.466, 0.234, 0.402, 0.772, 1.172, 0.602, 0.834, 0.333, 0.449, + 0.882, 1.146, 0.581, 0.729, 0.053, 0.137, 0.258, 0.458, 0.237, 0.353, + 0.41, 0.642, 1.252, 1.78, 0.906, 1.202, 1.098, 1.394, 2.756, 3.412, + 1.722, 2.082, 0.893, 1.073, 2.13, 2.522, 1.269, 1.481}, + sd::DataType::FLOAT32); + NDArray expGradW( + 'c', {kH, kW, iC, mC}, + {2.4, 2.56, 2.72, 2.88, 2.4, 2.56, 2.72, 2.88, 2.4, 2.56, 2.72, 2.88, + 2.4, 2.56, 2.72, 2.88, 2.4, 2.56, 2.72, 2.88, 2.4, 2.56, 2.72, 2.88}, + sd::DataType::FLOAT32); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::depthwise_conv2d_bp op; + auto results = + op.evaluate({&input, &weights, &bias, &gradO}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test3) { - - auto in = NDArrayFactory::create('c', {4, 8, 64, 64}); - auto w = NDArrayFactory::create('c', {2, 2, 8, 2}); - auto b = NDArrayFactory::create('c', {1, 16}); - auto grad = NDArrayFactory::create('c', {4, 16, 64, 64}); - - auto gradI = in.like(); - auto gradW = w.like(); - auto gradB = b.like(); - - nd4j:ops::depthwise_conv2d_bp op; - auto status = op.execute({&in, &w, &b, &grad}, {&gradI, &gradW, &gradB}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0}); - ASSERT_EQ(Status::OK(), status); + auto in = NDArrayFactory::create('c', {4, 8, 64, 64}); + auto w = NDArrayFactory::create('c', {2, 2, 8, 2}); + auto b = NDArrayFactory::create('c', {1, 16}); + auto grad = NDArrayFactory::create('c', {4, 16, 64, 64}); + + auto gradI = in.like(); + auto gradW = w.like(); + auto gradB = b.like(); + +nd4j: + ops::depthwise_conv2d_bp op; + auto status = op.execute({&in, &w, &b, &grad}, {&gradI, &gradW, &gradB}, + {2, 2, 1, 1, 0, 0, 1, 1, 1, 0}); + ASSERT_EQ(Status::OK(), status); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test4) { - - int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=10,oW=10; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, sd::DataType::FLOAT32); - - input.linspace(-10, 0.1); - weights.linspace(-2, 0.1); - gradO.linspace(10, -0.1); - - - NDArray expGradI('c', {bS, iH, iW, iC},{10.880001, 13.239998, 15.520001, 17.719997, 19.840000, 21.880001, 23.839998, 25.720001, 31.360004, 34.420002, 37.360001, 40.180004, 42.880005, 45.460003, 47.919994, 50.260002, 31.360001, 33.939999, 36.400002, 38.739998, 40.959999, 43.059998, 45.040001, 46.900005, 31.359997, 33.459999, 35.439999, 37.300003, 39.040001, 40.660000, 42.160000, 43.539997, 31.360001, 32.980000, 34.480000, 35.860001, 37.119999, 38.259998, 39.279999, 40.180000, 31.360001, 32.499996, 33.520000, 34.419998, 35.200001, 35.860001, 36.400002, 36.820000, 31.360001, 32.019997, 32.560001, 32.979996, 33.280003, 33.459999, 33.520000, 33.459999, 31.360001, 31.540001, 31.599998, 31.539999, 31.360001, 31.059999, 30.639999, 30.100000, 31.360001, 31.060001, 30.639999, 30.099998, 29.440002, 28.660000, 27.759998, 26.740000, 18.559999, 18.040001, 17.440001, 16.760000, 16.000000, 15.160000, 14.240001, 13.240000, 85.439995, 85.860001, 86.159996, 86.339996, 86.400002, 86.340012, 86.159996, 85.860008, 132.000000, 131.910004, 131.639999, 131.190002, 130.559998, 129.750000, 128.760010, 127.589996, 123.360001, 122.550003, 121.559998, 120.389999, 119.040009, 117.510002, 115.799988, 113.910004, 114.720001, 113.189995, 111.480003, 109.590004, 107.520004, 105.270004, 102.839996, 100.230011, 106.079994, 103.830002, 101.400009, 98.790009, 96.000008, - 93.030006, 89.879990, 86.549988, 97.439995, 94.469994, 91.319992, 87.990005, 84.479996, 80.789993, 76.919998, 72.870003, 88.800003, 85.110001, 81.239998, 77.190002, 72.960007, 68.550003, 63.959999, 59.190002, 80.160004, 75.750000, 71.160004, 66.389999, 61.440002, 56.309994, 51.000000, 45.510002, 71.519997, 66.389999, 61.079998, 55.590000, 49.919998, 44.070000, 38.040001, 31.830002, 31.680000, 27.780003, 23.760000, 19.619999, 15.360001, 10.980000, 6.480000, 1.859999, 47.040001, 42.660004, 38.160000, 33.540001, 28.799999, 23.939999, 18.960001, 13.860001, 45.599998, 38.310001, 30.840000, 23.190002, 15.360001, 7.349998, -0.840002, -9.210003, 36.959999, 28.950003, 20.759998, 12.390001, 3.839998, -4.889999, -13.799999, -22.890003, 28.320002, 19.589998, 10.680000, 1.590002, -7.680002, -17.129999, -26.759998, -36.570007, 19.680002, 10.230003, 0.599998, -9.210001, -19.199999, -29.370003, -39.720001, -50.250008, 11.039999, 0.869999, -9.480000, -20.010002, -30.719994, -41.610001, -52.679996, -63.930008, 2.400005, -8.489998, -19.560005, -30.809998, -42.239998, -53.849991, -65.639992, -77.610001, -6.239998, -17.849998, -29.639988, -41.609985, -53.760002, -66.090004, -78.599991, -91.290009, -14.879990, -27.209995, -39.720009, -52.410007, -65.279999, -78.330002, -91.559998, -104.969986, -45.119995, -53.820000, -62.639999, -71.580002, -80.640007, -89.819992, -99.119995, -108.540009, 8.639999, -0.540001, -9.839996, -19.259998, -28.799995, -38.459999, -48.240002, -58.140003, -40.799999, -55.289997, -69.960007, -84.810013, -99.840004, -115.050011, -130.440018, -146.010010, -49.439991, -64.650009, -80.040009, -95.610016, -111.360008, -127.290001, -143.399994, -159.690018, -58.080009, -74.009987, -90.119995, -106.409988, -122.880005, -139.530014, -156.360001, -173.369995, -66.720001, -83.369995, -100.199997, - -117.209999, -134.399994, -151.769989, -169.319992, -187.049988, -75.360008, -92.729996, -110.279991, -128.009979, -145.920013, -164.009995, -182.279984, -200.729996, -84.000000, -102.089996, -120.360016, -138.809967, -157.440002, -176.249969, -195.240005, -214.410019, -92.639999, -111.449997, -130.440018, -149.610016, -168.960007, -188.489990, -208.200012, -228.090012, -101.279976, -120.809982, -140.519989, -160.410004, -180.480011, -200.730011, -221.160034, -241.770020, -121.920006, -135.420013, -149.040009, -162.779999, -176.640015, -190.619995, -204.719986, -218.940002, -29.760002, -43.739998, -57.840000, -72.059998, -86.400009, -100.860001, -115.439995, -130.140015, -127.199997, -148.890015, -170.760010, -192.809998, -215.040024, -237.450012, -260.039978, -282.809998, -135.839996, -158.250000, -180.840012, -203.610046, -226.559982, -249.690002, -272.999969, -296.489990, -144.479980, -167.609985, -190.920013, -214.410019, -238.080032, -261.929993, -285.959991, -310.169983, -153.119995, -176.969986, -201.000031, -225.210022, -249.599976, -274.170013, -298.920013, -323.849976, -161.760040, -186.330017, -211.079987, -236.009995, -261.120026, -286.410034, -311.879974, -337.530029, -170.400009, -195.689987, -221.159973, -246.809998, -272.639954, -298.650024, -324.840057, -351.209991, -179.039963, -205.050018, -231.240021, -257.609985, -284.160004, -310.890015, -337.799988, -364.890015, -187.680023, -214.410004, -241.319977, -268.410004, -295.679993, -323.130005, -350.760010, -378.570038, -198.720016, -217.019989, -235.440002, -253.979980, -272.640045, -291.419983, -310.319977, -329.339996, -68.159981, -86.939987, -105.840012, -124.860001, -144.000000, -163.260010, -182.639984, -202.140015, -213.600021, -242.489990, -271.559937, -300.809998, -330.239990, -359.849976, -389.639984, - -419.610016, -222.240036, -251.849960, -281.640015, -311.609985, -341.760040, -372.089996, -402.600037, -433.290009, -230.880005, -261.210022, -291.719971, -322.410034, -353.280029, -384.329956, -415.559998, -446.970001, -239.519989, -270.570007, -301.800018, -333.209991, -364.800018, -396.570007, -428.520020, -460.650024, -248.160034, -279.929962, -311.880005, -344.010010, -376.320038, -408.809998, -441.479980, -474.330017, -256.799988, -289.289978, -321.960022, -354.809967, -387.839996, -421.050018, -454.440002, -488.009979, -265.440002, -298.650024, -332.040009, -365.609985, -399.360016, -433.290009, -467.399963, -501.689941, -274.080017, -308.009949, -342.119995, -376.409973, -410.880005, -445.530029, -480.359985, -515.369995, -275.520020, -298.619995, -321.839966, -345.179993, -368.640015, -392.220001, -415.919952, -439.740021, -106.560005, -130.140030, -153.840027, -177.659973, -201.599991, -225.660019, -249.840012, -274.140015, -300.000000, -336.090057, -372.360046, -408.809937, -445.440002, -482.250031, -519.240051, -556.410034, -308.640015, -345.450012, -382.440002, -419.609955, -456.959961, -494.489960, -532.200012, -570.089966, -317.280029, -354.809998, -392.520020, -430.410004, -468.480042, -506.729980, -545.159912, -583.770020, -325.920013, -364.169952, -402.600037, -441.210022, -480.000000, -518.970032, -558.119873, -597.449951, -334.559967, -373.529999, -412.679993, -452.009949, -491.519989, -531.209961, -571.080017, -611.129944, -343.200012, -382.889984, -422.760071, -462.809906, -503.039978, -543.449951, -584.039978, -624.809998, -351.839966, -392.250000, -432.839966, -473.609955, -514.560120, -555.689941, -596.999939, -638.489990, -360.480011, -401.610016, -442.920044, -484.409912, -526.080017, -567.929993, -609.959961, -652.169983, -352.320007, -380.220001, - -408.239990, -436.380005, -464.639984, -493.019989, -521.519958, -550.139954, -144.960022, -173.339996, -201.839996, -230.459976, -259.200043, -288.059998, -317.039978, -346.140015, -386.399963, -429.690002, -473.159912, -516.809937, -560.640076, -604.650024, -648.839966, -693.210022, -395.039978, -439.050018, -483.239929, -527.609985, -572.159973, -616.890015, -661.799988, -706.890015, -403.680023, -448.409973, -493.320007, -538.410034, -583.680054, -629.129944, -674.760010, -720.570068, -412.320007, -457.769897, -503.399963, -549.210083, -595.199951, -641.369995, -687.720093, -734.250000, -420.960052, -467.130035, -513.479980, -560.010010, -606.720093, -653.610046, -700.680054, -747.930115, -429.599976, -476.489990, -523.559998, -570.809937, -618.239990, -665.849976, -713.640015, -761.609985, -438.239990, -485.850037, -533.640015, -581.610046, -629.760010, -678.089966, -726.600037, -775.289917, -446.880035,-495.210052, -543.719971, -592.410034, -641.279968, -690.330017, -739.559937, -788.970093, -429.120026, -461.819946, -494.639984, -527.580017, -560.640015, -593.820007, -627.119995, -660.540039, -183.360016, -216.540009, -249.839996, -283.260040, -316.800018, -350.459961, -384.239990, -418.139984, -472.800049, -523.289917, -573.959961, -624.809998, -675.839966, -727.050049, -778.440063, -830.010010, -481.440002, -532.649963, -584.040100, -635.609985, -687.359924, -739.290039, -791.399963, -843.689941, -490.079987, -542.010010, -594.119995, -646.410034, -698.880005, -751.529968, -804.359985, -857.369995, -498.720032, -551.369995, -604.200012, -657.210022, -710.400024, -763.770081, -817.319946, -871.050049, -507.359955, -560.729919, -614.280029, -668.010010, -721.919983, -776.010010, -830.280029, -884.730042, -515.999939, -570.089966, -624.360046, -678.809937, -733.440002, - -788.250000, -843.239990, -898.410034, -524.639954, -579.449951, -634.440002, -689.609985, -744.960022, -800.489990, -856.200012, -912.090027, -533.280029, -588.810059, -644.520081, -700.409973, -756.480042, -812.730103, -869.159912, -925.769958, -505.920013, -543.420044, -581.040039, -618.780029, -656.640015, -694.620056, -732.719971, -770.940002, -447.359985, -471.559998, -495.840027, -520.200012, -544.640015, -569.159973, -593.760010, -618.440002, -815.359985, -852.140015, -889.040039, -926.059937, -963.200073, -1000.460022, -1037.839966, -1075.339966, -826.879944, -864.139954, -901.519958, -939.019958, -976.640076, -1014.379944, -1052.239990, -1090.219971, -838.400024, -876.140015, -913.999939, -951.979919, -990.080017, -1028.299927, -1066.640015, -1105.099976, -849.919983, -888.140015, -926.479980, -964.939941, -1003.520081, -1042.219971, -1081.040039, -1119.979980, -861.440063, -900.140015, -938.960022,-977.899963, -1016.960022, -1056.140015, -1095.440063, -1134.859985, -872.960022, -912.140015, -951.439941, -990.859985, -1030.400024, -1070.060059, -1109.839844, -1149.739990, -884.479980, -924.140015, -963.919922, -1003.819946, -1043.839966, -1083.979980, -1124.239990, -1164.619995, -896.000000, -936.140015, -976.399963, -1016.780029, -1057.280029, -1097.899902, -1138.640015, -1179.500122, -705.919983, -733.000000, -760.159912, -787.400024, -814.719971, -842.119995, -869.599976, -897.160034}, sd::DataType::FLOAT32); - - NDArray expGradW('c', {kH, kW, iC, mC},{-104306.421875, -104786.734375, -105268.687500, -105752.250000, -106237.421875, -106724.242188, -107212.671875, - -107702.734375, -116289.593750, -116823.296875, -117358.781250, -117896.109375, -118435.210938, -118976.109375, -119518.796875, -120063.296875, -104824.789062, - -105305.117188, -105787.070312, -106270.640625, -106755.843750, -107242.640625, -107731.078125, -108221.117188, -126744.000000, -127277.710938, -127813.187500, - -128350.484375, -128889.601562, -129430.515625, -129973.210938, -130517.703125, -140944.000000, -141536.984375, -142131.984375, -142729.000000, -143328.000000, - -143929.015625, -144532.000000, -145137.000000, -126744.000000, -127277.710938, -127813.187500, -128350.484375, -128889.601562, -129430.515625, -129973.210938, -130517.703125, -104824.789062, -105305.117188, -105787.070312, -106270.640625, -106755.843750, -107242.640625, -107731.078125, -108221.117188, -116289.593750, -116823.296875, -117358.781250, -117896.109375, -118435.210938, -118976.109375, -119518.796875, -120063.296875, -104306.421875, -104786.734375, -105268.687500, -105752.250000, -106237.421875, -106724.242188, -107212.671875, -107702.734375}, sd::DataType::FLOAT32); - - NDArray expGradB('c', {oC}, {-2960., -2970., -2980., -2990., -3000., -3010., -3020., -3030.}, sd::DataType::FLOAT32); - - sd::ops::depthwise_conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - NDArray* gradI = results.at(0); - NDArray* gradW = results.at(1); - NDArray* gradB = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); - + int bS = 1, iH = 10, iW = 10, iC = 8, mC = 1, kH = 3, kW = 3, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 10, oW = 10; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, sd::DataType::FLOAT32); + + input.linspace(-10, 0.1); + weights.linspace(-2, 0.1); + gradO.linspace(10, -0.1); + + NDArray expGradI( + 'c', {bS, iH, iW, iC}, + {10.880001, 13.239998, 15.520001, 17.719997, 19.840000, + 21.880001, 23.839998, 25.720001, 31.360004, 34.420002, + 37.360001, 40.180004, 42.880005, 45.460003, 47.919994, + 50.260002, 31.360001, 33.939999, 36.400002, 38.739998, + 40.959999, 43.059998, 45.040001, 46.900005, 31.359997, + 33.459999, 35.439999, 37.300003, 39.040001, 40.660000, + 42.160000, 43.539997, 31.360001, 32.980000, 34.480000, + 35.860001, 37.119999, 38.259998, 39.279999, 40.180000, + 31.360001, 32.499996, 33.520000, 34.419998, 35.200001, + 35.860001, 36.400002, 36.820000, 31.360001, 32.019997, + 32.560001, 32.979996, 33.280003, 33.459999, 33.520000, + 33.459999, 31.360001, 31.540001, 31.599998, 31.539999, + 31.360001, 31.059999, 30.639999, 30.100000, 31.360001, + 31.060001, 30.639999, 30.099998, 29.440002, 28.660000, + 27.759998, 26.740000, 18.559999, 18.040001, 17.440001, + 16.760000, 16.000000, 15.160000, 14.240001, 13.240000, + 85.439995, 85.860001, 86.159996, 86.339996, 86.400002, + 86.340012, 86.159996, 85.860008, 132.000000, 131.910004, + 131.639999, 131.190002, 130.559998, 129.750000, 128.760010, + 127.589996, 123.360001, 122.550003, 121.559998, 120.389999, + 119.040009, 117.510002, 115.799988, 113.910004, 114.720001, + 113.189995, 111.480003, 109.590004, 107.520004, 105.270004, + 102.839996, 100.230011, 106.079994, 103.830002, 101.400009, + 98.790009, 96.000008, 93.030006, 89.879990, 86.549988, + 97.439995, 94.469994, 91.319992, 87.990005, 84.479996, + 80.789993, 76.919998, 72.870003, 88.800003, 85.110001, + 81.239998, 77.190002, 72.960007, 68.550003, 63.959999, + 59.190002, 80.160004, 75.750000, 71.160004, 66.389999, + 61.440002, 56.309994, 51.000000, 45.510002, 71.519997, + 66.389999, 61.079998, 55.590000, 49.919998, 44.070000, + 38.040001, 31.830002, 31.680000, 27.780003, 23.760000, + 19.619999, 15.360001, 10.980000, 6.480000, 1.859999, + 47.040001, 42.660004, 38.160000, 33.540001, 28.799999, + 23.939999, 18.960001, 13.860001, 45.599998, 38.310001, + 30.840000, 23.190002, 15.360001, 7.349998, -0.840002, + -9.210003, 36.959999, 28.950003, 20.759998, 12.390001, + 3.839998, -4.889999, -13.799999, -22.890003, 28.320002, + 19.589998, 10.680000, 1.590002, -7.680002, -17.129999, + -26.759998, -36.570007, 19.680002, 10.230003, 0.599998, + -9.210001, -19.199999, -29.370003, -39.720001, -50.250008, + 11.039999, 0.869999, -9.480000, -20.010002, -30.719994, + -41.610001, -52.679996, -63.930008, 2.400005, -8.489998, + -19.560005, -30.809998, -42.239998, -53.849991, -65.639992, + -77.610001, -6.239998, -17.849998, -29.639988, -41.609985, + -53.760002, -66.090004, -78.599991, -91.290009, -14.879990, + -27.209995, -39.720009, -52.410007, -65.279999, -78.330002, + -91.559998, -104.969986, -45.119995, -53.820000, -62.639999, + -71.580002, -80.640007, -89.819992, -99.119995, -108.540009, + 8.639999, -0.540001, -9.839996, -19.259998, -28.799995, + -38.459999, -48.240002, -58.140003, -40.799999, -55.289997, + -69.960007, -84.810013, -99.840004, -115.050011, -130.440018, + -146.010010, -49.439991, -64.650009, -80.040009, -95.610016, + -111.360008, -127.290001, -143.399994, -159.690018, -58.080009, + -74.009987, -90.119995, -106.409988, -122.880005, -139.530014, + -156.360001, -173.369995, -66.720001, -83.369995, -100.199997, + -117.209999, -134.399994, -151.769989, -169.319992, -187.049988, + -75.360008, -92.729996, -110.279991, -128.009979, -145.920013, + -164.009995, -182.279984, -200.729996, -84.000000, -102.089996, + -120.360016, -138.809967, -157.440002, -176.249969, -195.240005, + -214.410019, -92.639999, -111.449997, -130.440018, -149.610016, + -168.960007, -188.489990, -208.200012, -228.090012, -101.279976, + -120.809982, -140.519989, -160.410004, -180.480011, -200.730011, + -221.160034, -241.770020, -121.920006, -135.420013, -149.040009, + -162.779999, -176.640015, -190.619995, -204.719986, -218.940002, + -29.760002, -43.739998, -57.840000, -72.059998, -86.400009, + -100.860001, -115.439995, -130.140015, -127.199997, -148.890015, + -170.760010, -192.809998, -215.040024, -237.450012, -260.039978, + -282.809998, -135.839996, -158.250000, -180.840012, -203.610046, + -226.559982, -249.690002, -272.999969, -296.489990, -144.479980, + -167.609985, -190.920013, -214.410019, -238.080032, -261.929993, + -285.959991, -310.169983, -153.119995, -176.969986, -201.000031, + -225.210022, -249.599976, -274.170013, -298.920013, -323.849976, + -161.760040, -186.330017, -211.079987, -236.009995, -261.120026, + -286.410034, -311.879974, -337.530029, -170.400009, -195.689987, + -221.159973, -246.809998, -272.639954, -298.650024, -324.840057, + -351.209991, -179.039963, -205.050018, -231.240021, -257.609985, + -284.160004, -310.890015, -337.799988, -364.890015, -187.680023, + -214.410004, -241.319977, -268.410004, -295.679993, -323.130005, + -350.760010, -378.570038, -198.720016, -217.019989, -235.440002, + -253.979980, -272.640045, -291.419983, -310.319977, -329.339996, + -68.159981, -86.939987, -105.840012, -124.860001, -144.000000, + -163.260010, -182.639984, -202.140015, -213.600021, -242.489990, + -271.559937, -300.809998, -330.239990, -359.849976, -389.639984, + -419.610016, -222.240036, -251.849960, -281.640015, -311.609985, + -341.760040, -372.089996, -402.600037, -433.290009, -230.880005, + -261.210022, -291.719971, -322.410034, -353.280029, -384.329956, + -415.559998, -446.970001, -239.519989, -270.570007, -301.800018, + -333.209991, -364.800018, -396.570007, -428.520020, -460.650024, + -248.160034, -279.929962, -311.880005, -344.010010, -376.320038, + -408.809998, -441.479980, -474.330017, -256.799988, -289.289978, + -321.960022, -354.809967, -387.839996, -421.050018, -454.440002, + -488.009979, -265.440002, -298.650024, -332.040009, -365.609985, + -399.360016, -433.290009, -467.399963, -501.689941, -274.080017, + -308.009949, -342.119995, -376.409973, -410.880005, -445.530029, + -480.359985, -515.369995, -275.520020, -298.619995, -321.839966, + -345.179993, -368.640015, -392.220001, -415.919952, -439.740021, + -106.560005, -130.140030, -153.840027, -177.659973, -201.599991, + -225.660019, -249.840012, -274.140015, -300.000000, -336.090057, + -372.360046, -408.809937, -445.440002, -482.250031, -519.240051, + -556.410034, -308.640015, -345.450012, -382.440002, -419.609955, + -456.959961, -494.489960, -532.200012, -570.089966, -317.280029, + -354.809998, -392.520020, -430.410004, -468.480042, -506.729980, + -545.159912, -583.770020, -325.920013, -364.169952, -402.600037, + -441.210022, -480.000000, -518.970032, -558.119873, -597.449951, + -334.559967, -373.529999, -412.679993, -452.009949, -491.519989, + -531.209961, -571.080017, -611.129944, -343.200012, -382.889984, + -422.760071, -462.809906, -503.039978, -543.449951, -584.039978, + -624.809998, -351.839966, -392.250000, -432.839966, -473.609955, + -514.560120, -555.689941, -596.999939, -638.489990, -360.480011, + -401.610016, -442.920044, -484.409912, -526.080017, -567.929993, + -609.959961, -652.169983, -352.320007, -380.220001, -408.239990, + -436.380005, -464.639984, -493.019989, -521.519958, -550.139954, + -144.960022, -173.339996, -201.839996, -230.459976, -259.200043, + -288.059998, -317.039978, -346.140015, -386.399963, -429.690002, + -473.159912, -516.809937, -560.640076, -604.650024, -648.839966, + -693.210022, -395.039978, -439.050018, -483.239929, -527.609985, + -572.159973, -616.890015, -661.799988, -706.890015, -403.680023, + -448.409973, -493.320007, -538.410034, -583.680054, -629.129944, + -674.760010, -720.570068, -412.320007, -457.769897, -503.399963, + -549.210083, -595.199951, -641.369995, -687.720093, -734.250000, + -420.960052, -467.130035, -513.479980, -560.010010, -606.720093, + -653.610046, -700.680054, -747.930115, -429.599976, -476.489990, + -523.559998, -570.809937, -618.239990, -665.849976, -713.640015, + -761.609985, -438.239990, -485.850037, -533.640015, -581.610046, + -629.760010, -678.089966, -726.600037, -775.289917, -446.880035, + -495.210052, -543.719971, -592.410034, -641.279968, -690.330017, + -739.559937, -788.970093, -429.120026, -461.819946, -494.639984, + -527.580017, -560.640015, -593.820007, -627.119995, -660.540039, + -183.360016, -216.540009, -249.839996, -283.260040, -316.800018, + -350.459961, -384.239990, -418.139984, -472.800049, -523.289917, + -573.959961, -624.809998, -675.839966, -727.050049, -778.440063, + -830.010010, -481.440002, -532.649963, -584.040100, -635.609985, + -687.359924, -739.290039, -791.399963, -843.689941, -490.079987, + -542.010010, -594.119995, -646.410034, -698.880005, -751.529968, + -804.359985, -857.369995, -498.720032, -551.369995, -604.200012, + -657.210022, -710.400024, -763.770081, -817.319946, -871.050049, + -507.359955, -560.729919, -614.280029, -668.010010, -721.919983, + -776.010010, -830.280029, -884.730042, -515.999939, -570.089966, + -624.360046, -678.809937, -733.440002, -788.250000, -843.239990, + -898.410034, -524.639954, -579.449951, -634.440002, -689.609985, + -744.960022, -800.489990, -856.200012, -912.090027, -533.280029, + -588.810059, -644.520081, -700.409973, -756.480042, -812.730103, + -869.159912, -925.769958, -505.920013, -543.420044, -581.040039, + -618.780029, -656.640015, -694.620056, -732.719971, -770.940002, + -447.359985, -471.559998, -495.840027, -520.200012, -544.640015, + -569.159973, -593.760010, -618.440002, -815.359985, -852.140015, + -889.040039, -926.059937, -963.200073, -1000.460022, -1037.839966, + -1075.339966, -826.879944, -864.139954, -901.519958, -939.019958, + -976.640076, -1014.379944, -1052.239990, -1090.219971, -838.400024, + -876.140015, -913.999939, -951.979919, -990.080017, -1028.299927, + -1066.640015, -1105.099976, -849.919983, -888.140015, -926.479980, + -964.939941, -1003.520081, -1042.219971, -1081.040039, -1119.979980, + -861.440063, -900.140015, -938.960022, -977.899963, -1016.960022, + -1056.140015, -1095.440063, -1134.859985, -872.960022, -912.140015, + -951.439941, -990.859985, -1030.400024, -1070.060059, -1109.839844, + -1149.739990, -884.479980, -924.140015, -963.919922, -1003.819946, + -1043.839966, -1083.979980, -1124.239990, -1164.619995, -896.000000, + -936.140015, -976.399963, -1016.780029, -1057.280029, -1097.899902, + -1138.640015, -1179.500122, -705.919983, -733.000000, -760.159912, + -787.400024, -814.719971, -842.119995, -869.599976, -897.160034}, + sd::DataType::FLOAT32); + + NDArray expGradW( + 'c', {kH, kW, iC, mC}, + {-104306.421875, -104786.734375, -105268.687500, -105752.250000, + -106237.421875, -106724.242188, -107212.671875, -107702.734375, + -116289.593750, -116823.296875, -117358.781250, -117896.109375, + -118435.210938, -118976.109375, -119518.796875, -120063.296875, + -104824.789062, -105305.117188, -105787.070312, -106270.640625, + -106755.843750, -107242.640625, -107731.078125, -108221.117188, + -126744.000000, -127277.710938, -127813.187500, -128350.484375, + -128889.601562, -129430.515625, -129973.210938, -130517.703125, + -140944.000000, -141536.984375, -142131.984375, -142729.000000, + -143328.000000, -143929.015625, -144532.000000, -145137.000000, + -126744.000000, -127277.710938, -127813.187500, -128350.484375, + -128889.601562, -129430.515625, -129973.210938, -130517.703125, + -104824.789062, -105305.117188, -105787.070312, -106270.640625, + -106755.843750, -107242.640625, -107731.078125, -108221.117188, + -116289.593750, -116823.296875, -117358.781250, -117896.109375, + -118435.210938, -118976.109375, -119518.796875, -120063.296875, + -104306.421875, -104786.734375, -105268.687500, -105752.250000, + -106237.421875, -106724.242188, -107212.671875, -107702.734375}, + sd::DataType::FLOAT32); + + NDArray expGradB( + 'c', {oC}, + {-2960., -2970., -2980., -2990., -3000., -3010., -3020., -3030.}, + sd::DataType::FLOAT32); + + sd::ops::depthwise_conv2d_bp op; + auto results = + op.evaluate({&input, &weights, &bias, &gradO}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test5) { - - int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=10,oW=10; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); - NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, sd::DataType::FLOAT32); - - input.linspace(-10, 0.1); - weights.linspace(-2, 0.1); - gradO.linspace(10, -0.1); - - - NDArray expGradI('c', {bS, iC, iH, iW}, {-12.639999, 3.920004, 3.920000, 3.920000, 3.920002, 3.920000, 3.920000, 3.919998, 3.919998, 16.319998, 52.680004, 111.000015, 109.919991, 108.840004, 107.760002, 106.680008, 105.600006, 104.519997, 103.440018, 87.960007, 47.880001, 100.200005, 99.119995, 98.040001, 96.959999, 95.879990, 94.799995, 93.720001, 92.639999, 78.360001, 43.079998, 89.399994, 88.320007, 87.240005, 86.159996, 85.079994, 84.000000, 82.919998, 81.840004, 68.759995, 38.279999, 78.600006, 77.519997, 76.440010, 75.360001, 74.279999, 73.200005, 72.120003, 71.040001, 59.160004, 33.480000, 67.799995, 66.720009, 65.639999, 64.559998, 63.480000, 62.399994, 61.320007, 60.240002, 49.559998, 28.680004, 57.000004, 55.919998, 54.839993, 53.759998, 52.680000, 51.600002, 50.519997, 49.440002, 39.959999, 23.880001, 46.200001, 45.120003, 44.039997, 42.959999, 41.880001, 40.799999, 39.719994, 38.639999, 30.360001, 19.079998, 35.400002, 34.320000, 33.239998, 32.159996, 31.080000, 29.999998, 28.919998, 27.840000, 20.759998, 14.079999, 24.080000, 22.639997, 21.200001, 19.759998, 18.320002, 16.880001, 15.440001, 14.000000, 9.759999, 3.140000, 3.560000, 3.500000, 3.440000, 3.380000, 3.320000, 3.260000, 3.200000, 3.140000, -0.220000, 4.050000, 2.010000, 0.840000, -0.330000, -1.499999, -2.670000, -3.840000, -5.010000, -6.179998, -9.150000, -1.350000, -9.690001, -10.859999, -12.029998, -13.200001, -14.370001, -15.539999, -16.710001, -17.879999, -19.349998, -6.750000, -21.389997, -22.560003, -23.730003, -24.900002, -26.069998, -27.239998, -28.410007, -29.580002, -29.550003, -12.150001, -33.089996, -34.260002, -35.430000, -36.600002, -37.770000, -38.939995, -40.110001, -41.280003, -39.749996, -17.550003, -44.790005, -45.959991, -47.129993, -48.300003, -49.470001, -50.640003, -51.809990, -52.979996, -49.950001, -22.949999, -56.490005, -57.660000, -58.829998, -60.000000, -61.170002, -62.340004, -63.510002, -64.680000, - -60.149994, -28.349998, -68.189987, -69.360001, -70.529999, -71.700005, -72.870010, -74.039993, -75.209999, -76.379990, -70.349998, -33.749996, -79.889999, -81.059990, -82.229988, -83.399994, -84.570007, -85.740005, -86.910004, -88.079994, -80.549995, -69.340004, -125.080002, -126.580002, -128.080002, -129.580002, -131.080002, -132.580002, -134.080002, -135.580002, -105.979996, 10.919998, -8.799997, -8.919998, -9.040003, -9.160004, -9.279999, -9.400002, -9.520002, -9.640003, -24.760000, -56.580009, -124.980003, -126.240005, -127.499992, -128.759995, -130.020020, -131.279999, -132.540009, -133.800003, -118.260002, -62.580009, -137.580002, -138.840012, -140.099991, -141.360001, -142.620010, -143.879974, -145.139999, -146.399994, -129.060013, -68.580002, -150.179993, -151.439987, -152.699997, -153.959991, -155.219986, -156.480011, -157.740005, -159.000000, -139.860001, -74.579994, -162.779999, -164.040024, -165.300003, -166.560028, -167.819977, -169.080002, -170.339996, -171.599991, -150.660004, -80.580002, -175.379990, -176.639999, -177.899994, -179.160019, -180.419998, -181.679993, -182.940002, -184.199997, -161.459991, -86.580002, -187.979996, -189.240005, -190.499985, -191.759995, -193.020020, -194.279999, -195.540024, -196.800018, -172.260010, -92.580002, -200.579987, -201.839981, -203.100006, -204.359970, -205.620010, -206.880005, -208.139999, -209.399994, -183.060013, -98.580002, -213.180023, -214.440002, -215.700012, -216.959991, -218.220001, -219.480011, -220.739975, -222.000000, -193.860001, -160.760010, -286.239990, -287.799988, -289.360016, -290.920013, -292.480011, -294.040009, -295.599976, -297.160004, -229.719986, 10.700003, -33.160004, -33.339996, -33.519993, -33.700001, - -33.879997, -34.059994, -34.239994, -34.419994, -57.299995, -129.209991, -269.969971, -271.319977, -272.670044, -274.019989, -275.369995, -276.720001, -278.070007, -279.420013, -239.369980, -135.809998, -283.470001, -284.820007, -286.169983, -287.520020, -288.869995, -290.220001, -291.570038, -292.919983, -250.770004, -142.410004, -296.969971, -298.320007, -299.669983, -301.020020, -302.369995, -303.719971, -305.070007, -306.419983, -262.169983, -149.009995, -310.470001, -311.820007, -313.170013, -314.519989, -315.869995, -317.220001, -318.570007, -319.919983, -273.570007, -155.610016, -323.969971, -325.320038, -326.669983, -328.020020, -329.369965, -330.719971, -332.070007, -333.419983, -284.970001, -162.209991, -337.469971, -338.820007, -340.169983, -341.519958, -342.869995, -344.220001, -345.570007, -346.920013, -296.369995, -168.809998, -350.970001, -352.320007, -353.669983, -355.019989, -356.369995, -357.719971, -359.070038, -360.419983, -307.769989, -175.410004, -364.469971, -365.820007, -367.169983, -368.520020, -369.869995, -371.219971, -372.570007, -373.919983, -319.169983, -260.179993, -459.399994, -461.019958, -462.639984, -464.260010, -465.880005, -467.500000, -469.119995, -470.739990, -361.459991, 2.480003, -69.520004, -69.760025, -70.000000, -70.239990, -70.479996, -70.720001, -70.960007, -71.200005, -97.839996, -213.840012, -432.960022, -434.400055, -435.840027, -437.279999, -438.720001, -440.160065, -441.599976, -443.040039, -372.480011, -221.040009, -447.360016, -448.800018, -450.239990, -451.679993, -453.119995, -454.559967, -456.000061, -457.440033, -384.480011, -228.239990, -461.759979, -463.200012, -464.639984, -466.079956, -467.520081, -468.960052, -470.399963, -471.839996, -396.479980, -235.440002, -476.159912, - -477.600006, -479.040039, -480.479980, -481.919952, -483.360046, -484.800079, -486.239990, -408.480042, -242.639999, -490.559967, -491.999969, -493.440063, -494.880035, -496.319946, -497.759979, -499.200012, -500.639984, -420.480011, -249.840012, -504.960052, -506.399963, -507.839996, -509.280029, -510.720001, -512.159973, -513.599976, -515.040039, -432.480011, -257.040009, -519.360046, -520.800049, -522.239990, -523.680054, -525.120056, -526.559998, -527.999939, -529.440002, -444.480011, -264.239990, -533.760010, -535.200012, -536.640015, -538.079956, -539.520020, -540.960022, -542.399963, -543.839966, -456.479980, -367.599976, -644.559998, -646.239929, -647.920044, -649.599976, -651.280029, -652.960022, -654.640076, -656.320007, -501.200043, -13.740002, -117.880005, -118.179993, -118.479996, -118.780014, -119.080002, -119.379990, -119.680008, -119.979996, -146.379990, -310.470001, -613.950012, -615.479980, -617.010071, -618.539978, -620.069946, -621.599976, -623.130005, -624.660034, -517.589966, -318.269958, -629.250000, -630.779968, -632.309937, -633.840027, -635.369995, -636.899902, -638.429993, -639.959961, -530.190063, -326.070038, -644.550049, -646.079956, -647.609985, -649.140015, -650.669922, -652.200012, -653.729980, -655.260010, -542.789978, -333.870026, -659.849976, -661.380005, -662.910034, -664.439941, -665.970093, -667.500000, -669.029968, -670.559937, -555.390015, -341.669983, -675.149902, -676.679993, -678.209961, -679.740051, -681.270020, -682.800049, -684.329956, -685.859985, -567.989990, -349.470001, -690.450012, -691.979980, -693.510010, -695.039978, -696.569946, -698.099976, -699.630005, -701.160034, -580.589966, -357.269958, -705.750000, -707.279968, -708.809937, -710.340027, -711.869995, -713.399902, -714.929993, -716.459961, -593.190002, -365.070038, -721.050049, -722.579956, -724.109985, -725.640015, -727.169922, -728.700012, - -730.229980, -731.760010, -605.789978, -483.019958, -841.719971, -843.460022, -845.200073, -846.939941, -848.680054, -850.419983, -852.159973, -853.899963, -648.940002, -37.960014, -178.240021, -178.599976, -178.959991, -179.320007, -179.679993, -180.039978, -180.399994, -180.759964, -202.919983, -419.099915, -812.939941, -814.559937, -816.179993, -817.800049, -819.419922, -821.040039, -822.660034, -824.279968, -674.699951, -427.500031, -829.140015, -830.759949, -832.380005, -833.999939, -835.619995, -837.240051, -838.859924, -840.479980, -687.899963, -435.899994, -845.339966, -846.959961, -848.579956, -850.200012, -851.819885, -853.439941, -855.059937, -856.679993, -701.100037, -444.299927, -861.540039, -863.160034, -864.779968, -866.399963, -868.020020, -869.640015, -871.259949, -872.880005, -714.299988, -452.700012, -877.740051, -879.359924, -880.979980, -882.599915, -884.219971, -885.839966, -887.459961, -889.079956, -727.500000, -461.099915, -893.939941, -895.559937, -897.179993, -898.800049, -900.419922, -902.040039, -903.660034, -905.279968, -740.700012, -469.499969, -910.140015, -911.759949, -913.380005, -914.999939, -916.620056, -918.239990, -919.860046, -921.479919, -753.899963, -477.899902, -926.339905, -927.959961, -929.579956, -931.200012, -932.819946, -934.439880, -936.059937, -937.679932, -767.100037, -606.439941, -1050.880005, -1052.680054, -1054.479980, -1056.280029, -1058.079956, -1059.880005, -1061.679932, -1063.479980, -804.679993, -70.180008, -250.600006, -251.019958, -251.440033, -251.860001, -252.280029, -252.700043, -253.120026, -253.540039, -267.459991, -539.730042, -1029.929932, -1031.640137, -1033.350098, -1035.060059, -1036.770020, -1038.479980, -1040.190063, -1041.900024, -843.809998, -548.729980, -1047.030029, -1048.740112, -1050.449829, -1052.160034, -1053.870117, -1055.580078, -1057.289917, -1059.000122, -857.609985, -557.729980, - -1064.130005, -1065.840088, -1067.550049, -1069.260010, -1070.969849, -1072.679932, -1074.390137, -1076.100098, -871.410034, -566.729980, -1081.229980, -1082.940063, -1084.650024, -1086.359985, -1088.069946, -1089.780029, -1091.489990, -1093.199951, -885.210022, -575.729980, -1098.329956, -1100.040039, -1101.750122, -1103.460205, -1105.170166, -1106.879883, -1108.589966, -1110.300049, -899.010071, -584.730042, -1115.429932, -1117.140137, -1118.850098, -1120.560059, -1122.270020, -1123.979980, -1125.689941, -1127.400024, -912.810059, -593.730042, -1132.530029, -1134.240234, -1135.949951, -1137.659912, -1139.370117, -1141.079956, -1142.790039, -1144.500122, -926.610046, -602.730042, -1149.629883, -1151.339966, -1153.050049, -1154.760132, -1156.469971, -1158.179810, -1159.890137, -1161.600098, -940.410034, -737.859985, -1272.040039, -1273.899902, -1275.760010, -1277.619995, -1279.479980, -1281.340088, -1283.200195, -1285.060059, -968.420044}, sd::DataType::FLOAT32); - - NDArray expGradW('c', {kH, kW, iC, mC}, {-2586.600586, -2505.600098, -18624.595703, -50943.605469, -99462.601562, -164181.609375, -245100.609375, -342219.625000, - -2880.149902, -2790.150146, -20700.152344, -56610.148438, -110520.156250, -182430.156250, -272340.156250, -380250.125000, -2594.701416, -2513.699951, - -18632.699219, -50951.695312, -99470.695312, -164189.703125, -245108.687500, -342227.750000, -3043.501465, -2953.500244, -20863.500000, -56773.492188, - -110683.515625, -182593.515625, -272503.531250, -380413.562500, -3383.499756, -3283.500000, -23183.501953, -63083.500000, -122983.500000, -202883.515625, - -302783.531250, -422683.468750, -3043.501465, -2953.500244, -20863.500000, -56773.492188, -110683.515625, -182593.515625, -272503.531250, -380413.562500, - -2594.701416, -2513.699951, -18632.699219, -50951.695312, -99470.695312, -164189.703125, -245108.687500, -342227.750000, -2880.149902, -2790.150146, -20700.152344, -56610.148438, -110520.156250, -182430.156250, -272340.156250, -380250.125000, -2586.600586, -2505.600098, -18624.595703, -50943.605469, -99462.601562, -164181.609375, -245100.609375, -342219.625000}, sd::DataType::FLOAT32); - - NDArray expGradB('c', {oC}, {505., -495., -1495., -2495., -3495., -4494.999512, -5495., -6495.}, sd::DataType::FLOAT32); - - sd::ops::depthwise_conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - NDArray* gradI = results.at(0); - NDArray* gradW = results.at(1); - NDArray* gradB = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); + int bS = 1, iH = 10, iW = 10, iC = 8, mC = 1, kH = 3, kW = 3, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 10, oW = 10; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, sd::DataType::FLOAT32); + + input.linspace(-10, 0.1); + weights.linspace(-2, 0.1); + gradO.linspace(10, -0.1); + + NDArray expGradI( + 'c', {bS, iC, iH, iW}, + {-12.639999, 3.920004, 3.920000, 3.920000, 3.920002, + 3.920000, 3.920000, 3.919998, 3.919998, 16.319998, + 52.680004, 111.000015, 109.919991, 108.840004, 107.760002, + 106.680008, 105.600006, 104.519997, 103.440018, 87.960007, + 47.880001, 100.200005, 99.119995, 98.040001, 96.959999, + 95.879990, 94.799995, 93.720001, 92.639999, 78.360001, + 43.079998, 89.399994, 88.320007, 87.240005, 86.159996, + 85.079994, 84.000000, 82.919998, 81.840004, 68.759995, + 38.279999, 78.600006, 77.519997, 76.440010, 75.360001, + 74.279999, 73.200005, 72.120003, 71.040001, 59.160004, + 33.480000, 67.799995, 66.720009, 65.639999, 64.559998, + 63.480000, 62.399994, 61.320007, 60.240002, 49.559998, + 28.680004, 57.000004, 55.919998, 54.839993, 53.759998, + 52.680000, 51.600002, 50.519997, 49.440002, 39.959999, + 23.880001, 46.200001, 45.120003, 44.039997, 42.959999, + 41.880001, 40.799999, 39.719994, 38.639999, 30.360001, + 19.079998, 35.400002, 34.320000, 33.239998, 32.159996, + 31.080000, 29.999998, 28.919998, 27.840000, 20.759998, + 14.079999, 24.080000, 22.639997, 21.200001, 19.759998, + 18.320002, 16.880001, 15.440001, 14.000000, 9.759999, + 3.140000, 3.560000, 3.500000, 3.440000, 3.380000, + 3.320000, 3.260000, 3.200000, 3.140000, -0.220000, + 4.050000, 2.010000, 0.840000, -0.330000, -1.499999, + -2.670000, -3.840000, -5.010000, -6.179998, -9.150000, + -1.350000, -9.690001, -10.859999, -12.029998, -13.200001, + -14.370001, -15.539999, -16.710001, -17.879999, -19.349998, + -6.750000, -21.389997, -22.560003, -23.730003, -24.900002, + -26.069998, -27.239998, -28.410007, -29.580002, -29.550003, + -12.150001, -33.089996, -34.260002, -35.430000, -36.600002, + -37.770000, -38.939995, -40.110001, -41.280003, -39.749996, + -17.550003, -44.790005, -45.959991, -47.129993, -48.300003, + -49.470001, -50.640003, -51.809990, -52.979996, -49.950001, + -22.949999, -56.490005, -57.660000, -58.829998, -60.000000, + -61.170002, -62.340004, -63.510002, -64.680000, -60.149994, + -28.349998, -68.189987, -69.360001, -70.529999, -71.700005, + -72.870010, -74.039993, -75.209999, -76.379990, -70.349998, + -33.749996, -79.889999, -81.059990, -82.229988, -83.399994, + -84.570007, -85.740005, -86.910004, -88.079994, -80.549995, + -69.340004, -125.080002, -126.580002, -128.080002, -129.580002, + -131.080002, -132.580002, -134.080002, -135.580002, -105.979996, + 10.919998, -8.799997, -8.919998, -9.040003, -9.160004, + -9.279999, -9.400002, -9.520002, -9.640003, -24.760000, + -56.580009, -124.980003, -126.240005, -127.499992, -128.759995, + -130.020020, -131.279999, -132.540009, -133.800003, -118.260002, + -62.580009, -137.580002, -138.840012, -140.099991, -141.360001, + -142.620010, -143.879974, -145.139999, -146.399994, -129.060013, + -68.580002, -150.179993, -151.439987, -152.699997, -153.959991, + -155.219986, -156.480011, -157.740005, -159.000000, -139.860001, + -74.579994, -162.779999, -164.040024, -165.300003, -166.560028, + -167.819977, -169.080002, -170.339996, -171.599991, -150.660004, + -80.580002, -175.379990, -176.639999, -177.899994, -179.160019, + -180.419998, -181.679993, -182.940002, -184.199997, -161.459991, + -86.580002, -187.979996, -189.240005, -190.499985, -191.759995, + -193.020020, -194.279999, -195.540024, -196.800018, -172.260010, + -92.580002, -200.579987, -201.839981, -203.100006, -204.359970, + -205.620010, -206.880005, -208.139999, -209.399994, -183.060013, + -98.580002, -213.180023, -214.440002, -215.700012, -216.959991, + -218.220001, -219.480011, -220.739975, -222.000000, -193.860001, + -160.760010, -286.239990, -287.799988, -289.360016, -290.920013, + -292.480011, -294.040009, -295.599976, -297.160004, -229.719986, + 10.700003, -33.160004, -33.339996, -33.519993, -33.700001, + -33.879997, -34.059994, -34.239994, -34.419994, -57.299995, + -129.209991, -269.969971, -271.319977, -272.670044, -274.019989, + -275.369995, -276.720001, -278.070007, -279.420013, -239.369980, + -135.809998, -283.470001, -284.820007, -286.169983, -287.520020, + -288.869995, -290.220001, -291.570038, -292.919983, -250.770004, + -142.410004, -296.969971, -298.320007, -299.669983, -301.020020, + -302.369995, -303.719971, -305.070007, -306.419983, -262.169983, + -149.009995, -310.470001, -311.820007, -313.170013, -314.519989, + -315.869995, -317.220001, -318.570007, -319.919983, -273.570007, + -155.610016, -323.969971, -325.320038, -326.669983, -328.020020, + -329.369965, -330.719971, -332.070007, -333.419983, -284.970001, + -162.209991, -337.469971, -338.820007, -340.169983, -341.519958, + -342.869995, -344.220001, -345.570007, -346.920013, -296.369995, + -168.809998, -350.970001, -352.320007, -353.669983, -355.019989, + -356.369995, -357.719971, -359.070038, -360.419983, -307.769989, + -175.410004, -364.469971, -365.820007, -367.169983, -368.520020, + -369.869995, -371.219971, -372.570007, -373.919983, -319.169983, + -260.179993, -459.399994, -461.019958, -462.639984, -464.260010, + -465.880005, -467.500000, -469.119995, -470.739990, -361.459991, + 2.480003, -69.520004, -69.760025, -70.000000, -70.239990, + -70.479996, -70.720001, -70.960007, -71.200005, -97.839996, + -213.840012, -432.960022, -434.400055, -435.840027, -437.279999, + -438.720001, -440.160065, -441.599976, -443.040039, -372.480011, + -221.040009, -447.360016, -448.800018, -450.239990, -451.679993, + -453.119995, -454.559967, -456.000061, -457.440033, -384.480011, + -228.239990, -461.759979, -463.200012, -464.639984, -466.079956, + -467.520081, -468.960052, -470.399963, -471.839996, -396.479980, + -235.440002, -476.159912, -477.600006, -479.040039, -480.479980, + -481.919952, -483.360046, -484.800079, -486.239990, -408.480042, + -242.639999, -490.559967, -491.999969, -493.440063, -494.880035, + -496.319946, -497.759979, -499.200012, -500.639984, -420.480011, + -249.840012, -504.960052, -506.399963, -507.839996, -509.280029, + -510.720001, -512.159973, -513.599976, -515.040039, -432.480011, + -257.040009, -519.360046, -520.800049, -522.239990, -523.680054, + -525.120056, -526.559998, -527.999939, -529.440002, -444.480011, + -264.239990, -533.760010, -535.200012, -536.640015, -538.079956, + -539.520020, -540.960022, -542.399963, -543.839966, -456.479980, + -367.599976, -644.559998, -646.239929, -647.920044, -649.599976, + -651.280029, -652.960022, -654.640076, -656.320007, -501.200043, + -13.740002, -117.880005, -118.179993, -118.479996, -118.780014, + -119.080002, -119.379990, -119.680008, -119.979996, -146.379990, + -310.470001, -613.950012, -615.479980, -617.010071, -618.539978, + -620.069946, -621.599976, -623.130005, -624.660034, -517.589966, + -318.269958, -629.250000, -630.779968, -632.309937, -633.840027, + -635.369995, -636.899902, -638.429993, -639.959961, -530.190063, + -326.070038, -644.550049, -646.079956, -647.609985, -649.140015, + -650.669922, -652.200012, -653.729980, -655.260010, -542.789978, + -333.870026, -659.849976, -661.380005, -662.910034, -664.439941, + -665.970093, -667.500000, -669.029968, -670.559937, -555.390015, + -341.669983, -675.149902, -676.679993, -678.209961, -679.740051, + -681.270020, -682.800049, -684.329956, -685.859985, -567.989990, + -349.470001, -690.450012, -691.979980, -693.510010, -695.039978, + -696.569946, -698.099976, -699.630005, -701.160034, -580.589966, + -357.269958, -705.750000, -707.279968, -708.809937, -710.340027, + -711.869995, -713.399902, -714.929993, -716.459961, -593.190002, + -365.070038, -721.050049, -722.579956, -724.109985, -725.640015, + -727.169922, -728.700012, -730.229980, -731.760010, -605.789978, + -483.019958, -841.719971, -843.460022, -845.200073, -846.939941, + -848.680054, -850.419983, -852.159973, -853.899963, -648.940002, + -37.960014, -178.240021, -178.599976, -178.959991, -179.320007, + -179.679993, -180.039978, -180.399994, -180.759964, -202.919983, + -419.099915, -812.939941, -814.559937, -816.179993, -817.800049, + -819.419922, -821.040039, -822.660034, -824.279968, -674.699951, + -427.500031, -829.140015, -830.759949, -832.380005, -833.999939, + -835.619995, -837.240051, -838.859924, -840.479980, -687.899963, + -435.899994, -845.339966, -846.959961, -848.579956, -850.200012, + -851.819885, -853.439941, -855.059937, -856.679993, -701.100037, + -444.299927, -861.540039, -863.160034, -864.779968, -866.399963, + -868.020020, -869.640015, -871.259949, -872.880005, -714.299988, + -452.700012, -877.740051, -879.359924, -880.979980, -882.599915, + -884.219971, -885.839966, -887.459961, -889.079956, -727.500000, + -461.099915, -893.939941, -895.559937, -897.179993, -898.800049, + -900.419922, -902.040039, -903.660034, -905.279968, -740.700012, + -469.499969, -910.140015, -911.759949, -913.380005, -914.999939, + -916.620056, -918.239990, -919.860046, -921.479919, -753.899963, + -477.899902, -926.339905, -927.959961, -929.579956, -931.200012, + -932.819946, -934.439880, -936.059937, -937.679932, -767.100037, + -606.439941, -1050.880005, -1052.680054, -1054.479980, -1056.280029, + -1058.079956, -1059.880005, -1061.679932, -1063.479980, -804.679993, + -70.180008, -250.600006, -251.019958, -251.440033, -251.860001, + -252.280029, -252.700043, -253.120026, -253.540039, -267.459991, + -539.730042, -1029.929932, -1031.640137, -1033.350098, -1035.060059, + -1036.770020, -1038.479980, -1040.190063, -1041.900024, -843.809998, + -548.729980, -1047.030029, -1048.740112, -1050.449829, -1052.160034, + -1053.870117, -1055.580078, -1057.289917, -1059.000122, -857.609985, + -557.729980, -1064.130005, -1065.840088, -1067.550049, -1069.260010, + -1070.969849, -1072.679932, -1074.390137, -1076.100098, -871.410034, + -566.729980, -1081.229980, -1082.940063, -1084.650024, -1086.359985, + -1088.069946, -1089.780029, -1091.489990, -1093.199951, -885.210022, + -575.729980, -1098.329956, -1100.040039, -1101.750122, -1103.460205, + -1105.170166, -1106.879883, -1108.589966, -1110.300049, -899.010071, + -584.730042, -1115.429932, -1117.140137, -1118.850098, -1120.560059, + -1122.270020, -1123.979980, -1125.689941, -1127.400024, -912.810059, + -593.730042, -1132.530029, -1134.240234, -1135.949951, -1137.659912, + -1139.370117, -1141.079956, -1142.790039, -1144.500122, -926.610046, + -602.730042, -1149.629883, -1151.339966, -1153.050049, -1154.760132, + -1156.469971, -1158.179810, -1159.890137, -1161.600098, -940.410034, + -737.859985, -1272.040039, -1273.899902, -1275.760010, -1277.619995, + -1279.479980, -1281.340088, -1283.200195, -1285.060059, -968.420044}, + sd::DataType::FLOAT32); + + NDArray expGradW( + 'c', {kH, kW, iC, mC}, + {-2586.600586, -2505.600098, -18624.595703, -50943.605469, + -99462.601562, -164181.609375, -245100.609375, -342219.625000, + -2880.149902, -2790.150146, -20700.152344, -56610.148438, + -110520.156250, -182430.156250, -272340.156250, -380250.125000, + -2594.701416, -2513.699951, -18632.699219, -50951.695312, + -99470.695312, -164189.703125, -245108.687500, -342227.750000, + -3043.501465, -2953.500244, -20863.500000, -56773.492188, + -110683.515625, -182593.515625, -272503.531250, -380413.562500, + -3383.499756, -3283.500000, -23183.501953, -63083.500000, + -122983.500000, -202883.515625, -302783.531250, -422683.468750, + -3043.501465, -2953.500244, -20863.500000, -56773.492188, + -110683.515625, -182593.515625, -272503.531250, -380413.562500, + -2594.701416, -2513.699951, -18632.699219, -50951.695312, + -99470.695312, -164189.703125, -245108.687500, -342227.750000, + -2880.149902, -2790.150146, -20700.152344, -56610.148438, + -110520.156250, -182430.156250, -272340.156250, -380250.125000, + -2586.600586, -2505.600098, -18624.595703, -50943.605469, + -99462.601562, -164181.609375, -245100.609375, -342219.625000}, + sd::DataType::FLOAT32); + + NDArray expGradB( + 'c', {oC}, + {505., -495., -1495., -2495., -3495., -4494.999512, -5495., -6495.}, + sd::DataType::FLOAT32); + + sd::ops::depthwise_conv2d_bp op; + auto results = + op.evaluate({&input, &weights, &bias, &gradO}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test6) { - - int bS=2, iH=4,iW=3, iC=2,mC=1, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int oC=iC*mC; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - auto bias = NDArrayFactory::create('c', {oC}, {3,4}); - auto gradO = NDArrayFactory::create('c', {bS, oC, oH, oW}); - - auto expGradI = NDArrayFactory::create('c', {bS, iC, iH, iW},{0.001, 0.005, 0.006, 0.008, 0.03, 0.026, 0.024, 0.07, 0.05, 0.027, 0.069, 0.044, 0.01, - 0.032, 0.024, 0.044, 0.12, 0.08, 0.092, 0.224, 0.136, 0.07, 0.164, 0.096, 0.009, 0.037, 0.03, 0.056, 0.158, 0.106, 0.136, - 0.326, 0.194, 0.099, 0.229, 0.132, 0.026, 0.08, 0.056, 0.108, 0.28, 0.176, 0.22, 0.512, 0.296, 0.15, 0.34, 0.192}); - - auto expGradW = NDArrayFactory::create('c', {kH, kW, iC, mC}, {1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68}); - - input = 2.; - weights.linspace(0.1, 0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::depthwise_conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* gradI = results.at(0); - auto* gradW = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); + int bS = 2, iH = 4, iW = 3, iC = 2, mC = 1, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int oC = iC * mC; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + auto bias = NDArrayFactory::create('c', {oC}, {3, 4}); + auto gradO = NDArrayFactory::create('c', {bS, oC, oH, oW}); + + auto expGradI = NDArrayFactory::create( + 'c', {bS, iC, iH, iW}, + {0.001, 0.005, 0.006, 0.008, 0.03, 0.026, 0.024, 0.07, 0.05, 0.027, + 0.069, 0.044, 0.01, 0.032, 0.024, 0.044, 0.12, 0.08, 0.092, 0.224, + 0.136, 0.07, 0.164, 0.096, 0.009, 0.037, 0.03, 0.056, 0.158, 0.106, + 0.136, 0.326, 0.194, 0.099, 0.229, 0.132, 0.026, 0.08, 0.056, 0.108, + 0.28, 0.176, 0.22, 0.512, 0.296, 0.15, 0.34, 0.192}); + + auto expGradW = NDArrayFactory::create( + 'c', {kH, kW, iC, mC}, + {1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68}); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::depthwise_conv2d_bp op; + auto results = + op.evaluate({&input, &weights, &bias, &gradO}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test7) { - - int bS=2, iH=4,iW=3, iC=2,mC=1, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int oC=iC*mC; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - int wFormat = 1; // 0-[kH, kW, iC, mC], 1-[mC, iC, kH, kW], 2-[mC, kH, kW, iC] - - NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); - NDArray weights('c', {mC, iC, kH, kW}, {0.10, 0.30, 0.50, 0.70, 0.90, 1.10, 0.20, 0.40, 0.60, 0.80, 1., 1.2}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {3,4}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); - - - NDArray expGradI('c', {bS, iC, iH, iW},{0.001, 0.005, 0.006, 0.008, 0.03, 0.026, 0.024, 0.07, 0.05, 0.027, 0.069, 0.044, 0.01, - 0.032, 0.024, 0.044, 0.12, 0.08, 0.092, 0.224, 0.136, 0.07, 0.164, 0.096, 0.009, 0.037, 0.03, 0.056, 0.158, 0.106, 0.136, - 0.326, 0.194, 0.099, 0.229, 0.132, 0.026, 0.08, 0.056, 0.108, 0.28, 0.176, 0.22, 0.512, 0.296, 0.15, 0.34, 0.192}, sd::DataType::FLOAT32); - - NDArray expGradW('c', {mC, iC, kH, kW}, {1.04, 1.04, 1.04, 1.04, 1.04, 1.04, 1.68, 1.68, 1.68, 1.68, 1.68, 1.68}, sd::DataType::FLOAT32); - - input = 2.; - gradO.linspace(0.01, 0.01); - - sd::ops::depthwise_conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); - auto* gradI = results.at(0); - auto* gradW = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); + int bS = 2, iH = 4, iW = 3, iC = 2, mC = 1, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int oC = iC * mC; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = + 1; // 0-[kH, kW, iC, mC], 1-[mC, iC, kH, kW], 2-[mC, kH, kW, iC] + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {mC, iC, kH, kW}, + {0.10, 0.30, 0.50, 0.70, 0.90, 1.10, 0.20, 0.40, 0.60, 0.80, 1., 1.2}, + sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {3, 4}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); + + NDArray expGradI( + 'c', {bS, iC, iH, iW}, + {0.001, 0.005, 0.006, 0.008, 0.03, 0.026, 0.024, 0.07, 0.05, 0.027, + 0.069, 0.044, 0.01, 0.032, 0.024, 0.044, 0.12, 0.08, 0.092, 0.224, + 0.136, 0.07, 0.164, 0.096, 0.009, 0.037, 0.03, 0.056, 0.158, 0.106, + 0.136, 0.326, 0.194, 0.099, 0.229, 0.132, 0.026, 0.08, 0.056, 0.108, + 0.28, 0.176, 0.22, 0.512, 0.296, 0.15, 0.34, 0.192}, + sd::DataType::FLOAT32); + + NDArray expGradW( + 'c', {mC, iC, kH, kW}, + {1.04, 1.04, 1.04, 1.04, 1.04, 1.04, 1.68, 1.68, 1.68, 1.68, 1.68, 1.68}, + sd::DataType::FLOAT32); + + input = 2.; + gradO.linspace(0.01, 0.01); + + sd::ops::depthwise_conv2d_bp op; + auto results = op.evaluate( + {&input, &weights, &bias, &gradO}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat, wFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); } -#endif //LIBND4J_CONVOLUTIONTESTS2_H \ No newline at end of file +#endif // LIBND4J_CONVOLUTIONTESTS2_H \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/CuDnnTests.cu b/libnd4j/tests_cpu/layers_tests/CuDnnTests.cu index 47892c973f2b..25729de21cdb 100644 --- a/libnd4j/tests_cpu/layers_tests/CuDnnTests.cu +++ b/libnd4j/tests_cpu/layers_tests/CuDnnTests.cu @@ -14,17 +14,18 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // @author raver119@gmail.com // -#include "testlayers.h" -#include #include -#include -#include #include +#include +#include + +#include + +#include "testlayers.h" #ifdef HAVE_CUDNN @@ -35,114 +36,115 @@ using namespace sd; class CuDnnTests : public testing::Test { -public: - + public: }; -static void printer(std::initializer_list helpers) { - - for (auto v:helpers) { - nd4j_printf("Initialized [%s]\n", v->name().c_str()); - } +static void printer( + std::initializer_list helpers) { + for (auto v : helpers) { + nd4j_printf("Initialized [%s]\n", v->name().c_str()); + } } - TEST_F(CuDnnTests, helpers_includer) { - // we need this block, to make sure all helpers are still available within binary, and not optimized out by linker + // we need this block, to make sure all helpers are still available within + // binary, and not optimized out by linker #ifdef HAVE_CUDNN - sd::ops::platforms::PLATFORM_conv2d_ENGINE_CUDA conv2d; - sd::ops::platforms::PLATFORM_conv2d_bp_ENGINE_CUDA conv2d_bp; - sd::ops::platforms::PLATFORM_conv3dnew_ENGINE_CUDA conv3dnew; - sd::ops::platforms::PLATFORM_conv3dnew_bp_ENGINE_CUDA conv3dnew_bp; - sd::ops::platforms::PLATFORM_depthwise_conv2d_ENGINE_CUDA depthwise_conv2d; - sd::ops::platforms::PLATFORM_depthwise_conv2d_bp_ENGINE_CUDA depthwise_conv2d_bp; - sd::ops::platforms::PLATFORM_batchnorm_ENGINE_CUDA batchnorm; - sd::ops::platforms::PLATFORM_batchnorm_bp_ENGINE_CUDA batchnorm_bp; - sd::ops::platforms::PLATFORM_avgpool2d_ENGINE_CUDA avgpool2d; - sd::ops::platforms::PLATFORM_avgpool2d_bp_ENGINE_CUDA avgpool2d_bp; - sd::ops::platforms::PLATFORM_maxpool2d_ENGINE_CUDA maxpool2d; - sd::ops::platforms::PLATFORM_maxpool2d_bp_ENGINE_CUDA maxpool2d_bp; - sd::ops::platforms::PLATFORM_avgpool3dnew_ENGINE_CUDA avgpool3dnew; - sd::ops::platforms::PLATFORM_avgpool3dnew_bp_ENGINE_CUDA avgpool3dnew_bp; - sd::ops::platforms::PLATFORM_maxpool3dnew_ENGINE_CUDA maxpool3dnew; - sd::ops::platforms::PLATFORM_maxpool3dnew_bp_ENGINE_CUDA maxpool3dnew_bp; - - - - printer({&conv2d}); - printer({&conv2d_bp}); - printer({&conv3dnew}); - printer({&conv3dnew_bp}); - printer({&depthwise_conv2d}); - printer({&depthwise_conv2d_bp}); - printer({&batchnorm}); - printer({&batchnorm_bp}); - printer({&avgpool2d}); - printer({&avgpool2d_bp}); - printer({&maxpool2d}); - printer({&maxpool2d_bp}); - printer({&avgpool3dnew}); - printer({&avgpool3dnew_bp}); - printer({&maxpool3dnew}); - printer({&maxpool3dnew_bp}); + sd::ops::platforms::PLATFORM_conv2d_ENGINE_CUDA conv2d; + sd::ops::platforms::PLATFORM_conv2d_bp_ENGINE_CUDA conv2d_bp; + sd::ops::platforms::PLATFORM_conv3dnew_ENGINE_CUDA conv3dnew; + sd::ops::platforms::PLATFORM_conv3dnew_bp_ENGINE_CUDA conv3dnew_bp; + sd::ops::platforms::PLATFORM_depthwise_conv2d_ENGINE_CUDA depthwise_conv2d; + sd::ops::platforms::PLATFORM_depthwise_conv2d_bp_ENGINE_CUDA + depthwise_conv2d_bp; + sd::ops::platforms::PLATFORM_batchnorm_ENGINE_CUDA batchnorm; + sd::ops::platforms::PLATFORM_batchnorm_bp_ENGINE_CUDA batchnorm_bp; + sd::ops::platforms::PLATFORM_avgpool2d_ENGINE_CUDA avgpool2d; + sd::ops::platforms::PLATFORM_avgpool2d_bp_ENGINE_CUDA avgpool2d_bp; + sd::ops::platforms::PLATFORM_maxpool2d_ENGINE_CUDA maxpool2d; + sd::ops::platforms::PLATFORM_maxpool2d_bp_ENGINE_CUDA maxpool2d_bp; + sd::ops::platforms::PLATFORM_avgpool3dnew_ENGINE_CUDA avgpool3dnew; + sd::ops::platforms::PLATFORM_avgpool3dnew_bp_ENGINE_CUDA avgpool3dnew_bp; + sd::ops::platforms::PLATFORM_maxpool3dnew_ENGINE_CUDA maxpool3dnew; + sd::ops::platforms::PLATFORM_maxpool3dnew_bp_ENGINE_CUDA maxpool3dnew_bp; + + printer({&conv2d}); + printer({&conv2d_bp}); + printer({&conv3dnew}); + printer({&conv3dnew_bp}); + printer({&depthwise_conv2d}); + printer({&depthwise_conv2d_bp}); + printer({&batchnorm}); + printer({&batchnorm_bp}); + printer({&avgpool2d}); + printer({&avgpool2d_bp}); + printer({&maxpool2d}); + printer({&maxpool2d_bp}); + printer({&avgpool3dnew}); + printer({&avgpool3dnew_bp}); + printer({&maxpool3dnew}); + printer({&maxpool3dnew_bp}); #endif } - TEST_F(CuDnnTests, mixed_helpers_test_1) { -#if defined(HAVE_CUDNN) && defined (HAVE_MKLDNN) - nd4j_printf("Mixed platforms test\n", ""); - - - int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto weights = NDArrayFactory::create('c', {oC, iC, kH, kW}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); - - auto expOutput = NDArrayFactory::create('c', {bS, oC, oH, oW}, {61.f, 61.f, 61.f, 61.f, 177.2f, 177.2f, 177.2f, 177.2f, 293.4f, 293.4f, 293.4f, 293.4f, 61.f, 61.f, 61.f, 61.f, 177.2f, 177.2f, 177.2f, 177.2f, 293.4f, 293.4f, 293.4f, 293.4f}); - auto zCUDA = expOutput.like(); - auto zMKL = expOutput.like(); - - input = 2.; - weights.linspace(0.1, 0.1); - weights.permutei({2,3,1,0}); - - input.syncToHost(); - weights.syncToHost(); - bias.syncToHost(); - - sd::ops::conv2d op; - - // cuDNN part - Context cuda(1); - cuda.setTargetEngine(samediff::Engine::ENGINE_CUDA); - cuda.setInputArray(0, &input); - cuda.setInputArray(1, &weights); - cuda.setInputArray(2, &bias); - cuda.setOutputArray(0, &zCUDA); - cuda.setIArguments({kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto statusCUDA = op.execute(&cuda); - - ASSERT_EQ(Status::OK(), statusCUDA); - ASSERT_EQ(expOutput, zCUDA); - - // MKL-DNN part - Context mkl(1); - mkl.setTargetEngine(samediff::Engine::ENGINE_CPU); - mkl.setInputArray(0, &input); - mkl.setInputArray(1, &weights); - mkl.setInputArray(2, &bias); - mkl.setOutputArray(0, &zMKL); - mkl.setIArguments({kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto statusMKL = op.execute(&mkl); - - zMKL.tickWriteHost(); - - ASSERT_EQ(Status::OK(), statusMKL); - ASSERT_EQ(expOutput, zMKL); +#if defined(HAVE_CUDNN) && defined(HAVE_MKLDNN) + nd4j_printf("Mixed platforms test\n", ""); + + int bS = 2, iH = 4, iW = 3, iC = 4, oC = 3, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kH, kW}); + auto bias = NDArrayFactory::create('c', {oC}, {1, 2, 3}); + + auto expOutput = NDArrayFactory::create( + 'c', {bS, oC, oH, oW}, + {61.f, 61.f, 61.f, 61.f, 177.2f, 177.2f, 177.2f, 177.2f, + 293.4f, 293.4f, 293.4f, 293.4f, 61.f, 61.f, 61.f, 61.f, + 177.2f, 177.2f, 177.2f, 177.2f, 293.4f, 293.4f, 293.4f, 293.4f}); + auto zCUDA = expOutput.like(); + auto zMKL = expOutput.like(); + + input = 2.; + weights.linspace(0.1, 0.1); + weights.permutei({2, 3, 1, 0}); + + input.syncToHost(); + weights.syncToHost(); + bias.syncToHost(); + + sd::ops::conv2d op; + + // cuDNN part + Context cuda(1); + cuda.setTargetEngine(samediff::Engine::ENGINE_CUDA); + cuda.setInputArray(0, &input); + cuda.setInputArray(1, &weights); + cuda.setInputArray(2, &bias); + cuda.setOutputArray(0, &zCUDA); + cuda.setIArguments({kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto statusCUDA = op.execute(&cuda); + + ASSERT_EQ(Status::OK(), statusCUDA); + ASSERT_EQ(expOutput, zCUDA); + + // MKL-DNN part + Context mkl(1); + mkl.setTargetEngine(samediff::Engine::ENGINE_CPU); + mkl.setInputArray(0, &input); + mkl.setInputArray(1, &weights); + mkl.setInputArray(2, &bias); + mkl.setOutputArray(0, &zMKL); + mkl.setIArguments({kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto statusMKL = op.execute(&mkl); + + zMKL.tickWriteHost(); + + ASSERT_EQ(Status::OK(), statusMKL); + ASSERT_EQ(expOutput, zMKL); #endif } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu b/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu index 3d6886565cee..dfd57b7350f1 100644 --- a/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu +++ b/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu @@ -14,3091 +14,3591 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author raver119@gmail.com - // +// +// @author raver119@gmail.com +// -#include "testlayers.h" +#include #include #include +#include +#include #include #include #include #include -#include -#include +#include +#include #include #include -#include #include -#include -#include -#include -#include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class CudaBasicsTests1 : public testing::Test { -public: - + public: }; - ////////////////////////////////////////////////////////////////////////// -static cudaError_t allocateDeviceMem(LaunchContext& lc, std::vector& devicePtrs, const std::vector>& hostData) { - - if(devicePtrs.size() != hostData.size()) - throw std::invalid_argument("prepareDataForCuda: two input sts::vectors should same sizes !"); - - cudaError_t cudaResult; - - void* reductionPointer; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); if(cudaResult != 0) return cudaResult; - int* allocationPointer; - cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); if(cudaResult != 0) return cudaResult; - - lc.setReductionPointer(reductionPointer); - lc.setAllocationPointer(allocationPointer); - cudaStream_t stream = *lc.getCudaStream(); - - for(int i = 0; i < devicePtrs.size(); ++i) { - - cudaResult = cudaMalloc(reinterpret_cast(&devicePtrs[i]), hostData[i].second); if(cudaResult != 0) return cudaResult; - cudaMemcpyAsync(devicePtrs[i], hostData[i].first, hostData[i].second, cudaMemcpyHostToDevice, stream); - } - return cudaResult; +static cudaError_t allocateDeviceMem( + LaunchContext &lc, std::vector &devicePtrs, + const std::vector> &hostData) { + if (devicePtrs.size() != hostData.size()) + throw std::invalid_argument( + "prepareDataForCuda: two input sts::vectors should same sizes !"); + + cudaError_t cudaResult; + + void *reductionPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + if (cudaResult != 0) return cudaResult; + int *allocationPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&allocationPointer), 1024 * 1024); + if (cudaResult != 0) return cudaResult; + + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + cudaStream_t stream = *lc.getCudaStream(); + + for (int i = 0; i < devicePtrs.size(); ++i) { + cudaResult = cudaMalloc(reinterpret_cast(&devicePtrs[i]), + hostData[i].second); + if (cudaResult != 0) return cudaResult; + cudaMemcpyAsync(devicePtrs[i], hostData[i].first, hostData[i].second, + cudaMemcpyHostToDevice, stream); + } + return cudaResult; } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, TestPairwise_1) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto z = NDArrayFactory::create('c', { 5 }, {0,0,0,0,0}); - - auto exp = NDArrayFactory::create('c', { 5 }, { 2, 4, 6, 8, 10 }); - - // making raw buffers - Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - ASSERT_EQ(0, res); - res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - ASSERT_EQ(0, res); - res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - ASSERT_EQ(0, res); - - Nd4jPointer nativeStream = (Nd4jPointer)malloc(sizeof(cudaStream_t)); - CHECK_ALLOC(nativeStream, "Failed to allocate memory for new CUDA stream", sizeof(cudaStream_t)); - cudaError_t dZ = cudaStreamCreate(reinterpret_cast(&nativeStream)); - auto stream = reinterpret_cast(&nativeStream); - - x.dataBuffer()->allocatePrimary(); - x.syncToHost(); - - cudaMemcpyAsync(devBufferPtrX, x.buffer(), x.lengthOf() * x.sizeOfT(), cudaMemcpyHostToDevice, *stream); - cudaMemcpyAsync(devShapePtrX, x.shapeInfo(), shape::shapeInfoByteLength(x.shapeInfo()), cudaMemcpyHostToDevice, *stream); - res = cudaStreamSynchronize(*stream); - ASSERT_EQ(0, res); - - LaunchContext lc(stream, nullptr, nullptr); - NativeOpExecutioner::execPairwiseTransform(&lc, pairwise::Add, nullptr, x.shapeInfo(), devBufferPtrX, reinterpret_cast(devShapePtrX), nullptr, x.shapeInfo(), devBufferPtrX, reinterpret_cast(devShapePtrX), nullptr, z.shapeInfo(), devBufferPtrZ, reinterpret_cast(devShapePtrX), nullptr); - res = cudaStreamSynchronize(*stream); - ASSERT_EQ(0, res); - - z.dataBuffer()->allocatePrimary(); - - cudaMemcpyAsync(z.buffer(), devBufferPtrZ, z.lengthOf() * x.sizeOfT(), cudaMemcpyDeviceToHost, *stream); - res = cudaStreamSynchronize(*stream); - ASSERT_EQ(0, res); - - cudaFree(devBufferPtrX); - cudaFree(devBufferPtrZ); - cudaFree(devShapePtrX); - - // needed due to memcpy - z.tickWriteHost(); - - for (int e = 0; e < z.lengthOf(); e++) { - //nd4j_printf("step %i\n", e); - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - } + // allocating host-side arrays + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto z = NDArrayFactory::create('c', {5}, {0, 0, 0, 0, 0}); + + auto exp = NDArrayFactory::create('c', {5}, {2, 4, 6, 8, 10}); + + // making raw buffers + Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + x.lengthOf() * x.sizeOfT()); + ASSERT_EQ(0, res); + res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), + x.lengthOf() * x.sizeOfT()); + ASSERT_EQ(0, res); + res = cudaMalloc(reinterpret_cast(&devShapePtrX), + shape::shapeInfoByteLength(x.shapeInfo())); + ASSERT_EQ(0, res); + + Nd4jPointer nativeStream = (Nd4jPointer)malloc(sizeof(cudaStream_t)); + CHECK_ALLOC(nativeStream, "Failed to allocate memory for new CUDA stream", + sizeof(cudaStream_t)); + cudaError_t dZ = + cudaStreamCreate(reinterpret_cast(&nativeStream)); + auto stream = reinterpret_cast(&nativeStream); + + x.dataBuffer()->allocatePrimary(); + x.syncToHost(); + + cudaMemcpyAsync(devBufferPtrX, x.buffer(), x.lengthOf() * x.sizeOfT(), + cudaMemcpyHostToDevice, *stream); + cudaMemcpyAsync(devShapePtrX, x.shapeInfo(), + shape::shapeInfoByteLength(x.shapeInfo()), + cudaMemcpyHostToDevice, *stream); + res = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, res); + + LaunchContext lc(stream, nullptr, nullptr); + NativeOpExecutioner::execPairwiseTransform( + &lc, pairwise::Add, nullptr, x.shapeInfo(), devBufferPtrX, + reinterpret_cast(devShapePtrX), nullptr, x.shapeInfo(), + devBufferPtrX, reinterpret_cast(devShapePtrX), nullptr, + z.shapeInfo(), devBufferPtrZ, reinterpret_cast(devShapePtrX), + nullptr); + res = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, res); + + z.dataBuffer()->allocatePrimary(); + + cudaMemcpyAsync(z.buffer(), devBufferPtrZ, z.lengthOf() * x.sizeOfT(), + cudaMemcpyDeviceToHost, *stream); + res = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, res); + + cudaFree(devBufferPtrX); + cudaFree(devBufferPtrZ); + cudaFree(devShapePtrX); + + // needed due to memcpy + z.tickWriteHost(); + + for (int e = 0; e < z.lengthOf(); e++) { + //nd4j_printf("step %i\n", e); + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + } } - //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execIndexReduceScalar_1) { + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT32); + NDArray x2('c', {2, 2}, {0.5, 1.5, -4.5, 3.5}, sd::DataType::BFLOAT16); + NDArray x3('c', {2, 2}, {0, -1, 0, 1}, sd::DataType::BOOL); + + NDArray scalar('c', {}, std::vector{0}, sd::DataType::INT64); + + NDArray exp1('c', {}, std::vector{3}, sd::DataType::INT64); + NDArray exp2('c', {}, std::vector{2}, sd::DataType::INT64); + NDArray exp3('c', {}, std::vector{1}, sd::DataType::INT64); + + void *dX1, *dX2, *dX3, *dZ; + Nd4jLong *dX1ShapeInfo, *dX2ShapeInfo, *dX3ShapeInfo, *dZShapeInfo; + + cudaError_t cudaResult; + + cudaResult = + cudaMalloc(reinterpret_cast(&dX1), x1.lengthOf() * x1.sizeOfT()); + ASSERT_EQ(0, cudaResult); + cudaResult = + cudaMalloc(reinterpret_cast(&dX2), x2.lengthOf() * x2.sizeOfT()); + ASSERT_EQ(0, cudaResult); + cudaResult = + cudaMalloc(reinterpret_cast(&dX3), x3.lengthOf() * x3.sizeOfT()); + ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dZ), + scalar.lengthOf() * scalar.sizeOfT()); + ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dX1ShapeInfo), + shape::shapeInfoByteLength(x1.shapeInfo())); + ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dX2ShapeInfo), + shape::shapeInfoByteLength(x2.shapeInfo())); + ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dX3ShapeInfo), + shape::shapeInfoByteLength(x3.shapeInfo())); + ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dZShapeInfo), + shape::shapeInfoByteLength(scalar.shapeInfo())); + ASSERT_EQ(0, cudaResult); + + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + + x1.syncToHost(); + x2.syncToHost(); + x3.syncToHost(); + scalar.syncToHost(); + + cudaMemcpyAsync(dX1, x1.buffer(), x1.lengthOf() * x1.sizeOfT(), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX2, x2.buffer(), x2.lengthOf() * x2.sizeOfT(), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX3, x3.buffer(), x3.lengthOf() * x3.sizeOfT(), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX1ShapeInfo, x1.shapeInfo(), + shape::shapeInfoByteLength(x1.shapeInfo()), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX2ShapeInfo, x2.shapeInfo(), + shape::shapeInfoByteLength(x2.shapeInfo()), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX3ShapeInfo, x3.shapeInfo(), + shape::shapeInfoByteLength(x3.shapeInfo()), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dZShapeInfo, scalar.shapeInfo(), + shape::shapeInfoByteLength(scalar.shapeInfo()), + cudaMemcpyHostToDevice, stream); + + void *reductionPointer = nullptr; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + cudaResult = cudaMemset(reductionPointer, 0, 1024 * 1024); + ASSERT_EQ(0, cudaResult); + + LaunchContext lc(&stream, + LaunchContext::defaultContext()->getReductionPointer(), + LaunchContext::defaultContext()->getScalarPointer(), + LaunchContext::defaultContext()->getAllocationPointer()); + + /***************************************/ + + NativeOpExecutioner::execIndexReduceScalar( + &lc, sd::indexreduce::IndexAbsoluteMax, x1.buffer(), x1.shapeInfo(), dX1, + dX1ShapeInfo, nullptr, scalar.buffer(), scalar.shapeInfo(), dZ, + dZShapeInfo); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + cudaMemcpyAsync(scalar.buffer(), dZ, scalar.lengthOf() * scalar.sizeOfT(), + cudaMemcpyDeviceToHost, stream); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + scalar.tickWriteHost(); + + ASSERT_NEAR(exp1.e(0), scalar.e(0), 1e-5); + + /***************************************/ + + NativeOpExecutioner::execIndexReduceScalar( + &lc, sd::indexreduce::IndexAbsoluteMax, nullptr, x2.shapeInfo(), dX2, + dX2ShapeInfo, nullptr, nullptr, scalar.shapeInfo(), dZ, dZShapeInfo); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + cudaMemcpyAsync(scalar.buffer(), dZ, scalar.lengthOf() * scalar.sizeOfT(), + cudaMemcpyDeviceToHost, stream); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + ASSERT_NEAR(exp2.e(0), scalar.e(0), 1e-5); + + // ************************************* + + NativeOpExecutioner::execIndexReduceScalar( + &lc, sd::indexreduce::IndexAbsoluteMax, nullptr, x3.shapeInfo(), dX3, + dX3ShapeInfo, nullptr, nullptr, scalar.shapeInfo(), dZ, dZShapeInfo); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + cudaMemcpyAsync(scalar.buffer(), dZ, scalar.lengthOf() * scalar.sizeOfT(), + cudaMemcpyDeviceToHost, stream); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + ASSERT_NEAR(exp3.e(0), scalar.e(0), 1e-5); + + /***************************************/ - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT32); - NDArray x2('c', {2,2}, {0.5, 1.5, -4.5, 3.5}, sd::DataType::BFLOAT16); - NDArray x3('c', {2,2}, {0, -1, 0, 1}, sd::DataType::BOOL); - - NDArray scalar('c', {}, std::vector{0}, sd::DataType::INT64); - - NDArray exp1('c', {}, std::vector{3}, sd::DataType::INT64); - NDArray exp2('c', {}, std::vector{2}, sd::DataType::INT64); - NDArray exp3('c', {}, std::vector{1}, sd::DataType::INT64); - - void *dX1, *dX2, *dX3, *dZ; - Nd4jLong *dX1ShapeInfo, *dX2ShapeInfo, *dX3ShapeInfo, *dZShapeInfo; - - cudaError_t cudaResult; - - cudaResult = cudaMalloc(reinterpret_cast(&dX1), x1.lengthOf() * x1.sizeOfT()); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dX2), x2.lengthOf() * x2.sizeOfT()); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dX3), x3.lengthOf() * x3.sizeOfT()); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dZ), scalar.lengthOf() * scalar.sizeOfT()); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dX1ShapeInfo), shape::shapeInfoByteLength(x1.shapeInfo())); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dX2ShapeInfo), shape::shapeInfoByteLength(x2.shapeInfo())); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dX3ShapeInfo), shape::shapeInfoByteLength(x3.shapeInfo())); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dZShapeInfo), shape::shapeInfoByteLength(scalar.shapeInfo())); ASSERT_EQ(0, cudaResult); - - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); - ASSERT_EQ(0, cudaResult); - - x1.syncToHost(); - x2.syncToHost(); - x3.syncToHost(); - scalar.syncToHost(); - - cudaMemcpyAsync(dX1, x1.buffer(), x1.lengthOf() * x1.sizeOfT(), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dX2, x2.buffer(), x2.lengthOf() * x2.sizeOfT(), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dX3, x3.buffer(), x3.lengthOf() * x3.sizeOfT(), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dX1ShapeInfo, x1.shapeInfo(), shape::shapeInfoByteLength(x1.shapeInfo()), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dX2ShapeInfo, x2.shapeInfo(), shape::shapeInfoByteLength(x2.shapeInfo()), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dX3ShapeInfo, x3.shapeInfo(), shape::shapeInfoByteLength(x3.shapeInfo()), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dZShapeInfo, scalar.shapeInfo(), shape::shapeInfoByteLength(scalar.shapeInfo()), cudaMemcpyHostToDevice, stream); - - void* reductionPointer = nullptr; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); - ASSERT_EQ(0, cudaResult); - cudaResult = cudaMemset(reductionPointer, 0, 1024 * 1024); - ASSERT_EQ(0, cudaResult); - - LaunchContext lc(&stream, LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getScalarPointer(), LaunchContext::defaultContext()->getAllocationPointer()); - - /***************************************/ - - NativeOpExecutioner::execIndexReduceScalar(&lc, - sd::indexreduce::IndexAbsoluteMax, - x1.buffer(), x1.shapeInfo(), - dX1, dX1ShapeInfo, - nullptr, - scalar.buffer(), scalar.shapeInfo(), - dZ, dZShapeInfo); - - cudaResult = cudaStreamSynchronize(stream); - ASSERT_EQ(0, cudaResult); - - cudaMemcpyAsync(scalar.buffer(), dZ, scalar.lengthOf() * scalar.sizeOfT(), cudaMemcpyDeviceToHost, stream); - - cudaResult = cudaStreamSynchronize(stream); - ASSERT_EQ(0, cudaResult); - - scalar.tickWriteHost(); - - ASSERT_NEAR(exp1.e(0), scalar.e(0), 1e-5); - - /***************************************/ - - NativeOpExecutioner::execIndexReduceScalar(&lc, - sd::indexreduce::IndexAbsoluteMax, - nullptr, x2.shapeInfo(), - dX2, dX2ShapeInfo, - nullptr, - nullptr, scalar.shapeInfo(), - dZ, dZShapeInfo); - - cudaResult = cudaStreamSynchronize(stream); - ASSERT_EQ(0, cudaResult); - - cudaMemcpyAsync(scalar.buffer(), dZ, scalar.lengthOf() * scalar.sizeOfT(), cudaMemcpyDeviceToHost, stream); - - cudaResult = cudaStreamSynchronize(stream); - ASSERT_EQ(0, cudaResult); - - ASSERT_NEAR(exp2.e(0), scalar.e(0), 1e-5); - - // ************************************* - - NativeOpExecutioner::execIndexReduceScalar(&lc, - sd::indexreduce::IndexAbsoluteMax, - nullptr, x3.shapeInfo(), - dX3, dX3ShapeInfo, - nullptr, - nullptr, scalar.shapeInfo(), - dZ, dZShapeInfo); - - cudaResult = cudaStreamSynchronize(stream); - ASSERT_EQ(0, cudaResult); - - cudaMemcpyAsync(scalar.buffer(), dZ, scalar.lengthOf() * scalar.sizeOfT(), cudaMemcpyDeviceToHost, stream); - - cudaResult = cudaStreamSynchronize(stream); - ASSERT_EQ(0, cudaResult); - - ASSERT_NEAR(exp3.e(0), scalar.e(0), 1e-5); - - /***************************************/ - - cudaFree(dX1); cudaFree(dX2); cudaFree(dX3); cudaFree(dZ); - cudaFree(dX1ShapeInfo); cudaFree(dX2ShapeInfo); cudaFree(dX3ShapeInfo); cudaFree(dZShapeInfo); - - /***************************************/ - - cudaResult = cudaStreamDestroy(stream); - ASSERT_EQ(0, cudaResult); - + cudaFree(dX1); + cudaFree(dX2); + cudaFree(dX3); + cudaFree(dZ); + cudaFree(dX1ShapeInfo); + cudaFree(dX2ShapeInfo); + cudaFree(dX3ShapeInfo); + cudaFree(dZShapeInfo); + + /***************************************/ + + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduce3Scalar_1) { - - if (!Environment::getInstance().isExperimentalBuild()) - return; - - NDArray x1('c', {2,2}, {1,2,3,4}, sd::DataType::INT32); - NDArray x2('c', {2,2}, {-1,-2,-3,-4}, sd::DataType::INT32); - NDArray x3('c', {2,2}, {1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, {1,2,3,4}, sd::DataType::DOUBLE); - - NDArray exp1('c', {}, std::vector{-30.f}, sd::DataType::FLOAT32); - NDArray exp2('c', {}, std::vector{15.}, sd::DataType::DOUBLE); - - NDArray scalar1('c', {}, std::vector{100.f}, sd::DataType::FLOAT32); - NDArray scalar2('c', {}, std::vector{100.}, sd::DataType::DOUBLE); - - void *dX1, *dX2, *dX3, *dX4, *dZ1, *dZ2; - Nd4jLong *dX1ShapeInfo, *dX3ShapeInfo, *dZ1ShapeInfo, *dZ2ShapeInfo; - - cudaError_t cudaResult; - - cudaResult = cudaMalloc(reinterpret_cast(&dX1), x1.lengthOf() * x1.sizeOfT()); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dX2), x2.lengthOf() * x2.sizeOfT()); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dX3), x3.lengthOf() * x3.sizeOfT()); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dX4), x4.lengthOf() * x4.sizeOfT()); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dZ1), scalar1.lengthOf() * scalar1.sizeOfT()); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dZ2), scalar2.lengthOf() * scalar2.sizeOfT()); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dX1ShapeInfo), shape::shapeInfoByteLength(x1.shapeInfo())); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dX3ShapeInfo), shape::shapeInfoByteLength(x3.shapeInfo())); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dZ1ShapeInfo), shape::shapeInfoByteLength(scalar1.shapeInfo())); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dZ2ShapeInfo), shape::shapeInfoByteLength(scalar2.shapeInfo())); ASSERT_EQ(0, cudaResult); - - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); - ASSERT_EQ(0, cudaResult); - - x1.syncToHost(); - x2.syncToHost(); - x3.syncToHost(); - x4.syncToHost(); - scalar1.syncToHost(); - scalar2.syncToHost(); - - cudaMemcpyAsync(dX1, x1.buffer(), x1.lengthOf() * x1.sizeOfT(), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dX2, x2.buffer(), x2.lengthOf() * x2.sizeOfT(), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dX3, x3.buffer(), x3.lengthOf() * x3.sizeOfT(), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dX4, x4.buffer(), x4.lengthOf() * x4.sizeOfT(), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dX1ShapeInfo, x1.shapeInfo(), shape::shapeInfoByteLength(x1.shapeInfo()), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dX3ShapeInfo, x3.shapeInfo(), shape::shapeInfoByteLength(x3.shapeInfo()), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dZ1ShapeInfo, scalar1.shapeInfo(), shape::shapeInfoByteLength(scalar1.shapeInfo()), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dZ2ShapeInfo, scalar2.shapeInfo(), shape::shapeInfoByteLength(scalar2.shapeInfo()), cudaMemcpyHostToDevice, stream); - - /***************************************/ - - void* reductionPointer = nullptr; - int* allocationPointer = nullptr; - - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - - LaunchContext lc(&stream, reductionPointer, nullptr, allocationPointer); - - /***************************************/ - - NativeOpExecutioner::execReduce3Scalar(&lc, sd::reduce3::Dot,nullptr, x1.shapeInfo(),dX1, dX1ShapeInfo, nullptr, nullptr, x2.shapeInfo(),dX2, dX1ShapeInfo,nullptr, scalar1.shapeInfo(),dZ1, dZ1ShapeInfo); - - cudaResult = cudaStreamSynchronize(stream); - ASSERT_EQ(0, cudaResult); - - scalar1.tickWriteHost(); - scalar2.tickWriteHost(); - - cudaMemcpyAsync(scalar1.buffer(), dZ1, scalar1.lengthOf() * scalar1.sizeOfT(), cudaMemcpyDeviceToHost, stream); - - cudaResult = cudaStreamSynchronize(stream); - ASSERT_EQ(0, cudaResult); - - ASSERT_NEAR(exp1.e(0), scalar1.e(0), 1e-5); - - /***************************************/ - - NativeOpExecutioner::execReduce3Scalar(&lc, sd::reduce3::Dot,nullptr, x3.shapeInfo(),dX3, dX3ShapeInfo, nullptr, nullptr, x4.shapeInfo(),dX4, dX3ShapeInfo,nullptr, scalar2.shapeInfo(),dZ2, dZ2ShapeInfo); - - cudaResult = cudaStreamSynchronize(stream); - ASSERT_EQ(0, cudaResult); - - cudaMemcpyAsync(scalar2.buffer(), dZ2, scalar2.lengthOf() * scalar2.sizeOfT(), cudaMemcpyDeviceToHost, stream); - - cudaResult = cudaStreamSynchronize(stream); - ASSERT_EQ(0, cudaResult); - - ASSERT_NEAR(exp2.e(0), scalar2.e(0), 1e-5); - - /***************************************/ - - cudaFree(dX1); cudaFree(dX2); cudaFree(dX3); cudaFree(dX4); cudaFree(dZ1); cudaFree(dZ2); - cudaFree(dX1ShapeInfo); cudaFree(dX3ShapeInfo); cudaFree(dZ1ShapeInfo); cudaFree(dZ2ShapeInfo); - - /***************************************/ - - cudaResult = cudaStreamDestroy(stream); - ASSERT_EQ(0, cudaResult); + if (!Environment::getInstance().isExperimentalBuild()) return; + + NDArray x1('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray x2('c', {2, 2}, {-1, -2, -3, -4}, sd::DataType::INT32); + NDArray x3('c', {2, 2}, {1.5, 1.5, 1.5, 1.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::DOUBLE); + + NDArray exp1('c', {}, std::vector{-30.f}, sd::DataType::FLOAT32); + NDArray exp2('c', {}, std::vector{15.}, sd::DataType::DOUBLE); + + NDArray scalar1('c', {}, std::vector{100.f}, sd::DataType::FLOAT32); + NDArray scalar2('c', {}, std::vector{100.}, sd::DataType::DOUBLE); + + void *dX1, *dX2, *dX3, *dX4, *dZ1, *dZ2; + Nd4jLong *dX1ShapeInfo, *dX3ShapeInfo, *dZ1ShapeInfo, *dZ2ShapeInfo; + + cudaError_t cudaResult; + + cudaResult = + cudaMalloc(reinterpret_cast(&dX1), x1.lengthOf() * x1.sizeOfT()); + ASSERT_EQ(0, cudaResult); + cudaResult = + cudaMalloc(reinterpret_cast(&dX2), x2.lengthOf() * x2.sizeOfT()); + ASSERT_EQ(0, cudaResult); + cudaResult = + cudaMalloc(reinterpret_cast(&dX3), x3.lengthOf() * x3.sizeOfT()); + ASSERT_EQ(0, cudaResult); + cudaResult = + cudaMalloc(reinterpret_cast(&dX4), x4.lengthOf() * x4.sizeOfT()); + ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dZ1), + scalar1.lengthOf() * scalar1.sizeOfT()); + ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dZ2), + scalar2.lengthOf() * scalar2.sizeOfT()); + ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dX1ShapeInfo), + shape::shapeInfoByteLength(x1.shapeInfo())); + ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dX3ShapeInfo), + shape::shapeInfoByteLength(x3.shapeInfo())); + ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dZ1ShapeInfo), + shape::shapeInfoByteLength(scalar1.shapeInfo())); + ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dZ2ShapeInfo), + shape::shapeInfoByteLength(scalar2.shapeInfo())); + ASSERT_EQ(0, cudaResult); + + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + + x1.syncToHost(); + x2.syncToHost(); + x3.syncToHost(); + x4.syncToHost(); + scalar1.syncToHost(); + scalar2.syncToHost(); + + cudaMemcpyAsync(dX1, x1.buffer(), x1.lengthOf() * x1.sizeOfT(), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX2, x2.buffer(), x2.lengthOf() * x2.sizeOfT(), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX3, x3.buffer(), x3.lengthOf() * x3.sizeOfT(), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX4, x4.buffer(), x4.lengthOf() * x4.sizeOfT(), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX1ShapeInfo, x1.shapeInfo(), + shape::shapeInfoByteLength(x1.shapeInfo()), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX3ShapeInfo, x3.shapeInfo(), + shape::shapeInfoByteLength(x3.shapeInfo()), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dZ1ShapeInfo, scalar1.shapeInfo(), + shape::shapeInfoByteLength(scalar1.shapeInfo()), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dZ2ShapeInfo, scalar2.shapeInfo(), + shape::shapeInfoByteLength(scalar2.shapeInfo()), + cudaMemcpyHostToDevice, stream); + + /***************************************/ + + void *reductionPointer = nullptr; + int *allocationPointer = nullptr; + + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + cudaResult = + cudaMalloc(reinterpret_cast(&allocationPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + + LaunchContext lc(&stream, reductionPointer, nullptr, allocationPointer); + + /***************************************/ + + NativeOpExecutioner::execReduce3Scalar( + &lc, sd::reduce3::Dot, nullptr, x1.shapeInfo(), dX1, dX1ShapeInfo, + nullptr, nullptr, x2.shapeInfo(), dX2, dX1ShapeInfo, nullptr, + scalar1.shapeInfo(), dZ1, dZ1ShapeInfo); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + scalar1.tickWriteHost(); + scalar2.tickWriteHost(); + + cudaMemcpyAsync(scalar1.buffer(), dZ1, scalar1.lengthOf() * scalar1.sizeOfT(), + cudaMemcpyDeviceToHost, stream); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + ASSERT_NEAR(exp1.e(0), scalar1.e(0), 1e-5); + + /***************************************/ + + NativeOpExecutioner::execReduce3Scalar( + &lc, sd::reduce3::Dot, nullptr, x3.shapeInfo(), dX3, dX3ShapeInfo, + nullptr, nullptr, x4.shapeInfo(), dX4, dX3ShapeInfo, nullptr, + scalar2.shapeInfo(), dZ2, dZ2ShapeInfo); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + cudaMemcpyAsync(scalar2.buffer(), dZ2, scalar2.lengthOf() * scalar2.sizeOfT(), + cudaMemcpyDeviceToHost, stream); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + ASSERT_NEAR(exp2.e(0), scalar2.e(0), 1e-5); + + /***************************************/ + + cudaFree(dX1); + cudaFree(dX2); + cudaFree(dX3); + cudaFree(dX4); + cudaFree(dZ1); + cudaFree(dZ2); + cudaFree(dX1ShapeInfo); + cudaFree(dX3ShapeInfo); + cudaFree(dZ1ShapeInfo); + cudaFree(dZ2ShapeInfo); + + /***************************************/ + + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } - //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduce3_1) { - - NDArray x('c', {2,2}, {1,2,3,4}, sd::DataType::INT32); - NDArray y('c', {2,2}, {-1,-2,-3,-4}, sd::DataType::INT32); - - NDArray exp('c', {}, std::vector{-30.f}, sd::DataType::FLOAT32); - NDArray z('c', {}, std::vector{100.f}, sd::DataType::FLOAT32); - - std::vector dimensions = {0, 1}; - - x.syncToHost(); - y.syncToHost(); - z.syncToHost(); - - - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - std::vector devicePtrs(hostData.size(), nullptr); - - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduce3(&lc, sd::reduce3::Dot, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - nullptr, nullptr, nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray y('c', {2, 2}, {-1, -2, -3, -4}, sd::DataType::INT32); + + NDArray exp('c', {}, std::vector{-30.f}, sd::DataType::FLOAT32); + NDArray z('c', {}, std::vector{100.f}, sd::DataType::FLOAT32); + + std::vector dimensions = {0, 1}; + + x.syncToHost(); + y.syncToHost(); + z.syncToHost(); + + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + std::vector devicePtrs(hostData.size(), nullptr); + + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3( + &lc, sd::reduce3::Dot, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), nullptr, + nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } - //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduce3_2) { - - NDArray x('c', {2,2}, {1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); - NDArray y('c', {2,2}, {1,2,3,4}, sd::DataType::DOUBLE); - - NDArray exp('c', {}, std::vector{15.}, sd::DataType::DOUBLE); - NDArray z('c', {}, std::vector{100.}, sd::DataType::DOUBLE); - - std::vector dimensions = {0, 1}; - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduce3(&lc, sd::reduce3::Dot, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - nullptr, nullptr, nullptr, nullptr); - - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2}, {1.5, 1.5, 1.5, 1.5}, sd::DataType::DOUBLE); + NDArray y('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::DOUBLE); + + NDArray exp('c', {}, std::vector{15.}, sd::DataType::DOUBLE); + NDArray z('c', {}, std::vector{100.}, sd::DataType::DOUBLE); + + std::vector dimensions = {0, 1}; + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3( + &lc, sd::reduce3::Dot, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), nullptr, + nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduce3_3) { - - NDArray x('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::INT32); - NDArray y('c', {2,3}, {-6,-5,-4,-3,-2,-1}, sd::DataType::INT32); - - NDArray exp('c', {3}, {-18,-20,-18}, sd::DataType::FLOAT32); - NDArray z('c', {3}, {100,100,100}, sd::DataType::FLOAT32); - - std::vector dimensions = {0}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // evaluate yTad data - shape::TAD yTad; - yTad.init(y.shapeInfo(), dimensions.data(), dimensions.size()); - yTad.createTadOnlyShapeInfo(); - yTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - hostData.emplace_back(yTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(yTad.tadOnlyShapeInfo));// 3 -- yTadShapeInfo - hostData.emplace_back(yTad.tadOffsets, yTad.numTads * sizeof(Nd4jLong)); // 4-- yTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduce3(&lc, sd::reduce3::Dot, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - (Nd4jLong*)devicePtrs[3], (Nd4jLong*)devicePtrs[4]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::INT32); + NDArray y('c', {2, 3}, {-6, -5, -4, -3, -2, -1}, sd::DataType::INT32); + + NDArray exp('c', {3}, {-18, -20, -18}, sd::DataType::FLOAT32); + NDArray z('c', {3}, {100, 100, 100}, sd::DataType::FLOAT32); + + std::vector dimensions = {0}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // evaluate yTad data + shape::TAD yTad; + yTad.init(y.shapeInfo(), dimensions.data(), dimensions.size()); + yTad.createTadOnlyShapeInfo(); + yTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + hostData.emplace_back( + yTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(yTad.tadOnlyShapeInfo)); // 3 -- yTadShapeInfo + hostData.emplace_back(yTad.tadOffsets, + yTad.numTads * sizeof(Nd4jLong)); // 4-- yTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3( + &lc, sd::reduce3::Dot, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], + (Nd4jLong *)devicePtrs[3], (Nd4jLong *)devicePtrs[4]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduce3_4) { - - NDArray x('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::DOUBLE); - NDArray y('c', {2,3}, {1.5,1.5,1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); - - NDArray exp('c', {2}, {9,22.5}, sd::DataType::DOUBLE); - NDArray z('c', {2}, {100,100}, sd::DataType::DOUBLE); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // evaluate yTad data - shape::TAD yTad; - yTad.init(y.shapeInfo(), dimensions.data(), dimensions.size()); - yTad.createTadOnlyShapeInfo(); - yTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - hostData.emplace_back(yTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(yTad.tadOnlyShapeInfo));// 3 -- yTadShapeInfo - hostData.emplace_back(yTad.tadOffsets, yTad.numTads * sizeof(Nd4jLong)); // 4-- yTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduce3(&lc, sd::reduce3::Dot, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - (Nd4jLong*)devicePtrs[3], (Nd4jLong*)devicePtrs[4]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); + NDArray y('c', {2, 3}, {1.5, 1.5, 1.5, 1.5, 1.5, 1.5}, sd::DataType::DOUBLE); + + NDArray exp('c', {2}, {9, 22.5}, sd::DataType::DOUBLE); + NDArray z('c', {2}, {100, 100}, sd::DataType::DOUBLE); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // evaluate yTad data + shape::TAD yTad; + yTad.init(y.shapeInfo(), dimensions.data(), dimensions.size()); + yTad.createTadOnlyShapeInfo(); + yTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + hostData.emplace_back( + yTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(yTad.tadOnlyShapeInfo)); // 3 -- yTadShapeInfo + hostData.emplace_back(yTad.tadOffsets, + yTad.numTads * sizeof(Nd4jLong)); // 4-- yTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3( + &lc, sd::reduce3::Dot, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], + (Nd4jLong *)devicePtrs[3], (Nd4jLong *)devicePtrs[4]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduce3_5) { - - NDArray x('c', {2,2,3}, {1.5,1.5,1.5,1.5,1.5,1.5,1.5,1.5,1.5,1.5,1.5,1.5}, sd::DataType::FLOAT32); - NDArray y('c', {2,2,3}, {1,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::FLOAT32); - - NDArray exp('c', {2,3}, {7.5, 10.5, 13.5, 25.5, 28.5, 31.5}, sd::DataType::FLOAT32); - NDArray z('c', {2,3}, {100,100,100,100,100,100}, sd::DataType::FLOAT32); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // evaluate yTad data - shape::TAD yTad; - yTad.init(y.shapeInfo(), dimensions.data(), dimensions.size()); - yTad.createTadOnlyShapeInfo(); - yTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - hostData.emplace_back(yTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(yTad.tadOnlyShapeInfo));// 3 -- yTadShapeInfo - hostData.emplace_back(yTad.tadOffsets, yTad.numTads * sizeof(Nd4jLong)); // 4-- yTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduce3(&lc, sd::reduce3::Dot, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - (Nd4jLong*)devicePtrs[3], (Nd4jLong*)devicePtrs[4]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2, 3}, + {1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5}, + sd::DataType::FLOAT32); + NDArray y('c', {2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + sd::DataType::FLOAT32); + + NDArray exp('c', {2, 3}, {7.5, 10.5, 13.5, 25.5, 28.5, 31.5}, + sd::DataType::FLOAT32); + NDArray z('c', {2, 3}, {100, 100, 100, 100, 100, 100}, sd::DataType::FLOAT32); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // evaluate yTad data + shape::TAD yTad; + yTad.init(y.shapeInfo(), dimensions.data(), dimensions.size()); + yTad.createTadOnlyShapeInfo(); + yTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + hostData.emplace_back( + yTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(yTad.tadOnlyShapeInfo)); // 3 -- yTadShapeInfo + hostData.emplace_back(yTad.tadOffsets, + yTad.numTads * sizeof(Nd4jLong)); // 4-- yTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3( + &lc, sd::reduce3::Dot, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], + (Nd4jLong *)devicePtrs[3], (Nd4jLong *)devicePtrs[4]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduce3All_1) { - - NDArray x('c', {2,2}, {1,2,3,4}, sd::DataType::INT32); - NDArray y('c', {2,3}, {-1,1,-1,1,-1,1}, sd::DataType::INT32); - - NDArray exp('c', {2,3}, {2,-2,2,2,-2,2}, sd::DataType::FLOAT32); - NDArray z('c', {2,3}, {100,100,100,100,100,100}, sd::DataType::FLOAT32); - - std::vector dimensions = {0}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // evaluate yTad data - shape::TAD yTad; - yTad.init(y.shapeInfo(), dimensions.data(), dimensions.size()); - yTad.createTadOnlyShapeInfo(); - yTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - hostData.emplace_back(yTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(yTad.tadOnlyShapeInfo));// 3 -- yTadShapeInfo - hostData.emplace_back(yTad.tadOffsets, yTad.numTads * sizeof(Nd4jLong)); // 4 -- yTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduce3All(&lc, sd::reduce3::Dot, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - (Nd4jLong*)devicePtrs[3], (Nd4jLong*)devicePtrs[4]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray y('c', {2, 3}, {-1, 1, -1, 1, -1, 1}, sd::DataType::INT32); + + NDArray exp('c', {2, 3}, {2, -2, 2, 2, -2, 2}, sd::DataType::FLOAT32); + NDArray z('c', {2, 3}, {100, 100, 100, 100, 100, 100}, sd::DataType::FLOAT32); + + std::vector dimensions = {0}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // evaluate yTad data + shape::TAD yTad; + yTad.init(y.shapeInfo(), dimensions.data(), dimensions.size()); + yTad.createTadOnlyShapeInfo(); + yTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + hostData.emplace_back( + yTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(yTad.tadOnlyShapeInfo)); // 3 -- yTadShapeInfo + hostData.emplace_back(yTad.tadOffsets, + yTad.numTads * sizeof(Nd4jLong)); // 4 -- yTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3All( + &lc, sd::reduce3::Dot, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], + (Nd4jLong *)devicePtrs[3], (Nd4jLong *)devicePtrs[4]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduce3All_2) { - - NDArray x('c', {2,2}, {1,2,3,4}, sd::DataType::DOUBLE); - NDArray y('c', {2,3}, {1.5,1.5,1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); - - NDArray exp('c', {2,3}, {6,6,6,9,9,9}, sd::DataType::DOUBLE); - NDArray z('c', {2,3}, {100,100,100,100,100,100,},sd::DataType::DOUBLE); - - std::vector dimensions = {0}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // evaluate yTad data - shape::TAD yTad; - yTad.init(y.shapeInfo(), dimensions.data(), dimensions.size()); - yTad.createTadOnlyShapeInfo(); - yTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - hostData.emplace_back(yTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(yTad.tadOnlyShapeInfo));// 3 -- yTadShapeInfo - hostData.emplace_back(yTad.tadOffsets, yTad.numTads * sizeof(Nd4jLong)); // 4-- yTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduce3All(&lc, sd::reduce3::Dot, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - (Nd4jLong*)devicePtrs[3], (Nd4jLong*)devicePtrs[4]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::DOUBLE); + NDArray y('c', {2, 3}, {1.5, 1.5, 1.5, 1.5, 1.5, 1.5}, sd::DataType::DOUBLE); + + NDArray exp('c', {2, 3}, {6, 6, 6, 9, 9, 9}, sd::DataType::DOUBLE); + NDArray z('c', {2, 3}, + { + 100, + 100, + 100, + 100, + 100, + 100, + }, + sd::DataType::DOUBLE); + + std::vector dimensions = {0}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // evaluate yTad data + shape::TAD yTad; + yTad.init(y.shapeInfo(), dimensions.data(), dimensions.size()); + yTad.createTadOnlyShapeInfo(); + yTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + hostData.emplace_back( + yTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(yTad.tadOnlyShapeInfo)); // 3 -- yTadShapeInfo + hostData.emplace_back(yTad.tadOffsets, + yTad.numTads * sizeof(Nd4jLong)); // 4-- yTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3All( + &lc, sd::reduce3::Dot, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], + (Nd4jLong *)devicePtrs[3], (Nd4jLong *)devicePtrs[4]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execIndexReduce_1) { - - NDArray x('c', {2,3}, {100,100,100,100,100,100}, sd::DataType::DOUBLE); - x.linspace(-2.); x.syncToDevice(); - NDArray exp('c', {2}, {2, 2}, sd::DataType::INT64); - NDArray z('c', {2}, {100,100}, sd::DataType::INT64); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execIndexReduce(&lc, sd::indexreduce::IndexMax, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3}, {100, 100, 100, 100, 100, 100}, sd::DataType::DOUBLE); + x.linspace(-2.); + x.syncToDevice(); + NDArray exp('c', {2}, {2, 2}, sd::DataType::INT64); + NDArray z('c', {2}, {100, 100}, sd::DataType::INT64); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execIndexReduce( + &lc, sd::indexreduce::IndexMax, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execIndexReduce_2) { - - NDArray x('c', {2,3,4,5}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, - 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, - 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, - 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, - 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, - 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::FLOAT32); - x.linspace(-2.f); x.syncToDevice(); - NDArray exp('c', {2,5}, {11,11,11,11,11,11,11,11,11,11}, sd::DataType::INT64); - NDArray z('c', {2,5}, {100,100,100,100,100,100,100,100,100,100}, sd::DataType::INT64); - - std::vector dimensions = {1,2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execIndexReduce(&lc, sd::indexreduce::IndexMax, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x( + 'c', {2, 3, 4, 5}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::FLOAT32); + x.linspace(-2.f); + x.syncToDevice(); + NDArray exp('c', {2, 5}, {11, 11, 11, 11, 11, 11, 11, 11, 11, 11}, + sd::DataType::INT64); + NDArray z('c', {2, 5}, {100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::INT64); + + std::vector dimensions = {1, 2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execIndexReduce( + &lc, sd::indexreduce::IndexMax, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execIndexReduce_3) { - - NDArray x('c', {2,3,4,5}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, - 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, - 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, - 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, - 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, - 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::DOUBLE); - x.linspace(-2.); x.syncToDevice(); - NDArray exp('c', {3}, {39, 39, 39}, sd::DataType::INT64); - NDArray z('c', {3}, {100,100,100}, sd::DataType::INT64); - - std::vector dimensions = {0,2,3}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execIndexReduce(&lc, sd::indexreduce::IndexMax, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x( + 'c', {2, 3, 4, 5}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::DOUBLE); + x.linspace(-2.); + x.syncToDevice(); + NDArray exp('c', {3}, {39, 39, 39}, sd::DataType::INT64); + NDArray z('c', {3}, {100, 100, 100}, sd::DataType::INT64); + + std::vector dimensions = {0, 2, 3}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execIndexReduce( + &lc, sd::indexreduce::IndexMax, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execScalar_1) { - - if (!Environment::getInstance().isExperimentalBuild()) - return; - - NDArray x('c', {2,3}, {0,1,2,3,4,5}, sd::DataType::INT64); - NDArray exp('c',{2,3}, {0,0,1,1,2,2}, sd::DataType::INT64); - NDArray scalar('c',{}, std::vector{2.f}, sd::DataType::FLOAT32); - NDArray z('c', {2,3}, {100,100,100,100,100,100}, sd::DataType::INT64); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execScalar(&lc, sd::scalar::Divide, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), - nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + if (!Environment::getInstance().isExperimentalBuild()) return; + + NDArray x('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::INT64); + NDArray exp('c', {2, 3}, {0, 0, 1, 1, 2, 2}, sd::DataType::INT64); + NDArray scalar('c', {}, std::vector{2.f}, sd::DataType::FLOAT32); + NDArray z('c', {2, 3}, {100, 100, 100, 100, 100, 100}, sd::DataType::INT64); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execScalar( + &lc, sd::scalar::Divide, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, scalar.shapeInfo(), scalar.specialBuffer(), + scalar.specialShapeInfo(), nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execScalar_2) { - - if (!Environment::getInstance().isExperimentalBuild()) - return; - - NDArray x('c', {2,3}, {-1,-2,-3,-4,-5,-6}, sd::DataType::INT64); - NDArray exp('c',{2,3}, {10,10,10,10,10,10}, sd::DataType::FLOAT32); - NDArray scalar('c',{}, std::vector{10.f}, sd::DataType::FLOAT32); - NDArray z('c', {2,3}, {100,100,100,100,100,100}, sd::DataType::FLOAT32); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execScalar(&lc, sd::scalar::CopyPws, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), - nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + if (!Environment::getInstance().isExperimentalBuild()) return; + + NDArray x('c', {2, 3}, {-1, -2, -3, -4, -5, -6}, sd::DataType::INT64); + NDArray exp('c', {2, 3}, {10, 10, 10, 10, 10, 10}, sd::DataType::FLOAT32); + NDArray scalar('c', {}, std::vector{10.f}, sd::DataType::FLOAT32); + NDArray z('c', {2, 3}, {100, 100, 100, 100, 100, 100}, sd::DataType::FLOAT32); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execScalar( + &lc, sd::scalar::CopyPws, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, scalar.shapeInfo(), scalar.specialBuffer(), + scalar.specialShapeInfo(), nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execScalar_3) { - - if (!Environment::getInstance().isExperimentalBuild()) - return; - - NDArray x('c', {2,3,2}, {0,1,2,3,4,5,6,7,8,9,10,11}, sd::DataType::INT64); - NDArray scalars('c',{2,2}, {1,2,3,4}, sd::DataType::FLOAT32); - NDArray exp('c', {2,3,2}, {0,0,2,1,4,2, 2,1,2,2,3,2}, sd::DataType::INT64); - NDArray z('c', {2,3,2}, {100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::INT64); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execScalar(&lc, sd::scalar::Divide, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, scalars.shapeInfo(), scalars.specialBuffer(), scalars.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + if (!Environment::getInstance().isExperimentalBuild()) return; + + NDArray x('c', {2, 3, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + sd::DataType::INT64); + NDArray scalars('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::FLOAT32); + NDArray exp('c', {2, 3, 2}, {0, 0, 2, 1, 4, 2, 2, 1, 2, 2, 3, 2}, + sd::DataType::INT64); + NDArray z('c', {2, 3, 2}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::INT64); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execScalar( + &lc, sd::scalar::Divide, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, scalars.shapeInfo(), + scalars.specialBuffer(), scalars.specialShapeInfo(), (int *)devicePtrs[0], + dimensions.size(), (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], + nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execScalarBool_1) { - - NDArray x('c', {2,3}, {-1,-2,0,1,2,3}, sd::DataType::BFLOAT16); - NDArray scalar('c',{}, std::vector{0}, sd::DataType::BFLOAT16); - NDArray exp('c',{2,3}, {0,0,0,1,1,1}, sd::DataType::BOOL); - NDArray z('c', {2,3}, {100,100,100,100,100,100,}, sd::DataType::BOOL); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - // call cuda kernel which calculates result - NativeOpExecutioner::execScalarBool(&lc, sd::scalar::GreaterThan, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), - nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3}, {-1, -2, 0, 1, 2, 3}, sd::DataType::BFLOAT16); + NDArray scalar('c', {}, std::vector{0}, sd::DataType::BFLOAT16); + NDArray exp('c', {2, 3}, {0, 0, 0, 1, 1, 1}, sd::DataType::BOOL); + NDArray z('c', {2, 3}, + { + 100, + 100, + 100, + 100, + 100, + 100, + }, + sd::DataType::BOOL); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + // call cuda kernel which calculates result + NativeOpExecutioner::execScalarBool( + &lc, sd::scalar::GreaterThan, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, scalar.shapeInfo(), scalar.specialBuffer(), + scalar.specialShapeInfo(), nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execScalarBool_2) { - - NDArray x('c', {2,3}, {0,1,2,3,4,5}, sd::DataType::FLOAT32); - NDArray scalars('c',{2}, {-1,4}, sd::DataType::FLOAT32); - NDArray exp('c', {2,3}, {1,1,1,0,0,1}, sd::DataType::BOOL); - NDArray z('c', {2,3}, {100,100,100,100,100,100}, sd::DataType::BOOL); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execScalarBool(&lc, sd::scalar::GreaterThan, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, scalars.shapeInfo(), scalars.specialBuffer(), scalars.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray scalars('c', {2}, {-1, 4}, sd::DataType::FLOAT32); + NDArray exp('c', {2, 3}, {1, 1, 1, 0, 0, 1}, sd::DataType::BOOL); + NDArray z('c', {2, 3}, {100, 100, 100, 100, 100, 100}, sd::DataType::BOOL); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execScalarBool( + &lc, sd::scalar::GreaterThan, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, scalars.shapeInfo(), + scalars.specialBuffer(), scalars.specialShapeInfo(), (int *)devicePtrs[0], + dimensions.size(), (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], + nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execBroadcast_1) { - - if (!Environment::getInstance().isExperimentalBuild()) - return; - - NDArray x('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::INT32); - NDArray y('c', {3}, {10, 20, 30}, sd::DataType::INT64); - NDArray z('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::INT32); - NDArray exp('c', {2,3,4}, {10, 11, 12, 13,24, 25, 26, 27,38, 39, 40, 41,22, 23, 24, 25,36, 37, 38, 39,50, 51, 52, 53}, sd::DataType::INT32); - x.linspace(0); x.syncToDevice(); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execBroadcast(&lc, sd::broadcast::Add, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + if (!Environment::getInstance().isExperimentalBuild()) return; + + NDArray x('c', {2, 3, 4}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::INT32); + NDArray y('c', {3}, {10, 20, 30}, sd::DataType::INT64); + NDArray z('c', {2, 3, 4}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::INT32); + NDArray exp('c', {2, 3, 4}, {10, 11, 12, 13, 24, 25, 26, 27, 38, 39, 40, 41, + 22, 23, 24, 25, 36, 37, 38, 39, 50, 51, 52, 53}, + sd::DataType::INT32); + x.linspace(0); + x.syncToDevice(); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execBroadcast( + &lc, sd::broadcast::Add, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execBroadcast_2) { - - if (!Environment::getInstance().isExperimentalBuild()) - return; - - NDArray x('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::INT32); - NDArray y('c', {2,4}, {10,20,30,40,50,60,70,80}, sd::DataType::FLOAT32); - NDArray z('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::FLOAT32); - NDArray exp('c', {2,3,4}, {10., 21., 32., 43., 14., 25., 36., 47., 18., 29., 40., 51., 62., 73., 84., 95., 66., 77., 88., 99., 70., 81., 92., 103}, sd::DataType::FLOAT32); - x.linspace(0); x.syncToDevice(); - - std::vector dimensions = {0,2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execBroadcast(&lc, sd::broadcast::Add, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + if (!Environment::getInstance().isExperimentalBuild()) return; + + NDArray x('c', {2, 3, 4}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::INT32); + NDArray y('c', {2, 4}, {10, 20, 30, 40, 50, 60, 70, 80}, + sd::DataType::FLOAT32); + NDArray z('c', {2, 3, 4}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::FLOAT32); + NDArray exp('c', {2, 3, 4}, + {10., 21., 32., 43., 14., 25., 36., 47., 18., 29., 40., 51., + 62., 73., 84., 95., 66., 77., 88., 99., 70., 81., 92., 103}, + sd::DataType::FLOAT32); + x.linspace(0); + x.syncToDevice(); + + std::vector dimensions = {0, 2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execBroadcast( + &lc, sd::broadcast::Add, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execBroadcastBool_1) { - - NDArray x('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::INT32); - NDArray y('c', {3}, {2, 12, 22}, sd::DataType::INT32); - NDArray z('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,}, sd::DataType::BOOL); - NDArray exp('c', {2,3,4}, {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0}, sd::DataType::BOOL); - x.linspace(1); x.syncToDevice(); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execBroadcastBool(&lc, sd::broadcast::EqualTo, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::INT32); + NDArray y('c', {3}, {2, 12, 22}, sd::DataType::INT32); + NDArray z('c', {2, 3, 4}, + { + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + }, + sd::DataType::BOOL); + NDArray exp('c', {2, 3, 4}, {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0}, + sd::DataType::BOOL); + x.linspace(1); + x.syncToDevice(); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execBroadcastBool( + &lc, sd::broadcast::EqualTo, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execBroadcastBool_2) { - - NDArray x('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100},sd::DataType::FLOAT32); - NDArray y('c', {2,4}, {1,10,10,15,20,20,20,24}, sd::DataType::FLOAT32); - NDArray z('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::BOOL); - NDArray exp('c', {2,3,4}, {1, 0, 0, 0,0, 0, 0, 0,0, 1, 0, 0,0, 0, 0, 0,0, 0, 0, 0,0, 0, 0, 1}, sd::DataType::BOOL); - x.linspace(1); x.syncToDevice(); - - std::vector dimensions = {0,2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execBroadcastBool(&lc, sd::broadcast::EqualTo, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::FLOAT32); + NDArray y('c', {2, 4}, {1, 10, 10, 15, 20, 20, 20, 24}, + sd::DataType::FLOAT32); + NDArray z('c', {2, 3, 4}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::BOOL); + NDArray exp('c', {2, 3, 4}, {1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, + sd::DataType::BOOL); + x.linspace(1); + x.syncToDevice(); + + std::vector dimensions = {0, 2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execBroadcastBool( + &lc, sd::broadcast::EqualTo, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execPairwiseTransform_1) { - - if (!Environment::getInstance().isExperimentalBuild()) - return; - - NDArray x('c', {2,2,2}, {1,5,3,7,2,6,4,8}, sd::DataType::INT32); - NDArray y('c', {4,2}, {0.1,0.2,0.3,0.4,1.5,0.6,0.7,1.8}, sd::DataType::DOUBLE); - NDArray z('c', {8}, {100,100,100,100,100,100,100,100}, sd::DataType::INT32); - NDArray exp('c', {8}, {0,1,2,3,3,5,6,6}, sd::DataType::INT32); - x.permutei({2,1,0}); // -> {1,2,3,4,5,6,7,8} - x.syncShape(); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execPairwiseTransform(&lc, sd::pairwise::Subtract, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + if (!Environment::getInstance().isExperimentalBuild()) return; + + NDArray x('c', {2, 2, 2}, {1, 5, 3, 7, 2, 6, 4, 8}, sd::DataType::INT32); + NDArray y('c', {4, 2}, {0.1, 0.2, 0.3, 0.4, 1.5, 0.6, 0.7, 1.8}, + sd::DataType::DOUBLE); + NDArray z('c', {8}, {100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::INT32); + NDArray exp('c', {8}, {0, 1, 2, 3, 3, 5, 6, 6}, sd::DataType::INT32); + x.permutei({2, 1, 0}); // -> {1,2,3,4,5,6,7,8} + x.syncShape(); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execPairwiseTransform( + &lc, sd::pairwise::Subtract, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execPairwiseBoolTransform_1) { - - NDArray x('c', {2,2,2}, {1,5,3,7,2,6,4,8}, sd::DataType::INT64); - NDArray y('c', {4,2}, {0,2,0,4,0,6,0,8}, sd::DataType::INT64); - NDArray z('c', {8}, {100,100,100,100,100,100,100,100}, sd::DataType::BOOL); - NDArray exp('c', {8}, {0,1,0,1,0,1,0,1}, sd::DataType::BOOL); - x.permutei({2,1,0}); // -> {1,2,3,4,5,6,7,8} - x.syncShape(); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execPairwiseBoolTransform(&lc, sd::pairwise::EqualTo, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2, 2}, {1, 5, 3, 7, 2, 6, 4, 8}, sd::DataType::INT64); + NDArray y('c', {4, 2}, {0, 2, 0, 4, 0, 6, 0, 8}, sd::DataType::INT64); + NDArray z('c', {8}, {100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::BOOL); + NDArray exp('c', {8}, {0, 1, 0, 1, 0, 1, 0, 1}, sd::DataType::BOOL); + x.permutei({2, 1, 0}); // -> {1,2,3,4,5,6,7,8} + x.syncShape(); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execPairwiseBoolTransform( + &lc, sd::pairwise::EqualTo, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } - //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execTransformFloat_1) { - - NDArray x('c', {2,2}, {0, 6.25, 2.25, 12.25}, sd::DataType::DOUBLE); - NDArray z('c', {4}, {100,100,100,100}, sd::DataType::FLOAT32); - NDArray exp('c', {4}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); - x.permutei({1,0}); - x.syncShape(); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execTransformFloat(&lc, sd::transform::Sqrt, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2}, {0, 6.25, 2.25, 12.25}, sd::DataType::DOUBLE); + NDArray z('c', {4}, {100, 100, 100, 100}, sd::DataType::FLOAT32); + NDArray exp('c', {4}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + x.permutei({1, 0}); + x.syncShape(); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformFloat( + &lc, sd::transform::Sqrt, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execTransformFloat_2) { - - NDArray x('c', {1,4}, {0, 4, 9, 16}, sd::DataType::INT64); - NDArray z('c', {2,2}, {100,100,100,100}, sd::DataType::DOUBLE); - NDArray exp('c', {2,2}, {0, 2, 3, 4}, sd::DataType::DOUBLE); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execTransformFloat(&lc, sd::transform::Sqrt, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {1, 4}, {0, 4, 9, 16}, sd::DataType::INT64); + NDArray z('c', {2, 2}, {100, 100, 100, 100}, sd::DataType::DOUBLE); + NDArray exp('c', {2, 2}, {0, 2, 3, 4}, sd::DataType::DOUBLE); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformFloat( + &lc, sd::transform::Sqrt, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execTransformAny_1) { - - NDArray x('c', {2,2}, {0, 6.25, 2.25, 12.25}, sd::DataType::DOUBLE); - NDArray z('c', {4,1}, {100,100,100,100}, sd::DataType::INT32); - NDArray exp('c', {4,1}, {0, 2, 6, 12}, sd::DataType::INT32); - x.permutei({1,0}); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execTransformAny(&lc, sd::transform::Assign, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2}, {0, 6.25, 2.25, 12.25}, sd::DataType::DOUBLE); + NDArray z('c', {4, 1}, {100, 100, 100, 100}, sd::DataType::INT32); + NDArray exp('c', {4, 1}, {0, 2, 6, 12}, sd::DataType::INT32); + x.permutei({1, 0}); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformAny( + &lc, sd::transform::Assign, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execTransformAny_2) { - - NDArray x('c', {1,4}, {0, 6.25, 2.25, 12.25}, sd::DataType::BFLOAT16); - NDArray z('c', {2,2}, {100,100,100,100}, sd::DataType::FLOAT32); - NDArray exp('c', {2,2}, {0, 6.25, 2.25, 12.25}, sd::DataType::FLOAT32); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execTransformAny(&lc, sd::transform::Assign, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {1, 4}, {0, 6.25, 2.25, 12.25}, sd::DataType::BFLOAT16); + NDArray z('c', {2, 2}, {100, 100, 100, 100}, sd::DataType::FLOAT32); + NDArray exp('c', {2, 2}, {0, 6.25, 2.25, 12.25}, sd::DataType::FLOAT32); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformAny( + &lc, sd::transform::Assign, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execTransformStrict_1) { - - NDArray x('c', {2,3}, {0,2,4,1,3,5}, sd::DataType::DOUBLE); - NDArray z('c', {3,2}, {100,100,100,100,100,100}, sd::DataType::DOUBLE); - NDArray exp('c', {3,2}, {0, 3, 12, 27, 48, 75}, sd::DataType::DOUBLE); - x.permutei({1,0}); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execTransformStrict(&lc, sd::transform::CubeDerivative, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3}, {0, 2, 4, 1, 3, 5}, sd::DataType::DOUBLE); + NDArray z('c', {3, 2}, {100, 100, 100, 100, 100, 100}, sd::DataType::DOUBLE); + NDArray exp('c', {3, 2}, {0, 3, 12, 27, 48, 75}, sd::DataType::DOUBLE); + x.permutei({1, 0}); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformStrict( + &lc, sd::transform::CubeDerivative, nullptr, x.shapeInfo(), + x.specialBuffer(), x.specialShapeInfo(), nullptr, z.shapeInfo(), + z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execTransformStrict_2) { - - NDArray x('c', {6}, {0,1,2,3,4,5}, sd::DataType::FLOAT32); - NDArray z('c', {3,2}, {100,100,100,100,100,100}, sd::DataType::FLOAT32); - NDArray exp('c', {3,2}, {0, 3, 12, 27, 48, 75}, sd::DataType::FLOAT32); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execTransformStrict(&lc, sd::transform::CubeDerivative, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {6}, {0, 1, 2, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray z('c', {3, 2}, {100, 100, 100, 100, 100, 100}, sd::DataType::FLOAT32); + NDArray exp('c', {3, 2}, {0, 3, 12, 27, 48, 75}, sd::DataType::FLOAT32); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformStrict( + &lc, sd::transform::CubeDerivative, nullptr, x.shapeInfo(), + x.specialBuffer(), x.specialShapeInfo(), nullptr, z.shapeInfo(), + z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execTransformSame_1) { - - NDArray x('c', {2,3}, {0,2.5,4.5,1.5,3.5,5.5}, sd::DataType::DOUBLE); - NDArray z('c', {1,6}, {100,100,100,100,100,100}, sd::DataType::DOUBLE); - NDArray exp('c', {1,6}, {0,2.25,6.25,12.25,20.25,30.25}, sd::DataType::DOUBLE); - x.permutei({1,0}); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execTransformSame(&lc, sd::transform::Square, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3}, {0, 2.5, 4.5, 1.5, 3.5, 5.5}, sd::DataType::DOUBLE); + NDArray z('c', {1, 6}, {100, 100, 100, 100, 100, 100}, sd::DataType::DOUBLE); + NDArray exp('c', {1, 6}, {0, 2.25, 6.25, 12.25, 20.25, 30.25}, + sd::DataType::DOUBLE); + x.permutei({1, 0}); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformSame( + &lc, sd::transform::Square, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execTransformSame_2) { - - NDArray x('c', {6}, {0,1,2,3,4,5}, sd::DataType::INT32); - NDArray z('c', {3,2}, {100,100,100,100,100,100}, sd::DataType::INT32); - NDArray exp('c', {3,2}, {0,1,4,9,16,25}, sd::DataType::INT32); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execTransformSame(&lc, sd::transform::Square, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {6}, {0, 1, 2, 3, 4, 5}, sd::DataType::INT32); + NDArray z('c', {3, 2}, {100, 100, 100, 100, 100, 100}, sd::DataType::INT32); + NDArray exp('c', {3, 2}, {0, 1, 4, 9, 16, 25}, sd::DataType::INT32); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformSame( + &lc, sd::transform::Square, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execTransformBool_1) { - - NDArray x('c', {2,3}, {0,2,4,-1,-3,-5}, sd::DataType::DOUBLE); - NDArray z('c', {1,6}, {100,100,100,100,100,100}, sd::DataType::BOOL); - NDArray exp('c', {1,6}, {0,0,1,0,1,0}, sd::DataType::BOOL); - x.permutei({1,0}); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execTransformBool(&lc, sd::transform::IsPositive, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3}, {0, 2, 4, -1, -3, -5}, sd::DataType::DOUBLE); + NDArray z('c', {1, 6}, {100, 100, 100, 100, 100, 100}, sd::DataType::BOOL); + NDArray exp('c', {1, 6}, {0, 0, 1, 0, 1, 0}, sd::DataType::BOOL); + x.permutei({1, 0}); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformBool( + &lc, sd::transform::IsPositive, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execTransformBool_2) { - - NDArray x('c', {6}, {0,-1,2,-3,4,-5}, sd::DataType::INT32); - NDArray z('c', {3,2}, {100,100,100,100,100,100}, sd::DataType::BOOL); - NDArray exp('c', {3,2}, {0,0,1,0,1,0}, sd::DataType::BOOL); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execTransformBool(&lc, sd::transform::IsPositive, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {6}, {0, -1, 2, -3, 4, -5}, sd::DataType::INT32); + NDArray z('c', {3, 2}, {100, 100, 100, 100, 100, 100}, sd::DataType::BOOL); + NDArray exp('c', {3, 2}, {0, 0, 1, 0, 1, 0}, sd::DataType::BOOL); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformBool( + &lc, sd::transform::IsPositive, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceFloat_1) { - - NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, sd::DataType::INT32); - NDArray z('c', {3}, {100,100,100}, sd::DataType::FLOAT32); - NDArray exp('c', {3}, {2.5, 6.5, 10.5}, sd::DataType::FLOAT32); - x.permutei({2,1,0}); - - std::vector dimensions = {0,2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceFloat(&lc, sd::reduce::Mean, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, + sd::DataType::INT32); + NDArray z('c', {3}, {100, 100, 100}, sd::DataType::FLOAT32); + NDArray exp('c', {3}, {2.5, 6.5, 10.5}, sd::DataType::FLOAT32); + x.permutei({2, 1, 0}); + + std::vector dimensions = {0, 2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceFloat( + &lc, sd::reduce::Mean, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceFloat_2) { - - NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, sd::DataType::INT32); - NDArray z('c', {2,4}, {100,100,100,100,100,100,100,100}, sd::DataType::DOUBLE); - NDArray exp('c', {2,4}, {-1., 0., 1., 2.,11., 12., 13., 14.}, sd::DataType::DOUBLE); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceFloat(&lc, sd::reduce::Mean, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, + sd::DataType::INT32); + NDArray z('c', {2, 4}, {100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::DOUBLE); + NDArray exp('c', {2, 4}, {-1., 0., 1., 2., 11., 12., 13., 14.}, + sd::DataType::DOUBLE); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceFloat( + &lc, sd::reduce::Mean, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceSame_1) { - - NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, sd::DataType::INT32); - NDArray z('c', {3}, {100,100,100}, sd::DataType::INT32); - NDArray exp('c', {3}, {20, 52, 84}, sd::DataType::INT32); - x.permutei({2,1,0}); - - std::vector dimensions = {0,2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceSame(&lc, sd::reduce::Sum, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, + sd::DataType::INT32); + NDArray z('c', {3}, {100, 100, 100}, sd::DataType::INT32); + NDArray exp('c', {3}, {20, 52, 84}, sd::DataType::INT32); + x.permutei({2, 1, 0}); + + std::vector dimensions = {0, 2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceSame( + &lc, sd::reduce::Sum, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceSame_2) { - - NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, sd::DataType::FLOAT32); - NDArray z('c', {2,4}, {100,100,100,100,100,100,100,100}, sd::DataType::FLOAT32); - NDArray exp('c', {2,4}, {-3., 0., 3., 6.,33., 36., 39., 42.}, sd::DataType::FLOAT32); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceSame(&lc, sd::reduce::Sum, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, + sd::DataType::FLOAT32); + NDArray z('c', {2, 4}, {100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::FLOAT32); + NDArray exp('c', {2, 4}, {-3., 0., 3., 6., 33., 36., 39., 42.}, + sd::DataType::FLOAT32); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceSame( + &lc, sd::reduce::Sum, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceBool_1) { - - NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,-7,-8,-9,-10,-11,-12,-13,-14,-15,-16,-17,-18}, sd::DataType::INT32); - NDArray z('c', {3}, {100,100,100}, sd::DataType::BOOL); - NDArray exp('c', {3}, {0, 1, 1}, sd::DataType::BOOL); - x.permutei({2,1,0}); - - - std::vector dimensions = {0,2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceBool(&lc, sd::reduce::IsPositive, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, + {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, + -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18}, + sd::DataType::INT32); + NDArray z('c', {3}, {100, 100, 100}, sd::DataType::BOOL); + NDArray exp('c', {3}, {0, 1, 1}, sd::DataType::BOOL); + x.permutei({2, 1, 0}); + + std::vector dimensions = {0, 2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceBool( + &lc, sd::reduce::IsPositive, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceBool_2) { - - NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,-7,-8,-9,-10,-11,-12,-13,-14,-15,-16,-17,-18}, sd::DataType::FLOAT32); - NDArray z('c', {2,4}, {100,100,100,100,100,100,100,100}, sd::DataType::BOOL); - NDArray exp('c', {2,4}, {1, 1, 1, 1, 0, 0, 0, 0}, sd::DataType::BOOL); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceBool(&lc, sd::reduce::IsPositive, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, + {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, + -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18}, + sd::DataType::FLOAT32); + NDArray z('c', {2, 4}, {100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::BOOL); + NDArray exp('c', {2, 4}, {1, 1, 1, 1, 0, 0, 0, 0}, sd::DataType::BOOL); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceBool( + &lc, sd::reduce::IsPositive, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceLong_1) { - - NDArray x('c', {2,3,4}, {-5,0,-3,0,-1,0,1,2,3,4,5,6,7,0,9,10,11,0,13,14,0,16,0,18}, sd::DataType::INT32); - NDArray z('c', {3}, {100,100,100}, sd::DataType::INT64); - NDArray exp('c', {3}, {5,6,6}, sd::DataType::INT64); - x.permutei({2,1,0}); - - std::vector dimensions = {0,2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceLong(&lc, sd::reduce::CountNonZero, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, {-5, 0, -3, 0, -1, 0, 1, 2, 3, 4, 5, 6, + 7, 0, 9, 10, 11, 0, 13, 14, 0, 16, 0, 18}, + sd::DataType::INT32); + NDArray z('c', {3}, {100, 100, 100}, sd::DataType::INT64); + NDArray exp('c', {3}, {5, 6, 6}, sd::DataType::INT64); + x.permutei({2, 1, 0}); + + std::vector dimensions = {0, 2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceLong( + &lc, sd::reduce::CountNonZero, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceLong_2) { - - NDArray x('c', {2,3,4}, {-5,0,-3,0,-1,0,1,2,3,4,5,6,7,0,9,10,11,0,13,14,0,16,0,18}, sd::DataType::FLOAT32); - NDArray z('c', {2,4}, {100,100,100,100,100,100,100,100}, sd::DataType::INT64); - NDArray exp('c', {2,4}, {3, 1, 3, 2, 2, 1, 2, 3}, sd::DataType::INT64); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceLong(&lc, sd::reduce::CountNonZero, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, {-5, 0, -3, 0, -1, 0, 1, 2, 3, 4, 5, 6, + 7, 0, 9, 10, 11, 0, 13, 14, 0, 16, 0, 18}, + sd::DataType::FLOAT32); + NDArray z('c', {2, 4}, {100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::INT64); + NDArray exp('c', {2, 4}, {3, 1, 3, 2, 2, 1, 2, 3}, sd::DataType::INT64); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceLong( + &lc, sd::reduce::CountNonZero, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceFloatScalar_1) { - - NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, sd::DataType::INT32); - NDArray z('c', {}, std::vector{100}, sd::DataType::FLOAT32); - NDArray exp('c', {}, std::vector{6.5}, sd::DataType::FLOAT32); - x.permutei({2,1,0}); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - void* reductionPointer; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - int* allocationPointer; - cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - lc.setReductionPointer(reductionPointer); - lc.setAllocationPointer(allocationPointer); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceFloatScalar(&lc, sd::reduce::Mean, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, + sd::DataType::INT32); + NDArray z('c', {}, std::vector{100}, sd::DataType::FLOAT32); + NDArray exp('c', {}, std::vector{6.5}, sd::DataType::FLOAT32); + x.permutei({2, 1, 0}); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void *reductionPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + int *allocationPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&allocationPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceFloatScalar( + &lc, sd::reduce::Mean, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceFloatScalar_2) { - - NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, sd::DataType::INT32); - NDArray z('c', {}, std::vector{100}, sd::DataType::DOUBLE); - NDArray exp('c', {}, std::vector{6.5}, sd::DataType::DOUBLE); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - void* reductionPointer; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - int* allocationPointer; - cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - lc.setReductionPointer(reductionPointer); - lc.setAllocationPointer(allocationPointer); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceFloatScalar(&lc, sd::reduce::Mean, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, + sd::DataType::INT32); + NDArray z('c', {}, std::vector{100}, sd::DataType::DOUBLE); + NDArray exp('c', {}, std::vector{6.5}, sd::DataType::DOUBLE); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void *reductionPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + int *allocationPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&allocationPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceFloatScalar( + &lc, sd::reduce::Mean, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceSameScalar_1) { - - NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, sd::DataType::INT32); - NDArray z('c', {}, std::vector{100}, sd::DataType::INT32); - NDArray exp('c', {}, std::vector{156}, sd::DataType::INT32); - x.permutei({2,1,0}); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - void* reductionPointer; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - int* allocationPointer; - cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - lc.setReductionPointer(reductionPointer); - lc.setAllocationPointer(allocationPointer); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceSameScalar(&lc, sd::reduce::Sum, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, + sd::DataType::INT32); + NDArray z('c', {}, std::vector{100}, sd::DataType::INT32); + NDArray exp('c', {}, std::vector{156}, sd::DataType::INT32); + x.permutei({2, 1, 0}); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void *reductionPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + int *allocationPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&allocationPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceSameScalar( + &lc, sd::reduce::Sum, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceSameScalar_2) { - - NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, sd::DataType::DOUBLE); - NDArray z('c', {}, std::vector{100}, sd::DataType::DOUBLE); - NDArray exp('c', {}, std::vector{156}, sd::DataType::DOUBLE); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - void* reductionPointer; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - int* allocationPointer; - cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - lc.setReductionPointer(reductionPointer); - lc.setAllocationPointer(allocationPointer); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceSameScalar(&lc, sd::reduce::Sum, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, + sd::DataType::DOUBLE); + NDArray z('c', {}, std::vector{100}, sd::DataType::DOUBLE); + NDArray exp('c', {}, std::vector{156}, sd::DataType::DOUBLE); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void *reductionPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + int *allocationPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&allocationPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceSameScalar( + &lc, sd::reduce::Sum, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceBoolScalar_1) { - - NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,-7,-8,-9,-10,-11,-12,-13,-14,-15,-16,-17,-18}, sd::DataType::INT32); - NDArray z('c', {}, std::vector{100}, sd::DataType::BOOL); - NDArray exp('c', {}, std::vector{1}, sd::DataType::BOOL); - x.permutei({2,1,0}); - x.syncShape(); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - void* reductionPointer; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - int* allocationPointer; - cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - lc.setReductionPointer(reductionPointer); - lc.setAllocationPointer(allocationPointer); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceBoolScalar(&lc, sd::reduce::IsPositive, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, + {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, + -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18}, + sd::DataType::INT32); + NDArray z('c', {}, std::vector{100}, sd::DataType::BOOL); + NDArray exp('c', {}, std::vector{1}, sd::DataType::BOOL); + x.permutei({2, 1, 0}); + x.syncShape(); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void *reductionPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + int *allocationPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&allocationPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceBoolScalar( + &lc, sd::reduce::IsPositive, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceBoolScalar_2) { - - NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,-7,-8,-9,-10,-11,-12,-13,-14,-15,-16,-17,-18}, sd::DataType::DOUBLE); - NDArray z('c', {}, std::vector{100}, sd::DataType::BOOL); - NDArray exp('c', {}, std::vector{1}, sd::DataType::BOOL); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - void* reductionPointer; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - int* allocationPointer; - cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - lc.setReductionPointer(reductionPointer); - lc.setAllocationPointer(allocationPointer); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceBoolScalar(&lc, sd::reduce::IsPositive, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, + {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, + -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18}, + sd::DataType::DOUBLE); + NDArray z('c', {}, std::vector{100}, sd::DataType::BOOL); + NDArray exp('c', {}, std::vector{1}, sd::DataType::BOOL); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void *reductionPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + int *allocationPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&allocationPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceBoolScalar( + &lc, sd::reduce::IsPositive, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceLongScalar_1) { - - NDArray x('c', {2,3,4}, {-5,0,-3,0,-1,0,1,2,3,4,5,6,7,0,9,10,11,0,13,14,0,16,0,18}, sd::DataType::INT32); - NDArray z('c', {}, std::vector{100}, sd::DataType::INT64); - NDArray exp('c', {}, std::vector{17}, sd::DataType::INT64); - x.permutei({2,1,0}); - x.syncShape(); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - void* reductionPointer; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - int* allocationPointer; - cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - lc.setReductionPointer(reductionPointer); - lc.setAllocationPointer(allocationPointer); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceLongScalar(&lc, sd::reduce::CountNonZero, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, {-5, 0, -3, 0, -1, 0, 1, 2, 3, 4, 5, 6, + 7, 0, 9, 10, 11, 0, 13, 14, 0, 16, 0, 18}, + sd::DataType::INT32); + NDArray z('c', {}, std::vector{100}, sd::DataType::INT64); + NDArray exp('c', {}, std::vector{17}, sd::DataType::INT64); + x.permutei({2, 1, 0}); + x.syncShape(); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void *reductionPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + int *allocationPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&allocationPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceLongScalar( + &lc, sd::reduce::CountNonZero, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceLongScalar_2) { - - NDArray x('c', {2,3,4}, {-5,0,-3,0,-1,0,1,2,3,4,5,6,7,0,9,10,11,0,13,14,0,16,0,18}, sd::DataType::DOUBLE); - NDArray z('c', {}, std::vector{100}, sd::DataType::INT64); - NDArray exp('c', {}, std::vector{17}, sd::DataType::INT64); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - void* reductionPointer; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - int* allocationPointer; - cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - lc.setReductionPointer(reductionPointer); - lc.setAllocationPointer(allocationPointer); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceLongScalar(&lc, sd::reduce::CountNonZero, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, {-5, 0, -3, 0, -1, 0, 1, 2, 3, 4, 5, 6, + 7, 0, 9, 10, 11, 0, 13, 14, 0, 16, 0, 18}, + sd::DataType::DOUBLE); + NDArray z('c', {}, std::vector{100}, sd::DataType::INT64); + NDArray exp('c', {}, std::vector{17}, sd::DataType::INT64); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void *reductionPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + int *allocationPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&allocationPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceLongScalar( + &lc, sd::reduce::CountNonZero, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduce3TAD_1) { - - NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, sd::DataType::FLOAT32); - NDArray y('c', {2,2}, {1,2,3,4}, sd::DataType::FLOAT32); - NDArray exp('c', {3}, {10,20,30}, sd::DataType::DOUBLE); - NDArray z('c', {3}, {100,100,100}, sd::DataType::DOUBLE); - - std::vector dimensions = {0,1}; - auto packX = ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), dimensions); - LaunchContext* context = x.getContext(); - - x.syncToDevice(); - y.syncToDevice(); - PointersManager pm(context, "execReduce3TAD_1"); - // call cuda kernel which calculates result - NativeOpExecutioner::execReduce3TAD(context, sd::reduce3::Dot, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, dimensions.size(), - packX.specialShapeInfo(), packX.specialOffsets(), nullptr, nullptr); - pm.synchronize(); -// cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); -// z.printIndexedBuffer("OutputReduce3TAD"); - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - + NDArray x('c', {2, 2, 3}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6}, + sd::DataType::FLOAT32); + NDArray y('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::FLOAT32); + NDArray exp('c', {3}, {10, 20, 30}, sd::DataType::DOUBLE); + NDArray z('c', {3}, {100, 100, 100}, sd::DataType::DOUBLE); + + std::vector dimensions = {0, 1}; + auto packX = ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), + dimensions); + LaunchContext *context = x.getContext(); + + x.syncToDevice(); + y.syncToDevice(); + PointersManager pm(context, "execReduce3TAD_1"); + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3TAD( + context, sd::reduce3::Dot, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, dimensions.size(), + packX.specialShapeInfo(), packX.specialOffsets(), nullptr, nullptr); + pm.synchronize(); + // cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + // z.printIndexedBuffer("OutputReduce3TAD"); + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduce3TAD_2) { - - NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, sd::DataType::INT64); - NDArray y('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::INT64); - NDArray exp('c', {2}, {10,73}, sd::DataType::FLOAT32); - NDArray z('c', {2}, {100,100}, sd::DataType::FLOAT32); - - std::vector dimensions = {0,2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduce3TAD(&lc, sd::reduce3::Dot, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2, 3}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6}, + sd::DataType::INT64); + NDArray y('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::INT64); + NDArray exp('c', {2}, {10, 73}, sd::DataType::FLOAT32); + NDArray z('c', {2}, {100, 100}, sd::DataType::FLOAT32); + + std::vector dimensions = {0, 2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3TAD( + &lc, sd::reduce3::Dot, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduce3TAD_3) { - - NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, sd::DataType::INT64); - NDArray y('c', {3}, {1,2,3}, sd::DataType::INT64); - NDArray exp('c', {2,2}, {-22,-4,14,32}, sd::DataType::FLOAT32); - NDArray z('c', {2,2}, {100,100,100,100}, sd::DataType::FLOAT32); - - std::vector dimensions = {2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduce3TAD(&lc, sd::reduce3::Dot, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2, 3}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6}, + sd::DataType::INT64); + NDArray y('c', {3}, {1, 2, 3}, sd::DataType::INT64); + NDArray exp('c', {2, 2}, {-22, -4, 14, 32}, sd::DataType::FLOAT32); + NDArray z('c', {2, 2}, {100, 100, 100, 100}, sd::DataType::FLOAT32); + + std::vector dimensions = {2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3TAD( + &lc, sd::reduce3::Dot, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduce3TAD_4) { - - NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, sd::DataType::DOUBLE); - NDArray y('c', {2,2,3}, {10,20,30,40,50,60,70,80,90,100,110,120}, sd::DataType::DOUBLE); - NDArray exp('c', {}, std::vector{1820}, sd::DataType::FLOAT32); - NDArray z('c', {}, std::vector{100}, sd::DataType::FLOAT32); - - std::vector dimensions = {0,1,2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduce3TAD(&lc, sd::reduce3::Dot, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2, 3}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6}, + sd::DataType::DOUBLE); + NDArray y('c', {2, 2, 3}, {10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120}, + sd::DataType::DOUBLE); + NDArray exp('c', {}, std::vector{1820}, sd::DataType::FLOAT32); + NDArray z('c', {}, std::vector{100}, sd::DataType::FLOAT32); + + std::vector dimensions = {0, 1, 2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3TAD( + &lc, sd::reduce3::Dot, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execSummaryStats_1) { // FIXME: Yurii, this test should be fixed if (1 > 0) - return; - - NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, sd::DataType::INT64); - NDArray exp('c', {}, std::vector{3.605551}, sd::DataType::FLOAT32); - NDArray z('c', {}, std::vector{100}, sd::DataType::FLOAT32); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - void* reductionPointer; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - lc.setReductionPointer(reductionPointer); - - // call cuda kernel which calculates result - NativeOpExecutioner::execSummaryStats(&lc, sd::variance::SummaryStatsStandardDeviation, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - true); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + return;NDArray x('c', {2, 2, 3}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6}, + sd::DataType::INT64); + NDArray exp('c', {}, std::vector{3.605551}, sd::DataType::FLOAT32); + NDArray z('c', {}, std::vector{100}, sd::DataType::FLOAT32); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void *reductionPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execSummaryStats( + &lc, sd::variance::SummaryStatsStandardDeviation, nullptr, x.shapeInfo(), + x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), + z.specialBuffer(), z.specialShapeInfo(), true); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execSummaryStats_2) { - - NDArray x('c', {2,2,3}, {-5,-4,-3,-20,-1,0,1,2,3,4,5,6}, sd::DataType::DOUBLE); - NDArray exp('c', {2}, {3.405877, 9.715966}, sd::DataType::FLOAT32); - NDArray z('c', {2}, {100,100}, sd::DataType::FLOAT32); - - std::vector dimensions = {0,2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execSummaryStats(&lc, sd::variance::SummaryStatsStandardDeviation, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - true); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2, 3}, {-5, -4, -3, -20, -1, 0, 1, 2, 3, 4, 5, 6}, + sd::DataType::DOUBLE); + NDArray exp('c', {2}, {3.405877, 9.715966}, sd::DataType::FLOAT32); + NDArray z('c', {2}, {100, 100}, sd::DataType::FLOAT32); + + std::vector dimensions = {0, 2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execSummaryStats( + &lc, sd::variance::SummaryStatsStandardDeviation, nullptr, x.shapeInfo(), + x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), + z.specialBuffer(), z.specialShapeInfo(), (int *)devicePtrs[0], + dimensions.size(), (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], + true); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } /* //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execSummaryStats_3) { - - NDArray x('c', {2,2,3}, {-5,-4,-3,-20,-1,0,1,2,3,4,5,6}, sd::DataType::DOUBLE); - NDArray exp('c', {2}, {10.606602, 2.121320}, sd::DataType::FLOAT32); - NDArray z('c', {2}, {100,100}, sd::DataType::FLOAT32); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execSummaryStats(&lc, sd::variance::SummaryStatsStandardDeviation, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.special(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - true); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2, 3}, {-5, -4, -3, -20, -1, 0, 1, 2, 3, 4, 5, 6}, + sd::DataType::DOUBLE); + NDArray exp('c', {2}, {10.606602, 2.121320}, sd::DataType::FLOAT32); + NDArray z('c', {2}, {100, 100}, sd::DataType::FLOAT32); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execSummaryStats( + &lc, sd::variance::SummaryStatsStandardDeviation, nullptr, x.shapeInfo(), + x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), + z.specialBuffer(), z.special(), (int *)devicePtrs[0], + dimensions.size(), (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], + true); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } */ //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execSummaryStatsScalar_1) { - - NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, sd::DataType::INT64); - NDArray exp('c', {}, std::vector{3.605551}, sd::DataType::FLOAT32); - NDArray z('c', {}, std::vector{100}, sd::DataType::FLOAT32); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - void* reductionPointer; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - lc.setReductionPointer(reductionPointer); - - // call cuda kernel which calculates result - NativeOpExecutioner::execSummaryStatsScalar(&lc, sd::variance::SummaryStatsStandardDeviation, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - true); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2, 3}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6}, + sd::DataType::INT64); + NDArray exp('c', {}, std::vector{3.605551}, sd::DataType::FLOAT32); + NDArray z('c', {}, std::vector{100}, sd::DataType::FLOAT32); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void *reductionPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execSummaryStatsScalar( + &lc, sd::variance::SummaryStatsStandardDeviation, nullptr, x.shapeInfo(), + x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), + z.specialBuffer(), z.specialShapeInfo(), true); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execRandom_1) { - -// NDArray z('c', {10}, {100,0,0,0,0,0,0,0,0,0}, sd::DataType::DOUBLE); - NDArray z('c', {10}, {100,0,0,0,0,0,0,0,0,100}, sd::DataType::FLOAT32); - NDArray exp('c', {10}, {0.050942, -0.183229, -0.093921, 0.075469, 0.257166, -0.254838, 0.342227, -0.682188, -0.004345, 0.464633}, sd::DataType::FLOAT32); - - sd::graph::RandomGenerator gen(119,5); - - cudaError_t cudaResult; - NDArray* array = &z; - ExtraArguments arguments({0.f, 0.5f}); - auto context = z.getContext(); - PointersManager pm(context, "tests::execRandom_1"); -// z.printIndexedBuffer("Input data"); -// z.syncToDevice(); - NativeOpExecutioner::execRandom(context, random::GaussianDistribution, &gen, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); - pm.synchronize(); - z.tickWriteDevice(); -// z.printIndexedBuffer("Output Gaussian"); -// RandomLauncher::fillGaussian(context, gen, &z, 0.f, 0.5f); -// pm.synchronize(); -// z.tickWriteDevice(); -// z.printIndexedBuffer("Output Gaussian"); - -// cudaStream_t stream; -// cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); -// LaunchContext lc(&stream); -// -// // ::execRandom(extraPointers, random::GaussianDistribution, &gen, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.special(), &extra); -// // call cuda kernel which calculates result -// NativeOpExecutioner::execRandom(&lc, sd::random::GaussianDistribution, -// &gen, -// nullptr, z.shapeInfo(), z.specialBuffer(), z.special(), -// nullptr, z.shapeInfo(), z.specialBuffer(), z.special(), -// nullptr, z.shapeInfo(), z.specialBuffer(), z.special(), -// extraArguments.argumentsAsT(z.dataType())); -// -// cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); -// ASSERT_EQ(cudaResult, 0); -// z.tickWriteDevice(); -// z.syncToHost(); -// z.printIndexedBuffer("Random1"); - ASSERT_EQ(exp, z); -// // verify results -// for (int e = 0; e < z.lengthOf(); e++) -// ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); -// cudaFree(dExtraArgs); - // free allocated global device memory -// cudaFree(dGen); - // delete cuda stream -// cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + // NDArray z('c', {10}, {100,0,0,0,0,0,0,0,0,0}, sd::DataType::DOUBLE); + NDArray z('c', {10}, {100, 0, 0, 0, 0, 0, 0, 0, 0, 100}, + sd::DataType::FLOAT32); + NDArray exp('c', {10}, + {0.050942, -0.183229, -0.093921, 0.075469, 0.257166, -0.254838, + 0.342227, -0.682188, -0.004345, 0.464633}, + sd::DataType::FLOAT32); + + sd::graph::RandomGenerator gen(119, 5); + + cudaError_t cudaResult; + NDArray *array = &z; + ExtraArguments arguments({0.f, 0.5f}); + auto context = z.getContext(); + PointersManager pm(context, "tests::execRandom_1"); + // z.printIndexedBuffer("Input data"); + // z.syncToDevice(); + NativeOpExecutioner::execRandom( + context, random::GaussianDistribution, &gen, array->buffer(), + array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), + array->buffer(), array->shapeInfo(), array->specialBuffer(), + array->specialShapeInfo(), array->buffer(), array->shapeInfo(), + array->specialBuffer(), array->specialShapeInfo(), + arguments.argumentsAsT(array->dataType())); + pm.synchronize(); + z.tickWriteDevice(); + // z.printIndexedBuffer("Output Gaussian"); + // RandomLauncher::fillGaussian(context, gen, &z, 0.f, 0.5f); + // pm.synchronize(); + // z.tickWriteDevice(); + // z.printIndexedBuffer("Output Gaussian"); + + // cudaStream_t stream; + // cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + // LaunchContext lc(&stream); + // + // // ::execRandom(extraPointers, random::GaussianDistribution, &gen, + // z.buffer(), z.shapeInfo(), z.specialBuffer(), z.special(), + // &extra); + // // call cuda kernel which calculates result + // NativeOpExecutioner::execRandom(&lc, sd::random::GaussianDistribution, + // &gen, + // nullptr, z.shapeInfo(), + //z.specialBuffer(), + // z.special(), nullptr, z.shapeInfo(), + // z.specialBuffer(), + // z.special(), nullptr, z.shapeInfo(), + // z.specialBuffer(), z.special(), + // extraArguments.argumentsAsT(z.dataType())); + // + // cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + // ASSERT_EQ(cudaResult, 0); + // z.tickWriteDevice(); + // z.syncToHost(); + // z.printIndexedBuffer("Random1"); + ASSERT_EQ(exp, z); + // // verify results + // for (int e = 0; e < z.lengthOf(); e++) + // ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + // cudaFree(dExtraArgs); + // free allocated global device memory + // cudaFree(dGen); + // delete cuda stream + // cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execRandom_2) { - - NDArray x('c', {10}, {0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1}, sd::DataType::DOUBLE); - NDArray z('c', {2,5}, {100,100,100,100,100,100,100,100,100,100}, sd::DataType::DOUBLE); - NDArray exp('c', {10}, {0., 0., 0.3, 0., 0.5, 0., 0.7, 0., 0., 1.}, sd::DataType::DOUBLE); - - ExtraArguments extraArguments({0.7}); - sd::graph::RandomGenerator gen(119,5); - -// // prepare input arrays for prepareDataForCuda function -// std::vector> hostData; -// hostData.emplace_back(extraArguments.data(), extraArguments.size() * sizeof(double)); // 0 -- dimensions -// std::vector devicePtrs(hostData.size(), nullptr); -// - // create cuda stream and LaunchContext - cudaError_t cudaResult; -// cudaStream_t stream; -// cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext* lc = x.getContext(); //(&stream); - - // allocate required amount of global device memory and copy host data to it -// cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execRandom(lc, sd::random::DropOut, - &gen, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - extraArguments.argumentsAsT(z.dataType())); - - cudaResult = cudaStreamSynchronize(*lc->getCudaStream()); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - z.syncToHost(); - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory -// for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream -// cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {10}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1}, + sd::DataType::DOUBLE); + NDArray z('c', {2, 5}, {100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::DOUBLE); + NDArray exp('c', {10}, {0., 0., 0.3, 0., 0.5, 0., 0.7, 0., 0., 1.}, + sd::DataType::DOUBLE); + + ExtraArguments extraArguments({0.7}); + sd::graph::RandomGenerator gen(119, 5); + + // // prepare input arrays for prepareDataForCuda function + // std::vector> hostData; + // hostData.emplace_back(extraArguments.data(), extraArguments.size() * + // sizeof(double)); // 0 -- dimensions std::vector + // devicePtrs(hostData.size(), nullptr); + // + // create cuda stream and LaunchContext + cudaError_t cudaResult; + // cudaStream_t stream; + // cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext *lc = x.getContext(); //(&stream); + + // allocate required amount of global device memory and copy host data to it + // cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + // ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execRandom( + lc, sd::random::DropOut, &gen, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), extraArguments.argumentsAsT(z.dataType())); + + cudaResult = cudaStreamSynchronize(*lc->getCudaStream()); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + z.syncToHost(); + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + // for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + // cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execRandom_3) { - - NDArray z('c', {10}, {100,100,100,100,100,100,100,100,100,100}, sd::DataType::DOUBLE); - NDArray exp('c', {10}, {2.373649, 2.239791, 1.887353, 2.488636, 2.068904, 2.281399, 1.828228, 2.228222, 2.490847, 1.669537}, sd::DataType::DOUBLE); - - std::vector extraArguments = {1.5, 2.5}; - sd::graph::RandomGenerator gen(119,5); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(extraArguments.data(), extraArguments.size() * sizeof(double)); // 0 -- dimensions - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execRandom(&lc, sd::random::UniformDistribution, - &gen, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - devicePtrs[0]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray z('c', {10}, {100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::DOUBLE); + NDArray exp('c', {10}, + {2.373649, 2.239791, 1.887353, 2.488636, 2.068904, 2.281399, + 1.828228, 2.228222, 2.490847, 1.669537}, + sd::DataType::DOUBLE); + + std::vector extraArguments = {1.5, 2.5}; + sd::graph::RandomGenerator gen(119, 5); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back( + extraArguments.data(), + extraArguments.size() * sizeof(double)); // 0 -- dimensions + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execRandom(&lc, sd::random::UniformDistribution, &gen, + nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), devicePtrs[0]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execRandom_4) { - - NDArray z('c', {2,5}, {1,2,3,4,5,6,7,8,9,10}, sd::DataType::FLOAT32); - NDArray exp('c', {10}, {2.373649, 2.281399, 2.239791, 1.828228, 1.887353, 2.228222, 2.488636, 2.490847, 2.068904, 1.669537}, sd::DataType::FLOAT32); - z.permutei({1,0}); - - ExtraArguments extraArguments({1.5, 2.5}); - sd::graph::RandomGenerator gen(119,5); - -// // prepare input arrays for prepareDataForCuda function -// std::vector> hostData; -// hostData.emplace_back(extraArguments.data(), extraArguments.size() * sizeof(double)); // 0 -- dimensions -// std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext -// cudaError_t cudaResult; -// cudaStream_t stream; -// cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); -// LaunchContext lc(&stream); -// -// // allocate required amount of global device memory and copy host data to it -// cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - auto context = z.getContext(); - PointersManager pm(context, "execRandom4"); - // call cuda kernel which calculates result - NativeOpExecutioner::execRandom(context, sd::random::UniformDistribution, - &gen, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - extraArguments.argumentsAsT(z.dataType())); - -// cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); -// z.printIndexedBuffer("Output Uniform4"); - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory -// for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream -// cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray z('c', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + sd::DataType::FLOAT32); + NDArray exp('c', {10}, + {2.373649, 2.281399, 2.239791, 1.828228, 1.887353, 2.228222, + 2.488636, 2.490847, 2.068904, 1.669537}, + sd::DataType::FLOAT32); + z.permutei({1, 0}); + + ExtraArguments extraArguments({1.5, 2.5}); + sd::graph::RandomGenerator gen(119, 5); + + // // prepare input arrays for prepareDataForCuda function + // std::vector> hostData; + // hostData.emplace_back(extraArguments.data(), extraArguments.size() * + // sizeof(double)); // 0 -- dimensions std::vector + // devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + // cudaError_t cudaResult; + // cudaStream_t stream; + // cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + // LaunchContext lc(&stream); + // + // // allocate required amount of global device memory and copy host data + // to it cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + // ASSERT_EQ(0, cudaResult); + auto context = z.getContext(); + PointersManager pm(context, "execRandom4"); + // call cuda kernel which calculates result + NativeOpExecutioner::execRandom(context, sd::random::UniformDistribution, + &gen, nullptr, z.shapeInfo(), + z.specialBuffer(), z.specialShapeInfo(), + extraArguments.argumentsAsT(z.dataType())); + + // cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + // z.printIndexedBuffer("Output Uniform4"); + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + // for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + // cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); } - diff --git a/libnd4j/tests_cpu/layers_tests/CudaBasicsTests2.cu b/libnd4j/tests_cpu/layers_tests/CudaBasicsTests2.cu index 28102cad59da..532a672b1aae 100644 --- a/libnd4j/tests_cpu/layers_tests/CudaBasicsTests2.cu +++ b/libnd4j/tests_cpu/layers_tests/CudaBasicsTests2.cu @@ -14,973 +14,1199 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author raver119@gmail.com - // +// +// @author raver119@gmail.com +// -#include "testlayers.h" #include #include +#include #include #include #include #include -#include -#include #include +#include +#include -#include +#include "testlayers.h" using namespace sd; using namespace sd::graph; class CudaBasicsTests2 : public testing::Test { -public: - + public: }; TEST_F(CudaBasicsTests2, test_devices_1) { - auto caps = Environment::getInstance().capabilities(); - ASSERT_FALSE(caps.empty()); + auto caps = Environment::getInstance().capabilities(); + ASSERT_FALSE(caps.empty()); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_1) { - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::FLOAT32); - NDArray c('f', {M,N}, sd::DataType::FLOAT32); - - NDArray exp('f', {M,N}, {0.1, 0.3, 0.5, 2.5, 2.7, 2.9, 4.9, 5.1, 5.3, 7.3, 7.5, 7.7, 9.7, 9.9, 10.1}, sd::DataType::FLOAT32); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - // c.printIndexedBuffer(); - - ASSERT_TRUE(c.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::FLOAT32); + NDArray b('f', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::FLOAT32); + NDArray c('f', {M, N}, sd::DataType::FLOAT32); + + NDArray exp('f', {M, N}, + {0.1, 0.3, 0.5, 2.5, 2.7, 2.9, 4.9, 5.1, 5.3, 7.3, 7.5, 7.7, 9.7, + 9.9, 10.1}, + sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + // c.printIndexedBuffer(); + + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_2) { - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); - NDArray c('f', {M,N}, sd::DataType::DOUBLE); - NDArray exp('f', {M,N}, {-1.6, -0.7, 0.2, -0.8, 0.1, 1., -0., 0.9, 1.8, 0.8, 1.7, 2.6, 1.6, 2.5, 3.4}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray b('f', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::DOUBLE); + NDArray c('f', {M, N}, sd::DataType::DOUBLE); + NDArray exp('f', {M, N}, + {-1.6, -0.7, 0.2, -0.8, 0.1, 1., -0., 0.9, 1.8, 0.8, 1.7, 2.6, + 1.6, 2.5, 3.4}, + sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_3) { - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); - NDArray c('f', {M,N}, sd::DataType::DOUBLE); - - NDArray exp('f', {M,N}, {-1.9, -0.9, 0.1, 1.3, 0.3, -0.7, -0.7, 0.3, 1.3, 0.1, -0.9, -1.9, 0.5, 1.5, 2.5}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray b('c', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::DOUBLE); + NDArray c('f', {M, N}, sd::DataType::DOUBLE); + + NDArray exp('f', {M, N}, + {-1.9, -0.9, 0.1, 1.3, 0.3, -0.7, -0.7, 0.3, 1.3, 0.1, -0.9, -1.9, + 0.5, 1.5, 2.5}, + sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_4) { - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); - NDArray c('c', {M,N}, sd::DataType::DOUBLE); - - NDArray exp('c', {M,N}, {0.1, 2.5, 4.9, 7.3, 9.7,0.3, 2.7, 5.1, 7.5, 9.9,0.5, 2.9, 5.3, 7.7, 10.1}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - ASSERT_TRUE(c.equalsTo(&exp)); - - - // NDArray* pA = a.permute({1,0}); - // NDArray* pB = b.permute({1,0}); - // NDArray* pC = c.permute({1,0}); - - // sd::MmulHelper::mmul(pB, pA, pC, 1., 0.); - // ASSERT_TRUE(c.equalsTo(&exp)); - - // delete pA; - // delete pB; - // delete pC; + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray b('f', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::DOUBLE); + NDArray c('c', {M, N}, sd::DataType::DOUBLE); + + NDArray exp('c', {M, N}, + {0.1, 2.5, 4.9, 7.3, 9.7, 0.3, 2.7, 5.1, 7.5, 9.9, 0.5, 2.9, 5.3, + 7.7, 10.1}, + sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + ASSERT_TRUE(c.equalsTo(&exp)); + + // NDArray* pA = a.permute({1,0}); + // NDArray* pB = b.permute({1,0}); + // NDArray* pC = c.permute({1,0}); + + // sd::MmulHelper::mmul(pB, pA, pC, 1., 0.); + // ASSERT_TRUE(c.equalsTo(&exp)); + + // delete pA; + // delete pB; + // delete pC; } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_5) { - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); - NDArray c('f', {M,N}, sd::DataType::DOUBLE); - - NDArray exp('f', {M,N}, {-8.8, -4.3, 0.2, 8.6, 4.1, -0.4, -8.4, -3.9, 0.6, 8.2, 3.7, -0.8, -8.0, -3.5, 1.}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray b('c', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::DOUBLE); + NDArray c('f', {M, N}, sd::DataType::DOUBLE); + + NDArray exp('f', {M, N}, + {-8.8, -4.3, 0.2, 8.6, 4.1, -0.4, -8.4, -3.9, 0.6, 8.2, 3.7, -0.8, + -8.0, -3.5, 1.}, + sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_6) { - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); - NDArray c('c', {M,N}, sd::DataType::DOUBLE); - - NDArray exp('c', {M,N}, {-1.6, -0.8, -0.0, 0.8, 1.6, -0.7, 0.1, 0.9, 1.7, 2.5, 0.2, 1.0, 1.8, 2.6, 3.4}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray b('f', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::DOUBLE); + NDArray c('c', {M, N}, sd::DataType::DOUBLE); + + NDArray exp('c', {M, N}, + {-1.6, -0.8, -0.0, 0.8, 1.6, -0.7, 0.1, 0.9, 1.7, 2.5, 0.2, 1.0, + 1.8, 2.6, 3.4}, + sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_7) { - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); - NDArray c('c', {M,N}, sd::DataType::DOUBLE); - - NDArray exp('c', {M,N}, {-1.9, 1.3, -0.7, 0.1, 0.5, -0.9, 0.3, 0.3, -0.9, 1.5, 0.1, -0.7, 1.3, -1.9, 2.5}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray b('c', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::DOUBLE); + NDArray c('c', {M, N}, sd::DataType::DOUBLE); + + NDArray exp('c', {M, N}, + {-1.9, 1.3, -0.7, 0.1, 0.5, -0.9, 0.3, 0.3, -0.9, 1.5, 0.1, -0.7, + 1.3, -1.9, 2.5}, + sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_8) { - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); - NDArray c('c', {M,N}, sd::DataType::DOUBLE); - - NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray b('c', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::DOUBLE); + NDArray c('c', {M, N}, sd::DataType::DOUBLE); + + NDArray exp('c', {M, N}, + {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, + -0.4, 0.6, -0.8, 1.}, + sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_9) { - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::FLOAT32); - NDArray c('c', {M,N}, sd::DataType::FLOAT32); - - NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::FLOAT32); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::FLOAT32); + NDArray b('c', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::FLOAT32); + NDArray c('c', {M, N}, sd::DataType::FLOAT32); + + NDArray exp('c', {M, N}, + {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, + -0.4, 0.6, -0.8, 1.}, + sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_10) { - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::FLOAT32); - NDArray c('f', {M,N}, sd::DataType::FLOAT32); - - NDArray exp('f', {M,N}, {0.1, 0.3, 0.5, 2.5, 2.7, 2.9, 4.9, 5.1, 5.3, 7.3, 7.5, 7.7, 9.7, 9.9, 10.1}, sd::DataType::FLOAT32); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - // c.printIndexedBuffer(); - - ASSERT_TRUE(c.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::FLOAT32); + NDArray b('f', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::FLOAT32); + NDArray c('f', {M, N}, sd::DataType::FLOAT32); + + NDArray exp('f', {M, N}, + {0.1, 0.3, 0.5, 2.5, 2.7, 2.9, 4.9, 5.1, 5.3, 7.3, 7.5, 7.7, 9.7, + 9.9, 10.1}, + sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + // c.printIndexedBuffer(); + + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_11) { - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::FLOAT32); - NDArray c('f', {M,N}, sd::DataType::FLOAT32); - - NDArray exp('f', {M,N}, {-1.9, -0.9, 0.1, 1.3, 0.3, -0.7, -0.7, 0.3, 1.3, 0.1, -0.9, -1.9, 0.5, 1.5, 2.5}, sd::DataType::FLOAT32); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::FLOAT32); + NDArray b('c', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::FLOAT32); + NDArray c('f', {M, N}, sd::DataType::FLOAT32); + + NDArray exp('f', {M, N}, + {-1.9, -0.9, 0.1, 1.3, 0.3, -0.7, -0.7, 0.3, 1.3, 0.1, -0.9, -1.9, + 0.5, 1.5, 2.5}, + sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_12) { - - int devCnt = 0; - cudaGetDevice(&devCnt); - if(Environment::getInstance().capabilities()[devCnt].first() < 5) return; - - const Nd4jLong M = 4; - const Nd4jLong K = 4; - const Nd4jLong N = 4; - - NDArray a('f', {M,K}, {1.,2,3,4,5,6,7,8,9,2,3,2,1,0,4,7.}, sd::DataType::INT8); - NDArray b('f', {K,N}, {-2,-3,0,1,5,-6,7,-8,9,-1,2,-2,3,-4,5,-6.}, sd::DataType::INT8); - NDArray c('f', {M,N}, sd::DataType::FLOAT32); - - NDArray exp('f', {M,N}, {-16., -22., -23., -25., 30., -12., -38., -70., 20., 16., 18., 18., 22., -8., -28., -52.}, sd::DataType::FLOAT32); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - // c.printBuffer(); - - ASSERT_TRUE(c.equalsTo(&exp)); + int devCnt = 0; + cudaGetDevice(&devCnt); + if (Environment::getInstance().capabilities()[devCnt].first() < 5) return; + + const Nd4jLong M = 4; + const Nd4jLong K = 4; + const Nd4jLong N = 4; + + NDArray a('f', {M, K}, {1., 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 2, 1, 0, 4, 7.}, + sd::DataType::INT8); + NDArray b('f', {K, N}, + {-2, -3, 0, 1, 5, -6, 7, -8, 9, -1, 2, -2, 3, -4, 5, -6.}, + sd::DataType::INT8); + NDArray c('f', {M, N}, sd::DataType::FLOAT32); + + NDArray exp('f', {M, N}, + {-16., -22., -23., -25., 30., -12., -38., -70., 20., 16., 18., + 18., 22., -8., -28., -52.}, + sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + // c.printBuffer(); + + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_13) { + int devCnt = 0; + cudaGetDevice(&devCnt); + if (Environment::getInstance().capabilities()[devCnt].first() < 5) return; - int devCnt = 0; - cudaGetDevice(&devCnt); - if(Environment::getInstance().capabilities()[devCnt].first() < 5) return; - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; - NDArray a('f', {M,K}, {1.,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::INT8); - NDArray b('c', {K,N}, {-2,-3,0,1,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::INT8); - NDArray c('f', {M,N}, sd::DataType::FLOAT32); + NDArray a('f', {M, K}, {1., 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + sd::DataType::INT8); + NDArray b('c', {K, N}, {-2, -3, 0, 1, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::INT8); + NDArray c('f', {M, N}, sd::DataType::FLOAT32); - NDArray exp('f', {M,N}, {-109., -122., -135., 111., 120., 129., -121., -134., -147., 129., 144., 159., -130., -140., -150.}, sd::DataType::FLOAT32); + NDArray exp('f', {M, N}, + {-109., -122., -135., 111., 120., 129., -121., -134., -147., 129., + 144., 159., -130., -140., -150.}, + sd::DataType::FLOAT32); - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - ASSERT_TRUE(c.equalsTo(&exp)); + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_14) { + int devCnt = 0; + cudaGetDevice(&devCnt); + if (Environment::getInstance().capabilities()[devCnt].first() < 5) return; - int devCnt = 0; - cudaGetDevice(&devCnt); - if(Environment::getInstance().capabilities()[devCnt].first() < 5) return; + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; + NDArray a('c', {M, K}, {1., 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + sd::DataType::INT8); + NDArray b('c', {K, N}, {-2, -3, 0, 1, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::INT8); + NDArray c('c', {M, N}, sd::DataType::FLOAT32); - NDArray a('c', {M,K}, {1.,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::INT8); - NDArray b('c', {K,N}, {-2,-3,0,1,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::INT8); - NDArray c('c', {M,N}, sd::DataType::FLOAT32); + NDArray exp('c', {M, N}, + {-45., 43., -49., 53., -50., -97., 79., -101., 113., -90., -149., + 115., -153., 173., -130.}, + sd::DataType::FLOAT32); - NDArray exp('c', {M,N}, {-45., 43., -49., 53., -50., -97., 79., -101., 113., -90., -149., 115., -153., 173., -130.}, sd::DataType::FLOAT32); + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp)); + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_15) { - - int devCnt = 0; - cudaGetDevice(&devCnt); - if(Environment::getInstance().capabilities()[devCnt].first() < 5) return; - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); - NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::HALF); - NDArray c('f', {M,N}, sd::DataType::FLOAT32); - - NDArray exp('f', {M,N}, {0.1, 0.3, 0.5, 2.5, 2.7, 2.9, 4.9, 5.1, 5.3, 7.3, 7.5, 7.7, 9.7, 9.9, 10.1}, sd::DataType::FLOAT32); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - // c.printBuffer(); - - ASSERT_TRUE(c.equalsTo(&exp, 0.01)); + int devCnt = 0; + cudaGetDevice(&devCnt); + if (Environment::getInstance().capabilities()[devCnt].first() < 5) return; + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::HALF); + NDArray b('f', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::HALF); + NDArray c('f', {M, N}, sd::DataType::FLOAT32); + + NDArray exp('f', {M, N}, + {0.1, 0.3, 0.5, 2.5, 2.7, 2.9, 4.9, 5.1, 5.3, 7.3, 7.5, 7.7, 9.7, + 9.9, 10.1}, + sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + // c.printBuffer(); + + ASSERT_TRUE(c.equalsTo(&exp, 0.01)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_16) { - - int devCnt = 0; - cudaGetDevice(&devCnt); - if(Environment::getInstance().capabilities()[devCnt].first() < 5) return; - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::HALF); - NDArray c('f', {M,N}, sd::DataType::FLOAT32); - - NDArray exp('f', {M,N}, {-1.9, -0.9, 0.1, 1.3, 0.3, -0.7, -0.7, 0.3, 1.3, 0.1, -0.9, -1.9, 0.5, 1.5, 2.5}, sd::DataType::FLOAT32); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp, 0.01)); + int devCnt = 0; + cudaGetDevice(&devCnt); + if (Environment::getInstance().capabilities()[devCnt].first() < 5) return; + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::HALF); + NDArray b('c', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::HALF); + NDArray c('f', {M, N}, sd::DataType::FLOAT32); + + NDArray exp('f', {M, N}, + {-1.9, -0.9, 0.1, 1.3, 0.3, -0.7, -0.7, 0.3, 1.3, 0.1, -0.9, -1.9, + 0.5, 1.5, 2.5}, + sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp, 0.01)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_17) { - - int devCnt = 0; - cudaGetDevice(&devCnt); - if(Environment::getInstance().capabilities()[devCnt].first() < 5) return; - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::HALF); - NDArray c('c', {M,N}, sd::DataType::FLOAT32); - - NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::FLOAT32); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp, 0.01)); + int devCnt = 0; + cudaGetDevice(&devCnt); + if (Environment::getInstance().capabilities()[devCnt].first() < 5) return; + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::HALF); + NDArray b('c', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::HALF); + NDArray c('c', {M, N}, sd::DataType::FLOAT32); + + NDArray exp('c', {M, N}, + {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, + -0.4, 0.6, -0.8, 1.}, + sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp, 0.01)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_18) { - - int devCnt = 0; - cudaGetDevice(&devCnt); - if(Environment::getInstance().capabilities()[devCnt].first() < 5.3) return; - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); - NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::HALF); - NDArray c('f', {M,N}, sd::DataType::HALF); - - NDArray exp('f', {M,N}, {0.1, 0.3, 0.5, 2.5, 2.7, 2.9, 4.9, 5.1, 5.3, 7.3, 7.5, 7.7, 9.7, 9.9, 10.1}, sd::DataType::HALF); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp, 1e-1)); + int devCnt = 0; + cudaGetDevice(&devCnt); + if (Environment::getInstance().capabilities()[devCnt].first() < 5.3) return; + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::HALF); + NDArray b('f', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::HALF); + NDArray c('f', {M, N}, sd::DataType::HALF); + + NDArray exp('f', {M, N}, + {0.1, 0.3, 0.5, 2.5, 2.7, 2.9, 4.9, 5.1, 5.3, 7.3, 7.5, 7.7, 9.7, + 9.9, 10.1}, + sd::DataType::HALF); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp, 1e-1)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_19) { - - int devCnt = 0; - cudaGetDevice(&devCnt); - if(Environment::getInstance().capabilities()[devCnt].first() < 5.3) return; - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::HALF); - NDArray c('f', {M,N}, sd::DataType::HALF); - - NDArray exp('f', {M,N}, {-1.9, -0.9, 0.1, 1.3, 0.3, -0.7, -0.7, 0.3, 1.3, 0.1, -0.9, -1.9, 0.5, 1.5, 2.5}, sd::DataType::HALF); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp, 1e-1)); + int devCnt = 0; + cudaGetDevice(&devCnt); + if (Environment::getInstance().capabilities()[devCnt].first() < 5.3) return; + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::HALF); + NDArray b('c', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::HALF); + NDArray c('f', {M, N}, sd::DataType::HALF); + + NDArray exp('f', {M, N}, + {-1.9, -0.9, 0.1, 1.3, 0.3, -0.7, -0.7, 0.3, 1.3, 0.1, -0.9, -1.9, + 0.5, 1.5, 2.5}, + sd::DataType::HALF); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp, 1e-1)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_20) { - - int devCnt = 0; - cudaGetDevice(&devCnt); - if(Environment::getInstance().capabilities()[devCnt].first() < 5.3) return; - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::HALF); - NDArray c('c', {M,N}, sd::DataType::HALF); - - NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::HALF); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp, 1e-1)); + int devCnt = 0; + cudaGetDevice(&devCnt); + if (Environment::getInstance().capabilities()[devCnt].first() < 5.3) return; + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::HALF); + NDArray b('c', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::HALF); + NDArray c('c', {M, N}, sd::DataType::HALF); + + NDArray exp('c', {M, N}, + {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, + -0.4, 0.6, -0.8, 1.}, + sd::DataType::HALF); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp, 1e-1)); } /* ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_21) { - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; - NDArray a('c', {M,K}, {1.,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::INT8); - NDArray b('c', {K,N}, {-2,-3,0,1,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::FLOAT32); - NDArray c('c', {M,N}, sd::DataType::DOUBLE); + NDArray a('c', {M,K}, {1.,2,3,4,5,6,7,8,9,10,11,12}, +sd::DataType::INT8); NDArray b('c', {K,N}, +{-2,-3,0,1,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, +sd::DataType::FLOAT32); NDArray c('c', {M,N}, sd::DataType::DOUBLE); - NDArray exp('c', {M,N}, {-45., 43., -49., 53., -50., -97., 79., -101., 113., -90., -149., 115., -153., 173., -130.}, sd::DataType::DOUBLE); + NDArray exp('c', {M,N}, {-45., 43., -49., 53., -50., -97., 79., -101., +113., -90., -149., 115., -153., 173., -130.}, sd::DataType::DOUBLE); - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - ASSERT_TRUE(c.equalsTo(&exp)); + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_22) { - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; - NDArray a('c', {M,K}, {1.,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::FLOAT32); - NDArray b('c', {K,N}, {-2,-3,0,1,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::HALF); - NDArray c('c', {M,N}, sd::DataType::FLOAT32); + NDArray a('c', {M,K}, {1.,2,3,4,5,6,7,8,9,10,11,12}, +sd::DataType::FLOAT32); NDArray b('c', {K,N}, +{-2,-3,0,1,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, +sd::DataType::HALF); NDArray c('c', {M,N}, sd::DataType::FLOAT32); - NDArray exp('c', {M,N}, {-45., 43., -49., 53., -50., -97., 79., -101., 113., -90., -149., 115., -153., 173., -130.}, sd::DataType::FLOAT32); + NDArray exp('c', {M,N}, {-45., 43., -49., 53., -50., -97., 79., -101., +113., -90., -149., 115., -153., 173., -130.}, sd::DataType::FLOAT32); - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - // c.printBuffer(); + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + // c.printBuffer(); - ASSERT_TRUE(c.equalsTo(&exp)); + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_23) { - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; - NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::HALF); - NDArray c('c', {M,N}, sd::DataType::DOUBLE); + NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::HALF); NDArray b('c', {K,N}, +{1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, +sd::DataType::HALF); NDArray c('c', {M,N}, sd::DataType::DOUBLE); - NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::DOUBLE); + NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, +-3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::DOUBLE); - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - ASSERT_TRUE(c.equalsTo(&exp, 0.01)); + ASSERT_TRUE(c.equalsTo(&exp, 0.01)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_24) { - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; - NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::FLOAT32); - NDArray c('c', {M,N}, sd::DataType::DOUBLE); + NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::HALF); NDArray b('c', {K,N}, +{1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, +sd::DataType::FLOAT32); NDArray c('c', {M,N}, sd::DataType::DOUBLE); - NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::DOUBLE); + NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, +-3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::DOUBLE); - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - ASSERT_TRUE(c.equalsTo(&exp, 0.01)); + ASSERT_TRUE(c.equalsTo(&exp, 0.01)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_25) { - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; - NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::FLOAT32); - NDArray c('c', {M,N}, sd::DataType::HALF); + NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::DOUBLE); NDArray b('c', {K,N}, +{1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, +sd::DataType::FLOAT32); NDArray c('c', {M,N}, sd::DataType::HALF); - NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::HALF); + NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, +-3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::HALF); - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - ASSERT_TRUE(c.equalsTo(&exp, 0.01)); + ASSERT_TRUE(c.equalsTo(&exp, 0.01)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_26) { - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; - // 3x4 * 4x5 = 3x5 - NDArray a('c', {M,K}, {1.,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::INT64); - NDArray b('c', {K,N}, {-2,-3,0,1,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::FLOAT32); - NDArray c('c', {M,N}, sd::DataType::DOUBLE); + // 3x4 * 4x5 = 3x5 + NDArray a('c', {M,K}, {1.,2,3,4,5,6,7,8,9,10,11,12}, +sd::DataType::INT64); NDArray b('c', {K,N}, +{-2,-3,0,1,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, +sd::DataType::FLOAT32); NDArray c('c', {M,N}, sd::DataType::DOUBLE); - NDArray exp('c', {M,N}, {-45., 43., -49., 53., -50., -97., 79., -101., 113., -90., -149., 115., -153., 173., -130.}, sd::DataType::DOUBLE); + NDArray exp('c', {M,N}, {-45., 43., -49., 53., -50., -97., 79., -101., +113., -90., -149., 115., -153., 173., -130.}, sd::DataType::DOUBLE); - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - // c.printBuffer(); + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + // c.printBuffer(); - ASSERT_TRUE(c.equalsTo(&exp)); + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_27) { - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; - NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); - NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); - NDArray c('f', {M,N}, sd::DataType::FLOAT32); + NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::HALF); NDArray b('f', {K,N}, +{1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, +sd::DataType::DOUBLE); NDArray c('f', {M,N}, sd::DataType::FLOAT32); - NDArray exp('f', {M,N}, {0.1, 0.3, 0.5, 2.5, 2.7, 2.9, 4.9, 5.1, 5.3, 7.3, 7.5, 7.7, 9.7, 9.9, 10.1}, sd::DataType::FLOAT32); + NDArray exp('f', {M,N}, {0.1, 0.3, +0.5, 2.5, 2.7, 2.9, 4.9, 5.1, 5.3, 7.3, 7.5, 7.7, 9.7, 9.9, 10.1}, +sd::DataType::FLOAT32); - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - // c.printBuffer(); + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + // c.printBuffer(); - ASSERT_TRUE(c.equalsTo(&exp, 0.01)); + ASSERT_TRUE(c.equalsTo(&exp, 0.01)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_28) { - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; - NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); - NDArray c('f', {M,N}, sd::DataType::FLOAT32); + NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::FLOAT32); NDArray b('f', {K,N}, +{1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, +sd::DataType::DOUBLE); NDArray c('f', {M,N}, sd::DataType::FLOAT32); - NDArray exp('f', {M,N}, {-1.6, -0.7, 0.2, -0.8, 0.1, 1., -0., 0.9, 1.8, 0.8, 1.7, 2.6, 1.6, 2.5, 3.4}, sd::DataType::FLOAT32); + NDArray exp('f', {M,N}, {-1.6, -0.7, 0.2, -0.8, 0.1, 1., -0., 0.9, 1.8, +0.8, 1.7, 2.6, 1.6, 2.5, 3.4}, sd::DataType::FLOAT32); - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - ASSERT_TRUE(c.equalsTo(&exp)); + ASSERT_TRUE(c.equalsTo(&exp)); } */ ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_1) { + const Nd4jLong M = 3; + const Nd4jLong N = 4; - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray x('f', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); - NDArray y('f', {M}, sd::DataType::DOUBLE); + NDArray a('f', {M, N}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray x('f', {N}, {1, -2, 3, -4}, sd::DataType::DOUBLE); + NDArray y('f', {M}, sd::DataType::DOUBLE); - NDArray exp('f', {M}, {0.1, 0.3, 0.5}, sd::DataType::DOUBLE); + NDArray exp('f', {M}, {0.1, 0.3, 0.5}, sd::DataType::DOUBLE); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_2) { + const Nd4jLong M = 3; + const Nd4jLong N = 4; - const Nd4jLong M = 3; - const Nd4jLong N = 4; + NDArray a('c', {M, N}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray x('f', {N}, {1, -2, 3, -4}, sd::DataType::DOUBLE); + NDArray y('f', {M}, sd::DataType::DOUBLE); - NDArray a('c', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray x('f', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); - NDArray y('f', {M}, sd::DataType::DOUBLE); + NDArray exp('f', {M}, {-1.6, -0.7, 0.2}, sd::DataType::DOUBLE); - NDArray exp('f', {M}, {-1.6, -0.7, 0.2}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_3) { + const Nd4jLong M = 3; + const Nd4jLong N = 4; - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('c', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray x('c', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); - NDArray y('f', {M}, sd::DataType::DOUBLE); + NDArray a('c', {M, N}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray x('c', {N}, {1, -2, 3, -4}, sd::DataType::DOUBLE); + NDArray y('f', {M}, sd::DataType::DOUBLE); - NDArray exp('f', {M}, {-1.6, -0.7, 0.2}, sd::DataType::DOUBLE); + NDArray exp('f', {M}, {-1.6, -0.7, 0.2}, sd::DataType::DOUBLE); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_4) { + const Nd4jLong M = 3; + const Nd4jLong N = 4; - const Nd4jLong M = 3; - const Nd4jLong N = 4; + NDArray a('c', {M, N}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray x('c', {N}, {1, -2, 3, -4}, sd::DataType::DOUBLE); + NDArray y('c', {M}, sd::DataType::DOUBLE); - NDArray a('c', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray x('c', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); - NDArray y('c', {M}, sd::DataType::DOUBLE); + NDArray exp('c', {M}, {-1.6, -0.7, 0.2}, sd::DataType::DOUBLE); - NDArray exp('c', {M}, {-1.6, -0.7, 0.2}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_5) { + const Nd4jLong M = 3; + const Nd4jLong N = 4; - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray x('c', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); - NDArray y('c', {M}, sd::DataType::DOUBLE); + NDArray a('f', {M, N}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray x('c', {N}, {1, -2, 3, -4}, sd::DataType::DOUBLE); + NDArray y('c', {M}, sd::DataType::DOUBLE); - NDArray exp('c', {M}, {0.1, 0.3, 0.5}, sd::DataType::DOUBLE); + NDArray exp('c', {M}, {0.1, 0.3, 0.5}, sd::DataType::DOUBLE); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_6) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray temp('f', {M,N,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(6, {0,2}); - NDArray y('f', {M}, sd::DataType::DOUBLE); - - NDArray exp('f', {M}, {5.5, 5.1, 4.7}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('f', {M, N}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray temp('f', {M, N, 5}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(6, {0, 2}); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {5.5, 5.1, 4.7}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_7) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {M,N,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(6, {0,2}); - NDArray y('f', {M}, sd::DataType::DOUBLE); - - NDArray exp('f', {M}, {5.1, 3.3, 1.5}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('f', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('f', {M, N, 5}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(6, {0, 2}); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {5.1, 3.3, 1.5}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_8) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {N,M,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(4, {1,2}); - NDArray y('f', {M}, sd::DataType::DOUBLE); - - NDArray exp('f', {M}, {6.2, 4.5, 1.7}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('f', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('f', {N, M, 5}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(4, {1, 2}); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {6.2, 4.5, 1.7}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_9) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(3, {0,1}); - NDArray y('f', {M}, sd::DataType::DOUBLE); - - NDArray exp('f', {M}, {1.5, 1.8, 1.5}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('f', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('f', {5, M, N}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(3, {0, 1}); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {1.5, 1.8, 1.5}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_10) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(2, {0,1}); - NDArray y('f', {M}, sd::DataType::DOUBLE); - - NDArray exp('f', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('f', {5, M, N}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(2, {0, 1}); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_11) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('c', {5,N,M}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(13, {0,2}); - NDArray y('f', {M}, sd::DataType::DOUBLE); - - NDArray exp('f', {M}, {-12.1, -10.9, -9.7}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('c', {5, N, M}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(13, {0, 2}); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {-12.1, -10.9, -9.7}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_12) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('c', {5,N,M}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(10, {0,2}); - NDArray y('c', {M}, sd::DataType::DOUBLE); - - NDArray exp('c', {M}, {3.3, 3.3, 3.3}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('c', {5, N, M}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(10, {0, 2}); + NDArray y('c', {M}, sd::DataType::DOUBLE); + + NDArray exp('c', {M}, {3.3, 3.3, 3.3}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_13) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(2, {0,1}, true); - NDArray y('f', {M}, sd::DataType::DOUBLE); - - NDArray exp('f', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('f', {5, M, N}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(2, {0, 1}, true); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_14) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('c', {5,N,M}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(10, {0,2}, true); - NDArray y('c', {M}, sd::DataType::DOUBLE); - - NDArray exp('c', {M}, {3.3, 3.3, 3.3}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('c', {5, N, M}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(10, {0, 2}, true); + NDArray y('c', {M}, sd::DataType::DOUBLE); + + NDArray exp('c', {M}, {3.3, 3.3, 3.3}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_15) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(2, {0,1}); - NDArray y = temp(17, {0,2}); - - NDArray exp('f', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('f', {5, M, N}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(2, {0, 1}); + NDArray y = temp(17, {0, 2}); + + NDArray exp('f', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_16) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray temp1('c', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(2, {0,1}); - NDArray y = temp1(17, {0,2}); - - NDArray exp('c', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('f', {5, M, N}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray temp1('c', {5, M, N}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(2, {0, 1}); + NDArray y = temp1(17, {0, 2}); + + NDArray exp('c', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_17) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(2, {0,1}); - NDArray y = temp(17, {0,2}, true); - // y.printShapeInfo(); - - NDArray exp('f', {1,M,1}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('f', {5, M, N}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(2, {0, 1}); + NDArray y = temp(17, {0, 2}, true); + // y.printShapeInfo(); + + NDArray exp('f', {1, M, 1}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_18) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray temp1('c', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(2, {0,1},true); - NDArray y = temp1(17, {0,2},true); - - NDArray exp('c', {1,M,1}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('f', {5, M, N}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray temp1('c', {5, M, N}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(2, {0, 1}, true); + NDArray y = temp1(17, {0, 2}, true); + + NDArray exp('c', {1, M, 1}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// /* TEST_F(CudaBasicsTests2, mmulMxV_19) { - const Nd4jLong M = 3; - const Nd4jLong N = 4; + const Nd4jLong M = 3; + const Nd4jLong N = 4; - NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - NDArray x('f', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); - NDArray y('f', {M}, sd::DataType::FLOAT32); + NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::FLOAT32); NDArray x('f', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); + NDArray y('f', {M}, sd::DataType::FLOAT32); - NDArray exp('f', {M}, {0.1, 0.3, 0.5}, sd::DataType::FLOAT32); + NDArray exp('f', {M}, {0.1, 0.3, 0.5}, sd::DataType::FLOAT32); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_20) { - const Nd4jLong M = 3; - const Nd4jLong N = 4; + const Nd4jLong M = 3; + const Nd4jLong N = 4; - NDArray a('c', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - NDArray x('f', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); - NDArray y('f', {M}, sd::DataType::FLOAT32); - NDArray exp('f', {M}, {-1.6, -0.7, 0.2}, sd::DataType::FLOAT32); + NDArray a('c', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::FLOAT32); NDArray x('f', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); + NDArray y('f', {M}, sd::DataType::FLOAT32); + NDArray exp('f', {M}, {-1.6, -0.7, 0.2}, sd::DataType::FLOAT32); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_21) { - const Nd4jLong M = 3; - const Nd4jLong N = 4; + const Nd4jLong M = 3; + const Nd4jLong N = 4; - NDArray a('c', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - NDArray x('c', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); - NDArray y('c', {M}, sd::DataType::FLOAT32); - NDArray exp('c', {M}, {-1.6, -0.7, 0.2}, sd::DataType::FLOAT32); + NDArray a('c', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::FLOAT32); NDArray x('c', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); + NDArray y('c', {M}, sd::DataType::FLOAT32); + NDArray exp('c', {M}, {-1.6, -0.7, 0.2}, sd::DataType::FLOAT32); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_22) { - const Nd4jLong M = 3; - const Nd4jLong N = 4; + const Nd4jLong M = 3; + const Nd4jLong N = 4; - NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - NDArray temp('f', {M,N,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(6, {0,2}); - NDArray y('f', {M}, sd::DataType::FLOAT32); + NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::FLOAT32); NDArray temp('f', {M,N,5}, +{16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, +sd::DataType::DOUBLE); NDArray x = temp(6, {0,2}); NDArray y('f', {M}, +sd::DataType::FLOAT32); - NDArray exp('f', {M}, {5.5, 5.1, 4.7}, sd::DataType::FLOAT32); + NDArray exp('f', {M}, {5.5, 5.1, 4.7}, sd::DataType::FLOAT32); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// @@ -989,11 +1215,11 @@ TEST_F(CudaBasicsTests2, mmulMxV_23) { const Nd4jLong M = 3; const Nd4jLong N = 4; - NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(3, {0,1}); - NDArray y('f', {M}, sd::DataType::FLOAT32); + NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::FLOAT32); a.permutei({1,0}); NDArray temp('f', {5,M,N}, +{16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, +sd::DataType::DOUBLE); NDArray x = temp(3, {0,1}); NDArray y('f', {M}, +sd::DataType::FLOAT32); NDArray exp('f', {M}, {1.5, 1.8, 1.5}, sd::DataType::FLOAT32); @@ -1004,18 +1230,19 @@ TEST_F(CudaBasicsTests2, mmulMxV_23) { ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_24) { - const Nd4jLong M = 3; - const Nd4jLong N = 4; + const Nd4jLong M = 3; + const Nd4jLong N = 4; - NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - NDArray temp('f', {M,N,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(6, {0,2},true); - NDArray y('f', {M}, sd::DataType::FLOAT32); + NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::FLOAT32); NDArray temp('f', {M,N,5}, +{16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, +sd::DataType::DOUBLE); NDArray x = temp(6, {0,2},true); NDArray y('f', {M}, +sd::DataType::FLOAT32); - NDArray exp('f', {M}, {5.5, 5.1, 4.7}, sd::DataType::FLOAT32); + NDArray exp('f', {M}, {5.5, 5.1, 4.7}, sd::DataType::FLOAT32); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// @@ -1024,11 +1251,11 @@ TEST_F(CudaBasicsTests2, mmulMxV_25) { const Nd4jLong M = 3; const Nd4jLong N = 4; - NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(3, {0,1}, true); - NDArray y('f', {M}, sd::DataType::FLOAT32); + NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::FLOAT32); a.permutei({1,0}); NDArray temp('f', {5,M,N}, +{16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, +sd::DataType::DOUBLE); NDArray x = temp(3, {0,1}, true); NDArray y('f', {M}, +sd::DataType::FLOAT32); NDArray exp('f', {M}, {1.5, 1.8, 1.5}, sd::DataType::FLOAT32); @@ -1042,12 +1269,13 @@ TEST_F(CudaBasicsTests2, mmulMxV_26) { const Nd4jLong M = 3; const Nd4jLong N = 4; - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray temp1('c', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::FLOAT32); - NDArray x = temp(2, {0,1}); - NDArray y = temp1(17, {0,2}); + NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::FLOAT32); a.permutei({1,0}); NDArray temp('f', {5,M,N}, +{16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, +sd::DataType::DOUBLE); NDArray temp1('c', {5,M,N}, +{16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, +sd::DataType::FLOAT32); NDArray x = temp(2, {0,1}); NDArray y = temp1(17, +{0,2}); NDArray exp('c', {M}, {-0.3, 0.3, 0.9}, sd::DataType::FLOAT32); @@ -1061,12 +1289,13 @@ TEST_F(CudaBasicsTests2, mmulMxV_27) { const Nd4jLong M = 3; const Nd4jLong N = 4; - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray temp1('c', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::FLOAT32); - NDArray x = temp(2, {0,1},true); - NDArray y = temp1(17, {0,2},true); + NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::FLOAT32); a.permutei({1,0}); NDArray temp('f', {5,M,N}, +{16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, +sd::DataType::DOUBLE); NDArray temp1('c', {5,M,N}, +{16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, +sd::DataType::FLOAT32); NDArray x = temp(2, {0,1},true); NDArray y = temp1(17, +{0,2},true); NDArray exp('c', {1,M,1}, {-0.3, 0.3, 0.9}, sd::DataType::FLOAT32); @@ -1080,10 +1309,11 @@ TEST_F(CudaBasicsTests2, mmulMxV_28) { const Nd4jLong M = 3; const Nd4jLong N = 4; - NDArray a('c', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - NDArray temp('f', {M,N,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(6, {0,2}); - NDArray y('f', {M}, sd::DataType::FLOAT32); + NDArray a('c', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::FLOAT32); NDArray temp('f', {M,N,5}, +{16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, +sd::DataType::DOUBLE); NDArray x = temp(6, {0,2}); NDArray y('f', {M}, +sd::DataType::FLOAT32); NDArray exp('f', {M}, {5.1, 3.3, 1.5}, sd::DataType::FLOAT32); @@ -1127,10 +1357,9 @@ TEST_F(CudaBasicsTests2, mmulDot_3) { const Nd4jLong N = 4; NDArray xBig('c', {4,2}, {1, 0, 2, 0, 3, 0, 4, 0}, sd::DataType::INT32); - NDArray yBig('c', {4,3}, {0.1, 0, 0, 0.2, 0, 0, 0.3, 0, 0, 0.4, 0,0}, sd::DataType::FLOAT32); - NDArray x = xBig(0, {1}, true); - NDArray y = yBig(0, {1}, true); - NDArray z(sd::DataType::DOUBLE); + NDArray yBig('c', {4,3}, {0.1, 0, 0, 0.2, 0, 0, 0.3, 0, 0, 0.4, 0,0}, +sd::DataType::FLOAT32); NDArray x = xBig(0, {1}, true); NDArray y = yBig(0, {1}, +true); NDArray z(sd::DataType::DOUBLE); NDArray exp('c', {}, {3}, sd::DataType::DOUBLE); @@ -1144,10 +1373,9 @@ TEST_F(CudaBasicsTests2, mmulDot_4) { const Nd4jLong N = 4; NDArray xBig('f', {4,2}, {1, 2, 3, 4, 0, 0, 0, 0}, sd::DataType::INT32); - NDArray yBig('c', {4,3}, {0.1, 0, 0, 0.2, 0, 0, 0.3, 0, 0, 0.4, 0,0}, sd::DataType::FLOAT32); - NDArray x = xBig(0, {1}, true); - NDArray y = yBig(0, {1}); - NDArray z(sd::DataType::DOUBLE); + NDArray yBig('c', {4,3}, {0.1, 0, 0, 0.2, 0, 0, 0.3, 0, 0, 0.4, 0,0}, +sd::DataType::FLOAT32); NDArray x = xBig(0, {1}, true); NDArray y = yBig(0, +{1}); NDArray z(sd::DataType::DOUBLE); NDArray exp('c', {}, {3}, sd::DataType::DOUBLE); diff --git a/libnd4j/tests_cpu/layers_tests/CudaExtraArgumentsTests.cu b/libnd4j/tests_cpu/layers_tests/CudaExtraArgumentsTests.cu index 30d58946c51b..bd66220c07aa 100644 --- a/libnd4j/tests_cpu/layers_tests/CudaExtraArgumentsTests.cu +++ b/libnd4j/tests_cpu/layers_tests/CudaExtraArgumentsTests.cu @@ -18,57 +18,56 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include -#include #include #include +#include + +#include "testlayers.h" + using namespace sd; class CudaExtraArgumentsTests : public testing::Test { -public: - - CudaExtraArgumentsTests() { - printf("\n"); - fflush(stdout); - } + public: + CudaExtraArgumentsTests() { + printf("\n"); + fflush(stdout); + } }; TEST_F(CudaExtraArgumentsTests, Basic_Test_1) { - ExtraArguments args({1.0, 2.0, 3.0}); + ExtraArguments args({1.0, 2.0, 3.0}); - float ef[] = {1.f, 2.f, 3.f}; - double ed[] = {1., 2., 3.}; + float ef[] = {1.f, 2.f, 3.f}; + double ed[] = {1., 2., 3.}; - auto ptrFloat = reinterpret_cast(args.argumentsAsT()); - auto ptrDouble = reinterpret_cast(args.argumentsAsT()); - ASSERT_TRUE(ptrFloat != nullptr); - ASSERT_TRUE(ptrDouble != nullptr); + auto ptrFloat = reinterpret_cast(args.argumentsAsT()); + auto ptrDouble = reinterpret_cast(args.argumentsAsT()); + ASSERT_TRUE(ptrFloat != nullptr); + ASSERT_TRUE(ptrDouble != nullptr); - auto tmpFloat = new float[3]; - auto tmpDouble = new double[3]; + auto tmpFloat = new float[3]; + auto tmpDouble = new double[3]; - cudaMemcpy(tmpFloat, ptrFloat, 3 * sizeof(float), cudaMemcpyDeviceToHost); - cudaMemcpy(tmpDouble, ptrDouble, 3 * sizeof(double), cudaMemcpyDeviceToHost); + cudaMemcpy(tmpFloat, ptrFloat, 3 * sizeof(float), cudaMemcpyDeviceToHost); + cudaMemcpy(tmpDouble, ptrDouble, 3 * sizeof(double), cudaMemcpyDeviceToHost); - for (int e = 0; e < 3; e++) { - ASSERT_NEAR(ef[e], tmpFloat[e], 1e-5f); - } + for (int e = 0; e < 3; e++) { + ASSERT_NEAR(ef[e], tmpFloat[e], 1e-5f); + } - for (int e = 0; e < 3; e++) { - ASSERT_NEAR(ed[e], tmpDouble[e], 1e-5); - } + for (int e = 0; e < 3; e++) { + ASSERT_NEAR(ed[e], tmpDouble[e], 1e-5); + } - delete[] tmpFloat; - delete[] tmpDouble; + delete[] tmpFloat; + delete[] tmpDouble; } - TEST_F(CudaExtraArgumentsTests, Basic_Test_2) { - ExtraArguments args; + ExtraArguments args; - auto ptrInt = args.argumentsAsT(); - ASSERT_TRUE(ptrInt == nullptr); + auto ptrInt = args.argumentsAsT(); + ASSERT_TRUE(ptrInt == nullptr); } - diff --git a/libnd4j/tests_cpu/layers_tests/CudaLaunchHelperTests.cpp b/libnd4j/tests_cpu/layers_tests/CudaLaunchHelperTests.cpp index 66cc024bf98d..6e9979d33c5d 100644 --- a/libnd4j/tests_cpu/layers_tests/CudaLaunchHelperTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/CudaLaunchHelperTests.cpp @@ -18,29 +18,29 @@ // Created by raver on 11/26/2018. // -#include "testlayers.h" #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class CudaLaunchHelperTests : public testing::Test { -public: - + public: }; TEST_F(CudaLaunchHelperTests, test_reduction_blocks_1) { - ASSERT_EQ(1, CudaLaunchHelper::getReductionBlocks(512)); + ASSERT_EQ(1, CudaLaunchHelper::getReductionBlocks(512)); } TEST_F(CudaLaunchHelperTests, test_reduction_blocks_2) { - ASSERT_EQ(1, CudaLaunchHelper::getReductionBlocks(121)); + ASSERT_EQ(1, CudaLaunchHelper::getReductionBlocks(121)); } TEST_F(CudaLaunchHelperTests, test_reduction_blocks_3) { - ASSERT_EQ(2, CudaLaunchHelper::getReductionBlocks(513)); + ASSERT_EQ(2, CudaLaunchHelper::getReductionBlocks(513)); } TEST_F(CudaLaunchHelperTests, test_reduction_blocks_4) { - ASSERT_EQ(3, CudaLaunchHelper::getReductionBlocks(1225)); + ASSERT_EQ(3, CudaLaunchHelper::getReductionBlocks(1225)); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DataBufferTests.cpp b/libnd4j/tests_cpu/layers_tests/DataBufferTests.cpp index b22f9e765582..ed53afd8c045 100644 --- a/libnd4j/tests_cpu/layers_tests/DataBufferTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/DataBufferTests.cpp @@ -18,61 +18,64 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include #include #include #include #include +#include #include -#include #include -#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; using namespace sd::memory; class DataBufferTests : public testing::Test { -public: - + public: }; TEST_F(DataBufferTests, test_alloc_limit_1) { - if (!Environment::getInstance().isCPU()) - return; - - auto deviceId = AffinityManager::currentDeviceId(); - auto odLimit = MemoryCounter::getInstance().deviceLimit(deviceId); - auto ogLimit = MemoryCounter::getInstance().groupLimit(MemoryType::HOST); - auto odUse = MemoryCounter::getInstance().allocatedDevice(deviceId); - auto ogUse = MemoryCounter::getInstance().allocatedGroup(MemoryType::HOST); + if (!Environment::getInstance().isCPU()) return; - auto limitSize = odUse + (150 * 1024 * 1024); - auto allocSize = 100000000; + auto deviceId = AffinityManager::currentDeviceId(); + auto odLimit = MemoryCounter::getInstance().deviceLimit(deviceId); + auto ogLimit = MemoryCounter::getInstance().groupLimit(MemoryType::HOST); + auto odUse = MemoryCounter::getInstance().allocatedDevice(deviceId); + auto ogUse = MemoryCounter::getInstance().allocatedGroup(MemoryType::HOST); - MemoryCounter::getInstance().setDeviceLimit(deviceId, odLimit + limitSize); - MemoryCounter::getInstance().setGroupLimit(MemoryType::HOST, odLimit + limitSize); + auto limitSize = odUse + (150 * 1024 * 1024); + auto allocSize = 100000000; - DataBuffer buffer(allocSize, DataType::INT32); + MemoryCounter::getInstance().setDeviceLimit(deviceId, odLimit + limitSize); + MemoryCounter::getInstance().setGroupLimit(MemoryType::HOST, + odLimit + limitSize); - // separately testing per-device limits and group limits - ASSERT_EQ(odUse + allocSize, MemoryCounter::getInstance().allocatedDevice(deviceId)); - ASSERT_EQ(ogUse + allocSize, MemoryCounter::getInstance().allocatedGroup(MemoryType::HOST)); + DataBuffer buffer(allocSize, DataType::INT32); + // separately testing per-device limits and group limits + ASSERT_EQ(odUse + allocSize, + MemoryCounter::getInstance().allocatedDevice(deviceId)); + ASSERT_EQ(ogUse + allocSize, + MemoryCounter::getInstance().allocatedGroup(MemoryType::HOST)); - // setting smaller limits, to make sure next allocation fails with OOM exception - MemoryCounter::getInstance().setDeviceLimit(deviceId, allocSize - 100); - MemoryCounter::getInstance().setGroupLimit(MemoryType::HOST, allocSize - 100); + // setting smaller limits, to make sure next allocation fails with OOM + // exception + MemoryCounter::getInstance().setDeviceLimit(deviceId, allocSize - 100); + MemoryCounter::getInstance().setGroupLimit(MemoryType::HOST, + allocSize - 100); - try { - DataBuffer bufferFailed(allocSize, DataType::INT32); - ASSERT_TRUE(false); - } catch (allocation_exception &e) { - // we expect exception here - } + try { + DataBuffer bufferFailed(allocSize, DataType::INT32); + ASSERT_TRUE(false); + } catch (allocation_exception &e) { + // we expect exception here + } - // restore original limits, so subsequent tests do not fail - MemoryCounter::getInstance().setDeviceLimit(deviceId, odLimit); - MemoryCounter::getInstance().setGroupLimit(MemoryType::HOST, odLimit); + // restore original limits, so subsequent tests do not fail + MemoryCounter::getInstance().setDeviceLimit(deviceId, odLimit); + MemoryCounter::getInstance().setGroupLimit(MemoryType::HOST, odLimit); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DataBufferTestsCuda.cu b/libnd4j/tests_cpu/layers_tests/DataBufferTestsCuda.cu index 6f7d38ede0b7..59c88c934ee7 100644 --- a/libnd4j/tests_cpu/layers_tests/DataBufferTestsCuda.cu +++ b/libnd4j/tests_cpu/layers_tests/DataBufferTestsCuda.cu @@ -18,24 +18,24 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include #include #include #include #include +#include #include -#include #include -#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; using namespace sd::memory; class DataBufferTestsCuda : public testing::Test { -public: - + public: }; /* @@ -50,25 +50,30 @@ TEST_F(DataBufferTestsCuda, test_alloc_limit_1) { auto odUse = MemoryCounter::getInstance().allocatedDevice(deviceId); auto opUse = MemoryCounter::getInstance().allocatedGroup(MemoryType::HOST); - auto osUse = MemoryCounter::getInstance().allocatedGroup(MemoryType::DEVICE); + auto osUse = +MemoryCounter::getInstance().allocatedGroup(MemoryType::DEVICE); auto limitSize = odUse + 150000000; auto allocSize = 100000000; MemoryCounter::getInstance().setDeviceLimit(deviceId, odLimit + limitSize); - MemoryCounter::getInstance().setGroupLimit(MemoryType::HOST, opLimit + limitSize); - MemoryCounter::getInstance().setGroupLimit(MemoryType::DEVICE, osLimit + limitSize); + MemoryCounter::getInstance().setGroupLimit(MemoryType::HOST, opLimit + +limitSize); MemoryCounter::getInstance().setGroupLimit(MemoryType::DEVICE, +osLimit + limitSize); DataBuffer buffer(allocSize, DataType::INT32, nullptr, true); // separately testing per-device limits and group limits - ASSERT_EQ(odUse + allocSize, MemoryCounter::getInstance().allocatedDevice(deviceId)); - ASSERT_EQ(opUse + allocSize, MemoryCounter::getInstance().allocatedGroup(MemoryType::HOST)); - ASSERT_EQ(osUse + allocSize, MemoryCounter::getInstance().allocatedGroup(MemoryType::DEVICE)); - - // setting smaller limits, to make sure next allocation fails with OOM exception - MemoryCounter::getInstance().setDeviceLimit(deviceId, allocSize - 100); - MemoryCounter::getInstance().setGroupLimit(MemoryType::DEVICE, allocSize - 100); + ASSERT_EQ(odUse + allocSize, +MemoryCounter::getInstance().allocatedDevice(deviceId)); ASSERT_EQ(opUse + +allocSize, MemoryCounter::getInstance().allocatedGroup(MemoryType::HOST)); + ASSERT_EQ(osUse + allocSize, +MemoryCounter::getInstance().allocatedGroup(MemoryType::DEVICE)); + + // setting smaller limits, to make sure next allocation fails with OOM +exception MemoryCounter::getInstance().setDeviceLimit(deviceId, allocSize - +100); MemoryCounter::getInstance().setGroupLimit(MemoryType::DEVICE, allocSize +- 100); // this allocation should fail, since we're allocating too much diff --git a/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp b/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp index 83f3a15f58ef..ad4740a880f4 100644 --- a/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp @@ -18,139 +18,150 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include #include #include #include #include +#include #include -#include #include -#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class DataTypesValidationTests : public testing::Test { -public: - + public: }; TEST_F(DataTypesValidationTests, Basic_Test_1) { - auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto exp = NDArrayFactory::create('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); + auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto exp = NDArrayFactory::create( + 'c', {1, 4, 1, 4}, + {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); - weights.assign(2.0); - input.linspace(1); + weights.assign(2.0); + input.linspace(1); - sd::ops::conv2d op; - auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}); + sd::ops::conv2d op; + auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}); - ASSERT_EQ(ND4J_STATUS_VALIDATION, result.status()); + ASSERT_EQ(ND4J_STATUS_VALIDATION, result.status()); } TEST_F(DataTypesValidationTests, Basic_Test_2) { - auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto exp = NDArrayFactory::create('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); + auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto exp = NDArrayFactory::create( + 'c', {1, 4, 1, 4}, + {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); - weights.assign(2.0); - input.linspace(1); + weights.assign(2.0); + input.linspace(1); - sd::ops::conv2d op; - auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::conv2d op; + auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DataTypesValidationTests, Basic_Test_3) { - auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto exp = NDArrayFactory::create('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); - auto out = NDArrayFactory::create('c', {1, 4, 1, 4}); - - weights.assign(2.0); - input.linspace(1); - - sd::ops::conv2d op; - auto result = op.execute({&input, &weights}, {&out}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {}); - ASSERT_EQ(Status::OK(), result); - - ASSERT_EQ(exp, out); + auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto exp = NDArrayFactory::create( + 'c', {1, 4, 1, 4}, + {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); + auto out = NDArrayFactory::create('c', {1, 4, 1, 4}); + + weights.assign(2.0); + input.linspace(1); + + sd::ops::conv2d op; + auto result = op.execute({&input, &weights}, {&out}, {}, + {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {}); + ASSERT_EQ(Status::OK(), result); + + ASSERT_EQ(exp, out); } TEST_F(DataTypesValidationTests, Basic_Test_4) { - auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto exp = NDArrayFactory::create('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); - auto out = NDArrayFactory::create('c', {1, 4, 1, 4}); - - weights.assign(2.0); - input.linspace(1); - - sd::ops::conv2d op; - auto result = op.execute({&input, &weights}, {&out}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {}); - ASSERT_EQ(ND4J_STATUS_VALIDATION, result); + auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto exp = NDArrayFactory::create( + 'c', {1, 4, 1, 4}, + {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); + auto out = NDArrayFactory::create('c', {1, 4, 1, 4}); + + weights.assign(2.0); + input.linspace(1); + + sd::ops::conv2d op; + auto result = op.execute({&input, &weights}, {&out}, {}, + {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {}); + ASSERT_EQ(ND4J_STATUS_VALIDATION, result); } TEST_F(DataTypesValidationTests, test_bfloat16_rand_1) { - auto x = NDArrayFactory::create('c', {5, 10}); - RandomGenerator gen(119, 120); - RandomLauncher::fillUniform(LaunchContext::defaultContext(), gen, &x, 1, 6); + auto x = NDArrayFactory::create('c', {5, 10}); + RandomGenerator gen(119, 120); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), gen, &x, 1, 6); - ASSERT_TRUE(x.sumNumber().e(0) != 0.f); + ASSERT_TRUE(x.sumNumber().e(0) != 0.f); } TEST_F(DataTypesValidationTests, test_bfloat16_rand_2) { - auto x = NDArrayFactory::create('c', {5, 10}); - RandomGenerator gen(119, 120); - RandomLauncher::fillGaussian(LaunchContext::defaultContext(), gen, &x, 0, 1); + auto x = NDArrayFactory::create('c', {5, 10}); + RandomGenerator gen(119, 120); + RandomLauncher::fillGaussian(LaunchContext::defaultContext(), gen, &x, 0, 1); - ASSERT_TRUE(x.sumNumber().e(0) != 0.f); + ASSERT_TRUE(x.sumNumber().e(0) != 0.f); } TEST_F(DataTypesValidationTests, cast_1) { + float16 x = static_cast(1.f); + float y = static_cast(x); - float16 x = static_cast(1.f); - float y = static_cast(x); - - ASSERT_TRUE(static_cast(1.f) == x); - ASSERT_TRUE(y == static_cast(x)); + ASSERT_TRUE(static_cast(1.f) == x); + ASSERT_TRUE(y == static_cast(x)); } TEST_F(DataTypesValidationTests, test_bits_hamming_distance_1) { - auto x = NDArrayFactory::create('c', {3}, {0b01011000, 0b01011111, 0b01111110}); - auto y = NDArrayFactory::create('c', {3}, {0b00010110, 0b01011000, 0b01011000}); - auto z = NDArrayFactory::create(0); - - Context ctx(1); - ctx.setInputArray(0, &x); - ctx.setInputArray(1, &y); - ctx.setOutputArray(0, &z); - - sd::ops::bits_hamming_distance op; - auto status = op.execute(&ctx); - ASSERT_NE(Status::OK(), status); + auto x = NDArrayFactory::create('c', {3}, + {0b01011000, 0b01011111, 0b01111110}); + auto y = NDArrayFactory::create('c', {3}, + {0b00010110, 0b01011000, 0b01011000}); + auto z = NDArrayFactory::create(0); + + Context ctx(1); + ctx.setInputArray(0, x); + ctx.setInputArray(1, y); + ctx.setOutputArray(0, z); + + sd::ops::bits_hamming_distance op; + auto status = op.execute(&ctx); + ASSERT_NE(Status::OK(), status); } TEST_F(DataTypesValidationTests, test_bits_hamming_distance_2) { - auto x = NDArrayFactory::create('c', {3}, {0b01011000, 0b01011111, 0b01111110}); - auto y = NDArrayFactory::create('c', {3}, {0b00010110, 0b01011000, 0b01011000}); - auto z = NDArrayFactory::create(0); - - Context ctx(1); - ctx.setInputArray(0, &x); - ctx.setInputArray(1, &y); - ctx.setOutputArray(0, &z); - - sd::ops::bits_hamming_distance op; - auto status = op.execute(&ctx); - ASSERT_EQ(Status::OK(), status); + auto x = NDArrayFactory::create('c', {3}, + {0b01011000, 0b01011111, 0b01111110}); + auto y = NDArrayFactory::create('c', {3}, + {0b00010110, 0b01011000, 0b01011000}); + auto z = NDArrayFactory::create(0); + + Context ctx(1); + ctx.setInputArray(0, x); + ctx.setInputArray(1, y); + ctx.setOutputArray(0, z); + + sd::ops::bits_hamming_distance op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index a5715fd01c55..491b369f5b01 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -14,78 +14,80 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author raver119@gmail.com - // +// +// @author raver119@gmail.com +// -#include "testlayers.h" +#include +#include #include -#include #include #include -#include -#include +#include #include -#include -#include #include +#include +#include #include -#include + +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class DeclarableOpsTests1 : public testing::Test { -public: - - const int bS = 2; // batch size - const int iD = 1; // input depth (number of picture channels, for example rgb=3) - const int iH = 28; // picture height in pixels - const int iW = 28; // picture width in pixels - const int oD = 3; // output depth (= N for dense layer) - const int kH = 5; // kernel height in pixels - const int kW = 5; // kernel width in pixels - const int sH = 1; // stride step in horizontal direction - const int sW = 1; // stride step in vertical direction - const int pH = 0; // padding height - const int pW = 0; // padding width - const int dH = 2; // dilation height - const int dW = 2; // dilation width - const int oH = (iH - kH - (kH - 1) * (dH - 1) + 2 * pH) / sH + 1; // output height - const int oW = (iW - kW - (kW - 1) * (dW - 1) + 2 * pW) / sW + 1; // output width - - DeclarableOpsTests1() { - sd::memory::MemoryTracker::getInstance().reset(); - } - - ~DeclarableOpsTests1() { - sd::memory::MemoryTracker::getInstance().summarize(); - } + public: + const int bS = 2; // batch size + const int iD = + 1; // input depth (number of picture channels, for example rgb=3) + const int iH = 28; // picture height in pixels + const int iW = 28; // picture width in pixels + const int oD = 3; // output depth (= N for dense layer) + const int kH = 5; // kernel height in pixels + const int kW = 5; // kernel width in pixels + const int sH = 1; // stride step in horizontal direction + const int sW = 1; // stride step in vertical direction + const int pH = 0; // padding height + const int pW = 0; // padding width + const int dH = 2; // dilation height + const int dW = 2; // dilation width + const int oH = + (iH - kH - (kH - 1) * (dH - 1) + 2 * pH) / sH + 1; // output height + const int oW = + (iW - kW - (kW - 1) * (dW - 1) + 2 * pW) / sW + 1; // output width + + DeclarableOpsTests1() { sd::memory::MemoryTracker::getInstance().reset(); } + + ~DeclarableOpsTests1() { + sd::memory::MemoryTracker::getInstance().summarize(); + } }; template class TypedDeclarableOpsTests1 : public testing::Test { -public: - - const int bS = 2; // batch size - const int iD = 1; // input depth (number of picture channels, for example rgb=3) - const int iH = 28; // picture height in pixels - const int iW = 28; // picture width in pixels - const int oD = 3; // output depth (= N for dense layer) - const int kH = 5; // kernel height in pixels - const int kW = 5; // kernel width in pixels - const int sH = 1; // stride step in horizontal direction - const int sW = 1; // stride step in vertical direction - const int pH = 0; // padding height - const int pW = 0; // padding width - const int dH = 2; // dilation height - const int dW = 2; // dilation width - const int oH = (iH - kH - (kH - 1) * (dH - 1) + 2 * pH) / sH + 1; // output height - const int oW = (iW - kW - (kW - 1) * (dW - 1) + 2 * pW) / sW + 1; // output width - - TypedDeclarableOpsTests1() { - printf("\n"); - } + public: + const int bS = 2; // batch size + const int iD = + 1; // input depth (number of picture channels, for example rgb=3) + const int iH = 28; // picture height in pixels + const int iW = 28; // picture width in pixels + const int oD = 3; // output depth (= N for dense layer) + const int kH = 5; // kernel height in pixels + const int kW = 5; // kernel width in pixels + const int sH = 1; // stride step in horizontal direction + const int sW = 1; // stride step in vertical direction + const int pH = 0; // padding height + const int pW = 0; // padding width + const int dH = 2; // dilation height + const int dW = 2; // dilation width + const int oH = + (iH - kH - (kH - 1) * (dH - 1) + 2 * pH) / sH + 1; // output height + const int oW = + (iW - kW - (kW - 1) * (dW - 1) + 2 * pW) / sW + 1; // output width + + TypedDeclarableOpsTests1() { printf("\n"); } }; typedef ::testing::Types TestingTypes; @@ -93,1609 +95,1651 @@ TYPED_TEST_CASE(TypedDeclarableOpsTests1, TestingTypes); ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, BasicInitialization1) { - auto concat = new sd::ops::concat(); - std::string expName("concat"); - ASSERT_EQ(expName, *(concat->getOpName())); - - auto x0 = NDArrayFactory::create_('c', { 1, 5 }); - auto x1 = NDArrayFactory::create_('c', { 1, 5 }); - auto x2 = NDArrayFactory::create_('c', { 1, 5 }); - auto x3 = NDArrayFactory::create_('c', { 1, 5 }); - auto x4 = NDArrayFactory::create_('c', { 1, 5 }); - - x0->assign(1.0f); - x1->assign(1.0f); - x2->assign(1.0f); - x3->assign(1.0f); - x4->assign(1.0f); + auto concat = new sd::ops::concat(); + std::string expName("concat"); + ASSERT_EQ(expName, concat->getOpName()); - auto variableSpace = new VariableSpace(); + auto x0 = NDArrayFactory::create('c', {1, 5}); + auto x1 = NDArrayFactory::create('c', {1, 5}); + auto x2 = NDArrayFactory::create('c', {1, 5}); + auto x3 = NDArrayFactory::create('c', {1, 5}); + auto x4 = NDArrayFactory::create('c', {1, 5}); - variableSpace->putVariable(-1, x0); - variableSpace->putVariable(-2, x1); - variableSpace->putVariable(-3, x2); - variableSpace->putVariable(-4, x3); - variableSpace->putVariable(-5, x4); + x0.assign(1.0f); + x1.assign(1.0f); + x2.assign(1.0f); + x3.assign(1.0f); + x4.assign(1.0f); - auto nodeVar = new Variable(); + auto variableSpace = new VariableSpace(); - variableSpace->putVariable(1, nodeVar); + variableSpace->putVariable(-1, x0); + variableSpace->putVariable(-2, x1); + variableSpace->putVariable(-3, x2); + variableSpace->putVariable(-4, x3); + variableSpace->putVariable(-5, x4); - Context block(1, variableSpace); - block.getIArguments()->push_back(1); - block.fillInputs({ -1, -2, -3, -4, -5 }); + auto nodeVar = std::make_shared(); - ASSERT_FALSE(nodeVar->hasNDArray()); + variableSpace->putVariable(1, nodeVar); - Nd4jStatus result = concat->execute(&block); + Context block(1, variableSpace); + block.appendI(1); + block.fillInputs({-1, -2, -3, -4, -5}); - ASSERT_TRUE(nodeVar->hasNDArray()); + ASSERT_FALSE(nodeVar->hasNDArray()); - ASSERT_EQ(25, nodeVar->getNDArray()->lengthOf()); + Nd4jStatus result = concat->execute(&block); - ASSERT_NEAR(25.0, nodeVar->getNDArray()->reduceNumber(reduce::Sum).e(0), 1e-5); + ASSERT_TRUE(nodeVar->hasNDArray()); - ASSERT_EQ(ND4J_STATUS_OK, result); + ASSERT_EQ(25, nodeVar->getNDArray()->lengthOf()); + ASSERT_NEAR(25.0, + nodeVar->getNDArray()->reduceNumber(reduce::Sum).e(0), + 1e-5); - delete variableSpace; - delete concat; + ASSERT_EQ(ND4J_STATUS_OK, result); + + delete variableSpace; + delete concat; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, BasicInitialization2) { - auto op = sd::ops::OpRegistrator::getInstance().getOperation("concat"); + auto op = sd::ops::OpRegistrator::getInstance().getOperation("concat"); - ASSERT_TRUE(op != nullptr); - std::string expName("concat"); - ASSERT_EQ(expName, *(op->getOpName())); + ASSERT_TRUE(op != nullptr); + std::string expName("concat"); + ASSERT_EQ(expName, op->getOpName()); - ASSERT_EQ(-1, op->getOpDescriptor()->getNumberOfInputs()); - ASSERT_EQ(1, op->getOpDescriptor()->getNumberOfOutputs()); + ASSERT_EQ(-1, op->getOpDescriptor()->getNumberOfInputs()); + ASSERT_EQ(1, op->getOpDescriptor()->getNumberOfOutputs()); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ApplyGradientDescent_1) { - auto x = NDArrayFactory::create('c', { 3,4 }, { 1,2,3,4,5,6,7,8,9,10,11,12 }); - auto y = NDArrayFactory::create('c', { 3,4 }, { 0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0,1.1,1.2 }); - auto exp = NDArrayFactory::create('c', { 3,4 }); - exp.linspace(0.9, 0.9); - sd::ops::apply_sgd op; - auto result = op.evaluate({ &x, &y }, { 1. }, {}); - ASSERT_EQ(result.status(), ND4J_STATUS_OK); - auto z = result.at(0); - - ASSERT_TRUE(z->equalsTo(exp)); + auto x = NDArrayFactory::create( + 'c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto y = NDArrayFactory::create( + 'c', {3, 4}, + {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2}); + auto exp = NDArrayFactory::create('c', {3, 4}); + exp.linspace(0.9, 0.9); + sd::ops::apply_sgd op; + auto result = op.evaluate({&x, &y}, {1.}, {}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto z = result.at(0); + ASSERT_TRUE(z.equalsTo(exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, AssignBroadcastTest_1) { - auto x = NDArrayFactory::create('c', { 3,4 }, { 1,2,3,4,5,6,7,8,9,10,11,12 }); - auto y = NDArrayFactory::create('c', { 1,4 }, { 0.1,0.2,0.3,0.4 }); - auto exp = NDArrayFactory::create('c', { 3,4 }, { 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4 }); - sd::ops::assign op; - auto result = op.evaluate({ &x, &y }); - ASSERT_EQ(result.status(), ND4J_STATUS_OK); - auto z = result.at(0); - - ASSERT_TRUE(z->equalsTo(exp)); + auto x = NDArrayFactory::create( + 'c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto y = NDArrayFactory::create('c', {1, 4}, {0.1, 0.2, 0.3, 0.4}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4}); + sd::ops::assign op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto z = result.at(0); + ASSERT_TRUE(z.equalsTo(exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, AssignBroadcastTest_2) { - auto x = NDArrayFactory::create('c', { 3,4 }, { 1,2,3,4,5,6,7,8,9,10,11,12 }); - auto y = NDArrayFactory::create('c', { 1,4 }, { 0.1,0.2,0.3,0.4 }); - auto eps = NDArrayFactory::create('c', { 3,4 }, { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 }); - auto exp1 = NDArrayFactory::create('c', { 3,4 }); // zero - auto exp2 = NDArrayFactory::create('c', { 1,4 }, { 3, 6, 9, 12 }); - sd::ops::assign_bp op; - auto result = op.evaluate({ &x, &y, &eps }); - ASSERT_EQ(result.status(), ND4J_STATUS_OK); - auto z1 = result.at(0); - auto z2 = result.at(1); - - ASSERT_TRUE(z1->equalsTo(exp1)); - ASSERT_TRUE(z2->equalsTo(exp2)); + auto x = NDArrayFactory::create( + 'c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto y = NDArrayFactory::create('c', {1, 4}, {0.1, 0.2, 0.3, 0.4}); + auto eps = NDArrayFactory::create( + 'c', {3, 4}, {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}); + auto exp1 = NDArrayFactory::create('c', {3, 4}); // zero + auto exp2 = NDArrayFactory::create('c', {1, 4}, {3, 6, 9, 12}); + sd::ops::assign_bp op; + auto result = op.evaluate({&x, &y, &eps}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto z1 = result.at(0); + auto z2 = result.at(1); + ASSERT_TRUE(z1.equalsTo(exp1)); + ASSERT_TRUE(z2.equalsTo(exp2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, AXpY_Test_1) { - auto x = NDArrayFactory::create('c', { 3,4 }, { 1,2,3,4,5,6,7,8,9,10,11,12 }); - auto y = NDArrayFactory::create('c', { 3,4 }, { 1,2,3,4,5,6,7,8,9,10,11,12 }); - auto exp = NDArrayFactory::create('c', { 3,4 }); - exp.linspace(3, 3); - sd::ops::axpy op; - auto result = op.evaluate({ &x, &y }, { 2. }); - ASSERT_EQ(result.status(), ND4J_STATUS_OK); - auto z = result.at(0); - - ASSERT_TRUE(z->equalsTo(exp)); + auto x = NDArrayFactory::create( + 'c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto y = NDArrayFactory::create( + 'c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create('c', {3, 4}); + exp.linspace(3, 3); + sd::ops::axpy op; + auto result = op.evaluate({&x, &y}, {2.}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto z = result.at(0); + ASSERT_TRUE(z.equalsTo(exp)); } TEST_F(DeclarableOpsTests1, BasicInitialization3) { - auto op1 = sd::ops::OpRegistrator::getInstance().getOperation("concat"); - std::string expName("concat"); - auto hash = sd::ops::HashHelper::getInstance().getLongHash(expName); + auto op1 = sd::ops::OpRegistrator::getInstance().getOperation("concat"); + std::string expName("concat"); + auto hash = sd::ops::HashHelper::getInstance().getLongHash(expName); - auto op2 = sd::ops::OpRegistrator::getInstance().getOperation(hash); + auto op2 = sd::ops::OpRegistrator::getInstance().getOperation(hash); - ASSERT_TRUE(op1 == op2); + ASSERT_TRUE(op1 == op2); } - TEST_F(DeclarableOpsTests1, SynonymInitialization2) { - auto op = sd::ops::OpRegistrator::getInstance().getOperation("Mul"); - auto op2 = sd::ops::OpRegistrator::getInstance().getOperation("multiply"); + auto op = sd::ops::OpRegistrator::getInstance().getOperation("Mul"); + auto op2 = sd::ops::OpRegistrator::getInstance().getOperation("multiply"); - ASSERT_TRUE(op != nullptr); - std::string expName("multiply"); - ASSERT_EQ(expName, *(op->getOpName())); - ASSERT_TRUE(op == op2); + ASSERT_TRUE(op != nullptr); + std::string expName("multiply"); + ASSERT_EQ(expName, op->getOpName()); + ASSERT_TRUE(op == op2); } - TEST_F(DeclarableOpsTests1, TestTensorMmul1) { + NDArray x('c', {2, 3, 4}, sd::DataType::FLOAT32); + NDArray y('c', {2, 3, 4}, sd::DataType::FLOAT32); - NDArray x('c', { 2, 3, 4 }, sd::DataType::FLOAT32); - NDArray y('c', { 2, 3, 4 }, sd::DataType::FLOAT32); - - x.linspace(1); - y.linspace(1); - - NDArray exp('c', { 2, 2 }, { 650.0, 1586.0, 1586.0, 4250.0 }, sd::DataType::FLOAT32); - - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 2,1,2,2,1,2 }); + x.linspace(1); + y.linspace(1); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + NDArray exp('c', {2, 2}, {650.0, 1586.0, 1586.0, 4250.0}, + sd::DataType::FLOAT32); - auto* out = results.at(0); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {2, 1, 2, 2, 1, 2}); - ASSERT_TRUE(exp.isSameShape(out)); - ASSERT_TRUE(exp.equalsTo(out)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto out = results.at(0); + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); } TEST_F(DeclarableOpsTests1, TestTensorDot2) { + NDArray x('f', {2, 3, 4}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.}, + sd::DataType::FLOAT32); + NDArray y('f', {2, 3, 4}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.}, + sd::DataType::FLOAT32); - NDArray x('f', { 2, 3, 4 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. }, sd::DataType::FLOAT32); - NDArray y('f', { 2, 3, 4 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. }, sd::DataType::FLOAT32); + NDArray exp('c', {2, 2}, {2300.0, 2444.0, 2444.0, 2600.0}, + sd::DataType::FLOAT32); - NDArray exp('c', { 2, 2 }, { 2300.0, 2444.0, 2444.0, 2600.0 }, sd::DataType::FLOAT32); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {2, 1, 2, 2, 1, 2}); - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 2,1,2,2,1,2 }); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto* out = results.at(0); - - ASSERT_TRUE(exp.isSameShape(out)); - ASSERT_TRUE(exp.equalsTo(out)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto out = results.at(0); + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); } TEST_F(DeclarableOpsTests1, TestTensorDot3) { + NDArray x('c', {2, 3, 4}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.}, + sd::DataType::FLOAT32); + NDArray y('f', {2, 3, 4}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.}, + sd::DataType::FLOAT32); - NDArray x('c', { 2, 3, 4 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. }, sd::DataType::FLOAT32); - NDArray y('f', { 2, 3, 4 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. }, sd::DataType::FLOAT32); - - NDArray exp('f', { 2, 2 }, { 1090.0, 2818.0, 1168.0, 3040.0 }, sd::DataType::FLOAT32); + NDArray exp('f', {2, 2}, {1090.0, 2818.0, 1168.0, 3040.0}, + sd::DataType::FLOAT32); - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 2,1,2,2,1,2 }); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {2, 1, 2, 2, 1, 2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto* out = results.at(0); - - ASSERT_TRUE(exp.isSameShape(out)); - ASSERT_TRUE(exp.equalsTo(out)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto out = results.at(0); + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); } TEST_F(DeclarableOpsTests1, TestTensorDot4) { + NDArray x('f', {2, 3, 4}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.}, + sd::DataType::FLOAT32); + NDArray y('c', {2, 3, 4}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.}, + sd::DataType::FLOAT32); - NDArray x('f', { 2, 3, 4 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. }, sd::DataType::FLOAT32); - NDArray y('c', { 2, 3, 4 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. }, sd::DataType::FLOAT32); - - NDArray exp('f', { 2, 2 }, { 1090.0, 1168.0, 2818.0, 3040.0 }, sd::DataType::FLOAT32); - - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 2,1,2,2,1,2 }); + NDArray exp('f', {2, 2}, {1090.0, 1168.0, 2818.0, 3040.0}, + sd::DataType::FLOAT32); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {2, 1, 2, 2, 1, 2}); - auto* out = results.at(0); - - ASSERT_TRUE(exp.isSameShape(out)); - ASSERT_TRUE(exp.equalsTo(out)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto out = results.at(0); + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot5) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, + 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15}); + auto y = NDArrayFactory::create( + 'c', {2, 4, 3}, {2, 4, 6, 8, 10, 12, 14, 16, 2, 4, 6, 8, + 10, 12, 14, 16, 2, 4, 6, 8, 10, 12, 14, 16}); + auto expected = NDArrayFactory::create( + 'c', {2, 4, 2, 4}, + {44, 110, 160, 66, 132, 38, 88, 154, 68, 170, 224, 102, 204, + 82, 136, 238, 92, 230, 288, 138, 276, 126, 184, 322, 116, 290, + 352, 174, 348, 170, 232, 406, 76, 190, 160, 114, 228, 182, 152, + 266, 100, 250, 224, 150, 300, 226, 200, 350, 124, 310, 288, 186, + 372, 270, 248, 434, 148, 370, 352, 222, 444, 314, 296, 518}); - auto x = NDArrayFactory::create('c', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); - auto y = NDArrayFactory::create('c', { 2,4,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); - auto expected = NDArrayFactory::create('c', { 2,4,2,4 }, { 44,110,160, 66,132, 38, 88,154, 68,170,224,102,204, 82,136,238, 92,230,288,138,276,126,184,322, 116,290,352,174,348,170,232,406, 76,190,160,114,228,182,152,266, 100,250,224,150,300,226,200,350, 124,310,288,186,372,270,248,434, 148,370,352,222,444,314,296,518 }); - - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 1,1,1,2 }); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto* result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1, 1, 2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot6) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, + 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15}); + auto y = NDArrayFactory::create( + 'f', {2, 4, 3}, {2, 4, 6, 8, 10, 12, 14, 16, 2, 4, 6, 8, + 10, 12, 14, 16, 2, 4, 6, 8, 10, 12, 14, 16}); + auto expected = NDArrayFactory::create( + 'c', {2, 4, 2, 4}, + {22, 66, 110, 154, 44, 88, 132, 176, 34, 102, 170, 238, 68, + 136, 204, 272, 46, 138, 230, 322, 92, 184, 276, 368, 58, 174, + 290, 406, 116, 232, 348, 464, 38, 114, 190, 266, 76, 152, 228, + 304, 50, 150, 250, 350, 100, 200, 300, 400, 62, 186, 310, 434, + 124, 248, 372, 496, 74, 222, 370, 518, 148, 296, 444, 592}); - auto x = NDArrayFactory::create('c', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); - auto y = NDArrayFactory::create('f', { 2,4,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); - auto expected = NDArrayFactory::create('c', { 2,4,2,4 }, { 22, 66,110,154, 44, 88,132,176, 34,102,170,238, 68,136,204,272, 46,138,230,322, 92,184,276,368, 58,174,290,406,116,232,348,464, 38,114,190,266, 76,152,228,304, 50,150,250,350,100,200,300,400, 62,186,310,434,124,248,372,496, 74,222,370,518,148,296,444,592 }); - - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 1,1,1,2 }); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto* result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1, 1, 2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot7) { + auto x = NDArrayFactory::create( + 'f', {2, 3, 4}, {1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, + 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15}); + auto y = NDArrayFactory::create( + 'c', {2, 4, 3}, {2, 4, 6, 8, 10, 12, 14, 16, 2, 4, 6, 8, + 10, 12, 14, 16, 2, 4, 6, 8, 10, 12, 14, 16}); + auto expected = NDArrayFactory::create( + 'c', {2, 4, 2, 4}, + {76, 166, 112, 106, 196, 62, 136, 226, 60, 174, 208, 98, 212, + 230, 136, 250, 76, 214, 336, 122, 260, 174, 168, 306, 124, 286, + 240, 178, 340, 150, 232, 394, 100, 226, 176, 142, 268, 106, 184, + 310, 84, 234, 272, 134, 284, 274, 184, 334, 100, 274, 400, 158, + 332, 218, 216, 390, 148, 346, 304, 214, 412, 194, 280, 478}); - auto x = NDArrayFactory::create('f', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); - auto y = NDArrayFactory::create('c', { 2,4,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); - auto expected = NDArrayFactory::create('c', { 2,4,2,4 }, { 76,166,112,106,196, 62,136,226, 60,174,208, 98,212,230,136,250, 76,214,336,122,260,174,168,306, 124,286,240,178,340,150,232,394, 100,226,176,142,268,106,184,310, 84,234,272,134,284,274,184,334, 100,274,400,158,332,218,216,390, 148,346,304,214,412,194,280,478 }); - - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 1,1,1,2 }); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto* result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1, 1, 2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot8) { + auto x = NDArrayFactory::create( + 'f', {2, 3, 4}, {1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, + 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15}); + auto y = NDArrayFactory::create( + 'f', {2, 4, 3}, {2, 4, 6, 8, 10, 12, 14, 16, 2, 4, 6, 8, + 10, 12, 14, 16, 2, 4, 6, 8, 10, 12, 14, 16}); + auto expected = NDArrayFactory::create( + 'c', {2, 4, 2, 4}, + {30, 90, 150, 210, 60, 120, 180, 240, 38, 114, 190, 266, 76, + 152, 228, 304, 46, 138, 230, 322, 92, 184, 276, 368, 54, 162, + 270, 378, 108, 216, 324, 432, 42, 126, 210, 294, 84, 168, 252, + 336, 50, 150, 250, 350, 100, 200, 300, 400, 58, 174, 290, 406, + 116, 232, 348, 464, 66, 198, 330, 462, 132, 264, 396, 528}); - auto x = NDArrayFactory::create('f', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); - auto y = NDArrayFactory::create('f', { 2,4,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); - auto expected = NDArrayFactory::create('c', { 2,4,2,4 }, { 30, 90,150,210, 60,120,180,240, 38,114,190,266, 76,152,228,304, 46,138,230,322, 92,184,276,368, 54,162,270,378,108,216,324,432, 42,126,210,294, 84,168,252,336, 50,150,250,350,100,200,300,400, 58,174,290,406,116,232,348,464, 66,198,330,462,132,264,396,528 }); - - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 1,1,1,2 }); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto* result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1, 1, 2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot9) { - - // NDArray z('f',{2,2,3}, sd::DataType::DOUBLE); - // z.linspace(1); - // z.printShapeInfo(); - // z.printIndexedBuffer(); - // z.reshapei('c', {4,3}); - // z.printShapeInfo(); - // z.printIndexedBuffer(); - - auto x = NDArrayFactory::create('f', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); - auto y = NDArrayFactory::create('f', { 2,4,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); - auto expected = NDArrayFactory::create('c', { 3,4,4,3 }, { 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422 }); - - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 1,0,1,0 }); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto* result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + // NDArray z('f',{2,2,3}, sd::DataType::DOUBLE); + // z.linspace(1); + // z.printShapeInfo(); + // z.printIndexedBuffer(); + // z.reshapei('c', {4,3}); + // z.printShapeInfo(); + // z.printIndexedBuffer(); + + auto x = NDArrayFactory::create( + 'f', {2, 3, 4}, {1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, + 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15}); + auto y = NDArrayFactory::create( + 'f', {2, 4, 3}, {2, 4, 6, 8, 10, 12, 14, 16, 2, 4, 6, 8, + 10, 12, 14, 16, 2, 4, 6, 8, 10, 12, 14, 16}); + auto expected = NDArrayFactory::create( + 'c', {3, 4, 4, 3}, + {14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, + 86, 198, 198, 198, 310, 310, 310, 422, 422, 422, 62, 62, 62, 142, + 142, 142, 222, 222, 222, 302, 302, 302, 38, 38, 38, 86, 86, 86, + 134, 134, 134, 182, 182, 182, 38, 38, 38, 86, 86, 86, 134, 134, + 134, 182, 182, 182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, + 62, 62, 86, 86, 86, 198, 198, 198, 310, 310, 310, 422, 422, 422, + 62, 62, 62, 142, 142, 142, 222, 222, 222, 302, 302, 302, 62, 62, + 62, 142, 142, 142, 222, 222, 222, 302, 302, 302, 38, 38, 38, 86, + 86, 86, 134, 134, 134, 182, 182, 182, 14, 14, 14, 30, 30, 30, + 46, 46, 46, 62, 62, 62, 86, 86, 86, 198, 198, 198, 310, 310, + 310, 422, 422, 422}); + + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 0, 1, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot10) { + auto x = NDArrayFactory::create( + 'f', {2, 3, 4}, {1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, + 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15}); + auto y = NDArrayFactory::create( + 'f', {2, 4, 3}, {2, 4, 6, 8, 10, 12, 14, 16, 2, 4, 6, 8, + 10, 12, 14, 16, 2, 4, 6, 8, 10, 12, 14, 16}); + auto expected = + NDArrayFactory::create('c', {4, 4}, + {114, 258, 402, 546, 138, 314, 490, 666, + 162, 370, 578, 786, 186, 426, 666, 906}); - auto x = NDArrayFactory::create('f', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); - auto y = NDArrayFactory::create('f', { 2,4,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); - auto expected = NDArrayFactory::create('c', { 4,4 }, { 114,258,402,546, 138,314,490,666, 162,370,578,786, 186,426,666,906 }); - - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 2,0,1, 2,0,2 }); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto* result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {2, 0, 1, 2, 0, 2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot11) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, + 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15}); + auto y = NDArrayFactory::create( + 'f', {2, 4, 3}, {2, 4, 6, 8, 10, 12, 14, 16, 2, 4, 6, 8, + 10, 12, 14, 16, 2, 4, 6, 8, 10, 12, 14, 16}); + auto expected = + NDArrayFactory::create('c', {4, 4}, + {98, 218, 338, 458, 134, 302, 470, 638, + 170, 386, 602, 818, 206, 470, 734, 998}); - auto x = NDArrayFactory::create('c', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); - auto y = NDArrayFactory::create('f', { 2,4,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); - auto expected = NDArrayFactory::create('c', { 4,4 }, { 98,218,338,458, 134,302,470,638, 170,386,602,818, 206,470,734,998 }); - - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 2,0,1, 2,0,2 }); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto* result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {2, 0, 1, 2, 0, 2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot12) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, + 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15}); + auto y = NDArrayFactory::create( + 'c', {2, 4, 3}, {2, 4, 6, 8, 10, 12, 14, 16, 2, 4, 6, 8, + 10, 12, 14, 16, 2, 4, 6, 8, 10, 12, 14, 16}); + auto expected = + NDArrayFactory::create('c', {4, 4}, + {272, 292, 312, 332, 368, 396, 424, 452, + 464, 500, 536, 572, 560, 604, 648, 692}); - auto x = NDArrayFactory::create('c', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); - auto y = NDArrayFactory::create('c', { 2,4,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); - auto expected = NDArrayFactory::create('c', { 4,4 }, { 272,292,312,332, 368,396,424,452, 464,500,536,572, 560,604,648,692 }); - - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 2,0,1, 2,0,2 }); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto* result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {2, 0, 1, 2, 0, 2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot13) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, + 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15}); + auto y = NDArrayFactory::create( + 'c', {4, 2, 3}, {2, 4, 6, 8, 10, 12, 14, 16, 2, 4, 6, 8, + 10, 12, 14, 16, 2, 4, 6, 8, 10, 12, 14, 16}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, {640, 560, 640, 576, 624, 576, 640, 560, 640}); - auto x = NDArrayFactory::create('c', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); - auto y = NDArrayFactory::create('c', { 4,2,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); - auto expected = NDArrayFactory::create('c', { 3,3 }, { 640,560,640, 576,624,576, 640,560,640 }); - - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 2,0,2, 2,1,0 }); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto* result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {2, 0, 2, 2, 1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot14) { + auto x = NDArrayFactory::create( + 'f', {2, 3, 4}, {1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, + 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15}); + auto y = NDArrayFactory::create( + 'c', {4, 2, 3}, {2, 4, 6, 8, 10, 12, 14, 16, 2, 4, 6, 8, + 10, 12, 14, 16, 2, 4, 6, 8, 10, 12, 14, 16}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, {648, 600, 520, 648, 536, 648, 520, 600, 648}); - auto x = NDArrayFactory::create('f', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); - auto y = NDArrayFactory::create('c', { 4,2,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); - auto expected = NDArrayFactory::create('c', { 3,3 }, { 648,600,520, 648,536,648, 520,600,648 }); - - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 2,0,2, 2,1,0 }); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto* result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {2, 0, 2, 2, 1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot15) { + auto x = NDArrayFactory::create( + 'f', {2, 3, 4}, {1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, + 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15}); + auto y = NDArrayFactory::create( + 'f', {4, 2, 3}, {2, 4, 6, 8, 10, 12, 14, 16, 2, 4, 6, 8, + 10, 12, 14, 16, 2, 4, 6, 8, 10, 12, 14, 16}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, {624, 624, 624, 656, 656, 656, 624, 624, 624}); - auto x = NDArrayFactory::create('f', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); - auto y = NDArrayFactory::create('f', { 4,2,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); - auto expected = NDArrayFactory::create('c', { 3,3 }, { 624,624,624, 656,656,656, 624,624,624 }); - - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 2,0,2, 2,1,0 }); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto* result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {2, 0, 2, 2, 1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot16) { + NDArray x('c', {1}, std::vector{2}, sd::DataType::FLOAT32); + NDArray y('c', {2, 1, 2}, {1, 2, 3, 4}, sd::DataType::FLOAT32); + NDArray exp('c', {2, 2}, {2, 4, 6, 8}, sd::DataType::FLOAT32); - NDArray x('c', { 1 }, std::vector{2}, sd::DataType::FLOAT32); - NDArray y('c', { 2,1,2 }, { 1,2,3,4 }, sd::DataType::FLOAT32); - NDArray exp('c', { 2,2 }, { 2,4,6,8 }, sd::DataType::FLOAT32); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 0, 1, 1}); - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 1,0, 1,1 }); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto* result = results.at(0); - - ASSERT_TRUE(exp.isSameShape(result)); - ASSERT_TRUE(exp.equalsTo(result)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot17) { + NDArray x('f', {16, 16}, sd::DataType::FLOAT32); + NDArray y('f', {1000, 16}, sd::DataType::FLOAT32); + NDArray z('c', {16, 1000}, sd::DataType::FLOAT32); - NDArray x('f', { 16,16 }, sd::DataType::FLOAT32); - NDArray y('f', { 1000,16 }, sd::DataType::FLOAT32); - NDArray z('c', { 16,1000 }, sd::DataType::FLOAT32); + sd::ops::tensormmul op; + auto status = op.execute({&x, &y}, {&z}, {}, {1, 1, 1, 1}, {}); - sd::ops::tensormmul op; - auto status = op.execute({ &x, &y }, { &z }, {}, { 1,1, 1,1 }, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_EQ(ND4J_STATUS_OK, status); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, DivergentCheck1) { - auto op = sd::ops::OpRegistrator::getInstance().getOperation("switch"); + auto op = sd::ops::OpRegistrator::getInstance().getOperation("switch"); - ASSERT_TRUE(op != nullptr); - std::string expName("Switch"); - ASSERT_EQ(expName, *(op->getOpName())); - ASSERT_TRUE(op->getOpDescriptor()->isDivergent()); - ASSERT_EQ(2, op->getOpDescriptor()->getNumberOfOutputs()); + ASSERT_TRUE(op != nullptr); + std::string expName("Switch"); + ASSERT_EQ(expName, op->getOpName()); + ASSERT_TRUE(op->getOpDescriptor()->isDivergent()); + ASSERT_EQ(2, op->getOpDescriptor()->getNumberOfOutputs()); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, AddMatrices1) { + auto x = NDArrayFactory::create('c', {5, 3}); + auto y = NDArrayFactory::create('c', {5, 3}); + auto exp = NDArrayFactory::create('c', {5, 3}); + x.assign(2); + y.assign(1); + exp.assign(3); - auto x = NDArrayFactory::create_('c', { 5, 3 }); - auto y = NDArrayFactory::create_('c', { 5, 3 }); - auto exp = NDArrayFactory::create_('c', { 5, 3 }); - x->assign(2); - y->assign(1); - exp->assign(3); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::add addOp; + sd::ops::add addOp; - addOp.execute(block); - - ASSERT_TRUE(x->equalsTo(exp)); - - delete exp; - delete block; - delete variableSpace; + addOp.execute(block); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); + delete block; + delete variableSpace; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, AddVectorVector1) { + auto x = NDArrayFactory::create('c', {1, 15}); + auto y = NDArrayFactory::create('c', {1, 15}); + auto exp = NDArrayFactory::create('c', {1, 15}); + x.assign(2); + y.assign(1); + exp.assign(3); - auto x = NDArrayFactory::create_('c', { 1, 15 }); - auto y = NDArrayFactory::create_('c', { 1, 15 }); - auto exp = NDArrayFactory::create_('c', { 1, 15 }); - x->assign(2); - y->assign(1); - exp->assign(3); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + auto variableSpace = new VariableSpace(); - sd::ops::add addOp; + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, false); - addOp.execute(block); + block->fillInputs({-1, -2}); + sd::ops::add addOp; - ASSERT_TRUE(x->equalsTo(exp)); + addOp.execute(block); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete exp; - delete block; - delete variableSpace; + delete block; + delete variableSpace; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, AddMatrixScalar1) { + auto x = NDArrayFactory::create('c', {5, 3}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {5, 3}); + x.assign(2); + y.assign(1); + exp.assign(3); - auto x = NDArrayFactory::create_('c', { 5, 3 }); - auto y = NDArrayFactory::create_('c', { 1, 1 }); - auto exp = NDArrayFactory::create('c', { 5, 3 }); - x->assign(2); - y->assign(1); - exp.assign(3); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - sd::ops::add addOp; + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - addOp.execute(block); + sd::ops::add addOp; - ASSERT_TRUE(x->equalsTo(&exp)); + addOp.execute(block); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, AddScalarScalar1) { + auto x = NDArrayFactory::create('c', {1, 1}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {1, 1}); + x.assign(2); + y.assign(1); + exp.assign(3); - auto x = NDArrayFactory::create_('c', { 1, 1 }); - auto y = NDArrayFactory::create_('c', { 1, 1 }); - auto exp = NDArrayFactory::create('c', { 1, 1 }); - x->assign(2); - y->assign(1); - exp.assign(3); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::add addOp; + sd::ops::add addOp; - addOp.execute(block); + addOp.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, SubtractMatrices1) { + auto x = NDArrayFactory::create('c', {5, 3}); + auto y = NDArrayFactory::create('c', {5, 3}); + auto exp = NDArrayFactory::create('c', {5, 3}); + x.assign(3); + y.assign(1); + exp.assign(2); - auto x = NDArrayFactory::create_('c', { 5, 3 }); - auto y = NDArrayFactory::create_('c', { 5, 3 }); - auto exp = NDArrayFactory::create('c', { 5, 3 }); - x->assign(3); - y->assign(1); - exp.assign(2); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); - - sd::ops::subtract subOp; + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - subOp.execute(block); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - ASSERT_TRUE(x->equalsTo(&exp)); + sd::ops::subtract subOp; + subOp.execute(block); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, SubtractTest_1) { + auto x = NDArrayFactory::create('c', {1, 6}); + auto y = NDArrayFactory::create('c', {1, 6}); + auto exp = NDArrayFactory::create('c', {1, 6}); + x.assign(3); + y.assign(1); + exp.assign(2); - auto x = NDArrayFactory::create_('c', { 1, 6 }); - auto y = NDArrayFactory::create_('c', { 1, 6 }); - auto exp = NDArrayFactory::create('c', { 1, 6 }); - x->assign(3); - y->assign(1); - exp.assign(2); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - sd::ops::subtract subOp; + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - subOp.execute(block); + sd::ops::subtract subOp; - ASSERT_TRUE(x->equalsTo(&exp)); + subOp.execute(block); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, SubtractTest_2) { + auto x = NDArrayFactory::create('c', {3, 4, 5, 1}); + auto y = NDArrayFactory::create('c', {1, 6}); + // auto y({6}, {1,1,1,1,1,1}); + auto exp = NDArrayFactory::create('c', {3, 4, 5, 6}); + x.assign(3); + y.assign(1); + exp.assign(2); - auto x = NDArrayFactory::create('c', { 3, 4, 5, 1 }); - auto y = NDArrayFactory::create('c', { 1, 6 }); - // auto y({6}, {1,1,1,1,1,1}); - auto exp = NDArrayFactory::create('c', { 3, 4, 5, 6 }); - x.assign(3); - y.assign(1); - exp.assign(2); - + sd::ops::subtract subOp; - sd::ops::subtract subOp; - - auto res = subOp.evaluate({ &x, &y }); - - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - - ASSERT_TRUE(res.at(0)->equalsTo(&exp)); + auto res = subOp.evaluate({&x, &y}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0).equalsTo(&exp)); } TEST_F(DeclarableOpsTests1, TestRng1) { - /* - Nd4jLong *buffer = new Nd4jLong[100000]; + /* + Nd4jLong *buffer = new Nd4jLong[100000]; - sd::random::RandomBuffer *rng = (sd::random::RandomBuffer *) initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer); + sd::random::RandomBuffer *rng = (sd::random::RandomBuffer *) + initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer); - if (rng == nullptr) - throw std::runtime_error("RNG initialization failed"); + if (rng == nullptr) + throw std::runtime_error("RNG initialization failed"); - auto x = NDArrayFactory::create_('c', {5, 3}); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - auto block = new Context(1, variableSpace, true); - block->fillInputs({-1}); - block->setRNG(rng); - block->getTArguments()->push_back(0.0f); - block->getTArguments()->push_back(1.0f); + auto x = NDArrayFactory::create_('c', {5, 3}); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + auto block = new Context(1, variableSpace, true); + block->fillInputs({-1}); + block->setRNG(rng); + block->getTArguments()->push_back(0.0f); + block->getTArguments()->push_back(1.0f); - sd::ops::randomuniform uniform; + sd::ops::randomuniform uniform; - Nd4jStatus status = uniform.execute(block); + Nd4jStatus status = uniform.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(x->sumNumber() > 0.0); + ASSERT_TRUE(x->sumNumber() > 0.0); - destroyRandom((Nd4jPointer) rng); - delete[] buffer; + destroyRandom((Nd4jPointer) rng); + delete[] buffer; - delete variableSpace; - delete block; - */ + delete variableSpace; + delete block; + */ } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, MergeSumTest1) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.assign(3); + y.assign(1); + z.assign(2); + exp.assign(6); - auto x = NDArrayFactory::create_('c', { 5, 5 }); - auto y = NDArrayFactory::create_('c', { 5, 5 }); - auto z = NDArrayFactory::create_('c', { 5, 5 }); - auto exp = NDArrayFactory::create('c', { 5, 5 }); - x->assign(3); - y->assign(1); - z->assign(2); - exp.assign(6); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - variableSpace->putVariable(-3, z); - variableSpace->putVariable(1, new Variable(NDArrayFactory::create_('c', { 5, 5 }))); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2, -3 }); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + variableSpace->putVariable(-3, z); + variableSpace->putVariable(1, NDArrayFactory::create('c', {5, 5})); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2, -3}); - sd::ops::mergeadd merge; + sd::ops::mergeadd merge; - merge.execute(block); + merge.execute(block); - auto res = variableSpace->getVariable(1)->getNDArray(); + auto res = variableSpace->getVariable(1)->getNDArray(); - ASSERT_TRUE(res->equalsTo(&exp)); + ASSERT_TRUE(res->equalsTo(&exp)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ClipByValue1) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.assign(4); + x.p(0, -1); + x.p(1, 2); + exp.assign(3); + exp.p(0, 0); + exp.p(1, 2); - auto x = NDArrayFactory::create_('c', { 5, 5 }); - auto exp = NDArrayFactory::create('c', { 5, 5 }); - x->assign(4); - x->p(0, -1); - x->p(1, 2); - exp.assign(3); - exp.p(0, 0); - exp.p(1, 2); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + auto block = new Context(1, variableSpace, false); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(1, new Variable()); - auto block = new Context(1, variableSpace, true); - block->getTArguments()->push_back(0.0f); - block->getTArguments()->push_back(3.0f); - block->fillInputs({ -1 }); + block->appendT(0.0f); + block->appendT(3.0f); - sd::ops::clipbyvalue clip; + block->fillInputs({-1}); - clip.execute(block); + sd::ops::clipbyvalue clip; - ASSERT_TRUE(x->equalsTo(&exp)); + clip.execute(block); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, MergeAvgTest1) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.assign(3); + y.assign(1); + z.assign(2); + exp.assign(2); - auto x = NDArrayFactory::create_('c', { 5, 5 }); - auto y = NDArrayFactory::create_('c', { 5, 5 }); - auto z = NDArrayFactory::create_('c', { 5, 5 }); - auto exp = NDArrayFactory::create('c', { 5, 5 }); - x->assign(3); - y->assign(1); - z->assign(2); - exp.assign(2); + auto zu = NDArrayFactory::create('c', {5, 5}); - auto zu = NDArrayFactory::create('c', { 5, 5 }); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + variableSpace->putVariable(-3, z); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2, -3}); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - variableSpace->putVariable(-3, z); - variableSpace->putVariable(1, new Variable(NDArrayFactory::create_('c', { 5, 5 }))); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2, -3 }); + sd::ops::mergeavg merge; - sd::ops::mergeavg merge; + merge.execute(block); - merge.execute(block); + auto res = variableSpace->getVariable(1)->getNDArray(); - auto res = variableSpace->getVariable(1)->getNDArray(); + ASSERT_TRUE(res->equalsTo(exp)); - ASSERT_TRUE(res->equalsTo(&exp)); - - delete block; - delete variableSpace; + delete block; + delete variableSpace; } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, SubtractVectorVector1) { + auto x = NDArrayFactory::create('c', {1, 15}); + auto y = NDArrayFactory::create('c', {1, 15}); + auto exp = NDArrayFactory::create('c', {1, 15}); + x.assign(3); + y.assign(1); + exp.assign(2); - auto x = NDArrayFactory::create_('c', { 1, 15 }); - auto y = NDArrayFactory::create_('c', { 1, 15 }); - auto exp = NDArrayFactory::create('c', { 1, 15 }); - x->assign(3); - y->assign(1); - exp.assign(2); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - sd::ops::subtract subOp; + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - subOp.execute(block); + sd::ops::subtract subOp; - ASSERT_TRUE(x->equalsTo(&exp)); + subOp.execute(block); - delete block; - delete variableSpace; + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); + delete block; + delete variableSpace; } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, SubtractMatrixScalar1) { + auto x = NDArrayFactory::create('c', {5, 3}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {5, 3}); + x.assign(3); + y.assign(1); + exp.assign(2); - auto x = NDArrayFactory::create_('c', { 5, 3 }); - auto y = NDArrayFactory::create_('c', { 1, 1 }); - auto exp = NDArrayFactory::create('c', { 5, 3 }); - x->assign(3); - y->assign(1); - exp.assign(2); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::subtract subOp; + sd::ops::subtract subOp; - subOp.execute(block); + subOp.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete block; - delete variableSpace; + delete block; + delete variableSpace; } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, SubtractScalarScalar1) { + auto x = NDArrayFactory::create('c', {1, 1}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {1, 1}); + x.assign(3); + y.assign(1); + exp.assign(2); - auto x = NDArrayFactory::create_('c', { 1, 1 }); - auto y = NDArrayFactory::create_('c', { 1, 1 }); - auto exp = NDArrayFactory::create('c', { 1, 1 }); - x->assign(3); - y->assign(1); - exp.assign(2); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::subtract subOp; + sd::ops::subtract subOp; - subOp.execute(block); + subOp.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete block; - delete variableSpace; + delete block; + delete variableSpace; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseSubtractMatrices1) { + auto x = NDArrayFactory::create('c', {5, 3}); + auto y = NDArrayFactory::create('c', {5, 3}); + auto exp = NDArrayFactory::create('c', {5, 3}); + x.assign(3.f); + y.assign(1.f); + exp.assign(-2.f); - auto x = NDArrayFactory::create_('c', { 5, 3 }); - auto y = NDArrayFactory::create_('c', { 5, 3 }); - auto exp = NDArrayFactory::create('c', { 5, 3 }); - x->assign(3.f); - y->assign(1.f); - exp.assign(-2.f); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::reversesubtract subOp; + sd::ops::reversesubtract subOp; - subOp.execute(block); + subOp.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseSubtractTest_1) { + auto x = NDArrayFactory::create('c', {1, 6}); + auto y = NDArrayFactory::create('c', {1, 6}); + auto exp = NDArrayFactory::create('c', {1, 6}); + x.assign(3.f); + y.assign(1.f); + exp.assign(-2.f); - auto x = NDArrayFactory::create('c', { 1, 6 }); - auto y = NDArrayFactory::create('c', { 1, 6 }); - auto exp = NDArrayFactory::create('c', { 1, 6 }); - x.assign(3.f); - y.assign(1.f); - exp.assign(-2.f); - - sd::ops::reversesubtract subOp; - - auto res = subOp.evaluate({ &x, &y }); - - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0)->equalsTo(&exp)); + sd::ops::reversesubtract subOp; + auto res = subOp.evaluate({&x, &y}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0).equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseSubtractTest_2) { + // auto x('c', {1, 6}); + auto x = NDArrayFactory::create('c', {1, 6}); + auto y = NDArrayFactory::create('c', {3, 4, 5, 1}); + auto exp = NDArrayFactory::create('c', {3, 4, 5, 6}); + auto z(exp); + x.assign(3.f); + y.assign(1.f); + exp.assign(-2.f); + x.applyTrueBroadcast(BROADCAST(ReverseSubtract), y, z, true); - // auto x('c', {1, 6}); - auto x = NDArrayFactory::create('c', { 1, 6 }); - auto y = NDArrayFactory::create('c', { 3, 4, 5, 1 }); - auto exp = NDArrayFactory::create('c', { 3, 4, 5, 6 }); - auto z(exp); - x.assign(3.f); - y.assign(1.f); - exp.assign(-2.f); - x.applyTrueBroadcast(BROADCAST(ReverseSubtract), y, z, true); - - ASSERT_TRUE(exp.equalsTo(&z)); - - sd::ops::reversesubtract subOp; + ASSERT_TRUE(exp.equalsTo(z)); - auto res = subOp.evaluate({ &x, &y }); - - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0)->equalsTo(&exp)); + sd::ops::reversesubtract subOp; + auto res = subOp.evaluate({&x, &y}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0).equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseSubtractTest_3) { + // auto x('c', {1, 6}); + auto x = NDArrayFactory::create('c', {6}); + auto y = NDArrayFactory::create('c', {3, 4, 5, 1}); + auto exp = NDArrayFactory::create('c', {3, 4, 5, 6}); + auto z(exp); + x.assign(1); + y.assign(3); + exp.assign(2); + x.applyTrueBroadcast(BROADCAST(ReverseSubtract), y, z, true); + ASSERT_TRUE(z.equalsTo(&exp)); + sd::ops::reversesubtract subOp; - // auto x('c', {1, 6}); - auto x = NDArrayFactory::create('c', { 6 }); - auto y = NDArrayFactory::create('c', { 3, 4, 5, 1 }); - auto exp = NDArrayFactory::create('c', { 3, 4, 5, 6 }); - auto z(exp); - x.assign(1); - y.assign(3); - exp.assign(2); - x.applyTrueBroadcast(BROADCAST(ReverseSubtract), y, z, true); - ASSERT_TRUE(z.equalsTo(&exp)); - sd::ops::reversesubtract subOp; - - auto res = subOp.evaluate({ &x, &y }); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0)->equalsTo(&exp)); - - + auto res = subOp.evaluate({&x, &y}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0).equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseModTest_1) { + // auto x('c', {1, 6}); + auto x = NDArrayFactory::create('c', {6}); + auto y = NDArrayFactory::create('c', {3, 4, 5, 1}); + auto exp = NDArrayFactory::create('c', {3, 4, 5, 6}); + auto z(exp); + x.assign(2.); + y.assign(9.f); + exp.assign(1.f); + y.applyTrueBroadcast(BROADCAST(Mod), x, z, true); + ASSERT_TRUE(exp.equalsTo(&z)); - // auto x('c', {1, 6}); - auto x = NDArrayFactory::create('c', { 6 }); - auto y = NDArrayFactory::create('c', { 3, 4, 5, 1 }); - auto exp = NDArrayFactory::create('c', { 3, 4, 5, 6 }); - auto z(exp); - x.assign(2.); - y.assign(9.f); - exp.assign(1.f); - y.applyTrueBroadcast(BROADCAST(Mod), x, z, true); - ASSERT_TRUE(exp.equalsTo(&z)); - - x.applyTrueBroadcast(BROADCAST(ReverseMod), y, exp, true); - ASSERT_TRUE(exp.equalsTo(&z)); - - sd::ops::reversemod subOp; - - auto res = subOp.evaluate({ &x, &y }); + x.applyTrueBroadcast(BROADCAST(ReverseMod), y, exp, true); + ASSERT_TRUE(exp.equalsTo(&z)); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0)->equalsTo(&exp)); - ASSERT_TRUE(exp.equalsTo(&z)); + sd::ops::reversemod subOp; + auto res = subOp.evaluate({&x, &y}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0).equalsTo(&exp)); + ASSERT_TRUE(exp.equalsTo(&z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseModTest_2) { + // auto x('c', {1, 6}); + auto x = NDArrayFactory::create('c', {3, 4, 5}); + auto y = NDArrayFactory::create('c', {3, 4, 5}); + auto exp = NDArrayFactory::create('c', {3, 4, 5}); + auto z(exp); + x.assign(2.f); + y.assign(9.f); + exp.assign(1.f); + x.applyTrueBroadcast(BROADCAST(ReverseMod), y, z, true); + ASSERT_TRUE(z.equalsTo(&exp)); + x.applyTrueBroadcast(BROADCAST(ReverseMod), y, exp, true); + ASSERT_TRUE(z.equalsTo(&exp)); - // auto x('c', {1, 6}); - auto x = NDArrayFactory::create('c', { 3, 4, 5 }); - auto y = NDArrayFactory::create('c', { 3, 4, 5 }); - auto exp = NDArrayFactory::create('c', { 3, 4, 5 }); - auto z(exp); - x.assign(2.f); - y.assign(9.f); - exp.assign(1.f); - x.applyTrueBroadcast(BROADCAST(ReverseMod), y, z, true); - ASSERT_TRUE(z.equalsTo(&exp)); - x.applyTrueBroadcast(BROADCAST(ReverseMod), y, exp, true); - ASSERT_TRUE(z.equalsTo(&exp)); - - sd::ops::reversemod subOp; - - auto res = subOp.evaluate({ &x, &y }); - - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0)->equalsTo(&exp)); + sd::ops::reversemod subOp; + auto res = subOp.evaluate({&x, &y}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0).equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseSubtractVectorVector1) { + auto x = NDArrayFactory::create('c', {1, 15}); + auto y = NDArrayFactory::create('c', {1, 15}); + auto exp = NDArrayFactory::create('c', {1, 15}); + x.assign(3); + y.assign(1); + exp.assign(-2); - auto x = NDArrayFactory::create_('c', { 1, 15 }); - auto y = NDArrayFactory::create_('c', { 1, 15 }); - auto exp = NDArrayFactory::create_('c', { 1, 15 }); - x->assign(3); - y->assign(1); - exp->assign(-2); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::reversesubtract subOp; + sd::ops::reversesubtract subOp; - subOp.execute(block); + subOp.execute(block); - ASSERT_TRUE(x->equalsTo(exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; - delete exp; + delete variableSpace; + delete block; } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseSubtractMatrixScalar1) { + auto x = NDArrayFactory::create('c', {5, 3}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {5, 3}); + x.assign(3); + y.assign(1); + exp.assign(-2); - auto x = NDArrayFactory::create_('c', { 5, 3 }); - auto y = NDArrayFactory::create_('c', { 1, 1 }); - auto exp = NDArrayFactory::create_('c', { 5, 3 }); - x->assign(3); - y->assign(1); - exp->assign(-2); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::reversesubtract subOp; + sd::ops::reversesubtract subOp; - subOp.execute(block); + subOp.execute(block); - ASSERT_TRUE(x->equalsTo(exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; - delete exp; + delete variableSpace; + delete block; } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseSubtractScalarScalar1) { + auto x = NDArrayFactory::create('c', {1, 1}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {1, 1}); + x.assign(3); + y.assign(1); + exp.assign(-2); - auto x = NDArrayFactory::create_('c', { 1, 1 }); - auto y = NDArrayFactory::create_('c', { 1, 1 }); - auto exp = NDArrayFactory::create_('c', { 1, 1 }); - x->assign(3); - y->assign(1); - exp->assign(-2); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::reversesubtract subOp; + sd::ops::reversesubtract subOp; - subOp.execute(block); + subOp.execute(block); - ASSERT_TRUE(x->equalsTo(exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; - delete exp; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, MultiplyMatrices1) { + auto x = NDArrayFactory::create('c', {5, 3}); + auto y = NDArrayFactory::create('c', {5, 3}); + auto exp = NDArrayFactory::create('c', {5, 3}); + x.assign(2); + y.assign(3); + exp.assign(6); - auto x = NDArrayFactory::create_('c', { 5, 3 }); - auto y = NDArrayFactory::create_('c', { 5, 3 }); - auto exp = NDArrayFactory::create_('c', { 5, 3 }); - x->assign(2); - y->assign(3); - exp->assign(6); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::multiply mul; + sd::ops::multiply mul; - mul.execute(block); + mul.execute(block); - ASSERT_TRUE(x->equalsTo(exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; - delete exp; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, MultiplyVectorVector1) { + auto x = NDArrayFactory::create('c', {1, 15}); + auto y = NDArrayFactory::create('c', {1, 15}); + auto exp = NDArrayFactory::create('c', {1, 15}); + x.assign(2); + y.assign(3); + exp.assign(6); - auto x = NDArrayFactory::create_('c', { 1, 15 }); - auto y = NDArrayFactory::create_('c', { 1, 15 }); - auto exp = NDArrayFactory::create_('c', { 1, 15 }); - x->assign(2); - y->assign(3); - exp->assign(6); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::multiply mul; + sd::ops::multiply mul; - mul.execute(block); + mul.execute(block); - ASSERT_TRUE(x->equalsTo(exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; - delete exp; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, MultiplyMatrixScalar) { + auto x = NDArrayFactory::create('c', {5, 3}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {5, 3}); + x.assign(2); + y.assign(3); + exp.assign(6); - auto x = NDArrayFactory::create_('c', { 5, 3 }); - auto y = NDArrayFactory::create_('c', { 1, 1 }); - auto exp = NDArrayFactory::create_('c', { 5, 3 }); - x->assign(2); - y->assign(3); - exp->assign(6); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::multiply mul; + sd::ops::multiply mul; - mul.execute(block); + mul.execute(block); - ASSERT_TRUE(x->equalsTo(exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; - delete exp; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, MultiplyScalarScalar1) { + auto x = NDArrayFactory::create('c', {1, 1}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {1, 1}); + x.assign(2); + y.assign(3); + exp.assign(6); - auto x = NDArrayFactory::create_('c', { 1, 1 }); - auto y = NDArrayFactory::create_('c', { 1, 1 }); - auto exp = NDArrayFactory::create_('c', { 1, 1 }); - x->assign(2); - y->assign(3); - exp->assign(6); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::multiply mul; + sd::ops::multiply mul; - mul.execute(block); + mul.execute(block); - ASSERT_TRUE(x->equalsTo(exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete block; - delete variableSpace; - delete exp; + delete block; + delete variableSpace; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestSoftMax_bp_1) { + auto input = NDArrayFactory::create('c', {2, 2}); + for (int e = 0; e < input.lengthOf(); e++) input.p(e, e + 1); - auto input = NDArrayFactory::create_('c', { 2, 2 }); - for (int e = 0; e < input->lengthOf(); e++) - input->p(e, e + 1); - - auto epsilon = NDArrayFactory::create_('c', { 2, 2 }); - epsilon->p(0, 0.1f); - epsilon->p(1, 0.2f); - epsilon->p(2, 0.3f); - epsilon->p(3, 0.4f); + auto epsilon = NDArrayFactory::create('c', {2, 2}); + epsilon.p(0, 0.1f); + epsilon.p(1, 0.2f); + epsilon.p(2, 0.3f); + epsilon.p(3, 0.4f); - auto output = NDArrayFactory::create_('c', { 2, 2 }); - output->assign(1.0f); + auto output = NDArrayFactory::create('c', {2, 2}); + output.assign(1.0f); - auto exp = NDArrayFactory::create_('c', { 2, 2 }); - exp->p(0, -0.019661194f); - exp->p(1, 0.019661194f); - exp->p(2, -0.019661194f); - exp->p(3, 0.019661194f); + auto exp = NDArrayFactory::create('c', {2, 2}); + exp.p(0, -0.019661194f); + exp.p(1, 0.019661194f); + exp.p(2, -0.019661194f); + exp.p(3, 0.019661194f); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, input); - variableSpace->putVariable(-2, epsilon); - variableSpace->putVariable(1, output); - //variableSpace->putVariable(42, exp); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); + variableSpace->putVariable(-2, epsilon); + variableSpace->putVariable(1, output); + // variableSpace->putVariable(42, exp); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::softmax_bp op; + sd::ops::softmax_bp op; - Nd4jStatus status = op.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); + Nd4jStatus status = op.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(output->equalsTo(exp)); - - delete variableSpace; - delete block; - delete exp; + ASSERT_TRUE(output.equalsTo(exp)); + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, BroadcastDivideTest_1) { + auto x = NDArrayFactory::create('c', {3, 4, 5, 1}); + auto y = NDArrayFactory::create('c', {1, 6}); + auto exp = NDArrayFactory::create('c', {3, 4, 5, 6}); + x.assign(6); + y.assign(2); + exp.assign(3); - auto x = NDArrayFactory::create('c', { 3, 4, 5, 1 }); - auto y = NDArrayFactory::create('c', { 1, 6 }); - auto exp = NDArrayFactory::create('c', { 3, 4, 5, 6 }); - x.assign(6); - y.assign(2); - exp.assign(3); - - sd::ops::divide div; - - auto res = div.evaluate({ &x, &y }); - - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0)->equalsTo(exp)); + sd::ops::divide div; + auto res = div.evaluate({&x, &y}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0).equalsTo(exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, BroadcastDivideTest_2) { + auto x = NDArrayFactory::create('c', {3, 4, 5, 1}); + auto y = NDArrayFactory::create('c', {1, 6}); + auto exp = NDArrayFactory::create('c', {3, 4, 5, 6}); + x.assign(6); + y.assign(2); + exp.assign(3); - auto x = NDArrayFactory::create('c', { 3, 4, 5, 1 }); - auto y = NDArrayFactory::create('c', { 1, 6 }); - auto exp = NDArrayFactory::create('c', { 3, 4, 5, 6 }); - x.assign(6); - y.assign(2); - exp.assign(3); - - sd::ops::divide_no_nan div; - auto res = div.evaluate({ &x, &y }); - - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0)->equalsTo(exp)); - + sd::ops::divide_no_nan div; + auto res = div.evaluate({&x, &y}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0).equalsTo(exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, BroadcastDivideTest_3) { + auto x = NDArrayFactory::create({6, 6, 6, 6, 6}); + auto y = NDArrayFactory::create({3, 3, 0, 3, 3}); + auto exp = NDArrayFactory::create({2, 2, 0, 2, 2}); - auto x = NDArrayFactory::create({ 6,6,6,6,6 }); - auto y = NDArrayFactory::create({ 3,3,0,3,3 }); - auto exp = NDArrayFactory::create({ 2, 2, 0, 2, 2 }); - - sd::ops::divide_no_nan div; - auto res = div.evaluate({ &x, &y }); - - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0)->equalsTo(exp)); - + sd::ops::divide_no_nan div; + auto res = div.evaluate({&x, &y}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0).equalsTo(exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, BroadcastReverseDivideTest_1) { + auto x = NDArrayFactory::create('c', {3, 4, 5, 1}); + auto y = NDArrayFactory::create('c', {1, 6}); + auto exp = NDArrayFactory::create('c', {3, 4, 5, 6}); + x.assign(3.f); + y.assign(6.f); + exp.assign(2.f); - auto x = NDArrayFactory::create('c', { 3, 4, 5, 1 }); - auto y = NDArrayFactory::create('c', { 1, 6 }); - auto exp = NDArrayFactory::create('c', { 3, 4, 5, 6 }); - x.assign(3.f); - y.assign(6.f); - exp.assign(2.f); + sd::ops::reversedivide div; - sd::ops::reversedivide div; + auto res = div.evaluate({&x, &y}); - auto res = div.evaluate({ &x, &y }); - - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - - ASSERT_TRUE(res.at(0)->equalsTo(exp)); - auto z(exp); - x.applyTrueBroadcast(BROADCAST(ReverseDivide), y, z, true); - y.applyTrueBroadcast(BROADCAST(Divide), x, exp, true); - - ASSERT_TRUE(z.equalsTo(&exp)); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0).equalsTo(exp)); + auto z(exp); + x.applyTrueBroadcast(BROADCAST(ReverseDivide), y, z, true); + y.applyTrueBroadcast(BROADCAST(Divide), x, exp, true); + ASSERT_TRUE(z.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, DivideMatrices1) { + auto x = NDArrayFactory::create('c', {5, 3}); + auto y = NDArrayFactory::create('c', {5, 3}); + auto exp = NDArrayFactory::create('c', {5, 3}); + x.assign(6); + y.assign(2); + exp.assign(3); - auto x = NDArrayFactory::create_('c', { 5, 3 }); - auto y = NDArrayFactory::create_('c', { 5, 3 }); - auto exp = NDArrayFactory::create_('c', { 5, 3 }); - x->assign(6); - y->assign(2); - exp->assign(3); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::divide div; + sd::ops::divide div; - div.execute(block); + div.execute(block); - ASSERT_TRUE(x->equalsTo(exp)); + ASSERT_TRUE(x.equalsTo(exp)); - delete variableSpace; - delete block; - delete exp; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, DivideVectorVector1) { + auto x = NDArrayFactory::create('c', {1, 15}); + auto y = NDArrayFactory::create('c', {1, 15}); + auto exp = NDArrayFactory::create('c', {1, 15}); + x.assign(6); + y.assign(2); + exp.assign(3); - auto x = NDArrayFactory::create_('c', { 1, 15 }); - auto y = NDArrayFactory::create_('c', { 1, 15 }); - auto exp = NDArrayFactory::create('c', { 1, 15 }); - x->assign(6); - y->assign(2); - exp.assign(3); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::divide div; + sd::ops::divide div; - div.execute(block); + div.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, DivideMatrixScalar1) { + auto x = NDArrayFactory::create('c', {5, 3}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {5, 3}); + x.assign(6); + y.assign(2); + exp.assign(3); - auto x = NDArrayFactory::create_('c', { 5, 3 }); - auto y = NDArrayFactory::create_('c', { 1, 1 }); - auto exp = NDArrayFactory::create('c', { 5, 3 }); - x->assign(6); - y->assign(2); - exp.assign(3); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::divide div; + sd::ops::divide div; - div.execute(block); + div.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete block; - delete variableSpace; + delete block; + delete variableSpace; } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, DivideScalarScalar1) { + auto x = NDArrayFactory::create('c', {5, 1}); + auto y = NDArrayFactory::create('c', {5, 1}); + auto exp = NDArrayFactory::create('c', {5, 1}); + x.assign(6); + y.assign(2); + exp.assign(3); - auto x = NDArrayFactory::create_('c', { 5, 1 }); - auto y = NDArrayFactory::create_('c', { 5, 1 }); - auto exp = NDArrayFactory::create('c', { 5, 1 }); - x->assign(6); - y->assign(2); - exp.assign(3); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::divide div; + sd::ops::divide div; - div.execute(block); + div.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseDivideMatrices1) { + auto x = NDArrayFactory::create('c', {5, 3}); + auto y = NDArrayFactory::create('c', {5, 3}); + auto exp = NDArrayFactory::create('c', {5, 3}); + x.assign(2); + y.assign(6); + exp.assign(3); - auto x = NDArrayFactory::create_('c', { 5, 3 }); - auto y = NDArrayFactory::create_('c', { 5, 3 }); - auto exp = NDArrayFactory::create('c', { 5, 3 }); - x->assign(2); - y->assign(6); - exp.assign(3); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::reversedivide div; + sd::ops::reversedivide div; - div.execute(block); + div.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseDivideVectorVector1) { + auto x = NDArrayFactory::create('c', {1, 15}); + auto y = NDArrayFactory::create('c', {1, 15}); + auto exp = NDArrayFactory::create('c', {1, 15}); + x.assign(2); + y.assign(6); + exp.assign(3); - auto x = NDArrayFactory::create_('c', { 1, 15 }); - auto y = NDArrayFactory::create_('c', { 1, 15 }); - auto exp = NDArrayFactory::create('c', { 1, 15 }); - x->assign(2); - y->assign(6); - exp.assign(3); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::reversedivide div; + sd::ops::reversedivide div; - div.execute(block); + div.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseDivideMatrixScalar1) { + auto x = NDArrayFactory::create('c', {5, 3}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {5, 3}); + x.assign(2); + y.assign(6); + exp.assign(3); - auto x = NDArrayFactory::create_('c', { 5, 3 }); - auto y = NDArrayFactory::create_('c', { 1, 1 }); - auto exp = NDArrayFactory::create('c', { 5, 3 }); - x->assign(2); - y->assign(6); - exp.assign(3); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::reversedivide div; + sd::ops::reversedivide div; - div.execute(block); + div.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseDivideScalarScalar1) { + auto x = NDArrayFactory::create('c', {1, 1}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {1, 1}); + x.assign(2); + y.assign(6); + exp.assign(3); - auto x = NDArrayFactory::create_('c', { 1, 1 }); - auto y = NDArrayFactory::create_('c', { 1, 1 }); - auto exp = NDArrayFactory::create('c', { 1, 1 }); - x->assign(2); - y->assign(6); - exp.assign(3); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::reversedivide div; + sd::ops::reversedivide div; - div.execute(block); + div.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } TEST_F(DeclarableOpsTests1, Test_Cast_1) { - // TODO: right now there's no real cast implementation, but genera idea should be the same: arrays equality to be expected - auto x = NDArrayFactory::create('c', { 5, 5 }); - auto yExp = NDArrayFactory::create('c', { 5, 5 }); - x.linspace(1); - yExp.linspace(1); - sd::ops::cast op; - - auto result = op.evaluate({ &x }, {}, { 3 }); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - ASSERT_TRUE(yExp.equalsTo(z)); + // TODO: right now there's no real cast implementation, but genera idea should + // be the same: arrays equality to be expected + auto x = NDArrayFactory::create('c', {5, 5}); + auto yExp = NDArrayFactory::create('c', {5, 5}); + x.linspace(1); + yExp.linspace(1); + sd::ops::cast op; + auto result = op.evaluate({&x}, {}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(yExp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestRegistrator1) { - auto res = sd::ops::OpRegistrator::getInstance().getAllCustomOperations(); + auto res = sd::ops::OpRegistrator::getInstance().getAllCustomOperations(); } // ////////////////////////////////////////////////////////////////////// @@ -1713,7 +1757,8 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) { // z->assign(120.0f); // std::string opName("add"); -// auto hash = sd::ops::HashHelper::getInstance().getInstance()->getLongHash(opName); +// auto hash = +// sd::ops::HashHelper::getInstance().getInstance().getLongHash(opName); // auto inputBuffers = new Nd4jPointer[2]; // auto inputShapes = new Nd4jPointer[2]; @@ -1730,13 +1775,14 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) { // outputBuffers[0] = (Nd4jPointer) z->buffer(); // outputShapes[0] = (Nd4jPointer) z->shapeInfo(); - -// //auto status = execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, false); -// auto status = execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); -// ASSERT_EQ(ND4J_STATUS_OK, status); -// ASSERT_NEAR(2.0f, y->meanNumber().e(0), 1e-5); -// ASSERT_NEAR(1.0f, x->meanNumber().e(0), 1e-5); -// ASSERT_NEAR(3.0f, z->meanNumber().e(0), 1e-5); +// //auto status = execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, +// outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, false); auto +// status = execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, +// outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, +// false); ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_NEAR(2.0f, +// y->meanNumber().e(0), 1e-5); ASSERT_NEAR(1.0f, +// x->meanNumber().e(0), 1e-5); ASSERT_NEAR(3.0f, +// z->meanNumber().e(0), 1e-5); // delete x; // delete y; @@ -1763,7 +1809,8 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) { // std::string opName("add"); -// auto hash = sd::ops::HashHelper::getInstance().getInstance()->getLongHash(opName); +// auto hash = +// sd::ops::HashHelper::getInstance().getInstance().getLongHash(opName); // auto inputBuffers = new Nd4jPointer[2]; // auto inputShapes = new Nd4jPointer[2]; @@ -1777,12 +1824,12 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) { // auto outputBuffers = new Nd4jPointer[1]; // auto outputShapes = new Nd4jPointer[1]; -// execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, true); +// execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, +// outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, true); // ASSERT_NEAR(2.0, y->meanNumber().e(0), 1e-5); // ASSERT_NEAR(3.0, x->meanNumber().e(0), 1e-5); - // delete x; // delete y; // delete z; @@ -1796,220 +1843,212 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) { #ifndef __CUDABLAS__ ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestGemv1) { - /* - auto xBuffer = new float[15]{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}; - auto xShape = new Nd4jLong[8] {2, 5, 3, 3, 1, 0, 1, 99}; - ArrayOptions::setDataType(xShape, sd::DataType::FLOAT32); - auto x = new NDArray(xBuffer, xShape); + /* + auto xBuffer = new + float[15]{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, + 15.f}; auto xShape = new Nd4jLong[8] {2, 5, 3, 3, 1, 0, 1, 99}; + ArrayOptions::setDataType(xShape, sd::DataType::FLOAT32); + auto x = new NDArray(xBuffer, xShape); - auto yBuffer = new float[3]{2.f, 4.f, 6.f}; - auto yShape = new Nd4jLong[8] {2, 3, 1, 1, 1, 0, 1, 99}; - ArrayOptions::setDataType(yShape, sd::DataType::FLOAT32); + auto yBuffer = new float[3]{2.f, 4.f, 6.f}; + auto yShape = new Nd4jLong[8] {2, 3, 1, 1, 1, 0, 1, 99}; + ArrayOptions::setDataType(yShape, sd::DataType::FLOAT32); - auto y = new NDArray(yBuffer, yShape); + auto y = new NDArray(yBuffer, yShape); - auto z = NDArrayFactory::create_('f', {5, 1}); + auto z = NDArrayFactory::create_('f', {5, 1}); - auto expBuffer = new float[5]{28.00f,64.00f,100.00f,136.00f,172.00f}; - auto exp = new NDArray(expBuffer, z->shapeInfo()); + auto expBuffer = new float[5]{28.00f,64.00f,100.00f,136.00f,172.00f}; + auto exp = new NDArray(expBuffer, z->shapeInfo()); - sd::blas::GEMV::op('f', x->rows(), x->columns(), 1.0f, x->buffer(), y->rows(), y->buffer(), 1, 0.0, z->buffer(), 1); + sd::blas::GEMV::op('f', x->rows(), x->columns(), 1.0f, + x->buffer(), y->rows(), y->buffer(), 1, 0.0, z->buffer(), 1); - ASSERT_TRUE(z->equalsTo(exp)); + ASSERT_TRUE(z->equalsTo(exp)); - delete []xBuffer; delete []xShape; delete x; delete []yBuffer; delete []yShape; delete y; delete z; delete []expBuffer; delete exp; - */ + delete []xBuffer; delete []xShape; delete x; delete []yBuffer; delete + []yShape; delete y; delete z; delete []expBuffer; delete exp; + */ } #endif - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Transpose1) { - auto x = NDArrayFactory::create_('c', { 3,5,2 }); - auto exp = NDArrayFactory::create_('c', { 2,5,3 }); + auto x = NDArrayFactory::create('c', {3, 5, 2}); + auto exp = NDArrayFactory::create('c', {2, 5, 3}); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(1, new Variable()); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); - auto block = new Context(1, variableSpace, false); // not-in-place - block->fillInputs({ -1 }); - sd::ops::transpose transpose; + auto block = new Context(1, variableSpace, false); // not-in-place + block->fillInputs({-1}); + sd::ops::transpose transpose; - Nd4jStatus status = transpose.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); + Nd4jStatus status = transpose.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - ASSERT_TRUE(exp->isSameShape(result)); - ASSERT_TRUE(exp->dataType() == result->dataType()); - ASSERT_TRUE(exp->ordering() == result->ordering()); + ASSERT_TRUE(exp.isSameShape(*result)); + ASSERT_TRUE(exp.dataType() == result->dataType()); + ASSERT_TRUE(exp.ordering() == result->ordering()); - delete exp; - delete block; - delete variableSpace; + delete block; + delete variableSpace; } ////////////////////////////////////////////////////////////////////// // not-in-place TEST_F(DeclarableOpsTests1, Permute1) { + Nd4jLong shapeX[] = {3, 5, 10, 15, 150, 15, 1, 0, 1, 99}; + Nd4jLong shapeExp[] = {3, 15, 5, 10, 50, 10, 1, 0, 1, 99}; + const std::vector perm = {2, 0, 1}; - Nd4jLong shapeX[] = { 3, 5,10,15, 150,15,1, 0,1,99 }; - Nd4jLong shapeExp[] = { 3, 15,5,10, 50,10,1, 0,1,99 }; - const std::vector perm = { 2, 0, 1 }; + ArrayOptions::setDataType(shapeX, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shapeExp, sd::DataType::FLOAT32); - ArrayOptions::setDataType(shapeX, sd::DataType::FLOAT32); - ArrayOptions::setDataType(shapeExp, sd::DataType::FLOAT32); + NDArray x(shapeX, true); + NDArray exp(shapeExp, true); - auto x = new NDArray(shapeX, true); - auto exp = new NDArray(shapeExp, true); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(1, new Variable()); + auto block = new Context(1, variableSpace, false); // not-in-place + block->fillInputs({-1}); - auto block = new Context(1, variableSpace, false); // not-in-place - block->fillInputs({ -1 }); - auto arguments = block->getIArguments(); - *arguments = perm; // set dimensions to be permuted + block->appendI(perm); // set dimensions to be permuted - sd::ops::permute permute; - Nd4jStatus status = permute.execute(block); - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + sd::ops::permute permute; + Nd4jStatus status = permute.execute(block); + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(result->isSameShapeStrict(*exp)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(result->isSameShapeStrict(exp)); - delete block; - delete variableSpace; - delete exp; + delete block; + delete variableSpace; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestArgumentsValidation1) { - Nd4jLong shapeX[] = { 3, 5, 10, 15, 150, 15, 1, 0, 1, 99 }; - Nd4jLong shapeExp[] = { 3, 15, 5, 10, 1, 150, 15, 0, -1, 99 }; + Nd4jLong shapeX[] = {3, 5, 10, 15, 150, 15, 1, 0, 1, 99}; + Nd4jLong shapeExp[] = {3, 15, 5, 10, 1, 150, 15, 0, -1, 99}; - ArrayOptions::setDataType(shapeX, sd::DataType::FLOAT32); - ArrayOptions::setDataType(shapeExp, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shapeX, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shapeExp, sd::DataType::FLOAT32); - const std::vector perm = { 2, 0, 1 }; - auto x = new NDArray(shapeX); - auto exp = new NDArray(shapeExp); + const std::vector perm = {2, 0, 1}; + NDArray x(shapeX); + NDArray exp(shapeExp); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(1, new Variable()); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); - auto block = new Context(1, variableSpace, false); // not-in-place - block->fillInputs({ -1 }); + auto block = new Context(1, variableSpace, false); // not-in-place + block->fillInputs({-1}); - sd::ops::im2col permute; - Nd4jStatus status = permute.execute(block); + sd::ops::im2col permute; + Nd4jStatus status = permute.execute(block); - ASSERT_TRUE(status != 0); + ASSERT_TRUE(status != 0); - delete exp; - delete block; - delete variableSpace; + delete block; + delete variableSpace; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestReductionShape1) { - auto input = NDArrayFactory::create_('c', { 4, 5, 5, 10, 10 }); + auto input = NDArrayFactory::create('c', {4, 5, 5, 10, 10}); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, input); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); - auto block = new Context(1, variableSpace, false); // not-in-place - block->fillInputs({ -1 }); + auto block = new Context(1, variableSpace, false); // not-in-place + block->fillInputs({-1}); - // kernel params - block->getIArguments()->push_back(MAX_INT); + // kernel params + block->appendI(MAX_INT); - sd::ops::testreduction testop; + sd::ops::testreduction testop; - auto inP = new Nd4jLong[shape::shapeInfoLength(input->shapeInfo())]; - memcpy(inP, input->shapeInfo(), shape::shapeInfoByteLength(input->rankOf())); + auto inP = new Nd4jLong[shape::shapeInfoLength(input.shapeInfo())]; + memcpy(inP, input.shapeInfo(), shape::shapeInfoByteLength(input.rankOf())); - auto inshape = new ShapeList(inP); + auto inshape = new ShapeList(inP); - auto shapes = testop.calculateOutputShape(inshape, *block); + auto shapes = testop.calculateOutputShape(inshape, *block); - ASSERT_EQ(1, shapes->size()); - ASSERT_EQ(0, shapes->at(0)[0]); // scalar shape has rank 0 - ASSERT_EQ(8192, shapes->at(0)[1]); - ASSERT_EQ(1, shapes->at(0)[2]); - - delete[] inP; - delete variableSpace; - delete block; - delete inshape; - delete shapes; + ASSERT_EQ(1, shapes->size()); + ASSERT_EQ(0, shapes->at(0)[0]); // scalar shape has rank 0 + ASSERT_EQ(8192, shapes->at(0)[1]); + ASSERT_EQ(1, shapes->at(0)[2]); + delete[] inP; + delete variableSpace; + delete block; + delete inshape; + delete shapes; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestReductionShape2) { - auto input = NDArrayFactory::create_('c', { 4, 5, 5, 10, 10 }); + auto input = NDArrayFactory::create('c', {4, 5, 5, 10, 10}); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, input); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); - auto block = new Context(1, variableSpace, false); // not-in-place - block->fillInputs({ -1 }); + auto block = new Context(1, variableSpace, false); // not-in-place + block->fillInputs({-1}); - // kernel params - //block->getIArguments()->push_back(4); - block->getIArguments()->push_back(1); - block->getIArguments()->push_back(2); - block->getIArguments()->push_back(3); - block->getIArguments()->push_back(4); + // kernel params + // block->getIArguments()->push_back(4); + block->appendI(1); + block->appendI(2); + block->appendI(3); + block->appendI(4); - sd::ops::testreduction testop; + sd::ops::testreduction testop; - auto inshapes = new ShapeList(input->shapeInfo()); - auto shapes = testop.calculateOutputShape(inshapes, *block); - ASSERT_EQ(1, shapes->size()); - ASSERT_EQ(1, shapes->at(0)[0]); - ASSERT_EQ(4, shapes->at(0)[1]); - ASSERT_EQ(1, shapes->at(0)[2]); + auto inshapes = new ShapeList(input.shapeInfo()); + auto shapes = testop.calculateOutputShape(inshapes, *block); + ASSERT_EQ(1, shapes->size()); + ASSERT_EQ(1, shapes->at(0)[0]); + ASSERT_EQ(4, shapes->at(0)[1]); + ASSERT_EQ(1, shapes->at(0)[2]); - delete variableSpace; - delete block; - delete shapes; - delete inshapes; + delete variableSpace; + delete block; + delete shapes; + delete inshapes; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestCustomShape1) { - auto input = NDArrayFactory::create_('c', { 2, 3, 4 }); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, input); + auto input = NDArrayFactory::create('c', {2, 3, 4}); - auto block = new Context(1, variableSpace, false); // not-in-place - block->fillInputs({ -1 }); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); - sd::ops::testcustom test; + auto block = new Context(1, variableSpace, false); // not-in-place + block->fillInputs({-1}); - auto inshapes = new ShapeList(input->shapeInfo()); - auto shapes = test.calculateOutputShape(inshapes, *block); + sd::ops::testcustom test; + auto inshapes = new ShapeList(input.shapeInfo()); + auto shapes = test.calculateOutputShape(inshapes, *block); - ASSERT_EQ(input->shapeInfo()[0], shapes->at(0)[0]); - ASSERT_EQ(input->shapeInfo()[1] * 2, shapes->at(0)[1]); - ASSERT_EQ(input->shapeInfo()[2] * 2, shapes->at(0)[2]); - ASSERT_EQ(input->shapeInfo()[3] * 2, shapes->at(0)[3]); + ASSERT_EQ(input.shapeInfo()[0], shapes->at(0)[0]); + ASSERT_EQ(input.shapeInfo()[1] * 2, shapes->at(0)[1]); + ASSERT_EQ(input.shapeInfo()[2] * 2, shapes->at(0)[2]); + ASSERT_EQ(input.shapeInfo()[3] * 2, shapes->at(0)[3]); - delete variableSpace; - delete block; - delete shapes; - delete inshapes; + delete variableSpace; + delete block; + delete shapes; + delete inshapes; } - ////////////////////////////////////////////////////////////////////// /* TEST_F(DeclarableOpsTests1, Sum1) { @@ -2049,29 +2088,29 @@ TEST_F(DeclarableOpsTests1, Sum1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Pnormpool2d1) { + auto x = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto exp = NDArrayFactory::create('c', {bS, iD, oH, oW}); + // auto z('c',{bS,iD,oH,oW}); - auto x = NDArrayFactory::create_('c', { bS,iD,iH,iW }); - auto exp = NDArrayFactory::create('c', { bS,iD,oH,oW }); - // auto z('c',{bS,iD,oH,oW}); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - // variableSpace->putVariable(1, &z); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->appendI({kH, kW, sH, sW, pH, pW, dW, dH, 0, 1, + 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; + // 4,5 - pad Height/Width; 6,7 - dilation Height/Width; + // 8 - same mode; 9 - extraParam0 for pnorm case; - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1 }); - std::vector* argI = block->getIArguments(); - *argI = { kH,kW, sH,sW, pH,pW, dW,dH, 0, 1, 0 }; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; 9 - extraParam0 for pnorm case; + sd::ops::pnormpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); - sd::ops::pnormpool2d pooling; - Nd4jStatus status = pooling.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - ASSERT_TRUE(exp.isSameShape(result)); + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + ASSERT_TRUE(exp.isSameShape(*result)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } /*///////////////////////////////////////////////////////////////////// @@ -2092,7 +2131,8 @@ TEST_F(DeclarableOpsTests1, IsMax1) { block->fillInputs({-1}); std::vector* argI = block->getIArguments(); // *argI = {1}; // dimensions - argI->push_back(1); // = {1}; // dimensions + argI->push_back(1); // = {1}; // +dimensions sd::ops::ismax ismaxOp; Nd4jStatus status = ismaxOp.execute(block); @@ -2109,81 +2149,78 @@ TEST_F(DeclarableOpsTests1, IsMax1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, IsMax1) { - NDArray x('c', { 3, 3 }, sd::DataType::FLOAT32); - // NDArray exp('c', {3, 3}, sd::DataType::BOOL); - NDArray exp('c', { 3, 3 }, sd::DataType::FLOAT32); - x.linspace(1); - exp.p(0, 2, true); - exp.p(1, 2, true); - exp.p(2, 2, true); - - sd::ops::ismax ismaxOp; - auto result = ismaxOp.evaluate({ &x }, {}, { 1 }); + NDArray x('c', {3, 3}, sd::DataType::FLOAT32); + // NDArray exp('c', {3, 3}, sd::DataType::BOOL); + NDArray exp('c', {3, 3}, sd::DataType::FLOAT32); + x.linspace(1); + exp.p(0, 2, true); + exp.p(1, 2, true); + exp.p(2, 2, true); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto res = result.at(0); - //res->printIndexedBuffer("IS_MAX"); - ASSERT_TRUE(exp.equalsTo(res)); + sd::ops::ismax ismaxOp; + auto result = ismaxOp.evaluate({&x}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto res = result.at(0); + res.printIndexedBuffer("IS_MAX"); + ASSERT_TRUE(exp.equalsTo(res)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, IsMax2) { - NDArray x('c', { 3, 3 }, sd::DataType::FLOAT32); - // NDArray exp('c', {3, 3}, sd::DataType::BOOL); - NDArray exp('c', { 3, 3 }, sd::DataType::FLOAT32); - x.linspace(1); - //exp.p(0, 2, true); - //exp.p(1, 2, true); - exp.p(2, 2, true); - - sd::ops::ismax ismaxOp; - auto result = ismaxOp.evaluate({ &x }, {}, { 0, 1 }); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + NDArray x('c', {3, 3}, sd::DataType::FLOAT32); + // NDArray exp('c', {3, 3}, sd::DataType::BOOL); + NDArray exp('c', {3, 3}, sd::DataType::FLOAT32); + x.linspace(1); + // exp.p(0, 2, true); + // exp.p(1, 2, true); + exp.p(2, 2, true); - auto res = result.at(0); - //res->printIndexedBuffer("IS_MAX"); - ASSERT_TRUE(exp.equalsTo(res)); + sd::ops::ismax ismaxOp; + auto result = ismaxOp.evaluate({&x}, {}, {0, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto res = result.at(0); + // res->printIndexedBuffer("IS_MAX"); + ASSERT_TRUE(exp.equalsTo(res)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, IsMax3) { - NDArray x = NDArrayFactory::create(120.f); //('c', {3, 3}, sd::DataType::FLOAT32); -// NDArray exp('c', {3, 3}, sd::DataType::BOOL); - NDArray exp = NDArrayFactory::create(1.f);//, sd::DataType::FLOAT32); //'c', {3, 3}, sd::DataType::FLOAT32); - x.linspace(1); - //exp.p(0, 2, true); - //exp.p(1, 2, true); - //exp.p(2, 2, true); + NDArray x = NDArrayFactory::create( + 120.f); //('c', {3, 3}, sd::DataType::FLOAT32); + // NDArray exp('c', {3, 3}, sd::DataType::BOOL); + NDArray exp = NDArrayFactory::create( + 1.f); //, sd::DataType::FLOAT32); //'c', {3, 3}, sd::DataType::FLOAT32); + x.linspace(1); + // exp.p(0, 2, true); + // exp.p(1, 2, true); + // exp.p(2, 2, true); - sd::ops::ismax ismaxOp; - auto result = ismaxOp.evaluate({ &x }, {}, { 0 }); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto res = result.at(0); - //res->printIndexedBuffer("IS_MAX"); - ASSERT_TRUE(exp.equalsTo(res)); + sd::ops::ismax ismaxOp; + auto result = ismaxOp.evaluate({&x}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto res = result.at(0); + // res->printIndexedBuffer("IS_MAX"); + ASSERT_TRUE(exp.equalsTo(res)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, IsMax4) { - auto x = NDArrayFactory::create('c', { 6 }, { 0, 0, 0, 2, 2, 0 }); - auto z = NDArrayFactory::create('c', { 6 }); - auto e = NDArrayFactory::create('c', { 6 }, { false, false, false, true, false, false }); + auto x = NDArrayFactory::create('c', {6}, {0, 0, 0, 2, 2, 0}); + auto z = NDArrayFactory::create('c', {6}); + auto e = NDArrayFactory::create( + 'c', {6}, {false, false, false, true, false, false}); - sd::ops::ismax op; - auto result = op.execute({ &x }, { &z }); - ASSERT_EQ(Status::OK(), result); + sd::ops::ismax op; + auto result = op.execute({&x}, {&z}); + ASSERT_EQ(Status::OK(), result); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } //////////////////////////////////////////////////////////////////// @@ -2198,8 +2235,15 @@ TEST_F(DeclarableOpsTests1, IsMax4) { // NDArray bias('c', {1,2*K}, sd::DataType::DOUBLE); // NDArray init('c', {bS,K}, sd::DataType::DOUBLE); // NDArray mask('c', {bS,K}, sd::DataType::DOUBLE); -// NDArray expState('c', {bS,K,N}, {0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715}, sd::DataType::DOUBLE); -// NDArray expOut('c', {bS,K,N}, {1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656}, sd::DataType::DOUBLE); +// NDArray expState('c', {bS,K,N}, {0.847983, 0.874549, 0.896109, 0.913715, +// 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, +// 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, +// 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715}, +// sd::DataType::DOUBLE); NDArray expOut('c', {bS,K,N}, +// {1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, +// 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, +// 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656}, +// sd::DataType::DOUBLE); // input.assign(1.5); // weights.assign(0.5); @@ -2208,8 +2252,8 @@ TEST_F(DeclarableOpsTests1, IsMax4) { // mask.assign(1.); // sd::ops::sru_old op; -// auto results = op.execute({&input, &weights, &bias, &init, &mask}, {}, {}); -// ASSERT_TRUE(results.size() == 2); +// auto results = op.execute({&input, &weights, &bias, &init, &mask}, {}, +// {}); ASSERT_TRUE(results.size() == 2); // auto state = results.at(0); // auto output = results.at(1); @@ -2224,1145 +2268,1273 @@ TEST_F(DeclarableOpsTests1, IsMax4) { ////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, sru_test1) { - - const int bS = 2; - const int K = 3; - const int N = 4; - - NDArray input('c', { bS,K,N }, sd::DataType::DOUBLE); - NDArray weights('c', { 3 * K,K }, sd::DataType::DOUBLE); - NDArray bias('c', { 2 * K }, sd::DataType::DOUBLE); - NDArray init('c', { bS,K }, sd::DataType::DOUBLE); - NDArray mask('c', { bS,K }, sd::DataType::DOUBLE); - NDArray expState('c', { bS,K,N }, { 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656 }, sd::DataType::DOUBLE); - NDArray expOut('c', { bS,K,N }, { 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715 }, sd::DataType::DOUBLE); - - input.assign(1.5); - weights.assign(0.5); - bias.assign(0.3); - init.assign(1.); - mask.assign(1.); - - sd::ops::sru op; - auto results = op.evaluate({ &input, &weights, &bias, &init, &mask }); - ASSERT_TRUE(results.size() == 2); - - auto output = results.at(0); - auto state = results.at(1); - - ASSERT_TRUE(expState.equalsTo(state)); - ASSERT_TRUE(expOut.equalsTo(output)); - - + const int bS = 2; + const int K = 3; + const int N = 4; + + NDArray input('c', {bS, K, N}, sd::DataType::DOUBLE); + NDArray weights('c', {3 * K, K}, sd::DataType::DOUBLE); + NDArray bias('c', {2 * K}, sd::DataType::DOUBLE); + NDArray init('c', {bS, K}, sd::DataType::DOUBLE); + NDArray mask('c', {bS, K}, sd::DataType::DOUBLE); + NDArray expState('c', {bS, K, N}, + {1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, + 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, + 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, + 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656}, + sd::DataType::DOUBLE); + NDArray expOut('c', {bS, K, N}, + {0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, + 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, + 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, + 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715}, + sd::DataType::DOUBLE); + + input.assign(1.5); + weights.assign(0.5); + bias.assign(0.3); + init.assign(1.); + mask.assign(1.); + + sd::ops::sru op; + auto results = op.evaluate({&input, &weights, &bias, &init, &mask}); + ASSERT_TRUE(results.size() == 2); + + auto output = results.at(0); + auto state = results.at(1); + + ASSERT_TRUE(expState.equalsTo(state)); + ASSERT_TRUE(expOut.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, sru_bp) { - - const int bS = 2; - const int K = 3; - const int N = 4; - std::vector expGradXBuff = { -0.0259303, -0.03869125, -0.0302272, -0.02299165, -0.0259303, -0.03869125, -0.0302272, -0.02299165, -0.0259303, -0.03869125, -0.0302272, -0.02299165, -0.0259303, -0.03869125, -0.0302272, -0.02299165, -0.0259303, -0.03869125, -0.0302272, -0.02299165, -0.0259303, -0.03869125, -0.0302272, -0.02299165 }; - std::vector expGradWBuff = { 0.42526005,0.42526005,0.42526005, 0.42526005,0.42526005,0.42526005, 0.42526005,0.42526005,0.42526005, -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, 0.42526005,0.42526005,0.42526005, 0.42526005,0.42526005,0.42526005, 0.42526005,0.42526005,0.42526005, -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215 }; - std::vector expGradBBuff = { -0.7043748, -0.7043748, -0.7043748, -0.2128962, -0.2128962, -0.2128962 }; - std::vector expGradInitBuff = { 1.1421, 1.1421, 1.1421, 1.1421, 1.1421, 1.1421 }; - std::vector stateBuff = { 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715 }; - - auto input = NDArrayFactory::create('c', { bS,K,N }); - auto weights = NDArrayFactory::create('c', { 3 * K,K }); - auto bias = NDArrayFactory::create('c', { 1,2 * K }); - auto init = NDArrayFactory::create('c', { bS,K }); - auto mask = NDArrayFactory::create('c', { bS,K }); - auto state = NDArrayFactory::create('c', { bS,K,N }, stateBuff); - auto inGradCt = NDArrayFactory::create('c', { bS,K }); - auto inGradH = NDArrayFactory::create('c', { bS,K,N }); - - auto expGradX = NDArrayFactory::create('c', { bS,K,N }, expGradXBuff); - auto expGradW = NDArrayFactory::create('c', { bS,3 * K,K }, expGradWBuff); - auto expGradB = NDArrayFactory::create('c', { 1,2 * K }, expGradBBuff); - auto expGradInit = NDArrayFactory::create('c', { bS,K }, expGradInitBuff); - - input.assign(1.5); - weights.assign(0.5); - bias.assign(0.3); - mask.assign(1.); - init.assign(1.); - inGradCt.assign(0.5); - inGradH.assign(0.5); - - sd::ops::sru_bp bp; - auto resultsBP = bp.evaluate({ &input, &weights, &bias, &init, &state, &inGradCt, &inGradH, &mask }, {}, {}); - ASSERT_TRUE(resultsBP.size() == 4); - - auto gradX = resultsBP.at(0); - auto gradW = resultsBP.at(1); - auto gradB = resultsBP.at(2); - auto gradInit = resultsBP.at(3); - // expGradX.printBuffer("Exp GRAD"); - // gradX->printBuffer("Res GRAD"); - ASSERT_TRUE(expGradX.equalsTo(gradX, 1e-4)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); - ASSERT_TRUE(expGradInit.equalsTo(gradInit)); - - + const int bS = 2; + const int K = 3; + const int N = 4; + std::vector expGradXBuff = { + -0.0259303, -0.03869125, -0.0302272, -0.02299165, -0.0259303, + -0.03869125, -0.0302272, -0.02299165, -0.0259303, -0.03869125, + -0.0302272, -0.02299165, -0.0259303, -0.03869125, -0.0302272, + -0.02299165, -0.0259303, -0.03869125, -0.0302272, -0.02299165, + -0.0259303, -0.03869125, -0.0302272, -0.02299165}; + std::vector expGradWBuff = { + 0.42526005, 0.42526005, 0.42526005, 0.42526005, 0.42526005, + 0.42526005, 0.42526005, 0.42526005, 0.42526005, -0.5282811, + -0.5282811, -0.5282811, -0.5282811, -0.5282811, -0.5282811, + -0.5282811, -0.5282811, -0.5282811, -0.15967215, -0.15967215, + -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, + -0.15967215, -0.15967215, 0.42526005, 0.42526005, 0.42526005, + 0.42526005, 0.42526005, 0.42526005, 0.42526005, 0.42526005, + 0.42526005, -0.5282811, -0.5282811, -0.5282811, -0.5282811, + -0.5282811, -0.5282811, -0.5282811, -0.5282811, -0.5282811, + -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, + -0.15967215, -0.15967215, -0.15967215, -0.15967215}; + std::vector expGradBBuff = {-0.7043748, -0.7043748, -0.7043748, + -0.2128962, -0.2128962, -0.2128962}; + std::vector expGradInitBuff = {1.1421, 1.1421, 1.1421, + 1.1421, 1.1421, 1.1421}; + std::vector stateBuff = { + 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, + 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, + 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, + 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715}; + + auto input = NDArrayFactory::create('c', {bS, K, N}); + auto weights = NDArrayFactory::create('c', {3 * K, K}); + auto bias = NDArrayFactory::create('c', {1, 2 * K}); + auto init = NDArrayFactory::create('c', {bS, K}); + auto mask = NDArrayFactory::create('c', {bS, K}); + auto state = NDArrayFactory::create('c', {bS, K, N}, stateBuff); + auto inGradCt = NDArrayFactory::create('c', {bS, K}); + auto inGradH = NDArrayFactory::create('c', {bS, K, N}); + + auto expGradX = NDArrayFactory::create('c', {bS, K, N}, expGradXBuff); + auto expGradW = + NDArrayFactory::create('c', {bS, 3 * K, K}, expGradWBuff); + auto expGradB = NDArrayFactory::create('c', {1, 2 * K}, expGradBBuff); + auto expGradInit = + NDArrayFactory::create('c', {bS, K}, expGradInitBuff); + + input.assign(1.5); + weights.assign(0.5); + bias.assign(0.3); + mask.assign(1.); + init.assign(1.); + inGradCt.assign(0.5); + inGradH.assign(0.5); + + sd::ops::sru_bp bp; + auto resultsBP = bp.evaluate( + {&input, &weights, &bias, &init, &state, &inGradCt, &inGradH, &mask}, {}, + {}); + ASSERT_TRUE(resultsBP.size() == 4); + + auto gradX = resultsBP.at(0); + auto gradW = resultsBP.at(1); + auto gradB = resultsBP.at(2); + auto gradInit = resultsBP.at(3); + // expGradX.printBuffer("Exp GRAD"); + // gradX->printBuffer("Res GRAD"); + ASSERT_TRUE(expGradX.equalsTo(gradX, 1e-4)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); + ASSERT_TRUE(expGradInit.equalsTo(gradInit)); } ////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, sru_bi_1) { - - const int bS = 2; - const int K = 3; - const int N = 4; - - NDArray input('c', { N,bS,2 * K }, sd::DataType::DOUBLE); - NDArray weights('c', { 2 * K,6 * K }, sd::DataType::DOUBLE); - NDArray bias('c', { 4 * K }, sd::DataType::DOUBLE); - NDArray init('c', { bS,2 * K }, sd::DataType::DOUBLE); - NDArray mask('c', { bS,2 * K }, sd::DataType::DOUBLE); - NDArray expState('c', { N,bS,2 * K }, { 1.02857, 1.02857, 1.02857, 1.11288, 1.11288, 1.11288, 1.02857, 1.02857, 1.02857, 1.11288, 1.11288, 1.11288, 1.0569, 1.0569, 1.0569, 1.08501, 1.08501, 1.08501, 1.0569, 1.0569, 1.0569, 1.08501, 1.08501, 1.08501, 1.08501, 1.08501, 1.08501, 1.0569, 1.0569, 1.0569, 1.08501, 1.08501, 1.08501, 1.0569, 1.0569, 1.0569, 1.11288, 1.11288, 1.11288, 1.02857, 1.02857, 1.02857, 1.11288, 1.11288, 1.11288, 1.02857, 1.02857, 1.02857 }); - NDArray expOut('c', { N,bS,2 * K }, { 0.779265, 0.779265, 0.779265, 0.810752, 0.810752, 0.810752, 0.779265, 0.779265, 0.779265, 0.810752, 0.810752, 0.810752, 0.790317, 0.790317, 0.790317, 0.800804, 0.800804, 0.800804, 0.790317, 0.790317, 0.790317, 0.800804, 0.800804, 0.800804, 0.800804, 0.800804, 0.800804, 0.790317, 0.790317, 0.790317, 0.800804, 0.800804, 0.800804, 0.790317, 0.790317, 0.790317, 0.810752, 0.810752, 0.810752, 0.779265, 0.779265, 0.779265, 0.810752, 0.810752, 0.810752, 0.779265, 0.779265, 0.779265 }); - - input.assign(1.5); - weights.assign(0.5); - bias.assign(0.3); - init.assign(1.); - mask.assign(1.); - - sd::ops::sru_bi op; - auto results = op.evaluate({ &input, &weights, &bias, &init, &mask }, {}, {}); - ASSERT_TRUE(results.size() == 2); - - auto output = results.at(0); - auto state = results.at(1); - // state->printBuffer(); - // output->printBuffer(); - - ASSERT_TRUE(expState.equalsTo(state)); - ASSERT_TRUE(expOut.equalsTo(output)); - - + const int bS = 2; + const int K = 3; + const int N = 4; + + NDArray input('c', {N, bS, 2 * K}, sd::DataType::DOUBLE); + NDArray weights('c', {2 * K, 6 * K}, sd::DataType::DOUBLE); + NDArray bias('c', {4 * K}, sd::DataType::DOUBLE); + NDArray init('c', {bS, 2 * K}, sd::DataType::DOUBLE); + NDArray mask('c', {bS, 2 * K}, sd::DataType::DOUBLE); + NDArray expState( + 'c', {N, bS, 2 * K}, + {1.02857, 1.02857, 1.02857, 1.11288, 1.11288, 1.11288, 1.02857, 1.02857, + 1.02857, 1.11288, 1.11288, 1.11288, 1.0569, 1.0569, 1.0569, 1.08501, + 1.08501, 1.08501, 1.0569, 1.0569, 1.0569, 1.08501, 1.08501, 1.08501, + 1.08501, 1.08501, 1.08501, 1.0569, 1.0569, 1.0569, 1.08501, 1.08501, + 1.08501, 1.0569, 1.0569, 1.0569, 1.11288, 1.11288, 1.11288, 1.02857, + 1.02857, 1.02857, 1.11288, 1.11288, 1.11288, 1.02857, 1.02857, 1.02857}); + NDArray expOut( + 'c', {N, bS, 2 * K}, + {0.779265, 0.779265, 0.779265, 0.810752, 0.810752, 0.810752, 0.779265, + 0.779265, 0.779265, 0.810752, 0.810752, 0.810752, 0.790317, 0.790317, + 0.790317, 0.800804, 0.800804, 0.800804, 0.790317, 0.790317, 0.790317, + 0.800804, 0.800804, 0.800804, 0.800804, 0.800804, 0.800804, 0.790317, + 0.790317, 0.790317, 0.800804, 0.800804, 0.800804, 0.790317, 0.790317, + 0.790317, 0.810752, 0.810752, 0.810752, 0.779265, 0.779265, 0.779265, + 0.810752, 0.810752, 0.810752, 0.779265, 0.779265, 0.779265}); + + input.assign(1.5); + weights.assign(0.5); + bias.assign(0.3); + init.assign(1.); + mask.assign(1.); + + sd::ops::sru_bi op; + auto results = op.evaluate({&input, &weights, &bias, &init, &mask}, {}, {}); + ASSERT_TRUE(results.size() == 2); + + auto output = results.at(0); + auto state = results.at(1); + // state->printBuffer(); + // output->printBuffer(); + + ASSERT_TRUE(expState.equalsTo(state)); + ASSERT_TRUE(expOut.equalsTo(output)); } TEST_F(DeclarableOpsTests1, sru_bi_bp_1) { - - const int bS = 2; - const int K = 3; - const int N = 3; - std::vector expGradXBuff = { 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129 }; - std::vector expGradInitBuff = { 1.05121, 1.05121, 1.05121, 1.02676, 1.02676, 1.02676, 1.05121, 1.05121, 1.05121, 1.02676, 1.02676, 1.02676 }; - std::vector expGradWBuff = { 0.02595354,-0.090096 ,-0.00882456,0.02595354,-0.090096 ,-0.0088245, 0.02595354,-0.090096 ,-0.00882456,0.01651665,-0.0559437,-0.0084390, 0.01651665,-0.0559437,-0.00843906,0.01651665,-0.0559437,-0.00843906, 0.02595354,-0.090096 ,-0.00882456,0.02595354,-0.090096 ,-0.0088245, 0.02595354,-0.090096 ,-0.00882456,0.01651665,-0.0559437,-0.0084390, 0.01651665,-0.0559437,-0.00843906,0.01651665,-0.0559437,-0.00843906, 0.02595354,-0.090096 ,-0.00882456,0.02595354,-0.090096 ,-0.0088245, 0.02595354,-0.090096 ,-0.00882456,0.01651665,-0.0559437,-0.0084390, 0.01651665,-0.0559437,-0.00843906,0.01651665,-0.0559437,-0.00843906, 0.02595354,-0.090096 ,-0.00882456,0.02595354,-0.090096 ,-0.0088245, 0.02595354,-0.090096 ,-0.00882456,0.01651665,-0.0559437,-0.0084390, 0.01651665,-0.0559437,-0.00843906,0.01651665,-0.0559437,-0.00843906, 0.02595354,-0.090096 ,-0.00882456,0.02595354,-0.090096 ,-0.0088245, 0.02595354,-0.090096 ,-0.00882456,0.01651665,-0.0559437,-0.0084390, 0.01651665,-0.0559437,-0.00843906,0.01651665,-0.0559437,-0.00843906, 0.02595354,-0.090096 ,-0.00882456,0.02595354,-0.090096 ,-0.0088245, 0.02595354,-0.090096 ,-0.00882456,0.01651665,-0.0559437,-0.0084390, 0.01651665,-0.0559437,-0.00843906,0.01651665,-0.0559437,-0.00843906, 0.02124567,-0.0731508,-0.00868926,0.02124567,-0.0731508,-0.0086892, 0.02124567,-0.0731508,-0.00868926,0.02084955,-0.0712011,-0.0085608, 0.02084955,-0.0712011,-0.00856086,0.02084955,-0.0712011,-0.00856086, 0.02124567,-0.0731508,-0.00868926,0.02124567,-0.0731508,-0.0086892, 0.02124567,-0.0731508,-0.00868926,0.02084955,-0.0712011,-0.0085608, 0.02084955,-0.0712011,-0.00856086,0.02084955,-0.0712011,-0.00856086, 0.02124567,-0.0731508,-0.00868926,0.02124567,-0.0731508,-0.0086892, 0.02124567,-0.0731508,-0.00868926,0.02084955,-0.0712011,-0.0085608, 0.02084955,-0.0712011,-0.00856086,0.02084955,-0.0712011,-0.00856086, 0.02124567,-0.0731508,-0.00868926,0.02124567,-0.0731508,-0.0086892, 0.02124567,-0.0731508,-0.00868926,0.02084955,-0.0712011,-0.0085608, 0.02084955,-0.0712011,-0.00856086,0.02084955,-0.0712011,-0.00856086, 0.02124567,-0.0731508,-0.00868926,0.02124567,-0.0731508,-0.0086892, 0.02124567,-0.0731508,-0.00868926,0.02084955,-0.0712011,-0.0085608, 0.02084955,-0.0712011,-0.00856086,0.02084955,-0.0712011,-0.00856086, 0.02124567,-0.0731508,-0.00868926,0.02124567,-0.0731508,-0.0086892, 0.02124567,-0.0731508,-0.00868926,0.02084955,-0.0712011,-0.0085608, 0.02084955,-0.0712011,-0.00856086,0.02084955,-0.0712011,-0.00856086, 0.01671156,-0.0570699,-0.00856086,0.01671156,-0.0570699,-0.0085608, 0.01671156,-0.0570699,-0.00856086,0.02534988,-0.0880002,-0.0086892, 0.02534988,-0.0880002,-0.00868926,0.02534988,-0.0880002,-0.00868926, 0.01671156,-0.0570699,-0.00856086,0.01671156,-0.0570699,-0.0085608, 0.01671156,-0.0570699,-0.00856086,0.02534988,-0.0880002,-0.0086892, 0.02534988,-0.0880002,-0.00868926,0.02534988,-0.0880002,-0.00868926, 0.01671156,-0.0570699,-0.00856086,0.01671156,-0.0570699,-0.0085608, 0.01671156,-0.0570699,-0.00856086,0.02534988,-0.0880002,-0.0086892, 0.02534988,-0.0880002,-0.00868926,0.02534988,-0.0880002,-0.00868926, 0.01671156,-0.0570699,-0.00856086,0.01671156,-0.0570699,-0.0085608, 0.01671156,-0.0570699,-0.00856086,0.02534988,-0.0880002,-0.0086892, 0.02534988,-0.0880002,-0.00868926,0.02534988,-0.0880002,-0.00868926, 0.01671156,-0.0570699,-0.00856086,0.01671156,-0.0570699,-0.0085608, 0.01671156,-0.0570699,-0.00856086,0.02534988,-0.0880002,-0.0086892, 0.02534988,-0.0880002,-0.00868926,0.02534988,-0.0880002,-0.00868926, 0.01671156,-0.0570699,-0.00856086,0.01671156,-0.0570699,-0.0085608, 0.01671156,-0.0570699,-0.00856086,0.02534988,-0.0880002,-0.0086892, 0.02534988,-0.0880002,-0.00868926,0.02534988,-0.0880002,-0.00868926 }; - std::vector expGradBBuff = { -0.0734389, -0.0734389, -0.0734389, -0.0717151, -0.0717151, -0.0717151, -0.0734389, -0.0734389, -0.0734389, -0.0717151, -0.0717151, -0.0717151, -0.00869156, -0.00869156, -0.00869156, -0.00856306, -0.00856306, -0.00856306, -0.00869156, -0.00869156, -0.00869156, -0.00856306, -0.00856306, -0.00856306 }; - std::vector stateBuff = { 1.028569, 1.028569, 1.028569, 1.112884, 1.112884, 1.112884, 1.028569, 1.028569, 1.028569, 1.112884,1.112884, 1.112884, 1.056905, 1.056905, 1.056905, 1.085009, 1.085009, 1.085009, 1.056905, 1.056905,1.056905, 1.085009, 1.085009, 1.085009, 1.085009, 1.085009, 1.085009, 1.056905, 1.056905, 1.056905,1.085009, 1.085009, 1.085009, 1.056905, 1.056905, 1.056905 }; - - auto input = NDArrayFactory::create('c', { N,bS,2 * K }); - auto weights = NDArrayFactory::create('c', { 2 * K,6 * K }); - auto bias = NDArrayFactory::create('c', { 4 * K }); - auto init = NDArrayFactory::create('c', { bS,2 * K }); - auto mask = NDArrayFactory::create('c', { bS,2 * K }); - NDArray state('c', { N,bS,2 * K }, stateBuff); - auto inGradCt = NDArrayFactory::create('c', { bS,2 * K }); - auto inGradH = NDArrayFactory::create('c', { N,bS,2 * K }); - - NDArray gradBias('c', { bS,4 * K }, expGradBBuff); - - NDArray expGradX('c', { N,bS,2 * K }, expGradXBuff); - NDArray expGradW('c', { N,2 * K,6 * K }, expGradWBuff); - auto expGradB = NDArrayFactory::create('c', { 4 * K }); - gradBias.reduceAlongDimension(reduce::Sum, expGradB, { 0 }); // [bS, 4K] -> [4K] - NDArray expGradInit('c', { bS,2 * K }, expGradInitBuff); - - input.assign(1.5); - weights.assign(0.5); - bias.assign(0.3); - mask.assign(1.); - init.assign(1.); - inGradCt.assign(0.5); - inGradH.assign(0.5); - - sd::ops::sru_bi_bp bp; - auto resultsBP = bp.evaluate({ &input, &weights, &bias, &init, &state, &inGradCt, &inGradH, &mask }, {}, {}); - ASSERT_TRUE(resultsBP.size() == 4); - - auto gradX = resultsBP.at(0); - auto gradW = resultsBP.at(1); - auto gradB = resultsBP.at(2); - auto gradInit = resultsBP.at(3); - - ASSERT_TRUE(expGradX.equalsTo(gradX)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); - ASSERT_TRUE(expGradInit.equalsTo(gradInit)); - - + const int bS = 2; + const int K = 3; + const int N = 3; + std::vector expGradXBuff = { + 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, + 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, + 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, + 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, + 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, + 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129}; + std::vector expGradInitBuff = {1.05121, 1.05121, 1.05121, 1.02676, + 1.02676, 1.02676, 1.05121, 1.05121, + 1.05121, 1.02676, 1.02676, 1.02676}; + std::vector expGradWBuff = { + 0.02595354, -0.090096, -0.00882456, 0.02595354, -0.090096, -0.0088245, + 0.02595354, -0.090096, -0.00882456, 0.01651665, -0.0559437, -0.0084390, + 0.01651665, -0.0559437, -0.00843906, 0.01651665, -0.0559437, -0.00843906, + 0.02595354, -0.090096, -0.00882456, 0.02595354, -0.090096, -0.0088245, + 0.02595354, -0.090096, -0.00882456, 0.01651665, -0.0559437, -0.0084390, + 0.01651665, -0.0559437, -0.00843906, 0.01651665, -0.0559437, -0.00843906, + 0.02595354, -0.090096, -0.00882456, 0.02595354, -0.090096, -0.0088245, + 0.02595354, -0.090096, -0.00882456, 0.01651665, -0.0559437, -0.0084390, + 0.01651665, -0.0559437, -0.00843906, 0.01651665, -0.0559437, -0.00843906, + 0.02595354, -0.090096, -0.00882456, 0.02595354, -0.090096, -0.0088245, + 0.02595354, -0.090096, -0.00882456, 0.01651665, -0.0559437, -0.0084390, + 0.01651665, -0.0559437, -0.00843906, 0.01651665, -0.0559437, -0.00843906, + 0.02595354, -0.090096, -0.00882456, 0.02595354, -0.090096, -0.0088245, + 0.02595354, -0.090096, -0.00882456, 0.01651665, -0.0559437, -0.0084390, + 0.01651665, -0.0559437, -0.00843906, 0.01651665, -0.0559437, -0.00843906, + 0.02595354, -0.090096, -0.00882456, 0.02595354, -0.090096, -0.0088245, + 0.02595354, -0.090096, -0.00882456, 0.01651665, -0.0559437, -0.0084390, + 0.01651665, -0.0559437, -0.00843906, 0.01651665, -0.0559437, -0.00843906, + 0.02124567, -0.0731508, -0.00868926, 0.02124567, -0.0731508, -0.0086892, + 0.02124567, -0.0731508, -0.00868926, 0.02084955, -0.0712011, -0.0085608, + 0.02084955, -0.0712011, -0.00856086, 0.02084955, -0.0712011, -0.00856086, + 0.02124567, -0.0731508, -0.00868926, 0.02124567, -0.0731508, -0.0086892, + 0.02124567, -0.0731508, -0.00868926, 0.02084955, -0.0712011, -0.0085608, + 0.02084955, -0.0712011, -0.00856086, 0.02084955, -0.0712011, -0.00856086, + 0.02124567, -0.0731508, -0.00868926, 0.02124567, -0.0731508, -0.0086892, + 0.02124567, -0.0731508, -0.00868926, 0.02084955, -0.0712011, -0.0085608, + 0.02084955, -0.0712011, -0.00856086, 0.02084955, -0.0712011, -0.00856086, + 0.02124567, -0.0731508, -0.00868926, 0.02124567, -0.0731508, -0.0086892, + 0.02124567, -0.0731508, -0.00868926, 0.02084955, -0.0712011, -0.0085608, + 0.02084955, -0.0712011, -0.00856086, 0.02084955, -0.0712011, -0.00856086, + 0.02124567, -0.0731508, -0.00868926, 0.02124567, -0.0731508, -0.0086892, + 0.02124567, -0.0731508, -0.00868926, 0.02084955, -0.0712011, -0.0085608, + 0.02084955, -0.0712011, -0.00856086, 0.02084955, -0.0712011, -0.00856086, + 0.02124567, -0.0731508, -0.00868926, 0.02124567, -0.0731508, -0.0086892, + 0.02124567, -0.0731508, -0.00868926, 0.02084955, -0.0712011, -0.0085608, + 0.02084955, -0.0712011, -0.00856086, 0.02084955, -0.0712011, -0.00856086, + 0.01671156, -0.0570699, -0.00856086, 0.01671156, -0.0570699, -0.0085608, + 0.01671156, -0.0570699, -0.00856086, 0.02534988, -0.0880002, -0.0086892, + 0.02534988, -0.0880002, -0.00868926, 0.02534988, -0.0880002, -0.00868926, + 0.01671156, -0.0570699, -0.00856086, 0.01671156, -0.0570699, -0.0085608, + 0.01671156, -0.0570699, -0.00856086, 0.02534988, -0.0880002, -0.0086892, + 0.02534988, -0.0880002, -0.00868926, 0.02534988, -0.0880002, -0.00868926, + 0.01671156, -0.0570699, -0.00856086, 0.01671156, -0.0570699, -0.0085608, + 0.01671156, -0.0570699, -0.00856086, 0.02534988, -0.0880002, -0.0086892, + 0.02534988, -0.0880002, -0.00868926, 0.02534988, -0.0880002, -0.00868926, + 0.01671156, -0.0570699, -0.00856086, 0.01671156, -0.0570699, -0.0085608, + 0.01671156, -0.0570699, -0.00856086, 0.02534988, -0.0880002, -0.0086892, + 0.02534988, -0.0880002, -0.00868926, 0.02534988, -0.0880002, -0.00868926, + 0.01671156, -0.0570699, -0.00856086, 0.01671156, -0.0570699, -0.0085608, + 0.01671156, -0.0570699, -0.00856086, 0.02534988, -0.0880002, -0.0086892, + 0.02534988, -0.0880002, -0.00868926, 0.02534988, -0.0880002, -0.00868926, + 0.01671156, -0.0570699, -0.00856086, 0.01671156, -0.0570699, -0.0085608, + 0.01671156, -0.0570699, -0.00856086, 0.02534988, -0.0880002, -0.0086892, + 0.02534988, -0.0880002, -0.00868926, 0.02534988, -0.0880002, -0.00868926}; + std::vector expGradBBuff = { + -0.0734389, -0.0734389, -0.0734389, -0.0717151, -0.0717151, + -0.0717151, -0.0734389, -0.0734389, -0.0734389, -0.0717151, + -0.0717151, -0.0717151, -0.00869156, -0.00869156, -0.00869156, + -0.00856306, -0.00856306, -0.00856306, -0.00869156, -0.00869156, + -0.00869156, -0.00856306, -0.00856306, -0.00856306}; + std::vector stateBuff = { + 1.028569, 1.028569, 1.028569, 1.112884, 1.112884, 1.112884, + 1.028569, 1.028569, 1.028569, 1.112884, 1.112884, 1.112884, + 1.056905, 1.056905, 1.056905, 1.085009, 1.085009, 1.085009, + 1.056905, 1.056905, 1.056905, 1.085009, 1.085009, 1.085009, + 1.085009, 1.085009, 1.085009, 1.056905, 1.056905, 1.056905, + 1.085009, 1.085009, 1.085009, 1.056905, 1.056905, 1.056905}; + + auto input = NDArrayFactory::create('c', {N, bS, 2 * K}); + auto weights = NDArrayFactory::create('c', {2 * K, 6 * K}); + auto bias = NDArrayFactory::create('c', {4 * K}); + auto init = NDArrayFactory::create('c', {bS, 2 * K}); + auto mask = NDArrayFactory::create('c', {bS, 2 * K}); + NDArray state('c', {N, bS, 2 * K}, stateBuff); + auto inGradCt = NDArrayFactory::create('c', {bS, 2 * K}); + auto inGradH = NDArrayFactory::create('c', {N, bS, 2 * K}); + + NDArray gradBias('c', {bS, 4 * K}, expGradBBuff); + + NDArray expGradX('c', {N, bS, 2 * K}, expGradXBuff); + NDArray expGradW('c', {N, 2 * K, 6 * K}, expGradWBuff); + auto expGradB = NDArrayFactory::create('c', {4 * K}); + gradBias.reduceAlongDimension(reduce::Sum, expGradB, + {0}); // [bS, 4K] -> [4K] + NDArray expGradInit('c', {bS, 2 * K}, expGradInitBuff); + + input.assign(1.5); + weights.assign(0.5); + bias.assign(0.3); + mask.assign(1.); + init.assign(1.); + inGradCt.assign(0.5); + inGradH.assign(0.5); + + sd::ops::sru_bi_bp bp; + auto resultsBP = bp.evaluate( + {&input, &weights, &bias, &init, &state, &inGradCt, &inGradH, &mask}, {}, + {}); + ASSERT_TRUE(resultsBP.size() == 4); + + auto gradX = resultsBP.at(0); + auto gradW = resultsBP.at(1); + auto gradB = resultsBP.at(2); + auto gradInit = resultsBP.at(3); + + ASSERT_TRUE(expGradX.equalsTo(gradX)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); + ASSERT_TRUE(expGradInit.equalsTo(gradInit)); } TEST_F(DeclarableOpsTests1, ArgMax1) { - auto x = NDArrayFactory::create('c', { 3, 5 }); - x.linspace(1); - auto exp = NDArrayFactory::create('c', { 3 }); - exp.assign(4); - - sd::ops::argmax op; + auto x = NDArrayFactory::create('c', {3, 5}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {3}); + exp.assign(4); - auto result = op.evaluate({ &x }, {}, { 1 }); + sd::ops::argmax op; - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({&x}, {}, {1}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests1, ArgMax2) { - auto x = NDArrayFactory::create('c', { 3, 5 }); - x.linspace(1); - auto exp = NDArrayFactory::create('c', { 5 }); - exp.assign(2); - - sd::ops::argmax op; - - auto result = op.evaluate({ &x }, {}, { 0 }); + auto x = NDArrayFactory::create('c', {3, 5}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {5}); + exp.assign(2); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::argmax op; - auto z = result.at(0); + auto result = op.evaluate({&x}, {}, {0}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests1, ArgMax3) { - auto x = NDArrayFactory::create('c', { 3, 5 }); - auto dim = NDArrayFactory::create('c', { 1, 1 }, { 0. }); - x.linspace(1); - auto exp = NDArrayFactory::create('c', { 5 }); - exp.assign(2); - - sd::ops::argmax op; + auto x = NDArrayFactory::create('c', {3, 5}); + auto dim = NDArrayFactory::create('c', {1, 1}, {0.}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {5}); + exp.assign(2); - auto result = op.evaluate({ &x, &dim }, {}, {}); + sd::ops::argmax op; - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({&x, &dim}, {}, {}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests1, ArgMax4) { - auto x = NDArrayFactory::create('c', { 3, 5 }); - auto dim = NDArrayFactory::create('c', { 1, 1 }, { 1 }); - x.linspace(1); - auto exp = NDArrayFactory::create('c', { 3 }); - exp.assign(4); - - sd::ops::argmax op; + auto x = NDArrayFactory::create('c', {3, 5}); + auto dim = NDArrayFactory::create('c', {1, 1}, {1}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {3}); + exp.assign(4); - auto result = op.evaluate({ &x, &dim }, {}, {}); + sd::ops::argmax op; - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({&x, &dim}, {}, {}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests1, ArgMax5) { - auto x = NDArrayFactory::create('c', { 3, 5 }); - auto dim = NDArrayFactory::create('c', { 1, 2 }, { 0, 1 }); - x.linspace(1); - auto exp = NDArrayFactory::create(14); - - - sd::ops::argmax op; + auto x = NDArrayFactory::create('c', {3, 5}); + auto dim = NDArrayFactory::create('c', {1, 2}, {0, 1}); + x.linspace(1); + auto exp = NDArrayFactory::create(14); - auto result = op.evaluate({ &x, &dim }, {}, {}); + sd::ops::argmax op; - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({&x, &dim}, {}, {}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests1, ArgMax6) { - auto x = NDArrayFactory::create('c', { 3, 4, 5 }); - auto dim = NDArrayFactory::create(-1.f); - x.linspace(1); - - - sd::ops::argmax op; + auto x = NDArrayFactory::create('c', {3, 4, 5}); + auto dim = NDArrayFactory::create(-1.f); + x.linspace(1); - auto expected = op.evaluate({ &x }, {}, { 2 }); - ASSERT_EQ(Status::OK(), expected.status()); - auto exp = expected.at(0); + sd::ops::argmax op; + auto expected = op.evaluate({&x}, {}, {2}); + ASSERT_EQ(Status::OK(), expected.status()); + auto exp = expected.at(0); - auto result = op.evaluate({ &x, &dim }, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto result = op.evaluate({&x, &dim}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_EQ(*exp, *z); + ASSERT_EQ(exp, z); } - TEST_F(DeclarableOpsTests1, ArgMin1) { - auto x = NDArrayFactory::create('c', { 3, 5 }); - x.linspace(1); - // auto exp('c', {3, 1}); - auto exp = NDArrayFactory::create('c', { 3 }); - exp.assign(0.0f); - - sd::ops::argmin op; + auto x = NDArrayFactory::create('c', {3, 5}); + x.linspace(1); + // auto exp('c', {3, 1}); + auto exp = NDArrayFactory::create('c', {3}); + exp.assign(0.0f); - auto result = op.evaluate({ &x }, {}, { 1 }); + sd::ops::argmin op; - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({&x}, {}, {1}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests1, SquareTests1) { - auto x = NDArrayFactory::create('c', { 3, 5 }); - x.linspace(1); + auto x = NDArrayFactory::create('c', {3, 5}); + x.linspace(1); - auto exp = NDArrayFactory::create('c', { 3, 5 }); - exp.linspace(1); - exp *= exp; + auto exp = NDArrayFactory::create('c', {3, 5}); + exp.linspace(1); + exp *= exp; - sd::ops::square op; + sd::ops::square op; - auto result = op.evaluate({ &x }, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests1, OneHotTests_1) { + auto indices = + NDArrayFactory::create('c', {1, 4}, {0.0f, 2.0f, -1.0f, 1.0f}); - auto indices = NDArrayFactory::create('c', { 1, 4 }, { 0.0f, 2.0f, -1.0f, 1.0f }); - - auto exp = NDArrayFactory::create('c', { 1, 4, 3 }, { 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f }); - - sd::ops::onehot op; + auto exp = NDArrayFactory::create( + 'c', {1, 4, 3}, + {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f}); - auto result = op.evaluate({ &indices }, { 1.0f, 0.0f }, { -1, 3 }); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::onehot op; - auto z = result.at(0); - // z->printBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto result = op.evaluate({&indices}, {1.0f, 0.0f}, {-1, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printBuffer(); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests1, OneHotTests_2) { - auto indices = NDArrayFactory::create('c', { 2, 2 }, { 0.f, 2.f, 1.f, -1.f }); - - auto exp = NDArrayFactory::create('c', { 2, 2, 3 }, { 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f }); + auto indices = + NDArrayFactory::create('c', {2, 2}, {0.f, 2.f, 1.f, -1.f}); - sd::ops::onehot op; - auto result = op.evaluate({ &indices }, { 1.0f, 0.0f }, { -1, 3 }); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 3}, + {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::onehot op; + auto result = op.evaluate({&indices}, {1.0f, 0.0f}, {-1, 3}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests1, OneHotTests_3) { - auto indices = NDArrayFactory::create('c', { 4 }, { 0.0f, 2.0f, -1.0f, 1.0f }); - - auto exp = NDArrayFactory::create('c', { 4, 3 }, { 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f }); + auto indices = + NDArrayFactory::create('c', {4}, {0.0f, 2.0f, -1.0f, 1.0f}); - sd::ops::onehot op; + auto exp = NDArrayFactory::create( + 'c', {4, 3}, + {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f}); - auto result = op.evaluate({ &indices }, { 1.0f, 0.0f }, { -1, 3 }); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::onehot op; - auto z = result.at(0); + auto result = op.evaluate({&indices}, {1.0f, 0.0f}, {-1, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - // z->printIndexedBuffer("z"); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printIndexedBuffer("z"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests1, OneHotTests_4) { - auto indices = NDArrayFactory::create('c', { 4 }, { 0.0f, 2.0f, -1.0f, 1.0f }); - auto depth = NDArrayFactory::create(3.0f); - - auto exp = NDArrayFactory::create('c', { 4, 3 }, { 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f }); - - sd::ops::onehot op; + auto indices = + NDArrayFactory::create('c', {4}, {0.0f, 2.0f, -1.0f, 1.0f}); + auto depth = NDArrayFactory::create(3.0f); - auto result = op.evaluate({ &indices, &depth }, { 1.0f, 0.0f }, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto exp = NDArrayFactory::create( + 'c', {4, 3}, + {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f}); - auto z = result.at(0); + sd::ops::onehot op; - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto result = op.evaluate({&indices, &depth}, {1.0f, 0.0f}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests1, OneHotTests_5) { - auto indices = NDArrayFactory::create('c', { 4 }, { 0.0f, 2.0f, -1.0f, 1.0f }); - auto depth = NDArrayFactory::create(3.0f); - auto on = NDArrayFactory::create(1.0f); - auto off = NDArrayFactory::create(0.0f); + auto indices = + NDArrayFactory::create('c', {4}, {0.0f, 2.0f, -1.0f, 1.0f}); + auto depth = NDArrayFactory::create(3.0f); + auto on = NDArrayFactory::create(1.0f); + auto off = NDArrayFactory::create(0.0f); - auto exp = NDArrayFactory::create('c', { 4, 3 }, { 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f }); + auto exp = NDArrayFactory::create( + 'c', {4, 3}, + {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f}); - sd::ops::onehot op; + sd::ops::onehot op; - auto result = op.evaluate({ &indices, &depth, &on, &off }, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto result = op.evaluate({&indices, &depth, &on, &off}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests1, OneHotTests_6) { - auto indices = NDArrayFactory::create('c', { 3 }, { 0.f, 1.f, 2.f }); - auto e = NDArrayFactory::create('c', { 3, 3 }, { 1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f }); - - sd::ops::onehot op; - auto result = op.evaluate({ &indices }, { 1.0, 0.0 }, { 0, 3 }); - auto z = result.at(0); - - ASSERT_EQ(e, *z); + auto indices = NDArrayFactory::create('c', {3}, {0.f, 1.f, 2.f}); + auto e = NDArrayFactory::create( + 'c', {3, 3}, {1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f}); + sd::ops::onehot op; + auto result = op.evaluate({&indices}, {1.0, 0.0}, {0, 3}); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests1, OneHotTests_7) { - auto indices = NDArrayFactory::create('c', { 3 }, { 0, 1, 2 }); - auto e = NDArrayFactory::create('c', { 3, 3 }, { 1., 0., 0., 0., 1., 0., 0., 0., 1. }); - - sd::ops::onehot op; - auto result = op.evaluate({ &indices }, { 1.0, 0.0 }, { 0, 3 }, {}, { sd::DataType::HALF }, false); - auto z = result.at(0); - - ASSERT_EQ(e, *z); + auto indices = NDArrayFactory::create('c', {3}, {0, 1, 2}); + auto e = NDArrayFactory::create( + 'c', {3, 3}, {1., 0., 0., 0., 1., 0., 0., 0., 1.}); + sd::ops::onehot op; + auto result = op.evaluate({&indices}, {1.0, 0.0}, {0, 3}, {}, + {sd::DataType::HALF}, false); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests1, FillAs_1) { - auto x = NDArrayFactory::create('c', { 2, 2 }); - x.assign(117); - - float scalar = 119.f; - - sd::ops::fill_as op; - auto result = op.evaluate({ &x }, { scalar }, {}); + auto x = NDArrayFactory::create('c', {2, 2}); + x.assign(117); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + float scalar = 119.f; - ASSERT_TRUE(x.isSameShape(result.at(0))); + sd::ops::fill_as op; + auto result = op.evaluate({&x}, {scalar}, {}); - ASSERT_NEAR(scalar, result.at(0)->meanNumber().e(0), 1e-5f); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(x.isSameShape(result.at(0))); + ASSERT_NEAR(scalar, result.at(0).meanNumber().e(0), 1e-5f); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, LRN1) { - sd::ops::lrn lrn; + sd::ops::lrn lrn; - lrn.getOpName(); + lrn.getOpName(); } TEST_F(DeclarableOpsTests1, Test_Range_Integer_1) { - auto exp = NDArrayFactory::create('c', { 4 }); - exp.linspace(1); + auto exp = NDArrayFactory::create('c', {4}); + exp.linspace(1); - sd::ops::range op; + sd::ops::range op; - auto result = op.evaluate({}, {}, { 1, 5, 1 }); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_EQ(1, result.size()); - - auto array = result.at(0); - // array->printIndexedBuffer("Range integer 1"); - ASSERT_TRUE(exp.isSameShape(array)); - ASSERT_TRUE(exp.equalsTo(array)); + auto result = op.evaluate({}, {}, {1, 5, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(1, result.size()); + auto array = result.at(0); + // array->printIndexedBuffer("Range integer 1"); + ASSERT_TRUE(exp.isSameShape(array)); + ASSERT_TRUE(exp.equalsTo(array)); } - TEST_F(DeclarableOpsTests1, Test_Range_Integer_2) { - auto exp = NDArrayFactory::create('c', { 4 }); - exp.linspace(1); - - auto start = NDArrayFactory::create('c', { 1, 1 }); - auto stop = NDArrayFactory::create('c', { 1, 1 }); - auto step = NDArrayFactory::create('c', { 1, 1 }); - start.p(0, 1.f); - stop.p(0, 5.f); - step.p(0, 1.f); + auto exp = NDArrayFactory::create('c', {4}); + exp.linspace(1); - sd::ops::range op; + auto start = NDArrayFactory::create('c', {1, 1}); + auto stop = NDArrayFactory::create('c', {1, 1}); + auto step = NDArrayFactory::create('c', {1, 1}); + start.p(0, 1.f); + stop.p(0, 5.f); + step.p(0, 1.f); - auto result = op.evaluate({ &start, &stop, &step }, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::range op; - ASSERT_EQ(1, result.size()); + auto result = op.evaluate({&start, &stop, &step}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto array = result.at(0); - - ASSERT_TRUE(exp.isSameShape(array)); - ASSERT_TRUE(exp.equalsTo(array)); + ASSERT_EQ(1, result.size()); + auto array = result.at(0); + ASSERT_TRUE(exp.isSameShape(array)); + ASSERT_TRUE(exp.equalsTo(array)); } - TEST_F(DeclarableOpsTests1, Test_Range_Integer_3) { - auto exp = NDArrayFactory::create('c', { 4 }); - exp.linspace(1); - - sd::ops::range op; + auto exp = NDArrayFactory::create('c', {4}); + exp.linspace(1); - auto result = op.evaluate({}, { 1.f, 5.f, 1.f }, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::range op; - ASSERT_EQ(1, result.size()); + auto result = op.evaluate({}, {1.f, 5.f, 1.f}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto array = result.at(0); - - ASSERT_TRUE(exp.isSameShape(array)); - ASSERT_TRUE(exp.equalsTo(array)); + ASSERT_EQ(1, result.size()); + auto array = result.at(0); + ASSERT_TRUE(exp.isSameShape(array)); + ASSERT_TRUE(exp.equalsTo(array)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test1) { + NDArray input('c', {3, 3}, {-1.f, 1.f, -2.f, 2.f, -3.f, 3.f, -4.f, 4.f, 5.f}, + sd::DataType::FLOAT32); - NDArray input('c', { 3, 3 }, { -1.f, 1.f, -2.f, 2.f, -3.f, 3.f, -4.f, 4.f, 5.f }, sd::DataType::FLOAT32); - - NDArray expOutput('c', { 3, 3 }, { 1.14195199e-01, 8.43794734e-01, 4.20100661e-02, 2.68454951e-01, 1.80883523e-03, 7.29736214e-01, 9.02116571e-05, 2.68917160e-01, 7.30992629e-01 }, sd::DataType::FLOAT32); - - sd::ops::softmax op; - auto results = op.evaluate({ &input }, {}, {}, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + NDArray expOutput('c', {3, 3}, + {1.14195199e-01, 8.43794734e-01, 4.20100661e-02, + 2.68454951e-01, 1.80883523e-03, 7.29736214e-01, + 9.02116571e-05, 2.68917160e-01, 7.30992629e-01}, + sd::DataType::FLOAT32); + sd::ops::softmax op; + auto results = op.evaluate({&input}, {}, {}, {}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test2) { - NDArray input('c', { 3, 3, 3 }, { -1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14 }, sd::DataType::FLOAT32); - NDArray expOutput('c', { 3, 3, 3 }, { 4.73142e-02,4.73847e-02,6.69062e-03, 9.50330e-01,8.67881e-04,9.92976e-01, 2.35563e-03,9.51747e-01,3.33106e-04, 4.74259e-02,2.26032e-06,4.74259e-02, 2.91395e-07,9.99998e-01,3.94360e-08, 9.52574e-01,1.12535e-07,9.52574e-01, 7.58256e-10,4.74259e-02,1.22325e-11, 1.00000e+00,1.32293e-11,1.19203e-01, 3.77513e-11,9.52574e-01,8.80797e-01 }, sd::DataType::FLOAT32); - - sd::ops::softmax op; - auto results = op.evaluate({ &input }, {}, { 1 }, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); - - + NDArray input('c', {3, 3, 3}, + {-1, 1, -2, 2, -3, 3, -4, 4, -5, 5, -6, 6, -7, 7, + -8, 8, -9, 9, -10, 10, -11, 11, -12, 12, -13, 13, 14}, + sd::DataType::FLOAT32); + NDArray expOutput( + 'c', {3, 3, 3}, + {4.73142e-02, 4.73847e-02, 6.69062e-03, 9.50330e-01, 8.67881e-04, + 9.92976e-01, 2.35563e-03, 9.51747e-01, 3.33106e-04, 4.74259e-02, + 2.26032e-06, 4.74259e-02, 2.91395e-07, 9.99998e-01, 3.94360e-08, + 9.52574e-01, 1.12535e-07, 9.52574e-01, 7.58256e-10, 4.74259e-02, + 1.22325e-11, 1.00000e+00, 1.32293e-11, 1.19203e-01, 3.77513e-11, + 9.52574e-01, 8.80797e-01}, + sd::DataType::FLOAT32); + + sd::ops::softmax op; + auto results = op.evaluate({&input}, {}, {1}, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test3) { - NDArray input('c', { 3, 3, 3 }, { -1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14 }, sd::DataType::FLOAT32); - NDArray expOutput('c', { 3, 3, 3 }, { 2.47262e-03,1.23395e-04,3.35350e-04, 1.23395e-04,4.53979e-05,1.23395e-04, 6.14417e-06,1.23395e-04,5.56530e-09, 9.97527e-01,1.12521e-07,9.99665e-01, 1.52281e-08,9.99955e-01,2.06090e-09, 9.99994e-01,2.78912e-10,6.69285e-03, 3.05146e-07,9.99876e-01,4.13855e-08, 9.99877e-01,5.60254e-09,9.99877e-01, 7.58251e-10,9.99877e-01,9.93307e-01 }, sd::DataType::FLOAT32); - - sd::ops::softmax op; - auto results = op.evaluate({ &input }, {}, { 0 }, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); - - + NDArray input('c', {3, 3, 3}, + {-1, 1, -2, 2, -3, 3, -4, 4, -5, 5, -6, 6, -7, 7, + -8, 8, -9, 9, -10, 10, -11, 11, -12, 12, -13, 13, 14}, + sd::DataType::FLOAT32); + NDArray expOutput( + 'c', {3, 3, 3}, + {2.47262e-03, 1.23395e-04, 3.35350e-04, 1.23395e-04, 4.53979e-05, + 1.23395e-04, 6.14417e-06, 1.23395e-04, 5.56530e-09, 9.97527e-01, + 1.12521e-07, 9.99665e-01, 1.52281e-08, 9.99955e-01, 2.06090e-09, + 9.99994e-01, 2.78912e-10, 6.69285e-03, 3.05146e-07, 9.99876e-01, + 4.13855e-08, 9.99877e-01, 5.60254e-09, 9.99877e-01, 7.58251e-10, + 9.99877e-01, 9.93307e-01}, + sd::DataType::FLOAT32); + + sd::ops::softmax op; + auto results = op.evaluate({&input}, {}, {0}, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test4) { - NDArray input('c', { 1, 5 }, { -1, 1, -2, 2, 3 }, sd::DataType::FLOAT32); - NDArray expOutput('c', { 1, 5 }, { 0.01198,0.08855,0.00441,0.24072,0.65434 }, sd::DataType::FLOAT32); - - sd::ops::softmax op; - auto results = op.evaluate({ &input }, {}, { 1 }, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + NDArray input('c', {1, 5}, {-1, 1, -2, 2, 3}, sd::DataType::FLOAT32); + NDArray expOutput('c', {1, 5}, {0.01198, 0.08855, 0.00441, 0.24072, 0.65434}, + sd::DataType::FLOAT32); + sd::ops::softmax op; + auto results = op.evaluate({&input}, {}, {1}, {}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test5) { - NDArray input('c', { 1, 5 }, { -1, 1, -2, 2, 3 }, sd::DataType::FLOAT32); - NDArray expOutput('c', { 1, 5 }, { 1,1,1,1,1 }, sd::DataType::FLOAT32); - - sd::ops::softmax op; - auto results = op.evaluate({ &input }, {}, { 0 }); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + NDArray input('c', {1, 5}, {-1, 1, -2, 2, 3}, sd::DataType::FLOAT32); + NDArray expOutput('c', {1, 5}, {1, 1, 1, 1, 1}, sd::DataType::FLOAT32); + sd::ops::softmax op; + auto results = op.evaluate({&input}, {}, {0}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test6) { - NDArray input('c', { 5, 1 }, { -1, 1, -2, 2, 3 }, sd::DataType::FLOAT32); - NDArray expOutput('c', { 5, 1 }, { 0.01198,0.08855,0.00441,0.24072,0.65434 }, sd::DataType::FLOAT32); - - sd::ops::softmax op; - auto results = op.evaluate({ &input }, {}, { 0 }, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + NDArray input('c', {5, 1}, {-1, 1, -2, 2, 3}, sd::DataType::FLOAT32); + NDArray expOutput('c', {5, 1}, {0.01198, 0.08855, 0.00441, 0.24072, 0.65434}, + sd::DataType::FLOAT32); + sd::ops::softmax op; + auto results = op.evaluate({&input}, {}, {0}, {}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test7) { - NDArray input('c', { 5, 1 }, { -1, 1, -2, 2, 3 }, sd::DataType::FLOAT32); - NDArray expOutput('c', { 5, 1 }, { 1,1,1,1,1 }, sd::DataType::FLOAT32); - - sd::ops::softmax op; - auto results = op.evaluate({ &input }, {}, { 1 }, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + NDArray input('c', {5, 1}, {-1, 1, -2, 2, 3}, sd::DataType::FLOAT32); + NDArray expOutput('c', {5, 1}, {1, 1, 1, 1, 1}, sd::DataType::FLOAT32); + sd::ops::softmax op; + auto results = op.evaluate({&input}, {}, {1}, {}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test8) { - NDArray input('c', { 5 }, { -1, 1, -2, 2, 3 }, sd::DataType::FLOAT32); - NDArray expOutput('c', { 5 }, { 0.01198,0.08855,0.00441,0.24072,0.65434 }, sd::DataType::FLOAT32); - - sd::ops::softmax op; - auto results = op.evaluate({ &input }, {}, {}, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + NDArray input('c', {5}, {-1, 1, -2, 2, 3}, sd::DataType::FLOAT32); + NDArray expOutput('c', {5}, {0.01198, 0.08855, 0.00441, 0.24072, 0.65434}, + sd::DataType::FLOAT32); + sd::ops::softmax op; + auto results = op.evaluate({&input}, {}, {}, {}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test9) { - NDArray input('c', { 2, 2, 2, 2 }, { -1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8 }, sd::DataType::FLOAT32); - NDArray expOutput('c', { 2, 2, 2, 2 }, { 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059 }, sd::DataType::FLOAT32); - - sd::ops::softmax op; - auto results = op.evaluate({ &input }, {}, { 2 }, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + NDArray input('c', {2, 2, 2, 2}, + {-1, 1, -2, 2, -3, 3, -4, 4, -5, 5, -6, 6, -7, 7, -8, 8}, + sd::DataType::FLOAT32); + NDArray expOutput('c', {2, 2, 2, 2}, + {0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, + 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, + 0.731059, 0.268941, 0.268941, 0.731059}, + sd::DataType::FLOAT32); + sd::ops::softmax op; + auto results = op.evaluate({&input}, {}, {2}, {}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test10) { - NDArray input('c', { 2, 2, 2, 2, 2 }, { -1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14, -14, 15, -15, 16,-16 }, sd::DataType::FLOAT32); - NDArray expOutput('c', { 2, 2, 2, 2, 2 }, { 0.119203, 0.880797, 0.017986, 0.982014, 0.002473, 0.997527, 0.000335, 0.999665, 0.000045, 0.999955, 0.000006, 0.999994, 0.000001, 0.999999, 0.000000, 1.000000, 0.000000, 1.000000, 0.000000, 1.000000, 0.000000, 1.000000, 0.000000, 1.000000, 0.000000, 1.000000, 1.000000, 0.000000, 1.000000, 0.000000, 1.000000, 0.00000 }, sd::DataType::FLOAT32); - - sd::ops::softmax op; - auto results = op.evaluate({ &input }, {}, { 4 }, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); - - + NDArray input( + 'c', {2, 2, 2, 2, 2}, + {-1, 1, -2, 2, -3, 3, -4, 4, -5, 5, -6, 6, -7, 7, -8, 8, + -9, 9, -10, 10, -11, 11, -12, 12, -13, 13, 14, -14, 15, -15, 16, -16}, + sd::DataType::FLOAT32); + NDArray expOutput( + 'c', {2, 2, 2, 2, 2}, + {0.119203, 0.880797, 0.017986, 0.982014, 0.002473, 0.997527, 0.000335, + 0.999665, 0.000045, 0.999955, 0.000006, 0.999994, 0.000001, 0.999999, + 0.000000, 1.000000, 0.000000, 1.000000, 0.000000, 1.000000, 0.000000, + 1.000000, 0.000000, 1.000000, 0.000000, 1.000000, 1.000000, 0.000000, + 1.000000, 0.000000, 1.000000, 0.00000}, + sd::DataType::FLOAT32); + + sd::ops::softmax op; + auto results = op.evaluate({&input}, {}, {4}, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test11) { - NDArray input('c', { 2, 2, 2, 2, 2, 2 }, { -1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14, -14, 15, -15, 16,-16, -2.1, 2.1, -2.2, 2.2, -2.3, 2.3, -2.4, 2.4, -2.5,2.5 ,-2.6,2.6, -2.7,2.7, -2.8,2.8, -2.9,2.9, -3.0,3.0, -3.1,3.1, -3.2,3.2, -3.3,3.3, 3.4, -3.4, 3.5, -3.5, 3.6,-3.6 }, sd::DataType::FLOAT32); - NDArray expOutput('c', { 2, 2, 2, 2, 2, 2 }, { 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.000000, 1.000000, 1.000000, 0.000000, 0.268941, 0.731059, 0.731059, 0.268941, 0.524979, 0.475021, 0.475021, 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, 0.001229, 0.998771, 0.998771, 0.001229, 0.475021, 0.524979, 0.524979, 0.475021 }, sd::DataType::FLOAT32); - - sd::ops::softmax op; - auto results = op.evaluate({ &input }, {}, { 4 }, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); - - + NDArray input( + 'c', {2, 2, 2, 2, 2, 2}, + {-1, 1, -2, 2, -3, 3, -4, 4, -5, 5, -6, + 6, -7, 7, -8, 8, -9, 9, -10, 10, -11, 11, + -12, 12, -13, 13, 14, -14, 15, -15, 16, -16, -2.1, + 2.1, -2.2, 2.2, -2.3, 2.3, -2.4, 2.4, -2.5, 2.5, -2.6, 2.6, + -2.7, 2.7, -2.8, 2.8, -2.9, 2.9, -3.0, 3.0, -3.1, 3.1, -3.2, + 3.2, -3.3, 3.3, 3.4, -3.4, 3.5, -3.5, 3.6, -3.6}, + sd::DataType::FLOAT32); + NDArray expOutput( + 'c', {2, 2, 2, 2, 2, 2}, + {0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, + 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, + 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, + 0.268941, 0.268941, 0.731059, 0.000000, 1.000000, 1.000000, 0.000000, + 0.268941, 0.731059, 0.731059, 0.268941, 0.524979, 0.475021, 0.475021, + 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, 0.524979, 0.475021, + 0.475021, 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, 0.524979, + 0.475021, 0.475021, 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, + 0.001229, 0.998771, 0.998771, 0.001229, 0.475021, 0.524979, 0.524979, + 0.475021}, + sd::DataType::FLOAT32); + + sd::ops::softmax op; + auto results = op.evaluate({&input}, {}, {4}, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test12) { - NDArray input('f', { 2, 2, 2, 2, 2, 2 }, { -1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14, -14, 15, -15, 16,-16, -2.1, 2.1, -2.2, 2.2, -2.3, 2.3, -2.4, 2.4, -2.5,2.5 ,-2.6,2.6, -2.7,2.7, -2.8,2.8, -2.9,2.9, -3.0,3.0, -3.1,3.1, -3.2,3.2, -3.3,3.3, 3.4, -3.4, 3.5, -3.5, 3.6,-3.6 }, sd::DataType::FLOAT32); - NDArray exp('c', { 2, 2, 2, 2, 2, 2 }, { 0.982014, 0.598688, 0.982014, 0.598688, 0.017986, 0.401312, 0.017986, 0.401312, 0.982014, 0.598688, 0.000000, 0.001359, 0.017986, 0.401312, 1.000000, 0.998641, 0.982014, 0.598688, 0.000000, 0.001659, 0.017986, 0.401312, 1.000000, 0.998341, 0.982014, 0.598688, 0.000000, 0.001113, 0.017986, 0.401312, 1.000000, 0.998887, 0.017986, 0.401312, 0.017986, 0.401312, 0.982014, 0.598688, 0.982014, 0.598688, 0.017986, 0.401312, 1.000000, 0.998641, 0.982014, 0.598688, 0.000000, 0.001359, 0.017986, 0.401312, 1.000000, 0.998341, 0.982014, 0.598688, 0.000000, 0.001659, 0.017986, 0.401312, 1.000000, 0.998887, 0.982014, 0.598688, 0.000000, 0.001113 }, sd::DataType::FLOAT32); - - auto expOutput = NDArray('f', { 2, 2, 2, 2, 2, 2 }, sd::DataType::FLOAT32); - expOutput.assign(exp); - - sd::ops::softmax op; - auto results = op.evaluate({ &input }, {}, { 3 }, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); - - + NDArray input( + 'f', {2, 2, 2, 2, 2, 2}, + {-1, 1, -2, 2, -3, 3, -4, 4, -5, 5, -6, + 6, -7, 7, -8, 8, -9, 9, -10, 10, -11, 11, + -12, 12, -13, 13, 14, -14, 15, -15, 16, -16, -2.1, + 2.1, -2.2, 2.2, -2.3, 2.3, -2.4, 2.4, -2.5, 2.5, -2.6, 2.6, + -2.7, 2.7, -2.8, 2.8, -2.9, 2.9, -3.0, 3.0, -3.1, 3.1, -3.2, + 3.2, -3.3, 3.3, 3.4, -3.4, 3.5, -3.5, 3.6, -3.6}, + sd::DataType::FLOAT32); + NDArray exp( + 'c', {2, 2, 2, 2, 2, 2}, + {0.982014, 0.598688, 0.982014, 0.598688, 0.017986, 0.401312, 0.017986, + 0.401312, 0.982014, 0.598688, 0.000000, 0.001359, 0.017986, 0.401312, + 1.000000, 0.998641, 0.982014, 0.598688, 0.000000, 0.001659, 0.017986, + 0.401312, 1.000000, 0.998341, 0.982014, 0.598688, 0.000000, 0.001113, + 0.017986, 0.401312, 1.000000, 0.998887, 0.017986, 0.401312, 0.017986, + 0.401312, 0.982014, 0.598688, 0.982014, 0.598688, 0.017986, 0.401312, + 1.000000, 0.998641, 0.982014, 0.598688, 0.000000, 0.001359, 0.017986, + 0.401312, 1.000000, 0.998341, 0.982014, 0.598688, 0.000000, 0.001659, + 0.017986, 0.401312, 1.000000, 0.998887, 0.982014, 0.598688, 0.000000, + 0.001113}, + sd::DataType::FLOAT32); + + auto expOutput = NDArray('f', {2, 2, 2, 2, 2, 2}, sd::DataType::FLOAT32); + expOutput.assign(exp); + + sd::ops::softmax op; + auto results = op.evaluate({&input}, {}, {3}, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_1) { + float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {24., 23., 22., 21., 20., 19., 18., 17., 16., 15., 14., 13., + 12., 11., 10., 9., 8., 7., 6., 5., 4., 3., 2., 1.}; + Nd4jLong shapeInfo[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99}; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; - float expBuff[] = { 24., 23., 22., 21., 20., 19., 18., 17., 16., 15., 14., 13., 12., 11., 10., 9., 8., 7., 6., 5., 4., 3., 2., 1. }; - Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; - ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - - NDArray input(inBuff, shapeInfo); - NDArray expected(expBuff, shapeInfo); - NDArray output(shapeInfo); + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, { 0,1,2 }); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {0, 1, 2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_2) { + float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + Nd4jLong shapeInfo[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99}; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; - float expBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; - Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; - ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - - NDArray input(inBuff, shapeInfo); - NDArray expected(expBuff, shapeInfo); - NDArray output(shapeInfo); - - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, {}, {}, {}, true); + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {}, {}, {}, true); - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(input)); - ASSERT_TRUE(expected.equalsTo(&input)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(input)); + ASSERT_TRUE(expected.equalsTo(&input)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_3) { + float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {12., 11., 10., 9., 8., 7., 6., 5., + 4., 3., 2., 1., 24., 23., 22., 21., + 20., 19., 18., 17., 16., 15., 14., 13.}; + Nd4jLong shapeInfo[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99}; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; - float expBuff[] = { 12., 11., 10., 9., 8., 7., 6., 5., 4., 3., 2., 1., 24., 23., 22., 21., 20., 19., 18., 17., 16., 15., 14., 13. }; - Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; - ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - - NDArray input(inBuff, shapeInfo); - NDArray expected(expBuff, shapeInfo); - NDArray output(shapeInfo); - - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, { 1,2 }); + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {1, 2}); - auto result = results.at(0); - // result->printBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + // result->printBuffer(); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_4) { + float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = { + 16, 15, 14, 13, 20, 19, 18, 17, 24, 23, 22, 21, + 4, 3, 2, 1, 8, 7, 6, 5, 12, 11, 10, 9, + }; + Nd4jLong shapeInfo[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99}; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; - float expBuff[] = { 16,15,14,13,20,19,18,17,24,23,22,21,4,3,2,1,8,7,6,5,12,11,10,9, }; - Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; - ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - - NDArray input(inBuff, shapeInfo); - NDArray expected(expBuff, shapeInfo); - NDArray output(shapeInfo); - - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, { 0,2 }); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); - auto result = results.at(0); - // result->printBuffer(); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {0, 2}); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + // result->printBuffer(); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_5) { + float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {21., 22., 23., 24., 17., 18., 19., 20., 13., 14., 15., 16., + 9., 10., 11., 12., 5., 6., 7., 8., 1., 2., 3., 4.}; + Nd4jLong shapeInfo[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99}; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; - float expBuff[] = { 21., 22., 23., 24., 17., 18., 19., 20., 13., 14., 15., 16., 9., 10., 11., 12., 5., 6., 7., 8., 1., 2., 3., 4. }; - Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; - ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); - NDArray input(inBuff, shapeInfo); - NDArray expected(expBuff, shapeInfo); - NDArray output(shapeInfo); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {0, 1}); - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, { 0,1 }); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_6) { + float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {4., 3., 2., 1., 8., 7., 6., 5., + 12., 11., 10., 9., 16., 15., 14., 13., + 20., 19., 18., 17., 24., 23., 22., 21.}; + Nd4jLong shapeInfo[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99}; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; - float expBuff[] = { 4., 3., 2., 1., 8., 7., 6., 5., 12., 11., 10., 9., 16., 15., 14., 13., 20., 19., 18., 17., 24., 23., 22., 21. }; - Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; - ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - - NDArray input(inBuff, shapeInfo); - NDArray expected(expBuff, shapeInfo); - NDArray output(shapeInfo); + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, { 2 }, {}, {}, true); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {2}, {}, {}, true); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result->printBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(input)); - ASSERT_TRUE(expected.equalsTo(&input)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + // result->printBuffer(); + ASSERT_TRUE(expected.isSameShapeStrict(input)); + ASSERT_TRUE(expected.equalsTo(&input)); } - //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_7) { + float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {9., 10., 11., 12., 5., 6., 7., 8., + 1., 2., 3., 4., 21., 22., 23., 24., + 17., 18., 19., 20., 13., 14., 15., 16.}; + Nd4jLong shapeInfo[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99}; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; - float expBuff[] = { 9., 10., 11., 12., 5., 6., 7., 8., 1., 2., 3., 4., 21., 22., 23., 24., 17., 18., 19., 20., 13., 14., 15., 16. }; - Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; - ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - - NDArray input(inBuff, shapeInfo); - NDArray expected(expBuff, shapeInfo); - NDArray output(shapeInfo); + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, { 1 }); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - //expected.printIndexedBuffer("E"); - //result->printIndexedBuffer("R"); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + // expected.printIndexedBuffer("E"); + // result->printIndexedBuffer("R"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_8) { + float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {12., 11., 10., 9., 8., 7., 6., 5., + 4., 3., 2., 1., 24., 23., 22., 21., + 20., 19., 18., 17., 16., 15., 14., 13.}; + Nd4jLong shapeInfo[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99}; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; - float expBuff[] = { 12., 11., 10., 9., 8., 7., 6., 5., 4., 3., 2., 1., 24., 23., 22., 21., 20., 19., 18., 17., 16., 15., 14., 13. }; - Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; - ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - - NDArray input(inBuff, shapeInfo); - NDArray expected(expBuff, shapeInfo); - NDArray output(shapeInfo); - - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, { 2,1 }); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); - auto result = results.at(0); - // result->printBuffer(); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {2, 1}); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + // result->printBuffer(); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_9) { + float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {13., 14., 15., 16., 17., 18., 19., 20., + 21., 22., 23., 24., 1., 2., 3., 4., + 5., 6., 7., 8., 9., 10., 11., 12.}; + Nd4jLong shapeInfo[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99}; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; - float expBuff[] = { 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }; - Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; - ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); - NDArray input(inBuff, shapeInfo); - NDArray expected(expBuff, shapeInfo); - NDArray output(shapeInfo); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {0}); - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, { 0 }); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests1, Reverse_10) { - auto x = NDArrayFactory::create('c', { 4, 3 }, { 1.5375735, 0.1592365, 0.09966054, 0.677872, 1.144433, -1.0355669, 0.48456487, -0.67863184, 0.85020787, 0.13950661, 0.20998026, -1.1660044 }); - auto i = NDArrayFactory::create('c', { 1 }, { -1 }); - auto e = NDArrayFactory::create('c', { 4, 3 }, { 0.09966054, 0.1592365, 1.5375735, -1.0355669, 1.144433, 0.677872,0.85020787, -0.67863184, 0.48456487, -1.1660044, 0.20998026, 0.13950661 }); - - sd::ops::reverse op; - auto result = op.evaluate({ &x, &i }, {}, {}, {}); + auto x = NDArrayFactory::create( + 'c', {4, 3}, + {1.5375735, 0.1592365, 0.09966054, 0.677872, 1.144433, -1.0355669, + 0.48456487, -0.67863184, 0.85020787, 0.13950661, 0.20998026, + -1.1660044}); + auto i = NDArrayFactory::create('c', {1}, {-1}); + auto e = NDArrayFactory::create( + 'c', {4, 3}, + {0.09966054, 0.1592365, 1.5375735, -1.0355669, 1.144433, 0.677872, + 0.85020787, -0.67863184, 0.48456487, -1.1660044, 0.20998026, + 0.13950661}); - auto z = result.at(0); - - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + sd::ops::reverse op; + auto result = op.evaluate({&x, &i}, {}, {}, {}); + auto z = result.at(0); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_11) { + auto input = NDArrayFactory::create('c', {2, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {24.f, 23.f, 22.f, 21.f, 20.f, 19.f, 18.f, 17.f, 16.f, 15.f, 14.f, 13.f, + 12.f, 11.f, 10.f, 9.f, 8.f, 7.f, 6.f, 5.f, 4.f, 3.f, 2.f, 1.f}); + input.linspace(1); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {0, 1, 2}); - auto input = NDArrayFactory::create('c', { 2,3,4 }); - auto expected = NDArrayFactory::create('c', { 2,3,4 }, { 24.f, 23.f, 22.f, 21.f, 20.f, 19.f, 18.f, 17.f, 16.f, - 15.f, 14.f, 13.f, 12.f, 11.f, 10.f, 9.f, 8.f, 7.f, - 6.f, 5.f, 4.f, 3.f, 2.f, 1.f }); - - input.linspace(1); - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, { 0, 1, 2 }); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_12) { + auto input = NDArrayFactory::create({0.f, 1.f, 2.f, 3.f, 4.f}); + auto expected = NDArrayFactory::create({4.f, 3.f, 2.f, 1.f, 0.f}); + // input.linspace(1); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {0}); - auto input = NDArrayFactory::create({ 0.f, 1.f, 2.f, 3.f, 4.f }); - auto expected = NDArrayFactory::create({ 4.f, 3.f, 2.f, 1.f, 0.f }); - - //input.linspace(1); - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, { 0 }); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - //result->printIndexedBuffer("Result reverse"); - //expected.printIndexedBuffer("Expected reverse"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); - + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + // result->printIndexedBuffer("Result reverse"); + // expected.printIndexedBuffer("Expected reverse"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_13) { + auto input = NDArrayFactory::create({0.f, 1.f, 2.f, 3.f, 4.f}); + auto expected = NDArrayFactory::create({4.f, 3.f, 2.f, 1.f, 0.f}); + // input.linspace(1); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {-1}); - auto input = NDArrayFactory::create({ 0.f, 1.f, 2.f, 3.f, 4.f }); - auto expected = NDArrayFactory::create({ 4.f, 3.f, 2.f, 1.f, 0.f }); - - //input.linspace(1); - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, { -1 }); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_14) { + auto input = NDArrayFactory::create({0.f, 1.f, 2.f, 3.f, 4.f}); + auto expected = NDArrayFactory::create({0.f, 1.f, 2.f, 3.f, 4.f}); + // input.linspace(1); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {}, {}); - auto input = NDArrayFactory::create({ 0.f, 1.f, 2.f, 3.f, 4.f }); - auto expected = NDArrayFactory::create({ 0.f, 1.f, 2.f, 3.f, 4.f }); - - //input.linspace(1); - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests1, Test_Expose_1) { - auto input0 = NDArrayFactory::create('c', { 2, 3 }, { 1, 2, 3, 6, 5, 4 }); - auto input1 = NDArrayFactory::create('c', { 2, 3 }, { 3, 2, 1, 4, 5, 6 }); - - sd::ops::expose op; + auto input0 = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 6, 5, 4}); + auto input1 = NDArrayFactory::create('c', {2, 3}, {3, 2, 1, 4, 5, 6}); - auto result = op.evaluate({ &input0, &input1 }); + sd::ops::expose op; - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({&input0, &input1}); - auto z0 = result.at(0); - auto z1 = result.at(1); - - ASSERT_TRUE(input0.equalsTo(z0)); - ASSERT_TRUE(input1.equalsTo(z1)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z0 = result.at(0); + auto z1 = result.at(1); + ASSERT_TRUE(input0.equalsTo(z0)); + ASSERT_TRUE(input1.equalsTo(z1)); } TEST_F(DeclarableOpsTests1, Test_Expose_2) { - auto list = new NDArrayList(0, true); - - auto var = new Variable(nullptr, "arraylist", -1, 0); - var->setNDArrayList(list); - - VariableSpace variableSpace; - variableSpace.putVariable(-1, var); - variableSpace.trackList(list); + auto list = std::make_shared(0, true); + auto var = std::make_shared(list, "arraylist", -1, 0); - Context block(1, &variableSpace); - block.pickInput(-1); + VariableSpace variableSpace; + variableSpace.putVariable(-1, var); - sd::ops::expose op; - auto result = op.execute(&block); + Context block(1, &variableSpace); + block.pickInput(-1); - ASSERT_EQ(ND4J_STATUS_OK, result); - ASSERT_TRUE(variableSpace.hasVariable(1)); + sd::ops::expose op; + auto result = op.execute(&block); - auto var1 = variableSpace.getVariable(1); + ASSERT_EQ(ND4J_STATUS_OK, result); + ASSERT_TRUE(variableSpace.hasVariable(1)); - ASSERT_EQ(var->variableType(), var1->variableType()); + auto var1 = variableSpace.getVariable(1); - auto list1 = var1->getNDArrayList(); + ASSERT_EQ(var->variableType(), var1->variableType()); - ASSERT_TRUE(list == list1); + auto list1 = var1->getNDArrayList(); + ASSERT_TRUE(list.get() == list1.get()); } TEST_F(DeclarableOpsTests1, Test_Release) { - auto x = NDArrayFactory::create('c', { 8, 8 }); - // x.printShapeInfo("x shape"); + auto x = NDArrayFactory::create('c', {8, 8}); + // x.printShapeInfo("x shape"); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 2ffc2c22d39a..f2e7196a8c22 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -15,385 +15,389 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // Created by raver on 8/4/2018. // -#include "testlayers.h" -#include #include -#include #include +#include +#include +#include "testlayers.h" using namespace sd; - class DeclarableOpsTests10 : public testing::Test { -public: - - DeclarableOpsTests10() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests10() { + printf("\n"); + fflush(stdout); + } }; template class TypedDeclarableOpsTests10 : public testing::Test { -public: - - TypedDeclarableOpsTests10() { - printf("\n"); - fflush(stdout); - } + public: + TypedDeclarableOpsTests10() { + printf("\n"); + fflush(stdout); + } }; typedef ::testing::Types TestingTypes; -TYPED_TEST_CASE(TypedDeclarableOpsTests10, TestingTypes); +TYPED_TEST_SUITE(TypedDeclarableOpsTests10, TestingTypes); TEST_F(DeclarableOpsTests10, Test_ArgMax_1) { - auto x = NDArrayFactory::create('c', {3, 3}); - auto e = NDArrayFactory::create(8); + auto x = NDArrayFactory::create('c', {3, 3}); + auto e = NDArrayFactory::create(8); - x.linspace(1.0); + x.linspace(1.0); + sd::ops::argmax op; + auto result = op.evaluate({&x}); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::argmax op; - auto result = op.evaluate({&x}); - ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); - - auto z = *result.at(0); - - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests10, Test_ArgMax_2) { - auto x = NDArrayFactory::create('c', {3, 3}); - auto y = NDArrayFactory::create('c', {1}, {1}); - auto e = NDArrayFactory::create('c', {3}, {2, 2, 2}); + auto x = NDArrayFactory::create('c', {3, 3}); + auto y = NDArrayFactory::create('c', {1}, {1}); + auto e = NDArrayFactory::create('c', {3}, {2, 2, 2}); - x.linspace(1.0); + x.linspace(1.0); - sd::ops::argmax op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::argmax op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = *result.at(0); + auto z = result.at(0); - //z.printIndexedBuffer("z"); - //z.printShapeInfo("z shape"); + // z.printIndexedBuffer("z"); + // z.printShapeInfo("z shape"); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests10, Test_And_1) { - auto x = NDArrayFactory::create('c', {4}, {1, 1, 0, 1}); - auto y = NDArrayFactory::create('c', {4}, {0, 0, 0, 1}); - auto e = NDArrayFactory::create('c', {4}, {0, 0, 0, 1}); + auto x = NDArrayFactory::create('c', {4}, {1, 1, 0, 1}); + auto y = NDArrayFactory::create('c', {4}, {0, 0, 0, 1}); + auto e = NDArrayFactory::create('c', {4}, {0, 0, 0, 1}); - sd::ops::boolean_and op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::boolean_and op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, *result.at(0)); + ASSERT_EQ(e, result.at(0)); } TEST_F(DeclarableOpsTests10, Test_Or_1) { - auto x = NDArrayFactory::create('c', {4}, {1, 1, 0, 1}); - auto y = NDArrayFactory::create('c', {4}, {0, 0, 0, 1}); - auto e = NDArrayFactory::create('c', {4}, {1, 1, 0, 1}); + auto x = NDArrayFactory::create('c', {4}, {1, 1, 0, 1}); + auto y = NDArrayFactory::create('c', {4}, {0, 0, 0, 1}); + auto e = NDArrayFactory::create('c', {4}, {1, 1, 0, 1}); - sd::ops::boolean_or op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::boolean_or op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, *result.at(0)); + ASSERT_EQ(e, result.at(0)); } TEST_F(DeclarableOpsTests10, Test_Not_1) { - auto x = NDArrayFactory::create('c', {4}, {true, true, false, true}); - auto y = NDArrayFactory::create('c', {4}, {false, false, false, true}); -// auto e = NDArrayFactory::create('c', {4}, {1, 1, 1, 0}); - auto e = NDArrayFactory::create('c', {4}, {false, false, true, false}); + auto x = NDArrayFactory::create('c', {4}, {true, true, false, true}); + auto y = NDArrayFactory::create('c', {4}, {false, false, false, true}); + // auto e = NDArrayFactory::create('c', {4}, {1, 1, 1, 0}); + auto e = NDArrayFactory::create('c', {4}, {false, false, true, false}); - sd::ops::boolean_not op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); - auto res = result.at(0); + sd::ops::boolean_not op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + auto res = result.at(0); - ASSERT_TRUE(e.equalsTo(res)); + ASSERT_TRUE(e.equalsTo(res)); } TEST_F(DeclarableOpsTests10, Test_Size_at_1) { - auto x = NDArrayFactory::create('c', {10, 20, 30}); - auto e = NDArrayFactory::create(20); - - sd::ops::size_at op; - auto result = op.evaluate({&x}, {1}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {10, 20, 30}); + auto e = NDArrayFactory::create(20); - ASSERT_EQ(e, *result.at(0)); + sd::ops::size_at op; + auto result = op.evaluate({&x}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(e, result.at(0)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, MirrorPad_SGO_Test_1) { + auto in = NDArrayFactory::create({1., 2., 3., 4., 5.}); + // auto pad('c', {1, 2}, {1., 1.});// = Nd4j.create(new double[]{1, 1}, new + // long[]{1, 2}); + auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); + // auto value(10.0); - auto in = NDArrayFactory::create({1., 2., 3., 4., 5.}); -// auto pad('c', {1, 2}, {1., 1.});// = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}); - auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); -// auto value(10.0); + auto exp = NDArrayFactory::create({2., 1., 2., 3., 4., 5., 4.}); - auto exp = NDArrayFactory::create({2., 1., 2., 3., 4., 5., 4.}); + sd::ops::mirror_pad op; - sd::ops::mirror_pad op; + auto res = op.evaluate({&in, &pad}, {10.0}, {0}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto res = op.evaluate({&in, &pad}, {10.0}, {0}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - - ASSERT_TRUE(exp.equalsTo(res.at(0))); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Unique_SGO_Test_1) { - auto input = NDArrayFactory::create({3., 4., 3., 1., 3., 0., 2., 4., 2., 4.}); - auto expIdx = NDArrayFactory::create({0, 1, 0, 2, 0, 3, 4, 1, 4, 1}); - auto exp = NDArrayFactory::create({3., 4., 1., 0., 2.}); - - sd::ops::unique op; - auto res = op.evaluate({&input}, {}, {}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto res1 = res.at(0); - auto res2 = res.at(1); + auto input = + NDArrayFactory::create({3., 4., 3., 1., 3., 0., 2., 4., 2., 4.}); + auto expIdx = + NDArrayFactory::create({0, 1, 0, 2, 0, 3, 4, 1, 4, 1}); + auto exp = NDArrayFactory::create({3., 4., 1., 0., 2.}); - ASSERT_TRUE(exp.equalsTo(res1)); - ASSERT_TRUE(expIdx.equalsTo(res2)); + sd::ops::unique op; + auto res = op.evaluate({&input}, {}, {}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto res1 = res.at(0); + auto res2 = res.at(1); + ASSERT_TRUE(exp.equalsTo(res1)); + ASSERT_TRUE(expIdx.equalsTo(res2)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Where_SGO_Test_1) { - auto input = NDArrayFactory::create('c', {3, 3}, {true, false, false, true, true, false, true, true, true}); - //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); - auto exp = NDArrayFactory::create('c', {6, 2}, {0LL, 0LL, 1LL, 0LL, 1LL, 1LL, 2LL, 0LL, 2LL, 1LL, 2LL, 2LL}); - - sd::ops::Where op; - auto res = op.evaluate({&input}, {}, {}); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - auto resA = res.at(0); + auto input = NDArrayFactory::create( + 'c', {3, 3}, {true, false, false, true, true, false, true, true, true}); + // auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto exp = NDArrayFactory::create( + 'c', {6, 2}, + {0LL, 0LL, 1LL, 0LL, 1LL, 1LL, 2LL, 0LL, 2LL, 1LL, 2LL, 2LL}); - ASSERT_TRUE(exp.isSameShape(resA)); - ASSERT_TRUE(exp.equalsTo(resA)); -// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); + sd::ops::Where op; + auto res = op.evaluate({&input}, {}, {}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + ASSERT_TRUE(exp.isSameShape(resA)); + ASSERT_TRUE(exp.equalsTo(resA)); + // ASSERT_TRUE(expIdx.equalsTo(res.at(1))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Where_SGO_Test_02) { - auto input = NDArrayFactory::create('c', {2, 2, 2}, {true, false, false, true, true, true, true, false}); - //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); - auto exp = NDArrayFactory::create('c', {5, 3}, {0LL, 0LL, 0LL, 0LL, 1LL, 1LL, 1LL, 0LL, 0LL, 1LL, 0LL, 1LL, 1LL, 1LL, 0LL}); + auto input = NDArrayFactory::create( + 'c', {2, 2, 2}, {true, false, false, true, true, true, true, false}); + // auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto exp = + NDArrayFactory::create('c', {5, 3}, + {0LL, 0LL, 0LL, 0LL, 1LL, 1LL, 1LL, 0LL, + 0LL, 1LL, 0LL, 1LL, 1LL, 1LL, 0LL}); - sd::ops::Where op; - auto res = op.evaluate({&input}, {}, {}); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - auto resA = res.at(0); - - ASSERT_TRUE(exp.equalsTo(resA)); - ASSERT_TRUE(exp.isSameShape(resA)); -// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); + sd::ops::Where op; + auto res = op.evaluate({&input}, {}, {}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + ASSERT_TRUE(exp.equalsTo(resA)); + ASSERT_TRUE(exp.isSameShape(resA)); + // ASSERT_TRUE(expIdx.equalsTo(res.at(1))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_1) { - auto cond3d = NDArrayFactory::create('c', {2, 2, 2}, {true, false, false, true, true, true, true, false}); -// auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); - auto exp1 = NDArrayFactory::create({0, 0, 1, 1, 1}); - auto exp2 = NDArrayFactory::create({0, 1, 0, 0, 1}); - auto exp3 = NDArrayFactory::create({0, 1, 0, 1, 0}); - sd::ops::where_np op; - auto res = op.evaluate({&cond3d}, {}, {}); - ASSERT_TRUE(res.size() == 3); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto res1 = res.at(0); - auto res2 = res.at(1); - auto res3 = res.at(2); -// res1->printShapeInfo("Res1 shape"); res1->printBuffer("Res1"); -// res2->printShapeInfo("Res2 shape"); res2->printBuffer("Res2"); -// res3->printShapeInfo("Res3 shape"); res3->printBuffer("Res3"); - ASSERT_TRUE(exp1.equalsTo(res1)); - ASSERT_TRUE(exp2.equalsTo(res2)); - ASSERT_TRUE(exp3.equalsTo(res3)); - //ASSERT_TRUE(expIdx.equalsTo(res.at(1))); - + auto cond3d = NDArrayFactory::create( + 'c', {2, 2, 2}, {true, false, false, true, true, true, true, false}); + // auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto exp1 = NDArrayFactory::create({0, 0, 1, 1, 1}); + auto exp2 = NDArrayFactory::create({0, 1, 0, 0, 1}); + auto exp3 = NDArrayFactory::create({0, 1, 0, 1, 0}); + sd::ops::where_np op; + auto res = op.evaluate({&cond3d}, {}, {}); + ASSERT_TRUE(res.size() == 3); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto res1 = res.at(0); + auto res2 = res.at(1); + auto res3 = res.at(2); + // res1->printShapeInfo("Res1 shape"); res1->printBuffer("Res1"); + // res2->printShapeInfo("Res2 shape"); res2->printBuffer("Res2"); + // res3->printShapeInfo("Res3 shape"); res3->printBuffer("Res3"); + ASSERT_TRUE(exp1.equalsTo(res1)); + ASSERT_TRUE(exp2.equalsTo(res2)); + ASSERT_TRUE(exp3.equalsTo(res3)); + // ASSERT_TRUE(expIdx.equalsTo(res.at(1))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_2) { - auto cond2d = NDArrayFactory::create('c', {3, 5}, {true, true, false, false, true, true, true, - true, true, true, false, true, true, true, true}); -// auto expIdx({0, 1, 0, 2, 0, 3, 4, 1, 4, 1}); - auto exp1 = NDArrayFactory::create({0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2}); - auto exp2 = NDArrayFactory::create({0, 1, 4, 0, 1, 2, 3, 4, 1, 2, 3, 4}); - sd::ops::where_np op; - auto res = op.evaluate({&cond2d}, {}, {}); - ASSERT_TRUE(res.size() == 2); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - ASSERT_TRUE(exp1.equalsTo(res.at(0))); - ASSERT_TRUE(exp2.equalsTo(res.at(1))); - //ASSERT_TRUE(expIdx.equalsTo(res.at(1))); - + auto cond2d = NDArrayFactory::create( + 'c', {3, 5}, + {true, true, false, false, true, true, true, true, true, true, false, + true, true, true, true}); + // auto expIdx({0, 1, 0, 2, 0, 3, 4, 1, 4, 1}); + auto exp1 = + NDArrayFactory::create({0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2}); + auto exp2 = + NDArrayFactory::create({0, 1, 4, 0, 1, 2, 3, 4, 1, 2, 3, 4}); + sd::ops::where_np op; + auto res = op.evaluate({&cond2d}, {}, {}); + ASSERT_TRUE(res.size() == 2); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + ASSERT_TRUE(exp1.equalsTo(res.at(0))); + ASSERT_TRUE(exp2.equalsTo(res.at(1))); + // ASSERT_TRUE(expIdx.equalsTo(res.at(1))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Where_SGO_Test_2) { - auto input = NDArrayFactory::create({true, false, true, true, true}); - //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); - auto exp = NDArrayFactory::create('c', {4,1}, {0, 2, 3, 4}); - - sd::ops::Where op; - auto res = op.evaluate({&input}); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - auto resA = res.at(0); -// resA->printIndexedBuffer("Result A"); -// resA->printShapeInfo("ShapeA"); - ASSERT_TRUE(exp.equalsTo(resA)); - ASSERT_TRUE(exp.isSameShape(resA)); -// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); + auto input = NDArrayFactory::create({true, false, true, true, true}); + // auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto exp = NDArrayFactory::create('c', {4, 1}, {0, 2, 3, 4}); + sd::ops::Where op; + auto res = op.evaluate({&input}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + // resA->printIndexedBuffer("Result A"); + // resA->printShapeInfo("ShapeA"); + ASSERT_TRUE(exp.equalsTo(resA)); + ASSERT_TRUE(exp.isSameShape(resA)); + // ASSERT_TRUE(expIdx.equalsTo(res.at(1))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Where_SGO_Test_3) { - auto input = NDArrayFactory::create('c', {5, 1}, {true, false, true, true, true}); - //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); - auto exp = NDArrayFactory::create('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0}); - - sd::ops::Where op; - auto res = op.evaluate({&input}, {}, {}); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - auto resA = res.at(0); - //resA->printIndexedBuffer("Result A"); - //resA->printShapeInfo("ShapeA"); - ASSERT_TRUE(exp.equalsTo(resA)); - ASSERT_TRUE(exp.isSameShape(resA)); -// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); - + auto input = NDArrayFactory::create('c', {5, 1}, + {true, false, true, true, true}); + // auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto exp = + NDArrayFactory::create('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0}); + + sd::ops::Where op; + auto res = op.evaluate({&input}, {}, {}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + // resA->printIndexedBuffer("Result A"); + // resA->printShapeInfo("ShapeA"); + ASSERT_TRUE(exp.equalsTo(resA)); + ASSERT_TRUE(exp.isSameShape(resA)); + // ASSERT_TRUE(expIdx.equalsTo(res.at(1))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Where_SGO_Test_4) { - auto input = NDArrayFactory::create('c', {5, 1}, {false, false, false, false, false}); - //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); - auto exp = NDArrayFactory::create('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0}); - - sd::ops::Where op; - auto res = op.evaluate({&input}, {}, {}); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - auto resA = res.at(0); - ASSERT_TRUE(resA->isEmpty()); - //resA->printIndexedBuffer("Result A"); - //resA->printShapeInfo("ShapeA"); - //ASSERT_TRUE(exp.equalsTo(resA)); - //ASSERT_TRUE(exp.isSameShape(resA)); -// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); - + auto input = NDArrayFactory::create( + 'c', {5, 1}, {false, false, false, false, false}); + // auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto exp = + NDArrayFactory::create('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0}); + + sd::ops::Where op; + auto res = op.evaluate({&input}, {}, {}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + ASSERT_TRUE(resA.isEmpty()); + // resA->printIndexedBuffer("Result A"); + // resA->printShapeInfo("ShapeA"); + // ASSERT_TRUE(exp.equalsTo(resA)); + // ASSERT_TRUE(exp.isSameShape(resA)); + // ASSERT_TRUE(expIdx.equalsTo(res.at(1))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Where_SGO_Test_5) { - auto input = NDArrayFactory::create('c', {5}, {1, 0, 0, 2, 3}); - //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); - auto exp = NDArrayFactory::create('c', {3, 1}, {0, 3, 4}); - - sd::ops::Where op; - auto res = op.evaluate({&input}, {}, {}); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - auto resA = res.at(0); - //ASSERT_TRUE(resA->isEmpty()); + auto input = NDArrayFactory::create('c', {5}, {1, 0, 0, 2, 3}); + // auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto exp = NDArrayFactory::create('c', {3, 1}, {0, 3, 4}); - ASSERT_TRUE(exp.equalsTo(resA)); - ASSERT_TRUE(exp.isSameShape(resA)); -// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); + sd::ops::Where op; + auto res = op.evaluate({&input}, {}, {}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + // ASSERT_TRUE(resA->isEmpty()); + ASSERT_TRUE(exp.equalsTo(resA)); + ASSERT_TRUE(exp.isSameShape(resA)); + // ASSERT_TRUE(expIdx.equalsTo(res.at(1))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_4) { - auto input = NDArrayFactory::create('c', {5, 1}, {false, false, false, false, false}); - //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); - auto exp = NDArrayFactory::create('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0}); - - sd::ops::where_np op; - auto res = op.evaluate({&input}, {}, {}); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - auto resA = res.at(0); - ASSERT_TRUE(resA->isEmpty()); - //resA->printIndexedBuffer("Result A"); - //resA->printShapeInfo("ShapeA"); - //ASSERT_TRUE(exp.equalsTo(resA)); - //ASSERT_TRUE(exp.isSameShape(resA)); -// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); - + auto input = NDArrayFactory::create( + 'c', {5, 1}, {false, false, false, false, false}); + // auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto exp = + NDArrayFactory::create('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0}); + + sd::ops::where_np op; + auto res = op.evaluate({&input}, {}, {}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + ASSERT_TRUE(resA.isEmpty()); + // resA->printIndexedBuffer("Result A"); + // resA->printShapeInfo("ShapeA"); + // ASSERT_TRUE(exp.equalsTo(resA)); + // ASSERT_TRUE(exp.isSameShape(resA)); + // ASSERT_TRUE(expIdx.equalsTo(res.at(1))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, CosineDistance_SGO_Test_1) { - auto labels = NDArrayFactory::create('c', {2, 3}, {1.0, 2.0, 3.0, -1.0, 2.0, 1.0}); - //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); - auto predictions = NDArrayFactory::create('c', {2, 3}, {-0.3, -0.2, -0.1, 0, 0.1, 0.2}); - auto weights = NDArrayFactory::create('c', {2, 1}, {0., 1.}); - auto exp = NDArrayFactory::create(0.6); - - sd::ops::cosine_distance_loss op; - auto res = op.evaluate({&predictions, &weights, &labels}, {}, {3, 1}); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - auto resA = res.at(0); + auto labels = NDArrayFactory::create('c', {2, 3}, + {1.0, 2.0, 3.0, -1.0, 2.0, 1.0}); + // auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto predictions = NDArrayFactory::create( + 'c', {2, 3}, {-0.3, -0.2, -0.1, 0, 0.1, 0.2}); + auto weights = NDArrayFactory::create('c', {2, 1}, {0., 1.}); + auto exp = NDArrayFactory::create(0.6); - ASSERT_TRUE(exp.equalsTo(resA)); + sd::ops::cosine_distance_loss op; + auto res = op.evaluate({&predictions, &weights, &labels}, {}, {3, 1}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + ASSERT_TRUE(exp.equalsTo(resA)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, CosineDistance_SGO_Test_2) { - auto labels = NDArrayFactory::create('c', {2, 3}, {1.0, 2.0, 3.0, -1.0, 2.0, 1.0}); - //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); - auto predictions = NDArrayFactory::create('c', {2, 3}, {-0.3, -0.2, -0.1, 0, 0.1, 0.2}); - auto weights = NDArrayFactory::create('c', {2, 1}, {0., 1.}); - auto exp = NDArrayFactory::create(0.6); + auto labels = NDArrayFactory::create('c', {2, 3}, + {1.0, 2.0, 3.0, -1.0, 2.0, 1.0}); + // auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto predictions = NDArrayFactory::create( + 'c', {2, 3}, {-0.3, -0.2, -0.1, 0, 0.1, 0.2}); + auto weights = NDArrayFactory::create('c', {2, 1}, {0., 1.}); + auto exp = NDArrayFactory::create(0.6); - sd::ops::cosine_distance_loss op; - auto res = op.evaluate({&predictions, &weights, &labels}, {}, {2, 1}); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - auto resA = res.at(0); - - ASSERT_TRUE(exp.equalsTo(resA)); + sd::ops::cosine_distance_loss op; + auto res = op.evaluate({&predictions, &weights, &labels}, {}, {2, 1}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + ASSERT_TRUE(exp.equalsTo(resA)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, TestMarixBandPart_Test_1) { + auto x = NDArrayFactory::create('c', {2, 3, 3}); - auto x = NDArrayFactory::create('c', {2, 3, 3}); - - auto exp = NDArrayFactory::create('c', {2, 3, 3}); - x.linspace(1); - exp.linspace(1); - exp.p(0, 0, 2, 0.); - exp.p(1, 0, 2, 0.); - exp.p(0, 2, 0, 0.); - exp.p(1, 2, 0, 0.); + auto exp = NDArrayFactory::create('c', {2, 3, 3}); + x.linspace(1); + exp.linspace(1); + exp.p(0, 0, 2, 0.); + exp.p(1, 0, 2, 0.); + exp.p(0, 2, 0, 0.); + exp.p(1, 2, 0, 0.); - sd::ops::matrix_band_part op; - auto results = op.evaluate({&x}, {}, {1, 1}); + sd::ops::matrix_band_part op; + auto results = op.evaluate({&x}, {}, {1, 1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - //results.at(0)->printIndexedBuffer("MBP Test1"); - //exp.printIndexedBuffer("MBP Expec"); - ASSERT_TRUE(exp.equalsTo(results.at(0))); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + // results.at(0)->printIndexedBuffer("MBP Test1"); + // exp.printIndexedBuffer("MBP Expec"); + ASSERT_TRUE(exp.equalsTo(results.at(0))); } /////////////////////////////////////////////////////////////////// @@ -421,1134 +425,1195 @@ TEST_F(DeclarableOpsTests10, TestMarixBandPart_Test_2) { ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, atan2_test1) { - - auto y = NDArrayFactory::create('c', {2, 3, 4}, {-1.001 ,-0.915 ,-0.829 ,-0.743 ,-0.657 ,-0.571 ,-0.485 ,-0.399 ,-0.313 ,-0.227 ,-0.141 ,-0.055 ,0.031 ,0.117 ,0.203 ,0.289 ,0.375 ,0.461 ,0.547 ,0.633 ,0.719 ,0.805 ,0.891 ,0.977}); - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-0.51, -0.46, -0.41, -0.36, -0.31, -0.26, -0.21, -0.16, -0.11, -0.06, -0.01, 0.04, 0.09, 0.14, 0.19, 0.24, 0.29, 0.34, 0.39, 0.44, 0.49, 0.54, 0.59, 0.61}); - - auto exp = NDArrayFactory::create('c', {2,3,4}, {-2.04201, -2.03663, -2.03009, -2.02199,-2.01166, -1.99808, -1.97941, -1.95217,-1.90875, -1.8292 , -1.6416 , -0.942 , - 0.33172, 0.69614, 0.81846, 0.87776, 0.91253, 0.93533, 0.95141, 0.96336, 0.97259, 0.97993, 0.98591, 1.01266,}); - - sd::ops::tf_atan2 op; - auto result = op.evaluate({&y, &x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto y = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-1.001, -0.915, -0.829, -0.743, -0.657, -0.571, -0.485, -0.399, + -0.313, -0.227, -0.141, -0.055, 0.031, 0.117, 0.203, 0.289, + 0.375, 0.461, 0.547, 0.633, 0.719, 0.805, 0.891, 0.977}); + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {-0.51, -0.46, -0.41, -0.36, -0.31, -0.26, -0.21, -0.16, + -0.11, -0.06, -0.01, 0.04, 0.09, 0.14, 0.19, 0.24, + 0.29, 0.34, 0.39, 0.44, 0.49, 0.54, 0.59, 0.61}); + + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + { + -2.04201, -2.03663, -2.03009, -2.02199, -2.01166, -1.99808, + -1.97941, -1.95217, -1.90875, -1.8292, -1.6416, -0.942, + 0.33172, 0.69614, 0.81846, 0.87776, 0.91253, 0.93533, + 0.95141, 0.96336, 0.97259, 0.97993, 0.98591, 1.01266, + }); + + sd::ops::tf_atan2 op; + auto result = op.evaluate({&y, &x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, atan2_test2) { - - auto y = NDArrayFactory::create('c', {2, 3, 4}, {-1.001 ,-0.915 ,-0.829 ,-0.743 ,-0.657 ,-0.571 ,-0.485 ,-0.399 ,-0.313 ,-0.227 ,-0.141 ,-0.055 ,0.031 ,0.117 ,0.203 ,0.289 ,0.375 ,0.461 ,0.547 ,0.633 ,0.719 ,0.805 ,0.891 ,0.977}); - auto x = NDArrayFactory::create('c', { 3, 4}, {-1.05, -0.82, -0.639, -0.458, -0.277, -0.096, 0.085, 0.266, 0.447, 0.628, 0.809, 0.99}); - - auto exp = NDArrayFactory::create('c', {2,3,4}, {-2.38008, -2.30149, -2.22748, -2.1232 ,-1.96979, -1.73736, -1.3973 , -0.98279,-0.61088, -0.34685, -0.17256, -0.0555 , - 3.11208, 2.99987, 2.83399, 2.57869, 2.207 , 1.77611, 1.41664, 1.17298, 1.01458, 0.90829, 0.8336 , 0.77879}); - - sd::ops::tf_atan2 op; - auto result = op.evaluate({&y, &x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - // z->printIndexedBuffer(); - - // x.applyTrueBroadcast(sd::BroadcastOpsTuple::custom(scalar::Atan2, pairwise::Atan2, broadcast::Atan2), &y, &z, true); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto y = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-1.001, -0.915, -0.829, -0.743, -0.657, -0.571, -0.485, -0.399, + -0.313, -0.227, -0.141, -0.055, 0.031, 0.117, 0.203, 0.289, + 0.375, 0.461, 0.547, 0.633, 0.719, 0.805, 0.891, 0.977}); + auto x = NDArrayFactory::create( + 'c', {3, 4}, + {-1.05, -0.82, -0.639, -0.458, -0.277, -0.096, 0.085, 0.266, 0.447, 0.628, + 0.809, 0.99}); + + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-2.38008, -2.30149, -2.22748, -2.1232, -1.96979, -1.73736, + -1.3973, -0.98279, -0.61088, -0.34685, -0.17256, -0.0555, + 3.11208, 2.99987, 2.83399, 2.57869, 2.207, 1.77611, + 1.41664, 1.17298, 1.01458, 0.90829, 0.8336, 0.77879}); + + sd::ops::tf_atan2 op; + auto result = op.evaluate({&y, &x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer(); + + // x.applyTrueBroadcast(sd::BroadcastOpsTuple::custom(scalar::Atan2, + // pairwise::Atan2, broadcast::Atan2), &y, &z, true); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, atan2_test3) { - - auto y = NDArrayFactory::create('c', {2, 3, 4}, {-1.001 ,-0.915 ,-0.829 ,-0.743 ,-0.657 ,-0.571 ,-0.485 ,-0.399 ,-0.313 ,-0.227 ,-0.141 ,-0.055 ,0.031 ,0.117 ,0.203 ,0.289 ,0.375 ,0.461 ,0.547 ,0.633 ,0.719 ,0.805 ,0.891 ,0.977}); - auto x = NDArrayFactory::create('c', { 3, 4}, {-1.05, -0.82, -0.639, -0.458, -0.277, -0.096, 0.085, 0.266, 0.447, 0.628, 0.809, 0.99}); - - auto exp = NDArrayFactory::create('c', {2,3,4}, {-2.33231, -2.41089, -2.48491, -2.58919,-2.74259, -2.97502, 2.9681 , 2.55359, 2.18167, 1.91765, 1.74335, 1.62629, - -1.54128, -1.42907, -1.2632 , -1.00789,-0.63621, -0.20531, 0.15416, 0.39782, 0.55622, 0.6625 , 0.7372 , 0.79201}); - - sd::ops::tf_atan2 op; - auto result = op.evaluate({&x, &y}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto y = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-1.001, -0.915, -0.829, -0.743, -0.657, -0.571, -0.485, -0.399, + -0.313, -0.227, -0.141, -0.055, 0.031, 0.117, 0.203, 0.289, + 0.375, 0.461, 0.547, 0.633, 0.719, 0.805, 0.891, 0.977}); + auto x = NDArrayFactory::create( + 'c', {3, 4}, + {-1.05, -0.82, -0.639, -0.458, -0.277, -0.096, 0.085, 0.266, 0.447, 0.628, + 0.809, 0.99}); + + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-2.33231, -2.41089, -2.48491, -2.58919, -2.74259, -2.97502, + 2.9681, 2.55359, 2.18167, 1.91765, 1.74335, 1.62629, + -1.54128, -1.42907, -1.2632, -1.00789, -0.63621, -0.20531, + 0.15416, 0.39782, 0.55622, 0.6625, 0.7372, 0.79201}); + + sd::ops::tf_atan2 op; + auto result = op.evaluate({&x, &y}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, atan2_test4) { + auto y = NDArrayFactory::create( + 'c', {1, 3, 4}, + {-1.001, -0.829, -0.657, -0.485, -0.313, -0.141, 0.031, 0.203, 0.375, + 0.547, 0.719, 0.891}); + auto x = NDArrayFactory::create( + 'c', {2, 3, 1}, {-0.82, -0.458, -0.096, 0.085, 0.447, 0.809}); - auto y = NDArrayFactory::create('c', {1, 3, 4}, {-1.001 ,-0.829 ,-0.657 ,-0.485 ,-0.313 ,-0.141 ,0.031 ,0.203 ,0.375 ,0.547 ,0.719 ,0.891}); - auto x = NDArrayFactory::create('c', {2, 3, 1}, {-0.82, -0.458, -0.096, 0.085, 0.447, 0.809}); - - auto exp = NDArrayFactory::create('c', {2,3,4}, {-2.45527, -2.36165, -2.24628, -2.10492,-2.1703 , -1.86945, -1.50321, -1.15359,-0.25062, -0.17373, -0.13273, -0.10733, - 3.05688, 3.03942, 3.01293, 2.9681 , 2.18167, 1.87635, 1.50156, 1.14451, 1.13674, 0.97626, 0.84423, 0.7372 }); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-2.45527, -2.36165, -2.24628, -2.10492, -2.1703, -1.86945, + -1.50321, -1.15359, -0.25062, -0.17373, -0.13273, -0.10733, + 3.05688, 3.03942, 3.01293, 2.9681, 2.18167, 1.87635, + 1.50156, 1.14451, 1.13674, 0.97626, 0.84423, 0.7372}); - sd::ops::tf_atan2 op; - auto result = op.evaluate({&x, &y}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + sd::ops::tf_atan2 op; + auto result = op.evaluate({&x, &y}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, atan2_test5) { + auto y = NDArrayFactory::create( + 'c', {1, 3, 4}, + {-1.001, -0.829, -0.657, -0.485, -0.313, -0.141, 0.031, 0.203, 0.375, + 0.547, 0.719, 0.891}); + auto x = NDArrayFactory::create( + 'c', {2, 3, 1}, {-0.82, -0.458, -0.096, 0.085, 0.447, 0.809}); - auto y = NDArrayFactory::create('c', {1, 3, 4}, {-1.001 ,-0.829 ,-0.657 ,-0.485 ,-0.313 ,-0.141 ,0.031 ,0.203 ,0.375 ,0.547 ,0.719 ,0.891}); - auto x = NDArrayFactory::create('c', {2, 3, 1}, {-0.82, -0.458, -0.096, 0.085, 0.447, 0.809}); - - auto exp = NDArrayFactory::create('c', {2,3,4}, {-2.25712, -2.35074, -2.46611, -2.60747,-2.54209, -2.84294, 3.07401, 2.72438, 1.82141, 1.74453, 1.70353, 1.67813, - -1.48608, -1.46862, -1.44214, -1.3973 ,-0.61088, -0.30556, 0.06924, 0.42629, 0.43405, 0.59453, 0.72657, 0.8336 }); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-2.25712, -2.35074, -2.46611, -2.60747, -2.54209, -2.84294, + 3.07401, 2.72438, 1.82141, 1.74453, 1.70353, 1.67813, + -1.48608, -1.46862, -1.44214, -1.3973, -0.61088, -0.30556, + 0.06924, 0.42629, 0.43405, 0.59453, 0.72657, 0.8336}); - sd::ops::tf_atan2 op; - auto result = op.evaluate({&y, &x}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + sd::ops::tf_atan2 op; + auto result = op.evaluate({&y, &x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, atan2_test6) { + auto y = NDArrayFactory::create( + 'c', {1, 3, 4}, + {-1.001, -0.829, -0.657, -0.485, -0.313, -0.141, 0.031, 0.203, 0.375, + 0.547, 0.719, 0.891}); + auto x = + NDArrayFactory::create('c', {4}, {-0.82, -0.096, 0.085, 0.809}); - auto y = NDArrayFactory::create('c', {1, 3, 4}, {-1.001 ,-0.829 ,-0.657 ,-0.485 ,-0.313 ,-0.141 ,0.031 ,0.203 ,0.375 ,0.547 ,0.719 ,0.891}); - auto x = NDArrayFactory::create('c', { 4}, {-0.82, -0.096, 0.085, 0.809}); + auto exp = NDArrayFactory::create( + 'c', {1, 3, 4}, + {-2.25712, -1.68608, -1.44214, -0.54006, -2.77695, -2.16855, 0.34972, + 0.24585, 2.71267, 1.74453, 1.45312, 0.8336}); - auto exp = NDArrayFactory::create('c', {1,3,4}, {-2.25712, -1.68608, -1.44214, -0.54006,-2.77695, -2.16855, 0.34972, 0.24585, 2.71267, 1.74453, 1.45312, 0.8336 }); + sd::ops::tf_atan2 op; + auto result = op.evaluate({&y, &x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); - sd::ops::tf_atan2 op; - auto result = op.evaluate({&y, &x}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, IGamma_Test1) { - - auto y = NDArrayFactory::create('c', {1, 3, 4}, {1.1 , 2.1 , 3.1 ,4.1 , 5.1 , 6.1 ,7.1 ,8.1 ,9.1 ,10.1,11.1 ,12.1}); - auto x = NDArrayFactory::create('c', { 4}, {1.2, 2.2, 3.2, 4.2}); - - auto exp = NDArrayFactory::create('c', {1,3,4}, { - 0.659917, 0.61757898, 0.59726304, 0.58478117, - 0.0066205109, 0.022211598, 0.040677428, 0.059117373, - 0.0000039433403, 0.000086064574, 0.000436067, 0.0012273735}); - - sd::ops::igamma op; - auto result = op.evaluate({&y, &x}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); -// z->printBuffer("OUtput"); -// exp.printBuffer("EXpect"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto y = NDArrayFactory::create( + 'c', {1, 3, 4}, + {1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1, 9.1, 10.1, 11.1, 12.1}); + auto x = NDArrayFactory::create('c', {4}, {1.2, 2.2, 3.2, 4.2}); + + auto exp = NDArrayFactory::create( + 'c', {1, 3, 4}, + {0.659917, 0.61757898, 0.59726304, 0.58478117, 0.0066205109, 0.022211598, + 0.040677428, 0.059117373, 0.0000039433403, 0.000086064574, 0.000436067, + 0.0012273735}); + + sd::ops::igamma op; + auto result = op.evaluate({&y, &x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printBuffer("OUtput"); + // exp.printBuffer("EXpect"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, IGamma_Test2) { - - auto y = NDArrayFactory::create('c', {1, 3, 4}, {1.1 , 2.1 , 3.1 ,4.1 , 5.1 , 6.1 , - 7.1 ,8.1 ,9.1 ,10.1,11.1 ,12.1}); - auto x = NDArrayFactory::create('c', { 4}, {1.2, 2.2, 3.2, 4.2}); - auto exp = NDArrayFactory::create('c', {1,3,4}, {0.340083, 0.382421, 0.402737, 0.415221, - 0.993379, 0.977788, 0.959323, 0.940883, - 0.999996, 0.999914, 0.999564, 0.998773}); - - sd::ops::igammac op; - auto result = op.evaluate({&y, &x}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); -// z->printBuffer("OUtput"); -// exp.printBuffer("EXpect"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto y = NDArrayFactory::create( + 'c', {1, 3, 4}, + {1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1, 9.1, 10.1, 11.1, 12.1}); + auto x = NDArrayFactory::create('c', {4}, {1.2, 2.2, 3.2, 4.2}); + auto exp = NDArrayFactory::create( + 'c', {1, 3, 4}, + {0.340083, 0.382421, 0.402737, 0.415221, 0.993379, 0.977788, 0.959323, + 0.940883, 0.999996, 0.999914, 0.999564, 0.998773}); + + sd::ops::igammac op; + auto result = op.evaluate({&y, &x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printBuffer("OUtput"); + // exp.printBuffer("EXpect"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, LGamma_Test1) { + auto x = NDArrayFactory::create( + 'c', {3, 3}, {0.1, 0.5, 0.7, 1.5, 1.7, 2.0, 2.5, 2.7, 3.}); - auto x = NDArrayFactory::create('c', {3, 3}, {0.1, 0.5, 0.7, 1.5, 1.7, 2.0, 2.5, 2.7, 3.}); - - auto exp = NDArrayFactory::create('c', {3,3}, { - 2.2527127 , 0.5723649 , 0.26086727, - -0.12078223, -0.09580769, 0., - 0.28468287, 0.4348206 , 0.6931472 - }); - - sd::ops::lgamma op; - auto result = op.evaluate({&x}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); -// z->printBuffer("OUtput"); -// exp.printBuffer("EXpect"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {2.2527127, 0.5723649, 0.26086727, -0.12078223, -0.09580769, 0., + 0.28468287, 0.4348206, 0.6931472}); + sd::ops::lgamma op; + auto result = op.evaluate({&x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printBuffer("OUtput"); + // exp.printBuffer("EXpect"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, range_test10) { + auto limit = NDArrayFactory::create('c', {1, 3, 4}); + limit = 5.; + auto exp = NDArrayFactory::create('c', {5}, {0., 1., 2., 3., 4.}); - auto limit = NDArrayFactory::create('c', {1, 3, 4}); - limit = 5.; - auto exp = NDArrayFactory::create('c', {5}, {0.,1.,2.,3.,4.}); - - sd::ops::range op; - auto result = op.evaluate({&limit}, {}, {}, {}); + sd::ops::range op; + auto result = op.evaluate({&limit}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, range_test11) { + auto limit = NDArrayFactory::create('c', {1, 3, 4}); + auto start = NDArrayFactory::create('c', {2, 4}); + limit = 5.; + start = 0.5; + auto exp = + NDArrayFactory::create('c', {5}, {0.5, 1.5, 2.5, 3.5, 4.5}); - auto limit = NDArrayFactory::create('c', {1, 3, 4}); - auto start = NDArrayFactory::create('c', {2, 4}); - limit = 5.; - start = 0.5; - auto exp = NDArrayFactory::create('c', {5}, {0.5,1.5,2.5,3.5,4.5}); - - sd::ops::range op; - auto result = op.evaluate({&start, &limit}, {}, {}, {}); + sd::ops::range op; + auto result = op.evaluate({&start, &limit}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, range_test12) { + auto exp = NDArrayFactory::create( + 'c', {9}, {0.5f, 1.f, 1.5f, 2.f, 2.5f, 3.f, 3.5f, 4.f, 4.5f}); - auto exp = NDArrayFactory::create('c', {9}, {0.5f, 1.f , 1.5f, 2.f , 2.5f, 3.f , 3.5f, 4.f , 4.5f}); + sd::ops::range op; + auto result = op.evaluate({}, {0.5, 5, 0.5}, {}, {}); - sd::ops::range op; - auto result = op.evaluate({}, {0.5, 5, 0.5}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, top_k_permuted_test1) { + auto x = + NDArrayFactory::create({7., 3., 1., 2., 5., 0., 4., 6., 9., 8.}); + auto expUnsorted = + NDArrayFactory::create({7., 6., 9., 8.}); // Sorted = False + auto expSorted = + NDArrayFactory::create({9., 8., 7., 6., 5.}); // Sorted = False - auto x = NDArrayFactory::create({7., 3., 1., 2., 5., 0., 4., 6., 9., 8.}); - auto expUnsorted = NDArrayFactory::create({7., 6., 9., 8.}); // Sorted = False - auto expSorted = NDArrayFactory::create({9., 8., 7., 6., 5.}); // Sorted = False - - - sd::ops::top_k op; - auto result = op.evaluate({&x}, {}, {4}, {false}); + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {4}, {false}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - auto zI = result.at(1); + auto z = result.at(0); + auto zI = result.at(1); - ASSERT_TRUE(expUnsorted.isSameShape(z)); - ASSERT_TRUE(expUnsorted.equalsTo(z)); + ASSERT_TRUE(expUnsorted.isSameShape(z)); + ASSERT_TRUE(expUnsorted.equalsTo(z)); - auto result2 = op.evaluate({&x}, {}, {5}, {true}); + auto result2 = op.evaluate({&x}, {}, {5}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result2.status()); + ASSERT_EQ(ND4J_STATUS_OK, result2.status()); - z = result2.at(0); - zI = result2.at(1); + z = result2.at(0); + zI = result2.at(1); - ASSERT_TRUE(expSorted.isSameShape(z)); - ASSERT_TRUE(expSorted.equalsTo(z)); + ASSERT_TRUE(expSorted.isSameShape(z)); + ASSERT_TRUE(expSorted.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, top_k_permuted_test2) { + auto x = + NDArrayFactory::create({7., 3., 1., 2., 5., 0., 4., 6., 9., 8.}); + auto expUnsorted = + NDArrayFactory::create({7., 5., 6., 9., 8.}); // Sorted = False + auto expSorted = + NDArrayFactory::create({9., 8., 7., 6., 5.}); // Sorted = False - auto x = NDArrayFactory::create({7., 3., 1., 2., 5., 0., 4., 6., 9., 8.}); - auto expUnsorted = NDArrayFactory::create({7., 5., 6., 9., 8.}); // Sorted = False - auto expSorted = NDArrayFactory::create({9., 8., 7., 6., 5.}); // Sorted = False + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {5}, {false}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - sd::ops::top_k op; - auto result = op.evaluate({&x}, {}, {5}, {false}); + auto z = result.at(0); + auto zI = result.at(1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(expUnsorted.isSameShape(z)); + ASSERT_TRUE(expUnsorted.equalsTo(z)); - auto z = result.at(0); - auto zI = result.at(1); + auto result2 = op.evaluate({&x}, {}, {5}, {true}); - ASSERT_TRUE(expUnsorted.isSameShape(z)); - ASSERT_TRUE(expUnsorted.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result2.status()); - auto result2 = op.evaluate({&x}, {}, {5}, {true}); + z = result2.at(0); + zI = result2.at(1); - ASSERT_EQ(ND4J_STATUS_OK, result2.status()); - - z = result2.at(0); - zI = result2.at(1); - - ASSERT_TRUE(expSorted.isSameShape(z)); - ASSERT_TRUE(expSorted.equalsTo(z)); + ASSERT_TRUE(expSorted.isSameShape(z)); + ASSERT_TRUE(expSorted.equalsTo(z)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test1) { - - auto labels = NDArrayFactory::create('c', {2,3},{3, 2, 1, 0, 1, 2}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto expected = NDArrayFactory::create('c', {2,3}, {1.24254, 1.34254, 1.44254, 1.54254, 1.44254, 1.34254}); +TEST_F(DeclarableOpsTests10, + sparse_softmax_cross_entropy_loss_with_logits_test1) { + auto labels = NDArrayFactory::create('c', {2, 3}, {3, 2, 1, 0, 1, 2}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 3}, {1.24254, 1.34254, 1.44254, 1.54254, 1.44254, 1.34254}); - logits.linspace(0.1, 0.1); + logits.linspace(0.1, 0.1); - sd::ops::sparse_softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&labels, &logits}); + sd::ops::sparse_softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&labels, &logits}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test2) { - - auto labels = NDArrayFactory::create('c', {2},{1, 0}); - auto logits = NDArrayFactory::create('c', {2,3}); - auto expected = NDArrayFactory::create('c', {2}, {1.10194, 1.20194}); +TEST_F(DeclarableOpsTests10, + sparse_softmax_cross_entropy_loss_with_logits_test2) { + auto labels = NDArrayFactory::create('c', {2}, {1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3}); + auto expected = NDArrayFactory::create('c', {2}, {1.10194, 1.20194}); - logits.linspace(0.1, 0.1); + logits.linspace(0.1, 0.1); - sd::ops::sparse_softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&labels, &logits}); + sd::ops::sparse_softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&labels, &logits}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test3) { +TEST_F(DeclarableOpsTests10, + sparse_softmax_cross_entropy_loss_with_logits_test3) { + NDArray labels('c', {1}, std::vector{0}, sd::DataType::INT32); + auto logits = NDArrayFactory::create('c', {1, 3}); + auto expected = NDArrayFactory::create('c', {1}, {1.20194}); - NDArray labels('c', {1}, std::vector{0}, sd::DataType::INT32); - auto logits = NDArrayFactory::create('c', {1,3}); - auto expected = NDArrayFactory::create('c', {1}, {1.20194}); + logits.linspace(0.1, 0.1); - logits.linspace(0.1, 0.1); + sd::ops::sparse_softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&labels, &logits}); - sd::ops::sparse_softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&labels, &logits}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test4) { - - auto labels = NDArrayFactory::create('c', {2},{0, 0}); - auto logits = NDArrayFactory::create('c', {2,1}); - auto expected = NDArrayFactory::create('c', {2}, {0., 0.}); +TEST_F(DeclarableOpsTests10, + sparse_softmax_cross_entropy_loss_with_logits_test4) { + auto labels = NDArrayFactory::create('c', {2}, {0, 0}); + auto logits = NDArrayFactory::create('c', {2, 1}); + auto expected = NDArrayFactory::create('c', {2}, {0., 0.}); - logits.linspace(0.1, 0.1); + logits.linspace(0.1, 0.1); - sd::ops::sparse_softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&labels, &logits}); + sd::ops::sparse_softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&labels, &logits}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, histogram_fixed_width_test1) { + auto input = NDArrayFactory::create( + 'c', {2, 3}, {-1.f, 0.f, 1.5f, 2.f, 5.f, 15.f}); + auto range = NDArrayFactory::create('c', {2}, {0, 5}); + auto exp = NDArrayFactory::create('c', {5}, {2, 1, 1, 0, 2}); - auto input = NDArrayFactory::create('c', {2,3},{-1.f, 0.f, 1.5f, 2.f, 5.f, 15.f}); - auto range = NDArrayFactory::create('c', {2}, {0, 5}); - auto exp = NDArrayFactory::create('c', {5}, {2, 1, 1, 0, 2}); - - sd::ops::histogram_fixed_width op; - auto results = op.evaluate({&input, &range}, {}, {5}, {}); + sd::ops::histogram_fixed_width op; + auto results = op.evaluate({&input, &range}, {}, {5}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto out = results.at(0); + auto out = results.at(0); - ASSERT_TRUE(exp.isSameShape(out)); - ASSERT_TRUE(exp.equalsTo(out)); + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, histogram_fixed_width_test2) { + auto input = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.f, 5.f, 2.f, 1.f, -1.f, 2.f, 5.f, 3.f, 2.f, 3.f, -1.f, 5.f, + 3.f, 2.f, 1.f, 4.f, 2.f, 5.f, 5.f, 5.f, 6.f, 6.f, -1.f, 0.f}); + auto range = NDArrayFactory::create('c', {2}, {0, 5}); + auto exp = NDArrayFactory::create('c', {5}, {5, 2, 5, 3, 9}); - auto input = NDArrayFactory::create('c', {2,3,4},{0.f, 5.f, 2.f, 1.f, -1.f, 2.f, 5.f, 3.f, 2.f, 3.f, -1.f, 5.f, 3.f, 2.f, 1.f, 4.f, 2.f, 5.f, 5.f, 5.f, 6.f, 6.f, -1.f, 0.f}); - auto range = NDArrayFactory::create('c', {2}, {0, 5}); - auto exp = NDArrayFactory::create('c', {5}, {5, 2, 5, 3, 9}); + sd::ops::histogram_fixed_width op; + auto results = op.evaluate({&input, &range}, {}, {5}, {}); - sd::ops::histogram_fixed_width op; - auto results = op.evaluate({&input, &range}, {}, {5}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto out = results.at(0); + auto out = results.at(0); - ASSERT_TRUE(exp.isSameShape(out)); - ASSERT_TRUE(exp.equalsTo(out)); + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, histogram_fixed_width_test3) { + auto input = NDArrayFactory::create( + 'c', {2, 3, 1, 4, 1}, + {0.f, 5.f, 2.001f, 1.f, -1.f, 2.f, 5.f, 3.f, + 2.999f, 3.00001f, -1.f, 3.99999f, 3.f, 2.f, 1.f, 4.f, + 2.f, 5.f, 5.f, 5.f, 6.f, 6.f, -1.f, 0.00001f}); + auto range = NDArrayFactory::create('c', {1, 2, 1}, {0, 5}); + auto exp = NDArrayFactory::create('c', {5}, {5, 2, 5, 4, 8}); - auto input = NDArrayFactory::create('c', {2,3,1,4,1},{0.f, 5.f, 2.001f, 1.f, -1.f, 2.f, 5.f, 3.f, 2.999f, 3.00001f, -1.f, 3.99999f, 3.f, 2.f, 1.f, 4.f, 2.f, 5.f, 5.f, 5.f, 6.f, 6.f, -1.f, 0.00001f}); - auto range = NDArrayFactory::create('c', {1,2,1}, {0, 5}); - auto exp = NDArrayFactory::create('c', {5}, {5, 2, 5, 4, 8}); - - sd::ops::histogram_fixed_width op; - auto results = op.evaluate({&input, &range}, {}, {5}, {}); + sd::ops::histogram_fixed_width op; + auto results = op.evaluate({&input, &range}, {}, {5}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto out = results.at(0); + auto out = results.at(0); - ASSERT_TRUE(exp.isSameShape(out)); - ASSERT_TRUE(exp.equalsTo(out)); + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, histogram_fixed_width_test4) { - - auto input = NDArrayFactory::create('c', {20,5},{13.8387f,0.1509f,50.39f,30.403f,13.5174f,9.7351f,37.6652f,28.9215f,22.7011f,45.2834f,40.7628f,50.4995f,26.8003f,27.479f,44.633f,6.9109f,48.5004f, - 46.5971f,1.6203f,23.6381f,38.9661f,50.8146f,17.2482f,8.0429f,7.5666f,7.9709f,21.8403f,20.1694f,23.3004f,50.9151f,46.239f,38.7323f,29.6946f,32.9876f, - 23.0013f,39.7318f,19.4486f,37.6147f,-0.1506f,5.3246f,3.6173f,24.2573f,4.3941f,9.7105f,24.0364f,35.3681f,17.7805f,35.7681f,16.4144f,17.4362f,8.4987f, - 26.8108f,36.2937f,31.6442f,29.7221f,8.7445f,33.3301f,4.0939f,13.078f,45.1481f,29.0172f,21.6548f,35.408f,27.1861f,2.2576f,40.6804f,36.2201f,29.7352f, - 29.1244f,38.7444f,5.8721f,33.5983f,48.2694f,34.4161f,19.7148f,13.8085f,13.6075f,22.5042f,37.8002f,50.0543f,48.5314f,20.3694f,28.5042f,-0.4679f,4.4245f, - 18.9837f,40.7724f,2.7611f,44.0431f,37.186f,27.7361f,14.6001f,9.1721f,14.6087f,21.4072f,49.3344f,11.4668f,14.6171f,15.2502f,5.244f}); - auto range = NDArrayFactory::create('c', {1,2}, {0, 50}); - auto exp = NDArrayFactory::create('c', {5}, {22, 17, 24, 19, 18}); - - sd::ops::histogram_fixed_width op; - auto results = op.evaluate({&input, &range}, {}, {5}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto out = results.at(0); - - ASSERT_TRUE(exp.isSameShape(out)); - ASSERT_TRUE(exp.equalsTo(out)); + auto input = NDArrayFactory::create( + 'c', {20, 5}, + {13.8387f, 0.1509f, 50.39f, 30.403f, 13.5174f, 9.7351f, 37.6652f, + 28.9215f, 22.7011f, 45.2834f, 40.7628f, 50.4995f, 26.8003f, 27.479f, + 44.633f, 6.9109f, 48.5004f, 46.5971f, 1.6203f, 23.6381f, 38.9661f, + 50.8146f, 17.2482f, 8.0429f, 7.5666f, 7.9709f, 21.8403f, 20.1694f, + 23.3004f, 50.9151f, 46.239f, 38.7323f, 29.6946f, 32.9876f, 23.0013f, + 39.7318f, 19.4486f, 37.6147f, -0.1506f, 5.3246f, 3.6173f, 24.2573f, + 4.3941f, 9.7105f, 24.0364f, 35.3681f, 17.7805f, 35.7681f, 16.4144f, + 17.4362f, 8.4987f, 26.8108f, 36.2937f, 31.6442f, 29.7221f, 8.7445f, + 33.3301f, 4.0939f, 13.078f, 45.1481f, 29.0172f, 21.6548f, 35.408f, + 27.1861f, 2.2576f, 40.6804f, 36.2201f, 29.7352f, 29.1244f, 38.7444f, + 5.8721f, 33.5983f, 48.2694f, 34.4161f, 19.7148f, 13.8085f, 13.6075f, + 22.5042f, 37.8002f, 50.0543f, 48.5314f, 20.3694f, 28.5042f, -0.4679f, + 4.4245f, 18.9837f, 40.7724f, 2.7611f, 44.0431f, 37.186f, 27.7361f, + 14.6001f, 9.1721f, 14.6087f, 21.4072f, 49.3344f, 11.4668f, 14.6171f, + 15.2502f, 5.244f}); + auto range = NDArrayFactory::create('c', {1, 2}, {0, 50}); + auto exp = NDArrayFactory::create('c', {5}, {22, 17, 24, 19, 18}); + + sd::ops::histogram_fixed_width op; + auto results = op.evaluate({&input, &range}, {}, {5}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto out = results.at(0); + + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, histogram_fixed_width_test5) { - - auto input = NDArrayFactory::create('c', {5,20},{20.f, 0.f, 60.f, 40.f, 20.f, 0.f, 40.f, 0.f, 40.f, 40.f,40.f,60.f, 20.f, 20.f, 60.f, 0.f, 40.f, - 46.5971f,1.6203f,23.6381f,38.9661f,50.8146f,17.2482f,8.0429f,7.5666f,7.9709f,21.8403f,20.1694f,23.3004f,50.9151f,46.239f,38.7323f,29.6946f,32.9876f, - 23.0013f,39.7318f,19.4486f,37.6147f,-0.1506f,5.3246f,3.6173f,24.2573f,4.3941f,9.7105f,24.0364f,35.3681f,17.7805f,35.7681f,16.4144f,17.4362f,8.4987f, - 26.8108f,36.2937f,31.6442f,29.7221f,8.7445f,33.3301f,4.0939f,13.078f,45.1481f,29.0172f,21.6548f,35.408f,27.1861f,2.2576f,40.6804f,36.2201f,29.7352f, - 29.1244f,38.7444f,5.8721f,33.5983f,48.2694f,34.4161f,19.7148f,13.8085f,13.6075f,22.5042f,37.8002f,50.0543f,48.5314f,20.3694f,28.5042f,-0.4679f,4.4245f, - 18.9837f,40.7724f,2.7611f,44.0431f,37.186f,27.7361f,14.6001f,9.1721f,14.6087f,21.4072f,49.3344f,11.4668f,14.6171f,15.2502f,5.244f}); - auto range = NDArrayFactory::create('c', {1,2}, {0, 50}); -// auto exp = NDArrayFactory::create('c', {5}, {23, 19, 20, 23, 15}); // 23, 15, 24, 17, 21 - auto exp = NDArrayFactory::create('c', {5}, {23, 15, 24, 17, 21}); - - sd::ops::histogram_fixed_width op; - auto results = op.evaluate({&input, &range}, {}, {5}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *out = results.at(0); - - ASSERT_TRUE(exp.isSameShape(out)); - // out->printBuffer("5HIST"); - ASSERT_TRUE(exp.equalsTo(out)); + auto input = NDArrayFactory::create( + 'c', {5, 20}, + {20.f, 0.f, 60.f, 40.f, 20.f, 0.f, 40.f, + 0.f, 40.f, 40.f, 40.f, 60.f, 20.f, 20.f, + 60.f, 0.f, 40.f, 46.5971f, 1.6203f, 23.6381f, 38.9661f, + 50.8146f, 17.2482f, 8.0429f, 7.5666f, 7.9709f, 21.8403f, 20.1694f, + 23.3004f, 50.9151f, 46.239f, 38.7323f, 29.6946f, 32.9876f, 23.0013f, + 39.7318f, 19.4486f, 37.6147f, -0.1506f, 5.3246f, 3.6173f, 24.2573f, + 4.3941f, 9.7105f, 24.0364f, 35.3681f, 17.7805f, 35.7681f, 16.4144f, + 17.4362f, 8.4987f, 26.8108f, 36.2937f, 31.6442f, 29.7221f, 8.7445f, + 33.3301f, 4.0939f, 13.078f, 45.1481f, 29.0172f, 21.6548f, 35.408f, + 27.1861f, 2.2576f, 40.6804f, 36.2201f, 29.7352f, 29.1244f, 38.7444f, + 5.8721f, 33.5983f, 48.2694f, 34.4161f, 19.7148f, 13.8085f, 13.6075f, + 22.5042f, 37.8002f, 50.0543f, 48.5314f, 20.3694f, 28.5042f, -0.4679f, + 4.4245f, 18.9837f, 40.7724f, 2.7611f, 44.0431f, 37.186f, 27.7361f, + 14.6001f, 9.1721f, 14.6087f, 21.4072f, 49.3344f, 11.4668f, 14.6171f, + 15.2502f, 5.244f}); + auto range = NDArrayFactory::create('c', {1, 2}, {0, 50}); + // auto exp = NDArrayFactory::create('c', {5}, {23, 19, 20, 23, + // 15}); // 23, 15, 24, 17, 21 + auto exp = NDArrayFactory::create('c', {5}, {23, 15, 24, 17, 21}); + + sd::ops::histogram_fixed_width op; + auto results = op.evaluate({&input, &range}, {}, {5}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto out = results.at(0); + + ASSERT_TRUE(exp.isSameShape(out)); + // out->printBuffer("5HIST"); + ASSERT_TRUE(exp.equalsTo(out)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, histogram_fixed_width_test6) { + auto input = NDArrayFactory::create( + 'c', {7}, {0.0, 0.1, 0.1, 0.3, 0.5, 0.5, 0.9}); + auto range = NDArrayFactory::create('c', {2}, {0, 1}); + auto bins = NDArrayFactory::create(5); - auto input = NDArrayFactory::create('c', {7},{0.0, 0.1, 0.1, 0.3, 0.5, 0.5, 0.9}); - auto range = NDArrayFactory::create('c', {2}, {0, 1}); - auto bins = NDArrayFactory::create(5); - - auto exp = NDArrayFactory::create('c', {5}, {3, 1, 2, 0, 1}); + auto exp = NDArrayFactory::create('c', {5}, {3, 1, 2, 0, 1}); - sd::ops::histogram_fixed_width op; - auto results = op.evaluate({&input, &range, &bins}, {}, {}, {}); + sd::ops::histogram_fixed_width op; + auto results = op.evaluate({&input, &range, &bins}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto out = results.at(0); - // out->printShapeInfo(); - // out->printIndexedBuffer(); + auto out = results.at(0); + // out->printShapeInfo(); + // out->printIndexedBuffer(); - ASSERT_TRUE(exp.isSameShape(out)); - ASSERT_TRUE(exp.equalsTo(out)); + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, NTH_Element_Test_1) { + NDArray input = NDArrayFactory::create( + 'c', {12}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); + NDArray n = NDArrayFactory::create(4.f); + NDArray exp = NDArrayFactory::create(5.f); - NDArray input = NDArrayFactory::create('c', {12}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); - NDArray n = NDArrayFactory::create(4.f); - NDArray exp = NDArrayFactory::create(5.f); + // input.linspace(1.f); - //input.linspace(1.f); + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {}); - sd::ops::nth_element op; - auto results = op.evaluate({&input, &n}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray* output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, NTH_Element_Test_2) { + NDArray input = NDArrayFactory::create( + 'c', {3, 4}, {10, 11, 9, 12, 8, 7, 6, 5, 1, 3, 2, 4}); + NDArray n = NDArrayFactory::create(3); + NDArray exp = NDArrayFactory::create({12.f, 8.f, 4.f}); - NDArray input = NDArrayFactory::create('c', {3, 4}, {10, 11, 9, 12, 8, 7, 6, 5, 1, 3, 2, 4}); - NDArray n = NDArrayFactory::create(3); - NDArray exp = NDArrayFactory::create({12.f, 8.f, 4.f}); - -// input.linspace(1.f); + // input.linspace(1.f); - sd::ops::nth_element op; - auto results = op.evaluate({&input, &n}, {}, {}); + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, NTH_Element_Test_3) { + NDArray input = NDArrayFactory::create( + 'c', {3, 4}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); + NDArray n = NDArrayFactory::create(3); + NDArray exp = NDArrayFactory::create({1.f, 5.f, 2.f}); - NDArray input = NDArrayFactory::create('c', {3,4}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); - NDArray n = NDArrayFactory::create(3); - NDArray exp = NDArrayFactory::create({1.f, 5.f, 2.f}); - - //input.linspace(1.f); + // input.linspace(1.f); - sd::ops::nth_element op; - auto results = op.evaluate({&input, &n}, {}, {1}); // with reverse = true + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {1}); // with reverse = true - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, NTH_Element_Test_4) { + NDArray input = NDArrayFactory::create( + 'c', {2, 2, 3}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); + NDArray n = NDArrayFactory::create(2); + NDArray exp = + NDArrayFactory::create('c', {2, 2}, {10.f, 11.f, 12.f, 4.f}); - NDArray input = NDArrayFactory::create('c', {2, 2, 3}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); - NDArray n = NDArrayFactory::create(2); - NDArray exp = NDArrayFactory::create('c', {2,2}, {10.f, 11.f, 12.f, 4.f}); + // input.linspace(1.f); - //input.linspace(1.f); + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {}); - sd::ops::nth_element op; - auto results = op.evaluate({&input, &n}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray* output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, NTH_Element_Test_04) { + NDArray input = NDArrayFactory::create('c', {6, 15}); + NDArray n = NDArrayFactory::create(4); + NDArray exp = NDArrayFactory::create( + 'c', {6}, {5.f, 20.f, 35.f, 50.f, 65.f, 80.f}); - NDArray input = NDArrayFactory::create('c', {6, 15}); - NDArray n = NDArrayFactory::create(4); - NDArray exp = NDArrayFactory::create('c', {6}, {5.f, 20.f, 35.f, 50.f, 65.f, 80.f}); - - input.linspace(1.f); + input.linspace(1.f); - sd::ops::nth_element op; - auto results = op.evaluate({&input, &n}, {}, {}); + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, NTH_Element_Test_5) { + NDArray input = NDArrayFactory::create( + 'c', {2, 2, 3}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); + NDArray n = NDArrayFactory::create(2); + NDArray exp = + NDArrayFactory::create('c', {2, 2}, {1.f, 7.f, 5.f, 2.f}); - NDArray input = NDArrayFactory::create('c', {2, 2, 3}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); - NDArray n = NDArrayFactory::create(2); - NDArray exp = NDArrayFactory::create('c', {2,2}, {1.f, 7.f, 5.f, 2.f}); - -// input.linspace(1.f); + // input.linspace(1.f); - sd::ops::nth_element op; - auto results = op.evaluate({&input, &n}, {}, {1}); + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, NTH_Element_Test_6) { + NDArray input = NDArrayFactory::create( + 'c', {12}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); + NDArray n = NDArrayFactory::create(0); + NDArray exp = + NDArrayFactory::create(1.f); // NDArrayFactory::create('c', + // {2,2}, {1.f, 4.f, 7.f, 10.f}); - NDArray input = NDArrayFactory::create('c', {12}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); - NDArray n = NDArrayFactory::create(0); - NDArray exp = NDArrayFactory::create(1.f);//NDArrayFactory::create('c', {2,2}, {1.f, 4.f, 7.f, 10.f}); + // input.linspace(1.f); -// input.linspace(1.f); + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {0}); - sd::ops::nth_element op; - auto results = op.evaluate({&input, &n}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray* output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + auto output = results.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, NTH_Element_Test_06) { + NDArray input = NDArrayFactory::create( + 'c', {12}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); + NDArray n = NDArrayFactory::create(4); + NDArray exp = + NDArrayFactory::create(8.f); // NDArrayFactory::create('c', + // {2,2}, {1.f, 4.f, 7.f, 10.f}); - NDArray input = NDArrayFactory::create('c', {12}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); - NDArray n = NDArrayFactory::create(4); - NDArray exp = NDArrayFactory::create(8.f);//NDArrayFactory::create('c', {2,2}, {1.f, 4.f, 7.f, 10.f}); - -// input.linspace(1.f); + // input.linspace(1.f); - sd::ops::nth_element op; - auto results = op.evaluate({&input, &n}, {}, {1}); + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + auto output = results.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, NTH_Element_Test_7) { + NDArray input = NDArrayFactory::create( + 'c', {2, 3, 4}, {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f, 0.9234f, 0.0856f, 0.7938f, - NDArray input = NDArrayFactory::create('c', {2, 3, 4}, {0.7788f, 0.8012f, 0.7244f, 0.2309f, - 0.7271f, 0.1804f, 0.5056f, 0.8925f, - 0.5461f, 0.9234f, 0.0856f, 0.7938f, - - 0.6591f, 0.5555f, 0.1596f, 0.3087f, - 0.1548f, 0.4695f, 0.9939f, 0.6113f, - 0.6765f, 0.1800f, 0.6750f, 0.2246f}); - NDArray n = NDArrayFactory::create(2); - NDArray exp = NDArrayFactory::create('c', {2,3}, {0.7788f, 0.7271f, 0.7938f, 0.5555f, 0.6113f, 0.675f}); + 0.6591f, 0.5555f, 0.1596f, 0.3087f, 0.1548f, 0.4695f, + 0.9939f, 0.6113f, 0.6765f, 0.1800f, 0.6750f, 0.2246f}); + NDArray n = NDArrayFactory::create(2); + NDArray exp = NDArrayFactory::create( + 'c', {2, 3}, {0.7788f, 0.7271f, 0.7938f, 0.5555f, 0.6113f, 0.675f}); - //input.linspace(1.f); + // input.linspace(1.f); - sd::ops::nth_element op; - auto results = op.evaluate({&input, &n}, {}, {0}); + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, NTH_Element_Test_8) { + NDArray input = NDArrayFactory::create( + 'c', {2, 3, 4}, {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f, 0.9234f, 0.0856f, 0.7938f, - NDArray input = NDArrayFactory::create('c', {2, 3, 4}, {0.7788f, 0.8012f, 0.7244f, 0.2309f, - 0.7271f, 0.1804f, 0.5056f, 0.8925f, - 0.5461f, 0.9234f, 0.0856f, 0.7938f, + 0.6591f, 0.5555f, 0.1596f, 0.3087f, 0.1548f, 0.4695f, + 0.9939f, 0.6113f, 0.6765f, 0.1800f, 0.6750f, 0.2246f}); + NDArray n = NDArrayFactory::create(2); + NDArray exp = NDArrayFactory::create( + 'c', {2, 3}, {0.7244f, 0.5056f, 0.5461f, 0.3087f, 0.4695f, 0.2246f}); - 0.6591f, 0.5555f, 0.1596f, 0.3087f, - 0.1548f, 0.4695f, 0.9939f, 0.6113f, - 0.6765f, 0.1800f, 0.6750f, 0.2246f}); - NDArray n = NDArrayFactory::create(2); - NDArray exp = NDArrayFactory::create('c', {2,3}, {0.7244f, 0.5056f, 0.5461f, 0.3087f, 0.4695f, 0.2246f}); + // input.linspace(1.f); - //input.linspace(1.f); + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {1}); - sd::ops::nth_element op; - auto results = op.evaluate({&input, &n}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray* output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, broadcast_to_test1) { + auto input = NDArrayFactory::create('c', {3}); + auto shape = NDArrayFactory::create('c', {2}, {3, 3}); + auto exp = NDArrayFactory::create('c', {3, 3}, + {1, 2, 3, 1, 2, 3, 1, 2, 3}); - auto input = NDArrayFactory::create('c', {3}); - auto shape = NDArrayFactory::create('c', {2}, {3, 3}); - auto exp = NDArrayFactory::create('c', {3,3}, {1, 2, 3,1, 2, 3, 1, 2, 3}); - - input.linspace(1.f); + input.linspace(1.f); - sd::ops::broadcast_to op; - auto results = op.evaluate({&input, &shape}, {}, {}, {}); + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, broadcast_to_test2) { + auto input = NDArrayFactory::create('c', {1, 3}); + auto shape = NDArrayFactory::create('c', {2}, {3.f, 3.f}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f}); - auto input = NDArrayFactory::create('c', {1,3}); - auto shape = NDArrayFactory::create('c', {2}, {3.f, 3.f}); - auto exp = NDArrayFactory::create('c', {3,3}, {1.f, 2.f, 3.f,1.f, 2.f, 3.f,1.f, 2.f, 3.f}); - - input.linspace(1.f); + input.linspace(1.f); - sd::ops::broadcast_to op; - auto results = op.evaluate({&input, &shape}, {}, {}, {}); + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, broadcast_to_test3) { + auto input = NDArrayFactory::create('c', {3, 1}); + auto shape = NDArrayFactory::create('c', {2}, {3.f, 3.f}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f}); - auto input = NDArrayFactory::create('c', {3,1}); - auto shape = NDArrayFactory::create('c', {2}, {3.f, 3.f}); - auto exp = NDArrayFactory::create('c', {3,3}, {1.f, 1.f, 1.f,2.f, 2.f, 2.f,3.f, 3.f, 3.f}); + input.linspace(1.f); - input.linspace(1.f); + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); - sd::ops::broadcast_to op; - auto results = op.evaluate({&input, &shape}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, broadcast_to_test4) { + auto input = NDArrayFactory::create(10.); + auto shape = NDArrayFactory::create('c', {2}, {3.f, 3.f}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {10.f, 10.f, 10.f, 10.f, 10.f, 10.f, 10.f, 10.f, 10.f}); - auto input = NDArrayFactory::create(10.); - auto shape = NDArrayFactory::create('c', {2}, {3.f, 3.f}); - auto exp = NDArrayFactory::create('c', {3,3}, {10.f, 10.f, 10.f,10.f, 10.f, 10.f, 10.f, 10.f, 10.f}); - - sd::ops::broadcast_to op; - auto results = op.evaluate({&input, &shape}, {}, {}, {}); + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, broadcast_to_test5) { + auto input = NDArrayFactory::create(10.f); + auto shape = NDArrayFactory::create('c', {1}, {3.f}); + auto exp = NDArrayFactory::create('c', {3}, {10.f, 10.f, 10.f}); - auto input = NDArrayFactory::create(10.f); - auto shape = NDArrayFactory::create('c', {1}, {3.f}); - auto exp = NDArrayFactory::create('c', {3}, {10.f, 10.f, 10.f}); - - sd::ops::broadcast_to op; - auto results = op.evaluate({&input, &shape}, {}, {}, {}); + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, broadcast_to_test6) { + auto input = NDArrayFactory::create(10.f); + auto shape = NDArrayFactory::create(1.f); + auto exp = NDArrayFactory::create('c', {1}, {10.f}); - auto input = NDArrayFactory::create(10.f); - auto shape = NDArrayFactory::create(1.f); - auto exp = NDArrayFactory::create('c', {1}, {10.f}); + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); - sd::ops::broadcast_to op; - auto results = op.evaluate({&input, &shape}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, broadcast_to_test7) { + auto input = NDArrayFactory::create(10.f); + auto shape = NDArrayFactory::create(1); + auto exp = NDArrayFactory::create('c', {1}, {10.}); - auto input = NDArrayFactory::create(10.f); - auto shape = NDArrayFactory::create(1); - auto exp = NDArrayFactory::create('c', {1}, {10.}); - - sd::ops::broadcast_to op; - auto results = op.evaluate({&input, &shape}, {}, {}, {}); + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, broadcast_to_test8) { + auto input = NDArrayFactory::create('c', {3}); + auto shape = NDArrayFactory::create('c', {3}, {1.f, 3.f, 3.f}); + auto exp = NDArrayFactory::create( + 'c', {1, 3, 3}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f}); - auto input = NDArrayFactory::create('c', {3}); - auto shape = NDArrayFactory::create('c', {3}, {1.f, 3.f, 3.f}); - auto exp = NDArrayFactory::create('c', {1,3,3}, {1.f, 2.f, 3.f,1.f, 2.f, 3.f,1.f, 2.f, 3.f}); - - input.linspace(1.f); + input.linspace(1.f); - sd::ops::broadcast_to op; - auto results = op.evaluate({&input, &shape}, {}, {}, {}); + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, broadcast_to_test9) { + auto input = NDArrayFactory::create('c', {5, 1, 1}); + auto shape = + NDArrayFactory::create('c', {5}, {2.f, 1.f, 5.f, 1.f, 3.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 1, 5, 1, 3}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, + 4.f, 4.f, 5.f, 5.f, 5.f, 1.f, 1.f, 1.f, 2.f, 2.f, + 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 5.f, 5.f, 5.f}); + input.linspace(1.f); - auto input = NDArrayFactory::create('c', {5,1,1}); - auto shape = NDArrayFactory::create('c', {5}, {2.f,1.f,5.f,1.f,3.f}); - auto exp = NDArrayFactory::create('c', {2,1,5,1,3}, {1.f, 1.f, 1.f,2.f, 2.f, 2.f,3.f, 3.f, 3.f,4.f, 4.f, 4.f,5.f, 5.f, 5.f, - 1.f, 1.f, 1.f,2.f, 2.f, 2.f,3.f, 3.f, 3.f,4.f, 4.f, 4.f,5.f, 5.f, 5.f}); - input.linspace(1.f); + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); - sd::ops::broadcast_to op; - auto results = op.evaluate({&input, &shape}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, broadcast_to_test10) { + auto input = NDArrayFactory::create('c', {5, 1, 3}); + auto shape = + NDArrayFactory::create('c', {5}, {2.f, 1.f, 5.f, 1.f, 3.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 1, 5, 1, 3}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 1.f, 2.f, 3.f, 4.f, 5.f, + 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); + input.linspace(1.f); - auto input = NDArrayFactory::create('c', {5,1,3}); - auto shape = NDArrayFactory::create('c', {5}, {2.f,1.f,5.f,1.f,3.f}); - auto exp = NDArrayFactory::create('c', {2,1,5,1,3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f,10.f, 11.f, 12.f,13.f, 14.f, 15.f, - 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f,10.f, 11.f, 12.f,13.f, 14.f, 15.f}); - input.linspace(1.f); + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); - sd::ops::broadcast_to op; - auto results = op.evaluate({&input, &shape}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto output = results.at(0); - auto *output = results.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1) { - - NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); - //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); - //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - NDArray expected = NDArrayFactory::create('c', {1, 10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4, - 4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10., - 8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12., - 9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6, - 5.8, 6.8, 7.8, 8.8, 7.0, 8., 9., 10., 8.2, 9.2, 10.2, 11.2, - 9.4,10.4, 11.4, 12.4,10.6, 11.6,12.6, 13.6,11.4, 12.4, 13.4, 14.4, - 11.4,12.4, 13.4, 14.4,11.4, 12.4,13.4, 14.4, 5.8, 6.8, 7.8, 8.8, - 7., 8., 9., 10., 8.2, 9.2,10.2, 11.2, 9.4, 10.4, 11.4, 12.4, - 10.6,11.6, 12.6, 13.6,11.8, 12.8,13.8, 14.8,13.0, 14.0, 15.0, 16., - 13.8,14.8, 15.8, 16.8,13.8, 14.8,15.8, 16.8,13.8, 14.8, 15.8, 16.8, - 8.2, 9.2, 10.2, 11.2, 9.4, 10.4,11.4, 12.4,10.6, 11.6, 12.6, 13.6, - 11.8,12.8, 13.8, 14.8,13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, - 15.4,16.4, 17.4, 18.4,16.2, 17.2,18.2, 19.2,16.2, 17.2, 18.2, 19.2, - 16.2,17.2, 18.2, 19.2,10.6, 11.6,12.6, 13.6,11.8, 12.8, 13.8, 14.8, - 13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4, - 16.6,17.6, 18.6, 19.6,17.8, 18.8,19.8, 20.8,18.6, 19.6, 20.6, 21.6, - 18.6,19.6, 20.6, 21.6,18.6, 19.6,20.6, 21.6,13., 14., 15., 16., - 14.2,15.2, 16.2, 17.2,15.4, 16.4,17.4, 18.4,16.6, 17.6, 18.6, 19.6, - 17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, - 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., - 13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4, - 16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22., - 20.2,21.2, 22.2, 23.2,21., 22., 23., 24., 21., 22., 23., 24., - 21., 22., 23., 24., 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, - 15.4,16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8, - 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2,21., 22., 23., 24., - 21., 22., 23., 24., 21., 22., 23., 24., 13., 14., 15., 16., - 14.2,15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6, - 17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, - 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., - 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4, - 16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22., - 20.2,21.2, 22.2, 23.2, - 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24.}); - //input = 1.f; - input.linspace(1); - - sd::ops::resize_bilinear op; - auto results = op.evaluate({&input}, {}, {10, 10}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray* result = results.at(0); - - //result.printIndexedBuffer("Resized to 10x10"); - //expected.printIndexedBuffer("Expect for 10x10"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + // NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + // NDArray expected('c', {2,4,4}, + // {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 10, 10, 4}, + {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4, + 4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10., + 8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12., + 9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6, + 5.8, 6.8, 7.8, 8.8, 7.0, 8., 9., 10., 8.2, 9.2, 10.2, 11.2, + 9.4, 10.4, 11.4, 12.4, 10.6, 11.6, 12.6, 13.6, 11.4, 12.4, 13.4, 14.4, + 11.4, 12.4, 13.4, 14.4, 11.4, 12.4, 13.4, 14.4, 5.8, 6.8, 7.8, 8.8, + 7., 8., 9., 10., 8.2, 9.2, 10.2, 11.2, 9.4, 10.4, 11.4, 12.4, + 10.6, 11.6, 12.6, 13.6, 11.8, 12.8, 13.8, 14.8, 13.0, 14.0, 15.0, 16., + 13.8, 14.8, 15.8, 16.8, 13.8, 14.8, 15.8, 16.8, 13.8, 14.8, 15.8, 16.8, + 8.2, 9.2, 10.2, 11.2, 9.4, 10.4, 11.4, 12.4, 10.6, 11.6, 12.6, 13.6, + 11.8, 12.8, 13.8, 14.8, 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, + 15.4, 16.4, 17.4, 18.4, 16.2, 17.2, 18.2, 19.2, 16.2, 17.2, 18.2, 19.2, + 16.2, 17.2, 18.2, 19.2, 10.6, 11.6, 12.6, 13.6, 11.8, 12.8, 13.8, 14.8, + 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, + 16.6, 17.6, 18.6, 19.6, 17.8, 18.8, 19.8, 20.8, 18.6, 19.6, 20.6, 21.6, + 18.6, 19.6, 20.6, 21.6, 18.6, 19.6, 20.6, 21.6, 13., 14., 15., 16., + 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, 16.6, 17.6, 18.6, 19.6, + 17.8, 18.8, 19.8, 20.8, 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., + 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, + 16.6, 17.6, 18.6, 19.6, 17.8, 18.8, 19.8, 20.8, 19., 20., 21., 22., + 20.2, 21.2, 22.2, 23.2, 21., 22., 23., 24., 21., 22., 23., 24., + 21., 22., 23., 24., 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, + 15.4, 16.4, 17.4, 18.4, 16.6, 17.6, 18.6, 19.6, 17.8, 18.8, 19.8, 20.8, + 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, 21., 22., 23., 24., + 21., 22., 23., 24., 21., 22., 23., 24., 13., 14., 15., 16., + 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, 16.6, 17.6, 18.6, 19.6, + 17.8, 18.8, 19.8, 20.8, 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., + 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, + 16.6, 17.6, 18.6, 19.6, 17.8, 18.8, 19.8, 20.8, 19., 20., 21., 22., + 20.2, 21.2, 22.2, 23.2, 21., 22., 23., 24., 21., 22., 23., 24., + 21., 22., 23., 24.}); + // input = 1.f; + input.linspace(1); + + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input}, {}, {10, 10}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printIndexedBuffer("Resized to 10x10"); + // expected.printIndexedBuffer("Expect for 10x10"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test_11) { + NDArray input = NDArrayFactory::create('c', {1, 1, 1, 256}); - NDArray input = NDArrayFactory::create('c', {1, 1, 1, 256}); + input.assign(0.8f); // linspace(1); + auto size = NDArrayFactory::create({65, 65}); + auto ex = NDArrayFactory::create('c', {1, 65, 65, 256}); + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input, &size}, {}, {}, {false}); - input.assign(0.8f); //linspace(1); - auto size = NDArrayFactory::create({65,65}); - auto ex = NDArrayFactory::create('c', {1,65,65,256}); - sd::ops::resize_bilinear op; - auto results = op.evaluate({&input, &size}, {}, {}, {false}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray* result = results.at(0); - ASSERT_NE(*result, ex); + auto result = results.at(0); + ASSERT_NE(result, ex); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test_12) { + NDArray input = NDArrayFactory::create('c', {1, 1, 1, 256}); - NDArray input = NDArrayFactory::create('c', {1, 1, 1, 256}); - - input.assign(0.8f); //linspace(1); - auto size = NDArrayFactory::create({65,65}); - auto ex = NDArrayFactory::create('c', {1,65,65,256}); - sd::ops::resize_bilinear op; - auto results = op.evaluate({&input, &size}, {}, {}, {true}); + input.assign(0.8f); // linspace(1); + auto size = NDArrayFactory::create({65, 65}); + auto ex = NDArrayFactory::create('c', {1, 65, 65, 256}); + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input, &size}, {}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); - ASSERT_NE(*result, ex); + auto result = results.at(0); + ASSERT_NE(result, ex); } TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_1) { + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + // NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + // NDArray expected('c', {2,4,4}, + // {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 4, 5, 4}, + {1., 2., 3., 4., 2.6, 3.6, 4.6, 5.6, 5., 6., + 7., 8., 7.4, 8.4, 9.4, 10.4, 9., 10., 11., 12., - NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); - //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); - //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { - 1., 2., 3., 4., - 2.6, 3.6, 4.6, 5.6, - 5., 6., 7., 8., - 7.4, 8.4, 9.4, 10.4, - 9., 10., 11., 12., - - 4., 5., 6., 7., - 5.6, 6.6, 7.6, 8.6, - 8., 9., 10., 11., - 10.4, 11.4, 12.4, 13.4, - 12., 13., 14., 15., - - 10., 11., 12., 13., - 11.6, 12.6, 13.6, 14.6, - 14., 15., 16., 17., - 16.4, 17.4, 18.4, 19.4, - 18., 19., 20., 21., - - 13., 14., 15., 16., - 14.6, 15.6, 16.6, 17.6, - 17., 18., 19., 20., - 19.4, 20.4, 21.4, 22.4, - 21., 22., 23., 24. - }); - //input = 1.f; - input.linspace(1); + 4., 5., 6., 7., 5.6, 6.6, 7.6, 8.6, 8., 9., + 10., 11., 10.4, 11.4, 12.4, 13.4, 12., 13., 14., 15., - sd::ops::resize_bilinear op; - auto results = op.evaluate({&input}, {}, {4, 5}, {false, true}); + 10., 11., 12., 13., 11.6, 12.6, 13.6, 14.6, 14., 15., + 16., 17., 16.4, 17.4, 18.4, 19.4, 18., 19., 20., 21., - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + 13., 14., 15., 16., 14.6, 15.6, 16.6, 17.6, 17., 18., + 19., 20., 19.4, 20.4, 21.4, 22.4, 21., 22., 23., 24.}); + // input = 1.f; + input.linspace(1); - NDArray* result = results.at(0); + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input}, {}, {4, 5}, {false, true}); -// result.printIndexedBuffer("Resized to 4x5 bilinear with half pixels"); - //expected.printIndexedBuffer("Expect for 10x10"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printIndexedBuffer("Resized to 4x5 bilinear with half pixels"); + // expected.printIndexedBuffer("Expect for 10x10"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_2) { + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + // NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + // NDArray expected('c', {2,4,4}, + // {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 4, 5, 4}, + {1.f, 2.f, 3.f, 4.f, 2.6f, 3.6f, 4.6f, 5.6f, 5.f, 6.f, + 7.f, 8.f, 7.4f, 8.4f, 9.4f, 10.4f, 9.f, 10.f, 11.f, 12.f, - NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); - //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); - //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { - 1.f, 2.f, 3.f, 4.f, - 2.6f, 3.6f, 4.6f, 5.6f, - 5.f, 6.f, 7.f, 8.f, - 7.4f, 8.4f, 9.4f, 10.4f, - 9.f, 10.f, 11.f, 12.f, - - 4.f, 5.f, 6.f, 7.f, - 5.6f, 6.6f, 7.6f, 8.6f, - 8.f, 9.f, 10.f, 11.f, - 10.4f, 11.4f, 12.4f, 13.4f, - 12.f, 13.f, 14.f, 15.f, - - 10.f, 11.f, 12.f, 13.f, - 11.6f, 12.6f, 13.6f, 14.6f, - 14.f, 15.f, 16.f, 17.f, - 16.4f, 17.4f, 18.4f, 19.4f, - 18.f, 19.f, 20.f, 21.f, - - 13.f, 14.f, 15.f, 16.f, - 14.6f, 15.6f, 16.6f, 17.6f, - 17.f, 18.f, 19.f, 20.f, - 19.4f, 20.4f, 21.4f, 22.4f, - 21.f, 22.f, 23.f, 24.f - }); - //input = 1.f; - input.linspace(1); + 4.f, 5.f, 6.f, 7.f, 5.6f, 6.6f, 7.6f, 8.6f, 8.f, 9.f, + 10.f, 11.f, 10.4f, 11.4f, 12.4f, 13.4f, 12.f, 13.f, 14.f, 15.f, - sd::ops::resize_bilinear op; - auto results = op.evaluate({&input}, {}, {4, 5}, {false, true}); + 10.f, 11.f, 12.f, 13.f, 11.6f, 12.6f, 13.6f, 14.6f, 14.f, 15.f, + 16.f, 17.f, 16.4f, 17.4f, 18.4f, 19.4f, 18.f, 19.f, 20.f, 21.f, - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + 13.f, 14.f, 15.f, 16.f, 14.6f, 15.6f, 16.6f, 17.6f, 17.f, 18.f, + 19.f, 20.f, 19.4f, 20.4f, 21.4f, 22.4f, 21.f, 22.f, 23.f, 24.f}); + // input = 1.f; + input.linspace(1); - NDArray* result = results.at(0); + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input}, {}, {4, 5}, {false, true}); -// result.printBuffer("Resized to 4x5"); -// expected.printBuffer("Expect for 4x5"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printBuffer("Resized to 4x5"); + // expected.printBuffer("Expect for 4x5"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) { - - NDArray input = NDArrayFactory::create('c', {2,3,4}); - //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); - //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - NDArray expected = NDArrayFactory::create('c', {10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4, - 4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10., - 8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12., - 9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6, - 5.8, 6.8, 7.8, 8.8, 7.0, 8., 9., 10., 8.2, 9.2, 10.2, 11.2, - 9.4,10.4, 11.4, 12.4,10.6, 11.6,12.6, 13.6,11.4, 12.4, 13.4, 14.4, - 11.4,12.4, 13.4, 14.4,11.4, 12.4,13.4, 14.4, 5.8, 6.8, 7.8, 8.8, - 7., 8., 9., 10., 8.2, 9.2,10.2, 11.2, 9.4, 10.4, 11.4, 12.4, - 10.6,11.6, 12.6, 13.6,11.8, 12.8,13.8, 14.8,13.0, 14.0, 15.0, 16., - 13.8,14.8, 15.8, 16.8,13.8, 14.8,15.8, 16.8,13.8, 14.8, 15.8, 16.8, - 8.2, 9.2, 10.2, 11.2, 9.4, 10.4,11.4, 12.4,10.6, 11.6, 12.6, 13.6, - 11.8,12.8, 13.8, 14.8,13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, - 15.4,16.4, 17.4, 18.4,16.2, 17.2,18.2, 19.2,16.2, 17.2, 18.2, 19.2, - 16.2,17.2, 18.2, 19.2,10.6, 11.6,12.6, 13.6,11.8, 12.8, 13.8, 14.8, - 13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4, - 16.6,17.6, 18.6, 19.6,17.8, 18.8,19.8, 20.8,18.6, 19.6, 20.6, 21.6, - 18.6,19.6, 20.6, 21.6,18.6, 19.6,20.6, 21.6,13., 14., 15., 16., - 14.2,15.2, 16.2, 17.2,15.4, 16.4,17.4, 18.4,16.6, 17.6, 18.6, 19.6, - 17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, - 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., - 13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4, - 16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22., - 20.2,21.2, 22.2, 23.2,21., 22., 23., 24., 21., 22., 23., 24., - 21., 22., 23., 24., 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, - 15.4,16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8, - 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2,21., 22., 23., 24., - 21., 22., 23., 24., 21., 22., 23., 24., 13., 14., 15., 16., - 14.2,15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6, - 17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, - 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., - 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4, - 16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22., - 20.2,21.2, 22.2, 23.2, - 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24.}); - //input = 1.f; - input.linspace(1); - - sd::ops::resize_bilinear op; - auto results = op.evaluate({&input}, {}, {10, 10}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray* result = results.at(0); - - //result.printIndexedBuffer("Resized to 10x10"); - //expected.printIndexedBuffer("Expect for 10x10"); -// result.printShapeInfo("Output shape"); -// expected.printShapeInfo("Expect shape"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - + NDArray input = NDArrayFactory::create('c', {2, 3, 4}); + // NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + // NDArray expected('c', {2,4,4}, + // {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create( + 'c', {10, 10, 4}, + {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4, + 4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10., + 8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12., + 9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6, + 5.8, 6.8, 7.8, 8.8, 7.0, 8., 9., 10., 8.2, 9.2, 10.2, 11.2, + 9.4, 10.4, 11.4, 12.4, 10.6, 11.6, 12.6, 13.6, 11.4, 12.4, 13.4, 14.4, + 11.4, 12.4, 13.4, 14.4, 11.4, 12.4, 13.4, 14.4, 5.8, 6.8, 7.8, 8.8, + 7., 8., 9., 10., 8.2, 9.2, 10.2, 11.2, 9.4, 10.4, 11.4, 12.4, + 10.6, 11.6, 12.6, 13.6, 11.8, 12.8, 13.8, 14.8, 13.0, 14.0, 15.0, 16., + 13.8, 14.8, 15.8, 16.8, 13.8, 14.8, 15.8, 16.8, 13.8, 14.8, 15.8, 16.8, + 8.2, 9.2, 10.2, 11.2, 9.4, 10.4, 11.4, 12.4, 10.6, 11.6, 12.6, 13.6, + 11.8, 12.8, 13.8, 14.8, 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, + 15.4, 16.4, 17.4, 18.4, 16.2, 17.2, 18.2, 19.2, 16.2, 17.2, 18.2, 19.2, + 16.2, 17.2, 18.2, 19.2, 10.6, 11.6, 12.6, 13.6, 11.8, 12.8, 13.8, 14.8, + 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, + 16.6, 17.6, 18.6, 19.6, 17.8, 18.8, 19.8, 20.8, 18.6, 19.6, 20.6, 21.6, + 18.6, 19.6, 20.6, 21.6, 18.6, 19.6, 20.6, 21.6, 13., 14., 15., 16., + 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, 16.6, 17.6, 18.6, 19.6, + 17.8, 18.8, 19.8, 20.8, 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., + 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, + 16.6, 17.6, 18.6, 19.6, 17.8, 18.8, 19.8, 20.8, 19., 20., 21., 22., + 20.2, 21.2, 22.2, 23.2, 21., 22., 23., 24., 21., 22., 23., 24., + 21., 22., 23., 24., 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, + 15.4, 16.4, 17.4, 18.4, 16.6, 17.6, 18.6, 19.6, 17.8, 18.8, 19.8, 20.8, + 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, 21., 22., 23., 24., + 21., 22., 23., 24., 21., 22., 23., 24., 13., 14., 15., 16., + 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, 16.6, 17.6, 18.6, 19.6, + 17.8, 18.8, 19.8, 20.8, 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., + 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, + 16.6, 17.6, 18.6, 19.6, 17.8, 18.8, 19.8, 20.8, 19., 20., 21., 22., + 20.2, 21.2, 22.2, 23.2, 21., 22., 23., 24., 21., 22., 23., 24., + 21., 22., 23., 24.}); + // input = 1.f; + input.linspace(1); + + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input}, {}, {10, 10}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printIndexedBuffer("Resized to 10x10"); + // expected.printIndexedBuffer("Expect for 10x10"); + // result.printShapeInfo("Output shape"); + // expected.printShapeInfo("Expect shape"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests10, ResizeImages_Test1) { @@ -1607,7 +1672,7 @@ TEST_F(DeclarableOpsTests10, ResizeImages_Test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray *result = results.at(0); + auto result = results.at(0); // result->printBuffer("Resized to 7x9"); // expected.printBuffer("Expect for 7x9"); @@ -1617,1619 +1682,1530 @@ TEST_F(DeclarableOpsTests10, ResizeImages_Test1) { ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) { - - NDArray input = NDArrayFactory::create('c', {2, 5,5,3}, { - 0.7788f, 0.8012f, 0.7244f, - 0.2309f, 0.7271f, 0.1804f, - 0.5056f, 0.8925f, 0.5461f, - 0.9234f, 0.0856f, 0.7938f, - 0.6591f, 0.5555f, 0.1596f, - 0.3087f, 0.1548f, 0.4695f, - 0.9939f, 0.6113f, 0.6765f, - 0.1800f, 0.6750f, 0.2246f, - 0.0509f, 0.4601f, 0.8284f, - 0.2354f, 0.9752f, 0.8361f, - 0.2585f, 0.4189f, 0.7028f, - 0.7679f, 0.5373f, 0.7234f, - 0.2690f, 0.0062f, 0.0327f, - 0.0644f, 0.8428f, 0.7494f, - 0.0755f, 0.6245f, 0.3491f, - 0.5793f, 0.5730f, 0.1822f, - 0.6420f, 0.9143f, 0.3019f, - 0.3574f, 0.1704f, 0.8395f, - 0.5468f, 0.0744f, 0.9011f, - 0.6574f, 0.4124f, 0.2445f, - 0.4248f, 0.5219f, 0.6952f, - 0.4900f, 0.2158f, 0.9549f, - 0.1386f, 0.1544f, 0.5365f, - 0.0134f, 0.4163f, 0.1456f, - 0.4109f, 0.2484f, 0.3330f, - 0.2974f, 0.6636f, 0.3808f, - 0.8664f, 0.1896f, 0.7530f, - 0.7215f, 0.6612f, 0.7270f, - 0.5704f, 0.2666f, 0.7453f, - 0.0444f, 0.3024f, 0.4850f, - 0.7982f, 0.0965f, 0.7843f, - 0.5075f, 0.0844f, 0.8370f, - 0.6103f, 0.4604f, 0.6087f, - 0.8594f, 0.4599f, 0.6714f, - 0.2744f, 0.1981f, 0.4143f, - 0.7821f, 0.3505f, 0.5040f, - 0.1180f, 0.8307f, 0.1817f, - 0.8442f, 0.5074f, 0.4471f, - 0.5105f, 0.6666f, 0.2576f, - 0.2341f, 0.6801f, 0.2652f, - 0.5394f, 0.4690f, 0.6146f, - 0.1210f, 0.2576f, 0.0769f, - 0.4643f, 0.1628f, 0.2026f, - 0.3774f, 0.0506f, 0.3462f, - 0.5720f, 0.0838f, 0.4228f, - 0.0588f, 0.5362f, 0.4756f, - 0.2530f, 0.1778f, 0.0751f, - 0.8977f, 0.3648f, 0.3065f, - 0.4739f, 0.7014f, 0.4473f, - 0.5171f, 0.1744f, 0.3487f}); - - NDArray expected = NDArrayFactory::create('c', {2, 9, 9, 3}, { - 0.7788f, 0.8012f, 0.7244f, 0.4744111f, 0.7600333f, 0.42217776f, - 0.26142225f, 0.7454778f, 0.22103335f, 0.41403335f, 0.8373667f, 0.42420003f, - 0.59844446f, 0.71318877f, 0.6011445f, 0.83055556f, 0.264911f, 0.7387556f, - 0.83529997f, 0.2422334f, 0.5823999f, 0.6884666f, 0.5032889f, 0.23006654f, - 0.6591f, 0.5555f, 0.1596f, 0.5176333f, 0.44208887f , 0.5827889f, - 0.5938309f, 0.5646876f, 0.5123568f, 0.61811364f, 0.6748667f, 0.44617534f, - 0.43473703f, 0.7353667f, 0.3969963f, 0.35003704f, 0.6654419f, 0.46649635f, - 0.41335183f, 0.39988017f, 0.7140149f, 0.43368888f, 0.45865932f, 0.72049254f, - 0.42537406f, 0.73366547f, 0.5662765f, 0.42371112f, 0.78866667f, 0.53543335f, - 0.30312222f, 0.18414445f, 0.49542224f, 0.67293704f, 0.4168852f, 0.59891605f, - 0.8822444f, 0.60281235f, 0.62855184f, 0.4495222f, 0.6014852f, 0.36275554f, - 0.15933579f, 0.5788963f, 0.34024328f, 0.08295307f, 0.52441484f, 0.6826569f, - 0.10747781f, 0.64715934f, 0.80707777f, 0.19927411f, 0.8880544f, 0.7861703f, - 0.21763334f, 0.9362333f, 0.78198886f, 0.27523333f, 0.3308667f, 0.6250333f, - 0.5907889f, 0.45925558f, 0.6709963f, 0.7761333f, 0.5249852f, 0.63986665f, - 0.4406333f, 0.34007773f, 0.3003666f, 0.19945924f, 0.33715558f, 0.24757043f, - 0.09977405f, 0.60721123f, 0.6248297f, 0.08286668f, 0.7239556f, 0.6876333f, - 0.12114445f, 0.73849255f ,0.54079986f, 0.12879999f, 0.74139994f, 0.51143324f, - 0.32978892f, 0.45314446f, 0.58711106f, 0.5576408f, 0.5464408f, 0.6107901f, - 0.68978024f, 0.55681235f, 0.5833172f, 0.43907034f, 0.23548517f, 0.35123706f, - 0.26263458f, 0.18254575f, 0.33890504f, 0.1976099f, 0.5321877f, 0.65619516f, - 0.18267044f, 0.6404851f, 0.63069254f, 0.20112106f, 0.58788633f, 0.37666163f, - 0.20481117f, 0.57736665f, 0.32585555f, 0.50801116f, 0.5387556f, 0.29788882f, - 0.59799266f, 0.7008482f, 0.35215425f, 0.6330642f, 0.753121f, 0.42497158f, - 0.44849625f, 0.36611477f, 0.5719964f, 0.36038768f, 0.1586321f, 0.70625067f, - 0.416968f, 0.22043455f, 0.82134944f, 0.4690964f, 0.31661478f, 0.6675073f, - 0.5182569f, 0.4357136f, 0.33437145f, 0.528089f, 0.4595333f, 0.26774442f, - 0.52779996f, 0.5559667f, 0.35320008f, 0.5630963f, 0.62568885f, 0.44562602f, - 0.557237f, 0.62408876f, 0.5438927f, 0.3867555f, 0.3371999f, 0.6655223f, - 0.30325183f, 0.17024446f, 0.71867025f, 0.35021478f, 0.18318895f, 0.6690962f, - 0.4377444f, 0.24482228f, 0.5241777f, 0.5523185f, 0.33891484f, 0.3156962f, - 0.5752333f, 0.3577333f, 0.27400002f, 0.44196665f, 0.52757776f, 0.6382001f, - 0.47803456f, 0.3974851f, 0.7738359f, 0.4686691f, 0.27816284f, 0.8476581f, - 0.2775703f, 0.20192216f, 0.6742259f, 0.14285672f, 0.20554078f, 0.4944727f, - 0.0927209f, 0.32894826f, 0.30523813f, 0.19454071f, 0.3410815f, 0.26075178f, - 0.3976642f, 0.27903205f, 0.31276423f, 0.43828884f, 0.2666222f, 0.32316667f, - 0.4248f, 0.5219f, 0.6952f, 0.46102223f, 0.35184443f, 0.8394778f, - 0.45095554f, 0.20897777f, 0.9084111f, 0.2557333f, 0.17486666f, 0.6759666f, - 0.11077777f, 0.21260004f, 0.44963327f, 0.04122221f, 0.35810006f, 0.23246664f, - 0.14590007f, 0.36033332f, 0.2080667f, 0.3667334f, 0.2670555f, 0.31217784f, - 0.4109f, 0.2484f, 0.333f, 0.2974f, 0.6636f, 0.3808f, - 0.6135111f, 0.40026665f, 0.5875778f, 0.8503f, 0.24200003f, 0.7501111f, - 0.76979995f, 0.50400007f, 0.7356667f, 0.6879222f, 0.57351106f, 0.73106664f, - 0.60397774f, 0.35428885f, 0.74123335f, 0.39506656f, 0.27853334f, 0.6585333f, - 0.10284433f, 0.29842222f, 0.5139222f, 0.0444f, 0.3024f, 0.485f, - 0.5756222f, 0.34854442f, 0.6049667f, 0.6263938f, 0.22777282f, 0.71313334f, - 0.66620123f, 0.17765433f, 0.78429013f, 0.6621518f, 0.41014817f, 0.7074074f, - 0.67555183f, 0.51060987f, 0.6708259f, 0.7151259f, 0.41302344f, 0.6946963f, - 0.5446962f, 0.33081108f, 0.6180703f, 0.23426408f, 0.25884813f, 0.4744469f, - 0.17217779f, 0.24445555f, 0.44572222f, 0.7964111f, 0.12472223f, 0.7531556f, - 0.6118617f, 0.1483889f, 0.75928515f, 0.4833407f, 0.2004667f, 0.7449173f, - 0.57893336f, 0.3661889f, 0.6485592f, 0.6772543f, 0.46945432f, 0.5984506f, - 0.7796679f, 0.47903457f, 0.617716f, 0.63706285f, 0.40579626f, 0.54952586f, - 0.33111224f, 0.27734566f, 0.42303205f, 0.26992223f, 0.25165558f, 0.39773333f, - 0.7874667f, 0.26583335f, 0.5974333f, 0.4876703f, 0.44144446f, 0.48782218f, - 0.30543333f, 0.57191116f, 0.41133702f, 0.5934334f, 0.5218f, 0.46735552f, - 0.73524815f, 0.5152815f, 0.47753704f, 0.6577852f, 0.5741519f, 0.41896293f, - 0.50037766f, 0.57161117f, 0.3686555f, 0.28967398f, 0.5281297f, 0.3238592f, - 0.24753332f, 0.5194334f, 0.31489998f, 0.72816664f, 0.37683335f, 0.5285778f, - 0.3895555f, 0.5582283f, 0.32292962f, 0.18990126f, 0.6730641f, 0.18445063f, - 0.5460741f, 0.5216629f, 0.31464812f, 0.6978098f, 0.45279747f, 0.36710492f, - 0.5428901f, 0.5077358f, 0.30295062f, 0.42367774f, 0.53567034f, 0.28493333f, - 0.32827038f, 0.54560244f, 0.2976741f, 0.30918893f, 0.5475888f, 0.30022222f, - 0.5933333f, 0.44266668f, 0.59002227f, 0.3305555f, 0.4106049f, 0.31789258f, - 0.16793211f, 0.36878017f, 0.11760493f, 0.40592593f, 0.28790364f, 0.20468517f, - 0.5172234f, 0.22784683f, 0.27239504f, 0.4384765f, 0.19901967f, 0.3110494f, - 0.43695557f, 0.19709623f, 0.34693336f, 0.4869186f, 0.21310854f, 0.38097042f, - 0.49691117f, 0.21631104f, 0.3877778f, 0.37919992f, 0.4914f, 0.56826663f, - 0.26019996f, 0.34673333f, 0.29495183f, 0.21430746f, 0.23090371f, 0.09418149f, - 0.46084452f, 0.23042224f, 0.1835889f, 0.56450003f, 0.23844449f, 0.26893705f, - 0.45383334f, 0.2592223f, 0.34819633f, 0.45761114f, 0.21635559f, 0.38596666f, - 0.5376852f, 0.13105926f, 0.39607778f, 0.55370003f, 0.11400001f, 0.3981f, - 0.11219993f, 0.5287333f, 0.49104443f, 0.18227404f, 0.3386963f, 0.26007527f, - 0.30624574f, 0.20396544f, 0.09970618f, 0.6458075f, 0.2904593f, 0.22173704f, - 0.7636852f, 0.40607417f, 0.32631359f, 0.549037f, 0.5653705f, 0.40470868f, - 0.4831852f, 0.47417036f, 0.40968886f, 0.5165309f, 0.21597281f, 0.3657259f, - 0.5232f, 0.16433334f, 0.3569333f, 0.0588f, 0.5362f, 0.4756f, - 0.16668889f, 0.33708888f, 0.25309998f, 0.32463336f, 0.19857779f, 0.10081112f, - 0.68280005f, 0.3024667f, 0.22936666f, 0.80352217f, 0.43960005f, 0.33778888f, - 0.5680777f, 0.6266f, 0.41601112f, 0.4883f, 0.52573323f, 0.4144333f, - 0.5123f, 0.23295549f, 0.35965553f, 0.5171f, 0.1744f, 0.3487f - }); - //input.linspace(1); - - sd::ops::resize_bilinear op; - auto results = op.evaluate({&input}, {}, {9, 9}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray* result = results.at(0); - -// result.printBuffer("Resized to 9x9"); -// expected.printBuffer("Expect for 9x9"); -// result.printShapeInfo("Output shape"); -// expected.printShapeInfo("Expect shape"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - + NDArray input = NDArrayFactory::create( + 'c', {2, 5, 5, 3}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f, 0.3087f, + 0.1548f, 0.4695f, 0.9939f, 0.6113f, 0.6765f, 0.1800f, 0.6750f, 0.2246f, + 0.0509f, 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f, 0.2585f, 0.4189f, + 0.7028f, 0.7679f, 0.5373f, 0.7234f, 0.2690f, 0.0062f, 0.0327f, 0.0644f, + 0.8428f, 0.7494f, 0.0755f, 0.6245f, 0.3491f, 0.5793f, 0.5730f, 0.1822f, + 0.6420f, 0.9143f, 0.3019f, 0.3574f, 0.1704f, 0.8395f, 0.5468f, 0.0744f, + 0.9011f, 0.6574f, 0.4124f, 0.2445f, 0.4248f, 0.5219f, 0.6952f, 0.4900f, + 0.2158f, 0.9549f, 0.1386f, 0.1544f, 0.5365f, 0.0134f, 0.4163f, 0.1456f, + 0.4109f, 0.2484f, 0.3330f, 0.2974f, 0.6636f, 0.3808f, 0.8664f, 0.1896f, + 0.7530f, 0.7215f, 0.6612f, 0.7270f, 0.5704f, 0.2666f, 0.7453f, 0.0444f, + 0.3024f, 0.4850f, 0.7982f, 0.0965f, 0.7843f, 0.5075f, 0.0844f, 0.8370f, + 0.6103f, 0.4604f, 0.6087f, 0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, + 0.4143f, 0.7821f, 0.3505f, 0.5040f, 0.1180f, 0.8307f, 0.1817f, 0.8442f, + 0.5074f, 0.4471f, 0.5105f, 0.6666f, 0.2576f, 0.2341f, 0.6801f, 0.2652f, + 0.5394f, 0.4690f, 0.6146f, 0.1210f, 0.2576f, 0.0769f, 0.4643f, 0.1628f, + 0.2026f, 0.3774f, 0.0506f, 0.3462f, 0.5720f, 0.0838f, 0.4228f, 0.0588f, + 0.5362f, 0.4756f, 0.2530f, 0.1778f, 0.0751f, 0.8977f, 0.3648f, 0.3065f, + 0.4739f, 0.7014f, 0.4473f, 0.5171f, 0.1744f, 0.3487f}); + + NDArray expected = NDArrayFactory::create( + 'c', {2, 9, 9, 3}, + {0.7788f, 0.8012f, 0.7244f, 0.4744111f, 0.7600333f, + 0.42217776f, 0.26142225f, 0.7454778f, 0.22103335f, 0.41403335f, + 0.8373667f, 0.42420003f, 0.59844446f, 0.71318877f, 0.6011445f, + 0.83055556f, 0.264911f, 0.7387556f, 0.83529997f, 0.2422334f, + 0.5823999f, 0.6884666f, 0.5032889f, 0.23006654f, 0.6591f, + 0.5555f, 0.1596f, 0.5176333f, 0.44208887f, 0.5827889f, + 0.5938309f, 0.5646876f, 0.5123568f, 0.61811364f, 0.6748667f, + 0.44617534f, 0.43473703f, 0.7353667f, 0.3969963f, 0.35003704f, + 0.6654419f, 0.46649635f, 0.41335183f, 0.39988017f, 0.7140149f, + 0.43368888f, 0.45865932f, 0.72049254f, 0.42537406f, 0.73366547f, + 0.5662765f, 0.42371112f, 0.78866667f, 0.53543335f, 0.30312222f, + 0.18414445f, 0.49542224f, 0.67293704f, 0.4168852f, 0.59891605f, + 0.8822444f, 0.60281235f, 0.62855184f, 0.4495222f, 0.6014852f, + 0.36275554f, 0.15933579f, 0.5788963f, 0.34024328f, 0.08295307f, + 0.52441484f, 0.6826569f, 0.10747781f, 0.64715934f, 0.80707777f, + 0.19927411f, 0.8880544f, 0.7861703f, 0.21763334f, 0.9362333f, + 0.78198886f, 0.27523333f, 0.3308667f, 0.6250333f, 0.5907889f, + 0.45925558f, 0.6709963f, 0.7761333f, 0.5249852f, 0.63986665f, + 0.4406333f, 0.34007773f, 0.3003666f, 0.19945924f, 0.33715558f, + 0.24757043f, 0.09977405f, 0.60721123f, 0.6248297f, 0.08286668f, + 0.7239556f, 0.6876333f, 0.12114445f, 0.73849255f, 0.54079986f, + 0.12879999f, 0.74139994f, 0.51143324f, 0.32978892f, 0.45314446f, + 0.58711106f, 0.5576408f, 0.5464408f, 0.6107901f, 0.68978024f, + 0.55681235f, 0.5833172f, 0.43907034f, 0.23548517f, 0.35123706f, + 0.26263458f, 0.18254575f, 0.33890504f, 0.1976099f, 0.5321877f, + 0.65619516f, 0.18267044f, 0.6404851f, 0.63069254f, 0.20112106f, + 0.58788633f, 0.37666163f, 0.20481117f, 0.57736665f, 0.32585555f, + 0.50801116f, 0.5387556f, 0.29788882f, 0.59799266f, 0.7008482f, + 0.35215425f, 0.6330642f, 0.753121f, 0.42497158f, 0.44849625f, + 0.36611477f, 0.5719964f, 0.36038768f, 0.1586321f, 0.70625067f, + 0.416968f, 0.22043455f, 0.82134944f, 0.4690964f, 0.31661478f, + 0.6675073f, 0.5182569f, 0.4357136f, 0.33437145f, 0.528089f, + 0.4595333f, 0.26774442f, 0.52779996f, 0.5559667f, 0.35320008f, + 0.5630963f, 0.62568885f, 0.44562602f, 0.557237f, 0.62408876f, + 0.5438927f, 0.3867555f, 0.3371999f, 0.6655223f, 0.30325183f, + 0.17024446f, 0.71867025f, 0.35021478f, 0.18318895f, 0.6690962f, + 0.4377444f, 0.24482228f, 0.5241777f, 0.5523185f, 0.33891484f, + 0.3156962f, 0.5752333f, 0.3577333f, 0.27400002f, 0.44196665f, + 0.52757776f, 0.6382001f, 0.47803456f, 0.3974851f, 0.7738359f, + 0.4686691f, 0.27816284f, 0.8476581f, 0.2775703f, 0.20192216f, + 0.6742259f, 0.14285672f, 0.20554078f, 0.4944727f, 0.0927209f, + 0.32894826f, 0.30523813f, 0.19454071f, 0.3410815f, 0.26075178f, + 0.3976642f, 0.27903205f, 0.31276423f, 0.43828884f, 0.2666222f, + 0.32316667f, 0.4248f, 0.5219f, 0.6952f, 0.46102223f, + 0.35184443f, 0.8394778f, 0.45095554f, 0.20897777f, 0.9084111f, + 0.2557333f, 0.17486666f, 0.6759666f, 0.11077777f, 0.21260004f, + 0.44963327f, 0.04122221f, 0.35810006f, 0.23246664f, 0.14590007f, + 0.36033332f, 0.2080667f, 0.3667334f, 0.2670555f, 0.31217784f, + 0.4109f, 0.2484f, 0.333f, 0.2974f, 0.6636f, + 0.3808f, 0.6135111f, 0.40026665f, 0.5875778f, 0.8503f, + 0.24200003f, 0.7501111f, 0.76979995f, 0.50400007f, 0.7356667f, + 0.6879222f, 0.57351106f, 0.73106664f, 0.60397774f, 0.35428885f, + 0.74123335f, 0.39506656f, 0.27853334f, 0.6585333f, 0.10284433f, + 0.29842222f, 0.5139222f, 0.0444f, 0.3024f, 0.485f, + 0.5756222f, 0.34854442f, 0.6049667f, 0.6263938f, 0.22777282f, + 0.71313334f, 0.66620123f, 0.17765433f, 0.78429013f, 0.6621518f, + 0.41014817f, 0.7074074f, 0.67555183f, 0.51060987f, 0.6708259f, + 0.7151259f, 0.41302344f, 0.6946963f, 0.5446962f, 0.33081108f, + 0.6180703f, 0.23426408f, 0.25884813f, 0.4744469f, 0.17217779f, + 0.24445555f, 0.44572222f, 0.7964111f, 0.12472223f, 0.7531556f, + 0.6118617f, 0.1483889f, 0.75928515f, 0.4833407f, 0.2004667f, + 0.7449173f, 0.57893336f, 0.3661889f, 0.6485592f, 0.6772543f, + 0.46945432f, 0.5984506f, 0.7796679f, 0.47903457f, 0.617716f, + 0.63706285f, 0.40579626f, 0.54952586f, 0.33111224f, 0.27734566f, + 0.42303205f, 0.26992223f, 0.25165558f, 0.39773333f, 0.7874667f, + 0.26583335f, 0.5974333f, 0.4876703f, 0.44144446f, 0.48782218f, + 0.30543333f, 0.57191116f, 0.41133702f, 0.5934334f, 0.5218f, + 0.46735552f, 0.73524815f, 0.5152815f, 0.47753704f, 0.6577852f, + 0.5741519f, 0.41896293f, 0.50037766f, 0.57161117f, 0.3686555f, + 0.28967398f, 0.5281297f, 0.3238592f, 0.24753332f, 0.5194334f, + 0.31489998f, 0.72816664f, 0.37683335f, 0.5285778f, 0.3895555f, + 0.5582283f, 0.32292962f, 0.18990126f, 0.6730641f, 0.18445063f, + 0.5460741f, 0.5216629f, 0.31464812f, 0.6978098f, 0.45279747f, + 0.36710492f, 0.5428901f, 0.5077358f, 0.30295062f, 0.42367774f, + 0.53567034f, 0.28493333f, 0.32827038f, 0.54560244f, 0.2976741f, + 0.30918893f, 0.5475888f, 0.30022222f, 0.5933333f, 0.44266668f, + 0.59002227f, 0.3305555f, 0.4106049f, 0.31789258f, 0.16793211f, + 0.36878017f, 0.11760493f, 0.40592593f, 0.28790364f, 0.20468517f, + 0.5172234f, 0.22784683f, 0.27239504f, 0.4384765f, 0.19901967f, + 0.3110494f, 0.43695557f, 0.19709623f, 0.34693336f, 0.4869186f, + 0.21310854f, 0.38097042f, 0.49691117f, 0.21631104f, 0.3877778f, + 0.37919992f, 0.4914f, 0.56826663f, 0.26019996f, 0.34673333f, + 0.29495183f, 0.21430746f, 0.23090371f, 0.09418149f, 0.46084452f, + 0.23042224f, 0.1835889f, 0.56450003f, 0.23844449f, 0.26893705f, + 0.45383334f, 0.2592223f, 0.34819633f, 0.45761114f, 0.21635559f, + 0.38596666f, 0.5376852f, 0.13105926f, 0.39607778f, 0.55370003f, + 0.11400001f, 0.3981f, 0.11219993f, 0.5287333f, 0.49104443f, + 0.18227404f, 0.3386963f, 0.26007527f, 0.30624574f, 0.20396544f, + 0.09970618f, 0.6458075f, 0.2904593f, 0.22173704f, 0.7636852f, + 0.40607417f, 0.32631359f, 0.549037f, 0.5653705f, 0.40470868f, + 0.4831852f, 0.47417036f, 0.40968886f, 0.5165309f, 0.21597281f, + 0.3657259f, 0.5232f, 0.16433334f, 0.3569333f, 0.0588f, + 0.5362f, 0.4756f, 0.16668889f, 0.33708888f, 0.25309998f, + 0.32463336f, 0.19857779f, 0.10081112f, 0.68280005f, 0.3024667f, + 0.22936666f, 0.80352217f, 0.43960005f, 0.33778888f, 0.5680777f, + 0.6266f, 0.41601112f, 0.4883f, 0.52573323f, 0.4144333f, + 0.5123f, 0.23295549f, 0.35965553f, 0.5171f, 0.1744f, + 0.3487f}); + // input.linspace(1); + + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input}, {}, {9, 9}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Resized to 9x9"); + // expected.printBuffer("Expect for 9x9"); + // result.printShapeInfo("Output shape"); + // expected.printShapeInfo("Expect shape"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test2) { - - NDArray input = NDArrayFactory::create('c', {1, 2,3,4}); - NDArray size = NDArrayFactory::create({10, 10}); - //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - NDArray expected = NDArrayFactory::create('c', {1, 10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4, - 4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10., - 8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12., - 9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6, - 5.8, 6.8, 7.8, 8.8, 7.0, 8., 9., 10., 8.2, 9.2, 10.2, 11.2, - 9.4,10.4, 11.4, 12.4,10.6, 11.6,12.6, 13.6,11.4, 12.4, 13.4, 14.4, - 11.4,12.4, 13.4, 14.4,11.4, 12.4,13.4, 14.4, 5.8, 6.8, 7.8, 8.8, - 7., 8., 9., 10., 8.2, 9.2,10.2, 11.2, 9.4, 10.4, 11.4, 12.4, - 10.6,11.6, 12.6, 13.6,11.8, 12.8,13.8, 14.8,13.0, 14.0, 15.0, 16., - 13.8,14.8, 15.8, 16.8,13.8, 14.8,15.8, 16.8,13.8, 14.8, 15.8, 16.8, - 8.2, 9.2, 10.2, 11.2, 9.4, 10.4,11.4, 12.4,10.6, 11.6, 12.6, 13.6, - 11.8,12.8, 13.8, 14.8,13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, - 15.4,16.4, 17.4, 18.4,16.2, 17.2,18.2, 19.2,16.2, 17.2, 18.2, 19.2, - 16.2,17.2, 18.2, 19.2,10.6, 11.6,12.6, 13.6,11.8, 12.8, 13.8, 14.8, - 13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4, - 16.6,17.6, 18.6, 19.6,17.8, 18.8,19.8, 20.8,18.6, 19.6, 20.6, 21.6, - 18.6,19.6, 20.6, 21.6,18.6, 19.6,20.6, 21.6,13., 14., 15., 16., - 14.2,15.2, 16.2, 17.2,15.4, 16.4,17.4, 18.4,16.6, 17.6, 18.6, 19.6, - 17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, - 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., - 13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4, - 16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22., - 20.2,21.2, 22.2, 23.2,21., 22., 23., 24., 21., 22., 23., 24., - 21., 22., 23., 24., 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, - 15.4,16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8, - 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2,21., 22., 23., 24., - 21., 22., 23., 24., 21., 22., 23., 24., 13., 14., 15., 16., - 14.2,15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6, - 17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, - 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., - 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4, - 16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22., - 20.2,21.2, 22.2, 23.2, - 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24.}); - //input = 1.f; - input.linspace(1); - - sd::ops::resize_bilinear op; - auto results = op.evaluate({&input, &size}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray* result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + NDArray size = NDArrayFactory::create({10, 10}); + // NDArray expected('c', {2,4,4}, + // {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 10, 10, 4}, + {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4, + 4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10., + 8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12., + 9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6, + 5.8, 6.8, 7.8, 8.8, 7.0, 8., 9., 10., 8.2, 9.2, 10.2, 11.2, + 9.4, 10.4, 11.4, 12.4, 10.6, 11.6, 12.6, 13.6, 11.4, 12.4, 13.4, 14.4, + 11.4, 12.4, 13.4, 14.4, 11.4, 12.4, 13.4, 14.4, 5.8, 6.8, 7.8, 8.8, + 7., 8., 9., 10., 8.2, 9.2, 10.2, 11.2, 9.4, 10.4, 11.4, 12.4, + 10.6, 11.6, 12.6, 13.6, 11.8, 12.8, 13.8, 14.8, 13.0, 14.0, 15.0, 16., + 13.8, 14.8, 15.8, 16.8, 13.8, 14.8, 15.8, 16.8, 13.8, 14.8, 15.8, 16.8, + 8.2, 9.2, 10.2, 11.2, 9.4, 10.4, 11.4, 12.4, 10.6, 11.6, 12.6, 13.6, + 11.8, 12.8, 13.8, 14.8, 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, + 15.4, 16.4, 17.4, 18.4, 16.2, 17.2, 18.2, 19.2, 16.2, 17.2, 18.2, 19.2, + 16.2, 17.2, 18.2, 19.2, 10.6, 11.6, 12.6, 13.6, 11.8, 12.8, 13.8, 14.8, + 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, + 16.6, 17.6, 18.6, 19.6, 17.8, 18.8, 19.8, 20.8, 18.6, 19.6, 20.6, 21.6, + 18.6, 19.6, 20.6, 21.6, 18.6, 19.6, 20.6, 21.6, 13., 14., 15., 16., + 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, 16.6, 17.6, 18.6, 19.6, + 17.8, 18.8, 19.8, 20.8, 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., + 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, + 16.6, 17.6, 18.6, 19.6, 17.8, 18.8, 19.8, 20.8, 19., 20., 21., 22., + 20.2, 21.2, 22.2, 23.2, 21., 22., 23., 24., 21., 22., 23., 24., + 21., 22., 23., 24., 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, + 15.4, 16.4, 17.4, 18.4, 16.6, 17.6, 18.6, 19.6, 17.8, 18.8, 19.8, 20.8, + 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, 21., 22., 23., 24., + 21., 22., 23., 24., 21., 22., 23., 24., 13., 14., 15., 16., + 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, 16.6, 17.6, 18.6, 19.6, + 17.8, 18.8, 19.8, 20.8, 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., + 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, + 16.6, 17.6, 18.6, 19.6, 17.8, 18.8, 19.8, 20.8, 19., 20., 21., 22., + 20.2, 21.2, 22.2, 23.2, 21., 22., 23., 24., 21., 22., 23., 24., + 21., 22., 23., 24.}); + // input = 1.f; + input.linspace(1); + + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test3) { - - NDArray input = NDArrayFactory::create('c', {1, 2,3,4}); - //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); - //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - NDArray expected = NDArrayFactory::create('c', {1, 10, 10, 4}, - { 1., 2., 3., 4. , - 1.8888888, 2.8888888, 3.8888888, 4.888889, - 2.7777777, 3.7777777, 4.7777777, 5.7777777, - 3.6666667, 4.666667 , 5.666667, 6.666667 , - 4.5555553, 5.5555553, 6.5555553, 7.5555553, - 5.4444447, 6.4444447, 7.4444447, 8.444445, - 6.3333335, 7.3333335, 8.333334, 9.333334, - 7.2222223, 8.222222, 9.222222, 10.222222, - 8.111111, 9.111111, 10.111111, 11.111111, - 9., 10., 11., 12., - - 2.3333335, 3.3333335, 4.3333335, 5.3333335, - 3.2222223, 4.2222223, 5.2222223, 6.2222223, - 4.111111, 5.111111, 6.111111, 7.111111, - 5., 6., 7., 8., - 5.888889, 6.888889, 7.888889, 8.888888, - 6.777778, 7.777778, 8.777778, 9.777778, - 7.666667, 8.666667, 9.666667, 10.666667, - 8.555555, 9.555555, 10.555555, 11.555555, - 9.444444, 10.444444, 11.444444, 12.444444, - 10.333333, 11.333333, 12.333333, 13.333333, - - 3.6666667, 4.666667, 5.666667, 6.666667, - 4.5555553, 5.5555553, 6.5555553, 7.5555553, - 5.4444447, 6.4444447, 7.4444447, 8.444445 , - 6.3333335, 7.3333335, 8.333334, 9.333334 , - 7.2222223, 8.222222, 9.222222, 10.222222 , - 8.111112, 9.111112, 10.111112, 11.111112 , - 9., 10., 11.000001, 12.000001 , - 9.888889, 10.888889, 11.888889, 12.888889 , - 10.777778, 11.777778, 12.777778, 13.777778 , - 11.666667, 12.666667, 13.666667, 14.666667, - - 5., 6., 7., 8., - 5.888889, 6.888889, 7.888889, 8.888889, - 6.7777777, 7.7777777, 8.777779, 9.777779, - 7.666667, 8.666667, 9.666667, 10.666667, - 8.555555, 9.555555, 10.555555, 11.555555, - 9.444445, 10.444445, 11.444445, 12.444445, - 10.333334, 11.333334, 12.333334, 13.333334, - 11.222222, 12.222222, 13.222222, 14.222222, - 12.111111, 13.111111, 14.111111, 15.111111, - 13., 14., 15., 16., - - 6.3333335, 7.3333335, 8.333334, 9.333334, - 7.2222223, 8.222222, 9.222222, 10.222222, - 8.111111, 9.111111, 10.111112, 11.111112, - 9., 10., 11., 12., - 9.888889, 10.888889, 11.888889, 12.888889, - 10.777779, 11.777779, 12.777779, 13.777779, - 11.666667, 12.666667, 13.666668, 14.666668, - 12.555555, 13.555555, 14.555555, 15.555555, - 13.444445, 14.444445, 15.444445, 16.444445, - 14.333334, 15.333334, 16.333334, 17.333334, - 7.666667, 8.666667, 9.666667, 10.666667, - 8.555555, 9.555555, 10.555555, 11.555555, - 9.444445, 10.444445, 11.444445, 12.444445, - 10.333334, 11.333334, 12.333334, 13.333334, - 11.222222, 12.222222, 13.222222, 14.222222, - 12.111112, 13.111112, 14.111112, 15.111112, - 13., 14., 15.0, 16., - 13.888889, 14.888889, 15.888889, 16.88889, - 14.777778, 15.777778, 16.777779, 17.777779, - 15.666667, 16.666668, 17.666668, 18.666668, - - 9., 10., 11., 12., - 9.888889, 10.888889, 11.888889, 12.888889, - 10.777778, 11.777778, 12.777779, 13.777779, - 11.666667, 12.666666, 13.666666, 14.666666, - 12.555555, 13.555555, 14.555555, 15.555555, - 13.444445, 14.444445, 15.444445, 16.444445, - 14.333334, 15.333334, 16.333334, 17.333334, - 15.222221, 16.222221, 17.222221, 18.222221, - 16.11111, 17.11111, 18.11111, 19.11111, - 17., 18., 19., 20., - - 10.333334, 11.333334, 12.333334, 13.333334, - 11.222223, 12.222223, 13.222223, 14.222223, - 12.111112, 13.111112, 14.111112, 15.111112, - 13.000001, 14., 15., 16., - 13.888889, 14.888889, 15.888889, 16.88889, - 14.777779, 15.777779, 16.777779, 17.777779, - 15.666668, 16.666668, 17.666668, 18.666668, - 16.555555, 17.555555, 18.555555, 19.555555, - 17.444445, 18.444445, 19.444445, 20.444445, - 18.333334, 19.333334, 20.333334, 21.333334, - 11.666667, 12.666667, 13.666667, 14.666667, - 12.555555, 13.555555, 14.555555, 15.555555, - 13.444445, 14.444445, 15.444446, 16.444447, - 14.333334, 15.333333, 16.333332, 17.333332, - 15.222222, 16.222221, 17.222221, 18.222221, - 16.11111, 17.11111, 18.11111, 19.11111, - 17., 18., 19., 20., - 17.88889, 18.88889, 19.88889, 20.88889, - 18.777779, 19.777779, 20.777779, 21.777779, - 19.666668, 20.666668, 21.666668, 22.666668, - - 13., 14., 15., 16., - 13.888889, 14.888889, 15.888889, 16.88889, - 14.777778, 15.777778, 16.777779, 17.777779, - 15.666667, 16.666666, 17.666666, 18.666666, - 16.555555, 17.555555, 18.555555, 19.555555, - 17.444445, 18.444445, 19.444445, 20.444445, - 18.333334, 19.333334, 20.333334, 21.333334, - 19.222221, 20.222221, 21.222221, 22.222221, - 20.11111, 21.11111, 22.11111, 23.11111, - 21., 22., 23., 24.}); - //input = 1.f; - input.linspace(1); - - sd::ops::resize_bilinear op; - auto results = op.evaluate({&input}, {}, {10, 10}, {true}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray* result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + // NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + // NDArray expected('c', {2,4,4}, + // {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 10, 10, 4}, + {1., 2., 3., 4., 1.8888888, 2.8888888, + 3.8888888, 4.888889, 2.7777777, 3.7777777, 4.7777777, 5.7777777, + 3.6666667, 4.666667, 5.666667, 6.666667, 4.5555553, 5.5555553, + 6.5555553, 7.5555553, 5.4444447, 6.4444447, 7.4444447, 8.444445, + 6.3333335, 7.3333335, 8.333334, 9.333334, 7.2222223, 8.222222, + 9.222222, 10.222222, 8.111111, 9.111111, 10.111111, 11.111111, + 9., 10., 11., 12., + + 2.3333335, 3.3333335, 4.3333335, 5.3333335, 3.2222223, 4.2222223, + 5.2222223, 6.2222223, 4.111111, 5.111111, 6.111111, 7.111111, + 5., 6., 7., 8., 5.888889, 6.888889, + 7.888889, 8.888888, 6.777778, 7.777778, 8.777778, 9.777778, + 7.666667, 8.666667, 9.666667, 10.666667, 8.555555, 9.555555, + 10.555555, 11.555555, 9.444444, 10.444444, 11.444444, 12.444444, + 10.333333, 11.333333, 12.333333, 13.333333, + + 3.6666667, 4.666667, 5.666667, 6.666667, 4.5555553, 5.5555553, + 6.5555553, 7.5555553, 5.4444447, 6.4444447, 7.4444447, 8.444445, + 6.3333335, 7.3333335, 8.333334, 9.333334, 7.2222223, 8.222222, + 9.222222, 10.222222, 8.111112, 9.111112, 10.111112, 11.111112, + 9., 10., 11.000001, 12.000001, 9.888889, 10.888889, + 11.888889, 12.888889, 10.777778, 11.777778, 12.777778, 13.777778, + 11.666667, 12.666667, 13.666667, 14.666667, + + 5., 6., 7., 8., 5.888889, 6.888889, + 7.888889, 8.888889, 6.7777777, 7.7777777, 8.777779, 9.777779, + 7.666667, 8.666667, 9.666667, 10.666667, 8.555555, 9.555555, + 10.555555, 11.555555, 9.444445, 10.444445, 11.444445, 12.444445, + 10.333334, 11.333334, 12.333334, 13.333334, 11.222222, 12.222222, + 13.222222, 14.222222, 12.111111, 13.111111, 14.111111, 15.111111, + 13., 14., 15., 16., + + 6.3333335, 7.3333335, 8.333334, 9.333334, 7.2222223, 8.222222, + 9.222222, 10.222222, 8.111111, 9.111111, 10.111112, 11.111112, + 9., 10., 11., 12., 9.888889, 10.888889, + 11.888889, 12.888889, 10.777779, 11.777779, 12.777779, 13.777779, + 11.666667, 12.666667, 13.666668, 14.666668, 12.555555, 13.555555, + 14.555555, 15.555555, 13.444445, 14.444445, 15.444445, 16.444445, + 14.333334, 15.333334, 16.333334, 17.333334, 7.666667, 8.666667, + 9.666667, 10.666667, 8.555555, 9.555555, 10.555555, 11.555555, + 9.444445, 10.444445, 11.444445, 12.444445, 10.333334, 11.333334, + 12.333334, 13.333334, 11.222222, 12.222222, 13.222222, 14.222222, + 12.111112, 13.111112, 14.111112, 15.111112, 13., 14., + 15.0, 16., 13.888889, 14.888889, 15.888889, 16.88889, + 14.777778, 15.777778, 16.777779, 17.777779, 15.666667, 16.666668, + 17.666668, 18.666668, + + 9., 10., 11., 12., 9.888889, 10.888889, + 11.888889, 12.888889, 10.777778, 11.777778, 12.777779, 13.777779, + 11.666667, 12.666666, 13.666666, 14.666666, 12.555555, 13.555555, + 14.555555, 15.555555, 13.444445, 14.444445, 15.444445, 16.444445, + 14.333334, 15.333334, 16.333334, 17.333334, 15.222221, 16.222221, + 17.222221, 18.222221, 16.11111, 17.11111, 18.11111, 19.11111, + 17., 18., 19., 20., + + 10.333334, 11.333334, 12.333334, 13.333334, 11.222223, 12.222223, + 13.222223, 14.222223, 12.111112, 13.111112, 14.111112, 15.111112, + 13.000001, 14., 15., 16., 13.888889, 14.888889, + 15.888889, 16.88889, 14.777779, 15.777779, 16.777779, 17.777779, + 15.666668, 16.666668, 17.666668, 18.666668, 16.555555, 17.555555, + 18.555555, 19.555555, 17.444445, 18.444445, 19.444445, 20.444445, + 18.333334, 19.333334, 20.333334, 21.333334, 11.666667, 12.666667, + 13.666667, 14.666667, 12.555555, 13.555555, 14.555555, 15.555555, + 13.444445, 14.444445, 15.444446, 16.444447, 14.333334, 15.333333, + 16.333332, 17.333332, 15.222222, 16.222221, 17.222221, 18.222221, + 16.11111, 17.11111, 18.11111, 19.11111, 17., 18., + 19., 20., 17.88889, 18.88889, 19.88889, 20.88889, + 18.777779, 19.777779, 20.777779, 21.777779, 19.666668, 20.666668, + 21.666668, 22.666668, + + 13., 14., 15., 16., 13.888889, 14.888889, + 15.888889, 16.88889, 14.777778, 15.777778, 16.777779, 17.777779, + 15.666667, 16.666666, 17.666666, 18.666666, 16.555555, 17.555555, + 18.555555, 19.555555, 17.444445, 18.444445, 19.444445, 20.444445, + 18.333334, 19.333334, 20.333334, 21.333334, 19.222221, 20.222221, + 21.222221, 22.222221, 20.11111, 21.11111, 22.11111, 23.11111, + 21., 22., 23., 24.}); + // input = 1.f; + input.linspace(1); + + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input}, {}, {10, 10}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test4) { - - NDArray input = NDArrayFactory::create('c', {1, 2,3,4}); - NDArray size = NDArrayFactory::create({10, 10}); - NDArray expected = NDArrayFactory::create('c', {1, 10, 10, 4}, - { 1., 2., 3., 4. , - 1.8888888, 2.8888888, 3.8888888, 4.888889, - 2.7777777, 3.7777777, 4.7777777, 5.7777777, - 3.6666667, 4.666667 , 5.666667, 6.666667 , - 4.5555553, 5.5555553, 6.5555553, 7.5555553, - 5.4444447, 6.4444447, 7.4444447, 8.444445, - 6.3333335, 7.3333335, 8.333334, 9.333334, - 7.2222223, 8.222222, 9.222222, 10.222222, - 8.111111, 9.111111, 10.111111, 11.111111, - 9., 10., 11., 12., - - 2.3333335, 3.3333335, 4.3333335, 5.3333335, - 3.2222223, 4.2222223, 5.2222223, 6.2222223, - 4.111111, 5.111111, 6.111111, 7.111111, - 5., 6., 7., 8., - 5.888889, 6.888889, 7.888889, 8.888888, - 6.777778, 7.777778, 8.777778, 9.777778, - 7.666667, 8.666667, 9.666667, 10.666667, - 8.555555, 9.555555, 10.555555, 11.555555, - 9.444444, 10.444444, 11.444444, 12.444444, - 10.333333, 11.333333, 12.333333, 13.333333, - - 3.6666667, 4.666667, 5.666667, 6.666667, - 4.5555553, 5.5555553, 6.5555553, 7.5555553, - 5.4444447, 6.4444447, 7.4444447, 8.444445 , - 6.3333335, 7.3333335, 8.333334, 9.333334 , - 7.2222223, 8.222222, 9.222222, 10.222222 , - 8.111112, 9.111112, 10.111112, 11.111112 , - 9., 10., 11.000001, 12.000001 , - 9.888889, 10.888889, 11.888889, 12.888889 , - 10.777778, 11.777778, 12.777778, 13.777778 , - 11.666667, 12.666667, 13.666667, 14.666667, - - 5., 6., 7., 8., - 5.888889, 6.888889, 7.888889, 8.888889, - 6.7777777, 7.7777777, 8.777779, 9.777779, - 7.666667, 8.666667, 9.666667, 10.666667, - 8.555555, 9.555555, 10.555555, 11.555555, - 9.444445, 10.444445, 11.444445, 12.444445, - 10.333334, 11.333334, 12.333334, 13.333334, - 11.222222, 12.222222, 13.222222, 14.222222, - 12.111111, 13.111111, 14.111111, 15.111111, - 13., 14., 15., 16., - - 6.3333335, 7.3333335, 8.333334, 9.333334, - 7.2222223, 8.222222, 9.222222, 10.222222, - 8.111111, 9.111111, 10.111112, 11.111112, - 9., 10., 11., 12., - 9.888889, 10.888889, 11.888889, 12.888889, - 10.777779, 11.777779, 12.777779, 13.777779, - 11.666667, 12.666667, 13.666668, 14.666668, - 12.555555, 13.555555, 14.555555, 15.555555, - 13.444445, 14.444445, 15.444445, 16.444445, - 14.333334, 15.333334, 16.333334, 17.333334, - 7.666667, 8.666667, 9.666667, 10.666667, - 8.555555, 9.555555, 10.555555, 11.555555, - 9.444445, 10.444445, 11.444445, 12.444445, - 10.333334, 11.333334, 12.333334, 13.333334, - 11.222222, 12.222222, 13.222222, 14.222222, - 12.111112, 13.111112, 14.111112, 15.111112, - 13., 14., 15.0, 16., - 13.888889, 14.888889, 15.888889, 16.88889, - 14.777778, 15.777778, 16.777779, 17.777779, - 15.666667, 16.666668, 17.666668, 18.666668, - - 9., 10., 11., 12., - 9.888889, 10.888889, 11.888889, 12.888889, - 10.777778, 11.777778, 12.777779, 13.777779, - 11.666667, 12.666666, 13.666666, 14.666666, - 12.555555, 13.555555, 14.555555, 15.555555, - 13.444445, 14.444445, 15.444445, 16.444445, - 14.333334, 15.333334, 16.333334, 17.333334, - 15.222221, 16.222221, 17.222221, 18.222221, - 16.11111, 17.11111, 18.11111, 19.11111, - 17., 18., 19., 20., - - 10.333334, 11.333334, 12.333334, 13.333334, - 11.222223, 12.222223, 13.222223, 14.222223, - 12.111112, 13.111112, 14.111112, 15.111112, - 13.000001, 14., 15., 16., - 13.888889, 14.888889, 15.888889, 16.88889, - 14.777779, 15.777779, 16.777779, 17.777779, - 15.666668, 16.666668, 17.666668, 18.666668, - 16.555555, 17.555555, 18.555555, 19.555555, - 17.444445, 18.444445, 19.444445, 20.444445, - 18.333334, 19.333334, 20.333334, 21.333334, - 11.666667, 12.666667, 13.666667, 14.666667, - 12.555555, 13.555555, 14.555555, 15.555555, - 13.444445, 14.444445, 15.444446, 16.444447, - 14.333334, 15.333333, 16.333332, 17.333332, - 15.222222, 16.222221, 17.222221, 18.222221, - 16.11111, 17.11111, 18.11111, 19.11111, - 17., 18., 19., 20., - 17.88889, 18.88889, 19.88889, 20.88889, - 18.777779, 19.777779, 20.777779, 21.777779, - 19.666668, 20.666668, 21.666668, 22.666668, - - 13., 14., 15., 16., - 13.888889, 14.888889, 15.888889, 16.88889, - 14.777778, 15.777778, 16.777779, 17.777779, - 15.666667, 16.666666, 17.666666, 18.666666, - 16.555555, 17.555555, 18.555555, 19.555555, - 17.444445, 18.444445, 19.444445, 20.444445, - 18.333334, 19.333334, 20.333334, 21.333334, - 19.222221, 20.222221, 21.222221, 22.222221, - 20.11111, 21.11111, 22.11111, 23.11111, - 21., 22., 23., 24.}); - //input = 1.f; - input.linspace(1); - - sd::ops::resize_bilinear op; - auto results = op.evaluate({&input, &size}, {}, {}, {true}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray* result = results.at(0); -// result.printIndexedBuffer("Resized to 10x10"); -// expected.printIndexedBuffer("Expected of 10x10"); -// result.printShapeInfo("Resized to 10x10 shape"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + NDArray size = NDArrayFactory::create({10, 10}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 10, 10, 4}, + {1., 2., 3., 4., 1.8888888, 2.8888888, + 3.8888888, 4.888889, 2.7777777, 3.7777777, 4.7777777, 5.7777777, + 3.6666667, 4.666667, 5.666667, 6.666667, 4.5555553, 5.5555553, + 6.5555553, 7.5555553, 5.4444447, 6.4444447, 7.4444447, 8.444445, + 6.3333335, 7.3333335, 8.333334, 9.333334, 7.2222223, 8.222222, + 9.222222, 10.222222, 8.111111, 9.111111, 10.111111, 11.111111, + 9., 10., 11., 12., + + 2.3333335, 3.3333335, 4.3333335, 5.3333335, 3.2222223, 4.2222223, + 5.2222223, 6.2222223, 4.111111, 5.111111, 6.111111, 7.111111, + 5., 6., 7., 8., 5.888889, 6.888889, + 7.888889, 8.888888, 6.777778, 7.777778, 8.777778, 9.777778, + 7.666667, 8.666667, 9.666667, 10.666667, 8.555555, 9.555555, + 10.555555, 11.555555, 9.444444, 10.444444, 11.444444, 12.444444, + 10.333333, 11.333333, 12.333333, 13.333333, + + 3.6666667, 4.666667, 5.666667, 6.666667, 4.5555553, 5.5555553, + 6.5555553, 7.5555553, 5.4444447, 6.4444447, 7.4444447, 8.444445, + 6.3333335, 7.3333335, 8.333334, 9.333334, 7.2222223, 8.222222, + 9.222222, 10.222222, 8.111112, 9.111112, 10.111112, 11.111112, + 9., 10., 11.000001, 12.000001, 9.888889, 10.888889, + 11.888889, 12.888889, 10.777778, 11.777778, 12.777778, 13.777778, + 11.666667, 12.666667, 13.666667, 14.666667, + + 5., 6., 7., 8., 5.888889, 6.888889, + 7.888889, 8.888889, 6.7777777, 7.7777777, 8.777779, 9.777779, + 7.666667, 8.666667, 9.666667, 10.666667, 8.555555, 9.555555, + 10.555555, 11.555555, 9.444445, 10.444445, 11.444445, 12.444445, + 10.333334, 11.333334, 12.333334, 13.333334, 11.222222, 12.222222, + 13.222222, 14.222222, 12.111111, 13.111111, 14.111111, 15.111111, + 13., 14., 15., 16., + + 6.3333335, 7.3333335, 8.333334, 9.333334, 7.2222223, 8.222222, + 9.222222, 10.222222, 8.111111, 9.111111, 10.111112, 11.111112, + 9., 10., 11., 12., 9.888889, 10.888889, + 11.888889, 12.888889, 10.777779, 11.777779, 12.777779, 13.777779, + 11.666667, 12.666667, 13.666668, 14.666668, 12.555555, 13.555555, + 14.555555, 15.555555, 13.444445, 14.444445, 15.444445, 16.444445, + 14.333334, 15.333334, 16.333334, 17.333334, 7.666667, 8.666667, + 9.666667, 10.666667, 8.555555, 9.555555, 10.555555, 11.555555, + 9.444445, 10.444445, 11.444445, 12.444445, 10.333334, 11.333334, + 12.333334, 13.333334, 11.222222, 12.222222, 13.222222, 14.222222, + 12.111112, 13.111112, 14.111112, 15.111112, 13., 14., + 15.0, 16., 13.888889, 14.888889, 15.888889, 16.88889, + 14.777778, 15.777778, 16.777779, 17.777779, 15.666667, 16.666668, + 17.666668, 18.666668, + + 9., 10., 11., 12., 9.888889, 10.888889, + 11.888889, 12.888889, 10.777778, 11.777778, 12.777779, 13.777779, + 11.666667, 12.666666, 13.666666, 14.666666, 12.555555, 13.555555, + 14.555555, 15.555555, 13.444445, 14.444445, 15.444445, 16.444445, + 14.333334, 15.333334, 16.333334, 17.333334, 15.222221, 16.222221, + 17.222221, 18.222221, 16.11111, 17.11111, 18.11111, 19.11111, + 17., 18., 19., 20., + + 10.333334, 11.333334, 12.333334, 13.333334, 11.222223, 12.222223, + 13.222223, 14.222223, 12.111112, 13.111112, 14.111112, 15.111112, + 13.000001, 14., 15., 16., 13.888889, 14.888889, + 15.888889, 16.88889, 14.777779, 15.777779, 16.777779, 17.777779, + 15.666668, 16.666668, 17.666668, 18.666668, 16.555555, 17.555555, + 18.555555, 19.555555, 17.444445, 18.444445, 19.444445, 20.444445, + 18.333334, 19.333334, 20.333334, 21.333334, 11.666667, 12.666667, + 13.666667, 14.666667, 12.555555, 13.555555, 14.555555, 15.555555, + 13.444445, 14.444445, 15.444446, 16.444447, 14.333334, 15.333333, + 16.333332, 17.333332, 15.222222, 16.222221, 17.222221, 18.222221, + 16.11111, 17.11111, 18.11111, 19.11111, 17., 18., + 19., 20., 17.88889, 18.88889, 19.88889, 20.88889, + 18.777779, 19.777779, 20.777779, 21.777779, 19.666668, 20.666668, + 21.666668, 22.666668, + + 13., 14., 15., 16., 13.888889, 14.888889, + 15.888889, 16.88889, 14.777778, 15.777778, 16.777779, 17.777779, + 15.666667, 16.666666, 17.666666, 18.666666, 16.555555, 17.555555, + 18.555555, 19.555555, 17.444445, 18.444445, 19.444445, 20.444445, + 18.333334, 19.333334, 20.333334, 21.333334, 19.222221, 20.222221, + 21.222221, 22.222221, 20.11111, 21.11111, 22.11111, 23.11111, + 21., 22., 23., 24.}); + // input = 1.f; + input.linspace(1); + + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input, &size}, {}, {}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printIndexedBuffer("Resized to 10x10"); + // expected.printIndexedBuffer("Expected of 10x10"); + // result.printShapeInfo("Resized to 10x10 shape"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, LinSpace_Test1) { + NDArray start = NDArrayFactory::create(1.); + NDArray finish = NDArrayFactory::create(12.); + NDArray num = NDArrayFactory::create(23); + NDArray expect = NDArrayFactory::create( + {1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, + 7., 7.5, 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}); - NDArray start = NDArrayFactory::create(1.); - NDArray finish = NDArrayFactory::create(12.); - NDArray num = NDArrayFactory::create(23); - NDArray expect = NDArrayFactory::create({1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, 7., 7.5, - 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}); - - sd::ops::lin_space op; - auto result = op.evaluate({&start, &finish, &num}, {}, {}); - ASSERT_EQ(result.status(), ND4J_STATUS_OK); - auto res = result.at(0); - - ASSERT_TRUE(expect.equalsTo(res)); + sd::ops::lin_space op; + auto result = op.evaluate({&start, &finish, &num}, {}, {}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto res = result.at(0); + ASSERT_TRUE(expect.equalsTo(res)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, LinSpace_Test2) { + NDArray expect = NDArrayFactory::create( + {1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, + 7., 7.5, 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}); - NDArray expect = NDArrayFactory::create({1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, 7., 7.5, - 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}); - - sd::ops::lin_space op; - auto result = op.evaluate({}, {1, 12}, {23}); - ASSERT_EQ(result.status(), ND4J_STATUS_OK); - auto res = result.at(0); - ASSERT_EQ( res->dataType(), sd::DataType::FLOAT32 ); - ASSERT_TRUE(expect.equalsTo(res)); - + sd::ops::lin_space op; + auto result = op.evaluate({}, {1, 12}, {23}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto res = result.at(0); + ASSERT_EQ(res.dataType(), sd::DataType::FLOAT32); + ASSERT_TRUE(expect.equalsTo(res)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, LinSpace_Test3) { + NDArray expect('c', {23}, + {1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, + 7., 7.5, 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}, + sd::DataType::DOUBLE); - NDArray expect('c', { 23 }, {1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, 7., 7.5, 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}, sd::DataType::DOUBLE ); - - sd::ops::lin_space op; - auto result = op.evaluate({}, {1, 12}, {23}, {}, { sd::DOUBLE }); - ASSERT_EQ(result.status(), ND4J_STATUS_OK); - auto res = result.at(0); - - ASSERT_EQ( res->dataType(), expect.dataType()); - ASSERT_TRUE(expect.equalsTo(res)); + sd::ops::lin_space op; + auto result = op.evaluate({}, {1, 12}, {23}, {}, {sd::DOUBLE}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto res = result.at(0); + ASSERT_EQ(res.dataType(), expect.dataType()); + ASSERT_TRUE(expect.equalsTo(res)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) { + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + // NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + // NDArray expected('c', {2,4,4}, + // {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 4, 5, 4}, {1, 2, 3, 4, 1, 2, 3, 4, 5, 6, + 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, - NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); - //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); - //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { - 1, 2, 3, 4, - 1, 2, 3, 4, - 5, 6, 7, 8, - 5, 6, 7, 8, - 9, 10, 11, 12, - - 1, 2, 3, 4, - 1, 2, 3, 4, - 5, 6, 7, 8, - 5, 6, 7, 8, - 9, 10, 11, 12, - - 13, 14, 15, 16, - 13, 14, 15, 16, - 17, 18, 19, 20, - 17, 18, 19, 20, - 21, 22, 23, 24, - - 13, 14, 15, 16, - 13, 14, 15, 16, - 17, 18, 19, 20, - 17, 18, 19, 20, - 21, 22, 23, 24 - }); - //input = 1.f; - input.linspace(1); + 1, 2, 3, 4, 1, 2, 3, 4, 5, 6, + 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, - sd::ops::resize_nearest_neighbor op; - auto results = op.evaluate({&input}, {}, {4, 5}, {false, false}); + 13, 14, 15, 16, 13, 14, 15, 16, 17, 18, + 19, 20, 17, 18, 19, 20, 21, 22, 23, 24, - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + 13, 14, 15, 16, 13, 14, 15, 16, 17, 18, + 19, 20, 17, 18, 19, 20, 21, 22, 23, 24}); + // input = 1.f; + input.linspace(1); - NDArray* result = results.at(0); + sd::ops::resize_nearest_neighbor op; + auto results = op.evaluate({&input}, {}, {4, 5}, {false, false}); -// result.printIndexedBuffer("Resized to 4x5"); -// expected.printIndexedBuffer("Expect for 4x5"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printIndexedBuffer("Resized to 4x5"); + // expected.printIndexedBuffer("Expect for 4x5"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1) { + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + // NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + // NDArray expected('c', {2,4,4}, + // {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 4, 5, 4}, {1, 2, 3, 4, 1, 2, 3, 4, 5, 6, + 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, - NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); - //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); - //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { - 1, 2, 3, 4, - 1, 2, 3, 4, - 5, 6, 7, 8, - 5, 6, 7, 8, - 9, 10, 11, 12, - - 1, 2, 3, 4, - 1, 2, 3, 4, - 5, 6, 7, 8, - 5, 6, 7, 8, - 9, 10, 11, 12, - - 13, 14, 15, 16, - 13, 14, 15, 16, - 17, 18, 19, 20, - 17, 18, 19, 20, - 21, 22, 23, 24, - - 13, 14, 15, 16, - 13, 14, 15, 16, - 17, 18, 19, 20, - 17, 18, 19, 20, - 21, 22, 23, 24 - }); - //input = 1.f; - input.linspace(1); + 1, 2, 3, 4, 1, 2, 3, 4, 5, 6, + 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, - sd::ops::resize_nearest_neighbor op; - auto results = op.evaluate({&input}, {}, {4, 5}); + 13, 14, 15, 16, 13, 14, 15, 16, 17, 18, + 19, 20, 17, 18, 19, 20, 21, 22, 23, 24, - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + 13, 14, 15, 16, 13, 14, 15, 16, 17, 18, + 19, 20, 17, 18, 19, 20, 21, 22, 23, 24}); + // input = 1.f; + input.linspace(1); - NDArray* result = results.at(0); + sd::ops::resize_nearest_neighbor op; + auto results = op.evaluate({&input}, {}, {4, 5}); -// result.printIndexedBuffer("Resized to 4x5"); -// expected.printIndexedBuffer("Expect for 4x5"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printIndexedBuffer("Resized to 4x5"); + // expected.printIndexedBuffer("Expect for 4x5"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1_1) { + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + // NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + // NDArray expected('c', {2,4,4}, + // {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 4, 5, 4}, + {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, + 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 9.f, 10.f, 11.f, 12.f, - NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); - //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); - //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { - 1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f, - 5.f, 6.f, 7.f, 8.f, - 9.f, 10.f, 11.f, 12.f, - 9.f, 10.f, 11.f, 12.f, - - 1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f, - 5.f, 6.f, 7.f, 8.f, - 9.f, 10.f, 11.f, 12.f, - 9.f, 10.f, 11.f, 12.f, - - 13.f, 14.f, 15.f, 16.f, - 13.f, 14.f, 15.f, 16.f, - 17.f, 18.f, 19.f, 20.f, - 21.f, 22.f, 23.f, 24.f, - 21.f, 22.f, 23.f, 24.f, - - 13.f, 14.f, 15.f, 16.f, - 13.f, 14.f, 15.f, 16.f, - 17.f, 18.f, 19.f, 20.f, - 21.f, 22.f, 23.f, 24.f, - 21.f, 22.f, 23.f, 24.f - }); - //input = 1.f; - input.linspace(1); + 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, + 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 9.f, 10.f, 11.f, 12.f, - sd::ops::resize_nearest_neighbor op; - auto results = op.evaluate({&input}, {}, {4,5}, {false, true}); + 13.f, 14.f, 15.f, 16.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 21.f, 22.f, 23.f, 24.f, - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + 13.f, 14.f, 15.f, 16.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 21.f, 22.f, 23.f, 24.f}); + // input = 1.f; + input.linspace(1); - NDArray* result = results.at(0); + sd::ops::resize_nearest_neighbor op; + auto results = op.evaluate({&input}, {}, {4, 5}, {false, true}); -// result.printIndexedBuffer("Resized to 4x5"); -// expected.printBuffer("Expect for 4x5"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printIndexedBuffer("Resized to 4x5"); + // expected.printBuffer("Expect for 4x5"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test01) { + NDArray input = NDArrayFactory::create('c', {2, 3, 4}); + // NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + // NDArray expected('c', {2,4,4}, + // {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create( + 'c', {4, 5, 4}, {1, 2, 3, 4, 1, 2, 3, 4, 5, 6, + 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, - NDArray input = NDArrayFactory::create('c', {2, 3, 4}); - //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); - //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - NDArray expected = NDArrayFactory::create('c', {4, 5, 4}, { 1, 2, 3, 4, - 1, 2, 3, 4, - 5, 6, 7, 8, - 5, 6, 7, 8, - 9, 10, 11, 12, - - 1, 2, 3, 4, - 1, 2, 3, 4, - 5, 6, 7, 8, - 5, 6, 7, 8, - 9, 10, 11, 12, - - 13, 14, 15, 16, - 13, 14, 15, 16, - 17, 18, 19, 20, - 17, 18, 19, 20, - 21, 22, 23, 24, - - 13, 14, 15, 16, - 13, 14, 15, 16, - 17, 18, 19, 20, - 17, 18, 19, 20, - 21, 22, 23, 24 - }); - //input = 1.f; - input.linspace(1); + 1, 2, 3, 4, 1, 2, 3, 4, 5, 6, + 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, - sd::ops::resize_nearest_neighbor op; - auto results = op.evaluate({&input}, {}, {4, 5}); + 13, 14, 15, 16, 13, 14, 15, 16, 17, 18, + 19, 20, 17, 18, 19, 20, 21, 22, 23, 24, - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + 13, 14, 15, 16, 13, 14, 15, 16, 17, 18, + 19, 20, 17, 18, 19, 20, 21, 22, 23, 24}); + // input = 1.f; + input.linspace(1); - NDArray* result = results.at(0); + sd::ops::resize_nearest_neighbor op; + auto results = op.evaluate({&input}, {}, {4, 5}); - //result.printIndexedBuffer("Resized to 4x5"); - //expected.printIndexedBuffer("Expect for 4x5"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printIndexedBuffer("Resized to 4x5"); + // expected.printIndexedBuffer("Expect for 4x5"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_1) { + NDArray input = + NDArrayFactory::create('c', {3, 3}, {0, 1, 0, 0, 1, 0, 0, 0, 0}); - NDArray input = NDArrayFactory::create ('c', {3,3}, {0, 1, 0, 0, 1, 0, 0, 0, 0}); + NDArray expected = NDArrayFactory::create(2.5206409f); - NDArray expected = NDArrayFactory::create(2.5206409f); + sd::ops::reduce_logsumexp op; + auto results = op.evaluate({&input}, {}, {}); - sd::ops::reduce_logsumexp op; - auto results = op.evaluate({&input}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_2) { + NDArray input = + NDArrayFactory::create('c', {3, 3}, {0, 1, 0, 0, 1, 0, 0, 0, 0}); - NDArray input = NDArrayFactory::create('c', {3,3}, {0, 1, 0, 0, 1, 0, 0, 0, 0}); - - NDArray expected = NDArrayFactory::create({1.0986123f, 1.8619947f, 1.0986123f}); - - sd::ops::reduce_logsumexp op; - auto results = op.evaluate({&input}, {}, {0}); + NDArray expected = + NDArrayFactory::create({1.0986123f, 1.8619947f, 1.0986123f}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// result.printIndexedBuffer("REDUCE_LOGSUMEXP"); -// expected.printIndexedBuffer("LSE EXPECTED"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + sd::ops::reduce_logsumexp op; + auto results = op.evaluate({&input}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + // result.printIndexedBuffer("REDUCE_LOGSUMEXP"); + // expected.printIndexedBuffer("LSE EXPECTED"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_3) { + NDArray input = + NDArrayFactory::create('c', {3, 3}, {0, 1, 0, 0, 1, 0, 0, 0, 0}); - NDArray input = NDArrayFactory::create('c', {3,3}, {0, 1, 0, 0, 1, 0, 0, 0, 0}); - - NDArray expected = NDArrayFactory::create('c', {1,3}, {1.0986123f, 1.8619947f, 1.0986123f}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 3}, {1.0986123f, 1.8619947f, 1.0986123f}); - sd::ops::reduce_logsumexp op; - auto results = op.evaluate({&input}, {1.f}, {0}); + sd::ops::reduce_logsumexp op; + auto results = op.evaluate({&input}, {1.f}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); -// result.printIndexedBuffer("REDUCE_LOGSUMEXP"); -// expected.printIndexedBuffer("LSE EXPECTED"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result.printIndexedBuffer("REDUCE_LOGSUMEXP"); + // expected.printIndexedBuffer("LSE EXPECTED"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_1) { + NDArray boxes = NDArrayFactory::create('c', {3, 4}); + NDArray scores = NDArrayFactory::create('c', {3}, {1, 2, 3}); + NDArray expected = NDArrayFactory::create('c', {3}, {2, 1, 0}); + boxes.linspace(1.f); - NDArray boxes = NDArrayFactory::create('c', {3,4}); - NDArray scores = NDArrayFactory::create('c', {3}, {1, 2, 3}); - NDArray expected = NDArrayFactory::create('c', {3}, {2, 1, 0}); - boxes.linspace(1.f); - - sd::ops::non_max_suppression op; - auto results = op.evaluate({&boxes, &scores}, {}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray* result = results.at(0); - //result.printIndexedBuffer("OOOOUUUUTTT"); + sd::ops::non_max_suppression op; + auto results = op.evaluate({&boxes, &scores}, {}, {3}); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + // result.printIndexedBuffer("OOOOUUUUTTT"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) { + NDArray boxes = NDArrayFactory::create( + 'c', {6, 4}, {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1.f, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1.f, 11.1f, 0, 100, 1, 101}); + NDArray scales = NDArrayFactory::create( + 'c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); // 3, 0, 1, 2, 4, 5 + NDArray expected = NDArrayFactory::create('c', {3}, {3, 0, 5}); - NDArray boxes = NDArrayFactory::create('c', {6,4}, {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1.f, 0.9f, - 0, 10, 1, 11, 0, 10.1f, 1.f, 11.1f, 0, 100, 1, 101}); - NDArray scales = NDArrayFactory::create('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); //3, 0, 1, 2, 4, 5 - NDArray expected = NDArrayFactory::create('c', {3}, {3,0,5}); - - sd::ops::non_max_suppression op; - auto results = op.evaluate({&boxes, &scales}, {0.5}, {3}); + sd::ops::non_max_suppression op; + auto results = op.evaluate({&boxes, &scales}, {0.5}, {3}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); -// result.printBuffer("NonMaxSuppression OUtput2"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result.printBuffer("NonMaxSuppression OUtput2"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_3) { + NDArray boxes = NDArrayFactory::create( + 'c', {3, 4}, + {0.8115f, 0.4121f, 0.0771f, 0.4863f, 0.7412f, 0.7607f, 0.1543f, 0.5479f, + 0.8223f, 0.2246f, 0.0049f, 0.6465f}); + NDArray scales = NDArrayFactory::create( + 'c', {3}, {0.0029f, 0.8135f, 0.4873f}); // 3, 0, 1, 2, 4, 5 + NDArray expected = NDArrayFactory::create('c', {1}, {1}); - NDArray boxes = NDArrayFactory::create('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f, - 0.7412f, 0.7607f, 0.1543f, 0.5479f, - 0.8223f, 0.2246f, 0.0049f, 0.6465f}); - NDArray scales = NDArrayFactory::create('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5 - NDArray expected = NDArrayFactory::create('c', {1}, {1}); - - sd::ops::non_max_suppression op; - auto results = op.evaluate({&boxes, &scales}, {0.5, 0.5}, {2}); + sd::ops::non_max_suppression op; + auto results = op.evaluate({&boxes, &scales}, {0.5, 0.5}, {2}); - ASSERT_EQ(Status::OK(), results.status()); + ASSERT_EQ(Status::OK(), results.status()); - NDArray* result = results.at(0); -// result.printBuffer("NonMaxSuppression OUtput3"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result.printBuffer("NonMaxSuppression OUtput3"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_4) { - - NDArray boxes = NDArrayFactory::create('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f, - 0.7412f, 0.7607f, 0.1543f, 0.5479f, - 0.8223f, 0.2246f, 0.0049f, 0.6465f}); - NDArray scales = NDArrayFactory::create('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5 - NDArray expected = NDArrayFactory::create('c', {1}, {1}); - NDArray maxSize = NDArrayFactory::create(2); - NDArray threshold = NDArrayFactory::create(0.5f); - NDArray scoreThreshold = NDArrayFactory::create(0.5); - sd::ops::non_max_suppression op; - auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); - - ASSERT_EQ(Status::OK(), results.status()); - - NDArray* result = results.at(0); -// result.printBuffer("NonMaxSuppression OUtput4"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray boxes = NDArrayFactory::create( + 'c', {3, 4}, + {0.8115f, 0.4121f, 0.0771f, 0.4863f, 0.7412f, 0.7607f, 0.1543f, 0.5479f, + 0.8223f, 0.2246f, 0.0049f, 0.6465f}); + NDArray scales = NDArrayFactory::create( + 'c', {3}, {0.0029f, 0.8135f, 0.4873f}); // 3, 0, 1, 2, 4, 5 + NDArray expected = NDArrayFactory::create('c', {1}, {1}); + NDArray maxSize = NDArrayFactory::create(2); + NDArray threshold = NDArrayFactory::create(0.5f); + NDArray scoreThreshold = NDArrayFactory::create(0.5); + sd::ops::non_max_suppression op; + auto results = op.evaluate( + {&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto result = results.at(0); + // result.printBuffer("NonMaxSuppression OUtput4"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_5) { - - NDArray boxes = NDArrayFactory::create('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f, - 0.7412f, 0.7607f, 0.1543f, 0.5479f, - 0.8223f, 0.2246f, 0.0049f, 0.6465f}); - NDArray scales = NDArrayFactory::create('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5 - NDArray expected = NDArrayFactory::create('c', {2}, {1, 2}); - NDArray maxSize = NDArrayFactory::create(2); - NDArray threshold = NDArrayFactory::create(0.5f); - NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax()); - sd::ops::non_max_suppression op; - auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); - - ASSERT_EQ(Status::OK(), results.status()); - - NDArray* result = results.at(0); -// result.printBuffer("NonMaxSuppression OUtput4"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray boxes = NDArrayFactory::create( + 'c', {3, 4}, + {0.8115f, 0.4121f, 0.0771f, 0.4863f, 0.7412f, 0.7607f, 0.1543f, 0.5479f, + 0.8223f, 0.2246f, 0.0049f, 0.6465f}); + NDArray scales = NDArrayFactory::create( + 'c', {3}, {0.0029f, 0.8135f, 0.4873f}); // 3, 0, 1, 2, 4, 5 + NDArray expected = NDArrayFactory::create('c', {2}, {1, 2}); + NDArray maxSize = NDArrayFactory::create(2); + NDArray threshold = NDArrayFactory::create(0.5f); + NDArray scoreThreshold = + NDArrayFactory::create(-DataTypeUtils::infOrMax()); + sd::ops::non_max_suppression op; + auto results = op.evaluate( + {&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto result = results.at(0); + // result.printBuffer("NonMaxSuppression OUtput4"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_6) { - - NDArray boxes = NDArrayFactory::create('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f, - 0.7412f, 0.7607f, 0.1543f, 0.5479f, - 0.8223f, 0.2246f, 0.0049f, 0.6465f}); - NDArray scales = NDArrayFactory::create('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5 - NDArray expected = NDArrayFactory::create('c', {2}, {1,2}); - NDArray maxSize = NDArrayFactory::create(2); - NDArray threshold = NDArrayFactory::create(0.5f); - NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax()); - sd::ops::non_max_suppression_v3 op; - auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); - - ASSERT_EQ(Status::OK(), results.status()); - - NDArray* result = results.at(0); -// result.printBuffer("NonMaxSuppression OUtput6"); -// result.printShapeInfo("Ouput6 shape is"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray boxes = NDArrayFactory::create( + 'c', {3, 4}, + {0.8115f, 0.4121f, 0.0771f, 0.4863f, 0.7412f, 0.7607f, 0.1543f, 0.5479f, + 0.8223f, 0.2246f, 0.0049f, 0.6465f}); + NDArray scales = NDArrayFactory::create( + 'c', {3}, {0.0029f, 0.8135f, 0.4873f}); // 3, 0, 1, 2, 4, 5 + NDArray expected = NDArrayFactory::create('c', {2}, {1, 2}); + NDArray maxSize = NDArrayFactory::create(2); + NDArray threshold = NDArrayFactory::create(0.5f); + NDArray scoreThreshold = + NDArrayFactory::create(-DataTypeUtils::infOrMax()); + sd::ops::non_max_suppression_v3 op; + auto results = op.evaluate( + {&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto result = results.at(0); + // result.printBuffer("NonMaxSuppression OUtput6"); + // result.printShapeInfo("Ouput6 shape is"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_06) { - - NDArray boxes = NDArrayFactory::create('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f, - 0.7412f, 0.7607f, 0.1543f, 0.5479f, - 0.8223f, 0.2246f, 0.0049f, 0.6465f}); - NDArray scales = NDArrayFactory::create('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5 - NDArray expected = NDArrayFactory::create('c', {2}, {1,2}); - NDArray maxSize = NDArrayFactory::create(2); - NDArray threshold = NDArrayFactory::create(0.5f); - NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax()); - sd::ops::non_max_suppression_v3 op; - auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); - - ASSERT_EQ(Status::OK(), results.status()); - - NDArray* result = results.at(0); -// result.printBuffer("NonMaxSuppression OUtput06"); -// result.printShapeInfo("Ouput06 shape is"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray boxes = NDArrayFactory::create( + 'c', {3, 4}, + {0.8115f, 0.4121f, 0.0771f, 0.4863f, 0.7412f, 0.7607f, 0.1543f, 0.5479f, + 0.8223f, 0.2246f, 0.0049f, 0.6465f}); + NDArray scales = NDArrayFactory::create( + 'c', {3}, {0.0029f, 0.8135f, 0.4873f}); // 3, 0, 1, 2, 4, 5 + NDArray expected = NDArrayFactory::create('c', {2}, {1, 2}); + NDArray maxSize = NDArrayFactory::create(2); + NDArray threshold = NDArrayFactory::create(0.5f); + NDArray scoreThreshold = + NDArrayFactory::create(-DataTypeUtils::infOrMax()); + sd::ops::non_max_suppression_v3 op; + auto results = op.evaluate( + {&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto result = results.at(0); + // result.printBuffer("NonMaxSuppression OUtput06"); + // result.printShapeInfo("Ouput06 shape is"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_7) { - - NDArray boxes = NDArrayFactory::create('c', {3, 4}, {0.7788f, 0.8012f, 0.7244f, 0.2329f, - 0.7271f, 0.1804f, 0.5056f, 0.8929f, - 0.5461f, 0.9234f, 0.0856f, 0.7938f}); - NDArray scales = NDArrayFactory::create('c', {3}, {0.7717f, 0.9281f, 0.9846f}); //3, 0, 1, 2, 4, 5 - NDArray maxSize = NDArrayFactory::create(0); - NDArray threshold = NDArrayFactory::create(0.5f); - NDArray scoreThreshold = NDArrayFactory::create(0.5f); - sd::ops::non_max_suppression_v3 op; - auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); - - ASSERT_EQ(Status::OK(), results.status()); - - NDArray* result = results.at(0); -// result.printBuffer("NonMaxSuppression OUtput7"); -// result.printShapeInfo("Ouput6 shape is"); - ASSERT_TRUE(result->isEmpty()); + NDArray boxes = NDArrayFactory::create( + 'c', {3, 4}, + {0.7788f, 0.8012f, 0.7244f, 0.2329f, 0.7271f, 0.1804f, 0.5056f, 0.8929f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f}); + NDArray scales = NDArrayFactory::create( + 'c', {3}, {0.7717f, 0.9281f, 0.9846f}); // 3, 0, 1, 2, 4, 5 + NDArray maxSize = NDArrayFactory::create(0); + NDArray threshold = NDArrayFactory::create(0.5f); + NDArray scoreThreshold = NDArrayFactory::create(0.5f); + sd::ops::non_max_suppression_v3 op; + auto results = op.evaluate( + {&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto result = results.at(0); + // result.printBuffer("NonMaxSuppression OUtput7"); + // result.printShapeInfo("Ouput6 shape is"); + ASSERT_TRUE(result.isEmpty()); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_1) { + auto boxes = NDArrayFactory::create('c', {4, 4}, + {0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, 0, 10, 1, 11}); + auto scores = NDArrayFactory::create('c', {4}, {0.9, .75, .6, .95}); // 3 + auto max_num = NDArrayFactory::create(3); + auto expected = NDArrayFactory::create('c',{1}, {3}); - NDArray boxes = NDArrayFactory::create('c', {4,4}, { - 0, 0, 1, 1, - 0, 0.1, 1, 1.1, - 0, -0.1, 1, 0.9, - 0, 10, 1, 11}); - NDArray scores = NDArrayFactory::create('c', {4}, {0.9, .75, .6, .95}); //3 - NDArray max_num = NDArrayFactory::create(3); - NDArray expected = NDArrayFactory::create('c', {1,}, {3}); + sd::ops::non_max_suppression_overlaps op; + auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {}); - sd::ops::non_max_suppression_overlaps op; - auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray* result = results.at(0); -// result.printBuffer("NonMaxSuppressionOverlap1 Output"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result.printBuffer("NonMaxSuppressionOverlap1 Output"); + ASSERT_EQ(expected, result); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_2) { + auto boxes = NDArrayFactory::create( + 'c', {4, 4}, {0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, 0, 10, 1, 11}); + auto scores = NDArrayFactory::create('c', {4}, {0.9, .95, .6, .75}); // 3 + auto max_num = NDArrayFactory::create(3); + auto expected = NDArrayFactory::create('c', {3}, {1, 1, 1}); - NDArray boxes = NDArrayFactory::create('c', {4,4}, { - 0, 0, 1, 1, - 0, 0.1, 1, 1.1, - 0, -0.1, 1, 0.9, - 0, 10, 1, 11}); - NDArray scores = NDArrayFactory::create('c', {4}, {0.9, .95, .6, .75}); //3 - NDArray max_num = NDArrayFactory::create(3); - NDArray expected = NDArrayFactory::create('c', {3,}, {1,1,1}); - - sd::ops::non_max_suppression_overlaps op; - auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {}); + sd::ops::non_max_suppression_overlaps op; + auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); -// result.printBuffer("NonMaxSuppressionOverlap Output"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result.printBuffer("NonMaxSuppressionOverlap Output"); + ASSERT_EQ(expected, result); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_3) { + auto boxes = NDArrayFactory::create('c', {4, 4}, + {0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, 0, 10, 1, 11}); + auto scores = NDArrayFactory::create('c', {4}, {0.5, .95, -.6, .75}); // 3 + auto max_num = NDArrayFactory::create(5); + auto expected = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); - NDArray boxes = NDArrayFactory::create('c', {4,4}, { - 0, 0, 1, 1, - 0, 0.1, 1, 1.1, - 0, -0.1, 1, 0.9, - 0, 10, 1, 11}); - NDArray scores = NDArrayFactory::create('c', {4}, {0.5, .95, -.6, .75}); //3 - NDArray max_num = NDArrayFactory::create(5); - NDArray expected = NDArrayFactory::create('c', {5,}, {1,1,1,1,1}); + sd::ops::non_max_suppression_overlaps op; + auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {}); - sd::ops::non_max_suppression_overlaps op; - auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray* result = results.at(0); -// result.printBuffer("NonMaxSuppressionOverlap Output"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result.printBuffer("NonMaxSuppressionOverlap Output"); + ASSERT_EQ(expected, result); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_CropAndResize_1) { - int axis = 0; - NDArray images = NDArrayFactory::create('c', {1,2,2,1}, {1,2,3,4}); - NDArray boxes = NDArrayFactory::create('c', {1,4}, {0,0,1,1}); - NDArray boxI = NDArrayFactory::create('c', {1}, {axis}); - NDArray cropSize = NDArrayFactory::create({1, 1}); + int axis = 0; + auto images = NDArrayFactory::create('c', {1, 2, 2, 1}, {1, 2, 3, 4}); + auto boxes = NDArrayFactory::create('c', {1, 4}, {0, 0, 1, 1}); + auto boxI = NDArrayFactory::create('c', {1}, {axis}); + auto cropSize = NDArrayFactory::create({1, 1}); - //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); - NDArray expected = NDArrayFactory::create('c', {1,1,1,1}, {2.5f}); + // NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); + NDArray expected = NDArrayFactory::create('c', {1, 1, 1, 1}, {2.5f}); - sd::ops::crop_and_resize op; - auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {}); + sd::ops::crop_and_resize op; + auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); -// result.printIndexedBuffer("Cropped and Resized"); + auto result = results.at(0); + // result.printIndexedBuffer("Cropped and Resized"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_CropAndResize_2) { - int axis = 0; - NDArray images = NDArrayFactory::create('c', {1,2,2,1}, {1.f, 2.f, 3.f, 4.f}); - NDArray boxes = NDArrayFactory::create('c', {1,4}, {0.f, 0.f, 1.f, 1.f}); - NDArray boxI = NDArrayFactory::create('c', {1}, {axis}); - NDArray cropSize = NDArrayFactory::create({1, 1}); + int axis = 0; + NDArray images = + NDArrayFactory::create('c', {1, 2, 2, 1}, {1.f, 2.f, 3.f, 4.f}); + NDArray boxes = + NDArrayFactory::create('c', {1, 4}, {0.f, 0.f, 1.f, 1.f}); + NDArray boxI = NDArrayFactory::create('c', {1}, {axis}); + NDArray cropSize = NDArrayFactory::create({1, 1}); - //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); - NDArray expected = NDArrayFactory::create('c', {1,1,1,1}, {4.f}); + // NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); + NDArray expected = NDArrayFactory::create('c', {1, 1, 1, 1}, {4.f}); - sd::ops::crop_and_resize op; - auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {1}); + sd::ops::crop_and_resize op; + auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) { + NDArray images('c', {1, 2, 2, 1}, {1, 2, 3, 4}, sd::DataType::FLOAT32); + NDArray boxes('c', {1, 4}, {0, 0, 1, 1}, sd::DataType::FLOAT32); + NDArray boxI('c', {1}, std::vector{0}, sd::DataType::INT64); + NDArray cropSize = NDArrayFactory::create({3, 3}); - NDArray images ('c', {1,2,2,1}, {1,2,3,4}, sd::DataType::FLOAT32); - NDArray boxes('c', {1,4}, {0,0,1,1}, sd::DataType::FLOAT32); - NDArray boxI('c', {1}, std::vector{0}, sd::DataType::INT64); - NDArray cropSize = NDArrayFactory::create({3, 3}); + // NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); + NDArray expected('c', {1, 3, 3, 1}, + {1.f, 1.5f, 2., 2.f, 2.5f, 3.f, 3.f, 3.5f, 4.f}, + sd::DataType::FLOAT32); - //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); - NDArray expected('c', {1,3,3,1}, {1.f, 1.5f, 2., 2.f, 2.5f, 3.f, 3.f, 3.5f, 4.f}, sd::DataType::FLOAT32); + sd::ops::crop_and_resize op; + auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {0}); - sd::ops::crop_and_resize op; - auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) { + NDArray images('c', {1, 2, 2, 1}, {1, 2, 3, 4}, sd::DataType::FLOAT32); + NDArray boxes('c', {1, 4}, {0, 0, 1, 1}, sd::DataType::FLOAT32); + NDArray boxI('c', {1}, std::vector({0.}), sd::DataType::INT32); + NDArray cropSize = NDArrayFactory::create({3, 3}); - NDArray images('c', {1,2,2,1}, {1, 2, 3, 4}, sd::DataType::FLOAT32); - NDArray boxes('c', {1,4}, {0,0,1,1}, sd::DataType::FLOAT32); - NDArray boxI('c', {1}, std::vector({0.}), sd::DataType::INT32); - NDArray cropSize = NDArrayFactory::create({3, 3}); - - //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); - NDArray expected('c', {1,3,3,1}, {1.f, 2.f, 2.f, 3.f, 4, 4.f, 3.f, 4.f, 4.f}, sd::DataType::FLOAT32); + // NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); + NDArray expected('c', {1, 3, 3, 1}, + {1.f, 2.f, 2.f, 3.f, 4, 4.f, 3.f, 4.f, 4.f}, + sd::DataType::FLOAT32); - sd::ops::crop_and_resize op; - auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {1}); + sd::ops::crop_and_resize op; + auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - // result.printIndexedBuffer("Cropped and Resized"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result.printIndexedBuffer("Cropped and Resized"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_CropAndResize_5) { + NDArray images('c', {1, 100, 100, 3}, sd::DataType::FLOAT32); + NDArray boxes('c', {1, 4}, {0, 0, 1, 1}, sd::DataType::FLOAT32); + NDArray boxI('c', {2}, {1, 1}, sd::DataType::INT32); + NDArray cropSize = NDArrayFactory::create({10, 10}); - NDArray images('c', {1, 100, 100, 3}, sd::DataType::FLOAT32); - NDArray boxes('c', {1,4}, {0,0,1,1}, sd::DataType::FLOAT32); - NDArray boxI('c', {2}, {1,1}, sd::DataType::INT32); - NDArray cropSize = NDArrayFactory::create({10, 10}); - - //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); - NDArray expected('c', {1, 10, 10,3}, sd::DataType::FLOAT32); + // NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); + NDArray expected('c', {1, 10, 10, 3}, sd::DataType::FLOAT32); - sd::ops::crop_and_resize op; - auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {1}); + sd::ops::crop_and_resize op; + auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - //ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + // ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_1) { - NDArray images = NDArrayFactory::create('c', {2,4,5,3}); - NDArray boxes = NDArrayFactory::create('c', {2, 2, 4}, { - 0.f , 0.f , 1.f , 1.f , 0.1f, 0.2f, 0.9f, 0.8f, - 0.3f, 0.3f, 0.7f, 0.7f, 0.4f, 0.4f, 0.6f, 0.6f - }); - - NDArray colors = NDArrayFactory::create('c', {2, 3}, {201.f, 202.f, 203.f, 127.f, 128.f, 129.f}); - - //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); - NDArray expected = NDArrayFactory::create('c', {2,4,5,3}, { - 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 201.f, 202.f, 203.f, - 127.f, 128.f, 129.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 127.f, 128.f, 129.f, 201.f, 202.f, 203.f, - 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 201.f, 202.f, 203.f, - 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, - - 61.f, 62.f, 63.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 70.f, 71.f, 72.f, 73.f, 74.f, 75.f, - 76.f, 77.f, 78.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, - 91.f, 92.f, 93.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 100.f, 101.f, 102.f, 103.f, 104.f, 105.f, - 106.f, 107.f, 108.f, 109.f, 110.f, 111.f, 112.f, 113.f, 114.f, 115.f, 116.f, 117.f, 118.f, 119.f, 120.f - }); - images.linspace(1.); - sd::ops::draw_bounding_boxes op; - auto results = op.evaluate({&images, &boxes, &colors}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - result->syncToHost(); -// result.printBuffer("Bounded boxes"); -// expected.printBuffer("Bounded expec"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray images = NDArrayFactory::create('c', {2, 4, 5, 3}); + NDArray boxes = NDArrayFactory::create( + 'c', {2, 2, 4}, + {0.f, 0.f, 1.f, 1.f, 0.1f, 0.2f, 0.9f, 0.8f, 0.3f, 0.3f, 0.7f, 0.7f, 0.4f, 0.4f, 0.6f, 0.6f}); + + NDArray colors = NDArrayFactory::create( + 'c', {2, 3}, {201.f, 202.f, 203.f, 127.f, 128.f, 129.f}); + + // NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); + NDArray expected = NDArrayFactory::create( + 'c', {2, 4, 5, 3}, + {127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, + 128.f, 129.f, 201.f, 202.f, 203.f, 127.f, 128.f, 129.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, 127.f, 128.f, 129.f, 201.f, 202.f, 203.f, + 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, + 128.f, 129.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 201.f, 202.f, + 203.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, + + 61.f, 62.f, 63.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 70.f, + 71.f, 72.f, 73.f, 74.f, 75.f, 76.f, 77.f, 78.f, 127.f, 128.f, + 129.f, 127.f, 128.f, 129.f, 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, + 91.f, 92.f, 93.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 100.f, + 101.f, 102.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 109.f, 110.f, + 111.f, 112.f, 113.f, 114.f, 115.f, 116.f, 117.f, 118.f, 119.f, 120.f}); + + images.linspace(1.); + sd::ops::draw_bounding_boxes op; + auto results = op.evaluate({&images, &boxes, &colors}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + result.syncToHost(); + // result.printBuffer("Bounded boxes"); + // expected.printBuffer("Bounded expec"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_2) { - NDArray images = NDArrayFactory::create('c', {1,9,9,1}); - NDArray boxes = NDArrayFactory::create('c', {1, 1, 4}, {0.2f, 0.2f, 0.7f, 0.7f}); - NDArray colors = NDArrayFactory::create('c', {1, 1}, {0.95f}); - - //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); - NDArray expected = NDArrayFactory::create('c', {1,9,9,1}, { - 1.1f , 2.1f, 3.1f, 4.1f, 5.1f, 6.1f, 7.1f , 8.1f , 9.1f , - 10.1f , 0.95f, 0.95f, 0.95f, 0.95f, 0.95f, 16.1f , 17.1f , 18.1f , - 19.1f , 0.95f, 21.1f, 22.1f, 23.1f, 0.95f, 25.1f , 26.1f , 27.1f , - 28.1f , 0.95f, 30.1f, 31.1f, 32.1f, 0.95f, 34.1f , 35.1f , 36.1f , - 37.1f , 0.95f, 39.1f, 40.1f, 41.1f, 0.95f, 43.1f , 44.1f , 45.1f , - 46.1f , 0.95f, 0.95f, 0.95f, 0.95f, 0.95f, 52.1f , 53.1f , 54.1f , - 55.1f , 56.1f, 57.1f, 58.1f, 59.1f , 60.1f, 61.1f , 62.1f , 63.1f , - 64.1f , 65.1f, 66.1f, 67.1f, 68.1f , 69.1f, 70.1f , 71.1f , 72.1f , - 73.1f , 74.1f, 75.1f, 76.1f, 77.1f , 78.1f, 79.1f , 80.1f , 81.1f }); - images.linspace(1.1); - sd::ops::draw_bounding_boxes op; - auto results = op.evaluate({&images, &boxes, &colors}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// result.syncToHost(); -// result.printBuffer("Bounded boxes 2"); -// expected.printBuffer("Bounded expec 2"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray images = NDArrayFactory::create('c', {1, 9, 9, 1}); + NDArray boxes = + NDArrayFactory::create('c', {1, 1, 4}, {0.2f, 0.2f, 0.7f, 0.7f}); + NDArray colors = NDArrayFactory::create('c', {1, 1}, {0.95f}); + + // NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 9, 9, 1}, + {1.1f, 2.1f, 3.1f, 4.1f, 5.1f, 6.1f, 7.1f, 8.1f, 9.1f, + 10.1f, 0.95f, 0.95f, 0.95f, 0.95f, 0.95f, 16.1f, 17.1f, 18.1f, + 19.1f, 0.95f, 21.1f, 22.1f, 23.1f, 0.95f, 25.1f, 26.1f, 27.1f, + 28.1f, 0.95f, 30.1f, 31.1f, 32.1f, 0.95f, 34.1f, 35.1f, 36.1f, + 37.1f, 0.95f, 39.1f, 40.1f, 41.1f, 0.95f, 43.1f, 44.1f, 45.1f, + 46.1f, 0.95f, 0.95f, 0.95f, 0.95f, 0.95f, 52.1f, 53.1f, 54.1f, + 55.1f, 56.1f, 57.1f, 58.1f, 59.1f, 60.1f, 61.1f, 62.1f, 63.1f, + 64.1f, 65.1f, 66.1f, 67.1f, 68.1f, 69.1f, 70.1f, 71.1f, 72.1f, + 73.1f, 74.1f, 75.1f, 76.1f, 77.1f, 78.1f, 79.1f, 80.1f, 81.1f}); + images.linspace(1.1); + sd::ops::draw_bounding_boxes op; + auto results = op.evaluate({&images, &boxes, &colors}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.syncToHost(); + // result.printBuffer("Bounded boxes 2"); + // expected.printBuffer("Bounded expec 2"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_3) { - NDArray images = NDArrayFactory::create('c', {2,5,5,1}, {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, - 0.5056f, 0.8925f, 0.5461f, 0.9234f, 0.0856f, 0.7938f, - 0.6591f, 0.5555f, 0.1596f, 0.3087f, 0.1548f, 0.4695f, - 0.9939f, 0.6113f, 0.6765f, 0.1800f, 0.6750f, 0.2246f, - 0.0509f, 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f, - 0.2585f, 0.4189f, 0.7028f, 0.7679f, 0.5373f, 0.7234f, - 0.2690f, 0.0062f, 0.0327f, 0.0644f, 0.8428f, 0.7494f, - 0.0755f, 0.6245f, 0.3491f, 0.5793f, 0.5730f, 0.1822f, - 0.6420f, 0.9143f}); - - NDArray boxes = NDArrayFactory::create('c', {2, 2, 4}, {0.7717f, 0.9281f, 0.9846f, 0.4838f, - 0.6433f, 0.6041f, 0.6501f, 0.7612f, - 0.7605f, 0.3948f, 0.9493f, 0.8600f, - 0.7876f, 0.8945f, 0.4638f, 0.7157f}); - NDArray colors = NDArrayFactory::create('c', {1, 2}, {0.9441f, 0.5957f}); - - //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); -// NDArray expected = NDArrayFactory::create('c', {2,5,5,1}, { -// 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, -// 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.9441f, -// 0.9441f, 0.1596f, 0.3087f, 0.1548f, 0.4695f, 0.9939f, 0.6113f, 0.6765f, -// 0.1800f, 0.6750f, 0.2246f, 0.0509f, 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f, -// 0.2585f, 0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f, -// 0.8428f, 0.9441f,0.9441f,0.9441f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f }); - NDArray expected = NDArrayFactory::create('c', {2,5,5,1}, { - 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, - 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, - 0.0856f, 0.7938f, 0.9441f, 0.9441f, 0.1596f, - 0.3087f, 0.1548f, 0.4695f, 0.9939f, 0.6113f, - 0.6765f, 0.18f , 0.675f , 0.2246f, 0.0509f, - - 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f, - 0.2585f, 0.4189f, 0.7028f, 0.7679f, 0.5373f, - 0.7234f, 0.269f , 0.0062f, 0.0327f, 0.0644f, - 0.8428f, 0.9441f, 0.9441f, 0.9441f, 0.3491f, - 0.5793f, 0.573f , 0.1822f, 0.642f , 0.9143f}); - sd::ops::draw_bounding_boxes op; - auto results = op.evaluate({&images, &boxes, &colors}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// result.printBuffer("Boxes3 output"); -// expected.printBuffer("Boxes3 expect"); - -// result.syncToHost(); -// result.printBuffer("Bounded boxes 2"); -// expected.printBuffer("Bounded expec 2"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray images = NDArrayFactory::create( + 'c', {2, 5, 5, 1}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f, 0.3087f, + 0.1548f, 0.4695f, 0.9939f, 0.6113f, 0.6765f, 0.1800f, 0.6750f, 0.2246f, + 0.0509f, 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f, 0.2585f, 0.4189f, + 0.7028f, 0.7679f, 0.5373f, 0.7234f, 0.2690f, 0.0062f, 0.0327f, 0.0644f, + 0.8428f, 0.7494f, 0.0755f, 0.6245f, 0.3491f, 0.5793f, 0.5730f, 0.1822f, + 0.6420f, 0.9143f}); + + NDArray boxes = NDArrayFactory::create( + 'c', {2, 2, 4}, + {0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f, + 0.7605f, 0.3948f, 0.9493f, 0.8600f, 0.7876f, 0.8945f, 0.4638f, 0.7157f}); + NDArray colors = + NDArrayFactory::create('c', {1, 2}, {0.9441f, 0.5957f}); + + // NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); + // NDArray expected = NDArrayFactory::create('c', {2,5,5,1}, { + // 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, + // 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, 0.0856f, 0.7938f, + // 0.9441f, 0.9441f, 0.1596f, 0.3087f, 0.1548f, 0.4695f, 0.9939f, + // 0.6113f, 0.6765f, 0.1800f, 0.6750f, 0.2246f, 0.0509f, 0.4601f, + // 0.8284f, 0.2354f, 0.9752f, 0.8361f, 0.2585f, + // 0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f, + // 0.8428f, + // 0.9441f,0.9441f,0.9441f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f + // }); + NDArray expected = NDArrayFactory::create( + 'c', {2, 5, 5, 1}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, + 0.8925f, 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.9441f, 0.9441f, + 0.1596f, 0.3087f, 0.1548f, 0.4695f, 0.9939f, 0.6113f, 0.6765f, + 0.18f, 0.675f, 0.2246f, 0.0509f, + + 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f, 0.2585f, 0.4189f, + 0.7028f, 0.7679f, 0.5373f, 0.7234f, 0.269f, 0.0062f, 0.0327f, + 0.0644f, 0.8428f, 0.9441f, 0.9441f, 0.9441f, 0.3491f, 0.5793f, + 0.573f, 0.1822f, 0.642f, 0.9143f}); + sd::ops::draw_bounding_boxes op; + auto results = op.evaluate({&images, &boxes, &colors}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printBuffer("Boxes3 output"); + // expected.printBuffer("Boxes3 expect"); + + // result.syncToHost(); + // result.printBuffer("Bounded boxes 2"); + // expected.printBuffer("Bounded expec 2"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) { + NDArray x('c', {2, 3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, + sd::DataType::FLOAT32); + NDArray exp('c', {2, 3}, {-63.75f, -63.75f, -63.75f, -63.5f, 0.f, 0.f}, + sd::DataType::FLOAT32); + NDArray min('c', {}, std::vector{-63.65f}, sd::DataType::FLOAT32); + NDArray max('c', {}, std::vector{0.1f}, sd::DataType::FLOAT32); - NDArray x('c', {2,3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, sd::DataType::FLOAT32); - NDArray exp('c', {2,3}, {-63.75f, -63.75f, -63.75f, -63.5f, 0.f, 0.f}, sd::DataType::FLOAT32); - NDArray min('c', {}, std::vector{-63.65f}, sd::DataType::FLOAT32); - NDArray max('c', {}, std::vector{0.1f}, sd::DataType::FLOAT32); + sd::ops::fake_quant_with_min_max_vars op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); - sd::ops::fake_quant_with_min_max_vars op; - auto results = op.evaluate({&x, &min, &max}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// result.printBuffer("Quantized"); -// exp.printBuffer("Expected"); - ASSERT_TRUE(exp.isSameShapeStrict(*result)); - ASSERT_TRUE(exp.equalsTo(result)); + auto result = results.at(0); + // result.printBuffer("Quantized"); + // exp.printBuffer("Expected"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_2) { + NDArray x = NDArrayFactory::create( + 'c', {2, 3}, {-63.80, -63.75, -63.4, -63.5, 0.0, 0.1}); + NDArray exp = NDArrayFactory::create( + 'c', {2, 3}, {-63.75, -63.75, -63.5, -63.5, 0., 0.}); + NDArray min = NDArrayFactory::create(-63.65); + NDArray max = NDArrayFactory::create(0.1); - NDArray x = NDArrayFactory::create('c', {2,3}, {-63.80, -63.75, -63.4, -63.5, 0.0, 0.1}); - NDArray exp = NDArrayFactory::create('c', {2,3}, {-63.75, -63.75, -63.5 , -63.5 , 0. , 0. }); - NDArray min = NDArrayFactory::create(-63.65); - NDArray max = NDArrayFactory::create(0.1); + sd::ops::fake_quant_with_min_max_vars op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); - sd::ops::fake_quant_with_min_max_vars op; - auto results = op.evaluate({&x, &min, &max}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result.printIndexedBuffer("Quantized2"); - ASSERT_TRUE(exp.isSameShapeStrict(*result)); - ASSERT_TRUE(exp.equalsTo(result)); + auto result = results.at(0); + // result.printIndexedBuffer("Quantized2"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_3) { + NDArray x = NDArrayFactory::create( + 'c', {1, 2, 3, 1}, {-63.80, -63.75, -63.4, -63.5, 0.0, 0.1}); + NDArray exp = NDArrayFactory::create( + 'c', {1, 2, 3, 1}, {-63.75, -63.75, -63.5, -63.5, 0., 0.}); + NDArray min = NDArrayFactory::create('c', {1}, {-63.65}); + NDArray max = NDArrayFactory::create('c', {1}, {0.1}); - NDArray x = NDArrayFactory::create('c', {1,2,3,1}, {-63.80, -63.75, -63.4, -63.5, 0.0, 0.1}); - NDArray exp = NDArrayFactory::create('c', {1,2,3,1}, {-63.75, -63.75, -63.5 , -63.5 , 0. , 0. }); - NDArray min = NDArrayFactory::create('c', {1},{-63.65}); - NDArray max = NDArrayFactory::create('c', {1}, {0.1}); + sd::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); - sd::ops::fake_quant_with_min_max_vars_per_channel op; - auto results = op.evaluate({&x, &min, &max}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result.printIndexedBuffer("Quantized2"); - ASSERT_TRUE(exp.isSameShapeStrict(*result)); - ASSERT_TRUE(exp.equalsTo(result)); + auto result = results.at(0); + // result.printIndexedBuffer("Quantized2"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03) { - NDArray x = NDArrayFactory::create('c', {3,5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, - 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, - 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); - NDArray exp = NDArrayFactory::create('c', {3,5}, { - 0.777002f, 0.596913f, 0.72314f, 0.231040f, 0.509824f, - 0.179308f, 0.505282f, 0.86846f, 0.349958f, 0.509824f, - 0.087355f, 0.596913f, 0.65740f, 0.349958f, 0.159745f}); - NDArray min = NDArrayFactory::create({-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); - NDArray max = NDArrayFactory::create({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); - - sd::ops::fake_quant_with_min_max_vars_per_channel op; - auto results = op.evaluate({&x, &min, &max}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// result.printIndexedBuffer("Quantized03"); - ASSERT_TRUE(exp.isSameShapeStrict(*result)); - ASSERT_TRUE(exp.equalsTo(result)); + NDArray x = NDArrayFactory::create( + 'c', {3, 5}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + NDArray exp = NDArrayFactory::create( + 'c', {3, 5}, + {0.777002f, 0.596913f, 0.72314f, 0.231040f, 0.509824f, 0.179308f, + 0.505282f, 0.86846f, 0.349958f, 0.509824f, 0.087355f, 0.596913f, + 0.65740f, 0.349958f, 0.159745f}); + NDArray min = NDArrayFactory::create( + {-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + NDArray max = NDArrayFactory::create( + {0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + sd::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printIndexedBuffer("Quantized03"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_1) { - NDArray x = NDArrayFactory::create('c', {3,5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, - 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, - 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); - NDArray exp = NDArrayFactory::create('c', {3,5}, { - 0.780061f, 0.596635f, 0.725987f, 0.231950f, 0.508419f, - 0.180014f, 0.504643f, 0.868406f, 0.351335f, 0.508419f, - 0.087699f, 0.596635f, 0.659988f, 0.351335f, 0.160374f}); - NDArray min = NDArrayFactory::create({-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); - NDArray max = NDArrayFactory::create({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); - - sd::ops::fake_quant_with_min_max_vars_per_channel op; - auto results = op.evaluate({&x, &min, &max}, {}, {8}, {true}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// result.printIndexedBuffer("Quantized03_1"); - ASSERT_TRUE(exp.isSameShapeStrict(*result)); - ASSERT_TRUE(exp.equalsTo(result)); + NDArray x = NDArrayFactory::create( + 'c', {3, 5}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + NDArray exp = NDArrayFactory::create( + 'c', {3, 5}, + {0.780061f, 0.596635f, 0.725987f, 0.231950f, 0.508419f, 0.180014f, + 0.504643f, 0.868406f, 0.351335f, 0.508419f, 0.087699f, 0.596635f, + 0.659988f, 0.351335f, 0.160374f}); + NDArray min = NDArrayFactory::create( + {-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + NDArray max = NDArrayFactory::create( + {0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + sd::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.evaluate({&x, &min, &max}, {}, {8}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printIndexedBuffer("Quantized03_1"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_2) { - NDArray x = NDArrayFactory::create('c', {3,5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, - 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, - 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); - NDArray exp = NDArrayFactory::create('c', {3,5}, { - 0.775297f, 0.592226f, 0.725763f, 0.237561f, 0.503245f, - 0.189097f, 0.506084f, 0.868069f, 0.349355f, 0.503245f, - 0.094548f, 0.592226f, 0.654610f, 0.349355f, 0.153769f}); - NDArray min = NDArrayFactory::create({-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); - NDArray max = NDArrayFactory::create({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); - - sd::ops::fake_quant_with_min_max_vars_per_channel op; - auto results = op.evaluate({&x, &min, &max}, {}, {6}, {true}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - result->printIndexedBuffer("Quantized03_2"); - ASSERT_TRUE(exp.isSameShapeStrict(*result)); - ASSERT_TRUE(exp.equalsTo(result)); + NDArray x = NDArrayFactory::create( + 'c', {3, 5}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + NDArray exp = NDArrayFactory::create( + 'c', {3, 5}, + {0.775297f, 0.592226f, 0.725763f, 0.237561f, 0.503245f, 0.189097f, + 0.506084f, 0.868069f, 0.349355f, 0.503245f, 0.094548f, 0.592226f, + 0.654610f, 0.349355f, 0.153769f}); + NDArray min = NDArrayFactory::create( + {-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + NDArray max = NDArrayFactory::create( + {0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + sd::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.evaluate({&x, &min, &max}, {}, {6}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printIndexedBuffer("Quantized03_2"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_3) { - NDArray x = NDArrayFactory::create('c', {3,5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, - 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, - 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); - NDArray exp = NDArrayFactory::create('c', {3,5}, { - 0.781600f, 0.593422f, 0.728248f, 0.233790f, 0.509014f, 0.186095f, 0.508648f, 0.868295f, 0.343809f, - 0.509014f, 0.093048f, 0.593422f, 0.658224f, 0.343809f, 0.165086f}); - NDArray min = NDArrayFactory::create({-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); - NDArray max = NDArrayFactory::create({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); - - sd::ops::fake_quant_with_min_max_vars_per_channel op; - auto results = op.evaluate({&x, &min, &max}, {}, {6}, {false}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - result->printIndexedBuffer("Quantized03_3"); - ASSERT_TRUE(exp.isSameShapeStrict(*result)); - ASSERT_TRUE(exp.equalsTo(result)); + NDArray x = NDArrayFactory::create( + 'c', {3, 5}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + NDArray exp = NDArrayFactory::create( + 'c', {3, 5}, + {0.781600f, 0.593422f, 0.728248f, 0.233790f, 0.509014f, 0.186095f, + 0.508648f, 0.868295f, 0.343809f, 0.509014f, 0.093048f, 0.593422f, + 0.658224f, 0.343809f, 0.165086f}); + NDArray min = NDArrayFactory::create( + {-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + NDArray max = NDArrayFactory::create( + {0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + sd::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.evaluate({&x, &min, &max}, {}, {6}, {false}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer("Quantized03_3"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_4) { #ifdef FFAST_MATH - if (1 > 0) - return; + if (1 > 0) return; #endif - NDArray x = NDArrayFactory::create('c', {2,4,5,3}); - NDArray exp = NDArrayFactory::create('c', {2,4,5,3},{ - 1.0588236f, 1.9607843f, 3.019608f, 4.0588236f, 5.098039f, 6.039216f, 7.0588236f, 8.039216f, 9.058824f, - 10.058824f, 10.980392f, 12.078432f, 13.058824f, 13.921569f, 15.09804f, 16.058825f, 17.058825f, 18.117647f, - 19.058825f, 20.f, 21.137257f, 22.058825f, 22.941177f, 23.882355f, 25.058825f, 26.078432f, 26.901962f, - 28.058825f, 29.019608f, 29.92157f, 31.058825f, 31.960785f, 32.941177f, 34.058823f, 35.09804f, 35.960785f, - 37.058823f, 38.039215f, 38.980392f, 40.058823f, 40.980392f, 42.000004f, 43.058826f, 43.92157f, 45.01961f, - 45.f, 47.058823f, 48.03922f, 45.f, 50.f, 51.058826f, 45.f, 50.f, 54.078434f, - 45.f, 50.f, 57.09804f, 45.f, 50.f, 60.11765f, 45.f, 50.f, 62.862747f, - 45.f, 50.f, 65.882355f, 45.f, 50.f, 68.90196f, 45.f, 50.f, 70.f, - 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, - 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, - 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, - 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, - 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, - 45.f, 50.f, 70.f}); - NDArray min = NDArrayFactory::create({20.f, 20.f, 20.f}); - NDArray max = NDArrayFactory::create({65.f, 70.f, 90.f}); - x.linspace(1.); - sd::ops::fake_quant_with_min_max_vars_per_channel op; - auto results = op.evaluate({&x, &min, &max}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// result.printBuffer("Quantized per channels 4"); -// exp.printBuffer("Quantized per channest E"); -// auto diff = *result - exp; -// diff.printIndexedBuffer("Difference"); - ASSERT_TRUE(exp.isSameShapeStrict(*result)); - ASSERT_TRUE(exp.equalsTo(result)); + NDArray x = NDArrayFactory::create('c', {2, 4, 5, 3}); + NDArray exp = NDArrayFactory::create( + 'c', {2, 4, 5, 3}, + {1.0588236f, 1.9607843f, 3.019608f, 4.0588236f, 5.098039f, 6.039216f, + 7.0588236f, 8.039216f, 9.058824f, 10.058824f, 10.980392f, 12.078432f, + 13.058824f, 13.921569f, 15.09804f, 16.058825f, 17.058825f, 18.117647f, + 19.058825f, 20.f, 21.137257f, 22.058825f, 22.941177f, 23.882355f, + 25.058825f, 26.078432f, 26.901962f, 28.058825f, 29.019608f, 29.92157f, + 31.058825f, 31.960785f, 32.941177f, 34.058823f, 35.09804f, 35.960785f, + 37.058823f, 38.039215f, 38.980392f, 40.058823f, 40.980392f, 42.000004f, + 43.058826f, 43.92157f, 45.01961f, 45.f, 47.058823f, 48.03922f, + 45.f, 50.f, 51.058826f, 45.f, 50.f, 54.078434f, + 45.f, 50.f, 57.09804f, 45.f, 50.f, 60.11765f, + 45.f, 50.f, 62.862747f, 45.f, 50.f, 65.882355f, + 45.f, 50.f, 68.90196f, 45.f, 50.f, 70.f, + 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, + 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, + 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, + 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, + 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, + 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, + 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, + 45.f, 50.f, 70.f, 45.f, 50.f, 70.f}); + NDArray min = NDArrayFactory::create({20.f, 20.f, 20.f}); + NDArray max = NDArrayFactory::create({65.f, 70.f, 90.f}); + x.linspace(1.); + sd::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printBuffer("Quantized per channels 4"); + // exp.printBuffer("Quantized per channest E"); + // auto diff = result - exp; + // diff.printIndexedBuffer("Difference"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) { - NDArray x = NDArrayFactory::create('c', {2, 3, 5, 4}); - NDArray exp = NDArrayFactory::create('c', {2, 3, 5, 4},{ - -19.92157f, -18.980392f, -18.039217f, -16.941177f, - -19.92157f, -18.980392f, -18.039217f, -16.941177f, - -19.92157f, -18.980392f, -18.039217f, -16.941177f, - -19.92157f, -18.980392f, -18.039217f, -16.941177f, - -19.92157f, -18.980392f, -18.039217f, -16.941177f, - -19.92157f, -18.980392f, -18.039217f, -16.941177f, - -19.92157f, -18.980392f, -18.039217f, -16.941177f, - -19.92157f, -18.980392f, -18.039217f, -16.941177f, - -19.92157f, -18.980392f, -18.039217f, -16.941177f, - -19.92157f, -18.980392f, -18.039217f, -16.941177f, - -19.92157f, -18.980392f, -18.039217f, -16.941177f, - -16.f, -15.058824f, -13.960785f, -13.0196085f, - -11.92157f, -10.980392f, -10.039217f, -8.941177f, - -8.000001f, -7.0588236f, -5.960785f, -5.0196085f, - -3.9215698f, -2.9803925f, -2.039217f, -0.94117737f, - 0.f, 0.94117737f, 2.039215f, 2.9803925f, - 4.07843f, 5.0196075f, 5.960783f, 7.0588226f, - 8.f, 8.941177f, 10.039215f, 10.980392f, - 12.07843f, 13.019608f, 13.960783f, 15.058823f, - 16.f, 16.941177f, 18.039217f, 18.980392f, - 20.07843f, 21.019608f, 21.960783f, 23.058823f, - 20.07843f, 21.019608f, 21.960783f, 23.058823f, - 20.07843f, 21.019608f, 21.960783f, 23.058823f, - 20.07843f, 21.019608f, 21.960783f, 23.058823f, - 20.07843f, 21.019608f, 21.960783f, 23.058823f, - 20.07843f, 21.019608f, 21.960783f, 23.058823f, - 20.07843f, 21.019608f, 21.960783f, 23.058823f, - 20.07843f, 21.019608f, 21.960783f, 23.058823f, - 20.07843f, 21.019608f, 21.960783f, 23.058823f, - 20.07843f, 21.019608f, 21.960783f, 23.058823f - }); - NDArray min = NDArrayFactory::create({-20.f, -19.f, -18.f, -17.f}); - NDArray max = NDArrayFactory::create({20.f, 21.f, 22.f, 23.f}); - x.linspace(-60.); - sd::ops::fake_quant_with_min_max_vars_per_channel op; - auto results = op.evaluate({&x, &min, &max}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// result.printBuffer("Quantized per channels 5"); -// exp.printBuffer("Quantized per channest E"); -// auto diff = *result - exp; -// diff.printIndexedBuffer("Difference"); - - ASSERT_TRUE(exp.isSameShapeStrict(*result)); - ASSERT_TRUE(exp.equalsTo(result)); + NDArray x = NDArrayFactory::create('c', {2, 3, 5, 4}); + NDArray exp = NDArrayFactory::create( + 'c', {2, 3, 5, 4}, + {-19.92157f, -18.980392f, -18.039217f, -16.941177f, -19.92157f, + -18.980392f, -18.039217f, -16.941177f, -19.92157f, -18.980392f, + -18.039217f, -16.941177f, -19.92157f, -18.980392f, -18.039217f, + -16.941177f, -19.92157f, -18.980392f, -18.039217f, -16.941177f, + -19.92157f, -18.980392f, -18.039217f, -16.941177f, -19.92157f, + -18.980392f, -18.039217f, -16.941177f, -19.92157f, -18.980392f, + -18.039217f, -16.941177f, -19.92157f, -18.980392f, -18.039217f, + -16.941177f, -19.92157f, -18.980392f, -18.039217f, -16.941177f, + -19.92157f, -18.980392f, -18.039217f, -16.941177f, -16.f, + -15.058824f, -13.960785f, -13.0196085f, -11.92157f, -10.980392f, + -10.039217f, -8.941177f, -8.000001f, -7.0588236f, -5.960785f, + -5.0196085f, -3.9215698f, -2.9803925f, -2.039217f, -0.94117737f, + 0.f, 0.94117737f, 2.039215f, 2.9803925f, 4.07843f, + 5.0196075f, 5.960783f, 7.0588226f, 8.f, 8.941177f, + 10.039215f, 10.980392f, 12.07843f, 13.019608f, 13.960783f, + 15.058823f, 16.f, 16.941177f, 18.039217f, 18.980392f, + 20.07843f, 21.019608f, 21.960783f, 23.058823f, 20.07843f, + 21.019608f, 21.960783f, 23.058823f, 20.07843f, 21.019608f, + 21.960783f, 23.058823f, 20.07843f, 21.019608f, 21.960783f, + 23.058823f, 20.07843f, 21.019608f, 21.960783f, 23.058823f, + 20.07843f, 21.019608f, 21.960783f, 23.058823f, 20.07843f, + 21.019608f, 21.960783f, 23.058823f, 20.07843f, 21.019608f, + 21.960783f, 23.058823f, 20.07843f, 21.019608f, 21.960783f, + 23.058823f, 20.07843f, 21.019608f, 21.960783f, 23.058823f}); + NDArray min = NDArrayFactory::create({-20.f, -19.f, -18.f, -17.f}); + NDArray max = NDArrayFactory::create({20.f, 21.f, 22.f, 23.f}); + x.linspace(-60.); + sd::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printBuffer("Quantized per channels 5"); + // exp.printBuffer("Quantized per channest E"); + // auto diff = result - exp; + // diff.printIndexedBuffer("Difference"); + + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_6) { - NDArray x = NDArrayFactory::create('c', {3, 5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, - 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, - 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); -// NDArray exp = NDArrayFactory::create('c', {3, 5},{ -// 0.7801f, 0.5966f, 0.7260f, 0.2320f, 0.5084f, -// 0.1800f, 0.5046f, 0.8684f, 0.3513f, 0.5084f, -// 0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f -// }); - - NDArray exp = NDArrayFactory::create('c', {3,5}, { - 0.77700233f, 0.596913f, 0.72314f, 0.23104f, 0.50982356f, - 0.17930824f, 0.50528157f, 0.86846f, 0.34995764f, 0.50982356f, - 0.08735529f, 0.596913f, 0.6574f, 0.34995764f, 0.15974471f}); - NDArray min = NDArrayFactory::create('c', {5}, {-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); - NDArray max = NDArrayFactory::create('c', {5}, {0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); - // x.linspace(-60.); - sd::ops::fake_quant_with_min_max_vars_per_channel op; - auto results = op.evaluate({&x, &min, &max}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// result.printBuffer("Quantized per channels 5"); -// exp.printBuffer("Quantized per channest E"); -// auto diff = *result - exp; -// diff.printIndexedBuffer("Difference"); - - ASSERT_TRUE(exp.isSameShapeStrict(*result)); - ASSERT_TRUE(exp.equalsTo(result)); + NDArray x = NDArrayFactory::create( + 'c', {3, 5}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + // NDArray exp = NDArrayFactory::create('c', {3, 5},{ + // 0.7801f, 0.5966f, 0.7260f, 0.2320f, 0.5084f, + // 0.1800f, 0.5046f, 0.8684f, 0.3513f, 0.5084f, + // 0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f + // }); + + NDArray exp = NDArrayFactory::create( + 'c', {3, 5}, + {0.77700233f, 0.596913f, 0.72314f, 0.23104f, 0.50982356f, 0.17930824f, + 0.50528157f, 0.86846f, 0.34995764f, 0.50982356f, 0.08735529f, 0.596913f, + 0.6574f, 0.34995764f, 0.15974471f}); + NDArray min = NDArrayFactory::create( + 'c', {5}, {-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + NDArray max = NDArrayFactory::create( + 'c', {5}, {0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + // x.linspace(-60.); + sd::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printBuffer("Quantized per channels 5"); + // exp.printBuffer("Quantized per channest E"); + // auto diff = result - exp; + // diff.printIndexedBuffer("Difference"); + + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } ////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_7) { - - NDArray x = NDArrayFactory::create('c', {100}); - NDArray exp = NDArrayFactory::create('c', {100}, { - 0.f, 0.01176471f, 0.01960784f, 0.03137255f, 0.03921569f, - 0.0509804f, 0.05882353f, 0.07058824f, 0.07843138f, 0.09019608f, - 0.09803922f, 0.10980393f, 0.12156864f, 0.12941177f, 0.14117648f, - 0.14901961f, 0.16078432f, 0.16862746f, 0.18039216f, 0.18823531f, - 0.20000002f, 0.21176472f, 0.21960786f, 0.23137257f, 0.2392157f, - 0.2509804f, 0.25882354f, 0.27058825f, 0.2784314f, 0.2901961f, - 0.3019608f, 0.30980393f, 0.32156864f, 0.32941177f, 0.34117648f, - 0.34901962f, 0.36078432f, 0.36862746f, 0.3803922f, 0.38823533f, - 0.40000004f, 0.41176474f, 0.41960788f, 0.43137258f, 0.43921572f, - 0.45098042f, 0.45882356f, 0.47058827f, 0.4784314f, 0.4901961f, - 0.49803925f, 0.50980395f, 0.52156866f, 0.5294118f, 0.5411765f, - 0.54901963f, 0.56078434f, 0.5686275f, 0.5803922f, 0.5882353f, - 0.6f, 0.6117647f, 0.61960787f, 0.6313726f, 0.6392157f, - 0.6509804f, 0.65882355f, 0.67058825f, 0.6784314f, 0.6901961f, - 0.69803923f, 0.70980394f, 0.72156864f, 0.7294118f, 0.7411765f, - 0.7490196f, 0.7607844f, 0.7686275f, 0.7803922f, 0.78823537f, - 0.8000001f, 0.8117648f, 0.8196079f, 0.8313726f, 0.83921576f, - 0.85098046f, 0.8588236f, 0.8705883f, 0.87843144f, 0.89019614f, - 0.8980393f, 0.909804f, 0.9215687f, 0.9294118f, 0.94117653f, - 0.9490197f, 0.9607844f, 0.9686275f, 0.9803922f, 0.98823535f - }); - NDArray min = NDArrayFactory::create('c', {1},{0.0f}); - NDArray max = NDArrayFactory::create('c', {1}, {1.f}); - x.linspace(0., 0.01); - sd::ops::fake_quant_with_min_max_vars op; - auto results = op.evaluate({&x, &min, &max}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// result.printBuffer("Quantized7"); -// exp.printBuffer("Expected 7"); - ASSERT_TRUE(exp.isSameShapeStrict(*result)); - ASSERT_TRUE(exp.equalsTo(result)); + NDArray x = NDArrayFactory::create('c', {100}); + NDArray exp = NDArrayFactory::create( + 'c', {100}, + {0.f, 0.01176471f, 0.01960784f, 0.03137255f, 0.03921569f, + 0.0509804f, 0.05882353f, 0.07058824f, 0.07843138f, 0.09019608f, + 0.09803922f, 0.10980393f, 0.12156864f, 0.12941177f, 0.14117648f, + 0.14901961f, 0.16078432f, 0.16862746f, 0.18039216f, 0.18823531f, + 0.20000002f, 0.21176472f, 0.21960786f, 0.23137257f, 0.2392157f, + 0.2509804f, 0.25882354f, 0.27058825f, 0.2784314f, 0.2901961f, + 0.3019608f, 0.30980393f, 0.32156864f, 0.32941177f, 0.34117648f, + 0.34901962f, 0.36078432f, 0.36862746f, 0.3803922f, 0.38823533f, + 0.40000004f, 0.41176474f, 0.41960788f, 0.43137258f, 0.43921572f, + 0.45098042f, 0.45882356f, 0.47058827f, 0.4784314f, 0.4901961f, + 0.49803925f, 0.50980395f, 0.52156866f, 0.5294118f, 0.5411765f, + 0.54901963f, 0.56078434f, 0.5686275f, 0.5803922f, 0.5882353f, + 0.6f, 0.6117647f, 0.61960787f, 0.6313726f, 0.6392157f, + 0.6509804f, 0.65882355f, 0.67058825f, 0.6784314f, 0.6901961f, + 0.69803923f, 0.70980394f, 0.72156864f, 0.7294118f, 0.7411765f, + 0.7490196f, 0.7607844f, 0.7686275f, 0.7803922f, 0.78823537f, + 0.8000001f, 0.8117648f, 0.8196079f, 0.8313726f, 0.83921576f, + 0.85098046f, 0.8588236f, 0.8705883f, 0.87843144f, 0.89019614f, + 0.8980393f, 0.909804f, 0.9215687f, 0.9294118f, 0.94117653f, + 0.9490197f, 0.9607844f, 0.9686275f, 0.9803922f, 0.98823535f}); + NDArray min = NDArrayFactory::create('c', {1}, {0.0f}); + NDArray max = NDArrayFactory::create('c', {1}, {1.f}); + x.linspace(0., 0.01); + sd::ops::fake_quant_with_min_max_vars op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printBuffer("Quantized7"); + // exp.printBuffer("Expected 7"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } ////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_8) { - - NDArray x = NDArrayFactory::create('c', {10}); - NDArray exp = NDArrayFactory::create('c', {10}, { - 0.f, 0.09803922f, 0.20000002f, 0.3019608f, 0.40000004f, 0.49803925f, - 0.6f, 0.69803923f, 0.8000001f, 0.8980393f - }); - NDArray min = NDArrayFactory::create('c', {1},{0.0f}); - NDArray max = NDArrayFactory::create('c', {1}, {1.f}); - x.linspace(0., 0.1); - sd::ops::fake_quant_with_min_max_vars op; - auto results = op.evaluate({&x, &min, &max}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// x.printBuffer("SourInput8"); -// result.printBuffer("Quantized8"); -// exp.printBuffer("Expected 8"); - ASSERT_TRUE(exp.isSameShapeStrict(*result)); - ASSERT_TRUE(exp.equalsTo(result)); + NDArray x = NDArrayFactory::create('c', {10}); + NDArray exp = NDArrayFactory::create( + 'c', {10}, + {0.f, 0.09803922f, 0.20000002f, 0.3019608f, 0.40000004f, 0.49803925f, + 0.6f, 0.69803923f, 0.8000001f, 0.8980393f}); + NDArray min = NDArrayFactory::create('c', {1}, {0.0f}); + NDArray max = NDArrayFactory::create('c', {1}, {1.f}); + x.linspace(0., 0.1); + sd::ops::fake_quant_with_min_max_vars op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // x.printBuffer("SourInput8"); + // result.printBuffer("Quantized8"); + // exp.printBuffer("Expected 8"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, bool_broadcast_test_1) { + NDArray arr1('c', {2, 2, 1}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray arr2('c', {2, 2}, {0, 1, 0, 4}, sd::DataType::INT32); - NDArray arr1('c', {2,2,1}, {1, 2, 3, 4}, sd::DataType::INT32); - NDArray arr2('c', { 2,2}, {0, 1, 0, 4}, sd::DataType::INT32); - - NDArray expd('c', {2,2,2}, {false, true, false, false, false, false, false, true}, sd::DataType::BOOL); + NDArray expd('c', {2, 2, 2}, + {false, true, false, false, false, false, false, true}, + sd::DataType::BOOL); - NDArray result('c', {2,2,2}, sd::DataType::BOOL); + NDArray result('c', {2, 2, 2}, sd::DataType::BOOL); - arr1.applyTrueBroadcast(sd::BroadcastBoolOpsTuple::custom(scalar::EqualTo, pairwise::EqualTo, broadcast::EqualTo), arr2, result, true); - // result.printIndexedBuffer(); - // expd.printIndexedBuffer(); + arr1.applyTrueBroadcast( + sd::BroadcastBoolOpsTuple::custom(scalar::EqualTo, pairwise::EqualTo, + broadcast::EqualTo), + arr2, result, true); + // result.printIndexedBuffer(); + // expd.printIndexedBuffer(); - ASSERT_TRUE(expd.isSameShape(result)); - ASSERT_TRUE(expd.equalsTo(result)); + ASSERT_TRUE(expd.isSameShape(result)); + ASSERT_TRUE(expd.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, printIndexedTest_1) { - - NDArray arr('c', {2,2,2,2}, {1, 2, 3, 4, 5, 6, 7, 8,9, 10, 11, 12, 13, 14, 15, 16}, sd::DataType::INT32); -// NDArray arr2('c', { 2,2}, {0, 1, 0, 4}, sd::DataType::INT32); - -// NDArray expd('c', {2,2,2}, {0,1,0,0, 0,0,0,1}, sd::DataType::BOOL); - -// NDArray result('c', {2,2,2}, sd::DataType::BOOL); - -// arr1.applyTrueBroadcast(sd::BroadcastBoolOpsTuple::custom(scalar::EqualTo, pairwise::EqualTo, broadcast::EqualTo), &arr2, &result, true, nullptr); - // result.printIndexedBuffer(); - // expd.printIndexedBuffer(); - -// ASSERT_TRUE(expd.isSameShape(result)); -// ASSERT_TRUE(expd.equalsTo(result)); - // arr.printIndexedBuffer("Test Print"); // output as [1, 2, 3, 4, 5, 6, 7, 8] -// -// we want output as -// [[[1 2] -// [3 4]] -// -// [[5 6] -// [7 8]]] -// - ResultSet lastDims = arr.allTensorsAlongDimension({3}); // last dim - size_t k = 0; // k from 0 to lastDims->size() - Nd4jLong rank = 4; // in this case + NDArray arr('c', {2, 2, 2, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + sd::DataType::INT32); + // NDArray arr2('c', { 2,2}, {0, 1, 0, 4}, sd::DataType::INT32); + + // NDArray expd('c', {2,2,2}, {0,1,0,0, 0,0,0,1}, sd::DataType::BOOL); + + // NDArray result('c', {2,2,2}, sd::DataType::BOOL); + + // arr1.applyTrueBroadcast(sd::BroadcastBoolOpsTuple::custom(scalar::EqualTo, + // pairwise::EqualTo, broadcast::EqualTo), &arr2, &result, true, nullptr); + // result.printIndexedBuffer(); + // expd.printIndexedBuffer(); + + // ASSERT_TRUE(expd.isSameShape(result)); + // ASSERT_TRUE(expd.equalsTo(result)); + // arr.printIndexedBuffer("Test Print"); // output as [1, 2, 3, 4, 5, 6, 7, 8] + // + // we want output as + // [[[1 2] + // [3 4]] + // + // [[5 6] + // [7 8]]] + // + ResultSet lastDims = arr.allTensorsAlongDimension({3}); // last dim + size_t k = 0; // k from 0 to lastDims->size() + Nd4jLong rank = 4; // in this case + printf("["); + for (Nd4jLong i = 0; i < rank - 1; i++) { + for (Nd4jLong l = 0; l < i; ++l) printf("\n"); printf("["); - for (Nd4jLong i = 0; i < rank - 1; i++) { - - for (Nd4jLong l = 0; l < i; ++l) - printf("\n"); - printf("["); - for (Nd4jLong j = 0; j < arr.sizeAt(i); j++) { - // if (!i) - // printf("["); - // else - // printf(" "); - lastDims.at(k++)->printBuffer(); - //if (k == arr.sizeAt(i)) - // printf("]\n"); - } - printf("]\n"); + for (Nd4jLong j = 0; j < arr.sizeAt(i); j++) { + // if (!i) + // printf("["); + // else + // printf(" "); + lastDims.at(k++).printBuffer(); + // if (k == arr.sizeAt(i)) + // printf("]\n"); } printf("]\n"); + } + printf("]\n"); } - diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp index 97dcf7574e66..b2cb13f5db6f 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp @@ -14,4015 +14,4528 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // Created by raver on 8/4/2018. // -#include "testlayers.h" -#include #include -#include #include #include #include +#include +#include -using namespace sd; +#include "testlayers.h" +using namespace sd; class DeclarableOpsTests11 : public testing::Test { -public: - - DeclarableOpsTests11() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests11() { + printf("\n"); + fflush(stdout); + } }; - TEST_F(DeclarableOpsTests11, test_listdiff_1) { - auto x = NDArrayFactory::create('c', {4}, {0, 1, 2, 3}); - auto y = NDArrayFactory::create('c',{2}, {3, 1}); - - sd::ops::listdiff op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - + auto x = NDArrayFactory::create('c', {4}, {0, 1, 2, 3}); + auto y = NDArrayFactory::create('c', {2}, {3, 1}); + sd::ops::listdiff op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, log_loss_grad_test1) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692, - -24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911}); - NDArray dLdwExp('c', {2,3,4}, {3.21887, 4.96807, 6.10512, 6.80726, 7.15461, 7.19051, 6.93973, 6.41584, 5.62456, 4.56548, 3.2326 , 1.61444, - -0.30659, -2.55529, -5.16569, -8.18417,-11.67468,-15.72734,-20.47379,-26.11644,-32.9902 ,-41.71318,-53.64824,-73.05434}); - NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002, - -0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {-12.49997, -13.04346, -13.63635, -14.28571, -14.99999, -15.78947, + -16.66666, -17.64705, -18.75, -20., -21.42857, -23.07692, + -24.99999, -27.27272, -29.99999, -33.33332, -37.49999, -42.85713, + -49.99998, -59.99998, -74.99995, -99.99992, -149.99986, -299.99911}); + NDArray dLdwExp( + 'c', {2, 3, 4}, + {3.21887, 4.96807, 6.10512, 6.80726, 7.15461, 7.19051, + 6.93973, 6.41584, 5.62456, 4.56548, 3.2326, 1.61444, + -0.30659, -2.55529, -5.16569, -8.18417, -11.67468, -15.72734, + -20.47379, -26.11644, -32.9902, -41.71318, -53.64824, -73.05434}); + NDArray dLdlExp('c', {2, 3, 4}, + {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, + 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002, + -0.04002, -0.12058, -0.20273, -0.28768, -0.37689, -0.47223, + -0.57634, -0.69315, -0.82911, -0.99621, -1.22117, -1.58903}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss_grad op; + auto results = + op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, log_loss_grad_test2) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 1, 4}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,1,4}, sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {2,1,4}, {15.99805, 16.72406, 16.27746, 14.83754,-44.97147,-59.99582,-79.28771,-107.35497}); + NDArray dLdwExp('c', {2, 1, 4}, + {15.99805, 16.72406, 16.27746, 14.83754, -44.97147, -59.99582, + -79.28771, -107.35497}); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}); + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, log_loss_grad_test3) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights(sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692, - -24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911}); - NDArray dLdwExp('c', {}, std::vector{-227.77286}); - NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002, - -0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights(sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {-12.49997, -13.04346, -13.63635, -14.28571, -14.99999, -15.78947, + -16.66666, -17.64705, -18.75, -20., -21.42857, -23.07692, + -24.99999, -27.27272, -29.99999, -33.33332, -37.49999, -42.85713, + -49.99998, -59.99998, -74.99995, -99.99992, -149.99986, -299.99911}); + NDArray dLdwExp('c', {}, std::vector{-227.77286}); + NDArray dLdlExp('c', {2, 3, 4}, + {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, + 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002, + -0.04002, -0.12058, -0.20273, -0.28768, -0.37689, -0.47223, + -0.57634, -0.69315, -0.82911, -0.99621, -1.22117, -1.58903}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, log_loss_grad_test4) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {1,3,1}, {4.8876 , -46.29156, -186.36887}); + NDArray dLdwExp('c', {1, 3, 1}, {4.8876, -46.29156, -186.36887}); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); - // dLdw->printIndexedBuffer(); - // dLdw->printShapeInfo(); + auto dLdw = results.at(1); + // dLdw->printIndexedBuffer(); + // dLdw->printShapeInfo(); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, log_loss_grad_test5) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-1.04166,-1.08696, -1.13636, -1.19048,-1.25 ,-1.31579, -1.38889, -1.47059,-1.5625 ,-1.66667, -1.78571, -1.92308, - -2.08333,-2.27273, -2.5 , -2.77778,-3.125 ,-3.57143, -4.16667, -5. ,-6.25 ,-8.33333,-12.49999,-24.99993}); - NDArray dLdwExp('c', {2,3,4}, {1.05912, 1.20488, 1.29964, 1.35815, 1.3871 , 1.39009, 1.36919, 1.32553, 1.25959, 1.17133, 1.06026, 0.92541, - 0.76533, 0.57794, 0.3604 , 0.10886,-0.18201,-0.51973,-0.91527,-1.38549,-1.95831,-2.68522,-3.67981,-5.29698}); - NDArray dLdlExp('c', {2,3,4}, {0.13242, 0.10176, 0.08302, 0.06909, 0.05776, 0.04803, 0.03935, 0.03141, 0.02397, 0.01689, 0.01005, 0.00334, - -0.00334,-0.01005,-0.01689,-0.02397,-0.03141,-0.03935,-0.04803,-0.05776,-0.06909,-0.08302,-0.10176,-0.13242}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {-1.04166, -1.08696, -1.13636, -1.19048, -1.25, -1.31579, + -1.38889, -1.47059, -1.5625, -1.66667, -1.78571, -1.92308, + -2.08333, -2.27273, -2.5, -2.77778, -3.125, -3.57143, + -4.16667, -5., -6.25, -8.33333, -12.49999, -24.99993}); + NDArray dLdwExp('c', {2, 3, 4}, + {1.05912, 1.20488, 1.29964, 1.35815, 1.3871, 1.39009, + 1.36919, 1.32553, 1.25959, 1.17133, 1.06026, 0.92541, + 0.76533, 0.57794, 0.3604, 0.10886, -0.18201, -0.51973, + -0.91527, -1.38549, -1.95831, -2.68522, -3.67981, -5.29698}); + NDArray dLdlExp('c', {2, 3, 4}, + {0.13242, 0.10176, 0.08302, 0.06909, 0.05776, 0.04803, + 0.03935, 0.03141, 0.02397, 0.01689, 0.01005, 0.00334, + -0.00334, -0.01005, -0.01689, -0.02397, -0.03141, -0.03935, + -0.04803, -0.05776, -0.06909, -0.08302, -0.10176, -0.13242}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, log_loss_grad_test6) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {1,3,1}, {6.73432, 2.46939,-9.20372}); + NDArray dLdwExp('c', {1, 3, 1}, {6.73432, 2.46939, -9.20372}); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, log_loss_grad_test7) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights(sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights(sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {}, std::vector{0.}); + NDArray dLdwExp('c', {}, std::vector{0.}); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, log_loss_grad_test8) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. ,-1.5 ,-1.57895, -1.66667, -1.76471,-1.875 ,-2. , -2.14286, -2.30769, - -2.5 ,-2.72727, -3. , -3.33333,-3.75 ,-4.28571, -5. , -6. ,-7.49999,-9.99999,-14.99999,-29.99991}); - NDArray dLdwExp('c', {2,3,4}, {1.56625, 1.74117, 1.85487, 1.92509, 1.95982, 1.96341, 1.93833, 1.88594, 1.80682, 1.70091, 1.56762, 1.4058 , - 1.2137 , 0.98883, 0.72779, 0.42594, 0.07689,-0.32837,-0.80302,-1.36728,-2.05466,-2.92696,-4.12046,-6.06107}); - NDArray dLdlExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0.06931, 0.05763, 0.04722, 0.03769, 0.02877, 0.02027, 0.01206, 0.004, - -0.004 ,-0.01206,-0.02027,-0.02877,-0.03769,-0.04722,-0.05763,-0.06931,-0.08291,-0.09962,-0.12212,-0.1589}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); - weights.p(3, 0.); - - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {0., 0., 0., 0., -1.5, -1.57895, + -1.66667, -1.76471, -1.875, -2., -2.14286, -2.30769, + -2.5, -2.72727, -3., -3.33333, -3.75, -4.28571, + -5., -6., -7.49999, -9.99999, -14.99999, -29.99991}); + NDArray dLdwExp('c', {2, 3, 4}, + {1.56625, 1.74117, 1.85487, 1.92509, 1.95982, 1.96341, + 1.93833, 1.88594, 1.80682, 1.70091, 1.56762, 1.4058, + 1.2137, 0.98883, 0.72779, 0.42594, 0.07689, -0.32837, + -0.80302, -1.36728, -2.05466, -2.92696, -4.12046, -6.06107}); + NDArray dLdlExp('c', {2, 3, 4}, + {0., 0., 0., 0., 0.06931, 0.05763, + 0.04722, 0.03769, 0.02877, 0.02027, 0.01206, 0.004, + -0.004, -0.01206, -0.02027, -0.02877, -0.03769, -0.04722, + -0.05763, -0.06931, -0.08291, -0.09962, -0.12212, -0.1589}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); + + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, log_loss_grad_test9) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {-0.52083, -0.54348, -0.56818, -0.59524, -0.625, -0.65789, + -0.69444, -0.73529, -0.78125, -0.83333, -0.89286, -0.96154, + -1.04167, -1.13636, -1.25, -1.38889, -1.5625, -1.78571, + -2.08333, -2.5, -3.125, -4.16666, -6.24999, -12.49996}); + NDArray dLdwExp('c', {2, 3, 4}, + {0.13412, 0.207, 0.25438, 0.28364, 0.29811, 0.2996, + 0.28916, 0.26733, 0.23436, 0.19023, 0.13469, 0.06727, + -0.01277, -0.10647, -0.21524, -0.34101, -0.48645, -0.65531, + -0.85307, -1.08819, -1.37459, -1.73805, -2.23534, -3.04393}); + NDArray dLdlExp('c', {2, 3, 4}, + {0.06621, 0.05088, 0.04151, 0.03455, 0.02888, 0.02401, + 0.01968, 0.0157, 0.01199, 0.00845, 0.00502, 0.00167, + -0.00167, -0.00502, -0.00845, -0.01199, -0.0157, -0.01968, + -0.02401, -0.02888, -0.03455, -0.04151, -0.05088, -0.06621}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); +} - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, log_loss_grad_test10) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 1}, sd::DataType::DOUBLE); - NDArray dLdpExp('c', {2,3,4}, {-0.52083,-0.54348,-0.56818, -0.59524,-0.625 ,-0.65789,-0.69444, -0.73529,-0.78125,-0.83333,-0.89286, -0.96154, - -1.04167,-1.13636,-1.25 , -1.38889,-1.5625 ,-1.78571,-2.08333, -2.5 ,-3.125 ,-4.16666,-6.24999,-12.49996}); - NDArray dLdwExp('c', {2,3,4}, {0.13412, 0.207 , 0.25438, 0.28364, 0.29811, 0.2996 , 0.28916, 0.26733, 0.23436, 0.19023, 0.13469, 0.06727, - -0.01277,-0.10647,-0.21524,-0.34101,-0.48645,-0.65531,-0.85307,-1.08819,-1.37459,-1.73805,-2.23534,-3.04393}); - NDArray dLdlExp('c', {2,3,4}, {0.06621, 0.05088, 0.04151, 0.03455, 0.02888, 0.02401, 0.01968, 0.0157 , 0.01199, 0.00845, 0.00502, 0.00167, - -0.00167,-0.00502,-0.00845,-0.01199,-0.0157 ,-0.01968,-0.02401,-0.02888,-0.03455,-0.04151,-0.05088,-0.06621}); + NDArray dLdwExp('c', {1, 1}, std::vector{-9.49054}); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdw = results.at(1); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, log_loss_grad_test10) { +TEST_F(DeclarableOpsTests11, log_loss_grad_test11) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,1}, sd::DataType::DOUBLE); + NDArray dLdwExp('c', {1, 3, 1}, {0.20365, -1.92882, -7.76537}); - NDArray dLdwExp('c', {1,1}, std::vector{-9.49054}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto dLdw = results.at(1); - auto *dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); +} - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, log_loss_grad_test12) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {0., 0., 0., 0., -0.75, -0.789473, + -0.833333, -0.882353, -0.9375, -1., -1.071428, -1.153846, + -1.25, -1.363636, -1.5, -1.666666, -1.875, -2.142857, + -2.499999, -2.999999, -3.749997, -4.999997, -7.499993, -14.999956}); + NDArray dLdwExp('c', {2, 3, 4}, + {0.16094, 0.2484, 0.30526, 0.34036, 0.35773, 0.35953, + 0.34699, 0.32079, 0.28123, 0.22827, 0.16163, 0.08072, + -0.01533, -0.12776, -0.25828, -0.40921, -0.58373, -0.78637, + -1.02369, -1.30582, -1.64951, -2.08566, -2.68241, -3.65272}); + NDArray dLdlExp('c', {2, 3, 4}, + {0., 0., 0., 0., 0.03466, 0.02882, + 0.02361, 0.01884, 0.01438, 0.01014, 0.00603, 0.002, + -0.002, -0.00603, -0.01014, -0.01438, -0.01884, -0.02361, + -0.02882, -0.03466, -0.04146, -0.04981, -0.06106, -0.07945}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.r(0) = 0.; + weights.r(1) = 0.; + weights.r(2) = 0.; + weights.r(3) = 0.; + + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, log_loss_grad_test11) { +TEST_F(DeclarableOpsTests11, log_loss_grad_test13) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 1}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + -2.08333, -2.27273, -2.5, -2.77778, -3.125, -3.57143, + -4.16667, -5., -6.25, -8.33333, -12.49999, -24.99993}); + NDArray dLdwExp('c', {2, 3, 1}, + {1.75828, 2.30839, 1.25309, -1.35098, -6.16602, -16.78383}); + NDArray dLdlExp('c', {2, 3, 4}, + {0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + -0.00334, -0.01005, -0.01689, -0.02397, -0.03141, -0.03935, + -0.04803, -0.05776, -0.06909, -0.08302, -0.10176, -0.13242}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.r(0) = 0.; + weights.r(1) = 0.; + weights.r(2) = 0.; + + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); +} - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test1) { + NDArray input = NDArrayFactory::create( + 'c', {1, 7, 7, 1}, + {1.f, 2.1f, 3.15f, 4.2f, 5.15f, 6.1f, 7.f, 8.f, 9.1f, 10.f, + 11.f, 12.9f, 13.1f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 30.f, 31.f, + 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, 41.f, + 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 30, 30, 1}, + {1.f, 1.1976162f, 1.4174359f, 1.6775769f, 1.9961575f, + 2.3283265f, 2.550918f, 2.7360606f, 2.9655411f, 3.2929654f, + 3.5441515f, 3.7380352f, 3.948995f, 4.248106f, 4.5073795f, + 4.6843743f, 4.8572845f, 5.104302f, 5.3869915f, 5.581401f, + 5.7539616f, 5.974285f, 6.272836f, 6.5204263f, 6.718899f, + 6.8871036f, 7.039068f, 7.099216f, 7.0784245f, 7.0281887f, + 2.247592f, 2.446947f, 2.6694887f, 2.9312382f, 3.248216f, + 3.5745337f, 3.78931f, 3.9656973f, 4.186417f, 4.5046535f, + 4.740569f, 4.9217057f, 5.133866f, 5.459533f, 5.7744613f, + 6.0197873f, 6.254011f, 6.535633f, 6.8097296f, 6.9607787f, + 7.0749416f, 7.241601f, 7.5094895f, 7.7499495f, 7.954571f, + 8.131972f, 8.286526f, 8.346463f, 8.325745f, 8.275683f, + 3.6286845f, 3.830573f, 4.0569587f, 4.3211575f, 4.6364856f, + 4.9556503f, 5.160583f, 5.3258467f, 5.535462f, 5.84216f, + 6.058749f, 6.223753f, 6.437597f, 6.797369f, 7.1836042f, + 7.5164022f, 7.8290343f, 8.154773f, 8.417635f, 8.512958f, + 8.5521f, 8.649708f, 8.87788f, 9.108794f, 9.320926f, + 9.509781f, 9.667375f, 9.72694f, 9.706349f, 9.656599f, + 5.276778f, 5.480438f, 5.709702f, 5.9754477f, 6.288551f, + 6.6005697f, 6.796207f, 6.9511423f, 7.1503997f, 7.4461427f, + 7.644651f, 7.794562f, 8.009684f, 8.400473f, 8.851847f, + 9.26469f, 9.649218f, 10.015648f, 10.268647f, 10.313368f, + 10.2843275f, 10.319379f, 10.512033f, 10.734956f, 10.954604f, + 11.154507f, 11.315369f, 11.374779f, 11.354242f, 11.304622f, + 7.325373f, 7.5284843f, 7.757575f, 8.022221f, 8.331997f, + 8.638187f, 8.827649f, 8.976217f, 9.168955f, 9.45726f, + 9.6442375f, 9.784517f, 9.999621f, 10.407702f, 10.896234f, + 11.355122f, 11.781423f, 12.172186f, 12.420712f, 12.4374485f, + 12.370511f, 12.371386f, 12.545973f, 12.766424f, 12.992249f, + 13.20012f, 13.364252f, 13.424109f, 13.40342f, 13.353425f, + 9.493208f, 9.692467f, 9.9169445f, 10.176801f, 10.482199f, + 10.78547f, 10.974367f, 11.123442f, 11.31637f, 11.603645f, + 11.790616f, 11.930889f, 12.144082f, 12.546447f, 13.024898f, + 13.4723f, 13.889232f, 14.276275f, 14.528972f, 14.555555f, + 14.50145f, 14.515459f, 14.700572f, 14.927055f, 15.156046f, + 15.366046f, 15.532901f, 15.594008f, 15.5728855f, 15.521847f, + 10.970133f, 11.163599f, 11.380694f, 11.633735f, 11.935032f, + 12.238887f, 12.43254f, 12.588294f, 12.787534f, 13.079956f, + 13.27752f, 13.426631f, 13.636713f, 14.013844f, 14.441672f, + 14.827978f, 15.191209f, 15.549808f, 15.81343f, 15.881828f, + 15.883522f, 15.950411f, 16.16933f, 16.40794f, 16.636436f, + 16.842583f, 17.010887f, 17.07363f, 17.05194f, 16.999537f, + 12.219155f, 12.406129f, 12.614796f, 12.860335f, 13.157928f, + 13.464224f, 13.665207f, 13.830567f, 14.039036f, 14.339629f, + 14.552863f, 14.715049f, 14.921564f, 15.264454f, 15.622843f, + 15.924977f, 16.213829f, 16.532364f, 16.8099f, 16.934835f, + 17.012146f, 17.150164f, 17.413412f, 17.666712f, 17.892765f, + 18.09207f, 18.261044f, 18.325531f, 18.303238f, 18.249378f, + 13.7663965f, 13.947391f, 14.148263f, 14.386917f, 14.681246f, + 14.990087f, 15.198166f, 15.372728f, 15.590062f, 15.898583f, + 16.126892f, 16.301655f, 16.50487f, 16.815214f, 17.107498f, + 17.329458f, 17.547403f, 17.827654f, 18.118288f, 18.296928f, + 18.4461f, 18.651634f, 18.956806f, 19.22382f, 19.447308f, + 19.639887f, 19.809319f, 19.875397f, 19.852556f, 19.797365f, + 15.9419365f, 16.118704f, 16.314133f, 16.547867f, 16.839561f, + 17.14954f, 17.361883f, 17.542162f, 17.764957f, 18.078188f, + 18.315733f, 18.498205f, 18.699116f, 18.988684f, 19.238989f, + 19.410137f, 19.583265f, 19.839512f, 20.13878f, 20.35177f, + 20.546844f, 20.795671f, 21.128067f, 21.404358f, 21.626736f, + 21.8155f, 21.98561f, 22.052843f, 22.029604f, 21.973448f, + 17.53522f, 17.71077f, 17.904636f, 18.13695f, 18.42784f, + 18.738056f, 18.951529f, 19.133352f, 19.357613f, 19.672083f, + 19.912102f, 20.096638f, 20.296894f, 20.580765f, 20.819603f, + 20.976887f, 21.137802f, 21.387535f, 21.689209f, 21.911621f, + 22.119276f, 22.37999f, 22.71991f, 22.998823f, 23.22097f, + 23.40876f, 23.57911f, 23.646685f, 23.623325f, 23.566887f, + 18.746353f, 18.922657f, 19.117487f, 19.350685f, 19.64207f, + 19.952137f, 20.164913f, 20.345781f, 20.569134f, 20.88284f, + 21.12133f, 21.30459f, 21.505253f, 21.792645f, 22.038572f, + 22.204426f, 22.37289f, 22.626648f, 22.926834f, 23.143423f, + 23.343302f, 23.596668f, 23.931936f, 24.209232f, 24.431519f, + 24.619913f, 24.79011f, 24.857473f, 24.83419f, 24.777927f, + 20.16656f, 20.344206f, 20.540766f, 20.775532f, 21.067804f, + 21.377607f, 21.589132f, 21.768297f, 21.99003f, 22.302366f, + 22.538124f, 22.719105f, 22.920494f, 23.214176f, 23.472767f, + 23.653934f, 23.83589f, 24.096842f, 24.394371f, 24.600555f, + 24.786541f, 25.026773f, 25.353731f, 25.62813f, 25.850672f, + 26.04014f, 26.210072f, 26.277063f, 26.253906f, 26.197956f, + 22.363024f, 22.54125f, 22.738552f, 22.973991f, 23.266647f, + 23.57634f, 23.787327f, 23.96576f, 24.186796f, 24.498543f, + 24.733124f, 24.913122f, 25.114826f, 25.411213f, 25.675262f, + 25.863028f, 26.050789f, 26.314838f, 26.611223f, 26.812925f, + 26.992926f, 27.227505f, 27.550882f, 27.824034f, 28.046684f, + 28.236614f, 28.406433f, 28.473265f, 28.450163f, 28.394344f, + 24.429443f, 24.60767f, 24.80497f, 25.04041f, 25.333065f, + 25.642756f, 25.853743f, 26.032173f, 26.25321f, 26.564959f, + 26.79954f, 26.97954f, 27.181242f, 27.47763f, 27.74168f, + 27.929441f, 28.117207f, 28.381254f, 28.677637f, 28.879343f, + 29.059345f, 29.293922f, 29.617298f, 29.890451f, 30.113104f, + 30.303034f, 30.472853f, 30.539684f, 30.516582f, 30.460762f, + 26.f, 26.178228f, 26.375526f, 26.61097f, 26.903624f, + 27.213314f, 27.424305f, 27.602734f, 27.823772f, 28.135519f, + 28.3701f, 28.550098f, 28.7518f, 29.04819f, 29.312237f, + 29.5f, 29.687763f, 29.951813f, 30.2482f, 30.449903f, + 30.629902f, 30.864483f, 31.187859f, 31.461012f, 31.683659f, + 31.873592f, 32.043407f, 32.11024f, 32.087135f, 32.03132f, + 27.570559f, 27.748787f, 27.946087f, 28.181528f, 28.474184f, + 28.783876f, 28.994865f, 29.173294f, 29.39433f, 29.70608f, + 29.940659f, 30.120655f, 30.32236f, 30.618746f, 30.882797f, + 31.070557f, 31.25832f, 31.522371f, 31.818754f, 32.02046f, + 32.20046f, 32.43504f, 32.758415f, 33.031567f, 33.25422f, + 33.44415f, 33.613964f, 33.680794f, 33.657696f, 33.60188f, + 29.636976f, 29.815207f, 30.0125f, 30.247944f, 30.5406f, + 30.85029f, 31.061283f, 31.239712f, 31.46075f, 31.7725f, + 32.00708f, 32.187077f, 32.38878f, 32.685165f, 32.949215f, + 33.13698f, 33.32474f, 33.58879f, 33.885178f, 34.086884f, + 34.26688f, 34.501457f, 34.824837f, 35.09799f, 35.320637f, + 35.510574f, 35.68039f, 35.747215f, 35.724117f, 35.6683f, + 31.83344f, 32.011665f, 32.20897f, 32.444412f, 32.73707f, + 33.046757f, 33.257744f, 33.436176f, 33.657207f, 33.96896f, + 34.203537f, 34.383537f, 34.58524f, 34.88163f, 35.145676f, + 35.33344f, 35.521206f, 35.785255f, 36.081642f, 36.28334f, + 36.46334f, 36.69792f, 37.021297f, 37.294453f, 37.517097f, + 37.707027f, 37.876846f, 37.94368f, 37.920578f, 37.864758f, + 33.253647f, 33.431873f, 33.62917f, 33.864613f, 34.15727f, + 34.466957f, 34.677948f, 34.856377f, 35.077415f, 35.38916f, + 35.623745f, 35.803745f, 36.005447f, 36.301834f, 36.565884f, + 36.753647f, 36.941406f, 37.205456f, 37.50184f, 37.703545f, + 37.883545f, 38.118122f, 38.4415f, 38.714653f, 38.9373f, + 39.127235f, 39.297054f, 39.363884f, 39.340782f, 39.28496f, + 34.464783f, 34.64301f, 34.840305f, 35.075752f, 35.368404f, + 35.6781f, 35.889088f, 36.067516f, 36.28855f, 36.6003f, + 36.834885f, 37.014877f, 37.216583f, 37.51297f, 37.77702f, + 37.964783f, 38.152546f, 38.416595f, 38.71298f, 38.914684f, + 39.094685f, 39.32926f, 39.652645f, 39.925793f, 40.14844f, + 40.338375f, 40.508194f, 40.575024f, 40.55192f, 40.496105f, + 36.058067f, 36.23629f, 36.43359f, 36.669033f, 36.961685f, + 37.271378f, 37.48237f, 37.6608f, 37.881836f, 38.19359f, + 38.42817f, 38.608162f, 38.809868f, 39.10625f, 39.3703f, + 39.558064f, 39.74583f, 40.00988f, 40.306267f, 40.50797f, + 40.68797f, 40.92255f, 41.245926f, 41.519077f, 41.741722f, + 41.931652f, 42.101475f, 42.168304f, 42.145203f, 42.089386f, + 38.315002f, 38.493233f, 38.690533f, 38.925976f, 39.218628f, + 39.52832f, 39.739307f, 39.917736f, 40.138775f, 40.45052f, + 40.685104f, 40.865097f, 41.066803f, 41.36319f, 41.627243f, + 41.815002f, 42.002766f, 42.26682f, 42.5632f, 42.764908f, + 42.944904f, 43.179485f, 43.50286f, 43.776016f, 43.998665f, + 44.188595f, 44.358418f, 44.425247f, 44.402145f, 44.34633f, + 40.22708f, 40.40531f, 40.602608f, 40.83805f, 41.130707f, + 41.440395f, 41.651382f, 41.82982f, 42.050854f, 42.3626f, + 42.597183f, 42.77718f, 42.97888f, 43.27527f, 43.53932f, + 43.72708f, 43.914845f, 44.178894f, 44.47528f, 44.676983f, + 44.856983f, 45.09156f, 45.41494f, 45.68809f, 45.91074f, + 46.100674f, 46.270493f, 46.337322f, 46.31422f, 46.2584f, + 41.785618f, 41.963844f, 42.161144f, 42.396584f, 42.68924f, + 42.998936f, 43.209923f, 43.388355f, 43.609394f, 43.921143f, + 44.15572f, 44.335716f, 44.53742f, 44.833805f, 45.09786f, + 45.285614f, 45.473377f, 45.737427f, 46.033817f, 46.235523f, + 46.415524f, 46.650105f, 46.973476f, 47.24663f, 47.469276f, + 47.65921f, 47.82903f, 47.895855f, 47.872753f, 47.81694f, + 43.11514f, 43.293365f, 43.490665f, 43.726105f, 44.018764f, + 44.328457f, 44.539444f, 44.717873f, 44.93891f, 45.25066f, + 45.48524f, 45.665237f, 45.86694f, 46.163326f, 46.427376f, + 46.615143f, 46.802902f, 47.066956f, 47.363342f, 47.56505f, + 47.74505f, 47.979626f, 48.302998f, 48.576153f, 48.798798f, + 48.98873f, 49.158546f, 49.225376f, 49.202282f, 49.146458f, + 44.303867f, 44.482094f, 44.679394f, 44.914833f, 45.207493f, + 45.51718f, 45.72817f, 45.9066f, 46.12764f, 46.439384f, + 46.673965f, 46.853966f, 47.055668f, 47.352055f, 47.6161f, + 47.803867f, 47.99163f, 48.25568f, 48.552063f, 48.75377f, + 48.933773f, 49.16835f, 49.491726f, 49.764877f, 49.987526f, + 50.17746f, 50.347275f, 50.4141f, 50.391006f, 50.335186f, + 44.771675f, 44.949905f, 45.1472f, 45.382645f, 45.6753f, + 45.98499f, 46.195976f, 46.374413f, 46.595448f, 46.907196f, + 47.141773f, 47.321774f, 47.523476f, 47.819862f, 48.08391f, + 48.27168f, 48.459446f, 48.72349f, 49.019882f, 49.22158f, + 49.401585f, 49.63616f, 49.959538f, 50.232693f, 50.455338f, + 50.64527f, 50.81509f, 50.88192f, 50.858818f, 50.803f, + 44.609966f, 44.788193f, 44.985493f, 45.220936f, 45.51359f, + 45.82328f, 46.03427f, 46.2127f, 46.433743f, 46.74549f, + 46.98007f, 47.160065f, 47.36177f, 47.658157f, 47.922207f, + 48.10997f, 48.297733f, 48.561783f, 48.858166f, 49.059875f, + 49.239872f, 49.47445f, 49.79783f, 50.07098f, 50.293625f, + 50.48356f, 50.653378f, 50.720203f, 50.6971f, 50.64128f, + 44.219246f, 44.397472f, 44.594772f, 44.83021f, 45.122868f, + 45.43256f, 45.643543f, 45.82198f, 46.04302f, 46.354763f, + 46.589344f, 46.76934f, 46.971046f, 47.267433f, 47.531483f, + 47.719242f, 47.907005f, 48.17105f, 48.467438f, 48.66914f, + 48.849144f, 49.08372f, 49.4071f, 49.680256f, 49.902905f, + 50.092834f, 50.262653f, 50.329483f, 50.30638f, 50.25057f}); + + auto size = NDArrayFactory::create({30, 30}); + sd::ops::resize_bicubic op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + + // result.printBuffer("Resized to 30x30"); + // expected.printBuffer("Expect for 30x30"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test2) { + NDArray input = NDArrayFactory::create('c', {2, 5, 4, 3}); + NDArray expected = NDArrayFactory::create( + 'c', {2, 10, 8, 3}, + {1.000000f, 2.000000f, 3.000000f, 2.218750f, + 3.218750f, 4.218750f, 4.000000f, 5.000000f, + 6.000000f, 5.500000f, 6.500000f, 7.500000f, + 7.000000f, 8.000000f, 9.000000f, 8.781250f, + 9.781250f, 10.781250f, 10.000000f, 11.000000f, + 12.000000f, 10.281250f, 11.281250f, 12.281250f, + 5.875000f, 6.875000f, 7.875000f, 7.093750f, + 8.093750f, 9.093750f, 8.875000f, 9.875000f, + 10.875000f, 10.375000f, 11.375000f, 12.375000f, + 11.875000f, 12.875000f, 13.875000f, 13.656250f, + 14.656250f, 15.656250f, 14.875000f, 15.875000f, + 16.875000f, 15.156250f, 16.156250f, 17.156250f, + 13.000000f, 14.000000f, 15.000000f, 14.218750f, + 15.218750f, 16.218750f, 16.000000f, 17.000000f, + 18.000000f, 17.500000f, 18.500000f, 19.500000f, + 19.000000f, 20.000000f, 21.000000f, 20.781250f, + 21.781250f, 22.781250f, 22.000000f, 23.000000f, + 24.000000f, 22.281250f, 23.281250f, 24.281250f, + 19.000000f, 20.000000f, 21.000000f, 20.218750f, + 21.218750f, 22.218750f, 22.000000f, 23.000000f, + 24.000000f, 23.500000f, 24.500000f, 25.500000f, + 25.000000f, 26.000000f, 27.000000f, 26.781250f, + 27.781250f, 28.781250f, 28.000000f, 29.000000f, + 30.000000f, 28.281250f, 29.281250f, 30.281250f, + 25.000000f, 26.000000f, 27.000000f, 26.218750f, + 27.218750f, 28.218750f, 28.000000f, 29.000000f, + 30.000000f, 29.500000f, 30.500000f, 31.500000f, + 31.000000f, 32.000000f, 33.000000f, 32.781250f, + 33.781250f, 34.781250f, 34.000000f, 35.000000f, + 36.000000f, 34.281250f, 35.281250f, 36.281250f, + 31.000000f, 32.000000f, 33.000000f, 32.218750f, + 33.218750f, 34.218750f, 34.000000f, 35.000000f, + 36.000000f, 35.500000f, 36.500000f, 37.500000f, + 37.000000f, 38.000000f, 39.000000f, 38.781250f, + 39.781250f, 40.781250f, 40.000000f, 41.000000f, + 42.000000f, 40.281250f, 41.281250f, 42.281250f, + 37.000000f, 38.000000f, 39.000000f, 38.218750f, + 39.218750f, 40.218750f, 40.000000f, 41.000000f, + 42.000000f, 41.500000f, 42.500000f, 43.500000f, + 43.000000f, 44.000000f, 45.000000f, 44.781250f, + 45.781250f, 46.781250f, 46.000000f, 47.000000f, + 48.000000f, 46.281250f, 47.281250f, 48.281250f, + 44.125000f, 45.125000f, 46.125000f, 45.343750f, + 46.343750f, 47.343750f, 47.125000f, 48.125000f, + 49.125000f, 48.625000f, 49.625000f, 50.625000f, + 50.125000f, 51.125000f, 52.125000f, 51.906250f, + 52.906250f, 53.906250f, 53.125000f, 54.125000f, + 55.125000f, 53.406250f, 54.406250f, 55.406250f, + 49.000000f, 50.000000f, 51.000000f, 50.218750f, + 51.218750f, 52.218750f, 52.000000f, 53.000000f, + 54.000000f, 53.500000f, 54.500000f, 55.500000f, + 55.000000f, 56.000000f, 57.000000f, 56.781250f, + 57.781250f, 58.781250f, 58.000000f, 59.000000f, + 60.000000f, 58.281250f, 59.281250f, 60.281250f, + 50.125000f, 51.125000f, 52.125000f, 51.343750f, + 52.343750f, 53.343750f, 53.125000f, 54.125000f, + 55.125000f, 54.625000f, 55.625000f, 56.625000f, + 56.125000f, 57.125000f, 58.125000f, 57.906250f, + 58.906250f, 59.906250f, 59.125000f, 60.125000f, + 61.125000f, 59.406250f, 60.406250f, 61.406250f, + 61.000000f, 62.000000f, 63.000000f, 62.218750f, + 63.218750f, 64.218750f, 64.000000f, 65.000000f, + 66.000000f, 65.500000f, 66.500000f, 67.500000f, + 67.000000f, 68.000000f, 69.000000f, 68.781250f, + 69.781250f, 70.781250f, 70.000000f, 71.000000f, + 72.000000f, 70.281250f, 71.281250f, 72.281250f, + 65.875000f, 66.875000f, 67.875000f, 67.093750f, + 68.093750f, 69.093750f, 68.875000f, 69.875000f, + 70.875000f, 70.375000f, 71.375000f, 72.375000f, + 71.875000f, 72.875000f, 73.875000f, 73.656250f, + 74.656250f, 75.656250f, 74.875000f, 75.875000f, + 76.875000f, 75.156250f, 76.156250f, 77.156250f, + 73.000000f, 74.000000f, 75.000000f, 74.218750f, + 75.218750f, 76.218750f, 76.000000f, 77.000000f, + 78.000000f, 77.500000f, 78.500000f, 79.500000f, + 79.000000f, 80.000000f, 81.000000f, 80.781250f, + 81.781250f, 82.781250f, 82.000000f, 83.000000f, + 84.000000f, 82.281250f, 83.281250f, 84.281250f, + 79.000000f, 80.000000f, 81.000000f, 80.218750f, + 81.218750f, 82.218750f, 82.000000f, 83.000000f, + 84.000000f, 83.500000f, 84.500000f, 85.500000f, + 85.000000f, 86.000000f, 87.000000f, 86.781250f, + 87.781250f, 88.781250f, 88.000000f, 89.000000f, + 90.000000f, 88.281250f, 89.281250f, 90.281250f, + 85.000000f, 86.000000f, 87.000000f, 86.218750f, + 87.218750f, 88.218750f, 88.000000f, 89.000000f, + 90.000000f, 89.500000f, 90.500000f, 91.500000f, + 91.000000f, 92.000000f, 93.000000f, 92.781250f, + 93.781250f, 94.781250f, 94.000000f, 95.000000f, + 96.000000f, 94.281250f, 95.281250f, 96.281250f, + 91.000000f, 92.000000f, 93.000000f, 92.218750f, + 93.218750f, 94.218750f, 94.000000f, 95.000000f, + 96.000000f, 95.500000f, 96.500000f, 97.500000f, + 97.000000f, 98.000000f, 99.000000f, 98.781250f, + 99.781250f, 100.781250f, 100.000000f, 101.000000f, + 102.000000f, 100.281250f, 101.281250f, 102.281250f, + 97.000000f, 98.000000f, 99.000000f, 98.218750f, + 99.218750f, 100.218750f, 100.000000f, 101.000000f, + 102.000000f, 101.500000f, 102.500000f, 103.500000f, + 103.000000f, 104.000000f, 105.000000f, 104.781250f, + 105.781250f, 106.781250f, 106.000000f, 107.000000f, + 108.000000f, 106.281250f, 107.281250f, 108.281250f, + 104.125000f, 105.125000f, 106.125000f, 105.343750f, + 106.343750f, 107.343750f, 107.125000f, 108.125000f, + 109.125000f, 108.625000f, 109.625000f, 110.625000f, + 110.125000f, 111.125000f, 112.125000f, 111.906250f, + 112.906250f, 113.906250f, 113.125000f, 114.125000f, + 115.125000f, 113.406250f, 114.406250f, 115.406250f, + 109.000000f, 110.000000f, 111.000000f, 110.218750f, + 111.218750f, 112.218750f, 112.000000f, 113.000000f, + 114.000000f, 113.500000f, 114.500000f, 115.500000f, + 115.000000f, 116.000000f, 117.000000f, 116.781250f, + 117.781250f, 118.781250f, 118.000000f, 119.000000f, + 120.000000f, 118.281250f, 119.281250f, 120.281250f, + 110.125000f, 111.125000f, 112.125000f, 111.343750f, + 112.343750f, 113.343750f, 113.125000f, 114.125000f, + 115.125000f, 114.625000f, 115.625000f, 116.625000f, + 116.125000f, 117.125000f, 118.125000f, 117.906250f, + 118.906250f, 119.906250f, 119.125000f, 120.125000f, + 121.125000f, 119.406250f, 120.406250f, 121.406250f}); // input = 1.f; + input.linspace(1); + auto size = NDArrayFactory::create({10, 8}); + sd::ops::resize_bicubic op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Resized to 10x8"); + // expected.printBuffer("Expect for 10x8"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} - NDArray dLdwExp('c', {1,3,1}, {0.20365,-1.92882,-7.76537}); +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test3) { + NDArray input = NDArrayFactory::create('c', {1, 3, 3, 4}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 6, 6, 4}, + {1.000000f, 2.000000f, 3.000000f, 4.000000f, 2.625000f, 3.625000f, + 4.625000f, 5.625000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, + 7.375000f, 8.375000f, 9.375000f, 10.375000f, 9.000000f, 10.000000f, + 11.000000f, 12.000000f, 9.375000f, 10.375000f, 11.375000f, 12.375000f, + 5.875000f, 6.875000f, 7.875000f, 8.875000f, 7.500000f, 8.500000f, + 9.500000f, 10.500000f, 9.875000f, 10.875000f, 11.875000f, 12.875000f, + 12.250000f, 13.250000f, 14.250000f, 15.250000f, 13.875000f, 14.875000f, + 15.875000f, 16.875000f, 14.250000f, 15.250000f, 16.250000f, 17.250000f, + 13.000000f, 14.000000f, 15.000000f, 16.000000f, 14.625000f, 15.625000f, + 16.625000f, 17.625000f, 17.000000f, 18.000000f, 19.000000f, 20.000000f, + 19.375000f, 20.375000f, 21.375000f, 22.375000f, 21.000000f, 22.000000f, + 23.000000f, 24.000000f, 21.375000f, 22.375000f, 23.375000f, 24.375000f, + 20.125000f, 21.125000f, 22.125000f, 23.125000f, 21.750000f, 22.750000f, + 23.750000f, 24.750000f, 24.125000f, 25.125000f, 26.125000f, 27.125000f, + 26.500000f, 27.500000f, 28.500000f, 29.500000f, 28.125000f, 29.125000f, + 30.125000f, 31.125000f, 28.500000f, 29.500000f, 30.500000f, 31.500000f, + 25.000000f, 26.000000f, 27.000000f, 28.000000f, 26.625000f, 27.625000f, + 28.625000f, 29.625000f, 29.000000f, 30.000000f, 31.000000f, 32.000000f, + 31.375000f, 32.375000f, 33.375000f, 34.375000f, 33.000000f, 34.000000f, + 35.000000f, 36.000000f, 33.375000f, 34.375000f, 35.375000f, 36.375000f, + 26.125000f, 27.125000f, 28.125000f, 29.125000f, 27.750000f, 28.750000f, + 29.750000f, 30.750000f, 30.125000f, 31.125000f, 32.125000f, 33.125000f, + 32.500000f, 33.500000f, 34.500000f, 35.500000f, 34.125000f, 35.125000f, + 36.125000f, 37.125000f, 34.500000f, 35.500000f, 36.500000f, 37.500000f}); + input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_bicubic op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Resized to 6x6"); + // expected.printBuffer("Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test4) { + NDArray input = NDArrayFactory::create('c', {1, 3, 4, 3}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 6, 8, 3}, + {1.000000f, 2.000000f, 3.000000f, 2.218750f, 3.218750f, 4.218750f, + 4.000000f, 5.000000f, 6.000000f, 5.500000f, 6.500000f, 7.500000f, + 7.000000f, 8.000000f, 9.000000f, 8.781250f, 9.781250f, 10.781250f, + 10.000000f, 11.000000f, 12.000000f, 10.281250f, 11.281250f, 12.281250f, + 5.875000f, 6.875000f, 7.875000f, 7.093750f, 8.093750f, 9.093750f, + 8.875000f, 9.875000f, 10.875000f, 10.375000f, 11.375000f, 12.375000f, + 11.875000f, 12.875000f, 13.875000f, 13.656250f, 14.656250f, 15.656250f, + 14.875000f, 15.875000f, 16.875000f, 15.156250f, 16.156250f, 17.156250f, + 13.000000f, 14.000000f, 15.000000f, 14.218750f, 15.218750f, 16.218750f, + 16.000000f, 17.000000f, 18.000000f, 17.500000f, 18.500000f, 19.500000f, + 19.000000f, 20.000000f, 21.000000f, 20.781250f, 21.781250f, 22.781250f, + 22.000000f, 23.000000f, 24.000000f, 22.281250f, 23.281250f, 24.281250f, + 20.125000f, 21.125000f, 22.125000f, 21.343750f, 22.343750f, 23.343750f, + 23.125000f, 24.125000f, 25.125000f, 24.625000f, 25.625000f, 26.625000f, + 26.125000f, 27.125000f, 28.125000f, 27.906250f, 28.906250f, 29.906250f, + 29.125000f, 30.125000f, 31.125000f, 29.406250f, 30.406250f, 31.406250f, + 25.000000f, 26.000000f, 27.000000f, 26.218750f, 27.218750f, 28.218750f, + 28.000000f, 29.000000f, 30.000000f, 29.500000f, 30.500000f, 31.500000f, + 31.000000f, 32.000000f, 33.000000f, 32.781250f, 33.781250f, 34.781250f, + 34.000000f, 35.000000f, 36.000000f, 34.281250f, 35.281250f, 36.281250f, + 26.125000f, 27.125000f, 28.125000f, 27.343750f, 28.343750f, 29.343750f, + 29.125000f, 30.125000f, 31.125000f, 30.625000f, 31.625000f, 32.625000f, + 32.125000f, 33.125000f, 34.125000f, 33.906250f, 34.906250f, 35.906250f, + 35.125000f, 36.125000f, 37.125000f, 35.406250f, 36.406250f, 37.406250f}); + input.linspace(1); + auto size = NDArrayFactory::create({6, 8}); + sd::ops::resize_bicubic op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Resized to 6x8"); + // expected.printBuffer("Expect for 6x8"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test5) { + NDArray input = NDArrayFactory::create('c', {1, 4, 4, 3}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 8, 8, 3}, + { + 1.000000f, 2.000000f, 3.000000f, 2.218750f, 3.218750f, + 4.218750f, 4.000000f, 5.000000f, 6.000000f, 5.500000f, + 6.500000f, 7.500000f, 7.000000f, 8.000000f, 9.000000f, + 8.781250f, 9.781250f, 10.781250f, 10.000000f, 11.000000f, + 12.000000f, 10.281250f, 11.281250f, 12.281250f, 5.875000f, + 6.875000f, 7.875000f, 7.093750f, 8.093750f, 9.093750f, + 8.875000f, 9.875000f, 10.875000f, 10.375000f, 11.375000f, + 12.375000f, 11.875000f, 12.875000f, 13.875000f, 13.656250f, + 14.656250f, 15.656250f, 14.875000f, 15.875000f, 16.875000f, + 15.156250f, 16.156250f, 17.156250f, 13.000000f, 14.000000f, + 15.000000f, 14.218750f, 15.218750f, 16.218750f, 16.000000f, + 17.000000f, 18.000000f, 17.500000f, 18.500000f, 19.500000f, + 19.000000f, 20.000000f, 21.000000f, 20.781250f, 21.781250f, + 22.781250f, 22.000000f, 23.000000f, 24.000000f, 22.281250f, + 23.281250f, 24.281250f, 19.000000f, 20.000000f, 21.000000f, + 20.218750f, 21.218750f, 22.218750f, 22.000000f, 23.000000f, + 24.000000f, 23.500000f, 24.500000f, 25.500000f, 25.000000f, + 26.000000f, 27.000000f, 26.781250f, 27.781250f, 28.781250f, + 28.000000f, 29.000000f, 30.000000f, 28.281250f, 29.281250f, + 30.281250f, 25.000000f, 26.000000f, 27.000000f, 26.218750f, + 27.218750f, 28.218750f, 28.000000f, 29.000000f, 30.000000f, + 29.500000f, 30.500000f, 31.500000f, 31.000000f, 32.000000f, + 33.000000f, 32.781250f, 33.781250f, 34.781250f, 34.000000f, + 35.000000f, 36.000000f, 34.281250f, 35.281250f, 36.281250f, + 32.125000f, 33.125000f, 34.125000f, 33.343750f, 34.343750f, + 35.343750f, 35.125000f, 36.125000f, 37.125000f, 36.625000f, + 37.625000f, 38.625000f, 38.125000f, 39.125000f, 40.125000f, + 39.906250f, 40.906250f, 41.906250f, 41.125000f, 42.125000f, + 43.125000f, 41.406250f, 42.406250f, 43.406250f, 37.000000f, + 38.000000f, 39.000000f, 38.218750f, 39.218750f, 40.218750f, + 40.000000f, 41.000000f, 42.000000f, 41.500000f, 42.500000f, + 43.500000f, 43.000000f, 44.000000f, 45.000000f, 44.781250f, + 45.781250f, 46.781250f, 46.000000f, 47.000000f, 48.000000f, + 46.281250f, 47.281250f, 48.281250f, 38.125000f, 39.125000f, + 40.125000f, 39.343750f, 40.343750f, 41.343750f, 41.125000f, + 42.125000f, 43.125000f, 42.625000f, 43.625000f, 44.625000f, + 44.125000f, 45.125000f, 46.125000f, 45.906250f, 46.906250f, + 47.906250f, 47.125000f, 48.125000f, 49.125000f, 47.406250f, + 48.406250f, 49.406250f, + }); + input.linspace(1); + auto size = NDArrayFactory::create({8, 8}); + sd::ops::resize_bicubic op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Resized to 8x8"); + // expected.printBuffer("Expect for 8x8"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} - ASSERT_EQ(ND4J_STATUS_OK, results.status()); +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test6) { + NDArray input = NDArrayFactory::create( + 'c', {7, 7, 1}, + {1.f, 2.1f, 3.15f, 4.2f, 5.15f, 6.1f, 7.f, 8.f, 9.1f, 10.f, + 11.f, 12.9f, 13.1f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 30.f, 31.f, + 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, 41.f, + 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f}); + + NDArray expected = NDArrayFactory::create( + 'c', {30, 30, 1}, + {1.000000f, 1.197616f, 1.417436f, 1.677577f, 1.996158f, 2.328327f, + 2.550918f, 2.736061f, 2.965541f, 3.292965f, 3.544151f, 3.738035f, + 3.948995f, 4.248106f, 4.507379f, 4.684374f, 4.857284f, 5.104302f, + 5.386991f, 5.581401f, 5.753962f, 5.974285f, 6.272836f, 6.520426f, + 6.718899f, 6.887104f, 7.039068f, 7.099216f, 7.078424f, 7.028189f, + 2.247592f, 2.446947f, 2.669489f, 2.931238f, 3.248216f, 3.574534f, + 3.789310f, 3.965697f, 4.186417f, 4.504653f, 4.740569f, 4.921706f, + 5.133866f, 5.459533f, 5.774461f, 6.019787f, 6.254011f, 6.535633f, + 6.809730f, 6.960779f, 7.074942f, 7.241601f, 7.509489f, 7.749949f, + 7.954571f, 8.131972f, 8.286526f, 8.346463f, 8.325745f, 8.275683f, + 3.628684f, 3.830573f, 4.056959f, 4.321157f, 4.636486f, 4.955650f, + 5.160583f, 5.325847f, 5.535462f, 5.842160f, 6.058749f, 6.223753f, + 6.437597f, 6.797369f, 7.183604f, 7.516402f, 7.829034f, 8.154773f, + 8.417635f, 8.512958f, 8.552100f, 8.649708f, 8.877880f, 9.108794f, + 9.320926f, 9.509781f, 9.667375f, 9.726940f, 9.706349f, 9.656599f, + 5.276778f, 5.480438f, 5.709702f, 5.975448f, 6.288551f, 6.600570f, + 6.796207f, 6.951142f, 7.150400f, 7.446143f, 7.644651f, 7.794562f, + 8.009684f, 8.400473f, 8.851847f, 9.264690f, 9.649218f, 10.015648f, + 10.268647f, 10.313368f, 10.284327f, 10.319379f, 10.512033f, 10.734956f, + 10.954604f, 11.154507f, 11.315369f, 11.374779f, 11.354242f, 11.304622f, + 7.325373f, 7.528484f, 7.757575f, 8.022221f, 8.331997f, 8.638187f, + 8.827649f, 8.976217f, 9.168955f, 9.457260f, 9.644237f, 9.784517f, + 9.999621f, 10.407702f, 10.896234f, 11.355122f, 11.781423f, 12.172186f, + 12.420712f, 12.437449f, 12.370511f, 12.371386f, 12.545973f, 12.766424f, + 12.992249f, 13.200120f, 13.364252f, 13.424109f, 13.403420f, 13.353425f, + 9.493208f, 9.692467f, 9.916944f, 10.176801f, 10.482199f, 10.785470f, + 10.974367f, 11.123442f, 11.316370f, 11.603645f, 11.790616f, 11.930889f, + 12.144082f, 12.546447f, 13.024898f, 13.472300f, 13.889232f, 14.276275f, + 14.528972f, 14.555555f, 14.501450f, 14.515459f, 14.700572f, 14.927055f, + 15.156046f, 15.366046f, 15.532901f, 15.594008f, 15.572885f, 15.521847f, + 10.970133f, 11.163599f, 11.380694f, 11.633735f, 11.935032f, 12.238887f, + 12.432540f, 12.588294f, 12.787534f, 13.079956f, 13.277520f, 13.426631f, + 13.636713f, 14.013844f, 14.441672f, 14.827978f, 15.191209f, 15.549808f, + 15.813430f, 15.881828f, 15.883522f, 15.950411f, 16.169330f, 16.407940f, + 16.636436f, 16.842583f, 17.010887f, 17.073630f, 17.051940f, 16.999537f, + 12.219155f, 12.406129f, 12.614796f, 12.860335f, 13.157928f, 13.464224f, + 13.665207f, 13.830567f, 14.039036f, 14.339629f, 14.552863f, 14.715049f, + 14.921564f, 15.264454f, 15.622843f, 15.924977f, 16.213829f, 16.532364f, + 16.809900f, 16.934835f, 17.012146f, 17.150164f, 17.413412f, 17.666712f, + 17.892765f, 18.092070f, 18.261044f, 18.325531f, 18.303238f, 18.249378f, + 13.766397f, 13.947391f, 14.148263f, 14.386917f, 14.681246f, 14.990087f, + 15.198166f, 15.372728f, 15.590062f, 15.898583f, 16.126892f, 16.301655f, + 16.504870f, 16.815214f, 17.107498f, 17.329458f, 17.547403f, 17.827654f, + 18.118288f, 18.296928f, 18.446100f, 18.651634f, 18.956806f, 19.223820f, + 19.447308f, 19.639887f, 19.809319f, 19.875397f, 19.852556f, 19.797365f, + 15.941937f, 16.118704f, 16.314133f, 16.547867f, 16.839561f, 17.149540f, + 17.361883f, 17.542162f, 17.764957f, 18.078188f, 18.315733f, 18.498205f, + 18.699116f, 18.988684f, 19.238989f, 19.410137f, 19.583265f, 19.839512f, + 20.138780f, 20.351770f, 20.546844f, 20.795671f, 21.128067f, 21.404358f, + 21.626736f, 21.815500f, 21.985610f, 22.052843f, 22.029604f, 21.973448f, + 17.535220f, 17.710770f, 17.904636f, 18.136950f, 18.427840f, 18.738056f, + 18.951529f, 19.133352f, 19.357613f, 19.672083f, 19.912102f, 20.096638f, + 20.296894f, 20.580765f, 20.819603f, 20.976887f, 21.137802f, 21.387535f, + 21.689209f, 21.911621f, 22.119276f, 22.379990f, 22.719910f, 22.998823f, + 23.220970f, 23.408760f, 23.579110f, 23.646685f, 23.623325f, 23.566887f, + 18.746353f, 18.922657f, 19.117487f, 19.350685f, 19.642070f, 19.952137f, + 20.164913f, 20.345781f, 20.569134f, 20.882840f, 21.121330f, 21.304590f, + 21.505253f, 21.792645f, 22.038572f, 22.204426f, 22.372890f, 22.626648f, + 22.926834f, 23.143423f, 23.343302f, 23.596668f, 23.931936f, 24.209232f, + 24.431519f, 24.619913f, 24.790110f, 24.857473f, 24.834190f, 24.777927f, + 20.166560f, 20.344206f, 20.540766f, 20.775532f, 21.067804f, 21.377607f, + 21.589132f, 21.768297f, 21.990030f, 22.302366f, 22.538124f, 22.719105f, + 22.920494f, 23.214176f, 23.472767f, 23.653934f, 23.835890f, 24.096842f, + 24.394371f, 24.600555f, 24.786541f, 25.026773f, 25.353731f, 25.628130f, + 25.850672f, 26.040140f, 26.210072f, 26.277063f, 26.253906f, 26.197956f, + 22.363024f, 22.541250f, 22.738552f, 22.973991f, 23.266647f, 23.576340f, + 23.787327f, 23.965760f, 24.186796f, 24.498543f, 24.733124f, 24.913122f, + 25.114826f, 25.411213f, 25.675262f, 25.863028f, 26.050789f, 26.314838f, + 26.611223f, 26.812925f, 26.992926f, 27.227505f, 27.550882f, 27.824034f, + 28.046684f, 28.236614f, 28.406433f, 28.473265f, 28.450163f, 28.394344f, + 24.429443f, 24.607670f, 24.804970f, 25.040410f, 25.333065f, 25.642756f, + 25.853743f, 26.032173f, 26.253210f, 26.564959f, 26.799540f, 26.979540f, + 27.181242f, 27.477630f, 27.741680f, 27.929441f, 28.117207f, 28.381254f, + 28.677637f, 28.879343f, 29.059345f, 29.293922f, 29.617298f, 29.890451f, + 30.113104f, 30.303034f, 30.472853f, 30.539684f, 30.516582f, 30.460762f, + 26.000000f, 26.178228f, 26.375526f, 26.610970f, 26.903624f, 27.213314f, + 27.424305f, 27.602734f, 27.823772f, 28.135519f, 28.370100f, 28.550098f, + 28.751800f, 29.048190f, 29.312237f, 29.500000f, 29.687763f, 29.951813f, + 30.248200f, 30.449903f, 30.629902f, 30.864483f, 31.187859f, 31.461012f, + 31.683659f, 31.873592f, 32.043407f, 32.110240f, 32.087135f, 32.031320f, + 27.570559f, 27.748787f, 27.946087f, 28.181528f, 28.474184f, 28.783876f, + 28.994865f, 29.173294f, 29.394330f, 29.706080f, 29.940659f, 30.120655f, + 30.322360f, 30.618746f, 30.882797f, 31.070557f, 31.258320f, 31.522371f, + 31.818754f, 32.020460f, 32.200460f, 32.435040f, 32.758415f, 33.031567f, + 33.254220f, 33.444150f, 33.613964f, 33.680794f, 33.657696f, 33.601880f, + 29.636976f, 29.815207f, 30.012500f, 30.247944f, 30.540600f, 30.850290f, + 31.061283f, 31.239712f, 31.460750f, 31.772500f, 32.007080f, 32.187077f, + 32.388780f, 32.685165f, 32.949215f, 33.136980f, 33.324740f, 33.588790f, + 33.885178f, 34.086884f, 34.266880f, 34.501457f, 34.824837f, 35.097990f, + 35.320637f, 35.510574f, 35.680390f, 35.747215f, 35.724117f, 35.668300f, + 31.833440f, 32.011665f, 32.208970f, 32.444412f, 32.737070f, 33.046757f, + 33.257744f, 33.436176f, 33.657207f, 33.968960f, 34.203537f, 34.383537f, + 34.585240f, 34.881630f, 35.145676f, 35.333440f, 35.521206f, 35.785255f, + 36.081642f, 36.283340f, 36.463340f, 36.697920f, 37.021297f, 37.294453f, + 37.517097f, 37.707027f, 37.876846f, 37.943680f, 37.920578f, 37.864758f, + 33.253647f, 33.431873f, 33.629170f, 33.864613f, 34.157270f, 34.466957f, + 34.677948f, 34.856377f, 35.077415f, 35.389160f, 35.623745f, 35.803745f, + 36.005447f, 36.301834f, 36.565884f, 36.753647f, 36.941406f, 37.205456f, + 37.501840f, 37.703545f, 37.883545f, 38.118122f, 38.441500f, 38.714653f, + 38.937300f, 39.127235f, 39.297054f, 39.363884f, 39.340782f, 39.284960f, + 34.464783f, 34.643010f, 34.840305f, 35.075752f, 35.368404f, 35.678100f, + 35.889088f, 36.067516f, 36.288550f, 36.600300f, 36.834885f, 37.014877f, + 37.216583f, 37.512970f, 37.777020f, 37.964783f, 38.152546f, 38.416595f, + 38.712980f, 38.914684f, 39.094685f, 39.329260f, 39.652645f, 39.925793f, + 40.148440f, 40.338375f, 40.508194f, 40.575024f, 40.551920f, 40.496105f, + 36.058067f, 36.236290f, 36.433590f, 36.669033f, 36.961685f, 37.271378f, + 37.482370f, 37.660800f, 37.881836f, 38.193590f, 38.428170f, 38.608162f, + 38.809868f, 39.106250f, 39.370300f, 39.558064f, 39.745830f, 40.009880f, + 40.306267f, 40.507970f, 40.687970f, 40.922550f, 41.245926f, 41.519077f, + 41.741722f, 41.931652f, 42.101475f, 42.168304f, 42.145203f, 42.089386f, + 38.315002f, 38.493233f, 38.690533f, 38.925976f, 39.218628f, 39.528320f, + 39.739307f, 39.917736f, 40.138775f, 40.450520f, 40.685104f, 40.865097f, + 41.066803f, 41.363190f, 41.627243f, 41.815002f, 42.002766f, 42.266820f, + 42.563200f, 42.764908f, 42.944904f, 43.179485f, 43.502860f, 43.776016f, + 43.998665f, 44.188595f, 44.358418f, 44.425247f, 44.402145f, 44.346330f, + 40.227080f, 40.405310f, 40.602608f, 40.838050f, 41.130707f, 41.440395f, + 41.651382f, 41.829820f, 42.050854f, 42.362600f, 42.597183f, 42.777180f, + 42.978880f, 43.275270f, 43.539320f, 43.727080f, 43.914845f, 44.178894f, + 44.475280f, 44.676983f, 44.856983f, 45.091560f, 45.414940f, 45.688090f, + 45.910740f, 46.100674f, 46.270493f, 46.337322f, 46.314220f, 46.258400f, + 41.785618f, 41.963844f, 42.161144f, 42.396584f, 42.689240f, 42.998936f, + 43.209923f, 43.388355f, 43.609394f, 43.921143f, 44.155720f, 44.335716f, + 44.537420f, 44.833805f, 45.097860f, 45.285614f, 45.473377f, 45.737427f, + 46.033817f, 46.235523f, 46.415524f, 46.650105f, 46.973476f, 47.246630f, + 47.469276f, 47.659210f, 47.829030f, 47.895855f, 47.872753f, 47.816940f, + 43.115140f, 43.293365f, 43.490665f, 43.726105f, 44.018764f, 44.328457f, + 44.539444f, 44.717873f, 44.938910f, 45.250660f, 45.485240f, 45.665237f, + 45.866940f, 46.163326f, 46.427376f, 46.615143f, 46.802902f, 47.066956f, + 47.363342f, 47.565050f, 47.745050f, 47.979626f, 48.302998f, 48.576153f, + 48.798798f, 48.988730f, 49.158546f, 49.225376f, 49.202282f, 49.146458f, + 44.303867f, 44.482094f, 44.679394f, 44.914833f, 45.207493f, 45.517180f, + 45.728170f, 45.906600f, 46.127640f, 46.439384f, 46.673965f, 46.853966f, + 47.055668f, 47.352055f, 47.616100f, 47.803867f, 47.991630f, 48.255680f, + 48.552063f, 48.753770f, 48.933773f, 49.168350f, 49.491726f, 49.764877f, + 49.987526f, 50.177460f, 50.347275f, 50.414100f, 50.391006f, 50.335186f, + 44.771675f, 44.949905f, 45.147200f, 45.382645f, 45.675300f, 45.984990f, + 46.195976f, 46.374413f, 46.595448f, 46.907196f, 47.141773f, 47.321774f, + 47.523476f, 47.819862f, 48.083910f, 48.271680f, 48.459446f, 48.723490f, + 49.019882f, 49.221580f, 49.401585f, 49.636160f, 49.959538f, 50.232693f, + 50.455338f, 50.645270f, 50.815090f, 50.881920f, 50.858818f, 50.803000f, + 44.609966f, 44.788193f, 44.985493f, 45.220936f, 45.513590f, 45.823280f, + 46.034270f, 46.212700f, 46.433743f, 46.745490f, 46.980070f, 47.160065f, + 47.361770f, 47.658157f, 47.922207f, 48.109970f, 48.297733f, 48.561783f, + 48.858166f, 49.059875f, 49.239872f, 49.474450f, 49.797830f, 50.070980f, + 50.293625f, 50.483560f, 50.653378f, 50.720203f, 50.697100f, 50.641280f, + 44.219246f, 44.397472f, 44.594772f, 44.830210f, 45.122868f, 45.432560f, + 45.643543f, 45.821980f, 46.043020f, 46.354763f, 46.589344f, 46.769340f, + 46.971046f, 47.267433f, 47.531483f, 47.719242f, 47.907005f, 48.171050f, + 48.467438f, 48.669140f, 48.849144f, 49.083720f, 49.407100f, 49.680256f, + 49.902905f, 50.092834f, 50.262653f, 50.329483f, 50.306380f, 50.250570f}); + + auto size = NDArrayFactory::create({30, 30}); + sd::ops::resize_bicubic op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + + // result.printBuffer("Resized to 30x30"); + // expected.printBuffer("Expect for 30x30"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} - auto *dLdw = results.at(1); +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test7) { + NDArray input = NDArrayFactory::create( + 'c', {2, 5, 5, 1}, + {0.2303, 0.7950, 0.8171, 0.0451, 0.3690, 0.6846, 0.2727, 0.2770, 0.2381, + 0.9511, 0.4116, 0.3997, 0.4075, 0.6275, 0.8018, 0.0678, 0.6221, 0.2982, + 0.1524, 0.2613, 0.7425, 0.6036, 0.7926, 0.5838, 0.1361, 0.4154, 0.3634, + 0.3741, 0.2088, 0.2989, 0.3982, 0.5618, 0.7266, 0.1089, 0.2922, 0.3306, + 0.2869, 0.6638, 0.3091, 0.9312, 0.0240, 0.2893, 0.5632, 0.9625, 0.4189, + 0.3854, 0.2743, 0.6754, 0.8820, 0.8699}); + + NDArray expected = NDArrayFactory::create( + 'c', {2, 9, 9, 1}, + {0.2303f, 0.54569f, 0.840649f, 0.92725444f, 0.65660673f, + 0.16641647f, 0.06117659f, 0.33279106f, 0.4023279f, 0.5139505f, + 0.49821317f, 0.4906872f, 0.537642f, 0.4070102f, 0.13030615f, + 0.258801f, 0.65352744f, 0.773368f, 0.69225276f, 0.44177493f, + 0.21910316f, 0.22368976f, 0.24221404f, 0.21399781f, 0.5114972f, + 0.9169859f, 1.0511527f, 0.5608501f, 0.41315168f, 0.2913824f, + 0.2966933f, 0.38585684f, 0.48849702f, 0.71013063f, 0.9086001f, + 0.9794303f, 0.29625386f, 0.39427578f, 0.45971435f, 0.39693952f, + 0.40860707f, 0.51061106f, 0.6181093f, 0.67309624f, 0.69564015f, + 0.06012487f, 0.3863805f, 0.58993465f, 0.40679216f, 0.22607432f, + 0.20093678f, 0.25901243f, 0.3615362f, 0.39371052f, 0.24176767f, + 0.4868709f, 0.650651f, 0.5493148f, 0.3825456f, 0.27788478f, + 0.18927254f, 0.16692996f, 0.15432167f, 0.677519f, 0.6236242f, + 0.61700624f, 0.7214321f, 0.7307374f, 0.6251454f, 0.3924176f, + 0.17802659f, 0.10231908f, 0.81192374f, 0.66878575f, 0.6118803f, + 0.7797006f, 0.8396968f, 0.72889954f, 0.44547448f, 0.16794783f, + 0.07125802f, 0.4154f, 0.38504714f, 0.3623221f, 0.3862173f, + 0.3397379f, 0.23285517f, 0.21876639f, 0.2892362f, 0.30817088f, + 0.41268015f, 0.45587808f, 0.51991886f, 0.60977113f, 0.49489656f, + 0.21313031f, 0.11297428f, 0.2167207f, 0.23940037f, 0.39337245f, + 0.46112412f, 0.583034f, 0.76207364f, 0.6326203f, 0.22189438f, + 0.12071565f, 0.3275853f, 0.3794855f, 0.38497013f, 0.35049653f, + 0.41895086f, 0.671095f, 0.62119365f, 0.22362521f, 0.30189657f, + 0.72530353f, 0.85048175f, 0.2524255f, 0.2182264f, 0.2964637f, + 0.5361996f, 0.6255393f, 0.46424767f, 0.5741281f, 0.8408146f, + 0.92403257f, 0.04648584f, 0.14959256f, 0.32215607f, 0.46194845f, + 0.6642166f, 0.83560026f, 0.7663391f, 0.5284251f, 0.4573109f, + 0.10357999f, 0.17442937f, 0.32116935f, 0.45530772f, 0.7163773f, + 0.9856574f, 0.8976148f, 0.5538923f, 0.45173654f, 0.34958175f, + 0.2680429f, 0.30470955f, 0.51233786f, 0.75128907f, 0.86736864f, + 0.8982046f, 0.83254474f, 0.8168574f, 0.4225865f, 0.2956836f, + 0.29948136f, 0.5276342f, 0.76461166f, 0.8442875f, 0.907862f, + 0.9139262f, 0.92068815f}); + auto size = NDArrayFactory::create({9, 9}); + sd::ops::resize_bicubic op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Resized to 9x9"); + // expected.printBuffer("Expect for 9x9"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) { + NDArray input = NDArrayFactory::create( + 'c', {2, 5, 5, 1}, + {0.23028551377579154, 0.7949972231516509, 0.8171307820461517, + 0.04507309923418412, 0.3689673597428338, 0.6845757584903018, + 0.27268547668219667, 0.2770196372806053, 0.2381478370531429, + 0.9511201914609859, 0.41160882670429033, 0.3997152563642703, + 0.4074505147711718, 0.6274595060113246, 0.8017922711300232, + 0.06782045852179475, 0.6220772280691722, 0.2982335327629251, + 0.1523603480424196, 0.2612986044295986, 0.7424762244324299, + 0.6036156464824591, 0.7926371071102005, 0.5838270656432538, + 0.13607200219168547, 0.4154002170215956, 0.36340617544852116, + 0.37405031188276827, 0.20880251686544882, 0.298919946410666, + 0.39820758164277126, 0.5617728968896589, 0.72660225993937, + 0.10888245916813699, 0.29215797784445496, 0.3305531351746034, + 0.28693451964931715, 0.6637635348315494, 0.30913418229827583, + 0.9312186188801752, 0.0239594182399363, 0.2892942758780874, + 0.5631691110629038, 0.9625499752246309, 0.4189439089689968, + 0.3854304088214935, 0.27426304203925045, 0.6754051704648238, + 0.8820362490795286, 0.8699337744328859}); + + auto testData = NDArrayFactory::create( + 'c', {2, 9, 9, 1}, + {0.230286514f, 0.510566354f, 0.794997215f, 0.931386113f, 0.817130804f, + 0.402811885f, 0.045073099f, 0.134639814f, 0.368967354f, 0.483021289f, + 0.501266003f, 0.521932304f, 0.572325349f, 0.534847379f, 0.267853439f, + 0.105112493f, 0.349290252f, 0.674043298f, 0.684575737f, 0.478224277f, + 0.272685468f, 0.239882097f, 0.27701965f, 0.191148892f, 0.23814784f, + 0.590989769f, 0.951120198f, 0.622912169f, 0.441326082f, 0.266387194f, + 0.232538164f, 0.301838756f, 0.356378645f, 0.495445013f, 0.756725252f, + 0.981704295f, 0.411608815f, 0.40493685f, 0.399715245f, 0.381842017f, + 0.407450527f, 0.501836538f, 0.627459526f, 0.735251725f, 0.801792264f, + 0.150875032f, 0.357000858f, 0.524536073f, 0.450354964f, 0.318719596f, + 0.319606483f, 0.385957927f, 0.46392554f, 0.529285908f, 0.06782046f, + 0.375309169f, 0.622077227f, 0.525792599f, 0.298233539f, 0.184723631f, + 0.15236035f, 0.193153858f, 0.261298597f, + + 0.372918189f, 0.512539625f, 0.63369292f, 0.628733814f, 0.535196245f, + 0.436597466f, 0.323553175f, 0.215942055f, 0.148014024f, 0.742476225f, + 0.655325174f, 0.603615642f, 0.704684138f, 0.79263711f, 0.747929871f, + 0.583827078f, 0.340373576f, 0.136071995f, 0.415400207f, 0.388405323f, + 0.363406181f, 0.379345775f, 0.374050319f, 0.28397581f, 0.208802521f, + 0.238369256f, 0.298919946f, 0.413146496f, 0.444389015f, 0.488355637f, + 0.568351328f, 0.556217432f, 0.345546633f, 0.140068889f, 0.148834035f, + 0.23562704f, 0.398207575f, 0.464537472f, 0.561772883f, 0.717433035f, + 0.726602256f, 0.416013002f, 0.108882457f, 0.142608985f, 0.292157978f, + 0.391511708f, 0.389470309f, 0.442729384f, 0.651181757f, 0.737665415f, + 0.41685915f, 0.138383076f, 0.342548877f, 0.659080088f, + + 0.330553144f, 0.273416102f, 0.286934525f, 0.50450629f, 0.663763523f, + 0.463456154f, 0.309134185f, 0.586929917f, 0.931218624f, 0.137025774f, + 0.169145152f, 0.263757467f, 0.436182201f, 0.597053051f, 0.657990932f, + 0.662163854f, 0.68354249f, 0.692712903f, 0.023959421f, 0.130951077f, + 0.289294273f, 0.413664877f, 0.563169122f, 0.839498401f, 0.962549984f, + 0.728188932f, 0.418943912f, 0.175951749f, 0.198239252f, 0.281999886f, + 0.420836329f, 0.609856486f, 0.863734365f, 0.983550847f, 0.825015843f, + 0.596413136f, 0.385430396f, 0.292239636f, 0.274263054f, 0.445040524f, + 0.675405145f, 0.817462444f, 0.882036269f, 0.895356655f, 0.869933784f}); + + auto size = NDArrayFactory::create({9, 9}); + sd::ops::resize_bicubic op; + auto results = op.evaluate({&input, &size}, {}, {}, {true, false}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Resized to 9x9"); + // testData.printBuffer("Expect for 9x9"); + ASSERT_TRUE(testData.isSameShape(result)); + ASSERT_TRUE(testData.equalsTo(result)); } -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, log_loss_grad_test12) { - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test1) { + NDArray input = NDArrayFactory::create('c', {1, 3, 3, 4}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 6, 6, 4}, + {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 9.f, 10.f, 11.f, 12.f, - NDArray dLdpExp('c', {2,3,4}, { 0. , 0. , 0. , 0. ,-0.75 ,-0.789473,-0.833333, -0.882353,-0.9375 ,-1. ,-1.071428, -1.153846, - -1.25 ,-1.363636,-1.5 , -1.666666,-1.875 ,-2.142857,-2.499999, -2.999999,-3.749997,-4.999997,-7.499993,-14.999956}); - NDArray dLdwExp('c', {2,3,4}, {0.16094, 0.2484 , 0.30526, 0.34036, 0.35773, 0.35953, 0.34699, 0.32079, 0.28123, 0.22827, 0.16163, 0.08072, - -0.01533,-0.12776,-0.25828,-0.40921,-0.58373,-0.78637,-1.02369,-1.30582,-1.64951,-2.08566,-2.68241,-3.65272}); - NDArray dLdlExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0.03466, 0.02882, 0.02361, 0.01884, 0.01438, 0.01014, 0.00603, 0.002 , - -0.002 ,-0.00603,-0.01014,-0.01438,-0.01884,-0.02361,-0.02882,-0.03466,-0.04146,-0.04981,-0.06106,-0.07945}); + 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 9.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 16.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 21.f, 22.f, 23.f, 24.f, - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.r(0) = 0.; - weights.r(1) = 0.; - weights.r(2) = 0.; - weights.r(3) = 0.; + 13.f, 14.f, 15.f, 16.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 21.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 28.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, + 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 33.f, 34.f, 35.f, 36.f, - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); + 25.f, 26.f, 27.f, 28.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, + 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 33.f, 34.f, 35.f, 36.f}); + input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto result = results.at(0); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + // result.printBuffer("Area Resized to 6x6"); + // expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, log_loss_grad_test13) { - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,1}, sd::DataType::DOUBLE); +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test2) { + NDArray input = NDArrayFactory::create('c', {1, 3, 3, 1}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 6, 6, 1}, + {1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, + 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, + 7.f, 7.f, 8.f, 8.f, 9.f, 9.f, 7.f, 7.f, 8.f, 8.f, 9.f, 9.f}); + input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Area Resized to 6x6"); + // expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} - NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , - -2.08333,-2.27273, -2.5 , -2.77778,-3.125 ,-3.57143, -4.16667, -5. ,-6.25 ,-8.33333,-12.49999,-24.99993}); - NDArray dLdwExp('c', {2,3,1}, {1.75828, 2.30839, 1.25309, -1.35098, -6.16602,-16.78383}); - NDArray dLdlExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , - -0.00334,-0.01005,-0.01689,-0.02397,-0.03141,-0.03935,-0.04803,-0.05776,-0.06909,-0.08302,-0.10176,-0.13242}); +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test3) { + NDArray input = NDArrayFactory::create('c', {1, 3, 3, 3}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 6, 6, 3}, + {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, + 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, + 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f}); + input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Area Resized to 6x6"); + // expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.r(0) = 0.; - weights.r(1) = 0.; - weights.r(2) = 0.; +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test4) { + NDArray input = NDArrayFactory::create( + 'c', {2, 3, 3, 3}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27}); + + NDArray expected = NDArrayFactory::create( + 'c', {2, 6, 6, 3}, + {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, + 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, + 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, + + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, + 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, + 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f}); + // input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Area Resized to 6x6"); + // expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test5) { + NDArray input = NDArrayFactory::create( + 'c', {2, 3, 3, 3}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27}); + + NDArray expected = NDArrayFactory::create( + 'c', {2, 6, 6, 3}, + {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, + 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, + 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, + + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, + 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, + 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f}); + // input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Area Resized to 6x6"); + // expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} - ASSERT_EQ(ND4J_STATUS_OK, results.status()); +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test6) { + NDArray input = NDArrayFactory::create( + 'c', {2, 3, 3, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + + NDArray expected = NDArrayFactory::create( + 'c', {2, 6, 6, 1}, + {1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f, + + 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f}); + // input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Area Resized to 6x6"); + // expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test7) { + NDArray input = NDArrayFactory::create( + 'c', {2, 3, 3, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + + NDArray expected = NDArrayFactory::create( + 'c', {2, 6, 6, 1}, + {1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f, + + 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f}); + // input.linspace(1); + // auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input}, {}, {6, 6}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Area Resized to 6x6"); + // expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test8) { + NDArray input = NDArrayFactory::create('c', {1, 3, 3, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9}); + + NDArray expected = NDArrayFactory::create( + 'c', {1, 6, 6, 1}, + {1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f}); + // input.linspace(1); + // auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input}, {}, {6, 6}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Area Resized to 6x6"); + // expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } -TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test1) { +TEST_F(DeclarableOpsTests11, ResizeImages_Test8) { - NDArray input = NDArrayFactory::create('c', {1, 7, 7, 1}, { - 1.f, 2.1f, 3.15f, 4.2f, 5.15f, 6.1f, 7.f, - 8.f, 9.1f, 10.f, 11.f, 12.9f, 13.1f, 14.f, - 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, - 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, - 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, - 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, - 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f + NDArray input = NDArrayFactory::create('c', {1, 3, 3, 1}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9 + }); + + NDArray expected = NDArrayFactory::create('c', {1, 6, 6, 1}, { +// 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, 4.f, 4.f, 5.f, 5.f, +// 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 9.f, 9.f, 7.f, 7.f, 8.f, 8.f, 9.f, 9.f + 1.f , 1.f , 1.5f, 2.f , 2.f, 3.f, 1.f , 1.f , 1.5f, 2.f , 2.f, 3.f, + 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, 4.f , 4.f , 4.5f , 5.f, 5.f, 6.f , + 4.f, 4.f, 4.5f , 5.f, 5.f, 6.f, 7.f , 7.f , 7.5f , 8.f , 8.f , 9.f }); - NDArray expected = NDArrayFactory::create('c', {1, 30, 30, 1}, { - 1.f, 1.1976162f, 1.4174359f, 1.6775769f, 1.9961575f, 2.3283265f, - 2.550918f, 2.7360606f, 2.9655411f, 3.2929654f, 3.5441515f, 3.7380352f, - 3.948995f, 4.248106f, 4.5073795f, 4.6843743f, 4.8572845f, 5.104302f, - 5.3869915f, 5.581401f, 5.7539616f, 5.974285f, 6.272836f, 6.5204263f, - 6.718899f, 6.8871036f, 7.039068f, 7.099216f, 7.0784245f, 7.0281887f, - 2.247592f, 2.446947f, 2.6694887f, 2.9312382f, 3.248216f, 3.5745337f, - 3.78931f, 3.9656973f, 4.186417f, 4.5046535f, 4.740569f, 4.9217057f, - 5.133866f, 5.459533f, 5.7744613f, 6.0197873f, 6.254011f, 6.535633f, - 6.8097296f, 6.9607787f, 7.0749416f, 7.241601f, 7.5094895f, 7.7499495f, - 7.954571f, 8.131972f, 8.286526f, 8.346463f, 8.325745f, 8.275683f, - 3.6286845f, 3.830573f, 4.0569587f, 4.3211575f, 4.6364856f, 4.9556503f, - 5.160583f, 5.3258467f, 5.535462f, 5.84216f, 6.058749f, 6.223753f, - 6.437597f, 6.797369f, 7.1836042f, 7.5164022f, 7.8290343f, 8.154773f, - 8.417635f, 8.512958f, 8.5521f, 8.649708f, 8.87788f, 9.108794f, - 9.320926f, 9.509781f, 9.667375f, 9.72694f, 9.706349f, 9.656599f, - 5.276778f, 5.480438f, 5.709702f, 5.9754477f, 6.288551f, 6.6005697f, - 6.796207f, 6.9511423f, 7.1503997f, 7.4461427f, 7.644651f, 7.794562f, - 8.009684f, 8.400473f, 8.851847f, 9.26469f, 9.649218f, 10.015648f, - 10.268647f, 10.313368f, 10.2843275f, 10.319379f, 10.512033f, 10.734956f, - 10.954604f, 11.154507f, 11.315369f, 11.374779f, 11.354242f, 11.304622f, - 7.325373f, 7.5284843f, 7.757575f, 8.022221f, 8.331997f, 8.638187f, - 8.827649f, 8.976217f, 9.168955f, 9.45726f, 9.6442375f, 9.784517f, - 9.999621f, 10.407702f, 10.896234f, 11.355122f, 11.781423f, 12.172186f, - 12.420712f, 12.4374485f, 12.370511f, 12.371386f, 12.545973f, 12.766424f, - 12.992249f, 13.20012f, 13.364252f, 13.424109f, 13.40342f, 13.353425f, - 9.493208f, 9.692467f, 9.9169445f, 10.176801f, 10.482199f, 10.78547f, - 10.974367f, 11.123442f, 11.31637f, 11.603645f, 11.790616f, 11.930889f, - 12.144082f, 12.546447f, 13.024898f, 13.4723f, 13.889232f, 14.276275f, - 14.528972f, 14.555555f, 14.50145f, 14.515459f, 14.700572f, 14.927055f, - 15.156046f, 15.366046f, 15.532901f, 15.594008f, 15.5728855f, 15.521847f, - 10.970133f, 11.163599f, 11.380694f, 11.633735f, 11.935032f, 12.238887f, - 12.43254f, 12.588294f, 12.787534f, 13.079956f, 13.27752f, 13.426631f, - 13.636713f, 14.013844f, 14.441672f, 14.827978f, 15.191209f, 15.549808f, - 15.81343f, 15.881828f, 15.883522f, 15.950411f, 16.16933f, 16.40794f, - 16.636436f, 16.842583f, 17.010887f, 17.07363f, 17.05194f, 16.999537f, - 12.219155f, 12.406129f, 12.614796f, 12.860335f, 13.157928f, 13.464224f, - 13.665207f, 13.830567f, 14.039036f, 14.339629f, 14.552863f, 14.715049f, - 14.921564f, 15.264454f, 15.622843f, 15.924977f, 16.213829f, 16.532364f, - 16.8099f, 16.934835f, 17.012146f, 17.150164f, 17.413412f, 17.666712f, - 17.892765f, 18.09207f, 18.261044f, 18.325531f, 18.303238f, 18.249378f, - 13.7663965f, 13.947391f, 14.148263f, 14.386917f, 14.681246f, 14.990087f, - 15.198166f, 15.372728f, 15.590062f, 15.898583f, 16.126892f, 16.301655f, - 16.50487f, 16.815214f, 17.107498f, 17.329458f, 17.547403f, 17.827654f, - 18.118288f, 18.296928f, 18.4461f, 18.651634f, 18.956806f, 19.22382f, - 19.447308f, 19.639887f, 19.809319f, 19.875397f, 19.852556f, 19.797365f, - 15.9419365f, 16.118704f, 16.314133f, 16.547867f, 16.839561f, 17.14954f, - 17.361883f, 17.542162f, 17.764957f, 18.078188f, 18.315733f, 18.498205f, - 18.699116f, 18.988684f, 19.238989f, 19.410137f, 19.583265f, 19.839512f, - 20.13878f, 20.35177f, 20.546844f, 20.795671f, 21.128067f, 21.404358f, - 21.626736f, 21.8155f, 21.98561f, 22.052843f, 22.029604f, 21.973448f, - 17.53522f, 17.71077f, 17.904636f, 18.13695f, 18.42784f, 18.738056f, - 18.951529f, 19.133352f, 19.357613f, 19.672083f, 19.912102f, 20.096638f, - 20.296894f, 20.580765f, 20.819603f, 20.976887f, 21.137802f, 21.387535f, - 21.689209f, 21.911621f, 22.119276f, 22.37999f, 22.71991f, 22.998823f, - 23.22097f, 23.40876f, 23.57911f, 23.646685f, 23.623325f, 23.566887f, - 18.746353f, 18.922657f, 19.117487f, 19.350685f, 19.64207f, 19.952137f, - 20.164913f, 20.345781f, 20.569134f, 20.88284f, 21.12133f, 21.30459f, - 21.505253f, 21.792645f, 22.038572f, 22.204426f, 22.37289f, 22.626648f, - 22.926834f, 23.143423f, 23.343302f, 23.596668f, 23.931936f, 24.209232f, - 24.431519f, 24.619913f, 24.79011f, 24.857473f, 24.83419f, 24.777927f, - 20.16656f, 20.344206f, 20.540766f, 20.775532f, 21.067804f, 21.377607f, - 21.589132f, 21.768297f, 21.99003f, 22.302366f, 22.538124f, 22.719105f, - 22.920494f, 23.214176f, 23.472767f, 23.653934f, 23.83589f, 24.096842f, - 24.394371f, 24.600555f, 24.786541f, 25.026773f, 25.353731f, 25.62813f, - 25.850672f, 26.04014f, 26.210072f, 26.277063f, 26.253906f, 26.197956f, - 22.363024f, 22.54125f, 22.738552f, 22.973991f, 23.266647f, 23.57634f, - 23.787327f, 23.96576f, 24.186796f, 24.498543f, 24.733124f, 24.913122f, - 25.114826f, 25.411213f, 25.675262f, 25.863028f, 26.050789f, 26.314838f, - 26.611223f, 26.812925f, 26.992926f, 27.227505f, 27.550882f, 27.824034f, - 28.046684f, 28.236614f, 28.406433f, 28.473265f, 28.450163f, 28.394344f, - 24.429443f, 24.60767f, 24.80497f, 25.04041f, 25.333065f, 25.642756f, - 25.853743f, 26.032173f, 26.25321f, 26.564959f, 26.79954f, 26.97954f, - 27.181242f, 27.47763f, 27.74168f, 27.929441f, 28.117207f, 28.381254f, - 28.677637f, 28.879343f, 29.059345f, 29.293922f, 29.617298f, 29.890451f, - 30.113104f, 30.303034f, 30.472853f, 30.539684f, 30.516582f, 30.460762f, - 26.f, 26.178228f, 26.375526f, 26.61097f, 26.903624f, 27.213314f, - 27.424305f, 27.602734f, 27.823772f, 28.135519f, 28.3701f, 28.550098f, - 28.7518f, 29.04819f, 29.312237f, 29.5f, 29.687763f, 29.951813f, - 30.2482f, 30.449903f, 30.629902f, 30.864483f, 31.187859f, 31.461012f, - 31.683659f, 31.873592f, 32.043407f, 32.11024f, 32.087135f, 32.03132f, - 27.570559f, 27.748787f, 27.946087f, 28.181528f, 28.474184f, 28.783876f, - 28.994865f, 29.173294f, 29.39433f, 29.70608f, 29.940659f, 30.120655f, - 30.32236f, 30.618746f, 30.882797f, 31.070557f, 31.25832f, 31.522371f, - 31.818754f, 32.02046f, 32.20046f, 32.43504f, 32.758415f, 33.031567f, - 33.25422f, 33.44415f, 33.613964f, 33.680794f, 33.657696f, 33.60188f, - 29.636976f, 29.815207f, 30.0125f, 30.247944f, 30.5406f, 30.85029f, - 31.061283f, 31.239712f, 31.46075f, 31.7725f, 32.00708f, 32.187077f, - 32.38878f, 32.685165f, 32.949215f, 33.13698f, 33.32474f, 33.58879f, - 33.885178f, 34.086884f, 34.26688f, 34.501457f, 34.824837f, 35.09799f, - 35.320637f, 35.510574f, 35.68039f, 35.747215f, 35.724117f, 35.6683f, - 31.83344f, 32.011665f, 32.20897f, 32.444412f, 32.73707f, 33.046757f, - 33.257744f, 33.436176f, 33.657207f, 33.96896f, 34.203537f, 34.383537f, - 34.58524f, 34.88163f, 35.145676f, 35.33344f, 35.521206f, 35.785255f, - 36.081642f, 36.28334f, 36.46334f, 36.69792f, 37.021297f, 37.294453f, - 37.517097f, 37.707027f, 37.876846f, 37.94368f, 37.920578f, 37.864758f, - 33.253647f, 33.431873f, 33.62917f, 33.864613f, 34.15727f, 34.466957f, - 34.677948f, 34.856377f, 35.077415f, 35.38916f, 35.623745f, 35.803745f, - 36.005447f, 36.301834f, 36.565884f, 36.753647f, 36.941406f, 37.205456f, - 37.50184f, 37.703545f, 37.883545f, 38.118122f, 38.4415f, 38.714653f, - 38.9373f, 39.127235f, 39.297054f, 39.363884f, 39.340782f, 39.28496f, - 34.464783f, 34.64301f, 34.840305f, 35.075752f, 35.368404f, 35.6781f, - 35.889088f, 36.067516f, 36.28855f, 36.6003f, 36.834885f, 37.014877f, - 37.216583f, 37.51297f, 37.77702f, 37.964783f, 38.152546f, 38.416595f, - 38.71298f, 38.914684f, 39.094685f, 39.32926f, 39.652645f, 39.925793f, - 40.14844f, 40.338375f, 40.508194f, 40.575024f, 40.55192f, 40.496105f, - 36.058067f, 36.23629f, 36.43359f, 36.669033f, 36.961685f, 37.271378f, - 37.48237f, 37.6608f, 37.881836f, 38.19359f, 38.42817f, 38.608162f, - 38.809868f, 39.10625f, 39.3703f, 39.558064f, 39.74583f, 40.00988f, - 40.306267f, 40.50797f, 40.68797f, 40.92255f, 41.245926f, 41.519077f, - 41.741722f, 41.931652f, 42.101475f, 42.168304f, 42.145203f, 42.089386f, - 38.315002f, 38.493233f, 38.690533f, 38.925976f, 39.218628f, 39.52832f, - 39.739307f, 39.917736f, 40.138775f, 40.45052f, 40.685104f, 40.865097f, - 41.066803f, 41.36319f, 41.627243f, 41.815002f, 42.002766f, 42.26682f, - 42.5632f, 42.764908f, 42.944904f, 43.179485f, 43.50286f, 43.776016f, - 43.998665f, 44.188595f, 44.358418f, 44.425247f, 44.402145f, 44.34633f, - 40.22708f, 40.40531f, 40.602608f, 40.83805f, 41.130707f, 41.440395f, - 41.651382f, 41.82982f, 42.050854f, 42.3626f, 42.597183f, 42.77718f, - 42.97888f, 43.27527f, 43.53932f, 43.72708f, 43.914845f, 44.178894f, - 44.47528f, 44.676983f, 44.856983f, 45.09156f, 45.41494f, 45.68809f, - 45.91074f, 46.100674f, 46.270493f, 46.337322f, 46.31422f, 46.2584f, - 41.785618f, 41.963844f, 42.161144f, 42.396584f, 42.68924f, 42.998936f, - 43.209923f, 43.388355f, 43.609394f, 43.921143f, 44.15572f, 44.335716f, - 44.53742f, 44.833805f, 45.09786f, 45.285614f, 45.473377f, 45.737427f, - 46.033817f, 46.235523f, 46.415524f, 46.650105f, 46.973476f, 47.24663f, - 47.469276f, 47.65921f, 47.82903f, 47.895855f, 47.872753f, 47.81694f, - 43.11514f, 43.293365f, 43.490665f, 43.726105f, 44.018764f, 44.328457f, - 44.539444f, 44.717873f, 44.93891f, 45.25066f, 45.48524f, 45.665237f, - 45.86694f, 46.163326f, 46.427376f, 46.615143f, 46.802902f, 47.066956f, - 47.363342f, 47.56505f, 47.74505f, 47.979626f, 48.302998f, 48.576153f, - 48.798798f, 48.98873f, 49.158546f, 49.225376f, 49.202282f, 49.146458f, - 44.303867f, 44.482094f, 44.679394f, 44.914833f, 45.207493f, 45.51718f, - 45.72817f, 45.9066f, 46.12764f, 46.439384f, 46.673965f, 46.853966f, - 47.055668f, 47.352055f, 47.6161f, 47.803867f, 47.99163f, 48.25568f, - 48.552063f, 48.75377f, 48.933773f, 49.16835f, 49.491726f, 49.764877f, - 49.987526f, 50.17746f, 50.347275f, 50.4141f, 50.391006f, 50.335186f, - 44.771675f, 44.949905f, 45.1472f, 45.382645f, 45.6753f, 45.98499f, - 46.195976f, 46.374413f, 46.595448f, 46.907196f, 47.141773f, 47.321774f, - 47.523476f, 47.819862f, 48.08391f, 48.27168f, 48.459446f, 48.72349f, - 49.019882f, 49.22158f, 49.401585f, 49.63616f, 49.959538f, 50.232693f, - 50.455338f, 50.64527f, 50.81509f, 50.88192f, 50.858818f, 50.803f, - 44.609966f, 44.788193f, 44.985493f, 45.220936f, 45.51359f, 45.82328f, - 46.03427f, 46.2127f, 46.433743f, 46.74549f, 46.98007f, 47.160065f, - 47.36177f, 47.658157f, 47.922207f, 48.10997f, 48.297733f, 48.561783f, - 48.858166f, 49.059875f, 49.239872f, 49.47445f, 49.79783f, 50.07098f, - 50.293625f, 50.48356f, 50.653378f, 50.720203f, 50.6971f, 50.64128f, - 44.219246f, 44.397472f, 44.594772f, 44.83021f, 45.122868f, 45.43256f, - 45.643543f, 45.82198f, 46.04302f, 46.354763f, 46.589344f, 46.76934f, - 46.971046f, 47.267433f, 47.531483f, 47.719242f, 47.907005f, 48.17105f, - 48.467438f, 48.66914f, 48.849144f, 49.08372f, 49.4071f, 49.680256f, - 49.902905f, 50.092834f, 50.262653f, 50.329483f, 50.30638f, 50.25057f}); - - auto size = NDArrayFactory::create({30, 30}); - sd::ops::resize_bicubic op; - auto results = op.evaluate({&input, &size}, {}, {}); + //input.linspace(1); +// auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_images op; + auto results = op.evaluate({&input}, {}, {6, 8, ops::helpers::kResizeArea}, {true, true}); // resize_area to 6x8 with align corners and preserve aspect ratio of input image ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); -// result.printBuffer("Resized to 30x30"); -// expected.printBuffer("Expect for 30x30"); + auto result = results.at(0); + +// result->printBuffer("Area Resized to 6x6"); +// expected.printBuffer("Area Expect for 6x6"); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); } -TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test2) { - - NDArray input = NDArrayFactory::create('c', {2, 5, 4, 3}); - NDArray expected = NDArrayFactory::create('c', {2, 10, 8, 3}, { - 1.000000f, 2.000000f, 3.000000f, 2.218750f, 3.218750f, 4.218750f, 4.000000f, 5.000000f, 6.000000f, - 5.500000f, 6.500000f, 7.500000f, 7.000000f, 8.000000f, 9.000000f, 8.781250f, 9.781250f, 10.781250f, - 10.000000f, 11.000000f, 12.000000f, 10.281250f, 11.281250f, 12.281250f, 5.875000f, 6.875000f, 7.875000f, - 7.093750f, 8.093750f, 9.093750f, 8.875000f, 9.875000f, 10.875000f, 10.375000f, 11.375000f, 12.375000f, - 11.875000f, 12.875000f, 13.875000f, 13.656250f, 14.656250f, 15.656250f, 14.875000f, 15.875000f, 16.875000f, - 15.156250f, 16.156250f, 17.156250f, 13.000000f, 14.000000f, 15.000000f, 14.218750f, 15.218750f, 16.218750f, - 16.000000f, 17.000000f, 18.000000f, 17.500000f, 18.500000f, 19.500000f, 19.000000f, 20.000000f, 21.000000f, - 20.781250f, 21.781250f, 22.781250f, 22.000000f, 23.000000f, 24.000000f, 22.281250f, 23.281250f, 24.281250f, - 19.000000f, 20.000000f, 21.000000f, 20.218750f, 21.218750f, 22.218750f, 22.000000f, 23.000000f, 24.000000f, - 23.500000f, 24.500000f, 25.500000f, 25.000000f, 26.000000f, 27.000000f, 26.781250f, 27.781250f, 28.781250f, - 28.000000f, 29.000000f, 30.000000f, 28.281250f, 29.281250f, 30.281250f, 25.000000f, 26.000000f, 27.000000f, - 26.218750f, 27.218750f, 28.218750f, 28.000000f, 29.000000f, 30.000000f, 29.500000f, 30.500000f, 31.500000f, - 31.000000f, 32.000000f, 33.000000f, 32.781250f, 33.781250f, 34.781250f, 34.000000f, 35.000000f, 36.000000f, - 34.281250f, 35.281250f, 36.281250f, 31.000000f, 32.000000f, 33.000000f, 32.218750f, 33.218750f, 34.218750f, - 34.000000f, 35.000000f, 36.000000f, 35.500000f, 36.500000f, 37.500000f, 37.000000f, 38.000000f, 39.000000f, - 38.781250f, 39.781250f, 40.781250f, 40.000000f, 41.000000f, 42.000000f, 40.281250f, 41.281250f, 42.281250f, - 37.000000f, 38.000000f, 39.000000f, 38.218750f, 39.218750f, 40.218750f, 40.000000f, 41.000000f, 42.000000f, - 41.500000f, 42.500000f, 43.500000f, 43.000000f, 44.000000f, 45.000000f, 44.781250f, 45.781250f, 46.781250f, - 46.000000f, 47.000000f, 48.000000f, 46.281250f, 47.281250f, 48.281250f, 44.125000f, 45.125000f, 46.125000f, - 45.343750f, 46.343750f, 47.343750f, 47.125000f, 48.125000f, 49.125000f, 48.625000f, 49.625000f, 50.625000f, - 50.125000f, 51.125000f, 52.125000f, 51.906250f, 52.906250f, 53.906250f, 53.125000f, 54.125000f, 55.125000f, - 53.406250f, 54.406250f, 55.406250f, 49.000000f, 50.000000f, 51.000000f, 50.218750f, 51.218750f, 52.218750f, - 52.000000f, 53.000000f, 54.000000f, 53.500000f, 54.500000f, 55.500000f, 55.000000f, 56.000000f, 57.000000f, - 56.781250f, 57.781250f, 58.781250f, 58.000000f, 59.000000f, 60.000000f, 58.281250f, 59.281250f, 60.281250f, - 50.125000f, 51.125000f, 52.125000f, 51.343750f, 52.343750f, 53.343750f, 53.125000f, 54.125000f, 55.125000f, - 54.625000f, 55.625000f, 56.625000f, 56.125000f, 57.125000f, 58.125000f, 57.906250f, 58.906250f, 59.906250f, - 59.125000f, 60.125000f, 61.125000f, 59.406250f, 60.406250f, 61.406250f, 61.000000f, 62.000000f, 63.000000f, - 62.218750f, 63.218750f, 64.218750f, 64.000000f, 65.000000f, 66.000000f, 65.500000f, 66.500000f, 67.500000f, - 67.000000f, 68.000000f, 69.000000f, 68.781250f, 69.781250f, 70.781250f, 70.000000f, 71.000000f, 72.000000f, - 70.281250f, 71.281250f, 72.281250f, 65.875000f, 66.875000f, 67.875000f, 67.093750f, 68.093750f, 69.093750f, - 68.875000f, 69.875000f, 70.875000f, 70.375000f, 71.375000f, 72.375000f, 71.875000f, 72.875000f, 73.875000f, - 73.656250f, 74.656250f, 75.656250f, 74.875000f, 75.875000f, 76.875000f, 75.156250f, 76.156250f, 77.156250f, - 73.000000f, 74.000000f, 75.000000f, 74.218750f, 75.218750f, 76.218750f, 76.000000f, 77.000000f, 78.000000f, - 77.500000f, 78.500000f, 79.500000f, 79.000000f, 80.000000f, 81.000000f, 80.781250f, 81.781250f, 82.781250f, - 82.000000f, 83.000000f, 84.000000f, 82.281250f, 83.281250f, 84.281250f, 79.000000f, 80.000000f, 81.000000f, - 80.218750f, 81.218750f, 82.218750f, 82.000000f, 83.000000f, 84.000000f, 83.500000f, 84.500000f, 85.500000f, - 85.000000f, 86.000000f, 87.000000f, 86.781250f, 87.781250f, 88.781250f, 88.000000f, 89.000000f, 90.000000f, - 88.281250f, 89.281250f, 90.281250f, 85.000000f, 86.000000f, 87.000000f, 86.218750f, 87.218750f, 88.218750f, - 88.000000f, 89.000000f, 90.000000f, 89.500000f, 90.500000f, 91.500000f, 91.000000f, 92.000000f, 93.000000f, - 92.781250f, 93.781250f, 94.781250f, 94.000000f, 95.000000f, 96.000000f, 94.281250f, 95.281250f, 96.281250f, - 91.000000f, 92.000000f, 93.000000f, 92.218750f, 93.218750f, 94.218750f, 94.000000f, 95.000000f, 96.000000f, - 95.500000f, 96.500000f, 97.500000f, 97.000000f, 98.000000f, 99.000000f, 98.781250f, 99.781250f, 100.781250f, - 100.000000f, 101.000000f, 102.000000f, 100.281250f, 101.281250f, 102.281250f, 97.000000f, 98.000000f, - 99.000000f, 98.218750f, 99.218750f, 100.218750f, 100.000000f, 101.000000f, 102.000000f, 101.500000f, - 102.500000f, 103.500000f, 103.000000f, 104.000000f, 105.000000f, 104.781250f, 105.781250f, 106.781250f, - 106.000000f, 107.000000f, 108.000000f, 106.281250f, 107.281250f, 108.281250f, 104.125000f, 105.125000f, - 106.125000f, 105.343750f, 106.343750f, 107.343750f, 107.125000f, 108.125000f, 109.125000f, 108.625000f, - 109.625000f, 110.625000f, 110.125000f, 111.125000f, 112.125000f, 111.906250f, 112.906250f, 113.906250f, - 113.125000f, 114.125000f, 115.125000f, 113.406250f, 114.406250f, 115.406250f, 109.000000f, 110.000000f, - 111.000000f, 110.218750f, 111.218750f, 112.218750f, 112.000000f, 113.000000f, 114.000000f, 113.500000f, - 114.500000f, 115.500000f, 115.000000f, 116.000000f, 117.000000f, 116.781250f, 117.781250f, 118.781250f, - 118.000000f, 119.000000f, 120.000000f, 118.281250f, 119.281250f, 120.281250f, 110.125000f, 111.125000f, - 112.125000f, 111.343750f, 112.343750f, 113.343750f, 113.125000f, 114.125000f, 115.125000f, 114.625000f, - 115.625000f, 116.625000f, 116.125000f, 117.125000f, 118.125000f, 117.906250f, 118.906250f, 119.906250f, - 119.125000f, 120.125000f, 121.125000f, 119.406250f, 120.406250f, 121.406250f - }); //input = 1.f; - input.linspace(1); - auto size = NDArrayFactory::create({10, 8}); - sd::ops::resize_bicubic op; - auto results = op.evaluate({&input, &size}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test9) { + NDArray input = NDArrayFactory::create( + 'c', {1, 2, 3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + + NDArray expected = NDArrayFactory::create( + 'c', {1, 10, 10, 4}, + {1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, + 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, + 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, + 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, + 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, + 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, + 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, + 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, + 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, + 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, + 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, + 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, + 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, + 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, + 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, + 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, + 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, + 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, + 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, + 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, + 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, + 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, + 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, + 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, + 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, + 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, + 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, + 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, + 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, + 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, + 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, + 8.333336f, 9.333336f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, + 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999998f, 9.999997f, + 10.999997f, 11.999997f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, + 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, + 15.000003f, 16.000004f, 15.666671f, 16.666672f, 17.666672f, 18.666672f, + 17.000006f, 18.000004f, 19.000006f, 20.000004f, 17.000006f, 18.000004f, + 19.000006f, 20.000004f, 18.333344f, 19.333344f, 20.333345f, 21.333344f, + 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000006f, 22.000006f, + 23.000006f, 24.000006f, 21.000002f, 22.000000f, 23.000002f, 24.000000f, + 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, + 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, + 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, + 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, + 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, + 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, + 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, + 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, + 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, + 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, + 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, + 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, + 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, + 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, + 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, + 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, + 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, + 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, + 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, + 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, + 12.999995f, 13.999995f, 14.999994f, 15.999994f, 12.999995f, 13.999995f, + 14.999994f, 15.999994f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, + 15.666661f, 16.666662f, 17.666660f, 18.666660f, 16.999994f, 17.999994f, + 18.999992f, 19.999992f, 16.999994f, 17.999994f, 18.999992f, 19.999992f, + 18.333334f, 19.333332f, 20.333334f, 21.333332f, 20.999992f, 21.999992f, + 22.999990f, 23.999992f, 20.999992f, 21.999992f, 22.999990f, 23.999992f, + 20.999989f, 21.999989f, 22.999987f, 23.999987f + + }); + // input.linspace(1); + auto size = NDArrayFactory::create({10, 10}); + sd::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Area Resized to 10x10"); + // expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} -// result.printBuffer("Resized to 10x8"); -// expected.printBuffer("Expect for 10x8"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test10) { + NDArray input = NDArrayFactory::create( + 'c', {1, 2, 3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + + NDArray expected = NDArrayFactory::create( + 'c', {1, 10, 10, 4}, + {1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, + 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, + 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, + 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, + 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, + 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, + 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, + 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, + 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, + 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, + 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, + 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, + 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, + 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, + 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, + 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, + 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, + 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, + 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, + 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, + 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, + 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, + 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, + 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, + 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, + 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, + 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, + 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, + 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, + 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, + 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, + 8.333336f, 9.333336f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, + 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999998f, 9.999997f, + 10.999997f, 11.999997f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, + 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, + 15.000003f, 16.000004f, 15.666671f, 16.666672f, 17.666672f, 18.666672f, + 17.000006f, 18.000004f, 19.000006f, 20.000004f, 17.000006f, 18.000004f, + 19.000006f, 20.000004f, 18.333344f, 19.333344f, 20.333345f, 21.333344f, + 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000006f, 22.000006f, + 23.000006f, 24.000006f, 21.000002f, 22.000000f, 23.000002f, 24.000000f, + 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, + 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, + 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, + 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, + 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, + 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, + 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, + 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, + 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, + 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, + 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, + 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, + 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, + 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, + 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, + 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, + 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, + 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, + 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, + 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, + 12.999995f, 13.999995f, 14.999994f, 15.999994f, 12.999995f, 13.999995f, + 14.999994f, 15.999994f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, + 15.666661f, 16.666662f, 17.666660f, 18.666660f, 16.999994f, 17.999994f, + 18.999992f, 19.999992f, 16.999994f, 17.999994f, 18.999992f, 19.999992f, + 18.333334f, 19.333332f, 20.333334f, 21.333332f, 20.999992f, 21.999992f, + 22.999990f, 23.999992f, 20.999992f, 21.999992f, 22.999990f, 23.999992f, + 20.999989f, 21.999989f, 22.999987f, 23.999987f + + }); + // input.linspace(1); + // auto size = NDArrayFactory::create({10, 10}); + sd::ops::resize_area op; + auto results = op.evaluate({&input}, {}, {10, 10}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Area Resized to 10x10"); + // expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } -TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test3) { +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test11) { + NDArray input = NDArrayFactory::create( + 'c', {1, 2, 3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + + // NDArray expected = NDArrayFactory::create('c', {1, 6, 9, 4}, { + // 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, + // 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, + // 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, + // 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, + // 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, + // 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, + // 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, + // 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, + // 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, + // 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, + // 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, + // 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, + // 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, + // 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, + // 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, + // 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, + // 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, + // 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, + // 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, + // 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, + // 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, + // 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, + // 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, + // 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, + // 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, + // 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, + // 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333336, 8.999999, + // 9.999999, 11.000000, 11.999999, 8.999999, 9.999999, 11.000000, 11.999999, + // 8.999998, 9.999997, 10.999997, 11.999997, 13.000003, 14.000004, 15.000003, + // 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, + // 14.000004, 15.000003, 16.000004, 15.666671, 16.666672, 17.666672, + // 18.666672, 17.000006, 18.000004, 19.000006, 20.000004, 17.000006, + // 18.000004, 19.000006, 20.000004, 18.333344, 19.333344, 20.333345, + // 21.333344, 21.000006, 22.000006, 23.000006, 24.000006, 21.000006, + // 22.000006, 23.000006, 24.000006, 21.000002, 22.000000, 23.000002, + // 24.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, + // 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, + // 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, + // 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, + // 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, + // 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, + // 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, + // 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, + // 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, + // 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, + // 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, + // 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, + // 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, + // 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, + // 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, + // 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, + // 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, + // 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, + // 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, + // 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, + // 23.999996, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, + // 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, + // 15.999994, 15.666661, 16.666662, 17.666660, 18.666660, 16.999994, + // 17.999994, 18.999992, 19.999992, 16.999994, 17.999994, 18.999992, + // 19.999992, 18.333334, 19.333332, 20.333334, 21.333332, 20.999992, + // 21.999992, 22.999990, 23.999992, 20.999992, 21.999992, 22.999990, + // 23.999992, 20.999989, 21.999989, 22.999987, 23.999987 + // + // }); + // input.linspace(1); + // auto size = NDArrayFactory::create({10, 10}); + sd::ops::resize_area op; + auto results = op.evaluate({&input}, {}, {6, 9}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Area Resized to 6x9"); + // expected.printBuffer("Area Expect for 6x6"); + // ASSERT_TRUE(expected.isSameShape(result)); + // ASSERT_TRUE(expected.equalsTo(result)); +} - NDArray input = NDArrayFactory::create('c', {1, 3, 3, 4}); - NDArray expected = NDArrayFactory::create('c', {1, 6, 6, 4}, { - 1.000000f, 2.000000f, 3.000000f, 4.000000f, 2.625000f, 3.625000f, 4.625000f, 5.625000f, 5.000000f, - 6.000000f, 7.000000f, 8.000000f, 7.375000f, 8.375000f, 9.375000f, 10.375000f, 9.000000f, 10.000000f, - 11.000000f, 12.000000f, 9.375000f, 10.375000f, 11.375000f, 12.375000f, 5.875000f, 6.875000f, 7.875000f, - 8.875000f, 7.500000f, 8.500000f, 9.500000f, 10.500000f, 9.875000f, 10.875000f, 11.875000f, 12.875000f, - 12.250000f, 13.250000f, 14.250000f, 15.250000f, 13.875000f, 14.875000f, 15.875000f, 16.875000f, 14.250000f, - 15.250000f, 16.250000f, 17.250000f, 13.000000f, 14.000000f, 15.000000f, 16.000000f, 14.625000f, 15.625000f, - 16.625000f, 17.625000f, 17.000000f, 18.000000f, 19.000000f, 20.000000f, 19.375000f, 20.375000f, 21.375000f, - 22.375000f, 21.000000f, 22.000000f, 23.000000f, 24.000000f, 21.375000f, 22.375000f, 23.375000f, 24.375000f, - 20.125000f, 21.125000f, 22.125000f, 23.125000f, 21.750000f, 22.750000f, 23.750000f, 24.750000f, 24.125000f, - 25.125000f, 26.125000f, 27.125000f, 26.500000f, 27.500000f, 28.500000f, 29.500000f, 28.125000f, 29.125000f, - 30.125000f, 31.125000f, 28.500000f, 29.500000f, 30.500000f, 31.500000f, 25.000000f, 26.000000f, 27.000000f, - 28.000000f, 26.625000f, 27.625000f, 28.625000f, 29.625000f, 29.000000f, 30.000000f, 31.000000f, 32.000000f, - 31.375000f, 32.375000f, 33.375000f, 34.375000f, 33.000000f, 34.000000f, 35.000000f, 36.000000f, 33.375000f, - 34.375000f, 35.375000f, 36.375000f, 26.125000f, 27.125000f, 28.125000f, 29.125000f, 27.750000f, 28.750000f, - 29.750000f, 30.750000f, 30.125000f, 31.125000f, 32.125000f, 33.125000f, 32.500000f, 33.500000f, 34.500000f, - 35.500000f, 34.125000f, 35.125000f, 36.125000f, 37.125000f, 34.500000f, 35.500000f, 36.500000f, 37.500000f - }); - input.linspace(1); - auto size = NDArrayFactory::create({6, 6}); - sd::ops::resize_bicubic op; - auto results = op.evaluate({&input, &size}, {}, {}); +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test12) { + NDArray input = NDArrayFactory::create( + 'c', {1, 2, 3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + + // NDArray expected = NDArrayFactory::create('c', {1, 6, 9, 4}, { + // 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, + // 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, + // 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, + // 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, + // 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, + // 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, + // 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, + // 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, + // 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, + // 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, + // 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, + // 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, + // 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, + // 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, + // 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, + // 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, + // 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, + // 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, + // 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, + // 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, + // 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, + // 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, + // 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, + // 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, + // 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, + // 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, + // 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333336, 8.999999, + // 9.999999, 11.000000, 11.999999, 8.999999, 9.999999, 11.000000, 11.999999, + // 8.999998, 9.999997, 10.999997, 11.999997, 13.000003, 14.000004, 15.000003, + // 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, + // 14.000004, 15.000003, 16.000004, 15.666671, 16.666672, 17.666672, + // 18.666672, 17.000006, 18.000004, 19.000006, 20.000004, 17.000006, + // 18.000004, 19.000006, 20.000004, 18.333344, 19.333344, 20.333345, + // 21.333344, 21.000006, 22.000006, 23.000006, 24.000006, 21.000006, + // 22.000006, 23.000006, 24.000006, 21.000002, 22.000000, 23.000002, + // 24.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, + // 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, + // 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, + // 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, + // 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, + // 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, + // 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, + // 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, + // 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, + // 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, + // 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, + // 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, + // 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, + // 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, + // 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, + // 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, + // 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, + // 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, + // 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, + // 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, + // 23.999996, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, + // 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, + // 15.999994, 15.666661, 16.666662, 17.666660, 18.666660, 16.999994, + // 17.999994, 18.999992, 19.999992, 16.999994, 17.999994, 18.999992, + // 19.999992, 18.333334, 19.333332, 20.333334, 21.333332, 20.999992, + // 21.999992, 22.999990, 23.999992, 20.999992, 21.999992, 22.999990, + // 23.999992, 20.999989, 21.999989, 22.999987, 23.999987 + // + // }); + // input.linspace(1); + // auto size = NDArrayFactory::create({10, 10}); + sd::ops::resize_area op; + auto results = op.evaluate({&input}, {}, {10, 15}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Area Resized to 6x9"); + // expected.printBuffer("Area Expect for 6x6"); + // ASSERT_TRUE(expected.isSameShape(result)); + // ASSERT_TRUE(expected.equalsTo(result)); +} - ASSERT_EQ(ND4J_STATUS_OK, results.status()); +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test13) { + NDArray input = NDArrayFactory::create( + 'c', {1, 2, 3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + + // NDArray expected = NDArrayFactory::create('c', {1, 8, 8, 4}, { + // 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, + // 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, + // 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, + // 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, + // 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, + // 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, + // 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, + // 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, + // 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, + // 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, + // 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, + // 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, + // 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, + // 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, + // 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, + // 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, + // 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, + // 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, + // 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, + // 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, + // 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, + // 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, + // 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, + // 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, + // 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, + // 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, + // 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333336, 8.999999, + // 9.999999, 11.000000, 11.999999, 8.999999, 9.999999, 11.000000, 11.999999, + // 8.999998, 9.999997, 10.999997, 11.999997, 13.000003, 14.000004, 15.000003, + // 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, + // 14.000004, 15.000003, 16.000004, 15.666671, 16.666672, 17.666672, + // 18.666672, 17.000006, 18.000004, 19.000006, 20.000004, 17.000006, + // 18.000004, 19.000006, 20.000004, 18.333344, 19.333344, 20.333345, + // 21.333344, 21.000006, 22.000006, 23.000006, 24.000006, 21.000006, + // 22.000006, 23.000006, 24.000006, 21.000002, 22.000000, 23.000002, + // 24.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, + // 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, + // 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, + // 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, + // 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, + // 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, + // 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, + // 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, + // 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, + // 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, + // 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, + // 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, + // 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, + // 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, + // 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, + // 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, + // 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, + // 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, + // 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, + // 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, + // 23.999996, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, + // 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, + // 15.999994, 15.666661, 16.666662, 17.666660, 18.666660, 16.999994, + // 17.999994, 18.999992, 19.999992, 16.999994, 17.999994, 18.999992, + // 19.999992, 18.333334, 19.333332, 20.333334, 21.333332, 20.999992, + // 21.999992, 22.999990, 23.999992, 20.999992, 21.999992, 22.999990, + // 23.999992, 20.999989, 21.999989, 22.999987, 23.999987 + // + // }); + // input.linspace(1); + // auto size = NDArrayFactory::create({10, 10}); + sd::ops::resize_area op; + auto results = op.evaluate({&input}, {}, {9, 9}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Area Resized to 8x8"); + // expected.printBuffer("Area Expect for 6x6"); + // ASSERT_TRUE(expected.isSameShape(result)); + // ASSERT_TRUE(expected.equalsTo(result)); +} - NDArray* result = results.at(0); +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test14) { + NDArray input = NDArrayFactory::create( + 'c', {1, 5, 5, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}); + auto size = NDArrayFactory::create({8, 7}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 8, 7, 1}, + {1.f, 1.6f, 2.1999993f, 2.9999995f, 3.8f, 4.399997f, + 5.f, 2.9999995f, 3.5999997f, 4.199999f, 4.9999995f, 5.8f, + 6.3999963f, 7.f, 5.999999f, 6.6f, 7.1999984f, 7.9999995f, + 8.8f, 9.399994f, 10.f, 10.f, 10.6f, 11.199998f, + 12.f, 12.8f, 13.399992f, 14.f, 12.f, 12.599999f, + 13.199998f, 13.999998f, 14.800002f, 15.399991f, 16.f, 15.999999f, + 16.599998f, 17.199995f, 18.f, 18.800003f, 19.399986f, 20.000002f, + 19.f, 19.599998f, 20.199997f, 20.999998f, 21.800003f, 22.399984f, + 23.000002f, 20.999998f, 21.599998f, 22.199995f, 22.999998f, 23.800001f, + 24.399984f, 25.f}); // input.linspace(1); + // auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {false}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printBuffer("Area Resized to 8x7"); + // expected.printBuffer("Area Expect for 8x7"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} -// result.printBuffer("Resized to 6x6"); -// expected.printBuffer("Expect for 6x6"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test15) { + NDArray input = NDArrayFactory::create( + 'c', {1, 5, 5, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}); + // auto size = NDArrayFactory::create({8, 7}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 8, 7, 1}, + {1.f, 1.6f, 2.1999993f, 2.9999995f, 3.8f, 4.399997f, + 5.f, 2.9999995f, 3.5999997f, 4.199999f, 4.9999995f, 5.8f, + 6.3999963f, 7.f, 5.999999f, 6.6f, 7.1999984f, 7.9999995f, + 8.8f, 9.399994f, 10.f, 10.f, 10.6f, 11.199998f, + 12.f, 12.8f, 13.399992f, 14.f, 12.f, 12.599999f, + 13.199998f, 13.999998f, 14.800002f, 15.399991f, 16.f, 15.999999f, + 16.599998f, 17.199995f, 18.f, 18.800003f, 19.399986f, 20.000002f, + 19.f, 19.599998f, 20.199997f, 20.999998f, 21.800003f, 22.399984f, + 23.000002f, 20.999998f, 21.599998f, 22.199995f, 22.999998f, 23.800001f, + 24.399984f, 25.f}); + + sd::ops::resize_area op; + auto results = op.evaluate({&input}, {}, {8, 7}, {false}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printBuffer("Area Resized to 8x7"); + // expected.printBuffer("Area Expect for 8x7"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } -TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test4) { - NDArray input = NDArrayFactory::create('c', {1, 3, 4, 3}); - NDArray expected = NDArrayFactory::create('c', {1, 6, 8, 3}, { - 1.000000f, 2.000000f, 3.000000f, 2.218750f, 3.218750f, 4.218750f, 4.000000f, 5.000000f, 6.000000f, - 5.500000f, 6.500000f, 7.500000f, 7.000000f, 8.000000f, 9.000000f, 8.781250f, 9.781250f, 10.781250f, - 10.000000f, 11.000000f, 12.000000f, 10.281250f, 11.281250f, 12.281250f, 5.875000f, 6.875000f, 7.875000f, - 7.093750f, 8.093750f, 9.093750f, 8.875000f, 9.875000f, 10.875000f, 10.375000f, 11.375000f, 12.375000f, - 11.875000f, 12.875000f, 13.875000f, 13.656250f, 14.656250f, 15.656250f, 14.875000f, 15.875000f, 16.875000f, - 15.156250f, 16.156250f, 17.156250f, 13.000000f, 14.000000f, 15.000000f, 14.218750f, 15.218750f, 16.218750f, - 16.000000f, 17.000000f, 18.000000f, 17.500000f, 18.500000f, 19.500000f, 19.000000f, 20.000000f, 21.000000f, - 20.781250f, 21.781250f, 22.781250f, 22.000000f, 23.000000f, 24.000000f, 22.281250f, 23.281250f, 24.281250f, - 20.125000f, 21.125000f, 22.125000f, 21.343750f, 22.343750f, 23.343750f, 23.125000f, 24.125000f, 25.125000f, - 24.625000f, 25.625000f, 26.625000f, 26.125000f, 27.125000f, 28.125000f, 27.906250f, 28.906250f, 29.906250f, - 29.125000f, 30.125000f, 31.125000f, 29.406250f, 30.406250f, 31.406250f, 25.000000f, 26.000000f, 27.000000f, - 26.218750f, 27.218750f, 28.218750f, 28.000000f, 29.000000f, 30.000000f, 29.500000f, 30.500000f, 31.500000f, - 31.000000f, 32.000000f, 33.000000f, 32.781250f, 33.781250f, 34.781250f, 34.000000f, 35.000000f, 36.000000f, - 34.281250f, 35.281250f, 36.281250f, 26.125000f, 27.125000f, 28.125000f, 27.343750f, 28.343750f, 29.343750f, - 29.125000f, 30.125000f, 31.125000f, 30.625000f, 31.625000f, 32.625000f, 32.125000f, 33.125000f, 34.125000f, - 33.906250f, 34.906250f, 35.906250f, 35.125000f, 36.125000f, 37.125000f, 35.406250f, 36.406250f, 37.406250f - }); - input.linspace(1); - auto size = NDArrayFactory::create({6, 8}); - sd::ops::resize_bicubic op; - auto results = op.evaluate({&input, &size}, {}, {}); +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, summaryStatsData_test1) { + functions::summarystats::SummaryStatsData var1; + functions::summarystats::SummaryStatsData var2; + var2.n = var2.mean = var2.M2 = var2.M3 = var2.M4 = var2.bias = 5; + + functions::summarystats::SummaryStatsData* arr = + new functions::summarystats::SummaryStatsData[2]; + arr[0] = var1; + arr[1] = var2; + arr[0] = arr[1]; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + functions::summarystats::SummaryStatsData var3(var1); - NDArray* result = results.at(0); + ASSERT_TRUE(arr[0].n == arr[0].mean && arr[0].M2 == arr[0].M3 && + arr[0].n == 5); + ASSERT_TRUE(arr[1].n == arr[1].mean && arr[1].M2 == arr[1].M3 && + arr[1].n == 5); + ASSERT_TRUE(var3.n == var3.mean && var3.M2 == var3.M3 && var3.n == 0); -// result.printBuffer("Resized to 6x8"); -// expected.printBuffer("Expect for 6x8"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + delete[] arr; } -TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test5) { +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_1) { + auto a = NDArrayFactory::create( + 'c', {3, 3}, {2.f, -1.f, -2.f, -4.f, 6.f, 3.f, -4.f, -2.f, 8.f}); - NDArray input = NDArrayFactory::create('c', {1, 4, 4, 3}); - NDArray expected = NDArrayFactory::create('c', {1, 8, 8, 3}, { - 1.000000f, 2.000000f, 3.000000f, 2.218750f, 3.218750f, 4.218750f, 4.000000f, 5.000000f, 6.000000f, - 5.500000f, 6.500000f, 7.500000f, 7.000000f, 8.000000f, 9.000000f, 8.781250f, 9.781250f, 10.781250f, - 10.000000f, 11.000000f, 12.000000f, 10.281250f, 11.281250f, 12.281250f, 5.875000f, 6.875000f, 7.875000f, - 7.093750f, 8.093750f, 9.093750f, 8.875000f, 9.875000f, 10.875000f, 10.375000f, 11.375000f, 12.375000f, - 11.875000f, 12.875000f, 13.875000f, 13.656250f, 14.656250f, 15.656250f, 14.875000f, 15.875000f, 16.875000f, - 15.156250f, 16.156250f, 17.156250f, 13.000000f, 14.000000f, 15.000000f, 14.218750f, 15.218750f, 16.218750f, - 16.000000f, 17.000000f, 18.000000f, 17.500000f, 18.500000f, 19.500000f, 19.000000f, 20.000000f, 21.000000f, - 20.781250f, 21.781250f, 22.781250f, 22.000000f, 23.000000f, 24.000000f, 22.281250f, 23.281250f, 24.281250f, - 19.000000f, 20.000000f, 21.000000f, 20.218750f, 21.218750f, 22.218750f, 22.000000f, 23.000000f, 24.000000f, - 23.500000f, 24.500000f, 25.500000f, 25.000000f, 26.000000f, 27.000000f, 26.781250f, 27.781250f, 28.781250f, - 28.000000f, 29.000000f, 30.000000f, 28.281250f, 29.281250f, 30.281250f, 25.000000f, 26.000000f, 27.000000f, - 26.218750f, 27.218750f, 28.218750f, 28.000000f, 29.000000f, 30.000000f, 29.500000f, 30.500000f, 31.500000f, - 31.000000f, 32.000000f, 33.000000f, 32.781250f, 33.781250f, 34.781250f, 34.000000f, 35.000000f, 36.000000f, - 34.281250f, 35.281250f, 36.281250f, 32.125000f, 33.125000f, 34.125000f, 33.343750f, 34.343750f, 35.343750f, - 35.125000f, 36.125000f, 37.125000f, 36.625000f, 37.625000f, 38.625000f, 38.125000f, 39.125000f, 40.125000f, - 39.906250f, 40.906250f, 41.906250f, 41.125000f, 42.125000f, 43.125000f, 41.406250f, 42.406250f, 43.406250f, - 37.000000f, 38.000000f, 39.000000f, 38.218750f, 39.218750f, 40.218750f, 40.000000f, 41.000000f, 42.000000f, - 41.500000f, 42.500000f, 43.500000f, 43.000000f, 44.000000f, 45.000000f, 44.781250f, 45.781250f, 46.781250f, - 46.000000f, 47.000000f, 48.000000f, 46.281250f, 47.281250f, 48.281250f, 38.125000f, 39.125000f, 40.125000f, - 39.343750f, 40.343750f, 41.343750f, 41.125000f, 42.125000f, 43.125000f, 42.625000f, 43.625000f, 44.625000f, - 44.125000f, 45.125000f, 46.125000f, 45.906250f, 46.906250f, 47.906250f, 47.125000f, 48.125000f, 49.125000f, - 47.406250f, 48.406250f, 49.406250f, - }); - input.linspace(1); - auto size = NDArrayFactory::create({8, 8}); - sd::ops::resize_bicubic op; - auto results = op.evaluate({&input, &size}, {}, {}); + auto b = NDArrayFactory::create('c', {3, 1}, {2.f, 4.f, 3.f}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto exp = NDArrayFactory::create('c', {3, 1}, {7.625f, 3.25f, 5.f}); - NDArray* result = results.at(0); + sd::ops::solve op; -// result.printBuffer("Resized to 8x8"); -// expected.printBuffer("Expect for 8x8"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); -} + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); -TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test6) { + // z->printIndexedBuffer("Solve of 3x3"); - NDArray input = NDArrayFactory::create('c', {7, 7, 1}, { - 1.f, 2.1f, 3.15f, 4.2f, 5.15f, 6.1f, 7.f, - 8.f, 9.1f, 10.f, 11.f, 12.9f, 13.1f, 14.f, - 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, - 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, - 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, - 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, - 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f - }); + ASSERT_TRUE(exp.equalsTo(z)); +} - NDArray expected = NDArrayFactory::create('c', {30, 30, 1}, { - 1.000000f, 1.197616f, 1.417436f, 1.677577f, 1.996158f, 2.328327f, 2.550918f, 2.736061f, 2.965541f, - 3.292965f, 3.544151f, 3.738035f, 3.948995f, 4.248106f, 4.507379f, 4.684374f, 4.857284f, 5.104302f, - 5.386991f, 5.581401f, 5.753962f, 5.974285f, 6.272836f, 6.520426f, 6.718899f, 6.887104f, 7.039068f, - 7.099216f, 7.078424f, 7.028189f, 2.247592f, 2.446947f, 2.669489f, 2.931238f, 3.248216f, 3.574534f, - 3.789310f, 3.965697f, 4.186417f, 4.504653f, 4.740569f, 4.921706f, 5.133866f, 5.459533f, 5.774461f, - 6.019787f, 6.254011f, 6.535633f, 6.809730f, 6.960779f, 7.074942f, 7.241601f, 7.509489f, 7.749949f, - 7.954571f, 8.131972f, 8.286526f, 8.346463f, 8.325745f, 8.275683f, 3.628684f, 3.830573f, 4.056959f, - 4.321157f, 4.636486f, 4.955650f, 5.160583f, 5.325847f, 5.535462f, 5.842160f, 6.058749f, 6.223753f, - 6.437597f, 6.797369f, 7.183604f, 7.516402f, 7.829034f, 8.154773f, 8.417635f, 8.512958f, 8.552100f, - 8.649708f, 8.877880f, 9.108794f, 9.320926f, 9.509781f, 9.667375f, 9.726940f, 9.706349f, 9.656599f, - 5.276778f, 5.480438f, 5.709702f, 5.975448f, 6.288551f, 6.600570f, 6.796207f, 6.951142f, 7.150400f, - 7.446143f, 7.644651f, 7.794562f, 8.009684f, 8.400473f, 8.851847f, 9.264690f, 9.649218f, 10.015648f, - 10.268647f, 10.313368f, 10.284327f, 10.319379f, 10.512033f, 10.734956f, 10.954604f, 11.154507f, 11.315369f, - 11.374779f, 11.354242f, 11.304622f, 7.325373f, 7.528484f, 7.757575f, 8.022221f, 8.331997f, 8.638187f, - 8.827649f, 8.976217f, 9.168955f, 9.457260f, 9.644237f, 9.784517f, 9.999621f, 10.407702f, 10.896234f, - 11.355122f, 11.781423f, 12.172186f, 12.420712f, 12.437449f, 12.370511f, 12.371386f, 12.545973f, 12.766424f, - 12.992249f, 13.200120f, 13.364252f, 13.424109f, 13.403420f, 13.353425f, 9.493208f, 9.692467f, 9.916944f, - 10.176801f, 10.482199f, 10.785470f, 10.974367f, 11.123442f, 11.316370f, 11.603645f, 11.790616f, 11.930889f, - 12.144082f, 12.546447f, 13.024898f, 13.472300f, 13.889232f, 14.276275f, 14.528972f, 14.555555f, 14.501450f, - 14.515459f, 14.700572f, 14.927055f, 15.156046f, 15.366046f, 15.532901f, 15.594008f, 15.572885f, 15.521847f, - 10.970133f, 11.163599f, 11.380694f, 11.633735f, 11.935032f, 12.238887f, 12.432540f, 12.588294f, 12.787534f, - 13.079956f, 13.277520f, 13.426631f, 13.636713f, 14.013844f, 14.441672f, 14.827978f, 15.191209f, 15.549808f, - 15.813430f, 15.881828f, 15.883522f, 15.950411f, 16.169330f, 16.407940f, 16.636436f, 16.842583f, 17.010887f, - 17.073630f, 17.051940f, 16.999537f, 12.219155f, 12.406129f, 12.614796f, 12.860335f, 13.157928f, 13.464224f, - 13.665207f, 13.830567f, 14.039036f, 14.339629f, 14.552863f, 14.715049f, 14.921564f, 15.264454f, 15.622843f, - 15.924977f, 16.213829f, 16.532364f, 16.809900f, 16.934835f, 17.012146f, 17.150164f, 17.413412f, 17.666712f, - 17.892765f, 18.092070f, 18.261044f, 18.325531f, 18.303238f, 18.249378f, 13.766397f, 13.947391f, 14.148263f, - 14.386917f, 14.681246f, 14.990087f, 15.198166f, 15.372728f, 15.590062f, 15.898583f, 16.126892f, 16.301655f, - 16.504870f, 16.815214f, 17.107498f, 17.329458f, 17.547403f, 17.827654f, 18.118288f, 18.296928f, 18.446100f, - 18.651634f, 18.956806f, 19.223820f, 19.447308f, 19.639887f, 19.809319f, 19.875397f, 19.852556f, 19.797365f, - 15.941937f, 16.118704f, 16.314133f, 16.547867f, 16.839561f, 17.149540f, 17.361883f, 17.542162f, 17.764957f, - 18.078188f, 18.315733f, 18.498205f, 18.699116f, 18.988684f, 19.238989f, 19.410137f, 19.583265f, 19.839512f, - 20.138780f, 20.351770f, 20.546844f, 20.795671f, 21.128067f, 21.404358f, 21.626736f, 21.815500f, 21.985610f, - 22.052843f, 22.029604f, 21.973448f, 17.535220f, 17.710770f, 17.904636f, 18.136950f, 18.427840f, 18.738056f, - 18.951529f, 19.133352f, 19.357613f, 19.672083f, 19.912102f, 20.096638f, 20.296894f, 20.580765f, 20.819603f, - 20.976887f, 21.137802f, 21.387535f, 21.689209f, 21.911621f, 22.119276f, 22.379990f, 22.719910f, 22.998823f, - 23.220970f, 23.408760f, 23.579110f, 23.646685f, 23.623325f, 23.566887f, 18.746353f, 18.922657f, 19.117487f, - 19.350685f, 19.642070f, 19.952137f, 20.164913f, 20.345781f, 20.569134f, 20.882840f, 21.121330f, 21.304590f, - 21.505253f, 21.792645f, 22.038572f, 22.204426f, 22.372890f, 22.626648f, 22.926834f, 23.143423f, 23.343302f, - 23.596668f, 23.931936f, 24.209232f, 24.431519f, 24.619913f, 24.790110f, 24.857473f, 24.834190f, 24.777927f, - 20.166560f, 20.344206f, 20.540766f, 20.775532f, 21.067804f, 21.377607f, 21.589132f, 21.768297f, 21.990030f, - 22.302366f, 22.538124f, 22.719105f, 22.920494f, 23.214176f, 23.472767f, 23.653934f, 23.835890f, 24.096842f, - 24.394371f, 24.600555f, 24.786541f, 25.026773f, 25.353731f, 25.628130f, 25.850672f, 26.040140f, 26.210072f, - 26.277063f, 26.253906f, 26.197956f, 22.363024f, 22.541250f, 22.738552f, 22.973991f, 23.266647f, 23.576340f, - 23.787327f, 23.965760f, 24.186796f, 24.498543f, 24.733124f, 24.913122f, 25.114826f, 25.411213f, 25.675262f, - 25.863028f, 26.050789f, 26.314838f, 26.611223f, 26.812925f, 26.992926f, 27.227505f, 27.550882f, 27.824034f, - 28.046684f, 28.236614f, 28.406433f, 28.473265f, 28.450163f, 28.394344f, 24.429443f, 24.607670f, 24.804970f, - 25.040410f, 25.333065f, 25.642756f, 25.853743f, 26.032173f, 26.253210f, 26.564959f, 26.799540f, 26.979540f, - 27.181242f, 27.477630f, 27.741680f, 27.929441f, 28.117207f, 28.381254f, 28.677637f, 28.879343f, 29.059345f, - 29.293922f, 29.617298f, 29.890451f, 30.113104f, 30.303034f, 30.472853f, 30.539684f, 30.516582f, 30.460762f, - 26.000000f, 26.178228f, 26.375526f, 26.610970f, 26.903624f, 27.213314f, 27.424305f, 27.602734f, 27.823772f, - 28.135519f, 28.370100f, 28.550098f, 28.751800f, 29.048190f, 29.312237f, 29.500000f, 29.687763f, 29.951813f, - 30.248200f, 30.449903f, 30.629902f, 30.864483f, 31.187859f, 31.461012f, 31.683659f, 31.873592f, 32.043407f, - 32.110240f, 32.087135f, 32.031320f, 27.570559f, 27.748787f, 27.946087f, 28.181528f, 28.474184f, 28.783876f, - 28.994865f, 29.173294f, 29.394330f, 29.706080f, 29.940659f, 30.120655f, 30.322360f, 30.618746f, 30.882797f, - 31.070557f, 31.258320f, 31.522371f, 31.818754f, 32.020460f, 32.200460f, 32.435040f, 32.758415f, 33.031567f, - 33.254220f, 33.444150f, 33.613964f, 33.680794f, 33.657696f, 33.601880f, 29.636976f, 29.815207f, 30.012500f, - 30.247944f, 30.540600f, 30.850290f, 31.061283f, 31.239712f, 31.460750f, 31.772500f, 32.007080f, 32.187077f, - 32.388780f, 32.685165f, 32.949215f, 33.136980f, 33.324740f, 33.588790f, 33.885178f, 34.086884f, 34.266880f, - 34.501457f, 34.824837f, 35.097990f, 35.320637f, 35.510574f, 35.680390f, 35.747215f, 35.724117f, 35.668300f, - 31.833440f, 32.011665f, 32.208970f, 32.444412f, 32.737070f, 33.046757f, 33.257744f, 33.436176f, 33.657207f, - 33.968960f, 34.203537f, 34.383537f, 34.585240f, 34.881630f, 35.145676f, 35.333440f, 35.521206f, 35.785255f, - 36.081642f, 36.283340f, 36.463340f, 36.697920f, 37.021297f, 37.294453f, 37.517097f, 37.707027f, 37.876846f, - 37.943680f, 37.920578f, 37.864758f, 33.253647f, 33.431873f, 33.629170f, 33.864613f, 34.157270f, 34.466957f, - 34.677948f, 34.856377f, 35.077415f, 35.389160f, 35.623745f, 35.803745f, 36.005447f, 36.301834f, 36.565884f, - 36.753647f, 36.941406f, 37.205456f, 37.501840f, 37.703545f, 37.883545f, 38.118122f, 38.441500f, 38.714653f, - 38.937300f, 39.127235f, 39.297054f, 39.363884f, 39.340782f, 39.284960f, 34.464783f, 34.643010f, 34.840305f, - 35.075752f, 35.368404f, 35.678100f, 35.889088f, 36.067516f, 36.288550f, 36.600300f, 36.834885f, 37.014877f, - 37.216583f, 37.512970f, 37.777020f, 37.964783f, 38.152546f, 38.416595f, 38.712980f, 38.914684f, 39.094685f, - 39.329260f, 39.652645f, 39.925793f, 40.148440f, 40.338375f, 40.508194f, 40.575024f, 40.551920f, 40.496105f, - 36.058067f, 36.236290f, 36.433590f, 36.669033f, 36.961685f, 37.271378f, 37.482370f, 37.660800f, 37.881836f, - 38.193590f, 38.428170f, 38.608162f, 38.809868f, 39.106250f, 39.370300f, 39.558064f, 39.745830f, 40.009880f, - 40.306267f, 40.507970f, 40.687970f, 40.922550f, 41.245926f, 41.519077f, 41.741722f, 41.931652f, 42.101475f, - 42.168304f, 42.145203f, 42.089386f, 38.315002f, 38.493233f, 38.690533f, 38.925976f, 39.218628f, 39.528320f, - 39.739307f, 39.917736f, 40.138775f, 40.450520f, 40.685104f, 40.865097f, 41.066803f, 41.363190f, 41.627243f, - 41.815002f, 42.002766f, 42.266820f, 42.563200f, 42.764908f, 42.944904f, 43.179485f, 43.502860f, 43.776016f, - 43.998665f, 44.188595f, 44.358418f, 44.425247f, 44.402145f, 44.346330f, 40.227080f, 40.405310f, 40.602608f, - 40.838050f, 41.130707f, 41.440395f, 41.651382f, 41.829820f, 42.050854f, 42.362600f, 42.597183f, 42.777180f, - 42.978880f, 43.275270f, 43.539320f, 43.727080f, 43.914845f, 44.178894f, 44.475280f, 44.676983f, 44.856983f, - 45.091560f, 45.414940f, 45.688090f, 45.910740f, 46.100674f, 46.270493f, 46.337322f, 46.314220f, 46.258400f, - 41.785618f, 41.963844f, 42.161144f, 42.396584f, 42.689240f, 42.998936f, 43.209923f, 43.388355f, 43.609394f, - 43.921143f, 44.155720f, 44.335716f, 44.537420f, 44.833805f, 45.097860f, 45.285614f, 45.473377f, 45.737427f, - 46.033817f, 46.235523f, 46.415524f, 46.650105f, 46.973476f, 47.246630f, 47.469276f, 47.659210f, 47.829030f, - 47.895855f, 47.872753f, 47.816940f, 43.115140f, 43.293365f, 43.490665f, 43.726105f, 44.018764f, 44.328457f, - 44.539444f, 44.717873f, 44.938910f, 45.250660f, 45.485240f, 45.665237f, 45.866940f, 46.163326f, 46.427376f, - 46.615143f, 46.802902f, 47.066956f, 47.363342f, 47.565050f, 47.745050f, 47.979626f, 48.302998f, 48.576153f, - 48.798798f, 48.988730f, 49.158546f, 49.225376f, 49.202282f, 49.146458f, 44.303867f, 44.482094f, 44.679394f, - 44.914833f, 45.207493f, 45.517180f, 45.728170f, 45.906600f, 46.127640f, 46.439384f, 46.673965f, 46.853966f, - 47.055668f, 47.352055f, 47.616100f, 47.803867f, 47.991630f, 48.255680f, 48.552063f, 48.753770f, 48.933773f, - 49.168350f, 49.491726f, 49.764877f, 49.987526f, 50.177460f, 50.347275f, 50.414100f, 50.391006f, 50.335186f, - 44.771675f, 44.949905f, 45.147200f, 45.382645f, 45.675300f, 45.984990f, 46.195976f, 46.374413f, 46.595448f, - 46.907196f, 47.141773f, 47.321774f, 47.523476f, 47.819862f, 48.083910f, 48.271680f, 48.459446f, 48.723490f, - 49.019882f, 49.221580f, 49.401585f, 49.636160f, 49.959538f, 50.232693f, 50.455338f, 50.645270f, 50.815090f, - 50.881920f, 50.858818f, 50.803000f, 44.609966f, 44.788193f, 44.985493f, 45.220936f, 45.513590f, 45.823280f, - 46.034270f, 46.212700f, 46.433743f, 46.745490f, 46.980070f, 47.160065f, 47.361770f, 47.658157f, 47.922207f, - 48.109970f, 48.297733f, 48.561783f, 48.858166f, 49.059875f, 49.239872f, 49.474450f, 49.797830f, 50.070980f, - 50.293625f, 50.483560f, 50.653378f, 50.720203f, 50.697100f, 50.641280f, 44.219246f, 44.397472f, 44.594772f, - 44.830210f, 45.122868f, 45.432560f, 45.643543f, 45.821980f, 46.043020f, 46.354763f, 46.589344f, 46.769340f, - 46.971046f, 47.267433f, 47.531483f, 47.719242f, 47.907005f, 48.171050f, 48.467438f, 48.669140f, 48.849144f, - 49.083720f, 49.407100f, 49.680256f, 49.902905f, 50.092834f, 50.262653f, 50.329483f, 50.306380f, 50.250570f - }); +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_2) { + auto a = NDArrayFactory::create('c', {4, 4}, + { + 1.f, + 1.f, + 1.f, + 1.f, + 0.f, + 1.f, + 1.f, + 0.f, + 0.f, + 0.f, + 2.f, + 1.f, + 0.f, + 0.f, + 0.f, + 3.f, + }); + + auto b = NDArrayFactory::create('c', {4, 1}, {2.f, 4.f, 2.f, 4.f}); + + auto exp = NDArrayFactory::create( + 'c', {4, 1}, {-3.3333333f, 3.6666666f, 0.333333f, 1.3333333f}); + + sd::ops::solve op; + + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + + // z->printIndexedBuffer("Solve 4x4"); + + ASSERT_TRUE(exp.equalsTo(z)); +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_3) { + auto a = NDArrayFactory::create( + 'c', {2, 4, 4}, + {1.f, 1.f, 1.f, 1.f, 0.f, 1.f, 1.f, 0.f, 0.f, + 0.f, 2.f, 1.f, 0.f, 0.f, 0.f, 3.f, - auto size = NDArrayFactory::create({30, 30}); - sd::ops::resize_bicubic op; - auto results = op.evaluate({&input, &size}, {}, {}); + 3.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, 1.f, + 0.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + }); -// result.printBuffer("Resized to 30x30"); -// expected.printBuffer("Expect for 30x30"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); -} + auto b = NDArrayFactory::create( + 'c', {2, 4, 1}, {2.f, 4.f, 2.f, 4.f, 4.f, 2.f, 4.f, 2.f}); -TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test7) { + auto exp = NDArrayFactory::create( + 'c', {2, 4, 1}, + {-3.3333333f, 3.6666666f, 0.333333f, 1.3333333f, 1.333333f, -0.6666667f, + 2.6666667f, -1.3333333f}); - NDArray input = NDArrayFactory::create('c', {2, 5, 5, 1}, { - 0.2303, 0.7950, 0.8171, 0.0451, 0.3690, 0.6846, 0.2727, 0.2770, 0.2381, 0.9511, - 0.4116, 0.3997, 0.4075, 0.6275, 0.8018, 0.0678, 0.6221, 0.2982, 0.1524, 0.2613, - 0.7425, 0.6036, 0.7926, 0.5838, 0.1361, 0.4154, 0.3634, 0.3741, 0.2088, 0.2989, - 0.3982, 0.5618, 0.7266, 0.1089, 0.2922, 0.3306, 0.2869, 0.6638, 0.3091, 0.9312, - 0.0240, 0.2893, 0.5632, 0.9625, 0.4189, 0.3854, 0.2743, 0.6754, 0.8820, 0.8699}); - - NDArray expected = NDArrayFactory::create('c', {2, 9, 9, 1}, { - 0.2303f, 0.54569f, 0.840649f, 0.92725444f, 0.65660673f, - 0.16641647f, 0.06117659f, 0.33279106f, 0.4023279f, 0.5139505f, - 0.49821317f, 0.4906872f, 0.537642f, 0.4070102f, 0.13030615f, - 0.258801f, 0.65352744f, 0.773368f, 0.69225276f, 0.44177493f, - 0.21910316f, 0.22368976f, 0.24221404f, 0.21399781f, 0.5114972f, - 0.9169859f, 1.0511527f, 0.5608501f, 0.41315168f, 0.2913824f, - 0.2966933f, 0.38585684f, 0.48849702f, 0.71013063f, 0.9086001f, - 0.9794303f, 0.29625386f, 0.39427578f, 0.45971435f, 0.39693952f, - 0.40860707f, 0.51061106f, 0.6181093f, 0.67309624f, 0.69564015f, - 0.06012487f, 0.3863805f, 0.58993465f, 0.40679216f, 0.22607432f, - 0.20093678f, 0.25901243f, 0.3615362f, 0.39371052f, 0.24176767f, - 0.4868709f, 0.650651f, 0.5493148f, 0.3825456f, 0.27788478f, - 0.18927254f, 0.16692996f, 0.15432167f, 0.677519f, 0.6236242f, - 0.61700624f, 0.7214321f, 0.7307374f, 0.6251454f, 0.3924176f, - 0.17802659f, 0.10231908f, 0.81192374f, 0.66878575f, 0.6118803f, - 0.7797006f, 0.8396968f, 0.72889954f, 0.44547448f, 0.16794783f, - 0.07125802f, 0.4154f, 0.38504714f, 0.3623221f, 0.3862173f, - 0.3397379f, 0.23285517f, 0.21876639f, 0.2892362f, 0.30817088f, - 0.41268015f, 0.45587808f, 0.51991886f, 0.60977113f, 0.49489656f, - 0.21313031f, 0.11297428f, 0.2167207f, 0.23940037f, 0.39337245f, - 0.46112412f, 0.583034f, 0.76207364f, 0.6326203f, 0.22189438f, - 0.12071565f, 0.3275853f, 0.3794855f, 0.38497013f, 0.35049653f, - 0.41895086f, 0.671095f, 0.62119365f, 0.22362521f, 0.30189657f, - 0.72530353f, 0.85048175f, 0.2524255f, 0.2182264f, 0.2964637f, - 0.5361996f, 0.6255393f, 0.46424767f, 0.5741281f, 0.8408146f, - 0.92403257f, 0.04648584f, 0.14959256f, 0.32215607f, 0.46194845f, - 0.6642166f, 0.83560026f, 0.7663391f, 0.5284251f, 0.4573109f, - 0.10357999f, 0.17442937f, 0.32116935f, 0.45530772f, 0.7163773f, - 0.9856574f, 0.8976148f, 0.5538923f, 0.45173654f, 0.34958175f, - 0.2680429f, 0.30470955f, 0.51233786f, 0.75128907f, 0.86736864f, - 0.8982046f, 0.83254474f, 0.8168574f, 0.4225865f, 0.2956836f, - 0.29948136f, 0.5276342f, 0.76461166f, 0.8442875f, 0.907862f, - 0.9139262f, 0.92068815f - }); - auto size = NDArrayFactory::create({9, 9}); - sd::ops::resize_bicubic op; - auto results = op.evaluate({&input, &size}, {}, {}); + sd::ops::solve op; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); - NDArray* result = results.at(0); + // z->printIndexedBuffer("Solve 4x4"); -// result.printBuffer("Resized to 9x9"); -// expected.printBuffer("Expect for 9x9"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(exp.equalsTo(z)); } -TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) { +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4) { + auto a = NDArrayFactory::create( + 'c', {2, 2, 2}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f}); - NDArray input = NDArrayFactory::create('c', {2, 5, 5, 1}, { - 0.23028551377579154, 0.7949972231516509, 0.8171307820461517, 0.04507309923418412, 0.3689673597428338, - 0.6845757584903018, 0.27268547668219667, 0.2770196372806053, 0.2381478370531429, 0.9511201914609859, - 0.41160882670429033, 0.3997152563642703, 0.4074505147711718, 0.6274595060113246, 0.8017922711300232, - 0.06782045852179475, 0.6220772280691722, 0.2982335327629251, 0.1523603480424196, 0.2612986044295986, - 0.7424762244324299, 0.6036156464824591, 0.7926371071102005, 0.5838270656432538, 0.13607200219168547, - 0.4154002170215956, 0.36340617544852116, 0.37405031188276827, 0.20880251686544882, 0.298919946410666, - 0.39820758164277126, 0.5617728968896589, 0.72660225993937, 0.10888245916813699, 0.29215797784445496, - 0.3305531351746034, 0.28693451964931715, 0.6637635348315494, 0.30913418229827583, 0.9312186188801752, - 0.0239594182399363, 0.2892942758780874, 0.5631691110629038, 0.9625499752246309, 0.4189439089689968, - 0.3854304088214935, 0.27426304203925045, 0.6754051704648238, 0.8820362490795286, 0.8699337744328859}); - - - auto testData = NDArrayFactory::create('c', {2,9,9,1}, { - 0.230286514f, 0.510566354f, 0.794997215f, 0.931386113f, 0.817130804f, 0.402811885f, 0.045073099f, 0.134639814f, 0.368967354f, - 0.483021289f, 0.501266003f, 0.521932304f, 0.572325349f, 0.534847379f, 0.267853439f, 0.105112493f, 0.349290252f, 0.674043298f, - 0.684575737f, 0.478224277f, 0.272685468f, 0.239882097f, 0.27701965f, 0.191148892f, 0.23814784f, 0.590989769f, 0.951120198f, - 0.622912169f, 0.441326082f, 0.266387194f, 0.232538164f, 0.301838756f, 0.356378645f, 0.495445013f, 0.756725252f, 0.981704295f, - 0.411608815f, 0.40493685f, 0.399715245f, 0.381842017f, 0.407450527f, 0.501836538f, 0.627459526f, 0.735251725f, 0.801792264f, - 0.150875032f, 0.357000858f, 0.524536073f, 0.450354964f, 0.318719596f, 0.319606483f, 0.385957927f, 0.46392554f, 0.529285908f, - 0.06782046f, 0.375309169f, 0.622077227f, 0.525792599f, 0.298233539f, 0.184723631f, 0.15236035f, 0.193153858f, 0.261298597f, - - 0.372918189f, 0.512539625f, 0.63369292f, 0.628733814f, 0.535196245f, 0.436597466f, 0.323553175f, 0.215942055f, 0.148014024f, - 0.742476225f, 0.655325174f, 0.603615642f, 0.704684138f, 0.79263711f, 0.747929871f, 0.583827078f, 0.340373576f, 0.136071995f, - 0.415400207f, 0.388405323f, 0.363406181f, 0.379345775f, 0.374050319f, 0.28397581f, 0.208802521f, 0.238369256f, 0.298919946f, - 0.413146496f, 0.444389015f, 0.488355637f, 0.568351328f, 0.556217432f, 0.345546633f, 0.140068889f, 0.148834035f, 0.23562704f, - 0.398207575f, 0.464537472f, 0.561772883f, 0.717433035f, 0.726602256f, 0.416013002f, 0.108882457f, 0.142608985f, 0.292157978f, - 0.391511708f, 0.389470309f, 0.442729384f, 0.651181757f, 0.737665415f, 0.41685915f, 0.138383076f, 0.342548877f, 0.659080088f, - - 0.330553144f, 0.273416102f, 0.286934525f, 0.50450629f, 0.663763523f, 0.463456154f, 0.309134185f, 0.586929917f, 0.931218624f, - 0.137025774f, 0.169145152f, 0.263757467f, 0.436182201f, 0.597053051f, 0.657990932f, 0.662163854f, 0.68354249f, 0.692712903f, - 0.023959421f, 0.130951077f, 0.289294273f, 0.413664877f, 0.563169122f, 0.839498401f, 0.962549984f, 0.728188932f, 0.418943912f, - 0.175951749f, 0.198239252f, 0.281999886f, 0.420836329f, 0.609856486f, 0.863734365f, 0.983550847f, 0.825015843f, 0.596413136f, - 0.385430396f, 0.292239636f, 0.274263054f, 0.445040524f, 0.675405145f, 0.817462444f, 0.882036269f, 0.895356655f, 0.869933784f - }); + auto b = NDArrayFactory::create( + 'c', {2, 2, 2}, + {0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f}); - auto size = NDArrayFactory::create({9, 9}); - sd::ops::resize_bicubic op; - auto results = op.evaluate({&input, &size}, {}, {}, {true, false}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2}, + {// 1.524494767f, 0.432706356f,-0.518630624f, 0.737760842f, + // 0.819143713f, 0.720401764f, 0.264349997f, 0.444699198f + 1.5245394f, 0.4326952f, -0.51873577f, 0.7377896f, 0.81915987f, + 0.72049433f, 0.2643504f, 0.44472617f}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::solve op; - NDArray* result = results.at(0); + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); -// result.printBuffer("Resized to 9x9"); -// testData.printBuffer("Expect for 9x9"); - ASSERT_TRUE(testData.isSameShape(result)); - ASSERT_TRUE(testData.equalsTo(result)); + // z->printBuffer("4 Solve 4x4"); + // exp.printBuffer("4 Expec 4x4"); + + ASSERT_TRUE(exp.equalsTo(z)); } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_1) { + auto a = NDArrayFactory::create( + 'c', {2, 2, 2}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f}); + auto b = NDArrayFactory::create( + 'c', {2, 2, 2}, + {0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f}); -TEST_F(DeclarableOpsTests11, ImageResizeArea_Test1) { + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2}, + {1.3357621f, 0.3399364f, -0.37077796f, 0.91573375f, 0.4400987f, + 0.2766527f, 0.6394467f, 0.79696566f}); - NDArray input = NDArrayFactory::create('c', {1, 3, 3, 4}); - NDArray expected = NDArrayFactory::create('c', {1, 6, 6, 4}, { - 1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f, - 5.f, 6.f, 7.f, 8.f, - 5.f, 6.f, 7.f, 8.f, - 9.f, 10.f, 11.f, 12.f, - 9.f, 10.f, 11.f, 12.f, - - 1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f, - 5.f, 6.f, 7.f, 8.f, - 5.f, 6.f, 7.f, 8.f, - 9.f, 10.f, 11.f, 12.f, - 9.f, 10.f, 11.f, 12.f, - - 13.f, 14.f, 15.f, 16.f, - 13.f, 14.f, 15.f, 16.f, - 17.f, 18.f, 19.f, 20.f, - 17.f, 18.f, 19.f, 20.f, - 21.f, 22.f, 23.f, 24.f, - 21.f, 22.f, 23.f, 24.f, - - 13.f, 14.f, 15.f, 16.f, - 13.f, 14.f, 15.f, 16.f, - 17.f, 18.f, 19.f, 20.f, - 17.f, 18.f, 19.f, 20.f, - 21.f, 22.f, 23.f, 24.f, - 21.f, 22.f, 23.f, 24.f, - - 25.f, 26.f, 27.f, 28.f, - 25.f, 26.f, 27.f, 28.f, - 29.f, 30.f, 31.f, 32.f, - 29.f, 30.f, 31.f, 32.f, - 33.f, 34.f, 35.f, 36.f, - 33.f, 34.f, 35.f, 36.f, - - 25.f, 26.f, 27.f, 28.f, - 25.f, 26.f, 27.f, 28.f, - 29.f, 30.f, 31.f, 32.f, - 29.f, 30.f, 31.f, 32.f, - 33.f, 34.f, 35.f, 36.f, - 33.f, 34.f, 35.f, 36.f }); - input.linspace(1); - auto size = NDArrayFactory::create({6, 6}); - sd::ops::resize_area op; - auto results = op.evaluate({&input, &size}, {}, {}); + sd::ops::solve op; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto res = op.evaluate({&a, &b}, {true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); - NDArray* result = results.at(0); + // z->printBuffer("4 Solve 4x4"); + // exp.printBuffer("4 Expec 4x4"); -// result.printBuffer("Area Resized to 6x6"); -// expected.printBuffer("Area Expect for 6x6"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(exp.equalsTo(z)); } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_2) { + auto a = NDArrayFactory::create( + 'c', {3, 3}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f}); + auto b = NDArrayFactory::create( + 'c', {3, 3}, + {0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f, + 0.7605f}); -TEST_F(DeclarableOpsTests11, ImageResizeArea_Test2) { + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {0.99088347f, 1.1917052f, 1.2642528f, 0.35071516f, 0.50630623f, + 0.42935497f, -0.30013534f, -0.53690606f, -0.47959247f}); - NDArray input = NDArrayFactory::create('c', {1, 3, 3, 1}); - NDArray expected = NDArrayFactory::create('c', {1, 6, 6, 1}, { - 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, - 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, - 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, - 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, - 7.f, 7.f, 8.f, 8.f, 9.f, 9.f, - 7.f, 7.f, 8.f, 8.f, 9.f, 9.f - }); - input.linspace(1); - auto size = NDArrayFactory::create({6, 6}); - sd::ops::resize_area op; - auto results = op.evaluate({&input, &size}, {}, {}); + sd::ops::triangular_solve op; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto res = op.evaluate({&a, &b}, {true, false}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); - NDArray* result = results.at(0); + // z->printBuffer("4_2 Triangular_Solve 3x3"); + // exp.printBuffer("4_2 Triangular_Expec 3x3"); -// result.printBuffer("Area Resized to 6x6"); -// expected.printBuffer("Area Expect for 6x6"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(exp.equalsTo(z)); } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_3) { + auto a = NDArrayFactory::create( + 'c', {3, 3}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f}); -TEST_F(DeclarableOpsTests11, ImageResizeArea_Test3) { + auto b = NDArrayFactory::create( + 'c', {3, 3}, + {0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f, + 0.7605f}); - NDArray input = NDArrayFactory::create('c', {1, 3, 3, 3}); - NDArray expected = NDArrayFactory::create('c', {1, 6, 6, 3}, { - 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, - 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, - 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f - }); - input.linspace(1); - auto size = NDArrayFactory::create({6, 6}); - sd::ops::resize_area op; - auto results = op.evaluate({&input, &size}, {}, {}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {0.45400196f, 0.53174824f, 0.62064564f, -0.79585856f, -0.82621557f, + -0.87855506f, 1.1904413f, 1.3938838f, 1.3926021f}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::triangular_solve op; - NDArray* result = results.at(0); + auto res = op.evaluate({&a, &b}, {true, true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); -// result.printBuffer("Area Resized to 6x6"); -// expected.printBuffer("Area Expect for 6x6"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + // z->printBuffer("4_3 Triangular_Solve 3x3"); + // exp.printBuffer("4_3 Triangular_Expec 3x3"); + + ASSERT_TRUE(exp.equalsTo(z)); } -TEST_F(DeclarableOpsTests11, ImageResizeArea_Test4) { +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_4) { + auto a = NDArrayFactory::create( + 'c', {3, 3}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f}); - NDArray input = NDArrayFactory::create('c', {2, 3, 3, 3}, { - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27 - }); + auto b = NDArrayFactory::create( + 'c', {3, 3}, + {0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f, + 0.7605f}); - NDArray expected = NDArrayFactory::create('c', {2, 6, 6, 3}, { - 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, - 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, - 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, - - 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, - 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, - 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f - }); - //input.linspace(1); - auto size = NDArrayFactory::create({6, 6}); - sd::ops::resize_area op; - auto results = op.evaluate({&input, &size}, {}, {}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {0.8959121f, 1.6109066f, 1.7501404f, 0.49000582f, 0.66842675f, 0.5577021f, + -0.4398522f, -1.1899745f, -1.1392052f}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::solve op; - NDArray* result = results.at(0); + auto res = op.evaluate({&a, &b}, {false}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); -// result.printBuffer("Area Resized to 6x6"); -// expected.printBuffer("Area Expect for 6x6"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + // z->printBuffer("4_4 Solve 3x3"); + // exp.printBuffer("4_4 Expec 3x3"); + + ASSERT_TRUE(exp.equalsTo(z)); } -TEST_F(DeclarableOpsTests11, ImageResizeArea_Test5) { +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_5) { + auto a = NDArrayFactory::create( + 'c', {3, 3}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f}); - NDArray input = NDArrayFactory::create('c', {2, 3, 3, 3}, { - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27 - }); + auto b = NDArrayFactory::create( + 'c', {3, 3}, + {0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f, + 0.7605f}); - NDArray expected = NDArrayFactory::create('c', {2, 6, 6, 3}, { - 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, - 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, - 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, - - 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, - 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, - 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f - }); - //input.linspace(1); - auto size = NDArrayFactory::create({6, 6}); - sd::ops::resize_area op; - auto results = op.evaluate({&input, &size}, {}, {}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {1.5504692f, 1.8953944f, 2.2765768f, 0.03399149f, 0.2883001f, 0.5377323f, + -0.8774802f, -1.2155888f, -1.8049058f}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::solve op; - NDArray* result = results.at(0); + auto res = op.evaluate({&a, &b}, {true, true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); -// result.printBuffer("Area Resized to 6x6"); -// expected.printBuffer("Area Expect for 6x6"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + // z->printBuffer("4_5 Solve 3x3"); + // exp.printBuffer("4_5 Expec 3x3"); + + ASSERT_TRUE(exp.equalsTo(z)); } -TEST_F(DeclarableOpsTests11, ImageResizeArea_Test6) { +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_6) { + auto a = NDArrayFactory::create( + 'c', {3, 3}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f}); - NDArray input = NDArrayFactory::create('c', {2, 3, 3, 1}, { - 1, 2, 3, 4, 5, 6, 7, 8, 9, - 1, 2, 3, 4, 5, 6, 7, 8, 9 - }); + auto b = NDArrayFactory::create( + 'c', {3, 3}, + {0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f, + 0.7605f}); - NDArray expected = NDArrayFactory::create('c', {2, 6, 6, 1}, { - 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, - 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, - 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, - 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, - 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, - 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f, - - 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, - 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, - 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, - 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, - 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, - 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f - }); - //input.linspace(1); - auto size = NDArrayFactory::create({6, 6}); - sd::ops::resize_area op; - auto results = op.evaluate({&input, &size}, {}, {}, {true}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {0.99088347f, 1.1917052f, 1.2642528f, -0.426483f, -0.42840624f, + -0.5622601f, 0.01692283f, -0.04538865f, -0.09868701f}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::triangular_solve op; - NDArray* result = results.at(0); + auto res = op.evaluate({&a, &b}, {false, true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); -// result.printBuffer("Area Resized to 6x6"); -// expected.printBuffer("Area Expect for 6x6"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + // z->printBuffer("4_6 Solve 3x3"); + // exp.printBuffer("4_6 Expec 3x3"); + + ASSERT_TRUE(exp.equalsTo(z)); } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_7) { + auto a = NDArrayFactory::create( + 'c', {3, 3}, + {// 0.7788f, 0.2309f, 0.5056f, + // 0.8012f, 0.7271f, 0.8925f, + // 0.7244f, 0.1804f, 0.5461f -TEST_F(DeclarableOpsTests11, ImageResizeArea_Test7) { + 0.7788f, 0.2309f, 0.5056f, 0.8012f, 0.7271f, 0.8925f, 0.7244f, 0.1804f, + 0.5461f}); - NDArray input = NDArrayFactory::create('c', {2, 3, 3, 1}, { - 1, 2, 3, 4, 5, 6, 7, 8, 9, - 1, 2, 3, 4, 5, 6, 7, 8, 9 - }); + auto b = NDArrayFactory::create( + 'c', {3, 3}, + {0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f, + 0.7605f}); - NDArray expected = NDArrayFactory::create('c', {2, 6, 6, 1}, { - 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, - 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, - 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, - 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, - 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, - 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f, - - 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, - 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, - 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, - 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, - 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, - 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f - }); - //input.linspace(1); -// auto size = NDArrayFactory::create({6, 6}); - sd::ops::resize_area op; - auto results = op.evaluate({&input}, {}, {6, 6}, {true}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {0.99088347f, 1.1917052f, 1.2642528f, -0.426483f, -0.42840624f, + -0.5622601f, 0.01692283f, -0.04538865f, -0.09868701f}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::triangular_solve op; - NDArray* result = results.at(0); + auto res = op.evaluate({&a, &b}, {true, false}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); -// result.printBuffer("Area Resized to 6x6"); -// expected.printBuffer("Area Expect for 6x6"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); -} - -TEST_F(DeclarableOpsTests11, ImageResizeArea_Test8) { - - NDArray input = NDArrayFactory::create('c', {1, 3, 3, 1}, { - 1, 2, 3, 4, 5, 6, 7, 8, 9 - }); - - NDArray expected = NDArrayFactory::create('c', {1, 6, 6, 1}, { - 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, - 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, - 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, - 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, - 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, - 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f - }); - //input.linspace(1); -// auto size = NDArrayFactory::create({6, 6}); - sd::ops::resize_area op; - auto results = op.evaluate({&input}, {}, {6, 6}, {true}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray* result = results.at(0); - -// result.printBuffer("Area Resized to 6x6"); -// expected.printBuffer("Area Expect for 6x6"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); -} - -TEST_F(DeclarableOpsTests11, ResizeImages_Test8) { - - NDArray input = NDArrayFactory::create('c', {1, 3, 3, 1}, { - 1, 2, 3, 4, 5, 6, 7, 8, 9 - }); - - NDArray expected = NDArrayFactory::create('c', {1, 6, 6, 1}, { -// 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, 4.f, 4.f, 5.f, 5.f, -// 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 9.f, 9.f, 7.f, 7.f, 8.f, 8.f, 9.f, 9.f - 1.f , 1.f , 1.5f, 2.f , 2.f, 3.f, 1.f , 1.f , 1.5f, 2.f , 2.f, 3.f, - 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, 4.f , 4.f , 4.5f , 5.f, 5.f, 6.f , - 4.f, 4.f, 4.5f , 5.f, 5.f, 6.f, 7.f , 7.f , 7.5f , 8.f , 8.f , 9.f - }); - //input.linspace(1); -// auto size = NDArrayFactory::create({6, 6}); - sd::ops::resize_images op; - auto results = op.evaluate({&input}, {}, {6, 8, ops::helpers::kResizeArea}, {true, true}); // resize_area to 6x8 with align corners and preserve aspect ratio of input image - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray* result = results.at(0); - -// result->printBuffer("Area Resized to 6x6"); -// expected.printBuffer("Area Expect for 6x6"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, ImageResizeArea_Test9) { - - NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}, { - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24 - }); - - NDArray expected = NDArrayFactory::create('c', {1, 10, 10, 4}, { - 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, - 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, - 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, - 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333336f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999998f, 9.999997f, 10.999997f, 11.999997f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 15.666671f, 16.666672f, 17.666672f, 18.666672f, 17.000006f, 18.000004f, 19.000006f, 20.000004f, 17.000006f, 18.000004f, 19.000006f, 20.000004f, 18.333344f, 19.333344f, 20.333345f, 21.333344f, 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000002f, 22.000000f, 23.000002f, 24.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, 15.666661f, 16.666662f, 17.666660f, 18.666660f, 16.999994f, 17.999994f, 18.999992f, 19.999992f, 16.999994f, 17.999994f, 18.999992f, 19.999992f, 18.333334f, 19.333332f, 20.333334f, 21.333332f, 20.999992f, 21.999992f, 22.999990f, 23.999992f, 20.999992f, 21.999992f, 22.999990f, 23.999992f, 20.999989f, 21.999989f, 22.999987f, 23.999987f - - }); - //input.linspace(1); - auto size = NDArrayFactory::create({10, 10}); - sd::ops::resize_area op; - auto results = op.evaluate({&input, &size}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray* result = results.at(0); - -// result.printBuffer("Area Resized to 10x10"); - // expected.printBuffer("Area Expect for 6x6"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, ImageResizeArea_Test10) { - - NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}, { - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24 - }); - - NDArray expected = NDArrayFactory::create('c', {1, 10, 10, 4}, { - 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333336f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999998f, 9.999997f, 10.999997f, 11.999997f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 15.666671f, 16.666672f, 17.666672f, 18.666672f, 17.000006f, 18.000004f, 19.000006f, 20.000004f, 17.000006f, 18.000004f, 19.000006f, 20.000004f, 18.333344f, 19.333344f, 20.333345f, 21.333344f, 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000002f, 22.000000f, 23.000002f, 24.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, 15.666661f, 16.666662f, 17.666660f, 18.666660f, 16.999994f, 17.999994f, 18.999992f, 19.999992f, 16.999994f, 17.999994f, 18.999992f, 19.999992f, 18.333334f, 19.333332f, 20.333334f, 21.333332f, 20.999992f, 21.999992f, 22.999990f, 23.999992f, 20.999992f, 21.999992f, 22.999990f, 23.999992f, 20.999989f, 21.999989f, 22.999987f, 23.999987f - - }); - //input.linspace(1); - //auto size = NDArrayFactory::create({10, 10}); - sd::ops::resize_area op; - auto results = op.evaluate({&input}, {}, {10, 10}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray* result = results.at(0); - -// result.printBuffer("Area Resized to 10x10"); - // expected.printBuffer("Area Expect for 6x6"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, ImageResizeArea_Test11) { - - NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}, { - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24 - }); - -// NDArray expected = NDArrayFactory::create('c', {1, 6, 9, 4}, { -// 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333336, 8.999999, 9.999999, 11.000000, 11.999999, 8.999999, 9.999999, 11.000000, 11.999999, 8.999998, 9.999997, 10.999997, 11.999997, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 15.666671, 16.666672, 17.666672, 18.666672, 17.000006, 18.000004, 19.000006, 20.000004, 17.000006, 18.000004, 19.000006, 20.000004, 18.333344, 19.333344, 20.333345, 21.333344, 21.000006, 22.000006, 23.000006, 24.000006, 21.000006, 22.000006, 23.000006, 24.000006, 21.000002, 22.000000, 23.000002, 24.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 15.666661, 16.666662, 17.666660, 18.666660, 16.999994, 17.999994, 18.999992, 19.999992, 16.999994, 17.999994, 18.999992, 19.999992, 18.333334, 19.333332, 20.333334, 21.333332, 20.999992, 21.999992, 22.999990, 23.999992, 20.999992, 21.999992, 22.999990, 23.999992, 20.999989, 21.999989, 22.999987, 23.999987 -// -// }); - //input.linspace(1); - //auto size = NDArrayFactory::create({10, 10}); - sd::ops::resize_area op; - auto results = op.evaluate({&input}, {}, {6, 9}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray* result = results.at(0); - -// result.printBuffer("Area Resized to 6x9"); - // expected.printBuffer("Area Expect for 6x6"); -// ASSERT_TRUE(expected.isSameShape(result)); -// ASSERT_TRUE(expected.equalsTo(result)); -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, ImageResizeArea_Test12) { - - NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}, { - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24 - }); - -// NDArray expected = NDArrayFactory::create('c', {1, 6, 9, 4}, { -// 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333336, 8.999999, 9.999999, 11.000000, 11.999999, 8.999999, 9.999999, 11.000000, 11.999999, 8.999998, 9.999997, 10.999997, 11.999997, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 15.666671, 16.666672, 17.666672, 18.666672, 17.000006, 18.000004, 19.000006, 20.000004, 17.000006, 18.000004, 19.000006, 20.000004, 18.333344, 19.333344, 20.333345, 21.333344, 21.000006, 22.000006, 23.000006, 24.000006, 21.000006, 22.000006, 23.000006, 24.000006, 21.000002, 22.000000, 23.000002, 24.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 15.666661, 16.666662, 17.666660, 18.666660, 16.999994, 17.999994, 18.999992, 19.999992, 16.999994, 17.999994, 18.999992, 19.999992, 18.333334, 19.333332, 20.333334, 21.333332, 20.999992, 21.999992, 22.999990, 23.999992, 20.999992, 21.999992, 22.999990, 23.999992, 20.999989, 21.999989, 22.999987, 23.999987 -// -// }); - //input.linspace(1); - //auto size = NDArrayFactory::create({10, 10}); - sd::ops::resize_area op; - auto results = op.evaluate({&input}, {}, {10, 15}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray* result = results.at(0); - -// result.printBuffer("Area Resized to 6x9"); - // expected.printBuffer("Area Expect for 6x6"); -// ASSERT_TRUE(expected.isSameShape(result)); -// ASSERT_TRUE(expected.equalsTo(result)); -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, ImageResizeArea_Test13) { - - NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}, { - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24 - }); - -// NDArray expected = NDArrayFactory::create('c', {1, 8, 8, 4}, { -// 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333336, 8.999999, 9.999999, 11.000000, 11.999999, 8.999999, 9.999999, 11.000000, 11.999999, 8.999998, 9.999997, 10.999997, 11.999997, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 15.666671, 16.666672, 17.666672, 18.666672, 17.000006, 18.000004, 19.000006, 20.000004, 17.000006, 18.000004, 19.000006, 20.000004, 18.333344, 19.333344, 20.333345, 21.333344, 21.000006, 22.000006, 23.000006, 24.000006, 21.000006, 22.000006, 23.000006, 24.000006, 21.000002, 22.000000, 23.000002, 24.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 15.666661, 16.666662, 17.666660, 18.666660, 16.999994, 17.999994, 18.999992, 19.999992, 16.999994, 17.999994, 18.999992, 19.999992, 18.333334, 19.333332, 20.333334, 21.333332, 20.999992, 21.999992, 22.999990, 23.999992, 20.999992, 21.999992, 22.999990, 23.999992, 20.999989, 21.999989, 22.999987, 23.999987 -// -// }); - //input.linspace(1); - //auto size = NDArrayFactory::create({10, 10}); - sd::ops::resize_area op; - auto results = op.evaluate({&input}, {}, {9, 9}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray* result = results.at(0); - -// result.printBuffer("Area Resized to 8x8"); - // expected.printBuffer("Area Expect for 6x6"); -// ASSERT_TRUE(expected.isSameShape(result)); -// ASSERT_TRUE(expected.equalsTo(result)); -} - -TEST_F(DeclarableOpsTests11, ImageResizeArea_Test14) { - - NDArray input = NDArrayFactory::create('c', {1, 5, 5, 1}, { - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 - }); - auto size = NDArrayFactory::create({8, 7}); - NDArray expected = NDArrayFactory::create('c', {1, 8, 7, 1}, { - 1.f, 1.6f , 2.1999993f, 2.9999995f , 3.8f , 4.399997f, 5.f , 2.9999995f , 3.5999997f , 4.199999f, - 4.9999995f, 5.8f , 6.3999963f , 7.f , 5.999999f , 6.6f, 7.1999984f , 7.9999995f , 8.8f, - 9.399994f, 10.f , 10.f, 10.6f , 11.199998f, 12.f, 12.8f, 13.399992f , 14.f, 12.f , 12.599999f, - 13.199998f , 13.999998f , 14.800002f , 15.399991f , 16.f , 15.999999f , 16.599998f , 17.199995f, - 18.f , 18.800003f , 19.399986f , 20.000002f , 19.f , 19.599998f , 20.199997f , - 20.999998f , 21.800003f , 22.399984f , 23.000002f , 20.999998f , - 21.599998f , 22.199995f , 22.999998f , 23.800001f , 24.399984f , - 25.f - }); //input.linspace(1); -// auto size = NDArrayFactory::create({6, 6}); - sd::ops::resize_area op; - auto results = op.evaluate({&input, &size}, {}, {false}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray* result = results.at(0); -// result.printBuffer("Area Resized to 8x7"); -// expected.printBuffer("Area Expect for 8x7"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); -} - -TEST_F(DeclarableOpsTests11, ImageResizeArea_Test15) { - - NDArray input = NDArrayFactory::create('c', {1, 5, 5, 1}, { - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 - }); - //auto size = NDArrayFactory::create({8, 7}); - NDArray expected = NDArrayFactory::create('c', {1, 8, 7, 1}, { - 1.f, 1.6f , 2.1999993f, 2.9999995f , 3.8f , 4.399997f, 5.f , 2.9999995f , 3.5999997f , 4.199999f, - 4.9999995f, 5.8f , 6.3999963f , 7.f , 5.999999f , 6.6f, 7.1999984f , 7.9999995f , 8.8f, - 9.399994f, 10.f , 10.f, 10.6f , 11.199998f, 12.f, 12.8f, 13.399992f , 14.f, 12.f , 12.599999f, - 13.199998f , 13.999998f , 14.800002f , 15.399991f , 16.f , 15.999999f , 16.599998f , 17.199995f, - 18.f , 18.800003f , 19.399986f , 20.000002f , 19.f , 19.599998f , 20.199997f , - 20.999998f , 21.800003f , 22.399984f , 23.000002f , 20.999998f , 21.599998f , 22.199995f , - 22.999998f , 23.800001f , 24.399984f , 25.f - }); - - sd::ops::resize_area op; - auto results = op.evaluate({&input}, {}, {8, 7}, {false}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray* result = results.at(0); -// result.printBuffer("Area Resized to 8x7"); -// expected.printBuffer("Area Expect for 8x7"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); -} - - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, summaryStatsData_test1) { - - functions::summarystats::SummaryStatsData var1; - functions::summarystats::SummaryStatsData var2; - var2.n = var2.mean = var2.M2 = var2.M3 = var2.M4 = var2.bias = 5; - - functions::summarystats::SummaryStatsData* arr = new functions::summarystats::SummaryStatsData[2]; - arr[0] = var1; - arr[1] = var2; - arr[0] = arr[1]; - - functions::summarystats::SummaryStatsData var3(var1); - - ASSERT_TRUE(arr[0].n == arr[0].mean && arr[0].M2 == arr[0].M3 && arr[0].n == 5); - ASSERT_TRUE(arr[1].n == arr[1].mean && arr[1].M2 == arr[1].M3 && arr[1].n == 5); - ASSERT_TRUE(var3.n == var3.mean && var3.M2 == var3.M3 && var3.n == 0); - - delete []arr; -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, Solve_Test_1) { - - auto a = NDArrayFactory::create('c', {3, 3}, { - 2.f, -1.f, -2.f, -4.f, 6.f, 3.f, -4.f, -2.f, 8.f - }); - - auto b = NDArrayFactory::create('c', {3, 1}, { - 2.f, 4.f, 3.f - }); - - auto exp = NDArrayFactory::create('c', {3, 1}, { - 7.625f, 3.25f, 5.f - }); - - sd::ops::solve op; - - auto res = op.evaluate({&a, &b}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - -// z->printIndexedBuffer("Solve of 3x3"); - - ASSERT_TRUE(exp.equalsTo(z)); -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, Solve_Test_2) { - - auto a = NDArrayFactory::create('c', {4, 4}, { - 1.f, 1.f, 1.f, 1.f, - 0.f, 1.f, 1.f, 0.f, - 0.f, 0.f, 2.f, 1.f, - 0.f, 0.f, 0.f, 3.f, - }); - - auto b = NDArrayFactory::create('c', {4, 1}, { - 2.f, 4.f, 2.f, 4.f - }); - - auto exp = NDArrayFactory::create('c', {4, 1}, { - -3.3333333f, 3.6666666f, 0.333333f, 1.3333333f - }); - - sd::ops::solve op; - - auto res = op.evaluate({&a, &b}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - -// z->printIndexedBuffer("Solve 4x4"); - - ASSERT_TRUE(exp.equalsTo(z)); - -} -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, Solve_Test_3) { - - auto a = NDArrayFactory::create('c', {2, 4, 4}, { - 1.f, 1.f, 1.f, 1.f, - 0.f, 1.f, 1.f, 0.f, - 0.f, 0.f, 2.f, 1.f, - 0.f, 0.f, 0.f, 3.f, - - 3.f, 0.f, 0.f, 0.f, - 2.f, 1.f, 0.f, 0.f, - 1.f, 0.f, 1.f, 0.f, - 1.f, 1.f, 1.f, 1.f - - }); - - auto b = NDArrayFactory::create('c', {2, 4, 1}, { - 2.f, 4.f, 2.f, 4.f, - 4.f, 2.f, 4.f, 2.f - }); - - auto exp = NDArrayFactory::create('c', {2, 4, 1}, { - -3.3333333f, 3.6666666f, 0.333333f, 1.3333333f, - 1.333333f, -0.6666667f, 2.6666667f, -1.3333333f - }); - - sd::ops::solve op; - - auto res = op.evaluate({&a, &b}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - -// z->printIndexedBuffer("Solve 4x4"); - - ASSERT_TRUE(exp.equalsTo(z)); - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, Solve_Test_4) { - - auto a = NDArrayFactory::create('c', {2, 2, 2}, { - 0.7788f, 0.8012f, 0.7244f, 0.2309f, - 0.7271f, 0.1804f, 0.5056f, 0.8925f - }); - - auto b = NDArrayFactory::create('c', {2, 2, 2}, { - 0.7717f, 0.9281f, 0.9846f, 0.4838f, - 0.6433f, 0.6041f, 0.6501f, 0.7612f - }); - - auto exp = NDArrayFactory::create('c', {2, 2, 2}, { -// 1.524494767f, 0.432706356f,-0.518630624f, 0.737760842f, -// 0.819143713f, 0.720401764f, 0.264349997f, 0.444699198f - 1.5245394f, 0.4326952f, -0.51873577f, 0.7377896f, - 0.81915987f, 0.72049433f, 0.2643504f, 0.44472617f - }); - - sd::ops::solve op; - - auto res = op.evaluate({&a, &b}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - -// z->printBuffer("4 Solve 4x4"); -// exp.printBuffer("4 Expec 4x4"); - - ASSERT_TRUE(exp.equalsTo(z)); - -} -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, Solve_Test_4_1) { - - auto a = NDArrayFactory::create('c', {2, 2, 2}, { - 0.7788f, 0.8012f, 0.7244f, 0.2309f, - 0.7271f, 0.1804f, 0.5056f, 0.8925f - }); - - auto b = NDArrayFactory::create('c', {2, 2, 2}, { - 0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f - }); - - auto exp = NDArrayFactory::create('c', {2, 2, 2}, { - 1.3357621f, 0.3399364f, -0.37077796f, 0.91573375f, - 0.4400987f, 0.2766527f, 0.6394467f, 0.79696566f - }); - - sd::ops::solve op; - - auto res = op.evaluate({&a, &b}, {true}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - -// z->printBuffer("4 Solve 4x4"); -// exp.printBuffer("4 Expec 4x4"); - - ASSERT_TRUE(exp.equalsTo(z)); - -} -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, Solve_Test_4_2) { - - auto a = NDArrayFactory::create('c', {3, 3}, { - 0.7788f, 0.8012f, 0.7244f, - 0.2309f, 0.7271f, 0.1804f, - 0.5056f, 0.8925f, 0.5461f - }); - - auto b = NDArrayFactory::create('c', {3, 3}, { - 0.7717f, 0.9281f, 0.9846f, - 0.4838f, 0.6433f, 0.6041f, - 0.6501f, 0.7612f, 0.7605f - }); - - auto exp = NDArrayFactory::create('c', {3, 3}, { - 0.99088347f, 1.1917052f, 1.2642528f, - 0.35071516f, 0.50630623f, 0.42935497f, - -0.30013534f, -0.53690606f, -0.47959247f - }); - - sd::ops::triangular_solve op; - - auto res = op.evaluate({&a, &b}, {true, false}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - -// z->printBuffer("4_2 Triangular_Solve 3x3"); -// exp.printBuffer("4_2 Triangular_Expec 3x3"); - - ASSERT_TRUE(exp.equalsTo(z)); + // z->printBuffer("4_7 Solve 3x3"); + // exp.printBuffer("4_7 Expec 3x3"); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, Solve_Test_4_3) { - - auto a = NDArrayFactory::create('c', {3, 3}, { - 0.7788f, 0.8012f, 0.7244f, - 0.2309f, 0.7271f, 0.1804f, - 0.5056f, 0.8925f, 0.5461f - }); - - auto b = NDArrayFactory::create('c', {3, 3}, { - 0.7717f, 0.9281f, 0.9846f, - 0.4838f, 0.6433f, 0.6041f, - 0.6501f, 0.7612f, 0.7605f - }); +TEST_F(DeclarableOpsTests11, Solve_Test_5) { + auto a = NDArrayFactory::create( + 'c', {3, 3}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f}); - auto exp = NDArrayFactory::create('c', {3, 3}, { - 0.45400196f, 0.53174824f, 0.62064564f, - -0.79585856f, -0.82621557f, -0.87855506f, - 1.1904413f, 1.3938838f, 1.3926021f - }); + auto b = NDArrayFactory::create( + 'c', {3, 3}, + {0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f, + 0.7605f}); - sd::ops::triangular_solve op; + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {1.5504692f, 1.8953944f, 2.2765768f, 0.03399149f, 0.2883001f, 0.5377323f, + -0.8774802f, -1.2155888f, -1.8049058f}); - auto res = op.evaluate({&a, &b}, {true, true}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + sd::ops::solve op; -// z->printBuffer("4_3 Triangular_Solve 3x3"); -// exp.printBuffer("4_3 Triangular_Expec 3x3"); + auto res = op.evaluate({&a, &b}, {true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); - ASSERT_TRUE(exp.equalsTo(z)); - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, Solve_Test_4_4) { - - auto a = NDArrayFactory::create('c', {3, 3}, { - 0.7788f, 0.8012f, 0.7244f, - 0.2309f, 0.7271f, 0.1804f, - 0.5056f, 0.8925f, 0.5461f - }); - - auto b = NDArrayFactory::create('c', {3, 3}, { - 0.7717f, 0.9281f, 0.9846f, - 0.4838f, 0.6433f, 0.6041f, - 0.6501f, 0.7612f, 0.7605f - }); - - auto exp = NDArrayFactory::create('c', {3, 3}, { - 0.8959121f, 1.6109066f, 1.7501404f, - 0.49000582f, 0.66842675f, 0.5577021f, - -0.4398522f, -1.1899745f, -1.1392052f - }); - - sd::ops::solve op; - - auto res = op.evaluate({&a, &b}, {false}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - -// z->printBuffer("4_4 Solve 3x3"); -// exp.printBuffer("4_4 Expec 3x3"); - - ASSERT_TRUE(exp.equalsTo(z)); - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, Solve_Test_4_5) { - - auto a = NDArrayFactory::create('c', {3, 3}, { - 0.7788f, 0.8012f, 0.7244f, - 0.2309f, 0.7271f, 0.1804f, - 0.5056f, 0.8925f, 0.5461f - }); - - auto b = NDArrayFactory::create('c', {3, 3}, { - 0.7717f, 0.9281f, 0.9846f, - 0.4838f, 0.6433f, 0.6041f, - 0.6501f, 0.7612f, 0.7605f - }); - - auto exp = NDArrayFactory::create('c', {3, 3}, { - 1.5504692f, 1.8953944f, 2.2765768f, - 0.03399149f, 0.2883001f, 0.5377323f, - -0.8774802f, -1.2155888f, -1.8049058f - }); - - sd::ops::solve op; - - auto res = op.evaluate({&a, &b}, {true, true}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - -// z->printBuffer("4_5 Solve 3x3"); -// exp.printBuffer("4_5 Expec 3x3"); - - ASSERT_TRUE(exp.equalsTo(z)); - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, Solve_Test_4_6) { - - auto a = NDArrayFactory::create('c', {3, 3}, { - 0.7788f, 0.8012f, 0.7244f, - 0.2309f, 0.7271f, 0.1804f, - 0.5056f, 0.8925f, 0.5461f - }); - - auto b = NDArrayFactory::create('c', {3, 3}, { - 0.7717f, 0.9281f, 0.9846f, - 0.4838f, 0.6433f, 0.6041f, - 0.6501f, 0.7612f, 0.7605f - }); - - auto exp = NDArrayFactory::create('c', {3, 3}, { - 0.99088347f, 1.1917052f, 1.2642528f, - -0.426483f, -0.42840624f, -0.5622601f, - 0.01692283f, -0.04538865f, -0.09868701f - }); - - sd::ops::triangular_solve op; - - auto res = op.evaluate({&a, &b}, {false, true}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - -// z->printBuffer("4_6 Solve 3x3"); -// exp.printBuffer("4_6 Expec 3x3"); - - ASSERT_TRUE(exp.equalsTo(z)); - -} -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, Solve_Test_4_7) { - - auto a = NDArrayFactory::create('c', {3, 3}, { -// 0.7788f, 0.2309f, 0.5056f, -// 0.8012f, 0.7271f, 0.8925f, -// 0.7244f, 0.1804f, 0.5461f - - 0.7788f, 0.2309f, 0.5056f, - 0.8012f, 0.7271f, 0.8925f, - 0.7244f, 0.1804f, 0.5461f - }); - - auto b = NDArrayFactory::create('c', {3, 3}, { - 0.7717f, 0.9281f, 0.9846f, - 0.4838f, 0.6433f, 0.6041f, - 0.6501f, 0.7612f, 0.7605f - }); - - auto exp = NDArrayFactory::create('c', {3, 3}, { - 0.99088347f, 1.1917052f, 1.2642528f, - -0.426483f, -0.42840624f, -0.5622601f, - 0.01692283f, -0.04538865f, -0.09868701f - }); - - sd::ops::triangular_solve op; - - auto res = op.evaluate({&a, &b}, {true, false}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - -// z->printBuffer("4_7 Solve 3x3"); -// exp.printBuffer("4_7 Expec 3x3"); - - ASSERT_TRUE(exp.equalsTo(z)); - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, Solve_Test_5) { - - auto a = NDArrayFactory::create('c', {3, 3}, { - 0.7788f, 0.8012f, 0.7244f, - 0.2309f, 0.7271f, 0.1804f, - 0.5056f, 0.8925f, 0.5461f - }); - - auto b = NDArrayFactory::create('c', {3, 3}, { - 0.7717f, 0.9281f, 0.9846f, - 0.4838f, 0.6433f, 0.6041f, - 0.6501f, 0.7612f, 0.7605f - }); - - auto exp = NDArrayFactory::create('c', {3, 3}, { - 1.5504692f, 1.8953944f, 2.2765768f, - 0.03399149f, 0.2883001f, 0.5377323f, - -0.8774802f, -1.2155888f, -1.8049058f - }); - - sd::ops::solve op; - - auto res = op.evaluate({&a, &b}, {true}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - -// z->printBuffer("4 Solve 4x4"); -// exp.printBuffer("4 Expec 4x4"); - - ASSERT_TRUE(exp.equalsTo(z)); + // z->printBuffer("4 Solve 4x4"); + // exp.printBuffer("4 Expec 4x4"); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, SolveLS_Test_1) { + auto a = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - auto a = NDArrayFactory::create('c', {2,2, 2}, { - 1.f, 2.f, 3.f, 4.f, - 5.f, 6.f, 7.f, 8.f - }); - - auto b = NDArrayFactory::create('c', {2, 2, 1}, { - 3.f, 7.f, 11.f, 15.f - }); - - auto exp = NDArrayFactory::create('c', {2, 2, 1}, { - 0.8311695f, 1.0909086f, 0.9205573f, 1.0630057f - }); - - sd::ops::lstsq op; - - auto res = op.evaluate({&a, &b}, {0.5}, {}, {true}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - -// z->printIndexedBuffer("LS Solve 2x2"); -// exp.printIndexedBuffer("LS Expec 2x2"); - - ASSERT_TRUE(exp.equalsTo(z, 1.e-4)); -} - -TEST_F(DeclarableOpsTests11, SolveLS_Test_2) { - - auto a = NDArrayFactory::create('c', {2,2, 2}, { - 1.f, 2.f, 3.f, 4.f, - 5.f, 6.f, 7.f, 8.f - }); - - auto b = NDArrayFactory::create('c', {2, 2, 1}, { - 3.f, 7.f, 11.f, 15.f - }); - - auto exp = NDArrayFactory::create('c', {2, 2, 1}, { - 0.8311695f, 1.0909086f, 0.9205573f, 1.0630057f - }); - - sd::ops::lstsq op; - - auto res = op.evaluate({&a, &b}, {0.5}, {}, {true}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - -// z->printIndexedBuffer("2LS Solve 2x2"); -// exp.printIndexedBuffer("2LS Expec 2x2"); - - ASSERT_TRUE(exp.equalsTo(z, 1.e-4)); -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, Cholesky_Test_2x2x2) { - - auto a = NDArrayFactory::create('c', {2,2, 2}, { - 10.f, 14.f, - 14.f, 20.f, - - 74.f, 86.f, - 86.f, 100.f - }); - - auto exp = NDArrayFactory::create('c', {2, 2, 2}, { - 3.1622777f, 0.f, 4.427189f, 0.6324552f, - 8.602325f, 0.f, 9.997296f, 0.23252854f - }); - - sd::ops::cholesky op; - - auto res = op.evaluate({&a}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - - z->printIndexedBuffer("L matrix is"); - exp.printIndexedBuffer("L expected is"); - - ASSERT_TRUE(exp.equalsTo(z)); - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, Cholesky_Test_2x2x2_2) { - - auto a = NDArrayFactory::create('c', {2,2, 2}, { - 10.5f, 14.f, - 14.f, 20.5f, - - 74.5f, 86.f, - 86.f, 100.5f - }); - - auto exp = NDArrayFactory::create('c', {2, 2, 2}, { - 3.2403703f, 0.f, 4.3204937f, 1.3540066f, - 8.631338f, 0.f, 9.963693f, 1.1067207f - }); - - sd::ops::cholesky op; - - auto res = op.evaluate({&a}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - -// z->printIndexedBuffer("L matrix is"); -// exp.printIndexedBuffer("L expected is"); - MmulHelper::matmul(z, z, &exp, false, true); - ASSERT_TRUE(exp.equalsTo(a)); -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test1) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52, - -12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04}); - NDArray dLdwExp('c', {2,3,4}, {0.9216 , 3.6864 , 8.2944 , 14.7456 , 23.04 , 33.1776 , 45.1584 , 58.9824 , 74.6496 , 92.16 ,111.51361,132.7104 , - 155.75038,180.63359,207.35999,235.9296 ,266.34238,298.59842,332.6976 ,368.64001,406.4256 ,446.05444,487.5264 ,530.84161}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); - -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test2) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,1,4}, sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {2,1,4}, {98.61121,129.024 , 164.9664 , 206.4384 , 828.51837,925.28644,1027.58398,1135.41113}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdw = results.at(1); - - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test3) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights(sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52, - -12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04}); - NDArray dLdwExp('c', {}, std::vector{4515.84}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); - -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test4) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {1,3,1}, {807.32153, 1426.63684, 2281.88159}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdw = results.at(1); - - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test5) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0.08,-0.16,-0.24,-0.32,-0.4 ,-0.48,-0.56,-0.64,-0.72,-0.8 ,-0.88,-0.96, - -1.04,-1.12,-1.2 ,-1.28,-1.36,-1.44,-1.52,-1.6 ,-1.68,-1.76,-1.84,-1.92}); - NDArray dLdwExp('c', {2,3,4}, {-15.6032,-15.3728,-14.9888,-14.4512,-13.76 ,-12.9152,-11.9168,-10.7648, -9.4592, -8. , -6.3872, -4.6208, - -2.7008, -0.6272, 1.6 , 3.9808, 6.5152, 9.2032, 12.0448, 15.04 , 18.1888, 21.4912, 24.9472, 28.5568}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); - -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test6) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {1,3,1}, {-58.16319, -6.5536 , 64.71682}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdw = results.at(1); - - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test7) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights(sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {}, std::vector{0.}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdw = results.at(1); - - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test8) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {0. ,0. ,0. ,0. ,-0.48 ,-0.576,-0.672,-0.768,-0.864,-0.96 ,-1.056,-1.152, - -1.248,-1.344,-1.44 ,-1.536,-1.632,-1.728,-1.824,-1.92 ,-2.016,-2.112,-2.208,-2.304}); - NDArray dLdwExp('c', {2,3,4}, {-22.3488 ,-22.07232,-21.61152,-20.9664 ,-20.13696,-19.1232 ,-17.92512,-16.54272,-14.976 ,-13.22496,-11.2896 , -9.16992, - -6.86592, -4.3776 , -1.70496, 1.152 , 4.19328, 7.41888, 10.8288 , 14.42304, 18.2016 , 22.16449, 26.31168, 30.6432 }); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); - weights.p(3, 0.); - - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); - -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test9) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0.04,-0.08,-0.12,-0.16,-0.2 ,-0.24,-0.28,-0.32,-0.36,-0.4 ,-0.44,-0.48, - -0.52,-0.56,-0.6 ,-0.64,-0.68,-0.72,-0.76,-0.8 ,-0.84,-0.88,-0.92,-0.96}); - NDArray dLdwExp('c', {2,3,4}, {0.0384, 0.1536, 0.3456, 0.6144, 0.96 , 1.3824, 1.8816, 2.4576, 3.1104, 3.84 , 4.6464, 5.5296, - 6.4896, 7.5264, 8.64 , 9.8304,11.0976,12.4416,13.8624,15.36 ,16.9344,18.5856,20.3136,22.1184}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); - -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test10) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,1}, sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {1,1}, std::vector{188.16}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdw = results.at(1); - - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test11) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {1,3,1}, {33.6384 ,59.4432 ,95.07841}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdw = results.at(1); - - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test12) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {0.,0.,0.,0., -0.24 ,-0.288,-0.336,-0.384,-0.432,-0.48 ,-0.528,-0.576, - -0.624,-0.672,-0.72 ,-0.768,-0.816,-0.864,-0.912,-0.96 ,-1.008,-1.056,-1.104,-1.152}); - NDArray dLdwExp('c', {2,3,4}, {0.04608, 0.18432, 0.41472, 0.73728, 1.152 , 1.65888, 2.25792, 2.94912, 3.73248, 4.608 , 5.57568, 6.63552, - 7.78752, 9.03168,10.368 ,11.79648,13.31712,14.92992,16.63488,18.432 ,20.32128,22.30272,24.37632,26.54208}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.r(0) = 0.; - weights.r(1) = 0.; - weights.r(2) = 0.; - weights.r(3) = 0.; - - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); - -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test13) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,1}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - -1.04,-1.12,-1.2 ,-1.28,-1.36,-1.44,-1.52,-1.6 ,-1.68,-1.76,-1.84,-1.92}); - NDArray dLdwExp('c', {2,3,1}, {2.304 , 13.3632 , 34.2528 , 64.97279,105.5232 ,155.90401}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.r(0) = 0.; - weights.r(1) = 0.; - weights.r(2) = 0.; - - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); - -} - -TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test1) { - auto x = NDArrayFactory::create('c', {4}, {0, 1, 2, 3}); - auto y = NDArrayFactory::create('c',{4}, {3, 2, 1, 0}); - auto exp = NDArrayFactory::create('c', {4}, {9, 1,1, 9}); - sd::ops::squaredsubtract op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - -} - -TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test2) { - auto x = NDArrayFactory::create('c', {2, 4}, {0, 1, 2, 3, 0, 1, 2, 3}); - auto y = NDArrayFactory::create('c',{4}, {3, 2, 1, 0}); - auto exp = NDArrayFactory::create('c', {2, 4}, {9, 1,1, 9, 9, 1, 1, 9}); - sd::ops::squaredsubtract op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - -} - -TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test3) { - auto x = NDArrayFactory::create('c', {2, 4}, {0, 1, 2, 3, 0, 1, 2, 3}); - auto y = NDArrayFactory::create('c',{4}, {3, 2, 1, 0}); - auto exp = NDArrayFactory::create('c', {2, 4}, {-6, -4, 6, 24, -30, -12, 14, 48}); - auto eps = NDArrayFactory::create('c', {2, 4}, {1,2,3,4,5,6,7,8}); - sd::ops::squaredsubtract_bp op; - auto result = op.evaluate({&x, &y, &eps}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test1) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5, - -0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5}); - NDArray dLdwExp('c', {2,3,4}, {0.96, 1.92, 2.88, 3.84, 4.8 , 5.76, 6.72, 7.68, 8.64, 9.6 ,10.56,11.52, - 12.48,13.44,14.4 ,15.36,16.32,17.28,18.24,19.2 ,20.16,21.12,22.08,23.04}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test2) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,1,4}, sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {2,1,4}, {14.4 , 17.28, 20.16, 23.04, 48.96, 51.84, 54.72, 57.6}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdw = results.at(1); - - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test3) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights(sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5, - -0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5}); - NDArray dLdwExp('c', {}, std::vector{288.}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); - -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test4) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {1,3,1}, {65.28, 96., 126.72001}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdw = results.at(1); - - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test5) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167, - -0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167}); - NDArray dLdwExp('c', {2,3,4}, {-0.92,-0.84,-0.76,-0.68,-0.6 ,-0.52,-0.44,-0.36,-0.28,-0.2 ,-0.12,-0.04, - 0.04, 0.12, 0.2 , 0.28, 0.36, 0.44, 0.52, 0.6 , 0.68, 0.76, 0.84, 0.92}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); - -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test6) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {1,3,1}, {-2.56, 0., 2.56}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdw = results.at(1); - - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test7) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights(sd::DataType::DOUBLE); + auto b = + NDArrayFactory::create('c', {2, 2, 1}, {3.f, 7.f, 11.f, 15.f}); - NDArray dLdwExp('c', {}, std::vector{0.}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 1}, {0.8311695f, 1.0909086f, 0.9205573f, 1.0630057f}); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + sd::ops::lstsq op; - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + auto res = op.evaluate({&a, &b}, {0.5}, {}, {true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + // z->printIndexedBuffer("LS Solve 2x2"); + // exp.printIndexedBuffer("LS Expec 2x2"); - auto *dLdw = results.at(1); + ASSERT_TRUE(exp.equalsTo(z, 1.e-4)); +} - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); +TEST_F(DeclarableOpsTests11, SolveLS_Test_2) { + auto a = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); -} + auto b = + NDArrayFactory::create('c', {2, 2, 1}, {3.f, 7.f, 11.f, 15.f}); -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test8) { + auto exp = NDArrayFactory::create( + 'c', {2, 2, 1}, {0.8311695f, 1.0909086f, 0.9205573f, 1.0630057f}); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); + sd::ops::lstsq op; - NDArray dLdpExp('c', {2,3,4}, {-0. ,-0. ,-0. ,-0. ,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05, - -0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05}); - NDArray dLdwExp('c', {2,3,4}, {-1.296,-1.2 ,-1.104,-1.008,-0.912,-0.816,-0.72 ,-0.624,-0.528,-0.432,-0.336,-0.24 , - -0.144,-0.048, 0.048, 0.144, 0.24 , 0.336, 0.432, 0.528, 0.624, 0.72 , 0.816, 0.912}); + auto res = op.evaluate({&a, &b}, {0.5}, {}, {true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); - weights.p(3, 0.); + // z->printIndexedBuffer("2LS Solve 2x2"); + // exp.printIndexedBuffer("2LS Expec 2x2"); - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + ASSERT_TRUE(exp.equalsTo(z, 1.e-4)); +} - ASSERT_EQ(ND4J_STATUS_OK, results.status()); +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Cholesky_Test_2x2x2) { + auto a = NDArrayFactory::create('c', {2, 2, 2}, + {10.f, 14.f, 14.f, 20.f, - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + 74.f, 86.f, 86.f, 100.f}); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + auto exp = + NDArrayFactory::create('c', {2, 2, 2}, + {3.1622777f, 0.f, 4.427189f, 0.6324552f, + 8.602325f, 0.f, 9.997296f, 0.23252854f}); -} + sd::ops::cholesky op; -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test9) { + auto res = op.evaluate({&a}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); + // z->printIndexedBuffer("L matrix is"); + // exp.printIndexedBuffer("L expected is"); - NDArray dLdpExp('c', {2,3,4}, {-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083, - -0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083}); - NDArray dLdwExp('c', {2,3,4}, {0.04, 0.08, 0.12, 0.16, 0.2 , 0.24, 0.28, 0.32,0.36, 0.4 , 0.44, 0.48, - 0.52, 0.56, 0.6 , 0.64,0.68, 0.72, 0.76, 0.8 ,0.84, 0.88, 0.92, 0.96}); + ASSERT_TRUE(exp.equalsTo(z)); +} - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Cholesky_Test_2x2x2_2) { + auto a = NDArrayFactory::create('c', {2, 2, 2}, + {10.5f, 14.f, 14.f, 20.5f, - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + 74.5f, 86.f, 86.f, 100.5f}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto exp = + NDArrayFactory::create('c', {2, 2, 2}, + {3.2403703f, 0.f, 4.3204937f, 1.3540066f, + 8.631338f, 0.f, 9.963693f, 1.1067207f}); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + sd::ops::cholesky op; - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + auto res = op.evaluate({&a}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + // z->printIndexedBuffer("L matrix is"); + // exp.printIndexedBuffer("L expected is"); + MmulHelper::matmul(&z, &z, &exp, false, true); + ASSERT_TRUE(exp.equalsTo(a)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test10) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,1}, sd::DataType::DOUBLE); +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test1) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {-0.96, -1.92, -2.88, -3.84, -4.8, -5.76, -6.72, -7.68, + -8.64, -9.6, -10.56, -11.52, -12.48, -13.44, -14.4, -15.36, + -16.32, -17.28, -18.24, -19.2, -20.16, -21.12, -22.08, -23.04}); + NDArray dLdwExp( + 'c', {2, 3, 4}, + {0.9216, 3.6864, 8.2944, 14.7456, 23.04, 33.1776, + 45.1584, 58.9824, 74.6496, 92.16, 111.51361, 132.7104, + 155.75038, 180.63359, 207.35999, 235.9296, 266.34238, 298.59842, + 332.6976, 368.64001, 406.4256, 446.05444, 487.5264, 530.84161}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); +} - NDArray dLdwExp('c', {1,1}, std::vector{12.}); +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test2) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 1, 4}, sd::DataType::DOUBLE); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + NDArray dLdwExp('c', {2, 1, 4}, + {98.61121, 129.024, 164.9664, 206.4384, 828.51837, 925.28644, + 1027.58398, 1135.41113}); - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - auto *dLdw = results.at(1); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test11) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test3) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights(sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {-0.96, -1.92, -2.88, -3.84, -4.8, -5.76, -6.72, -7.68, + -8.64, -9.6, -10.56, -11.52, -12.48, -13.44, -14.4, -15.36, + -16.32, -17.28, -18.24, -19.2, -20.16, -21.12, -22.08, -23.04}); + NDArray dLdwExp('c', {}, std::vector{4515.84}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); +} - NDArray dLdwExp('c', {1,3,1}, {2.72, 4., 5.28}); +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test4) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + NDArray dLdwExp('c', {1, 3, 1}, {807.32153, 1426.63684, 2281.88159}); - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - auto *dLdw = results.at(1); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test12) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test5) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {-0.08, -0.16, -0.24, -0.32, -0.4, -0.48, -0.56, -0.64, + -0.72, -0.8, -0.88, -0.96, -1.04, -1.12, -1.2, -1.28, + -1.36, -1.44, -1.52, -1.6, -1.68, -1.76, -1.84, -1.92}); + NDArray dLdwExp('c', {2, 3, 4}, + {-15.6032, -15.3728, -14.9888, -14.4512, -13.76, -12.9152, + -11.9168, -10.7648, -9.4592, -8., -6.3872, -4.6208, + -2.7008, -0.6272, 1.6, 3.9808, 6.5152, 9.2032, + 12.0448, 15.04, 18.1888, 21.4912, 24.9472, 28.5568}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); +} - NDArray dLdpExp('c', {2,3,4}, {0., 0., 0., 0., -0.025, -0.025, -0.025, -0.025,-0.025, -0.025, -0.025, -0.025, - -0.025, -0.025, -0.025, -0.025,-0.025, -0.025, -0.025, -0.025,-0.025, -0.025, -0.025, -0.025}); - NDArray dLdwExp('c', {2,3,4}, {0.048, 0.096, 0.144, 0.192,0.24 , 0.288, 0.336, 0.384,0.432, 0.48 , 0.528, 0.576, - 0.624, 0.672, 0.72 , 0.768,0.816, 0.864, 0.912, 0.96 ,1.008, 1.056, 1.104, 1.152}); +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test6) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.r(0) = 0.; - weights.r(1) = 0.; - weights.r(2) = 0.; - weights.r(3) = 0.; + NDArray dLdwExp('c', {1, 3, 1}, {-58.16319, -6.5536, 64.71682}); - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test13) { +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test7) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights(sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,1}, sd::DataType::DOUBLE); + NDArray dLdwExp('c', {}, std::vector{0.}); - NDArray dLdpExp('c', {2,3,4}, {0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., - -0.04167, -0.04167, -0.04167, -0.04167,-0.04167, -0.04167, -0.04167, -0.04167,-0.04167, -0.04167, -0.04167, -0.04167}); - NDArray dLdwExp('c', {2,3,1}, {0.8 ,2.08,3.36,4.64,5.92,7.2 }); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.r(0) = 0.; - weights.r(1) = 0.; - weights.r(2) = 0.; + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto dLdw = results.at(1); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); +} - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test8) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {0., 0., 0., 0., -0.48, -0.576, -0.672, -0.768, + -0.864, -0.96, -1.056, -1.152, -1.248, -1.344, -1.44, -1.536, + -1.632, -1.728, -1.824, -1.92, -2.016, -2.112, -2.208, -2.304}); + NDArray dLdwExp( + 'c', {2, 3, 4}, + {-22.3488, -22.07232, -21.61152, -20.9664, -20.13696, -19.1232, + -17.92512, -16.54272, -14.976, -13.22496, -11.2896, -9.16992, + -6.86592, -4.3776, -1.70496, 1.152, 4.19328, 7.41888, + 10.8288, 14.42304, 18.2016, 22.16449, 26.31168, 30.6432}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); +} +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test9) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {-0.04, -0.08, -0.12, -0.16, -0.2, -0.24, -0.28, -0.32, + -0.36, -0.4, -0.44, -0.48, -0.52, -0.56, -0.6, -0.64, + -0.68, -0.72, -0.76, -0.8, -0.84, -0.88, -0.92, -0.96}); + NDArray dLdwExp( + 'c', {2, 3, 4}, + {0.0384, 0.1536, 0.3456, 0.6144, 0.96, 1.3824, 1.8816, 2.4576, + 3.1104, 3.84, 4.6464, 5.5296, 6.4896, 7.5264, 8.64, 9.8304, + 11.0976, 12.4416, 13.8624, 15.36, 16.9344, 18.5856, 20.3136, 22.1184}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, BFloat16_Test_1) { +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test10) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 1}, sd::DataType::DOUBLE); - NDArray x = NDArrayFactory::create('c', {2,3,4}); - NDArray y = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); - NDArray exp = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); + NDArray dLdwExp('c', {1, 1}, std::vector{188.16}); - x.linspace(1); - y.linspace(1); - exp.linspace(2,2); - sd::ops::add op; - auto results = op.evaluate({&x, &y}, {}, {}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - auto res = results.at(0); - ASSERT_TRUE(res->equalsTo(exp)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto dLdw = results.at(1); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, BFloat16_Test_2) { +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test11) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - NDArray x = NDArrayFactory::create('c', {2,3,4}); - NDArray y = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); - NDArray exp = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); + NDArray dLdwExp('c', {1, 3, 1}, {33.6384, 59.4432, 95.07841}); - x.linspace(1); - y.linspace(1); - exp.linspace(2,2); - sd::ops::add op; - auto results = op.evaluate({&x, &y}, {}, {}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - auto res = results.at(0); - ASSERT_TRUE(res->equalsTo(exp)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto dLdw = results.at(1); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, BFloat16_Test_3) { +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test12) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {0., 0., 0., 0., -0.24, -0.288, -0.336, -0.384, + -0.432, -0.48, -0.528, -0.576, -0.624, -0.672, -0.72, -0.768, + -0.816, -0.864, -0.912, -0.96, -1.008, -1.056, -1.104, -1.152}); + NDArray dLdwExp('c', {2, 3, 4}, + {0.04608, 0.18432, 0.41472, 0.73728, 1.152, 1.65888, + 2.25792, 2.94912, 3.73248, 4.608, 5.57568, 6.63552, + 7.78752, 9.03168, 10.368, 11.79648, 13.31712, 14.92992, + 16.63488, 18.432, 20.32128, 22.30272, 24.37632, 26.54208}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.r(0) = 0.; + weights.r(1) = 0.; + weights.r(2) = 0.; + weights.r(3) = 0.; + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); +} - NDArray x('c', {2,3,4}, sd::DataType::BFLOAT16); - NDArray y('c', {2,3,4}, sd::DataType::BFLOAT16); - NDArray exp('c', {2,3,4}, sd::DataType::BFLOAT16); +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test13) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 1}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., -1.04, -1.12, -1.2, -1.28, + -1.36, -1.44, -1.52, -1.6, -1.68, -1.76, -1.84, -1.92}); + NDArray dLdwExp('c', {2, 3, 1}, + {2.304, 13.3632, 34.2528, 64.97279, 105.5232, 155.90401}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.r(0) = 0.; + weights.r(1) = 0.; + weights.r(2) = 0.; + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); +} - x.linspace(1); - y.linspace(1); - exp.linspace(2,2); - sd::ops::add op; - auto results = op.evaluate({&x, &y}, {}, {}); +TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test1) { + auto x = NDArrayFactory::create('c', {4}, {0, 1, 2, 3}); + auto y = NDArrayFactory::create('c', {4}, {3, 2, 1, 0}); + auto exp = NDArrayFactory::create('c', {4}, {9, 1, 1, 9}); + sd::ops::squaredsubtract op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); +} - ASSERT_EQ(ND4J_STATUS_OK, results.status()); +TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test2) { + auto x = NDArrayFactory::create('c', {2, 4}, {0, 1, 2, 3, 0, 1, 2, 3}); + auto y = NDArrayFactory::create('c', {4}, {3, 2, 1, 0}); + auto exp = + NDArrayFactory::create('c', {2, 4}, {9, 1, 1, 9, 9, 1, 1, 9}); + sd::ops::squaredsubtract op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); +} - auto res = results.at(0); - ASSERT_TRUE(res->equalsTo(exp)); +TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test3) { + auto x = NDArrayFactory::create('c', {2, 4}, {0, 1, 2, 3, 0, 1, 2, 3}); + auto y = NDArrayFactory::create('c', {4}, {3, 2, 1, 0}); + auto exp = NDArrayFactory::create('c', {2, 4}, + {-6, -4, 6, 24, -30, -12, 14, 48}); + auto eps = + NDArrayFactory::create('c', {2, 4}, {1, 2, 3, 4, 5, 6, 7, 8}); + sd::ops::squaredsubtract_bp op; + auto result = op.evaluate({&x, &y, &eps}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test1) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test1) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {-0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, + -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5}); + NDArray dLdwExp('c', {2, 3, 4}, + {0.96, 1.92, 2.88, 3.84, 4.8, 5.76, 6.72, 7.68, + 8.64, 9.6, 10.56, 11.52, 12.48, 13.44, 14.4, 15.36, + 16.32, 17.28, 18.24, 19.2, 20.16, 21.12, 22.08, 23.04}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); +} - NDArray dLdpExp('c', {2,3,4}, {-0.25999, -0.755 , -1.25 , -1.745 , -2.24001, -2.73502, -3.23004, -3.72508, -4.22014, -4.71523, -5.21034, -5.70548, - -6.20066, -6.69587, -7.19113, -7.68643, -8.18177, -8.67717, -9.17262, -9.66813,-10.1637 ,-10.65932,-11.15501,-11.65077}); - NDArray dLdwExp('c', {2,3,4}, {0.73395, 0.75335, 0.69315, 0.55335, 0.33395, 0.03495, -0.34366, -0.80186, -1.33967, -1.95708, -2.65411, -3.43074, - -4.28698, -5.22285, -6.23833, -7.33343, -8.50815, -9.76251,-11.0965 ,-12.51013,-14.00341,-15.57633,-17.2289 ,-18.96113}); - NDArray dLdlExp('c', {2,3,4}, {0.04, 0.02,-0. ,-0.02,-0.04,-0.06,-0.08,-0.1 ,-0.12,-0.14,-0.16,-0.18, - -0.2 ,-0.22,-0.24,-0.26,-0.28,-0.3 ,-0.32,-0.34,-0.36,-0.38,-0.4 ,-0.42}); +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test2) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 1, 4}, sd::DataType::DOUBLE); - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); + NDArray dLdwExp('c', {2, 1, 4}, + {14.4, 17.28, 20.16, 23.04, 48.96, 51.84, 54.72, 57.6}); - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test2) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,1,4}, sd::DataType::DOUBLE); +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test3) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights(sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {-0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, + -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5}); + NDArray dLdwExp('c', {}, std::vector{288.}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); +} - NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048, - -4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577}); - NDArray dLdwExp('c', {2,1,4}, {0.43622, -0.19079, -0.98462, -1.94525,-18.09855,-20.72768,-23.52373,-26.48669}); - NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0. , -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126, - -0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294}); +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test4) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); + NDArray dLdwExp('c', {1, 3, 1}, {65.28, 96., 126.72001}); - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {0}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test3) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights(sd::DataType::DOUBLE); +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test5) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {-0.04167, -0.04167, -0.04167, -0.04167, -0.04167, -0.04167, + -0.04167, -0.04167, -0.04167, -0.04167, -0.04167, -0.04167, + -0.04167, -0.04167, -0.04167, -0.04167, -0.04167, -0.04167, + -0.04167, -0.04167, -0.04167, -0.04167, -0.04167, -0.04167}); + NDArray dLdwExp('c', {2, 3, 4}, + {-0.92, -0.84, -0.76, -0.68, -0.6, -0.52, -0.44, -0.36, + -0.28, -0.2, -0.12, -0.04, 0.04, 0.12, 0.2, 0.28, + 0.36, 0.44, 0.52, 0.6, 0.68, 0.76, 0.84, 0.92}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); +} - NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048, - -4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577}); - NDArray dLdwExp('c', {}, std::vector{-91.52109}); - NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0., -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126, - -0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294}); +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test6) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); + NDArray dLdwExp('c', {1, 3, 1}, {-2.56, 0., 2.56}); - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {1}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test4) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {1,3,1}, {-12.54779,-28.13393,-50.83936}); +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test7) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights(sd::DataType::DOUBLE); - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); + NDArray dLdwExp('c', {}, std::vector{0.}); - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {1}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - auto *dLdw = results.at(1); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test5) { +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test8) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {-0., -0., -0., -0., -0.05, -0.05, -0.05, -0.05, + -0.05, -0.05, -0.05, -0.05, -0.05, -0.05, -0.05, -0.05, + -0.05, -0.05, -0.05, -0.05, -0.05, -0.05, -0.05, -0.05}); + NDArray dLdwExp( + 'c', {2, 3, 4}, + {-1.296, -1.2, -1.104, -1.008, -0.912, -0.816, -0.72, -0.624, + -0.528, -0.432, -0.336, -0.24, -0.144, -0.048, 0.048, 0.144, + 0.24, 0.336, 0.432, 0.528, 0.624, 0.72, 0.816, 0.912}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); +} - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test9) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {-0.02083, -0.02083, -0.02083, -0.02083, -0.02083, -0.02083, + -0.02083, -0.02083, -0.02083, -0.02083, -0.02083, -0.02083, + -0.02083, -0.02083, -0.02083, -0.02083, -0.02083, -0.02083, + -0.02083, -0.02083, -0.02083, -0.02083, -0.02083, -0.02083}); + NDArray dLdwExp( + 'c', {2, 3, 4}, + {0.04, 0.08, 0.12, 0.16, 0.2, 0.24, 0.28, 0.32, 0.36, 0.4, 0.44, 0.48, + 0.52, 0.56, 0.6, 0.64, 0.68, 0.72, 0.76, 0.8, 0.84, 0.88, 0.92, 0.96}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); +} - NDArray dLdpExp('c', {2,3,4}, {-0.01542,-0.04417,-0.07292,-0.10167,-0.13042,-0.15917,-0.18792,-0.21667,-0.24543,-0.27419,-0.30294,-0.33171, - -0.36047,-0.38924,-0.41801,-0.44679,-0.47556,-0.50435,-0.53314,-0.56193,-0.59072,-0.61953,-0.64833,-0.67715}); - NDArray dLdwExp('c', {2,3,4}, {0.37794, 0.37906, 0.37554, 0.36739, 0.35461, 0.33719, 0.31514, 0.28846, 0.25714, 0.22119, 0.18061, 0.13539, - 0.08553, 0.03104,-0.02808,-0.09184,-0.16023,-0.23326,-0.31093,-0.39323,-0.48017,-0.57175,-0.66796,-0.76881}); - NDArray dLdlExp('c', {2,3,4}, {0.00233, 0.00117,-0.,-0.00117,-0.00233,-0.0035 ,-0.00467,-0.00583,-0.007 ,-0.00817,-0.00933,-0.0105, - -0.01167,-0.01283,-0.014 ,-0.01517,-0.01633,-0.0175 ,-0.01867,-0.01983,-0.021 ,-0.02217,-0.02333,-0.0245}); +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test10) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 1}, sd::DataType::DOUBLE); - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); + NDArray dLdwExp('c', {1, 1}, std::vector{12.}); - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test6) { +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test11) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); + NDArray dLdwExp('c', {1, 3, 1}, {2.72, 4., 5.28}); - NDArray dLdwExp('c', {1,3,1}, {1.4966 , 0.19776,-1.69436}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto dLdw = results.at(1); - auto *dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); +} - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test12) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {0., 0., 0., 0., -0.025, -0.025, -0.025, -0.025, + -0.025, -0.025, -0.025, -0.025, -0.025, -0.025, -0.025, -0.025, + -0.025, -0.025, -0.025, -0.025, -0.025, -0.025, -0.025, -0.025}); + NDArray dLdwExp('c', {2, 3, 4}, + {0.048, 0.096, 0.144, 0.192, 0.24, 0.288, 0.336, 0.384, + 0.432, 0.48, 0.528, 0.576, 0.624, 0.672, 0.72, 0.768, + 0.816, 0.864, 0.912, 0.96, 1.008, 1.056, 1.104, 1.152}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.r(0) = 0.; + weights.r(1) = 0.; + weights.r(2) = 0.; + weights.r(3) = 0.; + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); +} +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test13) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 1}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + -0.04167, -0.04167, -0.04167, -0.04167, -0.04167, -0.04167, + -0.04167, -0.04167, -0.04167, -0.04167, -0.04167, -0.04167}); + NDArray dLdwExp('c', {2, 3, 1}, {0.8, 2.08, 3.36, 4.64, 5.92, 7.2}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.r(0) = 0.; + weights.r(1) = 0.; + weights.r(2) = 0.; + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test7) { +TEST_F(DeclarableOpsTests11, BFloat16_Test_1) { + NDArray x = NDArrayFactory::create('c', {2, 3, 4}); + NDArray y = NDArrayFactory::create( + 'c', {2, 3, 4}); //('c', {2,3,4}, sd::DataType::BFLOAT16); + NDArray exp = NDArrayFactory::create( + 'c', {2, 3, 4}); //('c', {2,3,4}, sd::DataType::BFLOAT16); + + x.linspace(1); + y.linspace(1); + exp.linspace(2, 2); + sd::ops::add op; + auto results = op.evaluate({&x, &y}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto res = results.at(0); + ASSERT_TRUE(res.equalsTo(exp)); +} - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights(sd::DataType::DOUBLE); +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, BFloat16_Test_2) { + NDArray x = NDArrayFactory::create('c', {2, 3, 4}); + NDArray y = NDArrayFactory::create( + 'c', {2, 3, 4}); //('c', {2,3,4}, sd::DataType::BFLOAT16); + NDArray exp = NDArrayFactory::create( + 'c', {2, 3, 4}); //('c', {2,3,4}, sd::DataType::BFLOAT16); + + x.linspace(1); + y.linspace(1); + exp.linspace(2, 2); + sd::ops::add op; + auto results = op.evaluate({&x, &y}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto res = results.at(0); + ASSERT_TRUE(res.equalsTo(exp)); +} - NDArray dLdwExp('c', {}, std::vector{0.}); +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, BFloat16_Test_3) { + NDArray x('c', {2, 3, 4}, sd::DataType::BFLOAT16); + NDArray y('c', {2, 3, 4}, sd::DataType::BFLOAT16); + NDArray exp('c', {2, 3, 4}, sd::DataType::BFLOAT16); - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); + x.linspace(1); + y.linspace(1); + exp.linspace(2, 2); + sd::ops::add op; + auto results = op.evaluate({&x, &y}, {}, {}); - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto res = results.at(0); + ASSERT_TRUE(res.equalsTo(exp)); +} - auto *dLdw = results.at(1); +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test1) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {-0.25999, -0.755, -1.25, -1.745, -2.24001, -2.73502, + -3.23004, -3.72508, -4.22014, -4.71523, -5.21034, -5.70548, + -6.20066, -6.69587, -7.19113, -7.68643, -8.18177, -8.67717, + -9.17262, -9.66813, -10.1637, -10.65932, -11.15501, -11.65077}); + NDArray dLdwExp( + 'c', {2, 3, 4}, + {0.73395, 0.75335, 0.69315, 0.55335, 0.33395, 0.03495, + -0.34366, -0.80186, -1.33967, -1.95708, -2.65411, -3.43074, + -4.28698, -5.22285, -6.23833, -7.33343, -8.50815, -9.76251, + -11.0965, -12.51013, -14.00341, -15.57633, -17.2289, -18.96113}); + NDArray dLdlExp('c', {2, 3, 4}, + {0.04, 0.02, -0., -0.02, -0.04, -0.06, -0.08, -0.1, + -0.12, -0.14, -0.16, -0.18, -0.2, -0.22, -0.24, -0.26, + -0.28, -0.3, -0.32, -0.34, -0.36, -0.38, -0.4, -0.42}); + + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); +} - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test2) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 1, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {-0.18499, -0.53, -0.875, -1.22, -1.56501, -1.91002, + -2.25504, -2.60008, -2.94514, -3.29023, -3.63534, -3.98048, + -4.32566, -4.67087, -5.01613, -5.36143, -5.70677, -6.05217, + -6.39762, -6.74313, -7.0887, -7.43432, -7.78001, -8.12577}); + NDArray dLdwExp('c', {2, 1, 4}, + {0.43622, -0.19079, -0.98462, -1.94525, -18.09855, -20.72768, + -23.52373, -26.48669}); + NDArray dLdlExp( + 'c', {2, 3, 4}, + {0.028, 0.014, -0., -0.014, -0.028, -0.042, -0.056, -0.07, + -0.084, -0.098, -0.112, -0.126, -0.14, -0.154, -0.168, -0.182, + -0.196, -0.21, -0.224, -0.238, -0.252, -0.266, -0.28, -0.294}); + + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); +} +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test3) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights(sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {-0.18499, -0.53, -0.875, -1.22, -1.56501, -1.91002, + -2.25504, -2.60008, -2.94514, -3.29023, -3.63534, -3.98048, + -4.32566, -4.67087, -5.01613, -5.36143, -5.70677, -6.05217, + -6.39762, -6.74313, -7.0887, -7.43432, -7.78001, -8.12577}); + NDArray dLdwExp('c', {}, std::vector{-91.52109}); + NDArray dLdlExp( + 'c', {2, 3, 4}, + {0.028, 0.014, -0., -0.014, -0.028, -0.042, -0.056, -0.07, + -0.084, -0.098, -0.112, -0.126, -0.14, -0.154, -0.168, -0.182, + -0.196, -0.21, -0.224, -0.238, -0.252, -0.266, -0.28, -0.294}); + + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test8) { +TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test4) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, { 0. , 0. , 0. , 0. ,-0.1565 ,-0.191 ,-0.2255 ,-0.26001,-0.29451,-0.32902,-0.36353,-0.39805, - -0.43257,-0.46709,-0.50161,-0.53614,-0.57068,-0.60522,-0.63976,-0.67431,-0.70887,-0.74343,-0.778 ,-0.81258}); - NDArray dLdwExp('c', {2,3,4}, {0.54353, 0.54487, 0.54065, 0.53087, 0.51553, 0.49463, 0.46817, 0.43615, 0.39857, 0.35543, 0.30672, 0.25246, - 0.19264, 0.12725, 0.0563 ,-0.02021,-0.10228,-0.18992,-0.28312,-0.38188,-0.48621,-0.5961 ,-0.71156,-0.83258}); - NDArray dLdlExp('c', {2,3,4}, {-0. ,-0. , 0. , 0. ,-0.0028,-0.0042,-0.0056,-0.007 ,-0.0084,-0.0098,-0.0112,-0.0126, - -0.014 ,-0.0154,-0.0168,-0.0182,-0.0196,-0.021 ,-0.0224,-0.0238,-0.0252,-0.0266,-0.028 ,-0.0294}); - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); - weights.p(3, 0.); - - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); + NDArray dLdwExp('c', {1, 3, 1}, {-12.54779, -28.13393, -50.83936}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {1}); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto dLdw = results.at(1); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test9) { +TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test5) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {-0.01542, -0.04417, -0.07292, -0.10167, -0.13042, -0.15917, + -0.18792, -0.21667, -0.24543, -0.27419, -0.30294, -0.33171, + -0.36047, -0.38924, -0.41801, -0.44679, -0.47556, -0.50435, + -0.53314, -0.56193, -0.59072, -0.61953, -0.64833, -0.67715}); + NDArray dLdwExp('c', {2, 3, 4}, + {0.37794, 0.37906, 0.37554, 0.36739, 0.35461, 0.33719, + 0.31514, 0.28846, 0.25714, 0.22119, 0.18061, 0.13539, + 0.08553, 0.03104, -0.02808, -0.09184, -0.16023, -0.23326, + -0.31093, -0.39323, -0.48017, -0.57175, -0.66796, -0.76881}); + NDArray dLdlExp('c', {2, 3, 4}, + {0.00233, 0.00117, -0., -0.00117, -0.00233, -0.0035, + -0.00467, -0.00583, -0.007, -0.00817, -0.00933, -0.0105, + -0.01167, -0.01283, -0.014, -0.01517, -0.01633, -0.0175, + -0.01867, -0.01983, -0.021, -0.02217, -0.02333, -0.0245}); + + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); +} - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test6) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - NDArray dLdpExp('c', {2,3,4}, {-0.00771, -0.02208, -0.03646, -0.05083,-0.06521, -0.07958, -0.09396, -0.10834,-0.12271, -0.13709, -0.15147, -0.16585, - -0.18024, -0.19462, -0.20901, -0.22339,-0.23778, -0.25217, -0.26657, -0.28096,-0.29536, -0.30976, -0.32417, -0.33857}); - NDArray dLdwExp('c', {2,3,4}, {0.03008, 0.03064, 0.02888, 0.02481, 0.01841, 0.00971, -0.00132, -0.01466,-0.03032, -0.0483 , -0.06859, -0.0912 , - -0.11612, -0.14337, -0.17293, -0.20481,-0.23901, -0.27552, -0.31435, -0.35551,-0.39898, -0.44476, -0.49287, -0.5433 }); - NDArray dLdlExp('c', {2,3,4}, {0.00117, 0.00058, -0. , -0.00058,-0.00117, -0.00175, -0.00233, -0.00292,-0.0035 , -0.00408, -0.00467, -0.00525, - -0.00583, -0.00642, -0.007 , -0.00758,-0.00817, -0.00875, -0.00933, -0.00992,-0.0105 , -0.01108, -0.01167, -0.01225}); - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); + NDArray dLdwExp('c', {1, 3, 1}, {1.4966, 0.19776, -1.69436}); - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test10) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,1}, sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {1,1}, std::vector{-3.81338}); +TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test7) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights(sd::DataType::DOUBLE); - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); + NDArray dLdwExp('c', {}, std::vector{0.}); - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); - auto *dLdw = results.at(1); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test11) { +TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test8) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {0., 0., 0., 0., -0.1565, -0.191, + -0.2255, -0.26001, -0.29451, -0.32902, -0.36353, -0.39805, + -0.43257, -0.46709, -0.50161, -0.53614, -0.57068, -0.60522, + -0.63976, -0.67431, -0.70887, -0.74343, -0.778, -0.81258}); + NDArray dLdwExp('c', {2, 3, 4}, + {0.54353, 0.54487, 0.54065, 0.53087, 0.51553, 0.49463, + 0.46817, 0.43615, 0.39857, 0.35543, 0.30672, 0.25246, + 0.19264, 0.12725, 0.0563, -0.02021, -0.10228, -0.18992, + -0.28312, -0.38188, -0.48621, -0.5961, -0.71156, -0.83258}); + NDArray dLdlExp( + 'c', {2, 3, 4}, + {-0., -0., 0., 0., -0.0028, -0.0042, -0.0056, -0.007, + -0.0084, -0.0098, -0.0112, -0.0126, -0.014, -0.0154, -0.0168, -0.0182, + -0.0196, -0.021, -0.0224, -0.0238, -0.0252, -0.0266, -0.028, -0.0294}); + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); +} - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test9) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {-0.00771, -0.02208, -0.03646, -0.05083, -0.06521, -0.07958, + -0.09396, -0.10834, -0.12271, -0.13709, -0.15147, -0.16585, + -0.18024, -0.19462, -0.20901, -0.22339, -0.23778, -0.25217, + -0.26657, -0.28096, -0.29536, -0.30976, -0.32417, -0.33857}); + NDArray dLdwExp('c', {2, 3, 4}, + {0.03008, 0.03064, 0.02888, 0.02481, 0.01841, 0.00971, + -0.00132, -0.01466, -0.03032, -0.0483, -0.06859, -0.0912, + -0.11612, -0.14337, -0.17293, -0.20481, -0.23901, -0.27552, + -0.31435, -0.35551, -0.39898, -0.44476, -0.49287, -0.5433}); + NDArray dLdlExp('c', {2, 3, 4}, + {0.00117, 0.00058, -0., -0.00058, -0.00117, -0.00175, + -0.00233, -0.00292, -0.0035, -0.00408, -0.00467, -0.00525, + -0.00583, -0.00642, -0.007, -0.00758, -0.00817, -0.00875, + -0.00933, -0.00992, -0.0105, -0.01108, -0.01167, -0.01225}); + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); +} - NDArray dLdwExp('c', {1,3,1}, {-0.52282,-1.17225,-2.11831}); +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test10) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 1}, sd::DataType::DOUBLE); - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); + NDArray dLdwExp('c', {1, 1}, std::vector{-3.81338}); - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); - auto *dLdw = results.at(1); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test12) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); +TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test11) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. ,-0.07825, -0.0955 , -0.11275, -0.13 ,-0.14726, -0.16451, -0.18177, -0.19902, - -0.21628, -0.23354, -0.25081, -0.26807,-0.28534, -0.30261, -0.31988, -0.33716,-0.35443, -0.37172, -0.389 , -0.40629}); - NDArray dLdwExp('c', {2,3,4}, {0.0361 , 0.03677, 0.03466, 0.02977, 0.0221 , 0.01165, -0.00158, -0.01759,-0.03638, -0.05795, -0.08231, -0.10944, - -0.13935, -0.17204, -0.20752, -0.24577,-0.28681, -0.33063, -0.37723, -0.42661,-0.47877, -0.53372, -0.59144, -0.65196}); - NDArray dLdlExp('c', {2,3,4}, {-0. , -0. , 0. , 0. ,-0.0014, -0.0021, -0.0028, -0.0035,-0.0042, -0.0049, -0.0056, -0.0063, - -0.007 , -0.0077, -0.0084, -0.0091,-0.0098, -0.0105, -0.0112, -0.0119,-0.0126, -0.0133, -0.014 , -0.0147}); - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.r(0) = 0.; - weights.r(1) = 0.; - weights.r(2) = 0.; - weights.r(3) = 0.; + NDArray dLdwExp('c', {1, 3, 1}, {-0.52282, -1.17225, -2.11831}); + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdw = results.at(1); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); +} +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test12) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {0., 0., 0., 0., -0.07825, -0.0955, + -0.11275, -0.13, -0.14726, -0.16451, -0.18177, -0.19902, + -0.21628, -0.23354, -0.25081, -0.26807, -0.28534, -0.30261, + -0.31988, -0.33716, -0.35443, -0.37172, -0.389, -0.40629}); + NDArray dLdwExp('c', {2, 3, 4}, + {0.0361, 0.03677, 0.03466, 0.02977, 0.0221, 0.01165, + -0.00158, -0.01759, -0.03638, -0.05795, -0.08231, -0.10944, + -0.13935, -0.17204, -0.20752, -0.24577, -0.28681, -0.33063, + -0.37723, -0.42661, -0.47877, -0.53372, -0.59144, -0.65196}); + NDArray dLdlExp( + 'c', {2, 3, 4}, + {-0., -0., 0., 0., -0.0014, -0.0021, -0.0028, -0.0035, + -0.0042, -0.0049, -0.0056, -0.0063, -0.007, -0.0077, -0.0084, -0.0091, + -0.0098, -0.0105, -0.0112, -0.0119, -0.0126, -0.0133, -0.014, -0.0147}); + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.r(0) = 0.; + weights.r(1) = 0.; + weights.r(2) = 0.; + weights.r(3) = 0.; + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test13) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,1}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , - -0.36047, -0.38924, -0.41801, -0.44679,-0.47556, -0.50435, -0.53314, -0.56193,-0.59072, -0.61953, -0.64833, -0.67715}); - NDArray dLdwExp('c', {2,3,1}, {0.22882, 0.02428,-0.4768 ,-1.27447,-2.36878,-3.75981,}); - NDArray dLdlExp('c', {2,3,4}, {-0. , -0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0., - -0.01167, -0.01283, -0.014 , -0.01517,-0.01633, -0.0175 , -0.01867, -0.01983,-0.021 , -0.02217, -0.02333, -0.0245}); - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.r(0) = 0.; - weights.r(1) = 0.; - weights.r(2) = 0.; - - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); - + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 1}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + -0.36047, -0.38924, -0.41801, -0.44679, -0.47556, -0.50435, + -0.53314, -0.56193, -0.59072, -0.61953, -0.64833, -0.67715}); + NDArray dLdwExp('c', {2, 3, 1}, + { + 0.22882, + 0.02428, + -0.4768, + -1.27447, + -2.36878, + -3.75981, + }); + NDArray dLdlExp('c', {2, 3, 4}, + {-0., -0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + -0.01167, -0.01283, -0.014, -0.01517, -0.01633, -0.0175, + -0.01867, -0.01983, -0.021, -0.02217, -0.02333, -0.0245}); + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.r(0) = 0.; + weights.r(1) = 0.; + weights.r(2) = 0.; + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, BFloat16_Test_4) { - - NDArray x = NDArrayFactory::create('c', {2,3,4}); - NDArray y = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); - NDArray exp = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); - - x.linspace(1); - y.linspace(1); - exp.linspace(2,2); - sd::ops::add op; - auto results = op.evaluate({&x, &y}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto res = results.at(0); - ASSERT_TRUE(res->equalsTo(exp)); + NDArray x = NDArrayFactory::create('c', {2, 3, 4}); + NDArray y = NDArrayFactory::create( + 'c', {2, 3, 4}); //('c', {2,3,4}, sd::DataType::BFLOAT16); + NDArray exp = NDArrayFactory::create( + 'c', {2, 3, 4}); //('c', {2,3,4}, sd::DataType::BFLOAT16); + + x.linspace(1); + y.linspace(1); + exp.linspace(2, 2); + sd::ops::add op; + auto results = op.evaluate({&x, &y}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto res = results.at(0); + ASSERT_TRUE(res.equalsTo(exp)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, BFloat16_Test_5) { - - NDArray x = NDArrayFactory::create('c', {2,3,4}); - NDArray y = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); - NDArray exp = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); - - x.linspace(2, 2); - y.linspace(1); - exp.linspace(1); - sd::ops::subtract op; - auto results = op.evaluate({&x, &y}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto res = results.at(0); - ASSERT_TRUE(res->equalsTo(exp)); - + NDArray x = NDArrayFactory::create('c', {2, 3, 4}); + NDArray y = NDArrayFactory::create( + 'c', {2, 3, 4}); //('c', {2,3,4}, sd::DataType::BFLOAT16); + NDArray exp = NDArrayFactory::create( + 'c', {2, 3, 4}); //('c', {2,3,4}, sd::DataType::BFLOAT16); + + x.linspace(2, 2); + y.linspace(1); + exp.linspace(1); + sd::ops::subtract op; + auto results = op.evaluate({&x, &y}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto res = results.at(0); + ASSERT_TRUE(res.equalsTo(exp)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, BFloat16_Test_6) { - - NDArray x = NDArrayFactory::create('c', {2,3,4}); - NDArray y = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); - NDArray exp = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); - - x.linspace(2, 2); - y.linspace(1); - exp.linspace(1); - sd::ops::subtract op; - auto results = op.evaluate({&x, &y}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto res = results.at(0); - ASSERT_TRUE(res->equalsTo(exp)); + NDArray x = NDArrayFactory::create('c', {2, 3, 4}); + NDArray y = NDArrayFactory::create( + 'c', {2, 3, 4}); //('c', {2,3,4}, sd::DataType::BFLOAT16); + NDArray exp = NDArrayFactory::create( + 'c', {2, 3, 4}); //('c', {2,3,4}, sd::DataType::BFLOAT16); + + x.linspace(2, 2); + y.linspace(1); + exp.linspace(1); + sd::ops::subtract op; + auto results = op.evaluate({&x, &y}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto res = results.at(0); + ASSERT_TRUE(res.equalsTo(exp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test1) { + NDArray labels('c', {2, 4}, {0, 0, 1, 0, 0, 1, 0, 0}, sd::DataType::INT32); + NDArray logits('c', {2, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2}, sd::DataType::DOUBLE); - NDArray labels('c', {2,4}, {0,0,1,0, 0,1,0,0}, sd::DataType::INT32); - NDArray logits('c', {2,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2}, sd::DataType::DOUBLE); + NDArray dLdpExp( + 'c', {2, 4}, + {0.1176, 0.1224, -0.3726, 0.1326, 0.1176, -0.3776, 0.1274, 0.1326}); + NDArray dLdwExp('c', {2}, {1.36729, 1.40729}); - NDArray dLdpExp('c', {2,4}, {0.1176, 0.1224, -0.3726, 0.1326, 0.1176, -0.3776, 0.1274, 0.1326}); - NDArray dLdwExp('c', {2}, {1.36729, 1.40729}); + logits.linspace(-0.08, 0.04); + weights.assign(0.5); - logits.linspace(-0.08, 0.04); - weights.assign(0.5); + sd::ops::softmax_cross_entropy_loss_grad op; - sd::ops::softmax_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test2) { + NDArray labels('c', {4}, {0, 0, 1, 0}, sd::DataType::INT32); + NDArray logits('c', {4}, sd::DataType::DOUBLE); + NDArray weights('c', {1}, sd::DataType::DOUBLE); - NDArray labels('c', {4}, {0,0,1,0}, sd::DataType::INT32); - NDArray logits('c', {4}, sd::DataType::DOUBLE); - NDArray weights('c', {1}, sd::DataType::DOUBLE); + NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125}); + NDArray dLdwExp('c', {1}, std::vector{1.38629}); - NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125}); - NDArray dLdwExp('c', {1}, std::vector{1.38629}); + logits = 2.; + weights.assign(0.5); - logits = 2.; - weights.assign(0.5); + sd::ops::softmax_cross_entropy_loss_grad op; - sd::ops::softmax_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test3) { + NDArray labels('c', {4}, {0, 0, 1, 0}, sd::DataType::INT32); + NDArray logits('c', {4}, sd::DataType::DOUBLE); + NDArray weights('c', {}, std::vector{0}, sd::DataType::DOUBLE); - NDArray labels('c', {4}, {0,0,1,0}, sd::DataType::INT32); - NDArray logits('c', {4}, sd::DataType::DOUBLE); - NDArray weights('c', {}, std::vector{0}, sd::DataType::DOUBLE); + NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125}); + NDArray dLdwExp('c', {}, std::vector{1.38629}); - NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125}); - NDArray dLdwExp('c', {}, std::vector{1.38629}); + logits = 2.; + weights.assign(0.5); - logits = 2.; - weights.assign(0.5); + sd::ops::softmax_cross_entropy_loss_grad op; - sd::ops::softmax_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test4) { + NDArray labels('c', {4}, {0, 0, 1, 0}, sd::DataType::INT32); + NDArray logits('c', {4}, sd::DataType::DOUBLE); + NDArray weights('c', {}, std::vector{0}, sd::DataType::DOUBLE); - NDArray labels('c', {4}, {0,0,1,0}, sd::DataType::INT32); - NDArray logits('c', {4}, sd::DataType::DOUBLE); - NDArray weights('c', {}, std::vector{0}, sd::DataType::DOUBLE); + NDArray dLdpExp('c', {4}, {0.23521, 0.2448, -0.7452, 0.26519}); + NDArray dLdwExp('c', {}, std::vector{0.}); - NDArray dLdpExp('c', {4}, {0.23521, 0.2448 , -0.7452 , 0.26519}); - NDArray dLdwExp('c', {}, std::vector{0.}); + logits.linspace(-0.08, 0.04); + weights = 0.5; - logits.linspace(-0.08, 0.04); - weights = 0.5; + sd::ops::softmax_cross_entropy_loss_grad op; - sd::ops::softmax_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2}); - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test5) { + NDArray labels('c', {4}, {0, 0, 1, 0}, sd::DataType::INT32); + NDArray logits('c', {4}, sd::DataType::DOUBLE); + NDArray weights('c', {1}, sd::DataType::DOUBLE); - NDArray labels('c', {4}, {0,0,1,0}, sd::DataType::INT32); - NDArray logits('c', {4}, sd::DataType::DOUBLE); - NDArray weights('c', {1}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {4}, {0.1176, 0.1224, -0.3726, 0.1326}); - NDArray dLdwExp('c', {1}, std::vector{1.36729}); + NDArray dLdpExp('c', {4}, {0.1176, 0.1224, -0.3726, 0.1326}); + NDArray dLdwExp('c', {1}, std::vector{1.36729}); - logits.linspace(-0.08, 0.04); - weights = 0.5; + logits.linspace(-0.08, 0.04); + weights = 0.5; - sd::ops::softmax_cross_entropy_loss_grad op; + sd::ops::softmax_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test6) { + NDArray labels('c', {2, 4}, {0, 0, 1, 0, 0, 1, 0, 0}, sd::DataType::INT32); + NDArray logits('c', {2, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2}, sd::DataType::DOUBLE); - NDArray labels('c', {2,4}, {0,0,1,0, 0,1,0,0}, sd::DataType::INT32); - NDArray logits('c', {2,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,4}, {0.0801, 0.0849, -0.2601, 0.0951, 0.0801, -0.2651, 0.0899, 0.0951}); - NDArray dLdwExp('c', {2}, {-0.014000, 0.014000}); - - logits.linspace(-0.08, 0.04); - weights.assign(0.5); + NDArray dLdpExp( + 'c', {2, 4}, + {0.0801, 0.0849, -0.2601, 0.0951, 0.0801, -0.2651, 0.0899, 0.0951}); + NDArray dLdwExp('c', {2}, {-0.014000, 0.014000}); - sd::ops::softmax_cross_entropy_loss_grad op; + logits.linspace(-0.08, 0.04); + weights.assign(0.5); - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); + sd::ops::softmax_cross_entropy_loss_grad op; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test7) { + NDArray labels('c', {2, 3, 4}, {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, + 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0}, + sd::DataType::INT32); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3}, {0.5, 0., 1.5}); - NDArray labels('c', {2,3,4}, {1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1, 1,0,0,0, 0,1,0,0}, sd::DataType::INT32); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3}, {0.5, 0., 1.5}); - - NDArray dLdpExp('c', {2,3,4}, {-0.0956 , 0.0306 , 0.03185, 0.03315, 0.,-0., 0., 0., 0.0882 , 0.0918 ,-0.27945, 0.09945, - 0.0294 , 0.0306 , 0.03185,-0.09185,-0., 0., 0., 0., 0.0882 ,-0.2832 , 0.09555, 0.09945}); - NDArray dLdwExp('c', {1,3}, {0.69365, 0.71365, 0.69365}); + NDArray dLdpExp( + 'c', {2, 3, 4}, + {-0.0956, 0.0306, 0.03185, 0.03315, 0., -0., 0., 0., + 0.0882, 0.0918, -0.27945, 0.09945, 0.0294, 0.0306, 0.03185, -0.09185, + -0., 0., 0., 0., 0.0882, -0.2832, 0.09555, 0.09945}); + NDArray dLdwExp('c', {1, 3}, {0.69365, 0.71365, 0.69365}); - logits.linspace(-0.08, 0.04); + logits.linspace(-0.08, 0.04); - sd::ops::softmax_cross_entropy_loss_grad op; + sd::ops::softmax_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3}); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test8) { - - NDArray labels('c', {2,3,4,5}, {1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0, - 0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1, - 0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0}, sd::DataType::INT32); - - NDArray logits('c', {2,3,4,5}, sd::DataType::DOUBLE); - NDArray weights('c', {1,1,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4,5}, {-0.03399, 0.00799, 0.00832, 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, - 0.00866, 0.00901, 0.00768, 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, - 0.00799, 0.00832, 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, - 0.00901, 0.00768, 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, 0.00799, - 0.00832, 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, 0.00901, - 0.00768, 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, 0.00799, 0.00832, - 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, 0.00901, 0.00768, - 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, 0.00799, 0.00832, 0.00866, - 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, 0.00901, 0.00768, 0.00799, 0.00832,-0.03301, 0.00901}); - - NDArray dLdwExp('c', {1,1,4}, {0.005, 0.00167, -0.00167, -0.005}); - logits.linspace(-0.08, 0.04); - weights.assign(0.5); - - sd::ops::softmax_cross_entropy_loss_grad op; - - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); - - // dLdp->printIndexedBuffer(); - - // ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - // ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - + NDArray labels( + 'c', {2, 3, 4, 5}, + {1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, + 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, + 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, + 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0}, + sd::DataType::INT32); + + NDArray logits('c', {2, 3, 4, 5}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 1, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4, 5}, + {-0.03399, 0.00799, 0.00832, 0.00866, 0.00901, 0.00768, -0.03367, + 0.00832, 0.00866, 0.00901, 0.00768, 0.00799, -0.03335, 0.00866, + 0.00901, 0.00768, 0.00799, 0.00832, -0.03301, 0.00901, 0.00768, + 0.00799, 0.00832, 0.00866, -0.03265, -0.03399, 0.00799, 0.00832, + 0.00866, 0.00901, 0.00768, -0.03367, 0.00832, 0.00866, 0.00901, + 0.00768, 0.00799, -0.03335, 0.00866, 0.00901, 0.00768, 0.00799, + 0.00832, -0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866, + -0.03265, -0.03399, 0.00799, 0.00832, 0.00866, 0.00901, 0.00768, + -0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799, -0.03335, + 0.00866, 0.00901, 0.00768, 0.00799, 0.00832, -0.03301, 0.00901, + 0.00768, 0.00799, 0.00832, 0.00866, -0.03265, -0.03399, 0.00799, + 0.00832, 0.00866, 0.00901, 0.00768, -0.03367, 0.00832, 0.00866, + 0.00901, 0.00768, 0.00799, -0.03335, 0.00866, 0.00901, 0.00768, + 0.00799, 0.00832, -0.03301, 0.00901, 0.00768, 0.00799, 0.00832, + 0.00866, -0.03265, -0.03399, 0.00799, 0.00832, 0.00866, 0.00901, + 0.00768, -0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799, + -0.03335, 0.00866, 0.00901, 0.00768, 0.00799, 0.00832, -0.03301, + 0.00901}); + + NDArray dLdwExp('c', {1, 1, 4}, {0.005, 0.00167, -0.00167, -0.005}); + logits.linspace(-0.08, 0.04); + weights.assign(0.5); + + sd::ops::softmax_cross_entropy_loss_grad op; + + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + // dLdp->printIndexedBuffer(); + + // ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + // ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, SafeDivideMixed_Test1) { + NDArray labels('c', {2, 3}, {1.0, 2.0, 3.0, -1.0, 2.0, 1.0}); + auto sumDiff = labels.reduceAlongDimension(reduce::Sum, {1}, true); - NDArray labels('c', {2, 3}, {1.0, 2.0, 3.0, -1.0, 2.0, 1.0}); - auto sumDiff = labels.reduceAlongDimension(reduce::Sum, {1}, true); - - NDArray numOfNonZero(sumDiff.shapeInfo(), sd::DataType::INT64, false); - numOfNonZero.assign(1); - sumDiff.applyPairwiseTransform(pairwise::SafeDivide, numOfNonZero, sumDiff); + NDArray numOfNonZero(sumDiff.shapeInfo(), sd::DataType::INT64, false); + numOfNonZero.assign(1); + sumDiff.applyPairwiseTransform(pairwise::SafeDivide, numOfNonZero, sumDiff); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test1) { + NDArray labels('c', {2, 3, 4}, {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, + 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0}); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, {1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1, 1,0,0,0, 0,1,0,0}); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0.76479, 0.2448, 0.2548, 0.26519, 0.23521,-0.7552, 0.2548, 0.26519, 0.23521, 0.2448,-0.7452, 0.26519, - 0.23521, 0.2448, 0.2548,-0.73481,-0.76479, 0.2448, 0.2548, 0.26519, 0.23521,-0.7552, 0.2548, 0.26519}); - logits.linspace(-0.08, 0.04); + NDArray dLdpExp( + 'c', {2, 3, 4}, + {-0.76479, 0.2448, 0.2548, 0.26519, 0.23521, -0.7552, 0.2548, 0.26519, + 0.23521, 0.2448, -0.7452, 0.26519, 0.23521, 0.2448, 0.2548, -0.73481, + -0.76479, 0.2448, 0.2548, 0.26519, 0.23521, -0.7552, 0.2548, 0.26519}); + logits.linspace(-0.08, 0.04); - sd::ops::softmax_cross_entropy_loss_with_logits_grad op; + sd::ops::softmax_cross_entropy_loss_with_logits_grad op; - auto results = op.evaluate({&logits, &labels}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&logits, &labels}, {}, {}); - auto *dLdp = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test2) { + NDArray labels('c', {2, 3, 4}, {1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, + 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0}); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, {1,0,0,0, 0,1,0,1, 0,0,1,0, 0,0,0,1, 1,0,1,0, 0,1,0,0}); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray dLdpExp('c', {2, 3, 4}, + {-0.71836, 0.28164, 0.28164, 0.28164, 0.33051, -0.66949, + 0.33051, -0.66949, 0.38785, 0.38785, -0.61215, 0.38785, + 0.28164, 0.28164, 0.28164, -0.71836, -0.66949, 0.33051, + -0.66949, 0.33051, 0.38785, -0.61215, 0.38785, 0.38785}); + logits.linspace(-0.08, 0.04); - NDArray dLdpExp('c', {2,3,4}, {-0.71836, 0.28164, 0.28164, 0.28164, 0.33051, -0.66949, 0.33051, -0.66949, 0.38785, 0.38785, -0.61215, 0.38785, - 0.28164, 0.28164, 0.28164, -0.71836,-0.66949, 0.33051, -0.66949, 0.33051, 0.38785, -0.61215, 0.38785, 0.38785}); - logits.linspace(-0.08, 0.04); + sd::ops::softmax_cross_entropy_loss_with_logits_grad op; - sd::ops::softmax_cross_entropy_loss_with_logits_grad op; + auto results = op.evaluate({&logits, &labels}, {}, {1}); - auto results = op.evaluate({&logits, &labels}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test3) { + NDArray labels('c', {2, 3}, {1, 0, 0, 0, 1, 1}); + NDArray logits('c', {2, 3}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3}, {1,0,0, 0,1,1}); - NDArray logits('c', {2,3}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3}, {-0.52996, 0.47004, 0.47004, 0.52996, -0.47004, -0.47004}); - logits.linspace(-0.08, 0.04); + NDArray dLdpExp('c', {2, 3}, + {-0.52996, 0.47004, 0.47004, 0.52996, -0.47004, -0.47004}); + logits.linspace(-0.08, 0.04); - sd::ops::softmax_cross_entropy_loss_with_logits_grad op; + sd::ops::softmax_cross_entropy_loss_with_logits_grad op; - auto results = op.evaluate({&logits, &labels}, {}, {0}); + auto results = op.evaluate({&logits, &labels}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test4) { + NDArray labels('c', {2, 1}, {1, 1}); + NDArray logits('c', {2, 1}, {-0.04, 0.04}); - NDArray labels('c', {2,1}, {1,1}); - NDArray logits('c', {2,1}, {-0.04, 0.04}); - - NDArray dLdpExp('c', {2,1}, {0., 0.}); - - sd::ops::softmax_cross_entropy_loss_with_logits_grad op; + NDArray dLdpExp('c', {2, 1}, {0., 0.}); - auto results = op.evaluate({&logits, &labels}, {}, {1}); + sd::ops::softmax_cross_entropy_loss_with_logits_grad op; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&logits, &labels}, {}, {1}); - auto *dLdp = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test5) { + NDArray labels('c', {2, 1}, std::vector{1, 0}); + NDArray logits('c', {2, 1}, {-0.04, 0.04}); - NDArray labels('c', {2,1}, std::vector{1,0}); - NDArray logits('c', {2,1}, {-0.04, 0.04}); - - NDArray dLdpExp('c', {2,1}, {-0.51999, 0.51999}); + NDArray dLdpExp('c', {2, 1}, {-0.51999, 0.51999}); - sd::ops::softmax_cross_entropy_loss_with_logits_grad op; + sd::ops::softmax_cross_entropy_loss_with_logits_grad op; - auto results = op.evaluate({&logits, &labels}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&logits, &labels}, {}, {0}); - auto *dLdp = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test6) { + NDArray labels('c', {1, 2}, {1, 1.}); + NDArray logits('c', {1, 2}, {-0.04, 0.04}); - NDArray labels('c', {1,2}, {1,1.}); - NDArray logits('c', {1,2}, {-0.04, 0.04}); + NDArray dLdpExp('c', {1, 2}, {0, 0.}); - NDArray dLdpExp('c', {1,2}, {0, 0.}); + sd::ops::softmax_cross_entropy_loss_with_logits_grad op; - sd::ops::softmax_cross_entropy_loss_with_logits_grad op; + auto results = op.evaluate({&logits, &labels}, {}, {0}); - auto results = op.evaluate({&logits, &labels}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test7) { + NDArray labels('c', {2}, {0, 1}); + NDArray logits('c', {2}, {-0.04, 0.04}); - NDArray labels('c', {2}, {0,1}); - NDArray logits('c', {2}, {-0.04, 0.04}); - - NDArray dLdpExp('c', {2}, {0.48001, -0.48001}); + NDArray dLdpExp('c', {2}, {0.48001, -0.48001}); - sd::ops::softmax_cross_entropy_loss_with_logits_grad op; + sd::ops::softmax_cross_entropy_loss_with_logits_grad op; - auto results = op.evaluate({&logits, &labels}, {}, {0}); + auto results = op.evaluate({&logits, &labels}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test8) { + NDArray labels('c', {1}, std::vector{1}); + NDArray logits('c', {1}, std::vector{0.04}); - NDArray labels('c', {1}, std::vector{1}); - NDArray logits('c', {1}, std::vector{0.04}); - - NDArray dLdpExp('c', {1}, std::vector{0}); - - sd::ops::softmax_cross_entropy_loss_with_logits_grad op; + NDArray dLdpExp('c', {1}, std::vector{0}); - auto results = op.evaluate({&logits, &labels}, {}, {0}); + sd::ops::softmax_cross_entropy_loss_with_logits_grad op; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&logits, &labels}, {}, {0}); - auto *dLdp = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, Multiply_BP_Test1) { + NDArray x('c', {3, 4, 5}, sd::DataType::DOUBLE); + NDArray y('c', {1, 1, 1}, sd::DataType::DOUBLE); - NDArray x('c', {3,4,5}, sd::DataType::DOUBLE); - NDArray y('c', {1,1,1}, sd::DataType::DOUBLE); - - NDArray dLdp('c', {3,4,5}, sd::DataType::DOUBLE); - NDArray dLdpExp('c', {3,4,5}, sd::DataType::DOUBLE); + NDArray dLdp('c', {3, 4, 5}, sd::DataType::DOUBLE); + NDArray dLdpExp('c', {3, 4, 5}, sd::DataType::DOUBLE); - x.assign(1.0);//linspace(0.1, 0.1); - y.assign(1.0); - dLdp.assign(1.0); - dLdpExp.assign(1.0); - sd::ops::multiply_bp op; + x.assign(1.0); // linspace(0.1, 0.1); + y.assign(1.0); + dLdp.assign(1.0); + dLdpExp.assign(1.0); + sd::ops::multiply_bp op; - auto results = op.evaluate({&x, &y, &dLdp}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&x, &y, &dLdp}, {}, {}); - auto *dLdo = results.at(0); - ASSERT_TRUE(dLdpExp.isSameShape(dLdo)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdo)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto dLdo = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdo)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdo)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test1) { + NDArray labels('c', {2}, {2, 1}, sd::DataType::INT64); + NDArray logits('c', {2, 3}, sd::DataType::DOUBLE); - NDArray labels('c', {2}, {2,1}, sd::DataType::INT64); - NDArray logits('c', {2,3}, sd::DataType::DOUBLE); + NDArray dLdpExp('c', {2, 3}, + {0.30061, 0.33222, -0.63283, 0.30061, -0.66778, 0.36717}); - NDArray dLdpExp('c', {2,3}, {0.30061, 0.33222, -0.63283, 0.30061, -0.66778, 0.36717}); + logits.linspace(0.1, 0.1); - logits.linspace(0.1, 0.1); + sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; - sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; + auto results = op.evaluate({&labels, &logits}, {}, {}); - auto results = op.evaluate({&labels, &logits}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test2) { + NDArray labels('c', {2}, {0, 1}, sd::DataType::INT64); + NDArray logits('c', {2, 3}, sd::DataType::DOUBLE); - NDArray labels('c', {2}, {0,1}, sd::DataType::INT64); - NDArray logits('c', {2,3}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3}, {-0.69939, 0.33222, 0.36717, 0.30061, -0.66778, 0.36717}); + NDArray dLdpExp('c', {2, 3}, + {-0.69939, 0.33222, 0.36717, 0.30061, -0.66778, 0.36717}); - logits.linspace(-0.1, 0.1); + logits.linspace(-0.1, 0.1); - sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; + sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; - auto results = op.evaluate({&labels, &logits}, {}, {}); + auto results = op.evaluate({&labels, &logits}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test3) { + NDArray labels('c', {}, std::vector{1}, sd::DataType::INT64); + NDArray logits('c', {2}, {-0.2, 0.3}); - NDArray labels('c', {}, std::vector{1}, sd::DataType::INT64); - NDArray logits('c', {2}, {-0.2, 0.3}); - - NDArray dLdpExp('c', {2}, {0.37754, -0.37754}); - - sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; + NDArray dLdpExp('c', {2}, {0.37754, -0.37754}); - auto results = op.evaluate({&labels, &logits}, {}, {}); + sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&labels, &logits}, {}, {}); - auto *dLdp = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test4) { + NDArray labels('c', {2, 3}, {0, 1, 1, 3, 3, 2}, sd::DataType::INT64); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3}, {0,1,1, 3,3,2}, sd::DataType::INT64); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0.78616, 0.23633, 0.26118, 0.28865, 0.21384, -0.76367, 0.26118, 0.28865, 0.21384, -0.76367, 0.26118, 0.28865, - 0.21384, 0.23633, 0.26118, -0.71135, 0.21384, 0.23633, 0.26118, -0.71135, 0.21384, 0.23633, -0.73882, 0.28865}); - logits.linspace(-0.5, 0.1); + NDArray dLdpExp('c', {2, 3, 4}, + {-0.78616, 0.23633, 0.26118, 0.28865, 0.21384, -0.76367, + 0.26118, 0.28865, 0.21384, -0.76367, 0.26118, 0.28865, + 0.21384, 0.23633, 0.26118, -0.71135, 0.21384, 0.23633, + 0.26118, -0.71135, 0.21384, 0.23633, -0.73882, 0.28865}); + logits.linspace(-0.5, 0.1); - sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; + sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; - auto results = op.evaluate({&labels, &logits}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&labels, &logits}, {}, {}); - auto *dLdp = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test5) { + NDArray labels('c', {1, 1}, std::vector({0}), sd::DataType::INT64); + NDArray logits('c', {1, 1, 2}, {-0.3, 0.2}); - NDArray labels('c', {1,1}, std::vector({0}), sd::DataType::INT64); - NDArray logits('c', {1,1,2}, {-0.3,0.2}); + NDArray dLdpExp('c', {1, 1, 2}, {-0.62246, 0.62246}); - NDArray dLdpExp('c', {1,1,2}, {-0.62246, 0.62246}); + sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; - sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; + auto results = op.evaluate({&labels, &logits}, {}, {}); - auto results = op.evaluate({&labels, &logits}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } - diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index 9e5281afe51b..fcab5d25cc85 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -14,2812 +14,3140 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // Created by raver on 8/4/2018. // -#include "testlayers.h" -#include #include -#include -#include #include +#include +#include #include #include #include +#include +#include -using namespace sd; +#include "testlayers.h" +using namespace sd; class DeclarableOpsTests12 : public testing::Test { -public: - - DeclarableOpsTests12() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests12() { + } }; TEST_F(DeclarableOpsTests12, test_any_validation_1) { - auto x = NDArrayFactory::create('c', {2, 1}, {1.0, 2.0}); - auto y = NDArrayFactory::create('c', {2}, {1, 0}); - - sd::ops::transpose op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {2, 1}, {1.0, 2.0}); + auto y = NDArrayFactory::create('c', {2}, {1, 0}); - auto z = result.at(0); - ASSERT_EQ(x.dataType(), z->dataType()); + sd::ops::transpose op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); - + auto z = result.at(0); + ASSERT_EQ(x.dataType(), z.dataType()); } - ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test1) { + NDArray labels('c', {2, 4}, {0, 1, 1, 0, 1, 0, 1, 0}); + NDArray predictions('c', {2, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,4}, {0,1,1,0,1,0,1,0}); - NDArray predictions('c', {2,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,1}, sd::DataType::DOUBLE); + NDArray dLdpExp('c', {2, 4}, {-0., -0.5, -0.5, -0., -0.5, -0., -0.5, -0.}); + NDArray dLdwExp('c', {2, 1}, {1.2, -0.2}); - NDArray dLdpExp('c', {2,4}, {-0. , -0.5, -0.5, -0., -0.5, -0. , -0.5, -0.}); - NDArray dLdwExp('c', {2,1}, {1.2, -0.2}); + predictions.linspace(-0.4, 0.2); + weights.assign(0.5); - predictions.linspace(-0.4, 0.2); - weights.assign(0.5); + sd::ops::cosine_distance_loss_grad op; - sd::ops::cosine_distance_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, -1}); - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, -1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - - + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test2) { + NDArray labels('c', {2, 4}, {-0.1, 0.3, 2, -1.4, 2.5, -3, 1.2, 2.2}); + NDArray predictions('c', {2, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 4}, sd::DataType::DOUBLE); - NDArray labels('c', {2,4}, {-0.1, 0.3, 2, -1.4, 2.5, -3, 1.2, 2.2}); - NDArray predictions('c', {2,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,4}, sd::DataType::DOUBLE); + NDArray dLdpExp('c', {2, 4}, {0.05, -0.15, -1., 0.7, -1.25, 1.5, -0.6, -1.1}); + NDArray dLdwExp('c', {1, 4}, {-0.04, 2.86, 0.04, -0.92}); + NDArray dLdlExp('c', {2, 4}, {0.2, 0.1, 0., -0.1, -0.2, -0.3, -0.4, -0.5}); - NDArray dLdpExp('c', {2,4}, {0.05, -0.15, -1. , 0.7 ,-1.25, 1.5 , -0.6 , -1.1 }); - NDArray dLdwExp('c', {1,4}, {-0.04, 2.86, 0.04, -0.92}); - NDArray dLdlExp('c', {2,4}, {0.2, 0.1, 0. , -0.1, -0.2, -0.3, -0.4, -0.5}); + predictions.linspace(-0.4, 0.2); + weights.assign(0.5); - predictions.linspace(-0.4, 0.2); - weights.assign(0.5); + sd::ops::cosine_distance_loss_grad op; - sd::ops::cosine_distance_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 0}); - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); - - + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test3) { + NDArray labels('c', {4}, {-0.1, 0.3, 2, -1.4}); + NDArray predictions('c', {4}, sd::DataType::DOUBLE); + NDArray weights('c', {1}, sd::DataType::DOUBLE); - NDArray labels('c', {4}, {-0.1, 0.3, 2, -1.4}); - NDArray predictions('c', {4}, sd::DataType::DOUBLE); - NDArray weights('c', {1}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {4}, {0.05, -0.15, -1., 0.7}); - NDArray dLdwExp('c', {1}, std::vector{1.3}); - NDArray dLdlExp('c', {4}, {0.2, 0.1, -0. , -0.1}); + NDArray dLdpExp('c', {4}, {0.05, -0.15, -1., 0.7}); + NDArray dLdwExp('c', {1}, std::vector{1.3}); + NDArray dLdlExp('c', {4}, {0.2, 0.1, -0., -0.1}); - predictions.linspace(-0.4, 0.2); - weights.assign(0.5); + predictions.linspace(-0.4, 0.2); + weights.assign(0.5); - sd::ops::cosine_distance_loss_grad op; + sd::ops::cosine_distance_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); - + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test4) { + NDArray labels('c', {1, 4}, {-0.1, 0.3, 2, -1.4}); + NDArray predictions('c', {1, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {}, std::vector{0.}, sd::DataType::DOUBLE); - NDArray labels('c', {1,4}, {-0.1, 0.3, 2, -1.4}); - NDArray predictions('c', {1,4}, sd::DataType::DOUBLE); - NDArray weights('c', {}, std::vector{0.}, sd::DataType::DOUBLE); + NDArray dLdpExp('c', {1, 4}, {0.05, -0.15, -1., 0.7}); + NDArray dLdwExp('c', {}, std::vector{1.3}); + NDArray dLdlExp('c', {1, 4}, {0.2, 0.1, -0., -0.1}); - NDArray dLdpExp('c', {1,4}, {0.05, -0.15, -1., 0.7}); - NDArray dLdwExp('c', {}, std::vector{1.3}); - NDArray dLdlExp('c', {1,4}, {0.2, 0.1, -0. , -0.1}); + predictions.linspace(-0.4, 0.2); + weights.assign(0.5); - predictions.linspace(-0.4, 0.2); - weights.assign(0.5); + sd::ops::cosine_distance_loss_grad op; - sd::ops::cosine_distance_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1, 1}); - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1, 1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); - - + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } - ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test5) { + NDArray labels('c', {4}, {-0.1, 0.3, 2, -1.4}, sd::DataType::DOUBLE); + NDArray predictions('c', {4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {4}, {-0.1, 0.3, 2, -1.4}, sd::DataType::DOUBLE); - NDArray predictions('c', {4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,1}, sd::DataType::DOUBLE); + NDArray dLdpExp('c', {4}, {0.1, -0.3, -2., 1.4}); + NDArray dLdwExp('c', {1, 1}, std::vector{0.}); + NDArray dLdlExp('c', {4}, {0.4, 0.2, -0., -0.2}); - NDArray dLdpExp('c', {4}, {0.1, -0.3, -2. , 1.4}); - NDArray dLdwExp('c', {1,1}, std::vector{0.}); - NDArray dLdlExp('c', {4}, {0.4, 0.2, -0. , -0.2}); + predictions.linspace(-0.4, 0.2); + weights = 0.5; - predictions.linspace(-0.4, 0.2); - weights = 0.5; + sd::ops::cosine_distance_loss_grad op; - sd::ops::cosine_distance_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2, 0}); - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2, 0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); - - + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test6) { + NDArray labels('c', {4, 1}, {-0.1, 0.3, 2, -1.4}, sd::DataType::DOUBLE); + NDArray predictions('c', {4, 1}, sd::DataType::DOUBLE); + NDArray weights('c', {4, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {4,1}, {-0.1, 0.3, 2, -1.4}, sd::DataType::DOUBLE); - NDArray predictions('c', {4,1}, sd::DataType::DOUBLE); - NDArray weights('c', {4,1}, sd::DataType::DOUBLE); + NDArray dLdpExp('c', {4, 1}, {0.0125, -0.0375, -0.25, 0.175}); + NDArray dLdwExp('c', {4, 1}, {0.24, 0.265, 0.25, 0.32}); + NDArray dLdlExp('c', {4, 1}, {0.05, 0.025, -0., -0.025}); - NDArray dLdpExp('c', {4,1}, {0.0125, -0.0375, -0.25 , 0.175}); - NDArray dLdwExp('c', {4,1}, {0.24 , 0.265, 0.25 , 0.32}); - NDArray dLdlExp('c', {4,1}, {0.05 , 0.025, -0. , -0.025}); + predictions.linspace(-0.4, 0.2); + weights = 0.5; - predictions.linspace(-0.4, 0.2); - weights = 0.5; + sd::ops::cosine_distance_loss_grad op; - sd::ops::cosine_distance_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3, 1}); - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3, 1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); - - + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test7) { + NDArray labels('c', {2, 3, 4}, {-0.1, 0.3, 2, -1.4, 2.5, -3, 1.2, 2.2, + -0.1, 0.3, 2, -3.4, 2.5, -3, 1.2, 2.2, + -0.2, 0.3, 2, -1.4, 2.7, -3, 1.2, 4.2}); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, {-0.1, 0.3, 2, -1.4, 2.5, -3, 1.2, 2.2,-0.1, 0.3, 2, -3.4, 2.5, -3, 1.2, 2.2,-0.2, 0.3, 2, -1.4, 2.7, -3, 1.2, 4.2}); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {0.00833, -0.025 , -0.16667, 0.11667,-0.20833, 0.25 , -0.1 , -0.18333, 0.00833, -0.025 , -0.16667, 0.28333, - -0.20833, 0.25 , -0.1 , -0.18333, 0.01667, -0.025 , -0.16667, 0.11667,-0.225 , 0.25 , -0.1 , -0.35 }); - NDArray dLdwExp('c', {1,3,1}, {0.50444, 0.89778, -1.40222}); - NDArray dLdlExp('c', {2,3,4}, {0.03333, 0.01667, -0. , -0.01667,-0.03333, -0.05 , -0.06667, -0.08333,-0.1, -0.11667, -0.13333, -0.15, - -0.16667, -0.18333, -0.2 , -0.21667,-0.23333, -0.25 , -0.26667, -0.28333,-0.3, -0.31667, -0.33333, -0.35 }); - - predictions.linspace(-0.4, 0.2); - weights = 0.5; + NDArray dLdpExp( + 'c', {2, 3, 4}, + {0.00833, -0.025, -0.16667, 0.11667, -0.20833, 0.25, -0.1, -0.18333, + 0.00833, -0.025, -0.16667, 0.28333, -0.20833, 0.25, -0.1, -0.18333, + 0.01667, -0.025, -0.16667, 0.11667, -0.225, 0.25, -0.1, -0.35}); + NDArray dLdwExp('c', {1, 3, 1}, {0.50444, 0.89778, -1.40222}); + NDArray dLdlExp('c', {2, 3, 4}, + {0.03333, 0.01667, -0., -0.01667, -0.03333, -0.05, + -0.06667, -0.08333, -0.1, -0.11667, -0.13333, -0.15, + -0.16667, -0.18333, -0.2, -0.21667, -0.23333, -0.25, + -0.26667, -0.28333, -0.3, -0.31667, -0.33333, -0.35}); - sd::ops::cosine_distance_loss_grad op; + predictions.linspace(-0.4, 0.2); + weights = 0.5; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2, 0}); + sd::ops::cosine_distance_loss_grad op; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2, 0}); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); - + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test8) { + NDArray labels('c', {2, 3, 4}, {-0.1, 0.3, 2, -1.4, 2.5, -3, 1.2, 2.2, + -0.1, 0.3, 2, -3.4, 2.5, -3, 1.2, 2.2, + -0.2, 0.3, 2, -1.4, 2.7, -3, 1.2, 4.2}); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 1, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, {-0.1, 0.3, 2, -1.4, 2.5, -3, 1.2, 2.2,-0.1, 0.3, 2, -3.4, 2.5, -3, 1.2, 2.2,-0.2, 0.3, 2, -1.4, 2.7, -3, 1.2, 4.2}); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,1,1}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {0.00625, -0.01875, -0.125 , 0.0875,-0.15625, 0.1875 , -0.075 , -0.1375, 0.00625, -0.01875, -0.125 , 0.2125, - -0.15625, 0.1875 , -0.075 , -0.1375, 0.0125 , -0.01875, -0.125 , 0.0875,-0.16875, 0.1875 , -0.075 , -0.2625}); - NDArray dLdwExp('c', {2,1,1}, {0.57, -3.2175}); - NDArray dLdlExp('c', {2,3,4}, {0.025, 0.0125, -0. , -0.0125,-0.025, -0.0375, -0.05, -0.0625,-0.075, -0.0875, -0.1 , -0.1125, - -0.125, -0.1375, -0.15, -0.1625,-0.175, -0.1875, -0.2 , -0.2125,-0.225, -0.2375, -0.25, -0.2625}); + NDArray dLdpExp( + 'c', {2, 3, 4}, + {0.00625, -0.01875, -0.125, 0.0875, -0.15625, 0.1875, -0.075, -0.1375, + 0.00625, -0.01875, -0.125, 0.2125, -0.15625, 0.1875, -0.075, -0.1375, + 0.0125, -0.01875, -0.125, 0.0875, -0.16875, 0.1875, -0.075, -0.2625}); + NDArray dLdwExp('c', {2, 1, 1}, {0.57, -3.2175}); + NDArray dLdlExp( + 'c', {2, 3, 4}, + {0.025, 0.0125, -0., -0.0125, -0.025, -0.0375, -0.05, -0.0625, + -0.075, -0.0875, -0.1, -0.1125, -0.125, -0.1375, -0.15, -0.1625, + -0.175, -0.1875, -0.2, -0.2125, -0.225, -0.2375, -0.25, -0.2625}); - predictions.linspace(-0.4, 0.2); - weights = 0.5; + predictions.linspace(-0.4, 0.2); + weights = 0.5; - sd::ops::cosine_distance_loss_grad op; + sd::ops::cosine_distance_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3, 1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3, 1}); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); - + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test9) { + NDArray labels('c', {2, 3, 4}, {-0.1, 0.3, 2, -1.4, 2.5, -3, 1.2, 2.2, + -0.1, 0.3, 2, -3.4, 2.5, -3, 1.2, 2.2, + -0.2, 0.3, 2, -1.4, 2.7, -3, 1.2, 4.2}); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, {-0.1, 0.3, 2, -1.4, 2.5, -3, 1.2, 2.2,-0.1, 0.3, 2, -3.4, 2.5, -3, 1.2, 2.2,-0.2, 0.3, 2, -1.4, 2.7, -3, 1.2, 4.2}); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,1}, sd::DataType::DOUBLE); + NDArray dLdpExp('c', {2, 3, 4}, + {0.05, -0.15, -1., 0.7, -1.25, 1.5, -0.6, -1.1, + 0.05, -0.15, -1., 1.7, -1.25, 1.5, -0.6, -1.1, + 0.1, -0.15, -1., 0.7, -1.35, 1.5, -0.6, -2.1}); + NDArray dLdwExp('c', {2, 3, 1}, {1.3, -1.36, 3.62, -6., -0.98, -19.76}); + NDArray dLdlExp( + 'c', {2, 3, 4}, + {0.2, 0.1, -0., -0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8, -0.9, + -1., -1.1, -1.2, -1.3, -1.4, -1.5, -1.6, -1.7, -1.8, -1.9, -2., -2.1}); - NDArray dLdpExp('c', {2,3,4}, {0.05, -0.15, -1. , 0.7,-1.25, 1.5 , -0.6 , -1.1, 0.05, -0.15, -1. , 1.7, - -1.25, 1.5 , -0.6 , -1.1, 0.1 , -0.15, -1. , 0.7,-1.35, 1.5 , -0.6 , -2.1}); - NDArray dLdwExp('c', {2,3,1}, {1.3 , -1.36, 3.62, -6. , -0.98,-19.76}); - NDArray dLdlExp('c', {2,3,4}, {0.2, 0.1, -0. , -0.1,-0.2, -0.3, -0.4, -0.5,-0.6, -0.7, -0.8, -0.9, - -1. , -1.1, -1.2, -1.3,-1.4, -1.5, -1.6, -1.7,-1.8, -1.9, -2. , -2.1}); + predictions.linspace(-0.4, 0.2); + weights = 0.5; - predictions.linspace(-0.4, 0.2); - weights = 0.5; + sd::ops::cosine_distance_loss_grad op; - sd::ops::cosine_distance_loss_grad op; - - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 2}); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); - + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } - ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, hinge_loss_14) { + NDArray logits('c', {3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {}, std::vector{1.}); + NDArray labels('c', {3, 4}, {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0}); - NDArray logits('c', {3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {}, std::vector{1.}); - NDArray labels('c', {3,4}, {0,1,1,0,1,0,1,0,1,0,1,0}); - - NDArray output('c', {}, std::vector{0.}, sd::DataType::DOUBLE); + NDArray output('c', {}, std::vector{0.}, sd::DataType::DOUBLE); - logits.linspace(1.); - weights.assign(1.); + logits.linspace(1.); + weights.assign(1.); - sd::ops::hinge_loss op; - Nd4jStatus status = op.execute({&logits, &weights, &labels}, {&output}, {}, {1}, {}); + sd::ops::hinge_loss op; + Nd4jStatus status = + op.execute({&logits, &weights, &labels}, {&output}, {}, {1}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(output.e(0) == 47.); + ASSERT_TRUE(output.e(0) == 47.); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TestDivideBP_1) { + NDArray x('c', {3, 4}, sd::DataType::DOUBLE); + NDArray y = NDArrayFactory::create(2.); + NDArray eps('c', {3, 4}, sd::DataType::DOUBLE); - NDArray x('c', {3,4}, sd::DataType::DOUBLE); - NDArray y = NDArrayFactory::create(2.); - NDArray eps('c', {3,4}, sd::DataType::DOUBLE); + NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); + NDArray output2(sd::DataType::DOUBLE); - NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); - NDArray output2(sd::DataType::DOUBLE); + x.linspace(2., 2.); + eps.linspace(1.); - x.linspace(2., 2.); - eps.linspace(1.); + sd::ops::divide_bp op; + Nd4jStatus status = + op.execute({&x, &y, &eps}, {&output1, &output2}, {}, {}, {}); - sd::ops::divide_bp op; - Nd4jStatus status = op.execute({&x, &y, &eps}, {&output1, &output2}, {}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); - //ASSERT_TRUE(output.e(0) == 47.); + ASSERT_EQ(ND4J_STATUS_OK, status); + // ASSERT_TRUE(output.e(0) == 47.); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TestDivideBP_2) { - - NDArray x('c', {3,4}, sd::DataType::DOUBLE); - NDArray y = NDArrayFactory::create('c', {3,4}); - NDArray eps('c', {3,4}, sd::DataType::DOUBLE); - NDArray exp1('c', {3,4}, sd::DataType::DOUBLE); - NDArray exp2('c', {3,4}, sd::DataType::DOUBLE); - NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); - NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); - exp1.assign(1.); - exp2.assign(-2.); - x.linspace(2., 2.); - y.linspace(1.); - eps.linspace(1.); - - sd::ops::divide_bp op; - Nd4jStatus status = op.execute({&x, &y, &eps}, std::vector{&output1, &output2}, {}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(output1.equalsTo(exp1)); - ASSERT_TRUE(output2.equalsTo(exp2)); + NDArray x('c', {3, 4}, sd::DataType::DOUBLE); + NDArray y = NDArrayFactory::create('c', {3, 4}); + NDArray eps('c', {3, 4}, sd::DataType::DOUBLE); + NDArray exp1('c', {3, 4}, sd::DataType::DOUBLE); + NDArray exp2('c', {3, 4}, sd::DataType::DOUBLE); + NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); + NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); + exp1.assign(1.); + exp2.assign(-2.); + x.linspace(2., 2.); + y.linspace(1.); + eps.linspace(1.); + + sd::ops::divide_bp op; + Nd4jStatus status = op.execute( + {&x, &y, &eps}, std::vector{&output1, &output2}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output1.equalsTo(exp1)); + ASSERT_TRUE(output2.equalsTo(exp2)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TestReverseDivideBP_1) { + NDArray x('c', {3, 4}, sd::DataType::DOUBLE); + NDArray y = NDArrayFactory::create(2.); + NDArray eps('c', {3, 4}, sd::DataType::DOUBLE); - NDArray x('c', {3,4}, sd::DataType::DOUBLE); - NDArray y = NDArrayFactory::create(2.); - NDArray eps('c', {3,4}, sd::DataType::DOUBLE); - - NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); - NDArray output2(sd::DataType::DOUBLE); + NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); + NDArray output2(sd::DataType::DOUBLE); - x.linspace(2., 2.); - eps.linspace(1.); + x.linspace(2., 2.); + eps.linspace(1.); - sd::ops::reversedivide_bp op; - Nd4jStatus status = op.execute({&y, &x, &eps}, std::vector{&output2, &output1}, {}, {}, {}); + sd::ops::reversedivide_bp op; + Nd4jStatus status = op.execute( + {&y, &x, &eps}, std::vector{&output2, &output1}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); - //ASSERT_TRUE(output.e(0) == 47.); + ASSERT_EQ(ND4J_STATUS_OK, status); + // ASSERT_TRUE(output.e(0) == 47.); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TestReverseDivideBP_2) { + NDArray x('c', {3, 4}, sd::DataType::DOUBLE); + NDArray y = NDArrayFactory::create('c', {3, 4}); + NDArray eps('c', {3, 4}, sd::DataType::DOUBLE); + NDArray exp1('c', {3, 4}, sd::DataType::DOUBLE); + NDArray exp2('c', {3, 4}, sd::DataType::DOUBLE); - NDArray x('c', {3,4}, sd::DataType::DOUBLE); - NDArray y = NDArrayFactory::create('c', {3,4}); - NDArray eps('c', {3,4}, sd::DataType::DOUBLE); - NDArray exp1('c', {3,4}, sd::DataType::DOUBLE); - NDArray exp2('c', {3,4}, sd::DataType::DOUBLE); + NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); + NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); - NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); - NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); + x.linspace(2., 2.); + y.linspace(1.); + eps.linspace(1.); + exp1.assign(1.); + exp2.assign(-2.); + sd::ops::reversedivide_bp op; + Nd4jStatus status = op.execute( + {&y, &x, &eps}, std::vector{&output2, &output1}, {}, {}, {}); - x.linspace(2., 2.); - y.linspace(1.); - eps.linspace(1.); - exp1.assign(1.); - exp2.assign(-2.); - sd::ops::reversedivide_bp op; - Nd4jStatus status = op.execute({&y, &x, &eps}, std::vector{&output2, &output1}, {}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(output1.equalsTo(exp1)); - ASSERT_TRUE(output2.equalsTo(exp2)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output1.equalsTo(exp1)); + ASSERT_TRUE(output2.equalsTo(exp2)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TestSliceBP_1) { - - NDArray x('c', {3,4}, sd::DataType::DOUBLE); - NDArray eps('c', {2,2}, sd::DataType::DOUBLE); - NDArray exp('c', {3,4}, {0., 0., 0., 0., 0., 1.,1., 0., 0., 1., 1., 0.}); - //NDArray exp2('c', {3,4}, sd::DataType::DOUBLE); - - NDArray output('c', {3, 4}, sd::DataType::DOUBLE); - //NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); - output.assign(119.113); - x.linspace(1.); - eps.assign(1.); - //exp1.assign(1.); - //exp2.assign(-2.); - sd::ops::slice_bp op; - Nd4jStatus status = op.execute({&x, &eps}, {&output}, {}, {1,1,2,2}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(output.equalsTo(exp)); - //ASSERT_TRUE(output2.equalsTo(exp2)); + NDArray x('c', {3, 4}, sd::DataType::DOUBLE); + NDArray eps('c', {2, 2}, sd::DataType::DOUBLE); + NDArray exp('c', {3, 4}, {0., 0., 0., 0., 0., 1., 1., 0., 0., 1., 1., 0.}); + // NDArray exp2('c', {3,4}, sd::DataType::DOUBLE); + + NDArray output('c', {3, 4}, sd::DataType::DOUBLE); + // NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); + output.assign(119.113); + x.linspace(1.); + eps.assign(1.); + // exp1.assign(1.); + // exp2.assign(-2.); + sd::ops::slice_bp op; + Nd4jStatus status = op.execute({&x, &eps}, {&output}, {}, {1, 1, 2, 2}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output.equalsTo(exp)); + // ASSERT_TRUE(output2.equalsTo(exp2)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TestConfusionZero_1) { - - NDArray x('c', {2}, {1,2}, sd::DataType::INT64); - NDArray i('c', {2}, {0,2}, sd::DataType::INT64); - //NDArray eps('c', {2,2}, sd::DataType::DOUBLE); - NDArray exp('c', {4,4}, {0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0}, sd::DataType::INT64); - //NDArray exp2('c', {3,4}, sd::DataType::DOUBLE); - - NDArray output('c', {4, 4}, sd::DataType::INT64); - //NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); - output.assign(119.113); - x.linspace(1.); - //eps.assign(1.); - //exp1.assign(1.); - //exp2.assign(-2.); - sd::ops::confusion_matrix op; - Nd4jStatus status = op.execute({&x, &i}, {&output}, {}, {4}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(output.equalsTo(exp)); - //ASSERT_TRUE(output2.equalsTo(exp2)); + NDArray x('c', {2}, {1, 2}, sd::DataType::INT64); + NDArray i('c', {2}, {0, 2}, sd::DataType::INT64); + // NDArray eps('c', {2,2}, sd::DataType::DOUBLE); + NDArray exp('c', {4, 4}, {0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0}, + sd::DataType::INT64); + // NDArray exp2('c', {3,4}, sd::DataType::DOUBLE); + + NDArray output('c', {4, 4}, sd::DataType::INT64); + // NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); + output.assign(119.113); + x.linspace(1.); + // eps.assign(1.); + // exp1.assign(1.); + // exp2.assign(-2.); + sd::ops::confusion_matrix op; + Nd4jStatus status = op.execute({&x, &i}, {&output}, {}, {4}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output.equalsTo(exp)); + // ASSERT_TRUE(output2.equalsTo(exp2)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TestMaximumBP_1) { - - NDArray x('c', {3,4}, sd::DataType::DOUBLE); - NDArray y('c', {3,4}, sd::DataType::DOUBLE); - NDArray eps('c', {3,4}, sd::DataType::DOUBLE); - NDArray exp1('c', {3,4}, {0, 0, 0, 0, 0, 0, 7, 8, 9, 10, 11, 12}, sd::DataType::DOUBLE); - NDArray exp2('c', {3,4}, {1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0}, sd::DataType::DOUBLE); - - NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); - NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); - output1.assign(119); - x.linspace(1.); - y.linspace(12., -1.); - eps.linspace(1.); - //exp1.assign(1.); - //exp2.assign(-2.); - sd::ops::maximum_bp op; - Nd4jStatus status = op.execute({&x, &y, &eps}, std::vector{&output1, &output2}, {}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(output1.equalsTo(exp1)); - ASSERT_TRUE(output2.equalsTo(exp2)); + NDArray x('c', {3, 4}, sd::DataType::DOUBLE); + NDArray y('c', {3, 4}, sd::DataType::DOUBLE); + NDArray eps('c', {3, 4}, sd::DataType::DOUBLE); + NDArray exp1('c', {3, 4}, {0, 0, 0, 0, 0, 0, 7, 8, 9, 10, 11, 12}, + sd::DataType::DOUBLE); + NDArray exp2('c', {3, 4}, {1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0}, + sd::DataType::DOUBLE); + + NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); + NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); + output1.assign(119); + x.linspace(1.); + y.linspace(12., -1.); + eps.linspace(1.); + // exp1.assign(1.); + // exp2.assign(-2.); + sd::ops::maximum_bp op; + Nd4jStatus status = op.execute( + {&x, &y, &eps}, std::vector{&output1, &output2}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output1.equalsTo(exp1)); + ASSERT_TRUE(output2.equalsTo(exp2)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TestMinimumBP_1) { - - NDArray x('c', {3,4}, sd::DataType::DOUBLE); - NDArray y('c', {3,4}, sd::DataType::DOUBLE); - NDArray eps('c', {3,4}, sd::DataType::DOUBLE); - NDArray exp1('c', {3,4}, {0, 0, 0, 0, 0, 0, 7, 8, 9, 10, 11, 12}, sd::DataType::DOUBLE); - NDArray exp2('c', {3,4}, {1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0}, sd::DataType::DOUBLE); - - NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); - NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); - output1.assign(119); - x.linspace(1.); - y.linspace(12., -1.); - eps.linspace(1.); - //exp1.assign(1.); - //exp2.assign(-2.); - sd::ops::minimum_bp op; - Nd4jStatus status = op.execute({&x, &y, &eps}, std::vector{&output2, &output1}, {}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(output1.equalsTo(exp1)); - ASSERT_TRUE(output2.equalsTo(exp2)); + NDArray x('c', {3, 4}, sd::DataType::DOUBLE); + NDArray y('c', {3, 4}, sd::DataType::DOUBLE); + NDArray eps('c', {3, 4}, sd::DataType::DOUBLE); + NDArray exp1('c', {3, 4}, {0, 0, 0, 0, 0, 0, 7, 8, 9, 10, 11, 12}, + sd::DataType::DOUBLE); + NDArray exp2('c', {3, 4}, {1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0}, + sd::DataType::DOUBLE); + + NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); + NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); + output1.assign(119); + x.linspace(1.); + y.linspace(12., -1.); + eps.linspace(1.); + // exp1.assign(1.); + // exp2.assign(-2.); + sd::ops::minimum_bp op; + Nd4jStatus status = op.execute( + {&x, &y, &eps}, std::vector{&output2, &output1}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output1.equalsTo(exp1)); + ASSERT_TRUE(output2.equalsTo(exp2)); } - ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, reverse_test15) { + NDArray x('c', {5}, {1, 2, 3, 4, 5}, sd::DataType::DOUBLE); + NDArray axis('c', {}, std::vector{0}, sd::DataType::INT32); + NDArray z('c', {5}, sd::DataType::DOUBLE); + NDArray exp('c', {5}, {5, 4, 3, 2, 1}, sd::DataType::DOUBLE); - NDArray x('c', {5}, {1,2,3,4,5}, sd::DataType::DOUBLE); - NDArray axis('c', {}, std::vector{0}, sd::DataType::INT32); - NDArray z('c', {5}, sd::DataType::DOUBLE); - NDArray exp('c', {5}, {5,4,3,2,1}, sd::DataType::DOUBLE); + sd::ops::reverse op; + // auto result = op.execute({&x, &axis}, {}, {1}, {}); + Nd4jStatus status = op.execute({&x, &axis}, {&z}, {}, {1}, {}); + // auto z = result.at(0); + // z->printIndexedBuffer(); - - sd::ops::reverse op; - // auto result = op.execute({&x, &axis}, {}, {1}, {}); - Nd4jStatus status = op.execute({&x, &axis}, {&z}, {}, {1}, {}); - // auto z = result.at(0); - // z->printIndexedBuffer(); - - ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - // + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + // } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, mirrorPad_test17) { + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); + NDArray padding('c', {2, 2}, {1, 1, 2, 2}, sd::DataType::INT64); + NDArray z('c', {4, 7}, sd::DataType::DOUBLE); + NDArray exp1('c', {4, 7}, {6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, + 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}, + sd::DataType::DOUBLE); + NDArray exp2('c', {4, 7}, {2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, + 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5}, + sd::DataType::DOUBLE); - NDArray x('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::DOUBLE); - NDArray padding('c', {2,2}, {1,1,2,2}, sd::DataType::INT64); - NDArray z('c', {4,7}, sd::DataType::DOUBLE); - NDArray exp1('c', {4,7}, {6, 5, 4, 5, 6, 5, 4,3, 2, 1, 2, 3, 2, 1,6, 5, 4, 5, 6, 5, 4,3, 2, 1, 2, 3, 2, 1}, sd::DataType::DOUBLE); - NDArray exp2('c', {4,7}, {2, 1, 1, 2, 3, 3, 2,2, 1, 1, 2, 3, 3, 2,5, 4, 4, 5, 6, 6, 5,5, 4, 4, 5, 6, 6, 5}, sd::DataType::DOUBLE); + sd::ops::mirror_pad op; + Nd4jStatus status = op.execute({&x, &padding}, {&z}, {}, {0}, {}); // reflect - sd::ops::mirror_pad op; - Nd4jStatus status = op.execute({&x, &padding}, {&z}, {}, {0}, {}); // reflect + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(exp1.isSameShape(z)); + ASSERT_TRUE(exp1.equalsTo(z)); - ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(exp1.isSameShape(z)); - ASSERT_TRUE(exp1.equalsTo(z)); + z = 0.; + status = op.execute({&x, &padding}, {&z}, {}, {1}, {}); // symmetric - z = 0.; - status = op.execute({&x, &padding}, {&z}, {}, {1}, {}); // symmetric - - ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(exp2.isSameShape(z)); - ASSERT_TRUE(exp2.equalsTo(z)); + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(exp2.isSameShape(z)); + ASSERT_TRUE(exp2.equalsTo(z)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, mirrorPad_test18) { + NDArray x('c', {3}, {1, 2, 3}, sd::DataType::DOUBLE); + NDArray padding('c', {1, 2}, {1, 1}, sd::DataType::INT32); + NDArray z('c', {5}, sd::DataType::DOUBLE); + NDArray exp('c', {5}, {2, 1, 2, 3, 2}, sd::DataType::DOUBLE); - NDArray x('c', {3}, {1,2,3}, sd::DataType::DOUBLE); - NDArray padding('c', {1, 2}, {1,1}, sd::DataType::INT32); - NDArray z('c', {5}, sd::DataType::DOUBLE); - NDArray exp('c', {5}, {2,1,2,3,2}, sd::DataType::DOUBLE); - - sd::ops::mirror_pad op; - Nd4jStatus status = op.execute({&x, &padding}, {&z}, {}, {0}, {}); // reflect + sd::ops::mirror_pad op; + Nd4jStatus status = op.execute({&x, &padding}, {&z}, {}, {0}, {}); // reflect - ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, relu_1) { - - NDArray input('c', {1,5,5,6}, { 0.557449, 0.768277, 1.094015, -0.557449, -0.768277, -1.094015,0.563735, 0.900299, 0.789979, -0.563735, -0.900299, -0.789979, - 0.142528, 0.959611, 0.877506, -0.142528, -0.959611, -0.877506,0.448742, 0.995377, 1.171543, -0.448742, -0.995377, -1.171543, - 0.603772, 0.799391, 0.560310, -0.603772, -0.799391, -0.560310,0.529753, 0.906786, 0.737630, -0.529753, -0.906786, -0.737630, - 0.221464, 0.824996, 0.472221, -0.221464, -0.824996, -0.472221,0.427730, 0.397933, 0.714365, -0.427730, -0.397933, -0.714365, - 0.488365, 1.016589, 0.744197, -0.488365, -1.016589, -0.744197,0.789846, 0.940837, 0.838412, -0.789846, -0.940837, -0.838412, - 0.404485, 0.677328, 0.754997, -0.404485, -0.677328, -0.754997,0.436760, 0.794765, 0.729766, -0.436760, -0.794765, -0.729766, - 0.588081, 0.652226, 0.725522, -0.588081, -0.652226, -0.725522,0.374457, 1.225813, 1.053411, -0.374457, -1.225813, -1.053411, - 0.300958, 0.599417, 0.633234, -0.300958, -0.599417, -0.633234,0.241993, 1.025464, 0.695378, -0.241993, -1.025464, -0.695378, - 0.236289, 0.907919, 1.012100, -0.236289, -0.907919, -1.012100,0.627402, 0.565187, 0.766926, -0.627402, -0.565187, -0.766926, - 0.133276, 0.326284, 0.102804, -0.133276, -0.326284, -0.102804,0.426913, 0.256251, 0.305241, -0.426913, -0.256251, -0.305241, - 0.177977, 0.841799, 0.800615, -0.177977, -0.841799, -0.800615,0.001991, 0.518389, 0.439322, -0.001991, -0.518389, -0.439322, - 0.166846, 0.508224, 0.486687, -0.166846, -0.508224, -0.486687,0.167493, 0.930932, 0.868717, -0.167493, -0.930932, -0.868717, - 0.174864, 0.444607, 0.445000, -0.174864, -0.444607, -0.445000}, sd::DataType::FLOAT32); - - NDArray expected('c', {1,5,5,6}, { 0.557449, 0.768277, 1.094015, 0., 0., 0., 0.563735, 0.900299, 0.789979, 0., 0., 0., - 0.142528, 0.959611, 0.877506, 0., 0., 0., 0.448742, 0.995377, 1.171543, 0., 0., 0., - 0.603772, 0.799391, 0.560310, 0., 0., 0., 0.529753, 0.906786, 0.737630, 0., 0., 0., - 0.221464, 0.824996, 0.472221, 0., 0., 0., 0.427730, 0.397933, 0.714365, 0., 0., 0., - 0.488365, 1.016589, 0.744197, 0., 0., 0., 0.789846, 0.940837, 0.838412, 0., 0., 0., - 0.404485, 0.677328, 0.754997, 0., 0., 0., 0.436760, 0.794765, 0.729766, 0., 0., 0., - 0.588081, 0.652226, 0.725522, 0., 0., 0., 0.374457, 1.225813, 1.053411, 0., 0., 0., - 0.300958, 0.599417, 0.633234, 0., 0., 0., 0.241993, 1.025464, 0.695378, 0., 0., 0., - 0.236289, 0.907919, 1.012100, 0., 0., 0., 0.627402, 0.565187, 0.766926, 0., 0., 0., - 0.133276, 0.326284, 0.102804, 0., 0., 0., 0.426913, 0.256251, 0.305241, 0., 0., 0., - 0.177977, 0.841799, 0.800615, 0., 0., 0., 0.001991, 0.518389, 0.439322, 0., 0., 0., - 0.166846, 0.508224, 0.486687, 0., 0., 0., 0.167493, 0.930932, 0.868717, 0., 0., 0., - 0.174864, 0.444607, 0.445000, 0., 0., 0.}, sd::DataType::FLOAT32); - - NDArray z('c', {1,5,5,6}, sd::DataType::FLOAT32); - - sd::ops::relu op; - Nd4jStatus status = op.execute({&input}, {&z}, {0}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.isSameShapeStrict(z)); - ASSERT_TRUE(expected.equalsTo(z)); + NDArray input('c', {1, 5, 5, 6}, + {0.557449, 0.768277, 1.094015, -0.557449, -0.768277, -1.094015, + 0.563735, 0.900299, 0.789979, -0.563735, -0.900299, -0.789979, + 0.142528, 0.959611, 0.877506, -0.142528, -0.959611, -0.877506, + 0.448742, 0.995377, 1.171543, -0.448742, -0.995377, -1.171543, + 0.603772, 0.799391, 0.560310, -0.603772, -0.799391, -0.560310, + 0.529753, 0.906786, 0.737630, -0.529753, -0.906786, -0.737630, + 0.221464, 0.824996, 0.472221, -0.221464, -0.824996, -0.472221, + 0.427730, 0.397933, 0.714365, -0.427730, -0.397933, -0.714365, + 0.488365, 1.016589, 0.744197, -0.488365, -1.016589, -0.744197, + 0.789846, 0.940837, 0.838412, -0.789846, -0.940837, -0.838412, + 0.404485, 0.677328, 0.754997, -0.404485, -0.677328, -0.754997, + 0.436760, 0.794765, 0.729766, -0.436760, -0.794765, -0.729766, + 0.588081, 0.652226, 0.725522, -0.588081, -0.652226, -0.725522, + 0.374457, 1.225813, 1.053411, -0.374457, -1.225813, -1.053411, + 0.300958, 0.599417, 0.633234, -0.300958, -0.599417, -0.633234, + 0.241993, 1.025464, 0.695378, -0.241993, -1.025464, -0.695378, + 0.236289, 0.907919, 1.012100, -0.236289, -0.907919, -1.012100, + 0.627402, 0.565187, 0.766926, -0.627402, -0.565187, -0.766926, + 0.133276, 0.326284, 0.102804, -0.133276, -0.326284, -0.102804, + 0.426913, 0.256251, 0.305241, -0.426913, -0.256251, -0.305241, + 0.177977, 0.841799, 0.800615, -0.177977, -0.841799, -0.800615, + 0.001991, 0.518389, 0.439322, -0.001991, -0.518389, -0.439322, + 0.166846, 0.508224, 0.486687, -0.166846, -0.508224, -0.486687, + 0.167493, 0.930932, 0.868717, -0.167493, -0.930932, -0.868717, + 0.174864, 0.444607, 0.445000, -0.174864, -0.444607, -0.445000}, + sd::DataType::FLOAT32); + + NDArray expected( + 'c', {1, 5, 5, 6}, + {0.557449, 0.768277, 1.094015, 0., 0., 0., 0.563735, + 0.900299, 0.789979, 0., 0., 0., 0.142528, 0.959611, + 0.877506, 0., 0., 0., 0.448742, 0.995377, 1.171543, + 0., 0., 0., 0.603772, 0.799391, 0.560310, 0., + 0., 0., 0.529753, 0.906786, 0.737630, 0., 0., + 0., 0.221464, 0.824996, 0.472221, 0., 0., 0., + 0.427730, 0.397933, 0.714365, 0., 0., 0., 0.488365, + 1.016589, 0.744197, 0., 0., 0., 0.789846, 0.940837, + 0.838412, 0., 0., 0., 0.404485, 0.677328, 0.754997, + 0., 0., 0., 0.436760, 0.794765, 0.729766, 0., + 0., 0., 0.588081, 0.652226, 0.725522, 0., 0., + 0., 0.374457, 1.225813, 1.053411, 0., 0., 0., + 0.300958, 0.599417, 0.633234, 0., 0., 0., 0.241993, + 1.025464, 0.695378, 0., 0., 0., 0.236289, 0.907919, + 1.012100, 0., 0., 0., 0.627402, 0.565187, 0.766926, + 0., 0., 0., 0.133276, 0.326284, 0.102804, 0., + 0., 0., 0.426913, 0.256251, 0.305241, 0., 0., + 0., 0.177977, 0.841799, 0.800615, 0., 0., 0., + 0.001991, 0.518389, 0.439322, 0., 0., 0., 0.166846, + 0.508224, 0.486687, 0., 0., 0., 0.167493, 0.930932, + 0.868717, 0., 0., 0., 0.174864, 0.444607, 0.445000, + 0., 0., 0.}, + sd::DataType::FLOAT32); + + NDArray z('c', {1, 5, 5, 6}, sd::DataType::FLOAT32); + + sd::ops::relu op; + Nd4jStatus status = op.execute({&input}, {&z}, {0}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.isSameShapeStrict(z)); + ASSERT_TRUE(expected.equalsTo(z)); } #include "ops/declarable/helpers/multiUnique.h" //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, multiUnique_1) { + NDArray input1('c', {3, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + sd::DataType::INT32); + NDArray input2('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + sd::DataType::INT32); + NDArray input3('c', {2, 3}, {10, 11, 12, 13, 14, 15}, sd::DataType::INT32); + NDArray input4('c', {1, 5}, {7, 8, 9, 10, 11}, sd::DataType::INT32); + NDArray input5('c', {5, 3}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + sd::DataType::INT32); - NDArray input1('c', {3,5}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, sd::DataType::INT32); - NDArray input2('c', {3,4}, {1,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::INT32); - NDArray input3('c', {2,3}, {10,11,12,13,14,15}, sd::DataType::INT32); - NDArray input4('c', {1,5}, {7,8,9,10,11}, sd::DataType::INT32); - NDArray input5('c', {5,3}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, sd::DataType::INT32); - - //NDArray indices('c', {1}, {2}, sd::DataType::INT32); - //NDArray expected('c', {1,5}, {11, 12, 13, 14, 15.}, sd::DataType::FLOAT32); + // NDArray indices('c', {1}, {2}, sd::DataType::INT32); + // NDArray expected('c', {1,5}, {11, 12, 13, 14, 15.}, sd::DataType::FLOAT32); - std::vector arrayList({&input1, &input2, &input3, &input4, &input5}); + std::vector arrayList( + {&input1, &input2, &input3, &input4, &input5}); - ASSERT_FALSE(sd::ops::helpers::multiUnique(arrayList)); + ASSERT_FALSE(sd::ops::helpers::multiUnique(arrayList)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, multiUnique_2) { - - NDArray input1('c', {3,5}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, sd::DataType::INT32); - NDArray input2('c', {3,4}, {21,22,23,24,25,26,27,28,29,210,211,212}, sd::DataType::INT32); - NDArray input3('c', {2,3}, {310,311,312,313,314,315}, sd::DataType::INT32); - NDArray input4('c', {1,5}, {47,48,49,410,411}, sd::DataType::INT32); - NDArray input5('c', {5,3}, {51,52,53,54,55,56,57,58,59,510,511,512,513,514,515}, sd::DataType::INT32); - - //NDArray indices('c', {1}, {2}, sd::DataType::INT32); - //NDArray expected('c', {1,5}, {11, 12, 13, 14, 15.}, sd::DataType::FLOAT32); - - std::vector arrayList({&input1, &input2, &input3, &input4, &input5}); - ASSERT_TRUE(sd::ops::helpers::multiUnique(arrayList)); + NDArray input1('c', {3, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + sd::DataType::INT32); + NDArray input2('c', {3, 4}, + {21, 22, 23, 24, 25, 26, 27, 28, 29, 210, 211, 212}, + sd::DataType::INT32); + NDArray input3('c', {2, 3}, {310, 311, 312, 313, 314, 315}, + sd::DataType::INT32); + NDArray input4('c', {1, 5}, {47, 48, 49, 410, 411}, sd::DataType::INT32); + NDArray input5( + 'c', {5, 3}, + {51, 52, 53, 54, 55, 56, 57, 58, 59, 510, 511, 512, 513, 514, 515}, + sd::DataType::INT32); + + // NDArray indices('c', {1}, {2}, sd::DataType::INT32); + // NDArray expected('c', {1,5}, {11, 12, 13, 14, 15.}, sd::DataType::FLOAT32); + + std::vector arrayList( + {&input1, &input2, &input3, &input4, &input5}); + ASSERT_TRUE(sd::ops::helpers::multiUnique(arrayList)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, reduceMeanBp_4) { + NDArray x('c', {3, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + NDArray gradO('c', {5}, sd::DataType::DOUBLE); + NDArray exp('c', {3, 5}, sd::DataType::DOUBLE); - NDArray x('c', {3,5}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}); - NDArray gradO('c', {5}, sd::DataType::DOUBLE); - NDArray exp('c', {3,5}, sd::DataType::DOUBLE); - - gradO = 1.; - exp = 0.333333; - - sd::ops::reduce_mean_bp op; - auto result = op.evaluate({&x, &gradO}, {}, {0}); - auto output = result.at(0); + gradO = 1.; + exp = 0.333333; - // output->printShapeInfo(); - // output->printIndexedBuffer(); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_mean_bp op; + auto result = op.evaluate({&x, &gradO}, {}, {0}); + auto output = result.at(0); - + // output->printShapeInfo(); + // output->printIndexedBuffer(); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, reduceMeanBp_5) { + NDArray x('c', {3, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + NDArray gradO('c', {3}, sd::DataType::DOUBLE); + NDArray exp('c', {3, 5}, sd::DataType::DOUBLE); - NDArray x('c', {3,5}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}); - NDArray gradO('c', {3}, sd::DataType::DOUBLE); - NDArray exp('c', {3,5}, sd::DataType::DOUBLE); + gradO = 1.; + exp = 0.2; - gradO = 1.; - exp = 0.2; + sd::ops::reduce_mean_bp op; + auto result = op.evaluate({&x, &gradO}, {}, {1}); + auto output = result.at(0); - sd::ops::reduce_mean_bp op; - auto result = op.evaluate({&x, &gradO}, {}, {1}); - auto output = result.at(0); - - // output->printShapeInfo(); - // output->printIndexedBuffer(); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + // output->printShapeInfo(); + // output->printIndexedBuffer(); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, reduceSqnormBp_1) { + NDArray x('c', {8, 6, 4}, sd::DataType::DOUBLE); + NDArray gradO('c', {8, 6, 1}, sd::DataType::DOUBLE); - NDArray x('c', {8,6,4}, sd::DataType::DOUBLE); - NDArray gradO('c', {8,6,1}, sd::DataType::DOUBLE); - - sd::ops::reduce_sqnorm_bp op; - auto result = op.evaluate({&x, &gradO}, {1}, {2}); - ASSERT_EQ(Status::OK(), result.status()); - - + sd::ops::reduce_sqnorm_bp op; + auto result = op.evaluate({&x, &gradO}, {1}, {2}); + ASSERT_EQ(Status::OK(), result.status()); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pullRows_1) { + NDArray x('c', {5, 1}, {0, 1, 2, 3, 4}); + NDArray z('c', {4, 1}, sd::DataType::DOUBLE); + NDArray exp('c', {4, 1}, {0, 2, 3, 4}); - NDArray x('c', {5, 1}, {0,1,2,3,4}); - NDArray z('c', {4, 1}, sd::DataType::DOUBLE); - NDArray exp('c', {4, 1}, {0,2,3,4}); - - Nd4jLong indexes[] = {0,2,3,4}; - PointersManager pm(LaunchContext::defaultContext(), "pullRows"); - auto pidx = reinterpret_cast(pm.replicatePointer(indexes, 4 * sizeof(Nd4jLong))); + Nd4jLong indexes[] = {0, 2, 3, 4}; + PointersManager pm(LaunchContext::defaultContext(), "pullRows"); + auto pidx = reinterpret_cast( + pm.replicatePointer(indexes, 4 * sizeof(Nd4jLong))); - std::vector dims = {1}; + std::vector dims = {1}; - auto xTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), dims); - auto zTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(z.shapeInfo(), dims); + auto xTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + x.shapeInfo(), dims); + auto zTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + z.shapeInfo(), dims); - Nd4jPointer nativeStart[2]; + Nd4jPointer nativeStart[2]; #ifdef __CUDABLAS__ - nativeStart[1] = (x.getContext()->getCudaStream()); + nativeStart[1] = (x.getContext()->getCudaStream()); #endif - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); - pullRows(nativeStart, &xBuf, x.shapeInfo(), x.specialShapeInfo(), - &zBuf, z.shapeInfo(), z.specialShapeInfo(), - 4, pidx, - xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), - zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + pullRows(nativeStart, &xBuf, x.shapeInfo(), x.specialShapeInfo(), &zBuf, + z.shapeInfo(), z.specialShapeInfo(), 4, pidx, + xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), + zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); - ASSERT_TRUE(z.equalsTo(exp)); - pm.synchronize(); + ASSERT_TRUE(z.equalsTo(exp)); + pm.synchronize(); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pullRows_2) { + NDArray arr('f', {5, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + NDArray* y = new NDArray(arr.dup('c')); + NDArray x = (*y)({0, 0, 0, 1}, + true); // view, points on first column of y, shape is {5,1} - NDArray arr('f', {5, 2}, {0,1,2,3,4,5,6,7,8,9}); - NDArray* y = new NDArray(arr.dup('c')); - NDArray x = (*y)({0,0, 0,1}, true); // view, points on first column of y, shape is {5,1} + NDArray z('c', {4, 1}, sd::DataType::DOUBLE); + NDArray exp('c', {4, 1}, {0, 2, 3, 4}); - NDArray z('c', {4, 1}, sd::DataType::DOUBLE); - NDArray exp('c', {4, 1}, {0,2,3,4}); + Nd4jLong indexes[] = {0, 2, 3, 4}; + PointersManager pm(LaunchContext::defaultContext(), "pullRows"); + auto pidx = reinterpret_cast( + pm.replicatePointer(indexes, 4 * sizeof(Nd4jLong))); - Nd4jLong indexes[] = {0,2,3,4}; - PointersManager pm(LaunchContext::defaultContext(), "pullRows"); - auto pidx = reinterpret_cast(pm.replicatePointer(indexes, 4 * sizeof(Nd4jLong))); + std::vector dims = {1}; - std::vector dims = {1}; + auto xTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + x.shapeInfo(), dims); + auto zTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + z.shapeInfo(), dims); - auto xTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), dims); - auto zTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(z.shapeInfo(), dims); - - Nd4jPointer nativeStart[2]; + Nd4jPointer nativeStart[2]; #ifdef __CUDABLAS__ - nativeStart[1] = (x.getContext()->getCudaStream()); + nativeStart[1] = (x.getContext()->getCudaStream()); #endif - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); - pullRows(nativeStart, &xBuf, x.shapeInfo(), x.specialShapeInfo(), - &zBuf, z.shapeInfo(), z.specialShapeInfo(), - 4, pidx, - xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), - zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + pullRows(nativeStart, &xBuf, x.shapeInfo(), x.specialShapeInfo(), &zBuf, + z.shapeInfo(), z.specialShapeInfo(), 4, pidx, + xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), + zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); - ASSERT_TRUE(z.equalsTo(exp)); - pm.synchronize(); - delete y; + ASSERT_TRUE(z.equalsTo(exp)); + pm.synchronize(); + delete y; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, softmax_9) { - NDArray arrC('c', {5,2}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 1}, sd::DataType::FLOAT32); - NDArray* arrF = new NDArray(arrC.dup('f')); - - NDArray outCC('c', {5,2}, sd::DataType::FLOAT32); - NDArray outCF('f', {5,2}, sd::DataType::FLOAT32); - NDArray outFC('c', {5,2}, sd::DataType::FLOAT32); - NDArray outFF('c', {5,2}, sd::DataType::FLOAT32); - - sd::ops::softmax op; - auto status1 = op.execute({&arrC}, {&outCC}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status1); - auto status2 = op.execute({&arrC}, {&outCF}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status2); - auto status3 = op.execute({arrF}, {&outFC}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status3); - auto status4 = op.execute({arrF}, {&outFF}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status4); - - // outCC.printIndexedBuffer("\n"); - // outCF.printIndexedBuffer("\n"); - // outFC.printIndexedBuffer("\n"); - // outFF.printIndexedBuffer("\n"); - - ASSERT_EQ(outCC, outCF); - ASSERT_EQ(outCC, outFC); - ASSERT_EQ(outCC, outFF); - - delete arrF; + NDArray arrC('c', {5, 2}, + {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 1}, + sd::DataType::FLOAT32); + NDArray* arrF = new NDArray(arrC.dup('f')); + + NDArray outCC('c', {5, 2}, sd::DataType::FLOAT32); + NDArray outCF('f', {5, 2}, sd::DataType::FLOAT32); + NDArray outFC('c', {5, 2}, sd::DataType::FLOAT32); + NDArray outFF('c', {5, 2}, sd::DataType::FLOAT32); + + sd::ops::softmax op; + auto status1 = op.execute({&arrC}, {&outCC}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status1); + auto status2 = op.execute({&arrC}, {&outCF}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status2); + auto status3 = op.execute({arrF}, {&outFC}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status3); + auto status4 = op.execute({arrF}, {&outFF}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status4); + + // outCC.printIndexedBuffer("\n"); + // outCF.printIndexedBuffer("\n"); + // outFC.printIndexedBuffer("\n"); + // outFF.printIndexedBuffer("\n"); + + ASSERT_EQ(outCC, outCF); + ASSERT_EQ(outCC, outFC); + ASSERT_EQ(outCC, outFF); + + delete arrF; } TEST_F(DeclarableOpsTests12, maxpool_bp_half_1) { - auto x = NDArrayFactory::create('c', {2, 3, 10, 1}, {0.2019043f, 0.6464844f, 0.9116211f, 0.60058594f, 0.34033203f, 0.7036133f, 0.6772461f, 0.3815918f, 0.87353516f, 0.04650879f, 0.67822266f, 0.8618164f, 0.88378906f, 0.7573242f, 0.66796875f, 0.63427734f, 0.33764648f, 0.46923828f, 0.62939453f, 0.76464844f, -0.8618164f, -0.94873047f, -0.9902344f, -0.88916016f, -0.86572266f, -0.92089844f, -0.90722656f, -0.96533203f, -0.97509766f, -0.4975586f, -0.84814453f, -0.984375f, -0.98828125f, -0.95458984f, -0.9472656f, -0.91064453f, -0.80859375f, -0.83496094f, -0.9140625f, -0.82470703f, 0.4802246f, 0.45361328f, 0.28125f, 0.28320312f, 0.79345703f, 0.44604492f, -0.30273438f, 0.11730957f, 0.56396484f, 0.73583984f, 0.1418457f, -0.44848633f, 0.6923828f, -0.40234375f, 0.40185547f, 0.48632812f, 0.14538574f, 0.4638672f, 0.13000488f, 0.5058594f}); - auto y = NDArrayFactory::create('c', {2, 3, 10, 1}, {0.0f, -0.13391113f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, -0.1751709f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.51904297f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.5107422f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); - auto z = NDArrayFactory::create('c', {2, 3, 10, 1}); - - sd::ops::maxpool2d_bp op; - Context ctx(1); - Nd4jLong iArgs[] = {5,1,1, 2,2,0, 1,1,1, 0,0}; - ctx.setIArguments(iArgs, 11); - ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); - ctx.setInputArray(1, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo()); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - - auto status = op.execute(&ctx); - ASSERT_EQ(Status::OK(), status); - + auto x = NDArrayFactory::create( + 'c', {2, 3, 10, 1}, + {0.2019043f, 0.6464844f, 0.9116211f, 0.60058594f, 0.34033203f, + 0.7036133f, 0.6772461f, 0.3815918f, 0.87353516f, 0.04650879f, + 0.67822266f, 0.8618164f, 0.88378906f, 0.7573242f, 0.66796875f, + 0.63427734f, 0.33764648f, 0.46923828f, 0.62939453f, 0.76464844f, + -0.8618164f, -0.94873047f, -0.9902344f, -0.88916016f, -0.86572266f, + -0.92089844f, -0.90722656f, -0.96533203f, -0.97509766f, -0.4975586f, + -0.84814453f, -0.984375f, -0.98828125f, -0.95458984f, -0.9472656f, + -0.91064453f, -0.80859375f, -0.83496094f, -0.9140625f, -0.82470703f, + 0.4802246f, 0.45361328f, 0.28125f, 0.28320312f, 0.79345703f, + 0.44604492f, -0.30273438f, 0.11730957f, 0.56396484f, 0.73583984f, + 0.1418457f, -0.44848633f, 0.6923828f, -0.40234375f, 0.40185547f, + 0.48632812f, 0.14538574f, 0.4638672f, 0.13000488f, 0.5058594f}); + auto y = NDArrayFactory::create( + 'c', {2, 3, 10, 1}, + {0.0f, -0.13391113f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, -0.1751709f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.51904297f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.5107422f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f}); + auto z = NDArrayFactory::create('c', {2, 3, 10, 1}); + + sd::ops::maxpool2d_bp op; + Context ctx(1); + Nd4jLong iArgs[] = {5, 1, 1, 2, 2, 0, 1, 1, 1, 0, 0}; + ctx.setIArguments(iArgs, 11); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo()); + ctx.setInputArray(1, y.buffer(), y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_1) { - - NDArray input('c', {2,3,4,10}); - NDArray gradO('c', {2,3,4,10}); - NDArray exp('c', {2,3,4,10}, {1.00438418e-02, 5.25184907e-03, 1.78685773e-03, -1.14537543e-03, -4.00071684e-03, -5.31899510e-03, -4.97647980e-03, -4.42161644e-03, -3.95395281e-03, -3.59310722e-03, 2.91823584e-04, -2.18498681e-05, -3.12092161e-04, -6.07360795e-04, -9.36298165e-04, - -1.02553482e-03, -7.91735307e-04, -6.15672267e-04, -4.71792649e-04, -3.42114770e-04, 4.29357824e-05, -5.46473675e-05, -1.48361753e-04, -2.47166492e-04, -3.61090642e-04, -3.81607766e-04, -2.89086485e-04, -2.17203109e-04, -1.56231865e-04, -9.91634734e-05, - 8.99407951e-06, -3.76849275e-05, -8.32021178e-05, -1.31939698e-04, -1.89008832e-04, -1.96661276e-04, -1.47534331e-04, -1.08789405e-04, -7.53896020e-05, -4.36357586e-05, - 1.23124300e-06, -2.60028974e-05, -5.27824741e-05, -8.17063192e-05, -1.15871291e-04, -1.19515295e-04, -8.91248055e-05, -6.49499125e-05, -4.39216528e-05, -2.37579407e-05, -9.34046056e-07, -1.87477999e-05, -3.63574763e-05, -5.54830040e-05, -7.82010393e-05, - -8.02115537e-05, -5.95739621e-05, -4.30659420e-05, -2.86241393e-05, -1.47010251e-05, -1.52835810e-06, -1.40790498e-05, -2.65316012e-05, -4.01083526e-05, -5.62983550e-05, -5.75223821e-05, -4.25982689e-05, -3.06141737e-05, -2.00884024e-05, -9.90276021e-06, - -1.61666367e-06, -1.09328157e-05, -2.02010433e-05, -3.03347279e-05, -4.24536738e-05, -4.32532870e-05, -3.19610226e-05, -2.28673853e-05, -1.48570880e-05, -7.08444895e-06, - -1.53552355e-06, -8.72318924e-06, -1.58886232e-05, -2.37402273e-05, -3.31507035e-05, -3.37014644e-05, -2.48602537e-05, -1.77248403e-05, -1.14254890e-05, -5.30027773e-06, -1.40318230e-06, -7.11624580e-06, -1.28209140e-05, -1.90826468e-05, -2.66006646e-05, - -2.69959855e-05, -1.98865000e-05, -1.41387427e-05, -9.05554589e-06, -4.10473058e-06, -1.26330860e-06, -5.91293519e-06, -1.05618501e-05, -1.56718652e-05, -2.18157675e-05, -2.21090413e-05, -1.62681827e-05, -1.15394150e-05, -7.35144840e-06, -3.26711961e-06, - -1.13179840e-06, -4.98940426e-06, -8.85062400e-06, -1.30997241e-05, -1.82144904e-05, -1.84380206e-05, -1.35542105e-05, -9.59566933e-06, -6.08572736e-06, -2.65887866e-06, - -1.01367493e-06, -4.26561428e-06, -7.52358210e-06, -1.11123145e-05, -1.54364170e-05, -1.56106762e-05, -1.14666063e-05, -8.10436813e-06, -5.12021325e-06, -2.20401580e-06, -9.09635219e-07, -3.68808492e-06, -6.47385696e-06, -9.54499774e-06, -1.32485484e-05, - -1.33870126e-05, -9.82651000e-06, -6.93532820e-06, -4.36710525e-06, -1.85539375e-06, -8.18735487e-07, -3.22003825e-06, -5.62928972e-06, -8.28724023e-06, -1.14948289e-05, -1.16066676e-05, -8.51461300e-06, -6.00201292e-06, -3.76846447e-06, -1.58258263e-06, - -7.39498375e-07, -2.83553072e-06, -4.93973403e-06, -7.26259532e-06, -1.00675643e-05, -1.01591886e-05, -7.44886802e-06, -5.24508141e-06, -3.28481428e-06, -1.36524977e-06, - -6.70378654e-07, -2.51585061e-06, -4.36947221e-06, -6.41683391e-06, -8.89049170e-06, -8.96649362e-06, -6.57134478e-06, -4.62275193e-06, -2.88851857e-06, -1.18941352e-06, -6.09944266e-07, -2.24723408e-06, -3.89250545e-06, -5.71062310e-06, -7.90838203e-06, - -7.97212033e-06, -5.84020108e-06, -4.10491293e-06, -2.55976192e-06, -1.04521314e-06, -5.56935277e-07, -2.01937837e-06, -3.48954882e-06, -5.11487451e-06, -7.08044308e-06, -7.13442114e-06, -5.22460778e-06, -3.66942504e-06, -2.28403951e-06, -9.25535005e-07, - -5.10270809e-07, -1.82444705e-06, -3.14605040e-06, -4.60769843e-06, -6.37601988e-06, -6.42213308e-06, -4.70144141e-06, -3.29971408e-06, -2.05053857e-06, -8.25151346e-07, - -4.69036365e-07, -1.65639949e-06, -2.85086708e-06, -4.17237243e-06, -5.77171340e-06, -5.81141694e-06, -4.25308644e-06, -2.98317354e-06, -1.85106614e-06, -7.40148607e-07, -4.32460268e-07, -1.51051631e-06, -2.59534818e-06, -3.79594053e-06, -5.24941379e-06, - -5.28384317e-06, -3.86593183e-06, -2.71007866e-06, -1.67932183e-06, -6.67554332e-07, -3.99893480e-07, -1.38306928e-06, -2.37269478e-06, -3.46823890e-06, -4.79492701e-06, -4.82497671e-06, -3.52932648e-06, -2.47282924e-06, -1.53039912e-06, -6.05077048e-07, - -3.70789934e-07, -1.27108103e-06, -2.17750403e-06, -3.18120783e-06, -4.39700398e-06, -4.42338614e-06, -3.23483960e-06, -2.26541715e-06, -1.40042869e-06, -5.50929371e-07}); - input.linspace(1); - gradO = 1; - - sd::ops::lrn_bp op; - - auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {5}); - auto gradI = results.at(0); - - ASSERT_EQ(*gradI, exp); - - + NDArray input('c', {2, 3, 4, 10}); + NDArray gradO('c', {2, 3, 4, 10}); + NDArray exp( + 'c', {2, 3, 4, 10}, + {1.00438418e-02, 5.25184907e-03, 1.78685773e-03, -1.14537543e-03, + -4.00071684e-03, -5.31899510e-03, -4.97647980e-03, -4.42161644e-03, + -3.95395281e-03, -3.59310722e-03, 2.91823584e-04, -2.18498681e-05, + -3.12092161e-04, -6.07360795e-04, -9.36298165e-04, -1.02553482e-03, + -7.91735307e-04, -6.15672267e-04, -4.71792649e-04, -3.42114770e-04, + 4.29357824e-05, -5.46473675e-05, -1.48361753e-04, -2.47166492e-04, + -3.61090642e-04, -3.81607766e-04, -2.89086485e-04, -2.17203109e-04, + -1.56231865e-04, -9.91634734e-05, 8.99407951e-06, -3.76849275e-05, + -8.32021178e-05, -1.31939698e-04, -1.89008832e-04, -1.96661276e-04, + -1.47534331e-04, -1.08789405e-04, -7.53896020e-05, -4.36357586e-05, + 1.23124300e-06, -2.60028974e-05, -5.27824741e-05, -8.17063192e-05, + -1.15871291e-04, -1.19515295e-04, -8.91248055e-05, -6.49499125e-05, + -4.39216528e-05, -2.37579407e-05, -9.34046056e-07, -1.87477999e-05, + -3.63574763e-05, -5.54830040e-05, -7.82010393e-05, -8.02115537e-05, + -5.95739621e-05, -4.30659420e-05, -2.86241393e-05, -1.47010251e-05, + -1.52835810e-06, -1.40790498e-05, -2.65316012e-05, -4.01083526e-05, + -5.62983550e-05, -5.75223821e-05, -4.25982689e-05, -3.06141737e-05, + -2.00884024e-05, -9.90276021e-06, -1.61666367e-06, -1.09328157e-05, + -2.02010433e-05, -3.03347279e-05, -4.24536738e-05, -4.32532870e-05, + -3.19610226e-05, -2.28673853e-05, -1.48570880e-05, -7.08444895e-06, + -1.53552355e-06, -8.72318924e-06, -1.58886232e-05, -2.37402273e-05, + -3.31507035e-05, -3.37014644e-05, -2.48602537e-05, -1.77248403e-05, + -1.14254890e-05, -5.30027773e-06, -1.40318230e-06, -7.11624580e-06, + -1.28209140e-05, -1.90826468e-05, -2.66006646e-05, -2.69959855e-05, + -1.98865000e-05, -1.41387427e-05, -9.05554589e-06, -4.10473058e-06, + -1.26330860e-06, -5.91293519e-06, -1.05618501e-05, -1.56718652e-05, + -2.18157675e-05, -2.21090413e-05, -1.62681827e-05, -1.15394150e-05, + -7.35144840e-06, -3.26711961e-06, -1.13179840e-06, -4.98940426e-06, + -8.85062400e-06, -1.30997241e-05, -1.82144904e-05, -1.84380206e-05, + -1.35542105e-05, -9.59566933e-06, -6.08572736e-06, -2.65887866e-06, + -1.01367493e-06, -4.26561428e-06, -7.52358210e-06, -1.11123145e-05, + -1.54364170e-05, -1.56106762e-05, -1.14666063e-05, -8.10436813e-06, + -5.12021325e-06, -2.20401580e-06, -9.09635219e-07, -3.68808492e-06, + -6.47385696e-06, -9.54499774e-06, -1.32485484e-05, -1.33870126e-05, + -9.82651000e-06, -6.93532820e-06, -4.36710525e-06, -1.85539375e-06, + -8.18735487e-07, -3.22003825e-06, -5.62928972e-06, -8.28724023e-06, + -1.14948289e-05, -1.16066676e-05, -8.51461300e-06, -6.00201292e-06, + -3.76846447e-06, -1.58258263e-06, -7.39498375e-07, -2.83553072e-06, + -4.93973403e-06, -7.26259532e-06, -1.00675643e-05, -1.01591886e-05, + -7.44886802e-06, -5.24508141e-06, -3.28481428e-06, -1.36524977e-06, + -6.70378654e-07, -2.51585061e-06, -4.36947221e-06, -6.41683391e-06, + -8.89049170e-06, -8.96649362e-06, -6.57134478e-06, -4.62275193e-06, + -2.88851857e-06, -1.18941352e-06, -6.09944266e-07, -2.24723408e-06, + -3.89250545e-06, -5.71062310e-06, -7.90838203e-06, -7.97212033e-06, + -5.84020108e-06, -4.10491293e-06, -2.55976192e-06, -1.04521314e-06, + -5.56935277e-07, -2.01937837e-06, -3.48954882e-06, -5.11487451e-06, + -7.08044308e-06, -7.13442114e-06, -5.22460778e-06, -3.66942504e-06, + -2.28403951e-06, -9.25535005e-07, -5.10270809e-07, -1.82444705e-06, + -3.14605040e-06, -4.60769843e-06, -6.37601988e-06, -6.42213308e-06, + -4.70144141e-06, -3.29971408e-06, -2.05053857e-06, -8.25151346e-07, + -4.69036365e-07, -1.65639949e-06, -2.85086708e-06, -4.17237243e-06, + -5.77171340e-06, -5.81141694e-06, -4.25308644e-06, -2.98317354e-06, + -1.85106614e-06, -7.40148607e-07, -4.32460268e-07, -1.51051631e-06, + -2.59534818e-06, -3.79594053e-06, -5.24941379e-06, -5.28384317e-06, + -3.86593183e-06, -2.71007866e-06, -1.67932183e-06, -6.67554332e-07, + -3.99893480e-07, -1.38306928e-06, -2.37269478e-06, -3.46823890e-06, + -4.79492701e-06, -4.82497671e-06, -3.52932648e-06, -2.47282924e-06, + -1.53039912e-06, -6.05077048e-07, -3.70789934e-07, -1.27108103e-06, + -2.17750403e-06, -3.18120783e-06, -4.39700398e-06, -4.42338614e-06, + -3.23483960e-06, -2.26541715e-06, -1.40042869e-06, -5.50929371e-07}); + input.linspace(1); + gradO = 1; + + sd::ops::lrn_bp op; + + auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {5}); + auto gradI = results.at(0); + + ASSERT_EQ(gradI, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_2) { - - NDArray input('c', {2,3,4,10}); - NDArray gradO('c', {2,3,4,10}); - NDArray exp('c', {2,3,4,10}, {-1.06179598e-03, -2.70050880e-03, -4.02126182e-03, -2.58826977e-03, -2.16024881e-03, -2.20575323e-03, -2.75954953e-03, -4.42477595e-03, -2.89176637e-03, -9.46942251e-04, -1.32603094e-03, -3.34868953e-03, -4.98152524e-03, -3.21313459e-03, -2.68880837e-03, -2.75207381e-03, -3.45109636e-03, -5.54159656e-03, -3.61320702e-03, -1.16457068e-03, - -1.70158676e-03, -4.26037982e-03, -6.33032294e-03, -4.09416296e-03, -3.43742501e-03, -3.52900685e-03, -4.43827361e-03, -7.13911094e-03, -4.64041065e-03, -1.46419462e-03, -2.26016506e-03, -5.59943309e-03, -8.30824208e-03, -5.39253885e-03, -4.54709725e-03, -4.68666852e-03, -5.91615774e-03, -9.53640230e-03, -6.17204653e-03, -1.89000927e-03, - -3.14102764e-03, -7.67878769e-03, -1.13740638e-02, -7.41857197e-03, -6.29213545e-03, -6.51977258e-03, -8.27047508e-03, -1.33656031e-02, -8.59564263e-03, -2.51553906e-03, -4.64272872e-03, -1.11560747e-02, -1.64905936e-02, -1.08321551e-02, -9.26420093e-03, -9.67171416e-03, -1.23506878e-02, -2.00199075e-02, -1.27442302e-02, -3.45497206e-03, - -7.49545777e-03, -1.76018942e-02, -2.59558801e-02, -1.72390267e-02, -1.49321631e-02, -1.57669969e-02, -2.03234926e-02, -3.30405571e-02, -2.06389092e-02, -4.78462130e-03, -1.38390735e-02, -3.14943902e-02, -4.63354364e-02, -3.13667879e-02, -2.77508944e-02, -2.98541505e-02, -3.89749333e-02, -6.32867143e-02, -3.77952419e-02, -5.26650995e-03, - -3.16195861e-02, -6.90807998e-02, -1.01725549e-01, -7.13700354e-02, -6.54785037e-02, -7.25797564e-02, -9.49372798e-02, -1.47399038e-01, -7.21285641e-02, 2.15010419e-02, -8.06625858e-02, -1.79638922e-01, -2.66877055e-01, -1.64447501e-01, -1.00968637e-01, -2.75682062e-02, 1.13596700e-01, 3.32260162e-01, 5.96845448e-01, 8.13161016e-01, - 9.52381015e-01, 8.13161016e-01, 5.96845508e-01, 3.32260162e-01, 1.13596708e-01, -2.75682174e-02, -1.37202948e-01, -2.71326721e-01, -1.84127048e-01, -7.94974267e-02, 3.29870060e-02, -7.39035010e-02, -1.60488203e-01, -1.04997143e-01, -8.06594491e-02, -7.25797564e-02, -7.87955597e-02, -1.11791104e-01, -7.58660138e-02, -3.48676592e-02, - -4.96974029e-03, -4.04525958e-02, -6.82792515e-02, -4.20900472e-02, -3.21968049e-02, -2.98541524e-02, -3.36477235e-02, -4.95737195e-02, -3.37007530e-02, -1.48636252e-02, -4.92655952e-03, -2.17927732e-02, -3.49853337e-02, -2.15152260e-02, -1.66727621e-02, -1.57669988e-02, -1.81730352e-02, -2.73226351e-02, -1.85334161e-02, -7.91355036e-03, - -3.57114570e-03, -1.33136865e-02, -2.09431648e-02, -1.29161589e-02, -1.01064872e-02, -9.67171136e-03, -1.12970043e-02, -1.71830691e-02, -1.16271935e-02, -4.84848116e-03, -2.59314431e-03, -8.91274121e-03, -1.38697922e-02, -8.58002994e-03, -6.75992295e-03, -6.51977304e-03, -7.68158771e-03, -1.17703741e-02, -7.94785097e-03, -3.25604435e-03, - -1.94202550e-03, -6.36530807e-03, -9.84015409e-03, -6.10316684e-03, -4.83274320e-03, -4.68666898e-03, -5.55526093e-03, -8.55536573e-03, -5.76688722e-03, -2.33053416e-03, -1.50016253e-03, -4.76644421e-03, -7.33569637e-03, -4.55961144e-03, -3.62428720e-03, -3.52900638e-03, -4.20164689e-03, -6.49448857e-03, -4.37143166e-03, -1.74761284e-03, - -1.19028054e-03, -3.69978836e-03, -5.67591935e-03, -3.53418733e-03, -2.81759514e-03, -2.75207404e-03, -3.28776496e-03, -5.09600528e-03, -3.42601724e-03, -1.35771628e-03, -9.65878542e-04, -2.95373448e-03, -4.52052988e-03, -2.81889434e-03, -2.25270819e-03, -2.20575323e-03, -2.64216494e-03, -4.10421193e-03, -2.75646802e-03, -1.08450721e-03, - -7.98697409e-04, -2.41194153e-03, -3.68447183e-03, -2.30037421e-03, -1.84193184e-03, -1.80714857e-03, -2.16938392e-03, -3.37567786e-03, -2.26523401e-03, -8.85842834e-04, -6.71049987e-04, -2.00629188e-03, -3.06024216e-03, -1.91263494e-03, -1.53396139e-03, -1.50748459e-03, -1.81288645e-03, -2.82496959e-03, -1.89429161e-03, -7.36965681e-04, - -5.71501616e-04, -1.69480499e-03, -2.58198148e-03, -1.61517004e-03, -1.29717519e-03, -1.27655920e-03, -1.53747783e-03, -2.39865575e-03, -1.60740130e-03, -6.22576685e-04, -4.92433901e-04, -1.45049067e-03, -2.20754091e-03, -1.38200901e-03, -1.11122860e-03, -1.09486456e-03, -1.32032647e-03, -2.06194492e-03, -1.38099224e-03, -5.32818493e-04}); - - input.linspace(-10, 0.1); - gradO = 1; - - sd::ops::lrn_bp op; - - auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {2}); - auto gradI = results.at(0); - - ASSERT_EQ(*gradI, exp); - - + NDArray input('c', {2, 3, 4, 10}); + NDArray gradO('c', {2, 3, 4, 10}); + NDArray exp( + 'c', {2, 3, 4, 10}, + {-1.06179598e-03, -2.70050880e-03, -4.02126182e-03, -2.58826977e-03, + -2.16024881e-03, -2.20575323e-03, -2.75954953e-03, -4.42477595e-03, + -2.89176637e-03, -9.46942251e-04, -1.32603094e-03, -3.34868953e-03, + -4.98152524e-03, -3.21313459e-03, -2.68880837e-03, -2.75207381e-03, + -3.45109636e-03, -5.54159656e-03, -3.61320702e-03, -1.16457068e-03, + -1.70158676e-03, -4.26037982e-03, -6.33032294e-03, -4.09416296e-03, + -3.43742501e-03, -3.52900685e-03, -4.43827361e-03, -7.13911094e-03, + -4.64041065e-03, -1.46419462e-03, -2.26016506e-03, -5.59943309e-03, + -8.30824208e-03, -5.39253885e-03, -4.54709725e-03, -4.68666852e-03, + -5.91615774e-03, -9.53640230e-03, -6.17204653e-03, -1.89000927e-03, + -3.14102764e-03, -7.67878769e-03, -1.13740638e-02, -7.41857197e-03, + -6.29213545e-03, -6.51977258e-03, -8.27047508e-03, -1.33656031e-02, + -8.59564263e-03, -2.51553906e-03, -4.64272872e-03, -1.11560747e-02, + -1.64905936e-02, -1.08321551e-02, -9.26420093e-03, -9.67171416e-03, + -1.23506878e-02, -2.00199075e-02, -1.27442302e-02, -3.45497206e-03, + -7.49545777e-03, -1.76018942e-02, -2.59558801e-02, -1.72390267e-02, + -1.49321631e-02, -1.57669969e-02, -2.03234926e-02, -3.30405571e-02, + -2.06389092e-02, -4.78462130e-03, -1.38390735e-02, -3.14943902e-02, + -4.63354364e-02, -3.13667879e-02, -2.77508944e-02, -2.98541505e-02, + -3.89749333e-02, -6.32867143e-02, -3.77952419e-02, -5.26650995e-03, + -3.16195861e-02, -6.90807998e-02, -1.01725549e-01, -7.13700354e-02, + -6.54785037e-02, -7.25797564e-02, -9.49372798e-02, -1.47399038e-01, + -7.21285641e-02, 2.15010419e-02, -8.06625858e-02, -1.79638922e-01, + -2.66877055e-01, -1.64447501e-01, -1.00968637e-01, -2.75682062e-02, + 1.13596700e-01, 3.32260162e-01, 5.96845448e-01, 8.13161016e-01, + 9.52381015e-01, 8.13161016e-01, 5.96845508e-01, 3.32260162e-01, + 1.13596708e-01, -2.75682174e-02, -1.37202948e-01, -2.71326721e-01, + -1.84127048e-01, -7.94974267e-02, 3.29870060e-02, -7.39035010e-02, + -1.60488203e-01, -1.04997143e-01, -8.06594491e-02, -7.25797564e-02, + -7.87955597e-02, -1.11791104e-01, -7.58660138e-02, -3.48676592e-02, + -4.96974029e-03, -4.04525958e-02, -6.82792515e-02, -4.20900472e-02, + -3.21968049e-02, -2.98541524e-02, -3.36477235e-02, -4.95737195e-02, + -3.37007530e-02, -1.48636252e-02, -4.92655952e-03, -2.17927732e-02, + -3.49853337e-02, -2.15152260e-02, -1.66727621e-02, -1.57669988e-02, + -1.81730352e-02, -2.73226351e-02, -1.85334161e-02, -7.91355036e-03, + -3.57114570e-03, -1.33136865e-02, -2.09431648e-02, -1.29161589e-02, + -1.01064872e-02, -9.67171136e-03, -1.12970043e-02, -1.71830691e-02, + -1.16271935e-02, -4.84848116e-03, -2.59314431e-03, -8.91274121e-03, + -1.38697922e-02, -8.58002994e-03, -6.75992295e-03, -6.51977304e-03, + -7.68158771e-03, -1.17703741e-02, -7.94785097e-03, -3.25604435e-03, + -1.94202550e-03, -6.36530807e-03, -9.84015409e-03, -6.10316684e-03, + -4.83274320e-03, -4.68666898e-03, -5.55526093e-03, -8.55536573e-03, + -5.76688722e-03, -2.33053416e-03, -1.50016253e-03, -4.76644421e-03, + -7.33569637e-03, -4.55961144e-03, -3.62428720e-03, -3.52900638e-03, + -4.20164689e-03, -6.49448857e-03, -4.37143166e-03, -1.74761284e-03, + -1.19028054e-03, -3.69978836e-03, -5.67591935e-03, -3.53418733e-03, + -2.81759514e-03, -2.75207404e-03, -3.28776496e-03, -5.09600528e-03, + -3.42601724e-03, -1.35771628e-03, -9.65878542e-04, -2.95373448e-03, + -4.52052988e-03, -2.81889434e-03, -2.25270819e-03, -2.20575323e-03, + -2.64216494e-03, -4.10421193e-03, -2.75646802e-03, -1.08450721e-03, + -7.98697409e-04, -2.41194153e-03, -3.68447183e-03, -2.30037421e-03, + -1.84193184e-03, -1.80714857e-03, -2.16938392e-03, -3.37567786e-03, + -2.26523401e-03, -8.85842834e-04, -6.71049987e-04, -2.00629188e-03, + -3.06024216e-03, -1.91263494e-03, -1.53396139e-03, -1.50748459e-03, + -1.81288645e-03, -2.82496959e-03, -1.89429161e-03, -7.36965681e-04, + -5.71501616e-04, -1.69480499e-03, -2.58198148e-03, -1.61517004e-03, + -1.29717519e-03, -1.27655920e-03, -1.53747783e-03, -2.39865575e-03, + -1.60740130e-03, -6.22576685e-04, -4.92433901e-04, -1.45049067e-03, + -2.20754091e-03, -1.38200901e-03, -1.11122860e-03, -1.09486456e-03, + -1.32032647e-03, -2.06194492e-03, -1.38099224e-03, -5.32818493e-04}); + + input.linspace(-10, 0.1); + gradO = 1; + + sd::ops::lrn_bp op; + + auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {2}); + auto gradI = results.at(0); + + ASSERT_EQ(gradI, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_3) { - - NDArray input('c', {2,3,4,10}); - NDArray gradO('c', {2,3,4,10}); - NDArray exp('c', {2,3,4,10}, {-6.78180193e-04, -1.06947345e-03, -1.50362519e-03, -1.47711602e-03, -1.45060697e-03, -1.42409769e-03, -1.39758852e-03, -1.37107936e-03, -8.79839936e-04, -4.27795108e-04, -8.62496032e-04, -1.34585891e-03, -1.88281795e-03, -1.84591592e-03, -1.80901436e-03, -1.77211256e-03, -1.73521065e-03, -1.69830909e-03, -1.08184782e-03, -5.13895764e-04, - -1.13227055e-03, -1.74428569e-03, -2.42520543e-03, -2.37169350e-03, -2.31818156e-03, -2.26466986e-03, -2.21115816e-03, -2.15764646e-03, -1.36136822e-03, -6.26647263e-04, -1.54878304e-03, -2.34815548e-03, -3.23930010e-03, -3.15753091e-03, -3.07576265e-03, -2.99399323e-03, -2.91222427e-03, -2.83045508e-03, -1.76287338e-03, -7.75904860e-04, - -2.23870482e-03, -3.32566188e-03, -4.54067392e-03, -4.40674182e-03, -4.27281018e-03, -4.13887901e-03, -4.00494691e-03, -3.87101574e-03, -2.36659218e-03, -9.72117065e-04, -3.49745504e-03, -5.05724549e-03, -6.80746930e-03, -6.56589260e-03, -6.32431870e-03, -6.08274434e-03, -5.84116904e-03, -5.59959421e-03, -3.32604628e-03, -1.21081201e-03, - -6.14068285e-03, -8.55270587e-03, -1.12749329e-02, -1.07723922e-02, -1.02698486e-02, -9.76730697e-03, -9.26476624e-03, -8.76222178e-03, -4.94601438e-03, -1.37539487e-03, -1.30690653e-02, -1.72132626e-02, -2.19351258e-02, -2.06174850e-02, -1.92998387e-02, -1.79821979e-02, -1.66645572e-02, -1.53469117e-02, -7.72346184e-03, -5.22134826e-04, - -3.99478227e-02, -4.78655733e-02, -5.70126995e-02, -5.16961850e-02, -4.63796593e-02, -4.10631336e-02, -3.57466117e-02, -3.04300785e-02, -9.11374856e-03, 1.14024431e-02, -2.35893592e-01, -2.17480078e-01, -1.88097835e-01, -1.38812393e-01, -8.95269737e-02, -4.02415469e-02, 9.04385652e-03, 5.83292767e-02, 1.78530529e-01, 2.96026409e-01, - 4.16666657e-01, 2.79557735e-01, 1.36546940e-01, 7.49502778e-02, 1.33536234e-02, -4.82430384e-02, -1.09839723e-01, -1.71436355e-01, -2.33033031e-01, -2.74476141e-01, 1.54189002e-02, -8.10869783e-03, -3.24862264e-02, -3.88403721e-02, -4.51945364e-02, -5.15486896e-02, -5.79028539e-02, -6.42570183e-02, -5.45457527e-02, -4.61437553e-02, - -2.29711179e-04, -8.06892477e-03, -1.63567103e-02, -1.78351123e-02, -1.93135180e-02, -2.07919199e-02, -2.22703181e-02, -2.37487257e-02, -1.87229179e-02, -1.43175106e-02, -1.37000845e-03, -5.16320160e-03, -9.21433326e-03, -9.76086594e-03, -1.03073996e-02, -1.08539313e-02, -1.14004640e-02, -1.19469995e-02, -9.08647850e-03, -6.55380823e-03, - -1.23490533e-03, -3.45137389e-03, -5.83263952e-03, -6.09064987e-03, -6.34865928e-03, -6.60666777e-03, -6.86467718e-03, -7.12268520e-03, -5.30054048e-03, -3.67741752e-03, -9.94500006e-04, -2.44303374e-03, -4.00528917e-03, -4.14666394e-03, -4.28803731e-03, -4.42941114e-03, -4.57078544e-03, -4.71215881e-03, -3.45545518e-03, -2.33156094e-03, - -7.93270417e-04, -1.81236281e-03, -2.91444198e-03, -3.00004939e-03, -3.08565609e-03, -3.17126350e-03, -3.25687067e-03, -3.34247784e-03, -2.42513884e-03, -1.60246110e-03, -6.39747130e-04, -1.39506557e-03, -2.21352675e-03, -2.26921216e-03, -2.32489733e-03, -2.38058274e-03, -2.43626791e-03, -2.49195332e-03, -1.79354590e-03, -1.16592250e-03, - -5.23828785e-04, -1.10576022e-03, -1.73730974e-03, -1.77553250e-03, -1.81375467e-03, -1.85197743e-03, -1.89020019e-03, -1.92842260e-03, -1.37922564e-03, -8.84913374e-04, -4.35433642e-04, -8.97393096e-04, -1.39935245e-03, -1.42670958e-03, -1.45406683e-03, -1.48142409e-03, -1.50878134e-03, -1.53613824e-03, -1.09309505e-03, -6.93831593e-04, - -3.66991735e-04, -7.42538832e-04, -1.15100679e-03, -1.17125409e-03, -1.19150116e-03, -1.21174823e-03, -1.23199564e-03, -1.25224248e-03, -8.87364266e-04, -5.58210537e-04, -3.13144788e-04, -6.24410110e-04, -9.63238359e-04, -9.78639582e-04, -9.94040747e-04, -1.00944215e-03, -1.02484343e-03, -1.04024459e-03, -7.34565372e-04, -4.58585098e-04, - -2.70129647e-04, -5.32291830e-04, -8.17865424e-04, -8.29851197e-04, -8.41836852e-04, -8.53822567e-04, -8.65808397e-04, -8.77794111e-04, -6.18013146e-04, -3.83307983e-04, -2.35282409e-04, -4.59096394e-04, -7.03040219e-04, -7.12549896e-04, -7.22059398e-04, -7.31569016e-04, -7.41078693e-04, -7.50588137e-04, -5.27105702e-04, -3.25074652e-04}); - - input.linspace(-10, 0.1); - gradO = 1; - - sd::ops::lrn_bp op; - - auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {7}); - auto gradI = results.at(0); - - ASSERT_EQ(*gradI, exp); - - + NDArray input('c', {2, 3, 4, 10}); + NDArray gradO('c', {2, 3, 4, 10}); + NDArray exp( + 'c', {2, 3, 4, 10}, + {-6.78180193e-04, -1.06947345e-03, -1.50362519e-03, -1.47711602e-03, + -1.45060697e-03, -1.42409769e-03, -1.39758852e-03, -1.37107936e-03, + -8.79839936e-04, -4.27795108e-04, -8.62496032e-04, -1.34585891e-03, + -1.88281795e-03, -1.84591592e-03, -1.80901436e-03, -1.77211256e-03, + -1.73521065e-03, -1.69830909e-03, -1.08184782e-03, -5.13895764e-04, + -1.13227055e-03, -1.74428569e-03, -2.42520543e-03, -2.37169350e-03, + -2.31818156e-03, -2.26466986e-03, -2.21115816e-03, -2.15764646e-03, + -1.36136822e-03, -6.26647263e-04, -1.54878304e-03, -2.34815548e-03, + -3.23930010e-03, -3.15753091e-03, -3.07576265e-03, -2.99399323e-03, + -2.91222427e-03, -2.83045508e-03, -1.76287338e-03, -7.75904860e-04, + -2.23870482e-03, -3.32566188e-03, -4.54067392e-03, -4.40674182e-03, + -4.27281018e-03, -4.13887901e-03, -4.00494691e-03, -3.87101574e-03, + -2.36659218e-03, -9.72117065e-04, -3.49745504e-03, -5.05724549e-03, + -6.80746930e-03, -6.56589260e-03, -6.32431870e-03, -6.08274434e-03, + -5.84116904e-03, -5.59959421e-03, -3.32604628e-03, -1.21081201e-03, + -6.14068285e-03, -8.55270587e-03, -1.12749329e-02, -1.07723922e-02, + -1.02698486e-02, -9.76730697e-03, -9.26476624e-03, -8.76222178e-03, + -4.94601438e-03, -1.37539487e-03, -1.30690653e-02, -1.72132626e-02, + -2.19351258e-02, -2.06174850e-02, -1.92998387e-02, -1.79821979e-02, + -1.66645572e-02, -1.53469117e-02, -7.72346184e-03, -5.22134826e-04, + -3.99478227e-02, -4.78655733e-02, -5.70126995e-02, -5.16961850e-02, + -4.63796593e-02, -4.10631336e-02, -3.57466117e-02, -3.04300785e-02, + -9.11374856e-03, 1.14024431e-02, -2.35893592e-01, -2.17480078e-01, + -1.88097835e-01, -1.38812393e-01, -8.95269737e-02, -4.02415469e-02, + 9.04385652e-03, 5.83292767e-02, 1.78530529e-01, 2.96026409e-01, + 4.16666657e-01, 2.79557735e-01, 1.36546940e-01, 7.49502778e-02, + 1.33536234e-02, -4.82430384e-02, -1.09839723e-01, -1.71436355e-01, + -2.33033031e-01, -2.74476141e-01, 1.54189002e-02, -8.10869783e-03, + -3.24862264e-02, -3.88403721e-02, -4.51945364e-02, -5.15486896e-02, + -5.79028539e-02, -6.42570183e-02, -5.45457527e-02, -4.61437553e-02, + -2.29711179e-04, -8.06892477e-03, -1.63567103e-02, -1.78351123e-02, + -1.93135180e-02, -2.07919199e-02, -2.22703181e-02, -2.37487257e-02, + -1.87229179e-02, -1.43175106e-02, -1.37000845e-03, -5.16320160e-03, + -9.21433326e-03, -9.76086594e-03, -1.03073996e-02, -1.08539313e-02, + -1.14004640e-02, -1.19469995e-02, -9.08647850e-03, -6.55380823e-03, + -1.23490533e-03, -3.45137389e-03, -5.83263952e-03, -6.09064987e-03, + -6.34865928e-03, -6.60666777e-03, -6.86467718e-03, -7.12268520e-03, + -5.30054048e-03, -3.67741752e-03, -9.94500006e-04, -2.44303374e-03, + -4.00528917e-03, -4.14666394e-03, -4.28803731e-03, -4.42941114e-03, + -4.57078544e-03, -4.71215881e-03, -3.45545518e-03, -2.33156094e-03, + -7.93270417e-04, -1.81236281e-03, -2.91444198e-03, -3.00004939e-03, + -3.08565609e-03, -3.17126350e-03, -3.25687067e-03, -3.34247784e-03, + -2.42513884e-03, -1.60246110e-03, -6.39747130e-04, -1.39506557e-03, + -2.21352675e-03, -2.26921216e-03, -2.32489733e-03, -2.38058274e-03, + -2.43626791e-03, -2.49195332e-03, -1.79354590e-03, -1.16592250e-03, + -5.23828785e-04, -1.10576022e-03, -1.73730974e-03, -1.77553250e-03, + -1.81375467e-03, -1.85197743e-03, -1.89020019e-03, -1.92842260e-03, + -1.37922564e-03, -8.84913374e-04, -4.35433642e-04, -8.97393096e-04, + -1.39935245e-03, -1.42670958e-03, -1.45406683e-03, -1.48142409e-03, + -1.50878134e-03, -1.53613824e-03, -1.09309505e-03, -6.93831593e-04, + -3.66991735e-04, -7.42538832e-04, -1.15100679e-03, -1.17125409e-03, + -1.19150116e-03, -1.21174823e-03, -1.23199564e-03, -1.25224248e-03, + -8.87364266e-04, -5.58210537e-04, -3.13144788e-04, -6.24410110e-04, + -9.63238359e-04, -9.78639582e-04, -9.94040747e-04, -1.00944215e-03, + -1.02484343e-03, -1.04024459e-03, -7.34565372e-04, -4.58585098e-04, + -2.70129647e-04, -5.32291830e-04, -8.17865424e-04, -8.29851197e-04, + -8.41836852e-04, -8.53822567e-04, -8.65808397e-04, -8.77794111e-04, + -6.18013146e-04, -3.83307983e-04, -2.35282409e-04, -4.59096394e-04, + -7.03040219e-04, -7.12549896e-04, -7.22059398e-04, -7.31569016e-04, + -7.41078693e-04, -7.50588137e-04, -5.27105702e-04, -3.25074652e-04}); + + input.linspace(-10, 0.1); + gradO = 1; + + sd::ops::lrn_bp op; + + auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {7}); + auto gradI = results.at(0); + + ASSERT_EQ(gradI, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_4) { - - NDArray input('c', {2,3,4,10}); - NDArray gradO('c', {2,3,4,10}); - NDArray exp('c', {2,3,4,10}, {-0.00119282, -0.00116995, -0.00114708, -0.00112421, -0.00110134, -0.00107847, -0.00105559, -0.00103272, -0.00100985, -0.00098698, -0.00150102, -0.00146918, -0.00143734, -0.0014055 , -0.00137366, -0.00134182, -0.00130998, -0.00127814, -0.0012463 , -0.00121446, - -0.00194534,-0.00189916, -0.00185299, -0.00180681, -0.00176064, -0.00171446, -0.00166829, -0.00162211, -0.00157593, -0.00152976, -0.0026189 , -0.00254833, -0.00247776, -0.00240719, -0.00233662, -0.00226605, -0.00219548, -0.00212491, -0.00205434, -0.00198377, - -0.00370962, -0.00359401, -0.00347839, -0.00336277, -0.00324716, -0.00313154, -0.00301593, -0.00290031, -0.00278469, -0.00266908, -0.00564327, -0.00543464, -0.00522602, -0.00501739, -0.00480876, -0.00460013, -0.0043915 , -0.00418288, -0.00397425, -0.00376562, - -0.00955302, -0.00911865, -0.00868428, -0.00824992, -0.00781555, -0.00738118, -0.00694682, -0.00651245, -0.00607808, -0.00564371, -0.01927758, -0.01813637, -0.01699515, -0.01585394, -0.01471272, -0.01357151, -0.01243029, -0.01128908, -0.01014786, -0.00900664, - -0.05409876, -0.04945958, -0.04482041, -0.04018124, -0.03554206, -0.03090289, -0.02626371, -0.02162454, -0.01698537, -0.01234619, -0.26145172, -0.214688 , -0.16792431, -0.12116055, -0.07439683, -0.02763309, 0.01913062, 0.06589434, 0.11265809, 0.15942183, - 0.25974026, 0.19902176, 0.13830325, 0.07758474, 0.01686624, -0.04385226, -0.10457078, -0.16528927, -0.22600779, -0.2867263 , -0.01177884, -0.0173331 , -0.02288735, -0.02844159, -0.03399584, -0.0395501 , -0.04510435, -0.05065861, -0.05621284, -0.0617671 , - -0.00944993, -0.01073084, -0.01201174, -0.01329265, -0.01457355, -0.01585446, -0.01713536, -0.01841627, -0.01969717, -0.02097807, -0.00589878, -0.00637122, -0.00684368, -0.00731612, -0.00778858, -0.00826102, -0.00873347, -0.00920592, -0.00967837, -0.01015082, - -0.00390961, -0.00413245, -0.00435528, -0.00457812, -0.00480095, -0.00502378, -0.00524662, -0.00546945, -0.00569229, -0.00591512, -0.00275609, -0.00287813, -0.00300018, -0.00312222, -0.00324427, -0.00336631, -0.00348836, -0.0036104 , -0.00373245, -0.00385449, - -0.00203982, -0.00211371, -0.00218759, -0.00226147, -0.00233536, -0.00240924, -0.00248312, -0.00255701, -0.00263089, -0.00270478, -0.00156781, -0.00161586, -0.00166391, -0.00171197, -0.00176002, -0.00180807, -0.00185612, -0.00190417, -0.00195223, -0.00200028, - -0.00124141, -0.00127439, -0.00130737, -0.00134035, -0.00137333, -0.00140631, -0.00143929, -0.00147227, -0.00150525, -0.00153822, -0.00100674, -0.00103034, -0.00105394, -0.00107754, -0.00110115, -0.00112475, -0.00114835, -0.00117195, -0.00119556, -0.00121916, - -0.00083255, -0.00085002, -0.00086748, -0.00088495, -0.00090242, -0.00091989, -0.00093735, -0.00095482, -0.00097229, -0.00098976, -0.0006998 , -0.00071308, -0.00072637, -0.00073965, -0.00075294, -0.00076623, -0.00077951, -0.0007928 , -0.00080609, -0.00081937, - -0.00059635, -0.00060669, -0.00061703, -0.00062737, -0.00063771, -0.00064805, -0.00065839, -0.00066873, -0.00067906, -0.0006894 , -0.0005142 , -0.0005224 , -0.00053061, -0.00053881, -0.00054701, -0.00055522, -0.00056342, -0.00057162, -0.00057983, -0.00058803}); - - input.linspace(-10, 0.1); - gradO = 1; - - sd::ops::lrn_bp op; - - auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {12}); - auto gradI = results.at(0); - - ASSERT_EQ(*gradI, exp); - - + NDArray input('c', {2, 3, 4, 10}); + NDArray gradO('c', {2, 3, 4, 10}); + NDArray exp( + 'c', {2, 3, 4, 10}, + {-0.00119282, -0.00116995, -0.00114708, -0.00112421, -0.00110134, + -0.00107847, -0.00105559, -0.00103272, -0.00100985, -0.00098698, + -0.00150102, -0.00146918, -0.00143734, -0.0014055, -0.00137366, + -0.00134182, -0.00130998, -0.00127814, -0.0012463, -0.00121446, + -0.00194534, -0.00189916, -0.00185299, -0.00180681, -0.00176064, + -0.00171446, -0.00166829, -0.00162211, -0.00157593, -0.00152976, + -0.0026189, -0.00254833, -0.00247776, -0.00240719, -0.00233662, + -0.00226605, -0.00219548, -0.00212491, -0.00205434, -0.00198377, + -0.00370962, -0.00359401, -0.00347839, -0.00336277, -0.00324716, + -0.00313154, -0.00301593, -0.00290031, -0.00278469, -0.00266908, + -0.00564327, -0.00543464, -0.00522602, -0.00501739, -0.00480876, + -0.00460013, -0.0043915, -0.00418288, -0.00397425, -0.00376562, + -0.00955302, -0.00911865, -0.00868428, -0.00824992, -0.00781555, + -0.00738118, -0.00694682, -0.00651245, -0.00607808, -0.00564371, + -0.01927758, -0.01813637, -0.01699515, -0.01585394, -0.01471272, + -0.01357151, -0.01243029, -0.01128908, -0.01014786, -0.00900664, + -0.05409876, -0.04945958, -0.04482041, -0.04018124, -0.03554206, + -0.03090289, -0.02626371, -0.02162454, -0.01698537, -0.01234619, + -0.26145172, -0.214688, -0.16792431, -0.12116055, -0.07439683, + -0.02763309, 0.01913062, 0.06589434, 0.11265809, 0.15942183, + 0.25974026, 0.19902176, 0.13830325, 0.07758474, 0.01686624, + -0.04385226, -0.10457078, -0.16528927, -0.22600779, -0.2867263, + -0.01177884, -0.0173331, -0.02288735, -0.02844159, -0.03399584, + -0.0395501, -0.04510435, -0.05065861, -0.05621284, -0.0617671, + -0.00944993, -0.01073084, -0.01201174, -0.01329265, -0.01457355, + -0.01585446, -0.01713536, -0.01841627, -0.01969717, -0.02097807, + -0.00589878, -0.00637122, -0.00684368, -0.00731612, -0.00778858, + -0.00826102, -0.00873347, -0.00920592, -0.00967837, -0.01015082, + -0.00390961, -0.00413245, -0.00435528, -0.00457812, -0.00480095, + -0.00502378, -0.00524662, -0.00546945, -0.00569229, -0.00591512, + -0.00275609, -0.00287813, -0.00300018, -0.00312222, -0.00324427, + -0.00336631, -0.00348836, -0.0036104, -0.00373245, -0.00385449, + -0.00203982, -0.00211371, -0.00218759, -0.00226147, -0.00233536, + -0.00240924, -0.00248312, -0.00255701, -0.00263089, -0.00270478, + -0.00156781, -0.00161586, -0.00166391, -0.00171197, -0.00176002, + -0.00180807, -0.00185612, -0.00190417, -0.00195223, -0.00200028, + -0.00124141, -0.00127439, -0.00130737, -0.00134035, -0.00137333, + -0.00140631, -0.00143929, -0.00147227, -0.00150525, -0.00153822, + -0.00100674, -0.00103034, -0.00105394, -0.00107754, -0.00110115, + -0.00112475, -0.00114835, -0.00117195, -0.00119556, -0.00121916, + -0.00083255, -0.00085002, -0.00086748, -0.00088495, -0.00090242, + -0.00091989, -0.00093735, -0.00095482, -0.00097229, -0.00098976, + -0.0006998, -0.00071308, -0.00072637, -0.00073965, -0.00075294, + -0.00076623, -0.00077951, -0.0007928, -0.00080609, -0.00081937, + -0.00059635, -0.00060669, -0.00061703, -0.00062737, -0.00063771, + -0.00064805, -0.00065839, -0.00066873, -0.00067906, -0.0006894, + -0.0005142, -0.0005224, -0.00053061, -0.00053881, -0.00054701, + -0.00055522, -0.00056342, -0.00057162, -0.00057983, -0.00058803}); + + input.linspace(-10, 0.1); + gradO = 1; + + sd::ops::lrn_bp op; + + auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {12}); + auto gradI = results.at(0); + + ASSERT_EQ(gradI, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_5) { - - NDArray input('c', {2,2,2,5}); - NDArray gradO('c', {2,2,2,5}); - NDArray exp('c', {2,2,2,5}, {6.2497472e-03, -3.4008762e-03, -1.5232352e-02, 2.3018382e-04, 1.3257053e-02, 7.1492628e-03, -5.4330104e-03, -2.0878183e-02, 1.5153568e-03, 2.0571884e-02, - 6.7926152e-03, -1.0990440e-02, -3.2685306e-02, 7.2436016e-03, 4.2120241e-02, -1.3439789e-02, -3.4284033e-02, -4.4852167e-02, 8.8073254e-02, 2.2223940e-01, - 4.0824831e-01, 2.1201703e-01, 3.8555145e-02, -3.1969927e-02, -3.0673094e-02, 5.2034661e-02, 1.0463811e-02, -3.6619946e-02, -1.3280880e-02, 5.9767403e-03, - 2.3028374e-02, 2.0452859e-03, -2.2533152e-02, -6.1039329e-03, 7.2805062e-03, 1.4290780e-02, 3.8017845e-04, -1.6107092e-02,-3.6896234e-03, 6.4357026e-03}); - input.linspace(-20, 1); - // gradO.linspace(0.1, 0.1); - gradO = 1; - - sd::ops::lrn_bp op; - - auto results = op.evaluate({&input, &gradO}, {1., 1., 0.5}, {2}); - auto gradI = results.at(0); - - ASSERT_EQ(*gradI, exp); - - + NDArray input('c', {2, 2, 2, 5}); + NDArray gradO('c', {2, 2, 2, 5}); + NDArray exp('c', {2, 2, 2, 5}, + {6.2497472e-03, -3.4008762e-03, -1.5232352e-02, 2.3018382e-04, + 1.3257053e-02, 7.1492628e-03, -5.4330104e-03, -2.0878183e-02, + 1.5153568e-03, 2.0571884e-02, 6.7926152e-03, -1.0990440e-02, + -3.2685306e-02, 7.2436016e-03, 4.2120241e-02, -1.3439789e-02, + -3.4284033e-02, -4.4852167e-02, 8.8073254e-02, 2.2223940e-01, + 4.0824831e-01, 2.1201703e-01, 3.8555145e-02, -3.1969927e-02, + -3.0673094e-02, 5.2034661e-02, 1.0463811e-02, -3.6619946e-02, + -1.3280880e-02, 5.9767403e-03, 2.3028374e-02, 2.0452859e-03, + -2.2533152e-02, -6.1039329e-03, 7.2805062e-03, 1.4290780e-02, + 3.8017845e-04, -1.6107092e-02, -3.6896234e-03, 6.4357026e-03}); + input.linspace(-20, 1); + // gradO.linspace(0.1, 0.1); + gradO = 1; + + sd::ops::lrn_bp op; + + auto results = op.evaluate({&input, &gradO}, {1., 1., 0.5}, {2}); + auto gradI = results.at(0); + + ASSERT_EQ(gradI, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_6) { + NDArray input('c', {1, 1, 1, 5}, {1, 2., 3, 4, 5}); + NDArray gradO('c', {1, 1, 1, 5}); + NDArray exp('c', {1, 1, 1, 5}, + {0.06926288, 0.04360996, 0.01795704, -0.00769587, -0.0333488}); + // gradO.linspace(-1.5, 0.1); + gradO = 1; - NDArray input('c', {1,1,1,5}, {1, 2., 3, 4, 5}); - NDArray gradO('c', {1,1,1,5}); - NDArray exp('c', {1,1,1,5}, {0.06926288, 0.04360996, 0.01795704, -0.00769587, -0.0333488}); - // gradO.linspace(-1.5, 0.1); - gradO = 1; - - sd::ops::lrn_bp op; + sd::ops::lrn_bp op; - auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {10}); - auto gradI = results.at(0); + auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {10}); + auto gradI = results.at(0); - ASSERT_EQ(*gradI, exp); - - + ASSERT_EQ(gradI, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_7) { + NDArray input('c', {2, 2, 2, 5}); + NDArray gradO('c', {2, 2, 2, 5}); - NDArray input('c', {2,2,2,5}); - NDArray gradO('c', {2,2,2,5}); - - input.linspace(-20, 1); - gradO.linspace(-1.5, 0.1); + input.linspace(-20, 1); + gradO.linspace(-1.5, 0.1); - const OpArgsHolder argsHolderFF({&input}, {1,2,0.5}, {2}); - const OpArgsHolder argsHolderBP({&input, &gradO}, {1,2,0.5}, {2}); + const OpArgsHolder argsHolderFF({&input}, {1, 2, 0.5}, {2}); + const OpArgsHolder argsHolderBP({&input, &gradO}, {1, 2, 0.5}, {2}); - sd::ops::lrn opFF; - sd::ops::lrn_bp opBP; + sd::ops::lrn opFF; + sd::ops::lrn_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_8) { + NDArray input('c', {1, 1, 1, 5}, {1, 2, 3, 4, 5}); + NDArray gradO('c', {1, 1, 1, 5}, {2, 3, 4, 5, 6}); - NDArray input('c', {1,1,1,5}, {1, 2, 3, 4, 5}); - NDArray gradO('c', {1,1,1,5}, {2, 3, 4, 5, 6}); + const OpArgsHolder argsHolderFF({&input}, {1, 2, 0.5}, {2}); + const OpArgsHolder argsHolderBP({&input, &gradO}, {1, 2, 0.5}, {2}); - const OpArgsHolder argsHolderFF({&input}, {1,2,0.5}, {2}); - const OpArgsHolder argsHolderBP({&input, &gradO}, {1,2,0.5}, {2}); + sd::ops::lrn opFF; + sd::ops::lrn_bp opBP; - sd::ops::lrn opFF; - sd::ops::lrn_bp opBP; + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_9) { + NDArray input('c', {1, 1, 1, 5}, {1, 2, 3, 4, 5}); + NDArray gradO('c', {1, 1, 1, 5}, {1, 1, 1, 1, 1}); + NDArray exp('c', {1, 1, 1, 5}, + {0.1084472, 0.03816165, 0.00978456, -0.01859251, -0.02511311}); - NDArray input('c', {1,1,1,5}, {1,2,3,4,5}); - NDArray gradO('c', {1,1,1,5}, {1, 1, 1, 1, 1}); - NDArray exp('c', {1,1,1,5}, {0.1084472 , 0.03816165, 0.00978456, -0.01859251,-0.02511311}); - - sd::ops::lrn_bp op; - - auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {3}); - auto gradI = results.at(0); + sd::ops::lrn_bp op; - // for (int i = 0; i < exp.lengthOf(); ++i) - // printf("%10.5f %10.5f\n", exp.e(i), gradI->e(i)); + auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {3}); + auto gradI = results.at(0); - ASSERT_EQ(*gradI, exp); + // for (int i = 0; i < exp.lengthOf(); ++i) + // printf("%10.5f %10.5f\n", exp.e(i), gradI->e(i)); - + ASSERT_EQ(gradI, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_10) { + NDArray input('c', {1, 1, 1, 1}, std::vector{1}); + NDArray gradO('c', {1, 1, 1, 1}, std::vector{1}); + NDArray exp('c', {1, 1, 1, 1}, std::vector{0.19245008}); - NDArray input('c', {1,1,1,1}, std::vector{1}); - NDArray gradO('c', {1,1,1,1}, std::vector{1}); - NDArray exp('c', {1,1,1,1}, std::vector{0.19245008}); + sd::ops::lrn_bp op; - sd::ops::lrn_bp op; + auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {1}); + auto gradI = results.at(0); - auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {1}); - auto gradI = results.at(0); - - ASSERT_EQ(*gradI, exp); - - + ASSERT_EQ(gradI, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_1) { + NDArray input('c', {2, 2, 2, 5}); + NDArray exp('c', {2, 2, 2, 5}, + {-0.42923987, -0.3623817, -0.3152079, -0.34268343, -0.3836809, + -0.43648192, -0.3652726, -0.31428117, -0.3379276, -0.3731494, + -0.45129365, -0.37083852, -0.3111639, -0.3260225, -0.34698898, + -0.4975186, -0.3831305, -0.2847474, -0.25607377, -0.18569534, + 0., 0.18569534, 0.25607377, 0.38411066, 0.52075565, + 0.33633637, 0.32117262, 0.30966178, 0.37259716, 0.45631808, + 0.36986336, 0.33643705, 0.31394684, 0.36608824, 0.43857202, + 0.3821113, 0.34197718, 0.31508508, 0.36284128, 0.4303756}); - NDArray input('c', {2,2,2,5}); - NDArray exp('c', {2,2,2,5}, {-0.42923987, -0.3623817 , -0.3152079 , -0.34268343, -0.3836809, -0.43648192, -0.3652726 , -0.31428117, -0.3379276 , -0.3731494 , - -0.45129365, -0.37083852, -0.3111639 , -0.3260225 , -0.34698898, -0.4975186 , -0.3831305 , -0.2847474 , -0.25607377, -0.18569534, - 0., 0.18569534, 0.25607377, 0.38411066, 0.52075565,0.33633637, 0.32117262, 0.30966178, 0.37259716, 0.45631808, - 0.36986336, 0.33643705, 0.31394684, 0.36608824, 0.43857202, 0.3821113 , 0.34197718, 0.31508508, 0.36284128, 0.4303756 }); - - input.linspace(-20, 1); - - sd::ops::lrn op; + input.linspace(-20, 1); - auto results = op.evaluate({&input}, {1., 2., 0.5}, {2}); - auto output = results.at(0); + sd::ops::lrn op; - ASSERT_EQ(*output, exp); + auto results = op.evaluate({&input}, {1., 2., 0.5}, {2}); + auto output = results.at(0); - + ASSERT_EQ(output, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_2) { + NDArray input('c', {1, 1, 1, 5}, {1, 2., 3, 4, 5}); + NDArray exp('c', {1, 1, 1, 5}, + {0.09530295, 0.1906059, 0.28590885, 0.3812118, 0.47651473}); - NDArray input('c', {1,1,1,5}, {1, 2., 3, 4, 5}); - NDArray exp('c', {1,1,1,5}, {0.09530295, 0.1906059 , 0.28590885, 0.3812118 , 0.47651473}); + sd::ops::lrn op; - sd::ops::lrn op; - - auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {5}); - auto output = results.at(0); - ASSERT_EQ(*output, exp); - - + auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {5}); + auto output = results.at(0); + ASSERT_EQ(output, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_3) { + NDArray input('c', {1, 1, 1, 1}, std::vector{1.}); + NDArray exp('c', {1, 1, 1, 1}, std::vector{0.69006556}); - NDArray input('c', {1,1,1,1}, std::vector{1.}); - NDArray exp('c', {1,1,1,1}, std::vector{0.69006556}); - - sd::ops::lrn op; + sd::ops::lrn op; - auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {5}); - auto output = results.at(0); - ASSERT_EQ(*output, exp); - - + auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {5}); + auto output = results.at(0); + ASSERT_EQ(output, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_4) { + NDArray input('c', {1, 1, 1, 1}, std::vector{1.}); + NDArray exp('c', {1, 1, 1, 1}, std::vector{0.69006556}); - NDArray input('c', {1,1,1,1}, std::vector{1.}); - NDArray exp('c', {1,1,1,1}, std::vector{0.69006556}); - - sd::ops::lrn op; + sd::ops::lrn op; - auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {0}); - auto output = results.at(0); - ASSERT_EQ(*output, exp); - - + auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {0}); + auto output = results.at(0); + ASSERT_EQ(output, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_5) { + NDArray input('c', {1, 1, 1, 5}, {1, 2., 3, 4, 5}); + NDArray exp('c', {1, 1, 1, 5}, + {0.69006556, 0.70272833, 0.7051508, 0.7060045, 0.7064008}); - NDArray input('c', {1,1,1,5}, {1, 2., 3, 4, 5}); - NDArray exp('c', {1,1,1,5}, {0.69006556, 0.70272833, 0.7051508 , 0.7060045 , 0.7064008}); - - sd::ops::lrn op; + sd::ops::lrn op; - auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {0}); - auto output = results.at(0); - ASSERT_EQ(*output, exp); - - + auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {0}); + auto output = results.at(0); + ASSERT_EQ(output, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, inTopK_1) { + NDArray x('c', {4, 5}, + {11.0, 14.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 15.0, 6.0, + 9.0, 3.5, 7.0, 11.0, 13.0, 5.0, 16.0, 9.0, 13.5, 7.0}); + NDArray y('c', {4}, {0., 0, 0, 0}, sd::DataType::INT64); + NDArray z('c', {4}, {1., 1, 1, 1}, sd::DataType::BOOL); - NDArray x('c', {4, 5}, {11.0, 14.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 15.0, 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 5.0, 16.0, 9.0, 13.5, 7.0}); - NDArray y('c', {4}, {0., 0, 0, 0}, sd::DataType::INT64); - NDArray z('c', {4}, {1., 1, 1, 1}, sd::DataType::BOOL); - - NDArray expV('c', {4}, {1., 0, 0, 0}, sd::DataType::BOOL); + NDArray expV('c', {4}, {1., 0, 0, 0}, sd::DataType::BOOL); - sd::ops::in_top_k op; - Nd4jStatus status = op.execute({&x, &y, }, {&z}, {}, {2}, {}); + sd::ops::in_top_k op; + Nd4jStatus status = op.execute( + { + &x, + &y, + }, + {&z}, {}, {2}, {}); - // z.printIndexedBuffer(); - ASSERT_EQ(ND4J_STATUS_OK, status); + // z.printIndexedBuffer(); + ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expV.isSameShape(z)); - ASSERT_TRUE(expV.equalsTo(z)); + ASSERT_TRUE(expV.isSameShape(z)); + ASSERT_TRUE(expV.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, inTopK_2) { + auto input = NDArrayFactory::create('c', {4, 5}); + auto idx = NDArrayFactory::create('c', {4}); - auto input = NDArrayFactory::create('c', {4, 5}); - auto idx = NDArrayFactory::create('c', {4}); - - auto exp = NDArrayFactory::create({false, false, false, true}); + auto exp = NDArrayFactory::create({false, false, false, true}); - int exclusive, reverse; - input.linspace(1); - idx.linspace(1); + int exclusive, reverse; + input.linspace(1); + idx.linspace(1); - sd::ops::in_top_k op; + sd::ops::in_top_k op; - auto res = op.evaluate({&input, &idx}, {}, {1}); + auto res = op.evaluate({&input, &idx}, {}, {1}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - //res.at(0)->printIndexedBuffer("IN_TOP_K output"); - ASSERT_TRUE(res.at(0)->equalsTo(&exp)); - + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + // res.at(0)->printIndexedBuffer("IN_TOP_K output"); + ASSERT_TRUE(res.at(0).equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, inTopK_3) { - auto x = NDArrayFactory::create('c', {2, 3}, {1.0, 11.0, 3.0, 14.0, 5.0, 6.0}); - auto y = NDArrayFactory::create('c', {2}, {1, 1}); - auto expV = NDArrayFactory::create('c', {2}, {true, false}); + auto x = NDArrayFactory::create('c', {2, 3}, + {1.0, 11.0, 3.0, 14.0, 5.0, 6.0}); + auto y = NDArrayFactory::create('c', {2}, {1, 1}); + auto expV = NDArrayFactory::create('c', {2}, {true, false}); - sd::ops::in_top_k op; - auto result = op.evaluate({&x, &y}, {}, {2}); + sd::ops::in_top_k op; + auto result = op.evaluate({&x, &y}, {}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(1, result.size()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(1, result.size()); - auto v = result.at(0); + auto v = result.at(0); - ASSERT_TRUE(expV.isSameShape(v)); - ASSERT_TRUE(expV.equalsTo(v)); - - + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, inTopK_4) { - auto x = NDArrayFactory::create('c', {6, 4}, {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0} ); - auto y = NDArrayFactory::create('c', {6}, {0, 0, 0, 0, 0, 0}); - auto expV = NDArrayFactory::create('c', {6}, {true, false, true, false, false, true}); - - sd::ops::in_top_k op; - auto result = op.evaluate({&x, &y}, {}, {2}); + auto x = NDArrayFactory::create( + 'c', {6, 4}, + {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, + 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0}); + auto y = NDArrayFactory::create('c', {6}, {0, 0, 0, 0, 0, 0}); + auto expV = NDArrayFactory::create( + 'c', {6}, {true, false, true, false, false, true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(1, result.size()); + sd::ops::in_top_k op; + auto result = op.evaluate({&x, &y}, {}, {2}); - auto v = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(1, result.size()); - ASSERT_TRUE(expV.isSameShape(v)); - ASSERT_TRUE(expV.equalsTo(v)); - - + auto v = result.at(0); + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, inTopK_5) { - auto x = NDArrayFactory::create('f', {6, 4}, {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0} ); - auto y = NDArrayFactory::create('f', {6}, {0, 0, 0, 0, 0, 0}); - auto expV = NDArrayFactory::create('f', {6}, {true, false, false, false, false, false }); - - sd::ops::in_top_k op; - auto result = op.evaluate({&x, &y}, {}, {2}); + auto x = NDArrayFactory::create( + 'f', {6, 4}, + {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, + 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0}); + auto y = NDArrayFactory::create('f', {6}, {0, 0, 0, 0, 0, 0}); + auto expV = NDArrayFactory::create( + 'f', {6}, {true, false, false, false, false, false}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(1, result.size()); + sd::ops::in_top_k op; + auto result = op.evaluate({&x, &y}, {}, {2}); - auto v = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(1, result.size()); - ASSERT_TRUE(expV.isSameShape(v)); - ASSERT_TRUE(expV.equalsTo(v)); + auto v = result.at(0); - + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cube_1) { + NDArray x('c', {2, 3}, {1., 2., 3., 4., 5, 6}); + NDArray exp('c', {2, 3}, {1., 8., 27., 64., 125, 216}); - NDArray x('c', {2, 3}, {1., 2., 3., 4., 5, 6}); - NDArray exp('c', {2, 3}, {1., 8., 27., 64., 125, 216}); + sd::ops::cube op; - sd::ops::cube op; + auto result = op.evaluate({&x}); - auto result = op.evaluate({&x}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cube_bp_1) { + NDArray x('c', {2, 3}, {1., 2., 3., 4., 5, 6}); + NDArray gradO('c', {2, 3}, sd::DataType::DOUBLE); + NDArray exp('c', {2, 3}, {1.5, 6., 13.5, 24., 37.5, 54}); - NDArray x('c', {2, 3}, {1., 2., 3., 4., 5, 6}); - NDArray gradO('c', {2, 3}, sd::DataType::DOUBLE); - NDArray exp('c', {2, 3}, {1.5, 6., 13.5, 24., 37.5, 54}); - - gradO = 0.5; + gradO = 0.5; - sd::ops::cube_bp op; + sd::ops::cube_bp op; - auto result = op.evaluate({&x, &gradO}); + auto result = op.evaluate({&x, &gradO}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - // z->printIndexedBuffer(); + auto z = result.at(0); + // z->printIndexedBuffer(); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// // CONSTANT mode 2D TEST_F(DeclarableOpsTests12, pad_tests1) { + NDArray input('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::FLOAT32); + NDArray paddings('c', {2, 2}, {1, 1, 2, 2}, sd::DataType::INT32); + NDArray expected('c', {4, 7}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 0, 0, + 0, 0, 4, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + sd::DataType::FLOAT32); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); - NDArray input('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::FLOAT32); - NDArray paddings('c', {2,2}, {1,1,2,2}, sd::DataType::INT32); - NDArray expected('c', {4,7}, {0,0,0,0,0,0,0, 0,0,1,2,3,0,0, 0,0,4,5,6,0,0, 0,0,0,0,0,0,0}, sd::DataType::FLOAT32); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - // result->printIndexedBuffer(); + auto result = results.at(0); + // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// // REFLECT mode 2D TEST_F(DeclarableOpsTests12, pad_tests2) { + float inBuff[] = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; + int padBuff[] = {1, 1, 2, 2}; + float expBuff[] = {6.f, 5.f, 4.f, 5.f, 6.f, 5.f, 4.f, 3.f, 2.f, 1.f, + 2.f, 3.f, 2.f, 1.f, 6.f, 5.f, 4.f, 5.f, 6.f, 5.f, + 4.f, 3.f, 2.f, 1.f, 2.f, 3.f, 2.f, 1.f}; - float inBuff[] = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; - int padBuff[] = {1,1,2,2}; - float expBuff[] = {6.f, 5.f, 4.f, 5.f, 6.f, 5.f, 4.f, 3.f, 2.f, 1.f, 2.f, 3.f, 2.f, 1.f, 6.f, 5.f, 4.f, 5.f, 6.f, 5.f, 4.f, 3.f, 2.f, 1.f, 2.f, 3.f, 2.f, 1.f}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {2, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 7}); - auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result->printIndexedBuffer(); - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// // SYMMETRIC mode 2D TEST_F(DeclarableOpsTests12, pad_tests3) { + float inBuff[] = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; + Nd4jLong padBuff[] = {1, 1, 2, 2}; + float expBuff[] = {2.f, 1.f, 1.f, 2.f, 3.f, 3.f, 2.f, 2.f, 1.f, 1.f, + 2.f, 3.f, 3.f, 2.f, 5.f, 4.f, 4.f, 5.f, 6.f, 6.f, + 5.f, 5.f, 4.f, 4.f, 5.f, 6.f, 6.f, 5.f}; - float inBuff[] = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; - Nd4jLong padBuff[] = {1,1,2,2}; - float expBuff[] = {2.f, 1.f, 1.f, 2.f, 3.f, 3.f, 2.f, 2.f,1.f,1.f,2.f,3.f,3.f,2.f, 5.f,4.f,4.f,5.f,6.f,6.f,5.f, 5.f,4.f,4.f,5.f,6.f,6.f,5.f}; - - auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {2}); + auto input = NDArrayFactory::create(inBuff, 'c', {2, 3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {2, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 7}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); - auto result = results.at(0); - // result->printIndexedBuffer(); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result->printIndexedBuffer(); - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// // CONSTANT mode 3D TEST_F(DeclarableOpsTests12, pad_tests4) { - - float inBuff[] = {1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f}; - int padBuff[] = {1,1,2,2,2,2}; - float expBuff[] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.f, 2.f, 3.f, 0.f, 0.f, 0.f, 0.f, 4.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, - 7.f, 8.f, 9.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.f, 11.f, 12.f, 0.f, - 0.f, 0.f, 0.f, 13.f, 14.f, 15.f, 0.f, 0.f, 0.f, 0.f, 16.f, 17.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; - - auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); - - // for(int i = 0; i < expected.lengthOf(); ++i) { - // float one = expected.e(i); - // float two = result->e(i); - // if(one != two) - // printf("%i : %f, %f\n", i, one, two); - // } - - + float inBuff[] = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f}; + int padBuff[] = {1, 1, 2, 2, 2, 2}; + float expBuff[] = { + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 1.f, 2.f, 3.f, 0.f, 0.f, 0.f, 0.f, 4.f, 5.f, 6.f, 0.f, 0.f, 0.f, + 0.f, 7.f, 8.f, 9.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.f, 11.f, 12.f, + 0.f, 0.f, 0.f, 0.f, 13.f, 14.f, 15.f, 0.f, 0.f, 0.f, 0.f, 16.f, 17.f, + 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f}; + + auto input = NDArrayFactory::create(inBuff, 'c', {2, 3, 3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {3, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 7, 7}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + // for(int i = 0; i < expected.lengthOf(); ++i) { + // float one = expected.e(i); + // float two = result->e(i); + // if(one != two) + // printf("%i : %f, %f\n", i, one, two); + // } } - - //////////////////////////////////////////////////////////////////// // REFLECT mode 3D TEST_F(DeclarableOpsTests12, pad_tests5) { - - double inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; - int padBuff[] = {1,1,2,2,2,2}; - double expBuff[] = {18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}; - auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18}; + int padBuff[] = {1, 1, 2, 2, 2, 2}; + double expBuff[] = { + 18, 17, 16, 17, 18, 17, 16, 15, 14, 13, 14, 15, 14, 13, 12, 11, 10, 11, + 12, 11, 10, 15, 14, 13, 14, 15, 14, 13, 18, 17, 16, 17, 18, 17, 16, 15, + 14, 13, 14, 15, 14, 13, 12, 11, 10, 11, 12, 11, 10, 9, 8, 7, 8, 9, + 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, + 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, + 4, 3, 2, 1, 2, 3, 2, 1, 18, 17, 16, 17, 18, 17, 16, 15, 14, 13, + 14, 15, 14, 13, 12, 11, 10, 11, 12, 11, 10, 15, 14, 13, 14, 15, 14, 13, + 18, 17, 16, 17, 18, 17, 16, 15, 14, 13, 14, 15, 14, 13, 12, 11, 10, 11, + 12, 11, 10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, + 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, + 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 3, 3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {3, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 7, 7}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// // SYMMETRIC mode 3D TEST_F(DeclarableOpsTests12, pad_tests6) { - - double inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; - int padBuff[] = {1,1,2,2,2,2}; - double expBuff[] = {5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14}; - - auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18}; + int padBuff[] = {1, 1, 2, 2, 2, 2}; + double expBuff[] = { + 5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, + 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, + 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, + 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, + 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, + 8, 5, 4, 4, 5, 6, 6, 5, 14, 13, 13, 14, 15, 15, 14, 11, 10, 10, + 11, 12, 12, 11, 11, 10, 10, 11, 12, 12, 11, 14, 13, 13, 14, 15, 15, 14, + 17, 16, 16, 17, 18, 18, 17, 17, 16, 16, 17, 18, 18, 17, 14, 13, 13, 14, + 15, 15, 14, 14, 13, 13, 14, 15, 15, 14, 11, 10, 10, 11, 12, 12, 11, 11, + 10, 10, 11, 12, 12, 11, 14, 13, 13, 14, 15, 15, 14, 17, 16, 16, 17, 18, + 18, 17, 17, 16, 16, 17, 18, 18, 17, 14, 13, 13, 14, 15, 15, 14}; + + auto input = NDArrayFactory::create(inBuff, 'c', {2, 3, 3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {3, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 7, 7}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// // CONSTANT mode 4D -TEST_F(DeclarableOpsTests12, pad_tests7) -{ - - double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; - int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; - double expBuff[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 0, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 10, 0, 0, 11, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 0, 0, 15, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); - - +TEST_F(DeclarableOpsTests12, pad_tests7) { + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; + double expBuff[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, + 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 0, 7, 8, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 10, 0, 0, 11, + 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 0, 0, 15, 16, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// // REFLECT mode 4D -TEST_F(DeclarableOpsTests12, pad_tests8) -{ - - double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; - int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; - double expBuff[] = {16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1}; - auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); - - +TEST_F(DeclarableOpsTests12, pad_tests8) { + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; + double expBuff[] = { + 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, + 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, + 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, + 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, + 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, + 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, + 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, + 2, 1, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, + 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, + 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, + 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, + 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, + 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, + 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, + 2, 1, 2, 1}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } ////////////////////////////////////////////////////////////////// // SYMMETRIC mode 4D -TEST_F(DeclarableOpsTests12, pad_tests9) -{ - - double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; - int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; - double expBuff[] = {1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16}; - auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); - - +TEST_F(DeclarableOpsTests12, pad_tests9) { + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; + double expBuff[] = { + 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, + 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, + 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, + 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 1, 1, 2, 2, 1, 1, 2, 2, + 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, + 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, + 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, + 8, 8, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, + 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, + 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, + 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 9, 9, 10, 10, 9, 9, + 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, + 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, + 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, + 15, 15, 16, 16}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests10) { + auto input = NDArrayFactory::create('c', {2, 3, 4}); + auto paddings = NDArrayFactory::create('c', {3, 2}, {0, 0, 0, 1, 0, 0}); + auto expected = NDArrayFactory::create( + 'c', {2, 4, 4}, + {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., + 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.}); - auto input = NDArrayFactory::create('c', {2,3,4}); - auto paddings = NDArrayFactory::create('c', {3,2}, {0,0, 0,1, 0,0}); - auto expected = NDArrayFactory::create('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - - input = 1.f; - //input.assign(1.); - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {0}); + input = 1.f; + // input.assign(1.); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests11) { + auto input = NDArrayFactory::create('c', {2, 3, 4}); + auto paddings = NDArrayFactory::create('c', {3, 2}, {0, 0, 0, 1, 0, 0}); + auto expected = NDArrayFactory::create( + 'c', {2, 4, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., + 12., 5., 6., 7., 8., 13., 14., 15., 16., 17., 18., + 19., 20., 21., 22., 23., 24., 17., 18., 19., 20.}); - auto input = NDArrayFactory::create('c', {2,3,4}); - auto paddings = NDArrayFactory::create('c', {3,2}, {0,0, 0,1, 0,0}); - auto expected = NDArrayFactory::create('c', {2,4,4}, {1., 2., 3., 4., 5., 6., 7., 8., 9.,10.,11.,12., 5., 6., 7., 8.,13.,14.,15.,16.,17.,18.,19.,20.,21.,22.,23.,24.,17.,18.,19.,20.}); - - input.linspace(1.f); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {1}); + input.linspace(1.f); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests12) { - - auto input = NDArrayFactory::create('c', {2,3,4,5}); - auto paddings = NDArrayFactory::create('c', {4,2}, {0,0, 0,1, 0,1, 0,0}); - auto expected = NDArrayFactory::create('c', {2,4,5,5}, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 16., 17., 18., 19., 20., - 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 36., 37., 38., 39., 40., - 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 56., 57., 58., 59., 60., - 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 56., 57., 58., 59., 60., - 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 73., 74., 75., 76., 77., 78., 79., 80., 76., 77., 78., 79., 80., - 81., 82., 83., 84., 85., 86., 87., 88., 89., 90., 91., 92., 93., 94., 95., 96., 97., 98., 99.,100., 96., 97., 98., 99.,100., - 101.,102.,103.,104.,105.,106.,107.,108.,109.,110.,111.,112.,113.,114.,115.,116.,117.,118.,119.,120.,116.,117.,118.,119.,120., - 101.,102.,103.,104.,105.,106.,107.,108.,109.,110.,111.,112.,113.,114.,115.,116.,117.,118.,119.,120.,116.,117.,118.,119.,120.}); - input.linspace(1.f); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + auto input = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto paddings = + NDArrayFactory::create('c', {4, 2}, {0, 0, 0, 1, 0, 1, 0, 0}); + auto expected = NDArrayFactory::create( + 'c', {2, 4, 5, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 16., 17., 18., 19., + 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., + 32., 33., 34., 35., 36., 37., 38., 39., 40., 36., 37., 38., + 39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., + 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 56., 57., + 58., 59., 60., 41., 42., 43., 44., 45., 46., 47., 48., 49., + 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 56., + 57., 58., 59., 60., 61., 62., 63., 64., 65., 66., 67., 68., + 69., 70., 71., 72., 73., 74., 75., 76., 77., 78., 79., 80., + 76., 77., 78., 79., 80., 81., 82., 83., 84., 85., 86., 87., + 88., 89., 90., 91., 92., 93., 94., 95., 96., 97., 98., 99., + 100., 96., 97., 98., 99., 100., 101., 102., 103., 104., 105., 106., + 107., 108., 109., 110., 111., 112., 113., 114., 115., 116., 117., 118., + 119., 120., 116., 117., 118., 119., 120., 101., 102., 103., 104., 105., + 106., 107., 108., 109., 110., 111., 112., 113., 114., 115., 116., 117., + 118., 119., 120., 116., 117., 118., 119., 120.}); + input.linspace(1.f); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests13) { + auto input = NDArrayFactory::create('c', {5}); + auto paddings = NDArrayFactory::create('c', {1, 2}, {2, 3}); + auto expected = NDArrayFactory::create( + 'c', {10}, {3., 2., 1., 2., 3., 4., 5., 4., 3., 2.}); + input.linspace(1.f); - auto input = NDArrayFactory::create('c', {5}); - auto paddings = NDArrayFactory::create('c', {1,2}, {2,3}); - auto expected = NDArrayFactory::create('c', {10}, {3., 2., 1., 2., 3., 4., 5., 4., 3., 2.}); - input.linspace(1.f); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); - auto result = results.at(0); - // result->printIndexedBuffer(); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result->printIndexedBuffer(); - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests14) { + auto input = NDArrayFactory::create('c', {1, 5}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 2, 3}); + auto expected = NDArrayFactory::create( + 'c', {1, 10}, {2., 1., 1., 2., 3., 4., 5., 5., 4., 3.}); + input.linspace(1.f); - auto input = NDArrayFactory::create('c', {1,5}); - auto paddings = NDArrayFactory::create('c', {2,2}, {0,0,2,3}); - auto expected = NDArrayFactory::create('c', {1,10}, {2., 1., 1., 2., 3., 4., 5., 5., 4., 3.}); - input.linspace(1.f); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests15) { + auto input = NDArrayFactory::create('c', {1, 5}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 0, 0}); + auto expected = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 1., 2., 3., 4., 5., 1., 2., 3., 4., 5.}); + input.linspace(1.f); - auto input = NDArrayFactory::create('c', {1,5}); - auto paddings = NDArrayFactory::create('c', {2,2}, {1,1,0,0}); - auto expected = NDArrayFactory::create('c', {3,5}, {1., 2., 3., 4., 5., 1., 2., 3., 4., 5., 1., 2., 3., 4., 5.}); - input.linspace(1.f); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests16) { + auto input = NDArrayFactory::create('c', {5, 1}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {2, 3, 0, 0}); + auto expected = NDArrayFactory::create( + 'c', {10, 1}, {3., 2., 1., 2., 3., 4., 5., 4., 3., 2.}); + input.linspace(1.f); - auto input = NDArrayFactory::create('c', {5,1}); - auto paddings = NDArrayFactory::create('c', {2,2}, {2,3,0,0}); - auto expected = NDArrayFactory::create('c', {10,1}, {3., 2., 1., 2., 3., 4., 5., 4., 3., 2.}); - input.linspace(1.f); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests17) { + auto input = NDArrayFactory::create('c', {5, 1}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 1, 0}); + auto expected = NDArrayFactory::create( + 'c', {5, 2}, {1., 1., 2., 2., 3., 3., 4., 4., 5., 5.}); + input.linspace(1.f); - auto input = NDArrayFactory::create('c', {5,1}); - auto paddings = NDArrayFactory::create('c', {2,2}, {0,0,1,0}); - auto expected = NDArrayFactory::create('c', {5,2}, {1.,1., 2.,2., 3.,3., 4.,4., 5.,5.}); - input.linspace(1.f); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests18) { + auto input = NDArrayFactory::create('c', {5}); + auto paddings = NDArrayFactory::create('c', {1, 2}, {0, 0}); + auto expected = + NDArrayFactory::create('c', {5}, {1., 2., 3., 4., 5.}); + input.linspace(1.f); - auto input = NDArrayFactory::create('c', {5}); - auto paddings = NDArrayFactory::create('c', {1,2}, {0,0}); - auto expected = NDArrayFactory::create('c', {5}, {1.,2.,3.,4.,5.}); - input.linspace(1.f); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests19) { + auto input = NDArrayFactory::create('c', {5, 1}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); + auto expected = + NDArrayFactory::create('c', {5, 1}, {1., 2., 3., 4., 5.}); + input.linspace(1.f); - auto input = NDArrayFactory::create('c', {5,1}); - auto paddings = NDArrayFactory::create('c', {2,2}, {0,0,0,0}); - auto expected = NDArrayFactory::create('c', {5,1}, {1., 2., 3., 4., 5.}); - input.linspace(1.f); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests20) { + auto input = NDArrayFactory::create('c', {1, 5}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); + auto expected = + NDArrayFactory::create('c', {1, 5}, {1., 2., 3., 4., 5.}); + input.linspace(1.f); - auto input = NDArrayFactory::create('c', {1,5}); - auto paddings = NDArrayFactory::create('c', {2,2}, {0,0,0,0}); - auto expected = NDArrayFactory::create('c', {1,5}, {1., 2., 3., 4., 5.}); - input.linspace(1.f); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {1}); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests21) { + auto input = NDArrayFactory::create('c', {1, 3, 1, 5}); + auto paddings = + NDArrayFactory::create('c', {4, 2}, {0, 0, 0, 1, 0, 1, 0, 0}); + auto expected = NDArrayFactory::create( + 'c', {1, 4, 2, 5}, + {1., 2., 3., 4., 5., 1., 2., 3., 4., 5., 6., 7., 8., 9., + 10., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 11., 12., 13., + 14., 15., 11., 12., 13., 14., 15., 11., 12., 13., 14., 15.}); + input.linspace(1.f); - auto input = NDArrayFactory::create('c', {1,3,1,5}); - auto paddings = NDArrayFactory::create('c', {4,2}, {0,0, 0,1, 0,1, 0,0}); - auto expected = NDArrayFactory::create('c', {1,4,2,5}, {1., 2., 3., 4., 5., 1., 2., 3., 4., 5., 6., 7., 8., 9.,10., 6., 7., 8., 9.,10., - 11.,12.,13.,14.,15.,11.,12.,13.,14.,15.,11.,12.,13.,14.,15.,11.,12.,13.,14.,15.}); - input.linspace(1.f); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); - auto result = results.at(0); - // result->printIndexedBuffer(); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result->printIndexedBuffer(); - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests22) { + auto input = NDArrayFactory::create('c', {1, 1}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); + auto expected = NDArrayFactory::create('c', {1, 1}, {1.}); - auto input = NDArrayFactory::create('c', {1,1}); - auto paddings = NDArrayFactory::create('c', {2,2}, {0,0, 0,0}); - auto expected = NDArrayFactory::create('c', {1,1}, {1.}); + input.linspace(1.f); - input.linspace(1.f); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - // result->printIndexedBuffer(); + auto result = results.at(0); + // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests23) { + auto input = NDArrayFactory::create('c', {1, 1}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 1, 0}); + auto expected = NDArrayFactory::create('c', {1, 2}, {0., 1.}); - auto input = NDArrayFactory::create('c', {1,1}); - auto paddings = NDArrayFactory::create('c', {2,2}, {0,0, 1,0}); - auto expected = NDArrayFactory::create('c', {1,2}, {0.,1.}); - - input.linspace(1.f); + input.linspace(1.f); - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {0}); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result->printShapeInfo("r"); - // expected.printShapeInfo("e"); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result->printShapeInfo("r"); + // expected.printShapeInfo("e"); - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests24) { + auto input = NDArrayFactory::create('c', {1}); + auto paddings = NDArrayFactory::create('c', {1, 2}, {0, 0}); + auto expected = NDArrayFactory::create('c', {1}, {1.}); - auto input = NDArrayFactory::create('c', {1}); - auto paddings = NDArrayFactory::create('c', {1,2}, {0,0}); - auto expected = NDArrayFactory::create('c', {1}, {1.}); + input.linspace(1.f); - input.linspace(1.f); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests25) { + auto input = NDArrayFactory::create('c', {1}); + auto paddings = NDArrayFactory::create('c', {1, 2}, {1, 1}); + auto expected = NDArrayFactory::create('c', {3}, {1., 1., 1}); - auto input = NDArrayFactory::create('c', {1}); - auto paddings = NDArrayFactory::create('c', {1,2}, {1,1}); - auto expected = NDArrayFactory::create('c', {3}, {1.,1.,1}); + input.linspace(1.f); - input.linspace(1.f); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests26) { + auto input = NDArrayFactory::create('c', {1}); + auto paddings = NDArrayFactory::create('c', {1, 2}, {3, 2}); + auto expected = + NDArrayFactory::create('c', {6}, {0., 0., 0., 1., 0., 0.}); - auto input = NDArrayFactory::create('c', {1}); - auto paddings = NDArrayFactory::create('c', {1,2}, {3,2}); - auto expected = NDArrayFactory::create('c', {6}, {0., 0., 0., 1., 0., 0.}); + input.linspace(1.f); - input.linspace(1.f); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests27) { + NDArray input('c', {2, 3}, sd::DataType::FLOAT32); + NDArray paddings('c', {2, 2}, {0, 0, 0, 1}, sd::DataType::INT32); + NDArray exp('c', {2, 4}, {1, 1, 1, 0, 1, 1, 1, 0}, sd::DataType::FLOAT32); + NDArray z('c', {2, 4}, sd::DataType::FLOAT32); + input = 1.; - NDArray input('c', {2,3}, sd::DataType::FLOAT32); - NDArray paddings('c', {2,2}, {0,0,0,1}, sd::DataType::INT32); - NDArray exp('c', {2,4}, {1,1,1,0,1,1,1,0}, sd::DataType::FLOAT32); - NDArray z('c', {2,4}, sd::DataType::FLOAT32); - input = 1.; - - sd::ops::pad op; - Nd4jStatus status = op.execute({&input, &paddings}, {&z}, {0}, {0}, {}); // constant - // z.printIndexedBuffer(); + sd::ops::pad op; + Nd4jStatus status = + op.execute({&input, &paddings}, {&z}, {0}, {0}, {}); // constant + // z.printIndexedBuffer(); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(exp.isSameShapeStrict(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(exp.isSameShapeStrict(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests28) { + NDArray input('c', {1, 111, 111, 32}, sd::DataType::FLOAT32); + NDArray paddings('c', {4, 2}, {0, 0, 0, 1, 0, 1, 0, 0}, sd::DataType::INT32); + NDArray z('c', {1, 112, 112, 32}, sd::DataType::FLOAT32); + input = 1.; - NDArray input('c', {1,111,111,32}, sd::DataType::FLOAT32); - NDArray paddings('c', {4,2}, {0,0,0,1,0,1,0,0}, sd::DataType::INT32); - NDArray z('c', {1,112,112,32}, sd::DataType::FLOAT32); - input = 1.; + sd::ops::pad op; + Nd4jStatus status = + op.execute({&input, &paddings}, {&z}, {0}, {0}, {}); // constant + // z.printIndexedBuffer(); - sd::ops::pad op; - Nd4jStatus status = op.execute({&input, &paddings}, {&z}, {0}, {0}, {}); // constant - // z.printIndexedBuffer(); + NDArray sum = z.reduceNumber(sd::reduce::Sum); - NDArray sum = z.reduceNumber(sd::reduce::Sum); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_EQ(sum.e(0), 111*111*32); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_EQ(sum.e(0), 111 * 111 * 32); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests29) { + auto in = NDArrayFactory::create({1., 1., 1., 1., 1.}); + // auto pad = NDArrayFactory::create('c', {1, 2}, {1., 1.});// = + // Nd4j.create(new double[]{1, 1}, new long[]{1, 2}); + auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); + // auto value(10.0); - auto in = NDArrayFactory::create({1., 1., 1., 1., 1.}); -// auto pad = NDArrayFactory::create('c', {1, 2}, {1., 1.});// = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}); - auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); -// auto value(10.0); - - auto exp = NDArrayFactory::create({10., 1., 1., 1., 1., 1., 10.}); + auto exp = NDArrayFactory::create({10., 1., 1., 1., 1., 1., 10.}); - sd::ops::pad op; + sd::ops::pad op; - auto res = op.evaluate({&in, &pad}, {10.0}, {0}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - ASSERT_TRUE(exp.equalsTo(res.at(0))); - + auto res = op.evaluate({&in, &pad}, {10.0}, {0}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests30) { + auto in = NDArrayFactory::create({1., 11., 111., 11., 1.}); + auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); - auto in = NDArrayFactory::create({1., 11., 111., 11., 1.}); - auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); - - auto exp = NDArrayFactory::create({1., 1., 11., 111., 11., 1., 1.}); + auto exp = NDArrayFactory::create({1., 1., 11., 111., 11., 1., 1.}); - sd::ops::pad op; + sd::ops::pad op; - auto res = op.evaluate({&in, &pad}, {10.0}, {2}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - ASSERT_TRUE(exp.equalsTo(res.at(0))); - + auto res = op.evaluate({&in, &pad}, {10.0}, {2}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests31) { + auto in = NDArrayFactory::create({1., 11., 111., 1111., 11111.}); + // auto pad = NDArrayFactory::create('c', {1, 2}, {1., 1.});// = + // Nd4j.create(new double[]{1, 1}, new long[]{1, 2}); + auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); + // auto value(10.0); - auto in = NDArrayFactory::create({1., 11., 111., 1111., 11111.}); -// auto pad = NDArrayFactory::create('c', {1, 2}, {1., 1.});// = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}); - auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); -// auto value(10.0); - - auto exp = NDArrayFactory::create({11., 1., 11., 111., 1111., 11111., 1111.}); + auto exp = NDArrayFactory::create( + {11., 1., 11., 111., 1111., 11111., 1111.}); - sd::ops::pad op; + sd::ops::pad op; - auto res = op.evaluate({&in, &pad}, {10.0}, {1}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - ASSERT_TRUE(exp.equalsTo(res.at(0))); - + auto res = op.evaluate({&in, &pad}, {10.0}, {1}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } /////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests32) { + auto in = NDArrayFactory::create('c', {3, 3}, + {1., 2., 3., 4., 5., 6, 7, 8, 9}); + auto pad = NDArrayFactory::create('c', {2, 2}, {1, 2, 2, 3}); - auto in = NDArrayFactory::create('c', {3,3}, {1., 2., 3., 4., 5.,6,7,8,9}); - auto pad = NDArrayFactory::create('c', {2,2}, {1, 2, 2, 3}); + auto exp = NDArrayFactory::create( + 'c', {6, 8}, + {2, 1, 1, 2, 3, 3, 2, 1, 2, 1, 1, 2, 3, 3, 2, 1, 5, 4, 4, 5, 6, 6, 5, 4, + 8, 7, 7, 8, 9, 9, 8, 7, 8, 7, 7, 8, 9, 9, 8, 7, 5, 4, 4, 5, 6, 6, 5, 4}); - auto exp = NDArrayFactory::create('c', {6,8}, {2, 1, 1, 2, 3, 3, 2, 1, 2, 1, 1, 2, 3, 3, 2, 1, 5, 4, 4, 5, 6, 6, 5, 4, 8, 7, 7, 8, 9, 9, 8, 7, 8, 7, 7, 8, 9, 9, 8, 7, 5, 4, 4, 5, 6, 6, 5, 4}); + sd::ops::pad op; - sd::ops::pad op; - - auto res = op.evaluate({&in, &pad}, {10.0}, {2}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - ASSERT_TRUE(exp.equalsTo(res.at(0))); - + auto res = op.evaluate({&in, &pad}, {10.0}, {2}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } /////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests33) { - - auto in = NDArrayFactory::create('c', {2,3,4}, {1, 2, 3, 4,5, 6, 7, 8,9,10,11,12,13, 14, 15, 16,17, 18, 19, 20,21, 22, 23, 24}); - - auto pad = NDArrayFactory::create('c', {3,2}, {1, 2, 2, 3, 3,3}); - - auto exp = NDArrayFactory::create('c', {5,8,10}, { 7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2., 3,2,1,1,2,3,4,4,3,2., 7,6,5,5,6,7,8,8,7,6., 11,10,9,9,10,11,12,12,11,10., - 11,10,9,9,10,11,12,12,11,10., 7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2., 7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2., - 3,2,1,1,2,3,4,4,3,2., 7,6,5,5,6,7,8,8,7,6., 11,10,9,9,10,11,12,12,11,10., 11,10,9,9,10,11,12,12,11,10.,7,6,5,5,6,7,8,8,7,6., - 3,2,1,1,2,3,4,4,3,2., 19,18,17,17,18,19,20,20,19,18., 15,14,13,13,14,15,16,16,15,14., 15,14,13,13,14,15,16,16,15,14., - 19,18,17,17,18,19,20,20,19,18., 23,22,21,21,22,23,24,24,23,22., 23,22,21,21,22,23,24,24,23,22., 19,18,17,17,18,19,20,20,19,18., - 15,14,13,13,14,15,16,16,15,14., 19,18,17,17,18,19,20,20,19,18., 15,14,13,13,14,15,16,16,15,14., 15,14,13,13,14,15,16,16,15,14., - 19,18,17,17,18,19,20,20,19,18., 23,22,21,21,22,23,24,24,23,22., 23,22,21,21,22,23,24,24,23,22., 19,18,17,17,18,19,20,20,19,18., - 15,14,13,13,14,15,16,16,15,14., 7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2., 3,2,1,1,2,3,4,4,3,2., 7,6,5,5,6,7,8,8,7,6., - 11,10,9,9,10,11,12,12,11,10., 11,10,9,9,10,11,12,12,11,10., 7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2.}); - sd::ops::pad op; - - auto res = op.evaluate({&in, &pad}, {10.0}, {2}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - ASSERT_TRUE(exp.equalsTo(res.at(0))); - + auto in = NDArrayFactory::create( + 'c', {2, 3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + + auto pad = NDArrayFactory::create('c', {3, 2}, {1, 2, 2, 3, 3, 3}); + + auto exp = NDArrayFactory::create( + 'c', {5, 8, 10}, + {7, 6, 5, 5, 6, 7, 8, 8, 7, 6., 3, 2, 1, 1, 2, 3, + 4, 4, 3, 2., 3, 2, 1, 1, 2, 3, 4, 4, 3, 2., 7, 6, + 5, 5, 6, 7, 8, 8, 7, 6., 11, 10, 9, 9, 10, 11, 12, 12, + 11, 10., 11, 10, 9, 9, 10, 11, 12, 12, 11, 10., 7, 6, 5, 5, + 6, 7, 8, 8, 7, 6., 3, 2, 1, 1, 2, 3, 4, 4, 3, 2., + 7, 6, 5, 5, 6, 7, 8, 8, 7, 6., 3, 2, 1, 1, 2, 3, + 4, 4, 3, 2., 3, 2, 1, 1, 2, 3, 4, 4, 3, 2., 7, 6, + 5, 5, 6, 7, 8, 8, 7, 6., 11, 10, 9, 9, 10, 11, 12, 12, + 11, 10., 11, 10, 9, 9, 10, 11, 12, 12, 11, 10., 7, 6, 5, 5, + 6, 7, 8, 8, 7, 6., 3, 2, 1, 1, 2, 3, 4, 4, 3, 2., + 19, 18, 17, 17, 18, 19, 20, 20, 19, 18., 15, 14, 13, 13, 14, 15, + 16, 16, 15, 14., 15, 14, 13, 13, 14, 15, 16, 16, 15, 14., 19, 18, + 17, 17, 18, 19, 20, 20, 19, 18., 23, 22, 21, 21, 22, 23, 24, 24, + 23, 22., 23, 22, 21, 21, 22, 23, 24, 24, 23, 22., 19, 18, 17, 17, + 18, 19, 20, 20, 19, 18., 15, 14, 13, 13, 14, 15, 16, 16, 15, 14., + 19, 18, 17, 17, 18, 19, 20, 20, 19, 18., 15, 14, 13, 13, 14, 15, + 16, 16, 15, 14., 15, 14, 13, 13, 14, 15, 16, 16, 15, 14., 19, 18, + 17, 17, 18, 19, 20, 20, 19, 18., 23, 22, 21, 21, 22, 23, 24, 24, + 23, 22., 23, 22, 21, 21, 22, 23, 24, 24, 23, 22., 19, 18, 17, 17, + 18, 19, 20, 20, 19, 18., 15, 14, 13, 13, 14, 15, 16, 16, 15, 14., + 7, 6, 5, 5, 6, 7, 8, 8, 7, 6., 3, 2, 1, 1, 2, 3, + 4, 4, 3, 2., 3, 2, 1, 1, 2, 3, 4, 4, 3, 2., 7, 6, + 5, 5, 6, 7, 8, 8, 7, 6., 11, 10, 9, 9, 10, 11, 12, 12, + 11, 10., 11, 10, 9, 9, 10, 11, 12, 12, 11, 10., 7, 6, 5, 5, + 6, 7, 8, 8, 7, 6., 3, 2, 1, 1, 2, 3, 4, 4, 3, 2.}); + sd::ops::pad op; + + auto res = op.evaluate({&in, &pad}, {10.0}, {2}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests34) { + NDArray input('c', {5}, {0.778786, 0.801198, 0.724375, 0.230894, 0.727141}, + sd::DataType::FLOAT32); + NDArray paddings('c', {1, 2}, {1, 1}, sd::DataType::INT32); + NDArray expected('c', {7}, + {10., 0.778786, 0.801198, 0.724375, 0.230894, 0.727141, 10.}, + sd::DataType::FLOAT32); + NDArray z('c', {7}, sd::DataType::FLOAT32); - NDArray input('c', {5}, {0.778786, 0.801198, 0.724375, 0.230894, 0.727141}, sd::DataType::FLOAT32); - NDArray paddings('c', {1,2}, {1,1}, sd::DataType::INT32); - NDArray expected('c', {7}, {10., 0.778786, 0.801198, 0.724375, 0.230894, 0.727141, 10.}, sd::DataType::FLOAT32); - NDArray z('c', {7}, sd::DataType::FLOAT32); - - sd::ops::pad op; - Nd4jStatus status = op.execute({&input, &paddings}, {&z}, {10}, {0}, {}); // constant + sd::ops::pad op; + Nd4jStatus status = + op.execute({&input, &paddings}, {&z}, {10}, {0}, {}); // constant - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.isSameShapeStrict(z)); - ASSERT_TRUE(expected.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.isSameShapeStrict(z)); + ASSERT_TRUE(expected.equalsTo(z)); } //////////////////////////////////////////////////////////////////// // CONSTANT mode 2D TEST_F(DeclarableOpsTests12, Pad_1) { + double inBuff[] = {1, 2, 3, 4, 5, 6}; + int padBuff[] = {1, 1, 2, 2}; + double expBuff[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 0, 0, + 0, 0, 4, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - double inBuff[] = {1,2,3,4,5,6}; - int padBuff[] = {1,1,2,2}; - double expBuff[] = {0,0,0,0,0,0,0, 0,0,1,2,3,0,0, 0,0,4,5,6,0,0, 0,0,0,0,0,0,0}; - - auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); + auto input = NDArrayFactory::create(inBuff, 'c', {2, 3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {2, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 7}); - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {0}); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result->printIndexedBuffer(); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result->printIndexedBuffer(); - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// // REFLECT mode 2D TEST_F(DeclarableOpsTests12, Pad_2) { + double inBuff[] = {1, 2, 3, 4, 5, 6}; + int padBuff[] = {1, 1, 2, 2}; + double expBuff[] = {6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, + 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}; - double inBuff[] = {1,2,3,4,5,6}; - int padBuff[] = {1,1,2,2}; - double expBuff[] = {6,5,4,5,6,5,4, 3,2,1,2,3,2,1, 6,5,4,5,6,5,4, 3,2,1,2,3,2,1}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {2, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 7}); - auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); - auto result = results.at(0); - // result->printIndexedBuffer(); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result->printIndexedBuffer(); - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// // SYMMETRIC mode 2D TEST_F(DeclarableOpsTests12, Pad_3) { + double inBuff[] = {1, 2, 3, 4, 5, 6}; + int padBuff[] = {1, 1, 2, 2}; + double expBuff[] = {2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, + 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5}; - double inBuff[] = {1,2,3,4,5,6}; - int padBuff[] = {1,1,2,2}; - double expBuff[] = {2,1,1,2,3,3,2, 2,1,1,2,3,3,2, 5,4,4,5,6,6,5, 5,4,4,5,6,6,5}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {2, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 7}); - auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - // result->printIndexedBuffer(); + auto result = results.at(0); + // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// // CONSTANT mode 3D TEST_F(DeclarableOpsTests12, Pad_4) { + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18}; + int padBuff[] = {1, 1, 2, 2, 2, 2}; + double expBuff[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0, 0, 7, 8, 9, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 10, 11, 12, 0, 0, 0, 0, 13, 14, 15, 0, 0, 0, 0, 16, 17, 18, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - double inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; - int padBuff[] = {1,1,2,2,2,2}; - double expBuff[] = {0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 1, 2, 3,0,0,0,0, 4, 5, 6,0,0,0,0, 7, 8, 9,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0,10,11,12,0,0,0,0,13,14,15,0,0,0,0,16,17,18,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 3, 3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {3, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 7, 7}); - auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result->printIndexedBuffer(); - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - - //////////////////////////////////////////////////////////////////// // REFLECT mode 3D TEST_F(DeclarableOpsTests12, Pad_5) { - - double inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; - int padBuff[] = {1,1,2,2,2,2}; - double expBuff[] = {18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}; - auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18}; + int padBuff[] = {1, 1, 2, 2, 2, 2}; + double expBuff[] = { + 18, 17, 16, 17, 18, 17, 16, 15, 14, 13, 14, 15, 14, 13, 12, 11, 10, 11, + 12, 11, 10, 15, 14, 13, 14, 15, 14, 13, 18, 17, 16, 17, 18, 17, 16, 15, + 14, 13, 14, 15, 14, 13, 12, 11, 10, 11, 12, 11, 10, 9, 8, 7, 8, 9, + 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, + 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, + 4, 3, 2, 1, 2, 3, 2, 1, 18, 17, 16, 17, 18, 17, 16, 15, 14, 13, + 14, 15, 14, 13, 12, 11, 10, 11, 12, 11, 10, 15, 14, 13, 14, 15, 14, 13, + 18, 17, 16, 17, 18, 17, 16, 15, 14, 13, 14, 15, 14, 13, 12, 11, 10, 11, + 12, 11, 10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, + 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, + 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 3, 3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {3, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 7, 7}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// // SYMMETRIC mode 3D TEST_F(DeclarableOpsTests12, Pad_6) { - - double inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; - int padBuff[] = {1,1,2,2,2,2}; - double expBuff[] = {5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14}; - - auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18}; + int padBuff[] = {1, 1, 2, 2, 2, 2}; + double expBuff[] = { + 5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, + 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, + 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, + 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, + 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, + 8, 5, 4, 4, 5, 6, 6, 5, 14, 13, 13, 14, 15, 15, 14, 11, 10, 10, + 11, 12, 12, 11, 11, 10, 10, 11, 12, 12, 11, 14, 13, 13, 14, 15, 15, 14, + 17, 16, 16, 17, 18, 18, 17, 17, 16, 16, 17, 18, 18, 17, 14, 13, 13, 14, + 15, 15, 14, 14, 13, 13, 14, 15, 15, 14, 11, 10, 10, 11, 12, 12, 11, 11, + 10, 10, 11, 12, 12, 11, 14, 13, 13, 14, 15, 15, 14, 17, 16, 16, 17, 18, + 18, 17, 17, 16, 16, 17, 18, 18, 17, 14, 13, 13, 14, 15, 15, 14}; + + auto input = NDArrayFactory::create(inBuff, 'c', {2, 3, 3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {3, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 7, 7}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// // CONSTANT mode 4D -TEST_F(DeclarableOpsTests12, Pad_7) -{ - - double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; - int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; - double expBuff[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 0, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 10, 0, 0, 11, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 0, 0, 15, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); - - +TEST_F(DeclarableOpsTests12, Pad_7) { + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; + double expBuff[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, + 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 0, 7, 8, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 10, 0, 0, 11, + 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 0, 0, 15, 16, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// // REFLECT mode 4D -TEST_F(DeclarableOpsTests12, Pad_8) -{ - - double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; - int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; - double expBuff[] = {16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1}; - auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); - - +TEST_F(DeclarableOpsTests12, Pad_8) { + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; + double expBuff[] = { + 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, + 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, + 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, + 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, + 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, + 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, + 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, + 2, 1, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, + 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, + 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, + 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, + 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, + 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, + 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, + 2, 1, 2, 1}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } ////////////////////////////////////////////////////////////////// // SYMMETRIC mode 4D -TEST_F(DeclarableOpsTests12, Pad_9) -{ - - double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; - int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; - double expBuff[] = {1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16}; - auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(*result)); - ASSERT_TRUE(expected.equalsTo(result)); - - +TEST_F(DeclarableOpsTests12, Pad_9) { + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; + double expBuff[] = { + 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, + 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, + 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, + 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 1, 1, 2, 2, 1, 1, 2, 2, + 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, + 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, + 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, + 8, 8, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, + 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, + 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, + 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 9, 9, 10, 10, 9, 9, + 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, + 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, + 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, + 15, 15, 16, 16}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests12, Test_Expose_1) { - auto input0 = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 6, 5, 4}); - auto input1 = NDArrayFactory::create('c', {2, 3}, {3, 2, 1, 4, 5, 6}); - - sd::ops::expose op; + auto input0 = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 6, 5, 4}); + auto input1 = NDArrayFactory::create('c', {2, 3}, {3, 2, 1, 4, 5, 6}); - auto result = op.evaluate({&input0, &input1}); + sd::ops::expose op; - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({&input0, &input1}); - auto z0 = result.at(0); - auto z1 = result.at(1); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(input0.equalsTo(z0)); - ASSERT_TRUE(input1.equalsTo(z1)); + auto z0 = result.at(0); + auto z1 = result.at(1); - + ASSERT_TRUE(input0.equalsTo(z0)); + ASSERT_TRUE(input1.equalsTo(z1)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, Pad_SGO_Test_1) { + auto in = NDArrayFactory::create({1., 1., 1., 1., 1.}); + // auto pad = NDArrayFactory::create('c', {1, 2}, {1., 1.});// = + // Nd4j.create(new double[]{1, 1}, new long[]{1, 2}); + auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); + // auto value(10.0); - auto in = NDArrayFactory::create({1., 1., 1., 1., 1.}); -// auto pad = NDArrayFactory::create('c', {1, 2}, {1., 1.});// = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}); - auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); -// auto value(10.0); + auto exp = NDArrayFactory::create({10., 1., 1., 1., 1., 1., 10.}); - auto exp = NDArrayFactory::create({10., 1., 1., 1., 1., 1., 10.}); + sd::ops::pad op; - sd::ops::pad op; - - auto res = op.evaluate({&in, &pad}, {10.0}, {0}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - // res.at(0)->printIndexedBuffer("PAD_SGO"); - // exp.printIndexedBuffer("PAD_EXP"); - ASSERT_TRUE(exp.equalsTo(res.at(0))); - + auto res = op.evaluate({&in, &pad}, {10.0}, {0}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + // res.at(0)->printIndexedBuffer("PAD_SGO"); + // exp.printIndexedBuffer("PAD_EXP"); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, LU_Test_1) { + auto in = NDArrayFactory::create( + 'c', {3, 3}, {1., 2., 3., 0., 2., 3., 0., 0., 7.}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {1., 2., 3., 0., 2., 3., 0., 0., 7}); + auto pExp = NDArrayFactory::create('c', {3}, {0, 1, 2}); + sd::ops::lu op; - auto in = NDArrayFactory::create('c', {3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7.}); - auto exp = NDArrayFactory::create('c', {3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7}); - auto pExp = NDArrayFactory::create('c', {3}, {0, 1, 2}); - sd::ops::lu op; - - auto res = op.evaluate({&in}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - auto p = res.at(1); -// z->printIndexedBuffer("Triangulars"); -// p->printIndexedBuffer("Permutaions"); + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); + // z->printIndexedBuffer("Triangulars"); + // p->printIndexedBuffer("Permutaions"); - ASSERT_TRUE(exp.equalsTo(z)); - ASSERT_TRUE(pExp.equalsTo(p)); - - + ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(pExp.equalsTo(p)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, LU_Test_2) { - auto in = NDArrayFactory::create('c', {3,3}, {1, 0, 0, 2, 3, 0, 4, 5, 6}); + auto in = + NDArrayFactory::create('c', {3, 3}, {1, 0, 0, 2, 3, 0, 4, 5, 6}); - auto expLU = NDArrayFactory::create('c', {3,3}, {4., 5., 6., 0.25, -1.25, -1.5, 0.5, -0.4, -3.6}); - auto expP = NDArrayFactory::create({2, 0, 1}); - sd::ops::lu op; + auto expLU = NDArrayFactory::create( + 'c', {3, 3}, {4., 5., 6., 0.25, -1.25, -1.5, 0.5, -0.4, -3.6}); + auto expP = NDArrayFactory::create({2, 0, 1}); + sd::ops::lu op; - auto res = op.evaluate({&in}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - auto p = res.at(1); -// z->printIndexedBuffer("Triangulars2"); -// p->printIndexedBuffer("Permutaions2"); - ASSERT_TRUE(expLU.equalsTo(z)); - ASSERT_TRUE(expP.equalsTo(p)); - + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); + // z->printIndexedBuffer("Triangulars2"); + // p->printIndexedBuffer("Permutaions2"); + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, LU_Test_3) { - auto in = NDArrayFactory::create('c', {3,3}, {1,2,3,4,7,9, 11, 12, 13}); + auto in = NDArrayFactory::create('c', {3, 3}, + {1, 2, 3, 4, 7, 9, 11, 12, 13}); - auto expLU = NDArrayFactory::create('c', {3,3}, { - 11., 12., 13., - 0.36363637, 2.6363635, 4.272727, - 0.09090909, 0.3448276, 0.34482753}); + auto expLU = NDArrayFactory::create( + 'c', {3, 3}, + {11., 12., 13., 0.36363637, 2.6363635, 4.272727, 0.09090909, 0.3448276, + 0.34482753}); - auto expP = NDArrayFactory::create({2, 1, 0}); - sd::ops::lu op; + auto expP = NDArrayFactory::create({2, 1, 0}); + sd::ops::lu op; - auto res = op.evaluate({&in}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - auto p = res.at(1); -// z->printIndexedBuffer("Triangulars3"); -// p->printIndexedBuffer("Permutaions3"); - ASSERT_TRUE(expLU.equalsTo(z)); - ASSERT_TRUE(expP.equalsTo(p)); - + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); + // z->printIndexedBuffer("Triangulars3"); + // p->printIndexedBuffer("Permutaions3"); + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, LU_Test_4) { - - auto in = NDArrayFactory::create('c', {10,10}, { - 1., 2., 3., 4., 5., 6., 7., 8., 1., 15., - 5., 1., 13., 4., 15., 1., 17., 9., 11., 25., - 1., 9., 1., 4., 5., 2., 13., 10, 21., 15., - 3., 9., 4., 1., 5., 3., 7., 1, 1., 5., - 2., 3., 2., 5., 4., 4., 7., 3, 3., 4., - 0., 1., 3., 3., 5., 1., 3., 1, 31., 15., - 2., 1., 4., 3., 1., 5., 1., 2, 31., 35., - 3., 4., 3., 3., 4., 4., 4., 1., 3., 1., - 1., 1., 1., 1., 5., 6., 5., 4., 3., 2., - 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}); - - auto expLU = NDArrayFactory::create('c', {10,10}, { - 5.0, 1.0, 13.0, 4.0, 15.0, 1.0, 17.0, 9.0, 11.0, 25.0, - 0.2, 8.8, -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, 18.8, 10.0, - 0.6, 0.386364, -4.181818, -0.636364, -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636, - 0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, -6.978261, -8.114130, -17.641304, -9.836957, - 0.4, 0.068182, 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, 22.365079, 25.751323, - 0.2, 0.090909, 0.347826, -0.031746, -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387, - 0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, -7.593805, -9.585099, 1.663379, -15.900300, - 0.4, 0.295455, 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, -16.392106, - 9.022119, - 0.2, 0.204545, -0.173913, -0.592593, 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178, - 0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, -0.030154, -0.243578, 0.087256, 0.112695 - }); - - auto expP = NDArrayFactory::create({1, 2, 7, 3, 6, 8, 5, 4, 0, 9}); - sd::ops::lu op; - - auto res = op.evaluate({&in}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - auto p = res.at(1); -// z->printBuffer("Triangulars4"); -// expLU.printBuffer("TriangulExp4"); -// p->printBuffer("Permutaions4"); - - ASSERT_TRUE(expLU.equalsTo(z)); - ASSERT_TRUE(expP.equalsTo(p)); - + auto in = NDArrayFactory::create( + 'c', {10, 10}, + {1., 2., 3., 4., 5., 6., 7., 8., 1., 15., 5., 1., 13., 4., 15., + 1., 17., 9., 11., 25., 1., 9., 1., 4., 5., 2., 13., 10, 21., 15., + 3., 9., 4., 1., 5., 3., 7., 1, 1., 5., 2., 3., 2., 5., 4., + 4., 7., 3, 3., 4., 0., 1., 3., 3., 5., 1., 3., 1, 31., 15., + 2., 1., 4., 3., 1., 5., 1., 2, 31., 35., 3., 4., 3., 3., 4., + 4., 4., 1., 3., 1., 1., 1., 1., 1., 5., 6., 5., 4., 3., 2., + 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}); + + auto expLU = NDArrayFactory::create( + 'c', {10, 10}, + {5.0, 1.0, 13.0, 4.0, 15.0, 1.0, + 17.0, 9.0, 11.0, 25.0, 0.2, 8.8, + -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, + 18.8, 10.0, 0.6, 0.386364, -4.181818, -0.636364, + -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636, + 0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, + -6.978261, -8.114130, -17.641304, -9.836957, 0.4, 0.068182, + 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, + 22.365079, 25.751323, 0.2, 0.090909, 0.347826, -0.031746, + -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387, + 0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, + -7.593805, -9.585099, 1.663379, -15.900300, 0.4, 0.295455, + 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, + -16.392106, -9.022119, 0.2, 0.204545, -0.173913, -0.592593, + 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178, + 0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, + -0.030154, -0.243578, 0.087256, 0.112695}); + + auto expP = NDArrayFactory::create({1, 2, 7, 3, 6, 8, 5, 4, 0, 9}); + sd::ops::lu op; + + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); + // z->printBuffer("Triangulars4"); + // expLU.printBuffer("TriangulExp4"); + // p->printBuffer("Permutaions4"); + + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); } TEST_F(DeclarableOpsTests12, LU_Test_5) { - - auto in = NDArrayFactory::create('c', {2, 10,10}, { - 1., 2., 3., 4., 5., 6., 7., 8., 1., 15., - 5., 1., 13., 4., 15., 1., 17., 9., 11., 25., - 1., 9., 1., 4., 5., 2., 13., 10, 21., 15., - 3., 9., 4., 1., 5., 3., 7., 1, 1., 5., - 2., 3., 2., 5., 4., 4., 7., 3, 3., 4., - 0., 1., 3., 3., 5., 1., 3., 1, 31., 15., - 2., 1., 4., 3., 1., 5., 1., 2, 31., 35., - 3., 4., 3., 3., 4., 4., 4., 1., 3., 1., - 1., 1., 1., 1., 5., 6., 5., 4., 3., 2., - 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., - - 1., 2., 3., 4., 5., 6., 7., 8., 1., 15., - 5., 1., 13., 4., 15., 1., 17., 9., 11., 25., - 1., 9., 1., 4., 5., 2., 13., 10, 21., 15., - 3., 9., 4., 1., 5., 3., 7., 1, 1., 5., - 2., 3., 2., 5., 4., 4., 7., 3, 3., 4., - 0., 1., 3., 3., 5., 1., 3., 1, 31., 15., - 2., 1., 4., 3., 1., 5., 1., 2, 31., 35., - 3., 4., 3., 3., 4., 4., 4., 1., 3., 1., - 1., 1., 1., 1., 5., 6., 5., 4., 3., 2., - 1., 1., 1., 1., 1., 1., 1., 1., 1., 1. - }); - - auto expLU = NDArrayFactory::create('c', {2, 10,10}, { - 5.0, 1.0, 13.0, 4.0, 15.0, 1.0, 17.0, 9.0, 11.0, 25.0, - 0.2, 8.8, -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, 18.8, 10.0, - 0.6, 0.386364, -4.181818, -0.636364, -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636, - 0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, -6.978261, -8.114130, -17.641304, -9.836957, - 0.4, 0.068182, 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, 22.365079, 25.751323, - 0.2, 0.090909, 0.347826, -0.031746, -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387, - 0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, -7.593805, -9.585099, 1.663379, -15.900300, - 0.4, 0.295455, 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, -16.392106, - 9.022119, - 0.2, 0.204545, -0.173913, -0.592593, 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178, - 0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, -0.030154, -0.243578, 0.087256, 0.112695, - - 5.0, 1.0, 13.0, 4.0, 15.0, 1.0, 17.0, 9.0, 11.0, 25.0, - 0.2, 8.8, -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, 18.8, 10.0, - 0.6, 0.386364, -4.181818, -0.636364, -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636, - 0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, -6.978261, -8.114130, -17.641304, -9.836957, - 0.4, 0.068182, 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, 22.365079, 25.751323, - 0.2, 0.090909, 0.347826, -0.031746, -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387, - 0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, -7.593805, -9.585099, 1.663379, -15.900300, - 0.4, 0.295455, 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, -16.392106, - 9.022119, - 0.2, 0.204545, -0.173913, -0.592593, 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178, - 0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, -0.030154, -0.243578, 0.087256, 0.112695 - - }); - - auto expP = NDArrayFactory::create('c', {2, 10}, { - 1, 2, 7, 3, 6, 8, 5, 4, 0, 9, - 1, 2, 7, 3, 6, 8, 5, 4, 0, 9 - }); - sd::ops::lu op; - - auto res = op.evaluate({&in}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - auto p = res.at(1); -// z->printBuffer("Triangulars5"); -// expLU.printBuffer("TriangulExp5"); -// p->printBuffer("Permutaions5"); - - ASSERT_TRUE(expLU.equalsTo(z)); - ASSERT_TRUE(expP.equalsTo(p)); - + auto in = NDArrayFactory::create( + 'c', {2, 10, 10}, + {1., 2., 3., 4., 5., 6., 7., 8., 1., 15., 5., 1., 13., 4., 15., + 1., 17., 9., 11., 25., 1., 9., 1., 4., 5., 2., 13., 10, 21., 15., + 3., 9., 4., 1., 5., 3., 7., 1, 1., 5., 2., 3., 2., 5., 4., + 4., 7., 3, 3., 4., 0., 1., 3., 3., 5., 1., 3., 1, 31., 15., + 2., 1., 4., 3., 1., 5., 1., 2, 31., 35., 3., 4., 3., 3., 4., + 4., 4., 1., 3., 1., 1., 1., 1., 1., 5., 6., 5., 4., 3., 2., + 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., + + 1., 2., 3., 4., 5., 6., 7., 8., 1., 15., 5., 1., 13., 4., 15., + 1., 17., 9., 11., 25., 1., 9., 1., 4., 5., 2., 13., 10, 21., 15., + 3., 9., 4., 1., 5., 3., 7., 1, 1., 5., 2., 3., 2., 5., 4., + 4., 7., 3, 3., 4., 0., 1., 3., 3., 5., 1., 3., 1, 31., 15., + 2., 1., 4., 3., 1., 5., 1., 2, 31., 35., 3., 4., 3., 3., 4., + 4., 4., 1., 3., 1., 1., 1., 1., 1., 5., 6., 5., 4., 3., 2., + 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}); + + auto expLU = NDArrayFactory::create( + 'c', {2, 10, 10}, + {5.0, 1.0, 13.0, 4.0, 15.0, 1.0, + 17.0, 9.0, 11.0, 25.0, 0.2, 8.8, + -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, + 18.8, 10.0, 0.6, 0.386364, -4.181818, -0.636364, + -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636, + 0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, + -6.978261, -8.114130, -17.641304, -9.836957, 0.4, 0.068182, + 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, + 22.365079, 25.751323, 0.2, 0.090909, 0.347826, -0.031746, + -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387, + 0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, + -7.593805, -9.585099, 1.663379, -15.900300, 0.4, 0.295455, + 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, + -16.392106, -9.022119, 0.2, 0.204545, -0.173913, -0.592593, + 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178, + 0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, + -0.030154, -0.243578, 0.087256, 0.112695, + + 5.0, 1.0, 13.0, 4.0, 15.0, 1.0, + 17.0, 9.0, 11.0, 25.0, 0.2, 8.8, + -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, + 18.8, 10.0, 0.6, 0.386364, -4.181818, -0.636364, + -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636, + 0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, + -6.978261, -8.114130, -17.641304, -9.836957, 0.4, 0.068182, + 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, + 22.365079, 25.751323, 0.2, 0.090909, 0.347826, -0.031746, + -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387, + 0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, + -7.593805, -9.585099, 1.663379, -15.900300, 0.4, 0.295455, + 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, + -16.392106, -9.022119, 0.2, 0.204545, -0.173913, -0.592593, + 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178, + 0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, + -0.030154, -0.243578, 0.087256, 0.112695 + + }); + + auto expP = NDArrayFactory::create( + 'c', {2, 10}, + {1, 2, 7, 3, 6, 8, 5, 4, 0, 9, 1, 2, 7, 3, 6, 8, 5, 4, 0, 9}); + sd::ops::lu op; + + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); + // z->printBuffer("Triangulars5"); + // expLU.printBuffer("TriangulExp5"); + // p->printBuffer("Permutaions5"); + + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, LU_Test_1_2) { + auto in = NDArrayFactory::create( + 'c', {2, 3, 3}, + {1., 2., 3., 0., 2., 3., 0., 0., 7., 1., 2., 3., 0., 2., 3., 0., 0., 7.}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 3}, + {1., 2., 3., 0., 2., 3., 0., 0., 7, 1., 2., 3., 0., 2., 3., 0., 0., 7.}); - auto in = NDArrayFactory::create('c', {2, 3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7.,1., 2., 3., 0., 2., 3., 0., 0., 7.}); - auto exp = NDArrayFactory::create('c', {2, 3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7, 1., 2., 3., 0., 2., 3., 0., 0., 7.}); - - sd::ops::lu op; + sd::ops::lu op; - auto res = op.evaluate({&in}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - auto p = res.at(1); -// z->printIndexedBuffer("Triangulars (2,3,3)"); -// p->printIndexedBuffer("Permutaions (2,3,3)"); - ASSERT_TRUE(exp.equalsTo(res.at(0))); - + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); + // z->printIndexedBuffer("Triangulars (2,3,3)"); + // p->printIndexedBuffer("Permutaions (2,3,3)"); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, LU_Test_3_2) { + auto in = NDArrayFactory::create( + 'c', {2, 3, 3}, + {1, 2, 3, 4, 7, 9, 11, 12, 13, 1, 2, 3, 4, 7, 9, 11, 12, 13}); - auto in = NDArrayFactory::create('c', {2, 3,3}, {1,2,3,4,7,9, 11, 12, 13,1,2,3,4,7,9, 11, 12, 13}); + auto expLU = NDArrayFactory::create( + 'c', {2, 3, 3}, + {11., 12., 13., 0.36363637, 2.6363635, 4.272727, 0.09090909, 0.3448276, + 0.34482753, - auto expLU = NDArrayFactory::create('c', {2, 3,3}, { - 11., 12., 13., - 0.36363637, 2.6363635, 4.272727, - 0.09090909, 0.3448276, 0.34482753, + 11., 12., 13., 0.36363637, 2.6363635, 4.272727, 0.09090909, 0.3448276, + 0.34482753}); - 11., 12., 13., - 0.36363637, 2.6363635, 4.272727, - 0.09090909, 0.3448276, 0.34482753 - }); - - auto expP = NDArrayFactory::create('c', {2,3}, {2, 1, 0, 2, 1, 0}); - sd::ops::lu op; + auto expP = NDArrayFactory::create('c', {2, 3}, {2, 1, 0, 2, 1, 0}); + sd::ops::lu op; - auto res = op.evaluate({&in}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - auto p = res.at(1); -// z->printIndexedBuffer("Triangulars3_2"); -// p->printIndexedBuffer("Permutaions3_2"); + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); + // z->printIndexedBuffer("Triangulars3_2"); + // p->printIndexedBuffer("Permutaions3_2"); - ASSERT_TRUE(expLU.equalsTo(z)); - ASSERT_TRUE(expP.equalsTo(p)); - + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, LU_Test_3_3) { + auto in = NDArrayFactory::create( + 'c', {2, 3, 3}, + {1, 2, 3, 4, 7, 9, 11, 12, 13, 13, 2, 3, 4, 7, 9, 11, 12, 1}); + auto expLU = NDArrayFactory::create( + 'c', {2, 3, 3}, + {11., 12., 13., 0.36363637, 2.6363635, 4.272727, 0.09090909, 0.3448276, + 0.34482753, - auto in = NDArrayFactory::create('c', {2, 3,3}, {1,2,3,4,7,9, 11, 12, 13,13,2,3,4,7,9, 11, 12, 1}); - auto expLU = NDArrayFactory::create('c', {2, 3,3}, { - 11., 12., 13., - 0.36363637, 2.6363635, 4.272727, - 0.09090909, 0.3448276, 0.34482753, - - 13., 2., 3., - 0.84615386, 10.307693, -1.5384617, - 0.30769232, 0.619403, 9.029851}); + 13., 2., 3., 0.84615386, 10.307693, -1.5384617, 0.30769232, 0.619403, + 9.029851}); - auto expP = NDArrayFactory::create('c', {2,3}, {2, 1, 0, 0, 2, 1}); - sd::ops::lu op; + auto expP = NDArrayFactory::create('c', {2, 3}, {2, 1, 0, 0, 2, 1}); + sd::ops::lu op; - auto res = op.evaluate({&in}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - auto p = res.at(1); -// z->printIndexedBuffer("Triangulars3_3"); -// p->printIndexedBuffer("Permutaions3_3"); + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); + // z->printIndexedBuffer("Triangulars3_3"); + // p->printIndexedBuffer("Permutaions3_3"); - ASSERT_TRUE(expLU.equalsTo(z)); - ASSERT_TRUE(expP.equalsTo(p)); - + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, LU_Test_4_1) { + auto in = NDArrayFactory::create( + 'c', {2, 2, 2}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f}); - auto in = NDArrayFactory::create('c', {2, 2,2}, { - 0.7788f, 0.8012f, 0.7244f, 0.2309f, - 0.7271f, 0.1804f, 0.5056f, 0.8925f - }); - - auto expLU = NDArrayFactory::create('c', {2, 2,2}, { - 0.7788f, 0.8012f, 0.930149f, -0.514335f, - 0.7271f, 0.1804f, 0.695365f, 0.767056f - }); + auto expLU = + NDArrayFactory::create('c', {2, 2, 2}, + {0.7788f, 0.8012f, 0.930149f, -0.514335f, + 0.7271f, 0.1804f, 0.695365f, 0.767056f}); - auto expP = NDArrayFactory::create('c', {2,2}, {0, 1, 0, 1}); - sd::ops::lu op; + auto expP = NDArrayFactory::create('c', {2, 2}, {0, 1, 0, 1}); + sd::ops::lu op; - auto res = op.evaluate({&in}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - auto p = res.at(1); -// z->printIndexedBuffer("Triangulars4_1"); -// p->printIndexedBuffer("Permutaions4_1"); + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); + // z->printIndexedBuffer("Triangulars4_1"); + // p->printIndexedBuffer("Permutaions4_1"); - ASSERT_TRUE(expLU.equalsTo(z)); - ASSERT_TRUE(expP.equalsTo(p)); - + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, LU_Test_4_2) { + auto in = NDArrayFactory::create( + 'c', {2, 2, 2}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f}); - auto in = NDArrayFactory::create('c', {2, 2,2}, { - 0.7788f, 0.8012f, 0.7244f, 0.2309f, - 0.7271f, 0.1804f, 0.5056f, 0.8925f - }); + auto expLU = + NDArrayFactory::create('c', {2, 2, 2}, + {0.7788f, 0.8012f, 0.930149f, -0.514335f, + 0.7271f, 0.1804f, 0.695365f, 0.767056f}); - auto expLU = NDArrayFactory::create('c', {2, 2,2}, { - 0.7788f, 0.8012f, 0.930149f, -0.514335f, - 0.7271f, 0.1804f, 0.695365f, 0.767056f - }); - - auto expP = NDArrayFactory::create('c', {2,2}, {0, 1, 0, 1}); - sd::ops::lu op; + auto expP = NDArrayFactory::create('c', {2, 2}, {0, 1, 0, 1}); + sd::ops::lu op; - auto res = op.evaluate({&in}, {}, {sd::DataType::INT64}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - auto p = res.at(1); + auto res = op.evaluate({&in}, {}, {sd::DataType::INT64}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); -// z->printIndexedBuffer("Triangulars4_2"); -// p->printIndexedBuffer("Permutaions4_2"); + // z->printIndexedBuffer("Triangulars4_2"); + // p->printIndexedBuffer("Permutaions4_2"); - ASSERT_TRUE(expLU.equalsTo(z)); - ASSERT_TRUE(expP.equalsTo(p)); - + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, QR_Test_1) { - - auto in = NDArrayFactory::create('c', {5,3}, { - 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3. - }); - auto expQ = NDArrayFactory::create('c', {5, 5}, { - 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485 - }); - - auto expR = NDArrayFactory::create('c', {5,3}, { - -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0. }); - sd::ops::qr op; - auto res = op.evaluate({&in}, {}, {}, {true}); - - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto q = res.at(0); - auto r = res.at(1); -// q->printIndexedBuffer("Orthogonal 5x5"); -// expQ.printBuffer("Orthogonal Exp"); -// r->printIndexedBuffer("Upper triangular 5x3"); -// expR.printBuffer("Upper triangular Exp"); -// q->printShapeInfo("Q shape"); -// r->printShapeInfo("R shape"); - sd::ops::matmul opMul; - auto res2 = opMul.evaluate({q, r}); //MmulHelper::matmul(q, r, &in, false, false); - auto exp = res2.at(0);//->printIndexedBuffer("Result as result"); - ASSERT_TRUE(exp->isSameShape(in)); -// ASSERT_TRUE(q->isSameShape(expQ)); - - //ASSERT_TRUE(expQ.equalsTo(q)); - ASSERT_TRUE(exp->equalsTo(in)); - - + auto in = NDArrayFactory::create( + 'c', {5, 3}, + {12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3.}); + auto expQ = NDArrayFactory::create( + 'c', {5, 5}, + {0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, + -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, + 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, + 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, + -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485}); + + auto expR = NDArrayFactory::create( + 'c', {5, 3}, + {-14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., + 35.201546, 0., 0., 0., 0., 0., 0.}); + sd::ops::qr op; + auto res = op.evaluate({&in}, {}, {}, {true}); + + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto q = res.at(0); + auto r = res.at(1); + q.printIndexedBuffer("Orthogonal 5x5"); + expQ.printBuffer("Orthogonal Exp"); + r.printIndexedBuffer("Upper triangular 5x3"); + expR.printBuffer("Upper triangular Exp"); + q.printShapeInfo("Q shape"); + r.printShapeInfo("R shape"); + sd::ops::matmul opMul; + auto res2 = + opMul.evaluate({&q, &r}); // MmulHelper::matmul(q, r, &in, false, false); + auto exp = res2.at(0); //->printIndexedBuffer("Result as result"); + ASSERT_TRUE(exp.isSameShape(in)); + // ASSERT_TRUE(q->isSameShape(expQ)); + + // ASSERT_TRUE(expQ.equalsTo(q)); + ASSERT_TRUE(exp.equalsTo(in)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, QR_Test_1_1) { - - auto in = NDArrayFactory::create('c', {4, 5, 3}, { - 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3., - 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3., - 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3., - 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3. - }); - auto expQ = NDArrayFactory::create('c', {4, 5, 5}, { - 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485, - 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485, - 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485, - 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485 - }); - - auto expR = NDArrayFactory::create('c', {4, 5,3}, { - -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0., - -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0., - -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0., - -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0. - }); - sd::ops::qr op; - auto res = op.evaluate({&in}, {}, {}, {true}); - - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto q = res.at(0); - auto r = res.at(1); -// q->printIndexedBuffer("Orthogonal 5x5"); -// expQ.printBuffer("Orthogonal Exp"); -// r->printIndexedBuffer("Upper triangular 5x3"); -// expR.printBuffer("Upper triangular Exp"); -// q->printShapeInfo("Q shape"); -// r->printShapeInfo("R shape"); - sd::ops::matmul opMul; - auto res2 = opMul.evaluate({q, r}); //MmulHelper::matmul(q, r, &in, false, false); - auto exp = res2.at(0);//->printIndexedBuffer("Result as result"); - ASSERT_TRUE(exp->isSameShape(in)); -// ASSERT_TRUE(q->isSameShape(expQ)); - - //ASSERT_TRUE(expQ.equalsTo(q)); - ASSERT_TRUE(exp->equalsTo(in)); - - + auto in = NDArrayFactory::create( + 'c', {4, 5, 3}, + {12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3., + 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3., + 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3., + 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3.}); + auto expQ = NDArrayFactory::create( + 'c', {4, 5, 5}, + {0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, + -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, + 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, + 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, + -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485, + 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, + -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, + 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, + 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, + -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485, + 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, + -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, + 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, + 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, + -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485, + 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, + -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, + 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, + 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, + -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485}); + + auto expR = NDArrayFactory::create( + 'c', {4, 5, 3}, + {-14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, + 0., 0., 35.201546, 0., 0., 0., + 0., 0., 0., -14.177447, -20.666622, 13.401566, + 0., -175.04254, 70.080315, 0., 0., 35.201546, + 0., 0., 0., 0., 0., 0., + -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, + 0., 0., 35.201546, 0., 0., 0., + 0., 0., 0., -14.177447, -20.666622, 13.401566, + 0., -175.04254, 70.080315, 0., 0., 35.201546, + 0., 0., 0., 0., 0., 0.}); + sd::ops::qr op; + auto res = op.evaluate({&in}, {}, {}, {true}); + + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto q = res.at(0); + auto r = res.at(1); + // q->printIndexedBuffer("Orthogonal 5x5"); + // expQ.printBuffer("Orthogonal Exp"); + // r->printIndexedBuffer("Upper triangular 5x3"); + // expR.printBuffer("Upper triangular Exp"); + // q->printShapeInfo("Q shape"); + // r->printShapeInfo("R shape"); + sd::ops::matmul opMul; + auto res2 = + opMul.evaluate({&q, &r}); // MmulHelper::matmul(q, r, &in, false, false); + auto exp = res2.at(0); //->printIndexedBuffer("Result as result"); + ASSERT_TRUE(exp.isSameShape(in)); + // ASSERT_TRUE(q->isSameShape(expQ)); + + // ASSERT_TRUE(expQ.equalsTo(q)); + ASSERT_TRUE(exp.equalsTo(in)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, QR_Test_2) { - - auto in = NDArrayFactory::create('c', {5,3}, {12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3.}); - auto expQ = NDArrayFactory::create('c', {5, 3}, {0.8464148,0.3912908,-0.3431241,-0.42320737, -0.9040873,0.02927014,0.28213826, -0.17042054, -0.93285596,0.07053456, -0.01404065,0.00109937,-0.14106913,0.0166551,0.10577161}); - auto expR = NDArrayFactory::create('c', {3,3}, {-14.177447,-20.666622,13.401566,0.,-175.04254,70.080315,0.,0.,35.201546}); - - sd::ops::qr op; - auto res = op.evaluate({&in}, {}, {}, {false}); - - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto q = res.at(0); - auto r = res.at(1); - ASSERT_TRUE(q->isSameShape(expQ)); - ASSERT_TRUE(r->isSameShape(expR)); - - sd::ops::matmul opMul; - auto res2 = opMul.evaluate({q, r}); //MmulHelper::matmul(q, r, &in, false, false); - auto exp = res2.at(0);//->printIndexedBuffer("Result as result"); - ASSERT_TRUE(exp->isSameShape(in)); - ASSERT_TRUE(exp->equalsTo(in)); - + auto in = NDArrayFactory::create( + 'c', {5, 3}, + {12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3.}); + auto expQ = NDArrayFactory::create( + 'c', {5, 3}, + {0.8464148, 0.3912908, -0.3431241, -0.42320737, -0.9040873, 0.02927014, + 0.28213826, -0.17042054, -0.93285596, 0.07053456, -0.01404065, + 0.00109937, -0.14106913, 0.0166551, 0.10577161}); + auto expR = NDArrayFactory::create( + 'c', {3, 3}, + {-14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., + 35.201546}); + + sd::ops::qr op; + auto res = op.evaluate({&in}, {}, {}, {false}); + + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto q = res.at(0); + auto r = res.at(1); + ASSERT_TRUE(q.isSameShape(expQ)); + ASSERT_TRUE(r.isSameShape(expR)); + + sd::ops::matmul opMul; + auto res2 = + opMul.evaluate({&q, &r}); // MmulHelper::matmul(q, r, &in, false, false); + auto exp = res2.at(0); //->printIndexedBuffer("Result as result"); + ASSERT_TRUE(exp.isSameShape(in)); + ASSERT_TRUE(exp.equalsTo(in)); } TEST_F(DeclarableOpsTests12, ImageResize_Test1) { @@ -3148,320 +3476,291 @@ TEST_F(DeclarableOpsTests12, ImageResize_Test11) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TriangularSolve_Test_1) { + auto a = + NDArrayFactory::create('c', {4, 4}, + {3.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, + 1.f, 0.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f}); - auto a = NDArrayFactory::create('c', {4, 4}, { - 3.f, 0.f, 0.f, 0.f, - 2.f, 1.f, 0.f, 0.f, - 1.f, 0.f, 1.f, 0.f, - 1.f, 1.f, 1.f, 1.f - }); + auto b = NDArrayFactory::create('c', {4, 1}, {4.f, 2.f, 4.f, 2.f}); - auto b = NDArrayFactory::create('c', {4, 1}, { - 4.f, 2.f, 4.f, 2.f - }); + auto exp = NDArrayFactory::create( + 'c', {4, 1}, {1.333333f, -0.6666667f, 2.6666667f, -1.3333333f}); - auto exp = NDArrayFactory::create('c', {4, 1}, { - 1.333333f, -0.6666667f, 2.6666667f, -1.3333333f }); + sd::ops::triangular_solve op; - sd::ops::triangular_solve op; + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); - auto res = op.evaluate({&a, &b}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + // z->printIndexedBuffer("TriangularSolve"); -// z->printIndexedBuffer("TriangularSolve"); - - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TriangularSolve_Test_2) { - - auto a = NDArrayFactory::create('c', {4, 4}, { - 1.f, 1.f, 1.f, 1.f, - 0.f, 1.f, 1.f, 0.f, - 0.f, 0.f, 2.f, 1.f, - 0.f, 0.f, 0.f, 3.f, - }); - - auto b = NDArrayFactory::create('c', {4, 1}, { - 2.f, 4.f, 2.f, 4.f - }); - - auto exp = NDArrayFactory::create('c', {4, 1}, { - 2.f, 4.f, 1.f, 1.3333333f }); - - sd::ops::triangular_solve op; - - auto res = op.evaluate({&a, &b}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - -// z->printIndexedBuffer("TriangularSolve"); - - ASSERT_TRUE(exp.equalsTo(z)); - + auto a = NDArrayFactory::create('c', {4, 4}, + { + 1.f, + 1.f, + 1.f, + 1.f, + 0.f, + 1.f, + 1.f, + 0.f, + 0.f, + 0.f, + 2.f, + 1.f, + 0.f, + 0.f, + 0.f, + 3.f, + }); + + auto b = NDArrayFactory::create('c', {4, 1}, {2.f, 4.f, 2.f, 4.f}); + + auto exp = + NDArrayFactory::create('c', {4, 1}, {2.f, 4.f, 1.f, 1.3333333f}); + + sd::ops::triangular_solve op; + + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + + // z->printIndexedBuffer("TriangularSolve"); + + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TriangularSolve_Test_3) { + auto a = NDArrayFactory::create( + 'c', {2, 4, 4}, {3.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, + 1.f, 0.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f, - auto a = NDArrayFactory::create('c', {2, 4, 4}, { - 3.f, 0.f, 0.f, 0.f, - 2.f, 1.f, 0.f, 0.f, - 1.f, 0.f, 1.f, 0.f, - 1.f, 1.f, 1.f, 1.f, + 3.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, + 1.f, 0.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f}); - 3.f, 0.f, 0.f, 0.f, - 2.f, 1.f, 0.f, 0.f, - 1.f, 0.f, 1.f, 0.f, - 1.f, 1.f, 1.f, 1.f - }); - - auto b = NDArrayFactory::create('c', {2, 4, 1}, { - 4.f, 2.f, 4.f, 2.f, - 4.f, 2.f, 4.f, 2.f - }); + auto b = NDArrayFactory::create( + 'c', {2, 4, 1}, {4.f, 2.f, 4.f, 2.f, 4.f, 2.f, 4.f, 2.f}); - auto exp = NDArrayFactory::create('c', {2, 4, 1}, { - 1.333333f, -0.6666667f, 2.6666667f, -1.3333333f, - 1.333333f, -0.6666667f, 2.6666667f, -1.3333333f - }); + auto exp = NDArrayFactory::create( + 'c', {2, 4, 1}, + {1.333333f, -0.6666667f, 2.6666667f, -1.3333333f, 1.333333f, -0.6666667f, + 2.6666667f, -1.3333333f}); - sd::ops::triangular_solve op; + sd::ops::triangular_solve op; - auto res = op.evaluate({&a, &b}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); -// z->printIndexedBuffer("TriangularSolve"); + // z->printIndexedBuffer("TriangularSolve"); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TriangularSolve_Test_4) { - - auto a = NDArrayFactory::create('c', {4, 4}, { - 1.f, 1.f, 1.f, 1.f, - 0.f, 1.f, 1.f, 0.f, - 0.f, 0.f, 2.f, 1.f, - 0.f, 0.f, 0.f, 3.f, - }); - - auto b = NDArrayFactory::create('c', {4, 1}, { - 2.f, 4.f, 2.f, 4.f - }); - - auto exp = NDArrayFactory::create('c', {4, 1}, { - -3.3333333f, 3.6666666f, 0.333333f, 1.3333333f - }); - - sd::ops::triangular_solve op; - - auto res = op.evaluate({&a, &b}, {false}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - -// z->printIndexedBuffer("TriangularSolve"); - - ASSERT_TRUE(exp.equalsTo(z)); - + auto a = NDArrayFactory::create('c', {4, 4}, + { + 1.f, + 1.f, + 1.f, + 1.f, + 0.f, + 1.f, + 1.f, + 0.f, + 0.f, + 0.f, + 2.f, + 1.f, + 0.f, + 0.f, + 0.f, + 3.f, + }); + + auto b = NDArrayFactory::create('c', {4, 1}, {2.f, 4.f, 2.f, 4.f}); + + auto exp = NDArrayFactory::create( + 'c', {4, 1}, {-3.3333333f, 3.6666666f, 0.333333f, 1.3333333f}); + + sd::ops::triangular_solve op; + + auto res = op.evaluate({&a, &b}, {false}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + + // z->printIndexedBuffer("TriangularSolve"); + + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TriangularSolve_Test_5) { + auto a = + NDArrayFactory::create('c', {4, 4}, + {5.f, 1., -3.f, 3.f, 0.f, 1.f, 1.f, -1.f, + 0.f, 0.f, 2.f, -9.f, 0.f, 0.f, 0.f, 4.f}); - auto a = NDArrayFactory::create('c', {4, 4}, { - 5.f, 1., -3.f, 3.f, - 0.f, 1.f, 1.f, -1.f, - 0.f, 0.f, 2.f, -9.f, - 0.f, 0.f, 0.f, 4.f - }); - - auto b = NDArrayFactory::create('c', {4, 1}, { - 5.f, 2.f, 0.f, -3.f - }); + auto b = NDArrayFactory::create('c', {4, 1}, {5.f, 2.f, 0.f, -3.f}); - auto exp = NDArrayFactory::create('c', {4, 1}, { - 1.f, 1.f, 1.f, 1.f - }); + auto exp = NDArrayFactory::create('c', {4, 1}, {1.f, 1.f, 1.f, 1.f}); - sd::ops::triangular_solve op; + sd::ops::triangular_solve op; - auto res = op.evaluate({&a, &b}, {false, true}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + auto res = op.evaluate({&a, &b}, {false, true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); -// z->printIndexedBuffer("TriangularSolve with adjoint"); + // z->printIndexedBuffer("TriangularSolve with adjoint"); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, SolveLs_Test_1) { + auto a = + NDArrayFactory::create('c', {4, 4}, + {3.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, + 1.f, 0.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f}); - auto a = NDArrayFactory::create('c', {4, 4}, { - 3.f, 0.f, 0.f, 0.f, - 2.f, 1.f, 0.f, 0.f, - 1.f, 0.f, 1.f, 0.f, - 1.f, 1.f, 1.f, 1.f - }); - - auto b = NDArrayFactory::create('c', {4, 1}, { - 4.f, 2.f, 4.f, 2.f - }); + auto b = NDArrayFactory::create('c', {4, 1}, {4.f, 2.f, 4.f, 2.f}); - auto exp = NDArrayFactory::create('c', {4, 1}, { - 1.333333f, -0.6666667f, 2.6666667f, -1.3333333f }); + auto exp = NDArrayFactory::create( + 'c', {4, 1}, {1.333333f, -0.6666667f, 2.6666667f, -1.3333333f}); - sd::ops::lstsq op; + sd::ops::lstsq op; - auto res = op.evaluate({&a, &b}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); -// z->printIndexedBuffer("MatrixSolveLS"); - MmulHelper::matmul(&a, z, &exp, false, false); + // z->printIndexedBuffer("MatrixSolveLS"); + MmulHelper::matmul(&a, &z, &exp, false, false); - ASSERT_TRUE(exp.equalsTo(b)); - + ASSERT_TRUE(exp.equalsTo(b)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, SolveLs_Test_2) { + auto a = NDArrayFactory::create( + 'c', {3, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 11.f, 8.f, 21.f}); - auto a = NDArrayFactory::create('c', {3, 3}, { - 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 11.f, 8.f, 21.f - }); - - auto b = NDArrayFactory::create('c', {3, 1}, { 1.f, 2.f, 3.f }); + auto b = NDArrayFactory::create('c', {3, 1}, {1.f, 2.f, 3.f}); - auto exp = NDArrayFactory::create('c', {3, 1}, { -0.24999914f, 0.4999994f, 0.08333314f }); + auto exp = NDArrayFactory::create( + 'c', {3, 1}, {-0.24999914f, 0.4999994f, 0.08333314f}); - sd::ops::lstsq op; + sd::ops::lstsq op; - auto res = op.evaluate({&a, &b}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); - MmulHelper::matmul(&a, z, &exp, false, false); + MmulHelper::matmul(&a, &z, &exp, false, false); -// z->printIndexedBuffer("MatrixSolveLS2"); + // z->printIndexedBuffer("MatrixSolveLS2"); - ASSERT_TRUE(exp.equalsTo(b)); - + ASSERT_TRUE(exp.equalsTo(b)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, SolveLs_Test_3) { + auto a = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 1.f, 0.f, 0.f, -1.f, 1.f, 0.f, 0.f, 1.f, 1.f, -1.f, -1.f}); - auto a = NDArrayFactory::create('c', {3, 4}, { - 1.f,1.f,0.f,0.f,-1.f,1.f,0.f,0.f,1.f,1.f,-1.f,-1.f - }); - - auto b = NDArrayFactory::create('c', {3, 1}, { 1.f, 2.f, 3.f }); + auto b = NDArrayFactory::create('c', {3, 1}, {1.f, 2.f, 3.f}); - auto exp = NDArrayFactory::create('c', {3, 1}, { -0.5f, 1.5f, -2.f }); + auto exp = NDArrayFactory::create('c', {3, 1}, {-0.5f, 1.5f, -2.f}); - sd::ops::lstsq op; + sd::ops::lstsq op; - auto res = op.evaluate({&a, &b}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); -// z->printIndexedBuffer("MatrixSolveLS3"); - MmulHelper::matmul(&a, z, &exp, false, false); - ASSERT_TRUE(exp.equalsTo(b)); - + // z->printIndexedBuffer("MatrixSolveLS3"); + MmulHelper::matmul(&a, &z, &exp, false, false); + ASSERT_TRUE(exp.equalsTo(b)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, SolveLs_Test_4) { + auto a = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 1.f, 0.f, 0.f, -1.f, 1.f, 0.f, 0.f, 1.f, 1.f, -1.f, -1.f}); - auto a = NDArrayFactory::create('c', {3, 4}, { - 1.f,1.f,0.f,0.f,-1.f,1.f,0.f,0.f,1.f,1.f,-1.f,-1.f - }); - - auto b = NDArrayFactory::create('c', {3, 1}, { 1.f, 2.f, 3.f }); + auto b = NDArrayFactory::create('c', {3, 1}, {1.f, 2.f, 3.f}); - auto exp = NDArrayFactory::create('c', {4, 1}, { -0.5f, 1.5f, -2.f, 0.f}); + auto exp = + NDArrayFactory::create('c', {4, 1}, {-0.5f, 1.5f, -2.f, 0.f}); - sd::ops::lstsq op; + sd::ops::lstsq op; - auto res = op.evaluate({&a, &b}, {false}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); -// z->printIndexedBuffer("Output_12.4"); -// z->printShapeInfo("Output_12.4 shape"); -// MmulHelper::matmul(&a, z, &exp, false, false); + auto res = op.evaluate({&a, &b}, {false}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + // z->printIndexedBuffer("Output_12.4"); + // z->printShapeInfo("Output_12.4 shape"); + // MmulHelper::matmul(&a, z, &exp, false, false); -// z->printIndexedBuffer("MatrixSolveLS4"); + // z->printIndexedBuffer("MatrixSolveLS4"); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, SolveLs_Test_5) { + auto a = NDArrayFactory::create('c', {1, 0, 3, 4}); + auto b = NDArrayFactory::create('c', {1, 0, 3, 1}); - auto a = NDArrayFactory::create('c', {1, 0, 3, 4}); - auto b = NDArrayFactory::create('c', {1, 0, 3, 1}); - - sd::ops::lstsq op; + sd::ops::lstsq op; - auto res = op.evaluate({&a, &b}, {false}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - ASSERT_TRUE(z->isEmpty()); - - + auto res = op.evaluate({&a, &b}, {false}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + ASSERT_TRUE(z.isEmpty()); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, Solve_Test_6) { + auto a = NDArrayFactory::create('c', {1, 0, 3, 3}); + auto b = NDArrayFactory::create('c', {1, 0, 3, 1}); - auto a = NDArrayFactory::create('c', {1, 0, 3, 3}); - auto b = NDArrayFactory::create('c', {1, 0, 3, 1}); - - sd::ops::solve op; + sd::ops::solve op; - auto res = op.evaluate({&a, &b}, {true}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - ASSERT_TRUE(z->isEmpty()); - - + auto res = op.evaluate({&a, &b}, {true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + ASSERT_TRUE(z.isEmpty()); } ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TriangularSolve_Test_6) { + auto a = + NDArrayFactory::create('c', {4, 4}, + {5.f, 1.f, -3.f, 3.f, 0.f, 1.f, 1.f, -1.f, + 0.f, 0.f, 2.f, -9.f, 0.f, 0.f, 0.f, 4.f}); - auto a = NDArrayFactory::create('c', {4, 4}, { - 5.f, 1.f, -3.f, 3.f, - 0.f, 1.f, 1.f, -1.f, - 0.f, 0.f, 2.f, -9.f, - 0.f, 0.f, 0.f, 4.f - }); + auto b = NDArrayFactory::create( + 'c', {4, 2}, {5.f, 1.f, 2.f, 1.f, 0.f, 1.f, -3.f, 1.f}); - auto b = NDArrayFactory::create('c', {4, 2}, { - 5.f, 1.f, 2.f, 1.f, 0.f, 1.f, -3.f, 1.f - }); - - auto exp = NDArrayFactory::create('c', {4, 2}, { - 1.f,0.2f, 1.f,0.8f, 1.f,0.4f, 1.f,1.2f - }); + auto exp = NDArrayFactory::create( + 'c', {4, 2}, {1.f, 0.2f, 1.f, 0.8f, 1.f, 0.4f, 1.f, 1.2f}); - sd::ops::triangular_solve op; + sd::ops::triangular_solve op; - auto res = op.evaluate({&a, &b}, {}, {}, {false, true}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + auto res = op.evaluate({&a, &b}, {}, {}, {false, true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); - z->printIndexedBuffer("TriangularSolve with adjoint"); + // z.printIndexedBuffer("TriangularSolve with adjoint"); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index 639d90389dd7..04d9cb615996 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -14,2845 +14,3387 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // Created by raver on 8/4/2018. // -#include "testlayers.h" -#include #include -#include #include -#include #include +#include +#include -using namespace sd; +#include +#include "testlayers.h" -class DeclarableOpsTests13 : public testing::Test { -public: +using namespace sd; - DeclarableOpsTests13() { - //printf("\n"); - //fflush(stdout); - } +class DeclarableOpsTests13 : public testing::Test { + public: + DeclarableOpsTests13() { + } }; template class TypedDeclarableOpsTests13 : public testing::Test { -public: - - TypedDeclarableOpsTests13() { - printf("\n"); - fflush(stdout); - } + public: + TypedDeclarableOpsTests13() { + printf("\n"); + fflush(stdout); + } }; typedef ::testing::Types TestingTypes; -TYPED_TEST_CASE(TypedDeclarableOpsTests13, TestingTypes); +TYPED_TEST_SUITE(TypedDeclarableOpsTests13, TestingTypes); TEST_F(DeclarableOpsTests13, test_pow_1) { - auto x = NDArrayFactory::create('c', {2, 2}, {2.f, 2.f, 2.f, 2.f}); - auto y = NDArrayFactory::create('c', {2}, {3, 3}); - auto e = NDArrayFactory::create('c', {2, 2}, {8.f, 8.f, 8.f, 8.f}); + auto x = NDArrayFactory::create('c', {2, 2}, {2.f, 2.f, 2.f, 2.f}); + auto y = NDArrayFactory::create('c', {2}, {3, 3}); + auto e = NDArrayFactory::create('c', {2, 2}, {8.f, 8.f, 8.f, 8.f}); - sd::ops::Pow op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::Pow op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, test_empty_range_1) { - auto start = NDArrayFactory::create(0); - auto limit = NDArrayFactory::create(0); - - sd::ops::range op; - auto result = op.evaluate({&start, &limit}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_TRUE(z->isEmpty()); + auto start = NDArrayFactory::create(0); + auto limit = NDArrayFactory::create(0); + sd::ops::range op; + auto result = op.evaluate({&start, &limit}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(z.isEmpty()); } TEST_F(DeclarableOpsTests13, test_empty_range_2) { + sd::ops::range op; + auto result = op.evaluate({}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::range op; - auto result = op.evaluate({}, {1.0, 1.0}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_TRUE(z->isEmpty()); + auto z = result.at(0); + ASSERT_TRUE(z.isEmpty()); } TEST_F(DeclarableOpsTests13, test_empty_range_3) { + sd::ops::range op; + auto result = op.evaluate({}, {1, 1}); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::range op; - auto result = op.evaluate({}, {1, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_TRUE(z->isEmpty()); + auto z = result.at(0); + ASSERT_TRUE(z.isEmpty()); } TEST_F(DeclarableOpsTests13, test_argmax_edge_1) { - auto ctx = new Context(1); - auto arr = NDArrayFactory::create_('c', {1024,1}); + auto ctx = Context(1); + auto arr = NDArrayFactory::create('c', {1024, 1}); - ctx->setInputArray(0, arr, true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {1}), true); - ctx->setInputArray(1, NDArrayFactory::create_(0), true); //Axis 0 + ctx.setInputArray(0, arr); + ctx.setOutputArray(0, NDArrayFactory::create('c', {1})); + ctx.setInputArray(1, NDArrayFactory::create(0)); // Axis 0 - - sd::ops::argmax op; - auto result = op.execute(ctx); - ASSERT_EQ(Status::OK(), result); - - //nd4j_printf("Done\n",""); - delete ctx; + sd::ops::argmax op; + auto result = op.execute(&ctx); + ASSERT_EQ(Status::OK(), result); } TEST_F(DeclarableOpsTests13, test_add_1) { - auto x = NDArrayFactory::create('c', {1, 768}); - auto y = NDArrayFactory::create('c', {768}); - auto e = NDArrayFactory::create('c', {1, 768});; - y. assign(1.0f); - e.assign(1.0f); + auto x = NDArrayFactory::create('c', {1, 768}); + auto y = NDArrayFactory::create('c', {768}); + auto e = NDArrayFactory::create('c', {1, 768}); + ; + y.assign(1.0f); + e.assign(1.0f); - x += y; + x += y; - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(DeclarableOpsTests13, test_listdiff_1) { - auto x = NDArrayFactory::create('c', {4}, {0, 1, 2, 3}); - auto y = NDArrayFactory::create('c', {2}, {3, 1}); + auto x = NDArrayFactory::create('c', {4}, {0, 1, 2, 3}); + auto y = NDArrayFactory::create('c', {2}, {3, 1}); - auto od = NDArrayFactory::create('c', {2}); - auto oi = NDArrayFactory::create('c', {2}); + auto od = NDArrayFactory::create('c', {2}); + auto oi = NDArrayFactory::create('c', {2}); - sd::ops::listdiff op; - auto result = op.execute({&x, &y}, std::vector{&od, &oi}, {}, {}, {}); - ASSERT_EQ(Status::OK(), result); + sd::ops::listdiff op; + auto result = + op.execute({&x, &y}, std::vector{&od, &oi}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); } TEST_F(DeclarableOpsTests13, test_greater_1) { - auto x = NDArrayFactory::create('c', {3, 1}); - auto y = NDArrayFactory::create('c', {1, 4}); + auto x = NDArrayFactory::create('c', {3, 1}); + auto y = NDArrayFactory::create('c', {1, 4}); - sd::ops::greater op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::greater op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); } TEST_F(DeclarableOpsTests13, test_eval_reduction_shape_1) { - Nd4jLong axis = 0L; - auto x = NDArrayFactory::create('c', {2}, {4, 2}); - auto y = NDArrayFactory::create('c', {1}, {axis}); - auto exp = NDArrayFactory::create('c', {2}, {1, 2}); + Nd4jLong axis = 0L; + auto x = NDArrayFactory::create('c', {2}, {4, 2}); + auto y = NDArrayFactory::create('c', {1}, {axis}); + auto exp = NDArrayFactory::create('c', {2}, {1, 2}); - sd::ops::evaluate_reduction_shape op; - auto result = op.evaluate({&x, &y}, {true}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::evaluate_reduction_shape op; + auto result = op.evaluate({&x, &y}, {true}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_EQ(exp, *z); + ASSERT_EQ(exp, z); } TEST_F(DeclarableOpsTests13, test_or_1) { + NDArray x('c', {4}, {false, true, false, true}, sd::DataType::BOOL); + NDArray y('c', {4}, {false, false, true, true}, sd::DataType::BOOL); + NDArray e('c', {4}, {false, true, true, true}, sd::DataType::BOOL); - NDArray x('c', {4}, {false, true, false, true}, sd::DataType::BOOL); - NDArray y('c', {4}, {false, false, true, true}, sd::DataType::BOOL); - NDArray e('c', {4}, {false, true, true, true}, sd::DataType::BOOL); + NDArray z('c', {4}, sd::DataType::BOOL); - NDArray z('c', {4}, sd::DataType::BOOL); + x.applyPairwiseTransform(pairwise::Or, y, z); - x.applyPairwiseTransform(pairwise::Or, y, z); - - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, test_and_1) { - auto x = NDArrayFactory::create('c', {4}, {false, true, false, true}); - auto y = NDArrayFactory::create('c', {4}, {false, false, true, true}); - auto e = NDArrayFactory::create('c', {4}, {false, false, false, true}); + auto x = NDArrayFactory::create('c', {4}, {false, true, false, true}); + auto y = NDArrayFactory::create('c', {4}, {false, false, true, true}); + auto e = NDArrayFactory::create('c', {4}, {false, false, false, true}); - auto z = NDArrayFactory::create('c', {4}); + auto z = NDArrayFactory::create('c', {4}); - x.applyPairwiseTransform(pairwise::And, y, z); + x.applyPairwiseTransform(pairwise::And, y, z); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, test_xor_1) { - auto x = NDArrayFactory::create('c', {4}, {false, true, false, true}); - auto y = NDArrayFactory::create('c', {4}, {false, false, true, true}); - auto e = NDArrayFactory::create('c', {4}, {false, true, true, false}); + auto x = NDArrayFactory::create('c', {4}, {false, true, false, true}); + auto y = NDArrayFactory::create('c', {4}, {false, false, true, true}); + auto e = NDArrayFactory::create('c', {4}, {false, true, true, false}); - auto z = NDArrayFactory::create('c', {4}); + auto z = NDArrayFactory::create('c', {4}); - x.applyPairwiseTransform(pairwise::Xor, y, z); + x.applyPairwiseTransform(pairwise::Xor, y, z); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, BarnesHutTsne_GainsTest_1) { - auto x = NDArrayFactory::create('c', {2,3}, {1,2,3, 4, 5, 6}); - auto y = NDArrayFactory::create('c', {2,3}, {1,-2,3, -4, 5, -6}); - auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); - auto exp = NDArrayFactory::create('c', {2,3}, {1.2,2.2,3.2,4.2,5.2,6.2}); - sd::ops::barnes_gains op; - auto result = op.evaluate({&x, &y, &eps}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printBuffer("Gains out"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + auto y = NDArrayFactory::create('c', {2, 3}, {1, -2, 3, -4, 5, -6}); + auto eps = NDArrayFactory::create('c', {2, 3}, + {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); + auto exp = NDArrayFactory::create('c', {2, 3}, + {1.2, 2.2, 3.2, 4.2, 5.2, 6.2}); + sd::ops::barnes_gains op; + auto result = op.evaluate({&x, &y, &eps}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printBuffer("Gains out"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests13, BarnesHutTsne_GainsTest_2) { - auto x = NDArrayFactory::create('c', {2,3}, {1, -2, 3, -4, 5, -6}); - auto y = NDArrayFactory::create('c', {2,3}, {1, -2, 3, -4, 5, -6}); - auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); - auto exp = NDArrayFactory::create('c', {2,3}, {1.2, 0.01, 3.2, 0.01, 5.2, 0.01}); - sd::ops::barnes_gains op; - auto result = op.evaluate({&x, &y, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printBuffer("Gains out"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto x = NDArrayFactory::create('c', {2, 3}, {1, -2, 3, -4, 5, -6}); + auto y = NDArrayFactory::create('c', {2, 3}, {1, -2, 3, -4, 5, -6}); + auto eps = NDArrayFactory::create('c', {2, 3}, + {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); + auto exp = NDArrayFactory::create('c', {2, 3}, + {1.2, 0.01, 3.2, 0.01, 5.2, 0.01}); + sd::ops::barnes_gains op; + auto result = op.evaluate({&x, &y, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printBuffer("Gains out"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); - //ASSERT_EQ(e, z); + // ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, BarnesHutTsne_GainsTest_3) { - auto x = NDArrayFactory::create('c', {2,3}, {-1, 2, -3, 4, -5, 6}); - auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); - auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); - auto exp = NDArrayFactory::create('c', {2,3}, {0.01, 2.2, 0.01, 4.2, 0.01, 6.2}); - sd::ops::barnes_gains op; - auto result = op.evaluate({&x, &y, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printBuffer("Gains out"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - + auto x = NDArrayFactory::create('c', {2, 3}, {-1, 2, -3, 4, -5, 6}); + auto y = + NDArrayFactory::create('c', {2, 3}, {-0.1, -2, 3, -4, -0.5, -6}); + auto eps = NDArrayFactory::create('c', {2, 3}, + {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); + auto exp = NDArrayFactory::create('c', {2, 3}, + {0.01, 2.2, 0.01, 4.2, 0.01, 6.2}); + sd::ops::barnes_gains op; + auto result = op.evaluate({&x, &y, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printBuffer("Gains out"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_1) { - auto data = NDArrayFactory::create('c', {5,4}); - auto rows = NDArrayFactory::create('c', {2}, {2, 3}); - auto cols = NDArrayFactory::create('c', {5}, {0, 2, 1, 4, 3}); - auto vals = NDArrayFactory::create('c', {5}, {10., 20., 30., 40., 50.}); - //auto buf = NDArrayFactory::create('c', {4}); - auto exp1 = NDArrayFactory::create('c', {5,4}, {-1.846154, -1.846154, -1.846154, -1.846154, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}); - //auto exp2 = NDArrayFactory::create({-4., -4., -4., -4. - //std::vector exp({&exp1, &exp2}); - data.linspace(1); - -// auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); -// auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); -// auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); - sd::ops::barnes_edge_forces op; - auto result = op.evaluate({&rows, &cols, &vals, &data}, {}, {1}); - - - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printBuffer("Output"); - ASSERT_TRUE(exp1.equalsTo(result.at(0))); - + auto data = NDArrayFactory::create('c', {5, 4}); + auto rows = NDArrayFactory::create('c', {2}, {2, 3}); + auto cols = NDArrayFactory::create('c', {5}, {0, 2, 1, 4, 3}); + auto vals = + NDArrayFactory::create('c', {5}, {10., 20., 30., 40., 50.}); + // auto buf = NDArrayFactory::create('c', {4}); + auto exp1 = NDArrayFactory::create( + 'c', {5, 4}, + {-1.846154, -1.846154, -1.846154, -1.846154, 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}); + // auto exp2 = NDArrayFactory::create({-4., -4., -4., -4. + // std::vector exp({&exp1, &exp2}); + data.linspace(1); + + // auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, + // -0.5, -6}); auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, + // 0.2, -0.3, 0.4, -0.5, 0.6}); auto exp = + // NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); + sd::ops::barnes_edge_forces op; + auto result = op.evaluate({&rows, &cols, &vals, &data}, {}, {1}); + + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printBuffer("Output"); + ASSERT_TRUE(exp1.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_2) { - auto data = NDArrayFactory::create('c', {5,4}); - auto rows = NDArrayFactory::create('c', {3}, {1,2,3}); - auto cols = NDArrayFactory::create('c', {5}, {1, 2, 0, 4, 3}); - auto vals = NDArrayFactory::create('c', {5}, {10., 20., 30., 40., 50.}); - //auto buf = NDArrayFactory::create('c', {4}); - auto exp = NDArrayFactory::create('c', {5,4}, {-0.622568, -0.622568, -0.622568, -0.622568, 1.846154, 1.846154, 1.846154, 1.846154, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}); - //auto exp2 = NDArrayFactory::create({-4., -4., -4., -4. - //std::vector exp({&exp1, &exp2}); - data.linspace(1); - -// auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); -// auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); -// auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); - sd::ops::barnes_edge_forces op; - auto result = op.evaluate({&rows, &cols, &vals, &data}, {}, {2}); - - - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printBuffer("Output"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - + auto data = NDArrayFactory::create('c', {5, 4}); + auto rows = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto cols = NDArrayFactory::create('c', {5}, {1, 2, 0, 4, 3}); + auto vals = + NDArrayFactory::create('c', {5}, {10., 20., 30., 40., 50.}); + // auto buf = NDArrayFactory::create('c', {4}); + auto exp = NDArrayFactory::create( + 'c', {5, 4}, + {-0.622568, -0.622568, -0.622568, -0.622568, 1.846154, 1.846154, 1.846154, + 1.846154, 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0.}); + // auto exp2 = NDArrayFactory::create({-4., -4., -4., -4. + // std::vector exp({&exp1, &exp2}); + data.linspace(1); + + // auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, + // -0.5, -6}); auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, + // 0.2, -0.3, 0.4, -0.5, 0.6}); auto exp = + // NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); + sd::ops::barnes_edge_forces op; + auto result = op.evaluate({&rows, &cols, &vals, &data}, {}, {2}); + + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printBuffer("Output"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_3) { - auto data = NDArrayFactory::create('c', {11, 5}, {0.3, 0.2625, 0.2674, 0.8604, 0.4803, 0.1096, 0.795, 0.5918, 0.2738, 0.952, 0.969, 0.8586, 0.8088, 0.5338, 0.5961, 0.7187, 0.463, 0.0867, 0.7748, 0.4802, 0.2493, 0.3227, 0.3064, 0.698, 0.7977, 0.7674, 0.168, 0.3107, 0.0217, 0.138, 0.8619, 0.8413, 0.5285, 0.9703, 0.6774, 0.2624, 0.4374, 0.1569, 0.1107, 0.0601, 0.4094, 0.9564, 0.5994, 0.8279, 0.3859, 0.6202, 0.7604, 0.0788, 0.0865, 0.7445, 0.6548, 0.3385, 0.0582, 0.6249, 0.7432}); - auto rows = NDArrayFactory::create({0, 9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99}); - auto cols = NDArrayFactory::create({4, 3, 10, 8, 6, 7, 1, 5, 9, 4, 9, 8, 10, 2, 0, 6, 7, 3, 6, 8, 3, 9, 10, 1, 4, 0, 5, 10, 0, 4, 6, 8, 9, 2, 5, 7, 0, 10, 3, 1, 8, 9, 6, 7, 2, 7, 9, 3, 10, 0, 4, 2, 8, 1, 2, 8, 3, 10, 0, 4, 9, 1, 5, 5, 9, 0, 3, 10, 4, 8, 1, 2, 6, 2, 0, 3, 4, 1, 10, 9, 7, 10, 1, 3, 7, 4, 5, 2, 8, 6, 3, 4, 0, 9, 6, 5, 8, 7, 1}); - auto vals = NDArrayFactory::create({0.6199614579042966, 0.19644097697184246, 0.13824979367331638, 0.01949900138247239, 0.008923198738222747, 0.008392793826291798, 0.0033348224714784204, 0.0026246189757042166, 0.0025733360563748838, 0.5877136110798608, 0.28250257562439585, 0.08098135424273815, 0.014862718272075049, 0.01219187321450782, 0.01152346362368888, 0.004243137936786281, 0.0034626999030188577, 0.0025185661029283168, 0.6777005651521399, 0.18321248222489303, 0.04018202465629351, 0.02941935889988646, 0.02164146250842832, 0.019898422145651618, 0.011683461395713935, 0.008439076090480863, 0.007823146926512332, 0.6770900431883232, 0.16617511239723026, 0.06039349887686468, 0.04650913399744179, 0.016886531410284355, 0.014591049666869658, 0.006407638669806174, 0.006074413005122801, 0.0058725787880570205, 0.6278185083409108, 0.235127797795446, 0.07023700015217448, 0.030885483448633774, 0.01229522088606573, 0.009238279699136107, 0.008219511168822047, 0.004303744819835723, 0.0018744536889749907, 0.7122603898978483, 0.07862620103245824, 0.07061257369349086, 0.06721483653169834, 0.028957853952131768, 0.01778978123182596, 0.01481713955181034, 0.005492728917348627, 0.0042284951913875955, 0.5266844101016999, 0.3304104787383107, 0.10930017433210941, 0.018514917515240075, 0.006969360999637938, 0.0063776901975396, 0.0010590388116165708, 6.526830884629785E-4, 3.1246215383067865E-5, 0.7176179284835663, 0.08741734015883978, 0.05927699083866909, 0.04663169573956976, 0.03287576269194147, 0.02993912340339554, 0.013365238657916641, 0.010616858763291145, 0.002259061262810172, 0.6891905160321706, 0.1397658294110526, 0.05438284759722162, 0.05437184733708826, 0.028683289714498808, 0.020986120697576355, 0.007218358114741088, 0.0032834770669826364, 0.002117714028667893, 0.6823873496503976, 0.1345267083671607, 0.08712863515505885, 0.04286621088946242, 0.02544804597749639, 0.01689343932533317, 0.007219134659004873, 0.0019232929717404616, 0.0016071830043453991, 0.6425809622897437, 0.18474464886441516, 0.10897036475298316, 0.03466939253836615, 0.013288054277817787, 0.005149178177380355, 0.0037974063158903518, 0.0037851733015991287, 0.0030148194818042273}); - //auto buf = NDArrayFactory::create('c', {4}); - auto exp = NDArrayFactory::create('c', {11, 5}, {-0.080205, -0.085862, 0.024045, 0.133551, -0.199896, -0.170597, 0.187301, 0.205824, -0.165268, 0.131228, 0.155135, 0.021446, 0.217583, -0.262873, -0.021075, 0.114537, 0.088023, -0.039205, 0.087984, -0.179565, -0.132683, 0.003677, 0.072081, -0.068737, 0.204481, 0.287223, -0.193989, 0.104569, -0.123401, -0.036368, 0.086745, 0.002961, -0.091327, 0.234853, 0.120270, -0.304006, 0.128305, -0.084867, -0.017550, -0.130837, -0.288569, 0.124679, 0.054078, -0.034187, -0.192599, 0.033196, 0.228182, -0.044972, -0.314217, 0.020287, 0.054427, -0.078887, -0.078246, -0.104543, 0.169803}); - //auto exp2 = NDArrayFactory::create({-4., -4., -4., -4. - //std::vector exp({&exp1, &exp2}); - //data.assign(1.0); //linspace(1); - -// auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); -// auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); -// auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); - sd::ops::barnes_edge_forces op; - auto result = op.evaluate({&rows, &cols, &vals, &data}, {}, {11}); - - //nd4j_printf("rows %lld, cols %lld, vals %lld, res full %lld\n", rows.lengthOf(), cols.lengthOf(), vals.lengthOf(), exp1.lengthOf()); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printBuffer("Output"); - //exp.printBuffer("Expect"); - //result.at(0)->printShapeInfo("Shape output"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - + auto data = NDArrayFactory::create( + 'c', {11, 5}, + {0.3, 0.2625, 0.2674, 0.8604, 0.4803, 0.1096, 0.795, 0.5918, + 0.2738, 0.952, 0.969, 0.8586, 0.8088, 0.5338, 0.5961, 0.7187, + 0.463, 0.0867, 0.7748, 0.4802, 0.2493, 0.3227, 0.3064, 0.698, + 0.7977, 0.7674, 0.168, 0.3107, 0.0217, 0.138, 0.8619, 0.8413, + 0.5285, 0.9703, 0.6774, 0.2624, 0.4374, 0.1569, 0.1107, 0.0601, + 0.4094, 0.9564, 0.5994, 0.8279, 0.3859, 0.6202, 0.7604, 0.0788, + 0.0865, 0.7445, 0.6548, 0.3385, 0.0582, 0.6249, 0.7432}); + auto rows = NDArrayFactory::create( + {0, 9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99}); + auto cols = NDArrayFactory::create( + {4, 3, 10, 8, 6, 7, 1, 5, 9, 4, 9, 8, 10, 2, 0, 6, 7, 3, 6, 8, + 3, 9, 10, 1, 4, 0, 5, 10, 0, 4, 6, 8, 9, 2, 5, 7, 0, 10, 3, 1, + 8, 9, 6, 7, 2, 7, 9, 3, 10, 0, 4, 2, 8, 1, 2, 8, 3, 10, 0, 4, + 9, 1, 5, 5, 9, 0, 3, 10, 4, 8, 1, 2, 6, 2, 0, 3, 4, 1, 10, 9, + 7, 10, 1, 3, 7, 4, 5, 2, 8, 6, 3, 4, 0, 9, 6, 5, 8, 7, 1}); + auto vals = NDArrayFactory::create( + {0.6199614579042966, 0.19644097697184246, 0.13824979367331638, + 0.01949900138247239, 0.008923198738222747, 0.008392793826291798, + 0.0033348224714784204, 0.0026246189757042166, 0.0025733360563748838, + 0.5877136110798608, 0.28250257562439585, 0.08098135424273815, + 0.014862718272075049, 0.01219187321450782, 0.01152346362368888, + 0.004243137936786281, 0.0034626999030188577, 0.0025185661029283168, + 0.6777005651521399, 0.18321248222489303, 0.04018202465629351, + 0.02941935889988646, 0.02164146250842832, 0.019898422145651618, + 0.011683461395713935, 0.008439076090480863, 0.007823146926512332, + 0.6770900431883232, 0.16617511239723026, 0.06039349887686468, + 0.04650913399744179, 0.016886531410284355, 0.014591049666869658, + 0.006407638669806174, 0.006074413005122801, 0.0058725787880570205, + 0.6278185083409108, 0.235127797795446, 0.07023700015217448, + 0.030885483448633774, 0.01229522088606573, 0.009238279699136107, + 0.008219511168822047, 0.004303744819835723, 0.0018744536889749907, + 0.7122603898978483, 0.07862620103245824, 0.07061257369349086, + 0.06721483653169834, 0.028957853952131768, 0.01778978123182596, + 0.01481713955181034, 0.005492728917348627, 0.0042284951913875955, + 0.5266844101016999, 0.3304104787383107, 0.10930017433210941, + 0.018514917515240075, 0.006969360999637938, 0.0063776901975396, + 0.0010590388116165708, 6.526830884629785E-4, 3.1246215383067865E-5, + 0.7176179284835663, 0.08741734015883978, 0.05927699083866909, + 0.04663169573956976, 0.03287576269194147, 0.02993912340339554, + 0.013365238657916641, 0.010616858763291145, 0.002259061262810172, + 0.6891905160321706, 0.1397658294110526, 0.05438284759722162, + 0.05437184733708826, 0.028683289714498808, 0.020986120697576355, + 0.007218358114741088, 0.0032834770669826364, 0.002117714028667893, + 0.6823873496503976, 0.1345267083671607, 0.08712863515505885, + 0.04286621088946242, 0.02544804597749639, 0.01689343932533317, + 0.007219134659004873, 0.0019232929717404616, 0.0016071830043453991, + 0.6425809622897437, 0.18474464886441516, 0.10897036475298316, + 0.03466939253836615, 0.013288054277817787, 0.005149178177380355, + 0.0037974063158903518, 0.0037851733015991287, 0.0030148194818042273}); + // auto buf = NDArrayFactory::create('c', {4}); + auto exp = NDArrayFactory::create( + 'c', {11, 5}, + {-0.080205, -0.085862, 0.024045, 0.133551, -0.199896, -0.170597, + 0.187301, 0.205824, -0.165268, 0.131228, 0.155135, 0.021446, + 0.217583, -0.262873, -0.021075, 0.114537, 0.088023, -0.039205, + 0.087984, -0.179565, -0.132683, 0.003677, 0.072081, -0.068737, + 0.204481, 0.287223, -0.193989, 0.104569, -0.123401, -0.036368, + 0.086745, 0.002961, -0.091327, 0.234853, 0.120270, -0.304006, + 0.128305, -0.084867, -0.017550, -0.130837, -0.288569, 0.124679, + 0.054078, -0.034187, -0.192599, 0.033196, 0.228182, -0.044972, + -0.314217, 0.020287, 0.054427, -0.078887, -0.078246, -0.104543, + 0.169803}); + // auto exp2 = NDArrayFactory::create({-4., -4., -4., -4. + // std::vector exp({&exp1, &exp2}); + // data.assign(1.0); //linspace(1); + + // auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, + // -0.5, -6}); auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, + // 0.2, -0.3, 0.4, -0.5, 0.6}); auto exp = + // NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); + sd::ops::barnes_edge_forces op; + auto result = op.evaluate({&rows, &cols, &vals, &data}, {}, {11}); + + // nd4j_printf("rows %lld, cols %lld, vals %lld, res full %lld\n", + // rows.lengthOf(), cols.lengthOf(), vals.lengthOf(), exp1.lengthOf()); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printBuffer("Output"); + // exp.printBuffer("Expect"); + // result.at(0)->printShapeInfo("Shape output"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_1) { -// auto data = NDArrayFactory::create('c', {5,4}); - auto rows = NDArrayFactory::create('c', {2}, {0, 1}); - auto cols = NDArrayFactory::create('c', {4}, {0, 1, 1, 0}); - auto vals = NDArrayFactory::create('c', {4}, {20., 30., 40., 50.}); - auto exp = NDArrayFactory::create('c', {1,1}, {20.}); -// data.linspace(1); - -// auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); -// auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); -// auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); - sd::ops::barnes_symmetrized op; - auto result = op.evaluate({&rows, &cols, &vals}, {}, {1}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(2)->printBuffer("Symmetrized1"); - ASSERT_TRUE(exp.equalsTo(result.at(2))); + // auto data = NDArrayFactory::create('c', {5,4}); + auto rows = NDArrayFactory::create('c', {2}, {0, 1}); + auto cols = NDArrayFactory::create('c', {4}, {0, 1, 1, 0}); + auto vals = NDArrayFactory::create('c', {4}, {20., 30., 40., 50.}); + auto exp = NDArrayFactory::create('c', {1, 1}, {20.}); + // data.linspace(1); + + // auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, + // -0.5, -6}); auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, + // 0.2, -0.3, 0.4, -0.5, 0.6}); auto exp = + // NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); + sd::ops::barnes_symmetrized op; + auto result = op.evaluate({&rows, &cols, &vals}, {}, {1}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(2)->printBuffer("Symmetrized1"); + ASSERT_TRUE(exp.equalsTo(result.at(2))); } TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_2) { - auto rows = NDArrayFactory::create('c', {4}, {0, 2, 2, 3}); - auto cols = NDArrayFactory::create('c', {8}, {0, 1, 1, 0, 0, 1, 1, 1}); - auto vals = NDArrayFactory::create('c', {8}, {20., 30., 40., 50., 120., 130., 140., 150.}); - auto exp = NDArrayFactory::create('c', {1,5}, {20., 15., 15., 20., 20.}); -// data.linspace(1); - -// auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); -// auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); -// auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); - sd::ops::barnes_symmetrized op; - auto result = op.evaluate({&rows, &cols, &vals}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(2)->printBuffer("Symmetrized2"); - // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); - ASSERT_TRUE(exp.equalsTo(result.at(2))); - + auto rows = NDArrayFactory::create('c', {4}, {0, 2, 2, 3}); + auto cols = NDArrayFactory::create('c', {8}, {0, 1, 1, 0, 0, 1, 1, 1}); + auto vals = NDArrayFactory::create( + 'c', {8}, {20., 30., 40., 50., 120., 130., 140., 150.}); + auto exp = + NDArrayFactory::create('c', {1, 5}, {20., 15., 15., 20., 20.}); + // data.linspace(1); + + // auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, + // -0.5, -6}); auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, + // 0.2, -0.3, 0.4, -0.5, 0.6}); auto exp = + // NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); + sd::ops::barnes_symmetrized op; + auto result = op.evaluate({&rows, &cols, &vals}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(2)->printBuffer("Symmetrized2"); + // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); + ASSERT_TRUE(exp.equalsTo(result.at(2))); } TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_3) { - auto rows = NDArrayFactory::create('c', {12}, {0, 2, 3, 5, 7, 8, 9, 11, 12, 14, 18, 21}); - auto cols = NDArrayFactory::create('c', {24}, {0, 1, 2, 3, 4, 5, 4, 3, 2, 1, 0, 1, 0, 2, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5}); - auto vals = NDArrayFactory::create('c', {24}, {20., 30., 40., 50., 120., 130., 140., 150.,220., 230., 240., 250., 2120., 2130., 2140., 2150., 320., 330., 340., 350., 3120., 3130., 3140., 3150.}); - auto exp = NDArrayFactory::create('c', {1, 39}, {15.000000, 0.000000, 0.000000, 65.000000, 60.000000, 145.000000, 20.000000, 25.000000, 65.000000, 145.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000}); -// data.linspace(1); - -// auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); -// auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); -// auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); - sd::ops::barnes_symmetrized op; - auto result = op.evaluate({&rows, &cols, &vals}, {}, {11}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(2)->printBuffer("Symmetrized3"); - //exp.printBuffer("EXPect symm3"); - // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); - //ASSERT_TRUE(exp.equalsTo(result.at(0))); - + auto rows = NDArrayFactory::create( + 'c', {12}, {0, 2, 3, 5, 7, 8, 9, 11, 12, 14, 18, 21}); + auto cols = NDArrayFactory::create( + 'c', {24}, + {0, 1, 2, 3, 4, 5, 4, 3, 2, 1, 0, 1, 0, 2, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5}); + auto vals = NDArrayFactory::create( + 'c', {24}, {20., 30., 40., 50., 120., 130., 140., 150., + 220., 230., 240., 250., 2120., 2130., 2140., 2150., + 320., 330., 340., 350., 3120., 3130., 3140., 3150.}); + auto exp = NDArrayFactory::create( + 'c', {1, 39}, + {15.000000, 0.000000, 0.000000, 65.000000, 60.000000, 145.000000, + 20.000000, 25.000000, 65.000000, 145.000000, 0.000000, 0.000000, + 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + 0.000000, 0.000000, 0.000000}); + // data.linspace(1); + + // auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, + // -0.5, -6}); auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, + // 0.2, -0.3, 0.4, -0.5, 0.6}); auto exp = + // NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); + sd::ops::barnes_symmetrized op; + auto result = op.evaluate({&rows, &cols, &vals}, {}, {11}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(2)->printBuffer("Symmetrized3"); + // exp.printBuffer("EXPect symm3"); + // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); + // ASSERT_TRUE(exp.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_4) { - auto rows = NDArrayFactory::create({0, 9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99}); - auto cols = NDArrayFactory::create({4, 3, 10, 8, 6, 7, 1, 5, 9, 4, 9, 8, 10, 2, 0, 6, 7, 3, 6, 8, 3, 9, 10, 1, 4, 0, 5, 10, 0, 4, 6, 8, 9, 2, 5, 7, 0, 10, 3, 1, 8, 9, 6, 7, 2, 7, 9, 3, 10, 0, 4, 2, 8, 1, 2, 8, 3, 10, 0, 4, 9, 1, 5, 5, 9, 0, 3, 10, 4, 8, 1, 2, 6, 2, 0, 3, 4, 1, 10, 9, 7, 10, 1, 3, 7, 4, 5, 2, 8, 6, 3, 4, 0, 9, 6, 5, 8, 7, 1}); - auto vals = NDArrayFactory::create( {0.6200, 0.1964, 0.1382, 0.0195, 0.0089, 0.0084, 0.0033, 0.0026, 0.0026, 0.5877, 0.2825, 0.0810, 0.0149, 0.0122, 0.0115, 0.0042, 0.0035, 0.0025, 0.6777, 0.1832, 0.0402, 0.0294, 0.0216, 0.0199, 0.0117, 0.0084, 0.0078, 0.6771, 0.1662, 0.0604, 0.0465, 0.0169, 0.0146, 0.0064, 0.0061, 0.0059, 0.6278, 0.2351, 0.0702, 0.0309, 0.0123, 0.0092, 0.0082, 0.0043, 0.0019, 0.7123, 0.0786, 0.0706, 0.0672, 0.0290, 0.0178, 0.0148, 0.0055, 0.0042, 0.5267, 0.3304, 0.1093, 0.0185, 0.0070, 0.0064, 0.0011, 0.0007, 3.1246e-5, 0.7176, 0.0874, 0.0593, 0.0466, 0.0329, 0.0299, 0.0134, 0.0106, 0.0023, 0.6892, 0.1398, 0.0544, 0.0544, 0.0287, 0.0210, 0.0072, 0.0033, 0.0021, 0.6824, 0.1345, 0.0871, 0.0429, 0.0254, 0.0169, 0.0072, 0.0019, 0.0016, 0.6426, 0.1847, 0.1090, 0.0347, 0.0133, 0.0051, 0.0038, 0.0038, 0.0030}); - //auto exp = NDArrayFactory::create('c', {1, 39}, {15.000000, 0.000000, 0.000000, 65.000000, 60.000000, 145.000000, 20.000000, 25.000000, 65.000000, 145.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000}); -// data.linspace(1); - auto exp4 = NDArrayFactory::create('c', {1, 108}, {0.6239, 0.1813, 0.1236, 0.03695, 0.00795, 0.03385, 0.0074, 0.0158, 0.0013, 0.0042, 0.0074, 0.3093, 0.2085, 0.051, 0.00895, 0.01605, 0.00245, 0.00705, 0.00125, 0.0021, 0.01605, 0.6022, 0.1615, 0.0233, - 0.0183, 0.0108, 0.0068, 0.0042, 0.0113, 0.00115, 0.1813, 0.00125, 0.0233, 0.65985, 0.0653, 0.0779, 0.03565, 0.05085, 0.03835, 0.02625, 0.6239, 0.3093, 0.0068, 0.0653, 0.2099, 0.0205, 0.0173, 0.0073, - 0.0171, 0.0089, 0.0158, 0.0113, 0.03835, 0.71495, 0.04775, 0.03615, 0.0089, 0.00275, 0.0021, 1.5623E-5, 0.00795, 0.00245, 0.6022, 0.0779, 0.0073, 0.5098, 0.0159, 0.00135, 1.5623E-5, 0.03385, 0.00705, - 0.02625, 0.0171, 0.71495, 0.06515, 0.01835, 0.00775, 0.00115, 0.03695, 0.051, 0.1615, 0.03565, 0.0205, 0.00275, 0.5098, 0.00775, 0.0055, 0.0026, 0.0013, 0.2085, 0.0183, 0.05085, 0.0173, 0.04775, - 0.00135, 0.06515, 0.0026, 0.35855, 0.1236, 0.00895, 0.0108, 0.65985, 0.2099, 0.03615, 0.0159, 0.01835, 0.0055, 0.35855}); -// auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); -// auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); -// auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); - sd::ops::barnes_symmetrized op; - auto result = op.evaluate({&rows, &cols, &vals}, {}, {11}); - ASSERT_EQ(result.status(), Status::OK()); - auto res = result.at(2); + auto rows = NDArrayFactory::create( + {0, 9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99}); + auto cols = NDArrayFactory::create( + {4, 3, 10, 8, 6, 7, 1, 5, 9, 4, 9, 8, 10, 2, 0, 6, 7, 3, 6, 8, + 3, 9, 10, 1, 4, 0, 5, 10, 0, 4, 6, 8, 9, 2, 5, 7, 0, 10, 3, 1, + 8, 9, 6, 7, 2, 7, 9, 3, 10, 0, 4, 2, 8, 1, 2, 8, 3, 10, 0, 4, + 9, 1, 5, 5, 9, 0, 3, 10, 4, 8, 1, 2, 6, 2, 0, 3, 4, 1, 10, 9, + 7, 10, 1, 3, 7, 4, 5, 2, 8, 6, 3, 4, 0, 9, 6, 5, 8, 7, 1}); + auto vals = NDArrayFactory::create( + {0.6200, 0.1964, 0.1382, 0.0195, 0.0089, 0.0084, 0.0033, 0.0026, + 0.0026, 0.5877, 0.2825, 0.0810, 0.0149, 0.0122, 0.0115, 0.0042, + 0.0035, 0.0025, 0.6777, 0.1832, 0.0402, 0.0294, 0.0216, 0.0199, + 0.0117, 0.0084, 0.0078, 0.6771, 0.1662, 0.0604, 0.0465, 0.0169, + 0.0146, 0.0064, 0.0061, 0.0059, 0.6278, 0.2351, 0.0702, 0.0309, + 0.0123, 0.0092, 0.0082, 0.0043, 0.0019, 0.7123, 0.0786, 0.0706, + 0.0672, 0.0290, 0.0178, 0.0148, 0.0055, 0.0042, 0.5267, 0.3304, + 0.1093, 0.0185, 0.0070, 0.0064, 0.0011, 0.0007, 3.1246e-5, 0.7176, + 0.0874, 0.0593, 0.0466, 0.0329, 0.0299, 0.0134, 0.0106, 0.0023, + 0.6892, 0.1398, 0.0544, 0.0544, 0.0287, 0.0210, 0.0072, 0.0033, + 0.0021, 0.6824, 0.1345, 0.0871, 0.0429, 0.0254, 0.0169, 0.0072, + 0.0019, 0.0016, 0.6426, 0.1847, 0.1090, 0.0347, 0.0133, 0.0051, + 0.0038, 0.0038, 0.0030}); + // auto exp = NDArrayFactory::create('c', {1, 39}, {15.000000, + // 0.000000, 0.000000, 65.000000, 60.000000, + // 145.000000, 20.000000, 25.000000, 65.000000, 145.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000}); + // data.linspace(1); + auto exp4 = NDArrayFactory::create( + 'c', {1, 108}, + {0.6239, 0.1813, 0.1236, 0.03695, 0.00795, 0.03385, 0.0074, + 0.0158, 0.0013, 0.0042, 0.0074, 0.3093, 0.2085, 0.051, + 0.00895, 0.01605, 0.00245, 0.00705, 0.00125, 0.0021, 0.01605, + 0.6022, 0.1615, 0.0233, 0.0183, 0.0108, 0.0068, 0.0042, + 0.0113, 0.00115, 0.1813, 0.00125, 0.0233, 0.65985, 0.0653, + 0.0779, 0.03565, 0.05085, 0.03835, 0.02625, 0.6239, 0.3093, + 0.0068, 0.0653, 0.2099, 0.0205, 0.0173, 0.0073, 0.0171, + 0.0089, 0.0158, 0.0113, 0.03835, 0.71495, 0.04775, 0.03615, + 0.0089, 0.00275, 0.0021, 1.5623E-5, 0.00795, 0.00245, 0.6022, + 0.0779, 0.0073, 0.5098, 0.0159, 0.00135, 1.5623E-5, 0.03385, + 0.00705, 0.02625, 0.0171, 0.71495, 0.06515, 0.01835, 0.00775, + 0.00115, 0.03695, 0.051, 0.1615, 0.03565, 0.0205, 0.00275, + 0.5098, 0.00775, 0.0055, 0.0026, 0.0013, 0.2085, 0.0183, + 0.05085, 0.0173, 0.04775, 0.00135, 0.06515, 0.0026, 0.35855, + 0.1236, 0.00895, 0.0108, 0.65985, 0.2099, 0.03615, 0.0159, + 0.01835, 0.0055, 0.35855}); + // auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, + // -0.5, -6}); auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, + // 0.2, -0.3, 0.4, -0.5, 0.6}); auto exp = + // NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); + sd::ops::barnes_symmetrized op; + auto result = op.evaluate({&rows, &cols, &vals}, {}, {11}); + ASSERT_EQ(result.status(), Status::OK()); + auto res = result.at(2); // res->printBuffer("Symmetrized4"); // exp4.printBuffer("Expected sym"); // nd4j_printf("Total res is {1, %lld}\n", res->lengthOf()); // nd4j_printf("Expected is {1, %lld}\n", exp4.lengthOf()); - //exp.printBuffer("EXPect symm3"); - // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); - ASSERT_TRUE(exp4.equalsTo(res)); - + // exp.printBuffer("EXPect symm3"); + // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); + ASSERT_TRUE(exp4.equalsTo(res)); } TEST_F(DeclarableOpsTests13, CellContains_test_1) { - - auto corners = NDArrayFactory::create( {0.5384, 0.5640, 0.3449, 0.5257, 0.5505}); - auto width = NDArrayFactory::create({0.4306, 0.3960, 0.4639, 0.5040, 0.4904}); - auto point = NDArrayFactory::create({0.3000, 0.2625, 0.2674, 0.8604, 0.4803}); - //auto exp = NDArrayFactory::create('c', {1, 39}, {15.000000, 0.000000, 0.000000, 65.000000, 60.000000, 145.000000, 20.000000, 25.000000, 65.000000, 145.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000}); - // data.linspace(1); - - // auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); - // auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); - // auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); - sd::ops::cell_contains op; - auto result = op.evaluate({&corners, &width, &point}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(result.at(0)->e(0)); - //result.at(2)->printBuffer("Symmetrized3"); - //exp.printBuffer("EXPect symm3"); - // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); - //ASSERT_TRUE(exp.equalsTo(result.at(0))); - + auto corners = + NDArrayFactory::create({0.5384, 0.5640, 0.3449, 0.5257, 0.5505}); + auto width = + NDArrayFactory::create({0.4306, 0.3960, 0.4639, 0.5040, 0.4904}); + auto point = + NDArrayFactory::create({0.3000, 0.2625, 0.2674, 0.8604, 0.4803}); + // auto exp = NDArrayFactory::create('c', {1, 39}, {15.000000, + // 0.000000, 0.000000, 65.000000, 60.000000, + // 145.000000, 20.000000, 25.000000, 65.000000, 145.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000}); + // data.linspace(1); + + // auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, + // -0.5, -6}); auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, + // 0.2, -0.3, 0.4, -0.5, 0.6}); auto exp = + // NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); + sd::ops::cell_contains op; + auto result = op.evaluate({&corners, &width, &point}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(result.at(0).e(0)); + // result.at(2)->printBuffer("Symmetrized3"); + // exp.printBuffer("EXPect symm3"); + // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); + // ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, adjustHue_1) { + NDArray input('c', {2, 2, 3}, + {0, 100, 56, 17, 220, 5, 150, 97, 230, 255, 2, 13}, + sd::DataType::FLOAT32); + NDArray factor = NDArrayFactory::create(0.5); + NDArray exp('c', {2, 2, 3}, + {100, 0, 44, 208, 5, 220, 177, 230, 97, 2, 255, 244}, + sd::DataType::FLOAT32); - NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, sd::DataType::FLOAT32); - NDArray factor = NDArrayFactory::create(0.5); - NDArray exp ('c', {2,2,3}, {100,0,44, 208,5,220, 177,230,97, 2,255,244}, sd::DataType::FLOAT32); + sd::ops::adjust_hue op; + auto results(op.evaluate({&input, &factor}, {}, {2})); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - sd::ops::adjust_hue op; - auto results (op.evaluate({&input, &factor}, {}, {2})); + auto result = results.at(0); + // result.printIndexedBuffer(); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result.printIndexedBuffer(); - - ASSERT_TRUE(exp.isSameShape(result)); - ASSERT_TRUE(exp.equalsTo(result)); + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, adjustHue_2) { + NDArray input('c', {2, 2, 3}, + {0.f, 100.f / 255.f, 56.f / 255.f, 17.f / 255.f, 220.f / 255.f, + 5.f / 255.f, 150.f / 255.f, 97.f / 255.f, 230.f / 255.f, + 255.f / 255.f, 2.f / 255.f, 13.f / 255.f}, + sd::DataType::FLOAT32); + NDArray exp('c', {2, 2, 3}, + {4.f / 255.f, 100.f / 255.f, 0.f, 146.f / 255.f, 220.f / 255.f, + 5.f / 255.f, 97.f / 255.f, 123.8f / 255.f, 230.f / 255.f, + 255.f / 255.f, 2.f / 255.f, 164.8f / 255.f}, + sd::DataType::FLOAT32); - NDArray input('c', { 2,2,3 }, { 0.f,100.f / 255.f,56.f / 255.f, 17.f / 255.f,220.f / 255.f,5.f / 255.f, 150.f / 255.f,97.f / 255.f,230.f / 255.f, 255.f / 255.f,2.f / 255.f,13.f / 255.f }, sd::DataType::FLOAT32); - NDArray exp('c', { 2,2,3 }, { 4.f / 255.f,100.f / 255.f,0.f, 146.f / 255.f,220.f / 255.f,5.f / 255.f, 97.f / 255.f,123.8f / 255.f,230.f / 255.f, 255.f / 255.f,2.f / 255.f,164.8f / 255.f }, sd::DataType::FLOAT32); - + sd::ops::adjust_hue op; + auto results(op.evaluate({&input}, {0.9}, {2})); - sd::ops::adjust_hue op; - auto results(op.evaluate({&input}, {0.9}, {2})); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(exp.isSameShape(result)); - ASSERT_TRUE(exp.equalsTo(result)); + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, adjustHue_3) { + NDArray input('c', {2, 2, 3}, + {0, 100, 56, 17, 220, 5, 150, 97, 230, 255, 2, 13}, + sd::DataType::FLOAT32); + NDArray exp( + 'c', {2, 2, 3}, + {0., 84., 100., 5., 220., 122.0001, 229.8, 97., 230., 255., 142.8002, 2.}, + sd::DataType::FLOAT32); - NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, sd::DataType::FLOAT32); - NDArray exp ('c', {2,2,3}, {0.,84.,100., 5.,220.,122.0001, 229.8,97.,230., 255.,142.8002,2.}, sd::DataType::FLOAT32); - - sd::ops::adjust_hue op; - auto results(op.evaluate({&input}, {-0.9}, {2})); + sd::ops::adjust_hue op; + auto results(op.evaluate({&input}, {-0.9}, {2})); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(exp.isSameShape(result)); - ASSERT_TRUE(exp.equalsTo(result)); + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, adjustHue_4) { + NDArray input('c', {2, 3, 2}, + {0, 17, 100, 220, 56, 5, 150, 255, 97, 2, 230, 13}, + sd::DataType::FLOAT32); + NDArray exp('c', {2, 3, 2}, + {100, 208, 0, 5, 44, 220, 177, 2, 230, 255, 97, 244}, + sd::DataType::FLOAT32); - NDArray input('c', {2,3,2}, {0,17, 100,220, 56,5, 150,255, 97,2, 230,13}, sd::DataType::FLOAT32); - NDArray exp ('c', {2,3,2}, {100,208, 0,5, 44,220, 177,2, 230,255, 97,244}, sd::DataType::FLOAT32); + sd::ops::adjust_hue op; + auto results(op.evaluate({&input}, {0.5}, {1})); - sd::ops::adjust_hue op; - auto results(op.evaluate({&input}, {0.5}, {1})); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(exp.isSameShape(result)); - ASSERT_TRUE(exp.equalsTo(result)); + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, adjustHue_5) { + NDArray input('c', {3, 2, 2}, + {0, 17, 150, 255, 100, 220, 97, 2, 56, 5, 230, 13}, + sd::DataType::FLOAT32); + NDArray exp('c', {3, 2, 2}, + {100, 208, 177, 2, 0, 5, 230, 255, 44, 220, 97, 244}, + sd::DataType::FLOAT32); - NDArray input('c', {3,2,2}, {0,17, 150,255, 100,220, 97,2, 56,5, 230,13}, sd::DataType::FLOAT32); - NDArray exp ('c', {3,2,2}, {100,208, 177,2, 0,5, 230,255, 44,220, 97,244}, sd::DataType::FLOAT32); - - sd::ops::adjust_hue op; - auto results(op.evaluate({&input}, {0.5}, {0})); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::adjust_hue op; + auto results(op.evaluate({&input}, {0.5}, {0})); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(exp.isSameShape(result)); - ASSERT_TRUE(exp.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, adjustSaturation_1) { + NDArray input('c', {2, 2, 3}, + {0, 100, 56, 17, 220, 5, 150, 97, 230, 255, 2, 13}, + sd::DataType::FLOAT32); + NDArray factor = NDArrayFactory::create(0.5); + NDArray exp( + 'c', {2, 2, 3}, + {50, 100, 78, 118.5, 220, 112.5, 190, 163.5, 230, 255, 128.5, 134}, + sd::DataType::FLOAT32); - NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, sd::DataType::FLOAT32); - NDArray factor = NDArrayFactory::create(0.5); - NDArray exp ('c', {2,2,3}, {50,100,78, 118.5,220,112.5, 190,163.5,230, 255,128.5,134}, sd::DataType::FLOAT32); - - sd::ops::adjust_saturation op; - auto results = op.evaluate({&input, &factor}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::adjust_saturation op; + auto results = op.evaluate({&input, &factor}, {}, {2}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(exp.isSameShape(result)); - ASSERT_TRUE(exp.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, adjustSaturation_2) { + NDArray input('c', {2, 2, 3}, + {0, 100, 56, 17, 220, 5, 150, 97, 230, 255, 2, 13}, + sd::DataType::DOUBLE); + NDArray exp('c', {2, 2, 3}, + {0., 100., 56., 12.279087, 220., 0., 91.654228, 0., 230., 255., + 0., 11.087015}, + sd::DataType::DOUBLE); - NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, sd::DataType::DOUBLE); - NDArray exp ('c', {2,2,3}, {0.,100.,56., 12.279087,220.,0., 91.654228,0.,230., 255.,0.,11.087015}, sd::DataType::DOUBLE); + sd::ops::adjust_saturation op; + auto results = op.evaluate({&input}, {10}, {2}); - sd::ops::adjust_saturation op; - auto results = op.evaluate({&input}, {10}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// result.printIndexedBuffer("Result2"); -// exp.printIndexedBuffer("Expect2"); - - ASSERT_TRUE(exp.isSameShape(result)); - ASSERT_TRUE(exp.equalsTo(result)); + auto result = results.at(0); + // result.printIndexedBuffer("Result2"); + // exp.printIndexedBuffer("Expect2"); + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, adjustSaturation_3) { + NDArray input('c', {2, 2, 3}, + {0, 100, 56, 17, 220, 5, 150, 97, 230, 255, 2, 13}, + sd::DataType::FLOAT32); + NDArray exp( + 'c', {2, 2, 3}, + {100., 100., 100., 220., 220., 220., 230., 230., 230., 255., 255., 255.}, + sd::DataType::FLOAT32); - NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, sd::DataType::FLOAT32); - NDArray exp ('c', {2,2,3}, {100.,100.,100., 220.,220.,220., 230.,230.,230., 255., 255., 255.}, sd::DataType::FLOAT32); - - sd::ops::adjust_saturation op; - auto results = op.evaluate({&input}, {-10}, {2}); + sd::ops::adjust_saturation op; + auto results = op.evaluate({&input}, {-10}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - - ASSERT_TRUE(exp.isSameShape(result)); - ASSERT_TRUE(exp.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, adjustSaturation_4) { + NDArray input('c', {2, 3, 2}, + {0, 17, 100, 220, 56, 5, 150, 255, 97, 2, 230, 13}, + sd::DataType::FLOAT32); + NDArray exp( + 'c', {2, 3, 2}, + {50, 118.5, 100, 220, 78, 112.5, 190, 255, 163.5, 128.5, 230, 134}, + sd::DataType::FLOAT32); - NDArray input('c', {2,3,2}, {0,17, 100,220, 56,5, 150,255, 97,2, 230,13}, sd::DataType::FLOAT32); - NDArray exp ('c', {2,3,2}, {50,118.5, 100,220, 78,112.5, 190,255, 163.5,128.5, 230,134}, sd::DataType::FLOAT32); - - sd::ops::adjust_saturation op; - auto results = op.evaluate({&input}, {0.5}, {1}); + sd::ops::adjust_saturation op; + auto results = op.evaluate({&input}, {0.5}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - // result.printIndexedBuffer(); - - ASSERT_TRUE(exp.isSameShape(result)); - ASSERT_TRUE(exp.equalsTo(result)); + auto result = results.at(0); + // result.printIndexedBuffer(); + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, adjustSaturation_5) { + NDArray input('c', {3, 2, 2}, + {0, 17, 150, 255, 100, 220, 97, 2, 56, 5, 230, 13}, + sd::DataType::FLOAT32); + NDArray exp( + 'c', {3, 2, 2}, + {50, 118.5, 190, 255, 100, 220, 163.5, 128.5, 78, 112.5, 230, 134}, + sd::DataType::FLOAT32); - NDArray input('c', {3,2,2}, {0,17, 150,255, 100,220, 97,2, 56,5, 230,13}, sd::DataType::FLOAT32); - NDArray exp ('c', {3,2,2}, {50,118.5, 190,255, 100,220, 163.5,128.5, 78,112.5, 230,134}, sd::DataType::FLOAT32); - - sd::ops::adjust_saturation op; - auto results = op.evaluate({&input}, {0.5}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::adjust_saturation op; + auto results = op.evaluate({&input}, {0.5}, {0}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(exp.isSameShape(result)); - ASSERT_TRUE(exp.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); } - TEST_F(DeclarableOpsTests13, shift_bits_1) { - auto x = NDArrayFactory::create('c', {5}); - auto y = NDArrayFactory::create(4); - auto e = x.ulike(); - x.assign(32); - e.assign(512); - - sd::ops::shift_bits op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create(4); + auto e = x.ulike(); + x.assign(32); + e.assign(512); - auto z = result.at(0); + sd::ops::shift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, *z); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, rshift_bits_1) { - auto x = NDArrayFactory::create('c', {5}); - auto y = NDArrayFactory::create(4); - auto e = x.ulike(); - x.assign(512); - e.assign(32); - - sd::ops::rshift_bits op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create(4); + auto e = x.ulike(); + x.assign(512); + e.assign(32); - auto z = result.at(0); + sd::ops::rshift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, *z); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, cyclic_shift_bits_1) { - auto x = NDArrayFactory::create('c', {5}); - auto y = NDArrayFactory::create(4); - auto e = x.ulike(); - x.assign(32); - e.assign(512); + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create(4); + auto e = x.ulike(); + x.assign(32); + e.assign(512); - sd::ops::cyclic_shift_bits op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::cyclic_shift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_EQ(e, *z); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, cyclic_rshift_bits_1) { - auto x = NDArrayFactory::create('c', {5}); - auto y = NDArrayFactory::create(4); - auto e = x.ulike(); - x.assign(512); - e.assign(32); + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create(4); + auto e = x.ulike(); + x.assign(512); + e.assign(32); - sd::ops::cyclic_rshift_bits op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::cyclic_rshift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_EQ(e, *z); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, shift_bits_2) { - auto x = NDArrayFactory::create('c', {5}); - auto y = NDArrayFactory::create('c', {5}); - auto e = x.ulike(); - x.assign(32); - y.assign(4); - e.assign(512); - - sd::ops::shift_bits op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(32); + y.assign(4); + e.assign(512); - auto z = result.at(0); + sd::ops::shift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, *z); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, rshift_bits_2) { - auto x = NDArrayFactory::create('c', {5}); - auto y = NDArrayFactory::create('c', {5}); - auto e = x.ulike(); - x.assign(512); - y.assign(4); - e.assign(32); - - sd::ops::rshift_bits op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(512); + y.assign(4); + e.assign(32); - ASSERT_EQ(e, *z); + sd::ops::rshift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, cyclic_shift_bits_2) { - auto x = NDArrayFactory::create('c', {5}); - auto y = NDArrayFactory::create('c', {5}); - auto e = x.ulike(); - x.assign(32); - y.assign(4); - e.assign(512); - - sd::ops::cyclic_shift_bits op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(32); + y.assign(4); + e.assign(512); - auto z = result.at(0); + sd::ops::cyclic_shift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, *z); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, cyclic_rshift_bits_2) { - auto x = NDArrayFactory::create('c', {5}); - auto y = NDArrayFactory::create('c', {5}); - auto e = x.ulike(); - x.assign(512); - y.assign(4); - e.assign(32); + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(512); + y.assign(4); + e.assign(32); - sd::ops::cyclic_rshift_bits op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::cyclic_rshift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_EQ(e, *z); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, shift_bits_3) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {1, 5}); - auto e = x.ulike(); - x.assign(32); - y.assign(4); - e.assign(512); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {1, 5}); + auto e = x.ulike(); + x.assign(32); + y.assign(4); + e.assign(512); - sd::ops::shift_bits op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::shift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_EQ(e, *z); + auto z = result.at(0); + ASSERT_EQ(e, z); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, space_to_batch_nd_1) { + NDArray x('c', {1, 2, 2, 2, 3}, sd::DataType::FLOAT32); + NDArray blockShape('c', {3}, {2, 2, 2}, + sd::DataType::INT32); // three spatial dimensions + NDArray paddings('c', {3, 2}, std::vector{0, 0, 0, 0, 0, 0}, + sd::DataType::INT32); - NDArray x('c', {1, 2, 2, 2, 3}, sd::DataType::FLOAT32); - NDArray blockShape('c', {3}, {2, 2, 2} , sd::DataType::INT32); // three spatial dimensions - NDArray paddings('c', {3, 2}, std::vector{0, 0, 0, 0, 0, 0} , sd::DataType::INT32); - - NDArray exp('c', {8, 1, 1, 1, 3}, sd::DataType::FLOAT32); + NDArray exp('c', {8, 1, 1, 1, 3}, sd::DataType::FLOAT32); - x.linspace(1); - exp.linspace(1); + x.linspace(1); + exp.linspace(1); - sd::ops::space_to_batch_nd op; - auto result = op.evaluate({&x, &blockShape, &paddings}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::space_to_batch_nd op; + auto result = op.evaluate({&x, &blockShape, &paddings}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, space_to_batch_nd_2) { - - NDArray x('c', {2, 2,4,3, 1}, sd::DataType::FLOAT32); - NDArray blockShape('c', {3}, {2, 2, 3} , sd::DataType::INT32); // three spatial dimensions - NDArray paddings('c', {3, 2}, {0,0, 0,2, 2,1} , sd::DataType::INT32); - - NDArray exp('c', {24, 1,3,2, 1}, { 0, 2, 0, 8, 0, 0, 0, 26, 0, 32, 0, 0, 0, 3, 0, 9, 0, 0, 0, 27, 0, 33, 0, 0, 1, - 0, 7, 0, 0, 0, 25, 0, 31, 0, 0, 0, 0, 5, 0, 11, 0, 0, 0, 29, 0, 35, 0, 0, 0, 6, - 0, 12, 0, 0, 0, 30, 0, 36, 0, 0, 4, 0, 10, 0, 0, 0, 28, 0, 34, 0, 0, 0, 0, 14, - 0, 20, 0, 0, 0, 38, 0, 44, 0, 0, 0, 15, 0, 21, 0, 0, 0, 39, 0, 45, 0, 0, 13, 0, - 19, 0, 0, 0, 37, 0, 43, 0, 0, 0, 0, 17, 0, 23, 0, 0, 0, 41, 0, 47, 0, 0, 0, 18, - 0, 24, 0, 0, 0, 42, 0, 48, 0, 0, 16, 0, 22, 0, 0, 0, 40, 0, 46, 0, 0, 0}, sd::DataType::FLOAT32); - x.linspace(1); - - sd::ops::space_to_batch_nd op; - auto result = op.evaluate({&x, &blockShape, &paddings}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - // z->printBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + NDArray x('c', {2, 2, 4, 3, 1}, sd::DataType::FLOAT32); + NDArray blockShape('c', {3}, {2, 2, 3}, + sd::DataType::INT32); // three spatial dimensions + NDArray paddings('c', {3, 2}, {0, 0, 0, 2, 2, 1}, sd::DataType::INT32); + + NDArray exp('c', {24, 1, 3, 2, 1}, + {0, 2, 0, 8, 0, 0, 0, 26, 0, 32, 0, 0, 0, 3, 0, 9, 0, 0, + 0, 27, 0, 33, 0, 0, 1, 0, 7, 0, 0, 0, 25, 0, 31, 0, 0, 0, + 0, 5, 0, 11, 0, 0, 0, 29, 0, 35, 0, 0, 0, 6, 0, 12, 0, 0, + 0, 30, 0, 36, 0, 0, 4, 0, 10, 0, 0, 0, 28, 0, 34, 0, 0, 0, + 0, 14, 0, 20, 0, 0, 0, 38, 0, 44, 0, 0, 0, 15, 0, 21, 0, 0, + 0, 39, 0, 45, 0, 0, 13, 0, 19, 0, 0, 0, 37, 0, 43, 0, 0, 0, + 0, 17, 0, 23, 0, 0, 0, 41, 0, 47, 0, 0, 0, 18, 0, 24, 0, 0, + 0, 42, 0, 48, 0, 0, 16, 0, 22, 0, 0, 0, 40, 0, 46, 0, 0, 0}, + sd::DataType::FLOAT32); + x.linspace(1); + + sd::ops::space_to_batch_nd op; + auto result = op.evaluate({&x, &blockShape, &paddings}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, space_to_batch_nd_3) { - - NDArray x('c', {2, 2,4,3, 1}, sd::DataType::FLOAT32); - NDArray blockShape('c', {3}, {2, 2, 3} , sd::DataType::INT32); // three spatial dimensions - NDArray paddings('c', {3, 2}, {1,1, 0,2, 2,1} , sd::DataType::INT32); - - NDArray exp('c', {24, 2,3,2, 1}, { 0, 0, 0, 0, 0, 0, 0, 14, 0, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 38, 0, 44, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, - 0, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 39, 0, 45, 0, 0, 0, 0, 0, 0, 0, 0, 13, 0, 19, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 37, 0, 43, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 17, 0, 23, 0, 0, 0, 0, 0, 0, 0, 0, 0, 41, 0, 47, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 18, 0, 24, 0, 0, 0, 0, 0, 0, 0, 0, 0, 42, 0, 48, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, - 22, 0, 0, 0, 0, 0, 0, 0, 0, 0, 40, 0, 46, 0, 0, 0, 0, 2, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 26, 0, 32, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 27, 0, 33, 0, 0, 0, 0, 0, 0, 0, 0, 1, - 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 25, 0, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 11, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 29, 0, 35, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 30, 0, 36, 0, 0, - 0, 0, 0, 0, 0, 0, 4, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 28, 0, 34, 0, 0, 0, 0, 0, 0, 0, 0, 0}, sd::DataType::FLOAT32); - x.linspace(1); - - sd::ops::space_to_batch_nd op; - auto result = op.evaluate({&x, &blockShape, &paddings}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - // z->printBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + NDArray x('c', {2, 2, 4, 3, 1}, sd::DataType::FLOAT32); + NDArray blockShape('c', {3}, {2, 2, 3}, + sd::DataType::INT32); // three spatial dimensions + NDArray paddings('c', {3, 2}, {1, 1, 0, 2, 2, 1}, sd::DataType::INT32); + + NDArray exp( + 'c', {24, 2, 3, 2, 1}, + {0, 0, 0, 0, 0, 0, 0, 14, 0, 20, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 38, 0, 44, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0, 21, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 39, 0, 45, 0, 0, 0, 0, 0, 0, 0, 0, + 13, 0, 19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 37, 0, 43, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 17, 0, 23, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 41, 0, 47, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 0, 24, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 42, 0, 48, 0, 0, 0, 0, 0, 0, 0, 0, + 16, 0, 22, 0, 0, 0, 0, 0, 0, 0, 0, 0, 40, 0, 46, 0, 0, 0, + 0, 2, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 26, 0, 32, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 3, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 27, 0, 33, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 7, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 25, 0, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 5, 0, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 29, 0, 35, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 6, 0, 12, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 30, 0, 36, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 10, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 28, 0, 34, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + sd::DataType::FLOAT32); + x.linspace(1); + + sd::ops::space_to_batch_nd op; + auto result = op.evaluate({&x, &blockShape, &paddings}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batch_to_space_nd_1) { + NDArray x('c', {8, 1, 1, 1, 3}, sd::DataType::FLOAT32); - NDArray x('c', {8, 1, 1, 1, 3}, sd::DataType::FLOAT32); + NDArray blockShape('c', {3}, {2., 2, 2}, + sd::DataType::INT32); // three spatial dimensions + NDArray crop('c', {3, 2}, {0., 0, 0, 0, 0, 0}, sd::DataType::INT32); - NDArray blockShape('c', {3}, {2., 2, 2} , sd::DataType::INT32); // three spatial dimensions - NDArray crop('c', {3, 2}, {0., 0, 0, 0, 0, 0} , sd::DataType::INT32); + NDArray exp('c', {1, 2, 2, 2, 3}, sd::DataType::FLOAT32); - NDArray exp('c', {1, 2, 2, 2, 3}, sd::DataType::FLOAT32); + x.linspace(1); + exp.linspace(1); - x.linspace(1); - exp.linspace(1); + sd::ops::batch_to_space_nd op; + auto result = op.evaluate({&x, &blockShape, &crop}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::batch_to_space_nd op; - auto result = op.evaluate({&x, &blockShape, &crop}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batch_to_space_nd_2) { + NDArray x('c', {24, 1, 3, 2, 1}, sd::DataType::FLOAT32); + NDArray blockShape('c', {3}, {2, 2, 3}, + sd::DataType::INT32); // three spatial dimensions + NDArray crop('c', {3, 2}, {0, 0, 0, 2, 2, 1}, sd::DataType::INT32); - NDArray x('c', {24, 1,3,2, 1}, sd::DataType::FLOAT32); - NDArray blockShape('c', {3}, {2, 2, 3} , sd::DataType::INT32); // three spatial dimensions - NDArray crop('c', {3, 2}, {0,0, 0,2, 2,1} , sd::DataType::INT32); + NDArray exp('c', {2, 2, 4, 3, 1}, + {25, 2, 14, 61, 38, 50, 27, 4, 16, 63, 40, 52, + 97, 74, 86, 133, 110, 122, 99, 76, 88, 135, 112, 124, + 31, 8, 20, 67, 44, 56, 33, 10, 22, 69, 46, 58, + 103, 80, 92, 139, 116, 128, 105, 82, 94, 141, 118, 130}, + sd::DataType::FLOAT32); + x.linspace(1); - NDArray exp('c', {2, 2,4,3, 1}, {25, 2, 14, 61, 38, 50, 27, 4, 16, 63, 40, 52, 97, 74, 86, 133, 110, 122, 99, 76, 88, 135, 112, 124, - 31, 8, 20, 67, 44, 56, 33, 10, 22, 69, 46, 58, 103, 80, 92, 139, 116, 128, 105, 82, 94, 141, 118, 130}, sd::DataType::FLOAT32); - x.linspace(1); - - sd::ops::batch_to_space_nd op; - auto result = op.evaluate({&x, &blockShape, &crop}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - // z->printBuffer(); + sd::ops::batch_to_space_nd op; + auto result = op.evaluate({&x, &blockShape, &crop}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printBuffer(); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batch_to_space_nd_3) { + NDArray x('c', {24, 2, 3, 2, 1}, sd::DataType::FLOAT32); + NDArray blockShape('c', {3}, {2, 2, 3}, + sd::DataType::INT32); // three spatial dimensions + NDArray crop('c', {3, 2}, {1, 1, 0, 2, 2, 1}, sd::DataType::INT32); - NDArray x('c', {24, 2,3,2, 1}, sd::DataType::FLOAT32); - NDArray blockShape('c', {3}, {2, 2, 3} , sd::DataType::INT32); // three spatial dimensions - NDArray crop('c', {3, 2}, {1,1, 0,2, 2,1} , sd::DataType::INT32); + NDArray exp('c', {2, 2, 4, 3, 1}, + {193, 146, 170, 265, 218, 242, 195, 148, 172, 267, 220, 244, + 55, 8, 32, 127, 80, 104, 57, 10, 34, 129, 82, 106, + 205, 158, 182, 277, 230, 254, 207, 160, 184, 279, 232, 256, + 67, 20, 44, 139, 92, 116, 69, 22, 46, 141, 94, 118}, + sd::DataType::FLOAT32); + x.linspace(1); - NDArray exp('c', {2, 2,4,3, 1}, {193, 146, 170, 265, 218, 242, 195, 148, 172, 267, 220, 244, 55, 8, 32, 127, 80, 104, 57, 10, 34, 129, 82, - 106, 205, 158, 182, 277, 230, 254, 207, 160, 184, 279, 232, 256, 67, 20, 44, 139, 92, 116, 69, 22, 46, 141, 94, 118}, sd::DataType::FLOAT32); - x.linspace(1); + sd::ops::batch_to_space_nd op; + auto result = op.evaluate({&x, &blockShape, &crop}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::batch_to_space_nd op; - auto result = op.evaluate({&x, &blockShape, &crop}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - // z->printBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printBuffer(); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, mergemax_1) { + NDArray x1('c', {5, 5}, sd::DataType::FLOAT32); + NDArray x2('c', {5, 5}, sd::DataType::FLOAT32); + NDArray x3('c', {5, 5}, sd::DataType::FLOAT32); + NDArray e('c', {5, 5}, sd::DataType::FLOAT32); + x1.assign(3); + x2.assign(1); + x3.assign(2); + e.assign(3); - NDArray x1('c', {5, 5}, sd::DataType::FLOAT32); - NDArray x2('c', {5, 5}, sd::DataType::FLOAT32); - NDArray x3('c', {5, 5}, sd::DataType::FLOAT32); - NDArray e('c', {5, 5}, sd::DataType::FLOAT32); - x1.assign(3); - x2.assign(1); - x3.assign(2); - e.assign(3); - - - sd::ops::mergemax op; - auto result = op.evaluate({&x1, &x2, &x3}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::mergemax op; + auto result = op.evaluate({&x1, &x2, &x3}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - // z->printBuffer(); - - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + auto z = result.at(0); + // z->printBuffer(); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, mergemax_2) { + NDArray x1('c', {1, 3}, {0., 1, 2}, sd::DataType::FLOAT32); + NDArray x2('c', {1, 1}, std::vector{1.}, sd::DataType::FLOAT32); + NDArray out('c', {1, 3}, {-1., -1, -1}, sd::DataType::FLOAT32); - NDArray x1('c', {1, 3}, {0., 1, 2}, sd::DataType::FLOAT32); - NDArray x2('c', {1, 1}, std::vector{1.}, sd::DataType::FLOAT32); - NDArray out('c', {1, 3}, {-1., -1, -1}, sd::DataType::FLOAT32); - - sd::ops::mergemax op; - auto status = op.execute({&x1, &x2}, {&out}, {}, {}, {}); + sd::ops::mergemax op; + auto status = op.execute({&x1, &x2}, {&out}, {}, {}, {}); - ASSERT_EQ(20, status); + ASSERT_EQ(20, status); } ///////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, mergemax_bp_1) { + NDArray x1('c', {5, 5}, sd::DataType::FLOAT32); + NDArray x2('c', {5, 5}, sd::DataType::FLOAT32); + NDArray x3('c', {5, 5}, sd::DataType::FLOAT32); + NDArray grad('c', {5, 5}, sd::DataType::FLOAT32); - NDArray x1('c', { 5, 5 }, sd::DataType::FLOAT32); - NDArray x2('c', { 5, 5 }, sd::DataType::FLOAT32); - NDArray x3('c', { 5, 5 }, sd::DataType::FLOAT32); - NDArray grad('c', { 5, 5 }, sd::DataType::FLOAT32); + x1.assign(3); + x2.assign(1); + x3.assign(2); + grad.linspace(.1, .1); - x1.assign(3); - x2.assign(1); - x3.assign(2); - grad.linspace(.1, .1); + sd::ops::mergemax_bp op; + auto result = op.evaluate({&x1, &x2, &x3, &grad}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(3, result.size()); + auto z = result.at(0); - sd::ops::mergemax_bp op; - auto result = op.evaluate({ &x1, &x2, &x3, &grad }, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(3, result.size()); - - auto z = result.at(0); - - ASSERT_TRUE(grad.isSameShape(z)); - ASSERT_TRUE(grad.equalsTo(z)); - + ASSERT_TRUE(grad.isSameShape(z)); + ASSERT_TRUE(grad.equalsTo(z)); } ///////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, mergemax_bp_2) { - - NDArray x1('c', { 2, 5 }, { 1,2,3,4,5,4,3,2,1,0 }, sd::DataType::FLOAT32); - NDArray x2('c', { 2, 5 }, { 0,1,2,3,4,5,6,7,8,9 }, sd::DataType::FLOAT32); - NDArray x3('c', { 2, 5 }, { 0,1,1,2,3,4,7,5,8,10 }, sd::DataType::FLOAT32); - NDArray grad('c', { 2, 5 }, sd::DataType::FLOAT32); - - grad.linspace(.1, .1); - - NDArray exp1('c', { 2, 5 }, { 0.1, 0.2, 0.3, 0.4, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0 }, sd::DataType::FLOAT32); - NDArray exp2('c', { 2, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0, 0.6, 0.0, 0.8, 0.9, 0.0 }, sd::DataType::FLOAT32); - NDArray exp3('c', { 2, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7, 0.0, 0.0, 1.0 }, sd::DataType::FLOAT32); - - sd::ops::mergemax_bp op; - auto result = op.evaluate({ &x1, &x2, &x3, &grad }, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(3, result.size()); - - auto z1 = result.at(0); - auto z2 = result.at(1); - auto z3 = result.at(2); - - ASSERT_TRUE(exp1.isSameShape(z1)); - ASSERT_TRUE(exp1.equalsTo(z1)); - ASSERT_TRUE(exp2.isSameShape(z2)); - ASSERT_TRUE(exp2.equalsTo(z2)); - ASSERT_TRUE(exp3.isSameShape(z3)); - ASSERT_TRUE(exp3.equalsTo(z3)); - + NDArray x1('c', {2, 5}, {1, 2, 3, 4, 5, 4, 3, 2, 1, 0}, + sd::DataType::FLOAT32); + NDArray x2('c', {2, 5}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, + sd::DataType::FLOAT32); + NDArray x3('c', {2, 5}, {0, 1, 1, 2, 3, 4, 7, 5, 8, 10}, + sd::DataType::FLOAT32); + NDArray grad('c', {2, 5}, sd::DataType::FLOAT32); + + grad.linspace(.1, .1); + + NDArray exp1('c', {2, 5}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0}, + sd::DataType::FLOAT32); + NDArray exp2('c', {2, 5}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.6, 0.0, 0.8, 0.9, 0.0}, + sd::DataType::FLOAT32); + NDArray exp3('c', {2, 5}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7, 0.0, 0.0, 1.0}, + sd::DataType::FLOAT32); + + sd::ops::mergemax_bp op; + auto result = op.evaluate({&x1, &x2, &x3, &grad}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(3, result.size()); + + auto z1 = result.at(0); + auto z2 = result.at(1); + auto z3 = result.at(2); + + ASSERT_TRUE(exp1.isSameShape(z1)); + ASSERT_TRUE(exp1.equalsTo(z1)); + ASSERT_TRUE(exp2.isSameShape(z2)); + ASSERT_TRUE(exp2.equalsTo(z2)); + ASSERT_TRUE(exp3.isSameShape(z3)); + ASSERT_TRUE(exp3.equalsTo(z3)); } ///////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, mergemax_bp_3) { - - NDArray x1C('c', { 2, 5 }, { 1,2,3,4,5,4,3,2,1,0 }, sd::DataType::FLOAT32); - NDArray x2C('c', { 2, 5 }, { 0,1,2,3,4,5,6,7,8,9 }, sd::DataType::FLOAT32); - NDArray x3C('c', { 2, 5 }, { 0,1,1,2,3,4,7,5,8,10 }, sd::DataType::FLOAT32); - NDArray grad('c', { 2, 5 }, sd::DataType::FLOAT32); - - grad.linspace(.1, .1); - - NDArray x1('f', { 2, 5 }, sd::DataType::FLOAT32); - NDArray x2('f', { 2, 5 }, sd::DataType::FLOAT32); - NDArray x3('f', { 2, 5 }, sd::DataType::FLOAT32); - - NDArray exp1C('c', { 2, 5 }, { 0.1, 0.2, 0.3, 0.4, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0 }, sd::DataType::FLOAT32); - NDArray exp2C('c', { 2, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0, 0.6, 0.0, 0.8, 0.9, 0.0 }, sd::DataType::FLOAT32); - NDArray exp3C('c', { 2, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7, 0.0, 0.0, 1.0 }, sd::DataType::FLOAT32); - - NDArray exp1('f', { 2, 5 }, sd::DataType::FLOAT32); - NDArray exp2('f', { 2, 5 }, sd::DataType::FLOAT32); - NDArray exp3('f', { 2, 5 }, sd::DataType::FLOAT32); - - x1.assign(x1C); - x2.assign(x2C); - x3.assign(x3C); - - exp1.assign(exp1C); - exp2.assign(exp2C); - exp3.assign(exp3C); - - sd::ops::mergemax_bp op; - auto result = op.evaluate({ &x1, &x2, &x3, &grad }, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(3, result.size()); - - auto z1 = result.at(0); - auto z2 = result.at(1); - auto z3 = result.at(2); - - ASSERT_TRUE(exp1.isSameShape(z1)); - ASSERT_TRUE(exp1.equalsTo(z1)); - ASSERT_TRUE(exp2.isSameShape(z2)); - ASSERT_TRUE(exp2.equalsTo(z2)); - ASSERT_TRUE(exp3.isSameShape(z3)); - ASSERT_TRUE(exp3.equalsTo(z3)); - + NDArray x1C('c', {2, 5}, {1, 2, 3, 4, 5, 4, 3, 2, 1, 0}, + sd::DataType::FLOAT32); + NDArray x2C('c', {2, 5}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, + sd::DataType::FLOAT32); + NDArray x3C('c', {2, 5}, {0, 1, 1, 2, 3, 4, 7, 5, 8, 10}, + sd::DataType::FLOAT32); + NDArray grad('c', {2, 5}, sd::DataType::FLOAT32); + + grad.linspace(.1, .1); + + NDArray x1('f', {2, 5}, sd::DataType::FLOAT32); + NDArray x2('f', {2, 5}, sd::DataType::FLOAT32); + NDArray x3('f', {2, 5}, sd::DataType::FLOAT32); + + NDArray exp1C('c', {2, 5}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0}, + sd::DataType::FLOAT32); + NDArray exp2C('c', {2, 5}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.6, 0.0, 0.8, 0.9, 0.0}, + sd::DataType::FLOAT32); + NDArray exp3C('c', {2, 5}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7, 0.0, 0.0, 1.0}, + sd::DataType::FLOAT32); + + NDArray exp1('f', {2, 5}, sd::DataType::FLOAT32); + NDArray exp2('f', {2, 5}, sd::DataType::FLOAT32); + NDArray exp3('f', {2, 5}, sd::DataType::FLOAT32); + + x1.assign(x1C); + x2.assign(x2C); + x3.assign(x3C); + + exp1.assign(exp1C); + exp2.assign(exp2C); + exp3.assign(exp3C); + + sd::ops::mergemax_bp op; + auto result = op.evaluate({&x1, &x2, &x3, &grad}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(3, result.size()); + + auto z1 = result.at(0); + auto z2 = result.at(1); + auto z3 = result.at(2); + + ASSERT_TRUE(exp1.isSameShape(z1)); + ASSERT_TRUE(exp1.equalsTo(z1)); + ASSERT_TRUE(exp2.isSameShape(z2)); + ASSERT_TRUE(exp2.equalsTo(z2)); + ASSERT_TRUE(exp3.isSameShape(z3)); + ASSERT_TRUE(exp3.equalsTo(z3)); } ///////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, mergeadd_bp_1) { - - NDArray x1('c', { 5, 5 }, sd::DataType::FLOAT32); - NDArray x2('c', { 5, 5 }, sd::DataType::FLOAT32); - NDArray x3('c', { 5, 5 }, sd::DataType::FLOAT32); - NDArray grad('c', { 5, 5 }, sd::DataType::FLOAT32); - - x1.assign(3); - x2.assign(1); - x3.assign(2); - grad.linspace(.1, .1); - - sd::ops::mergeadd_bp op; - auto result = op.evaluate({ &x1, &x2, &x3, &grad }, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(3, result.size()); - - for (int i = 0; i < 3; i++) { - auto z = result.at(0); - ASSERT_TRUE(grad.isSameShape(z)); - ASSERT_TRUE(grad.equalsTo(z)); - } + NDArray x1('c', {5, 5}, sd::DataType::FLOAT32); + NDArray x2('c', {5, 5}, sd::DataType::FLOAT32); + NDArray x3('c', {5, 5}, sd::DataType::FLOAT32); + NDArray grad('c', {5, 5}, sd::DataType::FLOAT32); + + x1.assign(3); + x2.assign(1); + x3.assign(2); + grad.linspace(.1, .1); + + sd::ops::mergeadd_bp op; + auto result = op.evaluate({&x1, &x2, &x3, &grad}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(3, result.size()); + + for (int i = 0; i < 3; i++) { + auto z = result.at(0); + ASSERT_TRUE(grad.isSameShape(z)); + ASSERT_TRUE(grad.equalsTo(z)); + } } ///////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, mergeavg_bp_1) { + NDArray x1('c', {5, 5}, sd::DataType::FLOAT32); + NDArray x2('c', {5, 5}, sd::DataType::FLOAT32); + NDArray x3('c', {5, 5}, sd::DataType::FLOAT32); + NDArray grad('c', {5, 5}, sd::DataType::FLOAT32); - NDArray x1('c', { 5, 5 }, sd::DataType::FLOAT32); - NDArray x2('c', { 5, 5 }, sd::DataType::FLOAT32); - NDArray x3('c', { 5, 5 }, sd::DataType::FLOAT32); - NDArray grad('c', { 5, 5 }, sd::DataType::FLOAT32); - - x1.assign(3); - x2.assign(1); - x3.assign(2); - grad.linspace(.1, .1); - - sd::ops::mergeavg_bp op; - auto result = op.evaluate({ &x1, &x2, &x3, &grad }, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(3, result.size()); + x1.assign(3); + x2.assign(1); + x3.assign(2); + grad.linspace(.1, .1); - grad.applyScalar(sd::scalar::Divide, 3, grad); + sd::ops::mergeavg_bp op; + auto result = op.evaluate({&x1, &x2, &x3, &grad}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(3, result.size()); - for (int i = 0; i < 3; i++) { - auto z = result.at(i); - ASSERT_TRUE(grad.isSameShape(z)); - ASSERT_TRUE(grad.equalsTo(z)); - } + grad.applyScalar(sd::scalar::Divide, 3, grad); + for (int i = 0; i < 3; i++) { + auto z = result.at(i); + ASSERT_TRUE(grad.isSameShape(z)); + ASSERT_TRUE(grad.equalsTo(z)); + } } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_1) { - - const int sL = 5; - const int bS = 3; - const int nIn = 3; - const int nOut = 3; - - // input arguments - - const int dataFormat = 0; // [sL,bS,nIn] - const int directionMode = 0; // forward - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = false; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = false; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // do not return output at last time step - const auto retLastC = true; // return cells state at last time step - - const double cellClip = 0; // do not apply clipping - - NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); - - x.linspace(0.5, 0.5); - Wx = 0.003; - Wr = 0.006; - b = 0.5; - hI = 1.; - cI = 2.; - - std::vector tArgs = {cellClip}; - std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - auto expH = NDArrayFactory::create('c', {sL, bS, nOut}, {0.57574f, 0.57574f, 0.57574f, 0.58006f, 0.58006f, 0.58006f, 0.58434f, 0.58434f, 0.58434f, - 0.55114f, 0.55114f, 0.55114f, 0.55732f, 0.55732f, 0.55732f, 0.56338f, 0.56338f, 0.56338f, - 0.53763f, 0.53763f, 0.53763f, 0.54534f, 0.54534f, 0.54534f, 0.55287f, 0.55287f, 0.55287f, - 0.53626f, 0.53626f, 0.53626f, 0.54487f, 0.54487f, 0.54487f, 0.55327f, 0.55327f, 0.55327f, - 0.54484f, 0.54484f, 0.54484f, 0.55379f, 0.55379f, 0.55379f, 0.5625f, 0.5625f, 0.5625f}); - - auto expClast = NDArrayFactory::create('c', {bS, nOut}, {1.1589154f, 1.1589154f, 1.1589154f, 1.1892855f, 1.1892855f, 1.1892855f, 1.219861f, 1.219861f, 1.219861f}); - - sd::ops::lstmLayer op; - auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *h = results.at(0); - auto *cL = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expClast.isSameShape(cL)); - ASSERT_TRUE(expClast.equalsTo(cL)); - + const int sL = 5; + const int bS = 3; + const int nIn = 3; + const int nOut = 3; + + // input arguments + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 0; // forward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, + outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, + hasPH, retFullSeq, retLastH, retLastC}; + + auto expH = NDArrayFactory::create( + 'c', {sL, bS, nOut}, + {0.57574f, 0.57574f, 0.57574f, 0.58006f, 0.58006f, 0.58006f, 0.58434f, + 0.58434f, 0.58434f, 0.55114f, 0.55114f, 0.55114f, 0.55732f, 0.55732f, + 0.55732f, 0.56338f, 0.56338f, 0.56338f, 0.53763f, 0.53763f, 0.53763f, + 0.54534f, 0.54534f, 0.54534f, 0.55287f, 0.55287f, 0.55287f, 0.53626f, + 0.53626f, 0.53626f, 0.54487f, 0.54487f, 0.54487f, 0.55327f, 0.55327f, + 0.55327f, 0.54484f, 0.54484f, 0.54484f, 0.55379f, 0.55379f, 0.55379f, + 0.5625f, 0.5625f, 0.5625f}); + + auto expClast = NDArrayFactory::create( + 'c', {bS, nOut}, + {1.1589154f, 1.1589154f, 1.1589154f, 1.1892855f, 1.1892855f, 1.1892855f, + 1.219861f, 1.219861f, 1.219861f}); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expClast.isSameShape(cL)); + ASSERT_TRUE(expClast.equalsTo(cL)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_2) { - - const int sL = 5; - const int bS = 3; - const int nIn = 3; - const int nOut = 3; - - // input arguments - - const int dataFormat = 1; // [bS,sL,nIn] - const int directionMode = 0; // forward - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = false; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = false; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // return output at last time step - const auto retLastC = true; // return cells state at last time step - - const double cellClip = 0; // do not apply clipping - - NDArray x('c', {bS, sL, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); - - x.linspace(0.5, 0.5); - Wx = 0.003; - Wr = 0.006; - b = 0.5; - hI = 1.; - cI = 2.; - - std::vector tArgs = {cellClip}; - std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - auto expH = NDArrayFactory::create('c', {bS, sL, nOut}, {0.575735f, 0.575735f, 0.575735f, 0.541562f, 0.541562f, 0.541562f, 0.514003f, 0.514003f, 0.514003f, 0.495597f, 0.495597f, 0.495597f, 0.485999f, 0.485999f, 0.485999f, - 0.596965f, 0.596965f, 0.596965f, 0.571978f, 0.571978f, 0.571978f, 0.552888f, 0.552888f, 0.552888f, 0.540606f, 0.540606f, 0.540606f, 0.534764f, 0.534764f, 0.534764f, - 0.61725f, 0.61725f, 0.61725f, 0.599828f, 0.599828f, 0.599828f, 0.587627f, 0.587627f, 0.587627f, 0.580408f, 0.580408f, 0.580408f, 0.577735f, 0.577735f, 0.577735f}); - - auto expClast = NDArrayFactory::create('c', {bS, nOut}, {0.996965f, 0.996965f, 0.996965f, 1.146756f, 1.146756f, 1.146756f, 1.301922f, 1.301922f, 1.301922f}); - - sd::ops::lstmLayer op; - auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *h = results.at(0); - auto *cL = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expClast.isSameShape(cL)); - ASSERT_TRUE(expClast.equalsTo(cL)); - + const int sL = 5; + const int bS = 3; + const int nIn = 3; + const int nOut = 3; + + // input arguments + + const int dataFormat = 1; // [bS,sL,nIn] + const int directionMode = 0; // forward + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = + true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {bS, sL, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, + outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, + hasPH, retFullSeq, retLastH, retLastC}; + + auto expH = NDArrayFactory::create( + 'c', {bS, sL, nOut}, + {0.575735f, 0.575735f, 0.575735f, 0.541562f, 0.541562f, 0.541562f, + 0.514003f, 0.514003f, 0.514003f, 0.495597f, 0.495597f, 0.495597f, + 0.485999f, 0.485999f, 0.485999f, 0.596965f, 0.596965f, 0.596965f, + 0.571978f, 0.571978f, 0.571978f, 0.552888f, 0.552888f, 0.552888f, + 0.540606f, 0.540606f, 0.540606f, 0.534764f, 0.534764f, 0.534764f, + 0.61725f, 0.61725f, 0.61725f, 0.599828f, 0.599828f, 0.599828f, + 0.587627f, 0.587627f, 0.587627f, 0.580408f, 0.580408f, 0.580408f, + 0.577735f, 0.577735f, 0.577735f}); + + auto expClast = NDArrayFactory::create( + 'c', {bS, nOut}, + {0.996965f, 0.996965f, 0.996965f, 1.146756f, 1.146756f, 1.146756f, + 1.301922f, 1.301922f, 1.301922f}); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expClast.isSameShape(cL)); + ASSERT_TRUE(expClast.equalsTo(cL)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_3) { - - const int sL = 5; - const int bS = 2; - const int nIn = 4; - const int nOut = 3; - - // input arguments - - const int dataFormat = 0; // [sL,bS,nIn] - const int directionMode = 1; // backward - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = false; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = false; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // do not return output at last time step - const auto retLastC = true; // return cells state at last time step - - const double cellClip = 0; // do not apply clipping - - NDArray x('c', {sL,bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); - - x.linspace(0.5, 0.5); - Wx = 0.003; - Wr = 0.006; - b = 0.5; - hI = 1.; - cI = 2.; - - std::vector tArgs = {cellClip}; - std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - NDArray expH('c', {sL, bS, nOut}, {0.493883f, 0.493883f, 0.493883f, 0.510990f, 0.510990f, 0.510990f, 0.534701f, 0.534701f, 0.534701f, 0.549139f, - 0.549139f, 0.549139f, 0.571900f, 0.571900f, 0.571900f, 0.583561f, 0.583561f, 0.583561f, 0.605106f, 0.605106f, - 0.605106f, 0.614114f, 0.614114f, 0.614114f, 0.635354f, 0.635354f, 0.635354f, 0.642045f, 0.642045f, 0.642045f}, sd::DataType::FLOAT32); - - NDArray expHL('c', {bS, nOut}, {0.493883f, 0.493883f, 0.493883f, 0.510990f, 0.510990f, 0.510990f}, sd::DataType::FLOAT32); - NDArray expCL('c', {bS, nOut}, {1.061274f, 1.061274f, 1.061274f, 1.115888f, 1.115888f, 1.115888f}, sd::DataType::FLOAT32); - - sd::ops::lstmLayer op; - auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hL = results.at(1); - auto cL = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expHL.isSameShape(hL)); - ASSERT_TRUE(expHL.equalsTo(hL)); - - ASSERT_TRUE(expCL.isSameShape(cL)); - ASSERT_TRUE(expCL.equalsTo(cL)); - + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 1; // backward + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = + true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, + outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, + hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH( + 'c', {sL, bS, nOut}, + {0.493883f, 0.493883f, 0.493883f, 0.510990f, 0.510990f, 0.510990f, + 0.534701f, 0.534701f, 0.534701f, 0.549139f, 0.549139f, 0.549139f, + 0.571900f, 0.571900f, 0.571900f, 0.583561f, 0.583561f, 0.583561f, + 0.605106f, 0.605106f, 0.605106f, 0.614114f, 0.614114f, 0.614114f, + 0.635354f, 0.635354f, 0.635354f, 0.642045f, 0.642045f, 0.642045f}, + sd::DataType::FLOAT32); + + NDArray expHL( + 'c', {bS, nOut}, + {0.493883f, 0.493883f, 0.493883f, 0.510990f, 0.510990f, 0.510990f}, + sd::DataType::FLOAT32); + NDArray expCL( + 'c', {bS, nOut}, + {1.061274f, 1.061274f, 1.061274f, 1.115888f, 1.115888f, 1.115888f}, + sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); } - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_4) { - - const int sL = 5; - const int bS = 2; - const int nIn = 4; - const int nOut = 3; - - // input arguments - const int dataFormat = 0; // [sL,bS,nIn] - const int directionMode = 3; // bidirectional concat - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = false; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = false; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // do not return output at last time step - const auto retLastC = true; // return cells state at last time step - - const double cellClip = 0; // do not apply clipping - - NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {2,nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {2,nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {2,4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {2,bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {2,bS, nOut}, sd::DataType::FLOAT32); - - x.linspace(0.5, 0.5); - Wx({0,1, 0,0, 0,0}) = 0.003f; - Wx({1,2, 0,0, 0,0}) = -0.003f; - Wr({0,1, 0,0, 0,0}) = 0.006f; - Wr({1,2, 0,0, 0,0}) = -0.006f; - b({0,1, 0,0}) = 0.5f; - b({1,2, 0,0}) = -0.5f; - hI({0,1, 0,0, 0,0}) = 1; - hI({1,2, 0,0, 0,0}) = -1; - cI({0,1, 0,0, 0,0}) = 2; - cI({1,2, 0,0, 0,0}) = -2; - - std::vector tArgs = {cellClip}; - std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - NDArray expH('c', {sL, bS, 2 * nOut}, { - 0.577661f, 0.577661f, 0.577661f, -0.107642f, -0.107642f, -0.107642f, 0.585289f, 0.585289f, 0.585289f, - -0.106937f, -0.106937f, -0.106937f, 0.556517f, 0.556517f, 0.556517f, -0.111647f, -0.111647f, -0.111647f, - 0.567274f, 0.567274f, 0.567274f, -0.110214f, -0.110214f, -0.110214f, 0.547395f, 0.547395f, 0.547395f, - -0.123305f, -0.123305f, -0.123305f, 0.560640f, 0.560640f, 0.560640f, -0.120862f, -0.120862f, -0.120862f, - 0.550714f, 0.550714f, 0.550714f, -0.156223f, -0.156223f, -0.156223f, 0.565308f, 0.565308f, 0.565308f, - -0.152313f, -0.152313f, -0.152313f, 0.563741f, 0.563741f, 0.563741f, -0.234128f, -0.234128f, -0.234128f, - 0.578676f, 0.578676f, 0.578676f, -0.228917f, -0.228917f, -0.228917f}, sd::DataType::FLOAT32); - - NDArray expHL('c', {2,bS, nOut}, {0.563741f, 0.563741f, 0.563741f, 0.578676f, 0.578676f, 0.578676f, -0.107642f, - -0.107642f, -0.107642f, -0.106937f, -0.106937f, -0.106937f}, sd::DataType::FLOAT32); - NDArray expCL('c', {2,bS, nOut}, {1.217757f, 1.217757f, 1.217757f, 1.272398f, 1.272398f, 1.272398f, -0.295768f, - -0.295768f, -0.295768f, -0.298453f, -0.298453f, -0.298453f}, sd::DataType::FLOAT32); - - sd::ops::lstmLayer op; - auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hL = results.at(1); - auto cL = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expHL.isSameShape(hL)); - ASSERT_TRUE(expHL.equalsTo(hL)); - - ASSERT_TRUE(expCL.isSameShape(cL)); - ASSERT_TRUE(expCL.equalsTo(cL)); - + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 3; // bidirectional concat + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = + true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {2, nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {2, nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {2, 4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {2, bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {2, bS, nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx({0, 1, 0, 0, 0, 0}) = 0.003f; + Wx({1, 2, 0, 0, 0, 0}) = -0.003f; + Wr({0, 1, 0, 0, 0, 0}) = 0.006f; + Wr({1, 2, 0, 0, 0, 0}) = -0.006f; + b({0, 1, 0, 0}) = 0.5f; + b({1, 2, 0, 0}) = -0.5f; + hI({0, 1, 0, 0, 0, 0}) = 1; + hI({1, 2, 0, 0, 0, 0}) = -1; + cI({0, 1, 0, 0, 0, 0}) = 2; + cI({1, 2, 0, 0, 0, 0}) = -2; + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, + outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, + hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH( + 'c', {sL, bS, 2 * nOut}, + {0.577661f, 0.577661f, 0.577661f, -0.107642f, -0.107642f, -0.107642f, + 0.585289f, 0.585289f, 0.585289f, -0.106937f, -0.106937f, -0.106937f, + 0.556517f, 0.556517f, 0.556517f, -0.111647f, -0.111647f, -0.111647f, + 0.567274f, 0.567274f, 0.567274f, -0.110214f, -0.110214f, -0.110214f, + 0.547395f, 0.547395f, 0.547395f, -0.123305f, -0.123305f, -0.123305f, + 0.560640f, 0.560640f, 0.560640f, -0.120862f, -0.120862f, -0.120862f, + 0.550714f, 0.550714f, 0.550714f, -0.156223f, -0.156223f, -0.156223f, + 0.565308f, 0.565308f, 0.565308f, -0.152313f, -0.152313f, -0.152313f, + 0.563741f, 0.563741f, 0.563741f, -0.234128f, -0.234128f, -0.234128f, + 0.578676f, 0.578676f, 0.578676f, -0.228917f, -0.228917f, -0.228917f}, + sd::DataType::FLOAT32); + + NDArray expHL( + 'c', {2, bS, nOut}, + {0.563741f, 0.563741f, 0.563741f, 0.578676f, 0.578676f, 0.578676f, + -0.107642f, -0.107642f, -0.107642f, -0.106937f, -0.106937f, -0.106937f}, + sd::DataType::FLOAT32); + NDArray expCL( + 'c', {2, bS, nOut}, + {1.217757f, 1.217757f, 1.217757f, 1.272398f, 1.272398f, 1.272398f, + -0.295768f, -0.295768f, -0.295768f, -0.298453f, -0.298453f, -0.298453f}, + sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_5) { - - const int sL = 5; - const int bS = 2; - const int nIn = 4; - const int nOut = 3; - - // input arguments - const int dataFormat = 1; // [bS,sL,nIn] - const int directionMode = 3; // bidirectional concat - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = false; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = false; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // do not return output at last time step - const auto retLastC = true; // return cells state at last time step - - const double cellClip = 0; // do not apply clipping - - NDArray x('c', {bS, sL, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {2,nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {2,nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {2,4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {2,bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {2,bS, nOut}, sd::DataType::FLOAT32); - - x.linspace(0.5, 0.5); - Wx({0,1, 0,0, 0,0}) = 0.003; - Wx({1,2, 0,0, 0,0}) = -0.003; - Wr({0,1, 0,0, 0,0}) = 0.006; - Wr({1,2, 0,0, 0,0}) = -0.006; - b({0,1, 0,0}) = 0.5; - b({1,2, 0,0}) = -0.5; - hI({0,1, 0,0, 0,0}) = 1; - hI({1,2, 0,0, 0,0}) = -1; - cI({0,1, 0,0, 0,0}) = 2; - cI({1,2, 0,0, 0,0}) = -2; - - std::vector tArgs = {cellClip}; - std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - NDArray expH('c', {bS, sL, 2*nOut}, { - 0.577661f, 0.577661f, 0.577661f, -0.107659f, -0.107659f, -0.107659f, 0.548099f, 0.548099f, 0.548099f, -0.113406f, -0.113406f, -0.113406f, - 0.526881f, 0.526881f, 0.526881f, -0.12883f, -0.12883f, -0.12883f, 0.515882f, 0.515882f, 0.515882f, -0.16868f, -0.16868f, -0.16868f, - 0.51409f, 0.51409f, 0.51409f, -0.255185f, -0.255185f, -0.255185f, 0.614599f, 0.614599f, 0.614599f, -0.102739f, -0.102739f, -0.102739f, - 0.599572f, 0.599572f, 0.599572f, -0.105802f, -0.105802f, -0.105802f, 0.591089f, 0.591089f, 0.591089f, -0.116681f, -0.116681f, -0.116681f, - 0.588694f, 0.588694f, 0.588694f, -0.149201f, -0.149201f, -0.149201f, 0.591492f, 0.591492f, 0.591492f, -0.228917f, -0.228917f, -0.228917f}, sd::DataType::FLOAT32); - - NDArray expHL('c', {2,bS, nOut}, {0.51409f, 0.51409f, 0.51409f, 0.591492f, 0.591492f, 0.591492f, - -0.107659f, -0.107659f, -0.107659f, -0.102739f, -0.102739f, -0.102739f}, sd::DataType::FLOAT32); - NDArray expCL('c', {2,bS, nOut}, {1.07293f , 1.07293f , 1.07293f, 1.346609f, 1.346609f, 1.346609f, - -0.295811f, -0.295811f, -0.295811f, -0.305394f, -0.305394f, -0.305394f}, sd::DataType::FLOAT32); - - sd::ops::lstmLayer op; - auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hL = results.at(1); - auto cL = results.at(2); - - // h->printBuffer(); - // hL->printBuffer(); - // cL->printBuffer(); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expHL.isSameShape(hL)); - ASSERT_TRUE(expHL.equalsTo(hL)); - - ASSERT_TRUE(expCL.isSameShape(cL)); - ASSERT_TRUE(expCL.equalsTo(cL)); - + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 1; // [bS,sL,nIn] + const int directionMode = 3; // bidirectional concat + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = + true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {bS, sL, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {2, nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {2, nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {2, 4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {2, bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {2, bS, nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx({0, 1, 0, 0, 0, 0}) = 0.003; + Wx({1, 2, 0, 0, 0, 0}) = -0.003; + Wr({0, 1, 0, 0, 0, 0}) = 0.006; + Wr({1, 2, 0, 0, 0, 0}) = -0.006; + b({0, 1, 0, 0}) = 0.5; + b({1, 2, 0, 0}) = -0.5; + hI({0, 1, 0, 0, 0, 0}) = 1; + hI({1, 2, 0, 0, 0, 0}) = -1; + cI({0, 1, 0, 0, 0, 0}) = 2; + cI({1, 2, 0, 0, 0, 0}) = -2; + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, + outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, + hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH( + 'c', {bS, sL, 2 * nOut}, + {0.577661f, 0.577661f, 0.577661f, -0.107659f, -0.107659f, -0.107659f, + 0.548099f, 0.548099f, 0.548099f, -0.113406f, -0.113406f, -0.113406f, + 0.526881f, 0.526881f, 0.526881f, -0.12883f, -0.12883f, -0.12883f, + 0.515882f, 0.515882f, 0.515882f, -0.16868f, -0.16868f, -0.16868f, + 0.51409f, 0.51409f, 0.51409f, -0.255185f, -0.255185f, -0.255185f, + 0.614599f, 0.614599f, 0.614599f, -0.102739f, -0.102739f, -0.102739f, + 0.599572f, 0.599572f, 0.599572f, -0.105802f, -0.105802f, -0.105802f, + 0.591089f, 0.591089f, 0.591089f, -0.116681f, -0.116681f, -0.116681f, + 0.588694f, 0.588694f, 0.588694f, -0.149201f, -0.149201f, -0.149201f, + 0.591492f, 0.591492f, 0.591492f, -0.228917f, -0.228917f, -0.228917f}, + sd::DataType::FLOAT32); + + NDArray expHL( + 'c', {2, bS, nOut}, + {0.51409f, 0.51409f, 0.51409f, 0.591492f, 0.591492f, 0.591492f, + -0.107659f, -0.107659f, -0.107659f, -0.102739f, -0.102739f, -0.102739f}, + sd::DataType::FLOAT32); + NDArray expCL( + 'c', {2, bS, nOut}, + {1.07293f, 1.07293f, 1.07293f, 1.346609f, 1.346609f, 1.346609f, + -0.295811f, -0.295811f, -0.295811f, -0.305394f, -0.305394f, -0.305394f}, + sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + // h->printBuffer(); + // hL->printBuffer(); + // cL->printBuffer(); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_6) { - - const int sL = 5; - const int bS = 2; - const int nIn = 4; - const int nOut = 3; - - // input arguments - const int dataFormat = 0; // [sL,bS,nIn] - const int directionMode = 2; // bidirectional sum - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = false; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = false; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // do not return output at last time step - const auto retLastC = true; // return cells state at last time step - - const double cellClip = 0; // do not apply clipping - - NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {2,nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {2,nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {2,4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {2,bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {2,bS, nOut}, sd::DataType::FLOAT32); - - x.linspace(0.5, 0.5); - Wx({0,1, 0,0, 0,0}) = 0.003f; - Wx({1,2, 0,0, 0,0}) = -0.003f; - Wr({0,1, 0,0, 0,0}) = 0.006f; - Wr({1,2, 0,0, 0,0}) = -0.006f; - b({0,1, 0,0}) = 0.5f; - b({1,2, 0,0}) = -0.5f; - hI({0,1, 0,0, 0,0}) = 1; - hI({1,2, 0,0, 0,0}) = -1; - cI({0,1, 0,0, 0,0}) = 2; - cI({1,2, 0,0, 0,0}) = -2; - - std::vector tArgs = {cellClip}; - std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - NDArray expH('c', {sL, bS, nOut}, { - 0.470019f, 0.470019f, 0.470019f, 0.478352f, 0.478352f, 0.478352f, 0.444871f, 0.444871f, 0.444871f, 0.457060f, - 0.457060f, 0.457060f, 0.424090f, 0.424090f, 0.424090f, 0.439778f, 0.439778f, 0.439778f, 0.394491f, 0.394491f, - 0.394491f, 0.412995f, 0.412995f, 0.412995f, 0.329613f, 0.329613f, 0.329613f, 0.349760f, 0.349760f, 0.349760f}, sd::DataType::FLOAT32); - - NDArray expHL('c', {2,bS, nOut}, {0.563741f, 0.563741f, 0.563741f, 0.578676f, 0.578676f, 0.578676f, - -0.107642f, -0.107642f, -0.107642f, -0.106937f, -0.106937f, -0.106937f}, - sd::DataType::FLOAT32); - NDArray expCL('c', {2,bS, nOut}, {1.217757f, 1.217757f, 1.217757f, 1.272398f, 1.272398f, 1.272398f, - -0.295768f, -0.295768f, -0.295768f, -0.298453f, -0.298453f, -0.298453f}, - sd::DataType::FLOAT32); - - sd::ops::lstmLayer op; - auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hL = results.at(1); - auto cL = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expHL.isSameShape(hL)); - ASSERT_TRUE(expHL.equalsTo(hL)); - - ASSERT_TRUE(expCL.isSameShape(cL)); - ASSERT_TRUE(expCL.equalsTo(cL)); - + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 2; // bidirectional sum + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = + true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {2, nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {2, nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {2, 4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {2, bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {2, bS, nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx({0, 1, 0, 0, 0, 0}) = 0.003f; + Wx({1, 2, 0, 0, 0, 0}) = -0.003f; + Wr({0, 1, 0, 0, 0, 0}) = 0.006f; + Wr({1, 2, 0, 0, 0, 0}) = -0.006f; + b({0, 1, 0, 0}) = 0.5f; + b({1, 2, 0, 0}) = -0.5f; + hI({0, 1, 0, 0, 0, 0}) = 1; + hI({1, 2, 0, 0, 0, 0}) = -1; + cI({0, 1, 0, 0, 0, 0}) = 2; + cI({1, 2, 0, 0, 0, 0}) = -2; + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, + outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, + hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH( + 'c', {sL, bS, nOut}, + {0.470019f, 0.470019f, 0.470019f, 0.478352f, 0.478352f, 0.478352f, + 0.444871f, 0.444871f, 0.444871f, 0.457060f, 0.457060f, 0.457060f, + 0.424090f, 0.424090f, 0.424090f, 0.439778f, 0.439778f, 0.439778f, + 0.394491f, 0.394491f, 0.394491f, 0.412995f, 0.412995f, 0.412995f, + 0.329613f, 0.329613f, 0.329613f, 0.349760f, 0.349760f, 0.349760f}, + sd::DataType::FLOAT32); + + NDArray expHL( + 'c', {2, bS, nOut}, + {0.563741f, 0.563741f, 0.563741f, 0.578676f, 0.578676f, 0.578676f, + -0.107642f, -0.107642f, -0.107642f, -0.106937f, -0.106937f, -0.106937f}, + sd::DataType::FLOAT32); + NDArray expCL( + 'c', {2, bS, nOut}, + {1.217757f, 1.217757f, 1.217757f, 1.272398f, 1.272398f, 1.272398f, + -0.295768f, -0.295768f, -0.295768f, -0.298453f, -0.298453f, -0.298453f}, + sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_7) { - #ifndef HAVE_MKLDNN - - const int sL = 5; - const int bS = 2; - const int nIn = 4; - const int nOut = 3; - - // input arguments - - const int dataFormat = 0; // [sL,bS,nIn] - const int directionMode = 0; // forward - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = false; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // do not return output at last time step - const auto retLastC = true; // return cells state at last time step - - const double cellClip = 0; // do not apply clipping - - NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray Wp('c', {3*nOut}, sd::DataType::FLOAT32); - - x.linspace(0.5, 0.5); - Wx = 0.003; - Wr = 0.006; - b = 0.5; - hI = 1.; - cI = 2.; - Wp = -0.05; - - std::initializer_list tArgs = {cellClip}; - std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - NDArray expH('c', {sL, bS, nOut}, {0.55533 , 0.55533 , 0.55533 , 0.562925, 0.562925, 0.562925, 0.531795, 0.531795, 0.531795, 0.542556, - 0.542556, 0.542556, 0.521466, 0.521466, 0.521466, 0.534638, 0.534638, 0.534638, 0.524805, 0.524805, - 0.524805, 0.539187, 0.539187, 0.539187, 0.538309, 0.538309, 0.538309, 0.552923, 0.552923, 0.552923}, sd::DataType::FLOAT32); - - NDArray expHL('c', {bS, nOut}, {0.538309, 0.538309, 0.538309,0.552923, 0.552923, 0.552923}, sd::DataType::FLOAT32); - NDArray expCL('c', {bS, nOut}, {1.147089, 1.147089, 1.147089,1.197228, 1.197228, 1.197228}, sd::DataType::FLOAT32); - - sd::ops::lstmLayer op; - auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hL = results.at(1); - auto cL = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expHL.isSameShape(hL)); - ASSERT_TRUE(expHL.equalsTo(hL)); - - ASSERT_TRUE(expCL.isSameShape(cL)); - ASSERT_TRUE(expCL.equalsTo(cL)); +#ifndef HAVE_MKLDNN + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 0; // forward + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = + true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray Wp('c', {3 * nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + Wp = -0.05; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, + cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, + hasInitC, hasPH, retFullSeq, + retLastH, retLastC}; + + NDArray expH('c', {sL, bS, nOut}, + {0.55533, 0.55533, 0.55533, 0.562925, 0.562925, 0.562925, + 0.531795, 0.531795, 0.531795, 0.542556, 0.542556, 0.542556, + 0.521466, 0.521466, 0.521466, 0.534638, 0.534638, 0.534638, + 0.524805, 0.524805, 0.524805, 0.539187, 0.539187, 0.539187, + 0.538309, 0.538309, 0.538309, 0.552923, 0.552923, 0.552923}, + sd::DataType::FLOAT32); + + NDArray expHL('c', {bS, nOut}, + {0.538309, 0.538309, 0.538309, 0.552923, 0.552923, 0.552923}, + sd::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, + {1.147089, 1.147089, 1.147089, 1.197228, 1.197228, 1.197228}, + sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = + op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); - - #endif +#endif } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_8) { - #ifndef HAVE_MKLDNN - - const int sL = 5; - const int bS = 2; - const int nIn = 4; - const int nOut = 3; - - // input arguments - - const int dataFormat = 0; // [sL,bS,nIn] - const int directionMode = 1; // backward - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = false; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // do not return output at last time step - const auto retLastC = true; // return cells state at last time step - - const double cellClip = 1.; // do not apply clipping - - NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray Wp('c', {3*nOut}, sd::DataType::FLOAT32); - - x.linspace(0.5, 0.5); - Wx = 0.003; - Wr = 0.006; - b = 0.5; - hI = 1.; - cI = 2.; - Wp = -0.05; - - std::initializer_list tArgs = {cellClip}; - std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - NDArray expH('c', {sL, bS, nOut}, { - 0.436221f, 0.436221f, 0.436221f, 0.450573f, 0.450573f, 0.450573f, 0.463602f, 0.463602f, 0.463602f, 0.474674f, 0.474674f, 0.474674f, - 0.484039f, 0.484039f, 0.484039f, 0.490679f, 0.490679f, 0.490679f, 0.494871f, 0.494871f, 0.494871f, 0.499028f, 0.499028f, 0.499028f, - 0.504649f, 0.504649f, 0.504649f, 0.508719f, 0.508719f, 0.508719f}, sd::DataType::FLOAT32); +#ifndef HAVE_MKLDNN + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 1; // backward + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = + true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 1.; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray Wp('c', {3 * nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + Wp = -0.05; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, + cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, + hasInitC, hasPH, retFullSeq, + retLastH, retLastC}; + + NDArray expH( + 'c', {sL, bS, nOut}, + {0.436221f, 0.436221f, 0.436221f, 0.450573f, 0.450573f, 0.450573f, + 0.463602f, 0.463602f, 0.463602f, 0.474674f, 0.474674f, 0.474674f, + 0.484039f, 0.484039f, 0.484039f, 0.490679f, 0.490679f, 0.490679f, + 0.494871f, 0.494871f, 0.494871f, 0.499028f, 0.499028f, 0.499028f, + 0.504649f, 0.504649f, 0.504649f, 0.508719f, 0.508719f, 0.508719f}, + sd::DataType::FLOAT32); + + NDArray expHL( + 'c', {bS, nOut}, + {0.436221f, 0.436221f, 0.436221f, 0.450573f, 0.450573f, 0.450573f}, + sd::DataType::FLOAT32); + NDArray expCL( + 'c', {bS, nOut}, + {0.879804f, 0.879804f, 0.879804f, 0.914666f, 0.914666f, 0.914666f}, + sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = + op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); - NDArray expHL('c', {bS, nOut}, {0.436221f, 0.436221f, 0.436221f, 0.450573f, 0.450573f, 0.450573f}, sd::DataType::FLOAT32); - NDArray expCL('c', {bS, nOut}, {0.879804f, 0.879804f, 0.879804f, 0.914666f, 0.914666f, 0.914666f}, sd::DataType::FLOAT32); - - sd::ops::lstmLayer op; - auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hL = results.at(1); - auto cL = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expHL.isSameShape(hL)); - ASSERT_TRUE(expHL.equalsTo(hL)); - - ASSERT_TRUE(expCL.isSameShape(cL)); - ASSERT_TRUE(expCL.equalsTo(cL)); - - - #endif +#endif } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_9) { - #ifndef HAVE_MKLDNN - - const int sL = 5; - const int bS = 2; - const int nIn = 4; - const int nOut = 3; - - // input arguments - const int dataFormat = 0; // [sL,bS,nIn] - const int directionMode = 3; // bidirectional concat - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = false; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // do not return output at last time step - const auto retLastC = true; // return cells state at last time step - - const double cellClip = 0; // do not apply clipping - - NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {2,nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {2,nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {2,4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {2,bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {2,bS, nOut}, sd::DataType::FLOAT32); - NDArray Wp('c', {2,3*nOut}, sd::DataType::FLOAT32); - - x.linspace(0.5, 0.5); - Wx({0,1, 0,0, 0,0}) = 0.003; - Wx({1,2, 0,0, 0,0}) = -0.003; - Wr({0,1, 0,0, 0,0}) = 0.006; - Wr({1,2, 0,0, 0,0}) = -0.006; - b({0,1, 0,0}) = 0.5; - b({1,2, 0,0}) = -0.5; - hI({0,1, 0,0, 0,0}) = 1; - hI({1,2, 0,0, 0,0}) = -1; - cI({0,1, 0,0, 0,0}) = 2; - cI({1,2, 0,0, 0,0}) = -2; - Wp({0,1, 0,0}) = -0.05; - Wp({1,2, 0,0}) = 0.05; - - std::initializer_list tArgs = {cellClip}; - std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - NDArray expH('c', {sL, bS, 2*nOut}, { - 0.55533f, 0.55533f, 0.55533f, -0.104502f, -0.104502f, -0.104502f, 0.562925f, 0.562925f, 0.562925f, -0.103843f, -0.103843f, -0.103843f, - 0.531795f, 0.531795f, 0.531795f, -0.107456f, -0.107456f, -0.107456f, 0.542556f, 0.542556f, 0.542556f, -0.106139f, -0.106139f, -0.106139f, - 0.521466f, 0.521466f, 0.521466f, -0.11681f, -0.11681f, -0.11681f, 0.534638f, 0.534638f, 0.534638f, -0.11458f, -0.11458f, -0.11458f, - 0.524805f, 0.524805f, 0.524805f, -0.145177f, -0.145177f, -0.145177f, 0.539187f, 0.539187f, 0.539187f, -0.14157f, -0.14157f, -0.14157f, - 0.538309f, 0.538309f, 0.538309f, -0.218056f, -0.218056f, -0.218056f, 0.552923f, 0.552923f, 0.552923f, -0.213068f, -0.213068f, -0.213068f}, sd::DataType::FLOAT32); - - NDArray expHL('c', {2,bS, nOut}, {0.538309f, 0.538309f, 0.538309f, 0.552923f, 0.552923f, 0.552923f, -0.104502f, -0.104502f, -0.104502f, - -0.103843f, -0.103843f, -0.103843f}, sd::DataType::FLOAT32); - NDArray expCL('c', {2,bS, nOut}, {1.147089f, 1.147089f, 1.147089f, 1.197228f, 1.197228f, 1.197228f, -0.289425f, -0.289425f, -0.289425f, - -0.292174f, -0.292174f, -0.292174f}, sd::DataType::FLOAT32); - - sd::ops::lstmLayer op; - auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hL = results.at(1); - auto cL = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expHL.isSameShape(hL)); - ASSERT_TRUE(expHL.equalsTo(hL)); - - ASSERT_TRUE(expCL.isSameShape(cL)); - ASSERT_TRUE(expCL.equalsTo(cL)); - - - #endif +#ifndef HAVE_MKLDNN + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 3; // bidirectional concat + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = + true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {2, nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {2, nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {2, 4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {2, bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {2, bS, nOut}, sd::DataType::FLOAT32); + NDArray Wp('c', {2, 3 * nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx({0, 1, 0, 0, 0, 0}) = 0.003; + Wx({1, 2, 0, 0, 0, 0}) = -0.003; + Wr({0, 1, 0, 0, 0, 0}) = 0.006; + Wr({1, 2, 0, 0, 0, 0}) = -0.006; + b({0, 1, 0, 0}) = 0.5; + b({1, 2, 0, 0}) = -0.5; + hI({0, 1, 0, 0, 0, 0}) = 1; + hI({1, 2, 0, 0, 0, 0}) = -1; + cI({0, 1, 0, 0, 0, 0}) = 2; + cI({1, 2, 0, 0, 0, 0}) = -2; + Wp({0, 1, 0, 0}) = -0.05; + Wp({1, 2, 0, 0}) = 0.05; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, + cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, + hasInitC, hasPH, retFullSeq, + retLastH, retLastC}; + + NDArray expH( + 'c', {sL, bS, 2 * nOut}, + {0.55533f, 0.55533f, 0.55533f, -0.104502f, -0.104502f, -0.104502f, + 0.562925f, 0.562925f, 0.562925f, -0.103843f, -0.103843f, -0.103843f, + 0.531795f, 0.531795f, 0.531795f, -0.107456f, -0.107456f, -0.107456f, + 0.542556f, 0.542556f, 0.542556f, -0.106139f, -0.106139f, -0.106139f, + 0.521466f, 0.521466f, 0.521466f, -0.11681f, -0.11681f, -0.11681f, + 0.534638f, 0.534638f, 0.534638f, -0.11458f, -0.11458f, -0.11458f, + 0.524805f, 0.524805f, 0.524805f, -0.145177f, -0.145177f, -0.145177f, + 0.539187f, 0.539187f, 0.539187f, -0.14157f, -0.14157f, -0.14157f, + 0.538309f, 0.538309f, 0.538309f, -0.218056f, -0.218056f, -0.218056f, + 0.552923f, 0.552923f, 0.552923f, -0.213068f, -0.213068f, -0.213068f}, + sd::DataType::FLOAT32); + + NDArray expHL( + 'c', {2, bS, nOut}, + {0.538309f, 0.538309f, 0.538309f, 0.552923f, 0.552923f, 0.552923f, + -0.104502f, -0.104502f, -0.104502f, -0.103843f, -0.103843f, -0.103843f}, + sd::DataType::FLOAT32); + NDArray expCL( + 'c', {2, bS, nOut}, + {1.147089f, 1.147089f, 1.147089f, 1.197228f, 1.197228f, 1.197228f, + -0.289425f, -0.289425f, -0.289425f, -0.292174f, -0.292174f, -0.292174f}, + sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = + op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + +#endif } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_10) { - #ifndef HAVE_MKLDNN - - const int sL = 6; - const int bS = 5; - const int nIn = 4; - const int nOut = 3; - - // input arguments - const int dataFormat = 0; // [sL,bS,nIn] - const int directionMode = 0; // forward - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = true; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // do not return output at last time step - const auto retLastC = true; // return cells state at last time step - - const double cellClip = 0; // do not apply clipping - - NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray seqLen('c', {bS}, {0,1,2,3,5}, sd::DataType::FLOAT32); - NDArray Wp('c', {3*nOut}, sd::DataType::FLOAT32); - - x.linspace(0.5, 0.5); - Wx = 0.003; - Wr = 0.006; - b = 0.5; - hI = 1.; - cI = 2.; - Wp = -0.05; - - std::initializer_list tArgs = {cellClip}; - std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - NDArray expH('c', {sL, bS, nOut}, { - 0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.570404f, 0.570404f, 0.570404f, 0.57777f, - 0.57777f, 0.57777f, 0.585023f, 0.585023f, 0.585023f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.576568f, 0.576568f, 0.576568f, 0.586163f, 0.586163f, 0.586163f, 0.595462f, 0.595462f, 0.595462f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.611224f, - 0.611224f, 0.611224f, 0.621298f, 0.621298f, 0.621298f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.655858f, 0.655858f, 0.655858f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.692315f, 0.692315f, 0.692315f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}, - sd::DataType::FLOAT32); - - NDArray expHL('c', {bS, nOut}, {0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.576568f, 0.576568f, 0.576568f, 0.611224f, 0.611224f, 0.611224f, 0.692315f, 0.692315f, 0.692315f}, sd::DataType::FLOAT32); - NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 1.534275f, 1.534275f, 1.534275f, 1.40183f, 1.40183f, 1.40183f, 1.449675f, 1.449675f, 1.449675f, 1.767702f, 1.767702f, 1.767702f}, sd::DataType::FLOAT32); - - sd::ops::lstmLayer op; - auto results = op.evaluate({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hL = results.at(1); - auto cL = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expHL.isSameShape(hL)); - ASSERT_TRUE(expHL.equalsTo(hL)); - - ASSERT_TRUE(expCL.isSameShape(cL)); - ASSERT_TRUE(expCL.equalsTo(cL)); - - - #endif +#ifndef HAVE_MKLDNN + + const int sL = 6; + const int bS = 5; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 0; // forward + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = + true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray seqLen('c', {bS}, {0, 1, 2, 3, 5}, sd::DataType::FLOAT32); + NDArray Wp('c', {3 * nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + Wp = -0.05; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, + cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, + hasInitC, hasPH, retFullSeq, + retLastH, retLastC}; + + NDArray expH( + 'c', {sL, bS, nOut}, + {0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, + 0.570404f, 0.570404f, 0.570404f, 0.57777f, 0.57777f, 0.57777f, + 0.585023f, 0.585023f, 0.585023f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.576568f, 0.576568f, 0.576568f, + 0.586163f, 0.586163f, 0.586163f, 0.595462f, 0.595462f, 0.595462f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.611224f, 0.611224f, 0.611224f, + 0.621298f, 0.621298f, 0.621298f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.655858f, 0.655858f, 0.655858f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.692315f, 0.692315f, 0.692315f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}, + sd::DataType::FLOAT32); + + NDArray expHL('c', {bS, nOut}, + {0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.576568f, + 0.576568f, 0.576568f, 0.611224f, 0.611224f, 0.611224f, + 0.692315f, 0.692315f, 0.692315f}, + sd::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, + {0.f, 0.f, 0.f, 1.534275f, 1.534275f, 1.534275f, 1.40183f, + 1.40183f, 1.40183f, 1.449675f, 1.449675f, 1.449675f, 1.767702f, + 1.767702f, 1.767702f}, + sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, + iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + +#endif } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_11) { - #ifndef HAVE_MKLDNN - - const int sL = 6; - const int bS = 5; - const int nIn = 4; - const int nOut = 3; - - // input arguments - const int dataFormat = 0; // [sL,bS,nIn] - const int directionMode = 1; // backward - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = true; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // do not return output at last time step - const auto retLastC = true; // return cells state at last time step - - const double cellClip = 0; // do not apply clipping - - NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray seqLen('c', {bS}, {0,1,2,3,5}, sd::DataType::FLOAT32); - NDArray Wp('c', {3*nOut}, sd::DataType::FLOAT32); - - x.linspace(0.5, 0.5); - Wx = 0.003f; - Wr = 0.006f; - b = 0.5f; - hI = 1.f; - cI = 2.f; - Wp = -0.05f; - - std::initializer_list tArgs = {cellClip}; - std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - NDArray expH('c', {sL, bS, nOut}, { - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.61209f, - 0.61209f, 0.61209f,0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.652042f, 0.652042f, 0.652042f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.677708f, 0.677708f, 0.677708f, 0.684177f, 0.684177f, 0.684177f, 0.f, 0.f, 0.f,0.f, 0.f, 0.f, 0.699627f, 0.699627f, - 0.699627f, 0.705371f, 0.705371f, 0.705371f, 0.710989f, 0.710989f, 0.710989f, 0., 0., 0., 0.719014, 0.719014, 0.719014, 0.724087, - 0.724087f, 0.724087f, 0.729084f, 0.729084f, 0.729084f, 0.734004f, 0.734004f, 0.734004f }, sd::DataType::FLOAT32); - - NDArray expHL('c', {bS, nOut}, {0.f, 0.f, 0.f, 0.719014f, 0.719014f, 0.719014f, 0.699627f, 0.699627f, 0.699627f, 0.677708f, 0.677708f, 0.677708f, 0.61209f, 0.61209f, 0.61209f}, sd::DataType::FLOAT32); - NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 2.092814f, 2.092814f, 2.092814f, 2.08832f, 2.08832f, 2.08832f, 2.009851f, 2.009851f, 2.009851f, 1.646034f, 1.646034f, 1.646034f}, sd::DataType::FLOAT32); +#ifndef HAVE_MKLDNN + + const int sL = 6; + const int bS = 5; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 1; // backward + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = + true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray seqLen('c', {bS}, {0, 1, 2, 3, 5}, sd::DataType::FLOAT32); + NDArray Wp('c', {3 * nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003f; + Wr = 0.006f; + b = 0.5f; + hI = 1.f; + cI = 2.f; + Wp = -0.05f; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, + cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, + hasInitC, hasPH, retFullSeq, + retLastH, retLastC}; + + NDArray expH( + 'c', {sL, bS, nOut}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.61209f, 0.61209f, 0.61209f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.652042f, 0.652042f, 0.652042f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.677708f, 0.677708f, 0.677708f, 0.684177f, 0.684177f, 0.684177f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.699627f, 0.699627f, 0.699627f, 0.705371f, 0.705371f, 0.705371f, + 0.710989f, 0.710989f, 0.710989f, 0., 0., 0., + 0.719014, 0.719014, 0.719014, 0.724087, 0.724087f, 0.724087f, + 0.729084f, 0.729084f, 0.729084f, 0.734004f, 0.734004f, 0.734004f}, + sd::DataType::FLOAT32); + + NDArray expHL('c', {bS, nOut}, + {0.f, 0.f, 0.f, 0.719014f, 0.719014f, 0.719014f, 0.699627f, + 0.699627f, 0.699627f, 0.677708f, 0.677708f, 0.677708f, + 0.61209f, 0.61209f, 0.61209f}, + sd::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, + {0.f, 0.f, 0.f, 2.092814f, 2.092814f, 2.092814f, 2.08832f, + 2.08832f, 2.08832f, 2.009851f, 2.009851f, 2.009851f, 1.646034f, + 1.646034f, 1.646034f}, + sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, + iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); - sd::ops::lstmLayer op; - auto results = op.evaluate({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hL = results.at(1); - auto cL = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expHL.isSameShape(hL)); - ASSERT_TRUE(expHL.equalsTo(hL)); - - ASSERT_TRUE(expCL.isSameShape(cL)); - ASSERT_TRUE(expCL.equalsTo(cL)); - - - #endif +#endif } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_12) { - #ifndef HAVE_MKLDNN - - const int sL = 6; - const int bS = 5; - const int nIn = 4; - const int nOut = 3; - - // input arguments - const int dataFormat = 0; // [sL,bS,nIn] - const int directionMode = 3; // bidirectional concat - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = true; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // do not return output at last time step - const auto retLastC = true; // return cells state at last time step - - const double cellClip = 0; // do not apply clipping - - NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {2,nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {2,nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {2,4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {2,bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {2,bS, nOut}, sd::DataType::FLOAT32); - NDArray seqLen('c', {bS}, {0,1,2,3,5}, sd::DataType::FLOAT32); - NDArray Wp('c', {2,3*nOut}, sd::DataType::FLOAT32); - - x.linspace(0.5, 0.5); - Wx({0,1, 0,0, 0,0}) = 0.003f; - Wx({1,2, 0,0, 0,0}) = -0.003f; - Wr({0,1, 0,0, 0,0}) = 0.006f; - Wr({1,2, 0,0, 0,0}) = -0.006f; - b({0,1, 0,0}) = 0.5f; - b({1,2, 0,0}) = -0.5f; - hI({0,1, 0,0, 0,0}) = 1; - hI({1,2, 0,0, 0,0}) = -1; - cI({0,1, 0,0, 0,0}) = 2; - cI({1,2, 0,0, 0,0}) = -2; - Wp({0,1, 0,0}) = -0.05f; - Wp({1,2, 0,0}) = 0.05f; - - std::initializer_list tArgs = {cellClip}; - std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - NDArray expH('c', {sL, bS, 2*nOut}, {0., 0., 0., 0., 0., 0., 0.562925, 0.562925, 0.562925, -0.25361 , -0.25361 , -0.25361 , 0.570404, 0.570404, 0.570404, -0.157103, - -0.157103, -0.157103, 0.57777 , 0.57777 , 0.57777 , -0.116502, -0.116502, -0.116502,0.585023, 0.585023, 0.585023, -0.100025, - -0.100025, -0.100025, 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0.576568, 0.576568, 0.576568, -0.223072, -0.223072, -0.223072, - 0.586163, 0.586163, 0.586163, -0.135714, -0.135714, -0.135714,0.595462, 0.595462, 0.595462, -0.094438, -0.094438, -0.094438, - 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.611224, 0.611224, 0.611224, -0.193473, -0.193473, -0.193473, - 0.621298, 0.621298, 0.621298, -0.090626, -0.090626, -0.090626, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0.655858, 0.655858, 0.655858, -0.098015, -0.098015, -0.098015, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.692315, 0.692315, 0.692315, -0.143704, -0.143704, -0.143704, 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, sd::DataType::FLOAT32); - - NDArray expHL('c', {2,bS, nOut}, {0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.576568f, 0.576568f, 0.576568f, 0.611224f, 0.611224f, 0.611224f, 0.692315f, 0.692315f, 0.692315f, - 0.f, 0.f, 0.f, -0.25361f, -0.25361f, -0.25361f, -0.157103f, -0.157103f, -0.157103f, -0.116502f, -0.116502f, -0.116502f, -0.100025f, -0.100025f, -0.100025f}, sd::DataType::FLOAT32); - NDArray expCL('c', {2,bS, nOut}, {0.f, 0.f, 0.f, 1.534275f, 1.534275f, 1.534275f, 1.40183f, 1.40183f, 1.40183f, 1.449675f, 1.449675f, 1.449675f, 1.767702f, 1.767702f, 1.767702f, - 0.f, 0.f, 0.f, -0.86636f, -0.86636f, -0.86636f, -0.470245f, -0.470245f, -0.470245f, -0.341856f, -0.341856f, -0.341856f, -0.294986f, -0.294986f, -0.294986f}, sd::DataType::FLOAT32); - - sd::ops::lstmLayer op; - auto results = op.evaluate({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hL = results.at(1); - auto cL = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expHL.isSameShape(hL)); - ASSERT_TRUE(expHL.equalsTo(hL)); - - ASSERT_TRUE(expCL.isSameShape(cL)); - ASSERT_TRUE(expCL.equalsTo(cL)); - - #endif +#ifndef HAVE_MKLDNN + + const int sL = 6; + const int bS = 5; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 3; // bidirectional concat + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = + true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {2, nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {2, nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {2, 4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {2, bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {2, bS, nOut}, sd::DataType::FLOAT32); + NDArray seqLen('c', {bS}, {0, 1, 2, 3, 5}, sd::DataType::FLOAT32); + NDArray Wp('c', {2, 3 * nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx({0, 1, 0, 0, 0, 0}) = 0.003f; + Wx({1, 2, 0, 0, 0, 0}) = -0.003f; + Wr({0, 1, 0, 0, 0, 0}) = 0.006f; + Wr({1, 2, 0, 0, 0, 0}) = -0.006f; + b({0, 1, 0, 0}) = 0.5f; + b({1, 2, 0, 0}) = -0.5f; + hI({0, 1, 0, 0, 0, 0}) = 1; + hI({1, 2, 0, 0, 0, 0}) = -1; + cI({0, 1, 0, 0, 0, 0}) = 2; + cI({1, 2, 0, 0, 0, 0}) = -2; + Wp({0, 1, 0, 0}) = -0.05f; + Wp({1, 2, 0, 0}) = 0.05f; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, + cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, + hasInitC, hasPH, retFullSeq, + retLastH, retLastC}; + + NDArray expH('c', {sL, bS, 2 * nOut}, + {0., 0., 0., 0., 0., 0., + 0.562925, 0.562925, 0.562925, -0.25361, -0.25361, -0.25361, + 0.570404, 0.570404, 0.570404, -0.157103, -0.157103, -0.157103, + 0.57777, 0.57777, 0.57777, -0.116502, -0.116502, -0.116502, + 0.585023, 0.585023, 0.585023, -0.100025, -0.100025, -0.100025, + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.576568, 0.576568, 0.576568, -0.223072, -0.223072, -0.223072, + 0.586163, 0.586163, 0.586163, -0.135714, -0.135714, -0.135714, + 0.595462, 0.595462, 0.595462, -0.094438, -0.094438, -0.094438, + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.611224, 0.611224, 0.611224, -0.193473, -0.193473, -0.193473, + 0.621298, 0.621298, 0.621298, -0.090626, -0.090626, -0.090626, + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.655858, 0.655858, 0.655858, -0.098015, -0.098015, -0.098015, + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.692315, 0.692315, 0.692315, -0.143704, -0.143704, -0.143704, + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0.}, + sd::DataType::FLOAT32); + + NDArray expHL( + 'c', {2, bS, nOut}, + {0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, + 0.576568f, 0.576568f, 0.576568f, 0.611224f, 0.611224f, 0.611224f, + 0.692315f, 0.692315f, 0.692315f, 0.f, 0.f, 0.f, + -0.25361f, -0.25361f, -0.25361f, -0.157103f, -0.157103f, -0.157103f, + -0.116502f, -0.116502f, -0.116502f, -0.100025f, -0.100025f, -0.100025f}, + sd::DataType::FLOAT32); + NDArray expCL( + 'c', {2, bS, nOut}, + {0.f, 0.f, 0.f, 1.534275f, 1.534275f, 1.534275f, + 1.40183f, 1.40183f, 1.40183f, 1.449675f, 1.449675f, 1.449675f, + 1.767702f, 1.767702f, 1.767702f, 0.f, 0.f, 0.f, + -0.86636f, -0.86636f, -0.86636f, -0.470245f, -0.470245f, -0.470245f, + -0.341856f, -0.341856f, -0.341856f, -0.294986f, -0.294986f, -0.294986f}, + sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, + iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + +#endif } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_test1) { + NDArray input('c', {2, 4}, sd::DataType::FLOAT32); + NDArray mean('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5f, 0.7f, 0.9f, 1.1f}, sd::DataType::FLOAT32); + NDArray gamma('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, sd::DataType::FLOAT32); + NDArray beta('c', {4}, {10.f, 20.f, -10.f, -20.f}, sd::DataType::FLOAT32); - NDArray input ('c', {2,4}, sd::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, sd::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5f, 0.7f, 0.9f, 1.1f}, sd::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, sd::DataType::FLOAT32); - NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, sd::DataType::FLOAT32); - - NDArray expected('c', {2,4}, {11.61218734f, 18.52390321f, -8.67185076f, -21.28716864f, 10.93337162f, 19.14541765f, -9.26213931f, -20.71509369f}, sd::DataType::FLOAT32); + NDArray expected('c', {2, 4}, + {11.61218734f, 18.52390321f, -8.67185076f, -21.28716864f, + 10.93337162f, 19.14541765f, -9.26213931f, -20.71509369f}, + sd::DataType::FLOAT32); - input.linspace(0.1, 0.1); + input.linspace(0.1, 0.1); - sd::ops::batchnorm op; + sd::ops::batchnorm op; - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); + auto results = + op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto output = results.at(0); - // output->printBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto output = results.at(0); + // output->printBuffer(); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests13, batchnorm_test2) { + auto input = NDArrayFactory::create('c', {2, 3, 4}); + auto mean = NDArrayFactory::create('c', {4}); + auto variance = NDArrayFactory::create('c', {4}); + auto gamma = NDArrayFactory::create('c', {4}); + auto beta = NDArrayFactory::create('c', {4}); - auto input = NDArrayFactory::create('c', {2,3,4}); - auto mean = NDArrayFactory::create('c', {4}); - auto variance = NDArrayFactory::create('c', {4}); - auto gamma = NDArrayFactory::create('c', {4}); - auto beta = NDArrayFactory::create('c', {4}); - - auto expected = NDArrayFactory::create('c', {2,3,4}, {-0.52733537f, -0.35763144f, -0.18792751f, -0.01822358f, 0.15148035f, 0.32118428f, 0.49088821f, 0.66059214f, 0.83029607f, 1.f, 1.16970393f, 1.33940786f, - 1.50911179f, 1.67881572f, 1.84851965f, 2.01822358f, 2.18792751f, 2.35763144f, 2.52733537f, 2.6970393f, 2.86674323f, 3.03644717f, 3.2061511f, 3.37585503f}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-0.52733537f, -0.35763144f, -0.18792751f, -0.01822358f, 0.15148035f, + 0.32118428f, 0.49088821f, 0.66059214f, 0.83029607f, 1.f, + 1.16970393f, 1.33940786f, 1.50911179f, 1.67881572f, 1.84851965f, + 2.01822358f, 2.18792751f, 2.35763144f, 2.52733537f, 2.6970393f, + 2.86674323f, 3.03644717f, 3.2061511f, 3.37585503f}); - input.linspace(0.1, 0.1); - mean.assign(1.); - variance.assign(0.5); - gamma.assign(1.2); - beta.assign(1.); + input.linspace(0.1, 0.1); + mean.assign(1.); + variance.assign(0.5); + gamma.assign(1.2); + beta.assign(1.); - sd::ops::batchnorm op; + sd::ops::batchnorm op; - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); + auto results = + op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto output = results.at(0); - // output->printBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = results.at(0); + // output->printBuffer(); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests13, batchnorm_test3) { + auto input = NDArrayFactory::create('c', {2, 3, 4}); + auto mean = NDArrayFactory::create('c', {3}, {1.05f, 1.1f, 1.15f}); + auto variance = + NDArrayFactory::create('c', {3}, {0.5f, 0.6f, 0.7f}); + auto gamma = NDArrayFactory::create('c', {3}, {1.2f, 1.3f, 1.4f}); + auto beta = NDArrayFactory::create('c', {3}, {0.1f, 0.2f, 0.3f}); - auto input = NDArrayFactory::create('c', {2,3,4}); - auto mean = NDArrayFactory::create('c', {3}, {1.05f, 1.1f, 1.15f}); - auto variance = NDArrayFactory::create('c', {3}, {0.5f, 0.6f, 0.7f}); - auto gamma = NDArrayFactory::create('c', {3}, {1.2f, 1.3f, 1.4f}); - auto beta = NDArrayFactory::create('c', {3}, {0.1f, 0.2f, 0.3f}); - - auto expected = NDArrayFactory::create('c', {2,3,4}, {-1.51218734f, -1.34248341f, -1.17277948f, -1.00307555f, -0.80696728f, -0.6391394f, -0.47131152f, -0.30348364f, -0.11832703f, 0.04900378f, 0.21633459f, 0.38366541f, - 0.52425983f, 0.69396376f, 0.86366769f, 1.03337162f, 1.20696728f, 1.37479516f, 1.54262304f, 1.71045092f, 1.8896427f, 2.05697351f, 2.22430432f, 2.39163513f}); - - input.linspace(0.1, 0.1); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-1.51218734f, -1.34248341f, -1.17277948f, -1.00307555f, -0.80696728f, + -0.6391394f, -0.47131152f, -0.30348364f, -0.11832703f, 0.04900378f, + 0.21633459f, 0.38366541f, 0.52425983f, 0.69396376f, 0.86366769f, + 1.03337162f, 1.20696728f, 1.37479516f, 1.54262304f, 1.71045092f, + 1.8896427f, 2.05697351f, 2.22430432f, 2.39163513f}); - sd::ops::batchnorm op; + input.linspace(0.1, 0.1); - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1}); + sd::ops::batchnorm op; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = + op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1, 1}); - auto output = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests13, batchnorm_test4) { + auto input = NDArrayFactory::create('c', {2, 3, 4}); + auto mean = NDArrayFactory::create( + 'c', {2, 1, 4}, {1.05f, 1.1f, 1.15f, 1.2f, 1.25f, 1.3f, 1.35f, 1.4f}); + auto variance = NDArrayFactory::create( + 'c', {2, 1, 4}, {0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f}); + auto gamma = NDArrayFactory::create( + 'c', {2, 1, 4}, {1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 1.9f}); + auto beta = NDArrayFactory::create( + 'c', {2, 1, 4}, {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.66f, 0.7f, 0.8f}); - auto input = NDArrayFactory::create('c', {2,3,4}); - auto mean = NDArrayFactory::create('c', {2,1,4}, {1.05f, 1.1f, 1.15f, 1.2f, 1.25f, 1.3f, 1.35f, 1.4f}); - auto variance = NDArrayFactory::create('c', {2,1,4}, {0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f}); - auto gamma = NDArrayFactory::create('c', {2,1,4}, {1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 1.9f}); - auto beta = NDArrayFactory::create('c', {2,1,4}, {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.66f, 0.7f, 0.8f}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-1.51218734f, -1.31045092f, -1.12231189f, -0.9416324f, -0.83337162f, + -0.6391394f, -0.45298865f, -0.2708162f, -0.1545559f, 0.03217212f, + 0.21633459f, 0.4f, 0.58432694f, 0.82999915f, 0.95743373f, + 1.14688951f, 1.25894242f, 1.50999575f, 1.64392367f, 1.84066852f, + 1.93355791f, 2.18999235f, 2.33041362f, 2.53444754f}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {-1.51218734f, -1.31045092f, -1.12231189f, -0.9416324f, -0.83337162f, -0.6391394f, -0.45298865f, -0.2708162f, -0.1545559f, 0.03217212f, 0.21633459f, 0.4f, - 0.58432694f, 0.82999915f, 0.95743373f, 1.14688951f, 1.25894242f, 1.50999575f, 1.64392367f, 1.84066852f, 1.93355791f, 2.18999235f, 2.33041362f, 2.53444754f}); + input.linspace(0.1, 0.1); - input.linspace(0.1, 0.1); + sd::ops::batchnorm op; - sd::ops::batchnorm op; + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, + {1, 1, 0, 2}); - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,0,2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_test5) { + NDArray input('c', {2, 4, 2, 2}, sd::DataType::FLOAT32); + NDArray mean('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5f, 0.7f, 0.9f, 1.1f}, sd::DataType::FLOAT32); + NDArray gamma('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, sd::DataType::FLOAT32); + NDArray beta('c', {4}, {10.f, 20.f, -10.f, -20.f}, sd::DataType::FLOAT32); - NDArray input ('c', {2,4,2,2}, sd::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, sd::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5f, 0.7f, 0.9f, 1.1f}, sd::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, sd::DataType::FLOAT32); - NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, sd::DataType::FLOAT32); - - NDArray expected('c', {2,4,2,2}, { 11.612187f, 11.442483f, 11.272779f, 11.103076f, 18.990039f, 19.145418f, 19.300796f, 19.456175f, -9.557284f, -9.704856f, -9.852428f, -10.f, -20.f, - -19.856981f, -19.713963f, -19.570944f, 8.896924f, 8.727221f, 8.557517f, 8.387813f, 21.476097f, 21.631475f, 21.786854f, 21.942233f, -11.918438f, - -12.06601f, -12.213582f, -12.361154f, -17.7117f, -17.568681f, -17.425663f, -17.282644f}, sd::DataType::FLOAT32); - input.linspace(0.1, 0.1); + NDArray expected( + 'c', {2, 4, 2, 2}, + {11.612187f, 11.442483f, 11.272779f, 11.103076f, 18.990039f, + 19.145418f, 19.300796f, 19.456175f, -9.557284f, -9.704856f, + -9.852428f, -10.f, -20.f, -19.856981f, -19.713963f, + -19.570944f, 8.896924f, 8.727221f, 8.557517f, 8.387813f, + 21.476097f, 21.631475f, 21.786854f, 21.942233f, -11.918438f, + -12.06601f, -12.213582f, -12.361154f, -17.7117f, -17.568681f, + -17.425663f, -17.282644f}, + sd::DataType::FLOAT32); + input.linspace(0.1, 0.1); - sd::ops::batchnorm op; + sd::ops::batchnorm op; - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1, 1}); + auto results = + op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1, 1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto output = results.at(0); - // output->printBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = results.at(0); + // output->printBuffer(); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_test6) { + NDArray input('c', {2, 2, 2, 4}, sd::DataType::FLOAT32); + NDArray mean('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5f, 0.7f, 0.9, 1.1f}, sd::DataType::FLOAT32); + NDArray gamma('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, sd::DataType::FLOAT32); + NDArray beta('c', {4}, {10.f, 20.f, -10.f, -20.f}, sd::DataType::FLOAT32); - NDArray input ('c', {2,2,2,4}, sd::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, sd::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5f, 0.7f, 0.9, 1.1f}, sd::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, sd::DataType::FLOAT32); - NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, sd::DataType::FLOAT32); - - NDArray expected('c', {2,2,2,4}, {11.612187f, 18.523903f, -8.671851f, -21.287169f, 10.933372f, 19.145418f, -9.262139f, -20.715094f, 10.254556f, 19.766932f, -9.852428f, -20.143019f, 9.57574f, - 20.388447f, -10.442716f, -19.570944f, 8.896924f, 21.009961f, -11.033005f, -18.998869f, 8.218109f, 21.631475f, -11.623294f, -18.426794f, 7.539293f, 22.25299f, - -12.213582f, -17.854719f, 6.860477f, 22.874504f, -12.803871f, -17.282644f}, sd::DataType::FLOAT32); - input.linspace(0.1, 0.1); + NDArray expected( + 'c', {2, 2, 2, 4}, + {11.612187f, 18.523903f, -8.671851f, -21.287169f, 10.933372f, + 19.145418f, -9.262139f, -20.715094f, 10.254556f, 19.766932f, + -9.852428f, -20.143019f, 9.57574f, 20.388447f, -10.442716f, + -19.570944f, 8.896924f, 21.009961f, -11.033005f, -18.998869f, + 8.218109f, 21.631475f, -11.623294f, -18.426794f, 7.539293f, + 22.25299f, -12.213582f, -17.854719f, 6.860477f, 22.874504f, + -12.803871f, -17.282644f}, + sd::DataType::FLOAT32); + input.linspace(0.1, 0.1); - sd::ops::batchnorm op; + sd::ops::batchnorm op; - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,3}); + auto results = + op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1, 3}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_test7) { + NDArray input1('c', {3, 3, 15, 15}, sd::DataType::FLOAT32); + NDArray input2('c', {3, 15, 15, 3}, sd::DataType::FLOAT32); + input2.permutei({0, 3, 1, 2}); - NDArray input1('c', {3,3,15,15}, sd::DataType::FLOAT32); - NDArray input2('c', {3,15,15,3}, sd::DataType::FLOAT32); - input2.permutei({0,3,1,2}); - - NDArray mean ('c', {3}, {0., 0, 0}, sd::DataType::FLOAT32); - NDArray variance('c', {3}, {1., 1, 1}, sd::DataType::FLOAT32); - NDArray gamma ('c', {3}, {1., 1, 1}, sd::DataType::FLOAT32); - NDArray beta ('c', {3}, {0., 0, 0}, sd::DataType::FLOAT32); + NDArray mean('c', {3}, {0., 0, 0}, sd::DataType::FLOAT32); + NDArray variance('c', {3}, {1., 1, 1}, sd::DataType::FLOAT32); + NDArray gamma('c', {3}, {1., 1, 1}, sd::DataType::FLOAT32); + NDArray beta('c', {3}, {0., 0, 0}, sd::DataType::FLOAT32); - NDArray out1('c', {3,3,15,15}, sd::DataType::FLOAT32); - NDArray out2('c', {3,3,15,15}, sd::DataType::FLOAT32); + NDArray out1('c', {3, 3, 15, 15}, sd::DataType::FLOAT32); + NDArray out2('c', {3, 3, 15, 15}, sd::DataType::FLOAT32); - input1.linspace(-1012, 1); - input2.assign(input1); + input1.linspace(-1012, 1); + input2.assign(input1); - sd::ops::batchnorm op; + sd::ops::batchnorm op; - auto res1 = op.execute({&input1, &mean, &variance, &gamma, &beta}, {&out1}, {1e-5}, {1,1,1}, {}); - ASSERT_EQ(ND4J_STATUS_OK, res1); + auto res1 = op.execute({&input1, &mean, &variance, &gamma, &beta}, {&out1}, + {1e-5}, {1, 1, 1}, {}); + ASSERT_EQ(ND4J_STATUS_OK, res1); - auto res2 = op.execute({&input2, &mean, &variance, &gamma, &beta}, {&out2}, {1e-5}, {1,1,1}, {}); - ASSERT_EQ(ND4J_STATUS_OK, res2); + auto res2 = op.execute({&input2, &mean, &variance, &gamma, &beta}, {&out2}, + {1e-5}, {1, 1, 1}, {}); + ASSERT_EQ(ND4J_STATUS_OK, res2); - ASSERT_TRUE(out1.equalsTo(out2)); + ASSERT_TRUE(out1.equalsTo(out2)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_test8) { - - NDArray input('c', {2,3,4,5}, sd::DataType::FLOAT32); - - NDArray mean ('c', {1,3,4,5}, sd::DataType::FLOAT32); - NDArray variance('c', {1,3,4,5}, sd::DataType::FLOAT32); - NDArray gamma ('c', {1,3,4,5}, sd::DataType::FLOAT32); - NDArray beta ('c', {1,3,4,5}, sd::DataType::FLOAT32); - - NDArray expected('c', {2,3,4,5}, {-105.019394, -103.322357, -101.625313, -99.928276, -98.231239, -96.534195, -94.837158, -93.140121, -91.443077, -89.746040, -88.049004, -86.351959, -84.654922, - -82.957886, -81.260841, -79.563805, -77.866768, -76.169724, -74.472687, -72.775650, -71.078606, -69.381569, -67.684532, -65.987488, -64.290451, -62.593414, - -60.896374, -59.199333, -57.502296, -55.805256, -54.108215, -52.411179, -50.714138, -49.017097, -47.320061, -45.623020, -43.925980, -42.228943, -40.531902, - -38.834862, -37.137825, -35.440784, -33.743744, -32.046707, -30.349667, -28.652628, -26.955589, -25.258549, -23.561510, -21.864471, -20.167431, -18.470392, - -16.773354, -15.076314, -13.379274, -11.682236, -9.985196, -8.288157, -6.591118, -4.894078, -3.197039, -1.500000, 0.197039, 1.894078, 3.591118, 5.288157, - 6.985196, 8.682236, 10.379274, 12.076314, 13.773354, 15.470392, 17.167431, 18.864471, 20.561510, 22.258549, 23.955589, 25.652628, 27.349667, 29.046707, 30.743744, - 32.440784, 34.137825, 35.834862, 37.531902, 39.228943, 40.925980, 42.623020, 44.320061, 46.017097, 47.714138, 49.411179, 51.108215, 52.805256, 54.502296, 56.199333, - 57.896374, 59.593414, 61.290451, 62.987488, 64.684532, 66.381569, 68.078606, 69.775650, 71.472687, 73.169724, 74.866768, 76.563805, 78.260841, 79.957886, 81.654922, - 83.351959, 85.049004, 86.746040, 88.443077, 90.140121, 91.837158, 93.534195, 95.231239, 96.928276}, sd::DataType::FLOAT32); - - input.linspace(-60, 1); - mean.assign(1.); - variance.assign(0.5); - gamma.assign(1.2); - beta.assign(-1.5); - - sd::ops::batchnorm op; - - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1, 1,2,3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(*output)); - ASSERT_TRUE(expected.equalsTo(output)); - + NDArray input('c', {2, 3, 4, 5}, sd::DataType::FLOAT32); + + NDArray mean('c', {1, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray variance('c', {1, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray gamma('c', {1, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray beta('c', {1, 3, 4, 5}, sd::DataType::FLOAT32); + + NDArray expected( + 'c', {2, 3, 4, 5}, + {-105.019394, -103.322357, -101.625313, -99.928276, -98.231239, + -96.534195, -94.837158, -93.140121, -91.443077, -89.746040, + -88.049004, -86.351959, -84.654922, -82.957886, -81.260841, + -79.563805, -77.866768, -76.169724, -74.472687, -72.775650, + -71.078606, -69.381569, -67.684532, -65.987488, -64.290451, + -62.593414, -60.896374, -59.199333, -57.502296, -55.805256, + -54.108215, -52.411179, -50.714138, -49.017097, -47.320061, + -45.623020, -43.925980, -42.228943, -40.531902, -38.834862, + -37.137825, -35.440784, -33.743744, -32.046707, -30.349667, + -28.652628, -26.955589, -25.258549, -23.561510, -21.864471, + -20.167431, -18.470392, -16.773354, -15.076314, -13.379274, + -11.682236, -9.985196, -8.288157, -6.591118, -4.894078, + -3.197039, -1.500000, 0.197039, 1.894078, 3.591118, + 5.288157, 6.985196, 8.682236, 10.379274, 12.076314, + 13.773354, 15.470392, 17.167431, 18.864471, 20.561510, + 22.258549, 23.955589, 25.652628, 27.349667, 29.046707, + 30.743744, 32.440784, 34.137825, 35.834862, 37.531902, + 39.228943, 40.925980, 42.623020, 44.320061, 46.017097, + 47.714138, 49.411179, 51.108215, 52.805256, 54.502296, + 56.199333, 57.896374, 59.593414, 61.290451, 62.987488, + 64.684532, 66.381569, 68.078606, 69.775650, 71.472687, + 73.169724, 74.866768, 76.563805, 78.260841, 79.957886, + 81.654922, 83.351959, 85.049004, 86.746040, 88.443077, + 90.140121, 91.837158, 93.534195, 95.231239, 96.928276}, + sd::DataType::FLOAT32); + + input.linspace(-60, 1); + mean.assign(1.); + variance.assign(0.5); + gamma.assign(1.2); + beta.assign(-1.5); + + sd::ops::batchnorm op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, + {1, 1, 1, 2, 3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_test9) { - - NDArray input('c', {2,3,3,3,3}, sd::DataType::FLOAT32); - - NDArray mean ('c', {1,3,3,3,3}, sd::DataType::FLOAT32); - NDArray variance('c', {1,3,3,3,3}, sd::DataType::FLOAT32); - NDArray gamma ('c', {1,3,3,3,3}, sd::DataType::FLOAT32); - NDArray beta ('c', {1,3,3,3,3}, sd::DataType::FLOAT32); - - NDArray expected('c', {2,3,3,3,3}, {-138.960175, -137.263138, -135.566101, -133.869064, -132.172028, -130.474976, -128.777954, -127.080902, -125.383865, -123.686829, -121.989784, -120.292747, - -118.595711, -116.898666, -115.201630, -113.504593, -111.807549, -110.110512, -108.413475, -106.716431, -105.019394, -103.322357, -101.625313, -99.928276, - -98.231239, -96.534195, -94.837158, -93.140121, -91.443077, -89.746040, -88.049004, -86.351959, -84.654922, -82.957886, -81.260841, -79.563805, -77.866768, - -76.169724, -74.472687, -72.775650, -71.078606, -69.381569, -67.684532, -65.987488, -64.290451, -62.593414, -60.896374, -59.199333, -57.502296, -55.805256, - -54.108215, -52.411179, -50.714138, -49.017097, -47.320061, -45.623020, -43.925980, -42.228943, -40.531902, -38.834862, -37.137825, -35.440784, -33.743744, - -32.046707, -30.349667, -28.652628, -26.955589, -25.258549, -23.561510, -21.864471, -20.167431, -18.470392, -16.773354, -15.076314, -13.379274, -11.682236, - -9.985196, -8.288157, -6.591118, -4.894078, -3.197039, -1.500000, 0.197039, 1.894078, 3.591118, 5.288157, 6.985196, 8.682236, 10.379274, 12.076314, 13.773354, - 15.470392, 17.167431, 18.864471, 20.561510, 22.258549, 23.955589, 25.652628, 27.349667, 29.046707, 30.743744, 32.440784, 34.137825, 35.834862, 37.531902, 39.228943, - 40.925980, 42.623020, 44.320061, 46.017097, 47.714138, 49.411179, 51.108215, 52.805256, 54.502296, 56.199333, 57.896374, 59.593414, 61.290451, 62.987488, 64.684532, - 66.381569, 68.078606, 69.775650, 71.472687, 73.169724, 74.866768, 76.563805, 78.260841, 79.957886, 81.654922, 83.351959, 85.049004, 86.746040, 88.443077, 90.140121, - 91.837158, 93.534195, 95.231239, 96.928276, 98.625313, 100.322357, 102.019394, 103.716431, 105.413475, 107.110512, 108.807549, 110.504593, 112.201630, 113.898666, - 115.595711, 117.292747, 118.989784, 120.686829, 122.383865, 124.080902, 125.777946, 127.474976, 129.172028, 130.869064, 132.566101, 134.263138}, sd::DataType::FLOAT32); - - input.linspace(-80, 1); - mean.assign(1.); - variance.assign(0.5); - gamma.assign(1.2); - beta.assign(-1.5); - - sd::ops::batchnorm op; - - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1, 1,2,3,4}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto output = results.at(0); - // output->printBuffer(); - - ASSERT_TRUE(expected.isSameShape(*output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + NDArray input('c', {2, 3, 3, 3, 3}, sd::DataType::FLOAT32); + + NDArray mean('c', {1, 3, 3, 3, 3}, sd::DataType::FLOAT32); + NDArray variance('c', {1, 3, 3, 3, 3}, sd::DataType::FLOAT32); + NDArray gamma('c', {1, 3, 3, 3, 3}, sd::DataType::FLOAT32); + NDArray beta('c', {1, 3, 3, 3, 3}, sd::DataType::FLOAT32); + + NDArray expected( + 'c', {2, 3, 3, 3, 3}, + {-138.960175, -137.263138, -135.566101, -133.869064, -132.172028, + -130.474976, -128.777954, -127.080902, -125.383865, -123.686829, + -121.989784, -120.292747, -118.595711, -116.898666, -115.201630, + -113.504593, -111.807549, -110.110512, -108.413475, -106.716431, + -105.019394, -103.322357, -101.625313, -99.928276, -98.231239, + -96.534195, -94.837158, -93.140121, -91.443077, -89.746040, + -88.049004, -86.351959, -84.654922, -82.957886, -81.260841, + -79.563805, -77.866768, -76.169724, -74.472687, -72.775650, + -71.078606, -69.381569, -67.684532, -65.987488, -64.290451, + -62.593414, -60.896374, -59.199333, -57.502296, -55.805256, + -54.108215, -52.411179, -50.714138, -49.017097, -47.320061, + -45.623020, -43.925980, -42.228943, -40.531902, -38.834862, + -37.137825, -35.440784, -33.743744, -32.046707, -30.349667, + -28.652628, -26.955589, -25.258549, -23.561510, -21.864471, + -20.167431, -18.470392, -16.773354, -15.076314, -13.379274, + -11.682236, -9.985196, -8.288157, -6.591118, -4.894078, + -3.197039, -1.500000, 0.197039, 1.894078, 3.591118, + 5.288157, 6.985196, 8.682236, 10.379274, 12.076314, + 13.773354, 15.470392, 17.167431, 18.864471, 20.561510, + 22.258549, 23.955589, 25.652628, 27.349667, 29.046707, + 30.743744, 32.440784, 34.137825, 35.834862, 37.531902, + 39.228943, 40.925980, 42.623020, 44.320061, 46.017097, + 47.714138, 49.411179, 51.108215, 52.805256, 54.502296, + 56.199333, 57.896374, 59.593414, 61.290451, 62.987488, + 64.684532, 66.381569, 68.078606, 69.775650, 71.472687, + 73.169724, 74.866768, 76.563805, 78.260841, 79.957886, + 81.654922, 83.351959, 85.049004, 86.746040, 88.443077, + 90.140121, 91.837158, 93.534195, 95.231239, 96.928276, + 98.625313, 100.322357, 102.019394, 103.716431, 105.413475, + 107.110512, 108.807549, 110.504593, 112.201630, 113.898666, + 115.595711, 117.292747, 118.989784, 120.686829, 122.383865, + 124.080902, 125.777946, 127.474976, 129.172028, 130.869064, + 132.566101, 134.263138}, + sd::DataType::FLOAT32); + + input.linspace(-80, 1); + mean.assign(1.); + variance.assign(0.5); + gamma.assign(1.2); + beta.assign(-1.5); + + sd::ops::batchnorm op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, + {1, 1, 1, 2, 3, 4}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto output = results.at(0); + // output->printBuffer(); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_bp_test1) { + NDArray input('c', {2, 3, 4}, sd::DataType::FLOAT32); + NDArray mean('c', {4}, {1.1, 1.2, 1.3, 1.4}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, sd::DataType::FLOAT32); + NDArray gamma('c', {4}, sd::DataType::FLOAT32); + NDArray beta('c', {4}, sd::DataType::FLOAT32); + NDArray gradO('c', {2, 3, 4}, sd::DataType::FLOAT32); - NDArray input ('c', {2,3,4}, sd::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.1, 1.2, 1.3, 1.4}, sd::DataType::FLOAT32); - NDArray variance('c', {4}, sd::DataType::FLOAT32); - NDArray gamma ('c', {4}, sd::DataType::FLOAT32); - NDArray beta ('c', {4}, sd::DataType::FLOAT32); - NDArray gradO ('c', {2,3,4}, sd::DataType::FLOAT32); - - NDArray expdLdI('c', {2,3,4}, {-0.000056, -0.000056, -0.000056, -0.000056, -0.000034, -0.000034, -0.000034, -0.000034, -0.000011, -0.000011, -0.000011, -0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000034, 0.000034, 0.000034, 0.000034, 0.000056, 0.000056, 0.000056, 0.000056}, sd::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {6.148104, 6.148104, 6.148105, 6.148105}, sd::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {3.6, 4.5, 5.4, 6.3}, sd::DataType::FLOAT32); + NDArray expdLdI( + 'c', {2, 3, 4}, + {-0.000056, -0.000056, -0.000056, -0.000056, -0.000034, -0.000034, + -0.000034, -0.000034, -0.000011, -0.000011, -0.000011, -0.000011, + 0.000011, 0.000011, 0.000011, 0.000011, 0.000034, 0.000034, + 0.000034, 0.000034, 0.000056, 0.000056, 0.000056, 0.000056}, + sd::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {6.148104, 6.148104, 6.148105, 6.148105}, + sd::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {3.6, 4.5, 5.4, 6.3}, sd::DataType::FLOAT32); - input.linspace(0.1, 0.1); - variance.assign(0.46666667); - gamma.assign(1.2); - beta.assign(1.); // has no effect on gradient calculations - gradO.linspace(-0.9, 0.15); + input.linspace(0.1, 0.1); + variance.assign(0.46666667); + gamma.assign(1.2); + beta.assign(1.); // has no effect on gradient calculations + gradO.linspace(-0.9, 0.15); - sd::ops::batchnorm_bp op; + sd::ops::batchnorm_bp op; - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, + {1e-5}, {1, 1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdI = results.at(0); - auto dLdG = results.at(3); - auto dLdB = results.at(4); + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } - //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_bp_test2) { + NDArray input('c', {2, 3, 4}, sd::DataType::FLOAT32); + NDArray mean('c', {3}, {1.05, 1.1, 1.15}, sd::DataType::FLOAT32); + NDArray variance('c', {3}, {0.5, 0.6, 0.7}, sd::DataType::FLOAT32); + NDArray gamma('c', {3}, {1.2, 1.3, 1.4}, sd::DataType::FLOAT32); + NDArray beta('c', {3}, sd::DataType::FLOAT32); + NDArray gradO('c', {2, 3, 4}, sd::DataType::FLOAT32); - NDArray input ('c', {2,3,4}, sd::DataType::FLOAT32); - NDArray mean ('c', {3}, {1.05, 1.1, 1.15}, sd::DataType::FLOAT32); - NDArray variance('c', {3}, {0.5, 0.6, 0.7}, sd::DataType::FLOAT32); - NDArray gamma ('c', {3}, {1.2, 1.3, 1.4}, sd::DataType::FLOAT32); - NDArray beta ('c', {3}, sd::DataType::FLOAT32); - NDArray gradO ('c', {2,3,4}, sd::DataType::FLOAT32); + NDArray expdLdI( + 'c', {2, 3, 4}, + {-0.601415, -0.521226, -0.441037, -0.360849, -0.456306, -0.395465, + -0.334624, -0.273784, 0.396631, 0.343747, 0.290863, 0.237978, + 0.360849, 0.441037, 0.521226, 0.601415, 0.273784, 0.334625, + 0.395465, 0.456306, -0.237978, -0.290863, -0.343746, -0.396631}, + sd::DataType::FLOAT32); + NDArray expdLdG('c', {3}, {5.81236, 7.048771, 12.155388}, + sd::DataType::FLOAT32); + NDArray expdLdB('c', {3}, {1.8, 6.6, 11.4}, sd::DataType::FLOAT32); - NDArray expdLdI('c', {2,3,4}, {-0.601415, -0.521226, -0.441037, -0.360849, -0.456306, -0.395465, -0.334624, -0.273784, 0.396631, 0.343747, - 0.290863, 0.237978, 0.360849, 0.441037, 0.521226, 0.601415, 0.273784, 0.334625, 0.395465, 0.456306, -0.237978, - -0.290863, -0.343746, -0.396631}, sd::DataType::FLOAT32); - NDArray expdLdG('c', {3}, {5.81236 , 7.048771, 12.155388}, sd::DataType::FLOAT32); - NDArray expdLdB('c', {3}, {1.8, 6.6, 11.4}, sd::DataType::FLOAT32); + input.linspace(0.1, 0.1); + // beta.assign(1.); // has no effect on gradient calculations + gradO.linspace(-0.9, 0.15); - input.linspace(0.1, 0.1); - // beta.assign(1.); // has no effect on gradient calculations - gradO.linspace(-0.9, 0.15); + sd::ops::batchnorm_bp op; - sd::ops::batchnorm_bp op; + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, + {1e-5}, {1, 1, 1}); - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); - auto dLdI = results.at(0); - auto dLdG = results.at(3); - auto dLdB = results.at(4); + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_bp_test3) { - - NDArray input ('c', {2,3,4}, sd::DataType::FLOAT32); - NDArray mean ('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}, sd::DataType::FLOAT32); - NDArray variance('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}, sd::DataType::FLOAT32); - NDArray gamma ('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9}, sd::DataType::FLOAT32); - NDArray beta ('c', {2,1,4}, sd::DataType::FLOAT32); - NDArray gradO ('c', {2,3,4}, sd::DataType::FLOAT32); - - NDArray expdLdI('c', {2,3,4}, {-0.577002, -0.744041, -0.850999, -0.922373, -0.000000, -0.000000, -0.000000, -0.000000, 0.577002, - 0.744041, 0.850999, 0.922373, -0.386037, -0.350205, -0.312047, -0.271737, -0.000000, -0.000000, - -0.000000, -0.000000, 0.386037, 0.350205, 0.312047, 0.271736}, sd::DataType::FLOAT32); - NDArray expdLdG('c', {2,1,4}, {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, 3.289431, 3.64234 }, sd::DataType::FLOAT32); - NDArray expdLdB('c', {2,1,4}, {-0.9 , -0.45, 0. , 0.45, 4.5 , 4.95, 5.4 , 5.85}, sd::DataType::FLOAT32); - - input.linspace(0.1, 0.1); - // beta.assign(1.); // has no effect on gradient calculations - gradO.linspace(-0.9, 0.15); - - sd::ops::batchnorm_bp op; - - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,0,2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdI = results.at(0); - auto dLdG = results.at(3); - auto dLdB = results.at(4); - - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - + NDArray input('c', {2, 3, 4}, sd::DataType::FLOAT32); + NDArray mean('c', {2, 1, 4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}, + sd::DataType::FLOAT32); + NDArray variance('c', {2, 1, 4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}, + sd::DataType::FLOAT32); + NDArray gamma('c', {2, 1, 4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9}, + sd::DataType::FLOAT32); + NDArray beta('c', {2, 1, 4}, sd::DataType::FLOAT32); + NDArray gradO('c', {2, 3, 4}, sd::DataType::FLOAT32); + + NDArray expdLdI( + 'c', {2, 3, 4}, + {-0.577002, -0.744041, -0.850999, -0.922373, -0.000000, -0.000000, + -0.000000, -0.000000, 0.577002, 0.744041, 0.850999, 0.922373, + -0.386037, -0.350205, -0.312047, -0.271737, -0.000000, -0.000000, + -0.000000, -0.000000, 0.386037, 0.350205, 0.312047, 0.271736}, + sd::DataType::FLOAT32); + NDArray expdLdG('c', {2, 1, 4}, + {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, + 3.289431, 3.64234}, + sd::DataType::FLOAT32); + NDArray expdLdB('c', {2, 1, 4}, {-0.9, -0.45, 0., 0.45, 4.5, 4.95, 5.4, 5.85}, + sd::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + // beta.assign(1.); // has no effect on gradient calculations + gradO.linspace(-0.9, 0.15); + + sd::ops::batchnorm_bp op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, + {1e-5}, {1, 1, 0, 2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_bp_test4) { + NDArray input('c', {2, 4}, sd::DataType::FLOAT32); + NDArray mean('c', {4}, {1.05, 1.15, 1.2, 1.3}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, sd::DataType::FLOAT32); + NDArray gamma('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); + NDArray beta('c', {4}, sd::DataType::FLOAT32); + NDArray gradO('c', {2, 4}, sd::DataType::FLOAT32); - NDArray input ('c', {2,4}, sd::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, sd::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, sd::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); - NDArray beta ('c', {4}, sd::DataType::FLOAT32); - NDArray gradO ('c', {2,4}, sd::DataType::FLOAT32); - - NDArray expdLdI('c', {2,4}, {0.162923, -0.289673, 0.354174, -0.386151, -0.162923, 0.289673, -0.354174, 0.386151}, sd::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {1.442483, 0.950200, 0.569207, 0.314641}, sd::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {-1.2, -0.9, -0.6, -0.3}, sd::DataType::FLOAT32); + NDArray expdLdI('c', {2, 4}, + {0.162923, -0.289673, 0.354174, -0.386151, -0.162923, + 0.289673, -0.354174, 0.386151}, + sd::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {1.442483, 0.950200, 0.569207, 0.314641}, + sd::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {-1.2, -0.9, -0.6, -0.3}, sd::DataType::FLOAT32); - input.linspace(0.1, 0.1); - gradO.linspace(-0.9, 0.15); + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); - sd::ops::batchnorm_bp op; + sd::ops::batchnorm_bp op; - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, + {1e-5}, {1, 1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdI = results.at(0); - auto dLdG = results.at(3); - auto dLdB = results.at(4); + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_bp_test5) { - #if defined(HAVE_CUDNN) -return; + return; #endif - NDArray input ('c', {2,4,2,2}, sd::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, sd::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, sd::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); - NDArray beta ('c', {4}, sd::DataType::FLOAT32); - NDArray gradO ('c', {2,4,2,2}, sd::DataType::FLOAT32); + NDArray input('c', {2, 4, 2, 2}, sd::DataType::FLOAT32); + NDArray mean('c', {4}, {1.05, 1.15, 1.2, 1.3}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, sd::DataType::FLOAT32); + NDArray gamma('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); + NDArray beta('c', {4}, sd::DataType::FLOAT32); + NDArray gradO('c', {2, 4, 2, 2}, sd::DataType::FLOAT32); - NDArray expdLdI('c', {2,4,2,2}, {-0.737512, -0.659880, -0.582247, -0.504614, 0.561404, 0.502309, 0.443214, 0.384118, -1.168243, - -1.045270, -0.922297, -0.799324, 1.899026, 1.699128, 1.499231, 1.299333, 0.504614, 0.582247, 0.659880, 0.737512, -0.384118, - -0.443214, -0.502308, -0.561404, 0.799324, 0.922297, 1.045270, 1.168243, -1.299334, -1.499231, -1.699129, -1.899026}, sd::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {11.073181, 12.585667, 17.708657, 24.313186}, sd::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, sd::DataType::FLOAT32); + NDArray expdLdI( + 'c', {2, 4, 2, 2}, + {-0.737512, -0.659880, -0.582247, -0.504614, 0.561404, 0.502309, + 0.443214, 0.384118, -1.168243, -1.045270, -0.922297, -0.799324, + 1.899026, 1.699128, 1.499231, 1.299333, 0.504614, 0.582247, + 0.659880, 0.737512, -0.384118, -0.443214, -0.502308, -0.561404, + 0.799324, 0.922297, 1.045270, 1.168243, -1.299334, -1.499231, + -1.699129, -1.899026}, + sd::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {11.073181, 12.585667, 17.708657, 24.313186}, + sd::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {4.2, 9., 13.8, 18.6}, sd::DataType::FLOAT32); - input.linspace(0.1, 0.1); - gradO.linspace(-0.9, 0.15); + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); - sd::ops::batchnorm_bp op; + sd::ops::batchnorm_bp op; - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, + {1e-5}, {1, 1, 1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdI = results.at(0); - auto dLdG = results.at(3); - auto dLdB = results.at(4); + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_bp_test6) { - #if defined(HAVE_CUDNN) -return; + return; #endif - NDArray input ('c', {2,2,2,4}, sd::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, sd::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, sd::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); - NDArray beta ('c', {4}, sd::DataType::FLOAT32); - NDArray gradO ('c', {2,2,2,4}, sd::DataType::FLOAT32); + NDArray input('c', {2, 2, 2, 4}, sd::DataType::FLOAT32); + NDArray mean('c', {4}, {1.05, 1.15, 1.2, 1.3}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, sd::DataType::FLOAT32); + NDArray gamma('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); + NDArray beta('c', {4}, sd::DataType::FLOAT32); + NDArray gradO('c', {2, 2, 2, 4}, sd::DataType::FLOAT32); - NDArray expdLdI('c', {2,2,2,4}, {-4.989124, 2.540357, -1.515022, 0.791769, -3.563660, 1.814540, -1.082159, 0.565549, -2.138196, 1.088724, -0.649295, - 0.339329, -0.712732, 0.362908, -0.216432, 0.113110, 0.712732, -0.362908, 0.216432, -0.113110, 2.138195, -1.088724, 0.649295, - -0.339330, 3.563660,-1.814540, 1.082159, -0.565549, 4.989125, -2.540356, 1.515022, -0.791770}, sd::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {20.364472, 17.856588, 16.949714, 15.903684}, sd::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {9.6, 10.8, 12. , 13.2}, sd::DataType::FLOAT32); + NDArray expdLdI( + 'c', {2, 2, 2, 4}, + {-4.989124, 2.540357, -1.515022, 0.791769, -3.563660, 1.814540, + -1.082159, 0.565549, -2.138196, 1.088724, -0.649295, 0.339329, + -0.712732, 0.362908, -0.216432, 0.113110, 0.712732, -0.362908, + 0.216432, -0.113110, 2.138195, -1.088724, 0.649295, -0.339330, + 3.563660, -1.814540, 1.082159, -0.565549, 4.989125, -2.540356, + 1.515022, -0.791770}, + sd::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {20.364472, 17.856588, 16.949714, 15.903684}, + sd::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {9.6, 10.8, 12., 13.2}, sd::DataType::FLOAT32); - input.linspace(0.1, 0.1); - gradO.linspace(-0.9, 0.15); + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); - sd::ops::batchnorm_bp op; + sd::ops::batchnorm_bp op; - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, + {1e-5}, {1, 1, 3}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdI = results.at(0); - auto dLdG = results.at(3); - auto dLdB = results.at(4); + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_bp_test7) { - #if defined(HAVE_CUDNN) -return; + return; #endif - NDArray input ('c', {2,2,2,2,4}, sd::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, sd::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, sd::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); - NDArray beta ('c', {4}, sd::DataType::FLOAT32); - NDArray gradO ('c', {2,2,2,2,4}, sd::DataType::FLOAT32); - - NDArray expdLdI('c', {2,2,2,2,4}, {-119.435059, 78.159744, -58.732986, 46.630123, -103.510391, 67.738441, -50.901920, 40.412773, -87.585716, 57.317142, - -43.070854, 34.195419, -71.661041, 46.895844, -35.239792, 27.978071, -55.736359, 36.474548, -27.408726, 21.760721, -39.811687, 26.053242, -19.577662, - 15.543370, -23.887009, 15.631950, -11.746595, 9.326023, -7.962326, 5.210644, -3.915531, 3.108671, 7.962341, -5.210655, 3.915535, -3.108677, 23.887032, - -15.631958, 11.746601, -9.326031, 39.811691, -26.053246, 19.577671, -15.543377, 55.736382, -36.474548, 27.408726, -21.760731, 71.661064, -46.895851, 35.239788, - -27.978077, 87.585732, -57.317154, 43.070866, -34.195431, 103.510384, -67.738464, 50.901920, -40.412777, 119.435097, -78.159744, 58.732998, -46.630131}, sd::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {282.38734 , 244.542027, 224.140995, 207.548793}, sd::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {57.6, 60. , 62.4, 64.8}, sd::DataType::FLOAT32); - - input.linspace(0.1, 0.1); - gradO.linspace(-0.9, 0.15); + NDArray input('c', {2, 2, 2, 2, 4}, sd::DataType::FLOAT32); + NDArray mean('c', {4}, {1.05, 1.15, 1.2, 1.3}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, sd::DataType::FLOAT32); + NDArray gamma('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); + NDArray beta('c', {4}, sd::DataType::FLOAT32); + NDArray gradO('c', {2, 2, 2, 2, 4}, sd::DataType::FLOAT32); + NDArray expdLdI( + 'c', {2, 2, 2, 2, 4}, + {-119.435059, 78.159744, -58.732986, 46.630123, -103.510391, 67.738441, + -50.901920, 40.412773, -87.585716, 57.317142, -43.070854, 34.195419, + -71.661041, 46.895844, -35.239792, 27.978071, -55.736359, 36.474548, + -27.408726, 21.760721, -39.811687, 26.053242, -19.577662, 15.543370, + -23.887009, 15.631950, -11.746595, 9.326023, -7.962326, 5.210644, + -3.915531, 3.108671, 7.962341, -5.210655, 3.915535, -3.108677, + 23.887032, -15.631958, 11.746601, -9.326031, 39.811691, -26.053246, + 19.577671, -15.543377, 55.736382, -36.474548, 27.408726, -21.760731, + 71.661064, -46.895851, 35.239788, -27.978077, 87.585732, -57.317154, + 43.070866, -34.195431, 103.510384, -67.738464, 50.901920, -40.412777, + 119.435097, -78.159744, 58.732998, -46.630131}, + sd::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {282.38734, 244.542027, 224.140995, 207.548793}, + sd::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {57.6, 60., 62.4, 64.8}, sd::DataType::FLOAT32); - sd::ops::batchnorm_bp op; + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,4}); + sd::ops::batchnorm_bp op; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, + {1e-5}, {1, 1, 4}); - auto dLdI = results.at(0); - auto dLdG = results.at(3); - auto dLdB = results.at(4); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - // dLdI->printBuffer(); + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + // dLdI->printBuffer(); - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_bp_test8) { - #if defined(HAVE_CUDNN) -return; + return; #endif - NDArray input ('c', {2,4,2,2,2}, sd::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, sd::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, sd::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); - NDArray beta ('c', {4}, sd::DataType::FLOAT32); - NDArray gradO ('c', {2,4,2,2,2}, sd::DataType::FLOAT32); - - NDArray expdLdI('c', {2,4,2,2,2}, {-34.373802, -32.611046, -30.848286, -29.085529, -27.322769, -25.560009, -23.797251, -22.034491, 36.146996, 34.293301, - 32.439610, 30.585917, 28.732227, 26.878534, 25.024841, 23.171150, -42.876553, -40.677757, -38.478958, -36.280159, -34.081367, -31.882565, -29.683767, - -27.484968, 50.674446, 48.075760, 45.477066, 42.878380, 40.279686, 37.681000, 35.082310, 32.483616, 22.034489, 23.797249, 25.560009, 27.322765, 29.085526, - 30.848286, 32.611046, 34.373802, -23.171146, -25.024837, -26.878536, -28.732231, -30.585918, -32.439613, -34.293297, -36.146996, 27.484982, 29.683773, - 31.882572, 34.081364, 36.280178, 38.478970, 40.677776, 42.876560, -32.483627, -35.082329, -37.681023, -40.279701, -42.878403, -45.477081, -48.075775, -50.674484}, sd::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {134.490365, 179.785003, 248.933114, 330.087248}, sd::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {32.4, 51.6, 70.8, 90.}, sd::DataType::FLOAT32); + NDArray input('c', {2, 4, 2, 2, 2}, sd::DataType::FLOAT32); + NDArray mean('c', {4}, {1.05, 1.15, 1.2, 1.3}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, sd::DataType::FLOAT32); + NDArray gamma('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); + NDArray beta('c', {4}, sd::DataType::FLOAT32); + NDArray gradO('c', {2, 4, 2, 2, 2}, sd::DataType::FLOAT32); - input.linspace(0.1, 0.1); - gradO.linspace(-0.9, 0.15); + NDArray expdLdI( + 'c', {2, 4, 2, 2, 2}, + {-34.373802, -32.611046, -30.848286, -29.085529, -27.322769, -25.560009, + -23.797251, -22.034491, 36.146996, 34.293301, 32.439610, 30.585917, + 28.732227, 26.878534, 25.024841, 23.171150, -42.876553, -40.677757, + -38.478958, -36.280159, -34.081367, -31.882565, -29.683767, -27.484968, + 50.674446, 48.075760, 45.477066, 42.878380, 40.279686, 37.681000, + 35.082310, 32.483616, 22.034489, 23.797249, 25.560009, 27.322765, + 29.085526, 30.848286, 32.611046, 34.373802, -23.171146, -25.024837, + -26.878536, -28.732231, -30.585918, -32.439613, -34.293297, -36.146996, + 27.484982, 29.683773, 31.882572, 34.081364, 36.280178, 38.478970, + 40.677776, 42.876560, -32.483627, -35.082329, -37.681023, -40.279701, + -42.878403, -45.477081, -48.075775, -50.674484}, + sd::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {134.490365, 179.785003, 248.933114, 330.087248}, + sd::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {32.4, 51.6, 70.8, 90.}, sd::DataType::FLOAT32); - sd::ops::batchnorm_bp op; + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); + sd::ops::batchnorm_bp op; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, + {1e-5}, {1, 1, 1}); - auto dLdI = results.at(0); - auto dLdG = results.at(3); - auto dLdB = results.at(4); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - // dLdI->printBuffer(); + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + // dLdI->printBuffer(); - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_bp_test9) { - - NDArray input ('c', {2,4,2,2}, sd::DataType::FLOAT32); - NDArray mean ('c', {4}, sd::DataType::FLOAT32); - NDArray variance('c', {4}, sd::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); - NDArray beta ('c', {4}, sd::DataType::FLOAT32); - NDArray gradO ('c', {2,4,2,2}, sd::DataType::FLOAT32); - - NDArray expdLdI('c', {2,4,2,2}, {0.032378, 0.028967, 0.025558, 0.022147, -0.035056, -0.031364, -0.027669, -0.024006, 0.037742, 0.033766, 0.029791, 0.025818, - -0.040429, -0.036172, -0.031913, -0.027656, -0.022155, -0.025564, -0.028974, -0.032359, 0.023982, 0.027677, 0.031373, 0.035063, - -0.025822, -0.029794, -0.033770, -0.037747, 0.027653, 0.031913, 0.036168, 0.040426}, sd::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {9.685875, 9.685880, 9.685887, 9.685891}, sd::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, sd::DataType::FLOAT32); - - input.linspace(1,0.01); - gradO.linspace(-0.9, 0.15); - - // calculate mean and variance of input - PointersManager manager(input.getContext(), "DeclarableOpsTests13.batchnorm_bp_test9"); - std::vector dimensions = {0,2,3}; - int* dims = reinterpret_cast(manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int))); - input.reduceAlongDimension(sd::reduce::Mean, mean, dimensions); - NDArray::prepareSpecialUse({&variance}, {&input}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input.shapeInfo(), dimensions); - NativeOpExecutioner::execSummaryStats(input.getContext(), 0,input.buffer(), input.shapeInfo(),input.specialBuffer(), input.specialShapeInfo(),nullptr,variance.buffer(), variance.shapeInfo(),variance.specialBuffer(), variance.specialShapeInfo(), dims, dimensions.size(),packX.platformShapeInfo(), packX.platformOffsets(),false); - manager.synchronize(); - NDArray::registerSpecialUse({&variance}, {&input}); - - sd::ops::batchnorm_bp op; - - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdI = results.at(0); - auto dLdG = results.at(3); - auto dLdB = results.at(4); - - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - + NDArray input('c', {2, 4, 2, 2}, sd::DataType::FLOAT32); + NDArray mean('c', {4}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, sd::DataType::FLOAT32); + NDArray gamma('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); + NDArray beta('c', {4}, sd::DataType::FLOAT32); + NDArray gradO('c', {2, 4, 2, 2}, sd::DataType::FLOAT32); + + NDArray expdLdI( + 'c', {2, 4, 2, 2}, + {0.032378, 0.028967, 0.025558, 0.022147, -0.035056, -0.031364, + -0.027669, -0.024006, 0.037742, 0.033766, 0.029791, 0.025818, + -0.040429, -0.036172, -0.031913, -0.027656, -0.022155, -0.025564, + -0.028974, -0.032359, 0.023982, 0.027677, 0.031373, 0.035063, + -0.025822, -0.029794, -0.033770, -0.037747, 0.027653, 0.031913, + 0.036168, 0.040426}, + sd::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {9.685875, 9.685880, 9.685887, 9.685891}, + sd::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {4.2, 9., 13.8, 18.6}, sd::DataType::FLOAT32); + + input.linspace(1, 0.01); + gradO.linspace(-0.9, 0.15); + + // calculate mean and variance of input + PointersManager manager(input.getContext(), + "DeclarableOpsTests13.batchnorm_bp_test9"); + std::vector dimensions = {0, 2, 3}; + int* dims = reinterpret_cast(manager.replicatePointer( + dimensions.data(), dimensions.size() * sizeof(int))); + input.reduceAlongDimension(sd::reduce::Mean, mean, dimensions); + NDArray::prepareSpecialUse({&variance}, {&input}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input.shapeInfo(), dimensions); + NativeOpExecutioner::execSummaryStats( + input.getContext(), 0, input.buffer(), input.shapeInfo(), + input.specialBuffer(), input.specialShapeInfo(), nullptr, + variance.buffer(), variance.shapeInfo(), variance.specialBuffer(), + variance.specialShapeInfo(), dims, dimensions.size(), + packX.platformShapeInfo(), packX.platformOffsets(), false); + manager.synchronize(); + NDArray::registerSpecialUse({&variance}, {&input}); + + sd::ops::batchnorm_bp op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, + {1e-5}, {1, 1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_bp_test10) { - - NDArray input ('c', {2,2,2,4}, sd::DataType::FLOAT32); - NDArray mean ('c', {4}, sd::DataType::FLOAT32); - NDArray variance('c', {4}, sd::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); - NDArray beta ('c', {4}, sd::DataType::FLOAT32); - NDArray gradO ('c', {2,2,2,4}, sd::DataType::FLOAT32); - - NDArray expdLdI('c', {2,2,2,4}, {0.032634, -0.035423, 0.038110, -0.040864, 0.023302, -0.025294, 0.027213, -0.029205, 0.013996, -0.015192, 0.016343, - -0.017519, 0.004664, -0.005062, 0.005445, -0.005833, -0.004668, 0.005067, -0.005452, 0.005824, -0.013974, 0.015171, - -0.016325, 0.017508, -0.023309, 0.025301, -0.027221, 0.029197, -0.032639, 0.035428, -0.038118, 0.040878}, sd::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {10.991656, 10.991631, 10.991643, 10.991632}, sd::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {9.6, 10.8, 12., 13.2}, sd::DataType::FLOAT32); - - input.linspace(1,0.01); - gradO.linspace(-0.9, 0.15); - - // calculate mean and variance of input - PointersManager manager(input.getContext(), "DeclarableOpsTests13.batchnorm_bp_test9"); - std::vector dimensions = {0,1,2}; - int* dims = reinterpret_cast(manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int))); - input.reduceAlongDimension(sd::reduce::Mean, mean, dimensions); - NDArray::prepareSpecialUse({&variance}, {&input}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input.shapeInfo(), dimensions); - NativeOpExecutioner::execSummaryStats(input.getContext(), 0,input.buffer(), input.shapeInfo(),input.specialBuffer(), input.specialShapeInfo(),nullptr,variance.buffer(), variance.shapeInfo(),variance.specialBuffer(), variance.specialShapeInfo(), dims, dimensions.size(),packX.platformShapeInfo(), packX.platformOffsets(),false); - manager.synchronize(); - NDArray::registerSpecialUse({&variance}, {&input}); - - sd::ops::batchnorm_bp op; - - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdI = results.at(0); - auto dLdG = results.at(3); - auto dLdB = results.at(4); - - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - + NDArray input('c', {2, 2, 2, 4}, sd::DataType::FLOAT32); + NDArray mean('c', {4}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, sd::DataType::FLOAT32); + NDArray gamma('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); + NDArray beta('c', {4}, sd::DataType::FLOAT32); + NDArray gradO('c', {2, 2, 2, 4}, sd::DataType::FLOAT32); + + NDArray expdLdI( + 'c', {2, 2, 2, 4}, + {0.032634, -0.035423, 0.038110, -0.040864, 0.023302, -0.025294, + 0.027213, -0.029205, 0.013996, -0.015192, 0.016343, -0.017519, + 0.004664, -0.005062, 0.005445, -0.005833, -0.004668, 0.005067, + -0.005452, 0.005824, -0.013974, 0.015171, -0.016325, 0.017508, + -0.023309, 0.025301, -0.027221, 0.029197, -0.032639, 0.035428, + -0.038118, 0.040878}, + sd::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {10.991656, 10.991631, 10.991643, 10.991632}, + sd::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {9.6, 10.8, 12., 13.2}, sd::DataType::FLOAT32); + + input.linspace(1, 0.01); + gradO.linspace(-0.9, 0.15); + + // calculate mean and variance of input + PointersManager manager(input.getContext(), + "DeclarableOpsTests13.batchnorm_bp_test9"); + std::vector dimensions = {0, 1, 2}; + int* dims = reinterpret_cast(manager.replicatePointer( + dimensions.data(), dimensions.size() * sizeof(int))); + input.reduceAlongDimension(sd::reduce::Mean, mean, dimensions); + NDArray::prepareSpecialUse({&variance}, {&input}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input.shapeInfo(), dimensions); + NativeOpExecutioner::execSummaryStats( + input.getContext(), 0, input.buffer(), input.shapeInfo(), + input.specialBuffer(), input.specialShapeInfo(), nullptr, + variance.buffer(), variance.shapeInfo(), variance.specialBuffer(), + variance.specialShapeInfo(), dims, dimensions.size(), + packX.platformShapeInfo(), packX.platformOffsets(), false); + manager.synchronize(); + NDArray::registerSpecialUse({&variance}, {&input}); + + sd::ops::batchnorm_bp op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, + {1e-5}, {1, 1, 3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_bp_test11) { - - NDArray input ('c', {2,3,4,5}, sd::DataType::FLOAT32); - NDArray mean ('c', {1,3,4,5}, sd::DataType::FLOAT32); - NDArray variance('c', {1,3,4,5}, sd::DataType::FLOAT32); - NDArray gamma ('c', {1,3,4,5}, sd::DataType::FLOAT32); - NDArray beta ('c', {1,3,4,5}, sd::DataType::FLOAT32); - NDArray gradO ('c', {2,3,4,5}, sd::DataType::FLOAT32); - - NDArray expdLdI('c', {2,3,4,5}, {0.004981, 0.004818, 0.004652, 0.004483, 0.004319, 0.004153, 0.003985, 0.003832, 0.003661, 0.003505, 0.003340, 0.003171, 0.003001, 0.002837, - 0.002670, 0.002505, 0.002337, 0.002167, 0.002003, 0.001835, 0.001666, 0.001499, 0.001327, 0.001162, 0.000996, 0.000830, 0.000664, 0.000498, - 0.000332, 0.000166, -0.0, -0.000166, -0.000333, -0.000500, -0.000668, -0.000835, -0.001003, -0.001168, -0.001337, -0.001502, -0.001670, - -0.001838, -0.002003, -0.002172, -0.002330, -0.002499, -0.002669, -0.002832, -0.003002, -0.003162, -0.003332, -0.003495, -0.003665, -0.003821, - -0.004001, -0.004163, -0.004324, -0.004516, -0.004678, -0.004851, -0.004981, -0.004818, -0.004652, -0.004483, -0.004319, -0.004151, -0.003985, - -0.003836, -0.003661, -0.003505, -0.003338, -0.003171, -0.003004, -0.002837, -0.002670, -0.002503, -0.002337, -0.002170, -0.002003, -0.001835, - -0.001664, -0.001499, -0.001328, -0.001162, -0.000996, -0.000829, -0.000664, -0.000498, -0.000332, -0.000166, 0.0, 0.000166, 0.000334, - 0.000500, 0.000668, 0.000834, 0.001003, 0.001170, 0.001337, 0.001502, 0.001669, 0.001838, 0.002005, 0.002172, 0.002330, 0.002496, 0.002669, - 0.002836, 0.003002, 0.003162, 0.003328, 0.003495, 0.003670, 0.003828, 0.003992, 0.004158, 0.004324, 0.004522, 0.004689, 0.004843}, sd::DataType::FLOAT32); - NDArray expdLdG('c', {1,3,4,5}, {8.999503, 8.999502, 8.999502, 8.999503, 8.999502, 8.999503, 8.999503, 8.999499, 8.999501, 8.999498, 8.999498, 8.999498, 8.999498, 8.999498, 8.999498, - 8.999498, 8.999498, 8.999498, 8.999498, 8.999499, 8.999501, 8.999500, 8.999503, 8.999503, 8.999503, 8.999504, 8.999503, 8.999503, 8.999504, 8.999503, - 8.999504, 8.999504, 8.999499, 8.999500, 8.999497, 8.999498, 8.999496, 8.999496, 8.999496, 8.999498, 8.999498, 8.999496, 8.999496, 8.999496, 8.999501, - 8.999501, 8.999499, 8.999499, 8.999499, 8.999501, 8.999501, 8.999501, 8.999499, 8.999500, 8.999501, 8.999501, 8.999501, 8.999495, 8.999495, 8.999497}, sd::DataType::FLOAT32); - NDArray expdLdB('c', {1,3,4,5}, {7.2, 7.5, 7.8, 8.1, 8.4, 8.7, 9.0, 9.3, 9.6, 9.9, 10.2, 10.5, 10.8, 11.1, 11.4, 11.7, 12.0, 12.3, 12.6, 12.9, 13.2, 13.5, 13.8, 14.1, 14.4, 14.7, 15.0, - 15.3, 15.6, 15.9, 16.2, 16.5, 16.8, 17.1, 17.4, 17.7, 18.0, 18.3, 18.6, 18.9, 19.2, 19.5, 19.8, 20.1, 20.4, 20.7, 21.0, 21.3, 21.6, 21.9, 22.2, 22.5, - 22.8, 23.1, 23.4, 23.7, 24.0, 24.3, 24.6, 24.9}, sd::DataType::FLOAT32); - - input.linspace(1,0.01); - gradO.linspace(-0.9, 0.15); - gamma.linspace(-3, 0.1); - - // calculate mean and variance of input - PointersManager manager(input.getContext(), "DeclarableOpsTests13.batchnorm_bp_test9"); - std::vector dimensions = {0}; - int* dims = reinterpret_cast(manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int))); - input.reduceAlongDimension(sd::reduce::Mean, mean, dimensions, true); - NDArray::prepareSpecialUse({&variance}, {&input}); - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input.shapeInfo(), dimensions); - NativeOpExecutioner::execSummaryStats(input.getContext(), 0,input.buffer(), input.shapeInfo(),input.specialBuffer(), input.specialShapeInfo(),nullptr,variance.buffer(), variance.shapeInfo(),variance.specialBuffer(), variance.specialShapeInfo(), dims, dimensions.size(),packX.platformShapeInfo(), packX.platformOffsets(),false); - manager.synchronize(); - NDArray::registerSpecialUse({&variance}, {&input}); - - sd::ops::batchnorm_bp op; - - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1, 1,2,3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdI = results.at(0); - auto dLdG = results.at(3); - auto dLdB = results.at(4); - - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - + NDArray input('c', {2, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray mean('c', {1, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray variance('c', {1, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray gamma('c', {1, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray beta('c', {1, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray gradO('c', {2, 3, 4, 5}, sd::DataType::FLOAT32); + + NDArray expdLdI( + 'c', {2, 3, 4, 5}, + {0.004981, 0.004818, 0.004652, 0.004483, 0.004319, 0.004153, + 0.003985, 0.003832, 0.003661, 0.003505, 0.003340, 0.003171, + 0.003001, 0.002837, 0.002670, 0.002505, 0.002337, 0.002167, + 0.002003, 0.001835, 0.001666, 0.001499, 0.001327, 0.001162, + 0.000996, 0.000830, 0.000664, 0.000498, 0.000332, 0.000166, + -0.0, -0.000166, -0.000333, -0.000500, -0.000668, -0.000835, + -0.001003, -0.001168, -0.001337, -0.001502, -0.001670, -0.001838, + -0.002003, -0.002172, -0.002330, -0.002499, -0.002669, -0.002832, + -0.003002, -0.003162, -0.003332, -0.003495, -0.003665, -0.003821, + -0.004001, -0.004163, -0.004324, -0.004516, -0.004678, -0.004851, + -0.004981, -0.004818, -0.004652, -0.004483, -0.004319, -0.004151, + -0.003985, -0.003836, -0.003661, -0.003505, -0.003338, -0.003171, + -0.003004, -0.002837, -0.002670, -0.002503, -0.002337, -0.002170, + -0.002003, -0.001835, -0.001664, -0.001499, -0.001328, -0.001162, + -0.000996, -0.000829, -0.000664, -0.000498, -0.000332, -0.000166, + 0.0, 0.000166, 0.000334, 0.000500, 0.000668, 0.000834, + 0.001003, 0.001170, 0.001337, 0.001502, 0.001669, 0.001838, + 0.002005, 0.002172, 0.002330, 0.002496, 0.002669, 0.002836, + 0.003002, 0.003162, 0.003328, 0.003495, 0.003670, 0.003828, + 0.003992, 0.004158, 0.004324, 0.004522, 0.004689, 0.004843}, + sd::DataType::FLOAT32); + NDArray expdLdG( + 'c', {1, 3, 4, 5}, + {8.999503, 8.999502, 8.999502, 8.999503, 8.999502, 8.999503, 8.999503, + 8.999499, 8.999501, 8.999498, 8.999498, 8.999498, 8.999498, 8.999498, + 8.999498, 8.999498, 8.999498, 8.999498, 8.999498, 8.999499, 8.999501, + 8.999500, 8.999503, 8.999503, 8.999503, 8.999504, 8.999503, 8.999503, + 8.999504, 8.999503, 8.999504, 8.999504, 8.999499, 8.999500, 8.999497, + 8.999498, 8.999496, 8.999496, 8.999496, 8.999498, 8.999498, 8.999496, + 8.999496, 8.999496, 8.999501, 8.999501, 8.999499, 8.999499, 8.999499, + 8.999501, 8.999501, 8.999501, 8.999499, 8.999500, 8.999501, 8.999501, + 8.999501, 8.999495, 8.999495, 8.999497}, + sd::DataType::FLOAT32); + NDArray expdLdB( + 'c', {1, 3, 4, 5}, + {7.2, 7.5, 7.8, 8.1, 8.4, 8.7, 9.0, 9.3, 9.6, 9.9, 10.2, 10.5, + 10.8, 11.1, 11.4, 11.7, 12.0, 12.3, 12.6, 12.9, 13.2, 13.5, 13.8, 14.1, + 14.4, 14.7, 15.0, 15.3, 15.6, 15.9, 16.2, 16.5, 16.8, 17.1, 17.4, 17.7, + 18.0, 18.3, 18.6, 18.9, 19.2, 19.5, 19.8, 20.1, 20.4, 20.7, 21.0, 21.3, + 21.6, 21.9, 22.2, 22.5, 22.8, 23.1, 23.4, 23.7, 24.0, 24.3, 24.6, 24.9}, + sd::DataType::FLOAT32); + + input.linspace(1, 0.01); + gradO.linspace(-0.9, 0.15); + gamma.linspace(-3, 0.1); + + // calculate mean and variance of input + PointersManager manager(input.getContext(), + "DeclarableOpsTests13.batchnorm_bp_test9"); + std::vector dimensions = {0}; + int* dims = reinterpret_cast(manager.replicatePointer( + dimensions.data(), dimensions.size() * sizeof(int))); + input.reduceAlongDimension(sd::reduce::Mean, mean, dimensions, true); + NDArray::prepareSpecialUse({&variance}, {&input}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + input.shapeInfo(), dimensions); + NativeOpExecutioner::execSummaryStats( + input.getContext(), 0, input.buffer(), input.shapeInfo(), + input.specialBuffer(), input.specialShapeInfo(), nullptr, + variance.buffer(), variance.shapeInfo(), variance.specialBuffer(), + variance.specialShapeInfo(), dims, dimensions.size(), + packX.platformShapeInfo(), packX.platformOffsets(), false); + manager.synchronize(); + NDArray::registerSpecialUse({&variance}, {&input}); + + sd::ops::batchnorm_bp op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, + {1e-5}, {1, 1, 1, 2, 3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp index ef35bfa72988..5a47ac2254e0 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp @@ -14,2411 +14,2704 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // Created by raver on 8/4/2018. // -#include "testlayers.h" -#include #include -#include #include +#include +#include +#include "testlayers.h" using namespace sd; - class DeclarableOpsTests14 : public testing::Test { -public: - - DeclarableOpsTests14() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests14() { + } }; TEST_F(DeclarableOpsTests14, Test_Validation_Edge_1) { - auto x = NDArrayFactory::create('c', {2}, {2, 2}); - auto exp = NDArrayFactory::create('c', {2, 2}, Environment::getInstance().defaultFloatDataType()); - exp.assign(4.0f); - - sd::ops::fill op; - auto result = op.evaluate({&x}, {4.0f}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); + auto x = NDArrayFactory::create('c', {2}, {2, 2}); + auto exp = NDArrayFactory::create( + 'c', {2, 2}, Environment::getInstance().defaultFloatDataType()); + exp.assign(4.0f); - ASSERT_EQ(exp, *z); + sd::ops::fill op; + auto result = op.evaluate({&x}, {4.0f}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(exp, z); } TEST_F(DeclarableOpsTests14, Test_Inf_Comparison_1) { - auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, std::numeric_limits::infinity(), 5}); - auto y = NDArrayFactory::create('c', {5}, {1, 2, 3, std::numeric_limits::infinity(), 5}); + auto x = NDArrayFactory::create( + 'c', {5}, {1, 2, 3, std::numeric_limits::infinity(), 5}); + auto y = NDArrayFactory::create( + 'c', {5}, {1, 2, 3, std::numeric_limits::infinity(), 5}); - ASSERT_EQ(x, y); + ASSERT_EQ(x, y); } TEST_F(DeclarableOpsTests14, Test_Inf_Comparison_2) { #ifdef FFAST_MATH - if (1 > 0) - return; + if (1 > 0) return; #endif - auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, std::numeric_limits::infinity(), 5}); - auto y = NDArrayFactory::create('c', {5}, {1, 2, 3, -std::numeric_limits::infinity(), 5}); + auto x = NDArrayFactory::create( + 'c', {5}, {1, 2, 3, std::numeric_limits::infinity(), 5}); + auto y = NDArrayFactory::create( + 'c', {5}, {1, 2, 3, -std::numeric_limits::infinity(), 5}); - ASSERT_NE(x, y); + ASSERT_NE(x, y); } TEST_F(DeclarableOpsTests14, Multiply_test) { + for (int k = 2; k < 10; k++) { + // nd4j_printf("k=%d\n", k); + NDArray x = NDArrayFactory::create('c', {k, 1}); + NDArray y = NDArrayFactory::create('c', {k}); + NDArray e = NDArrayFactory::create('c', {k, k}); + x.assign(1.0); + y.assign(1.0); + e.assign(1.0); - for(int k=2;k<10;k++){ - //nd4j_printf("k=%d\n", k); - NDArray x = NDArrayFactory::create('c', {k, 1}); - NDArray y = NDArrayFactory::create('c', {k}); - NDArray e = NDArrayFactory::create('c', {k, k}); - x.assign(1.0); - y.assign(1.0); - e.assign(1.0); - - sd::ops::multiply op; - auto result = op.evaluate({&x, &y}); - auto f = result.at(0); - NDArray r = *f; - - ASSERT_EQ(e, r); - ASSERT_EQ(e, *f); - + sd::ops::multiply op; + auto result = op.evaluate({&x, &y}); + auto f = result.at(0); - } + ASSERT_EQ(e, f); + } } TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_1) { - auto x = NDArrayFactory::create('c', {3}, {5, 3, 4}); - auto y = NDArrayFactory::create('c', {1}, {1}); - auto e = NDArrayFactory::create('c', {2}, {5, 4}); - - sd::ops::evaluate_reduction_shape op; - auto result = op.evaluate({&x, &y}, {}, {}, {false, false}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_EQ(e, *z); + auto x = NDArrayFactory::create('c', {3}, {5, 3, 4}); + auto y = NDArrayFactory::create('c', {1}, {1}); + auto e = NDArrayFactory::create('c', {2}, {5, 4}); + sd::ops::evaluate_reduction_shape op; + auto result = op.evaluate({&x, &y}, {}, {}, {false, false}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_2) { - auto x = NDArrayFactory::create('c', {3}, {5, 3, 4}); - auto y = NDArrayFactory::create('c', {1}, {1}); - auto e = NDArrayFactory::create('c', {3}, {5, 1, 4}); - - sd::ops::evaluate_reduction_shape op; - auto result = op.evaluate({&x, &y}, {}, {}, {true, false}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_EQ(e, *z); + auto x = NDArrayFactory::create('c', {3}, {5, 3, 4}); + auto y = NDArrayFactory::create('c', {1}, {1}); + auto e = NDArrayFactory::create('c', {3}, {5, 1, 4}); + sd::ops::evaluate_reduction_shape op; + auto result = op.evaluate({&x, &y}, {}, {}, {true, false}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests14, Test_Reduce_Min_Small_0) { - auto x = NDArrayFactory::create('c', {3, 4}, {-999.f, 0.2236f, 0.7973f, 0.0962f, 0.7231f, 0.3381f, -0.7301f, 0.9115f, -0.5094f, 0.9749f, -2.1340f, 0.6023f}); - auto z = NDArrayFactory::create('c', {4}); - auto e = NDArrayFactory::create('c', {4}, {-999.f, 0.2236f, -2.1340f, 0.0962f}); + auto x = NDArrayFactory::create( + 'c', {3, 4}, + {-999.f, 0.2236f, 0.7973f, 0.0962f, 0.7231f, 0.3381f, -0.7301f, 0.9115f, + -0.5094f, 0.9749f, -2.1340f, 0.6023f}); + auto z = NDArrayFactory::create('c', {4}); + auto e = NDArrayFactory::create('c', {4}, + {-999.f, 0.2236f, -2.1340f, 0.0962f}); - sd::ops::reduce_min op; - op.execute({&x}, {&z}, {}, {0}, {}); + sd::ops::reduce_min op; + op.execute({&x}, {&z}, {}, {0}, {}); - //z.printIndexedBuffer("Z"); + // z.printIndexedBuffer("Z"); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests14, Test_Reduce_Min_Small_1) { - auto x = NDArrayFactory::create('c', {3, 4}, {-999.f, 0.2236f, 0.7973f, 0.0962f, 0.7231f, 0.3381f, -0.7301f, 0.9115f, -0.5094f, 0.9749f, -2.1340f, 0.6023f}); - auto z = NDArrayFactory::create('c', {3}); - auto e = NDArrayFactory::create('c', {3}, {-999.f, -0.7301f, -2.1340f}); + auto x = NDArrayFactory::create( + 'c', {3, 4}, + {-999.f, 0.2236f, 0.7973f, 0.0962f, 0.7231f, 0.3381f, -0.7301f, 0.9115f, + -0.5094f, 0.9749f, -2.1340f, 0.6023f}); + auto z = NDArrayFactory::create('c', {3}); + auto e = + NDArrayFactory::create('c', {3}, {-999.f, -0.7301f, -2.1340f}); - sd::ops::reduce_min op; - op.execute({&x}, {&z}, {}, {1}, {}); + sd::ops::reduce_min op; + op.execute({&x}, {&z}, {}, {1}, {}); - //z.printIndexedBuffer("Z"); + // z.printIndexedBuffer("Z"); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests14, Test_Diag_Zeros_1) { - auto x = NDArrayFactory::create('c', {2}, {1, 2}); - auto z = NDArrayFactory::create('c', {2, 2}, {-119, -119, -119, -119}); - auto exp = NDArrayFactory::create('c', {2, 2}, {1, 0, 0, 2}); + auto x = NDArrayFactory::create('c', {2}, {1, 2}); + auto z = + NDArrayFactory::create('c', {2, 2}, {-119, -119, -119, -119}); + auto exp = NDArrayFactory::create('c', {2, 2}, {1, 0, 0, 2}); - sd::ops::diag op; - auto status = op.execute({&x}, {&z}, {}, {}, {}); - ASSERT_EQ(Status::OK(), status); + sd::ops::diag op; + auto status = op.execute({&x}, {&z}, {}, {}, {}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(exp, z); + ASSERT_EQ(exp, z); } TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_1) { - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create('c', {5, 10}); - auto e = NDArrayFactory::create('c', {5, 10}); - e.assign(1.0); - - - sd::ops::add op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_EQ(e, *result.at(0)); + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {5, 10}); + auto e = NDArrayFactory::create('c', {5, 10}); + e.assign(1.0); + sd::ops::add op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(e, result.at(0)); } TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_2) { - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create('c', {5, 10}); - auto e = NDArrayFactory::create('c', {5, 10}); - y.assign(2.0f); - e.assign(-1.0f); - - - sd::ops::subtract op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_EQ(e, *result.at(0)); + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {5, 10}); + auto e = NDArrayFactory::create('c', {5, 10}); + y.assign(2.0f); + e.assign(-1.0f); + sd::ops::subtract op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(e, result.at(0)); } TEST_F(DeclarableOpsTests14, test_empty_fill_1) { - auto x = NDArrayFactory::empty(); - auto y = NDArrayFactory::create(1); - - sd::ops::fill op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_EQ(y, *z); + auto x = NDArrayFactory::empty(); + auto y = NDArrayFactory::create(1); + sd::ops::fill op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(y, z); } TEST_F(DeclarableOpsTests14, test_lstmBlockCell_1) { - auto a = NDArrayFactory::create('c', {1, 5}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f}); - auto b = NDArrayFactory::create('c', {1, 3}); - auto c = NDArrayFactory::create('c', {1, 3}); - auto d = NDArrayFactory::create('c', {8, 12}, {-0.15320599,-0.120416045,0.33126968,0.13921785,-0.32313538,-0.43956736,0.4756174,0.4335605,-0.5450856,-0.3943429,-0.28687626,0.068032146,-0.2793799,0.17298919,-0.36553562,-0.097853184,-0.2544747,-0.39872527,-0.14556861,-0.31479517,0.2559092,0.47166896,-0.31330687,0.47313118,0.5134543,-0.4678212,-0.12853557,0.26142156,0.43472284,-0.42842552,-0.1895876,0.538689,0.508651,-0.020272732,0.112327516,0.2704304,-0.046546757,0.32570732,-0.15148133,-0.19145513,0.18631572,-0.024152994,0.41603214,-0.3421499,0.0106860995,-0.2966229,-0.36713937,0.25841123,0.0843398,0.49082482,0.10800403,0.1874243,-0.26379472,-0.22531849,0.24924624,0.23119557,0.49940765,-0.051413506,0.20315129,-0.41888732,0.44097036,0.40453392,0.013338983,0.23434466,0.23942488,0.47894,-0.19898453,0.09253675,-0.032358468,-0.15213022,-0.3441009,-0.15600958,-0.08235118,0.12165731,-0.4481289,-0.4842423,-0.45797008,-0.4606034,0.08163166,-0.2981107,0.50207126,0.44195646,0.13850057,0.072246075,-0.34388685,0.030900061,0.35821778,0.47900867,0.5094063,0.23683065,0.18020362,-0.1369732,0.015235603,0.2786904,0.07954317,0.12543976}); - auto e = NDArrayFactory::create('c', {3}); - auto f = NDArrayFactory::create('c', {3}); - auto g = NDArrayFactory::create('c', {3}); - auto h = NDArrayFactory::create('c', {12}); - - auto z0 = NDArrayFactory::create('c', {1, 3}); - auto z1 = NDArrayFactory::create('c', {1, 3}); - auto z2 = NDArrayFactory::create('c', {1, 3}); - auto z3 = NDArrayFactory::create('c', {1, 3}); - auto z4 = NDArrayFactory::create('c', {1, 3}); - auto z5 = NDArrayFactory::create('c', {1, 3}); - auto z6 = NDArrayFactory::create('c', {1, 3}); - - sd::ops::lstmBlockCell op; - auto result = op.execute({&a, &b, &c, &d, &e, &f, &g, &h}, {&z0, &z1, &z2, &z3, &z4, &z5, &z6}, {1.0, -1.0}, {0}, {}); - ASSERT_EQ(Status::OK(), result); + auto a = NDArrayFactory::create( + 'c', {1, 5}, + {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f}); + auto b = NDArrayFactory::create('c', {1, 3}); + auto c = NDArrayFactory::create('c', {1, 3}); + auto d = NDArrayFactory::create( + 'c', {8, 12}, + {-0.15320599, -0.120416045, 0.33126968, 0.13921785, -0.32313538, + -0.43956736, 0.4756174, 0.4335605, -0.5450856, -0.3943429, + -0.28687626, 0.068032146, -0.2793799, 0.17298919, -0.36553562, + -0.097853184, -0.2544747, -0.39872527, -0.14556861, -0.31479517, + 0.2559092, 0.47166896, -0.31330687, 0.47313118, 0.5134543, + -0.4678212, -0.12853557, 0.26142156, 0.43472284, -0.42842552, + -0.1895876, 0.538689, 0.508651, -0.020272732, 0.112327516, + 0.2704304, -0.046546757, 0.32570732, -0.15148133, -0.19145513, + 0.18631572, -0.024152994, 0.41603214, -0.3421499, 0.0106860995, + -0.2966229, -0.36713937, 0.25841123, 0.0843398, 0.49082482, + 0.10800403, 0.1874243, -0.26379472, -0.22531849, 0.24924624, + 0.23119557, 0.49940765, -0.051413506, 0.20315129, -0.41888732, + 0.44097036, 0.40453392, 0.013338983, 0.23434466, 0.23942488, + 0.47894, -0.19898453, 0.09253675, -0.032358468, -0.15213022, + -0.3441009, -0.15600958, -0.08235118, 0.12165731, -0.4481289, + -0.4842423, -0.45797008, -0.4606034, 0.08163166, -0.2981107, + 0.50207126, 0.44195646, 0.13850057, 0.072246075, -0.34388685, + 0.030900061, 0.35821778, 0.47900867, 0.5094063, 0.23683065, + 0.18020362, -0.1369732, 0.015235603, 0.2786904, 0.07954317, + 0.12543976}); + auto e = NDArrayFactory::create('c', {3}); + auto f = NDArrayFactory::create('c', {3}); + auto g = NDArrayFactory::create('c', {3}); + auto h = NDArrayFactory::create('c', {12}); + + auto z0 = NDArrayFactory::create('c', {1, 3}); + auto z1 = NDArrayFactory::create('c', {1, 3}); + auto z2 = NDArrayFactory::create('c', {1, 3}); + auto z3 = NDArrayFactory::create('c', {1, 3}); + auto z4 = NDArrayFactory::create('c', {1, 3}); + auto z5 = NDArrayFactory::create('c', {1, 3}); + auto z6 = NDArrayFactory::create('c', {1, 3}); + + sd::ops::lstmBlockCell op; + auto result = + op.execute({&a, &b, &c, &d, &e, &f, &g, &h}, + {&z0, &z1, &z2, &z3, &z4, &z5, &z6}, {1.0, -1.0}, {0}, {}); + ASSERT_EQ(Status::OK(), result); } TEST_F(DeclarableOpsTests14, test_empty_reduce_min_1) { + auto e = NDArrayFactory::create('c', {1, 0}); + sd::ops::reduce_min sumOp; + auto res2 = sumOp.evaluate({&e}, {1.}, {1}); + ASSERT_EQ(res2.status(), Status::OK()); + auto out = res2.at(0); - auto e = NDArrayFactory::create('c', {1, 0}); - sd::ops::reduce_min sumOp; - auto res2 = sumOp.evaluate({&e}, {1.}, {1}); - ASSERT_EQ(res2.status(), Status::OK()); - auto out = res2.at(0); - - ASSERT_EQ(out->e(0), DataTypeUtils::infOrMax()); - + ASSERT_EQ(out.e(0), DataTypeUtils::infOrMax()); } TEST_F(DeclarableOpsTests14, test_empty_reduce_max_1) { + auto e = NDArrayFactory::create('c', {1, 0}); + sd::ops::reduce_max sumOp; + auto res2 = sumOp.evaluate({&e}, {1.}, {1}); + ASSERT_EQ(res2.status(), Status::OK()); + auto out = res2.at(0); - auto e = NDArrayFactory::create('c', {1, 0}); - sd::ops::reduce_max sumOp; - auto res2 = sumOp.evaluate({&e}, {1.}, {1}); - ASSERT_EQ(res2.status(), Status::OK()); - auto out = res2.at(0); - - ASSERT_EQ(out->e(0), -DataTypeUtils::infOrMax()); - + ASSERT_EQ(out.e(0), -DataTypeUtils::infOrMax()); } TEST_F(DeclarableOpsTests14, test_empty_reduce_sum_1) { #ifdef FFAST_MATH - if (1 > 0) - return; + if (1 > 0) return; #endif - auto e = NDArrayFactory::create('c', {1, 0}); - sd::ops::reduce_sum sumOp; - auto res2 = sumOp.evaluate({&e}, {1.}, {1}); - ASSERT_EQ(res2.status(), Status::OK()); - auto out = res2.at(0); - ASSERT_EQ(out->e(0), 0.f); - + auto e = NDArrayFactory::create('c', {1, 0}); + sd::ops::reduce_sum sumOp; + auto res2 = sumOp.evaluate({&e}, {1.}, {1}); + ASSERT_EQ(res2.status(), Status::OK()); + auto out = res2.at(0); + ASSERT_EQ(out.e(0), 0.f); } TEST_F(DeclarableOpsTests14, test_empty_reduce_mean_1) { #ifdef FFAST_MATH - if (1 > 0) - return; + if (1 > 0) return; #endif - auto e = NDArrayFactory::create('c', {1, 0}); - sd::ops::reduce_mean sumOp; - auto res2 = sumOp.evaluate({&e}, {1.}, {1}); - ASSERT_EQ(res2.status(), Status::OK()); - auto out = res2.at(0); - // out->printShapeInfo("ReduceMean empty shape with keep dims"); - // out->printIndexedBuffer("ReduceMean scalar"); - ASSERT_TRUE(std::isnan(out->e(0))); - + auto e = NDArrayFactory::create('c', {1, 0}); + sd::ops::reduce_mean sumOp; + auto res2 = sumOp.evaluate({&e}, {1.}, {1}); + ASSERT_EQ(res2.status(), Status::OK()); + auto out = res2.at(0); + // out->printShapeInfo("ReduceMean empty shape with keep dims"); + // out->printIndexedBuffer("ReduceMean scalar"); + ASSERT_TRUE(std::isnan(out.e(0))); } TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_1) { - auto matrix = NDArrayFactory::create('c', {1, 2, 0, 4}); - auto b = NDArrayFactory::create('c', {3}, {0, 0, 0}); - auto e = NDArrayFactory::create('c', {3}, {2,0,2}); - auto s = NDArrayFactory::create('c', {3}, {1,1,1}); - - auto exp = NDArrayFactory::create('c', {1,0,0,4}); - - matrix.linspace(1); + auto matrix = NDArrayFactory::create('c', {1, 2, 0, 4}); + auto b = NDArrayFactory::create('c', {3}, {0, 0, 0}); + auto e = NDArrayFactory::create('c', {3}, {2, 0, 2}); + auto s = NDArrayFactory::create('c', {3}, {1, 1, 1}); - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 0}); - ASSERT_EQ(Status::OK(), result.status()); + auto exp = NDArrayFactory::create('c', {1, 0, 0, 4}); - auto z = result.at(0); + matrix.linspace(1); - ASSERT_TRUE(exp.isSameShape(z)); + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 0}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); } TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_2) { - auto matrix = NDArrayFactory::create('c', {1, 2, 0, 4}); - auto b = NDArrayFactory::create('c', {3}, {0, 0, 0}); - auto e = NDArrayFactory::create('c', {3}, {2,0,2}); - auto s = NDArrayFactory::create('c', {3}, {1,1,1}); + auto matrix = NDArrayFactory::create('c', {1, 2, 0, 4}); + auto b = NDArrayFactory::create('c', {3}, {0, 0, 0}); + auto e = NDArrayFactory::create('c', {3}, {2, 0, 2}); + auto s = NDArrayFactory::create('c', {3}, {1, 1, 1}); - auto exp = NDArrayFactory::create('c', {0,0,4}); + auto exp = NDArrayFactory::create('c', {0, 0, 4}); - matrix.linspace(1); + matrix.linspace(1); - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); } TEST_F(DeclarableOpsTests14, test_empty_argmax_1) { - auto x = NDArrayFactory::create('c', {1, 0}); - auto y = NDArrayFactory::create(0); - auto e = NDArrayFactory::create('c', {0}); - - sd::ops::argmax op; - //sd::ops::reduce_max op; + auto x = NDArrayFactory::create('c', {1, 0}); + auto y = NDArrayFactory::create(0); + auto e = NDArrayFactory::create('c', {0}); - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::argmax op; + // sd::ops::reduce_max op; - auto z = result.at(0); - - ASSERT_EQ(e, *z); + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests14, test_empty_argmax_2) { - auto x = NDArrayFactory::create('c', {1, 0}); - auto y = NDArrayFactory::create(1); + auto x = NDArrayFactory::create('c', {1, 0}); + auto y = NDArrayFactory::create(1); - sd::ops::argmax op; - try { - auto result = op.execute({&x, &y}, {&y}, {}, {}, {}); - ASSERT_TRUE(false); - } catch (std::exception &e) { - // - } + sd::ops::argmax op; + try { + auto result = op.execute({&x, &y}, {&y}, {}, {}, {}); + ASSERT_TRUE(false); + } catch (std::exception &e) { + // + } } TEST_F(DeclarableOpsTests14, test_empty_tanh_5) { - auto x = NDArrayFactory::create('c', {32, 0}); - - sd::ops::tanh op; - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); + auto x = NDArrayFactory::create('c', {32, 0}); - ASSERT_TRUE(x.isSameShape(z)); - ASSERT_EQ(x, *z); + sd::ops::tanh op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(x.isSameShape(z)); + ASSERT_EQ(x, z); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, repeat_1) { + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + NDArray e('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6}); - NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}); - NDArray e('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6}); - - sd::ops::repeat op; - auto result = op.evaluate({&x}, {}, {2, 0}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + sd::ops::repeat op; + auto result = op.evaluate({&x}, {}, {2, 0}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, repeat_2) { + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + NDArray e('c', {2, 6}, {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6}); - NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}); - NDArray e('c', {2, 6}, {1, 1, 2, 2, 3, 3,4, 4, 5, 5, 6, 6}); - - sd::ops::repeat op; - auto result = op.evaluate({&x}, {}, {2, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + sd::ops::repeat op; + auto result = op.evaluate({&x}, {}, {2, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, repeat_3) { + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + NDArray e('c', {2, 6}, {1, 2, 2, 3, 3, 3, 4, 5, 5, 6, 6, 6}); - NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}); - NDArray e('c', {2, 6}, {1, 2, 2, 3, 3, 3,4, 5, 5, 6, 6, 6}); - - sd::ops::repeat op; - auto result = op.evaluate({&x}, {}, {1,2,3, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + sd::ops::repeat op; + auto result = op.evaluate({&x}, {}, {1, 2, 3, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, repeat_4) { + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + NDArray e('c', {7, 3}, + {1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 4, 5, 6}); - NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}); - NDArray e('c', {7, 3}, {1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 4, 5, 6}); - - sd::ops::repeat op; - auto result = op.evaluate({&x}, {}, {3,4, 0}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + sd::ops::repeat op; + auto result = op.evaluate({&x}, {}, {3, 4, 0}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, repeat_5) { + NDArray x('c', {2, 3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + NDArray e('c', {2, 4, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 17, 18, 19, 20, 21, 22, 23, 24}); - NDArray x('c', {2, 3, 4}, {1, 2, 3, 4, 5, 6, 7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); - NDArray e('c', {2, 4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 17, 18, 19, 20, 21, 22, 23, 24}); - - sd::ops::repeat op; - auto result = op.evaluate({&x}, {}, {1,2,1, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + sd::ops::repeat op; + auto result = op.evaluate({&x}, {}, {1, 2, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } ///////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest) { + auto y = NDArray('c', {3}, sd::DataType::FLOAT32); + auto x = NDArray('c', {5, 2, 1}, sd::DataType::FLOAT32); - auto y = NDArray('c', { 3 }, sd::DataType::FLOAT32); - auto x = NDArray('c', { 5, 2, 1 }, sd::DataType::FLOAT32); + auto e = NDArray( + 'c', {5, 2, 3}, + {2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5., 6., 6., 6., + 7., 7., 7., 8., 8., 8., 9., 9., 9., 10., 10., 10., 11., 11., 11.}, + sd::DataType::FLOAT32); - auto e = NDArray('c', { 5, 2, 3 }, { 2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5., 6., 6., 6., 7., 7., 7., 8., 8., 8., 9., 9., 9., 10., 10., 10., 11., 11., 11. }, sd::DataType::FLOAT32); + y.assign(1.0); + x.linspace(1.0); - y.assign(1.0); - x.linspace(1.0); - - sd::ops::add op; - auto result = op.evaluate({ &x, &y }); - ASSERT_EQ(Status::OK(), result.status()); - - auto res = *result.at(0); - - ASSERT_EQ(e, res); + sd::ops::add op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + auto res = result.at(0); + ASSERT_EQ(e, res); } ///////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest2) { + auto y = NDArray('c', {1, 3}, sd::DataType::FLOAT32); + auto x = NDArray('c', {5, 2, 1}, sd::DataType::FLOAT32); - auto y = NDArray('c', { 1, 3 }, sd::DataType::FLOAT32); - auto x = NDArray('c', { 5, 2, 1 }, sd::DataType::FLOAT32); - - auto e = NDArray('c', { 5, 2, 3 }, { 2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5., 6., 6., 6., 7., 7., 7., 8., 8., 8., 9., 9., 9., 10., 10., 10., 11., 11., 11. }, sd::DataType::FLOAT32); + auto e = NDArray( + 'c', {5, 2, 3}, + {2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5., 6., 6., 6., + 7., 7., 7., 8., 8., 8., 9., 9., 9., 10., 10., 10., 11., 11., 11.}, + sd::DataType::FLOAT32); - y.assign(1.0); - x.linspace(1.0); - - sd::ops::add op; - auto result = op.evaluate({ &x, &y }); - ASSERT_EQ(Status::OK(), result.status()); - - auto res = *result.at(0); + y.assign(1.0); + x.linspace(1.0); - ASSERT_EQ(e, res); + sd::ops::add op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + auto res = result.at(0); + ASSERT_EQ(e, res); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest3) { - - auto x = NDArray('c', { 3, 5, 1 }, sd::DataType::FLOAT32); - auto y = NDArray('c', { 3, 1, 4 }, sd::DataType::FLOAT32); - auto z = NDArray('c', { 3, 5, 4 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto e = NDArray('c', { 3, 5, 4 }, { 10., 11., 12., 13., 20., 22., 24., 26., 30., 33., 36., 39., 40., 44., 48., 52., 50., 55., 60., 65., 84., 90., 96., 102., 98., 105., 112., 119., 112., 120., 128., 136., 126., 135., 144., 153., 140., 150., 160., 170., 198., 209., 220., 231., 216., 228., 240., 252., 234., 247., 260., 273., 252., 266., 280., 294., 270., 285., 300., 315. }, sd::DataType::FLOAT32); - - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); - ASSERT_EQ(e, z); + auto x = NDArray('c', {3, 5, 1}, sd::DataType::FLOAT32); + auto y = NDArray('c', {3, 1, 4}, sd::DataType::FLOAT32); + auto z = NDArray('c', {3, 5, 4}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = NDArray( + 'c', {3, 5, 4}, + {10., 11., 12., 13., 20., 22., 24., 26., 30., 33., 36., 39., + 40., 44., 48., 52., 50., 55., 60., 65., 84., 90., 96., 102., + 98., 105., 112., 119., 112., 120., 128., 136., 126., 135., 144., 153., + 140., 150., 160., 170., 198., 209., 220., 231., 216., 228., 240., 252., + 234., 247., 260., 273., 252., 266., 280., 294., 270., 285., 300., 315.}, + sd::DataType::FLOAT32); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest4) { - - auto x = NDArray('c', { 2, 3, 5, 1 }, sd::DataType::FLOAT32); - auto y = NDArray('c', { 2, 3, 1, 4 }, sd::DataType::FLOAT32); - auto z = NDArray('c', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto e = NDArray('c', { 2, 3, 5, 4 }, { 10., 11., 12., 13.,20., 22., 24., 26.,30., 33., 36., 39.,40., 44., 48., 52.,50., 55., 60., 65.,84., 90., 96., 102.,98., 105., 112., 119.,112., 120., 128., 136.,126., 135., 144., 153.,140., 150., 160., 170.,198., 209., 220., 231.,216., 228., 240., 252.,234., 247., 260., 273.,252., 266., 280., 294.,270., 285., 300., 315.,352., 368., 384., 400.,374., 391., 408., 425.,396., 414., 432., 450.,418., 437., 456., 475.,440., 460., 480., 500.,546., 567., 588., 609.,572., 594., 616., 638.,598., 621., 644., 667.,624., 648., 672., 696.,650., 675., 700., 725.,780., 806., 832., 858.,810., 837., 864., 891.,840., 868., 896., 924.,870., 899., 928., 957.,900., 930., 960., 990. }, sd::DataType::FLOAT32); - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); - ASSERT_EQ(e, z); + auto x = NDArray('c', {2, 3, 5, 1}, sd::DataType::FLOAT32); + auto y = NDArray('c', {2, 3, 1, 4}, sd::DataType::FLOAT32); + auto z = NDArray('c', {2, 3, 5, 4}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = NDArray( + 'c', {2, 3, 5, 4}, + {10., 11., 12., 13., 20., 22., 24., 26., 30., 33., 36., 39., + 40., 44., 48., 52., 50., 55., 60., 65., 84., 90., 96., 102., + 98., 105., 112., 119., 112., 120., 128., 136., 126., 135., 144., 153., + 140., 150., 160., 170., 198., 209., 220., 231., 216., 228., 240., 252., + 234., 247., 260., 273., 252., 266., 280., 294., 270., 285., 300., 315., + 352., 368., 384., 400., 374., 391., 408., 425., 396., 414., 432., 450., + 418., 437., 456., 475., 440., 460., 480., 500., 546., 567., 588., 609., + 572., 594., 616., 638., 598., 621., 644., 667., 624., 648., 672., 696., + 650., 675., 700., 725., 780., 806., 832., 858., 810., 837., 864., 891., + 840., 868., 896., 924., 870., 899., 928., 957., 900., 930., 960., 990.}, + sd::DataType::FLOAT32); + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest5) { - - auto x = NDArray('c', { 3, 5, 1 }, sd::DataType::FLOAT32); - auto y = NDArray('c', { 3, 1, 4 }, sd::DataType::FLOAT32); - auto z = NDArray('c', { 3, 5, 4 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto e = NDArray('c', { 3, 5, 4 }, { 0.1, 0.090909, 0.083333, 0.076923,0.2, 0.181818, 0.166667, 0.153846,0.3, 0.272727, 0.250000, 0.230769,0.4, 0.363636, 0.333333, 0.307692,0.5, 0.454545, 0.416667, 0.384615, 0.428571, 0.400000, 0.375000, 0.352941, 0.500000, 0.466667, 0.437500, 0.411765, 0.571429, 0.533333, 0.500000, 0.470588, 0.642857, 0.600000, 0.562500, 0.529412, 0.714286, 0.666667, 0.625000, 0.588235, 0.611111, 0.578947, 0.550000, 0.523810, 0.666667, 0.631579, 0.600000, 0.571429, 0.722222, 0.684211, 0.650000, 0.619048, 0.777778, 0.736842, 0.700000, 0.666667, 0.833333, 0.789474, 0.750000, 0.714286 }, sd::DataType::FLOAT32); - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyTrueBroadcast(BroadcastOpsTuple::Divide(), y, z); - ASSERT_EQ(e, z); + auto x = NDArray('c', {3, 5, 1}, sd::DataType::FLOAT32); + auto y = NDArray('c', {3, 1, 4}, sd::DataType::FLOAT32); + auto z = NDArray('c', {3, 5, 4}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = NDArray( + 'c', {3, 5, 4}, + {0.1, 0.090909, 0.083333, 0.076923, 0.2, 0.181818, 0.166667, + 0.153846, 0.3, 0.272727, 0.250000, 0.230769, 0.4, 0.363636, + 0.333333, 0.307692, 0.5, 0.454545, 0.416667, 0.384615, 0.428571, + 0.400000, 0.375000, 0.352941, 0.500000, 0.466667, 0.437500, 0.411765, + 0.571429, 0.533333, 0.500000, 0.470588, 0.642857, 0.600000, 0.562500, + 0.529412, 0.714286, 0.666667, 0.625000, 0.588235, 0.611111, 0.578947, + 0.550000, 0.523810, 0.666667, 0.631579, 0.600000, 0.571429, 0.722222, + 0.684211, 0.650000, 0.619048, 0.777778, 0.736842, 0.700000, 0.666667, + 0.833333, 0.789474, 0.750000, 0.714286}, + sd::DataType::FLOAT32); + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Divide(), y, z); + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest6) { - - auto x = NDArray('c', { 2, 3, 5, 1 }, sd::DataType::FLOAT32); - auto y = NDArray('c', { 2, 3, 1, 4 }, sd::DataType::FLOAT32); - auto z = NDArray('c', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto e = NDArray('c', { 2, 3, 5, 4 }, { 0.1, 0.090909, 0.083333, 0.076923,0.2, 0.181818, 0.166667, 0.153846,0.3, 0.272727, 0.250000, 0.230769,0.4, 0.363636, 0.333333, 0.307692,0.5, 0.454545, 0.416667, 0.384615, 0.428571, 0.400000, 0.375000, 0.352941, 0.500000, 0.466667, 0.437500, 0.411765, 0.571429, 0.533333, 0.500000, 0.470588, 0.642857, 0.600000, 0.562500, 0.529412, 0.714286, 0.666667, 0.625000, 0.588235,0.611111, 0.578947, 0.550000, 0.523810,0.666667, 0.631579, 0.600000, 0.571429,0.722222, 0.684211, 0.650000, 0.619048,0.777778, 0.736842, 0.700000, 0.666667,0.833333, 0.789474, 0.750000, 0.714286, 0.727273, 0.695652, 0.666667, 0.64, 0.772727, 0.739130, 0.708333, 0.68, 0.818182, 0.782609, 0.750000, 0.72, 0.863636, 0.826087, 0.791667, 0.76, 0.909091, 0.869565, 0.833333, 0.80, 0.807692, 0.777778, 0.750000, 0.724138, 0.846154, 0.814815, 0.785714, 0.758621, 0.884615, 0.851852, 0.821429, 0.793103, 0.923077, 0.888889, 0.857143, 0.827586, 0.961538, 0.925926, 0.892857, 0.862069, 0.866667, 0.838710, 0.812500, 0.787879, 0.900000, 0.870968, 0.843750, 0.818182, 0.933333, 0.903226, 0.875000, 0.848485, 0.966667, 0.935484, 0.906250, 0.878788, 1.000000, 0.967742, 0.937500, 0.909091 }, sd::DataType::FLOAT32); - - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyTrueBroadcast(BroadcastOpsTuple::Divide(), y, z); - ASSERT_EQ(e, z); + auto x = NDArray('c', {2, 3, 5, 1}, sd::DataType::FLOAT32); + auto y = NDArray('c', {2, 3, 1, 4}, sd::DataType::FLOAT32); + auto z = NDArray('c', {2, 3, 5, 4}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = NDArray( + 'c', {2, 3, 5, 4}, + {0.1, 0.090909, 0.083333, 0.076923, 0.2, 0.181818, 0.166667, + 0.153846, 0.3, 0.272727, 0.250000, 0.230769, 0.4, 0.363636, + 0.333333, 0.307692, 0.5, 0.454545, 0.416667, 0.384615, 0.428571, + 0.400000, 0.375000, 0.352941, 0.500000, 0.466667, 0.437500, 0.411765, + 0.571429, 0.533333, 0.500000, 0.470588, 0.642857, 0.600000, 0.562500, + 0.529412, 0.714286, 0.666667, 0.625000, 0.588235, 0.611111, 0.578947, + 0.550000, 0.523810, 0.666667, 0.631579, 0.600000, 0.571429, 0.722222, + 0.684211, 0.650000, 0.619048, 0.777778, 0.736842, 0.700000, 0.666667, + 0.833333, 0.789474, 0.750000, 0.714286, 0.727273, 0.695652, 0.666667, + 0.64, 0.772727, 0.739130, 0.708333, 0.68, 0.818182, 0.782609, + 0.750000, 0.72, 0.863636, 0.826087, 0.791667, 0.76, 0.909091, + 0.869565, 0.833333, 0.80, 0.807692, 0.777778, 0.750000, 0.724138, + 0.846154, 0.814815, 0.785714, 0.758621, 0.884615, 0.851852, 0.821429, + 0.793103, 0.923077, 0.888889, 0.857143, 0.827586, 0.961538, 0.925926, + 0.892857, 0.862069, 0.866667, 0.838710, 0.812500, 0.787879, 0.900000, + 0.870968, 0.843750, 0.818182, 0.933333, 0.903226, 0.875000, 0.848485, + 0.966667, 0.935484, 0.906250, 0.878788, 1.000000, 0.967742, 0.937500, + 0.909091}, + sd::DataType::FLOAT32); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Divide(), y, z); + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest7) { - - auto x = NDArray('c', { 3, 5, 1 }, sd::DataType::FLOAT32); - auto y = NDArray('c', { 3, 1, 4 }, sd::DataType::FLOAT32); - auto z = NDArray('c', { 3, 5, 4 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto e = NDArray('c', { 3, 5, 4 }, { -9., -10., -11., -12.,-8., -9., -10., -11., -7., -8., -9., -10.,-6., -7., -8., -9.,-5., -6., -7., -8.,-8., -9., -10., -11.,-7., -8., -9., -10.,-6., -7., -8., -9.,-5., -6., -7., -8.,-4., -5., -6., -7.,-7., -8.000000, -9.000000, -10.00,-6.000000, -7.000000, -8.000000, -9.000,-5.000000, -6.000000, -7.000000, -8.000,-4.000000, -5.000000, -6.000000, -7.000,-3.000000, -4.000000, -5.000000, -6.000 }, sd::DataType::FLOAT32); - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyTrueBroadcast(BroadcastOpsTuple::Subtract(), y, z); - ASSERT_EQ(e, z); + auto x = NDArray('c', {3, 5, 1}, sd::DataType::FLOAT32); + auto y = NDArray('c', {3, 1, 4}, sd::DataType::FLOAT32); + auto z = NDArray('c', {3, 5, 4}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = + NDArray('c', {3, 5, 4}, + {-9., -10., -11., -12., -8., -9., + -10., -11., -7., -8., -9., -10., + -6., -7., -8., -9., -5., -6., + -7., -8., -8., -9., -10., -11., + -7., -8., -9., -10., -6., -7., + -8., -9., -5., -6., -7., -8., + -4., -5., -6., -7., -7., -8.000000, + -9.000000, -10.00, -6.000000, -7.000000, -8.000000, -9.000, + -5.000000, -6.000000, -7.000000, -8.000, -4.000000, -5.000000, + -6.000000, -7.000, -3.000000, -4.000000, -5.000000, -6.000}, + sd::DataType::FLOAT32); + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Subtract(), y, z); + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest8) { - - auto x = NDArray('c', { 2, 3, 5, 1 }, sd::DataType::FLOAT32); - auto y = NDArray('c', { 2, 3, 1, 4 }, sd::DataType::FLOAT32); - auto z = NDArray('c', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto e = NDArray('c', { 2, 3, 5, 4 }, { -9.0, -10., -11., -12.,-8., -9., -10., -11.0,-7., -8., -9., -10.,-6., -7., -8., -9.,-5., -6., -7., -8.,-8., -9., -10., -11.,-7., -8., -9., -10.,-6., -7., -8., -9.,-5., -6., -7., -8.,-4., -5., -6., -7.,-7., -8., -9., -10.,-6., -7., -8., -9.,-5., -6., -7., -8.,-4., -5., -6., -7.,-3., -4., -5., -6.,-6., -7., -8., -9.,-5., -6., -7., -8.,-4., -5., -6., -7.,-3., -4., -5., -6.,-2., -3., -4., -5.,-5., -6., -7., -8.,-4., -5., -6., -7.,-3., -4., -5., -6.,-2., -3., -4., -5.,-1., -2., -3., -4.,-4., -5., -6., -7.,-3., -4., -5., -6.,-2., -3., -4., -5.,-1., -2., -3., -4., 0., -1., -2., -3. }, sd::DataType::FLOAT32); - - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyTrueBroadcast(BroadcastOpsTuple::Subtract(), y, z); - ASSERT_EQ(e, z); + auto x = NDArray('c', {2, 3, 5, 1}, sd::DataType::FLOAT32); + auto y = NDArray('c', {2, 3, 1, 4}, sd::DataType::FLOAT32); + auto z = NDArray('c', {2, 3, 5, 4}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = NDArray( + 'c', {2, 3, 5, 4}, + {-9.0, -10., -11., -12., -8., -9., -10., -11.0, -7., -8., -9., -10., + -6., -7., -8., -9., -5., -6., -7., -8., -8., -9., -10., -11., + -7., -8., -9., -10., -6., -7., -8., -9., -5., -6., -7., -8., + -4., -5., -6., -7., -7., -8., -9., -10., -6., -7., -8., -9., + -5., -6., -7., -8., -4., -5., -6., -7., -3., -4., -5., -6., + -6., -7., -8., -9., -5., -6., -7., -8., -4., -5., -6., -7., + -3., -4., -5., -6., -2., -3., -4., -5., -5., -6., -7., -8., + -4., -5., -6., -7., -3., -4., -5., -6., -2., -3., -4., -5., + -1., -2., -3., -4., -4., -5., -6., -7., -3., -4., -5., -6., + -2., -3., -4., -5., -1., -2., -3., -4., 0., -1., -2., -3.}, + sd::DataType::FLOAT32); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Subtract(), y, z); + ASSERT_EQ(e, z); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test1) { + auto x = NDArrayFactory::create('c', {3, 4}); + auto y = NDArrayFactory::create('c', {4, 3}); + auto exp = NDArrayFactory::create( + 'f', {3, 3}, {35., 79., 123., 40., 92., 144., 45., 105., 165.}); - auto x = NDArrayFactory::create('c', {3, 4}); - auto y = NDArrayFactory::create('c', {4, 3}); - auto exp = NDArrayFactory::create('f', {3, 3}, {35., 79., 123., 40., 92., 144., 45., 105., 165.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.5, 0.5); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test2) { + auto x = NDArrayFactory::create('c', {3, 4}); + auto y = NDArrayFactory::create('f', {4, 3}); + auto exp = NDArrayFactory::create( + 'f', {3, 3}, {35., 79., 123., 40., 92., 144., 45., 105., 165.}); - auto x = NDArrayFactory::create('c', {3, 4}); - auto y = NDArrayFactory::create('f', {4, 3}); - auto exp = NDArrayFactory::create('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.5, 0.5); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test3) { + auto x = NDArrayFactory::create('f', {3, 4}); + auto y = NDArrayFactory::create('c', {4, 3}); + auto exp = NDArrayFactory::create( + 'f', {3, 3}, {35., 79., 123., 40., 92., 144., 45., 105., 165.}); - auto x = NDArrayFactory::create('f', {3, 4}); - auto y = NDArrayFactory::create('c', {4, 3}); - auto exp = NDArrayFactory::create('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.5, 0.5); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test4) { + auto x = NDArrayFactory::create('f', {3, 4}); + auto y = NDArrayFactory::create('f', {4, 3}); + auto exp = NDArrayFactory::create( + 'f', {3, 3}, {35., 79., 123., 40., 92., 144., 45., 105., 165.}); - auto x = NDArrayFactory::create ('f', {3, 4}); - auto y = NDArrayFactory::create('f', {4, 3}); - auto exp = NDArrayFactory::create('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.5, 0.5); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test5) { + auto x = NDArrayFactory::create('c', {4, 3}); + auto y = NDArrayFactory::create('c', {4, 3}); + auto exp = NDArrayFactory::create( + 'f', {3, 3}, {83., 94., 105., 94., 107., 120., 105., 120., 135.}); - auto x = NDArrayFactory::create('c', {4, 3}); - auto y = NDArrayFactory::create('c', {4, 3}); - auto exp = NDArrayFactory::create('f', {3, 3}, {83., 94., 105., 94., 107., 120., 105., 120., 135.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.5, 0.5); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test6) { + auto x = NDArrayFactory::create('c', {4, 3}); + auto y = NDArrayFactory::create('f', {3, 4}); + auto exp = NDArrayFactory::create( + 'f', {3, 3}, {35., 40., 45., 79., 92., 105., 123., 144., 165.}); - auto x = NDArrayFactory::create('c', {4, 3}); - auto y = NDArrayFactory::create('f', {3, 4}); - auto exp = NDArrayFactory::create('f', {3, 3}, {35., 40., 45., 79., 92., 105., 123., 144., 165.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.5, 0.5); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test7) { - - auto x = NDArrayFactory::create('c', {5, 3,4}); - auto y = NDArrayFactory::create('f', {5, 3,4}); - auto exp = NDArrayFactory::create('f',{5, 3,3}, {3. , 84.6, 281.4, 593.4, 1020.6, 7. , 107.8, 323.8, 655. , 1101.4,11. , 131. , 366.2, 716.6, 1182.2, - 7. , 107.8, 323.8, 655. , 1101.4,17.4, 137.4, 372.6, 723. , 1188.6,27.8, 167. , 421.4, 791. , 1275.8, - 11. , 131. , 366.2, 716.6, 1182.2,27.8, 167. , 421.4, 791. , 1275.8,44.6, 203. , 476.6, 865.4, 1369.4,}); - - x.linspace(1.); - y.linspace(0.1, 0.1); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {0, 1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto x = NDArrayFactory::create('c', {5, 3, 4}); + auto y = NDArrayFactory::create('f', {5, 3, 4}); + auto exp = NDArrayFactory::create( + 'f', {5, 3, 3}, + { + 3., 84.6, 281.4, 593.4, 1020.6, 7., 107.8, 323.8, 655., 1101.4, + 11., 131., 366.2, 716.6, 1182.2, 7., 107.8, 323.8, 655., 1101.4, + 17.4, 137.4, 372.6, 723., 1188.6, 27.8, 167., 421.4, 791., 1275.8, + 11., 131., 366.2, 716.6, 1182.2, 27.8, 167., 421.4, 791., 1275.8, + 44.6, 203., 476.6, 865.4, 1369.4, + }); + + x.linspace(1.); + y.linspace(0.1, 0.1); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {0, 1}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test8) { - - auto x = NDArrayFactory::create('c', {2,5, 3,4}); - auto y = NDArrayFactory::create('f', {2,5, 3,4}); - auto exp = NDArrayFactory::create('f',{2,5, 3,3}, {3. , 1563. , 84.6, 2220.6, 281.4, 2993.4, 593.4, 3881.4,1020.6, 4884.6, 7. , 1663. , 107.8, 2339.8, 323.8, 3131.8, 655. , 4039. ,1101.4, 5061.4, - 11. , 1763. , 131. , 2459. , 366.2, 3270.2, 716.6, 4196.6,1182.2, 5238.2, 7. , 1663. , 107.8, 2339.8, 323.8, 3131.8, 655. , 4039. ,1101.4, 5061.4, - 17.4, 1769.4, 137.4, 2465.4, 372.6, 3276.6, 723. , 4203. ,1188.6, 5244.6, 27.8, 1875.8, 167. , 2591. , 421.4, 3421.4, 791. , 4367. ,1275.8, 5427.8, - 11. , 1763. , 131. , 2459. , 366.2, 3270.2, 716.6, 4196.6,1182.2, 5238.2, 27.8, 1875.8, 167. , 2591. , 421.4, 3421.4, 791. , 4367. ,1275.8, 5427.8, - 44.6, 1988.6, 203. , 2723. , 476.6, 3572.6, 865.4, 4537.4,1369.4, 5617.4}); - - x.linspace(1.); - y.linspace(0.1, 0.1); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {0, 1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto x = NDArrayFactory::create('c', {2, 5, 3, 4}); + auto y = NDArrayFactory::create('f', {2, 5, 3, 4}); + auto exp = NDArrayFactory::create( + 'f', {2, 5, 3, 3}, + {3., 1563., 84.6, 2220.6, 281.4, 2993.4, 593.4, 3881.4, 1020.6, + 4884.6, 7., 1663., 107.8, 2339.8, 323.8, 3131.8, 655., 4039., + 1101.4, 5061.4, 11., 1763., 131., 2459., 366.2, 3270.2, 716.6, + 4196.6, 1182.2, 5238.2, 7., 1663., 107.8, 2339.8, 323.8, 3131.8, + 655., 4039., 1101.4, 5061.4, 17.4, 1769.4, 137.4, 2465.4, 372.6, + 3276.6, 723., 4203., 1188.6, 5244.6, 27.8, 1875.8, 167., 2591., + 421.4, 3421.4, 791., 4367., 1275.8, 5427.8, 11., 1763., 131., + 2459., 366.2, 3270.2, 716.6, 4196.6, 1182.2, 5238.2, 27.8, 1875.8, + 167., 2591., 421.4, 3421.4, 791., 4367., 1275.8, 5427.8, 44.6, + 1988.6, 203., 2723., 476.6, 3572.6, 865.4, 4537.4, 1369.4, 5617.4}); + + x.linspace(1.); + y.linspace(0.1, 0.1); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {0, 1}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test9) { - - auto x = NDArrayFactory::create('c', {2,5, 4,3}); - auto y = NDArrayFactory::create('f', {2,5, 3,4}); - auto exp = NDArrayFactory::create('f',{2,5, 3,3}, {7. , 1639. , 103. , 2311. , 314.2, 3098.2, 640.6, 4000.6,1082.2, 5018.2, 8. , 1664. , 108.8, 2340.8, 324.8, 3132.8, 656. , 4040. ,1102.4, 5062.4, - 9. , 1689. , 114.6, 2370.6, 335.4, 3167.4, 671.4, 4079.4,1122.6, 5106.6, 15.8, 1743.8, 131. , 2435. , 361.4, 3241.4, 707. , 4163. ,1167.8, 5199.8, - 18.4, 1770.4, 138.4, 2466.4, 373.6, 3277.6, 724. , 4204. ,1189.6, 5245.6, 21. , 1797. , 145.8, 2497.8, 385.8, 3313.8, 741. , 4245. ,1211.4, 5291.4, - 24.6, 1848.6, 159. , 2559. , 408.6, 3384.6, 773.4, 4325.4,1253.4, 5381.4, 28.8, 1876.8, 168. , 2592. , 422.4, 3422.4, 792. , 4368. ,1276.8, 5428.8, - 33. , 1905. , 177. , 2625. , 436.2, 3460.2, 810.6, 4410.6,1300.2, 5476.2}); - - x.linspace(1.); - y.linspace(0.1, 0.1); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto x = NDArrayFactory::create('c', {2, 5, 4, 3}); + auto y = NDArrayFactory::create('f', {2, 5, 3, 4}); + auto exp = NDArrayFactory::create( + 'f', {2, 5, 3, 3}, + {7., 1639., 103., 2311., 314.2, 3098.2, 640.6, 4000.6, 1082.2, + 5018.2, 8., 1664., 108.8, 2340.8, 324.8, 3132.8, 656., 4040., + 1102.4, 5062.4, 9., 1689., 114.6, 2370.6, 335.4, 3167.4, 671.4, + 4079.4, 1122.6, 5106.6, 15.8, 1743.8, 131., 2435., 361.4, 3241.4, + 707., 4163., 1167.8, 5199.8, 18.4, 1770.4, 138.4, 2466.4, 373.6, + 3277.6, 724., 4204., 1189.6, 5245.6, 21., 1797., 145.8, 2497.8, + 385.8, 3313.8, 741., 4245., 1211.4, 5291.4, 24.6, 1848.6, 159., + 2559., 408.6, 3384.6, 773.4, 4325.4, 1253.4, 5381.4, 28.8, 1876.8, + 168., 2592., 422.4, 3422.4, 792., 4368., 1276.8, 5428.8, 33., + 1905., 177., 2625., 436.2, 3460.2, 810.6, 4410.6, 1300.2, 5476.2}); + + x.linspace(1.); + y.linspace(0.1, 0.1); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, matmul_test10) { + auto x = NDArrayFactory::create('c', {3, 5}); + x.linspace(1); - auto x = NDArrayFactory::create_('c', {3, 5}); - x->linspace(1); - - auto y = NDArrayFactory::create_('c', {5, 3}); - y->linspace(1); + auto y = NDArrayFactory::create('c', {5, 3}); + y.linspace(1); - float _expB[]{135.0f, 310.0f, 485.0f, 150.0f, 350.0f, 550.0f, 165.0f, 390.0f, 615.0f}; - Nd4jLong _expS[] {2, 3, 3, 1, 3, 0, 1, 102}; // expected shape - ArrayOptions::setDataType(_expS, sd::DataType::FLOAT32); - NDArray exp(_expB, _expS); + float _expB[]{135.0f, 310.0f, 485.0f, 150.0f, 350.0f, + 550.0f, 165.0f, 390.0f, 615.0f}; + Nd4jLong _expS[]{2, 3, 3, 1, 3, 0, 1, 102}; // expected shape + ArrayOptions::setDataType(_expS, sd::DataType::FLOAT32); + NDArray exp(_expB, _expS); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, new Variable()); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + variableSpace->putVariable(1, std::make_shared()); - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1, -2}); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::matmul op; + sd::ops::matmul op; - Nd4jStatus status = op.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(variableSpace->hasVariable(1)); + Nd4jStatus status = op.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(variableSpace->hasVariable(1)); - auto result = variableSpace->getVariable(1)->getNDArray(); + auto result = variableSpace->getVariable(1)->getNDArray(); - ASSERT_TRUE(result->equalsTo(&exp)); + ASSERT_TRUE(result->equalsTo(exp)); - delete block; - delete variableSpace; + delete block; + delete variableSpace; } TEST_F(DeclarableOpsTests14, matmul_test11) { - auto A = NDArrayFactory::create('c', {3, 3}); - auto B = NDArrayFactory::create('c', {3, 1}); - auto exp = NDArrayFactory::create('c', {3, 1}, {14.00f, 32.00f, 50.00f}); - - A.linspace(1); - B.linspace(1); + auto A = NDArrayFactory::create('c', {3, 3}); + auto B = NDArrayFactory::create('c', {3, 1}); + auto exp = + NDArrayFactory::create('c', {3, 1}, {14.00f, 32.00f, 50.00f}); - sd::ops::matmul op; + A.linspace(1); + B.linspace(1); - auto result = op.evaluate({&A, &B}, {}, {}); + sd::ops::matmul op; - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({&A, &B}, {}, {}); - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, matmul_test12) { - auto x= NDArrayFactory::create('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12}); - auto y= NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12}); - auto exp= NDArrayFactory::create('f', {4, 4}, {38.0, 44.0, 50.0, 56.0, 83.0, 98.0, 113.0, 128.0, 128.0, 152.0, 176.0, 200.0, 173.0, 206.0, 239.0, 272.0}); - - sd::ops::matmul op; - auto result = op.evaluate({&x, &y}, {}, {1, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto x = NDArrayFactory::create( + 'c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto y = NDArrayFactory::create( + 'c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create( + 'f', {4, 4}, + {38.0, 44.0, 50.0, 56.0, 83.0, 98.0, 113.0, 128.0, 128.0, 152.0, 176.0, + 200.0, 173.0, 206.0, 239.0, 272.0}); + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {1, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests14, matmul_test13) { - auto x= NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); - auto y= NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto exp= NDArrayFactory::create('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); - - sd::ops::matmul op; - auto result = op.evaluate({&x, &y}, {}, {1, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto x = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); + auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create( + 'f', {3, 4}, + {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); - //z->printIndexedBuffer("z"); + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printIndexedBuffer("z"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, matmul_test14) { - auto x= NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); - auto y= NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); - auto exp= NDArrayFactory::create('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); + auto x = NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); + auto y = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create( + 'f', {3, 4}, + {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); - sd::ops::matmul op; - auto result = op.evaluate({&x, &y}, {}, {0, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {0, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - //z->printIndexedBuffer("z"); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printIndexedBuffer("z"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, matmul_test15) { - auto x= NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); - auto y= NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto exp= NDArrayFactory::create('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); - - sd::ops::matmul op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); + auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create( + 'f', {3, 4}, + {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); - auto z = result.at(0); + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - //z->printIndexedBuffer("z"); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printIndexedBuffer("z"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, matmul_test16) { - auto x= NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); - auto y= NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto exp= NDArrayFactory::create('f', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16}); - - sd::ops::matmul op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create( + 'f', {4, 4}, {1, 2, 3, 4, 2, 4, 6, 8, 3, 6, 9, 12, 4, 8, 12, 16}); - //z->printIndexedBuffer("z"); + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printIndexedBuffer("z"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, matmul_test17) { - auto x = NDArrayFactory::create('c', {1, 2}, {2.0f, 2.0f}); - auto y = NDArrayFactory::create('c', {2, 1}, {2.0f, 2.0f}); - auto exp = NDArrayFactory::create('c', {1, 1}, {8.0f}); - - sd::ops::matmul op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_EQ(exp, *result.at(0)); + auto x = NDArrayFactory::create('c', {1, 2}, {2.0f, 2.0f}); + auto y = NDArrayFactory::create('c', {2, 1}, {2.0f, 2.0f}); + auto exp = NDArrayFactory::create('c', {1, 1}, {8.0f}); + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(exp, result.at(0)); } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test18) { + auto x = NDArrayFactory::create('c', {1, 4, 3}); + auto y = NDArrayFactory::create('f', {1, 3, 4}); + auto exp = NDArrayFactory::create( + 'f', {1, 3, 3}, {35., 40., 45., 79., 92., 105., 123., 144., 165.}); - auto x = NDArrayFactory::create('c', {1, 4, 3}); - auto y = NDArrayFactory::create('f', {1, 3, 4}); - auto exp = NDArrayFactory::create('f', {1, 3, 3}, {35., 40., 45., 79., 92., 105., 123., 144., 165.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.5, 0.5); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test19) { + auto x = NDArrayFactory::create('c', {4, 1}); + auto y = NDArrayFactory::create('f', {1, 4}); + auto exp = NDArrayFactory::create('f', {1, 1}, {15}); - auto x = NDArrayFactory::create('c', {4, 1}); - auto y = NDArrayFactory::create('f', {1, 4}); - auto exp = NDArrayFactory::create('f', {1, 1}, {15}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 1}); - ASSERT_EQ(Status::OK(), results.status()); - - auto z = results.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + x.linspace(1.); + y.linspace(0.5, 0.5); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + ASSERT_EQ(Status::OK(), results.status()); + auto z = results.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test20) { + auto x = NDArrayFactory::create('c', {1, 4, 1}); + auto y = NDArrayFactory::create('f', {1, 1, 4}); + auto exp = NDArrayFactory::create('f', {1, 1, 1}, {15}); - auto x = NDArrayFactory::create('c', {1, 4, 1}); - auto y = NDArrayFactory::create('f', {1, 1, 4}); - auto exp = NDArrayFactory::create('f', {1, 1, 1}, {15}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 1}); - - ASSERT_EQ(Status::OK(), results.status()); - auto z = results.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.5, 0.5); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + ASSERT_EQ(Status::OK(), results.status()); + auto z = results.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test21) { + auto x = NDArrayFactory::create('c', {2, 3}); + auto y = NDArrayFactory::create('c', {3, 5}); + auto exp = NDArrayFactory::create( + 'f', {5, 2}, {23., 26., 29., 32., 35., 50., 57.5, 65., 72.5, 80.}); - auto x = NDArrayFactory::create('c', {2, 3}); - auto y = NDArrayFactory::create('c', {3, 5}); - auto exp = NDArrayFactory::create('f', {5, 2}, {23. , 26. , 29. , 32. , 35., 50. , 57.5, 65. , 72.5, 80.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {0, 0, 1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.5, 0.5); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {0, 0, 1}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test22) { + auto x = NDArrayFactory::create('c', {3, 2}); + auto y = NDArrayFactory::create('c', {3, 5}); + auto exp = NDArrayFactory::create( + 'f', {5, 2}, {37., 41.5, 46., 50.5, 55., 46., 52., 58., 64., 70.}); - auto x = NDArrayFactory::create('c', {3, 2}); - auto y = NDArrayFactory::create('c', {3, 5}); - auto exp = NDArrayFactory::create('f', {5, 2}, {37. , 41.5, 46. , 50.5, 55., 46. , 52. , 58. , 64. , 70.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 0, 1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.5, 0.5); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 0, 1}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test23) { + auto x = NDArrayFactory::create('c', {3, 2}); + auto y = NDArrayFactory::create('c', {3, 5}); + auto exp = NDArrayFactory::create( + 'f', {5, 2}, {37., 41.5, 46., 50.5, 55., 46., 52., 58., 64., 70.}); - auto x = NDArrayFactory::create('c', {3, 2}); - auto y = NDArrayFactory::create('c', {3, 5}); - auto exp = NDArrayFactory::create('f', {5, 2}, {37. , 41.5, 46. , 50.5, 55., 46. , 52. , 58. , 64. , 70.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 0, 1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.5, 0.5); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 0, 1}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test24) { - - auto x = NDArrayFactory::create('c', {2,2, 3,5}); - auto y = NDArrayFactory::create('c', {2,2, 4,3}); - auto exp = NDArrayFactory::create('f',{2,2, 4,5}, {4.6, 281.8, 89.2, 582.4, 10. , 314.2,108.1, 628.3, 15.4, 346.6,127. , 674.2, 20.8, 379. ,145.9, 720.1, 5.2, 289.6, 93.4, 593.8, - 11.5, 322.9,113.2, 640.6, 17.8, 356.2,133. , 687.4, 24.1, 389.5,152.8, 734.2, 5.8, 297.4, 97.6, 605.2, 13. , 331.6,118.3, 652.9, - 20.2, 365.8,139. , 700.6, 27.4, 400. ,159.7, 748.3, 6.4, 305.2,101.8, 616.6, 14.5, 340.3,123.4, 665.2, 22.6, 375.4,145. , 713.8, - 30.7, 410.5,166.6, 762.4, 7. , 313. ,106. , 628. , 16. , 349. ,128.5, 677.5, 25. , 385. ,151. , 727. , 34. , 421. ,173.5, 776.5}); - - x.linspace(1.); - y.linspace(0.1, 0.1); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 1, 1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto x = NDArrayFactory::create('c', {2, 2, 3, 5}); + auto y = NDArrayFactory::create('c', {2, 2, 4, 3}); + auto exp = NDArrayFactory::create( + 'f', {2, 2, 4, 5}, + {4.6, 281.8, 89.2, 582.4, 10., 314.2, 108.1, 628.3, 15.4, 346.6, + 127., 674.2, 20.8, 379., 145.9, 720.1, 5.2, 289.6, 93.4, 593.8, + 11.5, 322.9, 113.2, 640.6, 17.8, 356.2, 133., 687.4, 24.1, 389.5, + 152.8, 734.2, 5.8, 297.4, 97.6, 605.2, 13., 331.6, 118.3, 652.9, + 20.2, 365.8, 139., 700.6, 27.4, 400., 159.7, 748.3, 6.4, 305.2, + 101.8, 616.6, 14.5, 340.3, 123.4, 665.2, 22.6, 375.4, 145., 713.8, + 30.7, 410.5, 166.6, 762.4, 7., 313., 106., 628., 16., 349., + 128.5, 677.5, 25., 385., 151., 727., 34., 421., 173.5, 776.5}); + + x.linspace(1.); + y.linspace(0.1, 0.1); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1, 1}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test25) { + auto x = NDArrayFactory::create('f', {4, 3}); + auto y = NDArrayFactory::create('c', {4}); + auto exp = NDArrayFactory::create('f', {3}, {7., 8., 9.}); - auto x = NDArrayFactory::create('f', {4, 3}); - auto y = NDArrayFactory::create('c', {4}); - auto exp = NDArrayFactory::create('f',{3}, {7., 8., 9.}); - - x.linspace(1.); - y.linspace(0.1, 0.1); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 0}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.1, 0.1); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 0}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test26) { + auto x = NDArrayFactory::create('f', {3}); + auto y = NDArrayFactory::create('c', {4, 3}); + auto exp = NDArrayFactory::create('f', {4}, {1.4, 3.2, 5., 6.8}); - auto x = NDArrayFactory::create('f', {3}); - auto y = NDArrayFactory::create('c', {4, 3}); - auto exp = NDArrayFactory::create('f',{4}, {1.4, 3.2, 5., 6.8}); - - x.linspace(1.); - y.linspace(0.1, 0.1); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {0, 1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.1, 0.1); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {0, 1}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test27) { + auto x = NDArrayFactory::create('f', {1, 1}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('f', {1, 1}, {0.2}); - auto x = NDArrayFactory::create('f', {1, 1}); - auto y = NDArrayFactory::create('c', {1, 1}); - auto exp = NDArrayFactory::create('f',{1, 1}, {0.2}); - - x.linspace(2.); - y.linspace(0.1, 0.1); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(2.); + y.linspace(0.1, 0.1); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test28) { + auto x = NDArrayFactory::create('f', {1, 1}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('f', {1, 1}, {0.2}); - auto x = NDArrayFactory::create('f', {1, 1}); - auto y = NDArrayFactory::create('c', {1, 1}); - auto exp = NDArrayFactory::create('f',{1, 1}, {0.2}); - - x.linspace(2.); - y.linspace(0.1, 0.1); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1,1,1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(2.); + y.linspace(0.1, 0.1); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1, 1}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test29) { + auto x = NDArrayFactory::create('f', {1}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('f', {1}, {0.2}); - auto x = NDArrayFactory::create('f', {1}); - auto y = NDArrayFactory::create('c', {1, 1}); - auto exp = NDArrayFactory::create('f',{1}, {0.2}); - - x.linspace(2.); - y.linspace(0.1, 0.1); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(2.); + y.linspace(0.1, 0.1); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test30) { + auto x = NDArrayFactory::create('f', {1, 1}); + auto y = NDArrayFactory::create('c', {1}); + auto exp = NDArrayFactory::create('f', {1}, {0.2}); - auto x = NDArrayFactory::create('f', {1,1}); - auto y = NDArrayFactory::create('c', {1}); - auto exp = NDArrayFactory::create('f',{1}, {0.2}); - - x.linspace(2.); - y.linspace(0.1, 0.1); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(2.); + y.linspace(0.1, 0.1); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test31) { + auto x = NDArrayFactory::create('f', {4}); + auto y = NDArrayFactory::create('c', {4}); + auto exp = NDArrayFactory::create(3.); - auto x = NDArrayFactory::create('f', {4}); - auto y = NDArrayFactory::create('c', {4}); - auto exp = NDArrayFactory::create(3.); - - x.linspace(1.); - y.linspace(0.1, 0.1); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.1, 0.1); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test32) { + auto x = NDArrayFactory::create('f', {1}, {2.}); + auto y = NDArrayFactory::create('c', {1}, {3.}); + auto exp = NDArrayFactory::create(6.); - auto x = NDArrayFactory::create('f', {1}, {2.}); - auto y = NDArrayFactory::create('c', {1}, {3.}); - auto exp = NDArrayFactory::create(6.); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ///////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test33) { - auto x = NDArrayFactory::create('c', {4, 3}); - auto y = NDArrayFactory::create('c', {4, 1}); - auto exp = NDArrayFactory::create('c',{ 3, 1}, {70, 80, 90}); - - x.linspace(1); - y.linspace(1); + auto x = NDArrayFactory::create('c', {4, 3}); + auto y = NDArrayFactory::create('c', {4, 1}); + auto exp = NDArrayFactory::create('c', {3, 1}, {70, 80, 90}); - sd::ops::matmul op; - auto result = op.evaluate({&x, &y}, {}, {1, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + x.linspace(1); + y.linspace(1); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test34) { - auto a = NDArrayFactory::create('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto b = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {30, 70, 110}); - - sd::ops::matmul op; - auto result = op.evaluate({&a, &b}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto a = NDArrayFactory::create( + 'c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto b = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {30, 70, 110}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::matmul op; + auto result = op.evaluate({&a, &b}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ///////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test35) { - auto a = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - auto b = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto exp = NDArrayFactory::create('c', {3}, {70, 80, 90}); - - sd::ops::matmul op; - auto result = op.evaluate({&a, &b}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto a = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + auto b = NDArrayFactory::create( + 'c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create('c', {3}, {70, 80, 90}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::matmul op; + auto result = op.evaluate({&a, &b}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test36) { - auto a = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto b = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto exp = NDArrayFactory::create('c', {1, 3}, {70, 80, 90}); + auto a = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto b = NDArrayFactory::create( + 'c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create('c', {1, 3}, {70, 80, 90}); - sd::ops::matmul op; - auto result = op.evaluate({&a, &b}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::matmul op; + auto result = op.evaluate({&a, &b}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test37) { + NDArray a('c', {32, 12, 128, 64}, sd::DataType::FLOAT32); + NDArray b('c', {32, 12, 128, 64}, sd::DataType::FLOAT32); + NDArray c('c', {32, 12, 128, 128}, sd::DataType::FLOAT32); + NDArray cExp('c', {32, 12, 128, 128}, sd::DataType::FLOAT32); - NDArray a('c', {32, 12, 128, 64}, sd::DataType::FLOAT32); - NDArray b('c', {32, 12, 128, 64}, sd::DataType::FLOAT32); - NDArray c('c', {32,12,128,128}, sd::DataType::FLOAT32); - NDArray cExp('c', {32,12,128,128}, sd::DataType::FLOAT32); - - a = 1; - b = 1; - cExp = 64; //Each entry in output c is sum of 64 (1.0 x 1.0) multiplications + a = 1; + b = 1; + cExp = 64; // Each entry in output c is sum of 64 (1.0 x 1.0) multiplications - sd::ops::matmul op; - auto status = op.execute({&a, &b}, {&c}, {}, {0,1}); + sd::ops::matmul op; + auto status = op.execute({&a, &b}, {&c}, {}, {0, 1}); - ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(cExp.isSameShape(c)); - ASSERT_TRUE(cExp.equalsTo(c)); + ASSERT_TRUE(cExp.isSameShape(c)); + ASSERT_TRUE(cExp.equalsTo(c)); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_3D_1) { - - // x[4, 12, 128] * y[4, 128] = z[4, 12, 128] - - auto x = NDArray('c', { 2, 3, 5 }, sd::DataType::FLOAT32); - auto y = NDArray('c', { 2, 5 }, sd::DataType::FLOAT32); - auto z = NDArray('c', { 2, 3, 5 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto e = NDArray('c', { 2, 3, 5 }, { 10.000000, 22.000000, 36.000000, 52.000000, 70.000000, 60.000000, 77.000000, 96.000000, 117.000000, 140.000000, 110.000000, 132.000000, 156.000000, 182.000000, 210.000000, 240.000000, 272.000000, 306.000000, 342.000000, 380.000000, 315.000000, 352.000000, 391.000000, 432.000000, 475.000000, 390.000000, 432.000000, 476.000000, 522.000000, 570.000000 }, sd::DataType::FLOAT32); - - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyBroadcast(sd::broadcast::Multiply, { 0,2 }, y, z); - //z.printBuffer(); - ASSERT_EQ(e, z); + // x[4, 12, 128] * y[4, 128] = z[4, 12, 128] + + auto x = NDArray('c', {2, 3, 5}, sd::DataType::FLOAT32); + auto y = NDArray('c', {2, 5}, sd::DataType::FLOAT32); + auto z = NDArray('c', {2, 3, 5}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = NDArray( + 'c', {2, 3, 5}, + {10.000000, 22.000000, 36.000000, 52.000000, 70.000000, 60.000000, + 77.000000, 96.000000, 117.000000, 140.000000, 110.000000, 132.000000, + 156.000000, 182.000000, 210.000000, 240.000000, 272.000000, 306.000000, + 342.000000, 380.000000, 315.000000, 352.000000, 391.000000, 432.000000, + 475.000000, 390.000000, 432.000000, 476.000000, 522.000000, 570.000000}, + sd::DataType::FLOAT32); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyBroadcast(sd::broadcast::Multiply, {0, 2}, y, z); + // z.printBuffer(); + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_3D_2) { + auto x = NDArray('f', {2, 3, 5}, sd::DataType::FLOAT32); + auto y = NDArray('f', {2, 5}, sd::DataType::FLOAT32); + auto z = NDArray('f', {2, 3, 5}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto eC = + NDArray('c', {2, 3, 5}, + {0.100000, 0.181818, 0.250000, 0.307692, 0.357143, 0.600000, + 0.636364, 0.666667, 0.692308, 0.714286, 1.100000, 1.090909, + 1.083333, 1.076923, 1.071429, 1.066667, 1.062500, 1.058824, + 1.055556, 1.052632, 1.400000, 1.375000, 1.352941, 1.333333, + 1.315789, 1.733333, 1.687500, 1.647059, 1.611111, 1.578947}, + sd::DataType::FLOAT32); - auto x = NDArray('f', { 2, 3, 5 }, sd::DataType::FLOAT32); - auto y = NDArray('f', { 2, 5 }, sd::DataType::FLOAT32); - auto z = NDArray('f', { 2, 3, 5 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto eC = NDArray('c', { 2, 3, 5 }, { 0.100000, 0.181818, 0.250000, 0.307692, 0.357143, 0.600000, 0.636364, 0.666667, 0.692308, 0.714286, 1.100000, 1.090909, 1.083333, 1.076923, 1.071429, 1.066667, 1.062500, 1.058824, 1.055556, 1.052632, 1.400000, 1.375000, 1.352941, 1.333333, 1.315789, 1.733333, 1.687500, 1.647059, 1.611111, 1.578947 }, sd::DataType::FLOAT32); - - auto e = NDArray('f', { 2, 3, 5 }, sd::DataType::FLOAT32); + auto e = NDArray('f', {2, 3, 5}, sd::DataType::FLOAT32); - e.assign(eC); + e.assign(eC); - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); - x.applyBroadcast(sd::broadcast::Divide, { 0,2 }, y, z); + x.applyBroadcast(sd::broadcast::Divide, {0, 2}, y, z); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_4D_1) { - - auto x = NDArray('c', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - auto y = NDArray('c', { 2, 5, 4 }, sd::DataType::FLOAT32); - auto z = NDArray('c', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto e = NDArray('c', { 2, 3, 5, 4 }, { 10.000000, 22.000000, 36.000000, 52.000000, 70.000000, 90.000000, 112.000000, 136.000000, 162.000000, 190.000000, 220.000000, 252.000000, 286.000000, 322.000000, 360.000000, 400.000000, 442.000000, 486.000000, 532.000000, 580.000000, 210.000000, 242.000000, 276.000000, 312.000000, 350.000000, 390.000000, 432.000000, 476.000000, 522.000000, 570.000000, 620.000000, 672.000000, 726.000000, 782.000000, 840.000000, 900.000000, 962.000000, 1026.000000, 1092.000000, 1160.000000, 410.000000, 462.000000, 516.000000, 572.000000, 630.000000, 690.000000, 752.000000, 816.000000, 882.000000, 950.000000, 1020.000000, 1092.000000, 1166.000000, 1242.000000, 1320.000000, 1400.000000, 1482.000000, 1566.000000, 1652.000000, 1740.000000, 1830.000000, 1922.000000, 2016.000000, 2112.000000, 2210.000000, 2310.000000, 2412.000000, 2516.000000, 2622.000000, 2730.000000, 2840.000000, 2952.000000, 3066.000000, 3182.000000, 3300.000000, 3420.000000, 3542.000000, 3666.000000, 3792.000000, 3920.000000, 2430.000000, 2542.000000, 2656.000000, 2772.000000, 2890.000000, 3010.000000, 3132.000000, 3256.000000, 3382.000000, 3510.000000, 3640.000000, 3772.000000, 3906.000000, 4042.000000, 4180.000000, 4320.000000, 4462.000000, 4606.000000, 4752.000000, 4900.000000, 3030.000000, 3162.000000, 3296.000000, 3432.000000, 3570.000000, 3710.000000, 3852.000000, 3996.000000, 4142.000000, 4290.000000, 4440.000000, 4592.000000, 4746.000000, 4902.000000, 5060.000000, 5220.000000, 5382.000000, 5546.000000, 5712.000000, 5880.000000 }, sd::DataType::FLOAT32); - - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyBroadcast(sd::broadcast::Multiply, { 0,2,3 }, y, z); - - ASSERT_EQ(e, z); + auto x = NDArray('c', {2, 3, 5, 4}, sd::DataType::FLOAT32); + auto y = NDArray('c', {2, 5, 4}, sd::DataType::FLOAT32); + auto z = NDArray('c', {2, 3, 5, 4}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = + NDArray('c', {2, 3, 5, 4}, + {10.000000, 22.000000, 36.000000, 52.000000, 70.000000, + 90.000000, 112.000000, 136.000000, 162.000000, 190.000000, + 220.000000, 252.000000, 286.000000, 322.000000, 360.000000, + 400.000000, 442.000000, 486.000000, 532.000000, 580.000000, + 210.000000, 242.000000, 276.000000, 312.000000, 350.000000, + 390.000000, 432.000000, 476.000000, 522.000000, 570.000000, + 620.000000, 672.000000, 726.000000, 782.000000, 840.000000, + 900.000000, 962.000000, 1026.000000, 1092.000000, 1160.000000, + 410.000000, 462.000000, 516.000000, 572.000000, 630.000000, + 690.000000, 752.000000, 816.000000, 882.000000, 950.000000, + 1020.000000, 1092.000000, 1166.000000, 1242.000000, 1320.000000, + 1400.000000, 1482.000000, 1566.000000, 1652.000000, 1740.000000, + 1830.000000, 1922.000000, 2016.000000, 2112.000000, 2210.000000, + 2310.000000, 2412.000000, 2516.000000, 2622.000000, 2730.000000, + 2840.000000, 2952.000000, 3066.000000, 3182.000000, 3300.000000, + 3420.000000, 3542.000000, 3666.000000, 3792.000000, 3920.000000, + 2430.000000, 2542.000000, 2656.000000, 2772.000000, 2890.000000, + 3010.000000, 3132.000000, 3256.000000, 3382.000000, 3510.000000, + 3640.000000, 3772.000000, 3906.000000, 4042.000000, 4180.000000, + 4320.000000, 4462.000000, 4606.000000, 4752.000000, 4900.000000, + 3030.000000, 3162.000000, 3296.000000, 3432.000000, 3570.000000, + 3710.000000, 3852.000000, 3996.000000, 4142.000000, 4290.000000, + 4440.000000, 4592.000000, 4746.000000, 4902.000000, 5060.000000, + 5220.000000, 5382.000000, 5546.000000, 5712.000000, 5880.000000}, + sd::DataType::FLOAT32); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyBroadcast(sd::broadcast::Multiply, {0, 2, 3}, y, z); + + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_4D_2) { - - auto x = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - auto y = NDArray('f', { 2, 5, 4 }, sd::DataType::FLOAT32); - auto z = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto eC = NDArray('c', { 2, 3, 5, 4 }, { 0.100000,0.181818,0.250000,0.307692,0.357143,0.400000,0.437500,0.470588,0.500000,0.526316,0.550000,0.571429, 0.590909,0.608696,0.625000,0.640000, 0.653846,0.666667,0.678571,0.689655, 2.100000,2.000000,1.916667, 1.846154, 1.785714, 1.733333,1.687500, 1.647059,1.611111, 1.578947,1.550000, 1.523810,1.500000, 1.478261,1.458333, 1.440000,1.423077, 1.407407,1.392857, 1.379310,4.100000, 3.818182,3.583333, 3.384615, 3.214286, 3.066667,2.937500, 2.823529,2.722222, 2.631579,2.550000, 2.476191,2.409091, 2.347826,2.291667, 2.240000,2.192308, 2.148148,2.107143, 2.068965,2.033333, 2.000000,1.968750, 1.939394,1.911765, 1.885714,1.861111, 1.837838,1.815789, 1.794872,1.775000, 1.756098,1.738095, 1.720930,1.704545, 1.688889,1.673913, 1.659575,1.645833,1.632653,2.700000,2.645161,2.593750,2.545455,2.500000,2.457143,2.416667,2.378378,2.342105,2.307692,2.275000,2.243902,2.214286,2.186047,2.159091,2.133333,2.108696,2.085106,2.062500,2.040816,3.366667,3.290323,3.218750,3.151515,3.088235,3.028571,2.972222,2.918919,2.868421,2.820513,2.775000,2.731707,2.690476,2.651163,2.613636,2.577778,2.543478,2.510638,2.479167,2.448980 }, sd::DataType::FLOAT32); - - auto e = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - - e.assign(eC); - - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyBroadcast(sd::broadcast::Divide, { 0,2,3 }, y, z); - - ASSERT_EQ(e, z); + auto x = NDArray('f', {2, 3, 5, 4}, sd::DataType::FLOAT32); + auto y = NDArray('f', {2, 5, 4}, sd::DataType::FLOAT32); + auto z = NDArray('f', {2, 3, 5, 4}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto eC = NDArray( + 'c', {2, 3, 5, 4}, + {0.100000, 0.181818, 0.250000, 0.307692, 0.357143, 0.400000, 0.437500, + 0.470588, 0.500000, 0.526316, 0.550000, 0.571429, 0.590909, 0.608696, + 0.625000, 0.640000, 0.653846, 0.666667, 0.678571, 0.689655, 2.100000, + 2.000000, 1.916667, 1.846154, 1.785714, 1.733333, 1.687500, 1.647059, + 1.611111, 1.578947, 1.550000, 1.523810, 1.500000, 1.478261, 1.458333, + 1.440000, 1.423077, 1.407407, 1.392857, 1.379310, 4.100000, 3.818182, + 3.583333, 3.384615, 3.214286, 3.066667, 2.937500, 2.823529, 2.722222, + 2.631579, 2.550000, 2.476191, 2.409091, 2.347826, 2.291667, 2.240000, + 2.192308, 2.148148, 2.107143, 2.068965, 2.033333, 2.000000, 1.968750, + 1.939394, 1.911765, 1.885714, 1.861111, 1.837838, 1.815789, 1.794872, + 1.775000, 1.756098, 1.738095, 1.720930, 1.704545, 1.688889, 1.673913, + 1.659575, 1.645833, 1.632653, 2.700000, 2.645161, 2.593750, 2.545455, + 2.500000, 2.457143, 2.416667, 2.378378, 2.342105, 2.307692, 2.275000, + 2.243902, 2.214286, 2.186047, 2.159091, 2.133333, 2.108696, 2.085106, + 2.062500, 2.040816, 3.366667, 3.290323, 3.218750, 3.151515, 3.088235, + 3.028571, 2.972222, 2.918919, 2.868421, 2.820513, 2.775000, 2.731707, + 2.690476, 2.651163, 2.613636, 2.577778, 2.543478, 2.510638, 2.479167, + 2.448980}, + sd::DataType::FLOAT32); + + auto e = NDArray('f', {2, 3, 5, 4}, sd::DataType::FLOAT32); + + e.assign(eC); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyBroadcast(sd::broadcast::Divide, {0, 2, 3}, y, z); + + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_4D_3) { - - auto x = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - auto y = NDArray('f', { 2, 5 }, sd::DataType::FLOAT32); - auto z = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto eC = NDArray('c', { 2, 3, 5, 4 }, { 0.100000, 0.200000, 0.300000, 0.400000, 0.454545, 0.545455, 0.636364, 0.727273, 0.750000, 0.833333, 0.916667, 1.000000, 1.000000, 1.076923, 1.153846, 1.230769, 1.214286, 1.285714, 1.357143, 1.428571, 2.100000, 2.200000, 2.300000, 2.400000, 2.272727, 2.363636, 2.454545, 2.545455, 2.416667, 2.500000, 2.583333, 2.666667, 2.538461, 2.615385, 2.692308, 2.769231, 2.642857, 2.714286, 2.785714, 2.857143, 4.100000, 4.200000, 4.300000, 4.400000, 4.090909, 4.181818, 4.272727, 4.363636, 4.083333, 4.166667, 4.250000, 4.333333, 4.076923, 4.153846, 4.230769, 4.307693, 4.071429, 4.142857, 4.214286, 4.285714, 4.066667, 4.133333, 4.200000, 4.266667, 4.062500, 4.125000, 4.187500, 4.250000, 4.058824, 4.117647, 4.176471, 4.235294, 4.055555, 4.111111, 4.166667, 4.222222, 4.052631, 4.105263, 4.157895, 4.210526, 5.400000, 5.466667, 5.533333, 5.600000, 5.312500, 5.375000, 5.437500, 5.500000, 5.235294, 5.294117, 5.352941, 5.411765, 5.166667, 5.222222, 5.277778, 5.333333, 5.105263, 5.157895, 5.210526, 5.263158, 6.733333, 6.800000, 6.866667, 6.933333, 6.562500, 6.625000, 6.687500, 6.750000, 6.411765, 6.470588, 6.529412, 6.588235, 6.277778, 6.333333, 6.388889, 6.444445, 6.157895, 6.210526, 6.263158, 6.315790 }, sd::DataType::FLOAT32); - - auto e = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - - e.assign(eC); - - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyBroadcast(sd::broadcast::Divide, { 0,2 }, y, z); - - ASSERT_EQ(e, z); + auto x = NDArray('f', {2, 3, 5, 4}, sd::DataType::FLOAT32); + auto y = NDArray('f', {2, 5}, sd::DataType::FLOAT32); + auto z = NDArray('f', {2, 3, 5, 4}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto eC = NDArray( + 'c', {2, 3, 5, 4}, + {0.100000, 0.200000, 0.300000, 0.400000, 0.454545, 0.545455, 0.636364, + 0.727273, 0.750000, 0.833333, 0.916667, 1.000000, 1.000000, 1.076923, + 1.153846, 1.230769, 1.214286, 1.285714, 1.357143, 1.428571, 2.100000, + 2.200000, 2.300000, 2.400000, 2.272727, 2.363636, 2.454545, 2.545455, + 2.416667, 2.500000, 2.583333, 2.666667, 2.538461, 2.615385, 2.692308, + 2.769231, 2.642857, 2.714286, 2.785714, 2.857143, 4.100000, 4.200000, + 4.300000, 4.400000, 4.090909, 4.181818, 4.272727, 4.363636, 4.083333, + 4.166667, 4.250000, 4.333333, 4.076923, 4.153846, 4.230769, 4.307693, + 4.071429, 4.142857, 4.214286, 4.285714, 4.066667, 4.133333, 4.200000, + 4.266667, 4.062500, 4.125000, 4.187500, 4.250000, 4.058824, 4.117647, + 4.176471, 4.235294, 4.055555, 4.111111, 4.166667, 4.222222, 4.052631, + 4.105263, 4.157895, 4.210526, 5.400000, 5.466667, 5.533333, 5.600000, + 5.312500, 5.375000, 5.437500, 5.500000, 5.235294, 5.294117, 5.352941, + 5.411765, 5.166667, 5.222222, 5.277778, 5.333333, 5.105263, 5.157895, + 5.210526, 5.263158, 6.733333, 6.800000, 6.866667, 6.933333, 6.562500, + 6.625000, 6.687500, 6.750000, 6.411765, 6.470588, 6.529412, 6.588235, + 6.277778, 6.333333, 6.388889, 6.444445, 6.157895, 6.210526, 6.263158, + 6.315790}, + sd::DataType::FLOAT32); + + auto e = NDArray('f', {2, 3, 5, 4}, sd::DataType::FLOAT32); + + e.assign(eC); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyBroadcast(sd::broadcast::Divide, {0, 2}, y, z); + + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_4D_4) { - - // x[4, 12, 128, 128] * y[4, 1, 128, 1] = z[4, 12, 128, 128] - - auto x = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - auto y = NDArray('f', { 2, 1, 5, 1 }, sd::DataType::FLOAT32); - auto z = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto eC = NDArray('c', { 2, 3, 5, 4 }, { 0.100000, 0.200000, 0.300000, 0.400000, 0.454545, 0.545455, 0.636364, 0.727273, 0.750000, 0.833333, 0.916667, 1.000000, 1.000000, 1.076923, 1.153846, 1.230769, 1.214286, 1.285714, 1.357143, 1.428571, 2.100000, 2.200000, 2.300000, 2.400000, 2.272727, 2.363636, 2.454545, 2.545455, 2.416667, 2.500000, 2.583333, 2.666667, 2.538461, 2.615385, 2.692308, 2.769231, 2.642857, 2.714286, 2.785714, 2.857143, 4.100000, 4.200000, 4.300000, 4.400000, 4.090909, 4.181818, 4.272727, 4.363636, 4.083333, 4.166667, 4.250000, 4.333333, 4.076923, 4.153846, 4.230769, 4.307693, 4.071429, 4.142857, 4.214286, 4.285714, 4.066667, 4.133333, 4.200000, 4.266667, 4.062500, 4.125000, 4.187500, 4.250000, 4.058824, 4.117647, 4.176471, 4.235294, 4.055555, 4.111111, 4.166667, 4.222222, 4.052631, 4.105263, 4.157895, 4.210526, 5.400000, 5.466667, 5.533333, 5.600000, 5.312500, 5.375000, 5.437500, 5.500000, 5.235294, 5.294117, 5.352941, 5.411765, 5.166667, 5.222222, 5.277778, 5.333333, 5.105263, 5.157895, 5.210526, 5.263158, 6.733333, 6.800000, 6.866667, 6.933333, 6.562500, 6.625000, 6.687500, 6.750000, 6.411765, 6.470588, 6.529412, 6.588235, 6.277778, 6.333333, 6.388889, 6.444445, 6.157895, 6.210526, 6.263158, 6.315790 }, sd::DataType::FLOAT32); - - auto e = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - e.assign(eC); - - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyTrueBroadcast(BroadcastOpsTuple::Divide(), y, z); - - ASSERT_EQ(e, z); + // x[4, 12, 128, 128] * y[4, 1, 128, 1] = z[4, 12, 128, 128] + + auto x = NDArray('f', {2, 3, 5, 4}, sd::DataType::FLOAT32); + auto y = NDArray('f', {2, 1, 5, 1}, sd::DataType::FLOAT32); + auto z = NDArray('f', {2, 3, 5, 4}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto eC = NDArray( + 'c', {2, 3, 5, 4}, + {0.100000, 0.200000, 0.300000, 0.400000, 0.454545, 0.545455, 0.636364, + 0.727273, 0.750000, 0.833333, 0.916667, 1.000000, 1.000000, 1.076923, + 1.153846, 1.230769, 1.214286, 1.285714, 1.357143, 1.428571, 2.100000, + 2.200000, 2.300000, 2.400000, 2.272727, 2.363636, 2.454545, 2.545455, + 2.416667, 2.500000, 2.583333, 2.666667, 2.538461, 2.615385, 2.692308, + 2.769231, 2.642857, 2.714286, 2.785714, 2.857143, 4.100000, 4.200000, + 4.300000, 4.400000, 4.090909, 4.181818, 4.272727, 4.363636, 4.083333, + 4.166667, 4.250000, 4.333333, 4.076923, 4.153846, 4.230769, 4.307693, + 4.071429, 4.142857, 4.214286, 4.285714, 4.066667, 4.133333, 4.200000, + 4.266667, 4.062500, 4.125000, 4.187500, 4.250000, 4.058824, 4.117647, + 4.176471, 4.235294, 4.055555, 4.111111, 4.166667, 4.222222, 4.052631, + 4.105263, 4.157895, 4.210526, 5.400000, 5.466667, 5.533333, 5.600000, + 5.312500, 5.375000, 5.437500, 5.500000, 5.235294, 5.294117, 5.352941, + 5.411765, 5.166667, 5.222222, 5.277778, 5.333333, 5.105263, 5.157895, + 5.210526, 5.263158, 6.733333, 6.800000, 6.866667, 6.933333, 6.562500, + 6.625000, 6.687500, 6.750000, 6.411765, 6.470588, 6.529412, 6.588235, + 6.277778, 6.333333, 6.388889, 6.444445, 6.157895, 6.210526, 6.263158, + 6.315790}, + sd::DataType::FLOAT32); + + auto e = NDArray('f', {2, 3, 5, 4}, sd::DataType::FLOAT32); + e.assign(eC); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Divide(), y, z); + + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_5D_1) { - // x[4, 12, 128, 128, 128] * y[4, 1, 128, 128, 128] = z[4, 12, 128, 128, 128] - auto x = NDArray('c', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); - auto y = NDArray('c', { 2, 1, 5, 4, 3 }, sd::DataType::FLOAT32); - auto z = NDArray('c', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto e = NDArray('c', { 2, 3, 5, 4, 3 }, { 10.000000, 22.000000, 36.000000, 52.000000, 70.000000, 90.000000, 112.000000, 136.000000, 162.000000, 190.000000, 220.000000, 252.000000, 286.000000, 322.000000, 360.000000, 400.000000, 442.000000, 486.000000, 532.000000, 580.000000, 630.000000, 682.000000, 736.000000, 792.000000, 850.000000, 910.000000, 972.000000, 1036.000000, 1102.000000, 1170.000000, 1240.000000, 1312.000000, 1386.000000, 1462.000000, 1540.000000, 1620.000000, 1702.000000, 1786.000000, 1872.000000, 1960.000000, 2050.000000, 2142.000000, 2236.000000, 2332.000000, 2430.000000, 2530.000000, 2632.000000, 2736.000000, 2842.000000, 2950.000000, 3060.000000, 3172.000000, 3286.000000, 3402.000000, 3520.000000, 3640.000000, 3762.000000, 3886.000000, 4012.000000, 4140.000000, 610.000000, 682.000000, 756.000000, 832.000000, 910.000000, 990.000000, 1072.000000, 1156.000000, 1242.000000, 1330.000000, 1420.000000, 1512.000000, 1606.000000, 1702.000000, 1800.000000, 1900.000000, 2002.000000, 2106.000000, 2212.000000, 2320.000000, 2430.000000, 2542.000000, 2656.000000, 2772.000000, 2890.000000, 3010.000000, 3132.000000, 3256.000000, 3382.000000, 3510.000000, 3640.000000, 3772.000000, 3906.000000, 4042.000000, 4180.000000, 4320.000000, 4462.000000, 4606.000000, 4752.000000, 4900.000000, 5050.000000, 5202.000000, 5356.000000, 5512.000000, 5670.000000, 5830.000000, 5992.000000, 6156.000000, 6322.000000, 6490.000000, 6660.000000, 6832.000000, 7006.000000, 7182.000000, 7360.000000, 7540.000000, 7722.000000, 7906.000000, 8092.000000, 8280.000000, 1210.000000, 1342.000000, 1476.000000, 1612.000000, 1750.000000, 1890.000000, 2032.000000, 2176.000000, 2322.000000, 2470.000000, 2620.000000, 2772.000000, 2926.000000, 3082.000000, 3240.000000, 3400.000000, 3562.000000, 3726.000000, 3892.000000, 4060.000000, 4230.000000, 4402.000000, 4576.000000, 4752.000000, 4930.000000, 5110.000000, 5292.000000, 5476.000000, 5662.000000, 5850.000000, 6040.000000, 6232.000000, 6426.000000, 6622.000000, 6820.000000, 7020.000000, 7222.000000, 7426.000000, 7632.000000, 7840.000000, 8050.000000, 8262.000000, 8476.000000, 8692.000000, 8910.000000, 9130.000000, 9352.000000, 9576.000000, 9802.000000, 10030.000000, 10260.000000, 10492.000000, 10726.000000, 10962.000000, 11200.000000, 11440.000000, 11682.000000, 11926.000000, 12172.000000, 12420.000000, 12670.000000, 12922.000000, 13176.000000, 13432.000000, 13690.000000, 13950.000000, 14212.000000, 14476.000000, 14742.000000, 15010.000000, 15280.000000, 15552.000000, 15826.000000, 16102.000000, 16380.000000, 16660.000000, 16942.000000, 17226.000000, 17512.000000, 17800.000000, 18090.000000, 18382.000000, 18676.000000, 18972.000000, 19270.000000, 19570.000000, 19872.000000, 20176.000000, 20482.000000, 20790.000000, 21100.000000, 21412.000000, 21726.000000, 22042.000000, 22360.000000, 22680.000000, 23002.000000, 23326.000000, 23652.000000, 23980.000000, 24310.000000, 24642.000000, 24976.000000, 25312.000000, 25650.000000, 25990.000000, 26332.000000, 26676.000000, 27022.000000, 27370.000000, 27720.000000, 28072.000000, 28426.000000, 28782.000000, 29140.000000, 29500.000000, 29862.000000, 30226.000000, 30592.000000, 30960.000000, 16870.000000, 17182.000000, 17496.000000, 17812.000000, 18130.000000, 18450.000000, 18772.000000, 19096.000000, 19422.000000, 19750.000000, 20080.000000, 20412.000000, 20746.000000, 21082.000000, 21420.000000, 21760.000000, 22102.000000, 22446.000000, 22792.000000, 23140.000000, 23490.000000, 23842.000000, 24196.000000, 24552.000000, 24910.000000, 25270.000000, 25632.000000, 25996.000000, 26362.000000, 26730.000000, 27100.000000, 27472.000000, 27846.000000, 28222.000000, 28600.000000, 28980.000000, 29362.000000, 29746.000000, 30132.000000, 30520.000000, 30910.000000, 31302.000000, 31696.000000, 32092.000000, 32490.000000, 32890.000000, 33292.000000, 33696.000000, 34102.000000, 34510.000000, 34920.000000, 35332.000000, 35746.000000, 36162.000000, 36580.000000, 37000.000000, 37422.000000, 37846.000000, 38272.000000, 38700.000000, 21070.000000, 21442.000000, 21816.000000, 22192.000000, 22570.000000, 22950.000000, 23332.000000, 23716.000000, 24102.000000, 24490.000000, 24880.000000, 25272.000000, 25666.000000, 26062.000000, 26460.000000, 26860.000000, 27262.000000, 27666.000000, 28072.000000, 28480.000000, 28890.000000, 29302.000000, 29716.000000, 30132.000000, 30550.000000, 30970.000000, 31392.000000, 31816.000000, 32242.000000, 32670.000000, 33100.000000, 33532.000000, 33966.000000, 34402.000000, 34840.000000, 35280.000000, 35722.000000, 36166.000000, 36612.000000, 37060.000000, 37510.000000, 37962.000000, 38416.000000, 38872.000000, 39330.000000, 39790.000000, 40252.000000, 40716.000000, 41182.000000, 41650.000000, 42120.000000, 42592.000000, 43066.000000, 43542.000000, 44020.000000, 44500.000000, 44982.000000, 45466.000000, 45952.000000, 46440.000000 }, sd::DataType::FLOAT32); - - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); - // z.printBuffer(); - ASSERT_EQ(e, z); + // x[4, 12, 128, 128, 128] * y[4, 1, 128, 128, 128] = z[4, 12, 128, 128, 128] + auto x = NDArray('c', {2, 3, 5, 4, 3}, sd::DataType::FLOAT32); + auto y = NDArray('c', {2, 1, 5, 4, 3}, sd::DataType::FLOAT32); + auto z = NDArray('c', {2, 3, 5, 4, 3}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = NDArray( + 'c', {2, 3, 5, 4, 3}, + {10.000000, 22.000000, 36.000000, 52.000000, 70.000000, + 90.000000, 112.000000, 136.000000, 162.000000, 190.000000, + 220.000000, 252.000000, 286.000000, 322.000000, 360.000000, + 400.000000, 442.000000, 486.000000, 532.000000, 580.000000, + 630.000000, 682.000000, 736.000000, 792.000000, 850.000000, + 910.000000, 972.000000, 1036.000000, 1102.000000, 1170.000000, + 1240.000000, 1312.000000, 1386.000000, 1462.000000, 1540.000000, + 1620.000000, 1702.000000, 1786.000000, 1872.000000, 1960.000000, + 2050.000000, 2142.000000, 2236.000000, 2332.000000, 2430.000000, + 2530.000000, 2632.000000, 2736.000000, 2842.000000, 2950.000000, + 3060.000000, 3172.000000, 3286.000000, 3402.000000, 3520.000000, + 3640.000000, 3762.000000, 3886.000000, 4012.000000, 4140.000000, + 610.000000, 682.000000, 756.000000, 832.000000, 910.000000, + 990.000000, 1072.000000, 1156.000000, 1242.000000, 1330.000000, + 1420.000000, 1512.000000, 1606.000000, 1702.000000, 1800.000000, + 1900.000000, 2002.000000, 2106.000000, 2212.000000, 2320.000000, + 2430.000000, 2542.000000, 2656.000000, 2772.000000, 2890.000000, + 3010.000000, 3132.000000, 3256.000000, 3382.000000, 3510.000000, + 3640.000000, 3772.000000, 3906.000000, 4042.000000, 4180.000000, + 4320.000000, 4462.000000, 4606.000000, 4752.000000, 4900.000000, + 5050.000000, 5202.000000, 5356.000000, 5512.000000, 5670.000000, + 5830.000000, 5992.000000, 6156.000000, 6322.000000, 6490.000000, + 6660.000000, 6832.000000, 7006.000000, 7182.000000, 7360.000000, + 7540.000000, 7722.000000, 7906.000000, 8092.000000, 8280.000000, + 1210.000000, 1342.000000, 1476.000000, 1612.000000, 1750.000000, + 1890.000000, 2032.000000, 2176.000000, 2322.000000, 2470.000000, + 2620.000000, 2772.000000, 2926.000000, 3082.000000, 3240.000000, + 3400.000000, 3562.000000, 3726.000000, 3892.000000, 4060.000000, + 4230.000000, 4402.000000, 4576.000000, 4752.000000, 4930.000000, + 5110.000000, 5292.000000, 5476.000000, 5662.000000, 5850.000000, + 6040.000000, 6232.000000, 6426.000000, 6622.000000, 6820.000000, + 7020.000000, 7222.000000, 7426.000000, 7632.000000, 7840.000000, + 8050.000000, 8262.000000, 8476.000000, 8692.000000, 8910.000000, + 9130.000000, 9352.000000, 9576.000000, 9802.000000, 10030.000000, + 10260.000000, 10492.000000, 10726.000000, 10962.000000, 11200.000000, + 11440.000000, 11682.000000, 11926.000000, 12172.000000, 12420.000000, + 12670.000000, 12922.000000, 13176.000000, 13432.000000, 13690.000000, + 13950.000000, 14212.000000, 14476.000000, 14742.000000, 15010.000000, + 15280.000000, 15552.000000, 15826.000000, 16102.000000, 16380.000000, + 16660.000000, 16942.000000, 17226.000000, 17512.000000, 17800.000000, + 18090.000000, 18382.000000, 18676.000000, 18972.000000, 19270.000000, + 19570.000000, 19872.000000, 20176.000000, 20482.000000, 20790.000000, + 21100.000000, 21412.000000, 21726.000000, 22042.000000, 22360.000000, + 22680.000000, 23002.000000, 23326.000000, 23652.000000, 23980.000000, + 24310.000000, 24642.000000, 24976.000000, 25312.000000, 25650.000000, + 25990.000000, 26332.000000, 26676.000000, 27022.000000, 27370.000000, + 27720.000000, 28072.000000, 28426.000000, 28782.000000, 29140.000000, + 29500.000000, 29862.000000, 30226.000000, 30592.000000, 30960.000000, + 16870.000000, 17182.000000, 17496.000000, 17812.000000, 18130.000000, + 18450.000000, 18772.000000, 19096.000000, 19422.000000, 19750.000000, + 20080.000000, 20412.000000, 20746.000000, 21082.000000, 21420.000000, + 21760.000000, 22102.000000, 22446.000000, 22792.000000, 23140.000000, + 23490.000000, 23842.000000, 24196.000000, 24552.000000, 24910.000000, + 25270.000000, 25632.000000, 25996.000000, 26362.000000, 26730.000000, + 27100.000000, 27472.000000, 27846.000000, 28222.000000, 28600.000000, + 28980.000000, 29362.000000, 29746.000000, 30132.000000, 30520.000000, + 30910.000000, 31302.000000, 31696.000000, 32092.000000, 32490.000000, + 32890.000000, 33292.000000, 33696.000000, 34102.000000, 34510.000000, + 34920.000000, 35332.000000, 35746.000000, 36162.000000, 36580.000000, + 37000.000000, 37422.000000, 37846.000000, 38272.000000, 38700.000000, + 21070.000000, 21442.000000, 21816.000000, 22192.000000, 22570.000000, + 22950.000000, 23332.000000, 23716.000000, 24102.000000, 24490.000000, + 24880.000000, 25272.000000, 25666.000000, 26062.000000, 26460.000000, + 26860.000000, 27262.000000, 27666.000000, 28072.000000, 28480.000000, + 28890.000000, 29302.000000, 29716.000000, 30132.000000, 30550.000000, + 30970.000000, 31392.000000, 31816.000000, 32242.000000, 32670.000000, + 33100.000000, 33532.000000, 33966.000000, 34402.000000, 34840.000000, + 35280.000000, 35722.000000, 36166.000000, 36612.000000, 37060.000000, + 37510.000000, 37962.000000, 38416.000000, 38872.000000, 39330.000000, + 39790.000000, 40252.000000, 40716.000000, 41182.000000, 41650.000000, + 42120.000000, 42592.000000, 43066.000000, 43542.000000, 44020.000000, + 44500.000000, 44982.000000, 45466.000000, 45952.000000, 46440.000000}, + sd::DataType::FLOAT32); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); + // z.printBuffer(); + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_5D_2) { - - auto x = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); - auto y = NDArray('f', { 2, 5, 4, 3 }, sd::DataType::FLOAT32); - auto z = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto eC = NDArray('c', { 2, 3, 5, 4, 3 }, { 0.100000, 0.181818, 0.250000, 0.307692, 0.357143, 0.400000, 0.437500, 0.470588, 0.500000, 0.526316, 0.550000, 0.571429, 0.590909, 0.608696, 0.625000, 0.640000, 0.653846, 0.666667, 0.678571, 0.689655, 0.700000, 0.709677, 0.718750, 0.727273, 0.735294, 0.742857, 0.750000, 0.756757, 0.763158, 0.769231, 0.775000, 0.780488, 0.785714, 0.790698, 0.795455, 0.800000, 0.804348, 0.808511, 0.812500, 0.816327, 0.820000, 0.823529, 0.826923, 0.830189, 0.833333, 0.836364, 0.839286, 0.842105, 0.844828, 0.847458, 0.850000, 0.852459, 0.854839, 0.857143, 0.859375, 0.861538, 0.863636, 0.865672, 0.867647, 0.869565, 6.100000, 5.636364, 5.250000, 4.923077, 4.642857, 4.400000, 4.187500, 4.000000, 3.833333, 3.684211, 3.550000, 3.428571, 3.318182, 3.217391, 3.125000, 3.040000, 2.961539, 2.888889, 2.821429, 2.758621, 2.700000, 2.645161, 2.593750, 2.545455, 2.500000, 2.457143, 2.416667, 2.378378, 2.342105, 2.307692, 2.275000, 2.243902, 2.214286, 2.186047, 2.159091, 2.133333, 2.108696, 2.085106, 2.062500, 2.040816, 2.020000, 2.000000, 1.980769, 1.962264, 1.944444, 1.927273, 1.910714, 1.894737, 1.879310, 1.864407, 1.850000, 1.836066, 1.822581, 1.809524, 1.796875, 1.784615, 1.772727, 1.761194, 1.750000, 1.739130, 12.100000, 11.090909, 10.250000, 9.538462, 8.928572, 8.400000, 7.937500, 7.529412, 7.166667, 6.842105, 6.550000, 6.285714, 6.045455, 5.826087, 5.625000, 5.440000, 5.269231, 5.111111, 4.964286, 4.827586, 4.700000, 4.580645, 4.468750, 4.363636, 4.264706, 4.171429, 4.083333, 4.000000, 3.921053, 3.846154, 3.775000, 3.707317, 3.642857, 3.581395, 3.522727, 3.466667, 3.413043, 3.361702, 3.312500, 3.265306, 3.220000, 3.176471, 3.134615, 3.094340, 3.055556, 3.018182, 2.982143, 2.947368, 2.913793, 2.881356, 2.850000, 2.819672, 2.790323, 2.761905, 2.734375, 2.707692, 2.681818, 2.656716, 2.632353, 2.608696, 2.585714, 2.563380, 2.541667, 2.520548, 2.500000, 2.480000, 2.460526, 2.441558, 2.423077, 2.405063, 2.387500, 2.370370, 2.353658, 2.337349, 2.321429, 2.305882, 2.290698, 2.275862, 2.261364, 2.247191, 2.233333, 2.219780, 2.206522, 2.193548, 2.180851, 2.168421, 2.156250, 2.144330, 2.132653, 2.121212, 2.110000, 2.099010, 2.088235, 2.077670, 2.067308, 2.057143, 2.047170, 2.037383, 2.027778, 2.018349, 2.009091, 2.000000, 1.991071, 1.982301, 1.973684, 1.965217, 1.956897, 1.948718, 1.940678, 1.932773, 1.925000, 1.917355, 1.909836, 1.902439, 1.895161, 1.888000, 1.880952, 1.874016, 1.867188, 1.860465, 3.442857, 3.408451, 3.375000, 3.342466, 3.310811, 3.280000, 3.250000, 3.220779, 3.192308, 3.164557, 3.137500, 3.111111, 3.085366, 3.060241, 3.035714, 3.011765, 2.988372, 2.965517, 2.943182, 2.921348, 2.900000, 2.879121, 2.858696, 2.838710, 2.819149, 2.800000, 2.781250, 2.762887, 2.744898, 2.727273, 2.710000, 2.693069, 2.676471, 2.660194, 2.644231, 2.628572, 2.613208, 2.598131, 2.583333, 2.568807, 2.554545, 2.540540, 2.526786, 2.513274, 2.500000, 2.486957, 2.474138, 2.461539, 2.449152, 2.436975, 2.425000, 2.413223, 2.401639, 2.390244, 2.379032, 2.368000, 2.357143, 2.346457, 2.335938, 2.325581, 4.300000, 4.253521, 4.208333, 4.164383, 4.121622, 4.080000, 4.039474, 4.000000, 3.961539, 3.924051, 3.887500, 3.851852, 3.817073, 3.783133, 3.750000, 3.717647, 3.686047, 3.655172, 3.625000, 3.595506, 3.566667, 3.538461, 3.510870, 3.483871, 3.457447, 3.431579, 3.406250, 3.381443, 3.357143, 3.333333, 3.310000, 3.287129, 3.264706, 3.242718, 3.221154, 3.200000, 3.179245, 3.158879, 3.138889, 3.119266, 3.100000, 3.081081, 3.062500, 3.044248, 3.026316, 3.008696, 2.991379, 2.974359, 2.957627, 2.941176, 2.925000, 2.909091, 2.893443, 2.878049, 2.862903, 2.848000, 2.833333, 2.818898, 2.804688, 2.790698 }, sd::DataType::FLOAT32); - - auto e = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); - - e.assign(eC); - - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyBroadcast(sd::broadcast::Divide, { 0,2,3,4 }, y, z); - - ASSERT_EQ(e, z); + auto x = NDArray('f', {2, 3, 5, 4, 3}, sd::DataType::FLOAT32); + auto y = NDArray('f', {2, 5, 4, 3}, sd::DataType::FLOAT32); + auto z = NDArray('f', {2, 3, 5, 4, 3}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto eC = NDArray( + 'c', {2, 3, 5, 4, 3}, + {0.100000, 0.181818, 0.250000, 0.307692, 0.357143, 0.400000, 0.437500, + 0.470588, 0.500000, 0.526316, 0.550000, 0.571429, 0.590909, 0.608696, + 0.625000, 0.640000, 0.653846, 0.666667, 0.678571, 0.689655, 0.700000, + 0.709677, 0.718750, 0.727273, 0.735294, 0.742857, 0.750000, 0.756757, + 0.763158, 0.769231, 0.775000, 0.780488, 0.785714, 0.790698, 0.795455, + 0.800000, 0.804348, 0.808511, 0.812500, 0.816327, 0.820000, 0.823529, + 0.826923, 0.830189, 0.833333, 0.836364, 0.839286, 0.842105, 0.844828, + 0.847458, 0.850000, 0.852459, 0.854839, 0.857143, 0.859375, 0.861538, + 0.863636, 0.865672, 0.867647, 0.869565, 6.100000, 5.636364, 5.250000, + 4.923077, 4.642857, 4.400000, 4.187500, 4.000000, 3.833333, 3.684211, + 3.550000, 3.428571, 3.318182, 3.217391, 3.125000, 3.040000, 2.961539, + 2.888889, 2.821429, 2.758621, 2.700000, 2.645161, 2.593750, 2.545455, + 2.500000, 2.457143, 2.416667, 2.378378, 2.342105, 2.307692, 2.275000, + 2.243902, 2.214286, 2.186047, 2.159091, 2.133333, 2.108696, 2.085106, + 2.062500, 2.040816, 2.020000, 2.000000, 1.980769, 1.962264, 1.944444, + 1.927273, 1.910714, 1.894737, 1.879310, 1.864407, 1.850000, 1.836066, + 1.822581, 1.809524, 1.796875, 1.784615, 1.772727, 1.761194, 1.750000, + 1.739130, 12.100000, 11.090909, 10.250000, 9.538462, 8.928572, 8.400000, + 7.937500, 7.529412, 7.166667, 6.842105, 6.550000, 6.285714, 6.045455, + 5.826087, 5.625000, 5.440000, 5.269231, 5.111111, 4.964286, 4.827586, + 4.700000, 4.580645, 4.468750, 4.363636, 4.264706, 4.171429, 4.083333, + 4.000000, 3.921053, 3.846154, 3.775000, 3.707317, 3.642857, 3.581395, + 3.522727, 3.466667, 3.413043, 3.361702, 3.312500, 3.265306, 3.220000, + 3.176471, 3.134615, 3.094340, 3.055556, 3.018182, 2.982143, 2.947368, + 2.913793, 2.881356, 2.850000, 2.819672, 2.790323, 2.761905, 2.734375, + 2.707692, 2.681818, 2.656716, 2.632353, 2.608696, 2.585714, 2.563380, + 2.541667, 2.520548, 2.500000, 2.480000, 2.460526, 2.441558, 2.423077, + 2.405063, 2.387500, 2.370370, 2.353658, 2.337349, 2.321429, 2.305882, + 2.290698, 2.275862, 2.261364, 2.247191, 2.233333, 2.219780, 2.206522, + 2.193548, 2.180851, 2.168421, 2.156250, 2.144330, 2.132653, 2.121212, + 2.110000, 2.099010, 2.088235, 2.077670, 2.067308, 2.057143, 2.047170, + 2.037383, 2.027778, 2.018349, 2.009091, 2.000000, 1.991071, 1.982301, + 1.973684, 1.965217, 1.956897, 1.948718, 1.940678, 1.932773, 1.925000, + 1.917355, 1.909836, 1.902439, 1.895161, 1.888000, 1.880952, 1.874016, + 1.867188, 1.860465, 3.442857, 3.408451, 3.375000, 3.342466, 3.310811, + 3.280000, 3.250000, 3.220779, 3.192308, 3.164557, 3.137500, 3.111111, + 3.085366, 3.060241, 3.035714, 3.011765, 2.988372, 2.965517, 2.943182, + 2.921348, 2.900000, 2.879121, 2.858696, 2.838710, 2.819149, 2.800000, + 2.781250, 2.762887, 2.744898, 2.727273, 2.710000, 2.693069, 2.676471, + 2.660194, 2.644231, 2.628572, 2.613208, 2.598131, 2.583333, 2.568807, + 2.554545, 2.540540, 2.526786, 2.513274, 2.500000, 2.486957, 2.474138, + 2.461539, 2.449152, 2.436975, 2.425000, 2.413223, 2.401639, 2.390244, + 2.379032, 2.368000, 2.357143, 2.346457, 2.335938, 2.325581, 4.300000, + 4.253521, 4.208333, 4.164383, 4.121622, 4.080000, 4.039474, 4.000000, + 3.961539, 3.924051, 3.887500, 3.851852, 3.817073, 3.783133, 3.750000, + 3.717647, 3.686047, 3.655172, 3.625000, 3.595506, 3.566667, 3.538461, + 3.510870, 3.483871, 3.457447, 3.431579, 3.406250, 3.381443, 3.357143, + 3.333333, 3.310000, 3.287129, 3.264706, 3.242718, 3.221154, 3.200000, + 3.179245, 3.158879, 3.138889, 3.119266, 3.100000, 3.081081, 3.062500, + 3.044248, 3.026316, 3.008696, 2.991379, 2.974359, 2.957627, 2.941176, + 2.925000, 2.909091, 2.893443, 2.878049, 2.862903, 2.848000, 2.833333, + 2.818898, 2.804688, 2.790698}, + sd::DataType::FLOAT32); + + auto e = NDArray('f', {2, 3, 5, 4, 3}, sd::DataType::FLOAT32); + + e.assign(eC); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyBroadcast(sd::broadcast::Divide, {0, 2, 3, 4}, y, z); + + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_5D_3) { - - auto x = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); - auto y = NDArray('f', { 2, 5 }, sd::DataType::FLOAT32); - auto z = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto eC = NDArray('c', { 2, 3, 5, 4, 3 }, { 0.100000, 0.200000, 0.300000, 0.400000, 0.500000, 0.600000, 0.700000, 0.800000, 0.900000, 1.000000, 1.100000, 1.200000, 1.181818, 1.272727, 1.363636, 1.454545, 1.545455, 1.636364, 1.727273, 1.818182, 1.909091, 2.000000, 2.090909, 2.181818, 2.083333, 2.166667, 2.250000, 2.333333, 2.416667, 2.500000, 2.583333, 2.666667, 2.750000, 2.833333, 2.916667, 3.000000, 2.846154, 2.923077, 3.000000, 3.076923, 3.153846, 3.230769, 3.307692, 3.384615, 3.461539, 3.538461, 3.615385, 3.692308, 3.500000, 3.571429, 3.642857, 3.714286, 3.785714, 3.857143, 3.928571, 4.000000, 4.071429, 4.142857, 4.214286, 4.285714, 6.100000, 6.200000, 6.300000, 6.400000, 6.500000, 6.600000, 6.700000, 6.800000, 6.900000, 7.000000, 7.100000, 7.200000, 6.636364, 6.727273, 6.818182, 6.909091, 7.000000, 7.090909, 7.181818, 7.272727, 7.363636, 7.454545, 7.545455, 7.636364, 7.083333, 7.166667, 7.250000, 7.333333, 7.416667, 7.500000, 7.583333, 7.666667, 7.750000, 7.833333, 7.916667, 8.000000, 7.461538, 7.538462, 7.615385, 7.692307, 7.769231, 7.846154, 7.923077, 8.000000, 8.076923, 8.153846, 8.230769, 8.307693, 7.785714, 7.857143, 7.928571, 8.000000, 8.071428, 8.142858, 8.214286, 8.285714, 8.357142, 8.428572, 8.500000, 8.571428, 12.100000, 12.200000, 12.300000, 12.400000, 12.500000, 12.600000, 12.700000, 12.800000, 12.900000, 13.000000, 13.100000, 13.200000, 12.090909, 12.181818, 12.272727, 12.363636, 12.454545, 12.545455, 12.636364, 12.727273, 12.818182, 12.909091, 13.000000, 13.090909, 12.083333, 12.166667, 12.250000, 12.333333, 12.416667, 12.500000, 12.583333, 12.666667, 12.750000, 12.833333, 12.916667, 13.000000, 12.076923, 12.153846, 12.230769, 12.307693, 12.384615, 12.461538, 12.538462, 12.615385, 12.692307, 12.769231, 12.846154, 12.923077, 12.071428, 12.142858, 12.214286, 12.285714, 12.357142, 12.428572, 12.500000, 12.571428, 12.642858, 12.714286, 12.785714, 12.857142, 12.066667, 12.133333, 12.200000, 12.266666, 12.333333, 12.400000, 12.466666, 12.533334, 12.600000, 12.666667, 12.733334, 12.800000, 12.062500, 12.125000, 12.187500, 12.250000, 12.312500, 12.375000, 12.437500, 12.500000, 12.562500, 12.625000, 12.687500, 12.750000, 12.058824, 12.117647, 12.176471, 12.235294, 12.294118, 12.352942, 12.411765, 12.470589, 12.529411, 12.588235, 12.647058, 12.705882, 12.055555, 12.111111, 12.166667, 12.222222, 12.277778, 12.333333, 12.388889, 12.444445, 12.500000, 12.555555, 12.611111, 12.666667, 12.052631, 12.105263, 12.157895, 12.210526, 12.263158, 12.315789, 12.368421, 12.421053, 12.473684, 12.526316, 12.578947, 12.631579, 16.066668, 16.133333, 16.200001, 16.266666, 16.333334, 16.400000, 16.466667, 16.533333, 16.600000, 16.666666, 16.733334, 16.799999, 15.812500, 15.875000, 15.937500, 16.000000, 16.062500, 16.125000, 16.187500, 16.250000, 16.312500, 16.375000, 16.437500, 16.500000, 15.588235, 15.647058, 15.705882, 15.764706, 15.823529, 15.882353, 15.941176, 16.000000, 16.058823, 16.117647, 16.176470, 16.235294, 15.388889, 15.444445, 15.500000, 15.555555, 15.611111, 15.666667, 15.722222, 15.777778, 15.833333, 15.888889, 15.944445, 16.000000, 15.210526, 15.263158, 15.315789, 15.368421, 15.421053, 15.473684, 15.526316, 15.578947, 15.631579, 15.684211, 15.736842, 15.789474, 20.066668, 20.133333, 20.200001, 20.266666, 20.333334, 20.400000, 20.466667, 20.533333, 20.600000, 20.666666, 20.733334, 20.799999, 19.562500, 19.625000, 19.687500, 19.750000, 19.812500, 19.875000, 19.937500, 20.000000, 20.062500, 20.125000, 20.187500, 20.250000, 19.117647, 19.176470, 19.235294, 19.294117, 19.352942, 19.411764, 19.470589, 19.529411, 19.588236, 19.647058, 19.705883, 19.764706, 18.722221, 18.777779, 18.833334, 18.888889, 18.944445, 19.000000, 19.055555, 19.111111, 19.166666, 19.222221, 19.277779, 19.333334, 18.368422, 18.421053, 18.473684, 18.526316, 18.578947, 18.631578, 18.684210, 18.736841, 18.789474, 18.842106, 18.894737, 18.947369 }, sd::DataType::FLOAT32); - - auto e = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); - - e.assign(eC); - - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyBroadcast(sd::broadcast::Divide, { 0,2 }, y, z); - - ASSERT_EQ(e, z); + auto x = NDArray('f', {2, 3, 5, 4, 3}, sd::DataType::FLOAT32); + auto y = NDArray('f', {2, 5}, sd::DataType::FLOAT32); + auto z = NDArray('f', {2, 3, 5, 4, 3}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto eC = NDArray( + 'c', {2, 3, 5, 4, 3}, + {0.100000, 0.200000, 0.300000, 0.400000, 0.500000, 0.600000, + 0.700000, 0.800000, 0.900000, 1.000000, 1.100000, 1.200000, + 1.181818, 1.272727, 1.363636, 1.454545, 1.545455, 1.636364, + 1.727273, 1.818182, 1.909091, 2.000000, 2.090909, 2.181818, + 2.083333, 2.166667, 2.250000, 2.333333, 2.416667, 2.500000, + 2.583333, 2.666667, 2.750000, 2.833333, 2.916667, 3.000000, + 2.846154, 2.923077, 3.000000, 3.076923, 3.153846, 3.230769, + 3.307692, 3.384615, 3.461539, 3.538461, 3.615385, 3.692308, + 3.500000, 3.571429, 3.642857, 3.714286, 3.785714, 3.857143, + 3.928571, 4.000000, 4.071429, 4.142857, 4.214286, 4.285714, + 6.100000, 6.200000, 6.300000, 6.400000, 6.500000, 6.600000, + 6.700000, 6.800000, 6.900000, 7.000000, 7.100000, 7.200000, + 6.636364, 6.727273, 6.818182, 6.909091, 7.000000, 7.090909, + 7.181818, 7.272727, 7.363636, 7.454545, 7.545455, 7.636364, + 7.083333, 7.166667, 7.250000, 7.333333, 7.416667, 7.500000, + 7.583333, 7.666667, 7.750000, 7.833333, 7.916667, 8.000000, + 7.461538, 7.538462, 7.615385, 7.692307, 7.769231, 7.846154, + 7.923077, 8.000000, 8.076923, 8.153846, 8.230769, 8.307693, + 7.785714, 7.857143, 7.928571, 8.000000, 8.071428, 8.142858, + 8.214286, 8.285714, 8.357142, 8.428572, 8.500000, 8.571428, + 12.100000, 12.200000, 12.300000, 12.400000, 12.500000, 12.600000, + 12.700000, 12.800000, 12.900000, 13.000000, 13.100000, 13.200000, + 12.090909, 12.181818, 12.272727, 12.363636, 12.454545, 12.545455, + 12.636364, 12.727273, 12.818182, 12.909091, 13.000000, 13.090909, + 12.083333, 12.166667, 12.250000, 12.333333, 12.416667, 12.500000, + 12.583333, 12.666667, 12.750000, 12.833333, 12.916667, 13.000000, + 12.076923, 12.153846, 12.230769, 12.307693, 12.384615, 12.461538, + 12.538462, 12.615385, 12.692307, 12.769231, 12.846154, 12.923077, + 12.071428, 12.142858, 12.214286, 12.285714, 12.357142, 12.428572, + 12.500000, 12.571428, 12.642858, 12.714286, 12.785714, 12.857142, + 12.066667, 12.133333, 12.200000, 12.266666, 12.333333, 12.400000, + 12.466666, 12.533334, 12.600000, 12.666667, 12.733334, 12.800000, + 12.062500, 12.125000, 12.187500, 12.250000, 12.312500, 12.375000, + 12.437500, 12.500000, 12.562500, 12.625000, 12.687500, 12.750000, + 12.058824, 12.117647, 12.176471, 12.235294, 12.294118, 12.352942, + 12.411765, 12.470589, 12.529411, 12.588235, 12.647058, 12.705882, + 12.055555, 12.111111, 12.166667, 12.222222, 12.277778, 12.333333, + 12.388889, 12.444445, 12.500000, 12.555555, 12.611111, 12.666667, + 12.052631, 12.105263, 12.157895, 12.210526, 12.263158, 12.315789, + 12.368421, 12.421053, 12.473684, 12.526316, 12.578947, 12.631579, + 16.066668, 16.133333, 16.200001, 16.266666, 16.333334, 16.400000, + 16.466667, 16.533333, 16.600000, 16.666666, 16.733334, 16.799999, + 15.812500, 15.875000, 15.937500, 16.000000, 16.062500, 16.125000, + 16.187500, 16.250000, 16.312500, 16.375000, 16.437500, 16.500000, + 15.588235, 15.647058, 15.705882, 15.764706, 15.823529, 15.882353, + 15.941176, 16.000000, 16.058823, 16.117647, 16.176470, 16.235294, + 15.388889, 15.444445, 15.500000, 15.555555, 15.611111, 15.666667, + 15.722222, 15.777778, 15.833333, 15.888889, 15.944445, 16.000000, + 15.210526, 15.263158, 15.315789, 15.368421, 15.421053, 15.473684, + 15.526316, 15.578947, 15.631579, 15.684211, 15.736842, 15.789474, + 20.066668, 20.133333, 20.200001, 20.266666, 20.333334, 20.400000, + 20.466667, 20.533333, 20.600000, 20.666666, 20.733334, 20.799999, + 19.562500, 19.625000, 19.687500, 19.750000, 19.812500, 19.875000, + 19.937500, 20.000000, 20.062500, 20.125000, 20.187500, 20.250000, + 19.117647, 19.176470, 19.235294, 19.294117, 19.352942, 19.411764, + 19.470589, 19.529411, 19.588236, 19.647058, 19.705883, 19.764706, + 18.722221, 18.777779, 18.833334, 18.888889, 18.944445, 19.000000, + 19.055555, 19.111111, 19.166666, 19.222221, 19.277779, 19.333334, + 18.368422, 18.421053, 18.473684, 18.526316, 18.578947, 18.631578, + 18.684210, 18.736841, 18.789474, 18.842106, 18.894737, 18.947369}, + sd::DataType::FLOAT32); + + auto e = NDArray('f', {2, 3, 5, 4, 3}, sd::DataType::FLOAT32); + + e.assign(eC); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyBroadcast(sd::broadcast::Divide, {0, 2}, y, z); + + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_5D_4) { - - auto x = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); - auto y = NDArray('f', { 2, 1, 5, 1, 1 }, sd::DataType::FLOAT32); - auto z = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto eC = NDArray('c', { 2, 3, 5, 4, 3 }, { 0.100000, 0.200000, 0.300000, 0.400000, 0.500000, 0.600000, 0.700000, 0.800000, 0.900000, 1.000000, 1.100000, 1.200000, 1.181818, 1.272727, 1.363636, 1.454545, 1.545455, 1.636364, 1.727273, 1.818182, 1.909091, 2.000000, 2.090909, 2.181818, 2.083333, 2.166667, 2.250000, 2.333333, 2.416667, 2.500000, 2.583333, 2.666667, 2.750000, 2.833333, 2.916667, 3.000000, 2.846154, 2.923077, 3.000000, 3.076923, 3.153846, 3.230769, 3.307692, 3.384615, 3.461539, 3.538461, 3.615385, 3.692308, 3.500000, 3.571429, 3.642857, 3.714286, 3.785714, 3.857143, 3.928571, 4.000000, 4.071429, 4.142857, 4.214286, 4.285714, 6.100000, 6.200000, 6.300000, 6.400000, 6.500000, 6.600000, 6.700000, 6.800000, 6.900000, 7.000000, 7.100000, 7.200000, 6.636364, 6.727273, 6.818182, 6.909091, 7.000000, 7.090909, 7.181818, 7.272727, 7.363636, 7.454545, 7.545455, 7.636364, 7.083333, 7.166667, 7.250000, 7.333333, 7.416667, 7.500000, 7.583333, 7.666667, 7.750000, 7.833333, 7.916667, 8.000000, 7.461538, 7.538462, 7.615385, 7.692307, 7.769231, 7.846154, 7.923077, 8.000000, 8.076923, 8.153846, 8.230769, 8.307693, 7.785714, 7.857143, 7.928571, 8.000000, 8.071428, 8.142858, 8.214286, 8.285714, 8.357142, 8.428572, 8.500000, 8.571428, 12.100000, 12.200000, 12.300000, 12.400000, 12.500000, 12.600000, 12.700000, 12.800000, 12.900000, 13.000000, 13.100000, 13.200000, 12.090909, 12.181818, 12.272727, 12.363636, 12.454545, 12.545455, 12.636364, 12.727273, 12.818182, 12.909091, 13.000000, 13.090909, 12.083333, 12.166667, 12.250000, 12.333333, 12.416667, 12.500000, 12.583333, 12.666667, 12.750000, 12.833333, 12.916667, 13.000000, 12.076923, 12.153846, 12.230769, 12.307693, 12.384615, 12.461538, 12.538462, 12.615385, 12.692307, 12.769231, 12.846154, 12.923077, 12.071428, 12.142858, 12.214286, 12.285714, 12.357142, 12.428572, 12.500000, 12.571428, 12.642858, 12.714286, 12.785714, 12.857142, 12.066667, 12.133333, 12.200000, 12.266666, 12.333333, 12.400000, 12.466666, 12.533334, 12.600000, 12.666667, 12.733334, 12.800000, 12.062500, 12.125000, 12.187500, 12.250000, 12.312500, 12.375000, 12.437500, 12.500000, 12.562500, 12.625000, 12.687500, 12.750000, 12.058824, 12.117647, 12.176471, 12.235294, 12.294118, 12.352942, 12.411765, 12.470589, 12.529411, 12.588235, 12.647058, 12.705882, 12.055555, 12.111111, 12.166667, 12.222222, 12.277778, 12.333333, 12.388889, 12.444445, 12.500000, 12.555555, 12.611111, 12.666667, 12.052631, 12.105263, 12.157895, 12.210526, 12.263158, 12.315789, 12.368421, 12.421053, 12.473684, 12.526316, 12.578947, 12.631579, 16.066668, 16.133333, 16.200001, 16.266666, 16.333334, 16.400000, 16.466667, 16.533333, 16.600000, 16.666666, 16.733334, 16.799999, 15.812500, 15.875000, 15.937500, 16.000000, 16.062500, 16.125000, 16.187500, 16.250000, 16.312500, 16.375000, 16.437500, 16.500000, 15.588235, 15.647058, 15.705882, 15.764706, 15.823529, 15.882353, 15.941176, 16.000000, 16.058823, 16.117647, 16.176470, 16.235294, 15.388889, 15.444445, 15.500000, 15.555555, 15.611111, 15.666667, 15.722222, 15.777778, 15.833333, 15.888889, 15.944445, 16.000000, 15.210526, 15.263158, 15.315789, 15.368421, 15.421053, 15.473684, 15.526316, 15.578947, 15.631579, 15.684211, 15.736842, 15.789474, 20.066668, 20.133333, 20.200001, 20.266666, 20.333334, 20.400000, 20.466667, 20.533333, 20.600000, 20.666666, 20.733334, 20.799999, 19.562500, 19.625000, 19.687500, 19.750000, 19.812500, 19.875000, 19.937500, 20.000000, 20.062500, 20.125000, 20.187500, 20.250000, 19.117647, 19.176470, 19.235294, 19.294117, 19.352942, 19.411764, 19.470589, 19.529411, 19.588236, 19.647058, 19.705883, 19.764706, 18.722221, 18.777779, 18.833334, 18.888889, 18.944445, 19.000000, 19.055555, 19.111111, 19.166666, 19.222221, 19.277779, 19.333334, 18.368422, 18.421053, 18.473684, 18.526316, 18.578947, 18.631578, 18.684210, 18.736841, 18.789474, 18.842106, 18.894737, 18.947369 }, sd::DataType::FLOAT32); - - auto e = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); - e.assign(eC); - - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyTrueBroadcast(BroadcastOpsTuple::Divide(), y, z); - - ASSERT_EQ(e, z); + auto x = NDArray('f', {2, 3, 5, 4, 3}, sd::DataType::FLOAT32); + auto y = NDArray('f', {2, 1, 5, 1, 1}, sd::DataType::FLOAT32); + auto z = NDArray('f', {2, 3, 5, 4, 3}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto eC = NDArray( + 'c', {2, 3, 5, 4, 3}, + {0.100000, 0.200000, 0.300000, 0.400000, 0.500000, 0.600000, + 0.700000, 0.800000, 0.900000, 1.000000, 1.100000, 1.200000, + 1.181818, 1.272727, 1.363636, 1.454545, 1.545455, 1.636364, + 1.727273, 1.818182, 1.909091, 2.000000, 2.090909, 2.181818, + 2.083333, 2.166667, 2.250000, 2.333333, 2.416667, 2.500000, + 2.583333, 2.666667, 2.750000, 2.833333, 2.916667, 3.000000, + 2.846154, 2.923077, 3.000000, 3.076923, 3.153846, 3.230769, + 3.307692, 3.384615, 3.461539, 3.538461, 3.615385, 3.692308, + 3.500000, 3.571429, 3.642857, 3.714286, 3.785714, 3.857143, + 3.928571, 4.000000, 4.071429, 4.142857, 4.214286, 4.285714, + 6.100000, 6.200000, 6.300000, 6.400000, 6.500000, 6.600000, + 6.700000, 6.800000, 6.900000, 7.000000, 7.100000, 7.200000, + 6.636364, 6.727273, 6.818182, 6.909091, 7.000000, 7.090909, + 7.181818, 7.272727, 7.363636, 7.454545, 7.545455, 7.636364, + 7.083333, 7.166667, 7.250000, 7.333333, 7.416667, 7.500000, + 7.583333, 7.666667, 7.750000, 7.833333, 7.916667, 8.000000, + 7.461538, 7.538462, 7.615385, 7.692307, 7.769231, 7.846154, + 7.923077, 8.000000, 8.076923, 8.153846, 8.230769, 8.307693, + 7.785714, 7.857143, 7.928571, 8.000000, 8.071428, 8.142858, + 8.214286, 8.285714, 8.357142, 8.428572, 8.500000, 8.571428, + 12.100000, 12.200000, 12.300000, 12.400000, 12.500000, 12.600000, + 12.700000, 12.800000, 12.900000, 13.000000, 13.100000, 13.200000, + 12.090909, 12.181818, 12.272727, 12.363636, 12.454545, 12.545455, + 12.636364, 12.727273, 12.818182, 12.909091, 13.000000, 13.090909, + 12.083333, 12.166667, 12.250000, 12.333333, 12.416667, 12.500000, + 12.583333, 12.666667, 12.750000, 12.833333, 12.916667, 13.000000, + 12.076923, 12.153846, 12.230769, 12.307693, 12.384615, 12.461538, + 12.538462, 12.615385, 12.692307, 12.769231, 12.846154, 12.923077, + 12.071428, 12.142858, 12.214286, 12.285714, 12.357142, 12.428572, + 12.500000, 12.571428, 12.642858, 12.714286, 12.785714, 12.857142, + 12.066667, 12.133333, 12.200000, 12.266666, 12.333333, 12.400000, + 12.466666, 12.533334, 12.600000, 12.666667, 12.733334, 12.800000, + 12.062500, 12.125000, 12.187500, 12.250000, 12.312500, 12.375000, + 12.437500, 12.500000, 12.562500, 12.625000, 12.687500, 12.750000, + 12.058824, 12.117647, 12.176471, 12.235294, 12.294118, 12.352942, + 12.411765, 12.470589, 12.529411, 12.588235, 12.647058, 12.705882, + 12.055555, 12.111111, 12.166667, 12.222222, 12.277778, 12.333333, + 12.388889, 12.444445, 12.500000, 12.555555, 12.611111, 12.666667, + 12.052631, 12.105263, 12.157895, 12.210526, 12.263158, 12.315789, + 12.368421, 12.421053, 12.473684, 12.526316, 12.578947, 12.631579, + 16.066668, 16.133333, 16.200001, 16.266666, 16.333334, 16.400000, + 16.466667, 16.533333, 16.600000, 16.666666, 16.733334, 16.799999, + 15.812500, 15.875000, 15.937500, 16.000000, 16.062500, 16.125000, + 16.187500, 16.250000, 16.312500, 16.375000, 16.437500, 16.500000, + 15.588235, 15.647058, 15.705882, 15.764706, 15.823529, 15.882353, + 15.941176, 16.000000, 16.058823, 16.117647, 16.176470, 16.235294, + 15.388889, 15.444445, 15.500000, 15.555555, 15.611111, 15.666667, + 15.722222, 15.777778, 15.833333, 15.888889, 15.944445, 16.000000, + 15.210526, 15.263158, 15.315789, 15.368421, 15.421053, 15.473684, + 15.526316, 15.578947, 15.631579, 15.684211, 15.736842, 15.789474, + 20.066668, 20.133333, 20.200001, 20.266666, 20.333334, 20.400000, + 20.466667, 20.533333, 20.600000, 20.666666, 20.733334, 20.799999, + 19.562500, 19.625000, 19.687500, 19.750000, 19.812500, 19.875000, + 19.937500, 20.000000, 20.062500, 20.125000, 20.187500, 20.250000, + 19.117647, 19.176470, 19.235294, 19.294117, 19.352942, 19.411764, + 19.470589, 19.529411, 19.588236, 19.647058, 19.705883, 19.764706, + 18.722221, 18.777779, 18.833334, 18.888889, 18.944445, 19.000000, + 19.055555, 19.111111, 19.166666, 19.222221, 19.277779, 19.333334, + 18.368422, 18.421053, 18.473684, 18.526316, 18.578947, 18.631578, + 18.684210, 18.736841, 18.789474, 18.842106, 18.894737, 18.947369}, + sd::DataType::FLOAT32); + + auto e = NDArray('f', {2, 3, 5, 4, 3}, sd::DataType::FLOAT32); + e.assign(eC); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Divide(), y, z); + + ASSERT_EQ(e, z); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_1) { + float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + float buff2[] = {13, 14, 16, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 16, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + Nd4jLong shape1[] = {2, 3, 4, 4, 1, 0, 1, 99}; + Nd4jLong shape2[] = {2, 3, 4, 4, 1, 0, 1, 99}; + Nd4jLong expShape[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12}; - float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24}; - float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24}; - Nd4jLong shape1[] = {2, 3, 4, 4, 1, 0, 1, 99}; - Nd4jLong shape2[] = {2, 3, 4, 4, 1, 0, 1, 99}; - Nd4jLong expShape[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99}; - ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); - ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); - ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - - NDArray input1(buff1, shape1); - NDArray input2(buff2, shape2); - NDArray expected(expBuff, expShape); - - sd::ops::stack op; - auto results = op.evaluate({&input1, &input2}, {}, {0}); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); - + NDArray input1(buff1, shape1); + NDArray input2(buff2, shape2); + NDArray expected(expBuff, expShape); + sd::ops::stack op; + auto results = op.evaluate({&input1, &input2}, {}, {0}); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_2) { + float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + float buff2[] = {13, 14, 16, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {1, 2, 3, 4, 13, 14, 16, 16, 5, 6, 7, 8, + 17, 18, 19, 20, 9, 10, 11, 12, 21, 22, 23, 24}; + Nd4jLong shape1[] = {2, 3, 4, 4, 1, 0, 1, 99}; + Nd4jLong shape2[] = {2, 3, 4, 4, 1, 0, 1, 99}; + Nd4jLong expShape[] = {3, 3, 2, 4, 8, 4, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12}; - float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24}; - float expBuff[] = {1,2,3,4, 13, 14, 16, 16, 5,6,7,8, 17, 18, 19, 20, 9, 10, 11, 12, 21, 22, 23, 24}; - Nd4jLong shape1[] = {2, 3, 4, 4, 1, 0, 1, 99}; - Nd4jLong shape2[] = {2, 3, 4, 4, 1, 0, 1, 99}; - Nd4jLong expShape[] = {3, 3, 2, 4, 8, 4, 1, 0, 1, 99}; - ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); - ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); - ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - - NDArray input1(buff1, shape1); - NDArray input2(buff2, shape2); - NDArray expected(expBuff, expShape); - - sd::ops::stack op; - auto results = op.evaluate({&input1, &input2}, {}, {1}); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); + NDArray input1(buff1, shape1); + NDArray input2(buff2, shape2); + NDArray expected(expBuff, expShape); + sd::ops::stack op; + auto results = op.evaluate({&input1, &input2}, {}, {1}); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_3) { + float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + float buff2[] = {13, 14, 16, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 16, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + Nd4jLong shape1[] = {2, 1, 12, 12, 1, 0, 1, 99}; + Nd4jLong shape2[] = {2, 1, 12, 12, 1, 0, 1, 99}; + Nd4jLong expShape[] = {3, 2, 1, 12, 12, 12, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12}; - float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24}; - float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24}; - Nd4jLong shape1[] = {2, 1, 12, 12, 1, 0, 1, 99}; - Nd4jLong shape2[] = {2, 1, 12, 12, 1, 0, 1, 99}; - Nd4jLong expShape[] = {3, 2, 1, 12, 12, 12, 1, 0, 1, 99}; - ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); - ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); - ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - - NDArray input1(buff1, shape1); - NDArray input2(buff2, shape2); - NDArray expected(expBuff, expShape); - - sd::ops::stack op; - auto results = op.evaluate({&input1, &input2}, {}, {0}); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); + NDArray input1(buff1, shape1); + NDArray input2(buff2, shape2); + NDArray expected(expBuff, expShape); + sd::ops::stack op; + auto results = op.evaluate({&input1, &input2}, {}, {0}); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_4) { + float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + float buff2[] = {13, 14, 16, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 16, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + Nd4jLong shape1[] = {2, 1, 12, 12, 1, 0, 1, 99}; + Nd4jLong shape2[] = {2, 1, 12, 12, 1, 0, 1, 99}; + Nd4jLong expShape[] = {3, 1, 2, 12, 24, 12, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12}; - float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24}; - float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24}; - Nd4jLong shape1[] = {2, 1, 12, 12, 1, 0, 1, 99}; - Nd4jLong shape2[] = {2, 1, 12, 12, 1, 0, 1, 99}; - Nd4jLong expShape[] = {3, 1, 2, 12, 24, 12, 1, 0, 1, 99}; - ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); - ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); - ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - - NDArray input1(buff1, shape1); - NDArray input2(buff2, shape2); - NDArray expected(expBuff, expShape); - - sd::ops::stack op; - auto results = op.evaluate({&input1, &input2}, {}, {1}); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); + NDArray input1(buff1, shape1); + NDArray input2(buff2, shape2); + NDArray expected(expBuff, expShape); + sd::ops::stack op; + auto results = op.evaluate({&input1, &input2}, {}, {1}); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_5) { + float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + float buff2[] = {13, 14, 16, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 16, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + Nd4jLong shape1[] = {2, 12, 1, 1, 1, 0, 1, 99}; + Nd4jLong shape2[] = {2, 12, 1, 1, 1, 0, 1, 99}; + Nd4jLong expShape[] = {3, 2, 12, 1, 12, 1, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12}; - float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24}; - float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24}; - Nd4jLong shape1[] = {2, 12, 1, 1,1, 0, 1, 99}; - Nd4jLong shape2[] = {2, 12, 1, 1,1, 0, 1, 99}; - Nd4jLong expShape[] = {3, 2, 12, 1, 12, 1, 1, 0, 1, 99}; - ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); - ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); - ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - - NDArray input1(buff1, shape1); - NDArray input2(buff2, shape2); - NDArray expected(expBuff, expShape); - - sd::ops::stack op; - auto results = op.evaluate({&input1, &input2}, {}, {0}); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); + NDArray input1(buff1, shape1); + NDArray input2(buff2, shape2); + NDArray expected(expBuff, expShape); + sd::ops::stack op; + auto results = op.evaluate({&input1, &input2}, {}, {0}); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_6) { + float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + float buff2[] = {13, 14, 16, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {1, 13, 2, 14, 3, 16, 4, 16, 5, 17, 6, 18, + 7, 19, 8, 20, 9, 21, 10, 22, 11, 23, 12, 24}; + Nd4jLong shape1[] = {2, 12, 1, 1, 12, 0, 1, 99}; + Nd4jLong shape2[] = {2, 12, 1, 1, 12, 0, 1, 99}; + Nd4jLong expShape[] = {3, 12, 2, 1, 2, 1, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12}; - float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24}; - float expBuff[] = {1 ,13 ,2 ,14 ,3 ,16 ,4 ,16 ,5 ,17 ,6 ,18 ,7 ,19 ,8 ,20 ,9 ,21 ,10 ,22 ,11 ,23 ,12 ,24}; - Nd4jLong shape1[] = {2, 12, 1, 1, 12, 0, 1, 99}; - Nd4jLong shape2[] = {2, 12, 1, 1, 12, 0, 1, 99}; - Nd4jLong expShape[] = {3, 12, 2, 1, 2, 1, 1, 0, 1, 99}; - ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); - ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); - ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - - NDArray input1(buff1, shape1); - NDArray input2(buff2, shape2); - NDArray expected(expBuff, expShape); - - sd::ops::stack op; - auto results = op.evaluate({&input1, &input2}, {}, {1}); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); + NDArray input1(buff1, shape1); + NDArray input2(buff2, shape2); + NDArray expected(expBuff, expShape); + sd::ops::stack op; + auto results = op.evaluate({&input1, &input2}, {}, {1}); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_7) { + float buff1[] = {1}; + float expBuff[] = {1, 1, 1}; + Nd4jLong shape1[] = {2, 1, 1, 1, 1, 0, 1, 99}; + Nd4jLong expShape[] = {3, 3, 1, 1, 1, 1, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - float buff1[] = {1}; - float expBuff[] = {1, 1, 1}; - Nd4jLong shape1[] = {2, 1, 1, 1, 1, 0, 1, 99}; - Nd4jLong expShape[] = {3, 3, 1, 1, 1, 1, 1, 0, 1, 99}; - ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); - ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - - NDArray input1(buff1, shape1); - NDArray expected(expBuff, expShape); - - sd::ops::stack op; - auto results = op.evaluate({&input1, &input1, &input1}, {}, {0}); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); + NDArray input1(buff1, shape1); + NDArray expected(expBuff, expShape); + sd::ops::stack op; + auto results = op.evaluate({&input1, &input1, &input1}, {}, {0}); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_8) { + float buff1[] = {1}; + float expBuff[] = {1, 1, 1}; + Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99}; + Nd4jLong expShape[] = {2, 3, 1, 1, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - float buff1[] = {1}; - float expBuff[] = {1, 1, 1}; - Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99}; - Nd4jLong expShape[] = {2, 3, 1, 1, 1, 0, 1, 99}; - ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); - ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - - NDArray input1(buff1, shape1); - NDArray expected(expBuff, expShape); - - sd::ops::stack op; - auto results = op.evaluate({&input1, &input1, &input1}, {}, {0}); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); + NDArray input1(buff1, shape1); + NDArray expected(expBuff, expShape); + sd::ops::stack op; + auto results = op.evaluate({&input1, &input1, &input1}, {}, {0}); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_9) { + float buff1[] = {1}; + float expBuff[] = {1, 1, 1}; + Nd4jLong shape1[] = {2, 1, 1, 1, 1, 0, 1, 99}; + Nd4jLong expShape[] = {3, 1, 3, 1, 3, 1, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - float buff1[] = {1}; - float expBuff[] = {1, 1, 1}; - Nd4jLong shape1[] = {2, 1, 1, 1, 1, 0, 1, 99}; - Nd4jLong expShape[] = {3, 1, 3, 1, 3, 1, 1, 0, 1, 99}; - ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); - ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - - NDArray input1(buff1, shape1); - NDArray expected(expBuff, expShape); - - sd::ops::stack op; - auto results = op.evaluate({&input1, &input1, &input1}, {}, {1}); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); + NDArray input1(buff1, shape1); + NDArray expected(expBuff, expShape); + sd::ops::stack op; + auto results = op.evaluate({&input1, &input1, &input1}, {}, {1}); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_10) { + float buff1[] = {1}; + float expBuff[] = {1, 1, 1}; + Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99}; + Nd4jLong expShape[] = {2, 1, 3, 3, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - float buff1[] = {1}; - float expBuff[] = {1, 1, 1}; - Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99}; - Nd4jLong expShape[] = {2, 1, 3, 3, 1, 0, 1, 99}; - ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); - ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - - NDArray input1(buff1, shape1); - NDArray expected(expBuff, expShape); - - sd::ops::stack op; - auto results = op.evaluate({&input1, &input1, &input1}, {}, {1}); - auto output = results.at(0); + NDArray input1(buff1, shape1); + NDArray expected(expBuff, expShape); - //expected.printShapeInfo("exp"); - //output->printShapeInfo("out"); - - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::stack op; + auto results = op.evaluate({&input1, &input1, &input1}, {}, {1}); + auto output = results.at(0); + // expected.printShapeInfo("exp"); + // output->printShapeInfo("out"); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } TEST_F(DeclarableOpsTests14, Stack_11) { + float buff1[] = {1}; + float expBuff[] = {1, 1, 1}; + Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99}; + Nd4jLong expShape[] = {2, 3, 1, 1, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - float buff1[] = {1}; - float expBuff[] = {1, 1, 1}; - Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99}; - Nd4jLong expShape[] = {2, 3, 1, 1, 1, 0, 1, 99}; - ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); - ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - - NDArray input1(buff1, shape1); - NDArray expected(expBuff, expShape); - - sd::ops::stack op; - auto results = op.evaluate({&input1, &input1, &input1}, {}, {}); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); + NDArray input1(buff1, shape1); + NDArray expected(expBuff, expShape); + sd::ops::stack op; + auto results = op.evaluate({&input1, &input1, &input1}, {}, {}); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_12) { - float inBuff[] = {1.0f, 2.0f, 3.0f}; - float expBuff[] = {1.0f, 2.0f, 3.0f}; + float inBuff[] = {1.0f, 2.0f, 3.0f}; + float expBuff[] = {1.0f, 2.0f, 3.0f}; - auto input = NDArrayFactory::create(inBuff, 'c', {1, 3}); + auto input = NDArrayFactory::create(inBuff, 'c', {1, 3}); - auto exp = NDArrayFactory::create(expBuff, 'c', {1, 1, 3}); + auto exp = NDArrayFactory::create(expBuff, 'c', {1, 1, 3}); - sd::ops::stack op; + sd::ops::stack op; - auto result = op.evaluate({&input}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto result = op.evaluate({&input}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_13) { - float inBuff[] = {1.0f, 2.0f, 3.0f}; - float expBuff[] = {1.0f, 2.0f, 3.0f}; - - auto input = NDArrayFactory::create(inBuff, 'c', {1, 1, 3}); + float inBuff[] = {1.0f, 2.0f, 3.0f}; + float expBuff[] = {1.0f, 2.0f, 3.0f}; - auto exp = NDArrayFactory::create(expBuff, 'c', {1, 1, 1, 3}); + auto input = NDArrayFactory::create(inBuff, 'c', {1, 1, 3}); - sd::ops::stack op; + auto exp = NDArrayFactory::create(expBuff, 'c', {1, 1, 1, 3}); - auto result = op.evaluate({&input}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::stack op; - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto result = op.evaluate({&input}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_14) { - float inBuff[] = {1.0f, 2.0f, 3.0f}; - float expBuff[] = {1.0f, 2.0f, 3.0f}; - - auto input = NDArrayFactory::create(inBuff, 'c', {1, 3}); + float inBuff[] = {1.0f, 2.0f, 3.0f}; + float expBuff[] = {1.0f, 2.0f, 3.0f}; - auto exp = NDArrayFactory::create(expBuff, 'c', {1, 1, 3}); + auto input = NDArrayFactory::create(inBuff, 'c', {1, 3}); - sd::ops::stack op; + auto exp = NDArrayFactory::create(expBuff, 'c', {1, 1, 3}); - auto result = op.evaluate({&input}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::stack op; - auto z = result.at(0); + auto result = op.evaluate({&input}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - //z->printShapeInfo(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printShapeInfo(); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, Stack_15) { - auto t = NDArrayFactory::create('c', {2, 3, 5}); - auto u = NDArrayFactory::create('c', {2, 3, 5}); - auto v = NDArrayFactory::create('c', {2, 3, 5}); - auto exp = NDArrayFactory::create('c', {3, 2, 3, 5}); - - sd::ops::stack op; - auto result = op.evaluate({&t, &u, &v}, {}, {-4}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + auto t = NDArrayFactory::create('c', {2, 3, 5}); + auto u = NDArrayFactory::create('c', {2, 3, 5}); + auto v = NDArrayFactory::create('c', {2, 3, 5}); + auto exp = NDArrayFactory::create('c', {3, 2, 3, 5}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); + sd::ops::stack op; + auto result = op.evaluate({&t, &u, &v}, {}, {-4}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); } - TEST_F(DeclarableOpsTests14, Stack_16) { - auto t = NDArrayFactory::create(1.0f); - auto u = NDArrayFactory::create(2.0f); - auto v = NDArrayFactory::create(3.0f); - auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto t = NDArrayFactory::create(1.0f); + auto u = NDArrayFactory::create(2.0f); + auto v = NDArrayFactory::create(3.0f); + auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); - sd::ops::stack op; - auto result = op.evaluate({&t, &u, &v}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::stack op; + auto result = op.evaluate({&t, &u, &v}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, Stack_17) { - auto t = NDArrayFactory::create('c', {1, 1}, {1.0f}); - auto u = NDArrayFactory::create('c', {1, 1}, {2.0f}); - auto v = NDArrayFactory::create('c', {1, 1}, {3.0f}); - auto w = NDArrayFactory::create('c', {1, 1}, {4.0f}); - auto exp = NDArrayFactory::create('c', {4, 1, 1}, {1, 2, 3, 4}); - - sd::ops::stack op; - auto result = op.evaluate({&t, &u, &v, &w}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto t = NDArrayFactory::create('c', {1, 1}, {1.0f}); + auto u = NDArrayFactory::create('c', {1, 1}, {2.0f}); + auto v = NDArrayFactory::create('c', {1, 1}, {3.0f}); + auto w = NDArrayFactory::create('c', {1, 1}, {4.0f}); + auto exp = NDArrayFactory::create('c', {4, 1, 1}, {1, 2, 3, 4}); - auto z = result.at(0); + sd::ops::stack op; + auto result = op.evaluate({&t, &u, &v, &w}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - // z->printShapeInfo("z shape"); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printShapeInfo("z shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, Stack_18) { - auto x = NDArrayFactory::create('c', {0}); - auto e = NDArrayFactory::create('c', {1, 0}); - - sd::ops::stack op; - auto result = op.evaluate({&x}, {}, {0}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {0}); + auto e = NDArrayFactory::create('c', {1, 0}); - auto z = result.at(0); - ASSERT_EQ(e, *z); - sd::ops::reduce_min sumOp; - auto res2 = sumOp.evaluate({&e}, {1.}, {1}); - ASSERT_EQ(res2.status(), Status::OK()); - auto out = res2.at(0); - - ASSERT_EQ(out->e(0), DataTypeUtils::infOrMax()); + sd::ops::stack op; + auto result = op.evaluate({&x}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(e, z); + sd::ops::reduce_min sumOp; + auto res2 = sumOp.evaluate({&e}, {1.}, {1}); + ASSERT_EQ(res2.status(), Status::OK()); + auto out = res2.at(0); + ASSERT_EQ(out.e(0), DataTypeUtils::infOrMax()); } TEST_F(DeclarableOpsTests14, Stack_19) { - auto x = NDArrayFactory::empty(); - auto e = NDArrayFactory::create('c', {0}); - - sd::ops::stack op; - auto result = op.evaluate({&x}, {}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_EQ(e, *z); + auto x = NDArrayFactory::empty(); + auto e = NDArrayFactory::create('c', {0}); + sd::ops::stack op; + auto result = op.evaluate({&x}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests14, Stack_20) { - auto x = NDArrayFactory::empty(); - auto e = NDArrayFactory::create('c', {2, 0}); - - sd::ops::stack op; - auto result = op.evaluate({&x, &x}, {}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_EQ(e, *z); + auto x = NDArrayFactory::empty(); + auto e = NDArrayFactory::create('c', {2, 0}); + sd::ops::stack op; + auto result = op.evaluate({&x, &x}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests14, Stack_21) { + NDArray x1('c', {3, 2}, sd::DataType::FLOAT32); + NDArray x2('c', {3, 2}, sd::DataType::FLOAT32); + x1.linspace(0); + x2.linspace(6); - NDArray x1('c', {3,2}, sd::DataType::FLOAT32); - NDArray x2('c', {3,2}, sd::DataType::FLOAT32); - x1.linspace(0); - x2.linspace(6); + sd::ops::stack opStack; + auto resultStack = opStack.evaluate({&x1, &x2}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, resultStack.status()); - sd::ops::stack opStack; - auto resultStack = opStack.evaluate({&x1, &x2}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, resultStack.status()); + sd::ops::concat opConcat; + auto resultConcat = opConcat.evaluate({&x1, &x2}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, resultConcat.status()); + auto outStack = resultStack.at(0); + auto outConcat = resultConcat.at(0); - sd::ops::concat opConcat; - auto resultConcat = opConcat.evaluate({&x1, &x2}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, resultConcat.status()); + outConcat.reshapei({2, 3, 2}); - auto outStack = resultStack.at(0); - auto outConcat = resultConcat.at(0); - - outConcat->reshapei({2,3,2}); - - ASSERT_TRUE(outStack->isSameShape(outConcat)); - ASSERT_TRUE(outStack->equalsTo(outConcat)); + ASSERT_TRUE(outStack.isSameShape(outConcat)); + ASSERT_TRUE(outStack.equalsTo(outConcat)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Reshape1) { - const std::vector xShape = { 5,4,3 }; - const std::vector yShape = { 3,5,4 }; + const std::vector xShape = {5, 4, 3}; + const std::vector yShape = {3, 5, 4}; - auto x = NDArrayFactory::create_('f', xShape); - auto y = NDArrayFactory::create_('f', yShape); + auto x = NDArrayFactory::create('f', xShape); + auto y = NDArrayFactory::create('f', yShape); + VariableSpace variableSpace; + variableSpace.putVariable(-1, x); + variableSpace.putVariable(-2, y); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + Context block(1, &variableSpace, false); + block.fillInputs({-1, -2}); - sd::ops::reshapeas reshape; + sd::ops::reshapeas reshape; - reshape.execute(block); + reshape.execute(&block); - ASSERT_TRUE(x->isSameShape(y)); + ASSERT_TRUE(variableSpace.hasVariable(1)); + auto z = variableSpace.getVariable(1)->getNDArray().get(); - delete variableSpace; - delete block; + ASSERT_TRUE(y.isSameShape(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Reshape2) { - const std::vector xShape = { 5,4,3 }; - const std::vector yShape = { 3,5,4 }; + const std::vector xShape = {5, 4, 3}; + const std::vector yShape = {3, 5, 4}; - auto x = NDArrayFactory::create_('c', xShape); - auto y = NDArrayFactory::create_('c', yShape); + auto x = NDArrayFactory::create('c', xShape); + auto y = NDArrayFactory::create('c', yShape); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(1, new Variable()); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(1, std::make_shared()); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1 }); - std::vector* arguments = block->getIArguments(); - arguments->push_back(-y->ordering()); - arguments->push_back(3); - arguments->push_back(5); - arguments->push_back(4); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->appendI(-y.ordering()); + block->appendI(3); + block->appendI(5); + block->appendI(4); - sd::ops::reshape reshape; + sd::ops::reshape reshape; - Nd4jStatus status = reshape.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + Nd4jStatus status = reshape.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - ASSERT_TRUE(result->isSameShape(y)); + ASSERT_TRUE(result->isSameShape(y)); - delete y; - delete block; - delete variableSpace; + delete block; + delete variableSpace; } TEST_F(DeclarableOpsTests14, Reshape3) { - auto x = NDArrayFactory::create('c', { 3, 4, 5 }); + auto x = NDArrayFactory::create('c', {3, 4, 5}); - sd::ops::reshape op; - auto result = op.evaluate({ &x }, {}, { -99, 3, 4, 5 }); + sd::ops::reshape op; + auto result = op.evaluate({&x}, {}, {-99, 3, 4, 5}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(x.isSameShape(z)); + ASSERT_TRUE(x.isSameShape(z)); } TEST_F(DeclarableOpsTests14, Reshape4) { - auto x = NDArrayFactory::create('c', { 3, 4, 5 }); + auto x = NDArrayFactory::create('c', {3, 4, 5}); - sd::ops::reshape op; - auto result = op.evaluate({ &x }, {}, { 3, 4, 5 }); + sd::ops::reshape op; + auto result = op.evaluate({&x}, {}, {3, 4, 5}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(x.isSameShape(z)); + ASSERT_TRUE(x.isSameShape(z)); } TEST_F(DeclarableOpsTests14, Reshape5) { - auto x = NDArrayFactory::create('c', { 3, 4, 5 }); + auto x = NDArrayFactory::create('c', {3, 4, 5}); - sd::ops::reshape op; - auto result = op.evaluate({ &x }, {}, { 5, 4, 3 }); + sd::ops::reshape op; + auto result = op.evaluate({&x}, {}, {5, 4, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); } TEST_F(DeclarableOpsTests14, Reshape6) { - auto x = NDArrayFactory::create('c', { 3, 4, 5 }); - auto exp = NDArrayFactory::create('c', { 4, 15 }); + auto x = NDArrayFactory::create('c', {3, 4, 5}); + auto exp = NDArrayFactory::create('c', {4, 15}); - sd::ops::reshape op; - auto result = op.evaluate({ &x }, {}, { 4, -1 }); + sd::ops::reshape op; + auto result = op.evaluate({&x}, {}, {4, -1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(z->isSameShape(exp)); + ASSERT_TRUE(z.isSameShape(exp)); } TEST_F(DeclarableOpsTests14, Reshape7) { - auto x = NDArrayFactory::create('c', { 3, 4, 5 }); - auto exp = NDArrayFactory::create('c', { 60 }); + auto x = NDArrayFactory::create('c', {3, 4, 5}); + auto exp = NDArrayFactory::create('c', {60}); - sd::ops::reshape op; - auto result = op.evaluate({ &x }, {}, { -1 }); + sd::ops::reshape op; + auto result = op.evaluate({&x}, {}, {-1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(z->isSameShape(exp)); + ASSERT_TRUE(z.isSameShape(exp)); } TEST_F(DeclarableOpsTests14, Reshape8) { - auto x = NDArrayFactory::create('f', {2, 3}, {1.0, 4.0, 2.0, 5.0, 3.0, 6.0}); - auto e = NDArrayFactory::create('f', {3, 2}, {1.0, 3.0, 5.0, 2.0, 4.0, 6.0}); + auto x = NDArrayFactory::create('f', {2, 3}, + {1.0, 4.0, 2.0, 5.0, 3.0, 6.0}); + auto e = NDArrayFactory::create('f', {3, 2}, + {1.0, 3.0, 5.0, 2.0, 4.0, 6.0}); - auto r = x.reshape('c', {3, 2});; - r.streamline('f'); + auto r = x.reshape('c', {3, 2}); + ; + r.streamline('f'); - sd::ops::reshape op; - auto result = op.evaluate({&x}, {3, 2}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reshape op; + auto result = op.evaluate({&x}, {3, 2}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); } TEST_F(DeclarableOpsTests14, Reshape9) { - auto array = NDArrayFactory::create(119.f); - auto e = NDArrayFactory::create('c', {1, 1}, {119.f}); + auto array = NDArrayFactory::create(119.f); + auto e = NDArrayFactory::create('c', {1, 1}, {119.f}); - sd::ops::reshape op; - auto result = op.evaluate({&array}, {}, {1, 1}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reshape op; + auto result = op.evaluate({&array}, {}, {1, 1}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests14, Reshape10) { - auto array = NDArrayFactory::create(119.f); - auto e = NDArrayFactory::create('c', {1, 1}, {119.f}); - auto z = NDArrayFactory::create('c', {1, 1}); + auto array = NDArrayFactory::create(119.f); + auto e = NDArrayFactory::create('c', {1, 1}, {119.f}); + auto z = NDArrayFactory::create('c', {1, 1}); - sd::ops::reshape op; - auto result = op.execute({&array}, {&z}, {}, {1, 1}, {}); - ASSERT_EQ(Status::OK(), result); - ASSERT_EQ(e, z); + sd::ops::reshape op; + auto result = op.execute({&array}, {&z}, {}, {1, 1}, {}); + ASSERT_EQ(Status::OK(), result); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests14, Reshape11) { - auto x = NDArrayFactory::create('c', {4, 3}); - auto exp = NDArrayFactory::create('c', {4, 3}); + auto x = NDArrayFactory::create('c', {4, 3}); + auto exp = NDArrayFactory::create('c', {4, 3}); - x.linspace(1); - exp.linspace(1); + x.linspace(1); + exp.linspace(1); - sd::ops::reshape op; - auto result = op.evaluate({&x}, {-99, 4, 3}); + sd::ops::reshape op; + auto result = op.evaluate({&x}, {-99, 4, 3}); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, Reshape12) { - auto x = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - auto shape = NDArrayFactory::create('c', {2}, {-1, 2}); - auto exp = NDArrayFactory::create('c', {4, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + auto x = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + auto shape = NDArrayFactory::create('c', {2}, {-1, 2}); + auto exp = + NDArrayFactory::create('c', {4, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - sd::ops::reshape op; - auto result = op.evaluate({&x, &shape}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reshape op; + auto result = op.evaluate({&x, &shape}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, Reshape13) { - auto vector = NDArrayFactory::create('c', {1}, {119.0f}); - auto exp = NDArrayFactory::create(119.f); - auto empty = NDArrayFactory::empty_(); + auto vector = NDArrayFactory::create('c', {1}, {119.0f}); + auto exp = NDArrayFactory::create(119.f); + auto empty = NDArrayFactory::empty_(); - sd::ops::reshape op; - auto result = op.evaluate({&vector, empty}, {}, {}); + sd::ops::reshape op; + auto result = op.evaluate({&vector, empty}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(exp, *result.at(0)); + ASSERT_EQ(exp, result.at(0)); - delete empty; + delete empty; } TEST_F(DeclarableOpsTests14, Reshape14) { - auto x = NDArrayFactory::create('c', {1, 0, 0, 2}); - auto y = NDArrayFactory::create('c', {2}, {10, 0}); - auto e = NDArrayFactory::create('c', {10, 0}); - - sd::ops::reshape op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {1, 0, 0, 2}); + auto y = NDArrayFactory::create('c', {2}, {10, 0}); + auto e = NDArrayFactory::create('c', {10, 0}); - auto z = result.at(0); + sd::ops::reshape op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_EQ(e, *z); + auto z = result.at(0); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests14, Reshape15) { + auto x0 = NDArrayFactory::create('c', {2, 0}); + auto x1 = NDArrayFactory::create('c', {0, 1, 2}); - auto x0 = NDArrayFactory::create('c', {2, 0}); - auto x1 = NDArrayFactory::create('c', {0, 1, 2}); + auto shape0 = NDArrayFactory::create('c', {3}, {2, 0, -1}); + auto shape1 = NDArrayFactory::create('c', {2}, {-1, 1}); - auto shape0 = NDArrayFactory::create('c', {3}, {2, 0, -1}); - auto shape1 = NDArrayFactory::create('c', {2}, {-1, 1}); + auto e0 = NDArrayFactory::create('c', {2, 0, 1}); + auto e1 = NDArrayFactory::create('c', {0, 1}); - auto e0 = NDArrayFactory::create('c', {2, 0, 1}); - auto e1 = NDArrayFactory::create('c', {0, 1}); + sd::ops::reshape op; + auto result0 = op.evaluate({&x0, &shape0}, {}, {}); + ASSERT_EQ(Status::OK(), result0.status()); + auto z0 = result0.at(0); + ASSERT_EQ(e0, z0); - sd::ops::reshape op; - auto result0 = op.evaluate({&x0, &shape0}, {}, {}); - ASSERT_EQ(Status::OK(), result0.status()); - auto z0 = result0.at(0); - ASSERT_EQ(e0, *z0); - - auto result1 = op.evaluate({&x1, &shape1}, {}, {}); - ASSERT_EQ(Status::OK(), result1.status()); - auto z1 = result1.at(0); - ASSERT_EQ(e1, *z1); + auto result1 = op.evaluate({&x1, &shape1}, {}, {}); + ASSERT_EQ(Status::OK(), result1.status()); + auto z1 = result1.at(0); + ASSERT_EQ(e1, z1); } TEST_F(DeclarableOpsTests14, Reshape16) { - auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto shape = NDArrayFactory::create('c', {1, 3}, {1, 2, 2}); + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto shape = NDArrayFactory::create('c', {1, 3}, {1, 2, 2}); - auto exp = NDArrayFactory::create('c', {1, 2, 2}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 2, 2}, {1, 2, 3, 4}); - sd::ops::reshape op; + sd::ops::reshape op; - auto result = op.evaluate({&x, &shape}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({&x, &shape}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, Reshape17) { - auto x = NDArrayFactory::create(2.0f); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {2.0f}); + auto x = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {2.0f}); - sd::ops::reshape op; - auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reshape op; + auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, Reshape18) { - auto x = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); - auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); - - sd::ops::reshape op; - auto result = op.evaluate({&x}, {}, {-99, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::reshape op; + auto result = op.evaluate({&x}, {}, {-99, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, Reshape19) { - auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto exp = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); - sd::ops::reshape op; - auto result = op.evaluate({&x}, {}, {-99, 1, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reshape op; + auto result = op.evaluate({&x}, {}, {-99, 1, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests14, Reshape20) { - - NDArray x1('c', {2,0}, sd::DataType::FLOAT32); - NDArray x2('c', {10,0}, sd::DataType::FLOAT32); - NDArray x3('c', {2,0,0,10}, sd::DataType::FLOAT32); - NDArray x4('c', {0,0,10}, sd::DataType::FLOAT32); - NDArray x5('c', {0,2,10}, sd::DataType::FLOAT32); - NDArray x6('c', {0,10,0}, sd::DataType::FLOAT32); - NDArray x7('c', {0,1,2}, sd::DataType::FLOAT32); - NDArray x8('c', {1,2,0}, sd::DataType::FLOAT32); - - sd::ops::reshape op; - - auto result = op.evaluate({&x1}, {}, {2, -1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0)->isSameShape({2,0})); - - result = op.evaluate({&x2}, {}, {2, 0, -1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0)->isSameShape({2,0,5})); - - result = op.evaluate({&x2}, {}, {5, 2, -1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0)->isSameShape({5,2,0})); - - result = op.evaluate({&x2}, {}, {-1, 2, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0)->isSameShape({5,2,0})); - - result = op.evaluate({&x3}, {}, {2, 0, -1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0)->isSameShape({2,0,10})); - - result = op.evaluate({&x4}, {}, {2, -1, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0)->isSameShape({2,5,0})); - - result = op.evaluate({&x5}, {}, {2, 0, 0, 0, -1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0)->isSameShape({2,0,0,0,10})); - - result = op.evaluate({&x6}, {}, {-1, 2, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0)->isSameShape({5, 2, 0})); - - result = op.evaluate({&x7}, {}, {-1, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0)->isSameShape({2, 0})); - - result = op.evaluate({&x7}, {}, {10,0,50,100}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0)->isSameShape({10,0,50,100})); - - result = op.evaluate({&x7}, {}, {2,0,-1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0)->isSameShape({2,0,1})); + NDArray x1('c', {2, 0}, sd::DataType::FLOAT32); + NDArray x2('c', {10, 0}, sd::DataType::FLOAT32); + NDArray x3('c', {2, 0, 0, 10}, sd::DataType::FLOAT32); + NDArray x4('c', {0, 0, 10}, sd::DataType::FLOAT32); + NDArray x5('c', {0, 2, 10}, sd::DataType::FLOAT32); + NDArray x6('c', {0, 10, 0}, sd::DataType::FLOAT32); + NDArray x7('c', {0, 1, 2}, sd::DataType::FLOAT32); + NDArray x8('c', {1, 2, 0}, sd::DataType::FLOAT32); + + sd::ops::reshape op; + + auto result = op.evaluate({&x1}, {}, {2, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0).isSameShape({2, 0})); + + result = op.evaluate({&x2}, {}, {2, 0, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0).isSameShape({2, 0, 5})); + + result = op.evaluate({&x2}, {}, {5, 2, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0).isSameShape({5, 2, 0})); + + result = op.evaluate({&x2}, {}, {-1, 2, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0).isSameShape({5, 2, 0})); + + result = op.evaluate({&x3}, {}, {2, 0, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0).isSameShape({2, 0, 10})); + + result = op.evaluate({&x4}, {}, {2, -1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0).isSameShape({2, 5, 0})); + + result = op.evaluate({&x5}, {}, {2, 0, 0, 0, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0).isSameShape({2, 0, 0, 0, 10})); + + result = op.evaluate({&x6}, {}, {-1, 2, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0).isSameShape({5, 2, 0})); + + result = op.evaluate({&x7}, {}, {-1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0).isSameShape({2, 0})); + + result = op.evaluate({&x7}, {}, {10, 0, 50, 100}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0).isSameShape({10, 0, 50, 100})); + + result = op.evaluate({&x7}, {}, {2, 0, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0).isSameShape({2, 0, 1})); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index e01900e87311..2f446570309c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -14,184 +14,186 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // Created by raver on 8/4/2018. // -#include "testlayers.h" -#include #include -#include #include +#include +#include + #include +#include "testlayers.h" using namespace sd; - class DeclarableOpsTests15 : public testing::Test { -public: - - DeclarableOpsTests15() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests15() { + printf("\n"); + fflush(stdout); + } }; TEST_F(DeclarableOpsTests15, Test_NormalizeMoments_1) { - auto d = NDArrayFactory::create('c', {10, 10}); - auto w = NDArrayFactory::create(10); - auto x = NDArrayFactory::create('c', {10}); - auto y = NDArrayFactory::create('c', {10}); + auto d = NDArrayFactory::create('c', {10, 10}); + auto w = NDArrayFactory::create(10); + auto x = NDArrayFactory::create('c', {10}); + auto y = NDArrayFactory::create('c', {10}); - auto z0 = NDArrayFactory::create('c', {10}); - auto z1 = NDArrayFactory::create('c', {10}); + auto z0 = NDArrayFactory::create('c', {10}); + auto z1 = NDArrayFactory::create('c', {10}); - sd::ops::normalize_moments op; - auto result = op.execute({&w, &x, &y}, std::vector{&z0, &z1}, {1e-4}, {}, {}); - ASSERT_EQ(Status::OK(), result); + sd::ops::normalize_moments op; + auto result = + op.execute({&w, &x, &y}, std::vector{&z0, &z1}, {1e-4}, {}, {}); + ASSERT_EQ(Status::OK(), result); } TEST_F(DeclarableOpsTests15, Test_Add_1) { - auto x = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); - auto y = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); - auto e = NDArrayFactory::create('c', {5}, {2, 2, 2, 2, 2}); + auto x = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); + auto y = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); + auto e = NDArrayFactory::create('c', {5}, {2, 2, 2, 2, 2}); - sd::ops::add op; - auto result = op.execute({&x, &y}, {&x}, {}, {}, {}); - ASSERT_EQ(Status::OK(), result); - ASSERT_EQ(e, x); + sd::ops::add op; + auto result = op.execute({&x, &y}, {&x}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); + ASSERT_EQ(e, x); } TEST_F(DeclarableOpsTests15, Test_Half_assign_1) { - auto x = NDArrayFactory::create('c', {2, 5}); - int y = 1; - x.assign(y); + auto x = NDArrayFactory::create('c', {2, 5}); + int y = 1; + x.assign(y); - ASSERT_EQ(10, x.sumNumber().e(0)); + ASSERT_EQ(10, x.sumNumber().e(0)); } TEST_F(DeclarableOpsTests15, Test_standarize_1) { - auto x = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); - auto e = NDArrayFactory::create('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f}); + auto x = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto e = NDArrayFactory::create('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f}); - sd::ops::standardize op; - auto result = op.execute({&x}, {&x}, {}, {0}, {}); - ASSERT_EQ(Status::OK(), result); - ASSERT_EQ(e, x); + sd::ops::standardize op; + auto result = op.execute({&x}, {&x}, {}, {0}, {}); + ASSERT_EQ(Status::OK(), result); + ASSERT_EQ(e, x); } TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) { - auto x = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); - auto eps = NDArrayFactory::create('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f}); - - sd::ops::standardize_bp op; - auto result = op.evaluate({&x, &eps}, {0}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto eps = NDArrayFactory::create('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f}); + sd::ops::standardize_bp op; + auto result = op.evaluate({&x, &eps}, {0}); + ASSERT_EQ(Status::OK(), result.status()); } TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) { - auto x = NDArrayFactory::create('c', {4,4,3}); - NDArray factor = NDArrayFactory::create(2.); - auto e = NDArrayFactory::create('c', {4,4,3}, {-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5, - 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5, - 26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5, - 50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5}); + auto x = NDArrayFactory::create('c', {4, 4, 3}); + NDArray factor = NDArrayFactory::create(2.); + auto e = NDArrayFactory::create( + 'c', {4, 4, 3}, + {-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, + -2.5, -1.5, 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, + 16.5, 20.5, 21.5, 22.5, 26.5, 27.5, 28.5, 32.5, 33.5, 34.5, + 38.5, 39.5, 40.5, 44.5, 45.5, 46.5, 50.5, 51.5, 52.5, 56.5, + 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5}); + x.linspace(1.); + sd::ops::adjust_contrast op; + auto result = op.evaluate({&x, &factor}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); - x.linspace(1.); - sd::ops::adjust_contrast op; - auto result = op.evaluate({&x, &factor}, {}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto out = result.at(0); - - ASSERT_TRUE(e.equalsTo(out)); - + ASSERT_TRUE(e.equalsTo(out)); } TEST_F(DeclarableOpsTests15, Test_AdjustContrast_2) { - auto x = NDArrayFactory::create('c', {1, 4,4,3}); - auto e = NDArrayFactory::create('c', {1, 4,4,3}, { - -21.5f, -20.5f, -19.5f, -15.5f, -14.5f, -13.5f, -9.5f, -8.5f, -7.5f, -3.5f, -2.5f, -1.5f, - 2.5f, 3.5f, 4.5f, 8.5f, 9.5f, 10.5f, 14.5f, 15.5f, 16.5f, 20.5f, 21.5f, 22.5f, - 26.5f, 27.5f, 28.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, 40.5f, 44.5f, 45.5f, 46.5f, - 50.5f, 51.5f, 52.5f, 56.5f, 57.5f, 58.5f, 62.5f, 63.5f, 64.5f, 68.5f, 69.5f, 70.5f - }); - x.linspace(1.); - sd::ops::adjust_contrast op; - auto result = op.evaluate({&x}, {2.}); - ASSERT_EQ(Status::OK(), result.status()); - auto out = result.at(0); -// out->printIndexedBuffer("Adjusted Constrast"); - ASSERT_TRUE(e.equalsTo(out)); - + auto x = NDArrayFactory::create('c', {1, 4, 4, 3}); + auto e = NDArrayFactory::create( + 'c', {1, 4, 4, 3}, + {-21.5f, -20.5f, -19.5f, -15.5f, -14.5f, -13.5f, -9.5f, -8.5f, + -7.5f, -3.5f, -2.5f, -1.5f, 2.5f, 3.5f, 4.5f, 8.5f, + 9.5f, 10.5f, 14.5f, 15.5f, 16.5f, 20.5f, 21.5f, 22.5f, + 26.5f, 27.5f, 28.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, + 40.5f, 44.5f, 45.5f, 46.5f, 50.5f, 51.5f, 52.5f, 56.5f, + 57.5f, 58.5f, 62.5f, 63.5f, 64.5f, 68.5f, 69.5f, 70.5f}); + x.linspace(1.); + sd::ops::adjust_contrast op; + auto result = op.evaluate({&x}, {2.}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); + // out->printIndexedBuffer("Adjusted Constrast"); + ASSERT_TRUE(e.equalsTo(out)); } TEST_F(DeclarableOpsTests15, Test_AdjustContrast_3) { - auto x = NDArrayFactory::create('c', {1, 4,4,3}); - auto e = NDArrayFactory::create('c', {1, 4,4,3}, { - -21.5f, -20.5f, -19.5f, -15.5f, -14.5f, -13.5f, -9.5f, -8.5f, -7.5f, -3.5f, -2.5f, -1.5f, - 2.5f, 3.5f, 4.5f, 8.5f, 9.5f, 10.5f, 14.5f, 15.5f, 16.5f, 20.5f, 21.5f, 22.5f, - 26.5f, 27.5f, 28.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, 40.5f, 44.5f, 45.5f, 46.5f, - 50.5f, 51.5f, 52.5f, 56.5f, 57.5f, 58.5f, 62.5f, 63.5f, 64.5f, 68.5f, 69.5f, 70.5f - }); - x.linspace(1.); - sd::ops::adjust_contrast_v2 op; - auto result = op.evaluate({&x}, {2.}); - ASSERT_EQ(Status::OK(), result.status()); - auto out = result.at(0); -// out->printIndexedBuffer("Adjusted Constrast"); - ASSERT_TRUE(e.equalsTo(out)); - + auto x = NDArrayFactory::create('c', {1, 4, 4, 3}); + auto e = NDArrayFactory::create( + 'c', {1, 4, 4, 3}, + {-21.5f, -20.5f, -19.5f, -15.5f, -14.5f, -13.5f, -9.5f, -8.5f, + -7.5f, -3.5f, -2.5f, -1.5f, 2.5f, 3.5f, 4.5f, 8.5f, + 9.5f, 10.5f, 14.5f, 15.5f, 16.5f, 20.5f, 21.5f, 22.5f, + 26.5f, 27.5f, 28.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, + 40.5f, 44.5f, 45.5f, 46.5f, 50.5f, 51.5f, 52.5f, 56.5f, + 57.5f, 58.5f, 62.5f, 63.5f, 64.5f, 68.5f, 69.5f, 70.5f}); + x.linspace(1.); + sd::ops::adjust_contrast_v2 op; + auto result = op.evaluate({&x}, {2.}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); + // out->printIndexedBuffer("Adjusted Constrast"); + ASSERT_TRUE(e.equalsTo(out)); } TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) { - auto x = NDArrayFactory::create('c', {4, 4, 3}); - auto e = NDArrayFactory::create('c', {4, 4, 3}, { - -21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5, - 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5, - 26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5, - 50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5 - }); - x.linspace(1.); - sd::ops::adjust_contrast_v2 op; - auto result = op.evaluate({&x}, {2.}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto out = result.at(0); -// out->printIndexedBuffer("Adjusted Constrast"); - ASSERT_TRUE(e.equalsTo(out)); - + auto x = NDArrayFactory::create('c', {4, 4, 3}); + auto e = NDArrayFactory::create( + 'c', {4, 4, 3}, + {-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, + -2.5, -1.5, 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, + 16.5, 20.5, 21.5, 22.5, 26.5, 27.5, 28.5, 32.5, 33.5, 34.5, + 38.5, 39.5, 40.5, 44.5, 45.5, 46.5, 50.5, 51.5, 52.5, 56.5, + 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5}); + x.linspace(1.); + sd::ops::adjust_contrast_v2 op; + auto result = op.evaluate({&x}, {2.}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); + // out->printIndexedBuffer("Adjusted Constrast"); + ASSERT_TRUE(e.equalsTo(out)); } TEST_F(DeclarableOpsTests15, Test_AdjustContrast_5) { - auto x = NDArrayFactory::create('c', {1, 3, 4}); - auto e = NDArrayFactory::create('c', {1, 3, 4}, { - -3., -2., -1., 0., 5., 6., 7., 8., 13., 14., 15., 16. - }); - x.linspace(1.); - sd::ops::adjust_contrast_v2 op; - auto result = op.evaluate({&x}, {2.}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto out = result.at(0); -// out->printIndexedBuffer("Adjusted Constrast"); - ASSERT_TRUE(e.equalsTo(out)); - + auto x = NDArrayFactory::create('c', {1, 3, 4}); + auto e = NDArrayFactory::create( + 'c', {1, 3, 4}, {-3., -2., -1., 0., 5., 6., 7., 8., 13., 14., 15., 16.}); + x.linspace(1.); + sd::ops::adjust_contrast_v2 op; + auto result = op.evaluate({&x}, {2.}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); + // out->printIndexedBuffer("Adjusted Constrast"); + ASSERT_TRUE(e.equalsTo(out)); } /* * public void testAdjustContrast1() { - INDArray in = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f,0.2309f,0.7271f,0.1804f, + INDArray in = Nd4j.createFromArray(new + float[]{0.7788f,0.8012f,0.7244f,0.2309f,0.7271f,0.1804f, 0.5056f,0.8925f,0.5461f,0.9234f,0.0856f,0.7938f,0.6591f,0.5555f,0.1596f,0.3087f,0.1548f,0.4695f, 0.9939f,0.6113f,0.6765f,0.1800f,0.6750f,0.2246f,0.0509f,0.4601f,0.8284f,0.2354f,0.9752f,0.8361f, 0.2585f,0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f,0.8428f,0.7494f, 0.0755f,0.6245f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f,0.3019f, 0.3574f,0.1704f,0.8395f,0.5468f,0.0744f,0.9011f,0.6574f,0.4124f,0.2445f,0.4248f,0.5219f, 0.6952f,0.4900f,0.2158f,0.9549f,0.1386f,0.1544f,0.5365f,0.0134f,0.4163f,0.1456f,0.4109f, - 0.2484f, 0.3330f,0.2974f,0.6636f,0.3808f,0.8664f, 0.1896f, 0.7530f, 0.7215f, 0.6612f, 0.7270f, + 0.2484f, 0.3330f,0.2974f,0.6636f,0.3808f,0.8664f, 0.1896f, + 0.7530f, 0.7215f, 0.6612f, 0.7270f, 0.5704f,0.2666f,0.7453f,0.0444f,0.3024f,0.4850f,0.7982f,0.0965f,0.7843f,0.5075f, - 0.0844f,0.8370f,0.6103f,0.4604f,0.6087f, 0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, 0.4143f, + 0.0844f,0.8370f,0.6103f,0.4604f,0.6087f, 0.8594f, 0.4599f, 0.6714f, + 0.2744f, 0.1981f, 0.4143f, 0.7821f,0.3505f,0.5040f,0.1180f,0.8307f,0.1817f,0.8442f,0.5074f,0.4471f,0.5105f,0.6666f, 0.2576f,0.2341f,0.6801f,0.2652f,0.5394f,0.4690f,0.6146f,0.1210f,0.2576f,0.0769f,0.4643f, 0.1628f,0.2026f,0.3774f,0.0506f,0.3462f,0.5720f,0.0838f,0.4228f,0.0588f,0.5362f,0.4756f, @@ -209,1807 +211,2169 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_5) { * */ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_6) { - auto x = NDArrayFactory::create('c', {8,8, 3, 1}, {0.7788f,0.8012f,0.7244f,0.2309f,0.7271f,0.1804f, - 0.5056f,0.8925f,0.5461f,0.9234f,0.0856f,0.7938f,0.6591f,0.5555f,0.1596f,0.3087f,0.1548f,0.4695f, - 0.9939f,0.6113f,0.6765f,0.1800f,0.6750f,0.2246f,0.0509f,0.4601f,0.8284f,0.2354f,0.9752f,0.8361f, - 0.2585f,0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f,0.8428f,0.7494f, - 0.0755f,0.6245f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f,0.3019f, - 0.3574f,0.1704f,0.8395f,0.5468f,0.0744f,0.9011f,0.6574f,0.4124f,0.2445f,0.4248f,0.5219f, - 0.6952f,0.4900f,0.2158f,0.9549f,0.1386f,0.1544f,0.5365f,0.0134f,0.4163f,0.1456f,0.4109f, - 0.2484f, 0.3330f,0.2974f,0.6636f,0.3808f,0.8664f, 0.1896f, 0.7530f, 0.7215f, 0.6612f, 0.7270f, - 0.5704f,0.2666f,0.7453f,0.0444f,0.3024f,0.4850f,0.7982f,0.0965f,0.7843f,0.5075f, - 0.0844f,0.8370f,0.6103f,0.4604f,0.6087f, 0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, 0.4143f, - 0.7821f,0.3505f,0.5040f,0.1180f,0.8307f,0.1817f,0.8442f,0.5074f,0.4471f,0.5105f,0.6666f, - 0.2576f,0.2341f,0.6801f,0.2652f,0.5394f,0.4690f,0.6146f,0.1210f,0.2576f,0.0769f,0.4643f, - 0.1628f,0.2026f,0.3774f,0.0506f,0.3462f,0.5720f,0.0838f,0.4228f,0.0588f,0.5362f,0.4756f, - 0.2530f,0.1778f,0.0751f,0.8977f,0.3648f,0.3065f,0.4739f,0.7014f,0.4473f,0.5171f,0.1744f, - 0.3487f,0.7759f,0.9491f,0.2072f,0.2182f,0.6520f,0.3092f,0.9545f,0.1881f,0.9579f,0.1785f, - 0.9636f,0.4830f,0.6569f,0.3353f,0.9997f,0.5869f,0.5747f,0.0238f,0.2943f,0.5248f,0.5879f, - .7266f,0.1965f,0.9167f,0.9726f,0.9206f,0.0519f,0.2997f,0.0039f,0.7652f,0.5498f, - 0.3794f,0.3791f,0.3528f,0.2873f,0.8082f,0.4732f,0.4399f,0.6606f,0.5991f,0.0034f,0.4874f}); - auto e = NDArrayFactory::create('c', {8, 8, 3, 1}, { - 1.0218375f, 1.0666375f, 0.9130375f, - -0.07396251f, 0.91843754f, -0.17496246f, - 0.47543746f, 1.2492375f, 0.55643755f, - 1.3110375f, -0.36456245f, 1.0518374f, - 0.7824375f, 0.57523745f, -0.21656245f, - 0.0816375f, -0.2261625f, 0.40323752f, - 1.4520376f, 0.6868375f, 0.81723756f, - -0.17576247f, 0.81423753f, -0.08656245f, - - -0.36249164f, 0.45590833f, 1.1925083f, - 0.00650835f, 1.4861084f, 1.2079083f, - 0.05270836f, 0.37350836f, 0.94130826f, - 1.0715083f, 0.6103083f, 0.9825083f, - 0.07370833f, -0.4518917f, -0.39889166f, - -0.3354917f, 1.2213084f, 1.0345083f, - -0.3132917f, 0.78470826f, 0.23390833f, - 0.6943083f, 0.68170834f, -0.09989169f, - - 0.8352709f, 1.3798709f, 0.15507084f, - 0.26607084f, -0.10792917f, 1.2302709f, - 0.6448709f, -0.29992914f, 1.3534708f, - 0.86607087f, 0.37607086f, 0.04027084f, - 0.40087086f, 0.59507084f, 0.9416709f, - 0.53127086f, -0.01712915f, 1.4610709f, - -0.17152917f, -0.13992918f, 0.6242708f, - -0.42192918f, 0.38387084f, -0.15752912f, - - 0.3311833f, 0.00618333f, 0.17538333f, - 0.10418332f, 0.8365834f, 0.27098334f, - 1.2421833f, -0.1114167f, 1.0153834f, - 0.9523833f, 0.8317833f, 0.9633833f, - 0.6501833f, 0.04258335f, 0.9999833f, - -0.40181667f, 0.11418331f, 0.47938335f, - 1.1057833f, -0.29761666f, 1.0779834f, - 0.5243833f, -0.32181668f, 1.1833833f, - - 0.73157084f, 0.4317708f, 0.7283708f, - 1.2297708f, 0.4307708f, 0.85377085f, - 0.05977082f, -0.09282917f, 0.33957082f, - 1.0751709f, 0.2119708f, 0.51897085f, - -0.25302917f, 1.1723708f, -0.12562919f, - 1.1993709f, 0.5257708f, 0.40517086f, - 0.53197086f, 0.8441708f, 0.02617085f, - -0.0208292f, 0.8711709f, 0.04137081f, - - 0.74936247f, 0.6085625f, 0.8997625f, - -0.08743751f, 0.18576252f, -0.17563748f, - 0.5991625f, -0.0038375f, 0.07576251f, - 0.42536253f, -0.22823751f, 0.36296248f, - 0.81456256f, -0.16183749f, 0.5161625f, - -0.21183747f, 0.7429625f, 0.6217625f, - 0.17656249f, 0.02616251f, -0.17923748f, - 1.4659625f, 0.40016252f, 0.28356248f, - - 0.4195791f, 0.8745791f, 0.36637908f, - 0.50597906f, -0.17942089f, 0.16917908f, - 1.0235791f, 1.3699791f, -0.11382091f, - -0.0918209f, 0.7757791f, 0.09017909f, - 1.3807791f, -0.15202093f, 1.3875791f, - -0.1712209f, 1.3989791f, 0.43777913f, - 0.7855791f, 0.1423791f, 1.4711791f, - 0.6455791f, 0.6211791f, -0.48062086f, - - 0.10189578f, 0.5628958f, 0.68909574f, - 0.96649575f, -0.09370419f, 1.3466958f, - 1.4584957f, 1.3544958f, -0.3829042f, - 0.11269578f, -0.47890422f, 1.0436958f, - 0.6128957f, 0.27209583f, 0.2714958f, - 0.21889582f, 0.08789578f, 1.1296958f, - 0.4596958f, 0.39309582f, 0.8344958f, - 0.71149576f, -0.4799042f, 0.4880958f - }); - - sd::ops::adjust_contrast op; - auto result = op.evaluate({&x}, {2.}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto out = result.at(0); -// out->printBuffer("Adjusted Constrast6"); -// e.printBuffer("Adjusted Expected 6"); -// ASSERT_TRUE(e.equalsTo(out)); - + auto x = NDArrayFactory::create( + 'c', {8, 8, 3, 1}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f, 0.3087f, + 0.1548f, 0.4695f, 0.9939f, 0.6113f, 0.6765f, 0.1800f, 0.6750f, 0.2246f, + 0.0509f, 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f, 0.2585f, 0.4189f, + 0.7028f, 0.7679f, 0.5373f, 0.7234f, 0.2690f, 0.0062f, 0.0327f, 0.0644f, + 0.8428f, 0.7494f, 0.0755f, 0.6245f, 0.3491f, 0.5793f, 0.5730f, 0.1822f, + 0.6420f, 0.9143f, 0.3019f, 0.3574f, 0.1704f, 0.8395f, 0.5468f, 0.0744f, + 0.9011f, 0.6574f, 0.4124f, 0.2445f, 0.4248f, 0.5219f, 0.6952f, 0.4900f, + 0.2158f, 0.9549f, 0.1386f, 0.1544f, 0.5365f, 0.0134f, 0.4163f, 0.1456f, + 0.4109f, 0.2484f, 0.3330f, 0.2974f, 0.6636f, 0.3808f, 0.8664f, 0.1896f, + 0.7530f, 0.7215f, 0.6612f, 0.7270f, 0.5704f, 0.2666f, 0.7453f, 0.0444f, + 0.3024f, 0.4850f, 0.7982f, 0.0965f, 0.7843f, 0.5075f, 0.0844f, 0.8370f, + 0.6103f, 0.4604f, 0.6087f, 0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, + 0.4143f, 0.7821f, 0.3505f, 0.5040f, 0.1180f, 0.8307f, 0.1817f, 0.8442f, + 0.5074f, 0.4471f, 0.5105f, 0.6666f, 0.2576f, 0.2341f, 0.6801f, 0.2652f, + 0.5394f, 0.4690f, 0.6146f, 0.1210f, 0.2576f, 0.0769f, 0.4643f, 0.1628f, + 0.2026f, 0.3774f, 0.0506f, 0.3462f, 0.5720f, 0.0838f, 0.4228f, 0.0588f, + 0.5362f, 0.4756f, 0.2530f, 0.1778f, 0.0751f, 0.8977f, 0.3648f, 0.3065f, + 0.4739f, 0.7014f, 0.4473f, 0.5171f, 0.1744f, 0.3487f, 0.7759f, 0.9491f, + 0.2072f, 0.2182f, 0.6520f, 0.3092f, 0.9545f, 0.1881f, 0.9579f, 0.1785f, + 0.9636f, 0.4830f, 0.6569f, 0.3353f, 0.9997f, 0.5869f, 0.5747f, 0.0238f, + 0.2943f, 0.5248f, 0.5879f, .7266f, 0.1965f, 0.9167f, 0.9726f, 0.9206f, + 0.0519f, 0.2997f, 0.0039f, 0.7652f, 0.5498f, 0.3794f, 0.3791f, 0.3528f, + 0.2873f, 0.8082f, 0.4732f, 0.4399f, 0.6606f, 0.5991f, 0.0034f, 0.4874f}); + auto e = NDArrayFactory::create( + 'c', {8, 8, 3, 1}, + {1.0218375f, 1.0666375f, 0.9130375f, -0.07396251f, 0.91843754f, + -0.17496246f, 0.47543746f, 1.2492375f, 0.55643755f, 1.3110375f, + -0.36456245f, 1.0518374f, 0.7824375f, 0.57523745f, -0.21656245f, + 0.0816375f, -0.2261625f, 0.40323752f, 1.4520376f, 0.6868375f, + 0.81723756f, -0.17576247f, 0.81423753f, -0.08656245f, + + -0.36249164f, 0.45590833f, 1.1925083f, 0.00650835f, 1.4861084f, + 1.2079083f, 0.05270836f, 0.37350836f, 0.94130826f, 1.0715083f, + 0.6103083f, 0.9825083f, 0.07370833f, -0.4518917f, -0.39889166f, + -0.3354917f, 1.2213084f, 1.0345083f, -0.3132917f, 0.78470826f, + 0.23390833f, 0.6943083f, 0.68170834f, -0.09989169f, + + 0.8352709f, 1.3798709f, 0.15507084f, 0.26607084f, -0.10792917f, + 1.2302709f, 0.6448709f, -0.29992914f, 1.3534708f, 0.86607087f, + 0.37607086f, 0.04027084f, 0.40087086f, 0.59507084f, 0.9416709f, + 0.53127086f, -0.01712915f, 1.4610709f, -0.17152917f, -0.13992918f, + 0.6242708f, -0.42192918f, 0.38387084f, -0.15752912f, + + 0.3311833f, 0.00618333f, 0.17538333f, 0.10418332f, 0.8365834f, + 0.27098334f, 1.2421833f, -0.1114167f, 1.0153834f, 0.9523833f, + 0.8317833f, 0.9633833f, 0.6501833f, 0.04258335f, 0.9999833f, + -0.40181667f, 0.11418331f, 0.47938335f, 1.1057833f, -0.29761666f, + 1.0779834f, 0.5243833f, -0.32181668f, 1.1833833f, + + 0.73157084f, 0.4317708f, 0.7283708f, 1.2297708f, 0.4307708f, + 0.85377085f, 0.05977082f, -0.09282917f, 0.33957082f, 1.0751709f, + 0.2119708f, 0.51897085f, -0.25302917f, 1.1723708f, -0.12562919f, + 1.1993709f, 0.5257708f, 0.40517086f, 0.53197086f, 0.8441708f, + 0.02617085f, -0.0208292f, 0.8711709f, 0.04137081f, + + 0.74936247f, 0.6085625f, 0.8997625f, -0.08743751f, 0.18576252f, + -0.17563748f, 0.5991625f, -0.0038375f, 0.07576251f, 0.42536253f, + -0.22823751f, 0.36296248f, 0.81456256f, -0.16183749f, 0.5161625f, + -0.21183747f, 0.7429625f, 0.6217625f, 0.17656249f, 0.02616251f, + -0.17923748f, 1.4659625f, 0.40016252f, 0.28356248f, + + 0.4195791f, 0.8745791f, 0.36637908f, 0.50597906f, -0.17942089f, + 0.16917908f, 1.0235791f, 1.3699791f, -0.11382091f, -0.0918209f, + 0.7757791f, 0.09017909f, 1.3807791f, -0.15202093f, 1.3875791f, + -0.1712209f, 1.3989791f, 0.43777913f, 0.7855791f, 0.1423791f, + 1.4711791f, 0.6455791f, 0.6211791f, -0.48062086f, + + 0.10189578f, 0.5628958f, 0.68909574f, 0.96649575f, -0.09370419f, + 1.3466958f, 1.4584957f, 1.3544958f, -0.3829042f, 0.11269578f, + -0.47890422f, 1.0436958f, 0.6128957f, 0.27209583f, 0.2714958f, + 0.21889582f, 0.08789578f, 1.1296958f, 0.4596958f, 0.39309582f, + 0.8344958f, 0.71149576f, -0.4799042f, 0.4880958f}); + + sd::ops::adjust_contrast op; + auto result = op.evaluate({&x}, {2.}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); + // out->printBuffer("Adjusted Constrast6"); + // e.printBuffer("Adjusted Expected 6"); + // ASSERT_TRUE(e.equalsTo(out)); } TEST_F(DeclarableOpsTests15, Test_AdjustContrast_7) { - auto x = NDArrayFactory::create('c', {8,8, 3, 1}, {0.7788f,0.8012f,0.7244f,0.2309f,0.7271f,0.1804f, - 0.5056f,0.8925f,0.5461f,0.9234f,0.0856f,0.7938f,0.6591f,0.5555f,0.1596f,0.3087f,0.1548f,0.4695f, - 0.9939f,0.6113f,0.6765f,0.1800f,0.6750f,0.2246f,0.0509f,0.4601f,0.8284f,0.2354f,0.9752f,0.8361f, - 0.2585f,0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f,0.8428f,0.7494f, - 0.0755f,0.6245f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f,0.3019f, - 0.3574f,0.1704f,0.8395f,0.5468f,0.0744f,0.9011f,0.6574f,0.4124f,0.2445f,0.4248f,0.5219f, - 0.6952f,0.4900f,0.2158f,0.9549f,0.1386f,0.1544f,0.5365f,0.0134f,0.4163f,0.1456f,0.4109f, - 0.2484f, 0.3330f,0.2974f,0.6636f,0.3808f,0.8664f, 0.1896f, 0.7530f, 0.7215f, 0.6612f, 0.7270f, - 0.5704f,0.2666f,0.7453f,0.0444f,0.3024f,0.4850f,0.7982f,0.0965f,0.7843f,0.5075f, - 0.0844f,0.8370f,0.6103f,0.4604f,0.6087f, 0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, 0.4143f, - 0.7821f,0.3505f,0.5040f,0.1180f,0.8307f,0.1817f,0.8442f,0.5074f,0.4471f,0.5105f,0.6666f, - 0.2576f,0.2341f,0.6801f,0.2652f,0.5394f,0.4690f,0.6146f,0.1210f,0.2576f,0.0769f,0.4643f, - 0.1628f,0.2026f,0.3774f,0.0506f,0.3462f,0.5720f,0.0838f,0.4228f,0.0588f,0.5362f,0.4756f, - 0.2530f,0.1778f,0.0751f,0.8977f,0.3648f,0.3065f,0.4739f,0.7014f,0.4473f,0.5171f,0.1744f, - 0.3487f,0.7759f,0.9491f,0.2072f,0.2182f,0.6520f,0.3092f,0.9545f,0.1881f,0.9579f,0.1785f, - 0.9636f,0.4830f,0.6569f,0.3353f,0.9997f,0.5869f,0.5747f,0.0238f,0.2943f,0.5248f,0.5879f, - .7266f,0.1965f,0.9167f,0.9726f,0.9206f,0.0519f,0.2997f,0.0039f,0.7652f,0.5498f, - 0.3794f,0.3791f,0.3528f,0.2873f,0.8082f,0.4732f,0.4399f,0.6606f,0.5991f,0.0034f,0.4874f}); - auto e = NDArrayFactory::create('c', {8, 8, 3, 1}, { - 1.0218375, 1.0666375 , 0.9130375 , - -0.07396251, 0.91843754, -0.17496246, - 0.47543746, 1.2492375 , 0.55643755, - 1.3110375 , -0.36456245, 1.0518374 , - 0.7824375 , 0.57523745, -0.21656245, - 0.0816375 , -0.2261625 , 0.40323752, - 1.4520376 , 0.6868375 , 0.81723756, - -0.17576247, 0.81423753, -0.08656245, - - -0.36249164, 0.45590833, 1.1925083 , - 0.00650835, 1.4861084 , 1.2079083 , - 0.05270836, 0.37350836, 0.94130826, - 1.0715083 , 0.6103083 , 0.9825083 , - 0.07370833, -0.4518917 , -0.39889166, - -0.3354917 , 1.2213084 , 1.0345083 , - -0.3132917 , 0.78470826, 0.23390833, - 0.6943083 , 0.68170834, -0.09989169, - - 0.8352709 , 1.3798709 , 0.15507084, - 0.26607084, -0.10792917, 1.2302709 , - 0.6448709 , -0.29992914, 1.3534708 , - 0.86607087, 0.37607086, 0.04027084, - 0.40087086, 0.59507084, 0.9416709 , - 0.53127086, -0.01712915, 1.4610709 , - -0.17152917, -0.13992918, 0.6242708 , - -0.42192918, 0.38387084, -0.15752912, - - - 0.3311833 , 0.00618333, 0.17538333, - 0.10418332, 0.8365834 , 0.27098334, - 1.2421833 , -0.1114167 , 1.0153834 , - 0.9523833 , 0.8317833 , 0.9633833 , - 0.6501833 , 0.04258335, 0.9999833 , - -0.40181667, 0.11418331, 0.47938335, - 1.1057833 , -0.29761666, 1.0779834 , - 0.5243833 , -0.32181668, 1.1833833 , - - 0.73157084, 0.4317708 , 0.7283708 , - 1.2297708 , 0.4307708 , 0.85377085, - 0.05977082, -0.09282917, 0.33957082, - 1.0751709 , 0.2119708 , 0.51897085, - -0.25302917, 1.1723708 , -0.12562919, - 1.1993709 , 0.5257708 , 0.40517086, - 0.53197086, 0.8441708 , 0.02617085, - -0.0208292 , 0.8711709 , 0.04137081, - - 0.74936247, 0.6085625 , 0.8997625 , - -0.08743751, 0.18576252, -0.17563748, - 0.5991625 , -0.0038375 , 0.07576251, - 0.42536253, -0.22823751, 0.36296248, - 0.81456256, -0.16183749, 0.5161625 , - -0.21183747, 0.7429625 , 0.6217625 , - 0.17656249, 0.02616251, -0.17923748, - 1.4659625 , 0.40016252, 0.28356248, - - 0.4195791 , 0.8745791 , 0.36637908, - 0.50597906, -0.17942089, 0.16917908, - 1.0235791 , 1.3699791 , -0.11382091, - -0.0918209 , 0.7757791 , 0.09017909, - 1.3807791 , -0.15202093, 1.3875791 , - -0.1712209 , 1.3989791 , 0.43777913, - 0.7855791 , 0.1423791 , 1.4711791 , - 0.6455791 , 0.6211791 , -0.48062086, - - - 0.10189578, 0.5628958 , 0.68909574, - 0.96649575, -0.09370419, 1.3466958 , - 1.4584957 , 1.3544958 , -0.3829042 , - 0.11269578, -0.47890422, 1.0436958 , - 0.6128957 , 0.27209583, 0.2714958 , - 0.21889582, 0.08789578, 1.1296958 , - 0.4596958 , 0.39309582, 0.8344958 , - 0.71149576, -0.4799042, 0.4880958 - }); -// x.linspace(1.); - sd::ops::adjust_contrast_v2 op; - auto result = op.evaluate({&x}, {2.}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto out = result.at(0); -// out->printBuffer("Adjusted Constrast7"); -// e.printBuffer("Adjusted expected 7"); - auto diff = e - *out; -// diff.printBuffer("Adjusted subtract 7"); - ASSERT_TRUE(e.equalsTo(out)); - + auto x = NDArrayFactory::create( + 'c', {8, 8, 3, 1}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f, 0.3087f, + 0.1548f, 0.4695f, 0.9939f, 0.6113f, 0.6765f, 0.1800f, 0.6750f, 0.2246f, + 0.0509f, 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f, 0.2585f, 0.4189f, + 0.7028f, 0.7679f, 0.5373f, 0.7234f, 0.2690f, 0.0062f, 0.0327f, 0.0644f, + 0.8428f, 0.7494f, 0.0755f, 0.6245f, 0.3491f, 0.5793f, 0.5730f, 0.1822f, + 0.6420f, 0.9143f, 0.3019f, 0.3574f, 0.1704f, 0.8395f, 0.5468f, 0.0744f, + 0.9011f, 0.6574f, 0.4124f, 0.2445f, 0.4248f, 0.5219f, 0.6952f, 0.4900f, + 0.2158f, 0.9549f, 0.1386f, 0.1544f, 0.5365f, 0.0134f, 0.4163f, 0.1456f, + 0.4109f, 0.2484f, 0.3330f, 0.2974f, 0.6636f, 0.3808f, 0.8664f, 0.1896f, + 0.7530f, 0.7215f, 0.6612f, 0.7270f, 0.5704f, 0.2666f, 0.7453f, 0.0444f, + 0.3024f, 0.4850f, 0.7982f, 0.0965f, 0.7843f, 0.5075f, 0.0844f, 0.8370f, + 0.6103f, 0.4604f, 0.6087f, 0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, + 0.4143f, 0.7821f, 0.3505f, 0.5040f, 0.1180f, 0.8307f, 0.1817f, 0.8442f, + 0.5074f, 0.4471f, 0.5105f, 0.6666f, 0.2576f, 0.2341f, 0.6801f, 0.2652f, + 0.5394f, 0.4690f, 0.6146f, 0.1210f, 0.2576f, 0.0769f, 0.4643f, 0.1628f, + 0.2026f, 0.3774f, 0.0506f, 0.3462f, 0.5720f, 0.0838f, 0.4228f, 0.0588f, + 0.5362f, 0.4756f, 0.2530f, 0.1778f, 0.0751f, 0.8977f, 0.3648f, 0.3065f, + 0.4739f, 0.7014f, 0.4473f, 0.5171f, 0.1744f, 0.3487f, 0.7759f, 0.9491f, + 0.2072f, 0.2182f, 0.6520f, 0.3092f, 0.9545f, 0.1881f, 0.9579f, 0.1785f, + 0.9636f, 0.4830f, 0.6569f, 0.3353f, 0.9997f, 0.5869f, 0.5747f, 0.0238f, + 0.2943f, 0.5248f, 0.5879f, .7266f, 0.1965f, 0.9167f, 0.9726f, 0.9206f, + 0.0519f, 0.2997f, 0.0039f, 0.7652f, 0.5498f, 0.3794f, 0.3791f, 0.3528f, + 0.2873f, 0.8082f, 0.4732f, 0.4399f, 0.6606f, 0.5991f, 0.0034f, 0.4874f}); + auto e = NDArrayFactory::create( + 'c', {8, 8, 3, 1}, + {1.0218375, 1.0666375, 0.9130375, -0.07396251, 0.91843754, + -0.17496246, 0.47543746, 1.2492375, 0.55643755, 1.3110375, + -0.36456245, 1.0518374, 0.7824375, 0.57523745, -0.21656245, + 0.0816375, -0.2261625, 0.40323752, 1.4520376, 0.6868375, + 0.81723756, -0.17576247, 0.81423753, -0.08656245, + + -0.36249164, 0.45590833, 1.1925083, 0.00650835, 1.4861084, + 1.2079083, 0.05270836, 0.37350836, 0.94130826, 1.0715083, + 0.6103083, 0.9825083, 0.07370833, -0.4518917, -0.39889166, + -0.3354917, 1.2213084, 1.0345083, -0.3132917, 0.78470826, + 0.23390833, 0.6943083, 0.68170834, -0.09989169, + + 0.8352709, 1.3798709, 0.15507084, 0.26607084, -0.10792917, + 1.2302709, 0.6448709, -0.29992914, 1.3534708, 0.86607087, + 0.37607086, 0.04027084, 0.40087086, 0.59507084, 0.9416709, + 0.53127086, -0.01712915, 1.4610709, -0.17152917, -0.13992918, + 0.6242708, -0.42192918, 0.38387084, -0.15752912, + + 0.3311833, 0.00618333, 0.17538333, 0.10418332, 0.8365834, + 0.27098334, 1.2421833, -0.1114167, 1.0153834, 0.9523833, + 0.8317833, 0.9633833, 0.6501833, 0.04258335, 0.9999833, + -0.40181667, 0.11418331, 0.47938335, 1.1057833, -0.29761666, + 1.0779834, 0.5243833, -0.32181668, 1.1833833, + + 0.73157084, 0.4317708, 0.7283708, 1.2297708, 0.4307708, + 0.85377085, 0.05977082, -0.09282917, 0.33957082, 1.0751709, + 0.2119708, 0.51897085, -0.25302917, 1.1723708, -0.12562919, + 1.1993709, 0.5257708, 0.40517086, 0.53197086, 0.8441708, + 0.02617085, -0.0208292, 0.8711709, 0.04137081, + + 0.74936247, 0.6085625, 0.8997625, -0.08743751, 0.18576252, + -0.17563748, 0.5991625, -0.0038375, 0.07576251, 0.42536253, + -0.22823751, 0.36296248, 0.81456256, -0.16183749, 0.5161625, + -0.21183747, 0.7429625, 0.6217625, 0.17656249, 0.02616251, + -0.17923748, 1.4659625, 0.40016252, 0.28356248, + + 0.4195791, 0.8745791, 0.36637908, 0.50597906, -0.17942089, + 0.16917908, 1.0235791, 1.3699791, -0.11382091, -0.0918209, + 0.7757791, 0.09017909, 1.3807791, -0.15202093, 1.3875791, + -0.1712209, 1.3989791, 0.43777913, 0.7855791, 0.1423791, + 1.4711791, 0.6455791, 0.6211791, -0.48062086, + + 0.10189578, 0.5628958, 0.68909574, 0.96649575, -0.09370419, + 1.3466958, 1.4584957, 1.3544958, -0.3829042, 0.11269578, + -0.47890422, 1.0436958, 0.6128957, 0.27209583, 0.2714958, + 0.21889582, 0.08789578, 1.1296958, 0.4596958, 0.39309582, + 0.8344958, 0.71149576, -0.4799042, 0.4880958}); + // x.linspace(1.); + sd::ops::adjust_contrast_v2 op; + auto result = op.evaluate({&x}, {2.}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); + // out->printBuffer("Adjusted Constrast7"); + // e.printBuffer("Adjusted expected 7"); + auto diff = e - out; + // diff.printBuffer("Adjusted subtract 7"); + ASSERT_TRUE(e.equalsTo(out)); } TEST_F(DeclarableOpsTests15, Test_BitCast_1) { - auto x = NDArrayFactory::create('c', {2, 2, 2}); - auto e = NDArrayFactory::create('c', {2, 2}, {2., 512., 8192., 131072.032 }); - x.linspace(1.); - - sd::ops::bitcast op; - auto result = op.evaluate({&x}, {(int) sd::DataType::DOUBLE}); - ASSERT_EQ(Status::OK(), result.status()); - auto out = result.at(0); -// out->printIndexedBuffer("Casted result"); - ASSERT_TRUE(e.equalsTo(out)); + auto x = NDArrayFactory::create('c', {2, 2, 2}); + auto e = NDArrayFactory::create('c', {2, 2}, + {2., 512., 8192., 131072.032}); + x.linspace(1.); + sd::ops::bitcast op; + auto result = op.evaluate({&x}, {(int)sd::DataType::DOUBLE}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); + // out->printIndexedBuffer("Casted result"); + ASSERT_TRUE(e.equalsTo(out)); } TEST_F(DeclarableOpsTests15, Test_BitCast_2) { - auto x = NDArrayFactory::create('c', {2, 4}); - auto e = NDArrayFactory::create('c', {2, 4, 2}, {0.f, 1.875f, 0.f, 2.f, 0.f, 2.125f, 0.f, 2.25f, - 0.f, 2.312f, 0.f, 2.375f, 0.f, 2.438f, 0.f, 2.5f}); - x.linspace(1.); + auto x = NDArrayFactory::create('c', {2, 4}); + auto e = NDArrayFactory::create( + 'c', {2, 4, 2}, + {0.f, 1.875f, 0.f, 2.f, 0.f, 2.125f, 0.f, 2.25f, 0.f, 2.312f, 0.f, 2.375f, + 0.f, 2.438f, 0.f, 2.5f}); + x.linspace(1.); - sd::ops::bitcast op; - auto result = op.evaluate({&x}, {(int) sd::DataType::HALF}); - ASSERT_EQ(Status::OK(), result.status()); - auto out = result.at(0); - - ASSERT_TRUE(e.equalsTo(out)); + sd::ops::bitcast op; + auto result = op.evaluate({&x}, {(int)sd::DataType::HALF}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); + ASSERT_TRUE(e.equalsTo(out)); } TEST_F(DeclarableOpsTests15, Test_BitCast_3) { - auto x = NDArrayFactory::create('c', {1, 4}); + auto x = NDArrayFactory::create('c', {1, 4}); - x.linspace(1.); - sd::ops::bitcast op; - try { - auto result = op.evaluate({&x}, {(int) sd::DataType::INT64}); - ASSERT_NE(Status::OK(), result.status()); + x.linspace(1.); + sd::ops::bitcast op; + try { + auto result = op.evaluate({&x}, {(int)sd::DataType::INT64}); + ASSERT_NE(Status::OK(), result.status()); - } catch (std::exception& e) { - nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); - } + } catch (std::exception& e) { + nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); + } } TEST_F(DeclarableOpsTests15, Test_BitCast_4) { - auto x = NDArrayFactory::create('c', {1, 4}); - auto e = NDArrayFactory::create('c', {1, 2}, {1234567890LL, 2468013579LL}); - x.linspace(1.); - sd::ops::bitcast op; - try { - auto result = op.execute({&x}, {&e}, {}, {sd::DataType::INT64}, {}); - ASSERT_NE(Status::OK(), result); - } catch(std::exception& e) { - nd4j_printf("Error `%s' should be here. It's OK.\n",e.what()); - } - + auto x = NDArrayFactory::create('c', {1, 4}); + auto e = NDArrayFactory::create('c', {1, 2}, + {1234567890LL, 2468013579LL}); + x.linspace(1.); + sd::ops::bitcast op; + try { + auto result = op.execute({&x}, {&e}, {}, {sd::DataType::INT64}, {}); + ASSERT_NE(Status::OK(), result); + } catch (std::exception& e) { + nd4j_printf("Error `%s' should be here. It's OK.\n", e.what()); + } } TEST_F(DeclarableOpsTests15, Test_BitCast_4_1) { - auto x = NDArrayFactory::create('c', {1, 2}); - auto e = NDArrayFactory::create('c', {1, 2}, {4607182418800017408LL, 4611686018427387904LL}); // as TF 4607182418800017408, 4611686018427387904 - x.linspace(1.); - sd::ops::bitcast op; - - auto result = op.evaluate({&x}, {}, {sd::DataType::INT64}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {1, 2}); + auto e = NDArrayFactory::create( + 'c', {1, 2}, + {4607182418800017408LL, + 4611686018427387904LL}); // as TF 4607182418800017408, + // 4611686018427387904 + x.linspace(1.); + sd::ops::bitcast op; - // e.printIndexedBuffer("Double to int64"); - auto res = result.at(0); - ASSERT_EQ(*res, e); + auto result = op.evaluate({&x}, {}, {sd::DataType::INT64}, {}); + ASSERT_EQ(Status::OK(), result.status()); + // e.printIndexedBuffer("Double to int64"); + auto res = result.at(0); + ASSERT_EQ(res, e); } - TEST_F(DeclarableOpsTests15, Test_BitCast_5) { - auto x = NDArrayFactory::create('c', {4, 4}, { - 0.4922f, 0.2969f, 0.6172f, 0.8906f, - 0.9297f, 0.0859f, 0.2344f, 0.3828f, - 0.5781f, 0.7969f, 0.0391f, 0.1719f, - 0.8359f, 0.9297f, 0.3438f, 0.0938f}); + auto x = NDArrayFactory::create( + 'c', {4, 4}, + {0.4922f, 0.2969f, 0.6172f, 0.8906f, 0.9297f, 0.0859f, 0.2344f, 0.3828f, + 0.5781f, 0.7969f, 0.0391f, 0.1719f, 0.8359f, 0.9297f, 0.3438f, 0.0938f}); - auto e = NDArrayFactory::create('c', {4}, {4260467851820808160LL, 3900173902914993008LL, 3566895990128523424LL, - 3314989625590692528LL}); + auto e = NDArrayFactory::create( + 'c', {4}, + {4260467851820808160LL, 3900173902914993008LL, 3566895990128523424LL, + 3314989625590692528LL}); - sd::ops::bitcast op; - auto result = op.evaluate({&x}, {}, {sd::DataType::INT64}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto res = result.at(0); - -// res->printIndexedBuffer("BITCAST5"); - ASSERT_TRUE(e.equalsTo(res)); + sd::ops::bitcast op; + auto result = op.evaluate({&x}, {}, {sd::DataType::INT64}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto res = result.at(0); + // res->printIndexedBuffer("BITCAST5"); + ASSERT_TRUE(e.equalsTo(res)); } TEST_F(DeclarableOpsTests15, Test_BitCast_6) { - auto x = NDArrayFactory::create('c', {4, 4}, { - 1.f, 2.f, 3.f, 4.f, - 5.f, 6.f, 7.f, 8.f, - 9.f, 10.f, 11.f, 12.f, - 13.f, 14.f, 15.f, 16.f}); - - auto e = NDArrayFactory::create('c', {4}, {4899988963420290048LL, 5188224837230806272LL, 5332342774136064128LL, - 5476460161268730496LL}); + auto x = NDArrayFactory::create( + 'c', {4, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, + 14.f, 15.f, 16.f}); - sd::ops::bitcast op; - auto result = op.evaluate({&x}, {}, {sd::DataType::INT64}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto res = result.at(0); + auto e = NDArrayFactory::create( + 'c', {4}, + {4899988963420290048LL, 5188224837230806272LL, 5332342774136064128LL, + 5476460161268730496LL}); -// res->printIndexedBuffer("BITCAST6"); - ASSERT_TRUE(e.equalsTo(res)); + sd::ops::bitcast op; + auto result = op.evaluate({&x}, {}, {sd::DataType::INT64}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto res = result.at(0); + // res->printIndexedBuffer("BITCAST6"); + ASSERT_TRUE(e.equalsTo(res)); } TEST_F(DeclarableOpsTests15, Test_BitCast_7) { - auto x = NDArrayFactory::create('c', {4, 4}, { - 1.1f, 2.2f, 3.3f, 4.4f, - 5.1f, 6.2f, 7.3f, 8.4f, - 9.1f, 10.2f, 11.3f, 12.4f, - 13.f, 14.2f, 15.3f, 16.4f}); - - auto e = NDArrayFactory::create('c', {4}, { - 4928700072476425318LL, 5202580391758873882LL, 5346698272827918477LL, 5483778673873668736LL}); + auto x = NDArrayFactory::create( + 'c', {4, 4}, + {1.1f, 2.2f, 3.3f, 4.4f, 5.1f, 6.2f, 7.3f, 8.4f, 9.1f, 10.2f, 11.3f, + 12.4f, 13.f, 14.2f, 15.3f, 16.4f}); - sd::ops::bitcast op; - auto result = op.evaluate({&x}, {}, {sd::DataType::INT64}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto res = result.at(0); + auto e = NDArrayFactory::create( + 'c', {4}, + {4928700072476425318LL, 5202580391758873882LL, 5346698272827918477LL, + 5483778673873668736LL}); -// res->printIndexedBuffer("BITCAST7"); - ASSERT_TRUE(e.equalsTo(res)); + sd::ops::bitcast op; + auto result = op.evaluate({&x}, {}, {sd::DataType::INT64}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto res = result.at(0); + // res->printIndexedBuffer("BITCAST7"); + ASSERT_TRUE(e.equalsTo(res)); } TEST_F(DeclarableOpsTests15, test_matmul_bp_1) { - auto a = NDArrayFactory::create('c', {1, 3}); - auto b = NDArrayFactory::create('c', {1, 4}); - auto gI = NDArrayFactory::create('c', {3, 4}); + auto a = NDArrayFactory::create('c', {1, 3}); + auto b = NDArrayFactory::create('c', {1, 4}); + auto gI = NDArrayFactory::create('c', {3, 4}); - auto gA = NDArrayFactory::create('c', {1, 3}); - auto gB = NDArrayFactory::create('c', {1, 4}); + auto gA = NDArrayFactory::create('c', {1, 3}); + auto gB = NDArrayFactory::create('c', {1, 4}); - sd::ops::matmul_bp op; - auto status = op.execute({&a, &b, &gI}, std::vector{&gA, &gB}, {}, {1, 0, 0}, {}); - ASSERT_EQ(Status::OK(), status); + sd::ops::matmul_bp op; + auto status = op.execute({&a, &b, &gI}, std::vector{&gA, &gB}, {}, + {1, 0, 0}, {}); + ASSERT_EQ(Status::OK(), status); } TEST_F(DeclarableOpsTests15, test_non_decreasing_1) { - auto x = NDArrayFactory::create(1.0); - auto z = NDArrayFactory::create(false); - auto e = NDArrayFactory::create(true); + auto x = NDArrayFactory::create(1.0); + auto z = NDArrayFactory::create(false); + auto e = NDArrayFactory::create(true); - sd::ops::is_non_decreasing op; - Context ctx(1); - ctx.setInputArray(0, &x); - ctx.setOutputArray(0, &z); + sd::ops::is_non_decreasing op; + Context ctx(1); + ctx.setInputArray(0, x); + ctx.setOutputArray(0, z); - auto status = op.execute(&ctx); - ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(e, z); + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests15, test_check_numeric_1) { - auto x = NDArrayFactory::create('c', {3},{1.f, 2.f, 3.f}); - auto y = NDArrayFactory::string("shouldn't ever trigger"); + auto x = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto y = NDArrayFactory::string("shouldn't ever trigger"); - sd::ops::check_numerics op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::check_numerics op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_EQ(x, *z); + ASSERT_EQ(x, z); } TEST_F(DeclarableOpsTests15, test_check_numeric_2) { #ifdef FFAST_MATH - if (1 > 0) - return; + if (1 > 0) return; #endif - auto x = NDArrayFactory::create('c', {3},{1.f, 2.f, std::numeric_limits::infinity()}); - auto y = NDArrayFactory::string("should trigger"); - auto z = NDArrayFactory::create('c', {3} ); + auto x = NDArrayFactory::create( + 'c', {3}, {1.f, 2.f, std::numeric_limits::infinity()}); + auto y = NDArrayFactory::string("should trigger"); + auto z = NDArrayFactory::create('c', {3}); - sd::ops::check_numerics op; - try { - auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); - ASSERT_TRUE(false); - } catch (std::invalid_argument &e) { - // - } + sd::ops::check_numerics op; + try { + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + ASSERT_TRUE(false); + } catch (std::invalid_argument& e) { + // + } } TEST_F(DeclarableOpsTests15, test_check_numeric_3) { #ifdef FFAST_MATH - if (1 > 0) - return; + if (1 > 0) return; #endif - auto x = NDArrayFactory::create('c', {3},{1.f, 2.f, std::numeric_limits::quiet_NaN()}); - auto y = NDArrayFactory::string("should trigger"); - auto z = NDArrayFactory::create('c', {3} ); + auto x = NDArrayFactory::create( + 'c', {3}, {1.f, 2.f, std::numeric_limits::quiet_NaN()}); + auto y = NDArrayFactory::string("should trigger"); + auto z = NDArrayFactory::create('c', {3}); - sd::ops::check_numerics op; - try { - auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); - ASSERT_TRUE(false); - } catch (std::invalid_argument &e) { - // - } + sd::ops::check_numerics op; + try { + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + ASSERT_TRUE(false); + } catch (std::invalid_argument& e) { + // + } } TEST_F(DeclarableOpsTests15, Test_layer_norm_1) { - auto x = NDArrayFactory::create('c', {1, 5}, {1.f, 2.f, 3.f, 4.f, 5.f}); - auto g = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); - auto b = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); - - sd::ops::layer_norm op; - auto result = op.evaluate({&x, &g, &b}, {}, {0}, {false}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = + NDArrayFactory::create('c', {1, 5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + auto g = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + auto b = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + sd::ops::layer_norm op; + auto result = op.evaluate({&x, &g, &b}, {}, {0}, {false}); + ASSERT_EQ(Status::OK(), result.status()); } TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) { - auto x = NDArrayFactory::create('c', {1, 5}, {1.f, 2.f, 3.f, 4.f, 5.f}); - auto g = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); - auto b = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); - auto eps = NDArrayFactory::create('c', {1, 5}, {0.f, 0.f, 0.f, 0.f, 0.f}); - - sd::ops::layer_norm_bp op; - auto result = op.evaluate({&x, &g, &b, &eps}, {}, {0}, {false}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = + NDArrayFactory::create('c', {1, 5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + auto g = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + auto b = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + auto eps = + NDArrayFactory::create('c', {1, 5}, {0.f, 0.f, 0.f, 0.f, 0.f}); + sd::ops::layer_norm_bp op; + auto result = op.evaluate({&x, &g, &b, &eps}, {}, {0}, {false}); + ASSERT_EQ(Status::OK(), result.status()); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_2) { + NDArray x('c', {3, 4, 8, 8}, sd::DataType::FLOAT32); + NDArray gain('c', {4}, {-0.1, 0.1, -0.2, 0.2}, sd::DataType::FLOAT32); + NDArray bias('c', {4}, {-0.05, 0.05, -1.05, 1.05}, sd::DataType::FLOAT32); + NDArray gradO('c', {3, 4, 8, 8}, sd::DataType::FLOAT32); - NDArray x('c', {3, 4, 8, 8}, sd::DataType::FLOAT32); - NDArray gain('c', {4}, {-0.1, 0.1, -0.2, 0.2}, sd::DataType::FLOAT32); - NDArray bias('c', {4}, {-0.05, 0.05, -1.05, 1.05}, sd::DataType::FLOAT32); - NDArray gradO('c', {3, 4, 8, 8}, sd::DataType::FLOAT32); - - NDArray gradI('c', {3, 4, 8, 8}, sd::DataType::FLOAT32); - NDArray gradG('c', {4}, sd::DataType::FLOAT32); - NDArray gradB('c', {4}, sd::DataType::FLOAT32); + NDArray gradI('c', {3, 4, 8, 8}, sd::DataType::FLOAT32); + NDArray gradG('c', {4}, sd::DataType::FLOAT32); + NDArray gradB('c', {4}, sd::DataType::FLOAT32); - x.linspace(-20, 0.5); - gradO.linspace(-4, 0.05); + x.linspace(-20, 0.5); + gradO.linspace(-4, 0.05); - sd::ops::layer_norm_bp op; - auto status = op.execute({&x, &gain, &bias, &gradO}, {&gradI, &gradG, &gradB}, {}, {1,2,3}, {true}); - ASSERT_EQ(Status::OK(), status); + sd::ops::layer_norm_bp op; + auto status = op.execute({&x, &gain, &bias, &gradO}, {&gradI, &gradG, &gradB}, + {}, {1, 2, 3}, {true}); + ASSERT_EQ(Status::OK(), status); } TEST_F(DeclarableOpsTests15, test_hashCode_1) { - auto x = NDArrayFactory::create('c', {10}); - auto y = NDArrayFactory::create('c', {10}); + auto x = NDArrayFactory::create('c', {10}); + auto y = NDArrayFactory::create('c', {10}); - x.linspace(1.); - y.linspace(2.); + x.linspace(1.); + y.linspace(2.); - sd::ops::hashcode op; - auto resultA0 = op.evaluate({&x}); - auto resultA1 = op.evaluate({&x}); - auto resultB0 = op.evaluate({&y}); -// resultA0->at(0)->printIndexedBuffer("A0"); -// resultA1->at(0)->printIndexedBuffer("A1"); -// resultB0->at(0)->printIndexedBuffer("B0"); - ASSERT_EQ(*resultA0.at(0), *resultA1.at(0)); - ASSERT_NE(*resultA0.at(0), *resultB0.at(0)); + sd::ops::hashcode op; + auto resultA0 = op.evaluate({&x}); + auto resultA1 = op.evaluate({&x}); + auto resultB0 = op.evaluate({&y}); + // resultA0->at(0)->printIndexedBuffer("A0"); + // resultA1->at(0)->printIndexedBuffer("A1"); + // resultB0->at(0)->printIndexedBuffer("B0"); + ASSERT_EQ(resultA0.at(0), resultA1.at(0)); + ASSERT_NE(resultA0.at(0), resultB0.at(0)); } TEST_F(DeclarableOpsTests15, test_hashCode_2) { - auto x = NDArrayFactory::create('c', {1027}); - auto y = NDArrayFactory::create('c', {1027}); + auto x = NDArrayFactory::create('c', {1027}); + auto y = NDArrayFactory::create('c', {1027}); - x.linspace(1.); - y.linspace(2.); + x.linspace(1.); + y.linspace(2.); - sd::ops::hashcode op; - auto resultA0 = op.evaluate({&x}); - auto resultA1 = op.evaluate({&x}); - auto resultB0 = op.evaluate({&y}); + sd::ops::hashcode op; + auto resultA0 = op.evaluate({&x}); + auto resultA1 = op.evaluate({&x}); + auto resultB0 = op.evaluate({&y}); -// resultA0->at(0)->printIndexedBuffer("A0"); -// resultA1->at(0)->printIndexedBuffer("A1"); -// resultB0->at(0)->printIndexedBuffer("B0"); + // resultA0->at(0)->printIndexedBuffer("A0"); + // resultA1->at(0)->printIndexedBuffer("A1"); + // resultB0->at(0)->printIndexedBuffer("B0"); - ASSERT_EQ(*resultA0.at(0), *resultA1.at(0)); - ASSERT_NE(*resultA0.at(0), *resultB0.at(0)); + ASSERT_EQ(resultA0.at(0), resultA1.at(0)); + ASSERT_NE(resultA0.at(0), resultB0.at(0)); } TEST_F(DeclarableOpsTests15, test_rank_1) { - auto array = NDArrayFactory::create('c', {4, 64}); - auto e = NDArrayFactory::create('c', {}, {2}); - auto z = NDArrayFactory::create('c', {}); + auto array = NDArrayFactory::create('c', {4, 64}); + auto e = NDArrayFactory::create('c', {}, {2}); + auto z = NDArrayFactory::create('c', {}); - sd::ops::rank op; - auto result = op.execute({&array}, {&z}, {}, {}, {}); - ASSERT_EQ(Status::OK(), result); - ASSERT_EQ(e, z); + sd::ops::rank op; + auto result = op.execute({&array}, {&z}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests15, test_rank_2) { - auto array = NDArrayFactory::create('c', {4, 64}); - auto e = NDArrayFactory::create('c', {}, {2}); - - sd::ops::rank op; - auto result = op.evaluate({&array}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto array = NDArrayFactory::create('c', {4, 64}); + auto e = NDArrayFactory::create('c', {}, {2}); - auto z = result.at(0); - - ASSERT_EQ(e, *z); + sd::ops::rank op; + auto result = op.evaluate({&array}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests15, test_lstmBlock_1) { - auto x0 = NDArrayFactory::create(5); - auto x1 = NDArrayFactory::create('c', {5, 1, 4}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f, 0.50563407f, 0.89252293f, 0.5461209f, 0.92336726f, 0.085571885f, 0.7937801f, 0.65908563f, 0.55552566f, 0.15962744f, 0.30874777f, 0.15476847f, 0.46954823f, 0.9938899f, 0.6112741f}); - auto x2 = NDArrayFactory::create('c', {1, 3}, {0.7717289f, 0.9280778f, 0.98455656f}); - auto x3 = NDArrayFactory::create('c', {1, 3}, {0.94414854f, 0.5956861f, 0.8668989f}); - auto x4 = NDArrayFactory::create('c', {7, 12}, {0.460692f, 0.042572856f, 0.08420354f, -0.09538093f, -0.11416581f, -0.53166187f, 0.40133476f, -0.24381405f, 0.30778718f, 0.52713746f, 0.16253126f, -0.034891903f, 0.011679292f, -0.19076681f, 0.14710993f, -0.3704369f, 0.51872355f, 0.13536876f, -0.5568739f, -0.08727971f, 0.07601875f, -0.074174374f, -0.5345982f, -0.3581748f, -0.28263924f, -0.25141674f, 0.43328637f, -0.50227314f, -0.26641843f, -0.38241976f, -0.19636461f, -0.04020852f, -0.27312332f, 0.5207915f, -0.37247592f, -0.4713087f, -0.25670746f, -0.14942765f, -0.015806139f, -0.22531253f, 0.5582536f, 0.3093416f, 0.3221351f, -0.0964683f, 0.14318448f, 0.42279094f, -0.46992f, -0.43399644f, -0.51704615f, -0.11854091f, 0.21697259f, -0.049382925f, 0.14059627f, 0.3912331f, -0.41345632f, 0.5067368f, -0.3420229f, 0.485789f, 0.044918716f, 0.26209074f, 0.12357575f, 0.21778125f, -0.53791714f, 0.18346387f, 0.054183125f, 0.5480431f, 0.03675288f, -0.26656917f, -0.018610716f, 0.19917983f, 0.5566165f, 0.43570566f, -0.35720813f, 0.31097364f, -0.47134516f, -0.289197f, 0.091138184f, 0.13300979f, -0.36592877f, -0.17540845f, 0.21732038f, 0.4393713f, 0.42800313f, 0.5006979f}); - auto x5 = NDArrayFactory::create('c', {1, 3}); - auto x6 = NDArrayFactory::create('c', {1, 3}); - auto x7 = NDArrayFactory::create('c', {1, 3}); - auto x8 = NDArrayFactory::create('c', {12}); - - sd::ops::lstmBlock op; - auto result = op.evaluate({&x0, &x1, &x2, &x3, &x4, &x5, &x6, &x7, &x8}, {2.0, 0.3}, {0, 0}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - // z->printIndexedBuffer("Z"); + auto x0 = NDArrayFactory::create(5); + auto x1 = NDArrayFactory::create( + 'c', {5, 1, 4}, + {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, + 0.18039072f, 0.50563407f, 0.89252293f, 0.5461209f, 0.92336726f, + 0.085571885f, 0.7937801f, 0.65908563f, 0.55552566f, 0.15962744f, + 0.30874777f, 0.15476847f, 0.46954823f, 0.9938899f, 0.6112741f}); + auto x2 = NDArrayFactory::create( + 'c', {1, 3}, {0.7717289f, 0.9280778f, 0.98455656f}); + auto x3 = NDArrayFactory::create( + 'c', {1, 3}, {0.94414854f, 0.5956861f, 0.8668989f}); + auto x4 = NDArrayFactory::create( + 'c', {7, 12}, + {0.460692f, 0.042572856f, 0.08420354f, -0.09538093f, -0.11416581f, + -0.53166187f, 0.40133476f, -0.24381405f, 0.30778718f, 0.52713746f, + 0.16253126f, -0.034891903f, 0.011679292f, -0.19076681f, 0.14710993f, + -0.3704369f, 0.51872355f, 0.13536876f, -0.5568739f, -0.08727971f, + 0.07601875f, -0.074174374f, -0.5345982f, -0.3581748f, -0.28263924f, + -0.25141674f, 0.43328637f, -0.50227314f, -0.26641843f, -0.38241976f, + -0.19636461f, -0.04020852f, -0.27312332f, 0.5207915f, -0.37247592f, + -0.4713087f, -0.25670746f, -0.14942765f, -0.015806139f, -0.22531253f, + 0.5582536f, 0.3093416f, 0.3221351f, -0.0964683f, 0.14318448f, + 0.42279094f, -0.46992f, -0.43399644f, -0.51704615f, -0.11854091f, + 0.21697259f, -0.049382925f, 0.14059627f, 0.3912331f, -0.41345632f, + 0.5067368f, -0.3420229f, 0.485789f, 0.044918716f, 0.26209074f, + 0.12357575f, 0.21778125f, -0.53791714f, 0.18346387f, 0.054183125f, + 0.5480431f, 0.03675288f, -0.26656917f, -0.018610716f, 0.19917983f, + 0.5566165f, 0.43570566f, -0.35720813f, 0.31097364f, -0.47134516f, + -0.289197f, 0.091138184f, 0.13300979f, -0.36592877f, -0.17540845f, + 0.21732038f, 0.4393713f, 0.42800313f, 0.5006979f}); + auto x5 = NDArrayFactory::create('c', {1, 3}); + auto x6 = NDArrayFactory::create('c', {1, 3}); + auto x7 = NDArrayFactory::create('c', {1, 3}); + auto x8 = NDArrayFactory::create('c', {12}); + + sd::ops::lstmBlock op; + auto result = op.evaluate({&x0, &x1, &x2, &x3, &x4, &x5, &x6, &x7, &x8}, + {2.0, 0.3}, {0, 0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + // z->printIndexedBuffer("Z"); } TEST_F(DeclarableOpsTests15, test_lstmBlock_2) { - int seqLen = 8; - int bS = 16; - int nIn = 8; - - auto x0 = NDArrayFactory::create(5); - auto x1 = NDArrayFactory::create('f', {bS, nIn, seqLen}); - auto x2 = NDArrayFactory::create('f', {bS, nIn}); // nIn == nOut - auto x3 = NDArrayFactory::create('f', {bS, nIn}); - auto x4 = NDArrayFactory::create('f', {2 * nIn, 4 * nIn}); - auto x5 = NDArrayFactory::create('f', {nIn}); - auto x6 = NDArrayFactory::create('f', {nIn}); - auto x7 = NDArrayFactory::create('f', {nIn}); - auto x8 = NDArrayFactory::create('f', {4 * nIn}); + int seqLen = 8; + int bS = 16; + int nIn = 8; - sd::ops::lstmBlock op; - auto result = op.evaluate({&x0, &x1, &x2, &x3, &x4, &x5, &x6, &x7, &x8}, {1.0, 0.0}, {0, 1}); - ASSERT_EQ(Status::OK(), result.status()); + auto x0 = NDArrayFactory::create(5); + auto x1 = NDArrayFactory::create('f', {bS, nIn, seqLen}); + auto x2 = NDArrayFactory::create('f', {bS, nIn}); // nIn == nOut + auto x3 = NDArrayFactory::create('f', {bS, nIn}); + auto x4 = NDArrayFactory::create('f', {2 * nIn, 4 * nIn}); + auto x5 = NDArrayFactory::create('f', {nIn}); + auto x6 = NDArrayFactory::create('f', {nIn}); + auto x7 = NDArrayFactory::create('f', {nIn}); + auto x8 = NDArrayFactory::create('f', {4 * nIn}); - auto z = result.at(0); + sd::ops::lstmBlock op; + auto result = op.evaluate({&x0, &x1, &x2, &x3, &x4, &x5, &x6, &x7, &x8}, + {1.0, 0.0}, {0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); } TEST_F(DeclarableOpsTests15, test_lstmBlock_3) { + int seqLen = 3; + int bS = 2; + int nIn = 4; - int seqLen = 3; - int bS = 2; - int nIn = 4; - - NDArray f('f', {bS, nIn, seqLen}, sd::DataType::FLOAT32); - NDArray cLast('f', {bS, nIn}, sd::DataType::FLOAT32); + NDArray f('f', {bS, nIn, seqLen}, sd::DataType::FLOAT32); + NDArray cLast('f', {bS, nIn}, sd::DataType::FLOAT32); - f = 2; - cLast = 3; + f = 2; + cLast = 3; - for (int t = 0; t < seqLen; ++t) { + for (int t = 0; t < seqLen; ++t) { + // section 1 + // auto ft = f({0,0, 0,0, t,t+1}); + // auto temp = ft * cLast; - //section 1 - //auto ft = f({0,0, 0,0, t,t+1}); - //auto temp = ft * cLast; - - - // section 2 - auto ft = f({0,0, 0,0, t,t+1}); - auto temp1 = ft.reshape('f', {bS, nIn}); - auto temp2 = temp1 * cLast; - } + // section 2 + auto ft = f({0, 0, 0, 0, t, t + 1}); + auto temp1 = ft.reshape('f', {bS, nIn}); + auto temp2 = temp1 * cLast; + } } TEST_F(DeclarableOpsTests15, test_empty_increasing_1) { - auto x = NDArrayFactory::create('c', {1, 0, 3}); - auto z = NDArrayFactory::create(false); + auto x = NDArrayFactory::create('c', {1, 0, 3}); + auto z = NDArrayFactory::create(false); - Context ctx(1); - ctx.setInputArray(0, &x); - ctx.setOutputArray(0, &z); + Context ctx(1); + ctx.setInputArray(0, x); + ctx.setOutputArray(0, z); - sd::ops::is_strictly_increasing op; - auto status = op.execute(&ctx); - ASSERT_EQ(Status::OK(), status); + sd::ops::is_strictly_increasing op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(true, z.e(0)); + ASSERT_EQ(true, z.e(0)); } TEST_F(DeclarableOpsTests15, test_empty_decreasing_1) { - auto x = NDArrayFactory::create('c', {1, 0, 3}); - auto z = NDArrayFactory::create(false); + auto x = NDArrayFactory::create('c', {1, 0, 3}); + auto z = NDArrayFactory::create(false); - Context ctx(1); - ctx.setInputArray(0, &x); - ctx.setOutputArray(0, &z); + Context ctx(1); + ctx.setInputArray(0, x); + ctx.setOutputArray(0, z); - sd::ops::is_non_decreasing op; - auto status = op.execute(&ctx); - ASSERT_EQ(Status::OK(), status); + sd::ops::is_non_decreasing op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(true, z.e(0)); + ASSERT_EQ(true, z.e(0)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_grs_1) { - // rank 1 - NDArray rgbs('c', { 3 }, { 10, 50, 200 }, sd::DataType::INT32); - NDArray expected('c', { 1 }, std::vector{ 55 }, sd::DataType::INT32); - sd::ops::rgb_to_grs op; - auto result = op.evaluate({&rgbs}, {}, {}); - auto output = result.at(0); + // rank 1 + NDArray rgbs('c', {3}, {10, 50, 200}, sd::DataType::INT32); + NDArray expected('c', {1}, std::vector{55}, sd::DataType::INT32); + sd::ops::rgb_to_grs op; + auto result = op.evaluate({&rgbs}, {}, {}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_grs_2) { - // rank 1 - auto rgbs = NDArrayFactory::create('f', { 3 }, { 1, 120, -25 }); - auto expected = NDArrayFactory::create('f', { 1 }, { 67 }); - sd::ops::rgb_to_grs op; - auto result = op.evaluate({ &rgbs }, {}, {}); - auto output = result.at(0); + // rank 1 + auto rgbs = NDArrayFactory::create('f', {3}, {1, 120, -25}); + auto expected = NDArrayFactory::create('f', {1}, {67}); + sd::ops::rgb_to_grs op; + auto result = op.evaluate({&rgbs}, {}, {}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_grs_3) { - // rank 2 - NDArray rgbs('c', { 4, 3 }, { -94, 99, 97, 90, 114, 101, 111, 96, 105, 100, 103, 102 }, sd::DataType::INT32); - NDArray expected('c', { 4, 1 }, { 41, 105, 101, 101 }, sd::DataType::INT32); - sd::ops::rgb_to_grs op; - auto result = op.evaluate({ &rgbs }, {}, {}); - auto output = result.at(0); + // rank 2 + NDArray rgbs('c', {4, 3}, + {-94, 99, 97, 90, 114, 101, 111, 96, 105, 100, 103, 102}, + sd::DataType::INT32); + NDArray expected('c', {4, 1}, {41, 105, 101, 101}, sd::DataType::INT32); + sd::ops::rgb_to_grs op; + auto result = op.evaluate({&rgbs}, {}, {}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_grs_4) { + NDArray rgbs('c', {3, 2}, {14, 99, 207, 10, 114, 201}, sd::DataType::INT32); - NDArray rgbs('c', { 3, 2 }, {14, 99, 207, 10, 114, 201 }, sd::DataType::INT32); - - rgbs.permutei({1,0}); - NDArray expected('c', { 2, 1 }, { 138, 58 }, sd::DataType::INT32); - sd::ops::rgb_to_grs op; - auto result = op.evaluate({ &rgbs }, {}, {}); - auto output = result.at(0); + rgbs.permutei({1, 0}); + NDArray expected('c', {2, 1}, {138, 58}, sd::DataType::INT32); + sd::ops::rgb_to_grs op; + auto result = op.evaluate({&rgbs}, {}, {}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_grs_5) { - // rank 2 - NDArray rgbs('c', { 3, 4 }, { -94, 99, 97, 90, 114, 101, 111, 96, 105, 100, 103, 102 }, sd::DataType::INT32); - NDArray expected('c', { 1, 4 }, { 50, 100, 105, 94 }, sd::DataType::INT32); - sd::ops::rgb_to_grs op; - auto result = op.evaluate({ &rgbs }, {}, {0}); - auto output = result.at(0); + // rank 2 + NDArray rgbs('c', {3, 4}, + {-94, 99, 97, 90, 114, 101, 111, 96, 105, 100, 103, 102}, + sd::DataType::INT32); + NDArray expected('c', {1, 4}, {50, 100, 105, 94}, sd::DataType::INT32); + sd::ops::rgb_to_grs op; + auto result = op.evaluate({&rgbs}, {}, {0}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_grs_6) { - // rank 3 - auto rgbs = NDArrayFactory::create('c', { 5,4,3 }, {1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f}); - auto expected = NDArrayFactory::create('c', { 5,4,1 }, {-47.82958221f, 34.46305847f, 21.36137581f, -21.91625023f,2.49686432f, -43.59792709f, 9.64180183f, 23.04854202f,40.7946167f, 44.98754883f, -25.19047546f, 20.64586449f,-4.97033119f, 30.0226841f, 30.30688286f, 15.61459541f,43.36166f, 18.22480774f, 13.74833488f, 21.59387016f}); - - sd::ops::rgb_to_grs op; - auto result = op.evaluate({ &rgbs }, {}, {}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + // rank 3 + auto rgbs = NDArrayFactory::create( + 'c', {5, 4, 3}, + {1.7750e+01f, -7.1062e+01f, -1.0019e+02f, -2.3406e+01f, 5.2094e+01f, + 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f, 3.3562e+01f, + -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f, + 2.1391e+01f, -8.5312e+01f, 7.5830e-01f, 2.3125e+01f, 1.8145e+00f, + 1.4602e+01f, -4.5859e+00f, 3.9344e+01f, 1.1617e+01f, -8.6562e+01f, + 1.0038e+02f, 6.7938e+01f, 5.9961e+00f, 6.7812e+01f, 2.9734e+01f, + 2.9609e+01f, -6.1438e+01f, 1.7750e+01f, 6.8562e+01f, -7.4414e+00f, + 3.9656e+01f, 1.1641e+01f, -2.7516e+01f, 6.7562e+01f, 7.8438e+01f, + 5.4883e+00f, 2.9438e+01f, -3.1344e+01f, 6.5125e+01f, 1.2695e+01f, + 4.0531e+01f, -6.1211e+00f, 6.2219e+01f, 4.6812e+01f, 5.2250e+01f, + -1.1414e+01f, 1.5404e-02f, 2.9938e+01f, 5.6719e+00f, -2.0125e+01f, + 2.1531e+01f, 6.2500e+01f, 7.2188e+01f, 9.3750e+00f, -4.8125e+01f}); + auto expected = NDArrayFactory::create( + 'c', {5, 4, 1}, + {-47.82958221f, 34.46305847f, 21.36137581f, -21.91625023f, 2.49686432f, + -43.59792709f, 9.64180183f, 23.04854202f, 40.7946167f, 44.98754883f, + -25.19047546f, 20.64586449f, -4.97033119f, 30.0226841f, 30.30688286f, + 15.61459541f, 43.36166f, 18.22480774f, 13.74833488f, 21.59387016f}); + + sd::ops::rgb_to_grs op; + auto result = op.evaluate({&rgbs}, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_grs_7) { - // rank 3 - auto rgbs = NDArrayFactory::create('c', { 5,3,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f}); - auto expected = NDArrayFactory::create('c', { 5,1,4 }, { 36.626545f, 38.607746f, -40.614971f, 18.233341f, -51.545094f,2.234142f, 20.913160f, 8.783220f, 15.955761f, 55.273506f, 36.838833f, -29.751089f, 8.148357f, 13.676106f, 1.097548f, 68.766457f, 38.690712f, 27.176361f, -14.156269f, 7.157052f }); - - sd::ops::rgb_to_grs op; - auto result = op.evaluate({ &rgbs }, {}, {1}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + // rank 3 + auto rgbs = NDArrayFactory::create( + 'c', {5, 3, 4}, + {1.7750e+01f, -7.1062e+01f, -1.0019e+02f, -2.3406e+01f, 5.2094e+01f, + 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f, 3.3562e+01f, + -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f, + 2.1391e+01f, -8.5312e+01f, 7.5830e-01f, 2.3125e+01f, 1.8145e+00f, + 1.4602e+01f, -4.5859e+00f, 3.9344e+01f, 1.1617e+01f, -8.6562e+01f, + 1.0038e+02f, 6.7938e+01f, 5.9961e+00f, 6.7812e+01f, 2.9734e+01f, + 2.9609e+01f, -6.1438e+01f, 1.7750e+01f, 6.8562e+01f, -7.4414e+00f, + 3.9656e+01f, 1.1641e+01f, -2.7516e+01f, 6.7562e+01f, 7.8438e+01f, + 5.4883e+00f, 2.9438e+01f, -3.1344e+01f, 6.5125e+01f, 1.2695e+01f, + 4.0531e+01f, -6.1211e+00f, 6.2219e+01f, 4.6812e+01f, 5.2250e+01f, + -1.1414e+01f, 1.5404e-02f, 2.9938e+01f, 5.6719e+00f, -2.0125e+01f, + 2.1531e+01f, 6.2500e+01f, 7.2188e+01f, 9.3750e+00f, -4.8125e+01f}); + auto expected = NDArrayFactory::create( + 'c', {5, 1, 4}, + {36.626545f, 38.607746f, -40.614971f, 18.233341f, -51.545094f, + 2.234142f, 20.913160f, 8.783220f, 15.955761f, 55.273506f, + 36.838833f, -29.751089f, 8.148357f, 13.676106f, 1.097548f, + 68.766457f, 38.690712f, 27.176361f, -14.156269f, 7.157052f}); + + sd::ops::rgb_to_grs op; + auto result = op.evaluate({&rgbs}, {}, {1}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_grs_8) { - // rank 3 - auto rgbs = NDArrayFactory::create('c', { 3,5,4 }, {1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f}); - try { - sd::ops::rgb_to_grs op; - auto result = op.evaluate({ &rgbs }, {}, {}); - ASSERT_EQ(Status::THROW(), result.status()); - - } catch (std::exception& e) { - nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); - } + // rank 3 + auto rgbs = NDArrayFactory::create( + 'c', {3, 5, 4}, + {1.7750e+01f, -7.1062e+01f, -1.0019e+02f, -2.3406e+01f, 5.2094e+01f, + 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f, 3.3562e+01f, + -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f, + 2.1391e+01f, -8.5312e+01f, 7.5830e-01f, 2.3125e+01f, 1.8145e+00f, + 1.4602e+01f, -4.5859e+00f, 3.9344e+01f, 1.1617e+01f, -8.6562e+01f, + 1.0038e+02f, 6.7938e+01f, 5.9961e+00f, 6.7812e+01f, 2.9734e+01f, + 2.9609e+01f, -6.1438e+01f, 1.7750e+01f, 6.8562e+01f, -7.4414e+00f, + 3.9656e+01f, 1.1641e+01f, -2.7516e+01f, 6.7562e+01f, 7.8438e+01f, + 5.4883e+00f, 2.9438e+01f, -3.1344e+01f, 6.5125e+01f, 1.2695e+01f, + 4.0531e+01f, -6.1211e+00f, 6.2219e+01f, 4.6812e+01f, 5.2250e+01f, + -1.1414e+01f, 1.5404e-02f, 2.9938e+01f, 5.6719e+00f, -2.0125e+01f, + 2.1531e+01f, 6.2500e+01f, 7.2188e+01f, 9.3750e+00f, -4.8125e+01f}); + try { + sd::ops::rgb_to_grs op; + auto result = op.evaluate({&rgbs}, {}, {}); + ASSERT_EQ(Status::THROW(), result.status()); + + } catch (std::exception& e) { + nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); + } } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_grs_9) { - // rank 3 - auto rgbs = NDArrayFactory::create('f', { 2, 2, 3 }, { 1.7750e+01f,-7.1062e+01f, -1.0019e+02f, -2.3406e+01f,5.2094e+01f,9.5438e+01f, -6.7461e+00f,3.8562e+01f, 6.5078e+00f, 3.3562e+01f,-5.8844e+01f,2.2750e+01f}); - auto expected = NDArrayFactory::create('f', { 2,2,1 }, { 36.626545f, 38.607746f, -40.614971f, 18.233341f }); + // rank 3 + auto rgbs = NDArrayFactory::create( + 'f', {2, 2, 3}, + {1.7750e+01f, -7.1062e+01f, -1.0019e+02f, -2.3406e+01f, 5.2094e+01f, + 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f, 3.3562e+01f, + -5.8844e+01f, 2.2750e+01f}); + auto expected = NDArrayFactory::create( + 'f', {2, 2, 1}, {36.626545f, 38.607746f, -40.614971f, 18.233341f}); - sd::ops::rgb_to_grs op; - auto result = op.evaluate({ &rgbs }, {}, {}); - auto output = result.at(0); + sd::ops::rgb_to_grs op; + auto result = op.evaluate({&rgbs}, {}, {}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_1) { - // rank 1 - NDArray rgbs('f', { 3 }, { 10, 50, 200 }, sd::DataType::FLOAT32); - NDArray expected('f', { 3 }, { 55.14 , 71.2872001, -39.6005542 }, sd::DataType::FLOAT32); - sd::ops::rgb_to_yuv op; - auto result = op.evaluate({ &rgbs }, {}, {}); - auto output = result.at(0); + // rank 1 + NDArray rgbs('f', {3}, {10, 50, 200}, sd::DataType::FLOAT32); + NDArray expected('f', {3}, {55.14, 71.2872001, -39.6005542}, + sd::DataType::FLOAT32); + sd::ops::rgb_to_yuv op; + auto result = op.evaluate({&rgbs}, {}, {}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_2) { + NDArray rgbs('c', {3, 2}, {14., 99., 207., 10., 114., 201.}, + sd::DataType::FLOAT32); + rgbs.permutei({1, 0}); - NDArray rgbs('c', { 3, 2 }, { 14., 99., 207., 10., 114., 201. }, sd::DataType::FLOAT32); - rgbs.permutei({ 1,0 }); - - NDArray expected('c', { 2, 3 }, { 138.691, -12.150713, -109.38929, 58.385, 70.18241, 35.63085 }, sd::DataType::FLOAT32); - sd::ops::rgb_to_yuv op; + NDArray expected( + 'c', {2, 3}, + {138.691, -12.150713, -109.38929, 58.385, 70.18241, 35.63085}, + sd::DataType::FLOAT32); + sd::ops::rgb_to_yuv op; - auto result = op.evaluate({ &rgbs }, {}, {}); - auto output = result.at(0); + auto result = op.evaluate({&rgbs}, {}, {}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_3) { - // rank 2 - NDArray rgbs('c', { 3, 4 }, { -9.4, 9.9, 9.7, 9.0, 1.14, 1.01, 1.11, 9.6, 1.05, 10.0, 1.03, 10.22 }, sd::DataType::FLOAT32); - NDArray expected('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, sd::DataType::FLOAT32); - - sd::ops::rgb_to_yuv op; - auto result = op.evaluate({ &rgbs }, {}, { 0 }); - auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + // rank 2 + NDArray rgbs( + 'c', {3, 4}, + {-9.4, 9.9, 9.7, 9.0, 1.14, 1.01, 1.11, 9.6, 1.05, 10.0, 1.03, 10.22}, + sd::DataType::FLOAT32); + NDArray expected( + 'c', {3, 4}, + {-2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, + 0.358612, -6.472839, 4.568039, 5.290639, -0.430992}, + sd::DataType::FLOAT32); + + sd::ops::rgb_to_yuv op; + auto result = op.evaluate({&rgbs}, {}, {0}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_4) { - // rank 3 - NDArray rgbs('c', { 5,4,3 }, { 1.7750e+01, 1.4602e+01, 5.4883e+00, 9.5438e+01, 1.0038e+02, 4.0531e+01, -5.8844e+01, 2.9609e+01, -1.1414e+01, 2.1391e+01, 3.9656e+01, 2.1531e+01, -7.1062e+01, -4.5859e+00, 2.9438e+01, -6.7461e+00, 6.7938e+01, -6.1211e+00, 2.2750e+01, -6.1438e+01, 1.5404e-02, -8.5312e+01, 1.1641e+01, 6.2500e+01, -1.0019e+02, 3.9344e+01, -3.1344e+01, 3.8562e+01, 5.9961e+00, 6.2219e+01, -1.0477e+01, 1.7750e+01, 2.9938e+01, 7.5830e-01, -2.7516e+01, 7.2188e+01, -2.3406e+01, 1.1617e+01, 6.5125e+01, 6.5078e+00, 6.7812e+01, 4.6812e+01, 7.7344e+00, 6.8562e+01, 5.6719e+00, 2.3125e+01, 6.7562e+01, 9.3750e+00, 5.2094e+01, -8.6562e+01, 1.2695e+01, 3.3562e+01, 2.9734e+01, 5.2250e+01, 9.5469e+00, -7.4414e+00, -2.0125e+01, 1.8145e+00, 7.8438e+01, -4.8125e+01 }, sd::DataType::FLOAT32); - NDArray expected('c', { 5,4,3 }, { 14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, -5.21515376, -9.41983935,-20.5835293, 24.61614501, -44.28390394, 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, 45.62757638, -11.550021, 36.44083018, -64.71012983,-10.435098, - 10.28950082, - 78.74044941, 22.1427147, 19.72198103, 14.40435988, 10.699559, 9.46744852, - 18.5778351 , -7.6957283, 39.31166179, 7.41657542, 7.245035, 28.48336771, - 26.88963173, 47.0880442, - 0.13584441, - 35.60035823, 43.2050762, - 18.47048906, - 31.11782117, 47.642019, - 18.83162118, - 21.50836396,-33.788558, 22.87507047, 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749 }, sd::DataType::FLOAT32); - - sd::ops::rgb_to_yuv op; - auto result = op.evaluate({ &rgbs }, {}, {}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + // rank 3 + NDArray rgbs( + 'c', {5, 4, 3}, + {1.7750e+01, 1.4602e+01, 5.4883e+00, 9.5438e+01, 1.0038e+02, + 4.0531e+01, -5.8844e+01, 2.9609e+01, -1.1414e+01, 2.1391e+01, + 3.9656e+01, 2.1531e+01, -7.1062e+01, -4.5859e+00, 2.9438e+01, + -6.7461e+00, 6.7938e+01, -6.1211e+00, 2.2750e+01, -6.1438e+01, + 1.5404e-02, -8.5312e+01, 1.1641e+01, 6.2500e+01, -1.0019e+02, + 3.9344e+01, -3.1344e+01, 3.8562e+01, 5.9961e+00, 6.2219e+01, + -1.0477e+01, 1.7750e+01, 2.9938e+01, 7.5830e-01, -2.7516e+01, + 7.2188e+01, -2.3406e+01, 1.1617e+01, 6.5125e+01, 6.5078e+00, + 6.7812e+01, 4.6812e+01, 7.7344e+00, 6.8562e+01, 5.6719e+00, + 2.3125e+01, 6.7562e+01, 9.3750e+00, 5.2094e+01, -8.6562e+01, + 1.2695e+01, 3.3562e+01, 2.9734e+01, 5.2250e+01, 9.5469e+00, + -7.4414e+00, -2.0125e+01, 1.8145e+00, 7.8438e+01, -4.8125e+01}, + sd::DataType::FLOAT32); + NDArray expected( + 'c', {5, 4, 3}, + {14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, + 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, + -5.21515376, -9.41983935, -20.5835293, 24.61614501, -44.28390394, + 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, + 45.62757638, -11.550021, 36.44083018, -64.71012983, -10.435098, + -10.28950082, -78.74044941, 22.1427147, 19.72198103, 14.40435988, + 10.699559, 9.46744852, -18.5778351, -7.6957283, 39.31166179, + 7.41657542, 7.245035, 28.48336771, -26.88963173, 47.0880442, + -0.13584441, -35.60035823, 43.2050762, -18.47048906, -31.11782117, + 47.642019, -18.83162118, -21.50836396, -33.788558, 22.87507047, + 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, + -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749}, + sd::DataType::FLOAT32); + + sd::ops::rgb_to_yuv op; + auto result = op.evaluate({&rgbs}, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_5) { - // rank 3 - NDArray rgbs('c', { 5,3,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, sd::DataType::FLOAT32); - NDArray expected('c', { 5,3,4 }, { 36.628319, 38.600643,-40.624989, 18.231001, - 14.822637, - 2.479566, - 8.965780, 2.223851, -16.561626,-96.205162,-52.255379,-36.527435,-51.546139,2.234915, 20.914114, 8.785358, 32.552223, -3.356598, 9.069552, 1.393482,36.029255, 4.824605,- 9.972263,11.058715, 15.947105, 55.283543, 36.845627, -29.750486,0.887228, 6.534475, -21.794132,34.155693, -89.929497,39.562351, 27.276817,31.359871, 8.149521, 13.673355, 1.104303, 68.774300, 2.236881, 13.216944, - 3.555702,- 3.225931,3.063015, - 36.134724,58.302204, 8.477802, 38.695396,27.181587, - 14.157411,7.157054, 11.714512, 22.148155, 11.580557, - 27.204905,7.120562, 21.992094, 2.406748, - 6.265247, }, sd::DataType::FLOAT32); - - sd::ops::rgb_to_yuv op; - auto result = op.evaluate({ &rgbs }, {}, { 1 }); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + // rank 3 + NDArray rgbs( + 'c', {5, 3, 4}, + {1.7750e+01f, -7.1062e+01f, -1.0019e+02f, -2.3406e+01f, 5.2094e+01f, + 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f, 3.3562e+01f, + -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f, + 2.1391e+01f, -8.5312e+01f, 7.5830e-01f, 2.3125e+01f, 1.8145e+00f, + 1.4602e+01f, -4.5859e+00f, 3.9344e+01f, 1.1617e+01f, -8.6562e+01f, + 1.0038e+02f, 6.7938e+01f, 5.9961e+00f, 6.7812e+01f, 2.9734e+01f, + 2.9609e+01f, -6.1438e+01f, 1.7750e+01f, 6.8562e+01f, -7.4414e+00f, + 3.9656e+01f, 1.1641e+01f, -2.7516e+01f, 6.7562e+01f, 7.8438e+01f, + 5.4883e+00f, 2.9438e+01f, -3.1344e+01f, 6.5125e+01f, 1.2695e+01f, + 4.0531e+01f, -6.1211e+00f, 6.2219e+01f, 4.6812e+01f, 5.2250e+01f, + -1.1414e+01f, 1.5404e-02f, 2.9938e+01f, 5.6719e+00f, -2.0125e+01f, + 2.1531e+01f, 6.2500e+01f, 7.2188e+01f, 9.3750e+00f, -4.8125e+01f}, + sd::DataType::FLOAT32); + NDArray expected( + 'c', {5, 3, 4}, + { + 36.628319, 38.600643, -40.624989, 18.231001, -14.822637, + -2.479566, -8.965780, 2.223851, -16.561626, -96.205162, + -52.255379, -36.527435, -51.546139, 2.234915, 20.914114, + 8.785358, 32.552223, -3.356598, 9.069552, 1.393482, + 36.029255, 4.824605, -9.972263, 11.058715, 15.947105, + 55.283543, 36.845627, -29.750486, 0.887228, 6.534475, + -21.794132, 34.155693, -89.929497, 39.562351, 27.276817, + 31.359871, 8.149521, 13.673355, 1.104303, 68.774300, + 2.236881, 13.216944, -3.555702, -3.225931, 3.063015, + -36.134724, 58.302204, 8.477802, 38.695396, 27.181587, + -14.157411, 7.157054, 11.714512, 22.148155, 11.580557, + -27.204905, 7.120562, 21.992094, 2.406748, -6.265247, + }, + sd::DataType::FLOAT32); + + sd::ops::rgb_to_yuv op; + auto result = op.evaluate({&rgbs}, {}, {1}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_6) { - // rank 3 - NDArray rgbs('c', { 3,5,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, sd::DataType::FLOAT32); - try { - sd::ops::rgb_to_yuv op; - auto result = op.evaluate({ &rgbs }, {}, {}); - ASSERT_EQ(Status::THROW(), result.status()); + // rank 3 + NDArray rgbs( + 'c', {3, 5, 4}, + {1.7750e+01f, -7.1062e+01f, -1.0019e+02f, -2.3406e+01f, 5.2094e+01f, + 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f, 3.3562e+01f, + -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f, + 2.1391e+01f, -8.5312e+01f, 7.5830e-01f, 2.3125e+01f, 1.8145e+00f, + 1.4602e+01f, -4.5859e+00f, 3.9344e+01f, 1.1617e+01f, -8.6562e+01f, + 1.0038e+02f, 6.7938e+01f, 5.9961e+00f, 6.7812e+01f, 2.9734e+01f, + 2.9609e+01f, -6.1438e+01f, 1.7750e+01f, 6.8562e+01f, -7.4414e+00f, + 3.9656e+01f, 1.1641e+01f, -2.7516e+01f, 6.7562e+01f, 7.8438e+01f, + 5.4883e+00f, 2.9438e+01f, -3.1344e+01f, 6.5125e+01f, 1.2695e+01f, + 4.0531e+01f, -6.1211e+00f, 6.2219e+01f, 4.6812e+01f, 5.2250e+01f, + -1.1414e+01f, 1.5404e-02f, 2.9938e+01f, 5.6719e+00f, -2.0125e+01f, + 2.1531e+01f, 6.2500e+01f, 7.2188e+01f, 9.3750e+00f, -4.8125e+01f}, + sd::DataType::FLOAT32); + try { + sd::ops::rgb_to_yuv op; + auto result = op.evaluate({&rgbs}, {}, {}); + ASSERT_EQ(Status::THROW(), result.status()); - } - catch (std::exception & e) { - nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); - } + } catch (std::exception& e) { + nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); + } } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_7) { - // rank 3 - NDArray rgbs('f', { 2, 2, 3 }, { 1.7750e+01f,-7.1062e+01f, -1.0019e+02f, -2.3406e+01f,5.2094e+01f,9.5438e+01f, -6.7461e+00f,3.8562e+01f, 6.5078e+00f, 3.3562e+01f,-5.8844e+01f,2.2750e+01f }, sd::DataType::FLOAT32); - NDArray expected('f', { 2,2,3 }, { 36.628319,38.600643, -40.624989,18.231001, -14.822637,-2.479566, -8.965780, 2.223851, -16.561626,- 96.205162,-52.255379, -36.527435 }, sd::DataType::FLOAT32); - - sd::ops::rgb_to_yuv op; - auto result = op.evaluate({ &rgbs }, {}, {}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + // rank 3 + NDArray rgbs('f', {2, 2, 3}, + {1.7750e+01f, -7.1062e+01f, -1.0019e+02f, -2.3406e+01f, + 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, + 6.5078e+00f, 3.3562e+01f, -5.8844e+01f, 2.2750e+01f}, + sd::DataType::FLOAT32); + NDArray expected( + 'f', {2, 2, 3}, + {36.628319, 38.600643, -40.624989, 18.231001, -14.822637, -2.479566, + -8.965780, 2.223851, -16.561626, -96.205162, -52.255379, -36.527435}, + sd::DataType::FLOAT32); + + sd::ops::rgb_to_yuv op; + auto result = op.evaluate({&rgbs}, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_1) { - // rank 1 - NDArray yuv('c', { 3 }, { 55.14 , 71.2872001, -39.6005542 }, sd::DataType::FLOAT32); - NDArray expected('c', { 3 }, { 10, 50, 200 }, sd::DataType::FLOAT32); - sd::ops::yuv_to_rgb op; - auto result = op.evaluate({ &yuv }, {}, {}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + // rank 1 + NDArray yuv('c', {3}, {55.14, 71.2872001, -39.6005542}, + sd::DataType::FLOAT32); + NDArray expected('c', {3}, {10, 50, 200}, sd::DataType::FLOAT32); + sd::ops::yuv_to_rgb op; + auto result = op.evaluate({&yuv}, {}, {}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_2) { - // rank 1 - NDArray yuv('f', { 3 }, { 55.14, 71.2872001, -39.6005542 }, sd::DataType::FLOAT32); - NDArray expected('f', { 3 }, { 10, 50, 200 }, sd::DataType::FLOAT32); - sd::ops::yuv_to_rgb op; - auto result = op.evaluate({ &yuv }, {}, {}); - auto output = result.at(0); + // rank 1 + NDArray yuv('f', {3}, {55.14, 71.2872001, -39.6005542}, + sd::DataType::FLOAT32); + NDArray expected('f', {3}, {10, 50, 200}, sd::DataType::FLOAT32); + sd::ops::yuv_to_rgb op; + auto result = op.evaluate({&yuv}, {}, {}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_3) { - // rank 2 - NDArray expected('c', { 3, 4 }, { -9.4, 9.9, 9.7, 9.0, 1.14, 1.01, 1.11, 9.6, 1.05, 10.0, 1.03, 10.22 }, sd::DataType::FLOAT32); - NDArray yuv('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, sd::DataType::FLOAT32); - - sd::ops::yuv_to_rgb op; - auto result = op.evaluate({ &yuv }, {}, { 0 }); - auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + // rank 2 + NDArray expected( + 'c', {3, 4}, + {-9.4, 9.9, 9.7, 9.0, 1.14, 1.01, 1.11, 9.6, 1.05, 10.0, 1.03, 10.22}, + sd::DataType::FLOAT32); + NDArray yuv('c', {3, 4}, + {-2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, + -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992}, + sd::DataType::FLOAT32); + + sd::ops::yuv_to_rgb op; + auto result = op.evaluate({&yuv}, {}, {0}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_4) { - // rank 3 - NDArray expected('c', { 5,4,3 }, { 1.7750e+01, 1.4602e+01, 5.4883e+00, 9.5438e+01, 1.0038e+02, 4.0531e+01, -5.8844e+01, 2.9609e+01, -1.1414e+01, 2.1391e+01, 3.9656e+01, 2.1531e+01, -7.1062e+01, -4.5859e+00, 2.9438e+01, -6.7461e+00, 6.7938e+01, -6.1211e+00, 2.2750e+01, -6.1438e+01, 1.5404e-02, -8.5312e+01, 1.1641e+01, 6.2500e+01, -1.0019e+02, 3.9344e+01, -3.1344e+01, 3.8562e+01, 5.9961e+00, 6.2219e+01, -1.0477e+01, 1.7750e+01, 2.9938e+01, 7.5830e-01, -2.7516e+01, 7.2188e+01, -2.3406e+01, 1.1617e+01, 6.5125e+01, 6.5078e+00, 6.7812e+01, 4.6812e+01, 7.7344e+00, 6.8562e+01, 5.6719e+00, 2.3125e+01, 6.7562e+01, 9.3750e+00, 5.2094e+01, -8.6562e+01, 1.2695e+01, 3.3562e+01, 2.9734e+01, 5.2250e+01, 9.5469e+00, -7.4414e+00, -2.0125e+01, 1.8145e+00, 7.8438e+01, -4.8125e+01 }, sd::DataType::FLOAT32); - NDArray yuv('c', { 5,4,3 }, { 14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, -5.21515376, -9.41983935,-20.5835293, 24.61614501, -44.28390394, 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, 45.62757638, -11.550021, 36.44083018, -64.71012983,-10.435098, -10.28950082, -78.74044941, 22.1427147, 19.72198103, 14.40435988, 10.699559, 9.46744852, -18.5778351 , -7.6957283, 39.31166179, 7.41657542, 7.245035, 28.48336771, -26.88963173, 47.0880442, -0.13584441, -35.60035823, 43.2050762, -18.47048906, -31.11782117, 47.642019, -18.83162118, -21.50836396,-33.788558, 22.87507047, 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749 }, sd::DataType::FLOAT32); - - sd::ops::yuv_to_rgb op; - auto result = op.evaluate({ &yuv }, {}, {}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + // rank 3 + NDArray expected( + 'c', {5, 4, 3}, + {1.7750e+01, 1.4602e+01, 5.4883e+00, 9.5438e+01, 1.0038e+02, + 4.0531e+01, -5.8844e+01, 2.9609e+01, -1.1414e+01, 2.1391e+01, + 3.9656e+01, 2.1531e+01, -7.1062e+01, -4.5859e+00, 2.9438e+01, + -6.7461e+00, 6.7938e+01, -6.1211e+00, 2.2750e+01, -6.1438e+01, + 1.5404e-02, -8.5312e+01, 1.1641e+01, 6.2500e+01, -1.0019e+02, + 3.9344e+01, -3.1344e+01, 3.8562e+01, 5.9961e+00, 6.2219e+01, + -1.0477e+01, 1.7750e+01, 2.9938e+01, 7.5830e-01, -2.7516e+01, + 7.2188e+01, -2.3406e+01, 1.1617e+01, 6.5125e+01, 6.5078e+00, + 6.7812e+01, 4.6812e+01, 7.7344e+00, 6.8562e+01, 5.6719e+00, + 2.3125e+01, 6.7562e+01, 9.3750e+00, 5.2094e+01, -8.6562e+01, + 1.2695e+01, 3.3562e+01, 2.9734e+01, 5.2250e+01, 9.5469e+00, + -7.4414e+00, -2.0125e+01, 1.8145e+00, 7.8438e+01, -4.8125e+01}, + sd::DataType::FLOAT32); + NDArray yuv( + 'c', {5, 4, 3}, + {14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, + 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, + -5.21515376, -9.41983935, -20.5835293, 24.61614501, -44.28390394, + 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, + 45.62757638, -11.550021, 36.44083018, -64.71012983, -10.435098, + -10.28950082, -78.74044941, 22.1427147, 19.72198103, 14.40435988, + 10.699559, 9.46744852, -18.5778351, -7.6957283, 39.31166179, + 7.41657542, 7.245035, 28.48336771, -26.88963173, 47.0880442, + -0.13584441, -35.60035823, 43.2050762, -18.47048906, -31.11782117, + 47.642019, -18.83162118, -21.50836396, -33.788558, 22.87507047, + 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, + -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749}, + sd::DataType::FLOAT32); + + sd::ops::yuv_to_rgb op; + auto result = op.evaluate({&yuv}, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_5) { - // rank 3 - NDArray expected('c', { 5,3,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, sd::DataType::FLOAT32); - NDArray yuv('c', { 5,3,4 }, { 36.628319, 38.600643,-40.624989, 18.231001, -14.822637, -2.479566, -8.965780, 2.223851, -16.561626,-96.205162,-52.255379,-36.527435,-51.546139,2.234915, 20.914114, 8.785358, 32.552223, -3.356598, 9.069552, 1.393482,36.029255, 4.824605,-9.972263,11.058715, 15.947105, 55.283543, 36.845627, -29.750486,0.887228, 6.534475, -21.794132,34.155693, -89.929497,39.562351, 27.276817,31.359871, 8.149521, 13.673355, 1.104303, 68.774300, 2.236881, 13.216944, -3.555702,-3.225931,3.063015, -36.134724,58.302204, 8.477802, 38.695396,27.181587, -14.157411,7.157054, 11.714512, 22.148155, 11.580557, -27.204905,7.120562, 21.992094, 2.406748, -6.265247, }, sd::DataType::FLOAT32); - - sd::ops::yuv_to_rgb op; - auto result = op.evaluate({ &yuv }, {}, { 1 }); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + // rank 3 + NDArray expected( + 'c', {5, 3, 4}, + {1.7750e+01f, -7.1062e+01f, -1.0019e+02f, -2.3406e+01f, 5.2094e+01f, + 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f, 3.3562e+01f, + -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f, + 2.1391e+01f, -8.5312e+01f, 7.5830e-01f, 2.3125e+01f, 1.8145e+00f, + 1.4602e+01f, -4.5859e+00f, 3.9344e+01f, 1.1617e+01f, -8.6562e+01f, + 1.0038e+02f, 6.7938e+01f, 5.9961e+00f, 6.7812e+01f, 2.9734e+01f, + 2.9609e+01f, -6.1438e+01f, 1.7750e+01f, 6.8562e+01f, -7.4414e+00f, + 3.9656e+01f, 1.1641e+01f, -2.7516e+01f, 6.7562e+01f, 7.8438e+01f, + 5.4883e+00f, 2.9438e+01f, -3.1344e+01f, 6.5125e+01f, 1.2695e+01f, + 4.0531e+01f, -6.1211e+00f, 6.2219e+01f, 4.6812e+01f, 5.2250e+01f, + -1.1414e+01f, 1.5404e-02f, 2.9938e+01f, 5.6719e+00f, -2.0125e+01f, + 2.1531e+01f, 6.2500e+01f, 7.2188e+01f, 9.3750e+00f, -4.8125e+01f}, + sd::DataType::FLOAT32); + NDArray yuv('c', {5, 3, 4}, + { + 36.628319, 38.600643, -40.624989, 18.231001, -14.822637, + -2.479566, -8.965780, 2.223851, -16.561626, -96.205162, + -52.255379, -36.527435, -51.546139, 2.234915, 20.914114, + 8.785358, 32.552223, -3.356598, 9.069552, 1.393482, + 36.029255, 4.824605, -9.972263, 11.058715, 15.947105, + 55.283543, 36.845627, -29.750486, 0.887228, 6.534475, + -21.794132, 34.155693, -89.929497, 39.562351, 27.276817, + 31.359871, 8.149521, 13.673355, 1.104303, 68.774300, + 2.236881, 13.216944, -3.555702, -3.225931, 3.063015, + -36.134724, 58.302204, 8.477802, 38.695396, 27.181587, + -14.157411, 7.157054, 11.714512, 22.148155, 11.580557, + -27.204905, 7.120562, 21.992094, 2.406748, -6.265247, + }, + sd::DataType::FLOAT32); + + sd::ops::yuv_to_rgb op; + auto result = op.evaluate({&yuv}, {}, {1}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_6) { - // rank 3 - NDArray yuv('c', { 3,5,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, sd::DataType::FLOAT32); - try { - sd::ops::yuv_to_rgb op; - auto result = op.evaluate({ &yuv }, {}, {}); - ASSERT_EQ(Status::THROW(), result.status()); + // rank 3 + NDArray yuv( + 'c', {3, 5, 4}, + {1.7750e+01f, -7.1062e+01f, -1.0019e+02f, -2.3406e+01f, 5.2094e+01f, + 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f, 3.3562e+01f, + -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f, + 2.1391e+01f, -8.5312e+01f, 7.5830e-01f, 2.3125e+01f, 1.8145e+00f, + 1.4602e+01f, -4.5859e+00f, 3.9344e+01f, 1.1617e+01f, -8.6562e+01f, + 1.0038e+02f, 6.7938e+01f, 5.9961e+00f, 6.7812e+01f, 2.9734e+01f, + 2.9609e+01f, -6.1438e+01f, 1.7750e+01f, 6.8562e+01f, -7.4414e+00f, + 3.9656e+01f, 1.1641e+01f, -2.7516e+01f, 6.7562e+01f, 7.8438e+01f, + 5.4883e+00f, 2.9438e+01f, -3.1344e+01f, 6.5125e+01f, 1.2695e+01f, + 4.0531e+01f, -6.1211e+00f, 6.2219e+01f, 4.6812e+01f, 5.2250e+01f, + -1.1414e+01f, 1.5404e-02f, 2.9938e+01f, 5.6719e+00f, -2.0125e+01f, + 2.1531e+01f, 6.2500e+01f, 7.2188e+01f, 9.3750e+00f, -4.8125e+01f}, + sd::DataType::FLOAT32); + try { + sd::ops::yuv_to_rgb op; + auto result = op.evaluate({&yuv}, {}, {}); + ASSERT_EQ(Status::THROW(), result.status()); - } - catch (std::exception & e) { - nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); - } + } catch (std::exception& e) { + nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); + } } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_7) { - // rank 3 - NDArray expected('f', { 2, 2, 3 }, { 1.7750e+01f,-7.1062e+01f, -1.0019e+02f, -2.3406e+01f,5.2094e+01f,9.5438e+01f, -6.7461e+00f,3.8562e+01f, 6.5078e+00f, 3.3562e+01f,-5.8844e+01f,2.2750e+01f }, sd::DataType::FLOAT32); - NDArray yuv('f', { 2,2,3 }, { 36.628319, 38.600643, -40.624989, 18.231001, -14.822637, -2.479566, -8.965780, 2.223851, -16.561626, -96.205162, -52.255379, -36.527435 }, sd::DataType::FLOAT32); - - sd::ops::yuv_to_rgb op; - auto result = op.evaluate({ &yuv }, {}, {}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + // rank 3 + NDArray expected('f', {2, 2, 3}, + {1.7750e+01f, -7.1062e+01f, -1.0019e+02f, -2.3406e+01f, + 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, + 6.5078e+00f, 3.3562e+01f, -5.8844e+01f, 2.2750e+01f}, + sd::DataType::FLOAT32); + NDArray yuv( + 'f', {2, 2, 3}, + {36.628319, 38.600643, -40.624989, 18.231001, -14.822637, -2.479566, + -8.965780, 2.223851, -16.561626, -96.205162, -52.255379, -36.527435}, + sd::DataType::FLOAT32); + + sd::ops::yuv_to_rgb op; + auto result = op.evaluate({&yuv}, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, Pow_BP_Test1) { + // same shape + NDArray x('c', {2, 2, 2}, {4, 3, 2, 5, 7, 8, -9, -12}, sd::DataType::FLOAT32); + NDArray y('c', {2, 2, 2}, {2, 3, -2, 4, -1, -4, 10, 8}, + sd::DataType::FLOAT32); - // same shape - NDArray x('c', { 2,2,2 }, { 4,3,2,5,7,8,-9,-12 }, sd::DataType::FLOAT32); - NDArray y('c', { 2,2,2 }, { 2,3,-2,4,-1,-4,10,8 }, sd::DataType::FLOAT32); + NDArray dLdz('c', {2, 2, 2}, sd::DataType::FLOAT32); + NDArray dLdxExp( + 'c', {2, 2, 2}, + {8, 27, -0.25, 500, -0.0204082, -0.000122, -3.87420e+09, -2.86654e+08}, + sd::DataType::FLOAT32); + NDArray dLdyExp( + 'c', {2, 2, 2}, + {22.18071, 29.66253, 0.17329, 1005.89874, 0.27799, 0.00051, 0, 0}, + sd::DataType::FLOAT32); + dLdz.assign(1.0); - NDArray dLdz('c', { 2,2,2 }, sd::DataType::FLOAT32); - NDArray dLdxExp('c', { 2,2,2 }, { 8, 27, -0.25, 500, -0.0204082, -0.000122, -3.87420e+09, -2.86654e+08 }, sd::DataType::FLOAT32); - NDArray dLdyExp('c', { 2,2,2 }, { 22.18071, 29.66253, 0.17329, 1005.89874, 0.27799, 0.00051, 0, 0 }, sd::DataType::FLOAT32); + sd::ops::Pow_bp op; + auto results = op.evaluate({&x, &y, &dLdz}, {}, {}); - dLdz.assign(1.0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - sd::ops::Pow_bp op; - auto results = op.evaluate({ &x, &y, &dLdz }, {}, {}); + auto dLdx = results.at(0); + auto dLdy = results.at(1); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto* dLdx = results.at(0); - auto* dLdy = results.at(1); - - ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); - ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); - ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); - ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); + ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); + ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); + ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); + ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); } TEST_F(DeclarableOpsTests15, Pow_BP_Test2) { + NDArray x('c', {1, 2, 3}, sd::DataType::FLOAT32); + NDArray y('c', {3, 2, 1}, sd::DataType::FLOAT32); + NDArray dLdz('c', {3, 2, 3}, sd::DataType::FLOAT32); - NDArray x('c', { 1,2,3 }, sd::DataType::FLOAT32); - NDArray y('c', { 3,2,1 }, sd::DataType::FLOAT32); - NDArray dLdz('c', { 3,2,3 }, sd::DataType::FLOAT32); + NDArray dLdxExp('c', {1, 2, 3}, {16.8, 19.2, 21.6, 24., 26.4, 28.8}, + sd::DataType::FLOAT32); + NDArray dLdyExp('c', {3, 2, 1}, + {13.30843, 33.27106, 53.2337, 73.19634, 93.15898, 113.12162}, + sd::DataType::FLOAT32); - NDArray dLdxExp('c', { 1,2,3 }, { 16.8, 19.2, 21.6, 24., 26.4, 28.8 }, sd::DataType::FLOAT32); - NDArray dLdyExp('c', { 3,2,1 }, { 13.30843, 33.27106, 53.2337, 73.19634, 93.15898, 113.12162 }, sd::DataType::FLOAT32); - - x.assign(4.0); - y.assign(2.0); - dLdz.linspace(0.1, 0.1); - - sd::ops::Pow_bp op; - auto results = op.evaluate({ &x, &y, &dLdz }, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + x.assign(4.0); + y.assign(2.0); + dLdz.linspace(0.1, 0.1); - auto* dLdx = results.at(0); - auto* dLdy = results.at(1); + sd::ops::Pow_bp op; + auto results = op.evaluate({&x, &y, &dLdz}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); - ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); - ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); - ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); + auto dLdx = results.at(0); + auto dLdy = results.at(1); + ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); + ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); + ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); + ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); } TEST_F(DeclarableOpsTests15, Pow_BP_Test3) { + // y - same shape as dLdz + NDArray xY('c', {1, 2, 3}, sd::DataType::FLOAT32); + NDArray yY('c', {3, 2, 3}, sd::DataType::FLOAT32); - // y - same shape as dLdz - NDArray xY('c', { 1,2,3 }, sd::DataType::FLOAT32); - NDArray yY('c', { 3,2,3 }, sd::DataType::FLOAT32); + NDArray dLdxExpY('c', {1, 2, 3}, {16.8, 19.2, 21.6, 24., 26.4, 28.8}, + sd::DataType::FLOAT32); + NDArray dLdyExpY('c', {3, 2, 3}, + {2.21807, 4.43614, 6.65421, 8.87228, 11.09035, 13.30843, + 15.5265, 17.74457, 19.96264, 22.18071, 24.39878, 26.61685, + 28.83492, 31.05299, 33.27106, 35.48914, 37.70721, 39.92528}, + sd::DataType::FLOAT32); + NDArray dLdz('c', {3, 2, 3}, sd::DataType::FLOAT32); - NDArray dLdxExpY('c', { 1,2,3 }, { 16.8, 19.2, 21.6, 24. , 26.4, 28.8 }, sd::DataType::FLOAT32); - NDArray dLdyExpY('c', { 3,2,3 }, { 2.21807, 4.43614, 6.65421, 8.87228, 11.09035, 13.30843, 15.5265 , 17.74457, 19.96264, 22.18071, 24.39878, 26.61685, 28.83492, 31.05299, 33.27106, 35.48914, 37.70721, 39.92528 }, sd::DataType::FLOAT32); - NDArray dLdz('c', { 3,2,3 }, sd::DataType::FLOAT32); + xY.assign(4.0); + yY.assign(2.0); + dLdz.linspace(0.1, 0.1); - xY.assign(4.0); - yY.assign(2.0); - dLdz.linspace(0.1, 0.1); + sd::ops::Pow_bp op; + auto resultsY = op.evaluate({&xY, &yY, &dLdz}, {}, {}); - sd::ops::Pow_bp op; - auto resultsY = op.evaluate({ &xY, &yY, &dLdz }, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, resultsY.status()); - ASSERT_EQ(ND4J_STATUS_OK, resultsY.status()); + auto dLdxY = resultsY.at(0); + auto dLdyY = resultsY.at(1); - auto* dLdxY = resultsY.at(0); - auto* dLdyY = resultsY.at(1); - - ASSERT_TRUE(dLdxExpY.isSameShape(dLdxY)); - ASSERT_TRUE(dLdxExpY.equalsTo(dLdxY)); - ASSERT_TRUE(dLdyExpY.isSameShape(dLdyY)); - ASSERT_TRUE(dLdyExpY.equalsTo(dLdyY)); + ASSERT_TRUE(dLdxExpY.isSameShape(dLdxY)); + ASSERT_TRUE(dLdxExpY.equalsTo(dLdxY)); + ASSERT_TRUE(dLdyExpY.isSameShape(dLdyY)); + ASSERT_TRUE(dLdyExpY.equalsTo(dLdyY)); } TEST_F(DeclarableOpsTests15, Pow_BP_Test4) { + // x - same shape ad dLdz + NDArray yX('c', {1, 2, 3}, sd::DataType::FLOAT32); + NDArray xX('c', {3, 2, 3}, sd::DataType::FLOAT32); - // x - same shape ad dLdz - NDArray yX('c', { 1,2,3 }, sd::DataType::FLOAT32); - NDArray xX('c', { 3,2,3 }, sd::DataType::FLOAT32); - - NDArray dLdxExpX('c', { 3,2,3 }, { 3.2, 6.4, 9.6, 12.8, 16. , 19.2, 22.4, 25.6, 28.8, 32. , 35.2, 38.4, 41.6, 44.8, 48., 51.2, 54.4, 57.6 }, sd::DataType::FLOAT32); - NDArray dLdyExpX('c', { 1,2,3 }, { 23.28975, 26.61685, 29.94396, 33.27106, 36.59817, 39.92528 }, sd::DataType::FLOAT32); + NDArray dLdxExpX('c', {3, 2, 3}, + {3.2, 6.4, 9.6, 12.8, 16., 19.2, 22.4, 25.6, 28.8, 32., 35.2, + 38.4, 41.6, 44.8, 48., 51.2, 54.4, 57.6}, + sd::DataType::FLOAT32); + NDArray dLdyExpX('c', {1, 2, 3}, + {23.28975, 26.61685, 29.94396, 33.27106, 36.59817, 39.92528}, + sd::DataType::FLOAT32); - NDArray dLdz('c', { 3,2,3 }, sd::DataType::FLOAT32); - dLdz.linspace(0.1, 0.1); + NDArray dLdz('c', {3, 2, 3}, sd::DataType::FLOAT32); + dLdz.linspace(0.1, 0.1); - sd::ops::Pow_bp op; + sd::ops::Pow_bp op; - xX.assign(2.0); - yX.assign(4.0); + xX.assign(2.0); + yX.assign(4.0); - auto resultsX = op.evaluate({ &xX, &yX, &dLdz }, {}, {}); + auto resultsX = op.evaluate({&xX, &yX, &dLdz}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, resultsX.status()); + ASSERT_EQ(ND4J_STATUS_OK, resultsX.status()); - auto* dLdxX = resultsX.at(0); - auto* dLdyX = resultsX.at(1); + auto dLdxX = resultsX.at(0); + auto dLdyX = resultsX.at(1); - ASSERT_TRUE(dLdxExpX.isSameShape(dLdxX)); - ASSERT_TRUE(dLdxExpX.equalsTo(dLdxX)); - ASSERT_TRUE(dLdyExpX.isSameShape(dLdyX)); - ASSERT_TRUE(dLdyExpX.equalsTo(dLdyX)); + ASSERT_TRUE(dLdxExpX.isSameShape(dLdxX)); + ASSERT_TRUE(dLdxExpX.equalsTo(dLdxX)); + ASSERT_TRUE(dLdyExpX.isSameShape(dLdyX)); + ASSERT_TRUE(dLdyExpX.equalsTo(dLdyX)); } TEST_F(DeclarableOpsTests15, Pow_BP_Test5) { + // both single array + NDArray xConst('c', {1}, sd::DataType::FLOAT32); + NDArray yConst('c', {1}, sd::DataType::FLOAT32); + NDArray dLdz('c', {1}, sd::DataType::FLOAT32); + NDArray dLdxExp('c', {1}, sd::DataType::FLOAT32); + NDArray dLdyExp('c', {1}, sd::DataType::FLOAT32); - // both single array - NDArray xConst('c', { 1 }, sd::DataType::FLOAT32); - NDArray yConst('c', { 1 }, sd::DataType::FLOAT32); - NDArray dLdz('c', { 1 }, sd::DataType::FLOAT32); - NDArray dLdxExp('c', { 1 }, sd::DataType::FLOAT32); - NDArray dLdyExp('c', { 1 }, sd::DataType::FLOAT32); - - xConst.assign(3.0); - yConst.assign(4.0); - dLdz.assign(1.0); + xConst.assign(3.0); + yConst.assign(4.0); + dLdz.assign(1.0); - dLdxExp.assign(4.0 * pow(3, 3)); - dLdyExp.assign(pow(3, 4) * log(3)); + dLdxExp.assign(4.0 * pow(3, 3)); + dLdyExp.assign(pow(3, 4) * log(3)); - sd::ops::Pow_bp op; - auto results = op.evaluate({ &xConst, &yConst, &dLdz }, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::Pow_bp op; + auto results = op.evaluate({&xConst, &yConst, &dLdz}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto* dLdx = results.at(0); - auto* dLdy = results.at(1); + auto dLdx = results.at(0); + auto dLdy = results.at(1); - ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); - ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); + ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); + ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); - ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); - ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); + ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); + ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); } TEST_F(DeclarableOpsTests15, Pow_BP_Test6) { + // x single array + NDArray xConst('c', {1}, sd::DataType::FLOAT32); + NDArray y('c', {2, 2, 2}, sd::DataType::FLOAT32); + NDArray dLdzC('c', {2, 2, 2}, sd::DataType::FLOAT32); - // x single array - NDArray xConst('c', { 1 }, sd::DataType::FLOAT32); - NDArray y('c', { 2, 2, 2 }, sd::DataType::FLOAT32); - NDArray dLdzC('c', { 2, 2, 2 }, sd::DataType::FLOAT32); - - xConst.assign(2.0); - y.assign(4.0); - dLdzC.linspace(0.1, 0.1); - - NDArray dLdxExpXC('c', { 1 }, std::vector{ 115.2 }, sd::DataType::FLOAT32); - NDArray dLdyExpXC('c', { 2, 2, 2 }, { 1.10904, 2.21807, 3.32711, 4.43614, 5.54518, 6.65421, 7.76325, 8.87228 }, sd::DataType::FLOAT32); + xConst.assign(2.0); + y.assign(4.0); + dLdzC.linspace(0.1, 0.1); - sd::ops::Pow_bp op; - auto resultsXC = op.evaluate({ &xConst, &y, &dLdzC }, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, resultsXC.status()); + NDArray dLdxExpXC('c', {1}, std::vector{115.2}, + sd::DataType::FLOAT32); + NDArray dLdyExpXC( + 'c', {2, 2, 2}, + {1.10904, 2.21807, 3.32711, 4.43614, 5.54518, 6.65421, 7.76325, 8.87228}, + sd::DataType::FLOAT32); - auto* dLdxXC = resultsXC.at(0); - auto* dLdyXC = resultsXC.at(1); + sd::ops::Pow_bp op; + auto resultsXC = op.evaluate({&xConst, &y, &dLdzC}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, resultsXC.status()); - ASSERT_TRUE(dLdxExpXC.isSameShape(dLdxXC)); - ASSERT_TRUE(dLdxExpXC.equalsTo(dLdxXC)); - ASSERT_TRUE(dLdyExpXC.isSameShape(dLdyXC)); - ASSERT_TRUE(dLdyExpXC.equalsTo(dLdyXC)); + auto dLdxXC = resultsXC.at(0); + auto dLdyXC = resultsXC.at(1); + ASSERT_TRUE(dLdxExpXC.isSameShape(dLdxXC)); + ASSERT_TRUE(dLdxExpXC.equalsTo(dLdxXC)); + ASSERT_TRUE(dLdyExpXC.isSameShape(dLdyXC)); + ASSERT_TRUE(dLdyExpXC.equalsTo(dLdyXC)); } TEST_F(DeclarableOpsTests15, Pow_BP_Test7) { + // Y - scalar + auto Y = NDArrayFactory::create(2.f); + NDArray x('c', {2, 2, 2}, sd::DataType::FLOAT32); + NDArray dLdzC('c', {2, 2, 2}, sd::DataType::FLOAT32); - // Y - scalar - auto Y = NDArrayFactory::create(2.f); - NDArray x('c', { 2, 2, 2 }, sd::DataType::FLOAT32); - NDArray dLdzC('c', { 2, 2, 2 }, sd::DataType::FLOAT32); + dLdzC.linspace(0.1, 0.1); + x = 4.f; - dLdzC.linspace(0.1, 0.1); - x = 4.f; + NDArray dLdxExpYs('c', {2, 2, 2}, {0.8, 1.6, 2.4, 3.2, 4., 4.8, 5.6, 6.4}, + sd::DataType::FLOAT32); - NDArray dLdxExpYs('c', { 2, 2, 2 }, { 0.8, 1.6, 2.4, 3.2, 4., 4.8, 5.6, 6.4 }, sd::DataType::FLOAT32); + auto dLdyExpYs = NDArrayFactory::create(79.85056f); - auto dLdyExpYs = NDArrayFactory::create(79.85056f); + sd::ops::Pow_bp op; + auto resultsYs = op.evaluate({&x, &Y, &dLdzC}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, resultsYs.status()); - sd::ops::Pow_bp op; - auto resultsYs = op.evaluate({ &x, &Y, &dLdzC }, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, resultsYs.status()); + auto dLdxY = resultsYs.at(0); + auto dLdyY = resultsYs.at(1); - auto* dLdxY = resultsYs.at(0); - auto* dLdyY = resultsYs.at(1); - - ASSERT_TRUE(dLdxExpYs.isSameShape(dLdxY)); - ASSERT_TRUE(dLdxExpYs.equalsTo(dLdxY)); - ASSERT_TRUE(dLdyExpYs.isSameShape(dLdyY)); - ASSERT_TRUE(dLdyExpYs.equalsTo(dLdyY)); + ASSERT_TRUE(dLdxExpYs.isSameShape(dLdxY)); + ASSERT_TRUE(dLdxExpYs.equalsTo(dLdxY)); + ASSERT_TRUE(dLdyExpYs.isSameShape(dLdyY)); + ASSERT_TRUE(dLdyExpYs.equalsTo(dLdyY)); } TEST_F(DeclarableOpsTests15, Pow_BP_Test8) { - // both scalars - - auto X = NDArrayFactory::create(4.f); - auto Y = NDArrayFactory::create(2.f); - NDArray dLdz = NDArrayFactory::create(0.1f); + // both scalars - NDArray dLdxExp = NDArrayFactory::create(2.f*4.f*0.1f); + auto X = NDArrayFactory::create(4.f); + auto Y = NDArrayFactory::create(2.f); + NDArray dLdz = NDArrayFactory::create(0.1f); - NDArray dLdyExp = NDArrayFactory::create(pow(4.f, 2.f) * log(4.f) * 0.1f); + NDArray dLdxExp = NDArrayFactory::create(2.f * 4.f * 0.1f); - sd::ops::Pow_bp op; - auto results = op.evaluate({ &X, &Y, &dLdz }, {}, {}); + NDArray dLdyExp = + NDArrayFactory::create(pow(4.f, 2.f) * log(4.f) * 0.1f); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::Pow_bp op; + auto results = op.evaluate({&X, &Y, &dLdz}, {}, {}); - auto* dLdx = results.at(0); - auto* dLdy = results.at(1); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); - ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); - ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); - ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); + auto dLdx = results.at(0); + auto dLdy = results.at(1); + ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); + ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); + ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); + ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); } TEST_F(DeclarableOpsTests15, Pow_BP_Test9) { + sd::ops::Pow_bp op; + // diff shapes + NDArray x('c', {3, 2, 1}, sd::DataType::FLOAT32); + NDArray y('c', {1, 2, 3}, sd::DataType::FLOAT32); + NDArray dLdz('c', {3, 2, 3}, sd::DataType::FLOAT32); - sd::ops::Pow_bp op; - // diff shapes - NDArray x('c', { 3,2,1 }, sd::DataType::FLOAT32); - NDArray y('c', { 1,2,3 }, sd::DataType::FLOAT32); - NDArray dLdz('c', { 3,2,3 }, sd::DataType::FLOAT32); + NDArray dLdxExp('c', {3, 2, 1}, {4.8, 12., 19.2, 26.4, 33.6, 40.8}, + sd::DataType::FLOAT32); + NDArray dLdyExp('c', {1, 2, 3}, + {46.57949, 53.2337, 59.88792, 66.54213, 73.19634, 79.85056}, + sd::DataType::FLOAT32); - NDArray dLdxExp('c', { 3,2,1 }, { 4.8, 12., 19.2, 26.4, 33.6, 40.8 }, sd::DataType::FLOAT32); - NDArray dLdyExp('c', { 1,2,3 }, { 46.57949, 53.2337 , 59.88792, 66.54213, 73.19634, 79.85056 }, sd::DataType::FLOAT32); + x.assign(4.0); + y.assign(2.0); + dLdz.linspace(0.1, 0.1); - x.assign(4.0); - y.assign(2.0); - dLdz.linspace(0.1, 0.1); + auto results = op.evaluate({&x, &y, &dLdz}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto results = op.evaluate({ &x, &y, &dLdz }, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto* dLdx = results.at(0); - auto* dLdy = results.at(1); + auto dLdx = results.at(0); + auto dLdy = results.at(1); - ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); - ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); - ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); - ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); + ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); + ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); + ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); + ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); } TEST_F(DeclarableOpsTests15, Pow_BP_Test10) { + // diff shapes broadcastable + NDArray yB('c', {1, 2, 3, 1}, sd::DataType::FLOAT32); + NDArray xB('c', {2, 3, 1}, sd::DataType::FLOAT32); - // diff shapes broadcastable - NDArray yB('c', { 1,2,3,1 }, sd::DataType::FLOAT32); - NDArray xB('c', { 2,3,1 }, sd::DataType::FLOAT32); - - NDArray dLdyExpB('c', { 1,2,3,1 }, { 2.21807, 4.43614, 6.65421, 8.87228, 11.09035, 13.30843 }, sd::DataType::FLOAT32); - NDArray dLdxExpB('c', { 2,3,1 }, { 0.8, 1.6, 2.4, 3.2, 4., 4.8 }, sd::DataType::FLOAT32); - NDArray dLdzB('c', { 1,2,3,1 }, sd::DataType::FLOAT32); + NDArray dLdyExpB('c', {1, 2, 3, 1}, + {2.21807, 4.43614, 6.65421, 8.87228, 11.09035, 13.30843}, + sd::DataType::FLOAT32); + NDArray dLdxExpB('c', {2, 3, 1}, {0.8, 1.6, 2.4, 3.2, 4., 4.8}, + sd::DataType::FLOAT32); + NDArray dLdzB('c', {1, 2, 3, 1}, sd::DataType::FLOAT32); - dLdzB.linspace(0.1, 0.1); - xB.assign(4.0); - yB.assign(2.0); + dLdzB.linspace(0.1, 0.1); + xB.assign(4.0); + yB.assign(2.0); - sd::ops::Pow_bp op; - auto resultsB = op.evaluate({ &xB, &yB, &dLdzB }, {}, {}); + sd::ops::Pow_bp op; + auto resultsB = op.evaluate({&xB, &yB, &dLdzB}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, resultsB.status()); + ASSERT_EQ(ND4J_STATUS_OK, resultsB.status()); - auto* dLdxB = resultsB.at(0); - auto* dLdyB = resultsB.at(1); + auto dLdxB = resultsB.at(0); + auto dLdyB = resultsB.at(1); - ASSERT_TRUE(dLdxExpB.isSameShape(dLdxB)); - ASSERT_TRUE(dLdxExpB.equalsTo(dLdxB)); + ASSERT_TRUE(dLdxExpB.isSameShape(dLdxB)); + ASSERT_TRUE(dLdxExpB.equalsTo(dLdxB)); - ASSERT_TRUE(dLdyExpB.isSameShape(dLdyB)); - ASSERT_TRUE(dLdyExpB.equalsTo(dLdyB)); + ASSERT_TRUE(dLdyExpB.isSameShape(dLdyB)); + ASSERT_TRUE(dLdyExpB.equalsTo(dLdyB)); } TEST_F(DeclarableOpsTests15, Pow_BP_Test11) { #ifdef FFAST_MATH - if (1 > 0) - return; + if (1 > 0) return; #endif - NDArray xB('c', { 3,2,1 }, { .4, 3, 5, .8, -9, -12 }, sd::DataType::FLOAT32); - NDArray yB('c', { 1,2,3 }, { 3, -2, .4, -4, 10, .8 }, sd::DataType::FLOAT32); - - NDArray dLdxExpB('c', { 3,2,1 }, { -5.994056, 39366.191406, 7.508829, -2.223537, -std::numeric_limits::quiet_NaN(), -std::numeric_limits::quiet_NaN() }, sd::DataType::FLOAT32); - NDArray dLdyExpB('c', { 1,2,3 }, { 20.11211, -1.119612, -std::numeric_limits::quiet_NaN(), -0.1076, 12974.389648, -std::numeric_limits::quiet_NaN() }, sd::DataType::FLOAT32); - - NDArray dLdzB('c', { 3,2,3 }, { .1,.2,.3, .1,.2,.3, .1,.4,.1, .2,.1,.1, .3,.1,.5, .1, .7, .1 }, sd::DataType::FLOAT32); - - sd::ops::Pow_bp op; - auto resultsB = op.evaluate({ &xB, &yB, &dLdzB }, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, resultsB.status()); - auto* dLdxB = resultsB.at(0); - auto* dLdyB = resultsB.at(1); - - ASSERT_TRUE(dLdxExpB.isSameShape(dLdxB)); - for (int i = 0; i < dLdxB->lengthOf(); ++i) { - if (!sd::math::nd4j_isnan(dLdxB->e(i)) && !sd::math::nd4j_isnan(dLdxExpB.e(i))) - ASSERT_NEAR(dLdxB->e(i), dLdxExpB.e(i), 0.00001); - } - - ASSERT_TRUE(dLdyExpB.isSameShape(dLdyB)); - for (int i = 0; i < dLdyB->lengthOf(); ++i) { - if (!sd::math::nd4j_isnan(dLdyB->e(i)) && !sd::math::nd4j_isnan(dLdyExpB.e(i))) - ASSERT_NEAR(dLdyB->e(i), dLdyExpB.e(i), 0.00001); - } - - + NDArray xB('c', {3, 2, 1}, {.4, 3, 5, .8, -9, -12}, sd::DataType::FLOAT32); + NDArray yB('c', {1, 2, 3}, {3, -2, .4, -4, 10, .8}, sd::DataType::FLOAT32); + + NDArray dLdxExpB('c', {3, 2, 1}, + {-5.994056, 39366.191406, 7.508829, -2.223537, + -std::numeric_limits::quiet_NaN(), + -std::numeric_limits::quiet_NaN()}, + sd::DataType::FLOAT32); + NDArray dLdyExpB( + 'c', {1, 2, 3}, + {20.11211, -1.119612, -std::numeric_limits::quiet_NaN(), -0.1076, + 12974.389648, -std::numeric_limits::quiet_NaN()}, + sd::DataType::FLOAT32); + + NDArray dLdzB( + 'c', {3, 2, 3}, + {.1, .2, .3, .1, .2, .3, .1, .4, .1, .2, .1, .1, .3, .1, .5, .1, .7, .1}, + sd::DataType::FLOAT32); + + sd::ops::Pow_bp op; + auto resultsB = op.evaluate({&xB, &yB, &dLdzB}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsB.status()); + auto dLdxB = resultsB.at(0); + auto dLdyB = resultsB.at(1); + + ASSERT_TRUE(dLdxExpB.isSameShape(dLdxB)); + for (int i = 0; i < dLdxB.lengthOf(); ++i) { + if (!sd::math::nd4j_isnan(dLdxB.e(i)) && + !sd::math::nd4j_isnan(dLdxExpB.e(i))) + ASSERT_NEAR(dLdxB.e(i), dLdxExpB.e(i), 0.00001); + } + + ASSERT_TRUE(dLdyExpB.isSameShape(dLdyB)); + for (int i = 0; i < dLdyB.lengthOf(); ++i) { + if (!sd::math::nd4j_isnan(dLdyB.e(i)) && + !sd::math::nd4j_isnan(dLdyExpB.e(i))) + ASSERT_NEAR(dLdyB.e(i), dLdyExpB.e(i), 0.00001); + } } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP1) { + NDArray A('c', {1, 2, 3}, {2.1, 2.2, 2.3, 2.4, 2.5, 2.6}, + sd::DataType::FLOAT32); + NDArray B('c', {1, 2, 4}, {3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8}, + sd::DataType::FLOAT32); + NDArray dLdC('c', {3, 4}, {.1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.1}, + sd::DataType::FLOAT32); - NDArray A('c', { 1, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6 }, sd::DataType::FLOAT32); - NDArray B('c', { 1, 2, 4 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8 }, sd::DataType::FLOAT32); - NDArray dLdC('c', { 3, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.1 }, sd::DataType::FLOAT32); - - NDArray dLdA('c', { 1, 2, 3 }, { 3.3, 8.5, 13.36, 3.7, 9.54, 15. }, sd::DataType::FLOAT32); - NDArray dLdB('c', { 1, 2, 4 }, { 3.38, 4.04, 4.7, 5.13, 3.83, 4.58, 5.33, 5.82 }, sd::DataType::FLOAT32); - - sd::ops::tensormmul_bp op_bp; + NDArray dLdA('c', {1, 2, 3}, {3.3, 8.5, 13.36, 3.7, 9.54, 15.}, + sd::DataType::FLOAT32); + NDArray dLdB('c', {1, 2, 4}, {3.38, 4.04, 4.7, 5.13, 3.83, 4.58, 5.33, 5.82}, + sd::DataType::FLOAT32); - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,0,1, 2,0,1 }, {}); + sd::ops::tensormmul_bp op_bp; - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + auto resultsBP = op_bp.evaluate({&A, &B, &dLdC}, {}, {2, 0, 1, 2, 0, 1}, {}); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - ASSERT_TRUE(dLdA.isSameShape(*dLdAbp)); - ASSERT_TRUE(dLdA.equalsTo(*dLdAbp)); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(dLdB.isSameShape(*dLdBbp)); - ASSERT_TRUE(dLdB.equalsTo(*dLdBbp)); + ASSERT_TRUE(dLdA.isSameShape(dLdAbp)); + ASSERT_TRUE(dLdA.equalsTo(dLdAbp)); + ASSERT_TRUE(dLdB.isSameShape(dLdBbp)); + ASSERT_TRUE(dLdB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP2) { + NDArray A('c', {1, 2, 3}, {2, 2, 2, 2, 2, 2}, sd::DataType::FLOAT32); + NDArray B('c', {1, 2, 3}, {3, 3, 3, 3, 3, 3}, sd::DataType::FLOAT32); + NDArray dLdC('c', {1}, {1}, sd::DataType::FLOAT32); - NDArray A('c', { 1, 2, 3 }, { 2,2,2, 2,2,2 }, sd::DataType::FLOAT32); - NDArray B('c', { 1, 2, 3 }, { 3,3,3,3, 3,3 }, sd::DataType::FLOAT32); - NDArray dLdC('c', { 1 }, { 1 }, sd::DataType::FLOAT32); + sd::ops::tensormmul_bp op_bp; + auto resultsBP = op_bp.evaluate({&A, &B, &dLdC}, {}, {2, 1, 2, 2, 1, 2}, {}); - sd::ops::tensormmul_bp op_bp; - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); - - ASSERT_TRUE(B.isSameShape(*dLdAbp)); - ASSERT_TRUE(B.equalsTo(*dLdAbp)); - - ASSERT_TRUE(A.isSameShape(*dLdBbp)); - ASSERT_TRUE(A.equalsTo(*dLdBbp)); + ASSERT_TRUE(B.isSameShape(dLdAbp)); + ASSERT_TRUE(B.equalsTo(dLdAbp)); + ASSERT_TRUE(A.isSameShape(dLdBbp)); + ASSERT_TRUE(A.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP3) { + NDArray A('c', {3, 2, 2}, + {2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2}, + sd::DataType::FLOAT32); + NDArray B('c', {4, 2, 2}, + {3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, + 4.4, 4.5, 4.6}, + sd::DataType::FLOAT32); + NDArray dLdC('c', {3, 4}, {.1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2}, + sd::DataType::FLOAT32); - NDArray A('c', { 3, 2, 2 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, sd::DataType::FLOAT32); - NDArray B('c', { 4, 2, 2 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, sd::DataType::FLOAT32); - NDArray dLdC('c', { 3, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, sd::DataType::FLOAT32); - - NDArray dA('c', { 3, 2, 2 }, { 3.9, 4., 4.1, 4.2, 9.82, 10.08, 10.34, 10.6, 15.74, 16.16, 16.58, 17. }, sd::DataType::FLOAT32); - NDArray dB('c', { 4, 2, 2 }, { 4.07, 4.22, 4.37, 4.52, 4.82, 5., 5.18, 5.36, 5.57, 5.78, 5.99, 6.2, 6.32, 6.56, 6.8, 7.04 }, sd::DataType::FLOAT32); + NDArray dA( + 'c', {3, 2, 2}, + {3.9, 4., 4.1, 4.2, 9.82, 10.08, 10.34, 10.6, 15.74, 16.16, 16.58, 17.}, + sd::DataType::FLOAT32); + NDArray dB('c', {4, 2, 2}, + {4.07, 4.22, 4.37, 4.52, 4.82, 5., 5.18, 5.36, 5.57, 5.78, 5.99, + 6.2, 6.32, 6.56, 6.8, 7.04}, + sd::DataType::FLOAT32); - sd::ops::tensormmul_bp op_bp; + sd::ops::tensormmul_bp op_bp; - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); + auto resultsBP = op_bp.evaluate({&A, &B, &dLdC}, {}, {2, 1, 2, 2, 1, 2}, {}); - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(dA.isSameShape(*dLdAbp)); - ASSERT_TRUE(dA.equalsTo(*dLdAbp)); - - ASSERT_TRUE(dB.isSameShape(*dLdBbp)); - ASSERT_TRUE(dB.equalsTo(*dLdBbp)); + ASSERT_TRUE(dA.isSameShape(dLdAbp)); + ASSERT_TRUE(dA.equalsTo(dLdAbp)); + ASSERT_TRUE(dB.isSameShape(dLdBbp)); + ASSERT_TRUE(dB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP4) { + NDArray A('c', {3, 4, 1}, {0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3}, + sd::DataType::FLOAT32); + NDArray B('c', {2, 4, 1}, {4, 13, .5, 19, 2.3, 1.2, 18, .9}, + sd::DataType::FLOAT32); + NDArray dLdC('c', {3, 2}, {1.1, 1.2, 1.3, 1.4, 1.5, 1.6}, + sd::DataType::FLOAT32); - NDArray A('c', { 3, 4, 1 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, sd::DataType::FLOAT32); - NDArray B('c', { 2, 4, 1 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, sd::DataType::FLOAT32); - NDArray dLdC('c', { 3, 2 }, { 1.1, 1.2, 1.3, 1.4, 1.5, 1.6 }, sd::DataType::FLOAT32); - - NDArray dLdA('c', { 3, 4, 1 }, { 7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, 21.42, 29.55, 29.94 }, sd::DataType::FLOAT32); - NDArray dLdB('c', { 2, 4, 1 }, { 30.49, 3.456, 201.9, 26.1, 32.84 , 3.768, 215.6, 28.2 }, sd::DataType::FLOAT32); - - sd::ops::tensormmul_bp op_bp; + NDArray dLdA('c', {3, 4, 1}, + {7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, + 21.42, 29.55, 29.94}, + sd::DataType::FLOAT32); + NDArray dLdB('c', {2, 4, 1}, + {30.49, 3.456, 201.9, 26.1, 32.84, 3.768, 215.6, 28.2}, + sd::DataType::FLOAT32); - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); + sd::ops::tensormmul_bp op_bp; - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + auto resultsBP = op_bp.evaluate({&A, &B, &dLdC}, {}, {2, 1, 2, 2, 1, 2}, {}); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - ASSERT_TRUE(dLdA.isSameShape(*dLdAbp)); - ASSERT_TRUE(dLdA.equalsTo(*dLdAbp)); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(dLdB.isSameShape(*dLdBbp)); - ASSERT_TRUE(dLdB.equalsTo(*dLdBbp)); + ASSERT_TRUE(dLdA.isSameShape(dLdAbp)); + ASSERT_TRUE(dLdA.equalsTo(dLdAbp)); + ASSERT_TRUE(dLdB.isSameShape(dLdBbp)); + ASSERT_TRUE(dLdB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP5) { + NDArray A('c', {3, 4, 1, 1}, {0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3}, + sd::DataType::FLOAT32); + NDArray B('c', {2, 4, 1, 1}, {4, 13, .5, 19, 2.3, 1.2, 18, .9}, + sd::DataType::FLOAT32); + NDArray dLdC('c', {3, 1, 2, 1}, {1.1, 1.2, 1.3, 1.4, 1.5, 1.6}, + sd::DataType::FLOAT32); - NDArray A('c', { 3, 4, 1, 1 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, sd::DataType::FLOAT32); - NDArray B('c', { 2, 4, 1, 1 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, sd::DataType::FLOAT32); - NDArray dLdC('c', { 3, 1, 2, 1 }, { 1.1,1.2,1.3,1.4,1.5,1.6 }, sd::DataType::FLOAT32); + NDArray dLdA('c', {3, 4, 1, 1}, + {7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, + 21.42, 29.55, 29.94}, + sd::DataType::FLOAT32); + NDArray dLdB('c', {2, 4, 1, 1}, + {30.49, 3.456, 201.9, 26.1, 32.84, 3.768, 215.6, 28.2}, + sd::DataType::FLOAT32); - NDArray dLdA('c', { 3, 4, 1, 1 }, { 7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, 21.42, 29.55, 29.94 }, sd::DataType::FLOAT32); - NDArray dLdB('c', { 2, 4, 1, 1 }, { 30.49, 3.456, 201.9, 26.1, 32.84, 3.768, 215.6, 28.2 }, sd::DataType::FLOAT32); + sd::ops::tensormmul_bp op_bp; - sd::ops::tensormmul_bp op_bp; + auto resultsBP = op_bp.evaluate({&A, &B, &dLdC}, {}, {2, 1, 2, 2, 1, 2}, {}); - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); - - ASSERT_TRUE(dLdA.isSameShape(*dLdAbp)); - ASSERT_TRUE(dLdA.equalsTo(*dLdAbp)); - - ASSERT_TRUE(dLdB.isSameShape(*dLdBbp)); - ASSERT_TRUE(dLdB.equalsTo(*dLdBbp)); + ASSERT_TRUE(dLdA.isSameShape(dLdAbp)); + ASSERT_TRUE(dLdA.equalsTo(dLdAbp)); + ASSERT_TRUE(dLdB.isSameShape(dLdBbp)); + ASSERT_TRUE(dLdB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP6) { + NDArray A('c', {2, 2, 2}, {2, 2, 2, 2, 2, 2, 2, 2}, sd::DataType::FLOAT32); + NDArray B('c', {2, 2, 2}, {3, 3, 3, 3, 3, 3, 3, 3}, sd::DataType::FLOAT32); - NDArray A('c', { 2, 2, 2 }, { 2,2, 2,2, 2,2, 2,2 }, sd::DataType::FLOAT32); - NDArray B('c', { 2, 2, 2 }, { 3,3, 3,3, 3,3, 3,3 }, sd::DataType::FLOAT32); - - auto dLdC = NDArrayFactory::create(1.f); + auto dLdC = NDArrayFactory::create(1.f); - sd::ops::tensormmul_bp op_bp; - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {}); + sd::ops::tensormmul_bp op_bp; + auto resultsBP = + op_bp.evaluate({&A, &B, &dLdC}, {}, {3, 0, 1, 2, 3, 0, 1, 2}, {}); - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(B.isSameShape(*dLdAbp)); - ASSERT_TRUE(B.equalsTo(*dLdAbp)); - - ASSERT_TRUE(A.isSameShape(*dLdBbp)); - ASSERT_TRUE(A.equalsTo(*dLdBbp)); + ASSERT_TRUE(B.isSameShape(dLdAbp)); + ASSERT_TRUE(B.equalsTo(dLdAbp)); + ASSERT_TRUE(A.isSameShape(dLdBbp)); + ASSERT_TRUE(A.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP7) { + NDArray A('c', {3, 4, 1}, {0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3}, + sd::DataType::FLOAT32); + NDArray B('c', {2, 4, 1}, {4, 13, .5, 19, 2.3, 1.2, 18, .9}, + sd::DataType::FLOAT32); + NDArray dLdC('c', {3, 1, 2, 1}, {1.1, 1.2, 1.3, 1.4, 1.5, 1.6}, + sd::DataType::FLOAT32); - NDArray A('c', { 3, 4, 1 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, sd::DataType::FLOAT32); - NDArray B('c', { 2, 4, 1 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, sd::DataType::FLOAT32); - NDArray dLdC('c', { 3, 1, 2, 1 }, { 1.1, 1.2, 1.3, 1.4, 1.5, 1.6 }, sd::DataType::FLOAT32); - - NDArray dLdA('c', { 3, 4, 1 }, { 7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, 21.42, 29.55, 29.94 }, sd::DataType::FLOAT32); - NDArray dLdB('c', { 2, 4, 1 }, { 30.49, 3.456, 201.9, 26.1, 32.84, 3.768, 215.6, 28.2 }, sd::DataType::FLOAT32); + NDArray dLdA('c', {3, 4, 1}, + {7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, + 21.42, 29.55, 29.94}, + sd::DataType::FLOAT32); + NDArray dLdB('c', {2, 4, 1}, + {30.49, 3.456, 201.9, 26.1, 32.84, 3.768, 215.6, 28.2}, + sd::DataType::FLOAT32); - sd::ops::tensormmul_bp op_bp; + sd::ops::tensormmul_bp op_bp; - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {}); + auto resultsBP = op_bp.evaluate({&A, &B, &dLdC}, {}, {1, 1, 1, 1}, {}); - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(dLdA.isSameShape(*dLdAbp)); - ASSERT_TRUE(dLdA.equalsTo(*dLdAbp)); - - ASSERT_TRUE(dLdB.isSameShape(*dLdBbp)); - ASSERT_TRUE(dLdB.equalsTo(*dLdBbp)); + ASSERT_TRUE(dLdA.isSameShape(dLdAbp)); + ASSERT_TRUE(dLdA.equalsTo(dLdAbp)); + ASSERT_TRUE(dLdB.isSameShape(dLdBbp)); + ASSERT_TRUE(dLdB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP8) { + NDArray A('c', {1, 1, 4, 3}, {0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3}, + sd::DataType::FLOAT32); + NDArray B('c', {1, 1, 4, 2}, {4, 13, .5, 19, 2.3, 1.2, 18, .9}, + sd::DataType::FLOAT32); + NDArray dLdC('c', {3, 2}, {1.1, 1.2, 1.3, 1.4, 1.5, 1.6}, + sd::DataType::FLOAT32); - NDArray A('c', { 1, 1, 4, 3 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, sd::DataType::FLOAT32); - NDArray B('c', { 1, 1, 4, 2 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, sd::DataType::FLOAT32); - NDArray dLdC('c', { 3, 2 }, { 1.1,1.2,1.3,1.4,1.5,1.6 }, sd::DataType::FLOAT32); - - NDArray dLdA('c', { 1, 1, 4, 3 }, { 20., 23.4, 26.8, 23.35, 27.25, 31.15, 3.97, 4.67, 5.37, 20.88, 24.66, 28.44 }, sd::DataType::FLOAT32); - NDArray dLdB('c', { 1, 1, 4, 2 }, { 11.84, 12.68, 39.98, 43.192, 20.65, 22.36, 165.7, 178.4 }, sd::DataType::FLOAT32); - - sd::ops::tensormmul_bp op_bp; + NDArray dLdA('c', {1, 1, 4, 3}, + {20., 23.4, 26.8, 23.35, 27.25, 31.15, 3.97, 4.67, 5.37, 20.88, + 24.66, 28.44}, + sd::DataType::FLOAT32); + NDArray dLdB('c', {1, 1, 4, 2}, + {11.84, 12.68, 39.98, 43.192, 20.65, 22.36, 165.7, 178.4}, + sd::DataType::FLOAT32); - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {}); + sd::ops::tensormmul_bp op_bp; - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + auto resultsBP = + op_bp.evaluate({&A, &B, &dLdC}, {}, {3, 0, 1, 2, 3, 0, 1, 2}, {}); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - ASSERT_TRUE(dLdA.isSameShape(*dLdAbp)); - ASSERT_TRUE(dLdA.equalsTo(*dLdAbp)); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(dLdB.isSameShape(*dLdBbp)); - ASSERT_TRUE(dLdB.equalsTo(*dLdBbp)); + ASSERT_TRUE(dLdA.isSameShape(dLdAbp)); + ASSERT_TRUE(dLdA.equalsTo(dLdAbp)); + ASSERT_TRUE(dLdB.isSameShape(dLdBbp)); + ASSERT_TRUE(dLdB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP9) { + NDArray A('c', {3, 2, 2, 1}, + {2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2}, + sd::DataType::FLOAT32); + NDArray B('c', {4, 2, 2, 1}, + {3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, + 4.4, 4.5, 4.6}, + sd::DataType::FLOAT32); + NDArray dLdC('c', {3, 1, 4, 1}, + {.1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2}, + sd::DataType::FLOAT32); - NDArray A('c', { 3, 2, 2, 1 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, sd::DataType::FLOAT32); - NDArray B('c', { 4, 2, 2 ,1 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, sd::DataType::FLOAT32); - NDArray dLdC('c', { 3, 1, 4, 1 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, sd::DataType::FLOAT32); + NDArray dA( + 'c', {3, 2, 2, 1}, + {3.9, 4., 4.1, 4.2, 9.82, 10.08, 10.34, 10.6, 15.74, 16.16, 16.58, 17.}, + sd::DataType::FLOAT32); + NDArray dB('c', {4, 2, 2, 1}, + {4.07, 4.22, 4.37, 4.52, 4.82, 5., 5.18, 5.36, 5.57, 5.78, 5.99, + 6.2, 6.32, 6.56, 6.8, 7.04}, + sd::DataType::FLOAT32); - NDArray dA('c', { 3, 2, 2, 1 }, { 3.9, 4., 4.1, 4.2, 9.82, 10.08, 10.34, 10.6, 15.74, 16.16, 16.58, 17. }, sd::DataType::FLOAT32); - NDArray dB('c', { 4, 2, 2, 1 }, { 4.07, 4.22, 4.37, 4.52, 4.82, 5., 5.18, 5.36, 5.57, 5.78, 5.99, 6.2, 6.32, 6.56, 6.8, 7.04 }, sd::DataType::FLOAT32); + sd::ops::tensormmul_bp op_bp; - sd::ops::tensormmul_bp op_bp; + auto resultsBP = op_bp.evaluate({&A, &B, &dLdC}, {}, {2, 1, 2, 2, 1, 2}, {}); - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); - - ASSERT_TRUE(dA.isSameShape(*dLdAbp)); - ASSERT_TRUE(dA.equalsTo(*dLdAbp)); - - ASSERT_TRUE(dB.isSameShape(*dLdBbp)); - ASSERT_TRUE(dB.equalsTo(*dLdBbp)); + ASSERT_TRUE(dA.isSameShape(dLdAbp)); + ASSERT_TRUE(dA.equalsTo(dLdAbp)); + ASSERT_TRUE(dB.isSameShape(dLdBbp)); + ASSERT_TRUE(dB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP10) { + NDArray A('c', {1, 2, 2, 3}, + {2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2}, + sd::DataType::FLOAT32); + NDArray B('c', {1, 2, 2, 4}, + {3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, + 4.4, 4.5, 4.6}, + sd::DataType::FLOAT32); + NDArray dLdC('c', {1, 3, 1, 4}, + {.1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2}, + sd::DataType::FLOAT32); - NDArray A('c', { 1, 2, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, sd::DataType::FLOAT32); - NDArray B('c', { 1, 2, 2 ,4 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, sd::DataType::FLOAT32); - NDArray dLdC('c', { 1, 3, 1, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, sd::DataType::FLOAT32); - + NDArray dA( + 'c', {1, 2, 2, 3}, + {3.3, 8.5, 13.7, 3.7, 9.54, 15.38, 4.1, 10.58, 17.06, 4.5, 11.62, 18.74}, + sd::DataType::FLOAT32); + NDArray dB('c', {1, 2, 2, 4}, + {3.38, 4.04, 4.7, 5.36, 3.83, 4.58, 5.33, 6.08, 4.28, 5.12, 5.96, + 6.8, 4.73, 5.66, 6.59, 7.52}, + sd::DataType::FLOAT32); - NDArray dA('c', { 1, 2, 2, 3 }, { 3.3, 8.5, 13.7, 3.7, 9.54, 15.38, 4.1, 10.58, 17.06, 4.5, 11.62, 18.74 }, sd::DataType::FLOAT32); - NDArray dB('c', { 1, 2, 2, 4 }, { 3.38, 4.04, 4.7, 5.36, 3.83, 4.58, 5.33, 6.08, 4.28, 5.12, 5.96, 6.8, 4.73, 5.66, 6.59, 7.52 }, sd::DataType::FLOAT32); + sd::ops::tensormmul_bp op_bp; - sd::ops::tensormmul_bp op_bp; + auto resultsBP = op_bp.evaluate({&A, &B, &dLdC}, {}, {2, 1, 2, 2, 1, 2}, {}); - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); - - ASSERT_TRUE(dA.isSameShape(*dLdAbp)); - ASSERT_TRUE(dA.equalsTo(*dLdAbp)); - - ASSERT_TRUE(dB.isSameShape(*dLdBbp)); - ASSERT_TRUE(dB.equalsTo(*dLdBbp)); + ASSERT_TRUE(dA.isSameShape(dLdAbp)); + ASSERT_TRUE(dA.equalsTo(dLdAbp)); + ASSERT_TRUE(dB.isSameShape(dLdBbp)); + ASSERT_TRUE(dB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP11) { + NDArray A('c', {2, 2, 3}, + {2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2}, + sd::DataType::FLOAT32); + NDArray B('c', {2, 2, 4}, + {3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, + 4.4, 4.5, 4.6}, + sd::DataType::FLOAT32); + NDArray dLdC('c', {3, 4}, {.1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2}, + sd::DataType::FLOAT32); - NDArray A('c', { 2, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, sd::DataType::FLOAT32); - NDArray B('c', { 2, 2 ,4 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, sd::DataType::FLOAT32); - NDArray dLdC('c', { 3, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, sd::DataType::FLOAT32); - + NDArray dA( + 'c', {2, 2, 3}, + {3.3, 8.5, 13.7, 3.7, 9.54, 15.38, 4.1, 10.58, 17.06, 4.5, 11.62, 18.74}, + sd::DataType::FLOAT32); + NDArray dB('c', {2, 2, 4}, + {3.38, 4.04, 4.7, 5.36, 3.83, 4.58, 5.33, 6.08, 4.28, 5.12, 5.96, + 6.8, 4.73, 5.66, 6.59, 7.52}, + sd::DataType::FLOAT32); - NDArray dA('c', { 2, 2, 3 }, { 3.3, 8.5, 13.7, 3.7, 9.54, 15.38, 4.1, 10.58, 17.06, 4.5, 11.62, 18.74 }, sd::DataType::FLOAT32); - NDArray dB('c', { 2, 2, 4 }, { 3.38, 4.04, 4.7, 5.36, 3.83, 4.58, 5.33, 6.08, 4.28, 5.12, 5.96, 6.8, 4.73, 5.66, 6.59, 7.52 }, sd::DataType::FLOAT32); + sd::ops::tensormmul_bp op_bp; - sd::ops::tensormmul_bp op_bp; + auto resultsBP = op_bp.evaluate({&A, &B, &dLdC}, {}, {2, 0, 1, 2, 0, 1}, {}); - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,0,1, 2,0,1 }, {}); + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); - - ASSERT_TRUE(dA.isSameShape(*dLdAbp)); - ASSERT_TRUE(dA.equalsTo(*dLdAbp)); - - ASSERT_TRUE(dB.isSameShape(*dLdBbp)); - ASSERT_TRUE(dB.equalsTo(*dLdBbp)); + ASSERT_TRUE(dA.isSameShape(dLdAbp)); + ASSERT_TRUE(dA.equalsTo(dLdAbp)); + ASSERT_TRUE(dB.isSameShape(dLdBbp)); + ASSERT_TRUE(dB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP12) { - - NDArray A('c', { 2, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, sd::DataType::FLOAT32); - NDArray B('c', { 2, 2 ,3 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2 }, sd::DataType::FLOAT32); - NDArray dLdC('c', { 2, 3, 2, 3 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, - 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4, - 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6 }, sd::DataType::FLOAT32); - - NDArray dA('c', { 2, 2, 3 }, { 7.66, 20.26, 32.86, 8.29, 21.97, 35.65, 45.46, 58.06, 70.66, 49.33, 63.01, 76.69 }, sd::DataType::FLOAT32); - NDArray dB('c', { 2, 2, 3 }, { 25.86, 27.36, 28.86, 28.74, 30.42, 32.1, 30.36, 31.86, 33.36, 33.78, 35.46, 37.14 }, sd::DataType::FLOAT32); - - sd::ops::tensormmul_bp op_bp; - - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {}); - - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); - - ASSERT_TRUE(dA.isSameShape(*dLdAbp)); - ASSERT_TRUE(dA.equalsTo(*dLdAbp)); - - ASSERT_TRUE(dB.isSameShape(*dLdBbp)); - ASSERT_TRUE(dB.equalsTo(*dLdBbp)); - + NDArray A('c', {2, 2, 3}, + {2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2}, + sd::DataType::FLOAT32); + NDArray B('c', {2, 2, 3}, + {3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2}, + sd::DataType::FLOAT32); + NDArray dLdC('c', {2, 3, 2, 3}, + {.1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, + 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4, + 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6}, + sd::DataType::FLOAT32); + + NDArray dA('c', {2, 2, 3}, + {7.66, 20.26, 32.86, 8.29, 21.97, 35.65, 45.46, 58.06, 70.66, + 49.33, 63.01, 76.69}, + sd::DataType::FLOAT32); + NDArray dB('c', {2, 2, 3}, + {25.86, 27.36, 28.86, 28.74, 30.42, 32.1, 30.36, 31.86, 33.36, + 33.78, 35.46, 37.14}, + sd::DataType::FLOAT32); + + sd::ops::tensormmul_bp op_bp; + + auto resultsBP = op_bp.evaluate({&A, &B, &dLdC}, {}, {1, 1, 1, 1}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); + + ASSERT_TRUE(dA.isSameShape(dLdAbp)); + ASSERT_TRUE(dA.equalsTo(dLdAbp)); + + ASSERT_TRUE(dB.isSameShape(dLdBbp)); + ASSERT_TRUE(dB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP13) { - - NDArray A('c', { 3, 2, 2 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, sd::DataType::DOUBLE); - NDArray B('c', { 3, 2, 2 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2 }, sd::DataType::DOUBLE); - NDArray dLdC('c', { 3, 2, 3, 2 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, - 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4, - 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6 }, sd::DataType::DOUBLE); - - NDArray dA('c', { 3, 2, 2 }, { 7.79, 20.57, 8.21, 21.71, 33.35, 46.13, 35.21, 48.71, 58.91, 71.69, 62.21, 75.71 }, sd::DataType::DOUBLE); - NDArray dB('c', { 3, 2, 2 }, { 26.49, 28.02, 28.41, 30.06, 29.55, 31.08, 31.71, 33.36, 32.61, 34.14, 35.01, 36.66 }, sd::DataType::DOUBLE); - - sd::ops::tensormmul_bp op_bp; - - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {}); - - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); - - ASSERT_TRUE(dA.isSameShape(*dLdAbp)); - ASSERT_TRUE(dA.equalsTo(*dLdAbp)); - - ASSERT_TRUE(dB.isSameShape(*dLdBbp)); - ASSERT_TRUE(dB.equalsTo(*dLdBbp)); - + NDArray A('c', {3, 2, 2}, + {2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2}, + sd::DataType::DOUBLE); + NDArray B('c', {3, 2, 2}, + {3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2}, + sd::DataType::DOUBLE); + NDArray dLdC('c', {3, 2, 3, 2}, + {.1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, + 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4, + 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6}, + sd::DataType::DOUBLE); + + NDArray dA('c', {3, 2, 2}, + {7.79, 20.57, 8.21, 21.71, 33.35, 46.13, 35.21, 48.71, 58.91, + 71.69, 62.21, 75.71}, + sd::DataType::DOUBLE); + NDArray dB('c', {3, 2, 2}, + {26.49, 28.02, 28.41, 30.06, 29.55, 31.08, 31.71, 33.36, 32.61, + 34.14, 35.01, 36.66}, + sd::DataType::DOUBLE); + + sd::ops::tensormmul_bp op_bp; + + auto resultsBP = op_bp.evaluate({&A, &B, &dLdC}, {}, {1, 1, 1, 1}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); + + ASSERT_TRUE(dA.isSameShape(dLdAbp)); + ASSERT_TRUE(dA.equalsTo(dLdAbp)); + + ASSERT_TRUE(dB.isSameShape(dLdBbp)); + ASSERT_TRUE(dB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP14) { + NDArray A('c', {2, 2, 2, 2}, + {2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, + 3.4, 3.5, 3.6}, + sd::DataType::DOUBLE); - NDArray A('c', { 2, 2, 2, 2 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6 }, sd::DataType::DOUBLE); - - NDArray B('c', { 2, 2, 2, 2 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, sd::DataType::DOUBLE); + NDArray B('c', {2, 2, 2, 2}, + {3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, + 4.4, 4.5, 4.6}, + sd::DataType::DOUBLE); - NDArray dLdC('c', { 2, 2, 2, 2, 2, 2 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, - 1.3, 1.4, 1.5, 1.6, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, - 1.3, 1.4, 1.5, 1.6, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, - 1.3, 1.4, 1.5, 1.6, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, - 1.3, 1.4, 1.5, 1.6 }, sd::DataType::DOUBLE); + NDArray dLdC( + 'c', {2, 2, 2, 2, 2, 2}, + {.1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, + .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, + .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, + .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6}, + sd::DataType::DOUBLE); - NDArray dA('c', { 2, 2, 2, 2 }, { 13.88, 37.24, 13.88, 37.24, 15.32, 41.24, 15.32, 41.24, 13.88, 37.24, 13.88, 37.24, 15.32, 41.24, 15.32, 41.24 }, sd::DataType::DOUBLE); - NDArray dB('c', { 2, 2, 2, 2 }, { 10.76, 12.88, 15., 17.12, 12.36, 14.8, 17.24, 19.68, 19.24, 21.36, 23.48, 25.6, 22.12, 24.56, 27., 29.44 }, sd::DataType::DOUBLE); + NDArray dA('c', {2, 2, 2, 2}, + {13.88, 37.24, 13.88, 37.24, 15.32, 41.24, 15.32, 41.24, 13.88, + 37.24, 13.88, 37.24, 15.32, 41.24, 15.32, 41.24}, + sd::DataType::DOUBLE); + NDArray dB('c', {2, 2, 2, 2}, + {10.76, 12.88, 15., 17.12, 12.36, 14.8, 17.24, 19.68, 19.24, 21.36, + 23.48, 25.6, 22.12, 24.56, 27., 29.44}, + sd::DataType::DOUBLE); - sd::ops::tensormmul_bp op_bp; + sd::ops::tensormmul_bp op_bp; - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {}); + auto resultsBP = op_bp.evaluate({&A, &B, &dLdC}, {}, {1, 1, 1, 1}, {}); - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(dA.isSameShape(*dLdAbp)); - ASSERT_TRUE(dA.equalsTo(*dLdAbp)); - - ASSERT_TRUE(dB.isSameShape(*dLdBbp)); - ASSERT_TRUE(dB.equalsTo(*dLdBbp)); + ASSERT_TRUE(dA.isSameShape(dLdAbp)); + ASSERT_TRUE(dA.equalsTo(dLdAbp)); + ASSERT_TRUE(dB.isSameShape(dLdBbp)); + ASSERT_TRUE(dB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP15) { + NDArray A('c', {2, 2, 3}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}, + sd::DataType::FLOAT32); + NDArray B('f', {2, 2, 3}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}, + sd::DataType::FLOAT32); - NDArray A('c', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, sd::DataType::FLOAT32); - NDArray B('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, sd::DataType::FLOAT32); - - NDArray dLdC('f', { 2, 2 }, { 23.0, 24.44, 2.0, 26. }, sd::DataType::FLOAT32); - - NDArray dA('c', { 2, 2, 3 }, { 27., 127., 227., 77., 177., 277., 76.44, 278.20001, 479.96002, 177.32, 379.08001, 580.839966 }, sd::DataType::FLOAT32); - NDArray dB('f', { 2, 2, 3 }, { 194.08, 184., 336.4, 268., 241.52, 212., 383.839996, 296., 288.96002, 240., 431.27999, 324. }, sd::DataType::FLOAT32); + NDArray dLdC('f', {2, 2}, {23.0, 24.44, 2.0, 26.}, sd::DataType::FLOAT32); - sd::ops::tensormmul_bp op; - auto results = op.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2,2,1,2 }); + NDArray dA('c', {2, 2, 3}, + {27., 127., 227., 77., 177., 277., 76.44, 278.20001, 479.96002, + 177.32, 379.08001, 580.839966}, + sd::DataType::FLOAT32); + NDArray dB('f', {2, 2, 3}, + {194.08, 184., 336.4, 268., 241.52, 212., 383.839996, 296., + 288.96002, 240., 431.27999, 324.}, + sd::DataType::FLOAT32); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::tensormmul_bp op; + auto results = op.evaluate({&A, &B, &dLdC}, {}, {2, 1, 2, 2, 1, 2}); - auto* dLdA = results.at(0); - auto* dLdB = results.at(1); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dA.isSameShape(*dLdA)); - ASSERT_TRUE(dA.equalsTo(*dLdA)); + auto dLdA = results.at(0); + auto dLdB = results.at(1); - ASSERT_TRUE(dB.isSameShape(*dLdB)); - ASSERT_TRUE(dB.equalsTo(*dLdB)); + ASSERT_TRUE(dA.isSameShape(dLdA)); + ASSERT_TRUE(dA.equalsTo(dLdA)); + ASSERT_TRUE(dB.isSameShape(dLdB)); + ASSERT_TRUE(dB.equalsTo(dLdB)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP16) { + NDArray A('f', {2, 2, 3}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}, + sd::DataType::DOUBLE); + NDArray B('c', {2, 2, 3}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}, + sd::DataType::DOUBLE); - NDArray A('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, sd::DataType::DOUBLE); - NDArray B('c', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, sd::DataType::DOUBLE); - - NDArray dLdC('c', { 2, 2 }, sd::DataType::DOUBLE); + NDArray dLdC('c', {2, 2}, sd::DataType::DOUBLE); - const OpArgsHolder argsHolderFF({ &A, &B }, {}, { 2,1,2, 2,1,2 }); - const OpArgsHolder argsHolderBP({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }); + const OpArgsHolder argsHolderFF({&A, &B}, {}, {2, 1, 2, 2, 1, 2}); + const OpArgsHolder argsHolderBP({&A, &B, &dLdC}, {}, {2, 1, 2, 2, 1, 2}); - sd::ops::tensormmul op; - sd::ops::tensormmul_bp op_bp; + sd::ops::tensormmul op; + sd::ops::tensormmul_bp op_bp; - const bool isGradCorrect = GradCheck::checkGrad(op, op_bp, argsHolderFF, argsHolderBP, {1,0}); - ASSERT_TRUE(isGradCorrect); + const bool isGradCorrect = + GradCheck::checkGrad(op, op_bp, argsHolderFF, argsHolderBP, {1, 0}); + ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP17) { + NDArray A('f', {2, 2, 3}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}, + sd::DataType::DOUBLE); + NDArray B('f', {2, 2, 3}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}, + sd::DataType::DOUBLE); - NDArray A('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, sd::DataType::DOUBLE); - NDArray B('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, sd::DataType::DOUBLE); - - NDArray dLdC('c', { 2, 2 }, sd::DataType::DOUBLE); + NDArray dLdC('c', {2, 2}, sd::DataType::DOUBLE); - const OpArgsHolder argsHolderFF({ &A, &B }, {}, { 2,1,2, 2,1,2 }); - const OpArgsHolder argsHolderBP({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }); + const OpArgsHolder argsHolderFF({&A, &B}, {}, {2, 1, 2, 2, 1, 2}); + const OpArgsHolder argsHolderBP({&A, &B, &dLdC}, {}, {2, 1, 2, 2, 1, 2}); - sd::ops::tensormmul op; - sd::ops::tensormmul_bp op_bp; + sd::ops::tensormmul op; + sd::ops::tensormmul_bp op_bp; - const bool isGradCorrect = GradCheck::checkGrad(op, op_bp, argsHolderFF, argsHolderBP, { 1,0 }); - ASSERT_TRUE(isGradCorrect); + const bool isGradCorrect = + GradCheck::checkGrad(op, op_bp, argsHolderFF, argsHolderBP, {1, 0}); + ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, gru_1) { + const int sL = 3; + const int bS = 2; + const int nIn = 5; + const int nOut = 4; - const int sL = 3; - const int bS = 2; - const int nIn = 5; - const int nOut = 4; + NDArray x('c', {sL, bS, nIn}, + {0.5, 1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., + 5.5, 6., 6.5, 7., 7.5, 8., 8.5, 9., 9.5, 10., + 10.5, 11., 11.5, 12., 12.5, 13., 13.5, 14., 14.5, 15.}, + sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, {-3, -2, -1, 0, 1, 2, 3, 4}, + sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 3 * nOut}, sd::DataType::FLOAT32); + NDArray Wh('c', {nOut, 3 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {3 * nOut}, sd::DataType::FLOAT32); + NDArray expH( + 'c', {sL, bS, nOut}, + {-1.681847, -1.062565, -0.443283, 0.175998, 0.837823, 1.488041, + 2.13826, 2.788478, -0.888747, -0.491826, -0.094907, 0.302014, + 0.751355, 1.182715, 1.614075, 2.045434, -0.388876, -0.126716, + 0.135444, 0.397604, 0.710558, 1.002922, 1.295287, 1.587651}, + sd::DataType::FLOAT32); - NDArray x('c', {sL, bS, nIn}, {0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5, 5. , 5.5, 6. , 6.5, 7. , 7.5, 8. , 8.5, 9. , 9.5, 10. , 10.5, 11. , 11.5, 12. , 12.5, 13. , 13.5, 14. , 14.5, 15.}, sd::DataType::FLOAT32); - NDArray hI('c', {bS, nOut}, {-3,-2,-1,0,1,2,3,4}, sd::DataType::FLOAT32); - NDArray Wx('c', {nIn, 3*nOut}, sd::DataType::FLOAT32); - NDArray Wh('c', {nOut, 3*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {3*nOut}, sd::DataType::FLOAT32); + Wx = 0.003; + Wh = 0.006; + b = 0.5; - NDArray expH('c', {sL, bS, nOut}, {-1.681847, -1.062565, -0.443283, 0.175998,0.837823, 1.488041, 2.13826 , 2.788478, -0.888747, -0.491826, -0.094907, 0.302014, - 0.751355, 1.182715, 1.614075, 2.045434, -0.388876, -0.126716, 0.135444, 0.397604,0.710558, 1.002922, 1.295287, 1.587651}, sd::DataType::FLOAT32); + NDArray dLdC('c', {2, 2}, sd::DataType::DOUBLE); - Wx = 0.003; - Wh = 0.006; - b = 0.5; + sd::ops::gru op; + auto results = op.evaluate({&x, &hI, &Wx, &Wh, &b}, {}, {}); - NDArray dLdC('c', { 2, 2 }, sd::DataType::DOUBLE); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - sd::ops::gru op; - auto results = op.evaluate({&x, &hI, &Wx, &Wh, &b}, {}, {}); + auto h = results.at(0); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto* h = results.at(0); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, sqrtm_1) { - - NDArray x1('c', {1,1}, {4.}, sd::DataType::DOUBLE); - NDArray x2('c', {2,2}, {1.3,2,0.3,.5}, sd::DataType::DOUBLE); - NDArray x3('c', {3,3}, {0.5 ,-0.4 ,1.2 ,-2.8 ,-0.2 ,-2.1 ,-2.4 ,-2.0 ,1.1}, sd::DataType::DOUBLE); - NDArray x4('c', {4,4}, {0.33 ,-7.25 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,7.59 ,3.44 ,2.24 ,-6.82 ,-1.15 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE); - NDArray x5('c', {5,5}, {2.4 ,0.3 ,0.0 ,1.1 ,1.8 ,0.1 ,1.7 ,2.7 ,1.5 ,2.6 ,0.6 ,2.1 ,2.2 ,1.0 ,0.2 ,1.2 ,2.8 ,1.9 ,0.8 ,2.0 ,0.5 ,1.6 ,0.9 ,1.4 ,2.5}, sd::DataType::DOUBLE); - - NDArray exp1('c', {1,1}, {2.}, sd::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {1.0163674, 1.3341597,0.200124, 0.4827035}, sd::DataType::DOUBLE); - NDArray exp3('c', {3,3}, {6.5692188, 2.6273616,-0.1387864,-16.8404762,-7.0296495, 0.9204148,-11.4664296,-5.834273 , 2.2087478}, sd::DataType::DOUBLE); - NDArray exp4('c', {4,4}, {1.161387 ,-1.9343154, 0.230372 , 0.8660897,0.80588 , 3.4045446,-1.0152824,-2.0369467,2.2589629, 1.9674252, 1.5109997,-1.4283141,0.0226356, 1.3032279,-1.00396 , 1.8278487}, sd::DataType::DOUBLE); - NDArray exp5('c', {5,5}, {1.4175046,-0.4425298, 0.1846149, 0.3166522, 0.9140631,-0.1929139, 0.2889113, 1.4045273, 0.2600026, 1.552021 , 0.1372758, 0.5703854, 1.3336126, 0.3869317,-0.082492 , + NDArray x1('c', {1,1}, {4.}, sd::DataType::DOUBLE); + NDArray x2('c', {2,2}, {1.3,2,0.3,.5}, sd::DataType::DOUBLE); + NDArray x3('c', {3,3}, {0.5 ,-0.4 ,1.2 ,-2.8 ,-0.2 ,-2.1 ,-2.4 ,-2.0 ,1.1}, sd::DataType::DOUBLE); + NDArray x4('c', {4,4}, {0.33 ,-7.25 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,7.59 ,3.44 ,2.24 ,-6.82 ,-1.15 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE); + NDArray x5('c', {5,5}, {2.4 ,0.3 ,0.0 ,1.1 ,1.8 ,0.1 ,1.7 ,2.7 ,1.5 ,2.6 ,0.6 ,2.1 ,2.2 ,1.0 ,0.2 ,1.2 ,2.8 ,1.9 ,0.8 ,2.0 ,0.5 ,1.6 ,0.9 ,1.4 ,2.5}, sd::DataType::DOUBLE); + + NDArray exp1('c', {1,1}, {2.}, sd::DataType::DOUBLE); + NDArray exp2('c', {2,2}, {1.0163674, 1.3341597,0.200124, 0.4827035}, sd::DataType::DOUBLE); + NDArray exp3('c', {3,3}, {6.5692188, 2.6273616,-0.1387864,-16.8404762,-7.0296495, 0.9204148,-11.4664296,-5.834273 , 2.2087478}, sd::DataType::DOUBLE); + NDArray exp4('c', {4,4}, {1.161387 ,-1.9343154, 0.230372 , 0.8660897,0.80588 , 3.4045446,-1.0152824,-2.0369467,2.2589629, 1.9674252, 1.5109997,-1.4283141,0.0226356, 1.3032279,-1.00396 , 1.8278487}, sd::DataType::DOUBLE); + NDArray exp5('c', {5,5}, {1.4175046,-0.4425298, 0.1846149, 0.3166522, 0.9140631,-0.1929139, 0.2889113, 1.4045273, 0.2600026, 1.552021 , 0.1372758, 0.5703854, 1.3336126, 0.3869317,-0.082492 , 0.8607272, 3.1792474,-0.9499947, 0.8541668,-1.4243879, 0.0081136,-0.0622248, 0.4534325, 0.4641865, 1.8132138}, sd::DataType::DOUBLE); - sd::ops::sqrtm op; + sd::ops::sqrtm op; - auto results = op.evaluate({&x1}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(exp1.isSameShape(results.at(0))); - ASSERT_TRUE(exp1.equalsTo(results.at(0))); + auto results = op.evaluate({&x1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(exp1.isSameShape(results.at(0))); + ASSERT_EQ(exp1, results.at(0)); - results = op.evaluate({&x2}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(exp2.isSameShape(results.at(0))); - ASSERT_TRUE(exp2.equalsTo(results.at(0))); + results = op.evaluate({&x2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(exp2.isSameShape(results.at(0))); + ASSERT_EQ(exp2, results.at( 0)); - results = op.evaluate({&x3}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(exp3.isSameShape(results.at(0))); - ASSERT_TRUE(exp3.equalsTo(results.at(0))); + results = op.evaluate({&x3}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(exp3.isSameShape(results.at(0))); + ASSERT_EQ(exp3, results.at(0)); - results = op.evaluate({&x4}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(exp4.isSameShape(results.at(0))); - ASSERT_TRUE(exp4.equalsTo(results.at(0))); + results = op.evaluate({&x4}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(exp4.isSameShape(results.at(0))); + ASSERT_EQ(exp4, results.at(0)); - results = op.evaluate({&x5}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(exp5.isSameShape(results.at(0))); - ASSERT_TRUE(exp5.equalsTo(results.at(0))); + results = op.evaluate({&x5}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(exp5.isSameShape(results.at( 0))); + ASSERT_EQ(exp5, results.at(0)); } -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests15, sqrtm_2) { + ////////////////////////////////////////////////////////////////////// + TEST_F(DeclarableOpsTests15, sqrtm_2) { - NDArray x('c', {10,10}, {-0.3 ,2.7 ,4.9 ,7.0 ,7.3 ,-1.3 ,0.5 ,9.9 ,-9.4 ,8.4 ,2.2 ,5.2 ,7.6 ,1.2 ,2.0 ,-3.8 ,2.1 ,6.1 ,1.6 ,6.9 ,5.1 ,5.3 ,6.4 ,8.7 ,0.1 ,8.5 , - 3.3 ,1.0 ,6.8 ,0.4 ,0.7 ,3.2 ,7.4 ,6.7 ,1.1 ,7.2 ,6.0 ,7.5 ,9.7 ,5.4 ,9.0 ,6.3 ,0.0 ,4.5 ,8.3 ,7.9 ,3.0 ,6.5 ,0.6 ,8.0 ,9.5 ,3.6 ,1.9 ,6.2 ,0.9 ,4.0 ,4.1 , + NDArray x('c', {10,10}, {-0.3 ,2.7 ,4.9 ,7.0 ,7.3 ,-1.3 ,0.5 ,9.9 ,-9.4 ,8.4 ,2.2 ,5.2 ,7.6 ,1.2 ,2.0 ,-3.8 ,2.1 ,6.1 ,1.6 ,6.9 ,5.1 ,5.3 ,6.4 ,8.7 ,0.1 ,8.5 , + 3.3 ,1.0 ,6.8 ,0.4 ,0.7 ,3.2 ,7.4 ,6.7 ,1.1 ,7.2 ,6.0 ,7.5 ,9.7 ,5.4 ,9.0 ,6.3 ,0.0 ,4.5 ,8.3 ,7.9 ,3.0 ,6.5 ,0.6 ,8.0 ,9.5 ,3.6 ,1.9 ,6.2 ,0.9 ,4.0 ,4.1 , 8.1 ,3.9 ,4.3 ,4.7 ,3.7 ,3.4 ,5.8 ,10.0 ,8.6 ,9.3 ,9.1 ,4.6 ,1.4 ,7.8 ,1.5 ,7.7 ,4.2 ,9.6 ,8.2 ,-7.1 ,5.7 ,5.5 ,2.6 ,8.8 ,2.9 ,0.2 ,5.6 ,-2.5 ,8.9 ,2.8 ,0.8 ,1.5 ,3.1 ,3.5 ,4.4 ,2.4 ,9.2 ,-4.8 ,1.7 ,6.6 ,9.8 ,1.8 ,5.9}, sd::DataType::DOUBLE); - NDArray expZ('c', {10,10}, {1.2779038, 0.0333321, 0.8215617, 0.5736392, 1.3973911, -1.1757741,0.1990005, 1.5893778, -3.0159568, 2.5829108,0.5692253, 2.219431 , 1.022612 , -0.3131795, -0.1957848, -1.7805065, + NDArray expZ('c', {10,10}, {1.2779038, 0.0333321, 0.8215617, 0.5736392, 1.3973911, -1.1757741,0.1990005, 1.5893778, -3.0159568, 2.5829108,0.5692253, 2.219431 , 1.022612 , -0.3131795, -0.1957848, -1.7805065, 0.6668489, 1.1968921, 0.9781974, 1.2007764,0.7028634, 0.7496937, 2.2511438, 2.1945378, 0.2559353, 2.8948612,-0.4306994, -0.9922216, 0.3884369, -1.4174481, -1.6060233, 0.1571057, 1.432471 , 0.4508346, 0.0618069, -2.4511742,2.0641709, 2.4751085, 1.84787 , 3.4146313,0.7774219, 0.768369 , -0.1417226, -0.3970577, 2.9512879, 0.5474537, 0.4991412, 0.7604095, 0.4523091, 1.7813704,2.5998339, 0.9402402, -0.82775 , 2.3637147, -0.6394584, 4.6181937,-0.1762181, -0.2820475, 0.9280713, -2.1876918, diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp index cbec08c0ca5b..36b4134641a4 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp @@ -14,1065 +14,1038 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +// +// @author raver119@gmail.com +// - // - // @author raver119@gmail.com - // - -#include "testlayers.h" -#include #include -#include #include +#include +#include + #include +#include "testlayers.h" using namespace sd; - class DeclarableOpsTests16 : public testing::Test { -public: - - DeclarableOpsTests16() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests16() { + printf("\n"); + fflush(stdout); + } }; TEST_F(DeclarableOpsTests16, scatter_upd_1) { - auto x = NDArrayFactory::create('c', { 3 }, { 1.f, 1.f, 1.f }); - auto y = NDArrayFactory::create(0); - auto w = NDArrayFactory::create(3.0f); - auto e = NDArrayFactory::create('c', { 3 }, { 3.f, 1.f, 1.f }); + auto x = NDArrayFactory::create('c', {3}, {1.f, 1.f, 1.f}); + auto y = NDArrayFactory::create(0); + auto w = NDArrayFactory::create(3.0f); + auto e = NDArrayFactory::create('c', {3}, {3.f, 1.f, 1.f}); - sd::ops::scatter_upd op; - auto result = op.evaluate({ &x, &y, &w }); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::scatter_upd op; + auto result = op.evaluate({&x, &y, &w}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests16, scatter_upd_2) { + NDArray x('c', {10, 3}, sd::DataType::FLOAT32); + NDArray indices('c', {2}, {2, 5}, sd::DataType::INT32); + NDArray updates('c', {2, 3}, {100, 101, 102, 200, 201, 202}, + sd::DataType::FLOAT32); + NDArray e('c', {10, 3}, + {1, 2, 3, 4, 5, 6, 100, 101, 102, 10, 11, 12, 13, 14, 15, + 200, 201, 202, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}, + sd::DataType::FLOAT32); - NDArray x('c', { 10, 3 }, sd::DataType::FLOAT32); - NDArray indices('c', { 2 }, { 2,5 }, sd::DataType::INT32); - NDArray updates('c', { 2, 3 }, { 100,101,102, 200,201,202 }, sd::DataType::FLOAT32); - NDArray e('c', { 10, 3 }, { 1,2,3, 4,5,6, 100,101,102, 10,11,12, 13,14,15, 200,201,202, 19,20,21, 22,23,24, 25,26,27, 28,29,30 }, sd::DataType::FLOAT32); + x.linspace(1); - x.linspace(1); + sd::ops::scatter_upd op; + auto result = op.evaluate({&x, &indices, &updates}); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::scatter_upd op; - auto result = op.evaluate({ &x, &indices, &updates }); - ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); - auto z = result.at(0); - - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests16, scatter_upd_3) { - - NDArray x('c', { 10, 3 }, sd::DataType::FLOAT32); - NDArray indices('c', { 2 }, { 20,5 }, sd::DataType::INT32); - NDArray updates('c', { 2, 3 }, { 100,101,102, 200,201,202 }, sd::DataType::FLOAT32); - NDArray output('c', { 10, 3 }, sd::DataType::FLOAT32); - - sd::ops::scatter_upd op; - ASSERT_ANY_THROW(op.execute({ &x, &indices, &updates }, { &output }, {}, {}, { true, true })); + NDArray x('c', {10, 3}, sd::DataType::FLOAT32); + NDArray indices('c', {2}, {20, 5}, sd::DataType::INT32); + NDArray updates('c', {2, 3}, {100, 101, 102, 200, 201, 202}, + sd::DataType::FLOAT32); + NDArray output('c', {10, 3}, sd::DataType::FLOAT32); + + sd::ops::scatter_upd op; + ASSERT_ANY_THROW( + op.execute({&x, &indices, &updates}, {&output}, {}, {}, {true, true})); } TEST_F(DeclarableOpsTests16, test_size_dtype_1) { - auto x = NDArrayFactory::create('c', { 3 }, { 1, 1, 1 }); - auto z = NDArrayFactory::create(0.0f); - auto e = NDArrayFactory::create(3.0f); + auto x = NDArrayFactory::create('c', {3}, {1, 1, 1}); + auto z = NDArrayFactory::create(0.0f); + auto e = NDArrayFactory::create(3.0f); - sd::ops::size op; - auto status = op.execute({ &x }, { &z }, {}, {}, {}); - ASSERT_EQ(Status::OK(), status); + sd::ops::size op; + auto status = op.execute({&x}, {&z}, {}, {}, {}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests16, test_empty_noop_1) { - auto z = NDArrayFactory::empty(); + auto z = NDArrayFactory::empty(); - sd::ops::noop op; - auto status = op.execute({}, { &z }, {}, {}, {}); - ASSERT_EQ(Status::OK(), status); + sd::ops::noop op; + auto status = op.execute({}, {&z}, {}, {}, {}); + ASSERT_EQ(Status::OK(), status); } TEST_F(DeclarableOpsTests16, test_empty_noop_2) { - auto z = NDArrayFactory::empty(); + auto z = NDArrayFactory::empty(); - Context ctx(1); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + Context ctx(1); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); - sd::ops::noop op; - auto status = op.execute(&ctx); + sd::ops::noop op; + auto status = op.execute(&ctx); - ASSERT_EQ(Status::OK(), status); + ASSERT_EQ(Status::OK(), status); } TEST_F(DeclarableOpsTests16, test_svd_1) { - auto x = NDArrayFactory::create('c', { 3, 3 }, { 0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f,0.50563407f, 0.89252293f, 0.5461209f }); - auto z = NDArrayFactory::create('c', { 3 }); + auto x = NDArrayFactory::create( + 'c', {3, 3}, + {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, + 0.18039072f, 0.50563407f, 0.89252293f, 0.5461209f}); + auto z = NDArrayFactory::create('c', {3}); - sd::ops::svd op; - auto status = op.execute({ &x }, { &z }, {}, { 0, 0, 16 }, {}); + sd::ops::svd op; + auto status = op.execute({&x}, {&z}, {}, {0, 0, 16}, {}); - ASSERT_EQ(Status::OK(), status); + ASSERT_EQ(Status::OK(), status); } TEST_F(DeclarableOpsTests16, test_hamming_distance_1) { - auto x = NDArrayFactory::create({ 37, 37, 37 }); - auto y = NDArrayFactory::create({ 8723, 8723, 8723 }); - auto e = NDArrayFactory::create(18); + auto x = NDArrayFactory::create({37, 37, 37}); + auto y = NDArrayFactory::create({8723, 8723, 8723}); + auto e = NDArrayFactory::create(18); - sd::ops::bits_hamming_distance op; - auto result = op.evaluate({ &x, &y }); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::bits_hamming_distance op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests16, test_knn_mindistance_1) { - auto input = NDArrayFactory::create('c', { 512 }); - auto low = NDArrayFactory::create('c', { 512 }); - auto high = NDArrayFactory::create('c', { 512 }); + auto input = NDArrayFactory::create('c', {512}); + auto low = NDArrayFactory::create('c', {512}); + auto high = NDArrayFactory::create('c', {512}); - auto output = NDArrayFactory::create(0.0f); + auto output = NDArrayFactory::create(0.0f); - input.linspace(1.0); - low.linspace(1.0); - high.linspace(1.0); + input.linspace(1.0); + low.linspace(1.0); + high.linspace(1.0); - sd::ops::knn_mindistance op; - auto result = op.execute({ &input, &low, &high }, { &output }, {}, {}, {}); - ASSERT_EQ(Status::OK(), result); + sd::ops::knn_mindistance op; + auto result = op.execute({&input, &low, &high}, {&output}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); } TEST_F(DeclarableOpsTests16, test_empty_cast_1) { - auto x = NDArrayFactory::create('c', { 1, 0, 2 }); - auto e = NDArrayFactory::create('c', { 1, 0, 2 }); + auto x = NDArrayFactory::create('c', {1, 0, 2}); + auto e = NDArrayFactory::create('c', {1, 0, 2}); - sd::ops::cast op; - auto result = op.evaluate({&x}, {10}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, *result.at(0)); + sd::ops::cast op; + auto result = op.evaluate({&x}, {10}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(e, result.at(0)); } TEST_F(DeclarableOpsTests16, test_range_1) { - sd::ops::range op; - auto z = NDArrayFactory::create('c', { 200 }); + sd::ops::range op; + auto z = NDArrayFactory::create('c', {200}); - Context ctx(1); - ctx.setTArguments({ -1.0, 1.0, 0.01 }); - ctx.setOutputArray(0, &z); + Context ctx(1); + ctx.setTArguments({-1.0, 1.0, 0.01}); + ctx.setOutputArray(0, z); - auto status = op.execute(&ctx); - ASSERT_EQ(Status::OK(), status); + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); } TEST_F(DeclarableOpsTests16, test_range_2) { - sd::ops::range op; - auto z = NDArrayFactory::create('c', { 200 }); + sd::ops::range op; + auto z = NDArrayFactory::create('c', {200}); - double tArgs[] = { -1.0, 1.0, 0.01 }; + double tArgs[] = {-1.0, 1.0, 0.01}; - auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0, nullptr, 0); - // shape::printShapeInfoLinear("Result", shapes->at(0)); - ASSERT_TRUE(shape::shapeEquals(z.shapeInfo(), shapes->at(0))); + auto shapes = + ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, + tArgs, 3, nullptr, 0, nullptr, 0, nullptr, 0); + //shape::printShapeInfoLinear("Result", shapes->at(0)); + ASSERT_TRUE(shape::shapeEquals(z.shapeInfo(), shapes->at(0))); - delete shapes; + delete shapes; } TEST_F(DeclarableOpsTests16, test_reverse_1) { - std::vector rows = { 3, 5, 7, 8, 9, 10, 119, 211 }; - std::vector columns = { 6, 5, 10, 100, 153, 171, 635 }; - - for (auto r : rows) { - for (auto c : columns) { - //nd4j_printf("Trying [%i, %i]\n", r, c); - auto array = NDArrayFactory::create('c', { r, c }); - auto exp = NDArrayFactory::create('c', { r, c }); - auto reversed = NDArrayFactory::create('c', { r, c }); - - auto rowOriginal = NDArrayFactory::create('c', { c }); - auto rowReversed = NDArrayFactory::create('c', { c }); - - for (int e = 0; e < c; e++) { - rowOriginal.p(e, (float)e); - rowReversed.p(c - e - 1, (float)e); - } - - - auto listI = array.allTensorsAlongDimension({ 1 }); - auto listE = exp.allTensorsAlongDimension({ 1 }); - - for (int e = 0; e < r; e++) { - listI.at(e)->assign(rowOriginal); - listE.at(e)->assign(rowReversed); - } - - sd::ops::reverse op; - Nd4jLong axis = 1; - auto status = op.execute({ &array }, { &reversed }, {}, { axis }, {}); - ASSERT_EQ(Status::OK(), status); - - ASSERT_EQ(exp, reversed); - } + std::vector rows = {3, 5, 7, 8, 9, 10, 119, 211}; + std::vector columns = {6, 5, 10, 100, 153, 171, 635}; + + for (auto r : rows) { + for (auto c : columns) { + // nd4j_printf("Trying [%i, %i]\n", r, c); + auto array = NDArrayFactory::create('c', {r, c}); + auto exp = NDArrayFactory::create('c', {r, c}); + auto reversed = NDArrayFactory::create('c', {r, c}); + + auto rowOriginal = NDArrayFactory::create('c', {c}); + auto rowReversed = NDArrayFactory::create('c', {c}); + + for (int e = 0; e < c; e++) { + rowOriginal.p(e, (float)e); + rowReversed.p(c - e - 1, (float)e); + } + + auto listI = array.allTensorsAlongDimension({1}); + auto listE = exp.allTensorsAlongDimension({1}); + + for (int e = 0; e < r; e++) { + listI.at(e).assign(rowOriginal); + listE.at(e).assign(rowReversed); + } + + sd::ops::reverse op; + Nd4jLong axis = 1; + auto status = op.execute({&array}, {&reversed}, {}, {axis}, {}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(exp, reversed); } + } } TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_1) { - /* - test case generated by python colorsys and scaled to suit our needs - from colorsys import * - from random import * - import numpy as np - rgbs = np.random.uniform(0,1, 5*4*3 ).astype('float32').reshape([5,4,3]) - hsvs=np.apply_along_axis(lambda x: np.array(rgb_to_hsv(x[0],x[1],x[2])),2,rgbs) - rgbs.ravel() - hsvs.ravel() - */ - auto rgbs = NDArrayFactory::create('c', { 5, 4, 3 }, { - 0.545678377f, 0.725874603f, 0.413571358f, 0.644941628f, 0.517642438f, - 0.890151322f, 0.461456001f, 0.0869259685f, 0.928968489f, 0.588904262f, - 0.54742825f, 0.684074104f, 0.52110225f, 0.761800349f, 0.486593395f, - 0.753103435f, 0.237176552f, 0.263826847f, 0.913557053f, 0.90049392f, - 0.290193319f, 0.46850124f, 0.965541422f, 0.148351923f, 0.674094439f, - 0.524110138f, 0.216262609f, 0.0361763388f, 0.2204483f, 0.279114306f, - 0.3721793f, 0.632020354f, 0.25007084f, 0.823592246f, 0.637001634f, - 0.30433768f, 0.0448598303f, 0.385092884f, 0.366362303f, 0.586083114f, - 0.218390301f, 0.931746006f, 0.978048146f, 0.762684941f, 0.00208298792f, - 0.91390729f, 0.505838513f, 0.875348926f, 0.428009957f, 0.367065936f, - 0.911922634f, 0.270003974f, 0.164243385f, 0.0581932105f, 0.313204288f, - 0.644775152f, 0.437950462f, 0.775881767f, 0.575452209f, 0.946475744f - }); - auto expected = NDArrayFactory::create('c', { 5, 4, 3 }, { - 0.262831867f, 0.430244058f, 0.725874603f, 0.723622441f, 0.418478161f, - 0.890151322f, 0.740797927f, 0.906427443f, 0.928968489f, 0.717254877f, - 0.199753001f, 0.684074104f, 0.312434604f, 0.361258626f, 0.761800349f, - 0.991390795f, 0.685067773f, 0.753103435f, 0.163174023f, 0.682347894f, - 0.913557053f, 0.268038541f, 0.84635365f, 0.965541422f, 0.112067183f, - 0.679180562f, 0.674094439f, 0.540247589f, 0.870388806f, 0.279114306f, - 0.280050347f, 0.604331017f, 0.632020354f, 0.106776128f, 0.630475283f, - 0.823592246f, 0.490824632f, 0.883509099f, 0.385092884f, 0.75257351f, - 0.765611768f, 0.931746006f, 0.129888852f, 0.997870266f, 0.978048146f, - 0.849081645f, 0.446510047f, 0.91390729f, 0.685308874f, 0.597481251f, - 0.911922634f, 0.0834472676f, 0.784472764f, 0.270003974f, 0.396037966f, - 0.514242649f, 0.644775152f, 0.756701186f, 0.392005324f, 0.946475744f - }); - - - auto actual = NDArrayFactory::create('c', { 5,4,3 }); - - Context ctx(1); - ctx.setInputArray(0, &rgbs); - ctx.setOutputArray(0, &actual); - - sd::ops::rgb_to_hsv op; - auto status = op.execute(&ctx); + /* + test case generated by python colorsys and scaled to suit our needs + from colorsys import * + from random import * + import numpy as np + rgbs = np.random.uniform(0,1, 5*4*3 ).astype('float32').reshape([5,4,3]) + hsvs=np.apply_along_axis(lambda x: + np.array(rgb_to_hsv(x[0],x[1],x[2])),2,rgbs) rgbs.ravel() hsvs.ravel() + */ + auto rgbs = NDArrayFactory::create( + 'c', {5, 4, 3}, + {0.545678377f, 0.725874603f, 0.413571358f, 0.644941628f, + 0.517642438f, 0.890151322f, 0.461456001f, 0.0869259685f, + 0.928968489f, 0.588904262f, 0.54742825f, 0.684074104f, + 0.52110225f, 0.761800349f, 0.486593395f, 0.753103435f, + 0.237176552f, 0.263826847f, 0.913557053f, 0.90049392f, + 0.290193319f, 0.46850124f, 0.965541422f, 0.148351923f, + 0.674094439f, 0.524110138f, 0.216262609f, 0.0361763388f, + 0.2204483f, 0.279114306f, 0.3721793f, 0.632020354f, + 0.25007084f, 0.823592246f, 0.637001634f, 0.30433768f, + 0.0448598303f, 0.385092884f, 0.366362303f, 0.586083114f, + 0.218390301f, 0.931746006f, 0.978048146f, 0.762684941f, + 0.00208298792f, 0.91390729f, 0.505838513f, 0.875348926f, + 0.428009957f, 0.367065936f, 0.911922634f, 0.270003974f, + 0.164243385f, 0.0581932105f, 0.313204288f, 0.644775152f, + 0.437950462f, 0.775881767f, 0.575452209f, 0.946475744f}); + auto expected = NDArrayFactory::create( + 'c', {5, 4, 3}, + {0.262831867f, 0.430244058f, 0.725874603f, 0.723622441f, 0.418478161f, + 0.890151322f, 0.740797927f, 0.906427443f, 0.928968489f, 0.717254877f, + 0.199753001f, 0.684074104f, 0.312434604f, 0.361258626f, 0.761800349f, + 0.991390795f, 0.685067773f, 0.753103435f, 0.163174023f, 0.682347894f, + 0.913557053f, 0.268038541f, 0.84635365f, 0.965541422f, 0.112067183f, + 0.679180562f, 0.674094439f, 0.540247589f, 0.870388806f, 0.279114306f, + 0.280050347f, 0.604331017f, 0.632020354f, 0.106776128f, 0.630475283f, + 0.823592246f, 0.490824632f, 0.883509099f, 0.385092884f, 0.75257351f, + 0.765611768f, 0.931746006f, 0.129888852f, 0.997870266f, 0.978048146f, + 0.849081645f, 0.446510047f, 0.91390729f, 0.685308874f, 0.597481251f, + 0.911922634f, 0.0834472676f, 0.784472764f, 0.270003974f, 0.396037966f, + 0.514242649f, 0.644775152f, 0.756701186f, 0.392005324f, 0.946475744f}); + + auto actual = NDArrayFactory::create('c', {5, 4, 3}); + + Context ctx(1); + ctx.setInputArray(0, rgbs); + ctx.setOutputArray(0, actual); + + sd::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); #if 0 //visual check rgbs.printBuffer("rgbs "); actual.printBuffer("HSV "); expected.printBuffer("exp"); #endif - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_2) { - /* - swapped_rgbs=rgbs.swapaxes(1,2).ravel() - swapped_hsvs=hsvs.swapaxes(1,2).ravel() - */ - auto rgbs = NDArrayFactory::create('c', { 5, 3, 4 }, { - 0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, - 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f, - 0.928968489f, 0.684074104f, 0.52110225f, 0.753103435f, 0.913557053f, - 0.46850124f, 0.761800349f, 0.237176552f, 0.90049392f, 0.965541422f, - 0.486593395f, 0.263826847f, 0.290193319f, 0.148351923f, 0.674094439f, - 0.0361763388f, 0.3721793f, 0.823592246f, 0.524110138f, 0.2204483f, - 0.632020354f, 0.637001634f, 0.216262609f, 0.279114306f, 0.25007084f, - 0.30433768f, 0.0448598303f, 0.586083114f, 0.978048146f, 0.91390729f, - 0.385092884f, 0.218390301f, 0.762684941f, 0.505838513f, 0.366362303f, - 0.931746006f, 0.00208298792f, 0.875348926f, 0.428009957f, 0.270003974f, - 0.313204288f, 0.775881767f, 0.367065936f, 0.164243385f, 0.644775152f, - 0.575452209f, 0.911922634f, 0.0581932105f, 0.437950462f, 0.946475744f - }); - auto expected = NDArrayFactory::create('c', { 5, 3, 4 }, { - 0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, - 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, - 0.928968489f, 0.684074104f, 0.312434604f, 0.991390795f, 0.163174023f, - 0.268038541f, 0.361258626f, 0.685067773f, 0.682347894f, 0.84635365f, - 0.761800349f, 0.753103435f, 0.913557053f, 0.965541422f, 0.112067183f, - 0.540247589f, 0.280050347f, 0.106776128f, 0.679180562f, 0.870388806f, - 0.604331017f, 0.630475283f, 0.674094439f, 0.279114306f, 0.632020354f, - 0.823592246f, 0.490824632f, 0.75257351f, 0.129888852f, 0.849081645f, - 0.883509099f, 0.765611768f, 0.997870266f, 0.446510047f, 0.385092884f, - 0.931746006f, 0.978048146f, 0.91390729f, 0.685308874f, 0.0834472676f, - 0.396037966f, 0.756701186f, 0.597481251f, 0.784472764f, 0.514242649f, - 0.392005324f, 0.911922634f, 0.270003974f, 0.644775152f, 0.946475744f - }); - - - auto actual = NDArrayFactory::create('c', { 5,3,4 }); - - Context ctx(1); - ctx.setInputArray(0, &rgbs); - ctx.setOutputArray(0, &actual); - ctx.setIArguments({ 1 }); - sd::ops::rgb_to_hsv op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + /* + swapped_rgbs=rgbs.swapaxes(1,2).ravel() + swapped_hsvs=hsvs.swapaxes(1,2).ravel() + */ + auto rgbs = NDArrayFactory::create( + 'c', {5, 3, 4}, + {0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, + 0.725874603f, 0.517642438f, 0.0869259685f, 0.54742825f, + 0.413571358f, 0.890151322f, 0.928968489f, 0.684074104f, + 0.52110225f, 0.753103435f, 0.913557053f, 0.46850124f, + 0.761800349f, 0.237176552f, 0.90049392f, 0.965541422f, + 0.486593395f, 0.263826847f, 0.290193319f, 0.148351923f, + 0.674094439f, 0.0361763388f, 0.3721793f, 0.823592246f, + 0.524110138f, 0.2204483f, 0.632020354f, 0.637001634f, + 0.216262609f, 0.279114306f, 0.25007084f, 0.30433768f, + 0.0448598303f, 0.586083114f, 0.978048146f, 0.91390729f, + 0.385092884f, 0.218390301f, 0.762684941f, 0.505838513f, + 0.366362303f, 0.931746006f, 0.00208298792f, 0.875348926f, + 0.428009957f, 0.270003974f, 0.313204288f, 0.775881767f, + 0.367065936f, 0.164243385f, 0.644775152f, 0.575452209f, + 0.911922634f, 0.0581932105f, 0.437950462f, 0.946475744f}); + auto expected = NDArrayFactory::create( + 'c', {5, 3, 4}, + {0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, + 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, + 0.928968489f, 0.684074104f, 0.312434604f, 0.991390795f, 0.163174023f, + 0.268038541f, 0.361258626f, 0.685067773f, 0.682347894f, 0.84635365f, + 0.761800349f, 0.753103435f, 0.913557053f, 0.965541422f, 0.112067183f, + 0.540247589f, 0.280050347f, 0.106776128f, 0.679180562f, 0.870388806f, + 0.604331017f, 0.630475283f, 0.674094439f, 0.279114306f, 0.632020354f, + 0.823592246f, 0.490824632f, 0.75257351f, 0.129888852f, 0.849081645f, + 0.883509099f, 0.765611768f, 0.997870266f, 0.446510047f, 0.385092884f, + 0.931746006f, 0.978048146f, 0.91390729f, 0.685308874f, 0.0834472676f, + 0.396037966f, 0.756701186f, 0.597481251f, 0.784472764f, 0.514242649f, + 0.392005324f, 0.911922634f, 0.270003974f, 0.644775152f, 0.946475744f}); + + auto actual = NDArrayFactory::create('c', {5, 3, 4}); + + Context ctx(1); + ctx.setInputArray(0, rgbs); + ctx.setOutputArray(0, actual); + ctx.setIArguments({1}); + sd::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_3) { - - auto rgbs = NDArrayFactory::create('c', { 4, 3 }, { - 0.545678377f, 0.725874603f, 0.413571358f, 0.644941628f, 0.517642438f, - 0.890151322f, 0.461456001f, 0.0869259685f, 0.928968489f, 0.588904262f, - 0.54742825f, 0.684074104f - }); - auto expected = NDArrayFactory::create('c', { 4, 3 }, { - 0.262831867f, 0.430244058f, 0.725874603f, 0.723622441f, 0.418478161f, - 0.890151322f, 0.740797927f, 0.906427443f, 0.928968489f, 0.717254877f, - 0.199753001f, 0.684074104f - }); - - auto actual = NDArrayFactory::create('c', { 4, 3 }); - - Context ctx(1); - ctx.setInputArray(0, &rgbs); - ctx.setOutputArray(0, &actual); - - sd::ops::rgb_to_hsv op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + auto rgbs = NDArrayFactory::create( + 'c', {4, 3}, + {0.545678377f, 0.725874603f, 0.413571358f, 0.644941628f, 0.517642438f, + 0.890151322f, 0.461456001f, 0.0869259685f, 0.928968489f, 0.588904262f, + 0.54742825f, 0.684074104f}); + auto expected = NDArrayFactory::create( + 'c', {4, 3}, + {0.262831867f, 0.430244058f, 0.725874603f, 0.723622441f, 0.418478161f, + 0.890151322f, 0.740797927f, 0.906427443f, 0.928968489f, 0.717254877f, + 0.199753001f, 0.684074104f}); + + auto actual = NDArrayFactory::create('c', {4, 3}); + + Context ctx(1); + ctx.setInputArray(0, rgbs); + ctx.setOutputArray(0, actual); + + sd::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } - TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_4) { - auto rgbs = NDArrayFactory::create('c', { 3, 4 }, { - 0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, - 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f, - 0.928968489f, 0.684074104f - }); - auto expected = NDArrayFactory::create('c', { 3, 4 }, { - 0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, - 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, - 0.928968489f, 0.684074104f - }); - - auto actual = NDArrayFactory::create('c', { 3, 4 }); - - Context ctx(1); - ctx.setInputArray(0, &rgbs); - ctx.setOutputArray(0, &actual); - ctx.setIArguments({ 0 }); - sd::ops::rgb_to_hsv op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + auto rgbs = NDArrayFactory::create( + 'c', {3, 4}, + {0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, + 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f, + 0.928968489f, 0.684074104f}); + auto expected = NDArrayFactory::create( + 'c', {3, 4}, + {0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, + 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, + 0.928968489f, 0.684074104f}); + + auto actual = NDArrayFactory::create('c', {3, 4}); + + Context ctx(1); + ctx.setInputArray(0, rgbs); + ctx.setOutputArray(0, actual); + ctx.setIArguments({0}); + sd::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_5) { - auto rgbs = NDArrayFactory::create('c', { 3 }, { - 0.545678377f, 0.725874603f, 0.413571358f - }); - auto expected = NDArrayFactory::create('c', { 3 }, { - 0.262831867f, 0.430244058f, 0.725874603f - }); - - auto actual = NDArrayFactory::create('c', { 3 }); + auto rgbs = NDArrayFactory::create( + 'c', {3}, {0.545678377f, 0.725874603f, 0.413571358f}); + auto expected = NDArrayFactory::create( + 'c', {3}, {0.262831867f, 0.430244058f, 0.725874603f}); - Context ctx(1); - ctx.setInputArray(0, &rgbs); - ctx.setOutputArray(0, &actual); + auto actual = NDArrayFactory::create('c', {3}); - sd::ops::rgb_to_hsv op; - auto status = op.execute(&ctx); + Context ctx(1); + ctx.setInputArray(0, rgbs); + ctx.setOutputArray(0, actual); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); + sd::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } - TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_6) { - auto rgbs = NDArrayFactory::create('c', { 3, 4 }, { - 0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, - 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f, - 0.928968489f, 0.684074104f - }); - auto hsvs = NDArrayFactory::create('c', { 3, 4 }, { - 0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, - 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, - 0.928968489f, 0.684074104f - }); - - //get subarray - //get subarray - NDArray subArrRgbs = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) }); - NDArray expected = hsvs.subarray({ NDIndex::all(), NDIndex::point(0) }); - subArrRgbs.reshapei({ 3 }); - expected.reshapei({ 3 }); + auto rgbs = NDArrayFactory::create( + 'c', {3, 4}, + {0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, + 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f, + 0.928968489f, 0.684074104f}); + auto hsvs = NDArrayFactory::create( + 'c', {3, 4}, + {0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, + 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, + 0.928968489f, 0.684074104f}); + + // get subarray + // get subarray + NDArray subArrRgbs = rgbs.subarray({NDIndex::all(), NDIndex::point(0)}); + NDArray expected = hsvs.subarray({NDIndex::all(), NDIndex::point(0)}); + subArrRgbs.reshapei({3}); + expected.reshapei({3}); #if 0 //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] subArrRgbs.printShapeInfo("subArrRgbs"); #endif - auto actual = NDArrayFactory::create('c', { 3 }); - - Context ctx(1); - ctx.setInputArray(0, &subArrRgbs); - ctx.setOutputArray(0, &actual); - sd::ops::rgb_to_hsv op; - auto status = op.execute(&ctx); + auto actual = NDArrayFactory::create('c', {3}); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); + Context ctx(1); + ctx.setInputArray(0, subArrRgbs); + ctx.setOutputArray(0, actual); + sd::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_1) { - - auto hsvs = NDArrayFactory::create('c', { 5, 4, 3 }, { - 0.705504596f, 0.793608069f, 0.65870738f, 0.848827183f, 0.920532584f, - 0.887555957f, 0.72317636f, 0.563831031f, 0.773604929f, 0.269532293f, - 0.332347751f, 0.111181192f, 0.239250854f, 0.499201417f, 0.862712979f, - 0.0853395388f, 0.0810681432f, 0.226065159f, 0.851340771f, 0.602043271f, - 0.690895379f, 0.971996486f, 0.273846686f, 0.464318275f, 0.194078103f, - 0.219649255f, 0.616706491f, 0.847525477f, 0.653597355f, 0.700065672f, - 0.0299375951f, 0.184475258f, 0.274936169f, 0.196718201f, 0.179381892f, - 0.934476376f, 0.895766437f, 0.52967906f, 0.675635338f, 0.966644645f, - 0.770889699f, 0.556649387f, 0.13426739f, 0.899450243f, 0.817096591f, - 0.150202557f, 0.763557851f, 0.709604502f, 0.741747797f, 0.657703638f, - 0.167678103f, 0.828556478f, 0.615502477f, 0.478080243f, 0.447288662f, - 0.864299297f, 0.129833668f, 0.66402483f, 0.795475543f, 0.561332941f - }); - auto expected = NDArrayFactory::create('c', { 5, 4, 3 }, { - 0.257768334f, 0.135951888f, 0.65870738f, 0.887555957f, 0.0705317783f, - 0.811602857f, 0.485313689f, 0.337422464f, 0.773604929f, 0.0883753772f, - 0.111181192f, 0.074230373f, 0.675155059f, 0.862712979f, 0.432045438f, - 0.226065159f, 0.21712242f, 0.207738476f, 0.690895379f, 0.274946465f, - 0.645954334f, 0.464318275f, 0.337166255f, 0.358530475f, 0.594427716f, - 0.616706491f, 0.481247369f, 0.700065672f, 0.242504601f, 0.661103036f, - 0.274936169f, 0.233327664f, 0.224217249f, 0.904251479f, 0.934476376f, - 0.766848235f, 0.675635338f, 0.317765447f, 0.54157777f, 0.556649387f, - 0.127534108f, 0.213413864f, 0.817096591f, 0.674227886f, 0.0821588641f, - 0.709604502f, 0.656080596f, 0.167780413f, 0.107076412f, 0.0573956046f, - 0.167678103f, 0.46964643f, 0.183820669f, 0.478080243f, 0.01761852f, - 0.129833668f, 0.0943436049f, 0.114806315f, 0.121884218f, 0.561332941f - }); - - - auto actual = NDArrayFactory::create('c', { 5,4,3 }); - - Context ctx(1); - ctx.setInputArray(0, &hsvs); - ctx.setOutputArray(0, &actual); - - sd::ops::hsv_to_rgb op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + auto hsvs = NDArrayFactory::create( + 'c', {5, 4, 3}, + {0.705504596f, 0.793608069f, 0.65870738f, 0.848827183f, 0.920532584f, + 0.887555957f, 0.72317636f, 0.563831031f, 0.773604929f, 0.269532293f, + 0.332347751f, 0.111181192f, 0.239250854f, 0.499201417f, 0.862712979f, + 0.0853395388f, 0.0810681432f, 0.226065159f, 0.851340771f, 0.602043271f, + 0.690895379f, 0.971996486f, 0.273846686f, 0.464318275f, 0.194078103f, + 0.219649255f, 0.616706491f, 0.847525477f, 0.653597355f, 0.700065672f, + 0.0299375951f, 0.184475258f, 0.274936169f, 0.196718201f, 0.179381892f, + 0.934476376f, 0.895766437f, 0.52967906f, 0.675635338f, 0.966644645f, + 0.770889699f, 0.556649387f, 0.13426739f, 0.899450243f, 0.817096591f, + 0.150202557f, 0.763557851f, 0.709604502f, 0.741747797f, 0.657703638f, + 0.167678103f, 0.828556478f, 0.615502477f, 0.478080243f, 0.447288662f, + 0.864299297f, 0.129833668f, 0.66402483f, 0.795475543f, 0.561332941f}); + auto expected = NDArrayFactory::create( + 'c', {5, 4, 3}, + {0.257768334f, 0.135951888f, 0.65870738f, 0.887555957f, 0.0705317783f, + 0.811602857f, 0.485313689f, 0.337422464f, 0.773604929f, 0.0883753772f, + 0.111181192f, 0.074230373f, 0.675155059f, 0.862712979f, 0.432045438f, + 0.226065159f, 0.21712242f, 0.207738476f, 0.690895379f, 0.274946465f, + 0.645954334f, 0.464318275f, 0.337166255f, 0.358530475f, 0.594427716f, + 0.616706491f, 0.481247369f, 0.700065672f, 0.242504601f, 0.661103036f, + 0.274936169f, 0.233327664f, 0.224217249f, 0.904251479f, 0.934476376f, + 0.766848235f, 0.675635338f, 0.317765447f, 0.54157777f, 0.556649387f, + 0.127534108f, 0.213413864f, 0.817096591f, 0.674227886f, 0.0821588641f, + 0.709604502f, 0.656080596f, 0.167780413f, 0.107076412f, 0.0573956046f, + 0.167678103f, 0.46964643f, 0.183820669f, 0.478080243f, 0.01761852f, + 0.129833668f, 0.0943436049f, 0.114806315f, 0.121884218f, 0.561332941f}); + + auto actual = NDArrayFactory::create('c', {5, 4, 3}); + + Context ctx(1); + ctx.setInputArray(0, hsvs); + ctx.setOutputArray(0, actual); + + sd::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_2) { - auto hsvs = NDArrayFactory::create('c', { 5, 3, 4 }, { - 0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, 0.793608069f, - 0.920532584f, 0.563831031f, 0.332347751f, 0.65870738f, 0.887555957f, - 0.773604929f, 0.111181192f, 0.239250854f, 0.0853395388f, 0.851340771f, - 0.971996486f, 0.499201417f, 0.0810681432f, 0.602043271f, 0.273846686f, - 0.862712979f, 0.226065159f, 0.690895379f, 0.464318275f, 0.194078103f, - 0.847525477f, 0.0299375951f, 0.196718201f, 0.219649255f, 0.653597355f, - 0.184475258f, 0.179381892f, 0.616706491f, 0.700065672f, 0.274936169f, - 0.934476376f, 0.895766437f, 0.966644645f, 0.13426739f, 0.150202557f, - 0.52967906f, 0.770889699f, 0.899450243f, 0.763557851f, 0.675635338f, - 0.556649387f, 0.817096591f, 0.709604502f, 0.741747797f, 0.828556478f, - 0.447288662f, 0.66402483f, 0.657703638f, 0.615502477f, 0.864299297f, - 0.795475543f, 0.167678103f, 0.478080243f, 0.129833668f, 0.561332941f - }); - auto expected = NDArrayFactory::create('c', { 5, 3, 4 }, { - 0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, 0.135951888f, - 0.0705317783f, 0.337422464f, 0.111181192f, 0.65870738f, 0.811602857f, - 0.773604929f, 0.074230373f, 0.675155059f, 0.226065159f, 0.690895379f, - 0.464318275f, 0.862712979f, 0.21712242f, 0.274946465f, 0.337166255f, - 0.432045438f, 0.207738476f, 0.645954334f, 0.358530475f, 0.594427716f, - 0.700065672f, 0.274936169f, 0.904251479f, 0.616706491f, 0.242504601f, - 0.233327664f, 0.934476376f, 0.481247369f, 0.661103036f, 0.224217249f, - 0.766848235f, 0.675635338f, 0.556649387f, 0.817096591f, 0.709604502f, - 0.317765447f, 0.127534108f, 0.674227886f, 0.656080596f, 0.54157777f, - 0.213413864f, 0.0821588641f, 0.167780413f, 0.107076412f, 0.46964643f, - 0.01761852f, 0.114806315f, 0.0573956046f, 0.183820669f, 0.129833668f, - 0.121884218f, 0.167678103f, 0.478080243f, 0.0943436049f, 0.561332941f - }); - auto actual = NDArrayFactory::create('c', { 5,3,4 }); - - Context ctx(1); - ctx.setInputArray(0, &hsvs); - ctx.setOutputArray(0, &actual); - ctx.setIArguments({ 1 }); - sd::ops::hsv_to_rgb op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + auto hsvs = NDArrayFactory::create( + 'c', {5, 3, 4}, + {0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, + 0.793608069f, 0.920532584f, 0.563831031f, 0.332347751f, + 0.65870738f, 0.887555957f, 0.773604929f, 0.111181192f, + 0.239250854f, 0.0853395388f, 0.851340771f, 0.971996486f, + 0.499201417f, 0.0810681432f, 0.602043271f, 0.273846686f, + 0.862712979f, 0.226065159f, 0.690895379f, 0.464318275f, + 0.194078103f, 0.847525477f, 0.0299375951f, 0.196718201f, + 0.219649255f, 0.653597355f, 0.184475258f, 0.179381892f, + 0.616706491f, 0.700065672f, 0.274936169f, 0.934476376f, + 0.895766437f, 0.966644645f, 0.13426739f, 0.150202557f, + 0.52967906f, 0.770889699f, 0.899450243f, 0.763557851f, + 0.675635338f, 0.556649387f, 0.817096591f, 0.709604502f, + 0.741747797f, 0.828556478f, 0.447288662f, 0.66402483f, + 0.657703638f, 0.615502477f, 0.864299297f, 0.795475543f, + 0.167678103f, 0.478080243f, 0.129833668f, 0.561332941f}); + auto expected = NDArrayFactory::create( + 'c', {5, 3, 4}, + {0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, + 0.135951888f, 0.0705317783f, 0.337422464f, 0.111181192f, + 0.65870738f, 0.811602857f, 0.773604929f, 0.074230373f, + 0.675155059f, 0.226065159f, 0.690895379f, 0.464318275f, + 0.862712979f, 0.21712242f, 0.274946465f, 0.337166255f, + 0.432045438f, 0.207738476f, 0.645954334f, 0.358530475f, + 0.594427716f, 0.700065672f, 0.274936169f, 0.904251479f, + 0.616706491f, 0.242504601f, 0.233327664f, 0.934476376f, + 0.481247369f, 0.661103036f, 0.224217249f, 0.766848235f, + 0.675635338f, 0.556649387f, 0.817096591f, 0.709604502f, + 0.317765447f, 0.127534108f, 0.674227886f, 0.656080596f, + 0.54157777f, 0.213413864f, 0.0821588641f, 0.167780413f, + 0.107076412f, 0.46964643f, 0.01761852f, 0.114806315f, + 0.0573956046f, 0.183820669f, 0.129833668f, 0.121884218f, + 0.167678103f, 0.478080243f, 0.0943436049f, 0.561332941f}); + auto actual = NDArrayFactory::create('c', {5, 3, 4}); + + Context ctx(1); + ctx.setInputArray(0, hsvs); + ctx.setOutputArray(0, actual); + ctx.setIArguments({1}); + sd::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_3) { - auto hsvs = NDArrayFactory::create('c', { 4, 3 }, { - 0.705504596f, 0.793608069f, 0.65870738f, 0.848827183f, 0.920532584f, - 0.887555957f, 0.72317636f, 0.563831031f, 0.773604929f, 0.269532293f, - 0.332347751f, 0.111181192f - }); - auto expected = NDArrayFactory::create('c', { 4, 3 }, { - 0.257768334f, 0.135951888f, 0.65870738f, 0.887555957f, 0.0705317783f, - 0.811602857f, 0.485313689f, 0.337422464f, 0.773604929f, 0.0883753772f, - 0.111181192f, 0.074230373f - }); - auto actual = NDArrayFactory::create('c', { 4,3 }); - - Context ctx(1); - ctx.setInputArray(0, &hsvs); - ctx.setOutputArray(0, &actual); - - sd::ops::hsv_to_rgb op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + auto hsvs = NDArrayFactory::create( + 'c', {4, 3}, + {0.705504596f, 0.793608069f, 0.65870738f, 0.848827183f, 0.920532584f, + 0.887555957f, 0.72317636f, 0.563831031f, 0.773604929f, 0.269532293f, + 0.332347751f, 0.111181192f}); + auto expected = NDArrayFactory::create( + 'c', {4, 3}, + {0.257768334f, 0.135951888f, 0.65870738f, 0.887555957f, 0.0705317783f, + 0.811602857f, 0.485313689f, 0.337422464f, 0.773604929f, 0.0883753772f, + 0.111181192f, 0.074230373f}); + auto actual = NDArrayFactory::create('c', {4, 3}); + + Context ctx(1); + ctx.setInputArray(0, hsvs); + ctx.setOutputArray(0, actual); + + sd::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } - TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_4) { - auto hsvs = NDArrayFactory::create('c', { 3, 4 }, { - 0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, 0.793608069f, - 0.920532584f, 0.563831031f, 0.332347751f, 0.65870738f, 0.887555957f, - 0.773604929f, 0.111181192f - }); - auto expected = NDArrayFactory::create('c', { 3, 4 }, { - 0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, 0.135951888f, - 0.0705317783f, 0.337422464f, 0.111181192f, 0.65870738f, 0.811602857f, - 0.773604929f, 0.074230373f - }); - auto actual = NDArrayFactory::create('c', { 3, 4 }); - - Context ctx(1); - ctx.setInputArray(0, &hsvs); - ctx.setOutputArray(0, &actual); - ctx.setIArguments({ 0 }); - sd::ops::hsv_to_rgb op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + auto hsvs = NDArrayFactory::create( + 'c', {3, 4}, + {0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, 0.793608069f, + 0.920532584f, 0.563831031f, 0.332347751f, 0.65870738f, 0.887555957f, + 0.773604929f, 0.111181192f}); + auto expected = NDArrayFactory::create( + 'c', {3, 4}, + {0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, 0.135951888f, + 0.0705317783f, 0.337422464f, 0.111181192f, 0.65870738f, 0.811602857f, + 0.773604929f, 0.074230373f}); + auto actual = NDArrayFactory::create('c', {3, 4}); + + Context ctx(1); + ctx.setInputArray(0, hsvs); + ctx.setOutputArray(0, actual); + ctx.setIArguments({0}); + sd::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_5) { + auto hsvs = NDArrayFactory::create( + 'c', {3}, {0.705504596f, 0.793608069f, 0.65870738f}); + auto expected = NDArrayFactory::create( + 'c', {3}, {0.257768334f, 0.135951888f, 0.65870738f}); - auto hsvs = NDArrayFactory::create('c', { 3 }, { - 0.705504596f, 0.793608069f, 0.65870738f - }); - auto expected = NDArrayFactory::create('c', { 3 }, { - 0.257768334f, 0.135951888f, 0.65870738f - }); + auto actual = NDArrayFactory::create('c', {3}); - auto actual = NDArrayFactory::create('c', { 3 }); + Context ctx(1); + ctx.setInputArray(0, hsvs); + ctx.setOutputArray(0, actual); - Context ctx(1); - ctx.setInputArray(0, &hsvs); - ctx.setOutputArray(0, &actual); - - sd::ops::hsv_to_rgb op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); + sd::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } - TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_6) { - - auto hsvs = NDArrayFactory::create('c', { 3, 4 }, { - 0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, 0.793608069f, - 0.920532584f, 0.563831031f, 0.332347751f, 0.65870738f, 0.887555957f, - 0.773604929f, 0.111181192f - }); - auto rgbs = NDArrayFactory::create('c', { 3, 4 }, { - 0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, 0.135951888f, - 0.0705317783f, 0.337422464f, 0.111181192f, 0.65870738f, 0.811602857f, - 0.773604929f, 0.074230373f - }); - - auto actual = NDArrayFactory::create('c', { 3 }); - //get subarray - NDArray subArrHsvs = hsvs.subarray({ NDIndex::all(), NDIndex::point(0) }); - subArrHsvs.reshapei({ 3 }); - NDArray expected = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) }); - expected.reshapei({ 3 }); + auto hsvs = NDArrayFactory::create( + 'c', {3, 4}, + {0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, 0.793608069f, + 0.920532584f, 0.563831031f, 0.332347751f, 0.65870738f, 0.887555957f, + 0.773604929f, 0.111181192f}); + auto rgbs = NDArrayFactory::create( + 'c', {3, 4}, + {0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, 0.135951888f, + 0.0705317783f, 0.337422464f, 0.111181192f, 0.65870738f, 0.811602857f, + 0.773604929f, 0.074230373f}); + + auto actual = NDArrayFactory::create('c', {3}); + // get subarray + NDArray subArrHsvs = hsvs.subarray({NDIndex::all(), NDIndex::point(0)}); + subArrHsvs.reshapei({3}); + NDArray expected = rgbs.subarray({NDIndex::all(), NDIndex::point(0)}); + expected.reshapei({3}); #if 0 //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] subArrHsvs.printShapeInfo("subArrHsvs"); #endif - Context ctx(1); - ctx.setInputArray(0, &subArrHsvs); - ctx.setOutputArray(0, &actual); - sd::ops::hsv_to_rgb op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); + Context ctx(1); + ctx.setInputArray(0, subArrHsvs); + ctx.setOutputArray(0, actual); + sd::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } - - TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_1) { - /** - generated using numpy - _rgb_to_yiq_kernel = np.array([[0.299f, 0.59590059f, 0.2115f], - [0.587f, -0.27455667f, -0.52273617f], - [0.114f, -0.32134392f, 0.31119955f]]) - nnrgbs = np.array([random() for x in range(0,3*4*5)],np.float32).reshape([5,4,3]) - out =np.tensordot(nnrgbs,_rgb_to_yiq_kernel,axes=[[len(nnrgbs.shape)-1],[0]]) - - #alternatively you could use just with apply - out_2=np.apply_along_axis(lambda x: _rgb_to_yiq_kernel.T @ x,len(nnrgbs.shape)-1,nnrgbs) - - */ - auto rgb = NDArrayFactory::create('c', { 5, 4 ,3 }, - { - 0.48055f , 0.80757356f, 0.2564435f , 0.94277316f, 0.17006584f, - 0.33366168f, 0.41727918f, 0.54528666f, 0.48942474f, 0.3305715f , - 0.98633456f, 0.00158441f, 0.97605824f, 0.02462568f, 0.14837205f, - 0.00112842f, 0.99260217f, 0.9585542f , 0.41196227f, 0.3095014f , - 0.6620493f , 0.30888894f, 0.3122602f , 0.7993488f , 0.86656475f, - 0.5997049f , 0.9776477f , 0.72481847f, 0.7835693f , 0.14649455f, - 0.3573504f , 0.33301765f, 0.7853056f , 0.25830218f, 0.59289205f, - 0.41357264f, 0.5934154f , 0.72647524f, 0.6623308f , 0.96197623f, - 0.0720306f , 0.23853847f, 0.1427159f , 0.19581454f, 0.06766324f, - 0.10614152f, 0.26093867f, 0.9584985f , 0.01258832f, 0.8160156f , - 0.56506383f, 0.08418505f, 0.86440504f, 0.6807802f , 0.20662387f, - 0.4153733f , 0.76146203f, 0.50057423f, 0.08274968f, 0.9521758f - }); - - auto expected = NDArrayFactory::create('c', { 5, 4 ,3 }, - { - 0.64696468f, -0.01777124f, -0.24070648f, 0.41975525f, 0.40788622f, - 0.21433232f, 0.50064416f, -0.05832884f, -0.04447775f, 0.67799989f, - -0.07432612f, -0.44518381f, 0.32321111f, 0.52719408f, 0.2397369f , - 0.69227005f, -0.57987869f, -0.22032876f, 0.38032767f, -0.05223263f, - 0.13137188f, 0.3667803f , -0.15853189f, 0.15085728f, 0.72258149f, - 0.03757231f, 0.17403452f, 0.69337627f, 0.16971045f, -0.21071186f, - 0.39185397f, -0.13084008f, 0.145886f , 0.47240727f, -0.1417591f , - -0.12659159f, 0.67937788f, -0.05867803f, -0.04813048f, 0.35710624f, - 0.47681283f, 0.24003804f, 0.1653288f , 0.00953913f, -0.05111816f, - 0.29417614f, -0.31640032f, 0.18433114f, 0.54718234f, -0.39812097f, - -0.24805083f, 0.61018603f, -0.40592682f, -0.22219216f, 0.39241133f, - -0.23560742f, 0.06353694f, 0.3067938f , -0.0304029f , 0.35893188f - }); - - auto actual = NDArrayFactory::create('c', { 5, 4, 3 }); - - Context ctx(1); - ctx.setInputArray(0, &rgb); - ctx.setOutputArray(0, &actual); - - sd::ops::rgb_to_yiq op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + /** + generated using numpy + _rgb_to_yiq_kernel = np.array([[0.299f, 0.59590059f, 0.2115f], + [0.587f, -0.27455667f, -0.52273617f], + [0.114f, -0.32134392f, 0.31119955f]]) + nnrgbs = np.array([random() for x in + range(0,3*4*5)],np.float32).reshape([5,4,3]) out + =np.tensordot(nnrgbs,_rgb_to_yiq_kernel,axes=[[len(nnrgbs.shape)-1],[0]]) + + #alternatively you could use just with apply + out_2=np.apply_along_axis(lambda x: _rgb_to_yiq_kernel.T @ + x,len(nnrgbs.shape)-1,nnrgbs) + + */ + auto rgb = NDArrayFactory::create( + 'c', {5, 4, 3}, + {0.48055f, 0.80757356f, 0.2564435f, 0.94277316f, 0.17006584f, + 0.33366168f, 0.41727918f, 0.54528666f, 0.48942474f, 0.3305715f, + 0.98633456f, 0.00158441f, 0.97605824f, 0.02462568f, 0.14837205f, + 0.00112842f, 0.99260217f, 0.9585542f, 0.41196227f, 0.3095014f, + 0.6620493f, 0.30888894f, 0.3122602f, 0.7993488f, 0.86656475f, + 0.5997049f, 0.9776477f, 0.72481847f, 0.7835693f, 0.14649455f, + 0.3573504f, 0.33301765f, 0.7853056f, 0.25830218f, 0.59289205f, + 0.41357264f, 0.5934154f, 0.72647524f, 0.6623308f, 0.96197623f, + 0.0720306f, 0.23853847f, 0.1427159f, 0.19581454f, 0.06766324f, + 0.10614152f, 0.26093867f, 0.9584985f, 0.01258832f, 0.8160156f, + 0.56506383f, 0.08418505f, 0.86440504f, 0.6807802f, 0.20662387f, + 0.4153733f, 0.76146203f, 0.50057423f, 0.08274968f, 0.9521758f}); + + auto expected = NDArrayFactory::create( + 'c', {5, 4, 3}, + {0.64696468f, -0.01777124f, -0.24070648f, 0.41975525f, 0.40788622f, + 0.21433232f, 0.50064416f, -0.05832884f, -0.04447775f, 0.67799989f, + -0.07432612f, -0.44518381f, 0.32321111f, 0.52719408f, 0.2397369f, + 0.69227005f, -0.57987869f, -0.22032876f, 0.38032767f, -0.05223263f, + 0.13137188f, 0.3667803f, -0.15853189f, 0.15085728f, 0.72258149f, + 0.03757231f, 0.17403452f, 0.69337627f, 0.16971045f, -0.21071186f, + 0.39185397f, -0.13084008f, 0.145886f, 0.47240727f, -0.1417591f, + -0.12659159f, 0.67937788f, -0.05867803f, -0.04813048f, 0.35710624f, + 0.47681283f, 0.24003804f, 0.1653288f, 0.00953913f, -0.05111816f, + 0.29417614f, -0.31640032f, 0.18433114f, 0.54718234f, -0.39812097f, + -0.24805083f, 0.61018603f, -0.40592682f, -0.22219216f, 0.39241133f, + -0.23560742f, 0.06353694f, 0.3067938f, -0.0304029f, 0.35893188f}); + + auto actual = NDArrayFactory::create('c', {5, 4, 3}); + + Context ctx(1); + ctx.setInputArray(0, rgb); + ctx.setOutputArray(0, actual); + + sd::ops::rgb_to_yiq op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_2) { - - auto rgb = NDArrayFactory::create('c', { 5, 3, 4 }, - { - 0.48055f , 0.94277316f, 0.41727918f, 0.3305715f , 0.80757356f, - 0.17006584f, 0.54528666f, 0.98633456f, 0.2564435f , 0.33366168f, - 0.48942474f, 0.00158441f, 0.97605824f, 0.00112842f, 0.41196227f, - 0.30888894f, 0.02462568f, 0.99260217f, 0.3095014f , 0.3122602f , - 0.14837205f, 0.9585542f , 0.6620493f , 0.7993488f , 0.86656475f, - 0.72481847f, 0.3573504f , 0.25830218f, 0.5997049f , 0.7835693f , - 0.33301765f, 0.59289205f, 0.9776477f , 0.14649455f, 0.7853056f , - 0.41357264f, 0.5934154f , 0.96197623f, 0.1427159f , 0.10614152f, - 0.72647524f, 0.0720306f , 0.19581454f, 0.26093867f, 0.6623308f , - 0.23853847f, 0.06766324f, 0.9584985f , 0.01258832f, 0.08418505f, - 0.20662387f, 0.50057423f, 0.8160156f , 0.86440504f, 0.4153733f , - 0.08274968f, 0.56506383f, 0.6807802f , 0.76146203f, 0.9521758f - }); - - auto expected = NDArrayFactory::create('c', { 5, 3, 4 }, - { - 0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f, - 0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f, - -0.04447775f, -0.44518381f, 0.32321111f, 0.69227005f, 0.38032767f, - 0.3667803f , 0.52719408f, -0.57987869f, -0.05223263f, -0.15853189f, - 0.2397369f , -0.22032876f, 0.13137188f, 0.15085728f, 0.72258149f, - 0.69337627f, 0.39185397f, 0.47240727f, 0.03757231f, 0.16971045f, - -0.13084008f, -0.1417591f , 0.17403452f, -0.21071186f, 0.145886f , - -0.12659159f, 0.67937788f, 0.35710624f, 0.1653288f , 0.29417614f, - -0.05867803f, 0.47681283f, 0.00953913f, -0.31640032f, -0.04813048f, - 0.24003804f, -0.05111816f, 0.18433114f, 0.54718234f, 0.61018603f, - 0.39241133f, 0.3067938f , -0.39812097f, -0.40592682f, -0.23560742f, - -0.0304029f , -0.24805083f, -0.22219216f, 0.06353694f, 0.35893188f - }); - - auto actual = NDArrayFactory::create('c', { 5, 3, 4 }); - - Context ctx(1); - ctx.setInputArray(0, &rgb); - ctx.setOutputArray(0, &actual); - ctx.setIArguments({ 1 }); - sd::ops::rgb_to_yiq op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + auto rgb = NDArrayFactory::create( + 'c', {5, 3, 4}, + {0.48055f, 0.94277316f, 0.41727918f, 0.3305715f, 0.80757356f, + 0.17006584f, 0.54528666f, 0.98633456f, 0.2564435f, 0.33366168f, + 0.48942474f, 0.00158441f, 0.97605824f, 0.00112842f, 0.41196227f, + 0.30888894f, 0.02462568f, 0.99260217f, 0.3095014f, 0.3122602f, + 0.14837205f, 0.9585542f, 0.6620493f, 0.7993488f, 0.86656475f, + 0.72481847f, 0.3573504f, 0.25830218f, 0.5997049f, 0.7835693f, + 0.33301765f, 0.59289205f, 0.9776477f, 0.14649455f, 0.7853056f, + 0.41357264f, 0.5934154f, 0.96197623f, 0.1427159f, 0.10614152f, + 0.72647524f, 0.0720306f, 0.19581454f, 0.26093867f, 0.6623308f, + 0.23853847f, 0.06766324f, 0.9584985f, 0.01258832f, 0.08418505f, + 0.20662387f, 0.50057423f, 0.8160156f, 0.86440504f, 0.4153733f, + 0.08274968f, 0.56506383f, 0.6807802f, 0.76146203f, 0.9521758f}); + + auto expected = NDArrayFactory::create( + 'c', {5, 3, 4}, + {0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f, + 0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f, + -0.04447775f, -0.44518381f, 0.32321111f, 0.69227005f, 0.38032767f, + 0.3667803f, 0.52719408f, -0.57987869f, -0.05223263f, -0.15853189f, + 0.2397369f, -0.22032876f, 0.13137188f, 0.15085728f, 0.72258149f, + 0.69337627f, 0.39185397f, 0.47240727f, 0.03757231f, 0.16971045f, + -0.13084008f, -0.1417591f, 0.17403452f, -0.21071186f, 0.145886f, + -0.12659159f, 0.67937788f, 0.35710624f, 0.1653288f, 0.29417614f, + -0.05867803f, 0.47681283f, 0.00953913f, -0.31640032f, -0.04813048f, + 0.24003804f, -0.05111816f, 0.18433114f, 0.54718234f, 0.61018603f, + 0.39241133f, 0.3067938f, -0.39812097f, -0.40592682f, -0.23560742f, + -0.0304029f, -0.24805083f, -0.22219216f, 0.06353694f, 0.35893188f}); + + auto actual = NDArrayFactory::create('c', {5, 3, 4}); + + Context ctx(1); + ctx.setInputArray(0, rgb); + ctx.setOutputArray(0, actual); + ctx.setIArguments({1}); + sd::ops::rgb_to_yiq op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_3) { + auto rgb = NDArrayFactory::create( + 'c', {4, 3}, + {0.48055f, 0.80757356f, 0.2564435f, 0.94277316f, 0.17006584f, 0.33366168f, + 0.41727918f, 0.54528666f, 0.48942474f, 0.3305715f, 0.98633456f, + 0.00158441f}); - auto rgb = NDArrayFactory::create('c', { 4, 3 }, - { - 0.48055f , 0.80757356f, 0.2564435f , 0.94277316f, 0.17006584f, - 0.33366168f, 0.41727918f, 0.54528666f, 0.48942474f, 0.3305715f , - 0.98633456f, 0.00158441f - }); - - auto expected = NDArrayFactory::create('c', { 4, 3 }, - { - 0.64696468f, -0.01777124f, -0.24070648f, 0.41975525f, 0.40788622f, - 0.21433232f, 0.50064416f, -0.05832884f, -0.04447775f, 0.67799989f, - -0.07432612f, -0.44518381f - }); + auto expected = NDArrayFactory::create( + 'c', {4, 3}, + {0.64696468f, -0.01777124f, -0.24070648f, 0.41975525f, 0.40788622f, + 0.21433232f, 0.50064416f, -0.05832884f, -0.04447775f, 0.67799989f, + -0.07432612f, -0.44518381f}); - auto actual = NDArrayFactory::create('c', { 4, 3 }); + auto actual = NDArrayFactory::create('c', {4, 3}); - Context ctx(1); - ctx.setInputArray(0, &rgb); - ctx.setOutputArray(0, &actual); + Context ctx(1); + ctx.setInputArray(0, rgb); + ctx.setOutputArray(0, actual); - sd::ops::rgb_to_yiq op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); + sd::ops::rgb_to_yiq op; + auto status = op.execute(&ctx); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_4) { - - auto rgb = NDArrayFactory::create('c', { 3, 4 }, - { - 0.48055f , 0.94277316f, 0.41727918f, 0.3305715f , 0.80757356f, - 0.17006584f, 0.54528666f, 0.98633456f, 0.2564435f , 0.33366168f, - 0.48942474f, 0.00158441f - }); - - auto expected = NDArrayFactory::create('c', { 3, 4 }, - { - 0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f, - 0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f, - -0.04447775f, -0.44518381f - }); - - auto actual = NDArrayFactory::create('c', { 3, 4 }); - - Context ctx(1); - ctx.setInputArray(0, &rgb); - ctx.setOutputArray(0, &actual); - ctx.setIArguments({ 0 }); - sd::ops::rgb_to_yiq op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); + auto rgb = NDArrayFactory::create( + 'c', {3, 4}, + {0.48055f, 0.94277316f, 0.41727918f, 0.3305715f, 0.80757356f, 0.17006584f, + 0.54528666f, 0.98633456f, 0.2564435f, 0.33366168f, 0.48942474f, + 0.00158441f}); + + auto expected = NDArrayFactory::create( + 'c', {3, 4}, + {0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f, + 0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f, + -0.04447775f, -0.44518381f}); + + auto actual = NDArrayFactory::create('c', {3, 4}); + + Context ctx(1); + ctx.setInputArray(0, rgb); + ctx.setOutputArray(0, actual); + ctx.setIArguments({0}); + sd::ops::rgb_to_yiq op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } - - TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_5) { - - auto rgbs = NDArrayFactory::create('c', { 3 }, - { 0.48055f , 0.80757356f, 0.2564435f }); - auto expected = NDArrayFactory::create('c', { 3 }, - { 0.64696468f, -0.01777124f, -0.24070648f, }); - - - auto actual = NDArrayFactory::create('c', { 3 }); - - Context ctx(1); - ctx.setInputArray(0, &rgbs); - ctx.setOutputArray(0, &actual); - - sd::ops::rgb_to_yiq op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); + auto rgbs = NDArrayFactory::create( + 'c', {3}, {0.48055f, 0.80757356f, 0.2564435f}); + auto expected = NDArrayFactory::create('c', {3}, + { + 0.64696468f, + -0.01777124f, + -0.24070648f, + }); + + auto actual = NDArrayFactory::create('c', {3}); + + Context ctx(1); + ctx.setInputArray(0, rgbs); + ctx.setOutputArray(0, actual); + + sd::ops::rgb_to_yiq op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_6) { - - auto rgbs = NDArrayFactory::create('c', { 3, 4 }, - { - 0.48055f , 0.94277316f, 0.41727918f, 0.3305715f , 0.80757356f, - 0.17006584f, 0.54528666f, 0.98633456f, 0.2564435f , 0.33366168f, - 0.48942474f, 0.00158441f - }); - - auto yiqs = NDArrayFactory::create('c', { 3, 4 }, - { - 0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f, - 0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f, - -0.04447775f, -0.44518381f - }); - - //get subarray - NDArray subArrRgbs = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) }); - NDArray expected = yiqs.subarray({ NDIndex::all(), NDIndex::point(0) }); - subArrRgbs.reshapei({ 3 }); - expected.reshapei({ 3 }); + auto rgbs = NDArrayFactory::create( + 'c', {3, 4}, + {0.48055f, 0.94277316f, 0.41727918f, 0.3305715f, 0.80757356f, 0.17006584f, + 0.54528666f, 0.98633456f, 0.2564435f, 0.33366168f, 0.48942474f, + 0.00158441f}); + + auto yiqs = NDArrayFactory::create( + 'c', {3, 4}, + {0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f, + 0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f, + -0.04447775f, -0.44518381f}); + + // get subarray + NDArray subArrRgbs = rgbs.subarray({NDIndex::all(), NDIndex::point(0)}); + NDArray expected = yiqs.subarray({NDIndex::all(), NDIndex::point(0)}); + subArrRgbs.reshapei({3}); + expected.reshapei({3}); #if 0 //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] subArrRgbs.printShapeInfo("subArrRgbs"); #endif - auto actual = NDArrayFactory::create('c', { 3 }); - - Context ctx(1); - ctx.setInputArray(0, &subArrRgbs); - ctx.setOutputArray(0, &actual); - sd::ops::rgb_to_yiq op; - auto status = op.execute(&ctx); + auto actual = NDArrayFactory::create('c', {3}); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); + Context ctx(1); + ctx.setInputArray(0, subArrRgbs); + ctx.setOutputArray(0, actual); + sd::ops::rgb_to_yiq op; + auto status = op.execute(&ctx); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_1) { - - auto yiqs = NDArrayFactory::create('c', { 5, 4, 3 }, { - 0.775258899f, -0.288912386f, -0.132725924f, 0.0664454922f, -0.212469354f, - 0.455438733f, 0.418221354f, 0.349350512f, 0.145902053f, 0.947576523f, - -0.471601307f, 0.263960421f, 0.700227439f, 0.32434237f, -0.278446227f, - 0.130805135f, -0.438441873f, 0.187127829f, 0.0276055578f, -0.179727226f, - 0.305075705f, 0.716282248f, 0.278215706f, -0.44586885f, 0.76971364f, - 0.131288841f, -0.141177326f, 0.900081575f, -0.0788725987f, 0.14756602f, - 0.387832165f, 0.229834676f, 0.47921446f, 0.632930398f, 0.0443540029f, - -0.268817365f, 0.0977194682f, -0.141669706f, -0.140715122f, 0.946808815f, - -0.52525419f, -0.106209636f, 0.659476519f, 0.391066104f, 0.426448852f, - 0.496989518f, -0.283434421f, -0.177366048f, 0.715208411f, -0.496444523f, - 0.189553142f, 0.616444945f, 0.345852494f, 0.447739422f, 0.224696323f, - 0.451372236f, 0.298027098f, 0.446561724f, -0.187599331f, -0.448159873f - }); - auto expected = NDArrayFactory::create('c', { 5, 4, 3 }, { - 0.416663059f, 0.939747555f, 0.868814286f, 0.146075352f, -0.170521997f, - 1.07776645f, 0.842775284f, 0.228765106f, 0.280231822f, 0.660605291f, - 0.905021825f, 1.91936605f, 0.837427991f, 0.792213732f, -0.133271854f, - -0.17216571f, 0.128957025f, 0.934955336f, 0.0451873479f, -0.120952621f, - 0.746436225f, 0.705446224f, 0.929172217f, -0.351493549f, 0.807577594f, - 0.825371955f, 0.383812296f, 0.916293093f, 0.82603058f, 1.23885956f, - 0.905059196f, 0.015164554f, 0.950156781f, 0.508443732f, 0.794845279f, - 0.12571529f, -0.125074273f, 0.227326869f, 0.0147000261f, 0.378735409f, - 1.15842402f, 1.34712305f, 1.2980804f, 0.277102016f, 0.953435072f, - 0.115916842f, 0.688879376f, 0.508405162f, 0.35829352f, 0.727568094f, - 1.58768577f, 1.22504294f, 0.232589777f, 0.996727258f, 0.841224629f, - -0.0909671176f, 0.233051388f, -0.0110094378f, 0.787642119f, -0.109582274f - }); - auto actual = NDArrayFactory::create('c', { 5, 4, 3 }); - - Context ctx(1); - ctx.setInputArray(0, &yiqs); - ctx.setOutputArray(0, &actual); - - sd::ops::yiq_to_rgb op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + auto yiqs = NDArrayFactory::create( + 'c', {5, 4, 3}, + {0.775258899f, -0.288912386f, -0.132725924f, 0.0664454922f, + -0.212469354f, 0.455438733f, 0.418221354f, 0.349350512f, + 0.145902053f, 0.947576523f, -0.471601307f, 0.263960421f, + 0.700227439f, 0.32434237f, -0.278446227f, 0.130805135f, + -0.438441873f, 0.187127829f, 0.0276055578f, -0.179727226f, + 0.305075705f, 0.716282248f, 0.278215706f, -0.44586885f, + 0.76971364f, 0.131288841f, -0.141177326f, 0.900081575f, + -0.0788725987f, 0.14756602f, 0.387832165f, 0.229834676f, + 0.47921446f, 0.632930398f, 0.0443540029f, -0.268817365f, + 0.0977194682f, -0.141669706f, -0.140715122f, 0.946808815f, + -0.52525419f, -0.106209636f, 0.659476519f, 0.391066104f, + 0.426448852f, 0.496989518f, -0.283434421f, -0.177366048f, + 0.715208411f, -0.496444523f, 0.189553142f, 0.616444945f, + 0.345852494f, 0.447739422f, 0.224696323f, 0.451372236f, + 0.298027098f, 0.446561724f, -0.187599331f, -0.448159873f}); + auto expected = NDArrayFactory::create( + 'c', {5, 4, 3}, + {0.416663059f, 0.939747555f, 0.868814286f, 0.146075352f, + -0.170521997f, 1.07776645f, 0.842775284f, 0.228765106f, + 0.280231822f, 0.660605291f, 0.905021825f, 1.91936605f, + 0.837427991f, 0.792213732f, -0.133271854f, -0.17216571f, + 0.128957025f, 0.934955336f, 0.0451873479f, -0.120952621f, + 0.746436225f, 0.705446224f, 0.929172217f, -0.351493549f, + 0.807577594f, 0.825371955f, 0.383812296f, 0.916293093f, + 0.82603058f, 1.23885956f, 0.905059196f, 0.015164554f, + 0.950156781f, 0.508443732f, 0.794845279f, 0.12571529f, + -0.125074273f, 0.227326869f, 0.0147000261f, 0.378735409f, + 1.15842402f, 1.34712305f, 1.2980804f, 0.277102016f, + 0.953435072f, 0.115916842f, 0.688879376f, 0.508405162f, + 0.35829352f, 0.727568094f, 1.58768577f, 1.22504294f, + 0.232589777f, 0.996727258f, 0.841224629f, -0.0909671176f, + 0.233051388f, -0.0110094378f, 0.787642119f, -0.109582274f}); + auto actual = NDArrayFactory::create('c', {5, 4, 3}); + + Context ctx(1); + ctx.setInputArray(0, yiqs); + ctx.setOutputArray(0, actual); + + sd::ops::yiq_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_2) { - - auto yiqs = NDArrayFactory::create('c', { 5, 3, 4 }, { - 0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, -0.288912386f, - -0.212469354f, 0.349350512f, -0.471601307f, -0.132725924f, 0.455438733f, - 0.145902053f, 0.263960421f, 0.700227439f, 0.130805135f, 0.0276055578f, - 0.716282248f, 0.32434237f, -0.438441873f, -0.179727226f, 0.278215706f, - -0.278446227f, 0.187127829f, 0.305075705f, -0.44586885f, 0.76971364f, - 0.900081575f, 0.387832165f, 0.632930398f, 0.131288841f, -0.0788725987f, - 0.229834676f, 0.0443540029f, -0.141177326f, 0.14756602f, 0.47921446f, - -0.268817365f, 0.0977194682f, 0.946808815f, 0.659476519f, 0.496989518f, - -0.141669706f, -0.52525419f, 0.391066104f, -0.283434421f, -0.140715122f, - -0.106209636f, 0.426448852f, -0.177366048f, 0.715208411f, 0.616444945f, - 0.224696323f, 0.446561724f, -0.496444523f, 0.345852494f, 0.451372236f, - -0.187599331f, 0.189553142f, 0.447739422f, 0.298027098f, -0.448159873f - }); - auto expected = NDArrayFactory::create('c', { 5, 3, 4 }, { - 0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, 0.939747555f, - -0.170521997f, 0.228765106f, 0.905021825f, 0.868814286f, 1.07776645f, - 0.280231822f, 1.91936605f, 0.837427991f, -0.17216571f, 0.0451873479f, - 0.705446224f, 0.792213732f, 0.128957025f, -0.120952621f, 0.929172217f, - -0.133271854f, 0.934955336f, 0.746436225f, -0.351493549f, 0.807577594f, - 0.916293093f, 0.905059196f, 0.508443732f, 0.825371955f, 0.82603058f, - 0.015164554f, 0.794845279f, 0.383812296f, 1.23885956f, 0.950156781f, - 0.12571529f, -0.125074273f, 0.378735409f, 1.2980804f, 0.115916842f, - 0.227326869f, 1.15842402f, 0.277102016f, 0.688879376f, 0.0147000261f, - 1.34712305f, 0.953435072f, 0.508405162f, 0.35829352f, 1.22504294f, - 0.841224629f, -0.0110094378f, 0.727568094f, 0.232589777f, -0.0909671176f, - 0.787642119f, 1.58768577f, 0.996727258f, 0.233051388f, -0.109582274f - }); - auto actual = NDArrayFactory::create('c', { 5, 3, 4 }); - - Context ctx(1); - ctx.setInputArray(0, &yiqs); - ctx.setOutputArray(0, &actual); - ctx.setIArguments({ 1 }); - sd::ops::yiq_to_rgb op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + auto yiqs = NDArrayFactory::create( + 'c', {5, 3, 4}, + {0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, + -0.288912386f, -0.212469354f, 0.349350512f, -0.471601307f, + -0.132725924f, 0.455438733f, 0.145902053f, 0.263960421f, + 0.700227439f, 0.130805135f, 0.0276055578f, 0.716282248f, + 0.32434237f, -0.438441873f, -0.179727226f, 0.278215706f, + -0.278446227f, 0.187127829f, 0.305075705f, -0.44586885f, + 0.76971364f, 0.900081575f, 0.387832165f, 0.632930398f, + 0.131288841f, -0.0788725987f, 0.229834676f, 0.0443540029f, + -0.141177326f, 0.14756602f, 0.47921446f, -0.268817365f, + 0.0977194682f, 0.946808815f, 0.659476519f, 0.496989518f, + -0.141669706f, -0.52525419f, 0.391066104f, -0.283434421f, + -0.140715122f, -0.106209636f, 0.426448852f, -0.177366048f, + 0.715208411f, 0.616444945f, 0.224696323f, 0.446561724f, + -0.496444523f, 0.345852494f, 0.451372236f, -0.187599331f, + 0.189553142f, 0.447739422f, 0.298027098f, -0.448159873f}); + auto expected = NDArrayFactory::create( + 'c', {5, 3, 4}, + {0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, + 0.939747555f, -0.170521997f, 0.228765106f, 0.905021825f, + 0.868814286f, 1.07776645f, 0.280231822f, 1.91936605f, + 0.837427991f, -0.17216571f, 0.0451873479f, 0.705446224f, + 0.792213732f, 0.128957025f, -0.120952621f, 0.929172217f, + -0.133271854f, 0.934955336f, 0.746436225f, -0.351493549f, + 0.807577594f, 0.916293093f, 0.905059196f, 0.508443732f, + 0.825371955f, 0.82603058f, 0.015164554f, 0.794845279f, + 0.383812296f, 1.23885956f, 0.950156781f, 0.12571529f, + -0.125074273f, 0.378735409f, 1.2980804f, 0.115916842f, + 0.227326869f, 1.15842402f, 0.277102016f, 0.688879376f, + 0.0147000261f, 1.34712305f, 0.953435072f, 0.508405162f, + 0.35829352f, 1.22504294f, 0.841224629f, -0.0110094378f, + 0.727568094f, 0.232589777f, -0.0909671176f, 0.787642119f, + 1.58768577f, 0.996727258f, 0.233051388f, -0.109582274f}); + auto actual = NDArrayFactory::create('c', {5, 3, 4}); + + Context ctx(1); + ctx.setInputArray(0, yiqs); + ctx.setOutputArray(0, actual); + ctx.setIArguments({1}); + sd::ops::yiq_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_3) { - - auto yiqs = NDArrayFactory::create('c', { 4, 3 }, { - 0.775258899f, -0.288912386f, -0.132725924f, 0.0664454922f, -0.212469354f, - 0.455438733f, 0.418221354f, 0.349350512f, 0.145902053f, 0.947576523f, - -0.471601307f, 0.263960421f - }); - auto expected = NDArrayFactory::create('c', { 4, 3 }, { - 0.416663059f, 0.939747555f, 0.868814286f, 0.146075352f, -0.170521997f, - 1.07776645f, 0.842775284f, 0.228765106f, 0.280231822f, 0.660605291f, - 0.905021825f, 1.91936605f - }); - auto actual = NDArrayFactory::create('c', { 4, 3 }); - - Context ctx(1); - ctx.setInputArray(0, &yiqs); - ctx.setOutputArray(0, &actual); - - sd::ops::yiq_to_rgb op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + auto yiqs = NDArrayFactory::create( + 'c', {4, 3}, + {0.775258899f, -0.288912386f, -0.132725924f, 0.0664454922f, -0.212469354f, + 0.455438733f, 0.418221354f, 0.349350512f, 0.145902053f, 0.947576523f, + -0.471601307f, 0.263960421f}); + auto expected = NDArrayFactory::create( + 'c', {4, 3}, + {0.416663059f, 0.939747555f, 0.868814286f, 0.146075352f, -0.170521997f, + 1.07776645f, 0.842775284f, 0.228765106f, 0.280231822f, 0.660605291f, + 0.905021825f, 1.91936605f}); + auto actual = NDArrayFactory::create('c', {4, 3}); + + Context ctx(1); + ctx.setInputArray(0, yiqs); + ctx.setOutputArray(0, actual); + + sd::ops::yiq_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_4) { - - auto yiqs = NDArrayFactory::create('c', { 3, 4 }, { - 0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, -0.288912386f, - -0.212469354f, 0.349350512f, -0.471601307f, -0.132725924f, 0.455438733f, - 0.145902053f, 0.263960421f - }); - auto expected = NDArrayFactory::create('c', { 3, 4 }, { - 0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, 0.939747555f, - -0.170521997f, 0.228765106f, 0.905021825f, 0.868814286f, 1.07776645f, - 0.280231822f, 1.91936605f - }); - auto actual = NDArrayFactory::create('c', { 3, 4 }); - - Context ctx(1); - ctx.setInputArray(0, &yiqs); - ctx.setOutputArray(0, &actual); - ctx.setIArguments({ 0 }); - sd::ops::yiq_to_rgb op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + auto yiqs = NDArrayFactory::create( + 'c', {3, 4}, + {0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, -0.288912386f, + -0.212469354f, 0.349350512f, -0.471601307f, -0.132725924f, 0.455438733f, + 0.145902053f, 0.263960421f}); + auto expected = NDArrayFactory::create( + 'c', {3, 4}, + {0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, 0.939747555f, + -0.170521997f, 0.228765106f, 0.905021825f, 0.868814286f, 1.07776645f, + 0.280231822f, 1.91936605f}); + auto actual = NDArrayFactory::create('c', {3, 4}); + + Context ctx(1); + ctx.setInputArray(0, yiqs); + ctx.setOutputArray(0, actual); + ctx.setIArguments({0}); + sd::ops::yiq_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_5) { - - auto yiqs = NDArrayFactory::create('c', { 3 }, { - 0.775258899f, -0.288912386f, -0.132725924f - }); - auto expected = NDArrayFactory::create('c', { 3 }, { - 0.416663059f, 0.939747555f, 0.868814286f - }); - auto actual = NDArrayFactory::create('c', { 3 }); - - Context ctx(1); - ctx.setInputArray(0, &yiqs); - ctx.setOutputArray(0, &actual); - - sd::ops::yiq_to_rgb op; - auto status = op.execute(&ctx); + auto yiqs = NDArrayFactory::create( + 'c', {3}, {0.775258899f, -0.288912386f, -0.132725924f}); + auto expected = NDArrayFactory::create( + 'c', {3}, {0.416663059f, 0.939747555f, 0.868814286f}); + auto actual = NDArrayFactory::create('c', {3}); + + Context ctx(1); + ctx.setInputArray(0, yiqs); + ctx.setOutputArray(0, actual); + + sd::ops::yiq_to_rgb op; + auto status = op.execute(&ctx); #if 0 actual.printBuffer("actual"); expected.printBuffer("expected"); #endif - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) { - - auto yiqs = NDArrayFactory::create('c', { 3, 4 }, { - 0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, -0.288912386f, - -0.212469354f, 0.349350512f, -0.471601307f, -0.132725924f, 0.455438733f, - 0.145902053f, 0.263960421f - }); - auto rgbs = NDArrayFactory::create('c', { 3, 4 }, { - 0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, 0.939747555f, - -0.170521997f, 0.228765106f, 0.905021825f, 0.868814286f, 1.07776645f, - 0.280231822f, 1.91936605f - }); - - //get subarray - NDArray subArrYiqs = yiqs.subarray({ NDIndex::all(), NDIndex::point(0) }); - NDArray expected = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) }); - subArrYiqs.reshapei({ 3 }); - expected.reshapei({ 3 }); + auto yiqs = NDArrayFactory::create( + 'c', {3, 4}, + {0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, -0.288912386f, + -0.212469354f, 0.349350512f, -0.471601307f, -0.132725924f, 0.455438733f, + 0.145902053f, 0.263960421f}); + auto rgbs = NDArrayFactory::create( + 'c', {3, 4}, + {0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, 0.939747555f, + -0.170521997f, 0.228765106f, 0.905021825f, 0.868814286f, 1.07776645f, + 0.280231822f, 1.91936605f}); + + // get subarray + NDArray subArrYiqs = yiqs.subarray({NDIndex::all(), NDIndex::point(0)}); + NDArray expected = rgbs.subarray({NDIndex::all(), NDIndex::point(0)}); + subArrYiqs.reshapei({3}); + expected.reshapei({3}); #if 0 //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] subArrYiqs.printShapeInfo("subArrYiqs"); #endif - auto actual = NDArrayFactory::create('c', { 3 }); + auto actual = NDArrayFactory::create('c', {3}); - Context ctx(1); - ctx.setInputArray(0, &subArrYiqs); - ctx.setOutputArray(0, &actual); - sd::ops::yiq_to_rgb op; - auto status = op.execute(&ctx); + Context ctx(1); + ctx.setInputArray(0, subArrYiqs); + ctx.setOutputArray(0, actual); + sd::ops::yiq_to_rgb op; + auto status = op.execute(&ctx); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } //////////////////////////////////////////////////////////////////////////////// @@ -1127,7 +1100,7 @@ TEST_F(DeclarableOpsTests16, clipbynorm_3) { auto result = op.evaluate({&x}, {1.0}, {1}); auto z = result.at(0); - auto zNorm1 = z->reduceAlongDimension(reduce::Norm2, {1}, true); + auto zNorm1 = z.reduceAlongDimension(reduce::Norm2, {1}, true); auto exp = NDArrayFactory::create('c', {3, 1}, {1., 1., xNorm1.e(2)}); ASSERT_TRUE(exp.isSameShape(&zNorm1)); @@ -1282,7 +1255,7 @@ TEST_F(DeclarableOpsTests16, clipbynorm_12) { sd::ops::clipbynorm op; auto result = op.evaluate({&x}, {0.54}, {}); - ASSERT_EQ(e, *result.at(0)); + ASSERT_EQ(e, result.at(0)); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp index 1341312f8deb..e0c0bd451984 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp @@ -14,77 +14,73 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // @author raver119@gmail.com // -#include "testlayers.h" -#include #include -#include #include +#include +#include + #include +#include "testlayers.h" using namespace sd; - class DeclarableOpsTests17 : public testing::Test { -public: - - DeclarableOpsTests17() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests17() { + printf("\n"); + fflush(stdout); + } }; TEST_F(DeclarableOpsTests17, test_sparse_to_dense_1) { - auto values = NDArrayFactory::create({1.f, 2.f, 3.f}); - auto shape = NDArrayFactory::create({3, 3}); - auto ranges = NDArrayFactory::create({0,0, 1,1, 2,2}); - auto def = NDArrayFactory::create(0.f); - auto exp = NDArrayFactory::create('c', {3, 3}, {1.f,0.f,0.f, 0.f,2.f,0.f, 0.f,0.f,3.f}); - - - sd::ops::compat_sparse_to_dense op; - auto result = op.evaluate({&ranges, &shape, &values, &def}); - ASSERT_EQ(Status::OK(), result.status()); + auto values = NDArrayFactory::create({1.f, 2.f, 3.f}); + auto shape = NDArrayFactory::create({3, 3}); + auto ranges = NDArrayFactory::create({0, 0, 1, 1, 2, 2}); + auto def = NDArrayFactory::create(0.f); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {1.f, 0.f, 0.f, 0.f, 2.f, 0.f, 0.f, 0.f, 3.f}); + + sd::ops::compat_sparse_to_dense op; + auto result = op.evaluate({&ranges, &shape, &values, &def}); + ASSERT_EQ(Status::OK(), result.status()); } TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) { - auto values = NDArrayFactory::string({3}, {"alpha", "beta", "gamma"}); - auto shape = NDArrayFactory::create({3, 3}); - auto ranges = NDArrayFactory::create({0,0, 1,1, 2,2}); - auto def = NDArrayFactory::string("d"); - auto exp = NDArrayFactory::string( {3, 3}, {"alpha","d","d", "d","beta","d", "d","d","gamma"}); - - - sd::ops::compat_sparse_to_dense op; - auto result = op.evaluate({&ranges, &shape, &values, &def}); - ASSERT_EQ(Status::OK(), result.status()); - + auto values = NDArrayFactory::string({3}, {"alpha", "beta", "gamma"}); + auto shape = NDArrayFactory::create({3, 3}); + auto ranges = NDArrayFactory::create({0, 0, 1, 1, 2, 2}); + auto def = NDArrayFactory::string("d"); + auto exp = NDArrayFactory::string( + {3, 3}, {"alpha", "d", "d", "d", "beta", "d", "d", "d", "gamma"}); + + sd::ops::compat_sparse_to_dense op; + auto result = op.evaluate({&ranges, &shape, &values, &def}); + ASSERT_EQ(Status::OK(), result.status()); } TEST_F(DeclarableOpsTests17, test_compat_string_split_1) { - auto x = NDArrayFactory::string( {2}, {"first string", "second"}); - auto delimiter = NDArrayFactory::string(" "); - - auto exp0 = NDArrayFactory::create({0,0, 0,1, 1,0}); - auto exp1 = NDArrayFactory::string( {3}, {"first", "string", "second"}); + auto x = NDArrayFactory::string({2}, {"first string", "second"}); + auto delimiter = NDArrayFactory::string(" "); - sd::ops::compat_string_split op; - auto result = op.evaluate({&x, &delimiter}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(2, result.size()); + auto exp0 = NDArrayFactory::create({0, 0, 0, 1, 1, 0}); + auto exp1 = NDArrayFactory::string({3}, {"first", "string", "second"}); - auto z0 = result.at(0); - auto z1 = result.at(1); + sd::ops::compat_string_split op; + auto result = op.evaluate({&x, &delimiter}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(2, result.size()); - ASSERT_TRUE(exp0.isSameShape(z0)); - ASSERT_TRUE(exp1.isSameShape(z1)); + auto z0 = result.at(0); + auto z1 = result.at(1); - ASSERT_EQ(exp0, *z0); - ASSERT_EQ(exp1, *z1); + ASSERT_TRUE(exp0.isSameShape(z0)); + ASSERT_TRUE(exp1.isSameShape(z1)); + ASSERT_EQ(exp0, z0); + ASSERT_EQ(exp1, z1); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp index 1f36a8f2c77a..bcea71efc00b 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp @@ -14,1596 +14,2613 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +// +// @author raver119@gmail.com +// - // - // @author raver119@gmail.com - // - -#include "testlayers.h" -#include #include -#include #include +#include +#include + #include +#include "testlayers.h" using namespace sd; - class DeclarableOpsTests18 : public testing::Test { -public: - - DeclarableOpsTests18() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests18() { + printf("\n"); + fflush(stdout); + } }; TEST_F(DeclarableOpsTests18, test_bitcast_1) { - auto x = NDArrayFactory::create(0.23028551377579154); - auto z = NDArrayFactory::create(0); - auto e = NDArrayFactory::create(4597464930322771456L); + auto x = NDArrayFactory::create(0.23028551377579154); + auto z = NDArrayFactory::create(0); + auto e = NDArrayFactory::create(4597464930322771456L); - sd::ops::bitcast op; - auto status = op.execute({ &x }, { &z }, {}, { (Nd4jLong)sd::DataType::INT64 }, {}); - ASSERT_EQ(Status::OK(), status); + sd::ops::bitcast op; + auto status = op.execute({&x}, {&z}, {}, {(Nd4jLong)sd::DataType::INT64}, {}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests18, test_tanh_1) { - auto x = NDArrayFactory::create('c', { 8 }, { 0.23f, -0.23f, 0.35f, -0.35f, 0.64f, -0.64f, 100000.f, -100000.f }); - auto z = x.ulike(); - auto e = NDArrayFactory::create('c', { 8 }, { 0.226028f, -0.226028f, 0.336376f, -0.336376f, 0.564900f, -0.564900f, 1.f, -1.f }); - - sd::ops::tanh op; - op.execute({ &x }, { &z }); - - ASSERT_EQ(e, z); + auto x = NDArrayFactory::create( + 'c', {8}, + {0.23f, -0.23f, 0.35f, -0.35f, 0.64f, -0.64f, 100000.f, -100000.f}); + auto z = x.ulike(); + auto e = NDArrayFactory::create( + 'c', {8}, + {0.226028f, -0.226028f, 0.336376f, -0.336376f, 0.564900f, -0.564900f, 1.f, + -1.f}); + + sd::ops::tanh op; + op.execute({&x}, {&z}); + + ASSERT_EQ(e, z); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, test_tanh_2) { - - NDArray x('c', { 2, 2, 3, 3, 4, 4 }, sd::DataType::FLOAT32); - NDArray z('c', { 2, 2, 3, 3, 4, 4 }, sd::DataType::FLOAT32); - - x.linspace(-1., 0.003); - - NDArray e('c', { 2, 2, 3, 3, 4, 4 }, { -0.761594, -0.760331, -0.759063, -0.757788, -0.756508, -0.755222, -0.753930, -0.752633, -0.751329, -0.750020, -0.748704, -0.747383, -0.746056, -0.744723, -0.743383, -0.742038, -0.740687, -0.739330, -0.737967, -0.736598, -0.735222, -0.733841, -0.732453, -0.731060, -0.729660, -0.728254, -0.726842, -0.725424, -0.724000, -0.722569, -0.721132, -0.719689, -0.718240, -0.716784, -0.715323, -0.713854, -0.712380, -0.710899, -0.709412, -0.707919, -0.706419, -0.704913, -0.703401, -0.701882, -0.700357, -0.698825, -0.697287, -0.695742, -0.694191, -0.692634, -0.691069, -0.689499, -0.687922, -0.686338, -0.684748, -0.683152, -0.681548, -0.679939, -0.678322, -0.676699, -0.675070, -0.673434, -0.671791, -0.670142, -0.668486, -0.666823, -0.665153, -0.663477, -0.661795, -0.660105, -0.658409, -0.656706, -0.654997, -0.653280, -0.651557, -0.649827, -0.648091, -0.646348, -0.644597, -0.642841, -0.641077, -0.639306, -0.637529, -0.635745, -0.633954, -0.632157, -0.630352, -0.628541, -0.626722, -0.624897, -0.623065, -0.621227, -0.619381, -0.617528, -0.615669, -0.613803, -0.611929, -0.610049, -0.608162, -0.606269, -0.604368, -0.602460, -0.600546, -0.598624, -0.596696, -0.594760, -0.592818, -0.590869, -0.588913, -0.586950, -0.584980, -0.583003, -0.581019, -0.579029, -0.577031, -0.575026, -0.573015, -0.570996, -0.568971, -0.566939, -0.564900, -0.562853, -0.560800, -0.558740, -0.556674, -0.554600, -0.552519, -0.550431, -0.548337, -0.546235, -0.544127, -0.542012, -0.539890, -0.537761, -0.535625, -0.533482, -0.531332, -0.529176, -0.527013, -0.524842, -0.522665, -0.520482, -0.518291, -0.516093, -0.513889, -0.511678, -0.509460, -0.507235, -0.505004, -0.502765, -0.500520, -0.498268, -0.496010, -0.493745, -0.491472, -0.489194, -0.486908, -0.484616, -0.482318, -0.480012, -0.477700, -0.475381, -0.473056, -0.470724, -0.468385, -0.466040, -0.463689, -0.461330, -0.458966, -0.456594, -0.454216, -0.451832, -0.449441, -0.447044, -0.444640, -0.442230, -0.439814, -0.437391, -0.434962, -0.432526, -0.430084, -0.427636, -0.425181, -0.422721, -0.420254, -0.417780, -0.415301, -0.412815, -0.410323, -0.407825, -0.405321, -0.402811, -0.400295, -0.397773, -0.395244, -0.392710, -0.390170, -0.387623, -0.385071, -0.382513, -0.379949, -0.377379, -0.374803, -0.372222, -0.369635, -0.367042, -0.364443, -0.361839, -0.359229, -0.356613, -0.353992, -0.351365, -0.348732, -0.346095, -0.343451, -0.340802, -0.338148, -0.335488, -0.332823, -0.330153, -0.327477, -0.324796, -0.322110, -0.319419, -0.316723, -0.314021, -0.311314, -0.308602, -0.305886, -0.303164, -0.300437, -0.297705, -0.294969, -0.292227, -0.289481, -0.286730, -0.283975, -0.281214, -0.278449, -0.275679, -0.272905, -0.270126, -0.267343, -0.264555, -0.261763, -0.258966, -0.256165, -0.253360, -0.250550, -0.247737, -0.244919, -0.242097, -0.239270, -0.236440, -0.233606, -0.230768, -0.227925, -0.225079, -0.222229, -0.219376, -0.216518, -0.213657, -0.210792, -0.207923, -0.205051, -0.202176, -0.199297, -0.196414, -0.193528, -0.190639, -0.187746, -0.184850, -0.181951, -0.179049, -0.176144, -0.173235, -0.170324, -0.167409, -0.164492, -0.161572, -0.158649, -0.155723, -0.152794, -0.149863, -0.146929, -0.143992, -0.141053, -0.138112, -0.135168, -0.132221, -0.129273, -0.126322, -0.123368, -0.120413, -0.117455, -0.114496, -0.111534, -0.108570, -0.105605, -0.102637, -0.099668, -0.096697, -0.093724, -0.090750, -0.087774, -0.084796, -0.081817, -0.078836, -0.075854, -0.072871, -0.069886, -0.066900, -0.063913, -0.060924, -0.057935, -0.054945, -0.051953, -0.048961, -0.045968, -0.042974, -0.039979, -0.036983, -0.033987, -0.030990, -0.027993, -0.024995, -0.021996, -0.018998, -0.015999, -0.012999, -0.010000, -0.007000, -0.004000, -0.001000, 0.002000, 0.005000, 0.008000, 0.011000, 0.013999, 0.016998, 0.019997, 0.022996, 0.025994, 0.028992, 0.031989, 0.034986, 0.037982, 0.040977, 0.043972, 0.046965, 0.049958, 0.052950, 0.055942, 0.058932, 0.061921, 0.064909, 0.067895, 0.070881, 0.073865, 0.076848, 0.079830, 0.082810, 0.085789, 0.088766, 0.091741, 0.094715, 0.097687, 0.100658, 0.103627, 0.106594, 0.109558, 0.112521, 0.115482, 0.118441, 0.121398, 0.124353, 0.127305, 0.130256, 0.133204, 0.136149, 0.139092, 0.142033, 0.144971, 0.147907, 0.150840, 0.153771, 0.156698, 0.159623, 0.162545, 0.165465, 0.168381, 0.171294, 0.174205, 0.177112, 0.180017, 0.182918, 0.185816, 0.188711, 0.191602, 0.194490, 0.197375, 0.200257, 0.203135, 0.206009, 0.208880, 0.211747, 0.214611, 0.217471, 0.220327, 0.223180, 0.226028, 0.228873, 0.231714, 0.234551, 0.237384, 0.240213, 0.243038, 0.245858, 0.248675, 0.251487, 0.254296, 0.257099, 0.259899, 0.262694, 0.265485, 0.268271, 0.271053, 0.273830, 0.276603, 0.279371, 0.282135, 0.284894, 0.287648, 0.290397, 0.293142, 0.295882, 0.298617, 0.301347, 0.304072, 0.306792, 0.309507, 0.312217, 0.314922, 0.317622, 0.320317, 0.323006, 0.325691, 0.328370, 0.331044, 0.333712, 0.336376, 0.339033, 0.341686, 0.344333, 0.346974, 0.349611, 0.352241, 0.354866, 0.357485, 0.360099, 0.362707, 0.365310, 0.367907, 0.370498, 0.373083, 0.375663, 0.378236, 0.380804, 0.383366, 0.385922, 0.388473, 0.391017, 0.393555, 0.396088, 0.398614, 0.401134, 0.403649, 0.406157, 0.408659, 0.411155, 0.413644, 0.416128, 0.418605, 0.421077, 0.423542, 0.426000, 0.428453, 0.430899, 0.433339, 0.435772, 0.438199, 0.440620, 0.443034, 0.445442, 0.447844, 0.450239, 0.452628, 0.455010, 0.457385, 0.459755, 0.462117, 0.464473, 0.466823, 0.469166, 0.471502, 0.473832, 0.476155, 0.478471, 0.480781, 0.483085, 0.485381, 0.487671, 0.489954, 0.492231, 0.494500, 0.496763, 0.499020, 0.501269, 0.503512, 0.505748, 0.507977, 0.510200, 0.512416, 0.514624, 0.516827, 0.519022, 0.521210, 0.523392, 0.525567, 0.527735, 0.529896, 0.532050, 0.534197, 0.536338, 0.538471, 0.540598, 0.542718, 0.544831, 0.546937, 0.549036, 0.551128, 0.553213, 0.555292, 0.557363, 0.559428, 0.561486, 0.563536, 0.565580, 0.567617, 0.569647, 0.571670, 0.573686, 0.575695, 0.577697, 0.579693, 0.581681, 0.583663, 0.585637, 0.587605, 0.589566, 0.591519, 0.593466, 0.595406, 0.597339, 0.599265, 0.601184, 0.603097, 0.605002, 0.606901, 0.608792, 0.610677, 0.612555, 0.614425, 0.616289, 0.618147, 0.619997 }, sd::DataType::FLOAT32); - - sd::ops::tanh op; - op.execute({ &x }, { &z }); - ASSERT_EQ(e, z); + NDArray x('c', {2, 2, 3, 3, 4, 4}, sd::DataType::FLOAT32); + NDArray z('c', {2, 2, 3, 3, 4, 4}, sd::DataType::FLOAT32); + + x.linspace(-1., 0.003); + + NDArray e('c', {2, 2, 3, 3, 4, 4}, + {-0.761594, -0.760331, -0.759063, -0.757788, -0.756508, -0.755222, + -0.753930, -0.752633, -0.751329, -0.750020, -0.748704, -0.747383, + -0.746056, -0.744723, -0.743383, -0.742038, -0.740687, -0.739330, + -0.737967, -0.736598, -0.735222, -0.733841, -0.732453, -0.731060, + -0.729660, -0.728254, -0.726842, -0.725424, -0.724000, -0.722569, + -0.721132, -0.719689, -0.718240, -0.716784, -0.715323, -0.713854, + -0.712380, -0.710899, -0.709412, -0.707919, -0.706419, -0.704913, + -0.703401, -0.701882, -0.700357, -0.698825, -0.697287, -0.695742, + -0.694191, -0.692634, -0.691069, -0.689499, -0.687922, -0.686338, + -0.684748, -0.683152, -0.681548, -0.679939, -0.678322, -0.676699, + -0.675070, -0.673434, -0.671791, -0.670142, -0.668486, -0.666823, + -0.665153, -0.663477, -0.661795, -0.660105, -0.658409, -0.656706, + -0.654997, -0.653280, -0.651557, -0.649827, -0.648091, -0.646348, + -0.644597, -0.642841, -0.641077, -0.639306, -0.637529, -0.635745, + -0.633954, -0.632157, -0.630352, -0.628541, -0.626722, -0.624897, + -0.623065, -0.621227, -0.619381, -0.617528, -0.615669, -0.613803, + -0.611929, -0.610049, -0.608162, -0.606269, -0.604368, -0.602460, + -0.600546, -0.598624, -0.596696, -0.594760, -0.592818, -0.590869, + -0.588913, -0.586950, -0.584980, -0.583003, -0.581019, -0.579029, + -0.577031, -0.575026, -0.573015, -0.570996, -0.568971, -0.566939, + -0.564900, -0.562853, -0.560800, -0.558740, -0.556674, -0.554600, + -0.552519, -0.550431, -0.548337, -0.546235, -0.544127, -0.542012, + -0.539890, -0.537761, -0.535625, -0.533482, -0.531332, -0.529176, + -0.527013, -0.524842, -0.522665, -0.520482, -0.518291, -0.516093, + -0.513889, -0.511678, -0.509460, -0.507235, -0.505004, -0.502765, + -0.500520, -0.498268, -0.496010, -0.493745, -0.491472, -0.489194, + -0.486908, -0.484616, -0.482318, -0.480012, -0.477700, -0.475381, + -0.473056, -0.470724, -0.468385, -0.466040, -0.463689, -0.461330, + -0.458966, -0.456594, -0.454216, -0.451832, -0.449441, -0.447044, + -0.444640, -0.442230, -0.439814, -0.437391, -0.434962, -0.432526, + -0.430084, -0.427636, -0.425181, -0.422721, -0.420254, -0.417780, + -0.415301, -0.412815, -0.410323, -0.407825, -0.405321, -0.402811, + -0.400295, -0.397773, -0.395244, -0.392710, -0.390170, -0.387623, + -0.385071, -0.382513, -0.379949, -0.377379, -0.374803, -0.372222, + -0.369635, -0.367042, -0.364443, -0.361839, -0.359229, -0.356613, + -0.353992, -0.351365, -0.348732, -0.346095, -0.343451, -0.340802, + -0.338148, -0.335488, -0.332823, -0.330153, -0.327477, -0.324796, + -0.322110, -0.319419, -0.316723, -0.314021, -0.311314, -0.308602, + -0.305886, -0.303164, -0.300437, -0.297705, -0.294969, -0.292227, + -0.289481, -0.286730, -0.283975, -0.281214, -0.278449, -0.275679, + -0.272905, -0.270126, -0.267343, -0.264555, -0.261763, -0.258966, + -0.256165, -0.253360, -0.250550, -0.247737, -0.244919, -0.242097, + -0.239270, -0.236440, -0.233606, -0.230768, -0.227925, -0.225079, + -0.222229, -0.219376, -0.216518, -0.213657, -0.210792, -0.207923, + -0.205051, -0.202176, -0.199297, -0.196414, -0.193528, -0.190639, + -0.187746, -0.184850, -0.181951, -0.179049, -0.176144, -0.173235, + -0.170324, -0.167409, -0.164492, -0.161572, -0.158649, -0.155723, + -0.152794, -0.149863, -0.146929, -0.143992, -0.141053, -0.138112, + -0.135168, -0.132221, -0.129273, -0.126322, -0.123368, -0.120413, + -0.117455, -0.114496, -0.111534, -0.108570, -0.105605, -0.102637, + -0.099668, -0.096697, -0.093724, -0.090750, -0.087774, -0.084796, + -0.081817, -0.078836, -0.075854, -0.072871, -0.069886, -0.066900, + -0.063913, -0.060924, -0.057935, -0.054945, -0.051953, -0.048961, + -0.045968, -0.042974, -0.039979, -0.036983, -0.033987, -0.030990, + -0.027993, -0.024995, -0.021996, -0.018998, -0.015999, -0.012999, + -0.010000, -0.007000, -0.004000, -0.001000, 0.002000, 0.005000, + 0.008000, 0.011000, 0.013999, 0.016998, 0.019997, 0.022996, + 0.025994, 0.028992, 0.031989, 0.034986, 0.037982, 0.040977, + 0.043972, 0.046965, 0.049958, 0.052950, 0.055942, 0.058932, + 0.061921, 0.064909, 0.067895, 0.070881, 0.073865, 0.076848, + 0.079830, 0.082810, 0.085789, 0.088766, 0.091741, 0.094715, + 0.097687, 0.100658, 0.103627, 0.106594, 0.109558, 0.112521, + 0.115482, 0.118441, 0.121398, 0.124353, 0.127305, 0.130256, + 0.133204, 0.136149, 0.139092, 0.142033, 0.144971, 0.147907, + 0.150840, 0.153771, 0.156698, 0.159623, 0.162545, 0.165465, + 0.168381, 0.171294, 0.174205, 0.177112, 0.180017, 0.182918, + 0.185816, 0.188711, 0.191602, 0.194490, 0.197375, 0.200257, + 0.203135, 0.206009, 0.208880, 0.211747, 0.214611, 0.217471, + 0.220327, 0.223180, 0.226028, 0.228873, 0.231714, 0.234551, + 0.237384, 0.240213, 0.243038, 0.245858, 0.248675, 0.251487, + 0.254296, 0.257099, 0.259899, 0.262694, 0.265485, 0.268271, + 0.271053, 0.273830, 0.276603, 0.279371, 0.282135, 0.284894, + 0.287648, 0.290397, 0.293142, 0.295882, 0.298617, 0.301347, + 0.304072, 0.306792, 0.309507, 0.312217, 0.314922, 0.317622, + 0.320317, 0.323006, 0.325691, 0.328370, 0.331044, 0.333712, + 0.336376, 0.339033, 0.341686, 0.344333, 0.346974, 0.349611, + 0.352241, 0.354866, 0.357485, 0.360099, 0.362707, 0.365310, + 0.367907, 0.370498, 0.373083, 0.375663, 0.378236, 0.380804, + 0.383366, 0.385922, 0.388473, 0.391017, 0.393555, 0.396088, + 0.398614, 0.401134, 0.403649, 0.406157, 0.408659, 0.411155, + 0.413644, 0.416128, 0.418605, 0.421077, 0.423542, 0.426000, + 0.428453, 0.430899, 0.433339, 0.435772, 0.438199, 0.440620, + 0.443034, 0.445442, 0.447844, 0.450239, 0.452628, 0.455010, + 0.457385, 0.459755, 0.462117, 0.464473, 0.466823, 0.469166, + 0.471502, 0.473832, 0.476155, 0.478471, 0.480781, 0.483085, + 0.485381, 0.487671, 0.489954, 0.492231, 0.494500, 0.496763, + 0.499020, 0.501269, 0.503512, 0.505748, 0.507977, 0.510200, + 0.512416, 0.514624, 0.516827, 0.519022, 0.521210, 0.523392, + 0.525567, 0.527735, 0.529896, 0.532050, 0.534197, 0.536338, + 0.538471, 0.540598, 0.542718, 0.544831, 0.546937, 0.549036, + 0.551128, 0.553213, 0.555292, 0.557363, 0.559428, 0.561486, + 0.563536, 0.565580, 0.567617, 0.569647, 0.571670, 0.573686, + 0.575695, 0.577697, 0.579693, 0.581681, 0.583663, 0.585637, + 0.587605, 0.589566, 0.591519, 0.593466, 0.595406, 0.597339, + 0.599265, 0.601184, 0.603097, 0.605002, 0.606901, 0.608792, + 0.610677, 0.612555, 0.614425, 0.616289, 0.618147, 0.619997}, + sd::DataType::FLOAT32); + + sd::ops::tanh op; + op.execute({&x}, {&z}); + ASSERT_EQ(e, z); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, test_tanh_bp) { - - NDArray x('c', { 2, 3, 4 }, sd::DataType::FLOAT32); - NDArray dLdz('c', { 2, 3, 4 }, sd::DataType::FLOAT32); - NDArray dLdx('c', { 2, 3, 4 }, sd::DataType::FLOAT32); - - x.linspace(-1., 0.003); - dLdz.linspace(0.01, 0.01); - - NDArray e('c', { 2, 3, 4 }, { 0.004200, 0.008438, 0.012715, 0.017030, 0.021385, 0.025778, 0.030211, 0.034684, 0.039195, 0.043747, 0.048339, 0.052970, 0.057642, 0.062354, 0.067107, 0.071901, 0.076735, 0.081610, 0.086527, 0.091485, 0.096484, 0.101525, 0.106608, 0.111732 }, sd::DataType::FLOAT32); - - sd::ops::tanh_bp op; - op.execute({ &x, &dLdz }, { &dLdx }); - ASSERT_EQ(e, dLdx); + NDArray x('c', {2, 3, 4}, sd::DataType::FLOAT32); + NDArray dLdz('c', {2, 3, 4}, sd::DataType::FLOAT32); + NDArray dLdx('c', {2, 3, 4}, sd::DataType::FLOAT32); + + x.linspace(-1., 0.003); + dLdz.linspace(0.01, 0.01); + + NDArray e('c', {2, 3, 4}, + {0.004200, 0.008438, 0.012715, 0.017030, 0.021385, 0.025778, + 0.030211, 0.034684, 0.039195, 0.043747, 0.048339, 0.052970, + 0.057642, 0.062354, 0.067107, 0.071901, 0.076735, 0.081610, + 0.086527, 0.091485, 0.096484, 0.101525, 0.106608, 0.111732}, + sd::DataType::FLOAT32); + + sd::ops::tanh_bp op; + op.execute({&x, &dLdz}, {&dLdx}); + ASSERT_EQ(e, dLdx); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, test_tanh_bp2) { - - NDArray x('f', { 2, 3, 4 }, sd::DataType::FLOAT32); - NDArray dLdz('f', { 2, 3, 4 }, sd::DataType::FLOAT32); - NDArray dLdx('f', { 2, 3, 4 }, sd::DataType::FLOAT32); - - x.linspace(-1., 0.003); - dLdz.linspace(0.01, 0.01); - - NDArray exp('c', { 2, 3, 4 }, { 0.004200, 0.008438, 0.012715, 0.017030, 0.021385, 0.025778, 0.030211, 0.034684, 0.039195, 0.043747, 0.048339, 0.052970, 0.057642, 0.062354, 0.067107, 0.071901, 0.076735, 0.081610, 0.086527, 0.091485, 0.096484, 0.101525, 0.106608, 0.111732 }, sd::DataType::FLOAT32); - NDArray e('f', { 2, 3, 4 }, sd::DataType::FLOAT32); - e.assign(exp); - - sd::ops::tanh_bp op; - op.execute({ &x, &dLdz }, { &dLdx }); - ASSERT_EQ(e, dLdx); + NDArray x('f', {2, 3, 4}, sd::DataType::FLOAT32); + NDArray dLdz('f', {2, 3, 4}, sd::DataType::FLOAT32); + NDArray dLdx('f', {2, 3, 4}, sd::DataType::FLOAT32); + + x.linspace(-1., 0.003); + dLdz.linspace(0.01, 0.01); + + NDArray exp('c', {2, 3, 4}, + {0.004200, 0.008438, 0.012715, 0.017030, 0.021385, 0.025778, + 0.030211, 0.034684, 0.039195, 0.043747, 0.048339, 0.052970, + 0.057642, 0.062354, 0.067107, 0.071901, 0.076735, 0.081610, + 0.086527, 0.091485, 0.096484, 0.101525, 0.106608, 0.111732}, + sd::DataType::FLOAT32); + NDArray e('f', {2, 3, 4}, sd::DataType::FLOAT32); + e.assign(exp); + + sd::ops::tanh_bp op; + op.execute({&x, &dLdz}, {&dLdx}); + ASSERT_EQ(e, dLdx); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, test_tanh_bp3) { - - NDArray x('f', { 2, 2, 3, 3, 4, 4 }, sd::DataType::FLOAT32); - NDArray dLdz('f', { 2,2, 3,3, 4,4 }, sd::DataType::FLOAT32); - NDArray dLdx('f', { 2, 2, 3, 3, 4, 4 }, sd::DataType::FLOAT32); - - x.linspace(-1.5, 0.005); - dLdz.linspace(-1., 0.01); - - NDArray exp('c', { 2, 2, 3, 3, 4, 4 }, { -0.180707, -0.180525, -0.180324, -0.180103, -0.179861, -0.179599, -0.179315, -0.179009, -0.178682, -0.178333, -0.177961, -0.177566, -0.177148, -0.176706, -0.176240, -0.175750, -0.175236, -0.174696, -0.174130, -0.173539, -0.172922, -0.172278, -0.171607, -0.170909, -0.170183, -0.169429, -0.168646, -0.167834, -0.166993, -0.166123, -0.165222, -0.164290, -0.163327, -0.162334, -0.161308, -0.160250, -0.159159, -0.158035, -0.156877, -0.155686, -0.154460, -0.153199, -0.151903, -0.150571, -0.149203, -0.147798, -0.146356, -0.144876, -0.143359, -0.141803, -0.140207, -0.138573, -0.136898, -0.135183, -0.133428, -0.131630, -0.129792, -0.127910, -0.125986, -0.124019, -0.122008, -0.119953, -0.117853, -0.115708, -0.113517, -0.111279, -0.108996, -0.106665, -0.104286, -0.101859, -0.099383, -0.096859, -0.094284, -0.091660, -0.088984, -0.086258, -0.083480, -0.080649, -0.077766, -0.074830, -0.071840, -0.068796, -0.065697, -0.062543, -0.059334, -0.056068, -0.052745, -0.049365, -0.045928, -0.042432, -0.038878, -0.035264, -0.031591, -0.027858, -0.024064, -0.020209, -0.016292, -0.012313, -0.008272, -0.004168, 0.000000, 0.004232, 0.008528, 0.012889, 0.017316, 0.021808, 0.026367, 0.030992, 0.035684, 0.040444, 0.045272, 0.050169, 0.055134, 0.060168, 0.065273, 0.070447, 0.075692, 0.081007, 0.086394, 0.091853, 0.097383, 0.102986, 0.108662, 0.114411, 0.120233, 0.126129, 0.132099, 0.138144, 0.144263, 0.150457, 0.156727, 0.163072, 0.169493, 0.175990, 0.182564, 0.189214, 0.195941, 0.202745, 0.209627, 0.216585, 0.223622, 0.230736, 0.237929, 0.245200, 0.252549, 0.259976, 0.267482, 0.275066, 0.282730, 0.290472, 0.298293, 0.306193, 0.314172, 0.322230, 0.330366, 0.338582, 0.346877, 0.355250, 0.363703, 0.372234, 0.380844, 0.389532, 0.398299, 0.407144, 0.416067, 0.425068, 0.434147, 0.443303, 0.452537, 0.461848, 0.471235, 0.480699, 0.490240, 0.499856, 0.509548, 0.519314, 0.529156, 0.539072, 0.549062, 0.559126, 0.569262, 0.579471, 0.589753, 0.600106, 0.610530, 0.621024, 0.631588, 0.642222, 0.652924, 0.663694, 0.674532, 0.685436, 0.696406, 0.707441, 0.718541, 0.729704, 0.740931, 0.752219, 0.763568, 0.774978, 0.786448, 0.797976, 0.809561, 0.821203, 0.832901, 0.844654, 0.856460, 0.868319, 0.880230, 0.892191, 0.904201, 0.916260, 0.928366, 0.940518, 0.952715, 0.964955, 0.977238, 0.989561, 1.001925, 1.014327, 1.026767, 1.039242, 1.051752, 1.064295, 1.076870, 1.089475, 1.102109, 1.114771, 1.127459, 1.140171, 1.152907, 1.165664, 1.178441, 1.191237, 1.204050, 1.216878, 1.229720, 1.242573, 1.255438, 1.268311, 1.281192, 1.294078, 1.306968, 1.319860, 1.332753, 1.345644, 1.358533, 1.371417, 1.384294, 1.397163, 1.410022, 1.422870, 1.435704, 1.448522, 1.461323, 1.474105, 1.486867, 1.499606, 1.512321, 1.525009, 1.537669, 1.550299, 1.562897, 1.575462, 1.587991, 1.600483, 1.612935, 1.625347, 1.637715, 1.650040, 1.662317, 1.674545, 1.686724, 1.698850, 1.710922, 1.722939, 1.734897, 1.746797, 1.758635, 1.770409, 1.782119, 1.793762, 1.805337, 1.816842, 1.828274, 1.839633, 1.850916, 1.862121, 1.873248, 1.884294, 1.895258, 1.906137, 1.916931, 1.927637, 1.938255, 1.948782, 1.959216, 1.969557, 1.979802, 1.989950, 2.000000, 2.009950, 2.019798, 2.029543, 2.039184, 2.048719, 2.058147, 2.067466, 2.076675, 2.085773, 2.094759, 2.103630, 2.112386, 2.121026, 2.129548, 2.137952, 2.146235, 2.154397, 2.162437, 2.170354, 2.178146, 2.185813, 2.193353, 2.200766, 2.208051, 2.215207, 2.222232, 2.229127, 2.235889, 2.242520, 2.249017, 2.255379, 2.261607, 2.267699, 2.273656, 2.279475, 2.285158, 2.290702, 2.296108, 2.301376, 2.306503, 2.311491, 2.316339, 2.321046, 2.325613, 2.330038, 2.334321, 2.338464, 2.342464, 2.346322, 2.350037, 2.353610, 2.357041, 2.360329, 2.363475, 2.366478, 2.369338, 2.372056, 2.374632, 2.377065, 2.379356, 2.381505, 2.383512, 2.385378, 2.387103, 2.388686, 2.390128, 2.391431, 2.392593, 2.393615, 2.394499, 2.395244, 2.395850, 2.396319, 2.396650, 2.396845, 2.396904, 2.396826, 2.396615, 2.396268, 2.395789, 2.395176, 2.394431, 2.393554, 2.392547, 2.391410, 2.390144, 2.388749, 2.387227, 2.385578, 2.383804, 2.381904, 2.379880, 2.377734, 2.375465, 2.373075, 2.370565, 2.367936, 2.365188, 2.362324, 2.359343, 2.356247, 2.353038, 2.349715, 2.346280, 2.342735, 2.339080, 2.335316, 2.331445, 2.327468, 2.323386, 2.319200, 2.314912, 2.310522, 2.306031, 2.301442, 2.296754, 2.291970, 2.287090, 2.282116, 2.277049, 2.271890, 2.266641, 2.261302, 2.255876, 2.250362, 2.244763, 2.239080, 2.233314, 2.227467, 2.221538, 2.215531, 2.209445, 2.203284, 2.197047, 2.190736, 2.184352, 2.177897, 2.171371, 2.164777, 2.158115, 2.151386, 2.144592, 2.137735, 2.130815, 2.123833, 2.116792, 2.109692, 2.102533, 2.095320, 2.088051, 2.080727, 2.073352, 2.065925, 2.058447, 2.050921, 2.043347, 2.035727, 2.028061, 2.020351, 2.012599, 2.004804, 1.996969, 1.989094, 1.981181, 1.973232, 1.965246, 1.957225, 1.949171, 1.941084, 1.932965, 1.924816, 1.916638, 1.908432, 1.900198, 1.891938, 1.883654, 1.875345, 1.867014, 1.858661, 1.850286, 1.841892, 1.833479, 1.825048, 1.816600, 1.808136, 1.799657, 1.791165, 1.782659, 1.774141, 1.765612, 1.757073, 1.748523, 1.739967, 1.731401, 1.722829, 1.714251, 1.705668, 1.697082, 1.688491, 1.679897, 1.671302, 1.662707, 1.654110, 1.645514, 1.636920, 1.628328, 1.619738, 1.611152, 1.602570, 1.593993, 1.585422, 1.576857, 1.568299, 1.559749, 1.551207, 1.542674, 1.534151, 1.525638, 1.517136, 1.508645, 1.500167, 1.491701, 1.483248, 1.474810, 1.466385, 1.457976, 1.449581, 1.441203, 1.432841, 1.424496, 1.416169, 1.407860, 1.399569, 1.391297, 1.383045, 1.374812, 1.366600, 1.358408, 1.350237, 1.342088, 1.333961, 1.325856, 1.317774, 1.309715, 1.301679, 1.293668, 1.285680, 1.277718, 1.269780, 1.261867, 1.253980, 1.246119, 1.238283, 1.230474, 1.222692, 1.214937, 1.207210, 1.199510, 1.191837, 1.184193, 1.176577, 1.168990, 1.161430, 1.153901, 1.146401, 1.138930, 1.131489, 1.124077, 1.116696, 1.109345, 1.102024, 1.094734, 1.087475, 1.080246, 1.073049 }, sd::DataType::FLOAT32); - - NDArray e('f', { 2, 2, 3, 3, 4, 4 }, sd::DataType::FLOAT32); - e.assign(exp); - - sd::ops::tanh_bp op; - op.execute({ &x, &dLdz }, { &dLdx }); - ASSERT_EQ(e, dLdx); + NDArray x('f', {2, 2, 3, 3, 4, 4}, sd::DataType::FLOAT32); + NDArray dLdz('f', {2, 2, 3, 3, 4, 4}, sd::DataType::FLOAT32); + NDArray dLdx('f', {2, 2, 3, 3, 4, 4}, sd::DataType::FLOAT32); + + x.linspace(-1.5, 0.005); + dLdz.linspace(-1., 0.01); + + NDArray exp('c', {2, 2, 3, 3, 4, 4}, + {-0.180707, -0.180525, -0.180324, -0.180103, -0.179861, -0.179599, + -0.179315, -0.179009, -0.178682, -0.178333, -0.177961, -0.177566, + -0.177148, -0.176706, -0.176240, -0.175750, -0.175236, -0.174696, + -0.174130, -0.173539, -0.172922, -0.172278, -0.171607, -0.170909, + -0.170183, -0.169429, -0.168646, -0.167834, -0.166993, -0.166123, + -0.165222, -0.164290, -0.163327, -0.162334, -0.161308, -0.160250, + -0.159159, -0.158035, -0.156877, -0.155686, -0.154460, -0.153199, + -0.151903, -0.150571, -0.149203, -0.147798, -0.146356, -0.144876, + -0.143359, -0.141803, -0.140207, -0.138573, -0.136898, -0.135183, + -0.133428, -0.131630, -0.129792, -0.127910, -0.125986, -0.124019, + -0.122008, -0.119953, -0.117853, -0.115708, -0.113517, -0.111279, + -0.108996, -0.106665, -0.104286, -0.101859, -0.099383, -0.096859, + -0.094284, -0.091660, -0.088984, -0.086258, -0.083480, -0.080649, + -0.077766, -0.074830, -0.071840, -0.068796, -0.065697, -0.062543, + -0.059334, -0.056068, -0.052745, -0.049365, -0.045928, -0.042432, + -0.038878, -0.035264, -0.031591, -0.027858, -0.024064, -0.020209, + -0.016292, -0.012313, -0.008272, -0.004168, 0.000000, 0.004232, + 0.008528, 0.012889, 0.017316, 0.021808, 0.026367, 0.030992, + 0.035684, 0.040444, 0.045272, 0.050169, 0.055134, 0.060168, + 0.065273, 0.070447, 0.075692, 0.081007, 0.086394, 0.091853, + 0.097383, 0.102986, 0.108662, 0.114411, 0.120233, 0.126129, + 0.132099, 0.138144, 0.144263, 0.150457, 0.156727, 0.163072, + 0.169493, 0.175990, 0.182564, 0.189214, 0.195941, 0.202745, + 0.209627, 0.216585, 0.223622, 0.230736, 0.237929, 0.245200, + 0.252549, 0.259976, 0.267482, 0.275066, 0.282730, 0.290472, + 0.298293, 0.306193, 0.314172, 0.322230, 0.330366, 0.338582, + 0.346877, 0.355250, 0.363703, 0.372234, 0.380844, 0.389532, + 0.398299, 0.407144, 0.416067, 0.425068, 0.434147, 0.443303, + 0.452537, 0.461848, 0.471235, 0.480699, 0.490240, 0.499856, + 0.509548, 0.519314, 0.529156, 0.539072, 0.549062, 0.559126, + 0.569262, 0.579471, 0.589753, 0.600106, 0.610530, 0.621024, + 0.631588, 0.642222, 0.652924, 0.663694, 0.674532, 0.685436, + 0.696406, 0.707441, 0.718541, 0.729704, 0.740931, 0.752219, + 0.763568, 0.774978, 0.786448, 0.797976, 0.809561, 0.821203, + 0.832901, 0.844654, 0.856460, 0.868319, 0.880230, 0.892191, + 0.904201, 0.916260, 0.928366, 0.940518, 0.952715, 0.964955, + 0.977238, 0.989561, 1.001925, 1.014327, 1.026767, 1.039242, + 1.051752, 1.064295, 1.076870, 1.089475, 1.102109, 1.114771, + 1.127459, 1.140171, 1.152907, 1.165664, 1.178441, 1.191237, + 1.204050, 1.216878, 1.229720, 1.242573, 1.255438, 1.268311, + 1.281192, 1.294078, 1.306968, 1.319860, 1.332753, 1.345644, + 1.358533, 1.371417, 1.384294, 1.397163, 1.410022, 1.422870, + 1.435704, 1.448522, 1.461323, 1.474105, 1.486867, 1.499606, + 1.512321, 1.525009, 1.537669, 1.550299, 1.562897, 1.575462, + 1.587991, 1.600483, 1.612935, 1.625347, 1.637715, 1.650040, + 1.662317, 1.674545, 1.686724, 1.698850, 1.710922, 1.722939, + 1.734897, 1.746797, 1.758635, 1.770409, 1.782119, 1.793762, + 1.805337, 1.816842, 1.828274, 1.839633, 1.850916, 1.862121, + 1.873248, 1.884294, 1.895258, 1.906137, 1.916931, 1.927637, + 1.938255, 1.948782, 1.959216, 1.969557, 1.979802, 1.989950, + 2.000000, 2.009950, 2.019798, 2.029543, 2.039184, 2.048719, + 2.058147, 2.067466, 2.076675, 2.085773, 2.094759, 2.103630, + 2.112386, 2.121026, 2.129548, 2.137952, 2.146235, 2.154397, + 2.162437, 2.170354, 2.178146, 2.185813, 2.193353, 2.200766, + 2.208051, 2.215207, 2.222232, 2.229127, 2.235889, 2.242520, + 2.249017, 2.255379, 2.261607, 2.267699, 2.273656, 2.279475, + 2.285158, 2.290702, 2.296108, 2.301376, 2.306503, 2.311491, + 2.316339, 2.321046, 2.325613, 2.330038, 2.334321, 2.338464, + 2.342464, 2.346322, 2.350037, 2.353610, 2.357041, 2.360329, + 2.363475, 2.366478, 2.369338, 2.372056, 2.374632, 2.377065, + 2.379356, 2.381505, 2.383512, 2.385378, 2.387103, 2.388686, + 2.390128, 2.391431, 2.392593, 2.393615, 2.394499, 2.395244, + 2.395850, 2.396319, 2.396650, 2.396845, 2.396904, 2.396826, + 2.396615, 2.396268, 2.395789, 2.395176, 2.394431, 2.393554, + 2.392547, 2.391410, 2.390144, 2.388749, 2.387227, 2.385578, + 2.383804, 2.381904, 2.379880, 2.377734, 2.375465, 2.373075, + 2.370565, 2.367936, 2.365188, 2.362324, 2.359343, 2.356247, + 2.353038, 2.349715, 2.346280, 2.342735, 2.339080, 2.335316, + 2.331445, 2.327468, 2.323386, 2.319200, 2.314912, 2.310522, + 2.306031, 2.301442, 2.296754, 2.291970, 2.287090, 2.282116, + 2.277049, 2.271890, 2.266641, 2.261302, 2.255876, 2.250362, + 2.244763, 2.239080, 2.233314, 2.227467, 2.221538, 2.215531, + 2.209445, 2.203284, 2.197047, 2.190736, 2.184352, 2.177897, + 2.171371, 2.164777, 2.158115, 2.151386, 2.144592, 2.137735, + 2.130815, 2.123833, 2.116792, 2.109692, 2.102533, 2.095320, + 2.088051, 2.080727, 2.073352, 2.065925, 2.058447, 2.050921, + 2.043347, 2.035727, 2.028061, 2.020351, 2.012599, 2.004804, + 1.996969, 1.989094, 1.981181, 1.973232, 1.965246, 1.957225, + 1.949171, 1.941084, 1.932965, 1.924816, 1.916638, 1.908432, + 1.900198, 1.891938, 1.883654, 1.875345, 1.867014, 1.858661, + 1.850286, 1.841892, 1.833479, 1.825048, 1.816600, 1.808136, + 1.799657, 1.791165, 1.782659, 1.774141, 1.765612, 1.757073, + 1.748523, 1.739967, 1.731401, 1.722829, 1.714251, 1.705668, + 1.697082, 1.688491, 1.679897, 1.671302, 1.662707, 1.654110, + 1.645514, 1.636920, 1.628328, 1.619738, 1.611152, 1.602570, + 1.593993, 1.585422, 1.576857, 1.568299, 1.559749, 1.551207, + 1.542674, 1.534151, 1.525638, 1.517136, 1.508645, 1.500167, + 1.491701, 1.483248, 1.474810, 1.466385, 1.457976, 1.449581, + 1.441203, 1.432841, 1.424496, 1.416169, 1.407860, 1.399569, + 1.391297, 1.383045, 1.374812, 1.366600, 1.358408, 1.350237, + 1.342088, 1.333961, 1.325856, 1.317774, 1.309715, 1.301679, + 1.293668, 1.285680, 1.277718, 1.269780, 1.261867, 1.253980, + 1.246119, 1.238283, 1.230474, 1.222692, 1.214937, 1.207210, + 1.199510, 1.191837, 1.184193, 1.176577, 1.168990, 1.161430, + 1.153901, 1.146401, 1.138930, 1.131489, 1.124077, 1.116696, + 1.109345, 1.102024, 1.094734, 1.087475, 1.080246, 1.073049}, + sd::DataType::FLOAT32); + + NDArray e('f', {2, 2, 3, 3, 4, 4}, sd::DataType::FLOAT32); + e.assign(exp); + + sd::ops::tanh_bp op; + op.execute({&x, &dLdz}, {&dLdx}); + ASSERT_EQ(e, dLdx); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST) { + NDArray input('c', {2, 2}, {1, 2, 3, 4}, DataType::FLOAT32); + NDArray epsilon('c', {2, 2}, {.1, .2, .3, .4}, DataType::FLOAT32); - NDArray input('c', { 2, 2 }, { 1,2,3,4 }, DataType::FLOAT32); - NDArray epsilon('c', { 2, 2 }, { .1, .2, .3, .4 }, DataType::FLOAT32); - - int axis = 1; + int axis = 1; - NDArray output('c', { 2, 2 }, DataType::FLOAT32); + NDArray output('c', {2, 2}, DataType::FLOAT32); - NDArray exp('c', { 2, 2 }, { -0.019661, 0.019661, -0.019661, 0.019661 }, DataType::FLOAT32); + NDArray exp('c', {2, 2}, {-0.019661, 0.019661, -0.019661, 0.019661}, + DataType::FLOAT32); - sd::ops::softmax_bp op; - - Nd4jStatus status = op.execute({ &input, &epsilon }, { &output }, {}, { axis }); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(output.equalsTo(exp)); + sd::ops::softmax_bp op; + Nd4jStatus status = op.execute({&input, &epsilon}, {&output}, {}, {axis}); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output.equalsTo(exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST2) { - - NDArray input('c', { 4, 5, 2, 3 }, DataType::FLOAT32); - NDArray epsilon('c', { 4, 5, 2, 3 }, { -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855 }, DataType::FLOAT32); - input.linspace(0.1, 0.2); - - int axis = -1; - - NDArray output('c', { 4, 5, 2, 3 }, DataType::FLOAT32); - NDArray exp('c', { 4, 5, 2, 3 }, { -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253 }, DataType::FLOAT32); - - sd::ops::softmax_bp op; - - Nd4jStatus status = op.execute({ &input, &epsilon }, { &output }, {}, { axis }); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(output.equalsTo(exp)); - + NDArray input('c', {4, 5, 2, 3}, DataType::FLOAT32); + NDArray epsilon( + 'c', {4, 5, 2, 3}, + {-0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855}, + DataType::FLOAT32); + input.linspace(0.1, 0.2); + + int axis = -1; + + NDArray output('c', {4, 5, 2, 3}, DataType::FLOAT32); + NDArray exp('c', {4, 5, 2, 3}, + {-0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253}, + DataType::FLOAT32); + + sd::ops::softmax_bp op; + + Nd4jStatus status = op.execute({&input, &epsilon}, {&output}, {}, {axis}); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output.equalsTo(exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST3) { - - NDArray input('f', { 4, 5, 2, 3 }, DataType::FLOAT32); - NDArray epsilon('f', { 4, 5, 2, 3 }, { -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855 }, DataType::FLOAT32); - input.linspace(-5., 0.5); - - int axis = 1; - - NDArray output('f', { 4, 5, 2, 3 }, DataType::FLOAT32); - NDArray expC('c', { 4, 5, 2, 3 }, { -0.0, -0.0, 0.0, 0.0, -0.0, -0.0, 0.0, 0.0, -0.0, -0.0, 0.0, 0.0, 0.000095, -0.000149, 0.000054, 0.000054, 0.000095, -0.000149, -0.001183, -0.001760, 0.002943, 0.002943, -0.001183, -0.001760, 0.001088, 0.001909, -0.002997, -0.002997, 0.001088, 0.001909, -0.000000, 0.000000, -0.000000, -0.000000, -0.000000, 0.000000, 0.000000, -0.000000, 0.000000, 0.000000, 0.000000, -0.000000, -0.000149, 0.000054, 0.000095, 0.000095, -0.000149, 0.000054, -0.001760, 0.002943, -0.001183, -0.001183, -0.001760, 0.002943, 0.001909, -0.002997, 0.001088, 0.001088, 0.001909, -0.002997, 0.000000, -0.000000, -0.000000, -0.000000, 0.000000, -0.000000, -0.000000, 0.000000, 0.000000, 0.000000, -0.000000, 0.000000, 0.000054, 0.000095, -0.000149, -0.000149, 0.000054, 0.000095, 0.002943, -0.001183, -0.001760, -0.001760, 0.002943, -0.001183, -0.002997, 0.001088, 0.001909, 0.001909, -0.002997, 0.001088, -0.000000, -0.000000, 0.000000, 0.000000, -0.000000, -0.000000, 0.000000, 0.000000, -0.000000, -0.000000, 0.000000, 0.000000, 0.000095, -0.000149, 0.000054, 0.000054, 0.000095, -0.000149, -0.001183, -0.001760, 0.002943, 0.002943, -0.001183, -0.001760, 0.001088, 0.001909, -0.002997, -0.002997, 0.001088, 0.001909 }, DataType::FLOAT32); - - NDArray exp('f', { 4, 5, 2, 3 }, DataType::FLOAT32); - exp.assign(expC); - - sd::ops::softmax_bp op; - - Nd4jStatus status = op.execute({ &input, &epsilon }, { &output }, {}, { axis }); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(output.equalsTo(exp)); + NDArray input('f', {4, 5, 2, 3}, DataType::FLOAT32); + NDArray epsilon( + 'f', {4, 5, 2, 3}, + {-0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855}, + DataType::FLOAT32); + input.linspace(-5., 0.5); + + int axis = 1; + + NDArray output('f', {4, 5, 2, 3}, DataType::FLOAT32); + NDArray expC( + 'c', {4, 5, 2, 3}, + {-0.0, -0.0, 0.0, 0.0, -0.0, -0.0, + 0.0, 0.0, -0.0, -0.0, 0.0, 0.0, + 0.000095, -0.000149, 0.000054, 0.000054, 0.000095, -0.000149, + -0.001183, -0.001760, 0.002943, 0.002943, -0.001183, -0.001760, + 0.001088, 0.001909, -0.002997, -0.002997, 0.001088, 0.001909, + -0.000000, 0.000000, -0.000000, -0.000000, -0.000000, 0.000000, + 0.000000, -0.000000, 0.000000, 0.000000, 0.000000, -0.000000, + -0.000149, 0.000054, 0.000095, 0.000095, -0.000149, 0.000054, + -0.001760, 0.002943, -0.001183, -0.001183, -0.001760, 0.002943, + 0.001909, -0.002997, 0.001088, 0.001088, 0.001909, -0.002997, + 0.000000, -0.000000, -0.000000, -0.000000, 0.000000, -0.000000, + -0.000000, 0.000000, 0.000000, 0.000000, -0.000000, 0.000000, + 0.000054, 0.000095, -0.000149, -0.000149, 0.000054, 0.000095, + 0.002943, -0.001183, -0.001760, -0.001760, 0.002943, -0.001183, + -0.002997, 0.001088, 0.001909, 0.001909, -0.002997, 0.001088, + -0.000000, -0.000000, 0.000000, 0.000000, -0.000000, -0.000000, + 0.000000, 0.000000, -0.000000, -0.000000, 0.000000, 0.000000, + 0.000095, -0.000149, 0.000054, 0.000054, 0.000095, -0.000149, + -0.001183, -0.001760, 0.002943, 0.002943, -0.001183, -0.001760, + 0.001088, 0.001909, -0.002997, -0.002997, 0.001088, 0.001909}, + DataType::FLOAT32); + + NDArray exp('f', {4, 5, 2, 3}, DataType::FLOAT32); + exp.assign(expC); + + sd::ops::softmax_bp op; + + Nd4jStatus status = op.execute({&input, &epsilon}, {&output}, {}, {axis}); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output.equalsTo(exp)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, XWPlusB_Bp_1) { - - auto x = NDArrayFactory::create('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); - auto w = NDArrayFactory::create('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); - auto b = NDArrayFactory::create({ 100.f, 200.f }); - - NDArray dLdz('c', { 2, 2 }, DataType::FLOAT32); - dLdz.linspace(1); - - sd::ops::xw_plus_b_bp op; - auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto dLdx = result.at(0); - auto dLdw = result.at(1); - auto dLdb = result.at(2); - - auto edLdx = NDArrayFactory::create('c', { 2,3 }, { 17.f, 14.f, 10.f, 45.f, 32.f, 26.f }); - auto edLdw = NDArrayFactory::create('c', { 3,2 }, { 43.f, 58.f, 26.f, 42.f, 21.f, 30.f }); - auto edLdb = NDArrayFactory::create('c', { 2 }, { 4.f, 6.f }); - - ASSERT_TRUE(edLdx.isSameShape(dLdx)); - ASSERT_TRUE(edLdw.isSameShape(dLdw)); - ASSERT_TRUE(edLdb.isSameShape(dLdb)); - ASSERT_TRUE(edLdx.equalsTo(dLdx)); - ASSERT_TRUE(edLdw.equalsTo(dLdw)); - ASSERT_TRUE(edLdb.equalsTo(dLdb)); + auto x = NDArrayFactory::create('c', {2, 3}, + {1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); + auto w = NDArrayFactory::create('c', {3, 2}, + {11.f, 3.f, 4.f, 5.f, 6.f, 2.f}); + auto b = NDArrayFactory::create({100.f, 200.f}); + + NDArray dLdz('c', {2, 2}, DataType::FLOAT32); + dLdz.linspace(1); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({&x, &w, &b, &dLdz}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdx = NDArrayFactory::create( + 'c', {2, 3}, {17.f, 14.f, 10.f, 45.f, 32.f, 26.f}); + auto edLdw = NDArrayFactory::create( + 'c', {3, 2}, {43.f, 58.f, 26.f, 42.f, 21.f, 30.f}); + auto edLdb = NDArrayFactory::create('c', {2}, {4.f, 6.f}); + + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, XWPlusB_Bp_2) { - - auto x = NDArrayFactory::create('c', { 6,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); - auto w = NDArrayFactory::create('c', { 3,4 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f, 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); - auto b = NDArrayFactory::create('c', { 4 }, { 100.f, 200.f, 100.f, 200.f }); - - NDArray dLdz('c', { 6, 4 }, DataType::FLOAT32); - dLdz.linspace(.1, .5); - - sd::ops::xw_plus_b_bp op; - auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto dLdx = result.at(0); - auto dLdw = result.at(1); - auto dLdb = result.at(2); - - auto edLdx = NDArrayFactory::create('c', { 6,3 }, { 15.3f, 18.700001f, 13.2f, 61.299995f, 62.699997f, 47.200001f, 107.299995f, 106.699997f, 81.199997f, 153.299988f, 150.699997f, 115.199997f, 199.300018f, 194.700012f, 149.199997f, 245.300018f, 238.700012f, 183.199997f }); - auto edLdw = NDArrayFactory::create('c', { 3,4 }, { 268.5f, 291.f, 313.5f, 336.f, 226.800003f, 250.800003f, 274.799988f, 298.799988f, 146.699997f, 160.199997f, 173.700012f, 187.200012f }); - auto edLdb = NDArrayFactory::create('c', { 4 }, { 30.6f, 33.599998f, 36.599998f, 39.599998f }); - ASSERT_TRUE(edLdx.isSameShape(dLdx)); - ASSERT_TRUE(edLdw.isSameShape(dLdw)); - ASSERT_TRUE(edLdb.isSameShape(dLdb)); - ASSERT_TRUE(edLdx.equalsTo(dLdx)); - ASSERT_TRUE(edLdw.equalsTo(dLdw)); - ASSERT_TRUE(edLdb.equalsTo(dLdb)); + auto x = NDArrayFactory::create( + 'c', {6, 3}, + {1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, + 11.f, 3.f, 14.f, 5.f, 6.f}); + auto w = NDArrayFactory::create( + 'c', {3, 4}, + {11.f, 3.f, 4.f, 5.f, 6.f, 2.f, 11.f, 3.f, 4.f, 5.f, 6.f, 2.f}); + auto b = + NDArrayFactory::create('c', {4}, {100.f, 200.f, 100.f, 200.f}); + + NDArray dLdz('c', {6, 4}, DataType::FLOAT32); + dLdz.linspace(.1, .5); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({&x, &w, &b, &dLdz}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdx = NDArrayFactory::create( + 'c', {6, 3}, + {15.3f, 18.700001f, 13.2f, 61.299995f, 62.699997f, 47.200001f, + 107.299995f, 106.699997f, 81.199997f, 153.299988f, 150.699997f, + 115.199997f, 199.300018f, 194.700012f, 149.199997f, 245.300018f, + 238.700012f, 183.199997f}); + auto edLdw = NDArrayFactory::create( + 'c', {3, 4}, + {268.5f, 291.f, 313.5f, 336.f, 226.800003f, 250.800003f, 274.799988f, + 298.799988f, 146.699997f, 160.199997f, 173.700012f, 187.200012f}); + auto edLdb = NDArrayFactory::create( + 'c', {4}, {30.6f, 33.599998f, 36.599998f, 39.599998f}); + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, XWPlusB_Bp_3) { - - auto x = NDArrayFactory::create('c', { 1, 2 }, { 1.f, 11.f }); - auto w = NDArrayFactory::create('c', { 2, 3 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); - auto b = NDArrayFactory::create({ 100.f, 200.f, 300.f }); - - auto dLdz = NDArrayFactory::create('c', { 1, 3 }, { 166.f, 269.f, 326.f }); - - sd::ops::xw_plus_b_bp op; - auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto dLdx = result.at(0); - auto dLdw = result.at(1); - auto dLdb = result.at(2); - - auto edLdx = NDArrayFactory::create('c', { 1,2 }, { 3937.f, 3096.f }); - auto edLdw = NDArrayFactory::create('c', { 2,3 }, { 166.f, 269.f, 326.f, 1826.f, 2959.f, 3586.f }); - auto edLdb = NDArrayFactory::create('c', { 3 }, { 166.f, 269.f, 326.f }); - ASSERT_TRUE(edLdx.isSameShape(dLdx)); - ASSERT_TRUE(edLdw.isSameShape(dLdw)); - ASSERT_TRUE(edLdb.isSameShape(dLdb)); - ASSERT_TRUE(edLdx.equalsTo(dLdx)); - ASSERT_TRUE(edLdw.equalsTo(dLdw)); - ASSERT_TRUE(edLdb.equalsTo(dLdb)); - + auto x = NDArrayFactory::create('c', {1, 2}, {1.f, 11.f}); + auto w = NDArrayFactory::create('c', {2, 3}, + {11.f, 3.f, 4.f, 5.f, 6.f, 2.f}); + auto b = NDArrayFactory::create({100.f, 200.f, 300.f}); + + auto dLdz = NDArrayFactory::create('c', {1, 3}, {166.f, 269.f, 326.f}); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({&x, &w, &b, &dLdz}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdx = NDArrayFactory::create('c', {1, 2}, {3937.f, 3096.f}); + auto edLdw = NDArrayFactory::create( + 'c', {2, 3}, {166.f, 269.f, 326.f, 1826.f, 2959.f, 3586.f}); + auto edLdb = NDArrayFactory::create('c', {3}, {166.f, 269.f, 326.f}); + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, XWPlusB_Bp_4) { - - auto x = NDArrayFactory::create('c', { 1, 2 }, { 1.f, 11.f }); - auto w = NDArrayFactory::create('c', { 2, 1 }, { 11.f, 3.f }); - auto b = NDArrayFactory::create('c', { 1 }, { 200.f }); - - auto dLdz = NDArrayFactory::create('c', { 1,1 }, { 244.f }); - - sd::ops::xw_plus_b_bp op; - auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto dLdx = result.at(0); - auto dLdw = result.at(1); - auto dLdb = result.at(2); - - auto edLdx = NDArrayFactory::create('c', { 1,2 }, { 2684.f, 732.f }); - auto edLdw = NDArrayFactory::create('c', { 2,1 }, { 244.f, 2684.f }); - auto edLdb = NDArrayFactory::create('c', { 1 }, { 244.f }); - ASSERT_TRUE(edLdx.isSameShape(dLdx)); - ASSERT_TRUE(edLdw.isSameShape(dLdw)); - ASSERT_TRUE(edLdb.isSameShape(dLdb)); - ASSERT_TRUE(edLdx.equalsTo(dLdx)); - ASSERT_TRUE(edLdw.equalsTo(dLdw)); - ASSERT_TRUE(edLdb.equalsTo(dLdb)); - + auto x = NDArrayFactory::create('c', {1, 2}, {1.f, 11.f}); + auto w = NDArrayFactory::create('c', {2, 1}, {11.f, 3.f}); + auto b = NDArrayFactory::create('c', {1}, {200.f}); + + auto dLdz = NDArrayFactory::create('c', {1, 1}, {244.f}); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({&x, &w, &b, &dLdz}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdx = NDArrayFactory::create('c', {1, 2}, {2684.f, 732.f}); + auto edLdw = NDArrayFactory::create('c', {2, 1}, {244.f, 2684.f}); + auto edLdb = NDArrayFactory::create('c', {1}, {244.f}); + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, XWPlusB_Bp_5) { - - auto x = NDArrayFactory::create('f', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); - auto w = NDArrayFactory::create('f', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); - auto b = NDArrayFactory::create({ 100.f, 200.f }); - - auto dLdz = NDArrayFactory::create('f', { 2,2 }, { 140.f, 287.f, 233.f, 351.f }); - - sd::ops::xw_plus_b_bp op; - auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto dLdx = result.at(0); - auto dLdw = result.at(1); - auto dLdb = result.at(2); - - auto edLdxC = NDArrayFactory::create('c', { 2,3 }, { 2705.f, 1818.f, 1026.f, 4912.f, 2967.f, 1850.f }); - auto edLdwC = NDArrayFactory::create('c', { 3,2 }, { 3297.f, 4094.f, 4438.f, 5613.f, 2422.f, 3271.f }); - auto edLdbC = NDArrayFactory::create('c', { 2 }, { 427.f, 584.f }); - - auto edLdx = NDArrayFactory::create('f', { 2,3 }); - auto edLdw = NDArrayFactory::create('f', { 3,2 }); - auto edLdb = NDArrayFactory::create('f', { 2 }); - - edLdx.assign(edLdxC); - edLdw.assign(edLdwC); - edLdb.assign(edLdbC); - - ASSERT_TRUE(edLdx.isSameShape(dLdx)); - ASSERT_TRUE(edLdw.isSameShape(dLdw)); - ASSERT_TRUE(edLdb.isSameShape(dLdb)); - ASSERT_TRUE(edLdx.equalsTo(dLdx)); - ASSERT_TRUE(edLdw.equalsTo(dLdw)); - ASSERT_TRUE(edLdb.equalsTo(dLdb)); - + auto x = NDArrayFactory::create('f', {2, 3}, + {1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); + auto w = NDArrayFactory::create('f', {3, 2}, + {11.f, 3.f, 4.f, 5.f, 6.f, 2.f}); + auto b = NDArrayFactory::create({100.f, 200.f}); + + auto dLdz = + NDArrayFactory::create('f', {2, 2}, {140.f, 287.f, 233.f, 351.f}); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({&x, &w, &b, &dLdz}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdxC = NDArrayFactory::create( + 'c', {2, 3}, {2705.f, 1818.f, 1026.f, 4912.f, 2967.f, 1850.f}); + auto edLdwC = NDArrayFactory::create( + 'c', {3, 2}, {3297.f, 4094.f, 4438.f, 5613.f, 2422.f, 3271.f}); + auto edLdbC = NDArrayFactory::create('c', {2}, {427.f, 584.f}); + + auto edLdx = NDArrayFactory::create('f', {2, 3}); + auto edLdw = NDArrayFactory::create('f', {3, 2}); + auto edLdb = NDArrayFactory::create('f', {2}); + + edLdx.assign(edLdxC); + edLdw.assign(edLdwC); + edLdb.assign(edLdbC); + + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, XWPlusB_Bp_6) { - - auto x = NDArrayFactory::create('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); - auto w = NDArrayFactory::create('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); - auto b = NDArrayFactory::create({ 100.f, 200.f }); - - auto dLdz = NDArrayFactory::create('c', { 2,2 }, { 173.f, 264.f, 310.f, 279.f }); - - // mkl-format - w.permutei({ 1,0 }); - - sd::ops::xw_plus_b_bp op; - auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, { 1 }); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto dLdx = result.at(0); - auto dLdw = result.at(1); - auto dLdb = result.at(2); - - auto edLdx = NDArrayFactory::create('c', { 2,3 }, { 2695.f, 2012.f, 1566.f, 4247.f, 2635.f, 2418.f }); - auto edLdwC = NDArrayFactory::create('c', { 3,2 }, { 4513.f, 3453.f, 2379.f, 4170.f, 4299.f, 2466.f }); - auto edLdb = NDArrayFactory::create('c', { 2 }, { 483.f, 543.f }); - auto edLdw = NDArrayFactory::create('c', { 3,2 }, { 4513.f, 3453.f, 2379.f, 4170.f, 4299.f, 2466.f }); - edLdw.permutei({ 1,0 }); - edLdw.assign(edLdwC); - - ASSERT_TRUE(edLdx.isSameShape(dLdx)); - ASSERT_TRUE(edLdw.isSameShape(dLdw)); - ASSERT_TRUE(edLdb.isSameShape(dLdb)); - ASSERT_TRUE(edLdx.equalsTo(dLdx)); - ASSERT_TRUE(edLdw.equalsTo(dLdw)); - ASSERT_TRUE(edLdb.equalsTo(dLdb)); + auto x = NDArrayFactory::create('c', {2, 3}, + {1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); + auto w = NDArrayFactory::create('c', {3, 2}, + {11.f, 3.f, 4.f, 5.f, 6.f, 2.f}); + auto b = NDArrayFactory::create({100.f, 200.f}); + + auto dLdz = + NDArrayFactory::create('c', {2, 2}, {173.f, 264.f, 310.f, 279.f}); + + // mkl-format + w.permutei({1, 0}); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({&x, &w, &b, &dLdz}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdx = NDArrayFactory::create( + 'c', {2, 3}, {2695.f, 2012.f, 1566.f, 4247.f, 2635.f, 2418.f}); + auto edLdwC = NDArrayFactory::create( + 'c', {3, 2}, {4513.f, 3453.f, 2379.f, 4170.f, 4299.f, 2466.f}); + auto edLdb = NDArrayFactory::create('c', {2}, {483.f, 543.f}); + auto edLdw = NDArrayFactory::create( + 'c', {3, 2}, {4513.f, 3453.f, 2379.f, 4170.f, 4299.f, 2466.f}); + edLdw.permutei({1, 0}); + edLdw.assign(edLdwC); + + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterSgd1) { + NDArray gradient('c', {1, 5}, + {0.21138794720172882, 0.38947954773902893, + 0.2822134494781494, 0.4342866837978363, 0.7928546667098999}, + DataType::FLOAT32); + auto lr = NDArrayFactory::create(0.001f); - NDArray gradient('c', { 1, 5 }, { 0.21138794720172882, 0.38947954773902893, 0.2822134494781494, 0.4342866837978363, 0.7928546667098999 }, DataType::FLOAT32); - auto lr = NDArrayFactory::create(0.001f); - - NDArray update('c', { 1, 5 }, { 0.00021138794720173, 0.00038947954773903, 0.00028221344947815, 0.00043428668379784, 0.0007928546667099 }, DataType::FLOAT32); - - sd::ops::sgd_updater op; + NDArray update('c', {1, 5}, + {0.00021138794720173, 0.00038947954773903, 0.00028221344947815, + 0.00043428668379784, 0.0007928546667099}, + DataType::FLOAT32); - Nd4jStatus status = op.execute({ &gradient, &lr }, { &gradient }, {}, { }); + sd::ops::sgd_updater op; - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(update.equalsTo(gradient)); + Nd4jStatus status = op.execute({&gradient, &lr}, {&gradient}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(update.equalsTo(gradient)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterSgd2) { + NDArray gradient('c', {1, 5}, + {0.21138794720172882, 0.38947954773902893, + 0.2822134494781494, 0.4342866837978363, 0.7928546667098999}, + DataType::FLOAT32); - NDArray gradient('c', { 1, 5 }, { 0.21138794720172882, 0.38947954773902893, 0.2822134494781494, 0.4342866837978363, 0.7928546667098999 }, DataType::FLOAT32); + NDArray update('c', {1, 5}, + {0.00021138794720173, 0.00038947954773903, 0.00028221344947815, + 0.00043428668379784, 0.0007928546667099}, + DataType::FLOAT32); - NDArray update('c', { 1, 5 }, { 0.00021138794720173, 0.00038947954773903, 0.00028221344947815, 0.00043428668379784, 0.0007928546667099 }, DataType::FLOAT32); - - sd::ops::sgd_updater op; - - Nd4jStatus status = op.execute({ &gradient }, { &gradient }, { 0.001f }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(update.equalsTo(gradient)); + sd::ops::sgd_updater op; + Nd4jStatus status = op.execute({&gradient}, {&gradient}, {0.001f}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(update.equalsTo(gradient)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterSgd3) { - - NDArray gradientC('c', { 1, 5 }, { 0.21138794720172882, 0.38947954773902893, 0.2822134494781494, 0.4342866837978363, 0.7928546667098999 }, DataType::FLOAT32); - - NDArray updateC('c', { 1, 5 }, { 0.00021138794720173, 0.00038947954773903, 0.00028221344947815, 0.00043428668379784, 0.0007928546667099 }, DataType::FLOAT32); - - NDArray gradient('f', { 1, 5 }, DataType::FLOAT32); - NDArray update('f', { 1, 5 }, DataType::FLOAT32); - - gradient.assign(gradientC); - update.assign(updateC); - - sd::ops::sgd_updater op; - - auto results = op.evaluate({ &gradient }, { 0.001f }, { }); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); + NDArray gradientC( + 'c', {1, 5}, + {0.21138794720172882, 0.38947954773902893, 0.2822134494781494, + 0.4342866837978363, 0.7928546667098999}, + DataType::FLOAT32); + + NDArray updateC( + 'c', {1, 5}, + {0.00021138794720173, 0.00038947954773903, 0.00028221344947815, + 0.00043428668379784, 0.0007928546667099}, + DataType::FLOAT32); + + NDArray gradient('f', {1, 5}, DataType::FLOAT32); + NDArray update('f', {1, 5}, DataType::FLOAT32); + + gradient.assign(gradientC); + update.assign(updateC); + + sd::ops::sgd_updater op; + + auto results = op.evaluate({&gradient}, {0.001f}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterRmsProm1) { - - NDArray grad0('c', { 1, 5 }, { 0.1811431348323822, 0.10499879717826843, 0.8736756443977356, 0.9707390666007996, 0.7415646314620972 }, DataType::FLOAT32); - NDArray init('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); - - auto lr = NDArrayFactory::create(0.1f); - auto decay = NDArrayFactory::create(0.95f); - auto epsilon = NDArrayFactory::create(1.e-8f); - - sd::ops::rms_prop_updater op; - - Nd4jStatus status = op.execute({ &grad0, &init, &lr, &decay, &epsilon }, { &grad0, &init }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - NDArray updateExp0('c', { 1, 5 }, { 0.4472121903197142, 0.4472095514452829, 0.4472135169488324, 0.44721352981195367, 0.44721349127249754 }, DataType::FLOAT32); - NDArray stateG0('c', { 1, 5 }, { 0.00164065126484513, 0.00055124687044416, 0.03816546608068996, 0.04711672627124962, 0.02749591463177582 }, DataType::FLOAT32); - - ASSERT_TRUE(grad0.equalsTo(updateExp0)); - ASSERT_TRUE(init.equalsTo(stateG0)); - - - NDArray grad1('c', { 1, 5 }, { 0.0139725673943758, 0.19333727657794952, 0.9288347363471985, 0.9253600239753723, 0.3578299283981323 }, DataType::FLOAT32); - status = op.execute({ &grad1, &init, &lr, &decay, &epsilon }, { &grad1, &init }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - NDArray updateExp1('c', { 1, 5 }, { 0.03528177364993147, 0.3952537075263024, 0.32964378302079766, 0.31269398966616074, 0.1984174163852542 }, DataType::FLOAT32); - NDArray stateG1('c', { 1, 5 }, { 0.00156838033358239, 0.00239264965265088, 0.07939389114891399, 0.08757544865627226, 0.03252323178305766 }, DataType::FLOAT32); - - ASSERT_TRUE(grad1.equalsTo(updateExp1)); - ASSERT_TRUE(init.equalsTo(stateG1)); - - NDArray grad2('c', { 1, 5 }, { 0.5442887544631958, 0.5386605262756348, 0.884294331073761, 0.15599730610847473, 0.7259345054626465 }, DataType::FLOAT32); - status = op.execute({ &grad2, &init, &lr, &decay, &epsilon }, { &grad2, &init }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - NDArray updateExp2('c', { 1, 5 }, { 0.4262874753567082, 0.41582357367557454, 0.2613066321005825, 0.05369221235564697, 0.3034061716240995 }, DataType::FLOAT32); - NDArray stateG2('c', { 1, 5 }, { 0.01630247372865814, 0.01678077529839554, 0.11452301978992785, 0.0844134341991137, 0.05724611550496966 }, DataType::FLOAT32); - - ASSERT_TRUE(grad2.equalsTo(updateExp2)); - ASSERT_TRUE(init.equalsTo(stateG2)); - + NDArray grad0('c', {1, 5}, + {0.1811431348323822, 0.10499879717826843, 0.8736756443977356, + 0.9707390666007996, 0.7415646314620972}, + DataType::FLOAT32); + NDArray init('c', {1, 5}, + {0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001}, + DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.1f); + auto decay = NDArrayFactory::create(0.95f); + auto epsilon = NDArrayFactory::create(1.e-8f); + + sd::ops::rms_prop_updater op; + + Nd4jStatus status = op.execute({&grad0, &init, &lr, &decay, &epsilon}, + {&grad0, &init}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp0( + 'c', {1, 5}, + {0.4472121903197142, 0.4472095514452829, 0.4472135169488324, + 0.44721352981195367, 0.44721349127249754}, + DataType::FLOAT32); + NDArray stateG0( + 'c', {1, 5}, + {0.00164065126484513, 0.00055124687044416, 0.03816546608068996, + 0.04711672627124962, 0.02749591463177582}, + DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(init.equalsTo(stateG0)); + + NDArray grad1('c', {1, 5}, + {0.0139725673943758, 0.19333727657794952, 0.9288347363471985, + 0.9253600239753723, 0.3578299283981323}, + DataType::FLOAT32); + status = op.execute({&grad1, &init, &lr, &decay, &epsilon}, {&grad1, &init}, + {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp1( + 'c', {1, 5}, + {0.03528177364993147, 0.3952537075263024, 0.32964378302079766, + 0.31269398966616074, 0.1984174163852542}, + DataType::FLOAT32); + NDArray stateG1( + 'c', {1, 5}, + {0.00156838033358239, 0.00239264965265088, 0.07939389114891399, + 0.08757544865627226, 0.03252323178305766}, + DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(init.equalsTo(stateG1)); + + NDArray grad2('c', {1, 5}, + {0.5442887544631958, 0.5386605262756348, 0.884294331073761, + 0.15599730610847473, 0.7259345054626465}, + DataType::FLOAT32); + status = op.execute({&grad2, &init, &lr, &decay, &epsilon}, {&grad2, &init}, + {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp2( + 'c', {1, 5}, + {0.4262874753567082, 0.41582357367557454, 0.2613066321005825, + 0.05369221235564697, 0.3034061716240995}, + DataType::FLOAT32); + NDArray stateG2( + 'c', {1, 5}, + {0.01630247372865814, 0.01678077529839554, 0.11452301978992785, + 0.0844134341991137, 0.05724611550496966}, + DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(init.equalsTo(stateG2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterRmsProm2) { - - NDArray grad('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); - NDArray init('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); - - NDArray update('c', { 1, 5 }, DataType::FLOAT32); - - sd::ops::rms_prop_updater op; - - Nd4jStatus status = op.execute({ &grad, &init }, { &update, &init }, { 0.1f, 0.95f, 1.e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - NDArray updateExp0('c', { 1, 5 }, { 0.4472135330146769, 0.44721357487863594, 0.44721358411270346, 0.4472135878446271, 0.447213589800546 }, DataType::FLOAT32); - NDArray stateG0('c', { 1, 5 }, { 0.05000000950000005, 0.2000000095000002, 0.4500000095000004, 0.8000000095000007, 1.250000009500001 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp0)); - ASSERT_TRUE(init.equalsTo(stateG0)); - - status = op.execute({ &grad, &init }, { &update, &init }, { 0.1f, 0.95f, 1.e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - NDArray updateExp1('c', { 1, 5 }, { 0.32025628253164734, 0.3202562987764395, 0.32025630254446874, 0.3202563041196892, 0.3202563049660074 }, DataType::FLOAT32); - NDArray stateG1('c', { 1, 5 }, { 0.09750000902500008, 0.3900000090250003, 0.8775000090250007, 1.5600000090250012, 2.437500009025002 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp1)); - ASSERT_TRUE(init.equalsTo(stateG1)); - - status = op.execute({ &grad, &init }, { &update, &init }, { 0.1f, 0.95f, 1.e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - NDArray updateExp2('c', { 1, 5 }, { 0.2647903457769699, 0.2647903552517623, 0.26479035752571606, 0.2647903584968847, 0.2647903590265272 }, DataType::FLOAT32); - NDArray stateG2('c', { 1, 5 }, { 0.1426250085737501, 0.5705000085737504, 1.283625008573751, 2.2820000085737515, 3.565625008573753 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp2)); - ASSERT_TRUE(init.equalsTo(stateG2)); - + NDArray grad('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray init('c', {1, 5}, + {0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001}, + DataType::FLOAT32); + + NDArray update('c', {1, 5}, DataType::FLOAT32); + + sd::ops::rms_prop_updater op; + + Nd4jStatus status = + op.execute({&grad, &init}, {&update, &init}, {0.1f, 0.95f, 1.e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp0( + 'c', {1, 5}, + {0.4472135330146769, 0.44721357487863594, 0.44721358411270346, + 0.4472135878446271, 0.447213589800546}, + DataType::FLOAT32); + NDArray stateG0('c', {1, 5}, + {0.05000000950000005, 0.2000000095000002, 0.4500000095000004, + 0.8000000095000007, 1.250000009500001}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(init.equalsTo(stateG0)); + + status = + op.execute({&grad, &init}, {&update, &init}, {0.1f, 0.95f, 1.e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp1( + 'c', {1, 5}, + {0.32025628253164734, 0.3202562987764395, 0.32025630254446874, + 0.3202563041196892, 0.3202563049660074}, + DataType::FLOAT32); + NDArray stateG1('c', {1, 5}, + {0.09750000902500008, 0.3900000090250003, 0.8775000090250007, + 1.5600000090250012, 2.437500009025002}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(init.equalsTo(stateG1)); + + status = + op.execute({&grad, &init}, {&update, &init}, {0.1f, 0.95f, 1.e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp2( + 'c', {1, 5}, + {0.2647903457769699, 0.2647903552517623, 0.26479035752571606, + 0.2647903584968847, 0.2647903590265272}, + DataType::FLOAT32); + NDArray stateG2('c', {1, 5}, + {0.1426250085737501, 0.5705000085737504, 1.283625008573751, + 2.2820000085737515, 3.565625008573753}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(init.equalsTo(stateG2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterRmsProm3) { - - NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); - NDArray initC('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); - - NDArray grad('f', { 1, 5 }, DataType::FLOAT32); - NDArray init('f', { 1, 5 }, DataType::FLOAT32); - grad.assign(gradC); - init.assign(initC); - - sd::ops::rms_prop_updater op; - auto results = op.evaluate({ &grad, &init }, { 0.1f, 0.95f, 1.e-8 }, { }); - - NDArray updateC('c', { 1, 5 }, { 0.4472135330146769, 0.44721357487863594, 0.44721358411270346, 0.4472135878446271, 0.447213589800546 }, DataType::FLOAT32); - NDArray update('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateG0C('c', { 1, 5 }, { 0.05000000950000005, 0.2000000095000002, 0.4500000095000004, 0.8000000095000007, 1.250000009500001 }, DataType::FLOAT32); - NDArray stateG('f', { 1, 5 }, DataType::FLOAT32); - - update.assign(updateC); - stateG.assign(stateG0C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateG.isSameShape(results.at(1))); - ASSERT_TRUE(stateG.equalsTo(results.at(1))); - - results = op.evaluate({ &grad, &stateG }, { 0.1f, 0.95f, 1.e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray update1C('c', { 1, 5 }, { 0.32025628253164734, 0.3202562987764395, 0.32025630254446874, 0.3202563041196892, 0.3202563049660074 }, DataType::FLOAT32); - NDArray stateG1C('c', { 1, 5 }, { 0.09750000902500008, 0.3900000090250003, 0.8775000090250007, 1.5600000090250012, 2.437500009025002 }, DataType::FLOAT32); - - update.assign(update1C); - stateG.assign(stateG1C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateG.isSameShape(results.at(1))); - ASSERT_TRUE(stateG.equalsTo(results.at(1))); - - - results = op.evaluate({ &grad, &stateG }, { 0.1f, 0.95f, 1.e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray update2C('c', { 1, 5 }, { 0.2647903457769699, 0.2647903552517623, 0.26479035752571606, 0.2647903584968847, 0.2647903590265272 }, DataType::FLOAT32); - NDArray stateG2C('c', { 1, 5 }, { 0.1426250085737501, 0.5705000085737504, 1.283625008573751, 2.2820000085737515, 3.565625008573753 }, DataType::FLOAT32); - - update.assign(update2C); - stateG.assign(stateG2C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateG.isSameShape(results.at(1))); - ASSERT_TRUE(stateG.equalsTo(results.at(1))); + NDArray gradC('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray initC('c', {1, 5}, + {0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001}, + DataType::FLOAT32); + + NDArray grad('f', {1, 5}, DataType::FLOAT32); + NDArray init('f', {1, 5}, DataType::FLOAT32); + grad.assign(gradC); + init.assign(initC); + + sd::ops::rms_prop_updater op; + auto results = op.evaluate({&grad, &init}, {0.1f, 0.95f, 1.e-8}, {}); + + NDArray updateC('c', {1, 5}, + {0.4472135330146769, 0.44721357487863594, 0.44721358411270346, + 0.4472135878446271, 0.447213589800546}, + DataType::FLOAT32); + NDArray update('f', {1, 5}, DataType::FLOAT32); + + NDArray stateG0C('c', {1, 5}, + {0.05000000950000005, 0.2000000095000002, 0.4500000095000004, + 0.8000000095000007, 1.250000009500001}, + DataType::FLOAT32); + NDArray stateG('f', {1, 5}, DataType::FLOAT32); + + update.assign(updateC); + stateG.assign(stateG0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); + + results = op.evaluate({&grad, &stateG}, {0.1f, 0.95f, 1.e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C( + 'c', {1, 5}, + {0.32025628253164734, 0.3202562987764395, 0.32025630254446874, + 0.3202563041196892, 0.3202563049660074}, + DataType::FLOAT32); + NDArray stateG1C('c', {1, 5}, + {0.09750000902500008, 0.3900000090250003, 0.8775000090250007, + 1.5600000090250012, 2.437500009025002}, + DataType::FLOAT32); + + update.assign(update1C); + stateG.assign(stateG1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); + + results = op.evaluate({&grad, &stateG}, {0.1f, 0.95f, 1.e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update2C('c', {1, 5}, + {0.2647903457769699, 0.2647903552517623, 0.26479035752571606, + 0.2647903584968847, 0.2647903590265272}, + DataType::FLOAT32); + NDArray stateG2C('c', {1, 5}, + {0.1426250085737501, 0.5705000085737504, 1.283625008573751, + 2.2820000085737515, 3.565625008573753}, + DataType::FLOAT32); + + update.assign(update2C); + stateG.assign(stateG2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAdaGrad1) { + // need Java test - // need Java test + NDArray grad0('c', {1, 5}, + {0.1811431348323822, 0.10499879717826843, 0.8736756443977356, + 0.9707390666007996, 0.7415646314620972}, + DataType::FLOAT32); + NDArray init('c', {1, 5}, + {0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001}, + DataType::FLOAT32); - NDArray grad0('c', { 1, 5 }, { 0.1811431348323822, 0.10499879717826843, 0.8736756443977356, 0.9707390666007996, 0.7415646314620972 }, DataType::FLOAT32); - NDArray init('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); + auto lr = NDArrayFactory::create(0.1f); + auto epsilon = NDArrayFactory::create(1.e-8f); - auto lr = NDArrayFactory::create(0.1f); - auto epsilon = NDArrayFactory::create(1.e-8f); - - sd::ops::ada_grad_updater op; - - Nd4jStatus status = op.execute({ &grad0, &init, &lr, &epsilon }, { &grad0, &init }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); + sd::ops::ada_grad_updater op; + Nd4jStatus status = + op.execute({&grad0, &init, &lr, &epsilon}, {&grad0, &init}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterNesterovs1) { - - NDArray grad0('c', { 1, 5 }, { 0.6877592206001282, 0.7830561399459839, 0.7647699117660522, 0.6183066964149475, 0.3303879499435425 }, DataType::FLOAT32); - NDArray init('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - sd::ops::nesterovs_updater op; - - Nd4jStatus status = op.execute({ &grad0, &init }, { &grad0, &init }, { 0.1f, 0.9f }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp0('c', { 1, 5 }, { 0.13067425191402435, 0.14878066658973696, 0.14530628323554992, 0.11747827231884002, 0.06277371048927306 }, DataType::FLOAT32); - NDArray stateV0('c', { 1, 5 }, { -0.06877592206001282, -0.0783056139945984, -0.07647699117660522, -0.06183066964149475, -0.03303879499435425 }, DataType::FLOAT32); - - ASSERT_TRUE(grad0.equalsTo(updateExp0)); - ASSERT_TRUE(init.equalsTo(stateV0)); - - NDArray grad1('c', { 1, 5 }, { 0.3676236569881439, 0.07645636051893234, 0.45949840545654297, 0.6335387825965881, 0.2953402101993561 }, DataType::FLOAT32); - status = op.execute({ &grad1, &init }, { &grad1, &init }, { 0.1f, 0.9f }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - NDArray updateExp1('c', { 1, 5 }, { 0.12555699169635773, 0.07795425583422186, 0.14925105988979342, 0.17045521110296247, 0.08287606388330458 }, DataType::FLOAT32); - NDArray stateV1('c', { 1, 5 }, { -0.09866069555282593, -0.0781206886470318, -0.11477913260459902, -0.11900148093700408, -0.05926893651485443 }, DataType::FLOAT32); - - ASSERT_TRUE(grad1.equalsTo(updateExp1)); - ASSERT_TRUE(init.equalsTo(stateV1)); - - NDArray grad2('c', { 1, 5 }, { 0.9874004125595093, 0.41817641258239746, 0.16838215291500092, 0.00803728867322206, 0.37015461921691895 }, DataType::FLOAT32); - status = op.execute({ &grad2, &init }, { &grad2, &init }, { 0.1f, 0.9f }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - NDArray updateExp2('c', { 1, 5 }, { 0.26752124178409575, 0.1427312761947513, 0.12496370646357537, 0.09791828440688549, 0.11833721622824667 }, DataType::FLOAT32); - NDArray stateV2('c', { 1, 5 }, { -0.18753466725349427, -0.11212626104056837, -0.12013943463563921, -0.10790506171062587, -0.09035750478506088 }, DataType::FLOAT32); - - ASSERT_TRUE(grad2.equalsTo(updateExp2)); - ASSERT_TRUE(init.equalsTo(stateV2)); - + NDArray grad0('c', {1, 5}, + {0.6877592206001282, 0.7830561399459839, 0.7647699117660522, + 0.6183066964149475, 0.3303879499435425}, + DataType::FLOAT32); + NDArray init('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + sd::ops::nesterovs_updater op; + + Nd4jStatus status = + op.execute({&grad0, &init}, {&grad0, &init}, {0.1f, 0.9f}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0( + 'c', {1, 5}, + {0.13067425191402435, 0.14878066658973696, 0.14530628323554992, + 0.11747827231884002, 0.06277371048927306}, + DataType::FLOAT32); + NDArray stateV0( + 'c', {1, 5}, + {-0.06877592206001282, -0.0783056139945984, -0.07647699117660522, + -0.06183066964149475, -0.03303879499435425}, + DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(init.equalsTo(stateV0)); + + NDArray grad1('c', {1, 5}, + {0.3676236569881439, 0.07645636051893234, 0.45949840545654297, + 0.6335387825965881, 0.2953402101993561}, + DataType::FLOAT32); + status = op.execute({&grad1, &init}, {&grad1, &init}, {0.1f, 0.9f}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp1( + 'c', {1, 5}, + {0.12555699169635773, 0.07795425583422186, 0.14925105988979342, + 0.17045521110296247, 0.08287606388330458}, + DataType::FLOAT32); + NDArray stateV1( + 'c', {1, 5}, + {-0.09866069555282593, -0.0781206886470318, -0.11477913260459902, + -0.11900148093700408, -0.05926893651485443}, + DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(init.equalsTo(stateV1)); + + NDArray grad2('c', {1, 5}, + {0.9874004125595093, 0.41817641258239746, 0.16838215291500092, + 0.00803728867322206, 0.37015461921691895}, + DataType::FLOAT32); + status = op.execute({&grad2, &init}, {&grad2, &init}, {0.1f, 0.9f}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp2( + 'c', {1, 5}, + {0.26752124178409575, 0.1427312761947513, 0.12496370646357537, + 0.09791828440688549, 0.11833721622824667}, + DataType::FLOAT32); + NDArray stateV2( + 'c', {1, 5}, + {-0.18753466725349427, -0.11212626104056837, -0.12013943463563921, + -0.10790506171062587, -0.09035750478506088}, + DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(init.equalsTo(stateV2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterNesterovs2) { - - NDArray grad('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); - NDArray init('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); - - NDArray update('c', { 1, 5 }, DataType::FLOAT32); - - auto lr = NDArrayFactory::create(0.1f); - auto momentum = NDArrayFactory::create(0.9f); - - sd::ops::nesterovs_updater op; - - Nd4jStatus status = op.execute({ &grad, &init, &lr, &momentum }, { &update, &init }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp0('c', { 1, 5 }, { 0.19, 0.38, 0.5700000000000001, 0.76, 0.95 }, DataType::FLOAT32); - NDArray stateV0('c', { 1, 5 }, { -0.1, -0.2, -0.30000000000000004, -0.4, -0.5 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp0)); - ASSERT_TRUE(init.equalsTo(stateV0)); - - status = op.execute({ &grad, &init, &lr, &momentum }, { &update, &init }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - NDArray updateExp1('c', { 1, 5 }, { 0.27099999999999996, 0.5419999999999999, 0.813, 1.0839999999999999, 1.355 }, DataType::FLOAT32); - NDArray stateV1('c', { 1, 5 }, { -0.19, -0.38, -0.5700000000000001, -0.76, -0.95 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp1)); - ASSERT_TRUE(init.equalsTo(stateV1)); - - status = op.execute({ &grad, &init, &lr, &momentum }, { &update, &init }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - NDArray updateExp2('c', { 1, 5 }, { 0.3439, 0.6878, 1.0317, 1.3756, 1.7195 }, DataType::FLOAT32); - NDArray stateV2('c', { 1, 5 }, { -0.271, -0.542, -0.8130000000000002, -1.084, -1.355 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp2)); - ASSERT_TRUE(init.equalsTo(stateV2)); + NDArray grad('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray init('c', {1, 5}, + {0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001}, + DataType::FLOAT32); + + NDArray update('c', {1, 5}, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.1f); + auto momentum = NDArrayFactory::create(0.9f); + + sd::ops::nesterovs_updater op; + + Nd4jStatus status = + op.execute({&grad, &init, &lr, &momentum}, {&update, &init}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', {1, 5}, {0.19, 0.38, 0.5700000000000001, 0.76, 0.95}, + DataType::FLOAT32); + NDArray stateV0('c', {1, 5}, {-0.1, -0.2, -0.30000000000000004, -0.4, -0.5}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(init.equalsTo(stateV0)); + + status = op.execute({&grad, &init, &lr, &momentum}, {&update, &init}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp1('c', {1, 5}, + {0.27099999999999996, 0.5419999999999999, 0.813, + 1.0839999999999999, 1.355}, + DataType::FLOAT32); + NDArray stateV1('c', {1, 5}, + {-0.19, -0.38, -0.5700000000000001, -0.76, -0.95}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(init.equalsTo(stateV1)); + + status = op.execute({&grad, &init, &lr, &momentum}, {&update, &init}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp2('c', {1, 5}, {0.3439, 0.6878, 1.0317, 1.3756, 1.7195}, + DataType::FLOAT32); + NDArray stateV2('c', {1, 5}, + {-0.271, -0.542, -0.8130000000000002, -1.084, -1.355}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(init.equalsTo(stateV2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterNesterovs3) { - - NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); - NDArray initC('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); - - NDArray grad('f', { 1, 5 }, DataType::FLOAT32); - NDArray init('f', { 1, 5 }, DataType::FLOAT32); - grad.assign(gradC); - init.assign(initC); - - sd::ops::nesterovs_updater op; - auto results = op.evaluate({ &grad, &init }, { 0.1f, 0.9f }, { }); - - NDArray updateC('c', { 1, 5 }, { 0.19, 0.38, 0.5700000000000001, 0.76, 0.95 }, DataType::FLOAT32); - NDArray update('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateG0C('c', { 1, 5 }, { -0.1, -0.2, -0.30000000000000004, -0.4, -0.5 }, DataType::FLOAT32); - NDArray stateG('f', { 1, 5 }, DataType::FLOAT32); - - update.assign(updateC); - stateG.assign(stateG0C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateG.isSameShape(results.at(1))); - ASSERT_TRUE(stateG.equalsTo(results.at(1))); - - results = op.evaluate({ &grad, &stateG }, { 0.1f, 0.9f }, { }); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray update1C('c', { 1, 5 }, { 0.27099999999999996, 0.5419999999999999, 0.813, 1.0839999999999999, 1.355 }, DataType::FLOAT32); - NDArray stateG1C('c', { 1, 5 }, { -0.19, -0.38, -0.5700000000000001, -0.76, -0.95 }, DataType::FLOAT32); - - update.assign(update1C); - stateG.assign(stateG1C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateG.isSameShape(results.at(1))); - ASSERT_TRUE(stateG.equalsTo(results.at(1))); - - - results = op.evaluate({ &grad, &stateG }, { 0.1f, 0.9f }, { }); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray update2C('c', { 1, 5 }, { 0.3439, 0.6878, 1.0317, 1.3756, 1.7195 }, DataType::FLOAT32); - NDArray stateG2C('c', { 1, 5 }, { -0.271, -0.542, -0.8130000000000002, -1.084, -1.355 }, DataType::FLOAT32); - - update.assign(update2C); - stateG.assign(stateG2C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateG.isSameShape(results.at(1))); - ASSERT_TRUE(stateG.equalsTo(results.at(1))); + NDArray gradC('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray initC('c', {1, 5}, + {0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001}, + DataType::FLOAT32); + + NDArray grad('f', {1, 5}, DataType::FLOAT32); + NDArray init('f', {1, 5}, DataType::FLOAT32); + grad.assign(gradC); + init.assign(initC); + + sd::ops::nesterovs_updater op; + auto results = op.evaluate({&grad, &init}, {0.1f, 0.9f}, {}); + + NDArray updateC('c', {1, 5}, {0.19, 0.38, 0.5700000000000001, 0.76, 0.95}, + DataType::FLOAT32); + NDArray update('f', {1, 5}, DataType::FLOAT32); + + NDArray stateG0C('c', {1, 5}, {-0.1, -0.2, -0.30000000000000004, -0.4, -0.5}, + DataType::FLOAT32); + NDArray stateG('f', {1, 5}, DataType::FLOAT32); + + update.assign(updateC); + stateG.assign(stateG0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); + + results = op.evaluate({&grad, &stateG}, {0.1f, 0.9f}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C('c', {1, 5}, + {0.27099999999999996, 0.5419999999999999, 0.813, + 1.0839999999999999, 1.355}, + DataType::FLOAT32); + NDArray stateG1C('c', {1, 5}, + {-0.19, -0.38, -0.5700000000000001, -0.76, -0.95}, + DataType::FLOAT32); + + update.assign(update1C); + stateG.assign(stateG1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); + + results = op.evaluate({&grad, &stateG}, {0.1f, 0.9f}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update2C('c', {1, 5}, {0.3439, 0.6878, 1.0317, 1.3756, 1.7195}, + DataType::FLOAT32); + NDArray stateG2C('c', {1, 5}, + {-0.271, -0.542, -0.8130000000000002, -1.084, -1.355}, + DataType::FLOAT32); + + update.assign(update2C); + stateG.assign(stateG2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAdaMax1) { - - NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); - NDArray initU('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - NDArray update('c', { 1, 5 }, DataType::FLOAT32); - - sd::ops::ada_max_updater op; - - Nd4jStatus status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp0('c', { 1, 5 }, { 0.001, 0.001, 0.001, 0.001, 0.001 }, DataType::FLOAT32); - NDArray stateU('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); - NDArray stateM0('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp0)); - ASSERT_TRUE(initU.equalsTo(stateU)); - ASSERT_TRUE(initM.equalsTo(stateM0)); - - status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp1('c', { 1, 5 }, { 0.0019, 0.0019, 0.0019, 0.0019, 0.0019 }, DataType::FLOAT32); - NDArray stateM1('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); - ASSERT_TRUE(update.equalsTo(updateExp1)); - ASSERT_TRUE(initU.equalsTo(stateU)); - ASSERT_TRUE(initM.equalsTo(stateM1)); - - status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp2('c', { 1, 5 }, { 0.00271, 0.00271, 0.00271, 0.00271, 0.00271 }, DataType::FLOAT32); - NDArray stateM2('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp2)); - ASSERT_TRUE(initU.equalsTo(stateU)); - ASSERT_TRUE(initM.equalsTo(stateM2)); + NDArray grad('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray initU('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initM('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + NDArray update('c', {1, 5}, DataType::FLOAT32); + + sd::ops::ada_max_updater op; + + Nd4jStatus status = + op.execute({&grad, &initU, &initM}, {&update, &initU, &initM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', {1, 5}, {0.001, 0.001, 0.001, 0.001, 0.001}, + DataType::FLOAT32); + NDArray stateU('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray stateM0('c', {1, 5}, + {0.09999999999999998, 0.19999999999999996, + 0.29999999999999993, 0.3999999999999999, 0.4999999999999999}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initU.equalsTo(stateU)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + status = op.execute({&grad, &initU, &initM}, {&update, &initU, &initM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', {1, 5}, {0.0019, 0.0019, 0.0019, 0.0019, 0.0019}, + DataType::FLOAT32); + NDArray stateM1('c', {1, 5}, + {0.18999999999999995, 0.3799999999999999, 0.5699999999999998, + 0.7599999999999998, 0.9499999999999997}, + DataType::FLOAT32); + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initU.equalsTo(stateU)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + status = op.execute({&grad, &initU, &initM}, {&update, &initU, &initM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', {1, 5}, {0.00271, 0.00271, 0.00271, 0.00271, 0.00271}, + DataType::FLOAT32); + NDArray stateM2('c', {1, 5}, + {0.2709999999999999, 0.5419999999999998, 0.8129999999999998, + 1.0839999999999996, 1.3549999999999995}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initU.equalsTo(stateU)); + ASSERT_TRUE(initM.equalsTo(stateM2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAdaMax2) { - - NDArray grad0('c', { 1, 5 }, { 0.05387359112501144, 0.9700437784194946, 0.8912011384963989, 0.8891847729682922, 0.18823780119419098 }, DataType::FLOAT32); - NDArray initU('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - auto lr = NDArrayFactory::create(0.001f); - auto beta1 = NDArrayFactory::create(0.9f); - auto beta2 = NDArrayFactory::create(0.999f); - auto epsilon = NDArrayFactory::create(1.0e-8); - - sd::ops::ada_max_updater op; - - Nd4jStatus status = op.execute({ &grad0, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad0, &initU, &initM }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp0('c', { 1, 5 }, { 0.001, 0.001, 0.001, 0.001, 0.001 }, DataType::FLOAT32); - NDArray stateU0('c', { 1, 5 }, { 0.05387359112501144, 0.9700437784194946, 0.8912011384963989, 0.8891847729682922, 0.18823780119419098 }, DataType::FLOAT32); - NDArray stateM0('c', { 1, 5 }, { 0.00538735911250114, 0.09700437784194944, 0.08912011384963987, 0.08891847729682921, 0.01882378011941909 }, DataType::FLOAT32); - - ASSERT_TRUE(grad0.equalsTo(updateExp0)); - ASSERT_TRUE(initU.equalsTo(stateU0)); - ASSERT_TRUE(initM.equalsTo(stateM0)); - - NDArray grad1('c', { 1, 5 }, { 0.6400517821311951, 0.3779360353946686, 0.35128724575042725, 0.6554615497589111, 0.8420050740242004 }, DataType::FLOAT32); - - status = op.execute({ &grad1, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad1, &initU, &initM }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp1('c', { 1, 5 }, { 0.00107575360832691, 0.00129089809294599, 0.00129546826560191, 0.00163878765669416, 0.00120120308808246 }, DataType::FLOAT32); - NDArray stateU1('c', { 1, 5 }, { 0.6400517821311951, 0.9690737346410752, 0.8903099373579025, 0.888295588195324, 0.8420050740242004 }, DataType::FLOAT32); - NDArray stateM1('c', { 1, 5 }, { 0.06885380141437052, 0.12509754359722136, 0.11533682703971859, 0.1455727845430374, 0.10114190950989721 }, DataType::FLOAT32); - - ASSERT_TRUE(grad1.equalsTo(updateExp1)); - ASSERT_TRUE(initU.equalsTo(stateU1)); - ASSERT_TRUE(initM.equalsTo(stateM1)); - - NDArray grad2('c', { 1, 5 }, { 0.5984494686126709, 0.05978915095329285, 0.5749519467353821, 0.2804091274738312, 0.0192152876406908 }, DataType::FLOAT32); - - status = op.execute({ &grad2, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad2, &initU, &initM }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp2('c', { 1, 5 }, { 0.00190508497658779, 0.00122473022928962, 0.00181352349370876, 0.00179237223044249, 0.00110500865710834 }, DataType::FLOAT32); - NDArray stateU2('c', { 1, 5 }, { 0.6394117303490638, 0.9681046609064341, 0.8894196274205446, 0.8874072926071286, 0.8411630689501762 }, DataType::FLOAT32); - NDArray stateM2('c', { 1, 5 }, { 0.12181336813420054, 0.11856670433282851, 0.16129833900928492, 0.15905641883611676, 0.09294924732297657 }, DataType::FLOAT32); - - ASSERT_TRUE(grad2.equalsTo(updateExp2)); - ASSERT_TRUE(initU.equalsTo(stateU2)); - ASSERT_TRUE(initM.equalsTo(stateM2)); + NDArray grad0('c', {1, 5}, + {0.05387359112501144, 0.9700437784194946, 0.8912011384963989, + 0.8891847729682922, 0.18823780119419098}, + DataType::FLOAT32); + NDArray initU('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initM('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.001f); + auto beta1 = NDArrayFactory::create(0.9f); + auto beta2 = NDArrayFactory::create(0.999f); + auto epsilon = NDArrayFactory::create(1.0e-8); + + sd::ops::ada_max_updater op; + + Nd4jStatus status = + op.execute({&grad0, &initU, &initM, &lr, &beta1, &beta2, &epsilon}, + {&grad0, &initU, &initM}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', {1, 5}, {0.001, 0.001, 0.001, 0.001, 0.001}, + DataType::FLOAT32); + NDArray stateU0('c', {1, 5}, + {0.05387359112501144, 0.9700437784194946, 0.8912011384963989, + 0.8891847729682922, 0.18823780119419098}, + DataType::FLOAT32); + NDArray stateM0( + 'c', {1, 5}, + {0.00538735911250114, 0.09700437784194944, 0.08912011384963987, + 0.08891847729682921, 0.01882378011941909}, + DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(initU.equalsTo(stateU0)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + NDArray grad1('c', {1, 5}, + {0.6400517821311951, 0.3779360353946686, 0.35128724575042725, + 0.6554615497589111, 0.8420050740242004}, + DataType::FLOAT32); + + status = op.execute({&grad1, &initU, &initM, &lr, &beta1, &beta2, &epsilon}, + {&grad1, &initU, &initM}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1( + 'c', {1, 5}, + {0.00107575360832691, 0.00129089809294599, 0.00129546826560191, + 0.00163878765669416, 0.00120120308808246}, + DataType::FLOAT32); + NDArray stateU1('c', {1, 5}, + {0.6400517821311951, 0.9690737346410752, 0.8903099373579025, + 0.888295588195324, 0.8420050740242004}, + DataType::FLOAT32); + NDArray stateM1( + 'c', {1, 5}, + {0.06885380141437052, 0.12509754359722136, 0.11533682703971859, + 0.1455727845430374, 0.10114190950989721}, + DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(initU.equalsTo(stateU1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + NDArray grad2('c', {1, 5}, + {0.5984494686126709, 0.05978915095329285, 0.5749519467353821, + 0.2804091274738312, 0.0192152876406908}, + DataType::FLOAT32); + + status = op.execute({&grad2, &initU, &initM, &lr, &beta1, &beta2, &epsilon}, + {&grad2, &initU, &initM}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2( + 'c', {1, 5}, + {0.00190508497658779, 0.00122473022928962, 0.00181352349370876, + 0.00179237223044249, 0.00110500865710834}, + DataType::FLOAT32); + NDArray stateU2('c', {1, 5}, + {0.6394117303490638, 0.9681046609064341, 0.8894196274205446, + 0.8874072926071286, 0.8411630689501762}, + DataType::FLOAT32); + NDArray stateM2( + 'c', {1, 5}, + {0.12181336813420054, 0.11856670433282851, 0.16129833900928492, + 0.15905641883611676, 0.09294924732297657}, + DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(initU.equalsTo(stateU2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAdaMax3) { - - NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); - NDArray initVC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initMC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - NDArray grad('f', { 1, 5 }, DataType::FLOAT32); - NDArray initV('f', { 1, 5 }, DataType::FLOAT32); - NDArray initM('f', { 1, 5 }, DataType::FLOAT32); - - grad.assign(gradC); - initV.assign(initVC); - initM.assign(initMC); - - sd::ops::ada_max_updater op; - auto results = op.evaluate({ &grad, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - - NDArray updateC('c', { 1, 5 }, { 0.001, 0.001, 0.001, 0.001, 0.001 }, DataType::FLOAT32); - NDArray update('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateV0C('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); - NDArray stateV('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateM0C('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); - NDArray stateM('f', { 1, 5 }, DataType::FLOAT32); - - update.assign(updateC); - stateV.assign(stateV0C); - stateM.assign(stateM0C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateV.isSameShape(results.at(1))); - ASSERT_TRUE(stateV.equalsTo(results.at(1))); - ASSERT_TRUE(stateM.isSameShape(results.at(2))); - ASSERT_TRUE(stateM.equalsTo(results.at(2))); - - results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray update1C('c', { 1, 5 }, { 0.0019, 0.0019, 0.0019, 0.0019, 0.0019 }, DataType::FLOAT32); - NDArray stateM1C('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); - - update.assign(update1C); - stateM.assign(stateM1C); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateV.isSameShape(results.at(1))); - ASSERT_TRUE(stateV.equalsTo(results.at(1))); - ASSERT_TRUE(stateM.isSameShape(results.at(2))); - ASSERT_TRUE(stateM.equalsTo(results.at(2))); - - - results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray update2C('c', { 1, 5 }, { 0.00271, 0.00271, 0.00271, 0.00271, 0.00271 }, DataType::FLOAT32); - NDArray stateM2C('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); - - update.assign(update2C); - stateM.assign(stateM2C); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateV.isSameShape(results.at(1))); - ASSERT_TRUE(stateV.equalsTo(results.at(1))); - ASSERT_TRUE(stateM.isSameShape(results.at(2))); - ASSERT_TRUE(stateM.equalsTo(results.at(2))); + NDArray gradC('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray initVC('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initMC('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + NDArray grad('f', {1, 5}, DataType::FLOAT32); + NDArray initV('f', {1, 5}, DataType::FLOAT32); + NDArray initM('f', {1, 5}, DataType::FLOAT32); + + grad.assign(gradC); + initV.assign(initVC); + initM.assign(initMC); + + sd::ops::ada_max_updater op; + auto results = + op.evaluate({&grad, &initV, &initM}, {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + + NDArray updateC('c', {1, 5}, {0.001, 0.001, 0.001, 0.001, 0.001}, + DataType::FLOAT32); + NDArray update('f', {1, 5}, DataType::FLOAT32); + + NDArray stateV0C('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray stateV('f', {1, 5}, DataType::FLOAT32); + + NDArray stateM0C( + 'c', {1, 5}, + {0.09999999999999998, 0.19999999999999996, 0.29999999999999993, + 0.3999999999999999, 0.4999999999999999}, + DataType::FLOAT32); + NDArray stateM('f', {1, 5}, DataType::FLOAT32); + + update.assign(updateC); + stateV.assign(stateV0C); + stateM.assign(stateM0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({&grad, &stateV, &stateM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C('c', {1, 5}, {0.0019, 0.0019, 0.0019, 0.0019, 0.0019}, + DataType::FLOAT32); + NDArray stateM1C('c', {1, 5}, + {0.18999999999999995, 0.3799999999999999, 0.5699999999999998, + 0.7599999999999998, 0.9499999999999997}, + DataType::FLOAT32); + + update.assign(update1C); + stateM.assign(stateM1C); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({&grad, &stateV, &stateM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update2C('c', {1, 5}, {0.00271, 0.00271, 0.00271, 0.00271, 0.00271}, + DataType::FLOAT32); + NDArray stateM2C('c', {1, 5}, + {0.2709999999999999, 0.5419999999999998, 0.8129999999999998, + 1.0839999999999996, 1.3549999999999995}, + DataType::FLOAT32); + + update.assign(update2C); + stateM.assign(stateM2C); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAdam1) { - - NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); - NDArray initU('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - NDArray update('c', { 1, 5 }, DataType::FLOAT32); - - sd::ops::adam_updater op; - - Nd4jStatus status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp0('c', { 1, 5 }, { 0.00099999968377233, 0.00099999984188614, 0.00099999989459076, 0.00099999992094306, 0.00099999993675445 }, DataType::FLOAT32); - NDArray stateV('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); - NDArray stateM0('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp0)); - ASSERT_TRUE(initU.equalsTo(stateV)); - ASSERT_TRUE(initM.equalsTo(stateM0)); - - status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp1('c', { 1, 5 }, { 0.00134383858541481, 0.00134383873569809, 0.00134383878579252, 0.00134383881083974, 0.00134383882586807 }, DataType::FLOAT32); - NDArray stateV1('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); - NDArray stateM1('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp1)); - ASSERT_TRUE(initU.equalsTo(stateV1)); - ASSERT_TRUE(initM.equalsTo(stateM1)); - - status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp2('c', { 1, 5 }, { 0.00156540157923389, 0.00156540172220632, 0.0015654017698638, 0.00156540179369254, 0.00156540180798979 }, DataType::FLOAT32); - NDArray stateV2('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); - NDArray stateM2('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp2)); - ASSERT_TRUE(initU.equalsTo(stateV2)); - ASSERT_TRUE(initM.equalsTo(stateM2)); + NDArray grad('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray initU('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initM('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + NDArray update('c', {1, 5}, DataType::FLOAT32); + + sd::ops::adam_updater op; + + Nd4jStatus status = + op.execute({&grad, &initU, &initM}, {&update, &initU, &initM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0( + 'c', {1, 5}, + {0.00099999968377233, 0.00099999984188614, 0.00099999989459076, + 0.00099999992094306, 0.00099999993675445}, + DataType::FLOAT32); + NDArray stateV('c', {1, 5}, + {0.001, 0.004, 0.00900000000000001, 0.01600000000000001, + 0.02500000000000002}, + DataType::FLOAT32); + NDArray stateM0('c', {1, 5}, + {0.09999999999999998, 0.19999999999999996, + 0.29999999999999993, 0.3999999999999999, 0.4999999999999999}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initU.equalsTo(stateV)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + status = op.execute({&grad, &initU, &initM}, {&update, &initU, &initM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1( + 'c', {1, 5}, + {0.00134383858541481, 0.00134383873569809, 0.00134383878579252, + 0.00134383881083974, 0.00134383882586807}, + DataType::FLOAT32); + NDArray stateV1('c', {1, 5}, + {0.001999, 0.00799600000000001, 0.01799100000000001, + 0.03198400000000003, 0.04997500000000005}, + DataType::FLOAT32); + NDArray stateM1('c', {1, 5}, + {0.18999999999999995, 0.3799999999999999, 0.5699999999999998, + 0.7599999999999998, 0.9499999999999997}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initU.equalsTo(stateV1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + status = op.execute({&grad, &initU, &initM}, {&update, &initU, &initM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2( + 'c', {1, 5}, + {0.00156540157923389, 0.00156540172220632, 0.0015654017698638, + 0.00156540179369254, 0.00156540180798979}, + DataType::FLOAT32); + NDArray stateV2('c', {1, 5}, + {0.002997001, 0.01198800400000001, 0.02697300900000002, + 0.04795201600000004, 0.07492502500000006}, + DataType::FLOAT32); + NDArray stateM2('c', {1, 5}, + {0.2709999999999999, 0.5419999999999998, 0.8129999999999998, + 1.0839999999999996, 1.3549999999999995}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initU.equalsTo(stateV2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAdam2) { - - NDArray grad0('c', { 1, 5 }, { 0.7124611735343933, 0.7283763289451599, 0.8196553587913513, 0.9501070976257324, 0.2654055953025818 }, DataType::FLOAT32); - NDArray initU('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - auto lr = NDArrayFactory::create(0.001f); - auto beta1 = NDArrayFactory::create(0.9f); - auto beta2 = NDArrayFactory::create(0.999f); - auto epsilon = NDArrayFactory::create(1.0e-8); - - sd::ops::adam_updater op; - - Nd4jStatus status = op.execute({ &grad0, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad0, &initU, &initM }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp0('c', { 1, 5 }, { 0.00099999955614757, 0.00099999956584582, 0.00099999961419438, 0.0009999996671663, 0.00099999880851273 }, DataType::FLOAT32); - NDArray stateU0('c', { 1, 5 }, { 0.00050760092379401, 0.00053053207656763, 0.00067183490719538, 0.00090270349695879, 0.00007044013001792 }, DataType::FLOAT32); - NDArray stateM0('c', { 1, 5 }, { 0.07124611735343932, 0.07283763289451597, 0.08196553587913512, 0.09501070976257323, 0.02654055953025817 }, DataType::FLOAT32); - - ASSERT_TRUE(grad0.equalsTo(updateExp0)); - ASSERT_TRUE(initU.equalsTo(stateU0)); - ASSERT_TRUE(initM.equalsTo(stateM0)); - - NDArray grad1('c', { 1, 5 }, { 0.4374369978904724, 0.11488933861255646, 0.6765823364257812, 0.7659900188446045, 0.04410457238554955 }, DataType::FLOAT32); - - status = op.execute({ &grad1, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad1, &initU, &initM }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp1('c', { 1, 5 }, { 0.00129067017716555, 0.00104532555849556, 0.00133106720937621, 0.00132869584719374, 0.00105226561254395 }, DataType::FLOAT32); - NDArray stateU1('c', { 1, 5 }, { 0.00069844444999364, 0.00054320110461789, 0.00112892673025155, 0.00148854150243139, 0.00007231490319321 }, DataType::FLOAT32); - NDArray stateM1('c', { 1, 5 }, { 0.10786520540714262, 0.07704280346632002, 0.14142721593379973, 0.16210864067077635, 0.02829696081578731 }, DataType::FLOAT32); - - ASSERT_TRUE(grad1.equalsTo(updateExp1)); - ASSERT_TRUE(initU.equalsTo(stateU1)); - ASSERT_TRUE(initM.equalsTo(stateM1)); - - NDArray grad2('c', { 1, 5 }, { 0.496029257774353, 0.11621368676424026, 0.9112075567245483, 0.5717480182647705, 0.5975669026374817 }, DataType::FLOAT32); - - status = op.execute({ &grad2, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad2, &initU, &initM }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp2('c', { 1, 5 }, { 0.00150986322036664, 0.00108559662275258, 0.00156079502787382, 0.00150778241516558, 0.00130066803775601 }, DataType::FLOAT32); - NDArray stateU2('c', { 1, 5 }, { 0.00094379103011182, 0.00055616352450461, 0.00195809701495322, 0.00181394875731865, 0.00042932879141777 }, DataType::FLOAT32); - NDArray stateM2('c', { 1, 5 }, { 0.14668161064386365, 0.08095989179611204, 0.21840525001287456, 0.20307257843017573, 0.08522395499795674 }, DataType::FLOAT32); - - ASSERT_TRUE(grad2.equalsTo(updateExp2)); - ASSERT_TRUE(initU.equalsTo(stateU2)); - ASSERT_TRUE(initM.equalsTo(stateM2)); + NDArray grad0('c', {1, 5}, + {0.7124611735343933, 0.7283763289451599, 0.8196553587913513, + 0.9501070976257324, 0.2654055953025818}, + DataType::FLOAT32); + NDArray initU('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initM('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.001f); + auto beta1 = NDArrayFactory::create(0.9f); + auto beta2 = NDArrayFactory::create(0.999f); + auto epsilon = NDArrayFactory::create(1.0e-8); + + sd::ops::adam_updater op; + + Nd4jStatus status = + op.execute({&grad0, &initU, &initM, &lr, &beta1, &beta2, &epsilon}, + {&grad0, &initU, &initM}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0( + 'c', {1, 5}, + {0.00099999955614757, 0.00099999956584582, 0.00099999961419438, + 0.0009999996671663, 0.00099999880851273}, + DataType::FLOAT32); + NDArray stateU0( + 'c', {1, 5}, + {0.00050760092379401, 0.00053053207656763, 0.00067183490719538, + 0.00090270349695879, 0.00007044013001792}, + DataType::FLOAT32); + NDArray stateM0( + 'c', {1, 5}, + {0.07124611735343932, 0.07283763289451597, 0.08196553587913512, + 0.09501070976257323, 0.02654055953025817}, + DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(initU.equalsTo(stateU0)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + NDArray grad1('c', {1, 5}, + {0.4374369978904724, 0.11488933861255646, 0.6765823364257812, + 0.7659900188446045, 0.04410457238554955}, + DataType::FLOAT32); + + status = op.execute({&grad1, &initU, &initM, &lr, &beta1, &beta2, &epsilon}, + {&grad1, &initU, &initM}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1( + 'c', {1, 5}, + {0.00129067017716555, 0.00104532555849556, 0.00133106720937621, + 0.00132869584719374, 0.00105226561254395}, + DataType::FLOAT32); + NDArray stateU1( + 'c', {1, 5}, + {0.00069844444999364, 0.00054320110461789, 0.00112892673025155, + 0.00148854150243139, 0.00007231490319321}, + DataType::FLOAT32); + NDArray stateM1( + 'c', {1, 5}, + {0.10786520540714262, 0.07704280346632002, 0.14142721593379973, + 0.16210864067077635, 0.02829696081578731}, + DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(initU.equalsTo(stateU1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + NDArray grad2('c', {1, 5}, + {0.496029257774353, 0.11621368676424026, 0.9112075567245483, + 0.5717480182647705, 0.5975669026374817}, + DataType::FLOAT32); + + status = op.execute({&grad2, &initU, &initM, &lr, &beta1, &beta2, &epsilon}, + {&grad2, &initU, &initM}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2( + 'c', {1, 5}, + {0.00150986322036664, 0.00108559662275258, 0.00156079502787382, + 0.00150778241516558, 0.00130066803775601}, + DataType::FLOAT32); + NDArray stateU2( + 'c', {1, 5}, + {0.00094379103011182, 0.00055616352450461, 0.00195809701495322, + 0.00181394875731865, 0.00042932879141777}, + DataType::FLOAT32); + NDArray stateM2( + 'c', {1, 5}, + {0.14668161064386365, 0.08095989179611204, 0.21840525001287456, + 0.20307257843017573, 0.08522395499795674}, + DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(initU.equalsTo(stateU2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAdam3) { - - NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); - NDArray initVC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initMC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - NDArray grad('f', { 1, 5 }, DataType::FLOAT32); - NDArray initV('f', { 1, 5 }, DataType::FLOAT32); - NDArray initM('f', { 1, 5 }, DataType::FLOAT32); - - grad.assign(gradC); - initV.assign(initVC); - initM.assign(initMC); - - sd::ops::adam_updater op; - auto results = op.evaluate({ &grad, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - - NDArray updateC('c', { 1, 5 }, { 0.00099999968377233, 0.00099999984188614, 0.00099999989459076, 0.00099999992094306, 0.00099999993675445 }, DataType::FLOAT32); - NDArray update('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateV0C('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); - NDArray stateV('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateM0C('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); - NDArray stateM('f', { 1, 5 }, DataType::FLOAT32); - - update.assign(updateC); - stateV.assign(stateV0C); - stateM.assign(stateM0C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateV.isSameShape(results.at(1))); - ASSERT_TRUE(stateV.equalsTo(results.at(1))); - ASSERT_TRUE(stateM.isSameShape(results.at(2))); - ASSERT_TRUE(stateM.equalsTo(results.at(2))); - - results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray update1C('c', { 1, 5 }, { 0.00134383858541481, 0.00134383873569809, 0.00134383878579252, 0.00134383881083974, 0.00134383882586807 }, DataType::FLOAT32); - NDArray stateV1C('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); - NDArray stateM1C('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); - - update.assign(update1C); - stateV.assign(stateV1C); - stateM.assign(stateM1C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateV.isSameShape(results.at(1))); - ASSERT_TRUE(stateV.equalsTo(results.at(1))); - ASSERT_TRUE(stateM.isSameShape(results.at(2))); - ASSERT_TRUE(stateM.equalsTo(results.at(2))); - - results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - - NDArray update2C('c', { 1, 5 }, { 0.00156540157923389, 0.00156540172220632, 0.0015654017698638, 0.00156540179369254, 0.00156540180798979 }, DataType::FLOAT32); - NDArray stateV2C('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); - NDArray stateM2C('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); - - update.assign(update2C); - stateV.assign(stateV2C); - stateM.assign(stateM2C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateV.isSameShape(results.at(1))); - ASSERT_TRUE(stateV.equalsTo(results.at(1))); - ASSERT_TRUE(stateM.isSameShape(results.at(2))); - ASSERT_TRUE(stateM.equalsTo(results.at(2))); + NDArray gradC('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray initVC('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initMC('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + NDArray grad('f', {1, 5}, DataType::FLOAT32); + NDArray initV('f', {1, 5}, DataType::FLOAT32); + NDArray initM('f', {1, 5}, DataType::FLOAT32); + + grad.assign(gradC); + initV.assign(initVC); + initM.assign(initMC); + + sd::ops::adam_updater op; + auto results = + op.evaluate({&grad, &initV, &initM}, {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + + NDArray updateC( + 'c', {1, 5}, + {0.00099999968377233, 0.00099999984188614, 0.00099999989459076, + 0.00099999992094306, 0.00099999993675445}, + DataType::FLOAT32); + NDArray update('f', {1, 5}, DataType::FLOAT32); + + NDArray stateV0C('c', {1, 5}, + {0.001, 0.004, 0.00900000000000001, 0.01600000000000001, + 0.02500000000000002}, + DataType::FLOAT32); + NDArray stateV('f', {1, 5}, DataType::FLOAT32); + + NDArray stateM0C( + 'c', {1, 5}, + {0.09999999999999998, 0.19999999999999996, 0.29999999999999993, + 0.3999999999999999, 0.4999999999999999}, + DataType::FLOAT32); + NDArray stateM('f', {1, 5}, DataType::FLOAT32); + + update.assign(updateC); + stateV.assign(stateV0C); + stateM.assign(stateM0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({&grad, &stateV, &stateM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C( + 'c', {1, 5}, + {0.00134383858541481, 0.00134383873569809, 0.00134383878579252, + 0.00134383881083974, 0.00134383882586807}, + DataType::FLOAT32); + NDArray stateV1C('c', {1, 5}, + {0.001999, 0.00799600000000001, 0.01799100000000001, + 0.03198400000000003, 0.04997500000000005}, + DataType::FLOAT32); + NDArray stateM1C('c', {1, 5}, + {0.18999999999999995, 0.3799999999999999, 0.5699999999999998, + 0.7599999999999998, 0.9499999999999997}, + DataType::FLOAT32); + + update.assign(update1C); + stateV.assign(stateV1C); + stateM.assign(stateM1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({&grad, &stateV, &stateM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + + NDArray update2C( + 'c', {1, 5}, + {0.00156540157923389, 0.00156540172220632, 0.0015654017698638, + 0.00156540179369254, 0.00156540180798979}, + DataType::FLOAT32); + NDArray stateV2C('c', {1, 5}, + {0.002997001, 0.01198800400000001, 0.02697300900000002, + 0.04795201600000004, 0.07492502500000006}, + DataType::FLOAT32); + NDArray stateM2C('c', {1, 5}, + {0.2709999999999999, 0.5419999999999998, 0.8129999999999998, + 1.0839999999999996, 1.3549999999999995}, + DataType::FLOAT32); + + update.assign(update2C); + stateV.assign(stateV2C); + stateM.assign(stateM2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAdaDelta1) { - - NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); - NDArray initMsg('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initMsdx('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - NDArray update('c', { 1, 5 }, DataType::FLOAT32); - - sd::ops::ada_delta_updater op; - - Nd4jStatus status = op.execute({ &grad, &initMsg, &initMsdx }, { &update, &initMsg, &initMsdx }, { 0.95f, 1.0e-6 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp0('c', { 1, 5 }, { 0.00447209123431084, 0.00447212477470162, 0.00447213098596791, 0.00447213315991723, 0.00447213416614627 }, DataType::FLOAT32); - NDArray stateMsg0('c', { 1, 5 }, { 0.05000000000000004, 0.20000000000000018, 0.4500000000000004, 0.8000000000000007, 1.250000000000001 }, DataType::FLOAT32); - NDArray stateMsdx0('c', { 1, 5 }, { 0.0000009999800004, 0.00000099999500002, 0.00000099999777778, 0.00000099999875, 0.0000009999992 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp0)); - ASSERT_TRUE(initMsg.equalsTo(stateMsg0)); - ASSERT_TRUE(initMsdx.equalsTo(stateMsdx0)); - - status = op.execute({ &grad, &initMsg, &initMsdx }, { &update, &initMsg, &initMsdx }, { 0.95f, 1.0e-6 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp1('c', { 1, 5 }, { 0.0045290622655332, 0.00452909666868751, 0.00452910303972733, 0.00452910526959756, 0.00452910630171004 }, DataType::FLOAT32); - NDArray stateMsg1('c', { 1, 5 }, { 0.09750000000000009, 0.39000000000000035, 0.8775000000000008, 1.5600000000000014, 2.4375000000000018 }, DataType::FLOAT32); - NDArray stateMsdx1('c', { 1, 5 }, { 0.00000197560125063, 0.00000197563108174, 0.00000197563660612, 0.00000197563853966, 0.00000197563943461 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp1)); - ASSERT_TRUE(initMsg.equalsTo(stateMsg1)); - ASSERT_TRUE(initMsdx.equalsTo(stateMsdx1)); - - status = op.execute({ &grad, &initMsg, &initMsdx }, { &update, &initMsg, &initMsdx }, { 0.95f, 1.0e-6 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp2('c', { 1, 5 }, { 0.00456759948242601, 0.00456763438748812, 0.00456764085147516, 0.00456764311387702, 0.004567644161047 }, DataType::FLOAT32); - NDArray stateMsg2('c', { 1, 5 }, { 0.1426250000000001, 0.5705000000000005, 1.2836250000000011, 2.282000000000002, 3.5656250000000025 }, DataType::FLOAT32); - NDArray stateMsdx2('c', { 1, 5 }, { 0.0000029199694397, 0.00000292001372254, 0.00000292002192321, 0.00000292002479346, 0.00000292002612198 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp2)); - ASSERT_TRUE(initMsg.equalsTo(stateMsg2)); - ASSERT_TRUE(initMsdx.equalsTo(stateMsdx2)); + NDArray grad('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray initMsg('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initMsdx('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + NDArray update('c', {1, 5}, DataType::FLOAT32); + + sd::ops::ada_delta_updater op; + + Nd4jStatus status = + op.execute({&grad, &initMsg, &initMsdx}, {&update, &initMsg, &initMsdx}, + {0.95f, 1.0e-6}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0( + 'c', {1, 5}, + {0.00447209123431084, 0.00447212477470162, 0.00447213098596791, + 0.00447213315991723, 0.00447213416614627}, + DataType::FLOAT32); + NDArray stateMsg0('c', {1, 5}, + {0.05000000000000004, 0.20000000000000018, + 0.4500000000000004, 0.8000000000000007, 1.250000000000001}, + DataType::FLOAT32); + NDArray stateMsdx0('c', {1, 5}, + {0.0000009999800004, 0.00000099999500002, + 0.00000099999777778, 0.00000099999875, 0.0000009999992}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg0)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx0)); + + status = op.execute({&grad, &initMsg, &initMsdx}, + {&update, &initMsg, &initMsdx}, {0.95f, 1.0e-6}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1( + 'c', {1, 5}, + {0.0045290622655332, 0.00452909666868751, 0.00452910303972733, + 0.00452910526959756, 0.00452910630171004}, + DataType::FLOAT32); + NDArray stateMsg1( + 'c', {1, 5}, + {0.09750000000000009, 0.39000000000000035, 0.8775000000000008, + 1.5600000000000014, 2.4375000000000018}, + DataType::FLOAT32); + NDArray stateMsdx1( + 'c', {1, 5}, + {0.00000197560125063, 0.00000197563108174, 0.00000197563660612, + 0.00000197563853966, 0.00000197563943461}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg1)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx1)); + + status = op.execute({&grad, &initMsg, &initMsdx}, + {&update, &initMsg, &initMsdx}, {0.95f, 1.0e-6}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2( + 'c', {1, 5}, + {0.00456759948242601, 0.00456763438748812, 0.00456764085147516, + 0.00456764311387702, 0.004567644161047}, + DataType::FLOAT32); + NDArray stateMsg2('c', {1, 5}, + {0.1426250000000001, 0.5705000000000005, 1.2836250000000011, + 2.282000000000002, 3.5656250000000025}, + DataType::FLOAT32); + NDArray stateMsdx2( + 'c', {1, 5}, + {0.0000029199694397, 0.00000292001372254, 0.00000292002192321, + 0.00000292002479346, 0.00000292002612198}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg2)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAdaDelta2) { - - NDArray grad0('c', { 1, 5 }, { 0.22060230374336243, 0.10593396425247192, 0.9027279019355774, 0.831809401512146, 0.2733047902584076 }, DataType::FLOAT32); - NDArray initMsg('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initMsdx('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - auto rho = NDArrayFactory::create(0.95f); - auto epsilon = NDArrayFactory::create(1.0e-6); - - sd::ops::ada_delta_updater op; - - Nd4jStatus status = op.execute({ &grad0, &initMsg, &initMsdx, &rho, &epsilon }, { &grad0, &initMsg, &initMsdx }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp0('c', { 1, 5 }, { 0.0044712172817412, 0.00446815612502933, 0.00447208107763182, 0.004472071321461, 0.00447153735969189 }, DataType::FLOAT32); - NDArray stateMsg0('c', { 1, 5 }, { 0.00243326882084394, 0.0005611002391122, 0.04074588324665051, 0.03459534402219976, 0.00373477541890961 }, DataType::FLOAT32); - NDArray stateMsdx0('c', { 1, 5 }, { 0.00000099958919903, 0.00000099822095788, 0.00000099997545825, 0.00000099997109521, 0.00000099973231796 }, DataType::FLOAT32); - - ASSERT_TRUE(grad0.equalsTo(updateExp0)); - ASSERT_TRUE(initMsg.equalsTo(stateMsg0)); - ASSERT_TRUE(initMsdx.equalsTo(stateMsdx0)); - - NDArray grad1('c', { 1, 5 }, { 0.6351608633995056, 0.21878601610660553, 0.6470938920974731, 0.3742971122264862, 0.9453978538513184 }, DataType::FLOAT32); - - status = op.execute({ &grad1, &initMsg, &initMsdx, &rho, &epsilon }, { &grad1, &initMsg, &initMsdx }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp1('c', { 1, 5 }, { 0.00598985959779411, 0.00571609509028959, 0.00374704195122062, 0.00265092283150538, 0.00608704322078556 }, DataType::FLOAT32); - NDArray stateMsg1('c', { 1, 5 }, { 0.02248307149952203, 0.00292641126934659, 0.05964511434381081, 0.03987049323214412, 0.0482368917512981 }, DataType::FLOAT32); - NDArray stateMsdx1('c', { 1, 5 }, { 0.00000274353063914, 0.00000258199706405, 0.00000165199285454, 0.00000130134213338, 0.00000280235046064 }, DataType::FLOAT32); - - ASSERT_TRUE(grad1.equalsTo(updateExp1)); - ASSERT_TRUE(initMsg.equalsTo(stateMsg1)); - ASSERT_TRUE(initMsdx.equalsTo(stateMsdx1)); - - NDArray grad2('c', { 1, 5 }, { 0.8484492301940918, 0.9634076952934265, 0.6676893830299377, 0.4450211524963379, 0.32364124059677124 }, DataType::FLOAT32); - - status = op.execute({ &grad2, &initMsg, &initMsdx, &rho, &epsilon }, { &grad2, &initMsg, &initMsdx }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp2('c', { 1, 5 }, { 0.00685468722145889, 0.00822128238053265, 0.00386965914609878, 0.00308849888680941, 0.00279277397245112 }, DataType::FLOAT32); - NDArray stateMsg2('c', { 1, 5 }, { 0.05735222273539331, 0.04918781007340889, 0.07895331423716523, 0.04777915987899536, 0.05106222979448406 }, DataType::FLOAT32); - NDArray stateMsdx2('c', { 1, 5 }, { 0.00000495569095238, 0.00000583237140987, 0.00000231810630717, 0.0000017132162954, 0.00000305221226067 }, DataType::FLOAT32); - - ASSERT_TRUE(grad2.equalsTo(updateExp2)); - ASSERT_TRUE(initMsg.equalsTo(stateMsg2)); - ASSERT_TRUE(initMsdx.equalsTo(stateMsdx2)); + NDArray grad0('c', {1, 5}, + {0.22060230374336243, 0.10593396425247192, 0.9027279019355774, + 0.831809401512146, 0.2733047902584076}, + DataType::FLOAT32); + NDArray initMsg('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initMsdx('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + auto rho = NDArrayFactory::create(0.95f); + auto epsilon = NDArrayFactory::create(1.0e-6); + + sd::ops::ada_delta_updater op; + + Nd4jStatus status = op.execute({&grad0, &initMsg, &initMsdx, &rho, &epsilon}, + {&grad0, &initMsg, &initMsdx}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0( + 'c', {1, 5}, + {0.0044712172817412, 0.00446815612502933, 0.00447208107763182, + 0.004472071321461, 0.00447153735969189}, + DataType::FLOAT32); + NDArray stateMsg0( + 'c', {1, 5}, + {0.00243326882084394, 0.0005611002391122, 0.04074588324665051, + 0.03459534402219976, 0.00373477541890961}, + DataType::FLOAT32); + NDArray stateMsdx0( + 'c', {1, 5}, + {0.00000099958919903, 0.00000099822095788, 0.00000099997545825, + 0.00000099997109521, 0.00000099973231796}, + DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg0)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx0)); + + NDArray grad1('c', {1, 5}, + {0.6351608633995056, 0.21878601610660553, 0.6470938920974731, + 0.3742971122264862, 0.9453978538513184}, + DataType::FLOAT32); + + status = op.execute({&grad1, &initMsg, &initMsdx, &rho, &epsilon}, + {&grad1, &initMsg, &initMsdx}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1( + 'c', {1, 5}, + {0.00598985959779411, 0.00571609509028959, 0.00374704195122062, + 0.00265092283150538, 0.00608704322078556}, + DataType::FLOAT32); + NDArray stateMsg1( + 'c', {1, 5}, + {0.02248307149952203, 0.00292641126934659, 0.05964511434381081, + 0.03987049323214412, 0.0482368917512981}, + DataType::FLOAT32); + NDArray stateMsdx1( + 'c', {1, 5}, + {0.00000274353063914, 0.00000258199706405, 0.00000165199285454, + 0.00000130134213338, 0.00000280235046064}, + DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg1)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx1)); + + NDArray grad2('c', {1, 5}, + {0.8484492301940918, 0.9634076952934265, 0.6676893830299377, + 0.4450211524963379, 0.32364124059677124}, + DataType::FLOAT32); + + status = op.execute({&grad2, &initMsg, &initMsdx, &rho, &epsilon}, + {&grad2, &initMsg, &initMsdx}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2( + 'c', {1, 5}, + {0.00685468722145889, 0.00822128238053265, 0.00386965914609878, + 0.00308849888680941, 0.00279277397245112}, + DataType::FLOAT32); + NDArray stateMsg2( + 'c', {1, 5}, + {0.05735222273539331, 0.04918781007340889, 0.07895331423716523, + 0.04777915987899536, 0.05106222979448406}, + DataType::FLOAT32); + NDArray stateMsdx2( + 'c', {1, 5}, + {0.00000495569095238, 0.00000583237140987, 0.00000231810630717, + 0.0000017132162954, 0.00000305221226067}, + DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg2)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAdaDelta3) { - - NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); - NDArray initVC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); // Msg - NDArray initMC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); // Msdx - - NDArray grad('f', { 1, 5 }, DataType::FLOAT32); - NDArray initMsg('f', { 1, 5 }, DataType::FLOAT32); - NDArray initMsdx('f', { 1, 5 }, DataType::FLOAT32); - - grad.assign(gradC); - initMsg.assign(initVC); - initMsdx.assign(initMC); - - sd::ops::ada_delta_updater op; - auto results = op.evaluate({ &grad, &initMsg, &initMsdx }, { 0.95f, 1.0e-6 }, { }); - - NDArray updateC('c', { 1, 5 }, { 0.00447209123431084, 0.00447212477470162, 0.00447213098596791, 0.00447213315991723, 0.00447213416614627 }, DataType::FLOAT32); - NDArray update('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateV0C('c', { 1, 5 }, { 0.05000000000000004, 0.20000000000000018, 0.4500000000000004, 0.8000000000000007, 1.250000000000001 }, DataType::FLOAT32); - NDArray stateMsg('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateM0C('c', { 1, 5 }, { 0.0000009999800004, 0.00000099999500002, 0.00000099999777778, 0.00000099999875, 0.0000009999992 }, DataType::FLOAT32); - NDArray stateMsdx('f', { 1, 5 }, DataType::FLOAT32); - - update.assign(updateC); - stateMsg.assign(stateV0C); - stateMsdx.assign(stateM0C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateMsg.isSameShape(results.at(1))); - ASSERT_TRUE(stateMsg.equalsTo(results.at(1))); - ASSERT_TRUE(stateMsdx.isSameShape(results.at(2))); - ASSERT_TRUE(stateMsdx.equalsTo(results.at(2))); - - results = op.evaluate({ &grad, results.at(1), results.at(2) }, { 0.95, 1.0e-6 }, { }); - - NDArray update1C('c', { 1, 5 }, { 0.0045290622655332, 0.00452909666868751, 0.00452910303972733, 0.00452910526959756, 0.00452910630171004 }, DataType::FLOAT32); - - NDArray stateV1C('c', { 1, 5 }, { 0.09750000000000009, 0.39000000000000035, 0.8775000000000008, 1.5600000000000014, 2.4375000000000018 }, DataType::FLOAT32); - NDArray stateM1C('c', { 1, 5 }, { 0.00000197560125063, 0.00000197563108174, 0.00000197563660612, 0.00000197563853966, 0.00000197563943461 }, DataType::FLOAT32); - - update.assign(update1C); - stateMsg.assign(stateV1C); - stateMsdx.assign(stateM1C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateMsg.isSameShape(results.at(1))); - ASSERT_TRUE(stateMsg.equalsTo(results.at(1))); - ASSERT_TRUE(stateMsdx.isSameShape(results.at(2))); - ASSERT_TRUE(stateMsdx.equalsTo(results.at(2))); - - results = op.evaluate({ &grad, &stateMsg, &stateMsdx }, { 0.95f, 1.0e-6 }, { }); - - NDArray update2C('c', { 1, 5 }, { 0.00456759948242601, 0.00456763438748812, 0.00456764085147516, 0.00456764311387702, 0.004567644161047 }, DataType::FLOAT32); - NDArray stateV2C('c', { 1, 5 }, { 0.1426250000000001, 0.5705000000000005, 1.2836250000000011, 2.282000000000002, 3.5656250000000025 }, DataType::FLOAT32); - NDArray stateM2C('c', { 1, 5 }, { 0.0000029199694397, 0.00000292001372254, 0.00000292002192321, 0.00000292002479346, 0.00000292002612198 }, DataType::FLOAT32); - - update.assign(update2C); - stateMsg.assign(stateV2C); - stateMsdx.assign(stateM2C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateMsg.isSameShape(results.at(1))); - ASSERT_TRUE(stateMsg.equalsTo(results.at(1))); - ASSERT_TRUE(stateMsdx.isSameShape(results.at(2))); - ASSERT_TRUE(stateMsdx.equalsTo(results.at(2))); + NDArray gradC('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray initVC('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, + DataType::FLOAT32); // Msg + NDArray initMC('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, + DataType::FLOAT32); // Msdx + + NDArray grad('f', {1, 5}, DataType::FLOAT32); + NDArray initMsg('f', {1, 5}, DataType::FLOAT32); + NDArray initMsdx('f', {1, 5}, DataType::FLOAT32); + + grad.assign(gradC); + initMsg.assign(initVC); + initMsdx.assign(initMC); + + sd::ops::ada_delta_updater op; + auto results = op.evaluate({&grad, &initMsg, &initMsdx}, {0.95f, 1.0e-6}, {}); + + NDArray updateC( + 'c', {1, 5}, + {0.00447209123431084, 0.00447212477470162, 0.00447213098596791, + 0.00447213315991723, 0.00447213416614627}, + DataType::FLOAT32); + NDArray update('f', {1, 5}, DataType::FLOAT32); + + NDArray stateV0C('c', {1, 5}, + {0.05000000000000004, 0.20000000000000018, + 0.4500000000000004, 0.8000000000000007, 1.250000000000001}, + DataType::FLOAT32); + NDArray stateMsg('f', {1, 5}, DataType::FLOAT32); + + NDArray stateM0C('c', {1, 5}, + {0.0000009999800004, 0.00000099999500002, + 0.00000099999777778, 0.00000099999875, 0.0000009999992}, + DataType::FLOAT32); + NDArray stateMsdx('f', {1, 5}, DataType::FLOAT32); + + update.assign(updateC); + stateMsg.assign(stateV0C); + stateMsdx.assign(stateM0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateMsg.isSameShape(results.at(1))); + ASSERT_TRUE(stateMsg.equalsTo(results.at(1))); + ASSERT_TRUE(stateMsdx.isSameShape(results.at(2))); + ASSERT_TRUE(stateMsdx.equalsTo(results.at(2))); + + results = + op.evaluate({&grad, &results.at(1), &results.at(2)}, {0.95, 1.0e-6}, {}); + + NDArray update1C( + 'c', {1, 5}, + {0.0045290622655332, 0.00452909666868751, 0.00452910303972733, + 0.00452910526959756, 0.00452910630171004}, + DataType::FLOAT32); + + NDArray stateV1C('c', {1, 5}, + {0.09750000000000009, 0.39000000000000035, + 0.8775000000000008, 1.5600000000000014, 2.4375000000000018}, + DataType::FLOAT32); + NDArray stateM1C( + 'c', {1, 5}, + {0.00000197560125063, 0.00000197563108174, 0.00000197563660612, + 0.00000197563853966, 0.00000197563943461}, + DataType::FLOAT32); + + update.assign(update1C); + stateMsg.assign(stateV1C); + stateMsdx.assign(stateM1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateMsg.isSameShape(results.at(1))); + ASSERT_TRUE(stateMsg.equalsTo(results.at(1))); + ASSERT_TRUE(stateMsdx.isSameShape(results.at(2))); + ASSERT_TRUE(stateMsdx.equalsTo(results.at(2))); + + results = op.evaluate({&grad, &stateMsg, &stateMsdx}, {0.95f, 1.0e-6}, {}); + + NDArray update2C( + 'c', {1, 5}, + {0.00456759948242601, 0.00456763438748812, 0.00456764085147516, + 0.00456764311387702, 0.004567644161047}, + DataType::FLOAT32); + NDArray stateV2C('c', {1, 5}, + {0.1426250000000001, 0.5705000000000005, 1.2836250000000011, + 2.282000000000002, 3.5656250000000025}, + DataType::FLOAT32); + NDArray stateM2C( + 'c', {1, 5}, + {0.0000029199694397, 0.00000292001372254, 0.00000292002192321, + 0.00000292002479346, 0.00000292002612198}, + DataType::FLOAT32); + + update.assign(update2C); + stateMsg.assign(stateV2C); + stateMsdx.assign(stateM2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateMsg.isSameShape(results.at(1))); + ASSERT_TRUE(stateMsg.equalsTo(results.at(1))); + ASSERT_TRUE(stateMsdx.isSameShape(results.at(2))); + ASSERT_TRUE(stateMsdx.equalsTo(results.at(2))); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterNadam1) { - - NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); - NDArray initV('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - NDArray update('c', { 1, 5 }, DataType::FLOAT32); - - sd::ops::nadam_updater op; - - Nd4jStatus status = op.execute({ &grad, &initV, &initM }, { &update, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp0('c', { 1, 5 }, { 0.06008325654320519, 0.06008326604320069, 0.06008326920986652, 0.06008327079319956, 0.0600832717431994 }, DataType::FLOAT32); - NDArray stateV('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); - NDArray stateM0('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.499999999999999 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp0)); - ASSERT_TRUE(initV.equalsTo(stateV)); - ASSERT_TRUE(initM.equalsTo(stateM0)); - - status = op.execute({ &grad, &initV, &initM }, { &update, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp1('c', { 1, 5 }, { 0.06061258367739481, 0.06061259045578174, 0.06061259271524436, 0.06061259384497576, 0.06061259452281461 }, DataType::FLOAT32); - NDArray stateV1('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); - NDArray stateM1('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp1)); - ASSERT_TRUE(initV.equalsTo(stateV1)); - ASSERT_TRUE(initM.equalsTo(stateM1)); - - status = op.execute({ &grad, &initV, &initM }, { &update, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp2('c', { 1, 5 }, { 0.06281865774973168, 0.06281866348713228, 0.06281866539959938, 0.06281866635583296, 0.06281866692957314 }, DataType::FLOAT32); - NDArray stateV2('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); - NDArray stateM2('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp2)); - ASSERT_TRUE(initV.equalsTo(stateV2)); - ASSERT_TRUE(initM.equalsTo(stateM2)); + NDArray grad('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray initV('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initM('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + NDArray update('c', {1, 5}, DataType::FLOAT32); + + sd::ops::nadam_updater op; + + Nd4jStatus status = + op.execute({&grad, &initV, &initM}, {&update, &initV, &initM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0( + 'c', {1, 5}, + {0.06008325654320519, 0.06008326604320069, 0.06008326920986652, + 0.06008327079319956, 0.0600832717431994}, + DataType::FLOAT32); + NDArray stateV('c', {1, 5}, + {0.001, 0.004, 0.00900000000000001, 0.01600000000000001, + 0.02500000000000002}, + DataType::FLOAT32); + NDArray stateM0('c', {1, 5}, + {0.09999999999999998, 0.19999999999999996, + 0.29999999999999993, 0.3999999999999999, 0.499999999999999}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initV.equalsTo(stateV)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + status = op.execute({&grad, &initV, &initM}, {&update, &initV, &initM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1( + 'c', {1, 5}, + {0.06061258367739481, 0.06061259045578174, 0.06061259271524436, + 0.06061259384497576, 0.06061259452281461}, + DataType::FLOAT32); + NDArray stateV1('c', {1, 5}, + {0.001999, 0.00799600000000001, 0.01799100000000001, + 0.03198400000000003, 0.04997500000000005}, + DataType::FLOAT32); + NDArray stateM1('c', {1, 5}, + {0.18999999999999995, 0.3799999999999999, 0.5699999999999998, + 0.7599999999999998, 0.9499999999999997}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initV.equalsTo(stateV1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + status = op.execute({&grad, &initV, &initM}, {&update, &initV, &initM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2( + 'c', {1, 5}, + {0.06281865774973168, 0.06281866348713228, 0.06281866539959938, + 0.06281866635583296, 0.06281866692957314}, + DataType::FLOAT32); + NDArray stateV2('c', {1, 5}, + {0.002997001, 0.01198800400000001, 0.02697300900000002, + 0.04795201600000004, 0.07492502500000006}, + DataType::FLOAT32); + NDArray stateM2('c', {1, 5}, + {0.2709999999999999, 0.5419999999999998, 0.8129999999999998, + 1.0839999999999996, 1.3549999999999995}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initV.equalsTo(stateV2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterNadam2) { - - NDArray grad0('c', { 1, 5 }, { 0.8047558665275574, 0.9653639197349548, 0.31240877509117126, 0.9530212879180908, 0.01295729912817478 }, DataType::FLOAT32); - NDArray initV('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - auto lr = NDArrayFactory::create(0.001f); - auto beta1 = NDArrayFactory::create(0.9f); - auto beta2 = NDArrayFactory::create(0.999f); - auto epsilon = NDArrayFactory::create(1.0e-8); - - sd::ops::nadam_updater op; - - Nd4jStatus status = op.execute({ &grad0, &initV, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad0, &initV, &initM }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp0('c', { 1, 5 }, { 0.06008325193356386, 0.0600832558615088, 0.06008321472550684, 0.06008325560661022, 0.0600818092240132 }, DataType::FLOAT32); - NDArray stateV0('c', { 1, 5 }, { 0.00064763200471052, 0.00093192749752604, 0.00009759924275397, 0.00090824957522506, 0.0000001678916007 }, DataType::FLOAT32); - NDArray stateM0('c', { 1, 5 }, { 0.08047558665275573, 0.09653639197349546, 0.03124087750911712, 0.09530212879180906, 0.00129572991281748 }, DataType::FLOAT32); - - ASSERT_TRUE(grad0.equalsTo(updateExp0)); - ASSERT_TRUE(initV.equalsTo(stateV0)); - ASSERT_TRUE(initM.equalsTo(stateM0)); - - NDArray grad1('c', { 1, 5 }, { 0.9839006662368774, 0.8964805603027344, 0.3631269931793213, 0.00931886397302151, 0.6320028901100159 }, DataType::FLOAT32); - - status = op.execute({ &grad1, &initV, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad1, &initV, &initM }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp1('c', { 1, 5 }, { 0.06273730114378717, 0.0596708938019245, 0.06226533928512862, 0.02621380498466489, 0.06059567064824535 }, DataType::FLOAT32); - NDArray stateV1('c', { 1, 5 }, { 0.00161504489372718, 0.00173467296502922, 0.00022936285668667, 0.00090742816687558, 0.0003995953768165 }, DataType::FLOAT32); - NDArray stateM1('c', { 1, 5 }, { 0.17081809461116787, 0.17653080880641933, 0.06442948907613753, 0.08670380230993031, 0.06436644593253729 }, DataType::FLOAT32); - - ASSERT_TRUE(grad1.equalsTo(updateExp1)); - ASSERT_TRUE(initV.equalsTo(stateV1)); - ASSERT_TRUE(initM.equalsTo(stateM1)); - - NDArray grad2('c', { 1, 5 }, { 0.7712154984474182, 0.1282273381948471, 0.7019220590591431, 0.8883536458015442, 0.33057701587677 }, DataType::FLOAT32); - - status = op.execute({ &grad2, &initV, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad2, &initV, &initM }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp2('c', { 1, 5 }, { 0.06062658222261493, 0.04001212712739213, 0.06906390273197544, 0.05804376499107734, 0.05097529565845974 }, DataType::FLOAT32); - NDArray stateV2('c', { 1, 5 }, { 0.00220820319387896, 0.00174938054232472, 0.00072182807082381, 0.0016956929387176, 0.00050847694486568 }, DataType::FLOAT32); - NDArray stateM2('c', { 1, 5 }, { 0.2308578349947929, 0.1717004617452621, 0.12817874607443808, 0.16686878665909166, 0.09098750292696056 }, DataType::FLOAT32); - - ASSERT_TRUE(grad2.equalsTo(updateExp2)); - ASSERT_TRUE(initV.equalsTo(stateV2)); - ASSERT_TRUE(initM.equalsTo(stateM2)); + NDArray grad0('c', {1, 5}, + {0.8047558665275574, 0.9653639197349548, 0.31240877509117126, + 0.9530212879180908, 0.01295729912817478}, + DataType::FLOAT32); + NDArray initV('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initM('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.001f); + auto beta1 = NDArrayFactory::create(0.9f); + auto beta2 = NDArrayFactory::create(0.999f); + auto epsilon = NDArrayFactory::create(1.0e-8); + + sd::ops::nadam_updater op; + + Nd4jStatus status = + op.execute({&grad0, &initV, &initM, &lr, &beta1, &beta2, &epsilon}, + {&grad0, &initV, &initM}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0( + 'c', {1, 5}, + {0.06008325193356386, 0.0600832558615088, 0.06008321472550684, + 0.06008325560661022, 0.0600818092240132}, + DataType::FLOAT32); + NDArray stateV0( + 'c', {1, 5}, + {0.00064763200471052, 0.00093192749752604, 0.00009759924275397, + 0.00090824957522506, 0.0000001678916007}, + DataType::FLOAT32); + NDArray stateM0( + 'c', {1, 5}, + {0.08047558665275573, 0.09653639197349546, 0.03124087750911712, + 0.09530212879180906, 0.00129572991281748}, + DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(initV.equalsTo(stateV0)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + NDArray grad1('c', {1, 5}, + {0.9839006662368774, 0.8964805603027344, 0.3631269931793213, + 0.00931886397302151, 0.6320028901100159}, + DataType::FLOAT32); + + status = op.execute({&grad1, &initV, &initM, &lr, &beta1, &beta2, &epsilon}, + {&grad1, &initV, &initM}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1( + 'c', {1, 5}, + {0.06273730114378717, 0.0596708938019245, 0.06226533928512862, + 0.02621380498466489, 0.06059567064824535}, + DataType::FLOAT32); + NDArray stateV1( + 'c', {1, 5}, + {0.00161504489372718, 0.00173467296502922, 0.00022936285668667, + 0.00090742816687558, 0.0003995953768165}, + DataType::FLOAT32); + NDArray stateM1( + 'c', {1, 5}, + {0.17081809461116787, 0.17653080880641933, 0.06442948907613753, + 0.08670380230993031, 0.06436644593253729}, + DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(initV.equalsTo(stateV1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + NDArray grad2('c', {1, 5}, + {0.7712154984474182, 0.1282273381948471, 0.7019220590591431, + 0.8883536458015442, 0.33057701587677}, + DataType::FLOAT32); + + status = op.execute({&grad2, &initV, &initM, &lr, &beta1, &beta2, &epsilon}, + {&grad2, &initV, &initM}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2( + 'c', {1, 5}, + {0.06062658222261493, 0.04001212712739213, 0.06906390273197544, + 0.05804376499107734, 0.05097529565845974}, + DataType::FLOAT32); + NDArray stateV2( + 'c', {1, 5}, + {0.00220820319387896, 0.00174938054232472, 0.00072182807082381, + 0.0016956929387176, 0.00050847694486568}, + DataType::FLOAT32); + NDArray stateM2('c', {1, 5}, + {0.2308578349947929, 0.1717004617452621, 0.12817874607443808, + 0.16686878665909166, 0.09098750292696056}, + DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(initV.equalsTo(stateV2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterNadam3) { - - NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); - NDArray initVC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initMC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - NDArray grad('f', { 1, 5 }, DataType::FLOAT32); - NDArray initV('f', { 1, 5 }, DataType::FLOAT32); - NDArray initM('f', { 1, 5 }, DataType::FLOAT32); - - grad.assign(gradC); - initV.assign(initVC); - initM.assign(initMC); - - sd::ops::nadam_updater op; - auto results = op.evaluate({ &grad, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - - NDArray updateC('c', { 1, 5 }, { 0.06008325654320519, 0.06008326604320069, 0.06008326920986652, 0.06008327079319956, 0.0600832717431994 }, DataType::FLOAT32); - NDArray update('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateV0C('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); - NDArray stateV('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateM0C('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.499999999999999 }, DataType::FLOAT32); - NDArray stateM('f', { 1, 5 }, DataType::FLOAT32); - - update.assign(updateC); - stateV.assign(stateV0C); - stateM.assign(stateM0C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateV.isSameShape(results.at(1))); - ASSERT_TRUE(stateV.equalsTo(results.at(1))); - ASSERT_TRUE(stateM.isSameShape(results.at(2))); - ASSERT_TRUE(stateM.equalsTo(results.at(2))); - - results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray update1C('c', { 1, 5 }, { 0.06061258367739481, 0.06061259045578174, 0.06061259271524436, 0.06061259384497576, 0.06061259452281461 }, DataType::FLOAT32); - NDArray stateV1C('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); - NDArray stateM1C('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); - - update.assign(update1C); - stateV.assign(stateV1C); - stateM.assign(stateM1C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateV.isSameShape(results.at(1))); - ASSERT_TRUE(stateV.equalsTo(results.at(1))); - ASSERT_TRUE(stateM.isSameShape(results.at(2))); - ASSERT_TRUE(stateM.equalsTo(results.at(2))); - - results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - - NDArray update2C('c', { 1, 5 }, { 0.06281865774973168, 0.06281866348713228, 0.06281866539959938, 0.06281866635583296, 0.06281866692957314 }, DataType::FLOAT32); - NDArray stateV2C('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); - NDArray stateM2C('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); - - update.assign(update2C); - stateV.assign(stateV2C); - stateM.assign(stateM2C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateV.isSameShape(results.at(1))); - ASSERT_TRUE(stateV.equalsTo(results.at(1))); - ASSERT_TRUE(stateM.isSameShape(results.at(2))); - ASSERT_TRUE(stateM.equalsTo(results.at(2))); + NDArray gradC('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray initVC('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initMC('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + NDArray grad('f', {1, 5}, DataType::FLOAT32); + NDArray initV('f', {1, 5}, DataType::FLOAT32); + NDArray initM('f', {1, 5}, DataType::FLOAT32); + + grad.assign(gradC); + initV.assign(initVC); + initM.assign(initMC); + + sd::ops::nadam_updater op; + auto results = + op.evaluate({&grad, &initV, &initM}, {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + + NDArray updateC( + 'c', {1, 5}, + {0.06008325654320519, 0.06008326604320069, 0.06008326920986652, + 0.06008327079319956, 0.0600832717431994}, + DataType::FLOAT32); + NDArray update('f', {1, 5}, DataType::FLOAT32); + + NDArray stateV0C('c', {1, 5}, + {0.001, 0.004, 0.00900000000000001, 0.01600000000000001, + 0.02500000000000002}, + DataType::FLOAT32); + NDArray stateV('f', {1, 5}, DataType::FLOAT32); + + NDArray stateM0C('c', {1, 5}, + {0.09999999999999998, 0.19999999999999996, + 0.29999999999999993, 0.3999999999999999, 0.499999999999999}, + DataType::FLOAT32); + NDArray stateM('f', {1, 5}, DataType::FLOAT32); + + update.assign(updateC); + stateV.assign(stateV0C); + stateM.assign(stateM0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({&grad, &stateV, &stateM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C( + 'c', {1, 5}, + {0.06061258367739481, 0.06061259045578174, 0.06061259271524436, + 0.06061259384497576, 0.06061259452281461}, + DataType::FLOAT32); + NDArray stateV1C('c', {1, 5}, + {0.001999, 0.00799600000000001, 0.01799100000000001, + 0.03198400000000003, 0.04997500000000005}, + DataType::FLOAT32); + NDArray stateM1C('c', {1, 5}, + {0.18999999999999995, 0.3799999999999999, 0.5699999999999998, + 0.7599999999999998, 0.9499999999999997}, + DataType::FLOAT32); + + update.assign(update1C); + stateV.assign(stateV1C); + stateM.assign(stateM1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({&grad, &stateV, &stateM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + + NDArray update2C( + 'c', {1, 5}, + {0.06281865774973168, 0.06281866348713228, 0.06281866539959938, + 0.06281866635583296, 0.06281866692957314}, + DataType::FLOAT32); + NDArray stateV2C('c', {1, 5}, + {0.002997001, 0.01198800400000001, 0.02697300900000002, + 0.04795201600000004, 0.07492502500000006}, + DataType::FLOAT32); + NDArray stateM2C('c', {1, 5}, + {0.2709999999999999, 0.5419999999999998, 0.8129999999999998, + 1.0839999999999996, 1.3549999999999995}, + DataType::FLOAT32); + + update.assign(update2C); + stateV.assign(stateV2C); + stateM.assign(stateM2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAmsGrad1) { - - NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); - NDArray initV('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initH('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - NDArray update('c', { 1, 5 }, DataType::FLOAT32); - - sd::ops::ams_grad_updater op; - - Nd4jStatus status = op.execute({ &grad, &initV, &initM, &initH }, { &update, &initV, &initM, &initH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp0('c', { 1, 5 }, { 0.00099999968377233, 0.00099999984188614, 0.00099999989459076, 0.00099999992094306, 0.00099999993675445 }, DataType::FLOAT32); - NDArray stateV0('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); - NDArray stateH0('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); - NDArray stateM0('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp0)); - ASSERT_TRUE(initV.equalsTo(stateV0)); - ASSERT_TRUE(initH.equalsTo(stateH0)); - ASSERT_TRUE(initM.equalsTo(stateM0)); - - status = op.execute({ &grad, &initV, &initM, &initH }, { &update, &initV, &initM, &initH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp1('c', { 1, 5 }, { 0.00134383858541481, 0.00134383873569809, 0.00134383878579252, 0.00134383881083974, 0.00134383882586807 }, DataType::FLOAT32); - NDArray stateV1('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); - NDArray stateH1('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); - NDArray stateM1('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp1)); - ASSERT_TRUE(initV.equalsTo(stateV1)); - ASSERT_TRUE(initH.equalsTo(stateH1)); - ASSERT_TRUE(initM.equalsTo(stateM1)); - - status = op.execute({ &grad, &initV, &initM, &initH }, { &update, &initV, &initM, &initH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp2('c', { 1, 5 }, { 0.00156540157923389, 0.00156540172220632, 0.0015654017698638, 0.00156540179369254, 0.00156540180798979 }, DataType::FLOAT32); - NDArray stateV2('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); - NDArray stateH2('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); - NDArray stateM2('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp2)); - ASSERT_TRUE(initV.equalsTo(stateV2)); - ASSERT_TRUE(initH.equalsTo(stateH2)); - ASSERT_TRUE(initM.equalsTo(stateM2)); + NDArray grad('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray initV('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initM('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initH('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + NDArray update('c', {1, 5}, DataType::FLOAT32); + + sd::ops::ams_grad_updater op; + + Nd4jStatus status = op.execute({&grad, &initV, &initM, &initH}, + {&update, &initV, &initM, &initH}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0( + 'c', {1, 5}, + {0.00099999968377233, 0.00099999984188614, 0.00099999989459076, + 0.00099999992094306, 0.00099999993675445}, + DataType::FLOAT32); + NDArray stateV0('c', {1, 5}, + {0.001, 0.004, 0.00900000000000001, 0.01600000000000001, + 0.02500000000000002}, + DataType::FLOAT32); + NDArray stateH0('c', {1, 5}, + {0.001, 0.004, 0.00900000000000001, 0.01600000000000001, + 0.02500000000000002}, + DataType::FLOAT32); + NDArray stateM0('c', {1, 5}, + {0.09999999999999998, 0.19999999999999996, + 0.29999999999999993, 0.3999999999999999, 0.4999999999999999}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initV.equalsTo(stateV0)); + ASSERT_TRUE(initH.equalsTo(stateH0)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + status = op.execute({&grad, &initV, &initM, &initH}, + {&update, &initV, &initM, &initH}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1( + 'c', {1, 5}, + {0.00134383858541481, 0.00134383873569809, 0.00134383878579252, + 0.00134383881083974, 0.00134383882586807}, + DataType::FLOAT32); + NDArray stateV1('c', {1, 5}, + {0.001999, 0.00799600000000001, 0.01799100000000001, + 0.03198400000000003, 0.04997500000000005}, + DataType::FLOAT32); + NDArray stateH1('c', {1, 5}, + {0.001999, 0.00799600000000001, 0.01799100000000001, + 0.03198400000000003, 0.04997500000000005}, + DataType::FLOAT32); + NDArray stateM1('c', {1, 5}, + {0.18999999999999995, 0.3799999999999999, 0.5699999999999998, + 0.7599999999999998, 0.9499999999999997}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initV.equalsTo(stateV1)); + ASSERT_TRUE(initH.equalsTo(stateH1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + status = op.execute({&grad, &initV, &initM, &initH}, + {&update, &initV, &initM, &initH}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2( + 'c', {1, 5}, + {0.00156540157923389, 0.00156540172220632, 0.0015654017698638, + 0.00156540179369254, 0.00156540180798979}, + DataType::FLOAT32); + NDArray stateV2('c', {1, 5}, + {0.002997001, 0.01198800400000001, 0.02697300900000002, + 0.04795201600000004, 0.07492502500000006}, + DataType::FLOAT32); + NDArray stateH2('c', {1, 5}, + {0.002997001, 0.01198800400000001, 0.02697300900000002, + 0.04795201600000004, 0.07492502500000006}, + DataType::FLOAT32); + NDArray stateM2('c', {1, 5}, + {0.2709999999999999, 0.5419999999999998, 0.8129999999999998, + 1.0839999999999996, 1.3549999999999995}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initV.equalsTo(stateV2)); + ASSERT_TRUE(initH.equalsTo(stateH2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAmsGrad2) { - - NDArray grad0('c', { 1, 5 }, { 0.5730348229408264, 0.04330538213253021, 0.249028742313385, 0.6514443755149841, 0.7017051577568054 }, DataType::FLOAT32); - NDArray initH('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initV('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - auto lr = NDArrayFactory::create(0.001f); - auto beta1 = NDArrayFactory::create(0.9f); - auto beta2 = NDArrayFactory::create(0.999f); - auto epsilon = NDArrayFactory::create(1.0e-8); - - sd::ops::ams_grad_updater op; - - Nd4jStatus status = op.execute({ &grad0, &initV, &initM, &initH, &lr, &beta1, &beta2, &epsilon }, { &grad0, &initV, &initM, &initH }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp0('c', { 1, 5 }, { 0.00099999944815292, 0.00099999269777932, 0.00099999873015716, 0.00099999951457465, 0.00099999954934402 }, DataType::FLOAT32); - NDArray stateV0('c', { 1, 5 }, { 0.00032836890830282, 0.00000187535612164, 0.00006201531449819, 0.00042437977439011, 0.0004923901284225 }, DataType::FLOAT32); - NDArray stateH0('c', { 1, 5 }, { 0.00032836890830282, 0.00000187535612164, 0.00006201531449819, 0.00042437977439011, 0.00049239012842255 }, DataType::FLOAT32); - NDArray stateM0('c', { 1, 5 }, { 0.05730348229408263, 0.00433053821325302, 0.0249028742313385, 0.0651444375514984, 0.07017051577568052 }, DataType::FLOAT32); - - ASSERT_TRUE(grad0.equalsTo(updateExp0)); - ASSERT_TRUE(initV.equalsTo(stateV0)); - ASSERT_TRUE(initH.equalsTo(stateH0)); - ASSERT_TRUE(initM.equalsTo(stateM0)); - - NDArray grad1('c', { 1, 5 }, { 0.6404328346252441, 0.9432603120803833, 0.45608729124069214, 0.9097326993942261, 0.748093843460083 }, DataType::FLOAT32); - - status = op.execute({ &grad1, &initV, &initM, &initH, &lr, &beta1, &beta2, &epsilon }, { &grad1, &initV, &initM, &initH }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp1('c', { 1, 5 }, { 0.00134565543815267, 0.00104022434054697, 0.00130914539820157, 0.00133725290576052, 0.0013453914974122 }, DataType::FLOAT32); - NDArray stateV1('c', { 1, 5 }, { 0.00073819475506065, 0.00089161349711151, 0.00026996891641496, 0.00125156897896282, 0.00105154213691696 }, DataType::FLOAT32); - NDArray stateH1('c', { 1, 5 }, { 0.00073819475506065, 0.00089161349711151, 0.00026996891641496, 0.00125156897896282, 0.00105154213691696 }, DataType::FLOAT32); - NDArray stateM1('c', { 1, 5 }, { 0.11561641752719877, 0.09822351559996603, 0.06802131593227385, 0.14960326373577115, 0.13796284854412078 }, DataType::FLOAT32); - - ASSERT_TRUE(grad1.equalsTo(updateExp1)); - ASSERT_TRUE(initV.equalsTo(stateV1)); - ASSERT_TRUE(initH.equalsTo(stateH1)); - ASSERT_TRUE(initM.equalsTo(stateM1)); - - NDArray grad2('c', { 1, 5 }, { 0.46250319480895996, 0.09698919206857681, 0.21754667162895203, 0.46824514865875244, 0.6005083918571472 }, DataType::FLOAT32); - - status = op.execute({ &grad2, &initV, &initM, &initH, &lr, &beta1, &beta2, &epsilon }, { &grad2, &initV, &initM, &initH }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp2('c', { 1, 5 }, { 0.00154098993679222, 0.00103399135000281, 0.00147364850040774, 0.00149693641196572, 0.00155078467854623 }, DataType::FLOAT32); - NDArray stateV2('c', { 1, 5 }, { 0.00095136576551408, 0.00090012878699251, 0.00031702550183538, 0.00146957092922632, 0.0014111009234709 }, DataType::FLOAT32); - NDArray stateH2('c', { 1, 5 }, { 0.00095136576551408, 0.00090012878699251, 0.00031702550183538, 0.00146957092922632, 0.0014111009234709 }, DataType::FLOAT32); - NDArray stateM2('c', { 1, 5 }, { 0.1503050952553749, 0.09810008324682712, 0.08297385150194167, 0.1814674522280693, 0.1842174028754234 }, DataType::FLOAT32); - - ASSERT_TRUE(grad2.equalsTo(updateExp2)); - ASSERT_TRUE(initV.equalsTo(stateV2)); - ASSERT_TRUE(initH.equalsTo(stateH2)); - ASSERT_TRUE(initM.equalsTo(stateM2)); + NDArray grad0('c', {1, 5}, + {0.5730348229408264, 0.04330538213253021, 0.249028742313385, + 0.6514443755149841, 0.7017051577568054}, + DataType::FLOAT32); + NDArray initH('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initV('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initM('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.001f); + auto beta1 = NDArrayFactory::create(0.9f); + auto beta2 = NDArrayFactory::create(0.999f); + auto epsilon = NDArrayFactory::create(1.0e-8); + + sd::ops::ams_grad_updater op; + + Nd4jStatus status = op.execute( + {&grad0, &initV, &initM, &initH, &lr, &beta1, &beta2, &epsilon}, + {&grad0, &initV, &initM, &initH}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0( + 'c', {1, 5}, + {0.00099999944815292, 0.00099999269777932, 0.00099999873015716, + 0.00099999951457465, 0.00099999954934402}, + DataType::FLOAT32); + NDArray stateV0( + 'c', {1, 5}, + {0.00032836890830282, 0.00000187535612164, 0.00006201531449819, + 0.00042437977439011, 0.0004923901284225}, + DataType::FLOAT32); + NDArray stateH0( + 'c', {1, 5}, + {0.00032836890830282, 0.00000187535612164, 0.00006201531449819, + 0.00042437977439011, 0.00049239012842255}, + DataType::FLOAT32); + NDArray stateM0('c', {1, 5}, + {0.05730348229408263, 0.00433053821325302, 0.0249028742313385, + 0.0651444375514984, 0.07017051577568052}, + DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(initV.equalsTo(stateV0)); + ASSERT_TRUE(initH.equalsTo(stateH0)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + NDArray grad1('c', {1, 5}, + {0.6404328346252441, 0.9432603120803833, 0.45608729124069214, + 0.9097326993942261, 0.748093843460083}, + DataType::FLOAT32); + + status = op.execute( + {&grad1, &initV, &initM, &initH, &lr, &beta1, &beta2, &epsilon}, + {&grad1, &initV, &initM, &initH}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1( + 'c', {1, 5}, + {0.00134565543815267, 0.00104022434054697, 0.00130914539820157, + 0.00133725290576052, 0.0013453914974122}, + DataType::FLOAT32); + NDArray stateV1( + 'c', {1, 5}, + {0.00073819475506065, 0.00089161349711151, 0.00026996891641496, + 0.00125156897896282, 0.00105154213691696}, + DataType::FLOAT32); + NDArray stateH1( + 'c', {1, 5}, + {0.00073819475506065, 0.00089161349711151, 0.00026996891641496, + 0.00125156897896282, 0.00105154213691696}, + DataType::FLOAT32); + NDArray stateM1( + 'c', {1, 5}, + {0.11561641752719877, 0.09822351559996603, 0.06802131593227385, + 0.14960326373577115, 0.13796284854412078}, + DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(initV.equalsTo(stateV1)); + ASSERT_TRUE(initH.equalsTo(stateH1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + NDArray grad2('c', {1, 5}, + {0.46250319480895996, 0.09698919206857681, 0.21754667162895203, + 0.46824514865875244, 0.6005083918571472}, + DataType::FLOAT32); + + status = op.execute( + {&grad2, &initV, &initM, &initH, &lr, &beta1, &beta2, &epsilon}, + {&grad2, &initV, &initM, &initH}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2( + 'c', {1, 5}, + {0.00154098993679222, 0.00103399135000281, 0.00147364850040774, + 0.00149693641196572, 0.00155078467854623}, + DataType::FLOAT32); + NDArray stateV2( + 'c', {1, 5}, + {0.00095136576551408, 0.00090012878699251, 0.00031702550183538, + 0.00146957092922632, 0.0014111009234709}, + DataType::FLOAT32); + NDArray stateH2( + 'c', {1, 5}, + {0.00095136576551408, 0.00090012878699251, 0.00031702550183538, + 0.00146957092922632, 0.0014111009234709}, + DataType::FLOAT32); + NDArray stateM2('c', {1, 5}, + {0.1503050952553749, 0.09810008324682712, 0.08297385150194167, + 0.1814674522280693, 0.1842174028754234}, + DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(initV.equalsTo(stateV2)); + ASSERT_TRUE(initH.equalsTo(stateH2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAmsGrad3) { - - NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); - NDArray initVC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initMC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initHC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - NDArray grad('f', { 1, 5 }, DataType::FLOAT32); - NDArray initV('f', { 1, 5 }, DataType::FLOAT32); - NDArray initM('f', { 1, 5 }, DataType::FLOAT32); - NDArray initH('f', { 1, 5 }, DataType::FLOAT32); - - grad.assign(gradC); - initV.assign(initVC); - initM.assign(initMC); - initH.assign(initHC); - - sd::ops::ams_grad_updater op; - auto results = op.evaluate({ &grad, &initV, &initM, &initH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - - NDArray updateC('c', { 1, 5 }, { 0.00099999968377233, 0.00099999984188614, 0.00099999989459076, 0.00099999992094306, 0.00099999993675445 }, DataType::FLOAT32); - NDArray update('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateV0C('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); - NDArray stateV('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateM0C('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); - NDArray stateM('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateH0C('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); - NDArray stateH('f', { 1, 5 }, DataType::FLOAT32); - - update.assign(updateC); - stateV.assign(stateV0C); - stateM.assign(stateM0C); - stateH.assign(stateH0C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateV.isSameShape(results.at(1))); - ASSERT_TRUE(stateV.equalsTo(results.at(1))); - ASSERT_TRUE(stateM.isSameShape(results.at(2))); - ASSERT_TRUE(stateM.equalsTo(results.at(2))); - ASSERT_TRUE(stateH.isSameShape(results.at(3))); - ASSERT_TRUE(stateH.equalsTo(results.at(3))); - - results = op.evaluate({ &grad, &stateV, &stateM, &stateH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray update1C('c', { 1, 5 }, { 0.00134383858541481, 0.00134383873569809, 0.00134383878579252, 0.00134383881083974, 0.00134383882586807 }, DataType::FLOAT32); - NDArray stateV1C('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); - NDArray stateM1C('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); - NDArray stateH1C('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); - - - update.assign(update1C); - stateV.assign(stateV1C); - stateM.assign(stateM1C); - stateH.assign(stateH1C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateV.isSameShape(results.at(1))); - ASSERT_TRUE(stateV.equalsTo(results.at(1))); - ASSERT_TRUE(stateM.isSameShape(results.at(2))); - ASSERT_TRUE(stateM.equalsTo(results.at(2))); - ASSERT_TRUE(stateH.isSameShape(results.at(3))); - ASSERT_TRUE(stateH.equalsTo(results.at(3))); - - results = op.evaluate({ &grad, &stateV, &stateM, &stateH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - - - NDArray update2C('c', { 1, 5 }, { 0.00156540157923389, 0.00156540172220632, 0.0015654017698638, 0.00156540179369254, 0.00156540180798979 }, DataType::FLOAT32); - NDArray stateV2C('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); - NDArray stateM2C('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); - NDArray stateH2C('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); - - - update.assign(update2C); - stateV.assign(stateV2C); - stateM.assign(stateM2C); - stateH.assign(stateH2C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateV.isSameShape(results.at(1))); - ASSERT_TRUE(stateV.equalsTo(results.at(1))); - ASSERT_TRUE(stateM.isSameShape(results.at(2))); - ASSERT_TRUE(stateM.equalsTo(results.at(2))); - ASSERT_TRUE(stateH.isSameShape(results.at(3))); - ASSERT_TRUE(stateH.equalsTo(results.at(3))); + NDArray gradC('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray initVC('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initMC('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initHC('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + NDArray grad('f', {1, 5}, DataType::FLOAT32); + NDArray initV('f', {1, 5}, DataType::FLOAT32); + NDArray initM('f', {1, 5}, DataType::FLOAT32); + NDArray initH('f', {1, 5}, DataType::FLOAT32); + + grad.assign(gradC); + initV.assign(initVC); + initM.assign(initMC); + initH.assign(initHC); + + sd::ops::ams_grad_updater op; + auto results = op.evaluate({&grad, &initV, &initM, &initH}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + + NDArray updateC( + 'c', {1, 5}, + {0.00099999968377233, 0.00099999984188614, 0.00099999989459076, + 0.00099999992094306, 0.00099999993675445}, + DataType::FLOAT32); + NDArray update('f', {1, 5}, DataType::FLOAT32); + + NDArray stateV0C('c', {1, 5}, + {0.001, 0.004, 0.00900000000000001, 0.01600000000000001, + 0.02500000000000002}, + DataType::FLOAT32); + NDArray stateV('f', {1, 5}, DataType::FLOAT32); + + NDArray stateM0C( + 'c', {1, 5}, + {0.09999999999999998, 0.19999999999999996, 0.29999999999999993, + 0.3999999999999999, 0.4999999999999999}, + DataType::FLOAT32); + NDArray stateM('f', {1, 5}, DataType::FLOAT32); + + NDArray stateH0C('c', {1, 5}, + {0.001, 0.004, 0.00900000000000001, 0.01600000000000001, + 0.02500000000000002}, + DataType::FLOAT32); + NDArray stateH('f', {1, 5}, DataType::FLOAT32); + + update.assign(updateC); + stateV.assign(stateV0C); + stateM.assign(stateM0C); + stateH.assign(stateH0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + ASSERT_TRUE(stateH.isSameShape(results.at(3))); + ASSERT_TRUE(stateH.equalsTo(results.at(3))); + + results = op.evaluate({&grad, &stateV, &stateM, &stateH}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C( + 'c', {1, 5}, + {0.00134383858541481, 0.00134383873569809, 0.00134383878579252, + 0.00134383881083974, 0.00134383882586807}, + DataType::FLOAT32); + NDArray stateV1C('c', {1, 5}, + {0.001999, 0.00799600000000001, 0.01799100000000001, + 0.03198400000000003, 0.04997500000000005}, + DataType::FLOAT32); + NDArray stateM1C('c', {1, 5}, + {0.18999999999999995, 0.3799999999999999, 0.5699999999999998, + 0.7599999999999998, 0.9499999999999997}, + DataType::FLOAT32); + NDArray stateH1C('c', {1, 5}, + {0.001999, 0.00799600000000001, 0.01799100000000001, + 0.03198400000000003, 0.04997500000000005}, + DataType::FLOAT32); + + update.assign(update1C); + stateV.assign(stateV1C); + stateM.assign(stateM1C); + stateH.assign(stateH1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + ASSERT_TRUE(stateH.isSameShape(results.at(3))); + ASSERT_TRUE(stateH.equalsTo(results.at(3))); + + results = op.evaluate({&grad, &stateV, &stateM, &stateH}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + + NDArray update2C( + 'c', {1, 5}, + {0.00156540157923389, 0.00156540172220632, 0.0015654017698638, + 0.00156540179369254, 0.00156540180798979}, + DataType::FLOAT32); + NDArray stateV2C('c', {1, 5}, + {0.002997001, 0.01198800400000001, 0.02697300900000002, + 0.04795201600000004, 0.07492502500000006}, + DataType::FLOAT32); + NDArray stateM2C('c', {1, 5}, + {0.2709999999999999, 0.5419999999999998, 0.8129999999999998, + 1.0839999999999996, 1.3549999999999995}, + DataType::FLOAT32); + NDArray stateH2C('c', {1, 5}, + {0.002997001, 0.01198800400000001, 0.02697300900000002, + 0.04795201600000004, 0.07492502500000006}, + DataType::FLOAT32); + + update.assign(update2C); + stateV.assign(stateV2C); + stateM.assign(stateM2C); + stateH.assign(stateH2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + ASSERT_TRUE(stateH.isSameShape(results.at(3))); + ASSERT_TRUE(stateH.equalsTo(results.at(3))); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp index d3d1deed8932..cb05a3e632b7 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp @@ -14,30 +14,28 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // @author raver119@gmail.com // -#include "testlayers.h" -#include #include -#include #include -#include #include +#include +#include +#include -using namespace sd; +#include "testlayers.h" +using namespace sd; class DeclarableOpsTests19 : public testing::Test { -public: - - DeclarableOpsTests19() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests19() { + printf("\n"); + fflush(stdout); + } }; @@ -54,195 +52,199 @@ TEST_F(DeclarableOpsTests19, test_argmax_maxint_vector_1) { TEST_F(DeclarableOpsTests19, test_threshold_encode_1) { - auto x = NDArrayFactory::create('c', {3}, {1.5, 2.5, -3.5}); - auto exp_encoded = NDArrayFactory::create('c', {7}, {3, 3, 1056964608, 0, 1, 2, -3}); - auto exp_gradients = NDArrayFactory::create('c', {3}, {1.0, 2.0, -3.0}); + auto x = NDArrayFactory::create('c', {3}, {1.5, 2.5, -3.5}); + auto exp_encoded = + NDArrayFactory::create('c', {7}, {3, 3, 1056964608, 0, 1, 2, -3}); + auto exp_gradients = + NDArrayFactory::create('c', {3}, {1.0, 2.0, -3.0}); - sd::ops::encode_threshold op; - auto result = op.evaluate({&x}, {0.5}); + sd::ops::encode_threshold op; + auto result = op.evaluate({&x}, {0.5}); - auto gradients = result.at(0); - auto encoded = result.at(1); + auto gradients = result.at(0); + auto encoded = result.at(1); - //encoded->printIndexedBuffer("ENC"); + // encoded->printIndexedBuffer("ENC"); - ASSERT_EQ(exp_encoded, *encoded); - ASSERT_EQ(exp_gradients, x); + ASSERT_EQ(exp_encoded, encoded); + ASSERT_EQ(exp_gradients, x); - // FIXME: we need to add a way to declare individual inplace outputs - //ASSERT_EQ(exp_gradients, *gradients); + // FIXME: we need to add a way to declare individual inplace outputs + // ASSERT_EQ(exp_gradients, *gradients); } TEST_F(DeclarableOpsTests19, test_threshold_encode_2) { - for (int length = 5; length < 35; length++) { - auto x = NDArrayFactory::create('c', {10000}); - auto exp_gradients = NDArrayFactory::create('c', {10000}); + for (int length = 5; length < 35; length++) { + auto x = NDArrayFactory::create('c', {10000}); + auto exp_gradients = NDArrayFactory::create('c', {10000}); - for (int e = 0; e < length; e++) { - x.p(e, 2e-3); - exp_gradients.p(e, 1e-3); - } + for (int e = 0; e < length; e++) { + x.p(e, 2e-3); + exp_gradients.p(e, 1e-3); + } - sd::ops::encode_threshold op; - auto result = op.evaluate({&x}, {1e-3}); + sd::ops::encode_threshold op; + auto result = op.evaluate({&x}, {1e-3}); - auto encoded = result.at(1); + auto encoded = result.at(1); - ASSERT_EQ(length + 4, encoded->lengthOf()); - ASSERT_EQ(exp_gradients, x); - } + ASSERT_EQ(length + 4, encoded.lengthOf()); + ASSERT_EQ(exp_gradients, x); + } } TEST_F(DeclarableOpsTests19, test_threshold_encode_boundary_1) { - auto x = NDArrayFactory::create('c', {6}); - x = 1.0f; + auto x = NDArrayFactory::create('c', {6}); + x = 1.0f; - sd::ops::encode_threshold op; - auto result = op.evaluate({&x}, {1.0}, {3}); + sd::ops::encode_threshold op; + auto result = op.evaluate({&x}, {1.0}, {3}); - auto gradients = result.at(0); - auto encoded = result.at(1); + auto gradients = result.at(0); + auto encoded = result.at(1); - ASSERT_EQ(7, encoded->lengthOf()); - ASSERT_EQ(3, x.sumNumber().e(0)); + ASSERT_EQ(7, encoded.lengthOf()); + ASSERT_EQ(3, x.sumNumber().e(0)); } TEST_F(DeclarableOpsTests19, test_threshold_encode_boundary_2) { - auto x = NDArrayFactory::create('c', {1000}); - x = 1.0f; + auto x = NDArrayFactory::create('c', {1000}); + x = 1.0f; - sd::ops::encode_threshold op; - auto result = op.evaluate({&x}, {1.0}, {100}); + sd::ops::encode_threshold op; + auto result = op.evaluate({&x}, {1.0}, {100}); - auto gradients = result.at(0); - auto encoded = result.at(1); + auto gradients = result.at(0); + auto encoded = result.at(1); - ASSERT_EQ(104, encoded->lengthOf()); + ASSERT_EQ(104, encoded.lengthOf()); - ASSERT_EQ(900, x.sumNumber().e(0)); + ASSERT_EQ(900, x.sumNumber().e(0)); } TEST_F(DeclarableOpsTests19, test_threshold_decode_1) { - auto x = NDArrayFactory::create('c', {3}, {1.0, 2.0, -3.0}); - auto y = NDArrayFactory::create('c', {7}, {3, 3, 1056964608, 0, 1, 2, -3}); - auto exp_gradients = NDArrayFactory::create('c', {3}, {1.5, 2.5, -3.5}); - - sd::ops::decode_threshold op; - auto status = op.execute({&x, &y}, {&x}); - ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(exp_gradients, x); + auto x = NDArrayFactory::create('c', {3}, {1.0, 2.0, -3.0}); + auto y = + NDArrayFactory::create('c', {7}, {3, 3, 1056964608, 0, 1, 2, -3}); + auto exp_gradients = + NDArrayFactory::create('c', {3}, {1.5, 2.5, -3.5}); + + sd::ops::decode_threshold op; + auto status = op.execute({&x, &y}, {&x}); + ASSERT_EQ(Status::OK(), status); + ASSERT_EQ(exp_gradients, x); } TEST_F(DeclarableOpsTests19, test_bitmap_encode_1) { - auto initial = NDArrayFactory::create('c', {6}, {0.0f, 0.0f, 1e-3f, -1e-3f, 0.0f, 0.0f}); - auto exp_0 = initial.like(); - auto exp_1 = initial.dup(); - auto exp_c = NDArrayFactory::create(2L); + auto initial = NDArrayFactory::create( + 'c', {6}, {0.0f, 0.0f, 1e-3f, -1e-3f, 0.0f, 0.0f}); + auto exp_0 = initial.like(); + auto exp_1 = initial.dup(); + auto exp_c = NDArrayFactory::create(2L); - sd::ops::encode_bitmap enc; - auto enc_result = enc.evaluate({&initial}, {1e-3f}); - ASSERT_EQ(Status::OK(), enc_result.status()); + sd::ops::encode_bitmap enc; + auto enc_result = enc.evaluate({&initial}, {1e-3f}); + ASSERT_EQ(Status::OK(), enc_result.status()); - //initial.printIndexedBuffer("initial"); - ASSERT_EQ(exp_0, initial); + // initial.printIndexedBuffer("initial"); + ASSERT_EQ(exp_0, initial); - auto encoded = enc_result.at(1); - auto counter = enc_result.at(2); - - //encoded->printIndexedBuffer("encoded"); + auto encoded = enc_result.at(1); + auto counter = enc_result.at(2); - ASSERT_EQ(exp_c, *counter); + // encoded->printIndexedBuffer("encoded"); - sd::ops::decode_bitmap dec; - auto status = dec.execute({&initial, encoded}, {&initial}); - ASSERT_EQ(Status::OK(), status); + ASSERT_EQ(exp_c, counter); + sd::ops::decode_bitmap dec; + auto status = dec.execute({&initial, &encoded}, {&initial}); + ASSERT_EQ(Status::OK(), status); - //initial.printIndexedBuffer(); + // initial.printIndexedBuffer(); - ASSERT_EQ(exp_1, initial); + ASSERT_EQ(exp_1, initial); } TEST_F(DeclarableOpsTests19, test_bitmap_encode_decode) { - auto initial = NDArrayFactory::create('c', {256000}); - initial = 1.0f; - auto exp = initial.dup(); - auto neg = initial.like(); - neg = 0.5f; - - sd::ops::encode_bitmap enc; - auto enc_result = enc.evaluate({&initial}, {0.5f}); - auto encoded = enc_result.at(1); - - // checking equality of all encoded bits - for (int e = 5; e < encoded->lengthOf() - 1; e++) { - if (encoded->e(e) != encoded->e(e - 1)) - nd4j_printf("Non equal encoded values at E[%i]: %i;\n", e, encoded->e(e)); - } + auto initial = NDArrayFactory::create('c', {256000}); + initial = 1.0f; + auto exp = initial.dup(); + auto neg = initial.like(); + neg = 0.5f; - ASSERT_NE(exp, initial); - ASSERT_EQ(neg, initial); + sd::ops::encode_bitmap enc; + auto enc_result = enc.evaluate({&initial}, {0.5f}); + auto encoded = enc_result.at(1); - sd::ops::decode_bitmap dec; - auto status = dec.execute({&initial, encoded}, {&initial}); - ASSERT_EQ(Status::OK(), status); + // checking equality of all encoded bits + for (int e = 5; e < encoded.lengthOf() - 1; e++) { + if (encoded.e(e) != encoded.e(e - 1)) + nd4j_printf("Non equal encoded values at E[%i]: %i;\n", e, + encoded.e(e)); + } - // checking equality of all dedoded bits - for (int e = 0; e < initial.lengthOf(); e++) { - auto f = initial.e(e); - if (f != 1.0f) - nd4j_printf("initial[%i] = %f\n", e, f); - } + ASSERT_NE(exp, initial); + ASSERT_EQ(neg, initial); + sd::ops::decode_bitmap dec; + auto status = dec.execute({&initial, &encoded}, {&initial}); + ASSERT_EQ(Status::OK(), status); + + // checking equality of all dedoded bits + for (int e = 0; e < initial.lengthOf(); e++) { + auto f = initial.e(e); + if (f != 1.0f) nd4j_printf("initial[%i] = %f\n", e, f); + } - ASSERT_EQ(exp, initial); + ASSERT_EQ(exp, initial); } TEST_F(DeclarableOpsTests19, test_threshold_encode_decode) { - auto initial = NDArrayFactory::create('c', {256000}); - initial = 1.0f; - auto exp = initial.dup(); - auto neg = initial.like(); - neg = 0.5f; - - sd::ops::encode_threshold enc; - auto enc_result = enc.evaluate({&initial}, {0.5f}); - auto encoded = enc_result.at(1); - - ASSERT_EQ(256000 + 4, encoded->lengthOf()); - ASSERT_NE(exp, initial); - - for (int e = 0; e < initial.lengthOf(); e++) { - auto f = initial.e(e); - if (f != 0.5f) { - nd4j_printf("initial[%i] = %f\n", e, f); - throw std::runtime_error(""); - } - } - ASSERT_EQ(neg, initial); + auto initial = NDArrayFactory::create('c', {256000}); + initial = 1.0f; + auto exp = initial.dup(); + auto neg = initial.like(); + neg = 0.5f; - // checking equality of all encoded bits - //for (int e = 5; e < encoded->lengthOf() - 1; e++) { - //if (encoded->e(e) != encoded->e(e - 1) + 1) - //nd4j_printf("Non equal encoded values at E[%i]: %i;\n", e, encoded->e(e)); - //} + sd::ops::encode_threshold enc; + auto enc_result = enc.evaluate({&initial}, {0.5f}); + auto encoded = enc_result.at(1); - sd::ops::decode_threshold dec; - auto status = dec.execute({&initial, encoded}, {&initial}); - ASSERT_EQ(Status::OK(), status); + ASSERT_EQ(256000 + 4, encoded.lengthOf()); + ASSERT_NE(exp, initial); - // checking equality of all dedoded bits - for (int e = 0; e < initial.lengthOf(); e++) { - auto f = initial.e(e); - if (f != 1.0f) - nd4j_printf("initial[%i] = %f\n", e, f); + for (int e = 0; e < initial.lengthOf(); e++) { + auto f = initial.e(e); + if (f != 0.5f) { + nd4j_printf("initial[%i] = %f\n", e, f); + throw std::runtime_error(""); } + } + ASSERT_EQ(neg, initial); + + // checking equality of all encoded bits + // for (int e = 5; e < encoded->lengthOf() - 1; e++) { + // if (encoded->e(e) != encoded->e(e - 1) + 1) + // nd4j_printf("Non equal encoded values at E[%i]: %i;\n", e, + // encoded->e(e)); + //} - ASSERT_EQ(exp, initial); + sd::ops::decode_threshold dec; + auto status = dec.execute({&initial, &encoded}, {&initial}); + ASSERT_EQ(Status::OK(), status); + + // checking equality of all dedoded bits + for (int e = 0; e < initial.lengthOf(); e++) { + auto f = initial.e(e); + if (f != 1.0f) nd4j_printf("initial[%i] = %f\n", e, f); + } + + ASSERT_EQ(exp, initial); } #ifdef _RELEASE TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) { +#ifdef _RELEASE // [2,1,135079944,1,1,8192,1,99] auto initial = NDArrayFactory::create('c', {1, 135079944}); initial = 1.0f; @@ -254,7 +256,7 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) { auto enc_result = enc.evaluate({&initial}, {0.5f}); auto encoded = enc_result.at(1); - ASSERT_EQ(135079944 + 4, encoded->lengthOf()); + ASSERT_EQ(135079944 + 4, encoded.lengthOf()); ASSERT_NE(exp, initial); /* for (int e = 0; e < initial.lengthOf(); e++) { @@ -274,7 +276,7 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) { //} sd::ops::decode_threshold dec; - auto status = dec.execute({&initial, encoded}, {&initial}); + auto status = dec.execute({&initial, &encoded}, {&initial}); ASSERT_EQ(Status::OK(), status); // checking equality of all dedoded bits @@ -287,136 +289,135 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) { */ ASSERT_EQ(exp, initial); +#endif } #endif TEST_F(DeclarableOpsTests19, test_matmul_ccc) { - auto x = NDArrayFactory::create('c', {10, 10}); - auto y = NDArrayFactory::create('c', {10, 10}); - auto e = NDArrayFactory::create('c', {10, 10}); - auto z = NDArrayFactory::create('c', {10, 10}); - - z.assign(100.f); - e.assign(110.f); - x.assign(1.0f); - y.assign(1.0f); - - sd::ops::matmul op; - auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); - ASSERT_EQ(Status::OK(), status); + auto x = NDArrayFactory::create('c', {10, 10}); + auto y = NDArrayFactory::create('c', {10, 10}); + auto e = NDArrayFactory::create('c', {10, 10}); + auto z = NDArrayFactory::create('c', {10, 10}); + + z.assign(100.f); + e.assign(110.f); + x.assign(1.0f); + y.assign(1.0f); + + sd::ops::matmul op; + auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests19, test_matmul_fcf) { - auto x = NDArrayFactory::create('f', {10, 10}); - auto y = NDArrayFactory::create('c', {10, 10}); - auto e = NDArrayFactory::create('f', {10, 10}); - auto z = NDArrayFactory::create('f', {10, 10}); - - z.assign(100.f); - e.assign(110.f); - x.assign(1.0f); - y.assign(1.0f); - - sd::ops::matmul op; - auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); - ASSERT_EQ(Status::OK(), status); + auto x = NDArrayFactory::create('f', {10, 10}); + auto y = NDArrayFactory::create('c', {10, 10}); + auto e = NDArrayFactory::create('f', {10, 10}); + auto z = NDArrayFactory::create('f', {10, 10}); + + z.assign(100.f); + e.assign(110.f); + x.assign(1.0f); + y.assign(1.0f); + + sd::ops::matmul op; + auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests19, test_matmul_cff) { - auto x = NDArrayFactory::create('c', {10, 10}); - auto y = NDArrayFactory::create('f', {10, 10}); - auto e = NDArrayFactory::create('f', {10, 10}); - auto z = NDArrayFactory::create('f', {10, 10}); - - z.assign(100.f); - e.assign(110.f); - x.assign(1.0f); - y.assign(1.0f); - - sd::ops::matmul op; - auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); - ASSERT_EQ(Status::OK(), status); + auto x = NDArrayFactory::create('c', {10, 10}); + auto y = NDArrayFactory::create('f', {10, 10}); + auto e = NDArrayFactory::create('f', {10, 10}); + auto z = NDArrayFactory::create('f', {10, 10}); + + z.assign(100.f); + e.assign(110.f); + x.assign(1.0f); + y.assign(1.0f); + + sd::ops::matmul op; + auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } - TEST_F(DeclarableOpsTests19, test_matmul_ccf) { - auto x = NDArrayFactory::create('c', {10, 10}); - auto y = NDArrayFactory::create('c', {10, 10}); - auto e = NDArrayFactory::create('f', {10, 10}); - auto z = NDArrayFactory::create('f', {10, 10}); - - z.assign(100.f); - e.assign(110.f); - x.assign(1.0f); - y.assign(1.0f); - - sd::ops::matmul op; - auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); - ASSERT_EQ(Status::OK(), status); + auto x = NDArrayFactory::create('c', {10, 10}); + auto y = NDArrayFactory::create('c', {10, 10}); + auto e = NDArrayFactory::create('f', {10, 10}); + auto z = NDArrayFactory::create('f', {10, 10}); + + z.assign(100.f); + e.assign(110.f); + x.assign(1.0f); + y.assign(1.0f); + + sd::ops::matmul op; + auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests19, test_matmul_fff) { - auto x = NDArrayFactory::create('f', {10, 10}); - auto y = NDArrayFactory::create('f', {10, 10}); - auto e = NDArrayFactory::create('f', {10, 10}); - auto z = NDArrayFactory::create('f', {10, 10}); - - z.assign(100.f); - e.assign(110.f); - x.assign(1.0f); - y.assign(1.0f); - - sd::ops::matmul op; - auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); - ASSERT_EQ(Status::OK(), status); + auto x = NDArrayFactory::create('f', {10, 10}); + auto y = NDArrayFactory::create('f', {10, 10}); + auto e = NDArrayFactory::create('f', {10, 10}); + auto z = NDArrayFactory::create('f', {10, 10}); + + z.assign(100.f); + e.assign(110.f); + x.assign(1.0f); + y.assign(1.0f); + + sd::ops::matmul op; + auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests19, test_conv1d_bp_1) { - /* - DynamicCustomOp op = DynamicCustomOp.builder("conv1d_bp") - .addInputs( - Nd4j.create(DataType.FLOAT, 2,2,12), - Nd4j.create(DataType.FLOAT, 3,2,3), - Nd4j.create(DataType.FLOAT, 2,3,6) - ) - .addOutputs( - Nd4j.create(DataType.FLOAT, 2,2,12), - Nd4j.create(DataType.FLOAT, 3,2,3)) - .addIntegerArguments(3,2,0,1,2,0) - .build(); - - Nd4j.exec(op); - */ - - auto t = NDArrayFactory::create('c', {2, 2, 12}); - auto u = NDArrayFactory::create('c', {3, 2, 3}); - auto v = NDArrayFactory::create('c', {2, 3, 6}); - - sd::ops::conv1d_bp op; - auto result = op.evaluate({&t, &u, &v}, {3, 2, 0, 1, 2,0}); - ASSERT_EQ(Status::OK(), result.status()); + /* + DynamicCustomOp op = DynamicCustomOp.builder("conv1d_bp") + .addInputs( + Nd4j.create(DataType.FLOAT, 2,2,12), + Nd4j.create(DataType.FLOAT, 3,2,3), + Nd4j.create(DataType.FLOAT, 2,3,6) + ) + .addOutputs( + Nd4j.create(DataType.FLOAT, 2,2,12), + Nd4j.create(DataType.FLOAT, 3,2,3)) + .addIntegerArguments(3,2,0,1,2,0) + .build(); + + Nd4j.exec(op); + */ + auto t = NDArrayFactory::create('c', {2, 2, 12}); + auto u = NDArrayFactory::create('c', {3, 2, 3}); + auto v = NDArrayFactory::create('c', {2, 3, 6}); + + sd::ops::conv1d_bp op; + auto result = op.evaluate({&t, &u, &v}, {3, 2, 0, 1, 2, 0}); + ASSERT_EQ(Status::OK(), result.status()); } TEST_F(DeclarableOpsTests19, test_squeeze_1) { - auto x = NDArrayFactory::create('c', {3, 4, 1}); - auto e = NDArrayFactory::create('c', {3, 4}); - int axis = 2; + auto x = NDArrayFactory::create('c', {3, 4, 1}); + auto e = NDArrayFactory::create('c', {3, 4}); + int axis = 2; - sd::ops::squeeze op; - auto status = op.execute({&x}, {&e}, {axis}); - ASSERT_EQ(Status::OK(), status); + sd::ops::squeeze op; + auto status = op.execute({&x}, {&e}, {axis}); + ASSERT_EQ(Status::OK(), status); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp index f4847889b8fb..68fddd6f7a71 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp @@ -14,4166 +14,4290 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -#include "testlayers.h" -#include -#include #include #include +#include +#include +#include "testlayers.h" using namespace sd; using namespace sd::graph; class DeclarableOpsTests2 : public testing::Test { -public: - - DeclarableOpsTests2() { - printf("\n"); - } + public: + DeclarableOpsTests2() { printf("\n"); } }; //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_1) { + NDArray input('c', {2, 3, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}, + sd::DataType::FLOAT32); + NDArray indices('c', {1, 6}, {0, 1, 2, 2, 1, 2}, sd::DataType::INT32); + NDArray expected( + 'c', {2, 1, 6, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 9, 10, 11, 12, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 21, 22, 23, 24, 17, 18, 19, 20, 21, 22, 23, 24}, + sd::DataType::FLOAT32); - NDArray input('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, sd::DataType::FLOAT32); - NDArray indices('c', {1,6}, {0,1, 2,2, 1,2}, sd::DataType::INT32); - NDArray expected('c', {2,1,6,4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 9,10,11,12, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16, 17,18,19,20, 21,22,23,24, 21,22,23,24, 17,18,19,20, 21,22,23,24}, sd::DataType::FLOAT32); - - sd::ops::gather op; + sd::ops::gather op; - auto result = op.evaluate({&input, &indices}, {1}); + auto result = op.evaluate({&input, &indices}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto* output = result.at(0); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } TEST_F(DeclarableOpsTests2, gather_2) { + NDArray input('c', {2, 3, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + // auto indices ('c', {1,6}, {0,1, 2,2, 1,2}); + NDArray expected( + 'c', {2, 6, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 9, 10, 11, 12, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 21, 22, 23, 24, 17, 18, 19, 20, 21, 22, 23, 24}); - NDArray input('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); - //auto indices ('c', {1,6}, {0,1, 2,2, 1,2}); - NDArray expected('c', {2,6,4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 9,10,11,12, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16, 17,18,19,20, 21,22,23,24, 21,22,23,24, 17,18,19,20, 21,22,23,24}); + sd::ops::gather op; - sd::ops::gather op; + auto result = op.evaluate({&input}, {}, {1, 0, 1, 2, 2, 1, 2}, {true}); - auto result = op.evaluate({&input}, {}, {1, 0,1, 2,2, 1,2}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); - auto* output = result.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_3) { + NDArray input('c', {2, 3, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + NDArray indices('c', {1, 1}, std::vector{2}, sd::DataType::INT32); + NDArray expected('c', {2, 1, 1, 4}, {9, 10, 11, 12, 21, 22, 23, 24}); - NDArray input ('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); - NDArray indices ('c', {1,1}, std::vector{2}, sd::DataType::INT32); - NDArray expected('c', {2,1,1,4}, {9,10,11,12,21,22,23,24}); + sd::ops::gather op; - sd::ops::gather op; + auto result = op.evaluate({&input, &indices}, {}, {1}); - auto result = op.evaluate({&input, &indices}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); - auto* output = result.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } TEST_F(DeclarableOpsTests2, gather_4) { + NDArray input('c', {2, 3, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + // auto indices ('c', {1,1}, {2}); + NDArray expected('c', {2, 4}, {9, 10, 11, 12, 21, 22, 23, 24}); - NDArray input('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); - //auto indices ('c', {1,1}, {2}); - NDArray expected('c', {2,4}, {9,10,11,12,21,22,23,24}); - - sd::ops::gather op; + sd::ops::gather op; - auto result = op.evaluate({&input}, {}, {1, 2}); + auto result = op.evaluate({&input}, {}, {1, 2}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto* output = result.at(0); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_5) { + NDArray input('c', {2, 3, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + NDArray indices('c', {2, 3}, {0, 1, 2, 2, 1, 2}, sd::DataType::INT32); + NDArray expected( + 'c', {2, 2, 3, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 9, 10, 11, 12, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 21, 22, 23, 24, 17, 18, 19, 20, 21, 22, 23, 24}); - NDArray input ('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); - NDArray indices ('c', {2,3}, {0, 1, 2, 2, 1,2}, sd::DataType::INT32); - NDArray expected('c', {2,2,3,4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 9,10,11,12, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16,17,18,19,20,21,22,23,24, 21,22,23,24,17,18,19,20,21,22,23,24}); + sd::ops::gather op; - sd::ops::gather op; + auto result = op.evaluate({&input, &indices}, {}, {1}, {true}); - auto result = op.evaluate({&input, &indices}, {}, {1}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); - auto* output = result.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_6) { + NDArray input( + 'c', {3, 3, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36}); + NDArray indices('c', {2, 3}, {0, 1, 2, 2, 1, 2}, sd::DataType::INT32); + NDArray expected( + 'c', {2, 3, 3, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36}); - NDArray input ('c', {3,3,4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16,17,18,19,20,21,22,23,24, 25,26,27,28,29,30,31,32,33,34,35,36}); - NDArray indices ('c', {2,3}, {0, 1, 2, 2, 1,2}, sd::DataType::INT32); - NDArray expected('c', {2,3,3,4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16,17,18,19,20,21,22,23,24, 25,26,27,28,29,30,31,32,33,34,35,36, 25,26,27,28,29,30,31,32,33,34,35,36, 13,14,15,16,17,18,19,20,21,22,23,24, 25,26,27,28,29,30,31,32,33,34,35,36}); - - sd::ops::gather op; + sd::ops::gather op; - auto result = op.evaluate({&input, &indices}, {}, {0}); + auto result = op.evaluate({&input, &indices}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto* output = result.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_7) { + NDArray input('c', {2, 3, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + NDArray indices('c', {2, 3}, {0, 1, 2, 2, 1, 2}, sd::DataType::INT64); + NDArray expected( + 'c', {2, 3, 2, 3}, + {1, 2, 3, 3, 2, 3, 5, 6, 7, 7, 6, 7, 9, 10, 11, 11, 10, 11, + 13, 14, 15, 15, 14, 15, 17, 18, 19, 19, 18, 19, 21, 22, 23, 23, 22, 23}); - NDArray input ('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); - NDArray indices ('c', {2,3}, {0, 1, 2, 2, 1,2}, sd::DataType::INT64); - NDArray expected('c', {2,3,2,3}, {1, 2, 3, 3, 2, 3, 5, 6, 7, 7, 6, 7, 9,10,11,11,10,11, 13,14,15,15,14,15, 17,18,19,19,18,19, 21,22,23,23,22,23}); - - sd::ops::gather op; + sd::ops::gather op; - auto result = op.evaluate({&input, &indices}, {}, {2}); + auto result = op.evaluate({&input, &indices}, {}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto* output = result.at(0); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); - + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_8) { + NDArray input('c', {3, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + sd::DataType::FLOAT32); + NDArray indices('c', {1}, std::vector{2}, sd::DataType::INT32); + NDArray expected('c', {1, 5}, {11, 12, 13, 14, 15.}, sd::DataType::FLOAT32); - NDArray input('c', {3,5}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, sd::DataType::FLOAT32); - NDArray indices('c', {1}, std::vector{2}, sd::DataType::INT32); - NDArray expected('c', {1,5}, {11, 12, 13, 14, 15.}, sd::DataType::FLOAT32); - - sd::ops::gather op; - - auto result = op.evaluate({&input, &indices}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto* output = result.at(0); - // output->printShapeInfo(); - // output->printIndexedBuffer(); + sd::ops::gather op; - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto result = op.evaluate({&input, &indices}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + // output->printShapeInfo(); + // output->printIndexedBuffer(); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_9) { - NDArray x('c', {2, 4, 3, 2}, sd::DataType::FLOAT32); - NDArray indices('c', {2}, std::vector{1, 0}, sd::DataType::INT32); + NDArray x('c', {2, 4, 3, 2}, sd::DataType::FLOAT32); + NDArray indices('c', {2}, std::vector{1, 0}, sd::DataType::INT32); - sd::ops::gather op; - auto result = op.evaluate({&x, &indices}, {}, {-2}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::gather op; + auto result = op.evaluate({&x, &indices}, {}, {-2}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_10) { - NDArray x('c', {2, 2}, {1, 2, 3, 4}); - NDArray e('c', {2, 2}, {3, 4, 1, 2}); + NDArray x('c', {2, 2}, {1, 2, 3, 4}); + NDArray e('c', {2, 2}, {3, 4, 1, 2}); - sd::ops::gather op; - auto result = op.evaluate({&x}, {}, {0, 1, 0}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::gather op; + auto result = op.evaluate({&x}, {}, {0, 1, 0}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_11) { + NDArray x('c', {2, 2}, {1, 2, 3, 4}); + NDArray indices('c', {2}, std::vector{1, 0}, sd::DataType::INT64); + NDArray e('c', {2, 2}, {3, 4, 1, 2}); - NDArray x('c', {2, 2}, {1, 2, 3, 4}); - NDArray indices('c', {2}, std::vector{1, 0}, sd::DataType::INT64); - NDArray e('c', {2, 2}, {3, 4, 1, 2}); - - sd::ops::gather op; - auto result = op.evaluate({&x, &indices}, {}, {0}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::gather op; + auto result = op.evaluate({&x, &indices}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); - + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_12) { + NDArray input('c', {4}, {2.f, 3.f, 4.f, 5.f}); + NDArray indices('c', {2}, {0, 2}, sd::DataType::INT32); + NDArray exp('c', {2}, {2.f, 4.f}); - NDArray input('c', {4}, {2.f, 3.f, 4.f, 5.f}); - NDArray indices('c', {2}, {0, 2}, sd::DataType::INT32); - NDArray exp('c', {2}, {2.f, 4.f}); - - sd::ops::gather op; - auto result = op.evaluate({&input, &indices}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::gather op; + auto result = op.evaluate({&input, &indices}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_13) { - - NDArray input ('c', {2,3,4,5}, sd::DataType::DOUBLE); - NDArray indices ('c', {2,3,4}, {0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3,0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}, sd::DataType::INT32); - NDArray expected('c', {2,3, 2,3,4, 5}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, - 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, - 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, - 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, - 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, - 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, - 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, - 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, - 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, - 100,101,102,103,104, 105,106,107,108,109, 110,111,112,113,114, 115,116,117,118,119, 100,101,102,103,104, 105,106,107,108,109, 110,111,112,113,114, 115,116,117,118,119, 100,101,102,103,104, 105,106,107,108,109, 110,111,112,113,114, 115,116,117,118,119, - 100,101,102,103,104, 105,106,107,108,109, 110,111,112,113,114, 115,116,117,118,119, 100,101,102,103,104, 105,106,107,108,109, 110,111,112,113,114, 115,116,117,118,119, 100,101,102,103,104, 105,106,107,108,109, 110,111,112,113,114, 115,116,117,118,119}); - - input.linspace(0); - - sd::ops::gather op; - - auto result = op.evaluate({&input, &indices}, {}, {2}, {true}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto* output = result.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(*output)); - ASSERT_TRUE(expected.equalsTo(output)); - + NDArray input('c', {2, 3, 4, 5}, sd::DataType::DOUBLE); + NDArray indices('c', {2, 3, 4}, {0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, + 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}, + sd::DataType::INT32); + NDArray expected( + 'c', {2, 3, 2, 3, 4, 5}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 1, 2, 3, + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, + 34, 35, 36, 37, 38, 39, 20, 21, 22, 23, 24, 25, 26, 27, + 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38, 39, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, + 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, + 52, 53, 54, 55, 56, 57, 58, 59, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, + 54, 55, 56, 57, 58, 59, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 40, 41, + 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, + 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, + 78, 79, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, + 72, 73, 74, 75, 76, 77, 78, 79, 60, 61, 62, 63, 64, 65, + 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, + 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, + 74, 75, 76, 77, 78, 79, 60, 61, 62, 63, 64, 65, 66, 67, + 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 60, 61, + 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, + 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 80, 81, 82, 83, + 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, + 98, 99, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, + 92, 93, 94, 95, 96, 97, 98, 99, 80, 81, 82, 83, 84, 85, + 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, + 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, + 94, 95, 96, 97, 98, 99, 80, 81, 82, 83, 84, 85, 86, 87, + 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, + 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, + 116, 117, 118, 119, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, + 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 100, 101, 102, 103, + 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, + 118, 119, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, + 112, 113, 114, 115, 116, 117, 118, 119, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, + 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, + 114, 115, 116, 117, 118, 119}); + + input.linspace(0); + + sd::ops::gather op; + + auto result = op.evaluate({&input, &indices}, {}, {2}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_14) { + NDArray input('c', {2, 3, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + NDArray indices('c', {2, 3}, {0, 10, 2, 20, 1, 2}, sd::DataType::INT32); + NDArray output('c', {2, 2, 3, 4}); - NDArray input ('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); - NDArray indices ('c', {2,3}, {0, 10, 2, 20, 1,2}, sd::DataType::INT32); - NDArray output('c', {2,2,3,4}); + sd::ops::gather op; - sd::ops::gather op; - - ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {1}, {true})); + ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {1}, {true})); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_15) { + NDArray input('c', {2, 3, 4, 5}, sd::DataType::DOUBLE); + NDArray indices('c', {2, 3, 4}, {0, 10, 2, 3, 0, 1, 20, 3, 0, 1, 2, 3, + 0, 1, 2, 3, 0, 1, 2, 30, 0, 1, 2, 3}, + sd::DataType::INT32); + NDArray output('c', {2, 3, 2, 3, 4, 5}); - NDArray input ('c', {2,3,4,5}, sd::DataType::DOUBLE); - NDArray indices ('c', {2,3,4}, {0, 10, 2, 3, 0, 1, 20, 3, 0, 1, 2, 3,0, 1, 2, 3, 0, 1, 2, 30, 0, 1, 2, 3}, sd::DataType::INT32); - NDArray output('c', {2,3, 2,3,4, 5}); - - sd::ops::gather op; + sd::ops::gather op; - ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {2}, {true})); + ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {2}, {true})); } TEST_F(DeclarableOpsTests2, BroadcastGradientArgs_1) { + NDArray input( + 'c', {3, 3, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36}, + sd::DataType::INT32); + NDArray indices('c', {2, 3}, {0, 1, 2, 2, 1, 2}, sd::DataType::INT32); - NDArray input ('c', {3,3,4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16,17,18,19,20,21,22,23,24, 25,26,27,28,29,30,31,32,33,34,35,36}, sd::DataType::INT32); - NDArray indices ('c', {2,3}, {0, 1, 2, 2, 1,2}, sd::DataType::INT32); + sd::ops::broadcastgradientargs op; - sd::ops::broadcastgradientargs op; + auto result = op.evaluate({&input, &indices}, {}, {}); - auto result = op.evaluate({&input, &indices}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_KERNEL_FAILURE, result.status()); - + ASSERT_EQ(ND4J_STATUS_KERNEL_FAILURE, result.status()); } TEST_F(DeclarableOpsTests2, NLP_Cbow_Test_1) { - auto exp0 = NDArrayFactory::create('c', {1, 10}); - auto exp1 = NDArrayFactory::create('c', {1, 10}); - auto exp2 = NDArrayFactory::create('c', {1, 10}); - - exp0.assign(0.0095); - exp1.assign(0.019875); - exp2.assign(0.02); - - auto target = NDArrayFactory::create(0); - auto ngStarter = NDArrayFactory::empty(); - auto context = NDArrayFactory::create('c', {3}, {0, 1, 2}); - auto locked = NDArrayFactory::create('c', {3}); - auto indices = NDArrayFactory::create('c', {2}, {4, 5}); - auto codes = NDArrayFactory::create('c', {2}, {1, 1}); - auto syn0 = NDArrayFactory::create('c', {100, 10}); - auto syn1 = NDArrayFactory::create('c', {100, 10}); - auto syn1Neg = NDArrayFactory::empty(); - auto expTable = NDArrayFactory::create('c', {10000}); - auto negTable = NDArrayFactory::empty(); - auto numWords = NDArrayFactory::create('c', {1}, {1}); - - syn0.assign(0.01); - syn1.assign(0.02); - expTable.assign(0.5); - - auto alpha = NDArrayFactory::create(0.025); - auto randomValue = NDArrayFactory::create(2L); - auto inferenceVector = NDArrayFactory::empty(); - - sd::ops::cbow op; - auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true); - ASSERT_EQ(Status::OK(), result.status()); - - auto row_s0_0 = syn0({0,1, 0,0}, true); - auto row_s0_1 = syn0({1,2, 0,0}, true); - auto row_s0_2 = syn0({2,3, 0,0}, true); - - auto row_s1_4 = syn1({4,5, 0,0}, true); - auto row_s1_5 = syn1({5,6, 0,0}, true); - auto row_s1_6 = syn1({6,7, 0,0}, true); - - ASSERT_EQ(exp0, row_s0_0); - ASSERT_EQ(exp0, row_s0_1); - ASSERT_EQ(exp0, row_s0_2); - - ASSERT_EQ(exp1, row_s1_4); - ASSERT_EQ(exp1, row_s1_5); - ASSERT_EQ(exp2, row_s1_6); - + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + auto exp2 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.0095); + exp1.assign(0.019875); + exp2.assign(0.02); + + auto target = NDArrayFactory::create(0); + auto ngStarter = NDArrayFactory::empty(); + auto context = NDArrayFactory::create('c', {3}, {0, 1, 2}); + auto locked = NDArrayFactory::create('c', {3}); + auto indices = NDArrayFactory::create('c', {2}, {4, 5}); + auto codes = NDArrayFactory::create('c', {2}, {1, 1}); + auto syn0 = NDArrayFactory::create('c', {100, 10}); + auto syn1 = NDArrayFactory::create('c', {100, 10}); + auto syn1Neg = NDArrayFactory::empty(); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::empty(); + auto numWords = NDArrayFactory::create('c', {1}, {1}); + + syn0.assign(0.01); + syn1.assign(0.02); + expTable.assign(0.5); + + auto alpha = NDArrayFactory::create(0.025); + auto randomValue = NDArrayFactory::create(2L); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::cbow op; + auto result = + op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, + &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, + &numWords, &locked, &inferenceVector}, + {}, {}, {true}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + auto row_s0_0 = syn0({0, 1, 0, 0}, true); + auto row_s0_1 = syn0({1, 2, 0, 0}, true); + auto row_s0_2 = syn0({2, 3, 0, 0}, true); + + auto row_s1_4 = syn1({4, 5, 0, 0}, true); + auto row_s1_5 = syn1({5, 6, 0, 0}, true); + auto row_s1_6 = syn1({6, 7, 0, 0}, true); + + ASSERT_EQ(exp0, row_s0_0); + ASSERT_EQ(exp0, row_s0_1); + ASSERT_EQ(exp0, row_s0_2); + + ASSERT_EQ(exp1, row_s1_4); + ASSERT_EQ(exp1, row_s1_5); + ASSERT_EQ(exp2, row_s1_6); } TEST_F(DeclarableOpsTests2, Test_Squeeze_1) { - auto x = NDArrayFactory::create('c', {2, 1, 3, 1, 1, 1, 4}); - x.linspace(1); - auto exp = x.reshape('c', {2, 3, 4}); + auto x = NDArrayFactory::create('c', {2, 1, 3, 1, 1, 1, 4}); + x.linspace(1); + auto exp = x.reshape('c', {2, 3, 4}); - sd::ops::squeeze op; - auto result = op.evaluate({&x}, {}, {}); + sd::ops::squeeze op; + auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests2, Test_Squeeze_2) { - auto x = NDArrayFactory::create('c', {2, 3, 4}); - x.linspace(1); - auto exp = new NDArray(x.dup()); + auto x = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + auto exp = new NDArray(x.dup()); - sd::ops::squeeze op; - auto result = op.evaluate({&x}, {}, {}); + sd::ops::squeeze op; + auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp->isSameShape(z)); - ASSERT_TRUE(exp->equalsTo(z)); - - delete exp; + ASSERT_TRUE(exp->isSameShape(z)); + ASSERT_TRUE(exp->equalsTo(z)); + + delete exp; } TEST_F(DeclarableOpsTests2, Test_FloorMod_1) { - auto x = NDArrayFactory::create('c', {1, 3}, {2.0f, 6.0f, -3.0f}); - auto y = NDArrayFactory::create('c', {1, 3}, {-3.0f, 2.0f, -2.0f}); - auto exp = NDArrayFactory::create('c', {1, 3}, {-1.f, 0.f, -1.f}); + auto x = NDArrayFactory::create('c', {1, 3}, {2.0f, 6.0f, -3.0f}); + auto y = NDArrayFactory::create('c', {1, 3}, {-3.0f, 2.0f, -2.0f}); + auto exp = NDArrayFactory::create('c', {1, 3}, {-1.f, 0.f, -1.f}); - sd::ops::floormod op; + sd::ops::floormod op; - auto result = op.evaluate({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests2, Test_FloorDiv_1) { - auto x = NDArrayFactory::create('c', {1, 3}, {3.0f, 6.0f, -3.0f}); - auto y = NDArrayFactory::create('c', {1, 3}, {-2.0f, 2.0f, -2.0f}); - auto exp = NDArrayFactory::create('c', {1, 3}, {-2.f, 3.f, 1.f}); + auto x = NDArrayFactory::create('c', {1, 3}, {3.0f, 6.0f, -3.0f}); + auto y = NDArrayFactory::create('c', {1, 3}, {-2.0f, 2.0f, -2.0f}); + auto exp = NDArrayFactory::create('c', {1, 3}, {-2.f, 3.f, 1.f}); - sd::ops::floordiv op; + sd::ops::floordiv op; - auto result = op.evaluate({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); - auto z = result.at(0); -// z->printShapeInfo("FloorDiv1 shape"); -// z->printIndexedBuffer("FloorDiv1"); - ASSERT_TRUE(exp.isSameShape(z)); + auto z = result.at(0); + // z->printShapeInfo("FloorDiv1 shape"); + // z->printIndexedBuffer("FloorDiv1"); + ASSERT_TRUE(exp.isSameShape(z)); } TEST_F(DeclarableOpsTests2, Test_FloorDiv_2) { - auto x = NDArrayFactory::create('c', {1, 3}, {3.0f, 6.0f, -3.0f}); - auto y = NDArrayFactory::create('c', {1, 3}, {-2.0f, 2.0f, -2.0f}); - auto eps = NDArrayFactory::create('c', {1, 3}, {1.f, 2.f, 3.f}); - - auto exp1 = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); - auto exp2 = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); + auto x = NDArrayFactory::create('c', {1, 3}, {3.0f, 6.0f, -3.0f}); + auto y = NDArrayFactory::create('c', {1, 3}, {-2.0f, 2.0f, -2.0f}); + auto eps = NDArrayFactory::create('c', {1, 3}, {1.f, 2.f, 3.f}); - sd::ops::floordiv_bp op; + auto exp1 = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); + auto exp2 = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); - auto result = op.evaluate({&x, &y, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - auto z1 = result.at(0); - auto z2 = result.at(1); -// z->printShapeInfo("FloorDiv1 shape"); -// z1->printIndexedBuffer("FloorDiv2_1"); -// z2->printIndexedBuffer("FloorDiv2_2"); + sd::ops::floordiv_bp op; - ASSERT_TRUE(exp1.equalsTo(z1)); - ASSERT_TRUE(exp2.equalsTo(z2)); + auto result = op.evaluate({&x, &y, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + auto z1 = result.at(0); + auto z2 = result.at(1); + // z->printShapeInfo("FloorDiv1 shape"); + // z1->printIndexedBuffer("FloorDiv2_1"); + // z2->printIndexedBuffer("FloorDiv2_2"); + ASSERT_TRUE(exp1.equalsTo(z1)); + ASSERT_TRUE(exp2.equalsTo(z2)); } TEST_F(DeclarableOpsTests2, Test_CRelu_1) { - auto x = NDArrayFactory::create('c', {2, 2}, {1.0f, 2.0f, 3.0f, 4.0f}); - auto exp = NDArrayFactory::create('c', {2, 4}, {1.0f, 2.0f, 0.f, 0.f, 3.0f, 4.0f, 0.f, 0.f}); - - sd::ops::crelu op; + auto x = NDArrayFactory::create('c', {2, 2}, {1.0f, 2.0f, 3.0f, 4.0f}); + auto exp = NDArrayFactory::create( + 'c', {2, 4}, {1.0f, 2.0f, 0.f, 0.f, 3.0f, 4.0f, 0.f, 0.f}); - auto result = op.evaluate({&x}, {}, {}); + sd::ops::crelu op; - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({&x}, {}, {}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests2, Test_CRelu_BP_2) { - auto x = NDArrayFactory::create('c', {2, 2}, {1.0f, 2.0f, -3.0f, 4.0f}); - auto eps = NDArrayFactory::create('c', {2, 4}, {1.0f, 2.0f, 4.f, 3.f, 3.0f, 4.0f, 2.f, 1.f}); - auto exp = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, -2.f, 4.f}); + auto x = + NDArrayFactory::create('c', {2, 2}, {1.0f, 2.0f, -3.0f, 4.0f}); + auto eps = NDArrayFactory::create( + 'c', {2, 4}, {1.0f, 2.0f, 4.f, 3.f, 3.0f, 4.0f, 2.f, 1.f}); + auto exp = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, -2.f, 4.f}); - sd::ops::crelu_bp op; - auto result = op.evaluate({&x, &eps}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(1, result.size()); + sd::ops::crelu_bp op; + auto result = op.evaluate({&x, &eps}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(1, result.size()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests2, Test_Concat_BP_1) { - auto x = NDArrayFactory::create('c', {2, 2}); - auto y = NDArrayFactory::create('c', {2, 2}); - auto eps = NDArrayFactory::create('c', {2, 4}, {1.0f, 2.0f, 0.f, 1.f, 3.0f, 4.0f, 0.f, 1.f}); - auto expEX = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); - auto expEY = NDArrayFactory::create('c', {2, 2}, {0.f, 1.f, 0.f, 1.f}); + auto x = NDArrayFactory::create('c', {2, 2}); + auto y = NDArrayFactory::create('c', {2, 2}); + auto eps = NDArrayFactory::create( + 'c', {2, 4}, {1.0f, 2.0f, 0.f, 1.f, 3.0f, 4.0f, 0.f, 1.f}); + auto expEX = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + auto expEY = NDArrayFactory::create('c', {2, 2}, {0.f, 1.f, 0.f, 1.f}); - sd::ops::concat_bp op; - auto result = op.evaluate({&x, &y, &eps}, {}, {-1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); + sd::ops::concat_bp op; + auto result = op.evaluate({&x, &y, &eps}, {}, {-1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); - auto epsX = result.at(0); - auto epsY = result.at(1); + auto epsX = result.at(0); + auto epsY = result.at(1); - ASSERT_TRUE(expEX.isSameShape(epsX)); - ASSERT_TRUE(expEX.equalsTo(epsX)); + ASSERT_TRUE(expEX.isSameShape(epsX)); + ASSERT_TRUE(expEX.equalsTo(epsX)); - ASSERT_TRUE(expEY.isSameShape(epsY)); - ASSERT_TRUE(expEY.equalsTo(epsY)); + ASSERT_TRUE(expEY.isSameShape(epsY)); + ASSERT_TRUE(expEY.equalsTo(epsY)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_1) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto expected = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,3,4,5}); - auto expected = NDArrayFactory::create('c', {2,3,4,5}); - - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.5f); - expected.assign(0.5f); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); + expected.assign(0.5f); - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_2) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {1, 1, 4, 5}); + auto expected = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {1,1,4,5}); - auto expected = NDArrayFactory::create('c', {2,3,4,5}); - - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.5f); - expected.assign(0.5f); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); + expected.assign(0.5f); - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *result = results.at(0); - // result->printIndexedBuffer("ADL test2"); - // expected.printIndexedBuffer("ADL expec"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + // result.printIndexedBuffer("ADL test2"); + // expected.printIndexedBuffer("ADL expec"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_3) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {1, 1, 1, 5}); + auto expected = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {1,1,1,5}); - auto expected = NDArrayFactory::create('c', {2,3,4,5}); - - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.5f); - expected.assign(0.5f); - - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); + expected.assign(0.5f); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_4) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 1, 1, 5}); + auto expected = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,1,1,5}); - auto expected = NDArrayFactory::create('c', {2,3,4,5}); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); + expected.assign(0.5f); - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.5f); - expected.assign(0.5f); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_5) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {1, 1}); + auto expected = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {1,1}); - auto expected = NDArrayFactory::create('c', {2,3,4,5}); - - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.5f); - expected.assign(0.5f); - - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); + expected.assign(0.5f); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_6) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {1, 1}); + auto expected = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {1,1}); - auto expected = NDArrayFactory::create('c', {2,3,4,5}); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.f); + expected.assign(0.f); - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.f); - expected.assign(0.f); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_7) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,3,4,5}); - - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.5f); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 60.f); + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 60.f); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_8) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,3,4,5}); - - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.f); - - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.f); - auto *result = results.at(0); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 0.f); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 0.f); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_9) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 1, 4, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,1,4,1}); - - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.5f); - - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); - auto *result = results.at(0); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 60.); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 60.); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_10) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {1,1}); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.5f); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 60.f); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 60.f); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_11) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {1,1}); - - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.5f); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 1.f); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 1.f); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_12) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {1,1}); - - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.f); - - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.f); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 0.f); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 0.f); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_13) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,3,4,5}); - - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.5f); - - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); - auto *result = results.at(0); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 1.f); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 1.f); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_14) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,3,4,5}); - - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.5); - weights.p(1, 0.f); - weights.p(2, 0.f); - - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5); + weights.p(1, 0.f); + weights.p(2, 0.f); - auto *result = results.at(0); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 1.f); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 1.f); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_15) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,3,4,5}); + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.5f); - labels.linspace(1); - predictions.linspace(3); - weights.assign(0.5f); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 2.f); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 2.f); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_16) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,3,4,5}); - - labels.linspace(1); - predictions.linspace(3); - weights.assign(0.5f); - predictions.p(0, 0.f); - predictions.p(1, 0.f); - predictions.p(2, 0.f); - predictions.p(3, 0.f); + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.5f); + predictions.p(0, 0.f); + predictions.p(1, 0.f); + predictions.p(2, 0.f); + predictions.p(3, 0.f); - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 2.01667, 1e-5); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 2.01667, 1e-5); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_17) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,3,4,5}); - - labels.linspace(1); - predictions.linspace(3); - weights.assign(0.5f); - predictions.p(0, 0.f); - predictions.p(1, 0.f); - predictions.p(2, 0.f); - predictions.p(3, 0.f); - labels.p(0, 0.f); - labels.p(1, 0.f); - labels.p(2, 0.f); - labels.p(3, 0.f); + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.5f); + predictions.p(0, 0.f); + predictions.p(1, 0.f); + predictions.p(2, 0.f); + predictions.p(3, 0.f); + labels.p(0, 0.f); + labels.p(1, 0.f); + labels.p(2, 0.f); + labels.p(3, 0.f); - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 1.93333, 1e-5); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 1.93333, 1e-5); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_18) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 1, 1, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,1,1,5}); - - labels.linspace(1); - predictions.linspace(3); - weights.assign(0.5f); - predictions.p(0, 0.f); - predictions.p(1, 0.f); - predictions.p(2, 0.f); - predictions.p(3, 0.); - labels.p(0, 0.f); - labels.p(1, 0.f); - labels.p(2, 0.f); - labels.p(3, 0.f); - - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.5f); + predictions.p(0, 0.f); + predictions.p(1, 0.f); + predictions.p(2, 0.f); + predictions.p(3, 0.); + labels.p(0, 0.f); + labels.p(1, 0.f); + labels.p(2, 0.f); + labels.p(3, 0.f); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 1.93333f, 1e-5); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 1.93333f, 1e-5); } - //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_19) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {1,1}); - - labels.linspace(1); - predictions.linspace(3); - weights.assign(0.5); + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.5); - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 1.); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 1.); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_20) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,3,4,5}); - - labels.linspace(1); - predictions.linspace(3); - weights.assign(0.5); - - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 1.); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 1.); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_21) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 3, 1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,3,1,1}); - - labels.linspace(1); - predictions.linspace(3); - weights.assign(0.5); - - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 1.f); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 1.f); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_22) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {1,1}); - - labels.linspace(1); - predictions.linspace(3); - weights.assign(0.); - - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.); - auto *result = results.at(0); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 0.); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 0.); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_23) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,3,4,5}); - - labels.linspace(1); - predictions.linspace(3); - weights.assign(0.5); - predictions.p(0, 0.); - predictions.p(1, 0.); - predictions.p(2, 0.); - predictions.p(3, 0.); - labels.p(0, 0.); - labels.p(1, 0.); - labels.p(2, 0.); - labels.p(3, 0.); - weights.p(40+0, 0.); - weights.p(40+1, 0.); - weights.p(40+2, 0.); - weights.p(40+3, 0.); - - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.5); + predictions.p(0, 0.); + predictions.p(1, 0.); + predictions.p(2, 0.); + predictions.p(3, 0.); + labels.p(0, 0.); + labels.p(1, 0.); + labels.p(2, 0.); + labels.p(3, 0.); + weights.p(40 + 0, 0.); + weights.p(40 + 1, 0.); + weights.p(40 + 2, 0.); + weights.p(40 + 3, 0.); - auto *result = results.at(0); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 0.965517, 1e-5); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.965517, 1e-5); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, cosine_distance_loss_test1) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {1, 3, 4}, + {-91.5f, -107.5f, -125.5f, -145.5f, -167.5f, -191.5f, -217.5f, -245.5f, + -275.5f, -307.5f, -341.5f, -377.5f}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,3,4}); - auto expected = NDArrayFactory::create('c', {1,3,4}, {-91.5f, -107.5f, -125.5f, -145.5f, -167.5f, -191.5f, -217.5f, -245.5f, -275.5f, -307.5f, -341.5f, -377.5f}); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5); - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.5); + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 0}); - sd::ops::cosine_distance_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0,0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, cosine_distance_loss_test2) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 1, 4}, + {-3.25f, -4.f, -4.75f, -5.5f, -12.25f, -13.f, -13.75f, -14.5f}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,1,4}); - auto expected = NDArrayFactory::create('c', {2,1,4}, {-3.25f, -4.f, -4.75f, -5.5f, -12.25f, -13.f, -13.75f, -14.5f}); - - labels.linspace(1); - weights.assign(0.5); - predictions.assign(0.5); + labels.linspace(1); + weights.assign(0.5); + predictions.assign(0.5); - sd::ops::cosine_distance_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0,1}); + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, cosine_distance_loss_test3) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 1}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 1}, {-2.f, -6.f, -10.f, -14.f, -18.f, -22.f}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,1}); - auto expected = NDArrayFactory::create('c', {2,3,1}, {-2.f, -6.f,-10.f,-14.f,-18.f,-22.f}); + labels.linspace(1); + weights.assign(0.5); + predictions.assign(0.5); - labels.linspace(1); - weights.assign(0.5); - predictions.assign(0.5); + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 2}); - sd::ops::cosine_distance_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0,2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, cosine_distance_loss_test4) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 1}, {-2.f, -6.f, -10.f, -14.f, -18.f, -22.f}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - auto expected = NDArrayFactory::create('c', {2,3,1}, {-2.f, -6.f,-10.f,-14.f,-18.f,-22.f}); - - labels.linspace(1); - weights.assign(0.5); - predictions.assign(0.5); - - sd::ops::cosine_distance_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0,2}); + labels.linspace(1); + weights.assign(0.5); + predictions.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 2}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, cosine_distance_loss_test5) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,1,4}); - - labels.linspace(1); - weights.assign(0.5); - predictions.assign(0.5); - - sd::ops::cosine_distance_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1,1}); + labels.linspace(1); + weights.assign(0.5); + predictions.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1, 1}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == -71.); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == -71.); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, cosine_distance_loss_test6) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - labels.linspace(1); - weights.assign(0.5); - predictions.assign(0.5); - - sd::ops::cosine_distance_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + labels.linspace(1); + weights.assign(0.5); + predictions.assign(0.5); - auto *result = results.at(0); + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1, 1}); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == -71.f); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == -71.f); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, cosine_distance_loss_test7) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1,4}); + labels.linspace(1); + weights.assign(0.5); + predictions.assign(0.5); - labels.linspace(1); - weights.assign(0.5); - predictions.assign(0.5); + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1, 0}); - sd::ops::cosine_distance_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1,0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == -69.f); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == -69.f); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, cosine_distance_loss_test8) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,1}); - - labels.linspace(1); - weights.assign(0.5f); - predictions.assign(0.5f); + labels.linspace(1); + weights.assign(0.5f); + predictions.assign(0.5f); - sd::ops::cosine_distance_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2,2}); + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2, 2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == -24.f); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == -24.f); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, cosine_distance_loss_test9) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - labels.linspace(1); - weights.assign(0.5f); - predictions.assign(0.5f); - - sd::ops::cosine_distance_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2,2}); + labels.linspace(1); + weights.assign(0.5f); + predictions.assign(0.5f); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2, 2}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == -24.); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == -24.); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, cosine_distance_loss_test10) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,1}); - - labels.linspace(1); - weights.assign(0.5f); - predictions.assign(0.5f); - weights.p(0, 0.f); - weights.p(1, 0.f); - - sd::ops::cosine_distance_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2,2}); + labels.linspace(1); + weights.assign(0.5f); + predictions.assign(0.5f); + weights.p(0, 0.f); + weights.p(1, 0.f); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2, 2}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == -32.); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == -32.); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test1) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1., 0., 0., 2.5, 0., 3.5, 0., 4.5, 0., 5.5, 0., 6.5, + 0., 7.5, 0., 8.5, 0., 9.5, 10., 0., 0., 11.5, 0., 12.5}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {1., 0. , 0., 2.5,0., 3.5, 0., 4.5,0., 5.5, 0., 6.5, 0., 7.5, 0., 8.5,0., 9.5,10., 0. ,0.,11.5, 0.,12.5}); - - logits.linspace(1); - weights.assign(0.5); + logits.linspace(1); + weights.assign(0.5); - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {0}); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); - // result->printBuffer(); + auto result = results.at(0); + // result.printBuffer(); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test2) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1., 0., 0., 2.5, 0., 3.5, 0., 4.5, 0., 5.5, 0., 6.5, + 0., 7.5, 0., 8.5, 0., 9.5, 10., 0., 0., 11.5, 0., 12.5}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {1., 0. , 0., 2.5,0., 3.5, 0., 4.5,0., 5.5, 0., 6.5, 0., 7.5, 0., 8.5,0., 9.5,10., 0. ,0.,11.5, 0.,12.5}); - - logits.linspace(1); - weights.assign(0.5); + logits.linspace(1); + weights.assign(0.5); - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {0}); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); - // result->printBuffer(); + auto result = results.at(0); + // result.printBuffer(); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test3) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 3, 1}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1., 0., 0., 2.5, 0., 3.5, 0., 4.5, 0., 5.5, 0., 6.5, + 0., 7.5, 0., 8.5, 0., 9.5, 10., 0., 0., 11.5, 0., 12.5}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,3,1}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {1., 0. , 0., 2.5,0., 3.5, 0., 4.5,0., 5.5, 0., 6.5, 0., 7.5, 0., 8.5,0., 9.5,10., 0. ,0.,11.5, 0.,12.5}); - - logits.linspace(1); - weights.assign(0.5); - - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {0}); + logits.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {0}); - auto *result = results.at(0); - // result->printBuffer(); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result.printBuffer(); - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test4) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); + logits.linspace(1); + weights.assign(0.5); - logits.linspace(1); - weights.assign(0.5); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {1}); + auto result = results.at(0); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 83.); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 83.); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test5) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - logits.linspace(1); - weights.assign(0.5); + logits.linspace(1); + weights.assign(0.5); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {1}); - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 83.); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 83.); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test6) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,1,1}); - - logits.linspace(1); - weights.assign(0.5); + logits.linspace(1); + weights.assign(0.5); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {1}); - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 83.); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 83.); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test7) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - logits.linspace(1); - weights.assign(0.5); - + logits.linspace(1); + weights.assign(0.5); - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {2}); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 6.91667, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 6.91667, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test8) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - logits.linspace(1); - weights.assign(0.5); - + logits.linspace(1); + weights.assign(0.5); - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {2}); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 6.91667, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 6.91667, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test9) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1,4}); - - logits.linspace(1); - weights.assign(0.5); - - - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {2}); + logits.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {2}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 6.91667, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 6.91667, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test10) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - logits.linspace(1); - weights.assign(0.5); - - - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {3}); + logits.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {3}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 3.45833, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 3.45833, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test11) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,1,4}); + logits.linspace(1); + weights.assign(0.5); - logits.linspace(1); - weights.assign(0.5); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {3}); + auto result = results.at(0); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 3.45833, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 3.45833, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test12) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - logits.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); - weights.p(3, 0.); + logits.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {3}); - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 3.975, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 3.975, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test13) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - logits.linspace(1); - weights.assign(0.); + logits.linspace(1); + weights.assign(0.); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {3}); - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 0.); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 0.); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, huber_loss_test1) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.0425, 0.0875, 0.13250001, 0.17749999, 0.22250001, 0.26750001, + 0.31250003, 0.35749999, 0.4025, 0.44749999, 0.49249998, 0.53750002, + 0.58249998, 0.6275, 0.67250001, 0.71749997, 0.76249999, 0.8075, + 0.85250002, 0.89749998, 0.9425, 0.98749995, 1.03250015, 1.0775001}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {0.0425 ,0.0875 ,0.13250001,0.17749999,0.22250001,0.26750001,0.31250003,0.35749999,0.4025 ,0.44749999,0.49249998,0.53750002, 0.58249998,0.6275 ,0.67250001,0.71749997,0.76249999,0.8075 ,0.85250002,0.89749998,0.9425 ,0.98749995,1.03250015,1.0775001}); - - labels.linspace(0.1, 0.1); - predictions.linspace(1); - weights.assign(0.5); - - sd::ops::huber_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {0}); + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {0}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, huber_loss_test2) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 1}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.0425, 0.0875, 0.13250001, 0.17749999, 0.22250001, 0.26750001, + 0.31250003, 0.35749999, 0.4025, 0.44749999, 0.49249998, 0.53750002, + 0.58249998, 0.6275, 0.67250001, 0.71749997, 0.76249999, 0.8075, + 0.85250002, 0.89749998, 0.9425, 0.98749995, 1.03250015, 1.0775001}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,1}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {0.0425 ,0.0875 ,0.13250001,0.17749999,0.22250001,0.26750001,0.31250003,0.35749999,0.4025 ,0.44749999,0.49249998,0.53750002, 0.58249998,0.6275 ,0.67250001,0.71749997,0.76249999,0.8075 ,0.85250002,0.89749998,0.9425 ,0.98749995,1.03250015,1.0775001}); + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); - labels.linspace(0.1, 0.1); - predictions.linspace(1); - weights.assign(0.5); + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {0}); - sd::ops::huber_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, huber_loss_test3) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.0425, 0.0875, 0.13250001, 0.17749999, 0.22250001, 0.26750001, + 0.31250003, 0.35749999, 0.4025, 0.44749999, 0.49249998, 0.53750002, + 0.58249998, 0.6275, 0.67250001, 0.71749997, 0.76249999, 0.8075, + 0.85250002, 0.89749998, 0.9425, 0.98749995, 1.03250015, 1.0775001}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {0.0425 ,0.0875 ,0.13250001,0.17749999,0.22250001,0.26750001,0.31250003,0.35749999,0.4025 ,0.44749999,0.49249998,0.53750002, 0.58249998,0.6275 ,0.67250001,0.71749997,0.76249999,0.8075 ,0.85250002,0.89749998,0.9425 ,0.98749995,1.03250015,1.0775001}); - - labels.linspace(0.1, 0.1); - predictions.linspace(1); - weights.assign(0.5); + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); - sd::ops::huber_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {0}); + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, huber_loss_test4) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - labels.linspace(0.1, 0.1); - predictions.linspace(1); - weights.assign(0.5); + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); - sd::ops::huber_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {1}); + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 13.44, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 13.44, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, huber_loss_test5) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - labels.linspace(0.1, 0.1); - predictions.linspace(1); - weights.assign(0.5); + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); - sd::ops::huber_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {1}); + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 13.44, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 13.44, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, huber_loss_test6) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - labels.linspace(0.1, 0.1); - predictions.linspace(1); - weights.assign(0.5); + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); - sd::ops::huber_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {2}); + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 1.12, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 1.12, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, huber_loss_test7) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,1,1}); - - labels.linspace(0.1, 0.1); - predictions.linspace(1); - weights.assign(0.5); + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); - sd::ops::huber_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {2}); + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 1.12, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 1.12, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, huber_loss_test8) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - labels.linspace(0.1, 0.1); - predictions.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); - weights.p(3, 0.); + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); - sd::ops::huber_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {2}); + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 1.3, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 1.3, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, huber_loss_test9) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - labels.linspace(0.1, 0.1); - predictions.linspace(1); - weights.assign(0.5); + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); - sd::ops::huber_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {3}); + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {3}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 0.56, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.56, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, huber_loss_test10) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - labels.linspace(0.1, 0.1); - predictions.linspace(1); - weights.assign(0.5); - - sd::ops::huber_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {3}); + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {3}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 0.56, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.56, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, huber_loss_test11) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); - labels.linspace(0.1, 0.1); - predictions.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); - weights.p(3, 0.); + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {3}); - sd::ops::huber_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 0.65, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.65, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test1) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1.60943663, 2.48403668, 3.05256081, 3.40363169, 3.57730675, + 3.59525585, 3.46986699, 3.20791793, 2.81228209, 2.28273821, + 1.61630058, 0.80721998, -0.15329313, -1.27764463, -2.5828433, + -4.09208679, -5.83734226, -7.8636713, -10.23689461, -13.05822182, + -16.49509811, -20.85659218, -26.82411766, -36.52717209}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {1.60943663, 2.48403668, 3.05256081, 3.40363169, 3.57730675, 3.59525585, 3.46986699, 3.20791793, 2.81228209, 2.28273821, 1.61630058, 0.80721998, -0.15329313, -1.27764463, -2.5828433 , -4.09208679, -5.83734226, -7.8636713 ,-10.23689461,-13.05822182,-16.49509811,-20.85659218,-26.82411766,-36.52717209}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test2) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1.60943663, 2.48403668, 3.05256081, 3.40363169, 3.57730675, + 3.59525585, 3.46986699, 3.20791793, 2.81228209, 2.28273821, + 1.61630058, 0.80721998, -0.15329313, -1.27764463, -2.5828433, + -4.09208679, -5.83734226, -7.8636713, -10.23689461, -13.05822182, + -16.49509811, -20.85659218, -26.82411766, -36.52717209}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,1,4}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {1.60943663, 2.48403668, 3.05256081, 3.40363169, 3.57730675, 3.59525585, 3.46986699, 3.20791793, 2.81228209, 2.28273821, 1.61630058, 0.80721998, -0.15329313, -1.27764463, -2.5828433 , -4.09208679, -5.83734226, -7.8636713 ,-10.23689461,-13.05822182,-16.49509811,-20.85659218,-26.82411766,-36.52717209}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}); - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test3) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + NDArray weights(sd::DataType::DOUBLE); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1.60943663, 2.48403668, 3.05256081, 3.40363169, 3.57730675, + 3.59525585, 3.46986699, 3.20791793, 2.81228209, 2.28273821, + 1.61630058, 0.80721998, -0.15329313, -1.27764463, -2.5828433, + -4.09208679, -5.83734226, -7.8636713, -10.23689461, -13.05822182, + -16.49509811, -20.85659218, -26.82411766, -36.52717209}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - NDArray weights(sd::DataType::DOUBLE); - auto expected = NDArrayFactory::create('c', {2,3,4}, {1.60943663, 2.48403668, 3.05256081, 3.40363169, 3.57730675, 3.59525585, 3.46986699, 3.20791793, 2.81228209, 2.28273821, 1.61630058, 0.80721998, -0.15329313, -1.27764463, -2.5828433 , -4.09208679, -5.83734226, -7.8636713 ,-10.23689461,-13.05822182,-16.49509811,-20.85659218,-26.82411766,-36.52717209}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test4) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -113.886429, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -113.886429, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test5) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 3, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,3,1}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -113.886429, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -113.886429, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test6) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + NDArray weights(sd::DataType::DOUBLE); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - NDArray weights(sd::DataType::DOUBLE); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -113.886429, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -113.886429, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test7) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -9.490536, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -9.490536, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test8) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 3, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,3,1}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -9.490536, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -9.490536, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test9) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + NDArray weights(sd::DataType::DOUBLE); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - NDArray weights(sd::DataType::DOUBLE); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -9.490536, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -9.490536, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test10) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); - weights.p(3, 0.); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -12.443609, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -12.443609, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test11) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -4.745268, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -4.745268, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test12) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -4.745268, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -4.745268, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test13) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); - weights.p(3, 0.); - - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -6.221805, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -6.221805, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test1) { - auto labels = NDArrayFactory::create('c', {1,3}, {0., 0.5, 1.}); - auto predictions = NDArrayFactory::create('c', {1,3}, {1., 1., 1.}); - auto weights = NDArrayFactory::create('c', {1,1}, {1}); - auto expected = NDArrayFactory::create('c', {1,1}, {1.}); + auto labels = NDArrayFactory::create('c', {1, 3}, {0., 0.5, 1.}); + auto predictions = NDArrayFactory::create('c', {1, 3}, {1., 1., 1.}); + auto weights = NDArrayFactory::create('c', {1, 1}, {1}); + auto expected = NDArrayFactory::create('c', {1, 1}, {1.}); - sd::ops::mean_pairwssqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test2) { - auto labels = NDArrayFactory::create('c', {10,4}, {-0.5533444483384939, -0.4045807428083095, -0.38990808632111873, -1.3367815555936828, 2.2110825342567204, -0.3322538938773163, 0.5683588435736076, 1.401524673423209, -0.2216208609234102, -0.23645194877057543, -1.9319189398422172, 0.6106128799796062, 1.6973842275926025, -2.8306371397325553E-4, -1.1550401544465256, -0.08357706614294765, -0.27784822018757077, 0.8290894318337857, 1.6484476009013025, -0.7752524785358668, -0.9700596207063842, 3.0809371469543207, -0.23684959888998405, 0.22403535560739518, 0.6146150452128438, -1.1250088686147994, -0.5915314787415693, -0.0944090155356556, 0.7995514825959854, -1.2290496239142903, -1.8329592004926936, -0.1694821152623061, -1.7614978090471403, 0.07929168376086736, 0.4086255139492943, 2.045562727396195, -0.48701853719962834, 0.10304152395720723, -0.8993147347502636, -0.49078404206110715}); - auto predictions = NDArrayFactory::create('c', {10,4}, {-0.5982871220907984, 1.2010665656903237, 0.30243355682445544, -0.2070857400459659, 0.6962389393180044, -0.5878034128580758, 0.8325626284025988, -0.3555823702782838, -0.7099759151434476, 1.7971905051128672, -1.1018498592680859, 0.008705918349147959, -1.713038986676157, 0.5029671900704719, 0.7491261275031563, -0.34800067781360444, -1.3529065441284513, -0.6075230577852321, -0.6153583973120907, 1.6014780660677996, 0.6444219215516616, 0.7925830851904783, -0.5006063079380708, 1.7812300901376552, 0.4736193941708224, 1.411502849640833, 0.9555142545037492, -0.03936687661890644, 1.31661624967917, 0.7344531724786305, 0.8388550872918745, 0.7010030219905558, -0.5442944240155373, 0.4437344837841118, -1.7502823958671712, -1.9271369730241665, 0.9256612923554498, 1.9065401403827893, 0.42450175148842717, -0.11783183865542822}); - auto weights = NDArrayFactory::create('c', {1,1}, {1}); - auto expected = NDArrayFactory::create('c', {10,1}, {1.9665822560405073, 3.806679563402927, 6.185624212589066, 20.237895345263905, 16.739700814450472, 13.655430201400929, 6.473256392322658, 3.9337379694106325, 22.509455553531062, 1.4741234749089487}); - - sd::ops::mean_pairwssqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + auto labels = NDArrayFactory::create( + 'c', {10, 4}, + {-0.5533444483384939, -0.4045807428083095, -0.38990808632111873, + -1.3367815555936828, 2.2110825342567204, -0.3322538938773163, + 0.5683588435736076, 1.401524673423209, -0.2216208609234102, + -0.23645194877057543, -1.9319189398422172, 0.6106128799796062, + 1.6973842275926025, -2.8306371397325553E-4, -1.1550401544465256, + -0.08357706614294765, -0.27784822018757077, 0.8290894318337857, + 1.6484476009013025, -0.7752524785358668, -0.9700596207063842, + 3.0809371469543207, -0.23684959888998405, 0.22403535560739518, + 0.6146150452128438, -1.1250088686147994, -0.5915314787415693, + -0.0944090155356556, 0.7995514825959854, -1.2290496239142903, + -1.8329592004926936, -0.1694821152623061, -1.7614978090471403, + 0.07929168376086736, 0.4086255139492943, 2.045562727396195, + -0.48701853719962834, 0.10304152395720723, -0.8993147347502636, + -0.49078404206110715}); + auto predictions = NDArrayFactory::create( + 'c', {10, 4}, + {-0.5982871220907984, 1.2010665656903237, 0.30243355682445544, + -0.2070857400459659, 0.6962389393180044, -0.5878034128580758, + 0.8325626284025988, -0.3555823702782838, -0.7099759151434476, + 1.7971905051128672, -1.1018498592680859, 0.008705918349147959, + -1.713038986676157, 0.5029671900704719, 0.7491261275031563, + -0.34800067781360444, -1.3529065441284513, -0.6075230577852321, + -0.6153583973120907, 1.6014780660677996, 0.6444219215516616, + 0.7925830851904783, -0.5006063079380708, 1.7812300901376552, + 0.4736193941708224, 1.411502849640833, 0.9555142545037492, + -0.03936687661890644, 1.31661624967917, 0.7344531724786305, + 0.8388550872918745, 0.7010030219905558, -0.5442944240155373, + 0.4437344837841118, -1.7502823958671712, -1.9271369730241665, + 0.9256612923554498, 1.9065401403827893, 0.42450175148842717, + -0.11783183865542822}); + auto weights = NDArrayFactory::create('c', {1, 1}, {1}); + auto expected = NDArrayFactory::create( + 'c', {10, 1}, + {1.9665822560405073, 3.806679563402927, 6.185624212589066, + 20.237895345263905, 16.739700814450472, 13.655430201400929, + 6.473256392322658, 3.9337379694106325, 22.509455553531062, + 1.4741234749089487}); + + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test3) { - auto labels = NDArrayFactory::create('c', {10,4}, {0.9165069946629816, 0.166426191704143, 0.13873357227527264, -0.5986162145785378, 0.4763504550662989, 1.2259816058633732, -0.4653205175596491, -1.7447031523970766, 1.349525448316014, 2.433089865629357, -2.54858150221601, -0.6060282162911894, 0.2625377104613349, -0.5007107584102752, 0.9576065700956302, -0.35787770401703584, -0.2608532564720665, 0.65688909921908, -0.1705876431948587, 1.2052884124800949, -0.976783296084278, 1.1163504624016534, -0.10545986164581109, -1.0632271027867568, 0.26460250034147065, -0.2299030354616135, -0.418989869909565, 0.7954060747536896, 0.37934127200736545, 0.8550487997440007, 0.2984909806904042, 0.1329065864221682, 1.478600294413247, 0.05421279873635542, -1.0552978360622536, -0.743808639782604, -1.3371851696151362, 2.7752972493355963, -1.6107187893743549, 1.5030902829432997}); - auto predictions = NDArrayFactory::create('c', {10,4}, {-3.398114657004427, 0.40587455906092945, 1.587706448479039, 0.27394335709083156, 1.0463122023764637, -0.6552570653663903, -0.26929204111727345, -2.710461824817806, 0.9141296064806023, -0.7632270851454939, -0.4077235519855459, 0.5555107559107472, -0.6776140976423888, 1.2422270521180823, 0.2372445100636733, 0.08522757123963924, -2.708523129389936, 0.09738215252575103, -0.8797837670498875, 0.8714091607391934, -0.628958978867591, 0.49380147969660415, -0.6663578349373824, 0.14570184758600965, -0.4710388511314244, 0.7708214742640788, 0.06836525442683238, -1.2786368797129386, -0.5077556003990912, 0.45383439418987664, 1.1686877788409553, -0.3078567969393852, -2.2375730522738198, 1.0108200459611192, 0.21955367964983963, 1.2268011099696847, 0.48061693077695455, -0.5306373077054981, 1.5005367299570744, -2.1005486985463966}); - auto weights = NDArrayFactory::create('c', {10,1}, {0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0}); - auto expected = NDArrayFactory::create('c', {10,1}, {0.0, 0.0, 21.748459867092496, 6.090581568657439, 7.51315897553838, 5.999534225166869, 22.58050883748054, 6.8600435676788605, 107.5976928688877, 191.56864939172544}); - - sd::ops::mean_pairwssqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + auto labels = NDArrayFactory::create( + 'c', {10, 4}, + {0.9165069946629816, 0.166426191704143, 0.13873357227527264, + -0.5986162145785378, 0.4763504550662989, 1.2259816058633732, + -0.4653205175596491, -1.7447031523970766, 1.349525448316014, + 2.433089865629357, -2.54858150221601, -0.6060282162911894, + 0.2625377104613349, -0.5007107584102752, 0.9576065700956302, + -0.35787770401703584, -0.2608532564720665, 0.65688909921908, + -0.1705876431948587, 1.2052884124800949, -0.976783296084278, + 1.1163504624016534, -0.10545986164581109, -1.0632271027867568, + 0.26460250034147065, -0.2299030354616135, -0.418989869909565, + 0.7954060747536896, 0.37934127200736545, 0.8550487997440007, + 0.2984909806904042, 0.1329065864221682, 1.478600294413247, + 0.05421279873635542, -1.0552978360622536, -0.743808639782604, + -1.3371851696151362, 2.7752972493355963, -1.6107187893743549, + 1.5030902829432997}); + auto predictions = NDArrayFactory::create( + 'c', {10, 4}, + {-3.398114657004427, 0.40587455906092945, 1.587706448479039, + 0.27394335709083156, 1.0463122023764637, -0.6552570653663903, + -0.26929204111727345, -2.710461824817806, 0.9141296064806023, + -0.7632270851454939, -0.4077235519855459, 0.5555107559107472, + -0.6776140976423888, 1.2422270521180823, 0.2372445100636733, + 0.08522757123963924, -2.708523129389936, 0.09738215252575103, + -0.8797837670498875, 0.8714091607391934, -0.628958978867591, + 0.49380147969660415, -0.6663578349373824, 0.14570184758600965, + -0.4710388511314244, 0.7708214742640788, 0.06836525442683238, + -1.2786368797129386, -0.5077556003990912, 0.45383439418987664, + 1.1686877788409553, -0.3078567969393852, -2.2375730522738198, + 1.0108200459611192, 0.21955367964983963, 1.2268011099696847, + 0.48061693077695455, -0.5306373077054981, 1.5005367299570744, + -2.1005486985463966}); + auto weights = NDArrayFactory::create( + 'c', {10, 1}, {0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0}); + auto expected = NDArrayFactory::create( + 'c', {10, 1}, + {0.0, 0.0, 21.748459867092496, 6.090581568657439, 7.51315897553838, + 5.999534225166869, 22.58050883748054, 6.8600435676788605, + 107.5976928688877, 191.56864939172544}); + + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test4) { - auto labels = NDArrayFactory::create('c', {10,4}, {-1.9540657282602247, -0.37099621218123746, 0.24959541842365968, 0.4125896396216978, -0.8661959659606203, 0.3651479206362867, -1.7475031047706964, -1.0962133982440159, 0.8451229874730279, 0.6876932162478913, 1.2598782790596628, 0.9372328828104118, 1.383555504464105, -0.816048166961237, 0.009041816630426176, -0.004376554457540983, -0.2386352931506252, -0.6494407817111416, 1.7888273635934742, -1.2157303560822368, -0.2446697859467434, -0.3040881765177774, -0.25843499040765916, -0.16479617511053568, 1.8063435075905592, 0.36002291874022285, -0.43317974028771883, 1.070086390817373, -1.0788479808458253, -0.3364318348487324, -0.859106579072977, 0.43984270049845064, -0.23662331183489546, -1.263417124724063, -0.3123732566483939, -0.125249623799724, -1.951308433393268, -0.4925779190927575, -1.081735149025745, -1.9910331435034687}); - auto predictions = NDArrayFactory::create('c', {10,4}, {-1.7053977111021588, 1.7704125629388408, -0.0876171627499475, 0.9428762101237441, 0.9080108618240852, -0.478732892339118, -0.8189639230649537, 1.3359668242925342, -0.07499867017894829, 0.6169780756804321, -1.1891117691972148, -0.319354110980483, -1.4287263424900434, -0.3556443786879834, 0.6389682186473912, 0.3161742985911756, 0.9047447733840537, -1.9974117226910393, 2.1067775658502326, 0.17035521714679938, -1.1393894489992826, 1.4570837278971687, 0.6312249731754015, -0.42793125692777634, -1.0685964336386844, -0.3590636581851568, -0.19147354841437528, -0.10128937266756889, -0.5714869078294972, 0.2682604831358205, 0.6608524575561853, 0.35658907103040305, -0.7053263272861181, -0.6318441042427088, 2.131292677079184, -0.3624048087249232, 1.6008209804575328, 0.1245980660014825, 1.0685424462364297, -0.5672594432046791}); - auto weights = NDArrayFactory::create('c', {1,1}, {1}); - - sd::ops::mean_pairwssqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 60.74394998193965, 1e-5); - - + auto labels = NDArrayFactory::create( + 'c', {10, 4}, + {-1.9540657282602247, -0.37099621218123746, 0.24959541842365968, + 0.4125896396216978, -0.8661959659606203, 0.3651479206362867, + -1.7475031047706964, -1.0962133982440159, 0.8451229874730279, + 0.6876932162478913, 1.2598782790596628, 0.9372328828104118, + 1.383555504464105, -0.816048166961237, 0.009041816630426176, + -0.004376554457540983, -0.2386352931506252, -0.6494407817111416, + 1.7888273635934742, -1.2157303560822368, -0.2446697859467434, + -0.3040881765177774, -0.25843499040765916, -0.16479617511053568, + 1.8063435075905592, 0.36002291874022285, -0.43317974028771883, + 1.070086390817373, -1.0788479808458253, -0.3364318348487324, + -0.859106579072977, 0.43984270049845064, -0.23662331183489546, + -1.263417124724063, -0.3123732566483939, -0.125249623799724, + -1.951308433393268, -0.4925779190927575, -1.081735149025745, + -1.9910331435034687}); + auto predictions = NDArrayFactory::create( + 'c', {10, 4}, + {-1.7053977111021588, 1.7704125629388408, -0.0876171627499475, + 0.9428762101237441, 0.9080108618240852, -0.478732892339118, + -0.8189639230649537, 1.3359668242925342, -0.07499867017894829, + 0.6169780756804321, -1.1891117691972148, -0.319354110980483, + -1.4287263424900434, -0.3556443786879834, 0.6389682186473912, + 0.3161742985911756, 0.9047447733840537, -1.9974117226910393, + 2.1067775658502326, 0.17035521714679938, -1.1393894489992826, + 1.4570837278971687, 0.6312249731754015, -0.42793125692777634, + -1.0685964336386844, -0.3590636581851568, -0.19147354841437528, + -0.10128937266756889, -0.5714869078294972, 0.2682604831358205, + 0.6608524575561853, 0.35658907103040305, -0.7053263272861181, + -0.6318441042427088, 2.131292677079184, -0.3624048087249232, + 1.6008209804575328, 0.1245980660014825, 1.0685424462364297, + -0.5672594432046791}); + auto weights = NDArrayFactory::create('c', {1, 1}, {1}); + + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 60.74394998193965, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test5) { - auto labels = NDArrayFactory::create('c', {10,4}, {0.9165069946629816, 0.166426191704143, 0.13873357227527264, -0.5986162145785378, 0.4763504550662989, 1.2259816058633732, -0.4653205175596491, -1.7447031523970766, 1.349525448316014, 2.433089865629357, -2.54858150221601, -0.6060282162911894, 0.2625377104613349, -0.5007107584102752, 0.9576065700956302, -0.35787770401703584, -0.2608532564720665, 0.65688909921908, -0.1705876431948587, 1.2052884124800949, -0.976783296084278, 1.1163504624016534, -0.10545986164581109, -1.0632271027867568, 0.26460250034147065, -0.2299030354616135, -0.418989869909565, 0.7954060747536896, 0.37934127200736545, 0.8550487997440007, 0.2984909806904042, 0.1329065864221682, 1.478600294413247, 0.05421279873635542, -1.0552978360622536, -0.743808639782604, -1.3371851696151362, 2.7752972493355963, -1.6107187893743549, 1.5030902829432997}); - auto predictions = NDArrayFactory::create('c', {10,4}, {-3.398114657004427, 0.40587455906092945, 1.587706448479039, 0.27394335709083156, 1.0463122023764637, -0.6552570653663903, -0.26929204111727345, -2.710461824817806, 0.9141296064806023, -0.7632270851454939, -0.4077235519855459, 0.5555107559107472, -0.6776140976423888, 1.2422270521180823, 0.2372445100636733, 0.08522757123963924, -2.708523129389936, 0.09738215252575103, -0.8797837670498875, 0.8714091607391934, -0.628958978867591, 0.49380147969660415, -0.6663578349373824, 0.14570184758600965, -0.4710388511314244, 0.7708214742640788, 0.06836525442683238, -1.2786368797129386, -0.5077556003990912, 0.45383439418987664, 1.1686877788409553, -0.3078567969393852, -2.2375730522738198, 1.0108200459611192, 0.21955367964983963, 1.2268011099696847, 0.48061693077695455, -0.5306373077054981, 1.5005367299570744, -2.1005486985463966}); - auto weights = NDArrayFactory::create('c', {1,1}, {1}); - - sd::ops::mean_pairwssqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 15.189082270182983, 1e-5); - - + auto labels = NDArrayFactory::create( + 'c', {10, 4}, + {0.9165069946629816, 0.166426191704143, 0.13873357227527264, + -0.5986162145785378, 0.4763504550662989, 1.2259816058633732, + -0.4653205175596491, -1.7447031523970766, 1.349525448316014, + 2.433089865629357, -2.54858150221601, -0.6060282162911894, + 0.2625377104613349, -0.5007107584102752, 0.9576065700956302, + -0.35787770401703584, -0.2608532564720665, 0.65688909921908, + -0.1705876431948587, 1.2052884124800949, -0.976783296084278, + 1.1163504624016534, -0.10545986164581109, -1.0632271027867568, + 0.26460250034147065, -0.2299030354616135, -0.418989869909565, + 0.7954060747536896, 0.37934127200736545, 0.8550487997440007, + 0.2984909806904042, 0.1329065864221682, 1.478600294413247, + 0.05421279873635542, -1.0552978360622536, -0.743808639782604, + -1.3371851696151362, 2.7752972493355963, -1.6107187893743549, + 1.5030902829432997}); + auto predictions = NDArrayFactory::create( + 'c', {10, 4}, + {-3.398114657004427, 0.40587455906092945, 1.587706448479039, + 0.27394335709083156, 1.0463122023764637, -0.6552570653663903, + -0.26929204111727345, -2.710461824817806, 0.9141296064806023, + -0.7632270851454939, -0.4077235519855459, 0.5555107559107472, + -0.6776140976423888, 1.2422270521180823, 0.2372445100636733, + 0.08522757123963924, -2.708523129389936, 0.09738215252575103, + -0.8797837670498875, 0.8714091607391934, -0.628958978867591, + 0.49380147969660415, -0.6663578349373824, 0.14570184758600965, + -0.4710388511314244, 0.7708214742640788, 0.06836525442683238, + -1.2786368797129386, -0.5077556003990912, 0.45383439418987664, + 1.1686877788409553, -0.3078567969393852, -2.2375730522738198, + 1.0108200459611192, 0.21955367964983963, 1.2268011099696847, + 0.48061693077695455, -0.5306373077054981, 1.5005367299570744, + -2.1005486985463966}); + auto weights = NDArrayFactory::create('c', {1, 1}, {1}); + + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 15.189082270182983, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test6) { - auto labels = NDArrayFactory::create('c', {10,4}, {0.7712557146220891, 0.37344724586647443, -1.465944048516541, 0.3226845250222374, 0.3153238532645865, -0.6453963287132424, -1.7695663855309438, -0.31350813714835285, 0.6209850696184357, -1.0632582557661083, 0.8971205782356552, -0.7361143357044725, 0.4349813432397299, 1.1012674501462072, -1.846028584047857, -0.04711049067212126, 0.3511384383511822, -1.5908669452488973, 0.6271232025632083, -0.5370025878354387, 0.09775855957778733, 0.8465118033582384, -0.5118005514773271, -0.8215749768059044, -0.5154271246850248, -0.6614138367887438, -2.721743038982485, -0.20634785234624944, 1.074134378795222, -0.515671736473577, 0.33574452224656587, -0.4258992514621533, -1.6946210614398756, 2.0853105493575246, -0.23223717047374226, -1.3145231337861756, -0.307739072607248, -0.13713627422120406, -0.05615471338688221, -0.7031780205843188}); - auto predictions = NDArrayFactory::create('c', {10,4}, {-0.8253096544930751, 0.81324545672996, 1.2530858908292535, 0.6881658781201572, 0.11626814971230247, 0.810096847233213, -0.41726775033902014, -0.07246036077805246, -0.3491325803119671, -0.7381717490678714, -1.258884944199858, 2.6195012275145992, 0.3241066697239042, -1.3306435333372646, -0.3413119919683999, 0.13167356361127197, -0.3992424507051653, 0.14454163796541403, -2.4931643208872316, 1.8740911656038526, -2.3404306490682956, -0.8036392545918644, -1.9726177395274997, -0.20128619801149433, -1.0680828820641624, -0.6228179015361869, 1.0785520122486962, -0.26148573195062036, -0.9154287856620913, 0.6612224269248097, -0.21735407368781667, 0.5584864652543093, 1.0208212201167435, -0.7560947201084579, -0.9092906572495081, 0.47525819203475833, 1.2215678456801444, -0.39319465979983964, 1.9435677135606038, 1.4540100039010526}); - auto weights = NDArrayFactory::create('c', {1,1}, {1}); - - sd::ops::mean_pairwssqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 13.568564090650312, 1e-5); - - + auto labels = NDArrayFactory::create( + 'c', {10, 4}, + {0.7712557146220891, 0.37344724586647443, -1.465944048516541, + 0.3226845250222374, 0.3153238532645865, -0.6453963287132424, + -1.7695663855309438, -0.31350813714835285, 0.6209850696184357, + -1.0632582557661083, 0.8971205782356552, -0.7361143357044725, + 0.4349813432397299, 1.1012674501462072, -1.846028584047857, + -0.04711049067212126, 0.3511384383511822, -1.5908669452488973, + 0.6271232025632083, -0.5370025878354387, 0.09775855957778733, + 0.8465118033582384, -0.5118005514773271, -0.8215749768059044, + -0.5154271246850248, -0.6614138367887438, -2.721743038982485, + -0.20634785234624944, 1.074134378795222, -0.515671736473577, + 0.33574452224656587, -0.4258992514621533, -1.6946210614398756, + 2.0853105493575246, -0.23223717047374226, -1.3145231337861756, + -0.307739072607248, -0.13713627422120406, -0.05615471338688221, + -0.7031780205843188}); + auto predictions = NDArrayFactory::create( + 'c', {10, 4}, + {-0.8253096544930751, 0.81324545672996, 1.2530858908292535, + 0.6881658781201572, 0.11626814971230247, 0.810096847233213, + -0.41726775033902014, -0.07246036077805246, -0.3491325803119671, + -0.7381717490678714, -1.258884944199858, 2.6195012275145992, + 0.3241066697239042, -1.3306435333372646, -0.3413119919683999, + 0.13167356361127197, -0.3992424507051653, 0.14454163796541403, + -2.4931643208872316, 1.8740911656038526, -2.3404306490682956, + -0.8036392545918644, -1.9726177395274997, -0.20128619801149433, + -1.0680828820641624, -0.6228179015361869, 1.0785520122486962, + -0.26148573195062036, -0.9154287856620913, 0.6612224269248097, + -0.21735407368781667, 0.5584864652543093, 1.0208212201167435, + -0.7560947201084579, -0.9092906572495081, 0.47525819203475833, + 1.2215678456801444, -0.39319465979983964, 1.9435677135606038, + 1.4540100039010526}); + auto weights = NDArrayFactory::create('c', {1, 1}, {1}); + + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 13.568564090650312, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test7) { - auto labels = NDArrayFactory::create('c', {10,4}, {-0.06125002348040258, 0.5143643450377119, 2.6790723358660036, -0.8032552006036418, -2.4374371040644163, -0.1562964773317163, -1.3957988654288038, 1.2791626503391635, -1.433421873294552, -1.1819478586737284, 0.05162930965054662, -0.538650473505593, -0.548171720093084, -0.3103900587344872, -2.3955103171953342, 0.7127238680062526, 0.7182079438418053, 1.1842662402382182, 0.09585189676958715, 0.9276146067349225, 0.7856673461867428, 0.41368195133354113, -0.2939280190178078, -2.400566355562181, -1.1841519118039245, -1.066170501847581, -0.9274507409610022, 1.7671863041813334, -1.2849985781031494, -1.275990164491566, -0.8866824403466698, -0.6074077385015517, 0.7647344603897107, -1.048099070426831, 0.9433828938345293, -0.5591415819237762, 1.7962773615541947, -0.42365710367758247, -0.0385518907389571, -1.109959713481321}); - auto predictions = NDArrayFactory::create('c', {10,4}, {-0.7445687252538243, 0.2293875300325241, -1.0231630280206505, -0.18532545069458992, -0.07797403344353356, -0.9132035669873787, 0.9352296415512886, -1.7406458535354787, 0.8578334648119594, -0.6186274065269556, 0.4874824473654153, -0.9285817343788997, 0.1654680500853023, -0.6371334533926012, 1.3115245864160707, -2.072558735678832, 0.660795731844733, -0.34942292767044864, 0.05787182311194333, -0.12939210444705632, -0.6457028552461069, -0.6048992126598505, -0.17179604529778109, 1.292989642826032, -0.28867767615688045, 0.7635565516046265, -1.5464151753137487, -1.273368390129285, -1.074046012825826, -0.3534580692302915, 0.5757285568118223, 1.823271242883469, 0.31618576929075215, 0.5422847605415213, -0.7836698021860683, -0.6292022623165172, 2.1114596721927508, 0.4634986528550097, 0.08922001427846013, 1.5767749644913223}); - auto weights = NDArrayFactory::create('c', {10,1}, {0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0}); - - sd::ops::mean_pairwssqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 198.318201904499, 1e-5); - - + auto labels = NDArrayFactory::create( + 'c', {10, 4}, + {-0.06125002348040258, 0.5143643450377119, 2.6790723358660036, + -0.8032552006036418, -2.4374371040644163, -0.1562964773317163, + -1.3957988654288038, 1.2791626503391635, -1.433421873294552, + -1.1819478586737284, 0.05162930965054662, -0.538650473505593, + -0.548171720093084, -0.3103900587344872, -2.3955103171953342, + 0.7127238680062526, 0.7182079438418053, 1.1842662402382182, + 0.09585189676958715, 0.9276146067349225, 0.7856673461867428, + 0.41368195133354113, -0.2939280190178078, -2.400566355562181, + -1.1841519118039245, -1.066170501847581, -0.9274507409610022, + 1.7671863041813334, -1.2849985781031494, -1.275990164491566, + -0.8866824403466698, -0.6074077385015517, 0.7647344603897107, + -1.048099070426831, 0.9433828938345293, -0.5591415819237762, + 1.7962773615541947, -0.42365710367758247, -0.0385518907389571, + -1.109959713481321}); + auto predictions = NDArrayFactory::create( + 'c', {10, 4}, + {-0.7445687252538243, 0.2293875300325241, -1.0231630280206505, + -0.18532545069458992, -0.07797403344353356, -0.9132035669873787, + 0.9352296415512886, -1.7406458535354787, 0.8578334648119594, + -0.6186274065269556, 0.4874824473654153, -0.9285817343788997, + 0.1654680500853023, -0.6371334533926012, 1.3115245864160707, + -2.072558735678832, 0.660795731844733, -0.34942292767044864, + 0.05787182311194333, -0.12939210444705632, -0.6457028552461069, + -0.6048992126598505, -0.17179604529778109, 1.292989642826032, + -0.28867767615688045, 0.7635565516046265, -1.5464151753137487, + -1.273368390129285, -1.074046012825826, -0.3534580692302915, + 0.5757285568118223, 1.823271242883469, 0.31618576929075215, + 0.5422847605415213, -0.7836698021860683, -0.6292022623165172, + 2.1114596721927508, 0.4634986528550097, 0.08922001427846013, + 1.5767749644913223}); + auto weights = NDArrayFactory::create( + 'c', {10, 1}, {0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0}); + + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 198.318201904499, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test8) { - auto labels = NDArrayFactory::create('c', {10,4}, {1.2003157672694111, -1.0738078620687983, 1.4513396266923826, 0.5753935722952708, -0.5424028602429585, 0.9816221437385002, -1.0566397385428794, 1.503481308203513, -0.6543147953583112, 1.7453669976827346, -0.1557689124924227, 0.3387794658137257, -1.2306868494328145, -0.3299042398395769, 0.026464968146954395, -1.5077479623528403, -0.27514168845621795, 0.18739335150879793, 1.7319910646645431, 1.5228099405663476, 0.8522684742808536, 0.2362049362675063, 0.2610756525241469, 0.457998065505686, -2.7342179885912623, -0.10968795695808314, 0.581598742956297, -1.9309885922934567, -1.5775788440607954, -0.04254899350225641, -0.3125858556254039, -1.1328154327730207, 0.00566243314780096, 0.8492052576274621, 0.05945202212214481, 1.4976918834497108, 0.8869512918387292, 0.4014181932175132, -0.015512552855187248, -1.3609667909108454}); - auto predictions = NDArrayFactory::create('c', {10,4}, {-1.1088399463364795, 0.09302972835006071, 0.033839927431215555, -0.39567507675572494, 0.8269497207597863, 1.111162272517752, 0.4930937252630912, -1.4561668998323452, 0.9417715392862969, -1.0553855492735509, 0.05848285303876081, 0.8852337518047972, -0.7472824481835305, 0.404906922583895, -0.2198309547562547, 1.9536515925189717, 0.8165036568007779, -0.19524282774410398, -0.09111693087754393, 1.1604245932512238, -0.6243762858131077, 1.4297003275591034, -0.17220079411538428, -2.3139504326793032, 0.3839796486999712, 2.0287791964679234, 0.1534441713632995, -0.6062103319229825, -0.4965880982906036, -0.373907747810053, -1.6566345746154432, 0.17534987728494222, -1.6713458890334796, 1.254628987947714, 1.914596591838086, -1.0816010467183583, 0.25033738231939673, -1.605752685708275, 1.1029112741353981, 0.3237822320282494}); - auto weights = NDArrayFactory::create('c', {10,1}, {0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0}); - - sd::ops::mean_pairwssqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 10.709003499121707, 1e-5); - - + auto labels = NDArrayFactory::create( + 'c', {10, 4}, + {1.2003157672694111, -1.0738078620687983, 1.4513396266923826, + 0.5753935722952708, -0.5424028602429585, 0.9816221437385002, + -1.0566397385428794, 1.503481308203513, -0.6543147953583112, + 1.7453669976827346, -0.1557689124924227, 0.3387794658137257, + -1.2306868494328145, -0.3299042398395769, 0.026464968146954395, + -1.5077479623528403, -0.27514168845621795, 0.18739335150879793, + 1.7319910646645431, 1.5228099405663476, 0.8522684742808536, + 0.2362049362675063, 0.2610756525241469, 0.457998065505686, + -2.7342179885912623, -0.10968795695808314, 0.581598742956297, + -1.9309885922934567, -1.5775788440607954, -0.04254899350225641, + -0.3125858556254039, -1.1328154327730207, 0.00566243314780096, + 0.8492052576274621, 0.05945202212214481, 1.4976918834497108, + 0.8869512918387292, 0.4014181932175132, -0.015512552855187248, + -1.3609667909108454}); + auto predictions = NDArrayFactory::create( + 'c', {10, 4}, + {-1.1088399463364795, 0.09302972835006071, 0.033839927431215555, + -0.39567507675572494, 0.8269497207597863, 1.111162272517752, + 0.4930937252630912, -1.4561668998323452, 0.9417715392862969, + -1.0553855492735509, 0.05848285303876081, 0.8852337518047972, + -0.7472824481835305, 0.404906922583895, -0.2198309547562547, + 1.9536515925189717, 0.8165036568007779, -0.19524282774410398, + -0.09111693087754393, 1.1604245932512238, -0.6243762858131077, + 1.4297003275591034, -0.17220079411538428, -2.3139504326793032, + 0.3839796486999712, 2.0287791964679234, 0.1534441713632995, + -0.6062103319229825, -0.4965880982906036, -0.373907747810053, + -1.6566345746154432, 0.17534987728494222, -1.6713458890334796, + 1.254628987947714, 1.914596591838086, -1.0816010467183583, + 0.25033738231939673, -1.605752685708275, 1.1029112741353981, + 0.3237822320282494}); + auto weights = NDArrayFactory::create( + 'c', {10, 1}, {0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0}); + + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 10.709003499121707, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test9) { - auto labels = NDArrayFactory::create('c', {10,4}, {0.054445708809271035, 2.107634671009908, -0.7906421810578572, -1.075840781788665, 0.11881403008710377, 0.8444812915085994, -0.305754504070933, 1.6429935026781464, 0.8155105031719394, 0.04900134907242568, 0.6847004530975871, 0.23315535615893132, 0.17011663306483038, -1.1865513655938285, 1.5931597087896407, -1.7937514075547496, -0.036695307704292295, -1.6416280650778925, 1.130578912176608, -1.1267224667674058, -0.8690453889645526, 0.6717944721406133, 0.0850200492927782, 1.1294419289013125, 0.2154793028698133, 0.4557382556428947, -0.7343674069166273, -0.20013117860162175, -0.6096905108192562, 0.42022878041905926, -0.7446306649741321, 0.01724811509597817, 1.843091605690758, 1.008879504632424, 1.198292190689489, -0.4474144618813475, 0.25202981742888664, 0.07036737843407408, 1.2400630276444486, -1.1072825235557615}); - auto predictions = NDArrayFactory::create('c', {10,4}, {-1.6788168943811437, 1.1823653279081687, -0.3580541857004183, -0.4449970504370699, -1.3031645333940127, 0.5755013195969282, -0.7997343141774744, -0.8806735270004084, 0.9705277499376251, -1.6360067944580943, 0.12579369136710156, 1.0525902242414313, -1.625751312422252, -0.03900152587147075, 0.4112500942756277, 0.6589999986358094, 0.6144107111689617, 2.8561269030217264, 1.5299963640392247, -0.314093051147705, 1.6523278218751989, -0.5504653447714114, 0.53395260877978, 0.409795577698306, 0.4466825218051794, 1.2382059301630401, 0.4834869732526594, -0.635409128905636, -1.9343816841697272, -0.4192523056060229, -1.0662979055059818, 0.4270901960618144, -0.7391311480757151, -0.8268168961897452, -1.0855715553457785, -9.410401291588706E-4, -0.7721838774717349, 0.4784019579457375, -0.6979798841469268, -0.319729737118584}); - auto weights = NDArrayFactory::create('c', {10,1}, {0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0}); - - sd::ops::mean_pairwssqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 17.686067864414472, 1e-5); - - + auto labels = NDArrayFactory::create( + 'c', {10, 4}, + {0.054445708809271035, 2.107634671009908, -0.7906421810578572, + -1.075840781788665, 0.11881403008710377, 0.8444812915085994, + -0.305754504070933, 1.6429935026781464, 0.8155105031719394, + 0.04900134907242568, 0.6847004530975871, 0.23315535615893132, + 0.17011663306483038, -1.1865513655938285, 1.5931597087896407, + -1.7937514075547496, -0.036695307704292295, -1.6416280650778925, + 1.130578912176608, -1.1267224667674058, -0.8690453889645526, + 0.6717944721406133, 0.0850200492927782, 1.1294419289013125, + 0.2154793028698133, 0.4557382556428947, -0.7343674069166273, + -0.20013117860162175, -0.6096905108192562, 0.42022878041905926, + -0.7446306649741321, 0.01724811509597817, 1.843091605690758, + 1.008879504632424, 1.198292190689489, -0.4474144618813475, + 0.25202981742888664, 0.07036737843407408, 1.2400630276444486, + -1.1072825235557615}); + auto predictions = NDArrayFactory::create( + 'c', {10, 4}, + {-1.6788168943811437, 1.1823653279081687, -0.3580541857004183, + -0.4449970504370699, -1.3031645333940127, 0.5755013195969282, + -0.7997343141774744, -0.8806735270004084, 0.9705277499376251, + -1.6360067944580943, 0.12579369136710156, 1.0525902242414313, + -1.625751312422252, -0.03900152587147075, 0.4112500942756277, + 0.6589999986358094, 0.6144107111689617, 2.8561269030217264, + 1.5299963640392247, -0.314093051147705, 1.6523278218751989, + -0.5504653447714114, 0.53395260877978, 0.409795577698306, + 0.4466825218051794, 1.2382059301630401, 0.4834869732526594, + -0.635409128905636, -1.9343816841697272, -0.4192523056060229, + -1.0662979055059818, 0.4270901960618144, -0.7391311480757151, + -0.8268168961897452, -1.0855715553457785, -9.410401291588706E-4, + -0.7721838774717349, 0.4784019579457375, -0.6979798841469268, + -0.319729737118584}); + auto weights = NDArrayFactory::create( + 'c', {10, 1}, {0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0}); + + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 17.686067864414472, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test1) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, {0.125, 0.5, 1.125, 2., 3.125, 4.5, 6.125, 8., + 10.125, 12.5, 15.125, 18., 21.125, 24.5, 28.125, 32., + 36.125, 40.5, 45.125, 50., 55.125, 60.5, 66.125, 72.}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {0.125, 0.5, 1.125, 2., 3.125, 4.5, 6.125, 8.,10.125,12.5,15.125,18.,21.125,24.5,28.125,32.,36.125,40.5,45.125,50.,55.125,60.5,66.125,72.}); - - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test2) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, {0.125, 0.5, 1.125, 2., 3.125, 4.5, 6.125, 8., + 10.125, 12.5, 15.125, 18., 21.125, 24.5, 28.125, 32., + 36.125, 40.5, 45.125, 50., 55.125, 60.5, 66.125, 72.}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,1,4}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {0.125, 0.5, 1.125, 2., 3.125, 4.5, 6.125, 8.,10.125,12.5,15.125,18.,21.125,24.5,28.125,32.,36.125,40.5,45.125,50.,55.125,60.5,66.125,72.}); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test3) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 1}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, {0.125, 0.5, 1.125, 2., 3.125, 4.5, 6.125, 8., + 10.125, 12.5, 15.125, 18., 21.125, 24.5, 28.125, 32., + 36.125, 40.5, 45.125, 50., 55.125, 60.5, 66.125, 72.}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,1,1}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {0.125, 0.5, 1.125, 2., 3.125, 4.5, 6.125, 8.,10.125,12.5,15.125,18.,21.125,24.5,28.125,32.,36.125,40.5,45.125,50.,55.125,60.5,66.125,72.}); - - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test4) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, {0., 0., 0., 0., 3.125, 4.5, 6.125, 8., + 10.125, 12.5, 15.125, 18., 21.125, 24.5, 28.125, 32., + 36.125, 40.5, 45.125, 50., 55.125, 60.5, 66.125, 72.}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {0., 0., 0., 0., 3.125, 4.5, 6.125, 8.,10.125,12.5,15.125,18.,21.125,24.5,28.125,32.,36.125,40.5,45.125,50.,55.125,60.5,66.125,72.}); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); - weights.p(3, 0.); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test5) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 612.5, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 612.5, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test6) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1,4}); - - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 612.5, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 612.5, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test7) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 612.5, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 612.5, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test8) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); - weights.p(3, 0.); - - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 608.75, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 608.75, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test9) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 51.041668, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 51.041668, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test10) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 3, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,3,1}); - - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 51.041668, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 51.041668, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test11) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 51.041668, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 51.041668, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test12) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,1}); - - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 88.541664, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 88.541664, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test13) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 25.520834, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 25.520834, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test14) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,1,4}); - - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 25.520834, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 25.520834, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test15) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 25.520834, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 25.520834, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test16) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,1}); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 44.270832, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 44.270832, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test1) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.37219834, 0.29906943, 0.27717763, 0.45650762, 0.23703849, 0.51874399, + 0.20159303, 0.58555031, 0.17057693, 0.65663081, 0.14366767, 0.73164123, + 0.12050423, 0.81020868, 0.10070664, 0.89195037, 0.08389302, 0.97648883, + 1.01969337, 0.06346401, 0.05775976, 1.15254164, 0.04777273, 1.2434181}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {0.37219834,0.29906943,0.27717763,0.45650762,0.23703849,0.51874399,0.20159303,0.58555031,0.17057693,0.65663081,0.14366767,0.73164123,0.12050423,0.81020868,0.10070664,0.89195037,0.08389302,0.97648883,1.01969337,0.06346401,0.05775976,1.15254164,0.04777273,1.2434181 }); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); - - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test2) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 1}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.37219834, 0.29906943, 0.27717763, 0.45650762, 0.23703849, 0.51874399, + 0.20159303, 0.58555031, 0.17057693, 0.65663081, 0.14366767, 0.73164123, + 0.12050423, 0.81020868, 0.10070664, 0.89195037, 0.08389302, 0.97648883, + 1.01969337, 0.06346401, 0.05775976, 1.15254164, 0.04777273, 1.2434181}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,1,1}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {0.37219834,0.29906943,0.27717763,0.45650762,0.23703849,0.51874399,0.20159303,0.58555031,0.17057693,0.65663081,0.14366767,0.73164123,0.12050423,0.81020868,0.10070664,0.89195037,0.08389302,0.97648883,1.01969337,0.06346401,0.05775976,1.15254164,0.04777273,1.2434181 }); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - logits.linspace(0.1, 0.1); - weights.assign(0.5); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test3) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.37219834, 0.29906943, 0.27717763, 0.45650762, 0.23703849, 0.51874399, + 0.20159303, 0.58555031, 0.17057693, 0.65663081, 0.14366767, 0.73164123, + 0.12050423, 0.81020868, 0.10070664, 0.89195037, 0.08389302, 0.97648883, + 1.01969337, 0.06346401, 0.05775976, 1.15254164, 0.04777273, 1.2434181}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {0.37219834,0.29906943,0.27717763,0.45650762,0.23703849,0.51874399,0.20159303,0.58555031,0.17057693,0.65663081,0.14366767,0.73164123,0.12050423,0.81020868,0.10070664,0.89195037,0.08389302,0.97648883,1.01969337,0.06346401,0.05775976,1.15254164,0.04777273,1.2434181 }); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); - - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test4) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.24719833, 0.54906946, 0.65217763, -0.04349237, 0.86203849, + -0.23125602, 1.07659304, -0.41444966, 1.29557693, -0.59336919, + 1.5186677, -0.76835877, 1.74550426, -0.93979132, 1.9757067, + -1.10804963, 2.20889306, -1.27351117, -1.35530663, 2.56346393, + 2.68275976, -1.59745836, 2.92277265, -1.7565819}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {0.24719833, 0.54906946, 0.65217763,-0.04349237,0.86203849,-0.23125602, 1.07659304,-0.41444966,1.29557693,-0.59336919, 1.5186677 ,-0.76835877,1.74550426,-0.93979132, 1.9757067 ,-1.10804963,2.20889306,-1.27351117,-1.35530663, 2.56346393,2.68275976,-1.59745836, 2.92277265,-1.7565819 }); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - logits.linspace(0.1, 0.1); - weights.assign(0.5); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test5) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); - - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 11.2187976837, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 11.2187976837, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test6) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,1}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - logits.linspace(0.1, 0.1); - weights.assign(0.5); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 11.2187976837, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 11.2187976837, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test7) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 11.2187976837, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 11.2187976837, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test8) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 10.2187976837, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 10.2187976837, 1e-5); } - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test9) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,1}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); - logits.linspace(0.1, 0.1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 6.06840181351, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 6.06840181351, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test10) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2}); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 0.934899806976, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.934899806976, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test11) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1,4}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2}); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 0.934899806976, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.934899806976, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test12) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {2}); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 0.851566493511, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.851566493511, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test13) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,1}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); - - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {2}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {2}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 1.01140034199, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 1.01140034199, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test14) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - logits.linspace(0.1, 0.1); - weights.assign(0.5); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3}); - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 0.467449903488, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.467449903488, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test15) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 3, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,3,1}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); - - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 0.467449903488, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.467449903488, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test16) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - logits.linspace(0.1, 0.1); - weights.assign(0.5); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {3}); - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 0.425783246756, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.425783246756, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test17) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,1}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); + logits.linspace(0.1, 0.1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {3}); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {3}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 0.505700170994, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.505700170994, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test1) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3}); + auto expected = NDArrayFactory::create( + 'c', {2, 3}, + {1.39253557, 1.44253552, 1.44253552, 1.44253552, 1.39253557, 1.44253552}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3}); - auto expected = NDArrayFactory::create('c', {2,3}, {1.39253557,1.44253552,1.44253552,1.44253552,1.39253557,1.44253552}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test2) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3}); + auto expected = + NDArrayFactory::create('c', {2, 3}, + {-0.92835701, -1.12835705, -1.12835705, + -1.12835705, -0.92835701, -1.12835705}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3}); - auto expected = NDArrayFactory::create('c', {2,3}, {-0.92835701,-1.12835705,-1.12835705,-1.12835705,-0.92835701,-1.12835705}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}, {}); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test3) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 1}); + auto expected = + NDArrayFactory::create('c', {2, 3}, + {-0.92835701, -1.12835705, -1.12835705, + -1.12835705, -0.92835701, -1.12835705}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,1}); - auto expected = NDArrayFactory::create('c', {2,3}, {-0.92835701,-1.12835705,-1.12835705,-1.12835705,-0.92835701,-1.12835705}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); - - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test4) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 3}); + auto expected = + NDArrayFactory::create('c', {2, 3}, + {-0.92835701, -1.12835705, -1.12835705, + -1.12835705, -0.92835701, -1.12835705}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,3}); - auto expected = NDArrayFactory::create('c', {2,3}, {-0.92835701,-1.12835705,-1.12835705,-1.12835705,-0.92835701,-1.12835705}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - logits.linspace(0.1, 0.1); - weights.assign(0.5); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test5) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); + auto expected = + NDArrayFactory::create('c', {2, 3}, + {-0.92835701, -1.12835705, -1.12835705, + -1.12835705, -0.92835701, -1.12835705}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - auto expected = NDArrayFactory::create('c', {2,3}, {-0.92835701,-1.12835705,-1.12835705,-1.12835705,-0.92835701,-1.12835705}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); - - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test6) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - logits.linspace(0.1, 0.1); - weights.assign(0.5); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 8.55521392822, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 8.55521392822, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test7) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -6.37014198303, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -6.37014198303, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test8) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -6.37014198303, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -6.37014198303, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test9) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 3}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,3}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -6.37014198303, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -6.37014198303, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test10) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 3}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,3}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); - - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {2}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {2}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -2.12338066101, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -2.12338066101, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test11) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 3}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,3}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - logits.linspace(0.1, 0.1); - weights.assign(0.5); + sd::ops::softmax_cross_entropy_loss op; + auto results = + op.evaluate({&logits, &weights, &labels}, {5.}, {3}, {}, {}, false); - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {3}, {}, {}, false); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -1.06169033051, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -1.06169033051, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test12) { + auto labels = + NDArrayFactory::create('c', {2, 4}, {0, 1, 1, 0, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 4}); + auto weights = NDArrayFactory::create('c', {2, 1}); - auto labels = NDArrayFactory::create('c', {2,4},{0,1,1,0,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,4}); - auto weights = NDArrayFactory::create('c', {2,1}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); - - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {3}, {}, {}, false); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::softmax_cross_entropy_loss op; + auto results = + op.evaluate({&logits, &weights, &labels}, {5.}, {3}, {}, {}, false); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -2.18880319595, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -2.18880319595, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test13) { + auto labels = + NDArrayFactory::create('c', {2, 4}, {0, 1, 1, 0, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 4}); + auto weights = NDArrayFactory::create('c', {2, 1}); + auto expected = + NDArrayFactory::create('c', {2, 1}, {1.39253557, 1.44253552}); - auto labels = NDArrayFactory::create('c', {2,4},{0,1,1,0,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,4}); - auto weights = NDArrayFactory::create('c', {2,1}); - auto expected = NDArrayFactory::create('c', {2,1}, {1.39253557,1.44253552}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - logits.linspace(0.1, 0.1); - weights.assign(0.5); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test14) { + auto labels = + NDArrayFactory::create('c', {2, 4}, {0, 1, 1, 0, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 4}); + auto weights = NDArrayFactory::create('c', {2, 1}); + auto expected = + NDArrayFactory::create('c', {2, 1}, {-2.08880329, -2.28880334}); - auto labels = NDArrayFactory::create('c', {2,4},{0,1,1,0,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,4}); - auto weights = NDArrayFactory::create('c', {2,1}); - auto expected = NDArrayFactory::create('c', {2,1}, {-2.08880329, -2.28880334}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - logits.linspace(0.1, 0.1); - weights.assign(0.5); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto *result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test15) { + auto labels = + NDArrayFactory::create('c', {2, 4}, {0, 1, 1, 0, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); + auto expected = + NDArrayFactory::create('c', {2, 1}, {-2.08880329, -2.28880334}); - auto labels = NDArrayFactory::create('c', {2,4},{0,1,1,0,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - auto expected = NDArrayFactory::create('c', {2,1}, {-2.08880329, -2.28880334}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); - - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); - auto *result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, lstmCell_test1) { - - const int batchSize = 2; - const int inSize = 10; - const int numProj = 4; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); - auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - xt.assign(1.); - ht_1.assign(2.); - ct_1.assign(3.); - Wx.assign(0.5); - Wh.assign(0.5); - Wc.assign(0.5); - Wp.assign(0.5); - b.assign(0.7); - - auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {0.99926789,0.99926789,0.99926789,0.99926789,0.99926789,0.99926789,0.99926789,0.99926789}); - auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.99987108,3.99987108,3.99987108,3.99987108,3.99987108,3.99987108,3.99987108,3.99987108}); - - sd::ops::lstmCell op; - auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 1.}, {0, 0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *ht = results.at(0); - auto *ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); - - + const int batchSize = 2; + const int inSize = 10; + const int numProj = 4; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numProj}, + {0.99926789, 0.99926789, 0.99926789, 0.99926789, 0.99926789, 0.99926789, + 0.99926789, 0.99926789}); + auto expCt = NDArrayFactory::create( + 'c', {batchSize, numUnits}, + {3.99987108, 3.99987108, 3.99987108, 3.99987108, 3.99987108, 3.99987108, + 3.99987108, 3.99987108}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, + {0., 0., 1.}, {0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto ht = results.at(0); + auto ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, lstmCell_test2) { - - const int batchSize = 2; - const int inSize = 10; - const int numProj = 4; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); - auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - xt.assign(1.); - ht_1.assign(2.); - ct_1.assign(3.); - Wx.assign(0.5); - Wh.assign(0.5); - Wc.assign(0.5); - Wp.assign(0.5); - b.assign(0.7); - - auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {0.95867589,0.95867589,0.95867589,0.95867589,0.95867589,0.95867589,0.95867589,0.95867589}); - auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{1.93001527,1.93001527,1.93001527,1.93001527, 1.93001527,1.93001527,1.93001527,1.93001527}); - - sd::ops::lstmCell op; - auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., -10.5}, {0, 0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *ht = results.at(0); - auto *ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); - - + const int batchSize = 2; + const int inSize = 10; + const int numProj = 4; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numProj}, + {0.95867589, 0.95867589, 0.95867589, 0.95867589, 0.95867589, 0.95867589, + 0.95867589, 0.95867589}); + auto expCt = NDArrayFactory::create( + 'c', {batchSize, numUnits}, + {1.93001527, 1.93001527, 1.93001527, 1.93001527, 1.93001527, 1.93001527, + 1.93001527, 1.93001527}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, + {0., 0., -10.5}, {0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto ht = results.at(0); + auto ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, lstmCell_test3) { - - const int batchSize = 2; - const int inSize = 10; - const int numProj = 4; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); - auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - xt.assign(1.); - ht_1.assign(2.); - ct_1.assign(3.); - Wx.assign(0.5); - Wh.assign(0.5); - Wc.assign(0.5); - Wp.assign(0.5); - b.assign(0.7); - - auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {0.37992568,0.37992568,0.37992568,0.37992568,0.37992568,0.37992568,0.37992568,0.37992568}); - auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4}); - - sd::ops::lstmCell op; - auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0.4, 0., 1.5}, {0, 0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *ht = results.at(0); - auto *ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); - - + const int batchSize = 2; + const int inSize = 10; + const int numProj = 4; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numProj}, + {0.37992568, 0.37992568, 0.37992568, 0.37992568, 0.37992568, 0.37992568, + 0.37992568, 0.37992568}); + auto expCt = NDArrayFactory::create( + 'c', {batchSize, numUnits}, {0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, + {0.4, 0., 1.5}, {0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto ht = results.at(0); + auto ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, lstmCell_test4) { - - const int batchSize = 2; - const int inSize = 10; - const int numProj = 4; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); - auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - xt.assign(1.); - ht_1.assign(2.); - ct_1.assign(3.); - Wx.assign(0.5); - Wh.assign(0.5); - Wc.assign(0.5); - Wp.assign(0.5); - b.assign(0.7); - - auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {0.37992568,0.37992568,0.37992568,0.37992568,0.37992568,0.37992568,0.37992568,0.37992568}); - auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4}); - - sd::ops::lstmCell op; - auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0.4, 0.3, 1.5}, {0, 0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *ht = results.at(0); - auto *ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); - - + const int batchSize = 2; + const int inSize = 10; + const int numProj = 4; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numProj}, + {0.37992568, 0.37992568, 0.37992568, 0.37992568, 0.37992568, 0.37992568, + 0.37992568, 0.37992568}); + auto expCt = NDArrayFactory::create( + 'c', {batchSize, numUnits}, {0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, + {0.4, 0.3, 1.5}, {0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto ht = results.at(0); + auto ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, lstmCell_test5) { - - const int batchSize = 2; - const int inSize = 10; - const int numProj = 3; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); - auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - xt.assign(1.); - ht_1.assign(2.); - ct_1.assign(3.); - Wx.assign(0.5); - Wh.assign(0.5); - Wc.assign(0.5); - Wp.assign(0.5); - b.assign(0.7); - - auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {0.3,0.3,0.3,0.3,0.3,0.3}); - auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4}); - - sd::ops::lstmCell op; - auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0.4, 0.3, 1.5}, {0, 1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *ht = results.at(0); - auto *ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); - - + const int batchSize = 2; + const int inSize = 10; + const int numProj = 3; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, + {0.3, 0.3, 0.3, 0.3, 0.3, 0.3}); + auto expCt = NDArrayFactory::create( + 'c', {batchSize, numUnits}, {0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, + {0.4, 0.3, 1.5}, {0, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto ht = results.at(0); + auto ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, lstmCell_test6) { - - const int batchSize = 2; - const int inSize = 10; - const int numProj = 3; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); - auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - xt.assign(1.); - ht_1.assign(2.); - ct_1.assign(3.); - Wx.assign(0.5); - Wh.assign(0.5); - Wc.assign(0.5); - Wp.assign(0.5); - b.assign(0.7); - - auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {1.99832496,1.99832496,1.99832496,1.99832496,1.99832496,1.99832496}); - auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.99972188,3.99972188,3.99972188,3.99972188,3.99972188,3.99972188,3.99972188,3.99972188}); - - sd::ops::lstmCell op; - auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 1.5}, {0, 1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *ht = results.at(0); - auto *ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); - - + const int batchSize = 2; + const int inSize = 10; + const int numProj = 3; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numProj}, + {1.99832496, 1.99832496, 1.99832496, 1.99832496, 1.99832496, 1.99832496}); + auto expCt = NDArrayFactory::create( + 'c', {batchSize, numUnits}, + {3.99972188, 3.99972188, 3.99972188, 3.99972188, 3.99972188, 3.99972188, + 3.99972188, 3.99972188}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, + {0., 0., 1.5}, {0, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto ht = results.at(0); + auto ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, lstmCell_test7) { - - const int batchSize = 2; - const int inSize = 10; - const int numProj = 3; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); - auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - xt.assign(1.); - ht_1.assign(2.); - ct_1.assign(3.); - Wx.assign(0.5); - Wh.assign(0.5); - Wc.assign(0.5); - Wp.assign(0.5); - b.assign(0.7); - - auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {0.75977136,0.75977136,0.75977136,0.75977136,0.75977136,0.75977136}); - auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4}); - - sd::ops::lstmCell op; - auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0.4, 0., 1.5}, {0, 1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *ht = results.at(0); - auto *ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); - - + const int batchSize = 2; + const int inSize = 10; + const int numProj = 3; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numProj}, + {0.75977136, 0.75977136, 0.75977136, 0.75977136, 0.75977136, 0.75977136}); + auto expCt = NDArrayFactory::create( + 'c', {batchSize, numUnits}, {0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, + {0.4, 0., 1.5}, {0, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto ht = results.at(0); + auto ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, lstmCell_test8) { - - const int batchSize = 2; - const int inSize = 10; - const int numProj = 4; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); - auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - xt.assign(1.); - ht_1.assign(2.); - ct_1.assign(3.); - Wx.assign(0.5); - Wh.assign(0.5); - Wc.assign(0.5); - Wp.assign(0.5); - b.assign(0.7); - - auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {0.99930672,0.99930672,0.99930672,0.99930672, 0.99930672,0.99930672,0.99930672,0.99930672}); - auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.99996277,3.99996277,3.99996277,3.99996277,3.99996277,3.99996277,3.99996277,3.99996277}); - - sd::ops::lstmCell op; - auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 10.5}, {1, 0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *ht = results.at(0); - auto *ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht,1e-4)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct,1e-4)); - - + const int batchSize = 2; + const int inSize = 10; + const int numProj = 4; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numProj}, + {0.99930672, 0.99930672, 0.99930672, 0.99930672, 0.99930672, 0.99930672, + 0.99930672, 0.99930672}); + auto expCt = NDArrayFactory::create( + 'c', {batchSize, numUnits}, + {3.99996277, 3.99996277, 3.99996277, 3.99996277, 3.99996277, 3.99996277, + 3.99996277, 3.99996277}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, + {0., 0., 10.5}, {1, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto ht = results.at(0); + auto ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht, 1e-4)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct, 1e-4)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, lstmCell_test9) { - - const int batchSize = 2; - const int inSize = 10; - const int numProj = 4; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); - auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - xt.assign(1.); - ht_1.assign(2.); - ct_1.assign(3.); - Wx.assign(0.5); - Wh.assign(0.5); - Wc.assign(0.5); - Wp.assign(0.5); - b.assign(0.7); - - auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {0.99501777,0.99501777,0.99501777,0.99501777,0.99501777,0.99501777,0.99501777,0.99501777}); - auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.,3.,3.,3.,3.,3.,3.,3.}); - - sd::ops::lstmCell op; - auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {3., 0., 10.5}, {1, 0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *ht = results.at(0); - auto *ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht,1e-4)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); - - + const int batchSize = 2; + const int inSize = 10; + const int numProj = 4; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numProj}, + {0.99501777, 0.99501777, 0.99501777, 0.99501777, 0.99501777, 0.99501777, + 0.99501777, 0.99501777}); + auto expCt = NDArrayFactory::create('c', {batchSize, numUnits}, + {3., 3., 3., 3., 3., 3., 3., 3.}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, + {3., 0., 10.5}, {1, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto ht = results.at(0); + auto ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht, 1e-4)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, lstmCell_test10) { - - const int batchSize = 2; - const int inSize = 10; - const int numProj = 3; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); - auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - xt.assign(1.); - ht_1.assign(2.); - ct_1.assign(3.); - Wx.assign(0.5); - Wh.assign(0.5); - Wc.assign(0.5); - Wp.assign(0.5); - b.assign(0.7); - - auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {1.99861344,1.99861344,1.99861344,1.99861344,1.99861344,1.99861344}); - auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.99996277, 3.99996277, 3.99996277, 3.99996277,3.99996277, 3.99996277, 3.99996277, 3.99996277}); - - sd::ops::lstmCell op; - auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 10.5}, {1, 1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *ht = results.at(0); - auto *ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); - - + const int batchSize = 2; + const int inSize = 10; + const int numProj = 3; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numProj}, + {1.99861344, 1.99861344, 1.99861344, 1.99861344, 1.99861344, 1.99861344}); + auto expCt = NDArrayFactory::create( + 'c', {batchSize, numUnits}, + {3.99996277, 3.99996277, 3.99996277, 3.99996277, 3.99996277, 3.99996277, + 3.99996277, 3.99996277}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, + {0., 0., 10.5}, {1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto ht = results.at(0); + auto ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, lstmCell_test11) { - - const int batchSize = 2; - const int inSize = 10; - const int numProj = 3; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); - auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - xt.assign(1.); - ht_1.assign(2.); - ct_1.assign(3.); - Wx.assign(0.5); - Wh.assign(0.5); - Wc.assign(0.5); - Wp.assign(0.5); - b.assign(0.7); - - auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {1.99003554,1.99003554,1.99003554,1.99003554,1.99003554,1.99003554}); - auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.,3.,3.,3.,3.,3.,3.,3.}); - - sd::ops::lstmCell op; - auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {3., 0., 10.5}, {1, 1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *ht = results.at(0); - auto *ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); - - + const int batchSize = 2; + const int inSize = 10; + const int numProj = 3; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numProj}, + {1.99003554, 1.99003554, 1.99003554, 1.99003554, 1.99003554, 1.99003554}); + auto expCt = NDArrayFactory::create('c', {batchSize, numUnits}, + {3., 3., 3., 3., 3., 3., 3., 3.}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, + {3., 0., 10.5}, {1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto ht = results.at(0); + auto ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, lstmCell_test12) { - - const int batchSize = 2; - const int inSize = 10; - const int numProj = 3; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); - auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - xt.assign(1.); - ht_1.assign(2.); - ct_1.assign(3.); - Wx.assign(0.5); - Wh.assign(0.5); - Wc.assign(0.5); - Wp.assign(0.5); - b.assign(0.7); - - auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {1.,1.,1.,1.,1.,1.}); - auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.,3.,3.,3.,3.,3.,3.,3.}); - - sd::ops::lstmCell op; - auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {3., 1.,-5.}, {1, 1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *ht = results.at(0); - auto *ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); - - + const int batchSize = 2; + const int inSize = 10; + const int numProj = 3; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, + {1., 1., 1., 1., 1., 1.}); + auto expCt = NDArrayFactory::create('c', {batchSize, numUnits}, + {3., 3., 3., 3., 3., 3., 3., 3.}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, + {3., 1., -5.}, {1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto ht = results.at(0); + auto ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp index 2a099230efdb..b80d66bd0c87 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp @@ -14,2749 +14,3239 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -#include "testlayers.h" -#include -#include #include #include #include #include +#include +#include +#include "testlayers.h" using namespace sd; using namespace sd::graph; class DeclarableOpsTests3 : public testing::Test { -public: - - DeclarableOpsTests3() { -// - } + public: + DeclarableOpsTests3() { + // + } }; - TEST_F(DeclarableOpsTests3, Test_Tile_1) { - auto x= NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto rep_vector= NDArrayFactory::create('c', {1, 2}, {2, 2}); - std::vector reps({2, 2}); + auto x = NDArrayFactory::create('c', {2, 3}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto rep_vector = NDArrayFactory::create('c', {1, 2}, {2, 2}); + std::vector reps({2, 2}); - auto exp = x.tile(reps); - - sd::ops::tile op; - auto result = op.evaluate({&x, &rep_vector}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto exp = x.tile(reps); + sd::ops::tile op; + auto result = op.evaluate({&x, &rep_vector}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests3, Test_Tile_2) { - auto x= NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - std::vector reps({2, 2}); - - auto exp = x.tile(reps); + auto x = NDArrayFactory::create('c', {2, 3}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + std::vector reps({2, 2}); - sd::ops::tile op; - auto result = op.evaluate({&x}, {}, {2, 2}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto exp = x.tile(reps); + sd::ops::tile op; + auto result = op.evaluate({&x}, {}, {2, 2}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests3, Test_Permute_1) { - auto x= NDArrayFactory::create('c', {2, 3, 4}); - auto permute= NDArrayFactory::create('c', {1, 3}, {0, 2, 1}); - auto exp= NDArrayFactory::create('c', {2, 4, 3}); + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto permute = NDArrayFactory::create('c', {1, 3}, {0, 2, 1}); + auto exp = NDArrayFactory::create('c', {2, 4, 3}); - sd::ops::permute op; - auto result = op.evaluate({&x, &permute}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::permute op; + auto result = op.evaluate({&x, &permute}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.isSameShape(z)); } TEST_F(DeclarableOpsTests3, Test_Permute_2) { - auto x= NDArrayFactory::create('c', {2, 3, 4}); - auto exp= NDArrayFactory::create('c', {4, 3, 2}); - - sd::ops::permute op; - auto result = op.evaluate({&x}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4, 3, 2}); - auto z = result.at(0); + sd::ops::permute op; + auto result = op.evaluate({&x}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); } - TEST_F(DeclarableOpsTests3, Test_Unique_1) { - auto x= NDArrayFactory::create('c', {1, 5}, {1.f, 2.f, 1.f, 2.f, 3.f}); - auto expV= NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - auto expI= NDArrayFactory::create('c', {5}, {0, 1, 0, 1, 2}); -// auto expI= NDArrayFactory::create('c', {3}, {0, 1, 4}); - - sd::ops::unique op; - auto result = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); - - auto v = result.at(0); - auto i = result.at(1); - // v->printIndexedBuffer("Values"); - // i->printIndexedBuffer("Indices"); - // i->printShapeInfo("Indices shape"); - ASSERT_TRUE(expV.isSameShape(v)); - ASSERT_TRUE(expV.equalsTo(v)); + auto x = + NDArrayFactory::create('c', {1, 5}, {1.f, 2.f, 1.f, 2.f, 3.f}); + auto expV = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto expI = NDArrayFactory::create('c', {5}, {0, 1, 0, 1, 2}); + // auto expI= NDArrayFactory::create('c', {3}, {0, 1, 4}); - ASSERT_TRUE(expI.isSameShape(i)); - ASSERT_TRUE(expI.equalsTo(i)); + sd::ops::unique op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); + auto v = result.at(0); + auto i = result.at(1); + // v->printIndexedBuffer("Values"); + // i->printIndexedBuffer("Indices"); + // i->printShapeInfo("Indices shape"); + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); + + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); } TEST_F(DeclarableOpsTests3, Test_Unique_2) { - auto x= NDArrayFactory::create('c', {1, 5}, {1.f, 2.f, 1.f, 2.f, 3.f}); - auto expV= NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - auto expI= NDArrayFactory::create('c', {5}, {0, 1, 0, 1, 2}); - auto expC= NDArrayFactory::create('c', {3}, {2, 2, 1}); + auto x = + NDArrayFactory::create('c', {1, 5}, {1.f, 2.f, 1.f, 2.f, 3.f}); + auto expV = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto expI = NDArrayFactory::create('c', {5}, {0, 1, 0, 1, 2}); + auto expC = NDArrayFactory::create('c', {3}, {2, 2, 1}); - sd::ops::unique_with_counts op; - auto result = op.evaluate({&x}, {}, {}); + sd::ops::unique_with_counts op; + auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(3, result.size()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(3, result.size()); - auto v = result.at(0); - auto i = result.at(1); - auto c = result.at(2); + auto v = result.at(0); + auto i = result.at(1); + auto c = result.at(2); - // v->printShapeInfo(); - // v->printIndexedBuffer("Values"); - // i->printShapeInfo(); - // i->printIndexedBuffer("Indices"); - // c->printShapeInfo(); - // c->printIndexedBuffer("Counts"); + // v->printShapeInfo(); + // v->printIndexedBuffer("Values"); + // i->printShapeInfo(); + // i->printIndexedBuffer("Indices"); + // c->printShapeInfo(); + // c->printIndexedBuffer("Counts"); - ASSERT_TRUE(expV.isSameShape(v)); - ASSERT_TRUE(expV.equalsTo(v)); + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); - ASSERT_TRUE(expI.isSameShape(i)); - ASSERT_TRUE(expI.equalsTo(i)); + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); - ASSERT_TRUE(expC.isSameShape(c)); - ASSERT_TRUE(expC.equalsTo(c)); + ASSERT_TRUE(expC.isSameShape(c)); + ASSERT_TRUE(expC.equalsTo(c)); } TEST_F(DeclarableOpsTests3, Test_Rint_1) { - auto x= NDArrayFactory::create('c', {1, 7}, {-1.7f, -1.5f, -0.2f, 0.2f, 1.5f, 1.7f, 2.0f}); - auto exp= NDArrayFactory::create('c', {1, 7}, {-2.f, -2.f, -0.f, 0.f, 2.f, 2.f, 2.f}); - - sd::ops::rint op; - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create( + 'c', {1, 7}, {-1.7f, -1.5f, -0.2f, 0.2f, 1.5f, 1.7f, 2.0f}); + auto exp = NDArrayFactory::create( + 'c', {1, 7}, {-2.f, -2.f, -0.f, 0.f, 2.f, 2.f, 2.f}); - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::rint op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests3, Test_Norm_1) { - auto x = NDArrayFactory::create('c', {100, 100}); - x.linspace(1); + auto x = NDArrayFactory::create('c', {100, 100}); + x.linspace(1); - std::vector empty; - std::vector dims({1}); - sd::ops::norm op; + std::vector empty; + std::vector dims({1}); + sd::ops::norm op; - auto result0 = op.evaluate({&x}, {0.}, {}); + auto result0 = op.evaluate({&x}, {0.}, {}); - auto z0 = result0.at(0); - auto exp0 = x.reduceAlongDimension(reduce::NormFrobenius, empty, false, false); - ASSERT_TRUE(exp0.isSameShape(z0)); - ASSERT_TRUE(exp0.equalsTo(z0)); + auto z0 = result0.at(0); + auto exp0 = + x.reduceAlongDimension(reduce::NormFrobenius, empty, false, false); + ASSERT_TRUE(exp0.isSameShape(z0)); + ASSERT_TRUE(exp0.equalsTo(z0)); - auto result1 = op.evaluate({&x}, {1.}, {1}); - ASSERT_EQ(result1.status(), ND4J_STATUS_OK); - auto z1 = result1.at(0); - // z1->printIndexedBuffer("Z1"); - auto exp1 = x.reduceAlongDimension(reduce::Norm2, dims, false, false); - // exp1.printIndexedBuffer("EXP1"); - // z1->printShapeInfo("Z1 shape"); - // exp1.printShapeInfo("EXP1 shape"); - ASSERT_TRUE(exp1.isSameShape(z1)); - ASSERT_TRUE(exp1.equalsTo(z1)); + auto result1 = op.evaluate({&x}, {1.}, {1}); + ASSERT_EQ(result1.status(), ND4J_STATUS_OK); + auto z1 = result1.at(0); + // z1->printIndexedBuffer("Z1"); + auto exp1 = x.reduceAlongDimension(reduce::Norm2, dims, false, false); + // exp1.printIndexedBuffer("EXP1"); + // z1->printShapeInfo("Z1 shape"); + // exp1.printShapeInfo("EXP1 shape"); + ASSERT_TRUE(exp1.isSameShape(z1)); + ASSERT_TRUE(exp1.equalsTo(z1)); - auto result4 = op.evaluate({&x}, {4.}, {1}); + auto result4 = op.evaluate({&x}, {4.}, {1}); - auto z4 = result4.at(0); - auto exp4= x.reduceAlongDimension(reduce::NormMax, dims, false, false); - ASSERT_TRUE(exp4.isSameShape(z4)); - ASSERT_TRUE(exp4.equalsTo(z4)); + auto z4 = result4.at(0); + auto exp4 = x.reduceAlongDimension(reduce::NormMax, dims, false, false); + ASSERT_TRUE(exp4.isSameShape(z4)); + ASSERT_TRUE(exp4.equalsTo(z4)); } - TEST_F(DeclarableOpsTests3, Test_Norm_2) { - auto x = NDArrayFactory::create('c', {100, 100}); - x.linspace(1); - auto axis= NDArrayFactory::create('c', {1, 1}, {1}); - - std::vector empty; - std::vector dims({1}); - sd::ops::norm op; - - auto result0 = op.evaluate({&x}, {0}, {}); - - auto z0 = result0.at(0); - auto exp0 = x.reduceAlongDimension(reduce::NormFrobenius, empty, false, false); - ASSERT_TRUE(exp0.isSameShape(z0)); - ASSERT_TRUE(exp0.equalsTo(z0)); + auto x = NDArrayFactory::create('c', {100, 100}); + x.linspace(1); + auto axis = NDArrayFactory::create('c', {1, 1}, {1}); + std::vector empty; + std::vector dims({1}); + sd::ops::norm op; + auto result0 = op.evaluate({&x}, {0}, {}); - auto result1 = op.evaluate({&x, &axis}, {1}, {}); + auto z0 = result0.at(0); + auto exp0 = + x.reduceAlongDimension(reduce::NormFrobenius, empty, false, false); + ASSERT_TRUE(exp0.isSameShape(z0)); + ASSERT_TRUE(exp0.equalsTo(z0)); - auto z1 = result1.at(0); - auto exp1 = x.reduceAlongDimension(reduce::Norm2, dims, false, false); - ASSERT_TRUE(exp1.isSameShape(z1)); - ASSERT_TRUE(exp1.equalsTo(z1)); + auto result1 = op.evaluate({&x, &axis}, {1}, {}); - auto result4 = op.evaluate({&x, &axis}, {4}, {}); + auto z1 = result1.at(0); + auto exp1 = x.reduceAlongDimension(reduce::Norm2, dims, false, false); + ASSERT_TRUE(exp1.isSameShape(z1)); + ASSERT_TRUE(exp1.equalsTo(z1)); - auto z4 = result4.at(0); - auto exp4= x.reduceAlongDimension(reduce::NormMax, dims, false, false); - ASSERT_TRUE(exp4.isSameShape(z4)); - ASSERT_TRUE(exp4.equalsTo(z4)); + auto result4 = op.evaluate({&x, &axis}, {4}, {}); + auto z4 = result4.at(0); + auto exp4 = x.reduceAlongDimension(reduce::NormMax, dims, false, false); + ASSERT_TRUE(exp4.isSameShape(z4)); + ASSERT_TRUE(exp4.equalsTo(z4)); } TEST_F(DeclarableOpsTests3, Test_ListDiff_1) { - auto x= NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto y= NDArrayFactory::create('c', {3}, {1.f, 3.f, 5.f}); + auto x = + NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto y = NDArrayFactory::create('c', {3}, {1.f, 3.f, 5.f}); - auto exp0= NDArrayFactory::create('c', {3}, {2.f, 4.f, 6.f}); - auto exp1= NDArrayFactory::create('c', {3}, {1, 3, 5}); + auto exp0 = NDArrayFactory::create('c', {3}, {2.f, 4.f, 6.f}); + auto exp1 = NDArrayFactory::create('c', {3}, {1, 3, 5}); - sd::ops::listdiff op; - auto result = op.evaluate({&x, &y}); + sd::ops::listdiff op; + auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - auto z0 = result.at(0); - auto z1 = result.at(1); + auto z0 = result.at(0); + auto z1 = result.at(1); - z0->getDataBuffer()->syncToSpecial(true); // force sync - z1->getDataBuffer()->syncToSpecial(true); // force sync + z0.getDataBuffer()->syncToSpecial(true); // force sync + z1.getDataBuffer()->syncToSpecial(true); // force sync - ASSERT_TRUE(exp0.isSameShape(z0)); - ASSERT_TRUE(exp0.equalsTo(z0)); - - ASSERT_TRUE(exp1.isSameShape(z1)); - ASSERT_TRUE(exp1.equalsTo(z1)); + ASSERT_TRUE(exp0.isSameShape(z0)); + ASSERT_TRUE(exp0.equalsTo(z0)); + ASSERT_TRUE(exp1.isSameShape(z1)); + ASSERT_TRUE(exp1.equalsTo(z1)); } TEST_F(DeclarableOpsTests3, Test_Range_1) { - auto start = NDArrayFactory::create(0.3f); - auto stop = NDArrayFactory::create(-5.f); - auto step = NDArrayFactory::create(-0.33f); - auto exp= NDArrayFactory::create('c', {17}, { 0.3f, -0.03f, -0.36f, -0.69f, -1.02f, -1.35f, -1.68f, -2.01f, -2.34f, -2.67f, -3.f, -3.33f, -3.66f, -3.99f, -4.32f, -4.65f, -4.98f}); - - sd::ops::range op; - auto result = op.evaluate({&start, &stop, &step}); + auto start = NDArrayFactory::create(0.3f); + auto stop = NDArrayFactory::create(-5.f); + auto step = NDArrayFactory::create(-0.33f); + auto exp = NDArrayFactory::create( + 'c', {17}, + {0.3f, -0.03f, -0.36f, -0.69f, -1.02f, -1.35f, -1.68f, -2.01f, -2.34f, + -2.67f, -3.f, -3.33f, -3.66f, -3.99f, -4.32f, -4.65f, -4.98f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::range op; + auto result = op.evaluate({&start, &stop, &step}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests3, Test_Range_2) { - auto start= NDArrayFactory::create('c', {1, 1}, {2.f}); - auto stop= NDArrayFactory::create('c', {1, 1}, {0.f}); - auto step= NDArrayFactory::create('c', {1, 1}, {-1.f}); - auto exp= NDArrayFactory::create('c', {2}, {2.f, 1.f}); - - sd::ops::range op; - auto result = op.evaluate({&start, &stop, &step}); + auto start = NDArrayFactory::create('c', {1, 1}, {2.f}); + auto stop = NDArrayFactory::create('c', {1, 1}, {0.f}); + auto step = NDArrayFactory::create('c', {1, 1}, {-1.f}); + auto exp = NDArrayFactory::create('c', {2}, {2.f, 1.f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::range op; + auto result = op.evaluate({&start, &stop, &step}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests3, Test_Range_3) { - auto start= NDArrayFactory::create('c', {1, 1}, {0.f}); - auto stop= NDArrayFactory::create('c', {1, 1}, {2.f}); - auto step= NDArrayFactory::create('c', {1, 1}, {1.f}); - auto exp= NDArrayFactory::create('c', {2}, {0.f, 1.f}); + auto start = NDArrayFactory::create('c', {1, 1}, {0.f}); + auto stop = NDArrayFactory::create('c', {1, 1}, {2.f}); + auto step = NDArrayFactory::create('c', {1, 1}, {1.f}); + auto exp = NDArrayFactory::create('c', {2}, {0.f, 1.f}); - sd::ops::range op; - auto result = op.evaluate({&start, &stop, &step}); + sd::ops::range op; + auto result = op.evaluate({&start, &stop, &step}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests3, Test_Range_10) { - auto start= NDArrayFactory::create('c', {1, 1}, {0.f}); - auto stop= NDArrayFactory::create('c', {1, 1}, {2.f}); - auto step= NDArrayFactory::create('c', {1, 1}, {1.f}); - auto exp= NDArrayFactory::create('c', {2}, {0.f, 1.f}); + auto start = NDArrayFactory::create('c', {1, 1}, {0.f}); + auto stop = NDArrayFactory::create('c', {1, 1}, {2.f}); + auto step = NDArrayFactory::create('c', {1, 1}, {1.f}); + auto exp = NDArrayFactory::create('c', {2}, {0.f, 1.f}); - sd::ops::range op; - auto result = op.evaluate({&start, &stop, &step}, {sd::DataType::DOUBLE}); + sd::ops::range op; + auto result = op.evaluate({&start, &stop, &step}, {sd::DataType::DOUBLE}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests3, Test_Range_4) { - auto exp= NDArrayFactory::create('c', {13}, {-10.f, -8.334f, -6.668f, -5.002f, -3.336f, -1.67f, -0.004f, 1.662f, 3.328f, 4.994f, 6.66f, 8.326f, 9.992f}); + auto exp = NDArrayFactory::create( + 'c', {13}, + {-10.f, -8.334f, -6.668f, -5.002f, -3.336f, -1.67f, -0.004f, 1.662f, + 3.328f, 4.994f, 6.66f, 8.326f, 9.992f}); - sd::ops::range op; - auto result = op.evaluate({}, {-10., 10., 1.666}, {}); + sd::ops::range op; + auto result = op.evaluate({}, {-10., 10., 1.666}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests3, Test_Range_5) { - auto exp= NDArrayFactory::create('c', {2}, {2.f, 1.f}); - - sd::ops::range op; - auto result = op.evaluate({}, {2, 0, -1}, {}); + auto exp = NDArrayFactory::create('c', {2}, {2.f, 1.f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::range op; + auto result = op.evaluate({}, {2, 0, -1}, {}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests3, Test_Range_6) { - auto exp= NDArrayFactory::create('c', {2}, {0.f, 1.f}); + auto exp = NDArrayFactory::create('c', {2}, {0.f, 1.f}); - sd::ops::range op; - auto result = op.evaluate({}, {0, 2, 1}, {}); + sd::ops::range op; + auto result = op.evaluate({}, {0, 2, 1}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests3, Test_Range_7) { - auto exp= NDArrayFactory::create('c', {10}, {10.f, 8.334f, 6.668f, 5.002f, 3.336f, 1.67f, 0.004f, -1.662f, -3.328f, -4.994f}); - - sd::ops::range op; - auto result = op.evaluate({}, {10,-5,-1.666}, {}); + auto exp = + NDArrayFactory::create('c', {10}, + {10.f, 8.334f, 6.668f, 5.002f, 3.336f, + 1.67f, 0.004f, -1.662f, -3.328f, -4.994f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::range op; + auto result = op.evaluate({}, {10, -5, -1.666}, {}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - - TEST_F(DeclarableOpsTests3, Test_Range_8) { - auto exp= NDArrayFactory::create('c', {2}, {2, 1}); + auto exp = NDArrayFactory::create('c', {2}, {2, 1}); - sd::ops::range op; - auto result = op.evaluate({}, {}, {2, 0, -1}); + sd::ops::range op; + auto result = op.evaluate({}, {}, {2, 0, -1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests3, Test_Range_9) { - auto exp= NDArrayFactory::create('c', {2}, {0, 1}); - - sd::ops::range op; - auto result = op.evaluate({}, {}, {0, 2, 1}); + auto exp = NDArrayFactory::create('c', {2}, {0, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::range op; + auto result = op.evaluate({}, {}, {0, 2, 1}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_1) { - auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); - auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); - auto x= NDArrayFactory::create('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto y= NDArrayFactory::create('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto a = NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b = NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x = + NDArrayFactory::create('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = + NDArrayFactory::create('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto exp = MmulHelper::mmul(&x, &y); + auto exp = MmulHelper::mmul(&x, &y); - sd::ops::batched_gemm op; - auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 3, 3, 3, 3, 3, 3, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::batched_gemm op; + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, + {111, 111, 3, 3, 3, 3, 3, 3, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(3, result.size()); + ASSERT_EQ(3, result.size()); - for (int e = 0; e < 3; e++) { - auto z = result.at(e); + for (int e = 0; e < 3; e++) { + auto z = result.at(e); -// exp->printIndexedBuffer("e"); -// z->printIndexedBuffer("z"); + // exp->printIndexedBuffer("e"); + // z->printIndexedBuffer("z"); - ASSERT_TRUE(exp->isSameShape(z)); - ASSERT_TRUE(exp->equalsTo(z)); - } - - delete exp; + ASSERT_TRUE(exp->isSameShape(z)); + ASSERT_TRUE(exp->equalsTo(z)); + } + delete exp; } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_2) { - auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); - auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); - auto x= NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto y= NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - - auto exp = MmulHelper::mmul(&x, &y); + auto a = NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b = NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - sd::ops::batched_gemm op; - auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 3, 3, 3, 3, 3, 3, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto exp = MmulHelper::mmul(&x, &y); - ASSERT_EQ(3, result.size()); + sd::ops::batched_gemm op; + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, + {112, 112, 3, 3, 3, 3, 3, 3, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - for (int e = 0; e < 3; e++) { - auto z = result.at(e); + ASSERT_EQ(3, result.size()); - //exp->printIndexedBuffer("e"); - //z->printIndexedBuffer("z"); + for (int e = 0; e < 3; e++) { + auto z = result.at(e); - ASSERT_TRUE(exp->isSameShape(z)); - ASSERT_TRUE(exp->equalsTo(z)); - } + // exp->printIndexedBuffer("e"); + // z->printIndexedBuffer("z"); - delete exp; + ASSERT_TRUE(exp->isSameShape(z)); + ASSERT_TRUE(exp->equalsTo(z)); + } + delete exp; } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_3) { - auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); - auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); - auto x= NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto y= NDArrayFactory::create('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto a = NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b = NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = + NDArrayFactory::create('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto exp = MmulHelper::mmul(&x, &y); + auto exp = MmulHelper::mmul(&x, &y); - sd::ops::batched_gemm op; - auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 111, 3, 3, 3, 3, 3, 3, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::batched_gemm op; + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, + {112, 111, 3, 3, 3, 3, 3, 3, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(3, result.size()); + ASSERT_EQ(3, result.size()); - for (int e = 0; e < 3; e++) { - auto z = result.at(e); + for (int e = 0; e < 3; e++) { + auto z = result.at(e); -// exp->printIndexedBuffer("e"); -// z->printIndexedBuffer("z"); + // exp->printIndexedBuffer("e"); + // z->printIndexedBuffer("z"); - ASSERT_TRUE(exp->isSameShape(z)); - ASSERT_TRUE(exp->equalsTo(z)); - } - - delete exp; + ASSERT_TRUE(exp->isSameShape(z)); + ASSERT_TRUE(exp->equalsTo(z)); + } + delete exp; } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_4) { - auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); - auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); - auto x= NDArrayFactory::create('f', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - auto y= NDArrayFactory::create('f', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - - auto exp = MmulHelper::mmul(&x, &y); + auto a = NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b = NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x = NDArrayFactory::create( + 'f', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + auto y = NDArrayFactory::create( + 'f', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - sd::ops::batched_gemm op; - auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 5, 4, 3, 5, 3, 5, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto exp = MmulHelper::mmul(&x, &y); - ASSERT_EQ(3, result.size()); + sd::ops::batched_gemm op; + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, + {111, 111, 5, 4, 3, 5, 3, 5, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - for (int e = 0; e < 3; e++) { - auto z = result.at(e); + ASSERT_EQ(3, result.size()); - //exp->printIndexedBuffer("e"); - //z->printIndexedBuffer("z"); + for (int e = 0; e < 3; e++) { + auto z = result.at(e); - ASSERT_TRUE(exp->isSameShape(z)); - ASSERT_TRUE(exp->equalsTo(z)); - } + // exp->printIndexedBuffer("e"); + // z->printIndexedBuffer("z"); - delete exp; + ASSERT_TRUE(exp->isSameShape(z)); + ASSERT_TRUE(exp->equalsTo(z)); + } + delete exp; } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_5) { - auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); - auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); - auto x= NDArrayFactory::create('c', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - auto y= NDArrayFactory::create('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto a = NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b = NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x = NDArrayFactory::create( + 'c', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + auto y = NDArrayFactory::create( + 'c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto exp = MmulHelper::mmul(&x, &y); + auto exp = MmulHelper::mmul(&x, &y); - sd::ops::batched_gemm op; - auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 5, 4, 3, 3, 4, 5, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::batched_gemm op; + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, + {112, 112, 5, 4, 3, 3, 4, 5, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(3, result.size()); + ASSERT_EQ(3, result.size()); - for (int e = 0; e < 3; e++) { - auto z = result.at(e); + for (int e = 0; e < 3; e++) { + auto z = result.at(e); - //exp->printIndexedBuffer("e"); - //z->printIndexedBuffer("z"); + // exp->printIndexedBuffer("e"); + // z->printIndexedBuffer("z"); - ASSERT_TRUE(exp->isSameShape(z)); - ASSERT_TRUE(exp->equalsTo(z)); - } - - delete exp; + ASSERT_TRUE(exp->isSameShape(z)); + ASSERT_TRUE(exp->equalsTo(z)); + } + delete exp; } - TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_6) { - auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); - auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); - auto x= NDArrayFactory::create('f', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); - auto y= NDArrayFactory::create('f', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + auto a = NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b = NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x = NDArrayFactory::create('f', {2, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + auto y = NDArrayFactory::create( + 'f', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - auto exp = MmulHelper::mmul(&x, &y); + auto exp = MmulHelper::mmul(&x, &y); - sd::ops::batched_gemm op; - auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 2, 3, 5, 2, 5, 2, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::batched_gemm op; + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, + {111, 111, 2, 3, 5, 2, 5, 2, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(3, result.size()); + ASSERT_EQ(3, result.size()); - for (int e = 0; e < 3; e++) { - auto z = result.at(e); + for (int e = 0; e < 3; e++) { + auto z = result.at(e); - //exp->printIndexedBuffer("e"); - //z->printIndexedBuffer("z"); + // exp->printIndexedBuffer("e"); + // z->printIndexedBuffer("z"); - ASSERT_TRUE(exp->isSameShape(z)); - ASSERT_TRUE(exp->equalsTo(z)); - } - - delete exp; + ASSERT_TRUE(exp->isSameShape(z)); + ASSERT_TRUE(exp->equalsTo(z)); + } + delete exp; } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_7) { - auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); - auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); - auto x= NDArrayFactory::create('c', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); - auto y= NDArrayFactory::create('c', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - - auto exp = MmulHelper::mmul(&x, &y); + auto a = NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b = NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x = NDArrayFactory::create('c', {2, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + auto y = NDArrayFactory::create( + 'c', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - // exp->printShapeInfo("exp shape"); + auto exp = MmulHelper::mmul(&x, &y); - sd::ops::batched_gemm op; - auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + // exp->printShapeInfo("exp shape"); - ASSERT_EQ(3, result.size()); + sd::ops::batched_gemm op; + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, + {112, 112, 2, 3, 5, 5, 3, 2, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - for (int e = 0; e < 3; e++) { - auto z = result.at(e); + ASSERT_EQ(3, result.size()); - //exp->printIndexedBuffer("e"); - //z->printIndexedBuffer("z"); + for (int e = 0; e < 3; e++) { + auto z = result.at(e); - ASSERT_TRUE(exp->isSameShape(z)); - ASSERT_TRUE(exp->equalsTo(z)); - } + // exp->printIndexedBuffer("e"); + // z->printIndexedBuffer("z"); - delete exp; + ASSERT_TRUE(exp->isSameShape(z)); + ASSERT_TRUE(exp->equalsTo(z)); + } + delete exp; } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_1) { - auto a = NDArrayFactory::create('c', {1, 3}, {1.f, 1.f, 1.f}); - auto b = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); - auto x = NDArrayFactory::create('c', {2, 5}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f}); - auto y = NDArrayFactory::create('c', {5, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); - - sd::ops::batched_gemm op; - try { - auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3}); - - ASSERT_TRUE(false); - } catch (std::invalid_argument &e) { - // - } + auto a = NDArrayFactory::create('c', {1, 3}, {1.f, 1.f, 1.f}); + auto b = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); + auto x = NDArrayFactory::create( + 'c', {2, 5}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f}); + auto y = + NDArrayFactory::create('c', {5, 3}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); + + sd::ops::batched_gemm op; + try { + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, + {112, 112, 2, 3, 5, 5, 3, 2, 3}); + + ASSERT_TRUE(false); + } catch (std::invalid_argument &e) { + // + } } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_2) { - auto a = NDArrayFactory::create('c', {1, 3}, {1.f, 1.f, 1.f}); - auto b = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); - auto x = NDArrayFactory::create('c', {2, 5}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f}); - auto y = NDArrayFactory::create('c', {5, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); - - auto z = NDArrayFactory::create('c', {2, 3}); - - sd::ops::batched_gemm op; - try { - auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {&z}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3}, {}); - ASSERT_TRUE(false); - } catch (std::invalid_argument &e) { - // - } + auto a = NDArrayFactory::create('c', {1, 3}, {1.f, 1.f, 1.f}); + auto b = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); + auto x = NDArrayFactory::create( + 'c', {2, 5}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f}); + auto y = + NDArrayFactory::create('c', {5, 3}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); + + auto z = NDArrayFactory::create('c', {2, 3}); + + sd::ops::batched_gemm op; + try { + auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {&z}, {}, + {112, 112, 2, 3, 5, 5, 3, 2, 3}, {}); + ASSERT_TRUE(false); + } catch (std::invalid_argument &e) { + // + } } TEST_F(DeclarableOpsTests3, Test_ReverseDivide_1) { - auto x= NDArrayFactory::create('c', {1, 3}, {2, 2, 2}); - auto y= NDArrayFactory::create('c', {1, 3}, {4, 6, 8}); - auto exp= NDArrayFactory::create('c', {1, 3}, {2, 3, 4}); + auto x = NDArrayFactory::create('c', {1, 3}, {2, 2, 2}); + auto y = NDArrayFactory::create('c', {1, 3}, {4, 6, 8}); + auto exp = NDArrayFactory::create('c', {1, 3}, {2, 3, 4}); - sd::ops::reversedivide op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reversedivide op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, sruCell_test1) { + const int batchSize = 2; + const int inSize = 5; - const int batchSize = 2; - const int inSize = 5; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ct_1= NDArrayFactory::create('c', {batchSize, inSize}); - auto w = NDArrayFactory::create('c', {inSize, 3*inSize}); - auto b = NDArrayFactory::create('c', {2*inSize}); + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, inSize}); + auto w = NDArrayFactory::create('c', {inSize, 3 * inSize}); + auto b = NDArrayFactory::create('c', {2 * inSize}); - xt.assign(1.); - ct_1.assign(2.); - w.assign(0.5); - b.assign(0.7); + xt.assign(1.); + ct_1.assign(2.); + w.assign(0.5); + b.assign(0.7); - auto expHt= NDArrayFactory::create('c', {batchSize, inSize}, {0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f}); - auto expCt= NDArrayFactory::create('c', {batchSize, inSize}, {2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f}); + auto expHt = NDArrayFactory::create( + 'c', {batchSize, inSize}, + {0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, + 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f}); + auto expCt = NDArrayFactory::create( + 'c', {batchSize, inSize}, + {2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, + 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f}); - sd::ops::sruCell op; - auto results = op.evaluate({&xt, &ct_1, &w, &b}); + sd::ops::sruCell op; + auto results = op.evaluate({&xt, &ct_1, &w, &b}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *ht = results.at(0); - auto *ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); + auto ht = results.at(0); + auto ct = results.at(1); + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } - //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, sruCell_test2) { + const int batchSize = 2; + const int inSize = 5; - const int batchSize = 2; - const int inSize = 5; + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, inSize}); + auto w = NDArrayFactory::create('c', {inSize, 3 * inSize}); + auto b = NDArrayFactory::create('c', {2 * inSize}); - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ct_1= NDArrayFactory::create('c', {batchSize, inSize}); - auto w = NDArrayFactory::create('c', {inSize, 3*inSize}); - auto b = NDArrayFactory::create('c', {2*inSize}); + xt.assign(1.); + ct_1.assign(2.); + w.assign(0.5); + b.assign(-1.); - xt.assign(1.); - ct_1.assign(2.); - w.assign(0.5); - b.assign(-1.); + auto expHt = NDArrayFactory::create( + 'c', {batchSize, inSize}, + {0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, + 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f}); + auto expCt = NDArrayFactory::create( + 'c', {batchSize, inSize}, + {2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, + 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f}); - auto expHt= NDArrayFactory::create('c', {batchSize, inSize}, {0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f}); - auto expCt= NDArrayFactory::create('c', {batchSize, inSize}, {2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f}); + sd::ops::sruCell op; + auto results = op.evaluate({&xt, &ct_1, &w, &b}); - sd::ops::sruCell op; - auto results = op.evaluate({&xt, &ct_1, &w, &b}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *ht = results.at(0); - auto *ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); + auto ht = results.at(0); + auto ct = results.at(1); + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, sruCell_test3) { + const int batchSize = 2; + const int inSize = 5; - const int batchSize = 2; - const int inSize = 5; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ct_1= NDArrayFactory::create('c', {batchSize, inSize}); - auto w = NDArrayFactory::create('c', {inSize, 3*inSize}); - auto b = NDArrayFactory::create('c', {2*inSize}); - - xt.assign(10.); - ct_1.assign(1.); - w.assign(0.5); - b.assign(-1.); + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, inSize}); + auto w = NDArrayFactory::create('c', {inSize, 3 * inSize}); + auto b = NDArrayFactory::create('c', {2 * inSize}); - auto expHt= NDArrayFactory::create('c', {batchSize, inSize}, {0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f}); - auto expCt= NDArrayFactory::create('c', {batchSize, inSize}, {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}); + xt.assign(10.); + ct_1.assign(1.); + w.assign(0.5); + b.assign(-1.); - sd::ops::sruCell op; - auto results = op.evaluate({&xt, &ct_1, &w, &b}); + auto expHt = NDArrayFactory::create( + 'c', {batchSize, inSize}, + {0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, + 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f}); + auto expCt = NDArrayFactory::create( + 'c', {batchSize, inSize}, + {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::sruCell op; + auto results = op.evaluate({&xt, &ct_1, &w, &b}); - auto *ht = results.at(0); - auto *ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto ht = results.at(0); + auto ct = results.at(1); + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, gruCell_test1) { + const int batchSize = 2; + const int inSize = 10; + const int numUnits = 4; - const int batchSize = 2; - const int inSize = 10; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wru = NDArrayFactory::create('c', {(inSize+numUnits), 2*numUnits}); - auto Wc = NDArrayFactory::create('c', {(inSize+numUnits), numUnits}); - auto bru = NDArrayFactory::create('c', {2*numUnits}); - auto bc = NDArrayFactory::create('c', {numUnits}); + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wru = + NDArrayFactory::create('c', {(inSize + numUnits), 2 * numUnits}); + auto Wc = NDArrayFactory::create('c', {(inSize + numUnits), numUnits}); + auto bru = NDArrayFactory::create('c', {2 * numUnits}); + auto bc = NDArrayFactory::create('c', {numUnits}); - xt.assign(1.); - ht_1.assign(2.); - Wru.assign(0.5); - Wc.assign(0.5); - bru.assign(0.7); - bc.assign(0.7); + xt.assign(1.); + ht_1.assign(2.); + Wru.assign(0.5); + Wc.assign(0.5); + bru.assign(0.7); + bc.assign(0.7); - auto expHt = NDArrayFactory::create('c', {batchSize, numUnits}, {1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f}); + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numUnits}, + {1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, + 1.99993872f, 1.99993872f, 1.99993872f}); - sd::ops::gruCell op; - auto results = op.evaluate({&xt, &ht_1, &Wru, &Wc, &bru, &bc}); + sd::ops::gruCell op; + auto results = op.evaluate({&xt, &ht_1, &Wru, &Wc, &bru, &bc}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *ht = results.at(3); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); + auto ht = results.at(3); + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, gruCell_test2) { + const int batchSize = 2; + const int inSize = 10; + const int numUnits = 4; - const int batchSize = 2; - const int inSize = 10; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wru = NDArrayFactory::create('c', {(inSize+numUnits), 2*numUnits}); - auto Wc = NDArrayFactory::create('c', {(inSize+numUnits), numUnits}); - auto bru = NDArrayFactory::create('c', {2*numUnits}); - auto bc = NDArrayFactory::create('c', {numUnits}); + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wru = + NDArrayFactory::create('c', {(inSize + numUnits), 2 * numUnits}); + auto Wc = NDArrayFactory::create('c', {(inSize + numUnits), numUnits}); + auto bru = NDArrayFactory::create('c', {2 * numUnits}); + auto bc = NDArrayFactory::create('c', {numUnits}); - xt.assign(1.); - ht_1.assign(0.); - Wru.assign(1.5); - Wc.assign(1.5); - bru.assign(-10); - bc.assign(-10); + xt.assign(1.); + ht_1.assign(0.); + Wru.assign(1.5); + Wc.assign(1.5); + bru.assign(-10); + bc.assign(-10); - auto expHt= NDArrayFactory::create('c', {batchSize, numUnits}, {0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f}); + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numUnits}, + {0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, + 0.00669224f, 0.00669224f, 0.00669224f}); - sd::ops::gruCell op; - auto results = op.evaluate({&xt, &ht_1, &Wru, &Wc, &bru, &bc}); + sd::ops::gruCell op; + auto results = op.evaluate({&xt, &ht_1, &Wru, &Wc, &bru, &bc}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *ht = results.at(3); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto ht = results.at(3); + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, gruCell_test3) { + const int batchSize = 2; + const int inSize = 10; + const int numUnits = 4; - const int batchSize = 2; - const int inSize = 10; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1= NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wru = NDArrayFactory::create('c', {(inSize+numUnits), 2*numUnits}); - auto Wc = NDArrayFactory::create('c', {(inSize+numUnits), numUnits}); - auto bru = NDArrayFactory::create('c', {2*numUnits}); - auto bc = NDArrayFactory::create('c', {numUnits}); - - xt.assign(1.); - ht_1.assign(0.); - Wru.assign(0.1); - Wc.assign(0.1); - bru.assign(1); - bc.assign(1); - - auto expHt= NDArrayFactory::create('c', {batchSize, numUnits}, {0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f}); + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wru = + NDArrayFactory::create('c', {(inSize + numUnits), 2 * numUnits}); + auto Wc = NDArrayFactory::create('c', {(inSize + numUnits), numUnits}); + auto bru = NDArrayFactory::create('c', {2 * numUnits}); + auto bc = NDArrayFactory::create('c', {numUnits}); + xt.assign(1.); + ht_1.assign(0.); + Wru.assign(0.1); + Wc.assign(0.1); + bru.assign(1); + bc.assign(1); - sd::ops::gruCell op; - auto result = op.evaluate({&xt, &ht_1, &Wru, &Wc, &bru, &bc}); + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numUnits}, + {0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, + 0.1149149f, 0.1149149f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::gruCell op; + auto result = op.evaluate({&xt, &ht_1, &Wru, &Wc, &bru, &bc}); - auto *ht = result.at(3); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); + auto ht = result.at(3); + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, invertPermutation_test1) { + auto input = + NDArrayFactory::create('c', {1, 8}, {5, 2, 7, 4, 6, 3, 1, 0}); + auto expected = + NDArrayFactory::create('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2}); - auto input= NDArrayFactory::create('c', {1, 8}, {5,2,7,4,6,3,1,0}); - auto expected= NDArrayFactory::create('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2}); + sd::ops::invert_permutation op; + auto result = op.evaluate({&input}); - sd::ops::invert_permutation op; - auto result = op.evaluate({&input}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, invertPermutation_test2) { + auto input = + NDArrayFactory::create('c', {1, 8}, {5, 2, 7, 4, 6, 3, 1, 0}); + auto expected = + NDArrayFactory::create('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2}); - auto input= NDArrayFactory::create('c', {1, 8}, {5,2,7,4,6,3,1,0}); - auto expected= NDArrayFactory::create('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2}); - + sd::ops::invert_permutation op; + auto result = op.evaluate({&input}); - sd::ops::invert_permutation op; - auto result = op.evaluate({&input}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, invertPermutation_test3) { + auto input = + NDArrayFactory::create('c', {1, 8}, {1, 2, 0, 4, 6, 3, 5, 7}); + auto expected = + NDArrayFactory::create('c', {1, 8}, {2, 0, 1, 5, 3, 6, 4, 7}); - auto input= NDArrayFactory::create('c', {1, 8}, {1,2,0,4,6,3,5,7}); - auto expected= NDArrayFactory::create('c', {1, 8}, {2, 0, 1, 5, 3, 6, 4, 7}); - - sd::ops::invert_permutation op; - auto result = op.evaluate({&input}); + sd::ops::invert_permutation op; + auto result = op.evaluate({&input}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diag_test1) { + auto input = NDArrayFactory::create('c', {3, 2}); + input.linspace(1); - auto input= NDArrayFactory::create('c', {3, 2}); - input.linspace(1); - - auto expected= NDArrayFactory::create('c', {3,2,3,2}, {1,0,0,0,0,0, 0,2,0,0,0,0, 0,0,3,0,0,0, 0,0,0,4,0,0, 0,0,0,0,5,0, 0,0,0,0,0,6}); + auto expected = NDArrayFactory::create( + 'c', {3, 2, 3, 2}, + {1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, + 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 6}); - sd::ops::diag op; - auto result = op.evaluate({&input}); + sd::ops::diag op; + auto result = op.evaluate({&input}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diag_test2) { + auto input = NDArrayFactory::create('c', {2, 3}); + input.linspace(1); - auto input= NDArrayFactory::create('c', {2, 3}); - input.linspace(1); - - auto expected= NDArrayFactory::create('c', {2,3,2,3}, {1,0,0,0,0,0, 0,2,0,0,0,0, 0,0,3,0,0,0, 0,0,0,4,0,0, 0,0,0,0,5,0, 0,0,0,0,0,6}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 2, 3}, + {1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, + 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 6}); - sd::ops::diag op; - auto result = op.evaluate({&input}); + sd::ops::diag op; + auto result = op.evaluate({&input}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diag_test_vector) { + auto input = NDArrayFactory::linspace(1, 4, 4); + auto expected = NDArrayFactory::create( + 'c', {4, 4}, {1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 0, 0, 4}); + sd::ops::diag op; + auto result = op.evaluate({input}); - auto input = NDArrayFactory::linspace(1,4,4); - auto expected= NDArrayFactory::create('c', {4,4}, {1,0,0,0, 0,2,0,0, 0,0,3,0,0,0,0,4}); - - sd::ops::diag op; - auto result = op.evaluate({input}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); - delete input; + delete input; } TEST_F(DeclarableOpsTests3, diag_test_col_vector) { + auto input = NDArrayFactory::linspace(1, 4, 4); + input->reshapei({4, 1}); + auto expected = NDArrayFactory::create( + 'c', {4, 4}, {1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 0, 0, 4}); + sd::ops::diag op; + auto result = op.evaluate({input}, {}, {}); - auto input = NDArrayFactory::linspace(1,4,4); - input->reshapei({4,1}); - auto expected= NDArrayFactory::create('c', {4,4}, {1,0,0,0, 0,2,0,0, 0,0,3,0,0,0,0,4}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - sd::ops::diag op; - auto result = op.evaluate({input}, {}, {}); + auto output = result.at(0); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - - delete input; + delete input; } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diag_test3) { + auto input = NDArrayFactory::create('c', {1, 3}); + input.linspace(1); - auto input= NDArrayFactory::create('c', {1, 3}); - input.linspace(1); - - auto expected= NDArrayFactory::create('c', {3,3}, {1,0,0, 0,2,0, 0,0,3}); + auto expected = + NDArrayFactory::create('c', {3, 3}, {1, 0, 0, 0, 2, 0, 0, 0, 3}); - sd::ops::diag op; - auto result = op.evaluate({&input}, {}, {}); + sd::ops::diag op; + auto result = op.evaluate({&input}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diag_test4) { + auto input = NDArrayFactory::create('c', {3, 1}); + input.linspace(1); - auto input= NDArrayFactory::create('c', {3, 1}); - input.linspace(1); - - auto expected= NDArrayFactory::create('c', {3,3}, {1,0,0, 0,2,0, 0,0,3}); + auto expected = + NDArrayFactory::create('c', {3, 3}, {1, 0, 0, 0, 2, 0, 0, 0, 3}); - sd::ops::diag op; - auto result = op.evaluate({&input}, {}, {}); + sd::ops::diag op; + auto result = op.evaluate({&input}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diag_test5) { + auto input = NDArrayFactory::create('c', {1, 1}); + input.linspace(2); - auto input= NDArrayFactory::create('c', {1, 1}); - input.linspace(2); - - auto expected= NDArrayFactory::create('c', {1,1}, {2}); + auto expected = NDArrayFactory::create('c', {1, 1}, {2}); - sd::ops::diag op; - auto result = op.evaluate({&input}, {}, {}); + sd::ops::diag op; + auto result = op.evaluate({&input}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diag_test6) { + auto input = NDArrayFactory::create('c', {2, 2, 2}); + input.linspace(1); - auto input= NDArrayFactory::create('c', {2,2,2}); - input.linspace(1); - - auto expected= NDArrayFactory::create('c', {2,2,2,2,2,2}, {1,0,0,0, 0,0,0,0, 0,2,0,0, 0,0,0,0, 0,0,3,0, 0,0,0,0, 0,0,0,4, 0,0,0,0, 0,0,0,0, 5,0,0,0, 0,0,0,0, 0,6,0,0, 0,0,0,0, 0,0,7,0, 0,0,0,0, 0,0,0,8}); + auto expected = NDArrayFactory::create( + 'c', {2, 2, 2, 2, 2, 2}, + {1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, + 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, + 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 8}); - sd::ops::diag op; - auto result = op.evaluate({&input}, {}, {}); + sd::ops::diag op; + auto result = op.evaluate({&input}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, matrixSetDiag_test1) { + auto input = NDArrayFactory::create('c', {4, 3, 2}); + auto diagonal = NDArrayFactory::create('c', {4, 2}); + input.assign(0.); + diagonal.assign(1.); - auto input= NDArrayFactory::create('c', {4,3,2}); - auto diagonal= NDArrayFactory::create('c', {4,2}); - input.assign(0.); - diagonal.assign(1.); - - auto expected= NDArrayFactory::create('c', {4,3,2}, {1,0,0,1,0,0, 1,0,0,1,0,0, 1,0,0,1,0,0, 1,0,0,1,0,0}); + auto expected = NDArrayFactory::create( + 'c', {4, 3, 2}, + {1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0}); - sd::ops::matrix_set_diag op; - auto result = op.evaluate({&input, &diagonal}, {}, {}); + sd::ops::matrix_set_diag op; + auto result = op.evaluate({&input, &diagonal}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, matrixSetDiag_test2) { + auto input = NDArrayFactory::create('c', {1, 1, 2}); + auto diagonal = NDArrayFactory::create('c', {1, 1}); + input.assign(0.); + diagonal.assign(1.); - auto input= NDArrayFactory::create('c', {1,1,2}); - auto diagonal= NDArrayFactory::create('c', {1,1}); - input.assign(0.); - diagonal.assign(1.); - - auto expected= NDArrayFactory::create('c', {1,1,2}, {1.f, 0.f}); - - sd::ops::matrix_set_diag op; - auto result = op.evaluate({&input, &diagonal}, {}, {}); + auto expected = NDArrayFactory::create('c', {1, 1, 2}, {1.f, 0.f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::matrix_set_diag op; + auto result = op.evaluate({&input, &diagonal}, {}, {}); - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, matrixSetDiag_test3) { + auto input = NDArrayFactory::create('c', {2, 1, 4}); + auto diagonal = NDArrayFactory::create('c', {2, 1}); + input.assign(0.); + diagonal.assign(1.); - auto input= NDArrayFactory::create('c', {2,1,4}); - auto diagonal= NDArrayFactory::create('c', {2,1}); - input.assign(0.); - diagonal.assign(1.); - - auto expected= NDArrayFactory::create('c', {2,1,4}, {1,0,0,0,1,0,0,0}); + auto expected = + NDArrayFactory::create('c', {2, 1, 4}, {1, 0, 0, 0, 1, 0, 0, 0}); - sd::ops::matrix_set_diag op; - auto result = op.evaluate({&input, &diagonal}, {}, {}); + sd::ops::matrix_set_diag op; + auto result = op.evaluate({&input, &diagonal}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, matrixSetDiag_test4) { + auto input = NDArrayFactory::create('c', {2, 1, 4, 1}); + auto diagonal = NDArrayFactory::create('c', {2, 1, 1}); + input.assign(0.); + diagonal.assign(1.); - auto input= NDArrayFactory::create('c', {2,1,4,1}); - auto diagonal= NDArrayFactory::create('c', {2,1,1}); - input.assign(0.); - diagonal.assign(1.); - - auto expected= NDArrayFactory::create('c', {2,1,4,1}, {1,0,0,0,1,0,0,0}); + auto expected = NDArrayFactory::create('c', {2, 1, 4, 1}, + {1, 0, 0, 0, 1, 0, 0, 0}); - sd::ops::matrix_set_diag op; - auto result = op.evaluate({&input, &diagonal}, {}, {}); + sd::ops::matrix_set_diag op; + auto result = op.evaluate({&input, &diagonal}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diagPart_test1) { + auto input = NDArrayFactory::create('c', {2, 2}); + input.linspace(1); - auto input= NDArrayFactory::create('c', {2,2}); - input.linspace(1); - - auto expected= NDArrayFactory::create('c', {2}, {1,4}); + auto expected = NDArrayFactory::create('c', {2}, {1, 4}); - sd::ops::diag_part op; - auto result = op.evaluate({&input}, {}, {}); + sd::ops::diag_part op; + auto result = op.evaluate({&input}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); - // output->printBuffer(); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + // output->printBuffer(); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diagPart_test2) { + auto input = NDArrayFactory::create('c', {2, 2, 2, 2}); + input.linspace(1); - auto input= NDArrayFactory::create('c', {2,2,2,2}); - input.linspace(1); - - auto expected= NDArrayFactory::create('c', {2,2}, {1,6,11,16}); + auto expected = NDArrayFactory::create('c', {2, 2}, {1, 6, 11, 16}); - sd::ops::diag_part op; - auto result = op.evaluate({&input}, {}, {}); + sd::ops::diag_part op; + auto result = op.evaluate({&input}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diagPart_test3) { + auto input = NDArrayFactory::create('c', {2, 2, 2, 2, 2, 2}); + input.linspace(1); - auto input= NDArrayFactory::create('c', {2,2,2,2,2,2}); - input.linspace(1); - - auto expected= NDArrayFactory::create('c', {2,2,2}, {1,10,19,28,37,46,55,64}); + auto expected = NDArrayFactory::create( + 'c', {2, 2, 2}, {1, 10, 19, 28, 37, 46, 55, 64}); - sd::ops::diag_part op; - auto result = op.evaluate({&input}, {}, {}); + sd::ops::diag_part op; + auto result = op.evaluate({&input}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, betainc_test1) { + auto a = NDArrayFactory::create('c', {3, 3}); + auto b = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); - auto a = NDArrayFactory::create('c', {3,3}); - auto b = NDArrayFactory::create('c', {3,3}); - auto x = NDArrayFactory::create('c', {3,3}); - - a.linspace((float16)0.1, (float16)0.1); - b.linspace((float16)0.1, (float16)0.1); - x.assign(0.1); + a.linspace((float16)0.1, (float16)0.1); + b.linspace((float16)0.1, (float16)0.1); + x.assign(0.1); - auto expected = NDArrayFactory::create('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, + 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f}); - sd::ops::betainc op; - auto result = op.evaluate({&a, &b, &x}, {}, {}); + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output, 1e-2)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output, 1e-2)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, betainc_test2) { + auto a = NDArrayFactory::create('c', {3, 3}); + auto b = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); - auto a= NDArrayFactory::create('c', {3,3}); - auto b= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); - - a.linspace(0.1, 0.1); - b.linspace(0.1, 0.1); - x.assign(0.1); + a.linspace(0.1, 0.1); + b.linspace(0.1, 0.1); + x.assign(0.1); - auto expected= NDArrayFactory::create('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, + 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f}); - sd::ops::betainc op; - auto result = op.evaluate({&a, &b, &x}, {}, {}); + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, betainc_test3) { + auto a = NDArrayFactory::create('c', {3, 3}); + auto b = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); - auto a= NDArrayFactory::create('c', {3,3}); - auto b= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); - - a.linspace(0.1, 0.1); - b.linspace(0.1, 0.1); - x.assign(0.1); + a.linspace(0.1, 0.1); + b.linspace(0.1, 0.1); + x.assign(0.1); - auto expected= NDArrayFactory::create('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, + 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f}); - sd::ops::betainc op; - auto result = op.evaluate({&a, &b, &x}, {}, {}); + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, betainc_test4) { + auto a = NDArrayFactory::create('c', {3, 3}); + auto b = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); - auto a= NDArrayFactory::create('c', {3,3}); - auto b= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); - - a.linspace(1); - b.linspace(1); - x.assign(0.1); + a.linspace(1); + b.linspace(1); + x.assign(0.1); - auto expected= NDArrayFactory::create('c', {3,3}, {1.00000000e-01f, 2.80000000e-02f, 8.56000000e-03f, 2.72800000e-03f, 8.90920000e-04f, 2.95706080e-04f, 9.92854864e-05f, 3.36248880e-05f, 1.14644360e-05f}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {1.00000000e-01f, 2.80000000e-02f, 8.56000000e-03f, 2.72800000e-03f, + 8.90920000e-04f, 2.95706080e-04f, 9.92854864e-05f, 3.36248880e-05f, + 1.14644360e-05f}); - sd::ops::betainc op; - auto result = op.evaluate({&a, &b, &x}, {}, {}); + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output, 1e-6)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output, 1e-6)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, betainc_test5) { + auto a = NDArrayFactory::create('c', {3, 3}); + auto b = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); - auto a= NDArrayFactory::create('c', {3,3}); - auto b= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); - - a.linspace(3200.); - b.linspace(3200.); - x.assign(0.1); + a.linspace(3200.); + b.linspace(3200.); + x.assign(0.1); - auto expected= NDArrayFactory::create('c', {3,3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); - sd::ops::betainc op; - auto result = op.evaluate({&a, &b, &x}, {}, {}); + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output, 1e-6)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output, 1e-6)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, betainc_test6) { + auto a = NDArrayFactory::create('c', {3, 3}); + auto b = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); - auto a= NDArrayFactory::create('c', {3,3}); - auto b= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); - - a.linspace(10.); - b.linspace(10.); - x.assign(0.1); + a.linspace(10.); + b.linspace(10.); + x.assign(0.1); - auto expected= NDArrayFactory::create('c', {3,3}, {3.92988233e-06f, 1.35306497e-06f, 4.67576826e-07f, 1.62083416e-07f, 5.63356971e-08f, 1.96261318e-08f, 6.85120307e-09f, 2.39594668e-09f, 8.39227685e-10f}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {3.92988233e-06f, 1.35306497e-06f, 4.67576826e-07f, 1.62083416e-07f, + 5.63356971e-08f, 1.96261318e-08f, 6.85120307e-09f, 2.39594668e-09f, + 8.39227685e-10f}); - sd::ops::betainc op; - auto result = op.evaluate({&a, &b, &x}, {}, {}); + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output, 1e-6)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output, 1e-6)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, betainc_test7) { + auto a = NDArrayFactory::create('c', {3, 3}); + auto b = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); - auto a= NDArrayFactory::create('c', {3,3}); - auto b= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); - - a.linspace(10.); - b.linspace(10.); - x.assign(0.9); + a.linspace(10.); + b.linspace(10.); + x.assign(0.9); - auto expected= NDArrayFactory::create('c', {3,3}, {0.99999607f, 0.99999865f, 0.99999953f, 0.99999984f, 0.99999994f, 0.99999998f, 0.99999999f, 1.f, 1.f}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {0.99999607f, 0.99999865f, 0.99999953f, 0.99999984f, 0.99999994f, + 0.99999998f, 0.99999999f, 1.f, 1.f}); - sd::ops::betainc op; - auto result = op.evaluate({&a, &b, &x}, {}, {}); + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output, 1e-6)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output, 1e-6)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, betainc_test8) { + auto a = NDArrayFactory::create('c', {3, 3}); + auto b = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); - auto a= NDArrayFactory::create('c', {3,3}); - auto b= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); - - a.linspace(10.); - b.linspace(10.); - x.assign(1.); + a.linspace(10.); + b.linspace(10.); + x.assign(1.); - auto expected= NDArrayFactory::create('c', {3,3}, {1.f, 1.f, 1.f,1.f,1.f,1.f,1.f,1.f,1.f}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}); - sd::ops::betainc op; - auto result = op.evaluate({&a, &b, &x}, {}, {}); + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output, 1e-6)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output, 1e-6)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, betainc_test9) { + auto a = NDArrayFactory::create('c', {3, 3}); + auto b = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); - auto a= NDArrayFactory::create('c', {3,3}); - auto b= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); - - a.linspace(10.); - b.linspace(10.); - x.assign(0.); - - auto expected= NDArrayFactory::create('c', {3,3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + a.linspace(10.); + b.linspace(10.); + x.assign(0.); - sd::ops::betainc op; - auto result = op.evaluate({&a, &b, &x}, {}, {}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); - auto *output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, betainc_test10) { + auto a = NDArrayFactory::create('c', {3, 3}); + auto b = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); - auto a= NDArrayFactory::create('c', {3,3}); - auto b= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); + a.linspace(10.); + b.linspace(10.); + x.assign(0.5); - a.linspace(10.); - b.linspace(10.); - x.assign(0.5); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}); - auto expected= NDArrayFactory::create('c', {3,3}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}); + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); - sd::ops::betainc op; - auto result = op.evaluate({&a, &b, &x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, betainc_test11) { + NDArray a('c', {4}, {0.7788f, 0.8012f, 0.7244f, 0.2309f}, + sd::DataType::FLOAT32); + NDArray b('c', {4}, {0.7717f, 0.9281f, 0.9846f, 0.4838f}, + sd::DataType::FLOAT32); + NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, + sd::DataType::FLOAT32); - NDArray a('c', {4}, {0.7788f, 0.8012f, 0.7244f, 0.2309f}, sd::DataType::FLOAT32); - NDArray b('c', {4}, {0.7717f, 0.9281f, 0.9846f, 0.4838f}, sd::DataType::FLOAT32); - NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, sd::DataType::FLOAT32); - - NDArray expected('c', {4}, {0.912156, 0.634460, 0.898314, 0.624538}, sd::DataType::FLOAT32); - sd::ops::betainc op; - auto result = op.evaluate({&a, &b, &x}, {}, {}); + NDArray expected('c', {4}, {0.912156, 0.634460, 0.898314, 0.624538}, + sd::DataType::FLOAT32); + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, betainc_test12) { + NDArray a('c', {4}, {8.0091f, 8.2108f, 7.5194f, 3.0780f}, + sd::DataType::FLOAT32); + NDArray b('c', {4}, {7.9456f, 9.3527f, 9.8610f, 5.3541f}, + sd::DataType::FLOAT32); + NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, + sd::DataType::FLOAT32); - NDArray a('c', {4}, {8.0091f, 8.2108f, 7.5194f, 3.0780f}, sd::DataType::FLOAT32); - NDArray b('c', {4}, {7.9456f, 9.3527f, 9.8610f, 5.3541f}, sd::DataType::FLOAT32); - NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, sd::DataType::FLOAT32); - - NDArray expected('c', {4}, {0.9999995 , 0.8594694 , 0.999988 , 0.49124345}, sd::DataType::FLOAT32); + NDArray expected('c', {4}, {0.9999995, 0.8594694, 0.999988, 0.49124345}, + sd::DataType::FLOAT32); - sd::ops::betainc op; - auto result = op.evaluate({&a, &b, &x}, {}, {}); + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, zeta_test1) { + auto x = NDArrayFactory::create('c', {3, 3}); + auto q = NDArrayFactory::create('c', {3, 3}); - auto x= NDArrayFactory::create('c', {3,3}); - auto q= NDArrayFactory::create('c', {3,3}); - - q.linspace(1.); - x.assign(2.); + q.linspace(1.); + x.assign(2.); - auto expected= NDArrayFactory::create('c', {3,3}, {1.64493407f, 0.64493407f, 0.39493407f, 0.28382296f, 0.22132296f, 0.18132296f, 0.15354518f, 0.13313701f, 0.11751201f}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {1.64493407f, 0.64493407f, 0.39493407f, 0.28382296f, 0.22132296f, + 0.18132296f, 0.15354518f, 0.13313701f, 0.11751201f}); - sd::ops::zeta op; - auto result = op.evaluate({&x, &q}, {}, {}); + sd::ops::zeta op; + auto result = op.evaluate({&x, &q}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, zeta_test2) { + auto x = NDArrayFactory::create('c', {3, 3}); + auto q = NDArrayFactory::create('c', {3, 3}); - auto x= NDArrayFactory::create('c', {3,3}); - auto q= NDArrayFactory::create('c', {3,3}); - - q.linspace(10.); - x.assign(2.); + q.linspace(10.); + x.assign(2.); - auto expected= NDArrayFactory::create('c', {3,3}, {0.10516634f, 0.09516634f, 0.08690187f, 0.07995743f, 0.07404027f, 0.06893823f, 0.06449378f, 0.06058753f, 0.05712733f}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {0.10516634f, 0.09516634f, 0.08690187f, 0.07995743f, 0.07404027f, + 0.06893823f, 0.06449378f, 0.06058753f, 0.05712733f}); - sd::ops::zeta op; - auto result = op.evaluate({&x, &q}, {}, {}); + sd::ops::zeta op; + auto result = op.evaluate({&x, &q}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, zeta_test3) { + auto x = NDArrayFactory::create('c', {3, 3}); + auto q = NDArrayFactory::create('c', {3, 3}); - auto x= NDArrayFactory::create('c', {3,3}); - auto q= NDArrayFactory::create('c', {3,3}); - - q.linspace(100.); - x.assign(2.); - - auto expected= NDArrayFactory::create('c', {3,3}, {0.01005017f, 0.00995017f, 0.00985214f, 0.00975602f, 0.00966176f, 0.0095693f, 0.0094786f, 0.0093896f, 0.00930226f}); + q.linspace(100.); + x.assign(2.); - sd::ops::zeta op; - auto result = op.evaluate({&x, &q}, {}, {}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {0.01005017f, 0.00995017f, 0.00985214f, 0.00975602f, 0.00966176f, + 0.0095693f, 0.0094786f, 0.0093896f, 0.00930226f}); + sd::ops::zeta op; + auto result = op.evaluate({&x, &q}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, zeta_test4) { + auto x = NDArrayFactory::create('c', {3, 3}); + auto q = NDArrayFactory::create('c', {3, 3}); - auto x= NDArrayFactory::create('c', {3,3}); - auto q= NDArrayFactory::create('c', {3,3}); + q.linspace(100.); + x.assign(2.); - q.linspace(100.); - x.assign(2.); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {0.01005017f, 0.00995017f, 0.00985214f, 0.00975602f, 0.00966176f, + 0.0095693f, 0.0094786f, 0.0093896f, 0.00930226f}); - auto expected= NDArrayFactory::create('c', {3,3}, {0.01005017f, 0.00995017f, 0.00985214f, 0.00975602f, 0.00966176f, 0.0095693f, 0.0094786f, 0.0093896f, 0.00930226f}); + sd::ops::zeta op; + auto result = op.evaluate({&x, &q}, {}, {}); - sd::ops::zeta op; - auto result = op.evaluate({&x, &q}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, zeta_test5) { + auto x = NDArrayFactory::create('c', {3, 3}); + auto q = NDArrayFactory::create('c', {3, 3}); - auto x= NDArrayFactory::create('c', {3,3}); - auto q= NDArrayFactory::create('c', {3,3}); - - q.linspace(1.); - x.assign(1.1); + q.linspace(1.); + x.assign(1.1); - auto expected= NDArrayFactory::create('c', {3,3}, {10.58444846f, 9.58444846f, 9.11793197f, 8.81927915f, 8.60164151f, 8.43137352f, 8.29204706f, 8.17445116f, 8.07291961f}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {10.58444846f, 9.58444846f, 9.11793197f, 8.81927915f, 8.60164151f, + 8.43137352f, 8.29204706f, 8.17445116f, 8.07291961f}); + sd::ops::zeta op; + auto result = op.evaluate({&x, &q}, {}, {}); - sd::ops::zeta op; - auto result = op.evaluate({&x, &q}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, zeta_test6) { + auto x = NDArrayFactory::create('c', {3, 3}); + auto q = NDArrayFactory::create('c', {3, 3}); - auto x= NDArrayFactory::create('c', {3,3}); - auto q= NDArrayFactory::create('c', {3,3}); - - q.linspace(1.); - x.assign(1.01); - - auto expected= NDArrayFactory::create('c', {3,3}, {100.57794334f, 99.57794334f, 99.08139709f, 98.75170576f, 98.50514758f, 98.30834069f, 98.1446337f, 98.00452955f, 97.88210202f}); + q.linspace(1.); + x.assign(1.01); - sd::ops::zeta op; - auto result = op.evaluate({&x, &q}, {}, {}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {100.57794334f, 99.57794334f, 99.08139709f, 98.75170576f, 98.50514758f, + 98.30834069f, 98.1446337f, 98.00452955f, 97.88210202f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::zeta op; + auto result = op.evaluate({&x, &q}, {}, {}); - auto *output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, zeta_test7) { + auto x = NDArrayFactory::create('c', {3, 3}); + auto q = NDArrayFactory::create('c', {3, 3}); - auto x= NDArrayFactory::create('c', {3,3}); - auto q= NDArrayFactory::create('c', {3,3}); + q.linspace(1.); + x.assign(10.); - q.linspace(1.); - x.assign(10.); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {1.00099458e+00f, 9.94575128e-04f, 1.80126278e-05f, 1.07754001e-06f, + 1.23865693e-07f, 2.14656932e-08f, 4.92752156e-09f, 1.38738839e-09f, + 4.56065812e-10f}); - auto expected= NDArrayFactory::create('c', {3,3}, {1.00099458e+00f, 9.94575128e-04f, 1.80126278e-05f, 1.07754001e-06f, 1.23865693e-07f, 2.14656932e-08f, 4.92752156e-09f, 1.38738839e-09f, 4.56065812e-10f}); + sd::ops::zeta op; + auto result = op.evaluate({&x, &q}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - sd::ops::zeta op; - auto result = op.evaluate({&x, &q}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, zeta_test8) { + auto x = NDArrayFactory::create( + 'c', {3, 4}, + {1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 1.01, 1.11, 1.12}); + auto q = NDArrayFactory::create( + 'c', {3, 4}, + {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); - auto x= NDArrayFactory::create('c', {3,4}, {1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,1.01,1.11,1.12}); - auto q= NDArrayFactory::create('c', {3,4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); - - //q.linspace(1.); - //x.assign(10.); + // q.linspace(1.); + // x.assign(10.); - auto expected= NDArrayFactory::create('c', {3,4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); + auto expected = NDArrayFactory::create( + 'c', {3, 4}, + {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, + 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); - sd::ops::zeta op; - auto result = op.evaluate({&x, &q}, {}, {}); + sd::ops::zeta op; + auto result = op.evaluate({&x, &q}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto *output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, zeta_test9) { + auto x = NDArrayFactory::create( + 'c', {3, 4}, + {1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 1.01, 1.11, 1.12}); + auto q = NDArrayFactory::create( + 'c', {3, 4}, + {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); + auto z = NDArrayFactory::create( + 'c', {3, 4}, {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}); - auto x= NDArrayFactory::create('c', {3,4}, {1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,1.01,1.11,1.12}); - auto q= NDArrayFactory::create('c', {3,4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); - auto z= NDArrayFactory::create('c', {3,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.}); + // q.linspace(1.); + // x.assign(10.); - //q.linspace(1.); - //x.assign(10.); + auto expected = NDArrayFactory::create( + 'c', {3, 4}, + {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, + 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); - auto expected= NDArrayFactory::create('c', {3,4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); + sd::ops::zeta op; + auto results = op.execute({&x, &q}, {&z}, {}, {}, {}); - sd::ops::zeta op; - auto results = op.execute({&x, &q}, {&z}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results); - ASSERT_EQ(ND4J_STATUS_OK, results); + // auto output = result.at(0); + // z.printIndexedBuffer("Zeta output"); + ASSERT_TRUE(expected.isSameShape(z)); + ASSERT_TRUE(expected.equalsTo(z)); - //auto *output = result.at(0); - // z.printIndexedBuffer("Zeta output"); - ASSERT_TRUE(expected.isSameShape(z)); - ASSERT_TRUE(expected.equalsTo(z)); - -// + // } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, zeta_test10) { + auto x = NDArrayFactory::create( + 'c', {3, 4}, + {1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 1.01, 1.11, 1.12}); + auto q = NDArrayFactory::create( + 'c', {3, 4}, + {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); + auto z = NDArrayFactory::create('c', {3, 4}); - auto x= NDArrayFactory::create('c', {3,4}, {1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,1.01,1.11,1.12}); - auto q= NDArrayFactory::create('c', {3,4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); - auto z= NDArrayFactory::create('c', {3,4}); + // q.linspace(1.); + // x.assign(10.); - //q.linspace(1.); - //x.assign(10.); + auto expected = NDArrayFactory::create( + 'c', {3, 4}, + {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, + 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); - auto expected= NDArrayFactory::create('c', {3,4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); + sd::ops::zeta op; + auto results = op.execute({&x, &q}, {&z}, {}, {}, {}); - sd::ops::zeta op; - auto results = op.execute({&x, &q}, {&z}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results); - ASSERT_EQ(ND4J_STATUS_OK, results); + // auto output = result.at(0); + // z.printIndexedBuffer("Zeta output"); + ASSERT_TRUE(expected.isSameShape(z)); + ASSERT_TRUE(expected.equalsTo(z)); - //auto *output = result.at(0); - // z.printIndexedBuffer("Zeta output"); - ASSERT_TRUE(expected.isSameShape(z)); - ASSERT_TRUE(expected.equalsTo(z)); - -// + // } - TEST_F(DeclarableOpsTests3, Test_SplitV_Validation_1) { - auto x = NDArrayFactory::create('c', {8, 7}); - auto indices = NDArrayFactory::create('c',{2}, {5, 3}); - auto axis = NDArrayFactory::create(-2); + auto x = NDArrayFactory::create('c', {8, 7}); + auto indices = NDArrayFactory::create('c', {2}, {5, 3}); + auto axis = NDArrayFactory::create(-2); - auto z0 = NDArrayFactory::create('c', {5, 7}); - auto z1 = NDArrayFactory::create('c', {3, 7}); + auto z0 = NDArrayFactory::create('c', {5, 7}); + auto z1 = NDArrayFactory::create('c', {3, 7}); - sd::ops::split_v op; - auto status = op.execute({&x, &indices, &axis}, std::vector{&z0, &z1}, {}, {}, {}); - ASSERT_EQ(Status::OK(), status); + sd::ops::split_v op; + auto status = op.execute({&x, &indices, &axis}, + std::vector{&z0, &z1}, {}, {}, {}); + ASSERT_EQ(Status::OK(), status); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, polygamma_test1) { + auto n = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); + // ASSERT_FALSE(true); + n.linspace(1.); + x.assign(0.5); - auto n= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); -// ASSERT_FALSE(true); - n.linspace(1.); - x.assign(0.5); - - auto expected= NDArrayFactory::create('c', {3,3}, {4.934802, -16.828796, 97.409088, -771.474243, 7691.113770, -92203.460938, 1290440.250000, -20644900.000000, 3.71595e+08}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {4.934802, -16.828796, 97.409088, -771.474243, 7691.113770, -92203.460938, + 1290440.250000, -20644900.000000, 3.71595e+08}); - sd::ops::polygamma op; - auto result = op.evaluate({&n, &x}, {}, {}); + sd::ops::polygamma op; + auto result = op.evaluate({&n, &x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto output = result.at(0); - // output->printBuffer(); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + // output->printBuffer(); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, polygamma_test2) { + auto n = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); - auto n= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); - - n.linspace(10.); - x.linspace(0.5); + n.linspace(10.); + x.linspace(0.5); - auto expected= NDArrayFactory::create('c', {3,3}, {-7.43182451e+09, 3.08334759e+05,-3.25669798e+03, 1.55186197e+02,-1.46220433e+01, 2.00905201e+00,-3.48791235e-01, 7.08016273e-02,-1.60476052e-02}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {-7.43182451e+09, 3.08334759e+05, -3.25669798e+03, 1.55186197e+02, + -1.46220433e+01, 2.00905201e+00, -3.48791235e-01, 7.08016273e-02, + -1.60476052e-02}); - //ASSERT_FALSE(true); + // ASSERT_FALSE(true); - sd::ops::polygamma op; - auto result = op.evaluate({&n, &x}, {}, {}); + sd::ops::polygamma op; + auto result = op.evaluate({&n, &x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, polygamma_test3) { + auto n = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); - auto n= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); - - n.linspace(1.); - x.linspace(10.); + n.linspace(1.); + x.linspace(10.); - auto expected= NDArrayFactory::create('c', {3,3}, {1.05166336e-01,-9.04983497e-03, 1.31009323e-03,-2.44459433e-04, 5.31593880e-05,-1.28049888e-05, 3.31755364e-06,-9.07408791e-07, 2.58758130e-07}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {1.05166336e-01, -9.04983497e-03, 1.31009323e-03, -2.44459433e-04, + 5.31593880e-05, -1.28049888e-05, 3.31755364e-06, -9.07408791e-07, + 2.58758130e-07}); - sd::ops::polygamma op; - auto result = op.evaluate({&n, &x}, {}, {}); + sd::ops::polygamma op; + auto result = op.evaluate({&n, &x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } TEST_F(DeclarableOpsTests3, polygamma_test4) { + NDArray n('c', {3, 4}, {/*0.7788*/ 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + sd::DataType::DOUBLE); + NDArray x('c', {3, 4}, + {0.7717, 0.9281, 0.9846, 0.4838, 0.6433, 0.6041, 0.6501, 0.7612, + 0.7605, 0.3948, 0.9493, 0.8600}, + sd::DataType::DOUBLE); - NDArray n('c', {3,4}, {/*0.7788*/0, 0,1,2,3,4,5,6,7,8,9,10}, sd::DataType::DOUBLE); - NDArray x('c', {3,4}, {0.7717,0.9281,0.9846,0.4838,0.6433,0.6041,0.6501,0.7612,0.7605,0.3948,0.9493,0.8600}, sd::DataType::DOUBLE); - - NDArray expected('c', {3,4}, {/*std::numeric_limits::quiet_NaN()*/-1.031918, -7.021327e-01, 1.682743e+00, -1.851378e+01,3.604167e+01, -3.008293e+02, - 1.596005e+03, -4.876665e+03,4.510025e+04, -1.730340e+08, 6.110257e+05, -1.907087e+07}, sd::DataType::DOUBLE); + NDArray expected( + 'c', {3, 4}, + {/*std::numeric_limits::quiet_NaN()*/ -1.031918, -7.021327e-01, + 1.682743e+00, -1.851378e+01, 3.604167e+01, -3.008293e+02, 1.596005e+03, + -4.876665e+03, 4.510025e+04, -1.730340e+08, 6.110257e+05, -1.907087e+07}, + sd::DataType::DOUBLE); - sd::ops::polygamma op; - auto result = op.evaluate({&n, &x}, {}, {}); + sd::ops::polygamma op; + auto result = op.evaluate({&n, &x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } TEST_F(DeclarableOpsTests3, digamma_1) { + NDArray x('c', {18}, + {-25, -24.99999, -21.5, -21.2, -5.5, -4.1, -2.1, -0.5, -0.3, 0., + 0.2, 1, 1.5, 2.2, 5.2, 19., 21, 22.2}, + sd::DataType::DOUBLE); - NDArray x('c', {18}, {-25, -24.99999, -21.5, -21.2, -5.5, -4.1, -2.1, -0.5, -0.3, 0., 0.2, 1, 1.5, 2.2, 5.2, 19., 21, 22.2}, sd::DataType::DOUBLE); - - NDArray expected('c', {18}, {std::numeric_limits::infinity(), -99996.761229, 3.091129, 7.401432, 1.792911,11.196838,10.630354, 0.03649, 2.11331, - std::numeric_limits::infinity(),-5.28904,-0.577216, 0.03649, 0.544293, 1.549434,2.917892, 3.020524, 3.077401}, sd::DataType::DOUBLE); + NDArray expected( + 'c', {18}, + {std::numeric_limits::infinity(), -99996.761229, 3.091129, + 7.401432, 1.792911, 11.196838, 10.630354, 0.03649, 2.11331, + std::numeric_limits::infinity(), -5.28904, -0.577216, 0.03649, + 0.544293, 1.549434, 2.917892, 3.020524, 3.077401}, + sd::DataType::DOUBLE); - sd::ops::digamma op; - auto result = op.evaluate({&x}, {}, {}); + sd::ops::digamma op; + auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test1) { - - auto x= NDArrayFactory::create('c', {6,6}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16}); - auto expS= NDArrayFactory::create('c', {6}, {54.12775, 38.79293, 25.89287, 9.82168, 6.07227, 2.91827}); - auto expU= NDArrayFactory::create('c', {6,6}, {0.14692,-0.11132,-0.69568, 0.59282,-0.14881, 0.32935,-0.38751, 0.60378,-0.04927,-0.01397,-0.69456,-0.01581, 0.19293,-0.12795,-0.18682,-0.69065,-0.20597, 0.62617, 0.66806, 0.4314 ,-0.33849,-0.22166, 0.04099,-0.44967, 0.11121,-0.64065,-0.02138,-0.07378,-0.60568,-0.45216,-0.5765 ,-0.1007 ,-0.60305,-0.34175, 0.29068,-0.3042}); - auto expV= NDArrayFactory::create('c', {6,6}, {-0.24577,-0.24512, 0.00401,-0.04585,-0.62058, 0.70162, 0.27937, 0.75961, 0.43885,-0.06857,-0.3839 , 0.01669,-0.35944,-0.09629, 0.44593, 0.78602,-0.09103,-0.19125, 0.53973, 0.07613,-0.10721, 0.49559, 0.35687, 0.56431,-0.6226 , 0.39742, 0.12785,-0.15716, 0.52372, 0.37297, 0.23113,-0.43578, 0.76204,-0.32414, 0.23996, 0.11543}); - - sd::ops::svd op; - auto result = op.evaluate({&x}, {}, {1, 1, 16}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto *s = result.at(0); - auto *u = result.at(1); - auto *v = result.at(2); - - ASSERT_TRUE(expS.isSameShape(s)); - ASSERT_TRUE(expU.isSameShape(u)); - ASSERT_TRUE(expV.isSameShape(v)); - - ASSERT_TRUE(expS.equalsTo(s)); - - if(sd::Environment::getInstance().isCPU()) { - ASSERT_TRUE(expU.equalsTo(u)); - ASSERT_TRUE(expV.equalsTo(v)); - } - else { - for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5); - for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5); - } - + auto x = NDArrayFactory::create( + 'c', {6, 6}, {0., -9., -6, 9, -10, -12, 2, 13, 5, -11, 20, -17, + 1, -2, -11, 3, -8, 3, -14, 19, -20, 20, -17, -5, + 6, -16, 0, -1, -16, 11, 7, -19, 2, -17, 17, -16}); + auto expS = NDArrayFactory::create( + 'c', {6}, {54.12775, 38.79293, 25.89287, 9.82168, 6.07227, 2.91827}); + auto expU = NDArrayFactory::create( + 'c', {6, 6}, {0.14692, -0.11132, -0.69568, 0.59282, -0.14881, 0.32935, + -0.38751, 0.60378, -0.04927, -0.01397, -0.69456, -0.01581, + 0.19293, -0.12795, -0.18682, -0.69065, -0.20597, 0.62617, + 0.66806, 0.4314, -0.33849, -0.22166, 0.04099, -0.44967, + 0.11121, -0.64065, -0.02138, -0.07378, -0.60568, -0.45216, + -0.5765, -0.1007, -0.60305, -0.34175, 0.29068, -0.3042}); + auto expV = NDArrayFactory::create( + 'c', {6, 6}, {-0.24577, -0.24512, 0.00401, -0.04585, -0.62058, 0.70162, + 0.27937, 0.75961, 0.43885, -0.06857, -0.3839, 0.01669, + -0.35944, -0.09629, 0.44593, 0.78602, -0.09103, -0.19125, + 0.53973, 0.07613, -0.10721, 0.49559, 0.35687, 0.56431, + -0.6226, 0.39742, 0.12785, -0.15716, 0.52372, 0.37297, + 0.23113, -0.43578, 0.76204, -0.32414, 0.23996, 0.11543}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {1, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if (sd::Environment::getInstance().isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } else { + for (uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), + sd::math::nd4j_abs(u.e(i)), 1e-5); + for (uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), + sd::math::nd4j_abs(v.e(i)), 1e-5); + } } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test2) { - - auto x = NDArrayFactory::create('c', {7,6}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2}); - auto expS= NDArrayFactory::create('c', {6}, {56.76573, 39.11776, 26.00713, 11.83606, 6.16578, 3.99672}); - auto expU= NDArrayFactory::create('c', {7,7}, {-0.13417,-0.12443, -0.68854, 0.5196 , 0.21706, 0.03974, 0.41683, 0.347 , 0.62666, -0.04964, -0.01912, 0.66932, 0.1457 , -0.12183,-0.17329,-0.14666, -0.19639, -0.55355, 0.0614 , 0.75729, 0.1619 ,-0.64703, 0.37056, -0.37398, -0.32922, -0.0186 , -0.35656, -0.26134,-0.08027,-0.64405, -0.0127 , -0.06934, 0.59287, -0.14956, -0.44712, 0.55906,-0.06235, -0.58017, -0.12911, -0.359 , -0.00393, -0.44877, 0.30645,-0.11953, -0.09083, -0.54163, 0.14283, -0.50417, 0.56178}); - auto expV= NDArrayFactory::create('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187}); - - sd::ops::svd op; - auto result = op.evaluate({&x}, {}, {1, 1, 16}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto *s = result.at(0); - auto *u = result.at(1); - auto *v = result.at(2); - - ASSERT_TRUE(expS.isSameShape(s)); - ASSERT_TRUE(expU.isSameShape(u)); - ASSERT_TRUE(expV.isSameShape(v)); - - ASSERT_TRUE(expS.equalsTo(s)); - - if(sd::Environment::getInstance().isCPU()) { - ASSERT_TRUE(expU.equalsTo(u)); - ASSERT_TRUE(expV.equalsTo(v)); - } - else { - for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5); - for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5); - } - + auto x = NDArrayFactory::create( + 'c', {7, 6}, + {0., -9., -6, 9, -10, -12, 2, 13, 5, -11, 20, -17, 1, -2, + -11, 3, -8, 3, -14, 19, -20, 20, -17, -5, 6, -16, 0, -1, + -16, 11, 7, -19, 2, -17, 17, -16, 4, -9, 1, -15, 7, -2}); + auto expS = NDArrayFactory::create( + 'c', {6}, {56.76573, 39.11776, 26.00713, 11.83606, 6.16578, 3.99672}); + auto expU = NDArrayFactory::create( + 'c', {7, 7}, + {-0.13417, -0.12443, -0.68854, 0.5196, 0.21706, 0.03974, 0.41683, + 0.347, 0.62666, -0.04964, -0.01912, 0.66932, 0.1457, -0.12183, + -0.17329, -0.14666, -0.19639, -0.55355, 0.0614, 0.75729, 0.1619, + -0.64703, 0.37056, -0.37398, -0.32922, -0.0186, -0.35656, -0.26134, + -0.08027, -0.64405, -0.0127, -0.06934, 0.59287, -0.14956, -0.44712, + 0.55906, -0.06235, -0.58017, -0.12911, -0.359, -0.00393, -0.44877, + 0.30645, -0.11953, -0.09083, -0.54163, 0.14283, -0.50417, 0.56178}); + auto expV = NDArrayFactory::create( + 'c', {6, 6}, {0.2508, -0.2265, 0.01689, 0.04486, 0.53132, 0.77537, + -0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, + 0.33139, -0.05528, 0.47186, 0.73171, 0.18905, -0.3055, + -0.57263, 0.06276, -0.09542, 0.59396, -0.36152, 0.419, + 0.59193, 0.4361, 0.13557, -0.03632, -0.5755, 0.32944, + -0.21165, -0.44227, 0.75794, -0.29895, -0.27993, 0.13187}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {1, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if (sd::Environment::getInstance().isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } else { + for (uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), + sd::math::nd4j_abs(u.e(i)), 1e-5); + for (uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), + sd::math::nd4j_abs(v.e(i)), 1e-5); + } } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test3) { - - auto x= NDArrayFactory::create('c', {7,6}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2}); - auto expS= NDArrayFactory::create('c', {6}, {56.76573, 39.11776, 26.00713, 11.83606, 6.16578, 3.99672}); - auto expU= NDArrayFactory::create('c', {7,6}, {-0.13417, -0.12443, -0.68854, 0.5196 , 0.21706, 0.03974, 0.347 , 0.62666, -0.04964, -0.01912, 0.66932, 0.1457 ,-0.17329, -0.14666, -0.19639, -0.55355, 0.0614 , 0.75729,-0.64703, 0.37056, -0.37398, -0.32922, -0.0186 , -0.35656,-0.08027, -0.64405, -0.0127 , -0.06934, 0.59287, -0.14956, 0.55906, -0.06235, -0.58017, -0.12911, -0.359 , -0.00393, 0.30645, -0.11953, -0.09083, -0.54163, 0.14283, -0.50417}); - auto expV= NDArrayFactory::create('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187}); - - sd::ops::svd op; - auto result = op.evaluate({&x}, {}, {0, 1, 16}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto *s = result.at(0); - auto *u = result.at(1); - auto *v = result.at(2); - - ASSERT_TRUE(expS.isSameShape(s)); - ASSERT_TRUE(expU.isSameShape(u)); - ASSERT_TRUE(expV.isSameShape(v)); - - ASSERT_TRUE(expS.equalsTo(s)); - - if(sd::Environment::getInstance().isCPU()) { - ASSERT_TRUE(expU.equalsTo(u)); - ASSERT_TRUE(expV.equalsTo(v)); - } - else { - for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5f); - for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5f); - } - + auto x = NDArrayFactory::create( + 'c', {7, 6}, + {0., -9., -6, 9, -10, -12, 2, 13, 5, -11, 20, -17, 1, -2, + -11, 3, -8, 3, -14, 19, -20, 20, -17, -5, 6, -16, 0, -1, + -16, 11, 7, -19, 2, -17, 17, -16, 4, -9, 1, -15, 7, -2}); + auto expS = NDArrayFactory::create( + 'c', {6}, {56.76573, 39.11776, 26.00713, 11.83606, 6.16578, 3.99672}); + auto expU = NDArrayFactory::create( + 'c', {7, 6}, + {-0.13417, -0.12443, -0.68854, 0.5196, 0.21706, 0.03974, 0.347, + 0.62666, -0.04964, -0.01912, 0.66932, 0.1457, -0.17329, -0.14666, + -0.19639, -0.55355, 0.0614, 0.75729, -0.64703, 0.37056, -0.37398, + -0.32922, -0.0186, -0.35656, -0.08027, -0.64405, -0.0127, -0.06934, + 0.59287, -0.14956, 0.55906, -0.06235, -0.58017, -0.12911, -0.359, + -0.00393, 0.30645, -0.11953, -0.09083, -0.54163, 0.14283, -0.50417}); + auto expV = NDArrayFactory::create( + 'c', {6, 6}, {0.2508, -0.2265, 0.01689, 0.04486, 0.53132, 0.77537, + -0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, + 0.33139, -0.05528, 0.47186, 0.73171, 0.18905, -0.3055, + -0.57263, 0.06276, -0.09542, 0.59396, -0.36152, 0.419, + 0.59193, 0.4361, 0.13557, -0.03632, -0.5755, 0.32944, + -0.21165, -0.44227, 0.75794, -0.29895, -0.27993, 0.13187}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {0, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if (sd::Environment::getInstance().isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } else { + for (uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), + sd::math::nd4j_abs(u.e(i)), 1e-5f); + for (uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), + sd::math::nd4j_abs(v.e(i)), 1e-5f); + } } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test4) { - - auto x= NDArrayFactory::create('c', {6,7}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2}); - auto expS= NDArrayFactory::create('c', {6}, {53.11053, 39.09542, 28.1987, 17.7468, 11.61684, 5.36217}); - auto expU= NDArrayFactory::create('c', {6,6}, {-0.16541, 0.21276, 0.51284, 0.20472, 0.74797, 0.25102,-0.49879, 0.12076, 0.37629, -0.7211 , -0.24585, 0.12086,-0.36569,-0.70218, -0.08012, 0.21274, -0.07314, 0.56231,-0.44508, 0.4329 , 0.1356 , 0.60909, -0.47398, -0.02164, 0.61238,-0.05674, 0.59489, 0.06588, -0.3874 , 0.33685,-0.13044,-0.50644, 0.46552, 0.13236, -0.00474, -0.70161}); - auto expV= NDArrayFactory::create('c', {7,7}, {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, -0.16709, 0.21989, -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, -0.06862, 0.32179, 0.12812, -0.25812, 0.0691 , -0.12891, 0.26979, 0.84807,-0.50833, 0.13793, 0.06658, -0.53001, 0.52572, -0.16194, 0.36692, 0.48118, 0.15876, -0.65132, -0.24602, 0.3963 , -0.16651, -0.27155,-0.31605, -0.46947, -0.50195, 0.0378 , -0.34937, -0.53062, 0.15069, 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151 , 0.13065}); - - sd::ops::svd op; - auto result = op.evaluate({&x}, {}, {1, 1, 16}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto *s = result.at(0); - auto *u = result.at(1); - auto *v = result.at(2); - - ASSERT_TRUE(expS.isSameShape(s)); - ASSERT_TRUE(expU.isSameShape(u)); - ASSERT_TRUE(expV.isSameShape(v)); - - ASSERT_TRUE(expS.equalsTo(s)); - - if(sd::Environment::getInstance().isCPU()) { - ASSERT_TRUE(expU.equalsTo(u)); - ASSERT_TRUE(expV.equalsTo(v)); - } - else { - for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5f); - for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5f); - } - + auto x = NDArrayFactory::create( + 'c', {6, 7}, + {0., -9., -6, 9, -10, -12, 2, 13, 5, -11, 20, -17, 1, -2, + -11, 3, -8, 3, -14, 19, -20, 20, -17, -5, 6, -16, 0, -1, + -16, 11, 7, -19, 2, -17, 17, -16, 4, -9, 1, -15, 7, -2}); + auto expS = NDArrayFactory::create( + 'c', {6}, {53.11053, 39.09542, 28.1987, 17.7468, 11.61684, 5.36217}); + auto expU = NDArrayFactory::create( + 'c', {6, 6}, {-0.16541, 0.21276, 0.51284, 0.20472, 0.74797, 0.25102, + -0.49879, 0.12076, 0.37629, -0.7211, -0.24585, 0.12086, + -0.36569, -0.70218, -0.08012, 0.21274, -0.07314, 0.56231, + -0.44508, 0.4329, 0.1356, 0.60909, -0.47398, -0.02164, + 0.61238, -0.05674, 0.59489, 0.06588, -0.3874, 0.33685, + -0.13044, -0.50644, 0.46552, 0.13236, -0.00474, -0.70161}); + auto expV = NDArrayFactory::create( + 'c', {7, 7}, + {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, -0.16709, + 0.21989, -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, -0.06862, + 0.32179, 0.12812, -0.25812, 0.0691, -0.12891, 0.26979, 0.84807, + -0.50833, 0.13793, 0.06658, -0.53001, 0.52572, -0.16194, 0.36692, + 0.48118, 0.15876, -0.65132, -0.24602, 0.3963, -0.16651, -0.27155, + -0.31605, -0.46947, -0.50195, 0.0378, -0.34937, -0.53062, 0.15069, + 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151, 0.13065}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {1, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if (sd::Environment::getInstance().isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } else { + for (uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), + sd::math::nd4j_abs(u.e(i)), 1e-5f); + for (uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), + sd::math::nd4j_abs(v.e(i)), 1e-5f); + } } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test5) { - - auto x= NDArrayFactory::create('c', {6,7}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2}); - auto expS= NDArrayFactory::create('c', {6}, {53.11053, 39.09542, 28.1987, 17.7468, 11.61684, 5.36217}); - auto expU= NDArrayFactory::create('c', {6,6}, {-0.16541, 0.21276, 0.51284, 0.20472, 0.74797, 0.25102,-0.49879, 0.12076, 0.37629, -0.7211 , -0.24585, 0.12086,-0.36569,-0.70218, -0.08012, 0.21274, -0.07314, 0.56231,-0.44508, 0.4329 , 0.1356 , 0.60909, -0.47398, -0.02164, 0.61238,-0.05674, 0.59489, 0.06588, -0.3874 , 0.33685,-0.13044,-0.50644, 0.46552, 0.13236, -0.00474, -0.70161}); - auto expV= NDArrayFactory::create('c', {7,6}, {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, 0.21989, -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, 0.32179, 0.12812, -0.25812, 0.0691 , -0.12891, 0.26979,-0.50833, 0.13793, 0.06658, -0.53001, 0.52572, -0.16194, 0.48118, 0.15876, -0.65132, -0.24602, 0.3963 , -0.16651,-0.31605, -0.46947, -0.50195, 0.0378 , -0.34937, -0.53062, 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151}); - - sd::ops::svd op; - auto result = op.evaluate({&x}, {}, {0, 1, 16}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto *s = result.at(0); - auto *u = result.at(1); - auto *v = result.at(2); - - ASSERT_TRUE(expS.isSameShape(s)); - ASSERT_TRUE(expU.isSameShape(u)); - ASSERT_TRUE(expV.isSameShape(v)); - - ASSERT_TRUE(expS.equalsTo(s)); - - if(sd::Environment::getInstance().isCPU()) { - ASSERT_TRUE(expU.equalsTo(u)); - ASSERT_TRUE(expV.equalsTo(v)); - } - else { - for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5f); - for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5f); - } - + auto x = NDArrayFactory::create( + 'c', {6, 7}, + {0., -9., -6, 9, -10, -12, 2, 13, 5, -11, 20, -17, 1, -2, + -11, 3, -8, 3, -14, 19, -20, 20, -17, -5, 6, -16, 0, -1, + -16, 11, 7, -19, 2, -17, 17, -16, 4, -9, 1, -15, 7, -2}); + auto expS = NDArrayFactory::create( + 'c', {6}, {53.11053, 39.09542, 28.1987, 17.7468, 11.61684, 5.36217}); + auto expU = NDArrayFactory::create( + 'c', {6, 6}, {-0.16541, 0.21276, 0.51284, 0.20472, 0.74797, 0.25102, + -0.49879, 0.12076, 0.37629, -0.7211, -0.24585, 0.12086, + -0.36569, -0.70218, -0.08012, 0.21274, -0.07314, 0.56231, + -0.44508, 0.4329, 0.1356, 0.60909, -0.47398, -0.02164, + 0.61238, -0.05674, 0.59489, 0.06588, -0.3874, 0.33685, + -0.13044, -0.50644, 0.46552, 0.13236, -0.00474, -0.70161}); + auto expV = NDArrayFactory::create( + 'c', {7, 6}, + {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, 0.21989, + -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, 0.32179, 0.12812, + -0.25812, 0.0691, -0.12891, 0.26979, -0.50833, 0.13793, 0.06658, + -0.53001, 0.52572, -0.16194, 0.48118, 0.15876, -0.65132, -0.24602, + 0.3963, -0.16651, -0.31605, -0.46947, -0.50195, 0.0378, -0.34937, + -0.53062, 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {0, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if (sd::Environment::getInstance().isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } else { + for (uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), + sd::math::nd4j_abs(u.e(i)), 1e-5f); + for (uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), + sd::math::nd4j_abs(v.e(i)), 1e-5f); + } } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test6) { - - auto x= NDArrayFactory::create('c', {2,2,5,5}, {-7. ,17 ,4 ,-10 ,5 ,1 ,-5 ,-19 ,13 ,-8 ,9 ,13 ,19 ,13 ,-2 - ,-8 ,10 ,-9 ,0 ,-20 ,-2 ,14 ,19 ,5 ,-18 ,4 ,-13 ,12 ,-10 ,5 ,-10 ,-10 ,17 ,-5 ,-2 ,10 ,5 ,-4 ,-11 ,15 ,-3 ,15 ,-17 - ,-20 ,-10 ,-4 ,12 ,-9 ,16 ,13 ,10 ,-19 ,2 ,-9 ,-10 ,8 ,-2 ,-4 ,3 ,7 ,10 ,-19 ,-11 ,-4 ,-6 ,2 ,-12 ,6 ,-4 ,-14 ,14 - ,16 ,7 ,19 ,-17 ,2 ,-14 ,5 ,-1 ,16 ,19 ,-11 ,-14 ,-16 ,-19 ,15 ,-18 ,-12 ,-16 ,16 ,1 ,5 ,7 ,8 ,2 ,13 ,-3 ,6 ,2 ,-5}); - auto expS= NDArrayFactory::create('c', {2,2,5}, {40.95395, 31.46869, 24.79993, 12.33768, 1.80031, - 38.18412, 31.52287, 23.52755, 11.79484, 1.90195, - 39.34498, 32.54861, 17.52492, 7.03003, 2.2399, - 44.72126, 32.3164 , 16.60139, 6.88783, 0.78122}); - auto expU= NDArrayFactory::create('c', {2,2,5,5}, {0.25441, 0.16908, -0.68564, 0.58844, -0.30054, - -0.32285, -0.58332, 0.3451 , 0.4746 , -0.45953,0.58332, 0.10605, 0.51533, 0.50234, 0.36136,0.12588, -0.73123, -0.37812, -0.00215, 0.55361, - 0.68915, -0.2919 , 0.04767, -0.4197 , -0.51132,0.44464, -0.25326, -0.42493, -0.01712, -0.74653,0.516 , -0.16688, 0.1854 , -0.77155, 0.27611, - -0.19321, -0.14317, -0.85886, -0.15224, 0.42585,-0.60155, -0.68323, 0.18819, -0.29053, -0.22696,-0.36993, 0.64862, -0.10956, -0.54483, -0.36552, - -0.57697, -0.32277, 0.11229, 0.55495, 0.4923 ,-0.02937, 0.01689, -0.63257, 0.57075, -0.52245,-0.56002, -0.2036 , -0.53119, -0.6022 , 0.01017, - -0.33605, -0.35257, 0.53215, -0.04936, -0.69075,0.48958, -0.85427, -0.14796, -0.03449, 0.08633,0.15008, 0.60996, 0.31071, -0.67721, 0.22421, - 0.67717, -0.59857, 0.04372, -0.2565 , 0.33979,0.68116, 0.49852, -0.13441, 0.51374, -0.07421,-0.20066, 0.04504, 0.42865, 0.44418, 0.75939,0.12113, -0.13826, 0.83651, 0.11988, -0.50209}); - auto expV= NDArrayFactory::create('c', {2,2,5,5}, {0.01858, 0.17863, 0.51259, 0.14048, 0.82781, - 0.59651, -0.13439, -0.395 , 0.66979, 0.14654,0.73731, 0.47061, 0.19357, -0.41127, -0.16817,0.1047 , -0.29727, 0.73711, 0.38235, -0.45951, - -0.29873, 0.80012, -0.02078, 0.4651 , -0.23201,-0.05314, -0.0419 , -0.52146, 0.77792, 0.344 ,-0.66438, 0.05648, 0.03756, -0.31531, 0.67422, - 0.74471, 0.01504, -0.03081, -0.24335, 0.62049,0.03172, 0.91947, 0.30828, 0.23713, 0.04796,-0.01311, 0.38652, -0.79415, -0.42423, -0.19945, - -0.13783, -0.54667, -0.58527, 0.49955, 0.3001 ,0.85214, 0.01628, 0.02688, -0.02891, 0.52157,0.16608, -0.20181, 0.61371, 0.69894, -0.25794, - 0.45726, -0.33952, -0.32659, -0.18938, -0.73015,0.13486, 0.73816, -0.41646, 0.47458, -0.1956 ,0.5536 , -0.137 , 0.64688, 0.50536, 0.03017, - -0.51827, -0.31837, -0.16732, 0.71378, -0.30425,-0.39314, 0.15266, 0.63693, -0.30945, -0.5663 ,-0.51981, 0.03325, 0.37603, 0.05147, 0.76462,-0.01282, 0.92491, -0.08042, 0.36977, -0.03428}); - - sd::ops::svd op; - auto result = op.evaluate({&x}, {}, {1, 1, 16}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto *s = result.at(0); - auto *u = result.at(1); - auto *v = result.at(2); - - ASSERT_TRUE(expS.isSameShape(s)); - ASSERT_TRUE(expU.isSameShape(u)); - ASSERT_TRUE(expV.isSameShape(v)); - - ASSERT_TRUE(expS.equalsTo(s)); - - if(sd::Environment::getInstance().isCPU()) { - ASSERT_TRUE(expU.equalsTo(u)); - ASSERT_TRUE(expV.equalsTo(v)); - } - else { - for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5f); - for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5f); - } - + auto x = NDArrayFactory::create( + 'c', {2, 2, 5, 5}, + {-7., 17, 4, -10, 5, 1, -5, -19, 13, -8, 9, 13, 19, 13, -2, + -8, 10, -9, 0, -20, -2, 14, 19, 5, -18, 4, -13, 12, -10, 5, + -10, -10, 17, -5, -2, 10, 5, -4, -11, 15, -3, 15, -17, -20, -10, + -4, 12, -9, 16, 13, 10, -19, 2, -9, -10, 8, -2, -4, 3, 7, + 10, -19, -11, -4, -6, 2, -12, 6, -4, -14, 14, 16, 7, 19, -17, + 2, -14, 5, -1, 16, 19, -11, -14, -16, -19, 15, -18, -12, -16, 16, + 1, 5, 7, 8, 2, 13, -3, 6, 2, -5}); + auto expS = NDArrayFactory::create( + 'c', {2, 2, 5}, + {40.95395, 31.46869, 24.79993, 12.33768, 1.80031, 38.18412, 31.52287, + 23.52755, 11.79484, 1.90195, 39.34498, 32.54861, 17.52492, 7.03003, + 2.2399, 44.72126, 32.3164, 16.60139, 6.88783, 0.78122}); + auto expU = NDArrayFactory::create( + 'c', {2, 2, 5, 5}, + {0.25441, 0.16908, -0.68564, 0.58844, -0.30054, -0.32285, -0.58332, + 0.3451, 0.4746, -0.45953, 0.58332, 0.10605, 0.51533, 0.50234, + 0.36136, 0.12588, -0.73123, -0.37812, -0.00215, 0.55361, 0.68915, + -0.2919, 0.04767, -0.4197, -0.51132, 0.44464, -0.25326, -0.42493, + -0.01712, -0.74653, 0.516, -0.16688, 0.1854, -0.77155, 0.27611, + -0.19321, -0.14317, -0.85886, -0.15224, 0.42585, -0.60155, -0.68323, + 0.18819, -0.29053, -0.22696, -0.36993, 0.64862, -0.10956, -0.54483, + -0.36552, -0.57697, -0.32277, 0.11229, 0.55495, 0.4923, -0.02937, + 0.01689, -0.63257, 0.57075, -0.52245, -0.56002, -0.2036, -0.53119, + -0.6022, 0.01017, -0.33605, -0.35257, 0.53215, -0.04936, -0.69075, + 0.48958, -0.85427, -0.14796, -0.03449, 0.08633, 0.15008, 0.60996, + 0.31071, -0.67721, 0.22421, 0.67717, -0.59857, 0.04372, -0.2565, + 0.33979, 0.68116, 0.49852, -0.13441, 0.51374, -0.07421, -0.20066, + 0.04504, 0.42865, 0.44418, 0.75939, 0.12113, -0.13826, 0.83651, + 0.11988, -0.50209}); + auto expV = NDArrayFactory::create( + 'c', {2, 2, 5, 5}, + {0.01858, 0.17863, 0.51259, 0.14048, 0.82781, 0.59651, -0.13439, + -0.395, 0.66979, 0.14654, 0.73731, 0.47061, 0.19357, -0.41127, + -0.16817, 0.1047, -0.29727, 0.73711, 0.38235, -0.45951, -0.29873, + 0.80012, -0.02078, 0.4651, -0.23201, -0.05314, -0.0419, -0.52146, + 0.77792, 0.344, -0.66438, 0.05648, 0.03756, -0.31531, 0.67422, + 0.74471, 0.01504, -0.03081, -0.24335, 0.62049, 0.03172, 0.91947, + 0.30828, 0.23713, 0.04796, -0.01311, 0.38652, -0.79415, -0.42423, + -0.19945, -0.13783, -0.54667, -0.58527, 0.49955, 0.3001, 0.85214, + 0.01628, 0.02688, -0.02891, 0.52157, 0.16608, -0.20181, 0.61371, + 0.69894, -0.25794, 0.45726, -0.33952, -0.32659, -0.18938, -0.73015, + 0.13486, 0.73816, -0.41646, 0.47458, -0.1956, 0.5536, -0.137, + 0.64688, 0.50536, 0.03017, -0.51827, -0.31837, -0.16732, 0.71378, + -0.30425, -0.39314, 0.15266, 0.63693, -0.30945, -0.5663, -0.51981, + 0.03325, 0.37603, 0.05147, 0.76462, -0.01282, 0.92491, -0.08042, + 0.36977, -0.03428}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {1, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if (sd::Environment::getInstance().isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } else { + for (uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), + sd::math::nd4j_abs(u.e(i)), 1e-5f); + for (uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), + sd::math::nd4j_abs(v.e(i)), 1e-5f); + } } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test7) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 5, 5}, + {-7., 17, 4, -10, 5, 1, -5, -19, 13, -8, 9, 13, 19, 13, -2, + -8, 10, -9, 0, -20, -2, 14, 19, 5, -18, 4, -13, 12, -10, 5, + -10, -10, 17, -5, -2, 10, 5, -4, -11, 15, -3, 15, -17, -20, -10, + -4, 12, -9, 16, 13, 10, -19, 2, -9, -10, 8, -2, -4, 3, 7, + 10, -19, -11, -4, -6, 2, -12, 6, -4, -14, 14, 16, 7, 19, -17, + 2, -14, 5, -1, 16, 19, -11, -14, -16, -19, 15, -18, -12, -16, 16, + 1, 5, 7, 8, 2, 13, -3, 6, 2, -5}); + auto expS = NDArrayFactory::create( + 'c', {2, 2, 5}, + {40.95395, 31.46869, 24.79993, 12.33768, 1.80031, 38.18412, 31.52287, + 23.52755, 11.79484, 1.90195, 39.34498, 32.54861, 17.52492, 7.03003, + 2.2399, 44.72126, 32.3164, 16.60139, 6.88783, 0.78122}); - auto x= NDArrayFactory::create('c', {2,2,5,5}, {-7. ,17 ,4 ,-10 ,5 ,1 ,-5 ,-19 ,13 ,-8 ,9 ,13 ,19 ,13 ,-2,-8 ,10 ,-9 ,0 ,-20 ,-2 ,14 ,19 ,5 ,-18 ,4 ,-13 ,12 ,-10 - ,5 ,-10 ,-10 ,17 ,-5 ,-2 ,10 ,5 ,-4 ,-11 ,15 ,-3 ,15 ,-17,-20 ,-10 ,-4 ,12 ,-9 ,16 ,13 ,10 ,-19 ,2 ,-9 ,-10 ,8 ,-2 - ,-4 ,3 ,7 ,10 ,-19 ,-11 ,-4 ,-6 ,2 ,-12 ,6 ,-4 ,-14 ,14,16 ,7 ,19 ,-17 ,2 ,-14 ,5 ,-1 ,16 ,19 ,-11 ,-14 ,-16,-19 ,15 ,-18 ,-12 ,-16 ,16 ,1 ,5 ,7 ,8 ,2 ,13 ,-3 ,6 ,2 ,-5}); - auto expS= NDArrayFactory::create('c', {2,2,5}, {40.95395, 31.46869, 24.79993, 12.33768, 1.80031,38.18412, 31.52287, 23.52755, 11.79484, 1.90195, - 39.34498, 32.54861, 17.52492, 7.03003, 2.2399,44.72126, 32.3164 , 16.60139, 6.88783, 0.78122}); + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {0, 0, 16}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - sd::ops::svd op; - auto result = op.evaluate({&x}, {}, {0, 0, 16}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto *s = result.at(0); - - ASSERT_TRUE(expS.equalsTo(s)); - ASSERT_TRUE(expS.isSameShape(s)); + auto s = result.at(0); + ASSERT_TRUE(expS.equalsTo(s)); + ASSERT_TRUE(expS.isSameShape(s)); } /////////////////////////////////////////////////////////////////// // TEST_F(DeclarableOpsTests3, svd_test8) { -// auto x= NDArrayFactory::create('c', {2,2,11,10}, {3 ,-8 ,0 ,3 ,-5 ,16 ,-3 ,7 ,-4 ,19 ,19 ,13 ,15 ,15 ,9 ,6 ,-7 ,-5 ,-9 ,-12 ,7 ,-1 ,-1 ,6 ,19 -// ,-6 ,16 ,0 ,16 ,16 ,7 ,14 ,18. ,0 ,18 ,-4 ,10 ,-16 ,-17 ,15 ,13 ,-17 ,-14 ,-17 ,-5 ,-9 ,-1 ,-19 -// ,-18 ,5 ,-5 ,-13 ,17 ,-19 ,-5 ,18 ,4 ,10 ,17 ,-7 ,-10 ,16 ,10 ,8 ,-10 ,-3 ,10 ,1 ,-4 ,-16 ,-1 -// ,-1 ,5 ,5 ,17 ,14 ,20 ,15 ,-6 ,19 ,14 ,17 ,0 ,-17 ,-16 ,-8 ,-6 ,3 ,-6 ,-11 ,-4 ,-2 ,-7 ,4 ,-6 -// ,-6 ,-17 ,16 ,-8 ,-20 ,2 ,7 ,-12 ,15 ,-15 ,-19 ,14 ,17 ,9 ,10 ,5 ,18 ,2 ,-6 ,0 ,2 ,-10 ,7 ,8 -// ,-13 ,2 ,8 ,20 ,11 ,-15 ,13 ,-10 ,-14 ,-2 ,20 ,5 ,2 ,16 ,18 ,-3 ,3 ,-18 ,15 ,-11 ,17 ,-8 ,-18 -// ,20 ,-12 ,20 ,20 ,-16 ,20 ,-8. ,19 ,-8 ,3 ,-3 ,17 ,7 ,13 ,9 ,-2 ,11 ,16 ,4 ,-18 ,5 ,0 ,-12 ,9 -// ,-6 ,6 ,0 ,-9 ,-13 ,13 ,17 ,-12 ,3 ,-13 ,17 ,-19 ,17 ,0 ,-8 ,4 ,-19 ,-9 ,-7 ,12 ,-1 ,-12 ,-1 -// ,7 ,2 ,19 ,10 ,19 ,-15 ,-18 ,17 ,-1 ,1 ,14 ,-7 ,-10 ,12 ,-20 ,6 ,-5 ,14 ,5 ,5 ,3 ,-18 ,5 ,17 -// ,-13 ,20 ,-1 ,-2 ,-11 ,-5 ,14 ,8 ,7 ,-13 ,-9 ,-12 ,11 ,3 ,14 ,-6 ,-2 ,13 ,8 ,-15 ,-5 ,-6 ,-7 ,19 -// ,-1 ,6 ,1 ,14 ,8 ,18 ,-20 ,-14 ,-3 ,-5 ,19 ,15 ,13 ,2 ,-20 ,2 ,14 ,13 ,4 ,-15 ,1 ,-14 -// ,0 ,9 ,-1 ,10 ,4 ,6 ,4 ,-7 ,-2 ,-1 ,-15 ,-1 ,-16 ,-5 ,-12 ,-10 ,16 ,-16 ,-15 ,-17 ,-5 ,-6 -// ,18 ,14 ,-3 ,-10 ,8 ,20 ,19 ,20 ,-3 ,-6 ,9 ,10 ,-1 ,-20 ,-5 ,5 ,12 ,8 ,17 ,13 ,-18 ,-14 ,0 -// ,4 ,-11 ,3 ,-12 ,-2 ,-5 ,19 ,-15 ,19 ,16 ,-16 ,13 ,-6 ,11 ,11 ,0 ,-18 ,4 ,5 ,6 ,-12 ,-10 -// ,-3 ,2 ,-18 ,16 ,-5 ,17 ,16 ,-16 ,-20 ,14 ,6 ,10 ,-5 ,-3 ,4 ,20 ,18 ,5 ,1 ,-10 ,15 ,10 ,16 -// ,-18 ,2 ,12 ,20 ,6 ,14 ,8 ,3 ,-2 ,9 ,15 ,-4 ,13 ,-19 ,-5 ,3 ,3 ,-20 ,-4 ,18 ,-11 ,11 ,-10 -// ,3 ,8 ,9 ,20 ,-19 ,6 ,18 ,9 ,20 ,-12 ,4 ,15 ,19 ,3 ,5 ,1 ,2 ,20 ,-3 ,-1 ,-8 ,-3 ,8 ,17 , -// -14 ,18 ,-10 ,4 ,13 ,-5 ,13 ,-6 ,12 ,-10 ,19 ,4 ,-7 ,-17 ,20 ,8 ,6 ,-3 ,3 ,-7 ,-18 ,17 , -// -13 ,18 ,-20 ,-16 ,-5 ,12 ,5 ,17 ,-4 ,4 ,7 ,8 ,17 ,-9 ,-12 ,-10 ,8 ,-14 ,-11 ,7 ,19 ,-17}); - -// auto expS= NDArrayFactory::create('c', {2,2,10}, { 64.12636, 54.37044, 50.63744, 48.10308, 33.7364 , 29.96456, +// auto x= NDArrayFactory::create('c', {2,2,11,10}, {3 ,-8 ,0 ,3 ,-5 +// ,16 ,-3 ,7 ,-4 ,19 ,19 ,13 ,15 ,15 ,9 ,6 ,-7 ,-5 ,-9 ,-12 ,7 ,-1 ,-1 ,6 +// ,19 +// ,-6 ,16 ,0 ,16 ,16 ,7 ,14 ,18. ,0 ,18 ,-4 ,10 ,-16 ,-17 ,15 ,13 +// ,-17 ,-14 ,-17 ,-5 ,-9 ,-1 ,-19 +// ,-18 ,5 ,-5 ,-13 ,17 ,-19 ,-5 ,18 ,4 ,10 ,17 ,-7 ,-10 ,16 ,10 ,8 +// ,-10 ,-3 ,10 ,1 ,-4 ,-16 ,-1 +// ,-1 ,5 ,5 ,17 ,14 ,20 ,15 ,-6 ,19 ,14 ,17 ,0 ,-17 ,-16 ,-8 ,-6 ,3 +// ,-6 ,-11 ,-4 ,-2 ,-7 ,4 ,-6 +// ,-6 ,-17 ,16 ,-8 ,-20 ,2 ,7 ,-12 ,15 ,-15 ,-19 ,14 ,17 ,9 ,10 ,5 +// ,18 ,2 ,-6 ,0 ,2 ,-10 ,7 ,8 +// ,-13 ,2 ,8 ,20 ,11 ,-15 ,13 ,-10 ,-14 ,-2 ,20 ,5 ,2 ,16 ,18 ,-3 ,3 +// ,-18 ,15 ,-11 ,17 ,-8 ,-18 ,20 ,-12 ,20 ,20 ,-16 ,20 ,-8. ,19 ,-8 +// ,3 ,-3 ,17 ,7 ,13 ,9 ,-2 ,11 ,16 ,4 ,-18 ,5 ,0 ,-12 ,9 +// ,-6 ,6 ,0 ,-9 ,-13 ,13 ,17 ,-12 ,3 ,-13 ,17 ,-19 ,17 ,0 ,-8 ,4 +// ,-19 ,-9 ,-7 ,12 ,-1 ,-12 ,-1 ,7 ,2 ,19 ,10 ,19 ,-15 ,-18 ,17 ,-1 +// ,1 ,14 ,-7 ,-10 ,12 ,-20 ,6 ,-5 ,14 ,5 ,5 ,3 ,-18 ,5 ,17 +// ,-13 ,20 ,-1 ,-2 ,-11 ,-5 ,14 ,8 ,7 ,-13 ,-9 ,-12 ,11 ,3 ,14 ,-6 +// ,-2 ,13 ,8 ,-15 ,-5 ,-6 ,-7 ,19 +// ,-1 ,6 ,1 ,14 ,8 ,18 ,-20 ,-14 ,-3 ,-5 ,19 ,15 ,13 ,2 ,-20 ,2 ,14 +// ,13 ,4 ,-15 ,1 ,-14 ,0 ,9 ,-1 ,10 ,4 ,6 ,4 ,-7 ,-2 ,-1 ,-15 ,-1 +// ,-16 ,-5 ,-12 ,-10 ,16 ,-16 ,-15 ,-17 ,-5 ,-6 ,18 ,14 ,-3 ,-10 ,8 +// ,20 ,19 ,20 ,-3 ,-6 ,9 ,10 ,-1 ,-20 ,-5 ,5 ,12 ,8 ,17 ,13 ,-18 +// ,-14 ,0 ,4 ,-11 ,3 ,-12 ,-2 ,-5 ,19 ,-15 ,19 ,16 ,-16 ,13 ,-6 ,11 +// ,11 ,0 ,-18 ,4 ,5 ,6 ,-12 ,-10 +// ,-3 ,2 ,-18 ,16 ,-5 ,17 ,16 ,-16 ,-20 ,14 ,6 ,10 ,-5 ,-3 ,4 ,20 +// ,18 ,5 ,1 ,-10 ,15 ,10 ,16 +// ,-18 ,2 ,12 ,20 ,6 ,14 ,8 ,3 ,-2 ,9 ,15 ,-4 ,13 ,-19 ,-5 ,3 ,3 +// ,-20 ,-4 ,18 ,-11 ,11 ,-10 ,3 ,8 ,9 ,20 ,-19 ,6 ,18 ,9 ,20 ,-12 ,4 +// ,15 ,19 ,3 ,5 ,1 ,2 ,20 ,-3 ,-1 ,-8 ,-3 ,8 ,17 , +// -14 ,18 ,-10 ,4 ,13 ,-5 ,13 ,-6 ,12 +// ,-10 ,19 ,4 ,-7 ,-17 ,20 ,8 ,6 ,-3 ,3 +// ,-7 ,-18 ,17 , -13 ,18 ,-20 ,-16 ,-5 +// ,12 ,5 ,17 ,-4 ,4 ,7 ,8 ,17 ,-9 ,-12 +// ,-10 ,8 ,-14 ,-11 ,7 ,19 ,-17}); + +// auto expS= NDArrayFactory::create('c', {2,2,10}, +// { 64.12636, 54.37044, 50.63744, 48.10308, 33.7364 , 29.96456, // 25.53945, 19.31856, 15.30939, 9.31349, -// 67.41342, 59.64963, 58.72687, 39.22496, 32.39772, 29.30833, -// 23.1491 , 16.92442, 6.38613, 3.49563, -// 74.37477, 52.07016, 46.10758, 39.10742, 32.02261, 27.05888, -// 20.54921, 13.17989, 8.4158 , 4.39974, -// 65.47447, 56.31305, 54.13371, 46.26955, 43.47755, 30.25799, +// 67.41342, 59.64963, 58.72687, 39.22496, +// 32.39772, 29.30833, 23.1491 +// , 16.92442, 6.38613, 3.49563, +// 74.37477, 52.07016, 46.10758, 39.10742, +// 32.02261, 27.05888, +// 20.54921, 13.17989, 8.4158 +// , 4.39974, +// 65.47447, 56.31305, 54.13371, 46.26955, +// 43.47755, 30.25799, // 20.71463, 16.89671, 10.39572, 7.81631}); -// auto expU= NDArrayFactory::create('c', {2,2,11,11}, {-0.177870, -0.149461, -0.196911, 0.036990, -0.338237, 0.548901, -// -0.074396, 0.497067, -0.083636, -0.111810, -0.466989, -0.010465, 0.434732, 0.337198, 0.305239, -0.292813, -// 0.041280, -0.517144, 0.121499, 0.464908, 0.003658, 0.135017, -0.446916, -0.098318, 0.073571, -0.200521, -// 0.186776, -0.353022, -0.435582, -0.225959, 0.052972, 0.032390, -0.583801, -0.402790, 0.562809, 0.102744, -// 0.066555, 0.206079, 0.115322, 0.217220, -0.062591, -0.273173, -0.569645, 0.005612, 0.092601, 0.350055, -// -0.608007, -0.367743, 0.064860, 0.112656, 0.091576, -0.144262, 0.554655, -0.042100, -0.092023, 0.026986, -// -0.395811, -0.245209, 0.572522, 0.429430, 0.099621, -0.159236, -0.086263, 0.268160, -0.391298, 0.050417, -// 0.150175, 0.045253, 0.464173, 0.138376, 0.265551, 0.049691, 0.528778, 0.116951, 0.384609, 0.144416, -// -0.453591, -0.519390, -0.150671, 0.072897, 0.102406, -0.154184, 0.450735, 0.174171, -0.519405, 0.147109, -// 0.333670, 0.178053, 0.360763, 0.226976, 0.069976, -0.046765, 0.448897, 0.511309, -0.361050, -0.191690, -// -0.304442, 0.270383, -0.124133, 0.417183, -0.083359, 0.137022, 0.004276, -0.462336, 0.051267, 0.020622, -// -0.566932, -0.051351, -0.417106, -0.292202, -0.021595, -0.315956, 0.396626, -0.604952, 0.155990, 0.258395, -// -0.125080, 0.115404, 0.234517, -0.357460, 0.271271, 0.063771, -0.087400, -0.024710, -0.179892, 0.584339, -// -0.413085, 0.510580, 0.334646, 0.044424, 0.224735, 0.134434, -0.147861, 0.291853, 0.487948, 0.238917, -// 0.433893, 0.435884, 0.056370, -0.051216, -0.450902, 0.062411, 0.080733, -0.365211, 0.031931, 0.493926, -// -0.239428, 0.038247, -0.180721, -0.118035, 0.042175, 0.377296, -0.516399, 0.324744, -0.756196, 0.160856, -// -0.152527, -0.046867, -0.092933, -0.044945, 0.137659, 0.246552, -0.071709, 0.032821, -0.529356, -0.029669, -// 0.200178, 0.188916, 0.428036, -0.496734, -0.164185, 0.629070, -0.131588, 0.073992, 0.066877, 0.208450, -// -0.156170, -0.253670, -0.000365, -0.121172, 0.067774, 0.618226, 0.230460, -0.118865, 0.579424, 0.324523, -// 0.038653, 0.310308, 0.570186, -0.217271, -0.110967, 0.196375, 0.167058, 0.264071, -0.130023, 0.254189, -// -0.459057, -0.301033, 0.069932, -0.033338, -0.070600, 0.685064, 0.130274, 0.074929, -0.206899, 0.574057, -// 0.327277, -0.131588, -0.018497, 0.312445, 0.314594, 0.480422, -0.293858, -0.273277, -0.006598, -0.134574, -// 0.403501, 0.140025, 0.380693, -0.257039, -0.067012, 0.248776, -0.361838, -0.270296, -0.225844, 0.320245, -// 0.055730, 0.454809, -0.212163, -0.063281, 0.563112, -0.200737, 0.537389, -0.210845, 0.109997, 0.166215, -// -0.243725, -0.347349, -0.274348, 0.263950, 0.437134, 0.265820, -0.127520, -0.033325, -0.137156, 0.518557, -// 0.246720, 0.389394, -0.600568, 0.062027, -0.047838, -0.338416, 0.032778, -0.141998, -0.338022, -0.381467, -// 0.210512, -0.314413, 0.256321, 0.001460, 0.238901, 0.139840, 0.633423, -0.182575, -0.461504, 0.290250, -// -0.025930, 0.336998, -0.211280, -0.662387, -0.207946, -0.003860, -0.147842, 0.157217, 0.123704, 0.345686, -// 0.337946, 0.138261, -0.178814, -0.109597, 0.087135, -0.509500, -0.300296, -0.262279, 0.377476, -0.366815, -// 0.091787, 0.247495, -0.193812, -0.179714, 0.238552, -0.162305, -0.029549, 0.785426, -0.157586, -0.084533, -// -0.357024, 0.317878, 0.217656, 0.125319, 0.648832, 0.344045, -0.001109, 0.457190, -0.072439, -0.106278, -// 0.228962, -0.136139, -0.528342, -0.020840, -0.108908, -0.231661, 0.396864, 0.234925, 0.180894, -0.179430, -// -0.587730, 0.178276, -0.008672, -0.386172, 0.033155, 0.319568, 0.101457, -0.272011, 0.126007, 0.175374, -// -0.081668, 0.112987, -0.296422, -0.713743, 0.269413, -0.082098, -0.338649, 0.131035, -0.518616, 0.022478, -// 0.177802, -0.042432, -0.606219, -0.343848, 0.014416, -0.141375, 0.748332, -0.165911, -0.049067, -0.241062, -// 0.436318, 0.173318, 0.058066, 0.193764, -0.000647, 0.265777, -0.027847, -0.096305, 0.711632, 0.066506, -// -0.223124, 0.219165, -0.038165, 0.427444, -0.296887, 0.139982, 0.298976, 0.294876, -0.001315, 0.419802, -// 0.475401, -0.156256, -0.289477, -0.438761, -0.116348, 0.108350, -0.369368, -0.219943, 0.433088, 0.187565, -// -0.217259, 0.147014, -0.538991, -0.065052, 0.310337, 0.491887, 0.254439, 0.075052, 0.071155, -0.084856, -// 0.402098, 0.096270, 0.093662, -0.475769, 0.256832, 0.161394, -0.390050, -0.513551, -0.184665, 0.211506, -// -0.112525, -0.493409, -0.258765, 0.262124, -0.272998, 0.269370, 0.266226, -0.367919, 0.192386, -0.006422, -// -0.466728, -0.481792, 0.090611, -0.156359, 0.178693, -0.371658, -0.214190, -0.469058, -0.006134, 0.081902, -// 0.536950, 0.064836, -0.334010, 0.523530, -0.182061, -0.206686, 0.002985, 0.054858, -0.038727, -0.075390, -// 0.543839, -0.442964, -0.190550, -0.298127, -0.065323, 0.131415, 0.329899, 0.122096, -0.507075, 0.523751, -// -0.167317, 0.198593, -0.069066, 0.402739, 0.328583, 0.314184, -0.268003, -0.148549, 0.118925, -0.508174, -// 0.128716, -0.405597, -0.157224, 0.271021, -0.384444, -0.174935, 0.343919, -0.076726, 0.607931, 0.383931, -// 0.198254, 0.133707, 0.321460, -0.232543, 0.099988, -0.321954, -0.366304, -0.137440, 0.232835, -0.290306, -// -0.260804, -0.347721, 0.182895, 0.382311, -0.332847, -0.192469, -0.438258, -0.017533, -0.192976, -0.702531, -// 0.124463, 0.039719, -0.221319, -0.224785, 0.096356, -0.302131, -0.462598, 0.194320}); - - -// auto expV= NDArrayFactory::create('c', {2,2,10,10}, {-0.050761, 0.370975, -0.061567, -0.125530, 0.024081, 0.275524, -0.800334, -// -0.025855, 0.348132, 0.036882, 0.034921, 0.307295, 0.629837, 0.014276, 0.265687, 0.188407, -0.035481, 0.082827, -// -0.490175, 0.391118, -0.180180, 0.169108, 0.206663, 0.623321, 0.260009, 0.081943, 0.004485, 0.136199, 0.060353, -// -0.641224, -0.181559, -0.041761, 0.578416, -0.161798, -0.573128, -0.187563, 0.012533, 0.368041, 0.314619, -// -0.079349, -0.527508, 0.216020, 0.004721, 0.188769, -0.242534, -0.442685, -0.121683, -0.565306, -0.202894, -// 0.095280, -0.181900, -0.170627, -0.201655, 0.620259, -0.257996, 0.277656, -0.009623, 0.266775, 0.081952, -// 0.539241, -0.452254, -0.136142, 0.177049, -0.144734, 0.494673, 0.101613, 0.280091, -0.186281, 0.548779, -// 0.235160, 0.054763, -0.571503, 0.298086, 0.035312, -0.195188, 0.474030, -0.175457, -0.497267, -0.101439, -// -0.170678, -0.060605, -0.557305, 0.073433, 0.057195, 0.352091, -0.486102, -0.483569, 0.252091, -0.121245, -// 0.068719, -0.638919, -0.078029, -0.236556, -0.351440, -0.024437, 0.319855, -0.007406, 0.319691, -0.402334, -// -0.197966, 0.058936, -0.360900, 0.233414, -0.251532, 0.105457, 0.048097, 0.029321, 0.002714, -0.845953, -// -0.136344, 0.378037, 0.277491, 0.278420, 0.037491, 0.432117, -0.586745, 0.104573, 0.316569, -0.039848, -// 0.239645, -0.320923, 0.555156, 0.145059, -0.546959, 0.267760, 0.298029, 0.177831, -0.191286, -0.032427, -// 0.197034, 0.081887, -0.113063, 0.711713, 0.020279, -0.362346, -0.145776, 0.173289, -0.500880, 0.181624, -// 0.084391, -0.278967, 0.212143, -0.413382, 0.012879, -0.216886, -0.625774, 0.066795, -0.421937, -0.291320, -// 0.011402, -0.416660, -0.134200, 0.043039, 0.554715, 0.126867, 0.147315, 0.474334, 0.094354, -0.156458, -// 0.450168, 0.447448, 0.261750, -0.161426, -0.064309, -0.592417, 0.210891, 0.104312, 0.176178, -0.237020, -// 0.455579, -0.358056, -0.307454, 0.033700, -0.486831, -0.303963, -0.284916, 0.241549, 0.510701, 0.206104, -// 0.062587, 0.248212, 0.132088, -0.122704, 0.026342, -0.011108, 0.066306, 0.763127, 0.009491, 0.038822, -// -0.562773, -0.320104, 0.477773, 0.354169, 0.293329, -0.304227, -0.001662, -0.213324, 0.365277, -0.198056, -// -0.383499, -0.017789, 0.324542, -0.642856, 0.238689, -0.360461, -0.060599, -0.257192, 0.342400, 0.180845, -// 0.272810, -0.452278, -0.409323, 0.077013, -0.082561, 0.334893, -0.103309, -0.198049, 0.480416, 0.470593, -// 0.029072, -0.300574, 0.532293, 0.250892, -0.355298, 0.079716, -0.319781, 0.259925, 0.277872, -0.251917, -// 0.346821, 0.161642, 0.205861, 0.107125, -0.594779, -0.226272, 0.610183, -0.065926, 0.170332, 0.312553, -// -0.108093, 0.368268, -0.183109, -0.192222, -0.544559, 0.136824, -0.412352, -0.398250, -0.257291, 0.019911, -// 0.288797, 0.013350, 0.349817, -0.108331, 0.180576, 0.652863, 0.319319, 0.020218, -0.324499, 0.290877, -// 0.338518, -0.301776, -0.440871, -0.281683, -0.158759, -0.080281, 0.418260, 0.189926, -0.064112, -0.390914, -// 0.485420, -0.464327, 0.211070, 0.044295, -0.032292, 0.043985, 0.147160, -0.702247, -0.198395, -0.352940, -// -0.237014, -0.438235, 0.073448, -0.418712, -0.280275, -0.091373, -0.194273, 0.347558, -0.421767, 0.283011, -// -0.351869, -0.210088, -0.034628, 0.448410, 0.149194, -0.488551, -0.068805, -0.117007, -0.390999, 0.377100, -// 0.423252, -0.041944, 0.455115, -0.537818, 0.266732, 0.218202, 0.047475, -0.383506, -0.158858, 0.450881, -// 0.072415, 0.355772, 0.002360, 0.138976, 0.541349, -0.295405, 0.463832, 0.400676, -0.168962, 0.259334, -// -0.047960, 0.272197, 0.582658, 0.198052, 0.127300, -0.320468, -0.104858, -0.229698, 0.046672, -0.474224, -// 0.370765, -0.246450, 0.212667, 0.024935, -0.344530, -0.238547, 0.185931, 0.269068, 0.487414, 0.421376, -// 0.442391, -0.284247, 0.304973, -0.365006, -0.159016, -0.129088, -0.126454, 0.600462, -0.461163, -0.243552, -// -0.049814, -0.381340, -0.054504, 0.436237, 0.126120, -0.359677, -0.409734, -0.179422, -0.414820, 0.371149, -// 0.078299, 0.503544, 0.322165, 0.148341, -0.495447, -0.084355, -0.174667, 0.016802, -0.066954, 0.318825, -// -0.480771, -0.060163, 0.144302, -0.041555, 0.459106, 0.029882, -0.565026, 0.282336, 0.528472, 0.044916, -// -0.286167, -0.101052, -0.181529, -0.419406, -0.032204, -0.732282, 0.106833, -0.288881, 0.171516, -0.096242, -// -0.331834, -0.493188, 0.393195, 0.358365, 0.049125, 0.123457, 0.438169, -0.105015, 0.092386, -0.130413, -0.476991}); +// auto expU= NDArrayFactory::create('c', {2,2,11,11}, {-0.177870, +// -0.149461, -0.196911, 0.036990, -0.338237, 0.548901, +// -0.074396, 0.497067, -0.083636, +// -0.111810, -0.466989, -0.010465, +// 0.434732, 0.337198, 0.305239, +// -0.292813, 0.041280, -0.517144, +// 0.121499, 0.464908, 0.003658, +// 0.135017, -0.446916, -0.098318, +// 0.073571, -0.200521, 0.186776, +// -0.353022, -0.435582, -0.225959, +// 0.052972, 0.032390, -0.583801, +// -0.402790, 0.562809, 0.102744, +// 0.066555, 0.206079, 0.115322, +// 0.217220, -0.062591, -0.273173, +// -0.569645, 0.005612, 0.092601, +// 0.350055, -0.608007, -0.367743, +// 0.064860, 0.112656, 0.091576, +// -0.144262, 0.554655, -0.042100, +// -0.092023, 0.026986, -0.395811, +// -0.245209, 0.572522, 0.429430, +// 0.099621, -0.159236, -0.086263, +// 0.268160, -0.391298, 0.050417, +// 0.150175, 0.045253, 0.464173, +// 0.138376, 0.265551, 0.049691, +// 0.528778, 0.116951, 0.384609, +// 0.144416, -0.453591, -0.519390, +// -0.150671, 0.072897, 0.102406, +// -0.154184, 0.450735, 0.174171, +// -0.519405, 0.147109, 0.333670, +// 0.178053, 0.360763, 0.226976, +// 0.069976, -0.046765, 0.448897, +// 0.511309, -0.361050, -0.191690, +// -0.304442, 0.270383, -0.124133, +// 0.417183, -0.083359, 0.137022, +// 0.004276, -0.462336, 0.051267, +// 0.020622, -0.566932, -0.051351, +// -0.417106, -0.292202, -0.021595, +// -0.315956, 0.396626, -0.604952, +// 0.155990, 0.258395, -0.125080, +// 0.115404, 0.234517, -0.357460, +// 0.271271, 0.063771, -0.087400, +// -0.024710, -0.179892, 0.584339, +// -0.413085, 0.510580, 0.334646, +// 0.044424, 0.224735, 0.134434, +// -0.147861, 0.291853, 0.487948, +// 0.238917, 0.433893, 0.435884, +// 0.056370, -0.051216, -0.450902, +// 0.062411, 0.080733, -0.365211, +// 0.031931, 0.493926, -0.239428, +// 0.038247, -0.180721, -0.118035, +// 0.042175, 0.377296, -0.516399, +// 0.324744, -0.756196, 0.160856, +// -0.152527, -0.046867, -0.092933, +// -0.044945, 0.137659, 0.246552, +// -0.071709, 0.032821, -0.529356, +// -0.029669, 0.200178, 0.188916, +// 0.428036, -0.496734, -0.164185, +// 0.629070, -0.131588, 0.073992, +// 0.066877, 0.208450, -0.156170, +// -0.253670, -0.000365, -0.121172, +// 0.067774, 0.618226, 0.230460, +// -0.118865, 0.579424, 0.324523, +// 0.038653, 0.310308, 0.570186, +// -0.217271, -0.110967, 0.196375, +// 0.167058, 0.264071, -0.130023, +// 0.254189, -0.459057, -0.301033, +// 0.069932, -0.033338, -0.070600, +// 0.685064, 0.130274, 0.074929, +// -0.206899, 0.574057, 0.327277, +// -0.131588, -0.018497, 0.312445, +// 0.314594, 0.480422, -0.293858, +// -0.273277, -0.006598, -0.134574, +// 0.403501, 0.140025, 0.380693, +// -0.257039, -0.067012, 0.248776, +// -0.361838, -0.270296, -0.225844, +// 0.320245, 0.055730, 0.454809, +// -0.212163, -0.063281, 0.563112, +// -0.200737, 0.537389, -0.210845, +// 0.109997, 0.166215, -0.243725, +// -0.347349, -0.274348, 0.263950, +// 0.437134, 0.265820, -0.127520, +// -0.033325, -0.137156, 0.518557, +// 0.246720, 0.389394, -0.600568, +// 0.062027, -0.047838, -0.338416, +// 0.032778, -0.141998, -0.338022, +// -0.381467, 0.210512, -0.314413, +// 0.256321, 0.001460, 0.238901, +// 0.139840, 0.633423, -0.182575, +// -0.461504, 0.290250, -0.025930, +// 0.336998, -0.211280, -0.662387, +// -0.207946, -0.003860, -0.147842, +// 0.157217, 0.123704, 0.345686, +// 0.337946, 0.138261, -0.178814, +// -0.109597, 0.087135, -0.509500, +// -0.300296, -0.262279, 0.377476, +// -0.366815, 0.091787, 0.247495, +// -0.193812, -0.179714, 0.238552, +// -0.162305, -0.029549, 0.785426, +// -0.157586, -0.084533, -0.357024, +// 0.317878, 0.217656, 0.125319, +// 0.648832, 0.344045, -0.001109, +// 0.457190, -0.072439, -0.106278, +// 0.228962, -0.136139, -0.528342, +// -0.020840, -0.108908, -0.231661, +// 0.396864, 0.234925, 0.180894, +// -0.179430, -0.587730, 0.178276, +// -0.008672, -0.386172, 0.033155, +// 0.319568, 0.101457, -0.272011, +// 0.126007, 0.175374, -0.081668, +// 0.112987, -0.296422, -0.713743, +// 0.269413, -0.082098, -0.338649, +// 0.131035, -0.518616, 0.022478, +// 0.177802, -0.042432, -0.606219, +// -0.343848, 0.014416, -0.141375, +// 0.748332, -0.165911, -0.049067, +// -0.241062, 0.436318, 0.173318, +// 0.058066, 0.193764, -0.000647, +// 0.265777, -0.027847, -0.096305, +// 0.711632, 0.066506, -0.223124, +// 0.219165, -0.038165, 0.427444, +// -0.296887, 0.139982, 0.298976, +// 0.294876, -0.001315, 0.419802, +// 0.475401, -0.156256, -0.289477, +// -0.438761, -0.116348, 0.108350, +// -0.369368, -0.219943, 0.433088, +// 0.187565, -0.217259, 0.147014, +// -0.538991, -0.065052, 0.310337, +// 0.491887, 0.254439, 0.075052, +// 0.071155, -0.084856, 0.402098, +// 0.096270, 0.093662, -0.475769, +// 0.256832, 0.161394, -0.390050, +// -0.513551, -0.184665, 0.211506, +// -0.112525, -0.493409, -0.258765, +// 0.262124, -0.272998, 0.269370, +// 0.266226, -0.367919, 0.192386, +// -0.006422, -0.466728, -0.481792, +// 0.090611, -0.156359, 0.178693, +// -0.371658, -0.214190, -0.469058, +// -0.006134, 0.081902, 0.536950, +// 0.064836, -0.334010, 0.523530, +// -0.182061, -0.206686, 0.002985, +// 0.054858, -0.038727, -0.075390, +// 0.543839, -0.442964, -0.190550, +// -0.298127, -0.065323, 0.131415, +// 0.329899, 0.122096, -0.507075, +// 0.523751, -0.167317, 0.198593, +// -0.069066, 0.402739, 0.328583, +// 0.314184, -0.268003, -0.148549, +// 0.118925, -0.508174, 0.128716, +// -0.405597, -0.157224, 0.271021, +// -0.384444, -0.174935, 0.343919, +// -0.076726, 0.607931, 0.383931, +// 0.198254, 0.133707, 0.321460, +// -0.232543, 0.099988, -0.321954, +// -0.366304, -0.137440, 0.232835, +// -0.290306, -0.260804, -0.347721, +// 0.182895, 0.382311, -0.332847, +// -0.192469, -0.438258, -0.017533, +// -0.192976, -0.702531, 0.124463, +// 0.039719, -0.221319, -0.224785, +// 0.096356, -0.302131, -0.462598, +// 0.194320}); + +// auto expV= NDArrayFactory::create('c', {2,2,10,10}, {-0.050761, +// 0.370975, -0.061567, -0.125530, 0.024081, 0.275524, -0.800334, +// -0.025855, 0.348132, 0.036882, +// 0.034921, 0.307295, 0.629837, +// 0.014276, 0.265687, 0.188407, +// -0.035481, 0.082827, -0.490175, +// 0.391118, -0.180180, 0.169108, +// 0.206663, 0.623321, 0.260009, +// 0.081943, 0.004485, 0.136199, +// 0.060353, -0.641224, -0.181559, +// -0.041761, 0.578416, -0.161798, +// -0.573128, -0.187563, 0.012533, +// 0.368041, 0.314619, -0.079349, +// -0.527508, 0.216020, 0.004721, +// 0.188769, -0.242534, -0.442685, +// -0.121683, -0.565306, -0.202894, +// 0.095280, -0.181900, -0.170627, +// -0.201655, 0.620259, -0.257996, +// 0.277656, -0.009623, 0.266775, +// 0.081952, 0.539241, -0.452254, +// -0.136142, 0.177049, -0.144734, +// 0.494673, 0.101613, 0.280091, +// -0.186281, 0.548779, 0.235160, +// 0.054763, -0.571503, 0.298086, +// 0.035312, -0.195188, 0.474030, +// -0.175457, -0.497267, -0.101439, +// -0.170678, -0.060605, -0.557305, +// 0.073433, 0.057195, 0.352091, +// -0.486102, -0.483569, 0.252091, +// -0.121245, 0.068719, -0.638919, +// -0.078029, -0.236556, -0.351440, +// -0.024437, 0.319855, -0.007406, +// 0.319691, -0.402334, -0.197966, +// 0.058936, -0.360900, 0.233414, +// -0.251532, 0.105457, 0.048097, +// 0.029321, 0.002714, -0.845953, +// -0.136344, 0.378037, 0.277491, +// 0.278420, 0.037491, 0.432117, +// -0.586745, 0.104573, 0.316569, +// -0.039848, 0.239645, -0.320923, +// 0.555156, 0.145059, -0.546959, +// 0.267760, 0.298029, 0.177831, +// -0.191286, -0.032427, 0.197034, +// 0.081887, -0.113063, 0.711713, +// 0.020279, -0.362346, -0.145776, +// 0.173289, -0.500880, 0.181624, +// 0.084391, -0.278967, 0.212143, +// -0.413382, 0.012879, -0.216886, +// -0.625774, 0.066795, -0.421937, +// -0.291320, 0.011402, -0.416660, +// -0.134200, 0.043039, 0.554715, +// 0.126867, 0.147315, 0.474334, +// 0.094354, -0.156458, 0.450168, +// 0.447448, 0.261750, -0.161426, +// -0.064309, -0.592417, 0.210891, +// 0.104312, 0.176178, -0.237020, +// 0.455579, -0.358056, -0.307454, +// 0.033700, -0.486831, -0.303963, +// -0.284916, 0.241549, 0.510701, +// 0.206104, 0.062587, 0.248212, +// 0.132088, -0.122704, 0.026342, +// -0.011108, 0.066306, 0.763127, +// 0.009491, 0.038822, -0.562773, +// -0.320104, 0.477773, 0.354169, +// 0.293329, -0.304227, -0.001662, +// -0.213324, 0.365277, -0.198056, +// -0.383499, -0.017789, 0.324542, +// -0.642856, 0.238689, -0.360461, +// -0.060599, -0.257192, 0.342400, +// 0.180845, 0.272810, -0.452278, +// -0.409323, 0.077013, -0.082561, +// 0.334893, -0.103309, -0.198049, +// 0.480416, 0.470593, 0.029072, +// -0.300574, 0.532293, 0.250892, +// -0.355298, 0.079716, -0.319781, +// 0.259925, 0.277872, -0.251917, +// 0.346821, 0.161642, 0.205861, +// 0.107125, -0.594779, -0.226272, +// 0.610183, -0.065926, 0.170332, +// 0.312553, -0.108093, 0.368268, +// -0.183109, -0.192222, -0.544559, +// 0.136824, -0.412352, -0.398250, +// -0.257291, 0.019911, 0.288797, +// 0.013350, 0.349817, -0.108331, +// 0.180576, 0.652863, 0.319319, +// 0.020218, -0.324499, 0.290877, +// 0.338518, -0.301776, -0.440871, +// -0.281683, -0.158759, -0.080281, +// 0.418260, 0.189926, -0.064112, +// -0.390914, 0.485420, -0.464327, +// 0.211070, 0.044295, -0.032292, +// 0.043985, 0.147160, -0.702247, +// -0.198395, -0.352940, -0.237014, +// -0.438235, 0.073448, -0.418712, +// -0.280275, -0.091373, -0.194273, +// 0.347558, -0.421767, 0.283011, +// -0.351869, -0.210088, -0.034628, +// 0.448410, 0.149194, -0.488551, +// -0.068805, -0.117007, -0.390999, +// 0.377100, 0.423252, -0.041944, +// 0.455115, -0.537818, 0.266732, +// 0.218202, 0.047475, -0.383506, +// -0.158858, 0.450881, 0.072415, +// 0.355772, 0.002360, 0.138976, +// 0.541349, -0.295405, 0.463832, +// 0.400676, -0.168962, 0.259334, +// -0.047960, 0.272197, 0.582658, +// 0.198052, 0.127300, -0.320468, +// -0.104858, -0.229698, 0.046672, +// -0.474224, 0.370765, -0.246450, +// 0.212667, 0.024935, -0.344530, +// -0.238547, 0.185931, 0.269068, +// 0.487414, 0.421376, 0.442391, +// -0.284247, 0.304973, -0.365006, +// -0.159016, -0.129088, -0.126454, +// 0.600462, -0.461163, -0.243552, +// -0.049814, -0.381340, -0.054504, +// 0.436237, 0.126120, -0.359677, +// -0.409734, -0.179422, -0.414820, +// 0.371149, 0.078299, 0.503544, +// 0.322165, 0.148341, -0.495447, +// -0.084355, -0.174667, 0.016802, +// -0.066954, 0.318825, -0.480771, +// -0.060163, 0.144302, -0.041555, +// 0.459106, 0.029882, -0.565026, +// 0.282336, 0.528472, 0.044916, +// -0.286167, -0.101052, -0.181529, +// -0.419406, -0.032204, -0.732282, +// 0.106833, -0.288881, 0.171516, +// -0.096242, -0.331834, -0.493188, +// 0.393195, 0.358365, 0.049125, +// 0.123457, 0.438169, -0.105015, +// 0.092386, -0.130413, -0.476991}); // sd::ops::svd op; // auto results = op.execute({&x}, {}, {1, 1, 7}); // ASSERT_EQ(ND4J_STATUS_OK, result.status()); -// auto *s = result.at(0); -// auto *u = result.at(1); -// auto *v = result.at(2); +// auto s = result.at(0); +// auto u = result.at(1); +// auto v = result.at(2); - // ASSERT_TRUE(expS.isSameShape(s)); - // ASSERT_TRUE(expU.isSameShape(u)); - // ASSERT_TRUE(expV.isSameShape(v)); +// ASSERT_TRUE(expS.isSameShape(s)); +// ASSERT_TRUE(expU.isSameShape(u)); +// ASSERT_TRUE(expV.isSameShape(v)); - // ASSERT_TRUE(expS.equalsTo(s)); +// ASSERT_TRUE(expS.equalsTo(s)); - // if(sd::Environment::getInstance().isCPU()) { - // ASSERT_TRUE(expU.equalsTo(u)); - // ASSERT_TRUE(expV.equalsTo(v)); - // } - // else { - // for(uint i = 0; i < expU.lengthOf(); ++i) - // ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5); - // for(uint i = 0; i < expV.lengthOf(); ++i) - // ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5); - // } +// if(sd::Environment::getInstance().isCPU()) { +// ASSERT_TRUE(expU.equalsTo(u)); +// ASSERT_TRUE(expV.equalsTo(v)); +// } +// else { +// for(uint i = 0; i < expU.lengthOf(); ++i) +// ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), +// sd::math::nd4j_abs(u->e(i)), 1e-5); +// for(uint i = 0; i < expV.lengthOf(); ++i) +// ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), +// sd::math::nd4j_abs(v->e(i)), 1e-5); +// } // // } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test9) { - - auto x= NDArrayFactory::create('c', {2,2,5,6}, {17 ,-11 ,20 ,-10 ,19 ,13 ,-18 ,6 ,-2 ,-6 ,-10 ,4 ,-6 ,-4 ,3 ,16 ,12 , - -15 ,8 ,-8 ,12 ,-1 ,20 ,19 ,-13 ,0 ,20 ,17 ,-8 ,16 ,-19 ,7 ,-16 ,-14 ,-5 ,7 ,7 ,-5 ,12 ,-15 ,7 ,8 , - 1 ,-8 ,-17 ,10 ,-11 ,8 ,-10 ,1 ,-6 ,10 ,15 ,19 ,-15 ,8 ,2 ,8 ,12 ,7 ,-5 ,1 ,8 ,4 ,-13 ,2 ,19 ,-2 ,-10 , - -8 ,11 ,1 ,20 ,-11 ,4 ,1 ,-17 ,-15 ,0 ,-9 ,-4 ,-1 ,-6 ,-9 ,-13 ,10 ,7 ,-2 ,15 ,-10 ,-1 ,11 ,-20 ,-2 , - -1 ,-18 ,12 ,16 ,8 ,-9 ,-20 ,-7 ,-20 ,3 ,-9 ,12 ,8 ,-19 ,-2 ,2 ,1 ,7 ,10 ,-18 ,13 ,6 ,14 ,0 ,19 ,8}); - - auto expS= NDArrayFactory::create('c', {2,2,5}, {50.46507, 35.75599, 28.12787, 12.45245, 9.08545, - 38.56035, 30.62846, 26.31646, 19.42605, 3.01162, - 38.56369, 29.18881, 19.54565, 10.89746, 2.017 , - 44.99108, 34.95059, 26.00453, 15.43898, 7.18752}); - - auto expU= NDArrayFactory::create('c', {2,2,5,5}, {-0.73644, -0.10751, 0.10081, -0.00325, 0.66025,0.26329, 0.3079 , 0.38582, 0.77696, 0.28872,0.03076, 0.03015, -0.9128 , 0.36387, 0.18039, - -0.61335, 0.10076, 0.01381, 0.40922, -0.66783,-0.10577, 0.93946, -0.0871 , -0.31058, 0.04677,0.52823, 0.31163, -0.78777, 0.02322, -0.05234, - -0.23942, -0.45801, -0.34248, 0.71286, 0.32778,0.26147, 0.60409, 0.39933, 0.46862, 0.43318,0.62118, -0.37993, 0.30992, 0.34537, -0.50444, - 0.45763, -0.42877, 0.08128, -0.3904 , 0.66912,-0.05428, 0.53632, 0.19774, -0.32198, 0.75276,-0.21986, -0.8214 , -0.00392, -0.1659 , 0.49944, - -0.79443, 0.1633 , -0.45374, -0.31666, -0.18989,-0.24459, 0.10463, -0.27652, 0.85595, 0.34657,0.50772, 0.00757, -0.82374, -0.18941, 0.16658, 0.49473, -0.39923, -0.20758, 0.74339, -0.01213, - -0.2024 , -0.80239, -0.35502, -0.3982 , -0.17492,0.68875, 0.1822 , -0.08046, -0.39238, -0.57619,0.34555, 0.12488, -0.50703, -0.29269, 0.72267,-0.34713, 0.3847 , -0.7532 , 0.22176, -0.33913}); - - auto expV= NDArrayFactory::create('c', {2,2,6,6}, {-4.15640000e-01, -5.30190000e-01, 5.29200000e-02, -7.15710000e-01,-1.10690000e-01, 1.37280000e-01,2.86620000e-01, 5.88200000e-02, 1.68760000e-01, -2.55000000e-03,-1.00090000e-01, 9.35890000e-01, - -4.88230000e-01, 4.84470000e-01, -1.09150000e-01, -1.46810000e-01,6.70320000e-01, 2.10040000e-01,1.00910000e-01, 4.35740000e-01, -6.90500000e-01, -3.61090000e-01,-4.38680000e-01, 1.83200000e-02, - -5.48440000e-01, -2.86950000e-01, -4.23900000e-01, 5.78540000e-01,-2.10060000e-01, 2.41550000e-01,-4.42450000e-01, 4.56640000e-01, 5.48020000e-01, 3.32100000e-02,-5.40210000e-01, -4.97000000e-02, - -6.36070000e-01, 5.57600000e-02, 3.28740000e-01, 3.81950000e-01,-4.21850000e-01, 4.00490000e-01,1.83740000e-01, -1.36190000e-01, -2.29380000e-01, -5.11090000e-01,-2.06580000e-01, 7.68890000e-01, - -4.81880000e-01, -6.31100000e-01, 3.40000000e-04, -1.35730000e-01,5.88210000e-01, 7.12900000e-02,2.25200000e-01, 4.30600000e-02, 9.08510000e-01, -3.08940000e-01,1.51570000e-01, 6.02100000e-02, - 1.97510000e-01, -7.26560000e-01, 1.05370000e-01, 1.10600000e-02,-5.79750000e-01, -2.92870000e-01,4.89620000e-01, -2.24300000e-01, 5.31200000e-02, 6.92040000e-01,2.72560000e-01, 3.92350000e-01, - -6.84450000e-01, -5.18030000e-01, 2.92000000e-02, -4.96740000e-01,-1.17970000e-01, -4.08100000e-02,4.25340000e-01, -1.65500000e-02, -2.82400000e-02, -5.60180000e-01,1.93050000e-01, -6.83340000e-01, - 8.08800000e-02, 4.38260000e-01, -2.48340000e-01, -6.36220000e-01,2.37500000e-02, 5.78250000e-01,-6.10000000e-04, 3.00110000e-01, 1.17290000e-01, -6.92400000e-02,-9.19220000e-01, -2.15420000e-01, - 5.41330000e-01, -6.61130000e-01, -2.86360000e-01, -2.13500000e-02,-3.19580000e-01, 2.92020000e-01,2.25920000e-01, -1.10170000e-01, 9.17020000e-01, -1.71540000e-01,3.39100000e-02, 2.55590000e-01, - -4.86810000e-01, -2.32390000e-01, -4.31500000e-01, 3.75290000e-01,4.98470000e-01, -3.65370000e-01,6.39700000e-02, -4.04150000e-01, -5.28310000e-01, 8.90000000e-02,-7.30460000e-01, -1.09390000e-01, - -4.94030000e-01, 1.55540000e-01, -3.46720000e-01, -7.58460000e-01,5.20000000e-04, 1.90420000e-01,2.55960000e-01, 3.17040000e-01, -3.47800000e-02, -3.01860000e-01,-3.57600000e-02, -8.60450000e-01, - 1.31650000e-01, 7.57150000e-01, -4.89030000e-01, 3.47710000e-01,-4.39400000e-02, 2.17750000e-01,-6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01,-4.63400000e-01, -1.74620000e-01}); - - sd::ops::svd op; - auto result = op.evaluate({&x}, {}, {1, 1, 16}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto *s = result.at(0); - auto *u = result.at(1); - auto *v = result.at(2); - - ASSERT_TRUE(expS.isSameShape(s)); - ASSERT_TRUE(expU.isSameShape(u)); - ASSERT_TRUE(expV.isSameShape(v)); - - ASSERT_TRUE(expS.equalsTo(s)); - - if(sd::Environment::getInstance().isCPU()) { - ASSERT_TRUE(expU.equalsTo(u)); - ASSERT_TRUE(expV.equalsTo(v)); - } - else { - for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5); - for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5); - } - + auto x = NDArrayFactory::create( + 'c', {2, 2, 5, 6}, + {17, -11, 20, -10, 19, 13, -18, 6, -2, -6, -10, 4, -6, -4, 3, + 16, 12, -15, 8, -8, 12, -1, 20, 19, -13, 0, 20, 17, -8, 16, + -19, 7, -16, -14, -5, 7, 7, -5, 12, -15, 7, 8, 1, -8, -17, + 10, -11, 8, -10, 1, -6, 10, 15, 19, -15, 8, 2, 8, 12, 7, + -5, 1, 8, 4, -13, 2, 19, -2, -10, -8, 11, 1, 20, -11, 4, + 1, -17, -15, 0, -9, -4, -1, -6, -9, -13, 10, 7, -2, 15, -10, + -1, 11, -20, -2, -1, -18, 12, 16, 8, -9, -20, -7, -20, 3, -9, + 12, 8, -19, -2, 2, 1, 7, 10, -18, 13, 6, 14, 0, 19, 8}); + + auto expS = NDArrayFactory::create( + 'c', {2, 2, 5}, + {50.46507, 35.75599, 28.12787, 12.45245, 9.08545, 38.56035, 30.62846, + 26.31646, 19.42605, 3.01162, 38.56369, 29.18881, 19.54565, 10.89746, + 2.017, 44.99108, 34.95059, 26.00453, 15.43898, 7.18752}); + + auto expU = NDArrayFactory::create( + 'c', {2, 2, 5, 5}, + {-0.73644, -0.10751, 0.10081, -0.00325, 0.66025, 0.26329, 0.3079, + 0.38582, 0.77696, 0.28872, 0.03076, 0.03015, -0.9128, 0.36387, + 0.18039, -0.61335, 0.10076, 0.01381, 0.40922, -0.66783, -0.10577, + 0.93946, -0.0871, -0.31058, 0.04677, 0.52823, 0.31163, -0.78777, + 0.02322, -0.05234, -0.23942, -0.45801, -0.34248, 0.71286, 0.32778, + 0.26147, 0.60409, 0.39933, 0.46862, 0.43318, 0.62118, -0.37993, + 0.30992, 0.34537, -0.50444, 0.45763, -0.42877, 0.08128, -0.3904, + 0.66912, -0.05428, 0.53632, 0.19774, -0.32198, 0.75276, -0.21986, + -0.8214, -0.00392, -0.1659, 0.49944, -0.79443, 0.1633, -0.45374, + -0.31666, -0.18989, -0.24459, 0.10463, -0.27652, 0.85595, 0.34657, + 0.50772, 0.00757, -0.82374, -0.18941, 0.16658, 0.49473, -0.39923, + -0.20758, 0.74339, -0.01213, -0.2024, -0.80239, -0.35502, -0.3982, + -0.17492, 0.68875, 0.1822, -0.08046, -0.39238, -0.57619, 0.34555, + 0.12488, -0.50703, -0.29269, 0.72267, -0.34713, 0.3847, -0.7532, + 0.22176, -0.33913}); + + auto expV = NDArrayFactory::create( + 'c', {2, 2, 6, 6}, + {-4.15640000e-01, -5.30190000e-01, 5.29200000e-02, -7.15710000e-01, + -1.10690000e-01, 1.37280000e-01, 2.86620000e-01, 5.88200000e-02, + 1.68760000e-01, -2.55000000e-03, -1.00090000e-01, 9.35890000e-01, + -4.88230000e-01, 4.84470000e-01, -1.09150000e-01, -1.46810000e-01, + 6.70320000e-01, 2.10040000e-01, 1.00910000e-01, 4.35740000e-01, + -6.90500000e-01, -3.61090000e-01, -4.38680000e-01, 1.83200000e-02, + -5.48440000e-01, -2.86950000e-01, -4.23900000e-01, 5.78540000e-01, + -2.10060000e-01, 2.41550000e-01, -4.42450000e-01, 4.56640000e-01, + 5.48020000e-01, 3.32100000e-02, -5.40210000e-01, -4.97000000e-02, + -6.36070000e-01, 5.57600000e-02, 3.28740000e-01, 3.81950000e-01, + -4.21850000e-01, 4.00490000e-01, 1.83740000e-01, -1.36190000e-01, + -2.29380000e-01, -5.11090000e-01, -2.06580000e-01, 7.68890000e-01, + -4.81880000e-01, -6.31100000e-01, 3.40000000e-04, -1.35730000e-01, + 5.88210000e-01, 7.12900000e-02, 2.25200000e-01, 4.30600000e-02, + 9.08510000e-01, -3.08940000e-01, 1.51570000e-01, 6.02100000e-02, + 1.97510000e-01, -7.26560000e-01, 1.05370000e-01, 1.10600000e-02, + -5.79750000e-01, -2.92870000e-01, 4.89620000e-01, -2.24300000e-01, + 5.31200000e-02, 6.92040000e-01, 2.72560000e-01, 3.92350000e-01, + -6.84450000e-01, -5.18030000e-01, 2.92000000e-02, -4.96740000e-01, + -1.17970000e-01, -4.08100000e-02, 4.25340000e-01, -1.65500000e-02, + -2.82400000e-02, -5.60180000e-01, 1.93050000e-01, -6.83340000e-01, + 8.08800000e-02, 4.38260000e-01, -2.48340000e-01, -6.36220000e-01, + 2.37500000e-02, 5.78250000e-01, -6.10000000e-04, 3.00110000e-01, + 1.17290000e-01, -6.92400000e-02, -9.19220000e-01, -2.15420000e-01, + 5.41330000e-01, -6.61130000e-01, -2.86360000e-01, -2.13500000e-02, + -3.19580000e-01, 2.92020000e-01, 2.25920000e-01, -1.10170000e-01, + 9.17020000e-01, -1.71540000e-01, 3.39100000e-02, 2.55590000e-01, + -4.86810000e-01, -2.32390000e-01, -4.31500000e-01, 3.75290000e-01, + 4.98470000e-01, -3.65370000e-01, 6.39700000e-02, -4.04150000e-01, + -5.28310000e-01, 8.90000000e-02, -7.30460000e-01, -1.09390000e-01, + -4.94030000e-01, 1.55540000e-01, -3.46720000e-01, -7.58460000e-01, + 5.20000000e-04, 1.90420000e-01, 2.55960000e-01, 3.17040000e-01, + -3.47800000e-02, -3.01860000e-01, -3.57600000e-02, -8.60450000e-01, + 1.31650000e-01, 7.57150000e-01, -4.89030000e-01, 3.47710000e-01, + -4.39400000e-02, 2.17750000e-01, -6.57270000e-01, 2.91000000e-01, + 4.17280000e-01, 2.52880000e-01, -4.63400000e-01, -1.74620000e-01}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {1, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if (sd::Environment::getInstance().isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } else { + for (uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), + sd::math::nd4j_abs(u.e(i)), 1e-5); + for (uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), + sd::math::nd4j_abs(v.e(i)), 1e-5); + } } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test10) { - - auto x= NDArrayFactory::create('c', {2,2,5,6}, {17 ,-11 ,20 ,-10 ,19 ,13 ,-18 ,6 ,-2 ,-6 ,-10 ,4 ,-6 ,-4 ,3 ,16 ,12 , - -15 ,8 ,-8 ,12 ,-1 ,20 ,19 ,-13 ,0 ,20 ,17 ,-8 ,16 ,-19 ,7 ,-16 ,-14 ,-5 ,7 ,7 ,-5 ,12 ,-15 ,7 ,8 , - 1 ,-8 ,-17 ,10 ,-11 ,8 ,-10 ,1 ,-6 ,10 ,15 ,19 ,-15 ,8 ,2 ,8 ,12 ,7 ,-5 ,1 ,8 ,4 ,-13 ,2 ,19 ,-2 ,-10 , - -8 ,11 ,1 ,20 ,-11 ,4 ,1 ,-17 ,-15 ,0 ,-9 ,-4 ,-1 ,-6 ,-9 ,-13 ,10 ,7 ,-2 ,15 ,-10 ,-1 ,11 ,-20 ,-2 , - -1 ,-18 ,12 ,16 ,8 ,-9 ,-20 ,-7 ,-20 ,3 ,-9 ,12 ,8 ,-19 ,-2 ,2 ,1 ,7 ,10 ,-18 ,13 ,6 ,14 ,0 ,19 ,8}); - - auto expS= NDArrayFactory::create('c', {2,2,5}, {50.46507, 35.75599, 28.12787, 12.45245, 9.08545, - 38.56035, 30.62846, 26.31646, 19.42605, 3.01162, - 38.56369, 29.18881, 19.54565, 10.89746, 2.017 , - 44.99108, 34.95059, 26.00453, 15.43898, 7.18752}); - - auto expU= NDArrayFactory::create('c', {2,2,5,5}, {-0.73644, -0.10751, 0.10081, -0.00325, 0.66025,0.26329, 0.3079 , 0.38582, 0.77696, 0.28872,0.03076, 0.03015, -0.9128 , 0.36387, 0.18039,-0.61335, 0.10076, 0.01381, 0.40922, -0.66783, - -0.10577, 0.93946, -0.0871 , -0.31058, 0.04677,0.52823, 0.31163, -0.78777, 0.02322, -0.05234,-0.23942, -0.45801, -0.34248, 0.71286, 0.32778,0.26147, 0.60409, 0.39933, 0.46862, 0.43318, - 0.62118, -0.37993, 0.30992, 0.34537, -0.50444,0.45763, -0.42877, 0.08128, -0.3904 , 0.66912,-0.05428, 0.53632, 0.19774, -0.32198, 0.75276,-0.21986, -0.8214 , -0.00392, -0.1659 , 0.49944, - -0.79443, 0.1633 , -0.45374, -0.31666, -0.18989,-0.24459, 0.10463, -0.27652, 0.85595, 0.34657,0.50772, 0.00757, -0.82374, -0.18941, 0.16658,0.49473, -0.39923, -0.20758, 0.74339, -0.01213, - -0.2024 , -0.80239, -0.35502, -0.3982 , -0.17492,0.68875, 0.1822 , -0.08046, -0.39238, -0.57619,0.34555, 0.12488, -0.50703, -0.29269, 0.72267,-0.34713, 0.3847 , -0.7532 , 0.22176, -0.33913}); - - auto expV= NDArrayFactory::create('c', {2,2,6,5}, { -4.15640000e-01, -5.30190000e-01, 5.29200000e-02, -7.15710000e-01,-1.10690000e-01,2.86620000e-01, 5.88200000e-02, 1.68760000e-01, -2.55000000e-03,-1.00090000e-01, - -4.88230000e-01, 4.84470000e-01, -1.09150000e-01, -1.46810000e-01,6.70320000e-01,1.00910000e-01, 4.35740000e-01, -6.90500000e-01, -3.61090000e-01,-4.38680000e-01,-5.48440000e-01, -2.86950000e-01, -4.23900000e-01, 5.78540000e-01, - -2.10060000e-01,-4.42450000e-01, 4.56640000e-01, 5.48020000e-01, 3.32100000e-02,-5.40210000e-01,-6.36070000e-01, 5.57600000e-02, 3.28740000e-01, 3.81950000e-01,-4.21850000e-01, - 1.83740000e-01, -1.36190000e-01, -2.29380000e-01, -5.11090000e-01,-2.06580000e-01,-4.81880000e-01, -6.31100000e-01, 3.40000000e-04, -1.35730000e-01,5.88210000e-01,2.25200000e-01, 4.30600000e-02, 9.08510000e-01, -3.08940000e-01, - 1.51570000e-01,1.97510000e-01, -7.26560000e-01, 1.05370000e-01, 1.10600000e-02,-5.79750000e-01,4.89620000e-01, -2.24300000e-01, 5.31200000e-02, 6.92040000e-01,2.72560000e-01, - -6.84450000e-01, -5.18030000e-01, 2.92000000e-02, -4.96740000e-01,-1.17970000e-01,4.25340000e-01, -1.65500000e-02, -2.82400000e-02, -5.60180000e-01,1.93050000e-01,8.08800000e-02, 4.38260000e-01, -2.48340000e-01, -6.36220000e-01,2.37500000e-02,-6.10000000e-04, 3.00110000e-01, 1.17290000e-01, -6.92400000e-02,-9.19220000e-01, - 5.41330000e-01, -6.61130000e-01, -2.86360000e-01, -2.13500000e-02,-3.19580000e-01,2.25920000e-01, -1.10170000e-01, 9.17020000e-01, -1.71540000e-01,3.39100000e-02,-4.86810000e-01, -2.32390000e-01, -4.31500000e-01, 3.75290000e-01,4.98470000e-01,6.39700000e-02, -4.04150000e-01, -5.28310000e-01, 8.90000000e-02,-7.30460000e-01, - -4.94030000e-01, 1.55540000e-01, -3.46720000e-01, -7.58460000e-01,5.20000000e-04,2.55960000e-01, 3.17040000e-01, -3.47800000e-02, -3.01860000e-01,-3.57600000e-02,1.31650000e-01, 7.57150000e-01, -4.89030000e-01, 3.47710000e-01, - -4.39400000e-02,-6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01,-4.63400000e-01}); - - sd::ops::svd op; - auto result = op.evaluate({&x}, {}, {0, 1, 16}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto *s = result.at(0); - auto *u = result.at(1); - auto *v = result.at(2); - - ASSERT_TRUE(expS.isSameShape(s)); - ASSERT_TRUE(expU.isSameShape(u)); - ASSERT_TRUE(expV.isSameShape(v)); - - ASSERT_TRUE(expS.equalsTo(s)); - - if(sd::Environment::getInstance().isCPU()) { - ASSERT_TRUE(expU.equalsTo(u)); - ASSERT_TRUE(expV.equalsTo(v)); - } - else { - for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5); - for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5); - } - + auto x = NDArrayFactory::create( + 'c', {2, 2, 5, 6}, + {17, -11, 20, -10, 19, 13, -18, 6, -2, -6, -10, 4, -6, -4, 3, + 16, 12, -15, 8, -8, 12, -1, 20, 19, -13, 0, 20, 17, -8, 16, + -19, 7, -16, -14, -5, 7, 7, -5, 12, -15, 7, 8, 1, -8, -17, + 10, -11, 8, -10, 1, -6, 10, 15, 19, -15, 8, 2, 8, 12, 7, + -5, 1, 8, 4, -13, 2, 19, -2, -10, -8, 11, 1, 20, -11, 4, + 1, -17, -15, 0, -9, -4, -1, -6, -9, -13, 10, 7, -2, 15, -10, + -1, 11, -20, -2, -1, -18, 12, 16, 8, -9, -20, -7, -20, 3, -9, + 12, 8, -19, -2, 2, 1, 7, 10, -18, 13, 6, 14, 0, 19, 8}); + + auto expS = NDArrayFactory::create( + 'c', {2, 2, 5}, + {50.46507, 35.75599, 28.12787, 12.45245, 9.08545, 38.56035, 30.62846, + 26.31646, 19.42605, 3.01162, 38.56369, 29.18881, 19.54565, 10.89746, + 2.017, 44.99108, 34.95059, 26.00453, 15.43898, 7.18752}); + + auto expU = NDArrayFactory::create( + 'c', {2, 2, 5, 5}, + {-0.73644, -0.10751, 0.10081, -0.00325, 0.66025, 0.26329, 0.3079, + 0.38582, 0.77696, 0.28872, 0.03076, 0.03015, -0.9128, 0.36387, + 0.18039, -0.61335, 0.10076, 0.01381, 0.40922, -0.66783, -0.10577, + 0.93946, -0.0871, -0.31058, 0.04677, 0.52823, 0.31163, -0.78777, + 0.02322, -0.05234, -0.23942, -0.45801, -0.34248, 0.71286, 0.32778, + 0.26147, 0.60409, 0.39933, 0.46862, 0.43318, 0.62118, -0.37993, + 0.30992, 0.34537, -0.50444, 0.45763, -0.42877, 0.08128, -0.3904, + 0.66912, -0.05428, 0.53632, 0.19774, -0.32198, 0.75276, -0.21986, + -0.8214, -0.00392, -0.1659, 0.49944, -0.79443, 0.1633, -0.45374, + -0.31666, -0.18989, -0.24459, 0.10463, -0.27652, 0.85595, 0.34657, + 0.50772, 0.00757, -0.82374, -0.18941, 0.16658, 0.49473, -0.39923, + -0.20758, 0.74339, -0.01213, -0.2024, -0.80239, -0.35502, -0.3982, + -0.17492, 0.68875, 0.1822, -0.08046, -0.39238, -0.57619, 0.34555, + 0.12488, -0.50703, -0.29269, 0.72267, -0.34713, 0.3847, -0.7532, + 0.22176, -0.33913}); + + auto expV = NDArrayFactory::create( + 'c', {2, 2, 6, 5}, + {-4.15640000e-01, -5.30190000e-01, 5.29200000e-02, -7.15710000e-01, + -1.10690000e-01, 2.86620000e-01, 5.88200000e-02, 1.68760000e-01, + -2.55000000e-03, -1.00090000e-01, -4.88230000e-01, 4.84470000e-01, + -1.09150000e-01, -1.46810000e-01, 6.70320000e-01, 1.00910000e-01, + 4.35740000e-01, -6.90500000e-01, -3.61090000e-01, -4.38680000e-01, + -5.48440000e-01, -2.86950000e-01, -4.23900000e-01, 5.78540000e-01, + -2.10060000e-01, -4.42450000e-01, 4.56640000e-01, 5.48020000e-01, + 3.32100000e-02, -5.40210000e-01, -6.36070000e-01, 5.57600000e-02, + 3.28740000e-01, 3.81950000e-01, -4.21850000e-01, 1.83740000e-01, + -1.36190000e-01, -2.29380000e-01, -5.11090000e-01, -2.06580000e-01, + -4.81880000e-01, -6.31100000e-01, 3.40000000e-04, -1.35730000e-01, + 5.88210000e-01, 2.25200000e-01, 4.30600000e-02, 9.08510000e-01, + -3.08940000e-01, 1.51570000e-01, 1.97510000e-01, -7.26560000e-01, + 1.05370000e-01, 1.10600000e-02, -5.79750000e-01, 4.89620000e-01, + -2.24300000e-01, 5.31200000e-02, 6.92040000e-01, 2.72560000e-01, + -6.84450000e-01, -5.18030000e-01, 2.92000000e-02, -4.96740000e-01, + -1.17970000e-01, 4.25340000e-01, -1.65500000e-02, -2.82400000e-02, + -5.60180000e-01, 1.93050000e-01, 8.08800000e-02, 4.38260000e-01, + -2.48340000e-01, -6.36220000e-01, 2.37500000e-02, -6.10000000e-04, + 3.00110000e-01, 1.17290000e-01, -6.92400000e-02, -9.19220000e-01, + 5.41330000e-01, -6.61130000e-01, -2.86360000e-01, -2.13500000e-02, + -3.19580000e-01, 2.25920000e-01, -1.10170000e-01, 9.17020000e-01, + -1.71540000e-01, 3.39100000e-02, -4.86810000e-01, -2.32390000e-01, + -4.31500000e-01, 3.75290000e-01, 4.98470000e-01, 6.39700000e-02, + -4.04150000e-01, -5.28310000e-01, 8.90000000e-02, -7.30460000e-01, + -4.94030000e-01, 1.55540000e-01, -3.46720000e-01, -7.58460000e-01, + 5.20000000e-04, 2.55960000e-01, 3.17040000e-01, -3.47800000e-02, + -3.01860000e-01, -3.57600000e-02, 1.31650000e-01, 7.57150000e-01, + -4.89030000e-01, 3.47710000e-01, -4.39400000e-02, -6.57270000e-01, + 2.91000000e-01, 4.17280000e-01, 2.52880000e-01, -4.63400000e-01}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {0, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if (sd::Environment::getInstance().isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } else { + for (uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), + sd::math::nd4j_abs(u.e(i)), 1e-5); + for (uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), + sd::math::nd4j_abs(v.e(i)), 1e-5); + } } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test11) { - - NDArray x('c', {2,2,3,3}, {0.7788, 0.8012, 0.7244, 0.2309, 0.7271, 0.1804, 0.5056, 0.8925, 0.5461, 0.9234, 0.0856, 0.7938, 0.6591, 0.5555, - 0.1596, 0.3087, 0.1548, 0.4695, 0.7788, 0.8012, 0.7244, 0.2309, 0.7271, 0.1804, 0.5056, 0.8925, -0.5461, 0.9234, - 0.0856, -0.7938, 0.6591, 0.5555, 0.1500, 0.3087, 0.1548, 0.4695}); - NDArray expS('c', {2,2,3}, {1.89671, 0.37095, 0.05525,1.51296, 0.52741, 0.17622, 1.69095, 0.90438, 0.24688,1.33551, 0.87475, 0.21571}); - NDArray expU('c', {2,2,3,3}, {6.9205e-01, 6.0147e-01, -3.9914e-01, 3.8423e-01, -7.7503e-01, -5.0170e-01, 6.1110e-01, -1.9384e-01, 7.6746e-01, - 7.8967e-01, 4.5442e-01, -4.1222e-01, 4.9381e-01, -8.6948e-01, -1.2540e-02, 3.6412e-01, 1.9366e-01, 9.1100e-01, - 7.1764e-01, 5.9844e-01, 3.5617e-01, 4.4477e-01, -3.1000e-04, -8.9564e-01, 5.3588e-01, -8.0116e-01, 2.6639e-01, - 8.7050e-01, -4.2088e-01, -2.5513e-01, 4.8622e-01, 6.5499e-01, 5.7843e-01, 7.6340e-02, 6.2757e-01, -7.7481e-01}); - NDArray expV('c', {2,2,3,3}, {0.49383, 0.51614, -0.69981, 0.72718, -0.68641, 0.00688, 0.4768 , 0.51228, 0.7143 , 0.77137, -0.17763, - -0.6111 , 0.26324, -0.7852 , 0.56051, 0.57939, 0.59322, 0.55892, 0.55149, 0.06737, 0.83146, 0.81413, - -0.26072, -0.51887, 0.18182, 0.96306, -0.19863, 0.85948, 0.2707 , -0.4336 , 0.26688, 0.48582, 0.83232, - -0.43596, 0.83108, -0.34531}); - - sd::ops::svd op; - auto result = op.evaluate({&x}, {}, {0, 1, 16}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto s = result.at(0); - auto u = result.at(1); - auto v = result.at(2); - - ASSERT_TRUE(expS.isSameShape(s)); - ASSERT_TRUE(expU.isSameShape(u)); - ASSERT_TRUE(expV.isSameShape(v)); - - ASSERT_TRUE(expS.equalsTo(s)); - - if(sd::Environment::getInstance().isCPU()) { - ASSERT_TRUE(expU.equalsTo(u)); - ASSERT_TRUE(expV.equalsTo(v)); - } - else { - for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5); - for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5); - } - + NDArray x('c', {2, 2, 3, 3}, + {0.7788, 0.8012, 0.7244, 0.2309, 0.7271, 0.1804, 0.5056, 0.8925, + 0.5461, 0.9234, 0.0856, 0.7938, 0.6591, 0.5555, 0.1596, 0.3087, + 0.1548, 0.4695, 0.7788, 0.8012, 0.7244, 0.2309, 0.7271, 0.1804, + 0.5056, 0.8925, -0.5461, 0.9234, 0.0856, -0.7938, 0.6591, 0.5555, + 0.1500, 0.3087, 0.1548, 0.4695}); + NDArray expS('c', {2, 2, 3}, + {1.89671, 0.37095, 0.05525, 1.51296, 0.52741, 0.17622, 1.69095, + 0.90438, 0.24688, 1.33551, 0.87475, 0.21571}); + NDArray expU('c', {2, 2, 3, 3}, + {6.9205e-01, 6.0147e-01, -3.9914e-01, 3.8423e-01, -7.7503e-01, + -5.0170e-01, 6.1110e-01, -1.9384e-01, 7.6746e-01, 7.8967e-01, + 4.5442e-01, -4.1222e-01, 4.9381e-01, -8.6948e-01, -1.2540e-02, + 3.6412e-01, 1.9366e-01, 9.1100e-01, 7.1764e-01, 5.9844e-01, + 3.5617e-01, 4.4477e-01, -3.1000e-04, -8.9564e-01, 5.3588e-01, + -8.0116e-01, 2.6639e-01, 8.7050e-01, -4.2088e-01, -2.5513e-01, + 4.8622e-01, 6.5499e-01, 5.7843e-01, 7.6340e-02, 6.2757e-01, + -7.7481e-01}); + NDArray expV('c', {2, 2, 3, 3}, + {0.49383, 0.51614, -0.69981, 0.72718, -0.68641, 0.00688, + 0.4768, 0.51228, 0.7143, 0.77137, -0.17763, -0.6111, + 0.26324, -0.7852, 0.56051, 0.57939, 0.59322, 0.55892, + 0.55149, 0.06737, 0.83146, 0.81413, -0.26072, -0.51887, + 0.18182, 0.96306, -0.19863, 0.85948, 0.2707, -0.4336, + 0.26688, 0.48582, 0.83232, -0.43596, 0.83108, -0.34531}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {0, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if (sd::Environment::getInstance().isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } else { + for (uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), + sd::math::nd4j_abs(u.e(i)), 1e-5); + for (uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), + sd::math::nd4j_abs(v.e(i)), 1e-5); + } } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test12) { + NDArray x( + 'c', {4, 3}, + {1.7787856, 0.80119777, 0.72437465, 0.23089433, 1.7271413, 0.18039072, + 0.50563407, 0.89252293, 1.5461209, 0.92336726, 0.085571885, 0.79378015}); + NDArray expS('c', {3}, {3.024703, 1.459483, 1.026371}); - NDArray x('c', {4,3}, {1.7787856,0.80119777,0.72437465,0.23089433,1.7271413,0.18039072,0.50563407,0.89252293,1.5461209,0.92336726,0.085571885,0.79378015}); - NDArray expS('c', {3}, {3.024703, 1.459483, 1.026371}); - - - sd::ops::svd op; - auto result = op.evaluate({&x}, {}, {1, 0, 16}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {1, 0, 16}); - auto s = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(expS.equalsTo(s)); - ASSERT_TRUE(expS.isSameShape(s)); + auto s = result.at(0); + ASSERT_TRUE(expS.equalsTo(s)); + ASSERT_TRUE(expS.isSameShape(s)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, elu_test1) { + auto x = NDArrayFactory::create( + 'c', {3, 3}, {0.1, .2, .3, -.4, -.5, -.6, .7, .8, .9}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {.1, .2, .3, 0.5 * -0.32968, 0.5 * -0.393469, 0.5 * -0.451188, .7, .8, + .9}); - auto x = NDArrayFactory::create('c', {3,3}, {0.1, .2, .3, -.4,-.5,-.6, .7, .8, .9}); - auto exp = NDArrayFactory::create('c', {3,3}, {.1, .2, .3, 0.5*-0.32968, 0.5*-0.393469, 0.5*-0.451188, .7, .8, .9}); + sd::ops::elu op; + auto result = op.evaluate({&x}, {0.5}, {}); - sd::ops::elu op; - auto result = op.evaluate({&x}, {0.5}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto s = result.at(0); - ASSERT_TRUE(exp.equalsTo(s)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto s = result.at(0); + ASSERT_TRUE(exp.equalsTo(s)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, elu_bp_test1) { + auto x = NDArrayFactory::create( + 'c', {3, 3}, {0.1, .2, .3, -.4, -.5, -.6, .7, .8, .9}); + auto eps = NDArrayFactory::create('c', {3, 3}); + eps.assign(2.); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {2, 2, 2, 0.5 * 1.34064, 0.5 * 1.213061, 0.5 * 1.097623, 2, 2, 2}); - auto x = NDArrayFactory::create('c', {3, 3}, {0.1, .2, .3, -.4, -.5, -.6, .7, .8, .9}); - auto eps = NDArrayFactory::create('c', {3,3}); - eps.assign(2.); - auto exp = NDArrayFactory::create('c', {3, 3}, {2, 2, 2, 0.5*1.34064, 0.5*1.213061, 0.5*1.097623, 2, 2, 2}); - - sd::ops::elu_bp op; - auto result = op.evaluate({ &x, &eps }, {0.5}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::elu_bp op; + auto result = op.evaluate({&x, &eps}, {0.5}, {}); - auto s = result.at(0); - ASSERT_TRUE(exp.equalsTo(s)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto s = result.at(0); + ASSERT_TRUE(exp.equalsTo(s)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, lrelu_test1) { + auto x = NDArrayFactory::create('c', {3, 3}, + {1, 2, 3, -4, -5, -6, 7, 8, 9}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {1, 2, 3, -0.8, -1., -1.2, 7, 8, 9}); - auto x = NDArrayFactory::create('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9}); - auto exp = NDArrayFactory::create('c', {3,3}, {1, 2, 3, -0.8, -1., -1.2, 7, 8, 9}); + sd::ops::lrelu op; + auto result = op.evaluate({&x}, {0.2}, {}); - sd::ops::lrelu op; - auto result = op.evaluate({&x}, {0.2}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto s = result.at(0); - ASSERT_TRUE(exp.equalsTo(s)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto s = result.at(0); + ASSERT_TRUE(exp.equalsTo(s)); } TEST_F(DeclarableOpsTests3, lrelu_bp_test1) { + auto x = NDArrayFactory::create('c', {3, 3}, + {1, 2, 3, -4, -5, -6, 7, 8, 9}); + auto eps = + NDArrayFactory::create('c', {3, 3}, {2, 2, 2, 2, 2, 2, 2, 2, 2}); + auto exp = NDArrayFactory::create('c', {3, 3}, + {2, 2, 2, 0.4, 0.4, 0.4, 2, 2, 2}); - auto x = NDArrayFactory::create('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9}); - auto eps = NDArrayFactory::create('c', {3,3}, {2,2,2,2,2,2,2, 2,2}); - auto exp = NDArrayFactory::create('c', {3,3}, {2, 2, 2, 0.4, 0.4, 0.4, 2, 2, 2}); - - sd::ops::lrelu_bp op; - auto result = op.evaluate({&x, &eps}, {0.2}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::lrelu_bp op; + auto result = op.evaluate({&x, &eps}, {0.2}, {}); - auto s = result.at(0); - ASSERT_TRUE(exp.equalsTo(s)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto s = result.at(0); + ASSERT_TRUE(exp.equalsTo(s)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, selu_test1) { + auto x = NDArrayFactory::create('c', {3, 3}, + {1, 2, 3, -4, -5, -6, 7, 8, 9}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {1.050701, 2.101402, 3.152103, -1.725899, -1.746253, -1.753742, 7.354907, + 8.405608, 9.456309}); - auto x = NDArrayFactory::create('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9}); - auto exp = NDArrayFactory::create('c', {3,3}, {1.050701, 2.101402, 3.152103, -1.725899, -1.746253, -1.753742, 7.354907, 8.405608, 9.456309}); + sd::ops::selu op; + auto result = op.evaluate({&x}, {}, {}); - sd::ops::selu op; - auto result = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto s = result.at(0); - ASSERT_TRUE(exp.equalsTo(s)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto s = result.at(0); + ASSERT_TRUE(exp.equalsTo(s)); } TEST_F(DeclarableOpsTests3, selu_test2) { + auto x = NDArrayFactory::create('c', {3, 3}, + {1, 2, 3, -4, -5, -6, 7, 8, 9}); + // auto expS = NDArrayFactory::create('c', {3}); + auto eps = + NDArrayFactory::create('c', {3, 3}, {2, 2, 2, 2, 2, 2, 2, 2, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {2.101401, 2.101402, 2.101402, 0.064401, 0.023692, 0.008716, 2.101402, + 2.101402, 2.101402}); - auto x = NDArrayFactory::create('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9}); -// auto expS = NDArrayFactory::create('c', {3}); - auto eps = NDArrayFactory::create('c', {3,3}, {2,2,2,2,2,2,2, 2,2}); - auto exp = NDArrayFactory::create('c', {3,3}, {2.101401, 2.101402, 2.101402, 0.064401, 0.023692, 0.008716, 2.101402, 2.101402, 2.101402}); - - sd::ops::selu_bp op; - auto result = op.evaluate({&x, &eps}, {0.2}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::selu_bp op; + auto result = op.evaluate({&x, &eps}, {0.2}, {}); - auto s = result.at(0); -// auto u = result.at(1); -// auto v = result.at(2); -// s->printIndexedBuffer("SELU_BP"); - ASSERT_TRUE(exp.equalsTo(s)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto s = result.at(0); + // auto u = result.at(1); + // auto v = result.at(2); + // s->printIndexedBuffer("SELU_BP"); + ASSERT_TRUE(exp.equalsTo(s)); } TEST_F(DeclarableOpsTests3, EQScalarTests_1) { - Graph graph; + Graph graph; - auto x = NDArrayFactory::create(1.0f); - auto scalar = NDArrayFactory::create(1.0f); - - sd::ops::eq_scalar op; - auto res = op.verify({&x, &scalar}); - ASSERT_TRUE(res); + auto x = NDArrayFactory::create(1.0f); + auto scalar = NDArrayFactory::create(1.0f); + sd::ops::eq_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_TRUE(res); } TEST_F(DeclarableOpsTests3, EQScalarTests_2) { - Graph graph; + Graph graph; - auto x = NDArrayFactory::create(2.0f); - auto scalar = NDArrayFactory::create(1.0f); + auto x = NDArrayFactory::create(2.0f); + auto scalar = NDArrayFactory::create(1.0f); - sd::ops::eq_scalar op; - auto res = op.verify({&x, &scalar}); - ASSERT_FALSE(res); + sd::ops::eq_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_FALSE(res); } TEST_F(DeclarableOpsTests3, GTScalarTests_1) { - Graph graph; + Graph graph; - auto x = NDArrayFactory::create(1.0f); - auto scalar = NDArrayFactory::create(1.0f); + auto x = NDArrayFactory::create(1.0f); + auto scalar = NDArrayFactory::create(1.0f); - sd::ops::gt_scalar op; - auto res = op.verify({&x, &scalar}); - ASSERT_FALSE(res); + sd::ops::gt_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_FALSE(res); } TEST_F(DeclarableOpsTests3, GTScalarTests_2) { - Graph graph; + Graph graph; - auto x = NDArrayFactory::create(2.0f); - auto scalar = NDArrayFactory::create(1.0f); + auto x = NDArrayFactory::create(2.0f); + auto scalar = NDArrayFactory::create(1.0f); - sd::ops::gt_scalar op; - auto res = op.verify({&x, &scalar}); - ASSERT_TRUE(res); + sd::ops::gt_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_TRUE(res); } TEST_F(DeclarableOpsTests3, GTEScalarTests_1) { - Graph graph; + Graph graph; - auto x = NDArrayFactory::create(1.0f); - auto scalar = NDArrayFactory::create(1.0f); + auto x = NDArrayFactory::create(1.0f); + auto scalar = NDArrayFactory::create(1.0f); - sd::ops::gte_scalar op; - auto res = op.verify({&x, &scalar}); - ASSERT_TRUE(res); + sd::ops::gte_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_TRUE(res); } TEST_F(DeclarableOpsTests3, GTEScalarTests_2) { - Graph graph; + Graph graph; - auto x = NDArrayFactory::create(2.0f); - auto scalar = NDArrayFactory::create(1.0f); + auto x = NDArrayFactory::create(2.0f); + auto scalar = NDArrayFactory::create(1.0f); - sd::ops::gte_scalar op; - auto res = op.verify({&x, &scalar}); - ASSERT_TRUE(res); + sd::ops::gte_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_TRUE(res); } TEST_F(DeclarableOpsTests3, GTEScalarTests_3) { - Graph graph; + Graph graph; - auto x = NDArrayFactory::create(1.0f); - auto scalar = NDArrayFactory::create(2.0f); + auto x = NDArrayFactory::create(1.0f); + auto scalar = NDArrayFactory::create(2.0f); - sd::ops::gte_scalar op; - auto res = op.verify({&x, &scalar}); - ASSERT_FALSE(res); + sd::ops::gte_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_FALSE(res); } TEST_F(DeclarableOpsTests3, LTEScalarTests_1) { - Graph graph; + Graph graph; - auto x = NDArrayFactory::create(1.0f); - auto scalar = NDArrayFactory::create(1.0f); + auto x = NDArrayFactory::create(1.0f); + auto scalar = NDArrayFactory::create(1.0f); - sd::ops::lte_scalar op; - auto res = op.verify({&x, &scalar}); - ASSERT_TRUE(res); + sd::ops::lte_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_TRUE(res); } TEST_F(DeclarableOpsTests3, LTEScalarTests_2) { - Graph graph; + Graph graph; - auto x = NDArrayFactory::create(2.0f); - auto scalar = NDArrayFactory::create(1.0f); + auto x = NDArrayFactory::create(2.0f); + auto scalar = NDArrayFactory::create(1.0f); - sd::ops::lte_scalar op; - auto res = op.verify({&x, &scalar}); - ASSERT_FALSE(res); + sd::ops::lte_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_FALSE(res); } TEST_F(DeclarableOpsTests3, LTEScalarTests_3) { - Graph graph; + Graph graph; - auto x = NDArrayFactory::create(1.0f); - auto scalar = NDArrayFactory::create(2.0f); + auto x = NDArrayFactory::create(1.0f); + auto scalar = NDArrayFactory::create(2.0f); - sd::ops::lte_scalar op; - auto res = op.verify({&x, &scalar}); - ASSERT_TRUE(res); + sd::ops::lte_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_TRUE(res); } TEST_F(DeclarableOpsTests3, NEQScalarTests_1) { - Graph graph; - - auto x = NDArrayFactory::create(1.0f); - auto scalar = NDArrayFactory::create(1.0f); + Graph graph; - sd::ops::neq_scalar op; - auto res = op.verify({&x, &scalar}); - ASSERT_FALSE(res); + auto x = NDArrayFactory::create(1.0f); + auto scalar = NDArrayFactory::create(1.0f); + sd::ops::neq_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_FALSE(res); } TEST_F(DeclarableOpsTests3, NEQScalarTests_2) { - Graph graph; + Graph graph; - auto x = NDArrayFactory::create(2.0f); - auto scalar = NDArrayFactory::create(1.0f); + auto x = NDArrayFactory::create(2.0f); + auto scalar = NDArrayFactory::create(1.0f); - sd::ops::neq_scalar op; - auto res = op.verify({&x, &scalar}); - ASSERT_TRUE(res); + sd::ops::neq_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_TRUE(res); } TEST_F(DeclarableOpsTests3, NOOPTests_1) { - Graph graph; + Graph graph; - auto x = NDArrayFactory::create(2.0f); - auto scalar = NDArrayFactory::create(1.0f); + auto x = NDArrayFactory::create(2.0f); + auto scalar = NDArrayFactory::create(1.0f); - sd::ops::noop op; - auto res = op.evaluate({&x, &scalar}, {}, {}); - ASSERT_TRUE(res.status() == sd::Status::OK()); + sd::ops::noop op; + auto res = op.evaluate({&x, &scalar}, {}, {}); + ASSERT_TRUE(res.status() == sd::Status::OK()); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp index 56e5e213a0bb..148a249e4612 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp @@ -18,2402 +18,2595 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include -#include #include #include +#include +#include +#include "testlayers.h" using namespace sd; using namespace sd::graph; class DeclarableOpsTests4 : public testing::Test { -public: - - DeclarableOpsTests4() { - printf("\n"); - fflush(stdout); - - sd::ops::adjust_hue op0; - sd::ops::adjust_saturation op1; - } + public: + DeclarableOpsTests4() { + printf("\n"); + fflush(stdout); + + sd::ops::adjust_hue op0; + sd::ops::adjust_saturation op1; + } }; template class TypedDeclarableOpsTests4 : public testing::Test { -public: - - TypedDeclarableOpsTests4() { - printf("\n"); - fflush(stdout); - - sd::ops::adjust_hue op0; - sd::ops::adjust_saturation op1; - } + public: + TypedDeclarableOpsTests4() { + printf("\n"); + fflush(stdout); + + sd::ops::adjust_hue op0; + sd::ops::adjust_saturation op1; + } }; typedef ::testing::Types TestingTypes; -TYPED_TEST_CASE(TypedDeclarableOpsTests4, TestingTypes); +TYPED_TEST_SUITE(TypedDeclarableOpsTests4, TestingTypes); ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_1) { - auto x = NDArrayFactory::create('c', {2, 4, 4, 2}); - auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, {6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, 54.f, 55.f, 58.f, 59.f}); - - x.linspace(1); - - sd::ops::avgpool2d op; - auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 1, 1}); + auto x = NDArrayFactory::create('c', {2, 4, 4, 2}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, + {6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, + 54.f, 55.f, 58.f, 59.f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + x.linspace(1); - auto z = result.at(0); + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 1, 1}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_2) { - auto x = NDArrayFactory::create('c', {2, 4, 4, 2}); - auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, {6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, 54.f, 55.f, 58.f, 59.f}); - - x.linspace(1); - + auto x = NDArrayFactory::create('c', {2, 4, 4, 2}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, + {6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, + 54.f, 55.f, 58.f, 59.f}); - sd::ops::avgpool2d op; - auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1}); + x.linspace(1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_3) { - auto x = NDArrayFactory::create('c', {2, 5, 5, 2}); - auto exp = NDArrayFactory::create('c', {2, 3, 3, 2}, {7.f, 8.f, 11.f, 12.f, 14.f, 15.f, 27.f, 28.f, 31.f, 32.f, 34.f, 35.f, 42.f, 43.f, 46.f, 47.f, 49.f, 50.f, 57.f, 58.f, 61.f, 62.f, 64.f, 65.f, 77.f, 78.f, 81.f, 82.f, 84.f, 85.f, 92.f, 93.f, 96.f, 97.f, 99.f, 100.f,}); - - x.linspace(1); - + auto x = NDArrayFactory::create('c', {2, 5, 5, 2}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 3, 2}, + { + 7.f, 8.f, 11.f, 12.f, 14.f, 15.f, 27.f, 28.f, 31.f, + 32.f, 34.f, 35.f, 42.f, 43.f, 46.f, 47.f, 49.f, 50.f, + 57.f, 58.f, 61.f, 62.f, 64.f, 65.f, 77.f, 78.f, 81.f, + 82.f, 84.f, 85.f, 92.f, 93.f, 96.f, 97.f, 99.f, 100.f, + }); - sd::ops::avgpool2d op; - auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 0, 1}); + x.linspace(1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 0, 1}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_4) { - auto x = NDArrayFactory::create('c', {2, 5, 5, 2}); - auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, {7.f, 8.f, 11.f, 12.f, 27.f, 28.f, 31.f, 32.f, 57.f, 58.f, 61.f, 62.f, 77.f, 78.f, 81.f, 82.f}); - - x.linspace(1); + auto x = NDArrayFactory::create('c', {2, 5, 5, 2}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, + {7.f, 8.f, 11.f, 12.f, 27.f, 28.f, 31.f, 32.f, 57.f, 58.f, 61.f, 62.f, + 77.f, 78.f, 81.f, 82.f}); - sd::ops::avgpool2d op; - auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1}); + x.linspace(1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_5) { - auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); - auto exp = NDArrayFactory::create('c', {2, 2, 3, 3}, {1.f, 2.5f, 4.5f, 8.5f, 10.f, 12.f, 18.5f, 20.f, 22.f, 26.f, 27.5f, 29.5f, 33.5f, 35.f, 37.f, 43.5f, 45.f, 47.f, 51.f, 52.5f, 54.5f, 58.5f, 60.f, 62.f, 68.5f, 70.f, 72.f, 76.f, 77.5f, 79.5f, 83.5f, 85.f, 87.f, 93.5f, 95.f, 97.f}); - - x.linspace(1); - + auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 3, 3}, + {1.f, 2.5f, 4.5f, 8.5f, 10.f, 12.f, 18.5f, 20.f, 22.f, + 26.f, 27.5f, 29.5f, 33.5f, 35.f, 37.f, 43.5f, 45.f, 47.f, + 51.f, 52.5f, 54.5f, 58.5f, 60.f, 62.f, 68.5f, 70.f, 72.f, + 76.f, 77.5f, 79.5f, 83.5f, 85.f, 87.f, 93.5f, 95.f, 97.f}); - sd::ops::avgpool2d op; - auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0}); + x.linspace(1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_6) { - auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); - auto exp = NDArrayFactory::create('c', {2, 2, 3, 3}, {0.25f, 1.25f, 2.25f, 4.25f, 10.f, 12.f, 9.25f, 20.f, 22.f, 6.5f, 13.75f, 14.75, 16.75f, 35.f, 37.f, 21.75f, 45.f, 47.f, 12.75f, 26.25f, 27.25f, 29.25f, 60.f, 62.f, 34.25f, 70.f, 72.f, 19.f, 38.75f, 39.75f, 41.75f, 85.f, 87.f, 46.75f, 95.f, 97.f}); - - x.linspace(1); - + auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 3, 3}, + {0.25f, 1.25f, 2.25f, 4.25f, 10.f, 12.f, 9.25f, 20.f, 22.f, + 6.5f, 13.75f, 14.75, 16.75f, 35.f, 37.f, 21.75f, 45.f, 47.f, + 12.75f, 26.25f, 27.25f, 29.25f, 60.f, 62.f, 34.25f, 70.f, 72.f, + 19.f, 38.75f, 39.75f, 41.75f, 85.f, 87.f, 46.75f, 95.f, 97.f}); - sd::ops::avgpool2d op; - auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 1, 1, 1, 1, 0, 1, 0}); + x.linspace(1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 1, 1, 1, 1, 0, 1, 0}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_7) { - auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); - auto exp = NDArrayFactory::create('c', {2, 2, 3, 3}, {4.f, 6.f, 7.5f, 14.f, 16.f, 17.5f, 21.5f, 23.5f, 25.f, 29.f, 31.f, 32.5f, 39.f, 41.f, 42.5f, 46.5f, 48.5f, 50.f, 54.f, 56.f, 57.5f, 64.f, 66.f, 67.5f, 71.5f, 73.5f, 75.f, 79.f, 81.f, 82.5f, 89.f, 91.f, 92.5f, 96.5f, 98.5f, 100.f}); - - x.linspace(1); - + auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 3, 3}, + {4.f, 6.f, 7.5f, 14.f, 16.f, 17.5f, 21.5f, 23.5f, 25.f, + 29.f, 31.f, 32.5f, 39.f, 41.f, 42.5f, 46.5f, 48.5f, 50.f, + 54.f, 56.f, 57.5f, 64.f, 66.f, 67.5f, 71.5f, 73.5f, 75.f, + 79.f, 81.f, 82.5f, 89.f, 91.f, 92.5f, 96.5f, 98.5f, 100.f}); - sd::ops::avgpool2d op; - auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 0, 0}); + x.linspace(1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 0, 0}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_8) { - auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); - auto exp = NDArrayFactory::create('c', {1, 1, 2, 2}, {3.f, 4.f, 6.f, 7.f}); - - x.linspace(1); + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + auto exp = NDArrayFactory::create('c', {1, 1, 2, 2}, + {3.f, 4.f, 6.f, 7.f}); - sd::ops::avgpool2d op; - auto result = op.evaluate({&x}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 0, 0, 0}); + x.linspace(1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 0, 0, 0}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_9) { - auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); - auto exp = NDArrayFactory::create('c', {1, 1, 3, 3}, {3.f, 4.f, 4.5f, 6.f, 7.f, 7.5f, 7.5f, 8.5f, 9.f}); - - x.linspace(1); - + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + auto exp = NDArrayFactory::create( + 'c', {1, 1, 3, 3}, {3.f, 4.f, 4.5f, 6.f, 7.f, 7.5f, 7.5f, 8.5f, 9.f}); - sd::ops::avgpool2d op; - auto result = op.evaluate({&x}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0, 0}); + x.linspace(1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0, 0}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - //z->printShapeInfo("z shape:"); - //z->printBuffer("z buffer:"); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printShapeInfo("z shape:"); + // z->printBuffer("z buffer:"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_10) { - - auto input = NDArrayFactory::create('c', {4, 10, 10, 3}, {9.37125111f, 2.20166993f, 2.91434479f, 5.43639755f, -2.10573769f, 4.08528662f, 5.86908436f, -4.46203756f, 2.21057916f, 5.35849190f, 0.01394637f, 4.40566349f, 7.07982206f, -0.09633455f, 2.42429352f, 3.97301817f, -1.89553940f, 1.99690318f, 6.33141708f, 0.55401880f, 1.70707977f, - 5.55204201f, -0.03513752f, 1.60011971f, 2.62700319f, -2.74582434f, 3.06697464f, 1.06277943f, -1.16075921f, -0.78095782f, 9.72352791f, -1.22686064f, 1.99644792f, 7.35571337f, 1.40607321f, 0.11390255f, 9.53334427f, 2.28303599f, -1.66728830f, 6.16678810f, -0.04532295f, -1.97708666f, 9.74906158f, 1.46223176f, -1.46734393f, 4.30761862f, - -1.23790228f, 1.24823606f, 6.13938427f, -3.83689475f, -1.19625473f, 7.91535568f, 6.05868721f, -3.22946382f, 8.81633949f, -0.19967777f, 0.66053957f, 2.30919123f, 0.74543846f, -0.39347672f, 11.11058044f, 0.53720862f, 1.52645731f, 5.70012379f, -1.15213466f, 1.16451406f, 7.00526333f, 1.57362783f, -2.44384766f, 5.54213285f, -1.98828590f, - -0.70483637f, 7.88281822f, -3.59875536f, 0.80745387f, 13.41578484f, -1.55507684f, -0.65855008f, 9.32583523f, -0.14544789f, 0.73436141f, 3.61176538f, -1.71268058f, -2.58490300f, 9.09280205f, -3.27405524f, -2.04569697f, 4.44761324f, -0.62955856f, -2.61917663f, 8.04890442f, 0.54579324f, 0.85929775f, 9.82259560f, -1.93825579f, 0.77703512f, - 4.67090321f, -4.79267597f, -2.38906908f, 9.31265545f, 0.96026313f, -1.14109385f, 11.54231834f, -0.01417295f, -0.39500344f, 8.49191666f, 0.55300158f, 2.79490185f, 6.92466164f, 1.72254205f, 2.82222271f, 8.83112717f, 2.95033407f, 2.18054962f, 6.73509789f, -2.22272944f, 0.51127720f, -1.04563558f, 2.15747333f, -2.30959272f, 9.55441570f, - 1.50396204f, 1.77370787f, 7.38146257f, -1.79076433f, 3.20961165f, 7.18864202f, 2.91217351f, 0.43018937f, 7.11078024f, -1.17386127f, -0.16817921f, 6.12327290f, -2.82205725f, 3.30696845f, 13.51291752f, -1.30856836f, -2.38332748f, 11.09487438f, -1.47190213f, -0.53050828f, 4.38285351f, -5.07309771f, 1.50714362f, 5.72274446f, -2.85825086f, - -0.89673209f, 3.73791552f, -0.67708802f, -4.13149452f, -0.00671843f, -0.26566532f, 0.32961160f, 7.14501762f, -1.41608179f, -4.96590328f, 12.26205540f, -0.65158135f, -0.88641000f, 6.95777559f, -0.79058206f, -0.10260171f, 7.87169170f, 1.35921454f, 1.11759663f, 5.46187401f, -2.57214499f, 2.48484039f, 4.04043484f, -2.07137156f, -1.42709637f, - 9.25487137f, -0.12605135f, -2.66949964f, 2.89412403f, 0.74451172f, -2.96250391f, 3.99258423f, 0.27084303f, 0.32213116f, 5.42332172f, -0.44414216f, 1.70881832f, 6.69346905f, 0.53058422f, -4.73146200f, 4.22051668f, 2.24834967f, 0.66996074f, 4.30173683f, 0.11849818f, -4.07520294f, 8.27318478f, -2.54398274f, -2.86705542f, 10.11775303f, - -0.99382895f, 0.65881538f, 7.93556786f, -1.27934420f, -1.69343162f, 9.68042564f, -1.02609646f, -1.18189347f, 5.75370646f, -1.67888868f, -4.48871994f, 4.79537392f, -0.79212248f, -0.19855022f, 6.15060997f, -0.01081491f, 3.64454579f, 10.82562447f, 1.58859253f, -2.65847278f, 8.60093212f, -1.59196103f, 0.07635692f, 11.76175690f, -1.17453325f, - 0.10122013f, 6.86458445f, -2.18891335f, -2.74004745f, 8.07066154f, 0.71818852f, -2.03035975f, 6.31053686f, 0.51509416f, 1.39789927f, 9.43515587f, 2.04256630f, 0.13985133f, 4.65010691f, 2.40911126f, -0.36255789f, -3.06867862f, -0.45225358f, -1.56778407f, 6.05917358f, -1.09891272f, 1.77184200f, 6.46248102f, 0.96042323f, -0.24346280f, - 4.63436460f, -4.69907761f, 1.25187206f, 11.46173859f, -2.21917558f, 1.28007793f, 6.92173195f, 2.11268163f, -3.47389889f, 5.08722782f, -3.03950930f, -4.17154264f, 11.30568314f, 0.80361372f, 2.53214502f, 7.18707085f, -4.49114513f, 2.85449266f, 10.14906883f, -0.31974933f, -0.84472644f, -0.52459574f, 0.12921631f, -1.81390119f, 2.76170087f, 1.03982210f, 2.91744232f, -0.29048753f, 5.87453508f, -1.53684759f, 1.85800636f, -0.91404629f, 1.28954852f, 5.11354685f, -2.47475505f, -1.33179152f, 2.58552408f, 1.37316465f, -3.32339454f, 1.54122913f, 3.24953628f, -0.29758382f, 2.82391763f, -1.51142192f, -1.22699404f, 6.75745535f, 0.65452754f, -3.29385471f, 2.06008053f, 2.53172946f, -4.23532820f, -1.53909743f, -0.07010663f, -1.42173731f, 7.29031610f, -0.18448229f, 4.59496164f, 6.73027277f, 0.73441899f, 0.14426160f, 4.14915276f, -2.97010231f, 6.05851364f, 4.95218086f, -2.39145470f, 2.40494704f, 2.10288811f, 0.53503096f, 1.44511235f, 6.66344261f, -3.05803776f, 7.21418667f, 3.30303526f, -0.24163735f, 3.47409391f, 3.64520788f, 2.15189481f, -3.11243272f, 3.62310791f, 0.37379482f, 0.40865007f, -0.83132005f, -4.78246069f, 2.07030797f, 6.51765442f, 3.16178989f, 5.06180477f, 3.78434467f, -0.96689719f, 0.35965276f, 5.89967585f, 1.40294051f, 1.11952639f, 10.59778214f, 0.26739889f, -1.61297631f, 6.24801159f, -0.93914318f, -0.57812452f, 9.92604542f, -0.73025000f, -3.38530874f, 2.45646000f, -2.47949195f, 0.51638460f, 10.65636063f, 1.97816694f, -3.00407791f, 2.66914415f, -0.81951088f, -0.23316640f, 2.40737987f, -2.70007610f, 1.51531935f, 4.08860207f, -0.27552786f, -1.31721711f, 7.11568260f, -3.33498216f, -4.02545023f, 7.22675610f, -0.81690705f, -2.52689576f, 1.04016697f, -0.79291463f, -0.34875512f, 10.00498390f, -4.24167728f, 1.46162593f, 11.82569408f, -1.70359993f, -0.30161047f, 16.44085884f, -0.82253462f, -0.09435523f, 6.13080597f, -0.20259480f, 0.68308711f, 6.15663004f, -6.61776876f, 0.33295766f, 2.55449438f, -0.17819691f, -1.14892209f, 5.56776142f, 1.99279118f, 1.33035934f, 4.45823956f, 3.34916544f, -2.59905386f, 6.16164446f, -2.03881931f, -2.45273542f, 12.46793365f, -2.22743297f, 2.83738565f, 8.48628139f, -1.39347959f, -1.30867767f, 11.08041477f, -4.00363779f, 2.09183025f, 11.30395889f, -2.20504737f, 1.37426853f, 8.98735619f, 1.04676604f, -0.72757077f, 8.28050232f, -6.70741081f, -0.65798020f, 5.68592072f, -0.60760021f, 0.35854483f, 6.26852131f, 1.94100165f, 1.32112014f, 0.80987954f, -1.74617672f, -0.25434083f, 7.16045523f, 1.58884013f, -2.64847064f, 13.14820385f, 1.21393633f, -2.47258949f, 9.41650105f, -0.79384226f, 2.48954105f, 10.95629311f, 0.47723705f, 4.02126694f, 8.02593136f, -2.20726371f, -1.18794477f, 1.50836647f, 0.93118095f, -1.73513174f, 8.85493565f, -2.99670315f, -0.79055870f, 2.39473820f, 2.05046916f, -2.38055134f, 11.82299423f, 0.15609655f, 0.68744308f, 5.66401434f, -0.69281673f, 2.09855556f, 7.74626589f, -0.34283102f, 1.00542057f, 9.95838642f, 0.80161905f, 2.33455157f, 9.80057335f, -0.93561798f, 2.56991577f, 8.29711342f, 0.94213426f, 0.44209945f, 11.70259857f, 0.92710167f, 2.60957146f, 0.24971688f, -0.86529571f, 3.78628922f, 6.80884457f, -0.68178189f, 2.21103406f, 3.18895817f, 0.60283208f, -2.92716241f, 6.72060776f, -1.06625068f, 2.56543374f, 9.97404480f, 3.58080721f, -0.94936347f, 10.16736984f, -1.38464379f, 1.18191063f, 6.66179037f, -3.56115270f, 0.32329530f, 10.90870762f, 2.20638227f, 0.19653285f, 7.34650040f, -3.63859272f, -1.03027737f, 5.98829985f, -3.66606474f, -3.89746714f, 8.63469028f, 1.22569811f, 1.63240814f, 3.74385309f, 0.58243257f, -0.56981975f, 3.69260955f, 1.00979900f, -1.44030499f, 8.57058144f, -1.10648811f, 1.20474911f, 5.43133020f, -2.14822555f, -0.07928789f, 11.25825310f, 0.19645604f, -5.49546146f, 10.41917038f, -0.68178523f, -2.99639869f, 6.50054455f, 0.46488351f, -5.42328453f, 9.09500027f, -2.82107449f, 0.05601966f, 15.34610748f, -0.06820253f, 3.86699796f, 10.73316956f, -3.04795432f, -0.14702171f, 5.64813185f, 1.44028485f, -2.47596145f, 0.07280898f, -3.03187990f, -1.35183525f, 9.35835648f, 2.72966957f, 1.88199532f, 10.36187744f, -0.22834805f, -3.26738238f, 6.92025137f, -2.34061313f, 4.77379704f, 5.28559113f, -2.96323752f, -1.76186585f, 5.94436455f, 0.38647744f, -5.73869514f, 6.76849556f, 1.40892124f, -1.19068217f, 5.37919092f, -6.65328646f, 3.62782669f, 12.34744644f, 2.44762444f, -4.19242620f, 6.14906216f, 0.08121119f, 0.61355996f, 2.69666457f, -1.88962626f, -0.55314136f, 1.84937525f, 1.56048691f, 1.17460012f, 3.75674725f, 1.06198275f, -5.74625874f, 5.41645575f, -1.28946674f, -1.51689398f, 4.32400894f, -0.05222082f, -4.83948946f, 1.80747867f, 1.63144708f, -2.73887825f, 1.63975775f, -2.02163982f, -0.16210437f, 2.93518686f, 1.14427686f, -2.83246303f, 4.79283667f, 2.69697428f, -3.12678456f, -1.19225168f, -2.37022972f, -3.09429741f, 1.94225383f, -1.13747168f, -2.55048585f, 5.40242243f, 1.12777328f, 3.43713188f, 3.62658787f, -2.16878843f, 0.30164462f, 2.97407579f, -0.07275413f, -1.31149673f, 4.70066261f, -2.01323795f, 4.85255766f, 4.59128904f, 1.68084168f, 1.60336494f, 6.58138466f, -1.04759812f, 2.69906545f, 3.55769277f, -0.74327278f, 2.65819693f, 5.39528131f, 2.11248922f, -1.06446671f, 5.24546766f, -2.43146014f, 4.58907509f, 0.06521678f, -2.24503994f, 2.45722699f, 6.94863081f, 0.35258654f, 2.83396196f, 9.92525196f, -1.12225175f, -0.34365177f, 7.19116688f, -4.39813757f, 0.46517885f, 13.22028065f, -2.57483673f, -6.37226963f, 7.58046293f, -2.74600363f, 0.42231262f, 8.04881668f, 0.17289802f, -0.53447008f, 16.55157471f, -5.63614368f, 0.39288223f, 3.37079263f, 1.26484549f, -0.12820500f, 8.46440125f, -4.39304399f, 2.97676420f, 0.65650189f, 0.83158541f, -1.11556435f, 6.32885838f, -0.36087769f, 2.80724382f, 9.90292645f, 1.15936041f, 0.20947981f, 6.91249275f, -2.67404819f, 2.93782163f, 6.65656614f, -2.30828357f, 2.98214006f, 6.80611229f, -4.93821478f, -7.66555262f, 7.59763002f, -0.54159302f, 3.87403512f, 12.42607784f, 2.59284401f, -0.23375344f, 8.95293331f, -0.71807784f, 0.61873478f, 8.66713524f, 1.24289191f, -2.37835455f, 2.08071637f, -0.88315344f, -3.41891551f, 6.85245323f, 1.73007369f, 1.02169311f, 7.69170332f, -2.85411978f, 2.69790673f, 8.12906551f, -1.19351399f, -2.26442742f, 12.26104450f, -0.75579089f, -1.73274946f, 10.68729019f, 2.20655656f, -0.90522075f, 12.42165184f, -1.67929137f, 2.44851565f, 9.31565762f, -0.06645700f, 1.52762020f, 6.18427515f, -1.68882596f, 3.70261097f, 3.02252960f, -3.44125366f, -1.31575799f, 2.84617424f, -0.96849400f, -4.52356243f, 9.95027161f, 0.19966406f, -0.78874779f, 8.18595028f, -4.08300209f, 1.75126517f, 0.96418417f, -4.04913044f, -0.95200396f, 12.03637886f, -0.03041124f, 0.41642749f, 8.88267422f, -3.24985337f, -2.24919462f, 7.32566118f, 0.16964148f, -2.74123430f, 7.05264473f, -3.30191112f, 0.17163286f, 4.81851053f, -1.64463484f, -0.85933101f, 7.29276276f, 2.34066939f, -2.14860010f, 3.46148157f, -0.01782012f, 1.51504040f, 4.79304934f, 1.85281146f, -1.70663762f, 6.93470192f, -4.15440845f, -1.25983095f, 10.52491760f, 0.42930329f, -1.85146868f, 11.70042324f, -0.41704914f, 3.83796859f, 9.21148491f, -2.79719448f, 0.79470479f, 6.26926661f, -5.85230207f, 3.95105338f, 7.84790897f, -1.38680744f, -1.78099084f, 11.95235348f, -2.99841452f, -1.34507811f, 6.15714645f, -1.07552516f, -2.81228638f, 1.66234732f, -4.55166149f, -1.92601109f, 8.64634514f, -0.48158705f, 3.31595659f, 7.67371941f, 2.56964207f, 0.12107098f, 4.56467867f, -0.93541539f, 1.39432955f, 11.99714088f, 1.05353570f, -2.13099813f, 3.67617917f, 3.45895386f, 1.37365830f, 8.74344158f, -4.17585802f, 1.43908918f, 6.28764772f, 3.97346330f, -0.69144285f, 9.07983303f, -0.41635889f, -0.14965028f, 8.85469818f, 1.11306190f, 2.59440994f, 5.38982344f, -1.07948279f, 1.37252975f, 10.26984596f, -0.09318046f, 2.73104119f, 12.45902252f, -1.55446684f, -2.76124811f, 12.19395065f, -0.51846564f, 1.02764034f, 11.42673588f, -0.95940983f, -0.04781032f, 8.78379822f, -4.88957930f, 0.32534006f, 11.97696400f, -3.35108662f, 1.95104563f, 4.46915388f, -2.32061648f, 3.45230985f, 8.29983711f, 2.81034684f, -2.35529327f, 6.07801294f, -0.98105043f, -0.05359888f, 2.52291036f, -0.01986909f, -2.35321999f, 10.51954269f, 2.11145401f, 3.53506470f, 7.29093266f, 0.03721160f, -1.13496494f, 7.43886709f, -5.84201956f, 2.50796294f, 12.14647675f, 2.77490377f, -2.18896222f, 6.05641937f, 5.32617044f, 1.04221284f, 10.79106712f, -2.95749092f, -2.75414610f, 11.30037117f, -3.40654182f, -2.24673963f, 7.49126101f, 0.70811015f, -6.18003702f, 13.83951187f, -1.01204085f, 1.36298490f, -1.04451632f, 2.42435336f, -0.02346706f, -0.85528886f, 1.04731262f, 0.22192979f, 4.15708160f, 0.34933877f, 0.04814529f, 2.24107265f, 0.49676740f, -1.47752666f, 0.45040059f, -0.70471478f, -1.19759345f, 0.21711677f, 0.88461423f, -2.76830935f, 5.52066898f, 1.97664857f, -1.75381601f, 3.45877838f, 1.52617192f, -1.61350942f, 0.85337949f, 1.97610760f, -3.40310287f, 3.40319014f, -3.38691044f, -0.71319139f, 1.65463758f, -0.60680127f, -1.80700517f, 8.02592373f, 2.59627104f, 2.65895891f, 5.93043184f, -4.48425817f, 3.92670918f, 4.19496679f, -2.28286791f, 6.41634607f, 5.72330523f, 1.16269672f, -0.28753027f, 2.46342492f, 0.36693189f, 0.26712441f, 6.37652683f, -2.50139046f, 2.43923736f, 5.56310415f, 0.98065847f, 1.04267502f, 4.16403675f, -0.04966142f, 4.40897894f, 3.72905660f, -3.46129870f, 3.59962773f, 1.34830284f, -1.76661730f, 0.47943926f, 5.29946661f, -1.12711561f, 1.26970029f, 15.17655945f, -1.50971997f, 5.81345224f, 8.48562050f, -4.36049604f, 2.48144460f, 8.23780441f, -3.46030426f, -0.84656560f, 5.94946814f, 1.12747943f, -2.65683913f, 8.69085693f, 1.31309867f, -2.79958344f, 8.76840591f, -1.56444156f, 1.62710834f, 2.41177034f, -0.72804940f, 5.70619011f, 4.67169666f, -0.86167198f, -1.83803177f, 2.96346045f, 2.82692933f, -2.81557131f, 7.11113358f, -1.90071094f, 2.54244423f, 11.19284058f, -0.06298946f, -1.71517313f, 12.98388577f, 0.84510714f, 3.00816894f, 2.57200313f, 0.03899818f, -1.49330592f, 9.60099125f, -3.59513044f, -1.30045319f, 7.09241819f, -0.65233821f, -2.33627677f, 8.81366920f, 0.84154201f, 1.03312039f, 9.85289097f, 0.19351870f, 1.78496623f, 7.34631205f, -2.16530800f, -0.65016162f, 2.46842360f, 0.24016285f, -1.24308395f, 4.78175163f, -0.97682536f, 2.20942235f, 6.68382788f, 3.76786447f, -1.44454038f, 6.26453733f, -3.23575711f, -2.30137897f, 9.53092670f, -5.55222607f, 3.25999236f, 9.37559509f, 1.86339056f, -0.23551451f, 10.23400211f, 3.93031883f, -0.52629089f, 7.85724449f, -2.91549587f, 4.46612740f, 5.66530371f, -2.70820427f, 4.81359577f, 10.31247330f, 1.92230141f, 2.53931546f, 0.74986327f, 1.70303428f, 0.48063779f, 5.31099129f, -0.78976244f, 3.75864220f, 4.23051405f, 2.34042454f, -7.98193836f, 9.83987141f, -1.46722627f, 3.54497814f, 10.36455154f, -4.51249075f, 0.77715248f, 7.78694630f, -4.59989023f, -2.49585629f, 9.90296268f, 1.38535416f, 1.17441154f, 10.10452843f, -0.98628229f, 0.60194463f, 9.12639141f, -3.90754628f, 2.88526392f, 7.24123430f, -0.15283313f, -0.75728363f, -1.15116858f, -2.53791571f, 0.77229571f, 6.44114161f, 0.02646767f, 4.95463037f, 7.21066380f, 1.79384065f, 0.73250306f, 8.04447937f, 0.32576546f, -0.79447043f, 10.12717724f, 2.33392906f, 1.30716443f, 12.36073112f, -0.36694977f, -1.20438910f, 7.03105593f, 0.59557682f, 0.69267452f, 10.18113136f, 2.49944925f, -0.42229167f, 8.83143330f, -1.18805945f, -2.87509322f, 4.53596449f, 4.09732771f, -3.39088297f, -1.02536607f, 0.82119560f, -3.47302604f, 9.29991817f, 0.21001509f, 4.97036457f, 9.50018406f, 1.04420102f, 1.96560478f, 10.74769592f, -6.22709799f, 3.11690164f, 5.06759691f, -1.23724771f, -3.05831861f, 8.12925529f, -1.93435478f, -1.10151744f, 9.32263088f, -0.04249470f, -5.98547363f, 10.49398136f, 0.26400441f, -0.78915191f, 13.28219604f, 2.99276900f, 0.74853164f, 2.49364305f, -3.43529654f, 4.05278301f, 2.13498688f, -2.35444307f, -0.79900265f, 4.66968822f, -0.31095147f, 3.60674143f, 12.37222099f, -0.07855003f, -3.30292702f, 12.15215874f, 0.60886210f, 2.87075138f, 7.75271845f, 0.38044083f, 3.34402204f, 6.40583277f, -0.87888050f, 0.67438459f, 6.91080809f, 1.98332930f, -0.08303714f, 8.08630371f, -0.16772588f, -2.74058914f, 7.17253590f, -2.69122696f, 1.48173678f, 8.99470139f, -1.43302310f, -0.88651133f, 2.66944790f, -0.29186964f, 2.00838661f, 5.09587479f, -0.76676071f, -2.88322186f, 8.31110573f, -0.14550979f, -1.37726915f, 10.28355122f, -1.60575438f, -0.04118848f, 9.97510815f, 0.14440438f, -3.24632120f, 9.00034523f, 4.14319563f, -1.31023729f, 7.16950464f, -0.70428526f, 2.01559544f, 7.26155043f, 2.40816474f, 2.09847403f, 7.31264496f, -0.75401551f, 2.13392544f, 7.03648758f, 1.04036045f, -1.15636516f, 1.09634531f, -0.06340861f, -0.58107805f, -0.65623116f, 1.18972754f, -0.80717683f, 1.40118241f, -0.61932516f, -3.60596156f, 1.59904599f, -2.23774099f, -1.13721037f, 3.89620137f, -0.09115922f, -7.51356888f, 2.36975193f, -1.42520905f, -2.34173775f, 3.33830214f, -2.74016523f, -3.04115510f, 6.00119495f, -1.36084354f, -2.45065260f, 4.56992292f, -3.02825928f, -3.74182844f, 5.11069250f, -0.91531068f, -2.31385994f, 1.83399653f, 3.39370203f, -3.60886002f}); - auto exp = NDArrayFactory::create('c', {4, 4, 4, 3}, {7.97172260f, 0.06878620f, 2.27749538f, 7.29276514f, -0.14074677f, 0.65480286f, 5.70313978f, -0.06546132f, 0.35443667f, 3.70382833f, -0.84020567f, 0.63826996f, 8.60301399f, -0.38236514f, 1.55177069f, 7.37542057f, -0.99374938f, -0.29971302f, 8.84352493f, -0.67121059f, 0.43132120f, 4.78175592f, -1.25070143f, -1.91523600f, 6.03855371f, -0.00292124f, -1.11214364f, 7.90158176f, -0.57949901f, -0.96735370f, 7.81192017f, -0.53255427f, -0.48009714f, 3.16953635f, 0.08353355f, -1.54299748f, 3.74821687f, 1.69396687f, 0.72724354f, 5.42915201f, -1.13686812f, -0.71793109f, 5.78376389f, -0.72239977f, -0.60055625f, 2.53636408f, 0.56777251f, -2.07892323f, 6.08064651f, 0.68620735f, 2.54017019f, 5.65828180f, -0.68255502f, 1.47283304f, 6.10842514f, -0.39655915f, 0.28380761f, 1.96707797f, -1.98206317f, 0.94027776f, 4.71811438f, 0.32104525f, -0.92409706f, 8.34588146f, -1.05581069f, -0.55217457f, 9.58440876f, -0.96549922f, 0.45820439f, 5.65453672f, -2.50953507f, -0.71441835f, 8.03059578f, -0.21281289f, 0.92125505f, 9.26900673f, -0.35963219f, -0.70039093f, 8.59924412f, -1.22358346f, 0.81318003f, 3.85920119f, -0.01305223f, -1.09234154f, 6.33158875f, 1.28094780f, -1.48926139f, 4.94969177f, -0.77126902f, -1.97033751f, 5.64381838f, -0.16285487f, -1.31277227f, 2.39893222f, -1.32902908f, -1.39609122f, 6.47572327f, -0.45267010f, 1.55727172f, 6.70965624f, -1.68735468f, -0.05672536f, 7.25092363f, -0.64613032f, 0.67050058f, 3.60789680f, -2.05948973f, 2.22687531f, 8.15202713f, -0.70148355f, 1.28314006f, 8.14842319f, -1.88807654f, -1.04808438f, 8.45500565f, -0.76425624f, 0.94542569f, 4.56179953f, -0.28786001f, -2.04502511f, 8.46278095f, -0.31019822f, 0.07339200f, 9.34214592f, -0.61948007f, 0.52481830f, 8.32515621f, -1.52418160f, 0.49678251f, 5.11082315f, -1.09908783f, -0.52969611f, 5.27806664f, 0.88632923f, 0.66754371f, 4.75839233f, 0.48928693f, -0.68036932f, 6.56925392f, -0.02949905f, -2.99189186f, 4.46320581f, -0.64534980f, -0.29516968f, 8.60809517f, -1.13120568f, 3.41720533f, 5.84243155f, -1.24109328f, 0.89566326f, 5.99578333f, -0.42496428f, 2.07076764f, 3.17812920f, -0.81566459f, -0.14363396f, 6.55184317f, 0.39633346f, -0.43852386f, 8.70214558f, -2.24613595f, 0.30708700f, 8.73882294f, -0.53545928f, 1.54409575f, 4.49452257f, -0.16509305f, 0.19028664f, 8.24897003f, 0.44750381f, 2.15448594f, 8.97640514f, -0.77728152f, 0.57272542f, 9.03467560f, 0.47173575f, -1.10807717f, 3.30056310f, -0.43268481f, -0.41470885f, 3.53798294f, -0.08546703f, -2.16840744f, 6.18733406f, -0.17871059f, -2.59837723f, 5.94218683f, -1.02990067f, -0.49760687f, 3.76938033f, 0.86383581f, -1.91504073f}); - - sd::ops::avgpool2d op; - auto result = op.evaluate({&input}, {3,3, 3,3, 0,0, 1,1,1, 0,1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - // z->printIndexedBuffer("z"); - // exp.printIndexedBuffer("e"); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto input = NDArrayFactory::create( + 'c', {4, 10, 10, 3}, + {9.37125111f, 2.20166993f, 2.91434479f, 5.43639755f, -2.10573769f, + 4.08528662f, 5.86908436f, -4.46203756f, 2.21057916f, 5.35849190f, + 0.01394637f, 4.40566349f, 7.07982206f, -0.09633455f, 2.42429352f, + 3.97301817f, -1.89553940f, 1.99690318f, 6.33141708f, 0.55401880f, + 1.70707977f, 5.55204201f, -0.03513752f, 1.60011971f, 2.62700319f, + -2.74582434f, 3.06697464f, 1.06277943f, -1.16075921f, -0.78095782f, + 9.72352791f, -1.22686064f, 1.99644792f, 7.35571337f, 1.40607321f, + 0.11390255f, 9.53334427f, 2.28303599f, -1.66728830f, 6.16678810f, + -0.04532295f, -1.97708666f, 9.74906158f, 1.46223176f, -1.46734393f, + 4.30761862f, -1.23790228f, 1.24823606f, 6.13938427f, -3.83689475f, + -1.19625473f, 7.91535568f, 6.05868721f, -3.22946382f, 8.81633949f, + -0.19967777f, 0.66053957f, 2.30919123f, 0.74543846f, -0.39347672f, + 11.11058044f, 0.53720862f, 1.52645731f, 5.70012379f, -1.15213466f, + 1.16451406f, 7.00526333f, 1.57362783f, -2.44384766f, 5.54213285f, + -1.98828590f, -0.70483637f, 7.88281822f, -3.59875536f, 0.80745387f, + 13.41578484f, -1.55507684f, -0.65855008f, 9.32583523f, -0.14544789f, + 0.73436141f, 3.61176538f, -1.71268058f, -2.58490300f, 9.09280205f, + -3.27405524f, -2.04569697f, 4.44761324f, -0.62955856f, -2.61917663f, + 8.04890442f, 0.54579324f, 0.85929775f, 9.82259560f, -1.93825579f, + 0.77703512f, 4.67090321f, -4.79267597f, -2.38906908f, 9.31265545f, + 0.96026313f, -1.14109385f, 11.54231834f, -0.01417295f, -0.39500344f, + 8.49191666f, 0.55300158f, 2.79490185f, 6.92466164f, 1.72254205f, + 2.82222271f, 8.83112717f, 2.95033407f, 2.18054962f, 6.73509789f, + -2.22272944f, 0.51127720f, -1.04563558f, 2.15747333f, -2.30959272f, + 9.55441570f, 1.50396204f, 1.77370787f, 7.38146257f, -1.79076433f, + 3.20961165f, 7.18864202f, 2.91217351f, 0.43018937f, 7.11078024f, + -1.17386127f, -0.16817921f, 6.12327290f, -2.82205725f, 3.30696845f, + 13.51291752f, -1.30856836f, -2.38332748f, 11.09487438f, -1.47190213f, + -0.53050828f, 4.38285351f, -5.07309771f, 1.50714362f, 5.72274446f, + -2.85825086f, -0.89673209f, 3.73791552f, -0.67708802f, -4.13149452f, + -0.00671843f, -0.26566532f, 0.32961160f, 7.14501762f, -1.41608179f, + -4.96590328f, 12.26205540f, -0.65158135f, -0.88641000f, 6.95777559f, + -0.79058206f, -0.10260171f, 7.87169170f, 1.35921454f, 1.11759663f, + 5.46187401f, -2.57214499f, 2.48484039f, 4.04043484f, -2.07137156f, + -1.42709637f, 9.25487137f, -0.12605135f, -2.66949964f, 2.89412403f, + 0.74451172f, -2.96250391f, 3.99258423f, 0.27084303f, 0.32213116f, + 5.42332172f, -0.44414216f, 1.70881832f, 6.69346905f, 0.53058422f, + -4.73146200f, 4.22051668f, 2.24834967f, 0.66996074f, 4.30173683f, + 0.11849818f, -4.07520294f, 8.27318478f, -2.54398274f, -2.86705542f, + 10.11775303f, -0.99382895f, 0.65881538f, 7.93556786f, -1.27934420f, + -1.69343162f, 9.68042564f, -1.02609646f, -1.18189347f, 5.75370646f, + -1.67888868f, -4.48871994f, 4.79537392f, -0.79212248f, -0.19855022f, + 6.15060997f, -0.01081491f, 3.64454579f, 10.82562447f, 1.58859253f, + -2.65847278f, 8.60093212f, -1.59196103f, 0.07635692f, 11.76175690f, + -1.17453325f, 0.10122013f, 6.86458445f, -2.18891335f, -2.74004745f, + 8.07066154f, 0.71818852f, -2.03035975f, 6.31053686f, 0.51509416f, + 1.39789927f, 9.43515587f, 2.04256630f, 0.13985133f, 4.65010691f, + 2.40911126f, -0.36255789f, -3.06867862f, -0.45225358f, -1.56778407f, + 6.05917358f, -1.09891272f, 1.77184200f, 6.46248102f, 0.96042323f, + -0.24346280f, 4.63436460f, -4.69907761f, 1.25187206f, 11.46173859f, + -2.21917558f, 1.28007793f, 6.92173195f, 2.11268163f, -3.47389889f, + 5.08722782f, -3.03950930f, -4.17154264f, 11.30568314f, 0.80361372f, + 2.53214502f, 7.18707085f, -4.49114513f, 2.85449266f, 10.14906883f, + -0.31974933f, -0.84472644f, -0.52459574f, 0.12921631f, -1.81390119f, + 2.76170087f, 1.03982210f, 2.91744232f, -0.29048753f, 5.87453508f, + -1.53684759f, 1.85800636f, -0.91404629f, 1.28954852f, 5.11354685f, + -2.47475505f, -1.33179152f, 2.58552408f, 1.37316465f, -3.32339454f, + 1.54122913f, 3.24953628f, -0.29758382f, 2.82391763f, -1.51142192f, + -1.22699404f, 6.75745535f, 0.65452754f, -3.29385471f, 2.06008053f, + 2.53172946f, -4.23532820f, -1.53909743f, -0.07010663f, -1.42173731f, + 7.29031610f, -0.18448229f, 4.59496164f, 6.73027277f, 0.73441899f, + 0.14426160f, 4.14915276f, -2.97010231f, 6.05851364f, 4.95218086f, + -2.39145470f, 2.40494704f, 2.10288811f, 0.53503096f, 1.44511235f, + 6.66344261f, -3.05803776f, 7.21418667f, 3.30303526f, -0.24163735f, + 3.47409391f, 3.64520788f, 2.15189481f, -3.11243272f, 3.62310791f, + 0.37379482f, 0.40865007f, -0.83132005f, -4.78246069f, 2.07030797f, + 6.51765442f, 3.16178989f, 5.06180477f, 3.78434467f, -0.96689719f, + 0.35965276f, 5.89967585f, 1.40294051f, 1.11952639f, 10.59778214f, + 0.26739889f, -1.61297631f, 6.24801159f, -0.93914318f, -0.57812452f, + 9.92604542f, -0.73025000f, -3.38530874f, 2.45646000f, -2.47949195f, + 0.51638460f, 10.65636063f, 1.97816694f, -3.00407791f, 2.66914415f, + -0.81951088f, -0.23316640f, 2.40737987f, -2.70007610f, 1.51531935f, + 4.08860207f, -0.27552786f, -1.31721711f, 7.11568260f, -3.33498216f, + -4.02545023f, 7.22675610f, -0.81690705f, -2.52689576f, 1.04016697f, + -0.79291463f, -0.34875512f, 10.00498390f, -4.24167728f, 1.46162593f, + 11.82569408f, -1.70359993f, -0.30161047f, 16.44085884f, -0.82253462f, + -0.09435523f, 6.13080597f, -0.20259480f, 0.68308711f, 6.15663004f, + -6.61776876f, 0.33295766f, 2.55449438f, -0.17819691f, -1.14892209f, + 5.56776142f, 1.99279118f, 1.33035934f, 4.45823956f, 3.34916544f, + -2.59905386f, 6.16164446f, -2.03881931f, -2.45273542f, 12.46793365f, + -2.22743297f, 2.83738565f, 8.48628139f, -1.39347959f, -1.30867767f, + 11.08041477f, -4.00363779f, 2.09183025f, 11.30395889f, -2.20504737f, + 1.37426853f, 8.98735619f, 1.04676604f, -0.72757077f, 8.28050232f, + -6.70741081f, -0.65798020f, 5.68592072f, -0.60760021f, 0.35854483f, + 6.26852131f, 1.94100165f, 1.32112014f, 0.80987954f, -1.74617672f, + -0.25434083f, 7.16045523f, 1.58884013f, -2.64847064f, 13.14820385f, + 1.21393633f, -2.47258949f, 9.41650105f, -0.79384226f, 2.48954105f, + 10.95629311f, 0.47723705f, 4.02126694f, 8.02593136f, -2.20726371f, + -1.18794477f, 1.50836647f, 0.93118095f, -1.73513174f, 8.85493565f, + -2.99670315f, -0.79055870f, 2.39473820f, 2.05046916f, -2.38055134f, + 11.82299423f, 0.15609655f, 0.68744308f, 5.66401434f, -0.69281673f, + 2.09855556f, 7.74626589f, -0.34283102f, 1.00542057f, 9.95838642f, + 0.80161905f, 2.33455157f, 9.80057335f, -0.93561798f, 2.56991577f, + 8.29711342f, 0.94213426f, 0.44209945f, 11.70259857f, 0.92710167f, + 2.60957146f, 0.24971688f, -0.86529571f, 3.78628922f, 6.80884457f, + -0.68178189f, 2.21103406f, 3.18895817f, 0.60283208f, -2.92716241f, + 6.72060776f, -1.06625068f, 2.56543374f, 9.97404480f, 3.58080721f, + -0.94936347f, 10.16736984f, -1.38464379f, 1.18191063f, 6.66179037f, + -3.56115270f, 0.32329530f, 10.90870762f, 2.20638227f, 0.19653285f, + 7.34650040f, -3.63859272f, -1.03027737f, 5.98829985f, -3.66606474f, + -3.89746714f, 8.63469028f, 1.22569811f, 1.63240814f, 3.74385309f, + 0.58243257f, -0.56981975f, 3.69260955f, 1.00979900f, -1.44030499f, + 8.57058144f, -1.10648811f, 1.20474911f, 5.43133020f, -2.14822555f, + -0.07928789f, 11.25825310f, 0.19645604f, -5.49546146f, 10.41917038f, + -0.68178523f, -2.99639869f, 6.50054455f, 0.46488351f, -5.42328453f, + 9.09500027f, -2.82107449f, 0.05601966f, 15.34610748f, -0.06820253f, + 3.86699796f, 10.73316956f, -3.04795432f, -0.14702171f, 5.64813185f, + 1.44028485f, -2.47596145f, 0.07280898f, -3.03187990f, -1.35183525f, + 9.35835648f, 2.72966957f, 1.88199532f, 10.36187744f, -0.22834805f, + -3.26738238f, 6.92025137f, -2.34061313f, 4.77379704f, 5.28559113f, + -2.96323752f, -1.76186585f, 5.94436455f, 0.38647744f, -5.73869514f, + 6.76849556f, 1.40892124f, -1.19068217f, 5.37919092f, -6.65328646f, + 3.62782669f, 12.34744644f, 2.44762444f, -4.19242620f, 6.14906216f, + 0.08121119f, 0.61355996f, 2.69666457f, -1.88962626f, -0.55314136f, + 1.84937525f, 1.56048691f, 1.17460012f, 3.75674725f, 1.06198275f, + -5.74625874f, 5.41645575f, -1.28946674f, -1.51689398f, 4.32400894f, + -0.05222082f, -4.83948946f, 1.80747867f, 1.63144708f, -2.73887825f, + 1.63975775f, -2.02163982f, -0.16210437f, 2.93518686f, 1.14427686f, + -2.83246303f, 4.79283667f, 2.69697428f, -3.12678456f, -1.19225168f, + -2.37022972f, -3.09429741f, 1.94225383f, -1.13747168f, -2.55048585f, + 5.40242243f, 1.12777328f, 3.43713188f, 3.62658787f, -2.16878843f, + 0.30164462f, 2.97407579f, -0.07275413f, -1.31149673f, 4.70066261f, + -2.01323795f, 4.85255766f, 4.59128904f, 1.68084168f, 1.60336494f, + 6.58138466f, -1.04759812f, 2.69906545f, 3.55769277f, -0.74327278f, + 2.65819693f, 5.39528131f, 2.11248922f, -1.06446671f, 5.24546766f, + -2.43146014f, 4.58907509f, 0.06521678f, -2.24503994f, 2.45722699f, + 6.94863081f, 0.35258654f, 2.83396196f, 9.92525196f, -1.12225175f, + -0.34365177f, 7.19116688f, -4.39813757f, 0.46517885f, 13.22028065f, + -2.57483673f, -6.37226963f, 7.58046293f, -2.74600363f, 0.42231262f, + 8.04881668f, 0.17289802f, -0.53447008f, 16.55157471f, -5.63614368f, + 0.39288223f, 3.37079263f, 1.26484549f, -0.12820500f, 8.46440125f, + -4.39304399f, 2.97676420f, 0.65650189f, 0.83158541f, -1.11556435f, + 6.32885838f, -0.36087769f, 2.80724382f, 9.90292645f, 1.15936041f, + 0.20947981f, 6.91249275f, -2.67404819f, 2.93782163f, 6.65656614f, + -2.30828357f, 2.98214006f, 6.80611229f, -4.93821478f, -7.66555262f, + 7.59763002f, -0.54159302f, 3.87403512f, 12.42607784f, 2.59284401f, + -0.23375344f, 8.95293331f, -0.71807784f, 0.61873478f, 8.66713524f, + 1.24289191f, -2.37835455f, 2.08071637f, -0.88315344f, -3.41891551f, + 6.85245323f, 1.73007369f, 1.02169311f, 7.69170332f, -2.85411978f, + 2.69790673f, 8.12906551f, -1.19351399f, -2.26442742f, 12.26104450f, + -0.75579089f, -1.73274946f, 10.68729019f, 2.20655656f, -0.90522075f, + 12.42165184f, -1.67929137f, 2.44851565f, 9.31565762f, -0.06645700f, + 1.52762020f, 6.18427515f, -1.68882596f, 3.70261097f, 3.02252960f, + -3.44125366f, -1.31575799f, 2.84617424f, -0.96849400f, -4.52356243f, + 9.95027161f, 0.19966406f, -0.78874779f, 8.18595028f, -4.08300209f, + 1.75126517f, 0.96418417f, -4.04913044f, -0.95200396f, 12.03637886f, + -0.03041124f, 0.41642749f, 8.88267422f, -3.24985337f, -2.24919462f, + 7.32566118f, 0.16964148f, -2.74123430f, 7.05264473f, -3.30191112f, + 0.17163286f, 4.81851053f, -1.64463484f, -0.85933101f, 7.29276276f, + 2.34066939f, -2.14860010f, 3.46148157f, -0.01782012f, 1.51504040f, + 4.79304934f, 1.85281146f, -1.70663762f, 6.93470192f, -4.15440845f, + -1.25983095f, 10.52491760f, 0.42930329f, -1.85146868f, 11.70042324f, + -0.41704914f, 3.83796859f, 9.21148491f, -2.79719448f, 0.79470479f, + 6.26926661f, -5.85230207f, 3.95105338f, 7.84790897f, -1.38680744f, + -1.78099084f, 11.95235348f, -2.99841452f, -1.34507811f, 6.15714645f, + -1.07552516f, -2.81228638f, 1.66234732f, -4.55166149f, -1.92601109f, + 8.64634514f, -0.48158705f, 3.31595659f, 7.67371941f, 2.56964207f, + 0.12107098f, 4.56467867f, -0.93541539f, 1.39432955f, 11.99714088f, + 1.05353570f, -2.13099813f, 3.67617917f, 3.45895386f, 1.37365830f, + 8.74344158f, -4.17585802f, 1.43908918f, 6.28764772f, 3.97346330f, + -0.69144285f, 9.07983303f, -0.41635889f, -0.14965028f, 8.85469818f, + 1.11306190f, 2.59440994f, 5.38982344f, -1.07948279f, 1.37252975f, + 10.26984596f, -0.09318046f, 2.73104119f, 12.45902252f, -1.55446684f, + -2.76124811f, 12.19395065f, -0.51846564f, 1.02764034f, 11.42673588f, + -0.95940983f, -0.04781032f, 8.78379822f, -4.88957930f, 0.32534006f, + 11.97696400f, -3.35108662f, 1.95104563f, 4.46915388f, -2.32061648f, + 3.45230985f, 8.29983711f, 2.81034684f, -2.35529327f, 6.07801294f, + -0.98105043f, -0.05359888f, 2.52291036f, -0.01986909f, -2.35321999f, + 10.51954269f, 2.11145401f, 3.53506470f, 7.29093266f, 0.03721160f, + -1.13496494f, 7.43886709f, -5.84201956f, 2.50796294f, 12.14647675f, + 2.77490377f, -2.18896222f, 6.05641937f, 5.32617044f, 1.04221284f, + 10.79106712f, -2.95749092f, -2.75414610f, 11.30037117f, -3.40654182f, + -2.24673963f, 7.49126101f, 0.70811015f, -6.18003702f, 13.83951187f, + -1.01204085f, 1.36298490f, -1.04451632f, 2.42435336f, -0.02346706f, + -0.85528886f, 1.04731262f, 0.22192979f, 4.15708160f, 0.34933877f, + 0.04814529f, 2.24107265f, 0.49676740f, -1.47752666f, 0.45040059f, + -0.70471478f, -1.19759345f, 0.21711677f, 0.88461423f, -2.76830935f, + 5.52066898f, 1.97664857f, -1.75381601f, 3.45877838f, 1.52617192f, + -1.61350942f, 0.85337949f, 1.97610760f, -3.40310287f, 3.40319014f, + -3.38691044f, -0.71319139f, 1.65463758f, -0.60680127f, -1.80700517f, + 8.02592373f, 2.59627104f, 2.65895891f, 5.93043184f, -4.48425817f, + 3.92670918f, 4.19496679f, -2.28286791f, 6.41634607f, 5.72330523f, + 1.16269672f, -0.28753027f, 2.46342492f, 0.36693189f, 0.26712441f, + 6.37652683f, -2.50139046f, 2.43923736f, 5.56310415f, 0.98065847f, + 1.04267502f, 4.16403675f, -0.04966142f, 4.40897894f, 3.72905660f, + -3.46129870f, 3.59962773f, 1.34830284f, -1.76661730f, 0.47943926f, + 5.29946661f, -1.12711561f, 1.26970029f, 15.17655945f, -1.50971997f, + 5.81345224f, 8.48562050f, -4.36049604f, 2.48144460f, 8.23780441f, + -3.46030426f, -0.84656560f, 5.94946814f, 1.12747943f, -2.65683913f, + 8.69085693f, 1.31309867f, -2.79958344f, 8.76840591f, -1.56444156f, + 1.62710834f, 2.41177034f, -0.72804940f, 5.70619011f, 4.67169666f, + -0.86167198f, -1.83803177f, 2.96346045f, 2.82692933f, -2.81557131f, + 7.11113358f, -1.90071094f, 2.54244423f, 11.19284058f, -0.06298946f, + -1.71517313f, 12.98388577f, 0.84510714f, 3.00816894f, 2.57200313f, + 0.03899818f, -1.49330592f, 9.60099125f, -3.59513044f, -1.30045319f, + 7.09241819f, -0.65233821f, -2.33627677f, 8.81366920f, 0.84154201f, + 1.03312039f, 9.85289097f, 0.19351870f, 1.78496623f, 7.34631205f, + -2.16530800f, -0.65016162f, 2.46842360f, 0.24016285f, -1.24308395f, + 4.78175163f, -0.97682536f, 2.20942235f, 6.68382788f, 3.76786447f, + -1.44454038f, 6.26453733f, -3.23575711f, -2.30137897f, 9.53092670f, + -5.55222607f, 3.25999236f, 9.37559509f, 1.86339056f, -0.23551451f, + 10.23400211f, 3.93031883f, -0.52629089f, 7.85724449f, -2.91549587f, + 4.46612740f, 5.66530371f, -2.70820427f, 4.81359577f, 10.31247330f, + 1.92230141f, 2.53931546f, 0.74986327f, 1.70303428f, 0.48063779f, + 5.31099129f, -0.78976244f, 3.75864220f, 4.23051405f, 2.34042454f, + -7.98193836f, 9.83987141f, -1.46722627f, 3.54497814f, 10.36455154f, + -4.51249075f, 0.77715248f, 7.78694630f, -4.59989023f, -2.49585629f, + 9.90296268f, 1.38535416f, 1.17441154f, 10.10452843f, -0.98628229f, + 0.60194463f, 9.12639141f, -3.90754628f, 2.88526392f, 7.24123430f, + -0.15283313f, -0.75728363f, -1.15116858f, -2.53791571f, 0.77229571f, + 6.44114161f, 0.02646767f, 4.95463037f, 7.21066380f, 1.79384065f, + 0.73250306f, 8.04447937f, 0.32576546f, -0.79447043f, 10.12717724f, + 2.33392906f, 1.30716443f, 12.36073112f, -0.36694977f, -1.20438910f, + 7.03105593f, 0.59557682f, 0.69267452f, 10.18113136f, 2.49944925f, + -0.42229167f, 8.83143330f, -1.18805945f, -2.87509322f, 4.53596449f, + 4.09732771f, -3.39088297f, -1.02536607f, 0.82119560f, -3.47302604f, + 9.29991817f, 0.21001509f, 4.97036457f, 9.50018406f, 1.04420102f, + 1.96560478f, 10.74769592f, -6.22709799f, 3.11690164f, 5.06759691f, + -1.23724771f, -3.05831861f, 8.12925529f, -1.93435478f, -1.10151744f, + 9.32263088f, -0.04249470f, -5.98547363f, 10.49398136f, 0.26400441f, + -0.78915191f, 13.28219604f, 2.99276900f, 0.74853164f, 2.49364305f, + -3.43529654f, 4.05278301f, 2.13498688f, -2.35444307f, -0.79900265f, + 4.66968822f, -0.31095147f, 3.60674143f, 12.37222099f, -0.07855003f, + -3.30292702f, 12.15215874f, 0.60886210f, 2.87075138f, 7.75271845f, + 0.38044083f, 3.34402204f, 6.40583277f, -0.87888050f, 0.67438459f, + 6.91080809f, 1.98332930f, -0.08303714f, 8.08630371f, -0.16772588f, + -2.74058914f, 7.17253590f, -2.69122696f, 1.48173678f, 8.99470139f, + -1.43302310f, -0.88651133f, 2.66944790f, -0.29186964f, 2.00838661f, + 5.09587479f, -0.76676071f, -2.88322186f, 8.31110573f, -0.14550979f, + -1.37726915f, 10.28355122f, -1.60575438f, -0.04118848f, 9.97510815f, + 0.14440438f, -3.24632120f, 9.00034523f, 4.14319563f, -1.31023729f, + 7.16950464f, -0.70428526f, 2.01559544f, 7.26155043f, 2.40816474f, + 2.09847403f, 7.31264496f, -0.75401551f, 2.13392544f, 7.03648758f, + 1.04036045f, -1.15636516f, 1.09634531f, -0.06340861f, -0.58107805f, + -0.65623116f, 1.18972754f, -0.80717683f, 1.40118241f, -0.61932516f, + -3.60596156f, 1.59904599f, -2.23774099f, -1.13721037f, 3.89620137f, + -0.09115922f, -7.51356888f, 2.36975193f, -1.42520905f, -2.34173775f, + 3.33830214f, -2.74016523f, -3.04115510f, 6.00119495f, -1.36084354f, + -2.45065260f, 4.56992292f, -3.02825928f, -3.74182844f, 5.11069250f, + -0.91531068f, -2.31385994f, 1.83399653f, 3.39370203f, -3.60886002f}); + auto exp = NDArrayFactory::create( + 'c', {4, 4, 4, 3}, + {7.97172260f, 0.06878620f, 2.27749538f, 7.29276514f, -0.14074677f, + 0.65480286f, 5.70313978f, -0.06546132f, 0.35443667f, 3.70382833f, + -0.84020567f, 0.63826996f, 8.60301399f, -0.38236514f, 1.55177069f, + 7.37542057f, -0.99374938f, -0.29971302f, 8.84352493f, -0.67121059f, + 0.43132120f, 4.78175592f, -1.25070143f, -1.91523600f, 6.03855371f, + -0.00292124f, -1.11214364f, 7.90158176f, -0.57949901f, -0.96735370f, + 7.81192017f, -0.53255427f, -0.48009714f, 3.16953635f, 0.08353355f, + -1.54299748f, 3.74821687f, 1.69396687f, 0.72724354f, 5.42915201f, + -1.13686812f, -0.71793109f, 5.78376389f, -0.72239977f, -0.60055625f, + 2.53636408f, 0.56777251f, -2.07892323f, 6.08064651f, 0.68620735f, + 2.54017019f, 5.65828180f, -0.68255502f, 1.47283304f, 6.10842514f, + -0.39655915f, 0.28380761f, 1.96707797f, -1.98206317f, 0.94027776f, + 4.71811438f, 0.32104525f, -0.92409706f, 8.34588146f, -1.05581069f, + -0.55217457f, 9.58440876f, -0.96549922f, 0.45820439f, 5.65453672f, + -2.50953507f, -0.71441835f, 8.03059578f, -0.21281289f, 0.92125505f, + 9.26900673f, -0.35963219f, -0.70039093f, 8.59924412f, -1.22358346f, + 0.81318003f, 3.85920119f, -0.01305223f, -1.09234154f, 6.33158875f, + 1.28094780f, -1.48926139f, 4.94969177f, -0.77126902f, -1.97033751f, + 5.64381838f, -0.16285487f, -1.31277227f, 2.39893222f, -1.32902908f, + -1.39609122f, 6.47572327f, -0.45267010f, 1.55727172f, 6.70965624f, + -1.68735468f, -0.05672536f, 7.25092363f, -0.64613032f, 0.67050058f, + 3.60789680f, -2.05948973f, 2.22687531f, 8.15202713f, -0.70148355f, + 1.28314006f, 8.14842319f, -1.88807654f, -1.04808438f, 8.45500565f, + -0.76425624f, 0.94542569f, 4.56179953f, -0.28786001f, -2.04502511f, + 8.46278095f, -0.31019822f, 0.07339200f, 9.34214592f, -0.61948007f, + 0.52481830f, 8.32515621f, -1.52418160f, 0.49678251f, 5.11082315f, + -1.09908783f, -0.52969611f, 5.27806664f, 0.88632923f, 0.66754371f, + 4.75839233f, 0.48928693f, -0.68036932f, 6.56925392f, -0.02949905f, + -2.99189186f, 4.46320581f, -0.64534980f, -0.29516968f, 8.60809517f, + -1.13120568f, 3.41720533f, 5.84243155f, -1.24109328f, 0.89566326f, + 5.99578333f, -0.42496428f, 2.07076764f, 3.17812920f, -0.81566459f, + -0.14363396f, 6.55184317f, 0.39633346f, -0.43852386f, 8.70214558f, + -2.24613595f, 0.30708700f, 8.73882294f, -0.53545928f, 1.54409575f, + 4.49452257f, -0.16509305f, 0.19028664f, 8.24897003f, 0.44750381f, + 2.15448594f, 8.97640514f, -0.77728152f, 0.57272542f, 9.03467560f, + 0.47173575f, -1.10807717f, 3.30056310f, -0.43268481f, -0.41470885f, + 3.53798294f, -0.08546703f, -2.16840744f, 6.18733406f, -0.17871059f, + -2.59837723f, 5.94218683f, -1.02990067f, -0.49760687f, 3.76938033f, + 0.86383581f, -1.91504073f}); + + sd::ops::avgpool2d op; + auto result = op.evaluate({&input}, {3, 3, 3, 3, 0, 0, 1, 1, 1, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + // z->printIndexedBuffer("z"); + // exp.printIndexedBuffer("e"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, avgpool2d_11) { - int inOutH = 5;// 35; - int inOutW = 5;// 35; - int inOutC = 10;// 192; + int inOutH = 5; // 35; + int inOutW = 5; // 35; + int inOutC = 10; // 192; - auto x = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); - x.linspace(1.0); + auto x = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); + x.linspace(1.0); - sd::ops::avgpool2d op; - auto result = op.evaluate({&x}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {3, 3, 1, 1, 0, 0, 1, 1, 1, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - int totalPadHeight = (inOutH - 1) * 1 + 3 - inOutH; - int padTop = totalPadHeight / 2; - int padBottom = totalPadHeight - totalPadHeight / 2; + int totalPadHeight = (inOutH - 1) * 1 + 3 - inOutH; + int padTop = totalPadHeight / 2; + int padBottom = totalPadHeight - totalPadHeight / 2; - int k = 3; + int k = 3; - auto m = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); - auto c = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); + auto m = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); + auto c = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); - for (int h = 0; h < inOutH; h++) { - for (int w = 0; w < inOutW; w++) { - int hFrom = h - padTop; - int wFrom = w - padBottom; + for (int h = 0; h < inOutH; h++) { + for (int w = 0; w < inOutW; w++) { + int hFrom = h - padTop; + int wFrom = w - padBottom; - int hTo = hFrom + k; - int wTo = wFrom + k; + int hTo = hFrom + k; + int wTo = wFrom + k; - hFrom = sd::math::nd4j_max(0, hFrom); - wFrom = sd::math::nd4j_max(0, wFrom); + hFrom = sd::math::nd4j_max(0, hFrom); + wFrom = sd::math::nd4j_max(0, wFrom); - hTo = sd::math::nd4j_min(inOutH, hTo); - wTo = sd::math::nd4j_min(inOutW, wTo); + hTo = sd::math::nd4j_min(inOutH, hTo); + wTo = sd::math::nd4j_min(inOutW, wTo); - int idxOut[4]; - int idxIn[4]; - for (int ch = 0; ch < inOutC; ch++) { - idxOut[1] = h; - idxOut[2] = w; - idxOut[3] = ch; - idxIn[3] = ch; + int idxOut[4]; + int idxIn[4]; + for (int ch = 0; ch < inOutC; ch++) { + idxOut[1] = h; + idxOut[2] = w; + idxOut[3] = ch; + idxIn[3] = ch; - for (int kh = hFrom; kh < hTo; kh++) { - for (int kw = wFrom; kw < wTo; kw++) { - idxIn[1] = kh; - idxIn[2] = kw; + for (int kh = hFrom; kh < hTo; kh++) { + for (int kw = wFrom; kw < wTo; kw++) { + idxIn[1] = kh; + idxIn[2] = kw; - auto inVal = x.e(0, kh, kw, ch); - m.p(0, h, w, ch, inVal + m.e(0, h, w, ch)); - c.p(0, h, w, ch, 1 + c.e(0, h, w, ch)); - } - } - } + auto inVal = x.e(0, kh, kw, ch); + m.p(0, h, w, ch, inVal + m.e(0, h, w, ch)); + c.p(0, h, w, ch, 1 + c.e(0, h, w, ch)); + } } + } } - m /= c; - - ASSERT_EQ(m, *z); - + } + m /= c; + ASSERT_EQ(m, z); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, avgpool2d_12) { - - int bS=4, iH=10,iW=10, iC=3, kH=3,kW=3, sH=3,sW=3, pH=0,pW=0, dH=1,dW=1; - int oH=4, oW=4; - int paddingMode = 1; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NHWC, 0-NDHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto expected = NDArrayFactory::create('c', {bS, oH, oW, iC}, { 17.5, 18.5, 19.5, 25. , 26. , 27. , 34. , 35. , 36. , 41.5, 42.5, 43.5, 92.5, 93.5, 94.5, 100. , 101. , 102. , 109. , 110. , 111. , 116.5, 117.5, 118.5, - 182.5, 183.5, 184.5, 190. , 191. , 192. , 199. , 200. , 201. , 206.5, 207.5, 208.5, 257.5, 258.5, 259.5, 265. , 266. , 267. , 274. , 275. , 276. , 281.5, 282.5, 283.5, - 317.5, 318.5, 319.5, 325. , 326. , 327. , 334. , 335. , 336. , 341.5, 342.5, 343.5, 392.5, 393.5, 394.5, 400. , 401. , 402. , 409. , 410. , 411. , 416.5, 417.5, 418.5, - 482.5, 483.5, 484.5, 490. , 491. , 492. , 499. , 500. , 501. , 506.5, 507.5, 508.5, 557.5, 558.5, 559.5, 565. , 566. , 567. , 574. , 575. , 576. , 581.5, 582.5, 583.5, - 617.5, 618.5, 619.5, 625. , 626. , 627. , 634. , 635. , 636. , 641.5, 642.5, 643.5, 692.5, 693.5, 694.5, 700. , 701. , 702. , 709. , 710. , 711. , 716.5, 717.5, 718.5, - 782.5, 783.5, 784.5, 790. , 791. , 792. , 799. , 800. , 801. , 806.5, 807.5, 808.5, 857.5, 858.5, 859.5, 865. , 866. , 867. , 874. , 875. , 876. , 881.5, 882.5, 883.5, - 917.5, 918.5, 919.5, 925. , 926. , 927. , 934. , 935. , 936. , 941.5, 942.5, 943.5, 992.5, 993.5, 994.5,1000. , 1001. , 1002. ,1009. , 1010. , 1011. ,1016.5, 1017.5, 1018.5, - 1082.5, 1083.5, 1084.5,1090. , 1091. , 1092. ,1099. , 1100. , 1101. ,1106.5, 1107.5, 1108.5,1157.5, 1158.5, 1159.5,1165. , 1166. , 1167. ,1174. , 1175. , 1176. ,1181.5, 1182.5, 1183.5}); - input.linspace(1.); - - sd::ops::avgpool2d op; - auto results = op.evaluate({&input}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - //output->printIndexedBuffer("output"); - //expected.printIndexedBuffer("expected"); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + int bS = 4, iH = 10, iW = 10, iC = 3, kH = 3, kW = 3, sH = 3, sW = 3, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 4; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NHWC, 0-NDHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, oH, oW, iC}, + {17.5, 18.5, 19.5, 25., 26., 27., 34., 35., 36., + 41.5, 42.5, 43.5, 92.5, 93.5, 94.5, 100., 101., 102., + 109., 110., 111., 116.5, 117.5, 118.5, 182.5, 183.5, 184.5, + 190., 191., 192., 199., 200., 201., 206.5, 207.5, 208.5, + 257.5, 258.5, 259.5, 265., 266., 267., 274., 275., 276., + 281.5, 282.5, 283.5, 317.5, 318.5, 319.5, 325., 326., 327., + 334., 335., 336., 341.5, 342.5, 343.5, 392.5, 393.5, 394.5, + 400., 401., 402., 409., 410., 411., 416.5, 417.5, 418.5, + 482.5, 483.5, 484.5, 490., 491., 492., 499., 500., 501., + 506.5, 507.5, 508.5, 557.5, 558.5, 559.5, 565., 566., 567., + 574., 575., 576., 581.5, 582.5, 583.5, 617.5, 618.5, 619.5, + 625., 626., 627., 634., 635., 636., 641.5, 642.5, 643.5, + 692.5, 693.5, 694.5, 700., 701., 702., 709., 710., 711., + 716.5, 717.5, 718.5, 782.5, 783.5, 784.5, 790., 791., 792., + 799., 800., 801., 806.5, 807.5, 808.5, 857.5, 858.5, 859.5, + 865., 866., 867., 874., 875., 876., 881.5, 882.5, 883.5, + 917.5, 918.5, 919.5, 925., 926., 927., 934., 935., 936., + 941.5, 942.5, 943.5, 992.5, 993.5, 994.5, 1000., 1001., 1002., + 1009., 1010., 1011., 1016.5, 1017.5, 1018.5, 1082.5, 1083.5, 1084.5, + 1090., 1091., 1092., 1099., 1100., 1101., 1106.5, 1107.5, 1108.5, + 1157.5, 1158.5, 1159.5, 1165., 1166., 1167., 1174., 1175., 1176., + 1181.5, 1182.5, 1183.5}); + input.linspace(1.); + + sd::ops::avgpool2d op; + auto results = op.evaluate( + {&input}, {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, 0, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + // output->printIndexedBuffer("output"); + // expected.printIndexedBuffer("expected"); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, avgpool2d_13) { - - const int bS = 2; // batch size - const int iD = 1; // input depth (number of picture channels, for example rgb=3) - const int iH = 28; // picture height in pixels - const int iW = 28; // picture width in pixels - const int kH = 5; // kernel height in pixels - const int kW = 5; // kernel width in pixels - const int sH = 1; // stride step in horizontal direction - const int sW = 1; // stride step in vertical direction - const int pH = 0; // padding height - const int pW = 0; // padding width - const int dH = 2; // dilation height - const int dW = 2; // dilation width - const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height - const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width - - auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); - auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); - // auto z('c',{bS,iD,oH,oW}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - // variableSpace->putVariable(1, &z); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1}); - std::vector* argI = block->getIArguments(); - *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - sd::ops::avgpool2d pooling; - Nd4jStatus status = pooling.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - ASSERT_TRUE(exp.isSameShape(result)); - - - delete variableSpace; - delete block; + const int bS = 2; // batch size + const int iD = + 1; // input depth (number of picture channels, for example rgb=3) + const int iH = 28; // picture height in pixels + const int iW = 28; // picture width in pixels + const int kH = 5; // kernel height in pixels + const int kW = 5; // kernel width in pixels + const int sH = 1; // stride step in horizontal direction + const int sW = 1; // stride step in vertical direction + const int pH = 0; // padding height + const int pW = 0; // padding width + const int dH = 2; // dilation height + const int dW = 2; // dilation width + const int oH = + (iH - kH - (kH - 1) * (dH - 1) + 2 * pH) / sH + 1; // output height + const int oW = + (iW - kW - (kW - 1) * (dW - 1) + 2 * pW) / sW + 1; // output width + + auto x = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto exp = NDArrayFactory::create('c', {bS, iD, oH, oW}); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->appendI( + {kH, kW, sH, sW, pH, pW, dW, dH, 0, 0, + 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::avgpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + ASSERT_TRUE(exp.isSameShape(*result)); + + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, avgpool2d_14) { - const int bS = 2; - const int iD = 1; - const int iH = 28; - const int iW = 28; - const int kH = 5; - const int kW = 5; - const int sH = 1; - const int sW = 1; - const int pH = 0; - const int pW = 0; - const int dH = 1; - const int dW = 1; - const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height - const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width - - - auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); - auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); - // auto z('c',{bS,iD,oH,oW}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - // variableSpace->putVariable(1, &z); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1}); - std::vector* argI = block->getIArguments(); - *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - sd::ops::avgpool2d pooling; - Nd4jStatus status = pooling.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - // result->printShapeInfo(); - ASSERT_TRUE(exp.isSameShape(result)); - - delete variableSpace; - delete block; + const int bS = 2; + const int iD = 1; + const int iH = 28; + const int iW = 28; + const int kH = 5; + const int kW = 5; + const int sH = 1; + const int sW = 1; + const int pH = 0; + const int pW = 0; + const int dH = 1; + const int dW = 1; + const int oH = + (iH - kH - (kH - 1) * (dH - 1) + 2 * pH) / sH + 1; // output height + const int oW = + (iW - kW - (kW - 1) * (dW - 1) + 2 * pW) / sW + 1; // output width + + auto x = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto exp = NDArrayFactory::create('c', {bS, iD, oH, oW}); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->appendI( + {kH, kW, sH, sW, pH, pW, dW, dH, 0, 0, + 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::avgpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + // result->printShapeInfo(); + ASSERT_TRUE(exp.isSameShape(*result)); + + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, Avgpool2d_test15) { - const int bS = 2; - const int iD = 1; - const int iH = 28; - const int iW = 28; - const int kH = 5; - const int kW = 5; - const int sH = 1; - const int sW = 1; - const int pH = 0; - const int pW = 0; - const int dH = 1; - const int dW = 1; - const int oH = (int) sd::math::nd4j_ceil(iH * 1.f / sH); - const int oW = (int) sd::math::nd4j_ceil(iW * 1.f / sW); - - - auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); - auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); - // auto z('c',{bS,iD,oH,oW}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - // variableSpace->putVariable(1, &z); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1}); - std::vector* argI = block->getIArguments(); - *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 1, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - sd::ops::avgpool2d pooling; - Nd4jStatus status = pooling.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - // result->printShapeInfo(); - ASSERT_TRUE(exp.isSameShape(result)); - - delete variableSpace; - delete block; + const int bS = 2; + const int iD = 1; + const int iH = 28; + const int iW = 28; + const int kH = 5; + const int kW = 5; + const int sH = 1; + const int sW = 1; + const int pH = 0; + const int pW = 0; + const int dH = 1; + const int dW = 1; + const int oH = (int)sd::math::nd4j_ceil(iH * 1.f / sH); + const int oW = (int)sd::math::nd4j_ceil(iW * 1.f / sW); + + auto x = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto exp = NDArrayFactory::create('c', {bS, iD, oH, oW}); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->appendI( + {kH, kW, sH, sW, pH, pW, dW, dH, 1, 0, + 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::avgpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + // result->printShapeInfo(); + ASSERT_TRUE(exp.isSameShape(*result)); + + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, avgpool2d_16) { + int bS = 2, iH = 4, iW = 4, iC = 2, kH = 2, kW = 2, sH = 2, sW = 2, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NHWC, 0-NDHW - int bS=2, iH=4,iW=4, iC=2, kH=2,kW=2, sH=2,sW=2, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int paddingMode = 1; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NHWC, 0-NDHW + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray output('f', {bS, oH, oW, iC}, sd::DataType::FLOAT32); + NDArray expected('c', {bS, oH, oW, iC}, + {6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, + 42.f, 43.f, 54.f, 55.f, 58.f, 59.f}, + sd::DataType::FLOAT32); - NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray output('f', {bS, oH, oW, iC}, sd::DataType::FLOAT32); - NDArray expected('c', {bS, oH, oW, iC}, {6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, 54.f, 55.f, 58.f, 59.f}, sd::DataType::FLOAT32); + input.linspace(1.); - input.linspace(1.); + sd::ops::avgpool2d op; + auto status = op.execute( + {&input}, {&output}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, 0, dataFormat}, {}); - sd::ops::avgpool2d op; - auto status = op.execute({&input}, {&output}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat}, {}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(Status::OK(), status); + // output.printBuffer(); + // expected.printIndexedBuffer("expected"); - // output.printBuffer(); - //expected.printIndexedBuffer("expected"); - - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, biasadd_1) { - auto x = NDArrayFactory::create('c', {2, 3, 3, 2}); - auto bias = NDArrayFactory::create('c', {2}, {1, 2}); - auto exp = NDArrayFactory::create('c', {2, 3, 3, 2}, {1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f}); - - sd::ops::biasadd op; - auto result = op.evaluate({&x, &bias}, {}, {}, {}); + auto x = NDArrayFactory::create('c', {2, 3, 3, 2}); + auto bias = NDArrayFactory::create('c', {2}, {1, 2}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 3, 2}, + {1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, + 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, + 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::biasadd op; + auto result = op.evaluate({&x, &bias}, {}, {}, {}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, biasadd_2) { - auto x = NDArrayFactory::create('c', {2, 2, 3, 3}); - auto bias = NDArrayFactory::create('c', {2}, {1, 2}); - auto exp = NDArrayFactory::create('c', {2, 2, 3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2}); - - sd::ops::biasadd op; - auto result = op.evaluate({&x, &bias}, {}, {}, {true}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {2, 2, 3, 3}); + auto bias = NDArrayFactory::create('c', {2}, {1, 2}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 3, 3}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2}); - auto z = result.at(0); + sd::ops::biasadd op; + auto result = op.evaluate({&x, &bias}, {}, {}, {true}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, biasadd_3) { - auto x = NDArrayFactory::create('c', {2, 3}); - auto row = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto exp = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); - - sd::ops::biasadd op; - auto result = op.evaluate({&x, &row}, {}, {}, {true}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {2, 3}); + auto row = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); - auto z = result.at(0); + sd::ops::biasadd op; + auto result = op.evaluate({&x, &row}, {}, {}, {true}); - ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, biasadd_bp_1) { + NDArray x('c', {2, 2, 2, 3}, {1., 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}, + sd::DataType::FLOAT32); + NDArray gradO('c', {2, 2, 2, 3}, sd::DataType::FLOAT32); + NDArray bias('c', {3}, {-1., -2, -3}, sd::DataType::FLOAT32); - NDArray x('c', {2,2,2,3}, {1.,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, sd::DataType::FLOAT32); - NDArray gradO('c', {2,2,2,3}, sd::DataType::FLOAT32); - NDArray bias('c', {3}, {-1., -2, -3}, sd::DataType::FLOAT32); + NDArray expGradB('c', {3}, {9.2, 10., 10.8}, sd::DataType::FLOAT32); - NDArray expGradB('c', {3}, {9.2, 10. , 10.8}, sd::DataType::FLOAT32); + gradO.linspace(0.1, 0.1); - gradO.linspace(0.1, 0.1); + sd::ops::biasadd_bp op; + auto result = op.evaluate({&x, &bias, &gradO}, {}, {}, {false}); // NHWC - sd::ops::biasadd_bp op; - auto result = op.evaluate({&x, &bias, &gradO}, {}, {}, {false}); // NHWC + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto gradI = result.at(0); - auto gradB = result.at(1); - - ASSERT_TRUE(gradI->isSameShape(gradO)); - ASSERT_TRUE(gradI->equalsTo(gradO)); - - ASSERT_TRUE(gradB->isSameShape(expGradB)); - ASSERT_TRUE(gradB->equalsTo(expGradB)); + auto gradI = result.at(0); + auto gradB = result.at(1); + ASSERT_TRUE(gradI.isSameShape(gradO)); + ASSERT_TRUE(gradI.equalsTo(gradO)); + ASSERT_TRUE(gradB.isSameShape(expGradB)); + ASSERT_TRUE(gradB.equalsTo(expGradB)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, biasadd_bp_2) { + NDArray x('c', {2, 3, 2, 2}, {1., 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}, + sd::DataType::FLOAT32); + NDArray gradO('c', {2, 3, 2, 2}, sd::DataType::FLOAT32); + NDArray bias('c', {3}, {-1., -2, -3}, sd::DataType::FLOAT32); - NDArray x('c', {2,3,2,2}, {1.,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, sd::DataType::FLOAT32); - NDArray gradO('c', {2,3,2,2}, sd::DataType::FLOAT32); - NDArray bias('c', {3}, {-1., -2, -3}, sd::DataType::FLOAT32); - - NDArray expGradB('c', {3}, {6.8, 10., 13.2}, sd::DataType::FLOAT32); - - gradO.linspace(0.1, 0.1); + NDArray expGradB('c', {3}, {6.8, 10., 13.2}, sd::DataType::FLOAT32); - sd::ops::biasadd_bp op; - auto result = op.evaluate({&x, &bias, &gradO}, {}, {}, {true}); // NCHW + gradO.linspace(0.1, 0.1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::biasadd_bp op; + auto result = op.evaluate({&x, &bias, &gradO}, {}, {}, {true}); // NCHW - auto gradI = result.at(0); - auto gradB = result.at(1); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(gradI->isSameShape(gradO)); - ASSERT_TRUE(gradI->equalsTo(gradO)); - - ASSERT_TRUE(gradB->isSameShape(expGradB)); - ASSERT_TRUE(gradB->equalsTo(expGradB)); + auto gradI = result.at(0); + auto gradB = result.at(1); + ASSERT_TRUE(gradI.isSameShape(gradO)); + ASSERT_TRUE(gradI.equalsTo(gradO)); + ASSERT_TRUE(gradB.isSameShape(expGradB)); + ASSERT_TRUE(gradB.equalsTo(expGradB)); } TEST_F(DeclarableOpsTests4, biasadd_4) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - auto x = NDArrayFactory::create('c', {2, 3}); - auto y = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - auto z = NDArrayFactory::create('c', {2, 3}); - auto exp = NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f}); + auto x = NDArrayFactory::create('c', {2, 3}); + auto y = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto z = NDArrayFactory::create('c', {2, 3}); + auto exp = NDArrayFactory::create('c', {2, 3}, + {1.f, 2.f, 3.f, 1.f, 2.f, 3.f}); - sd::ops::biasadd op; - auto status = op.execute({&x, &y}, {&z}, {}, {}, {true}); - ASSERT_EQ(Status::OK(), status); + sd::ops::biasadd op; + auto status = op.execute({&x, &y}, {&z}, {}, {}, {true}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(exp, z); + ASSERT_EQ(exp, z); } TEST_F(DeclarableOpsTests4, Test_Fill_1) { - auto x = NDArrayFactory::create('c', {1, 3}, {3, 2, 4}); - auto v = NDArrayFactory::create(2.); - auto exp = NDArrayFactory::create('c', {3, 2, 4}); - exp.assign(2.0f); - - sd::ops::fill op; - auto result = op.evaluate({&x, &v}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {1, 3}, {3, 2, 4}); + auto v = NDArrayFactory::create(2.); + auto exp = NDArrayFactory::create('c', {3, 2, 4}); + exp.assign(2.0f); - auto z = result.at(0); + sd::ops::fill op; + auto result = op.evaluate({&x, &v}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_FirasSparce_1) { - auto x = NDArrayFactory::create('c', {1, 81}); - auto exp = NDArrayFactory::create('c', {1, 2}, {0, 1}); - - x.p(51, 1); - x.p(52, 0); - x.p(60, 1); - x.p(61, 0); - sd::ops::firas_sparse op; - auto result = op.evaluate({&x}, {0, 1}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {1, 81}); + auto exp = NDArrayFactory::create('c', {1, 2}, {0, 1}); - auto z = result.at(0); -// z->printIndexedBuffer("FIRAS"); -// z->printShapeInfo("OUTSHAPE"); -// ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.p(51, 1); + x.p(52, 0); + x.p(60, 1); + x.p(61, 0); + sd::ops::firas_sparse op; + auto result = op.evaluate({&x}, {0, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("FIRAS"); + // z->printShapeInfo("OUTSHAPE"); + // ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_FlattenTests_1) { - auto x = NDArrayFactory::create('c', {3, 3, 3, 3}); - auto exp = NDArrayFactory::create('c', {81}); + auto x = NDArrayFactory::create('c', {3, 3, 3, 3}); + auto exp = NDArrayFactory::create('c', {81}); - x.linspace(1); - exp.linspace(1); - sd::ops::flatten op; - auto result = op.evaluate({&x}, {}, {'c'}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); -// z->printIndexedBuffer("Flatten1"); -// z->printShapeInfo("Flatten1 shape"); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1); + exp.linspace(1); + sd::ops::flatten op; + auto result = op.evaluate({&x}, {}, {'c'}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Flatten1"); + // z->printShapeInfo("Flatten1 shape"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_FlattenTests_2) { - auto x = NDArrayFactory::create('c', {3, 3, 3, 3}); - auto y = NDArrayFactory::create('c', {3, 3}); - auto exp = NDArrayFactory::create('c', {90}); - - x.linspace(1); - y.linspace(82); - exp.linspace(1); - sd::ops::flatten op; - auto result = op.evaluate({&x, &y}, {}, {'c'}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {3, 3, 3, 3}); + auto y = NDArrayFactory::create('c', {3, 3}); + auto exp = NDArrayFactory::create('c', {90}); - auto z = result.at(0); -// z->printIndexedBuffer("Flatten2"); -// z->printShapeInfo("Flatten2 shape"); - - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1); + y.linspace(82); + exp.linspace(1); + sd::ops::flatten op; + auto result = op.evaluate({&x, &y}, {}, {'c'}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Flatten2"); + // z->printShapeInfo("Flatten2 shape"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_FlattenTests_3) { - NDArray x('c', {2,2}, {1, 2, 3, 4}, sd::DataType::INT32); - NDArray y('f', {2,2}, sd::DataType::INT32); - NDArray exp('c', {8}, {1, 2, 3, 4, 1, 2, 3, 4}, sd::DataType::INT32); - - y.assign(x); + NDArray x('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray y('f', {2, 2}, sd::DataType::INT32); + NDArray exp('c', {8}, {1, 2, 3, 4, 1, 2, 3, 4}, sd::DataType::INT32); - sd::ops::flatten op; - auto result = op.evaluate({&x, &y}, {}, {'c'}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + y.assign(x); - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::flatten op; + auto result = op.evaluate({&x, &y}, {}, {'c'}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_FlattenTests_4) { - NDArray x('c', {2,2}, {1, 2, 3, 4}, sd::DataType::INT32); - NDArray y('f', {2,2}, sd::DataType::INT32); - NDArray exp('c', {8}, {1, 3, 2, 4, 1, 3, 2, 4}, sd::DataType::INT32); - - y.assign(x); - - sd::ops::flatten op; - auto result = op.evaluate({&x, &y}, {}, {'f'}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + NDArray x('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray y('f', {2, 2}, sd::DataType::INT32); + NDArray exp('c', {8}, {1, 3, 2, 4, 1, 3, 2, 4}, sd::DataType::INT32); - auto z = result.at(0); + y.assign(x); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::flatten op; + auto result = op.evaluate({&x, &y}, {}, {'f'}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_FloorTests_1) { - auto x = NDArrayFactory::create('c', {3, 3}, {1.5, 2.3, 3.4, 4.3, 5.9, 6.1, 7.2, 8.9, 9.7}); - auto exp = NDArrayFactory::create('c', {3,3}); + auto x = NDArrayFactory::create( + 'c', {3, 3}, {1.5, 2.3, 3.4, 4.3, 5.9, 6.1, 7.2, 8.9, 9.7}); + auto exp = NDArrayFactory::create('c', {3, 3}); - exp.linspace(1); - sd::ops::Floor op; - auto result = op.evaluate({&x}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); -// z->printIndexedBuffer("Flatten1"); -// z->printShapeInfo("Flatten1 shape"); - ASSERT_TRUE(exp.equalsTo(z)); + exp.linspace(1); + sd::ops::Floor op; + auto result = op.evaluate({&x}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Flatten1"); + // z->printShapeInfo("Flatten1 shape"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_Split_1) { - auto x = NDArrayFactory::create('c', {5, 30}); - auto sizes = NDArrayFactory::create('c', {1, 3}, {4, 15, 11}); - - std::vector list0({0,0, 0,4}); - std::vector list1({0,0, 4,19}); - std::vector list2({0,0, 19,30}); + auto x = NDArrayFactory::create('c', {5, 30}); + auto sizes = NDArrayFactory::create('c', {1, 3}, {4, 15, 11}); - auto sub0 = x(list0, true); - auto sub1 = x(list1, true); - auto sub2 = x(list2, true); + std::vector list0({0, 0, 0, 4}); + std::vector list1({0, 0, 4, 19}); + std::vector list2({0, 0, 19, 30}); - sub0.assign(0.0); - sub1.assign(1.0); - sub2.assign(2.0); + auto sub0 = x(list0, true); + auto sub1 = x(list1, true); + auto sub2 = x(list2, true); + sub0.assign(0.0); + sub1.assign(1.0); + sub2.assign(2.0); - sd::ops::split_v op; - auto result = op.evaluate({&x, &sizes}, {}, {1}); + sd::ops::split_v op; + auto result = op.evaluate({&x, &sizes}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(3, result.size()); + ASSERT_EQ(3, result.size()); - auto z0 = result.at(0); - auto z1 = result.at(1); - auto z2 = result.at(2); - - ASSERT_TRUE(sub0.isSameShape(z0)); - ASSERT_TRUE(sub1.isSameShape(z1)); - ASSERT_TRUE(sub2.isSameShape(z2)); - - ASSERT_TRUE(sub0.equalsTo(z0)); - ASSERT_TRUE(sub1.equalsTo(z1)); - ASSERT_TRUE(sub2.equalsTo(z2)); + auto z0 = result.at(0); + auto z1 = result.at(1); + auto z2 = result.at(2); + ASSERT_TRUE(sub0.isSameShape(z0)); + ASSERT_TRUE(sub1.isSameShape(z1)); + ASSERT_TRUE(sub2.isSameShape(z2)); + ASSERT_TRUE(sub0.equalsTo(z0)); + ASSERT_TRUE(sub1.equalsTo(z1)); + ASSERT_TRUE(sub2.equalsTo(z2)); } // special test for TF mode, when axis goes first TEST_F(DeclarableOpsTests4, Test_Split_2) { - auto x = NDArrayFactory::create('c', {5, 12}); - auto axis = NDArrayFactory::create('c', {1, 1}, {1.f}); - - std::vector list0 = {0,0, 0,3}; - std::vector list1 = {0,0, 3,6}; - std::vector list2 = {0,0, 6,9}; - std::vector list3 = {0,0, 9,12}; - - auto sub0 = x(list0, true); - auto sub1 = x(list1, true); - auto sub2 = x(list2, true); - auto sub3 = x(list3, true); - - sub0.assign(0.0f); - sub1.assign(1.0f); - sub2.assign(2.0f); - sub3.assign(3.0f); + auto x = NDArrayFactory::create('c', {5, 12}); + auto axis = NDArrayFactory::create('c', {1, 1}, {1.f}); + std::vector list0 = {0, 0, 0, 3}; + std::vector list1 = {0, 0, 3, 6}; + std::vector list2 = {0, 0, 6, 9}; + std::vector list3 = {0, 0, 9, 12}; - sd::ops::split op; - auto result = op.evaluate({&axis, &x}, {}, {4}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto sub0 = x(list0, true); + auto sub1 = x(list1, true); + auto sub2 = x(list2, true); + auto sub3 = x(list3, true); - auto z0 = result.at(0); - auto z1 = result.at(1); - auto z2 = result.at(2); - auto z3 = result.at(3); + sub0.assign(0.0f); + sub1.assign(1.0f); + sub2.assign(2.0f); + sub3.assign(3.0f); - ASSERT_TRUE(sub0.isSameShape(z0)); - ASSERT_TRUE(sub1.isSameShape(z1)); - ASSERT_TRUE(sub2.isSameShape(z2)); - ASSERT_TRUE(sub3.isSameShape(z3)); + sd::ops::split op; + auto result = op.evaluate({&axis, &x}, {}, {4}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(sub0.equalsTo(z0)); - ASSERT_TRUE(sub1.equalsTo(z1)); - ASSERT_TRUE(sub2.equalsTo(z2)); - ASSERT_TRUE(sub3.equalsTo(z3)); + auto z0 = result.at(0); + auto z1 = result.at(1); + auto z2 = result.at(2); + auto z3 = result.at(3); + ASSERT_TRUE(sub0.isSameShape(z0)); + ASSERT_TRUE(sub1.isSameShape(z1)); + ASSERT_TRUE(sub2.isSameShape(z2)); + ASSERT_TRUE(sub3.isSameShape(z3)); + ASSERT_TRUE(sub0.equalsTo(z0)); + ASSERT_TRUE(sub1.equalsTo(z1)); + ASSERT_TRUE(sub2.equalsTo(z2)); + ASSERT_TRUE(sub3.equalsTo(z3)); } // special test for TF mode, when axis goes first TEST_F(DeclarableOpsTests4, Test_Split_3) { - auto x = NDArrayFactory::create('c', {6, 12}); - auto axis = NDArrayFactory::create('c', {1, 1}, {0.f}); + auto x = NDArrayFactory::create('c', {6, 12}); + auto axis = NDArrayFactory::create('c', {1, 1}, {0.f}); - std::vector list0 = {0,2, 0,0}; - std::vector list1 = {2,4, 0,0}; - std::vector list2 = {4,6, 0,0}; + std::vector list0 = {0, 2, 0, 0}; + std::vector list1 = {2, 4, 0, 0}; + std::vector list2 = {4, 6, 0, 0}; - auto sub0 = x(list0, true); - auto sub1 = x(list1, true); - auto sub2 = x(list2, true); + auto sub0 = x(list0, true); + auto sub1 = x(list1, true); + auto sub2 = x(list2, true); - sub0.assign(0.0f); - sub1.assign(1.0f); - sub2.assign(2.0f); + sub0.assign(0.0f); + sub1.assign(1.0f); + sub2.assign(2.0f); - sd::ops::split op; - auto result = op.evaluate({&axis, &x}, {}, {3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::split op; + auto result = op.evaluate({&axis, &x}, {}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z0 = result.at(0); - auto z1 = result.at(1); - auto z2 = result.at(2); + auto z0 = result.at(0); + auto z1 = result.at(1); + auto z2 = result.at(2); - ASSERT_TRUE(sub0.isSameShape(z0)); - ASSERT_TRUE(sub1.isSameShape(z1)); - ASSERT_TRUE(sub2.isSameShape(z2)); + ASSERT_TRUE(sub0.isSameShape(z0)); + ASSERT_TRUE(sub1.isSameShape(z1)); + ASSERT_TRUE(sub2.isSameShape(z2)); - ASSERT_TRUE(sub0.equalsTo(z0)); - ASSERT_TRUE(sub1.equalsTo(z1)); - ASSERT_TRUE(sub2.equalsTo(z2)); + ASSERT_TRUE(sub0.equalsTo(z0)); + ASSERT_TRUE(sub1.equalsTo(z1)); + ASSERT_TRUE(sub2.equalsTo(z2)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, split_test4) { + auto input = NDArrayFactory::create( + 'c', {10}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f}); + auto axis = NDArrayFactory::create(-1); + auto exp1 = + NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + auto exp2 = + NDArrayFactory::create('c', {5}, {6.f, 7.f, 8.f, 9.f, 10.f}); - auto input = NDArrayFactory::create('c', {10},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f}); - auto axis = NDArrayFactory::create(-1); - auto exp1 = NDArrayFactory::create('c', {5}, {1.f,2.f,3.f,4.f,5.f}); - auto exp2 = NDArrayFactory::create('c', {5}, {6.f,7.f,8.f,9.f,10.f}); + sd::ops::split op; + auto results = op.evaluate({&input, &axis}, {}, {2}, {}); - sd::ops::split op; - auto results = op.evaluate({&input, &axis}, {}, {2}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto out1 = results.at(0); + auto out2 = results.at(1); - auto out1 = results.at(0); - auto out2 = results.at(1); - - ASSERT_TRUE(exp1.isSameShape(out1)); - ASSERT_TRUE(exp2.isSameShape(out2)); - ASSERT_TRUE(exp1.equalsTo(out1)); - ASSERT_TRUE(exp2.equalsTo(out2)); + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.equalsTo(out2)); } - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, split_test5) { + auto input = NDArrayFactory::create( + 'c', {3, 8}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f}); + auto exp1 = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 2.f, 3.f, 4.f, 9.f, 10.f, 11.f, 12.f, 17.f, 18.f, 19.f, 20.f}); + auto exp2 = NDArrayFactory::create( + 'c', {3, 4}, + {5.f, 6.f, 7.f, 8.f, 13.f, 14.f, 15.f, 16.f, 21.f, 22.f, 23.f, 24.f}); - auto input = NDArrayFactory::create('c', {3,8},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f,19.f,20.f,21.f,22.f,23.f,24.f}); - auto exp1 = NDArrayFactory::create('c', {3,4}, {1.f,2.f,3.f,4.f, 9.f,10.f,11.f,12.f, 17.f,18.f,19.f,20.f}); - auto exp2 = NDArrayFactory::create('c', {3,4}, {5.f,6.f,7.f,8.f, 13.f,14.f,15.f,16.f, 21.f,22.f,23.f,24.f}); - - sd::ops::split op; - auto results = op.evaluate({&input}, {}, {2,-1},{}); + sd::ops::split op; + auto results = op.evaluate({&input}, {}, {2, -1}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto out1 = results.at(0); - auto out2 = results.at(1); + auto out1 = results.at(0); + auto out2 = results.at(1); - ASSERT_TRUE(exp1.isSameShape(out1)); - ASSERT_TRUE(exp2.isSameShape(out2)); - ASSERT_TRUE(exp1.equalsTo(out1)); - ASSERT_TRUE(exp2.equalsTo(out2)); + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.equalsTo(out2)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, split_test6) { + NDArray input('c', {0, 4}, sd::DataType::FLOAT32); + std::vector expShape = {0, 1}; - NDArray input('c', {0,4}, sd::DataType::FLOAT32); - std::vector expShape = {0,1}; + const int numSplits = 4; + const int axis = 1; - const int numSplits = 4; - const int axis = 1; + sd::ops::split op; + auto results = op.evaluate({&input}, {}, {numSplits, axis}, {}); - sd::ops::split op; - auto results = op.evaluate({&input}, {}, {numSplits, axis}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - for (int i = 0; i < numSplits; ++i) - ASSERT_TRUE(results.at(i)->isSameShape(expShape)); + for (int i = 0; i < numSplits; ++i) + ASSERT_TRUE(results.at(i).isSameShape(expShape)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, split_test7) { + NDArray input('c', {0, 4}, sd::DataType::FLOAT32); + std::vector expShape = {0, 4}; - NDArray input('c', {0,4}, sd::DataType::FLOAT32); - std::vector expShape = {0,4}; - - const int numSplits = 4; - const int axis = 0; + const int numSplits = 4; + const int axis = 0; - sd::ops::split op; - auto results = op.evaluate({&input}, {}, {numSplits, axis}, {}); + sd::ops::split op; + auto results = op.evaluate({&input}, {}, {numSplits, axis}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - for (int i = 0; i < numSplits; ++i) - ASSERT_TRUE(results.at(i)->isSameShape(expShape)); + for (int i = 0; i < numSplits; ++i) + ASSERT_TRUE(results.at(i).isSameShape(expShape)); } - TEST_F(DeclarableOpsTests4, Test_Squeeze_args_1) { - auto x = NDArrayFactory::create('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {2, 1, 2}, {1, 2, 3, 4}); - - sd::ops::squeeze op; - auto result = op.evaluate({&x}, {}, {1, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto x = NDArrayFactory::create('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {2, 1, 2}, {1, 2, 3, 4}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::squeeze op; + auto result = op.evaluate({&x}, {}, {1, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_Squeeze_args_2) { - auto x = NDArrayFactory::create('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4}); - auto y = NDArrayFactory::create('c', {2}, {1.f, 3.f}); - auto exp = NDArrayFactory::create('c', {2, 1, 2}, {1, 2, 3, 4}); - - sd::ops::squeeze op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto x = NDArrayFactory::create('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {2}, {1.f, 3.f}); + auto exp = NDArrayFactory::create('c', {2, 1, 2}, {1, 2, 3, 4}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::squeeze op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests4, Test_Squeeze_args_3) { - auto x = NDArrayFactory::create('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {2, 1, 2}, {1, 2, 3, 4}); - - sd::ops::squeeze op; - auto result = op.evaluate({&x}, {}, {-2, -3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto x = NDArrayFactory::create('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {2, 1, 2}, {1, 2, 3, 4}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::squeeze op; + auto result = op.evaluate({&x}, {}, {-2, -3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_1) { - auto x = NDArrayFactory::create('c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto exp = NDArrayFactory::create('c', {1, 1, 1, 12}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto x = NDArrayFactory::create( + 'c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create( + 'c', {1, 1, 1, 12}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - sd::ops::space_to_depth op; - auto result = op.evaluate({&x}, {}, {2, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::space_to_depth op; + auto result = op.evaluate({&x}, {}, {2, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_2) { - auto x = NDArrayFactory::create('c', {1, 3, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto exp = NDArrayFactory::create('c', {1, 12, 1, 1}, {1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12}); - - sd::ops::space_to_depth op; - auto result = op.evaluate({&x}, {}, {2, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create( + 'c', {1, 3, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create( + 'c', {1, 12, 1, 1}, {1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::space_to_depth op; + auto result = op.evaluate({&x}, {}, {2, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests4, Test_DepthToSpace_1) { - auto x = NDArrayFactory::create('c', {1, 1, 1, 12}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto exp = NDArrayFactory::create('c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - - sd::ops::depth_to_space op; - auto result = op.evaluate({&x}, {}, {2, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create( + 'c', {1, 1, 1, 12}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create( + 'c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::depth_to_space op; + auto result = op.evaluate({&x}, {}, {2, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests4, Test_DepthToSpace_2) { - auto x = NDArrayFactory::create('c', {1, 12, 1, 1}, {1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12}); - auto exp = NDArrayFactory::create('c', {1, 3, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto x = NDArrayFactory::create( + 'c', {1, 12, 1, 1}, {1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12}); + auto exp = NDArrayFactory::create( + 'c', {1, 3, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - sd::ops::depth_to_space op; - auto result = op.evaluate({&x}, {}, {2, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::depth_to_space op; + auto result = op.evaluate({&x}, {}, {2, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_DepthToSpace_3) { - auto x = NDArrayFactory::create('c', {4, 4, 16, 16}); - auto exp = NDArrayFactory::create('c', {4, 16, 64, 1}); - - sd::ops::depth_to_space op; - auto result = op.evaluate({&x}, {}, {4, 1}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {4, 4, 16, 16}); + auto exp = NDArrayFactory::create('c', {4, 16, 64, 1}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); + sd::ops::depth_to_space op; + auto result = op.evaluate({&x}, {}, {4, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); } - TEST_F(DeclarableOpsTests4, Test_Cross_1) { - auto a = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto b = NDArrayFactory::create('c', {3}, {6, 7, 8}); - auto exp = NDArrayFactory::create('c', {3}, {-5, 10, -5}); + auto a = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto b = NDArrayFactory::create('c', {3}, {6, 7, 8}); + auto exp = NDArrayFactory::create('c', {3}, {-5, 10, -5}); - sd::ops::cross op; - auto result = op.evaluate({&a, &b}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::cross op; + auto result = op.evaluate({&a, &b}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests4, Test_Cross_2) { - auto a = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); - auto b = NDArrayFactory::create('c', {2, 3}, {6, 7, 8, 6, 7, 8}); - auto exp = NDArrayFactory::create('c', {2, 3}, {-5, 10, -5, -5, 10, -5}); - - sd::ops::cross op; - auto result = op.evaluate({&a, &b}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto a = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); + auto b = NDArrayFactory::create('c', {2, 3}, {6, 7, 8, 6, 7, 8}); + auto exp = + NDArrayFactory::create('c', {2, 3}, {-5, 10, -5, -5, 10, -5}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::cross op; + auto result = op.evaluate({&a, &b}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests4, Test_Cross_3) { - auto a = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto b = NDArrayFactory::create('c', {3, 3}, {2, 3, 4, 7, 6, 5, 6, 3, 2}); - auto exp = NDArrayFactory::create('c', {3, 3}, { -1, 2, -1, -11, 22, -11, -11, 40, -27}); - - sd::ops::cross op; - auto result = op.evaluate({&a, &b}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto a = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto b = + NDArrayFactory::create('c', {3, 3}, {2, 3, 4, 7, 6, 5, 6, 3, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {-1, 2, -1, -11, 22, -11, -11, 40, -27}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::cross op; + auto result = op.evaluate({&a, &b}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_Add_119) { - auto a = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto b = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 4}, {2, 4, 6, 8}); + auto a = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto b = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 4}, {2, 4, 6, 8}); - sd::ops::add op; - auto result = op.evaluate({&a, &b}); + sd::ops::add op; + auto result = op.evaluate({&a, &b}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_EQ(2, z->rankOf()); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_EQ(2, z.rankOf()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_TileToShape_1) { - auto x = NDArrayFactory::create('c', {2, 1, 3}); - auto exp = NDArrayFactory::create('c', {2, 4, 3}, {1.f, 2.f, 3.f,1.f, 2.f, 3.f,1.f, 2.f, 3.f,1.f, 2.f, 3.f, - 4.f, 5.f, 6.f,4.f, 5.f, 6.f,4.f, 5.f, 6.f,4.f, 5.f, 6.f}); - x.linspace(1.f); - - sd::ops::tile_to_shape op; - auto result = op.evaluate({&x},{}, {2, 4, 3}); + auto x = NDArrayFactory::create('c', {2, 1, 3}); + auto exp = NDArrayFactory::create( + 'c', {2, 4, 3}, + {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f}); + x.linspace(1.f); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::tile_to_shape op; + auto result = op.evaluate({&x}, {}, {2, 4, 3}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_1) { - auto x = NDArrayFactory::create('c', {3, 4, 5}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {1,3,4,5}); - exp.linspace(1); - - sd::ops::strided_slice op; - auto result = op.evaluate({&x}, {}, {0,0,0,1,0, -999,0,0,0, -999,3,4,5, -999,1,1,1}); - - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {3, 4, 5}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {1, 3, 4, 5}); + exp.linspace(1); - auto z = result.at(0); + sd::ops::strided_slice op; + auto result = op.evaluate( + {&x}, {}, {0, 0, 0, 1, 0, -999, 0, 0, 0, -999, 3, 4, 5, -999, 1, 1, 1}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_2) { - auto x = NDArrayFactory::create('c', {3, 4, 5}); - auto begin = NDArrayFactory::create('c', {4}, {-999,0,0,0}); - auto end = NDArrayFactory::create('c', {4}, {-999,3,4,5}); - auto stride = NDArrayFactory::create('c', {4}, {-999,1,1,1}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {1,3,4,5}); - exp.linspace(1); - - sd::ops::strided_slice op; - auto result = op.evaluate({&x, &begin, &end, &stride}, {}, {0,0,0,1,0}); - - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {3, 4, 5}); + auto begin = NDArrayFactory::create('c', {4}, {-999, 0, 0, 0}); + auto end = NDArrayFactory::create('c', {4}, {-999, 3, 4, 5}); + auto stride = NDArrayFactory::create('c', {4}, {-999, 1, 1, 1}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {1, 3, 4, 5}); + exp.linspace(1); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::strided_slice op; + auto result = op.evaluate({&x, &begin, &end, &stride}, {}, {0, 0, 0, 1, 0}); - auto z = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_3) { - int axis = 0; - auto x = NDArrayFactory::create('c', {1}, {10}); - auto begin = NDArrayFactory::create('c', {1}, {axis}); - auto end = NDArrayFactory::create('c', {1}, {axis}); - auto stride = NDArrayFactory::create('c', {1}, {1}); - //x.linspace(1); - //auto exp = NDArrayFactory::create('c', {1,3,4,5}); - //exp.linspace(1); + int axis = 0; + auto x = NDArrayFactory::create('c', {1}, {10}); + auto begin = NDArrayFactory::create('c', {1}, {axis}); + auto end = NDArrayFactory::create('c', {1}, {axis}); + auto stride = NDArrayFactory::create('c', {1}, {1}); + // x.linspace(1); + // auto exp = NDArrayFactory::create('c', {1,3,4,5}); + // exp.linspace(1); - sd::ops::strided_slice op; - auto result = op.evaluate({&x, &begin, &end, &stride}, {}, {1,0,0,0,0}); - - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_TRUE(z->isEmpty()); + sd::ops::strided_slice op; + auto result = op.evaluate({&x, &begin, &end, &stride}, {}, {1, 0, 0, 0, 0}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(z.isEmpty()); } TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_4) { - auto x = NDArrayFactory::create('c', {1,3}, {1, 2, 3}); - auto begin = NDArrayFactory::create('c', {2}, {0, 0}); - auto end = NDArrayFactory::create('c', {2}, {0,1}); - auto stride = NDArrayFactory::create('c', {2}, {1,1}); -// x.linspace(1); - auto exp = NDArrayFactory::create('c', {1}, {1}); - //exp.linspace(1); - - sd::ops::strided_slice op; - auto result = op.evaluate({&x, &begin, &end, &stride}, {}, {1,0,1,0,2}); + auto x = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); + auto begin = NDArrayFactory::create('c', {2}, {0, 0}); + auto end = NDArrayFactory::create('c', {2}, {0, 1}); + auto stride = NDArrayFactory::create('c', {2}, {1, 1}); + // x.linspace(1); + auto exp = NDArrayFactory::create('c', {1}, {1}); + // exp.linspace(1); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::strided_slice op; + auto result = op.evaluate({&x, &begin, &end, &stride}, {}, {1, 0, 1, 0, 2}); - auto z = result.at(0); - ASSERT_TRUE(z->lengthOf() == 1); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(z.lengthOf() == 1); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, parallel_stack_test1) { + auto x1 = NDArrayFactory::create('c', {2, 2, 2}); + auto x2 = NDArrayFactory::create('c', {2, 2, 2}); + auto x3 = NDArrayFactory::create('c', {2, 2, 2}); + x1.linspace(1); + x2.linspace(9); + x3.linspace(17); - auto x1 = NDArrayFactory::create('c', {2,2,2}); - auto x2 = NDArrayFactory::create('c', {2,2,2}); - auto x3 = NDArrayFactory::create('c', {2,2,2}); - x1.linspace(1); - x2.linspace(9); - x3.linspace(17); - - auto expected = NDArrayFactory::create('c', {3,2,2,2}); - expected.linspace(1); - - sd::ops::parallel_stack op; - auto results = op.evaluate({&x1, &x2, &x3}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto expected = NDArrayFactory::create('c', {3, 2, 2, 2}); + expected.linspace(1); + sd::ops::parallel_stack op; + auto results = op.evaluate({&x1, &x2, &x3}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, parallel_stack_test2) { + auto x1 = NDArrayFactory::create('c', {1, 2}, {1, 2}); + auto x2 = NDArrayFactory::create('c', {1, 2}, {3, 4}); + auto x3 = NDArrayFactory::create('c', {1, 2}, {5, 6}); - auto x1 = NDArrayFactory::create('c', {1,2}, {1,2}); - auto x2 = NDArrayFactory::create('c', {1,2}, {3,4}); - auto x3 = NDArrayFactory::create('c', {1,2}, {5,6}); - - auto expected = NDArrayFactory::create('c', {3,1,2}, {1,2,3,4,5,6}); - - sd::ops::parallel_stack op; - auto results = op.evaluate({&x1, &x2, &x3}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + auto expected = + NDArrayFactory::create('c', {3, 1, 2}, {1, 2, 3, 4, 5, 6}); + sd::ops::parallel_stack op; + auto results = op.evaluate({&x1, &x2, &x3}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, parallel_stack_test3) { + auto x1 = NDArrayFactory::create('c', {2, 1}, {1, 2}); + auto x2 = NDArrayFactory::create('c', {2, 1}, {3, 4}); + auto x3 = NDArrayFactory::create('c', {2, 1}, {5, 6}); - auto x1 = NDArrayFactory::create('c', {2,1}, {1,2}); - auto x2 = NDArrayFactory::create('c', {2,1}, {3,4}); - auto x3 = NDArrayFactory::create('c', {2,1}, {5,6}); - - auto expected = NDArrayFactory::create('c', {3,2,1}, {1,2,3,4,5,6}); - - sd::ops::parallel_stack op; - auto results = op.evaluate({&x1, &x2, &x3}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto expected = + NDArrayFactory::create('c', {3, 2, 1}, {1, 2, 3, 4, 5, 6}); + sd::ops::parallel_stack op; + auto results = op.evaluate({&x1, &x2, &x3}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } -\ + ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, parallel_stack_test4) { + auto x1 = NDArrayFactory::create('c', {2}, {1, 2}); + auto x2 = NDArrayFactory::create('c', {2}, {3, 4}); + auto x3 = NDArrayFactory::create('c', {2}, {5, 6}); - auto x1 = NDArrayFactory::create('c', {2}, {1,2}); - auto x2 = NDArrayFactory::create('c', {2}, {3,4}); - auto x3 = NDArrayFactory::create('c', {2}, {5,6}); - - auto expected = NDArrayFactory::create('c', {3,2}, {1,2,3,4,5,6}); - - sd::ops::parallel_stack op; - auto results = op.evaluate({&x1, &x2, &x3}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto expected = + NDArrayFactory::create('c', {3, 2}, {1, 2, 3, 4, 5, 6}); + sd::ops::parallel_stack op; + auto results = op.evaluate({&x1, &x2, &x3}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, parallel_stack_test5) { + auto x1 = NDArrayFactory::create('c', {1}, {1}); + auto x2 = NDArrayFactory::create('c', {1}, {3}); + auto x3 = NDArrayFactory::create('c', {1}, {5}); - auto x1 = NDArrayFactory::create('c', {1}, {1}); - auto x2 = NDArrayFactory::create('c', {1}, {3}); - auto x3 = NDArrayFactory::create('c', {1}, {5}); - - auto expected = NDArrayFactory::create('c', {3,1}, {1,3,5}); - - sd::ops::parallel_stack op; - auto results = op.evaluate({&x1, &x2, &x3}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto expected = NDArrayFactory::create('c', {3, 1}, {1, 3, 5}); + sd::ops::parallel_stack op; + auto results = op.evaluate({&x1, &x2, &x3}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, parallel_stack_test6) { + auto x1 = NDArrayFactory::create(1.); + auto x2 = NDArrayFactory::create(3.); + auto x3 = NDArrayFactory::create(5.); - auto x1 = NDArrayFactory::create(1.); - auto x2 = NDArrayFactory::create(3.); - auto x3 = NDArrayFactory::create(5.); - - auto expected = NDArrayFactory::create('c', {3}, {1,3,5}); - - sd::ops::parallel_stack op; - auto results = op.evaluate({&x1, &x2, &x3}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto expected = NDArrayFactory::create('c', {3}, {1, 3, 5}); + sd::ops::parallel_stack op; + auto results = op.evaluate({&x1, &x2, &x3}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, parallel_stack_test7) { + auto x1 = NDArrayFactory::create(1.); + auto expected = NDArrayFactory::create('c', {1}, {1.}); - auto x1 = NDArrayFactory::create(1.); - auto expected = NDArrayFactory::create('c', {1}, {1.}); - - sd::ops::parallel_stack op; - auto results = op.evaluate({&x1}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + sd::ops::parallel_stack op; + auto results = op.evaluate({&x1}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, meshgrid_test1) { - - auto in0 = NDArrayFactory::create('c', {2}, {1, 2}); - auto in1 = NDArrayFactory::create('c', {3}, {10, 20, 30}); - auto in2 = NDArrayFactory::create('c', {4}, {100, 200, 300, 400}); - auto exp0 = NDArrayFactory::create('c', {2,3,4}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}); - auto exp1 = NDArrayFactory::create('c', {2,3,4}, {10, 10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30, 10, 10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30}); - auto exp2 = NDArrayFactory::create('c', {2,3,4}, {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400}); - - sd::ops::meshgrid op; - auto results = op.evaluate({&in0, &in1, &in2}, {}, {0}); - auto out0 = results.at(0); - auto out1 = results.at(1); - auto out2 = results.at(2); - - // out0->printIndexedBuffer(); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp0.isSameShape(out0)); - ASSERT_TRUE(exp0.equalsTo(out0)); - ASSERT_TRUE(exp1.isSameShape(out1)); - ASSERT_TRUE(exp1.equalsTo(out1)); - ASSERT_TRUE(exp2.isSameShape(out2)); - ASSERT_TRUE(exp2.equalsTo(out2)); - - + auto in0 = NDArrayFactory::create('c', {2}, {1, 2}); + auto in1 = NDArrayFactory::create('c', {3}, {10, 20, 30}); + auto in2 = NDArrayFactory::create('c', {4}, {100, 200, 300, 400}); + auto exp0 = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}); + auto exp1 = NDArrayFactory::create( + 'c', {2, 3, 4}, {10, 10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30, + 10, 10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30}); + auto exp2 = NDArrayFactory::create( + 'c', {2, 3, 4}, + {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, + 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400}); + + sd::ops::meshgrid op; + auto results = op.evaluate({&in0, &in1, &in2}, {}, {0}); + auto out0 = results.at(0); + auto out1 = results.at(1); + auto out2 = results.at(2); + + // out0->printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp2.equalsTo(out2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, meshgrid_test2) { - - auto in0 = NDArrayFactory::create('c', {2}, {1, 2}); - auto in1 = NDArrayFactory::create('c', {3}, {10, 20, 30}); - auto in2 = NDArrayFactory::create('c', {4}, {100, 200, 300, 400}); - auto exp0 = NDArrayFactory::create('c', {3,2,4}, {1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2}); - auto exp1 = NDArrayFactory::create('c', {3,2,4}, {10, 10, 10, 10, 10, 10, 10, 10, 20, 20, 20, 20, 20, 20, 20, 20, 30, 30, 30, 30, 30, 30, 30, 30}); - auto exp2 = NDArrayFactory::create('c', {3,2,4}, {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400}); - - sd::ops::meshgrid op; - auto results = op.evaluate({&in0, &in1, &in2}); - auto out0 = results.at(0); - auto out1 = results.at(1); - auto out2 = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp0.isSameShape(out0)); - ASSERT_TRUE(exp0.equalsTo(out0)); - ASSERT_TRUE(exp1.isSameShape(out1)); - ASSERT_TRUE(exp1.equalsTo(out1)); - ASSERT_TRUE(exp2.isSameShape(out2)); - ASSERT_TRUE(exp2.equalsTo(out2)); - - + auto in0 = NDArrayFactory::create('c', {2}, {1, 2}); + auto in1 = NDArrayFactory::create('c', {3}, {10, 20, 30}); + auto in2 = NDArrayFactory::create('c', {4}, {100, 200, 300, 400}); + auto exp0 = NDArrayFactory::create( + 'c', {3, 2, 4}, + {1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2}); + auto exp1 = NDArrayFactory::create( + 'c', {3, 2, 4}, {10, 10, 10, 10, 10, 10, 10, 10, 20, 20, 20, 20, + 20, 20, 20, 20, 30, 30, 30, 30, 30, 30, 30, 30}); + auto exp2 = NDArrayFactory::create( + 'c', {3, 2, 4}, + {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, + 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400}); + + sd::ops::meshgrid op; + auto results = op.evaluate({&in0, &in1, &in2}); + auto out0 = results.at(0); + auto out1 = results.at(1); + auto out2 = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp2.equalsTo(out2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, meshgrid_test3) { - - auto in0 = NDArrayFactory::create('c', {2}, {1, 2}); - auto in1 = NDArrayFactory::create('c', {1,3}, {10, 20, 30}); - auto in2 = NDArrayFactory::create('c', {2,2}, {100, 200, 300, 400}); - auto exp0 = NDArrayFactory::create('c', {3,2,4}, {1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2}); - auto exp1 = NDArrayFactory::create('c', {3,2,4}, {10, 10, 10, 10, 10, 10, 10, 10, 20, 20, 20, 20, 20, 20, 20, 20, 30, 30, 30, 30, 30, 30, 30, 30}); - auto exp2 = NDArrayFactory::create('c', {3,2,4}, {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400}); - - sd::ops::meshgrid op; - auto results = op.evaluate({&in0, &in1, &in2}); - auto out0 = results.at(0); - auto out1 = results.at(1); - auto out2 = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp0.isSameShape(out0)); - ASSERT_TRUE(exp0.equalsTo(out0)); - ASSERT_TRUE(exp1.isSameShape(out1)); - ASSERT_TRUE(exp1.equalsTo(out1)); - ASSERT_TRUE(exp2.isSameShape(out2)); - ASSERT_TRUE(exp2.equalsTo(out2)); - - + auto in0 = NDArrayFactory::create('c', {2}, {1, 2}); + auto in1 = NDArrayFactory::create('c', {1, 3}, {10, 20, 30}); + auto in2 = NDArrayFactory::create('c', {2, 2}, {100, 200, 300, 400}); + auto exp0 = NDArrayFactory::create( + 'c', {3, 2, 4}, + {1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2}); + auto exp1 = NDArrayFactory::create( + 'c', {3, 2, 4}, {10, 10, 10, 10, 10, 10, 10, 10, 20, 20, 20, 20, + 20, 20, 20, 20, 30, 30, 30, 30, 30, 30, 30, 30}); + auto exp2 = NDArrayFactory::create( + 'c', {3, 2, 4}, + {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, + 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400}); + + sd::ops::meshgrid op; + auto results = op.evaluate({&in0, &in1, &in2}); + auto out0 = results.at(0); + auto out1 = results.at(1); + auto out2 = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp2.equalsTo(out2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, meshgrid_test4) { - - auto in0 = NDArrayFactory::create('c', {1,2}, {1, 2}); - auto in1 = NDArrayFactory::create('c', {3,1}, {10, 20, 30}); - auto in2 = NDArrayFactory::create('c', {1,4,1}, {100, 200, 300, 400}); - auto exp0 = NDArrayFactory::create('c', {2,3,4}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}); - auto exp1 = NDArrayFactory::create('c', {2,3,4}, {10, 10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30, 10, 10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30}); - auto exp2 = NDArrayFactory::create('c', {2,3,4}, {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400}); - - sd::ops::meshgrid op; - auto results = op.evaluate({&in0, &in1, &in2}, {}, {0}); - auto out0 = results.at(0); - auto out1 = results.at(1); - auto out2 = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp0.isSameShape(out0)); - ASSERT_TRUE(exp0.equalsTo(out0)); - ASSERT_TRUE(exp1.isSameShape(out1)); - ASSERT_TRUE(exp1.equalsTo(out1)); - ASSERT_TRUE(exp2.isSameShape(out2)); - ASSERT_TRUE(exp2.equalsTo(out2)); - - + auto in0 = NDArrayFactory::create('c', {1, 2}, {1, 2}); + auto in1 = NDArrayFactory::create('c', {3, 1}, {10, 20, 30}); + auto in2 = + NDArrayFactory::create('c', {1, 4, 1}, {100, 200, 300, 400}); + auto exp0 = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}); + auto exp1 = NDArrayFactory::create( + 'c', {2, 3, 4}, {10, 10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30, + 10, 10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30}); + auto exp2 = NDArrayFactory::create( + 'c', {2, 3, 4}, + {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, + 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400}); + + sd::ops::meshgrid op; + auto results = op.evaluate({&in0, &in1, &in2}, {}, {0}); + auto out0 = results.at(0); + auto out1 = results.at(1); + auto out2 = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp2.equalsTo(out2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, meshgrid_test5) { - - auto in0 = NDArrayFactory::create(1); - auto in1 = NDArrayFactory::create(2); - auto in2 = NDArrayFactory::create(3); - auto exp0 = NDArrayFactory::create('c', {1,1,1}, {1}); - auto exp1 = NDArrayFactory::create('c', {1,1,1}, {2}); - auto exp2 = NDArrayFactory::create('c', {1,1,1}, {3}); - - sd::ops::meshgrid op; - auto results = op.evaluate({&in0, &in1, &in2}, {}, {0}); - auto out0 = results.at(0); - auto out1 = results.at(1); - auto out2 = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp0.isSameShape(out0)); - ASSERT_TRUE(exp0.equalsTo(out0)); - ASSERT_TRUE(exp1.isSameShape(out1)); - ASSERT_TRUE(exp1.equalsTo(out1)); - ASSERT_TRUE(exp2.isSameShape(out2)); - ASSERT_TRUE(exp2.equalsTo(out2)); - - + auto in0 = NDArrayFactory::create(1); + auto in1 = NDArrayFactory::create(2); + auto in2 = NDArrayFactory::create(3); + auto exp0 = NDArrayFactory::create('c', {1, 1, 1}, {1}); + auto exp1 = NDArrayFactory::create('c', {1, 1, 1}, {2}); + auto exp2 = NDArrayFactory::create('c', {1, 1, 1}, {3}); + + sd::ops::meshgrid op; + auto results = op.evaluate({&in0, &in1, &in2}, {}, {0}); + auto out0 = results.at(0); + auto out1 = results.at(1); + auto out2 = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp2.equalsTo(out2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, meshgrid_test6) { - - auto in0 = NDArrayFactory::create('c', {2,2},{1,2,3,4}); - auto in1 = NDArrayFactory::create(5); - auto in2 = NDArrayFactory::create(6); - auto exp0 = NDArrayFactory::create('c', {4,1,1}, {1,2,3,4}); - auto exp1 = NDArrayFactory::create('c', {4,1,1}, {5,5,5,5}); - auto exp2 = NDArrayFactory::create('c', {4,1,1}, {6,6,6,6}); - - sd::ops::meshgrid op; - auto results = op.evaluate({&in0, &in1, &in2}, {}, {0}); - auto out0 = results.at(0); - auto out1 = results.at(1); - auto out2 = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp0.isSameShape(out0)); - ASSERT_TRUE(exp0.equalsTo(out0)); - ASSERT_TRUE(exp1.isSameShape(out1)); - ASSERT_TRUE(exp1.equalsTo(out1)); - ASSERT_TRUE(exp2.isSameShape(out2)); - ASSERT_TRUE(exp2.equalsTo(out2)); - - + auto in0 = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto in1 = NDArrayFactory::create(5); + auto in2 = NDArrayFactory::create(6); + auto exp0 = NDArrayFactory::create('c', {4, 1, 1}, {1, 2, 3, 4}); + auto exp1 = NDArrayFactory::create('c', {4, 1, 1}, {5, 5, 5, 5}); + auto exp2 = NDArrayFactory::create('c', {4, 1, 1}, {6, 6, 6, 6}); + + sd::ops::meshgrid op; + auto results = op.evaluate({&in0, &in1, &in2}, {}, {0}); + auto out0 = results.at(0); + auto out1 = results.at(1); + auto out2 = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp2.equalsTo(out2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, meshgrid_test7) { - - auto in0 = NDArrayFactory::create('c', {2,2},{1,2,3,4}); - auto in1 = NDArrayFactory::create(5); - auto in2 = NDArrayFactory::create(6); - auto exp0 = NDArrayFactory::create('c', {1,4,1}, {1,2,3,4}); - auto exp1 = NDArrayFactory::create('c', {1,4,1}, {5,5,5,5}); - auto exp2 = NDArrayFactory::create('c', {1,4,1}, {6,6,6,6}); - - sd::ops::meshgrid op; - auto results = op.evaluate({&in0, &in1, &in2}, {}, {1}); - auto out0 = results.at(0); - auto out1 = results.at(1); - auto out2 = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp0.isSameShape(out0)); - ASSERT_TRUE(exp0.equalsTo(out0)); - ASSERT_TRUE(exp1.isSameShape(out1)); - ASSERT_TRUE(exp1.equalsTo(out1)); - ASSERT_TRUE(exp2.isSameShape(out2)); - ASSERT_TRUE(exp2.equalsTo(out2)); - - + auto in0 = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto in1 = NDArrayFactory::create(5); + auto in2 = NDArrayFactory::create(6); + auto exp0 = NDArrayFactory::create('c', {1, 4, 1}, {1, 2, 3, 4}); + auto exp1 = NDArrayFactory::create('c', {1, 4, 1}, {5, 5, 5, 5}); + auto exp2 = NDArrayFactory::create('c', {1, 4, 1}, {6, 6, 6, 6}); + + sd::ops::meshgrid op; + auto results = op.evaluate({&in0, &in1, &in2}, {}, {1}); + auto out0 = results.at(0); + auto out1 = results.at(1); + auto out2 = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp2.equalsTo(out2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, meshgrid_test8) { + auto in0 = NDArrayFactory::create(5); + auto exp0 = NDArrayFactory::create('c', {1}, {5}); - auto in0 = NDArrayFactory::create(5); - auto exp0 = NDArrayFactory::create('c', {1}, {5}); - - sd::ops::meshgrid op; - auto results = op.evaluate({&in0}, {}, {0}); - auto out0 = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp0.isSameShape(out0)); - ASSERT_TRUE(exp0.equalsTo(out0)); - + sd::ops::meshgrid op; + auto results = op.evaluate({&in0}, {}, {0}); + auto out0 = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, meshgrid_test9) { + auto in0 = NDArrayFactory::create(5); + auto exp0 = NDArrayFactory::create('c', {1}, {5}); - auto in0 = NDArrayFactory::create(5); - auto exp0 = NDArrayFactory::create('c', {1}, {5}); - - sd::ops::meshgrid op; - auto results = op.evaluate({&in0}, {}, {1}); - auto out0 = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp0.isSameShape(out0)); - ASSERT_TRUE(exp0.equalsTo(out0)); - + sd::ops::meshgrid op; + auto results = op.evaluate({&in0}, {}, {1}); + auto out0 = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, WeightedCrossEntropyWithLogits_1) { - - - auto input = NDArrayFactory::create('c', {2, 3}, {11.f, 13.f, 4.f, 15.f, 6.f, 3.f}); - auto targets = NDArrayFactory::create('c', {2, 3}, {15.5f, 15.7f, 5.f , 15.f, 5.f, 6.f}); - auto weight = NDArrayFactory::create(0.7f); - auto expected = NDArrayFactory::create('c', {2, 3}, {-159.50006, -191.1, -16.009075, -210., -24.001238, -15.03887}); - -//Targets {15.5f, 15.7f, 5.f , 15.f, 5.f, 6.f}; -//---------- -//Inputs {11.f, 13.f, 4.f, 15.f, 6.f, 3.f}; -//---------- -//Weights [0.7] -//Result {-159.50006, -191.1, -16.009075, -210., -24.001238, -15.03887} - - sd::ops::weighted_cross_entropy_with_logits op; - auto results = op.evaluate({&targets, &input, &weight}); - auto output = results.at(0); - - // output->printIndexedBuffer(); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + auto input = NDArrayFactory::create( + 'c', {2, 3}, {11.f, 13.f, 4.f, 15.f, 6.f, 3.f}); + auto targets = NDArrayFactory::create( + 'c', {2, 3}, {15.5f, 15.7f, 5.f, 15.f, 5.f, 6.f}); + auto weight = NDArrayFactory::create(0.7f); + auto expected = NDArrayFactory::create( + 'c', {2, 3}, + {-159.50006, -191.1, -16.009075, -210., -24.001238, -15.03887}); + + // Targets {15.5f, 15.7f, 5.f , 15.f, 5.f, 6.f}; + //---------- + // Inputs {11.f, 13.f, 4.f, 15.f, 6.f, 3.f}; + //---------- + // Weights [0.7] + // Result {-159.50006, -191.1, -16.009075, -210., -24.001238, + // -15.03887} + + sd::ops::weighted_cross_entropy_with_logits op; + auto results = op.evaluate({&targets, &input, &weight}); + auto output = results.at(0); + + // output->printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, WeightedCrossEntropyWithLogits_2) { + auto input = NDArrayFactory::create( + 'c', {2, 3}, {11.f, 13.f, 4.f, 15.f, 6.f, 3.f}); + auto targets = NDArrayFactory::create( + 'c', {2, 3}, {15.5f, 15.7f, 5.f, 15.f, 5.f, 6.f}); + auto weights = NDArrayFactory::create({0.5f, 0.7f, 1.0f}); + auto expected = NDArrayFactory::create( + 'c', {2, 3}, + {-159.5001f, -191.1f, -15.98185f, -210.f, -24.001238f, -14.951412f}); + sd::ops::weighted_cross_entropy_with_logits op; + auto results = op.evaluate({&targets, &input, &weights}); + auto output = results.at(0); - auto input = NDArrayFactory::create('c', {2, 3}, {11.f, 13.f, 4.f, 15.f, 6.f, 3.f}); - auto targets = NDArrayFactory::create('c', {2, 3}, {15.5f, 15.7f, 5.f, 15.f, 5.f, 6.f}); - auto weights = NDArrayFactory::create({0.5f, 0.7f, 1.0f}) ; - auto expected = NDArrayFactory::create('c', {2, 3}, {-159.5001f, -191.1f, -15.98185f, -210.f, -24.001238f, -14.951412f}); - - sd::ops::weighted_cross_entropy_with_logits op; - auto results = op.evaluate({&targets, &input, &weights}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, lstm_test1) { - - const int time = 5; - const int batchSize = 3; - const int inSize = 3; - const int numProj = 3; - const int numUnits = 3; - - auto x = NDArrayFactory::create('c', {time, batchSize, inSize}); - auto h0 = NDArrayFactory::create('c', {batchSize, numProj}); - auto c0 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - x.linspace(0.5, 0.5); - h0 = 1.; - c0 = 2.; - Wx = 0.003; - Wh = 0.006; - Wc = 0.; - Wp = 0.; - b = 0.5; - - auto expH = NDArrayFactory::create('c', {time, batchSize, numProj}, {0.57574,0.57574,0.57574,0.58006,0.58006,0.58006,0.58434,0.58434,0.58434, - 0.55114,0.55114,0.55114,0.55732,0.55732,0.55732,0.56338,0.56338,0.56338, - 0.53763,0.53763,0.53763,0.54534,0.54534,0.54534,0.55287,0.55287,0.55287, - 0.53626,0.53626,0.53626,0.54487,0.54487,0.54487,0.55327,0.55327,0.55327, - 0.54484,0.54484,0.54484,0.55379,0.55379,0.55379,0.5625 ,0.5625 ,0.5625}); - - auto expClast = NDArrayFactory::create('c', {1, batchSize, numProj}, {1.1589154,1.1589154,1.1589154,1.1892855,1.1892855,1.1892855,1.219861 ,1.219861 ,1.219861}); - - sd::ops::lstm op; - auto results = op.evaluate({&x, &h0, &c0, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 0.}, {0, 0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto *h = results.at(0); - auto *c = results.at(1); - auto cLast = (*c)({4,5,0,0,0,0},true); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expClast.isSameShape(&cLast)); - ASSERT_TRUE(expClast.equalsTo(&cLast)); - - + const int time = 5; + const int batchSize = 3; + const int inSize = 3; + const int numProj = 3; + const int numUnits = 3; + + auto x = NDArrayFactory::create('c', {time, batchSize, inSize}); + auto h0 = NDArrayFactory::create('c', {batchSize, numProj}); + auto c0 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + x.linspace(0.5, 0.5); + h0 = 1.; + c0 = 2.; + Wx = 0.003; + Wh = 0.006; + Wc = 0.; + Wp = 0.; + b = 0.5; + + auto expH = NDArrayFactory::create( + 'c', {time, batchSize, numProj}, + {0.57574, 0.57574, 0.57574, 0.58006, 0.58006, 0.58006, 0.58434, 0.58434, + 0.58434, 0.55114, 0.55114, 0.55114, 0.55732, 0.55732, 0.55732, 0.56338, + 0.56338, 0.56338, 0.53763, 0.53763, 0.53763, 0.54534, 0.54534, 0.54534, + 0.55287, 0.55287, 0.55287, 0.53626, 0.53626, 0.53626, 0.54487, 0.54487, + 0.54487, 0.55327, 0.55327, 0.55327, 0.54484, 0.54484, 0.54484, 0.55379, + 0.55379, 0.55379, 0.5625, 0.5625, 0.5625}); + + auto expClast = NDArrayFactory::create( + 'c', {1, batchSize, numProj}, + {1.1589154, 1.1589154, 1.1589154, 1.1892855, 1.1892855, 1.1892855, + 1.219861, 1.219861, 1.219861}); + + sd::ops::lstm op; + auto results = + op.evaluate({&x, &h0, &c0, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 0.}, {0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto c = results.at(1); + auto cLast = c({4, 5, 0, 0, 0, 0}, true); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expClast.isSameShape(&cLast)); + ASSERT_TRUE(expClast.equalsTo(&cLast)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, relu6_test1) { + auto input = NDArrayFactory::create('c', {2, 4}, + {-13., 10, -5, 0, 2, 7, 6, 12}); + auto expected = NDArrayFactory::create( + 'c', {2, 4}, {0., 6., 0., 0., 2., 6., 6., 6.}); - auto input = NDArrayFactory::create('c', {2,4}, {-13.,10,-5,0,2,7,6,12}); - auto expected = NDArrayFactory::create('c', {2,4}, {0., 6., 0., 0.,2., 6., 6., 6.}); + sd::ops::relu6 op; + auto results = op.evaluate({&input}, {0.}, {}); - sd::ops::relu6 op; - auto results = op.evaluate({&input}, {0.}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, relu6_bp_test1) { + auto input = NDArrayFactory::create('c', {2, 4}, + {-13., 10, -5, 0, 2, 7, 6, 5}); + auto gradO = NDArrayFactory::create( + 'c', {2, 4}, {-1., -2., 0., 4., 5., 6., 7., 8.}); - auto input = NDArrayFactory::create('c', {2,4}, {-13.,10, -5, 0, 2, 7, 6, 5}); - auto gradO = NDArrayFactory::create('c', {2,4}, {-1., -2., 0., 4., 5., 6., 7., 8.}); - - auto expected = NDArrayFactory::create('c', {2,4}, {0., 0., 0., 0., 5., 0., 0., 8.}); + auto expected = NDArrayFactory::create( + 'c', {2, 4}, {0., 0., 0., 0., 5., 0., 0., 8.}); - sd::ops::relu6_bp op; - auto results = op.evaluate({&input, &gradO}, {0.}); + sd::ops::relu6_bp op; + auto results = op.evaluate({&input, &gradO}, {0.}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_1) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, + {5.5f, 0.f, 0.3f, 5.5f, 8.6f, 0.f, 0.f, 0.4f, 1.5f, 1.f, 1.3f, 1.5f, 2.6f, + 2.f, 3.f, 1.4f}); - auto x = NDArrayFactory::create('c', {2, 2, 2, 2}, { 5.5f, 0.f, 0.3f, 5.5f, - 8.6f, 0.f, 0.f, 0.4f, - 1.5f, 1.f, 1.3f, 1.5f, - 2.6f, 2.f, 3.f, 1.4f} - ); - - auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, { - 0.98386997f, 0.f, 0.05358852f, 0.9824562f, - 0.99330735f, 0.f, 0.f, 0.37139067f, - 0.72760683f, 0.4850712f, 0.5848977f, 0.67488194f, - 0.7581754f, 0.58321184f, 0.86747235f, 0.4048204f} - ); - - sd::ops::lrn op; - auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {5}); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(out)); - // out->printIndexedBuffer("LRN out"); - // exp.printIndexedBuffer("LRN exp"); - ASSERT_TRUE(exp.equalsTo(out)); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, + {0.98386997f, 0.f, 0.05358852f, 0.9824562f, 0.99330735f, 0.f, 0.f, + 0.37139067f, 0.72760683f, 0.4850712f, 0.5848977f, 0.67488194f, + 0.7581754f, 0.58321184f, 0.86747235f, 0.4048204f}); + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {5}); + auto out = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); } - //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_2) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, + {5.5f, 0.f, 0.3f, 5.5f, 8.6f, 0.f, 0.f, 0.4f, 1.5f, 1.f, 1.3f, 1.5f, 2.6f, + 2.f, 3.f, 1.4f}); - auto x = NDArrayFactory::create('c', {2, 2, 2, 2}, { 5.5f, 0.f, 0.3f, 5.5f, - 8.6f, 0.f, 0.f, 0.4f, - 1.5f, 1.f, 1.3f, 1.5f, - 2.6f, 2.f, 3.f, 1.4f}); - - auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, { - 0.98386997f, 0.f, 0.05358852f, 0.9824562f, - 0.99330735f, 0.f, 0.f, 0.37139067f, - 0.72760683f, 0.4850712f, 0.5848977f, 0.67488194f, - 0.7581754f, 0.58321184f, 0.86747235f, 0.4048204f}); - - sd::ops::lrn op; - auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(out)); - // out->printIndexedBuffer("LRN out"); - // exp.printIndexedBuffer("LRN exp"); - ASSERT_TRUE(exp.equalsTo(out)); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, + {0.98386997f, 0.f, 0.05358852f, 0.9824562f, 0.99330735f, 0.f, 0.f, + 0.37139067f, 0.72760683f, 0.4850712f, 0.5848977f, 0.67488194f, + 0.7581754f, 0.58321184f, 0.86747235f, 0.4048204f}); + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); + auto out = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_3) { - - auto x = NDArrayFactory::create('c', {2, 2, 2, 4}, { - - 5.5f, 0.f, 0.3f, 5.5f, - 1.5f, 0.f, 1.3f, 6.5f, - 8.6f, 0.f, 0.f, 0.4f, - 2.5f, 1.f, 0.3f, 4.5f, - 1.5f, 1.f, 1.3f, 1.5f, - 3.5f, 0.f, 1.3f, 2.5f, - 2.6f, 2.f, 3.f, 1.4f, - 4.5f, 1.f, 0.3f, 0.5f} - ); - - auto exp = NDArrayFactory::create('c', {2, 2, 2, 4}, { - 0.9824562f, 0.f, 0.03822664f, 0.9824562f, - 0.67488194f, 0.f, 0.18924236f, 0.96960944f, - 0.99330735f, 0.f, 0.f, 0.37139067f, - 0.86567914f, 0.18702209f, 0.05610663f, 0.9520745f, - 0.6154575f, 0.34942827f, 0.45425674f, 0.6154575f, - 0.905509f, 0.f, 0.2824086f, 0.8361251f, - 0.57063663f, 0.41959068f, 0.629386f, 0.3504383f, - 0.9520745f, 0.21039814f, 0.06311944f, 0.3268602f } - ); - - sd::ops::lrn op; - auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(out)); - // out->printIndexedBuffer("LRN out"); - // exp.printIndexedBuffer("LRN exp"); - ASSERT_TRUE(exp.equalsTo(out)); - - + auto x = NDArrayFactory::create( + 'c', {2, 2, 2, 4}, + { + + 5.5f, 0.f, 0.3f, 5.5f, 1.5f, 0.f, 1.3f, 6.5f, 8.6f, 0.f, 0.f, + 0.4f, 2.5f, 1.f, 0.3f, 4.5f, 1.5f, 1.f, 1.3f, 1.5f, 3.5f, 0.f, + 1.3f, 2.5f, 2.6f, 2.f, 3.f, 1.4f, 4.5f, 1.f, 0.3f, 0.5f}); + + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 4}, + {0.9824562f, 0.f, 0.03822664f, 0.9824562f, 0.67488194f, + 0.f, 0.18924236f, 0.96960944f, 0.99330735f, 0.f, + 0.f, 0.37139067f, 0.86567914f, 0.18702209f, 0.05610663f, + 0.9520745f, 0.6154575f, 0.34942827f, 0.45425674f, 0.6154575f, + 0.905509f, 0.f, 0.2824086f, 0.8361251f, 0.57063663f, + 0.41959068f, 0.629386f, 0.3504383f, 0.9520745f, 0.21039814f, + 0.06311944f, 0.3268602f}); + + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_4) { - - auto x = NDArrayFactory::create('c', {2, 2, 2, 4}, { - - 5.5f, 0.f, 0.3f, 5.5f, - 1.5f, 0.f, 1.3f, 6.5f, - 8.6f, 0.f, 0.f, 0.4f, - 2.5f, 1.f, 0.3f, 4.5f, - 1.5f, 1.f, 1.3f, 1.5f, - 3.5f, 0.f, 1.3f, 2.5f, - 2.6f, 2.f, 3.f, 1.4f, - 4.5f, 1.f, 0.3f, 0.5f} - ); - - auto exp = NDArrayFactory::create('c', {2, 2, 2, 4}, { - 0.70082176f, 0.f, 0.03822664f, 0.70082176f, - 0.21835658f, 0.f, 0.18924236f, 0.9462118f, - 0.9922489f, 0.f, 0.f, 0.04615111f, - 0.46755522f, 0.18702209f, 0.05610663f, 0.8415994f, - 0.5241424f, 0.34942827f, 0.45425674f, 0.5241424f, - 0.76033086f, 0.f, 0.2824086f, 0.54309344f, - 0.54546785f, 0.41959068f, 0.629386f, 0.29371348f, - 0.94679165f, 0.21039814f, 0.06311944f, 0.10519907f} - ); - - sd::ops::lrn op; - auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {5}); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(out)); - // out->printIndexedBuffer("LRN out"); - // exp.printIndexedBuffer("LRN exp"); - ASSERT_TRUE(exp.equalsTo(out)); - - + auto x = NDArrayFactory::create( + 'c', {2, 2, 2, 4}, + { + + 5.5f, 0.f, 0.3f, 5.5f, 1.5f, 0.f, 1.3f, 6.5f, 8.6f, 0.f, 0.f, + 0.4f, 2.5f, 1.f, 0.3f, 4.5f, 1.5f, 1.f, 1.3f, 1.5f, 3.5f, 0.f, + 1.3f, 2.5f, 2.6f, 2.f, 3.f, 1.4f, 4.5f, 1.f, 0.3f, 0.5f}); + + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 4}, + {0.70082176f, 0.f, 0.03822664f, 0.70082176f, 0.21835658f, + 0.f, 0.18924236f, 0.9462118f, 0.9922489f, 0.f, + 0.f, 0.04615111f, 0.46755522f, 0.18702209f, 0.05610663f, + 0.8415994f, 0.5241424f, 0.34942827f, 0.45425674f, 0.5241424f, + 0.76033086f, 0.f, 0.2824086f, 0.54309344f, 0.54546785f, + 0.41959068f, 0.629386f, 0.29371348f, 0.94679165f, 0.21039814f, + 0.06311944f, 0.10519907f}); + + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {5}); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_5) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 2, 4}, + { - auto x = NDArrayFactory::create('c', {2, 2, 2, 4}, { - - 5.5f, 0.f, 0.3f, 5.5f, - 1.5f, 0.f, 1.3f, 6.5f, - 8.6f, 0.f, 0.f, 0.4f, - 2.5f, 1.f, 0.3f, 4.5f, - 1.5f, 1.f, 1.3f, 1.5f, - 3.5f, 0.f, 1.3f, 2.5f, - 2.6f, 2.f, 3.f, 1.4f, - 4.5f, 1.f, 0.3f, 0.5f} - ); - - auto eps = NDArrayFactory::create('c', {2, 2, 2, 4}, { - 0.70082176f, 0.f, 0.03822664f, 0.70082176f, - 0.21835658f, 0.f, 0.18924236f, 0.9462118f, - - 0.9922489f, 0.f, 0.f, 0.04615111f, - 0.46755522f, 0.18702209f, 0.05610663f, 0.8415994f, + 5.5f, 0.f, 0.3f, 5.5f, 1.5f, 0.f, 1.3f, 6.5f, 8.6f, 0.f, 0.f, + 0.4f, 2.5f, 1.f, 0.3f, 4.5f, 1.5f, 1.f, 1.3f, 1.5f, 3.5f, 0.f, + 1.3f, 2.5f, 2.6f, 2.f, 3.f, 1.4f, 4.5f, 1.f, 0.3f, 0.5f}); + auto eps = NDArrayFactory::create( + 'c', {2, 2, 2, 4}, {0.70082176f, 0.f, 0.03822664f, 0.70082176f, + 0.21835658f, 0.f, 0.18924236f, 0.9462118f, - 0.5241424f, 0.34942827f, 0.45425674f, 0.5241424f, - 0.76033086f, 0.f, 0.2824086f, 0.54309344f, + 0.9922489f, 0.f, 0.f, 0.04615111f, + 0.46755522f, 0.18702209f, 0.05610663f, 0.8415994f, - 0.54546785f, 0.41959068f, 0.629386f, 0.29371348f, - 0.94679165f, 0.21039814f, 0.06311944f, 0.10519907f} - ); + 0.5241424f, 0.34942827f, 0.45425674f, 0.5241424f, + 0.76033086f, 0.f, 0.2824086f, 0.54309344f, - auto exp = NDArrayFactory::create('c', {2, 2, 2, 4}); + 0.54546785f, 0.41959068f, 0.629386f, 0.29371348f, + 0.94679165f, 0.21039814f, 0.06311944f, 0.10519907f}); - sd::ops::lrn_bp op; - auto results = op.evaluate({&x, &eps}, {1.0, 1.0, 0.5}, {5}, {}, {}, false); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(out)); - // out->printIndexedBuffer("LRN out"); - // exp.printIndexedBuffer("LRN exp"); -// ASSERT_TRUE(exp.equalsTo(out)); + auto exp = NDArrayFactory::create('c', {2, 2, 2, 4}); + sd::ops::lrn_bp op; + auto results = op.evaluate({&x, &eps}, {1.0, 1.0, 0.5}, {5}, {}, {}, false); + auto out = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + // ASSERT_TRUE(exp.equalsTo(out)); } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, tri_test1) { + const int rows = 3; + const int cols = 5; - const int rows = 3; - const int cols = 5; - - auto expected = NDArrayFactory::create('c', {rows, cols}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f}); - - sd::ops::tri op; - auto results = op.evaluate({}, {}, {rows, cols}); - auto output = results.at(0); - - // output->printIndexedBuffer(); + auto expected = + NDArrayFactory::create('c', {rows, cols}, + {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 0.f, + 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f}); - ASSERT_EQ(Status::OK(), results.status()); + sd::ops::tri op; + auto results = op.evaluate({}, {}, {rows, cols}); + auto output = results.at(0); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + // output->printIndexedBuffer(); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, tri_test2) { + const int rows = 3; + const int cols = 5; + const int diag = 2; - const int rows = 3; - const int cols = 5; - const int diag = 2; + auto expected = + NDArrayFactory::create('c', {rows, cols}, + {1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 1.f, + 1.f, 0.f, 1.f, 1.f, 1.f, 1.f, 1.f}); - auto expected = NDArrayFactory::create('c', {rows, cols}, {1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 1.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f, 1.f}); - - sd::ops::tri op; - auto results = op.evaluate({}, {}, {rows, cols, diag}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::tri op; + auto results = op.evaluate({}, {}, {rows, cols, diag}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, tri_test3) { + const int rows = 3; + const int cols = 5; + const int diag = -1; - const int rows = 3; - const int cols = 5; - const int diag = -1; - - auto expected = NDArrayFactory::create('c', {rows, cols}, {0.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f}); - - sd::ops::tri op; - auto results = op.evaluate({}, {}, {rows, cols, diag}); - auto output = results.at(0); + auto expected = + NDArrayFactory::create('c', {rows, cols}, + {0.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, + 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f}); - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::tri op; + auto results = op.evaluate({}, {}, {rows, cols, diag}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, tri_test4) { + const int rows = 3; + const int cols = 5; + const int diag = -2; - const int rows = 3; - const int cols = 5; - const int diag = -2; - - auto expected = NDArrayFactory::create('c', {rows, cols}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f}); - - sd::ops::tri op; - auto results = op.evaluate({}, {}, {rows, cols, diag}); - auto output = results.at(0); + auto expected = + NDArrayFactory::create('c', {rows, cols}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f}); - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::tri op; + auto results = op.evaluate({}, {}, {rows, cols, diag}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, tri_test5) { + const int rows = 5; - const int rows = 5; - - auto expected = NDArrayFactory::create('c', {rows, rows}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 1.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f, 1.f}); + auto expected = NDArrayFactory::create( + 'c', {rows, rows}, + {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, + 0.f, 0.f, 1.f, 1.f, 1.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f, 1.f}); - sd::ops::tri op; - auto results = op.evaluate({}, {}, {rows}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::tri op; + auto results = op.evaluate({}, {}, {rows}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, tri_test6) { + const int rows = 3; + const int cols = 5; + const int diag = -20; - const int rows = 3; - const int cols = 5; - const int diag = -20; - - auto expected = NDArrayFactory::create('c', {rows, cols}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); - - sd::ops::tri op; - auto results = op.evaluate({}, {}, {rows, cols, diag}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); + auto expected = + NDArrayFactory::create('c', {rows, cols}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::tri op; + auto results = op.evaluate({}, {}, {rows, cols, diag}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, tri_test7) { + const int rows = 3; + const int cols = 5; + const int diag = 20; - const int rows = 3; - const int cols = 5; - const int diag = 20; + auto expected = + NDArrayFactory::create('c', {rows, cols}, + {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}); - auto expected = NDArrayFactory::create('c', {rows, cols}, {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}); - - sd::ops::tri op; - auto results = op.evaluate({}, {}, {rows, cols, diag}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::tri op; + auto results = op.evaluate({}, {}, {rows, cols, diag}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_test1) { + auto input = NDArrayFactory::create( + 'c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto expected = NDArrayFactory::create( + 'c', {4, 3}, {1, 2, 3, 0, 5, 6, 0, 0, 9, 0, 0, 0}); - auto input = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto expected = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 0, 5, 6, 0, 0, 9, 0, 0, 0}); - - sd::ops::triu op; - auto results = op.evaluate({&input}, {}, {}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_test2) { + auto input = NDArrayFactory::create( + 'c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto expected = NDArrayFactory::create( + 'c', {4, 3}, {1, 2, 3, 4, 5, 6, 0, 8, 9, 0, 0, 12}); - auto input = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto expected = NDArrayFactory::create('c', {4, 3}, {1, 2, 3,4, 5, 6,0, 8, 9,0, 0, 12}); - - sd::ops::triu op; - auto results = op.evaluate({&input}, {}, {-1}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {-1}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_test3) { + auto input = NDArrayFactory::create( + 'c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 2}, {1, 2, 3, 4, 0, 6, 7, 8, 9, 10, 0, 12}); - auto input = NDArrayFactory::create('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto expected = NDArrayFactory::create('c', {2, 3, 2}, {1, 2,3, 4,0, 6,7, 8,9,10,0,12}); - - sd::ops::triu op; - auto results = op.evaluate({&input}, {}, {-1}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {-1}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_test4) { + auto input = NDArrayFactory::create( + 'c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 2}, {1, 2, 0, 4, 0, 0, 7, 8, 0, 10, 0, 0}); - auto input = NDArrayFactory::create('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto expected = NDArrayFactory::create('c', {2, 3, 2}, {1, 2,0, 4,0, 0,7, 8,0, 10,0, 0}); - - sd::ops::triu op; - auto results = op.evaluate({&input}, {}, {}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_test5) { + auto input = NDArrayFactory::create( + 'c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 2}, {0, 2, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0}); - auto input = NDArrayFactory::create('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto expected = NDArrayFactory::create('c', {2, 3, 2}, {0, 2,0, 0,0, 0,0, 8,0, 0,0, 0}); - - sd::ops::triu op; - auto results = op.evaluate({&input}, {}, {1}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {1}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_test6) { + auto input = NDArrayFactory::create( + 'c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 2}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); - auto input = NDArrayFactory::create('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto expected = NDArrayFactory::create('c', {2, 3, 2}, {0, 0,0, 0,0, 0,0, 0,0, 0,0, 0}); - - sd::ops::triu op; - auto results = op.evaluate({&input}, {}, {10}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {10}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_test7) { + auto input = NDArrayFactory::create( + 'c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto input = NDArrayFactory::create('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto expected = NDArrayFactory::create('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - - sd::ops::triu op; - auto results = op.evaluate({&input}, {}, {-10}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {-10}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_test8) { + auto input = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); + auto expected = NDArrayFactory::create( + 'c', {6, 6}, {1, 2, 3, 4, 5, 6, 0, 2, 3, 4, 5, 6, 0, 0, 3, 4, 5, 6, + 0, 0, 0, 4, 5, 6, 0, 0, 0, 0, 5, 6, 0, 0, 0, 0, 0, 6}); - auto input = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); - auto expected = NDArrayFactory::create('c', {6, 6}, {1, 2, 3, 4, 5, 6,0, 2, 3, 4, 5, 6,0, 0, 3, 4, 5, 6,0, 0, 0, 4, 5, 6,0, 0, 0, 0, 5, 6,0, 0, 0, 0, 0, 6}); - - sd::ops::triu op; - auto results = op.evaluate({&input}, {}, {}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_test9) { + auto input = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); + auto expected = NDArrayFactory::create( + 'c', {6, 6}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, + 1, 2, 3, 4, 5, 6, 0, 2, 3, 4, 5, 6, 0, 0, 3, 4, 5, 6}); - auto input = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); - auto expected = NDArrayFactory::create('c', {6, 6}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 0, 2, 3, 4, 5, 6, 0, 0, 3, 4, 5, 6}); - - sd::ops::triu op; - auto results = op.evaluate({&input}, {}, {-3}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {-3}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_test10) { + auto input = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); + auto expected = NDArrayFactory::create( + 'c', {6, 6}, {0, 0, 0, 4, 5, 6, 0, 0, 0, 0, 5, 6, 0, 0, 0, 0, 0, 6, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); - auto input = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); - auto expected = NDArrayFactory::create('c', {6, 6}, {0, 0, 0, 4, 5, 6, 0, 0, 0, 0, 5, 6, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); - - sd::ops::triu op; - auto results = op.evaluate({&input}, {}, {3}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {3}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_test11) { + auto input = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); + auto expected = NDArrayFactory::create( + 'c', {6, 6}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, + 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}); - auto input = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); - auto expected = NDArrayFactory::create('c', {6, 6}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}); - - sd::ops::triu op; - auto results = op.evaluate({&input}, {}, {-58}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {-58}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_bp_test1) { + auto input = NDArrayFactory::create( + 'c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto gradO = NDArrayFactory::create('c', {2, 3, 2}); + gradO = 0.5; - auto input = NDArrayFactory::create('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto gradO = NDArrayFactory::create('c', {2, 3, 2}); - gradO = 0.5; - - auto expected = NDArrayFactory::create('c', {2, 3, 2}, {0.,0.5,0.,0. ,0.,0. ,0.,0.5,0.,0. ,0.,0.}); - - sd::ops::triu_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {1}); - auto gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 2}, {0., 0.5, 0., 0., 0., 0., 0., 0.5, 0., 0., 0., 0.}); - ASSERT_TRUE(expected.isSameShape(gradI)); - ASSERT_TRUE(expected.equalsTo(gradI)); + sd::ops::triu_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {1}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(gradI)); + ASSERT_TRUE(expected.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_bp_test2) { + auto input = NDArrayFactory::create( + 'c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto gradO = NDArrayFactory::create('c', {2, 3, 2}); + gradO = 0.5; - auto input = NDArrayFactory::create('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto gradO = NDArrayFactory::create('c', {2, 3, 2}); - gradO = 0.5; + auto expected = NDArrayFactory::create( + 'c', {2, 3, 2}, {0.5, 0.5, 0., 0.5, 0., 0., 0.5, 0.5, 0., 0.5, 0., 0.}); - auto expected = NDArrayFactory::create('c', {2, 3, 2}, {0.5,0.5,0. ,0.5,0. ,0. ,0.5,0.5,0. ,0.5,0. ,0.}); - - sd::ops::triu_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {}); - auto gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(gradI)); - ASSERT_TRUE(expected.equalsTo(gradI)); + sd::ops::triu_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(gradI)); + ASSERT_TRUE(expected.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_bp_test3) { + auto input = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); + auto gradO = NDArrayFactory::create('c', {6, 6}); + gradO = 0.5; - auto input = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); - auto gradO = NDArrayFactory::create('c', {6,6}); - gradO = 0.5; - - auto expected = NDArrayFactory::create('c', {6,6}, {0.5, 0.5, 0.5, 0.5, 0.5, 0.5,0.5, 0.5, 0.5, 0.5, 0.5, 0.5,0.5, 0.5, 0.5, 0.5, 0.5, 0.5,0. , 0.5, 0.5, 0.5, 0.5, 0.5,0. , 0. , 0.5, 0.5, 0.5, 0.5,0. , 0. , 0. , 0.5, 0.5, 0.5}); - - sd::ops::triu_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {-2}); - auto gradI = results.at(0); + auto expected = NDArrayFactory::create( + 'c', {6, 6}, + {0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, + 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0., 0.5, 0.5, 0.5, 0.5, 0.5, + 0., 0., 0.5, 0.5, 0.5, 0.5, 0., 0., 0., 0.5, 0.5, 0.5}); - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(gradI)); - ASSERT_TRUE(expected.equalsTo(gradI)); + sd::ops::triu_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {-2}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(gradI)); + ASSERT_TRUE(expected.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_bp_test4) { + auto input = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + auto gradO = NDArrayFactory::create('c', {2, 3}); + gradO = 0.5; - auto input = NDArrayFactory::create('c', {2,3}, {1, 2, 3, 4, 5, 6}); - auto gradO = NDArrayFactory::create('c', {2,3}); - gradO = 0.5; - - auto expected = NDArrayFactory::create('c', {2,3}, {0., 0., 0., 0., 0., 0.}); - - sd::ops::triu_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {10}); - auto gradI = results.at(0); + auto expected = + NDArrayFactory::create('c', {2, 3}, {0., 0., 0., 0., 0., 0.}); - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(gradI)); - ASSERT_TRUE(expected.equalsTo(gradI)); + sd::ops::triu_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {10}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(gradI)); + ASSERT_TRUE(expected.equalsTo(gradI)); } - diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp index c68392da1148..61dc37514c5c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp @@ -18,1647 +18,1721 @@ // @author raver119@gmail.com // - -#include "testlayers.h" -#include -#include #include #include +#include +#include +#include "testlayers.h" using namespace sd; using namespace sd::graph; class DeclarableOpsTests5 : public testing::Test { -public: - - DeclarableOpsTests5() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests5() { + printf("\n"); + fflush(stdout); + } }; - TEST_F(DeclarableOpsTests5, Test_PermuteEquality_1) { - auto x = NDArrayFactory::create('c', {1, 60}); - auto exp = NDArrayFactory::create('c', {3, 5, 4}, {1.0, 6.0, 11.0, 16.0, 2.0, 7.0, 12.0, 17.0, 3.0, 8.0, 13.0, 18.0, 4.0, 9.0, 14.0, 19.0, 5.0, 10.0, 15.0, 20.0, 21.0, 26.0, 31.0, 36.0, 22.0, 27.0, 32.0, 37.0, 23.0, 28.0, 33.0, 38.0, 24.0, 29.0, 34.0, 39.0, 25.0, 30.0, 35.0, 40.0, 41.0, 46.0, 51.0, 56.0, 42.0, 47.0, 52.0, 57.0, 43.0, 48.0, 53.0, 58.0, 44.0, 49.0, 54.0, 59.0, 45.0, 50.0, 55.0, 60.0}); - x.linspace(1); - x.reshapei('c', {3, 4, 5}); - - sd::ops::permute op; - auto result = op.evaluate({&x}, {}, {0, 2, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); + auto x = NDArrayFactory::create('c', {1, 60}); + auto exp = NDArrayFactory::create( + 'c', {3, 5, 4}, + {1.0, 6.0, 11.0, 16.0, 2.0, 7.0, 12.0, 17.0, 3.0, 8.0, 13.0, 18.0, + 4.0, 9.0, 14.0, 19.0, 5.0, 10.0, 15.0, 20.0, 21.0, 26.0, 31.0, 36.0, + 22.0, 27.0, 32.0, 37.0, 23.0, 28.0, 33.0, 38.0, 24.0, 29.0, 34.0, 39.0, + 25.0, 30.0, 35.0, 40.0, 41.0, 46.0, 51.0, 56.0, 42.0, 47.0, 52.0, 57.0, + 43.0, 48.0, 53.0, 58.0, 44.0, 49.0, 54.0, 59.0, 45.0, 50.0, 55.0, 60.0}); + x.linspace(1); + x.reshapei('c', {3, 4, 5}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::permute op; + auto result = op.evaluate({&x}, {}, {0, 2, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests5, Test_PermuteEquality_0) { - auto x = NDArrayFactory::create('c', {1, 60}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0}); - x.reshapei('c', {3, 4, 5}); + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 5}, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, + 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, + 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, + 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0}); + x.reshapei('c', {3, 4, 5}); -// x.printShapeInfo("{0, 1, 2} shape"); -// x.printBuffer("{0, 1, 2} data"); + // x.printShapeInfo("{0, 1, 2} shape"); + // x.printBuffer("{0, 1, 2} data"); - sd::ops::permute op; - auto result = op.evaluate({&x}, {}, {0, 1, 2}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::permute op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests5, Test_PermuteEquality_2) { - auto x = NDArrayFactory::create('c', {1, 60}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {4, 3, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 21.0, 22.0, 23.0, 24.0, 25.0, 41.0, 42.0, 43.0, 44.0, 45.0, 6.0, 7.0, 8.0, 9.0, 10.0, 26.0, 27.0, 28.0, 29.0, 30.0, 46.0, 47.0, 48.0, 49.0, 50.0, 11.0, 12.0, 13.0, 14.0, 15.0, 31.0, 32.0, 33.0, 34.0, 35.0, 51.0, 52.0, 53.0, 54.0, 55.0, 16.0, 17.0, 18.0, 19.0, 20.0, 36.0, 37.0, 38.0, 39.0, 40.0, 56.0, 57.0, 58.0, 59.0, 60.0}); - x.reshapei('c', {3, 4, 5}); - -// x.printShapeInfo("{1, 0, 2} shape"); -// x.printBuffer("{1, 0, 2} data"); + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {4, 3, 5}, + {1.0, 2.0, 3.0, 4.0, 5.0, 21.0, 22.0, 23.0, 24.0, 25.0, 41.0, 42.0, + 43.0, 44.0, 45.0, 6.0, 7.0, 8.0, 9.0, 10.0, 26.0, 27.0, 28.0, 29.0, + 30.0, 46.0, 47.0, 48.0, 49.0, 50.0, 11.0, 12.0, 13.0, 14.0, 15.0, 31.0, + 32.0, 33.0, 34.0, 35.0, 51.0, 52.0, 53.0, 54.0, 55.0, 16.0, 17.0, 18.0, + 19.0, 20.0, 36.0, 37.0, 38.0, 39.0, 40.0, 56.0, 57.0, 58.0, 59.0, 60.0}); + x.reshapei('c', {3, 4, 5}); - sd::ops::permute op; - auto result = op.evaluate({&x}, {}, {1, 0, 2}); - ASSERT_EQ(Status::OK(), result.status()); + // x.printShapeInfo("{1, 0, 2} shape"); + // x.printBuffer("{1, 0, 2} data"); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::permute op; + auto result = op.evaluate({&x}, {}, {1, 0, 2}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests5, Test_PermuteEquality_3) { - auto x = NDArrayFactory::create('c', {1, 60}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {4, 5, 3}, {1.0, 21.0, 41.0, 2.0, 22.0, 42.0, 3.0, 23.0, 43.0, 4.0, 24.0, 44.0, 5.0, 25.0, 45.0, 6.0, 26.0, 46.0, 7.0, 27.0, 47.0, 8.0, 28.0, 48.0, 9.0, 29.0, 49.0, 10.0, 30.0, 50.0, 11.0, 31.0, 51.0, 12.0, 32.0, 52.0, 13.0, 33.0, 53.0, 14.0, 34.0, 54.0, 15.0, 35.0, 55.0, 16.0, 36.0, 56.0, 17.0, 37.0, 57.0, 18.0, 38.0, 58.0, 19.0, 39.0, 59.0, 20.0, 40.0, 60.0}); - x.reshapei('c', {3, 4, 5}); - -// x.printShapeInfo("{1, 2, 0} shape"); -// x.printBuffer("{1, 2, 0} data"); + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {4, 5, 3}, + {1.0, 21.0, 41.0, 2.0, 22.0, 42.0, 3.0, 23.0, 43.0, 4.0, 24.0, 44.0, + 5.0, 25.0, 45.0, 6.0, 26.0, 46.0, 7.0, 27.0, 47.0, 8.0, 28.0, 48.0, + 9.0, 29.0, 49.0, 10.0, 30.0, 50.0, 11.0, 31.0, 51.0, 12.0, 32.0, 52.0, + 13.0, 33.0, 53.0, 14.0, 34.0, 54.0, 15.0, 35.0, 55.0, 16.0, 36.0, 56.0, + 17.0, 37.0, 57.0, 18.0, 38.0, 58.0, 19.0, 39.0, 59.0, 20.0, 40.0, 60.0}); + x.reshapei('c', {3, 4, 5}); - sd::ops::permute op; - auto result = op.evaluate({&x}, {}, {1, 2, 0}); - ASSERT_EQ(Status::OK(), result.status()); + // x.printShapeInfo("{1, 2, 0} shape"); + // x.printBuffer("{1, 2, 0} data"); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::permute op; + auto result = op.evaluate({&x}, {}, {1, 2, 0}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests5, Test_PermuteEquality_4) { - auto x = NDArrayFactory::create('c', {1, 60}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {5, 3, 4}, {1.0, 6.0, 11.0, 16.0, 21.0, 26.0, 31.0, 36.0, 41.0, 46.0, 51.0, 56.0, 2.0, 7.0, 12.0, 17.0, 22.0, 27.0, 32.0, 37.0, 42.0, 47.0, 52.0, 57.0, 3.0, 8.0, 13.0, 18.0, 23.0, 28.0, 33.0, 38.0, 43.0, 48.0, 53.0, 58.0, 4.0, 9.0, 14.0, 19.0, 24.0, 29.0, 34.0, 39.0, 44.0, 49.0, 54.0, 59.0, 5.0, 10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0}); - x.reshapei('c', {3, 4, 5}); - -// x.printShapeInfo("{2, 0, 1} shape"); -// x.printBuffer("{2, 0, 1} data"); + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {5, 3, 4}, + {1.0, 6.0, 11.0, 16.0, 21.0, 26.0, 31.0, 36.0, 41.0, 46.0, 51.0, 56.0, + 2.0, 7.0, 12.0, 17.0, 22.0, 27.0, 32.0, 37.0, 42.0, 47.0, 52.0, 57.0, + 3.0, 8.0, 13.0, 18.0, 23.0, 28.0, 33.0, 38.0, 43.0, 48.0, 53.0, 58.0, + 4.0, 9.0, 14.0, 19.0, 24.0, 29.0, 34.0, 39.0, 44.0, 49.0, 54.0, 59.0, + 5.0, 10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0}); + x.reshapei('c', {3, 4, 5}); - sd::ops::permute op; - auto result = op.evaluate({&x}, {}, {2, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); + // x.printShapeInfo("{2, 0, 1} shape"); + // x.printBuffer("{2, 0, 1} data"); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::permute op; + auto result = op.evaluate({&x}, {}, {2, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests5, Test_PermuteEquality_5) { - auto x = NDArrayFactory::create('c', {1, 60}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {5, 4, 3}, {1.0, 21.0, 41.0, 6.0, 26.0, 46.0, 11.0, 31.0, 51.0, 16.0, 36.0, 56.0, 2.0, 22.0, 42.0, 7.0, 27.0, 47.0, 12.0, 32.0, 52.0, 17.0, 37.0, 57.0, 3.0, 23.0, 43.0, 8.0, 28.0, 48.0, 13.0, 33.0, 53.0, 18.0, 38.0, 58.0, 4.0, 24.0, 44.0, 9.0, 29.0, 49.0, 14.0, 34.0, 54.0, 19.0, 39.0, 59.0, 5.0, 25.0, 45.0, 10.0, 30.0, 50.0, 15.0, 35.0, 55.0, 20.0, 40.0, 60.0}); - x.reshapei('c', {3, 4, 5}); - -// x.printShapeInfo("{2, 1, 0} shape"); -// x.printBuffer("{2, 1, 0} data"); + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {5, 4, 3}, + {1.0, 21.0, 41.0, 6.0, 26.0, 46.0, 11.0, 31.0, 51.0, 16.0, 36.0, 56.0, + 2.0, 22.0, 42.0, 7.0, 27.0, 47.0, 12.0, 32.0, 52.0, 17.0, 37.0, 57.0, + 3.0, 23.0, 43.0, 8.0, 28.0, 48.0, 13.0, 33.0, 53.0, 18.0, 38.0, 58.0, + 4.0, 24.0, 44.0, 9.0, 29.0, 49.0, 14.0, 34.0, 54.0, 19.0, 39.0, 59.0, + 5.0, 25.0, 45.0, 10.0, 30.0, 50.0, 15.0, 35.0, 55.0, 20.0, 40.0, 60.0}); + x.reshapei('c', {3, 4, 5}); - sd::ops::permute op; - auto result = op.evaluate({&x}, {}, {2, 1, 0}); - ASSERT_EQ(Status::OK(), result.status()); + // x.printShapeInfo("{2, 1, 0} shape"); + // x.printBuffer("{2, 1, 0} data"); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::permute op; + auto result = op.evaluate({&x}, {}, {2, 1, 0}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests5, Test_TTS_bp_1) { - auto x = NDArrayFactory::create('c', {2, 1, 3}); - auto eps = NDArrayFactory::create('c', {2, 4, 3}); - - auto exp = NDArrayFactory::create('c', {2, 1, 3}, {22.f, 26.f, 30.f, 70.f, 74.f, 78.f}); + auto x = NDArrayFactory::create('c', {2, 1, 3}); + auto eps = NDArrayFactory::create('c', {2, 4, 3}); - eps.linspace(1.f); + auto exp = NDArrayFactory::create( + 'c', {2, 1, 3}, {22.f, 26.f, 30.f, 70.f, 74.f, 78.f}); - sd::ops::tile_to_shape_bp op; - auto result = op.evaluate({&x, &eps}, {}, {2, 4, 3}); + eps.linspace(1.f); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - // z->printShapeInfo("RES shape"); - // x.printShapeInfo("EXP shape"); - // z->printIndexedBuffer("RES output"); - ASSERT_TRUE(x.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::tile_to_shape_bp op; + auto result = op.evaluate({&x, &eps}, {}, {2, 4, 3}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printShapeInfo("RES shape"); + // x.printShapeInfo("EXP shape"); + // z->printIndexedBuffer("RES output"); + ASSERT_TRUE(x.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests5, Test_Rdiv_bp_1) { - auto x = NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); - auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto eps = NDArrayFactory::create('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto x = NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); + auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto eps = NDArrayFactory::create( + 'c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + sd::ops::reversedivide op_ff; + auto result_ff = op_ff.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result_ff.status()); - sd::ops::reversedivide op_ff; - auto result_ff = op_ff.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result_ff.status()); + auto z_ff = result_ff.at(0); + ASSERT_TRUE(eps.isSameShape(z_ff)); - auto z_ff = result_ff.at(0); - ASSERT_TRUE(eps.isSameShape(z_ff)); + sd::ops::reversedivide_bp op_bp; + auto result_bp = op_bp.evaluate({&x, &y, &eps}, {}, {}); + ASSERT_EQ(Status::OK(), result_bp.status()); - sd::ops::reversedivide_bp op_bp; - auto result_bp = op_bp.evaluate({&x, &y, &eps}, {}, {}); - ASSERT_EQ(Status::OK(), result_bp.status()); - - auto z_bp = result_bp.at(0); - ASSERT_TRUE(x.isSameShape(z_bp)); + auto z_bp = result_bp.at(0); + ASSERT_TRUE(x.isSameShape(z_bp)); } - TEST_F(DeclarableOpsTests5, Test_Boolean_diff_1) { - auto x = NDArrayFactory::create('c', {1, 1}, {1.0f}); - auto y = NDArrayFactory::create(2.0f); - - sd::ops::less op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(result.at(0)->t(0), true); + auto x = NDArrayFactory::create('c', {1, 1}, {1.0f}); + auto y = NDArrayFactory::create(2.0f); + sd::ops::less op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(result.at(0).t(0), true); } TEST_F(DeclarableOpsTests5, Test_SetSeed_1) { - auto x = NDArrayFactory::create('c', {1, 1}, {120}); - auto y = NDArrayFactory::create(5); + auto x = NDArrayFactory::create('c', {1, 1}, {120}); + auto y = NDArrayFactory::create(5); - sd::ops::set_seed op; - auto result = op.evaluate({&x, &y}, {}, {120, 5}); + sd::ops::set_seed op; + auto result = op.evaluate({&x, &y}, {}, {120, 5}); - ASSERT_EQ(Status::OK(), result.status()); -// result->at(0)->printIndexedBuffer("RES SEED"); + ASSERT_EQ(Status::OK(), result.status()); + // result->at(0)->printIndexedBuffer("RES SEED"); - sd::ops::get_seed getOp; - auto getRes = getOp.evaluate({}); - ASSERT_EQ(Status::OK(), getRes.status()); -// getres.at(0)->printIndexedBuffer("Output RES GET SEED"); -// ASSERT_EQ(result.at(0)->t(0), true); + sd::ops::get_seed getOp; + auto getRes = getOp.evaluate({}); + ASSERT_EQ(Status::OK(), getRes.status()); + // getres.at(0)->printIndexedBuffer("Output RES GET SEED"); + // ASSERT_EQ(result.at(0)->t(0), true); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, scatterMul_test1) { - auto matrix = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); - NDArray idc('c', {1}, std::vector({0LL}), sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {1, 2}, {10.f, 1.f}); - auto exp = NDArrayFactory::create('c', {2, 2}, {10.f, 2.f, 3.f, 4.f}); - - sd::ops::scatter_mul op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto matrix = + NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + NDArray idc('c', {1}, std::vector({0LL}), sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2}, {10.f, 1.f}); + auto exp = NDArrayFactory::create('c', {2, 2}, {10.f, 2.f, 3.f, 4.f}); - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_mul op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, scatterDiv_test1) { - auto matrix = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); - NDArray idc('c', {1}, std::vector({0LL}), sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {1, 2}, {10.f, 1.f}); - auto exp = NDArrayFactory::create('c', {2, 2}, {0.10f, 2.f, 3.f, 4.f}); - - sd::ops::scatter_div op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); -// z->printIndexedBuffer("Scatter Div"); - ASSERT_TRUE(exp.equalsTo(z)); + auto matrix = + NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + NDArray idc('c', {1}, std::vector({0LL}), sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2}, {10.f, 1.f}); + auto exp = NDArrayFactory::create('c', {2, 2}, {0.10f, 2.f, 3.f, 4.f}); + sd::ops::scatter_div op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Scatter Div"); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, scatterSub_test1) { - auto matrix = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); - NDArray idc('c', {1}, std::vector({0LL}), sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {1, 2}, {10.f, 1.f}); - auto exp = NDArrayFactory::create('c', {2, 2}, {-9.f, 1.f, 3.f, 4.f}); - - sd::ops::scatter_sub op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); -// z->printIndexedBuffer("Scatter Sub"); - ASSERT_TRUE(exp.equalsTo(z)); + auto matrix = + NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + NDArray idc('c', {1}, std::vector({0LL}), sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2}, {10.f, 1.f}); + auto exp = NDArrayFactory::create('c', {2, 2}, {-9.f, 1.f, 3.f, 4.f}); + sd::ops::scatter_sub op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Scatter Sub"); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, hardsigmoid_test1) { - auto matrix = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 2}, {0.7f, 0.9f, 1.f, 1.f}); - - sd::ops::hardsigmoid op; - auto result = op.evaluate({&matrix}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - ASSERT_TRUE(exp.equalsTo(z)); + auto matrix = + NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 2}, {0.7f, 0.9f, 1.f, 1.f}); + sd::ops::hardsigmoid op; + auto result = op.evaluate({&matrix}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, hardsigmoid_test2) { - auto matrix = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); - auto eps = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 2}, {0.2f, 0.4f, 0.f, 0.f}); - - sd::ops::hardsigmoid_bp op; - auto result = op.evaluate({&matrix, &eps}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - ASSERT_TRUE(exp.equalsTo(z)); + auto matrix = + NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + auto eps = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 2}, {0.2f, 0.4f, 0.f, 0.f}); + sd::ops::hardsigmoid_bp op; + auto result = op.evaluate({&matrix, &eps}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, hardtanh_test1) { - auto matrix = NDArrayFactory::create('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3, 3}, {-1, -1, -1, -1, 0, 1, 1, 1, 1}); - - sd::ops::hardtanh op; - auto result = op.evaluate({&matrix}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); -// z->printIndexedBuffer("Hardtanh 2x2"); - ASSERT_TRUE(exp.equalsTo(z)); + auto matrix = NDArrayFactory::create('c', {3, 3}, + {-4, -3, -2, -1, 0, 1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3, 3}, + {-1, -1, -1, -1, 0, 1, 1, 1, 1}); + sd::ops::hardtanh op; + auto result = op.evaluate({&matrix}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Hardtanh 2x2"); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, hardtanh_test2) { - auto matrix = NDArrayFactory::create('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); - auto eps = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto exp = NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 4, 5, 6, 0, 0, 0}); - - sd::ops::hardtanh_bp op; - auto result = op.evaluate({&matrix, &eps}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); -// z->printIndexedBuffer("Hardtanh_bp 2x2"); - ASSERT_TRUE(exp.equalsTo(z)); + auto matrix = NDArrayFactory::create('c', {3, 3}, + {-4, -3, -2, -1, 0, 1, 2, 3, 4}); + auto eps = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto exp = + NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 4, 5, 6, 0, 0, 0}); + sd::ops::hardtanh_bp op; + auto result = op.evaluate({&matrix, &eps}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Hardtanh_bp 2x2"); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, histogram_test1) { - auto matrix = NDArrayFactory::create('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {3, 3, 3}); - - sd::ops::histogram op; - auto result = op.evaluate({&matrix}, {}, {3}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); -// z->printIndexedBuffer("Histogram3"); - ASSERT_TRUE(exp.equalsTo(z)); + auto matrix = NDArrayFactory::create('c', {3, 3}, + {-4, -3, -2, -1, 0, 1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {3, 3, 3}); + sd::ops::histogram op; + auto result = op.evaluate({&matrix}, {}, {3}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Histogram3"); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, histogram_test2) { - auto matrix = NDArrayFactory::create('c', {3}, {1, 2, 1}); - auto exp = NDArrayFactory::create('c', {4}, {2, 0, 0, 1}); - - sd::ops::histogram op; - auto result = op.evaluate({&matrix}, {}, {4}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - ASSERT_TRUE(exp.equalsTo(z)); + auto matrix = NDArrayFactory::create('c', {3}, {1, 2, 1}); + auto exp = NDArrayFactory::create('c', {4}, {2, 0, 0, 1}); + sd::ops::histogram op; + auto result = op.evaluate({&matrix}, {}, {4}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, Identity_test1) { - auto matrix = NDArrayFactory::create('c', {3, 3}, {-4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f}); -// auto exp = NDArrayFactory::create('c', {3, 3}, {3, 3, 3}); - - sd::ops::identity op; - auto result = op.evaluate({&matrix}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - ASSERT_TRUE(matrix.equalsTo(z)); + auto matrix = NDArrayFactory::create( + 'c', {3, 3}, {-4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f}); + // auto exp = NDArrayFactory::create('c', {3, 3}, {3, 3, 3}); + sd::ops::identity op; + auto result = op.evaluate({&matrix}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(matrix.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, Identity_test2) { - auto matrix = NDArrayFactory::create('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); - auto eps = NDArrayFactory::create('c', {3, 3}, {1,2,3,4,5,6,7,8,9}); -// auto exp = NDArrayFactory::create('c', {3,3}); - sd::ops::identity_bp op; - auto result = op.evaluate({&matrix, &eps}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - ASSERT_TRUE(z->equalsTo(eps)); - - + auto matrix = NDArrayFactory::create('c', {3, 3}, + {-4, -3, -2, -1, 0, 1, 2, 3, 4}); + auto eps = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + // auto exp = NDArrayFactory::create('c', {3,3}); + sd::ops::identity_bp op; + auto result = op.evaluate({&matrix, &eps}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + ASSERT_TRUE(z.equalsTo(eps)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, Log1p_test1) { - auto matrix = NDArrayFactory::create('c', {3, 3}, {4, 3, 2, 1, 0, 1, 2, 3, 4}); - auto y = NDArrayFactory::create('c', {3,3}, {5,4,3,2,1,2,3,4,5}); - // auto eps = NDArrayFactory::create('c', {3, 3}, {1,2,3,4,5,6,7,8,9}); -// auto exp = NDArrayFactory::create('c', {3,3}); - sd::ops::Log1p op; - y.applyTransform(sd::transform::Log, y); - auto result = op.evaluate({&matrix}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - ASSERT_TRUE(z->equalsTo(y)); - - + auto matrix = + NDArrayFactory::create('c', {3, 3}, {4, 3, 2, 1, 0, 1, 2, 3, 4}); + auto y = + NDArrayFactory::create('c', {3, 3}, {5, 4, 3, 2, 1, 2, 3, 4, 5}); + // auto eps = NDArrayFactory::create('c', {3, 3}, + // {1,2,3,4,5,6,7,8,9}); + // auto exp = NDArrayFactory::create('c', {3,3}); + sd::ops::Log1p op; + y.applyTransform(sd::transform::Log, y); + auto result = op.evaluate({&matrix}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + ASSERT_TRUE(z.equalsTo(y)); } TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_1) { + auto x = NDArrayFactory::create( + 'c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create( + 'c', {4, 1, 1, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); - auto x = NDArrayFactory::create('c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto exp = NDArrayFactory::create('c', {4, 1, 1, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); - - sd::ops::space_to_batch op; - auto result = op.evaluate({&x, &paddings}, {}, {2}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::space_to_batch op; + auto result = op.evaluate({&x, &paddings}, {}, {2}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_2) { - auto x = NDArrayFactory::create('c', {1, 2, 2, 1}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4, 1, 1, 1}, {1, 2, 3, 4}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); - - sd::ops::space_to_batch op; - auto result = op.evaluate({&x, &paddings}, {}, {2}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {1, 2, 2, 1}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4, 1, 1, 1}, {1, 2, 3, 4}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::space_to_batch op; + auto result = op.evaluate({&x, &paddings}, {}, {2}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_3) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 4, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 2, 0}); + auto exp = NDArrayFactory::create( + 'c', {8, 1, 3, 1}, {0, 1, 3, 0, 9, 11, 0, 2, 4, 0, 10, 12, + 0, 5, 7, 0, 13, 15, 0, 6, 8, 0, 14, 16}); - auto x = NDArrayFactory::create('c', {2, 2, 4, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 2, 0}); - auto exp = NDArrayFactory::create('c', {8, 1, 3, 1}, {0, 1, 3, 0, 9, 11,0, 2, 4, 0, 10, 12,0, 5, 7, 0, 13, 15,0, 6, 8, 0, 14, 16}); - - sd::ops::space_to_batch op; - auto result = op.evaluate({&x, &paddings}, {}, {2}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - // z->printIndexedBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::space_to_batch op; + auto result = op.evaluate({&x, &paddings}, {}, {2}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer(); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_4) { - - const int blockSize = 2; - NDArray x('c', {3, 3*blockSize - 1 - 2, 4*blockSize - 2 - 3, 2}, {147, 148, 219, 220, 149, 150, 11, 12, 83, 84, 13, 14, 155, 156, 227, 228, 157, 158, 171, 172, 243, 244, 173, 174, 35, 36, 107, 108, 37, 38, 179, 180, 251, 252, 181, 182, 195, 196, 267, 268, 197, 198, 59, 60, 131, 132, 61, 62, 203, 204, 275, 276, 205, 206}, sd::DataType::FLOAT32); - NDArray paddings = NDArrayFactory::create('c', {2, 2}, {1, 2, 2, 3}); - - NDArray exp('c', {3*blockSize*blockSize, 3, 4, 2}, {0,0, 0,0, 0,0, 0,0, 0,0, 11,12, 13,14, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, - 0,0, 0,0, 0,0, 35,36, 37,38, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 59,60, 61,62, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, - 0,0, 0,0, 0,0, 0,0, 83,84, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 107, 108, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, - 0,0, 0,0, 0,0, 0,0, 0,0, 131, 132, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 147, 148, 149, 150, 0,0, 0,0, 155, 156, 157, 158, - 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 171, 172, 173, 174, 0,0, 0,0, 179, 180, 181, 182, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 195, 196, - 197, 198, 0,0, 0,0, 203, 204, 205, 206, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 219, 220, 0,0, 0,0, 0,0, 227, 228, 0,0, 0,0, 0,0, - 0,0, 0,0, 0,0, 0,0, 243, 244, 0,0, 0,0, 0,0, 251, 252, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 267, 268, 0,0, 0,0, 0,0, 275, - 276, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0}, sd::DataType::FLOAT32); - - sd::ops::space_to_batch op; - auto result = op.evaluate({&x, &paddings}, {}, {blockSize}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - // z->printIndexedBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + const int blockSize = 2; + NDArray x( + 'c', {3, 3 * blockSize - 1 - 2, 4 * blockSize - 2 - 3, 2}, + {147, 148, 219, 220, 149, 150, 11, 12, 83, 84, 13, 14, 155, 156, + 227, 228, 157, 158, 171, 172, 243, 244, 173, 174, 35, 36, 107, 108, + 37, 38, 179, 180, 251, 252, 181, 182, 195, 196, 267, 268, 197, 198, + 59, 60, 131, 132, 61, 62, 203, 204, 275, 276, 205, 206}, + sd::DataType::FLOAT32); + NDArray paddings = NDArrayFactory::create('c', {2, 2}, {1, 2, 2, 3}); + + NDArray exp('c', {3 * blockSize * blockSize, 3, 4, 2}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11, 12, 13, 14, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 35, 36, 37, 38, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 59, 60, 61, 62, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 83, 84, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 107, 108, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 131, 132, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 147, 148, 149, 150, 0, 0, 0, 0, 155, 156, 157, 158, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 171, 172, 173, 174, 0, 0, + 0, 0, 179, 180, 181, 182, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 195, 196, 197, 198, 0, 0, 0, 0, 203, 204, 205, 206, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 219, 220, 0, 0, 0, 0, + 0, 0, 227, 228, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 243, 244, 0, 0, 0, 0, 0, 0, 251, 252, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 267, 268, 0, 0, 0, 0, + 0, 0, 275, 276, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + sd::DataType::FLOAT32); + + sd::ops::space_to_batch op; + auto result = op.evaluate({&x, &paddings}, {}, {blockSize}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printIndexedBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests5, Test_BatchToSpace_1) { - auto x = NDArrayFactory::create('c', {4, 1, 1, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto exp = NDArrayFactory::create('c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto crops = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); + auto x = NDArrayFactory::create( + 'c', {4, 1, 1, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create( + 'c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto crops = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); - sd::ops::batch_to_space op; - auto result = op.evaluate({&x, &crops}, {}, {2}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - // z->printIndexedBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::batch_to_space op; + auto result = op.evaluate({&x, &crops}, {}, {2}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer(); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests5, Test_BatchToSpace_2) { - auto x = NDArrayFactory::create('c', {4, 1, 1, 1}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 2, 2, 1}, {1, 2, 3, 4}); - auto crops = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); - - sd::ops::batch_to_space op; - auto result = op.evaluate({&x, &crops}, {}, {2}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); + auto x = NDArrayFactory::create('c', {4, 1, 1, 1}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 2, 2, 1}, {1, 2, 3, 4}); + auto crops = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::batch_to_space op; + auto result = op.evaluate({&x, &crops}, {}, {2}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests5, Test_BatchToSpace_3) { - auto x = NDArrayFactory::create('c', {8, 1, 3, 1}, {0, 1, 3, 0, 9, 11, - 0, 2, 4, 0, 10, 12, - 0, 5, 7, 0, 13, 15, - 0, 6, 8, 0, 14, 16}); - auto exp = NDArrayFactory::create('c', {2, 2, 4, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - auto crops = NDArrayFactory::create('c', {2, 2}, {0, 0, 2, 0}); - - sd::ops::batch_to_space op; - auto result = op.evaluate({&x, &crops}, {}, {2}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create( + 'c', {8, 1, 3, 1}, {0, 1, 3, 0, 9, 11, 0, 2, 4, 0, 10, 12, + 0, 5, 7, 0, 13, 15, 0, 6, 8, 0, 14, 16}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 4, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + auto crops = NDArrayFactory::create('c', {2, 2}, {0, 0, 2, 0}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::batch_to_space op; + auto result = op.evaluate({&x, &crops}, {}, {2}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, Test_BatchToSpace_4) { + const int blockSize = 2; + NDArray x('c', {3 * blockSize * blockSize, 3, 4, 2}, sd::DataType::FLOAT32); + x.linspace(1, 1); + NDArray crops = NDArrayFactory::create('c', {2, 2}, {1, 2, 2, 3}); - const int blockSize = 2; - NDArray x('c', {3*blockSize*blockSize, 3, 4, 2}, sd::DataType::FLOAT32); - x.linspace(1, 1); - NDArray crops = NDArrayFactory::create('c', {2, 2}, {1, 2, 2, 3}); - - NDArray exp('c', {3, 3*blockSize - 1 - 2, 4*blockSize - 2 - 3, 2}, {147, 148, 219, 220, 149, 150, 11, 12, 83, 84, 13, 14, 155, 156, 227, 228, 157, 158, 171, 172, 243, 244, 173, 174, 35, 36, 107, 108, 37, 38, 179, 180, 251, 252, 181, 182, 195, 196, 267, 268, 197, 198, 59, 60, 131, 132, 61, 62, 203, 204, 275, 276, 205, 206}, sd::DataType::FLOAT32); + NDArray exp( + 'c', {3, 3 * blockSize - 1 - 2, 4 * blockSize - 2 - 3, 2}, + {147, 148, 219, 220, 149, 150, 11, 12, 83, 84, 13, 14, 155, 156, + 227, 228, 157, 158, 171, 172, 243, 244, 173, 174, 35, 36, 107, 108, + 37, 38, 179, 180, 251, 252, 181, 182, 195, 196, 267, 268, 197, 198, + 59, 60, 131, 132, 61, 62, 203, 204, 275, 276, 205, 206}, + sd::DataType::FLOAT32); - sd::ops::batch_to_space op; - auto result = op.evaluate({&x, &crops}, {}, {blockSize}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::batch_to_space op; + auto result = op.evaluate({&x, &crops}, {}, {blockSize}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, eye_test1) { + auto expected = NDArrayFactory::create( + 'c', {3, 3}, {1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f}); - auto expected = NDArrayFactory::create('c', {3, 3}, {1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f}); - - sd::ops::eye op; - auto results = op.evaluate({}, {}, {-99, 3}); - auto output = results.at(0); - // output->printIndexedBuffer(); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + sd::ops::eye op; + auto results = op.evaluate({}, {}, {-99, 3}); + auto output = results.at(0); + // output->printIndexedBuffer(); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, eye_test2) { + auto expected = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f}); - auto expected = NDArrayFactory::create('c', {3, 4}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f}); - - sd::ops::eye op; - auto results = op.evaluate({}, {}, {-99, 3, 4}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + sd::ops::eye op; + auto results = op.evaluate({}, {}, {-99, 3, 4}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, eye_test3) { + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0}); - auto expected = NDArrayFactory::create('c', {2, 3, 4}, {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0}); - - sd::ops::eye op; - auto results = op.evaluate({}, {9 /*int*/}, {-99, 3, 4, 2}); - auto output = results.at(0); - // output->printIndexedBuffer("Output eye"); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + sd::ops::eye op; + auto results = op.evaluate({}, {9 /*int*/}, {-99, 3, 4, 2}); + auto output = results.at(0); + // output->printIndexedBuffer("Output eye"); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, eye_test4) { + auto expected = NDArrayFactory::create( + 'c', {2, 2, 3, 4}, + {1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., + 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., + 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.}); - auto expected = NDArrayFactory::create('c', {2, 2, 3, 4}, {1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.}); - - sd::ops::eye op; - auto results = op.evaluate({}, {6/*double*/}, {-99, 3, 4, 2, 2}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + sd::ops::eye op; + auto results = op.evaluate({}, {6 /*double*/}, {-99, 3, 4, 2, 2}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, eye_test5) { + sd::ops::eye op; + auto result = op.evaluate({}, {}, {3, 2}); - sd::ops::eye op; - auto result = op.evaluate({},{},{3, 2}); - - auto z = result.at(0); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - + auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, gatherNd_test1) { + auto input = NDArrayFactory::create('c', {4, 3, 2}); + input.linspace(1); + auto indices = NDArrayFactory::create('c', {2, 2, 1}, {3, 2, 3, 2}); - auto input = NDArrayFactory::create('c', {4, 3, 2}); - input.linspace(1); - auto indices = NDArrayFactory::create('c', {2,2,1}, {3,2,3,2}); - - auto expected = NDArrayFactory::create('c', {2,2,3,2}, {19, 20, 21, 22, 23, 24, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 13, 14, 15, 16, 17, 18}); - - sd::ops::gather_nd op; - auto results = op.evaluate({&input, &indices}, {}, {}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto expected = NDArrayFactory::create( + 'c', {2, 2, 3, 2}, {19, 20, 21, 22, 23, 24, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 13, 14, 15, 16, 17, 18}); + sd::ops::gather_nd op; + auto results = op.evaluate({&input, &indices}, {}, {}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, gatherNd_test2) { + auto input = NDArrayFactory::create('c', {4, 3, 2}); + input.linspace(1); + auto indices = + NDArrayFactory::create('c', {2, 2, 2}, {3, 2, 1, 2, 0, 1, 0, 1}); - auto input = NDArrayFactory::create('c', {4, 3, 2}); - input.linspace(1); - auto indices = NDArrayFactory::create('c', {2,2,2}, {3,2,1,2, 0,1,0,1}); - - auto expected = NDArrayFactory::create('c', {2,2,2}, {23, 24, 11, 12, 3, 4, 3, 4}); - - sd::ops::gather_nd op; - auto results = op.evaluate({&input, &indices}, {}, {}, {true}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto expected = NDArrayFactory::create('c', {2, 2, 2}, + {23, 24, 11, 12, 3, 4, 3, 4}); + sd::ops::gather_nd op; + auto results = op.evaluate({&input, &indices}, {}, {}, {true}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, gatherNd_test3) { + auto input = NDArrayFactory::create('c', {4, 3, 2}); + input.linspace(1); + auto indices = NDArrayFactory::create('c', {3}, {3, 2, 1}); + auto expected = NDArrayFactory::create(24.); - auto input = NDArrayFactory::create('c', {4, 3, 2}); - input.linspace(1); - auto indices = NDArrayFactory::create('c', {3}, {3,2,1}); - auto expected = NDArrayFactory::create(24.); - - sd::ops::gather_nd op; - auto results = op.evaluate({&input, &indices}, {}, {}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + sd::ops::gather_nd op; + auto results = op.evaluate({&input, &indices}, {}, {}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, gatherNd_test4) { + auto input = NDArrayFactory::create('c', {4, 3, 2}); + input.linspace(1); + auto indices = NDArrayFactory::create('c', {2, 3}, {3, 2, 1, 0, 2, 1}); + auto expected = NDArrayFactory::create('c', {2}, {24., 6}); - auto input = NDArrayFactory::create('c', {4, 3, 2}); - input.linspace(1); - auto indices = NDArrayFactory::create('c', {2,3}, {3,2,1,0,2,1}); - auto expected = NDArrayFactory::create('c',{2}, {24., 6}); - - sd::ops::gather_nd op; - auto results = op.evaluate({&input, &indices}, {}, {}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + sd::ops::gather_nd op; + auto results = op.evaluate({&input, &indices}, {}, {}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, gatherNd_test5) { + auto input = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + auto indices = NDArrayFactory::create('c', {5, 1}, {3, 2, 0, 1, 1}); + auto expected = NDArrayFactory::create('c', {5}, {4., 3, 1, 2, 2}); - auto input = NDArrayFactory::create('c', {4}, {1,2,3,4}); - auto indices = NDArrayFactory::create('c', {5,1}, {3,2,0,1,1}); - auto expected = NDArrayFactory::create('c',{5}, {4.,3,1,2,2}); - - sd::ops::gather_nd op; - auto results = op.evaluate({&input, &indices}, {}, {}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + sd::ops::gather_nd op; + auto results = op.evaluate({&input, &indices}, {}, {}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, gatherNd_test6) { + auto input = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + std::vector shape = {1}; + auto indices = NDArrayFactory::create('c', shape, {2}); + auto expected = NDArrayFactory::create(3.); - auto input = NDArrayFactory::create('c', {4}, {1,2,3,4}); - std::vector shape = {1}; - auto indices = NDArrayFactory::create('c', shape, {2}); - auto expected = NDArrayFactory::create(3.); - - sd::ops::gather_nd op; - auto results = op.evaluate({&input, &indices}, {}, {}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + sd::ops::gather_nd op; + auto results = op.evaluate({&input, &indices}, {}, {}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, gatherNd_test7) { + auto input = NDArrayFactory::create('c', {4, 4}); + input.linspace(1); + auto indices = NDArrayFactory::create( + 'c', {3, 3, 2}, {0, 2, 1, 0, 1, 0, 1, 3, 1, 0, 2, 1, 0, 1, 0, 1, 3, 1}); + auto expected = NDArrayFactory::create('c', {3, 3}, + {3, 5, 5, 8, 5, 10, 2, 2, 14}); - auto input = NDArrayFactory::create('c', {4, 4}); - input.linspace(1); - auto indices = NDArrayFactory::create('c', {3,3,2}, {0,2,1, 0,1,0, 1,3,1, 0,2,1, 0,1,0, 1,3,1}); - auto expected = NDArrayFactory::create('c', {3,3}, {3,5,5,8,5,10,2,2,14}); - - sd::ops::gather_nd op; - auto results = op.evaluate({&input, &indices}, {}, {}, {true}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + sd::ops::gather_nd op; + auto results = op.evaluate({&input, &indices}, {}, {}, {true}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, gatherNd_test8) { - auto x = NDArrayFactory::create('c', {2, 2}, {1., 2., 3., 4.}); - auto y = NDArrayFactory::create('c', {2, 2}, {0, 0, 1, 1}); - auto e = NDArrayFactory::create('c', {2}, {1., 4.}); - - sd::ops::gather_nd op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {2, 2}, {1., 2., 3., 4.}); + auto y = NDArrayFactory::create('c', {2, 2}, {0, 0, 1, 1}); + auto e = NDArrayFactory::create('c', {2}, {1., 4.}); - auto z = result.at(0); - - ASSERT_EQ(e, *z); + sd::ops::gather_nd op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests5, gatherNd_test9) { - auto x = NDArrayFactory::create('c', {2, 4, 2, 2}); - auto indices = NDArrayFactory::create('c', {3, 3}, {0,2,1, 0,1,0, 1,3,1}); - auto exp = NDArrayFactory::create('c', {3,2}, {11.f, 12.f, 5.f, 6.f, 31.f, 32.f}); - x.linspace(1); - - sd::ops::gather_nd op; - auto result = op.evaluate({&x, &indices}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {2, 4, 2, 2}); + auto indices = + NDArrayFactory::create('c', {3, 3}, {0, 2, 1, 0, 1, 0, 1, 3, 1}); + auto exp = NDArrayFactory::create('c', {3, 2}, + {11.f, 12.f, 5.f, 6.f, 31.f, 32.f}); + x.linspace(1); - auto z = result.at(0); + sd::ops::gather_nd op; + auto result = op.evaluate({&x, &indices}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - //z->printIndexedBuffer(); - //z->printShapeInfo("z shape"); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printIndexedBuffer(); + // z->printShapeInfo("z shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, gatherNd_test10) { + auto input = NDArrayFactory::create('c', {4, 3, 2}); + auto indices = + NDArrayFactory::create('c', {2, 2, 2}, {30, 20, 1, 2, 0, 10, 0, 1}); - auto input = NDArrayFactory::create('c', {4, 3, 2}); - auto indices = NDArrayFactory::create('c', {2,2,2}, {30,20,1,2, 0,10,0,1}); - - auto output = NDArrayFactory::create('c', {2,2,2}); + auto output = NDArrayFactory::create('c', {2, 2, 2}); - sd::ops::gather_nd op; + sd::ops::gather_nd op; - ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {}, {true})); + ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {}, {true})); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, gatherNd_test11) { + auto input = NDArrayFactory::create('c', {4, 4}); + auto indices = NDArrayFactory::create( + 'c', {3, 3, 2}, + {0, 2, 1, 0, 10, 0, 1, 30, 1, 0, 20, 1, 0, 1, 0, 1, 30, 1}); + auto output = NDArrayFactory::create('c', {3, 3}); - auto input = NDArrayFactory::create('c', {4, 4}); - auto indices = NDArrayFactory::create('c', {3,3,2}, {0,2,1, 0,10,0, 1,30,1, 0,20,1, 0,1,0, 1,30,1}); - auto output = NDArrayFactory::create('c', {3,3}); + sd::ops::gather_nd op; - sd::ops::gather_nd op; - - ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {}, {true})); + ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {}, {true})); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test1) { + auto input = NDArrayFactory::create('c', {3, 4, 5}); + input.linspace(1); + auto seqLengths = NDArrayFactory::create('c', {4}, {4, 4, 4, 4}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 5}, + {4, 3, 2, 1, 5, 9, 8, 7, 6, 10, 14, 13, 12, 11, 15, + 19, 18, 17, 16, 20, 24, 23, 22, 21, 25, 29, 28, 27, 26, 30, + 34, 33, 32, 31, 35, 39, 38, 37, 36, 40, 44, 43, 42, 41, 45, + 49, 48, 47, 46, 50, 54, 53, 52, 51, 55, 59, 58, 57, 56, 60}); - auto input = NDArrayFactory::create('c', {3, 4, 5}); - input.linspace(1); - auto seqLengths = NDArrayFactory::create('c', {4}, {4,4,4,4}); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {4, 3, 2, 1, 5, 9, 8, 7, 6, 10, 14, 13, 12, 11, 15, 19, 18, 17, 16, 20, 24, 23, 22, 21, 25, 29, 28, 27, 26, 30, 34, 33, 32, 31, 35, 39, 38, 37, 36, 40, 44, 43, 42, 41, 45, 49, 48, 47, 46, 50, 54, 53, 52, 51, 55, 59, 58, 57, 56, 60}); - - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {2, 1}); - ASSERT_EQ(Status::OK(), results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {2, 1}); + ASSERT_EQ(Status::OK(), results.status()); + auto output = results.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test2) { + auto input = NDArrayFactory::create('c', {3, 4, 5}); + input.linspace(1); + auto seqLengths = NDArrayFactory::create('c', {4}, {0, 1, 2, 3}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 11, 13, 14, 15, + 18, 17, 16, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 32, 31, 33, 34, 35, 38, 37, 36, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 52, 51, 53, 54, 55, 58, 57, 56, 59, 60}); - auto input = NDArrayFactory::create('c', {3, 4, 5}); - input.linspace(1); - auto seqLengths = NDArrayFactory::create('c', {4}, {0,1,2,3}); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 11, 13, 14, 15, 18, 17, 16, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 32, 31, 33, 34, 35, 38, 37, 36, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 52, 51, 53, 54, 55, 58, 57, 56, 59, 60}); - - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {2, 1}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {2, 1}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test3) { + auto input = NDArrayFactory::create('c', {3, 4, 5}); + input.linspace(1); + auto seqLengths = NDArrayFactory::create('c', {3}, {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 5}, + {2, 1, 3, 4, 5, 7, 6, 8, 9, 10, 12, 11, 13, 14, 15, + 17, 16, 18, 19, 20, 23, 22, 21, 24, 25, 28, 27, 26, 29, 30, + 33, 32, 31, 34, 35, 38, 37, 36, 39, 40, 44, 43, 42, 41, 45, + 49, 48, 47, 46, 50, 54, 53, 52, 51, 55, 59, 58, 57, 56, 60}); - auto input = NDArrayFactory::create('c', {3, 4, 5}); - input.linspace(1); - auto seqLengths = NDArrayFactory::create('c', {3}, {2,3,4}); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {2, 1, 3, 4, 5, 7, 6, 8, 9, 10, 12, 11, 13, 14, 15, 17, 16, 18, 19, 20, 23, 22, 21, 24, 25, 28, 27, 26, 29, 30, 33, 32, 31, 34, 35, 38, 37, 36, 39, 40, 44, 43, 42, 41, 45, 49, 48, 47, 46, 50, 54, 53, 52, 51, 55, 59, 58, 57, 56, 60}); - - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {2, 0}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {2, 0}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test4) { + auto input = NDArrayFactory::create('c', {3, 4, 5}); + input.linspace(1); + auto seqLengths = NDArrayFactory::create('c', {5}, {1, 2, 1, 2, 3}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 5}, + {1, 22, 3, 24, 45, 6, 27, 8, 29, 50, 11, 32, 13, 34, 55, + 16, 37, 18, 39, 60, 21, 2, 23, 4, 25, 26, 7, 28, 9, 30, + 31, 12, 33, 14, 35, 36, 17, 38, 19, 40, 41, 42, 43, 44, 5, + 46, 47, 48, 49, 10, 51, 52, 53, 54, 15, 56, 57, 58, 59, 20}); - auto input = NDArrayFactory::create('c', {3, 4, 5}); - input.linspace(1); - auto seqLengths = NDArrayFactory::create('c', {5}, {1, 2, 1, 2, 3}); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1, 22, 3, 24, 45, 6, 27, 8, 29, 50, 11, 32, 13, 34, 55, 16, 37, 18, 39, 60, 21, 2, 23, 4, 25, 26, 7, 28, 9, 30, 31, 12, 33, 14, 35, 36, 17, 38, 19, 40, 41, 42, 43, 44, 5, 46, 47, 48, 49, 10, 51, 52, 53, 54, 15, 56, 57, 58, 59, 20}); - - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {0, 2}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {0, 2}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test5) { + auto input = NDArrayFactory::create('c', {3, 4, 5}); + input.linspace(1); + auto seqLengths = NDArrayFactory::create('c', {5}, {1, 2, 4, 2, 3}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 5}, + {1, 7, 18, 9, 15, 6, 2, 13, 4, 10, 11, 12, 8, 14, 5, + 16, 17, 3, 19, 20, 21, 27, 38, 29, 35, 26, 22, 33, 24, 30, + 31, 32, 28, 34, 25, 36, 37, 23, 39, 40, 41, 47, 58, 49, 55, + 46, 42, 53, 44, 50, 51, 52, 48, 54, 45, 56, 57, 43, 59, 60}); - auto input = NDArrayFactory::create('c', {3, 4, 5}); - input.linspace(1); - auto seqLengths = NDArrayFactory::create('c', {5}, {1, 2, 4, 2, 3}); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1, 7, 18, 9, 15, 6, 2, 13, 4, 10, 11, 12, 8, 14, 5, 16, 17, 3, 19, 20, 21, 27, 38, 29, 35, 26, 22, 33, 24, 30, 31, 32, 28, 34, 25, 36, 37, 23, 39, 40, 41, 47, 58, 49, 55, 46, 42, 53, 44, 50, 51, 52, 48, 54, 45, 56, 57, 43, 59, 60}); - - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {1, 2}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {1, 2}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test6) { + auto input = NDArrayFactory::create('c', {3, 4, 5}); + input.linspace(1); + auto seqLengths = NDArrayFactory::create('c', {4}, {1, 2, 3, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 5}, + {1, 2, 3, 4, 5, 26, 27, 28, 29, 30, 51, 52, 53, 54, 55, + 36, 37, 38, 39, 40, 21, 22, 23, 24, 25, 6, 7, 8, 9, 10, + 31, 32, 33, 34, 35, 16, 17, 18, 19, 20, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 11, 12, 13, 14, 15, 56, 57, 58, 59, 60}); - auto input = NDArrayFactory::create('c', {3, 4, 5}); - input.linspace(1); - auto seqLengths = NDArrayFactory::create('c', {4}, {1, 2, 3, 2}); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1, 2, 3, 4, 5, 26, 27, 28, 29, 30, 51, 52, 53, 54, 55, 36, 37, 38, 39, 40, 21, 22, 23, 24, 25, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 16, 17, 18, 19, 20, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 11, 12, 13, 14, 15, 56, 57, 58, 59, 60}); - - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {0, 1}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {0, 1}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test7) { + auto input = NDArrayFactory::create('c', {1, 5}); + input.linspace(1); + std::vector data = {3}; + auto seqLengths = NDArrayFactory::create('c', {1}, data); + auto exp = NDArrayFactory::create('c', {1, 5}, {3, 2, 1, 4, 5}); - auto input = NDArrayFactory::create('c', {1, 5}); - input.linspace(1); - std::vector data = {3}; - auto seqLengths = NDArrayFactory::create('c', {1}, data); - auto exp = NDArrayFactory::create('c', {1, 5}, {3, 2, 1, 4, 5}); - - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {1, 0}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {1, 0}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test8) { + auto input = NDArrayFactory::create('c', {1, 5}); + input.linspace(1); + std::vector data = {1, 0, 1, 0, 1}; + auto seqLengths = NDArrayFactory::create('c', {5}, data); + auto exp = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); - auto input = NDArrayFactory::create('c', {1, 5}); - input.linspace(1); - std::vector data = {1,0,1,0,1}; - auto seqLengths = NDArrayFactory::create('c', {5}, data); - auto exp = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); - - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {0, 1}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {0, 1}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test9) { + auto input = NDArrayFactory::create('c', {5, 1}); + input.linspace(1); + std::vector data = {1, 0, 1, 0, 1}; + auto seqLengths = NDArrayFactory::create('c', {5}, data); + auto exp = NDArrayFactory::create('c', {5, 1}, {1, 2, 3, 4, 5}); - auto input = NDArrayFactory::create('c', {5, 1}); - input.linspace(1); - std::vector data = {1,0,1,0,1}; - auto seqLengths = NDArrayFactory::create('c', {5}, data); - auto exp = NDArrayFactory::create('c', {5, 1}, {1, 2, 3, 4, 5}); - - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {1, 0}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {1, 0}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test10) { + auto input = NDArrayFactory::create('c', {5, 1}); + input.linspace(1); + std::vector data = {3}; + auto seqLengths = NDArrayFactory::create('c', {1}, data); + auto exp = NDArrayFactory::create('c', {5, 1}, {3, 2, 1, 4, 5}); - auto input = NDArrayFactory::create('c', {5, 1}); - input.linspace(1); - std::vector data = {3}; - auto seqLengths = NDArrayFactory::create('c', {1}, data); - auto exp = NDArrayFactory::create('c', {5, 1}, {3, 2, 1, 4, 5}); - - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {0, 1}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {0, 1}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test11) { + auto input = NDArrayFactory::create('c', {1, 1, 5, 1}); + input.linspace(1); + std::vector data = {1, 0, 1, 0, 1}; + auto seqLengths = NDArrayFactory::create('c', {5}, data); + auto exp = NDArrayFactory::create('c', {1, 1, 5, 1}, {1, 2, 3, 4, 5}); - auto input = NDArrayFactory::create('c', {1, 1, 5, 1}); - input.linspace(1); - std::vector data = {1, 0, 1, 0, 1}; - auto seqLengths = NDArrayFactory::create('c', {5}, data); - auto exp = NDArrayFactory::create('c', {1, 1, 5, 1}, {1, 2, 3, 4, 5}); - - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {1, 2}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {1, 2}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test12) { + auto input = NDArrayFactory::create('c', {1, 1, 5, 1}); + input.linspace(1); + std::vector data = {3}; + auto seqLengths = NDArrayFactory::create('c', {1}, data); + auto exp = NDArrayFactory::create('c', {1, 1, 5, 1}, {3, 2, 1, 4, 5}); - auto input = NDArrayFactory::create('c', {1, 1, 5, 1}); - input.linspace(1); - std::vector data = {3}; - auto seqLengths = NDArrayFactory::create('c', {1}, data); - auto exp = NDArrayFactory::create('c', {1, 1, 5, 1}, {3, 2, 1, 4, 5}); - - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {2, 0}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {2, 0}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test13) { + auto input = NDArrayFactory::create('c', {1, 1, 5, 1}); + input.linspace(1); + std::vector data = {1}; + auto seqLengths = NDArrayFactory::create('c', {1}, data); + auto exp = NDArrayFactory::create('c', {1, 1, 5, 1}, {1, 2, 3, 4, 5}); - auto input = NDArrayFactory::create('c', {1, 1, 5, 1}); - input.linspace(1); - std::vector data = {1}; - auto seqLengths = NDArrayFactory::create('c', {1}, data); - auto exp = NDArrayFactory::create('c', {1, 1, 5, 1}, {1, 2, 3, 4, 5}); - - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {3, 0}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {3, 0}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test14) { - auto input = NDArrayFactory::create('c', {8, 8, 3, 2}, {0.09753360, 0.76124972, 0.24693797, 0.13813169, 0.33144656, 0.08299957, 0.67197708, 0.80659380, 0.98274191, 0.63566073, 0.21592326, 0.54902743, 0.54555996, 0.23407607, 0.11372584, 0.49965927, 0.15210842, 0.53268608, 0.38700677, 0.68832738, 0.37292716, 0.94616004, 0.77735792, 0.60803430, 0.61523204, 0.64298760, 0.26848351, 0.75015615, 0.28683049, 0.70937606, 0.06478678, 0.68985848, 0.55216783, 0.55382648, 0.34652863, 0.17261296, 0.54193264, 0.05176904, 0.82555761, 0.71106697, 0.04416722, 0.07653656, 0.01034390, 0.99430482, 0.59944390, 0.17973880, 0.36437840, 0.86383673, 0.45025550, 0.97136977, 0.13565978, 0.71567448, 0.92094825, 0.93536442, 0.93630291, 0.67277404, 0.93899264, 0.52422773, 0.44892176, 0.03127759, 0.85910449, 0.18252879, 0.72830945, 0.96736828, 0.89831575, 0.83437150, 0.59050780, 0.36145925, 0.16483070, 0.44021176, 0.76018652, 0.44227383, 0.13052339, 0.18204235, 0.99743733, 0.26885190, 0.87726522, 0.16396056, 0.94943412, 0.40016700, 0.65267938, 0.71073267, 0.40094733, 0.91182634, 0.05391789, 0.49520416, 0.24963864, 0.34847086, 0.74088617, 0.36115701, 0.63074210, 0.97423085, 0.42216846, 0.06326975, 0.07858702, 0.20586622, 0.28752144, 0.38146961, 0.83518735, 0.08207577, 0.82083487, 0.81665728, 0.33309570, 0.67563176, 0.98343578, 0.95919930, 0.66994391, 0.89296165, 0.34755773, 0.63166554, 0.18849320, 0.34828456, 0.98477707, 0.75163124, 0.83306004, 0.14203056, 0.01497920, 0.85727447, 0.71194544, 0.85654019, 0.86160433, 0.79580411, 0.47710411, 0.09318029, 0.31369071, 0.64122249, 0.58399725, 0.26706597, 0.05655339, 0.91025211, 0.30330468, 0.33142930, 0.05668627, 0.02936449, 0.12613087, 0.09960114, 0.16218074, 0.15088139, 0.31239040, 0.55980062, 0.34804391, 0.34941538, 0.61370555, 0.07022964, 0.59757058, 0.31189846, 0.25215345, 0.52546591, 0.55744218, 0.59485650, 0.60553664, 0.07536713, 0.55971796, 0.38764845, 0.20737843, 0.37989120, 0.18361641, 0.48636240, 0.06052657, 0.04241913, 0.66710351, 0.07007925, 0.59371493, 0.74479056, 0.84699625, 0.51210368, 0.12489571, 0.23371067, 0.27274571, 0.83306066, 0.75830824, 0.25963478, 0.87137718, 0.24418835, 0.05032742, 0.52076188, 0.47762345, 0.89829370, 0.34417708, 0.84705151, 0.08203183, 0.10632956, 0.78431292, 0.86441722, 0.36487598, 0.09833603, 0.85863594, 0.11010505, 0.11659283, 0.42500288, 0.02747301, 0.12359903, 0.01753431, 0.41160932, 0.47245979, 0.08268172, 0.21580773, 0.75770279, 0.19736489, 0.44461885, 0.33341706, 0.22519571, 0.31528710, 0.14802902, 0.64171939, 0.52643769, 0.19261234, 0.98032835, 0.15401656, 0.85274458, 0.66408502, 0.23212704, 0.74630026, 0.05713613, 0.49025892, 0.48418810, 0.59541513, 0.09243053, 0.93919152, 0.95357019, 0.52377729, 0.65963871, 0.47934951, 0.49919534, 0.34369898, 0.78211256, 0.13908708, 0.95754117, 0.84107746, 0.09126213, 0.42979124, 0.10295325, 0.34631257, 0.69448345, 0.41720536, 0.15282440, 0.74329854, 0.45775009, 0.12786280, 0.39830299, 0.20386769, 0.59703523, 0.94077086, 0.42255597, 0.80453309, 0.79757204, 0.28653229, 0.60175909, 0.55859623, 0.34318230, 0.63002770, 0.36533324, 0.89689906, 0.73236186, 0.61491989, 0.83787947, 0.67939463, 0.72016694, 0.77499849, 0.72428343, 0.34571059, 0.23143007, 0.20099338, 0.85583142, 0.73174191, 0.54284092, 0.20264181, 0.53037061, 0.30493131, 0.82279766, 0.58542432, 0.72632070, 0.18394258, 0.00608118, 0.23808232, 0.17007573, 0.75245459, 0.84990616, 0.38827634, 0.33809538, 0.01080317, 0.27250145, 0.81769542, 0.15323253, 0.71668395, 0.99427044, 0.11355576, 0.50511923, 0.60248266, 0.36610154, 0.99123140, 0.10519719, 0.18754650, 0.43232584, 0.25247084, 0.47968157, 0.88649124, 0.33588961, 0.92338319, 0.18808573, 0.79433656, 0.12074559, 0.02325163, 0.10117917, 0.83559239, 0.67213900, 0.67265260, 0.11917707, 0.76574855, 0.43842117, 0.28530411, 0.79648090, 0.47939640, 0.73564612, 0.41465671, 0.10995635, 0.20271728, 0.00521771, 0.22952055, 0.78271870, 0.12833592, 0.88639055, 0.76398188, 0.49533508, 0.85447872, 0.15937568, 0.92947480, 0.62705964, 0.85960084, 0.13435660, 0.81845809, 0.60715133, 0.83030708, 0.83071910, 0.38883408, 0.92033237, 0.46066239, 0.48806761, 0.50688779, 0.00654483, 0.32076493, 0.42367646, 0.07381865, 0.22801110, 0.26669388, 0.99691302, 0.12113623, 0.34373057, 0.98977921, 0.96225332, 0.90143562, 0.19559914, 0.08978307, 0.09687492, 0.59820890, 0.75527947, 0.67683355, 0.21847023, 0.29395619, 0.50477953, 0.07112842, 0.54090558, 0.68230725, 0.49713828, 0.41958965, 0.68013847, 0.47691765, 0.63269259, 0.94304095, 0.54587271, 0.72447569, 0.28913523, 0.75766936, 0.52965692, 0.96854824, 0.15589071, 0.84128672, 0.16337522, 0.05771034, 0.21556356, 0.12094140, 0.29721207, 0.00811008, 0.66184926}); - auto lengths = NDArrayFactory::create('c', {8}, {7, 2, 3, 5, 2, 1, 6, 4}); - auto e = NDArrayFactory::create('c', {8, 8, 3, 2}, {0.54193264, 0.05176904, 0.82555761, 0.71106697, 0.04416722, 0.07653656, 0.06478678, 0.68985848, 0.55216783, 0.55382648, 0.34652863, 0.17261296, 0.61523204, 0.64298760, 0.26848351, 0.75015615, 0.28683049, 0.70937606, 0.38700677, 0.68832738, 0.37292716, 0.94616004, 0.77735792, 0.60803430, 0.54555996, 0.23407607, 0.11372584, 0.49965927, 0.15210842, 0.53268608, 0.67197708, 0.80659380, 0.98274191, 0.63566073, 0.21592326, 0.54902743, 0.09753360, 0.76124972, 0.24693797, 0.13813169, 0.33144656, 0.08299957, 0.01034390, 0.99430482, 0.59944390, 0.17973880, 0.36437840, 0.86383673, 0.93630291, 0.67277404, 0.93899264, 0.52422773, 0.44892176, 0.03127759, 0.45025550, 0.97136977, 0.13565978, 0.71567448, 0.92094825, 0.93536442, 0.85910449, 0.18252879, 0.72830945, 0.96736828, 0.89831575, 0.83437150, 0.59050780, 0.36145925, 0.16483070, 0.44021176, 0.76018652, 0.44227383, 0.13052339, 0.18204235, 0.99743733, 0.26885190, 0.87726522, 0.16396056, 0.94943412, 0.40016700, 0.65267938, 0.71073267, 0.40094733, 0.91182634, 0.05391789, 0.49520416, 0.24963864, 0.34847086, 0.74088617, 0.36115701, 0.63074210, 0.97423085, 0.42216846, 0.06326975, 0.07858702, 0.20586622, 0.34755773, 0.63166554, 0.18849320, 0.34828456, 0.98477707, 0.75163124, 0.33309570, 0.67563176, 0.98343578, 0.95919930, 0.66994391, 0.89296165, 0.28752144, 0.38146961, 0.83518735, 0.08207577, 0.82083487, 0.81665728, 0.83306004, 0.14203056, 0.01497920, 0.85727447, 0.71194544, 0.85654019, 0.86160433, 0.79580411, 0.47710411, 0.09318029, 0.31369071, 0.64122249, 0.58399725, 0.26706597, 0.05655339, 0.91025211, 0.30330468, 0.33142930, 0.05668627, 0.02936449, 0.12613087, 0.09960114, 0.16218074, 0.15088139, 0.31239040, 0.55980062, 0.34804391, 0.34941538, 0.61370555, 0.07022964, 0.27274571, 0.83306066, 0.75830824, 0.25963478, 0.87137718, 0.24418835, 0.59371493, 0.74479056, 0.84699625, 0.51210368, 0.12489571, 0.23371067, 0.18361641, 0.48636240, 0.06052657, 0.04241913, 0.66710351, 0.07007925, 0.60553664, 0.07536713, 0.55971796, 0.38764845, 0.20737843, 0.37989120, 0.59757058, 0.31189846, 0.25215345, 0.52546591, 0.55744218, 0.59485650, 0.05032742, 0.52076188, 0.47762345, 0.89829370, 0.34417708, 0.84705151, 0.08203183, 0.10632956, 0.78431292, 0.86441722, 0.36487598, 0.09833603, 0.85863594, 0.11010505, 0.11659283, 0.42500288, 0.02747301, 0.12359903, 0.19736489, 0.44461885, 0.33341706, 0.22519571, 0.31528710, 0.14802902, 0.01753431, 0.41160932, 0.47245979, 0.08268172, 0.21580773, 0.75770279, 0.64171939, 0.52643769, 0.19261234, 0.98032835, 0.15401656, 0.85274458, 0.66408502, 0.23212704, 0.74630026, 0.05713613, 0.49025892, 0.48418810, 0.59541513, 0.09243053, 0.93919152, 0.95357019, 0.52377729, 0.65963871, 0.47934951, 0.49919534, 0.34369898, 0.78211256, 0.13908708, 0.95754117, 0.84107746, 0.09126213, 0.42979124, 0.10295325, 0.34631257, 0.69448345, 0.41720536, 0.15282440, 0.74329854, 0.45775009, 0.12786280, 0.39830299, 0.20386769, 0.59703523, 0.94077086, 0.42255597, 0.80453309, 0.79757204, 0.28653229, 0.60175909, 0.55859623, 0.34318230, 0.63002770, 0.36533324, 0.89689906, 0.73236186, 0.61491989, 0.83787947, 0.67939463, 0.72016694, 0.77499849, 0.72428343, 0.34571059, 0.23143007, 0.20099338, 0.85583142, 0.73174191, 0.54284092, 0.20264181, 0.53037061, 0.30493131, 0.82279766, 0.58542432, 0.72632070, 0.18394258, 0.00608118, 0.23808232, 0.17007573, 0.75245459, 0.84990616, 0.38827634, 0.33809538, 0.01080317, 0.27250145, 0.81769542, 0.15323253, 0.71668395, 0.99427044, 0.11355576, 0.50511923, 0.22952055, 0.78271870, 0.12833592, 0.88639055, 0.76398188, 0.49533508, 0.47939640, 0.73564612, 0.41465671, 0.10995635, 0.20271728, 0.00521771, 0.67265260, 0.11917707, 0.76574855, 0.43842117, 0.28530411, 0.79648090, 0.79433656, 0.12074559, 0.02325163, 0.10117917, 0.83559239, 0.67213900, 0.25247084, 0.47968157, 0.88649124, 0.33588961, 0.92338319, 0.18808573, 0.60248266, 0.36610154, 0.99123140, 0.10519719, 0.18754650, 0.43232584, 0.85447872, 0.15937568, 0.92947480, 0.62705964, 0.85960084, 0.13435660, 0.81845809, 0.60715133, 0.83030708, 0.83071910, 0.38883408, 0.92033237, 0.59820890, 0.75527947, 0.67683355, 0.21847023, 0.29395619, 0.50477953, 0.98977921, 0.96225332, 0.90143562, 0.19559914, 0.08978307, 0.09687492, 0.07381865, 0.22801110, 0.26669388, 0.99691302, 0.12113623, 0.34373057, 0.46066239, 0.48806761, 0.50688779, 0.00654483, 0.32076493, 0.42367646, 0.07112842, 0.54090558, 0.68230725, 0.49713828, 0.41958965, 0.68013847, 0.47691765, 0.63269259, 0.94304095, 0.54587271, 0.72447569, 0.28913523, 0.75766936, 0.52965692, 0.96854824, 0.15589071, 0.84128672, 0.16337522, 0.05771034, 0.21556356, 0.12094140, 0.29721207, 0.00811008, 0.66184926}); + auto input = NDArrayFactory::create( + 'c', {8, 8, 3, 2}, + {0.09753360, 0.76124972, 0.24693797, 0.13813169, 0.33144656, 0.08299957, + 0.67197708, 0.80659380, 0.98274191, 0.63566073, 0.21592326, 0.54902743, + 0.54555996, 0.23407607, 0.11372584, 0.49965927, 0.15210842, 0.53268608, + 0.38700677, 0.68832738, 0.37292716, 0.94616004, 0.77735792, 0.60803430, + 0.61523204, 0.64298760, 0.26848351, 0.75015615, 0.28683049, 0.70937606, + 0.06478678, 0.68985848, 0.55216783, 0.55382648, 0.34652863, 0.17261296, + 0.54193264, 0.05176904, 0.82555761, 0.71106697, 0.04416722, 0.07653656, + 0.01034390, 0.99430482, 0.59944390, 0.17973880, 0.36437840, 0.86383673, + 0.45025550, 0.97136977, 0.13565978, 0.71567448, 0.92094825, 0.93536442, + 0.93630291, 0.67277404, 0.93899264, 0.52422773, 0.44892176, 0.03127759, + 0.85910449, 0.18252879, 0.72830945, 0.96736828, 0.89831575, 0.83437150, + 0.59050780, 0.36145925, 0.16483070, 0.44021176, 0.76018652, 0.44227383, + 0.13052339, 0.18204235, 0.99743733, 0.26885190, 0.87726522, 0.16396056, + 0.94943412, 0.40016700, 0.65267938, 0.71073267, 0.40094733, 0.91182634, + 0.05391789, 0.49520416, 0.24963864, 0.34847086, 0.74088617, 0.36115701, + 0.63074210, 0.97423085, 0.42216846, 0.06326975, 0.07858702, 0.20586622, + 0.28752144, 0.38146961, 0.83518735, 0.08207577, 0.82083487, 0.81665728, + 0.33309570, 0.67563176, 0.98343578, 0.95919930, 0.66994391, 0.89296165, + 0.34755773, 0.63166554, 0.18849320, 0.34828456, 0.98477707, 0.75163124, + 0.83306004, 0.14203056, 0.01497920, 0.85727447, 0.71194544, 0.85654019, + 0.86160433, 0.79580411, 0.47710411, 0.09318029, 0.31369071, 0.64122249, + 0.58399725, 0.26706597, 0.05655339, 0.91025211, 0.30330468, 0.33142930, + 0.05668627, 0.02936449, 0.12613087, 0.09960114, 0.16218074, 0.15088139, + 0.31239040, 0.55980062, 0.34804391, 0.34941538, 0.61370555, 0.07022964, + 0.59757058, 0.31189846, 0.25215345, 0.52546591, 0.55744218, 0.59485650, + 0.60553664, 0.07536713, 0.55971796, 0.38764845, 0.20737843, 0.37989120, + 0.18361641, 0.48636240, 0.06052657, 0.04241913, 0.66710351, 0.07007925, + 0.59371493, 0.74479056, 0.84699625, 0.51210368, 0.12489571, 0.23371067, + 0.27274571, 0.83306066, 0.75830824, 0.25963478, 0.87137718, 0.24418835, + 0.05032742, 0.52076188, 0.47762345, 0.89829370, 0.34417708, 0.84705151, + 0.08203183, 0.10632956, 0.78431292, 0.86441722, 0.36487598, 0.09833603, + 0.85863594, 0.11010505, 0.11659283, 0.42500288, 0.02747301, 0.12359903, + 0.01753431, 0.41160932, 0.47245979, 0.08268172, 0.21580773, 0.75770279, + 0.19736489, 0.44461885, 0.33341706, 0.22519571, 0.31528710, 0.14802902, + 0.64171939, 0.52643769, 0.19261234, 0.98032835, 0.15401656, 0.85274458, + 0.66408502, 0.23212704, 0.74630026, 0.05713613, 0.49025892, 0.48418810, + 0.59541513, 0.09243053, 0.93919152, 0.95357019, 0.52377729, 0.65963871, + 0.47934951, 0.49919534, 0.34369898, 0.78211256, 0.13908708, 0.95754117, + 0.84107746, 0.09126213, 0.42979124, 0.10295325, 0.34631257, 0.69448345, + 0.41720536, 0.15282440, 0.74329854, 0.45775009, 0.12786280, 0.39830299, + 0.20386769, 0.59703523, 0.94077086, 0.42255597, 0.80453309, 0.79757204, + 0.28653229, 0.60175909, 0.55859623, 0.34318230, 0.63002770, 0.36533324, + 0.89689906, 0.73236186, 0.61491989, 0.83787947, 0.67939463, 0.72016694, + 0.77499849, 0.72428343, 0.34571059, 0.23143007, 0.20099338, 0.85583142, + 0.73174191, 0.54284092, 0.20264181, 0.53037061, 0.30493131, 0.82279766, + 0.58542432, 0.72632070, 0.18394258, 0.00608118, 0.23808232, 0.17007573, + 0.75245459, 0.84990616, 0.38827634, 0.33809538, 0.01080317, 0.27250145, + 0.81769542, 0.15323253, 0.71668395, 0.99427044, 0.11355576, 0.50511923, + 0.60248266, 0.36610154, 0.99123140, 0.10519719, 0.18754650, 0.43232584, + 0.25247084, 0.47968157, 0.88649124, 0.33588961, 0.92338319, 0.18808573, + 0.79433656, 0.12074559, 0.02325163, 0.10117917, 0.83559239, 0.67213900, + 0.67265260, 0.11917707, 0.76574855, 0.43842117, 0.28530411, 0.79648090, + 0.47939640, 0.73564612, 0.41465671, 0.10995635, 0.20271728, 0.00521771, + 0.22952055, 0.78271870, 0.12833592, 0.88639055, 0.76398188, 0.49533508, + 0.85447872, 0.15937568, 0.92947480, 0.62705964, 0.85960084, 0.13435660, + 0.81845809, 0.60715133, 0.83030708, 0.83071910, 0.38883408, 0.92033237, + 0.46066239, 0.48806761, 0.50688779, 0.00654483, 0.32076493, 0.42367646, + 0.07381865, 0.22801110, 0.26669388, 0.99691302, 0.12113623, 0.34373057, + 0.98977921, 0.96225332, 0.90143562, 0.19559914, 0.08978307, 0.09687492, + 0.59820890, 0.75527947, 0.67683355, 0.21847023, 0.29395619, 0.50477953, + 0.07112842, 0.54090558, 0.68230725, 0.49713828, 0.41958965, 0.68013847, + 0.47691765, 0.63269259, 0.94304095, 0.54587271, 0.72447569, 0.28913523, + 0.75766936, 0.52965692, 0.96854824, 0.15589071, 0.84128672, 0.16337522, + 0.05771034, 0.21556356, 0.12094140, 0.29721207, 0.00811008, 0.66184926}); + auto lengths = + NDArrayFactory::create('c', {8}, {7, 2, 3, 5, 2, 1, 6, 4}); + auto e = NDArrayFactory::create( + 'c', {8, 8, 3, 2}, + {0.54193264, 0.05176904, 0.82555761, 0.71106697, 0.04416722, 0.07653656, + 0.06478678, 0.68985848, 0.55216783, 0.55382648, 0.34652863, 0.17261296, + 0.61523204, 0.64298760, 0.26848351, 0.75015615, 0.28683049, 0.70937606, + 0.38700677, 0.68832738, 0.37292716, 0.94616004, 0.77735792, 0.60803430, + 0.54555996, 0.23407607, 0.11372584, 0.49965927, 0.15210842, 0.53268608, + 0.67197708, 0.80659380, 0.98274191, 0.63566073, 0.21592326, 0.54902743, + 0.09753360, 0.76124972, 0.24693797, 0.13813169, 0.33144656, 0.08299957, + 0.01034390, 0.99430482, 0.59944390, 0.17973880, 0.36437840, 0.86383673, + 0.93630291, 0.67277404, 0.93899264, 0.52422773, 0.44892176, 0.03127759, + 0.45025550, 0.97136977, 0.13565978, 0.71567448, 0.92094825, 0.93536442, + 0.85910449, 0.18252879, 0.72830945, 0.96736828, 0.89831575, 0.83437150, + 0.59050780, 0.36145925, 0.16483070, 0.44021176, 0.76018652, 0.44227383, + 0.13052339, 0.18204235, 0.99743733, 0.26885190, 0.87726522, 0.16396056, + 0.94943412, 0.40016700, 0.65267938, 0.71073267, 0.40094733, 0.91182634, + 0.05391789, 0.49520416, 0.24963864, 0.34847086, 0.74088617, 0.36115701, + 0.63074210, 0.97423085, 0.42216846, 0.06326975, 0.07858702, 0.20586622, + 0.34755773, 0.63166554, 0.18849320, 0.34828456, 0.98477707, 0.75163124, + 0.33309570, 0.67563176, 0.98343578, 0.95919930, 0.66994391, 0.89296165, + 0.28752144, 0.38146961, 0.83518735, 0.08207577, 0.82083487, 0.81665728, + 0.83306004, 0.14203056, 0.01497920, 0.85727447, 0.71194544, 0.85654019, + 0.86160433, 0.79580411, 0.47710411, 0.09318029, 0.31369071, 0.64122249, + 0.58399725, 0.26706597, 0.05655339, 0.91025211, 0.30330468, 0.33142930, + 0.05668627, 0.02936449, 0.12613087, 0.09960114, 0.16218074, 0.15088139, + 0.31239040, 0.55980062, 0.34804391, 0.34941538, 0.61370555, 0.07022964, + 0.27274571, 0.83306066, 0.75830824, 0.25963478, 0.87137718, 0.24418835, + 0.59371493, 0.74479056, 0.84699625, 0.51210368, 0.12489571, 0.23371067, + 0.18361641, 0.48636240, 0.06052657, 0.04241913, 0.66710351, 0.07007925, + 0.60553664, 0.07536713, 0.55971796, 0.38764845, 0.20737843, 0.37989120, + 0.59757058, 0.31189846, 0.25215345, 0.52546591, 0.55744218, 0.59485650, + 0.05032742, 0.52076188, 0.47762345, 0.89829370, 0.34417708, 0.84705151, + 0.08203183, 0.10632956, 0.78431292, 0.86441722, 0.36487598, 0.09833603, + 0.85863594, 0.11010505, 0.11659283, 0.42500288, 0.02747301, 0.12359903, + 0.19736489, 0.44461885, 0.33341706, 0.22519571, 0.31528710, 0.14802902, + 0.01753431, 0.41160932, 0.47245979, 0.08268172, 0.21580773, 0.75770279, + 0.64171939, 0.52643769, 0.19261234, 0.98032835, 0.15401656, 0.85274458, + 0.66408502, 0.23212704, 0.74630026, 0.05713613, 0.49025892, 0.48418810, + 0.59541513, 0.09243053, 0.93919152, 0.95357019, 0.52377729, 0.65963871, + 0.47934951, 0.49919534, 0.34369898, 0.78211256, 0.13908708, 0.95754117, + 0.84107746, 0.09126213, 0.42979124, 0.10295325, 0.34631257, 0.69448345, + 0.41720536, 0.15282440, 0.74329854, 0.45775009, 0.12786280, 0.39830299, + 0.20386769, 0.59703523, 0.94077086, 0.42255597, 0.80453309, 0.79757204, + 0.28653229, 0.60175909, 0.55859623, 0.34318230, 0.63002770, 0.36533324, + 0.89689906, 0.73236186, 0.61491989, 0.83787947, 0.67939463, 0.72016694, + 0.77499849, 0.72428343, 0.34571059, 0.23143007, 0.20099338, 0.85583142, + 0.73174191, 0.54284092, 0.20264181, 0.53037061, 0.30493131, 0.82279766, + 0.58542432, 0.72632070, 0.18394258, 0.00608118, 0.23808232, 0.17007573, + 0.75245459, 0.84990616, 0.38827634, 0.33809538, 0.01080317, 0.27250145, + 0.81769542, 0.15323253, 0.71668395, 0.99427044, 0.11355576, 0.50511923, + 0.22952055, 0.78271870, 0.12833592, 0.88639055, 0.76398188, 0.49533508, + 0.47939640, 0.73564612, 0.41465671, 0.10995635, 0.20271728, 0.00521771, + 0.67265260, 0.11917707, 0.76574855, 0.43842117, 0.28530411, 0.79648090, + 0.79433656, 0.12074559, 0.02325163, 0.10117917, 0.83559239, 0.67213900, + 0.25247084, 0.47968157, 0.88649124, 0.33588961, 0.92338319, 0.18808573, + 0.60248266, 0.36610154, 0.99123140, 0.10519719, 0.18754650, 0.43232584, + 0.85447872, 0.15937568, 0.92947480, 0.62705964, 0.85960084, 0.13435660, + 0.81845809, 0.60715133, 0.83030708, 0.83071910, 0.38883408, 0.92033237, + 0.59820890, 0.75527947, 0.67683355, 0.21847023, 0.29395619, 0.50477953, + 0.98977921, 0.96225332, 0.90143562, 0.19559914, 0.08978307, 0.09687492, + 0.07381865, 0.22801110, 0.26669388, 0.99691302, 0.12113623, 0.34373057, + 0.46066239, 0.48806761, 0.50688779, 0.00654483, 0.32076493, 0.42367646, + 0.07112842, 0.54090558, 0.68230725, 0.49713828, 0.41958965, 0.68013847, + 0.47691765, 0.63269259, 0.94304095, 0.54587271, 0.72447569, 0.28913523, + 0.75766936, 0.52965692, 0.96854824, 0.15589071, 0.84128672, 0.16337522, + 0.05771034, 0.21556356, 0.12094140, 0.29721207, 0.00811008, 0.66184926}); + + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &lengths}, {}, {1, 0}); + ASSERT_EQ(Status::OK(), results.status()); + + auto z = results.at(0); + + ASSERT_EQ(e, z); +} - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &lengths}, {}, {1, 0}); - ASSERT_EQ(Status::OK(), results.status()); +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, Test_TopK_0) { + auto x = NDArrayFactory::create( + 'c', {2, 6}, + {1.0, 1.0, 1.0, 1.0, 11.0, 3.0, 1.0, 1.0, 1.0, 14.0, 5.0, 6.0}); + auto expV = NDArrayFactory::create('c', {2, 1}, {11.0, 14.0}); + auto expI = NDArrayFactory::create('c', {2, 1}, {4, 3}); - auto z = results.at(0); + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {1, 0}); // without sorting - ASSERT_EQ(e, *z); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); + auto v = result.at(0); + auto i = result.at(1); + /* + v->printShapeInfo("topK_0: shape v"); + expV.printShapeInfo("topK_0: shape expV"); -} + i->printShapeInfo("topK_0: shape I"); + expI.printShapeInfo("topK_0: shape expI"); -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests5, Test_TopK_0) { - auto x = NDArrayFactory::create('c', {2, 6}, {1.0, 1.0, 1.0, 1.0, 11.0, 3.0, 1.0, 1.0, 1.0, 14.0, 5.0, 6.0}); - auto expV = NDArrayFactory::create('c', {2, 1}, {11.0, 14.0}); - auto expI = NDArrayFactory::create('c', {2, 1}, {4, 3}); - - sd::ops::top_k op; - auto result = op.evaluate({&x}, {}, {1, 0}); // without sorting - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); - - auto v = result.at(0); - auto i = result.at(1); -/* - v->printShapeInfo("topK_0: shape v"); - expV.printShapeInfo("topK_0: shape expV"); - - i->printShapeInfo("topK_0: shape I"); - expI.printShapeInfo("topK_0: shape expI"); - - v->printIndexedBuffer("topK_0: v"); - expV.printIndexedBuffer("topK_0: expV"); - i->printIndexedBuffer("topK_0: i"); - expI.printIndexedBuffer("topK_0: expI"); -*/ - - ASSERT_TRUE(expV.isSameShape(v)); - ASSERT_TRUE(expV.equalsTo(v)); - - ASSERT_TRUE(expI.isSameShape(i)); - ASSERT_TRUE(expI.equalsTo(i)); - // repeat res again - for (int cases = 0; cases < 100; ++cases) { - op.execute({&x}, std::vector{v, i}, {}, {1, 0}, {}); // without sorting - } + v->printIndexedBuffer("topK_0: v"); + expV.printIndexedBuffer("topK_0: expV"); + i->printIndexedBuffer("topK_0: i"); + expI.printIndexedBuffer("topK_0: expI"); + */ + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); + + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); + // repeat res again + for (int cases = 0; cases < 100; ++cases) { + op.execute({&x}, {&v, &i}, {}, {1, 0}, {}); // without sorting + } } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, Test_TopK_1) { - auto x = NDArrayFactory::create('c', {2, 3}, {1.0f, 11.0f, 3.0f, 14.0f, 5.0f, 6.0f}); - auto expV = NDArrayFactory::create('c', {2, 1}, {11.0f, 14.0f}); - auto expI = NDArrayFactory::create('c', {2, 1}, {1, 0}); - - sd::ops::top_k op; - auto result = op.evaluate({&x}, {}, {1, 0}); // without sorting + auto x = NDArrayFactory::create( + 'c', {2, 3}, {1.0f, 11.0f, 3.0f, 14.0f, 5.0f, 6.0f}); + auto expV = NDArrayFactory::create('c', {2, 1}, {11.0f, 14.0f}); + auto expI = NDArrayFactory::create('c', {2, 1}, {1, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {1, 0}); // without sorting - auto v = result.at(0); - auto i = result.at(1); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); -// v->printShapeInfo("topK_1: shape v"); -// expV.printShapeInfo("topK_1: shape expV"); + auto v = result.at(0); + auto i = result.at(1); -// i->printShapeInfo("topK_1: shape I"); -// expI.printShapeInfo("topK_1: shape expI"); + // v->printShapeInfo("topK_1: shape v"); + // expV.printShapeInfo("topK_1: shape expV"); -// v->printIndexedBuffer("topK_1: v"); -// expV.printIndexedBuffer("topK_1: expV"); -// i->printIndexedBuffer("topK_1: i"); -// expI.printIndexedBuffer("topK_1: expI"); + // i->printShapeInfo("topK_1: shape I"); + // expI.printShapeInfo("topK_1: shape expI"); + // v->printIndexedBuffer("topK_1: v"); + // expV.printIndexedBuffer("topK_1: expV"); + // i->printIndexedBuffer("topK_1: i"); + // expI.printIndexedBuffer("topK_1: expI"); - ASSERT_TRUE(expV.isSameShape(v)); - ASSERT_TRUE(expV.equalsTo(v)); - - ASSERT_TRUE(expI.isSameShape(i)); - ASSERT_TRUE(expI.equalsTo(i)); - // repeat res again - for (int cases = 0; cases < 100; ++cases) { - op.execute({&x}, std::vector{v, i}, {}, {1, 0}, {}); // without sorting - } + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); + // repeat res again + for (int cases = 0; cases < 100; ++cases) { + op.execute({&x}, {&v, &i}, {}, {1, 0}, {}); // without sorting + } } /////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, Test_TopK_2) { - auto x = NDArrayFactory::create('c', {2, 3, 4}, {11.0, 3.0, 14.0, 5.0, - 6.0, 9.0, 3.5, 7.0, - 21.0, 3.0, 14.0, 15.0, - 6.0, 9.0, 3.5, 7.0, - 11.0, 13.0, 14.0, 5.0, - 16.0, 9.0, 13.5, 7.0 - } - ); -// <<<14.>,<9.>>, <<21.>,<9.>>, <<14.>,<16.>>> - auto expV = NDArrayFactory::create('c', {2, 3, 1}, {14.0f, 9.0f, - 21.0f, - 9.0f, 14.0f, - 16.0f - } - ); - - auto expI = NDArrayFactory::create('c', {2, 3, 1 }, {2, 1, 0, 1, 2, 0}); - - sd::ops::top_k op; - auto result = op.evaluate({&x}, {}, {1, 1}); + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, + 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0}); + // <<<14.>,<9.>>, <<21.>,<9.>>, <<14.>,<16.>>> + auto expV = NDArrayFactory::create( + 'c', {2, 3, 1}, {14.0f, 9.0f, 21.0f, 9.0f, 14.0f, 16.0f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); + auto expI = + NDArrayFactory::create('c', {2, 3, 1}, {2, 1, 0, 1, 2, 0}); - auto v = result.at(0); - auto i = result.at(1); + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {1, 1}); -// v->printShapeInfo("shape v"); -// expV.printShapeInfo("shape expV"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); -// i->printShapeInfo("shape I"); -// expI.printShapeInfo("shape expI"); + auto v = result.at(0); + auto i = result.at(1); -// v->printIndexedBuffer("v"); -// expV.printIndexedBuffer("expV"); -// i->printIndexedBuffer("i"); -// expI.printIndexedBuffer("expI"); + // v->printShapeInfo("shape v"); + // expV.printShapeInfo("shape expV"); - ASSERT_TRUE(expV.isSameShape(v)); - ASSERT_TRUE(expV.equalsTo(v)); + // i->printShapeInfo("shape I"); + // expI.printShapeInfo("shape expI"); - ASSERT_TRUE(expI.isSameShape(i)); - ASSERT_TRUE(expI.equalsTo(i)); + // v->printIndexedBuffer("v"); + // expV.printIndexedBuffer("expV"); + // i->printIndexedBuffer("i"); + // expI.printIndexedBuffer("expI"); + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); } TEST_F(DeclarableOpsTests5, Test_TopK_3) { - auto x = NDArrayFactory::create('c', {2, 3, 4}, {11.0, 3.0, 14.0, 5.0, - 6.0, 9.0, 3.5, 7.0, - 21.0, 3.0, 14.0, 15.0, - 6.0, 9.0, 3.5, 7.0, - 11.0, 13.0, 14.0, 5.0, - 16.0, 9.0, 13.5, 7.0 - } - ); + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, + 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0}); - auto expV = NDArrayFactory::create('c', {2, 3, 2}, {14.0f, 11.0f, - 9.0f, 7.0f, - 21.0f, 15.0f, - 9.0f, 7.0f, - 14.0f, 13.0f, - 16.0f, 13.5f - } - ); + auto expV = + NDArrayFactory::create('c', {2, 3, 2}, + {14.0f, 11.0f, 9.0f, 7.0f, 21.0f, 15.0f, + 9.0f, 7.0f, 14.0f, 13.0f, 16.0f, 13.5f}); - auto expI = NDArrayFactory::create('c', {2, 3, 2 }, {2, 0, 1, 3, 0, 3, 1, 3, 2, 1, 0, 2}); + auto expI = NDArrayFactory::create( + 'c', {2, 3, 2}, {2, 0, 1, 3, 0, 3, 1, 3, 2, 1, 0, 2}); - sd::ops::top_k op; - auto result = op.evaluate({&x}, {}, {2, 1}); + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {2, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); - auto v = result.at(0); - auto i = result.at(1); + auto v = result.at(0); + auto i = result.at(1); -// v->printShapeInfo("shape v"); -// expV.printShapeInfo("shape expV"); + // v->printShapeInfo("shape v"); + // expV.printShapeInfo("shape expV"); -// i->printShapeInfo("shape I"); -// expI.printShapeInfo("shape expI"); + // i->printShapeInfo("shape I"); + // expI.printShapeInfo("shape expI"); -// v->printIndexedBuffer("v"); -// expV.printIndexedBuffer("expV"); -// i->printIndexedBuffer("i"); -// expI.printIndexedBuffer("expI"); - - ASSERT_TRUE(expV.isSameShape(v)); - ASSERT_TRUE(expV.equalsTo(v)); - - ASSERT_TRUE(expI.isSameShape(i)); - ASSERT_TRUE(expI.equalsTo(i)); + // v->printIndexedBuffer("v"); + // expV.printIndexedBuffer("expV"); + // i->printIndexedBuffer("i"); + // expI.printIndexedBuffer("expI"); + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); } TEST_F(DeclarableOpsTests5, Test_TopK_3_unsorted) { - auto x = NDArrayFactory::create('c', {2, 3, 4}, {11.0, 3.0, 14.0, 5.0, - 6.0, 9.0, 3.5, 7.0, - 21.0, 3.0, 14.0, 15.0, - 6.0, 9.0, 3.5, 7.0, - 11.0, 13.0, 14.0, 5.0, - 16.0, 9.0, 13.5, 7.0 - } - ); - - auto expV = NDArrayFactory::create('c', {2, 3, 2}, {11.0f, 14.0f, - 9.0f, 7.0f, - 21.0f, 15.0f, - 9.0f, 7.0f, - 13.0f, 14.0f, - 16.0f, 13.5f - } - ); - - auto expI = NDArrayFactory::create('c', {2, 3, 2 }, {0, 2, 1, 3, 0, 3, 1, 3, 1, 2, 0, 2}); + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, + 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0}); - sd::ops::top_k op; - auto result = op.evaluate({&x}, {}, {2}, {false}); + auto expV = + NDArrayFactory::create('c', {2, 3, 2}, + {11.0f, 14.0f, 9.0f, 7.0f, 21.0f, 15.0f, + 9.0f, 7.0f, 13.0f, 14.0f, 16.0f, 13.5f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); + auto expI = NDArrayFactory::create( + 'c', {2, 3, 2}, {0, 2, 1, 3, 0, 3, 1, 3, 1, 2, 0, 2}); - auto v = result.at(0); - auto i = result.at(1); + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {2}, {false}); - ASSERT_TRUE(expV.isSameShape(v)); - ASSERT_TRUE(expV.equalsTo(v)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); - ASSERT_TRUE(expI.isSameShape(i)); - ASSERT_TRUE(expI.equalsTo(i)); + auto v = result.at(0); + auto i = result.at(1); + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, Test_TopK_4) { - auto x = NDArrayFactory::create('c', {2, 3}, {1.0f, 11.0f, 3.0f, 14.0f, 5.0f, 6.0f}); - auto expV = NDArrayFactory::create('c', {2, 2}, {11.0f, 3.0f, 14.0f, 6.0f}); - auto expI = NDArrayFactory::create('c', {2, 2}, {1, 2, 0, 2}); + auto x = NDArrayFactory::create( + 'c', {2, 3}, {1.0f, 11.0f, 3.0f, 14.0f, 5.0f, 6.0f}); + auto expV = + NDArrayFactory::create('c', {2, 2}, {11.0f, 3.0f, 14.0f, 6.0f}); + auto expI = NDArrayFactory::create('c', {2, 2}, {1, 2, 0, 2}); - sd::ops::top_k op; - auto result = op.evaluate({&x}, {}, {2, 1}); + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {2, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); - auto v = result.at(0); - auto i = result.at(1); - - ASSERT_TRUE(expV.isSameShape(v)); - ASSERT_TRUE(expV.equalsTo(v)); - - ASSERT_TRUE(expI.isSameShape(i)); - ASSERT_TRUE(expI.equalsTo(i)); + auto v = result.at(0); + auto i = result.at(1); + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, Test_TopK_5) { - auto x = NDArrayFactory::create('f', {2, 3}, {1.1, 5.2, 3.1, 14.2, 11.1, 6.2}); - auto expV = NDArrayFactory::create('f', {2, 2}, {11.1, 14.2, 3.1, 6.2}); - auto expI = NDArrayFactory::create('f', {2, 2}, {2, 1, 1, 2}); - - sd::ops::top_k op; - auto result = op.evaluate({&x}, {}, {2, 1}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); + auto x = NDArrayFactory::create('f', {2, 3}, + {1.1, 5.2, 3.1, 14.2, 11.1, 6.2}); + auto expV = + NDArrayFactory::create('f', {2, 2}, {11.1, 14.2, 3.1, 6.2}); + auto expI = NDArrayFactory::create('f', {2, 2}, {2, 1, 1, 2}); - auto v = result.at(0); - auto i = result.at(1); + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {2, 1}); - ASSERT_TRUE(expV.isSameShape(v)); - ASSERT_TRUE(expV.equalsTo(v)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); - ASSERT_TRUE(expI.isSameShape(i)); - ASSERT_TRUE(expI.equalsTo(i)); + auto v = result.at(0); + auto i = result.at(1); + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); } /////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, Test_Moments_1) { - auto x = NDArrayFactory::create('c', {2, 3, 4}, {11.0, 3.0, 14.0, 5.0, - 6.0, 9.0, 3.5, 7.0, - 21.0, 3.0, 14.0, 15.0, - 6.0, 9.0, 3.5, 7.0, - 11.0, 13.0, 14.0, 5.0, - 16.0, 9.0, 13.5, 7.0} - ); + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, + 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0}); - auto y = NDArrayFactory::create('c', {3}, {0, 1, 2}); - //auto expV('f', {6}, {1, 0, 0, 0, 0, 0 }); + auto y = NDArrayFactory::create('c', {3}, {0, 1, 2}); + // auto expV('f', {6}, {1, 0, 0, 0, 0, 0 }); - float expMean = 9.395833f; - float expDeviation = 22.4579f; -//Mean 9.395833 -//Deviance 22.4579 + float expMean = 9.395833f; + float expDeviation = 22.4579f; + // Mean 9.395833 + // Deviance 22.4579 - float inf = 1.e-5f; + float inf = 1.e-5f; - sd::ops::moments op; - auto result = op.evaluate({&x, &y}, {}, {}); + sd::ops::moments op; + auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); - auto v = result.at(0); - auto d = result.at(1); - -// v->printIndexedBuffer("Result is "); -// d->printIndexedBuffer("Result is "); - - ASSERT_TRUE(v->isScalar()); - ASSERT_NEAR(expMean, v->e(0), inf); - ASSERT_NEAR(expDeviation, d->e(0), inf); + auto v = result.at(0); + auto d = result.at(1); + // v->printIndexedBuffer("Result is "); + // d->printIndexedBuffer("Result is "); + ASSERT_TRUE(v.isScalar()); + ASSERT_NEAR(expMean, v.e(0), inf); + ASSERT_NEAR(expDeviation, d.e(0), inf); } TEST_F(DeclarableOpsTests5, Test_Moments_2) { - NDArray x('c', {2, 3, 4}, {11.0, 3.0, 14.0, 5.0, - 6.0, 9.0, 3.5, 7.0, - 21.0, 3.0, 14.0, 15.0, - 6.0, 9.0, 3.5, 7.0, - 11.0, 13.0, 14.0, 5.0, - 16.0, 9.0, 13.5, 7.0} - ); - - NDArray expV('c', {4}, {11.833333, 7.6666665, 10.416667, 7.6666665}); - NDArray expD('c', {4}, {28.472221, 12.888889, 23.951387, 11.555554}); - - sd::ops::moments op; - auto result = op.evaluate({&x}, {}, {0, 1}); + NDArray x('c', {2, 3, 4}, {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, + 21.0, 3.0, 14.0, 15.0, 6.0, 9.0, 3.5, 7.0, + 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); + NDArray expV('c', {4}, {11.833333, 7.6666665, 10.416667, 7.6666665}); + NDArray expD('c', {4}, {28.472221, 12.888889, 23.951387, 11.555554}); - auto v = result.at(0); - auto d = result.at(1); + sd::ops::moments op; + auto result = op.evaluate({&x}, {}, {0, 1}); - ASSERT_TRUE(v->isVector()); - ASSERT_TRUE(d->isVector()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); - ASSERT_TRUE(v->equalsTo(&expV)); - ASSERT_TRUE(d->equalsTo(&expD)); + auto v = result.at(0); + auto d = result.at(1); + ASSERT_TRUE(v.isVector()); + ASSERT_TRUE(d.isVector()); + ASSERT_TRUE(v.equalsTo(&expV)); + ASSERT_TRUE(d.equalsTo(&expD)); } TEST_F(DeclarableOpsTests5, Test_Moments_3) { - auto x = NDArrayFactory::create('c', {2, 3, 4}, {11.0, 3.0, 14.0, 5.0, - 6.0, 9.0, 3.5, 7.0, - 21.0, 3.0, 14.0, 15.0, - 6.0, 9.0, 3.5, 7.0, - 11.0, 13.0, 14.0, 5.0, - 16.0, 9.0, 13.5, 7.0} - ); + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, + 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0}); - auto expV = NDArrayFactory::create('c', {3, 4}, { 8.5f, 6.f , 8.75f, 6.f, - 8.5f, 11.f, 8.75f, 6.f, - 18.5f, 6.f, 13.75f, 11.f}); - auto expD = NDArrayFactory::create('c', {3, 4}, { 6.25f, 9.f, 27.5625f, 1.f, - 6.25f, 4.f, 27.5625f, 1.f, - 6.25f, 9.f, 0.0625f, 16.f}); + auto expV = + NDArrayFactory::create('c', {3, 4}, + {8.5f, 6.f, 8.75f, 6.f, 8.5f, 11.f, 8.75f, + 6.f, 18.5f, 6.f, 13.75f, 11.f}); + auto expD = NDArrayFactory::create( + 'c', {3, 4}, + {6.25f, 9.f, 27.5625f, 1.f, 6.25f, 4.f, 27.5625f, 1.f, 6.25f, 9.f, + 0.0625f, 16.f}); - sd::ops::moments op; - auto result = op.evaluate({&x}, {}, {0}); + sd::ops::moments op; + auto result = op.evaluate({&x}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); - auto v = result.at(0); - auto d = result.at(1); - - ASSERT_TRUE(v->isMatrix()); - ASSERT_TRUE(d->isMatrix()); - - ASSERT_TRUE(v->equalsTo(&expV)); - ASSERT_TRUE(d->equalsTo(&expD)); + auto v = result.at(0); + auto d = result.at(1); + ASSERT_TRUE(v.isMatrix()); + ASSERT_TRUE(d.isMatrix()); + ASSERT_TRUE(v.equalsTo(&expV)); + ASSERT_TRUE(d.equalsTo(&expD)); } TEST_F(DeclarableOpsTests5, Test_Moments_4) { + auto x = NDArrayFactory::create( + 'f', {2, 3, 4}, {11.0f, 6.0f, 6.0f, 11.0f, 21.0f, 16.0f, 3.0f, 9.0f, + 9.0f, 13.0f, 3.0f, 9.0f, 14.0f, 3.5f, 3.5f, 14.0f, + 14.0f, 13.5f, 5.0f, 7.0f, 7.0f, 5.0f, 15.0f, 7.0f}); - auto x = NDArrayFactory::create('f', {2, 3, 4}, {11.0f, 6.0f, 6.0f, 11.0f, 21.0f, 16.0f, 3.0f, 9.0f, 9.0f, 13.0f, 3.0f, 9.0f, - 14.0f, 3.5f, 3.5f, 14.0f, 14.0f, 13.5f, 5.0f, 7.0f, 7.0f, 5.0f, 15.0f, 7.0f}); - - - auto expV = NDArrayFactory::create('c', {3, 4}, { 8.5f, 6.f , 8.75f, 6.f, 8.5f, 11.f, 8.75f, 6.f, 18.5f, 6.f, 13.75f, 11.f}); - auto expD = NDArrayFactory::create('c', {3, 4}, { 6.25f, 9.f, 27.5625f, 1.f, 6.25f, 4.f, 27.5625f, 1.f, 6.25f, 9.f, 0.0625f, 16.f}); - - sd::ops::moments op; - auto result = op.evaluate({&x}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); + auto expV = + NDArrayFactory::create('c', {3, 4}, + {8.5f, 6.f, 8.75f, 6.f, 8.5f, 11.f, 8.75f, + 6.f, 18.5f, 6.f, 13.75f, 11.f}); + auto expD = NDArrayFactory::create( + 'c', {3, 4}, + {6.25f, 9.f, 27.5625f, 1.f, 6.25f, 4.f, 27.5625f, 1.f, 6.25f, 9.f, + 0.0625f, 16.f}); - auto v = result.at(0); - auto d = result.at(1); + sd::ops::moments op; + auto result = op.evaluate({&x}, {}, {0}); - ASSERT_TRUE(v->isMatrix()); - ASSERT_TRUE(d->isMatrix()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); - // v->printIndexedBuffer("v"); - // expV.printIndexedBuffer("expV"); + auto v = result.at(0); + auto d = result.at(1); - // d->printIndexedBuffer("d"); - // expD.printIndexedBuffer("expD"); + ASSERT_TRUE(v.isMatrix()); + ASSERT_TRUE(d.isMatrix()); - ASSERT_TRUE(v->equalsTo(&expV)); - ASSERT_TRUE(d->equalsTo(&expD)); + // v->printIndexedBuffer("v"); + // expV.printIndexedBuffer("expV"); + // d->printIndexedBuffer("d"); + // expD.printIndexedBuffer("expD"); + ASSERT_TRUE(v.equalsTo(&expV)); + ASSERT_TRUE(d.equalsTo(&expD)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, trace_test1) { - - auto input = NDArrayFactory::create('c', {3, 4, 5}); - input.linspace(1); - auto exp = NDArrayFactory::create('c', {3}, {40, 120, 200}); - NDArray matrix('c', {3, 3}, {1., 2., 3., 4., 5., 6., 7., 8., 9.}); - sd::ops::trace op; - auto results = op.evaluate({&input}, {}, {}); - auto output = results.at(0); - double traceM = matrix.getTrace(); - // nd4j_printf("Trace for matrix is %f\n", traceM); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - // exp.printIndexedBuffer("EXP TRACE"); - // output->printIndexedBuffer("OUT TRACE"); - ASSERT_TRUE(exp.equalsTo(output)); + auto input = NDArrayFactory::create('c', {3, 4, 5}); + input.linspace(1); + auto exp = NDArrayFactory::create('c', {3}, {40, 120, 200}); + NDArray matrix('c', {3, 3}, {1., 2., 3., 4., 5., 6., 7., 8., 9.}); + sd::ops::trace op; + auto results = op.evaluate({&input}, {}, {}); + auto output = results.at(0); + double traceM = matrix.getTrace(); + // nd4j_printf("Trace for matrix is %f\n", traceM); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + // exp.printIndexedBuffer("EXP TRACE"); + // output->printIndexedBuffer("OUT TRACE"); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, trace_test2) { + auto input = NDArrayFactory::create('c', {4, 5}); + input.linspace(1); + auto exp = NDArrayFactory::create(40.); - auto input = NDArrayFactory::create('c', {4, 5}); - input.linspace(1); - auto exp = NDArrayFactory::create(40.); + sd::ops::trace op; + auto results = op.evaluate({&input}, {}, {}); + auto output = results.at(0); - sd::ops::trace op; - auto results = op.evaluate({&input}, {}, {}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, trace_test3) { + auto input = NDArrayFactory::create('c', {1, 5}); + input.linspace(1); + auto exp = NDArrayFactory::create(1.); - auto input = NDArrayFactory::create('c', {1, 5}); - input.linspace(1); - auto exp = NDArrayFactory::create(1.); - - sd::ops::trace op; - auto results = op.evaluate({&input}, {}, {}); - auto output = results.at(0); + sd::ops::trace op; + auto results = op.evaluate({&input}, {}, {}); + auto output = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, trace_test4) { + auto input = NDArrayFactory::create('c', {5, 1}); + input.linspace(1); + auto exp = NDArrayFactory::create(1.); - auto input = NDArrayFactory::create('c', {5, 1}); - input.linspace(1); - auto exp = NDArrayFactory::create(1.); - - sd::ops::trace op; - auto results = op.evaluate({&input}, {}, {}); - auto output = results.at(0); + sd::ops::trace op; + auto results = op.evaluate({&input}, {}, {}); + auto output = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, trace_test5) { + auto input = NDArrayFactory::create('c', {3, 4, 5, 6}); + input.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {75, 225, 375, 525, 675, 825, 975, 1125, 1275, 1425, 1575, 1725}); - auto input = NDArrayFactory::create('c', {3, 4, 5, 6}); - input.linspace(1); - auto exp = NDArrayFactory::create('c', {3, 4}, {75, 225, 375, 525, 675, 825, 975, 1125, 1275, 1425, 1575, 1725}); - - sd::ops::trace op; - auto results = op.evaluate({&input}); - auto output = results.at(0); + sd::ops::trace op; + auto results = op.evaluate({&input}); + auto output = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test1) { - - auto input = NDArrayFactory::create('c', {2, 2, 2}); - input.linspace(1); + auto input = NDArrayFactory::create('c', {2, 2, 2}); + input.linspace(1); NDArray exp1 = input.dup(); NDArray exp2('c',{2,2,2}, {5,6,7,8, 1,2,3,4}, sd::DataType::DOUBLE); - sd::ops::random_shuffle op; - auto results = op.evaluate({&input}); - auto output = results.at(0); + sd::ops::random_shuffle op; + auto results = op.evaluate({&input}); + auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(output->equalsTo(exp1) || output->equalsTo(exp2)); + ASSERT_TRUE(output.equalsTo(exp1) || output.equalsTo(exp2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test2) { - - auto input = NDArrayFactory::create('c', {1, 3, 2}); - input.linspace(1); + auto input = NDArrayFactory::create('c', {1, 3, 2}); + input.linspace(1); NDArray exp1 = input.dup(); - sd::ops::random_shuffle op; - auto results = op.evaluate({&input}); - auto output = results.at(0); + sd::ops::random_shuffle op; + auto results = op.evaluate({&input}); + auto output = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(output->equalsTo(exp1)); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(output.equalsTo(exp1)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test3) { - - auto input = NDArrayFactory::create('c', {3, 2, 1}); - input.linspace(1); + auto input = NDArrayFactory::create('c', {3, 2, 1}); + input.linspace(1); NDArray exp1 = input.dup(); NDArray exp2('c',{3,2,1}, {1,2, 5,6, 3,4}, sd::DataType::DOUBLE); NDArray exp3('c',{3,2,1}, {3,4, 1,2, 5,6}, sd::DataType::DOUBLE); @@ -1676,10 +1750,8 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test3) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test4) { - - auto input = NDArrayFactory::create('c', {3, 2, 1}); - input.linspace(1); - NDArray exp1 = input.dup(); + auto input = NDArrayFactory::create('c', {3, 2, 1}); + input.linspace(1);NDArray exp1 = input.dup(); NDArray exp2('c',{3,2,1}, {1,2, 5,6, 3,4}, sd::DataType::DOUBLE); NDArray exp3('c',{3,2,1}, {3,4, 1,2, 5,6}, sd::DataType::DOUBLE); NDArray exp4('c',{3,2,1}, {3,4, 5,6, 1,2}, sd::DataType::DOUBLE); @@ -1691,52 +1763,51 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test4) { auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(output->equalsTo(exp1) || output->equalsTo(exp2) || output->equalsTo(exp3) - || output->equalsTo(exp4) || output->equalsTo(exp5) || output->equalsTo(exp6)); + ASSERT_TRUE(output.equalsTo(exp1) || output.equalsTo(exp2) || output.equalsTo(exp3) + || output.equalsTo(exp4) || output.equalsTo(exp5) || output.equalsTo(exp6)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test5) { - auto input = NDArrayFactory::create('c', {4}); - input.linspace(1); + auto input = NDArrayFactory::create('c', {4}); + input.linspace(1); sd::ops::random_shuffle op; auto results = op.evaluate({&input}, {}, {}, {}, {}, false); auto output = results.at(0); // output->printBuffer(); - ASSERT_EQ(Status::OK(), results.status()); - // ASSERT_TRUE(!output->equalsTo(input)); + ASSERT_EQ(Status::OK(), results.status()); + // ASSERT_TRUE(!output->equalsTo(input)); bool hasDublicates = false; - for(int i = 0; i < output->lengthOf() - 1; ++i) - for(int j = i+1; j < output->lengthOf(); ++j) - if(output->t(i) == output->t(j)) { + for(int i = 0; i < output.lengthOf() - 1; ++i) + for(int j = i+1; j < output.lengthOf(); ++j) + if(output.t(i) == output.t(j)) { hasDublicates = true; - i = output->lengthOf(); - break; - } + i = output.lengthOf(); + break;} ASSERT_TRUE(!hasDublicates); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test6) { - auto input = NDArrayFactory::create('c', {4,1,1}); - input.linspace(1); + auto input = NDArrayFactory::create('c', {4, 1, 1}); + input.linspace(1); - sd::ops::random_shuffle op; - auto results = op.evaluate({&input}, {}, {}, {}, {}, false); - auto output = results.at(0); + sd::ops::random_shuffle op; + auto results = op.evaluate({&input}, {}, {}, {}, {}, false); + auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); // ASSERT_TRUE(!output->equalsTo(input)); bool hasDublicates = false; - for(int i = 0; i < output->lengthOf() - 1; ++i) - for(int j = i+1; j < output->lengthOf(); ++j) - if(output->t(i) == output->t(j)) { + for(int i = 0; i < output.lengthOf() - 1; ++i) + for(int j = i+1; j < output.lengthOf(); ++j) + if(output.t(i) == output.t(j)) { hasDublicates = true; - i = output->lengthOf(); + i = output.lengthOf(); break; } ASSERT_TRUE(!hasDublicates); @@ -1744,18 +1815,19 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test6) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test7) { - auto input = NDArrayFactory::create('c', {16010}); - input.linspace(1); + auto input = NDArrayFactory::create('c', {16010}); + input.linspace(1); - sd::ops::random_shuffle op; - auto results = op.evaluate({&input}, {}, {}, {}, {}, false); - auto output = results.at(0); + + sd::ops::random_shuffle op; + auto results = op.evaluate({&input}, {}, {}, {}, {}, false); + auto output = results.at(0); // output->printBuffer(); ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(!output->equalsTo(input)); + ASSERT_TRUE(!output.equalsTo(input)); auto vec1 = input.getBufferAsVector(); - auto vec2 = output->getBufferAsVector(); + auto vec2 = output.getBufferAsVector(); std::sort(vec2.begin(), vec2.end()); ASSERT_TRUE(std::equal(vec1.begin(), vec1.end(), vec2.begin())); } @@ -1767,11 +1839,9 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test8) { NDArray inCopy = input.dup(); sd::ops::random_shuffle op; - auto results = op.evaluate({&input}, {}, {}, {}, {}, false); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(input.equalsTo(inCopy)); - -} + auto results = op.evaluate({&input}, {}, {}, {}, {}, false); ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(input.equalsTo(inCopy)); + } TEST_F(DeclarableOpsTests5, random_shuffle_test9) { @@ -1782,130 +1852,123 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test9) { auto status = op.execute({&x}, {&z}); ASSERT_EQ(Status::OK(), status); - auto vec = z.getBufferAsVector(); - std::sort(vec.begin(), vec.end()); + auto vec = z.getBufferAsVector();std::sort(vec.begin(), vec.end()); ASSERT_EQ(std::vector({1, 2, 3, 4}), vec); } //////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, EmbeddingLookup_1) { - - auto x = NDArrayFactory::create('c', {3, 4, 2}, {10, 20, 11, 21, 12, 22, 13, 23, - 14, 24, 15, 25, 16, 26, 17, 27, - 18, 28, 19, 29, 20, 30, 21, 31}); - - auto y = NDArrayFactory::create({1, 1, 1, 0, 0, 0, 2, 2, 2}); - auto exp = NDArrayFactory::create('c', {9, 4, 2}, {14, 24, 15, 25, 16, 26, 17, 27, 14, 24, 15, 25, - 16, 26, 17, 27, 14, 24, 15, 25, 16, 26, 17, 27, - 10, 20, 11, 21, 12, 22, 13, 23, 10, 20, 11, 21, - 12, 22, 13, 23, 10, 20, 11, 21, 12, 22, 13, 23, - 18, 28, 19, 29, 20, 30, 21, 31, 18, 28, 19, 29, - 20, 30, 21, 31, 18, 28, 19, 29, 20, 30, 21, 31}); - - // y.printShapeInfo("y shape"); - // y.printIndexedBuffer("y buffer"); - - sd::ops::embedding_lookup op; - auto result = op.evaluate({&x, &y}, {}, {0}); - auto output = result.at(0); - // x.printShapeInfo("Input"); - output->printShapeInfo("Output"); - exp.printShapeInfo("Expected"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - //output->printIndexedBuffer("Output"); - //exp.printIndexedBuffer("Expect"); - - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create( + 'c', {3, 4, 2}, {10, 20, 11, 21, 12, 22, 13, 23, 14, 24, 15, 25, + 16, 26, 17, 27, 18, 28, 19, 29, 20, 30, 21, 31}); + + auto y = NDArrayFactory::create({1, 1, 1, 0, 0, 0, 2, 2, 2}); + auto exp = NDArrayFactory::create( + 'c', {9, 4, 2}, + {14, 24, 15, 25, 16, 26, 17, 27, 14, 24, 15, 25, 16, 26, 17, 27, 14, 24, + 15, 25, 16, 26, 17, 27, 10, 20, 11, 21, 12, 22, 13, 23, 10, 20, 11, 21, + 12, 22, 13, 23, 10, 20, 11, 21, 12, 22, 13, 23, 18, 28, 19, 29, 20, 30, + 21, 31, 18, 28, 19, 29, 20, 30, 21, 31, 18, 28, 19, 29, 20, 30, 21, 31}); + + // y.printShapeInfo("y shape"); + // y.printIndexedBuffer("y buffer"); + + sd::ops::embedding_lookup op; + auto result = op.evaluate({&x, &y}, {}, {0}); + auto output = result.at(0); + // x.printShapeInfo("Input"); + // output->printShapeInfo("Output"); + // exp.printShapeInfo("Expected"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + // output->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(output)); } TEST_F(DeclarableOpsTests5, EmbeddingLookup_2) { - - auto x = NDArrayFactory::create('c', {3, 4, 2}, {10, 20, 30, 40, 50, 60, - 70, 80, 90, 10, 11, 12, - 13, 14, 15, 16, 17, 18, - 19, 20, 21, 22, 23, 24}); - //1, 0, 1, 0, 1, 0 - auto y = NDArrayFactory::create({1, 0, 1, 0, 1, 0}); - auto exp = NDArrayFactory::create('c', {6, 4, 2}, {90, 10, 11, 12, 13, 14, - 15, 16, 10, 20, 30, 40, - 50, 60, 70, 80, 90, 10, - 11, 12, 13, 14, 15, 16, - 10, 20, 30, 40, 50, 60, - 70, 80, 90, 10, 11, 12, - 13, 14, 15, 16, 10, 20, - 30, 40, 50, 60, 70, 80}); - - // y.printShapeInfo("y shape"); - // y.printIndexedBuffer("y buffer"); - - sd::ops::embedding_lookup op; - auto result = op.evaluate({&x, &y}, {}, {0}); - auto output = result.at(0); - // x.printShapeInfo("Input"); - // output->printShapeInfo("Output"); - // exp.printShapeInfo("Expected"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - // output->printIndexedBuffer("Output"); - // exp.printIndexedBuffer("Expect"); - - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create( + 'c', {3, 4, 2}, {10, 20, 30, 40, 50, 60, 70, 80, 90, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + // 1, 0, 1, 0, 1, 0 + auto y = NDArrayFactory::create({1, 0, 1, 0, 1, 0}); + auto exp = NDArrayFactory::create( + 'c', {6, 4, 2}, + {90, 10, 11, 12, 13, 14, 15, 16, 10, 20, 30, 40, 50, 60, 70, 80, + 90, 10, 11, 12, 13, 14, 15, 16, 10, 20, 30, 40, 50, 60, 70, 80, + 90, 10, 11, 12, 13, 14, 15, 16, 10, 20, 30, 40, 50, 60, 70, 80}); + + // y.printShapeInfo("y shape"); + // y.printIndexedBuffer("y buffer"); + + sd::ops::embedding_lookup op; + auto result = op.evaluate({&x, &y}, {}, {0}); + auto output = result.at(0); + // x.printShapeInfo("Input"); + // output->printShapeInfo("Output"); + // exp.printShapeInfo("Expected"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + // output->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(output)); } TEST_F(DeclarableOpsTests5, EmbeddingLookup_3) { - - - auto y = NDArrayFactory::create('c', {3,2}, {5, 4, 4, 5, 3, 3}); - auto exp = NDArrayFactory::create('c', {6, 3, 3}, { - 6, 20, 11, 21, 12, 22, 13, 23, 14, - 5, 20, 11, 21, 12, 22, 13, 23, 14, - 5, 20, 11, 21, 12, 22, 13, 23, 14, - 6, 20, 11, 21, 12, 22, 13, 23, 14, - 4, 20, 11, 21, 12, 22, 13, 23, 14, - 4, 20, 11, 21, 12, 22, 13, 23, 14 }); - - // y.printShapeInfo("y shape"); - // y.printIndexedBuffer("y buffer"); - auto p1 = NDArrayFactory::create('c', {3,3}, {1, 20, 11, 21, 12, 22, 13, 23, 14}); - auto p2 = NDArrayFactory::create('c', {3,3}, {2, 20, 11, 21, 12, 22, 13, 23, 14}); - auto p3 = NDArrayFactory::create('c', {3,3}, {3, 20, 11, 21, 12, 22, 13, 23, 14}); - auto p4 = NDArrayFactory::create('c', {3,3}, {4, 20, 11, 21, 12, 22, 13, 23, 14}); - auto p5 = NDArrayFactory::create('c', {3,3}, {5, 20, 11, 21, 12, 22, 13, 23, 14}); - auto p6 = NDArrayFactory::create('c', {3,3}, {6, 20, 11, 21, 12, 22, 13, 23, 14}); - auto p7 = NDArrayFactory::create('c', {3,3}, {7, 20, 11, 21, 12, 22, 13, 23, 14}); - auto p8 = NDArrayFactory::create('c', {3,3}, {8, 20, 11, 21, 12, 22, 13, 23, 14}); - -// res = tf.nn.embedding_lookup((p1, p2, p3, p4, p5, p6, p7), ids, 'mod') - - sd::ops::embedding_lookup op; - auto result = op.evaluate({&p1, &p2, &p3, &p4, &p5, &p6, &p7, &p8, &y}, {}, {1}); - auto output = result.at(0); - // x.printShapeInfo("Input"); - // output->printIndexedBuffer("Output"); - // exp.printShapeInfo("Expected"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - // output->printIndexedBuffer("Output"); - // exp.printIndexedBuffer("Expect"); - - ASSERT_TRUE(exp.equalsTo(output)); - - + auto y = NDArrayFactory::create('c', {3, 2}, {5, 4, 4, 5, 3, 3}); + auto exp = NDArrayFactory::create( + 'c', {6, 3, 3}, + {6, 20, 11, 21, 12, 22, 13, 23, 14, 5, 20, 11, 21, 12, 22, 13, 23, 14, + 5, 20, 11, 21, 12, 22, 13, 23, 14, 6, 20, 11, 21, 12, 22, 13, 23, 14, + 4, 20, 11, 21, 12, 22, 13, 23, 14, 4, 20, 11, 21, 12, 22, 13, 23, 14}); + + // y.printShapeInfo("y shape"); + // y.printIndexedBuffer("y buffer"); + auto p1 = NDArrayFactory::create('c', {3, 3}, + {1, 20, 11, 21, 12, 22, 13, 23, 14}); + auto p2 = NDArrayFactory::create('c', {3, 3}, + {2, 20, 11, 21, 12, 22, 13, 23, 14}); + auto p3 = NDArrayFactory::create('c', {3, 3}, + {3, 20, 11, 21, 12, 22, 13, 23, 14}); + auto p4 = NDArrayFactory::create('c', {3, 3}, + {4, 20, 11, 21, 12, 22, 13, 23, 14}); + auto p5 = NDArrayFactory::create('c', {3, 3}, + {5, 20, 11, 21, 12, 22, 13, 23, 14}); + auto p6 = NDArrayFactory::create('c', {3, 3}, + {6, 20, 11, 21, 12, 22, 13, 23, 14}); + auto p7 = NDArrayFactory::create('c', {3, 3}, + {7, 20, 11, 21, 12, 22, 13, 23, 14}); + auto p8 = NDArrayFactory::create('c', {3, 3}, + {8, 20, 11, 21, 12, 22, 13, 23, 14}); + + // res = tf.nn.embedding_lookup((p1, p2, p3, p4, p5, p6, p7), ids, 'mod') + + sd::ops::embedding_lookup op; + auto result = + op.evaluate({&p1, &p2, &p3, &p4, &p5, &p6, &p7, &p8, &y}, {}, {1}); + auto output = result.at(0); + // x.printShapeInfo("Input"); + // output->printIndexedBuffer("Output"); + // exp.printShapeInfo("Expected"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + // output->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(output)); } /* @Test public void testDynamicPartition(){ INDArray data = Nd4j.createFromArray(2, 1, 2, 0); INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0); INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition") - .addOutputs(Nd4j.createUninitialized(DataType.INT, 2), Nd4j.createUninitialized(DataType.INT, 1), Nd4j.createUninitialized(DataType.INT, 1)) - .addIntegerArguments(3) //3 partitions - .addInputs(data, partitions).build()); + .addOutputs(Nd4j.createUninitialized(DataType.INT, 2), + Nd4j.createUninitialized(DataType.INT, 1), + Nd4j.createUninitialized(DataType.INT, 1)) .addIntegerArguments(3) //3 + partitions .addInputs(data, partitions).build()); INDArray exp0 = Nd4j.createFromArray(2, 0); INDArray exp1 = Nd4j.createFromArray(2); @@ -1916,1169 +1979,1172 @@ TEST_F(DeclarableOpsTests5, EmbeddingLookup_3) { assertEquals(exp2, out[2]); }*/ TEST_F(DeclarableOpsTests5, DynamicPartition_01) { + auto x = NDArrayFactory::create({2, 1, 2, 0}); - auto x = NDArrayFactory::create({2,1,2,0}); - - auto y = NDArrayFactory::create({0,2,1,0}); + auto y = NDArrayFactory::create({0, 2, 1, 0}); - int numPartition = 3; - std::vector exp( { NDArrayFactory::create('c', {2}, {2, 0}), - NDArrayFactory::create('c', {1}, {2}), - NDArrayFactory::create('c', {1}, {1})}); + int numPartition = 3; + std::vector exp({NDArrayFactory::create('c', {2}, {2, 0}), + NDArrayFactory::create('c', {1}, {2}), + NDArrayFactory::create('c', {1}, {1})}); - sd::ops::dynamic_partition op; - auto result = op.evaluate({&x, &y}, {}, {numPartition}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(result.size(), numPartition); // result has the same size as given param 4 - - for (int e = 0; e < result.size(); e++) { - auto output = result.at(e); - // output->printShapeInfo("Output shape> "); - // output->printIndexedBuffer("Output data> "); - ASSERT_TRUE(exp[e].isSameShape(output)); - ASSERT_TRUE(exp[e].equalsTo(output)); - } + sd::ops::dynamic_partition op; + auto result = op.evaluate({&x, &y}, {}, {numPartition}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(result.size(), + numPartition); // result has the same size as given param 4 + for (int e = 0; e < result.size(); e++) { + auto output = result.at(e); + // output->printShapeInfo("Output shape> "); + // output->printIndexedBuffer("Output data> "); + ASSERT_TRUE(exp[e].isSameShape(output)); + ASSERT_TRUE(exp[e].equalsTo(output)); + } } TEST_F(DeclarableOpsTests5, DynamicPartition_1) { - - auto x = NDArrayFactory::create('c', {3, 4, 2}, {10, 20, 11, 21, 12, 22, - 13, 23, 14, 24, 15, 25, 16, 26, 17, 27, - 18, 28, 19, 29, 20, 30, 21, 31}); - - auto y = NDArrayFactory::create('c', {3, 4, 2}, {0, 0, 0, 0, 0, 0, - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 1, 1, 1, 1, 1, 1, 1, 1 - } - ); -/* auto y = NDArrayFactory::create('c', {3, 4}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, - 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f - } - ); -*/ - int numPartition = 3; - std::vector exp( { NDArrayFactory::create('c', {6}, {10, 20, 11, 21, 12, 22}), - NDArrayFactory::create('c', {8}, {18, 28, 19, 29, 20, 30, 21, 31}), - NDArrayFactory::create('c', {10}, {13, 23, 14, 24, 15, 25, 16, 26, 17, 27})}); - - sd::ops::dynamic_partition op; - auto result = op.evaluate({&x, &y}, {}, {numPartition}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(result.size(), numPartition); // result has the same size as given param 4 - - for (int e = 0; e < result.size(); e++) { - auto output = result.at(e); - // output->printShapeInfo("Output shape> "); - // output->printIndexedBuffer("Output data> "); - ASSERT_TRUE(exp[e].isSameShape(output)); - ASSERT_TRUE(exp[e].equalsTo(output)); - } - - + auto x = NDArrayFactory::create( + 'c', {3, 4, 2}, {10, 20, 11, 21, 12, 22, 13, 23, 14, 24, 15, 25, + 16, 26, 17, 27, 18, 28, 19, 29, 20, 30, 21, 31}); + + auto y = NDArrayFactory::create( + 'c', {3, 4, 2}, + {0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1}); + /* auto y = NDArrayFactory::create('c', {3, 4}, {0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f + } + ); + */ + int numPartition = 3; + std::vector exp( + {NDArrayFactory::create('c', {6}, {10, 20, 11, 21, 12, 22}), + NDArrayFactory::create('c', {8}, + {18, 28, 19, 29, 20, 30, 21, 31}), + NDArrayFactory::create( + 'c', {10}, {13, 23, 14, 24, 15, 25, 16, 26, 17, 27})}); + + sd::ops::dynamic_partition op; + auto result = op.evaluate({&x, &y}, {}, {numPartition}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(result.size(), + numPartition); // result has the same size as given param 4 + + for (int e = 0; e < result.size(); e++) { + auto output = result.at(e); + // output->printShapeInfo("Output shape> "); + // output->printIndexedBuffer("Output data> "); + ASSERT_TRUE(exp[e].isSameShape(output)); + ASSERT_TRUE(exp[e].equalsTo(output)); + } } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, DynamicPartition_2) { + auto x = NDArrayFactory::create( + 'c', {2, 4}, {0.1f, -1.f, 5.2f, 4.3f, -1.f, 7.4f, 0.0f, -2.2f}); + auto y = NDArrayFactory::create('c', {2, 4}, {1, 2, 1, 2, 1, 2, 3, 0}); - auto x = NDArrayFactory::create('c', {2, 4}, {0.1f, -1.f, 5.2f, 4.3f, -1.f, 7.4f, 0.0f, -2.2f}); - auto y = NDArrayFactory::create('c', {2, 4}, {1, 2, 1, 2, 1, 2, 3, 0}); - - std::vector exp( {NDArrayFactory::create('c', {1}, {-2.2}), - NDArrayFactory::create('c', {3}, {0.1, 5.2, -1.}), - NDArrayFactory::create('c', {3}, {-1., 4.3, 7.4}), - NDArrayFactory::create('c', {1}, {0.0})}); - - sd::ops::dynamic_partition op; - int numPartition = 4; - auto result = op.evaluate({&x, &y}, {}, {numPartition}); + std::vector exp( + {NDArrayFactory::create('c', {1}, {-2.2}), + NDArrayFactory::create('c', {3}, {0.1, 5.2, -1.}), + NDArrayFactory::create('c', {3}, {-1., 4.3, 7.4}), + NDArrayFactory::create('c', {1}, {0.0})}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(result.size(), numPartition); // result has the same size as given param 4 + sd::ops::dynamic_partition op; + int numPartition = 4; + auto result = op.evaluate({&x, &y}, {}, {numPartition}); - for (int e = 0; e < result.size(); e++) { - auto output = result.at(e); - - ASSERT_TRUE(exp[e].isSameShape(output)); - ASSERT_TRUE(exp[e].equalsTo(output)); - } + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(result.size(), + numPartition); // result has the same size as given param 4 + for (int e = 0; e < result.size(); e++) { + auto output = result.at(e); + ASSERT_TRUE(exp[e].isSameShape(output)); + ASSERT_TRUE(exp[e].equalsTo(output)); + } } - TEST_F(DeclarableOpsTests5, DynamicPartition_3) { - - auto x = NDArrayFactory::create('c', {2, 4}, {0.1f, -1.f, 5.2f, 4.3f, -1.f, 7.4f, 0.0f, -2.2f}); - auto y = NDArrayFactory::create('c', {2, 4}, {0, 1, 0, 2, 0, 2, 3, 0}); - - std::vector exp( {NDArrayFactory::create({0.1f, 5.2f, -1.f, -2.2f}), - NDArrayFactory::create('c', {1}, {-1.f}), - NDArrayFactory::create({4.3f, 7.4f}), - NDArrayFactory::create('c', {1}, {0.0f})}); - - sd::ops::dynamic_partition op; - int numPartition = 4; - auto result = op.evaluate({&x, &y}, {}, {numPartition}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(result.size(), numPartition); // result has the same size as given param 4 - - for (int e = 0; e < result.size(); e++) { - auto output = result.at(e); - if (output) - { - // output->printShapeInfo("Output shape> "); - // exp[e].printShapeInfo("Expected shape> "); - // output->printIndexedBuffer("Output data> "); - - ASSERT_TRUE(exp[e].isSameShape(output)); - ASSERT_TRUE(exp[e].equalsTo(output)); - } - else - { - ASSERT_TRUE(exp[e].lengthOf() == 0); - } + auto x = NDArrayFactory::create( + 'c', {2, 4}, {0.1f, -1.f, 5.2f, 4.3f, -1.f, 7.4f, 0.0f, -2.2f}); + auto y = + NDArrayFactory::create('c', {2, 4}, {0, 1, 0, 2, 0, 2, 3, 0}); + + std::vector exp( + {NDArrayFactory::create({0.1f, 5.2f, -1.f, -2.2f}), + NDArrayFactory::create('c', {1}, {-1.f}), + NDArrayFactory::create({4.3f, 7.4f}), + NDArrayFactory::create('c', {1}, {0.0f})}); + + sd::ops::dynamic_partition op; + int numPartition = 4; + auto result = op.evaluate({&x, &y}, {}, {numPartition}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(result.size(), + numPartition); // result has the same size as given param 4 + + for (int e = 0; e < result.size(); e++) { + auto output = result.at(e); + if (output.shapeInfo()) { + // output->printShapeInfo("Output shape> "); + // exp[e].printShapeInfo("Expected shape> "); + // output->printIndexedBuffer("Output data> "); + + ASSERT_TRUE(exp[e].isSameShape(output)); + ASSERT_TRUE(exp[e].equalsTo(output)); + } else { + ASSERT_TRUE(exp[e].lengthOf() == 0); } - - + } } TEST_F(DeclarableOpsTests5, DynamicStitch_empty_1) { - auto i0 = NDArrayFactory::create('c', {2}, {2, 3}); - auto i1 = NDArrayFactory::empty(); - auto i2 = NDArrayFactory::create('c', {2}, {0, 1}); - - auto d0 = NDArrayFactory::create('c', {2, 5}, {0.085571885,0.7937801,0.65908563,0.55552566,0.15962744,0.7787856,0.80119777,0.72437465,0.23089433,0.72714126}); - auto d1 = NDArrayFactory::empty(); - auto d2 = NDArrayFactory::create('c', {2, 5}, {0.94414854,0.5956861,0.8668989,0.3502196,0.5100082,0.061725974,0.6621324,0.034165382,0.32576954,0.51917326}); - - sd::ops::dynamic_stitch op; - auto result = op.evaluate({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto i0 = NDArrayFactory::create('c', {2}, {2, 3}); + auto i1 = NDArrayFactory::empty(); + auto i2 = NDArrayFactory::create('c', {2}, {0, 1}); + auto d0 = NDArrayFactory::create( + 'c', {2, 5}, + {0.085571885, 0.7937801, 0.65908563, 0.55552566, 0.15962744, 0.7787856, + 0.80119777, 0.72437465, 0.23089433, 0.72714126}); + auto d1 = NDArrayFactory::empty(); + auto d2 = NDArrayFactory::create( + 'c', {2, 5}, + {0.94414854, 0.5956861, 0.8668989, 0.3502196, 0.5100082, 0.061725974, + 0.6621324, 0.034165382, 0.32576954, 0.51917326}); + sd::ops::dynamic_stitch op; + auto result = op.evaluate({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); } TEST_F(DeclarableOpsTests5, DynamicStitch_empty_2) { - auto i0 = NDArrayFactory::create('c', {2}, {2, 3}); - auto i1 = NDArrayFactory::create('c', {0}); - auto i2 = NDArrayFactory::create('c', {2}, {0, 1}); - - auto d0 = NDArrayFactory::create('c', {2, 5}, {0.085571885,0.7937801,0.65908563,0.55552566,0.15962744,0.7787856,0.80119777,0.72437465,0.23089433,0.72714126}); - auto d1 = NDArrayFactory::create('c', {0, 5}); - auto d2 = NDArrayFactory::create('c', {2, 5}, {0.94414854,0.5956861,0.8668989,0.3502196,0.5100082,0.061725974,0.6621324,0.034165382,0.32576954,0.51917326}); - - sd::ops::dynamic_stitch op; - auto result = op.evaluate({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto i0 = NDArrayFactory::create('c', {2}, {2, 3}); + auto i1 = NDArrayFactory::create('c', {0}); + auto i2 = NDArrayFactory::create('c', {2}, {0, 1}); + auto d0 = NDArrayFactory::create( + 'c', {2, 5}, + {0.085571885, 0.7937801, 0.65908563, 0.55552566, 0.15962744, 0.7787856, + 0.80119777, 0.72437465, 0.23089433, 0.72714126}); + auto d1 = NDArrayFactory::create('c', {0, 5}); + auto d2 = NDArrayFactory::create( + 'c', {2, 5}, + {0.94414854, 0.5956861, 0.8668989, 0.3502196, 0.5100082, 0.061725974, + 0.6621324, 0.034165382, 0.32576954, 0.51917326}); + sd::ops::dynamic_stitch op; + auto result = op.evaluate({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, DynamicStitch_1) { + auto x1 = NDArrayFactory::create({1, 3, 5, 0}); + auto x2 = NDArrayFactory::create({2, 4}); + auto y2 = NDArrayFactory::create({-1., -1.}); + auto y1 = NDArrayFactory::create({0.1f, 5.2f, 4.3f, 7.4f}); - auto x1 = NDArrayFactory::create({1, 3, 5, 0}); - auto x2 = NDArrayFactory::create({2, 4}); - auto y2 = NDArrayFactory::create({-1., -1.}); - auto y1 = NDArrayFactory::create({0.1f, 5.2f, 4.3f, 7.4f}); - + auto exp = + NDArrayFactory::create({7.4f, 0.1f, -1.f, 5.2f, -1.f, 4.3f}); - auto exp = NDArrayFactory::create({7.4f, 0.1f, -1.f, 5.2f, -1.f, 4.3f}); + sd::ops::dynamic_stitch op; + auto result = op.evaluate({&x1, &x2, &y1, &y2}, {}, {}); - sd::ops::dynamic_stitch op; - auto result = op.evaluate({&x1, &x2, &y1, &y2}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, DynamicStitch_2) { + auto x1 = NDArrayFactory::create({1, 3}); + auto x2 = NDArrayFactory::create({5, 0, 2, 4}); + auto y1 = NDArrayFactory::create({-1.f, -1.f}); + auto y2 = NDArrayFactory::create({0.1f, 5.2f, 4.3f, 7.4f}); - auto x1 = NDArrayFactory::create({1, 3}); - auto x2 = NDArrayFactory::create({5, 0, 2, 4}); - auto y1 = NDArrayFactory::create({-1.f, -1.f}); - auto y2 = NDArrayFactory::create({0.1f, 5.2f, 4.3f, 7.4f}); - - - auto exp = NDArrayFactory::create({5.2f, -1.f, 4.3f, -1.f, 7.4f, 0.1f}); + auto exp = + NDArrayFactory::create({5.2f, -1.f, 4.3f, -1.f, 7.4f, 0.1f}); - sd::ops::dynamic_stitch op; - auto result = op.evaluate({&x1, &x2, &y1, &y2}, {}, {}); + sd::ops::dynamic_stitch op; + auto result = op.evaluate({&x1, &x2, &y1, &y2}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto output = result.at(0); - - // output->printShapeInfo("Output shape> "); - // exp.printShapeInfo("Expected shape> "); - // output->printIndexedBuffer("Output data> "); - // exp.printIndexedBuffer("Expected res>"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + // output->printShapeInfo("Output shape> "); + // exp.printShapeInfo("Expected shape> "); + // output->printIndexedBuffer("Output data> "); + // exp.printIndexedBuffer("Expected res>"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, fusedBatchNorm_test1) { - - auto x = NDArrayFactory::create('c', {2, 2, 3, 4}); - x.linspace(1); - auto scale = NDArrayFactory::create('c', {4}); - - scale = 0.5; - auto offset = NDArrayFactory::create('c', {4}); - offset = 2.; - auto expY = NDArrayFactory::create('c', {2, 2, 3, 4}, {1.20337462, 1.20337462, 1.20337462, 1.20337462, 1.34821558, 1.34821558, 1.34821558, 1.34821558, 1.49305654, 1.49305654, 1.49305654, 1.49305654, 1.63789749, 1.63789749, 1.63789749, 1.63789749, 1.78273857, 1.78273857, 1.78273857, 1.78273857, 1.92757952, 1.92757952, 1.92757952, 1.92757952, 2.0724206 , 2.0724206 , 2.0724206 , 2.0724206 , 2.21726155, 2.21726155, 2.21726155, 2.21726155, 2.36210251, 2.36210251, 2.36210251, 2.36210251, 2.50694346, 2.50694346, 2.50694346, 2.50694346, 2.65178442, 2.65178442, 2.65178442, 2.65178442, 2.79662538, 2.79662538, 2.79662538, 2.79662538}); - auto expBatchMean = NDArrayFactory::create('c', {4}, {23., 24., 25., 26.}); - auto expBatchVar = NDArrayFactory::create('c', {4}, {208.00001526, 208.00001526, 208.00001526, 208.00001526}); - - - sd::ops::fused_batch_norm op; - auto results = op.evaluate({&x, &scale, &offset}, {}, {0,1}); - auto y = results.at(0); - auto batchMean = results.at(1); - auto batchVar = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expY.isSameShape(y)); - ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); - ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); - - + auto x = NDArrayFactory::create('c', {2, 2, 3, 4}); + x.linspace(1); + auto scale = NDArrayFactory::create('c', {4}); + + scale = 0.5; + auto offset = NDArrayFactory::create('c', {4}); + offset = 2.; + auto expY = NDArrayFactory::create( + 'c', {2, 2, 3, 4}, + {1.20337462, 1.20337462, 1.20337462, 1.20337462, 1.34821558, 1.34821558, + 1.34821558, 1.34821558, 1.49305654, 1.49305654, 1.49305654, 1.49305654, + 1.63789749, 1.63789749, 1.63789749, 1.63789749, 1.78273857, 1.78273857, + 1.78273857, 1.78273857, 1.92757952, 1.92757952, 1.92757952, 1.92757952, + 2.0724206, 2.0724206, 2.0724206, 2.0724206, 2.21726155, 2.21726155, + 2.21726155, 2.21726155, 2.36210251, 2.36210251, 2.36210251, 2.36210251, + 2.50694346, 2.50694346, 2.50694346, 2.50694346, 2.65178442, 2.65178442, + 2.65178442, 2.65178442, 2.79662538, 2.79662538, 2.79662538, 2.79662538}); + auto expBatchMean = + NDArrayFactory::create('c', {4}, {23., 24., 25., 26.}); + auto expBatchVar = NDArrayFactory::create( + 'c', {4}, {208.00001526, 208.00001526, 208.00001526, 208.00001526}); + + sd::ops::fused_batch_norm op; + auto results = op.evaluate({&x, &scale, &offset}, {}, {0, 1}); + auto y = results.at(0); + auto batchMean = results.at(1); + auto batchVar = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expY.isSameShape(y)); + ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); + ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, fusedBatchNorm_test2) { - - auto x = NDArrayFactory::create('c', {2, 2, 3, 4}); - x.linspace(1); - - auto scale = NDArrayFactory::create('c', {4}); - - scale = 0.5; - auto offset = NDArrayFactory::create('c', {4}); - offset = 2.; - auto expY = NDArrayFactory::create('c', {2, 2, 3, 4}, {1.20347691, 1.20347691, 1.20347691, 1.20347691, 1.34829926, 1.34829926, 1.34829926, 1.34829926, 1.49312162, 1.49312162, 1.49312162, 1.49312162, 1.6379441 , 1.6379441 , 1.6379441 , 1.6379441 , 1.78276646, 1.78276646, 1.78276646, 1.78276646, 1.92758882, 1.92758882, 1.92758882, 1.92758882, 2.0724113 , 2.0724113 , 2.0724113 , 2.0724113 , 2.21723366, 2.21723366, 2.21723366, 2.21723366, 2.36205602, 2.36205602, 2.36205602, 2.36205602, 2.50687838, 2.50687838, 2.50687838, 2.50687838, 2.65170074, 2.65170074, 2.65170074, 2.65170074, 2.79652309, 2.79652309, 2.79652309, 2.79652309}); - auto expBatchMean = NDArrayFactory::create('c', {4}, {23., 24., 25., 26.}); - auto expBatchVar = NDArrayFactory::create('c', {4}, {208.00001526, 208.00001526, 208.00001526, 208.00001526}); - - sd::ops::fused_batch_norm op; - auto results = op.evaluate({&x, &scale, &offset}, {0.05}, {0,1}); - auto y = results.at(0); - auto batchMean = results.at(1); - auto batchVar = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expY.isSameShape(y)); - ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); - ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); - - + auto x = NDArrayFactory::create('c', {2, 2, 3, 4}); + x.linspace(1); + + auto scale = NDArrayFactory::create('c', {4}); + + scale = 0.5; + auto offset = NDArrayFactory::create('c', {4}); + offset = 2.; + auto expY = NDArrayFactory::create( + 'c', {2, 2, 3, 4}, + {1.20347691, 1.20347691, 1.20347691, 1.20347691, 1.34829926, 1.34829926, + 1.34829926, 1.34829926, 1.49312162, 1.49312162, 1.49312162, 1.49312162, + 1.6379441, 1.6379441, 1.6379441, 1.6379441, 1.78276646, 1.78276646, + 1.78276646, 1.78276646, 1.92758882, 1.92758882, 1.92758882, 1.92758882, + 2.0724113, 2.0724113, 2.0724113, 2.0724113, 2.21723366, 2.21723366, + 2.21723366, 2.21723366, 2.36205602, 2.36205602, 2.36205602, 2.36205602, + 2.50687838, 2.50687838, 2.50687838, 2.50687838, 2.65170074, 2.65170074, + 2.65170074, 2.65170074, 2.79652309, 2.79652309, 2.79652309, 2.79652309}); + auto expBatchMean = + NDArrayFactory::create('c', {4}, {23., 24., 25., 26.}); + auto expBatchVar = NDArrayFactory::create( + 'c', {4}, {208.00001526, 208.00001526, 208.00001526, 208.00001526}); + + sd::ops::fused_batch_norm op; + auto results = op.evaluate({&x, &scale, &offset}, {0.05}, {0, 1}); + auto y = results.at(0); + auto batchMean = results.at(1); + auto batchVar = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expY.isSameShape(y)); + ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); + ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, fusedBatchNorm_test3) { - - auto x = NDArrayFactory::create('c', {2, 4, 2, 3}); - x.linspace(1); - - auto scale = NDArrayFactory::create('c', {4}); - - scale = 0.5; - auto offset = NDArrayFactory::create('c', {4}); - offset = 2.; - auto expY = NDArrayFactory::create('c', {2, 4, 2, 3}, {1.20337462, 1.20337462, 1.20337462, 1.20337462, 1.34821558, 1.34821558, 1.34821558, 1.34821558, 1.49305654, 1.49305654, 1.49305654, 1.49305654, 1.63789749, 1.63789749, 1.63789749, 1.63789749, 1.78273857, 1.78273857, 1.78273857, 1.78273857, 1.92757952, 1.92757952, 1.92757952, 1.92757952, 2.0724206 , 2.0724206 , 2.0724206 , 2.0724206 , 2.21726155, 2.21726155, 2.21726155, 2.21726155, 2.36210251, 2.36210251, 2.36210251, 2.36210251, 2.50694346, 2.50694346, 2.50694346, 2.50694346, 2.65178442, 2.65178442, 2.65178442, 2.65178442, 2.79662538, 2.79662538, 2.79662538, 2.79662538}); - auto expBatchMean = NDArrayFactory::create('c', {4}, {23., 24., 25., 26.}); - auto expBatchVar = NDArrayFactory::create('c', {4}, {208.00001526, 208.00001526, 208.00001526, 208.00001526}); - - sd::ops::fused_batch_norm op; - auto results = op.evaluate({&x, &scale, &offset}, {}, {1,1}); - auto y = results.at(0); - auto batchMean = results.at(1); - auto batchVar = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expY.isSameShape(y)); - ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); - ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); - - + auto x = NDArrayFactory::create('c', {2, 4, 2, 3}); + x.linspace(1); + + auto scale = NDArrayFactory::create('c', {4}); + + scale = 0.5; + auto offset = NDArrayFactory::create('c', {4}); + offset = 2.; + auto expY = NDArrayFactory::create( + 'c', {2, 4, 2, 3}, + {1.20337462, 1.20337462, 1.20337462, 1.20337462, 1.34821558, 1.34821558, + 1.34821558, 1.34821558, 1.49305654, 1.49305654, 1.49305654, 1.49305654, + 1.63789749, 1.63789749, 1.63789749, 1.63789749, 1.78273857, 1.78273857, + 1.78273857, 1.78273857, 1.92757952, 1.92757952, 1.92757952, 1.92757952, + 2.0724206, 2.0724206, 2.0724206, 2.0724206, 2.21726155, 2.21726155, + 2.21726155, 2.21726155, 2.36210251, 2.36210251, 2.36210251, 2.36210251, + 2.50694346, 2.50694346, 2.50694346, 2.50694346, 2.65178442, 2.65178442, + 2.65178442, 2.65178442, 2.79662538, 2.79662538, 2.79662538, 2.79662538}); + auto expBatchMean = + NDArrayFactory::create('c', {4}, {23., 24., 25., 26.}); + auto expBatchVar = NDArrayFactory::create( + 'c', {4}, {208.00001526, 208.00001526, 208.00001526, 208.00001526}); + + sd::ops::fused_batch_norm op; + auto results = op.evaluate({&x, &scale, &offset}, {}, {1, 1}); + auto y = results.at(0); + auto batchMean = results.at(1); + auto batchVar = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expY.isSameShape(y)); + ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); + ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, fusedBatchNorm_test4) { - - auto x = NDArrayFactory::create('c', {2, 2, 3, 4}); - x.linspace(1); - std::vector shape = {4}; - auto scale = NDArrayFactory::create('c', shape); - auto offset = NDArrayFactory::create('c', shape); - auto mean = NDArrayFactory::create('c', shape); - auto variance = NDArrayFactory::create('c', shape); - - scale = 0.5; - offset = 2.; - mean = 25.; - variance = 5.; - - auto expY = NDArrayFactory::create('c', {2, 2, 3, 4}, {-3.36602688, -3.14244223, -2.91885757, -2.6952734 , -2.47168875, -2.24810457, -2.02451992, -1.80093551, -1.57735109, -1.35376668, -1.13018227, -0.90659785, -0.68301344, -0.45942879, -0.23584437, -0.01225996, 0.21132445, 0.43490887, 0.65849328, 0.88207781, 1.10566223, 1.32924664, 1.55283117, 1.77641559, 2. , 2.22358441, 2.44716883, 2.67075348, 2.89433765, 3.11792231, 3.34150672, 3.56509113, 3.78867555, 4.01225996, 4.23584461, 4.45942879, 4.68301344, 4.90659809, 5.13018227, 5.35376644, 5.57735109, 5.80093575, 6.02451992, 6.24810457, 6.47168875, 6.6952734 , 6.91885757, 7.14244223}); - auto expBatchMean = NDArrayFactory::create('c', shape, {0., 0., 0., 0.}); - auto expBatchVar = NDArrayFactory::create('c', shape, {0., 0., 0., 0.}); - - - sd::ops::fused_batch_norm op; - auto results = op.evaluate({&x, &scale, &offset}, {}, {0,1}); - auto y = results.at(0); - auto batchMean = results.at(1); - auto batchVar = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expY.isSameShape(y)); - ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); - ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); - - + auto x = NDArrayFactory::create('c', {2, 2, 3, 4}); + x.linspace(1); + std::vector shape = {4}; + auto scale = NDArrayFactory::create('c', shape); + auto offset = NDArrayFactory::create('c', shape); + auto mean = NDArrayFactory::create('c', shape); + auto variance = NDArrayFactory::create('c', shape); + + scale = 0.5; + offset = 2.; + mean = 25.; + variance = 5.; + + auto expY = NDArrayFactory::create( + 'c', {2, 2, 3, 4}, + {-3.36602688, -3.14244223, -2.91885757, -2.6952734, -2.47168875, + -2.24810457, -2.02451992, -1.80093551, -1.57735109, -1.35376668, + -1.13018227, -0.90659785, -0.68301344, -0.45942879, -0.23584437, + -0.01225996, 0.21132445, 0.43490887, 0.65849328, 0.88207781, + 1.10566223, 1.32924664, 1.55283117, 1.77641559, 2., + 2.22358441, 2.44716883, 2.67075348, 2.89433765, 3.11792231, + 3.34150672, 3.56509113, 3.78867555, 4.01225996, 4.23584461, + 4.45942879, 4.68301344, 4.90659809, 5.13018227, 5.35376644, + 5.57735109, 5.80093575, 6.02451992, 6.24810457, 6.47168875, + 6.6952734, 6.91885757, 7.14244223}); + auto expBatchMean = + NDArrayFactory::create('c', shape, {0., 0., 0., 0.}); + auto expBatchVar = + NDArrayFactory::create('c', shape, {0., 0., 0., 0.}); + + sd::ops::fused_batch_norm op; + auto results = op.evaluate({&x, &scale, &offset}, {}, {0, 1}); + auto y = results.at(0); + auto batchMean = results.at(1); + auto batchVar = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expY.isSameShape(y)); + ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); + ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, fusedBatchNorm_test5) { - - auto x = NDArrayFactory::create('c', {2, 2, 3, 4}); - x.linspace(1); - std::vector shape = {4}; - auto scale = NDArrayFactory::create('c', shape); - auto offset = NDArrayFactory::create('c', shape); - auto mean = NDArrayFactory::create('c', shape); - auto variance = NDArrayFactory::create('c', shape); - - scale = 0.5; - offset = 2.; - mean = 25.; - variance = 5.; - - auto expY = NDArrayFactory::create('c', {2, 2, 3, 4}, {-3.33992958e+00, -3.11743259e+00, -2.89493513e+00, -2.67243814e+00, -2.44994116e+00, -2.22744417e+00, -2.00494719e+00, -1.78244996e+00, -1.55995297e+00, -1.33745599e+00, -1.11495876e+00, -8.92461777e-01, -6.69964790e-01, -4.47467566e-01, -2.24970579e-01, -2.47359276e-03, 2.20023513e-01, 4.42520618e-01, 6.65017605e-01, 8.87514710e-01, 1.11001182e+00, 1.33250880e+00, 1.55500591e+00, 1.77750289e+00, 2.00000000e+00, 2.22249699e+00, 2.44499421e+00, 2.66749120e+00, 2.88998818e+00, 3.11248541e+00, 3.33498240e+00, 3.55747938e+00, 3.77997637e+00, 4.00247383e+00, 4.22497082e+00, 4.44746780e+00, 4.66996479e+00, 4.89246178e+00, 5.11495876e+00, 5.33745575e+00, 5.55995274e+00, 5.78244972e+00, 6.00494719e+00, 6.22744417e+00, 6.44994116e+00, 6.67243814e+00, 6.89493513e+00, 7.11743259e+00}); - auto expBatchMean = NDArrayFactory::create('c', shape, {0., 0., 0., 0.}); - auto expBatchVar = NDArrayFactory::create('c', shape, {0., 0., 0., 0.}); - - - sd::ops::fused_batch_norm op; - auto results = op.evaluate({&x, &scale, &offset}, {0.05}, {0,1}); - auto y = results.at(0); - auto batchMean = results.at(1); - auto batchVar = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expY.isSameShape(y)); - ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); - ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); - - + auto x = NDArrayFactory::create('c', {2, 2, 3, 4}); + x.linspace(1); + std::vector shape = {4}; + auto scale = NDArrayFactory::create('c', shape); + auto offset = NDArrayFactory::create('c', shape); + auto mean = NDArrayFactory::create('c', shape); + auto variance = NDArrayFactory::create('c', shape); + + scale = 0.5; + offset = 2.; + mean = 25.; + variance = 5.; + + auto expY = NDArrayFactory::create( + 'c', {2, 2, 3, 4}, + {-3.33992958e+00, -3.11743259e+00, -2.89493513e+00, -2.67243814e+00, + -2.44994116e+00, -2.22744417e+00, -2.00494719e+00, -1.78244996e+00, + -1.55995297e+00, -1.33745599e+00, -1.11495876e+00, -8.92461777e-01, + -6.69964790e-01, -4.47467566e-01, -2.24970579e-01, -2.47359276e-03, + 2.20023513e-01, 4.42520618e-01, 6.65017605e-01, 8.87514710e-01, + 1.11001182e+00, 1.33250880e+00, 1.55500591e+00, 1.77750289e+00, + 2.00000000e+00, 2.22249699e+00, 2.44499421e+00, 2.66749120e+00, + 2.88998818e+00, 3.11248541e+00, 3.33498240e+00, 3.55747938e+00, + 3.77997637e+00, 4.00247383e+00, 4.22497082e+00, 4.44746780e+00, + 4.66996479e+00, 4.89246178e+00, 5.11495876e+00, 5.33745575e+00, + 5.55995274e+00, 5.78244972e+00, 6.00494719e+00, 6.22744417e+00, + 6.44994116e+00, 6.67243814e+00, 6.89493513e+00, 7.11743259e+00}); + auto expBatchMean = + NDArrayFactory::create('c', shape, {0., 0., 0., 0.}); + auto expBatchVar = + NDArrayFactory::create('c', shape, {0., 0., 0., 0.}); + + sd::ops::fused_batch_norm op; + auto results = op.evaluate({&x, &scale, &offset}, {0.05}, {0, 1}); + auto y = results.at(0); + auto batchMean = results.at(1); + auto batchVar = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expY.isSameShape(y)); + ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); + ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, confusion_matrix_test1) { + auto labels = NDArrayFactory::create('c', {1, 3}, {1, 2, 4}); + auto predictions = NDArrayFactory::create('c', {1, 3}, {2, 2, 4}); + auto expected = NDArrayFactory::create( + 'c', {5, 5}, {0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}); - auto labels = NDArrayFactory::create('c', {1, 3}, {1, 2, 4}); - auto predictions = NDArrayFactory::create('c', {1, 3}, {2, 2, 4}); - auto expected = NDArrayFactory::create('c', {5, 5}, {0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}); - - sd::ops::confusion_matrix op; - auto results = op.evaluate({&labels, &predictions}, {}, {}); - ASSERT_EQ(Status::OK(), results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::confusion_matrix op; + auto results = op.evaluate({&labels, &predictions}, {}, {}); + ASSERT_EQ(Status::OK(), results.status()); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, confusion_matrix_test2) { + auto labels = NDArrayFactory::create('c', {1, 2}, {1, 2}); + auto predictions = NDArrayFactory::create('c', {1, 2}, {0, 2}); + auto expected = NDArrayFactory::create('c', {3, 3}, + {0, 0, 0, 1, 0, 0, 0, 0, 1}); - auto labels = NDArrayFactory::create('c', {1, 2}, {1, 2}); - auto predictions = NDArrayFactory::create('c', {1, 2}, {0, 2}); - auto expected = NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 1, 0, 0, 0, 0, 1}); - - sd::ops::confusion_matrix op; - auto results = op.evaluate({&labels, &predictions}, {}, {3}); - ASSERT_EQ(Status::OK(), results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::confusion_matrix op; + auto results = op.evaluate({&labels, &predictions}, {}, {3}); + ASSERT_EQ(Status::OK(), results.status()); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, confusion_matrix_test3) { + auto labels = NDArrayFactory::create('c', {1, 2}, {1, 2}); + auto predictions = NDArrayFactory::create('c', {1, 2}, {0, 2}); + auto weights = NDArrayFactory::create('c', {1, 2}, {100, 200}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, {0, 0, 0, 100, 0, 0, 0, 0, 200}); - auto labels = NDArrayFactory::create('c', {1, 2}, {1, 2}); - auto predictions = NDArrayFactory::create('c', {1, 2}, {0, 2}); - auto weights = NDArrayFactory::create('c', {1, 2}, {100, 200}); - auto expected = NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 100, 0, 0, 0, 0, 200}); - - sd::ops::confusion_matrix op; - auto results = op.evaluate({&labels, &predictions, &weights}, {}, {3}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + sd::ops::confusion_matrix op; + auto results = op.evaluate({&labels, &predictions, &weights}, {}, {3}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, confusion_matrix_test4) { + auto labels = NDArrayFactory::create('c', {1, 2}, {1, 2}); + auto predictions = NDArrayFactory::create('c', {1, 2}, {0, 2}); + auto weights = NDArrayFactory::create('c', {1, 2}, {100, 200}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, {0, 0, 0, 100, 0, 0, 0, 0, 200}); - auto labels = NDArrayFactory::create('c', {1, 2}, {1, 2}); - auto predictions = NDArrayFactory::create('c', {1, 2}, {0, 2}); - auto weights = NDArrayFactory::create('c', {1, 2}, {100, 200}); - auto expected = NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 100, 0, 0, 0, 0, 200}); - - sd::ops::confusion_matrix op; - auto results = op.evaluate({&labels, &predictions, &weights}, {}, {3}, {}, {sd::DataType::DOUBLE}); - auto output = results.at(0); - - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + sd::ops::confusion_matrix op; + auto results = op.evaluate({&labels, &predictions, &weights}, {}, {3}, {}, + {sd::DataType::DOUBLE}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, ZeroFraction_1) { + auto x = NDArrayFactory::create( + 'c', {3, 4, 2}, {0, 20, 30, 0, 50, 0, 70, 0, 90, 0, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 0, 21, 22, 23, 24}); - auto x = NDArrayFactory::create('c', {3, 4, 2}, {0, 20, 30, 0, 50, 0, - 70, 0, 90, 0, 11, 12, - 13, 14, 15, 16, 17, 18, - 19, 0, 21, 22, 23, 24}); - - sd::ops::zero_fraction op; - auto res = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(Status::OK(), res.status()); - ASSERT_TRUE(res.at(0)->isScalar()); - ASSERT_EQ(res.at(0)->e(0), 0.25); - + sd::ops::zero_fraction op; + auto res = op.evaluate({&x}, {}, {}); + ASSERT_EQ(Status::OK(), res.status()); + ASSERT_TRUE(res.at(0).isScalar()); + ASSERT_EQ(res.at(0).e(0), 0.25); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, ZeroFraction_2) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 2}, {5.5, 0., 0.3, 5.5, 8.6, 0., 0., 0.4}); - auto x = NDArrayFactory::create('c', {2, 2, 2}, {5.5, 0., 0.3, 5.5, 8.6, 0., 0., 0.4}); - - sd::ops::zero_fraction op; - auto res = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(Status::OK(), res.status()); - ASSERT_TRUE(res.at(0)->isScalar()); - ASSERT_EQ(res.at(0)->e(0), 0.375); - + sd::ops::zero_fraction op; + auto res = op.evaluate({&x}, {}, {}); + ASSERT_EQ(Status::OK(), res.status()); + ASSERT_TRUE(res.at(0).isScalar()); + ASSERT_EQ(res.at(0).e(0), 0.375); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, ZeroFraction_3) { + auto x = NDArrayFactory::create( + 'f', {2, 2, 2}, {5.5, 0., 0.3, 5.5, 8.6, 0., 0., 0.4}); - auto x = NDArrayFactory::create('f', {2, 2, 2}, {5.5, 0., 0.3, 5.5, 8.6, 0., 0., 0.4}); - - sd::ops::zero_fraction op; - auto res = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(Status::OK(), res.status()); - ASSERT_TRUE(res.at(0)->isScalar()); - ASSERT_EQ(res.at(0)->e(0), 0.375); - + sd::ops::zero_fraction op; + auto res = op.evaluate({&x}, {}, {}); + ASSERT_EQ(Status::OK(), res.status()); + ASSERT_TRUE(res.at(0).isScalar()); + ASSERT_EQ(res.at(0).e(0), 0.375); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, XWPlusB_1) { + auto x = NDArrayFactory::create('c', {2, 3}, + {1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); + auto y = NDArrayFactory::create('c', {3, 2}, + {11.f, 3.f, 4.f, 5.f, 6.f, 2.f}); + auto b = NDArrayFactory::create({100.f, 200.f}); - auto x = NDArrayFactory::create('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); - auto y = NDArrayFactory::create('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); - auto b = NDArrayFactory::create({ 100.f, 200.f }); + auto exp = + NDArrayFactory::create('c', {2, 2}, {173.f, 264.f, 310.f, 279.f}); - auto exp = NDArrayFactory::create('c', { 2,2 }, { 173.f, 264.f, 310.f, 279.f }); + sd::ops::xw_plus_b op; + auto result = op.evaluate({&x, &y, &b}); - sd::ops::xw_plus_b op; - auto result = op.evaluate({ &x, &y, &b }); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, XWPlusB_2) { + auto x = NDArrayFactory::create('c', {1, 2}, {1.f, 11.f}); + auto y = NDArrayFactory::create('c', {2, 3}, + {11.f, 3.f, 4.f, 5.f, 6.f, 2.f}); + auto b = NDArrayFactory::create({100.f, 200.f, 300.f}); - auto x = NDArrayFactory::create('c', { 1, 2 }, { 1.f, 11.f }); - auto y = NDArrayFactory::create('c', { 2, 3 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); - auto b = NDArrayFactory::create({ 100.f, 200.f, 300.f }); - - auto exp = NDArrayFactory::create('c', { 1, 3 }, { 166.f, 269.f, 326.f }); - - sd::ops::xw_plus_b op; - auto result = op.evaluate({ &x, &y, &b }, {}, {}); + auto exp = NDArrayFactory::create('c', {1, 3}, {166.f, 269.f, 326.f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::xw_plus_b op; + auto result = op.evaluate({&x, &y, &b}, {}, {}); - auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, XWPlusB_3) { + auto x = NDArrayFactory::create('c', {1, 2}, {1.f, 11.f}); + auto y = NDArrayFactory::create('c', {2, 1}, {11.f, 3.f}); + auto b = NDArrayFactory::create('c', {1}, {200.f}); - auto x = NDArrayFactory::create('c', { 1, 2 }, { 1.f, 11.f }); - auto y = NDArrayFactory::create('c', { 2, 1 }, { 11.f, 3.f }); - auto b = NDArrayFactory::create('c', { 1 }, { 200.f }); + auto exp = NDArrayFactory::create('c', {1, 1}, {244.f}); - auto exp = NDArrayFactory::create('c', { 1,1 }, { 244.f }); + sd::ops::xw_plus_b op; + auto result = op.evaluate({&x, &y, &b}); - sd::ops::xw_plus_b op; - auto result = op.evaluate({ &x, &y, &b }); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, XWPlusB_4) { + auto x = NDArrayFactory::create('f', {2, 3}, + {1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); + auto y = NDArrayFactory::create('f', {3, 2}, + {11.f, 3.f, 4.f, 5.f, 6.f, 2.f}); + auto b = NDArrayFactory::create({100.f, 200.f}); - auto x = NDArrayFactory::create('f', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); - auto y = NDArrayFactory::create('f', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); - auto b = NDArrayFactory::create({ 100.f, 200.f }); - - auto exp = NDArrayFactory::create('f', { 2,2 }, { 140.f, 287.f, 233.f, 351.f }); + auto exp = + NDArrayFactory::create('f', {2, 2}, {140.f, 287.f, 233.f, 351.f}); - sd::ops::xw_plus_b op; - auto result = op.evaluate({ &x, &y, &b }); + sd::ops::xw_plus_b op; + auto result = op.evaluate({&x, &y, &b}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); + auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, XWPlusB_5) { + auto x = NDArrayFactory::create('c', {2, 3}, + {1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); + auto y = NDArrayFactory::create('c', {3, 2}, + {11.f, 3.f, 4.f, 5.f, 6.f, 2.f}); - auto x = NDArrayFactory::create('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); - auto y = NDArrayFactory::create('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); - - y = y.transpose(); + y = y.transpose(); - auto b = NDArrayFactory::create({ 100.f, 200.f }); + auto b = NDArrayFactory::create({100.f, 200.f}); - auto exp = NDArrayFactory::create('c', { 2,2 }, { 173.f, 264.f, 310.f, 279.f }); + auto exp = + NDArrayFactory::create('c', {2, 2}, {173.f, 264.f, 310.f, 279.f}); + sd::ops::xw_plus_b op; + auto result = op.evaluate({&x, &y, &b}, {}, {1}); - sd::ops::xw_plus_b op; - auto result = op.evaluate({ &x, &y, &b }, {}, { 1 }); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, XWPlusB_6) { + auto x = NDArrayFactory::create('c', {3, 2}, + {1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); + auto y = NDArrayFactory::create('c', {2, 1}, {11.f, 3.f}); - auto x = NDArrayFactory::create('c', { 3, 2 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); - auto y = NDArrayFactory::create('c', { 2, 1 }, { 11.f, 3.f }); - - auto b = NDArrayFactory::create('c', { 1 }, { 100.f }); + auto b = NDArrayFactory::create('c', {1}, {100.f}); - auto exp = NDArrayFactory::create('c', { 3, 1 }, { 144.f, 175.f, 173.f }); + auto exp = NDArrayFactory::create('c', {3, 1}, {144.f, 175.f, 173.f}); - sd::ops::xw_plus_b op; - auto result = op.evaluate({ &x, &y, &b }); + sd::ops::xw_plus_b op; + auto result = op.evaluate({&x, &y, &b}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); + auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, XWPlusB_7) { + auto x = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); + auto y = NDArrayFactory::create( + 'c', {4, 5}, {11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, + 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 3.f, 11.f, 3.f, 11.f}); - auto x = NDArrayFactory::create('c', { 3, 4 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); - auto y = NDArrayFactory::create('c', { 4, 5 }, { 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 3.f, 11.f, 3.f, 11.f }); + auto b = NDArrayFactory::create('c', {5}, + {100.f, 200.f, 300.f, 400.f, 500.f}); - auto b = NDArrayFactory::create('c', { 5 }, { 100.f, 200.f, 300.f, 400.f, 500.f }); + auto exp = NDArrayFactory::create( + 'c', {3, 5}, + {219.f, 375.f, 531.f, 575.f, 731.f, 217.f, 317.f, 505.f, 517.f, 705.f, + 248.f, 396.f, 496.f, 596.f, 696.f}); - auto exp = NDArrayFactory::create('c', { 3, 5 }, { 219.f, 375.f, 531.f, 575.f, 731.f, 217.f, 317.f, 505.f, 517.f, 705.f, 248.f, 396.f, 496.f, 596.f, 696.f }); + sd::ops::xw_plus_b op; + auto result = op.evaluate({&x, &y, &b}); - sd::ops::xw_plus_b op; - auto result = op.evaluate({ &x, &y, &b }); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, StopGradient_1) { + auto x = NDArrayFactory::create('c', {2, 3}, + {1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); - auto x = NDArrayFactory::create('c', {2,3}, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); - - sd::ops::stop_gradient op; - auto result = op.evaluate({&x}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::stop_gradient op; + auto result = op.evaluate({&x}); - auto output = result.at(0); - - // output->printShapeInfo("Output shape> "); - // x.printShapeInfo("Expected shape> "); - // output->printIndexedBuffer("Output data> "); - // x.printIndexedBuffer("Expected res>"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(x.isSameShape(output)); - ASSERT_TRUE(x.equalsTo(output)); + auto output = result.at(0); + // output->printShapeInfo("Output shape> "); + // x.printShapeInfo("Expected shape> "); + // output->printIndexedBuffer("Output data> "); + // x.printIndexedBuffer("Expected res>"); + ASSERT_TRUE(x.isSameShape(output)); + ASSERT_TRUE(x.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, StopGradient_2) { + auto x = NDArrayFactory::create('f', {2, 3}, + {1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); - auto x = NDArrayFactory::create('f', {2,3}, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); - - sd::ops::stop_gradient op; - auto result = op.evaluate({&x}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::stop_gradient op; + auto result = op.evaluate({&x}); - auto output = result.at(0); - - // output->printShapeInfo("Output shape> "); - // x.printShapeInfo("Expected shape> "); - // output->printIndexedBuffer("Output data> "); - // x.printIndexedBuffer("Expected res>"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(x.isSameShape(output)); - ASSERT_TRUE(x.equalsTo(output)); + auto output = result.at(0); + // output->printShapeInfo("Output shape> "); + // x.printShapeInfo("Expected shape> "); + // output->printIndexedBuffer("Output data> "); + // x.printIndexedBuffer("Expected res>"); + ASSERT_TRUE(x.isSameShape(output)); + ASSERT_TRUE(x.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_test1) { + auto input = NDArrayFactory::create( + 'c', {3, 3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, -5, 5, -6, 6, -7, 7, + -8, 8, -9, 9, -10, 10, -11, 11, -12, 12, -13, 13, 14}); + auto expOutput = NDArrayFactory::create( + 'c', {3, 3, 3}, + {-2.16985e+00, -1.69846e-01, -3.16985e+00, -1.31507e+00, -6.31507e+00, + -3.15072e-01, -8.00046e+00, -4.58767e-04, -9.00046e+00, -1.31327e+00, + -1.23133e+01, -3.13266e-01, -1.40000e+01, -1.13743e-06, -1.50000e+01, + -1.31326e+00, -1.83133e+01, -3.13262e-01, -2.00000e+01, -2.81941e-09, + -2.10000e+01, -1.31326e+00, -2.43133e+01, -3.13262e-01, -2.73133e+01, + -1.31326e+00, -3.13262e-01}); - auto input = NDArrayFactory::create('c', {3, 3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14}); - auto expOutput = NDArrayFactory::create('c', {3, 3, 3}, {-2.16985e+00,-1.69846e-01,-3.16985e+00, -1.31507e+00,-6.31507e+00,-3.15072e-01, -8.00046e+00,-4.58767e-04,-9.00046e+00, -1.31327e+00,-1.23133e+01,-3.13266e-01, -1.40000e+01,-1.13743e-06,-1.50000e+01, -1.31326e+00,-1.83133e+01,-3.13262e-01, -2.00000e+01,-2.81941e-09,-2.10000e+01, -1.31326e+00,-2.43133e+01,-3.13262e-01, -2.73133e+01,-1.31326e+00,-3.13262e-01}); - - sd::ops::log_softmax op; - auto results = op.evaluate({&input}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); - + sd::ops::log_softmax op; + auto results = op.evaluate({&input}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_test2) { + auto input = NDArrayFactory::create( + 'c', {3, 3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, -5, 5, -6, 6, -7, 7, + -8, 8, -9, 9, -10, 10, -11, 11, -12, 12, -13, 13, 14}); + auto expOutput = NDArrayFactory::create( + 'c', {3, 3, 3}, + {-3.05095e+00, -3.04946e+00, -5.00705e+00, -5.09458e-02, -7.04946e+00, + -7.04851e-03, -6.05095e+00, -4.94556e-02, -8.00705e+00, -3.04859e+00, + -1.30000e+01, -3.04859e+00, -1.50486e+01, -2.37286e-06, -1.70486e+01, + -4.85876e-02, -1.60000e+01, -4.85874e-02, -2.10000e+01, -3.04859e+00, + -2.51269e+01, -7.96007e-10, -2.50486e+01, -2.12693e+00, -2.40000e+01, + -4.85874e-02, -1.26928e-01}); - auto input = NDArrayFactory::create('c', {3, 3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14}); - auto expOutput = NDArrayFactory::create('c', {3, 3, 3}, {-3.05095e+00,-3.04946e+00,-5.00705e+00, -5.09458e-02,-7.04946e+00,-7.04851e-03, -6.05095e+00,-4.94556e-02,-8.00705e+00, -3.04859e+00,-1.30000e+01,-3.04859e+00, -1.50486e+01,-2.37286e-06,-1.70486e+01, -4.85876e-02,-1.60000e+01,-4.85874e-02, -2.10000e+01,-3.04859e+00,-2.51269e+01, -7.96007e-10,-2.50486e+01,-2.12693e+00, -2.40000e+01,-4.85874e-02,-1.26928e-01}); - - sd::ops::log_softmax op; - auto results = op.evaluate({&input}, {}, {1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); - + sd::ops::log_softmax op; + auto results = op.evaluate({&input}, {}, {1}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_test3) { + auto input = NDArrayFactory::create( + 'c', {3, 3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, -5, 5, -6, 6, -7, 7, + -8, 8, -9, 9, -10, 10, -11, 11, -12, 12, -13, 13, 14}); + auto expOutput = NDArrayFactory::create( + 'c', {3, 3, 3}, + {-2.16985e+00, -1.69846e-01, -3.16985e+00, -1.31507e+00, -6.31507e+00, + -3.15072e-01, -8.00046e+00, -4.58767e-04, -9.00046e+00, -1.31327e+00, + -1.23133e+01, -3.13266e-01, -1.40000e+01, -1.13743e-06, -1.50000e+01, + -1.31326e+00, -1.83133e+01, -3.13262e-01, -2.00000e+01, -2.81941e-09, + -2.10000e+01, -1.31326e+00, -2.43133e+01, -3.13262e-01, -2.73133e+01, + -1.31326e+00, -3.13262e-01}); - auto input = NDArrayFactory::create('c', {3, 3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14}); - auto expOutput = NDArrayFactory::create('c', {3, 3, 3}, {-2.16985e+00,-1.69846e-01,-3.16985e+00, -1.31507e+00,-6.31507e+00,-3.15072e-01, -8.00046e+00,-4.58767e-04,-9.00046e+00, -1.31327e+00,-1.23133e+01,-3.13266e-01, -1.40000e+01,-1.13743e-06,-1.50000e+01, -1.31326e+00,-1.83133e+01,-3.13262e-01, -2.00000e+01,-2.81941e-09,-2.10000e+01, -1.31326e+00,-2.43133e+01,-3.13262e-01, -2.73133e+01,-1.31326e+00,-3.13262e-01}); - - sd::ops::log_softmax op; - auto results = op.evaluate({&input}, {}, {2}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); - - + sd::ops::log_softmax op; + auto results = op.evaluate({&input}, {}, {2}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_test5) { + auto input = NDArrayFactory::create('c', {3, 3}, + {-1, 1, -2, 2, -3, 3, -4, 4, 5}); + auto expOutput = NDArrayFactory::create( + 'c', {3, 3}, + {-2.16985, -0.16985, -3.16985, -1.31507, -6.31507, -0.31507, -9.31335, + -1.31335, -0.31335}); - auto input = NDArrayFactory::create('c', {3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, 5}); - auto expOutput = NDArrayFactory::create('c', {3, 3}, {-2.16985, -0.16985, -3.16985, -1.31507, -6.31507, -0.31507, -9.31335, -1.31335, -0.31335}); - - sd::ops::log_softmax op; - auto results = op.evaluate({&input}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); - + sd::ops::log_softmax op; + auto results = op.evaluate({&input}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_test6) { + auto input = NDArrayFactory::create('c', {3, 3}, + {-1, 1, -2, 2, -3, 3, -4, 4, 5}); + auto expOutput = NDArrayFactory::create( + 'c', {3, 3}, + {-3.05095, -3.04946, -7.12773, -0.05095, -7.04946, -2.12773, -6.05095, + -0.04946, -0.12773}); - auto input = NDArrayFactory::create('c', {3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, 5}); - auto expOutput = NDArrayFactory::create('c', {3, 3}, {-3.05095,-3.04946,-7.12773, -0.05095,-7.04946,-2.12773, -6.05095,-0.04946,-0.12773}); - - sd::ops::log_softmax op; - auto results = op.evaluate({&input}, {}, {0}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); - + sd::ops::log_softmax op; + auto results = op.evaluate({&input}, {}, {0}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_test7) { + auto input = NDArrayFactory::create('c', {1, 5}, {-1, 1, -2, 2, 3}); + auto expOutput = NDArrayFactory::create( + 'c', {1, 5}, {-4.42414, -2.42414, -5.42414, -1.42414, -0.42414}); - auto input = NDArrayFactory::create('c', {1, 5}, {-1, 1, -2, 2, 3}); - auto expOutput = NDArrayFactory::create('c', {1, 5}, {-4.42414, -2.42414, -5.42414, -1.42414, -0.42414}); - - sd::ops::log_softmax op; - auto results = op.evaluate({&input}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); - + sd::ops::log_softmax op; + auto results = op.evaluate({&input}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_test8) { + auto input = NDArrayFactory::create('c', {1, 5}, {-1, 1, -2, 2, 3}); + auto expOutput = NDArrayFactory::create('c', {1, 5}, {0, 0, 0, 0, 0}); - auto input = NDArrayFactory::create('c', {1, 5}, {-1, 1, -2, 2, 3}); - auto expOutput = NDArrayFactory::create('c', {1, 5}, {0, 0, 0, 0, 0}); - - sd::ops::log_softmax op; - auto results = op.evaluate({&input}, {}, {0}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); - + sd::ops::log_softmax op; + auto results = op.evaluate({&input}, {}, {0}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_test9) { + auto input = NDArrayFactory::create('c', {5, 1}, {-1, 1, -2, 2, 3}); + auto expOutput = NDArrayFactory::create('c', {5, 1}, {0, 0, 0, 0, 0}); - auto input = NDArrayFactory::create('c', {5, 1}, {-1, 1, -2, 2, 3}); - auto expOutput = NDArrayFactory::create('c', {5, 1}, {0, 0, 0, 0, 0}); - - sd::ops::log_softmax op; - auto results = op.evaluate({&input}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); - + sd::ops::log_softmax op; + auto results = op.evaluate({&input}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_test10) { + auto input = NDArrayFactory::create('c', {5, 1}, {-1, 1, -2, 2, 3}); + auto expOutput = NDArrayFactory::create( + 'c', {5, 1}, {-4.42414, -2.42414, -5.42414, -1.42414, -0.42414}); - auto input = NDArrayFactory::create('c', {5, 1}, {-1, 1, -2, 2, 3}); - auto expOutput = NDArrayFactory::create('c', {5, 1}, {-4.42414, -2.42414, -5.42414, -1.42414, -0.42414}); - - sd::ops::log_softmax op; - auto results = op.evaluate({&input}, {}, {0}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); - + sd::ops::log_softmax op; + auto results = op.evaluate({&input}, {}, {0}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_test11) { + auto input = NDArrayFactory::create('c', {5}, {-1, 1, -2, 2, 3}); + auto expOutput = NDArrayFactory::create( + 'c', {5}, {-4.42414, -2.42414, -5.42414, -1.42414, -0.42414}); - auto input = NDArrayFactory::create('c', {5}, {-1, 1, -2, 2, 3}); - auto expOutput = NDArrayFactory::create('c', {5}, {-4.42414, -2.42414, -5.42414, -1.42414, -0.42414}); - - sd::ops::log_softmax op; - auto results = op.evaluate({&input}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); - + sd::ops::log_softmax op; + auto results = op.evaluate({&input}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_test12) { + auto input = NDArrayFactory::create( + 'c', {1, 4}, {0.1869, -1.4918, -0.6497, -0.8864}); + auto expOutput = NDArrayFactory::create( + 'c', {1, 4}, {-0.6738, -2.3525, -1.5104, -1.7472}); - auto input = NDArrayFactory::create('c', {1, 4}, {0.1869, -1.4918, -0.6497, -0.8864}); - auto expOutput = NDArrayFactory::create('c', {1, 4}, {-0.6738, -2.3525, -1.5104, -1.7472}); - - for (int i = 0; i < 10; ++i) - { - sd::ops::log_softmax op; - auto results = op.evaluate({&input}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z, 1e-4)); - + for (int i = 0; i < 10; ++i) { + sd::ops::log_softmax op; + auto results = op.evaluate({&input}); + auto z = results.at(0); - } + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z, 1e-4)); + } } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_bp_test1) { + auto input = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto epsilon = + NDArrayFactory::create('c', {2, 2}, {0.1, 0.2, 0.3, 0.4}); + auto exp = NDArrayFactory::create( + 'c', {2, 2}, {-0.07311, 0.02689, -0.07311, 0.02689}); - auto input = NDArrayFactory::create('c', {2, 2}, {1,2,3,4}); - auto epsilon = NDArrayFactory::create('c', {2, 2}, {0.1, 0.2, 0.3, 0.4}); - auto exp = NDArrayFactory::create('c', {2, 2}, {-0.07311,0.02689, -0.07311,0.02689}); - - sd::ops::log_softmax_bp op; - auto results = op.evaluate({&input, &epsilon}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::log_softmax_bp op; + auto results = op.evaluate({&input, &epsilon}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_bp_test2) { + auto input = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto epsilon = + NDArrayFactory::create('c', {2, 2}, {0.1, 0.2, 0.3, 0.4}); + auto exp = NDArrayFactory::create( + 'c', {2, 2}, {-0.17616, -0.17616, 0.02384, 0.02384}); - auto input = NDArrayFactory::create('c', {2, 2}, {1,2,3,4}); - auto epsilon = NDArrayFactory::create('c', {2, 2}, {0.1, 0.2, 0.3, 0.4}); - auto exp = NDArrayFactory::create('c', {2, 2}, {-0.17616, -0.17616, 0.02384, 0.02384}); - - sd::ops::log_softmax_bp op; - auto results = op.evaluate({&input, &epsilon}, {}, {0}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::log_softmax_bp op; + auto results = op.evaluate({&input, &epsilon}, {}, {0}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, ELU_1) { + auto input = NDArrayFactory::create( + 'c', {2, 2, 2}, {-1., 2., 1.5, -1.4, 1., 2., 2., 1.}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2}, {-0.63212055, 2., 1.5, -0.753403, 1., 2., 2., 1.}); + auto res = NDArrayFactory::create('c', {2, 2, 2}); - auto input = NDArrayFactory::create('c', {2, 2, 2}, { -1., 2. , 1.5, -1.4, 1., 2., 2., 1.}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, { -0.63212055, 2. , 1.5, -0.753403, 1., 2., 2., 1.}); - auto res = NDArrayFactory::create('c', {2, 2, 2}); - - input.applyScalar(sd::scalar::ELU, 1.f, res); + input.applyScalar(sd::scalar::ELU, 1.f, res); - ASSERT_TRUE(res.equalsTo(&exp)); + ASSERT_TRUE(res.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, L2_Loss_1) { + auto input = NDArrayFactory::create( + 'c', {2, 2, 2}, {-1., 2., 1.5, -1.4, 1., 2., 2., 1.}); + double exp(9.605); - auto input = NDArrayFactory::create('c', {2, 2, 2}, { -1., 2. , 1.5, -1.4, 1., 2., 2., 1.}); - double exp(9.605); - - sd::ops::l2_loss op; - auto results = op.evaluate({&input}, {}, {}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(output->isScalar()); - - ASSERT_EQ(output->e(0), exp); + sd::ops::l2_loss op; + auto results = op.evaluate({&input}, {}, {}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(output.isScalar()); + ASSERT_EQ(output.e(0), exp); } TEST_F(DeclarableOpsTests5, L2_Loss_2) { - auto x = NDArrayFactory::create(0.7787855863571167); - auto e = NDArrayFactory::create(0.303254); - - sd::ops::l2_loss op; - auto results = op.evaluate({&x}, {}, {}); - ASSERT_EQ(Status::OK(), results.status()); - - auto z = results.at(0); + auto x = NDArrayFactory::create(0.7787855863571167); + auto e = NDArrayFactory::create(0.303254); - ASSERT_EQ(e, *z); + sd::ops::l2_loss op; + auto results = op.evaluate({&x}, {}, {}); + ASSERT_EQ(Status::OK(), results.status()); + auto z = results.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests5, L2_Loss_3) { - auto x = NDArrayFactory::create(0.7787855863571167); - auto e = NDArrayFactory::create(0.303254); - auto z = NDArrayFactory::create(0.0); + auto x = NDArrayFactory::create(0.7787855863571167); + auto e = NDArrayFactory::create(0.303254); + auto z = NDArrayFactory::create(0.0); - sd::ops::l2_loss op; - auto status = op.execute({&x}, {&z} , {}, {}, {}); - ASSERT_EQ(Status::OK(), status); + sd::ops::l2_loss op; + auto status = op.execute({&x}, {&z}, {}, {}, {}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, LogPoissonLoss_1) { - auto weights = NDArrayFactory::create('c', {1, 1}, {1}); + auto weights = NDArrayFactory::create('c', {1, 1}, {1}); - auto input = NDArrayFactory::create('c', {2, 2, 2}, { -1., 2. , 1.5, -1.4, 1., 2., 2., 1.}); - auto targets = NDArrayFactory::create('c', {2, 2, 2}, {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}); + auto input = NDArrayFactory::create( + 'c', {2, 2, 2}, {-1., 2., 1.5, -1.4, 1., 2., 2., 1.}); + auto targets = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1.3678794, 5.389056, 2.981689, 1.6465969, 1.7182817, 5.389056, 5.389056, 1.7182817}); - - sd::ops::log_poisson_loss op; - auto results = op.evaluate({&input, &weights, &targets}, {}, {0}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2}, + {1.3678794, 5.389056, 2.981689, 1.6465969, 1.7182817, 5.389056, 5.389056, + 1.7182817}); + sd::ops::log_poisson_loss op; + auto results = op.evaluate({&input, &weights, &targets}, {}, {0}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, LogPoissonLoss_2) { + auto weights = NDArrayFactory::create('c', {1, 1}, {1}); - auto weights = NDArrayFactory::create('c', {1, 1}, {1}); - - auto input = NDArrayFactory::create('c', {2, 2, 2}, { -1., 2. , 1.5, -1.4, 1., 2., 2., 1.}); - auto targets = NDArrayFactory::create('c', {2, 2, 2}, {2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0}); - - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {3.0196857, 4.0408626, 2.1334953, 3.6984034, 1.3700882, 4.0408626, 4.0408626, 1.3700882}); - - sd::ops::log_poisson_loss op; - auto results = op.evaluate({&input, &weights, &targets}, {}, {0, 1}); - auto output = results.at(0); + auto input = NDArrayFactory::create( + 'c', {2, 2, 2}, {-1., 2., 1.5, -1.4, 1., 2., 2., 1.}); + auto targets = NDArrayFactory::create( + 'c', {2, 2, 2}, {2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0}); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2}, + {3.0196857, 4.0408626, 2.1334953, 3.6984034, 1.3700882, 4.0408626, + 4.0408626, 1.3700882}); + sd::ops::log_poisson_loss op; + auto results = op.evaluate({&input, &weights, &targets}, {}, {0, 1}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, NormalizeMoments_1) { + auto means = NDArrayFactory::create( + 'c', {2, 3, 4}, {11., 3., 14., 5., 6., 9., 3.5, 7., 21., 3., 14., 15., + 6., 9., 3.5, 7., 11., 13., 14., 5., 16., 9., 13.5, 7.}); - auto means = NDArrayFactory::create('c', {2, 3, 4}, { 11., 3., 14., 5., - 6., 9., 3.5, 7., - 21., 3., 14., 15., - 6., 9., 3.5, 7., - 11., 13., 14., 5., - 16., 9., 13.5, 7.}); - - auto deviance = NDArrayFactory::create('c', {2, 3, 4}, { 21., 13., 24., 15., - 16., 19., 13.5, 17., - 31., 13., 24., 25., - 16., 19., 13.5, 17., - 21., 23., 24., 15., - 26., 19., 23.5, 17.}); - - auto counts = NDArrayFactory::create(2.0); - - auto expMeans = NDArrayFactory::create('c', {2, 3, 4}, { - 5.5, 1.5, 7., 2.5, - 3., 4.5, 1.75, 3.5, - 10.5, 1.5, 7., 7.5, - 3., 4.5, 1.75, 3.5, - 5.5, 6.5, 7., 2.5, - 8., 4.5, 6.75, 3.5}); - - auto expDeviance = NDArrayFactory::create('c', {2, 3, 4}, { - -19.75, 4.25, -37., 1.25, - -1., -10.75, 3.6875, -3.75, - -94.75, 4.25, -37., -43.75, - -1., -10.75, 3.6875, -3.75, - -19.75, -30.75, -37., 1.25, - -51., -10.75, -33.8125, -3.75}); - - sd::ops::normalize_moments op; - auto results = op.evaluate({&counts, &means, &deviance}, {0.0}, {}); + auto deviance = NDArrayFactory::create( + 'c', {2, 3, 4}, + {21., 13., 24., 15., 16., 19., 13.5, 17., 31., 13., 24., 25., + 16., 19., 13.5, 17., 21., 23., 24., 15., 26., 19., 23.5, 17.}); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_EQ(results.size(), 2); + auto counts = NDArrayFactory::create(2.0); + + auto expMeans = NDArrayFactory::create( + 'c', {2, 3, 4}, + {5.5, 1.5, 7., 2.5, 3., 4.5, 1.75, 3.5, 10.5, 1.5, 7., 7.5, + 3., 4.5, 1.75, 3.5, 5.5, 6.5, 7., 2.5, 8., 4.5, 6.75, 3.5}); - auto outputMeans = results.at(0); - auto outputDeviance = results.at(1); + auto expDeviance = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-19.75, 4.25, -37., 1.25, -1., -10.75, 3.6875, -3.75, + -94.75, 4.25, -37., -43.75, -1., -10.75, 3.6875, -3.75, + -19.75, -30.75, -37., 1.25, -51., -10.75, -33.8125, -3.75}); - ASSERT_TRUE(expMeans.isSameShape(outputMeans)); - ASSERT_TRUE(expMeans.equalsTo(outputMeans)); - ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); - ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); + sd::ops::normalize_moments op; + auto results = op.evaluate({&counts, &means, &deviance}, {0.0}, {}); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_EQ(results.size(), 2); + auto outputMeans = results.at(0); + auto outputDeviance = results.at(1); + + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); + ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, NormalizeMoments_2) { + auto means = NDArrayFactory::create( + 'c', {3, 2, 4}, {11., 3., 14., 5., 6., 9., 3.5, 7., 21., 3., 14., 15., + 6., 9., 3.5, 7., 11., 13., 14., 5., 16., 9., 13.5, 7.}); - auto means = NDArrayFactory::create('c', {3, 2, 4}, { 11., 3., 14., 5., - 6., 9., 3.5, 7., - 21., 3., 14., 15., - 6., 9., 3.5, 7., - 11., 13., 14., 5., - 16., 9., 13.5, 7.}); - - auto deviance = NDArrayFactory::create('c', {3, 2, 4}, { 21., 13., 24., 15., - 16., 19., 13.5, 17., - 31., 13., 24., 25., - 16., 19., 13.5, 17., - 21., 23., 24., 15., - 26., 19., 23.5, 17.}); - - auto counts = NDArrayFactory::create(12.0); - - auto expMeans = NDArrayFactory::create('c', {3, 2, 4}, { 0.9166667, 0.25, 1.1666667, 0.4166667, - 0.5, 0.75, 0.2916667, 0.5833334, - 1.75, 0.25, 1.1666667, 1.25, - 0.5, 0.75, 0.2916667, 0.5833334, - 0.9166667, 1.0833334, 1.1666667, 0.4166667, - 1.3333334, 0.75, 1.125, 0.5833334}); - - auto expDeviance = NDArrayFactory::create('c', {3, 2, 4}, { - 0.9097222, 1.0208334, 0.6388887, 1.0763888, - 1.0833334, 1.0208334, 1.0399306, 1.076389, - -0.4791665, 1.0208334, 0.6388887, 0.5208335, - 1.0833334, 1.0208334, 1.0399306, 1.076389, - 0.9097222, 0.7430556, 0.6388887, 1.0763888, - 0.38888884, 1.0208334, 0.6927084, 1.076389}); - - sd::ops::normalize_moments op; - auto results = op.evaluate({&counts, &means, &deviance}, {0.0}, {}); + auto deviance = NDArrayFactory::create( + 'c', {3, 2, 4}, + {21., 13., 24., 15., 16., 19., 13.5, 17., 31., 13., 24., 25., + 16., 19., 13.5, 17., 21., 23., 24., 15., 26., 19., 23.5, 17.}); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_EQ(results.size(), 2); + auto counts = NDArrayFactory::create(12.0); - auto outputMeans = results.at(0); - auto outputDeviance = results.at(1); + auto expMeans = NDArrayFactory::create( + 'c', {3, 2, 4}, + {0.9166667, 0.25, 1.1666667, 0.4166667, 0.5, 0.75, + 0.2916667, 0.5833334, 1.75, 0.25, 1.1666667, 1.25, + 0.5, 0.75, 0.2916667, 0.5833334, 0.9166667, 1.0833334, + 1.1666667, 0.4166667, 1.3333334, 0.75, 1.125, 0.5833334}); - ASSERT_TRUE(expMeans.isSameShape(outputMeans)); - ASSERT_TRUE(expMeans.equalsTo(outputMeans)); - ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); - ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); + auto expDeviance = NDArrayFactory::create( + 'c', {3, 2, 4}, + {0.9097222, 1.0208334, 0.6388887, 1.0763888, 1.0833334, 1.0208334, + 1.0399306, 1.076389, -0.4791665, 1.0208334, 0.6388887, 0.5208335, + 1.0833334, 1.0208334, 1.0399306, 1.076389, 0.9097222, 0.7430556, + 0.6388887, 1.0763888, 0.38888884, 1.0208334, 0.6927084, 1.076389}); + sd::ops::normalize_moments op; + auto results = op.evaluate({&counts, &means, &deviance}, {0.0}, {}); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_EQ(results.size(), 2); + + auto outputMeans = results.at(0); + auto outputDeviance = results.at(1); + + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); + ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, NormalizeMoments_3) { - - auto means = NDArrayFactory::create('c', {3, 2, 4}, { 11., 3., 14., 5., - 6., 9., 3.5, 7., - 21., 3., 14., 15., - 6., 9., 3.5, 7., - 11., 13., 14., 5., - 16., 9., 13.5, 7.}); - - auto deviance = NDArrayFactory::create('c', {3, 2, 4}, { 21., 13., 24., 15., - 16., 19., 13.5, 17., - 31., 13., 24., 25., - 16., 19., 13.5, 17., - 21., 23., 24., 15., - 26., 19., 23.5, 17.}); - - auto counts = NDArrayFactory::create(12.0); - double shift = 10.0; - auto expMeans = NDArrayFactory::create('c', {3, 2, 4}, { 10.9166667, 10.25, 11.1666667, 10.4166667, - 10.5, 10.75, 10.2916667, 10.5833334, - 11.75, 10.25, 11.1666667, 11.25, - 10.5, 10.75, 10.2916667, 10.5833334, - 10.9166667, 11.0833334, 11.1666667, 10.4166667, - 11.3333334, 10.75, 11.125, 10.5833334}); - - auto expDeviance = NDArrayFactory::create('c', {3, 2, 4}, { - 0.9097222, 1.0208334, 0.6388887, 1.0763888, - 1.0833334, 1.0208334, 1.0399306, 1.076389, - -0.4791665, 1.0208334, 0.6388887, 0.5208335, - 1.0833334, 1.0208334, 1.0399306, 1.076389, - 0.9097222, 0.7430556, 0.6388887, 1.0763888, - 0.38888884, 1.0208334, 0.6927084, 1.076389}); - - sd::ops::normalize_moments op; - auto results = op.evaluate({&counts, &means, &deviance}, {shift}, {}); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_EQ(results.size(), 2); - - auto outputMeans = results.at(0); - auto outputDeviance = results.at(1); - - ASSERT_TRUE(expMeans.isSameShape(outputMeans)); - ASSERT_TRUE(expMeans.equalsTo(outputMeans)); - ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); - ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); - - + auto means = NDArrayFactory::create( + 'c', {3, 2, 4}, {11., 3., 14., 5., 6., 9., 3.5, 7., 21., 3., 14., 15., + 6., 9., 3.5, 7., 11., 13., 14., 5., 16., 9., 13.5, 7.}); + + auto deviance = NDArrayFactory::create( + 'c', {3, 2, 4}, + {21., 13., 24., 15., 16., 19., 13.5, 17., 31., 13., 24., 25., + 16., 19., 13.5, 17., 21., 23., 24., 15., 26., 19., 23.5, 17.}); + + auto counts = NDArrayFactory::create(12.0); + double shift = 10.0; + auto expMeans = NDArrayFactory::create( + 'c', {3, 2, 4}, + {10.9166667, 10.25, 11.1666667, 10.4166667, 10.5, 10.75, + 10.2916667, 10.5833334, 11.75, 10.25, 11.1666667, 11.25, + 10.5, 10.75, 10.2916667, 10.5833334, 10.9166667, 11.0833334, + 11.1666667, 10.4166667, 11.3333334, 10.75, 11.125, 10.5833334}); + + auto expDeviance = NDArrayFactory::create( + 'c', {3, 2, 4}, + {0.9097222, 1.0208334, 0.6388887, 1.0763888, 1.0833334, 1.0208334, + 1.0399306, 1.076389, -0.4791665, 1.0208334, 0.6388887, 0.5208335, + 1.0833334, 1.0208334, 1.0399306, 1.076389, 0.9097222, 0.7430556, + 0.6388887, 1.0763888, 0.38888884, 1.0208334, 0.6927084, 1.076389}); + + sd::ops::normalize_moments op; + auto results = op.evaluate({&counts, &means, &deviance}, {shift}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_EQ(results.size(), 2); + + auto outputMeans = results.at(0); + auto outputDeviance = results.at(1); + + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); + ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); } - diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp index ed9dbee6877d..cc482adfa202 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp @@ -18,784 +18,773 @@ // Created by raver119 on 09.02.18. // - -#include "testlayers.h" -#include -#include #include #include +#include +#include +#include "testlayers.h" using namespace sd; using namespace sd::graph; class DeclarableOpsTests6 : public testing::Test { -public: - - DeclarableOpsTests6() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests6() { + printf("\n"); + fflush(stdout); + } }; - TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_1) { - auto matrix = NDArrayFactory::create('c', {5, 2}); - auto b = NDArrayFactory::create('c', {1}, {0.}); - auto e = NDArrayFactory::create('c', {1}, {1}); - auto s = NDArrayFactory::create('c', {1}, {1}); - - auto exp = NDArrayFactory::create('c', {2}, {1.0f, 2.0f}); + auto matrix = NDArrayFactory::create('c', {5, 2}); + auto b = NDArrayFactory::create('c', {1}, {0.}); + auto e = NDArrayFactory::create('c', {1}, {1}); + auto s = NDArrayFactory::create('c', {1}, {1}); - matrix.linspace(1); + auto exp = NDArrayFactory::create('c', {2}, {1.0f, 2.0f}); - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); + matrix.linspace(1); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_2) { - auto matrix = NDArrayFactory::create('c', {5, 2}); - auto b = NDArrayFactory::create('c', {1}, {0.0f}); - auto e = NDArrayFactory::create('c', {1}, {1.0f}); - auto s = NDArrayFactory::create('c', {1}, {1.0f}); - - auto exp = NDArrayFactory::create('c', {2}, {1.0f, 2.0f}); + auto matrix = NDArrayFactory::create('c', {5, 2}); + auto b = NDArrayFactory::create('c', {1}, {0.0f}); + auto e = NDArrayFactory::create('c', {1}, {1.0f}); + auto s = NDArrayFactory::create('c', {1}, {1.0f}); - matrix.linspace(1); + auto exp = NDArrayFactory::create('c', {2}, {1.0f, 2.0f}); - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); + matrix.linspace(1); - ASSERT_EQ(exp, *z); + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(exp, z); } TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_3) { - auto matrix = NDArrayFactory::create(10); - auto b = NDArrayFactory::create(0); - auto e = NDArrayFactory::create(0); - auto s = NDArrayFactory::create(1.0); + auto matrix = NDArrayFactory::create(10); + auto b = NDArrayFactory::create(0); + auto e = NDArrayFactory::create(0); + auto s = NDArrayFactory::create(1.0); - //auto exp = NDArrayFactory::create('c', {2}, {1.0f, 2.0f}); + // auto exp = NDArrayFactory::create('c', {2}, {1.0f, 2.0f}); - //matrix.linspace(1); - - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - //z->printShapeInfo("SS OS shape"); - ASSERT_TRUE(z->isEmpty()); - //ASSERT_EQ(exp, *z); + // matrix.linspace(1); + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printShapeInfo("SS OS shape"); + ASSERT_TRUE(z.isEmpty()); + // ASSERT_EQ(exp, *z); } TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) { - auto matrix = NDArrayFactory::create('c', {1}, {10}); - auto b = NDArrayFactory::create('c', {1}, {0.}); - auto e = NDArrayFactory::create('c', {1}, {0.}); - auto s = NDArrayFactory::create('c', {1}, {1.0}); - - auto exp = NDArrayFactory::create(10); + auto matrix = NDArrayFactory::create('c', {1}, {10}); + auto b = NDArrayFactory::create('c', {1}, {0.}); + auto e = NDArrayFactory::create('c', {1}, {0.}); + auto s = NDArrayFactory::create('c', {1}, {1.0}); - //matrix.linspace(1); + auto exp = NDArrayFactory::create(10); - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); + // matrix.linspace(1); - auto z = result.at(0); - - ASSERT_TRUE(z->equalsTo(exp)); - //ASSERT_EQ(exp, *z); + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(z.equalsTo(exp)); + // ASSERT_EQ(exp, *z); } TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_04) { - int z = 0; - auto matrix = NDArrayFactory::create('c', {1}, {10}); - auto b = NDArrayFactory::create_('c', {1}, {1}); - auto e = NDArrayFactory::create_('c', {1}, {z}); - auto s = NDArrayFactory::create_('c', {1}, {1}); - sd::ops::ones_as opOnes; - //auto exp = NDArrayFactory::create('c', {2}, {1.0f, 2.0f}); - auto onesRes = opOnes.evaluate({&matrix}); - //matrix.linspace(1); - ASSERT_EQ(onesRes.status(), Status::OK()); - - auto ones = onesRes.at(0); - *ones *= 10; - auto onesD = new NDArray(ones->dup()); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, onesD); - variableSpace->putVariable(-2, b); - variableSpace->putVariable(-3, e); - variableSpace->putVariable(-4, s); - auto block = new Context(1, variableSpace, false); // not-in-place - block->fillInputs({-1}); - block->fillInputs({-2}); - block->fillInputs({-3}); - block->fillInputs({-4}); - block->getIArguments()->push_back(0); - block->getIArguments()->push_back(0); - block->getIArguments()->push_back(1); - block->getIArguments()->push_back(0); - block->getIArguments()->push_back(0); - auto inputShapes = new ShapeList({ones->shapeInfo(), b->shapeInfo(), e->shapeInfo(), s->shapeInfo()}); - sd::ops::strided_slice op; - auto result = op.calculateOutputShape(inputShapes, *block); //execute({ones, &b, &e, &s}, {}, {0, 1, 0, 0, 0}); - ASSERT_EQ(result->size(), 1); - ASSERT_TRUE(shape::isEmpty(result->at(0))); - //ASSERT_EQ(exp, *z); - delete block; - delete result; - delete variableSpace; - delete inputShapes; + int z = 0; + auto matrix = NDArrayFactory::create('c', {1}, {10}); + auto b = NDArrayFactory::create('c', {1}, {1}); + auto e = NDArrayFactory::create('c', {1}, {z}); + auto s = NDArrayFactory::create('c', {1}, {1}); + sd::ops::ones_as opOnes; + // auto exp = NDArrayFactory::create('c', {2}, {1.0f, 2.0f}); + auto onesRes = opOnes.evaluate({&matrix}); + // matrix.linspace(1); + ASSERT_EQ(onesRes.status(), Status::OK()); + + auto ones = onesRes.at(0); + ones *= 10; + auto onesD = ones.dup(); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, onesD); + variableSpace->putVariable(-2, b); + variableSpace->putVariable(-3, e); + variableSpace->putVariable(-4, s); + auto block = new Context(1, variableSpace, false); // not-in-place + block->fillInputs({-1}); + block->fillInputs({-2}); + block->fillInputs({-3}); + block->fillInputs({-4}); + + block->appendI(0); + block->appendI(0); + block->appendI(1); + block->appendI(0); + block->appendI(0); + + auto inputShapes = new ShapeList( + {ones.shapeInfo(), b.shapeInfo(), e.shapeInfo(), s.shapeInfo()}); + + sd::ops::strided_slice op; + + auto result = op.calculateOutputShape( + inputShapes, + *block); // execute({ones, &b, &e, &s}, {}, {0, 1, 0, 0, 0}); + ASSERT_EQ(result->size(), 1); + ASSERT_TRUE(shape::isEmpty(result->at(0))); + // ASSERT_EQ(exp, *z); + delete block; + delete result; + delete variableSpace; + delete inputShapes; } TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_5) { - auto matrix = NDArrayFactory::create('c', {3, 2, 2}); - auto b = NDArrayFactory::create('c', {1}, {2}); - auto e = NDArrayFactory::create('c', {1}, {3}); - auto s = NDArrayFactory::create('c', {1}, {1}); - - auto exp = NDArrayFactory::create('c', {2,2}, {0.0f, 0.0f, 0., 0.}); - - //matrix.linspace(1); + auto matrix = NDArrayFactory::create('c', {3, 2, 2}); + auto b = NDArrayFactory::create('c', {1}, {2}); + auto e = NDArrayFactory::create('c', {1}, {3}); + auto s = NDArrayFactory::create('c', {1}, {1}); - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); + auto exp = NDArrayFactory::create('c', {2, 2}, {0.0f, 0.0f, 0., 0.}); - auto z = result.at(0); - ASSERT_TRUE(exp.equalsTo(z)); + // matrix.linspace(1); + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_6) { - auto matrix = NDArrayFactory::create('c', {3, 2, 2}); - auto b = NDArrayFactory::create('c', {1}, {2}); - auto e = NDArrayFactory::create('c', {1}, {3}); - auto s = NDArrayFactory::create('c', {1}, {1}); - - auto exp = NDArrayFactory::create('c', {1,2,2}, {0.0f, 0.0f, 0., 0.}); - - //matrix.linspace(1); + auto matrix = NDArrayFactory::create('c', {3, 2, 2}); + auto b = NDArrayFactory::create('c', {1}, {2}); + auto e = NDArrayFactory::create('c', {1}, {3}); + auto s = NDArrayFactory::create('c', {1}, {1}); - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 2}); - ASSERT_EQ(Status::OK(), result.status()); + auto exp = + NDArrayFactory::create('c', {1, 2, 2}, {0.0f, 0.0f, 0., 0.}); - auto z = result.at(0); - ASSERT_TRUE(exp.equalsTo(z)); + // matrix.linspace(1); + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 2}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_7) { - int zero = 0; - auto matrix = NDArrayFactory::create('c', {5, 4}); - auto b = NDArrayFactory::create('c', {1}, {zero}); - auto e = NDArrayFactory::create('c', {1}, {zero}); - auto s = NDArrayFactory::create('c', {1}, {1}); - - //auto exp = NDArrayFactory::create('c', {1,2,2}, {0.0f, 0.0f, 0., 0.}); - - //matrix.linspace(1); + int zero = 0; + auto matrix = NDArrayFactory::create('c', {5, 4}); + auto b = NDArrayFactory::create('c', {1}, {zero}); + auto e = NDArrayFactory::create('c', {1}, {zero}); + auto s = NDArrayFactory::create('c', {1}, {1}); - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {1, 0, 0, 0, 0}); - ASSERT_EQ(Status::OK(), result.status()); + // auto exp = NDArrayFactory::create('c', {1,2,2}, {0.0f, 0.0f, 0., + // 0.}); - auto z = result.at(0); - //ASSERT_TRUE(exp.equalsTo(z)); + // matrix.linspace(1); + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {1, 0, 0, 0, 0}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) { - int zero = 0; - auto matrix = NDArrayFactory::create('c', {5, 4}); -// auto b = NDArrayFactory::create('c', {1}, {zero}); -// auto e = NDArrayFactory::create('c', {1}, {zero}); -// auto s = NDArrayFactory::create('c', {1}, {1}); - - auto grad = NDArrayFactory::create('c', {5}); + int zero = 0; + auto matrix = NDArrayFactory::create('c', {5, 4}); + // auto b = NDArrayFactory::create('c', {1}, {zero}); + // auto e = NDArrayFactory::create('c', {1}, {zero}); + // auto s = NDArrayFactory::create('c', {1}, {1}); - matrix.linspace(1); - grad.linspace(1); + auto grad = NDArrayFactory::create('c', {5}); - sd::ops::strided_slice_bp op; - auto result = op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - //ASSERT_TRUE(exp.equalsTo(z)); + matrix.linspace(1); + grad.linspace(1); + sd::ops::strided_slice_bp op; + auto result = + op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) { - int zero = 0; - auto matrix = NDArrayFactory::create('c', {1, 2}); -// auto b = NDArrayFactory::create('c', {1}, {zero}); -// auto e = NDArrayFactory::create('c', {1}, {zero}); -// auto s = NDArrayFactory::create('c', {1}, {1}); + int zero = 0; + auto matrix = NDArrayFactory::create('c', {1, 2}); + // auto b = NDArrayFactory::create('c', {1}, {zero}); + // auto e = NDArrayFactory::create('c', {1}, {zero}); + // auto s = NDArrayFactory::create('c', {1}, {1}); - auto grad = NDArrayFactory::create('c', {1}, {1.}); + auto grad = NDArrayFactory::create('c', {1}, {1.}); - matrix.linspace(1); - //grad.linspace(1); - - sd::ops::strided_slice_bp op; - auto result = op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - //ASSERT_TRUE(exp.equalsTo(z)); + matrix.linspace(1); + // grad.linspace(1); + sd::ops::strided_slice_bp op; + auto result = + op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_3) { - int zero = 0; - auto matrix = NDArrayFactory::create('c', {4, 8192}); -// auto b = NDArrayFactory::create('c', {1}, {zero}); -// auto e = NDArrayFactory::create('c', {1}, {zero}); -// auto s = NDArrayFactory::create('c', {1}, {1}); - - auto grad = NDArrayFactory::create('c', {4, 256}); + int zero = 0; + auto matrix = NDArrayFactory::create('c', {4, 8192}); + // auto b = NDArrayFactory::create('c', {1}, {zero}); + // auto e = NDArrayFactory::create('c', {1}, {zero}); + // auto s = NDArrayFactory::create('c', {1}, {1}); - matrix.linspace(1); - grad.linspace(1); + auto grad = NDArrayFactory::create('c', {4, 256}); - sd::ops::strided_slice_bp op; - auto result = op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 0, 0, 0, 0, 256, 1, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - //ASSERT_TRUE(exp.equalsTo(z)); + matrix.linspace(1); + grad.linspace(1); + sd::ops::strided_slice_bp op; + auto result = + op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 0, 0, 0, 0, 256, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) { - auto x = NDArrayFactory::create('c', {1, 1}, {2.0f}); - auto exp = NDArrayFactory::create('c', {1, 1}, {4.0f}); - - sd::ops::test_scalar op; - auto result = op.evaluate({&x}, {}, {}); + auto x = NDArrayFactory::create('c', {1, 1}, {2.0f}); + auto exp = NDArrayFactory::create('c', {1, 1}, {4.0f}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::test_scalar op; + auto result = op.evaluate({&x}, {}, {}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, Test_Order_1) { - auto x = NDArrayFactory::create('f', {2, 3}); - auto exp = NDArrayFactory::create('c', {2, 3}); - x.linspace(1); - exp.linspace(1); - - sd::ops::order op; - auto result = op.evaluate({&x}, {}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_TRUE(exp.equalsTo(z)); - ASSERT_NE(x.ordering(), z->ordering()); + auto x = NDArrayFactory::create('f', {2, 3}); + auto exp = NDArrayFactory::create('c', {2, 3}); + x.linspace(1); + exp.linspace(1); + sd::ops::order op; + auto result = op.evaluate({&x}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_NE(x.ordering(), z.ordering()); } TEST_F(DeclarableOpsTests6, cumSum_1) { - auto x = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {1, 4}, {1.f, 3.f, 6.f, 10.f}); - - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {0, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto x = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {1, 4}, {1.f, 3.f, 6.f, 10.f}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, cumSum_2) { - auto x= NDArrayFactory::create('c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); - auto exp= NDArrayFactory::create('c', {2, 4}, {1.f, 3.f, 6.f, 10.f, 1.f, 3.f, 6.f, 10.f}); - - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {0, 0, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create( + 'c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 4}, {1.f, 3.f, 6.f, 10.f, 1.f, 3.f, 6.f, 10.f}); - auto z = result.at(0); - - // z->printIndexedBuffer("CumSum1"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 0, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("CumSum1"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, cumSum_3) { - auto x= NDArrayFactory::create('c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); - auto exp= NDArrayFactory::create('c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 2.f, 4.f, 6.f, 8.f}); + auto x = NDArrayFactory::create( + 'c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 2.f, 4.f, 6.f, 8.f}); - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {0, 0, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 0, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, cumSum_4) { - auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto exp = NDArrayFactory::create('c', {3, 3}, {12., 15., 18., 11., 13., 15., 7., 8., 9.}); - - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {0, 1, 0}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - // z->printBuffer(); + auto x = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {12., 15., 18., 11., 13., 15., 7., 8., 9.}); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 1, 0}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printBuffer(); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, cumSum_5) { - auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto exp = NDArrayFactory::create('c', {3, 3}, {6.f, 5.f, 3.f, 15.f, 11.f, 6.f, 24.f, 17.f, 9.f,}); - - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {0, 1, 1}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); - - + auto x = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto exp = NDArrayFactory::create('c', {3, 3}, + { + 6.f, + 5.f, + 3.f, + 15.f, + 11.f, + 6.f, + 24.f, + 17.f, + 9.f, + }); + + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 1, 1}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, cumSum_6) { - auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto exp = NDArrayFactory::create('c', {3, 3}, {11.f, 13.f, 15.f, 7.f, 8.f, 9.f, 0.f, 0.f, 0.f}); - - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {1, 1, 0}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); + auto x = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {11.f, 13.f, 15.f, 7.f, 8.f, 9.f, 0.f, 0.f, 0.f}); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {1, 1, 0}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, cumSum_7) { - auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto exp = NDArrayFactory::create('c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f}); - - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {1, 1, 1}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); + auto x = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f}); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {1, 1, 1}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, cumSum_8) { - auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto axis = NDArrayFactory::create('c', {1}, {1}); - auto exp = NDArrayFactory::create('c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f}); - - sd::ops::cumsum op; - auto result = op.evaluate({&x, &axis}, {}, {1, 1}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); + auto x = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto axis = NDArrayFactory::create('c', {1}, {1}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f}); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::cumsum op; + auto result = op.evaluate({&x, &axis}, {}, {1, 1}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, cumSum_9) { - - auto inputC = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); - auto axis = NDArrayFactory::create(1); - - auto expFF = NDArrayFactory::create('c', {3, 5}, {1., 3., 6., 10., 15., 6., 13., 21., 30., 40., 11., 23., 36., 50., 65.}); - auto expTF = NDArrayFactory::create('c', {3, 5}, {0., 1., 3., 6., 10., 0., 6., 13., 21., 30., 0., 11., 23., 36., 50.}); - - auto expFT = NDArrayFactory::create('c', {3, 5}, {15, 14, 12, 9, 5,40, 34, 27, 19, 10,65, 54, 42, 29, 15}); //+++ - auto expTT = NDArrayFactory::create('c', {3, 5}, {14, 12, 9, 5, 0,34, 27, 19, 10, 0,54, 42, 29, 15, 0}); - - int exclusive, reverse; - - //************************************// - exclusive = 0; reverse = 0; - - sd::ops::cumsum op; - auto result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - ASSERT_TRUE(expFF.equalsTo(z)); - - - //************************************// - exclusive = 1; reverse = 0; - - result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); - ASSERT_EQ(Status::OK(), result.status()); - z = result.at(0); - ASSERT_TRUE(expTF.equalsTo(z)); - - - //************************************// - exclusive = 0; reverse = 1; - - result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); - ASSERT_EQ(Status::OK(), result.status()); - z = result.at(0); - ASSERT_TRUE(expFT.equalsTo(z)); - - - //************************************// - exclusive = 1; reverse = 1; - - result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); - ASSERT_EQ(Status::OK(), result.status()); - z = result.at(0); - ASSERT_TRUE(expTT.equalsTo(z)); - - + auto inputC = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto axis = NDArrayFactory::create(1); + + auto expFF = NDArrayFactory::create( + 'c', {3, 5}, + {1., 3., 6., 10., 15., 6., 13., 21., 30., 40., 11., 23., 36., 50., 65.}); + auto expTF = NDArrayFactory::create( + 'c', {3, 5}, + {0., 1., 3., 6., 10., 0., 6., 13., 21., 30., 0., 11., 23., 36., 50.}); + + auto expFT = NDArrayFactory::create( + 'c', {3, 5}, + {15, 14, 12, 9, 5, 40, 34, 27, 19, 10, 65, 54, 42, 29, 15}); //+++ + auto expTT = NDArrayFactory::create( + 'c', {3, 5}, {14, 12, 9, 5, 0, 34, 27, 19, 10, 0, 54, 42, 29, 15, 0}); + + int exclusive, reverse; + + //************************************// + exclusive = 0; + reverse = 0; + + sd::ops::cumsum op; + auto result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(expFF.equalsTo(z)); + + //************************************// + exclusive = 1; + reverse = 0; + + result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); + ASSERT_EQ(Status::OK(), result.status()); + z = result.at(0); + ASSERT_TRUE(expTF.equalsTo(z)); + + //************************************// + exclusive = 0; + reverse = 1; + + result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); + ASSERT_EQ(Status::OK(), result.status()); + z = result.at(0); + ASSERT_TRUE(expFT.equalsTo(z)); + + //************************************// + exclusive = 1; + reverse = 1; + + result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); + ASSERT_EQ(Status::OK(), result.status()); + z = result.at(0); + ASSERT_TRUE(expTT.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, cumSum_10) { - auto x = NDArrayFactory::create('c', {4, 16, 16, 1}); - auto y = NDArrayFactory::create(-3); - - sd::ops::cumsum op; - auto result = op.evaluate({&x, &y}, {}, {1, 1}); - ASSERT_EQ(Status::OK(), result.status()); - + auto x = NDArrayFactory::create('c', {4, 16, 16, 1}); + auto y = NDArrayFactory::create(-3); + sd::ops::cumsum op; + auto result = op.evaluate({&x, &y}, {}, {1, 1}); + ASSERT_EQ(Status::OK(), result.status()); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, cumSum_11) { + NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE); + auto exp = NDArrayFactory::create( + 'c', {3, 3, 3}, + {12., 15., 18., 11., 13., 15., 7., 8., 9., 39., 42., 45., 29., 31., + 33., 16., 17., 18., 66., 69., 72., 47., 49., 51., 25., 26., 27.}); - NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE); - auto exp = NDArrayFactory::create('c', {3,3,3}, {12., 15., 18.,11., 13., 15.,7., 8., 9., 39., 42., 45.,29., 31., 33.,16., 17., 18., 66., 69., 72.,47., 49., 51.,25., 26., 27.}); - - x.linspace(1); - - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {0, 1, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); + x.linspace(1); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, cumSum_12) { + NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE); + auto exp = NDArrayFactory::create( + 'c', {3, 3, 3}, + {1., 2., 3., 5., 7., 9., 12., 15., 18., 10., 11., 12., 23., 25., + 27., 39., 42., 45., 19., 20., 21., 41., 43., 45., 66., 69., 72.}); - NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE); - auto exp = NDArrayFactory::create('c', {3,3,3}, {1., 2., 3.,5., 7., 9.,12., 15., 18., 10., 11., 12.,23., 25., 27.,39., 42., 45., 19., 20., 21.,41., 43., 45., 66., 69., 72.}); - - x.linspace(1); - - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {0, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); + x.linspace(1); - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, cumSum_13) { + NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE); + auto exp = NDArrayFactory::create( + 'c', {3, 3, 3}, + {11., 13., 15., 7., 8., 9., 0., 0., 0., 29., 31., 33., 16., 17., + 18., 0., 0., 0., 47., 49., 51., 25., 26., 27., 0., 0., 0.}); - NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE); - auto exp = NDArrayFactory::create('c', {3,3,3}, {11., 13., 15.,7., 8., 9.,0., 0., 0., 29., 31., 33.,16., 17., 18.,0., 0., 0., 47., 49., 51.,25., 26., 27.,0., 0., 0.}); - - x.linspace(1); + x.linspace(1); - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {1, 1, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {1, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, cumSum_14) { + NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE); + auto exp = NDArrayFactory::create( + 'c', {3, 3, 3}, + {29., 31., 33., 35., 37., 39., 41., 43., 45., 19., 20., 21., 22., 23., + 24., 25., 26., 27., 0., 0., 0., 0., 0., 0., 0., 0., 0.}); - NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE); - auto exp = NDArrayFactory::create('c', {3,3,3}, {29., 31., 33.,35., 37., 39.,41., 43., 45., 19., 20., 21.,22., 23., 24.,25., 26., 27., 0., 0., 0.,0., 0., 0.,0., 0., 0.}); + x.linspace(1); - x.linspace(1); - - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {1, 1, 0}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {1, 1, 0}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, cumSum_15) { + NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE); + auto exp = NDArrayFactory::create( + 'c', {3, 3, 3}, + {6., 5., 3., 15., 11., 6., 24., 17., 9., 33., 23., 12., 42., 29., + 15., 51., 35., 18., 60., 41., 21., 69., 47., 24., 78., 53., 27.}); - NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE); - auto exp = NDArrayFactory::create('c', {3,3,3}, {6., 5., 3.,15., 11., 6.,24., 17., 9., 33., 23., 12.,42., 29., 15.,51., 35., 18., 60., 41., 21.,69., 47., 24.,78., 53., 27.}); - - x.linspace(1); - - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {0, 1, 2}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); + x.linspace(1); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, cumSum_16) { + NDArray x('f', {3, 4}, sd::DataType::FLOAT32); - NDArray x('f', {3, 4}, sd::DataType::FLOAT32); - - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {0, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - // z->printShapeInfo(); - // x.printShapeInfo(); - - ASSERT_TRUE(z->ews() == 1); - ASSERT_TRUE(x.ews() == 1); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printShapeInfo(); + // x.printShapeInfo(); + ASSERT_TRUE(z.ews() == 1); + ASSERT_TRUE(x.ews() == 1); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, cumSum_17) { + NDArray x('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray x0 = x(0, {0}); + NDArray x1 = x(1, {0}); + x0.linspace(1); + x1.linspace(1); - NDArray x('c', {2, 1500}, sd::DataType::FLOAT32); - NDArray x0 = x(0, {0}); - NDArray x1 = x(1, {0}); - x0.linspace(1); - x1.linspace(1); - - NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32); - NDArray exp0 = exp(0, {0}); - NDArray exp1 = exp(1, {0}); - - exp0.p(0, 1.); - exp1.p(0, 1.); + NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray exp0 = exp(0, {0}); + NDArray exp1 = exp(1, {0}); - for (int i = 1; i < 1500; ++i) { - const auto prev = exp0.e(i-1); - exp0.p(i, prev + i + 1); - exp1.p(i, prev + i + 1); - } + exp0.p(0, 1.); + exp1.p(0, 1.); - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {0, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); + for (int i = 1; i < 1500; ++i) { + const auto prev = exp0.e(i - 1); + exp0.p(i, prev + i + 1); + exp1.p(i, prev + i + 1); + } - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, cumSum_18) { + NDArray x('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray x0 = x(0, {0}); + NDArray x1 = x(1, {0}); + x0.linspace(1); + x1.linspace(1); - NDArray x('c', {2, 1500}, sd::DataType::FLOAT32); - NDArray x0 = x(0, {0}); - NDArray x1 = x(1, {0}); - x0.linspace(1); - x1.linspace(1); - - NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32); - NDArray exp0 = exp(0, {0}); - NDArray exp1 = exp(1, {0}); + NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray exp0 = exp(0, {0}); + NDArray exp1 = exp(1, {0}); - exp0.p(0, 0.); - exp1.p(0, 0.); + exp0.p(0, 0.); + exp1.p(0, 0.); - for (int i = 1; i < 1500; ++i) { - const auto prev = exp0.e(i-1); - exp0.p(i, prev + i); - exp1.p(i, prev + i); - } + for (int i = 1; i < 1500; ++i) { + const auto prev = exp0.e(i - 1); + exp0.p(i, prev + i); + exp1.p(i, prev + i); + } - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {1, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {1, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, cumSum_19) { + NDArray x('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray x0 = x(0, {0}); + NDArray x1 = x(1, {0}); + x0.linspace(1); + x1.linspace(1); - NDArray x('c', {2, 1500}, sd::DataType::FLOAT32); - NDArray x0 = x(0, {0}); - NDArray x1 = x(1, {0}); - x0.linspace(1); - x1.linspace(1); + NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray exp0 = exp(0, {0}); + NDArray exp1 = exp(1, {0}); - NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32); - NDArray exp0 = exp(0, {0}); - NDArray exp1 = exp(1, {0}); + exp0.p(1499, 1500.f); + exp1.p(1499, 1500.f); - exp0.p(1499, 1500.f); - exp1.p(1499, 1500.f); + for (int i = 1498; i >= 0; --i) { + const auto prev = exp0.e(i + 1); + exp0.p(i, prev + i + 1); + exp1.p(i, prev + i + 1); + } - for (int i = 1498; i >= 0; --i) { - const auto prev = exp0.e(i + 1); - exp0.p(i, prev + i + 1); - exp1.p(i, prev + i + 1); - } - - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {0, 1, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - // exp0.printBuffer(); - - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // exp0.printBuffer(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, cumSum_20) { + NDArray x('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray x0 = x(0, {0}); + NDArray x1 = x(1, {0}); + x0.linspace(1); + x1.linspace(1); - NDArray x('c', {2, 1500}, sd::DataType::FLOAT32); - NDArray x0 = x(0, {0}); - NDArray x1 = x(1, {0}); - x0.linspace(1); - x1.linspace(1); - - NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32); - NDArray exp0 = exp(0, {0}); - NDArray exp1 = exp(1, {0}); - - exp0.p(1499, 0.); - exp1.p(1499, 0.); - - for (int i = 1498; i >= 0; --i) { - const auto prev = exp0.e(i + 1); - exp0.p(i, prev + i + 2); - exp1.p(i, prev + i + 2); - } + NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray exp0 = exp(0, {0}); + NDArray exp1 = exp(1, {0}); - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {1, 1, 1}); - ASSERT_EQ(Status::OK(), result.status()); + exp0.p(1499, 0.); + exp1.p(1499, 0.); - auto z = result.at(0); + for (int i = 1498; i >= 0; --i) { + const auto prev = exp0.e(i + 1); + exp0.p(i, prev + i + 2); + exp1.p(i, prev + i + 2); + } - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {1, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_1) { - - auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - auto y = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); - auto z = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 20.f, 3.f, 40.f, 5.f, 60.f, 7.f, 80.f}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2}); - sd::ops::mergemaxindex op; - - auto res = op.evaluate({&x, &y, &z}, {}, {}, {}); + auto x = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + auto y = NDArrayFactory::create( + 'c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto z = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.f, 20.f, 3.f, 40.f, 5.f, 60.f, 7.f, 80.f}); + auto exp = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2}); + sd::ops::mergemaxindex op; + + auto res = op.evaluate({&x, &y, &z}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_TRUE(res.at(0)->equalsTo(exp)); + ASSERT_TRUE(res.at(0).equalsTo(exp)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_2) { - - auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 60.f, 7.f, 8.f}); - auto y = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); - auto z = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 20.f, 3.f, 40.f, 5.f, 6.f, 7.f, 80.f}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 1, 2, 1, 0, 1, 2}); - sd::ops::mergemaxindex op; - - auto ress = op.evaluate({&x, &y, &z}, {}, {sd::DataType::INT64}); + auto x = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 60.f, 7.f, 8.f}); + auto y = NDArrayFactory::create( + 'c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto z = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.f, 20.f, 3.f, 40.f, 5.f, 6.f, 7.f, 80.f}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, + {1, 2, 1, 2, 1, 0, 1, 2}); + sd::ops::mergemaxindex op; + + auto ress = op.evaluate({&x, &y, &z}, {}, {sd::DataType::INT64}); ASSERT_EQ(ND4J_STATUS_OK, ress.status()); - ASSERT_TRUE(ress.at(0)->equalsTo(exp)); + ASSERT_TRUE(ress.at(0).equalsTo(exp)); } @@ -817,853 +806,793 @@ TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_3) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, TestDropout_1) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + auto shape = NDArrayFactory::create({2, 2}); + sd::ops::dropout op; - auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - auto shape = NDArrayFactory::create({2, 2}); - sd::ops::dropout op; - - auto res = op.evaluate({&x, &shape}, {0.2f}, {113}); - - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - //res.at(0)->printIndexedBuffer("Result is "); - //x.printIndexedBuffer("Input is"); - + auto res = op.evaluate({&x, &shape}, {0.2f}, {113}); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + // res.at(0).printIndexedBuffer("Result is "); + // x.printIndexedBuffer("Input is"); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, TestMod_1) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + auto y = NDArrayFactory::create( + 'c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto exp = + NDArrayFactory::create('c', {2, 2, 2}, {1, 0, 3, 0, 5, 0, 7, 0}); + sd::ops::mod op; - auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - auto y = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 0, 3, 0, 5, 0, 7, 0}); - sd::ops::mod op; - - auto res = op.evaluate({&x, &y}); - - ASSERT_EQ(ND4J_STATUS_OK, res.status()); -// res.at(0)->printIndexedBuffer("MOD Result is "); -// x.printIndexedBuffer("Input is"); - ASSERT_TRUE(res.at(0)->equalsTo(exp)); + auto res = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + // res.at(0).printIndexedBuffer("MOD Result is "); + // x.printIndexedBuffer("Input is"); + ASSERT_TRUE(res.at(0).equalsTo(exp)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, TestMod_BP_1) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + auto y = NDArrayFactory::create( + 'c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto eps = NDArrayFactory::create( + 'c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}); + sd::ops::mod_bp op; - auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - auto y = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); - auto eps = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}); - sd::ops::mod_bp op; - - auto res = op.evaluate({&x, &y, &eps}); + auto res = op.evaluate({&x, &y, &eps}); - ASSERT_EQ(ND4J_STATUS_OK, res.status()); -// res.at(0)->printIndexedBuffer("MOD_BP Result is "); - - // x.printIndexedBuffer("Input is"); - ASSERT_TRUE(res.at(0)->equalsTo(exp)); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + // res.at(0).printIndexedBuffer("MOD_BP Result is "); + // x.printIndexedBuffer("Input is"); + ASSERT_TRUE(res.at(0).equalsTo(exp)); } /////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, TestRank_1) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + auto y = NDArrayFactory::create( + 'c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto eps = NDArrayFactory::create( + 'c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto exp = NDArrayFactory::create(3); + sd::ops::rank op; - auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - auto y = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); - auto eps = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); - auto exp = NDArrayFactory::create(3); - sd::ops::rank op; - - auto res = op.evaluate({&x}); - - ASSERT_EQ(ND4J_STATUS_OK, res.status()); + auto res = op.evaluate({&x}); - ASSERT_TRUE(res.at(0)->equalsTo(exp)); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(res.at(0).equalsTo(exp)); } TEST_F(DeclarableOpsTests6, TestDropout_2) { -// auto x0 = NDArrayFactory::create('c', {10, 10}); -// auto x1 = NDArrayFactory::create('c', {10, 10}); - auto x = NDArrayFactory::create('c', {3, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}); - - sd::ops::dropout op; + // auto x0 = NDArrayFactory::create('c', {10, 10}); + // auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x = NDArrayFactory::create( + 'c', {3, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}); - auto res = op.evaluate({&x}, {0.4f}, {113}); - - ASSERT_EQ(ND4J_STATUS_OK, res.status()); + sd::ops::dropout op; + auto res = op.evaluate({&x}, {0.4f}, {113}); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); } TEST_F(DeclarableOpsTests6, TestDropout_3) { -// auto x0 = NDArrayFactory::create('c', {10, 10}); -// auto x1 = NDArrayFactory::create('c', {10, 10}); - auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - auto shape = NDArrayFactory::create({1, 2}); + // auto x0 = NDArrayFactory::create('c', {10, 10}); + // auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + auto shape = NDArrayFactory::create({1, 2}); - sd::ops::dropout op; - - auto res = op.evaluate({&x, &shape}, {0.4f}, {113}); - - ASSERT_EQ(ND4J_STATUS_OK, res.status()); + sd::ops::dropout op; + auto res = op.evaluate({&x, &shape}, {0.4f}, {113}); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MaxPoolWithArgmax_1) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 2, 4}, {5.5, 0., 0.3, 5.5, 1.5, 0., 1.3, 6.5, 8.6, 0., 0., + 0.4, 2.5, 1., 0.3, 4.5, 1.5, 1., 1.3, 1.5, 3.5, 0., + 1.3, 2.5, 2.6, 2., 3., 1.4, 4.5, 1., 0.3, 0.5}); + auto expI = NDArrayFactory::create( + 'c', {2, 2, 2, 4}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - auto x = NDArrayFactory::create('c', {2, 2, 2, 4}, {5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5, - 1.5, 1., 1.3, 1.5, 3.5, 0., 1.3, 2.5, 2.6, 2., 3., 1.4, 4.5, 1., 0.3, 0.5}); - auto expI = NDArrayFactory::create('c', {2, 2, 2, 4}, {0, 1, 2, 3,4, 5, 6, 7,8, 9, 10, 11,12, 13, 14, 15, - 0, 1, 2, 3,4, 5, 6, 7,8, 9, 10, 11,12, 13, 14, 15}); - - sd::ops::max_pool_with_argmax op; + sd::ops::max_pool_with_argmax op; - auto res = op.evaluate({&x}, {}, {1,1,1,1,1,1,1,1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_TRUE(expI.isSameShape(res.at(0))); - ASSERT_TRUE(expI.isSameShape(res.at(1))); - ASSERT_TRUE(x.equalsTo(res.at(0))); - ASSERT_TRUE(expI.equalsTo(res.at(1))); - //x.printIndexedBuffer("Input is"); - - ASSERT_TRUE(expI.equalsTo(res.at(1))); + auto res = op.evaluate({&x}, {}, {1, 1, 1, 1, 1, 1, 1, 1, 1}); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(expI.isSameShape(res.at(0))); + ASSERT_TRUE(expI.isSameShape(res.at(1))); + ASSERT_TRUE(x.equalsTo(res.at(0))); + ASSERT_TRUE(expI.equalsTo(res.at(1))); + // x.printIndexedBuffer("Input is"); + ASSERT_TRUE(expI.equalsTo(res.at(1))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, SufficientStatistics_1) { -// auto x0 = NDArrayFactory::create('c', {10, 10}); -// auto x1 = NDArrayFactory::create('c', {10, 10}); - auto x = NDArrayFactory::create('c', {2, 2, 2, 4}, {5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5,1.5, 1., - 1.3, 1.5,3.5, 0., 1.3, 2.5,2.6, 2., 3., 1.4,4.5, 1., 0.3, 0.5}); -// ------------------------------------ - double count = 8.0; - auto sumExp = NDArrayFactory::create({30.2, 5., 7.8, 22.8}); - auto sqrExp = NDArrayFactory::create({154.22, 7., 14.34, 103.62}); - - auto axis = NDArrayFactory::create({0, 1, 2}); + // auto x0 = NDArrayFactory::create('c', {10, 10}); + // auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x = NDArrayFactory::create( + 'c', {2, 2, 2, 4}, {5.5, 0., 0.3, 5.5, 1.5, 0., 1.3, 6.5, 8.6, 0., 0., + 0.4, 2.5, 1., 0.3, 4.5, 1.5, 1., 1.3, 1.5, 3.5, 0., + 1.3, 2.5, 2.6, 2., 3., 1.4, 4.5, 1., 0.3, 0.5}); + // ------------------------------------ + double count = 8.0; + auto sumExp = NDArrayFactory::create({30.2, 5., 7.8, 22.8}); + auto sqrExp = NDArrayFactory::create({154.22, 7., 14.34, 103.62}); - sd::ops::sufficient_statistics op; + auto axis = NDArrayFactory::create({0, 1, 2}); - auto res = op.evaluate({&x, &axis}); - - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_EQ(res.at(0)->e(0), count); - ASSERT_TRUE(sumExp.equalsTo(res.at(1))); - ASSERT_TRUE(sqrExp.equalsTo(res.at(2))); + sd::ops::sufficient_statistics op; + auto res = op.evaluate({&x, &axis}); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_EQ(res.at(0).e(0), count); + ASSERT_TRUE(sumExp.equalsTo(res.at(1))); + ASSERT_TRUE(sqrExp.equalsTo(res.at(2))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, SufficientStatistics_2) { -// auto x0 = NDArrayFactory::create('c', {10, 10}); -// auto x1 = NDArrayFactory::create('c', {10, 10}); - auto x = NDArrayFactory::create('c', {2, 2, 2, 4}, {5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5, - 1.5, 1., 1.3, 1.5,3.5, 0., 1.3, 2.5,2.6, 2., 3., 1.4,4.5, 1., 0.3, 0.5}); -// ------------------------------------ - double count = 4.0; - auto sumExp = NDArrayFactory::create('c', {2, 4}, { - 18.2, 3., 4.6, 8.8, - 12., 2., 3.2, 14.} - ); - - auto sqrExp = NDArrayFactory::create('c', {2, 4}, { - 113.22, 5., 10.78, 34.62, - 41., 2., 3.56, 69.} - ); - - auto axis = NDArrayFactory::create({0, 1}); + // auto x0 = NDArrayFactory::create('c', {10, 10}); + // auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x = NDArrayFactory::create( + 'c', {2, 2, 2, 4}, {5.5, 0., 0.3, 5.5, 1.5, 0., 1.3, 6.5, 8.6, 0., 0., + 0.4, 2.5, 1., 0.3, 4.5, 1.5, 1., 1.3, 1.5, 3.5, 0., + 1.3, 2.5, 2.6, 2., 3., 1.4, 4.5, 1., 0.3, 0.5}); + // ------------------------------------ + double count = 4.0; + auto sumExp = NDArrayFactory::create( + 'c', {2, 4}, {18.2, 3., 4.6, 8.8, 12., 2., 3.2, 14.}); - sd::ops::sufficient_statistics op; + auto sqrExp = NDArrayFactory::create( + 'c', {2, 4}, {113.22, 5., 10.78, 34.62, 41., 2., 3.56, 69.}); - auto res = op.evaluate({&x, &axis}); + auto axis = NDArrayFactory::create({0, 1}); - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_EQ(res.at(0)->e(0), count); - ASSERT_TRUE(sumExp.equalsTo(res.at(1))); - ASSERT_TRUE(sqrExp.equalsTo(res.at(2))); + sd::ops::sufficient_statistics op; + auto res = op.evaluate({&x, &axis}); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_EQ(res.at(0).e(0), count); + ASSERT_TRUE(sumExp.equalsTo(res.at(1))); + ASSERT_TRUE(sqrExp.equalsTo(res.at(2))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BinCount_1) { + auto x = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 0, 1, 2, 2, 1, 2}); + // ------------------------------------ - auto x = NDArrayFactory::create('c', {2, 2, 2}, { - 1, 2, 0, 1, 2, 2, 1, 2} - ); -// ------------------------------------ - - NDArray exp('c', {3}, {1, 3, 4}, sd::DataType::INT32); + NDArray exp('c', {3}, {1, 3, 4}, sd::DataType::INT32); - sd::ops::bincount op; + sd::ops::bincount op; - auto res = op.evaluate({&x}); - - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_TRUE(exp.equalsTo(res.at(0))); + auto res = op.evaluate({&x}); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BinCount_2) { + auto x = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 0, 1, 2, 2, 1, 2}); - auto x = NDArrayFactory::create('c', {2, 2, 2}, { - 1, 2, 0, 1, 2, 2, 1, 2} - ); + auto weights = + NDArrayFactory::create('c', {2, 2, 2}, {2, 1, 3, 1, 5, 1, 1, 6}); - auto weights = NDArrayFactory::create('c', {2, 2, 2}, { - 2, 1, 3, 1, 5, 1, 1, 6} - ); + // ------------------------------------ -// ------------------------------------ + auto exp = NDArrayFactory::create({3., 4., 13.}); - auto exp = NDArrayFactory::create({3., 4., 13.}); + sd::ops::bincount op; - sd::ops::bincount op; + auto res = op.evaluate({&x, &weights}); - auto res = op.evaluate({&x, &weights}); - - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_TRUE(exp.equalsTo(res.at(0))); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BinCount_3) { + auto x = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 0, 1, 2, 2, 1, 2}); - auto x = NDArrayFactory::create('c', {2, 2, 2}, { - 1, 2, 0, 1, 2, 2, 1, 2} - ); - - auto weights = NDArrayFactory::create('c', {2, 2, 2}, { - 2, 1, 3, 1, 5, 1, 1, 6} - ); - -// ------------------------------------ + auto weights = + NDArrayFactory::create('c', {2, 2, 2}, {2, 1, 3, 1, 5, 1, 1, 6}); - auto exp = NDArrayFactory::create({3., 4.}); + // ------------------------------------ - sd::ops::bincount op; + auto exp = NDArrayFactory::create({3., 4.}); - auto res = op.evaluate({&x, &weights}, {}, {0, 2}); + sd::ops::bincount op; - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_TRUE(exp.equalsTo(res.at(0))); + auto res = op.evaluate({&x, &weights}, {}, {0, 2}); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BinCount_4) { + auto x = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 0, 1, 2, 2, 1, 2}); - auto x = NDArrayFactory::create('c', {2, 2, 2}, { - 1, 2, 0, 1, 2, 2, 1, 2} - ); - - auto weights = NDArrayFactory::create('c', {2, 2, 2}, { - 2, 1, 3, 1, 5, 1, 1, 6} - ); + auto weights = + NDArrayFactory::create('c', {2, 2, 2}, {2, 1, 3, 1, 5, 1, 1, 6}); -// ------------------------------------ + // ------------------------------------ - auto exp = NDArrayFactory::create({3., 4., 13., 0.0}); + auto exp = NDArrayFactory::create({3., 4., 13., 0.0}); - sd::ops::bincount op; + sd::ops::bincount op; - auto res = op.evaluate({&x, &weights}, {}, {4, 4}); + auto res = op.evaluate({&x, &weights}, {}, {4, 4}); - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_TRUE(exp.equalsTo(res.at(0))); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BinCount_5) { + auto x = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 0, 1, 2, 2, 1, 2}); - auto x = NDArrayFactory::create('c', {2, 2, 2}, { - 1, 2, 0, 1, 2, 2, 1, 2} - ); - - auto weights = NDArrayFactory::create('c', {2, 2, 2}, { - 2, 1, 3, 1, 5, 1, 1, 6} - ); - auto minV = NDArrayFactory::create(4); - auto maxV = NDArrayFactory::create(4); -// ------------------------------------ + auto weights = + NDArrayFactory::create('c', {2, 2, 2}, {2, 1, 3, 1, 5, 1, 1, 6}); + auto minV = NDArrayFactory::create(4); + auto maxV = NDArrayFactory::create(4); + // ------------------------------------ - auto exp = NDArrayFactory::create({3., 4., 13., 0.0}); + auto exp = NDArrayFactory::create({3., 4., 13., 0.0}); - sd::ops::bincount op; - - auto res = op.evaluate({&x, &weights, &minV, &maxV}); - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - // res->at(0)->printBuffer("BC out"); - ASSERT_TRUE(exp.equalsTo(res.at(0))); + sd::ops::bincount op; + auto res = op.evaluate({&x, &weights, &minV, &maxV}); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + // res->at(0)->printBuffer("BC out"); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_1) { + auto x = NDArrayFactory::create({2, 2, 2}); - auto x = NDArrayFactory::create( {2, 2, 2} ); - - auto y = NDArrayFactory::create({ 2, 1, 2}); - - auto exp = NDArrayFactory::create({2, 2, 2}); + auto y = NDArrayFactory::create({2, 1, 2}); - sd::ops::broadcast_dynamic_shape op; + auto exp = NDArrayFactory::create({2, 2, 2}); - auto res = op.evaluate({&x, &y}); + sd::ops::broadcast_dynamic_shape op; - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_TRUE(exp.equalsTo(res.at(0))); + auto res = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_2) { + auto x = NDArrayFactory::create({2, 2}); - auto x = NDArrayFactory::create( {2, 2} ); - - auto y = NDArrayFactory::create({2, 1, 2}); + auto y = NDArrayFactory::create({2, 1, 2}); - auto exp = NDArrayFactory::create({2, 2, 2}); + auto exp = NDArrayFactory::create({2, 2, 2}); - sd::ops::broadcast_dynamic_shape op; - - auto res = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_TRUE(exp.equalsTo(res.at(0))); + sd::ops::broadcast_dynamic_shape op; + auto res = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_3) { + auto x = NDArrayFactory::create({2, 2, 2}); - auto x = NDArrayFactory::create( {2, 2, 2} ); + auto y = NDArrayFactory::create({2, 1}); - auto y = NDArrayFactory::create({2, 1}); + auto exp = NDArrayFactory::create({2, 2, 2}); - auto exp = NDArrayFactory::create({2, 2, 2}); + sd::ops::broadcast_dynamic_shape op; - sd::ops::broadcast_dynamic_shape op; + auto res = op.evaluate({&x, &y}, {}, {}, {}); - auto res = op.evaluate({&x, &y}, {}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_TRUE(exp.equalsTo(res.at(0))); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_4) { + auto x = NDArrayFactory::create({2, 1}); - auto x = NDArrayFactory::create( {2, 1} ); + auto y = NDArrayFactory::create('c', {1}, {4}); - auto y = NDArrayFactory::create('c', {1}, {4}); + auto exp = NDArrayFactory::create({2, 4}); - auto exp = NDArrayFactory::create({2, 4}); + sd::ops::broadcast_dynamic_shape op; - sd::ops::broadcast_dynamic_shape op; + auto res = op.evaluate({&x, &y}); - auto res = op.evaluate({&x, &y}); - - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - //res->at(0)->printBuffer("Shape SGO 4"); - ASSERT_TRUE(exp.equalsTo(res.at(0))); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + // res->at(0)->printBuffer("Shape SGO 4"); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_6) { + auto x = NDArrayFactory::create({2, 1, 4}); - auto x = NDArrayFactory::create({2, 1, 4}); + auto y = NDArrayFactory::create({2, 2, 4}); - auto y = NDArrayFactory::create({2, 2, 4}); + auto exp = NDArrayFactory::create({2, 2, 4}); - auto exp = NDArrayFactory::create({2, 2, 4}); - - sd::ops::broadcast_dynamic_shape op; - auto res = op.evaluate({&x, &y}); - - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_TRUE(exp.equalsTo(res.at(0))); + sd::ops::broadcast_dynamic_shape op; + auto res = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_7) { + auto x = NDArrayFactory::create({1, 1, 3}); - auto x = NDArrayFactory::create({1, 1, 3}); - - auto y = NDArrayFactory::create({2, 4, 1}); + auto y = NDArrayFactory::create({2, 4, 1}); - auto exp = NDArrayFactory::create({2, 4, 3}); + auto exp = NDArrayFactory::create({2, 4, 3}); - sd::ops::broadcast_dynamic_shape op; - auto res = op.evaluate({&x, &y}); + sd::ops::broadcast_dynamic_shape op; + auto res = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_TRUE(exp.equalsTo(res.at(0))); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_8) { + auto x = NDArrayFactory::create('c', {1}, {1}); - auto x = NDArrayFactory::create('c', {1}, {1}); - - auto y = NDArrayFactory::create('c', {1}, {4}); + auto y = NDArrayFactory::create('c', {1}, {4}); - auto z = NDArrayFactory::create('c', {1}); + auto z = NDArrayFactory::create('c', {1}); - auto exp = NDArrayFactory::create('c', {1}, {4}); + auto exp = NDArrayFactory::create('c', {1}, {4}); - sd::ops::broadcast_dynamic_shape op; - auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + sd::ops::broadcast_dynamic_shape op; + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(exp.equalsTo(z)); } ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_9) { + auto x = NDArrayFactory::create('c', {2}, {2, 2}); - auto x = NDArrayFactory::create('c', {2}, {2,2}); - - auto y = NDArrayFactory::create('c', {1}, {1}); + auto y = NDArrayFactory::create('c', {1}, {1}); - auto z = NDArrayFactory::create('c', {2}); + auto z = NDArrayFactory::create('c', {2}); - auto exp = NDArrayFactory::create('c', {2}, {2,2}); + auto exp = NDArrayFactory::create('c', {2}, {2, 2}); - sd::ops::broadcast_dynamic_shape op; - auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); - // ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::broadcast_dynamic_shape op; + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + // ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_1) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 3}, + {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, + 0.0, 4.0, 0.0, 0.0}); - auto x = NDArrayFactory::create('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, - -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, - -3.0, 0.0, 0.0, 4.0, 0.0, 0.0} - ); - - auto exp = NDArrayFactory::create('c', {2, 3, 3}, { - -0.2771281, 0., 0., - 0.36950415, 0., 0., - -0.2771281, 0., 0., - 0.36950415, 0., 0., - -0.2771281, 0., 0., - 0.36950415, 0., 0.} - ); -// 8.660254 -// auto expNorm(8.660254); - - sd::ops::clip_by_global_norm op; - auto result = op.evaluate({&x}, {0.8}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 3}, + {-0.2771281, 0., 0., 0.36950415, 0., 0., -0.2771281, 0., 0., 0.36950415, + 0., 0., -0.2771281, 0., 0., 0.36950415, 0., 0.}); + // 8.660254 + // auto expNorm(8.660254); - auto z = result.at(0); - auto norm = result.at(1); - //z->printIndexedBuffer("Output"); - //exp.printIndexedBuffer("Expected"); - //norm->printIndexedBuffer("Norm"); + sd::ops::clip_by_global_norm op; + auto result = op.evaluate({&x}, {0.8}, {}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); -// ASSERT_TRUE(expNorm.equalsTo(norm)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + auto norm = result.at(1); + // z->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expected"); + // norm->printIndexedBuffer("Norm"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + // ASSERT_TRUE(expNorm.equalsTo(norm)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_2) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 3}, + {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, + 0.0, 4.0, 0.0, 0.0}); - auto x = NDArrayFactory::create('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, - -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, - -3.0, 0.0, 0.0, 4.0, 0.0, 0.0} - ); - - auto a = NDArrayFactory::create('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, - -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, - -3.0, 0.0, 0.0, 4.0, 0.0, 0.0} - ); - - auto exp = NDArrayFactory::create('c', {2, 3, 3}, { - -0.44090813, 0., 0., - 0.5878775, 0., 0., - -0.44090813, 0., 0., - 0.5878775, 0., 0., - -0.44090813, 0., 0., - 0.5878775, 0., 0.} -//12.247449 + auto a = NDArrayFactory::create( + 'c', {2, 3, 3}, + {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, + 0.0, 4.0, 0.0, 0.0}); - ); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 3}, + {-0.44090813, 0., 0., 0.5878775, 0., 0., -0.44090813, 0., 0., 0.5878775, + 0., 0., -0.44090813, 0., 0., 0.5878775, 0., 0.} // 12.247449 - sd::ops::clip_by_global_norm op; - auto result = op.evaluate({&x, &a}, {1.8}, {}); + ); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - auto y = result.at(1); + sd::ops::clip_by_global_norm op; + auto result = op.evaluate({&x, &a}, {1.8}, {}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.isSameShape(y)); - ASSERT_TRUE(exp.equalsTo(z)); - ASSERT_TRUE(exp.equalsTo(y)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + auto y = result.at(1); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.isSameShape(y)); + ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.equalsTo(y)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_3) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 3}, + {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, + 0.0, 4.0, 0.0, 0.0}); + auto a = NDArrayFactory::create( + 'c', {2, 3, 3}, + {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, + 0.0, 4.0, 0.0, 0.0}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 3}, + {-0.19595918, 0., 0., 0.2612789, 0., 0., -0.19595918, 0., 0., 0.2612789, + 0., 0., -0.19595918, 0., 0., 0.2612789, 0., 0.}); + + sd::ops::clip_by_global_norm op; + auto result = op.evaluate({&x, &a}, {0.8}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + auto y = result.at(1); + // z->printIndexedBuffer("Output 1"); + // y->printIndexedBuffer("Output 2"); + // result.at(2)->printIndexedBuffer("Global norm is"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.isSameShape(y)); + ASSERT_TRUE(result.at(2).isScalar()); + ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.equalsTo(y)); +} - auto x = NDArrayFactory::create('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0}); - auto a = NDArrayFactory::create('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0}); - auto exp = NDArrayFactory::create('c', {2, 3, 3}, { - -0.19595918, 0., 0., - 0.2612789, 0., 0., - -0.19595918, 0., 0., - 0.2612789, 0., 0., - -0.19595918, 0., 0., - 0.2612789, 0., 0.} - ); - - sd::ops::clip_by_global_norm op; - auto result = op.evaluate({&x, &a}, {0.8}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, MatrixDeterminant_1) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 3}, + {-3.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, -3.0, 4.0, 0.0, 0.0, 0.0, -3.0, + 0.0, 0.0, 0.0, 4.0}); + auto exp = NDArrayFactory::create({36.0, -48.0}); - auto z = result.at(0); - auto y = result.at(1); - //z->printIndexedBuffer("Output 1"); - //y->printIndexedBuffer("Output 2"); - //result.at(2)->printIndexedBuffer("Global norm is"); + sd::ops::matrix_determinant op; + auto result = op.evaluate({&x}, {}, {}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.isSameShape(y)); - ASSERT_TRUE(result.at(2)->isScalar()); - ASSERT_TRUE(exp.equalsTo(z)); - ASSERT_TRUE(exp.equalsTo(y)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Output "); + // exp.printIndexedBuffer("Expected "); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests6, MatrixDeterminant_1) { +TEST_F(DeclarableOpsTests6, MatrixDeterminant_2) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0}); + auto exp = NDArrayFactory::create({-2.0, -2.0}); - auto x = NDArrayFactory::create('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, -3.0, 4.0, 0.0, 0.0, 0.0, -3.0, 0.0, 0.0, 0.0, 4.0}); - auto exp = NDArrayFactory::create({36.0, -48.0}); + sd::ops::matrix_determinant op; + auto result = op.evaluate({&x}, {}, {}); - sd::ops::matrix_determinant op; - auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Output "); + // exp.printIndexedBuffer("Expected "); - auto z = result.at(0); - //z->printIndexedBuffer("Output "); - //exp.printIndexedBuffer("Expected "); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests6, MatrixDeterminant_2) { - - auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0}); - auto exp = NDArrayFactory::create({-2.0, -2.0}); - - sd::ops::matrix_determinant op; - auto result = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - //z->printIndexedBuffer("Output "); - //exp.printIndexedBuffer("Expected "); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - -} + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixDeterminant_3) { + auto x = NDArrayFactory::create( + 'c', {1, 3, 3}, {3.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 3.0}); + NDArray exp('c', {1}, std::vector{-54.0}); - auto x = NDArrayFactory::create('c', {1, 3, 3}, {3.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 3.0}); - NDArray exp('c', {1}, std::vector{-54.0}); + sd::ops::matrix_determinant op; + auto result = op.evaluate({&x}, {}, {}); - sd::ops::matrix_determinant op; - auto result = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - //z->printIndexedBuffer("Output "); - //exp.printIndexedBuffer("Expected "); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Output "); + // exp.printIndexedBuffer("Expected "); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixDeterminant_4) { + auto x = NDArrayFactory::create( + 'c', {1, 3, 3}, {12.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 13.0}); + auto exp = NDArrayFactory::create('c', {1}, {189.0}); - auto x = NDArrayFactory::create('c', {1, 3, 3}, {12.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 13.0}); - auto exp = NDArrayFactory::create('c', {1}, {189.0}); - - sd::ops::matrix_determinant op; - auto result = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - // z->printIndexedBuffer("Output "); - // exp.printIndexedBuffer("Expected "); - // z->printShapeInfo("Output shape"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::matrix_determinant op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Output "); + // exp.printIndexedBuffer("Expected "); + // z->printShapeInfo("Output shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixDeterminant_5) { + auto x = NDArrayFactory::create('c', {1, 4, 4}); + NDArray exp('c', {1}, std::vector{-16.0}); + x.linspace(1); + x.p(5, 4.0); + x.p(12, 12.0); - auto x = NDArrayFactory::create('c', {1, 4, 4}); - NDArray exp('c', {1}, std::vector{-16.0}); - x.linspace(1); - x.p(5, 4.0); - x.p(12, 12.0); - - sd::ops::matrix_determinant op; - auto result = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - //z->printIndexedBuffer("Output "); - //exp.printIndexedBuffer("Expected "); + sd::ops::matrix_determinant op; + auto result = op.evaluate({&x}, {}, {}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Output "); + // exp.printIndexedBuffer("Expected "); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixDeterminant_6) { + auto x = NDArrayFactory::create('c', {4, 4}); + auto exp = NDArrayFactory::create(-16.0); + x.linspace(1); + x.p(5, 4.0); + x.p(12, 12.0); - auto x = NDArrayFactory::create('c', {4, 4}); - auto exp = NDArrayFactory::create(-16.0); - x.linspace(1); - x.p(5, 4.0); - x.p(12, 12.0); - - sd::ops::matrix_determinant op; - auto result = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - //z->printIndexedBuffer("Output "); - //z->printShapeInfo("Shape"); - //exp.printIndexedBuffer("Expected "); - ASSERT_TRUE(z->isScalar()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::matrix_determinant op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Output "); + // z->printShapeInfo("Shape"); + // exp.printIndexedBuffer("Expected "); + ASSERT_TRUE(z.isScalar()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, LogMatrixDeterminant_1) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 3}, + {-3.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, -3.0, 4.0, 0.0, 0.0, 0.0, -3.0, + 0.0, 0.0, 0.0, 4.0}); + auto exp = + NDArrayFactory::create({3.58351893845611, 3.871201010907891}); - auto x = NDArrayFactory::create('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, -3.0, 4.0, 0.0, 0.0, 0.0, -3.0, 0.0, 0.0, 0.0, 4.0}); - auto exp = NDArrayFactory::create({3.58351893845611, 3.871201010907891}); - - sd::ops::log_matrix_determinant op; - auto result = op.evaluate({&x}, {}, {}); + sd::ops::log_matrix_determinant op; + auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, LogDet_1) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 3}, + {4, 12, -16, 12, 37, -43, -16, -43, 98, 4, 1.2, -1.6, 1.2, 3.7, -4.3, + -1.6, -4.3, 9.8}); + auto exp = NDArrayFactory::create({3.5835189, 4.159008}); - auto x = NDArrayFactory::create('c', {2, 3, 3}, {4,12,-16,12,37,-43,-16,-43,98, 4,1.2,-1.6,1.2,3.7,-4.3,-1.6,-4.3,9.8}); - auto exp = NDArrayFactory::create({ 3.5835189, 4.159008}); - - sd::ops::logdet op; - auto result = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::logdet op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, LogDet_2) { + auto x = NDArrayFactory::create( + 'c', {1, 3, 3}, {4, 12, -16, 12, 37, -43, -16, -43, 98}); + auto exp = NDArrayFactory::create('c', {1}, {3.5835189}); - auto x = NDArrayFactory::create('c', {1, 3, 3}, {4,12,-16,12,37,-43,-16,-43,98}); - auto exp = NDArrayFactory::create('c', {1}, { 3.5835189}); - - sd::ops::logdet op; - auto result = op.evaluate({&x}, {}, {}); + sd::ops::logdet op; + auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, LogDet_3) { + auto x = NDArrayFactory::create( + 'c', {3, 3}, {4, 12, -16, 12, 37, -43, -16, -43, 98}); + auto exp = NDArrayFactory::create(3.5835189); - auto x = NDArrayFactory::create('c', {3, 3}, {4,12,-16,12,37,-43,-16,-43,98}); - auto exp = NDArrayFactory::create( 3.5835189); - - sd::ops::logdet op; - auto result = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + sd::ops::logdet op; + auto result = op.evaluate({&x}, {}, {}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixInverse_1) { + auto x = NDArrayFactory::create( + 'c', {2, 5, 5}, + {2.f, 4.f, 60.f, 8.f, 10.f, 0.f, 1.f, 2.f, 3.f, 4.f, 0.f, 0.f, 2.f, + 4.f, 6.f, 0.f, 0.f, 0.f, 1.f, 2.f, 0.f, 0.f, 0.f, 0.f, 4.f, - auto x = NDArrayFactory::create('c', {2, 5, 5}, { - 2.f, 4.f, 60.f, 8.f, 10.f, - 0.f, 1.f, 2.f, 3.f, 4.f, - 0.f, 0.f, 2.f, 4.f, 6.f, - 0.f, 0.f, 0.f, 1.f, 2.f, - 0.f, 0.f, 0.f, 0.f, 4.f, - - 1.f, 0.f, 0.f, 0.f, 0.f, - 2.f, 1.f, 0.f, 0.f, 0.f, - 30.f, 2.f, 1.f, 0.f, 0.f, - 4.f, 3.f, 2.f, 1.f, 0.f, - 5.f, 4.f, 3.f, 2.f, 1.f - }); - - auto exp = NDArrayFactory::create('c', {2, 5, 5}, { - 0.5f, -2.0f, -13.0f, 54.0f, -6.75f, - 0.0f, 1.0f, -1.0f, 1.0f, 0.0f, - 0.f, 0.f, 0.5f, -2.0f, 0.25f, - 0.f, 0.f, 0.f, 1.0f, -0.5f, - 0.f, 0.f, 0.f, 0.f, 0.25f, - - 1.0f, 0.0f, 0.0f, 0.0f, 0.f, - -2.0f, 1.0f, 0.f, 0.f, 0.f, - -26.0f, -2.0f, 1.f, 0.f, 0.f, - 54.0f, 1.0f, -2.0f, 1.f, 0.f, - -27.0f, 0.0f, 1.0f, -2.0f, 1.f, - }); + 1.f, 0.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, 0.f, 30.f, 2.f, 1.f, + 0.f, 0.f, 4.f, 3.f, 2.f, 1.f, 0.f, 5.f, 4.f, 3.f, 2.f, 1.f}); - sd::ops::matrix_inverse op; - auto result = op.evaluate({&x}); + auto exp = NDArrayFactory::create( + 'c', {2, 5, 5}, + { + 0.5f, -2.0f, -13.0f, 54.0f, -6.75f, 0.0f, 1.0f, -1.0f, 1.0f, + 0.0f, 0.f, 0.f, 0.5f, -2.0f, 0.25f, 0.f, 0.f, 0.f, + 1.0f, -0.5f, 0.f, 0.f, 0.f, 0.f, 0.25f, - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + 1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, + 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, + 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f, + }); - auto z = result.at(0); + sd::ops::matrix_inverse op; + auto result = op.evaluate({&x}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixInverse_010) { + auto x = NDArrayFactory::create( + 'c', {1, 5, 5}, + { + 1.f, 0.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, 0.f, 30.f, 2.f, 1.f, + 0.f, 0.f, 4.f, 3.f, 2.f, 1.f, 0.f, 5.f, 4.f, 3.f, 2.f, 1.f, + }); + auto exp = NDArrayFactory::create( + 'c', {1, 5, 5}, + {1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, + 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, + 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f}); - auto x = NDArrayFactory::create('c', {1, 5, 5}, {1.f, 0.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, 0.f, 30.f, 2.f, 1.f, 0.f, 0.f, 4.f, 3.f, 2.f, 1.f, 0.f, 5.f, 4.f, 3.f, 2.f, 1.f, }); - auto exp = NDArrayFactory::create('c', {1, 5, 5}, {1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f}); - - sd::ops::matrix_inverse op; - auto result = op.evaluate({&x}); + sd::ops::matrix_inverse op; + auto result = op.evaluate({&x}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixInverse_01) { + auto x = NDArrayFactory::create( + 'c', {1, 5, 5}, + {2.f, 4.f, 60.f, 8.f, 10.f, 0.f, 1.f, 2.f, 3.f, 4.f, 0.f, 0.f, 2.f, + 4.f, 6.f, 0.f, 0.f, 0.f, 1.f, 2.f, 0.f, 0.f, 0.f, 0.f, 4.f}); - auto x = NDArrayFactory::create('c', {1, 5, 5}, {2.f, 4.f, 60.f, 8.f, 10.f, 0.f, 1.f, 2.f, 3.f, 4.f, 0.f, 0.f, 2.f, 4.f, 6.f, 0.f, 0.f, 0.f, 1.f, 2.f, 0.f, 0.f, 0.f, 0.f, 4.f }); + auto exp = NDArrayFactory::create( + 'c', {1, 5, 5}, + {0.5f, -2.0f, -13.0f, 54.0f, -6.75f, 0.0f, 1.0f, -1.0f, 1.0f, + 0.0f, 0.f, 0.f, 0.5f, -2.0f, 0.25f, 0.f, 0.f, 0.f, + 1.0f, -0.5f, 0.f, 0.f, 0.f, 0.f, 0.25f}); + sd::ops::matrix_inverse op; + auto result = op.evaluate({&x}); - auto exp = NDArrayFactory::create('c', {1, 5, 5}, {0.5f, -2.0f, -13.0f, 54.0f, -6.75f, 0.0f, 1.0f, -1.0f, 1.0f, 0.0f, 0.f, 0.f, 0.5f, -2.0f, 0.25f, 0.f, 0.f, 0.f, 1.0f, -0.5f, 0.f, 0.f, 0.f, 0.f, 0.25f }); - sd::ops::matrix_inverse op; - auto result = op.evaluate({&x}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixInverse_02) { + auto x = NDArrayFactory::create( + 'c', {1, 5, 5}, + {1.f, 0.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, 0.f, 30.f, 2.f, 1.f, + 0.f, 0.f, 4.f, 3.f, 2.f, 1.f, 0.f, 5.f, 4.f, 3.f, 2.f, 1.f}); + auto exp = NDArrayFactory::create( + 'c', {1, 5, 5}, + {1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, + 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, + 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f}); - auto x = NDArrayFactory::create('c', {1, 5, 5}, {1.f, 0.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, 0.f, 30.f, 2.f, 1.f, 0.f, 0.f, 4.f, 3.f, 2.f, 1.f, 0.f, 5.f, 4.f, 3.f, 2.f, 1.f }); - auto exp = NDArrayFactory::create('c', {1, 5, 5}, {1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f }); + sd::ops::matrix_inverse op; + auto result = op.evaluate({&x}); - sd::ops::matrix_inverse op; - auto result = op.evaluate({&x}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// @@ -1714,1092 +1643,1265 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_2) { } */ TEST_F(DeclarableOpsTests6, MatrixInverse_03) { + auto x = NDArrayFactory::create( + 'c', {5, 5}, + { + 4.f, 0.f, 0.f, 0.f, 0.f, 4.f, 2.f, 0.f, 0.f, 0.f, 30.f, 2.f, 1.f, + 0.f, 0.f, 8.f, 6.f, 4.f, 2.f, 0.f, 15.f, 12.f, 9.f, 6.f, 3.f, + }); - auto x = NDArrayFactory::create('c', {5, 5}, { - 4.f, 0.f, 0.f, 0.f, 0.f, - 4.f, 2.f, 0.f, 0.f, 0.f, - 30.f, 2.f, 1.f, 0.f, 0.f, - 8.f, 6.f, 4.f, 2.f, 0.f, - 15.f, 12.f, 9.f, 6.f, 3.f, - }); + auto exp = NDArrayFactory::create( + 'c', {5, 5}, + {0.25f, 0.0f, 0.0f, 0.0f, 0.0f, -0.50f, 0.5f, 0.0f, 0.0f, + 0.0f, -6.50f, -1.0f, 1.0f, 0.0f, 0.0f, 13.50f, 0.5f, -2.0f, + 0.5f, 0.0f, -6.75f, 0.0f, 1.0f, -1.0f, 0.33333333f}); - auto exp = NDArrayFactory::create('c', {5, 5}, { - 0.25f, 0.0f, 0.0f, 0.0f, 0.0f, - -0.50f, 0.5f, 0.0f, 0.0f, 0.0f, - -6.50f, -1.0f, 1.0f, 0.0f, 0.0f, - 13.50f, 0.5f, -2.0f, 0.5f, 0.0f, - -6.75f, 0.0f, 1.0f, -1.0f, 0.33333333f - }); + sd::ops::matrix_inverse op; + auto result = op.evaluate({&x}); - sd::ops::matrix_inverse op; - auto result = op.evaluate({&x}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); -// z->printIndexedBuffer("Output "); -// exp.printIndexedBuffer("Expected "); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Output "); + // exp.printIndexedBuffer("Expected "); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixInverse_3) { + auto x = NDArrayFactory::create( + 'c', {5, 5}, + { + 4.f, 0.f, 0.f, 0.f, 0.f, 4.f, 2.f, 0.f, 0.f, 0.f, 30.f, 2.f, 1.f, + 0.f, 0.f, 8.f, 6.f, 4.f, 2.f, 0.f, 15.f, 12.f, 9.f, 6.f, 3.f, + }); - auto x = NDArrayFactory::create('c', {5, 5}, { - 4.f, 0.f, 0.f, 0.f, 0.f, - 4.f, 2.f, 0.f, 0.f, 0.f, - 30.f, 2.f, 1.f, 0.f, 0.f, - 8.f, 6.f, 4.f, 2.f, 0.f, - 15.f, 12.f, 9.f, 6.f, 3.f, - }); - - auto exp = NDArrayFactory::create('c', {5, 5}, { - 0.25f, 0.0f, 0.0f, 0.0f, 0.0f, - -0.50f, 0.5f, 0.0f, 0.0f, 0.0f, - -6.50f, -1.0f, 1.0f, 0.0f, 0.0f, - 13.50f, 0.5f, -2.0f, 0.5f, 0.0f, - -6.75f, 0.0f, 1.0f, -1.0f, 0.33333333f - }); - - sd::ops::matrix_inverse op; - auto result = op.evaluate({&x}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto exp = NDArrayFactory::create( + 'c', {5, 5}, + {0.25f, 0.0f, 0.0f, 0.0f, 0.0f, -0.50f, 0.5f, 0.0f, 0.0f, + 0.0f, -6.50f, -1.0f, 1.0f, 0.0f, 0.0f, 13.50f, 0.5f, -2.0f, + 0.5f, 0.0f, -6.75f, 0.0f, 1.0f, -1.0f, 0.33333333f}); - auto z = result.at(0); -// exp.printIndexedBuffer("Expected "); -// z->printIndexedBuffer("Output "); + sd::ops::matrix_inverse op; + auto result = op.evaluate({&x}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // exp.printIndexedBuffer("Expected "); + // z->printIndexedBuffer("Output "); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixInverse_4) { + auto x = NDArrayFactory::create( + 'c', {5, 5}, + {1.f, 2.f, 30.f, 4.f, 5.f, 0.f, 1.f, 2.f, 3.f, 4.f, 0.f, 0.f, 1.f, + 2.f, 3.f, 0.f, 0.f, 0.f, 1.f, 2.f, 0.f, 0.f, 0.f, 0.f, 1.f}); - auto x = NDArrayFactory::create('c', {5, 5}, { - 1.f, 2.f, 30.f, 4.f, 5.f, - 0.f, 1.f, 2.f, 3.f, 4.f, - 0.f, 0.f, 1.f, 2.f, 3.f, - 0.f, 0.f, 0.f, 1.f, 2.f, - 0.f, 0.f, 0.f, 0.f, 1.f - }); - - auto exp = NDArrayFactory::create('c', {5, 5}, { - 1.0f, -2.0f, -26.0f, 54.0f, -27.0f, - 0.0f, 1.0f, -2.0f, 1.0f, 0.0f, - 0.0f, 0.0f, 1.0f, -2.0f, 1.0f, - 0.0f, 0.0f, 0.0f, 1.0f, -2.0f, - 0.0f, 0.0f, 0.0f, 0.0f, 1.0f - }); - - sd::ops::matrix_inverse op; - auto result = op.evaluate({&x}); + auto exp = NDArrayFactory::create( + 'c', {5, 5}, {1.0f, -2.0f, -26.0f, 54.0f, -27.0f, 0.0f, 1.0f, -2.0f, 1.0f, + 0.0f, 0.0f, 0.0f, 1.0f, -2.0f, 1.0f, 0.0f, 0.0f, 0.0f, + 1.0f, -2.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::matrix_inverse op; + auto result = op.evaluate({&x}); - auto z = result.at(0); -// z->printIndexedBuffer("Output "); -// exp.printIndexedBuffer("Expected "); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Output "); + // exp.printIndexedBuffer("Expected "); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixInverse_04) { + auto x = NDArrayFactory::create( + 'c', {5, 5}, + {1.f, 2.f, 30.f, 4.f, 5.f, 0.f, 1.f, 2.f, 3.f, 4.f, 0.f, 0.f, 1.f, + 2.f, 3.f, 0.f, 0.f, 0.f, 1.f, 2.f, 0.f, 0.f, 0.f, 0.f, 1.f}); - auto x = NDArrayFactory::create('c', {5, 5}, { - 1.f, 2.f, 30.f, 4.f, 5.f, - 0.f, 1.f, 2.f, 3.f, 4.f, - 0.f, 0.f, 1.f, 2.f, 3.f, - 0.f, 0.f, 0.f, 1.f, 2.f, - 0.f, 0.f, 0.f, 0.f, 1.f - }); - - auto exp = NDArrayFactory::create('c', {5, 5}, { - 1.0f, -2.0f, -26.0f, 54.0f, -27.0f, - 0.0f, 1.0f, -2.0f, 1.0f, 0.0f, - 0.0f, 0.0f, 1.0f, -2.0f, 1.0f, - 0.0f, 0.0f, 0.0f, 1.0f, -2.0f, - 0.0f, 0.0f, 0.0f, 0.0f, 1.0f - }); - - sd::ops::matrix_inverse op; - auto result = op.evaluate({&x}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto exp = NDArrayFactory::create( + 'c', {5, 5}, {1.0f, -2.0f, -26.0f, 54.0f, -27.0f, 0.0f, 1.0f, -2.0f, 1.0f, + 0.0f, 0.0f, 0.0f, 1.0f, -2.0f, 1.0f, 0.0f, 0.0f, 0.0f, + 1.0f, -2.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f}); - auto z = result.at(0); -// z->printIndexedBuffer("Output "); -// exp.printIndexedBuffer("Expected "); + sd::ops::matrix_inverse op; + auto result = op.evaluate({&x}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Output "); + // exp.printIndexedBuffer("Expected "); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, ReluLayer_1) { - auto x = NDArrayFactory::create('c', {3, 4}, {1.0, -2.0, 3.0, 4.0, 5.0, -6.0, 7.0, 8.0, 9.0, -10.0, 11.0, 12}); - auto w = NDArrayFactory::create('c', {4, 3}, {0.5, 0.1, 0.8, 0.5, 0.2, 0.5, 0.5, 0.25, 0.5, 0.1, 0.0, 0.25}); - auto b = NDArrayFactory::create({20.0, 30.0, 50.0}); - + auto x = NDArrayFactory::create( + 'c', {3, 4}, + {1.0, -2.0, 3.0, 4.0, 5.0, -6.0, 7.0, 8.0, 9.0, -10.0, 11.0, 12}); + auto w = NDArrayFactory::create( + 'c', {4, 3}, + {0.5, 0.1, 0.8, 0.5, 0.2, 0.5, 0.5, 0.25, 0.5, 0.1, 0.0, 0.25}); + auto b = NDArrayFactory::create({20.0, 30.0, 50.0}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {21.4, 30.45, 52.3, 23.8, 31.05, 56.5, 26.2, 31.65, 60.7}); - auto exp = NDArrayFactory::create('c', {3, 3}, { - 21.4, 30.45, 52.3, - 23.8, 31.05, 56.5, - 26.2, 31.65, 60.7}); + sd::ops::relu_layer op; + auto result = op.evaluate({&x, &w, &b}); - sd::ops::relu_layer op; - auto result = op.evaluate({&x, &w, &b}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - // z->printShapeInfo("Output shape"); - // z->printIndexedBuffer("Output "); - // exp.printIndexedBuffer("Expected "); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printShapeInfo("Output shape"); + // z->printIndexedBuffer("Output "); + // exp.printIndexedBuffer("Expected "); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, Test_Reduce3_Edge) { - auto x = NDArrayFactory::create('c', {3, 4, 5}); - auto y = NDArrayFactory::create('c', {3, 4, 5}); - + auto x = NDArrayFactory::create('c', {3, 4, 5}); + auto y = NDArrayFactory::create('c', {3, 4, 5}); - std::vector dims = {0, 1}; - auto z = x.applyReduce3(reduce3::CosineSimilarity, y, dims); - ASSERT_TRUE(&z != nullptr); + std::vector dims = {0, 1}; + auto z = x.applyReduce3(reduce3::CosineSimilarity, y, dims); + ASSERT_TRUE(z.defined()); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, static_rnn_test1) { - - const int bS = 2; - const int inSize = 3; - const int numUnits = 4; - const int time = 5; - - auto x = NDArrayFactory::create('c', {time, bS, inSize}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}); - auto h0 = NDArrayFactory::create('c', {bS, numUnits}); - auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-3}); - - x.linspace(0.01, 0.01); - h0 = 0.2; - Wx = 0.3; - Wh = 0.4; - b = 0.25; - - auto expH = NDArrayFactory::create('c', {time, bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.69882484, 0.69882484, 0.69882484, 0.69882484, 0.9312333 , 0.9312333 , 0.9312333 , 0.9312333 , - 0.93751527, 0.93751527, 0.93751527, 0.93751527,0.97136768, 0.97136768, 0.97136768, 0.97136768,0., 0., 0., 0. , - 0.97732812, 0.97732812, 0.97732812, 0.97732812,0., 0., 0., 0. ,0., 0., 0., 0.,0., 0., 0., 0.}); - - auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, 0.93751527, 0.93751527}); - - sd::ops::static_rnn op; - auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFinal = results.at(1); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFinal.isSameShape(hFinal)); - ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - - + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2 * numUnits}); + auto h0 = NDArrayFactory::create('c', {bS, numUnits}); + auto maxTimeStep = + NDArrayFactory::create('c', {bS}, {time - 1, time - 3}); + + x.linspace(0.01, 0.01); + h0 = 0.2; + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create( + 'c', {time, bS, numUnits}, + {0.68474828, 0.68474828, 0.68474828, 0.68474828, 0.69882484, 0.69882484, + 0.69882484, 0.69882484, 0.9312333, 0.9312333, 0.9312333, 0.9312333, + 0.93751527, 0.93751527, 0.93751527, 0.93751527, 0.97136768, 0.97136768, + 0.97136768, 0.97136768, 0., 0., 0., 0., + 0.97732812, 0.97732812, 0.97732812, 0.97732812, 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0.}); + + auto expHFinal = NDArrayFactory::create( + 'c', {bS, numUnits}, + {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, + 0.93751527, 0.93751527}); + + sd::ops::static_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, static_rnn_test2) { - - const int bS = 2; - const int inSize = 3; - const int numUnits = 4; - const int time = 5; - - auto x = NDArrayFactory::create('c', {time, bS, inSize}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}); - auto h0 = NDArrayFactory::create('c', {bS, numUnits}); - - x.linspace(0.01, 0.01); - h0 = 0.2; - Wx = 0.3; - Wh = 0.4; - b = 0.25; - - auto expH = NDArrayFactory::create('c', {time, bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.69882484, 0.69882484, 0.69882484, 0.69882484,0.9312333 , 0.9312333 , 0.9312333 , 0.9312333, - 0.93751527, 0.93751527, 0.93751527, 0.93751527,0.97136768, 0.97136768, 0.97136768, 0.97136768,0.97338548, 0.97338548, 0.97338548, 0.97338548, - 0.97732812, 0.97732812, 0.97732812, 0.97732812,0.97864398, 0.97864398, 0.97864398, 0.97864398,0.98000654, 0.98000654, 0.98000654, 0.98000654, - 0.98112648, 0.98112648, 0.98112648, 0.98112648}); - - auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.98000654, 0.98000654, 0.98000654, 0.98000654,0.98112648, 0.98112648, 0.98112648, 0.98112648}); - - sd::ops::static_rnn op; - auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFinal = results.at(1); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFinal.isSameShape(hFinal)); - ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - - + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2 * numUnits}); + auto h0 = NDArrayFactory::create('c', {bS, numUnits}); + + x.linspace(0.01, 0.01); + h0 = 0.2; + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create( + 'c', {time, bS, numUnits}, + {0.68474828, 0.68474828, 0.68474828, 0.68474828, 0.69882484, 0.69882484, + 0.69882484, 0.69882484, 0.9312333, 0.9312333, 0.9312333, 0.9312333, + 0.93751527, 0.93751527, 0.93751527, 0.93751527, 0.97136768, 0.97136768, + 0.97136768, 0.97136768, 0.97338548, 0.97338548, 0.97338548, 0.97338548, + 0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.97864398, 0.97864398, + 0.97864398, 0.97864398, 0.98000654, 0.98000654, 0.98000654, 0.98000654, + 0.98112648, 0.98112648, 0.98112648, 0.98112648}); + + auto expHFinal = NDArrayFactory::create( + 'c', {bS, numUnits}, + {0.98000654, 0.98000654, 0.98000654, 0.98000654, 0.98112648, 0.98112648, + 0.98112648, 0.98112648}); + + sd::ops::static_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, static_rnn_test3) { - - const int bS = 2; - const int inSize = 3; - const int numUnits = 4; - const int time = 5; - - auto x = NDArrayFactory::create('c', {time, bS, inSize}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}); - auto h0 = NDArrayFactory::create('c', {bS, numUnits}); - auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, 0}); - - x.linspace(0.01, 0.01); - h0 = 0.2; - Wx = 0.3; - Wh = 0.4; - b = 0.25; - - auto expH = NDArrayFactory::create('c', {time, bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0., 0., 0., 0., 0.9312333, 0.9312333, 0.9312333, 0.9312333, - 0., 0., 0., 0. , 0.97136768, 0.97136768, 0.97136768, 0.97136768,0., 0., 0., 0. , - 0.97732812, 0.97732812, 0.97732812, 0.97732812,0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0.}); - - auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.2 , 0.2 , 0.2 , 0.2}); - - sd::ops::static_rnn op; - auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFinal = results.at(1); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFinal.isSameShape(hFinal)); - ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - - + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2 * numUnits}); + auto h0 = NDArrayFactory::create('c', {bS, numUnits}); + auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time - 1, 0}); + + x.linspace(0.01, 0.01); + h0 = 0.2; + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create( + 'c', {time, bS, numUnits}, + {0.68474828, 0.68474828, 0.68474828, 0.68474828, 0., 0., 0., 0., + 0.9312333, 0.9312333, 0.9312333, 0.9312333, 0., 0., 0., 0., + 0.97136768, 0.97136768, 0.97136768, 0.97136768, 0., 0., 0., 0., + 0.97732812, 0.97732812, 0.97732812, 0.97732812, 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0.}); + + auto expHFinal = NDArrayFactory::create( + 'c', {bS, numUnits}, + {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.2, 0.2, 0.2, 0.2}); + + sd::ops::static_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, static_rnn_test4) { - - const int bS = 2; - const int inSize = 3; - const int numUnits = 4; - const int time = 5; - - auto x = NDArrayFactory::create('c', {time, bS, inSize}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}); - auto h0 = NDArrayFactory::create('c', {bS, numUnits}); - auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-3}); - - x.linspace(0.01, 0.01); - Wx = 0.3; - Wh = 0.4; - b = 0.25; - - auto expH = NDArrayFactory::create('c', {time, bS, numUnits}, {0.47615493, 0.47615493, 0.47615493, 0.47615493,0.49676344, 0.49676344, 0.49676344, 0.49676344, 0.87018664, 0.87018664, 0.87018664, 0.87018664, - 0.88400882, 0.88400882, 0.88400882, 0.88400882, 0.96529784, 0.96529784, 0.96529784, 0.96529784,0., 0., 0., 0. , - 0.97688859, 0.97688859, 0.97688859, 0.97688859,0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0.}); - - auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97688859, 0.97688859, 0.97688859, 0.97688859, 0.88400882, 0.88400882, 0.88400882, 0.88400882}); - - sd::ops::static_rnn op; - auto results = op.evaluate({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFinal = results.at(1); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFinal.isSameShape(hFinal)); - ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - - + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2 * numUnits}); + auto h0 = NDArrayFactory::create('c', {bS, numUnits}); + auto maxTimeStep = + NDArrayFactory::create('c', {bS}, {time - 1, time - 3}); + + x.linspace(0.01, 0.01); + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create( + 'c', {time, bS, numUnits}, + {0.47615493, 0.47615493, 0.47615493, 0.47615493, 0.49676344, 0.49676344, + 0.49676344, 0.49676344, 0.87018664, 0.87018664, 0.87018664, 0.87018664, + 0.88400882, 0.88400882, 0.88400882, 0.88400882, 0.96529784, 0.96529784, + 0.96529784, 0.96529784, 0., 0., 0., 0., + 0.97688859, 0.97688859, 0.97688859, 0.97688859, 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0.}); + + auto expHFinal = NDArrayFactory::create( + 'c', {bS, numUnits}, + {0.97688859, 0.97688859, 0.97688859, 0.97688859, 0.88400882, 0.88400882, + 0.88400882, 0.88400882}); + + sd::ops::static_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, static_rnn_test5) { - - const int bS = 2; - const int inSize = 3; - const int numUnits = 4; - const int time = 5; - - auto x = NDArrayFactory::create('c', {time, bS, inSize}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}); - auto h0 = NDArrayFactory::create('c', {bS, numUnits}); - - x.linspace(0.01, 0.01); - Wx = 0.3; - Wh = 0.4; - b = 0.25; - - auto expH = NDArrayFactory::create('c', {time, bS, numUnits}, {0.47615493, 0.47615493, 0.47615493, 0.47615493,0.49676344, 0.49676344, 0.49676344, 0.49676344, 0.87018664, 0.87018664, 0.87018664, 0.87018664, - 0.88400882, 0.88400882, 0.88400882, 0.88400882, 0.96529784, 0.96529784, 0.96529784, 0.96529784,0.96849345, 0.96849345, 0.96849345, 0.96849345, - 0.97688859, 0.97688859, 0.97688859, 0.97688859,0.97831069, 0.97831069, 0.97831069, 0.97831069, 0.97997868, 0.97997868, 0.97997868, 0.97997868, - 0.98110653, 0.98110653, 0.98110653, 0.98110653}); - - auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97997868, 0.97997868, 0.97997868, 0.97997868, 0.98110653, 0.98110653, 0.98110653, 0.98110653}); - - sd::ops::static_rnn op; - auto results = op.evaluate({&x, &Wx, &Wh, &b}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFinal = results.at(1); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFinal.isSameShape(hFinal)); - ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - - + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2 * numUnits}); + auto h0 = NDArrayFactory::create('c', {bS, numUnits}); + + x.linspace(0.01, 0.01); + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create( + 'c', {time, bS, numUnits}, + {0.47615493, 0.47615493, 0.47615493, 0.47615493, 0.49676344, 0.49676344, + 0.49676344, 0.49676344, 0.87018664, 0.87018664, 0.87018664, 0.87018664, + 0.88400882, 0.88400882, 0.88400882, 0.88400882, 0.96529784, 0.96529784, + 0.96529784, 0.96529784, 0.96849345, 0.96849345, 0.96849345, 0.96849345, + 0.97688859, 0.97688859, 0.97688859, 0.97688859, 0.97831069, 0.97831069, + 0.97831069, 0.97831069, 0.97997868, 0.97997868, 0.97997868, 0.97997868, + 0.98110653, 0.98110653, 0.98110653, 0.98110653}); + + auto expHFinal = NDArrayFactory::create( + 'c', {bS, numUnits}, + {0.97997868, 0.97997868, 0.97997868, 0.97997868, 0.98110653, 0.98110653, + 0.98110653, 0.98110653}); + + sd::ops::static_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, static_bidir_rnn_test1) { - - const int bS = 4; - const int inSize = 4; - const int numUnitsFW = 3; - const int numUnitsBW = 3; - const int time = 5; - - auto x = NDArrayFactory::create('c', {time, bS, inSize}); - auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); - auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); - auto bFW = NDArrayFactory::create('c', {2*numUnitsFW}); - - auto h0FW = NDArrayFactory::create('c', {bS, numUnitsFW}); - auto h0BW = NDArrayFactory::create('c', {bS, numUnitsBW}); - auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-3, time-4, 0}); - - x.linspace(0.01, 0.01); - h0FW = 0.2; - h0BW = 0.25; - WxFW = 0.3; - WhFW = 0.4; - bFW = 0.1; - - auto expH = NDArrayFactory::create('c', {time, bS, numUnitsFW+numUnitsBW}, {0.43819931, 0.43819931, 0.43819931, 0.86708881, 0.86708881,0.86708881,0.47615493, 0.47615493, 0.47615493, 0.78347842, 0.78347842,0.78347842, - 0.51241561, 0.51241561, 0.51241561, 0.55529176, 0.55529176,0.55529176,0., 0., 0., 0., 0.,0.,0.73880324, 0.73880324, 0.73880324, 0.90935605, 0.90935605, - 0.90935605, 0.77843476, 0.77843476, 0.77843476, 0.64692945, 0.64692945,0.64692945,0., 0., 0., 0., 0.,0.,0., 0., 0., 0., 0.,0., - 0.9052501, 0.9052501, 0.9052501, 0.9181592, 0.9181592, 0.9181592,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., - 0.9555734, 0.9555734, 0.9555734, 0.8026439, 0.8026439, 0.8026439,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}); - - auto expHFWfinal = NDArrayFactory::create('c', {bS, numUnitsFW}, {0.9555734 , 0.9555734 , 0.9555734 , 0.77843476, 0.77843476, 0.77843476, 0.51241561, 0.51241561, 0.51241561, 0.2, 0.2, 0.2}); - auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, 0.55529176, 0.55529176, 0.55529176, 0.25, 0.25, 0.25}); - - sd::ops::static_bidirectional_rnn op; - auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFWfinal = results.at(1); - auto hBWfinal = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); - ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); - ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); - ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); - - + const int bS = 4; + const int inSize = 4; + const int numUnitsFW = 3; + const int numUnitsBW = 3; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); + auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); + auto bFW = NDArrayFactory::create('c', {2 * numUnitsFW}); + + auto h0FW = NDArrayFactory::create('c', {bS, numUnitsFW}); + auto h0BW = NDArrayFactory::create('c', {bS, numUnitsBW}); + auto maxTimeStep = NDArrayFactory::create( + 'c', {bS}, {time - 1, time - 3, time - 4, 0}); + + x.linspace(0.01, 0.01); + h0FW = 0.2; + h0BW = 0.25; + WxFW = 0.3; + WhFW = 0.4; + bFW = 0.1; + + auto expH = NDArrayFactory::create( + 'c', {time, bS, numUnitsFW + numUnitsBW}, + {0.43819931, 0.43819931, 0.43819931, 0.86708881, 0.86708881, 0.86708881, + 0.47615493, 0.47615493, 0.47615493, 0.78347842, 0.78347842, 0.78347842, + 0.51241561, 0.51241561, 0.51241561, 0.55529176, 0.55529176, 0.55529176, + 0., 0., 0., 0., 0., 0., + 0.73880324, 0.73880324, 0.73880324, 0.90935605, 0.90935605, 0.90935605, + 0.77843476, 0.77843476, 0.77843476, 0.64692945, 0.64692945, 0.64692945, + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.9052501, 0.9052501, 0.9052501, 0.9181592, 0.9181592, 0.9181592, + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.9555734, 0.9555734, 0.9555734, 0.8026439, 0.8026439, 0.8026439, + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0.}); + + auto expHFWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsFW}, + {0.9555734, 0.9555734, 0.9555734, 0.77843476, 0.77843476, 0.77843476, + 0.51241561, 0.51241561, 0.51241561, 0.2, 0.2, 0.2}); + auto expHBWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsBW}, + {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, + 0.55529176, 0.55529176, 0.55529176, 0.25, 0.25, 0.25}); + + sd::ops::static_bidirectional_rnn op; + auto results = op.evaluate( + {&x, &WxFW, &WhFW, &bFW, &WxFW, &WhFW, &bFW, &h0FW, &h0BW, &maxTimeStep}, + {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFWfinal = results.at(1); + auto hBWfinal = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); + ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); + ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); + ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, static_bidir_rnn_test2) { - - const int bS = 4; - const int inSize = 4; - const int numUnitsFW = 3; - const int numUnitsBW = 3; - const int time = 5; - - auto x = NDArrayFactory::create('c', {time, bS, inSize}); - auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); - auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); - auto bFW = NDArrayFactory::create('c', {2*numUnitsFW}); - - auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-3, time-4, 0}); - - x.linspace(0.01, 0.01); - WxFW = 0.3; - WhFW = 0.4; - bFW = 0.1; - - auto expH = NDArrayFactory::create('c', {time, bS, numUnitsFW+numUnitsBW}, {0.22602835, 0.22602835, 0.22602835, 0.86518273, 0.86518273,0.86518273,0.27105303, 0.27105303, 0.27105303, 0.66617761, 0.66617761,0.66617761, - 0.31492203, 0.31492203, 0.31492203, 0.31492203, 0.31492203,0.31492203,0. , 0. , 0. , 0. , 0. ,0. , - 0.60005558, 0.60005558, 0.60005558, 0.9029975 , 0.9029975 ,0.9029975 ,0.66138054, 0.66138054, 0.66138054, 0.43819931, 0.43819931,0.43819931, - 0. , 0. , 0. , 0. , 0. ,0. ,0. , 0. , 0. , 0. , 0. ,0. , - 0.87023975, 0.87023975, 0.87023975, 0.88852032, 0.88852032,0.88852032,0. , 0. , 0. , 0. , 0. ,0. , - 0. , 0. , 0. , 0. , 0. ,0. ,0. , 0. , 0. , 0. , 0. ,0. , - 0.95177305, 0.95177305, 0.95177305, 0.66737775, 0.66737775,0.66737775,0. , 0. , 0. , 0. , 0. ,0. , - 0. , 0. , 0. , 0. , 0. ,0. ,0. , 0. , 0. , 0. , 0. ,0. , - 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.}); - - auto expHFWfinal = NDArrayFactory::create('c', {bS, numUnitsFW}, {0.95177305, 0.95177305, 0.95177305, 0.66138054, 0.66138054, 0.66138054, 0.31492203, 0.31492203, 0.31492203, 0. , 0. , 0.}); - auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.86518273, 0.86518273, 0.86518273, 0.66617761, 0.66617761, 0.66617761, 0.31492203, 0.31492203, 0.31492203, 0. , 0. , 0.}); - - sd::ops::static_bidirectional_rnn op; - auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &maxTimeStep}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFWfinal = results.at(1); - auto hBWfinal = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); - ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); - ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); - ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); - - + const int bS = 4; + const int inSize = 4; + const int numUnitsFW = 3; + const int numUnitsBW = 3; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); + auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); + auto bFW = NDArrayFactory::create('c', {2 * numUnitsFW}); + + auto maxTimeStep = NDArrayFactory::create( + 'c', {bS}, {time - 1, time - 3, time - 4, 0}); + + x.linspace(0.01, 0.01); + WxFW = 0.3; + WhFW = 0.4; + bFW = 0.1; + + auto expH = NDArrayFactory::create( + 'c', {time, bS, numUnitsFW + numUnitsBW}, + {0.22602835, 0.22602835, 0.22602835, 0.86518273, 0.86518273, 0.86518273, + 0.27105303, 0.27105303, 0.27105303, 0.66617761, 0.66617761, 0.66617761, + 0.31492203, 0.31492203, 0.31492203, 0.31492203, 0.31492203, 0.31492203, + 0., 0., 0., 0., 0., 0., + 0.60005558, 0.60005558, 0.60005558, 0.9029975, 0.9029975, 0.9029975, + 0.66138054, 0.66138054, 0.66138054, 0.43819931, 0.43819931, 0.43819931, + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.87023975, 0.87023975, 0.87023975, 0.88852032, 0.88852032, 0.88852032, + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.95177305, 0.95177305, 0.95177305, 0.66737775, 0.66737775, 0.66737775, + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0.}); + + auto expHFWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsFW}, + {0.95177305, 0.95177305, 0.95177305, 0.66138054, 0.66138054, 0.66138054, + 0.31492203, 0.31492203, 0.31492203, 0., 0., 0.}); + auto expHBWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsBW}, + {0.86518273, 0.86518273, 0.86518273, 0.66617761, 0.66617761, 0.66617761, + 0.31492203, 0.31492203, 0.31492203, 0., 0., 0.}); + + sd::ops::static_bidirectional_rnn op; + auto results = op.evaluate( + {&x, &WxFW, &WhFW, &bFW, &WxFW, &WhFW, &bFW, &maxTimeStep}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFWfinal = results.at(1); + auto hBWfinal = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); + ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); + ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); + ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); } - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, static_bidir_rnn_test3) { - - const int bS = 4; - const int inSize = 4; - const int numUnitsFW = 3; - const int numUnitsBW = 3; - const int time = 5; - - auto x = NDArrayFactory::create('c', {time, bS, inSize}); - auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); - auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); - auto bFW = NDArrayFactory::create('c', {2*numUnitsFW}); - - x.linspace(0.01, 0.01); - WxFW = 0.3; - WhFW = 0.4; - bFW = 0.1; - - auto expH = NDArrayFactory::create('c', {time, bS, numUnitsFW+numUnitsBW}, {0.22602835, 0.22602835, 0.22602835, 0.86841012, 0.86841012,0.86841012,0.27105303, 0.27105303, 0.27105303, 0.88207531, 0.88207531,0.88207531, - 0.31492203, 0.31492203, 0.31492203, 0.8941667 , 0.8941667 ,0.8941667 ,0.35748551, 0.35748551, 0.35748551, 0.90489713, 0.90489713, - 0.90489713, 0.60005558, 0.60005558, 0.60005558, 0.91381375, 0.91381375,0.91381375,0.66138054, 0.66138054, 0.66138054, 0.92253504, 0.92253504, - 0.92253504,0.71429879, 0.71429879, 0.71429879, 0.93027876, 0.93027876,0.93027876,0.75947891, 0.75947891, 0.75947891, 0.9371767 , 0.9371767 , - 0.9371767 , 0.87023975, 0.87023975, 0.87023975, 0.94014274, 0.94014274,0.94014274,0.89680574, 0.89680574, 0.89680574, 0.94648926, 0.94648926, - 0.94648926,0.91657261, 0.91657261, 0.91657261, 0.95204779, 0.95204779,0.95204779,0.93146896, 0.93146896, 0.93146896, 0.95694206, 0.95694206, - 0.95694206, 0.95177305, 0.95177305, 0.95177305, 0.93773086, 0.93773086,0.93773086,0.95874689, 0.95874689, 0.95874689, 0.94579176, 0.94579176, - 0.94579176,0.96416067, 0.96416067, 0.96416067, 0.95267886, 0.95267886,0.95267886,0.96851506, 0.96851506, 0.96851506, 0.95857985, 0.95857985, - 0.95857985, 0.97269956, 0.97269956, 0.97269956, 0.76075293, 0.76075293,0.76075293,0.97557464, 0.97557464, 0.97557464, 0.78024637, 0.78024637, - 0.78024637,0.97806922, 0.97806922, 0.97806922, 0.79833344, 0.79833344,0.79833344,0.98026195, 0.98026195, 0.98026195, 0.81508646, 0.81508646,0.81508646}); - - auto expHFWfinal = NDArrayFactory::create('c', {bS, numUnitsFW}, {0.97269956, 0.97269956, 0.97269956, 0.97557464, 0.97557464, 0.97557464, 0.97806922, 0.97806922, 0.97806922, 0.98026195, 0.98026195, 0.98026195}); - auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.86841012, 0.86841012, 0.86841012, 0.88207531, 0.88207531, 0.88207531, 0.8941667 , 0.8941667 , 0.8941667 , 0.90489713, 0.90489713, 0.90489713}); - - sd::ops::static_bidirectional_rnn op; - auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFWfinal = results.at(1); - auto hBWfinal = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); - ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); - ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); - ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); - - + const int bS = 4; + const int inSize = 4; + const int numUnitsFW = 3; + const int numUnitsBW = 3; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); + auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); + auto bFW = NDArrayFactory::create('c', {2 * numUnitsFW}); + + x.linspace(0.01, 0.01); + WxFW = 0.3; + WhFW = 0.4; + bFW = 0.1; + + auto expH = NDArrayFactory::create( + 'c', {time, bS, numUnitsFW + numUnitsBW}, + {0.22602835, 0.22602835, 0.22602835, 0.86841012, 0.86841012, 0.86841012, + 0.27105303, 0.27105303, 0.27105303, 0.88207531, 0.88207531, 0.88207531, + 0.31492203, 0.31492203, 0.31492203, 0.8941667, 0.8941667, 0.8941667, + 0.35748551, 0.35748551, 0.35748551, 0.90489713, 0.90489713, 0.90489713, + 0.60005558, 0.60005558, 0.60005558, 0.91381375, 0.91381375, 0.91381375, + 0.66138054, 0.66138054, 0.66138054, 0.92253504, 0.92253504, 0.92253504, + 0.71429879, 0.71429879, 0.71429879, 0.93027876, 0.93027876, 0.93027876, + 0.75947891, 0.75947891, 0.75947891, 0.9371767, 0.9371767, 0.9371767, + 0.87023975, 0.87023975, 0.87023975, 0.94014274, 0.94014274, 0.94014274, + 0.89680574, 0.89680574, 0.89680574, 0.94648926, 0.94648926, 0.94648926, + 0.91657261, 0.91657261, 0.91657261, 0.95204779, 0.95204779, 0.95204779, + 0.93146896, 0.93146896, 0.93146896, 0.95694206, 0.95694206, 0.95694206, + 0.95177305, 0.95177305, 0.95177305, 0.93773086, 0.93773086, 0.93773086, + 0.95874689, 0.95874689, 0.95874689, 0.94579176, 0.94579176, 0.94579176, + 0.96416067, 0.96416067, 0.96416067, 0.95267886, 0.95267886, 0.95267886, + 0.96851506, 0.96851506, 0.96851506, 0.95857985, 0.95857985, 0.95857985, + 0.97269956, 0.97269956, 0.97269956, 0.76075293, 0.76075293, 0.76075293, + 0.97557464, 0.97557464, 0.97557464, 0.78024637, 0.78024637, 0.78024637, + 0.97806922, 0.97806922, 0.97806922, 0.79833344, 0.79833344, 0.79833344, + 0.98026195, 0.98026195, 0.98026195, 0.81508646, 0.81508646, 0.81508646}); + + auto expHFWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsFW}, + {0.97269956, 0.97269956, 0.97269956, 0.97557464, 0.97557464, 0.97557464, + 0.97806922, 0.97806922, 0.97806922, 0.98026195, 0.98026195, 0.98026195}); + auto expHBWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsBW}, + {0.86841012, 0.86841012, 0.86841012, 0.88207531, 0.88207531, 0.88207531, + 0.8941667, 0.8941667, 0.8941667, 0.90489713, 0.90489713, 0.90489713}); + + sd::ops::static_bidirectional_rnn op; + auto results = + op.evaluate({&x, &WxFW, &WhFW, &bFW, &WxFW, &WhFW, &bFW}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFWfinal = results.at(1); + auto hBWfinal = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); + ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); + ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); + ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, dynamic_rnn_test1) { - - const int bS = 2; - const int inSize = 3; - const int numUnits = 4; - const int time = 5; - - auto x = NDArrayFactory::create('c', {time, bS, inSize}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}); - auto h0 = NDArrayFactory::create('c', {bS, numUnits}); - auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-3}); - - x.linspace(0.01, 0.01); - h0 = 0.2; - Wx = 0.3; - Wh = 0.4; - b = 0.25; - - auto expH = NDArrayFactory::create('c', {time, bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.69882484, 0.69882484, 0.69882484, 0.69882484,0.9312333 , 0.9312333 , 0.9312333 , 0.9312333 , - 0.93751527, 0.93751527, 0.93751527, 0.93751527,0.97136768, 0.97136768, 0.97136768, 0.97136768,0. , 0. , 0. , 0. , - 0.97732812, 0.97732812, 0.97732812, 0.97732812,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. }); - - auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, 0.93751527, 0.93751527}); - - sd::ops::dynamic_rnn op; - auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFinal = results.at(1); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFinal.isSameShape(hFinal)); - ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - - + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2 * numUnits}); + auto h0 = NDArrayFactory::create('c', {bS, numUnits}); + auto maxTimeStep = + NDArrayFactory::create('c', {bS}, {time - 1, time - 3}); + + x.linspace(0.01, 0.01); + h0 = 0.2; + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create( + 'c', {time, bS, numUnits}, + {0.68474828, 0.68474828, 0.68474828, 0.68474828, 0.69882484, 0.69882484, + 0.69882484, 0.69882484, 0.9312333, 0.9312333, 0.9312333, 0.9312333, + 0.93751527, 0.93751527, 0.93751527, 0.93751527, 0.97136768, 0.97136768, + 0.97136768, 0.97136768, 0., 0., 0., 0., + 0.97732812, 0.97732812, 0.97732812, 0.97732812, 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0.}); + + auto expHFinal = NDArrayFactory::create( + 'c', {bS, numUnits}, + {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, + 0.93751527, 0.93751527}); + + sd::ops::dynamic_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); } - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, dynamic_rnn_test2) { - - const int bS = 2; - const int inSize = 3; - const int numUnits = 4; - const int time = 5; - - auto x = NDArrayFactory::create('c', {bS, time, inSize}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}); - auto h0 = NDArrayFactory::create('c', {bS, numUnits}); - auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time}); - - x.linspace(0.01, 0.01); - h0 = 0.2; - Wx = 0.3; - Wh = 0.4; - b = 0.25; - - auto expH = NDArrayFactory::create('c', {bS, time, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.92755601, 0.92755601, 0.92755601, 0.92755601,0.96778334, 0.96778334, 0.96778334, - 0.96778334,0.97309129, 0.97309129, 0.97309129, 0.97309129,0. , 0. , 0. , 0. , - 0.75001965, 0.75001965, 0.75001965, 0.75001965,0.95449491, 0.95449491, 0.95449491, 0.95449491,0.97732828, 0.97732828, 0.97732828, - 0.97732828,0.98000655, 0.98000655, 0.98000655, 0.98000655,0.98120782, 0.98120782, 0.98120782, 0.98120782}); - - auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97309129, 0.97309129, 0.97309129, 0.97309129, 0.98120782, 0.98120782, 0.98120782, 0.98120782}); - - sd::ops::dynamic_rnn op; - auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFinal = results.at(1); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFinal.isSameShape(hFinal)); - ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - - + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {bS, time, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2 * numUnits}); + auto h0 = NDArrayFactory::create('c', {bS, numUnits}); + auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time - 1, time}); + + x.linspace(0.01, 0.01); + h0 = 0.2; + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create( + 'c', {bS, time, numUnits}, + {0.68474828, 0.68474828, 0.68474828, 0.68474828, 0.92755601, 0.92755601, + 0.92755601, 0.92755601, 0.96778334, 0.96778334, 0.96778334, 0.96778334, + 0.97309129, 0.97309129, 0.97309129, 0.97309129, 0., 0., + 0., 0., 0.75001965, 0.75001965, 0.75001965, 0.75001965, + 0.95449491, 0.95449491, 0.95449491, 0.95449491, 0.97732828, 0.97732828, + 0.97732828, 0.97732828, 0.98000655, 0.98000655, 0.98000655, 0.98000655, + 0.98120782, 0.98120782, 0.98120782, 0.98120782}); + + auto expHFinal = NDArrayFactory::create( + 'c', {bS, numUnits}, + {0.97309129, 0.97309129, 0.97309129, 0.97309129, 0.98120782, 0.98120782, + 0.98120782, 0.98120782}); + + sd::ops::dynamic_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, dynamic_rnn_test3) { - - const int bS = 2; - const int inSize = 3; - const int numUnits = 4; - const int time = 5; - - auto x = NDArrayFactory::create('c', {bS, time, inSize}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}); - auto h0 = NDArrayFactory::create('c', {bS, numUnits}); - - x.linspace(0.01, 0.01); - h0 = 0.2; - Wx = 0.3; - Wh = 0.4; - b = 0.25; - - auto expH = NDArrayFactory::create('c', {bS, time, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.92755601, 0.92755601, 0.92755601, 0.92755601,0.96778334, 0.96778334, 0.96778334, 0.96778334,0.97309129, - 0.97309129, 0.97309129, 0.97309129,0.97491207, 0.97491207, 0.97491207, 0.97491207,0.75001965, 0.75001965, 0.75001965, 0.75001965,0.95449491, 0.95449491, - 0.95449491, 0.95449491,0.97732828, 0.97732828, 0.97732828, 0.97732828,0.98000655, 0.98000655, 0.98000655, 0.98000655,0.98120782, 0.98120782, 0.98120782, 0.98120782}); - - auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97491207, 0.97491207, 0.97491207, 0.97491207, 0.98120782, 0.98120782, 0.98120782, 0.98120782}); - - sd::ops::dynamic_rnn op; - auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFinal = results.at(1); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFinal.isSameShape(hFinal)); - ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - - + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {bS, time, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2 * numUnits}); + auto h0 = NDArrayFactory::create('c', {bS, numUnits}); + + x.linspace(0.01, 0.01); + h0 = 0.2; + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create( + 'c', {bS, time, numUnits}, + {0.68474828, 0.68474828, 0.68474828, 0.68474828, 0.92755601, 0.92755601, + 0.92755601, 0.92755601, 0.96778334, 0.96778334, 0.96778334, 0.96778334, + 0.97309129, 0.97309129, 0.97309129, 0.97309129, 0.97491207, 0.97491207, + 0.97491207, 0.97491207, 0.75001965, 0.75001965, 0.75001965, 0.75001965, + 0.95449491, 0.95449491, 0.95449491, 0.95449491, 0.97732828, 0.97732828, + 0.97732828, 0.97732828, 0.98000655, 0.98000655, 0.98000655, 0.98000655, + 0.98120782, 0.98120782, 0.98120782, 0.98120782}); + + auto expHFinal = NDArrayFactory::create( + 'c', {bS, numUnits}, + {0.97491207, 0.97491207, 0.97491207, 0.97491207, 0.98120782, 0.98120782, + 0.98120782, 0.98120782}); + + sd::ops::dynamic_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, dynamic_rnn_test4) { - - const int bS = 2; - const int inSize = 3; - const int numUnits = 4; - const int time = 5; - - auto x = NDArrayFactory::create('c', {bS, time, inSize}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}); - auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-4}); - - x.linspace(0.01, 0.01); - Wx = 0.3; - Wh = 0.4; - b = 0.25; - - auto expH = NDArrayFactory::create('c', {bS, time, numUnits}, {0.47615493, 0.47615493, 0.47615493, 0.47615493,0.86347567, 0.86347567, 0.86347567, 0.86347567,0.96059545, 0.96059545, - 0.96059545, 0.96059545,0.9724738 , 0.9724738 , 0.9724738 , 0.9724738 ,0. , 0. , 0. , 0. , - 0.57368608, 0.57368608, 0.57368608, 0.57368608,0. , 0. , 0 , 0. ,0., 0. , 0, 0.,0., 0., 0. , 0. ,0. , 0. , 0., 0. }); - - auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.9724738 , 0.9724738 , 0.9724738 , 0.9724738 ,0.57368608, 0.57368608, 0.57368608, 0.57368608}); - - sd::ops::dynamic_rnn op; - auto results = op.evaluate({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFinal = results.at(1); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFinal.isSameShape(hFinal)); - ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - - + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {bS, time, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2 * numUnits}); + auto maxTimeStep = + NDArrayFactory::create('c', {bS}, {time - 1, time - 4}); + + x.linspace(0.01, 0.01); + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create( + 'c', {bS, time, numUnits}, + {0.47615493, 0.47615493, 0.47615493, 0.47615493, 0.86347567, 0.86347567, + 0.86347567, 0.86347567, 0.96059545, 0.96059545, 0.96059545, 0.96059545, + 0.9724738, 0.9724738, 0.9724738, 0.9724738, 0., 0., + 0., 0., 0.57368608, 0.57368608, 0.57368608, 0.57368608, + 0., 0., 0, 0., 0., 0., + 0, 0., 0., 0., 0., 0., + 0., 0., 0., 0.}); + + auto expHFinal = NDArrayFactory::create( + 'c', {bS, numUnits}, + {0.9724738, 0.9724738, 0.9724738, 0.9724738, 0.57368608, 0.57368608, + 0.57368608, 0.57368608}); + + sd::ops::dynamic_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, dynamic_rnn_test5) { - - const int bS = 2; - const int inSize = 3; - const int numUnits = 4; - const int time = 5; - - auto x = NDArrayFactory::create('c', {bS, time, inSize}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}); - - x.linspace(0.01, 0.01); - Wx = 0.3; - Wh = 0.4; - b = 0.25; - - auto expH = NDArrayFactory::create('c', {bS, time, numUnits}, {0.47615493, 0.47615493, 0.47615493, 0.47615493,0.86347567, 0.86347567, 0.86347567, 0.86347567,0.96059545, 0.96059545, 0.96059545, 0.96059545, - 0.9724738 , 0.9724738 , 0.9724738 , 0.9724738 ,0.97486307, 0.97486307, 0.97486307, 0.97486307,0.57368608, 0.57368608, 0.57368608, 0.57368608, - 0.92135149, 0.92135149, 0.92135149, 0.92135149,0.97482354, 0.97482354, 0.97482354, 0.97482354,0.97984727, 0.97984727, 0.97984727, 0.97984727, - 0.98119833, 0.98119833, 0.98119833, 0.98119833}); - - auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97486307, 0.97486307, 0.97486307, 0.97486307,0.98119833, 0.98119833, 0.98119833, 0.98119833}); - - sd::ops::dynamic_rnn op; - auto results = op.evaluate({&x, &Wx, &Wh, &b}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFinal = results.at(1); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFinal.isSameShape(hFinal)); - ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - - + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {bS, time, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2 * numUnits}); + + x.linspace(0.01, 0.01); + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create( + 'c', {bS, time, numUnits}, + {0.47615493, 0.47615493, 0.47615493, 0.47615493, 0.86347567, 0.86347567, + 0.86347567, 0.86347567, 0.96059545, 0.96059545, 0.96059545, 0.96059545, + 0.9724738, 0.9724738, 0.9724738, 0.9724738, 0.97486307, 0.97486307, + 0.97486307, 0.97486307, 0.57368608, 0.57368608, 0.57368608, 0.57368608, + 0.92135149, 0.92135149, 0.92135149, 0.92135149, 0.97482354, 0.97482354, + 0.97482354, 0.97482354, 0.97984727, 0.97984727, 0.97984727, 0.97984727, + 0.98119833, 0.98119833, 0.98119833, 0.98119833}); + + auto expHFinal = NDArrayFactory::create( + 'c', {bS, numUnits}, + {0.97486307, 0.97486307, 0.97486307, 0.97486307, 0.98119833, 0.98119833, + 0.98119833, 0.98119833}); + + sd::ops::dynamic_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test1) { - - const int bS = 4; - const int inSize = 4; - const int numUnitsFW = 3; - const int numUnitsBW = 3; - const int time = 5; - - auto x = NDArrayFactory::create('c', {time, bS, inSize}); - auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); - auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); - auto bFW = NDArrayFactory::create('c', {2*numUnitsFW}); - - auto h0FW = NDArrayFactory::create('c', {bS, numUnitsFW}); - auto h0BW = NDArrayFactory::create('c', {bS, numUnitsBW}); - auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-3, time-4, 0}); - - x.linspace(0.01, 0.01); - h0FW = 0.2; - h0BW = 0.25; - WxFW = 0.3; - WhFW = 0.4; - bFW = 0.1; - - auto expHFW = NDArrayFactory::create('c', {time, bS, numUnitsFW}, {0.43819931, 0.43819931, 0.43819931,0.47615493, 0.47615493, 0.47615493,0.51241561, 0.51241561, 0.51241561,0. , 0. , 0. , - 0.73880324, 0.73880324, 0.73880324,0.77843476, 0.77843476, 0.77843476,0. , 0. , 0. ,0. , 0. , 0. , - 0.9052501 , 0.9052501 , 0.9052501 ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , - 0.9555734 , 0.9555734 , 0.9555734 ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , - 0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. }); - - auto expHBW = NDArrayFactory::create('c', {time, bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881,0.78347842, 0.78347842, 0.78347842,0.55529176, 0.55529176, 0.55529176,0. , 0. , 0. , - 0.90935605, 0.90935605, 0.90935605,0.64692945, 0.64692945, 0.64692945,0. , 0. , 0. ,0. , 0. , 0. , - 0.9181592 , 0.9181592 , 0.9181592 ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , - 0.8026439 , 0.8026439 , 0.8026439 ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , - 0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. }); - - auto expHFWfinal = NDArrayFactory::create('c', {bS, numUnitsFW}, {0.9555734 , 0.9555734 , 0.9555734 , 0.77843476, 0.77843476, 0.77843476, 0.51241561, 0.51241561, 0.51241561, 0.2 , 0.2 , 0.2}); - auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, 0.55529176, 0.55529176, 0.55529176, 0.25 , 0.25 , 0.25}); - - sd::ops::dynamic_bidirectional_rnn op; - auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto hFW = results.at(0); - auto hBW = results.at(1); - auto hFWfinal = results.at(2); - auto hBWfinal = results.at(3); - - ASSERT_TRUE(expHFW.isSameShape(hFW)); - ASSERT_TRUE(expHFW.equalsTo(hFW)); - ASSERT_TRUE(expHBW.isSameShape(hBW)); - ASSERT_TRUE(expHBW.equalsTo(hBW)); - ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); - ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); - ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); - ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); - - + const int bS = 4; + const int inSize = 4; + const int numUnitsFW = 3; + const int numUnitsBW = 3; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); + auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); + auto bFW = NDArrayFactory::create('c', {2 * numUnitsFW}); + + auto h0FW = NDArrayFactory::create('c', {bS, numUnitsFW}); + auto h0BW = NDArrayFactory::create('c', {bS, numUnitsBW}); + auto maxTimeStep = + NDArrayFactory::create('c', {bS}, {time - 1, time - 3, time - 4, 0}); + + x.linspace(0.01, 0.01); + h0FW = 0.2; + h0BW = 0.25; + WxFW = 0.3; + WhFW = 0.4; + bFW = 0.1; + + auto expHFW = NDArrayFactory::create( + 'c', {time, bS, numUnitsFW}, + {0.43819931, 0.43819931, 0.43819931, 0.47615493, 0.47615493, 0.47615493, + 0.51241561, 0.51241561, 0.51241561, 0., 0., 0., + 0.73880324, 0.73880324, 0.73880324, 0.77843476, 0.77843476, 0.77843476, + 0., 0., 0., 0., 0., 0., + 0.9052501, 0.9052501, 0.9052501, 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.9555734, 0.9555734, 0.9555734, 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0.}); + + auto expHBW = NDArrayFactory::create( + 'c', {time, bS, numUnitsBW}, + {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, + 0.55529176, 0.55529176, 0.55529176, 0., 0., 0., + 0.90935605, 0.90935605, 0.90935605, 0.64692945, 0.64692945, 0.64692945, + 0., 0., 0., 0., 0., 0., + 0.9181592, 0.9181592, 0.9181592, 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.8026439, 0.8026439, 0.8026439, 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0.}); + + auto expHFWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsFW}, + {0.9555734, 0.9555734, 0.9555734, 0.77843476, 0.77843476, 0.77843476, + 0.51241561, 0.51241561, 0.51241561, 0.2, 0.2, 0.2}); + auto expHBWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsBW}, + {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, + 0.55529176, 0.55529176, 0.55529176, 0.25, 0.25, 0.25}); + + sd::ops::dynamic_bidirectional_rnn op; + auto results = op.evaluate( + {&x, &WxFW, &WhFW, &bFW, &WxFW, &WhFW, &bFW, &h0FW, &h0BW, &maxTimeStep}, + {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto hFW = results.at(0); + auto hBW = results.at(1); + auto hFWfinal = results.at(2); + auto hBWfinal = results.at(3); + + ASSERT_TRUE(expHFW.isSameShape(hFW)); + ASSERT_TRUE(expHFW.equalsTo(hFW)); + ASSERT_TRUE(expHBW.isSameShape(hBW)); + ASSERT_TRUE(expHBW.equalsTo(hBW)); + ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); + ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); + ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); + ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test2) { - - const int bS = 4; - const int inSize = 4; - const int numUnitsFW = 3; - const int numUnitsBW = 3; - const int time = 5; - - auto x = NDArrayFactory::create('c', {bS, time, inSize}); - auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); - auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); - auto bFW = NDArrayFactory::create('c', {2*numUnitsFW}); - - auto h0FW = NDArrayFactory::create('c', {bS, numUnitsFW}); - auto h0BW = NDArrayFactory::create('c', {bS, numUnitsBW}); - auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-3, time-4, 0}); - - x.linspace(0.01, 0.01); - h0FW = 0.2; - h0BW = 0.25; - WxFW = 0.3; - WhFW = 0.4; - bFW = 0.1; - - auto expHFW = NDArrayFactory::create('c', {bS, time, numUnitsFW}, {0.43819931, 0.43819931, 0.43819931,0.66617761, 0.66617761, 0.66617761,0.80944357, 0.80944357, 0.80944357,0.87294706, 0.87294706, 0.87294706,0. , 0. , 0. , - 0.61067683, 0.61067683, 0.61067683,0.84851124, 0.84851124, 0.84851124,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , - 0.73978305, 0.73978305, 0.73978305,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , - 0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. }); - - auto expHBW = NDArrayFactory::create('c', {bS, time, numUnitsBW}, {0.84345207, 0.84345207, 0.84345207,0.83584708, 0.83584708, 0.83584708,0.77435951, 0.77435951, 0.77435951,0.58760492, 0.58760492, 0.58760492,0. , 0. , 0. , - 0.85615841, 0.85615841, 0.85615841,0.67397984, 0.67397984, 0.67397984,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , - 0.76576202, 0.76576202, 0.76576202,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , - 0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. }); - - auto expHFWfinal = NDArrayFactory::create('c', {bS, numUnitsFW}, {0.87294706, 0.87294706, 0.87294706,0.84851124, 0.84851124, 0.84851124,0.73978305, 0.73978305, 0.73978305,0.2 , 0.2 , 0.2}); - auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.84345207, 0.84345207, 0.84345207, 0.85615841, 0.85615841, 0.85615841, 0.76576202, 0.76576202, 0.76576202, 0.25 , 0.25 , 0.25}); - - sd::ops::dynamic_bidirectional_rnn op; - auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto hFW = results.at(0); - auto hBW = results.at(1); - auto hFWfinal = results.at(2); - auto hBWfinal = results.at(3); - - ASSERT_TRUE(expHFW.isSameShape(hFW)); - ASSERT_TRUE(expHFW.equalsTo(hFW)); - ASSERT_TRUE(expHBW.isSameShape(hBW)); - ASSERT_TRUE(expHBW.equalsTo(hBW)); - ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); - ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); - ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); - ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); - - + const int bS = 4; + const int inSize = 4; + const int numUnitsFW = 3; + const int numUnitsBW = 3; + const int time = 5; + + auto x = NDArrayFactory::create('c', {bS, time, inSize}); + auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); + auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); + auto bFW = NDArrayFactory::create('c', {2 * numUnitsFW}); + + auto h0FW = NDArrayFactory::create('c', {bS, numUnitsFW}); + auto h0BW = NDArrayFactory::create('c', {bS, numUnitsBW}); + auto maxTimeStep = + NDArrayFactory::create('c', {bS}, {time - 1, time - 3, time - 4, 0}); + + x.linspace(0.01, 0.01); + h0FW = 0.2; + h0BW = 0.25; + WxFW = 0.3; + WhFW = 0.4; + bFW = 0.1; + + auto expHFW = NDArrayFactory::create( + 'c', {bS, time, numUnitsFW}, + {0.43819931, 0.43819931, 0.43819931, 0.66617761, 0.66617761, 0.66617761, + 0.80944357, 0.80944357, 0.80944357, 0.87294706, 0.87294706, 0.87294706, + 0., 0., 0., 0.61067683, 0.61067683, 0.61067683, + 0.84851124, 0.84851124, 0.84851124, 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.73978305, 0.73978305, 0.73978305, 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0.}); + + auto expHBW = NDArrayFactory::create( + 'c', {bS, time, numUnitsBW}, + {0.84345207, 0.84345207, 0.84345207, 0.83584708, 0.83584708, 0.83584708, + 0.77435951, 0.77435951, 0.77435951, 0.58760492, 0.58760492, 0.58760492, + 0., 0., 0., 0.85615841, 0.85615841, 0.85615841, + 0.67397984, 0.67397984, 0.67397984, 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.76576202, 0.76576202, 0.76576202, 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0.}); + + auto expHFWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsFW}, + {0.87294706, 0.87294706, 0.87294706, 0.84851124, 0.84851124, 0.84851124, + 0.73978305, 0.73978305, 0.73978305, 0.2, 0.2, 0.2}); + auto expHBWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsBW}, + {0.84345207, 0.84345207, 0.84345207, 0.85615841, 0.85615841, 0.85615841, + 0.76576202, 0.76576202, 0.76576202, 0.25, 0.25, 0.25}); + + sd::ops::dynamic_bidirectional_rnn op; + auto results = op.evaluate( + {&x, &WxFW, &WhFW, &bFW, &WxFW, &WhFW, &bFW, &h0FW, &h0BW, &maxTimeStep}, + {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto hFW = results.at(0); + auto hBW = results.at(1); + auto hFWfinal = results.at(2); + auto hBWfinal = results.at(3); + + ASSERT_TRUE(expHFW.isSameShape(hFW)); + ASSERT_TRUE(expHFW.equalsTo(hFW)); + ASSERT_TRUE(expHBW.isSameShape(hBW)); + ASSERT_TRUE(expHBW.equalsTo(hBW)); + ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); + ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); + ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); + ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test3) { - - const int bS = 4; - const int inSize = 4; - const int numUnitsFW = 3; - const int numUnitsBW = 3; - const int time = 5; - - auto x = NDArrayFactory::create('c', {bS, time, inSize}); - auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); - auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); - auto bFW = NDArrayFactory::create('c', {2*numUnitsFW}); - - auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-3, time-4, 0}); - - x.linspace(0.01, 0.01); - WxFW = 0.3; - WhFW = 0.4; - bFW = 0.1; - - auto expHFW = NDArrayFactory::create('c', {bS, time, numUnitsFW}, {0.22602835, 0.22602835, 0.22602835,0.49994591, 0.49994591, 0.49994591,0.72869307, 0.72869307, 0.72869307,0.84784327, 0.84784327, 0.84784327,0. , 0. , 0. , - 0.43819931, 0.43819931, 0.43819931,0.7793996 , 0.7793996 , 0.7793996 ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , - 0.61067683, 0.61067683, 0.61067683,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , - 0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. }); - - auto expHBW = NDArrayFactory::create('c', {bS, time, numUnitsBW}, {0.82273707, 0.82273707, 0.82273707,0.77935851, 0.77935851, 0.77935851,0.6381121 , 0.6381121 , 0.6381121 ,0.35748551, 0.35748551, 0.35748551,0. , 0. , 0. , - 0.77843476, 0.77843476, 0.77843476,0.47615493, 0.47615493, 0.47615493,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , - 0.61067683, 0.61067683, 0.61067683,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , - 0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. }); - - auto expHFWfinal = NDArrayFactory::create('c', {bS, numUnitsFW}, {0.84784327, 0.84784327, 0.84784327, 0.7793996 , 0.7793996 , 0.7793996 , 0.61067683, 0.61067683, 0.61067683, 0. , 0. , 0.}); - auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.82273707, 0.82273707, 0.82273707, 0.77843476, 0.77843476, 0.77843476, 0.61067683, 0.61067683, 0.61067683, 0. , 0. , 0.}); - - sd::ops::dynamic_bidirectional_rnn op; - auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &maxTimeStep}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto hFW = results.at(0); - auto hBW = results.at(1); - auto hFWfinal = results.at(2); - auto hBWfinal = results.at(3); - - ASSERT_TRUE(expHFW.isSameShape(hFW)); - ASSERT_TRUE(expHFW.equalsTo(hFW)); - ASSERT_TRUE(expHBW.isSameShape(hBW)); - ASSERT_TRUE(expHBW.equalsTo(hBW)); - ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); - ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); - ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); - ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); - - + const int bS = 4; + const int inSize = 4; + const int numUnitsFW = 3; + const int numUnitsBW = 3; + const int time = 5; + + auto x = NDArrayFactory::create('c', {bS, time, inSize}); + auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); + auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); + auto bFW = NDArrayFactory::create('c', {2 * numUnitsFW}); + + auto maxTimeStep = + NDArrayFactory::create('c', {bS}, {time - 1, time - 3, time - 4, 0}); + + x.linspace(0.01, 0.01); + WxFW = 0.3; + WhFW = 0.4; + bFW = 0.1; + + auto expHFW = NDArrayFactory::create( + 'c', {bS, time, numUnitsFW}, + {0.22602835, 0.22602835, 0.22602835, 0.49994591, 0.49994591, 0.49994591, + 0.72869307, 0.72869307, 0.72869307, 0.84784327, 0.84784327, 0.84784327, + 0., 0., 0., 0.43819931, 0.43819931, 0.43819931, + 0.7793996, 0.7793996, 0.7793996, 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.61067683, 0.61067683, 0.61067683, 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0.}); + + auto expHBW = NDArrayFactory::create( + 'c', {bS, time, numUnitsBW}, + {0.82273707, 0.82273707, 0.82273707, 0.77935851, 0.77935851, 0.77935851, + 0.6381121, 0.6381121, 0.6381121, 0.35748551, 0.35748551, 0.35748551, + 0., 0., 0., 0.77843476, 0.77843476, 0.77843476, + 0.47615493, 0.47615493, 0.47615493, 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.61067683, 0.61067683, 0.61067683, 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0.}); + + auto expHFWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsFW}, + {0.84784327, 0.84784327, 0.84784327, 0.7793996, 0.7793996, 0.7793996, + 0.61067683, 0.61067683, 0.61067683, 0., 0., 0.}); + auto expHBWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsBW}, + {0.82273707, 0.82273707, 0.82273707, 0.77843476, 0.77843476, 0.77843476, + 0.61067683, 0.61067683, 0.61067683, 0., 0., 0.}); + + sd::ops::dynamic_bidirectional_rnn op; + auto results = op.evaluate( + {&x, &WxFW, &WhFW, &bFW, &WxFW, &WhFW, &bFW, &maxTimeStep}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto hFW = results.at(0); + auto hBW = results.at(1); + auto hFWfinal = results.at(2); + auto hBWfinal = results.at(3); + + ASSERT_TRUE(expHFW.isSameShape(hFW)); + ASSERT_TRUE(expHFW.equalsTo(hFW)); + ASSERT_TRUE(expHBW.isSameShape(hBW)); + ASSERT_TRUE(expHBW.equalsTo(hBW)); + ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); + ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); + ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); + ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test4) { - - const int bS = 4; - const int inSize = 4; - const int numUnitsFW = 3; - const int numUnitsBW = 3; - const int time = 5; - - auto x = NDArrayFactory::create('c', {bS, time, inSize}); - auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); - auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); - auto bFW = NDArrayFactory::create('c', {2*numUnitsFW}); - - auto h0FW = NDArrayFactory::create('c', {bS, numUnitsFW}); - auto h0BW = NDArrayFactory::create('c', {bS, numUnitsBW}); - - x.linspace(0.01, 0.01); - h0FW = 0.2; - h0BW = 0.25; - WxFW = 0.3; - WhFW = 0.4; - bFW = 0.1; - - auto expHFW = NDArrayFactory::create('c', {bS, time, numUnitsFW}, {0.43819931, 0.43819931, 0.43819931,0.66617761, 0.66617761, 0.66617761,0.80944357, 0.80944357, 0.80944357,0.87294706, 0.87294706, 0.87294706,0.89948899, 0.89948899, 0.89948899, - 0.61067683, 0.61067683, 0.61067683,0.84851124, 0.84851124, 0.84851124,0.91925737, 0.91925737, 0.91925737,0.93751395, 0.93751395, 0.93751395,0.94544483, 0.94544483, 0.94544483, - 0.73978305, 0.73978305, 0.73978305,0.92827068, 0.92827068, 0.92827068,0.95791111, 0.95791111, 0.95791111,0.96427356, 0.96427356, 0.96427356,0.96797541, 0.96797541, 0.96797541, - 0.83057887, 0.83057887, 0.83057887,0.96365083, 0.96365083, 0.96365083,0.97585698, 0.97585698, 0.97585698,0.97866981, 0.97866981, 0.97866981,0.9807326 , 0.9807326 , 0.9807326 }); - - auto expHBW = NDArrayFactory::create('c', {bS, time, numUnitsBW}, {0.85301722, 0.85301722, 0.85301722,0.86427295, 0.86427295, 0.86427295,0.8599919 , 0.8599919 , 0.8599919 ,0.80609463, 0.80609463, 0.80609463,0.61814662, 0.61814662, 0.61814662, - 0.91888753, 0.91888753, 0.91888753,0.92652672, 0.92652672, 0.92652672,0.92939674, 0.92939674, 0.92939674,0.90661931, 0.90661931, 0.90661931,0.74516764, 0.74516764, 0.74516764, - 0.95254269, 0.95254269, 0.95254269,0.95710717, 0.95710717, 0.95710717,0.96021584, 0.96021584, 0.96021584,0.95222547, 0.95222547, 0.95222547,0.83426363, 0.83426363, 0.83426363, - 0.97154357, 0.97154357, 0.97154357,0.97424915, 0.97424915, 0.97424915,0.97644817, 0.97644817, 0.97644817,0.97410547, 0.97410547, 0.97410547,0.89409962, 0.89409962, 0.89409962}); - - auto expHFWfinal = NDArrayFactory::create('c', {bS, numUnitsFW}, {0.89948899, 0.89948899, 0.89948899, 0.94544483, 0.94544483, 0.94544483, 0.96797541, 0.96797541, 0.96797541, 0.9807326 , 0.9807326 , 0.9807326 }); - auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.85301722, 0.85301722, 0.85301722, 0.91888753, 0.91888753, 0.91888753, 0.95254269, 0.95254269, 0.95254269, 0.97154357, 0.97154357, 0.97154357}); - - sd::ops::dynamic_bidirectional_rnn op; - auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto hFW = results.at(0); - auto hBW = results.at(1); - auto hFWfinal = results.at(2); - auto hBWfinal = results.at(3); - - ASSERT_TRUE(expHFW.isSameShape(hFW)); - ASSERT_TRUE(expHFW.equalsTo(hFW)); - ASSERT_TRUE(expHBW.isSameShape(hBW)); - ASSERT_TRUE(expHBW.equalsTo(hBW)); - ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); - ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); - ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); - ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); - - + const int bS = 4; + const int inSize = 4; + const int numUnitsFW = 3; + const int numUnitsBW = 3; + const int time = 5; + + auto x = NDArrayFactory::create('c', {bS, time, inSize}); + auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); + auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); + auto bFW = NDArrayFactory::create('c', {2 * numUnitsFW}); + + auto h0FW = NDArrayFactory::create('c', {bS, numUnitsFW}); + auto h0BW = NDArrayFactory::create('c', {bS, numUnitsBW}); + + x.linspace(0.01, 0.01); + h0FW = 0.2; + h0BW = 0.25; + WxFW = 0.3; + WhFW = 0.4; + bFW = 0.1; + + auto expHFW = NDArrayFactory::create( + 'c', {bS, time, numUnitsFW}, + {0.43819931, 0.43819931, 0.43819931, 0.66617761, 0.66617761, 0.66617761, + 0.80944357, 0.80944357, 0.80944357, 0.87294706, 0.87294706, 0.87294706, + 0.89948899, 0.89948899, 0.89948899, 0.61067683, 0.61067683, 0.61067683, + 0.84851124, 0.84851124, 0.84851124, 0.91925737, 0.91925737, 0.91925737, + 0.93751395, 0.93751395, 0.93751395, 0.94544483, 0.94544483, 0.94544483, + 0.73978305, 0.73978305, 0.73978305, 0.92827068, 0.92827068, 0.92827068, + 0.95791111, 0.95791111, 0.95791111, 0.96427356, 0.96427356, 0.96427356, + 0.96797541, 0.96797541, 0.96797541, 0.83057887, 0.83057887, 0.83057887, + 0.96365083, 0.96365083, 0.96365083, 0.97585698, 0.97585698, 0.97585698, + 0.97866981, 0.97866981, 0.97866981, 0.9807326, 0.9807326, 0.9807326}); + + auto expHBW = NDArrayFactory::create( + 'c', {bS, time, numUnitsBW}, + {0.85301722, 0.85301722, 0.85301722, 0.86427295, 0.86427295, 0.86427295, + 0.8599919, 0.8599919, 0.8599919, 0.80609463, 0.80609463, 0.80609463, + 0.61814662, 0.61814662, 0.61814662, 0.91888753, 0.91888753, 0.91888753, + 0.92652672, 0.92652672, 0.92652672, 0.92939674, 0.92939674, 0.92939674, + 0.90661931, 0.90661931, 0.90661931, 0.74516764, 0.74516764, 0.74516764, + 0.95254269, 0.95254269, 0.95254269, 0.95710717, 0.95710717, 0.95710717, + 0.96021584, 0.96021584, 0.96021584, 0.95222547, 0.95222547, 0.95222547, + 0.83426363, 0.83426363, 0.83426363, 0.97154357, 0.97154357, 0.97154357, + 0.97424915, 0.97424915, 0.97424915, 0.97644817, 0.97644817, 0.97644817, + 0.97410547, 0.97410547, 0.97410547, 0.89409962, 0.89409962, 0.89409962}); + + auto expHFWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsFW}, + {0.89948899, 0.89948899, 0.89948899, 0.94544483, 0.94544483, 0.94544483, + 0.96797541, 0.96797541, 0.96797541, 0.9807326, 0.9807326, 0.9807326}); + auto expHBWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsBW}, + {0.85301722, 0.85301722, 0.85301722, 0.91888753, 0.91888753, 0.91888753, + 0.95254269, 0.95254269, 0.95254269, 0.97154357, 0.97154357, 0.97154357}); + + sd::ops::dynamic_bidirectional_rnn op; + auto results = op.evaluate( + {&x, &WxFW, &WhFW, &bFW, &WxFW, &WhFW, &bFW, &h0FW, &h0BW}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto hFW = results.at(0); + auto hBW = results.at(1); + auto hFWfinal = results.at(2); + auto hBWfinal = results.at(3); + + ASSERT_TRUE(expHFW.isSameShape(hFW)); + ASSERT_TRUE(expHFW.equalsTo(hFW)); + ASSERT_TRUE(expHBW.isSameShape(hBW)); + ASSERT_TRUE(expHBW.equalsTo(hBW)); + ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); + ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); + ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); + ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); } TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test5) { - - const int bS = 4; - const int inSize = 4; - const int numUnitsFW = 3; - const int numUnitsBW = 3; - const int time = 5; - - auto x = NDArrayFactory::create('c', {bS, time, inSize}); - auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); - auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); - auto bFW = NDArrayFactory::create('c', {2*numUnitsFW}); - - x.linspace(0.01, 0.01); - WxFW = 0.3; - WhFW = 0.4; - bFW = 0.1; - - auto expHFW = NDArrayFactory::create('c', {bS, time, numUnitsFW}, {0.22602835, 0.22602835, 0.22602835,0.49994591, 0.49994591, 0.49994591,0.72869307, 0.72869307, 0.72869307,0.84784327, 0.84784327, 0.84784327,0.89357928, 0.89357928, 0.89357928, - 0.43819931, 0.43819931, 0.43819931,0.7793996 , 0.7793996 , 0.7793996 ,0.9053792 , 0.9053792 , 0.9053792 ,0.93546593, 0.93546593, 0.93546593,0.94518339, 0.94518339, 0.94518339, - 0.61067683, 0.61067683, 0.61067683,0.90347408, 0.90347408, 0.90347408,0.95538786, 0.95538786, 0.95538786,0.96406045, 0.96406045, 0.96406045,0.96795929, 0.96795929, 0.96795929, - 0.73978305, 0.73978305, 0.73978305,0.95499984, 0.95499984, 0.95499984,0.97535671, 0.97535671, 0.97535671,0.97864446, 0.97864446, 0.97864446,0.98073144, 0.98073144, 0.98073144}); - - auto expHBW = NDArrayFactory::create('c', {bS, time, numUnitsBW}, {0.84882345, 0.84882345, 0.84882345,0.85160683, 0.85160683, 0.85160683,0.81997657, 0.81997657, 0.81997657,0.69228829, 0.69228829, 0.69228829,0.39861399, 0.39861399, 0.39861399, - 0.91865453, 0.91865453, 0.91865453,0.92528094, 0.92528094, 0.92528094,0.92212167, 0.92212167, 0.92212167,0.86418213, 0.86418213, 0.86418213,0.57969286, 0.57969286, 0.57969286, - 0.95252666, 0.95252666, 0.95252666,0.95696305, 0.95696305, 0.95696305,0.95878749, 0.95878749, 0.95878749,0.93722463, 0.93722463, 0.93722463,0.71727031, 0.71727031, 0.71727031, - 0.97154234, 0.97154234, 0.97154234,0.97423089, 0.97423089, 0.97423089,0.976149 , 0.976149 , 0.976149 ,0.96878298, 0.96878298, 0.96878298,0.81508646, 0.81508646, 0.81508646}); - - auto expHFWfinal = NDArrayFactory::create('c', {bS, numUnitsFW}, {0.89357928, 0.89357928, 0.89357928, 0.94518339, 0.94518339, 0.94518339, 0.96795929, 0.96795929, 0.96795929, 0.98073144, 0.98073144, 0.98073144}); - auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.84882345, 0.84882345, 0.84882345, 0.91865453, 0.91865453, 0.91865453, 0.95252666, 0.95252666, 0.95252666, 0.97154234, 0.97154234, 0.97154234}); - - sd::ops::dynamic_bidirectional_rnn op; - auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto hFW = results.at(0); - auto hBW = results.at(1); - auto hFWfinal = results.at(2); - auto hBWfinal = results.at(3); - - ASSERT_TRUE(expHFW.isSameShape(hFW)); - ASSERT_TRUE(expHFW.equalsTo(hFW)); - ASSERT_TRUE(expHBW.isSameShape(hBW)); - ASSERT_TRUE(expHBW.equalsTo(hBW)); - ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); - ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); - ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); - ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); - - + const int bS = 4; + const int inSize = 4; + const int numUnitsFW = 3; + const int numUnitsBW = 3; + const int time = 5; + + auto x = NDArrayFactory::create('c', {bS, time, inSize}); + auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); + auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); + auto bFW = NDArrayFactory::create('c', {2 * numUnitsFW}); + + x.linspace(0.01, 0.01); + WxFW = 0.3; + WhFW = 0.4; + bFW = 0.1; + + auto expHFW = NDArrayFactory::create( + 'c', {bS, time, numUnitsFW}, + {0.22602835, 0.22602835, 0.22602835, 0.49994591, 0.49994591, 0.49994591, + 0.72869307, 0.72869307, 0.72869307, 0.84784327, 0.84784327, 0.84784327, + 0.89357928, 0.89357928, 0.89357928, 0.43819931, 0.43819931, 0.43819931, + 0.7793996, 0.7793996, 0.7793996, 0.9053792, 0.9053792, 0.9053792, + 0.93546593, 0.93546593, 0.93546593, 0.94518339, 0.94518339, 0.94518339, + 0.61067683, 0.61067683, 0.61067683, 0.90347408, 0.90347408, 0.90347408, + 0.95538786, 0.95538786, 0.95538786, 0.96406045, 0.96406045, 0.96406045, + 0.96795929, 0.96795929, 0.96795929, 0.73978305, 0.73978305, 0.73978305, + 0.95499984, 0.95499984, 0.95499984, 0.97535671, 0.97535671, 0.97535671, + 0.97864446, 0.97864446, 0.97864446, 0.98073144, 0.98073144, 0.98073144}); + + auto expHBW = NDArrayFactory::create( + 'c', {bS, time, numUnitsBW}, + {0.84882345, 0.84882345, 0.84882345, 0.85160683, 0.85160683, 0.85160683, + 0.81997657, 0.81997657, 0.81997657, 0.69228829, 0.69228829, 0.69228829, + 0.39861399, 0.39861399, 0.39861399, 0.91865453, 0.91865453, 0.91865453, + 0.92528094, 0.92528094, 0.92528094, 0.92212167, 0.92212167, 0.92212167, + 0.86418213, 0.86418213, 0.86418213, 0.57969286, 0.57969286, 0.57969286, + 0.95252666, 0.95252666, 0.95252666, 0.95696305, 0.95696305, 0.95696305, + 0.95878749, 0.95878749, 0.95878749, 0.93722463, 0.93722463, 0.93722463, + 0.71727031, 0.71727031, 0.71727031, 0.97154234, 0.97154234, 0.97154234, + 0.97423089, 0.97423089, 0.97423089, 0.976149, 0.976149, 0.976149, + 0.96878298, 0.96878298, 0.96878298, 0.81508646, 0.81508646, 0.81508646}); + + auto expHFWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsFW}, + {0.89357928, 0.89357928, 0.89357928, 0.94518339, 0.94518339, 0.94518339, + 0.96795929, 0.96795929, 0.96795929, 0.98073144, 0.98073144, 0.98073144}); + auto expHBWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsBW}, + {0.84882345, 0.84882345, 0.84882345, 0.91865453, 0.91865453, 0.91865453, + 0.95252666, 0.95252666, 0.95252666, 0.97154234, 0.97154234, 0.97154234}); + + sd::ops::dynamic_bidirectional_rnn op; + auto results = + op.evaluate({&x, &WxFW, &WhFW, &bFW, &WxFW, &WhFW, &bFW}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto hFW = results.at(0); + auto hBW = results.at(1); + auto hFWfinal = results.at(2); + auto hBWfinal = results.at(3); + + ASSERT_TRUE(expHFW.isSameShape(hFW)); + ASSERT_TRUE(expHFW.equalsTo(hFW)); + ASSERT_TRUE(expHBW.isSameShape(hBW)); + ASSERT_TRUE(expHBW.equalsTo(hBW)); + ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); + ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); + ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); + ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); } - TEST_F(DeclarableOpsTests6, Test_Diag_119_1) { - auto x = NDArrayFactory::create('c', {3}, {0.15f, 0.25f, 0.35f}); - auto e = NDArrayFactory::create('c', {3, 3}, {0.15f, 0.0f, 0.0f, 0.0f, 0.25f, 0.0f, 0.0f, 0.0f, 0.35f}); - - sd::ops::diag op; - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_EQ(e, *result.at(0)); + auto x = NDArrayFactory::create('c', {3}, {0.15f, 0.25f, 0.35f}); + auto e = NDArrayFactory::create( + 'c', {3, 3}, {0.15f, 0.0f, 0.0f, 0.0f, 0.25f, 0.0f, 0.0f, 0.0f, 0.35f}); + sd::ops::diag op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(e, result.at(0)); } TEST_F(DeclarableOpsTests6, Test_Diag_119_2) { - auto x = NDArrayFactory::create('c', {1}, {0.15f}); - auto e = NDArrayFactory::create('c', {1, 1}, {0.15f}); - - sd::ops::diag op; - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_EQ(e, *result.at(0)); + auto x = NDArrayFactory::create('c', {1}, {0.15f}); + auto e = NDArrayFactory::create('c', {1, 1}, {0.15f}); + sd::ops::diag op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(e, result.at(0)); } TEST_F(DeclarableOpsTests6, Test_Diag_119_3) { - auto x = NDArrayFactory::create(0.15f); - auto e = NDArrayFactory::create('c', {1, 1}, {0.15f}); - - sd::ops::diag op; - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_EQ(e, *result.at(0)); + auto x = NDArrayFactory::create(0.15f); + auto e = NDArrayFactory::create('c', {1, 1}, {0.15f}); + sd::ops::diag op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(e, result.at(0)); } - - diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp index d8478e471333..ffa56549245f 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp @@ -18,6981 +18,7333 @@ // Created by raver119 on 09.02.18. // - -#include "testlayers.h" -#include -#include #include #include #include +#include +#include +#include "testlayers.h" using namespace sd; using namespace sd::graph; - class DeclarableOpsTests7 : public testing::Test { -public: - - DeclarableOpsTests7() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests7() { + printf("\n"); + fflush(stdout); + } }; template class TypedDeclarableOpsTests7 : public testing::Test { -public: - - TypedDeclarableOpsTests7() { - printf("\n"); - fflush(stdout); - } + public: + TypedDeclarableOpsTests7() { + printf("\n"); + fflush(stdout); + } }; typedef ::testing::Types TestingTypes; -TYPED_TEST_CASE(TypedDeclarableOpsTests7, TestingTypes); +TYPED_TEST_SUITE(TypedDeclarableOpsTests7, TestingTypes); TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR_LARGE) { - double inputData[150] = { - 0, 0.51, 0.68, 0.69, 0.86, 0.91, 0.96, 0.97, 0.97, 1.03, 1.13, 1.16, 1.16, 1.17, 1.19, 1.25, 1.25, 1.26, 1.27, 1.28, 1.29, 1.29, 1.29, 1.30, 1.31, 1.32, 1.33, 1.33, 1.35, 1.35, 1.36, 1.37, 1.38, 1.40, 1.41, 1.42, 1.43, 1.44, 1.44, 1.45, 1.45, 1.47, 1.47, 1.51, 1.51, 1.51, 1.52, 1.53, 1.56, 1.57, 1.58, 1.59, 1.61, 1.62, 1.63, 1.63, 1.64, 1.64, 1.66, 1.66, 1.67, 1.67, 1.70, 1.70, 1.70, 1.72, 1.72, 1.72, 1.72, 1.73, 1.74, 1.74, 1.76, 1.76, 1.77, 1.77, 1.80, 1.80, 1.81, 1.82, 1.83, 1.83, 1.84, 1.84, 1.84, 1.85, 1.85, 1.85, 1.86, 1.86, 1.87, 1.88, 1.89, 1.89, 1.89, 1.89, 1.89, 1.91, 1.91, 1.91, 1.92, 1.94, 1.95, 1.97, 1.98, 1.98, 1.98, 1.98, 1.98, 1.99, 2, 2, 2.01, 2.01, 2.02, 2.03, 2.03, 2.03, 2.04, 2.04, 2.05, 2.06, 2.07, 2.08, 2.08, 2.08, 2.08, 2.09, 2.09, 2.10, 2.10, 2.11, 2.11, 2.11, 2.12, 2.12, 2.13, 2.13, 2.14, 2.14, 2.14, 2.14, 2.15, 2.15, 2.16, 2.16, 2.16, 2.16, 2.16, 2.17 - }; - - auto x = NDArrayFactory::create(inputData,'c',{1,149}); - sd::ops::choose op; - //greater than test - auto result = op.evaluate({&x}, {0.0},{3}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(1); - - ASSERT_EQ(148,z->e(0)); - //ASSERT_TRUE(exp.isSameShape(z)); - - - + double inputData[150] = { + 0, 0.51, 0.68, 0.69, 0.86, 0.91, 0.96, 0.97, 0.97, 1.03, 1.13, 1.16, + 1.16, 1.17, 1.19, 1.25, 1.25, 1.26, 1.27, 1.28, 1.29, 1.29, 1.29, 1.30, + 1.31, 1.32, 1.33, 1.33, 1.35, 1.35, 1.36, 1.37, 1.38, 1.40, 1.41, 1.42, + 1.43, 1.44, 1.44, 1.45, 1.45, 1.47, 1.47, 1.51, 1.51, 1.51, 1.52, 1.53, + 1.56, 1.57, 1.58, 1.59, 1.61, 1.62, 1.63, 1.63, 1.64, 1.64, 1.66, 1.66, + 1.67, 1.67, 1.70, 1.70, 1.70, 1.72, 1.72, 1.72, 1.72, 1.73, 1.74, 1.74, + 1.76, 1.76, 1.77, 1.77, 1.80, 1.80, 1.81, 1.82, 1.83, 1.83, 1.84, 1.84, + 1.84, 1.85, 1.85, 1.85, 1.86, 1.86, 1.87, 1.88, 1.89, 1.89, 1.89, 1.89, + 1.89, 1.91, 1.91, 1.91, 1.92, 1.94, 1.95, 1.97, 1.98, 1.98, 1.98, 1.98, + 1.98, 1.99, 2, 2, 2.01, 2.01, 2.02, 2.03, 2.03, 2.03, 2.04, 2.04, + 2.05, 2.06, 2.07, 2.08, 2.08, 2.08, 2.08, 2.09, 2.09, 2.10, 2.10, 2.11, + 2.11, 2.11, 2.12, 2.12, 2.13, 2.13, 2.14, 2.14, 2.14, 2.14, 2.15, 2.15, + 2.16, 2.16, 2.16, 2.16, 2.16, 2.17}; + + auto x = NDArrayFactory::create(inputData, 'c', {1, 149}); + sd::ops::choose op; + // greater than test + auto result = op.evaluate({&x}, {0.0}, {3}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(1); + + ASSERT_EQ(148, z.e(0)); + // ASSERT_TRUE(exp.isSameShape(z)); } TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR_ZERO) { - std::vector data; - for(Nd4jLong i = 0; i < 4; i++) { - data.push_back(i); - } + std::vector data; + for (Nd4jLong i = 0; i < 4; i++) { + data.push_back(i); + } + auto x = NDArrayFactory::create('c', {1, 4}, data); + sd::ops::choose op; + // greater than test + auto result = op.evaluate({&x}, {0.0}, {3}); + ASSERT_EQ(Status::OK(), result.status()); - - auto x = NDArrayFactory::create('c',{1,4},data); - sd::ops::choose op; - //greater than test - auto result = op.evaluate({&x}, {0.0},{3}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(1); - auto array = *z; - ASSERT_EQ(3,array.e(0)); - //ASSERT_TRUE(exp.isSameShape(z)); - - - + auto z = result.at(1); + ASSERT_EQ(3, z.e(0)); + // ASSERT_TRUE(exp.isSameShape(z)); } - TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR) { - std::vector data; - for(Nd4jLong i = 0; i < 4; i++) { - data.push_back(i); - } - - + std::vector data; + for (Nd4jLong i = 0; i < 4; i++) { + data.push_back(i); + } - auto x = NDArrayFactory::create('c',{1,4},data); - auto scalar = NDArrayFactory::create('c',{1,1},{0.0}); - sd::ops::choose op; - //greater than test - auto result = op.evaluate({&x,&scalar}, {1.0},{3}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_EQ(3, z->lengthOf()); - //ASSERT_TRUE(exp.isSameShape(z)); - - + auto x = NDArrayFactory::create('c', {1, 4}, data); + auto scalar = NDArrayFactory::create('c', {1, 1}, {0.0}); + sd::ops::choose op; + // greater than test + auto result = op.evaluate({&x, &scalar}, {1.0}, {3}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(3, z.lengthOf()); + // ASSERT_TRUE(exp.isSameShape(z)); } - TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR_LEFT) { - std::vector data; - for(Nd4jLong i = 0; i < 4; i++) { - data.push_back(i); - } - - + std::vector data; + for (Nd4jLong i = 0; i < 4; i++) { + data.push_back(i); + } - auto x = NDArrayFactory::create('c',{1,4},data); - auto scalar = NDArrayFactory::create('c',{1,1},{0.0}); - sd::ops::choose op; - //greater than test - auto result = op.evaluate({&scalar,&x}, {1.0},{3}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_EQ(3,z->lengthOf()); - //ASSERT_TRUE(exp.isSameShape(z)); - - + auto x = NDArrayFactory::create('c', {1, 4}, data); + auto scalar = NDArrayFactory::create('c', {1, 1}, {0.0}); + sd::ops::choose op; + // greater than test + auto result = op.evaluate({&scalar, &x}, {1.0}, {3}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(3, z.lengthOf()); + // ASSERT_TRUE(exp.isSameShape(z)); } - TEST_F(DeclarableOpsTests7, Test_CHOOSE_ONLY_SCALAR) { - std::vector data; - for(Nd4jLong i = 0; i < 4; i++) { - data.push_back(i); - } - - + std::vector data; + for (Nd4jLong i = 0; i < 4; i++) { + data.push_back(i); + } - auto x = NDArrayFactory::create('c',{1,4},data); - sd::ops::choose op; - //greater than test - auto result = op.evaluate({&x}, {1.0},{3}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_EQ(2,z->lengthOf()); - //ASSERT_TRUE(exp.isSameShape(z)); - - + auto x = NDArrayFactory::create('c', {1, 4}, data); + sd::ops::choose op; + // greater than test + auto result = op.evaluate({&x}, {1.0}, {3}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(2, z.lengthOf()); + // ASSERT_TRUE(exp.isSameShape(z)); } - TEST_F(DeclarableOpsTests7, Test_CHOOSE_ONLY_SCALAR_GTE) { - std::vector data; - for(Nd4jLong i = 0; i < 4; i++) { - data.push_back(i); - } - - - - auto x = NDArrayFactory::create('c',{1,4},data); - sd::ops::choose op; - //greater than test - auto result = op.evaluate({&x}, {1.0},{5}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_EQ(3,z->lengthOf()); - //ASSERT_TRUE(exp.isSameShape(z)); + std::vector data; + for (Nd4jLong i = 0; i < 4; i++) { + data.push_back(i); + } - + auto x = NDArrayFactory::create('c', {1, 4}, data); + sd::ops::choose op; + // greater than test + auto result = op.evaluate({&x}, {1.0}, {5}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(3, z.lengthOf()); + // ASSERT_TRUE(exp.isSameShape(z)); } - TEST_F(DeclarableOpsTests7, TEST_WHERE) { - std::vector data; - std::vector mask; - std::vector put; - std::vector resultData; - std::vector assertion; - for(Nd4jLong i = 0; i < 4; i++) { - data.push_back(i); - if(i > 1) { - assertion.push_back(5.0); - mask.push_back(true); - } - else { - assertion.push_back(i); - mask.push_back(false); - } - - put.push_back(5.0); - resultData.push_back(0.0); + std::vector data; + std::vector mask; + std::vector put; + std::vector resultData; + std::vector assertion; + for (Nd4jLong i = 0; i < 4; i++) { + data.push_back(i); + if (i > 1) { + assertion.push_back(5.0); + mask.push_back(true); + } else { + assertion.push_back(i); + mask.push_back(false); } - - - - auto x = NDArrayFactory::create('c',{1,4},data); - auto maskArr = NDArrayFactory::create('c',{1,4},mask); - auto putArr = NDArrayFactory::create('c',{1,4},put); - auto resultArr = NDArrayFactory::create('c',{1,4},resultData); - sd::ops::where_np op; - //greater than test - // Nd4jStatus execute(std::initializer_list*> inputs, std::initializer_list*> outputs , std::initializer_list tArgs, std::initializer_list iArgs, bool isInplace = false); - - auto result = op.execute({&maskArr,&x,&putArr},{&resultArr}, {},{3}, {}, {}, false); - ASSERT_EQ(Status::OK(), result); - for(int i = 0; i < 4; i++) - ASSERT_EQ(assertion[i],resultArr.e(i)); - // auto z = result.at(0); - //ASSERT_EQ(4,z->lengthOf()); - //ASSERT_TRUE(exp.isSameShape(z)); - - + put.push_back(5.0); + resultData.push_back(0.0); + } + + auto x = NDArrayFactory::create('c', {1, 4}, data); + auto maskArr = NDArrayFactory::create('c', {1, 4}, mask); + auto putArr = NDArrayFactory::create('c', {1, 4}, put); + auto resultArr = NDArrayFactory::create('c', {1, 4}, resultData); + sd::ops::where_np op; + // greater than test + // Nd4jStatus execute(std::initializer_list*> inputs, + // std::initializer_list*> outputs , + // std::initializer_list tArgs, std::initializer_list + // iArgs, bool isInplace = false); + + auto result = + op.execute({&maskArr, &x, &putArr}, {&resultArr}, {}, {3}, {}, {}, false); + ASSERT_EQ(Status::OK(), result); + for (int i = 0; i < 4; i++) ASSERT_EQ(assertion[i], resultArr.e(i)); + // auto z = result.at(0); + // ASSERT_EQ(4,z->lengthOf()); + // ASSERT_TRUE(exp.isSameShape(z)); } TEST_F(DeclarableOpsTests7, TEST_WHERE_MASK) { - double x[300] = {1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0}; - double z[300] = {1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0}; - bool mask[300] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; - double put[200] = {0.99666107,0.9867112,0.97686064,0.9671082,0.95745337,0.9478948,0.9384318,0.92906314,0.9197881,0.91060543,0.9015147,0.8925147,0.8836044,0.8747831,0.86605,0.85740393,0.8488442,0.84037,0.83198035,0.8236745,0.8154515,0.8073106,0.79925096,0.79127187,0.7833724,0.77555174,0.76780915,0.7601439,0.75255525,0.7450422,0.7376043,0.73024046,0.72295034,0.715733,0.7085876,0.7015135,0.69451016,0.68757665,0.6807124,0.6739167,0.66718876,0.66052806,0.6539338,0.6474054,0.6409421,0.6345435,0.6282087,0.6219371,0.6157281,0.60958105,0.6034956,0.59747064,0.5915059,0.5856007,0.57975453,0.5739667,0.5682366,0.5625637,0.5569475,0.5513874,0.54588276,0.540433,0.53503764,0.5296962,0.52440816,0.51917285,0.5139898,0.5088585,0.50377846,0.4987491,0.4937699,0.48884052,0.48396033,0.47912875,0.47434545,0.4696099,0.46492168,0.46028027,0.45568514,0.4511359,0.44663212,0.4421733,0.43775895,0.43338865,0.42906195,0.42477852,0.4205379,0.41633952,0.41218308,0.40806815,0.40399432,0.3999611,0.3959682,0.39201516,0.38810158,0.384227,0.38039115,0.37659356,0.37283397,0.3691119,0.36542687,0.36177874,0.35816705,0.3545914,0.35105142,0.34754673,0.34407702,0.34064204,0.33724132,0.3338745,0.33054137,0.3272415,0.32397458,0.32074028,0.3175382,0.31436813,0.31122974,0.3081226,0.30504647,0.30200112,0.2989862,0.29600134,0.29304633,0.2901207,0.28722438,0.28435695,0.2815181,0.27870762,0.27592525,0.27317056,0.27044344,0.26774356,0.26507056,0.2624243,0.25980446,0.25721073,0.25464293,0.25210077,0.249584,0.24709237,0.24462552,0.24218333,0.23976555,0.23737194,0.23500215,0.23265606,0.23033342,0.22803394,0.22575743,0.2235036,0.22127232,0.21906327,0.21687631,0.21471114,0.21256764,0.21044552,0.20834461,0.20626466,0.20420544,0.20216681,0.20014854,0.19815037,0.19617215,0.19421372,0.19227484,0.19035533,0.18845497,0.18657354,0.18471093,0.18286693,0.18104129,0.17923392,0.17744459,0.17567308,0.1739193,0.17218304,0.17046405,0.16876228,0.16707748,0.16540948,0.16375816,0.16212334,0.16050482,0.15890247,0.15731607,0.15574552,0.15419069,0.15265137,0.15112738,0.14961864,0.14812498,0.14664622,0.1451822,0.14373279,0.14229788,0.14087726,0.13947085,0.13807845,0.13669999,0.13533528}; - double assertion[300] = {1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,9.966611049434810354e-01,9.867111603284486332e-01,9.768605487739230320e-01,9.671082786103732953e-01,9.574533680683808834e-01,9.478948451798039354e-01,9.384317476799283186e-01,9.290631229105962285e-01,9.197880277243004610e-01,9.106055283892373620e-01,9.015147004953073528e-01,8.925146288610534828e-01,8.836044074415293492e-01,8.747831392370875037e-01,8.660499362030764647e-01,8.574039191604412302e-01,8.488442177072155204e-01,8.403699701308978698e-01,8.319803233217017979e-01,8.236744326866727306e-01,8.154514620646623468e-01,8.073105836421510251e-01,7.992509778699116163e-01,7.912718333805045523e-01,7.833723469065965173e-01,7.755517232000953554e-01,7.678091749520912224e-01,7.601439227135980969e-01,7.525551948170853267e-01,7.450422272987937689e-01,7.376042638218265335e-01,7.302405556000080011e-01,7.229503613225031211e-01,7.157329470791886639e-01,7.085875862867698771e-01,7.015135596156351072e-01,6.945101549174396149e-01,6.875766671534137009e-01,6.807123983233853703e-01,6.739166573955123196e-01,6.671887602367149173e-01,6.605280295438040739e-01,6.539337947752965619e-01,6.474053920839111242e-01,6.409421642497381555e-01,6.345434606140767375e-01,6.282086370139332576e-01,6.219370557171712832e-01,6.157280853583116942e-01,6.095811008749726367e-01,6.034954834449430816e-01,5.974706204238864338e-01,5.915059052836644238e-01,5.856007375512777280e-01,5.797545227484157682e-01,5.739666723316099173e-01,5.682366036329845604e-01,5.625637398015992385e-01,5.569475097453767676e-01,5.513873480736106725e-01,5.458826950400470501e-01,5.404329964865340896e-01,5.350377037872348085e-01,5.296962737933965659e-01,5.244081687786711354e-01,5.191728563849821176e-01,5.139898095689314772e-01,5.088585065487419845e-01,5.037784307517284565e-01,4.987490707622945774e-01,4.937699202704479151e-01,4.888404780208293054e-01,4.839602477622509946e-01,4.791287381977387683e-01,4.743454629350723484e-01,4.696099404378203390e-01,4.649216939768630041e-01,4.602802515824001017e-01,4.556851459964368911e-01,4.511359146257447605e-01,4.466320994952920342e-01,4.421732472021388527e-01,4.377589088697927955e-01,4.333886401030203062e-01,4.290620009431086457e-01,4.247785558235752101e-01,4.205378735263185508e-01,4.163395271382073215e-01,4.121830940081024908e-01,4.080681557043087104e-01,4.039942979724505667e-01,3.999611106937689398e-01,3.959681878438343627e-01,3.920151274516718853e-01,3.881015315592946102e-01,3.842270061816405180e-01,3.803911612669100828e-01,3.765936106572991271e-01,3.728339720501240850e-01,3.691118669593352886e-01,3.654269206774144463e-01,3.617787622376523182e-01,3.581670243768036999e-01,3.545913434981138868e-01,3.510513596347161203e-01,3.475467164133922426e-01,3.440770610186974499e-01,3.406420441574410929e-01,3.372413200235238606e-01,3.338745462631242389e-01,3.305413839402346898e-01,3.272414975025391692e-01,3.239745547476344245e-01,3.207402267895853032e-01,3.175381880258169032e-01,3.143681161043347383e-01,3.112296918912743071e-01,3.081225994387726264e-01,3.050465259531625062e-01,3.020011617634821843e-01,2.989862002903017069e-01,2.960013380148582840e-01,2.930462744485015647e-01,2.901207121024425017e-01,2.872243564578055852e-01,2.843569159359789489e-01,2.815181018692606840e-01,2.787076284717992514e-01,2.759252128108221624e-01,2.731705747781537075e-01,2.704434370620155681e-01,2.677435251191103149e-01,2.650705671469821278e-01,2.624242940566549609e-01,2.598044394455423789e-01,2.572107395706292876e-01,2.546429333219200064e-01,2.521007621961529055e-01,2.495839702707757235e-01,2.470923041781825646e-01,2.446255130802063582e-01,2.421833486428674187e-01,2.397655650113727777e-01,2.373719187853666479e-01,2.350021689944260528e-01,2.326560770738031469e-01,2.303334068404078172e-01,2.280339244690317291e-01,2.257573984688081292e-01,2.235035996599082919e-01,2.212723011504689752e-01,2.190632783137518302e-01,2.168763087655291855e-01,2.147111723416972873e-01,2.125676510761114746e-01,2.104455291786438698e-01,2.083445930134591173e-01,2.062646310775079761e-01,2.042054339792348794e-01,2.021667944174980747e-01,2.001485071607009836e-01,1.981503690261307848e-01,1.961721788595043592e-01,1.942137375147174327e-01,1.922748478337968081e-01,1.903553146270518526e-01,1.884549446534251604e-01,1.865735466010380594e-01,1.847109310679319050e-01,1.828669105430000552e-01,1.810412993871116094e-01,1.792339138144224131e-01,1.774445718738737465e-01,1.756730934308744496e-01,1.739193001491673995e-01,1.721830154728755669e-01,1.704640646087285105e-01,1.687622745084652875e-01,1.670774738514141378e-01,1.654094930272448083e-01,1.637581641188943782e-01,1.621233208856623365e-01,1.605047987464754966e-01,1.589024347633189727e-01,1.573160676248336609e-01,1.557455376300762306e-01,1.541906866724424563e-01,1.526513582237501165e-01,1.511273973184814046e-01,1.496186505381822129e-01,1.481249659960175158e-01,1.466461933214808777e-01,1.451821836452561187e-01,1.437327895842310799e-01,1.422978652266598532e-01,1.408772661174743090e-01,1.394708492437411185e-01,1.380784730202649913e-01,1.366999972753347725e-01,1.353352832366127023e-01}; - Nd4jLong threeHundredShapePointer[8] = {2,1,300,1,1,0,1,99}; - Nd4jLong twoHundredShapePointer[8] = {2,1,200,1,1,0,1,99}; - sd::ops::where_np op; - ArrayOptions::setDataType(threeHundredShapePointer, sd::DataType::DOUBLE); - ArrayOptions::setDataType(twoHundredShapePointer, sd::DataType::DOUBLE); - - NDArray xArr(x,threeHundredShapePointer); - NDArray putArr(put,twoHundredShapePointer); - NDArray resultArr(z,threeHundredShapePointer); - - resultArr.assign(0.0); - ArrayOptions::setDataType(threeHundredShapePointer, sd::DataType::BOOL); - NDArray maskArr(mask,threeHundredShapePointer); - - ArrayOptions::setDataType(threeHundredShapePointer, sd::DataType::DOUBLE); - NDArray assertArr(assertion, threeHundredShapePointer); - Nd4jStatus result = op.execute({&maskArr, &xArr, &putArr},{&resultArr},{},{},{}); - ASSERT_EQ(Status::OK(),result); - ASSERT_TRUE(assertArr.isSameShape(resultArr)); - ASSERT_TRUE (assertArr.equalsTo(resultArr)); + double x[300] = { + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + double z[300] = { + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + bool mask[300] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + double put[200] = { + 0.99666107, 0.9867112, 0.97686064, 0.9671082, 0.95745337, 0.9478948, + 0.9384318, 0.92906314, 0.9197881, 0.91060543, 0.9015147, 0.8925147, + 0.8836044, 0.8747831, 0.86605, 0.85740393, 0.8488442, 0.84037, + 0.83198035, 0.8236745, 0.8154515, 0.8073106, 0.79925096, 0.79127187, + 0.7833724, 0.77555174, 0.76780915, 0.7601439, 0.75255525, 0.7450422, + 0.7376043, 0.73024046, 0.72295034, 0.715733, 0.7085876, 0.7015135, + 0.69451016, 0.68757665, 0.6807124, 0.6739167, 0.66718876, 0.66052806, + 0.6539338, 0.6474054, 0.6409421, 0.6345435, 0.6282087, 0.6219371, + 0.6157281, 0.60958105, 0.6034956, 0.59747064, 0.5915059, 0.5856007, + 0.57975453, 0.5739667, 0.5682366, 0.5625637, 0.5569475, 0.5513874, + 0.54588276, 0.540433, 0.53503764, 0.5296962, 0.52440816, 0.51917285, + 0.5139898, 0.5088585, 0.50377846, 0.4987491, 0.4937699, 0.48884052, + 0.48396033, 0.47912875, 0.47434545, 0.4696099, 0.46492168, 0.46028027, + 0.45568514, 0.4511359, 0.44663212, 0.4421733, 0.43775895, 0.43338865, + 0.42906195, 0.42477852, 0.4205379, 0.41633952, 0.41218308, 0.40806815, + 0.40399432, 0.3999611, 0.3959682, 0.39201516, 0.38810158, 0.384227, + 0.38039115, 0.37659356, 0.37283397, 0.3691119, 0.36542687, 0.36177874, + 0.35816705, 0.3545914, 0.35105142, 0.34754673, 0.34407702, 0.34064204, + 0.33724132, 0.3338745, 0.33054137, 0.3272415, 0.32397458, 0.32074028, + 0.3175382, 0.31436813, 0.31122974, 0.3081226, 0.30504647, 0.30200112, + 0.2989862, 0.29600134, 0.29304633, 0.2901207, 0.28722438, 0.28435695, + 0.2815181, 0.27870762, 0.27592525, 0.27317056, 0.27044344, 0.26774356, + 0.26507056, 0.2624243, 0.25980446, 0.25721073, 0.25464293, 0.25210077, + 0.249584, 0.24709237, 0.24462552, 0.24218333, 0.23976555, 0.23737194, + 0.23500215, 0.23265606, 0.23033342, 0.22803394, 0.22575743, 0.2235036, + 0.22127232, 0.21906327, 0.21687631, 0.21471114, 0.21256764, 0.21044552, + 0.20834461, 0.20626466, 0.20420544, 0.20216681, 0.20014854, 0.19815037, + 0.19617215, 0.19421372, 0.19227484, 0.19035533, 0.18845497, 0.18657354, + 0.18471093, 0.18286693, 0.18104129, 0.17923392, 0.17744459, 0.17567308, + 0.1739193, 0.17218304, 0.17046405, 0.16876228, 0.16707748, 0.16540948, + 0.16375816, 0.16212334, 0.16050482, 0.15890247, 0.15731607, 0.15574552, + 0.15419069, 0.15265137, 0.15112738, 0.14961864, 0.14812498, 0.14664622, + 0.1451822, 0.14373279, 0.14229788, 0.14087726, 0.13947085, 0.13807845, + 0.13669999, 0.13533528}; + double assertion[300] = {1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 9.966611049434810354e-01, 9.867111603284486332e-01, + 9.768605487739230320e-01, 9.671082786103732953e-01, + 9.574533680683808834e-01, 9.478948451798039354e-01, + 9.384317476799283186e-01, 9.290631229105962285e-01, + 9.197880277243004610e-01, 9.106055283892373620e-01, + 9.015147004953073528e-01, 8.925146288610534828e-01, + 8.836044074415293492e-01, 8.747831392370875037e-01, + 8.660499362030764647e-01, 8.574039191604412302e-01, + 8.488442177072155204e-01, 8.403699701308978698e-01, + 8.319803233217017979e-01, 8.236744326866727306e-01, + 8.154514620646623468e-01, 8.073105836421510251e-01, + 7.992509778699116163e-01, 7.912718333805045523e-01, + 7.833723469065965173e-01, 7.755517232000953554e-01, + 7.678091749520912224e-01, 7.601439227135980969e-01, + 7.525551948170853267e-01, 7.450422272987937689e-01, + 7.376042638218265335e-01, 7.302405556000080011e-01, + 7.229503613225031211e-01, 7.157329470791886639e-01, + 7.085875862867698771e-01, 7.015135596156351072e-01, + 6.945101549174396149e-01, 6.875766671534137009e-01, + 6.807123983233853703e-01, 6.739166573955123196e-01, + 6.671887602367149173e-01, 6.605280295438040739e-01, + 6.539337947752965619e-01, 6.474053920839111242e-01, + 6.409421642497381555e-01, 6.345434606140767375e-01, + 6.282086370139332576e-01, 6.219370557171712832e-01, + 6.157280853583116942e-01, 6.095811008749726367e-01, + 6.034954834449430816e-01, 5.974706204238864338e-01, + 5.915059052836644238e-01, 5.856007375512777280e-01, + 5.797545227484157682e-01, 5.739666723316099173e-01, + 5.682366036329845604e-01, 5.625637398015992385e-01, + 5.569475097453767676e-01, 5.513873480736106725e-01, + 5.458826950400470501e-01, 5.404329964865340896e-01, + 5.350377037872348085e-01, 5.296962737933965659e-01, + 5.244081687786711354e-01, 5.191728563849821176e-01, + 5.139898095689314772e-01, 5.088585065487419845e-01, + 5.037784307517284565e-01, 4.987490707622945774e-01, + 4.937699202704479151e-01, 4.888404780208293054e-01, + 4.839602477622509946e-01, 4.791287381977387683e-01, + 4.743454629350723484e-01, 4.696099404378203390e-01, + 4.649216939768630041e-01, 4.602802515824001017e-01, + 4.556851459964368911e-01, 4.511359146257447605e-01, + 4.466320994952920342e-01, 4.421732472021388527e-01, + 4.377589088697927955e-01, 4.333886401030203062e-01, + 4.290620009431086457e-01, 4.247785558235752101e-01, + 4.205378735263185508e-01, 4.163395271382073215e-01, + 4.121830940081024908e-01, 4.080681557043087104e-01, + 4.039942979724505667e-01, 3.999611106937689398e-01, + 3.959681878438343627e-01, 3.920151274516718853e-01, + 3.881015315592946102e-01, 3.842270061816405180e-01, + 3.803911612669100828e-01, 3.765936106572991271e-01, + 3.728339720501240850e-01, 3.691118669593352886e-01, + 3.654269206774144463e-01, 3.617787622376523182e-01, + 3.581670243768036999e-01, 3.545913434981138868e-01, + 3.510513596347161203e-01, 3.475467164133922426e-01, + 3.440770610186974499e-01, 3.406420441574410929e-01, + 3.372413200235238606e-01, 3.338745462631242389e-01, + 3.305413839402346898e-01, 3.272414975025391692e-01, + 3.239745547476344245e-01, 3.207402267895853032e-01, + 3.175381880258169032e-01, 3.143681161043347383e-01, + 3.112296918912743071e-01, 3.081225994387726264e-01, + 3.050465259531625062e-01, 3.020011617634821843e-01, + 2.989862002903017069e-01, 2.960013380148582840e-01, + 2.930462744485015647e-01, 2.901207121024425017e-01, + 2.872243564578055852e-01, 2.843569159359789489e-01, + 2.815181018692606840e-01, 2.787076284717992514e-01, + 2.759252128108221624e-01, 2.731705747781537075e-01, + 2.704434370620155681e-01, 2.677435251191103149e-01, + 2.650705671469821278e-01, 2.624242940566549609e-01, + 2.598044394455423789e-01, 2.572107395706292876e-01, + 2.546429333219200064e-01, 2.521007621961529055e-01, + 2.495839702707757235e-01, 2.470923041781825646e-01, + 2.446255130802063582e-01, 2.421833486428674187e-01, + 2.397655650113727777e-01, 2.373719187853666479e-01, + 2.350021689944260528e-01, 2.326560770738031469e-01, + 2.303334068404078172e-01, 2.280339244690317291e-01, + 2.257573984688081292e-01, 2.235035996599082919e-01, + 2.212723011504689752e-01, 2.190632783137518302e-01, + 2.168763087655291855e-01, 2.147111723416972873e-01, + 2.125676510761114746e-01, 2.104455291786438698e-01, + 2.083445930134591173e-01, 2.062646310775079761e-01, + 2.042054339792348794e-01, 2.021667944174980747e-01, + 2.001485071607009836e-01, 1.981503690261307848e-01, + 1.961721788595043592e-01, 1.942137375147174327e-01, + 1.922748478337968081e-01, 1.903553146270518526e-01, + 1.884549446534251604e-01, 1.865735466010380594e-01, + 1.847109310679319050e-01, 1.828669105430000552e-01, + 1.810412993871116094e-01, 1.792339138144224131e-01, + 1.774445718738737465e-01, 1.756730934308744496e-01, + 1.739193001491673995e-01, 1.721830154728755669e-01, + 1.704640646087285105e-01, 1.687622745084652875e-01, + 1.670774738514141378e-01, 1.654094930272448083e-01, + 1.637581641188943782e-01, 1.621233208856623365e-01, + 1.605047987464754966e-01, 1.589024347633189727e-01, + 1.573160676248336609e-01, 1.557455376300762306e-01, + 1.541906866724424563e-01, 1.526513582237501165e-01, + 1.511273973184814046e-01, 1.496186505381822129e-01, + 1.481249659960175158e-01, 1.466461933214808777e-01, + 1.451821836452561187e-01, 1.437327895842310799e-01, + 1.422978652266598532e-01, 1.408772661174743090e-01, + 1.394708492437411185e-01, 1.380784730202649913e-01, + 1.366999972753347725e-01, 1.353352832366127023e-01}; + Nd4jLong threeHundredShapePointer[8] = {2, 1, 300, 1, 1, 0, 1, 99}; + Nd4jLong twoHundredShapePointer[8] = {2, 1, 200, 1, 1, 0, 1, 99}; + sd::ops::where_np op; + ArrayOptions::setDataType(threeHundredShapePointer, sd::DataType::DOUBLE); + ArrayOptions::setDataType(twoHundredShapePointer, sd::DataType::DOUBLE); + + NDArray xArr(x, threeHundredShapePointer); + NDArray putArr(put, twoHundredShapePointer); + NDArray resultArr(z, threeHundredShapePointer); + + resultArr.assign(0.0); + ArrayOptions::setDataType(threeHundredShapePointer, sd::DataType::BOOL); + NDArray maskArr(mask, threeHundredShapePointer); + + ArrayOptions::setDataType(threeHundredShapePointer, sd::DataType::DOUBLE); + NDArray assertArr(assertion, threeHundredShapePointer); + Nd4jStatus result = + op.execute({&maskArr, &xArr, &putArr}, {&resultArr}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); + ASSERT_TRUE(assertArr.isSameShape(resultArr)); + ASSERT_TRUE(assertArr.equalsTo(resultArr)); } TEST_F(DeclarableOpsTests7, TEST_WHERE_SCALAR) { - std::vector data; - std::vector mask; - std::vector put; - std::vector resultData; - std::vector assertion; - for(Nd4jLong i = 0; i < 4; i++) { - data.push_back(i); - if(i > 1) { - assertion.push_back(5.0); - mask.push_back(true); - } - else { - assertion.push_back(i); - mask.push_back(false); - } - - resultData.push_back(0.0); + std::vector data; + std::vector mask; + std::vector put; + std::vector resultData; + std::vector assertion; + for (Nd4jLong i = 0; i < 4; i++) { + data.push_back(i); + if (i > 1) { + assertion.push_back(5.0); + mask.push_back(true); + } else { + assertion.push_back(i); + mask.push_back(false); } + resultData.push_back(0.0); + } - put.push_back(5.0); - - - auto x = NDArrayFactory::create('c',{1,4},data); - auto maskArr = NDArrayFactory::create('c',{1,4},mask); - auto putArr = NDArrayFactory::create('c',{1,1},put); - auto resultArr = NDArrayFactory::create('c',{1,4},resultData); - sd::ops::where_np op; - //greater than test - // Nd4jStatus execute(std::initializer_list*> inputs, std::initializer_list*> outputs , std::initializer_list tArgs, std::initializer_list iArgs, bool isInplace = false); - - auto result = op.execute({&maskArr,&x,&putArr},{&resultArr}, {},{3}, {}, {}, false); - // ASSERT_EQ(Status::OK(), result.status()); - for(int i = 0; i < 4; i++) - ASSERT_EQ(assertion[i],resultArr.e(i)); - // auto z = result.at(0); - //ASSERT_EQ(4,z->lengthOf()); - //ASSERT_TRUE(exp.isSameShape(z)); + put.push_back(5.0); + auto x = NDArrayFactory::create('c', {1, 4}, data); + auto maskArr = NDArrayFactory::create('c', {1, 4}, mask); + auto putArr = NDArrayFactory::create('c', {1, 1}, put); + auto resultArr = NDArrayFactory::create('c', {1, 4}, resultData); + sd::ops::where_np op; + // greater than test + // Nd4jStatus execute(std::initializer_list*> inputs, + // std::initializer_list*> outputs , + // std::initializer_list tArgs, std::initializer_list + // iArgs, bool isInplace = false); + auto result = + op.execute({&maskArr, &x, &putArr}, {&resultArr}, {}, {3}, {}, {}, false); + // ASSERT_EQ(Status::OK(), result.status()); + for (int i = 0; i < 4; i++) ASSERT_EQ(assertion[i], resultArr.e(i)); + // auto z = result.at(0); + // ASSERT_EQ(4,z->lengthOf()); + // ASSERT_TRUE(exp.isSameShape(z)); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestMatrixDiagPart_1) { - auto x = NDArrayFactory::create('c', {2, 4, 4}, {1., 0., 0., 0., 0., 2., 0., 0., 0., 0., 3., 0., 0., 0., 0., 4.,5., 0., 0., 0., 0., 6., 0., 0., 0., 0., 7., 0., 0., 0., 0., 8.}); - - auto z = NDArrayFactory::create('c', {2, 4}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}); + auto x = NDArrayFactory::create( + 'c', {2, 4, 4}, + {1., 0., 0., 0., 0., 2., 0., 0., 0., 0., 3., 0., 0., 0., 0., 4., + 5., 0., 0., 0., 0., 6., 0., 0., 0., 0., 7., 0., 0., 0., 0., 8.}); - sd::ops::matrix_diag_part op; + auto z = NDArrayFactory::create( + 'c', {2, 4}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}); - auto result = op.evaluate({&x}, {}, {}); + sd::ops::matrix_diag_part op; - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(z.equalsTo(result.at(0))); + auto result = op.evaluate({&x}, {}, {}); - + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(z.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestMatrixDiagPart_2) { - auto x = NDArrayFactory::create('c', {2, 3, 4}, {1., 0., 0., 0., 0., 2., 0., 0., 0., 0., 3., 0.,5., 0., 0., 0., 0., 6., 0., 0., 0., 0., 7., 0.}); - - auto z = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 5, 6, 7}); + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {1., 0., 0., 0., 0., 2., 0., 0., 0., 0., 3., 0., + 5., 0., 0., 0., 0., 6., 0., 0., 0., 0., 7., 0.}); - sd::ops::matrix_diag_part op; + auto z = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 5, 6, 7}); - auto result = op.evaluate({&x}, {}, {}); + sd::ops::matrix_diag_part op; - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(z.equalsTo(result.at(0))); + auto result = op.evaluate({&x}, {}, {}); - + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(z.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestMatrixDiag_1) { - auto z = NDArrayFactory::create('c', {2, 4, 4}, {1., 0., 0., 0., 0., 2., 0., 0., 0., 0., 3., 0., 0., 0., 0., 4.,5., 0., 0., 0., 0., 6., 0., 0., 0., 0., 7., 0., 0., 0., 0., 8.}); + auto z = NDArrayFactory::create( + 'c', {2, 4, 4}, + {1., 0., 0., 0., 0., 2., 0., 0., 0., 0., 3., 0., 0., 0., 0., 4., + 5., 0., 0., 0., 0., 6., 0., 0., 0., 0., 7., 0., 0., 0., 0., 8.}); - auto x = NDArrayFactory::create('c', {2, 4}, {1, 2, 3, 4, 5, 6, 7, 8}); + auto x = + NDArrayFactory::create('c', {2, 4}, {1, 2, 3, 4, 5, 6, 7, 8}); - sd::ops::matrix_diag op; + sd::ops::matrix_diag op; - auto result = op.evaluate({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(z.equalsTo(result.at(0))); - - + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(z.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestMatrixDiag_2) { - auto z = NDArrayFactory::create('c', {2, 3, 3}, {1., 0., 0., 0., 2., 0., 0., 0., 3.,5., 0., 0., 0., 6., 0.,0., 0., 7.}); - auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 5, 6, 7}); - - sd::ops::matrix_diag op; + auto z = NDArrayFactory::create( + 'c', {2, 3, 3}, + {1., 0., 0., 0., 2., 0., 0., 0., 3., 5., 0., 0., 0., 6., 0., 0., 0., 7.}); + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 5, 6, 7}); - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(z.equalsTo(result.at(0))); + sd::ops::matrix_diag op; - + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(z.equalsTo(result.at(0))); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRandomCrop_1) { - auto x = NDArrayFactory::create('c', {2, 2, 4}, {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); - auto shape = NDArrayFactory::create({1, 2, 3}); - sd::ops::random_crop op; + auto x = + NDArrayFactory::create('c', {2, 2, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto shape = NDArrayFactory::create({1, 2, 3}); + sd::ops::random_crop op; - auto result = op.evaluate({&x, &shape}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); -// ASSERT_TRUE(z.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &shape}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // ASSERT_TRUE(z.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRandomCrop_2) { - auto x = NDArrayFactory::create('c', {2, 2, 4}, {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); - auto shape = NDArrayFactory::create({2, 2, 2}); - sd::ops::random_crop op; - - auto result = op.evaluate({&x, &shape}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); -// ASSERT_TRUE(z.equalsTo(result.at(0))); + auto x = + NDArrayFactory::create('c', {2, 2, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto shape = NDArrayFactory::create({2, 2, 2}); + sd::ops::random_crop op; - + auto result = op.evaluate({&x, &shape}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // ASSERT_TRUE(z.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Dynamic_Stitch_119) { - auto indices0 = NDArrayFactory::create('c', {2}, {1, 10}); - auto indices1 = NDArrayFactory::create('c', {2, 3}, {0, 7, 9, 5, 8, 3}); - auto indices2 = NDArrayFactory::create('c', {3, 1}, {6, 4, 2}); - auto data0 = NDArrayFactory::create('c', {2,5,4}, {1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, - 25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f}); - - auto data1 = NDArrayFactory::create('c', {2,3,5,4},{1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f, - 29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f,49.f, 50.f, 51.f, 52.f,53.f, 54.f, 55.f, 56.f, - 57.f, 58.f, 59.f, 60.f,61.f, 62.f, 63.f, 64.f,65.f, 66.f, 67.f, 68.f,69.f, 70.f, 71.f, 72.f,73.f, 74.f, 75.f, 76.f,77.f, 78.f, 79.f, 80.f,81.f, 82.f, 83.f, 84.f, - 85.f, 86.f, 87.f, 88.f,89.f, 90.f, 91.f, 92.f,93.f, 94.f, 95.f, 96.f,97.f, 98.f, 99.f, 100.f,101.f, 102.f, 103.f, 104.f,105.f, 106.f, 107.f, 108.f,109.f, 110.f, 111.f, 112.f, - 113.f, 114.f, 115.f, 116.f,117.f, 118.f, 119.f, 120.f}); - - auto data2 = NDArrayFactory::create('c', {3,1,5,4}, {1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, - 25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f, - 49.f, 50.f, 51.f, 52.f,53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f}); - - auto exp = NDArrayFactory::create('c', {11, 5, 4}, {1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,1.f, 2.f, 3.f, 4.f, - 5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f, - 49.f, 50.f, 51.f, 52.f,53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f,101.f, 102.f, 103.f, 104.f,105.f, 106.f, 107.f, 108.f,109.f, 110.f, 111.f, 112.f, - 113.f, 114.f, 115.f, 116.f,117.f, 118.f, 119.f, 120.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f, - 37.f, 38.f, 39.f, 40.f,61.f, 62.f, 63.f, 64.f,65.f, 66.f, 67.f, 68.f,69.f, 70.f, 71.f, 72.f,73.f, 74.f, 75.f, 76.f,77.f, 78.f, 79.f, 80.f, - 1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, - 25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f,81.f, 82.f, 83.f, 84.f,85.f, 86.f, 87.f, 88.f, - 89.f, 90.f, 91.f, 92.f,93.f, 94.f, 95.f, 96.f,97.f, 98.f, 99.f, 100.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f,49.f, 50.f, 51.f, 52.f, - 53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f}); - - sd::ops::dynamic_stitch op; - auto result = op.evaluate({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); -// result.at(0)->printIndexedBuffer("Output"); -// exp.printIndexedBuffer("Expect"); -// result.at(0)->printShapeInfo("Output shape"); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - + auto indices0 = NDArrayFactory::create('c', {2}, {1, 10}); + auto indices1 = NDArrayFactory::create('c', {2, 3}, {0, 7, 9, 5, 8, 3}); + auto indices2 = NDArrayFactory::create('c', {3, 1}, {6, 4, 2}); + auto data0 = NDArrayFactory::create( + 'c', {2, 5, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f}); + + auto data1 = NDArrayFactory::create( + 'c', {2, 3, 5, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, + 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, + 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, + 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 70.f, + 71.f, 72.f, 73.f, 74.f, 75.f, 76.f, 77.f, 78.f, 79.f, 80.f, + 81.f, 82.f, 83.f, 84.f, 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, + 91.f, 92.f, 93.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 100.f, + 101.f, 102.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 109.f, 110.f, + 111.f, 112.f, 113.f, 114.f, 115.f, 116.f, 117.f, 118.f, 119.f, 120.f}); + + auto data2 = NDArrayFactory::create( + 'c', {3, 1, 5, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, + 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, + 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f}); + + auto exp = NDArrayFactory::create( + 'c', {11, 5, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, + 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, + 101.f, 102.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 109.f, 110.f, + 111.f, 112.f, 113.f, 114.f, 115.f, 116.f, 117.f, 118.f, 119.f, 120.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, + 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 70.f, + 71.f, 72.f, 73.f, 74.f, 75.f, 76.f, 77.f, 78.f, 79.f, 80.f, + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, + 81.f, 82.f, 83.f, 84.f, 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, + 91.f, 92.f, 93.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 100.f, + 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, + 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f}); + + sd::ops::dynamic_stitch op; + auto result = op.evaluate( + {&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + // result.at(0)->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); + // result.at(0)->printShapeInfo("Output shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Dynamic_Stitch_Prof_1) { - auto indices0 = NDArrayFactory::create('c', {2}, {1, 10}); - auto indices1 = NDArrayFactory::create('c', {2, 3}, {0, 7, 9, 5, 8, 3}); - auto indices2 = NDArrayFactory::create('c', {3, 1}, {6, 4, 2}); - auto data0 = NDArrayFactory::create('c', {2,5,4}, {1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, - 25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f}); - - auto data1 = NDArrayFactory::create('c', {2,3,5,4},{1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f, - 29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f,49.f, 50.f, 51.f, 52.f,53.f, 54.f, 55.f, 56.f, - 57.f, 58.f, 59.f, 60.f,61.f, 62.f, 63.f, 64.f,65.f, 66.f, 67.f, 68.f,69.f, 70.f, 71.f, 72.f,73.f, 74.f, 75.f, 76.f,77.f, 78.f, 79.f, 80.f,81.f, 82.f, 83.f, 84.f, - 85.f, 86.f, 87.f, 88.f,89.f, 90.f, 91.f, 92.f,93.f, 94.f, 95.f, 96.f,97.f, 98.f, 99.f, 100.f,101.f, 102.f, 103.f, 104.f,105.f, 106.f, 107.f, 108.f,109.f, 110.f, 111.f, 112.f, - 113.f, 114.f, 115.f, 116.f,117.f, 118.f, 119.f, 120.f}); - - auto data2 = NDArrayFactory::create('c', {3,1,5,4}, {1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, - 25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f, - 49.f, 50.f, 51.f, 52.f,53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f}); - - auto exp = NDArrayFactory::create('c', {11, 5, 4}, {1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,1.f, 2.f, 3.f, 4.f, - 5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f, - 49.f, 50.f, 51.f, 52.f,53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f,101.f, 102.f, 103.f, 104.f,105.f, 106.f, 107.f, 108.f,109.f, 110.f, 111.f, 112.f, - 113.f, 114.f, 115.f, 116.f,117.f, 118.f, 119.f, 120.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f, - 37.f, 38.f, 39.f, 40.f,61.f, 62.f, 63.f, 64.f,65.f, 66.f, 67.f, 68.f,69.f, 70.f, 71.f, 72.f,73.f, 74.f, 75.f, 76.f,77.f, 78.f, 79.f, 80.f, - 1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, - 25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f,81.f, 82.f, 83.f, 84.f,85.f, 86.f, 87.f, 88.f, - 89.f, 90.f, 91.f, 92.f,93.f, 94.f, 95.f, 96.f,97.f, 98.f, 99.f, 100.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f,49.f, 50.f, 51.f, 52.f, - 53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f}); - - sd::ops::dynamic_stitch op; - auto result = op.evaluate({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); -// result.at(0)->printIndexedBuffer("Output"); -// exp.printIndexedBuffer("Expect"); -// result.at(0)->printShapeInfo("Output shape"); - auto res = result.at(0); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - int numOfCases = 100; - auto timeStart = std::chrono::system_clock::now(); - - for (int i = 0; i < numOfCases; i++) { - op.execute({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {res}, {}, {}, {}); - } - - auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); - //nd4j_printf("dynamic_stitch: Process with %i iterations was load: %lld us.\n", numOfCases, outerTime / numOfCases); - - + auto indices0 = NDArrayFactory::create('c', {2}, {1, 10}); + auto indices1 = NDArrayFactory::create('c', {2, 3}, {0, 7, 9, 5, 8, 3}); + auto indices2 = NDArrayFactory::create('c', {3, 1}, {6, 4, 2}); + auto data0 = NDArrayFactory::create( + 'c', {2, 5, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f}); + + auto data1 = NDArrayFactory::create( + 'c', {2, 3, 5, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, + 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, + 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, + 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 70.f, + 71.f, 72.f, 73.f, 74.f, 75.f, 76.f, 77.f, 78.f, 79.f, 80.f, + 81.f, 82.f, 83.f, 84.f, 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, + 91.f, 92.f, 93.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 100.f, + 101.f, 102.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 109.f, 110.f, + 111.f, 112.f, 113.f, 114.f, 115.f, 116.f, 117.f, 118.f, 119.f, 120.f}); + + auto data2 = NDArrayFactory::create( + 'c', {3, 1, 5, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, + 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, + 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f}); + + auto exp = NDArrayFactory::create( + 'c', {11, 5, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, + 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, + 101.f, 102.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 109.f, 110.f, + 111.f, 112.f, 113.f, 114.f, 115.f, 116.f, 117.f, 118.f, 119.f, 120.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, + 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 70.f, + 71.f, 72.f, 73.f, 74.f, 75.f, 76.f, 77.f, 78.f, 79.f, 80.f, + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, + 81.f, 82.f, 83.f, 84.f, 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, + 91.f, 92.f, 93.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 100.f, + 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, + 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f}); + + sd::ops::dynamic_stitch op; + auto result = op.evaluate( + {&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + // result.at(0)->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); + // result.at(0)->printShapeInfo("Output shape"); + auto res = result.at(0); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + int numOfCases = 100; + auto timeStart = std::chrono::system_clock::now(); + + for (int i = 0; i < numOfCases; i++) { + op.execute({&indices0, &indices1, &indices2, &data0, &data1, &data2}, + {&res}, {}, {}, {}); + } + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = + std::chrono::duration_cast(timeEnd - timeStart) + .count(); + // nd4j_printf("dynamic_stitch: Process with %i iterations was load: %lld + // us.\n", numOfCases, outerTime / numOfCases); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Dynamic_Stitch_119_1) { - auto indices0 = NDArrayFactory::create('c', {2}, {1, 10}); - auto indices1 = NDArrayFactory::create('c', {2, 3}, {0, 7, 9, 5, 8, 3}); - auto indices2 = NDArrayFactory::create('c', {3, 1}, {6, 4, 2}); - - auto data0 = NDArrayFactory::create('c', {2,5,4}); - auto data1 = NDArrayFactory::create('c', {2,3,5,4}); - auto data2 = NDArrayFactory::create('c', {3,1,5,4}); - - auto exp = NDArrayFactory::create('c', {11, 5, 4}, { - 21, 22, 23, 24, - 25, 26, 27, 28, - 29, 30, 31, 32, - 33, 34, 35, 36, - 37, 38, 39, 40, - - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - - 181, 182, 183, 184, - 185, 186, 187, 188, - 189, 190, 191, 192, - 193, 194, 195, 196, - 197, 198, 199, 200, - - 121, 122, 123, 124, - 125, 126, 127, 128, - 129, 130, 131, 132, - 133, 134, 135, 136, - 137, 138, 139, 140, - - 161, 162, 163, 164, - 165, 166, 167, 168, - 169, 170, 171, 172, - 173, 174, 175, 176, - 177, 178, 179, 180, - - 81, 82, 83, 84, - 85, 86, 87, 88, - 89, 90, 91, 92, - 93, 94, 95, 96, - 97, 98, 99, 100, - - 141, 142, 143, 144, - 145, 146, 147, 148, - 149, 150, 151, 152, - 153, 154, 155, 156, - 157, 158, 159, 160, - - 41, 42, 43, 44, - 45, 46, 47, 48, - 49, 50, 51, 52, - 53, 54, 55, 56, - 57, 58, 59, 60, - - 101, 102, 103, 104, - 105, 106, 107, 108, - 109, 110, 111, 112, - 113, 114, 115, 116, - 117, 118, 119, 120, - - 61, 62, 63, 64, - 65, 66, 67, 68, - 69, 70, 71, 72, - 73, 74, 75, 76, - 77, 78, 79, 80, - - 21, 22, 23, 24, - 25, 26, 27, 28, - 29, 30, 31, 32, - 33, 34, 35, 36, - 37, 38, 39, 40, - }); - data0.linspace(1); - data1.linspace(21); - data2.linspace(141); - sd::ops::dynamic_stitch op; - auto result = op.evaluate({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_TRUE(z->isSameShape(exp)); - ASSERT_TRUE(z->equalsTo(exp)); - - + auto indices0 = NDArrayFactory::create('c', {2}, {1, 10}); + auto indices1 = NDArrayFactory::create('c', {2, 3}, {0, 7, 9, 5, 8, 3}); + auto indices2 = NDArrayFactory::create('c', {3, 1}, {6, 4, 2}); + + auto data0 = NDArrayFactory::create('c', {2, 5, 4}); + auto data1 = NDArrayFactory::create('c', {2, 3, 5, 4}); + auto data2 = NDArrayFactory::create('c', {3, 1, 5, 4}); + + auto exp = NDArrayFactory::create( + 'c', {11, 5, 4}, + { + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, + + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + + 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, + + 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, + 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, + + 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, + 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, + + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, + 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, + + 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, + 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, + + 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, + 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, + + 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, + 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, + + 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, + 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, + }); + data0.linspace(1); + data1.linspace(21); + data2.linspace(141); + sd::ops::dynamic_stitch op; + auto result = op.evaluate( + {&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + + ASSERT_TRUE(z.isSameShape(exp)); + ASSERT_TRUE(z.equalsTo(exp)); } TEST_F(DeclarableOpsTests7, Test_Dynamic_Stitch_119_2) { - auto indices0 = NDArrayFactory::create('c', {2}, {1, 10}); - auto indices1 = NDArrayFactory::create('c', {2, 3}, {0, 7, 9, 5, 8, 3}); - auto indices2 = NDArrayFactory::create('c', {3, 1}, {6, 4, 2}); - - auto data0 = NDArrayFactory::create('c', {2,5,4}); - auto data1 = NDArrayFactory::create('c', {2,3,5,4}); - auto data2 = NDArrayFactory::create('c', {3,1,5,4}); - - auto exp = NDArrayFactory::create('c', {11, 5, 4}, { - 41, 42, 43, 44, - 45, 46, 47, 48, - 49, 50, 51, 52, - 53, 54, 55, 56, - 57, 58, 59, 60, - - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - - 201, 202, 203, 204, - 205, 206, 207, 208, - 209, 210, 211, 212, - 213, 214, 215, 216, - 217, 218, 219, 220, - - 141, 142, 143, 144, - 145, 146, 147, 148, - 149, 150, 151, 152, - 153, 154, 155, 156, - 157, 158, 159, 160, - - 181, 182, 183, 184, - 185, 186, 187, 188, - 189, 190, 191, 192, - 193, 194, 195, 196, - 197, 198, 199, 200, - - 101, 102, 103, 104, - 105, 106, 107, 108, - 109, 110, 111, 112, - 113, 114, 115, 116, - 117, 118, 119, 120, - - 161, 162, 163, 164, - 165, 166, 167, 168, - 169, 170, 171, 172, - 173, 174, 175, 176, - 177, 178, 179, 180, - - 61, 62, 63, 64, - 65, 66, 67, 68, - 69, 70, 71, 72, - 73, 74, 75, 76, - 77, 78, 79, 80, - - 121, 122, 123, 124, - 125, 126, 127, 128, - 129, 130, 131, 132, - 133, 134, 135, 136, - 137, 138, 139, 140, - - 81, 82, 83, 84, - 85, 86, 87, 88, - 89, 90, 91, 92, - 93, 94, 95, 96, - 97, 98, 99, 100, - - 21, 22, 23, 24, - 25, 26, 27, 28, - 29, 30, 31, 32, - 33, 34, 35, 36, - 37, 38, 39, 40, - }); - data0.linspace(1); - data1.linspace(41); - data2.linspace(161); - sd::ops::dynamic_stitch op; - auto result = op.evaluate({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_TRUE(z->isSameShape(exp)); - ASSERT_TRUE(z->equalsTo(exp)); - - -} + auto indices0 = NDArrayFactory::create('c', {2}, {1, 10}); + auto indices1 = NDArrayFactory::create('c', {2, 3}, {0, 7, 9, 5, 8, 3}); + auto indices2 = NDArrayFactory::create('c', {3, 1}, {6, 4, 2}); -TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119) { - auto x = NDArrayFactory::create('c', {5, 4, 11}); - auto y = NDArrayFactory::create('c', {5, 4}, {0,1,2,3, 1,0,2,3, 2,3,1,0, 2,1,0,3, 0,1,2,3}); - auto e = NDArrayFactory::create('c', {5, 11}); - x.assign(1.f); - e.assign(1.f); - sd::ops::dynamic_partition op; - auto result = op.evaluate({&x, &y}, {}, {4}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(4, result.size()); - auto z = result.at(0); + auto data0 = NDArrayFactory::create('c', {2, 5, 4}); + auto data1 = NDArrayFactory::create('c', {2, 3, 5, 4}); + auto data2 = NDArrayFactory::create('c', {3, 1, 5, 4}); - ASSERT_TRUE(e.isSameShape(z)); + auto exp = NDArrayFactory::create( + 'c', {11, 5, 4}, + { + 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, + 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, - -} + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, -TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119_1) { - auto x = NDArrayFactory::create('c', {3, 4, 2}, {10, 20,11, 21,12, 22,13, 23,14, 24,15, 25,16, 26,17, 27,18, 28,19, 29,20, 30,21, 31}); - - auto y = NDArrayFactory::create('c', {3, 4}, {0,0,0,0, 2,2,2,2, 2,1,1,1}); - auto e = NDArrayFactory::create('c', {4, 2}, {10, 20, 11, 21, 12, 22, 13, 23}); - -// x.assign(1.f); -// e.assign(1.f); - sd::ops::dynamic_partition op; - auto result = op.evaluate({&x, &y}, {}, {3}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(3, result.size()); - auto z = result.at(0); -// z->printShapeInfo("Output shape info"); -// result.at(1)->printShapeInfo("Shape2"); -// result.at(2)->printShapeInfo("Shape3"); -// result.at(3)->printShapeInfo("Shape4"); -// z->printIndexedBuffer("Output1"); -// result.at(1)->printIndexedBuffer("Output2"); -// result.at(2)->printIndexedBuffer("Output3"); -// result.at(3)->printIndexedBuffer("Output4"); - ASSERT_TRUE(e.isSameShape(z)); - - -} -TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119_2) { - auto x = NDArrayFactory::create('c', {5, 4, 11}); - auto y = NDArrayFactory::create('c', {5, 4}, {0,1,2,3, 1,0,2,3, 2,3,1,0, 2,1,0,3, 0,1,2,3}); - auto e1 = NDArrayFactory::create('c', {5, 11}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, - 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, - 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, - 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, - 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187}); - auto e2 = NDArrayFactory::create('c', {5, 11}, { 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, - 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, - 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, - 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, - 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198}); - auto e3 = NDArrayFactory::create('c', {5, 11}, {23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, - 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, - 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, - 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, - 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209}); - auto e4 = NDArrayFactory::create('c', {5, 11}, { 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, - 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, - 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, - 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, - 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220}) ; - std::vector e({&e1, &e2, &e3, &e4}); - x.linspace(1.f); - //.assign(1.f); - sd::ops::dynamic_partition op; - auto result = op.evaluate({&x, &y}, {}, {4}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(4, result.size()); - for (size_t i = 0; i < result.size(); i++) { - auto z = result.at(i); -// z->printShapeInfo("Output shape info"); -// z->printIndexedBuffer("Output1"); -// result.at(1)->printIndexedBuffer("Output2"); -// result.at(2)->printIndexedBuffer("Output3"); -// result.at(3)->printIndexedBuffer("Output4"); - ASSERT_TRUE(e[i]->isSameShape(z)); - ASSERT_TRUE(e[i]->equalsTo(z)); - } + 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, - -} + 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, + 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, + 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, -TEST_F(DeclarableOpsTests7, Test_SequenceMask_1) { - auto input = NDArrayFactory::create('c', {4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - auto exp = NDArrayFactory::create('c', {4, 4, 16}, {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }); + 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, + 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, - sd::ops::sequence_mask op; - auto result = op.evaluate({&input}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, + 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, - auto z = result.at(0); -// z->printIndexedBuffer("Output"); -// z->printShapeInfo("Shape"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, + 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, - + 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, + 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, -} + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, + 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, -TEST_F(DeclarableOpsTests7, Test_SequenceMask_2) { - auto input = NDArrayFactory::create('c', {2, 2, 2}, {10, 20, 30, 4, 0, 6, 7, 8}); - auto exp = NDArrayFactory::create('c', {2, 2, 2, 30}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, + }); + data0.linspace(1); + data1.linspace(41); + data2.linspace(161); + sd::ops::dynamic_stitch op; + auto result = op.evaluate( + {&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); - sd::ops::sequence_mask op; - auto result = op.evaluate({&input}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); - auto z = result.at(0); -// z->printBuffer("Output"); -// z->printShapeInfo("Shape"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(z.isSameShape(exp)); + ASSERT_TRUE(z.equalsTo(exp)); +} - +TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119) { + auto x = NDArrayFactory::create('c', {5, 4, 11}); + auto y = NDArrayFactory::create( + 'c', {5, 4}, + {0, 1, 2, 3, 1, 0, 2, 3, 2, 3, 1, 0, 2, 1, 0, 3, 0, 1, 2, 3}); + auto e = NDArrayFactory::create('c', {5, 11}); + x.assign(1.f); + e.assign(1.f); + sd::ops::dynamic_partition op; + auto result = op.evaluate({&x, &y}, {}, {4}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(4, result.size()); + auto z = result.at(0); + + ASSERT_TRUE(e.isSameShape(z)); } -TEST_F(DeclarableOpsTests7, Test_SequenceMask_3) { - auto input = NDArrayFactory::create('c', {2, 2, 2}, {10, 20, 30, 4, 0, 6, 7, 8}); - auto exp = NDArrayFactory::create('c', {2, 2, 2, 30}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); +TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119_1) { + auto x = NDArrayFactory::create( + 'c', {3, 4, 2}, {10, 20, 11, 21, 12, 22, 13, 23, 14, 24, 15, 25, + 16, 26, 17, 27, 18, 28, 19, 29, 20, 30, 21, 31}); + + auto y = NDArrayFactory::create('c', {3, 4}, + {0, 0, 0, 0, 2, 2, 2, 2, 2, 1, 1, 1}); + auto e = NDArrayFactory::create('c', {4, 2}, + {10, 20, 11, 21, 12, 22, 13, 23}); + + // x.assign(1.f); + // e.assign(1.f); + sd::ops::dynamic_partition op; + auto result = op.evaluate({&x, &y}, {}, {3}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(3, result.size()); + auto z = result.at(0); + // z->printShapeInfo("Output shape info"); + // result.at(1)->printShapeInfo("Shape2"); + // result.at(2)->printShapeInfo("Shape3"); + // result.at(3)->printShapeInfo("Shape4"); + // z->printIndexedBuffer("Output1"); + // result.at(1)->printIndexedBuffer("Output2"); + // result.at(2)->printIndexedBuffer("Output3"); + // result.at(3)->printIndexedBuffer("Output4"); + ASSERT_TRUE(e.isSameShape(z)); +} +TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119_2) { + auto x = NDArrayFactory::create('c', {5, 4, 11}); + auto y = NDArrayFactory::create( + 'c', {5, 4}, + {0, 1, 2, 3, 1, 0, 2, 3, 2, 3, 1, 0, 2, 1, 0, 3, 0, 1, 2, 3}); + auto e1 = NDArrayFactory::create( + 'c', {5, 11}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 56, 57, 58, + 59, 60, 61, 62, 63, 64, 65, 66, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 155, 156, 157, 158, 159, 160, 161, 162, 163, + 164, 165, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187}); + auto e2 = NDArrayFactory::create( + 'c', {5, 11}, + {12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 111, 112, 113, 114, 115, 116, + 117, 118, 119, 120, 121, 144, 145, 146, 147, 148, 149, 150, 151, 152, + 153, 154, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198}); + auto e3 = NDArrayFactory::create( + 'c', {5, 11}, + {23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 67, 68, 69, + 70, 71, 72, 73, 74, 75, 76, 77, 89, 90, 91, 92, 93, 94, + 95, 96, 97, 98, 99, 133, 134, 135, 136, 137, 138, 139, 140, 141, + 142, 143, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209}); + auto e4 = NDArrayFactory::create( + 'c', {5, 11}, + {34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 166, 167, 168, 169, 170, 171, 172, 173, 174, + 175, 176, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220}); + std::vector e({&e1, &e2, &e3, &e4}); + x.linspace(1.f); + //.assign(1.f); + sd::ops::dynamic_partition op; + auto result = op.evaluate({&x, &y}, {}, {4}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(4, result.size()); + for (size_t i = 0; i < result.size(); i++) { + auto z = result.at(i); + // z->printShapeInfo("Output shape info"); + // z->printIndexedBuffer("Output1"); + // result.at(1)->printIndexedBuffer("Output2"); + // result.at(2)->printIndexedBuffer("Output3"); + // result.at(3)->printIndexedBuffer("Output4"); + ASSERT_TRUE(e[i]->isSameShape(z)); + ASSERT_TRUE(e[i]->equalsTo(z)); + } +} - sd::ops::sequence_mask op; - auto result = op.evaluate({&input}, {sd::DataType::INT32}); - ASSERT_EQ(Status::OK(), result.status()); +TEST_F(DeclarableOpsTests7, Test_SequenceMask_1) { + auto input = NDArrayFactory::create( + 'c', {4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + auto exp = NDArrayFactory::create( + 'c', {4, 4, 16}, + {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + + sd::ops::sequence_mask op; + auto result = op.evaluate({&input}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printIndexedBuffer("Output"); + // z->printShapeInfo("Shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} - auto z = result.at(0); -// z->printBuffer("Output"); -// z->printShapeInfo("Shape"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); +TEST_F(DeclarableOpsTests7, Test_SequenceMask_2) { + auto input = + NDArrayFactory::create('c', {2, 2, 2}, {10, 20, 30, 4, 0, 6, 7, 8}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 30}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + + sd::ops::sequence_mask op; + auto result = op.evaluate({&input}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printBuffer("Output"); + // z->printShapeInfo("Shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} - +TEST_F(DeclarableOpsTests7, Test_SequenceMask_3) { + auto input = + NDArrayFactory::create('c', {2, 2, 2}, {10, 20, 30, 4, 0, 6, 7, 8}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 30}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + + sd::ops::sequence_mask op; + auto result = op.evaluate({&input}, {sd::DataType::INT32}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printBuffer("Output"); + // z->printShapeInfo("Shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests7, Test_SequenceMask_4) { - auto input = NDArrayFactory::create({1, 3, 2}); - auto maxLen = NDArrayFactory::create(5); - auto exp = NDArrayFactory::create('c', {3,5}, { - 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f - }); - - sd::ops::sequence_mask op; - auto result = op.evaluate({&input, &maxLen}, {sd::DataType::FLOAT32}); - ASSERT_EQ(Status::OK(), result.status()); + auto input = NDArrayFactory::create({1, 3, 2}); + auto maxLen = NDArrayFactory::create(5); + auto exp = + NDArrayFactory::create('c', {3, 5}, + {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, + 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f}); - auto z = result.at(0); -// z->printBuffer("Output"); -// z->printShapeInfo("Shape"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::sequence_mask op; + auto result = op.evaluate({&input, &maxLen}, {sd::DataType::FLOAT32}); + ASSERT_EQ(Status::OK(), result.status()); - + auto z = result.at(0); + // z->printBuffer("Output"); + // z->printShapeInfo("Shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests7, Test_SequenceMask_5) { - auto input = NDArrayFactory::create({1, 3, 2}); - auto exp = NDArrayFactory::create('c', {3,5}, { - 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f - }); + auto input = NDArrayFactory::create({1, 3, 2}); + auto exp = + NDArrayFactory::create('c', {3, 5}, + {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, + 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f}); - sd::ops::sequence_mask op; - auto result = op.evaluate({&input}, {5, (int)sd::DataType::FLOAT32}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::sequence_mask op; + auto result = op.evaluate({&input}, {5, (int)sd::DataType::FLOAT32}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printBuffer("Output"); -// z->printShapeInfo("Shape"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto z = result.at(0); + // z->printBuffer("Output"); + // z->printShapeInfo("Shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMax_1) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({2.5, 9, 3, 9, 4.2}); - - sd::ops::segment_max op; + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({2.5, 9, 3, 9, 4.2}); - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printBuffer("MaX1"); -// exp.printBuffer("ExP1"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::segment_max op; - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printBuffer("MaX1"); + // exp.printBuffer("ExP1"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMax_01) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1., 10, 40, 30}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5,5, 5}); - auto exp = NDArrayFactory::create({2.5, 9, 3, 9, 4.2, 40}); - - sd::ops::segment_max op; + auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., + 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1., + 10, 40, 30}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5}); + auto exp = NDArrayFactory::create({2.5, 9, 3, 9, 4.2, 40}); - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printBuffer("MaX01"); -// exp.printBuffer("ExP01"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::segment_max op; - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printBuffer("MaX01"); + // exp.printBuffer("ExP01"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMaxBP_1) { - auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({0., 1., 0., 2., 0., 0., 3., 4., 0., 0.,0., 0., 0., 5., 0.,0.}); - auto eps = NDArrayFactory::create('c', {5}); - sd::ops::segment_max_bp op; - eps.linspace(1); - auto result = op.evaluate({&x, &idx, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("OutputMaxBP"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create( + {0., 1., 0., 2., 0., 0., 3., 4., 0., 0., 0., 0., 0., 5., 0., 0.}); + auto eps = NDArrayFactory::create('c', {5}); + sd::ops::segment_max_bp op; + eps.linspace(1); + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("OutputMaxBP"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMax_2) { - auto x = NDArrayFactory::create('c', {5, 4}, { 0, 1.8, 2.5, 4., - 1, 9., 2.1, 2.4, - 0, 3., 9., 2.1, - 2, 1, 2.1, 0.7, - 3, 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 0, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4}, {1, 9, 9, 4, - 2, 1, 2.1, 0.7, - 3, 4.2, 2.2, 1.}); - - //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} - sd::ops::segment_max op; + auto x = NDArrayFactory::create( + 'c', {5, 4}, {0, 1.8, 2.5, 4., 1, 9., 2.1, 2.4, 0, 3., + 9., 2.1, 2, 1, 2.1, 0.7, 3, 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 0, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, {1, 9, 9, 4, 2, 1, 2.1, 0.7, 3, 4.2, 2.2, 1.}); - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); - auto out = result.at(0); -// out->printIndexedBuffer("Output2Max"); -// exp.printIndexedBuffer("Expect2Max"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + sd::ops::segment_max op; - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + auto out = result.at(0); + // out->printIndexedBuffer("Output2Max"); + // exp.printIndexedBuffer("Expect2Max"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMaxBP_2) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto eps = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); -// NDArray exp('c', {3, 4}, {2.1, 2.5, 4, 9,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); - auto exp = NDArrayFactory::create('c', {4, 4}, {0., 2., 3., 4., 1., 0., 0., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - - //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto eps = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + // NDArray exp('c', {3, 4}, {2.1, 2.5, 4, 9,2.1, 2.1, 0.7, + // 0.1,3., 4.2, 2.2, 1.}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, + {0., 2., 3., 4., 1., 0., 0., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - sd::ops::segment_max_bp op; + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} - auto result = op.evaluate({&x, &idx, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 2); - //exp.printIndexedBuffer("BP Max Expect"); - //result.at(0)->printIndexedBuffer("BP Max Output"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::segment_max_bp op; - + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 2); + // exp.printIndexedBuffer("BP Max Expect"); + // result.at(0)->printIndexedBuffer("BP Max Output"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMax_3) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , - 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); -// ---------------------------------------------------------------- + // ---------------------------------------------------------------- - auto idx = NDArrayFactory::create({0, 1, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4, 4}, {91. , 82. , 37. , 64.,55.1, 46.4, 73. , 28.,119.1, 12.1,112.7, 13.1,14. ,114.2, 16.2,117.,51. , 42. , 87. , 44., - 55.1, 56.4, 93. , 28.,119.1, 82.1,112.7,113.1,114. ,114.2,116.2,117.,91. , 82. , 37. , 64.,55.1, 46.4, 73. , 28., 119.1, 12.1,112.7, 13.1,14. ,114.2, 16.2,117. }); + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 87., 44., + 55.1, 56.4, 93., 28., 119.1, 82.1, 112.7, 113.1, 114., 114.2, + 116.2, 117., 91., 82., 37., 64., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); - //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} - sd::ops::segment_max op; + sd::ops::segment_max op; - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output3Max"); -// result.at(0)->printShapeInfo("Out Shape 3 Max"); -// exp.printIndexedBuffer("Expect3Max"); -// exp.printShapeInfo("Exp Shape 3 Max"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output3Max"); + // result.at(0)->printShapeInfo("Out Shape 3 Max"); + // exp.printIndexedBuffer("Expect3Max"); + // exp.printShapeInfo("Exp Shape 3 Max"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMax_4) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, {91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24., - 15.1, 56.4, 93. , 28.,109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. , - 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - auto idx = NDArrayFactory::create({0, 1, 3, 7}); - auto exp = NDArrayFactory::create('c', {8, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. , - 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - sd::ops::segment_max op; - - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); - //result.at(0)->printShapeInfo("Out Shape"); - //exp.printIndexedBuffer("Expect"); - //exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + auto idx = NDArrayFactory::create({0, 1, 3, 7}); + auto exp = NDArrayFactory::create( + 'c', {8, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, + 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., 15.1, 56.4, + 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, 116.2, 11., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 31., 22., 87., 44., 55.1, 46.4, 73., + 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, + 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::segment_max op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMax_1) { - auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({4, 4, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 0, 0}); - auto exp = NDArrayFactory::create({2.2, 9., 3., 9., 4.2}); + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {4, 4, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 0, 0}); + auto exp = NDArrayFactory::create({2.2, 9., 3., 9., 4.2}); - sd::ops::unsorted_segment_max op; + sd::ops::unsorted_segment_max op; - auto result = op.evaluate({&x, &idx}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMaxBP_1) { - auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({0., 1., 0., 2., 0., 0., 3., 4., 0., 0.,0., 0., 0., 5., 0.,0.}); - auto eps = NDArrayFactory::create('c', {5}); - sd::ops::segment_max_bp op; - eps.linspace(1); - auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create( + {0., 1., 0., 2., 0., 0., 3., 4., 0., 0., 0., 0., 0., 5., 0., 0.}); + auto eps = NDArrayFactory::create('c', {5}); + sd::ops::segment_max_bp op; + eps.linspace(1); + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMaxBP_2) { - auto x = NDArrayFactory::create({3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({3., 0., 1., 0., 2., 0., 0., 4., 0., 0.,0., 0., 0., 5., 0.,0.}); - auto eps = NDArrayFactory::create('c', {5}); - sd::ops::segment_max_bp op; - eps.linspace(1); - auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); - //exp.printIndexedBuffer("Expect"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = + NDArrayFactory::create({3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create( + {3., 0., 1., 0., 2., 0., 0., 4., 0., 0., 0., 0., 0., 5., 0., 0.}); + auto eps = NDArrayFactory::create('c', {5}); + sd::ops::segment_max_bp op; + eps.linspace(1); + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMax_2) { - auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({4, 4, 1, 1, 1, 1, 3, 3, 3, 3, 4, 4, 4, 4, 0, 0}); - auto exp = NDArrayFactory::create({2.2, 9., -DataTypeUtils::max(), 9., 4.2}); + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {4, 4, 1, 1, 1, 1, 3, 3, 3, 3, 4, 4, 4, 4, 0, 0}); + auto exp = NDArrayFactory::create( + {2.2, 9., -DataTypeUtils::max(), 9., 4.2}); - sd::ops::unsorted_segment_max op; + sd::ops::unsorted_segment_max op; - auto result = op.evaluate({&x, &idx}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("OutputUnsortedMax"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("OutputUnsortedMax"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMax_3) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4}, {2.1, 2.5, 4, 9,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); - - //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, {2.1, 2.5, 4, 9, 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - sd::ops::unsorted_segment_max op; + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); - //exp.printIndexedBuffer("Expect"); - //result.at(0)->printIndexedBuffer("Output"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::unsorted_segment_max op; - + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + // exp.printIndexedBuffer("Expect"); + // result.at(0)->printIndexedBuffer("Output"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMax_4) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 8., 2.1, 2.1, 11.7, 0.1, 3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 0, 2}); - double principalMax = DataTypeUtils::max(); - auto exp = NDArrayFactory::create('c', {3, 4}, {2.1, 2.5, 11.7, 9, - -principalMax, -principalMax, -principalMax, -principalMax, - 3., 4.2, 2.2, 1.}); - - //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 8., 2.1, + 2.1, 11.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 0, 2}); + double principalMax = DataTypeUtils::max(); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {2.1, 2.5, 11.7, 9, -principalMax, -principalMax, -principalMax, + -principalMax, 3., 4.2, 2.2, 1.}); - sd::ops::unsorted_segment_max op; + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); - //exp.printIndexedBuffer("Expect"); - //result.at(0)->printIndexedBuffer("Output"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::unsorted_segment_max op; - + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + // exp.printIndexedBuffer("Expect"); + // result.at(0)->printIndexedBuffer("Output"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMin_1) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4, 3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({1.8, 2.1, 3., 2.1, 0.1}); + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({1.8, 2.1, 3., 2.1, 0.1}); - sd::ops::segment_min op; + sd::ops::segment_min op; - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - auto out = result.at(0); + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMin_01) { - auto x = NDArrayFactory::create({1.8, -2.5,4., -9., 2.1, 2.4,-3.,-9., 2.1, 2.1,0.7, 0.1, 3., -4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({-2.5, -9, -3., -9, -4.2}); - - sd::ops::segment_min op; + auto x = + NDArrayFactory::create({1.8, -2.5, 4., -9., 2.1, 2.4, -3., -9., + 2.1, 2.1, 0.7, 0.1, 3., -4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({-2.5, -9, -3., -9, -4.2}); - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - auto out = result.at(0); + sd::ops::segment_min op; - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMin_02) { - auto x = NDArrayFactory::create({1.8f, -2.5f, 4.f, -9.f, 2.1f, 2.4f, -3.f, -9.f, 2.1f, 2.1f,0.7f, 0.1f, 3.f, -4.2f, 2.2f, 1.f}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({-2.5f, -9.f, -3.f, -9.f, -4.2f}); - - sd::ops::segment_min op; + auto x = NDArrayFactory::create({1.8f, -2.5f, 4.f, -9.f, 2.1f, 2.4f, + -3.f, -9.f, 2.1f, 2.1f, 0.7f, 0.1f, + 3.f, -4.2f, 2.2f, 1.f}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({-2.5f, -9.f, -3.f, -9.f, -4.2f}); - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - auto out = result.at(0); + sd::ops::segment_min op; - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMinBP_1) { - auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({ 1., 0., 0., 0., 2., 0., 3., 0., 4., 4., 0., 5., 0., 0., 0., 0.}); - auto eps = NDArrayFactory::create('c', {5}); - eps.linspace(1); - sd::ops::segment_min_bp op; + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create( + {1., 0., 0., 0., 2., 0., 3., 0., 4., 4., 0., 5., 0., 0., 0., 0.}); + auto eps = NDArrayFactory::create('c', {5}); + eps.linspace(1); + sd::ops::segment_min_bp op; - auto result = op.evaluate({&x, &idx, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMinBP_1) { - auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({ 1., 0., 0., 0., 2., 0., 3., 0., 4., 4., 0., 5., 0., 0., 0., 0.}); - auto eps = NDArrayFactory::create('c', {5}); - eps.linspace(1); - sd::ops::unsorted_segment_min_bp op; - - auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output1"); - //exp.printIndexedBuffer("Expecte"); + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create( + {1., 0., 0., 0., 2., 0., 3., 0., 4., 4., 0., 5., 0., 0., 0., 0.}); + auto eps = NDArrayFactory::create('c', {5}); + eps.linspace(1); + sd::ops::unsorted_segment_min_bp op; - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output1"); + // exp.printIndexedBuffer("Expecte"); - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMinBP_2) { - auto x = NDArrayFactory::create({3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({3., 1., 0., 0., 0., 2., 0., 0., 4., 4., 0., 5., 0., 0., 0., 0.}); - auto eps = NDArrayFactory::create('c', {5}); - eps.linspace(1); - sd::ops::unsorted_segment_min_bp op; - - auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output1"); - //exp.printIndexedBuffer("Expecte"); + auto x = + NDArrayFactory::create({3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create( + {3., 1., 0., 0., 0., 2., 0., 0., 4., 4., 0., 5., 0., 0., 0., 0.}); + auto eps = NDArrayFactory::create('c', {5}); + eps.linspace(1); + sd::ops::unsorted_segment_min_bp op; - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output1"); + // exp.printIndexedBuffer("Expecte"); - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMin_2) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4}, {1.8, 2.4, 3. , 9.,2.1, 2.1, 0.7, 0.1,3. , 4.2, 2.2, 1.}); + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, {1.8, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} - sd::ops::segment_min op; + sd::ops::segment_min op; - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMinBP_2) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto eps = NDArrayFactory::create('c', {3, 4}, {1., 2., 3. , 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - auto exp = NDArrayFactory::create('c', {4, 4}, {1., 0., 0., 4., 0., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - - //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto eps = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, + {1., 0., 0., 4., 0., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - sd::ops::segment_min_bp op; + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} - auto result = op.evaluate({&x, &idx, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 2); -// exp.printIndexedBuffer("Expect"); -// result.at(0)->printIndexedBuffer("Output"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::segment_min_bp op; - + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 2); + // exp.printIndexedBuffer("Expect"); + // result.at(0)->printIndexedBuffer("Output"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMin_3) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28.,109.1, 82.1, 12.7, 113.1, - 114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. , - 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4, 4}, {91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,31. , 22. , 67. , 24. , - 15.1, 46.4, 73. , 28. ,109.1, 12.1, 12.7, 13.1,14. , 14.2, 16.2, 11. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - sd::ops::segment_min op; - - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); -// result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 31., 22., 67., 24., + 15.1, 46.4, 73., 28., 109.1, 12.1, 12.7, 13.1, 14., 14.2, + 16.2, 11., 91., 82., 37., 64., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::segment_min op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMin_4) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , - 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 3, 7}); - auto exp = NDArrayFactory::create('c', {8, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - sd::ops::segment_min op; - - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); - //result.at(0)->printShapeInfo("Out Shape"); - //exp.printIndexedBuffer("Expect"); - //exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 3, 7}); + auto exp = NDArrayFactory::create( + 'c', {8, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, + 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., 15.1, 56.4, + 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, 116.2, 11., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 31., 22., 87., 44., 55.1, 46.4, 73., + 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, + 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::segment_min op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMin_1) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({1.8, 2.1, 3., 2.1, 0.1}); - - sd::ops::unsorted_segment_min op; + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({1.8, 2.1, 3., 2.1, 0.1}); - auto result = op.evaluate({&x, &idx}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::unsorted_segment_min op; - + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMin_01) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({1.8, 2.1, 3., 2.1, 0.1}); - - sd::ops::unsorted_segment_min op; + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({1.8, 2.1, 3., 2.1, 0.1}); - auto result = op.evaluate({&x, &idx}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::unsorted_segment_min op; - + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMin_2) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4}, {1.8, 2.4, 3. , 9.,2.1, 2.1, 0.7, 0.1,3. , 4.2, 2.2, 1.}); + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, {1.8, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} - sd::ops::unsorted_segment_min op; + sd::ops::unsorted_segment_min op; - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMin_3) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28.,109.1, 82.1, 12.7, 113.1, - 114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. , - 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4, 4}, {91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,31. , 22. , 67. , 24. , - 15.1, 46.4, 73. , 28. ,109.1, 12.1, 12.7, 13.1,14. , 14.2, 16.2, 11. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - sd::ops::unsorted_segment_min op; - - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); -// result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 31., 22., 67., 24., + 15.1, 46.4, 73., 28., 109.1, 12.1, 12.7, 13.1, 14., 14.2, + 16.2, 11., 91., 82., 37., 64., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::unsorted_segment_min op; + + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMin_4) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., - 51., 42., 67., 24., 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, 116.2, 11., - 31., 22., 87., 44., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., - 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 3, 7}); - double principalMax = DataTypeUtils::max(); - - auto exp = NDArrayFactory::create('c', {8, 4, 4}, { - 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 51., - 42., 67., 24., 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, 116.2, 11., - principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, - principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, - principalMax, principalMax, - 31., 22., 87., 44., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., - principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, - principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, - principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, - principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, - principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, - principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, - principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, - 91., 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - sd::ops::unsorted_segment_min op; - - auto result = op.evaluate({&x, &idx}, {}, {8}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); - //result.at(0)->printShapeInfo("Out Shape"); - // exp.printIndexedBuffer("Expect"); - //exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 3, 7}); + double principalMax = DataTypeUtils::max(); + + auto exp = NDArrayFactory::create( + 'c', {8, 4, 4}, + {91., 82., 37., 64., 55.1, + 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, + 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, + 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, 31., 22., + 87., 44., 55.1, 46.4, 73., + 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117., principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, 91., 82., 37., + 64., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., + 114.2, 16.2, 117.}); + + sd::ops::unsorted_segment_min op; + + auto result = op.evaluate({&x, &idx}, {}, {8}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMean_1) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({2.15, 4.375, 3., 4.4, 1.8666667}); + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({2.15, 4.375, 3., 4.4, 1.8666667}); - sd::ops::segment_mean op; + sd::ops::segment_mean op; - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests7, TestSegmentMean_2) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4}, { 1.95, 2.45, 3.5, 9., 2.1, 2.1, 0.7, 0.1, 3. , 4.2, 2.2, 1.}); - - sd::ops::segment_mean op; + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, {1.95, 2.45, 3.5, 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); -// exp.printIndexedBuffer("Expect"); -// result.at(0)->printIndexedBuffer("Output"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::segment_mean op; - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + // exp.printIndexedBuffer("Expect"); + // result.at(0)->printIndexedBuffer("Output"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests7, TestSegmentMean_02) { - auto x = NDArrayFactory::create('c', {6, 3}, {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 2,2}); - auto exp = NDArrayFactory::create('c', {3, 3}, { 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5}); - - sd::ops::segment_mean op; + auto x = + NDArrayFactory::create('c', {6, 3}, + {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., + 11., 12., 13., 14., 15., 16., 17., 18.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 2, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5}); - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::segment_mean op; - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests7, TestSegmentMean_021) { - auto x = NDArrayFactory::create('c', {6, 3});//, {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 2,2}); - auto exp = NDArrayFactory::create('c', {3, 3}, { 2.5f, 3.5f, 4.5f, 8.5f, 9.5f, 10.5f, 14.5f, 15.5f, 16.5f}); - - sd::ops::segment_mean op; - x.linspace(1.); - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {6, 3}); //, {1, + // 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., + // 15., 16., 17., 18.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 2, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {2.5f, 3.5f, 4.5f, 8.5f, 9.5f, 10.5f, 14.5f, 15.5f, 16.5f}); + + sd::ops::segment_mean op; + x.linspace(1.); + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests7, TestSegmentMean_022) { - auto x = NDArrayFactory::create('c', {6, 3});//, {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 2,2}); - auto z = NDArrayFactory::create('c', {3, 3}); //, { 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5}); - auto exp = NDArrayFactory::create('c', {3, 3}, { 2.5f, 3.5f, 4.5f, 8.5f, 9.5f, 10.5f, 14.5f, 15.5f, 16.5f}); + auto x = NDArrayFactory::create( + 'c', {6, 3}); //, {1, + // 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., + // 15., 16., 17., 18.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 2, 2}); + auto z = NDArrayFactory::create( + 'c', + {3, + 3}); //, { 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {2.5f, 3.5f, 4.5f, 8.5f, 9.5f, 10.5f, 14.5f, 15.5f, 16.5f}); - sd::ops::segment_mean op; - x.linspace(1.); - auto result = op.execute({&x, &idx}, {&z}); - ASSERT_EQ(result, Status::OK()); + sd::ops::segment_mean op; + x.linspace(1.); + auto result = op.execute({&x, &idx}, {&z}); + ASSERT_EQ(result, Status::OK()); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.equalsTo(z)); -// + // } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMeanBP_2) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto eps = NDArrayFactory::create('c', {3, 4}); - auto exp = NDArrayFactory::create('c', {4, 4}, { 0.5, 1., 1.5, 2., 0.5, 1., 1.5, 2., 5., 6., 7., 8., 9., 10., 11., 12.}); - eps.linspace(1); - - sd::ops::segment_mean_bp op; + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto eps = NDArrayFactory::create('c', {3, 4}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, + {0.5, 1., 1.5, 2., 0.5, 1., 1.5, 2., 5., 6., 7., 8., 9., 10., 11., 12.}); + eps.linspace(1); - auto result = op.evaluate({&x, &idx, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 2); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::segment_mean_bp op; - + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 2); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMean_3) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , - 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1 , 46.4 , 73. , 28. ,119.1 , 12.1 , 112.7 , 13.1,14. , 114.2 , 16.2 , 117. , - 41. , 32. , 77. , 34. ,35.1 , 51.4 , 83. , 28. ,114.1 , 47.1 , 62.7, 63.1,64. , 64.2 , 66.2 , 64. , - 91. , 82. , 37. , 64. ,55.1 , 46.4 , 73. , 28. ,119.1 , 12.1 , 112.7 , 13.1,14. , 114.2 , 16.2 , 117. }); - - sd::ops::segment_mean op; - - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); -// result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 41., 32., 77., 34., + 35.1, 51.4, 83., 28., 114.1, 47.1, 62.7, 63.1, 64., 64.2, + 66.2, 64., 91., 82., 37., 64., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::segment_mean op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMean_4) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , - 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 3, 7}); - auto exp = NDArrayFactory::create('c', {8, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. , - 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - sd::ops::segment_mean op; - - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); - //result.at(0)->printShapeInfo("Out Shape"); - //exp.printIndexedBuffer("Expect"); - //exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 3, 7}); + auto exp = NDArrayFactory::create( + 'c', {8, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, + 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., 15.1, 56.4, + 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, 116.2, 11., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 31., 22., 87., 44., 55.1, 46.4, 73., + 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, + 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::segment_mean op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMean_1) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({2.15, 4.375, 3., 4.4, 1.8666667}); - - sd::ops::unsorted_segment_mean op; + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({2.15, 4.375, 3., 4.4, 1.8666667}); - auto result = op.evaluate({&x, &idx}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::unsorted_segment_mean op; - + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMeanBP_1) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); - auto exp = NDArrayFactory::create({1./2., 1./2., 2./4., 2./4., 2./4., 2./4, 3., 4./3., 4./3., 4./3., - 5./6., 5./6., 5./6., 5./6., 5./6., 5./6.}); - sd::ops::segment_mean_bp op; + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + auto exp = NDArrayFactory::create( + {1. / 2., 1. / 2., 2. / 4., 2. / 4., 2. / 4., 2. / 4, 3., 4. / 3., + 4. / 3., 4. / 3., 5. / 6., 5. / 6., 5. / 6., 5. / 6., 5. / 6., 5. / 6.}); + sd::ops::segment_mean_bp op; - auto result = op.evaluate({&x, &idx, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMeanBP_1) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); - auto exp = NDArrayFactory::create({1./2., 1./2., 2./4., 2./4., 2./4., 2./4, 3., 4./3., 4./3., 4./3., - 5./6., 5./6., 5./6., 5./6., 5./6., 5./6.}); - sd::ops::unsorted_segment_mean_bp op; - - auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + auto exp = NDArrayFactory::create( + {1. / 2., 1. / 2., 2. / 4., 2. / 4., 2. / 4., 2. / 4, 3., 4. / 3., + 4. / 3., 4. / 3., 5. / 6., 5. / 6., 5. / 6., 5. / 6., 5. / 6., 5. / 6.}); + sd::ops::unsorted_segment_mean_bp op; - + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMeanBP_2) { - auto x = NDArrayFactory::create({3.,1.8, 2.5,4., 9., 2.1, 2.4,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); - auto exp = NDArrayFactory::create({3., 1./2., 1./2., 2./4., 2./4., 2./4., 2./4, 4./3., 4./3., 4./3., - 5./6., 5./6., 5./6., 5./6., 5./6., 5./6.}); - sd::ops::unsorted_segment_mean_bp op; - - auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto x = + NDArrayFactory::create({3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + auto exp = NDArrayFactory::create( + {3., 1. / 2., 1. / 2., 2. / 4., 2. / 4., 2. / 4., 2. / 4, 4. / 3., + 4. / 3., 4. / 3., 5. / 6., 5. / 6., 5. / 6., 5. / 6., 5. / 6., 5. / 6.}); + sd::ops::unsorted_segment_mean_bp op; - + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMean_2) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4}, { 1.95, 2.45, 3.5, 9., 2.1, 2.1, 0.7, 0.1, 3. , 4.2, 2.2, 1.}); + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, {1.95, 2.45, 3.5, 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - sd::ops::unsorted_segment_mean op; + sd::ops::unsorted_segment_mean op; - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); -// exp.printIndexedBuffer("Expect"); -// result.at(0)->printIndexedBuffer("Output"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + // exp.printIndexedBuffer("Expect"); + // result.at(0)->printIndexedBuffer("Output"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMean_3) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , - 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1 , 46.4 , 73. , 28. ,119.1 , 12.1 , 112.7 , 13.1,14. , 114.2 , 16.2 , 117. , - 41. , 32. , 77. , 34. ,35.1 , 51.4 , 83. , 28. ,114.1 , 47.1 , 62.7, 63.1,64. , 64.2 , 66.2 , 64. , - 91. , 82. , 37. , 64. ,55.1 , 46.4 , 73. , 28. ,119.1 , 12.1 , 112.7 , 13.1,14. , 114.2 , 16.2 , 117. }); - - sd::ops::unsorted_segment_mean op; - - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); -// result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 41., 32., 77., 34., + 35.1, 51.4, 83., 28., 114.1, 47.1, 62.7, 63.1, 64., 64.2, + 66.2, 64., 91., 82., 37., 64., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::unsorted_segment_mean op; + + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMean_4) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , - 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 3, 7}); - auto exp = NDArrayFactory::create('c', {8, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. , - 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - sd::ops::unsorted_segment_mean op; - - auto result = op.evaluate({&x, &idx}, {}, {8}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); - //result.at(0)->printShapeInfo("Out Shape"); - //exp.printIndexedBuffer("Expect"); - //exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 3, 7}); + auto exp = NDArrayFactory::create( + 'c', {8, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, + 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., 15.1, 56.4, + 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, 116.2, 11., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 31., 22., 87., 44., 55.1, 46.4, 73., + 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, + 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::unsorted_segment_mean op; + + auto result = op.evaluate({&x, &idx}, {}, {8}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_1) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({3.0405593, 8.75, 3., 7.621024, 4.5723805}); - - sd::ops::unsorted_segment_sqrt_n op; + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create( + {3.0405593, 8.75, 3., 7.621024, 4.5723805}); - auto result = op.evaluate({&x, &idx}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::unsorted_segment_sqrt_n op; - + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_BP_1) { - auto x = NDArrayFactory::create({3.,1.8, 2.5,4., 9., 2.1, 2.4,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); -// NDArray exp({3.0405593, 8.75, 3., 7.621024, 4.5723805}); - auto exp = NDArrayFactory::create({3., 0.707107, 0.707107, 1., 1., 1., 1., 2.309401, 2.309401, 2.309401, 2.041241, 2.041241, 2.041241, 2.041241, 2.041241, 2.041241}); - sd::ops::unsorted_segment_sqrt_n_bp op; - - auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Hello Out:"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = + NDArrayFactory::create({3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + // NDArray + // exp({3.0405593, 8.75, 3., 7.621024, 4.5723805}); + auto exp = NDArrayFactory::create( + {3., 0.707107, 0.707107, 1., 1., 1., 1., 2.309401, 2.309401, 2.309401, + 2.041241, 2.041241, 2.041241, 2.041241, 2.041241, 2.041241}); + sd::ops::unsorted_segment_sqrt_n_bp op; + + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Hello Out:"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_2) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4}, { 2.7577164, 3.4648232, 4.9497476, 12.727922, - 2.1, 2.1, 0.7, 0.1, - 3. , 4.2, 2.2, 1. - }); + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {2.7577164, 3.4648232, 4.9497476, 12.727922, 2.1, 2.1, 0.7, 0.1, 3., 4.2, + 2.2, 1.}); - sd::ops::unsorted_segment_sqrt_n op; + sd::ops::unsorted_segment_sqrt_n op; - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); -// exp.printIndexedBuffer("Expect"); -// result.at(0)->printIndexedBuffer("Output"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + // exp.printIndexedBuffer("Expect"); + // result.at(0)->printIndexedBuffer("Output"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_3) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , - 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1 , 46.4 , 73. , 28. ,119.1 , 12.1 , 112.7 , 13.1,14. , 114.2 , 16.2 , 117. , - 57.982758, 45.254833, 108.89445, 48.083263, 49.638893, 72.69058, 117.37973, 39.59798, 161.36177, 66.60946, 88.67119, 89.23688, 90.50967, 90.79251, 93.62093, 90.50967, - 91. , 82. , 37. , 64. ,55.1 , 46.4 , 73. , 28. ,119.1 , 12.1 , 112.7 , 13.1,14. , 114.2 , 16.2 , 117. }); - - sd::ops::unsorted_segment_sqrt_n op; - - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); -// result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, + 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117., 57.982758, 45.254833, + 108.89445, 48.083263, 49.638893, 72.69058, 117.37973, 39.59798, + 161.36177, 66.60946, 88.67119, 89.23688, 90.50967, 90.79251, + 93.62093, 90.50967, 91., 82., 37., 64., + 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::unsorted_segment_sqrt_n op; + + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_4) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , - 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 3, 7}); - auto exp = NDArrayFactory::create('c', {8, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. , - 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - sd::ops::unsorted_segment_sqrt_n op; - - auto result = op.evaluate({&x, &idx}, {}, {8}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); - //result.at(0)->printShapeInfo("Out Shape"); - //exp.printIndexedBuffer("Expect"); - //exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 3, 7}); + auto exp = NDArrayFactory::create( + 'c', {8, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, + 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., 15.1, 56.4, + 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, 116.2, 11., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 31., 22., 87., 44., 55.1, 46.4, 73., + 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, + 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::unsorted_segment_sqrt_n op; + + auto result = op.evaluate({&x, &idx}, {}, {8}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_5) { - auto x = NDArrayFactory::create({1.,2.,5.,7.,3.,1.,3.,4.}); - auto idx = NDArrayFactory::create({3, 1, 0, 0, 2, 0, 3, 2}); - //NDArray exp({1.7320508075688772, 1., 1.4142135623730951, 1.4142135623730951}); - auto exp = NDArrayFactory::create({7.5055537, 2., 4.9497476, 2.828427}); - sd::ops::unsorted_segment_sqrt_n op; - - auto result = op.evaluate({&x, &idx}, {}, {4}); - ASSERT_EQ(result.status(), Status::OK()); - // result.at(0)->printIndexedBuffer("Output"); - // exp.printIndexedBuffer("Expect"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto x = NDArrayFactory::create({1., 2., 5., 7., 3., 1., 3., 4.}); + auto idx = NDArrayFactory::create({3, 1, 0, 0, 2, 0, 3, 2}); + // NDArray exp({1.7320508075688772, 1., 1.4142135623730951, + // 1.4142135623730951}); + auto exp = + NDArrayFactory::create({7.5055537, 2., 4.9497476, 2.828427}); + sd::ops::unsorted_segment_sqrt_n op; + auto result = op.evaluate({&x, &idx}, {}, {4}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_6) { - auto x = NDArrayFactory::create({5,1,7,2,3,4,1,3}); - auto idx = NDArrayFactory::create({0,0,0,1,2,2,3,3}); - //NDArray exp({1.7320508075688772, 1., 1.4142135623730951, 1.4142135623730951}); -// auto exp = NDArrayFactory::create({7.5055537, 2., 4.9497476, 2.828427}); - sd::ops::unsorted_segment_sqrt_n op; - -try { + auto x = NDArrayFactory::create({5, 1, 7, 2, 3, 4, 1, 3}); + auto idx = NDArrayFactory::create({0, 0, 0, 1, 2, 2, 3, 3}); + // NDArray exp({1.7320508075688772, 1., 1.4142135623730951, + // 1.4142135623730951}); + // auto exp = + // NDArrayFactory::create({7.5055537, 2., 4.9497476, 2.828427}); + sd::ops::unsorted_segment_sqrt_n op; + + try { auto result = op.evaluate({&x, &idx}, {}, {1}); ASSERT_NE(result.status(), Status::OK()); -} -catch (std::exception& err) { - -} - // result.at(0)->printIndexedBuffer("Output"); - // exp.printIndexedBuffer("Expect"); - //ASSERT_TRUE(exp.equalsTo(result.at(0))); + } catch (std::exception& err) { + } + // result.at(0)->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); + // ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentSum_1) { - auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({ 0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({4.3, 17.5, 3., 13.2, 11.2}); - - sd::ops::segment_sum op; + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({4.3, 17.5, 3., 13.2, 11.2}); - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); + sd::ops::segment_sum op; - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentSumBP_1) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); - auto exp = NDArrayFactory::create({ 1., 1., 2., 2., 2., 2., 3., 4., 4., 4., 5., 5., 5., 5., 5., 5.}); - sd::ops::segment_sum_bp op; - - auto result = op.evaluate({&x, &idx, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + auto exp = NDArrayFactory::create( + {1., 1., 2., 2., 2., 2., 3., 4., 4., 4., 5., 5., 5., 5., 5., 5.}); + sd::ops::segment_sum_bp op; - + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSumBP_1) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto eps = NDArrayFactory::create({1, 2, 3, 4, 5}); - auto exp = NDArrayFactory::create({ 1., 1., 2., 2., 2., 2., 3., 4., 4., 4., 5., 5., 5., 5., 5., 5.}); - sd::ops::unsorted_segment_sum_bp op; + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1, 2, 3, 4, 5}); + auto exp = NDArrayFactory::create( + {1., 1., 2., 2., 2., 2., 3., 4., 4., 4., 5., 5., 5., 5., 5., 5.}); + sd::ops::unsorted_segment_sum_bp op; - auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSumBP_2) { - auto x = NDArrayFactory::create({3.,1.8, 2.5,4., 9., 2.1, 2.4,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); - auto exp = NDArrayFactory::create({ 3., 1., 1., 2., 2., 2., 2., 4., 4., 4., 5., 5., 5., 5., 5., 5.}); - sd::ops::unsorted_segment_sum_bp op; - - auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto x = + NDArrayFactory::create({3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + auto exp = NDArrayFactory::create( + {3., 1., 1., 2., 2., 2., 2., 4., 4., 4., 5., 5., 5., 5., 5., 5.}); + sd::ops::unsorted_segment_sum_bp op; - + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentSum_2) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4}, {3.9 , 4.9, 7. , 18.,2.1 , 2.1, 0.7, 0.1,3. , 4.2, 2.2, 1.}); - - sd::ops::segment_sum op; + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, {3.9, 4.9, 7., 18., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::segment_sum op; - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentSumBP_2) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto exp = NDArrayFactory::create('c', {4, 4}, {1. , 2., 3., 4., 1. , 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - auto eps = NDArrayFactory::create('c', {3, 4}); - eps.linspace(1); + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, + {1., 2., 3., 4., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create('c', {3, 4}); + eps.linspace(1); - sd::ops::segment_sum_bp op; + sd::ops::segment_sum_bp op; - auto result = op.evaluate({&x, &idx, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 2); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 2); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentSum_3) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., - 109.1, 82.1, 12.7, 113.1, 114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , - 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,82. , 64. , 154. , 68. , - 70.2, 102.8, 166. , 56. ,228.2, 94.2, 125.4, 126.2 ,128. , 128.4, 132.4, 128. ,91. , 82. , 37. , 64. , - 55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - sd::ops::segment_sum op; - - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); -// result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 82., 64., 154., 68., + 70.2, 102.8, 166., 56., 228.2, 94.2, 125.4, 126.2, 128., 128.4, + 132.4, 128., 91., 82., 37., 64., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::segment_sum op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentSum_4) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , - 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 3, 7}); - auto exp = NDArrayFactory::create('c', {8, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. , - 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - sd::ops::segment_sum op; - - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); - //result.at(0)->printShapeInfo("Out Shape"); - //exp.printIndexedBuffer("Expect"); - //exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 3, 7}); + auto exp = NDArrayFactory::create( + 'c', {8, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, + 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., 15.1, 56.4, + 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, 116.2, 11., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 31., 22., 87., 44., 55.1, 46.4, 73., + 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, + 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::segment_sum op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSum_1) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({4.3, 17.5, 3., 13.2, 11.2}); + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({4.3, 17.5, 3., 13.2, 11.2}); - sd::ops::unsorted_segment_sum op; + sd::ops::unsorted_segment_sum op; - auto result = op.evaluate({&x, &idx}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSum_2) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4}, {3.9 , 4.9, 7. , 18.,2.1 , 2.1, 0.7, 0.1,3. , 4.2, 2.2, 1.}); - - sd::ops::unsorted_segment_sum op; + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, {3.9, 4.9, 7., 18., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::unsorted_segment_sum op; - + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSum_3) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., - 109.1, 82.1, 12.7, 113.1, 114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , - 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,82. , 64. , 154. , 68. , - 70.2, 102.8, 166. , 56. ,228.2, 94.2, 125.4, 126.2 ,128. , 128.4, 132.4, 128. ,91. , 82. , 37. , 64. , - 55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - sd::ops::unsorted_segment_sum op; - - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); -// result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 82., 64., 154., 68., + 70.2, 102.8, 166., 56., 228.2, 94.2, 125.4, 126.2, 128., 128.4, + 132.4, 128., 91., 82., 37., 64., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::unsorted_segment_sum op; + + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSum_4) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, {91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , - 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 3, 7}); - auto exp = NDArrayFactory::create('c', {8, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. , - 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - sd::ops::unsorted_segment_sum op; - - auto result = op.evaluate({&x, &idx}, {}, {8}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); -// result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); - //exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 3, 7}); + auto exp = NDArrayFactory::create( + 'c', {8, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, + 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., 15.1, 56.4, + 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, 116.2, 11., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 31., 22., 87., 44., 55.1, 46.4, 73., + 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, + 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::unsorted_segment_sum op; + + auto result = op.evaluate({&x, &idx}, {}, {8}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProd_1) { - auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({4.5, 181.44, 3., 39.69, 1.9404}); + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({4.5, 181.44, 3., 39.69, 1.9404}); - sd::ops::segment_prod op; + sd::ops::segment_prod op; - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProdBP_1) { - auto x = NDArrayFactory::create({ 1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); - auto exp = NDArrayFactory::create({2.5, 1.8, 90.72, 40.32, 172.8, 151.2, 3., 17.64, 75.6, 75.6, 13.86, 97.02, 3.234, 2.31, 4.41, 9.702}); - sd::ops::segment_prod_bp op; - - auto result = op.evaluate({&x, &idx, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("ProdBP Output"); -// exp.printIndexedBuffer("ProdBP Expect"); + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + auto exp = NDArrayFactory::create( + {2.5, 1.8, 90.72, 40.32, 172.8, 151.2, 3., 17.64, 75.6, 75.6, 13.86, + 97.02, 3.234, 2.31, 4.41, 9.702}); + sd::ops::segment_prod_bp op; - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("ProdBP Output"); + // exp.printIndexedBuffer("ProdBP Expect"); - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProdBP_1) { - auto x = NDArrayFactory::create({ 1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); - auto exp = NDArrayFactory::create({2.5, 1.8, 90.72, 40.32, 172.8, 151.2, 3., 17.64, 75.6, 75.6, 13.86, 97.02, 3.234, 2.31, 4.41, 9.702}); - sd::ops::segment_prod_bp op; + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + auto exp = NDArrayFactory::create( + {2.5, 1.8, 90.72, 40.32, 172.8, 151.2, 3., 17.64, 75.6, 75.6, 13.86, + 97.02, 3.234, 2.31, 4.41, 9.702}); + sd::ops::segment_prod_bp op; - auto result = op.evaluate({&x, &idx, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("ProdBP Output"); - //exp.printIndexedBuffer("ProdBP Expect"); + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("ProdBP Output"); + // exp.printIndexedBuffer("ProdBP Expect"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProdBP_2) { - auto x = NDArrayFactory::create({ 3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); - auto exp = NDArrayFactory::create({3., 2.5, 1.8, 90.72, 40.32, 172.8, 151.2, 17.64, 75.6, 75.6, 13.86, 97.02, 3.234, 2.31, 4.41, 9.702}); - auto n = NDArrayFactory::create(5LL); - sd::ops::unsorted_segment_prod_bp op; - - auto result = op.evaluate({&x, &idx, &eps, &n}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Unsorted ProdBP Output"); - //exp.printIndexedBuffer("Unsorted ProdBP Expect"); + auto x = + NDArrayFactory::create({3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + auto exp = NDArrayFactory::create({3., 2.5, 1.8, 90.72, 40.32, 172.8, + 151.2, 17.64, 75.6, 75.6, 13.86, + 97.02, 3.234, 2.31, 4.41, 9.702}); + auto n = NDArrayFactory::create(5LL); + sd::ops::unsorted_segment_prod_bp op; - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto result = op.evaluate({&x, &idx, &eps, &n}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Unsorted ProdBP Output"); + // exp.printIndexedBuffer("Unsorted ProdBP Expect"); - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProd_2) { - auto x = NDArrayFactory::create('c', {4, 4}, { - 1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4}, { 3.78, 6. , 12. , 81., 2.1 , 2.1, 0.7 , 0.1, 3. , 4.2, 2.2 , 1.}); - - //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, {3.78, 6., 12., 81., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - sd::ops::segment_prod op; + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::segment_prod op; - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProdBP_2) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9., - 2.1, 2.4, 3., 9., - 2.1, 2.1, 0.7, 0.1, - 3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto eps = NDArrayFactory::create('c', {3, 4}); - auto exp = NDArrayFactory::create('c', {4, 4}, {2.1, 4.8, 9., 36., 1.8, 5., 12., 36., 5., 6., 7., 8., 9., 10., 11., 12.}); - - //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} - eps.linspace(1); - sd::ops::segment_prod_bp op; - - auto result = op.evaluate({&x, &idx, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 2); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto eps = NDArrayFactory::create('c', {3, 4}); + auto exp = + NDArrayFactory::create('c', {4, 4}, + {2.1, 4.8, 9., 36., 1.8, 5., 12., 36., 5., + 6., 7., 8., 9., 10., 11., 12.}); + + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + eps.linspace(1); + sd::ops::segment_prod_bp op; + + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 2); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProd_3) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 51. , 42. , 67. , 24., - 15.1, 56.4, 93. , 28., 109.1, 82.1, 12.7, 113.1, 114. , 14.2, 116.2, 11. , 31. , 22. , 87., 44. , 55.1, 46.4, 73., 28. , - 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 91. , 82. , 37., 64. , 55.1, 46.4, 73., 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4, 4}, { - 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , - 1581, 924, 5829, 1056,832.01001, 2616.9602, 6789, 784, 12993.810, 993.41003, 1431.2899, 1481.61, 1596, 1621.64, 1882.4401, 1287, - 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); - - sd::ops::segment_prod op; - - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); -// result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, + 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117., 1581, 924, + 5829, 1056, 832.01001, 2616.9602, 6789, 784, + 12993.810, 993.41003, 1431.2899, 1481.61, 1596, 1621.64, + 1882.4401, 1287, 91., 82., 37., 64., + 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::segment_prod op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProd_04) { - auto x = NDArrayFactory::create({1,2,3,4,5,6,7,8 }); - -// ---------------------------------------------------------------- + auto x = NDArrayFactory::create({1, 2, 3, 4, 5, 6, 7, 8}); - auto idx = NDArrayFactory::create({0,0,1,2,2,2,3,3}); - auto exp = NDArrayFactory::create({ 2, 3, 120, 56}); + // ---------------------------------------------------------------- - sd::ops::segment_prod op; + auto idx = NDArrayFactory::create({0, 0, 1, 2, 2, 2, 3, 3}); + auto exp = NDArrayFactory::create({2, 3, 120, 56}); - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::segment_prod op; - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProd_05) { - auto x = NDArrayFactory::create({1,2,3,4,5,6,7,8 }); + auto x = NDArrayFactory::create({1, 2, 3, 4, 5, 6, 7, 8}); -// ---------------------------------------------------------------- + // ---------------------------------------------------------------- - auto idx = NDArrayFactory::create({0,0,1,2,2,2,3,3}); - auto exp = NDArrayFactory::create({ 2, 3, 120, 56}); + auto idx = NDArrayFactory::create({0, 0, 1, 2, 2, 2, 3, 3}); + auto exp = NDArrayFactory::create({2, 3, 120, 56}); - sd::ops::segment_prod op; + sd::ops::segment_prod op; - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - auto res = result.at(0); -// res->printIndexedBuffer("Segment prod 05"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + auto res = result.at(0); + // res->printIndexedBuffer("Segment prod 05"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProd_05_1) { - auto x = NDArrayFactory::create({1,2,3,4,5,6,7,8 }); - -// ---------------------------------------------------------------- + auto x = NDArrayFactory::create({1, 2, 3, 4, 5, 6, 7, 8}); - auto idx = NDArrayFactory::create({0,0,1,2,2,2,3,3}); - auto exp = NDArrayFactory::create({ 2, 3, 120, 56}); + // ---------------------------------------------------------------- - sd::ops::segment_prod op; + auto idx = NDArrayFactory::create({0, 0, 1, 2, 2, 2, 3, 3}); + auto exp = NDArrayFactory::create({2, 3, 120, 56}); - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - auto res = result.at(0); -// res->printIndexedBuffer("Segment prod 05_1"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::segment_prod op; - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + auto res = result.at(0); + // res->printIndexedBuffer("Segment prod 05_1"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProd_06) { - auto x = NDArrayFactory::create({'\x1','\x2','\x3','\x4','\x5','\x6','\x7','\x8' }); - -// ---------------------------------------------------------------- + auto x = NDArrayFactory::create( + {'\x1', '\x2', '\x3', '\x4', '\x5', '\x6', '\x7', '\x8'}); - auto idx = NDArrayFactory::create({0,0,1,2,2,2,3,3}); - auto exp = NDArrayFactory::create({ 2, 3, 120, 56}); - sd::ops::segment_prod op; + // ---------------------------------------------------------------- - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto idx = NDArrayFactory::create({0, 0, 1, 2, 2, 2, 3, 3}); + auto exp = NDArrayFactory::create({2, 3, 120, 56}); + sd::ops::segment_prod op; - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProd_07) { - auto x = NDArrayFactory::create({'\x1','\x2','\x3','\x4','\x5','\x6','\x7','\x8' }); + auto x = NDArrayFactory::create( + {'\x1', '\x2', '\x3', '\x4', '\x5', '\x6', '\x7', '\x8'}); -// ---------------------------------------------------------------- + // ---------------------------------------------------------------- - auto idx = NDArrayFactory::create({0,0,1,2,2,2,3,3}); - auto exp = NDArrayFactory::create({ 2, 3, 120, 56}); - sd::ops::segment_prod op; + auto idx = NDArrayFactory::create({0, 0, 1, 2, 2, 2, 3, 3}); + auto exp = NDArrayFactory::create({2, 3, 120, 56}); + sd::ops::segment_prod op; - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProd_08) { - auto x = NDArrayFactory::create({'\x1','\x2','\x3','\x4','\x5','\x6','\x7','\x8', '\x9', '\xA' }); - -// ---------------------------------------------------------------- + auto x = NDArrayFactory::create( + {'\x1', '\x2', '\x3', '\x4', '\x5', '\x6', '\x7', '\x8', '\x9', '\xA'}); - auto idx = NDArrayFactory::create({0,0,2,2,2,2,3,3,3,3}); - auto exp = NDArrayFactory::create({ 2, 1,360, 5040}); - sd::ops::segment_prod op; + // ---------------------------------------------------------------- - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto idx = NDArrayFactory::create({0, 0, 2, 2, 2, 2, 3, 3, 3, 3}); + auto exp = NDArrayFactory::create({2, 1, 360, 5040}); + sd::ops::segment_prod op; - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_1) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({4.5, 181.44, 3., 39.69, 1.9404}); - - sd::ops::unsorted_segment_prod op; + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({4.5, 181.44, 3., 39.69, 1.9404}); - auto result = op.evaluate({&x, &idx}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::unsorted_segment_prod op; - + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_11) { - auto x = NDArrayFactory::create({3.,1.8, 2.5,4., 9., 2.1, 2.4,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({4.5, 181.44, 3., 39.69, 1.9404}); + auto x = + NDArrayFactory::create({3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({4.5, 181.44, 3., 39.69, 1.9404}); - sd::ops::unsorted_segment_prod op; + sd::ops::unsorted_segment_prod op; - auto result = op.evaluate({&x, &idx}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_2) { - auto x = NDArrayFactory::create('c', {4, 4}, { - 1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4}, { 3.78, 6. , 12. , 81., 2.1 , 2.1, 0.7 , 0.1, 3. , 4.2, 2.2 , 1.}); - - //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, {3.78, 6., 12., 81., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - sd::ops::unsorted_segment_prod op; + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::unsorted_segment_prod op; - + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_12) { - auto x = NDArrayFactory::create('c', {4, 4}, { - 3., 4.2, 2.2, 1., - 1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1 }); - auto idx = NDArrayFactory::create({2, 0, 0, 1}); - auto exp = NDArrayFactory::create('c', {3, 4}, { 3.78, 6. , 12. , 81., 2.1 , 2.1, 0.7 , 0.1, 3. , 4.2, 2.2 , 1.}); - - //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + auto x = + NDArrayFactory::create('c', {4, 4}, + {3., 4.2, 2.2, 1., 1.8, 2.5, 4., 9., 2.1, + 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1}); + auto idx = NDArrayFactory::create({2, 0, 0, 1}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, {3.78, 6., 12., 81., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - sd::ops::unsorted_segment_prod op; + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::unsorted_segment_prod op; - + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_08) { - auto x = NDArrayFactory::create({'\x1','\x2','\x3','\x4','\x5','\x6','\x7','\x8', '\x9', '\xA' }); + auto x = NDArrayFactory::create( + {'\x1', '\x2', '\x3', '\x4', '\x5', '\x6', '\x7', '\x8', '\x9', '\xA'}); -// ---------------------------------------------------------------- + // ---------------------------------------------------------------- - auto idx = NDArrayFactory::create({0,0,2,2,2,2,3,3,3,3}); - auto exp = NDArrayFactory::create({ 2, 1,360, 5040}); - sd::ops::unsorted_segment_prod op; + auto idx = NDArrayFactory::create({0, 0, 2, 2, 2, 2, 3, 3, 3, 3}); + auto exp = NDArrayFactory::create({2, 1, 360, 5040}); + sd::ops::unsorted_segment_prod op; - auto result = op.evaluate({&x, &idx}, {}, {4}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx}, {}, {4}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_3) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 51. , 42. , 67. , 24., - 15.1, 56.4, 93. , 28., 109.1, 82.1, 12.7, 113.1, 114. , 14.2, 116.2, 11. , 31. , 22. , 87., 44. , 55.1, 46.4, 73., 28. , - 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 91. , 82. , 37., 64. , 55.1, 46.4, 73., 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4, 4}, { - 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , - 1581, 924, 5829, 1056,832.01001, 2616.9602, 6789, 784, 12993.810, 993.41003, 1431.2899, 1481.61, 1596.0000, 1621.6399, 1882.4401, 1287, - 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); - - sd::ops::unsorted_segment_prod op; - - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); -// result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, + 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117., 1581, 924, + 5829, 1056, 832.01001, 2616.9602, 6789, 784, + 12993.810, 993.41003, 1431.2899, 1481.61, 1596.0000, 1621.6399, + 1882.4401, 1287, 91., 82., 37., 64., + 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::unsorted_segment_prod op; + + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_4) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 51. , 42. , 67. , 24., - 15.1, 56.4, 93. , 28., 109.1, 82.1, 12.7, 113.1, 114. , 14.2, 116.2, 11. , 31. , 22. , 87., 44. , 55.1, 46.4, 73., 28. , - 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 91. , 82. , 37., 64. , 55.1, 46.4, 73., 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); - auto idx = NDArrayFactory::create({1, 1, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4, 4}, { - 1., 1., 1., 1., 1., 1.,1.,1., 1.,1.,1.,1., 1.,1.,1.,1., + // ---------------------------------------------------------------- - 143871, 75768, 215673, 67584., 45843.75, 121426.96, 495597, 21952, - 1547562.8, 12020.262, 161306.38, 19409.092, 22344, 185191.27, 30495.531, 150579, + auto idx = NDArrayFactory::create({1, 1, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 4}, + {1., 1., 1., 1., 1., 1., + 1., 1., 1., 1., 1., 1., + 1., 1., 1., 1., - 91., 82., 37., 64, 55.1, 46.400002, 73, 28, 119.1, 12.1, 112.7, 13.1, 14, 114.2, 16.2, 117}); + 143871, 75768, 215673, 67584., 45843.75, 121426.96, + 495597, 21952, 1547562.8, 12020.262, 161306.38, 19409.092, + 22344, 185191.27, 30495.531, 150579, - sd::ops::unsorted_segment_prod op; + 91., 82., 37., 64, 55.1, 46.400002, + 73, 28, 119.1, 12.1, 112.7, 13.1, + 14, 114.2, 16.2, 117}); - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); -// result.at(0)->printShapeInfo("Out Shape"); - //exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::unsorted_segment_prod op; - + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_5) { - auto x = NDArrayFactory::create('c', {8, 15}); - -// ---------------------------------------------------------------- + auto x = NDArrayFactory::create('c', {8, 15}); - auto idx = NDArrayFactory::create({3, 1, 2, 1, 2, 3, 2, 1}); - auto exp = NDArrayFactory::create('c', {4, 15}, { - 1., 1., 1., 1., 1., - 1., 1., 1., 1., 1., - 1., 1., 1., 1., 1., - 78016., 85493., 93312., 101479., 110000., - 118881., 128128., 137747., 147744., 158125., - 168896., 180063., 191632., 203609., 216000., - 172081., 182528., 193347., 204544., 216125., - 228096., 240463., 253232., 266409., 280000., - 294011., 308448., 323317., 338624., 354375., - 76., 154., 234., 316., 400., - 486., 574., 664., 756., 850., - 946., 1044., 1144., 1246., 1350.}); - x.linspace(1.); + // ---------------------------------------------------------------- - sd::ops::unsorted_segment_prod op; + auto idx = NDArrayFactory::create({3, 1, 2, 1, 2, 3, 2, 1}); + auto exp = NDArrayFactory::create( + 'c', {4, 15}, + {1., 1., 1., 1., 1., 1., 1., 1., + 1., 1., 1., 1., 1., 1., 1., 78016., + 85493., 93312., 101479., 110000., 118881., 128128., 137747., 147744., + 158125., 168896., 180063., 191632., 203609., 216000., 172081., 182528., + 193347., 204544., 216125., 228096., 240463., 253232., 266409., 280000., + 294011., 308448., 323317., 338624., 354375., 76., 154., 234., + 316., 400., 486., 574., 664., 756., 850., 946., + 1044., 1144., 1246., 1350.}); + x.linspace(1.); - auto result = op.evaluate({&x, &idx}, {}, {4}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); -// result.at(0)->printShapeInfo("Out Shape"); - //exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::unsorted_segment_prod op; - + auto result = op.evaluate({&x, &idx}, {}, {4}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProdBP_4) { - auto x = NDArrayFactory::create('c', {8}, { - 5,1,7,2,3,4,1,3}); - auto gradO = NDArrayFactory::create('c', {4}, {1,2,3,4}); -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0,0,0,1,2,2,3,3}); - auto exp = NDArrayFactory::create('c', {8}, { - 7.000000, 35.000000, 5.000000, 2.000000, 12.000000, 9.000000, 12.000000, 4.000000 - }); -// 1., 1., 1., 1., 1., 1.,1.,1., 1.,1.,1.,1., 1.,1.,1.,1., -// -// 143871, 75768, 215673, 67584., 45843.75, 121426.96, 495597, 21952, -// 1547562.8, 12020.262, 161306.38, 19409.092, 22344, 185191.27, 30495.531, 150579, -// -// 91., 82., 37., 64, 55.1, 46.400002, 73, 28, 119.1, 12.1, 112.7, 13.1, 14, 114.2, 16.2, 117}); - - sd::ops::unsorted_segment_prod_bp op; - - auto result = op.evaluate({&x, &idx, &gradO}, {}, {4}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); - //result.at(0)->printShapeInfo("Out Shape"); - //exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create('c', {8}, {5, 1, 7, 2, 3, 4, 1, 3}); + auto gradO = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 0, 0, 1, 2, 2, 3, 3}); + auto exp = NDArrayFactory::create( + 'c', {8}, + {7.000000, 35.000000, 5.000000, 2.000000, 12.000000, 9.000000, 12.000000, + 4.000000}); + // 1., 1., 1., 1., 1., 1.,1.,1., 1.,1.,1.,1., 1.,1.,1.,1., + // + // 143871, 75768, 215673, 67584., 45843.75, 121426.96, 495597, + // 21952, 1547562.8, 12020.262, 161306.38, 19409.092, 22344, + // 185191.27, 30495.531, 150579, + // + // 91., 82., 37., 64, 55.1, 46.400002, 73, 28, 119.1, 12.1, + // 112.7, 13.1, 14, 114.2, 16.2, 117}); + + sd::ops::unsorted_segment_prod_bp op; + + auto result = op.evaluate({&x, &idx, &gradO}, {}, {4}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_1) { - auto x = NDArrayFactory::create('c', {2,4, 4, 4}, { - 91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., - 51., 42., 67., 24., 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., - 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., - 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119., 12., 112., 13., 14., 114., 16.2, 117., - 91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., - 51., 42., 67., 24., 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., - 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., - 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119., 12., 112., 13., 140., 110., 160., 107.}); - -// ---------------------------------------------------------------- - - - auto exp = NDArrayFactory::create('c', {2, 4, 4, 4}, { - 91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., - 51., 42., 67., 24., 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., - 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., - 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119., 12., 112., 13., 14., 114., 16.2, 117., - 91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., - 51., 42., 67., 24., 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., - 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., - 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119., 12., 112., 13., 140., 110., 160., 107.}); - - sd::ops::extract_image_patches op; - - auto result = op.evaluate({&x}, {}, {1,1,1,1,1,1,0}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); - //result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); - //exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {2, 4, 4, 4}, + {91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., + 14., 114., 16., 117., 51., 42., 67., 24., 15., 56., 93., 28., + 109., 82., 12., 113., 114., 14., 116., 11., 31., 22., 87., 44., + 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., + 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119., 12., 112., 13., + 14., 114., 16.2, 117., 91., 82., 37., 64., 55., 46., 73., 28., + 119., 12., 112., 13., 14., 114., 16., 117., 51., 42., 67., 24., + 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., + 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., + 14., 114., 16., 117., 91., 82., 37., 64., 55.1, 46.4, 73., 28., + 119., 12., 112., 13., 140., 110., 160., 107.}); + + // ---------------------------------------------------------------- + + auto exp = NDArrayFactory::create( + 'c', {2, 4, 4, 4}, + {91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., + 14., 114., 16., 117., 51., 42., 67., 24., 15., 56., 93., 28., + 109., 82., 12., 113., 114., 14., 116., 11., 31., 22., 87., 44., + 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., + 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119., 12., 112., 13., + 14., 114., 16.2, 117., 91., 82., 37., 64., 55., 46., 73., 28., + 119., 12., 112., 13., 14., 114., 16., 117., 51., 42., 67., 24., + 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., + 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., + 14., 114., 16., 117., 91., 82., 37., 64., 55.1, 46.4, 73., 28., + 119., 12., 112., 13., 140., 110., 160., 107.}); + + sd::ops::extract_image_patches op; + + auto result = op.evaluate({&x}, {}, {1, 1, 1, 1, 1, 1, 0}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_2) { - auto x = NDArrayFactory::create('c', {3, 3, 4, 3}, { - 11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., - 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., - 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., - 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., - 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., - 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., - 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., - 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., - 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); - -//Images shape is (3, 3, 4, 3) -//[1, 1, 1, 1] -//[1, 3, 2, 1] -auto exp = NDArrayFactory::create('c', {3, 1, 1, 12}, { - 11., 12., 13., 12., 13., 14., 1., 2., 3., 2., 3., 4., - 9., 8., 7., 6., 5., 4., 3., 2., 1., 0., 1., 2., - 211., 12., 13., 12., 213., 14., 21., 2., 3., 2., 3., 24. - }); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - - auto result = op.evaluate({&x}, {}, {2,2, 3,3, 1,1,0}); - ASSERT_EQ(result.status(), Status::OK()); - - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {3, 3, 4, 3}, + {11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., + 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., + 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., + 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., + 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., + 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., + 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., + 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., + 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); + + // Images shape is (3, 3, 4, 3) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {3, 1, 1, 12}, + {11., 12., 13., 12., 13., 14., 1., 2., 3., 2., 3., 4., + 9., 8., 7., 6., 5., 4., 3., 2., 1., 0., 1., 2., + 211., 12., 13., 12., 213., 14., 21., 2., 3., 2., 3., 24.}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate({&x}, {}, {2, 2, 3, 3, 1, 1, 0}); + ASSERT_EQ(result.status(), Status::OK()); + + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_3) { - auto x = NDArrayFactory::create('c', {3, 3, 4, 3}, { - 11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., - 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., - 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., - 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., - 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., - 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., - 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., - 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., - 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); - -//Images shape is (3, 3, 4, 3) -//[1, 1, 1, 1] -//[1, 3, 2, 1] -auto exp = NDArrayFactory::create('c', {3, 1, 2, 6}, { - 11., 12., 13., 5., 6., 7., 15., 16., 17., 35., 36., 37., 9., 8., - 7., 15., 16., 17., 49., 48., 47., 135., 136., 137., 211., 12., 13., 25., - 6., 7., 15., 216., 17., 35., 36., 327. - }); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - - auto result = op.evaluate({&x}, {}, {2,1,3,2,2,2,0}); - ASSERT_EQ(result.status(), Status::OK()); - - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {3, 3, 4, 3}, + {11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., + 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., + 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., + 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., + 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., + 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., + 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., + 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., + 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); + + // Images shape is (3, 3, 4, 3) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {3, 1, 2, 6}, + {11., 12., 13., 5., 6., 7., 15., 16., 17., 35., 36., 37., + 9., 8., 7., 15., 16., 17., 49., 48., 47., 135., 136., 137., + 211., 12., 13., 25., 6., 7., 15., 216., 17., 35., 36., 327.}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate({&x}, {}, {2, 1, 3, 2, 2, 2, 0}); + ASSERT_EQ(result.status(), Status::OK()); + + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_4) { - auto x = NDArrayFactory::create('c', {3, 3, 4, 3}, { - 11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., - 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., - 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., - 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., - 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., - 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., - 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., - 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., - 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); - -//Images shape is (3, 3, 4, 3) -//[1, 1, 1, 1] -//[1, 3, 2, 1] -auto exp = NDArrayFactory::create('c', {3, 3, 4, 3}, { - 11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., - 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., - 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., - 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., - 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., - 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., - 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., - 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., - 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - - auto result = op.evaluate({&x}, {}, {1,1,1,1,1,1,0}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); - //result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); - //exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {3, 3, 4, 3}, + {11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., + 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., + 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., + 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., + 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., + 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., + 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., + 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., + 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); + + // Images shape is (3, 3, 4, 3) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {3, 3, 4, 3}, + {11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., + 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., + 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., + 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., + 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., + 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., + 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., + 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., + 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate({&x}, {}, {1, 1, 1, 1, 1, 1, 0}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_5) { - auto x = NDArrayFactory::create('c', {3, 3, 4, 3}, { - 11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., - 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., - 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., - 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., - 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., - 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., -211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., - 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., - 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); - -//Images shape is (3, 3, 4, 3) -//[1, 1, 1, 1] -//[1, 3, 2, 1] -auto exp = NDArrayFactory::create('c', {3, 1, 1, 18}, { - 11., 12., 13., 15., 16., 17., 1., 2., 3., 21., 22., 23., 5., 6., 7., 35., 36., 37., - 9., 8., 7., 49., 48., 47., 3., 2., 1., 53., 52., 51., 15., 16., 17., 135., 136., 137., - 211., 12., 13., 15., 216., 17., 21., 2., 3., 21., 22., 223., 25., 6., 7., 35., 36., 327. - -//Patch shape is (3, 1, 2, 18) - - }); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - - auto result = op.evaluate({&x}, {}, {3,2,3,2,1,2,0}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); - //result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); - //exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {3, 3, 4, 3}, + {11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., + 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., + 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., + 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., + 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., + 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., + 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., + 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., + 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); + + // Images shape is (3, 3, 4, 3) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {3, 1, 1, 18}, + { + 11., 12., 13., 15., 16., 17., 1., 2., 3., 21., 22., + 23., 5., 6., 7., 35., 36., 37., 9., 8., 7., 49., + 48., 47., 3., 2., 1., 53., 52., 51., 15., 16., 17., + 135., 136., 137., 211., 12., 13., 15., 216., 17., 21., 2., + 3., 21., 22., 223., 25., 6., 7., 35., 36., 327. + + // Patch shape is (3, 1, 2, 18) + + }); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate({&x}, {}, {3, 2, 3, 2, 1, 2, 0}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_6) { - auto x = NDArrayFactory::create('c', {2, 2, 4, 2}, { - 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, - 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42 -}); - -//Images shape is (3, 3, 4, 3) -//[1, 1, 1, 1] -//[1, 3, 2, 1] -auto exp = NDArrayFactory::create('c', {2, 1, 4, 4}, { - 11.11, 11.12, 12.11, 12.12, 11.21, 11.22, 12.21, 12.22, 11.31, 11.32, 12.31, 12.32, 11.41, 11.42, 12.41, 12.42, - 21.11, 21.12, 22.11, 22.12, 21.21, 21.22, 22.21, 22.22, 21.31, 21.32, 22.31, 22.32, 21.41, 21.42, 22.41, 22.42 - }); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - - auto result = op.evaluate({&x}, {}, {2,1, 1,1, 1,1,0}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {2, 2, 4, 2}, + {11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, + 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, + 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, + 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}); + + // Images shape is (3, 3, 4, 3) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {2, 1, 4, 4}, + {11.11, 11.12, 12.11, 12.12, 11.21, 11.22, 12.21, 12.22, + 11.31, 11.32, 12.31, 12.32, 11.41, 11.42, 12.41, 12.42, + 21.11, 21.12, 22.11, 22.12, 21.21, 21.22, 22.21, 22.22, + 21.31, 21.32, 22.31, 22.32, 21.41, 21.42, 22.41, 22.42}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate({&x}, {}, {2, 1, 1, 1, 1, 1, 0}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_7) { - auto x = NDArrayFactory::create('c', {1, 3, 3, 1}); - x.linspace(1); - -//Images shape is (1, 3, 3, 4) -//[1, 1, 1, 1] -//[1, 3, 2, 1] - auto exp = NDArrayFactory::create('c', {1, 3, 3, 4}, { - 1., 2., 4., 5., 2., 3., 5., 6., 3., 0., 6., 0., - 4., 5., 7., 8., 5., 6., 8., 9., 6., 0., 9., 0., 7., 8., 0., 0., 8., 9., 0., 0., 9., 0., 0., 0. }); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - - auto result = op.evaluate({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" - ASSERT_EQ(result.status(), Status::OK()); - auto output = result.at(0); -// output->printBuffer("Output"); -// exp.printBuffer("Expect"); -// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) -// if (exp.e(e) != output->e(e)) -// printf("%lld ", e); -// printf("\n"); - //result.at(1)->printBuffer("OUtput2"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {1, 3, 3, 1}); + x.linspace(1); + + // Images shape is (1, 3, 3, 4) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {1, 3, 3, 4}, + {1., 2., 4., 5., 2., 3., 5., 6., 3., 0., 6., 0., 4., 5., 7., 8., 5., 6., + 8., 9., 6., 0., 9., 0., 7., 8., 0., 0., 8., 9., 0., 0., 9., 0., 0., 0.}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate( + {&x}, {}, + {2, 2, 1, 1, 1, 1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], + // rates=[1,1,1,1], padding="SAME" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); + // output->printBuffer("Output"); + // exp.printBuffer("Expect"); + // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) + // if (exp.e(e) != output->e(e)) + // printf("%lld ", e); + // printf("\n"); + // result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_8) { - auto x = NDArrayFactory::create('c', {1, 3, 3, 2}); - x.linspace(1); - -//Images shape is (1, 3, 3, 4) -//[1, 1, 1, 1] -//[1, 3, 2, 1] - auto exp = NDArrayFactory::create('c', {1, 3, 3, 8}, { - 1, 2, 3, 4, 7, 8, 9, 10, 3, 4, 5, 6, 9, 10, 11, 12, 5, 6, 0, 0, 11, 12, 0, 0, - 7, 8, 9, 10, 13, 14, 15, 16, 9, 10, 11, 12, 15, 16, 17, 18, 11, 12, 0, 0, 17, 18, 0, 0, - 13, 14, 15, 16, 0, 0, 0, 0, 15, 16, 17, 18, 0, 0, 0, 0, 17, 18, 0, 0, 0, 0, 0, 0 }); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - - auto result = op.evaluate({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" - ASSERT_EQ(result.status(), Status::OK()); - auto output = result.at(0); -// output->printBuffer("Output"); -// exp.printBuffer("Expect"); -// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) -// if (exp.e(e) != output->e(e)) -// printf("%lld ", e); -// printf("\n"); - //result.at(1)->printBuffer("OUtput2"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {1, 3, 3, 2}); + x.linspace(1); + + // Images shape is (1, 3, 3, 4) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {1, 3, 3, 8}, + {1, 2, 3, 4, 7, 8, 9, 10, 3, 4, 5, 6, 9, 10, 11, 12, 5, 6, + 0, 0, 11, 12, 0, 0, 7, 8, 9, 10, 13, 14, 15, 16, 9, 10, 11, 12, + 15, 16, 17, 18, 11, 12, 0, 0, 17, 18, 0, 0, 13, 14, 15, 16, 0, 0, + 0, 0, 15, 16, 17, 18, 0, 0, 0, 0, 17, 18, 0, 0, 0, 0, 0, 0}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate( + {&x}, {}, + {2, 2, 1, 1, 1, 1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], + // rates=[1,1,1,1], padding="SAME" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); + // output->printBuffer("Output"); + // exp.printBuffer("Expect"); + // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) + // if (exp.e(e) != output->e(e)) + // printf("%lld ", e); + // printf("\n"); + // result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_9) { - auto x = NDArrayFactory::create('c', {1, 6, 6, 2}); - x.linspace(1); - -//Images shape is (1, 3, 3, 4) -//[1, 1, 1, 1] -//[1, 3, 2, 1] - auto exp = NDArrayFactory::create('c', {1, 6, 6, 18}, { - 0., 0., 0., 0., 0., 0., 0., 0., 1., 2., 3., 4., 0., 0., 13., 14., 15., 16., - 0., 0., 0., 0., 0., 0., 1., 2., 3., 4., 5., 6., 13., 14., 15., 16., 17., 18., - 0., 0., 0., 0., 0., 0., 3., 4., 5., 6., 7., 8., 15., 16., 17., 18., 19., 20., - 0., 0., 0., 0., 0., 0., 5., 6., 7., 8., 9., 10., 17., 18., 19., 20., 21., 22., - 0., 0., 0., 0., 0., 0., 7., 8., 9., 10., 11., 12., 19., 20., 21., 22., 23., 24., - 0., 0., 0., 0., 0., 0., 9., 10., 11., 12., 0., 0., 21., 22., 23., 24., 0., 0., - 0., 0., 1., 2., 3., 4., 0., 0., 13., 14., 15., 16., 0., 0., 25., 26., 27., 28., - 1., 2., 3., 4., 5., 6., 13., 14., 15., 16., 17., 18., 25., 26., 27., 28., 29., 30., - 3., 4., 5., 6., 7., 8., 15., 16., 17., 18., 19., 20., 27., 28., 29., 30., 31., 32., - 5., 6., 7., 8., 9., 10., 17., 18., 19., 20., 21., 22., 29., 30., 31., 32., 33., 34., - 7., 8., 9., 10., 11., 12., 19., 20., 21., 22., 23., 24., 31., 32., 33., 34., 35., 36., - 9., 10., 11., 12., 0., 0., 21., 22., 23., 24., 0., 0., 33., 34., 35., 36., 0., 0., - 0., 0., 13., 14., 15., 16., 0., 0., 25., 26., 27., 28., 0., 0., 37., 38., 39., 40., - 13., 14., 15., 16., 17., 18., 25., 26., 27., 28., 29., 30., 37., 38., 39., 40., 41., 42., - 15., 16., 17., 18., 19., 20., 27., 28., 29., 30., 31., 32., 39., 40., 41., 42., 43., 44., - 17., 18., 19., 20., 21., 22., 29., 30., 31., 32., 33., 34., 41., 42., 43., 44., 45., 46., - 19., 20., 21., 22., 23., 24., 31., 32., 33., 34., 35., 36., 43., 44., 45., 46., 47., 48., - 21., 22., 23., 24., 0., 0., 33., 34., 35., 36., 0., 0., 45., 46., 47., 48., 0., 0., - 0., 0., 25., 26., 27., 28., 0., 0., 37., 38., 39., 40., 0., 0., 49., 50., 51., 52., - 25., 26., 27., 28., 29., 30., 37., 38., 39., 40., 41., 42., 49., 50., 51., 52., 53., 54., - 27., 28., 29., 30., 31., 32., 39., 40., 41., 42., 43., 44., 51., 52., 53., 54., 55., 56., - 29., 30., 31., 32., 33., 34., 41., 42., 43., 44., 45., 46., 53., 54., 55., 56., 57., 58., - 31., 32., 33., 34., 35., 36., 43., 44., 45., 46., 47., 48., 55., 56., 57., 58., 59., 60., - 33., 34., 35., 36., 0., 0., 45., 46., 47., 48., 0., 0., 57., 58., 59., 60., 0., 0., - 0., 0., 37., 38., 39., 40., 0., 0., 49., 50., 51., 52., 0., 0., 61., 62., 63., 64., - 37., 38., 39., 40., 41., 42., 49., 50., 51., 52., 53., 54., 61., 62., 63., 64., 65., 66., - 39., 40., 41., 42., 43., 44., 51., 52., 53., 54., 55., 56., 63., 64., 65., 66., 67., 68., - 41., 42., 43., 44., 45., 46., 53., 54., 55., 56., 57., 58., 65., 66., 67., 68., 69., 70., - 43., 44., 45., 46., 47., 48., 55., 56., 57., 58., 59., 60., 67., 68., 69., 70., 71., 72., - 45., 46., 47., 48., 0., 0., 57., 58., 59., 60., 0., 0., 69., 70., 71., 72., 0., 0., - 0., 0., 49., 50., 51., 52., 0., 0., 61., 62., 63., 64., 0., 0., 0., 0., 0., 0., - 49., 50., 51., 52., 53., 54., 61., 62., 63., 64., 65., 66., 0., 0., 0., 0., 0., 0., - 51., 52., 53., 54., 55., 56., 63., 64., 65., 66., 67., 68., 0., 0., 0., 0., 0., 0., - 53., 54., 55., 56., 57., 58., 65., 66., 67., 68., 69., 70., 0., 0., 0., 0., 0., 0., - 55., 56., 57., 58., 59., 60., 67., 68., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., - 57., 58., 59., 60., 0., 0., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., 0., 0.}); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - - auto result = op.evaluate({&x}, {}, {3,3, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" - ASSERT_EQ(result.status(), Status::OK()); - auto output = result.at(0); -// output->printBuffer("OutputSame"); -// exp.printBuffer("ExpectSame"); -// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) -// if (exp.e(e) != output->e(e)) -// printf("%lld ", e); -// printf("\n"); - //result.at(1)->printBuffer("OUtput2"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {1, 6, 6, 2}); + x.linspace(1); + + // Images shape is (1, 3, 3, 4) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {1, 6, 6, 18}, + {0., 0., 0., 0., 0., 0., 0., 0., 1., 2., 3., 4., 0., 0., + 13., 14., 15., 16., 0., 0., 0., 0., 0., 0., 1., 2., 3., 4., + 5., 6., 13., 14., 15., 16., 17., 18., 0., 0., 0., 0., 0., 0., + 3., 4., 5., 6., 7., 8., 15., 16., 17., 18., 19., 20., 0., 0., + 0., 0., 0., 0., 5., 6., 7., 8., 9., 10., 17., 18., 19., 20., + 21., 22., 0., 0., 0., 0., 0., 0., 7., 8., 9., 10., 11., 12., + 19., 20., 21., 22., 23., 24., 0., 0., 0., 0., 0., 0., 9., 10., + 11., 12., 0., 0., 21., 22., 23., 24., 0., 0., 0., 0., 1., 2., + 3., 4., 0., 0., 13., 14., 15., 16., 0., 0., 25., 26., 27., 28., + 1., 2., 3., 4., 5., 6., 13., 14., 15., 16., 17., 18., 25., 26., + 27., 28., 29., 30., 3., 4., 5., 6., 7., 8., 15., 16., 17., 18., + 19., 20., 27., 28., 29., 30., 31., 32., 5., 6., 7., 8., 9., 10., + 17., 18., 19., 20., 21., 22., 29., 30., 31., 32., 33., 34., 7., 8., + 9., 10., 11., 12., 19., 20., 21., 22., 23., 24., 31., 32., 33., 34., + 35., 36., 9., 10., 11., 12., 0., 0., 21., 22., 23., 24., 0., 0., + 33., 34., 35., 36., 0., 0., 0., 0., 13., 14., 15., 16., 0., 0., + 25., 26., 27., 28., 0., 0., 37., 38., 39., 40., 13., 14., 15., 16., + 17., 18., 25., 26., 27., 28., 29., 30., 37., 38., 39., 40., 41., 42., + 15., 16., 17., 18., 19., 20., 27., 28., 29., 30., 31., 32., 39., 40., + 41., 42., 43., 44., 17., 18., 19., 20., 21., 22., 29., 30., 31., 32., + 33., 34., 41., 42., 43., 44., 45., 46., 19., 20., 21., 22., 23., 24., + 31., 32., 33., 34., 35., 36., 43., 44., 45., 46., 47., 48., 21., 22., + 23., 24., 0., 0., 33., 34., 35., 36., 0., 0., 45., 46., 47., 48., + 0., 0., 0., 0., 25., 26., 27., 28., 0., 0., 37., 38., 39., 40., + 0., 0., 49., 50., 51., 52., 25., 26., 27., 28., 29., 30., 37., 38., + 39., 40., 41., 42., 49., 50., 51., 52., 53., 54., 27., 28., 29., 30., + 31., 32., 39., 40., 41., 42., 43., 44., 51., 52., 53., 54., 55., 56., + 29., 30., 31., 32., 33., 34., 41., 42., 43., 44., 45., 46., 53., 54., + 55., 56., 57., 58., 31., 32., 33., 34., 35., 36., 43., 44., 45., 46., + 47., 48., 55., 56., 57., 58., 59., 60., 33., 34., 35., 36., 0., 0., + 45., 46., 47., 48., 0., 0., 57., 58., 59., 60., 0., 0., 0., 0., + 37., 38., 39., 40., 0., 0., 49., 50., 51., 52., 0., 0., 61., 62., + 63., 64., 37., 38., 39., 40., 41., 42., 49., 50., 51., 52., 53., 54., + 61., 62., 63., 64., 65., 66., 39., 40., 41., 42., 43., 44., 51., 52., + 53., 54., 55., 56., 63., 64., 65., 66., 67., 68., 41., 42., 43., 44., + 45., 46., 53., 54., 55., 56., 57., 58., 65., 66., 67., 68., 69., 70., + 43., 44., 45., 46., 47., 48., 55., 56., 57., 58., 59., 60., 67., 68., + 69., 70., 71., 72., 45., 46., 47., 48., 0., 0., 57., 58., 59., 60., + 0., 0., 69., 70., 71., 72., 0., 0., 0., 0., 49., 50., 51., 52., + 0., 0., 61., 62., 63., 64., 0., 0., 0., 0., 0., 0., 49., 50., + 51., 52., 53., 54., 61., 62., 63., 64., 65., 66., 0., 0., 0., 0., + 0., 0., 51., 52., 53., 54., 55., 56., 63., 64., 65., 66., 67., 68., + 0., 0., 0., 0., 0., 0., 53., 54., 55., 56., 57., 58., 65., 66., + 67., 68., 69., 70., 0., 0., 0., 0., 0., 0., 55., 56., 57., 58., + 59., 60., 67., 68., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., + 57., 58., 59., 60., 0., 0., 69., 70., 71., 72., 0., 0., 0., 0., + 0., 0., 0., 0.}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate( + {&x}, {}, + {3, 3, 1, 1, 1, 1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], + // rates=[1,1,1,1], padding="SAME" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); + // output->printBuffer("OutputSame"); + // exp.printBuffer("ExpectSame"); + // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) + // if (exp.e(e) != output->e(e)) + // printf("%lld ", e); + // printf("\n"); + // result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_9_1) { - auto x = NDArrayFactory::create('c', {1, 4, 4, 2}, {1, 116, 2, 116, 3, 116, 4, 116, - 5, 117, 6, 117, 7, 117, 8, 117, - 9, 118, 10, 118, 11, 118, 12, 118, - 13, 119, 14, 119, 15, 119, 16, 119}); - //x.linspace(1); - -//Images shape is (1, 3, 3, 4) -//[1, 1, 1, 1] -//[1, 3, 2, 1] - auto exp = NDArrayFactory::create('c', {1, 4, 4, 8}, { - 1, 116, 2, 116, 5, 117, 6, 117, 2, 116, 3, 116, 6, 117, 7, 117, 3, 116, - 4, 116, 7, 117, 8, 117, 4, 116, 0, 0, 8, 117, 0, 0, 5, 117, 6, 117, - 9, 118, 10, 118, 6, 117, 7, 117, 10, 118, 11, 118, 7, 117, 8, 117, 11, 118, -12, 118, 8, 117, 0, 0, 12, 118, 0, 0, 9, 118, 10, 118, 13, 119, 14, 119, -10, 118, 11, 118, 14, 119, 15, 119, 11, 118, 12, 118, 15, 119, 16, 119, 12, 118, - 0, 0, 16, 119, 0, 0, 13, 119, 14, 119, 0, 0, 0, 0, 14, 119, 15, 119, - 0, 0, 0, 0, 15, 119, 16, 119, 0, 0, 0, 0, 16, 119, 0, 0, 0, 0, - 0, 0 - - }); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - - auto result = op.evaluate({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" - ASSERT_EQ(result.status(), Status::OK()); - auto output = result.at(0); -// output->printBuffer("OutputSame"); -// exp.printBuffer("ExpectSame"); -// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) -// if (exp.e(e) != output->e(e)) -// printf("%lld ", e); -// printf("\n"); - //result.at(1)->printBuffer("OUtput2"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create( + 'c', {1, 4, 4, 2}, + {1, 116, 2, 116, 3, 116, 4, 116, 5, 117, 6, 117, 7, 117, 8, 117, + 9, 118, 10, 118, 11, 118, 12, 118, 13, 119, 14, 119, 15, 119, 16, 119}); + // x.linspace(1); + + // Images shape is (1, 3, 3, 4) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {1, 4, 4, 8}, + {1, 116, 2, 116, 5, 117, 6, 117, 2, 116, 3, 116, 6, + 117, 7, 117, 3, 116, 4, 116, 7, 117, 8, 117, 4, 116, + 0, 0, 8, 117, 0, 0, 5, 117, 6, 117, 9, 118, 10, + 118, 6, 117, 7, 117, 10, 118, 11, 118, 7, 117, 8, 117, + 11, 118, 12, 118, 8, 117, 0, 0, 12, 118, 0, 0, 9, + 118, 10, 118, 13, 119, 14, 119, 10, 118, 11, 118, 14, 119, + 15, 119, 11, 118, 12, 118, 15, 119, 16, 119, 12, 118, 0, + 0, 16, 119, 0, 0, 13, 119, 14, 119, 0, 0, 0, 0, + 14, 119, 15, 119, 0, 0, 0, 0, 15, 119, 16, 119, 0, + 0, 0, 0, 16, 119, 0, 0, 0, 0, 0, 0 + + }); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate( + {&x}, {}, + {2, 2, 1, 1, 1, 1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], + // rates=[1,1,1,1], padding="SAME" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); + // output->printBuffer("OutputSame"); + // exp.printBuffer("ExpectSame"); + // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) + // if (exp.e(e) != output->e(e)) + // printf("%lld ", e); + // printf("\n"); + // result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } // // //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_10) { - auto x = NDArrayFactory::create('c', {1, 6, 6, 2}); - x.linspace(1); - -//Images shape is (1, 3, 3, 4) -//[1, 1, 1, 1] -//[1, 3, 2, 1] - auto exp = NDArrayFactory::create('c', {1, 4, 4, 18}, { - 1., 2., 3., 4., 5., 6., 13., 14., 15., 16., 17., 18., 25., 26., 27., 28., 29., 30., - 3., 4., 5., 6., 7., 8., 15., 16., 17., 18., 19., 20., 27., 28., 29., 30., 31., 32., - 5., 6., 7., 8., 9., 10., 17., 18., 19., 20., 21., 22., 29., 30., 31., 32., 33., 34., - 7., 8., 9., 10., 11., 12., 19., 20., 21., 22., 23., 24., 31., 32., 33., 34., 35., 36., - 13., 14., 15., 16., 17., 18., 25., 26., 27., 28., 29., 30., 37., 38., 39., 40., 41., 42., - 15., 16., 17., 18., 19., 20., 27., 28., 29., 30., 31., 32., 39., 40., 41., 42., 43., 44., - 17., 18., 19., 20., 21., 22., 29., 30., 31., 32., 33., 34., 41., 42., 43., 44., 45., 46., - 19., 20., 21., 22., 23., 24., 31., 32., 33., 34., 35., 36., 43., 44., 45., 46., 47., 48., - 25., 26., 27., 28., 29., 30., 37., 38., 39., 40., 41., 42., 49., 50., 51., 52., 53., 54., - 27., 28., 29., 30., 31., 32., 39., 40., 41., 42., 43., 44., 51., 52., 53., 54., 55., 56., - 29., 30., 31., 32., 33., 34., 41., 42., 43., 44., 45., 46., 53., 54., 55., 56., 57., 58., - 31., 32., 33., 34., 35., 36., 43., 44., 45., 46., 47., 48., 55., 56., 57., 58., 59., 60., - 37., 38., 39., 40., 41., 42., 49., 50., 51., 52., 53., 54., 61., 62., 63., 64., 65., 66., - 39., 40., 41., 42., 43., 44., 51., 52., 53., 54., 55., 56., 63., 64., 65., 66., 67., 68., - 41., 42., 43., 44., 45., 46., 53., 54., 55., 56., 57., 58., 65., 66., 67., 68., 69., 70., - 43., 44., 45., 46., 47., 48., 55., 56., 57., 58., 59., 60., 67., 68., 69., 70., 71., 72.}); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - //x.printIndexedBuffer("Images"); - //x.printBuffer("Images linear"); - auto result = op.evaluate({&x}, {}, {3,3, 1,1, 1,1, 0}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID" - ASSERT_EQ(result.status(), Status::OK()); - auto output = result.at(0); -// output->printBuffer("OutputValid"); -// exp.printBuffer("ExpectValid"); -// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) -// if (exp.e(e) != output->e(e)) -// printf("%lld ", e); -// printf("\n"); - //result.at(1)->printBuffer("OUtput2"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {1, 6, 6, 2}); + x.linspace(1); + + // Images shape is (1, 3, 3, 4) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {1, 4, 4, 18}, + {1., 2., 3., 4., 5., 6., 13., 14., 15., 16., 17., 18., 25., 26., + 27., 28., 29., 30., 3., 4., 5., 6., 7., 8., 15., 16., 17., 18., + 19., 20., 27., 28., 29., 30., 31., 32., 5., 6., 7., 8., 9., 10., + 17., 18., 19., 20., 21., 22., 29., 30., 31., 32., 33., 34., 7., 8., + 9., 10., 11., 12., 19., 20., 21., 22., 23., 24., 31., 32., 33., 34., + 35., 36., 13., 14., 15., 16., 17., 18., 25., 26., 27., 28., 29., 30., + 37., 38., 39., 40., 41., 42., 15., 16., 17., 18., 19., 20., 27., 28., + 29., 30., 31., 32., 39., 40., 41., 42., 43., 44., 17., 18., 19., 20., + 21., 22., 29., 30., 31., 32., 33., 34., 41., 42., 43., 44., 45., 46., + 19., 20., 21., 22., 23., 24., 31., 32., 33., 34., 35., 36., 43., 44., + 45., 46., 47., 48., 25., 26., 27., 28., 29., 30., 37., 38., 39., 40., + 41., 42., 49., 50., 51., 52., 53., 54., 27., 28., 29., 30., 31., 32., + 39., 40., 41., 42., 43., 44., 51., 52., 53., 54., 55., 56., 29., 30., + 31., 32., 33., 34., 41., 42., 43., 44., 45., 46., 53., 54., 55., 56., + 57., 58., 31., 32., 33., 34., 35., 36., 43., 44., 45., 46., 47., 48., + 55., 56., 57., 58., 59., 60., 37., 38., 39., 40., 41., 42., 49., 50., + 51., 52., 53., 54., 61., 62., 63., 64., 65., 66., 39., 40., 41., 42., + 43., 44., 51., 52., 53., 54., 55., 56., 63., 64., 65., 66., 67., 68., + 41., 42., 43., 44., 45., 46., 53., 54., 55., 56., 57., 58., 65., 66., + 67., 68., 69., 70., 43., 44., 45., 46., 47., 48., 55., 56., 57., 58., + 59., 60., 67., 68., 69., 70., 71., 72.}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + // x.printIndexedBuffer("Images"); + // x.printBuffer("Images linear"); + auto result = op.evaluate( + {&x}, {}, + {3, 3, 1, 1, 1, 1, 0}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], + // rates=[1,1,1,1], padding="VALID" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); + // output->printBuffer("OutputValid"); + // exp.printBuffer("ExpectValid"); + // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) + // if (exp.e(e) != output->e(e)) + // printf("%lld ", e); + // printf("\n"); + // result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_010) { - auto x = NDArrayFactory::create('c', {1, 4, 4, 1}); - x.linspace(1); - -//Images shape is (1, 3, 3, 4) -//[1, 1, 1, 1] -//[1, 3, 2, 1] - auto exp = NDArrayFactory::create('c', {1, 3, 3, 4}, { - 1, 2, 5, 6, 2, 3, 6, 7, 3, 4, 7, 8, 5, 6, 9, 10, 6, 7, 10, 11, 7, 8, 11, 12, - 9, 10, 13, 14, 10, 11, 14, 15, 11, 12, 15, 16}); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - //x.printIndexedBuffer("Images"); - //x.printBuffer("Images linear"); - auto result = op.evaluate({&x}, {}, {2,2, 1,1, 1,1, 0}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID" - ASSERT_EQ(result.status(), Status::OK()); - auto output = result.at(0); -// output->printBuffer("OutputValid"); -// exp.printBuffer("ExpectValid"); -// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) -// if (exp.e(e) != output->e(e)) -// printf("%lld ", e); -// printf("\n"); - //result.at(1)->printBuffer("OUtput2"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {1, 4, 4, 1}); + x.linspace(1); + + // Images shape is (1, 3, 3, 4) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {1, 3, 3, 4}, + {1, 2, 5, 6, 2, 3, 6, 7, 3, 4, 7, 8, 5, 6, 9, 10, 6, 7, + 10, 11, 7, 8, 11, 12, 9, 10, 13, 14, 10, 11, 14, 15, 11, 12, 15, 16}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + // x.printIndexedBuffer("Images"); + // x.printBuffer("Images linear"); + auto result = op.evaluate( + {&x}, {}, + {2, 2, 1, 1, 1, 1, 0}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], + // rates=[1,1,1,1], padding="VALID" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); + // output->printBuffer("OutputValid"); + // exp.printBuffer("ExpectValid"); + // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) + // if (exp.e(e) != output->e(e)) + // printf("%lld ", e); + // printf("\n"); + // result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_010_1) { - auto x = NDArrayFactory::create('c', {1, 4, 4, 1}); - x.linspace(1); - -//Images shape is (1, 3, 3, 4) -//[1, 1, 1, 1] -//[1, 3, 2, 1] - auto exp = NDArrayFactory::create('c', {1, 4, 4, 4}, { - 1, 2, 5, 6, 2, 3, 6, 7, 3, 4, 7, 8, 4, 0, 8, 0, 5, 6, 9, 10, 6, 7, 10, 11, - 7, 8, 11, 12, 8, 0, 12, 0, 9, 10, 13, 14, 10, 11, 14, 15, 11, 12, 15, 16, 12, 0, 16, 0, - 13, 14, 0, 0, 14, 15, 0, 0, 15, 16, 0, 0, 16, 0, 0, 0}); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - //x.printIndexedBuffer("Images"); - //x.printBuffer("Images linear"); - auto result = op.evaluate({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID" - ASSERT_EQ(result.status(), Status::OK()); - auto output = result.at(0); -// output->printBuffer("OutputSame"); -// exp.printBuffer("ExpectSame"); -// exp.printIndexedBuffer("Expect Same Formatted"); -// output->printIndexedBuffer("Output Same Formatted"); -// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) -// if (exp.e(e) != output->e(e)) -// printf("%lld ", e); -// printf("\n"); - //result.at(1)->printBuffer("OUtput2"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {1, 4, 4, 1}); + x.linspace(1); + + // Images shape is (1, 3, 3, 4) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {1, 4, 4, 4}, + {1, 2, 5, 6, 2, 3, 6, 7, 3, 4, 7, 8, 4, 0, 8, 0, + 5, 6, 9, 10, 6, 7, 10, 11, 7, 8, 11, 12, 8, 0, 12, 0, + 9, 10, 13, 14, 10, 11, 14, 15, 11, 12, 15, 16, 12, 0, 16, 0, + 13, 14, 0, 0, 14, 15, 0, 0, 15, 16, 0, 0, 16, 0, 0, 0}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + // x.printIndexedBuffer("Images"); + // x.printBuffer("Images linear"); + auto result = op.evaluate( + {&x}, {}, + {2, 2, 1, 1, 1, 1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], + // rates=[1,1,1,1], padding="VALID" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); + // output->printBuffer("OutputSame"); + // exp.printBuffer("ExpectSame"); + // exp.printIndexedBuffer("Expect Same Formatted"); + // output->printIndexedBuffer("Output Same Formatted"); + // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) + // if (exp.e(e) != output->e(e)) + // printf("%lld ", e); + // printf("\n"); + // result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_011) { - auto x = NDArrayFactory::create('c', {1, 4, 4, 1}); - x.linspace(1); - -//Images shape is (1, 3, 3, 4) -//[1, 1, 1, 1] -//[1, 3, 2, 1] - auto exp = NDArrayFactory::create('c', {1, 2, 2, 4}, { - 1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16, - }); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - //x.printIndexedBuffer("Images"); - //x.printBuffer("Images linear"); - auto result = op.evaluate({&x}, {}, {2,2, 1,1, 2,2, 0}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID" - ASSERT_EQ(result.status(), Status::OK()); - auto output = result.at(0); -// output->printBuffer("OutputValid"); -// exp.printBuffer("ExpectValid"); -// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) -// if (exp.e(e) != output->e(e)) -// printf("%lld ", e); -// printf("\n"); - //result.at(1)->printBuffer("OUtput2"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {1, 4, 4, 1}); + x.linspace(1); + + // Images shape is (1, 3, 3, 4) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create('c', {1, 2, 2, 4}, + { + 1, + 3, + 9, + 11, + 2, + 4, + 10, + 12, + 5, + 7, + 13, + 15, + 6, + 8, + 14, + 16, + }); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + // x.printIndexedBuffer("Images"); + // x.printBuffer("Images linear"); + auto result = op.evaluate( + {&x}, {}, + {2, 2, 1, 1, 2, 2, 0}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], + // rates=[1,1,1,1], padding="VALID" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); + // output->printBuffer("OutputValid"); + // exp.printBuffer("ExpectValid"); + // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) + // if (exp.e(e) != output->e(e)) + // printf("%lld ", e); + // printf("\n"); + // result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_11) { - auto x = NDArrayFactory::create('c', {1, 8, 8, 2}); - x.linspace(1); - -//Images shape is (1, 3, 3, 4) -//[1, 1, 1, 1] -//[1, 3, 2, 1] - auto exp = NDArrayFactory::create('c', {1, 4, 4, 8}, { - 1, 2, 3, 4, 17, 18, 19, 20, 5, 6, 7, 8, 21, 22, 23, 24, 9, 10, - 11, 12, 25, 26, 27, 28, 13, 14, 15, 16, 29, 30, 31, 32, 33, 34, 35, 36, - 49, 50, 51, 52, 37, 38, 39, 40, 53, 54, 55, 56, 41, 42, 43, 44, 57, 58, - 59, 60, 45, 46, 47, 48, 61, 62, 63, 64, 65, 66, 67, 68, 81, 82, 83, 84, - 69, 70, 71, 72, 85, 86, 87, 88, 73, 74, 75, 76, 89, 90, 91, 92, 77, 78, - 79, 80, 93, 94, 95, 96, 97, 98, 99, 100, 113, 114, 115, 116, 101, 102, 103, 104, - 117, 118, 119, 120, 105, 106, 107, 108, 121, 122, 123, 124, 109, 110, 111, 112, 125, 126, - 127, 128}); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - - auto result = op.evaluate({&x}, {}, {2,2, 2,2, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" - ASSERT_EQ(result.status(), Status::OK()); - auto output = result.at(0); -// output->printBuffer("Output"); -// exp.printBuffer("Expect"); -// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) -// if (exp.e(e) != output->e(e)) -// printf("%lld ", e); -// printf("\n"); - //result.at(1)->printBuffer("OUtput2"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {1, 8, 8, 2}); + x.linspace(1); + + // Images shape is (1, 3, 3, 4) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {1, 4, 4, 8}, + {1, 2, 3, 4, 17, 18, 19, 20, 5, 6, 7, 8, 21, + 22, 23, 24, 9, 10, 11, 12, 25, 26, 27, 28, 13, 14, + 15, 16, 29, 30, 31, 32, 33, 34, 35, 36, 49, 50, 51, + 52, 37, 38, 39, 40, 53, 54, 55, 56, 41, 42, 43, 44, + 57, 58, 59, 60, 45, 46, 47, 48, 61, 62, 63, 64, 65, + 66, 67, 68, 81, 82, 83, 84, 69, 70, 71, 72, 85, 86, + 87, 88, 73, 74, 75, 76, 89, 90, 91, 92, 77, 78, 79, + 80, 93, 94, 95, 96, 97, 98, 99, 100, 113, 114, 115, 116, + 101, 102, 103, 104, 117, 118, 119, 120, 105, 106, 107, 108, 121, + 122, 123, 124, 109, 110, 111, 112, 125, 126, 127, 128}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate( + {&x}, {}, + {2, 2, 2, 2, 1, 1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], + // rates=[1,1,1,1], padding="SAME" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); + // output->printBuffer("Output"); + // exp.printBuffer("Expect"); + // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) + // if (exp.e(e) != output->e(e)) + // printf("%lld ", e); + // printf("\n"); + // result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_12) { - auto x = NDArrayFactory::create('c', {1, 8, 8, 2}); - x.linspace(1); - -//Images shape is (1, 3, 3, 4) -//[1, 1, 1, 1] -//[1, 3, 2, 1] - auto exp = NDArrayFactory::create('c', {1, 8, 8, 8}, { - 0, 0, 0, 0, 0, 0, 19, 20, 0, 0, 0, 0, 17, 18, 21, 22, 0, 0, - 0, 0, 19, 20, 23, 24, 0, 0, 0, 0, 21, 22, 25, 26, 0, 0, 0, 0, - 23, 24, 27, 28, 0, 0, 0, 0, 25, 26, 29, 30, 0, 0, 0, 0, 27, 28, - 31, 32, 0, 0, 0, 0, 29, 30, 0, 0, 0, 0, 3, 4, 0, 0, 35, 36, - 1, 2, 5, 6, 33, 34, 37, 38, 3, 4, 7, 8, 35, 36, 39, 40, 5, 6, - 9, 10, 37, 38, 41, 42, 7, 8, 11, 12, 39, 40, 43, 44, 9, 10, 13, 14, - 41, 42, 45, 46, 11, 12, 15, 16, 43, 44, 47, 48, 13, 14, 0, 0, 45, 46, - 0, 0, 0, 0, 19, 20, 0, 0, 51, 52, 17, 18, 21, 22, 49, 50, 53, 54, - 19, 20, 23, 24, 51, 52, 55, 56, 21, 22, 25, 26, 53, 54, 57, 58, 23, 24, - 27, 28, 55, 56, 59, 60, 25, 26, 29, 30, 57, 58, 61, 62, 27, 28, 31, 32, - 59, 60, 63, 64, 29, 30, 0, 0, 61, 62, 0, 0, 0, 0, 35, 36, 0, 0, - 67, 68, 33, 34, 37, 38, 65, 66, 69, 70, 35, 36, 39, 40, 67, 68, 71, 72, - 37, 38, 41, 42, 69, 70, 73, 74, 39, 40, 43, 44, 71, 72, 75, 76, 41, 42, - 45, 46, 73, 74, 77, 78, 43, 44, 47, 48, 75, 76, 79, 80, 45, 46, 0, 0, - 77, 78, 0, 0, 0, 0, 51, 52, 0, 0, 83, 84, 49, 50, 53, 54, 81, 82, - 85, 86, 51, 52, 55, 56, 83, 84, 87, 88, 53, 54, 57, 58, 85, 86, 89, 90, - 55, 56, 59, 60, 87, 88, 91, 92, 57, 58, 61, 62, 89, 90, 93, 94, 59, 60, - 63, 64, 91, 92, 95, 96, 61, 62, 0, 0, 93, 94, 0, 0, 0, 0, 67, 68, - 0, 0, 99, 100, 65, 66, 69, 70, 97, 98, 101, 102, 67, 68, 71, 72, 99, 100, - 103, 104, 69, 70, 73, 74, 101, 102, 105, 106, 71, 72, 75, 76, 103, 104, 107, 108, - 73, 74, 77, 78, 105, 106, 109, 110, 75, 76, 79, 80, 107, 108, 111, 112, 77, 78, - 0, 0, 109, 110, 0, 0, 0, 0, 83, 84, 0, 0, 115, 116, 81, 82, 85, 86, - 113, 114, 117, 118, 83, 84, 87, 88, 115, 116, 119, 120, 85, 86, 89, 90, 117, 118, - 121, 122, 87, 88, 91, 92, 119, 120, 123, 124, 89, 90, 93, 94, 121, 122, 125, 126, - 91, 92, 95, 96, 123, 124, 127, 128, 93, 94, 0, 0, 125, 126, 0, 0, 0, 0, - 99, 100, 0, 0, 0, 0, 97, 98, 101, 102, 0, 0, 0, 0, 99, 100, 103, 104, - 0, 0, 0, 0, 101, 102, 105, 106, 0, 0, 0, 0, 103, 104, 107, 108, 0, 0, - 0, 0, 105, 106, 109, 110, 0, 0, 0, 0, 107, 108, 111, 112, 0, 0, 0, 0, - 109, 110, 0, 0, 0, 0, 0, 0}); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - - auto result = op.evaluate({&x}, {}, {2,2, 1,1, 2,2, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,2,2,1], padding="SAME" - ASSERT_EQ(result.status(), Status::OK()); - auto output = result.at(0); - //output->printShapeInfo("Output shape"); -// output->printIndexedBuffer("Output"); -// exp.printBuffer("Expect"); -// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) -// if (exp.e(e) != output->e(e)) -// printf("%lld ", e); -// printf("\n"); - //result.at(1)->printBuffer("OUtput2"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {1, 8, 8, 2}); + x.linspace(1); + + // Images shape is (1, 3, 3, 4) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {1, 8, 8, 8}, + {0, 0, 0, 0, 0, 0, 19, 20, 0, 0, 0, 0, 17, 18, + 21, 22, 0, 0, 0, 0, 19, 20, 23, 24, 0, 0, 0, 0, + 21, 22, 25, 26, 0, 0, 0, 0, 23, 24, 27, 28, 0, 0, + 0, 0, 25, 26, 29, 30, 0, 0, 0, 0, 27, 28, 31, 32, + 0, 0, 0, 0, 29, 30, 0, 0, 0, 0, 3, 4, 0, 0, + 35, 36, 1, 2, 5, 6, 33, 34, 37, 38, 3, 4, 7, 8, + 35, 36, 39, 40, 5, 6, 9, 10, 37, 38, 41, 42, 7, 8, + 11, 12, 39, 40, 43, 44, 9, 10, 13, 14, 41, 42, 45, 46, + 11, 12, 15, 16, 43, 44, 47, 48, 13, 14, 0, 0, 45, 46, + 0, 0, 0, 0, 19, 20, 0, 0, 51, 52, 17, 18, 21, 22, + 49, 50, 53, 54, 19, 20, 23, 24, 51, 52, 55, 56, 21, 22, + 25, 26, 53, 54, 57, 58, 23, 24, 27, 28, 55, 56, 59, 60, + 25, 26, 29, 30, 57, 58, 61, 62, 27, 28, 31, 32, 59, 60, + 63, 64, 29, 30, 0, 0, 61, 62, 0, 0, 0, 0, 35, 36, + 0, 0, 67, 68, 33, 34, 37, 38, 65, 66, 69, 70, 35, 36, + 39, 40, 67, 68, 71, 72, 37, 38, 41, 42, 69, 70, 73, 74, + 39, 40, 43, 44, 71, 72, 75, 76, 41, 42, 45, 46, 73, 74, + 77, 78, 43, 44, 47, 48, 75, 76, 79, 80, 45, 46, 0, 0, + 77, 78, 0, 0, 0, 0, 51, 52, 0, 0, 83, 84, 49, 50, + 53, 54, 81, 82, 85, 86, 51, 52, 55, 56, 83, 84, 87, 88, + 53, 54, 57, 58, 85, 86, 89, 90, 55, 56, 59, 60, 87, 88, + 91, 92, 57, 58, 61, 62, 89, 90, 93, 94, 59, 60, 63, 64, + 91, 92, 95, 96, 61, 62, 0, 0, 93, 94, 0, 0, 0, 0, + 67, 68, 0, 0, 99, 100, 65, 66, 69, 70, 97, 98, 101, 102, + 67, 68, 71, 72, 99, 100, 103, 104, 69, 70, 73, 74, 101, 102, + 105, 106, 71, 72, 75, 76, 103, 104, 107, 108, 73, 74, 77, 78, + 105, 106, 109, 110, 75, 76, 79, 80, 107, 108, 111, 112, 77, 78, + 0, 0, 109, 110, 0, 0, 0, 0, 83, 84, 0, 0, 115, 116, + 81, 82, 85, 86, 113, 114, 117, 118, 83, 84, 87, 88, 115, 116, + 119, 120, 85, 86, 89, 90, 117, 118, 121, 122, 87, 88, 91, 92, + 119, 120, 123, 124, 89, 90, 93, 94, 121, 122, 125, 126, 91, 92, + 95, 96, 123, 124, 127, 128, 93, 94, 0, 0, 125, 126, 0, 0, + 0, 0, 99, 100, 0, 0, 0, 0, 97, 98, 101, 102, 0, 0, + 0, 0, 99, 100, 103, 104, 0, 0, 0, 0, 101, 102, 105, 106, + 0, 0, 0, 0, 103, 104, 107, 108, 0, 0, 0, 0, 105, 106, + 109, 110, 0, 0, 0, 0, 107, 108, 111, 112, 0, 0, 0, 0, + 109, 110, 0, 0, 0, 0, 0, 0}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate( + {&x}, {}, + {2, 2, 1, 1, 2, 2, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], + // rates=[1,2,2,1], padding="SAME" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); + // output->printShapeInfo("Output shape"); + // output->printIndexedBuffer("Output"); + // exp.printBuffer("Expect"); + // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) + // if (exp.e(e) != output->e(e)) + // printf("%lld ", e); + // printf("\n"); + // result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_13) { - auto x = NDArrayFactory::create('c', {1, 3, 3, 2}); - x.linspace(1); - - auto exp = NDArrayFactory::create('c', {1, 3, 3, 8}, { - 1., 2., 3., 4., 7., 8., 9., 10., 3., 4., 5., 6., 9., 10., 11., 12., 5., 6., - 0., 0., 11., 12., 0., 0., 7., 8., 9., 10., 13., 14., 15., 16., 9., 10., 11., 12., - 15., 16., 17., 18., 11., 12., 0., 0., 17., 18., 0., 0., 13., 14., 15., 16., 0., 0., - 0., 0., 15., 16., 17., 18., 0., 0., 0., 0., 17., 18., 0., 0., 0., 0., 0., 0. }); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; + auto x = NDArrayFactory::create('c', {1, 3, 3, 2}); + x.linspace(1); - auto result = op.evaluate({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" - ASSERT_EQ(result.status(), Status::OK()); - auto output = result.at(0); + auto exp = NDArrayFactory::create( + 'c', {1, 3, 3, 8}, + {1., 2., 3., 4., 7., 8., 9., 10., 3., 4., 5., 6., 9., 10., 11., + 12., 5., 6., 0., 0., 11., 12., 0., 0., 7., 8., 9., 10., 13., 14., + 15., 16., 9., 10., 11., 12., 15., 16., 17., 18., 11., 12., 0., 0., 17., + 18., 0., 0., 13., 14., 15., 16., 0., 0., 0., 0., 15., 16., 17., 18., + 0., 0., 0., 0., 17., 18., 0., 0., 0., 0., 0., 0.}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + auto result = op.evaluate( + {&x}, {}, + {2, 2, 1, 1, 1, 1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], + // rates=[1,1,1,1], padding="SAME" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_1) { - auto x = NDArrayFactory::create('c', {2, 2, 4, 2}, { - 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, - 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42 -}); + auto x = NDArrayFactory::create( + 'c', {2, 2, 4, 2}, + {11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, + 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, + 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, + 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}); -auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { - 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, - 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, - 21.41, 21.42, 22.11, 22.12 - }); -// ---------------------------------------------------------------- - sd::ops::roll op; + auto exp = NDArrayFactory::create( + 'c', {2, 2, 4, 2}, + {22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, + 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, + 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, + 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12}); + // ---------------------------------------------------------------- + sd::ops::roll op; - auto result = op.evaluate({&x}, {}, {6}); - ASSERT_EQ(result.status(), Status::OK()); + auto result = op.evaluate({&x}, {}, {6}); + ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_2) { - auto x = NDArrayFactory::create('c', {2, 2, 4, 2}, { - 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, - 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42 -}); - -auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { - 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, - 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42 - }); -// ---------------------------------------------------------------- - sd::ops::roll op; + auto x = NDArrayFactory::create( + 'c', {2, 2, 4, 2}, + {11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, + 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, + 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, + 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}); - auto result = op.evaluate({&x}, {}, {-8}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 4, 2}, + {12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, + 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, + 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, + 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42}); + // ---------------------------------------------------------------- + sd::ops::roll op; - + auto result = op.evaluate({&x}, {}, {-8}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_3) { - auto x = NDArrayFactory::create('c', {2, 2, 4, 2}, { - 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, - 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42 -}); - -auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { - 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, - 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42 - }); -// ---------------------------------------------------------------- - sd::ops::roll op; + auto x = NDArrayFactory::create( + 'c', {2, 2, 4, 2}, + {11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, + 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, + 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, + 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}); - auto result = op.evaluate({&x}, {}, {-40}); - ASSERT_EQ(result.status(), Status::OK()); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 4, 2}, + {12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, + 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, + 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, + 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42}); + // ---------------------------------------------------------------- + sd::ops::roll op; - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto result = op.evaluate({&x}, {}, {-40}); + ASSERT_EQ(result.status(), Status::OK()); - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_4) { - auto x = NDArrayFactory::create('c', {2, 2, 4, 2}, { - 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, - 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42 -}); + auto x = NDArrayFactory::create( + 'c', {2, 2, 4, 2}, + {11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, + 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, + 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, + 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}); -auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { - 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, - 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, - 21.41, 21.42, 22.11, 22.12 - }); -// ---------------------------------------------------------------- - sd::ops::roll op; + auto exp = NDArrayFactory::create( + 'c', {2, 2, 4, 2}, + {22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, + 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, + 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, + 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12}); + // ---------------------------------------------------------------- + sd::ops::roll op; - auto result = op.evaluate({&x}, {}, {38}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output 4"); - //exp.printIndexedBuffer("Expect 4"); + auto result = op.evaluate({&x}, {}, {38}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output 4"); + // exp.printIndexedBuffer("Expect 4"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_4_inplace) { - auto x = NDArrayFactory::create('c', {2, 2, 4, 2}, { - 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, - 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42 -}); + auto x = NDArrayFactory::create( + 'c', {2, 2, 4, 2}, + {11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, + 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, + 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, + 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}); -auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { - 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, - 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, - 21.41, 21.42, 22.11, 22.12 - }); -// ---------------------------------------------------------------- - sd::ops::roll op; - NDArray* y = nullptr; - auto result = op.execute({&x}, {y}, {}, {38}, {}, {}, true); - ASSERT_EQ(result, Status::OK()); - //x.printIndexedBuffer("Output 4 inplace"); - //exp.printIndexedBuffer("Expect 4 inplace"); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 4, 2}, + {22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, + 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, + 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, + 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12}); + // ---------------------------------------------------------------- + sd::ops::roll op; + auto result = op.evaluate({&x}, {}, {38}, {}, {}, true); + ASSERT_EQ(result.status(), Status::OK()); + // x.printIndexedBuffer("Output 4 inplace"); + // exp.printIndexedBuffer("Expect 4 inplace"); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); -// + // } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_5) { - auto x = NDArrayFactory::create('c', {3, 4}, { - 0., 1., 2., 3., 4, 5., 6., 7., 8., 9., 10., 11. -}); - -auto exp = NDArrayFactory::create('c', {3, 4}, { - 2., 3., 0., 1., 6., 7., 4., 5., 10., 11., 8., 9. -// 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3 -}); -// ---------------------------------------------------------------- - sd::ops::roll op; + auto x = NDArrayFactory::create( + 'c', {3, 4}, {0., 1., 2., 3., 4, 5., 6., 7., 8., 9., 10., 11.}); - auto result = op.evaluate({&x}, {}, {2, 1}); - ASSERT_EQ(result.status(), Status::OK()); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + { + 2., 3., 0., 1., 6., 7., 4., 5., 10., 11., 8., 9. + // 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3 + }); + // ---------------------------------------------------------------- + sd::ops::roll op; - //result.at(0)->printIndexedBuffer("Output"); - //exp.printIndexedBuffer("Expect"); + auto result = op.evaluate({&x}, {}, {2, 1}); + ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + // result.at(0)->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_6) { - auto x = NDArrayFactory::create('c', {2, 3, 2}, { - 0., 1., 2., 3., 4, 5., 6., 7., 8., 9., 10., 11. -}); - -auto exp = NDArrayFactory::create('c', {2, 3, 2}, { - 1., 0., 3., 2., 5., 4., 7., 6., 9., 8., 11., 10. -}); -// ---------------------------------------------------------------- - sd::ops::roll op; + auto x = NDArrayFactory::create( + 'c', {2, 3, 2}, {0., 1., 2., 3., 4, 5., 6., 7., 8., 9., 10., 11.}); - auto result = op.evaluate({&x}, {}, {1, 2}); - ASSERT_EQ(result.status(), Status::OK()); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 2}, {1., 0., 3., 2., 5., 4., 7., 6., 9., 8., 11., 10.}); + // ---------------------------------------------------------------- + sd::ops::roll op; - //result.at(0)->printIndexedBuffer("Output"); - //exp.printIndexedBuffer("Expect"); + auto result = op.evaluate({&x}, {}, {1, 2}); + ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + // result.at(0)->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_7) { - auto x = NDArrayFactory::create('c', {2, 3, 2}, { - 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11. -}); + auto x = NDArrayFactory::create( + 'c', {2, 3, 2}, {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.}); -auto exp = NDArrayFactory::create('c', {2, 3, 2}, { - 11., 10., 7., 6., 9., 8., 5., 4., 1., 0., 3., 2. -}); -// ---------------------------------------------------------------- - sd::ops::roll op; + auto exp = NDArrayFactory::create( + 'c', {2, 3, 2}, {11., 10., 7., 6., 9., 8., 5., 4., 1., 0., 3., 2.}); + // ---------------------------------------------------------------- + sd::ops::roll op; - auto result = op.evaluate({&x}, {}, {1, 2, 1, 0}); - ASSERT_EQ(result.status(), Status::OK()); + auto result = op.evaluate({&x}, {}, {1, 2, 1, 0}); + ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); - //exp.printIndexedBuffer("Expect"); + // result.at(0)->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_8) { - auto x = NDArrayFactory::create('c', {2, 3, 2}, { - 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11. -}); + auto x = NDArrayFactory::create( + 'c', {2, 3, 2}, {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.}); -auto exp = NDArrayFactory::create('c', {2, 3, 2}, { - 11., 10., 7., 6., 9., 8., 5., 4., 1., 0., 3., 2. -}); -// ---------------------------------------------------------------- - sd::ops::roll op; - NDArray* y = nullptr; - auto result = op.execute({&x}, {y}, {}, {1, 2, 1, 0}, {}, {}, true); - ASSERT_EQ(result, Status::OK()); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 2}, {11., 10., 7., 6., 9., 8., 5., 4., 1., 0., 3., 2.}); + // ---------------------------------------------------------------- + sd::ops::roll op; + NDArray* y = nullptr; + auto result = op.evaluate({&x}, {}, {1, 2, 1, 0}, {}, {}, true); + ASSERT_EQ(result.status(), Status::OK()); - //x.printIndexedBuffer("Output"); - //exp.printIndexedBuffer("Expect"); + // x.printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); -// + // } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_9) { - auto x = NDArrayFactory::create('c', {2, 3, 3}, { - 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17. -}); + auto x = + NDArrayFactory::create('c', {2, 3, 3}, + {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., + 10., 11., 12., 13., 14., 15., 16., 17.}); -auto exp = NDArrayFactory::create('c', {2, 3, 3}, { - 6., 7., 8., 0., 1., 2., 3., 4., 5., 15., 16., 17., 9., 10., 11., 12., 13., 14. -}); -// ---------------------------------------------------------------- - sd::ops::roll op; - NDArray* y = nullptr; - auto result = op.execute({&x}, {y}, {}, {1, 1}, {}, {}, true); - ASSERT_EQ(result, Status::OK()); + auto exp = + NDArrayFactory::create('c', {2, 3, 3}, + {6., 7., 8., 0., 1., 2., 3., 4., 5., 15., + 16., 17., 9., 10., 11., 12., 13., 14.}); + // ---------------------------------------------------------------- + sd::ops::roll op; + auto result = op.evaluate({&x}, {}, {1, 1}, {}, {}, true); + ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); -// + // } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_10) { - auto x = NDArrayFactory::create('c', {2, 3, 4}, { - 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. - }); - - auto exp = NDArrayFactory::create('c', {2, 3, 4}, { - 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. - }); -// ---------------------------------------------------------------- - sd::ops::roll op; - auto result = op.evaluate({&x}, {}, {3, 1}); - ASSERT_EQ(result.status(), Status::OK()); - auto out = result.at(0); + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.}); -// out->printIndexedBuffer("Output"); - //exp.printIndexedBuffer("Expect"); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.}); + // ---------------------------------------------------------------- + sd::ops::roll op; + auto result = op.evaluate({&x}, {}, {3, 1}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); - ASSERT_TRUE(exp.equalsTo(out)); + // out->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); - + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_11) { - auto x = NDArrayFactory::create('c', {2, 3, 4}, { - 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. - }); - auto shift = NDArrayFactory::create({1,2}); - auto axis = NDArrayFactory::create({0, 1}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, { - 17., 18., 19., 20., 21., 22., 23., 24., 13., 14., 15., 16., 5., 6., 7, 8, 9, 10, 11, 12, 1, 2, 3, 4 - }); -// ---------------------------------------------------------------- - sd::ops::roll op; - NDArray* y = nullptr; - auto result = op.evaluate({&x, &shift, &axis}); - ASSERT_EQ(result.status(), Status::OK()); - auto out = result.at(0); - -// out->printIndexedBuffer("Output"); - //exp.printIndexedBuffer("Expect"); - - ASSERT_TRUE(exp.equalsTo(out)); - - + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.}); + auto shift = NDArrayFactory::create({1, 2}); + auto axis = NDArrayFactory::create({0, 1}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {17., 18., 19., 20., 21., 22., 23., 24., 13., 14., 15., 16., + 5., 6., 7, 8, 9, 10, 11, 12, 1, 2, 3, 4}); + // ---------------------------------------------------------------- + sd::ops::roll op; + NDArray* y = nullptr; + auto result = op.evaluate({&x, &shift, &axis}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); + + // out->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_12) { - auto x = NDArrayFactory::create('c', {2, 3, 4}, { - 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. - }); - auto shift = NDArrayFactory::create({1,1,1}); - auto axis = NDArrayFactory::create({0, 1, 2}); - - auto exp = NDArrayFactory::create('c', {2, 3, 4}, { - 24, 21, 22, 23, 16, 13, 14, 15, 20, 17, 18, 19, 12, 9, 10, 11, 4, 1, 2, 3, 8, 5, 6, 7 - }); -// ---------------------------------------------------------------- - sd::ops::roll op; - NDArray* y = nullptr; - auto result = op.evaluate({&x, &shift, &axis}); - ASSERT_EQ(result.status(), Status::OK()); - auto out = result.at(0); + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.}); + auto shift = NDArrayFactory::create({1, 1, 1}); + auto axis = NDArrayFactory::create({0, 1, 2}); - ASSERT_TRUE(exp.equalsTo(out)); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, {24, 21, 22, 23, 16, 13, 14, 15, 20, 17, 18, 19, + 12, 9, 10, 11, 4, 1, 2, 3, 8, 5, 6, 7}); + // ---------------------------------------------------------------- + sd::ops::roll op; + NDArray* y = nullptr; + auto result = op.evaluate({&x, &shift, &axis}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); - + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_13) { - auto x = NDArrayFactory::create('c', {2, 3, 4}, { - 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. - }); - auto shift = NDArrayFactory::create(3); - auto axis = NDArrayFactory::create(2); + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.}); + auto shift = NDArrayFactory::create(3); + auto axis = NDArrayFactory::create(2); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, { - 2,3,4,1,6,7,8,5,10,11,12,9,14, 15, 16, 13, 18, 19, 20, 17, 22, 23, 24, 21 - }); -// ---------------------------------------------------------------- - sd::ops::roll op; - NDArray* y = nullptr; - auto result = op.evaluate({&x}, {}, {3,2}); - ASSERT_EQ(result.status(), Status::OK()); - auto out = result.at(0); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, {2, 3, 4, 1, 6, 7, 8, 5, 10, 11, 12, 9, + 14, 15, 16, 13, 18, 19, 20, 17, 22, 23, 24, 21}); + // ---------------------------------------------------------------- + sd::ops::roll op; + NDArray* y = nullptr; + auto result = op.evaluate({&x}, {}, {3, 2}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); - ASSERT_TRUE(exp.equalsTo(out)); - - + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_14) { - auto x = NDArrayFactory::create('c', {2, 3, 4}, { - 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. - }); - auto shift = NDArrayFactory::create({1,1,1}); - auto axis = NDArrayFactory::create({0, 1, 2}); - - auto exp = NDArrayFactory::create('c', {2, 3, 4}, { - 24, 21, 22, 23, 16, 13, 14, 15, 20, 17, 18, 19, 12, 9, 10, 11, 4, 1, 2, 3, 8, 5, 6, 7 - }); -// ---------------------------------------------------------------- - sd::ops::roll op; + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.}); + auto shift = NDArrayFactory::create({1, 1, 1}); + auto axis = NDArrayFactory::create({0, 1, 2}); - auto result = op.evaluate({&x, &shift, &axis}); - ASSERT_EQ(result.status(), Status::OK()); - auto out = result.at(0); -// out->printIndexedBuffer("Output"); - //exp.printIndexedBuffer("Expect"); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, {24, 21, 22, 23, 16, 13, 14, 15, 20, 17, 18, 19, + 12, 9, 10, 11, 4, 1, 2, 3, 8, 5, 6, 7}); + // ---------------------------------------------------------------- + sd::ops::roll op; - ASSERT_TRUE(exp.equalsTo(out)); + auto result = op.evaluate({&x, &shift, &axis}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); + // out->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); - + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_15) { - auto x = NDArrayFactory::create({0.7788f, 0.8012f, 0.7244f, 0.2309f }); - auto shift = NDArrayFactory::create(2); - auto axis = NDArrayFactory::create(0); + auto x = NDArrayFactory::create({0.7788f, 0.8012f, 0.7244f, 0.2309f}); + auto shift = NDArrayFactory::create(2); + auto axis = NDArrayFactory::create(0); - auto exp = NDArrayFactory::create({0.7244f, 0.2309f, 0.7788f, 0.8012f }); -// ---------------------------------------------------------------- - sd::ops::roll op; + auto exp = + NDArrayFactory::create({0.7244f, 0.2309f, 0.7788f, 0.8012f}); + // ---------------------------------------------------------------- + sd::ops::roll op; - auto result = op.evaluate({&x, &shift, &axis}); - ASSERT_EQ(result.status(), Status::OK()); - auto out = result.at(0); -// out->printIndexedBuffer("Output 15"); -// exp.printIndexedBuffer("Expect 15"); + auto result = op.evaluate({&x, &shift, &axis}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); + // out->printIndexedBuffer("Output 15"); + // exp.printIndexedBuffer("Expect 15"); - ASSERT_TRUE(exp.equalsTo(out)); - - + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test1) { + const int dim0 = 5, dim1 = 5, dim2 = 4; - const int dim0=5, dim1=5, dim2=4; - - auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., - 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., - 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., - 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., - 82., 90., 91., 89., 92., 34., 35., 33., 36.}); - auto expected = NDArrayFactory::create(50.); - - sd::ops::percentile op; + auto input = NDArrayFactory::create( + 'c', {dim0, dim1, dim2}, + {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., + 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., + 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., + 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., + 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., + 97., 100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., + 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + auto expected = NDArrayFactory::create(50.); - auto result = op.evaluate({&input}, {50.}, {}); - auto output = result.at(0); + sd::ops::percentile op; - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto result = op.evaluate({&input}, {50.}, {}); + auto output = result.at(0); - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test2) { + const int dim0 = 5, dim1 = 5, dim2 = 4; - const int dim0=5, dim1=5, dim2=4; - - auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., - 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., - 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., - 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., - 82., 90., 91., 89., 92., 34., 35., 33., 36.}); - auto expected = NDArrayFactory::create('c', {1,1,1}, {11.}); - - sd::ops::percentile op; - //q, interpolation, keepDims - auto result = op.evaluate({&input}, {10, 2, 1}, {}); - auto output = result.at(0); + auto input = NDArrayFactory::create( + 'c', {dim0, dim1, dim2}, + {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., + 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., + 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., + 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., + 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., + 97., 100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., + 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + auto expected = NDArrayFactory::create('c', {1, 1, 1}, {11.}); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::percentile op; + // q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 2, 1}, {}); + auto output = result.at(0); - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test3) { + const int dim0 = 5, dim1 = 5, dim2 = 4; - const int dim0=5, dim1=5, dim2=4; - - auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., - 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., - 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., - 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., - 82., 90., 91., 89., 92., 34., 35., 33., 36.}); - auto expected = NDArrayFactory::create('c', {1,1,1}, {10.}); - - sd::ops::percentile op; - //q, interpolation, keepDims - auto result = op.evaluate({&input}, {10, 0, 1}, {}); - auto output = result.at(0); + auto input = NDArrayFactory::create( + 'c', {dim0, dim1, dim2}, + {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., + 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., + 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., + 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., + 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., + 97., 100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., + 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + auto expected = NDArrayFactory::create('c', {1, 1, 1}, {10.}); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::percentile op; + // q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 0, 1}, {}); + auto output = result.at(0); - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test4) { + const int dim0 = 5, dim1 = 5, dim2 = 4; - const int dim0=5, dim1=5, dim2=4; - - auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., - 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., - 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., - 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., - 82., 90., 91., 89., 92., 34., 35., 33., 36.}); - auto expected = NDArrayFactory::create('c', {1,1,1}, {11.}); - - sd::ops::percentile op; - //q, interpolation, keepDims - auto result = op.evaluate({&input}, {10, 1, 1}, {}); - auto output = result.at(0); + auto input = NDArrayFactory::create( + 'c', {dim0, dim1, dim2}, + {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., + 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., + 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., + 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., + 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., + 97., 100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., + 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + auto expected = NDArrayFactory::create('c', {1, 1, 1}, {11.}); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::percentile op; + // q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 1, 1}, {}); + auto output = result.at(0); - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test5) { + const int dim0 = 5, dim1 = 5, dim2 = 4; - const int dim0=5, dim1=5, dim2=4; + auto input = NDArrayFactory::create( + 'c', {dim0, dim1, dim2}, + {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., + 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., + 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., + 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., + 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., + 97., 100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., + 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); - auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., - 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., - 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., - 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., - 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + auto expected = + NDArrayFactory::create('c', {1, 1, 4}, {12., 7., 11., 10.}); - auto expected = NDArrayFactory::create('c', {1,1,4}, {12., 7., 11., 10.}); + sd::ops::percentile op; + // q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 0, 1}, {0, 1}); + auto output = result.at(0); - sd::ops::percentile op; - //q, interpolation, keepDims - auto result = op.evaluate({&input}, {10, 0, 1}, {0,1}); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test6) { + const int dim0 = 5, dim1 = 5, dim2 = 4; - const int dim0=5, dim1=5, dim2=4; - - auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., - 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., - 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., - 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., - 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + auto input = NDArrayFactory::create( + 'c', {dim0, dim1, dim2}, + {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., + 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., + 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., + 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., + 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., + 97., 100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., + 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); - auto expected = NDArrayFactory::create('c', {1,1,4}, {16., 14., 15., 13.}); + auto expected = + NDArrayFactory::create('c', {1, 1, 4}, {16., 14., 15., 13.}); - sd::ops::percentile op; - //q, interpolation, keepDims - auto result = op.evaluate({&input}, {10, 1, 1}, {0,1}); - auto output = result.at(0); + sd::ops::percentile op; + // q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 1, 1}, {0, 1}); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test7) { + const int dim0 = 5, dim1 = 5, dim2 = 4; - const int dim0=5, dim1=5, dim2=4; - - auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., - 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., - 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., - 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., - 82., 90., 91., 89., 92., 34., 35., 33., 36.}); - - auto expected = NDArrayFactory::create('c', {1,1,4}, {12., 7., 11., 10.}); + auto input = NDArrayFactory::create( + 'c', {dim0, dim1, dim2}, + {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., + 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., + 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., + 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., + 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., + 97., 100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., + 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); - sd::ops::percentile op; - //q, interpolation, keepDims - auto result = op.evaluate({&input}, {10, 2, 1}, {0,1}); - auto output = result.at(0); + auto expected = + NDArrayFactory::create('c', {1, 1, 4}, {12., 7., 11., 10.}); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::percentile op; + // q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 2, 1}, {0, 1}); + auto output = result.at(0); - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test8) { + const int dim0 = 5, dim1 = 5, dim2 = 4; - const int dim0=5, dim1=5, dim2=4; + auto input = NDArrayFactory::create( + 'c', {dim0, dim1, dim2}, + {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., + 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., + 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., + 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., + 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., + 97., 100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., + 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); - auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., - 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., - 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., - 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., - 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + auto expected = NDArrayFactory::create('c', {4}, {12., 7., 11., 10.}); - auto expected = NDArrayFactory::create('c', {4}, {12., 7., 11., 10.}); + sd::ops::percentile op; + // q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 2, 0}, {0, 1}); + auto output = result.at(0); - sd::ops::percentile op; - //q, interpolation, keepDims - auto result = op.evaluate({&input}, {10, 2, 0}, {0,1}); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test9) { + const int dim0 = 100; - const int dim0=100; - - auto input = NDArrayFactory::create('c', {dim0}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., - 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., - 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., - 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., - 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + auto input = NDArrayFactory::create( + 'c', {dim0}, + {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., + 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., + 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., + 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., + 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., + 97., 100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., + 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); - auto expected = NDArrayFactory::create(11.); + auto expected = NDArrayFactory::create(11.); - sd::ops::percentile op; - //q, interpolation, keepDims - auto result = op.evaluate({&input}, {10, 2, 0}, {0}); - auto output = result.at(0); + sd::ops::percentile op; + // q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 2, 0}, {0}); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test10) { + const int dim0 = 100; - const int dim0=100; - - auto input = NDArrayFactory::create('c', {dim0}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., - 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., - 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., - 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., - 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + auto input = NDArrayFactory::create( + 'c', {dim0}, + {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., + 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., + 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., + 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., + 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., + 97., 100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., + 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); - auto expected = NDArrayFactory::create('c', {1}, {11.}); + auto expected = NDArrayFactory::create('c', {1}, {11.}); - sd::ops::percentile op; - //q, interpolation, keepDims - auto result = op.evaluate({&input}, {10, 2, 1}, {0}); - auto output = result.at(0); + sd::ops::percentile op; + // q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 2, 1}, {0}); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test11) { + const int dim0 = 1; - const int dim0=1; - - auto input = NDArrayFactory::create('c', {dim0}, {100.}); - - auto expected = NDArrayFactory::create('c', {1}, {100.}); + auto input = NDArrayFactory::create('c', {dim0}, {100.}); - sd::ops::percentile op; - //q, interpolation, keepDims - auto result = op.evaluate({&input}, {10, 2, 1}, {0}); - auto output = result.at(0); + auto expected = NDArrayFactory::create('c', {1}, {100.}); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::percentile op; + // q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 2, 1}, {0}); + auto output = result.at(0); - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test12) { + const int dim0 = 1; - const int dim0=1; + auto input = NDArrayFactory::create('c', {dim0}, {100.}); - auto input = NDArrayFactory::create('c', {dim0}, {100.}); + auto expected = NDArrayFactory::create(100.); - auto expected = NDArrayFactory::create(100.); + sd::ops::percentile op; + // q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 2, 0}, {}); + auto output = result.at(0); - sd::ops::percentile op; - //q, interpolation, keepDims - auto result = op.evaluate({&input}, {10, 2, 0}, {}); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, transpose_test3) { + auto input = + NDArrayFactory::create('c', {5, 3}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); + auto exp = + NDArrayFactory::create('c', {3, 5}, + {1.f, 4.f, 7.f, 10.f, 13.f, 2.f, 5.f, 8.f, + 11.f, 14.f, 3.f, 6.f, 9.f, 12.f, 15.f}); - auto input = NDArrayFactory::create('c', {5, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); - auto exp = NDArrayFactory::create('c', {3, 5}, {1.f, 4.f, 7.f, 10.f, 13.f, 2.f, 5.f, 8.f, 11.f, 14.f, 3.f, 6.f, 9.f, 12.f, 15.f}); - - sd::ops::transpose op; - auto result = op.evaluate({&input}, {}, {}); - auto output = result.at(0); + sd::ops::transpose op; + auto result = op.evaluate({&input}, {}, {}); + auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, rationaltanh_test1) { + auto input = + NDArrayFactory::create('c', {8}, {0, 1, 2, 3, 4, 5, 6, 7}); + NDArray exp = + NDArrayFactory::create({0.000000, 0.998222, 1.516093, 1.658054, + 1.695077, 1.706884, 1.711427, 1.713446}); - auto input = NDArrayFactory::create('c', {8}, {0, 1, 2, 3, 4, 5, 6, 7}); - NDArray exp = NDArrayFactory::create({0.000000, 0.998222, 1.516093, 1.658054, 1.695077, 1.706884, 1.711427, 1.713446}); - - sd::ops::rationaltanh op; - auto result = op.evaluate({&input}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Output rationaltanh"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + sd::ops::rationaltanh op; + auto result = op.evaluate({&input}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Output rationaltanh"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, rationaltanh_test2) { + auto input = + NDArrayFactory::create('c', {2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7}); + NDArray exp = + NDArrayFactory::create('c', {2, 2, 2}, + {0.000000, 0.998222, 1.516093, 1.658054, + 1.695077, 1.706884, 1.711427, 1.713446}); - auto input = NDArrayFactory::create('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7}); - NDArray exp = NDArrayFactory::create('c', {2,2,2}, {0.000000, 0.998222, 1.516093, 1.658054, 1.695077, 1.706884, 1.711427, 1.713446}); - - sd::ops::rationaltanh op; - auto result = op.evaluate({&input}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Output rationaltanh"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + sd::ops::rationaltanh op; + auto result = op.evaluate({&input}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Output rationaltanh"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, rationaltanh_test3) { + auto input = + NDArrayFactory::create('c', {2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7}); + auto eps = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray exp = + NDArrayFactory::create('c', {2, 2, 2}, + {1.143933, 1.605747, 0.795557, 0.261710, + 0.095832, 0.041218, 0.020221, 0.010971}); - auto input = NDArrayFactory::create('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7}); - auto eps = NDArrayFactory::create('c', {2,2,2}, {1, 2, 3, 4, 5, 6, 7, 8}); - NDArray exp = NDArrayFactory::create('c', {2,2,2}, {1.143933, 1.605747, 0.795557, 0.261710, 0.095832, 0.041218, 0.020221, 0.010971}); - - sd::ops::rationaltanh_bp op; - auto result = op.evaluate({&input, &eps}, {}, {}); - auto output = result.at(0); -// output->printBuffer("Output rationaltanh BP"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + sd::ops::rationaltanh_bp op; + auto result = op.evaluate({&input, &eps}, {}, {}); + auto output = result.at(0); + // output->printBuffer("Output rationaltanh BP"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, rectifiedtanh_test1) { + auto input = + NDArrayFactory::create('c', {2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7}); + NDArray exp = + NDArrayFactory::create('c', {2, 2, 2}, + {0.000000, 0.761594, 0.964028, 0.995055, + 0.999329, 0.999909, 0.999988, 0.999998}); - auto input = NDArrayFactory::create('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7}); - NDArray exp = NDArrayFactory::create('c', {2,2,2}, {0.000000, 0.761594, 0.964028, 0.995055, 0.999329, 0.999909, 0.999988, 0.999998}); - - sd::ops::rectifiedtanh op; - auto result = op.evaluate({&input}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Output rectifiedtanh"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + sd::ops::rectifiedtanh op; + auto result = op.evaluate({&input}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Output rectifiedtanh"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, rectifiedtanh_test2) { - - auto input = NDArrayFactory::create('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7}); - auto eps = NDArrayFactory::create('c', {2,2,2}, {1, 2, 3, 4, 5, 6, 7, 8}); - NDArray exp = NDArrayFactory::create('c', {2,2,2}, {0.000000, 0.839949, 0.211952, 0.039464, 0.006705, 0.001089, 0.000172, 0.000027}); - - sd::ops::rectifiedtanh_bp op; - auto result = op.evaluate({&input, &eps}, {}, {}); - auto output = result.at(0); -// output->printBuffer("Output rectifiedtanh BP"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto input = + NDArrayFactory::create('c', {2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7}); + auto eps = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray exp = + NDArrayFactory::create('c', {2, 2, 2}, + {0.000000, 0.839949, 0.211952, 0.039464, + 0.006705, 0.001089, 0.000172, 0.000027}); + + sd::ops::rectifiedtanh_bp op; + auto result = op.evaluate({&input, &eps}, {}, {}); + auto output = result.at(0); + // output->printBuffer("Output rectifiedtanh BP"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } TEST_F(DeclarableOpsTests7, RealDiv_1) { + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); + NDArray y = NDArrayFactory::create('c', {1, 2}, {1.f, 2.f}); + NDArray e = + NDArrayFactory::create('c', {1, 2, 2}, {2.f, 1.f, 4.f, 2.f}); - NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); - NDArray y = NDArrayFactory::create('c', {1, 2}, {1.f,2.f}); - NDArray e = NDArrayFactory::create('c', {1, 2, 2}, {2.f, 1.f, 4.f, 2.f}); + sd::ops::realdiv op; + auto result = op.evaluate({&x, &y}, {}, {}); - sd::ops::realdiv op; - auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); -// z->printIndexedBuffer("OUtput RealDiv"); - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(*z)); - - + auto z = result.at(0); + // z->printIndexedBuffer("OUtput RealDiv"); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, RealDiv_BP_1) { + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); + NDArray y = NDArrayFactory::create('c', {1, 2}, {1.f, 2.f}); + NDArray e0 = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 5.f}); + NDArray e1 = NDArrayFactory::create('c', {1, 2}, {-14.f, -5.f}); + NDArray eps = + NDArrayFactory::create('c', {1, 2, 2}, {1.f, 2.f, 3.f, 4.f}); - NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); - NDArray y = NDArrayFactory::create('c', {1, 2}, {1.f, 2.f}); - NDArray e0 = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 5.f}); - NDArray e1 = NDArrayFactory::create('c', {1, 2}, {-14.f, -5.f}); - NDArray eps = NDArrayFactory::create('c', {1, 2, 2}, {1.f, 2.f, 3.f, 4.f}); - - sd::ops::realdiv_bp op; - auto result = op.evaluate({&x, &y, &eps}, {}, {}); + sd::ops::realdiv_bp op; + auto result = op.evaluate({&x, &y, &eps}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - auto z0 = result.at(0); - auto z1 = result.at(1); -// z0->printShapeInfo("OUtput RealDiv BP0 shape"); -// z1->printShapeInfo("OUtput RealDiv BP1 shape"); -// z0->printIndexedBuffer("OUtput RealDiv BP0"); -// z1->printIndexedBuffer("OUtput RealDiv BP1"); -// ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e0.equalsTo(z0)); - ASSERT_TRUE(e1.equalsTo(z1)); - - + auto z0 = result.at(0); + auto z1 = result.at(1); + // z0->printShapeInfo("OUtput RealDiv BP0 shape"); + // z1->printShapeInfo("OUtput RealDiv BP1 shape"); + // z0->printIndexedBuffer("OUtput RealDiv BP0"); + // z1->printIndexedBuffer("OUtput RealDiv BP1"); + // ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e0.equalsTo(z0)); + ASSERT_TRUE(e1.equalsTo(z1)); } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, ShapesOf_1) { + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); + // NDArray y = NDArrayFactory::create('c', {1, 2}, {1,2}); + NDArray e = NDArrayFactory::create({1, 2, 1}); - NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); -// NDArray y = NDArrayFactory::create('c', {1, 2}, {1,2}); - NDArray e = NDArrayFactory::create({1, 2, 1}); - - sd::ops::shapes_of op; - auto result = op.evaluate({&x}, {}, {}); + sd::ops::shapes_of op; + auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("OUtput RealDiv"); -// ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(*z)); - - + auto z = result.at(0); + // z->printIndexedBuffer("OUtput RealDiv"); + // ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, ShapesOf_2) { + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); + NDArray y = NDArrayFactory::create('c', {1, 2}, {1.f, 2.f}); + NDArray e0 = NDArrayFactory::create({1, 2, 1}); + NDArray e1 = NDArrayFactory::create({1, 2}); - NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); - NDArray y = NDArrayFactory::create('c', {1, 2}, {1.f, 2.f}); - NDArray e0 = NDArrayFactory::create({1, 2, 1}); - NDArray e1 = NDArrayFactory::create({1, 2}); - - sd::ops::shapes_of op; - auto result = op.evaluate({&x, &y}, {}, {}); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::shapes_of op; + auto result = op.evaluate({&x, &y}, {}, {}); - auto z0 = result.at(0); - auto z1 = result.at(1); -// z0->printIndexedBuffer("OUtput shapes2"); -// z1->printIndexedBuffer("OUtput shapes2"); -// ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e0.equalsTo(z0)); - ASSERT_TRUE(e1.equalsTo(z1)); + ASSERT_EQ(Status::OK(), result.status()); - + auto z0 = result.at(0); + auto z1 = result.at(1); + // z0->printIndexedBuffer("OUtput shapes2"); + // z1->printIndexedBuffer("OUtput shapes2"); + // ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e0.equalsTo(z0)); + ASSERT_TRUE(e1.equalsTo(z1)); } TEST_F(DeclarableOpsTests7, Size_1) { + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); + NDArray y = NDArrayFactory::create( + 'c', {5, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 7.f, 9.f, 10.f, 10.f, 11.f}); + NDArray e = NDArrayFactory::create(2); - NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); - NDArray y = NDArrayFactory::create('c', {5, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 7.f, 9.f, 10.f, 10.f, 11.f}); - NDArray e = NDArrayFactory::create(2); + sd::ops::size op; + auto result = op.evaluate({&x}, {}, {}); - sd::ops::size op; - auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); -// z->printIndexedBuffer("OUtput SIZE"); -/// ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(*z)); - - + auto z = result.at(0); + // z->printIndexedBuffer("OUtput SIZE"); + /// ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(DeclarableOpsTests7, Size_2) { + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2, 4}); + NDArray y = NDArrayFactory::create('c', {5, 2}, + {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + NDArray e = NDArrayFactory::create(10); - NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2, 4}); - NDArray y = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray e = NDArrayFactory::create(10); - - sd::ops::size op; - auto result = op.evaluate({&y}, {}, {}); + sd::ops::size op; + auto result = op.evaluate({&y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("OUtput SIZE"); -/// ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(*z)); - - + auto z = result.at(0); + // z->printIndexedBuffer("OUtput SIZE"); + /// ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(DeclarableOpsTests7, Softplus_1) { + NDArray x = NDArrayFactory::create('c', {5, 2}, + {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + NDArray e = NDArrayFactory::create( + 'c', {5, 2}, + {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, + 10.000046, 10.000046, 11.000016}); - NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray e = NDArrayFactory::create('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016}); - - sd::ops::softplus op; - auto result = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::softplus op; + auto result = op.evaluate({&x}, {}, {}); - auto z = result.at(0); -// z->printIndexedBuffer("OUtput Softplus"); -/// ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(*z)); + ASSERT_EQ(Status::OK(), result.status()); - + auto z = result.at(0); + // z->printIndexedBuffer("OUtput Softplus"); + /// ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(DeclarableOpsTests7, Softplus_BP_1) { - - NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); -// NDArray e = NDArrayFactory::create('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016}); - NDArray eps = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,6,7,8, 9, 10}); - sd::ops::softplus ffOP; - sd::ops::softplus_bp bpOp; - const OpArgsHolder argsHolderFF({&x}, {}, {}); - const OpArgsHolder argsHolderBP({&x, &eps}, {}, {}); - - bool gradOK = GradCheck::checkGrad(ffOP, bpOp, argsHolderFF, argsHolderBP); - - ASSERT_TRUE(gradOK); -// -// auto z = result.at(0); -// z->printIndexedBuffer("OUtput Softplus"); -///// ASSERT_TRUE(e.isSameShape(z)); -// ASSERT_TRUE(e.equalsTo(*z)); -// -// + NDArray x = NDArrayFactory::create('c', {5, 2}, + {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + // NDArray e = NDArrayFactory::create('c', {5, 2}, + // {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, + // 10.000046, 10.000046, 11.000016}); + NDArray eps = NDArrayFactory::create('c', {5, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + sd::ops::softplus ffOP; + sd::ops::softplus_bp bpOp; + const OpArgsHolder argsHolderFF({&x}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &eps}, {}, {}); + + bool gradOK = GradCheck::checkGrad(ffOP, bpOp, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(gradOK); + // + // auto z = result.at(0); + // z->printIndexedBuffer("OUtput Softplus"); + ///// ASSERT_TRUE(e.isSameShape(z)); + // ASSERT_TRUE(e.equalsTo(*z)); + // + // } TEST_F(DeclarableOpsTests7, Softsign_1) { + NDArray x = NDArrayFactory::create('c', {5, 2}, + {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + NDArray e = NDArrayFactory::create( + 'c', {5, 2}, + {0.5, 0.6666667, 0.75, 0.8, 0.8333333, 0.875, 0.9, 0.90909094, 0.90909094, + 0.9166667}); - NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray e = NDArrayFactory::create('c', {5, 2}, {0.5, 0.6666667, 0.75, 0.8, 0.8333333, 0.875, 0.9, 0.90909094, 0.90909094, 0.9166667}); - - sd::ops::softsign op; - auto result = op.evaluate({&x}, {}, {}); + sd::ops::softsign op; + auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("OUtput Softsign"); -/// ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(*z)); - - + auto z = result.at(0); + // z->printIndexedBuffer("OUtput Softsign"); + /// ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(DeclarableOpsTests7, Softsign_BP_1) { + NDArray x = NDArrayFactory::create('c', {5, 2}, + {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + // NDArray e = NDArrayFactory::create('c', {5, 2}, + // {1.3132616f, 2.126928f, 3.0485873f, 4.01815f, 5.0067153f, 7.0009117f, 9.000123f, + // 10.000046f, 10.000046f, 11.000016f}); + NDArray eps = NDArrayFactory::create('c', {5, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + sd::ops::softsign ffOP; + sd::ops::softsign_bp bpOp; + const OpArgsHolder argsHolderFF({&x}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &eps}, {}, {}); - NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); -// NDArray e = NDArrayFactory::create('c', {5, 2}, {1.3132616f, 2.126928f, 3.0485873f, 4.01815f, 5.0067153f, 7.0009117f, 9.000123f, 10.000046f, 10.000046f, 11.000016f}); - NDArray eps = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,6,7,8, 9, 10}); - sd::ops::softsign ffOP; - sd::ops::softsign_bp bpOp; - const OpArgsHolder argsHolderFF({&x}, {}, {}); - const OpArgsHolder argsHolderBP({&x, &eps}, {}, {}); - - bool gradOK = GradCheck::checkGrad(ffOP, bpOp, argsHolderFF, argsHolderBP); + bool gradOK = GradCheck::checkGrad(ffOP, bpOp, argsHolderFF, argsHolderBP); - ASSERT_TRUE(gradOK); + ASSERT_TRUE(gradOK); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, fill_test2) { + auto x = NDArrayFactory::create('c', {1, 2}, {2, 2}); + auto v = NDArrayFactory::create(42.); + auto exp = + NDArrayFactory::create('c', {2, 2}, {42.f, 42.f, 42.f, 42.f}); - auto x = NDArrayFactory::create('c', {1,2}, {2, 2}); - auto v = NDArrayFactory::create(42.); - auto exp = NDArrayFactory::create('c', {2, 2},{42.f, 42.f, 42.f, 42.f}); - - sd::ops::fill op; - auto result = op.evaluate({&x, &v}, {}, {}); + sd::ops::fill op; + auto result = op.evaluate({&x, &v}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, fill_test3) { + auto x = NDArrayFactory::create('c', {2}, {2, 2}); + auto v = NDArrayFactory::create(42.); + auto exp = + NDArrayFactory::create('c', {2, 2}, {42.f, 42.f, 42.f, 42.f}); - auto x = NDArrayFactory::create('c', {2}, {2, 2}); - auto v = NDArrayFactory::create(42.); - auto exp = NDArrayFactory::create('c', {2, 2}, {42.f, 42.f, 42.f, 42.f}); - - sd::ops::fill op; - auto result = op.evaluate({&x, &v}, {}, {}); - auto output = result.at(0); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::fill op; + auto result = op.evaluate({&x, &v}, {}, {}); + auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, ToggleBits_test1) { + auto x = NDArrayFactory::create('c', {2}, {2, 2}); + auto exp = NDArrayFactory::create('c', {2}, {-3, -3}); - auto x = NDArrayFactory::create('c', {2}, {2, 2}); - auto exp = NDArrayFactory::create('c', {2}, {-3, -3}); - - sd::ops::toggle_bits op; - auto result = op.evaluate({&x}); - auto output = result.at(0); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); -// output->printIndexedBuffer("Toggled"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::toggle_bits op; + auto result = op.evaluate({&x}); + auto output = result.at(0); - + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + // output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, ToggleBits_test2) { + auto x = NDArrayFactory::create('c', {2}, {2, 2}); + auto y = NDArrayFactory::create('c', {2}, {1, 1}); + auto exp0 = NDArrayFactory::create('c', {2}, {-3, -3}); + auto exp1 = NDArrayFactory::create('c', {2}, {-2, -2}); - auto x = NDArrayFactory::create('c', {2}, {2, 2}); - auto y = NDArrayFactory::create('c', {2}, {1, 1}); - auto exp0 = NDArrayFactory::create('c', {2}, {-3, -3}); - auto exp1 = NDArrayFactory::create('c', {2}, {-2, -2}); + sd::ops::toggle_bits op; + auto result = op.evaluate({&x, &y}); + auto output = result.at(0); + auto z = result.at(1); - sd::ops::toggle_bits op; - auto result = op.evaluate({&x, &y}); - auto output = result.at(0); - auto z = result.at(1); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); -// output->printIndexedBuffer("Toggled"); - ASSERT_TRUE(exp0.isSameShape(output)); - ASSERT_TRUE(exp0.equalsTo(output)); - ASSERT_TRUE(exp1.isSameShape(z)); - ASSERT_TRUE(exp1.equalsTo(z)); - - + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + // output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(exp0.isSameShape(output)); + ASSERT_TRUE(exp0.equalsTo(output)); + ASSERT_TRUE(exp1.isSameShape(z)); + ASSERT_TRUE(exp1.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Truncatediv_test1) { - NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray y = NDArrayFactory::create('c', {5, 2}, {2,2,2,2,2,2,2,2, 2, 2}); - NDArray exp = NDArrayFactory::create('c', {5, 2}, {0.5, 1., 1.5, 2., 2.5, 3.5, 4.5, 5., 5., 5.5}); + NDArray x = NDArrayFactory::create('c', {5, 2}, + {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + NDArray y = NDArrayFactory::create('c', {5, 2}, + {2, 2, 2, 2, 2, 2, 2, 2, 2, 2}); + NDArray exp = NDArrayFactory::create( + 'c', {5, 2}, {0.5, 1., 1.5, 2., 2.5, 3.5, 4.5, 5., 5., 5.5}); - sd::ops::truncatediv op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); -// output->printIndexedBuffer("Toggled"); - ASSERT_TRUE(exp.isSameShape(output)); - - + sd::ops::truncatediv op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + // output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(exp.isSameShape(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Truncatediv_test2) { - NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray y = NDArrayFactory::create('c', {1, 2}, {2,2}); - NDArray exp = NDArrayFactory::create('c', {5, 2}, {0.5, 1., 1.5, 2., 2.5, 3.5, 4.5, 5., 5., 5.5}); - - sd::ops::truncatediv op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); -// output->printIndexedBuffer("Toggled"); - ASSERT_TRUE(exp.isSameShape(output)); + NDArray x = NDArrayFactory::create('c', {5, 2}, + {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + NDArray y = NDArrayFactory::create('c', {1, 2}, {2, 2}); + NDArray exp = NDArrayFactory::create( + 'c', {5, 2}, {0.5, 1., 1.5, 2., 2.5, 3.5, 4.5, 5., 5., 5.5}); - + sd::ops::truncatediv op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + // output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(exp.isSameShape(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TypesConversion_test1) { - NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray expI = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray expL = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray expF = NDArrayFactory::create('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); - NDArray expF16 = NDArrayFactory::create('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); - - sd::ops::to_int32 op32; - sd::ops::to_int64 op64; - auto result32 = op32.evaluate({&x}, {}, {}); - auto result64 = op64.evaluate({&x}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result32.status()); - ASSERT_EQ(ND4J_STATUS_OK, result64.status()); - auto out1 = result32.at(0); -// out1->printIndexedBuffer("OUT_I"); - auto out2 = result64.at(0); -// out2->printIndexedBuffer("OUT_L"); - -// output->printIndexedBuffer("Toggled"); - ASSERT_TRUE(expI.equalsTo(out1)); - ASSERT_TRUE(expL.equalsTo(out2)); - + NDArray x = NDArrayFactory::create('c', {5, 2}, + {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + NDArray expI = NDArrayFactory::create('c', {5, 2}, + {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + NDArray expL = NDArrayFactory::create( + 'c', {5, 2}, {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + NDArray expF = NDArrayFactory::create( + 'c', {5, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 7.f, 9.f, 10.f, 10.f, 11.f}); + NDArray expF16 = NDArrayFactory::create( + 'c', {5, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 7.f, 9.f, 10.f, 10.f, 11.f}); + + sd::ops::to_int32 op32; + sd::ops::to_int64 op64; + auto result32 = op32.evaluate({&x}, {}, {}); + auto result64 = op64.evaluate({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result32.status()); + ASSERT_EQ(ND4J_STATUS_OK, result64.status()); + auto out1 = result32.at(0); + // out1->printIndexedBuffer("OUT_I"); + auto out2 = result64.at(0); + // out2->printIndexedBuffer("OUT_L"); + + // output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(expI.equalsTo(out1)); + ASSERT_TRUE(expL.equalsTo(out2)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TypesConversion_test2) { - NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray expF = NDArrayFactory::create('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); - NDArray expH = NDArrayFactory::create('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); + NDArray x = NDArrayFactory::create('c', {5, 2}, + {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + NDArray expF = NDArrayFactory::create( + 'c', {5, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 7.f, 9.f, 10.f, 10.f, 11.f}); + NDArray expH = NDArrayFactory::create( + 'c', {5, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 7.f, 9.f, 10.f, 10.f, 11.f}); - sd::ops::to_float32 op32; - sd::ops::to_float16 op16; - auto result32 = op32.evaluate({&x}, {}, {}); - auto result16 = op16.evaluate({&x}, {}, {}); + sd::ops::to_float32 op32; + sd::ops::to_float16 op16; + auto result32 = op32.evaluate({&x}, {}, {}); + auto result16 = op16.evaluate({&x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result32.status()); - ASSERT_EQ(ND4J_STATUS_OK, result16.status()); - auto out1 = result32.at(0); -// out1->printIndexedBuffer("OUT_F"); - auto out2 = result16.at(0); -// out2->printIndexedBuffer("OUT_H"); - -// output->printIndexedBuffer("Toggled"); - ASSERT_TRUE(expF.equalsTo(out1)); - ASSERT_TRUE(expH.equalsTo(out2)); + ASSERT_EQ(ND4J_STATUS_OK, result32.status()); + ASSERT_EQ(ND4J_STATUS_OK, result16.status()); + auto out1 = result32.at(0); + // out1->printIndexedBuffer("OUT_F"); + auto out2 = result16.at(0); + // out2->printIndexedBuffer("OUT_H"); + // output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(expF.equalsTo(out1)); + ASSERT_TRUE(expH.equalsTo(out2)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TypesConversion_test3) { - NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray exp32 = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray exp64 = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray x = NDArrayFactory::create( + 'c', {5, 2}, {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + NDArray exp32 = NDArrayFactory::create( + 'c', {5, 2}, {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + NDArray exp64 = NDArrayFactory::create( + 'c', {5, 2}, {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); - sd::ops::to_uint32 op32; - sd::ops::to_uint64 op64; - auto result32 = op32.evaluate({&x}, {}, {}); - auto result64 = op64.evaluate({&x}, {}, {}); + sd::ops::to_uint32 op32; + sd::ops::to_uint64 op64; + auto result32 = op32.evaluate({&x}, {}, {}); + auto result64 = op64.evaluate({&x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result32.status()); - ASSERT_EQ(ND4J_STATUS_OK, result64.status()); - auto out1 = result32.at(0); -// out1->printIndexedBuffer("OUT_U32"); - auto out2 = result64.at(0); -// out2->printIndexedBuffer("OUT_U64"); + ASSERT_EQ(ND4J_STATUS_OK, result32.status()); + ASSERT_EQ(ND4J_STATUS_OK, result64.status()); + auto out1 = result32.at(0); + // out1->printIndexedBuffer("OUT_U32"); + auto out2 = result64.at(0); + // out2->printIndexedBuffer("OUT_U64"); -// output->printIndexedBuffer("Toggled"); - ASSERT_TRUE(exp32.equalsTo(out1)); - ASSERT_TRUE(exp64.equalsTo(out2)); + // output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(exp32.equalsTo(out1)); + ASSERT_TRUE(exp64.equalsTo(out2)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TypesConversion_test4) { - NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray exp32 = NDArrayFactory::create('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); - NDArray exp64 = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - - sd::ops::to_float32 op32; - sd::ops::to_double op64; - auto result32 = op32.evaluate({&x}, {}, {}); - auto result64 = op64.evaluate({&x}, {}, {}); + NDArray x = NDArrayFactory::create( + 'c', {5, 2}, {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + NDArray exp32 = NDArrayFactory::create( + 'c', {5, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 7.f, 9.f, 10.f, 10.f, 11.f}); + NDArray exp64 = NDArrayFactory::create( + 'c', {5, 2}, {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); - ASSERT_EQ(ND4J_STATUS_OK, result32.status()); - ASSERT_EQ(ND4J_STATUS_OK, result64.status()); - auto out1 = result32.at(0); - auto out2 = result64.at(0); + sd::ops::to_float32 op32; + sd::ops::to_double op64; + auto result32 = op32.evaluate({&x}, {}, {}); + auto result64 = op64.evaluate({&x}, {}, {}); - ASSERT_TRUE(exp32.equalsTo(out1)); - ASSERT_TRUE(exp64.equalsTo(out2)); + ASSERT_EQ(ND4J_STATUS_OK, result32.status()); + ASSERT_EQ(ND4J_STATUS_OK, result64.status()); + auto out1 = result32.at(0); + auto out2 = result64.at(0); + ASSERT_TRUE(exp32.equalsTo(out1)); + ASSERT_TRUE(exp64.equalsTo(out2)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test1) { + auto input = + NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 2, 2}); - auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 2, 2}); - - auto exp = NDArrayFactory::create('c', {4, 7}, {2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5}); - - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {1}); - auto output = result.at(0); + auto exp = NDArrayFactory::create( + 'c', {4, 7}, {2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, + 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5}); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + auto output = result.at(0); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test2) { + auto input = + NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 2, 2}); - auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 2, 2}); + auto exp = NDArrayFactory::create( + 'c', {4, 7}, {6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, + 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}); - auto exp = NDArrayFactory::create('c', {4, 7}, {6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {0}); + auto output = result.at(0); - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {0}); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test3) { + auto input = NDArrayFactory::create('c', {3}, {1., 2., 3.}); + auto paddings = NDArrayFactory::create('c', {1, 2}, {2, 2}); - auto input = NDArrayFactory::create('c', {3}, {1., 2., 3.}); - auto paddings = NDArrayFactory::create('c', {1,2}, {2, 2}); - - auto exp = NDArrayFactory::create('c', {7}, {2, 1, 1, 2, 3, 3, 2}); - - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {1}); - auto output = result.at(0); + auto exp = NDArrayFactory::create('c', {7}, {2, 1, 1, 2, 3, 3, 2}); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + auto output = result.at(0); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test4) { + auto input = NDArrayFactory::create('c', {3}, {1., 2., 3.}); + auto paddings = NDArrayFactory::create('c', {2}, {2, 3}); - auto input = NDArrayFactory::create('c', {3}, {1., 2., 3.}); - auto paddings = NDArrayFactory::create('c', {2}, {2, 3}); + auto exp = NDArrayFactory::create('c', {8}, {2, 1, 1, 2, 3, 3, 2, 1}); - auto exp = NDArrayFactory::create('c', {8}, {2, 1, 1, 2, 3, 3, 2, 1}); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + auto output = result.at(0); - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {1}); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test5) { + auto input = NDArrayFactory::create('c', {3}, {1., 2., 3.}); + auto paddings = NDArrayFactory::create('c', {2}, {2, 2}); - auto input = NDArrayFactory::create('c', {3}, {1., 2., 3.}); - auto paddings = NDArrayFactory::create('c', {2}, {2, 2}); - - auto exp = NDArrayFactory::create('c', {7}, {3, 2, 1, 2, 3, 2, 1}); - - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {0}); - auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + auto exp = NDArrayFactory::create('c', {7}, {3, 2, 1, 2, 3, 2, 1}); - + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {0}); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test6) { + auto input = NDArrayFactory::create(1.); + auto paddings = NDArrayFactory::create('c', {1, 2, 1, 1}, {1, 1}); - auto input = NDArrayFactory::create(1.); - auto paddings = NDArrayFactory::create('c', {1,2,1,1}, {1, 1}); + auto exp = NDArrayFactory::create('c', {3}, {1, 1, 1}); - auto exp = NDArrayFactory::create('c', {3}, {1,1,1}); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + auto output = result.at(0); - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {1}); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test7) { + auto input = NDArrayFactory::create(1.); + auto paddings = NDArrayFactory::create('c', {2}, {1, 1}); - auto input = NDArrayFactory::create(1.); - auto paddings = NDArrayFactory::create('c', {2}, {1, 1}); - - auto exp = NDArrayFactory::create('c', {3}, {1,1,1}); - - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {1}); - auto output = result.at(0); + auto exp = NDArrayFactory::create('c', {3}, {1, 1, 1}); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + auto output = result.at(0); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test8) { + auto input = NDArrayFactory::create('c', {1, 3}, {1., 2., 3.}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 3, 3}); - auto input = NDArrayFactory::create('c', {1,3}, {1., 2., 3.}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 3, 3}); + auto exp = NDArrayFactory::create( + 'c', {3, 9}, {3, 2, 1, 1, 2, 3, 3, 2, 1, 3, 2, 1, 1, 2, + 3, 3, 2, 1, 3, 2, 1, 1, 2, 3, 3, 2, 1}); - auto exp = NDArrayFactory::create('c', {3,9}, {3, 2, 1, 1, 2, 3, 3, 2, 1, 3, 2, 1, 1, 2, 3, 3, 2, 1, 3, 2, 1, 1, 2, 3, 3, 2, 1}); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + ASSERT_EQ(result.status(), Status::OK()); - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {1}); - ASSERT_EQ(result.status(), Status::OK()); - - auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test9) { + auto input = + NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {2, 2, 3, 3}); - auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {2, 2, 3, 3}); - - auto exp = NDArrayFactory::create('c', {6, 9}, {6, 5, 4, 4, 5, 6, 6, 5, 4, 3, 2, 1, 1, 2, 3, 3, 2, 1, 3, 2, 1, 1, 2, 3, 3, 2, 1, 6, 5, 4, 4, 5, 6, 6, 5, 4, 6, 5, 4, 4, 5, 6, 6, 5, 4, 3, 2, 1, 1, 2, 3, 3, 2, 1}); - - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {1}); - auto output = result.at(0); + auto exp = NDArrayFactory::create( + 'c', {6, 9}, {6, 5, 4, 4, 5, 6, 6, 5, 4, 3, 2, 1, 1, 2, 3, 3, 2, 1, + 3, 2, 1, 1, 2, 3, 3, 2, 1, 6, 5, 4, 4, 5, 6, 6, 5, 4, + 6, 5, 4, 4, 5, 6, 6, 5, 4, 3, 2, 1, 1, 2, 3, 3, 2, 1}); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + auto output = result.at(0); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test10) { + auto input = NDArrayFactory::create('c', {1, 3}, {1., 2., 3.}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); - auto input = NDArrayFactory::create('c', {1,3}, {1., 2., 3.}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); + auto exp = NDArrayFactory::create('c', {1, 3}, {1., 2., 3.}); - auto exp = NDArrayFactory::create('c', {1,3}, {1., 2., 3.}); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + auto output = result.at(0); - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {1}); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test11) { + auto input = NDArrayFactory::create('c', {1, 3}, {1., 2., 3.}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); - auto input = NDArrayFactory::create('c', {1,3}, {1., 2., 3.}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); - - auto exp = NDArrayFactory::create('c', {1,3}, {1., 2., 3.}); - - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {0}); - auto output = result.at(0); + auto exp = NDArrayFactory::create('c', {1, 3}, {1., 2., 3.}); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {0}); + auto output = result.at(0); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test12) { + auto input = NDArrayFactory::create('c', {3}, {1., 2., 3.}); + auto paddings = NDArrayFactory::create('c', {2, 1}, {0, 0}); - auto input = NDArrayFactory::create('c', {3}, {1., 2., 3.}); - auto paddings = NDArrayFactory::create('c', {2,1}, {0, 0}); + auto exp = NDArrayFactory::create('c', {3}, {1., 2., 3.}); - auto exp = NDArrayFactory::create('c', {3}, {1., 2., 3.}); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {0}); + auto output = result.at(0); - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {0}); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test13) { + auto input = + NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); - auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); - - auto exp = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); - - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {0}); - auto output = result.at(0); + auto exp = + NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {0}); + auto output = result.at(0); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test14) { + auto input = + NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto paddings = + NDArrayFactory::create('c', {2, 2}, {1LL, 0LL, 0LL, 1LL}); - auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {1LL, 0LL, 0LL, 1LL}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, {4, 5, 6, 5, 1, 2, 3, 2, 4, 5, 6, 5}); - auto exp = NDArrayFactory::create('c', {3, 4}, {4, 5, 6, 5, 1, 2, 3, 2, 4, 5, 6, 5}); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {0}); + auto output = result.at(0); - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {0}); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test15) { + auto input = + NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 0, 0}); - auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 0, 0}); - - auto exp = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6}); - - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {1}); - auto output = result.at(0); + auto exp = NDArrayFactory::create( + 'c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6}); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + auto output = result.at(0); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test16) { - - auto input = NDArrayFactory::create('c', {4,3,2}); - auto paddings = NDArrayFactory::create('c', {3,2}, {3,3,2,2,1,1}); - - auto exp = NDArrayFactory::create('c', {10,7,4}, {24., 23., 24., 23.,22., 21., 22., 21.,20., 19., 20., 19.,22., 21., 22., 21.,24., 23., 24., 23.,22., 21., 22., 21.,20., 19., 20., 19.,18., 17., 18., 17.,16., 15., 16., 15.,14., 13., 14., 13.,16., 15., 16., 15.,18., 17., 18., 17.,16., 15., 16., 15.,14., 13., 14., 13., - 12., 11., 12., 11.,10., 9., 10., 9., 8., 7., 8., 7.,10., 9., 10., 9.,12., 11., 12., 11.,10., 9., 10., 9., 8., 7., 8., 7., 6., 5., 6., 5., 4., 3., 4., 3., 2., 1., 2., 1., 4., 3., 4., 3., 6., 5., 6., 5., 4., 3., 4., 3., 2., 1., 2., 1., - 12., 11., 12., 11.,10., 9., 10., 9., 8., 7., 8., 7.,10., 9., 10., 9.,12., 11., 12., 11.,10., 9., 10., 9., 8., 7., 8., 7.,18., 17., 18., 17.,16., 15., 16., 15.,14., 13., 14., 13.,16., 15., 16., 15.,18., 17., 18., 17.,16., 15., 16., 15.,14., 13., 14., 13., - 24., 23., 24., 23.,22., 21., 22., 21.,20., 19., 20., 19.,22., 21., 22., 21.,24., 23., 24., 23.,22., 21., 22., 21.,20., 19., 20., 19.,18., 17., 18., 17.,16., 15., 16., 15.,14., 13., 14., 13.,16., 15., 16., 15.,18., 17., 18., 17.,16., 15., 16., 15.,14., 13., 14., 13., - 12., 11., 12., 11.,10., 9., 10., 9., 8., 7., 8., 7.,10., 9., 10., 9.,12., 11., 12., 11.,10., 9., 10., 9., 8., 7., 8., 7., 6., 5., 6., 5., 4., 3., 4., 3., 2., 1., 2., 1., 4., 3., 4., 3., 6., 5., 6., 5., 4., 3., 4., 3., 2., 1., 2., 1.}); - input.linspace(1.); - - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {0}); - ASSERT_EQ(result.status(), Status::OK()); - auto output = result.at(0); - //output->printBuffer("VVV"); - //exp.printBuffer("EXP"); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto input = NDArrayFactory::create('c', {4, 3, 2}); + auto paddings = NDArrayFactory::create('c', {3, 2}, {3, 3, 2, 2, 1, 1}); + + auto exp = NDArrayFactory::create( + 'c', {10, 7, 4}, + {24., 23., 24., 23., 22., 21., 22., 21., 20., 19., 20., 19., 22., 21., + 22., 21., 24., 23., 24., 23., 22., 21., 22., 21., 20., 19., 20., 19., + 18., 17., 18., 17., 16., 15., 16., 15., 14., 13., 14., 13., 16., 15., + 16., 15., 18., 17., 18., 17., 16., 15., 16., 15., 14., 13., 14., 13., + 12., 11., 12., 11., 10., 9., 10., 9., 8., 7., 8., 7., 10., 9., + 10., 9., 12., 11., 12., 11., 10., 9., 10., 9., 8., 7., 8., 7., + 6., 5., 6., 5., 4., 3., 4., 3., 2., 1., 2., 1., 4., 3., + 4., 3., 6., 5., 6., 5., 4., 3., 4., 3., 2., 1., 2., 1., + 12., 11., 12., 11., 10., 9., 10., 9., 8., 7., 8., 7., 10., 9., + 10., 9., 12., 11., 12., 11., 10., 9., 10., 9., 8., 7., 8., 7., + 18., 17., 18., 17., 16., 15., 16., 15., 14., 13., 14., 13., 16., 15., + 16., 15., 18., 17., 18., 17., 16., 15., 16., 15., 14., 13., 14., 13., + 24., 23., 24., 23., 22., 21., 22., 21., 20., 19., 20., 19., 22., 21., + 22., 21., 24., 23., 24., 23., 22., 21., 22., 21., 20., 19., 20., 19., + 18., 17., 18., 17., 16., 15., 16., 15., 14., 13., 14., 13., 16., 15., + 16., 15., 18., 17., 18., 17., 16., 15., 16., 15., 14., 13., 14., 13., + 12., 11., 12., 11., 10., 9., 10., 9., 8., 7., 8., 7., 10., 9., + 10., 9., 12., 11., 12., 11., 10., 9., 10., 9., 8., 7., 8., 7., + 6., 5., 6., 5., 4., 3., 4., 3., 2., 1., 2., 1., 4., 3., + 4., 3., 6., 5., 6., 5., 4., 3., 4., 3., 2., 1., 2., 1.}); + input.linspace(1.); + + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {0}); + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); + // output->printBuffer("VVV"); + // exp.printBuffer("EXP"); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_1) { + auto input = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create(120.f); + //************************************// - auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); - auto exp = NDArrayFactory::create(120.f); - //************************************// - - sd::ops::reduce_sum op; - auto result = op.evaluate({&input}, {}, {}); + sd::ops::reduce_sum op; + auto result = op.evaluate({&input}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - //z->printIndexedBuffer("Result is "); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_2) { + auto input = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create({15.f, 40.f, 65.f}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); - auto exp = NDArrayFactory::create({15.f, 40.f, 65.f}); - //************************************// + sd::ops::reduce_sum op; + auto result = op.evaluate({&input}, {}, {1}); - sd::ops::reduce_sum op; - auto result = op.evaluate({&input}, {}, {1}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_1) { + auto input = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create(1307674368000.f); + //************************************// - auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); - auto exp = NDArrayFactory::create(1307674368000.f); - //************************************// - - sd::ops::reduce_prod op; - auto result = op.evaluate({&input}, {}, {}); + sd::ops::reduce_prod op; + auto result = op.evaluate({&input}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - //z->printIndexedBuffer("Result is "); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_2) { + auto input = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create({120.f, 30240.f, 360360.f}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); - auto exp = NDArrayFactory::create({120.f, 30240.f, 360360.f}); - //************************************// + sd::ops::reduce_prod op; + auto result = op.evaluate({&input}, {}, {1}); - sd::ops::reduce_prod op; - auto result = op.evaluate({&input}, {}, {1}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_01) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {66.f, 72.f, 78.f, 84.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {4}, {66.f, 72.f, 78.f, 84.f}); - x.linspace(1); - - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_02) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 1, 4}, {66.f, 72.f, 78.f, 84.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {66.f, 72.f, 78.f, 84.f}); - x.linspace(1); - - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {68.f, 100.f, 132.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {3}, {68.f, 100.f, 132.f}); - x.linspace(1); - - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 3, 1}, {68.f, 100.f, 132.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {68.f, 100.f, 132.f}); - x.linspace(1); - - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(300.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create(300.f); - x.linspace(1); - - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(300.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create(300.f); - x.linspace(1); - - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {300.f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,1}, {300.f}); - x.linspace(1); -// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_01) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create('c', {2}, {10395.f, 46080.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create('c', {2}, {10395.f, 46080.f}); - x.linspace(1); - - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_02) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create('c', {1, 1, 2}, {10395.f, 46080.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create('c', {1,1,2}, {10395.f, 46080.f}); - x.linspace(1); - - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_3) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create('c', {3}, {112.f, 1080.f, 3960.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create('c', {3}, {112.f, 1080.f, 3960.f}); - x.linspace(1); - - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_4) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = + NDArrayFactory::create('c', {1, 3, 1}, {112.f, 1080.f, 3960.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {112.f, 1080.f, 3960.f}); - x.linspace(1); - - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_5) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create(479001600.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create(479001600.f); - x.linspace(1); - - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_6) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create(479001600.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create(479001600.f); - x.linspace(1); - - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_7) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {479001600.f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {479001600.f}); - x.linspace(1); -// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } TYPED_TEST(TypedDeclarableOpsTests7, Test_Pnorm_Once_Again) { - auto input = NDArrayFactory::create('c', {1, 1, 5, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f}); - auto exp = NDArrayFactory::create('c', {1, 1, 5, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f}); - - sd::ops::pnormpool2d op; - auto result = op.evaluate({&input}, {}, {1,1, 1,1, 0,0, 1,1,1, 3, 0}); - ASSERT_EQ(Status::OK(), result.status()); + auto input = NDArrayFactory::create( + 'c', {1, 1, 5, 5}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, + 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, + 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f}); + auto exp = NDArrayFactory::create( + 'c', {1, 1, 5, 5}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, + 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, + 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f}); - ASSERT_EQ(exp, *result.at(0)); + sd::ops::pnormpool2d op; + auto result = op.evaluate({&input}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 1, 3, 0}); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_EQ(exp, result.at(0)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - x.linspace(1); - - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {1.f, 2.f, 3.f, 4.f}); - x.linspace(1); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {1.f, 5.f, 9.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {1.f, 5.f, 9.f}); - x.linspace(1); - - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {1.f, 5.f, 9.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {1.f, 5.f, 9.f}); - x.linspace(1); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(1.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(1.f); - x.linspace(1); - - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(1.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(1.f); - x.linspace(1); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {1.f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {1.f}); - x.linspace(1); -// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); - x.linspace(1); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + // output->printShapeInfo("Output shape"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); -// output->printShapeInfo("Output shape"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 1, 4}, {21.f, 22.f, 23.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {21.f, 22.f, 23.f, 24.f}); - x.linspace(1); - - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {16.f, 20.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {16.f, 20.f, 24.f}); - x.linspace(1); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {16.f, 20.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {16.f, 20.f, 24.f}); - x.linspace(1); - - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(24.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(24.f); - x.linspace(1); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(24.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(24.f); - x.linspace(1); - - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {24.f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {24.f}); - x.linspace(1); -// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {66.f, 72.f, 78.f, 84.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {66.f, 72.f, 78.f, 84.f}); - x.linspace(1); - - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 1, 4}, {66.f, 72.f, 78.f, 84.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {66.f, 72.f, 78.f, 84.f}); - x.linspace(1); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {68.f, 100.f, 132.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {68.f, 100.f, 132.f}); - x.linspace(1); - - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 3, 1}, {68.f, 100.f, 132.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {68.f, 100.f, 132.f}); - x.linspace(1); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(300.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(300.f); - x.linspace(1); - - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(300.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(300.f); - x.linspace(1); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {300.f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {300.f}); - x.linspace(1); -// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); - x.linspace(1); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {1, 1, 4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); - x.linspace(1); - - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {3}, {29.597298f, 39.344631f, 49.759422f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {29.597298f, 39.344631f, 49.759422f}); - x.linspace(1); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {1, 3, 1}, {29.597298f, 39.344631f, 49.759422f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {29.597298f, 39.344631f, 49.759422f}); - x.linspace(1); - - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(70.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(70.f); - x.linspace(1); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(70.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(70.f); - x.linspace(1); - - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {70.f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {70.f}); - x.linspace(1); -// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); - x.linspace(1); - - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 1, 4}, {21.f, 22.f, 23.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {21.f, 22.f, 23.f, 24.f}); - x.linspace(1); + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {1.f}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {1.f}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {16.f, 20.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {16.f, 20.f, 24.f}); - x.linspace(1); - - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {}, {0,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {16.f, 20.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 3, 1}, {16.f, 20.f, 24.f}); - x.linspace(1); + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {1.f}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {1.f}, {0,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(24.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(24.f); - x.linspace(1); - - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(24.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(24.f); - x.linspace(1); + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {}, {0, 1, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {24.f}); - x.linspace(1); - - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {1.f}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {1.f}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, + {1006.f, 1144.f, 1294.f, 1456.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {1006.f, 1144.f, 1294.f, 1456.f}); - x.linspace(1); + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 4}, + {1006.f, 1144.f, 1294.f, 1456.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {1006.f, 1144.f, 1294.f, 1456.f}); - x.linspace(1); - - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {1.f}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {1.f}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {876.f, 1548.f, 2476.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {876.f, 1548.f, 2476.f}); - x.linspace(1); + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {}, {0,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 3, 1}, {876.f, 1548.f, 2476.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 3, 1}, {876.f, 1548.f, 2476.f}); - x.linspace(1); - - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {1.f}, {0,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {1.f}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(4900.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(4900.f); - x.linspace(1); + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(4900.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(4900.f); - x.linspace(1); - - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {}, {0, 1, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {4900.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {4900.f}); - x.linspace(1); + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {1.f}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {1.f}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_BP_1) { + auto input = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create(0.5f); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - auto eps = NDArrayFactory::create(0.5f); - auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,0.5f}); - //************************************// + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps}, {}, {}); - sd::ops::reduce_sum_bp op; - auto result = op.evaluate({&input, &eps}, {}, {}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_BP_2) { + auto input = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create('c', {1, 1}, {0.5f}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - auto eps = NDArrayFactory::create('c', {1, 1}, {0.5f}); - auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, 0.5f, 0.5f, - 0.5f, 0.5f, 0.5f, 0.5f, - 0.5f, 0.5f, 0.5f,0.5f}); - //************************************// - - sd::ops::reduce_sum_bp op; - auto result = op.evaluate({&input, &eps}, {1.f}, {}); + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps}, {1.f}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_BP_3) { + auto input = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f}); - //************************************// + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps}, {}, {0}); - sd::ops::reduce_sum_bp op; - auto result = op.evaluate({&input, &eps}, {}, {0}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_BP_4) { + auto input = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f}); - //************************************// - - sd::ops::reduce_sum_bp op; - auto result = op.evaluate({&input, &eps}, {1.f}, {0}); + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps}, {1.f}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_1) { - - auto input = NDArrayFactory::create('c', {3, 5}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); - auto eps = NDArrayFactory::create(1307674368000.f); - //************************************// -// auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,0.5f}); - //************************************// - auto exp = NDArrayFactory::create('c', {3, 5}, {1710012166826558903812096.f, 855006083413279451906048.f, 570004067618451974258688.f, - 427503041706639725953024.f, 342002454982589992140800.f, 285002033809225987129344.f, - 244287457550765131825152.f, 213751520853319862976512.f, 190001355872817324752896.f, - 171001227491294996070400.f, 155455648254341989531648.f, 142501016904612993564672.f, - 131539399526781282156544.f, 122143728775382565912576.f, 114000815325130245799936.f}); - - sd::ops::reduce_prod_bp op; - auto result = op.evaluate({&input, &eps}, {}, {}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - + auto input = + NDArrayFactory::create('c', {3, 5}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); + auto eps = NDArrayFactory::create(1307674368000.f); + //************************************// + // auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, + // 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,0.5f}); + //************************************// + auto exp = NDArrayFactory::create( + 'c', {3, 5}, + {1710012166826558903812096.f, 855006083413279451906048.f, + 570004067618451974258688.f, 427503041706639725953024.f, + 342002454982589992140800.f, 285002033809225987129344.f, + 244287457550765131825152.f, 213751520853319862976512.f, + 190001355872817324752896.f, 171001227491294996070400.f, + 155455648254341989531648.f, 142501016904612993564672.f, + 131539399526781282156544.f, 122143728775382565912576.f, + 114000815325130245799936.f}); + + sd::ops::reduce_prod_bp op; + auto result = op.evaluate({&input, &eps}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_2) { - - auto input = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); - auto eps = NDArrayFactory::create(0.5f); - //************************************// -// auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,0.5f}); - //************************************// - auto exp = NDArrayFactory::create('c', {3, 4}); - - sd::ops::reduce_prod_bp op; - sd::ops::reduce_prod op_exp; - auto res = op_exp.evaluate({&input}); - auto result = op.evaluate({&input, &eps}, {}, {}); - exp.assign(res.at(0)->e(0)); - exp /= input; - exp *= eps.e(0); - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - //z->printIndexedBuffer("Result is "); - //exp.printIndexedBuffer("Expected"); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto input = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); + auto eps = NDArrayFactory::create(0.5f); + //************************************// + // auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, + // 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,0.5f}); + //************************************// + auto exp = NDArrayFactory::create('c', {3, 4}); + + sd::ops::reduce_prod_bp op; + sd::ops::reduce_prod op_exp; + auto res = op_exp.evaluate({&input}); + auto result = op.evaluate({&input, &eps}, {}, {}); + exp.assign(res.at(0).e(0)); + exp /= input; + exp *= eps.e(0); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // exp.printIndexedBuffer("Expected"); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_3) { + auto input = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); + auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + //************************************// + auto exp = + NDArrayFactory::create('c', {3, 4}, + {45.f, 120.f, 231.f, 384.f, 9.f, 40.f, + 99.f, 192.f, 5.f, 24.f, 63.f, 128.f}); - auto input = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); - auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); - //************************************// - auto exp = NDArrayFactory::create('c', {3, 4}, {45.f, 120.f, 231.f, 384.f, 9.f, 40.f, 99.f, 192.f, 5.f, 24.f, 63.f, 128.f}); + sd::ops::reduce_prod_bp op; + // sd::ops::reduce_prod op_exp; + auto result = op.evaluate({&input, &eps}, {1.f}, {0}); - sd::ops::reduce_prod_bp op; - //sd::ops::reduce_prod op_exp; - auto result = op.evaluate({&input, &eps}, {1.f}, {0}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// exp.printIndexedBuffer("Expected"); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // exp.printIndexedBuffer("Expected"); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_03) { - int ax = 0; - auto input = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); - auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); - //************************************// - auto exp = NDArrayFactory::create('c', {3, 4}, {45.f, 120.f, 231.f, 384.f, 9.f, 40.f, 99.f, 192.f, 5.f, 24.f, 63.f, 128.f}); - auto axis = NDArrayFactory::create('c', {1}, {ax}); - sd::ops::reduce_prod_bp op; - //sd::ops::reduce_prod op_exp; - auto result = op.evaluate({&input, &eps, &axis}, {}, {}, {true}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// exp.printIndexedBuffer("Expected"); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - + int ax = 0; + auto input = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); + auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + //************************************// + auto exp = + NDArrayFactory::create('c', {3, 4}, + {45.f, 120.f, 231.f, 384.f, 9.f, 40.f, + 99.f, 192.f, 5.f, 24.f, 63.f, 128.f}); + auto axis = NDArrayFactory::create('c', {1}, {ax}); + sd::ops::reduce_prod_bp op; + // sd::ops::reduce_prod op_exp; + auto result = op.evaluate({&input, &eps, &axis}, {}, {}, {true}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // exp.printIndexedBuffer("Expected"); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_4) { + auto input = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + //************************************// + auto exp = + NDArrayFactory::create('c', {3, 4}, + {45.f, 120.f, 231.f, 384.f, 9.f, 40.f, + 99.f, 192.f, 5.f, 24.f, 63.f, 128.f}); - auto input = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); - auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - //************************************// - auto exp = NDArrayFactory::create('c', {3, 4}, {45.f, 120.f, 231.f, 384.f, 9.f, 40.f, 99.f, 192.f, 5.f, 24.f, 63.f, 128.f}); + sd::ops::reduce_prod_bp op; + sd::ops::reduce_prod op_exp; + // auto res = op_exp.execute({&input}, {}, {}); + auto result = op.evaluate({&input, &eps}, {0.f}, {0}); - sd::ops::reduce_prod_bp op; - sd::ops::reduce_prod op_exp; -// auto res = op_exp.execute({&input}, {}, {}); - auto result = op.evaluate({&input, &eps}, {0.f}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // exp.printIndexedBuffer("Expected"); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// exp.printIndexedBuffer("Expected"); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - -// + // } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_5) { + auto input = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); + auto eps = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + //************************************// + auto exp = + NDArrayFactory::create('c', {3, 4}, + {24.f, 12.f, 8.f, 6.f, 672.f, 560.f, 480.f, + 420.f, 3960.f, 3564.f, 3240.f, 2970.f}); - auto input = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); - auto eps = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - //************************************// - auto exp = NDArrayFactory::create('c', {3, 4}, {24.f, 12.f, 8.f, 6.f, 672.f, 560.f, 480.f, 420.f, 3960.f, 3564.f, 3240.f, 2970.f}); + sd::ops::reduce_prod_bp op; + sd::ops::reduce_prod op_exp; + // auto res = op_exp.execute({&input}, {}, {}); + auto result = op.evaluate({&input, &eps}, {0.f}, {1}); - sd::ops::reduce_prod_bp op; - sd::ops::reduce_prod op_exp; -// auto res = op_exp.execute({&input}, {}, {}); - auto result = op.evaluate({&input, &eps}, {0.f}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // exp.printIndexedBuffer("Expected"); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// exp.printIndexedBuffer("Expected"); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - -// + // } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_1) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - exp.p(0, eps.e(0)); - exp.p(1, eps.e(1)); - exp.p(2, eps.e(2)); - exp.p(3, eps.e(3)); - x.linspace(1); -// x.printIndexedBuffer("Input is"); -// exp.printIndexedBuffer("Expected "); - sd::ops::reduce_min_bp op; - auto result = op.evaluate({&x, &eps}, {}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + exp.p(0, eps.e(0)); + exp.p(1, eps.e(1)); + exp.p(2, eps.e(2)); + exp.p(3, eps.e(3)); + x.linspace(1); + // x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_min_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_2) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - exp.p(0, eps.e(0)); - exp.p(1, eps.e(1)); - exp.p(2, eps.e(2)); - exp.p(3, eps.e(3)); - x.linspace(1); -// x.printIndexedBuffer("Input is"); -// exp.printIndexedBuffer("Expected "); - sd::ops::reduce_min_bp op; - auto result = op.evaluate({&x, &eps}, {1.f}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = + NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + exp.p(0, eps.e(0)); + exp.p(1, eps.e(1)); + exp.p(2, eps.e(2)); + exp.p(3, eps.e(3)); + x.linspace(1); + // x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_min_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_02) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - exp.p(0, eps.e(0)); - exp.p(1, eps.e(1)); - exp.p(2, eps.e(2)); - exp.p(3, eps.e(3)); - auto axes = NDArrayFactory::create({0,1}); - x.linspace(1); -// x.printIndexedBuffer("Input is"); -// exp.printIndexedBuffer("Expected "); - sd::ops::reduce_min_bp op; - auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {true}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = + NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + exp.p(0, eps.e(0)); + exp.p(1, eps.e(1)); + exp.p(2, eps.e(2)); + exp.p(3, eps.e(3)); + auto axes = NDArrayFactory::create({0, 1}); + x.linspace(1); + // x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_min_bp op; + auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {true}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_3) { + auto x = NDArrayFactory::create('c', {3, 4}); + auto eps = NDArrayFactory::create('c', {1, 1}, {0.5f}); + auto exp = NDArrayFactory::create('c', {3, 4}); + x.linspace(1); + x.p(2, 2, -1.f); + exp.p(2, 2, 0.5f); + // x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_min_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto x = NDArrayFactory::create('c', {3, 4}); - auto eps = NDArrayFactory::create('c', {1, 1}, {0.5f}); - auto exp = NDArrayFactory::create('c', {3, 4}); - x.linspace(1); - x.p(2,2, -1.f); - exp.p(2,2, 0.5f); - //x.printIndexedBuffer("Input is"); - // exp.printIndexedBuffer("Expected "); - sd::ops::reduce_min_bp op; - auto result = op.evaluate({&x, &eps}, {1.f}, {}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_4) { + auto x = NDArrayFactory::create('c', {3, 4}); + auto eps = NDArrayFactory::create(0.5f); + auto exp = NDArrayFactory::create('c', {3, 4}); + x.linspace(1); + x.p(2, 2, -1.f); + exp.p(2, 2, 0.5f); + // x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_min_bp op; + auto result = op.evaluate({&x, &eps}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto x = NDArrayFactory::create('c', {3, 4}); - auto eps = NDArrayFactory::create(0.5f); - auto exp = NDArrayFactory::create('c', {3, 4}); - x.linspace(1); - x.p(2,2, -1.f); - exp.p(2,2, 0.5f); -// x.printIndexedBuffer("Input is"); -// exp.printIndexedBuffer("Expected "); - sd::ops::reduce_min_bp op; - auto result = op.evaluate({&x, &eps}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_5) { - - auto x = NDArrayFactory::create('c', {4, 4}); - auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {4, 4}); - x.linspace(1); - x.p(0,0, -1.f); - x.p(1,1, -2.f); - x.p(2,2, -3.f); - x.p(3,3, -4.f); - exp.p(0,0, 1.f); - exp.p(1,1, 2.f); - exp.p(2,2, 3.f); - exp.p(3,3, 4.f); -// exp(2,2) = 0.5f; -// x.printIndexedBuffer("Input is"); -// exp.printIndexedBuffer("Expected "); - sd::ops::reduce_min_bp op; - auto result = op.evaluate({&x, &eps}, {}, {0}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {4, 4}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {4, 4}); + x.linspace(1); + x.p(0, 0, -1.f); + x.p(1, 1, -2.f); + x.p(2, 2, -3.f); + x.p(3, 3, -4.f); + exp.p(0, 0, 1.f); + exp.p(1, 1, 2.f); + exp.p(2, 2, 3.f); + exp.p(3, 3, 4.f); + // exp(2,2) = 0.5f; + // x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_min_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_6) { - - auto x = NDArrayFactory::create('c', {4, 4}); - auto eps = NDArrayFactory::create('c', {1,4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {4, 4}); - x.linspace(1); - x.p(0,0, -1.f); - x.p(1,1, -2.f); - x.p(2,2, -3.f); - x.p(3,3, -4.f); - exp.p(0,0, 1.f); - exp.p(1,1, 2.f); - exp.p(2,2, 3.f); - exp.p(3,3, 4.f); -// exp(2,2) = 0.5f; -// x.printIndexedBuffer("Input is"); -// exp.printIndexedBuffer("Expected "); - sd::ops::reduce_min_bp op; - auto result = op.evaluate({&x, &eps}, {1.f}, {0}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {4, 4}); + auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {4, 4}); + x.linspace(1); + x.p(0, 0, -1.f); + x.p(1, 1, -2.f); + x.p(2, 2, -3.f); + x.p(3, 3, -4.f); + exp.p(0, 0, 1.f); + exp.p(1, 1, 2.f); + exp.p(2, 2, 3.f); + exp.p(3, 3, 4.f); + // exp(2,2) = 0.5f; + // x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_min_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_BP_1) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - exp.p(20, eps.e(0)); - exp.p(21, eps.e(1)); - exp.p(22, eps.e(2)); - exp.p(23, eps.e(3)); - x.linspace(1); - // x.printIndexedBuffer("Input is"); - // exp.printIndexedBuffer("Expected "); - sd::ops::reduce_max_bp op; - auto result = op.evaluate({&x, &eps}, {}, {0, 1}); - auto output = result.at(0); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + exp.p(20, eps.e(0)); + exp.p(21, eps.e(1)); + exp.p(22, eps.e(2)); + exp.p(23, eps.e(3)); + x.linspace(1); + // x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_max_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0, 1}); + auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_BP_2) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {1, 1, 4}, {21.f, 22.f, 23.f, 24.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - exp.p(20, eps.e(0)); - exp.p(21, eps.e(1)); - exp.p(22, eps.e(2)); - exp.p(23, eps.e(3)); - x.linspace(1); -// x.printIndexedBuffer("Input is"); -// exp.printIndexedBuffer("Expected "); - sd::ops::reduce_max_bp op; - auto result = op.evaluate({&x, &eps}, {1.f}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = + NDArrayFactory::create('c', {1, 1, 4}, {21.f, 22.f, 23.f, 24.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + exp.p(20, eps.e(0)); + exp.p(21, eps.e(1)); + exp.p(22, eps.e(2)); + exp.p(23, eps.e(3)); + x.linspace(1); + // x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_max_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_BP_02) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {1, 1, 4}, {21.f, 22.f, 23.f, 24.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - exp.p(20, eps.e(0)); - exp.p(21, eps.e(1)); - exp.p(22, eps.e(2)); - exp.p(23, eps.e(3)); - auto axes = NDArrayFactory::create({0, 1}); - x.linspace(1); -// x.printIndexedBuffer("Input is"); -// exp.printIndexedBuffer("Expected "); - sd::ops::reduce_max_bp op; - auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {true}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = + NDArrayFactory::create('c', {1, 1, 4}, {21.f, 22.f, 23.f, 24.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + exp.p(20, eps.e(0)); + exp.p(21, eps.e(1)); + exp.p(22, eps.e(2)); + exp.p(23, eps.e(3)); + auto axes = NDArrayFactory::create({0, 1}); + x.linspace(1); + // x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_max_bp op; + auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {true}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_BP_3) { - - auto x = NDArrayFactory::create('c', {4, 4}); - auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {4, 4}); - x.linspace(1); - x.p(0,0, 21.f); - x.p(1,1, 22.f); - x.p(2,2, 23.f); - x.p(3,3, 24.f); - exp.p(0,0, 1.f); - exp.p(1,1, 2.f); - exp.p(2,2, 3.f); - exp.p(3,3, 4.f); -// x.printIndexedBuffer("Input is"); -// exp.printIndexedBuffer("Expected "); - sd::ops::reduce_max_bp op; - auto result = op.evaluate({&x, &eps}, {}, {0}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {4, 4}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {4, 4}); + x.linspace(1); + x.p(0, 0, 21.f); + x.p(1, 1, 22.f); + x.p(2, 2, 23.f); + x.p(3, 3, 24.f); + exp.p(0, 0, 1.f); + exp.p(1, 1, 2.f); + exp.p(2, 2, 3.f); + exp.p(3, 3, 4.f); + // x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_max_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_BP_4) { - - auto x = NDArrayFactory::create('c', {4, 4}); - auto eps = NDArrayFactory::create('c', {1,4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {4, 4}); - x.linspace(1); - x.p(0,0, 21.f); - x.p(1,1, 22.f); - x.p(2,2, 23.f); - x.p(3,3, 24.f); - exp.p(0,0, 1.f); - exp.p(1,1, 2.f); - exp.p(2,2, 3.f); - exp.p(3,3, 4.f); - -// x.printIndexedBuffer("Input is"); -// exp.printIndexedBuffer("Expected "); - sd::ops::reduce_max_bp op; - auto result = op.evaluate({&x, &eps}, {1.f}, {0}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {4, 4}); + auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {4, 4}); + x.linspace(1); + x.p(0, 0, 21.f); + x.p(1, 1, 22.f); + x.p(2, 2, 23.f); + x.p(3, 3, 24.f); + exp.p(0, 0, 1.f); + exp.p(1, 1, 2.f); + exp.p(2, 2, 3.f); + exp.p(3, 3, 4.f); + + // x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_max_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_BP_1) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create(5.f); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - x.p(12, -2.f); - x.p(20, -3.f); - exp.assign(5.f); - exp.p(12, -exp.e(12)); - exp.p(20, -exp.e(20)); - sd::ops::reduce_norm1_bp op; - auto result = op.evaluate({&x, &eps}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create(5.f); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.p(12, -2.f); + x.p(20, -3.f); + exp.assign(5.f); + exp.p(12, -exp.e(12)); + exp.p(20, -exp.e(20)); + sd::ops::reduce_norm1_bp op; + auto result = op.evaluate({&x, &eps}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_BP_2) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create({1.f, 2.f, 3.f, 4.f}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f}); - sd::ops::reduce_norm1_bp op; - auto result = op.evaluate({&x, &eps}, {}, {0,1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - // exp.printIndexedBuffer("Expect is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create({1.f, 2.f, 3.f, 4.f}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); + sd::ops::reduce_norm1_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + // exp.printIndexedBuffer("Expect is"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_BP_02) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create({1.f, 2.f, 3.f, 4.f}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f}); - auto axes = NDArrayFactory::create({0,1}); - sd::ops::reduce_norm1_bp op; - auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {false}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create({1.f, 2.f, 3.f, 4.f}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); + auto axes = NDArrayFactory::create({0, 1}); + sd::ops::reduce_norm1_bp op; + auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {false}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_BP_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = + NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); + sd::ops::reduce_norm1_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0, 1}); + auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f}); - sd::ops::reduce_norm1_bp op; - auto result = op.evaluate({&x, &eps}, {1.f}, {0,1}); - auto output = result.at(0); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_BP_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create( + 'c', {4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); - x.linspace(1); - - sd::ops::reduce_norm2_bp op; - auto result = op.evaluate({&x, &eps}, {}, {0,1}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm2_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(x.isSameShape(output)); - ASSERT_TRUE(x.equalsTo(output)); - - + ASSERT_TRUE(x.isSameShape(output)); + ASSERT_TRUE(x.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_BP_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create( + 'c', {1, 1, 4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {1, 1, 4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); - x.linspace(1); - - sd::ops::reduce_norm2_bp op; - auto result = op.evaluate({&x, &eps}, {1.f}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm2_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(x.isSameShape(output)); - ASSERT_TRUE(x.equalsTo(output)); - - + ASSERT_TRUE(x.isSameShape(output)); + ASSERT_TRUE(x.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_BP_02) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create( + 'c', {1, 1, 4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); + auto axes = NDArrayFactory::create({0, 1}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {1, 1, 4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); - auto axes = NDArrayFactory::create({0, 1}); - x.linspace(1); - - sd::ops::reduce_norm2_bp op; - auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {true}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm2_bp op; + auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {true}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(x.isSameShape(output)); - ASSERT_TRUE(x.equalsTo(output)); - - + ASSERT_TRUE(x.isSameShape(output)); + ASSERT_TRUE(x.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_BP_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create( + 'c', {3}, {29.597298f, 39.344631f, 49.759422f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {3}, {29.597298f, 39.344631f, 49.759422f}); - x.linspace(1); - - sd::ops::reduce_norm2_bp op; - auto result = op.evaluate({&x, &eps}, {}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_norm2_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(x.isSameShape(output)); - ASSERT_TRUE(x.equalsTo(output)); - - + ASSERT_TRUE(x.isSameShape(output)); + ASSERT_TRUE(x.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_BP_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create( + 'c', {1, 3, 1}, {29.597298f, 39.344631f, 49.759422f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {1,3,1}, {29.597298f, 39.344631f, 49.759422f}); - x.linspace(1); - - sd::ops::reduce_norm2_bp op; - auto result = op.evaluate({&x, &eps}, {1.f}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_norm2_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(x.isSameShape(output)); - ASSERT_TRUE(x.equalsTo(output)); - - + ASSERT_TRUE(x.isSameShape(output)); + ASSERT_TRUE(x.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_BP_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, {2.f, 8.f, 18.f, 32.f, 10.f, 24.f, 42.f, 64.f, + 18.f, 40.f, 66.f, 96.f, 26.f, 56.f, 90.f, 128.f, + 34.f, 72.f, 114.f, 160.f, 42.f, 88.f, 138.f, 192.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, { 2.f, 8.f, 18.f, 32.f, - 10.f, 24.f, 42.f, 64.f, - 18.f, 40.f, 66.f, 96.f, - 26.f, 56.f, 90.f, 128.f, - 34.f, 72.f, 114.f, 160.f, - 42.f, 88.f, 138.f, 192.f}); - x.linspace(1); - - sd::ops::reduce_sqnorm_bp op; - auto result = op.evaluate({&x, &eps}, {}, {0,1}); + sd::ops::reduce_sqnorm_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_BP_01) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, {2.f, 8.f, 18.f, 32.f, 10.f, 24.f, 42.f, 64.f, + 18.f, 40.f, 66.f, 96.f, 26.f, 56.f, 90.f, 128.f, + 34.f, 72.f, 114.f, 160.f, 42.f, 88.f, 138.f, 192.f}); + auto axes = NDArrayFactory::create({0, 1}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, { 2.f, 8.f, 18.f, 32.f, - 10.f, 24.f, 42.f, 64.f, - 18.f, 40.f, 66.f, 96.f, - 26.f, 56.f, 90.f, 128.f, - 34.f, 72.f, 114.f, 160.f, - 42.f, 88.f, 138.f, 192.f}); - auto axes = NDArrayFactory::create({0, 1}); - x.linspace(1); - - sd::ops::reduce_sqnorm_bp op; - auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {false}); + sd::ops::reduce_sqnorm_bp op; + auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {false}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + exp.p(20, 1.f); + exp.p(21, 2.f); + exp.p(22, 3.f); + exp.p(23, 4.f); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - x.linspace(1); - exp.p(20, 1.f); - exp.p(21, 2.f); - exp.p(22, 3.f); - exp.p(23, 4.f); - - sd::ops::reduce_norm_max_bp op; - auto result = op.evaluate({&x, &eps}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm_max_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = + NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + exp.p(20, 1.f); + exp.p(21, 2.f); + exp.p(22, 3.f); + exp.p(23, 4.f); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {1,1,4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - x.linspace(1); - exp.p(20, 1.f); - exp.p(21, 2.f); - exp.p(22, 3.f); - exp.p(23, 4.f); - - sd::ops::reduce_norm_max_bp op; - auto result = op.evaluate({&x, &eps}, {1.f}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm_max_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_02) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {1,1,4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - auto axes = NDArrayFactory::create({0,1}); - x.linspace(1); - exp.p(20, 1.f); - exp.p(21, 2.f); - exp.p(22, 3.f); - exp.p(23, 4.f); - - sd::ops::reduce_norm_max_bp op; - auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {true}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = + NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + auto axes = NDArrayFactory::create({0, 1}); + x.linspace(1); + exp.p(20, 1.f); + exp.p(21, 2.f); + exp.p(22, 3.f); + exp.p(23, 4.f); + + sd::ops::reduce_norm_max_bp op; + auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {true}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - x.linspace(1); - - exp.p(15, 1.f); - exp.p(19, 2.f); - exp.p(23, 3.f); + exp.p(15, 1.f); + exp.p(19, 2.f); + exp.p(23, 3.f); - sd::ops::reduce_norm_max_bp op; - auto result = op.evaluate({&x, &eps}, {}, {0,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm_max_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {1, 3, 1}, {1.f, 2.f, 3.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + exp.p(15, 1.f); + exp.p(19, 2.f); + exp.p(23, 3.f); + sd::ops::reduce_norm_max_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {1, 3, 1}, {1.f, 2.f, 3.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - x.linspace(1); - exp.p(15, 1.f); - exp.p(19, 2.f); - exp.p(23, 3.f); - sd::ops::reduce_norm_max_bp op; - auto result = op.evaluate({&x, &eps}, {1.f}, {0,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create(1.f); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + exp.p(23, 1.f); + sd::ops::reduce_norm_max_bp op; + auto result = op.evaluate({&x, &eps}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create(1.f); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - x.linspace(1); - exp.p(23, 1.f); - sd::ops::reduce_norm_max_bp op; - auto result = op.evaluate({&x, &eps}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create(1.f); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + exp.p(23, 1.f); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create(1.f); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - x.linspace(1); - exp.p(23, 1.f); - - sd::ops::reduce_norm_max_bp op; - auto result = op.evaluate({&x, &eps}, {}, {0, 1, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_norm_max_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {1, 1, 1}, {1.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + exp.p(23, 1.f); + sd::ops::reduce_norm_max_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {1, 1, 1}, {1.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - x.linspace(1); - exp.p(23, 1.f); - sd::ops::reduce_norm_max_bp op; - auto result = op.evaluate({&x, &eps}, {1.f}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Dot_BP_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto y = NDArrayFactory::create('c', {2, 3, 4}); + NDArray* z; // = NDArrayFactory::create('c', {4}); + auto eps = NDArrayFactory::create(1.f); + // auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + y.linspace(2); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto y = NDArrayFactory::create('c', {2, 3, 4}); - NDArray* z; // = NDArrayFactory::create('c', {4}); - auto eps = NDArrayFactory::create(1.f); -// auto exp = NDArrayFactory::create('c', {2, 3, 4}); - x.linspace(1); - y.linspace(2); - + sd::ops::reduce_dot_bp op; + auto result = op.evaluate({&x, &y, &eps}, {}, {}); + auto output = result.at(0); + auto outputX = result.at(1); + // tput->printIndexedBuffer("Result is"); - sd::ops::reduce_dot_bp op; - auto result = op.evaluate({&x, &y, &eps}, {}, {}); - auto output = result.at(0); - auto outputX = result.at(1); - //tput->printIndexedBuffer("Result is"); + // ASSERT_EQ(ND4J_STATUS_OK, result.status()); -// ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(x.equalsTo(outputX)); + ASSERT_TRUE(y.equalsTo(output)); - ASSERT_TRUE(x.equalsTo(outputX)); - ASSERT_TRUE(y.equalsTo(output)); - - -// delete z; + // delete z; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Dot_BP_2) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto y = NDArrayFactory::create('c', {2, 3, 4}); -// auto z; // = NDArrayFactory::create('c', {4}); - auto eps = NDArrayFactory::create('c', {2, 4}); - auto expX = NDArrayFactory::create('c', {2, 3, 4}, {2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, - 10.f, 12.f, 14.f, 16.f, 10.f, 12.f, 14.f, 16.f, 10.f, 12.f, 14.f, 16.f - }); - auto expY = NDArrayFactory::create('c', {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, - 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f - }); - x.assign(1.f); - eps.linspace(1); - y.assign(2.f); - sd::ops::reduce_dot_bp op; - auto result = op.evaluate({&x, &y, &eps}, {}, {1}); - ASSERT_EQ(result.status(), ND4J_STATUS_OK); - ASSERT_EQ(result.size(), 2); - auto outputX = result.at(0); - auto outputY = result.at(1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(expX.equalsTo(outputX)); - ASSERT_TRUE(expY.equalsTo(outputY)); - - -// delete z; + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto y = NDArrayFactory::create('c', {2, 3, 4}); + // auto z; // = NDArrayFactory::create('c', {4}); + auto eps = NDArrayFactory::create('c', {2, 4}); + auto expX = NDArrayFactory::create( + 'c', {2, 3, 4}, + {2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, + 10.f, 12.f, 14.f, 16.f, 10.f, 12.f, 14.f, 16.f, 10.f, 12.f, 14.f, 16.f}); + auto expY = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f}); + x.assign(1.f); + eps.linspace(1); + y.assign(2.f); + sd::ops::reduce_dot_bp op; + auto result = op.evaluate({&x, &y, &eps}, {}, {1}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + ASSERT_EQ(result.size(), 2); + auto outputX = result.at(0); + auto outputY = result.at(1); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(expX.equalsTo(outputX)); + ASSERT_TRUE(expY.equalsTo(outputY)); + + // delete z; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Dot_BP_02) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto y = NDArrayFactory::create('c', {2, 3, 4}); -// auto z; // = NDArrayFactory::create('c', {4}); - auto eps = NDArrayFactory::create('c', {2, 4}); - auto expX = NDArrayFactory::create('c', {2, 3, 4}, {2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, - 10.f, 12.f, 14.f, 16.f, 10.f, 12.f, 14.f, 16.f, 10.f, 12.f, 14.f, 16.f - }); - auto expY = NDArrayFactory::create('c', {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, - 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f - }); - auto axis = NDArrayFactory::create('c', {1}, {1}); - x.assign(1.f); - eps.linspace(1); - y.assign(2.f); - sd::ops::reduce_dot_bp op; - auto result = op.evaluate({&x, &y, &eps, &axis}, {}, {}, {false}); - ASSERT_EQ(result.status(), ND4J_STATUS_OK); - ASSERT_EQ(result.size(), 2); - auto outputX = result.at(0); - auto outputY = result.at(1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(expX.equalsTo(outputX)); - ASSERT_TRUE(expY.equalsTo(outputY)); - - -// delete z; + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto y = NDArrayFactory::create('c', {2, 3, 4}); + // auto z; // = NDArrayFactory::create('c', {4}); + auto eps = NDArrayFactory::create('c', {2, 4}); + auto expX = NDArrayFactory::create( + 'c', {2, 3, 4}, + {2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, + 10.f, 12.f, 14.f, 16.f, 10.f, 12.f, 14.f, 16.f, 10.f, 12.f, 14.f, 16.f}); + auto expY = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f}); + auto axis = NDArrayFactory::create('c', {1}, {1}); + x.assign(1.f); + eps.linspace(1); + y.assign(2.f); + sd::ops::reduce_dot_bp op; + auto result = op.evaluate({&x, &y, &eps, &axis}, {}, {}, {false}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + ASSERT_EQ(result.size(), 2); + auto outputX = result.at(0); + auto outputY = result.at(1); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(expX.equalsTo(outputX)); + ASSERT_TRUE(expY.equalsTo(outputY)); + + // delete z; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Dot_BP_3) { - - auto x = NDArrayFactory::create('c', {3, 4}); - auto y = NDArrayFactory::create('c', {3, 4}); - auto eps = NDArrayFactory::create('c', {3}); - auto expX = NDArrayFactory::create('c', {3, 4}, {2.f, 2.f, 2.f, 2.f, 4.f, 4.f, 4.f, 4.f, 6.f, 6.f, 6.f, 6.f}); - auto expY = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 10.f, 12.f, 14.f, 16.f, 27.f, 30.f, 33.f, 36.f}); - x.linspace(1); - eps.linspace(1); - y.assign(2.f); - - sd::ops::reduce_dot_bp op; - auto result = op.evaluate({&x,&y, &eps}, {}, {1}); - auto outputX = result.at(0); - auto outputY = result.at(1); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(expX.equalsTo(outputX)); - ASSERT_TRUE(expY.equalsTo(outputY)); - - + auto x = NDArrayFactory::create('c', {3, 4}); + auto y = NDArrayFactory::create('c', {3, 4}); + auto eps = NDArrayFactory::create('c', {3}); + auto expX = NDArrayFactory::create( + 'c', {3, 4}, + {2.f, 2.f, 2.f, 2.f, 4.f, 4.f, 4.f, 4.f, 6.f, 6.f, 6.f, 6.f}); + auto expY = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 2.f, 3.f, 4.f, 10.f, 12.f, 14.f, 16.f, 27.f, 30.f, 33.f, 36.f}); + x.linspace(1); + eps.linspace(1); + y.assign(2.f); + + sd::ops::reduce_dot_bp op; + auto result = op.evaluate({&x, &y, &eps}, {}, {1}); + auto outputX = result.at(0); + auto outputY = result.at(1); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(expX.equalsTo(outputX)); + ASSERT_TRUE(expY.equalsTo(outputY)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, cumsum_bp_1) { + auto x = NDArrayFactory::create('c', {3, 4}); + auto eps = NDArrayFactory::create('c', {3, 4}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {12.f, 11.f, 10.f, 9.f, 8.f, 7.f, 6.f, 5.f, 4.f, 3.f, 2.f, 1.f}); + x.linspace(1); + eps.assign(1.f); - auto x = NDArrayFactory::create('c', {3, 4}); - auto eps = NDArrayFactory::create('c', {3, 4}); - auto exp = NDArrayFactory::create('c', {3, 4}, {12.f, 11.f, 10.f, 9.f, 8.f, 7.f, - 6.f, 5.f, 4.f, 3.f, 2.f, 1.f}); - x.linspace(1); - eps.assign(1.f); - - sd::ops::cumsum_bp op; - auto result = op.evaluate({&x, &eps}, {}, {0,0}); - auto output = result.at(0); + sd::ops::cumsum_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0, 0}); + auto output = result.at(0); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, cumsum_bp_2) { - auto x = NDArrayFactory::create('c', {3, 4}); - auto eps = NDArrayFactory::create('c', {3, 4}); - auto exp = NDArrayFactory::create('c', {3, 4}, { 11.f, 10.f, 9.f, 8.f, 7.f, 6.f, - 5.f, 4.f, 3.f, 2.f, 1.f, 0.f}); - x.linspace(1); - eps.assign(1.f); - + auto x = NDArrayFactory::create('c', {3, 4}); + auto eps = NDArrayFactory::create('c', {3, 4}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {11.f, 10.f, 9.f, 8.f, 7.f, 6.f, 5.f, 4.f, 3.f, 2.f, 1.f, 0.f}); + x.linspace(1); + eps.assign(1.f); - sd::ops::cumsum_bp op; - auto result = op.evaluate({&x, &eps}, {}, {1,0}); - auto output = result.at(0); + sd::ops::cumsum_bp op; + auto result = op.evaluate({&x, &eps}, {}, {1, 0}); + auto output = result.at(0); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, cumsum_bp_3) { - auto x = NDArrayFactory::create('c', {3, 4}); - auto eps = NDArrayFactory::create('c', {3, 4}); - auto exp = NDArrayFactory::create('c', {3, 4}); - - x.linspace(1); - exp.linspace(0); - eps.assign(1.f); + auto x = NDArrayFactory::create('c', {3, 4}); + auto eps = NDArrayFactory::create('c', {3, 4}); + auto exp = NDArrayFactory::create('c', {3, 4}); - sd::ops::cumsum_bp op; - auto result = op.evaluate({&x, &eps}, {}, {1,1}); - auto output = result.at(0); + x.linspace(1); + exp.linspace(0); + eps.assign(1.f); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.equalsTo(output)); - - + sd::ops::cumsum_bp op; + auto result = op.evaluate({&x, &eps}, {}, {1, 1}); + auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(exp.equalsTo(output)); } - diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp index 0ca5e210ab92..b51ff96745f7 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp @@ -18,3506 +18,3456 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 10.06.2018 // - -#include "testlayers.h" -#include #include #include +#include + +#include "testlayers.h" // #include using namespace sd; - class DeclarableOpsTests8 : public testing::Test { -public: - - DeclarableOpsTests8() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests8() { + printf("\n"); + fflush(stdout); + } }; template class TypedDeclarableOpsTests8 : public testing::Test { -public: - - TypedDeclarableOpsTests8() { - printf("\n"); - fflush(stdout); - } + public: + TypedDeclarableOpsTests8() { + printf("\n"); + fflush(stdout); + } }; typedef ::testing::Types TestingTypes; -TYPED_TEST_CASE(TypedDeclarableOpsTests8, TestingTypes); +TYPED_TEST_SUITE(TypedDeclarableOpsTests8, TestingTypes); //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceVariance_test1) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.f}); + auto exp = NDArrayFactory::create( + 'c', {4}, {602.2222f, 727.13885f, 993.5555f, 755.8889f}); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.f}); - auto exp = NDArrayFactory::create('c', {4}, {602.2222f, 727.13885f, 993.5555f, 755.8889f}); - - sd::ops::reduce_variance op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_variance op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceVariance_test2) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.f}); + auto exp = NDArrayFactory::create( + 'c', {1, 1, 4}, {602.2222f, 727.13885f, 993.5555f, 755.8889f}); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.f}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {602.2222f, 727.13885f, 993.5555f, 755.8889f}); - - sd::ops::reduce_variance op; - auto result = op.evaluate({&x}, {1.}, {0,1}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_variance op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceVariance_test3) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.f}); + auto exp = NDArrayFactory::create('c', {3}, + {900.9375f, 969.8594f, 424.1875f}); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.f}); - auto exp = NDArrayFactory::create('c', {3}, {900.9375f, 969.8594f, 424.1875f}); - - sd::ops::reduce_variance op; - auto result = op.evaluate({&x}, {}, {0,2}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_variance op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceVariance_test4) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.f}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, + {900.9375f, 969.8594f, 424.1875f}); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.f}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {900.9375f, 969.8594f, 424.1875f}); - - sd::ops::reduce_variance op; - auto result = op.evaluate({&x}, {1.}, {0,2}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_variance op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceVariance_test5) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.f}); + auto exp = NDArrayFactory::create(788.6927f); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.f}); - auto exp = NDArrayFactory::create(788.6927f); - - sd::ops::reduce_variance op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_variance op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceVariance_test6) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.}); + auto exp = NDArrayFactory::create(788.6927f); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); - auto exp = NDArrayFactory::create(788.6927f); - - sd::ops::reduce_variance op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_variance op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceVariance_test7) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {788.6927f}); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); - auto exp = NDArrayFactory::create('c', {1,1,1}, {788.6927f}); - - sd::ops::reduce_variance op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_variance op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceVariance_test8) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {788.6927f}); + auto axes = NDArrayFactory::create({0, 1, 2}); + sd::ops::reduce_variance op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); - auto exp = NDArrayFactory::create('c', {1,1,1}, {788.6927f}); - auto axes = NDArrayFactory::create({0, 1, 2}); - sd::ops::reduce_variance op; - auto result = op.evaluate({&x, &axes}, {}, {}, {true}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDev_test1) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.}); + auto exp = NDArrayFactory::create( + 'c', {4}, {24.54022f, 26.96551f, 31.52072f, 27.49343f}); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); - auto exp = NDArrayFactory::create('c', {4}, {24.54022f, 26.96551f, 31.52072f, 27.49343f}); - - sd::ops::reduce_stdev op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDev_test2) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.}); + auto exp = NDArrayFactory::create( + 'c', {1, 1, 4}, {24.54022f, 26.96551f, 31.52072f, 27.49343f}); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {24.54022f, 26.96551f, 31.52072f, 27.49343f}); - - sd::ops::reduce_stdev op; - auto result = op.evaluate({&x}, {1.}, {0,1}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDev_test3) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.}); + auto exp = NDArrayFactory::create('c', {3}, + {30.01562f, 31.14257f, 20.59581f}); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); - auto exp = NDArrayFactory::create('c', {3}, {30.01562f, 31.14257f, 20.59581f}); - - sd::ops::reduce_stdev op; - auto result = op.evaluate({&x}, {}, {0,2}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDev_test4) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, + {30.01562f, 31.14257f, 20.59581f}); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {30.01562f, 31.14257f, 20.59581f}); - - sd::ops::reduce_stdev op; - auto result = op.evaluate({&x}, {1.}, {0,2}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDev_test5) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.}); + auto exp = NDArrayFactory::create(28.08367f); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); - auto exp = NDArrayFactory::create(28.08367f); - - sd::ops::reduce_stdev op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDev_test6) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.}); + auto exp = NDArrayFactory::create(28.08367f); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); - auto exp = NDArrayFactory::create(28.08367f); - - sd::ops::reduce_stdev op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDev_test7) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {28.08367f}); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); - auto exp = NDArrayFactory::create('c', {1,1,1}, {28.08367f}); - - sd::ops::reduce_stdev op; - auto result = op.evaluate({&x}, {1.f}, {0,1,2}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x}, {1.f}, {0, 1, 2}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDev_test8) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.}); + auto exp = NDArrayFactory::create( + 'c', {4}, {26.88246f, 29.53924f, 34.52921f, 30.11755f}); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); - auto exp = NDArrayFactory::create('c', {4}, {26.88246f, 29.53924f, 34.52921f, 30.11755f}); - - sd::ops::reduce_stdev op; - auto result = op.evaluate({&x}, {0.f,1.f}, {0,1}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - // output->printBuffer("Reduced STDDEV"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x}, {0.f, 1.f}, {0, 1}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + // output->printBuffer("Reduced STDDEV"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDev_test08) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.}); + auto exp = NDArrayFactory::create( + 'c', {4}, {26.88246f, 29.53924f, 34.52921f, 30.11755f}); + auto axes = NDArrayFactory::create({0, 1}); + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x, &axes}, {}, {}, {false, true}); + auto output = result.at(0); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); - auto exp = NDArrayFactory::create('c', {4}, {26.88246f, 29.53924f, 34.52921f, 30.11755f}); - auto axes = NDArrayFactory::create({0,1}); - sd::ops::reduce_stdev op; - auto result = op.evaluate({&x, &axes}, {}, {}, {false, true}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - // output->printBuffer("Reduced STDDEV08"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), result.status()); + // output->printBuffer("Reduced STDDEV08"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceVarianceBP_test1) { - - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create('c', {1,1}, {0.5f}); - auto gradO2 = NDArrayFactory::create(0.5f); - auto exp12 = NDArrayFactory::create('c', {3,4}, {-0.5f, -0.4090909f, -0.3181818f, -0.22727273f, -0.13636364f, -0.045454547f, 0.045454547f, 0.13636364f, 0.22727273f, 0.3181818f, 0.4090909f, 0.5f}); - auto exp34 = NDArrayFactory::create('c', {3,4}, {-0.45833334f, -0.375f, -0.29166666f, -0.20833333f, -0.125f, -0.041666668f, 0.041666668f, 0.125f, 0.20833333f, 0.29166666f, 0.375f, 0.45833334f}); - - x.linspace(1); - - sd::ops::reduce_variance_bp op; - - auto result = op.evaluate({&x, &gradO2}, {0,1}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1,1}, {}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO2}, {0,0}, {}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1,0}, {}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = NDArrayFactory::create('c', {1, 1}, {0.5f}); + auto gradO2 = NDArrayFactory::create(0.5f); + auto exp12 = NDArrayFactory::create( + 'c', {3, 4}, + {-0.5f, -0.4090909f, -0.3181818f, -0.22727273f, -0.13636364f, + -0.045454547f, 0.045454547f, 0.13636364f, 0.22727273f, 0.3181818f, + 0.4090909f, 0.5f}); + auto exp34 = NDArrayFactory::create( + 'c', {3, 4}, + {-0.45833334f, -0.375f, -0.29166666f, -0.20833333f, -0.125f, + -0.041666668f, 0.041666668f, 0.125f, 0.20833333f, 0.29166666f, 0.375f, + 0.45833334f}); + + x.linspace(1); + + sd::ops::reduce_variance_bp op; + + auto result = op.evaluate({&x, &gradO2}, {0, 1}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 1}, {}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO2}, {0, 0}, {}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 0}, {}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceVarianceBP_test2) { - - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create('c', {1,4}, {1.f,2.f,3.f,4.f}); - auto gradO2 = NDArrayFactory::create('c', {4}, {1.,2.,3.,4.}); - auto exp12 = NDArrayFactory::create('c', {3,4}, {-2.666667f, -5.333333f, -8.000000f, -10.666667f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 2.666667f, 5.333333f, 8.000000f, 10.666667f}); - auto exp34 = NDArrayFactory::create('c', {3,4}, {-4.000000f, -8.000000f, -12.000000f, -16.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 4.000000f, 8.000000f, 12.000000f, 16.000000f}); - - x.linspace(1); - - sd::ops::reduce_variance_bp op; - - auto result = op.evaluate({&x, &gradO2}, {0,0}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1,0}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO2}, {0,1}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1,1}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = + NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto gradO2 = NDArrayFactory::create('c', {4}, {1., 2., 3., 4.}); + auto exp12 = NDArrayFactory::create( + 'c', {3, 4}, + {-2.666667f, -5.333333f, -8.000000f, -10.666667f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 2.666667f, 5.333333f, 8.000000f, 10.666667f}); + auto exp34 = NDArrayFactory::create( + 'c', {3, 4}, + {-4.000000f, -8.000000f, -12.000000f, -16.000000f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 4.000000f, 8.000000f, 12.000000f, 16.000000f}); + + x.linspace(1); + + sd::ops::reduce_variance_bp op; + + auto result = op.evaluate({&x, &gradO2}, {0, 0}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 0}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO2}, {0, 1}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 1}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceVarianceBP_test02) { - - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create('c', {1,4}, {1.f,2.f,3.f,4.f}); - auto gradO2 = NDArrayFactory::create('c', {4}, {1.f,2.f,3.f,4.f}); - auto exp12 = NDArrayFactory::create('c', {3,4}, {-2.666667f, -5.333333f, -8.000000f, -10.666667f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 2.666667f, 5.333333f, 8.000000f, 10.666667f}); - auto exp34 = NDArrayFactory::create('c', {3,4}, {-4.000000f, -8.000000f, -12.000000f, -16.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 4.000000f, 8.000000f, 12.000000f, 16.000000f}); - auto axes = NDArrayFactory::create({(int)0,}); - x.linspace(1); - - sd::ops::reduce_variance_bp op; - - auto result = op.evaluate({&x, &gradO2, &axes}, {}, {}, {false, false}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1, &axes}, {}, {}, {true, false}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO2, &axes}, {}, {}, {false, true}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1, &axes}, {}, {}, {true, true}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = + NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto gradO2 = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp12 = NDArrayFactory::create( + 'c', {3, 4}, + {-2.666667f, -5.333333f, -8.000000f, -10.666667f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 2.666667f, 5.333333f, 8.000000f, 10.666667f}); + auto exp34 = NDArrayFactory::create( + 'c', {3, 4}, + {-4.000000f, -8.000000f, -12.000000f, -16.000000f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 4.000000f, 8.000000f, 12.000000f, 16.000000f}); + auto axes = NDArrayFactory::create(0); + x.linspace(1); + + sd::ops::reduce_variance_bp op; + + auto result = op.evaluate({&x, &gradO2, &axes}, {}, {}, {false, false}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO1, &axes}, {}, {}, {true, false}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO2, &axes}, {}, {}, {false, true}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + result = op.evaluate({&x, &gradO1, &axes}, {}, {}, {true, true}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceVarianceBP_test3) { - - auto x = NDArrayFactory::create('c', {3, 4}); - auto gradO1 = NDArrayFactory::create('c', {3, 1}, {1.f, 2.f, 3.f}); - auto gradO2 = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - auto exp12 = NDArrayFactory::create('c', {3, 4}, - {-0.750000f, -0.250000f, 0.250000f, 0.750000f, -1.500000f, -0.500000f, - 0.500000f, 1.500000f, -2.250000f, -0.750000f, 0.750000f, 2.250000f}); - auto exp34 = NDArrayFactory::create('c', {3, 4}, - {-1.000000f, -0.333333f, 0.333333f, 1.000000f, -2.000000f, -0.666667f, - 0.666667f, 2.000000f, -3.000000f, -1.000000f, 1.000000f, 3.000000f}); - - x.linspace(1); - - sd::ops::reduce_variance_bp op; - - auto result = op.evaluate({&x, &gradO2}, {0, 0}, {1}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1, 0}, {1}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO2}, {0, 1}, {1}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1, 1}, {1}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = NDArrayFactory::create('c', {3, 1}, {1.f, 2.f, 3.f}); + auto gradO2 = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto exp12 = NDArrayFactory::create( + 'c', {3, 4}, + {-0.750000f, -0.250000f, 0.250000f, 0.750000f, -1.500000f, -0.500000f, + 0.500000f, 1.500000f, -2.250000f, -0.750000f, 0.750000f, 2.250000f}); + auto exp34 = NDArrayFactory::create( + 'c', {3, 4}, + {-1.000000f, -0.333333f, 0.333333f, 1.000000f, -2.000000f, -0.666667f, + 0.666667f, 2.000000f, -3.000000f, -1.000000f, 1.000000f, 3.000000f}); + + x.linspace(1); + + sd::ops::reduce_variance_bp op; + + auto result = op.evaluate({&x, &gradO2}, {0, 0}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 0}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO2}, {0, 1}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 1}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDevBP_test1) { - - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create('c', {1,1}, {0.5f}); - auto gradO2 = NDArrayFactory::create(0.5f); - auto exp12 = NDArrayFactory::create('c', {3,4}, {-0.069337524f, -0.056730703f, -0.04412388f, -0.031517055f, -0.018910235f, -0.0063034114f, 0.0063034114f, 0.018910235f, 0.031517055f, 0.04412388f, 0.056730703f, 0.069337524f}); - auto exp34 = NDArrayFactory::create('c', {3,4}, {-0.06638563f, -0.05431551f, -0.0422454f, -0.030175284f, -0.01810517f, -0.006035057f, 0.006035057f, 0.01810517f, 0.030175284f, 0.0422454f, 0.05431551f, 0.06638563f}); - - x.linspace(1); - - sd::ops::reduce_stdev_bp op; - - auto result = op.evaluate({&x, &gradO2}, {0,1}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - // output->printIndexedBuffer(); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1,1}, {}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO2}, {0,0}, {}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1,0}, {}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = NDArrayFactory::create('c', {1, 1}, {0.5f}); + auto gradO2 = NDArrayFactory::create(0.5f); + auto exp12 = NDArrayFactory::create( + 'c', {3, 4}, + {-0.069337524f, -0.056730703f, -0.04412388f, -0.031517055f, -0.018910235f, + -0.0063034114f, 0.0063034114f, 0.018910235f, 0.031517055f, 0.04412388f, + 0.056730703f, 0.069337524f}); + auto exp34 = NDArrayFactory::create( + 'c', {3, 4}, + {-0.06638563f, -0.05431551f, -0.0422454f, -0.030175284f, -0.01810517f, + -0.006035057f, 0.006035057f, 0.01810517f, 0.030175284f, 0.0422454f, + 0.05431551f, 0.06638563f}); + + x.linspace(1); + + sd::ops::reduce_stdev_bp op; + + auto result = op.evaluate({&x, &gradO2}, {0, 1}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + // output->printIndexedBuffer(); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 1}, {}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO2}, {0, 0}, {}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 0}, {}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDevBP_test2) { - - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create('c', {1,4}, {1.f,2.f,3.f,4.f}); - auto gradO2 = NDArrayFactory::create('c', {4}, {1.f,2.f,3.f,4.f}); - auto exp12 = NDArrayFactory::create('c', {3,4}, {-0.4082483f, -0.8164966f, -1.2247449f, -1.6329932f, 0.0, 0.0, 0.0, 0.0, 0.4082483f, 0.8164966f, 1.2247449f, 1.6329932f}); - auto exp34 = NDArrayFactory::create('c', {3,4}, {-0.5f, -1.0f, -1.5f, -2.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.5f, 1.0f, 1.5f, 2.0f}); - - x.linspace(1); - - sd::ops::reduce_stdev_bp op; - - auto result = op.evaluate({&x, &gradO2}, {0,0}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1,0}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO2}, {0,1}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1,1}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = + NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto gradO2 = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp12 = NDArrayFactory::create( + 'c', {3, 4}, + {-0.4082483f, -0.8164966f, -1.2247449f, -1.6329932f, 0.0, 0.0, 0.0, 0.0, + 0.4082483f, 0.8164966f, 1.2247449f, 1.6329932f}); + auto exp34 = + NDArrayFactory::create('c', {3, 4}, + {-0.5f, -1.0f, -1.5f, -2.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.5f, 1.0f, 1.5f, 2.0f}); + + x.linspace(1); + + sd::ops::reduce_stdev_bp op; + + auto result = op.evaluate({&x, &gradO2}, {0, 0}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 0}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO2}, {0, 1}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 1}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDevBP_test02) { - - int ax = 0; - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create('c', {1,4}, {1.f,2.f,3.f,4.f}); - auto gradO2 = NDArrayFactory::create('c', {4}, {1.f,2.f,3.f,4.f}); - auto exp12 = NDArrayFactory::create('c', {3,4}, {-0.4082483f, -0.8164966f, -1.2247449f, -1.6329932f, 0.0, 0.0, 0.0, 0.0, 0.4082483f, 0.8164966f, 1.2247449f, 1.6329932f}); - auto exp34 = NDArrayFactory::create('c', {3,4}, {-0.5f, -1.0f, -1.5f, -2.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.5f, 1.0f, 1.5f, 2.0f}); - auto axis = NDArrayFactory::create('c', {1}, {ax}); - x.linspace(1); - - sd::ops::reduce_stdev_bp op; - - auto result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {false, false}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1, &axis}, {}, {}, {true, false}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {false, true}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1, &axis}, {}, {}, {true, true}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - + int ax = 0; + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = + NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto gradO2 = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp12 = NDArrayFactory::create( + 'c', {3, 4}, + {-0.4082483f, -0.8164966f, -1.2247449f, -1.6329932f, 0.0, 0.0, 0.0, 0.0, + 0.4082483f, 0.8164966f, 1.2247449f, 1.6329932f}); + auto exp34 = + NDArrayFactory::create('c', {3, 4}, + {-0.5f, -1.0f, -1.5f, -2.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.5f, 1.0f, 1.5f, 2.0f}); + auto axis = NDArrayFactory::create('c', {1}, {ax}); + x.linspace(1); + + sd::ops::reduce_stdev_bp op; + + auto result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {false, false}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO1, &axis}, {}, {}, {true, false}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {false, true}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + result = op.evaluate({&x, &gradO1, &axis}, {}, {}, {true, true}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDevBP_test3) { - - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create('c', {3,1}, {1.f,2.f,3.f}); - auto gradO2 = NDArrayFactory::create('c', {3}, {1.f,2.f,3.f}); - auto exp12 = NDArrayFactory::create('c', {3,4}, {-0.3354102f, -0.1118034f, 0.1118034f, 0.3354102f, -0.6708204f, -0.2236068f, 0.2236068f, 0.6708204f, -1.0062306f, -0.3354102f, 0.3354102f, 1.0062306f}); - auto exp34 = NDArrayFactory::create('c', {3,4}, {-0.38729835f, -0.12909944f, 0.12909944f, 0.38729835f, -0.7745967f, -0.2581989f, 0.2581989f, 0.7745967f, -1.161895f, -0.38729835f, 0.38729835f, 1.161895f}); - - x.linspace(1); - - sd::ops::reduce_stdev_bp op; - - auto result = op.evaluate({&x, &gradO2}, {0,0}, {1}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1,0}, {1}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO2}, {0,1}, {1}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1,1}, {1}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = NDArrayFactory::create('c', {3, 1}, {1.f, 2.f, 3.f}); + auto gradO2 = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto exp12 = NDArrayFactory::create( + 'c', {3, 4}, + {-0.3354102f, -0.1118034f, 0.1118034f, 0.3354102f, -0.6708204f, + -0.2236068f, 0.2236068f, 0.6708204f, -1.0062306f, -0.3354102f, + 0.3354102f, 1.0062306f}); + auto exp34 = NDArrayFactory::create( + 'c', {3, 4}, + {-0.38729835f, -0.12909944f, 0.12909944f, 0.38729835f, -0.7745967f, + -0.2581989f, 0.2581989f, 0.7745967f, -1.161895f, -0.38729835f, + 0.38729835f, 1.161895f}); + + x.linspace(1); + + sd::ops::reduce_stdev_bp op; + + auto result = op.evaluate({&x, &gradO2}, {0, 0}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 0}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO2}, {0, 1}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 1}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_1) { + auto input = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create(120.f); + //************************************// - auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); - auto exp = NDArrayFactory::create(120.f); - //************************************// - - sd::ops::reduce_sum op; - auto result = op.evaluate({&input}, {}, {}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - //z->printIndexedBuffer("Result is "); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::reduce_sum op; + auto result = op.evaluate({&input}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_2) { + auto input = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create({15.f, 40.f, 65.f}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); - auto exp = NDArrayFactory::create({15.f, 40.f, 65.f}); - //************************************// - - sd::ops::reduce_sum op; - auto result = op.evaluate({&input}, {}, {1}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::reduce_sum op; + auto result = op.evaluate({&input}, {}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_03) { + auto input = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create({15.f, 40.f, 65.f}); + auto axis = NDArrayFactory::create('c', {1}, {1}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); - auto exp = NDArrayFactory::create({15.f, 40.f, 65.f}); - auto axis = NDArrayFactory::create('c', {1}, {1}); - //************************************// - - sd::ops::reduce_sum op; - auto result = op.evaluate({&input, &axis}, {}, {}, {false}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - // z->printIndexedBuffer("Result is "); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::reduce_sum op; + auto result = op.evaluate({&input, &axis}, {}, {}, {false}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_1) { + auto input = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create(1307674368000.f); + //************************************// - auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); - auto exp = NDArrayFactory::create(1307674368000.f); - //************************************// - - sd::ops::reduce_prod op; - auto result = op.evaluate({&input}, {}, {}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - //z->printIndexedBuffer("Result is "); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::reduce_prod op; + auto result = op.evaluate({&input}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_2) { + auto input = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create({120.f, 30240.f, 360360.f}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); - auto exp = NDArrayFactory::create({120.f, 30240.f, 360360.f}); - //************************************// - - sd::ops::reduce_prod op; - auto result = op.evaluate({&input}, {}, {1}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::reduce_prod op; + auto result = op.evaluate({&input}, {}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_01) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {66.f, 72.f, 78.f, 84.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {4}, {66.f, 72.f, 78.f, 84.f}); - x.linspace(1); - - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_02) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 1, 4}, {66.f, 72.f, 78.f, 84.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {66.f, 72.f, 78.f, 84.f}); - x.linspace(1); - - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {68.f, 100.f, 132.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {3}, {68.f, 100.f, 132.f}); - x.linspace(1); - - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 3, 1}, {68.f, 100.f, 132.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {68.f, 100.f, 132.f}); - x.linspace(1); - - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(300.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create(300.f); - x.linspace(1); - - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(300.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create(300.f); - x.linspace(1); - - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {300.f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,1}, {300.f}); - x.linspace(1); -// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_01) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create('c', {2}, {10395.f, 46080.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create('c', {2}, {10395.f, 46080.f}); - x.linspace(1); - - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_02) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create('c', {1, 1, 2}, {10395.f, 46080.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create('c', {1,1,2}, {10395.f, 46080.f}); - x.linspace(1); - - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_3) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create('c', {3}, {112.f, 1080.f, 3960.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create('c', {3}, {112.f, 1080.f, 3960.f}); - x.linspace(1); - - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_4) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = + NDArrayFactory::create('c', {1, 3, 1}, {112.f, 1080.f, 3960.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {112.f, 1080.f, 3960.f}); - x.linspace(1); - - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_04) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = + NDArrayFactory::create('c', {1, 3, 1}, {112.f, 1080.f, 3960.f}); + auto axes = NDArrayFactory::create({0, 2}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {112.f, 1080.f, 3960.f}); - auto axes = NDArrayFactory::create({0, 2}); - x.linspace(1); - - sd::ops::reduce_prod op; - auto result = op.evaluate({&x, &axes}, {}, {}, {true}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_5) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create(479001600.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create(479001600.f); - x.linspace(1); - - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_6) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create(479001600.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create(479001600.f); - x.linspace(1); - - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_7) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {479001600.f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {479001600.f}); - x.linspace(1); -// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Min_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - x.linspace(1); - - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {}, {0, 1}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Min_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {1.f, 2.f, 3.f, 4.f}); - x.linspace(1); - - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Min_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {1.f, 5.f, 9.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {1.f, 5.f, 9.f}); - x.linspace(1); - - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {}, {0, 2}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Min_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {1.f, 5.f, 9.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {1.f, 5.f, 9.f}); - x.linspace(1); - - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Min_04) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {1.f, 5.f, 9.f}); + auto axes = NDArrayFactory::create({0, 2}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {1.f, 5.f, 9.f}); - auto axes = NDArrayFactory::create({0, 2}); - x.linspace(1); - - sd::ops::reduce_min op; - auto result = op.evaluate({&x, &axes}, {}, {}, {true}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_min op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Min_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(1.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(1.f); - x.linspace(1); - - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Min_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(1.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(1.f); - x.linspace(1); - - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Min_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {1.f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {1.f}); - x.linspace(1); - // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Max_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); - x.linspace(1); - - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - // output->printShapeInfo("Output shape"); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + // output->printShapeInfo("Output shape"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Max_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 1, 4}, {21.f, 22.f, 23.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {21.f, 22.f, 23.f, 24.f}); - x.linspace(1); - - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Max_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {16.f, 20.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {16.f, 20.f, 24.f}); - x.linspace(1); - - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {}, {0, 2}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Max_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {16.f, 20.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {16.f, 20.f, 24.f}); - x.linspace(1); - - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Max_04) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {16.f, 20.f, 24.f}); + auto axes = NDArrayFactory::create({0, 2}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {16.f, 20.f, 24.f}); - auto axes = NDArrayFactory::create({0, 2}); - x.linspace(1); - - sd::ops::reduce_max op; - auto result = op.evaluate({&x, &axes}, {}, {}, {true}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_max op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Max_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(24.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(24.f); - x.linspace(1); - - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Max_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(24.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(24.f); - x.linspace(1); - - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Max_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {24.f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {24.f}); - x.linspace(1); -// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {66.f, 72.f, 78.f, 84.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {66.f, 72.f, 78.f, 84.f}); - x.linspace(1); - - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 1, 4}, {66.f, 72.f, 78.f, 84.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {66.f, 72.f, 78.f, 84.f}); - x.linspace(1); - - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {68.f, 100.f, 132.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {68.f, 100.f, 132.f}); - x.linspace(1); - - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 3, 1}, {68.f, 100.f, 132.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {68.f, 100.f, 132.f}); - x.linspace(1); - - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_04) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 3, 1}, {68.f, 100.f, 132.f}); + auto axes = NDArrayFactory::create({0, 2}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {68.f, 100.f, 132.f}); - auto axes = NDArrayFactory::create({0, 2}); - x.linspace(1); - - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x, &axes}, {}, {}, {true}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(300.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(300.f); - x.linspace(1); - - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(300.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(300.f); - x.linspace(1); - - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {300.f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {300.f}); - x.linspace(1); -// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); - x.linspace(1); - - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {1, 1, 4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); - x.linspace(1); - - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {3}, {29.597298f, 39.344631f, 49.759422f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {29.597298f, 39.344631f, 49.759422f}); - x.linspace(1); - - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_4) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {29.597298f, 39.344631f, 49.759422f}); - x.linspace(1); - - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); +TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {1, 3, 1}, {29.597298f, 39.344631f, 49.759422f}); + x.linspace(1); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_04) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {1, 3, 1}, {29.597298f, 39.344631f, 49.759422f}); + auto axes = NDArrayFactory::create({0, 2}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {29.597298f, 39.344631f, 49.759422f}); - auto axes = NDArrayFactory::create({0,2}); - x.linspace(1); - - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x, &axes}, {}, {}, {true}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(70.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(70.f); - x.linspace(1); - - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(70.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(70.f); - x.linspace(1); - - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {70.f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {70.f}); - x.linspace(1); -// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); - x.linspace(1); - - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 1, 4}, {21.f, 22.f, 23.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {21.f, 22.f, 23.f, 24.f}); - x.linspace(1); - - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {1.f}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {1.f}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {16.f, 20.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {16.f, 20.f, 24.f}); - x.linspace(1); - - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {}, {0,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {16.f, 20.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 3, 1}, {16.f, 20.f, 24.f}); - x.linspace(1); - - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {1.f}, {0,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {1.f}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_04) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {16.f, 20.f, 24.f}); + auto axes = NDArrayFactory::create({0, 2}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 3, 1}, {16.f, 20.f, 24.f}); - auto axes = NDArrayFactory::create({0,2}); - x.linspace(1); - - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x, &axes}, {}, {}, {true}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(24.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(24.f); - x.linspace(1); - - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(24.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(24.f); - x.linspace(1); - - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {}, {0, 1, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {24.f}); - x.linspace(1); - - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {1.f}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {1.f}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, + {1006.f, 1144.f, 1294.f, 1456.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {1006.f, 1144.f, 1294.f, 1456.f}); - x.linspace(1); - - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 4}, + {1006.f, 1144.f, 1294.f, 1456.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {1006.f, 1144.f, 1294.f, 1456.f}); - x.linspace(1); - - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {1.f}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {1.f}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {876.f, 1548.f, 2476.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {876.f, 1548.f, 2476.f}); - x.linspace(1); - - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {}, {0,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 3, 1}, {876.f, 1548.f, 2476.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 3, 1}, {876.f, 1548.f, 2476.f}); - x.linspace(1); - - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {1.f}, {0,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {1.f}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_04) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 3, 1}, {876.f, 1548.f, 2476.f}); + auto axes = NDArrayFactory::create({0, 2}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 3, 1}, {876.f, 1548.f, 2476.f}); - auto axes = NDArrayFactory::create({0, 2}); - x.linspace(1); - - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x, &axes}, {}, {}, {true}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(4900.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(4900.f); - x.linspace(1); - - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(4900.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(4900.f); - x.linspace(1); - - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {}, {0, 1, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {4900.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {4900.f}); - x.linspace(1); - - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {1.f}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {1.f}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_BP_1) { + auto input = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create(0.5f); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - auto eps = NDArrayFactory::create(0.5f); - auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,0.5f}); - //************************************// - - sd::ops::reduce_sum_bp op; - auto result = op.evaluate({&input, &eps}, {}, {}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_BP_2) { + auto input = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create('c', {1, 1}, {0.5f}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - auto eps = NDArrayFactory::create('c', {1, 1}, {0.5f}); - auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, 0.5f, 0.5f, - 0.5f, 0.5f, 0.5f, 0.5f, - 0.5f, 0.5f, 0.5f,0.5f}); - //************************************// - - sd::ops::reduce_sum_bp op; - auto result = op.evaluate({&input, &eps}, {1.f}, {}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps}, {1.f}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_BP_3) { + auto input = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f}); - //************************************// - - sd::ops::reduce_sum_bp op; - auto result = op.evaluate({&input, &eps}, {}, {0}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_BP_4) { + auto input = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f}); - //************************************// - - sd::ops::reduce_sum_bp op; - auto result = op.evaluate({&input, &eps}, {1.f}, {0}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps}, {1.f}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_BP_04) { + int ax = 0; + auto input = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); + auto axis = NDArrayFactory::create('c', {1}, {ax}); + //************************************// - int ax = 0; - auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f}); - auto axis = NDArrayFactory::create('c', {1}, {ax}); - //************************************// - - sd::ops::reduce_sum_bp op; - auto result = op.evaluate({&input, &eps, &axis}, {}, {}, {true}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps, &axis}, {}, {}, {true}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_BP_1) { - - auto input = NDArrayFactory::create('c', {3, 5}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); - auto eps = NDArrayFactory::create(1307674368000.f); - //************************************// -// auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,0.5f}); - //************************************// - auto exp = NDArrayFactory::create('c', {3, 5}, {1710012166826558903812096.f, 855006083413279451906048.f, 570004067618451974258688.f, - 427503041706639725953024.f, 342002454982589992140800.f, 285002033809225987129344.f, - 244287457550765131825152.f, 213751520853319862976512.f, 190001355872817324752896.f, - 171001227491294996070400.f, 155455648254341989531648.f, 142501016904612993564672.f, - 131539399526781282156544.f, 122143728775382565912576.f, 114000815325130245799936.f}); - - sd::ops::reduce_prod_bp op; - auto result = op.evaluate({&input, &eps}, {}, {}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - + auto input = + NDArrayFactory::create('c', {3, 5}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); + auto eps = NDArrayFactory::create(1307674368000.f); + //************************************// + // auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, + // 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,0.5f}); + //************************************// + auto exp = NDArrayFactory::create( + 'c', {3, 5}, + {1710012166826558903812096.f, 855006083413279451906048.f, + 570004067618451974258688.f, 427503041706639725953024.f, + 342002454982589992140800.f, 285002033809225987129344.f, + 244287457550765131825152.f, 213751520853319862976512.f, + 190001355872817324752896.f, 171001227491294996070400.f, + 155455648254341989531648.f, 142501016904612993564672.f, + 131539399526781282156544.f, 122143728775382565912576.f, + 114000815325130245799936.f}); + + sd::ops::reduce_prod_bp op; + auto result = op.evaluate({&input, &eps}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMean_test1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {11.f, 12.f, 13.f, 14.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {4}, {11.f, 12.f, 13.f, 14.f}); - x.linspace(1); - - - sd::ops::reduce_mean op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_mean op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMean_test2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 1, 4}, {11.f, 12.f, 13.f, 14.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {11.f, 12.f, 13.f, 14.f}); - x.linspace(1); - - - sd::ops::reduce_mean op; - auto result = op.evaluate({&x}, {1.}, {0,1}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_mean op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMean_test3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {8.5f, 12.5f, 16.5f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {3}, {8.5f, 12.5f, 16.5f}); - x.linspace(1); - - - sd::ops::reduce_mean op; - auto result = op.evaluate({&x}, {}, {0,2}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_mean op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMean_test4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 3, 1}, {8.5f, 12.5f, 16.5f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {8.5f, 12.5f, 16.5f}); - x.linspace(1); - - - sd::ops::reduce_mean op; - auto result = op.evaluate({&x}, {1.f}, {0,2}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_mean op; + auto result = op.evaluate({&x}, {1.f}, {0, 2}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMean_test5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(12.5f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create(12.5f); - x.linspace(1); - - - sd::ops::reduce_mean op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_mean op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMean_test6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(12.5f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create(12.5f); - x.linspace(1); - - sd::ops::reduce_mean op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_mean op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMean_test7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {12.5f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,1}, {12.5f}); - x.linspace(1); - - sd::ops::reduce_mean op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_mean op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMean_test8) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {12.5f}); + auto axes = NDArrayFactory::create({0, 1, 2}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,1}, {12.5f}); - auto axes = NDArrayFactory::create({0, 1, 2}); - x.linspace(1); - - sd::ops::reduce_mean op; - auto result = op.evaluate({&x, &axes}, {}, {}, {true}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_mean op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMeanBP_test1) { + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = NDArrayFactory::create(0.5f); + auto gradO2 = NDArrayFactory::create('c', {1, 1}, {0.5f}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {1. / 24, 1. / 24, 1. / 24, 1. / 24, 1. / 24, 1. / 24, 1. / 24, 1. / 24, + 1. / 24, 1. / 24, 1. / 24, 1. / 24}); - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create(0.5f); - auto gradO2 = NDArrayFactory::create('c', {1,1}, {0.5f}); - auto exp = NDArrayFactory::create('c', {3,4}, {1./24, 1./24, 1./24, 1./24, 1./24, 1./24, 1./24, 1./24, 1./24, 1./24, 1./24, 1./24}); + x.linspace(1); - x.linspace(1); + sd::ops::reduce_mean_bp op; - sd::ops::reduce_mean_bp op; + auto result = op.evaluate({&x, &gradO1}, {0}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); - auto result = op.evaluate({&x, &gradO1}, {0}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); + // output->printShapeInfo("o"); - // output->printShapeInfo("o"); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - - result = op.evaluate({&x, &gradO2}, {1}, {}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + result = op.evaluate({&x, &gradO2}, {1}, {}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMeanBP_test2) { + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto gradO2 = + NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {1.f / 3.f, 2.f / 3.f, 1.f, 4.f / 3.f, 1.f / 3.f, 2.f / 3.f, 1.f, + 4.f / 3.f, 1.f / 3.f, 2.f / 3.f, 1.f, 4.f / 3.f}); - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - auto gradO2 = NDArrayFactory::create('c', {1,4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {3,4}, {1.f/3.f, 2.f/3.f, 1.f, 4.f/3.f, 1.f/3.f, 2.f/3.f, 1.f, 4.f/3.f, 1.f/3.f, 2.f/3.f, 1.f, 4.f/3.f}); + x.linspace(1); - x.linspace(1); + sd::ops::reduce_mean_bp op; - sd::ops::reduce_mean_bp op; - - auto result = op.evaluate({&x, &gradO1}, {0}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - - result = op.evaluate({&x, &gradO2}, {1}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + auto result = op.evaluate({&x, &gradO1}, {0}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + result = op.evaluate({&x, &gradO2}, {1}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMeanBP_test02) { - - int ax = 0; - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - auto gradO2 = NDArrayFactory::create('c', {1,4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {3,4}, {1.f/3.f, 2.f/3.f, 1.f, 4.f/3.f, 1.f/3.f, 2.f/3.f, 1.f, 4.f/3.f, 1.f/3.f, 2.f/3.f, 1.f, 4.f/3.f}); - auto axis = NDArrayFactory::create('c', {1}, {ax}); - x.linspace(1); - - sd::ops::reduce_mean_bp op; - - auto result = op.evaluate({&x, &gradO1, &axis}, {}, {}, {false}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - - result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {true}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + int ax = 0; + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto gradO2 = + NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {1.f / 3.f, 2.f / 3.f, 1.f, 4.f / 3.f, 1.f / 3.f, 2.f / 3.f, 1.f, + 4.f / 3.f, 1.f / 3.f, 2.f / 3.f, 1.f, 4.f / 3.f}); + auto axis = NDArrayFactory::create('c', {1}, {ax}); + x.linspace(1); + + sd::ops::reduce_mean_bp op; + + auto result = op.evaluate({&x, &gradO1, &axis}, {}, {}, {false}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {true}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMeanBP_test3) { + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto gradO2 = NDArrayFactory::create('c', {3, 1}, {1.f, 2.f, 3.f}); + auto exp = + NDArrayFactory::create('c', {3, 4}, + {0.25f, 0.25f, 0.25f, 0.25f, 0.5f, 0.5f, + 0.5f, 0.5f, 0.75f, 0.75f, 0.75f, 0.75f}); - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - auto gradO2 = NDArrayFactory::create('c', {3,1}, {1.f, 2.f, 3.f}); - auto exp = NDArrayFactory::create('c', {3,4}, {0.25f, 0.25f, 0.25f, 0.25f, 0.5f, 0.5f, 0.5f, 0.5f, 0.75f, 0.75f, 0.75f, 0.75f}); - - x.linspace(1); - - sd::ops::reduce_mean_bp op; - - auto result = op.evaluate({&x, &gradO1}, {0}, {1}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + x.linspace(1); + sd::ops::reduce_mean_bp op; - result = op.evaluate({&x, &gradO2}, {1}, {1}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + auto result = op.evaluate({&x, &gradO1}, {0}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + result = op.evaluate({&x, &gradO2}, {1}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDevBP_test4) { + auto x = NDArrayFactory::create('c', {3}, {2.f, 3.f, 4.f}); + auto gradO = NDArrayFactory::create(0.5f); + auto exp = NDArrayFactory::create('c', {3}, {-0.25f, 0.f, 0.25f}); - auto x = NDArrayFactory::create('c', {3}, {2.f, 3.f, 4.f}); - auto gradO = NDArrayFactory::create(0.5f); - auto exp = NDArrayFactory::create('c', {3}, {-0.25f, 0.f, 0.25f}); - - sd::ops::reduce_stdev_bp op; - - auto result = op.evaluate({&x, &gradO}, {0,1}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_stdev_bp op; + auto result = op.evaluate({&x, &gradO}, {0, 1}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test1) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 3}, {2.78507, 1.34254, 4.12761, 2.88507, 2.78507, 2.88507}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,0,0,1,0,1,0,1,1,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto expected = NDArrayFactory::create('c', {2,3}, {2.78507, 1.34254, 4.12761, 2.88507, 2.78507, 2.88507}); - - logits.linspace(0.1, 0.1); + logits.linspace(0.1, 0.1); - sd::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&logits, &labels}, {}, {}); + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {}); - ASSERT_EQ(Status::OK(), results.status()); - - auto *output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), results.status()); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test2) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {3, 4}, + {0.26328, 1.46328, 1.72656, 0., 0.26328, 0., 1.46328, 0.26328, 1.72656, + 0., 1.72656, 1.46328}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,0,0,1,0,1,0,1,1,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto expected = NDArrayFactory::create('c', {3,4}, {0.26328, 1.46328, 1.72656, 0. , 0.26328, 0. , 1.46328, 0.26328, 1.72656, 0. , 1.72656, 1.46328}); - - logits.linspace(0.1, 0.1); - - sd::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&logits, &labels}, {}, {0}); + logits.linspace(0.1, 0.1); - ASSERT_EQ(Status::OK(), results.status()); + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {0}); - auto *output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), results.status()); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test3) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 4}, + {0.75125, 1.55125, 3.45375, 0.75125, 3.45375, 0., 2.3025, 1.15125}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,0,0,1,0,1,0,1,1,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto expected = NDArrayFactory::create('c', {2,4}, {0.75125, 1.55125, 3.45375, 0.75125, 3.45375, 0. , 2.3025 , 1.15125}); - - logits.linspace(0.1, 0.1); - - sd::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&logits, &labels}, {}, {1}); - - ASSERT_EQ(Status::OK(), results.status()); + logits.linspace(0.1, 0.1); - auto *output = results.at(0); + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {1}); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), results.status()); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test4) { + auto labels = NDArrayFactory::create('c', {2, 3}, {0, 1, 1, 0, 0, 1}); + auto logits = NDArrayFactory::create('c', {2, 3}); + auto expected = NDArrayFactory::create('c', {2}, {2.10389, 1.00194}); - auto labels = NDArrayFactory::create('c', {2,3},{0,1,1,0,0,1}); - auto logits = NDArrayFactory::create('c', {2,3}); - auto expected = NDArrayFactory::create('c', {2}, {2.10389, 1.00194}); + logits.linspace(0.1, 0.1); - logits.linspace(0.1, 0.1); + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {}); - sd::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&logits, &labels}, {}, {}); - - ASSERT_EQ(Status::OK(), results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), results.status()); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test5) { + auto labels = NDArrayFactory::create('c', {2, 3}, {0, 1, 1, 0, 0, 1}); + auto logits = NDArrayFactory::create('c', {2, 3}); + auto expected = + NDArrayFactory::create('c', {3}, {0., 0.85436, 1.40871}); - auto labels = NDArrayFactory::create('c', {2,3},{0,1,1,0,0,1}); - auto logits = NDArrayFactory::create('c', {2,3}); - auto expected = NDArrayFactory::create('c', {3}, {0., 0.85436, 1.40871}); - - logits.linspace(0.1, 0.1); + logits.linspace(0.1, 0.1); - sd::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&logits, &labels}, {}, {0}); + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {0}); - ASSERT_EQ(Status::OK(), results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), results.status()); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test6) { + auto labels = NDArrayFactory::create('c', {2, 1}, {0, 1}); + auto logits = NDArrayFactory::create('c', {2, 1}); + auto expected = NDArrayFactory::create('c', {1}, {0.6444}); - auto labels = NDArrayFactory::create('c', {2,1}, {0,1}); - auto logits = NDArrayFactory::create('c', {2,1}); - auto expected = NDArrayFactory::create('c', {1}, {0.6444}); - - logits.linspace(0.1, 0.1); - - sd::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&logits, &labels}, {}, {0}); + logits.linspace(0.1, 0.1); - ASSERT_EQ(Status::OK(), results.status()); + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {0}); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), results.status()); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test7) { + auto labels = NDArrayFactory::create('c', {2, 1}, {0, 1}); + auto logits = NDArrayFactory::create('c', {2, 1}); + auto expected = NDArrayFactory::create('c', {2}, {0., 0.}); - auto labels = NDArrayFactory::create('c', {2,1}, {0,1}); - auto logits = NDArrayFactory::create('c', {2,1}); - auto expected = NDArrayFactory::create('c', {2}, {0., 0.}); - - logits.linspace(0.1, 0.1); - - sd::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&logits, &labels}, {}, {1}); - - ASSERT_EQ(Status::OK(), results.status()); + logits.linspace(0.1, 0.1); - auto output = results.at(0); + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {1}); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), results.status()); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test8) { + auto labels = NDArrayFactory::create('c', {2}, {0, 1}); + auto logits = NDArrayFactory::create('c', {2}); + auto expected = NDArrayFactory::create(0.6444); - auto labels = NDArrayFactory::create('c', {2}, {0,1}); - auto logits = NDArrayFactory::create('c', {2}); - auto expected = NDArrayFactory::create(0.6444); + logits.linspace(0.1, 0.1); - logits.linspace(0.1, 0.1); + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {}); - sd::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&logits, &labels}, {}, {}); - - ASSERT_EQ(Status::OK(), results.status()); - - auto *output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), results.status()); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test9) { + auto labels = NDArrayFactory::create('c', {1}, {0.}); + auto logits = NDArrayFactory::create('c', {1}, {0.2}); + auto expected = NDArrayFactory::create(0.); - auto labels = NDArrayFactory::create('c', {1}, {0.}); - auto logits = NDArrayFactory::create('c', {1}, {0.2}); - auto expected = NDArrayFactory::create(0.); - - sd::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&logits, &labels}, {}, {}); + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {}); - ASSERT_EQ(Status::OK(), results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), results.status()); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test10) { + auto labels = NDArrayFactory::create('c', {1, 2}, {0, 1}); + auto logits = NDArrayFactory::create('c', {1, 2}); + auto expected = NDArrayFactory::create('c', {2}, {0., 0.}); - auto labels = NDArrayFactory::create('c', {1,2}, {0,1}); - auto logits = NDArrayFactory::create('c', {1,2}); - auto expected = NDArrayFactory::create('c', {2}, {0., 0.}); - - logits.linspace(0.1, 0.1); - - sd::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&logits, &labels}, {}, {0}); + logits.linspace(0.1, 0.1); - ASSERT_EQ(Status::OK(), results.status()); + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {0}); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), results.status()); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMeanBP_test4) { + auto x = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto gradO1 = NDArrayFactory::create('c', {4}, {1., 2., 3., 4.}); + auto gradO2 = NDArrayFactory::create('c', {1, 4}, {1., 2., 3., 4.}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {0.333333, 0.666667, 1.000000, 1.333333, 0.333333, 0.666667, 1.000000, + 1.333333, 0.333333, 0.666667, 1.000000, 1.333333}); - auto x = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }); - auto gradO1 = NDArrayFactory::create('c', {4}, {1., 2., 3., 4.}); - auto gradO2 = NDArrayFactory::create('c', {1, 4}, {1., 2., 3., 4.}); - auto exp = NDArrayFactory::create('c', {3,4}, {0.333333, 0.666667, 1.000000, 1.333333, 0.333333, 0.666667, 1.000000, 1.333333, 0.333333, 0.666667, 1.000000, 1.333333}); - - sd::ops::reduce_mean_bp op; - - auto result = op.evaluate({&x, &gradO1}, {0}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reduce_mean_bp op; - result = op.evaluate({&x, &gradO2}, {1}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + auto result = op.evaluate({&x, &gradO1}, {0}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + result = op.evaluate({&x, &gradO2}, {1}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMeanBP_test5) { + auto x = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto gradO1 = NDArrayFactory::create('c', {3}, {1., 2., 3.}); + auto gradO2 = NDArrayFactory::create('c', {3, 1}, {1., 2., 3.}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {0.2500, 0.2500, 0.2500, 0.2500, 0.5000, 0.5000, 0.5000, 0.5000, 0.7500, + 0.7500, 0.7500, 0.7500}); - auto x = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }); - auto gradO1 = NDArrayFactory::create('c', {3}, {1., 2., 3.}); - auto gradO2 = NDArrayFactory::create('c', {3, 1}, {1., 2., 3.}); - auto exp = NDArrayFactory::create('c', {3,4}, {0.2500,0.2500,0.2500,0.2500, 0.5000,0.5000,0.5000,0.5000, 0.7500,0.7500,0.7500,0.7500}); - - sd::ops::reduce_mean_bp op; - - auto result = op.evaluate({&x, &gradO1}, {0}, {1}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reduce_mean_bp op; - result = op.evaluate({&x, &gradO2}, {1}, {1}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + auto result = op.evaluate({&x, &gradO1}, {0}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + result = op.evaluate({&x, &gradO2}, {1}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDevBP_test5) { + auto x = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto gradO1 = NDArrayFactory::create('c', {4}, {1., 2., 3., 4.}); + auto gradO2 = NDArrayFactory::create('c', {1, 4}, {1., 2., 3., 4.}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {-0.408248, -0.816497, -1.224745, -1.632993, 0.000000, 0.000000, 0.000000, + 0.000000, 0.408248, 0.816497, 1.224745, 1.632993}); - auto x = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }); - auto gradO1 = NDArrayFactory::create('c', {4}, {1., 2., 3., 4.}); - auto gradO2 = NDArrayFactory::create('c', {1, 4}, {1., 2., 3., 4.}); - auto exp = NDArrayFactory::create('c', {3,4}, {-0.408248, -0.816497, -1.224745, -1.632993, 0.000000, 0.000000, 0.000000, 0.000000, 0.408248, 0.816497, 1.224745, 1.632993}); - - sd::ops::reduce_stdev_bp op; - - auto result = op.evaluate({&x, &gradO1}, {0}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::reduce_stdev_bp op; - result = op.evaluate({&x, &gradO2}, {1}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + auto result = op.evaluate({&x, &gradO1}, {0}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + result = op.evaluate({&x, &gradO2}, {1}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, zeros_as_test1) { + auto x = NDArrayFactory::create(10.f); + auto y = NDArrayFactory::create(100.f); + auto exp = NDArrayFactory::create(0.f); - auto x = NDArrayFactory::create(10.f); - auto y = NDArrayFactory::create(100.f); - auto exp = NDArrayFactory::create(0.f); + sd::ops::zeros_as op; - sd::ops::zeros_as op; - - Nd4jStatus status = op.execute({&x}, {&y}, {}, {}, {}); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(y.isSameShape(exp)); - ASSERT_TRUE(y.equalsTo(exp)); + Nd4jStatus status = op.execute({&x}, {&y}, {}, {}, {}); + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(y.isSameShape(exp)); + ASSERT_TRUE(y.equalsTo(exp)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, zeros_as_test2) { + auto x = NDArrayFactory::create(10.f); + // auto y = NDArrayFactory::create(100.f); + auto exp = NDArrayFactory::create(0.f); - auto x = NDArrayFactory::create(10.f); - //auto y = NDArrayFactory::create(100.f); - auto exp = NDArrayFactory::create(0.f); - - sd::ops::zeros_as op; - - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto y = result.at(0); + sd::ops::zeros_as op; - ASSERT_TRUE(y->isSameShape(exp)); - ASSERT_TRUE(y->equalsTo(exp)); + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto y = result.at(0); + ASSERT_TRUE(y.isSameShape(exp)); + ASSERT_TRUE(y.equalsTo(exp)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, ones_as_test1) { + auto x = NDArrayFactory::create(10.); + auto y = NDArrayFactory::create(100.); + auto exp = NDArrayFactory::create(1.); - auto x = NDArrayFactory::create(10.); - auto y = NDArrayFactory::create(100.); - auto exp = NDArrayFactory::create(1.); + sd::ops::ones_as op; - sd::ops::ones_as op; - - Nd4jStatus status = op.execute({&x}, {&y}); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(y.isSameShape(exp)); - ASSERT_TRUE(y.equalsTo(exp)); + Nd4jStatus status = op.execute({&x}, {&y}); + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(y.isSameShape(exp)); + ASSERT_TRUE(y.equalsTo(exp)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, ones_as_test2) { + auto x = NDArrayFactory::create(10.); + // auto y = NDArrayFactory::create(100.); + auto exp = NDArrayFactory::create(1.); - auto x = NDArrayFactory::create(10.); - //auto y = NDArrayFactory::create(100.); - auto exp = NDArrayFactory::create(1.); - - sd::ops::ones_as op; - - auto results = op.evaluate({&x}); - ASSERT_EQ(Status::OK(), results.status()); - auto y = results.at(0); - ASSERT_TRUE(y->isSameShape(exp)); - ASSERT_TRUE(y->equalsTo(exp)); - + sd::ops::ones_as op; + auto results = op.evaluate({&x}); + ASSERT_EQ(Status::OK(), results.status()); + auto y = results.at(0); + ASSERT_TRUE(y.isSameShape(exp)); + ASSERT_TRUE(y.equalsTo(exp)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, ones_as_test3) { + auto x = NDArrayFactory::create(10.); + // auto y = NDArrayFactory::create(100.); + auto exp = NDArrayFactory::create(1.); - auto x = NDArrayFactory::create(10.); - //auto y = NDArrayFactory::create(100.); - auto exp = NDArrayFactory::create(1.); - - sd::ops::ones_as op; - - auto results = op.evaluate({&x}, {}, {}, {}, {sd::DataType::INT32}); - ASSERT_EQ(Status::OK(), results.status()); - auto y = results.at(0); - - ASSERT_TRUE(y->isSameShape(exp)); - ASSERT_TRUE(y->equalsTo(exp)); + sd::ops::ones_as op; + auto results = op.evaluate({&x}, {}, {}, {}, {sd::DataType::INT32}); + ASSERT_EQ(Status::OK(), results.status()); + auto y = results.at(0); + ASSERT_TRUE(y.isSameShape(exp)); + ASSERT_TRUE(y.equalsTo(exp)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, NormalizeMoments_SGO_1) { - - auto data = NDArrayFactory::create('c', {10, 10}); - data.linspace(1); - - auto means = data.reduceAlongDimension(reduce::Sum, {0}); - auto deviance = NDArrayFactory::create('c', {10}, {825., 825. , 825., 825., 825., 825., 825., 825., 825., 825. }); // data.varianceAlongDimension(variance::SummaryStatsVariance, false, {0}); // = NDArrayFactory::create('c', {10, 10}); - - auto counts = NDArrayFactory::create(10.0); - -// auto expMeans = NDArrayFactory::create('c', {10, 10}); - -// auto expDeviance = NDArrayFactory::create('c', {10, 10}); - auto squared = NDArrayFactory::create('c', {10, 10}); - data.applyTransform(transform::Square, squared); - auto ssSquared = squared.reduceAlongDimension(reduce::Sum, {0}); -// ssSquared->printBuffer("Sum squared"); -// squared.printBuffer("Squared"); - sd::ops::normalize_moments op; - auto results = op.evaluate({&counts, &means, &ssSquared}, {0.0}, {0}); - means /= counts; -// sd::ops::normalize_moments op; -// auto results = op.evaluate({&counts, means, deviance}, {0.0}, {}); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_EQ(results.size(), 2); - - auto outputMeans = results.at(0); - auto outputDeviance = results.at(1); - -// outputMeans->printIndexedBuffer("Means"); -// outputDeviance->printIndexedBuffer("Variance"); -// deviance.printIndexedBuffer("Expected"); -// means->printIndexedBuffer("Expected means"); - ASSERT_TRUE(means.isSameShape(outputMeans)); - ASSERT_TRUE(means.equalsTo(outputMeans)); - ASSERT_TRUE(deviance.isSameShape(outputDeviance)); - ASSERT_TRUE(deviance.equalsTo(outputDeviance)); - //delete deviance; -// ASSERT_TRUE(expMeans.isSameShape(outputMeans)); -// ASSERT_TRUE(expMeans.equalsTo(outputMeans)); -// ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); -// ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); - - + auto data = NDArrayFactory::create('c', {10, 10}); + data.linspace(1); + + auto means = data.reduceAlongDimension(reduce::Sum, {0}); + auto deviance = NDArrayFactory::create( + 'c', {10}, + {825., 825., 825., 825., 825., 825., 825., 825., 825., + 825.}); // data.varianceAlongDimension(variance::SummaryStatsVariance, + // false, {0}); // = NDArrayFactory::create('c', {10, + // 10}); + + auto counts = NDArrayFactory::create(10.0); + + // auto expMeans = NDArrayFactory::create('c', {10, 10}); + + // auto expDeviance = NDArrayFactory::create('c', {10, 10}); + auto squared = NDArrayFactory::create('c', {10, 10}); + data.applyTransform(transform::Square, squared); + auto ssSquared = squared.reduceAlongDimension(reduce::Sum, {0}); + // ssSquared->printBuffer("Sum squared"); + // squared.printBuffer("Squared"); + sd::ops::normalize_moments op; + auto results = op.evaluate({&counts, &means, &ssSquared}, {0.0}, {0}); + means /= counts; + // sd::ops::normalize_moments op; + // auto results = op.evaluate({&counts, means, deviance}, {0.0}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_EQ(results.size(), 2); + + auto outputMeans = results.at(0); + auto outputDeviance = results.at(1); + + // outputMeans->printIndexedBuffer("Means"); + // outputDeviance->printIndexedBuffer("Variance"); + // deviance.printIndexedBuffer("Expected"); + // means->printIndexedBuffer("Expected means"); + ASSERT_TRUE(means.isSameShape(outputMeans)); + ASSERT_TRUE(means.equalsTo(outputMeans)); + ASSERT_TRUE(deviance.isSameShape(outputDeviance)); + ASSERT_TRUE(deviance.equalsTo(outputDeviance)); + // delete deviance; + // ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + // ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + // ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); + // ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Moments_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto expMeans = + NDArrayFactory::create('c', {4}, {11.f, 12.f, 13.f, 14.f}); + auto expVariance = NDArrayFactory::create( + 'c', {4}, {46.666668f, 46.666668f, 46.66666f, 46.666668f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto expMeans = NDArrayFactory::create('c', {4}, {11.f, 12.f, 13.f, 14.f}); - auto expVariance = NDArrayFactory::create('c', {4}, {46.666668f, 46.666668f, 46.66666f, 46.666668f}); - x.linspace(1); - - sd::ops::moments op; - auto result = op.evaluate({&x}, {}, {0, 1}); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::moments op; + auto result = op.evaluate({&x}, {}, {0, 1}); - auto outputMeans = result.at(0); - auto outputVariance = result.at(1); + ASSERT_EQ(Status::OK(), result.status()); -// outputMeans->printIndexedBuffer("Means"); -// outputVariance->printIndexedBuffer("Variance"); -// outputMeans->printShapeInfo("Result shape"); - - -// ASSERT_TRUE(exp.isSameShape(output)); -// ASSERT_TRUE(exp.equalsTo(output)); - ASSERT_TRUE(expMeans.isSameShape(outputMeans)); - ASSERT_TRUE(expMeans.equalsTo(outputMeans)); - ASSERT_TRUE(expVariance.isSameShape(outputVariance)); - ASSERT_TRUE(expVariance.equalsTo(outputVariance)); + auto outputMeans = result.at(0); + auto outputVariance = result.at(1); + // outputMeans->printIndexedBuffer("Means"); + // outputVariance->printIndexedBuffer("Variance"); + // outputMeans->printShapeInfo("Result shape"); + // ASSERT_TRUE(exp.isSameShape(output)); + // ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expVariance.isSameShape(outputVariance)); + ASSERT_TRUE(expVariance.equalsTo(outputVariance)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Moments_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto expMeans = + NDArrayFactory::create('c', {1, 1, 4}, {11.f, 12.f, 13.f, 14.f}); + auto expVariance = NDArrayFactory::create( + 'c', {1, 1, 4}, {46.666668f, 46.666668f, 46.66666f, 46.666668f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto expMeans = NDArrayFactory::create('c', {1,1,4}, {11.f, 12.f, 13.f, 14.f}); - auto expVariance = NDArrayFactory::create('c', {1,1,4}, {46.666668f, 46.666668f, 46.66666f, 46.666668f}); - x.linspace(1); - - sd::ops::moments op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto outputMeans = result.at(0); - auto outputVariance = result.at(1); + sd::ops::moments op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + ASSERT_EQ(Status::OK(), result.status()); -// outputMeans->printIndexedBuffer("Means"); -// outputVariance->printIndexedBuffer("Variance"); -// outputMeans->printShapeInfo("Result shape"); - -// ASSERT_TRUE(exp.isSameShape(output)); -// ASSERT_TRUE(exp.equalsTo(output)); - ASSERT_TRUE(expMeans.isSameShape(outputMeans)); - ASSERT_TRUE(expMeans.equalsTo(outputMeans)); - ASSERT_TRUE(expVariance.isSameShape(outputVariance)); - ASSERT_TRUE(expVariance.equalsTo(outputVariance)); + auto outputMeans = result.at(0); + auto outputVariance = result.at(1); + // outputMeans->printIndexedBuffer("Means"); + // outputVariance->printIndexedBuffer("Variance"); + // outputMeans->printShapeInfo("Result shape"); + // ASSERT_TRUE(exp.isSameShape(output)); + // ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expVariance.isSameShape(outputVariance)); + ASSERT_TRUE(expVariance.equalsTo(outputVariance)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Moments_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto expMeans = + NDArrayFactory::create('c', {3}, {8.5f, 12.5f, 16.5f}); + auto expVariance = + NDArrayFactory::create('c', {3}, {37.25f, 37.25f, 37.25f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto expMeans = NDArrayFactory::create('c', {3}, {8.5f, 12.5f, 16.5f}); - auto expVariance = NDArrayFactory::create('c', {3}, {37.25f, 37.25f, 37.25f}); - x.linspace(1); - - sd::ops::moments op; - auto result = op.evaluate({&x}, {}, {0, 2}); - ASSERT_EQ(Status::OK(), result.status()); - - auto outputMeans = result.at(0); - auto outputVariance = result.at(1); - -// outputMeans->printIndexedBuffer("Means"); -// outputVariance->printIndexedBuffer("Variance"); -// outputMeans->printShapeInfo("Result shape"); + sd::ops::moments op; + auto result = op.evaluate({&x}, {}, {0, 2}); + ASSERT_EQ(Status::OK(), result.status()); -// ASSERT_TRUE(exp.isSameShape(output)); -// ASSERT_TRUE(exp.equalsTo(output)); - ASSERT_TRUE(expMeans.isSameShape(outputMeans)); - ASSERT_TRUE(expMeans.equalsTo(outputMeans)); - ASSERT_TRUE(expVariance.isSameShape(outputVariance)); - ASSERT_TRUE(expVariance.equalsTo(outputVariance)); + auto outputMeans = result.at(0); + auto outputVariance = result.at(1); + // outputMeans->printIndexedBuffer("Means"); + // outputVariance->printIndexedBuffer("Variance"); + // outputMeans->printShapeInfo("Result shape"); + // ASSERT_TRUE(exp.isSameShape(output)); + // ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expVariance.isSameShape(outputVariance)); + ASSERT_TRUE(expVariance.equalsTo(outputVariance)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Moments_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto expMeans = + NDArrayFactory::create('c', {1, 3, 1}, {8.5f, 12.5f, 16.5f}); + auto expVariance = + NDArrayFactory::create('c', {1, 3, 1}, {37.25f, 37.25f, 37.25f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto expMeans = NDArrayFactory::create('c', {1,3,1}, {8.5f, 12.5f, 16.5f}); - auto expVariance = NDArrayFactory::create('c', {1,3,1}, {37.25f, 37.25f, 37.25f}); - x.linspace(1); + sd::ops::moments op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::moments op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - ASSERT_EQ(Status::OK(), result.status()); - - auto outputMeans = result.at(0); - auto outputVariance = result.at(1); - -// outputMeans->printIndexedBuffer("Means"); -// outputVariance->printIndexedBuffer("Variance"); -// outputMeans->printShapeInfo("Result shape"); - -// ASSERT_TRUE(exp.isSameShape(output)); -// ASSERT_TRUE(exp.equalsTo(output)); - ASSERT_TRUE(expMeans.isSameShape(outputMeans)); - ASSERT_TRUE(expMeans.equalsTo(outputMeans)); - ASSERT_TRUE(expVariance.isSameShape(outputVariance)); - ASSERT_TRUE(expVariance.equalsTo(outputVariance)); + auto outputMeans = result.at(0); + auto outputVariance = result.at(1); + // outputMeans->printIndexedBuffer("Means"); + // outputVariance->printIndexedBuffer("Variance"); + // outputMeans->printShapeInfo("Result shape"); + // ASSERT_TRUE(exp.isSameShape(output)); + // ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expVariance.isSameShape(outputVariance)); + ASSERT_TRUE(expVariance.equalsTo(outputVariance)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Moments_6) { - auto expMeans = NDArrayFactory::create(12.5f); - auto expVariance = NDArrayFactory::create(47.916668f); - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - x.linspace(1); + auto expMeans = NDArrayFactory::create(12.5f); + auto expVariance = NDArrayFactory::create(47.916668f); - sd::ops::moments op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); - auto outputMeans = result.at(0); - auto outputVariance = result.at(1); + sd::ops::moments op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + ASSERT_EQ(Status::OK(), result.status()); -// outputMeans->printIndexedBuffer("Means"); -// outputVariance->printIndexedBuffer("Variance"); -// outputMeans->printShapeInfo("Result shape"); - - ASSERT_TRUE(expMeans.isSameShape(outputMeans)); - ASSERT_TRUE(expMeans.equalsTo(outputMeans)); - ASSERT_TRUE(expVariance.isSameShape(outputVariance)); - ASSERT_TRUE(expVariance.equalsTo(outputVariance)); + auto outputMeans = result.at(0); + auto outputVariance = result.at(1); + // outputMeans->printIndexedBuffer("Means"); + // outputVariance->printIndexedBuffer("Variance"); + // outputMeans->printShapeInfo("Result shape"); + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expVariance.isSameShape(outputVariance)); + ASSERT_TRUE(expVariance.equalsTo(outputVariance)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Moments_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - - auto expMeans = NDArrayFactory::create('c', {1,1,1}, {12.5f}); - auto expVariance = NDArrayFactory::create('c', {1,1,1}, {47.916668f}); - - x.linspace(1); - // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::moments op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - ASSERT_EQ(Status::OK(), result.status()); - - auto outputMeans = result.at(0); - auto outputVariance = result.at(1); + auto expMeans = NDArrayFactory::create('c', {1, 1, 1}, {12.5f}); + auto expVariance = + NDArrayFactory::create('c', {1, 1, 1}, {47.916668f}); -// outputMeans->printIndexedBuffer("Means"); -// outputVariance->printIndexedBuffer("Variance"); -// outputMeans->printShapeInfo("Result shape"); - ASSERT_TRUE(expMeans.isSameShape(outputMeans)); - ASSERT_TRUE(expMeans.equalsTo(outputMeans)); - ASSERT_TRUE(expVariance.isSameShape(outputVariance)); - ASSERT_TRUE(expVariance.equalsTo(outputVariance)); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::moments op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + ASSERT_EQ(Status::OK(), result.status()); + auto outputMeans = result.at(0); + auto outputVariance = result.at(1); + // outputMeans->printIndexedBuffer("Means"); + // outputVariance->printIndexedBuffer("Variance"); + // outputMeans->printShapeInfo("Result shape"); + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expVariance.isSameShape(outputVariance)); + ASSERT_TRUE(expVariance.equalsTo(outputVariance)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_01) { + auto x = NDArrayFactory::create( + 'c', {1, 1, 2, 5}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f}); - auto x = NDArrayFactory::create('c', {1, 1, 2, 5}, { 1.f, 2.f, 3.f, 4.f, 5.f, - 6.f, 7.f, 8.f, 9.f, 10.f} - ); - - auto exp = NDArrayFactory::create('c', {1, 1, 2, 5}, {0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, 0.4898979f, 0.46056613f, 0.43971977f, 0.5240003f, 0.6375767f}// 0.72760683, 0.4850712, 0.5848977, 0.67488194, -// 0.7581754, 0.58321184, 0.86747235, 0.4048204} - ); - - sd::ops::lrn op; - auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - //ASSERT_TRUE(exp.isSameShape(out)); - //out->printBuffer("LRN out"); - //exp.printBuffer("LRN exp"); - ASSERT_TRUE(exp.equalsTo(out)); + auto exp = NDArrayFactory::create( + 'c', {1, 1, 2, 5}, + {0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, 0.4898979f, + 0.46056613f, 0.43971977f, 0.5240003f, 0.6375767f} + // 0.72760683, 0.4850712, 0.5848977, 0.67488194, + // 0.7581754, 0.58321184, 0.86747235, 0.4048204} + ); + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); + auto out = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + // ASSERT_TRUE(exp.isSameShape(out)); + // out->printBuffer("LRN out"); + // exp.printBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_02) { + auto x = NDArrayFactory::create('c', {1, 1, 1, 6}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto x = NDArrayFactory::create('c', {1, 1, 1, 6}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - - auto exp = NDArrayFactory::create('c', {1, 1, 1, 6}, { - 0.2581989f, 0.3592106f, 0.40089184f, 0.4193139f, 0.5360563f, 0.67936623f} - ); - - sd::ops::lrn op; - auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - //ASSERT_TRUE(exp.isSameShape(out)); - //out->printIndexedBuffer("LRN out"); -// exp.printIndexedBuffer("LRN exp"); - ASSERT_TRUE(exp.equalsTo(out)); + auto exp = + NDArrayFactory::create('c', {1, 1, 1, 6}, + {0.2581989f, 0.3592106f, 0.40089184f, + 0.4193139f, 0.5360563f, 0.67936623f}); + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); + auto out = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + // ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_03) { + auto x = NDArrayFactory::create( + 'c', {1, 1, 1, 10}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f}); + auto exp = NDArrayFactory::create( + 'c', {1, 1, 1, 10}, + {0.10425719f, 0.16843036f, 0.2095291f, 0.23652494f, 0.25449327f, + 0.3053919f, 0.35675305f, 0.4098524f, 0.46662825f, 0.52999896f}); - auto x = NDArrayFactory::create('c', {1, 1, 1, 10}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f}); - auto exp = NDArrayFactory::create('c', {1, 1, 1, 10}, {0.10425719f, 0.16843036f, 0.2095291f, 0.23652494f, 0.25449327f, 0.3053919f, 0.35675305f, 0.4098524f, 0.46662825f, 0.52999896f}); - - sd::ops::lrn op; - auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {5}); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(out)); - // out->printIndexedBuffer("LRN out"); -// exp.printIndexedBuffer("LRN exp"); - ASSERT_TRUE(exp.equalsTo(out)); - + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {5}); + auto out = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_1) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, + {5.5f, 0.f, 0.3f, 5.5f, 8.6f, 0.f, 0.f, 0.4f, 1.5f, 1.f, 1.3f, 1.5f, 2.6f, + 2.f, 3.f, 1.4f}); - auto x = NDArrayFactory::create('c', {2, 2, 2, 2}, { 5.5f, 0.f, 0.3f, 5.5f, - 8.6f, 0.f, 0.f, 0.4f, - 1.5f, 1.f, 1.3f, 1.5f, - 2.6f, 2.f, 3.f, 1.4f} - ); - - auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, { - 0.98386997f, 0.f, 0.05358852f, 0.9824562f, - 0.99330735f, 0.f, 0.f, 0.37139067f, - 0.72760683f, 0.4850712f, 0.5848977f, 0.67488194f, - 0.7581754f, 0.58321184f, 0.86747235f, 0.4048204f} - ); - - sd::ops::lrn op; - auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(out)); -// out->printIndexedBuffer("LRN out"); -// exp.printIndexedBuffer("LRN exp"); - ASSERT_TRUE(exp.equalsTo(out)); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, + {0.98386997f, 0.f, 0.05358852f, 0.9824562f, 0.99330735f, 0.f, 0.f, + 0.37139067f, 0.72760683f, 0.4850712f, 0.5848977f, 0.67488194f, + 0.7581754f, 0.58321184f, 0.86747235f, 0.4048204f}); + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); + auto out = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_2) { - - auto x = NDArrayFactory::create('c', {3, 3, 5, 5}); - x.linspace(1); - - auto exp = NDArrayFactory::create('c', {3, 3, 5, 5}, { - 0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, - 0.4898979f, 0.46056613f, 0.43971977f, 0.5240002f, 0.6375767f, - 0.5274096f, 0.47771242f, 0.4443308f, 0.5163977f, 0.61701745f, - 0.5424508f, 0.48452914f, 0.44570294f, 0.5123918f, 0.6068971f, - 0.5505386f, 0.4881662f, 0.4462865f, 0.5099462f, 0.60088515f, - - 0.5555859f, 0.49042296f, 0.44658744f, 0.5083028f, 0.59690416f, - 0.55903524f, 0.4919585f, 0.44676256f, 0.5071239f, 0.59407425f, - 0.5615412f, 0.49307042f, 0.44687328f, 0.50623745f, 0.5919596f, - 0.56344414f, 0.49391258f, 0.4469477f, 0.5055468f, 0.59031945f, - 0.56493837f, 0.49457246f, 0.4470002f, 0.5049936f, 0.5890103f, - - 0.56614274f, 0.49510333f, 0.44703856f, 0.50454074f, 0.5879411f, - 0.567134f, 0.49553978f, 0.4470674f, 0.504163f, 0.5870515f, - 0.5679643f, 0.4959048f, 0.44708967f, 0.5038433f, 0.5862998f, - 0.56866974f, 0.4962146f, 0.44710726f, 0.5035692f, 0.58565617f, - 0.56927663f, 0.49648085f, 0.4471213f, 0.5033315f, 0.5850988f, - - - 0.56980413f, 0.49671215f, 0.44713274f, 0.50312346f, 0.58461165f, - 0.57026696f, 0.49691492f, 0.4471422f, 0.50293994f, 0.58418214f, - 0.5706764f, 0.49709415f, 0.44715008f, 0.5027767f, 0.5838005f, - 0.571041f, 0.4972537f, 0.44715673f, 0.50263065f, 0.58345926f, - 0.57136786f, 0.49739665f, 0.44716236f, 0.5024992f, 0.58315235f, - - 0.5716625f, 0.49752548f, 0.4471672f, 0.5023803f, 0.5828747f, - 0.5719295f, 0.49764213f, 0.44717142f, 0.5022721f, 0.5826225f, - 0.57217246f, 0.49774826f, 0.44717506f, 0.5021734f, 0.58239233f, - 0.5723947f, 0.4978453f, 0.44717824f, 0.5020829f, 0.58218133f, - 0.57259864f, 0.49793428f, 0.44718108f, 0.5019997f, 0.5819874f, - - 0.5727864f, 0.49801624f, 0.44718358f, 0.5019227f, 0.5818083f, - 0.57296f, 0.49809194f, 0.44718578f, 0.5018515f, 0.5816426f, - 0.5731208f, 0.49816203f, 0.44718775f, 0.5017854f, 0.58148885f, - 0.57327026f, 0.49822718f, 0.4471895f, 0.5017239f, 0.5813457f, - 0.57340944f, 0.49828786f, 0.44719115f, 0.5016664f, 0.581212f, - - - 0.57353944f, 0.4983446f, 0.44719255f, 0.50161266f, 0.58108705f, - 0.5736612f, 0.49839762f, 0.4471939f, 0.50156236f, 0.5809699f, - 0.5737754f, 0.4984474f, 0.44719502f, 0.501515f, 0.58085984f, - 0.5738828f, 0.49849418f, 0.4471962f, 0.50147045f, 0.5807564f, - 0.5739839f, 0.49853817f, 0.44719717f, 0.5014284f, 0.5806588f, - - 0.5740793f, 0.49857965f, 0.4471981f, 0.5013887f, 0.5805666f, - 0.5741694f, 0.49861887f, 0.44719887f, 0.50135124f, 0.58047944f, - 0.57425463f, 0.49865603f, 0.44719967f, 0.5013157f, 0.5803969f, - 0.5743354f, 0.4986912f, 0.44720036f, 0.5012819f, 0.5803186f, - 0.57441217f, 0.49872455f, 0.44720104f, 0.5012499f, 0.58024424f, - - 0.57448506f, 0.4987563f, 0.44720164f, 0.5012194f, 0.58017343f, - 0.57455444f, 0.4987865f, 0.4472022f, 0.5011904f, 0.5801061f, - 0.57462054f, 0.49881527f, 0.44720277f, 0.5011627f, 0.5800419f, - 0.57468355f, 0.49884263f, 0.44720328f, 0.50113624f, 0.5799805f, - 0.57474375f, 0.49886885f, 0.44720373f, 0.50111103f, 0.5799219f } - ); -// - sd::ops::lrn op; - auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); -// ASSERT_TRUE(exp.isSameShape(out)); -// out->printIndexedBuffer("LRN out"); -// exp.printIndexedBuffer("LRN exp"); - ASSERT_TRUE(exp.equalsTo(out)); - - + auto x = NDArrayFactory::create('c', {3, 3, 5, 5}); + x.linspace(1); + + auto exp = NDArrayFactory::create( + 'c', {3, 3, 5, 5}, + {0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, + 0.4898979f, 0.46056613f, 0.43971977f, 0.5240002f, 0.6375767f, + 0.5274096f, 0.47771242f, 0.4443308f, 0.5163977f, 0.61701745f, + 0.5424508f, 0.48452914f, 0.44570294f, 0.5123918f, 0.6068971f, + 0.5505386f, 0.4881662f, 0.4462865f, 0.5099462f, 0.60088515f, + + 0.5555859f, 0.49042296f, 0.44658744f, 0.5083028f, 0.59690416f, + 0.55903524f, 0.4919585f, 0.44676256f, 0.5071239f, 0.59407425f, + 0.5615412f, 0.49307042f, 0.44687328f, 0.50623745f, 0.5919596f, + 0.56344414f, 0.49391258f, 0.4469477f, 0.5055468f, 0.59031945f, + 0.56493837f, 0.49457246f, 0.4470002f, 0.5049936f, 0.5890103f, + + 0.56614274f, 0.49510333f, 0.44703856f, 0.50454074f, 0.5879411f, + 0.567134f, 0.49553978f, 0.4470674f, 0.504163f, 0.5870515f, + 0.5679643f, 0.4959048f, 0.44708967f, 0.5038433f, 0.5862998f, + 0.56866974f, 0.4962146f, 0.44710726f, 0.5035692f, 0.58565617f, + 0.56927663f, 0.49648085f, 0.4471213f, 0.5033315f, 0.5850988f, + + 0.56980413f, 0.49671215f, 0.44713274f, 0.50312346f, 0.58461165f, + 0.57026696f, 0.49691492f, 0.4471422f, 0.50293994f, 0.58418214f, + 0.5706764f, 0.49709415f, 0.44715008f, 0.5027767f, 0.5838005f, + 0.571041f, 0.4972537f, 0.44715673f, 0.50263065f, 0.58345926f, + 0.57136786f, 0.49739665f, 0.44716236f, 0.5024992f, 0.58315235f, + + 0.5716625f, 0.49752548f, 0.4471672f, 0.5023803f, 0.5828747f, + 0.5719295f, 0.49764213f, 0.44717142f, 0.5022721f, 0.5826225f, + 0.57217246f, 0.49774826f, 0.44717506f, 0.5021734f, 0.58239233f, + 0.5723947f, 0.4978453f, 0.44717824f, 0.5020829f, 0.58218133f, + 0.57259864f, 0.49793428f, 0.44718108f, 0.5019997f, 0.5819874f, + + 0.5727864f, 0.49801624f, 0.44718358f, 0.5019227f, 0.5818083f, + 0.57296f, 0.49809194f, 0.44718578f, 0.5018515f, 0.5816426f, + 0.5731208f, 0.49816203f, 0.44718775f, 0.5017854f, 0.58148885f, + 0.57327026f, 0.49822718f, 0.4471895f, 0.5017239f, 0.5813457f, + 0.57340944f, 0.49828786f, 0.44719115f, 0.5016664f, 0.581212f, + + 0.57353944f, 0.4983446f, 0.44719255f, 0.50161266f, 0.58108705f, + 0.5736612f, 0.49839762f, 0.4471939f, 0.50156236f, 0.5809699f, + 0.5737754f, 0.4984474f, 0.44719502f, 0.501515f, 0.58085984f, + 0.5738828f, 0.49849418f, 0.4471962f, 0.50147045f, 0.5807564f, + 0.5739839f, 0.49853817f, 0.44719717f, 0.5014284f, 0.5806588f, + + 0.5740793f, 0.49857965f, 0.4471981f, 0.5013887f, 0.5805666f, + 0.5741694f, 0.49861887f, 0.44719887f, 0.50135124f, 0.58047944f, + 0.57425463f, 0.49865603f, 0.44719967f, 0.5013157f, 0.5803969f, + 0.5743354f, 0.4986912f, 0.44720036f, 0.5012819f, 0.5803186f, + 0.57441217f, 0.49872455f, 0.44720104f, 0.5012499f, 0.58024424f, + + 0.57448506f, 0.4987563f, 0.44720164f, 0.5012194f, 0.58017343f, + 0.57455444f, 0.4987865f, 0.4472022f, 0.5011904f, 0.5801061f, + 0.57462054f, 0.49881527f, 0.44720277f, 0.5011627f, 0.5800419f, + 0.57468355f, 0.49884263f, 0.44720328f, 0.50113624f, 0.5799805f, + 0.57474375f, 0.49886885f, 0.44720373f, 0.50111103f, 0.5799219f}); + // + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + // ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_3) { - - auto x = NDArrayFactory::create('c', {3, 3, 5, 5}); - x.linspace(1); - - auto exp = NDArrayFactory::create('c', {3, 3, 5, 5}, { - 0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, - 0.4898979f, 0.46056613f, 0.43971977f, 0.5240002f, 0.6375767f, - 0.5274096f, 0.47771242f, 0.4443308f, 0.5163977f, 0.61701745f, - 0.5424508f, 0.48452914f, 0.44570294f, 0.5123918f, 0.6068971f, - 0.5505386f, 0.4881662f, 0.4462865f, 0.5099462f, 0.60088515f, - - 0.5555859f, 0.49042296f, 0.44658744f, 0.5083028f, 0.59690416f, - 0.55903524f, 0.4919585f, 0.44676256f, 0.5071239f, 0.59407425f, - 0.5615412f, 0.49307042f, 0.44687328f, 0.50623745f, 0.5919596f, - 0.56344414f, 0.49391258f, 0.4469477f, 0.5055468f, 0.59031945f, - 0.56493837f, 0.49457246f, 0.4470002f, 0.5049936f, 0.5890103f, - - 0.56614274f, 0.49510333f, 0.44703856f, 0.50454074f, 0.5879411f, - 0.567134f, 0.49553978f, 0.4470674f, 0.504163f, 0.5870515f, - 0.5679643f, 0.4959048f, 0.44708967f, 0.5038433f, 0.5862998f, - 0.56866974f, 0.4962146f, 0.44710726f, 0.5035692f, 0.58565617f, - 0.56927663f, 0.49648085f, 0.4471213f, 0.5033315f, 0.5850988f, - - - 0.56980413f, 0.49671215f, 0.44713274f, 0.50312346f, 0.58461165f, - 0.57026696f, 0.49691492f, 0.4471422f, 0.50293994f, 0.58418214f, - 0.5706764f, 0.49709415f, 0.44715008f, 0.5027767f, 0.5838005f, - 0.571041f, 0.4972537f, 0.44715673f, 0.50263065f, 0.58345926f, - 0.57136786f, 0.49739665f, 0.44716236f, 0.5024992f, 0.58315235f, - - 0.5716625f, 0.49752548f, 0.4471672f, 0.5023803f, 0.5828747f, - 0.5719295f, 0.49764213f, 0.44717142f, 0.5022721f, 0.5826225f, - 0.57217246f, 0.49774826f, 0.44717506f, 0.5021734f, 0.58239233f, - 0.5723947f, 0.4978453f, 0.44717824f, 0.5020829f, 0.58218133f, - 0.57259864f, 0.49793428f, 0.44718108f, 0.5019997f, 0.5819874f, - - 0.5727864f, 0.49801624f, 0.44718358f, 0.5019227f, 0.5818083f, - 0.57296f, 0.49809194f, 0.44718578f, 0.5018515f, 0.5816426f, - 0.5731208f, 0.49816203f, 0.44718775f, 0.5017854f, 0.58148885f, - 0.57327026f, 0.49822718f, 0.4471895f, 0.5017239f, 0.5813457f, - 0.57340944f, 0.49828786f, 0.44719115f, 0.5016664f, 0.581212f, - - - 0.57353944f, 0.4983446f, 0.44719255f, 0.50161266f, 0.58108705f, - 0.5736612f, 0.49839762f, 0.4471939f, 0.50156236f, 0.5809699f, - 0.5737754f, 0.4984474f, 0.44719502f, 0.501515f, 0.58085984f, - 0.5738828f, 0.49849418f, 0.4471962f, 0.50147045f, 0.5807564f, - 0.5739839f, 0.49853817f, 0.44719717f, 0.5014284f, 0.5806588f, - - 0.5740793f, 0.49857965f, 0.4471981f, 0.5013887f, 0.5805666f, - 0.5741694f, 0.49861887f, 0.44719887f, 0.50135124f, 0.58047944f, - 0.57425463f, 0.49865603f, 0.44719967f, 0.5013157f, 0.5803969f, - 0.5743354f, 0.4986912f, 0.44720036f, 0.5012819f, 0.5803186f, - 0.57441217f, 0.49872455f, 0.44720104f, 0.5012499f, 0.58024424f, - - 0.57448506f, 0.4987563f, 0.44720164f, 0.5012194f, 0.58017343f, - 0.57455444f, 0.4987865f, 0.4472022f, 0.5011904f, 0.5801061f, - 0.57462054f, 0.49881527f, 0.44720277f, 0.5011627f, 0.5800419f, - 0.57468355f, 0.49884263f, 0.44720328f, 0.50113624f, 0.5799805f, - 0.57474375f, 0.49886885f, 0.44720373f, 0.50111103f, 0.5799219f } - ); -// - sd::ops::lrn op; - auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); -// ASSERT_TRUE(exp.isSameShape(out)); -// out->printIndexedBuffer("LRN out"); -// exp.printIndexedBuffer("LRN exp"); - ASSERT_TRUE(exp.equalsTo(out)); - - + auto x = NDArrayFactory::create('c', {3, 3, 5, 5}); + x.linspace(1); + + auto exp = NDArrayFactory::create( + 'c', {3, 3, 5, 5}, + {0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, + 0.4898979f, 0.46056613f, 0.43971977f, 0.5240002f, 0.6375767f, + 0.5274096f, 0.47771242f, 0.4443308f, 0.5163977f, 0.61701745f, + 0.5424508f, 0.48452914f, 0.44570294f, 0.5123918f, 0.6068971f, + 0.5505386f, 0.4881662f, 0.4462865f, 0.5099462f, 0.60088515f, + + 0.5555859f, 0.49042296f, 0.44658744f, 0.5083028f, 0.59690416f, + 0.55903524f, 0.4919585f, 0.44676256f, 0.5071239f, 0.59407425f, + 0.5615412f, 0.49307042f, 0.44687328f, 0.50623745f, 0.5919596f, + 0.56344414f, 0.49391258f, 0.4469477f, 0.5055468f, 0.59031945f, + 0.56493837f, 0.49457246f, 0.4470002f, 0.5049936f, 0.5890103f, + + 0.56614274f, 0.49510333f, 0.44703856f, 0.50454074f, 0.5879411f, + 0.567134f, 0.49553978f, 0.4470674f, 0.504163f, 0.5870515f, + 0.5679643f, 0.4959048f, 0.44708967f, 0.5038433f, 0.5862998f, + 0.56866974f, 0.4962146f, 0.44710726f, 0.5035692f, 0.58565617f, + 0.56927663f, 0.49648085f, 0.4471213f, 0.5033315f, 0.5850988f, + + 0.56980413f, 0.49671215f, 0.44713274f, 0.50312346f, 0.58461165f, + 0.57026696f, 0.49691492f, 0.4471422f, 0.50293994f, 0.58418214f, + 0.5706764f, 0.49709415f, 0.44715008f, 0.5027767f, 0.5838005f, + 0.571041f, 0.4972537f, 0.44715673f, 0.50263065f, 0.58345926f, + 0.57136786f, 0.49739665f, 0.44716236f, 0.5024992f, 0.58315235f, + + 0.5716625f, 0.49752548f, 0.4471672f, 0.5023803f, 0.5828747f, + 0.5719295f, 0.49764213f, 0.44717142f, 0.5022721f, 0.5826225f, + 0.57217246f, 0.49774826f, 0.44717506f, 0.5021734f, 0.58239233f, + 0.5723947f, 0.4978453f, 0.44717824f, 0.5020829f, 0.58218133f, + 0.57259864f, 0.49793428f, 0.44718108f, 0.5019997f, 0.5819874f, + + 0.5727864f, 0.49801624f, 0.44718358f, 0.5019227f, 0.5818083f, + 0.57296f, 0.49809194f, 0.44718578f, 0.5018515f, 0.5816426f, + 0.5731208f, 0.49816203f, 0.44718775f, 0.5017854f, 0.58148885f, + 0.57327026f, 0.49822718f, 0.4471895f, 0.5017239f, 0.5813457f, + 0.57340944f, 0.49828786f, 0.44719115f, 0.5016664f, 0.581212f, + + 0.57353944f, 0.4983446f, 0.44719255f, 0.50161266f, 0.58108705f, + 0.5736612f, 0.49839762f, 0.4471939f, 0.50156236f, 0.5809699f, + 0.5737754f, 0.4984474f, 0.44719502f, 0.501515f, 0.58085984f, + 0.5738828f, 0.49849418f, 0.4471962f, 0.50147045f, 0.5807564f, + 0.5739839f, 0.49853817f, 0.44719717f, 0.5014284f, 0.5806588f, + + 0.5740793f, 0.49857965f, 0.4471981f, 0.5013887f, 0.5805666f, + 0.5741694f, 0.49861887f, 0.44719887f, 0.50135124f, 0.58047944f, + 0.57425463f, 0.49865603f, 0.44719967f, 0.5013157f, 0.5803969f, + 0.5743354f, 0.4986912f, 0.44720036f, 0.5012819f, 0.5803186f, + 0.57441217f, 0.49872455f, 0.44720104f, 0.5012499f, 0.58024424f, + + 0.57448506f, 0.4987563f, 0.44720164f, 0.5012194f, 0.58017343f, + 0.57455444f, 0.4987865f, 0.4472022f, 0.5011904f, 0.5801061f, + 0.57462054f, 0.49881527f, 0.44720277f, 0.5011627f, 0.5800419f, + 0.57468355f, 0.49884263f, 0.44720328f, 0.50113624f, 0.5799805f, + 0.57474375f, 0.49886885f, 0.44720373f, 0.50111103f, 0.5799219f}); + // + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + // ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_4) { + // auto x = NDArrayFactory::create('c', {8, 32, 64, 64}); + auto x = NDArrayFactory::create('c', {2, 8, 16, 16}); + x.linspace(1); - // auto x = NDArrayFactory::create('c', {8, 32, 64, 64}); - auto x = NDArrayFactory::create('c', {2, 8, 16, 16}); - x.linspace(1); - - sd::ops::lrn op; - auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); -// ASSERT_TRUE(exp.isSameShape(out)); -// out->printIndexedBuffer("LRN out"); -// exp.printIndexedBuffer("LRN exp"); -// ASSERT_TRUE(exp.equalsTo(out)); - + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); + auto out = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + // ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + // ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_4_119) { - int iterations = 1000; - // auto x = NDArrayFactory::create('c', {8, 32, 64, 64}); - // auto z = NDArrayFactory::create('c', {8, 32, 64, 64}); - auto x = NDArrayFactory::create('c', {2, 8, 16, 16}); - auto z = NDArrayFactory::create('c', {2, 8, 16, 16}); - x.linspace(1); - - sd::ops::lrn op; + int iterations = 1000; + // auto x = NDArrayFactory::create('c', {8, 32, 64, 64}); + // auto z = NDArrayFactory::create('c', {8, 32, 64, 64}); + auto x = NDArrayFactory::create('c', {2, 8, 16, 16}); + auto z = NDArrayFactory::create('c', {2, 8, 16, 16}); + x.linspace(1); - op.execute({&x}, {&z}, {1.0, 1.0, 0.5}, {2}); + sd::ops::lrn op; - auto timeStart = std::chrono::system_clock::now(); + op.execute({&x}, {&z}, {1.0, 1.0, 0.5}, {2}); - for (int e = 0; e < iterations; e++) - op.execute({&x}, {&z}, {1.0, 1.0, 0.5}, {2}); + auto timeStart = std::chrono::system_clock::now(); - auto timeEnd = std::chrono::system_clock::now(); - auto spanTime = std::chrono::duration_cast ((timeEnd - timeStart) / iterations).count(); - auto ttlTime = std::chrono::duration_cast ((timeEnd - timeStart)).count(); + for (int e = 0; e < iterations; e++) + op.execute({&x}, {&z}, {1.0, 1.0, 0.5}, {2}); + auto timeEnd = std::chrono::system_clock::now(); + auto spanTime = std::chrono::duration_cast( + (timeEnd - timeStart) / iterations) + .count(); + auto ttlTime = std::chrono::duration_cast( + (timeEnd - timeStart)) + .count(); -// ASSERT_TRUE(exp.isSameShape(out)); -// ASSERT_TRUE(exp.equalsTo(out)); + // ASSERT_TRUE(exp.isSameShape(out)); + // ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_01) { - - auto x = NDArrayFactory::create( 'c', {1, 1, 1, 10}); - x.linspace(1); - auto eps = NDArrayFactory::create('c', {1,1,1,10}); - eps.linspace(1); -// -// auto exp = NDArrayFactory::create('c', {3,3,5,5}, { -// 0.238337, 0.309664, 0.334077, 0.376534, 0.342926, 0.370734, 0.362017, 0.354182, 0.379140, 0.376275, 0.380027, 0.368347, 0.356401, 0.378316, 0.381315, 0.382465, 0.370592, 0.357055, 0.377670, 0.382950, 0.383445, 0.371718, 0.357332, 0.377217, 0.383677, 0.383933, 0.372391, 0.357475, 0.376891, 0.384062, 0.384212, 0.372837, 0.357557, 0.376646, 0.384290, 0.384385, 0.373153, 0.357610, 0.376457, 0.384436, 0.384500, 0.373389, 0.357645, 0.376306, 0.384536, 0.384581, 0.373572, 0.357670, 0.376184, 0.384606, 0.384639, 0.373718, 0.357688, 0.376082, 0.384658, 0.384683, 0.373837, 0.357702, 0.375996, 0.384698, 0.384717, 0.373935, 0.357712, 0.375923, 0.384728, 0.384743, 0.374019, 0.357721, 0.375860, 0.384752, 0.384764, 0.374090, 0.357727, 0.375804, 0.384771, 0.384781, 0.374152, 0.357733, 0.375756, 0.384787, 0.384795, 0.374205, 0.357737, 0.375713, 0.384800, 0.384807, 0.374253, 0.357741, 0.375674, 0.384811, 0.384817, 0.374295, 0.357744, 0.375640, 0.384820, 0.384825, 0.374333, 0.357747, 0.375609, 0.384828, 0.384832, 0.374366, 0.357749, 0.375581, 0.384835, 0.384839, 0.374397, 0.357751, 0.375555, 0.384841, 0.384844, 0.374425, 0.357753, 0.375531, 0.384846, 0.384849, 0.374450, 0.357754, 0.375510, 0.384850, 0.384853, 0.374473, 0.357756, 0.375490, 0.384854, 0.384856, 0.374494, 0.357757, 0.375471, 0.384858, 0.384860, 0.374514, 0.357758, 0.375454, 0.384861, 0.384863, 0.374532, 0.357759, 0.375438, 0.384864, 0.384865, 0.374549, 0.357760, 0.375423, 0.384866, 0.384868, 0.374565, 0.357760, 0.375410, 0.384868, 0.384870, 0.374579, 0.357761, 0.375397, 0.384870, 0.384872, 0.374593, 0.357762, 0.375384, 0.384872, 0.384873, 0.374606, 0.357762, 0.375373, 0.384874, 0.384875, 0.374618, 0.357763, 0.375362, 0.384875, 0.384876, 0.374629, 0.357763, 0.375352, 0.384877, 0.384878, 0.374640, 0.357764, 0.375342, 0.384878, 0.384879, 0.374650, 0.357764, 0.375333, 0.384879, 0.384880, 0.374660, 0.357764, 0.375325, 0.384880, 0.384881, 0.374669, 0.357765, 0.375316, 0.384881, 0.384882, 0.374677, 0.357765, 0.375309, 0.384882, 0.384883, 0.374685, 0.357765, 0.375301, 0.384883, 0.384884, 0.374693, 0.357765, 0.375294, 0.384884, 0.384884, 0.374700, 0.357766, 0.375287, 0.384885, 0.384885, 0.374707, 0.357766, 0.375281, 0.384885, 0.384886, 0.374714, 0.357766, 0.375275, 0.384886} -// ); -/// - sd::ops::lrn_bp op; - auto results = op.evaluate({&x, &eps}, {1.0, 1.0, 0.5}, {5}); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); -// ASSERT_TRUE(exp.isSameShape(out)); - //out->printBuffer("LRN BP out"); - //exp.printBuffer("LRN BP exp"); - //ASSERT_TRUE(exp.equalsTo(out)); - - + auto x = NDArrayFactory::create('c', {1, 1, 1, 10}); + x.linspace(1); + auto eps = NDArrayFactory::create('c', {1, 1, 1, 10}); + eps.linspace(1); + // + // auto exp = NDArrayFactory::create('c', {3,3,5,5}, { + // 0.238337, 0.309664, 0.334077, 0.376534, 0.342926, 0.370734, + // 0.362017, 0.354182, 0.379140, 0.376275, 0.380027, 0.368347, + // 0.356401, 0.378316, 0.381315, 0.382465, 0.370592, 0.357055, + // 0.377670, 0.382950, 0.383445, 0.371718, 0.357332, 0.377217, + // 0.383677, 0.383933, 0.372391, 0.357475, 0.376891, 0.384062, + // 0.384212, 0.372837, 0.357557, 0.376646, 0.384290, 0.384385, + // 0.373153, 0.357610, 0.376457, 0.384436, 0.384500, 0.373389, + // 0.357645, 0.376306, 0.384536, 0.384581, 0.373572, 0.357670, + // 0.376184, 0.384606, 0.384639, 0.373718, 0.357688, 0.376082, + // 0.384658, 0.384683, 0.373837, 0.357702, 0.375996, 0.384698, + // 0.384717, 0.373935, 0.357712, 0.375923, 0.384728, 0.384743, + // 0.374019, 0.357721, 0.375860, 0.384752, 0.384764, 0.374090, + // 0.357727, 0.375804, 0.384771, 0.384781, 0.374152, 0.357733, + // 0.375756, 0.384787, 0.384795, 0.374205, 0.357737, 0.375713, + // 0.384800, 0.384807, 0.374253, 0.357741, 0.375674, 0.384811, + // 0.384817, 0.374295, 0.357744, 0.375640, 0.384820, 0.384825, + // 0.374333, 0.357747, 0.375609, 0.384828, 0.384832, 0.374366, + // 0.357749, 0.375581, 0.384835, 0.384839, 0.374397, 0.357751, + // 0.375555, 0.384841, 0.384844, 0.374425, 0.357753, 0.375531, + // 0.384846, 0.384849, 0.374450, 0.357754, 0.375510, 0.384850, + // 0.384853, 0.374473, 0.357756, 0.375490, 0.384854, 0.384856, + // 0.374494, 0.357757, 0.375471, 0.384858, 0.384860, 0.374514, + // 0.357758, 0.375454, 0.384861, 0.384863, 0.374532, 0.357759, + // 0.375438, 0.384864, 0.384865, 0.374549, 0.357760, 0.375423, + // 0.384866, 0.384868, 0.374565, 0.357760, 0.375410, 0.384868, + // 0.384870, 0.374579, 0.357761, 0.375397, 0.384870, 0.384872, + // 0.374593, 0.357762, 0.375384, 0.384872, 0.384873, 0.374606, + // 0.357762, 0.375373, 0.384874, 0.384875, 0.374618, 0.357763, + // 0.375362, 0.384875, 0.384876, 0.374629, 0.357763, 0.375352, + // 0.384877, 0.384878, 0.374640, 0.357764, 0.375342, 0.384878, + // 0.384879, 0.374650, 0.357764, 0.375333, 0.384879, 0.384880, + // 0.374660, 0.357764, 0.375325, 0.384880, 0.384881, 0.374669, + // 0.357765, 0.375316, 0.384881, 0.384882, 0.374677, 0.357765, + // 0.375309, 0.384882, 0.384883, 0.374685, 0.357765, 0.375301, + // 0.384883, 0.384884, 0.374693, 0.357765, 0.375294, 0.384884, + // 0.384884, 0.374700, 0.357766, 0.375287, 0.384885, 0.384885, + // 0.374707, 0.357766, 0.375281, 0.384885, 0.384886, 0.374714, + // 0.357766, 0.375275, 0.384886} + // ); + /// + sd::ops::lrn_bp op; + auto results = op.evaluate({&x, &eps}, {1.0, 1.0, 0.5}, {5}); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + // ASSERT_TRUE(exp.isSameShape(out)); + // out->printBuffer("LRN BP out"); + // exp.printBuffer("LRN BP exp"); + // ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_02) { - - auto x = NDArrayFactory::create( 'c', {1, 1, 1, 10}); - x.linspace(1); - auto eps = NDArrayFactory::create('c', {1,1,1,10}); - eps.linspace(1); -// -// auto exp = NDArrayFactory::create('c', {3,3,5,5}, { -// 0.238337, 0.309664, 0.334077, 0.376534, 0.342926, 0.370734, 0.362017, 0.354182, 0.379140, 0.376275, 0.380027, 0.368347, 0.356401, 0.378316, 0.381315, 0.382465, 0.370592, 0.357055, 0.377670, 0.382950, 0.383445, 0.371718, 0.357332, 0.377217, 0.383677, 0.383933, 0.372391, 0.357475, 0.376891, 0.384062, 0.384212, 0.372837, 0.357557, 0.376646, 0.384290, 0.384385, 0.373153, 0.357610, 0.376457, 0.384436, 0.384500, 0.373389, 0.357645, 0.376306, 0.384536, 0.384581, 0.373572, 0.357670, 0.376184, 0.384606, 0.384639, 0.373718, 0.357688, 0.376082, 0.384658, 0.384683, 0.373837, 0.357702, 0.375996, 0.384698, 0.384717, 0.373935, 0.357712, 0.375923, 0.384728, 0.384743, 0.374019, 0.357721, 0.375860, 0.384752, 0.384764, 0.374090, 0.357727, 0.375804, 0.384771, 0.384781, 0.374152, 0.357733, 0.375756, 0.384787, 0.384795, 0.374205, 0.357737, 0.375713, 0.384800, 0.384807, 0.374253, 0.357741, 0.375674, 0.384811, 0.384817, 0.374295, 0.357744, 0.375640, 0.384820, 0.384825, 0.374333, 0.357747, 0.375609, 0.384828, 0.384832, 0.374366, 0.357749, 0.375581, 0.384835, 0.384839, 0.374397, 0.357751, 0.375555, 0.384841, 0.384844, 0.374425, 0.357753, 0.375531, 0.384846, 0.384849, 0.374450, 0.357754, 0.375510, 0.384850, 0.384853, 0.374473, 0.357756, 0.375490, 0.384854, 0.384856, 0.374494, 0.357757, 0.375471, 0.384858, 0.384860, 0.374514, 0.357758, 0.375454, 0.384861, 0.384863, 0.374532, 0.357759, 0.375438, 0.384864, 0.384865, 0.374549, 0.357760, 0.375423, 0.384866, 0.384868, 0.374565, 0.357760, 0.375410, 0.384868, 0.384870, 0.374579, 0.357761, 0.375397, 0.384870, 0.384872, 0.374593, 0.357762, 0.375384, 0.384872, 0.384873, 0.374606, 0.357762, 0.375373, 0.384874, 0.384875, 0.374618, 0.357763, 0.375362, 0.384875, 0.384876, 0.374629, 0.357763, 0.375352, 0.384877, 0.384878, 0.374640, 0.357764, 0.375342, 0.384878, 0.384879, 0.374650, 0.357764, 0.375333, 0.384879, 0.384880, 0.374660, 0.357764, 0.375325, 0.384880, 0.384881, 0.374669, 0.357765, 0.375316, 0.384881, 0.384882, 0.374677, 0.357765, 0.375309, 0.384882, 0.384883, 0.374685, 0.357765, 0.375301, 0.384883, 0.384884, 0.374693, 0.357765, 0.375294, 0.384884, 0.384884, 0.374700, 0.357766, 0.375287, 0.384885, 0.384885, 0.374707, 0.357766, 0.375281, 0.384885, 0.384886, 0.374714, 0.357766, 0.375275, 0.384886} -// ); -/// - sd::ops::lrn opFF; - sd::ops::lrn_bp opBP; - - const OpArgsHolder argsHolderFF({&x}, {1., 1., 0.5}, {5}); - const OpArgsHolder argsHolderBP({&x, &eps}, {1., 1., 0.5}, {5}); - - bool gradOK = true; //GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - //auto results = op.execute({&x, &eps}, {1.0, 1.0, 0.5}, {5}, {}, false, sd::DataType::DOUBLE); - //auto out = results.at(0); - - //ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(gradOK); - //out->printBuffer("LRN BP out"); - //exp.printBuffer("LRN BP exp"); - //ASSERT_TRUE(exp.equalsTo(out)); - - // + auto x = NDArrayFactory::create('c', {1, 1, 1, 10}); + x.linspace(1); + auto eps = NDArrayFactory::create('c', {1, 1, 1, 10}); + eps.linspace(1); + // + // auto exp = NDArrayFactory::create('c', {3,3,5,5}, { + // 0.238337, 0.309664, 0.334077, 0.376534, 0.342926, 0.370734, + // 0.362017, 0.354182, 0.379140, 0.376275, 0.380027, 0.368347, + // 0.356401, 0.378316, 0.381315, 0.382465, 0.370592, 0.357055, + // 0.377670, 0.382950, 0.383445, 0.371718, 0.357332, 0.377217, + // 0.383677, 0.383933, 0.372391, 0.357475, 0.376891, 0.384062, + // 0.384212, 0.372837, 0.357557, 0.376646, 0.384290, 0.384385, + // 0.373153, 0.357610, 0.376457, 0.384436, 0.384500, 0.373389, + // 0.357645, 0.376306, 0.384536, 0.384581, 0.373572, 0.357670, + // 0.376184, 0.384606, 0.384639, 0.373718, 0.357688, 0.376082, + // 0.384658, 0.384683, 0.373837, 0.357702, 0.375996, 0.384698, + // 0.384717, 0.373935, 0.357712, 0.375923, 0.384728, 0.384743, + // 0.374019, 0.357721, 0.375860, 0.384752, 0.384764, 0.374090, + // 0.357727, 0.375804, 0.384771, 0.384781, 0.374152, 0.357733, + // 0.375756, 0.384787, 0.384795, 0.374205, 0.357737, 0.375713, + // 0.384800, 0.384807, 0.374253, 0.357741, 0.375674, 0.384811, + // 0.384817, 0.374295, 0.357744, 0.375640, 0.384820, 0.384825, + // 0.374333, 0.357747, 0.375609, 0.384828, 0.384832, 0.374366, + // 0.357749, 0.375581, 0.384835, 0.384839, 0.374397, 0.357751, + // 0.375555, 0.384841, 0.384844, 0.374425, 0.357753, 0.375531, + // 0.384846, 0.384849, 0.374450, 0.357754, 0.375510, 0.384850, + // 0.384853, 0.374473, 0.357756, 0.375490, 0.384854, 0.384856, + // 0.374494, 0.357757, 0.375471, 0.384858, 0.384860, 0.374514, + // 0.357758, 0.375454, 0.384861, 0.384863, 0.374532, 0.357759, + // 0.375438, 0.384864, 0.384865, 0.374549, 0.357760, 0.375423, + // 0.384866, 0.384868, 0.374565, 0.357760, 0.375410, 0.384868, + // 0.384870, 0.374579, 0.357761, 0.375397, 0.384870, 0.384872, + // 0.374593, 0.357762, 0.375384, 0.384872, 0.384873, 0.374606, + // 0.357762, 0.375373, 0.384874, 0.384875, 0.374618, 0.357763, + // 0.375362, 0.384875, 0.384876, 0.374629, 0.357763, 0.375352, + // 0.384877, 0.384878, 0.374640, 0.357764, 0.375342, 0.384878, + // 0.384879, 0.374650, 0.357764, 0.375333, 0.384879, 0.384880, + // 0.374660, 0.357764, 0.375325, 0.384880, 0.384881, 0.374669, + // 0.357765, 0.375316, 0.384881, 0.384882, 0.374677, 0.357765, + // 0.375309, 0.384882, 0.384883, 0.374685, 0.357765, 0.375301, + // 0.384883, 0.384884, 0.374693, 0.357765, 0.375294, 0.384884, + // 0.384884, 0.374700, 0.357766, 0.375287, 0.384885, 0.384885, + // 0.374707, 0.357766, 0.375281, 0.384885, 0.384886, 0.374714, + // 0.357766, 0.375275, 0.384886} + // ); + /// + sd::ops::lrn opFF; + sd::ops::lrn_bp opBP; + + const OpArgsHolder argsHolderFF({&x}, {1., 1., 0.5}, {5}); + const OpArgsHolder argsHolderBP({&x, &eps}, {1., 1., 0.5}, {5}); + + bool gradOK = + true; // GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + // auto results = op.execute({&x, &eps}, {1.0, 1.0, 0.5}, {5}, {}, false, + // sd::DataType::DOUBLE); auto out = results.at(0); + + // ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradOK); + // out->printBuffer("LRN BP out"); + // exp.printBuffer("LRN BP exp"); + // ASSERT_TRUE(exp.equalsTo(out)); + + // } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_1) { - - auto x = NDArrayFactory::create( 'c', {3, 3, 5, 5}); - x.linspace(1); - auto eps = NDArrayFactory::create('c', {3,3,5,5}); - eps.linspace(1); -// -auto exp = NDArrayFactory::create('c', {3,3,5,5}, { - 0.238337f, 0.309664f, 0.334077f, 0.376534f, 0.342926f, 0.370734f, 0.362017f, 0.354182f, 0.379140f, 0.376275f, 0.380027f, 0.368347f, 0.356401f, 0.378316f, 0.381315f, 0.382465f, 0.370592f, 0.357055f, 0.377670f, 0.382950f, 0.383445f, 0.371718f, 0.357332f, 0.377217f, 0.383677f, 0.383933f, 0.372391f, 0.357475f, 0.376891f, 0.384062f, 0.384212f, 0.372837f, 0.357557f, 0.376646f, 0.384290f, 0.384385f, 0.373153f, 0.357610f, 0.376457f, 0.384436f, 0.384500f, 0.373389f, 0.357645f, 0.376306f, 0.384536f, 0.384581f, 0.373572f, 0.357670f, 0.376184f, 0.384606f, 0.384639f, 0.373718f, 0.357688f, 0.376082f, 0.384658f, 0.384683f, 0.373837f, 0.357702f, 0.375996f, 0.384698f, 0.384717f, 0.373935f, 0.357712f, 0.375923f, 0.384728f, 0.384743f, 0.374019f, 0.357721f, 0.375860f, 0.384752f, 0.384764f, 0.374090f, 0.357727f, 0.375804f, 0.384771f, 0.384781f, 0.374152f, 0.357733f, 0.375756f, 0.384787f, 0.384795f, 0.374205f, 0.357737f, 0.375713f, 0.384800f, 0.384807f, 0.374253f, 0.357741f, 0.375674f, 0.384811f, 0.384817f, 0.374295f, 0.357744f, 0.375640f, 0.384820f, 0.384825f, 0.374333f, 0.357747f, 0.375609f, 0.384828f, 0.384832f, 0.374366f, 0.357749f, 0.375581f, 0.384835f, 0.384839f, 0.374397f, 0.357751f, 0.375555f, 0.384841f, 0.384844f, 0.374425f, 0.357753f, 0.375531f, 0.384846f, 0.384849f, 0.374450f, 0.357754f, 0.375510f, 0.384850f, 0.384853f, 0.374473f, 0.357756f, 0.375490f, 0.384854f, 0.384856f, 0.374494f, 0.357757f, 0.375471f, 0.384858f, 0.384860f, 0.374514f, 0.357758f, 0.375454f, 0.384861f, 0.384863f, 0.374532f, 0.357759f, 0.375438f, 0.384864f, 0.384865f, 0.374549f, 0.357760f, 0.375423f, 0.384866f, 0.384868f, 0.374565f, 0.357760f, 0.375410f, 0.384868f, 0.384870f, 0.374579f, 0.357761f, 0.375397f, 0.384870f, 0.384872f, 0.374593f, 0.357762f, 0.375384f, 0.384872f, 0.384873f, 0.374606f, 0.357762f, 0.375373f, 0.384874f, 0.384875f, 0.374618f, 0.357763f, 0.375362f, 0.384875f, 0.384876f, 0.374629f, 0.357763f, 0.375352f, 0.384877f, 0.384878f, 0.374640f, 0.357764f, 0.375342f, 0.384878f, 0.384879f, 0.374650f, 0.357764f, 0.375333f, 0.384879f, 0.384880f, 0.374660f, 0.357764f, 0.375325f, 0.384880f, 0.384881f, 0.374669f, 0.357765f, 0.375316f, 0.384881f, 0.384882f, 0.374677f, 0.357765f, 0.375309f, 0.384882f, 0.384883f, 0.374685f, 0.357765f, 0.375301f, 0.384883f, 0.384884f, 0.374693f, 0.357765f, 0.375294f, 0.384884f, 0.384884f, 0.374700f, 0.357766f, 0.375287f, 0.384885f, 0.384885f, 0.374707f, 0.357766f, 0.375281f, 0.384885f, 0.384886f, 0.374714f, 0.357766f, 0.375275f, 0.384886f} - ); -/// - sd::ops::lrn_bp op; - auto results = op.evaluate({&x, &eps}, {1.0, 1.0, 0.5}, {2}, {}, {}, false); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); -// ASSERT_TRUE(exp.isSameShape(out)); - // out->printBuffer("LRN BP out"); - // exp.printBuffer("LRN BP exp"); - //ASSERT_TRUE(exp.equalsTo(out)); - - + auto x = NDArrayFactory::create('c', {3, 3, 5, 5}); + x.linspace(1); + auto eps = NDArrayFactory::create('c', {3, 3, 5, 5}); + eps.linspace(1); + // + auto exp = NDArrayFactory::create( + 'c', {3, 3, 5, 5}, + {0.238337f, 0.309664f, 0.334077f, 0.376534f, 0.342926f, 0.370734f, + 0.362017f, 0.354182f, 0.379140f, 0.376275f, 0.380027f, 0.368347f, + 0.356401f, 0.378316f, 0.381315f, 0.382465f, 0.370592f, 0.357055f, + 0.377670f, 0.382950f, 0.383445f, 0.371718f, 0.357332f, 0.377217f, + 0.383677f, 0.383933f, 0.372391f, 0.357475f, 0.376891f, 0.384062f, + 0.384212f, 0.372837f, 0.357557f, 0.376646f, 0.384290f, 0.384385f, + 0.373153f, 0.357610f, 0.376457f, 0.384436f, 0.384500f, 0.373389f, + 0.357645f, 0.376306f, 0.384536f, 0.384581f, 0.373572f, 0.357670f, + 0.376184f, 0.384606f, 0.384639f, 0.373718f, 0.357688f, 0.376082f, + 0.384658f, 0.384683f, 0.373837f, 0.357702f, 0.375996f, 0.384698f, + 0.384717f, 0.373935f, 0.357712f, 0.375923f, 0.384728f, 0.384743f, + 0.374019f, 0.357721f, 0.375860f, 0.384752f, 0.384764f, 0.374090f, + 0.357727f, 0.375804f, 0.384771f, 0.384781f, 0.374152f, 0.357733f, + 0.375756f, 0.384787f, 0.384795f, 0.374205f, 0.357737f, 0.375713f, + 0.384800f, 0.384807f, 0.374253f, 0.357741f, 0.375674f, 0.384811f, + 0.384817f, 0.374295f, 0.357744f, 0.375640f, 0.384820f, 0.384825f, + 0.374333f, 0.357747f, 0.375609f, 0.384828f, 0.384832f, 0.374366f, + 0.357749f, 0.375581f, 0.384835f, 0.384839f, 0.374397f, 0.357751f, + 0.375555f, 0.384841f, 0.384844f, 0.374425f, 0.357753f, 0.375531f, + 0.384846f, 0.384849f, 0.374450f, 0.357754f, 0.375510f, 0.384850f, + 0.384853f, 0.374473f, 0.357756f, 0.375490f, 0.384854f, 0.384856f, + 0.374494f, 0.357757f, 0.375471f, 0.384858f, 0.384860f, 0.374514f, + 0.357758f, 0.375454f, 0.384861f, 0.384863f, 0.374532f, 0.357759f, + 0.375438f, 0.384864f, 0.384865f, 0.374549f, 0.357760f, 0.375423f, + 0.384866f, 0.384868f, 0.374565f, 0.357760f, 0.375410f, 0.384868f, + 0.384870f, 0.374579f, 0.357761f, 0.375397f, 0.384870f, 0.384872f, + 0.374593f, 0.357762f, 0.375384f, 0.384872f, 0.384873f, 0.374606f, + 0.357762f, 0.375373f, 0.384874f, 0.384875f, 0.374618f, 0.357763f, + 0.375362f, 0.384875f, 0.384876f, 0.374629f, 0.357763f, 0.375352f, + 0.384877f, 0.384878f, 0.374640f, 0.357764f, 0.375342f, 0.384878f, + 0.384879f, 0.374650f, 0.357764f, 0.375333f, 0.384879f, 0.384880f, + 0.374660f, 0.357764f, 0.375325f, 0.384880f, 0.384881f, 0.374669f, + 0.357765f, 0.375316f, 0.384881f, 0.384882f, 0.374677f, 0.357765f, + 0.375309f, 0.384882f, 0.384883f, 0.374685f, 0.357765f, 0.375301f, + 0.384883f, 0.384884f, 0.374693f, 0.357765f, 0.375294f, 0.384884f, + 0.384884f, 0.374700f, 0.357766f, 0.375287f, 0.384885f, 0.384885f, + 0.374707f, 0.357766f, 0.375281f, 0.384885f, 0.384886f, 0.374714f, + 0.357766f, 0.375275f, 0.384886f}); + /// + sd::ops::lrn_bp op; + auto results = op.evaluate({&x, &eps}, {1.0, 1.0, 0.5}, {2}, {}, {}, false); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + // ASSERT_TRUE(exp.isSameShape(out)); + // out->printBuffer("LRN BP out"); + // exp.printBuffer("LRN BP exp"); + // ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_2) { - - auto x = NDArrayFactory::create( 'c', {3, 3, 5, 5}); - x.linspace(1); - - auto eps = NDArrayFactory::create('c', {3, 3, 5, 5}, { 0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, - 0.4898979f, 0.46056613f, 0.43971977f, 0.5240002f, 0.6375767f, - 0.5274096f, 0.47771242f, 0.4443308f, 0.5163977f, 0.61701745f, - 0.5424508f, 0.48452914f, 0.44570294f, 0.5123918f, 0.6068971f, - 0.5505386f, 0.4881662f, 0.4462865f, 0.5099462f, 0.60088515f, - - 0.5555859f, 0.49042296f, 0.44658744f, 0.5083028f, 0.59690416f, - 0.55903524f, 0.4919585f, 0.44676256f, 0.5071239f, 0.59407425f, - 0.5615412f, 0.49307042f, 0.44687328f, 0.50623745f, 0.5919596f, - 0.56344414f, 0.49391258f, 0.4469477f, 0.5055468f, 0.59031945f, - 0.56493837f, 0.49457246f, 0.4470002f, 0.5049936f, 0.5890103f, - - 0.56614274f, 0.49510333f, 0.44703856f, 0.50454074f, 0.5879411f, - 0.567134f, 0.49553978f, 0.4470674f, 0.504163f, 0.5870515f, - 0.5679643f, 0.4959048f, 0.44708967f, 0.5038433f, 0.5862998f, - 0.56866974f, 0.4962146f, 0.44710726f, 0.5035692f, 0.58565617f, - 0.56927663f, 0.49648085f, 0.4471213f, 0.5033315f, 0.5850988f, - - - 0.56980413f, 0.49671215f, 0.44713274f, 0.50312346f, 0.58461165f, - 0.57026696f, 0.49691492f, 0.4471422f, 0.50293994f, 0.58418214f, - 0.5706764f, 0.49709415f, 0.44715008f, 0.5027767f, 0.5838005f, - 0.571041f, 0.4972537f, 0.44715673f, 0.50263065f, 0.58345926f, - 0.57136786f, 0.49739665f, 0.44716236f, 0.5024992f, 0.58315235f, - - 0.5716625f, 0.49752548f, 0.4471672f, 0.5023803f, 0.5828747f, - 0.5719295f, 0.49764213f, 0.44717142f, 0.5022721f, 0.5826225f, - 0.57217246f, 0.49774826f, 0.44717506f, 0.5021734f, 0.58239233f, - 0.5723947f, 0.4978453f, 0.44717824f, 0.5020829f, 0.58218133f, - 0.57259864f, 0.49793428f, 0.44718108f, 0.5019997f, 0.5819874f, - - 0.5727864f, 0.49801624f, 0.44718358f, 0.5019227f, 0.5818083f, - 0.57296f, 0.49809194f, 0.44718578f, 0.5018515f, 0.5816426f, - 0.5731208f, 0.49816203f, 0.44718775f, 0.5017854f, 0.58148885f, - 0.57327026f, 0.49822718f, 0.4471895f, 0.5017239f, 0.5813457f, - 0.57340944f, 0.49828786f, 0.44719115f, 0.5016664f, 0.581212f, - - - 0.57353944f, 0.4983446f, 0.44719255f, 0.50161266f, 0.58108705f, - 0.5736612f, 0.49839762f, 0.4471939f, 0.50156236f, 0.5809699f, - 0.5737754f, 0.4984474f, 0.44719502f, 0.501515f, 0.58085984f, - 0.5738828f, 0.49849418f, 0.4471962f, 0.50147045f, 0.5807564f, - 0.5739839f, 0.49853817f, 0.44719717f, 0.5014284f, 0.5806588f, - - 0.5740793f, 0.49857965f, 0.4471981f, 0.5013887f, 0.5805666f, - 0.5741694f, 0.49861887f, 0.44719887f, 0.50135124f, 0.58047944f, - 0.57425463f, 0.49865603f, 0.44719967f, 0.5013157f, 0.5803969f, - 0.5743354f, 0.4986912f, 0.44720036f, 0.5012819f, 0.5803186f, - 0.57441217f, 0.49872455f, 0.44720104f, 0.5012499f, 0.58024424f, - - 0.57448506f, 0.4987563f, 0.44720164f, 0.5012194f, 0.58017343f, - 0.57455444f, 0.4987865f, 0.4472022f, 0.5011904f, 0.5801061f, - 0.57462054f, 0.49881527f, 0.44720277f, 0.5011627f, 0.5800419f, - 0.57468355f, 0.49884263f, 0.44720328f, 0.50113624f, 0.5799805f, - 0.57474375f, 0.49886885f, 0.44720373f, 0.50111103f, 0.5799219f }); -// - auto exp = NDArrayFactory::create('c', {3,3,5,5}, { - 0.061538f, 0.055617f, 0.044643f, 0.050772f, 0.048019f, 0.030270f, 0.023819f, 0.019468f, 0.022074f, 0.023990f, 0.018221f, 0.014664f, 0.012182f, 0.013954f, 0.015685f, 0.012967f, 0.010563f, 0.008841f, 0.010185f, 0.011621f, 0.010052f, 0.008248f, 0.006934f, 0.008015f, 0.009222f, 0.008204f, 0.006764f, 0.005702f, 0.006606f, 0.007642f, 0.006929f, 0.005732f, 0.004841f, 0.005618f, 0.006523f, 0.005996f, 0.004973f, 0.004205f, 0.004887f, 0.005689f, 0.005284f, 0.004391f, 0.003717f, 0.004324f, 0.005044f, 0.004723f, 0.003931f, 0.003331f, 0.003877f, 0.004531f, 0.004270f, 0.003558f, 0.003017f, 0.003514f, 0.004112f, 0.003896f, 0.003250f, 0.002757f, 0.003213f, 0.003764f, 0.003582f, 0.002991f, 0.002539f, 0.002959f, 0.003470f, 0.003315f, 0.002770f, 0.002352f, 0.002743f, 0.003219f, 0.003085f, 0.002580f, 0.002191f, 0.002556f, 0.003002f, 0.002885f, 0.002414f, 0.002051f, 0.002393f, 0.002812f, 0.002709f, 0.002268f, 0.001927f, 0.002250f, 0.002645f, 0.002553f, 0.002138f, 0.001818f, 0.002122f, 0.002496f, 0.002415f, 0.002023f, 0.001720f, 0.002009f, 0.002363f, 0.002290f, 0.001920f, 0.001632f, 0.001906f, 0.002244f, 0.002178f, 0.001826f, 0.001553f, 0.001814f, 0.002136f, 0.002076f, 0.001741f, 0.001481f, 0.001731f, 0.002038f, 0.001984f, 0.001664f, 0.001416f, 0.001654f, 0.001949f, 0.001899f, 0.001593f, 0.001356f, 0.001584f, 0.001867f, 0.001821f, 0.001528f, 0.001301f, 0.001520f, 0.001792f, 0.001750f, 0.001469f, 0.001250f, 0.001461f, 0.001722f, 0.001683f, 0.001413f, 0.001203f, 0.001406f, 0.001658f, 0.001622f, 0.001362f, 0.001159f, 0.001355f, 0.001599f, 0.001565f, 0.001314f, 0.001119f, 0.001308f, 0.001543f, 0.001512f, 0.001270f, 0.001081f, 0.001264f, 0.001491f, 0.001462f, 0.001228f, 0.001046f, 0.001223f, 0.001443f, 0.001415f, 0.001189f, 0.001013f, 0.001184f, 0.001397f, 0.001372f, 0.001153f, 0.000982f, 0.001148f, 0.001355f, 0.001331f, 0.001118f, 0.000952f, 0.001114f, 0.001315f, 0.001292f, 0.001086f, 0.000925f, 0.001082f, 0.001277f, 0.001255f, 0.001055f, 0.000899f, 0.001051f, 0.001241f, 0.001221f, 0.001026f, 0.000874f, 0.001023f, 0.001208f, 0.001188f, 0.000999f, 0.000851f, 0.000996f, 0.001176f, 0.001157f, 0.000973f, 0.000829f, 0.000970f, 0.001145f, 0.001128f, 0.000949f, 0.000808f, 0.000945f, 0.001117f, 0.001100f, 0.000925f, 0.000788f, 0.000922f, 0.001089f, 0.001073f, 0.000903f, 0.000769f, 0.000900f, 0.001063f, 0.001048f, 0.000882f, 0.000751f, 0.000879f, 0.001038f, 0.001024f, 0.000861f, 0.000734f, 0.000859f, 0.001015f, 0.001001f, 0.000842f, 0.000717f, 0.000840f, 0.000992f} - // 0.009859f, 0.013075f, 0.013874f, 0.017893f, 0.022344f, 0.014551f, 0.012859f, 0.011511f, 0.013311f, 0.015834f, 0.012025f, 0.010047f, 0.008601f, 0.009920f, 0.011885f, 0.009505f, 0.007636f, 0.006299f, 0.007413f, 0.009095f, 0.007446f, 0.005743f, 0.004540f, 0.005533f, 0.007033f, 0.005821f, 0.004282f, 0.003209f, 0.004123f, 0.005491f, 0.004577f, 0.003198f, 0.002247f, 0.003097f, 0.004355f, 0.003652f, 0.002412f, 0.001565f, 0.002357f, 0.003517f, 0.002965f, 0.001844f, 0.001084f, 0.001821f, 0.002893f, 0.002451f, 0.001430f, 0.000741f, 0.001428f, 0.002422f, -0.111434f, -0.105946f, -0.100351f, -0.091868f, -0.083323f, -0.078775f, -0.076222f, -0.073291f, -0.067635f, -0.061692f, -0.058943f, -0.057832f, -0.056263f, -0.052198f, -0.047768f, -0.046002f, -0.045655f, -0.044839f, -0.041748f, -0.038271f, -0.037084f, -0.037161f, -0.036786f, -0.034331f, -0.031495f, 0.000077f, -0.000673f, -0.001181f, -0.000667f, 0.000079f, -0.000089f, -0.000802f, -0.001285f, -0.000793f, -0.000079f, -0.000228f, -0.000908f, -0.001368f, -0.000896f, -0.000212f, -0.000345f, -0.000996f, -0.001434f, -0.000981f, -0.000325f, -0.000444f, -0.001067f, -0.001487f, -0.001051f, -0.000421f, 0.000697f, 0.000188f, -0.000152f, 0.000210f, 0.000731f, 0.000650f, 0.000165f, -0.000161f, 0.000185f, 0.000683f, 0.000610f, 0.000145f, -0.000168f, 0.000164f, 0.000641f, 0.000574f, 0.000128f, -0.000172f, 0.000146f, 0.000604f, 0.000542f, 0.000113f, -0.000175f, 0.000131f, 0.000571f, -0.009490f, -0.010070f, -0.010409f, -0.009734f, -0.008834f, -0.008785f, -0.009351f, -0.009687f, -0.009054f, -0.008207f, -0.008167f, -0.008718f, -0.009050f, -0.008455f, -0.007654f, -0.007622f, -0.008159f, -0.008485f, -0.007924f, -0.007164f, -0.007138f, -0.007661f, -0.007981f, -0.007450f, -0.006728f, -0.000901f, -0.001327f, -0.001614f, -0.001310f, -0.000869f, -0.000913f, -0.001328f, -0.001607f, -0.001310f, -0.000882f, -0.000922f, -0.001326f, -0.001598f, -0.001309f, -0.000892f, -0.000930f, -0.001323f, -0.001588f, -0.001306f, -0.000900f, -0.000936f, -0.001319f, -0.001577f, -0.001302f, -0.000906f, 0.000339f, 0.000038f, -0.000164f, 0.000048f, 0.000355f, 0.000328f, 0.000035f, -0.000162f, 0.000045f, 0.000343f, 0.000318f, 0.000033f, -0.000159f, 0.000041f, 0.000332f, 0.000308f, 0.000030f, -0.000157f, 0.000039f, 0.000322f, 0.000299f, 0.000028f, -0.000155f, 0.000036f, 0.000312f, -0.004085f, -0.004479f, -0.004733f, -0.004396f, -0.003925f, -0.003925f, -0.004309f, -0.004558f, -0.004232f, -0.003775f, -0.003776f, -0.004151f, -0.004395f, -0.004079f, -0.003636f, -0.003637f, -0.004004f, -0.004242f, -0.003936f, -0.003505f, -0.003507f, -0.003866f, -0.004100f, -0.003802f, -0.003383f} - ); - - sd::ops::lrn_bp op; - auto results = op.evaluate({&x, &eps}, {1.0, 1.0, 0.5}, {2}, {}, {}, false); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(out)); - //out->printBuffer("LRN BP out"); -// exp.printIndexedBuffer("LRN exp"); - // ASSERT_TRUE(exp.equalsTo(out)); - - + auto x = NDArrayFactory::create('c', {3, 3, 5, 5}); + x.linspace(1); + + auto eps = NDArrayFactory::create( + 'c', {3, 3, 5, 5}, + {0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, + 0.4898979f, 0.46056613f, 0.43971977f, 0.5240002f, 0.6375767f, + 0.5274096f, 0.47771242f, 0.4443308f, 0.5163977f, 0.61701745f, + 0.5424508f, 0.48452914f, 0.44570294f, 0.5123918f, 0.6068971f, + 0.5505386f, 0.4881662f, 0.4462865f, 0.5099462f, 0.60088515f, + + 0.5555859f, 0.49042296f, 0.44658744f, 0.5083028f, 0.59690416f, + 0.55903524f, 0.4919585f, 0.44676256f, 0.5071239f, 0.59407425f, + 0.5615412f, 0.49307042f, 0.44687328f, 0.50623745f, 0.5919596f, + 0.56344414f, 0.49391258f, 0.4469477f, 0.5055468f, 0.59031945f, + 0.56493837f, 0.49457246f, 0.4470002f, 0.5049936f, 0.5890103f, + + 0.56614274f, 0.49510333f, 0.44703856f, 0.50454074f, 0.5879411f, + 0.567134f, 0.49553978f, 0.4470674f, 0.504163f, 0.5870515f, + 0.5679643f, 0.4959048f, 0.44708967f, 0.5038433f, 0.5862998f, + 0.56866974f, 0.4962146f, 0.44710726f, 0.5035692f, 0.58565617f, + 0.56927663f, 0.49648085f, 0.4471213f, 0.5033315f, 0.5850988f, + + 0.56980413f, 0.49671215f, 0.44713274f, 0.50312346f, 0.58461165f, + 0.57026696f, 0.49691492f, 0.4471422f, 0.50293994f, 0.58418214f, + 0.5706764f, 0.49709415f, 0.44715008f, 0.5027767f, 0.5838005f, + 0.571041f, 0.4972537f, 0.44715673f, 0.50263065f, 0.58345926f, + 0.57136786f, 0.49739665f, 0.44716236f, 0.5024992f, 0.58315235f, + + 0.5716625f, 0.49752548f, 0.4471672f, 0.5023803f, 0.5828747f, + 0.5719295f, 0.49764213f, 0.44717142f, 0.5022721f, 0.5826225f, + 0.57217246f, 0.49774826f, 0.44717506f, 0.5021734f, 0.58239233f, + 0.5723947f, 0.4978453f, 0.44717824f, 0.5020829f, 0.58218133f, + 0.57259864f, 0.49793428f, 0.44718108f, 0.5019997f, 0.5819874f, + + 0.5727864f, 0.49801624f, 0.44718358f, 0.5019227f, 0.5818083f, + 0.57296f, 0.49809194f, 0.44718578f, 0.5018515f, 0.5816426f, + 0.5731208f, 0.49816203f, 0.44718775f, 0.5017854f, 0.58148885f, + 0.57327026f, 0.49822718f, 0.4471895f, 0.5017239f, 0.5813457f, + 0.57340944f, 0.49828786f, 0.44719115f, 0.5016664f, 0.581212f, + + 0.57353944f, 0.4983446f, 0.44719255f, 0.50161266f, 0.58108705f, + 0.5736612f, 0.49839762f, 0.4471939f, 0.50156236f, 0.5809699f, + 0.5737754f, 0.4984474f, 0.44719502f, 0.501515f, 0.58085984f, + 0.5738828f, 0.49849418f, 0.4471962f, 0.50147045f, 0.5807564f, + 0.5739839f, 0.49853817f, 0.44719717f, 0.5014284f, 0.5806588f, + + 0.5740793f, 0.49857965f, 0.4471981f, 0.5013887f, 0.5805666f, + 0.5741694f, 0.49861887f, 0.44719887f, 0.50135124f, 0.58047944f, + 0.57425463f, 0.49865603f, 0.44719967f, 0.5013157f, 0.5803969f, + 0.5743354f, 0.4986912f, 0.44720036f, 0.5012819f, 0.5803186f, + 0.57441217f, 0.49872455f, 0.44720104f, 0.5012499f, 0.58024424f, + + 0.57448506f, 0.4987563f, 0.44720164f, 0.5012194f, 0.58017343f, + 0.57455444f, 0.4987865f, 0.4472022f, 0.5011904f, 0.5801061f, + 0.57462054f, 0.49881527f, 0.44720277f, 0.5011627f, 0.5800419f, + 0.57468355f, 0.49884263f, 0.44720328f, 0.50113624f, 0.5799805f, + 0.57474375f, 0.49886885f, 0.44720373f, 0.50111103f, 0.5799219f}); + // + auto exp = NDArrayFactory::create( + 'c', {3, 3, 5, 5}, + {0.061538f, 0.055617f, 0.044643f, 0.050772f, 0.048019f, 0.030270f, + 0.023819f, 0.019468f, 0.022074f, 0.023990f, 0.018221f, 0.014664f, + 0.012182f, 0.013954f, 0.015685f, 0.012967f, 0.010563f, 0.008841f, + 0.010185f, 0.011621f, 0.010052f, 0.008248f, 0.006934f, 0.008015f, + 0.009222f, 0.008204f, 0.006764f, 0.005702f, 0.006606f, 0.007642f, + 0.006929f, 0.005732f, 0.004841f, 0.005618f, 0.006523f, 0.005996f, + 0.004973f, 0.004205f, 0.004887f, 0.005689f, 0.005284f, 0.004391f, + 0.003717f, 0.004324f, 0.005044f, 0.004723f, 0.003931f, 0.003331f, + 0.003877f, 0.004531f, 0.004270f, 0.003558f, 0.003017f, 0.003514f, + 0.004112f, 0.003896f, 0.003250f, 0.002757f, 0.003213f, 0.003764f, + 0.003582f, 0.002991f, 0.002539f, 0.002959f, 0.003470f, 0.003315f, + 0.002770f, 0.002352f, 0.002743f, 0.003219f, 0.003085f, 0.002580f, + 0.002191f, 0.002556f, 0.003002f, 0.002885f, 0.002414f, 0.002051f, + 0.002393f, 0.002812f, 0.002709f, 0.002268f, 0.001927f, 0.002250f, + 0.002645f, 0.002553f, 0.002138f, 0.001818f, 0.002122f, 0.002496f, + 0.002415f, 0.002023f, 0.001720f, 0.002009f, 0.002363f, 0.002290f, + 0.001920f, 0.001632f, 0.001906f, 0.002244f, 0.002178f, 0.001826f, + 0.001553f, 0.001814f, 0.002136f, 0.002076f, 0.001741f, 0.001481f, + 0.001731f, 0.002038f, 0.001984f, 0.001664f, 0.001416f, 0.001654f, + 0.001949f, 0.001899f, 0.001593f, 0.001356f, 0.001584f, 0.001867f, + 0.001821f, 0.001528f, 0.001301f, 0.001520f, 0.001792f, 0.001750f, + 0.001469f, 0.001250f, 0.001461f, 0.001722f, 0.001683f, 0.001413f, + 0.001203f, 0.001406f, 0.001658f, 0.001622f, 0.001362f, 0.001159f, + 0.001355f, 0.001599f, 0.001565f, 0.001314f, 0.001119f, 0.001308f, + 0.001543f, 0.001512f, 0.001270f, 0.001081f, 0.001264f, 0.001491f, + 0.001462f, 0.001228f, 0.001046f, 0.001223f, 0.001443f, 0.001415f, + 0.001189f, 0.001013f, 0.001184f, 0.001397f, 0.001372f, 0.001153f, + 0.000982f, 0.001148f, 0.001355f, 0.001331f, 0.001118f, 0.000952f, + 0.001114f, 0.001315f, 0.001292f, 0.001086f, 0.000925f, 0.001082f, + 0.001277f, 0.001255f, 0.001055f, 0.000899f, 0.001051f, 0.001241f, + 0.001221f, 0.001026f, 0.000874f, 0.001023f, 0.001208f, 0.001188f, + 0.000999f, 0.000851f, 0.000996f, 0.001176f, 0.001157f, 0.000973f, + 0.000829f, 0.000970f, 0.001145f, 0.001128f, 0.000949f, 0.000808f, + 0.000945f, 0.001117f, 0.001100f, 0.000925f, 0.000788f, 0.000922f, + 0.001089f, 0.001073f, 0.000903f, 0.000769f, 0.000900f, 0.001063f, + 0.001048f, 0.000882f, 0.000751f, 0.000879f, 0.001038f, 0.001024f, + 0.000861f, 0.000734f, 0.000859f, 0.001015f, 0.001001f, 0.000842f, + 0.000717f, 0.000840f, 0.000992f} + // 0.009859f, 0.013075f, 0.013874f, 0.017893f, 0.022344f, 0.014551f, + // 0.012859f, 0.011511f, 0.013311f, 0.015834f, 0.012025f, 0.010047f, + // 0.008601f, 0.009920f, 0.011885f, 0.009505f, 0.007636f, 0.006299f, + // 0.007413f, 0.009095f, 0.007446f, 0.005743f, 0.004540f, 0.005533f, + // 0.007033f, 0.005821f, 0.004282f, 0.003209f, 0.004123f, 0.005491f, + // 0.004577f, 0.003198f, 0.002247f, 0.003097f, 0.004355f, 0.003652f, + // 0.002412f, 0.001565f, 0.002357f, 0.003517f, 0.002965f, 0.001844f, + // 0.001084f, 0.001821f, 0.002893f, 0.002451f, 0.001430f, 0.000741f, + // 0.001428f, 0.002422f, -0.111434f, -0.105946f, -0.100351f, + // -0.091868f, -0.083323f, -0.078775f, -0.076222f, -0.073291f, + // -0.067635f, -0.061692f, -0.058943f, -0.057832f, -0.056263f, + // -0.052198f, -0.047768f, -0.046002f, -0.045655f, -0.044839f, + // -0.041748f, -0.038271f, -0.037084f, -0.037161f, -0.036786f, + // -0.034331f, -0.031495f, 0.000077f, -0.000673f, -0.001181f, + // -0.000667f, 0.000079f, -0.000089f, -0.000802f, -0.001285f, + // -0.000793f, -0.000079f, -0.000228f, -0.000908f, -0.001368f, + // -0.000896f, -0.000212f, -0.000345f, -0.000996f, -0.001434f, + // -0.000981f, -0.000325f, -0.000444f, -0.001067f, -0.001487f, + // -0.001051f, -0.000421f, 0.000697f, 0.000188f, -0.000152f, 0.000210f, + // 0.000731f, 0.000650f, 0.000165f, -0.000161f, 0.000185f, 0.000683f, + // 0.000610f, 0.000145f, -0.000168f, 0.000164f, 0.000641f, 0.000574f, + // 0.000128f, -0.000172f, 0.000146f, 0.000604f, 0.000542f, 0.000113f, + // -0.000175f, 0.000131f, 0.000571f, -0.009490f, -0.010070f, + // -0.010409f, -0.009734f, -0.008834f, -0.008785f, -0.009351f, + // -0.009687f, -0.009054f, -0.008207f, -0.008167f, -0.008718f, + // -0.009050f, -0.008455f, -0.007654f, -0.007622f, -0.008159f, + // -0.008485f, -0.007924f, -0.007164f, -0.007138f, -0.007661f, + // -0.007981f, -0.007450f, -0.006728f, -0.000901f, -0.001327f, + // -0.001614f, -0.001310f, -0.000869f, -0.000913f, -0.001328f, + // -0.001607f, -0.001310f, -0.000882f, -0.000922f, -0.001326f, + // -0.001598f, -0.001309f, -0.000892f, -0.000930f, -0.001323f, + // -0.001588f, -0.001306f, -0.000900f, -0.000936f, -0.001319f, + // -0.001577f, -0.001302f, -0.000906f, 0.000339f, 0.000038f, + // -0.000164f, 0.000048f, 0.000355f, 0.000328f, 0.000035f, -0.000162f, + // 0.000045f, 0.000343f, 0.000318f, 0.000033f, -0.000159f, 0.000041f, + // 0.000332f, 0.000308f, 0.000030f, -0.000157f, 0.000039f, 0.000322f, + // 0.000299f, 0.000028f, -0.000155f, 0.000036f, 0.000312f, -0.004085f, + // -0.004479f, -0.004733f, -0.004396f, -0.003925f, -0.003925f, + // -0.004309f, -0.004558f, -0.004232f, -0.003775f, -0.003776f, + // -0.004151f, -0.004395f, -0.004079f, -0.003636f, -0.003637f, + // -0.004004f, -0.004242f, -0.003936f, -0.003505f, -0.003507f, + // -0.003866f, -0.004100f, -0.003802f, -0.003383f} + ); + + sd::ops::lrn_bp op; + auto results = op.evaluate({&x, &eps}, {1.0, 1.0, 0.5}, {2}, {}, {}, false); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(out)); + // out->printBuffer("LRN BP out"); + // exp.printIndexedBuffer("LRN exp"); + // ASSERT_TRUE(exp.equalsTo(out)); } - - diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp index f2bd393e41f1..5bd273187c7e 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -18,83 +18,78 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 22.06.2018 // - -#include "testlayers.h" -#include #include -#include #include #include +#include +#include +#include "testlayers.h" using namespace sd; - class DeclarableOpsTests9 : public testing::Test { -public: - - DeclarableOpsTests9() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests9() { + printf("\n"); + fflush(stdout); + } }; //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, reduceStDevBP_test3) { + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = NDArrayFactory::create('c', {3, 1}, {1., 2., 3.}); + auto gradO2 = NDArrayFactory::create('c', {3}, {1., 2., 3.}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {-0.335410, -0.111803, 0.111803, 0.335410, -0.670820, -0.223607, 0.223607, + 0.670820, -1.006231, -0.335410, 0.335410, 1.006231}); - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create('c', {3,1}, {1.,2.,3.}); - auto gradO2 = NDArrayFactory::create('c', {3}, {1.,2.,3.}); - auto exp = NDArrayFactory::create('c', {3,4}, {-0.335410, -0.111803, 0.111803, 0.335410, -0.670820, -0.223607, 0.223607, 0.670820, -1.006231, -0.335410, 0.335410, 1.006231}); - - x.linspace(1); - - sd::ops::reduce_stdev_bp op; - - auto result = op.evaluate({&x, &gradO2}, {0,0}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - // output->printIndexedBuffer(); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + x.linspace(1); + sd::ops::reduce_stdev_bp op; - result = op.evaluate({&x, &gradO1}, {1,0}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + auto result = op.evaluate({&x, &gradO2}, {0, 0}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + // output->printIndexedBuffer(); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + result = op.evaluate({&x, &gradO1}, {1, 0}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, reduceStDevBP_test03) { - - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create('c', {3,1}, {1.,2.,3.}); - auto gradO2 = NDArrayFactory::create('c', {3}, {1.,2.,3.}); - auto exp = NDArrayFactory::create('c', {3,4}, {-0.335410, -0.111803, 0.111803, 0.335410, -0.670820, -0.223607, 0.223607, 0.670820, -1.006231, -0.335410, 0.335410, 1.006231}); - auto axis = NDArrayFactory::create('c', {1}, {1}); - x.linspace(1); - - sd::ops::reduce_stdev_bp op; - - auto result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {false, false}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - // output->printIndexedBuffer(); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1,0}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = NDArrayFactory::create('c', {3, 1}, {1., 2., 3.}); + auto gradO2 = NDArrayFactory::create('c', {3}, {1., 2., 3.}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {-0.335410, -0.111803, 0.111803, 0.335410, -0.670820, -0.223607, 0.223607, + 0.670820, -1.006231, -0.335410, 0.335410, 1.006231}); + auto axis = NDArrayFactory::create('c', {1}, {1}); + x.linspace(1); + + sd::ops::reduce_stdev_bp op; + + auto result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {false, false}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + // output->printIndexedBuffer(); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 0}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /* @@ -110,13 +105,16 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test1) { double extraParams[] = {lambda}; Nd4jLong *buffer = new Nd4jLong[N]; - auto rng = (sd::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer); - if (rng == nullptr) - throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test1: RNG initialization failed !"); + auto rng = (sd::random::RandomBuffer *) initRandom(nullptr, 123, N, +(Nd4jPointer) buffer); if (rng == nullptr) throw +std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test1: RNG +initialization failed !"); - functions::random::RandomFunction::template execTransform>(rng, x.getBuffer(), x.shapeInfo(), extraParams); - const double actualMean = x.meanNumber().e(0); - const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); + functions::random::RandomFunction::template +execTransform>(rng, x.getBuffer(), +x.shapeInfo(), extraParams); const double actualMean = +x.meanNumber().e(0); const double actualStd = +x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); ASSERT_NEAR(mean, actualMean, 0.01); ASSERT_NEAR(std, actualStd, 0.01); @@ -141,14 +139,18 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test2) { Nd4jLong *buffer = new Nd4jLong[N]; - auto rng = (sd::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer); - if (rng == nullptr) - throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test2: RNG initialization failed !"); + auto rng = (sd::random::RandomBuffer *) initRandom(nullptr, 123, N, +(Nd4jPointer) buffer); if (rng == nullptr) throw +std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test2: RNG +initialization failed !"); - functions::random::RandomFunction::template execTransform>(rng, y.getBuffer(), y.shapeInfo(), x.getBuffer(), x.shapeInfo(), extraParams); + functions::random::RandomFunction::template +execTransform>(rng, y.getBuffer(), +y.shapeInfo(), x.getBuffer(), x.shapeInfo(), extraParams); const double actualMean = x.meanNumber().e(0); - const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); + const double actualStd = +x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); ASSERT_NEAR(mean, actualMean, 0.01); ASSERT_NEAR(std, actualStd, 0.01); @@ -170,13 +172,16 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test1) { double extraParams[] = {lambda}; Nd4jLong *buffer = new Nd4jLong[N]; - auto rng = (sd::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer); - if (rng == nullptr) - throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test1: RNG initialization failed !"); + auto rng = (sd::random::RandomBuffer *) initRandom(nullptr, 123, N, +(Nd4jPointer) buffer); if (rng == nullptr) throw +std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test1: RNG +initialization failed !"); - functions::random::RandomFunction::template execTransform>(rng, x.getBuffer(), x.shapeInfo(), extraParams); - const double actualMean = x.meanNumber().e(0); - const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); + functions::random::RandomFunction::template +execTransform>(rng, x.getBuffer(), +x.shapeInfo(), extraParams); const double actualMean = +x.meanNumber().e(0); const double actualStd = +x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); ASSERT_NEAR(mean, actualMean, 0.01); ASSERT_NEAR(std, actualStd, 0.01); @@ -203,16 +208,20 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test2) { Nd4jLong *buffer = new Nd4jLong[N]; // Nd4jPointer extra[2]; #ifndef __CUDABLAS__ - sd::random::RandomBuffer* rng = (sd::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer); - if (rng == nullptr) - throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test2: RNG initialization failed !"); + sd::random::RandomBuffer* rng = (sd::random::RandomBuffer *) +initRandom(nullptr, 123, N, (Nd4jPointer) buffer); if (rng == nullptr) throw +std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test2: RNG +initialization failed !"); - functions::random::RandomFunction::template execTransform>(rng, y.getBuffer(), y.shapeInfo(), x.getBuffer(), x.shapeInfo(), extraParams); + functions::random::RandomFunction::template +execTransform>(rng, y.getBuffer(), +y.shapeInfo(), x.getBuffer(), x.shapeInfo(), extraParams); destroyRandom((Nd4jPointer) rng); #endif const double actualMean = x.meanNumber().e(0); - const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); + const double actualStd = +x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); ASSERT_NEAR(mean, actualMean, 0.01); ASSERT_NEAR(std, actualStd, 0.01); @@ -224,1998 +233,2072 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test2) { */ TEST_F(DeclarableOpsTests9, ScalarOpTest_MixedOrders_1) { - auto x = NDArrayFactory::create('f', {2, 2}, {1.0, 3.0, 2.0, 4.0}); - auto e = NDArrayFactory::create('c', {2, 2}, {2.0, 3.0, 4.0, 5.0}); - auto z = NDArrayFactory::create('c', {2, 2}, {0.0, 0.0, 0.0, 0.0}); + auto x = NDArrayFactory::create('f', {2, 2}, {1.0, 3.0, 2.0, 4.0}); + auto e = NDArrayFactory::create('c', {2, 2}, {2.0, 3.0, 4.0, 5.0}); + auto z = NDArrayFactory::create('c', {2, 2}, {0.0, 0.0, 0.0, 0.0}); - x.applyScalar(scalar::Add, 1.0, z); + x.applyScalar(scalar::Add, 1.0, z); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test1) { - - auto x0 = NDArrayFactory::create('c', {2,3,4}); - auto x1 = NDArrayFactory::create('c', {2,2,4}); - auto x2 = NDArrayFactory::create('c', {2,1,4}); - auto exp = NDArrayFactory::create('c', {2,6,4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, - 13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.}); - - x0.linspace(1); - x1.linspace(1); - x2.linspace(1); - - sd::ops::concat op; - - auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); + auto x0 = NDArrayFactory::create('c', {2, 3, 4}); + auto x1 = NDArrayFactory::create('c', {2, 2, 4}); + auto x2 = NDArrayFactory::create('c', {2, 1, 4}); + auto exp = NDArrayFactory::create( + 'c', {2, 6, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.}); + + x0.linspace(1); + x1.linspace(1); + x2.linspace(1); + + sd::ops::concat op; + + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); // output->printCurrentBuffer(false); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test2) { + auto x0 = NDArrayFactory::create('c', {1, 3, 1}); + auto x1 = NDArrayFactory::create('c', {1, 2, 1}); + auto x2 = NDArrayFactory::create('c', {1, 1, 1}); + auto exp = NDArrayFactory::create('c', {1, 6, 1}, + {1.f, 2.f, 3.f, 1.f, 2.f, 1.f}); - auto x0 = NDArrayFactory::create('c', {1,3,1}); - auto x1 = NDArrayFactory::create('c', {1,2,1}); - auto x2 = NDArrayFactory::create('c', {1,1,1}); - auto exp = NDArrayFactory::create('c', {1,6,1}, {1.f, 2.f, 3.f, 1.f, 2.f, 1.f}); - - x0.linspace(1); - x1.linspace(1); - x2.linspace(1); - - sd::ops::concat op; - - auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); + x0.linspace(1); + x1.linspace(1); + x2.linspace(1); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test3) { + auto x0 = NDArrayFactory::create('c', {3}); + auto x1 = NDArrayFactory::create('c', {2}); + auto x2 = NDArrayFactory::create('c', {1}); + auto exp = + NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 1.f, 2.f, 1.f}); - auto x0 = NDArrayFactory::create('c', {3}); - auto x1 = NDArrayFactory::create('c', {2}); - auto x2 = NDArrayFactory::create('c', {1}); - auto exp = NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 1.f, 2.f, 1.f}); + x0.linspace(1); + x1.linspace(1); + x2.linspace(1); - x0.linspace(1); - x1.linspace(1); - x2.linspace(1); + sd::ops::concat op; - sd::ops::concat op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); - auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test4) { + auto x0 = NDArrayFactory::create('c', {1, 1, 1}, {1.f}); + auto x1 = NDArrayFactory::create('c', {1, 1, 1}, {2.f}); + auto x2 = NDArrayFactory::create('c', {1, 1, 1}, {3.f}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {1.f, 2.f, 3.f}); - auto x0 = NDArrayFactory::create('c', {1,1,1}, {1.f}); - auto x1 = NDArrayFactory::create('c', {1,1,1}, {2.f}); - auto x2 = NDArrayFactory::create('c', {1,1,1}, {3.f}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {1.f, 2.f, 3.f}); - - sd::ops::concat op; - - auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test5) { + auto x0 = NDArrayFactory::create(1.f); + auto x1 = NDArrayFactory::create('c', {1}, {2.f}); + auto x2 = NDArrayFactory::create(3.f); + auto exp = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - auto x0 = NDArrayFactory::create(1.f); - auto x1 = NDArrayFactory::create('c', {1}, {2.f}); - auto x2 = NDArrayFactory::create(3.f); - auto exp = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - - sd::ops::concat op; - - auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test6) { + auto x0 = NDArrayFactory::create(1.f); + auto x1 = NDArrayFactory::create('c', {2}, {2.f, 20.f}); + auto x2 = NDArrayFactory::create(3.f); + auto exp = NDArrayFactory::create('c', {4}, {1.f, 2.f, 20.f, 3.f}); - auto x0 = NDArrayFactory::create(1.f); - auto x1 = NDArrayFactory::create('c', {2}, {2.f, 20.f}); - auto x2 = NDArrayFactory::create(3.f); - auto exp = NDArrayFactory::create('c', {4}, {1.f, 2.f, 20.f, 3.f}); - - sd::ops::concat op; - - auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test7) { + auto x0 = NDArrayFactory::create(1.f); + auto x1 = NDArrayFactory::create(2.f); + auto x2 = NDArrayFactory::create(3.f); + auto exp = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - auto x0 = NDArrayFactory::create(1.f); - auto x1 = NDArrayFactory::create(2.f); - auto x2 = NDArrayFactory::create(3.f); - auto exp = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - - sd::ops::concat op; - - auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test8) { + auto x0 = NDArrayFactory::create(1.f); + auto exp = NDArrayFactory::create('c', {1}, {1.f}); - auto x0 = NDArrayFactory::create(1.f); - auto exp = NDArrayFactory::create('c', {1}, {1.f}); - - sd::ops::concat op; - - auto result = op.evaluate({&x0}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::concat op; + auto result = op.evaluate({&x0}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test9) { + auto x0 = NDArrayFactory::create('c', {1}, {1.f}); + auto exp = NDArrayFactory::create('c', {1}, {1.f}); - auto x0 = NDArrayFactory::create('c', {1}, {1.f}); - auto exp = NDArrayFactory::create('c', {1}, {1.f}); - - sd::ops::concat op; - - auto result = op.evaluate({&x0}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::concat op; + auto result = op.evaluate({&x0}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test10) { + auto x0 = NDArrayFactory::create('c', {2, 3, 4}); + auto x1 = NDArrayFactory::create('f', {2, 2, 4}); + auto x2 = NDArrayFactory::create('c', {2, 1, 4}); + auto exp = NDArrayFactory::create( + 'c', {2, 6, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f}); - auto x0 = NDArrayFactory::create('c', {2,3,4}); - auto x1 = NDArrayFactory::create('f', {2,2,4}); - auto x2 = NDArrayFactory::create('c', {2,1,4}); - auto exp = NDArrayFactory::create('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, - 13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f}); - - x0.linspace(1); - x1.linspace(1); - x2.linspace(1); - - sd::ops::concat op; - - auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); + x0.linspace(1); + x1.linspace(1); + x2.linspace(1); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test11) { + auto x0 = NDArrayFactory::create('c', {2, 3, 4}); + auto x1 = NDArrayFactory::create('f', {2, 2, 4}); + auto x2 = NDArrayFactory::create('f', {2, 1, 4}); + auto exp = NDArrayFactory::create( + 'c', {2, 6, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f}); - auto x0 = NDArrayFactory::create('c', {2,3,4}); - auto x1 = NDArrayFactory::create('f', {2,2,4}); - auto x2 = NDArrayFactory::create('f', {2,1,4}); - auto exp = NDArrayFactory::create('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, - 13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f}); + x0.linspace(1); + x1.linspace(1); + x2.linspace(1); - x0.linspace(1); - x1.linspace(1); - x2.linspace(1); - - sd::ops::concat op; - - auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test12) { + auto x0 = NDArrayFactory::create('c', {2, 3, 4}); + auto x1 = NDArrayFactory::create('f', {2, 2, 4}); + auto x2 = NDArrayFactory::create('f', {2, 1, 4}); + auto exp = NDArrayFactory::create( + 'c', {2, 6, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f}); - auto x0 = NDArrayFactory::create('c', {2,3,4}); - auto x1 = NDArrayFactory::create('f', {2,2,4}); - auto x2 = NDArrayFactory::create('f', {2,1,4}); - auto exp = NDArrayFactory::create('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, - 13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f}); + x0.linspace(1); + x1.linspace(1); + x2.linspace(1); - x0.linspace(1); - x1.linspace(1); - x2.linspace(1); - - sd::ops::concat op; - - auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test13) { + auto x0 = NDArrayFactory::create('f', {2, 3, 4}); + auto x1 = NDArrayFactory::create('f', {2, 2, 4}); + auto x2 = NDArrayFactory::create('f', {2, 1, 4}); + auto exp = NDArrayFactory::create( + 'f', {2, 6, 4}, + {1.f, 13.f, 5.f, 17.f, 9.f, 21.f, 1.f, 9.f, 5.f, 13.f, 1.f, 5.f, + 2.f, 14.f, 6.f, 18.f, 10.f, 22.f, 2.f, 10.f, 6.f, 14.f, 2.f, 6.f, + 3.f, 15.f, 7.f, 19.f, 11.f, 23.f, 3.f, 11.f, 7.f, 15.f, 3.f, 7.f, + 4.f, 16.f, 8.f, 20.f, 12.f, 24.f, 4.f, 12.f, 8.f, 16.f, 4.f, 8.f}); - auto x0 = NDArrayFactory::create('f', {2,3,4}); - auto x1 = NDArrayFactory::create('f', {2,2,4}); - auto x2 = NDArrayFactory::create('f', {2,1,4}); - auto exp = NDArrayFactory::create('f', {2,6,4}, { 1.f, 13.f, 5.f, 17.f, 9.f, 21.f, 1.f, 9.f, 5.f, 13.f, 1.f, 5.f, 2.f, 14.f, 6.f, 18.f,10.f, 22.f, 2.f, 10.f, 6.f, 14.f, 2.f, 6.f, - 3.f, 15.f, 7.f, 19.f,11.f, 23.f, 3.f, 11.f, 7.f, 15.f, 3.f, 7.f, 4.f, 16.f, 8.f, 20.f,12.f, 24.f, 4.f, 12.f, 8.f, 16.f, 4.f, 8.f}); - - x0.linspace(1); - x1.linspace(1); - x2.linspace(1); + x0.linspace(1); + x1.linspace(1); + x2.linspace(1); - sd::ops::concat op; - - auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } TEST_F(DeclarableOpsTests9, concat_test14) { + NDArray x0('c', {1, 40, 60}, sd::DataType::FLOAT32); + NDArray x1('c', {1, 40, 60}, sd::DataType::FLOAT32); - NDArray x0('c', {1, 40, 60}, sd::DataType::FLOAT32); - NDArray x1('c', {1, 40, 60}, sd::DataType::FLOAT32); - - x0 = 1.; - x1 = 2.; + x0 = 1.; + x1 = 2.; - sd::ops::concat op; - auto result = op.evaluate({&x0, &x1}, {}, {0}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1}, {}, {0}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->shapeInfo(), {0}); - ASSERT_TRUE(2 == numOfTads); - - for (int e = 0; e < numOfTads; ++e) { - NDArray tad = (*z)(e, {0}); - auto mean = tad.meanNumber().e(0); - ASSERT_NEAR((e+1)*1., mean, 1e-5); - } + auto z = result.at(0); + Nd4jLong numOfTads = ShapeUtils::getNumOfSubArrs(z.shapeInfo(), {0}); + ASSERT_TRUE(2 == numOfTads); + for (int e = 0; e < numOfTads; ++e) { + NDArray tad = z(e, {0}); + auto mean = tad.meanNumber().e(0); + ASSERT_NEAR((e + 1) * 1., mean, 1e-5); + } } TEST_F(DeclarableOpsTests9, concat_test15) { - auto x = NDArrayFactory::create('c', {2}, {1, 0}); - auto y = NDArrayFactory::create (3.0f); - auto exp = NDArrayFactory::create('c', {3}, {1, 0, 3}); - - sd::ops::concat op; - auto result = op.evaluate({&x, &y}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto x = NDArrayFactory::create('c', {2}, {1, 0}); + auto y = NDArrayFactory::create(3.0f); + auto exp = NDArrayFactory::create('c', {3}, {1, 0, 3}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::concat op; + auto result = op.evaluate({&x, &y}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test16) { + auto x = NDArrayFactory::create('c', {0, 2, 3}); + auto y = NDArrayFactory::create('c', {0, 2, 3}); + auto exp = NDArrayFactory::create('c', {0, 2, 3}); - auto x = NDArrayFactory::create('c', {0,2,3}); - auto y = NDArrayFactory::create('c', {0,2,3}); - auto exp = NDArrayFactory::create('c', {0,2,3}); + sd::ops::concat op; + auto result = op.evaluate({&x, &y}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - sd::ops::concat op; - auto result = op.evaluate({&x, &y}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.isSameShape(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test17) { + NDArray x0('c', {1, 55, 40}, sd::DataType::FLOAT32); + NDArray x1('c', {1, 55, 40}, sd::DataType::FLOAT32); - NDArray x0('c', {1, 55, 40}, sd::DataType::FLOAT32); - NDArray x1('c', {1, 55, 40}, sd::DataType::FLOAT32); - - x0 = 1.; - x1 = 2.; + x0 = 1.; + x1 = 2.; - sd::ops::concat op; - auto result = op.evaluate({&x0, &x1}, {}, {0}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1}, {}, {0}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - // z->printShapeInfo(); - // z->printIndexedBuffer(); + auto z = result.at(0); + // z->printShapeInfo(); + // z->printIndexedBuffer(); - Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->shapeInfo(), {0}); - ASSERT_TRUE(2 == numOfTads); + Nd4jLong numOfTads = ShapeUtils::getNumOfSubArrs(z.shapeInfo(), {0}); + ASSERT_TRUE(2 == numOfTads); - for (int e = 0; e < numOfTads; ++e) { - NDArray tad = (*z)(e, {0}); - auto mean = tad.meanNumber().e(0); - ASSERT_NEAR((e+1)*1., mean, 1e-5); - } + for (int e = 0; e < numOfTads; ++e) { + NDArray tad = z(e, {0}); + auto mean = tad.meanNumber().e(0); + ASSERT_NEAR((e + 1) * 1., mean, 1e-5); + } } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test18) { - Context context(1); - Nd4jLong axis = 0; + Context context(1); + Nd4jLong axis = 0; - // we crate bunch of arrays, filled with specific values - for (int e = 0; e < 2000; e++) { - auto array = NDArrayFactory::create_('c', {1, 300}); - array->assign(e); - context.setInputArray(e, array, true); - } + // we crate bunch of arrays, filled with specific values + for (int e = 0; e < 2000; e++) { + auto array = NDArrayFactory::create('c', {1, 300}); + array.assign(e); + context.setInputArray(e, array); + } - auto z = NDArrayFactory::create('c', {2000, 300}); - context.setOutputArray(0, &z, false); - context.setIArguments(&axis, 1); + auto z = NDArrayFactory::create('c', {2000, 300}); + context.setOutputArray(0, z); + context.setIArguments(&axis, 1); - sd::ops::concat op; - op.execute(&context); + sd::ops::concat op; + op.execute(&context); - for (int e = 0; e < 2000; e++) { - auto exp = NDArrayFactory::create('c', {300}); - exp.assign(e); - auto row = z(e, {0}); - ASSERT_EQ(exp, row); - } + for (int e = 0; e < 2000; e++) { + auto exp = NDArrayFactory::create('c', {300}); + exp.assign(e); + auto row = z(e, {0}); + ASSERT_EQ(exp, row); + } } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test19) { + Context context(1); + Nd4jLong axis = 0; - Context context(1); - Nd4jLong axis = 0; + // we crate bunch of arrays, filled with specific values + for (int e = 0; e < 10; e++) { + auto array = NDArrayFactory::create('c', {1, 5, 20}); + array.assign(e); + context.setInputArray(e, array); + } - // we crate bunch of arrays, filled with specific values - for (int e = 0; e < 10; e++) { - auto array = NDArrayFactory::create_('c', {1, 5, 20}); - array->assign(e); - context.setInputArray(e, array, true); - } + auto z = NDArrayFactory::create('c', {10, 5, 20}); + context.setOutputArray(0, z); + context.setIArguments(&axis, 1); - auto z = NDArrayFactory::create('c', {10, 5, 20}); - context.setOutputArray(0, &z, false); - context.setIArguments(&axis, 1); + sd::ops::concat op; + op.execute(&context); - sd::ops::concat op; - op.execute(&context); - - for (int e = 0; e < 10; e++) - ASSERT_NEAR((float) e, z(e, {0}).meanNumber().e(0), 1e-5f); + for (int e = 0; e < 10; e++) + ASSERT_NEAR((float)e, z(e, {0}).meanNumber().e(0), 1e-5f); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test20) { - auto x0 = NDArrayFactory::create('c', {1, 100, 150}); - auto x1 = NDArrayFactory::create('c', {1, 100, 150}); - auto x2 = NDArrayFactory::create('c', {1, 100, 150}); - auto x3 = NDArrayFactory::create('c', {1, 100, 150}); - - x0.assign(1.0); - x1.assign(2.0); - x2.assign(3.0); - x3.assign(4.0); - - sd::ops::concat op; - auto result = op.evaluate({&x0, &x1, &x2, &x3}, {}, {0}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x0 = NDArrayFactory::create('c', {1, 100, 150}); + auto x1 = NDArrayFactory::create('c', {1, 100, 150}); + auto x2 = NDArrayFactory::create('c', {1, 100, 150}); + auto x3 = NDArrayFactory::create('c', {1, 100, 150}); - auto z = result.at(0); + x0.assign(1.0); + x1.assign(2.0); + x2.assign(3.0); + x3.assign(4.0); - Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->shapeInfo(), {0}); - ASSERT_TRUE(4 == numOfTads); + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1, &x2, &x3}, {}, {0}, {}); + ASSERT_EQ(Status::OK(), result.status()); - for (int e = 0; e < numOfTads; e++) { - NDArray tad = (*z)(e, {0}); - auto mean = tad.meanNumber().e(0); - ASSERT_NEAR((float) e+1, mean, 1e-5); - } + auto z = result.at(0); + Nd4jLong numOfTads = ShapeUtils::getNumOfSubArrs(z.shapeInfo(), {0}); + ASSERT_TRUE(4 == numOfTads); + for (int e = 0; e < numOfTads; e++) { + NDArray tad = z(e, {0}); + auto mean = tad.meanNumber().e(0); + ASSERT_NEAR((float)e + 1, mean, 1e-5); + } } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test21) { + NDArray x0('c', {1, 4, 5}, sd::DataType::FLOAT32); + NDArray x1('c', {2, 4, 5}, sd::DataType::FLOAT32); + NDArray z('f', {3, 4, 5}, sd::DataType::FLOAT32); - NDArray x0('c', {1,4,5}, sd::DataType::FLOAT32); - NDArray x1('c', {2,4,5}, sd::DataType::FLOAT32); - NDArray z('f', {3,4,5}, sd::DataType::FLOAT32); + x0 = 0.; + x1 = 1.; - x0 = 0.; - x1 = 1.; - - sd::ops::concat op; - auto status = op.execute({&x0, &x1}, {&z}, {}, {0}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); + sd::ops::concat op; + auto status = op.execute({&x0, &x1}, {&z}, {}, {0}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test22) { + NDArray x0('c', {1, 6}, {1, 2, 3, 4, 5, 6}, sd::DataType::FLOAT32); + NDArray x1('c', {1, 6}, {7, 8, 9, 10, 11, 12}, sd::DataType::FLOAT32); + NDArray output('f', {2, 6}, sd::DataType::FLOAT32); + NDArray exp('c', {2, 6}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, sd::DataType::FLOAT32); - NDArray x0('c', {1,6}, {1,2,3,4,5,6}, sd::DataType::FLOAT32); - NDArray x1('c', {1,6}, {7,8,9,10,11,12}, sd::DataType::FLOAT32); - NDArray output('f', {2,6}, sd::DataType::FLOAT32); - NDArray exp('c', {2,6}, {1,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::FLOAT32); - - sd::ops::concat op; + sd::ops::concat op; - auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); + auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test23) { + NDArray x0('c', {1, 4}, {1, 2, 3, 4},sd::DataType::FLOAT32); + NDArray x1('c', {1, 4}, {5, 6, 7, 8},sd::DataType::FLOAT32); + NDArray output('c', {2, 4}, sd::DataType::FLOAT32); + NDArray exp('c', {2, 4}, {1, 2, 3, 4, 5, 6, 7, 8}, sd::DataType::FLOAT32); - NDArray x0('c', {1,4}, {1,2,3,4},sd::DataType::FLOAT32); - NDArray x1('c', {1,4}, {5,6,7,8},sd::DataType::FLOAT32); - NDArray output('c', {2,4}, sd::DataType::FLOAT32); - NDArray exp('c', {2,4}, {1,2,3,4,5,6,7,8}, sd::DataType::FLOAT32); - - sd::ops::concat op; + sd::ops::concat op; - auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); + auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test24) { - auto x = NDArrayFactory::create('c', {2, 1}, {1, 1}); - auto y = NDArrayFactory::create('c', {2, 1}, {0, 0}); - auto e = NDArrayFactory::create('c', {2, 2}, {1, 0, 1, 0}); - auto z = NDArrayFactory::create('c', {2, 2}); + auto x = NDArrayFactory::create('c', {2, 1}, {1, 1}); + auto y = NDArrayFactory::create('c', {2, 1}, {0, 0}); + auto e = NDArrayFactory::create('c', {2, 2}, {1, 0, 1, 0}); + auto z = NDArrayFactory::create('c', {2, 2}); - sd::ops::concat op; - auto status = op.execute({&x, &y}, {&z}, {}, {1}, {}); - ASSERT_EQ(Status::OK(), status); + sd::ops::concat op; + auto status = op.execute({&x, &y}, {&z}, {}, {1}, {}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test25) { + auto x0 = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto x1 = NDArrayFactory::create('c', {1, 4}, {5, 6, 7, 8}); + auto axis = NDArrayFactory::create('c', {1}, {0.}); + auto exp = + NDArrayFactory::create('c', {2, 4}, {1, 2, 3, 4, 5, 6, 7, 8}); - auto x0 = NDArrayFactory::create('c', {1,4}, {1,2,3,4}); - auto x1 = NDArrayFactory::create('c', {1,4}, {5,6,7,8}); - auto axis = NDArrayFactory::create('c', {1}, {0.}); - auto exp = NDArrayFactory::create('c', {2,4}, {1,2,3,4,5,6,7,8}); - - sd::ops::concat op; - - auto result = op.evaluate({&x0, &x1, &axis}, {}, {}, {true}); + sd::ops::concat op; - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + auto result = op.evaluate({&x0, &x1, &axis}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test26) { + NDArray x0('f', {1, 2, 3}, sd::DataType::INT32); + NDArray x1('f', {1, 2, 3}, sd::DataType::INT32); + NDArray x2('f', {1, 2, 3}, sd::DataType::INT32); - NDArray x0('f', {1, 2, 3}, sd::DataType::INT32); - NDArray x1('f', {1, 2, 3}, sd::DataType::INT32); - NDArray x2('f', {1, 2, 3}, sd::DataType::INT32); + NDArray exp('f', {3, 2, 3}, + {0, 6, 12, 3, 9, 15, 1, 7, 13, 4, 10, 16, 2, 8, 14, 5, 11, 17}, + sd::DataType::INT32); - NDArray exp('f', {3, 2, 3}, {0, 6, 12, 3, 9, 15, 1, 7, 13, 4, 10, 16, 2, 8, 14, 5, 11, 17}, sd::DataType::INT32); + x0.linspace(0); + x1.linspace(6); + x2.linspace(12); - x0.linspace(0); - x1.linspace(6); - x2.linspace(12); + sd::ops::concat op; - sd::ops::concat op; - - auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}, {}); + auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}, {}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); auto output = result.at(0); // output->printLinearBuffer(); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test27) { + auto x1 = NDArrayFactory::create('c', {0, 1}); + auto x2 = NDArrayFactory::create('c', {0, 1}); + auto x3 = NDArrayFactory::create('c', {0, 1}); + auto x4 = NDArrayFactory::create('c', {0, 1}); - auto x1 = NDArrayFactory::create('c', {0,1}); - auto x2 = NDArrayFactory::create('c', {0,1}); - auto x3 = NDArrayFactory::create('c', {0,1}); - auto x4 = NDArrayFactory::create('c', {0,1}); - - std::vector expShape = {0, 4}; + std::vector expShape = {0, 4}; - sd::ops::concat op; - auto result = op.evaluate({&x1, &x2, &x3, &x4}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::concat op; + auto result = op.evaluate({&x1, &x2, &x3, &x4}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(z->isSameShape(expShape)); + ASSERT_TRUE(z.isSameShape(expShape)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, tile_bp_test1) { + auto input = + NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto gradO = NDArrayFactory::create('c', {4, 9}); + auto gradIExp = NDArrayFactory::create( + 'c', {2, 3}, {0.78, 0.84, 0.9, 1.32, 1.38, 1.44}); - auto input = NDArrayFactory::create('c', {2, 3}, {1.,2.,3.,4.,5.,6.}); - auto gradO = NDArrayFactory::create('c', {4, 9}); - auto gradIExp = NDArrayFactory::create('c', {2, 3}, {0.78, 0.84, 0.9,1.32, 1.38, 1.44}); - - gradO.linspace(0.01, 0.01); - - sd::ops::tile_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {2, 3}); - auto gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(gradIExp.isSameShape(gradI)); - ASSERT_TRUE(gradIExp.equalsTo(gradI)); + gradO.linspace(0.01, 0.01); + sd::ops::tile_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {2, 3}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradIExp.isSameShape(gradI)); + ASSERT_TRUE(gradIExp.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, tile_bp_test2) { + auto input = + NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto gradO = NDArrayFactory::create('c', {2, 9}); + auto gradIExp = NDArrayFactory::create( + 'c', {2, 3}, {0.12, 0.15, 0.18, 0.39, 0.42, 0.45}); - auto input = NDArrayFactory::create('c', {2, 3}, {1.,2.,3.,4.,5.,6.}); - auto gradO = NDArrayFactory::create('c', {2, 9}); - auto gradIExp = NDArrayFactory::create('c', {2, 3}, {0.12, 0.15, 0.18, 0.39, 0.42, 0.45}); - - gradO.linspace(0.01, 0.01); - - sd::ops::tile_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {1, 3}); - auto gradI = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(gradIExp.isSameShape(gradI)); - ASSERT_TRUE(gradIExp.equalsTo(gradI)); - + gradO.linspace(0.01, 0.01); + sd::ops::tile_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {1, 3}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradIExp.isSameShape(gradI)); + ASSERT_TRUE(gradIExp.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, tile_bp_test3) { + auto input = + NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto gradO = NDArrayFactory::create('c', {2, 3}); + auto gradIExp = NDArrayFactory::create( + 'c', {2, 3}, {0.01, 0.02, 0.03, 0.04, 0.05, 0.06}); - auto input = NDArrayFactory::create('c', {2, 3}, {1.,2.,3.,4.,5.,6.}); - auto gradO = NDArrayFactory::create('c', {2, 3}); - auto gradIExp = NDArrayFactory::create('c', {2, 3}, {0.01, 0.02, 0.03,0.04, 0.05, 0.06}); - - gradO.linspace(0.01, 0.01); - - sd::ops::tile_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {1, 1}); - auto gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(gradIExp.isSameShape(gradI)); - ASSERT_TRUE(gradIExp.equalsTo(gradI)); + gradO.linspace(0.01, 0.01); + sd::ops::tile_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {1, 1}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradIExp.isSameShape(gradI)); + ASSERT_TRUE(gradIExp.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, tile_bp_test4) { + auto input = + NDArrayFactory::create('c', {6}, {1., 2., 3., 4., 5., 6.}); + auto gradO = NDArrayFactory::create('c', {12}); + auto gradIExp = NDArrayFactory::create( + 'c', {6}, {0.08, 0.1, 0.12, 0.14, 0.16, 0.18}); - auto input = NDArrayFactory::create('c', {6}, {1.,2.,3.,4.,5.,6.}); - auto gradO = NDArrayFactory::create('c', {12}); - auto gradIExp = NDArrayFactory::create('c', {6}, {0.08, 0.1 , 0.12, 0.14, 0.16, 0.18}); - - gradO.linspace(0.01, 0.01); - - sd::ops::tile_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {2}); - auto gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(gradIExp.isSameShape(gradI)); - ASSERT_TRUE(gradIExp.equalsTo(gradI)); + gradO.linspace(0.01, 0.01); + sd::ops::tile_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {2}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradIExp.isSameShape(gradI)); + ASSERT_TRUE(gradIExp.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, tile_bp_test5) { + auto input = NDArrayFactory::create('c', {1}, {1.}); + auto gradO = NDArrayFactory::create('c', {1}); + auto gradIExp = NDArrayFactory::create('c', {1}, {0.01}); - auto input = NDArrayFactory::create('c', {1}, {1.}); - auto gradO = NDArrayFactory::create('c', {1}); - auto gradIExp = NDArrayFactory::create('c', {1}, {0.01}); - - gradO.linspace(0.01, 0.01); - - sd::ops::tile_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {1}); - auto gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(gradIExp.isSameShape(gradI)); - ASSERT_TRUE(gradIExp.equalsTo(gradI)); + gradO.linspace(0.01, 0.01); + sd::ops::tile_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {1}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradIExp.isSameShape(gradI)); + ASSERT_TRUE(gradIExp.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, tile_bp_test6) { + auto input = + NDArrayFactory::create('c', {2, 1, 3}, {1., 2., 3., 4., 5., 6.}); + auto gradO = NDArrayFactory::create('c', {2, 3, 6}); + auto gradIExp = NDArrayFactory::create( + 'c', {2, 1, 3}, {0.51, 0.57, 0.63, 1.59, 1.65, 1.71}); - auto input = NDArrayFactory::create('c', {2, 1, 3}, {1.,2.,3.,4.,5.,6.}); - auto gradO = NDArrayFactory::create('c', {2, 3, 6}); - auto gradIExp = NDArrayFactory::create('c', {2, 1, 3}, {0.51, 0.57, 0.63, 1.59, 1.65, 1.71}); - - gradO.linspace(0.01, 0.01); - - sd::ops::tile_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {1, 3, 2}); - auto gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(gradIExp.isSameShape(gradI)); - ASSERT_TRUE(gradIExp.equalsTo(gradI)); + gradO.linspace(0.01, 0.01); + sd::ops::tile_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {1, 3, 2}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradIExp.isSameShape(gradI)); + ASSERT_TRUE(gradIExp.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, tile_bp_test7) { + auto input = + NDArrayFactory::create('c', {2, 1, 3}, {1., 2., 3., 4., 5., 6.}); + auto reps = NDArrayFactory::create('c', {1, 3}, {1, 3, 2}); + auto gradO = NDArrayFactory::create('c', {2, 3, 6}); + auto gradIExp = NDArrayFactory::create( + 'c', {2, 1, 3}, {0.51, 0.57, 0.63, 1.59, 1.65, 1.71}); - auto input = NDArrayFactory::create('c', {2, 1, 3}, {1.,2.,3.,4.,5.,6.}); - auto reps = NDArrayFactory::create('c', {1, 3}, {1, 3, 2}); - auto gradO = NDArrayFactory::create('c', {2, 3, 6}); - auto gradIExp = NDArrayFactory::create('c', {2, 1, 3}, {0.51, 0.57, 0.63, 1.59, 1.65, 1.71}); - - gradO.linspace(0.01, 0.01); - - sd::ops::tile_bp op; - auto results = op.evaluate({&input, &reps, &gradO}, {}, {}); - auto gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(gradIExp.isSameShape(gradI)); - ASSERT_TRUE(gradIExp.equalsTo(gradI)); + gradO.linspace(0.01, 0.01); + sd::ops::tile_bp op; + auto results = op.evaluate({&input, &reps, &gradO}, {}, {}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradIExp.isSameShape(gradI)); + ASSERT_TRUE(gradIExp.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, tile_test1) { + auto input = + NDArrayFactory::create('c', {1, 6}, {1., 2., 3., 4., 5., 6.}); + auto reps = NDArrayFactory::create('c', {1, 2}, {2, 1}); + auto expOut = NDArrayFactory::create( + 'c', + { + 2, + 6, + }, + {1., 2., 3., 4., 5., 6., 1., 2., 3., 4., 5., 6.}); - auto input = NDArrayFactory::create('c', {1, 6}, {1.,2.,3.,4.,5.,6.}); - auto reps = NDArrayFactory::create('c', {1, 2}, {2, 1}); - auto expOut = NDArrayFactory::create('c', {2, 6,}, {1.,2.,3.,4.,5.,6., 1.,2.,3.,4.,5.,6.}); - - sd::ops::tile op; - auto results = op.evaluate({&input, &reps}, {}, {}); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOut.isSameShape(out)); - ASSERT_TRUE(expOut.equalsTo(out)); - + sd::ops::tile op; + auto results = op.evaluate({&input, &reps}, {}, {}); + auto out = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOut.isSameShape(out)); + ASSERT_TRUE(expOut.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, TestDropout_BP_1) { + NDArray x('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + NDArray errs('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + NDArray shape('c', {2}, {2, 2}); + sd::ops::dropout_bp op; - NDArray x('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - NDArray errs('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - NDArray shape('c', {2}, {2, 2}); - sd::ops::dropout_bp op; - - auto ress = op.evaluate({&x, &errs, &shape}, {0.2f}, {113}); - - ASSERT_EQ(ND4J_STATUS_OK, ress.status()); - //ress.at(0)->printIndexedBuffer("Result is "); - //x.printIndexedBuffer("Input is"); - ASSERT_FALSE(ress.at(0)->equalsTo(errs)); + auto ress = op.evaluate({&x, &errs, &shape}, {0.2f}, {113}); + ASSERT_EQ(ND4J_STATUS_OK, ress.status()); + // ress.at(0)->printIndexedBuffer("Result is "); + // x.printIndexedBuffer("Input is"); + ASSERT_FALSE(ress.at(0).equalsTo(errs)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, TestDropout_1) { - - NDArray x('c', {10, 10}, sd::DataType::FLOAT32); -// NDArray errs('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - //NDArray shape({2.f, 2.f}); - sd::ops::dropout op; - x.linspace(1); - auto ress = op.evaluate({&x}, {0.2f}, {113}); - - ASSERT_EQ(ND4J_STATUS_OK, ress.status()); - NDArray* res = ress.at(0); //->printIndexedBuffer("Result is "); - //x.printIndexedBuffer("Input is"); - //res->printIndexedBuffer("Result for Dropout_1"); - auto countZero = res->reduceNumber(reduce::CountZero); - ASSERT_NEAR(countZero.e(0), 80, 5); - auto ress2 = op.evaluate({&x}, {0.2f}, {113}); - - ASSERT_EQ(ND4J_STATUS_OK, ress2.status()); - NDArray* res2 = ress2.at(0); - - countZero = res->reduceNumber(reduce::CountZero); - ASSERT_NEAR(countZero.e(0), 80, 5); - //res2->printIndexedBuffer("Result for Dropout_2"); - ASSERT_TRUE(res->equalsTo(res2)); - //res->printIndexedBuffer("FF dropout"); - //res2->printIndexedBuffer("BP dropout"); - - - + NDArray x('c', {10, 10}, sd::DataType::FLOAT32); + // NDArray errs('c', {2, 2, 2}, + // {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + // NDArray shape({2.f, 2.f}); + sd::ops::dropout op; + x.linspace(1); + auto ress = op.evaluate({&x}, {0.2f}, {113}); + + ASSERT_EQ(ND4J_STATUS_OK, ress.status()); + auto res = ress.at(0); //->printIndexedBuffer("Result is "); + // x.printIndexedBuffer("Input is"); + // res->printIndexedBuffer("Result for Dropout_1"); + auto countZero = res.reduceNumber(reduce::CountZero); + ASSERT_NEAR(countZero.e(0), 80, 5); + auto ress2 = op.evaluate({&x}, {0.2f}, {113}); + + ASSERT_EQ(ND4J_STATUS_OK, ress2.status()); + auto res2 = ress2.at(0); + + countZero = res.reduceNumber(reduce::CountZero); + ASSERT_NEAR(countZero.e(0), 80, 5); + // res2->printIndexedBuffer("Result for Dropout_2"); + ASSERT_TRUE(res.equalsTo(res2)); + // res->printIndexedBuffer("FF dropout"); + // res2->printIndexedBuffer("BP dropout"); } TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) { - NDArray x0('c', {10, 10}, sd::DataType::FLOAT32); - NDArray x1('c', {10, 10}, sd::DataType::FLOAT32); - - x0.linspace(1); - x1.linspace(1); -/* - float prob[] = {0.5f}; - Nd4jLong* _bufferA = new Nd4jLong[100000]; - long _seed = 119L; - auto _rngA = (sd::random::RandomBuffer *) initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferA); - - x0. applyTransform(random::DropOutInverted, &x0, prob); -// x1.template applyRandom>(_rngB, nullptr, &x1, prob); -// x0.printIndexedBuffer("01Result1"); - int count = 0; - for (int e = 0; e < x0.lengthOf(); e++) - if (x0.e(e) != 0.f) - count++; -// nd4j_printf("\nX0 count %i\n", count); -// ASSERT_TRUE(x0.equalsTo(&x1)); - - // this check is required to ensure we're calling wrong signature -// ASSERT_FALSE(x0.equalsTo(nexp0)); -// ASSERT_FALSE(x0.equalsTo(nexp1)); -// ASSERT_FALSE(x0.equalsTo(nexp2)); - destroyRandom(_rngA); - delete [] _bufferA; -*/ - sd::ops::dropout op; - - auto ress = op.evaluate({&x1}, {0.5f}, {119}); - - ASSERT_EQ(ND4J_STATUS_OK, ress.status()); - //ress.at(0)->printIndexedBuffer("01Dropout result is "); - auto count = ress.at(0)->reduceNumber(reduce::CountNonZero); -// nd4j_printf("\n01Dropout count %i\n\n", count); - - sd::ops::dropout_bp op2; - //NDArray exp('c', {10,10}, {4.f, 0.f, 12.f, 0.f, 20.f, 24.f, 0.f, 32.f, 0.f, 0.f, 0.f, 0.f, 52.f, 56.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 84.f, 88.f, 0.f, 0.f, 0.f, 0.f, 108.f, 0.f, 0.f, 120.f, 0.f, 0.f, 132.f, 0.f, 0.f, 0.f, 0.f, 0.f, 156.f, 0.f, 164.f, 168.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 200.f, 204.f, 0.f, 0.f, 0.f, 220.f, 0.f, 0.f, 232.f, 236.f, 240.f, 0.f, 248.f, 0.f, 0.f, 260.f, 0.f, 0.f, 0.f, 276.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 316.f, 0.f, 324.f, 0.f, 0.f, 336.f, 0.f, 0.f, 0.f, 0.f, 356.f, 0.f, 0.f, 368.f, 0.f, 0.f, 0.f, 384.f, 388.f, 0.f, 0.f, 400.f}); - //02Dropout result is [4.000000, 0.000000, 12.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 36.000000, 0.000000, 0.000000, 0.000000, 0.000000, 56.000000, 60.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 88.000000, 0.000000, 96.000000, 0.000000, 0.000000, 108.000000, 0.000000, 0.000000, 120.000000, 0.000000, 128.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 156.000000, 0.000000, 164.000000, 0.000000, 0.000000, 0.000000, 0.000000, 184.000000, 0.000000, 0.000000, 0.000000, 200.000000, 0.000000, 0.000000, 0.000000, 216.000000, 0.000000, 0.000000, 0.000000, 232.000000, 0.000000, 240.000000, 0.000000, 248.000000, 0.000000, 0.000000, 260.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 308.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 348.000000, 0.000000, 356.000000, 0.000000, 0.000000, 0.000000, 0.000000, 376.000000, 0.000000, 384.000000, 0.000000, 0.000000, 0.000000, 400.000000] - - auto ressX = op2.evaluate({&x1, &x1}, {0.5f}, {119}); // , false, sd::DataType::FLOAT32); // skipped due given by default - //x0.printIndexedBuffer("X0"); - //x1.printIndexedBuffer("X1"); - ASSERT_EQ(ND4J_STATUS_OK, ressX.status()); - auto ressY = op2.evaluate({&x1, &x0}, {0.5f}, {119}); - ASSERT_EQ(ND4J_STATUS_OK, ressY.status()); - //ressY->at(0)->printIndexedBuffer("BP"); - //ress.at(0)->printIndexedBuffer("FF"); - bool ret = true; - for (int e = 0; e < ress.at(0)->lengthOf(); e++) { - if (ress.at(0)->e(e) == 0.f) - if (ressX.at(0)->e(e) != ress.at(0)->e(e)) { - ret = false; - break; - } - } - ASSERT_TRUE(ret); - // ASSERT_FALSE(ressX->at(0)->equalsTo(ressY->at(0))); - //ressX->at(0)->printIndexedBuffer("02Dropout result is "); -/* float countZero = ressX->at(0)->template reduceNumber>(); - ASSERT_NEAR(countZero, 50.f, 5.f); - countZero = ress.at(0)->template reduceNumber>(); - ASSERT_NEAR(countZero, 50.f, 5.f); - countZero = ressY->at(0)->template reduceNumber>(); - ASSERT_NEAR(countZero, 50.f, 5.f); - */ -// ASSERT_TRUE(exp.equalsTo(ressX->at(0))); - - + NDArray x0('c', {10, 10}, sd::DataType::FLOAT32); + NDArray x1('c', {10, 10}, sd::DataType::FLOAT32); + + x0.linspace(1); + x1.linspace(1); + /* + float prob[] = {0.5f}; + Nd4jLong* _bufferA = new Nd4jLong[100000]; + long _seed = 119L; + auto _rngA = (sd::random::RandomBuffer *) initRandom(nullptr, _seed, + 100000, (Nd4jPointer) _bufferA); + + x0. applyTransform(random::DropOutInverted, &x0, prob); + // x1.template applyRandom>(_rngB, + nullptr, &x1, prob); + // x0.printIndexedBuffer("01Result1"); + int count = 0; + for (int e = 0; e < x0.lengthOf(); e++) + if (x0.e(e) != 0.f) + count++; + // nd4j_printf("\nX0 count %i\n", count); + // ASSERT_TRUE(x0.equalsTo(&x1)); + + // this check is required to ensure we're calling wrong signature + // ASSERT_FALSE(x0.equalsTo(nexp0)); + // ASSERT_FALSE(x0.equalsTo(nexp1)); + // ASSERT_FALSE(x0.equalsTo(nexp2)); + destroyRandom(_rngA); + delete [] _bufferA; + */ + sd::ops::dropout op; + + auto ress = op.evaluate({&x1}, {0.5f}, {119}); + + ASSERT_EQ(ND4J_STATUS_OK, ress.status()); + // ress.at(0)->printIndexedBuffer("01Dropout result is "); + auto count = ress.at(0).reduceNumber(reduce::CountNonZero); + // nd4j_printf("\n01Dropout count %i\n\n", count); + + sd::ops::dropout_bp op2; + // NDArray exp('c', {10,10}, {4.f, 0.f, 12.f, 0.f, 20.f, 24.f, + // 0.f, 32.f, 0.f, 0.f, 0.f, 0.f, 52.f, 56.f, 60.f, 0.f, 0.f, 0.f, 0.f, + // 0.f, 84.f, 88.f, 0.f, 0.f, 0.f, 0.f, 108.f, 0.f, 0.f, 120.f, 0.f, 0.f, + // 132.f, 0.f, 0.f, 0.f, 0.f, 0.f, 156.f, 0.f, 164.f, 168.f, 0.f, 0.f, 0.f, + // 0.f, 0.f, 0.f, 0.f, 200.f, 204.f, 0.f, 0.f, 0.f, 220.f, 0.f, 0.f, 232.f, + // 236.f, 240.f, 0.f, 248.f, 0.f, 0.f, 260.f, 0.f, 0.f, 0.f, 276.f, 0.f, 0.f, + // 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 316.f, 0.f, 324.f, 0.f, 0.f, 336.f, 0.f, + // 0.f, 0.f, 0.f, 356.f, 0.f, 0.f, 368.f, 0.f, 0.f, 0.f, 384.f, 388.f, 0.f, + // 0.f, 400.f}); 02Dropout result is [4.000000, 0.000000, 12.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 36.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 56.000000, 60.000000, 0.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 88.000000, 0.000000, 96.000000, + // 0.000000, 0.000000, 108.000000, 0.000000, 0.000000, 120.000000, 0.000000, + // 128.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // 156.000000, 0.000000, 164.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // 184.000000, 0.000000, 0.000000, 0.000000, 200.000000, 0.000000, 0.000000, + // 0.000000, 216.000000, 0.000000, 0.000000, 0.000000, 232.000000, 0.000000, + // 240.000000, 0.000000, 248.000000, 0.000000, 0.000000, 260.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 308.000000, 0.000000, 0.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 348.000000, + // 0.000000, 356.000000, 0.000000, 0.000000, 0.000000, 0.000000, 376.000000, + // 0.000000, 384.000000, 0.000000, 0.000000, 0.000000, 400.000000] + + auto ressX = op2.evaluate({&x1, &x1}, {0.5f}, + {119}); // , false, sd::DataType::FLOAT32); // + // skipped due given by default + // x0.printIndexedBuffer("X0"); + // x1.printIndexedBuffer("X1"); + ASSERT_EQ(ND4J_STATUS_OK, ressX.status()); + auto ressY = op2.evaluate({&x1, &x0}, {0.5f}, {119}); + ASSERT_EQ(ND4J_STATUS_OK, ressY.status()); + // ressY->at(0)->printIndexedBuffer("BP"); + // ress.at(0)->printIndexedBuffer("FF"); + bool ret = true; + for (int e = 0; e < ress.at(0).lengthOf(); e++) { + if (ress.at(0).e(e) == 0.f) + if (ressX.at(0).e(e) != ress.at(0).e(e)) { + ret = false; + break; + } + } + ASSERT_TRUE(ret); + // ASSERT_FALSE(ressX->at(0)->equalsTo(ressY->at(0))); + // ressX->at(0)->printIndexedBuffer("02Dropout result is "); + /* float countZero = ressX->at(0)->template + reduceNumber>(); + ASSERT_NEAR(countZero, 50.f, 5.f); + countZero = ress.at(0)->template + reduceNumber>(); + ASSERT_NEAR(countZero, 50.f, 5.f); + countZero = ressY->at(0)->template + reduceNumber>(); + ASSERT_NEAR(countZero, 50.f, 5.f); + */ + // ASSERT_TRUE(exp.equalsTo(ressX->at(0))); } TEST_F(DeclarableOpsTests9, Test_Dropout_BP_2) { - NDArray x('c', {10, 10}, sd::DataType::FLOAT32); + NDArray x('c', {10, 10}, sd::DataType::FLOAT32); - x.linspace(1); + x.linspace(1); - sd::ops::dropout op; + sd::ops::dropout op; - auto ress = op.evaluate({&x}, {0.5f}, {119}); + auto ress = op.evaluate({&x}, {0.5f}, {119}); - ASSERT_EQ(ND4J_STATUS_OK, ress.status()); -// ress.at(0)->printIndexedBuffer("01Dropout result is "); + ASSERT_EQ(ND4J_STATUS_OK, ress.status()); + // ress.at(0)->printIndexedBuffer("01Dropout result is "); - sd::ops::dropout_bp op2; + sd::ops::dropout_bp op2; - auto ressX = op2.evaluate({&x, &x}, {0.5f}, {119}); + auto ressX = op2.evaluate({&x, &x}, {0.5f}, {119}); - ASSERT_EQ(ND4J_STATUS_OK, ressX.status()); - auto ressY = op2.evaluate({&x, &x}, {0.5f}, {119}); - ASSERT_EQ(ND4J_STATUS_OK, ressY.status()); + ASSERT_EQ(ND4J_STATUS_OK, ressX.status()); + auto ressY = op2.evaluate({&x, &x}, {0.5f}, {119}); + ASSERT_EQ(ND4J_STATUS_OK, ressY.status()); - //ress.at(0)->printIndexedBuffer("FF Dropout result is "); - //ressY->at(0)->printIndexedBuffer("BP Dropout result is "); - - - auto countZero = ress.at(0)->reduceNumber(reduce::CountZero); - ASSERT_NEAR(countZero.e(0), 50.f, 10.f); - countZero = ressX.at(0)->reduceNumber(reduce::CountZero); - //nd4j_printf("X zero count is %f\n", countZero); - ASSERT_NEAR(countZero.e(0), 50.f, 10.f); - countZero = ressY.at(0)->reduceNumber(reduce::CountZero); - //nd4j_printf("Y zero count is %f\n", countZero); - ASSERT_NEAR(countZero.e(0), 50.f, 10.f); -// ASSERT_TRUE(exp.equalsTo(ressX->at(0))); - ASSERT_TRUE(ressX.at(0)->equalsTo(ressY.at(0))); + // ress.at(0)->printIndexedBuffer("FF Dropout result is "); + // ressY->at(0)->printIndexedBuffer("BP Dropout result is "); + auto countZero = ress.at(0).reduceNumber(reduce::CountZero); + ASSERT_NEAR(countZero.e(0), 50.f, 10.f); + countZero = ressX.at(0).reduceNumber(reduce::CountZero); + // nd4j_printf("X zero count is %f\n", countZero); + ASSERT_NEAR(countZero.e(0), 50.f, 10.f); + countZero = ressY.at(0).reduceNumber(reduce::CountZero); + // nd4j_printf("Y zero count is %f\n", countZero); + ASSERT_NEAR(countZero.e(0), 50.f, 10.f); + // ASSERT_TRUE(exp.equalsTo(ressX->at(0))); + ASSERT_TRUE(ressX.at(0).equalsTo(ressY.at(0))); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, Test_AlphaDropout_BP_1) { - NDArray x('c', {10, 10}, sd::DataType::FLOAT32); - NDArray eps('c', {10, 10}, sd::DataType::FLOAT32); + NDArray x('c', {10, 10}, sd::DataType::FLOAT32); + NDArray eps('c', {10, 10}, sd::DataType::FLOAT32); - x.linspace(1); - eps.linspace(1); + x.linspace(1); + eps.linspace(1); - sd::ops::alpha_dropout_bp op; + sd::ops::alpha_dropout_bp op; - auto ress = op.evaluate({&x, &eps}, {0.5f, 0.5f, 1.5f, 1.6f}, {119}); + auto ress = op.evaluate({&x, &eps}, {0.5f, 0.5f, 1.5f, 1.6f}, {119}); - ASSERT_EQ(ND4J_STATUS_OK, ress.status()); - NDArray* res = ress.at(0); + ASSERT_EQ(ND4J_STATUS_OK, ress.status()); + auto res = ress.at(0); - auto ress2 = op.evaluate({&x, &eps}, {0.5f, 0.5f, 1.5f, 1.6f}, {119}); - - ASSERT_EQ(ND4J_STATUS_OK, ress2.status()); - NDArray* res2 = ress2.at(0); - //res->printIndexedBuffer("Result1AlphaBP1"); - //res2->printIndexedBuffer("Result1AlphaBP2"); - ASSERT_TRUE(res2->equalsTo(res)); + auto ress2 = op.evaluate({&x, &eps}, {0.5f, 0.5f, 1.5f, 1.6f}, {119}); + ASSERT_EQ(ND4J_STATUS_OK, ress2.status()); + auto res2 = ress2.at(0); + // res->printIndexedBuffer("Result1AlphaBP1"); + // res2->printIndexedBuffer("Result1AlphaBP2"); + ASSERT_TRUE(res2.equalsTo(res)); } TEST_F(DeclarableOpsTests9, test_range_int_1) { - auto x0 = NDArrayFactory::create(0); - auto x1 = NDArrayFactory::create(2); - auto x2 = NDArrayFactory::create(1); - - sd::ops::range op; - auto result = op.evaluate({&x0, &x1, &x2}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x0 = NDArrayFactory::create(0); + auto x1 = NDArrayFactory::create(2); + auto x2 = NDArrayFactory::create(1); - auto z = result.at(0); + sd::ops::range op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); } TEST_F(DeclarableOpsTests9, test_range_empty_1) { - auto x0 = NDArrayFactory::create(0); - auto x1 = NDArrayFactory::create(0); - auto x2 = NDArrayFactory::create(1); - - sd::ops::range op; - auto result = op.evaluate({&x0, &x1, &x2}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x0 = NDArrayFactory::create(0); + auto x1 = NDArrayFactory::create(0); + auto x2 = NDArrayFactory::create(1); - auto z = result.at(0); + sd::ops::range op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(z->isEmpty()); + auto z = result.at(0); + ASSERT_TRUE(z.isEmpty()); } - TEST_F(DeclarableOpsTests9, test_broadcast_bool_1) { - auto x = NDArrayFactory::create('c', {1, 3, 2, 4, 4}); - auto y = NDArrayFactory::create('c', {1, 2, 4, 4}); - auto z = NDArrayFactory::create('c', {1, 3, 2, 4, 4}); + auto x = NDArrayFactory::create('c', {1, 3, 2, 4, 4}); + auto y = NDArrayFactory::create('c', {1, 2, 4, 4}); + auto z = NDArrayFactory::create('c', {1, 3, 2, 4, 4}); - std::vector dims = {0, 2, 3, 4}; - x.applyBroadcast(broadcast::LessThan, dims, y, z); + std::vector dims = {0, 2, 3, 4}; + x.applyBroadcast(broadcast::LessThan, dims, y, z); } TEST_F(DeclarableOpsTests9, test_broadcast_bool_2) { - auto orig = NDArrayFactory::create('c', {1, 7, 4, 4}); - std::vector list = {0,0, 0,2, 0,0, 0,0}; - auto x = NDArrayFactory::create('c', {1, 3, 2, 4, 4}); - - auto y = orig(list, true); + auto orig = NDArrayFactory::create('c', {1, 7, 4, 4}); + std::vector list = {0, 0, 0, 2, 0, 0, 0, 0}; + auto x = NDArrayFactory::create('c', {1, 3, 2, 4, 4}); - auto z = NDArrayFactory::create('c', {1, 3, 2, 4, 4}); + auto y = orig(list, true); - std::vector dims = {0, 2, 3, 4}; - x.applyBroadcast(broadcast::LessThan, dims, y, z); + auto z = NDArrayFactory::create('c', {1, 3, 2, 4, 4}); + std::vector dims = {0, 2, 3, 4}; + x.applyBroadcast(broadcast::LessThan, dims, y, z); } TEST_F(DeclarableOpsTests9, test_unstack_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.linspace(1.0); - - sd::ops::unstack op; - auto result = op.evaluate({&x}, {}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(5, result.size()); + auto x = NDArrayFactory::create('c', {5, 5}); + x.linspace(1.0); + sd::ops::unstack op; + auto result = op.evaluate({&x}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(5, result.size()); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, test_unstack_SGO_1) { - auto x = NDArrayFactory::create({1, 2, 3, 4, 5}); - x.linspace(1.0); - auto z1 = NDArrayFactory::create(1); - auto z2 = NDArrayFactory::create(2); - auto z3 = NDArrayFactory::create(3); - auto z4 = NDArrayFactory::create(4); - auto z5 = NDArrayFactory::create(5); - std::vector z({&z1, &z2, &z3, &z4, &z5}); - sd::ops::unstack op; - auto result = op.evaluate({&x}, {}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(5, result.size()); - for (size_t i = 0; i < result.size(); i++) { - ASSERT_TRUE(result.at(i)->isSameShape(z[i])); - ASSERT_TRUE(result.at(i)->equalsTo(z[i])); - } - + auto x = NDArrayFactory::create({1, 2, 3, 4, 5}); + x.linspace(1.0); + auto z1 = NDArrayFactory::create(1); + auto z2 = NDArrayFactory::create(2); + auto z3 = NDArrayFactory::create(3); + auto z4 = NDArrayFactory::create(4); + auto z5 = NDArrayFactory::create(5); + std::vector z({&z1, &z2, &z3, &z4, &z5}); + sd::ops::unstack op; + auto result = op.evaluate({&x}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(5, result.size()); + for (size_t i = 0; i < result.size(); i++) { + ASSERT_TRUE(result.at(i).isSameShape(z[i])); + ASSERT_TRUE(result.at(i).equalsTo(z[i])); + } } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, cumprod_1) { - - auto inputC = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); - auto axis = NDArrayFactory::create(1); - - auto expFF = NDArrayFactory::create('c', {3, 5}, {1., 2., 6., 24., 120., 6., 42., 336., 3024., 30240.,11., 132.,1716., 24024.,360360.}); - auto expTF = NDArrayFactory::create('c', {3, 5}, {1, 1, 2, 6, 24,1, 6, 42, 336, 3024,1, 11, 132, 1716, 24024}); - - auto expFT = NDArrayFactory::create('c', {3, 5}, {120, 120, 60, 20, 5,30240, 5040, 720, 90, 10,360360, 32760, 2730, 210, 15}); //+++ - auto expTT = NDArrayFactory::create('c', {3, 5}, {120, 60, 20, 5, 1,5040, 720, 90, 10, 1,32760, 2730, 210, 15, 1}); - - int exclusive, reverse; - - //************************************// - exclusive = 0; reverse = 0; - - sd::ops::cumprod op; - auto result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - ASSERT_TRUE(expFF.equalsTo(z)); - - - //************************************// - exclusive = 1; reverse = 0; - - result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); - ASSERT_EQ(Status::OK(), result.status()); - z = result.at(0); - ASSERT_TRUE(expTF.equalsTo(z)); - - - //************************************// - exclusive = 0; reverse = 1; - - result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); - ASSERT_EQ(Status::OK(), result.status()); - z = result.at(0); - ASSERT_TRUE(expFT.equalsTo(z)); - - - //************************************// - exclusive = 1; reverse = 1; - - result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); - ASSERT_EQ(Status::OK(), result.status()); - z = result.at(0); - ASSERT_TRUE(expTT.equalsTo(z)); - - + auto inputC = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto axis = NDArrayFactory::create(1); + + auto expFF = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 6., 24., 120., 6., 42., 336., 3024., 30240., 11., 132., 1716., + 24024., 360360.}); + auto expTF = NDArrayFactory::create( + 'c', {3, 5}, + {1, 1, 2, 6, 24, 1, 6, 42, 336, 3024, 1, 11, 132, 1716, 24024}); + + auto expFT = + NDArrayFactory::create('c', {3, 5}, + {120, 120, 60, 20, 5, 30240, 5040, 720, 90, + 10, 360360, 32760, 2730, 210, 15}); //+++ + auto expTT = NDArrayFactory::create( + 'c', {3, 5}, + {120, 60, 20, 5, 1, 5040, 720, 90, 10, 1, 32760, 2730, 210, 15, 1}); + + int exclusive, reverse; + + //************************************// + exclusive = 0; + reverse = 0; + + sd::ops::cumprod op; + auto result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(expFF.equalsTo(z)); + + //************************************// + exclusive = 1; + reverse = 0; + + result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); + ASSERT_EQ(Status::OK(), result.status()); + z = result.at(0); + ASSERT_TRUE(expTF.equalsTo(z)); + + //************************************// + exclusive = 0; + reverse = 1; + + result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); + ASSERT_EQ(Status::OK(), result.status()); + z = result.at(0); + ASSERT_TRUE(expFT.equalsTo(z)); + + //************************************// + exclusive = 1; + reverse = 1; + + result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); + ASSERT_EQ(Status::OK(), result.status()); + z = result.at(0); + ASSERT_TRUE(expTT.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, cumprod_2) { + NDArray x('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray x0 = x(0, {0}); + NDArray x1 = x(1, {0}); + x0.linspace(1, 0.1); + x1.linspace(1, 0.1); - NDArray x('c', {2, 1500}, sd::DataType::FLOAT32); - NDArray x0 = x(0, {0}); - NDArray x1 = x(1, {0}); - x0.linspace(1, 0.1); - x1.linspace(1, 0.1); - - NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32); - NDArray exp0 = exp(0, {0}); - NDArray exp1 = exp(1, {0}); + NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray exp0 = exp(0, {0}); + NDArray exp1 = exp(1, {0}); - exp0.p(0, 1.f); - exp1.p(0, 1.f); + exp0.p(0, 1.f); + exp1.p(0, 1.f); - for (int i = 1; i < 1500; ++i) { - const auto prev = exp0.e(i-1); - exp0.p(i, prev * x0.e(i)); - exp1.p(i, prev * x1.e(i)); - } + for (int i = 1; i < 1500; ++i) { + const auto prev = exp0.e(i - 1); + exp0.p(i, prev * x0.e(i)); + exp1.p(i, prev * x1.e(i)); + } - sd::ops::cumprod op; - auto result = op.evaluate({&x}, {}, {0, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::cumprod op; + auto result = op.evaluate({&x}, {}, {0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, cumprod_bp_check_1) { + auto x = NDArrayFactory::create('c', {4, 4}); + auto gradO = NDArrayFactory::create('c', {4, 4}); - auto x = NDArrayFactory::create('c', {4, 4}); - auto gradO = NDArrayFactory::create('c', {4, 4}); - - x.linspace(1); + x.linspace(1); - const OpArgsHolder argsHolderFF({&x}, {}, {0, 0}); - const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {0, 0}); + const OpArgsHolder argsHolderFF({&x}, {}, {0, 0}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {0, 0}); - sd::ops::cumprod opFF; - sd::ops::cumprod_bp opBP; + sd::ops::cumprod opFF; + sd::ops::cumprod_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1},GradCheck::MEAN); + const bool isGradCorrect = GradCheck::checkGrad( + opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1}, GradCheck::MEAN); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, cumprod_bp_check_2) { + auto x = NDArrayFactory::create('c', {4, 4}); + auto gradO = NDArrayFactory::create('c', {4, 4}); - auto x = NDArrayFactory::create('c', {4, 4}); - auto gradO = NDArrayFactory::create('c', {4, 4}); - - x.linspace(1); + x.linspace(1); - const OpArgsHolder argsHolderFF({&x}, {}, {1, 1}); - const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {1, 1}); + const OpArgsHolder argsHolderFF({&x}, {}, {1, 1}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {1, 1}); - sd::ops::cumprod opFF; - sd::ops::cumprod_bp opBP; + sd::ops::cumprod opFF; + sd::ops::cumprod_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1},GradCheck::MEAN); + const bool isGradCorrect = GradCheck::checkGrad( + opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1}, GradCheck::MEAN); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, cumprod_bp_check_3) { + auto x = NDArrayFactory::create('c', {4, 4}); + auto gradO = NDArrayFactory::create('c', {4, 4}); - auto x = NDArrayFactory::create('c', {4, 4}); - auto gradO = NDArrayFactory::create('c', {4, 4}); - - x.linspace(1); + x.linspace(1); - const OpArgsHolder argsHolderFF({&x}, {}, {1, 0}); - const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {1, 0}); + const OpArgsHolder argsHolderFF({&x}, {}, {1, 0}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {1, 0}); - sd::ops::cumprod opFF; - sd::ops::cumprod_bp opBP; + sd::ops::cumprod opFF; + sd::ops::cumprod_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1},GradCheck::MEAN); + const bool isGradCorrect = GradCheck::checkGrad( + opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1}, GradCheck::MEAN); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, cumprod_bp_check_4) { + auto x = NDArrayFactory::create('c', {4, 4}); + auto gradO = NDArrayFactory::create('c', {4, 4}); - auto x = NDArrayFactory::create('c', {4, 4}); - auto gradO = NDArrayFactory::create('c', {4, 4}); + x.linspace(1); - x.linspace(1); + const OpArgsHolder argsHolderFF({&x}, {}, {0, 1}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {0, 1}); - const OpArgsHolder argsHolderFF({&x}, {}, {0, 1}); - const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {0, 1}); + sd::ops::cumprod opFF; + sd::ops::cumprod_bp opBP; - sd::ops::cumprod opFF; - sd::ops::cumprod_bp opBP; + const bool isGradCorrect = GradCheck::checkGrad( + opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1}, GradCheck::MEAN); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1},GradCheck::MEAN); - - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, cumsum_bp_check_2) { + auto x = NDArrayFactory::create('c', {4, 4}); + auto gradO = NDArrayFactory::create('c', {4, 4}); - auto x = NDArrayFactory::create('c', {4, 4}); - auto gradO = NDArrayFactory::create('c', {4, 4}); + x.linspace(1); - x.linspace(1); + const OpArgsHolder argsHolderFF({&x}, {}, {1, 1}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {1, 1}); - const OpArgsHolder argsHolderFF({&x}, {}, {1, 1}); - const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {1, 1}); + sd::ops::cumsum opFF; + sd::ops::cumsum_bp opBP; - sd::ops::cumsum opFF; - sd::ops::cumsum_bp opBP; + const bool isGradCorrect = GradCheck::checkGrad( + opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1}, GradCheck::MEAN); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1},GradCheck::MEAN); - - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, cumprod_test1) { + auto inputC = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto axis = NDArrayFactory::create(1.); - auto inputC = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); - auto axis = NDArrayFactory::create(1.); - - auto expFF = NDArrayFactory::create('c', {3, 5}, {1., 2., 6., 24., 120., 6., 42., 336., 3024., 30240.,11., 132.,1716., 24024.,360360.}); - auto expTF = NDArrayFactory::create('c', {3, 5}, {1, 1, 2, 6, 24,1, 6, 42, 336, 3024,1, 11, 132, 1716, 24024}); + auto expFF = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 6., 24., 120., 6., 42., 336., 3024., 30240., 11., 132., 1716., + 24024., 360360.}); + auto expTF = NDArrayFactory::create( + 'c', {3, 5}, + {1, 1, 2, 6, 24, 1, 6, 42, 336, 3024, 1, 11, 132, 1716, 24024}); - auto expFT = NDArrayFactory::create('c', {3, 5}, {120, 120, 60, 20, 5,30240, 5040, 720, 90, 10,360360, 32760, 2730, 210, 15}); //+++ - auto expTT = NDArrayFactory::create('c', {3, 5}, {120, 60, 20, 5, 1,5040, 720, 90, 10, 1,32760, 2730, 210, 15, 1}); - auto gradO = NDArrayFactory::create('c', {3, 5}); + auto expFT = + NDArrayFactory::create('c', {3, 5}, + {120, 120, 60, 20, 5, 30240, 5040, 720, 90, + 10, 360360, 32760, 2730, 210, 15}); //+++ + auto expTT = NDArrayFactory::create( + 'c', {3, 5}, + {120, 60, 20, 5, 1, 5040, 720, 90, 10, 1, 32760, 2730, 210, 15, 1}); + auto gradO = NDArrayFactory::create('c', {3, 5}); - int exclusive, reverse; + int exclusive, reverse; - //************************************// - exclusive = 0; reverse = 0; + //************************************// + exclusive = 0; + reverse = 0; - const OpArgsHolder argsHolderFF({&inputC, &axis}, {}, {exclusive, reverse}); - const OpArgsHolder argsHolderBP({&inputC, &axis, &gradO}, {}, {exclusive, reverse}); + const OpArgsHolder argsHolderFF({&inputC, &axis}, {}, {exclusive, reverse}); + const OpArgsHolder argsHolderBP({&inputC, &axis, &gradO}, {}, + {exclusive, reverse}); - sd::ops::cumprod opFF; - sd::ops::cumprod_bp opBP; + sd::ops::cumprod opFF; + sd::ops::cumprod_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1},GradCheck::MEAN); + const bool isGradCorrect = GradCheck::checkGrad( + opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1}, GradCheck::MEAN); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, cumprod_test2) { + auto inputC = NDArrayFactory::create('c', {2, 2}); + auto axis = NDArrayFactory::create(1.); - auto inputC = NDArrayFactory::create('c', {2, 2}); - auto axis = NDArrayFactory::create(1.); - - auto gradO = NDArrayFactory::create('c', {2, 2}); + auto gradO = NDArrayFactory::create('c', {2, 2}); - int exclusive, reverse; + int exclusive, reverse; - //************************************// - exclusive = 0; reverse = 0; - inputC.linspace(1); - const OpArgsHolder argsHolderFF({&inputC, &axis}, {}, {exclusive, reverse}); - const OpArgsHolder argsHolderBP({&inputC, &axis, &gradO}, {}, {exclusive, reverse}); + //************************************// + exclusive = 0; + reverse = 0; + inputC.linspace(1); + const OpArgsHolder argsHolderFF({&inputC, &axis}, {}, {exclusive, reverse}); + const OpArgsHolder argsHolderBP({&inputC, &axis, &gradO}, {}, + {exclusive, reverse}); - sd::ops::cumprod opFF; - sd::ops::cumprod_bp opBP; + sd::ops::cumprod opFF; + sd::ops::cumprod_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1, 1, 1}, {1, 1},GradCheck::MEAN); + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1, 1, 1}, + {1, 1}, GradCheck::MEAN); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test1) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, + -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto alpha = + NDArrayFactory::create('c', {3, 4}, + {-0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, + 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, {7.2f, 5.5f, 4.f, 2.7f, 1.6f, 0.7f, 0.f, -0.5f, + -0.8f, -0.9f, -0.8f, -0.5f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto alpha = NDArrayFactory::create('c', {3, 4}, {-0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {7.2f, 5.5f, 4.f, 2.7f, 1.6f, 0.7f, 0.f, -0.5f,-0.8f, -0.9f, -0.8f, -0.5f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - - sd::ops::prelu op; - - auto result = op.evaluate({&x, &alpha}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test2) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, + -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto alpha = NDArrayFactory::create('c', {3}, {-0.6f, 2.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, + -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto alpha = NDArrayFactory::create('c', {3}, {-0.6f, 2.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test3) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, + -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto alpha = NDArrayFactory::create('c', {3, 1}, {-0.6f, 2.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, + -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto alpha = NDArrayFactory::create('c', {3,1}, {-0.6f, 2.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test4) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, + -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto alpha = NDArrayFactory::create('c', {1, 3}, {-0.6f, 2.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, + -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto alpha = NDArrayFactory::create('c', {1, 3}, {-0.6f, 2.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test5) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, + -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto alpha = + NDArrayFactory::create('c', {4}, {-0.6f, 2.f, 4.f, -1.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, {7.2f, -22.f, -40.f, 9.f, 4.8f, -14.f, -24.f, 5.f, + 2.4f, -6.f, -8.f, 1.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto alpha = NDArrayFactory::create('c', {4}, {-0.6f, 2.f, 4.f, -1.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {7.2f, -22.f, -40.f, 9.f, 4.8f, -14.f, -24.f, 5.f, 2.4f, -6.f, -8.f, 1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test6) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, + -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto alpha = NDArrayFactory::create('c', {1, 1, 1}, {-2.}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, + 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto alpha = NDArrayFactory::create('c', {1,1,1}, {-2.}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {1,0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test7) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, + -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto alpha = NDArrayFactory::create(-2.f); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, + 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto alpha = NDArrayFactory::create(-2.f); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {1,0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test8) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, + -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto alpha = NDArrayFactory::create(-2.f); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, + 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto alpha = NDArrayFactory::create(-2.f); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {1,0,1,0,1,0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {1, 0, 1, 0, 1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test9) { + auto x = NDArrayFactory::create( + 'c', {2, 4}, {-4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f}); + auto alpha = NDArrayFactory::create(-2.f); + auto exp = NDArrayFactory::create( + 'c', {2, 4}, {8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f}); - auto x = NDArrayFactory::create('c', {2, 4}, {-4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f}); - auto alpha = NDArrayFactory::create(-2.f); - auto exp = NDArrayFactory::create('c', {2, 4}, {8.f, 6.f, 4.f, 2.f,0.f, 1.f, 2.f, 3.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test10) { + auto x = NDArrayFactory::create( + 'c', {2, 4}, {-4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f}); + auto alpha = NDArrayFactory::create(-2.f); + auto exp = NDArrayFactory::create( + 'c', {2, 4}, {8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f}); - auto x = NDArrayFactory::create('c', {2, 4}, {-4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f}); - auto alpha = NDArrayFactory::create(-2.f); - auto exp = NDArrayFactory::create('c', {2, 4}, {8.f, 6.f, 4.f, 2.f,0.f, 1.f, 2.f, 3.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test11) { - - auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); - x.linspace(-50.); - auto alpha = NDArrayFactory::create('c', {4}, {0.f, -0.5f, 0.5f, -1.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4, 5}, {0.f, 0.f, 0.f, 0.f, 0.f, 22.5f, 22.f, 21.5f, 21.f, 20.5f, -20.f, -19.5f, -19.f, -18.5f, -18.f, 35.f, 34.f, 33.f, - 32.f, 31.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.5f, 12.f, 11.5f, 11.f, 10.5f, -10.f, -9.5f, -9.f, -8.5f, -8.f, 15.f, - 14.f, 13.f, 12.f, 11.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.f, 1.5f, 1.f, 0.5f, 0.f, 1.f, 2.f, 3.f, 4.f, - 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, - 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, - 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, - 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {1,3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + x.linspace(-50.); + auto alpha = + NDArrayFactory::create('c', {4}, {0.f, -0.5f, 0.5f, -1.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4, 5}, + {0.f, 0.f, 0.f, 0.f, 0.f, 22.5f, 22.f, 21.5f, 21.f, 20.5f, + -20.f, -19.5f, -19.f, -18.5f, -18.f, 35.f, 34.f, 33.f, 32.f, 31.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 12.5f, 12.f, 11.5f, 11.f, 10.5f, + -10.f, -9.5f, -9.f, -8.5f, -8.f, 15.f, 14.f, 13.f, 12.f, 11.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.f, 1.5f, 1.f, 0.5f, + 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, + 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, + 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, + 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, + 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, + 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); + + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {1, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test12) { - - auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); - x.linspace(-50.); - auto alpha = NDArrayFactory::create('c', {3,5}, {-0.7f, -0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4, 5}, {35.f, 29.4f, 24.f, 18.8f, 13.8f, 31.5f, 26.4f, 21.5f, 16.8f, 12.3f, 28.f, 23.4f, 19.f, 14.8f, 10.8f, 24.5f, 20.4f, 16.5f, 12.8f, - 9.3f, 6.f, 2.9f, 0.f, -2.7f, -5.2f, 5.f, 2.4f, 0.f, -2.2f, -4.2f, 4.f, 1.9f, 0.f, -1.7f, -3.2f, 3.f, 1.4f, 0.f, -1.2f, - -2.2f, -3.f, -3.6f, -4.f, -4.2f, -4.2f, -1.5f, -1.6f, -1.5f, -1.2f, -0.7f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, - 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, - 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f, - 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {-1, 2}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + x.linspace(-50.); + auto alpha = NDArrayFactory::create( + 'c', {3, 5}, + {-0.7f, -0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, 0.3f, + 0.4f, 0.5f, 0.6f, 0.7f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4, 5}, + {35.f, 29.4f, 24.f, 18.8f, 13.8f, 31.5f, 26.4f, 21.5f, 16.8f, 12.3f, + 28.f, 23.4f, 19.f, 14.8f, 10.8f, 24.5f, 20.4f, 16.5f, 12.8f, 9.3f, + 6.f, 2.9f, 0.f, -2.7f, -5.2f, 5.f, 2.4f, 0.f, -2.2f, -4.2f, + 4.f, 1.9f, 0.f, -1.7f, -3.2f, 3.f, 1.4f, 0.f, -1.2f, -2.2f, + -3.f, -3.6f, -4.f, -4.2f, -4.2f, -1.5f, -1.6f, -1.5f, -1.2f, -0.7f, + 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, + 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, + 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, + 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, + 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, + 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); + + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {-1, 2}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test13) { - - auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); - x.linspace(-50.); - auto alpha = NDArrayFactory::create('c', {5,3}, {-0.7f, -0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4, 5}, {35.f, 29.4f, 24.f, 18.8f, 13.8f, 31.5f, 26.4f, 21.5f, 16.8f, 12.3f, 28.f, 23.4f, 19.f, 14.8f, 10.8f, 24.5f, 20.4f, 16.5f, 12.8f, - 9.3f, 6.f, 2.9f, 0.f, -2.7f, -5.2f, 5.f, 2.4f, 0.f, -2.2f, -4.2f, 4.f, 1.9f, 0.f, -1.7f, -3.2f, 3.f, 1.4f, 0.f, -1.2f, - -2.2f, -3.f, -3.6f, -4.f, -4.2f, -4.2f, -1.5f, -1.6f, -1.5f, -1.2f, -0.7f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, - 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, - 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f, - 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {-1, 2}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + x.linspace(-50.); + auto alpha = NDArrayFactory::create( + 'c', {5, 3}, + {-0.7f, -0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, 0.3f, + 0.4f, 0.5f, 0.6f, 0.7f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4, 5}, + {35.f, 29.4f, 24.f, 18.8f, 13.8f, 31.5f, 26.4f, 21.5f, 16.8f, 12.3f, + 28.f, 23.4f, 19.f, 14.8f, 10.8f, 24.5f, 20.4f, 16.5f, 12.8f, 9.3f, + 6.f, 2.9f, 0.f, -2.7f, -5.2f, 5.f, 2.4f, 0.f, -2.2f, -4.2f, + 4.f, 1.9f, 0.f, -1.7f, -3.2f, 3.f, 1.4f, 0.f, -1.2f, -2.2f, + -3.f, -3.6f, -4.f, -4.2f, -4.2f, -1.5f, -1.6f, -1.5f, -1.2f, -0.7f, + 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, + 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, + 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, + 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, + 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, + 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); + + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {-1, 2}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test14) { - - auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); - x.linspace(-50.); - auto alpha = NDArrayFactory::create('c', {2,10}, {-0.7f, -0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4, 5}, {35.f, 29.4f, 24.f, 18.8f, 13.8f, 9.f, 4.4f, 0.f, -4.2f, -8.2f, -12.f, -15.6f, -19.f, -22.2f, -25.2f, -28.f, -30.6f, - -33.f,-35.2f, -37.2f, 21.f, 17.4f, 14.f, 10.8f, 7.8f, 5.f, 2.4f, 0.f, -2.2f, -4.2f, -6.f, -7.6f, -9.f, -10.2f, - -11.2f, -12.f, -12.6f, -13.f, -13.2f, -13.2f, 7.f, 5.4f, 4.f, 2.8f, 1.8f, 1.f, 0.4f, 0.f, -0.2f, -0.2f, 0.f, - 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, - 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, - 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, - 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {-2}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + x.linspace(-50.); + auto alpha = NDArrayFactory::create( + 'c', {2, 10}, + {-0.7f, -0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, + 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4, 5}, + {35.f, 29.4f, 24.f, 18.8f, 13.8f, 9.f, 4.4f, 0.f, -4.2f, + -8.2f, -12.f, -15.6f, -19.f, -22.2f, -25.2f, -28.f, -30.6f, -33.f, + -35.2f, -37.2f, 21.f, 17.4f, 14.f, 10.8f, 7.8f, 5.f, 2.4f, + 0.f, -2.2f, -4.2f, -6.f, -7.6f, -9.f, -10.2f, -11.2f, -12.f, + -12.6f, -13.f, -13.2f, -13.2f, 7.f, 5.4f, 4.f, 2.8f, 1.8f, + 1.f, 0.4f, 0.f, -0.2f, -0.2f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, + 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, + 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, + 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, + 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, + 67.f, 68.f, 69.f}); + + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {-2}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) { + const float theta = 2.f; + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, + -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - const float theta = 2.f; - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {0.f, 0.f, 0.f, 0.f,0.f, 0.f, 0.f, 0.f,0.f, 0.f, 0.f, 0.f,0.f, 0.f, 0.f, 3.f,4.f, 5.f, 6.f, 7.f,8.f, 9.f,10.f,11.f}); - - sd::ops::thresholdedrelu op; - - auto result = op.evaluate({&x}, {theta}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::thresholdedrelu op; + auto result = op.evaluate({&x}, {theta}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, compare_and_bitpack_test1) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, + -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto threshold = NDArrayFactory::create(2.0); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto threshold = NDArrayFactory::create(2.0); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - - sd::ops::compare_and_bitpack op; - - auto result = op.evaluate({&x, &threshold}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); -// output->printIndexedBuffer("Packed to uint8"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::compare_and_bitpack op; + auto result = op.evaluate({&x, &threshold}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + // output->printIndexedBuffer("Packed to uint8"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) { + const float theta = -2.f; + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.f, -4.f, -10.f, -8.f, 0.f, -9.f, -8.f, 5.f, 6.f, 6.f, 9.f, 6.f, + -8.f, 5.f, 10.f, -2.f, 3.f, -7.f, 4.f, -8.f, -4.f, -9.f, -9.f, 3.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 5.f, 6.f, 6.f, 9.f, 6.f, + 0.f, 5.f, 10.f, 0.f, 3.f, 0.f, 4.f, 0.f, 0.f, 0.f, 0.f, 3.f}); - const float theta = -2.f; - auto x = NDArrayFactory::create('c', {2, 3, 4}, {0.f,-4.f, -10.f, -8.f, 0.f, -9.f, -8.f, 5.f, 6.f, 6.f, 9.f, 6.f, -8.f, 5.f, 10.f, -2.f, 3.f, -7.f, 4.f, -8.f, -4.f, -9.f, -9.f, 3.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 5.f, 6.f, 6.f, 9.f, 6.f, 0.f, 5.f, 10.f, 0.f, 3.f, 0.f, 4.f, 0.f, 0.f, 0.f, 0.f, 3.f}); - - sd::ops::thresholdedrelu op; - - auto result = op.evaluate({&x}, {theta}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::thresholdedrelu op; + auto result = op.evaluate({&x}, {theta}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_bp_test1) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-12., -11., -10., -9., -8., -7., -6., -5., -4., -3., -2., -1., + 0.5, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.}); + auto alpha = NDArrayFactory::create( + 'c', {3, 4}, + {-0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5}); + auto dLdO = NDArrayFactory::create('c', {2, 3, 4}); - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12., -11., -10., -9., -8., -7., -6., -5., -4., -3., -2., -1., 0.5, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.}); - auto alpha = NDArrayFactory::create('c', {3, 4}, {-0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5}); - auto dLdO = NDArrayFactory::create('c', {2, 3, 4}); - - const OpArgsHolder argsHolderFF({&x, &alpha}, {}, {}); - const OpArgsHolder argsHolderBP({&x, &alpha, &dLdO}, {}, {}); + const OpArgsHolder argsHolderFF({&x, &alpha}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &alpha, &dLdO}, {}, {}); - sd::ops::prelu opFF; - sd::ops::prelu_bp opBP; + sd::ops::prelu opFF; + sd::ops::prelu_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_bp_test2) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-12., -11., -10., -9., -8., -7., -6., -5., -4., -3., -2., -1., + 0.5, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.}); + auto alpha = NDArrayFactory::create('c', {4}, {-0.6, 2., 4., -1.}); + auto dLdO = NDArrayFactory::create('c', {2, 3, 4}); - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12., -11., -10., -9., -8., -7., -6., -5., -4., -3., -2., -1., 0.5, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.}); - auto alpha = NDArrayFactory::create('c', {4}, {-0.6, 2., 4., -1.}); - auto dLdO = NDArrayFactory::create('c', {2, 3, 4}); - - const OpArgsHolder argsHolderFF({&x, &alpha}, {}, {1}); - const OpArgsHolder argsHolderBP({&x, &alpha, &dLdO}, {}, {1}); + const OpArgsHolder argsHolderFF({&x, &alpha}, {}, {1}); + const OpArgsHolder argsHolderBP({&x, &alpha, &dLdO}, {}, {1}); - sd::ops::prelu opFF; - sd::ops::prelu_bp opBP; + sd::ops::prelu opFF; + sd::ops::prelu_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_bp_test3) { + auto x = NDArrayFactory::create('c', {2, 3, 2, 5}); + x.linspace(-30.); + x.p(30, 0.5); // avoid zero, since it is points of discontinuity for prelu + auto alpha = + NDArrayFactory::create('c', {5, 3}, + {-0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, + 0.5, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7}); + auto dLdO = NDArrayFactory::create('c', {2, 3, 2, 5}); - auto x = NDArrayFactory::create('c', {2, 3, 2, 5}); - x.linspace(-30.); - x.p(30, 0.5); // avoid zero, since it is points of discontinuity for prelu - auto alpha = NDArrayFactory::create('c', {5,3}, {-0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7}); - auto dLdO = NDArrayFactory::create('c', {2, 3, 2, 5}); + const OpArgsHolder argsHolderFF({&x, &alpha}, {}, {-1, 2}); + const OpArgsHolder argsHolderBP({&x, &alpha, &dLdO}, {}, {-1, 2}); - const OpArgsHolder argsHolderFF({&x, &alpha}, {}, {-1, 2}); - const OpArgsHolder argsHolderBP({&x, &alpha, &dLdO}, {}, {-1, 2}); + sd::ops::prelu opFF; + sd::ops::prelu_bp opBP; - sd::ops::prelu opFF; - sd::ops::prelu_bp opBP; + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_bp_test4) { + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + x.linspace(-50.); + x.p(50, 0.5); // avoid zero, since it is points of discontinuity for prele + auto alpha = NDArrayFactory::create( + 'c', {2, 10}, {-0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0.25, 0.1, 0.2, + 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}); + auto dLdO = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); - x.linspace(-50.); - x.p(50, 0.5); // avoid zero, since it is points of discontinuity for prele - auto alpha = NDArrayFactory::create('c', {2,10}, {-0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0.25, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}); - auto dLdO = NDArrayFactory::create('c', {2, 3, 4, 5}); - - const OpArgsHolder argsHolderFF({&x, &alpha}, {}, {-2}); - const OpArgsHolder argsHolderBP({&x, &alpha, &dLdO}, {}, {-2}); + const OpArgsHolder argsHolderFF({&x, &alpha}, {}, {-2}); + const OpArgsHolder argsHolderBP({&x, &alpha, &dLdO}, {}, {-2}); - sd::ops::prelu opFF; - sd::ops::prelu_bp opBP; + sd::ops::prelu opFF; + sd::ops::prelu_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, thresholdedrelu_bp_test1) { + const double theta = 0.15; - const double theta = 0.15; - - auto x = NDArrayFactory::create('c', {2, 3, 4}, {1.2, 1.1, 1., 0.9, 0.8, -0.7, -0.6,-0.5,-0.4,-0.3,-0.2,-0.1, 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, -0.9, -1.0, -1.1}); - auto dLdO = NDArrayFactory::create('c', {2, 3, 4}); + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1.2, 1.1, 1., 0.9, 0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, + 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, -0.9, -1.0, -1.1}); + auto dLdO = NDArrayFactory::create('c', {2, 3, 4}); - const OpArgsHolder argsHolderFF({&x}, {theta}, {}); - const OpArgsHolder argsHolderBP({&x, &dLdO}, {theta}, {}); + const OpArgsHolder argsHolderFF({&x}, {theta}, {}); + const OpArgsHolder argsHolderBP({&x, &dLdO}, {theta}, {}); - sd::ops::thresholdedrelu opFF; - sd::ops::thresholdedrelu_bp opBP; + sd::ops::thresholdedrelu opFF; + sd::ops::thresholdedrelu_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_test1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto y = NDArrayFactory::create('c', {4}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.1f, 0.4f, 0.9f, 1.6f, 0.5f, 1.2f, 2.1f, 3.2f, 0.9f, 2.f, 3.3f, 4.8f, + 1.3f, 2.8f, 4.5f, 6.4f, 1.7f, 3.6f, 5.7f, 8.f, 2.1f, 4.4f, 6.9f, 9.6f}); + x.linspace(1.f); + y.linspace(0.1f, 0.1f); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto y = NDArrayFactory::create('c', {4}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {0.1f, 0.4f, 0.9f, 1.6f, 0.5f, 1.2f, 2.1f, 3.2f, 0.9f, 2.f, 3.3f, 4.8f, 1.3f, 2.8f, 4.5f, 6.4f, 1.7f, 3.6f, 5.7f, 8.f, 2.1f, 4.4f, 6.9f, 9.6f}); - x.linspace(1.f); - y.linspace(0.1f, 0.1f); - - sd::ops::multiply op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + sd::ops::multiply op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_test2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto y = NDArrayFactory::create(0.1); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f, + 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 1.9f, 2.f, 2.1f, 2.2f, 2.3f, 2.4f}); + x.linspace(1.f); + // y.linspace(0.1f, 0.1f); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto y = NDArrayFactory::create(0.1); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 1.9f, 2.f, 2.1f, 2.2f, 2.3f, 2.4f}); - x.linspace(1.f); - // y.linspace(0.1f, 0.1f); - - sd::ops::multiply op; - auto result = op.evaluate({&y, &x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + sd::ops::multiply op; + auto result = op.evaluate({&y, &x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_test3) { + auto x = NDArrayFactory::create('c', {2, 1, 4}); + auto y = NDArrayFactory::create('c', {3, 1}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.1f, 0.2f, 0.3f, 0.4f, 0.2f, 0.4f, 0.6f, 0.8f, 0.3f, 0.6f, 0.9f, 1.2f, + 0.5f, 0.6f, 0.7f, 0.8f, 1.f, 1.2f, 1.4f, 1.6f, 1.5f, 1.8f, 2.1f, 2.4f}); + x.linspace(1.f); + y.linspace(0.1f, 0.1f); - auto x = NDArrayFactory::create('c', {2, 1, 4}); - auto y = NDArrayFactory::create('c', {3,1}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {0.1f, 0.2f, 0.3f, 0.4f, 0.2f, 0.4f, 0.6f, 0.8f, 0.3f, 0.6f, 0.9f, 1.2f, 0.5f, 0.6f, 0.7f, 0.8f, 1.f, 1.2f, 1.4f, 1.6f, 1.5f, 1.8f, 2.1f, 2.4f}); - x.linspace(1.f); - y.linspace(0.1f, 0.1f); - - sd::ops::multiply op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + sd::ops::multiply op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_test4) { + auto x = NDArrayFactory::create('c', {1, 1}); + auto y = NDArrayFactory::create(0.1f); + auto exp = NDArrayFactory::create('c', {1, 1}, {0.1f}); + x.linspace(1.f); - auto x = NDArrayFactory::create('c', {1, 1}); - auto y = NDArrayFactory::create(0.1f); - auto exp = NDArrayFactory::create('c', {1, 1}, {0.1f}); - x.linspace(1.f); - - sd::ops::multiply op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + sd::ops::multiply op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_test5) { + auto x = NDArrayFactory::create(1.f); + auto y = NDArrayFactory::create(0.1f); + auto exp = NDArrayFactory::create(0.1f); - auto x = NDArrayFactory::create(1.f); - auto y = NDArrayFactory::create(0.1f); - auto exp = NDArrayFactory::create(0.1f); - - sd::ops::multiply op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + sd::ops::multiply op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_bp_test1) { + auto x = NDArrayFactory::create('c', {1, 1}, {100.}); + auto y = NDArrayFactory::create(0.1); + auto dLdz = NDArrayFactory::create('c', {1, 1}); - auto x = NDArrayFactory::create('c', {1, 1}, {100.}); - auto y = NDArrayFactory::create(0.1); - auto dLdz = NDArrayFactory::create('c', {1, 1}); - - const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); - const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); - - sd::ops::multiply opFF; - sd::ops::multiply_bp opBP; - auto resFF = opFF.evaluate({&x, &y}, {}, {}); - auto resBP = opBP.evaluate({&x, &y, &dLdz}, {}, {}); -// resFF->at(0)->printIndexedBuffer("Multiply 1x1"); -// resBP->at(0)->printIndexedBuffer("Multiply BP 1x1 x"); -// resBP->at(1)->printIndexedBuffer("Multyply BP 1x1 y");*/ - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - ASSERT_TRUE(isGradCorrect); + const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); + + sd::ops::multiply opFF; + sd::ops::multiply_bp opBP; + auto resFF = opFF.evaluate({&x, &y}, {}, {}); + auto resBP = opBP.evaluate({&x, &y, &dLdz}, {}, {}); + // resFF->at(0)->printIndexedBuffer("Multiply 1x1"); + // resBP->at(0)->printIndexedBuffer("Multiply BP 1x1 x"); + // resBP->at(1)->printIndexedBuffer("Multyply BP 1x1 y");*/ + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_bp_test2) { + auto x = NDArrayFactory::create('c', {2, 2}, {1., 2., 3., 4.}); + auto y = NDArrayFactory::create(0.1); + auto dLdz = NDArrayFactory::create('c', {2, 2}); - auto x = NDArrayFactory::create('c', {2, 2}, {1.,2.,3.,4.}); - auto y = NDArrayFactory::create(0.1); - auto dLdz = NDArrayFactory::create('c', {2, 2}); - - const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); - const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); + const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); - sd::ops::multiply opFF; - sd::ops::multiply_bp opBP; + sd::ops::multiply opFF; + sd::ops::multiply_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_bp_test3) { + auto y = NDArrayFactory::create('c', {2, 2}, {1., 2., 3., 4.}); + auto x = NDArrayFactory::create(0.1); + auto dLdz = NDArrayFactory::create('c', {2, 2}); - auto y = NDArrayFactory::create('c', {2, 2}, {1.,2.,3.,4.}); - auto x = NDArrayFactory::create(0.1); - auto dLdz = NDArrayFactory::create('c', {2, 2}); + const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); - const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); - const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); + sd::ops::multiply opFF; + sd::ops::multiply_bp opBP; - sd::ops::multiply opFF; - sd::ops::multiply_bp opBP; + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_bp_test4) { + auto x = NDArrayFactory::create('c', {2, 2}, {1., 2., 3., 4.}); + auto y = NDArrayFactory::create('c', {2, 2}, {0.1, 0.2, 0.3, 0.4}); + auto dLdz = NDArrayFactory::create('c', {2, 2}); - auto x = NDArrayFactory::create('c', {2, 2}, {1.,2.,3.,4.}); - auto y = NDArrayFactory::create('c', {2, 2}, {0.1,0.2,0.3,0.4}); - auto dLdz = NDArrayFactory::create('c', {2, 2}); - - const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); - const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); + const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); - sd::ops::multiply opFF; - sd::ops::multiply_bp opBP; + sd::ops::multiply opFF; + sd::ops::multiply_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_bp_test5) { + auto x = NDArrayFactory::create('c', {2, 2}, {1., 2., 3., 4.}); + auto y = NDArrayFactory::create('c', {2}, {0.1, 0.2}); + auto dLdz = NDArrayFactory::create('c', {2, 2}); - auto x = NDArrayFactory::create('c', {2, 2}, {1.,2.,3.,4.}); - auto y = NDArrayFactory::create('c', {2}, {0.1,0.2}); - auto dLdz = NDArrayFactory::create('c', {2, 2}); - - const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); - const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); + const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); - sd::ops::multiply opFF; - sd::ops::multiply_bp opBP; + sd::ops::multiply opFF; + sd::ops::multiply_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_bp_test6) { + auto y = NDArrayFactory::create('c', {2, 2}, {1., 2., 3., 4.}); + auto x = NDArrayFactory::create('c', {2}, {0.1, 0.2}); + auto dLdz = NDArrayFactory::create('c', {2, 2}); - auto y = NDArrayFactory::create('c', {2, 2}, {1.,2.,3.,4.}); - auto x = NDArrayFactory::create('c', {2}, {0.1,0.2}); - auto dLdz = NDArrayFactory::create('c', {2, 2}); + const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); - const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); - const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); + sd::ops::multiply opFF; + sd::ops::multiply_bp opBP; - sd::ops::multiply opFF; - sd::ops::multiply_bp opBP; + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_bp_test7) { + auto y = + NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto x = NDArrayFactory::create('c', {2, 1}, {0.1, 0.2}); + auto dLdz = NDArrayFactory::create('c', {2, 3}); - auto y = NDArrayFactory::create('c', {2, 3}, {1.,2.,3.,4.,5.,6.}); - auto x = NDArrayFactory::create('c', {2, 1}, {0.1,0.2}); - auto dLdz = NDArrayFactory::create('c', {2, 3}); + const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); - const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); - const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); + sd::ops::multiply opFF; + sd::ops::multiply_bp opBP; - sd::ops::multiply opFF; - sd::ops::multiply_bp opBP; + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_bp_test8) { + auto y = NDArrayFactory::create('c', {2, 1, 4}); + auto x = NDArrayFactory::create('c', {1, 3, 4}); + auto dLdz = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1., 0.5); + y.linspace(0.1, 0.05); - auto y = NDArrayFactory::create('c', {2, 1, 4}); - auto x = NDArrayFactory::create('c', {1, 3, 4}); - auto dLdz = NDArrayFactory::create('c', {2, 3, 4}); - x.linspace(1., 0.5); - y.linspace(0.1, 0.05); - - const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); - const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); + const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); - sd::ops::multiply opFF; - sd::ops::multiply_bp opBP; + sd::ops::multiply opFF; + sd::ops::multiply_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, Floormod_BP_Test_2) { + auto y = NDArrayFactory::create('c', {10, 10}); + auto x = NDArrayFactory::create('c', {10, 10}); + auto dLdz = NDArrayFactory::create('c', {10, 10}); + // auto eps = NDArrayFactory::create('c', {10, 10}); + x.linspace(4); // 2., 2.0); + y.linspace(3); + dLdz.linspace(1); + // const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + // const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); - auto y = NDArrayFactory::create('c', {10, 10}); - auto x = NDArrayFactory::create('c', {10, 10}); - auto dLdz = NDArrayFactory::create('c', {10, 10}); - //auto eps = NDArrayFactory::create('c', {10, 10}); - x.linspace(4); //2., 2.0); - y.linspace(3); - dLdz.linspace(1); -// const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); -// const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); - -// sd::ops::floormod opFF; -// auto resFF = opFF.execute({&x, &y}, {}, {}); -// resFF->at(0)->printIndexedBuffer("FF floormod"); -// delete resFF; - sd::ops::floormod_bp opBP; - auto resBP = opBP.evaluate({&x, &y, &dLdz}, {}, {}); - ASSERT_TRUE(resBP.status() == ND4J_STATUS_OK); - -// resBP->at(0)->printIndexedBuffer("BP floormod /dx"); -// resBP->at(1)->printIndexedBuffer("BP floormod /dy"); - ASSERT_TRUE(dLdz.equalsTo(resBP.at(0))); - ASSERT_TRUE(dLdz.equalsTo(resBP.at(1))); - -// const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + // sd::ops::floormod opFF; + // auto resFF = opFF.execute({&x, &y}, {}, {}); + // resFF->at(0)->printIndexedBuffer("FF floormod"); + // delete resFF; + sd::ops::floormod_bp opBP; + auto resBP = opBP.evaluate({&x, &y, &dLdz}, {}, {}); + ASSERT_TRUE(resBP.status() == ND4J_STATUS_OK); -// ASSERT_TRUE(isGradCorrect); + // resBP->at(0)->printIndexedBuffer("BP floormod /dx"); + // resBP->at(1)->printIndexedBuffer("BP floormod /dy"); + ASSERT_TRUE(dLdz.equalsTo(resBP.at(0))); + ASSERT_TRUE(dLdz.equalsTo(resBP.at(1))); + + // const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, + // argsHolderFF, argsHolderBP); + + // ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, Dynamic_Partition_BP_1) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto y = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 1, 0, 2}); - auto dLdzX = NDArrayFactory::create('c', {2, 4}); - auto dLdzY = NDArrayFactory::create('c', {2, 4}); - auto dLdzZ = NDArrayFactory::create('c', {2, 4}); - auto exp = NDArrayFactory::create('c', {2,3,4}, {1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 3, 3, 3, 3}); - x.linspace(1); -// dLdzX.linspace(1); -// dLdzY.linspace(2); -// dLdzZ.linspace(3); - dLdzX.assign(1); - dLdzY.assign(2); - dLdzZ.assign(3); - - sd::ops::dynamic_partition op1; - auto res1 = op1.evaluate({&x, &y}, {}, {3}); - - sd::ops::dynamic_partition_bp op2; - auto res2 = op2.evaluate({&x, &y, &dLdzX, &dLdzY, &dLdzZ}, {}, {3}); - ASSERT_TRUE(res2.status() == ND4J_STATUS_OK); - ASSERT_TRUE(res2.size() == 2); -// printf("How many: %ul\n", res2->size()); -// res2->at(0)->printBuffer("Ouputput0"); -// res2->at(1)->printBuffer("Ouputput1"); - ASSERT_TRUE(res2.at(0)->equalsTo(exp)); - + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto y = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 1, 0, 2}); + auto dLdzX = NDArrayFactory::create('c', {2, 4}); + auto dLdzY = NDArrayFactory::create('c', {2, 4}); + auto dLdzZ = NDArrayFactory::create('c', {2, 4}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 3, 3, 3, 3}); + x.linspace(1); + // dLdzX.linspace(1); + // dLdzY.linspace(2); + // dLdzZ.linspace(3); + dLdzX.assign(1); + dLdzY.assign(2); + dLdzZ.assign(3); + + sd::ops::dynamic_partition op1; + auto res1 = op1.evaluate({&x, &y}, {}, {3}); + + sd::ops::dynamic_partition_bp op2; + auto res2 = op2.evaluate({&x, &y, &dLdzX, &dLdzY, &dLdzZ}, {}, {3}); + ASSERT_TRUE(res2.status() == ND4J_STATUS_OK); + ASSERT_TRUE(res2.size() == 2); + // printf("How many: %ul\n", res2->size()); + // res2->at(0)->printBuffer("Ouputput0"); + // res2->at(1)->printBuffer("Ouputput1"); + ASSERT_TRUE(res2.at(0).equalsTo(exp)); } ////////////////////////////////////////////////////////////////////// -//TEST_F(DeclarableOpsTests9, Dynamic_Partition_BP_2) { +// TEST_F(DeclarableOpsTests9, Dynamic_Partition_BP_2) { // // auto x = NDArrayFactory::create('c', {2, 3, 4}); // auto y = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 1, 0, 2}); @@ -2228,41 +2311,41 @@ TEST_F(DeclarableOpsTests9, Dynamic_Partition_BP_1) { // dLdzZ.linspace(1); // // const OpArgsHolder argsHolderFF({&x, &y}, {}, {3}); -// const OpArgsHolder argsHolderBP({&x, &y, &dLdzX, &dLdzY, &dLdzZ}, {}, {3}); +// const OpArgsHolder argsHolderBP({&x, &y, &dLdzX, &dLdzY, &dLdzZ}, {}, +// {3}); // // sd::ops::dynamic_partition opFF; // sd::ops::dynamic_partition_bp opBP; // -// const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); +// const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, +// argsHolderBP); // // ASSERT_TRUE(isGradCorrect); //} //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, Floormod_BP_Test_4) { + auto x = NDArrayFactory::create('c', {2, 1, 3}, + {2.0, 6.0, -3.0, 2.0, 6.0, -3.0}); + auto y = NDArrayFactory::create('c', {1, 3}, {-3.0, 2.0, -2.0}); + auto exp = NDArrayFactory::create('c', {1, 3}, {-1., 0., -1.}); + auto eps = NDArrayFactory::create('c', {2, 1, 3}); + eps.assign(1.f); + sd::ops::floormod_bp op; - auto x = NDArrayFactory::create('c', {2, 1, 3}, {2.0, 6.0, -3.0, 2.0, 6.0, -3.0}); - auto y = NDArrayFactory::create('c', {1, 3}, {-3.0, 2.0, -2.0}); - auto exp = NDArrayFactory::create('c', {1, 3}, {-1., 0., -1.}); - auto eps = NDArrayFactory::create('c', {2, 1, 3}); - eps.assign(1.f); - sd::ops::floormod_bp op; - - auto result = op.evaluate({&x, &y, &eps}, {}, {}); - - ASSERT_TRUE(result.size() == 2); - auto gradX = result.at(0); - auto gradY = result.at(1); + auto result = op.evaluate({&x, &y, &eps}); -// gradX->printIndexedBuffer("gradX"); -// gradY->printIndexedBuffer("gradY"); - ASSERT_TRUE(exp.isSameShape(gradY)); + ASSERT_TRUE(result.size() == 2); + auto gradX = result.at(0); + auto gradY = result.at(1); - ASSERT_TRUE(exp.equalsTo(gradY)); + // gradX->printIndexedBuffer("gradX"); + // gradY->printIndexedBuffer("gradY"); + ASSERT_TRUE(exp.isSameShape(gradY)); + ASSERT_TRUE(exp.equalsTo(gradY)); } - /* //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) { @@ -2313,12 +2396,15 @@ TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) { - const OpArgsHolder argsHolderBP({&x, &hi, &W, &Wc, &b, &bc, &dLdr, &dLdu, &dLdc, &dLdh}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &hi, &W, &Wc, &b, &bc, &dLdr, &dLdu, +&dLdc, &dLdh}, {}, {}); sd::ops::gruCell opFF; sd::ops::gruCell_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1, 1, 1 , 1, 1}, {0., 1.}, sd::GradCheck::LossFunc::SUM, true); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, +argsHolderBP, {1, 1, 1, 1 , 1, 1}, {0., 1.}, sd::GradCheck::LossFunc::SUM, +true); ASSERT_TRUE(isGradCorrect); } @@ -2326,50 +2412,56 @@ TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, Cholesky_Test_1) { + NDArray x = NDArrayFactory::create( + 'c', {3, 3}, {4, 12, -16, 12, 37, -43, -16, -43, 98}); + NDArray exp = NDArrayFactory::create( + 'c', {3, 3}, {2., 0., 0., 6., 1., 0., -8., 5., 3.}); - NDArray x = NDArrayFactory::create('c', {3, 3}, {4,12,-16, 12 ,37,-43, -16, -43, 98}); - NDArray exp = NDArrayFactory::create('c', {3,3}, {2., 0., 0., 6., 1., 0., -8., 5., 3.}); - - sd::ops::cholesky op; - - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(result.status(), ND4J_STATUS_OK); - auto res = result.at(0); -// res->printIndexedBuffer("Output for Cholesky1"); - ASSERT_TRUE(exp.equalsTo(res)); + sd::ops::cholesky op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto res = result.at(0); + // res->printIndexedBuffer("Output for Cholesky1"); + ASSERT_TRUE(exp.equalsTo(res)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, Cholesky_Test_2) { + NDArray x = NDArrayFactory::create( + 'c', {2, 3, 3}, + {4, 12, -16, 12, 37, -43, -16, -43, 98, 1, 1, 1, 1, 2, 2, 1, 2., 6}); + NDArray exp = NDArrayFactory::create( + 'c', {2, 3, 3}, + {2., 0., 0., 6., 1., 0., -8., 5., 3., 1., 0., 0., 1., 1., 0, 1., 1., 2.}); - NDArray x = NDArrayFactory::create('c', {2, 3, 3}, {4, 12,-16, 12 ,37,-43, -16, -43, 98, 1, 1, 1, 1, 2, 2, 1, 2., 6}); - NDArray exp = NDArrayFactory::create('c', {2, 3, 3}, {2., 0., 0., 6., 1., 0., -8., 5., 3., 1., 0., 0., 1., 1., 0,1., 1., 2.}); - - sd::ops::cholesky op; - - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(result.status(), ND4J_STATUS_OK); - auto res = result.at(0); -// res->printIndexedBuffer("Output for Cholesky 2"); - ASSERT_TRUE(exp.equalsTo(res)); + sd::ops::cholesky op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto res = result.at(0); + // res->printIndexedBuffer("Output for Cholesky 2"); + ASSERT_TRUE(exp.equalsTo(res)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, Cholesky_Test_3) { + NDArray x = NDArrayFactory::create( + 'c', {2, 3, 3}, + {4.f, 12.f, -16.f, 12.f, 37.f, -43.f, -16.f, -43.f, 98.f, 1.f, 1.f, 1.f, + 1.f, 2.f, 2.f, 1.f, 2.f, 6.f}); + NDArray exp = NDArrayFactory::create( + 'c', {2, 3, 3}, + {2.f, 0.f, 0.f, 6.f, 1.f, 0.f, -8.f, 5.f, 3.f, 1.f, 0.f, 0.f, 1.f, 1.f, + 0.f, 1.f, 1.f, 2.f}); - NDArray x = NDArrayFactory::create('c', {2, 3, 3}, {4.f, 12.f, -16.f, 12.f, 37.f, -43.f, -16.f, -43.f, 98.f, 1.f, 1.f, 1.f, 1.f, 2.f, 2.f, 1.f, 2.f, 6.f}); - NDArray exp = NDArrayFactory::create('c', {2, 3, 3}, {2.f, 0.f, 0.f, 6.f, 1.f, 0.f, -8.f, 5.f, 3.f, 1.f, 0.f, 0.f, 1.f, 1.f, 0.f, 1.f, 1.f, 2.f}); - - sd::ops::cholesky op; - - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(result.status(), ND4J_STATUS_OK); - auto res = result.at(0); - // res->printIndexedBuffer("Output for Cholesky 3"); - ASSERT_TRUE(exp.equalsTo(res, 1e-4)); + sd::ops::cholesky op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto res = result.at(0); + // res->printIndexedBuffer("Output for Cholesky 3"); + ASSERT_TRUE(exp.equalsTo(res, 1e-4)); } //////////////////////////////////////////////////////////////////// @@ -2394,12 +2486,14 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_3) { // b = 0.5; // const OpArgsHolder argsHolderFF({&x, &h0, &Wx, &Wh, &b}, {}, {}); -// const OpArgsHolder argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, {}, {}); +// const OpArgsHolder argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, +// {}, {}); // sd::ops::gru opFF; // sd::ops::gru_bp opBP; -// const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); +// const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, +// argsHolderBP); // ASSERT_TRUE(isGradCorrect); // } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu index 4f69da61b233..60d12a1745fa 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu @@ -14,49 +14,57 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // @author raver119@gmail.com // -#include "testlayers.h" -#include #include -#include #include +#include +#include + #include +#include "testlayers.h" using namespace sd; - class DeclarableOpsTestsCuda1 : public testing::Test { -public: - - DeclarableOpsTestsCuda1() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTestsCuda1() { + printf("\n"); + fflush(stdout); + } }; - TEST_F(DeclarableOpsTestsCuda1, Test_CHOOSE_SCALAR_LARGE) { - double inputData[150] = { - 0, 0.51, 0.68, 0.69, 0.86, 0.91, 0.96, 0.97, 0.97, 1.03, 1.13, 1.16, 1.16, 1.17, 1.19, 1.25, 1.25, 1.26, 1.27, 1.28, 1.29, 1.29, 1.29, 1.30, 1.31, 1.32, 1.33, 1.33, 1.35, 1.35, 1.36, 1.37, 1.38, 1.40, 1.41, 1.42, 1.43, 1.44, 1.44, 1.45, 1.45, 1.47, 1.47, 1.51, 1.51, 1.51, 1.52, 1.53, 1.56, 1.57, 1.58, 1.59, 1.61, 1.62, 1.63, 1.63, 1.64, 1.64, 1.66, 1.66, 1.67, 1.67, 1.70, 1.70, 1.70, 1.72, 1.72, 1.72, 1.72, 1.73, 1.74, 1.74, 1.76, 1.76, 1.77, 1.77, 1.80, 1.80, 1.81, 1.82, 1.83, 1.83, 1.84, 1.84, 1.84, 1.85, 1.85, 1.85, 1.86, 1.86, 1.87, 1.88, 1.89, 1.89, 1.89, 1.89, 1.89, 1.91, 1.91, 1.91, 1.92, 1.94, 1.95, 1.97, 1.98, 1.98, 1.98, 1.98, 1.98, 1.99, 2, 2, 2.01, 2.01, 2.02, 2.03, 2.03, 2.03, 2.04, 2.04, 2.05, 2.06, 2.07, 2.08, 2.08, 2.08, 2.08, 2.09, 2.09, 2.10, 2.10, 2.11, 2.11, 2.11, 2.12, 2.12, 2.13, 2.13, 2.14, 2.14, 2.14, 2.14, 2.15, 2.15, 2.16, 2.16, 2.16, 2.16, 2.16, 2.17 - }; - - auto precursor = NDArrayFactory::create(inputData,'c',{1,149}); - NDArray x(nullptr, precursor.specialBuffer(), precursor.shapeInfo()); - - sd::ops::choose op; - //greater than test - auto result = op.evaluate({&x}, {0.0},{3}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(1); - - ASSERT_EQ(148,z->e(0)); - //ASSERT_TRUE(exp.isSameShape(z)); + double inputData[150] = { + 0, 0.51, 0.68, 0.69, 0.86, 0.91, 0.96, 0.97, 0.97, 1.03, 1.13, 1.16, + 1.16, 1.17, 1.19, 1.25, 1.25, 1.26, 1.27, 1.28, 1.29, 1.29, 1.29, 1.30, + 1.31, 1.32, 1.33, 1.33, 1.35, 1.35, 1.36, 1.37, 1.38, 1.40, 1.41, 1.42, + 1.43, 1.44, 1.44, 1.45, 1.45, 1.47, 1.47, 1.51, 1.51, 1.51, 1.52, 1.53, + 1.56, 1.57, 1.58, 1.59, 1.61, 1.62, 1.63, 1.63, 1.64, 1.64, 1.66, 1.66, + 1.67, 1.67, 1.70, 1.70, 1.70, 1.72, 1.72, 1.72, 1.72, 1.73, 1.74, 1.74, + 1.76, 1.76, 1.77, 1.77, 1.80, 1.80, 1.81, 1.82, 1.83, 1.83, 1.84, 1.84, + 1.84, 1.85, 1.85, 1.85, 1.86, 1.86, 1.87, 1.88, 1.89, 1.89, 1.89, 1.89, + 1.89, 1.91, 1.91, 1.91, 1.92, 1.94, 1.95, 1.97, 1.98, 1.98, 1.98, 1.98, + 1.98, 1.99, 2, 2, 2.01, 2.01, 2.02, 2.03, 2.03, 2.03, 2.04, 2.04, + 2.05, 2.06, 2.07, 2.08, 2.08, 2.08, 2.08, 2.09, 2.09, 2.10, 2.10, 2.11, + 2.11, 2.11, 2.12, 2.12, 2.13, 2.13, 2.14, 2.14, 2.14, 2.14, 2.15, 2.15, + 2.16, 2.16, 2.16, 2.16, 2.16, 2.17}; + + auto precursor = NDArrayFactory::create(inputData, 'c', {1, 149}); + NDArray x(nullptr, precursor.specialBuffer(), precursor.shapeInfo()); + + sd::ops::choose op; + // greater than test + auto result = op.evaluate({&x}, {0.0}, {3}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(1); + + ASSERT_EQ(148, z.e(0)); + // ASSERT_TRUE(exp.isSameShape(z)); } /* @@ -69,8 +77,8 @@ TEST_F(DeclarableOpsTestsCuda1, Test_Reverse_TAD_1) { auto timeStart = std::chrono::system_clock::now(); auto status = op.execute({&x}, {&z}, {}, {1}, {}); auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); - nd4j_printf("exec time: %lld us\n", outerTime); + auto outerTime = std::chrono::duration_cast +(timeEnd - timeStart).count(); nd4j_printf("exec time: %lld us\n", outerTime); ASSERT_EQ(Status::OK(), status); } */ \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp index c142fb9aa8f4..a39ac3516c96 100644 --- a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp @@ -18,237 +18,225 @@ // Created by raver on 6/18/2018. // -#include "testlayers.h" -#include #include +#include + +#include "testlayers.h" // #include using namespace sd; - class EmptyTests : public testing::Test { -public: - - EmptyTests() { - printf("\n"); - fflush(stdout); - } + public: + EmptyTests() { + printf("\n"); + fflush(stdout); + } }; TEST_F(EmptyTests, Test_Create_Empty_1) { - auto empty = NDArrayFactory::empty_(); - ASSERT_TRUE(empty->isEmpty()); + auto empty = NDArrayFactory::empty_(); + ASSERT_TRUE(empty->isEmpty()); - ASSERT_EQ(0, empty->lengthOf()); - ASSERT_TRUE(empty->buffer() == nullptr); + ASSERT_EQ(0, empty->lengthOf()); + ASSERT_TRUE(empty->buffer() == nullptr); - ASSERT_TRUE(shape::isEmpty(empty->shapeInfo())); + ASSERT_TRUE(shape::isEmpty(empty->shapeInfo())); - delete empty; + delete empty; } TEST_F(EmptyTests, Test_Create_Empty_2) { - auto empty = NDArrayFactory::empty(); - ASSERT_TRUE(empty.isEmpty()); + auto empty = NDArrayFactory::empty(); + ASSERT_TRUE(empty.isEmpty()); - ASSERT_EQ(0, empty.lengthOf()); - ASSERT_TRUE(empty.buffer() == nullptr); + ASSERT_EQ(0, empty.lengthOf()); + ASSERT_TRUE(empty.buffer() == nullptr); - ASSERT_TRUE(shape::isEmpty(empty.shapeInfo())); - ASSERT_TRUE(empty.isEmpty()); + ASSERT_TRUE(shape::isEmpty(empty.shapeInfo())); + ASSERT_TRUE(empty.isEmpty()); } TEST_F(EmptyTests, Test_Concat_1) { -// auto empty = NDArrayFactory::empty_(); - auto empty = new NDArray('c', {0}, sd::DataType::FLOAT32);//NDArrayFactory::create_('c', {(Nd4jLong)0}}; - auto vector = NDArrayFactory::create_('c', {1}, {1.0f}); + // auto empty = NDArrayFactory::empty_(); + auto empty = NDArray( + 'c', {0}, sd::DataType::FLOAT32); // NDArrayFactory::create_('c', + // {(Nd4jLong)0}}; + auto vector = NDArrayFactory::create('c', {1}, {1.0f}); - ASSERT_TRUE(empty->isEmpty()); + ASSERT_TRUE(empty.isEmpty()); - sd::ops::concat op; - auto result = op.evaluate({empty, vector}, {}, {0}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::concat op; + auto result = op.evaluate({&empty, &vector}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); -// z->printShapeInfo("z shape"); -// z->printIndexedBuffer("z buffr"); + // z->printShapeInfo("z shape"); + // z->printIndexedBuffer("z buffr"); - ASSERT_EQ(*vector, *z); - - delete empty; - delete vector; + ASSERT_EQ(vector, z); } - TEST_F(EmptyTests, Test_Concat_2) { - auto empty = new NDArray('c', {0}, sd::DataType::FLOAT32); //NDArrayFactory::empty_(); - auto scalar1 = NDArrayFactory::create_('c', {1}, {1.0f}); - auto scalar2 = NDArrayFactory::create_('c', {1}, {2.0f}); - auto exp = NDArrayFactory::create('c', {2}, {1.f, 2.f}); + auto empty = NDArray( + 'c', {0}, sd::DataType::FLOAT32); // NDArrayFactory::empty_(); + auto scalar1 = NDArrayFactory::create('c', {1}, {1.0f}); + auto scalar2 = NDArrayFactory::create('c', {1}, {2.0f}); + auto exp = NDArrayFactory::create('c', {2}, {1.f, 2.f}); - ASSERT_TRUE(empty->isEmpty()); + ASSERT_TRUE(empty.isEmpty()); - sd::ops::concat op; - auto result = op.evaluate({empty, scalar1, scalar2}, {}, {0}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::concat op; + auto result = op.evaluate({&empty, &scalar1, &scalar2}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); -// z->printShapeInfo("z shape"); -// z->printIndexedBuffer("z buffr"); - - ASSERT_EQ(exp, *z); - - delete empty; - delete scalar1; - delete scalar2; + ASSERT_EQ(exp, z); } TEST_F(EmptyTests, Test_Concat_3) { - auto empty = NDArrayFactory::empty(); //NDArrayFactory::empty_(); - auto scalar1 = NDArrayFactory::create(1.0f); - auto scalar2 = NDArrayFactory::create(2.0f); - auto exp = NDArrayFactory::create('c', {2}, {1.f, 2.f}); + auto empty = + NDArrayFactory::empty(); // NDArrayFactory::empty_(); + auto scalar1 = NDArrayFactory::create(1.0f); + auto scalar2 = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create('c', {2}, {1.f, 2.f}); - ASSERT_TRUE(empty.isEmpty()); + ASSERT_TRUE(empty.isEmpty()); - sd::ops::concat op; - auto result = op.evaluate({&empty, &scalar1, &scalar2}, {}, {0}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::concat op; + auto result = op.evaluate({&empty, &scalar1, &scalar2}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_EQ(exp, *z); + auto z = result.at(0); + ASSERT_EQ(exp, z); } TEST_F(EmptyTests, Test_Concat_4) { - auto empty = NDArrayFactory::empty(); //NDArrayFactory::empty_(); - auto scalar1 = NDArrayFactory::create(1.0f); - auto scalar2 = NDArrayFactory::create(2.0f); - auto exp = NDArrayFactory::create('c', {2}, {1.f, 2.f}); + auto empty = + NDArrayFactory::empty(); // NDArrayFactory::empty_(); + auto scalar1 = NDArrayFactory::create(1.0f); + auto scalar2 = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create('c', {2}, {1.f, 2.f}); - ASSERT_TRUE(empty.isEmpty()); + ASSERT_TRUE(empty.isEmpty()); - sd::ops::concat op; - auto result = op.evaluate({&scalar1, &empty, &scalar2}, {}, {0}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::concat op; + auto result = op.evaluate({&scalar1, &empty, &scalar2}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_EQ(exp, *z); + ASSERT_EQ(exp, z); } TEST_F(EmptyTests, Test_dup_1) { - auto empty = NDArrayFactory::empty(); - auto dup = new NDArray(empty.dup()); - - ASSERT_TRUE(dup->isEmpty()); - ASSERT_EQ(empty, *dup); + auto empty = NDArrayFactory::empty(); + auto dup = empty.dup(); - delete dup; + ASSERT_TRUE(dup.isEmpty()); + ASSERT_EQ(empty, dup); } TEST_F(EmptyTests, test_empty_scatter_1) { - auto x = NDArrayFactory::create('c', {5}); - auto indices = NDArrayFactory::create('c', {0}); - auto updates = NDArrayFactory::create('c', {0}); + auto x = NDArrayFactory::create('c', {5}); + auto indices = NDArrayFactory::create('c', {0}); + auto updates = NDArrayFactory::create('c', {0}); - x.linspace(1.0f); + x.linspace(1.0f); - sd::ops::scatter_upd op; - auto result = op.evaluate({&x, &indices, &updates}, {}, {}, {true}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_EQ(x, *z); + sd::ops::scatter_upd op; + auto result = op.evaluate({&x, &indices, &updates}, {}, {}, {true}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(x, z); } TEST_F(EmptyTests, test_empty_scatter_2) { - NDArray x ('c', {5}, sd::DataType::FLOAT32); - NDArray z ('c', {5}, sd::DataType::FLOAT32); - auto indices = NDArrayFactory::create('c', {0}); - auto updates = NDArrayFactory::create('c', {0}); + NDArray x('c', {5}, sd::DataType::FLOAT32); + NDArray z('c', {5}, sd::DataType::FLOAT32); + auto indices = NDArrayFactory::create('c', {0}); + auto updates = NDArrayFactory::create('c', {0}); - x.linspace(1.0f); + x.linspace(1.0f); - sd::ops::scatter_upd op; - auto status = op.execute({&x, &indices, &updates}, {&z}, {}, {}, {true}); + sd::ops::scatter_upd op; + auto status = op.execute({&x, &indices, &updates}, {&z}, {}, {}, {true}); - ASSERT_EQ(Status::OK(), status); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(x, z); + ASSERT_EQ(x, z); } TEST_F(EmptyTests, test_shaped_empty_1) { - auto empty = NDArrayFactory::create('c', {2, 0, 3}); - std::vector shape = {2, 0, 3}; - - ASSERT_EQ(sd::DataType::FLOAT32, empty.dataType()); - ASSERT_EQ(0, empty.lengthOf()); - ASSERT_TRUE(empty.isEmpty()); - ASSERT_EQ(shape, empty.getShapeAsVector()); - ASSERT_EQ(3, empty.rankOf()); + auto empty = NDArrayFactory::create('c', {2, 0, 3}); + std::vector shape = {2, 0, 3}; + + ASSERT_EQ(sd::DataType::FLOAT32, empty.dataType()); + ASSERT_EQ(0, empty.lengthOf()); + ASSERT_TRUE(empty.isEmpty()); + ASSERT_EQ(shape, empty.getShapeAsVector()); + ASSERT_EQ(3, empty.rankOf()); } TEST_F(EmptyTests, test_shaped_empty_2) { - auto empty = NDArrayFactory::create('c', {0, 3}); - std::vector shape = {0, 3}; - - ASSERT_EQ(sd::DataType::FLOAT32, empty.dataType()); - ASSERT_EQ(0, empty.lengthOf()); - ASSERT_TRUE(empty.isEmpty()); - ASSERT_EQ(shape, empty.getShapeAsVector()); - ASSERT_EQ(2, empty.rankOf()); + auto empty = NDArrayFactory::create('c', {0, 3}); + std::vector shape = {0, 3}; + + ASSERT_EQ(sd::DataType::FLOAT32, empty.dataType()); + ASSERT_EQ(0, empty.lengthOf()); + ASSERT_TRUE(empty.isEmpty()); + ASSERT_EQ(shape, empty.getShapeAsVector()); + ASSERT_EQ(2, empty.rankOf()); } TEST_F(EmptyTests, test_shaped_empty_3) { - auto empty = NDArrayFactory::create('c', {0}); - std::vector shape = {0}; - - ASSERT_EQ(sd::DataType::FLOAT32, empty.dataType()); - ASSERT_EQ(0, empty.lengthOf()); - ASSERT_TRUE(empty.isEmpty()); - ASSERT_EQ(shape, empty.getShapeAsVector()); - ASSERT_EQ(1, empty.rankOf()); + auto empty = NDArrayFactory::create('c', {0}); + std::vector shape = {0}; + + ASSERT_EQ(sd::DataType::FLOAT32, empty.dataType()); + ASSERT_EQ(0, empty.lengthOf()); + ASSERT_TRUE(empty.isEmpty()); + ASSERT_EQ(shape, empty.getShapeAsVector()); + ASSERT_EQ(1, empty.rankOf()); } TEST_F(EmptyTests, test_shaped_empty_4) { - const auto shape = ConstantShapeHelper::getInstance().vectorShapeInfo(0, sd::DataType::FLOAT32); - NDArray array(shape, true, sd::LaunchContext::defaultContext()); - std::vector shapeOf({0}); - - ASSERT_TRUE(array.isEmpty()); - ASSERT_EQ(1, array.rankOf()); - ASSERT_EQ(shapeOf, array.getShapeAsVector()); + const auto shape = ConstantShapeHelper::getInstance().vectorShapeInfo( + 0, sd::DataType::FLOAT32); + NDArray array(shape, true, sd::LaunchContext::defaultContext()); + std::vector shapeOf({0}); + + ASSERT_TRUE(array.isEmpty()); + ASSERT_EQ(1, array.rankOf()); + ASSERT_EQ(shapeOf, array.getShapeAsVector()); } - TEST_F(EmptyTests, test_empty_matmul_1) { - auto x = NDArrayFactory::create('c', {0, 1}); - auto y = NDArrayFactory::create('c', {1, 0}); - auto e = NDArrayFactory::create('c', {0, 0}); - - sd::ops::matmul op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {0, 1}); + auto y = NDArrayFactory::create('c', {1, 0}); + auto e = NDArrayFactory::create('c', {0, 0}); - auto z = result.at(0); - ASSERT_EQ(e, *z); + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(EmptyTests, test_empty_matmul_2) { - auto x = NDArrayFactory::create('c', {1, 0, 4}); - auto y = NDArrayFactory::create('c', {1, 4, 0}); - auto e = NDArrayFactory::create('c', {1, 0, 0}); + auto x = NDArrayFactory::create('c', {1, 0, 4}); + auto y = NDArrayFactory::create('c', {1, 4, 0}); + auto e = NDArrayFactory::create('c', {1, 0, 0}); - sd::ops::matmul op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - ASSERT_EQ(e, *z); + auto z = result.at(0); + ASSERT_EQ(e, z); } diff --git a/libnd4j/tests_cpu/layers_tests/ExecutionLayerTests.cpp b/libnd4j/tests_cpu/layers_tests/ExecutionLayerTests.cpp new file mode 100644 index 000000000000..ac2026f00fe7 --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/ExecutionLayerTests.cpp @@ -0,0 +1,87 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include +#include + +#include + +#include "testlayers.h" + +using namespace sd; +using namespace sd::graph; + +class ExecutionLayerTests : public testing::Test { + public: + ExecutionLayerTests() { + /// + } +}; + +TEST_F(ExecutionLayerTests, test_reassign_1) { + ExecutionLayer layer; + OpSequence sequence1, sequence2; + + Node a(sd::ops::add(), "add"); + Node m(sd::ops::multiply(), "mul"); + Node d(sd::ops::divide(), "div"); + + Context ctx1(1); + Context ctx2(2); + Context ctx3(3); + + sequence1.append(a, ctx1); + sequence2.append(m, ctx2); + sequence2.append(d, ctx3); + + layer.append(sequence1); + layer.append(sequence2); + + auto seq = layer[0]; + ASSERT_EQ(1, seq.length()); + + seq = layer[1]; + ASSERT_EQ(2, seq.length()); +} + +TEST_F(ExecutionLayerTests, test_purge_1) { + ExecutionLayer layer; + OpSequence sequence1, sequence2; + + Node a(sd::ops::add(), "add"); + Node m(sd::ops::multiply(), "mul"); + + Context ctx1(1); + Context ctx2(2); + + sequence1.append(a, ctx1); + sequence1.append(m, ctx2); + + layer.append(sequence1); + layer.append(sequence2); + + ASSERT_EQ(2, layer.width()); + + layer.purgeEmptySequences(); + + ASSERT_EQ(1, layer.width()); +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/ExtraArgumentsTests.cpp b/libnd4j/tests_cpu/layers_tests/ExtraArgumentsTests.cpp index aa4a72f70a2b..a166e06ac458 100644 --- a/libnd4j/tests_cpu/layers_tests/ExtraArgumentsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ExtraArgumentsTests.cpp @@ -18,49 +18,45 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include + #include +#include "testlayers.h" + using namespace sd; class ExtraArgumentsTests : public testing::Test { -public: - - ExtraArgumentsTests() { - printf("\n"); - fflush(stdout); - } + public: + ExtraArgumentsTests() { + } }; TEST_F(ExtraArgumentsTests, Basic_Test_1) { - if (!Environment::getInstance().isCPU()) - return; + if (!Environment::getInstance().isCPU()) return; - ExtraArguments args({1.0, 2.0, 3.0}); + ExtraArguments args({1.0, 2.0, 3.0}); - float ef[] = {1.f, 2.f, 3.f}; - double ed[] = {1., 2., 3.}; + float ef[] = {1.f, 2.f, 3.f}; + double ed[] = {1., 2., 3.}; - auto ptrFloat = reinterpret_cast(args.argumentsAsT()); - auto ptrDouble = reinterpret_cast(args.argumentsAsT()); - ASSERT_TRUE(ptrFloat != nullptr); - ASSERT_TRUE(ptrDouble != nullptr); + auto ptrFloat = reinterpret_cast(args.argumentsAsT()); + auto ptrDouble = reinterpret_cast(args.argumentsAsT()); + ASSERT_TRUE(ptrFloat != nullptr); + ASSERT_TRUE(ptrDouble != nullptr); - for (int e = 0; e < 3; e++) { - ASSERT_NEAR(ef[e], ptrFloat[e], 1e-5f); - } + for (int e = 0; e < 3; e++) { + ASSERT_NEAR(ef[e], ptrFloat[e], 1e-5f); + } - for (int e = 0; e < 3; e++) { - ASSERT_NEAR(ed[e], ptrDouble[e], 1e-5); - } + for (int e = 0; e < 3; e++) { + ASSERT_NEAR(ed[e], ptrDouble[e], 1e-5); + } } - TEST_F(ExtraArgumentsTests, Basic_Test_2) { - ExtraArguments args; + ExtraArguments args; - auto ptrInt = args.argumentsAsT(); - ASSERT_TRUE(ptrInt == nullptr); + auto ptrInt = args.argumentsAsT(); + ASSERT_TRUE(ptrInt == nullptr); } - diff --git a/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp b/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp index 437edb52548b..6bcb70b33343 100644 --- a/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp @@ -18,70 +18,69 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include -#include +#include +#include #include +#include #include -#include -#include -#include #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class FlatBuffersTest : public testing::Test { -public: - int alpha = 0; - - Nd4jLong *cShape = new Nd4jLong[8]{2, 2, 2, 2, 1, 8192, 1, 99}; - Nd4jLong *fShape = new Nd4jLong[8]{2, 2, 2, 1, 2, 8192, 1, 102}; - - FlatBuffersTest() { - Environment::getInstance().setDebug(false); - Environment::getInstance().setVerbose(false); - Environment::getInstance().setProfiling(false); - } - - ~FlatBuffersTest() { - Environment::getInstance().setDebug(false); - Environment::getInstance().setVerbose(false); - Environment::getInstance().setProfiling(false); - - delete[] cShape; - delete[] fShape; - } + public: + int alpha = 0; + + Nd4jLong *cShape = new Nd4jLong[8]{2, 2, 2, 2, 1, 8192, 1, 99}; + Nd4jLong *fShape = new Nd4jLong[8]{2, 2, 2, 1, 2, 8192, 1, 102}; + + FlatBuffersTest() { + Environment::getInstance().setDebug(false); + Environment::getInstance().setVerbose(false); + Environment::getInstance().setProfiling(false); + } + + ~FlatBuffersTest() { + Environment::getInstance().setDebug(false); + Environment::getInstance().setVerbose(false); + Environment::getInstance().setProfiling(false); + + delete[] cShape; + delete[] fShape; + } }; /** * Simple test that creates Node & reads it */ TEST_F(FlatBuffersTest, BasicTest1) { - flatbuffers::FlatBufferBuilder builder(1024); + flatbuffers::FlatBufferBuilder builder(1024); - auto name = builder.CreateString("wow"); + auto name = builder.CreateString("wow"); - auto node = CreateFlatNode(builder, -1, name, OpType_TRANSFORM_SAME, transform::Ones, {0}); + auto node = CreateFlatNode(builder, -1, name, OpType_TRANSFORM_SAME, + transform::Ones, {0}); - builder.Finish(node); - - // now we have our buffer with data - uint8_t *buf = builder.GetBufferPointer(); - int size = builder.GetSize(); - ASSERT_TRUE(size > 0); + builder.Finish(node); + // now we have our buffer with data + uint8_t *buf = builder.GetBufferPointer(); + int size = builder.GetSize(); + ASSERT_TRUE(size > 0); + auto restored = GetFlatNode(buf); - auto restored = GetFlatNode(buf); + auto gA = new Node(restored); + auto gB = new Node(restored); - auto gA = new Node(restored); - auto gB = new Node(restored); + ASSERT_TRUE(gA->equals(gB)); - ASSERT_TRUE(gA->equals(gB)); - - delete gA; - delete gB; + delete gA; + delete gB; } /* @@ -94,10 +93,11 @@ TEST_F(FlatBuffersTest, FlatGraphTest1) { auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector()); auto fBuffer = builder.CreateVector(array->asByteVector()); - auto fArray = CreateFlatArray(builder, fShape, fBuffer, sd::graph::DType::DType_FLOAT); - auto fVid = CreateIntPair(builder, -1); + auto fArray = CreateFlatArray(builder, fShape, fBuffer, +sd::graph::DType::DType_FLOAT); auto fVid = CreateIntPair(builder, -1); - auto fVar = CreateFlatVariable(builder, fVid, 0, sd::graph::DType::DType_FLOAT, 0, fArray); + auto fVar = CreateFlatVariable(builder, fVid, 0, +sd::graph::DType::DType_FLOAT, 0, fArray); std::vector outputs1, outputs2, inputs1, inputs2; outputs1.push_back(2); @@ -116,8 +116,9 @@ TEST_F(FlatBuffersTest, FlatGraphTest1) { auto name1 = builder.CreateString("wow1"); auto name2 = builder.CreateString("wow2"); - auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM_SAME, transform::Abs, 0, in1, 0, vec1); - auto node2 = CreateFlatNode(builder, 2, name2, OpType_TRANSFORM_STRICT, transform::Cosine, 0, in2, 0, vec2); + auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM_SAME, +transform::Abs, 0, in1, 0, vec1); auto node2 = CreateFlatNode(builder, 2, name2, +OpType_TRANSFORM_STRICT, transform::Cosine, 0, in2, 0, vec2); std::vector> variables_vector; variables_vector.push_back(fVar); @@ -169,9 +170,10 @@ TEST_F(FlatBuffersTest, FlatGraphTest1) { ASSERT_EQ(1, graph.rootNodes()); - auto vs = graph.getVariableSpace(); + auto vs = graph.variableSpace(); - ASSERT_EQ(OutputMode_IMPLICIT, graph.getExecutorConfiguration()->_outputMode); + ASSERT_EQ(OutputMode_IMPLICIT, +graph.getExecutorConfiguration()->_outputMode); ASSERT_EQ(3, vs->totalEntries()); ASSERT_EQ(1, vs->externalEntries()); @@ -184,14 +186,16 @@ TEST_F(FlatBuffersTest, FlatGraphTest1) { sd::graph::GraphExecutioner::execute(&graph); - auto resultWrapper = sd::graph::GraphExecutioner::executeFlatBuffer((Nd4jPointer) buf); + auto resultWrapper = +sd::graph::GraphExecutioner::executeFlatBuffer((Nd4jPointer) buf); auto flatResults = GetFlatResult(resultWrapper->pointer()); ASSERT_EQ(1, flatResults->variables()->size()); ASSERT_TRUE(flatResults->variables()->Get(0)->name() != nullptr); ASSERT_TRUE(flatResults->variables()->Get(0)->name()->c_str() != nullptr); - //nd4j_printf("VARNAME: %s\n", flatResults->variables()->Get(0)->name()->c_str()); + //nd4j_printf("VARNAME: %s\n", +flatResults->variables()->Get(0)->name()->c_str()); auto var0 = new Variable(flatResults->variables()->Get(0)); //auto var1 = new Variable(flatResults->variables()->Get(1)); @@ -207,23 +211,23 @@ TEST_F(FlatBuffersTest, FlatGraphTest1) { } */ TEST_F(FlatBuffersTest, ExecutionTest1) { - auto gA = new Node(OpType_TRANSFORM_SAME); + auto gA = new Node(OpType_TRANSFORM_SAME); - auto c = new float[4] {-1, -2, -3, -4}; - auto array = new NDArray(c, cShape); + auto c = new float[4]{-1, -2, -3, -4}; + auto array = new NDArray(c, cShape); - auto e = new float[4] {1, 2, 3, 4}; - auto exp = new NDArray(e, cShape); + auto e = new float[4]{1, 2, 3, 4}; + auto exp = new NDArray(e, cShape); - //gA->execute(array, nullptr, array); + // gA->execute(array, nullptr, array); - //ASSERT_TRUE(exp->equalsTo(array)); + // ASSERT_TRUE(exp->equalsTo(array)); - delete gA; - delete[] c; - delete array; - delete[] e; - delete exp; + delete gA; + delete[] c; + delete array; + delete[] e; + delete exp; } /* @@ -265,7 +269,8 @@ TEST_F(FlatBuffersTest, ExplicitOutputTest1) { auto name1 = builder.CreateString("wow1"); - auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, sd::graph::DType::FLOAT); + auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, +sd::graph::DType::FLOAT); std::vector> variables_vector; variables_vector.push_back(fXVar); @@ -290,7 +295,8 @@ TEST_F(FlatBuffersTest, ExplicitOutputTest1) { auto flatGraph = graphBuilder.Finish(); builder.Finish(flatGraph); - auto restoredGraph = new Graph(GetFlatGraph(builder.GetBufferPointer())); + auto restoredGraph = new +Graph(GetFlatGraph(builder.GetBufferPointer())); GraphExecutioner::execute(restoredGraph); @@ -299,9 +305,12 @@ TEST_F(FlatBuffersTest, ExplicitOutputTest1) { // IMPLICIT is default ASSERT_EQ(1, results->size()); - //ASSERT_NEAR(-2.0, results->at(0)->getNDArray()->reduceNumber>(), 1e-5); - //ASSERT_NEAR(-1.0, results->at(1)->getNDArray()->reduceNumber>(), 1e-5); - ASSERT_NEAR(-3.0, results->at(0)->getNDArray()->reduceNumber>(), 1e-5); + //ASSERT_NEAR(-2.0, +results->at(0)->getNDArray()->reduceNumber>(), 1e-5); + //ASSERT_NEAR(-1.0, +results->at(1)->getNDArray()->reduceNumber>(), 1e-5); + ASSERT_NEAR(-3.0, +results->at(0)->getNDArray()->reduceNumber>(), 1e-5); //ASSERT_EQ(-1, results->at(0)->id()); //ASSERT_EQ(-2, results->at(1)->id()); @@ -324,7 +333,8 @@ TEST_F(FlatBuffersTest, ReadFile1) { ASSERT_EQ(1, restoredGraph->rootNodes()); ASSERT_EQ(2, restoredGraph->totalNodes()); - auto ones = restoredGraph->getVariableSpace()->getVariable(-1)->getNDArray(); + auto ones = +restoredGraph->getVariableSpace()->getVariable(-1)->getNDArray(); ASSERT_EQ(4, ones->lengthOf()); ASSERT_NEAR(4.0f, ones->template reduceNumber>(), 1e-5); @@ -332,9 +342,9 @@ TEST_F(FlatBuffersTest, ReadFile1) { Nd4jStatus status = GraphExecutioner::execute(restoredGraph); ASSERT_EQ(ND4J_STATUS_OK, status); - auto result = restoredGraph->getVariableSpace()->getVariable(2)->getNDArray(); - ASSERT_EQ(1, result->lengthOf()); - ASSERT_EQ(8, result->e(0)); + auto result = +restoredGraph->getVariableSpace()->getVariable(2)->getNDArray(); ASSERT_EQ(1, +result->lengthOf()); ASSERT_EQ(8, result->e(0)); delete[] data; delete restoredGraph; @@ -342,7 +352,8 @@ TEST_F(FlatBuffersTest, ReadFile1) { TEST_F(FlatBuffersTest, ReadFile2) { uint8_t* data = sd::graph::readFlatBuffers("./resources/adam_sum.fb"); - Nd4jPointer result = GraphExecutioner::executeFlatBuffer((Nd4jPointer) data); + Nd4jPointer result = +GraphExecutioner::executeFlatBuffer((Nd4jPointer) data); ResultSet arrays(GetFlatResult(result)); @@ -355,12 +366,13 @@ TEST_F(FlatBuffersTest, ReadFile2) { } TEST_F(FlatBuffersTest, ReadFile3) { - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/adam_sum.fb"); + auto graph = +GraphExecutioner::importFromFlatBuffers("./resources/adam_sum.fb"); Nd4jStatus status = GraphExecutioner::execute(graph); ASSERT_EQ(ND4J_STATUS_OK, status); - auto z = graph->getVariableSpace()->getVariable(2)->getNDArray(); + auto z = graph->variableSpace()->getVariable(2)->getNDArray(); ASSERT_EQ(1, z->lengthOf()); ASSERT_EQ(8, z->e(0)); @@ -370,14 +382,15 @@ TEST_F(FlatBuffersTest, ReadFile3) { TEST_F(FlatBuffersTest, ReadInception1) { - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/inception.fb"); + auto graph = +GraphExecutioner::importFromFlatBuffers("./resources/inception.fb"); Nd4jStatus status = GraphExecutioner::execute(graph); ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(227)); + ASSERT_TRUE(graph->variableSpace()->hasVariable(227)); - auto lastNode = graph->getVariableSpace()->getVariable(227)->getNDArray(); + auto lastNode = graph->variableSpace()->getVariable(227)->getNDArray(); lastNode->printShapeInfo("Result shape"); @@ -396,7 +409,8 @@ TEST_F(FlatBuffersTest, ReadInception1) { TEST_F(FlatBuffersTest, ReadLoops_3argsWhile_1) { // TF graph: // https://gist.github.com/raver119/b86ef727e9a094aab386e2b35e878966 - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/three_args_while.fb"); + auto graph = +GraphExecutioner::importFromFlatBuffers("./resources/three_args_while.fb"); ASSERT_TRUE(graph != nullptr); @@ -404,12 +418,12 @@ TEST_F(FlatBuffersTest, ReadLoops_3argsWhile_1) { auto expPhi('c', {2, 2}); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(-1)); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(-2)); + ASSERT_TRUE(graph->variableSpace()->hasVariable(-1)); + ASSERT_TRUE(graph->variableSpace()->hasVariable(-2)); - auto phi = graph->getVariableSpace()->getVariable(-2)->getNDArray(); - auto constA = graph->getVariableSpace()->getVariable(-5)->getNDArray(); - auto lessY = graph->getVariableSpace()->getVariable(-6)->getNDArray(); + auto phi = graph->variableSpace()->getVariable(-2)->getNDArray(); + auto constA = graph->variableSpace()->getVariable(-5)->getNDArray(); + auto lessY = graph->variableSpace()->getVariable(-6)->getNDArray(); //constA->printBuffer("constA"); //lessY->printBuffer("lessY"); @@ -422,8 +436,8 @@ TEST_F(FlatBuffersTest, ReadLoops_3argsWhile_1) { // now, we expect some values - auto x = graph->getVariableSpace()->getVariable(20); - auto y = graph->getVariableSpace()->getVariable(21); + auto x = graph->variableSpace()->getVariable(20); + auto y = graph->variableSpace()->getVariable(21); ASSERT_NEAR(110.0f, x->getNDArray()->meanNumber(), 1e-5); ASSERT_NEAR(33.0f, y->getNDArray()->meanNumber(), 1e-5); @@ -435,7 +449,8 @@ TEST_F(FlatBuffersTest, ReadLoops_3argsWhile_1) { TEST_F(FlatBuffersTest, ReadTensorArrayLoop_1) { auto exp('c', {5, 2}, {3., 6., 9., 12., 15., 18., 21., 24., 27., 30.}); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_array_loop.fb"); + auto graph = +GraphExecutioner::importFromFlatBuffers("./resources/tensor_array_loop.fb"); ASSERT_TRUE(graph != nullptr); @@ -445,7 +460,7 @@ TEST_F(FlatBuffersTest, ReadTensorArrayLoop_1) { ASSERT_EQ(ND4J_STATUS_OK, status); - auto variableSpace = graph->getVariableSpace(); + auto variableSpace = graph->variableSpace(); ASSERT_TRUE(variableSpace->hasVariable(23,0)); @@ -468,7 +483,8 @@ TEST_F(FlatBuffersTest, ReadLoops_NestedWhile_1) { // https://gist.github.com/raver119/2aa49daf7ec09ed4ddddbc6262f213a0 sd::ops::assign op1; - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/nested_while.fb"); + auto graph = +GraphExecutioner::importFromFlatBuffers("./resources/nested_while.fb"); ASSERT_TRUE(graph != nullptr); @@ -476,9 +492,9 @@ TEST_F(FlatBuffersTest, ReadLoops_NestedWhile_1) { ASSERT_EQ(ND4J_STATUS_OK, status); - auto x = graph->getVariableSpace()->getVariable(28); - auto y = graph->getVariableSpace()->getVariable(29); - auto z = graph->getVariableSpace()->getVariable(11, 2); + auto x = graph->variableSpace()->getVariable(28); + auto y = graph->variableSpace()->getVariable(29); + auto z = graph->variableSpace()->getVariable(11, 2); ASSERT_NEAR(110.0f, x->getNDArray()->meanNumber(), 1e-5); ASSERT_NEAR(33.0f, y->getNDArray()->meanNumber(), 1e-5); @@ -492,11 +508,14 @@ TEST_F(FlatBuffersTest, ReadLoops_NestedWhile_1) { /* TEST_F(FlatBuffersTest, ReadTensorArray_1) { - // TF graph: https://gist.github.com/raver119/3265923eed48feecc465d17ec842b6e2 + // TF graph: +https://gist.github.com/raver119/3265923eed48feecc465d17ec842b6e2 - auto exp('c', {3, 2}, {1.000000, 1.000000, 2.000000, 2.000000, 3.000000, 3.000000}); + auto exp('c', {3, 2}, +{1.000000, 1.000000, 2.000000, 2.000000, 3.000000, 3.000000}); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_array.fb"); + auto graph = +GraphExecutioner::importFromFlatBuffers("./resources/tensor_array.fb"); ASSERT_TRUE(graph != nullptr); @@ -504,9 +523,9 @@ TEST_F(FlatBuffersTest, ReadTensorArray_1) { ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(14)); + ASSERT_TRUE(graph->variableSpace()->hasVariable(14)); - auto z = graph->getVariableSpace()->getVariable(14)->getNDArray(); + auto z = graph->variableSpace()->getVariable(14)->getNDArray(); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -517,8 +536,9 @@ TEST_F(FlatBuffersTest, ReadTensorArray_1) { */ /* TEST_F(FlatBuffersTest, ReadStridedSlice_1) { - // TF graph: https://gist.github.com/raver119/fc3bf2d31c91e465c635b24020fd798d - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_slice.fb"); + // TF graph: +https://gist.github.com/raver119/fc3bf2d31c91e465c635b24020fd798d auto graph = +GraphExecutioner::importFromFlatBuffers("./resources/tensor_slice.fb"); ASSERT_TRUE(graph != nullptr); @@ -526,9 +546,9 @@ TEST_F(FlatBuffersTest, ReadStridedSlice_1) { ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(7)); + ASSERT_TRUE(graph->variableSpace()->hasVariable(7)); - auto z = graph->getVariableSpace()->getVariable(7)->getNDArray(); + auto z = graph->variableSpace()->getVariable(7)->getNDArray(); ASSERT_NEAR(73.5f, z->e(0), 1e-5); @@ -542,11 +562,12 @@ TEST_F(FlatBuffersTest, ReduceDim_1) { exp.assign(3.0); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); + auto graph = +GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); graph->printOut(); - auto variableSpace = graph->getVariableSpace(); + auto variableSpace = graph->variableSpace(); ASSERT_TRUE(variableSpace->hasVariable(1)); @@ -575,11 +596,12 @@ TEST_F(FlatBuffersTest, ReduceDim_2) { exp.assign(3.0); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_true.fb"); + auto graph = +GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_true.fb"); graph->printOut(); - auto variableSpace = graph->getVariableSpace(); + auto variableSpace = graph->variableSpace(); ASSERT_TRUE(variableSpace->hasVariable(1)); @@ -605,160 +627,445 @@ TEST_F(FlatBuffersTest, ReduceDim_2) { #ifdef GRAPH_FILES_OK TEST_F(FlatBuffersTest, Ae_00) { - sd::ops::rank op1; + sd::ops::rank op1; - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); - auto exp = NDArrayFactory::create('c', {5, 4}, {0.32454616f, -0.06604697f, 0.22593613f, 0.43166467f, -0.18320604f, 0.00102305f, -0.06963076f, 0.25266643f, 0.07568010f, -0.03009197f, 0.07805517f, 0.33180334f, -0.06220427f, 0.07249600f, -0.06726961f, -0.22998397f, -0.06343779f, 0.07384885f, -0.06891008f, -0.23745790f}); + auto exp = NDArrayFactory::create( + 'c', {5, 4}, + {0.32454616f, -0.06604697f, 0.22593613f, 0.43166467f, -0.18320604f, + 0.00102305f, -0.06963076f, 0.25266643f, 0.07568010f, -0.03009197f, + 0.07805517f, 0.33180334f, -0.06220427f, 0.07249600f, -0.06726961f, + -0.22998397f, -0.06343779f, 0.07384885f, -0.06891008f, -0.23745790f}); -// graph->printOut(); + // graph->printOut(); - ASSERT_EQ(OutputMode_VARIABLE_SPACE, graph->getExecutorConfiguration()->_outputMode); + ASSERT_EQ(OutputMode_VARIABLE_SPACE, + graph->getExecutorConfiguration()->_outputMode); - auto result = GraphExecutioner::execute(graph); - ASSERT_EQ(ND4J_STATUS_OK, result); + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(ND4J_STATUS_OK, result); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(18)); + ASSERT_TRUE(graph->variableSpace()->hasVariable(18)); - auto z = graph->getVariableSpace()->getVariable(18)->getNDArray(); + auto z = graph->variableSpace()->getVariable(18)->getNDArray(); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); - delete graph; + delete graph; } TEST_F(FlatBuffersTest, expand_dims) { - sd::ops::rank op1; + sd::ops::rank op1; - auto exp = NDArrayFactory::create('c', {3, 1, 4}, {-0.95938617f, -1.20301781f, 1.22260064f, 0.50172403f, 0.59972949f, 0.78568028f, 0.31609724f, 1.51674747f, 0.68013491f, -0.05227458f, 0.25903158f, 1.13243439f}); + auto exp = NDArrayFactory::create( + 'c', {3, 1, 4}, + {-0.95938617f, -1.20301781f, 1.22260064f, 0.50172403f, 0.59972949f, + 0.78568028f, 0.31609724f, 1.51674747f, 0.68013491f, -0.05227458f, + 0.25903158f, 1.13243439f}); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/expand_dim.fb"); + auto graph = + GraphExecutioner::importFromFlatBuffers("./resources/expand_dim.fb"); -// graph->printOut(); + // graph->printOut(); - auto result = GraphExecutioner::execute(graph); - ASSERT_EQ(ND4J_STATUS_OK, result); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(5)); + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(ND4J_STATUS_OK, result); + ASSERT_TRUE(graph->variableSpace()->hasVariable(5)); - auto z = graph->getVariableSpace()->getVariable(5)->getNDArray(); + auto z = graph->variableSpace()->getVariable(5)->getNDArray(); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); - delete graph; + delete graph; } TEST_F(FlatBuffersTest, transpose) { - sd::ops::rank op1; + sd::ops::rank op1; - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/transpose.fb"); + auto graph = + GraphExecutioner::importFromFlatBuffers("./resources/transpose.fb"); - //graph->printOut(); + // graph->printOut(); - auto result = GraphExecutioner::execute(graph); - ASSERT_EQ(ND4J_STATUS_OK, result); + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(ND4J_STATUS_OK, result); - delete graph; + delete graph; } TEST_F(FlatBuffersTest, Test_Stitches) { - sd::ops::realdiv op0; + sd::ops::realdiv op0; - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/partition_stitch_misc.fb"); - //graph->printOut(); + auto graph = GraphExecutioner::importFromFlatBuffers( + "./resources/partition_stitch_misc.fb"); + // graph->printOut(); + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(ND4J_STATUS_OK, result); - auto result = GraphExecutioner::execute(graph); - ASSERT_EQ(ND4J_STATUS_OK, result); - - delete graph; + delete graph; } TEST_F(FlatBuffersTest, Test_GruDynamicMnist) { - sd::Environment::getInstance().setDebug(false); - sd::Environment::getInstance().setVerbose(false); + sd::Environment::getInstance().setDebug(false); + sd::Environment::getInstance().setVerbose(false); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/gru_dynamic_mnist.fb"); - //graph->printOut(); + auto graph = GraphExecutioner::importFromFlatBuffers( + "./resources/gru_dynamic_mnist.fb"); + // graph->printOut(); - auto timeStart = std::chrono::system_clock::now(); - auto result = GraphExecutioner::execute(graph); - ASSERT_EQ(ND4J_STATUS_OK, result); + auto timeStart = std::chrono::system_clock::now(); + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(ND4J_STATUS_OK, result); - auto timeEnd = std::chrono::system_clock::now(); + auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); + auto outerTime = + std::chrono::duration_cast(timeEnd - timeStart) + .count(); - // nd4j_printf("GRU time 1 time %lld us\n", outerTime); + // nd4j_printf("GRU time 1 time %lld us\n", outerTime); - delete graph; + delete graph; } TEST_F(FlatBuffersTest, Test_Non2D_2) { - sd::Environment::getInstance().setDebug(false); - sd::Environment::getInstance().setVerbose(false); - sd::ops::realdiv op0; + sd::Environment::getInstance().setDebug(false); + sd::Environment::getInstance().setVerbose(false); + sd::ops::realdiv op0; - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/non2d_2.fb"); - //graph->printOut(); + auto graph = + GraphExecutioner::importFromFlatBuffers("./resources/non2d_2.fb"); + // graph->printOut(); - auto result = GraphExecutioner::execute(graph); - ASSERT_EQ(ND4J_STATUS_OK, result); + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(ND4J_STATUS_OK, result); - delete graph; + delete graph; } - TEST_F(FlatBuffersTest, Test_TensorDotMisc) { - Environment::getInstance().setVerbose(false); - Environment::getInstance().setDebug(false); - - auto e = NDArrayFactory::create('c', {1, 3, 16, 20}, {4.f, 6.f, 6.f, 5.f, 6.f, 4.f, 2.f, 3.f, 5.f, 5.f, 1.f, 4.f, 6.f, 3.f, 2.f, 1.f, 5.f, 4.f, 4.f, 4.f, 4.f, 4.f, 3.f, 4.f, 2.f, 3.f, 3.f, 5.f, 3.f, 6.f, 5.f, 4.f, 4.f, 3.f, 6.f, 1.f, 2.f, 4.f, 2.f, 6.f, 4.f, 2.f, 3.f, 2.f, 3.f, 1.f, 2.f, 4.f, 3.f, 5.f, 3.f, 3.f, 5.f, 2.f, 6.f, 3.f, 4.f, 4.f, 4.f, 4.f, 6.f, 4.f, 5.f, 2.f, 5.f, 5.f, 5.f, 5.f, 2.f, 4.f, 4.f, 4.f, 5.f, 4.f, 3.f, 6.f, 3.f, 4.f, 5.f, 2.f, 5.f, 4.f, 4.f, 5.f, 4.f, 3.f, 4.f, 5.f, 5.f, 3.f, 5.f, 6.f, 6.f, 3.f, 4.f, 5.f, 7.f, 6.f, 5.f, 2.f, 4.f, 5.f, 5.f, 4.f, 5.f, 4.f, 4.f, 6.f, 3.f, 4.f, 5.f, 4.f, 6.f, 2.f, 3.f, 4.f, 3.f, 3.f, 2.f, 2.f, 3.f, 4.f, 7.f, 3.f, 5.f, 4.f, 5.f, 4.f, 4.f, 4.f, 4.f, 6.f, 2.f, 3.f, 2.f, 5.f, 5.f, 4.f, 5.f, 2.f, 2.f, 1.f, 6.f, 2.f, 2.f, 3.f, 4.f, 5.f, 5.f, 3.f, 6.f, 6.f, 4.f, 3.f, 3.f, 3.f, 3.f, 3.f, 4.f, 5.f, 4.f, 4.f, 3.f, 5.f, 2.f, 3.f, 4.f, 5.f, 3.f, 4.f, 5.f, 5.f, 8.f, 4.f, 5.f, 3.f, 3.f, 4.f, 4.f, 5.f, 4.f, 5.f, 3.f, 3.f, 7.f, 2.f, 3.f, 2.f, 6.f, 6.f, 4.f, 4.f, 3.f, 5.f, 6.f, 2.f, 4.f, 3.f, 3.f, 4.f, 5.f, 3.f, 3.f, 6.f, 5.f, 3.f, 2.f, 5.f, 4.f, 4.f, 3.f, 5.f, 5.f, 6.f, 7.f, 3.f, 4.f, 3.f, 5.f, 6.f, 7.f, 5.f, 6.f, 5.f, 7.f, 4.f, 6.f, 5.f, 5.f, 6.f, 4.f, 2.f, 5.f, 4.f, 3.f, 4.f, 1.f, 5.f, 5.f, 3.f, 2.f, 2.f, 6.f, 5.f, 5.f, 2.f, 5.f, 2.f, 4.f, 4.f, 5.f, 5.f, 4.f, 3.f, 7.f, 4.f, 5.f, 3.f, 3.f, 3.f, 2.f, 3.f, 2.f, 3.f, 3.f, 4.f, 4.f, 2.f, 4.f, 5.f, 3.f, 4.f, 5.f, 3.f, 7.f, 2.f, 1.f, 3.f, 2.f, 3.f, 2.f, 3.f, 3.f, 4.f, 3.f, 4.f, 2.f, 4.f, 4.f, 4.f, 5.f, 3.f, 5.f, 3.f, 6.f, 6.f, 5.f, 3.f, 5.f, 3.f, 4.f, 3.f, 5.f, 3.f, 5.f, 6.f, 5.f, 3.f, 4.f, 5.f, 5.f, 3.f, 3.f, 3.f, 4.f, 6.f, 4.f, 3.f, 7.f, 4.f, 4.f, 6.f, 7.f, 5.f, 5.f, 3.f, 1.f, 2.f, 5.f, 5.f, 2.f, 5.f, 7.f, 5.f, 3.f, 1.f, 4.f, 6.f, 5.f, 7.f, 5.f, 6.f, 5.f, 6.f, 4.f, 3.f, 3.f, 4.f, 3.f, 4.f, 4.f, 4.f, 4.f, 3.f, 5.f, 2.f, 4.f, 5.f, 2.f, 5.f, 5.f, 4.f, 5.f, 4.f, 5.f, 2.f, 3.f, 5.f, 3.f, 6.f, 3.f, 4.f, 5.f, 3.f, 6.f, 5.f, 5.f, 6.f, 4.f, 6.f, 7.f, 4.f, 5.f, 3.f, 5.f, 4.f, 4.f, 4.f, 2.f, 2.f, 5.f, 3.f, 5.f, 3.f, 4.f, 6.f, 3.f, 5.f, 5.f, 3.f, 5.f, 4.f, 4.f, 4.f, 5.f, 2.f, 3.f, 5.f, 4.f, 2.f, 4.f, 5.f, 4.f, 2.f, 3.f, 4.f, 4.f, 5.f, 5.f, 1.f, 4.f, 4.f, 4.f, 3.f, 4.f, 5.f, 5.f, 8.f, 4.f, 4.f, 4.f, 3.f, 6.f, 2.f, 3.f, 4.f, 4.f, 4.f, 3.f, 2.f, 3.f, 4.f, 8.f, 3.f, 5.f, 5.f, 5.f, 3.f, 3.f, 4.f, 5.f, 7.f, 3.f, 3.f, 3.f, 6.f, 6.f, 5.f, 5.f, 3.f, 4.f, 3.f, 8.f, 3.f, 4.f, 2.f, 3.f, 4.f, 4.f, 3.f, 5.f, 5.f, 3.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 6.f, 6.f, 5.f, 6.f, 4.f, 5.f, 4.f, 6.f, 4.f, 5.f, 5.f, 4.f, 7.f, 3.f, 5.f, 5.f, 3.f, 5.f, 5.f, 6.f, 4.f, 5.f, 4.f, 2.f, 7.f, 2.f, 3.f, 1.f, 4.f, 5.f, 5.f, 4.f, 4.f, 5.f, 7.f, 2.f, 3.f, 3.f, 4.f, 4.f, 5.f, 3.f, 3.f, 6.f, 6.f, 3.f, 2.f, 4.f, 3.f, 3.f, 3.f, 3.f, 4.f, 4.f, 5.f, 1.f, 2.f, 3.f, 3.f, 4.f, 5.f, 4.f, 5.f, 4.f, 5.f, 6.f, 6.f, 6.f, 6.f, 7.f, 4.f, 3.f, 4.f, 5.f, 4.f, 4.f, 2.f, 5.f, 6.f, 4.f, 2.f, 2.f, 6.f, 5.f, 5.f, 1.f, 4.f, 2.f, 3.f, 4.f, 5.f, 5.f, 4.f, 5.f, 9.f, 4.f, 6.f, 4.f, 5.f, 5.f, 3.f, 4.f, 5.f, 5.f, 5.f, 4.f, 3.f, 1.f, 3.f, 4.f, 3.f, 4.f, 4.f, 3.f, 6.f, 2.f, 3.f, 3.f, 2.f, 3.f, 3.f, 4.f, 5.f, 6.f, 5.f, 5.f, 3.f, 4.f, 5.f, 5.f, 4.f, 3.f, 4.f, 3.f, 6.f, 7.f, 6.f, 4.f, 6.f, 4.f, 3.f, 3.f, 4.f, 3.f, 5.f, 5.f, 4.f, 2.f, 3.f, 4.f, 5.f, 3.f, 4.f, 2.f, 4.f, 5.f, 3.f, 3.f, 7.f, 4.f, 2.f, 5.f, 6.f, 5.f, 5.f, 3.f, 1.f, 2.f, 4.f, 4.f, 1.f, 3.f, 6.f, 3.f, 3.f, 1.f, 4.f, 4.f, 4.f, 5.f, 3.f, 4.f, 3.f, 4.f, 2.f, 3.f, 3.f, 4.f, 3.f, 4.f, 3.f, 3.f, 4.f, 2.f, 5.f, 1.f, 3.f, 4.f, 2.f, 6.f, 4.f, 3.f, 4.f, 3.f, 3.f, 1.f, 2.f, 5.f, 2.f, 6.f, 4.f, 5.f, 6.f, 3.f, 6.f, 4.f, 4.f, 5.f, 3.f, 5.f, 6.f, 3.f, 4.f, 2.f, 4.f, 5.f, 5.f, 5.f, 2.f, 3.f, 4.f, 3.f, 5.f, 3.f, 3.f, 9.f, 6.f, 7.f, 7.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 4.f, 6.f, 5.f, 3.f, 5.f, 5.f, 5.f, 2.f, 4.f, 6.f, 7.f, 7.f, 5.f, 3.f, 4.f, 5.f, 4.f, 4.f, 5.f, 5.f, 5.f, 8.f, 4.f, 4.f, 4.f, 3.f, 5.f, 3.f, 3.f, 4.f, 4.f, 5.f, 3.f, 3.f, 2.f, 3.f, 6.f, 2.f, 5.f, 4.f, 4.f, 3.f, 3.f, 3.f, 5.f, 7.f, 2.f, 3.f, 2.f, 5.f, 5.f, 4.f, 4.f, 2.f, 2.f, 1.f, 6.f, 1.f, 2.f, 2.f, 3.f, 5.f, 4.f, 3.f, 5.f, 5.f, 3.f, 2.f, 2.f, 2.f, 2.f, 4.f, 3.f, 4.f, 4.f, 4.f, 4.f, 5.f, 2.f, 4.f, 4.f, 5.f, 2.f, 4.f, 4.f, 5.f, 9.f, 4.f, 5.f, 4.f, 3.f, 5.f, 5.f, 6.f, 4.f, 4.f, 3.f, 3.f, 6.f, 2.f, 3.f, 2.f, 5.f, 6.f, 4.f, 4.f, 3.f, 5.f, 6.f, 4.f, 5.f, 5.f, 6.f, 7.f, 4.f, 2.f, 3.f, 5.f, 4.f, 4.f, 3.f, 5.f, 5.f, 4.f, 3.f, 4.f, 5.f, 4.f, 6.f, 3.f, 4.f, 4.f, 5.f, 6.f, 6.f, 4.f, 6.f, 6.f, 6.f, 5.f, 6.f, 6.f, 7.f, 7.f, 4.f, 3.f, 4.f, 4.f, 4.f, 5.f, 2.f, 5.f, 7.f, 5.f, 2.f, 1.f, 5.f, 5.f, 4.f, 1.f, 4.f, 1.f, 3.f, 3.f, 5.f, 4.f, 4.f, 3.f, 7.f, 3.f, 6.f, 3.f, 3.f, 4.f, 1.f, 3.f, 2.f, 3.f, 3.f, 4.f, 3.f, 1.f, 3.f, 4.f, 2.f, 4.f, 4.f, 2.f, 6.f, 1.f, 2.f, 2.f, 2.f, 3.f, 2.f, 3.f, 3.f, 4.f, 4.f, 4.f, 2.f, 4.f, 4.f, 4.f, 5.f, 5.f, 5.f, 4.f, 8.f, 5.f, 5.f, 3.f, 5.f, 3.f, 3.f, 2.f, 4.f, 3.f, 5.f, 6.f, 5.f, 3.f, 4.f, 5.f, 5.f, 3.f, 4.f, 3.f, 4.f, 8.f, 6.f, 5.f, 9.f, 6.f}); - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_dot_misc.fb"); -// graph->printOut(); - - auto result = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), result); - - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(77)); - - auto z = graph->getVariableSpace()->getVariable(77,0)->getNDArray(); - - ASSERT_EQ(e, *z); - - delete graph; + Environment::getInstance().setVerbose(false); + Environment::getInstance().setDebug(false); + + auto e = NDArrayFactory::create( + 'c', {1, 3, 16, 20}, + {4.f, 6.f, 6.f, 5.f, 6.f, 4.f, 2.f, 3.f, 5.f, 5.f, 1.f, 4.f, 6.f, 3.f, + 2.f, 1.f, 5.f, 4.f, 4.f, 4.f, 4.f, 4.f, 3.f, 4.f, 2.f, 3.f, 3.f, 5.f, + 3.f, 6.f, 5.f, 4.f, 4.f, 3.f, 6.f, 1.f, 2.f, 4.f, 2.f, 6.f, 4.f, 2.f, + 3.f, 2.f, 3.f, 1.f, 2.f, 4.f, 3.f, 5.f, 3.f, 3.f, 5.f, 2.f, 6.f, 3.f, + 4.f, 4.f, 4.f, 4.f, 6.f, 4.f, 5.f, 2.f, 5.f, 5.f, 5.f, 5.f, 2.f, 4.f, + 4.f, 4.f, 5.f, 4.f, 3.f, 6.f, 3.f, 4.f, 5.f, 2.f, 5.f, 4.f, 4.f, 5.f, + 4.f, 3.f, 4.f, 5.f, 5.f, 3.f, 5.f, 6.f, 6.f, 3.f, 4.f, 5.f, 7.f, 6.f, + 5.f, 2.f, 4.f, 5.f, 5.f, 4.f, 5.f, 4.f, 4.f, 6.f, 3.f, 4.f, 5.f, 4.f, + 6.f, 2.f, 3.f, 4.f, 3.f, 3.f, 2.f, 2.f, 3.f, 4.f, 7.f, 3.f, 5.f, 4.f, + 5.f, 4.f, 4.f, 4.f, 4.f, 6.f, 2.f, 3.f, 2.f, 5.f, 5.f, 4.f, 5.f, 2.f, + 2.f, 1.f, 6.f, 2.f, 2.f, 3.f, 4.f, 5.f, 5.f, 3.f, 6.f, 6.f, 4.f, 3.f, + 3.f, 3.f, 3.f, 3.f, 4.f, 5.f, 4.f, 4.f, 3.f, 5.f, 2.f, 3.f, 4.f, 5.f, + 3.f, 4.f, 5.f, 5.f, 8.f, 4.f, 5.f, 3.f, 3.f, 4.f, 4.f, 5.f, 4.f, 5.f, + 3.f, 3.f, 7.f, 2.f, 3.f, 2.f, 6.f, 6.f, 4.f, 4.f, 3.f, 5.f, 6.f, 2.f, + 4.f, 3.f, 3.f, 4.f, 5.f, 3.f, 3.f, 6.f, 5.f, 3.f, 2.f, 5.f, 4.f, 4.f, + 3.f, 5.f, 5.f, 6.f, 7.f, 3.f, 4.f, 3.f, 5.f, 6.f, 7.f, 5.f, 6.f, 5.f, + 7.f, 4.f, 6.f, 5.f, 5.f, 6.f, 4.f, 2.f, 5.f, 4.f, 3.f, 4.f, 1.f, 5.f, + 5.f, 3.f, 2.f, 2.f, 6.f, 5.f, 5.f, 2.f, 5.f, 2.f, 4.f, 4.f, 5.f, 5.f, + 4.f, 3.f, 7.f, 4.f, 5.f, 3.f, 3.f, 3.f, 2.f, 3.f, 2.f, 3.f, 3.f, 4.f, + 4.f, 2.f, 4.f, 5.f, 3.f, 4.f, 5.f, 3.f, 7.f, 2.f, 1.f, 3.f, 2.f, 3.f, + 2.f, 3.f, 3.f, 4.f, 3.f, 4.f, 2.f, 4.f, 4.f, 4.f, 5.f, 3.f, 5.f, 3.f, + 6.f, 6.f, 5.f, 3.f, 5.f, 3.f, 4.f, 3.f, 5.f, 3.f, 5.f, 6.f, 5.f, 3.f, + 4.f, 5.f, 5.f, 3.f, 3.f, 3.f, 4.f, 6.f, 4.f, 3.f, 7.f, 4.f, 4.f, 6.f, + 7.f, 5.f, 5.f, 3.f, 1.f, 2.f, 5.f, 5.f, 2.f, 5.f, 7.f, 5.f, 3.f, 1.f, + 4.f, 6.f, 5.f, 7.f, 5.f, 6.f, 5.f, 6.f, 4.f, 3.f, 3.f, 4.f, 3.f, 4.f, + 4.f, 4.f, 4.f, 3.f, 5.f, 2.f, 4.f, 5.f, 2.f, 5.f, 5.f, 4.f, 5.f, 4.f, + 5.f, 2.f, 3.f, 5.f, 3.f, 6.f, 3.f, 4.f, 5.f, 3.f, 6.f, 5.f, 5.f, 6.f, + 4.f, 6.f, 7.f, 4.f, 5.f, 3.f, 5.f, 4.f, 4.f, 4.f, 2.f, 2.f, 5.f, 3.f, + 5.f, 3.f, 4.f, 6.f, 3.f, 5.f, 5.f, 3.f, 5.f, 4.f, 4.f, 4.f, 5.f, 2.f, + 3.f, 5.f, 4.f, 2.f, 4.f, 5.f, 4.f, 2.f, 3.f, 4.f, 4.f, 5.f, 5.f, 1.f, + 4.f, 4.f, 4.f, 3.f, 4.f, 5.f, 5.f, 8.f, 4.f, 4.f, 4.f, 3.f, 6.f, 2.f, + 3.f, 4.f, 4.f, 4.f, 3.f, 2.f, 3.f, 4.f, 8.f, 3.f, 5.f, 5.f, 5.f, 3.f, + 3.f, 4.f, 5.f, 7.f, 3.f, 3.f, 3.f, 6.f, 6.f, 5.f, 5.f, 3.f, 4.f, 3.f, + 8.f, 3.f, 4.f, 2.f, 3.f, 4.f, 4.f, 3.f, 5.f, 5.f, 3.f, 2.f, 3.f, 3.f, + 3.f, 4.f, 4.f, 4.f, 6.f, 6.f, 5.f, 6.f, 4.f, 5.f, 4.f, 6.f, 4.f, 5.f, + 5.f, 4.f, 7.f, 3.f, 5.f, 5.f, 3.f, 5.f, 5.f, 6.f, 4.f, 5.f, 4.f, 2.f, + 7.f, 2.f, 3.f, 1.f, 4.f, 5.f, 5.f, 4.f, 4.f, 5.f, 7.f, 2.f, 3.f, 3.f, + 4.f, 4.f, 5.f, 3.f, 3.f, 6.f, 6.f, 3.f, 2.f, 4.f, 3.f, 3.f, 3.f, 3.f, + 4.f, 4.f, 5.f, 1.f, 2.f, 3.f, 3.f, 4.f, 5.f, 4.f, 5.f, 4.f, 5.f, 6.f, + 6.f, 6.f, 6.f, 7.f, 4.f, 3.f, 4.f, 5.f, 4.f, 4.f, 2.f, 5.f, 6.f, 4.f, + 2.f, 2.f, 6.f, 5.f, 5.f, 1.f, 4.f, 2.f, 3.f, 4.f, 5.f, 5.f, 4.f, 5.f, + 9.f, 4.f, 6.f, 4.f, 5.f, 5.f, 3.f, 4.f, 5.f, 5.f, 5.f, 4.f, 3.f, 1.f, + 3.f, 4.f, 3.f, 4.f, 4.f, 3.f, 6.f, 2.f, 3.f, 3.f, 2.f, 3.f, 3.f, 4.f, + 5.f, 6.f, 5.f, 5.f, 3.f, 4.f, 5.f, 5.f, 4.f, 3.f, 4.f, 3.f, 6.f, 7.f, + 6.f, 4.f, 6.f, 4.f, 3.f, 3.f, 4.f, 3.f, 5.f, 5.f, 4.f, 2.f, 3.f, 4.f, + 5.f, 3.f, 4.f, 2.f, 4.f, 5.f, 3.f, 3.f, 7.f, 4.f, 2.f, 5.f, 6.f, 5.f, + 5.f, 3.f, 1.f, 2.f, 4.f, 4.f, 1.f, 3.f, 6.f, 3.f, 3.f, 1.f, 4.f, 4.f, + 4.f, 5.f, 3.f, 4.f, 3.f, 4.f, 2.f, 3.f, 3.f, 4.f, 3.f, 4.f, 3.f, 3.f, + 4.f, 2.f, 5.f, 1.f, 3.f, 4.f, 2.f, 6.f, 4.f, 3.f, 4.f, 3.f, 3.f, 1.f, + 2.f, 5.f, 2.f, 6.f, 4.f, 5.f, 6.f, 3.f, 6.f, 4.f, 4.f, 5.f, 3.f, 5.f, + 6.f, 3.f, 4.f, 2.f, 4.f, 5.f, 5.f, 5.f, 2.f, 3.f, 4.f, 3.f, 5.f, 3.f, + 3.f, 9.f, 6.f, 7.f, 7.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 4.f, 6.f, + 5.f, 3.f, 5.f, 5.f, 5.f, 2.f, 4.f, 6.f, 7.f, 7.f, 5.f, 3.f, 4.f, 5.f, + 4.f, 4.f, 5.f, 5.f, 5.f, 8.f, 4.f, 4.f, 4.f, 3.f, 5.f, 3.f, 3.f, 4.f, + 4.f, 5.f, 3.f, 3.f, 2.f, 3.f, 6.f, 2.f, 5.f, 4.f, 4.f, 3.f, 3.f, 3.f, + 5.f, 7.f, 2.f, 3.f, 2.f, 5.f, 5.f, 4.f, 4.f, 2.f, 2.f, 1.f, 6.f, 1.f, + 2.f, 2.f, 3.f, 5.f, 4.f, 3.f, 5.f, 5.f, 3.f, 2.f, 2.f, 2.f, 2.f, 4.f, + 3.f, 4.f, 4.f, 4.f, 4.f, 5.f, 2.f, 4.f, 4.f, 5.f, 2.f, 4.f, 4.f, 5.f, + 9.f, 4.f, 5.f, 4.f, 3.f, 5.f, 5.f, 6.f, 4.f, 4.f, 3.f, 3.f, 6.f, 2.f, + 3.f, 2.f, 5.f, 6.f, 4.f, 4.f, 3.f, 5.f, 6.f, 4.f, 5.f, 5.f, 6.f, 7.f, + 4.f, 2.f, 3.f, 5.f, 4.f, 4.f, 3.f, 5.f, 5.f, 4.f, 3.f, 4.f, 5.f, 4.f, + 6.f, 3.f, 4.f, 4.f, 5.f, 6.f, 6.f, 4.f, 6.f, 6.f, 6.f, 5.f, 6.f, 6.f, + 7.f, 7.f, 4.f, 3.f, 4.f, 4.f, 4.f, 5.f, 2.f, 5.f, 7.f, 5.f, 2.f, 1.f, + 5.f, 5.f, 4.f, 1.f, 4.f, 1.f, 3.f, 3.f, 5.f, 4.f, 4.f, 3.f, 7.f, 3.f, + 6.f, 3.f, 3.f, 4.f, 1.f, 3.f, 2.f, 3.f, 3.f, 4.f, 3.f, 1.f, 3.f, 4.f, + 2.f, 4.f, 4.f, 2.f, 6.f, 1.f, 2.f, 2.f, 2.f, 3.f, 2.f, 3.f, 3.f, 4.f, + 4.f, 4.f, 2.f, 4.f, 4.f, 4.f, 5.f, 5.f, 5.f, 4.f, 8.f, 5.f, 5.f, 3.f, + 5.f, 3.f, 3.f, 2.f, 4.f, 3.f, 5.f, 6.f, 5.f, 3.f, 4.f, 5.f, 5.f, 3.f, + 4.f, 3.f, 4.f, 8.f, 6.f, 5.f, 9.f, 6.f}); + + auto graph = + GraphExecutioner::importFromFlatBuffers("./resources/tensor_dot_misc.fb"); + // graph->printOut(); + + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), result); + + ASSERT_TRUE(graph->variableSpace()->hasVariable(77)); + + auto z = graph->variableSpace()->getVariable(77, 0)->getNDArray(); + + ASSERT_EQ(e, *z); + + delete graph; } - TEST_F(FlatBuffersTest, Test_MNIST_00_1) { - auto e = NDArrayFactory::create('c', {100, 10}, {0.00066107f, 0.00002358f, 0.00031518f, 0.00238039f, 0.00027216f, 0.00030300f, 0.00004659f, 0.98962247f, 0.00050380f, 0.00587174f, 0.05895791f, 0.00323104f, 0.52636790f, 0.12912551f, 0.00003951f, 0.03615341f, 0.22013727f, 0.00007333f, 0.02566659f, 0.00024759f, 0.00192367f, 0.90509874f, 0.01985082f, 0.02080356f, 0.00260053f, 0.00497826f, 0.01107823f, 0.00872595f, 0.01559795f, 0.00934229f, 0.98202229f, 0.00000150f, 0.00137381f, 0.00082931f, 0.00001806f, 0.00384426f, 0.00758274f, 0.00305049f, 0.00052152f, 0.00075617f, 0.01094264f, 0.00044708f, 0.03576852f, 0.00711267f, 0.65963465f, 0.00734364f, 0.02747800f, 0.06494589f, 0.02966754f, 0.15665947f, 0.00035806f, 0.95196360f, 0.00622721f, 0.01610696f, 0.00084180f, 0.00139947f, 0.00127350f, 0.00577912f, 0.00980321f, 0.00624705f, 0.00167418f, 0.00125611f, 0.00109477f, 0.04061511f, 0.57403159f, 0.08173440f, 0.00423709f, 0.10187119f, 0.07103974f, 0.12244581f, 0.00073566f, 0.00624759f, 0.00559816f, 0.01215601f, 0.08299568f, 0.06209232f, 0.01742392f, 0.01341172f, 0.02181461f, 0.77752429f, 0.08474547f, 0.00957346f, 0.29235491f, 0.00243696f, 0.06653537f, 0.03792902f, 0.43910959f, 0.00344940f, 0.02626713f, 0.03759870f, 0.00143713f, 0.00011047f, 0.00018820f, 0.00047970f, 0.02127167f, 0.00308758f, 0.00093357f, 0.17067374f, 0.00545499f, 0.79636300f, 0.95257199f, 0.00002157f, 0.00647615f, 0.01024892f, 0.00005942f, 0.01910058f, 0.00044579f, 0.00008416f, 0.01097712f, 0.00001441f, 0.16705236f, 0.01782482f, 0.17580827f, 0.06262068f, 0.03860324f, 0.01763505f, 0.32766294f, 0.00555595f, 0.17227779f, 0.01495883f, 0.00180449f, 0.00010494f, 0.00075124f, 0.00161161f, 0.08859238f, 0.00364861f, 0.00162414f, 0.06005199f, 0.00805061f, 0.83375996f, 0.97355360f, 0.00000305f, 0.00144336f, 0.00051544f, 0.00010043f, 0.00714774f, 0.00021183f, 0.00042562f, 0.01294680f, 0.00365222f, 0.00026871f, 0.95752406f, 0.00408361f, 0.02153200f, 0.00015639f, 0.00153930f, 0.00323335f, 0.00178700f, 0.00516464f, 0.00471107f, 0.07408376f, 0.00468759f, 0.02638813f, 0.33325842f, 0.01172767f, 0.36993489f, 0.01118315f, 0.01460529f, 0.14850292f, 0.00562817f, 0.00551083f, 0.00015134f, 0.01184739f, 0.00643833f, 0.11686873f, 0.00163741f, 0.00582776f, 0.11497385f, 0.02010887f, 0.71663547f, 0.00154932f, 0.00001290f, 0.00023825f, 0.01393047f, 0.00012438f, 0.00033184f, 0.00010033f, 0.98197538f, 0.00022847f, 0.00150876f, 0.00597587f, 0.00819661f, 0.03041674f, 0.43121871f, 0.00986523f, 0.13834484f, 0.29576671f, 0.01305170f, 0.03919542f, 0.02796829f, 0.00139392f, 0.00031466f, 0.00229704f, 0.00647669f, 0.86193180f, 0.01064646f, 0.00494287f, 0.00901443f, 0.00526376f, 0.09771839f, 0.00184158f, 0.00040986f, 0.00008309f, 0.01634205f, 0.01102151f, 0.01133229f, 0.00011603f, 0.30489817f, 0.00813993f, 0.64581543f, 0.00132390f, 0.00009014f, 0.00471620f, 0.00419161f, 0.01024686f, 0.02504917f, 0.94500881f, 0.00010234f, 0.00620976f, 0.00306121f, 0.00971363f, 0.05415262f, 0.05265132f, 0.01217585f, 0.16251956f, 0.00188165f, 0.61800343f, 0.04541704f, 0.01950107f, 0.02398386f, 0.05354780f, 0.00129718f, 0.00762409f, 0.06902183f, 0.01746517f, 0.71758413f, 0.04491642f, 0.00194128f, 0.07204670f, 0.01455537f, 0.00356139f, 0.00223315f, 0.01881612f, 0.01844147f, 0.65686893f, 0.01172961f, 0.01321550f, 0.06555344f, 0.00993031f, 0.19965005f, 0.99641657f, 0.00000005f, 0.00027076f, 0.00000523f, 0.00001288f, 0.00173779f, 0.00140848f, 0.00001787f, 0.00012701f, 0.00000342f, 0.00364264f, 0.00040242f, 0.00199880f, 0.01658181f, 0.00522031f, 0.00494563f, 0.00134627f, 0.87392259f, 0.00277323f, 0.08916643f, 0.00200165f, 0.00006030f, 0.00265544f, 0.00137030f, 0.85328883f, 0.00988892f, 0.00416652f, 0.00394441f, 0.00617034f, 0.11645336f, 0.97291315f, 0.00000182f, 0.00194084f, 0.01498440f, 0.00001028f, 0.00389095f, 0.00023297f, 0.00044887f, 0.00528154f, 0.00029516f, 0.00188889f, 0.79829764f, 0.01104437f, 0.04222726f, 0.00522182f, 0.04550264f, 0.03192228f, 0.01099020f, 0.04107348f, 0.01183154f, 0.00058263f, 0.00048307f, 0.00013920f, 0.96885711f, 0.00005209f, 0.01755359f, 0.00061751f, 0.00787173f, 0.00087605f, 0.00296709f, 0.00342248f, 0.68736714f, 0.01477064f, 0.11038199f, 0.00979373f, 0.03290173f, 0.02064420f, 0.03154078f, 0.03068676f, 0.05849051f, 0.00054699f, 0.00028973f, 0.00066918f, 0.79915440f, 0.00078404f, 0.18881910f, 0.00078736f, 0.00024780f, 0.00598373f, 0.00271761f, 0.37178108f, 0.00029151f, 0.11573081f, 0.00016159f, 0.08614764f, 0.05626433f, 0.33961067f, 0.00184490f, 0.01931754f, 0.00884999f, 0.00103338f, 0.00105793f, 0.01583840f, 0.01417849f, 0.00086645f, 0.00075313f, 0.00009471f, 0.92975640f, 0.00786521f, 0.02855594f, 0.00831110f, 0.00041050f, 0.95547730f, 0.01004958f, 0.00024040f, 0.00674337f, 0.01100292f, 0.00229303f, 0.00543977f, 0.00003204f, 0.00073861f, 0.00003656f, 0.00233217f, 0.00864751f, 0.00044351f, 0.00055325f, 0.00046273f, 0.97456056f, 0.00097461f, 0.01125053f, 0.00035382f, 0.94428235f, 0.00286066f, 0.01286138f, 0.00111129f, 0.00731637f, 0.00518610f, 0.00538214f, 0.01197775f, 0.00866815f, 0.06013579f, 0.03228600f, 0.20441757f, 0.54548728f, 0.00006484f, 0.02362618f, 0.05482962f, 0.00106437f, 0.07713205f, 0.00095635f, 0.00029120f, 0.94839782f, 0.00271641f, 0.02038633f, 0.00010249f, 0.00270848f, 0.00299053f, 0.00069419f, 0.01599395f, 0.00571855f, 0.00580072f, 0.81594771f, 0.03097420f, 0.03646614f, 0.00565077f, 0.01715674f, 0.02362122f, 0.01730293f, 0.02312471f, 0.02395495f, 0.00083797f, 0.00032276f, 0.00475549f, 0.00577861f, 0.00193654f, 0.00201117f, 0.00095864f, 0.89032167f, 0.00238766f, 0.09068950f, 0.00007685f, 0.00309113f, 0.00165920f, 0.00566203f, 0.79406202f, 0.00106585f, 0.00073159f, 0.02779965f, 0.01331810f, 0.15253356f, 0.01362522f, 0.17258310f, 0.57671696f, 0.04606603f, 0.02204953f, 0.00909986f, 0.04971812f, 0.00135137f, 0.09417208f, 0.01461779f, 0.00351132f, 0.01659229f, 0.02209206f, 0.77456558f, 0.00303461f, 0.07932901f, 0.06269170f, 0.01151956f, 0.01363456f, 0.01302921f, 0.04056359f, 0.00052574f, 0.00214679f, 0.41835260f, 0.00373941f, 0.47472891f, 0.00819933f, 0.00047488f, 0.04602791f, 0.00524084f, 0.00085833f, 0.19585223f, 0.03986045f, 0.44138056f, 0.01866945f, 0.11297230f, 0.03688592f, 0.03147812f, 0.04306961f, 0.07897298f, 0.00580970f, 0.00654101f, 0.80165571f, 0.01388136f, 0.04366852f, 0.00407737f, 0.07712067f, 0.01289223f, 0.01437380f, 0.01997955f, 0.00013239f, 0.00000585f, 0.00003676f, 0.00288744f, 0.76327205f, 0.00911173f, 0.00025323f, 0.00345270f, 0.00977252f, 0.21107534f, 0.00238540f, 0.00011487f, 0.01707160f, 0.00274678f, 0.85196322f, 0.00066304f, 0.01279381f, 0.02112481f, 0.00446795f, 0.08666852f, 0.01046857f, 0.00011744f, 0.00377885f, 0.00806424f, 0.00110093f, 0.01087467f, 0.96216726f, 0.00024677f, 0.00213707f, 0.00104427f, 0.00835356f, 0.00037980f, 0.00540865f, 0.91882282f, 0.00084274f, 0.03935680f, 0.00700863f, 0.00609934f, 0.00307425f, 0.01065346f, 0.09310398f, 0.00066428f, 0.00076882f, 0.02210450f, 0.04447530f, 0.77650899f, 0.00945148f, 0.00689890f, 0.00886871f, 0.03715509f, 0.07214937f, 0.00624633f, 0.01399398f, 0.29444799f, 0.03825752f, 0.36904955f, 0.02109544f, 0.01373637f, 0.14653027f, 0.02449317f, 0.01878268f, 0.01089148f, 0.36442387f, 0.01426089f, 0.02649262f, 0.00308395f, 0.51123023f, 0.00987128f, 0.02856500f, 0.01239803f, 0.65732223f, 0.00001665f, 0.00257388f, 0.02261361f, 0.00056261f, 0.08028404f, 0.00753943f, 0.00092872f, 0.22300763f, 0.00515121f, 0.00238470f, 0.00001802f, 0.00303019f, 0.00282769f, 0.93392336f, 0.00829813f, 0.00937593f, 0.00232166f, 0.00606702f, 0.03175319f, 0.00192149f, 0.89188498f, 0.01474108f, 0.03585867f, 0.00123343f, 0.00441551f, 0.00399710f, 0.00857630f, 0.01781271f, 0.01955875f, 0.00221238f, 0.00005268f, 0.00038176f, 0.00141851f, 0.07513693f, 0.00153898f, 0.00254140f, 0.04116146f, 0.00216117f, 0.87339473f, 0.17824675f, 0.04543359f, 0.01501061f, 0.03382575f, 0.09682461f, 0.29989448f, 0.02655865f, 0.16809541f, 0.09566309f, 0.04044705f, 0.00052125f, 0.00006512f, 0.00041621f, 0.03254773f, 0.00120942f, 0.00177929f, 0.00091721f, 0.95285058f, 0.00068729f, 0.00900588f, 0.04185560f, 0.00125587f, 0.33473280f, 0.00119652f, 0.00552071f, 0.03358750f, 0.04974457f, 0.00243473f, 0.41644078f, 0.11323092f, 0.00945223f, 0.00509389f, 0.04602458f, 0.02943204f, 0.23871920f, 0.06141117f, 0.05274383f, 0.03511769f, 0.09954999f, 0.42245534f, 0.00686926f, 0.01075546f, 0.49830484f, 0.37111449f, 0.00928881f, 0.00910977f, 0.00822666f, 0.00448587f, 0.04094843f, 0.04089646f, 0.00190534f, 0.00074783f, 0.02465805f, 0.02045769f, 0.02690129f, 0.00249506f, 0.00202899f, 0.84847659f, 0.01121813f, 0.06111111f, 0.00527403f, 0.00617689f, 0.00719898f, 0.17549324f, 0.25461593f, 0.15036304f, 0.04163047f, 0.01647436f, 0.08906800f, 0.25370511f, 0.10200825f, 0.03916828f, 0.22575049f, 0.08762794f, 0.06703069f, 0.01087492f, 0.27197123f, 0.15926389f, 0.02289790f, 0.01340644f, 0.00233572f, 0.00071111f, 0.01389953f, 0.00187034f, 0.89338356f, 0.00067592f, 0.00535080f, 0.02598928f, 0.01003115f, 0.04575264f, 0.00010197f, 0.00006095f, 0.00021980f, 0.99164659f, 0.00011408f, 0.00474983f, 0.00004892f, 0.00012496f, 0.00257160f, 0.00036128f, 0.91125363f, 0.00012225f, 0.02511939f, 0.00156989f, 0.00002669f, 0.03335980f, 0.01791442f, 0.00531134f, 0.00345027f, 0.00187230f, 0.00210833f, 0.00001888f, 0.00016036f, 0.00394190f, 0.00016232f, 0.00026980f, 0.00012382f, 0.99098623f, 0.00036967f, 0.00185874f, 0.99578768f, 0.00000018f, 0.00162244f, 0.00012927f, 0.00000136f, 0.00158810f, 0.00016544f, 0.00000476f, 0.00069853f, 0.00000226f, 0.19834445f, 0.00044551f, 0.40857196f, 0.34896207f, 0.00023418f, 0.00828141f, 0.02426279f, 0.00148875f, 0.00938030f, 0.00002860f, 0.00201644f, 0.06109568f, 0.01542680f, 0.05984236f, 0.00112191f, 0.00419699f, 0.00110061f, 0.28937989f, 0.13231210f, 0.43350723f, 0.00055382f, 0.92216444f, 0.00396460f, 0.01456171f, 0.00061405f, 0.00972675f, 0.00677260f, 0.00454273f, 0.02471014f, 0.01238921f, 0.00027888f, 0.02572848f, 0.00290584f, 0.00748292f, 0.08441166f, 0.00232722f, 0.00188305f, 0.81133318f, 0.01191756f, 0.05173124f, 0.00315098f, 0.00499059f, 0.00158580f, 0.92859417f, 0.00035086f, 0.04807130f, 0.00101955f, 0.00034313f, 0.01119398f, 0.00069962f, 0.00112821f, 0.00214349f, 0.03968662f, 0.00325992f, 0.00253143f, 0.00199443f, 0.00964058f, 0.90529889f, 0.00384289f, 0.03047365f, 0.00174196f, 0.06674320f, 0.00283191f, 0.09274873f, 0.01944309f, 0.03424436f, 0.00694406f, 0.07912937f, 0.15087396f, 0.54529935f, 0.00007096f, 0.00001000f, 0.00001498f, 0.00007066f, 0.00002792f, 0.00005677f, 0.00000490f, 0.99606401f, 0.00030978f, 0.00337013f, 0.00286575f, 0.00011636f, 0.00064778f, 0.00992065f, 0.04501861f, 0.03149971f, 0.00287679f, 0.37334359f, 0.00214695f, 0.53156382f, 0.00600238f, 0.00003215f, 0.02112119f, 0.00084685f, 0.00497269f, 0.00753993f, 0.95174772f, 0.00150877f, 0.00212018f, 0.00410815f, 0.00006566f, 0.00001179f, 0.99827027f, 0.00028396f, 0.00004237f, 0.00000550f, 0.00091406f, 0.00003423f, 0.00036640f, 0.00000567f, 0.00079063f, 0.00006855f, 0.00051338f, 0.00590454f, 0.00732460f, 0.00195139f, 0.00034534f, 0.90222436f, 0.00163695f, 0.07924022f, 0.00362202f, 0.01493629f, 0.01135249f, 0.00781013f, 0.05138498f, 0.22704794f, 0.00442778f, 0.00350683f, 0.59828150f, 0.07762999f, 0.00016529f, 0.00001219f, 0.00006521f, 0.00446292f, 0.94456083f, 0.00407963f, 0.00102245f, 0.00057420f, 0.00344479f, 0.04161252f, 0.00000981f, 0.00030270f, 0.00017082f, 0.00029943f, 0.00010159f, 0.00003605f, 0.00001875f, 0.99310946f, 0.00063157f, 0.00531995f, 0.01100852f, 0.00021492f, 0.00049603f, 0.59714299f, 0.00454595f, 0.33691072f, 0.03074775f, 0.00427598f, 0.00512297f, 0.00953417f, 0.00064403f, 0.00001687f, 0.00822414f, 0.00012918f, 0.02522905f, 0.00046274f, 0.95950085f, 0.00174588f, 0.00070707f, 0.00334025f, 0.00014754f, 0.96842438f, 0.00752080f, 0.00713038f, 0.00074491f, 0.00107368f, 0.00245372f, 0.00181830f, 0.00883226f, 0.00185409f, 0.00210863f, 0.00017522f, 0.00039881f, 0.98836052f, 0.00003650f, 0.00535216f, 0.00001887f, 0.00069545f, 0.00265663f, 0.00019714f, 0.00028919f, 0.00026057f, 0.00356666f, 0.00034738f, 0.00413719f, 0.00133701f, 0.98608136f, 0.00009625f, 0.00153734f, 0.00234698f, 0.01427079f, 0.04020482f, 0.04733688f, 0.03817881f, 0.16299380f, 0.04943828f, 0.03522370f, 0.05902825f, 0.23904003f, 0.31428465f, 0.00029359f, 0.00005619f, 0.00007707f, 0.98437482f, 0.00000957f, 0.00828004f, 0.00002787f, 0.00510217f, 0.00087425f, 0.00090444f, 0.00011413f, 0.83918202f, 0.01017746f, 0.03100164f, 0.00308035f, 0.01615586f, 0.02608237f, 0.00337026f, 0.05493741f, 0.01589854f, 0.00053240f, 0.00144792f, 0.00108170f, 0.00027300f, 0.86477506f, 0.00072790f, 0.01062538f, 0.00428096f, 0.00233054f, 0.11392505f, 0.00411633f, 0.33660546f, 0.01735369f, 0.18114267f, 0.03090077f, 0.11699959f, 0.03416851f, 0.06780743f, 0.07481573f, 0.13608985f, 0.00073468f, 0.20941530f, 0.01012138f, 0.17237675f, 0.01661461f, 0.02184150f, 0.03694551f, 0.30870155f, 0.04255475f, 0.18069389f, 0.06343270f, 0.00037455f, 0.06623310f, 0.00041474f, 0.00209181f, 0.04566626f, 0.81232506f, 0.00054500f, 0.00807252f, 0.00084416f, 0.00008067f, 0.00003926f, 0.00225794f, 0.00115743f, 0.01925980f, 0.00010427f, 0.00062067f, 0.02234522f, 0.00210706f, 0.95202768f}); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/mnist_00.fb"); - //graph->printOut(); - - auto result = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), result); - - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(6)); - - auto z = graph->getVariableSpace()->getVariable(6,0)->getNDArray(); - - ASSERT_EQ(e, *z); - - delete graph; + auto e = NDArrayFactory::create( + 'c', {100, 10}, + {0.00066107f, 0.00002358f, 0.00031518f, 0.00238039f, 0.00027216f, + 0.00030300f, 0.00004659f, 0.98962247f, 0.00050380f, 0.00587174f, + 0.05895791f, 0.00323104f, 0.52636790f, 0.12912551f, 0.00003951f, + 0.03615341f, 0.22013727f, 0.00007333f, 0.02566659f, 0.00024759f, + 0.00192367f, 0.90509874f, 0.01985082f, 0.02080356f, 0.00260053f, + 0.00497826f, 0.01107823f, 0.00872595f, 0.01559795f, 0.00934229f, + 0.98202229f, 0.00000150f, 0.00137381f, 0.00082931f, 0.00001806f, + 0.00384426f, 0.00758274f, 0.00305049f, 0.00052152f, 0.00075617f, + 0.01094264f, 0.00044708f, 0.03576852f, 0.00711267f, 0.65963465f, + 0.00734364f, 0.02747800f, 0.06494589f, 0.02966754f, 0.15665947f, + 0.00035806f, 0.95196360f, 0.00622721f, 0.01610696f, 0.00084180f, + 0.00139947f, 0.00127350f, 0.00577912f, 0.00980321f, 0.00624705f, + 0.00167418f, 0.00125611f, 0.00109477f, 0.04061511f, 0.57403159f, + 0.08173440f, 0.00423709f, 0.10187119f, 0.07103974f, 0.12244581f, + 0.00073566f, 0.00624759f, 0.00559816f, 0.01215601f, 0.08299568f, + 0.06209232f, 0.01742392f, 0.01341172f, 0.02181461f, 0.77752429f, + 0.08474547f, 0.00957346f, 0.29235491f, 0.00243696f, 0.06653537f, + 0.03792902f, 0.43910959f, 0.00344940f, 0.02626713f, 0.03759870f, + 0.00143713f, 0.00011047f, 0.00018820f, 0.00047970f, 0.02127167f, + 0.00308758f, 0.00093357f, 0.17067374f, 0.00545499f, 0.79636300f, + 0.95257199f, 0.00002157f, 0.00647615f, 0.01024892f, 0.00005942f, + 0.01910058f, 0.00044579f, 0.00008416f, 0.01097712f, 0.00001441f, + 0.16705236f, 0.01782482f, 0.17580827f, 0.06262068f, 0.03860324f, + 0.01763505f, 0.32766294f, 0.00555595f, 0.17227779f, 0.01495883f, + 0.00180449f, 0.00010494f, 0.00075124f, 0.00161161f, 0.08859238f, + 0.00364861f, 0.00162414f, 0.06005199f, 0.00805061f, 0.83375996f, + 0.97355360f, 0.00000305f, 0.00144336f, 0.00051544f, 0.00010043f, + 0.00714774f, 0.00021183f, 0.00042562f, 0.01294680f, 0.00365222f, + 0.00026871f, 0.95752406f, 0.00408361f, 0.02153200f, 0.00015639f, + 0.00153930f, 0.00323335f, 0.00178700f, 0.00516464f, 0.00471107f, + 0.07408376f, 0.00468759f, 0.02638813f, 0.33325842f, 0.01172767f, + 0.36993489f, 0.01118315f, 0.01460529f, 0.14850292f, 0.00562817f, + 0.00551083f, 0.00015134f, 0.01184739f, 0.00643833f, 0.11686873f, + 0.00163741f, 0.00582776f, 0.11497385f, 0.02010887f, 0.71663547f, + 0.00154932f, 0.00001290f, 0.00023825f, 0.01393047f, 0.00012438f, + 0.00033184f, 0.00010033f, 0.98197538f, 0.00022847f, 0.00150876f, + 0.00597587f, 0.00819661f, 0.03041674f, 0.43121871f, 0.00986523f, + 0.13834484f, 0.29576671f, 0.01305170f, 0.03919542f, 0.02796829f, + 0.00139392f, 0.00031466f, 0.00229704f, 0.00647669f, 0.86193180f, + 0.01064646f, 0.00494287f, 0.00901443f, 0.00526376f, 0.09771839f, + 0.00184158f, 0.00040986f, 0.00008309f, 0.01634205f, 0.01102151f, + 0.01133229f, 0.00011603f, 0.30489817f, 0.00813993f, 0.64581543f, + 0.00132390f, 0.00009014f, 0.00471620f, 0.00419161f, 0.01024686f, + 0.02504917f, 0.94500881f, 0.00010234f, 0.00620976f, 0.00306121f, + 0.00971363f, 0.05415262f, 0.05265132f, 0.01217585f, 0.16251956f, + 0.00188165f, 0.61800343f, 0.04541704f, 0.01950107f, 0.02398386f, + 0.05354780f, 0.00129718f, 0.00762409f, 0.06902183f, 0.01746517f, + 0.71758413f, 0.04491642f, 0.00194128f, 0.07204670f, 0.01455537f, + 0.00356139f, 0.00223315f, 0.01881612f, 0.01844147f, 0.65686893f, + 0.01172961f, 0.01321550f, 0.06555344f, 0.00993031f, 0.19965005f, + 0.99641657f, 0.00000005f, 0.00027076f, 0.00000523f, 0.00001288f, + 0.00173779f, 0.00140848f, 0.00001787f, 0.00012701f, 0.00000342f, + 0.00364264f, 0.00040242f, 0.00199880f, 0.01658181f, 0.00522031f, + 0.00494563f, 0.00134627f, 0.87392259f, 0.00277323f, 0.08916643f, + 0.00200165f, 0.00006030f, 0.00265544f, 0.00137030f, 0.85328883f, + 0.00988892f, 0.00416652f, 0.00394441f, 0.00617034f, 0.11645336f, + 0.97291315f, 0.00000182f, 0.00194084f, 0.01498440f, 0.00001028f, + 0.00389095f, 0.00023297f, 0.00044887f, 0.00528154f, 0.00029516f, + 0.00188889f, 0.79829764f, 0.01104437f, 0.04222726f, 0.00522182f, + 0.04550264f, 0.03192228f, 0.01099020f, 0.04107348f, 0.01183154f, + 0.00058263f, 0.00048307f, 0.00013920f, 0.96885711f, 0.00005209f, + 0.01755359f, 0.00061751f, 0.00787173f, 0.00087605f, 0.00296709f, + 0.00342248f, 0.68736714f, 0.01477064f, 0.11038199f, 0.00979373f, + 0.03290173f, 0.02064420f, 0.03154078f, 0.03068676f, 0.05849051f, + 0.00054699f, 0.00028973f, 0.00066918f, 0.79915440f, 0.00078404f, + 0.18881910f, 0.00078736f, 0.00024780f, 0.00598373f, 0.00271761f, + 0.37178108f, 0.00029151f, 0.11573081f, 0.00016159f, 0.08614764f, + 0.05626433f, 0.33961067f, 0.00184490f, 0.01931754f, 0.00884999f, + 0.00103338f, 0.00105793f, 0.01583840f, 0.01417849f, 0.00086645f, + 0.00075313f, 0.00009471f, 0.92975640f, 0.00786521f, 0.02855594f, + 0.00831110f, 0.00041050f, 0.95547730f, 0.01004958f, 0.00024040f, + 0.00674337f, 0.01100292f, 0.00229303f, 0.00543977f, 0.00003204f, + 0.00073861f, 0.00003656f, 0.00233217f, 0.00864751f, 0.00044351f, + 0.00055325f, 0.00046273f, 0.97456056f, 0.00097461f, 0.01125053f, + 0.00035382f, 0.94428235f, 0.00286066f, 0.01286138f, 0.00111129f, + 0.00731637f, 0.00518610f, 0.00538214f, 0.01197775f, 0.00866815f, + 0.06013579f, 0.03228600f, 0.20441757f, 0.54548728f, 0.00006484f, + 0.02362618f, 0.05482962f, 0.00106437f, 0.07713205f, 0.00095635f, + 0.00029120f, 0.94839782f, 0.00271641f, 0.02038633f, 0.00010249f, + 0.00270848f, 0.00299053f, 0.00069419f, 0.01599395f, 0.00571855f, + 0.00580072f, 0.81594771f, 0.03097420f, 0.03646614f, 0.00565077f, + 0.01715674f, 0.02362122f, 0.01730293f, 0.02312471f, 0.02395495f, + 0.00083797f, 0.00032276f, 0.00475549f, 0.00577861f, 0.00193654f, + 0.00201117f, 0.00095864f, 0.89032167f, 0.00238766f, 0.09068950f, + 0.00007685f, 0.00309113f, 0.00165920f, 0.00566203f, 0.79406202f, + 0.00106585f, 0.00073159f, 0.02779965f, 0.01331810f, 0.15253356f, + 0.01362522f, 0.17258310f, 0.57671696f, 0.04606603f, 0.02204953f, + 0.00909986f, 0.04971812f, 0.00135137f, 0.09417208f, 0.01461779f, + 0.00351132f, 0.01659229f, 0.02209206f, 0.77456558f, 0.00303461f, + 0.07932901f, 0.06269170f, 0.01151956f, 0.01363456f, 0.01302921f, + 0.04056359f, 0.00052574f, 0.00214679f, 0.41835260f, 0.00373941f, + 0.47472891f, 0.00819933f, 0.00047488f, 0.04602791f, 0.00524084f, + 0.00085833f, 0.19585223f, 0.03986045f, 0.44138056f, 0.01866945f, + 0.11297230f, 0.03688592f, 0.03147812f, 0.04306961f, 0.07897298f, + 0.00580970f, 0.00654101f, 0.80165571f, 0.01388136f, 0.04366852f, + 0.00407737f, 0.07712067f, 0.01289223f, 0.01437380f, 0.01997955f, + 0.00013239f, 0.00000585f, 0.00003676f, 0.00288744f, 0.76327205f, + 0.00911173f, 0.00025323f, 0.00345270f, 0.00977252f, 0.21107534f, + 0.00238540f, 0.00011487f, 0.01707160f, 0.00274678f, 0.85196322f, + 0.00066304f, 0.01279381f, 0.02112481f, 0.00446795f, 0.08666852f, + 0.01046857f, 0.00011744f, 0.00377885f, 0.00806424f, 0.00110093f, + 0.01087467f, 0.96216726f, 0.00024677f, 0.00213707f, 0.00104427f, + 0.00835356f, 0.00037980f, 0.00540865f, 0.91882282f, 0.00084274f, + 0.03935680f, 0.00700863f, 0.00609934f, 0.00307425f, 0.01065346f, + 0.09310398f, 0.00066428f, 0.00076882f, 0.02210450f, 0.04447530f, + 0.77650899f, 0.00945148f, 0.00689890f, 0.00886871f, 0.03715509f, + 0.07214937f, 0.00624633f, 0.01399398f, 0.29444799f, 0.03825752f, + 0.36904955f, 0.02109544f, 0.01373637f, 0.14653027f, 0.02449317f, + 0.01878268f, 0.01089148f, 0.36442387f, 0.01426089f, 0.02649262f, + 0.00308395f, 0.51123023f, 0.00987128f, 0.02856500f, 0.01239803f, + 0.65732223f, 0.00001665f, 0.00257388f, 0.02261361f, 0.00056261f, + 0.08028404f, 0.00753943f, 0.00092872f, 0.22300763f, 0.00515121f, + 0.00238470f, 0.00001802f, 0.00303019f, 0.00282769f, 0.93392336f, + 0.00829813f, 0.00937593f, 0.00232166f, 0.00606702f, 0.03175319f, + 0.00192149f, 0.89188498f, 0.01474108f, 0.03585867f, 0.00123343f, + 0.00441551f, 0.00399710f, 0.00857630f, 0.01781271f, 0.01955875f, + 0.00221238f, 0.00005268f, 0.00038176f, 0.00141851f, 0.07513693f, + 0.00153898f, 0.00254140f, 0.04116146f, 0.00216117f, 0.87339473f, + 0.17824675f, 0.04543359f, 0.01501061f, 0.03382575f, 0.09682461f, + 0.29989448f, 0.02655865f, 0.16809541f, 0.09566309f, 0.04044705f, + 0.00052125f, 0.00006512f, 0.00041621f, 0.03254773f, 0.00120942f, + 0.00177929f, 0.00091721f, 0.95285058f, 0.00068729f, 0.00900588f, + 0.04185560f, 0.00125587f, 0.33473280f, 0.00119652f, 0.00552071f, + 0.03358750f, 0.04974457f, 0.00243473f, 0.41644078f, 0.11323092f, + 0.00945223f, 0.00509389f, 0.04602458f, 0.02943204f, 0.23871920f, + 0.06141117f, 0.05274383f, 0.03511769f, 0.09954999f, 0.42245534f, + 0.00686926f, 0.01075546f, 0.49830484f, 0.37111449f, 0.00928881f, + 0.00910977f, 0.00822666f, 0.00448587f, 0.04094843f, 0.04089646f, + 0.00190534f, 0.00074783f, 0.02465805f, 0.02045769f, 0.02690129f, + 0.00249506f, 0.00202899f, 0.84847659f, 0.01121813f, 0.06111111f, + 0.00527403f, 0.00617689f, 0.00719898f, 0.17549324f, 0.25461593f, + 0.15036304f, 0.04163047f, 0.01647436f, 0.08906800f, 0.25370511f, + 0.10200825f, 0.03916828f, 0.22575049f, 0.08762794f, 0.06703069f, + 0.01087492f, 0.27197123f, 0.15926389f, 0.02289790f, 0.01340644f, + 0.00233572f, 0.00071111f, 0.01389953f, 0.00187034f, 0.89338356f, + 0.00067592f, 0.00535080f, 0.02598928f, 0.01003115f, 0.04575264f, + 0.00010197f, 0.00006095f, 0.00021980f, 0.99164659f, 0.00011408f, + 0.00474983f, 0.00004892f, 0.00012496f, 0.00257160f, 0.00036128f, + 0.91125363f, 0.00012225f, 0.02511939f, 0.00156989f, 0.00002669f, + 0.03335980f, 0.01791442f, 0.00531134f, 0.00345027f, 0.00187230f, + 0.00210833f, 0.00001888f, 0.00016036f, 0.00394190f, 0.00016232f, + 0.00026980f, 0.00012382f, 0.99098623f, 0.00036967f, 0.00185874f, + 0.99578768f, 0.00000018f, 0.00162244f, 0.00012927f, 0.00000136f, + 0.00158810f, 0.00016544f, 0.00000476f, 0.00069853f, 0.00000226f, + 0.19834445f, 0.00044551f, 0.40857196f, 0.34896207f, 0.00023418f, + 0.00828141f, 0.02426279f, 0.00148875f, 0.00938030f, 0.00002860f, + 0.00201644f, 0.06109568f, 0.01542680f, 0.05984236f, 0.00112191f, + 0.00419699f, 0.00110061f, 0.28937989f, 0.13231210f, 0.43350723f, + 0.00055382f, 0.92216444f, 0.00396460f, 0.01456171f, 0.00061405f, + 0.00972675f, 0.00677260f, 0.00454273f, 0.02471014f, 0.01238921f, + 0.00027888f, 0.02572848f, 0.00290584f, 0.00748292f, 0.08441166f, + 0.00232722f, 0.00188305f, 0.81133318f, 0.01191756f, 0.05173124f, + 0.00315098f, 0.00499059f, 0.00158580f, 0.92859417f, 0.00035086f, + 0.04807130f, 0.00101955f, 0.00034313f, 0.01119398f, 0.00069962f, + 0.00112821f, 0.00214349f, 0.03968662f, 0.00325992f, 0.00253143f, + 0.00199443f, 0.00964058f, 0.90529889f, 0.00384289f, 0.03047365f, + 0.00174196f, 0.06674320f, 0.00283191f, 0.09274873f, 0.01944309f, + 0.03424436f, 0.00694406f, 0.07912937f, 0.15087396f, 0.54529935f, + 0.00007096f, 0.00001000f, 0.00001498f, 0.00007066f, 0.00002792f, + 0.00005677f, 0.00000490f, 0.99606401f, 0.00030978f, 0.00337013f, + 0.00286575f, 0.00011636f, 0.00064778f, 0.00992065f, 0.04501861f, + 0.03149971f, 0.00287679f, 0.37334359f, 0.00214695f, 0.53156382f, + 0.00600238f, 0.00003215f, 0.02112119f, 0.00084685f, 0.00497269f, + 0.00753993f, 0.95174772f, 0.00150877f, 0.00212018f, 0.00410815f, + 0.00006566f, 0.00001179f, 0.99827027f, 0.00028396f, 0.00004237f, + 0.00000550f, 0.00091406f, 0.00003423f, 0.00036640f, 0.00000567f, + 0.00079063f, 0.00006855f, 0.00051338f, 0.00590454f, 0.00732460f, + 0.00195139f, 0.00034534f, 0.90222436f, 0.00163695f, 0.07924022f, + 0.00362202f, 0.01493629f, 0.01135249f, 0.00781013f, 0.05138498f, + 0.22704794f, 0.00442778f, 0.00350683f, 0.59828150f, 0.07762999f, + 0.00016529f, 0.00001219f, 0.00006521f, 0.00446292f, 0.94456083f, + 0.00407963f, 0.00102245f, 0.00057420f, 0.00344479f, 0.04161252f, + 0.00000981f, 0.00030270f, 0.00017082f, 0.00029943f, 0.00010159f, + 0.00003605f, 0.00001875f, 0.99310946f, 0.00063157f, 0.00531995f, + 0.01100852f, 0.00021492f, 0.00049603f, 0.59714299f, 0.00454595f, + 0.33691072f, 0.03074775f, 0.00427598f, 0.00512297f, 0.00953417f, + 0.00064403f, 0.00001687f, 0.00822414f, 0.00012918f, 0.02522905f, + 0.00046274f, 0.95950085f, 0.00174588f, 0.00070707f, 0.00334025f, + 0.00014754f, 0.96842438f, 0.00752080f, 0.00713038f, 0.00074491f, + 0.00107368f, 0.00245372f, 0.00181830f, 0.00883226f, 0.00185409f, + 0.00210863f, 0.00017522f, 0.00039881f, 0.98836052f, 0.00003650f, + 0.00535216f, 0.00001887f, 0.00069545f, 0.00265663f, 0.00019714f, + 0.00028919f, 0.00026057f, 0.00356666f, 0.00034738f, 0.00413719f, + 0.00133701f, 0.98608136f, 0.00009625f, 0.00153734f, 0.00234698f, + 0.01427079f, 0.04020482f, 0.04733688f, 0.03817881f, 0.16299380f, + 0.04943828f, 0.03522370f, 0.05902825f, 0.23904003f, 0.31428465f, + 0.00029359f, 0.00005619f, 0.00007707f, 0.98437482f, 0.00000957f, + 0.00828004f, 0.00002787f, 0.00510217f, 0.00087425f, 0.00090444f, + 0.00011413f, 0.83918202f, 0.01017746f, 0.03100164f, 0.00308035f, + 0.01615586f, 0.02608237f, 0.00337026f, 0.05493741f, 0.01589854f, + 0.00053240f, 0.00144792f, 0.00108170f, 0.00027300f, 0.86477506f, + 0.00072790f, 0.01062538f, 0.00428096f, 0.00233054f, 0.11392505f, + 0.00411633f, 0.33660546f, 0.01735369f, 0.18114267f, 0.03090077f, + 0.11699959f, 0.03416851f, 0.06780743f, 0.07481573f, 0.13608985f, + 0.00073468f, 0.20941530f, 0.01012138f, 0.17237675f, 0.01661461f, + 0.02184150f, 0.03694551f, 0.30870155f, 0.04255475f, 0.18069389f, + 0.06343270f, 0.00037455f, 0.06623310f, 0.00041474f, 0.00209181f, + 0.04566626f, 0.81232506f, 0.00054500f, 0.00807252f, 0.00084416f, + 0.00008067f, 0.00003926f, 0.00225794f, 0.00115743f, 0.01925980f, + 0.00010427f, 0.00062067f, 0.02234522f, 0.00210706f, 0.95202768f}); + auto graph = + GraphExecutioner::importFromFlatBuffers("./resources/mnist_00.fb"); + // graph->printOut(); + + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), result); + + ASSERT_TRUE(graph->variableSpace()->hasVariable(6)); + + auto z = graph->variableSpace()->getVariable(6, 0)->getNDArray(); + + ASSERT_EQ(e, *z); + + delete graph; } - - TEST_F(FlatBuffersTest, Test_MNIST_1) { - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/mnist.fb"); - //graph->printOut(); + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/mnist.fb"); + // graph->printOut(); - auto result = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), result); + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), result); - delete graph; + delete graph; } /* @@ -766,18 +1073,20 @@ TEST_F(FlatBuffersTest, Test_MNIST_1) { TEST_F(FlatBuffersTest, nhwc_conv_0) { sd::ops::rank op1; - auto exp('c', {4, 2}, {2.958640f, 0.602521f, 7.571267f, 1.496686f, -2.292647f, -1.791460f, 13.055838f, 4.278642f}); + auto exp('c', {4, 2}, {2.958640f, 0.602521f, 7.571267f, 1.496686f, +-2.292647f, -1.791460f, 13.055838f, 4.278642f}); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/conv_0.fb"); + auto graph = +GraphExecutioner::importFromFlatBuffers("./resources/conv_0.fb"); graph->printOut(); auto result = GraphExecutioner::execute(graph); ASSERT_EQ(ND4J_STATUS_OK, result); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(11)); + ASSERT_TRUE(graph->variableSpace()->hasVariable(11)); - auto z = graph->getVariableSpace()->getVariable(11)->getNDArray(); + auto z = graph->variableSpace()->getVariable(11)->getNDArray(); //z->printShapeInfo("z buffr"); //z->printIndexedBuffer("z shape"); @@ -795,12 +1104,12 @@ TEST_F(FlatBuffersTest, nhwc_conv_0) { */ - /* TEST_F(FlatBuffersTest, ReadLoops_SimpleWhile_1) { // TF graph: // https://gist.github.com/raver119/2aa49daf7ec09ed4ddddbc6262f213a0 - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simple_while.fb"); + auto graph = +GraphExecutioner::importFromFlatBuffers("./resources/simple_while.fb"); ASSERT_TRUE(graph != nullptr); diff --git a/libnd4j/tests_cpu/layers_tests/FlatUtilsTests.cpp b/libnd4j/tests_cpu/layers_tests/FlatUtilsTests.cpp index f31a1c7ec47b..400ca6cb84a5 100644 --- a/libnd4j/tests_cpu/layers_tests/FlatUtilsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/FlatUtilsTests.cpp @@ -20,81 +20,70 @@ #include #include -#include "testlayers.h" -#include #include +#include + +#include "testlayers.h" using namespace sd; class FlatUtilsTests : public testing::Test { -public: - + public: }; TEST_F(FlatUtilsTests, flat_float_serde_1) { - auto array = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - - flatbuffers::FlatBufferBuilder builder(1024); - auto flatArray = FlatUtils::toFlatArray(builder, array); - builder.Finish(flatArray); + auto array = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + flatbuffers::FlatBufferBuilder builder(1024); + auto flatArray = FlatUtils::toFlatArray(builder, array); + builder.Finish(flatArray); - auto pfArray = GetFlatArray(builder.GetBufferPointer()); + auto pfArray = GetFlatArray(builder.GetBufferPointer()); - auto restored = FlatUtils::fromFlatArray(pfArray); + auto restored = FlatUtils::fromFlatArray(pfArray); - ASSERT_EQ(array, *restored); - - delete restored; + ASSERT_EQ(array, restored); } TEST_F(FlatUtilsTests, flat_int_serde_1) { - auto array = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - - flatbuffers::FlatBufferBuilder builder(1024); - auto flatArray = FlatUtils::toFlatArray(builder, array); - builder.Finish(flatArray); + auto array = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + flatbuffers::FlatBufferBuilder builder(1024); + auto flatArray = FlatUtils::toFlatArray(builder, array); + builder.Finish(flatArray); - auto pfArray = GetFlatArray(builder.GetBufferPointer()); + auto pfArray = GetFlatArray(builder.GetBufferPointer()); - auto restored = FlatUtils::fromFlatArray(pfArray); + auto restored = FlatUtils::fromFlatArray(pfArray); - ASSERT_EQ(array, *restored); - - delete restored; + ASSERT_EQ(array, restored); } TEST_F(FlatUtilsTests, flat_bool_serde_1) { - auto array = NDArrayFactory::create('c', {4}, {true, false, true, false}); - - flatbuffers::FlatBufferBuilder builder(1024); - auto flatArray = FlatUtils::toFlatArray(builder, array); - builder.Finish(flatArray); + auto array = + NDArrayFactory::create('c', {4}, {true, false, true, false}); + flatbuffers::FlatBufferBuilder builder(1024); + auto flatArray = FlatUtils::toFlatArray(builder, array); + builder.Finish(flatArray); - auto pfArray = GetFlatArray(builder.GetBufferPointer()); + auto pfArray = GetFlatArray(builder.GetBufferPointer()); - auto restored = FlatUtils::fromFlatArray(pfArray); + auto restored = FlatUtils::fromFlatArray(pfArray); - ASSERT_EQ(array, *restored); - - delete restored; + ASSERT_EQ(array, restored); } TEST_F(FlatUtilsTests, flat_string_serde_1) { - auto array = NDArrayFactory::string( {3}, {"alpha", "beta", "gamma"}); - - flatbuffers::FlatBufferBuilder builder(1024); - auto flatArray = FlatUtils::toFlatArray(builder, array); - builder.Finish(flatArray); - + auto array = NDArrayFactory::string({3}, {"alpha", "beta", "gamma"}); - auto pfArray = GetFlatArray(builder.GetBufferPointer()); + flatbuffers::FlatBufferBuilder builder(1024); + auto flatArray = FlatUtils::toFlatArray(builder, array); + builder.Finish(flatArray); - auto restored = FlatUtils::fromFlatArray(pfArray); + auto pfArray = GetFlatArray(builder.GetBufferPointer()); - ASSERT_EQ(array, *restored); + auto restored = FlatUtils::fromFlatArray(pfArray); - delete restored; + ASSERT_EQ(array, restored); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp new file mode 100644 index 000000000000..94b4d0dbc570 --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -0,0 +1,1133 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include +#include + +#include + +#include "testlayers.h" + +using namespace sd; +using namespace sd::graph; + +class GraphAnalysisTests : public testing::Test { + public: + GraphAnalysisTests() { + /// + } +}; + +TEST_F(GraphAnalysisTests, optimizedGraph_1) { + + // A*B + C + Graph graph; + + graph.addVariable("A", NDArray('c', {3}, sd::DataType::INT32)); + graph.addVariable("B", NDArray('c', {3}, sd::DataType::INT32)); + graph.addVariable("C", NDArray('c', {3}, sd::DataType::INT32)); + + Node a(sd::ops::multiply(), "multiply"); + Node b(sd::ops::add(), "add"); + + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"multiply", "C"}); + + // we just check that nodes were really added + ASSERT_EQ(2, graph.size()); + + const auto& optimized = graph.optimizedGraph(); + + // we expect that OptimizedGraph has exactly 1 layer + ASSERT_EQ(1, optimized.numOfLayers()); + + auto layer = optimized.layer(0); + + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer.width()); + auto sequence = layer[0]; + + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(2, sequence.length()); + + ASSERT_EQ(4, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(5, sequence.at(1).protoContext().nodeId()); +} + +TEST_F(GraphAnalysisTests, optimizedGraph_2) { + + // 0 = A*B, 1_0 = 0+C, 1_1 = 0-D + Graph graph; + + graph.addVariable("A", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("B", NDArray('c', {3}, {2, 2, 2}, sd::DataType::INT32)); + graph.addVariable("C", NDArray('c', {3}, {3, 3, 3}, sd::DataType::INT32)); + graph.addVariable("D", NDArray('c', {3}, {4, 4, 4}, sd::DataType::INT32)); + + + Node a(sd::ops::multiply(), "multiply"); + Node b(sd::ops::add(), "add"); + Node c(sd::ops::subtract(), "subtract"); + + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"multiply", "C"}); + graph.addNode(c, {"multiply", "D"}); + + // we just check that nodes were really added + ASSERT_EQ(3, graph.size()); + + const auto& optimized = graph.optimizedGraph(); + + // graph size must stay the same + ASSERT_EQ(3, graph.size()); + + // we expect that OptimizedGraph has exactly 2 layers + ASSERT_EQ(2, optimized.numOfLayers()); + + // checking first layer first + auto layer0 = optimized.layer(0); + + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer0.width()); + + // we expect that OpSequence has exactly 1 node + ASSERT_EQ(1, layer0[0].length()); + + ASSERT_EQ(5, layer0[0].at(0).protoContext().nodeId()); + + // checking second layer now + auto layer1 = optimized.layer(1); + + // we expect layer has exactly 2 OpSequences + ASSERT_EQ(2, layer1.width()); + + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(6, layer1[0].at(0).protoContext().nodeId()); + + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(7, layer1[1].at(0).protoContext().nodeId()); +} + +TEST_F(GraphAnalysisTests, optimizedGraph_3) { + + // 0 = A*B+C, 1_0 = 0-D, 1_1 = 0+D, 2 = 1_0*1_1 + Graph graph; + + graph.addVariable("A", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("B", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("C", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("D", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + + Node a(sd::ops::multiply(), "a"); + Node b(sd::ops::add(), "b"); + Node c(sd::ops::subtract(), "c"); + Node d(sd::ops::add(), "d"); + Node e(sd::ops::multiply(), "e"); + + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"a", "C"}); + + graph.addNode(c, {"b", "D"}); + graph.addNode(d, {"b", "D"}); + + graph.addNode(e, {"c", "d"}); + + // we just check that nodes were really added + ASSERT_EQ(5, graph.size()); + + const auto& optimized = graph.optimizedGraph(); + + // we expect that OptimizedGraph has exactly 3 layer + ASSERT_EQ(3, optimized.numOfLayers()); + + // checking first layer first + auto layer0 = optimized.layer(0); + + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer0.width()); + // auto sequence = layer0[0]; + + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(2, layer0[0].length()); + + ASSERT_EQ(5, layer0[0].at(0).protoContext().nodeId()); + ASSERT_EQ(6, layer0[0].at(1).protoContext().nodeId()); + + // checking second layer now + const auto& layer1 = optimized.layer(1); + + // we expect layer has exactly 2 OpSequences + ASSERT_EQ(2, layer1.width()); + + // sequence = layer1[0]; + + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(7, layer1[0].at(0).protoContext().nodeId()); + + // sequence = layer1[1]; + + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(8, layer1[1].at(0).protoContext().nodeId()); + + // checking last layer + auto layer2 = optimized.layer(2); + + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer2.width()); + // sequence = layer2[0]; + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[0].length()); + ASSERT_EQ(9, layer2[0].at(0).protoContext().nodeId()); +} + +TEST_F(GraphAnalysisTests, optimizedGraph_4) { + Graph graph; + + graph.addVariable("A", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("B", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("C", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("D", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("E", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("F", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + + Node a1(sd::ops::multiply(), "a1"); + Node a2(sd::ops::add(), "a2"); + + Node b1(sd::ops::subtract(), "b1"); + Node b2(sd::ops::add(), "b2"); + Node b3(sd::ops::multiply(), "b3"); + + Node d1(sd::ops::multiply(), "d1"); + Node d2(sd::ops::add(), "d2"); + + Node e(sd::ops::subtract(), "e"); + + graph.addNode(a1, {"A", "B"}); + graph.addNode(a2, {"C", "D"}); + + graph.addNode(b1, {"a1", "E"}); + graph.addNode(b2, {"a1", "a2"}); + graph.addNode(b3, {"a2", "F"}); + + graph.addNode(d1, {"b1", "b2"}); + graph.addNode(d2, {"b3", "b2"}); + + graph.addNode(e, {"d1", "d2"}); + + // we just check that nodes were really added + ASSERT_EQ(8, graph.size()); + + const auto& optimized = graph.optimizedGraph(); + + // we expect that OptimizedGraph has exactly 4 layer + ASSERT_EQ(4, optimized.numOfLayers()); + + // checking first layer first + auto layer0 = optimized.layer(0); + + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer0.width()); + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer0[0].length()); + ASSERT_EQ(7, layer0[0].at(0).protoContext().nodeId()); + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer0[1].length()); + ASSERT_EQ(8, layer0[1].at(0).protoContext().nodeId()); + + // checking second layer now + auto layer1 = optimized.layer(1); + + // we expect layer has exactly 3 OpSequences + ASSERT_EQ(3, layer1.width()); + + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(9, layer1[0].at(0).protoContext().nodeId()); + + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(10, layer1[1].at(0).protoContext().nodeId()); + + ASSERT_EQ(1, layer1[2].length()); + ASSERT_EQ(11, layer1[2].at(0).protoContext().nodeId()); + + auto layer2 = optimized.layer(2); + + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer2.width()); + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[0].length()); + ASSERT_EQ(12, layer2[0].at(0).protoContext().nodeId()); + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[1].length()); + ASSERT_EQ(13, layer2[1].at(0).protoContext().nodeId()); + + // checking last layer + auto layer3 = optimized.layer(3); + + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer3.width()); + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer3[0].length()); + ASSERT_EQ(14, layer3[0].at(0).protoContext().nodeId()); +} + +TEST_F(GraphAnalysisTests, optimizedGraph_5) { + + Graph graph; + + graph.addVariable("A", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("B", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("C", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("D", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + + Node a(sd::ops::multiply(), "a"); + Node b(sd::ops::add(), "b"); + Node c(sd::ops::subtract(), "c"); + Node d(sd::ops::add(), "d"); + Node e(sd::ops::multiply(), "e"); + Node f(sd::ops::multiply(), "f"); + + Node g(sd::ops::multiply(), "g"); + Node h(sd::ops::multiply(), "h"); + + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"C", "D"}); + + graph.addNode(c, {"a", "b"}); + graph.addNode(d, {"a", "b"}); + + graph.addNode(e, {"c", "d"}); + graph.addNode(f, {"c", "d"}); + + graph.addNode(g, {"c", "e"}); + graph.addNode(h, {"d", "f"}); + + // we just check that nodes were really added + ASSERT_EQ(8, graph.size()); + + const auto& optimized = graph.optimizedGraph(); + + // we expect that OptimizedGraph has exactly 3 layer + ASSERT_EQ(4, optimized.numOfLayers()); + + // checking first layer first + auto layer0 = optimized.layer(0); + + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer0.width()); + // auto sequence = layer0[0]; + + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(1, layer0[0].length()); + + ASSERT_EQ(5, layer0[0].at(0).protoContext().nodeId()); + + // sequence = layer0[1]; + + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(1, layer0[1].length()); + ASSERT_EQ(6, layer0[1].at(0).protoContext().nodeId()); + + // checking second layer now + auto layer1 = optimized.layer(1); + + // we expect layer has exactly 2 OpSequences + ASSERT_EQ(2, layer1.width()); + + // sequence = layer1[0]; + + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(7, layer1[0].at(0).protoContext().nodeId()); + + // sequence = layer1[1]; + + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(8, layer1[1].at(0).protoContext().nodeId()); + + // checking before last layer + auto layer2 = optimized.layer(2); + + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer2.width()); + // sequence = layer2[0]; + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[0].length()); + ASSERT_EQ(9, layer2[0].at(0).protoContext().nodeId()); + // sequence = layer2[1]; + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[1].length()); + ASSERT_EQ(10, layer2[1].at(0).protoContext().nodeId()); + + // checking last layer + auto layer3 = optimized.layer(3); + + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer3.width()); + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer3[0].length()); + ASSERT_EQ(11, layer3[0].at(0).protoContext().nodeId()); + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer3[1].length()); + ASSERT_EQ(12, layer3[1].at(0).protoContext().nodeId()); +} + +TEST_F(GraphAnalysisTests, optimizedGraph_6) { + Graph graph; + + graph.addVariable("A", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("B", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("C", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("D", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("E", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("F", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + + Node a(sd::ops::multiply(), "a"); + Node b1(sd::ops::add(), "b1"); + Node b2(sd::ops::subtract(), "b2"); + + Node c1(sd::ops::add(), "c1"); + Node c2(sd::ops::multiply(), "c2"); + Node c3(sd::ops::subtract(), "c3"); + + Node d1(sd::ops::multiply(), "d1"); + Node d2(sd::ops::multiply(), "d2"); + + Node e(sd::ops::add(), "e"); + + graph.addNode(a, {"A", "B"}); + + graph.addNode(b1, {"a", "C"}); + graph.addNode(b2, {"a", "D"}); + + graph.addNode(c1, {"b1", "E"}); + graph.addNode(c2, {"b1", "b2"}); + graph.addNode(c3, {"b2", "F"}); + + graph.addNode(d1, {"c1", "c2"}); + graph.addNode(d2, {"c2", "c3"}); + + graph.addNode(e, {"d1", "d2"}); + + // we just check that nodes were really added + ASSERT_EQ(9, graph.size()); + + const auto& optimized = graph.optimizedGraph(); + + // we expect that OptimizedGraph has exactly 3 layer + ASSERT_EQ(5, optimized.numOfLayers()); + + // checking first layer first + auto layer0 = optimized.layer(0); + + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer0.width()); + + // auto sequence = layer0[0]; + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(1, layer0[0].length()); + ASSERT_EQ(7, layer0[0].at(0).protoContext().nodeId()); + + // checking second layer now + auto layer1 = optimized.layer(1); + + // we expect layer has exactly 2 OpSequences + ASSERT_EQ(2, layer1.width()); + + // sequence = layer1[0]; + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(8, layer1[0].at(0).protoContext().nodeId()); + + // sequence = layer1[1]; + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(9, layer1[1].at(0).protoContext().nodeId()); + + // checking midle layer + auto layer2 = optimized.layer(2); + + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(3, layer2.width()); + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[0].length()); + ASSERT_EQ(10, layer2[0].at(0).protoContext().nodeId()); + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[1].length()); + ASSERT_EQ(11, layer2[1].at(0).protoContext().nodeId()); + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[2].length()); + ASSERT_EQ(12, layer2[2].at(0).protoContext().nodeId()); + + // checking before last layer + auto layer3 = optimized.layer(3); + + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer3.width()); + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer3[0].length()); + ASSERT_EQ(13, layer3[0].at(0).protoContext().nodeId()); + + // sequence = layer3[1]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer3[1].length()); + ASSERT_EQ(14, layer3[1].at(0).protoContext().nodeId()); + + // checking last layer + auto layer4 = optimized.layer(4); + + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(1, layer4.width()); + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer4[0].length()); + ASSERT_EQ(15, layer4[0].at(0).protoContext().nodeId()); +} + +TEST_F(GraphAnalysisTests, optimizedGraph_7) { + Graph graph; + + graph.addVariable("A", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("B", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("C", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + + Node a(sd::ops::multiply(), "a"); + Node b(sd::ops::add(), "b"); + Node c(sd::ops::subtract(), "c"); + Node d(sd::ops::add(), "d"); + Node e(sd::ops::multiply(), "e"); + + graph.addNode(a, {"A", "B"}); + + graph.addNode(b, {"a", "C"}); + + graph.addNode(c, {"a", "b"}); + + graph.addNode(d, {"b", "c"}); + + graph.addNode(e, {"b", "c", "d"}); + + // we just check that nodes were really added + ASSERT_EQ(5, graph.size()); + + const auto& optimized = graph.optimizedGraph(); + // graph.printOut(); + // we expect that OptimizedGraph has exactly 3 layer + ASSERT_EQ(1, optimized.numOfLayers()); + + auto layer = optimized.layer(0); + + ASSERT_EQ(1, layer.width()); + + auto seq = layer.at(0); + ASSERT_EQ(5, seq.length()); + + // this Graph doesn't allow any variance here. Order must be exactly the same as below + ASSERT_EQ(std::string("a"), seq[0].node().name()); + ASSERT_EQ(std::string("b"), seq[1].node().name()); + ASSERT_EQ(std::string("c"), seq[2].node().name()); + ASSERT_EQ(std::string("d"), seq[3].node().name()); + ASSERT_EQ(std::string("e"), seq[4].node().name()); +} + +TEST_F(GraphAnalysisTests, optimizedGraph_8) { + Graph graph; + + graph.addVariable("A", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("B", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("C", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("D", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("E", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("F", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + + Node a1(sd::ops::multiply(), "a1"); + Node a2(sd::ops::add(), "a2"); + Node a3(sd::ops::add(), "a3"); + + Node b1(sd::ops::subtract(), "b1"); + Node b2(sd::ops::add(), "b2"); + Node b3(sd::ops::multiply(), "b3"); + + graph.addNode(a1, {"A", "B"}); + graph.addNode(a2, {"C", "D"}); + graph.addNode(a3, {"E", "F"}); + + graph.addNode(b1, {"a1", "a2"}); + graph.addNode(b2, {"a1", "a2", "a3"}); + graph.addNode(b3, {"a2", "a3"}); + + // we just check that nodes were really added + ASSERT_EQ(6, graph.size()); + + const auto& optimized = graph.optimizedGraph(); + + // we expect that OptimizedGraph has exactly 2 layer + ASSERT_EQ(2, optimized.numOfLayers()); + + // checking first layer first + auto layer0 = optimized.layer(0); + + // we expect layer has exactly 3 OpSequence + ASSERT_EQ(3, layer0.width()); + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer0[0].length()); + ASSERT_EQ(7, layer0[0].at(0).protoContext().nodeId()); + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer0[1].length()); + ASSERT_EQ(8, layer0[1].at(0).protoContext().nodeId()); + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer0[2].length()); + ASSERT_EQ(9, layer0[2].at(0).protoContext().nodeId()); + + // checking second layer now + auto layer1 = optimized.layer(1); + + // we expect layer has exactly 3 OpSequences + ASSERT_EQ(3, layer1.width()); + + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(10, layer1[0].at(0).protoContext().nodeId()); + + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(11, layer1[1].at(0).protoContext().nodeId()); + + ASSERT_EQ(1, layer1[2].length()); + ASSERT_EQ(12, layer1[2].at(0).protoContext().nodeId()); +} + +TEST_F(GraphAnalysisTests, optimizedGraph_9) { + // start graph + + Graph graph; + + graph.addVariable("A", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("B", NDArray('c', {3}, {2, 2, 2}, sd::DataType::INT32)); + graph.addVariable("C", NDArray('c', {3}, {3, 3, 3}, sd::DataType::INT32)); + graph.addVariable("D", NDArray('c', {3}, {4, 4, 4}, sd::DataType::INT32)); + + Node a(sd::ops::multiply(), "a"); + + Node b1(sd::ops::add(), "b1"); + Node b2(sd::ops::multiply(), "b2"); + Node b3(sd::ops::subtract(), "b3"); + Node b4(sd::ops::Pow(), "b4"); + + Node c1(sd::ops::Pow(), "c1"); + Node c2(sd::ops::subtract(), "c2"); + Node c3(sd::ops::multiply(), "c3"); + Node c4(sd::ops::add(), "c4"); + + Node c5(sd::ops::Pow(), "c5"); + Node c6(sd::ops::subtract(), "c6"); + Node c7(sd::ops::multiply(), "c7"); + Node c8(sd::ops::add(), "c8"); + + graph.addNode(a, {"A", "B"}); + + graph.addNode(b1, {"a", "C"}); + graph.addNode(b2, {"a", "C"}); + graph.addNode(b3, {"a", "C"}); + graph.addNode(b4, {"a", "C"}); + + graph.addNode(c1, {"b1", "D"}); + graph.addNode(c2, {"b2", "D"}); + graph.addNode(c3, {"b3", "D"}); + graph.addNode(c4, {"b4", "D"}); + + graph.addNode(c5, {"b1", "D"}); + graph.addNode(c6, {"b2", "D"}); + graph.addNode(c7, {"b3", "D"}); + graph.addNode(c8, {"b4", "D"}); + + // we just check that nodes were really added + ASSERT_EQ(13, graph.size()); + + const auto& optimized = graph.optimizedGraph(); + + // we expect that OptimizedGraph has exactly 1 layer + ASSERT_EQ(3, optimized.numOfLayers()); + + auto layer = optimized.layer(0); + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer.width()); + + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(1, layer[0].length()); + ASSERT_EQ(5, layer[0].at(0).protoContext().nodeId()); + + auto layer1 = optimized.layer(1); + // we expect layer has exactly 4 OpSequence + ASSERT_EQ(4, layer1.width()); + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(6, layer1[0].at(0).protoContext().nodeId()); + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(7, layer1[1].at(0).protoContext().nodeId()); + + // sequence = layer1[2]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer1[2].length()); + ASSERT_EQ(8, layer1[2].at(0).protoContext().nodeId()); + + // sequence = layer1[3]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer1[3].length()); + ASSERT_EQ(9, layer1[3].at(0).protoContext().nodeId()); + + auto layer2 = optimized.layer(2); + // we expect layer has exactly 4 OpSequence + ASSERT_EQ(8, layer2.width()); + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[0].length()); + ASSERT_EQ(10, layer2[0].at(0).protoContext().nodeId()); + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[1].length()); + ASSERT_EQ(11, layer2[1].at(0).protoContext().nodeId()); + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[2].length()); + ASSERT_EQ(12, layer2[2].at(0).protoContext().nodeId()); + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[3].length()); + ASSERT_EQ(13, layer2[3].at(0).protoContext().nodeId()); + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[4].length()); + ASSERT_EQ(14, layer2[4].at(0).protoContext().nodeId()); + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[5].length()); + ASSERT_EQ(15, layer2[5].at(0).protoContext().nodeId()); + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[6].length()); + ASSERT_EQ(16, layer2[6].at(0).protoContext().nodeId()); + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[7].length()); + ASSERT_EQ(17, layer2[7].at(0).protoContext().nodeId()); +} + +TEST_F(GraphAnalysisTests, optimizedGraph_10) { + Graph graph; + + graph.addVariable("A", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("B", NDArray('c', {3}, {2, 2, 2}, sd::DataType::INT32)); + graph.addVariable("C", NDArray('c', {3}, {3, 3, 3}, sd::DataType::INT32)); + graph.addVariable("D", NDArray('c', {3}, {3, 3, 3}, sd::DataType::INT32)); + + Node a(sd::ops::multiply(), "a"); + Node b(sd::ops::add(), "b"); + Node c(sd::ops::multiply(), "c"); + Node d(sd::ops::subtract(), "d"); + + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"a", "C"}); + graph.addNode(c, {"a", "D"}); + graph.addNode(d, {"a", "b", "c"}); + + // we just check that nodes were really added + ASSERT_EQ(4, graph.size()); + + const auto& optimized = graph.optimizedGraph(); + + // we expect that OptimizedGraph has exactly 1 layer + ASSERT_EQ(3, optimized.numOfLayers()); + + auto layer = optimized.layer(0); + + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer.width()); + auto sequence = layer[0]; + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); + + auto layer1 = optimized.layer(1); + + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer1.width()); + sequence = layer1[0]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(6, sequence.at(0).protoContext().nodeId()); + sequence = layer1[1]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); + + auto layer2 = optimized.layer(2); + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer2.width()); + sequence = layer2[0]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); +} + +TEST_F(GraphAnalysisTests, optimizedGraph_11) { + Graph graph; + + graph.addVariable("A", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("B", NDArray('c', {3}, {2, 2, 2}, sd::DataType::INT32)); + graph.addVariable("C", NDArray('c', {3}, {3, 3, 3}, sd::DataType::INT32)); + graph.addVariable("D", NDArray('c', {3}, {3, 3, 3}, sd::DataType::INT32)); + + Node a(sd::ops::multiply(), "a"); + Node b(sd::ops::add(), "b"); + Node c(sd::ops::multiply(), "c"); + Node d(sd::ops::subtract(), "d"); + + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"A", "C"}); + graph.addNode(c, {"B", "D"}); + graph.addNode(d, {"C", "D"}); + + // we just check that nodes were really added + ASSERT_EQ(4, graph.size()); + + const auto& optimized = graph.optimizedGraph(); + + // we expect that OptimizedGraph has exactly 1 layer + ASSERT_EQ(1, optimized.numOfLayers()); + + auto layer = optimized.layer(0); + + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(4, layer.width()); + auto sequence = layer[0]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); + sequence = layer[1]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(6, sequence.at(0).protoContext().nodeId()); + sequence = layer[2]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); + sequence = layer[3]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); +} + +TEST_F(GraphAnalysisTests, optimizedGraph_12) { + Graph graph; + + graph.addVariable("start", NDArrayFactory::create(0)); + graph.addVariable("step", NDArrayFactory::create(1)); + + graph.addVariable("const_1", NDArrayFactory::create(0)); + graph.addVariable("const_2", NDArrayFactory::create(2)); + + // generating "stop" argument for Range op + graph.addNode(Node(sd::ops::add(), "add"), {"const_1", "const_2"}); + + // generating axis, should be equal to {0} + graph.addNode(Node(sd::ops::range(), "range_1"), {"start", "add", "step"}); + + graph.addNode(Node(sd::ops::range(), "range_2"), {"range_1", "add", "step"}); + + auto &optimized = graph.optimizedGraph(); + + // graph.printOut(); + + // we expect exactly 1 layer + ASSERT_EQ(1, optimized.layers()); + auto layer = optimized.layer(0); + + // we expect exactly 1 OpSequence wihtin this layer + ASSERT_EQ(1, layer.width()); + auto seq = layer.at(0); + + // this Graph doesn't allow any variance here. Order must be exactly the same as below + ASSERT_EQ(std::string("add"), seq[0].node().name()); + ASSERT_EQ(std::string("range_1"), seq[1].node().name()); + ASSERT_EQ(std::string("range_2"), seq[2].node().name()); + + std::pair exp; + + exp = {seq[1].node().id(), 0}; + ASSERT_EQ(exp, seq[0].node().outputs()[0]); + + exp = {seq[2].node().id(), 0}; + ASSERT_EQ(exp, seq[1].node().outputs()[0]); + + ASSERT_EQ(0, seq[2].node().outputs().size()); +} + +TEST_F(GraphAnalysisTests, optimizedGraph_13) { + Graph graph; + + graph.addPlaceholder("input", sd::DataType::FLOAT32, {-1, 3, 244, 244}); + graph.addVariable("weights_1", NDArrayFactory::create(0.1f)); + graph.addVariable("weights_2", NDArrayFactory::create(0.1f)); + graph.addVariable("weights_3", NDArrayFactory::create(0.1f)); + graph.addVariable("weights_4", NDArrayFactory::create(0.1f)); + + graph.addVariable("axis", NDArrayFactory::create(1)); + + graph.addNode(Node(sd::ops::tanh(), "conv_1"), {"input", "weights_1"}); + graph.addNode(Node(sd::ops::tanh(), "pooling_1"), {"conv_1"}); + + // branch 1 + graph.addNode(Node(sd::ops::tanh(), "conv_2"), {"pooling_1", "weights_2"}); + graph.addNode(Node(sd::ops::tanh(), "pooling_2"), {"conv_2"}); + + // branch 2 + graph.addNode(Node(sd::ops::tanh(), "conv_3"), {"pooling_1", "weights_3"}); + graph.addNode(Node(sd::ops::tanh(), "pooling_3"), {"conv_3"}); + + // branch 3 + graph.addNode(Node(sd::ops::tanh(), "conv_4"), {"pooling_1", "weights_4"}); + graph.addNode(Node(sd::ops::tanh(), "pooling_4"), {"conv_4"}); + + // merge branch + graph.addNode(Node(sd::ops::concat(), "concat"), {"pooling_2", "pooling_3", "pooling_4", "axis"}); + + auto &optimized = graph.optimizedGraph(); + + // we expect exactly 3 layers + ASSERT_EQ(3, optimized.layers()); + auto layer = optimized.layer(0); + + // layer 0 must have exactly 1 sequence of 2 ops: conv_1 and pooling_1 + ASSERT_EQ(1, layer.width()); + auto seq = layer[0]; + + ASSERT_EQ(2, seq.length()); + ASSERT_EQ(std::string("conv_1"), seq[0].node().name()); + ASSERT_EQ(std::string("pooling_1"), seq[1].node().name()); + + layer = optimized.layer(1); + ASSERT_EQ(3, layer.width()); + + seq = layer[0]; + ASSERT_EQ(2, seq.length()); + ASSERT_EQ(std::string("conv_2"), seq[0].node().name()); + ASSERT_EQ(std::string("pooling_2"), seq[1].node().name()); + + seq = layer[1]; + ASSERT_EQ(2, seq.length()); + ASSERT_EQ(std::string("conv_3"), seq[0].node().name()); + ASSERT_EQ(std::string("pooling_3"), seq[1].node().name()); + + seq = layer[2]; + ASSERT_EQ(2, seq.length()); + ASSERT_EQ(std::string("conv_4"), seq[0].node().name()); + ASSERT_EQ(std::string("pooling_4"), seq[1].node().name()); + + layer = optimized.layer(2); + ASSERT_EQ(1, layer.width()); + + seq = layer[0]; + ASSERT_EQ(1, seq.length()); + ASSERT_EQ(std::string("concat"), seq[0].node().name()); +} + +TEST_F(GraphAnalysisTests, optimizedGraph_cond1) { + + auto graph = Graph::fromFlatBuffers("resources/cond_true.fb"); + const auto& optimized = graph.optimizedGraph(); + graph.printOut(); + + // we expect exactly 3 layers + ASSERT_EQ(3, optimized.layers()); + + // layer 0 + auto layer = optimized.layer(0); + ASSERT_EQ(1, layer.width()); + auto seq = layer[0]; + ASSERT_EQ(2, seq.length()); + ASSERT_EQ(std::string("in_0/read"), seq[0].node().name()); + ASSERT_EQ(std::string("cond/Switch"), seq[1].node().name()); + + auto swtch = seq[1].node(); + ASSERT_EQ(2, swtch.outputs().size()); + + // layer 1 + layer = optimized.layer(1); + ASSERT_EQ(2, layer.width()); + seq = layer[0]; + ASSERT_EQ(2, seq.length()); + ASSERT_EQ(std::string("cond/switch_t"), seq[0].node().name()); + ASSERT_EQ(std::string("cond/LinSpace"), seq[1].node().name()); + seq = layer[1]; + ASSERT_EQ(1, seq.length()); + ASSERT_EQ(std::string("cond/switch_f"), seq[0].node().name()); + + std::pair exp; + + // checking True condition first + exp = {layer[0][0].node().id(), 1}; + ASSERT_EQ(exp, swtch.outputs()[0]); + + // checking False condition next + exp = {layer[1][0].node().id(), 0}; + ASSERT_EQ(exp, swtch.outputs()[1]); + + // layer 2 + layer = optimized.layer(2); + ASSERT_EQ(1, layer.width()); + seq = layer[0]; + ASSERT_EQ(1, seq.length()); + ASSERT_EQ(std::string("cond/Merge"), seq[0].node().name()); + + auto res = graph.execute({}, {"cond/Merge"}); + ASSERT_EQ(1, res.size()); + auto arr = res["cond/Merge"]; + + ASSERT_EQ(NDArrayFactory::create({1.f, 2.f, 3.f, 4.f, 5.f}), arr); +} + +TEST_F(GraphAnalysisTests, optimizedGraph_cond2) { + + auto graph = Graph::fromFlatBuffers("resources/cond_false.fb"); + const auto& optimized = graph.optimizedGraph(); + graph.printOut(); + + ASSERT_EQ(3, optimized.layers()); + + // layer 0 + auto layer = optimized.layer(0); + ASSERT_EQ(1, layer.width()); + auto seq = layer[0]; + ASSERT_EQ(2, seq.length()); + ASSERT_EQ(std::string("in_0/read"), seq[0].node().name()); + ASSERT_EQ(std::string("cond/Switch"), seq[1].node().name()); + + // layer 1 + layer = optimized.layer(1); + ASSERT_EQ(2, layer.width()); + seq = layer[0]; + ASSERT_EQ(2, seq.length()); + ASSERT_EQ(std::string("cond/switch_t"), seq[0].node().name()); + ASSERT_EQ(std::string("cond/LinSpace"), seq[1].node().name()); + seq = layer[1]; + ASSERT_EQ(1, seq.length()); + ASSERT_EQ(std::string("cond/switch_f"), seq[0].node().name()); + + // layer 2 + layer = optimized.layer(2); + ASSERT_EQ(1, layer.width()); + seq = layer[0]; + ASSERT_EQ(1, seq.length()); + ASSERT_EQ(std::string("cond/Merge"), seq[0].node().name()); + + auto res = graph.execute({}, {"cond/Merge"}); + ASSERT_EQ(1, res.size()); + ASSERT_EQ(NDArrayFactory::create({1.f, 1.f, 1.f, 1.f, 1.f}), res["cond/Merge"]); +} + +TEST_F(GraphAnalysisTests, optimizedGraph_while1) { + auto graph = Graph::fromFlatBuffers("resources/while_iter1.fb"); + const auto& optimized = graph.optimizedGraph(); + + // this Graph must have exactly 1 Layer and 1 OpSequence, since all it has is While loop + ASSERT_EQ(1, optimized.layers()); + ASSERT_EQ(1, optimized.layer(0).width()); + + graph.printOut(); + + auto results = graph.execute({}, {"while/Exit", "while/Exit_1"}); + ASSERT_EQ(2, results.size()); + + ASSERT_EQ(NDArrayFactory::create(1.f), results["while/Exit"]); + ASSERT_EQ(NDArrayFactory::create(1.f), results["while/Exit_1"]); +} + +TEST_F(GraphAnalysisTests, optimizedGraph_while2) { + auto graph = Graph::fromFlatBuffers("resources/while_iter3.fb"); + const auto& optimized = graph.optimizedGraph(); + + // this Graph must have exactly 1 Layer and 1 OpSequence, since all it has is While loop + ASSERT_EQ(1, optimized.layers()); + ASSERT_EQ(1, optimized.layer(0).width()); + + graph.printOut(); + + auto results = graph.execute({}, {"while/Exit", "while/Exit_1"}); + ASSERT_EQ(2, results.size()); + + ASSERT_EQ(NDArrayFactory::create(3.f), results["while/Exit"]); + ASSERT_EQ(NDArrayFactory::create(3.f), results["while/Exit_1"]); +} + +TEST_F(GraphAnalysisTests, optimizedGraph_nested_while_1) { + auto input_0 = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + auto input_1 = NDArrayFactory::create('c', {3, 3}, {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}); + + auto graph = Graph::fromFlatBuffers("resources/simplewhile_nested.fb"); + const auto& optimized = graph.optimizedGraph(); + + ASSERT_EQ(2, graph.placeholders().size()); + + graph.printOut(); + + auto results = graph.execute({{"input_0", input_0}, {"input_1", input_1}}, {"output"}); + + ASSERT_EQ(NDArrayFactory::create('c', {2, 2}, {13.f, 14.f, 15.f, 16.f}), results["output"]); +} + +TEST_F(GraphAnalysisTests, optimizedGraph_simpleif_0_alt) { + auto input_0 = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + auto input_1 = NDArrayFactory::create(11.f); + + auto graph = Graph::fromFlatBuffers("resources/simpleif_0_alt.fb"); + + auto results = graph.execute({{"input_0", input_0}, {"input_1", input_1}}, {"output"}); + + ASSERT_EQ(NDArrayFactory::create('c', {2, 2}, {3.f, 4.f, 5.f, 6.f}), results["output"]); +} + +TEST_F(GraphAnalysisTests, optimizedGraph_simplewhile_1) { + auto input_0 = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + auto input_1 = NDArrayFactory::create(21.f); + + auto graph = Graph::fromFlatBuffers("resources/simplewhile_1.fb"); + + auto results = graph.execute({{"input_0", input_0}, {"input_1", input_1}}, {"output"}); + + ASSERT_EQ(NDArrayFactory::create('c', {2, 2}, {24.f, 25.f, 26.f, 27.f}), results["output"]); +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/GraphExecutionerTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphExecutionerTests.cpp deleted file mode 100644 index 7a2856dc0c0a..000000000000 --- a/libnd4j/tests_cpu/layers_tests/GraphExecutionerTests.cpp +++ /dev/null @@ -1,103 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 29.11.17. -// - - -#include "testlayers.h" -#include -#include -#include -#include -#include -#include -#include - -using namespace sd; -using namespace sd::graph; - -class GraphExecutionerTests : public testing::Test { -public: - -}; - -#ifdef GRAPH_TESTS_OK -TEST_F(GraphExecutionerTests, Test_Implicit_Output_1) { - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_slice.fb"); - graph->buildGraph(); - - auto outputs = graph->fetchOutputs(); - - ASSERT_EQ(1, outputs->size()); - - auto var0 = outputs->at(0); - - ASSERT_EQ(7, var0->id()); - ASSERT_EQ(0, var0->index()); - - delete outputs; - delete graph; -} - - -TEST_F(GraphExecutionerTests, Test_Implicit_Output_2) { - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); - graph->buildGraph(); - - auto outputs = graph->fetchOutputs(); - - ASSERT_EQ(1, outputs->size()); - - auto var0 = outputs->at(0); - - ASSERT_EQ(3, var0->id()); - ASSERT_EQ(0, var0->index()); - - delete outputs; - delete graph; -} - - -TEST_F(GraphExecutionerTests, Test_Implicit_Output_3) { - auto exp = NDArrayFactory::create('c', {3}, {3, 3, 3}); - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); - auto status = GraphExecutioner::execute(graph); - - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto outputs = graph->fetchOutputs(); - - ASSERT_EQ(1, outputs->size()); - - auto var0 = outputs->at(0); - - ASSERT_EQ(3, var0->id()); - ASSERT_EQ(0, var0->index()); - - auto array = var0->getNDArray(); - - ASSERT_TRUE(array != nullptr); - - ASSERT_TRUE(exp.isSameShape(array)); - ASSERT_TRUE(exp.equalsTo(array)); - - delete outputs; - delete graph; -} -#endif diff --git a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp new file mode 100644 index 000000000000..038fab85ca40 --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp @@ -0,0 +1,105 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 29.11.17. +// + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "testlayers.h" + +using namespace sd; +using namespace sd::graph; + +class GraphExecutorTests : public testing::Test { + public: +}; + +TEST_F(GraphExecutorTests, test_basic_exec_1) { + GraphMemoryManager memoryManager; + Graph graph; + OptimizedGraph optimizedGraph; + OpSequence sequence; + + optimizedGraph.append(sequence); + ASSERT_EQ(1, optimizedGraph.layers()); + + VariableProxy proxy(&graph.variableSpace()); + GraphExecutor executor; + executor.execute(optimizedGraph, proxy); +} + +TEST_F(GraphExecutorTests, test_basic_exec_2) { + GraphMemoryManager mgr; + Graph graph(nullptr, mgr); + + auto A = NDArrayFactory::create('c', {3}, {1, 1, 1}); + auto B = NDArrayFactory::create('c', {3}, {2, 2, 2}); + auto C = NDArrayFactory::create('c', {3}, {3, 3, 3}); + + auto exp = NDArrayFactory::create('c', {3}, {5, 5, 5}); + + graph.addVariable("A", A); + graph.addVariable("B", B); + graph.addVariable("C", C); + + Node m(sd::ops::multiply(), "mul"); + Node a(sd::ops::add(), "add"); + + graph.addNode(m, {"A", "B"}); + graph.addNode(a, {"mul", "C"}); + + OptimizedGraph optimizedGraph; + OpSequence sequence; + + ASSERT_EQ(2, m.contextPrototype().inputs().size()); + ASSERT_EQ(2, a.contextPrototype().inputs().size()); + + sequence.append(m, m.contextPrototype()); + sequence.append(a, a.contextPrototype()); + + optimizedGraph.append(sequence); + + ASSERT_EQ(2, sequence.length()); + ASSERT_EQ(1, optimizedGraph.layers()); + + VariableProxy proxy(&graph.variableSpace()); + GraphExecutor executor; + executor.execute(optimizedGraph, proxy); + + // checking results by ID + ASSERT_TRUE(proxy.hasVariable(m.id())); + ASSERT_TRUE(proxy.hasVariable(a.id())); + + // checking results by name + ASSERT_TRUE(proxy.hasVariable("mul")); + ASSERT_TRUE(proxy.hasVariable("add")); + + // checking if result is valid + auto result = proxy.getVariable(a.id())->getNDArray(); + ASSERT_EQ(exp, *result); +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/GraphHolderTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphHolderTests.cpp index a50091840cda..5938f30ac72b 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphHolderTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphHolderTests.cpp @@ -18,69 +18,26 @@ // Created by raver119 on 11.12.17. // -#include "testlayers.h" #include +#include "testlayers.h" + using namespace sd; using namespace sd::ops; using namespace sd::graph; class GraphHolderTests : public testing::Test { -public: - + public: }; TEST_F(GraphHolderTests, SimpleTests_1) { - Graph graph; - Nd4jLong graphId = 119; - GraphHolder::getInstance().registerGraph(graphId, &graph); - - ASSERT_TRUE(GraphHolder::getInstance().hasGraph(graphId)); - - GraphHolder::getInstance().forgetGraph(graphId); - - ASSERT_FALSE(GraphHolder::getInstance().hasGraph(graphId)); -} - - - -TEST_F(GraphHolderTests, SimpleTests_2) { - auto graph = new Graph; - Nd4jLong graphId = 117; - GraphHolder::getInstance().registerGraph(graphId, graph); - - ASSERT_TRUE(GraphHolder::getInstance().hasGraph(graphId)); - - auto graph2 = GraphHolder::getInstance().cloneGraph(graphId); - - ASSERT_TRUE(graph != graph2); - ASSERT_TRUE(graph2 != nullptr); - - GraphHolder::getInstance().forgetGraph(graphId); - - ASSERT_FALSE(GraphHolder::getInstance().hasGraph(graphId)); - - delete graph; - delete graph2; -} - - -TEST_F(GraphHolderTests, SimpleTests_3) { - auto graph = new Graph; - Nd4jLong graphId = 117; - GraphHolder::getInstance().registerGraph(graphId, graph); - - ASSERT_TRUE(GraphHolder::getInstance().hasGraph(graphId)); - - auto graph2 = GraphHolder::getInstance().cloneGraph(graphId); - - ASSERT_TRUE(graph != graph2); - ASSERT_TRUE(graph2 != nullptr); - - GraphHolder::getInstance().dropGraph(graphId); + Graph graph; + Nd4jLong graphId = 119; + GraphHolder::getInstance().registerGraph(graphId, graph); - ASSERT_FALSE(GraphHolder::getInstance().hasGraph(graphId)); + ASSERT_TRUE(GraphHolder::getInstance().hasGraph(graphId)); + GraphHolder::getInstance().forgetGraph(graphId); - delete graph2; + ASSERT_FALSE(GraphHolder::getInstance().hasGraph(graphId)); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/GraphRandomGeneratorTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphRandomGeneratorTests.cpp index 8fe46cd2f791..b93721b1e1d4 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphRandomGeneratorTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphRandomGeneratorTests.cpp @@ -14,251 +14,247 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -#include "testlayers.h" -#include #include #include +#include + #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class GraphRandomGeneratorTests : public testing::Test { -public: - + public: }; TEST_F(GraphRandomGeneratorTests, Reproducibility_Test_1) { - sd::graph::RandomGenerator g0(119); - sd::graph::RandomGenerator g1(119); + sd::graph::RandomGenerator g0(119); + sd::graph::RandomGenerator g1(119); - auto i0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto i1 = g1.relativeT(15, 0, DataTypeUtils::max()); + auto i0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto i1 = g1.relativeT(15, 0, DataTypeUtils::max()); - ASSERT_EQ(i0, i1); + ASSERT_EQ(i0, i1); } TEST_F(GraphRandomGeneratorTests, Reproducibility_Test_2) { - sd::graph::RandomGenerator g0(119); - sd::graph::RandomGenerator g1(117); + sd::graph::RandomGenerator g0(119); + sd::graph::RandomGenerator g1(117); - auto i0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto i1 = g1.relativeT(15, 0, DataTypeUtils::max()); + auto i0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto i1 = g1.relativeT(15, 0, DataTypeUtils::max()); - ASSERT_NE(i0, i1); + ASSERT_NE(i0, i1); } TEST_F(GraphRandomGeneratorTests, Reproducibility_Test_3) { - sd::graph::RandomGenerator g0(119, 5); - sd::graph::RandomGenerator g1(119, 10); + sd::graph::RandomGenerator g0(119, 5); + sd::graph::RandomGenerator g1(119, 10); - auto i0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto i1 = g1.relativeT(15, 0, DataTypeUtils::max()); + auto i0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto i1 = g1.relativeT(15, 0, DataTypeUtils::max()); - ASSERT_NE(i0, i1); + ASSERT_NE(i0, i1); } TEST_F(GraphRandomGeneratorTests, Reproducibility_Test_4) { - sd::graph::RandomGenerator g0(119, 5); - sd::graph::RandomGenerator g1(117, 5); + sd::graph::RandomGenerator g0(119, 5); + sd::graph::RandomGenerator g1(117, 5); - auto i0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto i1 = g1.relativeT(15, 0, DataTypeUtils::max()); + auto i0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto i1 = g1.relativeT(15, 0, DataTypeUtils::max()); - ASSERT_NE(i0, i1); + ASSERT_NE(i0, i1); } TEST_F(GraphRandomGeneratorTests, Sequential_Test_1) { - sd::graph::RandomGenerator g0(119, 5); - sd::graph::RandomGenerator g1(119, 5); + sd::graph::RandomGenerator g0(119, 5); + sd::graph::RandomGenerator g1(119, 5); - auto v0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto v1 = g1.relativeT(15, 0, DataTypeUtils::max()); - g0.rewindH(200); - auto r0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto r1 = g1.relativeT(15, 0, DataTypeUtils::max()); + auto v0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto v1 = g1.relativeT(15, 0, DataTypeUtils::max()); + g0.rewindH(200); + auto r0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto r1 = g1.relativeT(15, 0, DataTypeUtils::max()); - // values after rewind aren't equal - ASSERT_NE(r0, v0); + // values after rewind aren't equal + ASSERT_NE(r0, v0); - // two generators must give the same output - ASSERT_EQ(v0, v1); + // two generators must give the same output + ASSERT_EQ(v0, v1); - // but not after one of them was rewinded - ASSERT_NE(r1, r0); + // but not after one of them was rewinded + ASSERT_NE(r1, r0); } TEST_F(GraphRandomGeneratorTests, Sequential_Test_2) { - sd::graph::RandomGenerator g0(119, 5); - sd::graph::RandomGenerator g1(119, 5); + sd::graph::RandomGenerator g0(119, 5); + sd::graph::RandomGenerator g1(119, 5); - auto v0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto v1 = g1.relativeT(15, 0, DataTypeUtils::max()); - g0.rewindH(200); - g1.rewindH(199); - auto r0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto r1 = g1.relativeT(15, 0, DataTypeUtils::max()); + auto v0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto v1 = g1.relativeT(15, 0, DataTypeUtils::max()); + g0.rewindH(200); + g1.rewindH(199); + auto r0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto r1 = g1.relativeT(15, 0, DataTypeUtils::max()); - // values after rewind aren't equal - ASSERT_NE(r0, v0); + // values after rewind aren't equal + ASSERT_NE(r0, v0); - // two generators must give the same output - ASSERT_EQ(v0, v1); + // two generators must give the same output + ASSERT_EQ(v0, v1); - // but not after they was rewinded with different number of elements - ASSERT_NE(r1, r0); + // but not after they was rewinded with different number of elements + ASSERT_NE(r1, r0); } TEST_F(GraphRandomGeneratorTests, Sequential_Test_3) { - sd::graph::RandomGenerator g0(119, 5); - sd::graph::RandomGenerator g1(119, 5); + sd::graph::RandomGenerator g0(119, 5); + sd::graph::RandomGenerator g1(119, 5); - auto v0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto v1 = g1.relativeT(15, 0, DataTypeUtils::max()); - g0.rewindH(200); - g1.rewindH(200); - auto r0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto r1 = g1.relativeT(15, 0, DataTypeUtils::max()); + auto v0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto v1 = g1.relativeT(15, 0, DataTypeUtils::max()); + g0.rewindH(200); + g1.rewindH(200); + auto r0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto r1 = g1.relativeT(15, 0, DataTypeUtils::max()); - // values after rewind aren't equal - ASSERT_NE(r0, v0); + // values after rewind aren't equal + ASSERT_NE(r0, v0); - // two generators must give the same output - ASSERT_EQ(v0, v1); + // two generators must give the same output + ASSERT_EQ(v0, v1); - // and here output must be equal as well - ASSERT_EQ(r1, r0); + // and here output must be equal as well + ASSERT_EQ(r1, r0); } TEST_F(GraphRandomGeneratorTests, Sequential_Test_4) { - sd::graph::RandomGenerator g0(119, 5); - sd::graph::RandomGenerator g1(119, 5); - - auto v0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto v1 = g1.relativeT(15, 0, DataTypeUtils::max()); - g0.rewindH(200); - g1.rewindH(200); - auto r0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto r1 = g1.relativeT(15, 0, DataTypeUtils::max()); - g0.rewindH(200); - g1.rewindH(200); - auto z0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto z1 = g1.relativeT(15, 0, DataTypeUtils::max()); - g0.rewindH(201); - g1.rewindH(199); - auto y0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto y1 = g1.relativeT(15, 0, DataTypeUtils::max()); - - // values after rewind aren't equal - ASSERT_NE(r0, v0); - - // two generators must give the same output - ASSERT_EQ(v0, v1); - - // and here output must be equal as well - ASSERT_EQ(r0, r1); - - ASSERT_EQ(z0, z1); - - ASSERT_NE(r0, z0); - ASSERT_NE(r1, z1); - - ASSERT_NE(y0, z0); - ASSERT_NE(y1, z1); + sd::graph::RandomGenerator g0(119, 5); + sd::graph::RandomGenerator g1(119, 5); + + auto v0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto v1 = g1.relativeT(15, 0, DataTypeUtils::max()); + g0.rewindH(200); + g1.rewindH(200); + auto r0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto r1 = g1.relativeT(15, 0, DataTypeUtils::max()); + g0.rewindH(200); + g1.rewindH(200); + auto z0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto z1 = g1.relativeT(15, 0, DataTypeUtils::max()); + g0.rewindH(201); + g1.rewindH(199); + auto y0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto y1 = g1.relativeT(15, 0, DataTypeUtils::max()); + + // values after rewind aren't equal + ASSERT_NE(r0, v0); + + // two generators must give the same output + ASSERT_EQ(v0, v1); + + // and here output must be equal as well + ASSERT_EQ(r0, r1); + + ASSERT_EQ(z0, z1); + + ASSERT_NE(r0, z0); + ASSERT_NE(r1, z1); + + ASSERT_NE(y0, z0); + ASSERT_NE(y1, z1); } - //#ifndef __clang__ TEST_F(GraphRandomGeneratorTests, Long_Test_1) { - sd::graph::RandomGenerator g0(119, 5); - sd::graph::RandomGenerator g1(119, 5); + sd::graph::RandomGenerator g0(119, 5); + sd::graph::RandomGenerator g1(119, 5); - std::array z0, z1, z2, z3; + std::array z0, z1, z2, z3; - for (int e = 0; e < z0.size(); e++) { - z0[e] = g0.relativeT(e); - z1[e] = g1.relativeT(e); - } + for (int e = 0; e < z0.size(); e++) { + z0[e] = g0.relativeT(e); + z1[e] = g1.relativeT(e); + } - g0.rewindH(z0.size()); - g1.rewindH(z0.size()); + g0.rewindH(z0.size()); + g1.rewindH(z0.size()); - for (int e = 0; e < z0.size(); e++) { - z2[e] = g0.relativeT(e); - z3[e] = g1.relativeT(e); - } + for (int e = 0; e < z0.size(); e++) { + z2[e] = g0.relativeT(e); + z3[e] = g1.relativeT(e); + } - // these sequences should be equal - ASSERT_EQ(z0, z1); - ASSERT_EQ(z2, z3); + // these sequences should be equal + ASSERT_EQ(z0, z1); + ASSERT_EQ(z2, z3); - // these sequences should be different due to rewind - ASSERT_NE(z0, z3); + // these sequences should be different due to rewind + ASSERT_NE(z0, z3); - // we'll be counting values > MAX_INT here - int maxes = 0; + // we'll be counting values > MAX_INT here + int maxes = 0; - for (int e = 0; e < z0.size(); e++) { - auto v = z0[e]; + for (int e = 0; e < z0.size(); e++) { + auto v = z0[e]; - // we don't want any negatives here - ASSERT_TRUE(v > 0); + // we don't want any negatives here + ASSERT_TRUE(v > 0); - if (v > DataTypeUtils::max()) - maxes++; - } + if (v > DataTypeUtils::max()) maxes++; + } - // and now we're ensuring there ARE values above MAX_INT - ASSERT_NE(0, maxes); + // and now we're ensuring there ARE values above MAX_INT + ASSERT_NE(0, maxes); } - TEST_F(GraphRandomGeneratorTests, FloatingPoint_Test_1) { - sd::graph::RandomGenerator g0(119, 5); - sd::graph::RandomGenerator g1(119, 5); - - std::array z0, z1, z2, z3; + sd::graph::RandomGenerator g0(119, 5); + sd::graph::RandomGenerator g1(119, 5); - for (int e = 0; e < z0.size(); e++) { - z0[e] = g0.relativeT(e, -1.0, 1.0); - z1[e] = g1.relativeT(e, -1.0, 1.0); - } + std::array z0, z1, z2, z3; - g0.rewindH(z0.size()); - g1.rewindH(z0.size()); + for (int e = 0; e < z0.size(); e++) { + z0[e] = g0.relativeT(e, -1.0, 1.0); + z1[e] = g1.relativeT(e, -1.0, 1.0); + } - for (int e = 0; e < z0.size(); e++) { - z2[e] = g0.relativeT(e, -1.0, 1.0); - z3[e] = g1.relativeT(e, -1.0, 1.0); - } + g0.rewindH(z0.size()); + g1.rewindH(z0.size()); - // these sequences should be equal - ASSERT_EQ(z0, z1); - ASSERT_EQ(z2, z3); + for (int e = 0; e < z0.size(); e++) { + z2[e] = g0.relativeT(e, -1.0, 1.0); + z3[e] = g1.relativeT(e, -1.0, 1.0); + } - // these sequences should be different due to rewind - ASSERT_NE(z0, z3); + // these sequences should be equal + ASSERT_EQ(z0, z1); + ASSERT_EQ(z2, z3); - // we'll count negatives as well - int negs = 0; + // these sequences should be different due to rewind + ASSERT_NE(z0, z3); - // make sure every value stays within distribution borders - for (int e = 0; e < z0.size(); e++) { - auto v = z0[e]; - if (!(v >= -1.0 && v <= 1.0)) { - nd4j_printf("Failed at idx [%i]: %f\n", e, (float) v); - ASSERT_TRUE(v >= -1.0 && v <= 1.0); - } + // we'll count negatives as well + int negs = 0; - if (v < 0.0) - negs++; + // make sure every value stays within distribution borders + for (int e = 0; e < z0.size(); e++) { + auto v = z0[e]; + if (!(v >= -1.0 && v <= 1.0)) { + nd4j_printf("Failed at idx [%i]: %f\n", e, (float)v); + ASSERT_TRUE(v >= -1.0 && v <= 1.0); } - // there should be negatives - ASSERT_TRUE(negs > 0); + if (v < 0.0) negs++; + } - // and positives - ASSERT_NE(z0.size(), negs); -} + // there should be negatives + ASSERT_TRUE(negs > 0); + // and positives + ASSERT_NE(z0.size(), negs); +} diff --git a/libnd4j/tests_cpu/layers_tests/GraphStateTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphStateTests.cpp deleted file mode 100644 index 16c1ed623d62..000000000000 --- a/libnd4j/tests_cpu/layers_tests/GraphStateTests.cpp +++ /dev/null @@ -1,349 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include "testlayers.h" -#include -#include -#include -#include -#include - -using namespace sd; -using namespace sd::graph; - -class GraphStateTests : public testing::Test { -public: - GraphStateTests() { - Environment::getInstance().setDebug(false); - Environment::getInstance().setVerbose(false); - }; - - ~GraphStateTests() { - Environment::getInstance().setDebug(false); - Environment::getInstance().setVerbose(false); - } -}; - -/* - * PLAN: - * Create GraphState - * Register Scope - * Add few Ops to it - * Call conditional, that refers to scopes - * Check results - */ - -TEST_F(GraphStateTests, Basic_Tests_1) { - auto state = (GraphState *) getGraphState(117L); - ASSERT_EQ(117L, state->id()); - - // this call will create scope internally - state->registerScope(119); - - sd::ops::add opA; - sd::ops::LegacyTransformSameOp opB(transform::Neg); // simdOps::Neg - - ArgumentsList argsA; - ArgumentsList argsB; - - state->attachOpToScope(119, 1, &opA, argsA); - state->attachOpToScope(119, 2, &opB, argsB); - - auto scope = state->getScope(119); - ASSERT_TRUE(scope != nullptr); - ASSERT_EQ(2, scope->size()); - - deleteGraphState(state); -} - -// just separate case for doubles wrapper in NativeOps, nothing else -TEST_F(GraphStateTests, Basic_Tests_2) { - auto state = (GraphState *) getGraphState(117L); - ASSERT_EQ(117L, state->id()); - - // this call will create scope internally - state->registerScope(119); - - sd::ops::add opA; - sd::ops::LegacyTransformSameOp opB(transform::Neg); // simdOps::Neg - - ArgumentsList argsA; - ArgumentsList argsB; - - state->attachOpToScope(119, 1, &opA, argsA); - state->attachOpToScope(119, 2, &opB, argsB); - - auto scope = state->getScope(119); - ASSERT_TRUE(scope != nullptr); - ASSERT_EQ(2, scope->size()); - - deleteGraphState(state); -} - -/* -TEST_F(GraphStateTests, Stateful_Execution_1) { - auto state = getGraphState(117L); - - Nd4jLong scopes[] = {22, 33}; - //auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0); - auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0); - - ASSERT_EQ(Status::THROW(), status); - - deleteGraphState(state); -} - -TEST_F(GraphStateTests, Stateful_Execution_2) { - auto state = (GraphState *) getGraphState(117L); - - state->registerScope(22); - state->registerScope(33); - - Nd4jLong scopes[] = {22, 33}; - auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0); - // it's no-op: just LogicScope - ASSERT_EQ(Status::OK(), status); - - deleteGraphState(state); -} - -// This test checks WHILE loop -TEST_F(GraphStateTests, Stateful_Execution_3) { - auto var0 = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto var1 = NDArrayFactory::create(11.0f); - auto var2 = NDArrayFactory::create(2.0f); - - auto res0 = NDArrayFactory::create('c', {2, 2}); - auto res1 = NDArrayFactory::create(0.0f); - auto res2 = NDArrayFactory::create(0.0f); - - // registering our GraphState holder - auto state = (GraphState *) getGraphState(117L); - - // we're prepping pointers to input/output buffers - Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer(), (Nd4jPointer)var2.buffer()}; - Nd4jPointer ptrShapes[] = {(Nd4jPointer) var0.shapeInfo(), (Nd4jPointer) var1.shapeInfo(), (Nd4jPointer)var2.shapeInfo()}; - - Nd4jPointer outBuffers[] = {(Nd4jPointer) res0.buffer(), (Nd4jPointer) res1.buffer(), (Nd4jPointer) res2.buffer()}; - Nd4jPointer outShapes[] = {(Nd4jPointer) res0.shapeInfo(), (Nd4jPointer) res1.shapeInfo(), (Nd4jPointer) res2.shapeInfo()}; - - // conditional scope - state->registerScope(22); - - sd::ops::LegacyReduceSameOp op1(reduce::Sum); - sd::ops::lt_scalar op2; - - // while sum(var0) < var1 - // this op takes sum - ArgumentsList args1({{0, 0}}); - - // this op compares result of sum to input variable 0:1 - ArgumentsList args2({{1, 0}, {0, 1}}); - - state->attachOpToScope(22, 1, &op1, args1); - state->attachOpToScope(22, 2, &op2, args2); - - // body scope - state->registerScope(33); - - // var0 + var1 + var1 - // this op is var0 + var1 - ArgumentsList args3({{0, 0}, {0, 2}}); - - // this op is result of previous op + 1 - ArgumentsList args4({{3, 0}, {0, 2}}); - - sd::ops::add op3; - sd::ops::add op4; - - state->attachOpToScope(33, 3, &op3, args3); - state->attachOpToScope(33, 4, &op4, args4); - - // Now we define RETURN, which returns 1 modified variable, and 2 unmodified variables - ArgumentsList args5({{4, 0}, {0, 1}, {0, 2}}); - - // so, at the end of body, initial variables will be updated - state->defineReturn(33, 5, args5); - - Nd4jLong scopes[] = {22, 33}; - - // we're executing while loop - auto status = execCustomOpWithScope(nullptr, state, 0, scopes, 2, ptrBuffers, ptrShapes, 3, outBuffers, outShapes, 3); - ASSERT_EQ(Status::OK(), status); - - // now we check provided result array - float sum = res0.reduceNumber(reduce::Sum).e(0); - - // Expected result is {1, 2, 3, 4} + {2} elementwise + {2} elementwise, which gives { 5, 6, 7, 8}, and sum should be 26 - ASSERT_NEAR(26.0f, sum, 1e-5); - - // nd4j_printf("0 ------------------\n",""); - - deleteGraphState(state); - - // nd4j_printf("1 ------------------\n",""); -} - -// This test checks CONDITIONAL execution for FALSE -TEST_F(GraphStateTests, Stateful_Execution_4) { - auto var0 = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto var1 = NDArrayFactory::create(5.0f); - - auto res0 = NDArrayFactory::create('c', {2, 2}); - auto res1 = NDArrayFactory::create(0.0f); - - auto exp = NDArrayFactory::create('c', {2, 2}, {-4, -3, -2, -1}); - - - // registering our GraphState holder - auto state = (GraphState *) getGraphState(117L); - - // we're prepping pointers to input/output buffers - Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()}; - Nd4jPointer ptrShapes[] = {(Nd4jPointer) var0.shapeInfo(), (Nd4jPointer) var1.shapeInfo()}; - - Nd4jPointer outBuffers[] = {(Nd4jPointer) res0.buffer(), (Nd4jPointer) res1.buffer()}; - Nd4jPointer outShapes[] = {(Nd4jPointer) res0.shapeInfo(), (Nd4jPointer) res1.shapeInfo()}; - - // conditional scope - state->registerScope(22); - - sd::ops::LegacyReduceSameOp op1(reduce::Sum); - sd::ops::lt_scalar op2; - - // if sum(var0) < var1 - // this op takes sum - ArgumentsList args1({{0, 0}}); - - // this op compares result of sum to input variable 0:1 - ArgumentsList args2({{1, 0}, {0, 1}}); - - state->attachOpToScope(22, 1, &op1, args1); - state->attachOpToScope(22, 2, &op2, args2); - - // false scope - state->registerScope(33); - - ArgumentsList args3({{0, 0}, {0, 1}}); - sd::ops::subtract op3; - state->attachOpToScope(33, 3, &op3, args3); - - // return for false scope - ArgumentsList args10({{3, 0}, {0, 1}}); - state->defineReturn(33, 10, args10); - - // true scope - state->registerScope(44); - - ArgumentsList args4({{0, 0}, {0, 1}}); - sd::ops::add op4; - state->attachOpToScope(44, 4, &op4, args4); - - // return for false scope - ArgumentsList args20({{4, 0}, {0, 1}}); - state->defineReturn(44, 20, args20); - - - Nd4jLong scopes[] = {22, 33, 44}; - - // we're executing conditional op - auto status = execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(exp.isSameShape(&res0)); - ASSERT_TRUE(exp.equalsTo(&res0)); - - - deleteGraphState(state); -} - - -// This test checks CONDITIONAL execution for TRUE -TEST_F(GraphStateTests, Stateful_Execution_5) { - auto var0 = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto var1 = NDArrayFactory::create(5.0f); - - auto res0 = NDArrayFactory::create('c', {2, 2}); - auto res1 = NDArrayFactory::create(0.0f); - - auto exp = NDArrayFactory::create('c', {2, 2}, {6, 7, 8, 9}); - - - // registering our GraphState holder - auto state = (GraphState *) getGraphState(117L); - - // we're prepping pointers to input/output buffers - Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()}; - Nd4jPointer ptrShapes[] = {(Nd4jPointer) var0.shapeInfo(), (Nd4jPointer) var1.shapeInfo()}; - - Nd4jPointer outBuffers[] = {(Nd4jPointer) res0.buffer(), (Nd4jPointer) res1.buffer()}; - Nd4jPointer outShapes[] = {(Nd4jPointer) res0.shapeInfo(), (Nd4jPointer) res1.shapeInfo()}; - - // conditional scope - state->registerScope(22); - - sd::ops::LegacyReduceSameOp op1(reduce::Sum); - sd::ops::gt_scalar op2; - - // if sum(var0) < var1 - // this op takes sum - ArgumentsList args1({{0, 0}}); - - // this op compares result of sum to input variable 0:1 - ArgumentsList args2({{1, 0}, {0, 1}}); - - state->attachOpToScope(22, 1, &op1, args1); - state->attachOpToScope(22, 2, &op2, args2); - - // false scope - state->registerScope(33); - - ArgumentsList args3({{0, 0}, {0, 1}}); - sd::ops::subtract op3; - state->attachOpToScope(33, 3, &op3, args3); - - // return for false scope - ArgumentsList args10({{3, 0}, {0, 1}}); - state->defineReturn(33, 10, args10); - - // true scope - state->registerScope(44); - - ArgumentsList args4({{0, 0}, {0, 1}}); - sd::ops::add op4; - state->attachOpToScope(44, 4, &op4, args4); - - // return for false scope - ArgumentsList args20({{4, 0}, {0, 1}}); - state->defineReturn(44, 20, args20); - - - Nd4jLong scopes[] = {22, 33, 44}; - - // we're executing conditional op - auto status = execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(exp.isSameShape(&res0)); - ASSERT_TRUE(exp.equalsTo(&res0)); - - deleteGraphState(state); -} -*/ \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/GraphTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphTests.cpp index 6d21b00f20e7..93365f60a65c 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphTests.cpp @@ -18,1595 +18,43 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace sd; -using namespace sd::graph; - -class GraphTests : public testing::Test { -public: - /* - int cShape[] = {2, 2, 2, 2, 1, 0, 1, 99}; - int fShape[] = {2, 2, 2, 1, 2, 0, 1, 102}; - */ - GraphTests() { - //Environment::getInstance().setDebug(true); - //Environment::getInstance().setVerbose(true); - } -}; - -TEST_F(GraphTests, SingleInput1) { - auto graph = new Graph(); - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0f); - - graph->getVariableSpace()->putVariable(-1, x); - - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_STRICT, transform::Cosine, 2, {1}, {3}); - auto nodeC = new Node(OpType_TRANSFORM_SAME, transform::Abs, 3, {2}, {}); - - graph->addNode(nodeA); - graph->addNode(nodeB); - graph->addNode(nodeC); - - ASSERT_EQ(1, graph->rootNodes()); - ASSERT_EQ(3, graph->totalNodes()); - - GraphExecutioner::execute(graph); - - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(3)); - - auto node3 = graph->getVariableSpace()->getVariable(3)->getNDArray(); - - ASSERT_NEAR(0.4161468, node3->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; -} - -TEST_F(GraphTests, DoubleInput1) { - auto graph = new Graph(); - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - auto y = NDArrayFactory::create_('c', {5, 5}); - y->assign(-1.0); - - auto z = NDArrayFactory::create_('c', {5, 5}); - - graph->getVariableSpace()->putVariable(-1, x); - graph->getVariableSpace()->putVariable(-2, y); - graph->getVariableSpace()->putVariable(-3, z); - - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {3}); - auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {-2}, {3}); - auto nodeC = new Node(OpType_PAIRWISE, pairwise::Add, 3, {1, 2}, {-3}); - - graph->addNode(nodeA); - graph->addNode(nodeB); - graph->addNode(nodeC); - - ASSERT_EQ(2, graph->rootNodes()); - ASSERT_EQ(3, graph->totalNodes()); - - GraphExecutioner::execute(graph); - - ASSERT_NEAR(3.0, z->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; -} - -TEST_F(GraphTests, SingleInput3) { - auto graph = new Graph(); - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - auto v0 = NDArrayFactory::create_('c', {5, 5}); - auto v1 = NDArrayFactory::create_('c', {5, 5}); - - graph->getVariableSpace()->putVariable(-1, x); - graph->getVariableSpace()->putVariable(-2, v0); - graph->getVariableSpace()->putVariable(-3, v1); - - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2, 3}); - auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {-2}); - auto nodeC = new Node(OpType_TRANSFORM_SAME, transform::Ones, 3, {1}, {-3}); - - graph->addNode(nodeA); - graph->addNode(nodeB); - graph->addNode(nodeC); - - ASSERT_EQ(1, graph->rootNodes()); - ASSERT_EQ(3, graph->totalNodes()); - - GraphExecutioner::execute(graph); - - ASSERT_NEAR(1.4142135, v0->reduceNumber(reduce::Mean).e(0), 1e-5); - ASSERT_NEAR(1.0, v1->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; -} - -TEST_F(GraphTests, SingleInput4) { - auto graph = new Graph(); - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - auto v0 = NDArrayFactory::create_('c', {5, 5}); - auto v1 = NDArrayFactory::create_('c', {5, 5}); - - graph->getVariableSpace()->putVariable(-1, x); - graph->getVariableSpace()->putVariable(-2, v0); - graph->getVariableSpace()->putVariable(-3, v1); - - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {3}); - auto nodeC = new Node(OpType_TRANSFORM_SAME, transform::Neg, 3, {2}, {4, 5}); - - auto nodeS = new Node(OpType_TRANSFORM_SAME, transform::Ones, 4, {3}, {-2}); - auto nodeE = new Node(OpType_TRANSFORM_SAME, transform::Identity, 5, {3}, {-3}); - - graph->addNode(nodeA); - graph->addNode(nodeB); - graph->addNode(nodeC); - graph->addNode(nodeS); - graph->addNode(nodeE); - - ASSERT_EQ(1, graph->rootNodes()); - ASSERT_EQ(5, graph->totalNodes()); - - GraphExecutioner::execute(graph); - - ASSERT_NEAR(1.0, v0->reduceNumber(reduce::Mean).e(0), 1e-5); - ASSERT_NEAR(-1.4142135, v1->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; -} - - -TEST_F(GraphTests, DoubleInput2) { - auto graph = new Graph(); - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - auto y = NDArrayFactory::create_('c', {5, 5}); - y->assign(-1.0); - - auto z0 = NDArrayFactory::create_('c', {5, 5}); - auto z1 = NDArrayFactory::create_('c', {5, 5}); - - graph->getVariableSpace()->putVariable(-1, x); - graph->getVariableSpace()->putVariable(-2, y); - graph->getVariableSpace()->putVariable(-3, z0); - graph->getVariableSpace()->putVariable(-4, z1); - - - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {3}); - auto nodeC = new Node(OpType_TRANSFORM_SAME, transform::Neg, 3, {2}, {-3}); - - auto nodeT = new Node(OpType_TRANSFORM_SAME, transform::Abs, 11, {-2}, {12}); - auto nodeU = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 12, {11}, {13}); - auto nodeV = new Node(OpType_TRANSFORM_SAME, transform::Neg, 13, {12}, {-4}); - - graph->addNode(nodeA); - graph->addNode(nodeB); - graph->addNode(nodeC); - graph->addNode(nodeT); - graph->addNode(nodeU); - graph->addNode(nodeV); - - ASSERT_EQ(2, graph->rootNodes()); - ASSERT_EQ(6, graph->totalNodes()); - - GraphExecutioner::execute(graph); - - ASSERT_NEAR(-1.4142135, z0->reduceNumber(reduce::Mean).e(0), 1e-5); - ASSERT_NEAR(-1.0, z1->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; -} - - -TEST_F(GraphTests, DoubleInput3) { - auto graph = new Graph(); - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - auto y = NDArrayFactory::create_('c', {5, 5}); - y->assign(-1.0); - - auto z0 = NDArrayFactory::create_('c', {5, 5}); - auto z1 = NDArrayFactory::create_('c', {5, 5}); - - - auto w = NDArrayFactory::create_('c', {5, 5}); - - graph->getVariableSpace()->putVariable(-1, x); - graph->getVariableSpace()->putVariable(-2, y); - graph->getVariableSpace()->putVariable(-3, z0); - graph->getVariableSpace()->putVariable(-4, z1); - graph->getVariableSpace()->putVariable(-5, w); - - - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {3}); - auto nodeC = new Node(OpType_TRANSFORM_SAME, transform::Neg, 3, {2}, {-3, 21}); - - auto nodeT = new Node(OpType_TRANSFORM_SAME, transform::Abs, 11, {-2}, {12}); - auto nodeU = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 12, {11}, {13}); - auto nodeV = new Node(OpType_TRANSFORM_SAME, transform::Neg, 13, {12}, {-4, 21}); - - auto nodeW = new Node(OpType_PAIRWISE, pairwise::Add, 21, {3, 13}, {22}); - auto nodeZ = new Node(OpType_TRANSFORM_SAME, transform::Abs, 22, {21}, {-5}); - - graph->addNode(nodeA); - graph->addNode(nodeB); - graph->addNode(nodeC); - graph->addNode(nodeT); - graph->addNode(nodeU); - graph->addNode(nodeV); - graph->addNode(nodeW); - graph->addNode(nodeZ); - - ASSERT_EQ(2, graph->rootNodes()); - ASSERT_EQ(8, graph->totalNodes()); - - GraphExecutioner::execute(graph); - - ASSERT_NEAR(-1.4142135, z0->reduceNumber(reduce::Mean).e(0), 1e-5); - ASSERT_NEAR(-1.0, z1->reduceNumber(reduce::Mean).e(0), 1e-5); - - ASSERT_NEAR(2.4142135, w->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; -} - - -TEST_F(GraphTests, QuadInput1) { - auto graph = new Graph(); - - auto x0 = NDArrayFactory::create_('c', {5, 5}); - x0->assign(0.0); - - auto x1 = NDArrayFactory::create_('c', {5, 5}); - x1->assign(-1.0); - - auto x2 = NDArrayFactory::create_('c', {5, 5}); - x2->assign(-2.0); - - auto x3 = NDArrayFactory::create_('c', {5, 5}); - x3->assign(-3.0); - - auto z = NDArrayFactory::create_('c', {5, 5}); - z->assign(119.0); - - graph->getVariableSpace()->putVariable(-1, x0); - graph->getVariableSpace()->putVariable(-2, x1); - graph->getVariableSpace()->putVariable(-3, x2); - graph->getVariableSpace()->putVariable(-4, x3); - graph->getVariableSpace()->putVariable(-5, z); - - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {11}); - auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {-2}, {11}); - auto nodeC = new Node(OpType_TRANSFORM_SAME, transform::Abs, 3, {-3}, {21}); - auto nodeD = new Node(OpType_TRANSFORM_SAME, transform::Abs, 4, {-4}, {21}); - - auto nodeP1 = new Node(OpType_PAIRWISE, pairwise::Add, 11, {1, 2}, {31}); - auto nodeP2 = new Node(OpType_PAIRWISE, pairwise::Add, 21, {3, 4}, {31}); - - auto nodeZ = new Node(OpType_PAIRWISE, pairwise::Add, 31, {11, 21}, {-5}); - - graph->addNode(nodeA); - graph->addNode(nodeB); - graph->addNode(nodeC); - graph->addNode(nodeD); - graph->addNode(nodeP1); - graph->addNode(nodeP2); - graph->addNode(nodeZ); - - ASSERT_EQ(4, graph->rootNodes()); - ASSERT_EQ(7, graph->totalNodes()); - - GraphExecutioner::execute(graph); - - ASSERT_NEAR(6.0, z->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; -} - -TEST_F(GraphTests, InternalBranching1) { - auto graph = new Graph(); - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(0.0); - - auto z = NDArrayFactory::create_('c', {5, 5}); - - graph->getVariableSpace()->putVariable(-1, x); - graph->getVariableSpace()->putVariable(-2, z); - - // 1.0 - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Ones, 1, {-1}, {11, 21}); - - // -1 - auto nodeK = new Node(OpType_TRANSFORM_SAME, transform::Neg, 11, {1}, {12}); - - // 2.0 - auto nodeL = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 12, {11}, {31}); - - // -1 - auto nodeR = new Node(OpType_TRANSFORM_SAME, transform::Neg, 21, {1}, {22}); - - // 1 - auto nodeS = new Node(OpType_TRANSFORM_SAME, transform::Neg, 22, {21}, {31}); - - // 1.0 - auto nodeZ = new Node(OpType_PAIRWISE, pairwise::Add, 31, {12, 22}, {-2}); - - graph->addNode(nodeA); - graph->addNode(nodeK); - graph->addNode(nodeL); - graph->addNode(nodeR); - graph->addNode(nodeS); - graph->addNode(nodeZ); - - ASSERT_EQ(1, graph->rootNodes()); - ASSERT_EQ(6, graph->totalNodes()); - - GraphExecutioner::execute(graph); - - ASSERT_EQ(3, nodeZ->getLayer()); - - ASSERT_NEAR(3.0, z->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; -} - - -TEST_F(GraphTests, ReductionsTest1) { - auto graph = new Graph(); - - auto x = NDArrayFactory::create_('c', {5, 5}); - for (int r = 0; r < x->rows(); r++) { - for (int c = 0; c < x->columns(); c++) { - x->p(r, c, -c); - } - } - - auto z = NDArrayFactory::create_('c', {5}); - - graph->getVariableSpace()->putVariable(-1, x); - graph->getVariableSpace()->putVariable(-2, z); - -// sd::graph::Node::Node(OpType opType, int opNum, int id, std::initializer_list input, std::initializer_list output, std::initializer_list dimensions, float scalar, std::initializer_list tArgs, std::initializer_list iArgs) { - - auto nodeA = new Node(OpType_REDUCE_FLOAT, reduce::Mean, 1, {-1}, {2}, {1}, {}); - auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {-2}); - - graph->addNode(nodeA); - graph->addNode(nodeB); - - ASSERT_EQ(1, graph->rootNodes()); - ASSERT_EQ(2, graph->totalNodes()); - - GraphExecutioner::execute(graph); - - ASSERT_NEAR(2.0, z->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; -} - - -TEST_F(GraphTests, IndexReductionsTest1) { - auto graph = new Graph(); - - auto x = NDArrayFactory::create_('c', {5, 5}); - for (int r = 0; r < x->rows(); r++) { - for (int c = 0; c < x->columns(); c++) { - x->p(r, c, -c); - } - } - - auto z = NDArrayFactory::create_('c', {5, 1}); - auto axis = NDArrayFactory::create_('c', {1}, {1}); - graph->getVariableSpace()->putVariable(-1, x); - graph->getVariableSpace()->putVariable(-2, z); - //graph->getVariableSpace()->putVariable(-3, axis); - - - auto nodeA = new Node(OpType_INDEX_REDUCE, indexreduce::IndexMin, 1, {-1}, {2}, {1}); - auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {-2}); - - graph->addNode(nodeA); - graph->addNode(nodeB); - - ASSERT_EQ(1, graph->rootNodes()); - ASSERT_EQ(2, graph->totalNodes()); - - GraphExecutioner::execute(graph); - - ASSERT_NEAR(4.0, z->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; - delete axis; -} - -#if 0 -TEST_F(GraphTests, AutoOutput1) { - auto graph = new Graph(); - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - graph->getVariableSpace()->putVariable(-1, x); - - auto nodeA = new Node(OpType_TRANSFORM_FLOAT, 0, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_FLOAT, 35, 2, {1}, {}); - - graph->addNode(nodeA); - graph->addNode(nodeB); - - ASSERT_EQ(1, graph->rootNodes()); - ASSERT_EQ(2, graph->totalNodes()); - - graph->buildGraph(); - - ASSERT_TRUE(graph->getVariableSpace()->getVariable(2) != nullptr); - - GraphExecutioner::execute(graph); - - auto outputs = graph->fetchOutputs(); - - ASSERT_EQ(1, outputs->size()); - - ASSERT_TRUE(outputs->at(0) != nullptr); - - ASSERT_NEAR(-1.0, outputs->at(0)->getNDArray()->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete outputs; - delete graph; -} - - -TEST_F(GraphTests, AutoOutput2) { - auto graph = new Graph(); - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - graph->getVariableSpace()->putVariable(-1, x); - - auto nodeA = new Node(OpType_TRANSFORM_SAME, 0, 1, {-1}, {2, 3, -1}); - auto nodeB = new Node(OpType_TRANSFORM_SAME, 35, 2, {1}, {}); - auto nodeC = new Node(OpType_TRANSFORM_SAME, 6, 3, {1}, {}); - - graph->addNode(nodeA); - graph->addNode(nodeB); - graph->addNode(nodeC); - - ASSERT_EQ(1, graph->rootNodes()); - ASSERT_EQ(3, graph->totalNodes()); - - graph->buildGraph(); - - ASSERT_TRUE(graph->getVariableSpace()->getVariable(-1) != nullptr); - ASSERT_TRUE(graph->getVariableSpace()->getVariable(2) != nullptr); - ASSERT_TRUE(graph->getVariableSpace()->getVariable(3) != nullptr); - - GraphExecutioner::execute(graph); - - auto outputs = graph->fetchOutputs(); - - ASSERT_EQ(2, outputs->size()); - - ASSERT_TRUE(outputs->at(0) != nullptr); - - ASSERT_NEAR(-1.0, outputs->at(0)->getNDArray()->reduceNumber(reduce::Mean).e(0), 1e-5); - ASSERT_NEAR(-2.0, outputs->at(1)->getNDArray()->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; - delete outputs; -} -#endif - -TEST_F(GraphTests, BroadcastTest1) { - auto graph = new Graph(); - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(0.f); - - auto y = NDArrayFactory::create_('c', {1, 5}); - for (int e = 0; e < y->columns(); e++) { - y->p(e, (float)e+1); - } - - auto z = NDArrayFactory::create_('c', {5, 5}); - - graph->getVariableSpace()->putVariable(-1, x); - graph->getVariableSpace()->putVariable(-2, y); - graph->getVariableSpace()->putVariable(-3, z); - - auto nodeA = new Node(OpType_BROADCAST, broadcast::Subtract, 1, {-1, -2}, {2}, {1}); - auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Neg, 2, {1}, {-3}); - - graph->addNode(nodeA); - graph->addNode(nodeB); - - GraphExecutioner::execute(graph); - - ASSERT_NEAR(3.0, z->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; -} - - -TEST_F(GraphTests, ScalarTest1) { - auto graph = new Graph(); - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - auto z = NDArrayFactory::create_('c', {5, 5}); - - graph->getVariableSpace()->putVariable(-1, x); - graph->getVariableSpace()->putVariable(-2, z); - - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {3}); - auto nodeE = new Node(OpType_SCALAR, scalar::Add, 3, {2}, {-2}, {}, 1.3f); - - graph->addNode(nodeA); - graph->addNode(nodeB); - graph->addNode(nodeE); - - ASSERT_EQ(1, graph->rootNodes()); - ASSERT_EQ(3, graph->totalNodes()); - - GraphExecutioner::execute(graph); - - ASSERT_NEAR(2.714213, z->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; -} - -TEST_F(GraphTests, SymbolicLookupTest1) { - auto graph = new Graph(); - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - auto z = NDArrayFactory::create_('c', {5, 5}); - - auto vX = new Variable(x); - auto vZ = new Variable(z); - - std::string a("alpha"); - std::string o("omega"); - - vX->setName(&a); - vZ->setName(&o); - - graph->getVariableSpace()->putVariable(-1, vX); - graph->getVariableSpace()->putVariable(-2, vZ); - - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {-2}); - - std::string p("phi"); - std::string t("theta"); - - nodeA->setName(&p); - nodeB->setName(&t); - - graph->addNode(nodeA); - graph->addNode(nodeB); - - - auto rX = graph->getVariableSpace()->getVariable(&a); - auto rZ = graph->getVariableSpace()->getVariable(&o); - - std::string om("omicron"); - - ASSERT_TRUE(rX->getNDArray() == vX->getNDArray()); - ASSERT_TRUE(rZ->getNDArray() == vZ->getNDArray()); - ASSERT_FALSE(graph->getVariableSpace()->hasVariable(&om)); - - - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1)); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(2)); - - GraphExecutioner::execute(graph); - - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(&p)); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(&t)); - - ASSERT_NEAR(1.4142135, z->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; -} - -TEST_F(GraphTests, OutputValidation1) { - auto graph = new Graph(); - - graph->getExecutorConfiguration()->_outputMode = OutputMode_EXPLICIT; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - auto z = NDArrayFactory::create_('c', {5, 5}); - - auto vX = new Variable(x); - auto vZ = new Variable(z); - - std::string a("alpha"); - std::string o("omega"); - - vX->setName(&a); - vZ->setName(&o); - - graph->getVariableSpace()->putVariable(-1, vX); - graph->getVariableSpace()->putVariable(-2, vZ); - - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {-2}); - - graph->addNode(nodeA); - graph->addNode(nodeB); - - - auto outputs = graph->fetchOutputs(); - - ASSERT_EQ(0, outputs->size()); - - delete graph; - delete outputs; -} - -TEST_F(GraphTests, OutputValidation2) { - auto graph = new Graph(); - - graph->getExecutorConfiguration()->_outputMode = OutputMode_EXPLICIT; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - auto z = NDArrayFactory::create_('c', {5, 5}); - - auto vX = new Variable(x); - auto vZ = new Variable(z); - - std::string a("alpha"); - std::string o("omega"); - - vX->setName(&a); - vZ->setName(&o); - - graph->getVariableSpace()->putVariable(-1, vX); - graph->getVariableSpace()->putVariable(-2, vZ); - - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {-2}); - - graph->addNode(nodeA); - graph->addNode(nodeB); - - graph->addOutput(-2); - - GraphExecutioner::execute(graph); - - auto outputs = graph->fetchOutputs(); - - ASSERT_EQ(1, outputs->size()); - - ASSERT_NEAR(1.4142135, outputs->at(0)->getNDArray()->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; - delete outputs; -} - -TEST_F(GraphTests, OutputValidation3) { - auto graph = new Graph(); - - graph->getExecutorConfiguration()->_outputMode = OutputMode_IMPLICIT; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - auto z = NDArrayFactory::create_('c', {5, 5}); - - auto vX = new Variable(x); - auto vZ = new Variable(z); - - std::string a("alpha"); - std::string o("omega"); - - vX->setName(&a); - vZ->setName(&o); - - graph->getVariableSpace()->putVariable(-1, vX); - graph->getVariableSpace()->putVariable(-2, vZ); - - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {}); - - graph->addNode(nodeA); - graph->addNode(nodeB); - - GraphExecutioner::execute(graph); - - auto outputs = graph->fetchOutputs(); - - ASSERT_EQ(1, outputs->size()); - - ASSERT_NEAR(1.4142135, outputs->at(0)->getNDArray()->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; - delete outputs; -} - -TEST_F(GraphTests, OutputValidation4) { - auto graph = new Graph(); - - graph->getExecutorConfiguration()->_outputMode = OutputMode_EXPLICIT_AND_IMPLICIT; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - auto z = NDArrayFactory::create_('c', {5, 5}); - - auto vX = new Variable(x); - auto vZ = new Variable(z); - - std::string a("alpha"); - std::string o("omega"); - - vX->setName(&a); - vZ->setName(&o); - - graph->getVariableSpace()->putVariable(-1, vX); - graph->getVariableSpace()->putVariable(-2, vZ); - - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {-2}); - - graph->addOutput(-1); - - // not a typo. we want this value only once - graph->addOutput(-1); - - graph->addNode(nodeA); - graph->addNode(nodeB); - - GraphExecutioner::execute(graph); - - auto outputs = graph->fetchOutputs(); - - ASSERT_EQ(2, outputs->size()); - - ASSERT_NEAR(1.4142135, outputs->at(1)->getNDArray()->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; - delete outputs; -} - - -TEST_F(GraphTests, OutputValidation5) { - auto graph = new Graph(); - - graph->getExecutorConfiguration()->_outputMode = OutputMode_VARIABLE_SPACE; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - auto z = NDArrayFactory::create_('c', {5, 5}); - - auto vX = new Variable(x); - auto vZ = new Variable(z); - - std::string a("alpha"); - std::string o("omega"); - - vX->setName(&a); - vZ->setName(&o); - - graph->getVariableSpace()->putVariable(-1, vX); - graph->getVariableSpace()->putVariable(-2, vZ); - - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Sqrt, 2, {1}, {-2}); - - graph->addOutput(-1); - - graph->addNode(nodeA); - graph->addNode(nodeB); - - GraphExecutioner::execute(graph); - - auto outputs = graph->fetchOutputs(); - - ASSERT_EQ(4, outputs->size()); - - delete graph; - delete outputs; -} - -TEST_F(GraphTests, OutputValidation6) { - auto graph = new Graph(); - - graph->getExecutorConfiguration()->_outputMode = OutputMode_VARIABLE_SPACE; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - auto z = NDArrayFactory::create_('c', {5, 5}); - - auto vX = new Variable(x); - auto vZ = new Variable(z); - - std::string a("alpha"); - std::string o("omega"); - - vX->setName(&a); - vZ->setName(&o); - - graph->getVariableSpace()->putVariable(-1, vX); - graph->getVariableSpace()->putVariable(-2, vZ); - - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {}); - - //graph->addOutput(-1); - - graph->addNode(nodeA); - graph->addNode(nodeB); - - GraphExecutioner::execute(graph); - - auto outputs = graph->fetchOutputs(); - -// nd4j_printf("Returned variables: \n", ""); -// for (int e = 0; e < outputs->size(); e++) { -// printf("%i, ", outputs->at(e)->id()); -// } -// printf("\n"); - - ASSERT_EQ(4, outputs->size()); - - //ASSERT_NEAR(1.4142135, graph->fetchOutputs()->at(1)->getNDArray()->reduceNumber(reduce::Mean).e(0), 1e-5); - delete graph; - delete outputs; -} - -TEST_F(GraphTests, TestMultiOutput1) { - sd::ops::testop2i2o op1; - auto graph = new Graph(); - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - auto y = NDArrayFactory::create_('c', {5, 5}); - y->assign(-3.0); - - graph->getVariableSpace()->putVariable(-1, x); - graph->getVariableSpace()->putVariable(-2, y); - - - // Abs - auto nodeA0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {11}); - nodeA0->markInplace(false); - auto nodeB0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {-2}, {11}); - nodeB0->markInplace(false); - - auto op = sd::ops::OpRegistrator::getInstance().getOperation("testop2i2o"); - - // this op will add 1.0 to first input, and 2.0 for second input - auto nodeT = new Node(op, 11, {1, 2}, {21, 31}, {}, 0.0f); - nodeT->setName("TestOp2i2o"); - nodeT->markInplace(false); - - - // this op will subtract this value from 1.0 - auto nodeX = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 21); - nodeX->markInplace(false); - nodeX->pickInput(11, 0); - - // this op will subtract this value from 1.0 - auto nodeY = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 31); - nodeY->markInplace(false); - nodeY->pickInput(11, 1); - - graph->addNode(nodeA0); - graph->addNode(nodeB0); - graph->addNode(nodeT); - graph->addNode(nodeX); - graph->addNode(nodeY); - - std::pair pair0(11,0); - std::pair pair1(11,1); - - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(pair0)); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(pair1)); - - Nd4jStatus status = GraphExecutioner::execute(graph); - - ASSERT_EQ(ND4J_STATUS_OK, status); - - ASSERT_NEAR(-2.0f, graph->getVariableSpace()->getVariable(21)->getNDArray()->meanNumber().e(0), 1e-5); - ASSERT_NEAR(-4.0f, graph->getVariableSpace()->getVariable(31)->getNDArray()->meanNumber().e(0), 1e-5); - - delete graph; -} - -TEST_F(GraphTests, TestDivergentNode1) { - auto op = sd::ops::OpRegistrator::getInstance().getOperation("Switch"); - auto nodeY = new Node(op, 1); - - ASSERT_TRUE(nodeY->isDivergencePoint()); - ASSERT_TRUE(nodeY->isActive()); - - delete nodeY; -} - - -TEST_F(GraphTests, MemoryEstimationTest1) { - Graph graph; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - graph.getVariableSpace()->putVariable(-1, x); - - auto nodeA0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeA1 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {}); - nodeA1->markInplace(false); - - graph.addNode(nodeA0); - graph.addNode(nodeA1); - - ASSERT_EQ(2, graph.totalNodes()); - ASSERT_EQ(1, graph.rootNodes()); - - auto memReq = graph.estimateRequiredMemory(); - - ASSERT_EQ(25 * x->sizeOfT(), memReq); -} - -TEST_F(GraphTests, MemoryEstimationTest2) { - Graph graph; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - graph.getVariableSpace()->putVariable(-1, x); - - auto nodeA0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeA1 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {}); - //nodeA1->markInplace(false); - - graph.addNode(nodeA0); - graph.addNode(nodeA1); - - ASSERT_EQ(2, graph.totalNodes()); - ASSERT_EQ(1, graph.rootNodes()); - - auto memReq = graph.estimateRequiredMemory(); - - ASSERT_EQ(0, memReq); -} - -TEST_F(GraphTests, MemoryEstimationTest3) { - Graph graph; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - graph.getVariableSpace()->putVariable(-1, x); - - auto nodeA0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeA1 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {3}); - auto nodeA2 = new Node(OpType_REDUCE_FLOAT, reduce::Mean, 3, {2}, {}, {}); - nodeA1->markInplace(false); - - graph.addNode(nodeA0); - graph.addNode(nodeA1); - graph.addNode(nodeA2); - - ASSERT_EQ(3, graph.totalNodes()); - ASSERT_EQ(1, graph.rootNodes()); - - auto memReq = graph.estimateRequiredMemory(); - - ASSERT_EQ(26 * x->sizeOfT(), memReq); -} - -TEST_F(GraphTests, MemoryEstimationTest4) { - Graph graph; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - graph.getVariableSpace()->putVariable(-1, x); - - auto nodeA0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeA1 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {3}); - auto nodeA2 = new Node(OpType_REDUCE_FLOAT, reduce::Mean, 3, {2}, {}, {1}); - nodeA1->markInplace(false); - - graph.addNode(nodeA0); - graph.addNode(nodeA1); - graph.addNode(nodeA2); - - ASSERT_EQ(3, graph.totalNodes()); - ASSERT_EQ(1, graph.rootNodes()); - - auto memReq = graph.estimateRequiredMemory(); - - ASSERT_EQ(30 * x->sizeOfT(), memReq); -} - -TEST_F(GraphTests, MemoryEstimationTest5) { - Graph graph; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - graph.getVariableSpace()->putVariable(-1, x); - - sd::ops::testcustom op; - - auto nodeA0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeA1 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {3}); - auto nodeA2 = new Node(&op, 3, {2}, {}, {}); - nodeA1->markInplace(false); - - - graph.addNode(nodeA0); - graph.addNode(nodeA1); - graph.addNode(nodeA2); - - graph.buildGraph(); - - ASSERT_EQ(3, graph.totalNodes()); - ASSERT_EQ(1, graph.rootNodes()); - - auto memReq = graph.estimateRequiredMemory(); - - ASSERT_EQ((25 + 100) * x->sizeOfT(), memReq); -} - -TEST_F(GraphTests, TestGraphInGraph_1) { - // this one is external graph - Graph graphA; - - // and this ons is embedded - Graph graphB; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-5.0); - - auto modifier = NDArrayFactory::create_('c', {5, 5}); - modifier->assign(3.0); - - graphA.getVariableSpace()->putVariable(-1, x); - graphB.getVariableSpace()->putVariable(-2, modifier); - - // this is placeholder variable - graphB.getVariableSpace()->putVariable(-1, new Variable(true)); - - // abs, result is 5 - auto nodeA0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - // 1-, result -4 - auto nodeA1 = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 2, {1}, {3}); - - // graph should return 12: abs(3.0 x -4) - auto nodeA2 = new Node(OpType_GRAPH, -1, 3, {2}, {4}); - - // 1 - 12 = -11 - auto nodeA3 = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 4, {3}, {}); - - nodeA2->setGraph(&graphB); - - graphA.addNode(nodeA0); - graphA.addNode(nodeA1); - graphA.addNode(nodeA2); - graphA.addNode(nodeA3); - - // this is going to be PWT - auto nodeB0 = new Node(OpType_PAIRWISE, pairwise::Multiply, 1, {-1, -2}, {2}); - auto nodeB1 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {}); - - graphB.addNode(nodeB0); - graphB.addNode(nodeB1); - - graphB.buildGraph(); - graphA.buildGraph(); - - ASSERT_EQ(0, nodeA0->getLayer()); - ASSERT_EQ(1, nodeA1->getLayer()); - ASSERT_EQ(2, nodeA2->getLayer()); - ASSERT_EQ(3, nodeA3->getLayer()); - - ASSERT_EQ(0, nodeB0->getLayer()); - ASSERT_EQ(1, nodeB1->getLayer()); - - Nd4jStatus status = GraphExecutioner::execute(&graphA); - ASSERT_EQ(ND4J_STATUS_OK, status); - - float m = graphA.getVariableSpace()->getVariable(4)->getNDArray()->meanNumber().e(0); - - //nd4j_printf("OpResult: %f\n", m); - - ASSERT_NEAR(-11.0, m, 1e-5); -} - -// test for symbolic lookup -TEST_F(GraphTests, TestGraphInGraph_2) { - // this one is external graph - Graph graphA; - - // and this ons is embedded - Graph graphB; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-5.0); - - auto modifier = NDArrayFactory::create_('c', {5, 5}); - modifier->assign(3.0); - - std::string nameA1("_nodeA1"); - - graphA.getVariableSpace()->putVariable(-1, x); - graphB.getVariableSpace()->putVariable(-2, modifier); - - // this is placeholder variable - auto placeHolder = new Variable(true); - placeHolder->setName(&nameA1); - graphB.getVariableSpace()->putVariable(-1, placeHolder); - - // abs, result is 5 - auto nodeA0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - // 1-, result -4 - auto nodeA1 = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 2, {1}, {3}); - nodeA1->setName(nameA1); - - // graph should return 12: abs(3.0 x -4) - auto nodeA2 = new Node(OpType_GRAPH, -1, 3, {2}, {4}); - - // 1 - 12 = -11 - auto nodeA3 = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 4, {3}, {}); - - nodeA2->setGraph(&graphB); - - graphA.addNode(nodeA0); - graphA.addNode(nodeA1); - graphA.addNode(nodeA2); - graphA.addNode(nodeA3); - - // this is going to be PWT - auto nodeB0 = new Node(OpType_PAIRWISE, pairwise::Multiply, 1, {-1, -2}, {2}); - auto nodeB1 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {}); - - graphB.addNode(nodeB0); - graphB.addNode(nodeB1); - - graphB.buildGraph(); - graphA.buildGraph(); - - ASSERT_EQ(0, nodeA0->getLayer()); - ASSERT_EQ(1, nodeA1->getLayer()); - ASSERT_EQ(2, nodeA2->getLayer()); - ASSERT_EQ(3, nodeA3->getLayer()); - - ASSERT_EQ(0, nodeB0->getLayer()); - ASSERT_EQ(1, nodeB1->getLayer()); - - Nd4jStatus status = GraphExecutioner::execute(&graphA); - ASSERT_EQ(ND4J_STATUS_OK, status); - - float m = graphA.getVariableSpace()->getVariable(4)->getNDArray()->meanNumber().e(0); - - //nd4j_printf("OpResult: %f\n", m); - - ASSERT_NEAR(-11.0, m, 1e-5); -} - -#if 0 -TEST_F(GraphTests, Test_Clone_1) { - auto exp = NDArrayFactory::create('c', {3}); - exp.assign(3.0); - - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); - auto variableSpace = graph->getVariableSpace(); - //graph->buildGraph(); - - auto clone = graph->clone(); - - Nd4jStatus statusOriginal = GraphExecutioner::execute(graph); - - ASSERT_EQ(ND4J_STATUS_OK, statusOriginal); - ASSERT_TRUE(variableSpace->hasVariable(3)); - - Nd4jStatus statusClone = GraphExecutioner::execute(clone); - - ASSERT_EQ(ND4J_STATUS_OK, statusClone); - - ASSERT_TRUE(variableSpace->hasVariable(3)); - - auto z0 = variableSpace->getVariable(3)->getNDArray(); - auto z1 = clone->getVariableSpace()->getVariable(3)->getNDArray(); - - ASSERT_TRUE(exp.isSameShape(z0)); - ASSERT_TRUE(exp.equalsTo(z0)); - - ASSERT_TRUE(exp.isSameShape(z1)); - ASSERT_TRUE(exp.equalsTo(z1)); - - delete graph; - delete clone; -} - - - - -TEST_F(GraphTests, Test_Clone_2) { - auto exp = NDArrayFactory::create('c', {3}); - exp.assign(3.0); - - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); - auto variableSpace = graph->getVariableSpace(); - graph->buildGraph(); - - auto clone = graph->clone(); - - Nd4jStatus statusOriginal = GraphExecutioner::execute(graph); - - ASSERT_EQ(ND4J_STATUS_OK, statusOriginal); - ASSERT_TRUE(variableSpace->hasVariable(3)); - - Nd4jStatus statusClone = GraphExecutioner::execute(clone); - - ASSERT_EQ(ND4J_STATUS_OK, statusClone); - - ASSERT_TRUE(variableSpace->hasVariable(3)); - - auto z0 = variableSpace->getVariable(3)->getNDArray(); - auto z1 = clone->getVariableSpace()->getVariable(3)->getNDArray(); - - ASSERT_TRUE(exp.isSameShape(z0)); - ASSERT_TRUE(exp.equalsTo(z0)); - - ASSERT_TRUE(exp.isSameShape(z1)); - ASSERT_TRUE(exp.equalsTo(z1)); - - delete graph; - delete clone; -} - -TEST_F(GraphTests, Test_Dtype_Conversion_1) { - /*auto expD = NDArrayFactory::create('c', {3}, {3.0, 3.0, 3.0}); - auto expF = NDArrayFactory::create('c', {3}, {3.0, 3.0, 3.0}); - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); - graph->buildGraph(); - - - auto gd = graph->template asT(); - auto gf = gd->template asT(); - - // checking float graph - Nd4jStatus statusF = GraphExecutioner::execute(gf); - ASSERT_EQ(ND4J_STATUS_OK, statusF); - - ASSERT_TRUE(gf->getVariableSpace()->hasVariable(3)); - - ASSERT_TRUE(gf->getVariableSpace()->hasVariable(3)); - auto z1 = gf->getVariableSpace()->getVariable(3)->getNDArray(); - - ASSERT_TRUE(expF.isSameShape(z1)); - ASSERT_TRUE(expF.equalsTo(z1)); - - - // checking double graph - Nd4jStatus statusD = GraphExecutioner::execute(gd); - ASSERT_EQ(ND4J_STATUS_OK, statusD); - - ASSERT_TRUE(gd->getVariableSpace()->hasVariable(3)); - auto z2 = gd->getVariableSpace()->getVariable(3)->getNDArray(); - - ASSERT_TRUE(expD.isSameShape(z2)); - ASSERT_TRUE(expD.equalsTo(z2)); - - - delete graph; - delete gd; - delete gf; - */ -} - -TEST_F(GraphTests, Test_Dtype_Conversion_2) { - /* - NDArray expF('c', {5, 4}, {0.32454616f, -0.06604697f, 0.22593613f, 0.43166467f, -0.18320604f, 0.00102305f, -0.06963076f, 0.25266643f, 0.07568010f, -0.03009197f, 0.07805517f, 0.33180334f, -0.06220427f, 0.07249600f, -0.06726961f, -0.22998397f, -0.06343779f, 0.07384885f, -0.06891008f, -0.23745790f}); - NDArray expD('c', {5, 4}, {0.32454616f, -0.06604697f, 0.22593613f, 0.43166467f, -0.18320604f, 0.00102305f, -0.06963076f, 0.25266643f, 0.07568010f, -0.03009197f, 0.07805517f, 0.33180334f, -0.06220427f, 0.07249600f, -0.06726961f, -0.22998397f, -0.06343779f, 0.07384885f, -0.06891008f, -0.23745790f}); - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); - graph->buildGraph(); - - - auto gd = graph->template asT(); - auto gf = gd->template asT(); - - // checking float - auto resultF = GraphExecutioner::execute(gf); - ASSERT_EQ(ND4J_STATUS_OK, resultF); - ASSERT_TRUE(gf->getVariableSpace()->hasVariable(18)); - auto zF = gf->getVariableSpace()->getVariable(18)->getNDArray(); - - ASSERT_TRUE(expF.isSameShape(zF)); - ASSERT_TRUE(expF.equalsTo(zF)); - - - // checking double - auto resultD = GraphExecutioner::execute(gd); - ASSERT_EQ(ND4J_STATUS_OK, resultD); - ASSERT_TRUE(gd->getVariableSpace()->hasVariable(18)); - auto zD = gd->getVariableSpace()->getVariable(18)->getNDArray(); - - ASSERT_TRUE(expD.isSameShape(zD)); - ASSERT_TRUE(expD.equalsTo(zD)); - - delete graph; - delete gd; - delete gf; - */ -} - -TEST_F(GraphTests, Test_Hash_Function_1) { - /* - auto graph0 = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); - auto graph1 = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); - auto graph2 = GraphExecutioner::importFromFlatBuffers("./resources/conv_0.fb"); - - ASSERT_EQ(graph0->hashCode(), graph1->hashCode()); - ASSERT_NE(0L, graph1->hashCode()); - ASSERT_NE(graph0->hashCode(), graph2->hashCode()); - - auto graph0D = graph0->template asT(); - auto graph1D = graph1->template asT(); - - ASSERT_NE(graph0->hashCode(), graph0D->hashCode()); - ASSERT_EQ(graph0D->hashCode(), graph1D->hashCode()); - - delete graph0; - delete graph1; - delete graph2; - delete graph0D; - delete graph1D; - */ -} - -TEST_F(GraphTests, OpListTest_1) { - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); ; - - ASSERT_TRUE(graph != nullptr); - std::vector ops = graph->getOperations(); - - ASSERT_TRUE(ops.size() == 11); - GraphUtils::filterOperations(ops); - ASSERT_TRUE(ops.size() == 7); - - std::string exp(" -g \"-DSD_OPS_LIST='-DOP_rank=true -DOP_range=true -DOP_subtract=true -DOP_permute=true -DOP_matmul=true -DOP_biasadd=true -DOP_TRANSFORM{15}=true '\""); - std::string out = GraphUtils::makeCommandLine(ops); -// nd4j_printf("EXP: >%s<\n", exp.c_str()); -// nd4j_printf("OUT: >%s<\n", out.c_str()); - ASSERT_EQ(exp, out); - - delete graph; -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(GraphTests, OpListTest_2) { - auto graph0 = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); - auto graph1 = GraphExecutioner::importFromFlatBuffers("./resources/tensor_slice.fb"); - - ASSERT_TRUE(graph0 != nullptr); - ASSERT_TRUE(graph1 != nullptr); - - std::vector ops = graph0->getOperations(); - std::vector ops1 = graph1->getOperations(); - std::copy ( ops1.begin(), ops1.end(), std::back_inserter(ops)); - - ASSERT_TRUE(ops.size() == 13); - - GraphUtils::filterOperations(ops); - - std::string exp = " -g \"-DSD_OPS_LIST='-DOP_rank=true -DOP_range=true -DOP_subtract=true -DOP_permute=true -DOP_matmul=true -DOP_biasadd=true -DOP_TRANSFORM{15}=true -DOP_strided_slice=true -DOP_ACCUMULATION{1}=true '\""; - - ASSERT_TRUE(ops.size() == 9); - ASSERT_EQ(exp, GraphUtils::makeCommandLine(ops)); - - delete graph0; - delete graph1; -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(GraphTests, OpListTest_3) { - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); ; - - ASSERT_TRUE(graph != nullptr); - std::vector ops = graph->getOperations(); - std::vector ops2(ops); - std::copy(ops.begin(), ops.end(), std::back_inserter(ops2)); - - ASSERT_TRUE(ops.size() == 11); - ASSERT_TRUE(ops2.size() == 2 * ops.size()); - - GraphUtils::filterOperations(ops2); - GraphUtils::filterOperations(ops); - ASSERT_TRUE(ops.size() == ops2.size()); - ASSERT_TRUE(ops.size() == 7); - ASSERT_TRUE(GraphUtils::makeCommandLine(ops) == GraphUtils::makeCommandLine(ops2)); - - delete graph; -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(GraphTests, OpListTest_4) { - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/conv_0.fb"); ; - - ASSERT_TRUE(graph != nullptr); - std::vector ops = graph->getOperations(); - std::vector ops2(ops); - std::copy(ops.begin(), ops.end(), std::back_inserter(ops2)); - - // nd4j_printf("Total ops before %i\n", ops.size()); - ASSERT_TRUE(ops.size() == 6); - ASSERT_TRUE(ops2.size() == 2 * ops.size()); - - GraphUtils::filterOperations(ops2); - GraphUtils::filterOperations(ops); - ASSERT_TRUE(ops.size() == ops2.size()); - ASSERT_TRUE(ops.size() == 5); - ASSERT_TRUE(GraphUtils::makeCommandLine(ops) == GraphUtils::makeCommandLine(ops2)); - - delete graph; -} - - -TEST_F(GraphTests, Test_Inplace_Execution_1) { - auto exp = NDArrayFactory::create('c', {5, 4}, {0.32454616f, -0.06604697f, 0.22593613f, 0.43166467f, -0.18320604f, 0.00102305f, -0.06963076f, 0.25266643f, 0.07568010f, -0.03009197f, 0.07805517f, 0.33180334f, -0.06220427f, 0.07249600f, -0.06726961f, -0.22998397f, -0.06343779f, 0.07384885f, -0.06891008f, -0.23745790f}); - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); - // graph->printOut(); - graph->tagInplaceNodes(); - - ASSERT_FALSE(graph->nodeById(8)->isInplace()); - ASSERT_TRUE(graph->nodeById(9)->isInplace()); - ASSERT_TRUE(graph->nodeById(10)->isInplace()); - ASSERT_FALSE(graph->nodeById(11)->isInplace()); - ASSERT_FALSE(graph->nodeById(12)->isInplace()); - ASSERT_TRUE(graph->nodeById(17)->isInplace()); - ASSERT_TRUE(graph->nodeById(18)->isInplace()); - - auto status = GraphExecutioner::execute(graph, graph->getVariableSpace()); - ASSERT_EQ(Status::OK(), status); - - auto z = graph->getVariableSpace()->getVariable(18)->getNDArray(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - auto z_17 = graph->getVariableSpace()->getVariable(17)->getNDArray(); - ASSERT_TRUE(z_17 == z); - - delete graph; -} - -TEST_F(GraphTests, Test_Inplace_Execution_2) { - Graph graphA; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-5.0); - - graphA.getVariableSpace()->putVariable(-1, x); - - // abs, result is 5 - auto nodeA0 = new Node(OpType_TRANSFORM_SAME, 0, 1, {-1}, {}); - // 1-, result -4 - auto nodeA1 = new Node(OpType_TRANSFORM_SAME, 35, 2, {1}, {}); - - // graph should return 4: abs(-4) - auto nodeA2 = new Node(OpType_TRANSFORM_SAME, 0, 3, {2}, {}); - - // graph should return 1 - 4 = -3 - auto nodeA21 = new Node(OpType_TRANSFORM_SAME, 35, 5, {3}, {}); - - // 1 - -4 = 3 - auto nodeA3 = new Node(OpType_TRANSFORM_SAME, 35, 4, {2}, {}); - - // same abs = 3 - auto nodeA31 = new Node(OpType_TRANSFORM_SAME, 35, 6, {4}, {}); - - graphA.addNode(nodeA0); - graphA.addNode(nodeA1); - graphA.addNode(nodeA2); - graphA.addNode(nodeA3); - graphA.addNode(nodeA21); - graphA.addNode(nodeA31); - - graphA.buildGraph(); - graphA.tagInplaceNodes(); - - // nodes have 1 output - ASSERT_TRUE(graphA.nodeById(1)->isInplace()); - ASSERT_TRUE(graphA.nodeById(2)->isInplace()); - - // this 2 nodes share same input: node 2, so they can't be inplace - ASSERT_FALSE(graphA.nodeById(3)->isInplace()); - ASSERT_FALSE(graphA.nodeById(4)->isInplace()); - - // these 2 ops are standalone, so they can be run inplace - ASSERT_TRUE(graphA.nodeById(5)->isInplace()); - ASSERT_TRUE(graphA.nodeById(6)->isInplace()); -} -#endif - -TEST_F(GraphTests, Test_Inplace_Outputs_1) { - auto x = NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto exp = NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto z = NDArrayFactory::create('c', {2, 3}); - - sd::ops::test_output_reshape op; - auto result = op.execute({&x}, {&z}, {}, {}, {}); - ASSERT_EQ(Status::OK(), result); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); -} +#include +#include +#include +#include +#include +#include +#include +#include -TEST_F(GraphTests, Test_Inplace_Outputs_2) { -#ifndef __APPLE_OS__ - // we dont want testing this on apple. due to try/catch +#include - auto x = NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto exp = NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto z = NDArrayFactory::create('c', {3, 3}); +#include "testlayers.h" - bool failed = false; - sd::ops::test_output_reshape op; - try { - op.execute({&x}, {&z}, {}, {}, {}); +using namespace sd; +using namespace sd::graph; - } catch (const std::runtime_error& e) { - failed = true; - } - - - ASSERT_TRUE(failed); -#endif -} +class GraphTests : public testing::Test { + public: + /* + int cShape[] = {2, 2, 2, 2, 1, 0, 1, 99}; + int fShape[] = {2, 2, 2, 1, 2, 0, 1, 102}; + */ + GraphTests() { + // Environment::getInstance().setDebug(true); + // Environment::getInstance().setVerbose(true); + } +}; /* TEST_F(GraphTests, Test_Minifier_1) { // run preprocessor to produce single header // if all ok - return value is 0, if error - non-zero value will be returned - std::string input("../include/ops/ops.h"); //declarable/CustomOperations.h"); + std::string input("../include/ops/ops.h"); +//declarable/CustomOperations.h"); - ASSERT_EQ(0, GraphUtils::runPreprocessor(input.c_str(), "libnd4j_mini.hpp")); + ASSERT_EQ(0, GraphUtils::runPreprocessor(input.c_str(), +"libnd4j_mini.hpp")); // remove file from filesystem #ifdef __linux__ ASSERT_EQ(0, unlink("libnd4j_mini.hpp")); @@ -1615,24 +63,23 @@ TEST_F(GraphTests, Test_Minifier_1) { */ TEST_F(GraphTests, Test_Minifier_2) { - - // run preprocessor to produce single header - // if all ok - return value is 0, if error - non-zero value will be returned - ASSERT_EQ(0, GraphUtils::runPreprocessor("../include/ops/specials.h", "libnd4j_mini2.hpp")); - // remove file from filesystem + // run preprocessor to produce single header + // if all ok - return value is 0, if error - non-zero value will be returned + ASSERT_EQ(0, GraphUtils::runPreprocessor("../include/ops/specials.h", + "libnd4j_mini2.hpp")); + // remove file from filesystem #ifdef __linux__ - ASSERT_EQ(0, unlink("libnd4j_mini2.hpp")); + ASSERT_EQ(0, unlink("libnd4j_mini2.hpp")); #endif } TEST_F(GraphTests, Test_Minifier_3) { - - // run preprocessor to produce single header - // if all ok - return value is 0, if error - non-zero value will be returned + // run preprocessor to produce single header + // if all ok - return value is 0, if error - non-zero value will be returned #ifdef __linux__ - ASSERT_EQ(0x100, GraphUtils::runPreprocessor("/include/ops/ops.h", "libnd4j_mini3.hpp")); + ASSERT_EQ(0x100, GraphUtils::runPreprocessor("/include/ops/ops.h", + "libnd4j_mini3.hpp")); #endif - // remove file from filesystem - //ASSERT_EQ(0, unlink("libnd4j_mini3.hpp")); - + // remove file from filesystem + // ASSERT_EQ(0, unlink("libnd4j_mini3.hpp")); } diff --git a/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp new file mode 100644 index 000000000000..f343058e70c6 --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp @@ -0,0 +1,188 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "testlayers.h" + +using namespace sd; +using namespace sd::graph; + +class GraphTests2 : public testing::Test { + public: + GraphTests2() { + // + } +}; + +TEST_F(GraphTests2, test_placeholder_1) { + Graph graph; + + graph.addPlaceholder("input", DataType::BFLOAT16, {4, 12, 48}); + + ASSERT_TRUE(graph.variableSpace().hasVariable("input")); + + auto variable = graph.variableSpace().getVariable("input"); + + ASSERT_NE(nullptr, variable); + ASSERT_TRUE(variable->isPlaceholder()); + ASSERT_EQ(DataType::BFLOAT16, variable->dataType()); + ASSERT_EQ(std::vector({4, 12, 48}), variable->shape()); + + auto placeholders = graph.placeholders(); + ASSERT_EQ(1, placeholders.size()); + ASSERT_EQ(placeholders[0], variable); +} + +TEST_F(GraphTests2, test_execution_1) { + Graph graph; + + // A + graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + + // B + graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); + + // C + graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); + + Node b(sd::ops::add(), "add_node"); + + graph.addNode(Node(sd::ops::multiply(), "multiply_node"), {"A", "B"}); + graph.addNode(b, {"multiply_node", "C"}); + + auto result = graph.execute({}, {"add_node"}); + ASSERT_EQ(1, result.size()); + ASSERT_EQ(1, result.count("add_node")); +} + +TEST_F(GraphTests2, test_placeholder_resolution_1) { + Graph graph; + + graph.addPlaceholder("input", DataType::FLOAT32); + + Node node(sd::ops::tanh(), "tanh_node"); + graph.addNode(node, {"input"}); + + // this test must throw an exception, because input isn't resolved yet + ASSERT_ANY_THROW(graph.execute()); +} + +TEST_F(GraphTests2, test_placeholder_resolution_2) { + Graph graph; + + graph.addPlaceholder("input", DataType::FLOAT32); + + graph.addNode(Node(sd::ops::rationaltanh(), "tanh_node"), {"input"}); + + auto result = + graph.execute({{"input", NDArrayFactory::create(0.5f)}}, {"tanh_node"}); + + // TODO: add result validation here +} + +TEST_F(GraphTests2, test_placeholder_resolution_3) { + Graph graph; + + graph.addPlaceholder("input", DataType::FLOAT32); + + graph.addNode(Node(sd::ops::tanh(), "tanh_node"), {"input"}); + + ASSERT_THROW( + graph.execute({{"input", NDArrayFactory::create(5)}}, {"tanh_node"}), + sd::datatype_exception); +} + +TEST_F(GraphTests2, test_placeholder_resolution_4) { + Graph graph; + + graph.addPlaceholder("input", DataType::FLOAT32, {3, 4, 5}); + + Node a(sd::ops::tanh(), "tanh_node"); + graph.addNode(a, {"input"}); + + ASSERT_THROW(graph.execute({{"input", NDArrayFactory::create(0.5f)}}, + {"tanh_node"}), + sd::shape_mismatch_exception); +} + +TEST_F(GraphTests2, test_output_resolution_1) { + Graph graph; + + graph.addPlaceholder("input", DataType::FLOAT32); + + Node node(sd::ops::tanh(), "tanh_node"); + graph.addNode(node, {"input"}); + + // since we're requesting output of non-existent node - we expect exception + ASSERT_THROW( + graph.execute({{"input", NDArrayFactory::create(0.5f)}}, {"pow_node"}), + graph::unresolved_output_exception); +} + +TEST_F(GraphTests2, test_input_resolution_1) { + Graph graph; + + graph.addPlaceholder("input", DataType::FLOAT32); + + Node a(sd::ops::tanh(), "tanh_node"); + graph.addNode(a, {"input"}); + + // since we're trying to resolve non-existent placeholder - we expect + // exception + ASSERT_THROW( + graph.execute({{"array", NDArrayFactory::create(0.5f)}}, {"tanh_node"}), + graph::unresolved_input_exception); +} + +TEST_F(GraphTests2, test_double_name_1) { + Graph graph; + + graph.addPlaceholder("input", DataType::FLOAT32); + + graph.addNode(Node(sd::ops::tanh(), "tanh_node"), {"input"}); + graph.addNode(Node(sd::ops::add(), "add_node"), {"tanh_node"}); + ASSERT_ANY_THROW( + graph.addNode(Node(sd::ops::add(), "add_node"), {"tanh_node"})); +} + +TEST_F(GraphTests2, test_self_reference) { + Graph graph; + + graph.addPlaceholder("input", DataType::FLOAT32); + + graph.addNode(Node(sd::ops::tanh(), "tanh_node"), {"input"}); + ASSERT_ANY_THROW( + graph.addNode(Node(sd::ops::add(), "add_node"), {"add_node"})); +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/HashUtilsTests.cpp b/libnd4j/tests_cpu/layers_tests/HashUtilsTests.cpp index 431a4bc14ebc..559007aa5eab 100644 --- a/libnd4j/tests_cpu/layers_tests/HashUtilsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/HashUtilsTests.cpp @@ -18,26 +18,22 @@ // Created by raver119 on 02.09.17. // -#include "testlayers.h" #include -class HashUtilsTests : public testing::Test { - -}; +#include "testlayers.h" +class HashUtilsTests : public testing::Test {}; TEST_F(HashUtilsTests, TestEquality1) { - std::string str("Conv2D"); + std::string str("Conv2D"); - Nd4jLong hash1 = sd::ops::HashHelper::getInstance().getLongHash(str); - ASSERT_EQ(-1637140380760460323L, hash1); + Nd4jLong hash1 = sd::ops::HashHelper::getInstance().getLongHash(str); + ASSERT_EQ(-1637140380760460323L, hash1); } - - TEST_F(HashUtilsTests, TestEquality2) { - std::string str("switch"); + std::string str("switch"); - Nd4jLong hash1 = sd::ops::HashHelper::getInstance().getLongHash(str); - ASSERT_EQ(-1988317239813741487L, hash1); + Nd4jLong hash1 = sd::ops::HashHelper::getInstance().getLongHash(str); + ASSERT_EQ(-1988317239813741487L, hash1); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp b/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp index fae8c49183d5..8bdefd4f121e 100644 --- a/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp @@ -36,12 +36,12 @@ using namespace sd; class HelpersTests1 : public testing::Test { -public: + public: - HelpersTests1() { + HelpersTests1() { - std::cout<('c', {4}, {14,17,3,1}); - auto tail = NDArrayFactory::create('c', {3}); - auto expTail = NDArrayFactory::create('c', {3}, {0.468984, 0.0827618, 0.0275873}); - const double normXExpected = -22.2486; - const double coeffExpected = 1.62925; + auto x = NDArrayFactory::create('c', {4}, {14,17,3,1}); + auto tail = NDArrayFactory::create('c', {3}); + auto expTail = NDArrayFactory::create('c', {3}, {0.468984, 0.0827618, 0.0275873}); + const double normXExpected = -22.2486; + const double coeffExpected = 1.62925; - double normX, coeff; - ops::helpers::Householder::evalHHmatrixData(x, tail, coeff, normX); + double normX, coeff; + ops::helpers::Householder::evalHHmatrixData(x, tail, coeff, normX); - ASSERT_NEAR(normX, normXExpected, 1e-5); - ASSERT_NEAR(coeff, coeffExpected, 1e-5); - ASSERT_TRUE(tail.isSameShapeStrict(expTail)); - ASSERT_TRUE(tail.equalsTo(&expTail)); + ASSERT_NEAR(normX, normXExpected, 1e-5); + ASSERT_NEAR(coeff, coeffExpected, 1e-5); + ASSERT_TRUE(tail.isSameShapeStrict(expTail)); + ASSERT_TRUE(tail.equalsTo(&expTail)); } ///////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, Householder_mulLeft_test1) { - auto x = NDArrayFactory::create('c', {4,4}, {12 ,19 ,14 ,3 ,10 ,4 ,17 ,19 ,19 ,18 ,5 ,3 ,6 ,4 ,2 ,16}); - auto tail = NDArrayFactory::create('c', {1,3}, {0.5,0.5,0.5}); - auto exp = NDArrayFactory::create('c', {4,4}, {9.05,15.8,11.4, 0.8, 8.525, 2.4,15.7,17.9, 17.525,16.4, 3.7, 1.9, 4.525, 2.4, 0.7,14.9}); + auto x = NDArrayFactory::create('c', {4,4}, {12 ,19 ,14 ,3 ,10 ,4 ,17 ,19 ,19 ,18 ,5 ,3 ,6 ,4 ,2 ,16}); + auto tail = NDArrayFactory::create('c', {1,3}, {0.5,0.5,0.5}); + auto exp = NDArrayFactory::create('c', {4,4}, {9.05,15.8,11.4, 0.8, 8.525, 2.4,15.7,17.9, 17.525,16.4, 3.7, 1.9, 4.525, 2.4, 0.7,14.9}); - ops::helpers::Householder::mulLeft(x, tail, 0.1); + ops::helpers::Householder::mulLeft(x, tail, 0.1); - ASSERT_TRUE(x.isSameShapeStrict(exp)); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(x.isSameShapeStrict(exp)); + ASSERT_TRUE(x.equalsTo(&exp)); } ///////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, Householder_mulLeft_test2) { - auto x = NDArrayFactory::create('c', {4,4}, {12 ,19 ,14 ,3 ,10 ,4 ,17 ,19 ,19 ,18 ,5 ,3 ,6 ,4 ,2 ,16}); - auto tail = NDArrayFactory::create('c', {3,1}, {0.5,0.5,0.5}); - auto exp = NDArrayFactory::create('c', {4,4}, {9.05,15.8,11.4, 0.8, 8.525, 2.4,15.7,17.9, 17.525,16.4, 3.7, 1.9, 4.525, 2.4, 0.7,14.9}); + auto x = NDArrayFactory::create('c', {4,4}, {12 ,19 ,14 ,3 ,10 ,4 ,17 ,19 ,19 ,18 ,5 ,3 ,6 ,4 ,2 ,16}); + auto tail = NDArrayFactory::create('c', {3,1}, {0.5,0.5,0.5}); + auto exp = NDArrayFactory::create('c', {4,4}, {9.05,15.8,11.4, 0.8, 8.525, 2.4,15.7,17.9, 17.525,16.4, 3.7, 1.9, 4.525, 2.4, 0.7,14.9}); - ops::helpers::Householder::mulLeft(x, tail, 0.1); + ops::helpers::Householder::mulLeft(x, tail, 0.1); - ASSERT_TRUE(x.isSameShapeStrict(exp)); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(x.isSameShapeStrict(exp)); + ASSERT_TRUE(x.equalsTo(&exp)); } ///////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, Householder_mulRight_test1) { - auto x = NDArrayFactory::create('c', {4,4}, {12 ,19 ,14 ,3 ,10 ,4 ,17 ,19 ,19 ,18 ,5 ,3 ,6 ,4 ,2 ,16}); - auto tail = NDArrayFactory::create('c', {1,3}, {0.5,0.5,0.5}); - auto exp = NDArrayFactory::create('c', {4,4}, {9,17.5,12.5, 1.5, 7, 2.5,15.5, 17.5, 15.8,16.4, 3.4, 1.4, 4.3,3.15,1.15,15.15}); + auto x = NDArrayFactory::create('c', {4,4}, {12 ,19 ,14 ,3 ,10 ,4 ,17 ,19 ,19 ,18 ,5 ,3 ,6 ,4 ,2 ,16}); + auto tail = NDArrayFactory::create('c', {1,3}, {0.5,0.5,0.5}); + auto exp = NDArrayFactory::create('c', {4,4}, {9,17.5,12.5, 1.5, 7, 2.5,15.5, 17.5, 15.8,16.4, 3.4, 1.4, 4.3,3.15,1.15,15.15}); - ops::helpers::Householder::mulRight(x, tail, 0.1); + ops::helpers::Householder::mulRight(x, tail, 0.1); - ASSERT_TRUE(x.isSameShapeStrict(exp)); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(x.isSameShapeStrict(exp)); + ASSERT_TRUE(x.equalsTo(&exp)); } ///////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, BiDiagonalizeUp_test1) { - auto matrix = NDArrayFactory::create('c', {4,4}, {9,13,3,6,13,11,7,6,3,7,4,7,6,6,7,10}); - auto hhMatrixExp = NDArrayFactory::create('c', {4,4}, {1.524000, 1.75682,0.233741,0.289458, 0.496646, 1.5655, 1.02929,0.971124, 0.114611,-0.451039, 1.06367,0, 0.229221,-0.272237,0.938237,0}); - auto hhBidiagExp = NDArrayFactory::create('c', {4,4}, {-17.1756, 24.3869, 0, 0, 0,-8.61985,-3.89823, 0, 0, 0, 4.03047,4.13018, 0, 0, 0,1.21666}); + auto matrix = NDArrayFactory::create('c', {4,4}, {9,13,3,6,13,11,7,6,3,7,4,7,6,6,7,10}); + auto hhMatrixExp = NDArrayFactory::create('c', {4,4}, {1.524000, 1.75682,0.233741,0.289458, 0.496646, 1.5655, 1.02929,0.971124, 0.114611,-0.451039, 1.06367,0, 0.229221,-0.272237,0.938237,0}); + auto hhBidiagExp = NDArrayFactory::create('c', {4,4}, {-17.1756, 24.3869, 0, 0, 0,-8.61985,-3.89823, 0, 0, 0, 4.03047,4.13018, 0, 0, 0,1.21666}); - ops::helpers::BiDiagonalUp object(matrix); - // object._HHmatrix.printBuffer(); + ops::helpers::BiDiagonalUp object(matrix); + // object._HHmatrix.printBuffer(); - ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(object._HHmatrix)); - ASSERT_TRUE(hhMatrixExp.equalsTo(&object._HHmatrix)); - ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(object._HHbidiag)); - ASSERT_TRUE(hhBidiagExp.equalsTo(&object._HHbidiag)); + ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(object._HHmatrix)); + ASSERT_TRUE(hhMatrixExp.equalsTo(&object._HHmatrix)); + ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(object._HHbidiag)); + ASSERT_TRUE(hhBidiagExp.equalsTo(&object._HHbidiag)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, BiDiagonalizeUp_test2) { - auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); - auto hhMatrixExp = NDArrayFactory::create('c', {5,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.66025, 1.66979,-0.444696, 0.114105,0.130601, 1.58392, 0, -0.22821, 0.215638,0.0524781, 1.99303, 0.0760699,0.375605, 0.509835,0.0591568}); - auto hhBidiagExp = NDArrayFactory::create('c', {4,4}, {-17.2916,7.03123, 0, 0, 0, 16.145,-22.9275, 0, 0, 0, -9.9264,-11.5516, 0, 0, 0,-12.8554}); + auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); + auto hhMatrixExp = NDArrayFactory::create('c', {5,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.66025, 1.66979,-0.444696, 0.114105,0.130601, 1.58392, 0, -0.22821, 0.215638,0.0524781, 1.99303, 0.0760699,0.375605, 0.509835,0.0591568}); + auto hhBidiagExp = NDArrayFactory::create('c', {4,4}, {-17.2916,7.03123, 0, 0, 0, 16.145,-22.9275, 0, 0, 0, -9.9264,-11.5516, 0, 0, 0,-12.8554}); - ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::BiDiagonalUp object(matrix); - ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(object._HHmatrix)); - ASSERT_TRUE(hhMatrixExp.equalsTo(&object._HHmatrix)); - ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(object._HHbidiag)); - ASSERT_TRUE(hhBidiagExp.equalsTo(&object._HHbidiag)); + ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(object._HHmatrix)); + ASSERT_TRUE(hhMatrixExp.equalsTo(&object._HHmatrix)); + ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(object._HHbidiag)); + ASSERT_TRUE(hhBidiagExp.equalsTo(&object._HHbidiag)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, BiDiagonalizeUp_test3) { - auto matrix = NDArrayFactory::create('c', {6,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12, 0,-15,10,2}); - auto hhMatrixExp = NDArrayFactory::create('c', {6,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.65232, 1.59666,-0.502606, 0.114105, 0.129651, 1.35075, 0, -0.22821, 0.214071, 0.103749, 1.61136, 0.0760699, 0.372875, 0.389936, 0.2398, 0,0.0935171,-0.563777, 0.428587}); - auto hhBidiagExp = NDArrayFactory::create('c', {4,4}, {-17.2916,7.03123, 0, 0, 0,16.3413,-20.7828, 0, 0, 0,-18.4892,4.13261, 0, 0, 0,-21.323}); + auto matrix = NDArrayFactory::create('c', {6,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12, 0,-15,10,2}); + auto hhMatrixExp = NDArrayFactory::create('c', {6,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.65232, 1.59666,-0.502606, 0.114105, 0.129651, 1.35075, 0, -0.22821, 0.214071, 0.103749, 1.61136, 0.0760699, 0.372875, 0.389936, 0.2398, 0,0.0935171,-0.563777, 0.428587}); + auto hhBidiagExp = NDArrayFactory::create('c', {4,4}, {-17.2916,7.03123, 0, 0, 0,16.3413,-20.7828, 0, 0, 0,-18.4892,4.13261, 0, 0, 0,-21.323}); - ops::helpers::BiDiagonalUp object(matrix); - // object._HHmatrix.printBuffer(); + ops::helpers::BiDiagonalUp object(matrix); + // object._HHmatrix.printBuffer(); - ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(object._HHmatrix)); - ASSERT_TRUE(hhMatrixExp.equalsTo(&object._HHmatrix)); - ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(object._HHbidiag)); - ASSERT_TRUE(hhBidiagExp.equalsTo(&object._HHbidiag)); + ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(object._HHmatrix)); + ASSERT_TRUE(hhMatrixExp.equalsTo(&object._HHmatrix)); + ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(object._HHbidiag)); + ASSERT_TRUE(hhBidiagExp.equalsTo(&object._HHbidiag)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test1) { - auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); - auto vectorsUseqExp = NDArrayFactory::create('c', {5,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.66025, 1.66979,-0.444696, 0.114105,0.130601, 1.58392, 0, -0.22821,0.215638,0.0524781, 1.99303, 0.0760699,0.375605, 0.509835,0.0591568}); - auto vectorsVseqExp = NDArrayFactory::create('c', {5,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.66025, 1.66979,-0.444696, 0.114105,0.130601, 1.58392, 0, -0.22821,0.215638,0.0524781, 1.99303, 0.0760699,0.375605, 0.509835,0.0591568}); - auto coeffsUseqExp = NDArrayFactory::create('c', {4,1}, {1.52048,1.66025,1.58392,1.99303}); - auto coeffsVseqExp = NDArrayFactory::create('c', {3,1}, {1.37012,1.66979,0}); + auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); + auto vectorsUseqExp = NDArrayFactory::create('c', {5,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.66025, 1.66979,-0.444696, 0.114105,0.130601, 1.58392, 0, -0.22821,0.215638,0.0524781, 1.99303, 0.0760699,0.375605, 0.509835,0.0591568}); + auto vectorsVseqExp = NDArrayFactory::create('c', {5,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.66025, 1.66979,-0.444696, 0.114105,0.130601, 1.58392, 0, -0.22821,0.215638,0.0524781, 1.99303, 0.0760699,0.375605, 0.509835,0.0591568}); + auto coeffsUseqExp = NDArrayFactory::create('c', {4,1}, {1.52048,1.66025,1.58392,1.99303}); + auto coeffsVseqExp = NDArrayFactory::create('c', {3,1}, {1.37012,1.66979,0}); - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); - ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(vectorsUseqExp)); - ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(vectorsVseqExp)); - ASSERT_TRUE(uSeq._vectors.equalsTo(&vectorsUseqExp)); - ASSERT_TRUE(vSeq._vectors.equalsTo(&vectorsVseqExp)); + ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(vectorsUseqExp)); + ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(vectorsVseqExp)); + ASSERT_TRUE(uSeq._vectors.equalsTo(&vectorsUseqExp)); + ASSERT_TRUE(vSeq._vectors.equalsTo(&vectorsVseqExp)); - ASSERT_TRUE(vSeq._diagSize == uSeq._diagSize - 1); - ASSERT_TRUE(vSeq._shift == 1); - ASSERT_TRUE(uSeq._shift == 0); + ASSERT_TRUE(vSeq._diagSize == uSeq._diagSize - 1); + ASSERT_TRUE(vSeq._shift == 1); + ASSERT_TRUE(uSeq._shift == 0); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test2) { - auto matrix = NDArrayFactory::create('c', {6,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12 ,0,-15,10,2}); - auto vectorsUseqExp = NDArrayFactory::create('c', {6,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.65232, 1.59666,-0.502606, 0.114105, 0.129651, 1.35075, 0, -0.22821, 0.214071, 0.103749, 1.61136, 0.0760699, 0.372875, 0.389936, 0.2398, 0,0.0935171,-0.563777, 0.428587}); - auto vectorsVseqExp = NDArrayFactory::create('c', {6,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.65232, 1.59666,-0.502606, 0.114105, 0.129651, 1.35075, 0, -0.22821, 0.214071, 0.103749, 1.61136, 0.0760699, 0.372875, 0.389936, 0.2398, 0,0.0935171,-0.563777, 0.428587}); - auto coeffsUseqExp = NDArrayFactory::create('c', {4,1}, {1.52048,1.65232,1.35075,1.61136}); - auto coeffsVseqExp = NDArrayFactory::create('c', {3,1}, {1.37012,1.59666,0}); + auto matrix = NDArrayFactory::create('c', {6,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12 ,0,-15,10,2}); + auto vectorsUseqExp = NDArrayFactory::create('c', {6,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.65232, 1.59666,-0.502606, 0.114105, 0.129651, 1.35075, 0, -0.22821, 0.214071, 0.103749, 1.61136, 0.0760699, 0.372875, 0.389936, 0.2398, 0,0.0935171,-0.563777, 0.428587}); + auto vectorsVseqExp = NDArrayFactory::create('c', {6,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.65232, 1.59666,-0.502606, 0.114105, 0.129651, 1.35075, 0, -0.22821, 0.214071, 0.103749, 1.61136, 0.0760699, 0.372875, 0.389936, 0.2398, 0,0.0935171,-0.563777, 0.428587}); + auto coeffsUseqExp = NDArrayFactory::create('c', {4,1}, {1.52048,1.65232,1.35075,1.61136}); + auto coeffsVseqExp = NDArrayFactory::create('c', {3,1}, {1.37012,1.59666,0}); - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); - ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(vectorsUseqExp)); - ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(vectorsVseqExp)); - ASSERT_TRUE(uSeq._vectors.equalsTo(&vectorsUseqExp)); - ASSERT_TRUE(vSeq._vectors.equalsTo(&vectorsVseqExp)); + ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(vectorsUseqExp)); + ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(vectorsVseqExp)); + ASSERT_TRUE(uSeq._vectors.equalsTo(&vectorsUseqExp)); + ASSERT_TRUE(vSeq._vectors.equalsTo(&vectorsVseqExp)); - ASSERT_TRUE(vSeq._diagSize == uSeq._diagSize - 1); - ASSERT_TRUE(vSeq._shift == 1); - ASSERT_TRUE(uSeq._shift == 0); + ASSERT_TRUE(vSeq._diagSize == uSeq._diagSize - 1); + ASSERT_TRUE(vSeq._shift == 1); + ASSERT_TRUE(uSeq._shift == 0); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test3) { - auto matrix = NDArrayFactory::create('c', {4,4}, {9,13,3,6, 13,11,7,6, 3,7,4,7, 6,6,7,10}); - auto vectorsUseqExp = NDArrayFactory::create('c', {4,4}, {1.524, 1.75682,0.233741,0.289458, 0.496646, 1.5655, 1.02929,0.971124, 0.114611,-0.451039, 1.06367, 0, 0.229221,-0.272237,0.938237, 0}); - auto vectorsVseqExp = NDArrayFactory::create('c', {4,4}, {1.524, 1.75682,0.233741,0.289458, 0.496646, 1.5655, 1.02929,0.971124, 0.114611,-0.451039, 1.06367, 0, 0.229221,-0.272237,0.938237, 0}); - auto coeffsUseqExp = NDArrayFactory::create('c', {4,1}, { 1.524, 1.5655,1.06367,0}); - auto coeffsVseqExp = NDArrayFactory::create('c', {3,1}, {1.75682,1.02929, 0}); + auto matrix = NDArrayFactory::create('c', {4,4}, {9,13,3,6, 13,11,7,6, 3,7,4,7, 6,6,7,10}); + auto vectorsUseqExp = NDArrayFactory::create('c', {4,4}, {1.524, 1.75682,0.233741,0.289458, 0.496646, 1.5655, 1.02929,0.971124, 0.114611,-0.451039, 1.06367, 0, 0.229221,-0.272237,0.938237, 0}); + auto vectorsVseqExp = NDArrayFactory::create('c', {4,4}, {1.524, 1.75682,0.233741,0.289458, 0.496646, 1.5655, 1.02929,0.971124, 0.114611,-0.451039, 1.06367, 0, 0.229221,-0.272237,0.938237, 0}); + auto coeffsUseqExp = NDArrayFactory::create('c', {4,1}, { 1.524, 1.5655,1.06367,0}); + auto coeffsVseqExp = NDArrayFactory::create('c', {3,1}, {1.75682,1.02929, 0}); - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); - ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(vectorsUseqExp)); - ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(vectorsVseqExp)); - ASSERT_TRUE(uSeq._vectors.equalsTo(&vectorsUseqExp)); - ASSERT_TRUE(vSeq._vectors.equalsTo(&vectorsVseqExp)); + ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(vectorsUseqExp)); + ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(vectorsVseqExp)); + ASSERT_TRUE(uSeq._vectors.equalsTo(&vectorsUseqExp)); + ASSERT_TRUE(vSeq._vectors.equalsTo(&vectorsVseqExp)); - ASSERT_TRUE(vSeq._diagSize == uSeq._diagSize - 1); - ASSERT_TRUE(vSeq._shift == 1); - ASSERT_TRUE(uSeq._shift == 0); + ASSERT_TRUE(vSeq._diagSize == uSeq._diagSize - 1); + ASSERT_TRUE(vSeq._shift == 1); + ASSERT_TRUE(uSeq._shift == 0); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test4) { - auto matrix = NDArrayFactory::create('c', {4,4}, {9,13,3,6, 13,11,7,6, 3,7,4,7, 6,6,7,10}); - auto exp = NDArrayFactory::create('c', {4,4}, {2.49369, 2.62176, 5.88386, 7.69905, -16.0588,-18.7319,-9.15007,-12.6164, 4.7247, 3.46252, 1.02038, -1.4533, 2.9279,-2.29178, 1.90139,-0.66187}); + auto matrix = NDArrayFactory::create('c', {4,4}, {9,13,3,6, 13,11,7,6, 3,7,4,7, 6,6,7,10}); + auto exp = NDArrayFactory::create('c', {4,4}, {2.49369, 2.62176, 5.88386, 7.69905, -16.0588,-18.7319,-9.15007,-12.6164, 4.7247, 3.46252, 1.02038, -1.4533, 2.9279,-2.29178, 1.90139,-0.66187}); - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); - uSeq.mulLeft(matrix); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + uSeq.mulLeft(matrix); - ASSERT_TRUE(matrix.equalsTo(&exp)); + ASSERT_TRUE(matrix.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test5) { - auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); - auto exp = NDArrayFactory::create('c', {5,4}, {4.52891, 8.09473,-2.73704,-13.0302, -11.0752, 7.41549,-3.75125,0.815252, -7.76818,-15.9102,-9.90869,-11.8677, 1.63942,-17.0312,-9.05102,-4.49088, -9.63311,0.540226,-1.52764, 5.79111}); + auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); + auto exp = NDArrayFactory::create('c', {5,4}, {4.52891, 8.09473,-2.73704,-13.0302, -11.0752, 7.41549,-3.75125,0.815252, -7.76818,-15.9102,-9.90869,-11.8677, 1.63942,-17.0312,-9.05102,-4.49088, -9.63311,0.540226,-1.52764, 5.79111}); - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); - uSeq.mulLeft(matrix); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + uSeq.mulLeft(matrix); - ASSERT_TRUE(matrix.equalsTo(&exp)); + ASSERT_TRUE(matrix.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test6) { - auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); - auto matrix2 = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); - auto exp = NDArrayFactory::create('c', {6,4}, {9,-1,3,9, -4.43019,-15.1713, -3.2854,-7.65743, -9.39162,-7.03599, 8.03827, 9.48453, -2.97785, -16.424, 5.35265,-20.1171, -0.0436177, -13.118,-8.37287,-17.3012, -1.14074, 4.18282,-10.0914,-5.69014}); + auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); + auto matrix2 = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); + auto exp = NDArrayFactory::create('c', {6,4}, {9,-1,3,9, -4.43019,-15.1713, -3.2854,-7.65743, -9.39162,-7.03599, 8.03827, 9.48453, -2.97785, -16.424, 5.35265,-20.1171, -0.0436177, -13.118,-8.37287,-17.3012, -1.14074, 4.18282,-10.0914,-5.69014}); - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); - uSeq.mulLeft(matrix2); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + uSeq.mulLeft(matrix2); - ASSERT_TRUE(matrix2.equalsTo(&exp)); + ASSERT_TRUE(matrix2.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test7) { - auto matrix = NDArrayFactory::create('c', {4,4}, {9,13,3,6, 13,11,7,6, 3,7,4,7, 6,6,7,10}); - auto exp = NDArrayFactory::create('c', {4,4}, {9,13,3,6,-5.90424,-2.30926,-0.447417, 3.05712, -10.504,-9.31339, -8.85493,-10.8886, -8.29494,-10.6737, -5.94895,-7.55591}); + auto matrix = NDArrayFactory::create('c', {4,4}, {9,13,3,6, 13,11,7,6, 3,7,4,7, 6,6,7,10}); + auto exp = NDArrayFactory::create('c', {4,4}, {9,13,3,6,-5.90424,-2.30926,-0.447417, 3.05712, -10.504,-9.31339, -8.85493,-10.8886, -8.29494,-10.6737, -5.94895,-7.55591}); - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - vSeq.mulLeft(matrix); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.mulLeft(matrix); - ASSERT_TRUE(matrix.equalsTo(&exp)); + ASSERT_TRUE(matrix.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test8) { - auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); - auto exp = NDArrayFactory::create('c', {5,4}, {9, -13, 3, 6, 13, 11, 7, -6, -6.90831,-5.01113, 0.381677,0.440128, -0.80107,0.961605,-0.308019,-1.96153, -0.795985, 18.6538, 12.0731, 16.9988}); + auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); + auto exp = NDArrayFactory::create('c', {5,4}, {9, -13, 3, 6, 13, 11, 7, -6, -6.90831,-5.01113, 0.381677,0.440128, -0.80107,0.961605,-0.308019,-1.96153, -0.795985, 18.6538, 12.0731, 16.9988}); - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - vSeq.mulLeft(matrix); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.mulLeft(matrix); - ASSERT_TRUE(matrix.equalsTo(&exp)); + ASSERT_TRUE(matrix.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test9) { - auto matrix = NDArrayFactory::create('c', {6,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12 ,0,-15,10,2}); - auto exp = NDArrayFactory::create('c', {6,4}, {9, -13, 3, 6, 13, 11, 7, -6, 3, 7, 4, 7, 3.77597, 18.6226,-0.674868, 4.61365, 5.02738,-14.1486, -2.22877,-8.98245, -0.683766, 1.73722, 14.9859, 12.0843}); + auto matrix = NDArrayFactory::create('c', {6,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12 ,0,-15,10,2}); + auto exp = NDArrayFactory::create('c', {6,4}, {9, -13, 3, 6, 13, 11, 7, -6, 3, 7, 4, 7, 3.77597, 18.6226,-0.674868, 4.61365, 5.02738,-14.1486, -2.22877,-8.98245, -0.683766, 1.73722, 14.9859, 12.0843}); - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - vSeq.mulLeft(matrix); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.mulLeft(matrix); - ASSERT_TRUE(matrix.equalsTo(&exp)); + ASSERT_TRUE(matrix.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test10) { - auto matrix = NDArrayFactory::create('c', {4,4}, {9,13,3,6, 13,11,7,6, 3,7,4,7, 6,6,7,10}); - auto matrix2 = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); - auto exp = NDArrayFactory::create('c', {6,4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, 2.58863, 11.0295,-4.17483,-0.641012, -1.21892,-16.3151, 6.12049, -20.0239, -0.901799,-15.0389,-12.4944, -20.2394}); + auto matrix = NDArrayFactory::create('c', {4,4}, {9,13,3,6, 13,11,7,6, 3,7,4,7, 6,6,7,10}); + auto matrix2 = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); + auto exp = NDArrayFactory::create('c', {6,4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, 2.58863, 11.0295,-4.17483,-0.641012, -1.21892,-16.3151, 6.12049, -20.0239, -0.901799,-15.0389,-12.4944, -20.2394}); - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - vSeq.mulLeft(matrix2); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.mulLeft(matrix2); - ASSERT_TRUE(matrix2.equalsTo(&exp)); + ASSERT_TRUE(matrix2.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test11) { - auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); - auto matrix2 = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); - auto exp = NDArrayFactory::create('c', {6,4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, 1.14934, 4.40257, 8.70127,-1.18824, 1.5132,0.220419,-11.6285,-11.7549, 2.32148, 24.3838,0.256531, 25.9116}); + auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); + auto matrix2 = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); + auto exp = NDArrayFactory::create('c', {6,4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, 1.14934, 4.40257, 8.70127,-1.18824, 1.5132,0.220419,-11.6285,-11.7549, 2.32148, 24.3838,0.256531, 25.9116}); - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - vSeq.mulLeft(matrix2); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.mulLeft(matrix2); - ASSERT_TRUE(matrix2.equalsTo(&exp)); + ASSERT_TRUE(matrix2.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test12) { - auto matrix = NDArrayFactory::create('c', {5,3}, {9,-13,3, 13,11,7, 3,7,4, -6,6,7, 2,17,9}); - auto matrix2 = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); - auto exp = NDArrayFactory::create('c', {6,4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, -1, 6, 7, 19, -2.62252,-22.2914, 4.76743,-19.6689, -1.05943,-9.00514,-11.8013,-7.94571}); + auto matrix = NDArrayFactory::create('c', {5,3}, {9,-13,3, 13,11,7, 3,7,4, -6,6,7, 2,17,9}); + auto matrix2 = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); + auto exp = NDArrayFactory::create('c', {6,4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, -1, 6, 7, 19, -2.62252,-22.2914, 4.76743,-19.6689, -1.05943,-9.00514,-11.8013,-7.94571}); - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - vSeq.mulLeft(matrix2); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.mulLeft(matrix2); - ASSERT_TRUE(matrix2.equalsTo(&exp)); + ASSERT_TRUE(matrix2.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test13) { - auto matrix = NDArrayFactory::create('c', {5,3}, {9,-13,3, 13,11,7, 3,7,4, -6,6,7, 2,17,9}); - auto matrix2 = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); - auto exp = NDArrayFactory::create('c', {6,4}, {9 , -1 , 3 , 9, -4.65167, 3.44652, 7.83593, 22.6899, -9.48514, -21.902, 5.66559,-13.0533, -0.343184, 15.2895, 7.2888, 14.0489, 0.289638,-1.87752, 3.944,-1.49707, -2.48845, 3.18285,-10.6685,0.406502}); + auto matrix = NDArrayFactory::create('c', {5,3}, {9,-13,3, 13,11,7, 3,7,4, -6,6,7, 2,17,9}); + auto matrix2 = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); + auto exp = NDArrayFactory::create('c', {6,4}, {9 , -1 , 3 , 9, -4.65167, 3.44652, 7.83593, 22.6899, -9.48514, -21.902, 5.66559,-13.0533, -0.343184, 15.2895, 7.2888, 14.0489, 0.289638,-1.87752, 3.944,-1.49707, -2.48845, 3.18285,-10.6685,0.406502}); - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); - uSeq.mulLeft(matrix2); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + uSeq.mulLeft(matrix2); - ASSERT_TRUE(matrix2.equalsTo(&exp)); + ASSERT_TRUE(matrix2.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test14) { - auto matrix = NDArrayFactory::create('c', {5,3}, {9,-13,3, 13,11,7, 3,7,4, -6,6,7, 2,17,9}); - auto matrix2 = NDArrayFactory::create('c',{5,5}, {9,-1,3,9,10, 11,-7,-5,3, 2, 4,7,-1,6,7, 19,2,17,9,15, 2,17,-9,15,2}); - auto exp = NDArrayFactory::create('c', {5,5}, {1.78958, 8.06962,-6.13687, 4.36267, 1.06472, -14.9578, -8.1522, 1.30442,-18.3343,-13.2578, 13.5536, 5.50764, 15.7859, 7.60831, 11.7871, -1.3626,-0.634986, 7.60934, -2.1841, 5.62694, -13.0577, 15.1554, -7.6511, 3.76365,-5.87368}); + auto matrix = NDArrayFactory::create('c', {5,3}, {9,-13,3, 13,11,7, 3,7,4, -6,6,7, 2,17,9}); + auto matrix2 = NDArrayFactory::create('c',{5,5}, {9,-1,3,9,10, 11,-7,-5,3, 2, 4,7,-1,6,7, 19,2,17,9,15, 2,17,-9,15,2}); + auto exp = NDArrayFactory::create('c', {5,5}, {1.78958, 8.06962,-6.13687, 4.36267, 1.06472, -14.9578, -8.1522, 1.30442,-18.3343,-13.2578, 13.5536, 5.50764, 15.7859, 7.60831, 11.7871, -1.3626,-0.634986, 7.60934, -2.1841, 5.62694, -13.0577, 15.1554, -7.6511, 3.76365,-5.87368}); - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); - uSeq.mulLeft(matrix2); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + uSeq.mulLeft(matrix2); - ASSERT_TRUE(matrix2.equalsTo(&exp)); + ASSERT_TRUE(matrix2.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test15) { - auto matrix = NDArrayFactory::create('c', {5,3}, {9,-13,3, 13,11,7, 3,7,4, -6,6,7, 2,17,9}); - auto matrix2 = NDArrayFactory::create('c',{5,5}, {9,-1,3,9,10, 11,-7,-5,3, 2, 4,7,-1,6,7, 19,2,17,9,15, 2,17,-9,15,2}); - auto exp = NDArrayFactory::create('c', {5,5}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, -1, 6, 7, -9.26566,-16.4298, 1.64125,-17.3243,-7.70257, -16.7077, 4.80216,-19.1652,-2.42279,-13.0258}); + auto matrix = NDArrayFactory::create('c', {5,3}, {9,-13,3, 13,11,7, 3,7,4, -6,6,7, 2,17,9}); + auto matrix2 = NDArrayFactory::create('c',{5,5}, {9,-1,3,9,10, 11,-7,-5,3, 2, 4,7,-1,6,7, 19,2,17,9,15, 2,17,-9,15,2}); + auto exp = NDArrayFactory::create('c', {5,5}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, -1, 6, 7, -9.26566,-16.4298, 1.64125,-17.3243,-7.70257, -16.7077, 4.80216,-19.1652,-2.42279,-13.0258}); - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - vSeq.mulLeft(matrix2); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.mulLeft(matrix2); - ASSERT_TRUE(matrix2.equalsTo(&exp)); + ASSERT_TRUE(matrix2.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test16) { - auto matrix = NDArrayFactory::create('c', {5,5}, {9,-1,3,9,10, 11,-7,-5,3, 2, 4,7,-1,6,7, 19,2,17,9,15, 2,17,-9,15,2}); - auto matrix2 = NDArrayFactory::create('c', {10,10}); - matrix2 = 100.; - auto exp = NDArrayFactory::create('c',{5,5}, {-0.372742, 0.295145, 0.325359, 0.790947, 0.20615, -0.455573,-0.824221,-0.239444, 0.216163,-0.0951492, -0.165663, 0.285319, -0.18501, 0.130431, -0.916465, -0.7869, 0.245393, 0.116952,-0.541267, 0.117997, -0.0828315, 0.303191,-0.888202, 0.133021, 0.3076}); + auto matrix = NDArrayFactory::create('c', {5,5}, {9,-1,3,9,10, 11,-7,-5,3, 2, 4,7,-1,6,7, 19,2,17,9,15, 2,17,-9,15,2}); + auto matrix2 = NDArrayFactory::create('c', {10,10}); + matrix2 = 100.; + auto exp = NDArrayFactory::create('c',{5,5}, {-0.372742, 0.295145, 0.325359, 0.790947, 0.20615, -0.455573,-0.824221,-0.239444, 0.216163,-0.0951492, -0.165663, 0.285319, -0.18501, 0.130431, -0.916465, -0.7869, 0.245393, 0.116952,-0.541267, 0.117997, -0.0828315, 0.303191,-0.888202, 0.133021, 0.3076}); - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); - uSeq.applyTo(matrix2); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + uSeq.applyTo(matrix2); - ASSERT_TRUE(matrix2.equalsTo(&exp)); + ASSERT_TRUE(matrix2.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test17) { - auto matrix = NDArrayFactory::create('c', {5,5}, {9,-1,3,9,10, 11,-7,-5,3, 2, 4,7,-1,6,7, 19,2,17,9,15, 2,17,-9,15,2}); - auto matrix2 = NDArrayFactory::create('c', {10,10}); - matrix2 = 100.; - auto exp = NDArrayFactory::create('c',{5,5}, {1, 0, 0, 0, 0, 0,-0.022902, 0.986163, 0.0411914, 0.158935, 0, -0.44659, 0.021539, 0.797676,-0.404731, 0,-0.554556, 0.103511, -0.600701, -0.56649, 0,-0.701784,-0.127684,-0.0342758, 0.700015}); + auto matrix = NDArrayFactory::create('c', {5,5}, {9,-1,3,9,10, 11,-7,-5,3, 2, 4,7,-1,6,7, 19,2,17,9,15, 2,17,-9,15,2}); + auto matrix2 = NDArrayFactory::create('c', {10,10}); + matrix2 = 100.; + auto exp = NDArrayFactory::create('c',{5,5}, {1, 0, 0, 0, 0, 0,-0.022902, 0.986163, 0.0411914, 0.158935, 0, -0.44659, 0.021539, 0.797676,-0.404731, 0,-0.554556, 0.103511, -0.600701, -0.56649, 0,-0.701784,-0.127684,-0.0342758, 0.700015}); - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - vSeq.applyTo(matrix2); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.applyTo(matrix2); - ASSERT_TRUE(matrix2.equalsTo(&exp)); + ASSERT_TRUE(matrix2.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test18) { - auto matrix = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); - auto matrix2 = NDArrayFactory::create('c', {10,10}); - matrix2 = 100.; - auto exp = NDArrayFactory::create('c',{6,6}, {-0.637993, 0.190621,-0.524821,-0.312287, 0.407189, 0.133659, -0.708881, 0.0450803, 0.47462, 0.232701,-0.204602,-0.417348, -0.212664,-0.0405892,-0.297123,0.0240276,-0.821557, 0.435099, 0.0708881, -0.432466, -0.49252,-0.145004,-0.199312,-0.710367, -0.141776, -0.56468,-0.180549, 0.706094, 0.274317, 0.233707, -0.141776, -0.673865, 0.368567,-0.572848,0.0490246, 0.243733}); + auto matrix = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); + auto matrix2 = NDArrayFactory::create('c', {10,10}); + matrix2 = 100.; + auto exp = NDArrayFactory::create('c',{6,6}, {-0.637993, 0.190621,-0.524821,-0.312287, 0.407189, 0.133659, -0.708881, 0.0450803, 0.47462, 0.232701,-0.204602,-0.417348, -0.212664,-0.0405892,-0.297123,0.0240276,-0.821557, 0.435099, 0.0708881, -0.432466, -0.49252,-0.145004,-0.199312,-0.710367, -0.141776, -0.56468,-0.180549, 0.706094, 0.274317, 0.233707, -0.141776, -0.673865, 0.368567,-0.572848,0.0490246, 0.243733}); - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); - uSeq.applyTo(matrix2); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + uSeq.applyTo(matrix2); - ASSERT_TRUE(matrix2.equalsTo(&exp)); + ASSERT_TRUE(matrix2.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test19) { - auto matrix = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); - auto matrix2 = NDArrayFactory::create('c', {10,10}); - matrix2 = 100.; - auto exp = NDArrayFactory::create('c',{4,4}, {1, 0, 0, 0, 0,-0.859586, 0.28601, -0.42345, 0, 0.19328,-0.585133,-0.787567, 0,-0.473027,-0.758826, 0.447693}); + auto matrix = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); + auto matrix2 = NDArrayFactory::create('c', {10,10}); + matrix2 = 100.; + auto exp = NDArrayFactory::create('c',{4,4}, {1, 0, 0, 0, 0,-0.859586, 0.28601, -0.42345, 0, 0.19328,-0.585133,-0.787567, 0,-0.473027,-0.758826, 0.447693}); - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - vSeq.applyTo(matrix2); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.applyTo(matrix2); - ASSERT_TRUE(matrix2.equalsTo(&exp)); + ASSERT_TRUE(matrix2.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHcolPivQR_1) { - auto matrix1 = NDArrayFactory::create('c', {5,6}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); + auto matrix1 = NDArrayFactory::create('c', {5,6}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); - auto expQR = NDArrayFactory::create('c', {5,6}, {-32.6649659, -4.9594419, -8.2657365, 7.2248659, 16.5927006, 11.7251002, -0.1354883, -29.0586293, 10.9775804, -14.6886248, 4.1884104, 20.7115773, 0.3483986, 0.3236753, 25.5376258, 1.6432380, 9.6395914, -9.0237996, -0.0580664, 0.0798999, -0.0799029, 19.5280665, -4.9773587, 16.0968604, 0.3483986, -0.6667832, 0.0252425, 0.0159188, 10.6978354, -4.6919842}); - auto expCoeffs = NDArrayFactory::create('c', {1,5}, {1.58166, 1.28555, 1.98605, 1.99949, 0}); - auto expPermut = NDArrayFactory::create('c', {6,6}, {0,1,0,0,0,0, 0,0,1,0,0,0, 1,0,0,0,0,0, 0,0,0,0,0,1, 0,0,0,0,1,0, 0,0,0,1,0,0}); + auto expQR = NDArrayFactory::create('c', {5,6}, {-32.6649659, -4.9594419, -8.2657365, 7.2248659, 16.5927006, 11.7251002, -0.1354883, -29.0586293, 10.9775804, -14.6886248, 4.1884104, 20.7115773, 0.3483986, 0.3236753, 25.5376258, 1.6432380, 9.6395914, -9.0237996, -0.0580664, 0.0798999, -0.0799029, 19.5280665, -4.9773587, 16.0968604, 0.3483986, -0.6667832, 0.0252425, 0.0159188, 10.6978354, -4.6919842}); + auto expCoeffs = NDArrayFactory::create('c', {1,5}, {1.58166, 1.28555, 1.98605, 1.99949, 0}); + auto expPermut = NDArrayFactory::create('c', {6,6}, {0,1,0,0,0,0, 0,0,1,0,0,0, 1,0,0,0,0,0, 0,0,0,0,0,1, 0,0,0,0,1,0, 0,0,0,1,0,0}); - ops::helpers::HHcolPivQR qr(matrix1); + ops::helpers::HHcolPivQR qr(matrix1); - ASSERT_TRUE(expQR.equalsTo(&qr._qr)); - ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); - ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); + ASSERT_TRUE(expQR.equalsTo(&qr._qr)); + ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); + ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); - ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); - ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); - ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); + ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); + ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); + ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHcolPivQR_2) { - auto matrix1 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); + auto matrix1 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); - auto expQR = NDArrayFactory::create('c', {6,6}, {38.1707, -3.03898, 5.16103, 23.0805, -7.57126, -13.885, -0.41519, 34.3623, 3.77403, 2.62327, -8.17784, 9.10312, 0.394431, 0.509952,-30.2179, -6.78341, 12.8421, 28.5491, -0.290633, 0.111912,0.450367, 28.1139, 15.5195, 2.60562, 0.332152, 0.405161,0.308163,0.0468127, 22.294,-2.94931, 0.249114,0.0627956,0.657873, 0.76767,-0.752594,-7.46986}); - auto expCoeffs = NDArrayFactory::create('c', {1,6}, {1.26198, 1.38824, 1.15567, 1.25667, 1.27682, 0}); - auto expPermut = NDArrayFactory::create('c', {6,6}, {0,0,1,0,0,0, 0,0,0,0,1,0, 0,0,0,1,0,0, 0,1,0,0,0,0, 0,0,0,0,0,1, 1,0,0,0,0,0}); + auto expQR = NDArrayFactory::create('c', {6,6}, {38.1707, -3.03898, 5.16103, 23.0805, -7.57126, -13.885, -0.41519, 34.3623, 3.77403, 2.62327, -8.17784, 9.10312, 0.394431, 0.509952,-30.2179, -6.78341, 12.8421, 28.5491, -0.290633, 0.111912,0.450367, 28.1139, 15.5195, 2.60562, 0.332152, 0.405161,0.308163,0.0468127, 22.294,-2.94931, 0.249114,0.0627956,0.657873, 0.76767,-0.752594,-7.46986}); + auto expCoeffs = NDArrayFactory::create('c', {1,6}, {1.26198, 1.38824, 1.15567, 1.25667, 1.27682, 0}); + auto expPermut = NDArrayFactory::create('c', {6,6}, {0,0,1,0,0,0, 0,0,0,0,1,0, 0,0,0,1,0,0, 0,1,0,0,0,0, 0,0,0,0,0,1, 1,0,0,0,0,0}); - ops::helpers::HHcolPivQR qr(matrix1); + ops::helpers::HHcolPivQR qr(matrix1); - ASSERT_TRUE(expQR.equalsTo(&qr._qr)); - ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); - ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); + ASSERT_TRUE(expQR.equalsTo(&qr._qr)); + ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); + ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); - ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); - ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); - ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); + ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); + ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); + ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHcolPivQR_3) { - NDArray matrix1('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); + NDArray matrix1('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); - auto expQR = NDArrayFactory::create('c', {6,5}, {-37.054 , 0.323852 , 8.04231 , -22.9395 ,-13.089, 0.105164, 32.6021, 6.42277, -0.262898,-1.58766, 0.140218, -0.485058, 29.2073, -9.92301,-23.7111, -0.262909,-0.00866538, 0.103467, 8.55831,-1.86455, -0.315491, 0.539207, 0.40754,-0.0374124,-7.10401, 0.315491, 0.385363,-0.216459, -0.340008,0.628595}); - auto expCoeffs = NDArrayFactory::create('c', {1,5}, {1.53975, 1.19431, 1.63446, 1.7905, 1.43356}); - auto expPermut = NDArrayFactory::create('c', {5,5}, {0,0,0,1,0, 1,0,0,0,0, 0,0,0,0,1, 0,0,1,0,0, 0,1,0,0,0}); + auto expQR = NDArrayFactory::create('c', {6,5}, {-37.054 , 0.323852 , 8.04231 , -22.9395 ,-13.089, 0.105164, 32.6021, 6.42277, -0.262898,-1.58766, 0.140218, -0.485058, 29.2073, -9.92301,-23.7111, -0.262909,-0.00866538, 0.103467, 8.55831,-1.86455, -0.315491, 0.539207, 0.40754,-0.0374124,-7.10401, 0.315491, 0.385363,-0.216459, -0.340008,0.628595}); + auto expCoeffs = NDArrayFactory::create('c', {1,5}, {1.53975, 1.19431, 1.63446, 1.7905, 1.43356}); + auto expPermut = NDArrayFactory::create('c', {5,5}, {0,0,0,1,0, 1,0,0,0,0, 0,0,0,0,1, 0,0,1,0,0, 0,1,0,0,0}); - ops::helpers::HHcolPivQR qr(matrix1); + ops::helpers::HHcolPivQR qr(matrix1); - ASSERT_TRUE(expQR.equalsTo(&qr._qr)); - ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); - ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); + ASSERT_TRUE(expQR.equalsTo(&qr._qr)); + ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); + ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); - ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); - ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); - ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); + ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); + ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); + ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); } @@ -543,593 +543,593 @@ TEST_F(HelpersTests1, HHcolPivQR_3) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test1) { - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto left = NDArrayFactory::create('c', {2,2}); - auto right = NDArrayFactory::create('c', {2,2}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto left = NDArrayFactory::create('c', {2,2}); + auto right = NDArrayFactory::create('c', {2,2}); - auto expLeft = NDArrayFactory::create('c', {2,2}, {0.972022, 0.23489, -0.23489, 0.972022}); - auto expRight = NDArrayFactory::create('c', {2,2}, {0.827657, 0.561234, -0.561234, 0.827657}); + auto expLeft = NDArrayFactory::create('c', {2,2}, {0.972022, 0.23489, -0.23489, 0.972022}); + auto expRight = NDArrayFactory::create('c', {2,2}, {0.827657, 0.561234, -0.561234, 0.827657}); - ops::helpers::JacobiSVD::svd2x2(matrix3, 1, 3, left, right); + ops::helpers::JacobiSVD::svd2x2(matrix3, 1, 3, left, right); - ASSERT_TRUE(expLeft.equalsTo(&left)); - ASSERT_TRUE(expRight.equalsTo(&right)); + ASSERT_TRUE(expLeft.equalsTo(&left)); + ASSERT_TRUE(expRight.equalsTo(&right)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test2) { - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto matrix4 = NDArrayFactory::create('c', {5,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19}); - auto matrix5 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto matrix4 = NDArrayFactory::create('c', {5,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19}); + auto matrix5 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); - auto exp3 = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, -0.609208,19.6977, 8.63044,-11.9811,-4.67059, -2, -11, 8, 2, -6, 3.55371, 0,-12.5903, 7.51356, -5.5844, 16, 15, -3, 7, 0}); - auto exp4 = NDArrayFactory::create('c', {5,5}, {12, -10.9657,19,24.5714, -6, 3, -2.6399, 2,8.83351, -7, 14,-0.406138,18,18.7839, 18, -14, 12.8949, 1,-7.9197, 2, -3, 23.353, 8, 8.2243,-19}); - auto exp5 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); + auto exp3 = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, -0.609208,19.6977, 8.63044,-11.9811,-4.67059, -2, -11, 8, 2, -6, 3.55371, 0,-12.5903, 7.51356, -5.5844, 16, 15, -3, 7, 0}); + auto exp4 = NDArrayFactory::create('c', {5,5}, {12, -10.9657,19,24.5714, -6, 3, -2.6399, 2,8.83351, -7, 14,-0.406138,18,18.7839, 18, -14, 12.8949, 1,-7.9197, 2, -3, 23.353, 8, 8.2243,-19}); + auto exp5 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); - ops::helpers::JacobiSVD jac(matrix3, true, true, true); - jac._m = matrix3; - jac._u = matrix4; - jac._v = matrix5; + ops::helpers::JacobiSVD jac(matrix3, true, true, true); + jac._m = matrix3; + jac._u = matrix4; + jac._v = matrix5; - double maxElem; - bool result = jac.isBlock2x2NotDiag(matrix3, 1, 3, maxElem); + double maxElem; + bool result = jac.isBlock2x2NotDiag(matrix3, 1, 3, maxElem); - // ASSERT_NEAR(maxElem, 19.69772, 1e-5); - ASSERT_TRUE(exp3.equalsTo(&matrix3)); - ASSERT_TRUE(exp4.equalsTo(&jac._u)); - ASSERT_TRUE(exp5.equalsTo(&jac._v)); + // ASSERT_NEAR(maxElem, 19.69772, 1e-5); + ASSERT_TRUE(exp3.equalsTo(&matrix3)); + ASSERT_TRUE(exp4.equalsTo(&jac._u)); + ASSERT_TRUE(exp5.equalsTo(&jac._v)); - ASSERT_TRUE(result); + ASSERT_TRUE(result); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test3) { - auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); + auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); - auto expected = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, -1.14919,-12.1206,3.59677, 4.34919,-4.24758, -1.94919, 11.7427,11.6698,-10.4444,-2.74919, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto expected = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, -1.14919,-12.1206,3.59677, 4.34919,-4.24758, -1.94919, 11.7427,11.6698,-10.4444,-2.74919, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); - ops::helpers::JacobiSVD::mulRotationOnLeft(1, 2, matrix, rotation); + ops::helpers::JacobiSVD::mulRotationOnLeft(1, 2, matrix, rotation); - ASSERT_TRUE(expected.equalsTo(&matrix)); + ASSERT_TRUE(expected.equalsTo(&matrix)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test4) { - auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); + auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); - auto expected = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 1.94919, 4.92056,-8.79677,1.25081, 5.04758, 1.14919,-16.1427,-8.46976,11.2444,0.349193, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto expected = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 1.94919, 4.92056,-8.79677,1.25081, 5.04758, 1.14919,-16.1427,-8.46976,11.2444,0.349193, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); - ops::helpers::JacobiSVD::mulRotationOnLeft(2, 1, matrix, rotation); + ops::helpers::JacobiSVD::mulRotationOnLeft(2, 1, matrix, rotation); - ASSERT_TRUE(expected.equalsTo(&matrix)); + ASSERT_TRUE(expected.equalsTo(&matrix)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test5) { - auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); + auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); - auto expected = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, 1.14919,6.32056,-4.59677,-1.14919, 3.44758, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto expected = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, 1.14919,6.32056,-4.59677,-1.14919, 3.44758, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); - ops::helpers::JacobiSVD::mulRotationOnLeft(2, 2, matrix, rotation); + ops::helpers::JacobiSVD::mulRotationOnLeft(2, 2, matrix, rotation); - ASSERT_TRUE(expected.equalsTo(&matrix)); + ASSERT_TRUE(expected.equalsTo(&matrix)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test6) { - auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); + auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); - auto expected = NDArrayFactory::create('c', {5,5}, {-18,-14.5173, 4.5746,-7, 1, 2, 6.46976,-16.5427,14, 2, -2,-8.39677,-6.92056, 2,-6, -3,-7.79677,-4.59677,-2, 7, 16, 5.32379, 11.019, 7, 0}); + auto expected = NDArrayFactory::create('c', {5,5}, {-18,-14.5173, 4.5746,-7, 1, 2, 6.46976,-16.5427,14, 2, -2,-8.39677,-6.92056, 2,-6, -3,-7.79677,-4.59677,-2, 7, 16, 5.32379, 11.019, 7, 0}); - ops::helpers::JacobiSVD::mulRotationOnRight(1, 2, matrix, rotation); + ops::helpers::JacobiSVD::mulRotationOnRight(1, 2, matrix, rotation); - ASSERT_TRUE(expected.equalsTo(&matrix)); + ASSERT_TRUE(expected.equalsTo(&matrix)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test7) { - auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); + auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); - auto expected = NDArrayFactory::create('c', {5,5}, {-18, 14.9173, 3.0254,-7, 1, 2,-13.6698,11.3427,14, 2, -2, 3.99677,10.1206, 2,-6, -3, 4.59677,7.79677,-2, 7, 16, 0.67621,-12.219, 7, 0}); + auto expected = NDArrayFactory::create('c', {5,5}, {-18, 14.9173, 3.0254,-7, 1, 2,-13.6698,11.3427,14, 2, -2, 3.99677,10.1206, 2,-6, -3, 4.59677,7.79677,-2, 7, 16, 0.67621,-12.219, 7, 0}); - ops::helpers::JacobiSVD::mulRotationOnRight(2, 1, matrix, rotation); + ops::helpers::JacobiSVD::mulRotationOnRight(2, 1, matrix, rotation); - ASSERT_TRUE(expected.equalsTo(&matrix)); + ASSERT_TRUE(expected.equalsTo(&matrix)); } ////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test8) { - auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); + auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); - auto expected = NDArrayFactory::create('c', {5,5}, {-18, 1, 18.5173,-7, 1, 2,-18,-12.6698,14, 2, -2,-11, 7.79677, 2,-6, -3, -8, 7.79677,-2, 7, 16, 15,-2.92379, 7, 0}); + auto expected = NDArrayFactory::create('c', {5,5}, {-18, 1, 18.5173,-7, 1, 2,-18,-12.6698,14, 2, -2,-11, 7.79677, 2,-6, -3, -8, 7.79677,-2, 7, 16, 15,-2.92379, 7, 0}); - ops::helpers::JacobiSVD::mulRotationOnRight(2, 2, matrix, rotation); + ops::helpers::JacobiSVD::mulRotationOnRight(2, 2, matrix, rotation); - ASSERT_TRUE(expected.equalsTo(&matrix)); + ASSERT_TRUE(expected.equalsTo(&matrix)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test9) { - auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto expS = NDArrayFactory::create('c', {5,1}, {35.7975, 29.1924, 11.1935, 9.2846, 6.77071}); - auto expU = NDArrayFactory::create('c', {5,5}, {0.744855,0.0686476, 0.079663,0.0889877, 0.65285, -0.386297,-0.760021,0.00624688, 0.156774, 0.498522, 0.186491,-0.322427, 0.773083,-0.468826,-0.209299, 0.246053,-0.215594, 0.240942, 0.821793,-0.399475, -0.447933, 0.516928, 0.581295, 0.269001, 0.349106}); - auto expV = NDArrayFactory::create('c', {5,5}, {-0.627363, 0.23317, 0.501211, 0.160272, -0.524545, -0.0849394, 0.917171,-0.155876,-0.0124053, 0.356555, 0.66983, 0.182569, 0.696897, 0.179807,0.000864568, -0.387647, -0.264316, 0.416597, 0.0941014, 0.772955, 0.0160818,-0.0351459,-0.255484, 0.965905, 0.0161524}); + auto expS = NDArrayFactory::create('c', {5,1}, {35.7975, 29.1924, 11.1935, 9.2846, 6.77071}); + auto expU = NDArrayFactory::create('c', {5,5}, {0.744855,0.0686476, 0.079663,0.0889877, 0.65285, -0.386297,-0.760021,0.00624688, 0.156774, 0.498522, 0.186491,-0.322427, 0.773083,-0.468826,-0.209299, 0.246053,-0.215594, 0.240942, 0.821793,-0.399475, -0.447933, 0.516928, 0.581295, 0.269001, 0.349106}); + auto expV = NDArrayFactory::create('c', {5,5}, {-0.627363, 0.23317, 0.501211, 0.160272, -0.524545, -0.0849394, 0.917171,-0.155876,-0.0124053, 0.356555, 0.66983, 0.182569, 0.696897, 0.179807,0.000864568, -0.387647, -0.264316, 0.416597, 0.0941014, 0.772955, 0.0160818,-0.0351459,-0.255484, 0.965905, 0.0161524}); - ops::helpers::JacobiSVD jac(matrix, true, true, true); + ops::helpers::JacobiSVD jac(matrix, true, true, true); - ASSERT_TRUE(expS.equalsTo(&jac._s)); - ASSERT_TRUE(expU.equalsTo(&jac._u)); - ASSERT_TRUE(expV.equalsTo(&jac._v)); + ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expU.equalsTo(&jac._u)); + ASSERT_TRUE(expV.equalsTo(&jac._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test10) { - auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto expS = NDArrayFactory::create('c', {5,1}, {35.7975, 29.1924, 11.1935, 9.2846, 6.77071}); - auto expU = NDArrayFactory::create('c', {5,5}, {0.744855,0.0686476, 0.079663,0.0889877, 0.65285, -0.386297,-0.760021,0.00624688, 0.156774, 0.498522, 0.186491,-0.322427, 0.773083,-0.468826,-0.209299, 0.246053,-0.215594, 0.240942, 0.821793,-0.399475, -0.447933, 0.516928, 0.581295, 0.269001, 0.349106}); - auto expV = NDArrayFactory::create('c', {5,5}, {-0.627363, 0.23317, 0.501211, 0.160272, -0.524545, -0.0849394, 0.917171,-0.155876,-0.0124053, 0.356555, 0.66983, 0.182569, 0.696897, 0.179807,0.000864568, -0.387647, -0.264316, 0.416597, 0.0941014, 0.772955, 0.0160818,-0.0351459,-0.255484, 0.965905, 0.0161524}); + auto expS = NDArrayFactory::create('c', {5,1}, {35.7975, 29.1924, 11.1935, 9.2846, 6.77071}); + auto expU = NDArrayFactory::create('c', {5,5}, {0.744855,0.0686476, 0.079663,0.0889877, 0.65285, -0.386297,-0.760021,0.00624688, 0.156774, 0.498522, 0.186491,-0.322427, 0.773083,-0.468826,-0.209299, 0.246053,-0.215594, 0.240942, 0.821793,-0.399475, -0.447933, 0.516928, 0.581295, 0.269001, 0.349106}); + auto expV = NDArrayFactory::create('c', {5,5}, {-0.627363, 0.23317, 0.501211, 0.160272, -0.524545, -0.0849394, 0.917171,-0.155876,-0.0124053, 0.356555, 0.66983, 0.182569, 0.696897, 0.179807,0.000864568, -0.387647, -0.264316, 0.416597, 0.0941014, 0.772955, 0.0160818,-0.0351459,-0.255484, 0.965905, 0.0161524}); - ops::helpers::JacobiSVD jac(matrix, true, true, false); + ops::helpers::JacobiSVD jac(matrix, true, true, false); - ASSERT_TRUE(expS.equalsTo(&jac._s)); - ASSERT_TRUE(expU.equalsTo(&jac._u)); - ASSERT_TRUE(expV.equalsTo(&jac._v)); + ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expU.equalsTo(&jac._u)); + ASSERT_TRUE(expV.equalsTo(&jac._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test11) { - auto matrix = NDArrayFactory::create('c', {6,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); + auto matrix = NDArrayFactory::create('c', {6,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); - auto expS = NDArrayFactory::create('c', {5,1}, {36.27, 32.1997, 15.9624, 10.6407, 6.9747}); - auto expU = NDArrayFactory::create('c', {6,5}, {0.720125,-0.149734, 0.227784,-0.0288531, 0.595353, -0.509487,-0.567298, -0.237169,-0.0469077, 0.38648, 0.120912, -0.32916,-0.0202265, 0.921633, -0.153994, 0.180033,-0.294831, 0.357867, -0.194106, -0.646595, -0.354033, 0.521937, 0.556566, 0.305582, 0.211013, -0.222425,-0.433662, 0.673515, -0.128465, 0.099309}); - auto expV = NDArrayFactory::create('c', {5,5}, {-0.581609, 0.315327,0.333158, 0.34476, -0.576582, 0.117364, 0.889461,0.175174,-0.166603, 0.369651, 0.643246,-0.0899117,0.613288, 0.442462,-0.0790943, -0.480818, -0.264384,0.395122, 0.223126, 0.702145, -0.0548207, -0.177325,0.571031,-0.779632, -0.1779}); + auto expS = NDArrayFactory::create('c', {5,1}, {36.27, 32.1997, 15.9624, 10.6407, 6.9747}); + auto expU = NDArrayFactory::create('c', {6,5}, {0.720125,-0.149734, 0.227784,-0.0288531, 0.595353, -0.509487,-0.567298, -0.237169,-0.0469077, 0.38648, 0.120912, -0.32916,-0.0202265, 0.921633, -0.153994, 0.180033,-0.294831, 0.357867, -0.194106, -0.646595, -0.354033, 0.521937, 0.556566, 0.305582, 0.211013, -0.222425,-0.433662, 0.673515, -0.128465, 0.099309}); + auto expV = NDArrayFactory::create('c', {5,5}, {-0.581609, 0.315327,0.333158, 0.34476, -0.576582, 0.117364, 0.889461,0.175174,-0.166603, 0.369651, 0.643246,-0.0899117,0.613288, 0.442462,-0.0790943, -0.480818, -0.264384,0.395122, 0.223126, 0.702145, -0.0548207, -0.177325,0.571031,-0.779632, -0.1779}); - ops::helpers::JacobiSVD jac(matrix, true, true, false); + ops::helpers::JacobiSVD jac(matrix, true, true, false); - ASSERT_TRUE(expS.equalsTo(&jac._s)); - ASSERT_TRUE(expU.equalsTo(&jac._u)); - ASSERT_TRUE(expV.equalsTo(&jac._v)); + ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expU.equalsTo(&jac._u)); + ASSERT_TRUE(expV.equalsTo(&jac._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test12) { - auto matrix = NDArrayFactory::create('c', {6,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); + auto matrix = NDArrayFactory::create('c', {6,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); - auto expS = NDArrayFactory::create('c', {5,1}, {36.27, 32.1997, 15.9624, 10.6407, 6.9747}); - auto expU = NDArrayFactory::create('c', {6,6}, {0.720125,-0.149734, 0.227784,-0.0288531, 0.595353,-0.227676, -0.509487,-0.567298, -0.237169,-0.0469077, 0.38648,-0.459108, 0.120912, -0.32916,-0.0202265, 0.921633,-0.153994,0.0591992, 0.180033,-0.294831, 0.357867, -0.194106,-0.646595,-0.544823, -0.354033, 0.521937, 0.556566, 0.305582, 0.211013,-0.393155, -0.222425,-0.433662, 0.673515, -0.128465, 0.099309, 0.531485}); - auto expV = NDArrayFactory::create('c', {5,5}, {-0.581609, 0.315327,0.333158, 0.34476, -0.576582, 0.117364, 0.889461,0.175174,-0.166603, 0.369651, 0.643246,-0.0899117,0.613288, 0.442462,-0.0790943, -0.480818, -0.264384,0.395122, 0.223126, 0.702145, -0.0548207, -0.177325,0.571031,-0.779632, -0.1779}); + auto expS = NDArrayFactory::create('c', {5,1}, {36.27, 32.1997, 15.9624, 10.6407, 6.9747}); + auto expU = NDArrayFactory::create('c', {6,6}, {0.720125,-0.149734, 0.227784,-0.0288531, 0.595353,-0.227676, -0.509487,-0.567298, -0.237169,-0.0469077, 0.38648,-0.459108, 0.120912, -0.32916,-0.0202265, 0.921633,-0.153994,0.0591992, 0.180033,-0.294831, 0.357867, -0.194106,-0.646595,-0.544823, -0.354033, 0.521937, 0.556566, 0.305582, 0.211013,-0.393155, -0.222425,-0.433662, 0.673515, -0.128465, 0.099309, 0.531485}); + auto expV = NDArrayFactory::create('c', {5,5}, {-0.581609, 0.315327,0.333158, 0.34476, -0.576582, 0.117364, 0.889461,0.175174,-0.166603, 0.369651, 0.643246,-0.0899117,0.613288, 0.442462,-0.0790943, -0.480818, -0.264384,0.395122, 0.223126, 0.702145, -0.0548207, -0.177325,0.571031,-0.779632, -0.1779}); - ops::helpers::JacobiSVD jac(matrix, true, true, true); + ops::helpers::JacobiSVD jac(matrix, true, true, true); - ASSERT_TRUE(expS.equalsTo(&jac._s)); - ASSERT_TRUE(expU.equalsTo(&jac._u)); - ASSERT_TRUE(expV.equalsTo(&jac._v)); + ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expU.equalsTo(&jac._u)); + ASSERT_TRUE(expV.equalsTo(&jac._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test13) { - auto matrix = NDArrayFactory::create('c', {5,6}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); + auto matrix = NDArrayFactory::create('c', {5,6}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); - auto expS = NDArrayFactory::create('c', {5,1}, {40.499, 23.5079, 17.8139, 14.4484, 7.07957}); - auto expU = NDArrayFactory::create('c', {5,5}, {0.592324,-0.121832,-0.484064,-0.624878,-0.0975619, 0.651331, 0.367418, 0.117429, 0.370792, 0.538048, -0.272693,-0.138725, 0.249336,-0.540587, 0.742962, 0.263619,-0.903996, 0.179714, 0.276206, 0.0686237, -0.284717,-0.117079,-0.810818, 0.321741, 0.379848}); - auto expV = NDArrayFactory::create('c', {6,6}, {-0.619634,-0.158345, 0.462262,-0.021009,-0.299779, 0.53571, -0.183441,-0.504296,-0.150804,-0.251078,-0.563079,-0.556052, 0.724925,-0.404744, 0.154104,-0.177039,-0.262604, 0.431988, 0.0335645,-0.501546, 0.221702, 0.797602, 0.186339,-0.165176, -0.0675636,0.0663677,-0.728788, 0.414614,-0.390566, 0.368038, -0.226262, -0.54849,-0.399426,-0.311613, 0.580387, 0.233392}); + auto expS = NDArrayFactory::create('c', {5,1}, {40.499, 23.5079, 17.8139, 14.4484, 7.07957}); + auto expU = NDArrayFactory::create('c', {5,5}, {0.592324,-0.121832,-0.484064,-0.624878,-0.0975619, 0.651331, 0.367418, 0.117429, 0.370792, 0.538048, -0.272693,-0.138725, 0.249336,-0.540587, 0.742962, 0.263619,-0.903996, 0.179714, 0.276206, 0.0686237, -0.284717,-0.117079,-0.810818, 0.321741, 0.379848}); + auto expV = NDArrayFactory::create('c', {6,6}, {-0.619634,-0.158345, 0.462262,-0.021009,-0.299779, 0.53571, -0.183441,-0.504296,-0.150804,-0.251078,-0.563079,-0.556052, 0.724925,-0.404744, 0.154104,-0.177039,-0.262604, 0.431988, 0.0335645,-0.501546, 0.221702, 0.797602, 0.186339,-0.165176, -0.0675636,0.0663677,-0.728788, 0.414614,-0.390566, 0.368038, -0.226262, -0.54849,-0.399426,-0.311613, 0.580387, 0.233392}); - ops::helpers::JacobiSVD jac(matrix, true, true, true); + ops::helpers::JacobiSVD jac(matrix, true, true, true); - ASSERT_TRUE(expS.equalsTo(&jac._s)); - ASSERT_TRUE(expU.equalsTo(&jac._u)); - ASSERT_TRUE(expV.equalsTo(&jac._v)); + ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expU.equalsTo(&jac._u)); + ASSERT_TRUE(expV.equalsTo(&jac._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test14) { - auto matrix = NDArrayFactory::create('c', {5,6}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); + auto matrix = NDArrayFactory::create('c', {5,6}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); - auto expS = NDArrayFactory::create('c', {5,1}, {40.499, 23.5079, 17.8139, 14.4484, 7.07957}); - auto expU = NDArrayFactory::create('c', {5,5}, {0.592324,-0.121832,-0.484064,-0.624878,-0.0975619, 0.651331, 0.367418, 0.117429, 0.370792, 0.538048, -0.272693,-0.138725, 0.249336,-0.540587, 0.742962, 0.263619,-0.903996, 0.179714, 0.276206, 0.0686237, -0.284717,-0.117079,-0.810818, 0.321741, 0.379848}); - auto expV = NDArrayFactory::create('c', {6,5}, {-0.619634,-0.158345, 0.462262,-0.021009,-0.299779, -0.183441,-0.504296,-0.150804,-0.251078,-0.563079, 0.724925,-0.404744, 0.154104,-0.177039,-0.262604, 0.0335645,-0.501546, 0.221702, 0.797602, 0.186339, -0.0675636,0.0663677,-0.728788, 0.414614,-0.390566, -0.226262, -0.54849,-0.399426,-0.311613, 0.580387}); + auto expS = NDArrayFactory::create('c', {5,1}, {40.499, 23.5079, 17.8139, 14.4484, 7.07957}); + auto expU = NDArrayFactory::create('c', {5,5}, {0.592324,-0.121832,-0.484064,-0.624878,-0.0975619, 0.651331, 0.367418, 0.117429, 0.370792, 0.538048, -0.272693,-0.138725, 0.249336,-0.540587, 0.742962, 0.263619,-0.903996, 0.179714, 0.276206, 0.0686237, -0.284717,-0.117079,-0.810818, 0.321741, 0.379848}); + auto expV = NDArrayFactory::create('c', {6,5}, {-0.619634,-0.158345, 0.462262,-0.021009,-0.299779, -0.183441,-0.504296,-0.150804,-0.251078,-0.563079, 0.724925,-0.404744, 0.154104,-0.177039,-0.262604, 0.0335645,-0.501546, 0.221702, 0.797602, 0.186339, -0.0675636,0.0663677,-0.728788, 0.414614,-0.390566, -0.226262, -0.54849,-0.399426,-0.311613, 0.580387}); - ops::helpers::JacobiSVD jac(matrix, true, true, false); + ops::helpers::JacobiSVD jac(matrix, true, true, false); - ASSERT_TRUE(expS.equalsTo(&jac._s)); - ASSERT_TRUE(expU.equalsTo(&jac._u)); - ASSERT_TRUE(expV.equalsTo(&jac._v)); + ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expU.equalsTo(&jac._u)); + ASSERT_TRUE(expV.equalsTo(&jac._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test15) { - auto matrix = NDArrayFactory::create('c', {5,6}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); + auto matrix = NDArrayFactory::create('c', {5,6}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); - auto expS = NDArrayFactory::create('c', {5,1}, {40.499, 23.5079, 17.8139, 14.4484, 7.07957}); - auto expU = NDArrayFactory::create('c', {5,5}, {0.592324,-0.121832,-0.484064,-0.624878,-0.0975619, 0.651331, 0.367418, 0.117429, 0.370792, 0.538048, -0.272693,-0.138725, 0.249336,-0.540587, 0.742962, 0.263619,-0.903996, 0.179714, 0.276206, 0.0686237, -0.284717,-0.117079,-0.810818, 0.321741, 0.379848}); - auto expV = NDArrayFactory::create('c', {6,5}, {-0.619634,-0.158345, 0.462262,-0.021009,-0.299779, -0.183441,-0.504296,-0.150804,-0.251078,-0.563079, 0.724925,-0.404744, 0.154104,-0.177039,-0.262604, 0.0335645,-0.501546, 0.221702, 0.797602, 0.186339, -0.0675636,0.0663677,-0.728788, 0.414614,-0.390566, -0.226262, -0.54849,-0.399426,-0.311613, 0.580387}); + auto expS = NDArrayFactory::create('c', {5,1}, {40.499, 23.5079, 17.8139, 14.4484, 7.07957}); + auto expU = NDArrayFactory::create('c', {5,5}, {0.592324,-0.121832,-0.484064,-0.624878,-0.0975619, 0.651331, 0.367418, 0.117429, 0.370792, 0.538048, -0.272693,-0.138725, 0.249336,-0.540587, 0.742962, 0.263619,-0.903996, 0.179714, 0.276206, 0.0686237, -0.284717,-0.117079,-0.810818, 0.321741, 0.379848}); + auto expV = NDArrayFactory::create('c', {6,5}, {-0.619634,-0.158345, 0.462262,-0.021009,-0.299779, -0.183441,-0.504296,-0.150804,-0.251078,-0.563079, 0.724925,-0.404744, 0.154104,-0.177039,-0.262604, 0.0335645,-0.501546, 0.221702, 0.797602, 0.186339, -0.0675636,0.0663677,-0.728788, 0.414614,-0.390566, -0.226262, -0.54849,-0.399426,-0.311613, 0.580387}); - ops::helpers::JacobiSVD jac(matrix, false, false, false); + ops::helpers::JacobiSVD jac(matrix, false, false, false); - ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expS.equalsTo(&jac._s)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test16) { - NDArray rotation('c', {2,2}, sd::DataType::DOUBLE); + NDArray rotation('c', {2,2}, sd::DataType::DOUBLE); - NDArray exp1('c', {2,2}, {1,0,0,1 }, sd::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {0,1,-1,0}, sd::DataType::DOUBLE); - NDArray exp3('c', {2,2}, {-1,0,0,-1}, sd::DataType::DOUBLE); - NDArray exp4('c', {2,2}, {0.983282, 0.182089, -0.182089, 0.983282}, sd::DataType::DOUBLE); - NDArray exp5('c', {2,2}, {0.249041, 0.968493, -0.968493, 0.249041}, sd::DataType::DOUBLE); + NDArray exp1('c', {2,2}, {1,0,0,1 }, sd::DataType::DOUBLE); + NDArray exp2('c', {2,2}, {0,1,-1,0}, sd::DataType::DOUBLE); + NDArray exp3('c', {2,2}, {-1,0,0,-1}, sd::DataType::DOUBLE); + NDArray exp4('c', {2,2}, {0.983282, 0.182089, -0.182089, 0.983282}, sd::DataType::DOUBLE); + NDArray exp5('c', {2,2}, {0.249041, 0.968493, -0.968493, 0.249041}, sd::DataType::DOUBLE); - ops::helpers::JacobiSVD::createJacobiRotationGivens(0, 0, rotation); - ASSERT_TRUE(rotation.equalsTo(exp1)); - ASSERT_TRUE(rotation.isSameShapeStrict(exp1)); + ops::helpers::JacobiSVD::createJacobiRotationGivens(0, 0, rotation); + ASSERT_TRUE(rotation.equalsTo(exp1)); + ASSERT_TRUE(rotation.isSameShapeStrict(exp1)); - ops::helpers::JacobiSVD::createJacobiRotationGivens(0, -0.5, rotation); - ASSERT_TRUE(rotation.equalsTo(exp2)); - ASSERT_TRUE(rotation.isSameShapeStrict(exp2)); + ops::helpers::JacobiSVD::createJacobiRotationGivens(0, -0.5, rotation); + ASSERT_TRUE(rotation.equalsTo(exp2)); + ASSERT_TRUE(rotation.isSameShapeStrict(exp2)); - ops::helpers::JacobiSVD::createJacobiRotationGivens(-0.5, 0, rotation); - ASSERT_TRUE(rotation.equalsTo(exp3)); - ASSERT_TRUE(rotation.isSameShapeStrict(exp3)); + ops::helpers::JacobiSVD::createJacobiRotationGivens(-0.5, 0, rotation); + ASSERT_TRUE(rotation.equalsTo(exp3)); + ASSERT_TRUE(rotation.isSameShapeStrict(exp3)); - ops::helpers::JacobiSVD::createJacobiRotationGivens(2.7, -0.5, rotation); - ASSERT_TRUE(rotation.equalsTo(exp4)); - ASSERT_TRUE(rotation.isSameShapeStrict(exp4)); + ops::helpers::JacobiSVD::createJacobiRotationGivens(2.7, -0.5, rotation); + ASSERT_TRUE(rotation.equalsTo(exp4)); + ASSERT_TRUE(rotation.isSameShapeStrict(exp4)); - ops::helpers::JacobiSVD::createJacobiRotationGivens(2.7, -10.5, rotation); - ASSERT_TRUE(rotation.equalsTo(exp5)); - ASSERT_TRUE(rotation.isSameShapeStrict(exp5)); + ops::helpers::JacobiSVD::createJacobiRotationGivens(2.7, -10.5, rotation); + ASSERT_TRUE(rotation.equalsTo(exp5)); + ASSERT_TRUE(rotation.isSameShapeStrict(exp5)); } TEST_F(HelpersTests1, test_binary_search_1) { - std::array array = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::array array = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - auto idx = sd::ops::helpers::binarySearch(array.data(), 2, 10); - ASSERT_EQ(2, idx); + auto idx = sd::ops::helpers::binarySearch(array.data(), 2, 10); + ASSERT_EQ(2, idx); } TEST_F(HelpersTests1, test_binary_search_2) { - std::array array = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::array array = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - auto idx = sd::ops::helpers::binarySearch(array.data(), 18, 10); - ASSERT_EQ(-1, idx); + auto idx = sd::ops::helpers::binarySearch(array.data(), 18, 10); + ASSERT_EQ(-1, idx); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test1) { - auto matrix = NDArrayFactory::create('c', {5,5}, {-17 ,14 ,9 ,-12 ,-12 ,5 ,-4 ,-19 ,-7 ,-12 ,15 ,16 ,17 ,-6 ,8 ,-10 ,14 ,-15 ,6 ,-10 ,-14 ,12 ,-1 ,-16 ,3}); - auto matrix2 = NDArrayFactory::create('c', {5,5}, {18 ,3 ,2 ,7 ,-11 ,7 ,7 ,10 ,-13 ,-8 ,13 ,20 ,-4 ,-16 ,-9 ,-17 ,-5 ,-7 ,-19 ,-8 ,-9 ,9 ,6 ,14 ,-11}); - auto expM = NDArrayFactory::create('c', {5,5}, {-17,14,9,-12,-12, 5,-4, -19, -7,-12, 15,16,17.0294, -6, 8, -10,14, -15, 6,-10, -14,12, 0,-16, 0}); - auto expU = NDArrayFactory::create('c', {5,5}, {18,3, 2,7,-11, 7, 7.75131,10,-12.5665, -8, 13, 20.905,-4,-14.7979, -9, -17,-3.87565,-7,-19.2608, -8, -9, 9, 6, 14,-11}); + auto matrix = NDArrayFactory::create('c', {5,5}, {-17 ,14 ,9 ,-12 ,-12 ,5 ,-4 ,-19 ,-7 ,-12 ,15 ,16 ,17 ,-6 ,8 ,-10 ,14 ,-15 ,6 ,-10 ,-14 ,12 ,-1 ,-16 ,3}); + auto matrix2 = NDArrayFactory::create('c', {5,5}, {18 ,3 ,2 ,7 ,-11 ,7 ,7 ,10 ,-13 ,-8 ,13 ,20 ,-4 ,-16 ,-9 ,-17 ,-5 ,-7 ,-19 ,-8 ,-9 ,9 ,6 ,14 ,-11}); + auto expM = NDArrayFactory::create('c', {5,5}, {-17,14,9,-12,-12, 5,-4, -19, -7,-12, 15,16,17.0294, -6, 8, -10,14, -15, 6,-10, -14,12, 0,-16, 0}); + auto expU = NDArrayFactory::create('c', {5,5}, {18,3, 2,7,-11, 7, 7.75131,10,-12.5665, -8, 13, 20.905,-4,-14.7979, -9, -17,-3.87565,-7,-19.2608, -8, -9, 9, 6, 14,-11}); - ops::helpers::SVD svd(matrix, 4, true, true, true, 't'); - svd._m = matrix; - svd._u = matrix2; - svd.deflation1(1,1,2,2); + ops::helpers::SVD svd(matrix, 4, true, true, true, 't'); + svd._m = matrix; + svd._u = matrix2; + svd.deflation1(1,1,2,2); - ASSERT_TRUE(expM.equalsTo(&svd._m)); - ASSERT_TRUE(expU.equalsTo(&svd._u)); + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test2) { - auto matrix= NDArrayFactory::create('c', {5,5}, {-17 ,14 ,9 ,-12 ,-12 ,5 ,-4 ,-19 ,-7 ,-12 ,15 ,16 ,17 ,-6 ,8 ,-10 ,14 ,-15 ,6 ,-10 ,-14 ,12 ,-1 ,-16 ,3}); - auto matrix2 = NDArrayFactory::create('c', {5,5}, {18 ,3 ,2 ,7 ,-11 ,7 ,7 ,10 ,-13 ,-8 ,13 ,20 ,-4 ,-16 ,-9 ,-17 ,-5 ,-7 ,-19 ,-8 ,-9 ,9 ,6 ,14 ,-11}); - auto expM = NDArrayFactory::create('c', {5,5}, {22.6716,14, 9,-12,-12, 5,-4,-19, -7,-12, 0,16, 0, -6, 8, -10,14,-15, 6,-10, -14,12, -1,-16, 3}); - auto expU = NDArrayFactory::create('c', {5,5}, {-12.1738, 3, -13.4089, 7,-11, 1.36735, 7, -12.1297,-13, -8, -12.3944,20, -5.60173,-16, -9, -17,-5,-7,-19, -8, -9, 9, 6, 14,-11}); + auto matrix= NDArrayFactory::create('c', {5,5}, {-17 ,14 ,9 ,-12 ,-12 ,5 ,-4 ,-19 ,-7 ,-12 ,15 ,16 ,17 ,-6 ,8 ,-10 ,14 ,-15 ,6 ,-10 ,-14 ,12 ,-1 ,-16 ,3}); + auto matrix2 = NDArrayFactory::create('c', {5,5}, {18 ,3 ,2 ,7 ,-11 ,7 ,7 ,10 ,-13 ,-8 ,13 ,20 ,-4 ,-16 ,-9 ,-17 ,-5 ,-7 ,-19 ,-8 ,-9 ,9 ,6 ,14 ,-11}); + auto expM = NDArrayFactory::create('c', {5,5}, {22.6716,14, 9,-12,-12, 5,-4,-19, -7,-12, 0,16, 0, -6, 8, -10,14,-15, 6,-10, -14,12, -1,-16, 3}); + auto expU = NDArrayFactory::create('c', {5,5}, {-12.1738, 3, -13.4089, 7,-11, 1.36735, 7, -12.1297,-13, -8, -12.3944,20, -5.60173,-16, -9, -17,-5,-7,-19, -8, -9, 9, 6, 14,-11}); - ops::helpers::SVD svd(matrix, 4, true, true, true); - svd._m = matrix; - svd._u = matrix2; - svd.deflation1(0,0,2,2); + ops::helpers::SVD svd(matrix, 4, true, true, true); + svd._m = matrix; + svd._u = matrix2; + svd.deflation1(0,0,2,2); - ASSERT_TRUE(expM.equalsTo(&svd._m)); - ASSERT_TRUE(expU.equalsTo(&svd._u)); + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test3) { - auto matrix= NDArrayFactory::create('c', {5,5}, {-17 ,14 ,9 ,-12 ,-12 ,5 ,-4 ,-19 ,-7 ,-12 ,15 ,16 ,17 ,-6 ,8 ,-10 ,14 ,-15 ,6 ,-10 ,-14 ,12 ,-1 ,-16 ,3}); - auto matrix2 = NDArrayFactory::create('c', {2,6}, {18 ,3 ,2 ,7 ,-11 ,7 ,7 ,10 ,-13 ,-8 ,13 ,20}); - auto expM = NDArrayFactory::create('c', {5,5}, {-17,14,9,-12,-12, 5,-4, -19, -7,-12, 15,16,17.0294, -6, 8, -10,14, -15, 6,-10, -14,12, 0,-16, 0}); - auto expU = NDArrayFactory::create('c', {2,6}, {18, 2.58377, 2, 7.16409,-11, 7, 7 ,10.4525 ,-13, -7.39897 ,13 ,20}); + auto matrix= NDArrayFactory::create('c', {5,5}, {-17 ,14 ,9 ,-12 ,-12 ,5 ,-4 ,-19 ,-7 ,-12 ,15 ,16 ,17 ,-6 ,8 ,-10 ,14 ,-15 ,6 ,-10 ,-14 ,12 ,-1 ,-16 ,3}); + auto matrix2 = NDArrayFactory::create('c', {2,6}, {18 ,3 ,2 ,7 ,-11 ,7 ,7 ,10 ,-13 ,-8 ,13 ,20}); + auto expM = NDArrayFactory::create('c', {5,5}, {-17,14,9,-12,-12, 5,-4, -19, -7,-12, 15,16,17.0294, -6, 8, -10,14, -15, 6,-10, -14,12, 0,-16, 0}); + auto expU = NDArrayFactory::create('c', {2,6}, {18, 2.58377, 2, 7.16409,-11, 7, 7 ,10.4525 ,-13, -7.39897 ,13 ,20}); - ops::helpers::SVD svd(matrix, 4, false, true, true, 't'); - svd._m = matrix; - svd._u = matrix2; - svd.deflation1(1,1,2,2); + ops::helpers::SVD svd(matrix, 4, false, true, true, 't'); + svd._m = matrix; + svd._u = matrix2; + svd.deflation1(1,1,2,2); - ASSERT_TRUE(expM.equalsTo(&svd._m)); - ASSERT_TRUE(expU.equalsTo(&svd._u)); + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test4) { - auto matrix1 = NDArrayFactory::create('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); - auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto expM = NDArrayFactory::create('c', {6,5}, {12, 20, 19,-18, -6, 3, 6, 2, -7, -7, 14, 8, 18,-17, 18, -14,-15,8.06226, 2, 2, -3,-18, 0,-17, 2, 12, 18, 6, -2,-17}); - auto expU = NDArrayFactory::create('c', {6,6}, {-10,-16, -20, 13, 20,-10, -9, -1,-20.7138,4.46525, -4, 20, -11, 19,-18.4812,2.72876, 12,-19, 18,-18, 17, -10,-19, 14, -2, -7, -17, -14, -4,-16, 18, -6, -18, 1,-15,-12}); - auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2,-18, -13, 14, 2, -2,-11,2.97683,-7.69015,-6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto matrix1 = NDArrayFactory::create('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); + auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto expM = NDArrayFactory::create('c', {6,5}, {12, 20, 19,-18, -6, 3, 6, 2, -7, -7, 14, 8, 18,-17, 18, -14,-15,8.06226, 2, 2, -3,-18, 0,-17, 2, 12, 18, 6, -2,-17}); + auto expU = NDArrayFactory::create('c', {6,6}, {-10,-16, -20, 13, 20,-10, -9, -1,-20.7138,4.46525, -4, 20, -11, 19,-18.4812,2.72876, 12,-19, 18,-18, 17, -10,-19, 14, -2, -7, -17, -14, -4,-16, 18, -6, -18, 1,-15,-12}); + auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2,-18, -13, 14, 2, -2,-11,2.97683,-7.69015,-6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); - ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); - svd._m = matrix1; - svd._u = matrix2; - svd._v = matrix3; - svd.deflation2(1, 2, 2, 1, 1, 2, 1); + ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); + svd._m = matrix1; + svd._u = matrix2; + svd._v = matrix3; + svd.deflation2(1, 2, 2, 1, 1, 2, 1); - ASSERT_TRUE(expM.equalsTo(&svd._m)); - ASSERT_TRUE(expU.equalsTo(&svd._u)); - ASSERT_TRUE(expV.equalsTo(&svd._v)); + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); + ASSERT_TRUE(expV.equalsTo(&svd._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test5) { - auto matrix1 = NDArrayFactory::create('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); - auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto expM = NDArrayFactory::create('c', {6,5}, {18.4391, 20, 19,-18, -6, 3, 6, 2, -7, -7, 0, 8,18.4391,-17, 18, -14,-15, 1, 2, 2, -3,-18, 8,-17,-19, 12, 18, 6, -2,-17}); - auto expU = NDArrayFactory::create('c', {6,6}, {-10,-16,-20,13, 20,-10, -9,-15.8359, -7,-12.2566, -4, 20, -11,-1.30158, -5,-26.1401, 12,-19, 18,-19.3068, 17, 7.15871,-19, 14, -2, -7,-17, -14, -4,-16, 18, -6,-18, 1,-15,-12}); - auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2,-1.08465,-13,22.7777, 2, -2,-5.64019, 8,9.65341,-6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto matrix1 = NDArrayFactory::create('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); + auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto expM = NDArrayFactory::create('c', {6,5}, {18.4391, 20, 19,-18, -6, 3, 6, 2, -7, -7, 0, 8,18.4391,-17, 18, -14,-15, 1, 2, 2, -3,-18, 8,-17,-19, 12, 18, 6, -2,-17}); + auto expU = NDArrayFactory::create('c', {6,6}, {-10,-16,-20,13, 20,-10, -9,-15.8359, -7,-12.2566, -4, 20, -11,-1.30158, -5,-26.1401, 12,-19, 18,-19.3068, 17, 7.15871,-19, 14, -2, -7,-17, -14, -4,-16, 18, -6,-18, 1,-15,-12}); + auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2,-1.08465,-13,22.7777, 2, -2,-5.64019, 8,9.65341,-6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); - ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); - svd._m = matrix1; - svd._u = matrix2; - svd._v = matrix3; - svd.deflation2(1, 0, 1, 1, 0, 2, 2); + ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); + svd._m = matrix1; + svd._u = matrix2; + svd._v = matrix3; + svd.deflation2(1, 0, 1, 1, 0, 2, 2); - ASSERT_TRUE(expM.equalsTo(&svd._m)); - ASSERT_TRUE(expU.equalsTo(&svd._u)); - ASSERT_TRUE(expV.equalsTo(&svd._v)); + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); + ASSERT_TRUE(expV.equalsTo(&svd._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test6) { - auto matrix1 = NDArrayFactory::create('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); - auto matrix2 = NDArrayFactory::create('c', {2,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20}); - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto expM = NDArrayFactory::create('c', {6,5}, {18.4391, 20, 19,-18, -6, 3, 6, 2, -7, -7, 0, 8,18.4391,-17, 18, -14,-15, 1, 2, 2, -3,-18, 8,-17,-19, 12, 18, 6, -2,-17}); - auto expU = NDArrayFactory::create('c', {2,6}, {-10, -0.542326,-20, 20.6084,20,-10, -9, -15.8359, -7,-12.2566,-4, 20}); - auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2,-1.08465,-13,22.7777, 2, -2,-5.64019, 8,9.65341,-6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto matrix1 = NDArrayFactory::create('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); + auto matrix2 = NDArrayFactory::create('c', {2,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto expM = NDArrayFactory::create('c', {6,5}, {18.4391, 20, 19,-18, -6, 3, 6, 2, -7, -7, 0, 8,18.4391,-17, 18, -14,-15, 1, 2, 2, -3,-18, 8,-17,-19, 12, 18, 6, -2,-17}); + auto expU = NDArrayFactory::create('c', {2,6}, {-10, -0.542326,-20, 20.6084,20,-10, -9, -15.8359, -7,-12.2566,-4, 20}); + auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2,-1.08465,-13,22.7777, 2, -2,-5.64019, 8,9.65341,-6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); - ops::helpers::SVD svd(matrix3, 4, false, true, true, 't'); - svd._m = matrix1; - svd._u = matrix2; - svd._v = matrix3; - svd.deflation2(1, 0, 1, 1, 0, 2, 2); + ops::helpers::SVD svd(matrix3, 4, false, true, true, 't'); + svd._m = matrix1; + svd._u = matrix2; + svd._v = matrix3; + svd.deflation2(1, 0, 1, 1, 0, 2, 2); - ASSERT_TRUE(expM.equalsTo(&svd._m)); - ASSERT_TRUE(expU.equalsTo(&svd._u)); - ASSERT_TRUE(expV.equalsTo(&svd._v)); + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); + ASSERT_TRUE(expV.equalsTo(&svd._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test7) { - auto matrix1 = NDArrayFactory::create('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); - auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto matrix1 = NDArrayFactory::create('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); + auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto expM = NDArrayFactory::create('c', {6,5}, {12, 20, 19,-18, -6, 3, 6, 2, -7, -7, 14, 8,19.6977,-17, 18, -14,-15, 1, 2, 2, -3,-18, 0,-17, 0, 12, 18, 6, -2,-17}); - auto expU = NDArrayFactory::create('c', {6,6}, {-10, -16,-20, 13, 20,-10, -9,-9.03658, -7,-17.8701, -4, 20, -11, 10.0519, -5,-24.1652, 12,-19, 18, -20.51, 17,-1.82762,-19, 14, -2,-12.0826,-17,-9.95039, -4,-16, 18, -6,-18, 1,-15,-12}); - auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19,-7, 1, 2,-18,-13,14, 2, -2,-11, 8, 2,-6, -3, -8, 8,-2, 7, 16, 15, -3, 7, 0}); + auto expM = NDArrayFactory::create('c', {6,5}, {12, 20, 19,-18, -6, 3, 6, 2, -7, -7, 14, 8,19.6977,-17, 18, -14,-15, 1, 2, 2, -3,-18, 0,-17, 0, 12, 18, 6, -2,-17}); + auto expU = NDArrayFactory::create('c', {6,6}, {-10, -16,-20, 13, 20,-10, -9,-9.03658, -7,-17.8701, -4, 20, -11, 10.0519, -5,-24.1652, 12,-19, 18, -20.51, 17,-1.82762,-19, 14, -2,-12.0826,-17,-9.95039, -4,-16, 18, -6,-18, 1,-15,-12}); + auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19,-7, 1, 2,-18,-13,14, 2, -2,-11, 8, 2,-6, -3, -8, 8,-2, 7, 16, 15, -3, 7, 0}); - ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); - svd._m = matrix1; - svd._u = matrix2; - svd._v = matrix3; - svd.deflation(1, 3, 1, 1, 2, 1); + ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); + svd._m = matrix1; + svd._u = matrix2; + svd._v = matrix3; + svd.deflation(1, 3, 1, 1, 2, 1); - ASSERT_TRUE(expM.equalsTo(&svd._m)); - ASSERT_TRUE(expU.equalsTo(&svd._u)); - ASSERT_TRUE(expV.equalsTo(&svd._v)); + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); + ASSERT_TRUE(expV.equalsTo(&svd._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test8) { - auto matrix1 = NDArrayFactory::create('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); - auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto matrix1 = NDArrayFactory::create('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); + auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto expM = NDArrayFactory::create('c', {6,5}, {12, 20,19,-18, -6, 3, 6, 2, -7, -7, 14,-15, 2,-17, 18, -14, 8, 1, 18, 2, -3,-18, 8,-17,-19, 12, 18, 6, -2,-17}); - auto expU = NDArrayFactory::create('c', {6,6}, {-10,-20,-16, 13, 20,-10, -9, -7, -1,-20, -4, 20, -11, -5, 19,-18, 12,-19, 18, 17,-18,-10,-19, 14, -2, -7,-17,-14, -4,-16, 18, -6,-18, 1,-15,-12}); - auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19,-7, 1, 2,-18,-13, 2,14, -2,-11, 8,-6, 2, -3, -8, 8, 7,-2, 16, 15, -3, 7, 0}); + auto expM = NDArrayFactory::create('c', {6,5}, {12, 20,19,-18, -6, 3, 6, 2, -7, -7, 14,-15, 2,-17, 18, -14, 8, 1, 18, 2, -3,-18, 8,-17,-19, 12, 18, 6, -2,-17}); + auto expU = NDArrayFactory::create('c', {6,6}, {-10,-20,-16, 13, 20,-10, -9, -7, -1,-20, -4, 20, -11, -5, 19,-18, 12,-19, 18, 17,-18,-10,-19, 14, -2, -7,-17,-14, -4,-16, 18, -6,-18, 1,-15,-12}); + auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19,-7, 1, 2,-18,-13, 2,14, -2,-11, 8,-6, 2, -3, -8, 8, 7,-2, 16, 15, -3, 7, 0}); - ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); - svd._m = matrix1; - svd._u = matrix2; - svd._v = matrix3; - svd.deflation(0, 2, 2, 1, 2, 1); + ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); + svd._m = matrix1; + svd._u = matrix2; + svd._v = matrix3; + svd.deflation(0, 2, 2, 1, 2, 1); - ASSERT_TRUE(expM.equalsTo(&svd._m)); - ASSERT_TRUE(expU.equalsTo(&svd._u)); - ASSERT_TRUE(expV.equalsTo(&svd._v)); + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); + ASSERT_TRUE(expV.equalsTo(&svd._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test9) { - auto col0 = NDArrayFactory::create('c', {10,1}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,14}); - auto diag = NDArrayFactory::create('c', {10,1}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2}); - auto permut = NDArrayFactory::create('c', {1,10}, {8 ,1 ,4 ,0, 5 ,2 ,9 ,3 ,7 ,6}); - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto col0 = NDArrayFactory::create('c', {10,1}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,14}); + auto diag = NDArrayFactory::create('c', {10,1}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2}); + auto permut = NDArrayFactory::create('c', {1,10}, {8 ,1 ,4 ,0, 5 ,2 ,9 ,3 ,7 ,6}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto expSingVals = NDArrayFactory::create('c', {10,1}, {-2, 15.304323, 11.2, -1, 1.73489, -12, -15.3043, -12.862, 5.6, 41.4039}); - auto expShifts = NDArrayFactory::create('c', {10,1}, {1, 19, 19, 1, 2, -18, -18, -13, 2, 2}); - auto expMus = NDArrayFactory::create('c', {10,1}, {-3, -3.695677, -7.8, -2, -0.265108, 6, 2.69568, 0.138048, 3.6, 39.4039}); + auto expSingVals = NDArrayFactory::create('c', {10,1}, {-2, 15.304323, 11.2, -1, 1.73489, -12, -15.3043, -12.862, 5.6, 41.4039}); + auto expShifts = NDArrayFactory::create('c', {10,1}, {1, 19, 19, 1, 2, -18, -18, -13, 2, 2}); + auto expMus = NDArrayFactory::create('c', {10,1}, {-3, -3.695677, -7.8, -2, -0.265108, 6, 2.69568, 0.138048, 3.6, 39.4039}); - auto singVals = NDArrayFactory::create('c', {10,1}); - auto shifts = NDArrayFactory::create('c', {10,1}); - auto mus = NDArrayFactory::create('c', {10,1}); + auto singVals = NDArrayFactory::create('c', {10,1}); + auto shifts = NDArrayFactory::create('c', {10,1}); + auto mus = NDArrayFactory::create('c', {10,1}); - ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); - svd.calcSingVals(col0, diag, permut, singVals, shifts, mus); + ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); + svd.calcSingVals(col0, diag, permut, singVals, shifts, mus); - ASSERT_TRUE(expSingVals.equalsTo(&singVals)); - ASSERT_TRUE(expShifts.equalsTo(&shifts)); - ASSERT_TRUE(expMus.equalsTo(&mus)); + ASSERT_TRUE(expSingVals.equalsTo(&singVals)); + ASSERT_TRUE(expShifts.equalsTo(&shifts)); + ASSERT_TRUE(expMus.equalsTo(&mus)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test10) { - auto singVals = NDArrayFactory::create('c', {4,1}, {1 ,1 ,1 ,1}); - auto col0 = NDArrayFactory::create('c', {4,1}, {1 ,1 ,1 ,1}); - auto diag = NDArrayFactory::create('c', {4,1}, {5 ,7 ,-13 ,14}); - auto permut = NDArrayFactory::create('c', {1,4}, {0 ,2 ,3 ,1 }); - auto mus = NDArrayFactory::create('c', {4,1}, {4,1,4,6}); - auto shifts = NDArrayFactory::create('c', {4,1}, {4,2,5,6}); - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto singVals = NDArrayFactory::create('c', {4,1}, {1 ,1 ,1 ,1}); + auto col0 = NDArrayFactory::create('c', {4,1}, {1 ,1 ,1 ,1}); + auto diag = NDArrayFactory::create('c', {4,1}, {5 ,7 ,-13 ,14}); + auto permut = NDArrayFactory::create('c', {1,4}, {0 ,2 ,3 ,1 }); + auto mus = NDArrayFactory::create('c', {4,1}, {4,1,4,6}); + auto shifts = NDArrayFactory::create('c', {4,1}, {4,2,5,6}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto expZhat = NDArrayFactory::create('c', {4,1}, {0, 0.278208, 72.501953, 0}); + auto expZhat = NDArrayFactory::create('c', {4,1}, {0, 0.278208, 72.501953, 0}); - auto zhat = NDArrayFactory::create('c', {4,1}); + auto zhat = NDArrayFactory::create('c', {4,1}); - ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); - svd.perturb(col0, diag, permut, singVals, shifts, mus, zhat); + ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); + svd.perturb(col0, diag, permut, singVals, shifts, mus, zhat); - ASSERT_NEAR(expZhat.e(1), zhat.e(1), EPS); - ASSERT_NEAR(expZhat.e(2), zhat.e(2), EPS); + ASSERT_NEAR(expZhat.e(1), zhat.e(1), EPS); + ASSERT_NEAR(expZhat.e(2), zhat.e(2), EPS); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test11) { - auto singVals = NDArrayFactory::create('c', {4,1}, {1 ,1 ,1 ,1}); - auto zhat = NDArrayFactory::create('c', {4,1}, {2 ,1 ,2 ,1}); - auto diag = NDArrayFactory::create('c', {4,1}, {5 ,7 ,-13 ,14}); - auto permut = NDArrayFactory::create('c', {1,4}, {0 ,2 ,3 ,1 }); - auto mus = NDArrayFactory::create('c', {4,1}, {4,1,4,6}); - auto shifts = NDArrayFactory::create('c', {4,1}, {4,2,5,6}); - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto singVals = NDArrayFactory::create('c', {4,1}, {1 ,1 ,1 ,1}); + auto zhat = NDArrayFactory::create('c', {4,1}, {2 ,1 ,2 ,1}); + auto diag = NDArrayFactory::create('c', {4,1}, {5 ,7 ,-13 ,14}); + auto permut = NDArrayFactory::create('c', {1,4}, {0 ,2 ,3 ,1 }); + auto mus = NDArrayFactory::create('c', {4,1}, {4,1,4,6}); + auto shifts = NDArrayFactory::create('c', {4,1}, {4,2,5,6}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto expU = NDArrayFactory::create('c', {5,5}, {-0.662161, 0.980399,-0.791469,-0.748434, 0, -0.744931, 0.183825,-0.593602,-0.392928, 0, 0.0472972, 0.061275,0.0719517, 0.104781, 0, 0.0662161,0.0356509, 0.126635, 0.523904, 0, 0, 0, 0, 0, 1}); - auto expV = NDArrayFactory::create('c', {4,4}, {-0.745259,-0.965209, -0.899497, -0.892319, -0.652102, 0.21114, -0.39353, -0.156156, -0.0768918,-0.130705,-0.0885868,-0.0773343, 0.115929,0.0818966, 0.167906, 0.416415}); - auto U = NDArrayFactory::create('c', {5,5}); - auto V = NDArrayFactory::create('c', {4,4}); + auto expU = NDArrayFactory::create('c', {5,5}, {-0.662161, 0.980399,-0.791469,-0.748434, 0, -0.744931, 0.183825,-0.593602,-0.392928, 0, 0.0472972, 0.061275,0.0719517, 0.104781, 0, 0.0662161,0.0356509, 0.126635, 0.523904, 0, 0, 0, 0, 0, 1}); + auto expV = NDArrayFactory::create('c', {4,4}, {-0.745259,-0.965209, -0.899497, -0.892319, -0.652102, 0.21114, -0.39353, -0.156156, -0.0768918,-0.130705,-0.0885868,-0.0773343, 0.115929,0.0818966, 0.167906, 0.416415}); + auto U = NDArrayFactory::create('c', {5,5}); + auto V = NDArrayFactory::create('c', {4,4}); - ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); - svd.calcSingVecs(zhat, diag,permut, singVals, shifts, mus, U, V); + ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); + svd.calcSingVecs(zhat, diag,permut, singVals, shifts, mus, U, V); - ASSERT_TRUE(expU.equalsTo(&U)); - ASSERT_TRUE(expV.equalsTo(&V)); + ASSERT_TRUE(expU.equalsTo(&U)); + ASSERT_TRUE(expV.equalsTo(&V)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test12) { - auto matrix1 = NDArrayFactory::create('c', {6,5}, {-2 ,-3 ,2 ,1 ,0 ,0 ,-4 ,5 ,-2 ,-3 ,-4 ,0 ,5 ,-1 ,-5 ,-3 ,-5 ,3 ,3 ,3 ,-5 ,5 ,-5 ,0 ,2 ,-2 ,-3 ,-4 ,-5 ,-3}); - auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto matrix4 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); + auto matrix1 = NDArrayFactory::create('c', {6,5}, {-2 ,-3 ,2 ,1 ,0 ,0 ,-4 ,5 ,-2 ,-3 ,-4 ,0 ,5 ,-1 ,-5 ,-3 ,-5 ,3 ,3 ,3 ,-5 ,5 ,-5 ,0 ,2 ,-2 ,-3 ,-4 ,-5 ,-3}); + auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto matrix4 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); - auto expSingVals = NDArrayFactory::create('c', {4,1}, {8.43282, 5, 2.3, 1.10167}); - auto expU = NDArrayFactory::create('c', {5,5}, {0.401972,0, 0.206791, 0.891995,0, 0,1, 0, 0,0, 0.816018,0,-0.522818,-0.246529,0, -0.415371,0,-0.826982, 0.378904,0, 0,0, 0, 0,1}); - auto expV = NDArrayFactory::create('c', {4,4}, {-0.951851,0,-0.133555,-0.275939, 0,1, 0, 0, 0.290301,0,-0.681937,-0.671333, -0.098513,0,-0.719114, 0.687873}); + auto expSingVals = NDArrayFactory::create('c', {4,1}, {8.43282, 5, 2.3, 1.10167}); + auto expU = NDArrayFactory::create('c', {5,5}, {0.401972,0, 0.206791, 0.891995,0, 0,1, 0, 0,0, 0.816018,0,-0.522818,-0.246529,0, -0.415371,0,-0.826982, 0.378904,0, 0,0, 0, 0,1}); + auto expV = NDArrayFactory::create('c', {4,4}, {-0.951851,0,-0.133555,-0.275939, 0,1, 0, 0, 0.290301,0,-0.681937,-0.671333, -0.098513,0,-0.719114, 0.687873}); - ops::helpers::SVD svd(matrix4, 4, true, true, true, 't'); - svd._m = matrix1; - svd._u = matrix2; - svd._v = matrix3; - NDArray U, singVals, V; - svd.calcBlockSVD(1, 4, U, singVals, V); + ops::helpers::SVD svd(matrix4, 4, true, true, true, 't'); + svd._m = matrix1; + svd._u = matrix2; + svd._v = matrix3; + NDArray U, singVals, V; + svd.calcBlockSVD(1, 4, U, singVals, V); - ASSERT_TRUE(expSingVals.equalsTo(&singVals)); - ASSERT_TRUE(expU.equalsTo(&U)); - ASSERT_TRUE(expV.equalsTo(&V)); + ASSERT_TRUE(expSingVals.equalsTo(&singVals)); + ASSERT_TRUE(expU.equalsTo(&U)); + ASSERT_TRUE(expV.equalsTo(&V)); - ASSERT_TRUE(expSingVals.isSameShapeStrict(singVals)); - ASSERT_TRUE(expU.isSameShapeStrict(U)); - ASSERT_TRUE(expV.isSameShapeStrict(V)); + ASSERT_TRUE(expSingVals.isSameShapeStrict(singVals)); + ASSERT_TRUE(expU.isSameShapeStrict(U)); + ASSERT_TRUE(expV.isSameShapeStrict(V)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test16) { - auto matrix1 = NDArrayFactory::create('c', {6,5}, {-2 ,-3 ,2 ,1 ,0 ,0 ,-4 ,5 ,-2 ,-3 ,-4 ,0 ,5 ,-1 ,-5 ,-3 ,-5 ,3 ,3 ,3 ,-5 ,5 ,-5 ,0 ,2 ,-2 ,-3 ,-4 ,-5 ,-3}); - auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto matrix4 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); + auto matrix1 = NDArrayFactory::create('c', {6,5}, {-2 ,-3 ,2 ,1 ,0 ,0 ,-4 ,5 ,-2 ,-3 ,-4 ,0 ,5 ,-1 ,-5 ,-3 ,-5 ,3 ,3 ,3 ,-5 ,5 ,-5 ,0 ,2 ,-2 ,-3 ,-4 ,-5 ,-3}); + auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto matrix4 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); - auto expM = NDArrayFactory::create('c', {6,5}, {-2, -3, 2, 1, 0, 0,7.07022, 0, 0, 0, -4, 0,5.09585, 0, 0, -3, 0, 0,3.32256, 0, -5, 0, 0, 0,1.00244, -2, -3, -4, -5, 0}); - auto expU = NDArrayFactory::create('c', {6,6}, {-5.58884,-2.18397,-11.0944, 3.30292, 0,-10, 8.19094, 5.05917, 16.9641,-4.53112, 0, 20, 6.55878, 3.76734, 15.9255,-3.76399, 0,-19, 1.36021, 23.3551,-8.01165, -1.5816, 0, 14, -15.6318,-2.85386, 8.83051, 2.74286, 1,-16, 18, -6, -18, 1,-15,-12}); - auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2, 14.5866, 3.90133, 1.06593, 9.99376, -2, 9.97311, 2.44445, 6.85159, 2.37014, -3, 0.56907,-8.93313,-5.31596, 3.10096, 16,-10.6859, 1.70708,-7.24295,-10.6975}); + auto expM = NDArrayFactory::create('c', {6,5}, {-2, -3, 2, 1, 0, 0,7.07022, 0, 0, 0, -4, 0,5.09585, 0, 0, -3, 0, 0,3.32256, 0, -5, 0, 0, 0,1.00244, -2, -3, -4, -5, 0}); + auto expU = NDArrayFactory::create('c', {6,6}, {-5.58884,-2.18397,-11.0944, 3.30292, 0,-10, 8.19094, 5.05917, 16.9641,-4.53112, 0, 20, 6.55878, 3.76734, 15.9255,-3.76399, 0,-19, 1.36021, 23.3551,-8.01165, -1.5816, 0, 14, -15.6318,-2.85386, 8.83051, 2.74286, 1,-16, 18, -6, -18, 1,-15,-12}); + auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2, 14.5866, 3.90133, 1.06593, 9.99376, -2, 9.97311, 2.44445, 6.85159, 2.37014, -3, 0.56907,-8.93313,-5.31596, 3.10096, 16,-10.6859, 1.70708,-7.24295,-10.6975}); - ops::helpers::SVD svd(matrix4, 4, true, true, true, 't'); - svd._m = matrix1; - svd._u = matrix2; - svd._v = matrix3; + ops::helpers::SVD svd(matrix4, 4, true, true, true, 't'); + svd._m = matrix1; + svd._u = matrix2; + svd._v = matrix3; - svd.DivideAndConquer(0, 3, 1, 1, 1); - // svd._m.printIndexedBuffer(); - ASSERT_TRUE(expM.isSameShapeStrict(svd._m)); - ASSERT_TRUE(expU.isSameShapeStrict(svd._u)); - ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); + svd.DivideAndConquer(0, 3, 1, 1, 1); + // svd._m.printIndexedBuffer(); + ASSERT_TRUE(expM.isSameShapeStrict(svd._m)); + ASSERT_TRUE(expU.isSameShapeStrict(svd._u)); + ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); - ASSERT_TRUE(expM.equalsTo(&svd._m)); - ASSERT_TRUE(expU.equalsTo(&svd._u)); - ASSERT_TRUE(expV.equalsTo(&svd._v)); + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); + ASSERT_TRUE(expV.equalsTo(&svd._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test17) { - auto matrix1 = NDArrayFactory::create('c', {6,5}, {-2 ,-3 ,2 ,1 ,0 ,0 ,-4 ,5 ,-2 ,-3 ,-4 ,0 ,5 ,-1 ,-5 ,-3 ,-5 ,3 ,3 ,3 ,-5 ,5 ,-5 ,0 ,2 ,-2 ,-3 ,-4 ,-5 ,-3}); - auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto matrix4 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); + auto matrix1 = NDArrayFactory::create('c', {6,5}, {-2 ,-3 ,2 ,1 ,0 ,0 ,-4 ,5 ,-2 ,-3 ,-4 ,0 ,5 ,-1 ,-5 ,-3 ,-5 ,3 ,3 ,3 ,-5 ,5 ,-5 ,0 ,2 ,-2 ,-3 ,-4 ,-5 ,-3}); + auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto matrix4 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); - auto expM = NDArrayFactory::create('c', {6,5}, {-2, -3, 2, 1, 0, 0,12.1676, 0, 0, 0, -4, 0,7.49514, 0, 0, -3, 0, 0,5.00951, 0, -5, 0, 0, 0, 1.63594, -2, 0, 0, 0, 0}); - auto expU = NDArrayFactory::create('c', {6,6}, {0.295543,-0.238695, 0.262095,-0.231772, -0.85631,-10, 0.519708,0.0571492,-0.368706,-0.727615, 0.247527, 20, 0.313717,-0.561567,-0.602941, 0.469567,-0.0468295,-19, 0.474589,-0.372165, 0.656962, 0.124776, 0.434845, 14, -0.564717,-0.697061,0.0150082, -0.4252, 0.119081,-16, 18, -6, -18, 1, -15,-12}); - auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2,-0.0366659, 0.977361,-0.0316106,0.205967, -2, -0.670795, -0.151697, -0.503288,0.523185, -3, 0.740124,-0.0841435, -0.486714,0.456339, 16, 0.0300945, -0.121135, 0.71331,0.689645}); + auto expM = NDArrayFactory::create('c', {6,5}, {-2, -3, 2, 1, 0, 0,12.1676, 0, 0, 0, -4, 0,7.49514, 0, 0, -3, 0, 0,5.00951, 0, -5, 0, 0, 0, 1.63594, -2, 0, 0, 0, 0}); + auto expU = NDArrayFactory::create('c', {6,6}, {0.295543,-0.238695, 0.262095,-0.231772, -0.85631,-10, 0.519708,0.0571492,-0.368706,-0.727615, 0.247527, 20, 0.313717,-0.561567,-0.602941, 0.469567,-0.0468295,-19, 0.474589,-0.372165, 0.656962, 0.124776, 0.434845, 14, -0.564717,-0.697061,0.0150082, -0.4252, 0.119081,-16, 18, -6, -18, 1, -15,-12}); + auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2,-0.0366659, 0.977361,-0.0316106,0.205967, -2, -0.670795, -0.151697, -0.503288,0.523185, -3, 0.740124,-0.0841435, -0.486714,0.456339, 16, 0.0300945, -0.121135, 0.71331,0.689645}); - ops::helpers::SVD svd(matrix4, 10, true, true, true, 't'); - svd._m = matrix1; - svd._u = matrix2; - svd._v = matrix3; + ops::helpers::SVD svd(matrix4, 10, true, true, true, 't'); + svd._m = matrix1; + svd._u = matrix2; + svd._v = matrix3; - svd.DivideAndConquer(0, 3, 1, 1, 1); + svd.DivideAndConquer(0, 3, 1, 1, 1); - ASSERT_TRUE(expM.equalsTo(&svd._m)); - ASSERT_TRUE(expU.equalsTo(&svd._u)); - ASSERT_TRUE(expV.equalsTo(&svd._v)); + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); + ASSERT_TRUE(expV.equalsTo(&svd._v)); - ASSERT_TRUE(expM.isSameShapeStrict(svd._m)); - ASSERT_TRUE(expU.isSameShapeStrict(svd._u)); - ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); + ASSERT_TRUE(expM.isSameShapeStrict(svd._m)); + ASSERT_TRUE(expU.isSameShapeStrict(svd._u)); + ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); } // /////////////////////////////////////////////////////////////////// @@ -1324,430 +1324,430 @@ TEST_F(HelpersTests1, SVD_test17) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, rnnCell_test1) { - const int bS = 2; - const int inSize = 4; - const int numUnits = 4; + const int bS = 2; + const int inSize = 4; + const int numUnits = 4; - NDArray xt('c', {bS, inSize}, sd::DataType::DOUBLE); - NDArray ht_1('c', {bS, numUnits}, sd::DataType::DOUBLE); - NDArray Wx('c', {inSize, numUnits}, sd::DataType::DOUBLE); - NDArray Wh('c', {numUnits, numUnits}, sd::DataType::DOUBLE); - NDArray b ('c', {2*numUnits}, {0.0,0.0,0.0,0.0, 0.1,0.2,0.3,0.4}); - NDArray ht('c', {bS, numUnits}, sd::DataType::DOUBLE); + NDArray xt('c', {bS, inSize}, sd::DataType::DOUBLE); + NDArray ht_1('c', {bS, numUnits}, sd::DataType::DOUBLE); + NDArray Wx('c', {inSize, numUnits}, sd::DataType::DOUBLE); + NDArray Wh('c', {numUnits, numUnits}, sd::DataType::DOUBLE); + NDArray b ('c', {2*numUnits}, {0.0,0.0,0.0,0.0, 0.1,0.2,0.3,0.4}); + NDArray ht('c', {bS, numUnits}, sd::DataType::DOUBLE); - xt.assign(0.1); - ht_1.assign(0.2); - Wx.assign(0.3); - Wh.assign(0.4); + xt.assign(0.1); + ht_1.assign(0.2); + Wx.assign(0.3); + Wh.assign(0.4); - NDArray expHt('c', {bS, numUnits}, {0.492988, 0.56489956, 0.6291452 , 0.6858091,0.492988, 0.56489956, 0.6291452 , 0.6858091}); + NDArray expHt('c', {bS, numUnits}, {0.492988, 0.56489956, 0.6291452 , 0.6858091,0.492988, 0.56489956, 0.6291452 , 0.6858091}); - ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, &ht_1, &ht); + ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, &ht_1, &ht); - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, rnnCell_test2) { - const int bS = 2; - const int inSize = 10; - const int numUnits = 4; + const int bS = 2; + const int inSize = 10; + const int numUnits = 4; - auto xt = NDArrayFactory::create('c', {bS, inSize}); - auto ht_1 = NDArrayFactory::create('c', {bS, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}, {0.0,0.0,0.0,0.0, 0.1,0.2,0.3,0.4}); + auto xt = NDArrayFactory::create('c', {bS, inSize}); + auto ht_1 = NDArrayFactory::create('c', {bS, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2*numUnits}, {0.0,0.0,0.0,0.0, 0.1,0.2,0.3,0.4}); - auto ht = NDArrayFactory::create('c', {bS, numUnits}); + auto ht = NDArrayFactory::create('c', {bS, numUnits}); - xt.assign(0.1); - ht_1.assign(0.2); - Wx.assign(0.3); - Wh.assign(0.4); + xt.assign(0.1); + ht_1.assign(0.2); + Wx.assign(0.3); + Wh.assign(0.4); - auto expHt = NDArrayFactory::create('c', {bS, numUnits}, {0.6169093,0.67506987,0.72589741,0.76986654,0.6169093,0.67506987,0.72589741,0.76986654}); + auto expHt = NDArrayFactory::create('c', {bS, numUnits}, {0.6169093,0.67506987,0.72589741,0.76986654,0.6169093,0.67506987,0.72589741,0.76986654}); - ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, &ht_1, &ht); + ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, &ht_1, &ht); - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, rnnCell_test3) { - const int bS = 2; - const int inSize = 10; - const int numUnits = 4; + const int bS = 2; + const int inSize = 10; + const int numUnits = 4; - auto xt = NDArrayFactory::create('c', {bS, inSize}); - auto ht_1 = NDArrayFactory::create('c', {bS, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}, {0.01,0.02,0.03,0.04, 0.05,0.06,0.07,0.08}); + auto xt = NDArrayFactory::create('c', {bS, inSize}); + auto ht_1 = NDArrayFactory::create('c', {bS, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2*numUnits}, {0.01,0.02,0.03,0.04, 0.05,0.06,0.07,0.08}); - auto ht = NDArrayFactory::create('c', {bS, numUnits}); + auto ht = NDArrayFactory::create('c', {bS, numUnits}); - xt.assign(0.1); - ht_1.assign(0.2); - Wx.assign(0.3); - Wh.assign(0.4); + xt.assign(0.1); + ht_1.assign(0.2); + Wx.assign(0.3); + Wh.assign(0.4); - auto expHt = NDArrayFactory::create('c', {bS, numUnits}, {0.5915195, 0.6043678, 0.6169093, 0.6291452,0.5915195, 0.6043678, 0.6169093, 0.6291452}); + auto expHt = NDArrayFactory::create('c', {bS, numUnits}, {0.5915195, 0.6043678, 0.6169093, 0.6291452,0.5915195, 0.6043678, 0.6169093, 0.6291452}); - ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, &ht_1, &ht); + ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, &ht_1, &ht); - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, rnnCell_test4) { - const int bS = 2; - const int inSize = 3; - const int numUnits = 4; + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; - auto xt = NDArrayFactory::create('c', {bS, inSize}); - auto ht_1 = NDArrayFactory::create('c', {bS, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}); + auto xt = NDArrayFactory::create('c', {bS, inSize}); + auto ht_1 = NDArrayFactory::create('c', {bS, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2*numUnits}); - auto ht = NDArrayFactory::create('c', {bS, numUnits}); + auto ht = NDArrayFactory::create('c', {bS, numUnits}); - xt.linspace(0.01, 0.01); - ht_1 = 0.2; - Wx = 0.3; - Wh = 0.4; - b = 0.25; + xt.linspace(0.01, 0.01); + ht_1 = 0.2; + Wx = 0.3; + Wh = 0.4; + b = 0.25; - auto expHt = NDArrayFactory::create('c', {bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.69882484, 0.69882484, 0.69882484, 0.69882484}); + auto expHt = NDArrayFactory::create('c', {bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.69882484, 0.69882484, 0.69882484, 0.69882484}); - ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, &ht_1, &ht); + ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, &ht_1, &ht); - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); } #endif //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulHelper_test_1) { - auto x = NDArrayFactory::create('c', {3,3}, {10,11,12,13,14,15,16,17,18}); - auto y = NDArrayFactory::create('c', {3,3}, {1,2,3,4,5,6,7,8,9}); - auto expected = NDArrayFactory::create('c', {3,3}, {138.,171.,204. ,174.,216.,258. ,210.,261.,312.}); + auto x = NDArrayFactory::create('c', {3,3}, {10,11,12,13,14,15,16,17,18}); + auto y = NDArrayFactory::create('c', {3,3}, {1,2,3,4,5,6,7,8,9}); + auto expected = NDArrayFactory::create('c', {3,3}, {138.,171.,204. ,174.,216.,258. ,210.,261.,312.}); - auto result = MmulHelper::mmul(&x, &y, nullptr, 1., 0.); + auto result = MmulHelper::mmul(&x, &y, nullptr, 1., 0.); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); - delete result; + delete result; } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulHelper_test_2) { - auto x = NDArrayFactory::create('c', {3,3}, {10,11,12,13,14,15,16,17,18}); - auto y = NDArrayFactory::create('c', {3,3}, {1,2,3,4,5,6,7,8,9}); - auto expected = NDArrayFactory::create('c', {3,3}, {138.,171.,204. ,174.,216.,258. ,210.,261.,312.}); - auto result = NDArrayFactory::create('c', {3,3}); + auto x = NDArrayFactory::create('c', {3,3}, {10,11,12,13,14,15,16,17,18}); + auto y = NDArrayFactory::create('c', {3,3}, {1,2,3,4,5,6,7,8,9}); + auto expected = NDArrayFactory::create('c', {3,3}, {138.,171.,204. ,174.,216.,258. ,210.,261.,312.}); + auto result = NDArrayFactory::create('c', {3,3}); - MmulHelper::mmul(&x, &y, &result, 1., 0.); + MmulHelper::mmul(&x, &y, &result, 1., 0.); - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulHelper_test_3) { - auto x = NDArrayFactory::create('c', {3,4}); x.linspace(1); - auto y = NDArrayFactory::create('c', {4,5}); y.linspace(1); - auto expected = NDArrayFactory::create('c', {3,5}, {110.,120.,130.,140.,150.,246.,272.,298.,324.,350.,382.,424.,466.,508.,550.}); + auto x = NDArrayFactory::create('c', {3,4}); x.linspace(1); + auto y = NDArrayFactory::create('c', {4,5}); y.linspace(1); + auto expected = NDArrayFactory::create('c', {3,5}, {110.,120.,130.,140.,150.,246.,272.,298.,324.,350.,382.,424.,466.,508.,550.}); - auto result = MmulHelper::mmul(&x, &y, nullptr, 1., 0.); + auto result = MmulHelper::mmul(&x, &y, nullptr, 1., 0.); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); - delete result; + delete result; } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulHelper_test_4) { - auto x = NDArrayFactory::create('c', {3,4}); x.linspace(1); - auto y = NDArrayFactory::create('c', {4,5}); y.linspace(1); - auto expected = NDArrayFactory::create('c', {3,5}, {110.,120.,130.,140.,150.,246.,272.,298.,324.,350.,382.,424.,466.,508.,550.}); - auto result = NDArrayFactory::create('c', {3,5}); + auto x = NDArrayFactory::create('c', {3,4}); x.linspace(1); + auto y = NDArrayFactory::create('c', {4,5}); y.linspace(1); + auto expected = NDArrayFactory::create('c', {3,5}, {110.,120.,130.,140.,150.,246.,272.,298.,324.,350.,382.,424.,466.,508.,550.}); + auto result = NDArrayFactory::create('c', {3,5}); - MmulHelper::mmul(&x, &y, &result, 1., 0.); + MmulHelper::mmul(&x, &y, &result, 1., 0.); - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulHelper_test_5) { - auto x = NDArrayFactory::create('c', {4,3}); x.linspace(1); - auto y = NDArrayFactory::create('c', {3,5}); y.linspace(1); - auto expected = NDArrayFactory::create('c', {4,5}, {46., 52., 58., 64., 70.,100.,115.,130.,145.,160.,154.,178.,202.,226.,250.,208.,241.,274.,307.,340.}); + auto x = NDArrayFactory::create('c', {4,3}); x.linspace(1); + auto y = NDArrayFactory::create('c', {3,5}); y.linspace(1); + auto expected = NDArrayFactory::create('c', {4,5}, {46., 52., 58., 64., 70.,100.,115.,130.,145.,160.,154.,178.,202.,226.,250.,208.,241.,274.,307.,340.}); - auto result = MmulHelper::mmul(&x, &y, nullptr, 1., 0.); + auto result = MmulHelper::mmul(&x, &y, nullptr, 1., 0.); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); - delete result; + delete result; } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulHelper_test_6) { - auto x = NDArrayFactory::create('c', {4,3}); x.linspace(1); - auto y = NDArrayFactory::create('c', {3,5}); y.linspace(1); - auto expected = NDArrayFactory::create('c', {4,5}, {46., 52., 58., 64., 70.,100.,115.,130.,145.,160.,154.,178.,202.,226.,250.,208.,241.,274.,307.,340.}); - auto result = NDArrayFactory::create('c', {4,5}); + auto x = NDArrayFactory::create('c', {4,3}); x.linspace(1); + auto y = NDArrayFactory::create('c', {3,5}); y.linspace(1); + auto expected = NDArrayFactory::create('c', {4,5}, {46., 52., 58., 64., 70.,100.,115.,130.,145.,160.,154.,178.,202.,226.,250.,208.,241.,274.,307.,340.}); + auto result = NDArrayFactory::create('c', {4,5}); - MmulHelper::mmul(&x, &y, &result, 1., 0.); + MmulHelper::mmul(&x, &y, &result, 1., 0.); - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulHelper_test_7) { - auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); - auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16}); - auto result = NDArrayFactory::create('c', {4,4}); + auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16}); + auto result = NDArrayFactory::create('c', {4,4}); - MmulHelper::mmul(&x, &y, &result, 1., 0.); + MmulHelper::mmul(&x, &y, &result, 1., 0.); - ASSERT_TRUE(exp.isSameShape(&result)); - ASSERT_TRUE(exp.equalsTo(&result)); + ASSERT_TRUE(exp.isSameShape(&result)); + ASSERT_TRUE(exp.equalsTo(&result)); } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, tensordot_test_1) { - auto a = NDArrayFactory::create('c', {2, 3, 4}); - auto b = NDArrayFactory::create('c', {2, 5, 3}); + auto a = NDArrayFactory::create('c', {2, 3, 4}); + auto b = NDArrayFactory::create('c', {2, 5, 3}); - auto c = MmulHelper::tensorDot(&a, &b, {1}, {2}); + auto c = MmulHelper::tensorDot(&a, &b, {1}, {2}); - ASSERT_TRUE(c->isSameShape({2,4,2,5})); - delete c; + ASSERT_TRUE(c->isSameShape({2,4,2,5})); + delete c; } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, tensordot_test_2) { - auto a = NDArrayFactory::create('c', {7, 3, 4, 6}); - auto b = NDArrayFactory::create('c', {2, 5, 3, 8, 4}); + auto a = NDArrayFactory::create('c', {7, 3, 4, 6}); + auto b = NDArrayFactory::create('c', {2, 5, 3, 8, 4}); - auto c = MmulHelper::tensorDot(&a, &b, {2,1}, {4,2}); + auto c = MmulHelper::tensorDot(&a, &b, {2,1}, {4,2}); - ASSERT_TRUE(c->isSameShape({7,6,2,5,8})); - delete c; + ASSERT_TRUE(c->isSameShape({7,6,2,5,8})); + delete c; } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, tensordot_test_3) { - auto a = NDArrayFactory::create('c', {7, 3, 4, 6}); - auto b = NDArrayFactory::create('c', {2, 5, 3, 8, 4}); - auto c = NDArrayFactory::create('f', {7,6,2,8,5}); + auto a = NDArrayFactory::create('c', {7, 3, 4, 6}); + auto b = NDArrayFactory::create('c', {2, 5, 3, 8, 4}); + auto c = NDArrayFactory::create('f', {7,6,2,8,5}); - MmulHelper::tensorDot(&a, &b, &c, {2,1}, {4,2}, {0,1,2,4,3}); + MmulHelper::tensorDot(&a, &b, &c, {2,1}, {4,2}, {0,1,2,4,3}); - ASSERT_TRUE(c.isSameShape({7,6,2,8,5})); + ASSERT_TRUE(c.isSameShape({7,6,2,8,5})); } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, tensordot_test_4) { - auto a = NDArrayFactory::create('c', {7, 3, 4, 3}); - auto b = NDArrayFactory::create('c', {2, 5, 3, 2, 4}); - auto c = NDArrayFactory::create('f', {7,3,2,2,5}); - auto expected = NDArrayFactory::create('c', {7,3,2,2,5}, { 754.5, 2014.5, 3274.5, 4534.5 , 5794.5, 964.5, 2224.5, 3484.5, 4744.5, 6004.5, 7054.5, 8314.5, 9574.5, 10834.5, 12094.5, 7264.5, 8524.5, 9784.5, 11044.5, 12304.5, 786. , 2118. , 3450. , 4782. , 6114. , 1008. , 2340. , 3672. , 5004. , 6336. , - 7446. , 8778. , 10110. , 11442. , 12774. , 7668. , 9000. , 10332. , 11664. , 12996. , 817.5, 2221.5, 3625.5, 5029.5, 6433.5, 1051.5, 2455.5, 3859.5, 5263.5, 6667.5, 7837.5, 9241.5, 10645.5, 12049.5, 13453.5, 8071.5, 9475.5, 10879.5, 12283.5, 13687.5, - 1888.5, 5740.5, 9592.5, 13444.5, 17296.5, 2530.5, 6382.5, 10234.5, 14086.5, 17938.5,21148.5, 25000.5, 28852.5, 32704.5, 36556.5,21790.5, 25642.5, 29494.5, 33346.5, 37198.5, 1920. , 5844. , 9768. , 13692. , 17616. , 2574. , 6498. , 10422. , 14346. , 18270. , - 21540. , 25464. , 29388. , 33312. , 37236. ,22194. , 26118. , 30042. , 33966. , 37890. , 1951.5, 5947.5, 9943.5, 13939.5, 17935.5, 2617.5, 6613.5, 10609.5, 14605.5, 18601.5,21931.5, 25927.5, 29923.5, 33919.5, 37915.5,22597.5, 26593.5, 30589.5, 34585.5, 38581.5, - 3022.5, 9466.5, 15910.5, 22354.5, 28798.5, 4096.5, 10540.5, 16984.5, 23428.5, 29872.5,35242.5, 41686.5, 48130.5, 54574.5, 61018.5,36316.5, 42760.5, 49204.5, 55648.5, 62092.5, 3054. , 9570. , 16086. , 22602. , 29118. , 4140. , 10656. , 17172. , 23688. , 30204. , - 35634. , 42150. , 48666. , 55182. , 61698. ,36720. , 43236. , 49752. , 56268. , 62784. , 3085.5, 9673.5, 16261.5, 22849.5, 29437.5, 4183.5, 10771.5, 17359.5, 23947.5, 30535.5,36025.5, 42613.5, 49201.5, 55789.5, 62377.5,37123.5, 43711.5, 50299.5, 56887.5, 63475.5, - 4156.5, 13192.5, 22228.5, 31264.5, 40300.5, 5662.5, 14698.5, 23734.5, 32770.5, 41806.5,49336.5, 58372.5, 67408.5, 76444.5, 85480.5,50842.5, 59878.5, 68914.5, 77950.5, 86986.5, 4188. , 13296. , 22404. , 31512. , 40620. , 5706. , 14814. , 23922. , 33030. , 42138. , - 49728. , 58836. , 67944. , 77052. , 86160. ,51246. , 60354. , 69462. , 78570. , 87678. , 4219.5, 13399.5, 22579.5, 31759.5, 40939.5, 5749.5, 14929.5, 24109.5, 33289.5, 42469.5,50119.5, 59299.5, 68479.5, 77659.5, 86839.5,51649.5, 60829.5, 70009.5, 79189.5, 88369.5, - 5290.5, 16918.5, 28546.5, 40174.5, 51802.5, 7228.5, 18856.5, 30484.5, 42112.5, 53740.5,63430.5, 75058.5, 86686.5, 98314.5,109942.5,65368.5, 76996.5, 88624.5,100252.5,111880.5, 5322. , 17022. , 28722. , 40422. , 52122. , 7272. , 18972. , 30672. , 42372. , 54072. , - 63822. , 75522. , 87222. , 98922. ,110622. ,65772. , 77472. , 89172. ,100872. ,112572. , 5353.5, 17125.5, 28897.5, 40669.5, 52441.5, 7315.5, 19087.5, 30859.5, 42631.5, 54403.5,64213.5, 75985.5, 87757.5, 99529.5,111301.5,66175.5, 77947.5, 89719.5,101491.5,113263.5, - 6424.5, 20644.5, 34864.5, 49084.5, 63304.5, 8794.5, 23014.5, 37234.5, 51454.5, 65674.5,77524.5, 91744.5,105964.5,120184.5,134404.5,79894.5, 94114.5,108334.5,122554.5,136774.5, 6456. , 20748. , 35040. , 49332. , 63624. , 8838. , 23130. , 37422. , 51714. , 66006. , - 77916. , 92208. ,106500. ,120792. ,135084. ,80298. , 94590. ,108882. ,123174. ,137466. , 6487.5, 20851.5, 35215.5, 49579.5, 63943.5, 8881.5, 23245.5, 37609.5, 51973.5, 66337.5,78307.5, 92671.5,107035.5,121399.5,135763.5,80701.5, 95065.5,109429.5,123793.5,138157.5, - 7558.5, 24370.5, 41182.5, 57994.5, 74806.5,10360.5, 27172.5, 43984.5, 60796.5, 77608.5,91618.5,108430.5,125242.5,142054.5,158866.5,94420.5,111232.5,128044.5,144856.5,161668.5, 7590. , 24474. , 41358. , 58242. , 75126. ,10404. , 27288. , 44172. , 61056. , 77940. , - 92010. ,108894. ,125778. ,142662. ,159546. ,94824. ,111708. ,128592. ,145476. ,162360. , 7621.5, 24577.5, 41533.5, 58489.5, 75445.5,10447.5, 27403.5, 44359.5, 61315.5, 78271.5,92401.5,109357.5,126313.5,143269.5,160225.5,95227.5,112183.5,129139.5,146095.5,163051.5}); - - a.linspace(0.5, 0.5); - b.linspace(0.5, 0.5); - - MmulHelper::tensorDot(&a, &b, &c, {2,1}, {4,2}, {0,1,2,4,3}); - - ASSERT_TRUE(c.isSameShape(expected)); - ASSERT_TRUE(c.equalsTo(expected)); + auto a = NDArrayFactory::create('c', {7, 3, 4, 3}); + auto b = NDArrayFactory::create('c', {2, 5, 3, 2, 4}); + auto c = NDArrayFactory::create('f', {7,3,2,2,5}); + auto expected = NDArrayFactory::create('c', {7,3,2,2,5}, { 754.5, 2014.5, 3274.5, 4534.5 , 5794.5, 964.5, 2224.5, 3484.5, 4744.5, 6004.5, 7054.5, 8314.5, 9574.5, 10834.5, 12094.5, 7264.5, 8524.5, 9784.5, 11044.5, 12304.5, 786. , 2118. , 3450. , 4782. , 6114. , 1008. , 2340. , 3672. , 5004. , 6336. , + 7446. , 8778. , 10110. , 11442. , 12774. , 7668. , 9000. , 10332. , 11664. , 12996. , 817.5, 2221.5, 3625.5, 5029.5, 6433.5, 1051.5, 2455.5, 3859.5, 5263.5, 6667.5, 7837.5, 9241.5, 10645.5, 12049.5, 13453.5, 8071.5, 9475.5, 10879.5, 12283.5, 13687.5, + 1888.5, 5740.5, 9592.5, 13444.5, 17296.5, 2530.5, 6382.5, 10234.5, 14086.5, 17938.5,21148.5, 25000.5, 28852.5, 32704.5, 36556.5,21790.5, 25642.5, 29494.5, 33346.5, 37198.5, 1920. , 5844. , 9768. , 13692. , 17616. , 2574. , 6498. , 10422. , 14346. , 18270. , + 21540. , 25464. , 29388. , 33312. , 37236. ,22194. , 26118. , 30042. , 33966. , 37890. , 1951.5, 5947.5, 9943.5, 13939.5, 17935.5, 2617.5, 6613.5, 10609.5, 14605.5, 18601.5,21931.5, 25927.5, 29923.5, 33919.5, 37915.5,22597.5, 26593.5, 30589.5, 34585.5, 38581.5, + 3022.5, 9466.5, 15910.5, 22354.5, 28798.5, 4096.5, 10540.5, 16984.5, 23428.5, 29872.5,35242.5, 41686.5, 48130.5, 54574.5, 61018.5,36316.5, 42760.5, 49204.5, 55648.5, 62092.5, 3054. , 9570. , 16086. , 22602. , 29118. , 4140. , 10656. , 17172. , 23688. , 30204. , + 35634. , 42150. , 48666. , 55182. , 61698. ,36720. , 43236. , 49752. , 56268. , 62784. , 3085.5, 9673.5, 16261.5, 22849.5, 29437.5, 4183.5, 10771.5, 17359.5, 23947.5, 30535.5,36025.5, 42613.5, 49201.5, 55789.5, 62377.5,37123.5, 43711.5, 50299.5, 56887.5, 63475.5, + 4156.5, 13192.5, 22228.5, 31264.5, 40300.5, 5662.5, 14698.5, 23734.5, 32770.5, 41806.5,49336.5, 58372.5, 67408.5, 76444.5, 85480.5,50842.5, 59878.5, 68914.5, 77950.5, 86986.5, 4188. , 13296. , 22404. , 31512. , 40620. , 5706. , 14814. , 23922. , 33030. , 42138. , + 49728. , 58836. , 67944. , 77052. , 86160. ,51246. , 60354. , 69462. , 78570. , 87678. , 4219.5, 13399.5, 22579.5, 31759.5, 40939.5, 5749.5, 14929.5, 24109.5, 33289.5, 42469.5,50119.5, 59299.5, 68479.5, 77659.5, 86839.5,51649.5, 60829.5, 70009.5, 79189.5, 88369.5, + 5290.5, 16918.5, 28546.5, 40174.5, 51802.5, 7228.5, 18856.5, 30484.5, 42112.5, 53740.5,63430.5, 75058.5, 86686.5, 98314.5,109942.5,65368.5, 76996.5, 88624.5,100252.5,111880.5, 5322. , 17022. , 28722. , 40422. , 52122. , 7272. , 18972. , 30672. , 42372. , 54072. , + 63822. , 75522. , 87222. , 98922. ,110622. ,65772. , 77472. , 89172. ,100872. ,112572. , 5353.5, 17125.5, 28897.5, 40669.5, 52441.5, 7315.5, 19087.5, 30859.5, 42631.5, 54403.5,64213.5, 75985.5, 87757.5, 99529.5,111301.5,66175.5, 77947.5, 89719.5,101491.5,113263.5, + 6424.5, 20644.5, 34864.5, 49084.5, 63304.5, 8794.5, 23014.5, 37234.5, 51454.5, 65674.5,77524.5, 91744.5,105964.5,120184.5,134404.5,79894.5, 94114.5,108334.5,122554.5,136774.5, 6456. , 20748. , 35040. , 49332. , 63624. , 8838. , 23130. , 37422. , 51714. , 66006. , + 77916. , 92208. ,106500. ,120792. ,135084. ,80298. , 94590. ,108882. ,123174. ,137466. , 6487.5, 20851.5, 35215.5, 49579.5, 63943.5, 8881.5, 23245.5, 37609.5, 51973.5, 66337.5,78307.5, 92671.5,107035.5,121399.5,135763.5,80701.5, 95065.5,109429.5,123793.5,138157.5, + 7558.5, 24370.5, 41182.5, 57994.5, 74806.5,10360.5, 27172.5, 43984.5, 60796.5, 77608.5,91618.5,108430.5,125242.5,142054.5,158866.5,94420.5,111232.5,128044.5,144856.5,161668.5, 7590. , 24474. , 41358. , 58242. , 75126. ,10404. , 27288. , 44172. , 61056. , 77940. , + 92010. ,108894. ,125778. ,142662. ,159546. ,94824. ,111708. ,128592. ,145476. ,162360. , 7621.5, 24577.5, 41533.5, 58489.5, 75445.5,10447.5, 27403.5, 44359.5, 61315.5, 78271.5,92401.5,109357.5,126313.5,143269.5,160225.5,95227.5,112183.5,129139.5,146095.5,163051.5}); + + a.linspace(0.5, 0.5); + b.linspace(0.5, 0.5); + + MmulHelper::tensorDot(&a, &b, &c, {2,1}, {4,2}, {0,1,2,4,3}); + + ASSERT_TRUE(c.isSameShape(expected)); + ASSERT_TRUE(c.equalsTo(expected)); } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, tensordot_test_5) { - auto a = NDArrayFactory::create('c', {2, 3}); - auto b = NDArrayFactory::create('c', {3, 4}); - auto c = NDArrayFactory::create('f', {2, 4}); - auto expected = NDArrayFactory::create('c', {2, 4}, {9.5,11.,12.5 ,14.,20.75 ,24.5,28.25,32.}); + auto a = NDArrayFactory::create('c', {2, 3}); + auto b = NDArrayFactory::create('c', {3, 4}); + auto c = NDArrayFactory::create('f', {2, 4}); + auto expected = NDArrayFactory::create('c', {2, 4}, {9.5,11.,12.5 ,14.,20.75 ,24.5,28.25,32.}); - a.linspace(0.5, 0.5); - b.linspace(0.5, 0.5); + a.linspace(0.5, 0.5); + b.linspace(0.5, 0.5); - MmulHelper::tensorDot(&a, &b, &c, {1}, {0}); - // c.printIndexedBuffer(); + MmulHelper::tensorDot(&a, &b, &c, {1}, {0}); + // c.printIndexedBuffer(); - ASSERT_TRUE(c.isSameShape(expected)); - ASSERT_TRUE(c.equalsTo(expected)); + ASSERT_TRUE(c.isSameShape(expected)); + ASSERT_TRUE(c.equalsTo(expected)); } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, tensordot_test_6) { - int bS=2, iH=3,iW=2, iC=2,mC=2, kH=2,kW=2; - int oC=iC*mC; - int oH=3,oW=2; + int bS=2, iH=3,iW=2, iC=2,mC=2, kH=2,kW=2; + int oC=iC*mC; + int oH=3,oW=2; - auto a = NDArrayFactory::create('c', {bS, iC, kH, kW, oH, oW}); - auto b = NDArrayFactory::create('c', {kH, kW, iC, mC}); - auto c = NDArrayFactory::create('c', {bS, oH, oW, iC*mC}); - auto expected = NDArrayFactory::create('c', {bS, oH, oW, iC*mC}, {100.,110.,336.,370.,107.,118.,345.,380.,114.,126.,354.,390.,121.,134.,363.,400.,128.,142.,372.,410.,135.,150.,381.,420., - 436.,494.,768.,850.,443.,502.,777.,860.,450.,510.,786.,870.,457.,518.,795.,880.,464.,526.,804.,890.,471.,534.,813.,900.}); + auto a = NDArrayFactory::create('c', {bS, iC, kH, kW, oH, oW}); + auto b = NDArrayFactory::create('c', {kH, kW, iC, mC}); + auto c = NDArrayFactory::create('c', {bS, oH, oW, iC*mC}); + auto expected = NDArrayFactory::create('c', {bS, oH, oW, iC*mC}, {100.,110.,336.,370.,107.,118.,345.,380.,114.,126.,354.,390.,121.,134.,363.,400.,128.,142.,372.,410.,135.,150.,381.,420., + 436.,494.,768.,850.,443.,502.,777.,860.,450.,510.,786.,870.,457.,518.,795.,880.,464.,526.,804.,890.,471.,534.,813.,900.}); - a.linspace(0.5, 0.5); - b.linspace(0.5, 0.5); + a.linspace(0.5, 0.5); + b.linspace(0.5, 0.5); - auto cR = c.reshape(a.ordering(), {bS, oH, oW, iC, mC}); + auto cR = c.reshape(a.ordering(), {bS, oH, oW, iC, mC}); - // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] - MmulHelper::tensorDot(&a, &b, &cR, {{1,0,4,5,2,3}, {iC,bS*oH*oW,kW*kH}}, {{2,0,1,3},{iC,kH*kW,mC}}, {{3,0,1,2,4},{iC, bS*oH*oW, mC}}); - // c.printBuffer(); + // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] + MmulHelper::tensorDot(&a, &b, &cR, {{1,0,4,5,2,3}, {iC,bS*oH*oW,kW*kH}}, {{2,0,1,3},{iC,kH*kW,mC}}, {{3,0,1,2,4},{iC, bS*oH*oW, mC}}); + // c.printBuffer(); - ASSERT_TRUE(c.isSameShape(expected)); - ASSERT_TRUE(c.equalsTo(expected)); + ASSERT_TRUE(c.isSameShape(expected)); + ASSERT_TRUE(c.equalsTo(expected)); } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmmulHelperAgain) { - auto x = NDArrayFactory::create('c', {128, 156}); - auto y = NDArrayFactory::create('c', {156, 256}); - auto z = NDArrayFactory::create('c', {128, 256}); - auto e = NDArrayFactory::create('c', {128, 256}); + auto x = NDArrayFactory::create('c', {128, 156}); + auto y = NDArrayFactory::create('c', {156, 256}); + auto z = NDArrayFactory::create('c', {128, 256}); + auto e = NDArrayFactory::create('c', {128, 256}); - x.assign(1.0f); - y.assign(1.0f); - e.assign(156.0f); + x.assign(1.0f); + y.assign(1.0f); + e.assign(156.0f); - MmulHelper::mmul(&x, &y, &z); + MmulHelper::mmul(&x, &y, &z); - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, OpArgsHolder_test1) { - auto x1 = NDArrayFactory::create('c', {1, 1}); - auto x2 = NDArrayFactory::create('c', {2, 2}); - auto x3 = NDArrayFactory::create('c', {3, 3}); + auto x1 = NDArrayFactory::create('c', {1, 1}); + auto x2 = NDArrayFactory::create('c', {2, 2}); + auto x3 = NDArrayFactory::create('c', {3, 3}); - OpArgsHolder holder1({&x1}); - OpArgsHolder holder2({&x1,&x2,&x3}, {4.f, 5.f}, {6}); + OpArgsHolder holder1({&x1}); + OpArgsHolder holder2({&x1,&x2,&x3}, {4.f, 5.f}, {6}); - ASSERT_TRUE(holder1.getNumInArrs() == 1); - ASSERT_TRUE(holder1.getNumTArgs() == 0); - ASSERT_TRUE(holder1.getNumIArgs() == 0); + ASSERT_TRUE(holder1.getNumInArrs() == 1); + ASSERT_TRUE(holder1.getNumTArgs() == 0); + ASSERT_TRUE(holder1.getNumIArgs() == 0); - ASSERT_TRUE(holder2.getNumInArrs() == 3); - ASSERT_TRUE(holder2.getNumTArgs() == 2); - ASSERT_TRUE(holder2.getNumIArgs() == 1); + ASSERT_TRUE(holder2.getNumInArrs() == 3); + ASSERT_TRUE(holder2.getNumTArgs() == 2); + ASSERT_TRUE(holder2.getNumIArgs() == 1); - const std::vector& isArrAlloc1 = holder1.getAllocInfo(); - ASSERT_TRUE(isArrAlloc1.size() == 0); + const std::vector& isArrAlloc1 = holder1.getAllocInfo(); + ASSERT_TRUE(isArrAlloc1.size() == 0); - const std::vector& isArrAlloc2 = holder2.getAllocInfo(); - ASSERT_TRUE(isArrAlloc2.size() == 0); + const std::vector& isArrAlloc2 = holder2.getAllocInfo(); + ASSERT_TRUE(isArrAlloc2.size() == 0); } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, OpArgsHolder_test2) { - auto x1 = NDArrayFactory::create('c', {1, 1}); - auto x2 = NDArrayFactory::create('c', {2, 2}); - auto x3 = NDArrayFactory::create('c', {3, 3}); - auto grad = NDArrayFactory::create('c', {2, 3}); + auto x1 = NDArrayFactory::create('c', {1, 1}); + auto x2 = NDArrayFactory::create('c', {2, 2}); + auto x3 = NDArrayFactory::create('c', {3, 3}); + auto grad = NDArrayFactory::create('c', {2, 3}); - OpArgsHolder holderFF({&x1,&x2,&x3}, {4.f, 5.f}, {6}); - OpArgsHolder holderBP1 = holderFF.createArgsHolderForBP({&grad}); - OpArgsHolder holderBP2 = holderFF.createArgsHolderForBP({&grad}, true); + OpArgsHolder holderFF({&x1,&x2,&x3}, {4.f, 5.f}, {6}); + OpArgsHolder holderBP1 = holderFF.createArgsHolderForBP({&grad}); + OpArgsHolder holderBP2 = holderFF.createArgsHolderForBP({&grad}, true); - ASSERT_TRUE(holderBP1.getNumInArrs() == 4); - ASSERT_TRUE(holderBP1.getNumTArgs() == 2); - ASSERT_TRUE(holderBP1.getNumIArgs() == 1); - ASSERT_TRUE(holderBP2.getNumInArrs() == 4); - ASSERT_TRUE(holderBP2.getNumTArgs() == 2); - ASSERT_TRUE(holderBP2.getNumIArgs() == 1); + ASSERT_TRUE(holderBP1.getNumInArrs() == 4); + ASSERT_TRUE(holderBP1.getNumTArgs() == 2); + ASSERT_TRUE(holderBP1.getNumIArgs() == 1); + ASSERT_TRUE(holderBP2.getNumInArrs() == 4); + ASSERT_TRUE(holderBP2.getNumTArgs() == 2); + ASSERT_TRUE(holderBP2.getNumIArgs() == 1); - const std::vector& isArrAllocBP1 = holderBP1.getAllocInfo(); - ASSERT_TRUE(isArrAllocBP1.size() == 0); + const std::vector& isArrAllocBP1 = holderBP1.getAllocInfo(); + ASSERT_TRUE(isArrAllocBP1.size() == 0); - const std::vector& isArrAllocBP2 = holderBP2.getAllocInfo(); - for(int i = 0; i < holderFF.getNumInArrs(); ++i) { - ASSERT_TRUE(static_cast(isArrAllocBP2[i]) == true); - } + const std::vector& isArrAllocBP2 = holderBP2.getAllocInfo(); + for(int i = 0; i < holderFF.getNumInArrs(); ++i) { + ASSERT_TRUE(static_cast(isArrAllocBP2[i]) == true); + } - ASSERT_TRUE(static_cast(isArrAllocBP2[holderFF.getNumInArrs()+1]) == false); + ASSERT_TRUE(static_cast(isArrAllocBP2[holderFF.getNumInArrs()+1]) == false); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, OpArgsHolder_test3) { - auto input = NDArrayFactory::create('c', {2, 3}, {1.,2.,3.,4.,5.,6.}); - auto gradO = NDArrayFactory::create('c', {4, 9}); - auto exp = NDArrayFactory::create('c', {4, 9}, {1, 2, 3, 1, 2, 3, 1, 2, 3,4, 5, 6, 4, 5, 6, 4, 5, 6,1, 2, 3, 1, 2, 3, 1, 2, 3,4, 5, 6, 4, 5, 6, 4, 5, 6}); - auto gradIExp = NDArrayFactory::create('c', {2, 3}, {0.78, 0.84, 0.9,1.32, 1.38, 1.44}); + auto input = NDArrayFactory::create('c', {2, 3}, {1.,2.,3.,4.,5.,6.}); + auto gradO = NDArrayFactory::create('c', {4, 9}); + auto exp = NDArrayFactory::create('c', {4, 9}, {1, 2, 3, 1, 2, 3, 1, 2, 3,4, 5, 6, 4, 5, 6, 4, 5, 6,1, 2, 3, 1, 2, 3, 1, 2, 3,4, 5, 6, 4, 5, 6, 4, 5, 6}); + auto gradIExp = NDArrayFactory::create('c', {2, 3}, {0.78, 0.84, 0.9,1.32, 1.38, 1.44}); - gradO.linspace(0.01, 0.01); + gradO.linspace(0.01, 0.01); - OpArgsHolder holderFF({&input}, {}, {2, 3}); - sd::ops::tile opFF; // the kind of op doesn't matter, we simply check here whether op.execute() works with OpArgsHolder correctly - auto results = opFF.execute(holderFF); - auto tiled = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(tiled)); - ASSERT_TRUE(exp.equalsTo(tiled)); + OpArgsHolder holderFF({&input}, {}, {2, 3}); + sd::ops::tile opFF; // the kind of op doesn't matter, we simply check here whether op.execute() works with OpArgsHolder correctly + auto results = opFF.execute(holderFF); + auto tiled = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(tiled)); + ASSERT_TRUE(exp.equalsTo(tiled)); - OpArgsHolder holderBP = holderFF.createArgsHolderForBP({&gradO}, true); - sd::ops::tile_bp opBP; - results = opBP.execute(holderBP); - auto gradI = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(gradIExp.isSameShape(gradI)); - ASSERT_TRUE(gradIExp.equalsTo(gradI)); + OpArgsHolder holderBP = holderFF.createArgsHolderForBP({&gradO}, true); + sd::ops::tile_bp opBP; + results = opBP.execute(holderBP); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradIExp.isSameShape(gradI)); + ASSERT_TRUE(gradIExp.equalsTo(gradI)); } @@ -1755,590 +1755,591 @@ TEST_F(HelpersTests1, OpArgsHolder_test3) { ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, checkGrad_test1) { - auto x = NDArrayFactory::create('c', {2, 3}, {0.1, 0.2, 0.3, 0.4, 0.5 ,0.6}); - auto gradO = NDArrayFactory::create('c', {2, 3}); + auto x = NDArrayFactory::create('c', {2, 3}, {0.1, 0.2, 0.3, 0.4, 0.5 ,0.6}); + auto gradO = NDArrayFactory::create('c', {2, 3}); - const OpArgsHolder argsHolderFF({&x}, {}, {}); - const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {}); + const OpArgsHolder argsHolderFF({&x}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {}); - sd::ops::sigmoid opFF; - sd::ops::sigmoid_bp opBP; + sd::ops::sigmoid opFF; + sd::ops::sigmoid_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, checkGrad_test2) { - auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); - auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); - auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); + auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); - x.linspace(1); - weights.linspace(0.1, 0.1); - weights.permutei({2,3,1,0}); + x.linspace(1); + weights.linspace(0.1, 0.1); + weights.permutei({2,3,1,0}); - const OpArgsHolder argsHolderFF({&x, &weights}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); - const OpArgsHolder argsHolderBP({&x, &weights, &gradO}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderFF({&x, &weights}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderBP({&x, &weights, &gradO}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); - sd::ops::conv2d opFF; - sd::ops::conv2d_bp opBP; + sd::ops::conv2d opFF; + sd::ops::conv2d_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, checkGrad_test3) { - auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); - auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); - auto bias = NDArrayFactory::create('c', {2, 1}); - auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); + auto bias = NDArrayFactory::create('c', {2, 1}); + auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); - x.linspace(1); - weights.linspace(0.1, 0.1); - bias = 0.5; - weights.permutei({2,3,1,0}); + x.linspace(1); + weights.linspace(0.1, 0.1); + bias = 0.5; + weights.permutei({2,3,1,0}); - const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); - const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); - sd::ops::conv2d opFF; - sd::ops::conv2d_bp opBP; + sd::ops::conv2d opFF; + sd::ops::conv2d_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, checkGrad_test4) { - auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); - auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); - auto bias = NDArrayFactory::create('c', {2, 1}); - auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); + auto bias = NDArrayFactory::create('c', {2, 1}); + auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); - x.linspace(1); - weights.linspace(0.1, 0.1); - bias = 0.5; - weights.permutei({2,3,1,0}); + x.linspace(1); + weights.linspace(0.1, 0.1); + bias = 0.5; + weights.permutei({2,3,1,0}); - const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); - const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); - sd::ops::conv2d opFF; - sd::ops::conv2d_bp opBP; + sd::ops::conv2d opFF; + sd::ops::conv2d_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 0, 1}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 0, 1}); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, checkGrad_test5) { - auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); - auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); - auto bias = NDArrayFactory::create('c', {2, 1}); - auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); + auto bias = NDArrayFactory::create('c', {2, 1}); + auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); - x.linspace(1); - weights.linspace(0.1, 0.1); - bias = 0.5; - weights.permutei({2,3,1,0}); + x.linspace(1); + weights.linspace(0.1, 0.1); + bias = 0.5; + weights.permutei({2,3,1,0}); - const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); - const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); - sd::ops::conv2d opFF; - sd::ops::conv2d_bp opBP; + sd::ops::conv2d opFF; + sd::ops::conv2d_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1, 1}, {0.5, 1}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1, 1}, {0.5, 1}); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, checkGrad_test6) { - auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); - auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); - auto bias = NDArrayFactory::create('c', {2, 1}); - auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); + auto bias = NDArrayFactory::create('c', {2, 1}); + auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); - x.linspace(1); - weights.linspace(0.1, 0.1); - bias = 0.5; - weights.permutei({2,3,1,0}); + x.linspace(1); + weights.linspace(0.1, 0.1); + bias = 0.5; + weights.permutei({2,3,1,0}); - const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); - const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); - sd::ops::conv2d opFF; - sd::ops::conv2d_bp opBP; + sd::ops::conv2d opFF; + sd::ops::conv2d_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 0, 1}, {0.5, 1}, GradCheck::MEAN); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 0, 1}, {0.5, 1}, GradCheck::MEAN); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, softMaxForVector_test1) { - auto input = NDArrayFactory::create('c', {1,5}, {1,2,3,4,5}); - auto output = NDArrayFactory::create('c', {1,5}); - auto expOutput = NDArrayFactory::create('c', {1,5}); - expOutput = 1; + auto input = NDArrayFactory::create('c', {1,5}, {1,2,3,4,5}); + auto output = NDArrayFactory::create('c', {1,5}); + auto expOutput = NDArrayFactory::create('c', {1,5}); + expOutput = 1; - ops::helpers::softmax(sd::LaunchContext ::defaultContext(), input, output, 0); + ops::helpers::softmax(sd::LaunchContext ::defaultContext(), input, output, 0); - ASSERT_TRUE(output.equalsTo(&expOutput)); + ASSERT_TRUE(output.equalsTo(&expOutput)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, softMaxForVector_test2) { - auto input = NDArrayFactory::create('c', {5,1}, {1,2,3,4,5}); - auto output = NDArrayFactory::create('c', {5,1}); - auto expOutput = NDArrayFactory::create('c', {5,1}, {0.01165623, 0.03168492, 0.08612854, 0.23412166, 0.63640865}); + auto input = NDArrayFactory::create('c', {5,1}, {1,2,3,4,5}); + auto output = NDArrayFactory::create('c', {5,1}); + auto expOutput = NDArrayFactory::create('c', {5,1}, {0.01165623, 0.03168492, 0.08612854, 0.23412166, 0.63640865}); - ops::helpers::softmax(sd::LaunchContext ::defaultContext(), input, output, 0); + ops::helpers::softmax(sd::LaunchContext ::defaultContext(), input, output, 0); - ASSERT_TRUE(output.equalsTo(&expOutput)); + ASSERT_TRUE(output.equalsTo(&expOutput)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, softMaxForVector_test3) { - auto input= NDArrayFactory::create('c', {5}, {1,2,3,4,5}); - auto output = NDArrayFactory::create('c', {5}); - auto expOutput = NDArrayFactory::create('c', {5}, {0.01165623, 0.03168492, 0.08612854, 0.23412166, 0.63640865}); + auto input= NDArrayFactory::create('c', {5}, {1,2,3,4,5}); + auto output = NDArrayFactory::create('c', {5}); + auto expOutput = NDArrayFactory::create('c', {5}, {0.01165623, 0.03168492, 0.08612854, 0.23412166, 0.63640865}); - ops::helpers::softmax(sd::LaunchContext ::defaultContext(), input, output, 0); + ops::helpers::softmax(sd::LaunchContext ::defaultContext(), input, output, 0); - ASSERT_TRUE(output.equalsTo(&expOutput)); + ASSERT_TRUE(output.equalsTo(&expOutput)); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, softMaxForVector_test4) { - NDArray input('c', {1500}, sd::DataType::DOUBLE); - NDArray output('c', {1500}, sd::DataType::DOUBLE); - NDArray expOutput('c', {1500}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.00001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, -0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001,0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, -0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002,0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, -0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003,0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, -0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005,0.000005, 0.000005, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, -0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009,0.000009, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000012, 0.000012, 0.000012, 0.000012, 0.000012, 0.000012, 0.000012, 0.000012, -0.000012, 0.000013, 0.000013, 0.000013, 0.000013, 0.000013, 0.000013, 0.000013, 0.000014, 0.000014, 0.000014, 0.000014, 0.000014, 0.000014, 0.000014, 0.000015, 0.000015, 0.000015, 0.000015, 0.000015, 0.000015, 0.000015, 0.000016, 0.000016, 0.000016, 0.000016, 0.000016, 0.000016,0.000017, 0.000017, 0.000017, 0.000017, 0.000017, 0.000017, 0.000018, 0.000018, 0.000018, 0.000018, 0.000018, 0.000018, 0.000019, 0.000019, 0.000019, 0.000019, 0.000019, 0.000020, 0.000020, 0.000020, 0.000020, 0.000020, 0.000021, 0.000021, 0.000021, 0.000021, 0.000021, 0.000022, -0.000022, 0.000022, 0.000022, 0.000023, 0.000023, 0.000023, 0.000023, 0.000023, 0.000024, 0.000024, 0.000024, 0.000024, 0.000025, 0.000025, 0.000025, 0.000025, 0.000026, 0.000026, 0.000026, 0.000026, 0.000027, 0.000027, 0.000027, 0.000028, 0.000028, 0.000028, 0.000028, 0.000029,0.000029, 0.000029, 0.000030, 0.000030, 0.000030, 0.000030, 0.000031, 0.000031, 0.000031, 0.000032, 0.000032, 0.000032, 0.000033, 0.000033, 0.000033, 0.000034, 0.000034, 0.000034, 0.000035, 0.000035, 0.000035, 0.000036, 0.000036, 0.000036, 0.000037, 0.000037, 0.000038, 0.000038, -0.000038, 0.000039, 0.000039, 0.000039, 0.000040, 0.000040, 0.000041, 0.000041, 0.000041, 0.000042, 0.000042, 0.000043, 0.000043, 0.000044, 0.000044, 0.000044, 0.000045, 0.000045, 0.000046, 0.000046, 0.000047, 0.000047, 0.000048, 0.000048, 0.000049, 0.000049, 0.000050, 0.000050,0.000051, 0.000051, 0.000052, 0.000052, 0.000053, 0.000053, 0.000054, 0.000054, 0.000055, 0.000055, 0.000056, 0.000057, 0.000057, 0.000058, 0.000058, 0.000059, 0.000059, 0.000060, 0.000061, 0.000061, 0.000062, 0.000063, 0.000063, 0.000064, 0.000064, 0.000065, 0.000066, 0.000066, -0.000067, 0.000068, 0.000068, 0.000069, 0.000070, 0.000070, 0.000071, 0.000072, 0.000073, 0.000073, 0.000074, 0.000075, 0.000076, 0.000076, 0.000077, 0.000078, 0.000079, 0.000079, 0.000080, 0.000081, 0.000082, 0.000083, 0.000084, 0.000084, 0.000085, 0.000086, 0.000087, 0.000088,0.000089, 0.000090, 0.000090, 0.000091, 0.000092, 0.000093, 0.000094, 0.000095, 0.000096, 0.000097, 0.000098, 0.000099, 0.000100, 0.000101, 0.000102, 0.000103, 0.000104, 0.000105, 0.000106, 0.000107, 0.000108, 0.000109, 0.000111, 0.000112, 0.000113, 0.000114, 0.000115, 0.000116, -0.000117, 0.000119, 0.000120, 0.000121, 0.000122, 0.000123, 0.000125, 0.000126, 0.000127, 0.000128, 0.000130, 0.000131, 0.000132, 0.000134, 0.000135, 0.000136, 0.000138, 0.000139, 0.000141, 0.000142, 0.000143, 0.000145, 0.000146, 0.000148, 0.000149, 0.000151, 0.000152, 0.000154,0.000155, 0.000157, 0.000158, 0.000160, 0.000162, 0.000163, 0.000165, 0.000167, 0.000168, 0.000170, 0.000172, 0.000173, 0.000175, 0.000177, 0.000179, 0.000180, 0.000182, 0.000184, 0.000186, 0.000188, 0.000190, 0.000192, 0.000194, 0.000195, 0.000197, 0.000199, 0.000201, 0.000203, -0.000205, 0.000208, 0.000210, 0.000212, 0.000214, 0.000216, 0.000218, 0.000220, 0.000223, 0.000225, 0.000227, 0.000229, 0.000232, 0.000234, 0.000236, 0.000239, 0.000241, 0.000244, 0.000246, 0.000248, 0.000251, 0.000253, 0.000256, 0.000259, 0.000261, 0.000264, 0.000266, 0.000269,0.000272, 0.000275, 0.000277, 0.000280, 0.000283, 0.000286, 0.000289, 0.000292, 0.000295, 0.000297, 0.000300, 0.000303, 0.000307, 0.000310, 0.000313, 0.000316, 0.000319, 0.000322, 0.000325, 0.000329, 0.000332, 0.000335, 0.000339, 0.000342, 0.000346, 0.000349, 0.000353, 0.000356, -0.000360, 0.000363, 0.000367, 0.000371, 0.000374, 0.000378, 0.000382, 0.000386, 0.000390, 0.000394, 0.000398, 0.000402, 0.000406, 0.000410, 0.000414, 0.000418, 0.000422, 0.000426, 0.000431, 0.000435, 0.000439, 0.000444, 0.000448, 0.000453, 0.000457, 0.000462, 0.000467, 0.000471,0.000476, 0.000481, 0.000486, 0.000490, 0.000495, 0.000500, 0.000505, 0.000510, 0.000516, 0.000521, 0.000526, 0.000531, 0.000537, 0.000542, 0.000547, 0.000553, 0.000559, 0.000564, 0.000570, 0.000576, 0.000581, 0.000587, 0.000593, 0.000599, 0.000605, 0.000611, 0.000617, 0.000623, -0.000630, 0.000636, 0.000642, 0.000649, 0.000655, 0.000662, 0.000669, 0.000675, 0.000682, 0.000689, 0.000696, 0.000703, 0.000710, 0.000717, 0.000724, 0.000732, 0.000739, 0.000746, 0.000754, 0.000762, 0.000769, 0.000777, 0.000785, 0.000793, 0.000801, 0.000809, 0.000817, 0.000825,0.000833, 0.000842, 0.000850, 0.000859, 0.000867, 0.000876, 0.000885, 0.000894, 0.000903, 0.000912, 0.000921, 0.000930, 0.000939, 0.000949, 0.000958, 0.000968, 0.000978, 0.000988, 0.000998, 0.001008, 0.001018, 0.001028, 0.001038, 0.001049, 0.001059, 0.001070, 0.001081, 0.001092, -0.001103, 0.001114, 0.001125, 0.001136, 0.001148, 0.001159, 0.001171, 0.001182, 0.001194, 0.001206, 0.001218, 0.001231, 0.001243, 0.001256, 0.001268, 0.001281, 0.001294, 0.001307, 0.001320, 0.001333, 0.001347, 0.001360, 0.001374, 0.001388, 0.001402, 0.001416, 0.001430, 0.001444,0.001459, 0.001473, 0.001488, 0.001503, 0.001518, 0.001534, 0.001549, 0.001565, 0.001580, 0.001596, 0.001612, 0.001628, 0.001645, 0.001661, 0.001678, 0.001695, 0.001712, 0.001729, 0.001746, 0.001764, 0.001782, 0.001800, 0.001818, 0.001836, 0.001854, 0.001873, 0.001892, 0.001911, -0.001930, 0.001950, 0.001969, 0.001989, 0.002009, 0.002029, 0.002049, 0.002070, 0.002091, 0.002112, 0.002133, 0.002155, 0.002176, 0.002198, 0.002220, 0.002242, 0.002265, 0.002288, 0.002311, 0.002334, 0.002357, 0.002381, 0.002405, 0.002429, 0.002454, 0.002478, 0.002503, 0.002528,0.002554, 0.002579, 0.002605, 0.002632, 0.002658, 0.002685, 0.002712, 0.002739, 0.002767, 0.002794, 0.002822, 0.002851, 0.002879, 0.002908, 0.002938, 0.002967, 0.002997, 0.003027, 0.003057, 0.003088, 0.003119, 0.003151, 0.003182, 0.003214, 0.003247, 0.003279, 0.003312, 0.003345, -0.003379, 0.003413, 0.003447, 0.003482, 0.003517, 0.003552, 0.003588, 0.003624, 0.003660, 0.003697, 0.003734, 0.003772, 0.003810, 0.003848, 0.003887, 0.003926, 0.003965, 0.004005, 0.004045, 0.004086, 0.004127, 0.004169, 0.004211, 0.004253, 0.004296, 0.004339, 0.004382, 0.004426,0.004471, 0.004516, 0.004561, 0.004607, 0.004653, 0.004700, 0.004747, 0.004795, 0.004843, 0.004892, 0.004941, 0.004991, 0.005041, 0.005092, 0.005143, 0.005194, 0.005247, 0.005299, 0.005353, 0.005406, 0.005461, 0.005516, 0.005571, 0.005627, 0.005684, 0.005741, 0.005798, 0.005857, -0.005916, 0.005975, 0.006035, 0.006096, 0.006157, 0.006219, 0.006281, 0.006345, 0.006408, 0.006473, 0.006538, 0.006603, 0.006670, 0.006737, 0.006805, 0.006873, 0.006942, 0.007012, 0.007082, 0.007153, 0.007225, 0.007298, 0.007371, 0.007445, 0.007520, 0.007596, 0.007672, 0.007749,0.007827, 0.007906, 0.007985, 0.008065, 0.008147, 0.008228, 0.008311, 0.008395, 0.008479, 0.008564, 0.008650, 0.008737, 0.008825, 0.008914, 0.009003, 0.009094, 0.009185, 0.009277, 0.009371, 0.009465, 0.009560, 0.009656, 0.009753, 0.009851, 0.009950}, sd::DataType::DOUBLE); - input.linspace(0.01, 0.01); - - ops::helpers::softmax(sd::LaunchContext ::defaultContext(), input, output, 0); - - ASSERT_TRUE(output.equalsTo(&expOutput)); + NDArray input('c', {1500}, sd::DataType::DOUBLE); + NDArray output('c', {1500}, sd::DataType::DOUBLE); + NDArray expOutput('c', {1500}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.00001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, + 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001,0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, + 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002,0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, + 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003,0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, + 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005,0.000005, 0.000005, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, + 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009,0.000009, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000012, 0.000012, 0.000012, 0.000012, 0.000012, 0.000012, 0.000012, 0.000012, + 0.000012, 0.000013, 0.000013, 0.000013, 0.000013, 0.000013, 0.000013, 0.000013, 0.000014, 0.000014, 0.000014, 0.000014, 0.000014, 0.000014, 0.000014, 0.000015, 0.000015, 0.000015, 0.000015, 0.000015, 0.000015, 0.000015, 0.000016, 0.000016, 0.000016, 0.000016, 0.000016, 0.000016,0.000017, 0.000017, 0.000017, 0.000017, 0.000017, 0.000017, 0.000018, 0.000018, 0.000018, 0.000018, 0.000018, 0.000018, 0.000019, 0.000019, 0.000019, 0.000019, 0.000019, 0.000020, 0.000020, 0.000020, 0.000020, 0.000020, 0.000021, 0.000021, 0.000021, 0.000021, 0.000021, 0.000022, + 0.000022, 0.000022, 0.000022, 0.000023, 0.000023, 0.000023, 0.000023, 0.000023, 0.000024, 0.000024, 0.000024, 0.000024, 0.000025, 0.000025, 0.000025, 0.000025, 0.000026, 0.000026, 0.000026, 0.000026, 0.000027, 0.000027, 0.000027, 0.000028, 0.000028, 0.000028, 0.000028, 0.000029,0.000029, 0.000029, 0.000030, 0.000030, 0.000030, 0.000030, 0.000031, 0.000031, 0.000031, 0.000032, 0.000032, 0.000032, 0.000033, 0.000033, 0.000033, 0.000034, 0.000034, 0.000034, 0.000035, 0.000035, 0.000035, 0.000036, 0.000036, 0.000036, 0.000037, 0.000037, 0.000038, 0.000038, + 0.000038, 0.000039, 0.000039, 0.000039, 0.000040, 0.000040, 0.000041, 0.000041, 0.000041, 0.000042, 0.000042, 0.000043, 0.000043, 0.000044, 0.000044, 0.000044, 0.000045, 0.000045, 0.000046, 0.000046, 0.000047, 0.000047, 0.000048, 0.000048, 0.000049, 0.000049, 0.000050, 0.000050,0.000051, 0.000051, 0.000052, 0.000052, 0.000053, 0.000053, 0.000054, 0.000054, 0.000055, 0.000055, 0.000056, 0.000057, 0.000057, 0.000058, 0.000058, 0.000059, 0.000059, 0.000060, 0.000061, 0.000061, 0.000062, 0.000063, 0.000063, 0.000064, 0.000064, 0.000065, 0.000066, 0.000066, + 0.000067, 0.000068, 0.000068, 0.000069, 0.000070, 0.000070, 0.000071, 0.000072, 0.000073, 0.000073, 0.000074, 0.000075, 0.000076, 0.000076, 0.000077, 0.000078, 0.000079, 0.000079, 0.000080, 0.000081, 0.000082, 0.000083, 0.000084, 0.000084, 0.000085, 0.000086, 0.000087, 0.000088,0.000089, 0.000090, 0.000090, 0.000091, 0.000092, 0.000093, 0.000094, 0.000095, 0.000096, 0.000097, 0.000098, 0.000099, 0.000100, 0.000101, 0.000102, 0.000103, 0.000104, 0.000105, 0.000106, 0.000107, 0.000108, 0.000109, 0.000111, 0.000112, 0.000113, 0.000114, 0.000115, 0.000116, + 0.000117, 0.000119, 0.000120, 0.000121, 0.000122, 0.000123, 0.000125, 0.000126, 0.000127, 0.000128, 0.000130, 0.000131, 0.000132, 0.000134, 0.000135, 0.000136, 0.000138, 0.000139, 0.000141, 0.000142, 0.000143, 0.000145, 0.000146, 0.000148, 0.000149, 0.000151, 0.000152, 0.000154,0.000155, 0.000157, 0.000158, 0.000160, 0.000162, 0.000163, 0.000165, 0.000167, 0.000168, 0.000170, 0.000172, 0.000173, 0.000175, 0.000177, 0.000179, 0.000180, 0.000182, 0.000184, 0.000186, 0.000188, 0.000190, 0.000192, 0.000194, 0.000195, 0.000197, 0.000199, 0.000201, 0.000203, + 0.000205, 0.000208, 0.000210, 0.000212, 0.000214, 0.000216, 0.000218, 0.000220, 0.000223, 0.000225, 0.000227, 0.000229, 0.000232, 0.000234, 0.000236, 0.000239, 0.000241, 0.000244, 0.000246, 0.000248, 0.000251, 0.000253, 0.000256, 0.000259, 0.000261, 0.000264, 0.000266, 0.000269,0.000272, 0.000275, 0.000277, 0.000280, 0.000283, 0.000286, 0.000289, 0.000292, 0.000295, 0.000297, 0.000300, 0.000303, 0.000307, 0.000310, 0.000313, 0.000316, 0.000319, 0.000322, 0.000325, 0.000329, 0.000332, 0.000335, 0.000339, 0.000342, 0.000346, 0.000349, 0.000353, 0.000356, + 0.000360, 0.000363, 0.000367, 0.000371, 0.000374, 0.000378, 0.000382, 0.000386, 0.000390, 0.000394, 0.000398, 0.000402, 0.000406, 0.000410, 0.000414, 0.000418, 0.000422, 0.000426, 0.000431, 0.000435, 0.000439, 0.000444, 0.000448, 0.000453, 0.000457, 0.000462, 0.000467, 0.000471,0.000476, 0.000481, 0.000486, 0.000490, 0.000495, 0.000500, 0.000505, 0.000510, 0.000516, 0.000521, 0.000526, 0.000531, 0.000537, 0.000542, 0.000547, 0.000553, 0.000559, 0.000564, 0.000570, 0.000576, 0.000581, 0.000587, 0.000593, 0.000599, 0.000605, 0.000611, 0.000617, 0.000623, + 0.000630, 0.000636, 0.000642, 0.000649, 0.000655, 0.000662, 0.000669, 0.000675, 0.000682, 0.000689, 0.000696, 0.000703, 0.000710, 0.000717, 0.000724, 0.000732, 0.000739, 0.000746, 0.000754, 0.000762, 0.000769, 0.000777, 0.000785, 0.000793, 0.000801, 0.000809, 0.000817, 0.000825,0.000833, 0.000842, 0.000850, 0.000859, 0.000867, 0.000876, 0.000885, 0.000894, 0.000903, 0.000912, 0.000921, 0.000930, 0.000939, 0.000949, 0.000958, 0.000968, 0.000978, 0.000988, 0.000998, 0.001008, 0.001018, 0.001028, 0.001038, 0.001049, 0.001059, 0.001070, 0.001081, 0.001092, + 0.001103, 0.001114, 0.001125, 0.001136, 0.001148, 0.001159, 0.001171, 0.001182, 0.001194, 0.001206, 0.001218, 0.001231, 0.001243, 0.001256, 0.001268, 0.001281, 0.001294, 0.001307, 0.001320, 0.001333, 0.001347, 0.001360, 0.001374, 0.001388, 0.001402, 0.001416, 0.001430, 0.001444,0.001459, 0.001473, 0.001488, 0.001503, 0.001518, 0.001534, 0.001549, 0.001565, 0.001580, 0.001596, 0.001612, 0.001628, 0.001645, 0.001661, 0.001678, 0.001695, 0.001712, 0.001729, 0.001746, 0.001764, 0.001782, 0.001800, 0.001818, 0.001836, 0.001854, 0.001873, 0.001892, 0.001911, + 0.001930, 0.001950, 0.001969, 0.001989, 0.002009, 0.002029, 0.002049, 0.002070, 0.002091, 0.002112, 0.002133, 0.002155, 0.002176, 0.002198, 0.002220, 0.002242, 0.002265, 0.002288, 0.002311, 0.002334, 0.002357, 0.002381, 0.002405, 0.002429, 0.002454, 0.002478, 0.002503, 0.002528,0.002554, 0.002579, 0.002605, 0.002632, 0.002658, 0.002685, 0.002712, 0.002739, 0.002767, 0.002794, 0.002822, 0.002851, 0.002879, 0.002908, 0.002938, 0.002967, 0.002997, 0.003027, 0.003057, 0.003088, 0.003119, 0.003151, 0.003182, 0.003214, 0.003247, 0.003279, 0.003312, 0.003345, + 0.003379, 0.003413, 0.003447, 0.003482, 0.003517, 0.003552, 0.003588, 0.003624, 0.003660, 0.003697, 0.003734, 0.003772, 0.003810, 0.003848, 0.003887, 0.003926, 0.003965, 0.004005, 0.004045, 0.004086, 0.004127, 0.004169, 0.004211, 0.004253, 0.004296, 0.004339, 0.004382, 0.004426,0.004471, 0.004516, 0.004561, 0.004607, 0.004653, 0.004700, 0.004747, 0.004795, 0.004843, 0.004892, 0.004941, 0.004991, 0.005041, 0.005092, 0.005143, 0.005194, 0.005247, 0.005299, 0.005353, 0.005406, 0.005461, 0.005516, 0.005571, 0.005627, 0.005684, 0.005741, 0.005798, 0.005857, + 0.005916, 0.005975, 0.006035, 0.006096, 0.006157, 0.006219, 0.006281, 0.006345, 0.006408, 0.006473, 0.006538, 0.006603, 0.006670, 0.006737, 0.006805, 0.006873, 0.006942, 0.007012, 0.007082, 0.007153, 0.007225, 0.007298, 0.007371, 0.007445, 0.007520, 0.007596, 0.007672, 0.007749,0.007827, 0.007906, 0.007985, 0.008065, 0.008147, 0.008228, 0.008311, 0.008395, 0.008479, 0.008564, 0.008650, 0.008737, 0.008825, 0.008914, 0.009003, 0.009094, 0.009185, 0.009277, 0.009371, 0.009465, 0.009560, 0.009656, 0.009753, 0.009851, 0.009950}, sd::DataType::DOUBLE); + input.linspace(0.01, 0.01); + + ops::helpers::softmax(sd::LaunchContext ::defaultContext(), input, output, 0); + + ASSERT_TRUE(output.equalsTo(&expOutput)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, logSoftMaxForVector_test1) { - auto input = NDArrayFactory::create('c', {1,5}, {1,2,3,4,5}); - auto output = NDArrayFactory::create('c', {1,5}); - auto expOutput = NDArrayFactory::create('c', {1,5}); - expOutput = 0; + auto input = NDArrayFactory::create('c', {1,5}, {1,2,3,4,5}); + auto output = NDArrayFactory::create('c', {1,5}); + auto expOutput = NDArrayFactory::create('c', {1,5}); + expOutput = 0; - ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, 0); + ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, 0); - ASSERT_TRUE(output.equalsTo(&expOutput)); + ASSERT_TRUE(output.equalsTo(&expOutput)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, logSoftMaxForVector_test2) { - auto input= NDArrayFactory::create('c', {5,1}, {1,2,3,4,5}); - auto output = NDArrayFactory::create('c', {5,1}); - auto expOutput = NDArrayFactory::create('c', {5,1}, {-4.4519144, -3.4519144, -2.4519144, -1.4519144, -0.4519144}); + auto input= NDArrayFactory::create('c', {5,1}, {1,2,3,4,5}); + auto output = NDArrayFactory::create('c', {5,1}); + auto expOutput = NDArrayFactory::create('c', {5,1}, {-4.4519144, -3.4519144, -2.4519144, -1.4519144, -0.4519144}); - ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, 0); + ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, 0); - ASSERT_TRUE(output.equalsTo(&expOutput)); + ASSERT_TRUE(output.equalsTo(&expOutput)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, logSoftMaxForVector_test3) { - auto input= NDArrayFactory::create('c', {5}, {1,2,3,4,5}); - auto output = NDArrayFactory::create('c', {5}); - auto expOutput = NDArrayFactory::create('c', {5}, {-4.4519144, -3.4519144, -2.4519144, -1.4519144, -0.4519144}); + auto input= NDArrayFactory::create('c', {5}, {1,2,3,4,5}); + auto output = NDArrayFactory::create('c', {5}); + auto expOutput = NDArrayFactory::create('c', {5}, {-4.4519144, -3.4519144, -2.4519144, -1.4519144, -0.4519144}); - ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, 0); + ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, 0); - ASSERT_TRUE(output.equalsTo(&expOutput)); + ASSERT_TRUE(output.equalsTo(&expOutput)); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, logSoftMaxForVector_test4) { - NDArray input('c', {1500}, sd::DataType::DOUBLE); - NDArray output('c', {1500}, sd::DataType::DOUBLE); - NDArray expOutput('c', {1500}, {-8.154773, -8.153772, -8.152773, -8.151772, -8.150773, -8.149773, -8.148773, -8.147773, -8.146772, -8.145773, -8.144773, -8.143773, -8.142773, -8.141773, -8.140773, -8.139772, -8.138773, -8.137773, -8.136773, -8.135773, -8.134773, -8.133773, -8.132772, -8.131773, -8.130773, -8.129773, -8.128773, -8.127772, -8.126773, -8.125772, -8.124773, -8.123773, -8.122773, -8.121773, -8.120772, -8.119773, -8.118773, -8.117773, -8.116773, -8.115773, -8.114773, -8.113772, -8.112773, -8.111773, -8.110773, -8.109773, -8.108773, -8.107773, -8.106772, -8.105773, -8.104773, -8.103773, -8.102773, -8.101772, -8.100773, -8.099772, -8.098773, -8.097773, -8.096773, -8.095773, -8.094772, -8.093773, -8.092772, -8.091773, -8.090773, -8.089773, -8.088773, -8.087772, -8.086773, -8.085773, -8.084773, -8.083773, -8.082773, -8.081773, -8.080772, -8.079773, -8.078773, -8.077773, -8.076773, -8.075773, -8.074773, -8.073772, -8.072773, -8.071773, -8.070773, -8.069773, -8.068772, -8.067773, -8.066772, -8.065773, -8.064773, -8.063773, -8.062773, -8.061772, -8.060773, -8.059772, -8.058773, -8.057773, -8.056773, -8.055773, -8.054772, --8.053773, -8.052773, -8.051773, -8.050773, -8.049773, -8.048773, -8.047772, -8.046773, -8.045773, -8.044773, -8.043773, -8.042773, -8.041773, -8.040772, -8.039773, -8.038773, -8.037773, -8.036773, -8.035772, -8.034773, -8.033772, -8.032773, -8.031773, -8.030773, -8.029773, -8.028772, -8.027773, -8.026772, -8.025773, -8.024773, -8.023773, -8.022773, -8.021772, -8.020773, -8.019773, -8.018773, -8.017773, -8.016773, -8.015773, -8.014772, -8.013773, -8.012773, -8.011773, -8.010773, -8.009773, -8.008773, -8.007772, -8.006773, -8.005773, -8.004773, -8.003773, -8.002772, -8.001773, -8.000772, -7.999773, -7.998773, -7.997773, -7.996773, -7.995773, -7.994773, -7.993773, -7.992773, -7.991773, -7.990773, -7.989773, -7.988773, -7.987773, -7.986773, -7.985773, -7.984773, -7.983773, -7.982773, -7.981773, -7.980773, -7.979773, -7.978773, -7.977773, -7.976773, -7.975773, -7.974773, -7.973773, -7.972773, -7.971773, -7.970773, -7.969773, -7.968773, -7.967773, -7.966773, -7.965773, -7.964773, -7.963773, -7.962773, -7.961773, -7.960773, -7.959773, -7.958773, -7.957773, -7.956773, -7.955773, -7.954773, -7.953773, -7.952773, --7.951773, -7.950773, -7.949773, -7.948773, -7.947773, -7.946773, -7.945773, -7.944773, -7.943773, -7.942773, -7.941773, -7.940773, -7.939773, -7.938773, -7.937773, -7.936773, -7.935773, -7.934773, -7.933773, -7.932773, -7.931773, -7.930773, -7.929773, -7.928773, -7.927773, -7.926773, -7.925773, -7.924773, -7.923773, -7.922773, -7.921773, -7.920773, -7.919773, -7.918773, -7.917773, -7.916773, -7.915773, -7.914773, -7.913773, -7.912773, -7.911773, -7.910773, -7.909773, -7.908773, -7.907773, -7.906773, -7.905773, -7.904773, -7.903773, -7.902773, -7.901773, -7.900773, -7.899773, -7.898773, -7.897773, -7.896773, -7.895773, -7.894773, -7.893773, -7.892773, -7.891773, -7.890773, -7.889773, -7.888773, -7.887773, -7.886773, -7.885773, -7.884773, -7.883773, -7.882773, -7.881773, -7.880773, -7.879773, -7.878773, -7.877773, -7.876773, -7.875773, -7.874773, -7.873773, -7.872773, -7.871773, -7.870773, -7.869773, -7.868773, -7.867773, -7.866773, -7.865773, -7.864773, -7.863773, -7.862773, -7.861773, -7.860773, -7.859773, -7.858773, -7.857773, -7.856773, -7.855773, -7.854773, -7.853773, -7.852773, -7.851773, -7.850773, -7.849773, --7.848773, -7.847773, -7.846773, -7.845773, -7.844773, -7.843773, -7.842773, -7.841773, -7.840773, -7.839773, -7.838773, -7.837773, -7.836773, -7.835773, -7.834773, -7.833773, -7.832773, -7.831773, -7.830773, -7.829773, -7.828773, -7.827773, -7.826773, -7.825773, -7.824773, -7.823773, -7.822773, -7.821773, -7.820773, -7.819773, -7.818773, -7.817773, -7.816773, -7.815773, -7.814773, -7.813773, -7.812773, -7.811773, -7.810773, -7.809773, -7.808773, -7.807773, -7.806773, -7.805773, -7.804773, -7.803773, -7.802773, -7.801773, -7.800773, -7.799773, -7.798773, -7.797773, -7.796773, -7.795773, -7.794773, -7.793773, -7.792773, -7.791773, -7.790773, -7.789773, -7.788773, -7.787773, -7.786773, -7.785773, -7.784773, -7.783773, -7.782773, -7.781773, -7.780773, -7.779773, -7.778773, -7.777773, -7.776773, -7.775773, -7.774773, -7.773773, -7.772773, -7.771773, -7.770773, -7.769773, -7.768773, -7.767773, -7.766773, -7.765773, -7.764773, -7.763773, -7.762773, -7.761773, -7.760773, -7.759773, -7.758773, -7.757773, -7.756773, -7.755773, -7.754773, -7.753773, -7.752773, -7.751773, -7.750773, -7.749773, -7.748773, -7.747773, -7.746773, --7.745773, -7.744773, -7.743773, -7.742773, -7.741773, -7.740773, -7.739773, -7.738773, -7.737773, -7.736773, -7.735773, -7.734773, -7.733773, -7.732773, -7.731773, -7.730773, -7.729773, -7.728773, -7.727773, -7.726773, -7.725773, -7.724773, -7.723773, -7.722773, -7.721773, -7.720773, -7.719773, -7.718773, -7.717773, -7.716773, -7.715773, -7.714773, -7.713773, -7.712773, -7.711773, -7.710773, -7.709773, -7.708773, -7.707773, -7.706773, -7.705773, -7.704773, -7.703773, -7.702773, -7.701773, -7.700773, -7.699773, -7.698773, -7.697773, -7.696773, -7.695773, -7.694773, -7.693773, -7.692773, -7.691773, -7.690773, -7.689773, -7.688773, -7.687773, -7.686773, -7.685773, -7.684773, -7.683773, -7.682773, -7.681773, -7.680773, -7.679773, -7.678773, -7.677773, -7.676773, -7.675773, -7.674773, -7.673773, -7.672773, -7.671773, -7.670773, -7.669773, -7.668773, -7.667773, -7.666773, -7.665773, -7.664773, -7.663773, -7.662773, -7.661773, -7.660773, -7.659773, -7.658773, -7.657773, -7.656773, -7.655773, -7.654773, -7.653773, -7.652773, -7.651773, -7.650773, -7.649773, -7.648773, -7.647773, -7.646773, -7.645773, -7.644773, -7.643773, --7.642773, -7.641773, -7.640773, -7.639773, -7.638773, -7.637773, -7.636773, -7.635773, -7.634773, -7.633773, -7.632773, -7.631773, -7.630773, -7.629773, -7.628773, -7.627773, -7.626773, -7.625773, -7.624773, -7.623773, -7.622773, -7.621773, -7.620773, -7.619773, -7.618773, -7.617773, -7.616773, -7.615773, -7.614773, -7.613773, -7.612773, -7.611773, -7.610773, -7.609773, -7.608773, -7.607773, -7.606773, -7.605773, -7.604773, -7.603773, -7.602773, -7.601773, -7.600773, -7.599773, -7.598773, -7.597773, -7.596773, -7.595773, -7.594773, -7.593773, -7.592773, -7.591773, -7.590773, -7.589773, -7.588773, -7.587773, -7.586773, -7.585773, -7.584773, -7.583773, -7.582773, -7.581773, -7.580773, -7.579773, -7.578773, -7.577773, -7.576773, -7.575773, -7.574773, -7.573773, -7.572773, -7.571773, -7.570773, -7.569773, -7.568773, -7.567773, -7.566773, -7.565773, -7.564773, -7.563773, -7.562773, -7.561773, -7.560773, -7.559773, -7.558773, -7.557773, -7.556773, -7.555773, -7.554773, -7.553773, -7.552773, -7.551773, -7.550773, -7.549773, -7.548773, -7.547773, -7.546773, -7.545773, -7.544773, -7.543773, -7.542773, -7.541773, -7.540773, --7.539773, -7.538773, -7.537773, -7.536773, -7.535773, -7.534773, -7.533773, -7.532773, -7.531773, -7.530773, -7.529773, -7.528773, -7.527773, -7.526773, -7.525773, -7.524773, -7.523773, -7.522773, -7.521773, -7.520773, -7.519773, -7.518773, -7.517773, -7.516773, -7.515773, -7.514773, -7.513773, -7.512773, -7.511773, -7.510773, -7.509773, -7.508773, -7.507773, -7.506773, -7.505773, -7.504773, -7.503773, -7.502773, -7.501773, -7.500773, -7.499773, -7.498773, -7.497773, -7.496773, -7.495773, -7.494773, -7.493773, -7.492773, -7.491773, -7.490773, -7.489773, -7.488773, -7.487773, -7.486773, -7.485773, -7.484773, -7.483773, -7.482773, -7.481773, -7.480773, -7.479773, -7.478773, -7.477773, -7.476773, -7.475773, -7.474773, -7.473773, -7.472773, -7.471773, -7.470773, -7.469773, -7.468773, -7.467773, -7.466773, -7.465773, -7.464773, -7.463773, -7.462773, -7.461773, -7.460773, -7.459773, -7.458773, -7.457773, -7.456773, -7.455773, -7.454773, -7.453773, -7.452773, -7.451773, -7.450773, -7.449773, -7.448773, -7.447773, -7.446773, -7.445773, -7.444773, -7.443773, -7.442773, -7.441773, -7.440773, -7.439773, -7.438773, -7.437773, --7.436773, -7.435773, -7.434773, -7.433773, -7.432773, -7.431773, -7.430773, -7.429773, -7.428773, -7.427773, -7.426773, -7.425773, -7.424773, -7.423773, -7.422773, -7.421773, -7.420773, -7.419773, -7.418773, -7.417773, -7.416773, -7.415773, -7.414773, -7.413773, -7.412773, -7.411773, -7.410773, -7.409773, -7.408773, -7.407773, -7.406773, -7.405773, -7.404773, -7.403773, -7.402773, -7.401773, -7.400773, -7.399773, -7.398773, -7.397773, -7.396773, -7.395773, -7.394773, -7.393773, -7.392773, -7.391773, -7.390773, -7.389773, -7.388773, -7.387773, -7.386773, -7.385773, -7.384773, -7.383773, -7.382773, -7.381773, -7.380773, -7.379773, -7.378773, -7.377773, -7.376773, -7.375773, -7.374773, -7.373773, -7.372773, -7.371773, -7.370773, -7.369773, -7.368773, -7.367773, -7.366773, -7.365773, -7.364773, -7.363773, -7.362773, -7.361773, -7.360773, -7.359773, -7.358773, -7.357773, -7.356773, -7.355773, -7.354773, -7.353773, -7.352773, -7.351773, -7.350773, -7.349773, -7.348773, -7.347773, -7.346773, -7.345773, -7.344773, -7.343773, -7.342773, -7.341773, -7.340773, -7.339773, -7.338773, -7.337773, -7.336773, -7.335773, -7.334773, --7.333773, -7.332773, -7.331773, -7.330773, -7.329773, -7.328773, -7.327773, -7.326773, -7.325773, -7.324773, -7.323773, -7.322773, -7.321773, -7.320773, -7.319773, -7.318773, -7.317773, -7.316773, -7.315773, -7.314773, -7.313773, -7.312773, -7.311773, -7.310773, -7.309773, -7.308773, -7.307773, -7.306773, -7.305773, -7.304773, -7.303773, -7.302773, -7.301773, -7.300773, -7.299773, -7.298773, -7.297773, -7.296773, -7.295773, -7.294773, -7.293773, -7.292773, -7.291773, -7.290773, -7.289773, -7.288773, -7.287773, -7.286773, -7.285773, -7.284773, -7.283773, -7.282773, -7.281773, -7.280773, -7.279773, -7.278773, -7.277773, -7.276773, -7.275773, -7.274773, -7.273773, -7.272773, -7.271773, -7.270773, -7.269773, -7.268773, -7.267773, -7.266773, -7.265773, -7.264773, -7.263773, -7.262773, -7.261773, -7.260773, -7.259773, -7.258773, -7.257773, -7.256773, -7.255773, -7.254773, -7.253773, -7.252773, -7.251773, -7.250773, -7.249773, -7.248773, -7.247773, -7.246773, -7.245773, -7.244773, -7.243773, -7.242773, -7.241773, -7.240773, -7.239773, -7.238773, -7.237773, -7.236773, -7.235773, -7.234773, -7.233773, -7.232773, -7.231773, --7.230773, -7.229773, -7.228773, -7.227773, -7.226773, -7.225773, -7.224773, -7.223773, -7.222773, -7.221773, -7.220773, -7.219773, -7.218773, -7.217773, -7.216773, -7.215773, -7.214773, -7.213773, -7.212773, -7.211773, -7.210773, -7.209773, -7.208773, -7.207773, -7.206773, -7.205773, -7.204773, -7.203773, -7.202773, -7.201773, -7.200773, -7.199773, -7.198773, -7.197773, -7.196773, -7.195773, -7.194773, -7.193773, -7.192773, -7.191773, -7.190773, -7.189773, -7.188773, -7.187773, -7.186773, -7.185773, -7.184773, -7.183773, -7.182773, -7.181773, -7.180773, -7.179773, -7.178773, -7.177773, -7.176773, -7.175773, -7.174773, -7.173773, -7.172773, -7.171773, -7.170773, -7.169773, -7.168773, -7.167773, -7.166773, -7.165773, -7.164773, -7.163773, -7.162773, -7.161773, -7.160773, -7.159773, -7.158773, -7.157773, -7.156773, -7.155773, -7.154773, -7.153773, -7.152773, -7.151773, -7.150773, -7.149773, -7.148773, -7.147773, -7.146773, -7.145773, -7.144773, -7.143773, -7.142773, -7.141773, -7.140773, -7.139773, -7.138773, -7.137773, -7.136773, -7.135773, -7.134773, -7.133773, -7.132773, -7.131773, -7.130773, -7.129773, -7.128773, --7.127773, -7.126773, -7.125773, -7.124773, -7.123773, -7.122773, -7.121773, -7.120773, -7.119773, -7.118773, -7.117773, -7.116773, -7.115773, -7.114773, -7.113773, -7.112773, -7.111773, -7.110773, -7.109773, -7.108773, -7.107773, -7.106773, -7.105773, -7.104773, -7.103773, -7.102773, -7.101773, -7.100773, -7.099773, -7.098773, -7.097773, -7.096773, -7.095773, -7.094773, -7.093773, -7.092773, -7.091773, -7.090773, -7.089773, -7.088773, -7.087773, -7.086773, -7.085773, -7.084773, -7.083773, -7.082773, -7.081773, -7.080773, -7.079773, -7.078773, -7.077773, -7.076773, -7.075773, -7.074773, -7.073773, -7.072773, -7.071773, -7.070773, -7.069773, -7.068773, -7.067773, -7.066773, -7.065773, -7.064773, -7.063773, -7.062773, -7.061773, -7.060773, -7.059773, -7.058773, -7.057773, -7.056773, -7.055773, -7.054773, -7.053773, -7.052773, -7.051773, -7.050773, -7.049773, -7.048773, -7.047773, -7.046773, -7.045773, -7.044773, -7.043773, -7.042773, -7.041773, -7.040773, -7.039773, -7.038773, -7.037773, -7.036773, -7.035773, -7.034773, -7.033773, -7.032773, -7.031773, -7.030773, -7.029773, -7.028773, -7.027773, -7.026773, -7.025773, --7.024773, -7.023773, -7.022773, -7.021773, -7.020773, -7.019773, -7.018773, -7.017773, -7.016773, -7.015773, -7.014773, -7.013773, -7.012773, -7.011773, -7.010773, -7.009773, -7.008773, -7.007773, -7.006773, -7.005773, -7.004773, -7.003773, -7.002773, -7.001773, -7.000773, -6.999773, -6.998773, -6.997773, -6.996773, -6.995773, -6.994773, -6.993773, -6.992773, -6.991773, -6.990773, -6.989773, -6.988773, -6.987773, -6.986773, -6.985773, -6.984773, -6.983773, -6.982773, -6.981773, -6.980773, -6.979773, -6.978773, -6.977773, -6.976773, -6.975773, -6.974773, -6.973773, -6.972773, -6.971773, -6.970773, -6.969773, -6.968773, -6.967773, -6.966773, -6.965773, -6.964773, -6.963773, -6.962773, -6.961773, -6.960773, -6.959773, -6.958773, -6.957773, -6.956773, -6.955773, -6.954773, -6.953773, -6.952773, -6.951773, -6.950773, -6.949773, -6.948773, -6.947773, -6.946773, -6.945773, -6.944773, -6.943773, -6.942773, -6.941773, -6.940773, -6.939773, -6.938773, -6.937773, -6.936773, -6.935773, -6.934773, -6.933773, -6.932773, -6.931773, -6.930773, -6.929773, -6.928773, -6.927773, -6.926773, -6.925773, -6.924773, -6.923773, -6.922773, --6.921773, -6.920773, -6.919773, -6.918773, -6.917773, -6.916773, -6.915773, -6.914773, -6.913773, -6.912773, -6.911773, -6.910773, -6.909773, -6.908773, -6.907773, -6.906773, -6.905773, -6.904773, -6.903773, -6.902773, -6.901773, -6.900773, -6.899773, -6.898773, -6.897773, -6.896773, -6.895773, -6.894773, -6.893773, -6.892773, -6.891773, -6.890773, -6.889773, -6.888773, -6.887773, -6.886773, -6.885773, -6.884773, -6.883773, -6.882773, -6.881773, -6.880773, -6.879773, -6.878773, -6.877773, -6.876773, -6.875773, -6.874773, -6.873773, -6.872773, -6.871773, -6.870773, -6.869773, -6.868773, -6.867773, -6.866773, -6.865773, -6.864773, -6.863773, -6.862773, -6.861773, -6.860773, -6.859773, -6.858773, -6.857773, -6.856773, -6.855773, -6.854773, -6.853773, -6.852773, -6.851773, -6.850773, -6.849773, -6.848773, -6.847773, -6.846773, -6.845773, -6.844773, -6.843773, -6.842773, -6.841773, -6.840773, -6.839773, -6.838773, -6.837773, -6.836773, -6.835773, -6.834773, -6.833773, -6.832773, -6.831773, -6.830773, -6.829773, -6.828773, -6.827773, -6.826773, -6.825773, -6.824773, -6.823773, -6.822773, -6.821773, -6.820773, -6.819773, --6.818773, -6.817773, -6.816773, -6.815773, -6.814773, -6.813773, -6.812773, -6.811773, -6.810773, -6.809773, -6.808773, -6.807773, -6.806773, -6.805773, -6.804773, -6.803773, -6.802773, -6.801773, -6.800773, -6.799773, -6.798773, -6.797773, -6.796773, -6.795773, -6.794773, -6.793773, -6.792773, -6.791773, -6.790773, -6.789773, -6.788773, -6.787773, -6.786773, -6.785773, -6.784773, -6.783773, -6.782773, -6.781773, -6.780773, -6.779773, -6.778773, -6.777773, -6.776773, -6.775773, -6.774773, -6.773773, -6.772773, -6.771773, -6.770773, -6.769773, -6.768773, -6.767773, -6.766773, -6.765773, -6.764773, -6.763773, -6.762773, -6.761773, -6.760773, -6.759773, -6.758773, -6.757773, -6.756773, -6.755773, -6.754773, -6.753773, -6.752773, -6.751773, -6.750773, -6.749773, -6.748773, -6.747773, -6.746773, -6.745773, -6.744773, -6.743773, -6.742773, -6.741773, -6.740773, -6.739773, -6.738773, -6.737773, -6.736773, -6.735773, -6.734773, -6.733773, -6.732773, -6.731773, -6.730773, -6.729773, -6.728773, -6.727773, -6.726773, -6.725773, -6.724773, -6.723773, -6.722773, -6.721773, -6.720773, -6.719773, -6.718773, -6.717773, -6.716773, -6.715773, --6.714773, -6.713773, -6.712773, -6.711773, -6.710773, -6.709773, -6.708773, -6.707773, -6.706773, -6.705773, -6.704773, -6.703773, -6.702773, -6.701773, -6.700773, -6.699773, -6.698773, -6.697773, -6.696773, -6.695773, -6.694773, -6.693773, -6.692773, -6.691773, -6.690773, -6.689773, -6.688773, -6.687773, -6.686773, -6.685773, -6.684773, -6.683773, -6.682773, -6.681773, -6.680773, -6.679773, -6.678773, -6.677773, -6.676773, -6.675773, -6.674773, -6.673773, -6.672773, -6.671773, -6.670773, -6.669773, -6.668773, -6.667773, -6.666773, -6.665773, -6.664773, -6.663773, -6.662773, -6.661773, -6.660773, -6.659773, -6.658773, -6.657773, -6.656773, -6.655773}, sd::DataType::DOUBLE); - input.linspace(0.01, 0.001); + NDArray input('c', {1500}, sd::DataType::DOUBLE); + NDArray output('c', {1500}, sd::DataType::DOUBLE); + NDArray expOutput('c', {1500}, {-8.154773, -8.153772, -8.152773, -8.151772, -8.150773, -8.149773, -8.148773, -8.147773, -8.146772, -8.145773, -8.144773, -8.143773, -8.142773, -8.141773, -8.140773, -8.139772, -8.138773, -8.137773, -8.136773, -8.135773, -8.134773, -8.133773, -8.132772, -8.131773, -8.130773, -8.129773, -8.128773, -8.127772, -8.126773, -8.125772, -8.124773, -8.123773, -8.122773, -8.121773, -8.120772, -8.119773, -8.118773, -8.117773, -8.116773, -8.115773, -8.114773, -8.113772, -8.112773, -8.111773, -8.110773, -8.109773, -8.108773, -8.107773, -8.106772, -8.105773, -8.104773, -8.103773, -8.102773, -8.101772, -8.100773, -8.099772, -8.098773, -8.097773, -8.096773, -8.095773, -8.094772, -8.093773, -8.092772, -8.091773, -8.090773, -8.089773, -8.088773, -8.087772, -8.086773, -8.085773, -8.084773, -8.083773, -8.082773, -8.081773, -8.080772, -8.079773, -8.078773, -8.077773, -8.076773, -8.075773, -8.074773, -8.073772, -8.072773, -8.071773, -8.070773, -8.069773, -8.068772, -8.067773, -8.066772, -8.065773, -8.064773, -8.063773, -8.062773, -8.061772, -8.060773, -8.059772, -8.058773, -8.057773, -8.056773, -8.055773, -8.054772, + -8.053773, -8.052773, -8.051773, -8.050773, -8.049773, -8.048773, -8.047772, -8.046773, -8.045773, -8.044773, -8.043773, -8.042773, -8.041773, -8.040772, -8.039773, -8.038773, -8.037773, -8.036773, -8.035772, -8.034773, -8.033772, -8.032773, -8.031773, -8.030773, -8.029773, -8.028772, -8.027773, -8.026772, -8.025773, -8.024773, -8.023773, -8.022773, -8.021772, -8.020773, -8.019773, -8.018773, -8.017773, -8.016773, -8.015773, -8.014772, -8.013773, -8.012773, -8.011773, -8.010773, -8.009773, -8.008773, -8.007772, -8.006773, -8.005773, -8.004773, -8.003773, -8.002772, -8.001773, -8.000772, -7.999773, -7.998773, -7.997773, -7.996773, -7.995773, -7.994773, -7.993773, -7.992773, -7.991773, -7.990773, -7.989773, -7.988773, -7.987773, -7.986773, -7.985773, -7.984773, -7.983773, -7.982773, -7.981773, -7.980773, -7.979773, -7.978773, -7.977773, -7.976773, -7.975773, -7.974773, -7.973773, -7.972773, -7.971773, -7.970773, -7.969773, -7.968773, -7.967773, -7.966773, -7.965773, -7.964773, -7.963773, -7.962773, -7.961773, -7.960773, -7.959773, -7.958773, -7.957773, -7.956773, -7.955773, -7.954773, -7.953773, -7.952773, + -7.951773, -7.950773, -7.949773, -7.948773, -7.947773, -7.946773, -7.945773, -7.944773, -7.943773, -7.942773, -7.941773, -7.940773, -7.939773, -7.938773, -7.937773, -7.936773, -7.935773, -7.934773, -7.933773, -7.932773, -7.931773, -7.930773, -7.929773, -7.928773, -7.927773, -7.926773, -7.925773, -7.924773, -7.923773, -7.922773, -7.921773, -7.920773, -7.919773, -7.918773, -7.917773, -7.916773, -7.915773, -7.914773, -7.913773, -7.912773, -7.911773, -7.910773, -7.909773, -7.908773, -7.907773, -7.906773, -7.905773, -7.904773, -7.903773, -7.902773, -7.901773, -7.900773, -7.899773, -7.898773, -7.897773, -7.896773, -7.895773, -7.894773, -7.893773, -7.892773, -7.891773, -7.890773, -7.889773, -7.888773, -7.887773, -7.886773, -7.885773, -7.884773, -7.883773, -7.882773, -7.881773, -7.880773, -7.879773, -7.878773, -7.877773, -7.876773, -7.875773, -7.874773, -7.873773, -7.872773, -7.871773, -7.870773, -7.869773, -7.868773, -7.867773, -7.866773, -7.865773, -7.864773, -7.863773, -7.862773, -7.861773, -7.860773, -7.859773, -7.858773, -7.857773, -7.856773, -7.855773, -7.854773, -7.853773, -7.852773, -7.851773, -7.850773, -7.849773, + -7.848773, -7.847773, -7.846773, -7.845773, -7.844773, -7.843773, -7.842773, -7.841773, -7.840773, -7.839773, -7.838773, -7.837773, -7.836773, -7.835773, -7.834773, -7.833773, -7.832773, -7.831773, -7.830773, -7.829773, -7.828773, -7.827773, -7.826773, -7.825773, -7.824773, -7.823773, -7.822773, -7.821773, -7.820773, -7.819773, -7.818773, -7.817773, -7.816773, -7.815773, -7.814773, -7.813773, -7.812773, -7.811773, -7.810773, -7.809773, -7.808773, -7.807773, -7.806773, -7.805773, -7.804773, -7.803773, -7.802773, -7.801773, -7.800773, -7.799773, -7.798773, -7.797773, -7.796773, -7.795773, -7.794773, -7.793773, -7.792773, -7.791773, -7.790773, -7.789773, -7.788773, -7.787773, -7.786773, -7.785773, -7.784773, -7.783773, -7.782773, -7.781773, -7.780773, -7.779773, -7.778773, -7.777773, -7.776773, -7.775773, -7.774773, -7.773773, -7.772773, -7.771773, -7.770773, -7.769773, -7.768773, -7.767773, -7.766773, -7.765773, -7.764773, -7.763773, -7.762773, -7.761773, -7.760773, -7.759773, -7.758773, -7.757773, -7.756773, -7.755773, -7.754773, -7.753773, -7.752773, -7.751773, -7.750773, -7.749773, -7.748773, -7.747773, -7.746773, + -7.745773, -7.744773, -7.743773, -7.742773, -7.741773, -7.740773, -7.739773, -7.738773, -7.737773, -7.736773, -7.735773, -7.734773, -7.733773, -7.732773, -7.731773, -7.730773, -7.729773, -7.728773, -7.727773, -7.726773, -7.725773, -7.724773, -7.723773, -7.722773, -7.721773, -7.720773, -7.719773, -7.718773, -7.717773, -7.716773, -7.715773, -7.714773, -7.713773, -7.712773, -7.711773, -7.710773, -7.709773, -7.708773, -7.707773, -7.706773, -7.705773, -7.704773, -7.703773, -7.702773, -7.701773, -7.700773, -7.699773, -7.698773, -7.697773, -7.696773, -7.695773, -7.694773, -7.693773, -7.692773, -7.691773, -7.690773, -7.689773, -7.688773, -7.687773, -7.686773, -7.685773, -7.684773, -7.683773, -7.682773, -7.681773, -7.680773, -7.679773, -7.678773, -7.677773, -7.676773, -7.675773, -7.674773, -7.673773, -7.672773, -7.671773, -7.670773, -7.669773, -7.668773, -7.667773, -7.666773, -7.665773, -7.664773, -7.663773, -7.662773, -7.661773, -7.660773, -7.659773, -7.658773, -7.657773, -7.656773, -7.655773, -7.654773, -7.653773, -7.652773, -7.651773, -7.650773, -7.649773, -7.648773, -7.647773, -7.646773, -7.645773, -7.644773, -7.643773, + -7.642773, -7.641773, -7.640773, -7.639773, -7.638773, -7.637773, -7.636773, -7.635773, -7.634773, -7.633773, -7.632773, -7.631773, -7.630773, -7.629773, -7.628773, -7.627773, -7.626773, -7.625773, -7.624773, -7.623773, -7.622773, -7.621773, -7.620773, -7.619773, -7.618773, -7.617773, -7.616773, -7.615773, -7.614773, -7.613773, -7.612773, -7.611773, -7.610773, -7.609773, -7.608773, -7.607773, -7.606773, -7.605773, -7.604773, -7.603773, -7.602773, -7.601773, -7.600773, -7.599773, -7.598773, -7.597773, -7.596773, -7.595773, -7.594773, -7.593773, -7.592773, -7.591773, -7.590773, -7.589773, -7.588773, -7.587773, -7.586773, -7.585773, -7.584773, -7.583773, -7.582773, -7.581773, -7.580773, -7.579773, -7.578773, -7.577773, -7.576773, -7.575773, -7.574773, -7.573773, -7.572773, -7.571773, -7.570773, -7.569773, -7.568773, -7.567773, -7.566773, -7.565773, -7.564773, -7.563773, -7.562773, -7.561773, -7.560773, -7.559773, -7.558773, -7.557773, -7.556773, -7.555773, -7.554773, -7.553773, -7.552773, -7.551773, -7.550773, -7.549773, -7.548773, -7.547773, -7.546773, -7.545773, -7.544773, -7.543773, -7.542773, -7.541773, -7.540773, + -7.539773, -7.538773, -7.537773, -7.536773, -7.535773, -7.534773, -7.533773, -7.532773, -7.531773, -7.530773, -7.529773, -7.528773, -7.527773, -7.526773, -7.525773, -7.524773, -7.523773, -7.522773, -7.521773, -7.520773, -7.519773, -7.518773, -7.517773, -7.516773, -7.515773, -7.514773, -7.513773, -7.512773, -7.511773, -7.510773, -7.509773, -7.508773, -7.507773, -7.506773, -7.505773, -7.504773, -7.503773, -7.502773, -7.501773, -7.500773, -7.499773, -7.498773, -7.497773, -7.496773, -7.495773, -7.494773, -7.493773, -7.492773, -7.491773, -7.490773, -7.489773, -7.488773, -7.487773, -7.486773, -7.485773, -7.484773, -7.483773, -7.482773, -7.481773, -7.480773, -7.479773, -7.478773, -7.477773, -7.476773, -7.475773, -7.474773, -7.473773, -7.472773, -7.471773, -7.470773, -7.469773, -7.468773, -7.467773, -7.466773, -7.465773, -7.464773, -7.463773, -7.462773, -7.461773, -7.460773, -7.459773, -7.458773, -7.457773, -7.456773, -7.455773, -7.454773, -7.453773, -7.452773, -7.451773, -7.450773, -7.449773, -7.448773, -7.447773, -7.446773, -7.445773, -7.444773, -7.443773, -7.442773, -7.441773, -7.440773, -7.439773, -7.438773, -7.437773, + -7.436773, -7.435773, -7.434773, -7.433773, -7.432773, -7.431773, -7.430773, -7.429773, -7.428773, -7.427773, -7.426773, -7.425773, -7.424773, -7.423773, -7.422773, -7.421773, -7.420773, -7.419773, -7.418773, -7.417773, -7.416773, -7.415773, -7.414773, -7.413773, -7.412773, -7.411773, -7.410773, -7.409773, -7.408773, -7.407773, -7.406773, -7.405773, -7.404773, -7.403773, -7.402773, -7.401773, -7.400773, -7.399773, -7.398773, -7.397773, -7.396773, -7.395773, -7.394773, -7.393773, -7.392773, -7.391773, -7.390773, -7.389773, -7.388773, -7.387773, -7.386773, -7.385773, -7.384773, -7.383773, -7.382773, -7.381773, -7.380773, -7.379773, -7.378773, -7.377773, -7.376773, -7.375773, -7.374773, -7.373773, -7.372773, -7.371773, -7.370773, -7.369773, -7.368773, -7.367773, -7.366773, -7.365773, -7.364773, -7.363773, -7.362773, -7.361773, -7.360773, -7.359773, -7.358773, -7.357773, -7.356773, -7.355773, -7.354773, -7.353773, -7.352773, -7.351773, -7.350773, -7.349773, -7.348773, -7.347773, -7.346773, -7.345773, -7.344773, -7.343773, -7.342773, -7.341773, -7.340773, -7.339773, -7.338773, -7.337773, -7.336773, -7.335773, -7.334773, + -7.333773, -7.332773, -7.331773, -7.330773, -7.329773, -7.328773, -7.327773, -7.326773, -7.325773, -7.324773, -7.323773, -7.322773, -7.321773, -7.320773, -7.319773, -7.318773, -7.317773, -7.316773, -7.315773, -7.314773, -7.313773, -7.312773, -7.311773, -7.310773, -7.309773, -7.308773, -7.307773, -7.306773, -7.305773, -7.304773, -7.303773, -7.302773, -7.301773, -7.300773, -7.299773, -7.298773, -7.297773, -7.296773, -7.295773, -7.294773, -7.293773, -7.292773, -7.291773, -7.290773, -7.289773, -7.288773, -7.287773, -7.286773, -7.285773, -7.284773, -7.283773, -7.282773, -7.281773, -7.280773, -7.279773, -7.278773, -7.277773, -7.276773, -7.275773, -7.274773, -7.273773, -7.272773, -7.271773, -7.270773, -7.269773, -7.268773, -7.267773, -7.266773, -7.265773, -7.264773, -7.263773, -7.262773, -7.261773, -7.260773, -7.259773, -7.258773, -7.257773, -7.256773, -7.255773, -7.254773, -7.253773, -7.252773, -7.251773, -7.250773, -7.249773, -7.248773, -7.247773, -7.246773, -7.245773, -7.244773, -7.243773, -7.242773, -7.241773, -7.240773, -7.239773, -7.238773, -7.237773, -7.236773, -7.235773, -7.234773, -7.233773, -7.232773, -7.231773, + -7.230773, -7.229773, -7.228773, -7.227773, -7.226773, -7.225773, -7.224773, -7.223773, -7.222773, -7.221773, -7.220773, -7.219773, -7.218773, -7.217773, -7.216773, -7.215773, -7.214773, -7.213773, -7.212773, -7.211773, -7.210773, -7.209773, -7.208773, -7.207773, -7.206773, -7.205773, -7.204773, -7.203773, -7.202773, -7.201773, -7.200773, -7.199773, -7.198773, -7.197773, -7.196773, -7.195773, -7.194773, -7.193773, -7.192773, -7.191773, -7.190773, -7.189773, -7.188773, -7.187773, -7.186773, -7.185773, -7.184773, -7.183773, -7.182773, -7.181773, -7.180773, -7.179773, -7.178773, -7.177773, -7.176773, -7.175773, -7.174773, -7.173773, -7.172773, -7.171773, -7.170773, -7.169773, -7.168773, -7.167773, -7.166773, -7.165773, -7.164773, -7.163773, -7.162773, -7.161773, -7.160773, -7.159773, -7.158773, -7.157773, -7.156773, -7.155773, -7.154773, -7.153773, -7.152773, -7.151773, -7.150773, -7.149773, -7.148773, -7.147773, -7.146773, -7.145773, -7.144773, -7.143773, -7.142773, -7.141773, -7.140773, -7.139773, -7.138773, -7.137773, -7.136773, -7.135773, -7.134773, -7.133773, -7.132773, -7.131773, -7.130773, -7.129773, -7.128773, + -7.127773, -7.126773, -7.125773, -7.124773, -7.123773, -7.122773, -7.121773, -7.120773, -7.119773, -7.118773, -7.117773, -7.116773, -7.115773, -7.114773, -7.113773, -7.112773, -7.111773, -7.110773, -7.109773, -7.108773, -7.107773, -7.106773, -7.105773, -7.104773, -7.103773, -7.102773, -7.101773, -7.100773, -7.099773, -7.098773, -7.097773, -7.096773, -7.095773, -7.094773, -7.093773, -7.092773, -7.091773, -7.090773, -7.089773, -7.088773, -7.087773, -7.086773, -7.085773, -7.084773, -7.083773, -7.082773, -7.081773, -7.080773, -7.079773, -7.078773, -7.077773, -7.076773, -7.075773, -7.074773, -7.073773, -7.072773, -7.071773, -7.070773, -7.069773, -7.068773, -7.067773, -7.066773, -7.065773, -7.064773, -7.063773, -7.062773, -7.061773, -7.060773, -7.059773, -7.058773, -7.057773, -7.056773, -7.055773, -7.054773, -7.053773, -7.052773, -7.051773, -7.050773, -7.049773, -7.048773, -7.047773, -7.046773, -7.045773, -7.044773, -7.043773, -7.042773, -7.041773, -7.040773, -7.039773, -7.038773, -7.037773, -7.036773, -7.035773, -7.034773, -7.033773, -7.032773, -7.031773, -7.030773, -7.029773, -7.028773, -7.027773, -7.026773, -7.025773, + -7.024773, -7.023773, -7.022773, -7.021773, -7.020773, -7.019773, -7.018773, -7.017773, -7.016773, -7.015773, -7.014773, -7.013773, -7.012773, -7.011773, -7.010773, -7.009773, -7.008773, -7.007773, -7.006773, -7.005773, -7.004773, -7.003773, -7.002773, -7.001773, -7.000773, -6.999773, -6.998773, -6.997773, -6.996773, -6.995773, -6.994773, -6.993773, -6.992773, -6.991773, -6.990773, -6.989773, -6.988773, -6.987773, -6.986773, -6.985773, -6.984773, -6.983773, -6.982773, -6.981773, -6.980773, -6.979773, -6.978773, -6.977773, -6.976773, -6.975773, -6.974773, -6.973773, -6.972773, -6.971773, -6.970773, -6.969773, -6.968773, -6.967773, -6.966773, -6.965773, -6.964773, -6.963773, -6.962773, -6.961773, -6.960773, -6.959773, -6.958773, -6.957773, -6.956773, -6.955773, -6.954773, -6.953773, -6.952773, -6.951773, -6.950773, -6.949773, -6.948773, -6.947773, -6.946773, -6.945773, -6.944773, -6.943773, -6.942773, -6.941773, -6.940773, -6.939773, -6.938773, -6.937773, -6.936773, -6.935773, -6.934773, -6.933773, -6.932773, -6.931773, -6.930773, -6.929773, -6.928773, -6.927773, -6.926773, -6.925773, -6.924773, -6.923773, -6.922773, + -6.921773, -6.920773, -6.919773, -6.918773, -6.917773, -6.916773, -6.915773, -6.914773, -6.913773, -6.912773, -6.911773, -6.910773, -6.909773, -6.908773, -6.907773, -6.906773, -6.905773, -6.904773, -6.903773, -6.902773, -6.901773, -6.900773, -6.899773, -6.898773, -6.897773, -6.896773, -6.895773, -6.894773, -6.893773, -6.892773, -6.891773, -6.890773, -6.889773, -6.888773, -6.887773, -6.886773, -6.885773, -6.884773, -6.883773, -6.882773, -6.881773, -6.880773, -6.879773, -6.878773, -6.877773, -6.876773, -6.875773, -6.874773, -6.873773, -6.872773, -6.871773, -6.870773, -6.869773, -6.868773, -6.867773, -6.866773, -6.865773, -6.864773, -6.863773, -6.862773, -6.861773, -6.860773, -6.859773, -6.858773, -6.857773, -6.856773, -6.855773, -6.854773, -6.853773, -6.852773, -6.851773, -6.850773, -6.849773, -6.848773, -6.847773, -6.846773, -6.845773, -6.844773, -6.843773, -6.842773, -6.841773, -6.840773, -6.839773, -6.838773, -6.837773, -6.836773, -6.835773, -6.834773, -6.833773, -6.832773, -6.831773, -6.830773, -6.829773, -6.828773, -6.827773, -6.826773, -6.825773, -6.824773, -6.823773, -6.822773, -6.821773, -6.820773, -6.819773, + -6.818773, -6.817773, -6.816773, -6.815773, -6.814773, -6.813773, -6.812773, -6.811773, -6.810773, -6.809773, -6.808773, -6.807773, -6.806773, -6.805773, -6.804773, -6.803773, -6.802773, -6.801773, -6.800773, -6.799773, -6.798773, -6.797773, -6.796773, -6.795773, -6.794773, -6.793773, -6.792773, -6.791773, -6.790773, -6.789773, -6.788773, -6.787773, -6.786773, -6.785773, -6.784773, -6.783773, -6.782773, -6.781773, -6.780773, -6.779773, -6.778773, -6.777773, -6.776773, -6.775773, -6.774773, -6.773773, -6.772773, -6.771773, -6.770773, -6.769773, -6.768773, -6.767773, -6.766773, -6.765773, -6.764773, -6.763773, -6.762773, -6.761773, -6.760773, -6.759773, -6.758773, -6.757773, -6.756773, -6.755773, -6.754773, -6.753773, -6.752773, -6.751773, -6.750773, -6.749773, -6.748773, -6.747773, -6.746773, -6.745773, -6.744773, -6.743773, -6.742773, -6.741773, -6.740773, -6.739773, -6.738773, -6.737773, -6.736773, -6.735773, -6.734773, -6.733773, -6.732773, -6.731773, -6.730773, -6.729773, -6.728773, -6.727773, -6.726773, -6.725773, -6.724773, -6.723773, -6.722773, -6.721773, -6.720773, -6.719773, -6.718773, -6.717773, -6.716773, -6.715773, + -6.714773, -6.713773, -6.712773, -6.711773, -6.710773, -6.709773, -6.708773, -6.707773, -6.706773, -6.705773, -6.704773, -6.703773, -6.702773, -6.701773, -6.700773, -6.699773, -6.698773, -6.697773, -6.696773, -6.695773, -6.694773, -6.693773, -6.692773, -6.691773, -6.690773, -6.689773, -6.688773, -6.687773, -6.686773, -6.685773, -6.684773, -6.683773, -6.682773, -6.681773, -6.680773, -6.679773, -6.678773, -6.677773, -6.676773, -6.675773, -6.674773, -6.673773, -6.672773, -6.671773, -6.670773, -6.669773, -6.668773, -6.667773, -6.666773, -6.665773, -6.664773, -6.663773, -6.662773, -6.661773, -6.660773, -6.659773, -6.658773, -6.657773, -6.656773, -6.655773}, sd::DataType::DOUBLE); + input.linspace(0.01, 0.001); - ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, 0); + ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, 0); - ASSERT_TRUE(output.equalsTo(&expOutput)); + ASSERT_TRUE(output.equalsTo(&expOutput)); } ////////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulMxV_1) { - const Nd4jLong M = 3; - const Nd4jLong N = 4; + const Nd4jLong M = 3; + const Nd4jLong N = 4; - NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray temp('f', {M,N,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(6, {0,2}); - NDArray y('f', {M}, sd::DataType::DOUBLE); + NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + NDArray temp('f', {M,N,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(6, {0,2}); + NDArray y('f', {M}, sd::DataType::DOUBLE); - NDArray exp('f', {M}, {5.5, 5.1, 4.7}, sd::DataType::DOUBLE); + NDArray exp('f', {M}, {5.5, 5.1, 4.7}, sd::DataType::DOUBLE); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulMxV_2) { - const Nd4jLong M = 3; - const Nd4jLong N = 4; + const Nd4jLong M = 3; + const Nd4jLong N = 4; - NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {M,N,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(6, {0,2}); - NDArray y('f', {M}, sd::DataType::DOUBLE); + NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('f', {M,N,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(6, {0,2}); + NDArray y('f', {M}, sd::DataType::DOUBLE); - NDArray exp('f', {M}, {5.1, 3.3, 1.5}, sd::DataType::DOUBLE); + NDArray exp('f', {M}, {5.1, 3.3, 1.5}, sd::DataType::DOUBLE); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulMxV_3) { - const Nd4jLong M = 3; - const Nd4jLong N = 4; + const Nd4jLong M = 3; + const Nd4jLong N = 4; - NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {N,M,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(4, {1,2}); - NDArray y('f', {M}, sd::DataType::DOUBLE); + NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('f', {N,M,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(4, {1,2}); + NDArray y('f', {M}, sd::DataType::DOUBLE); - NDArray exp('f', {M}, {6.2, 4.5, 1.7}, sd::DataType::DOUBLE); + NDArray exp('f', {M}, {6.2, 4.5, 1.7}, sd::DataType::DOUBLE); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulMxV_4) { - const Nd4jLong M = 3; - const Nd4jLong N = 4; + const Nd4jLong M = 3; + const Nd4jLong N = 4; - NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(3, {0,1}); - NDArray y('f', {M}, sd::DataType::DOUBLE); + NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(3, {0,1}); + NDArray y('f', {M}, sd::DataType::DOUBLE); - NDArray exp('f', {M}, {1.5, 1.8, 1.5}, sd::DataType::DOUBLE); + NDArray exp('f', {M}, {1.5, 1.8, 1.5}, sd::DataType::DOUBLE); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulMxV_5) { - const Nd4jLong M = 3; - const Nd4jLong N = 4; + const Nd4jLong M = 3; + const Nd4jLong N = 4; - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(2, {0,1}); - NDArray y('f', {M}, sd::DataType::DOUBLE); + NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(2, {0,1}); + NDArray y('f', {M}, sd::DataType::DOUBLE); - NDArray exp('f', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); + NDArray exp('f', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulMxV_6) { - const Nd4jLong M = 3; - const Nd4jLong N = 4; + const Nd4jLong M = 3; + const Nd4jLong N = 4; - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('c', {5,N,M}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(13, {0,2}); - NDArray y('f', {M}, sd::DataType::DOUBLE); + NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('c', {5,N,M}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(13, {0,2}); + NDArray y('f', {M}, sd::DataType::DOUBLE); - NDArray exp('f', {M}, {-12.1, -10.9, -9.7}, sd::DataType::DOUBLE); + NDArray exp('f', {M}, {-12.1, -10.9, -9.7}, sd::DataType::DOUBLE); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulMxV_7) { - const Nd4jLong M = 3; - const Nd4jLong N = 4; + const Nd4jLong M = 3; + const Nd4jLong N = 4; - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('c', {5,N,M}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(10, {0,2}); - NDArray y('c', {M}, sd::DataType::DOUBLE); + NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('c', {5,N,M}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(10, {0,2}); + NDArray y('c', {M}, sd::DataType::DOUBLE); - NDArray exp('c', {M}, {3.3, 3.3, 3.3}, sd::DataType::DOUBLE); + NDArray exp('c', {M}, {3.3, 3.3, 3.3}, sd::DataType::DOUBLE); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, softmaxDerivative_1) { - NDArray input('c', {3,3}, {-1, 1, -2, 2, -3, 3, -4, 4, 5.}, sd::DataType::DOUBLE); - NDArray expOutput('c', {3,3}, {0.04508, 0.04514, 0.0008 , 0.0472 , 0.00087, 0.10492, 0.00235, 0.04592, 0.10553}, sd::DataType::DOUBLE); - NDArray output('c', {3,3}, sd::DataType::DOUBLE); + NDArray input('c', {3,3}, {-1, 1, -2, 2, -3, 3, -4, 4, 5.}, sd::DataType::DOUBLE); + NDArray expOutput('c', {3,3}, {0.04508, 0.04514, 0.0008 , 0.0472 , 0.00087, 0.10492, 0.00235, 0.04592, 0.10553}, sd::DataType::DOUBLE); + NDArray output('c', {3,3}, sd::DataType::DOUBLE); - // input.applyTransform(sd::transform::SoftMaxDerivative, &output); + // input.applyTransform(sd::transform::SoftMaxDerivative, &output); - sd::ops::helpers::softmaxDerivative(input.getContext(), input, output, 0); - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + sd::ops::helpers::softmaxDerivative(input.getContext(), input, output, 0); + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, softmaxDerivative_2) { - NDArray input('c', {3,3,3}, {-1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14.}, sd::DataType::DOUBLE); - NDArray expOutput('c', {3,3,3}, {4.50755e-02, 4.51394e-02, 6.64586e-03,4.72027e-02, 8.67128e-04, 6.97440e-03,2.35008e-03, 4.59243e-02, 3.32995e-04, - 4.51766e-02, 2.26032e-06, 4.51767e-02,2.91394e-07, 2.37285e-06, 3.94360e-08,4.51769e-02, 1.12535e-07, 4.51767e-02, - 7.58256e-10, 4.51767e-02, 1.22325e-11,7.96007e-10, 1.32293e-11, 1.04994e-01,3.77513e-11, 4.51767e-02, 1.04994e-01}, sd::DataType::DOUBLE); - NDArray output('c', {3,3,3}, sd::DataType::DOUBLE); + NDArray input('c', {3,3,3}, {-1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14.}, sd::DataType::DOUBLE); + NDArray expOutput('c', {3,3,3}, {4.50755e-02, 4.51394e-02, 6.64586e-03,4.72027e-02, 8.67128e-04, 6.97440e-03,2.35008e-03, 4.59243e-02, 3.32995e-04, + 4.51766e-02, 2.26032e-06, 4.51767e-02,2.91394e-07, 2.37285e-06, 3.94360e-08,4.51769e-02, 1.12535e-07, 4.51767e-02, + 7.58256e-10, 4.51767e-02, 1.22325e-11,7.96007e-10, 1.32293e-11, 1.04994e-01,3.77513e-11, 4.51767e-02, 1.04994e-01}, sd::DataType::DOUBLE); + NDArray output('c', {3,3,3}, sd::DataType::DOUBLE); - // input.applyTransform(sd::transform::SoftMaxDerivative, &output); + // input.applyTransform(sd::transform::SoftMaxDerivative, &output); - sd::ops::helpers::softmaxDerivative(input.getContext(), input, output, 1); - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + sd::ops::helpers::softmaxDerivative(input.getContext(), input, output, 1); + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, softmaxDerivative_3) { - NDArray input('c', {5}, {-1., 1, -2, 2, 3}, sd::DataType::DOUBLE); - NDArray expOutput('c', {5}, {0.01184, 0.08071, 0.00439, 0.18277, 0.22618}, sd::DataType::DOUBLE); - NDArray output('c', {5}, sd::DataType::DOUBLE); + NDArray input('c', {5}, {-1., 1, -2, 2, 3}, sd::DataType::DOUBLE); + NDArray expOutput('c', {5}, {0.01184, 0.08071, 0.00439, 0.18277, 0.22618}, sd::DataType::DOUBLE); + NDArray output('c', {5}, sd::DataType::DOUBLE); - // input.applyTransform(sd::transform::SoftMaxDerivative, &output); + // input.applyTransform(sd::transform::SoftMaxDerivative, &output); - sd::ops::helpers::softmaxDerivative(input.getContext(), input, output, 0); - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + sd::ops::helpers::softmaxDerivative(input.getContext(), input, output, 0); + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, lstmLayerCell_1) { - const int bS = 2; - const int nIn = 10; - const int nOut = 4; + const int bS = 2; + const int nIn = 10; + const int nOut = 4; - const float dataFormat = 0; // is ignored in cell op - const float cellClip = 5; // clipping value - const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid - const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid - const float cellAct = 0; // tanh activation for cell state - const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh - const float cellBeta = 0; // beta value for cell state activation, not required for tanh - const float outAct = 0; // tanh activation for output - const float outAlpha = 0; // alpha value for output activation, not required for tanh - const float outBeta = 0; // beta value for output activation, not required for tanh + const float dataFormat = 0; // is ignored in cell op + const float cellClip = 5; // clipping value + const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid + const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid + const float cellAct = 0; // tanh activation for cell state + const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh + const float cellBeta = 0; // beta value for cell state activation, not required for tanh + const float outAct = 0; // tanh activation for output + const float outAlpha = 0; // alpha value for output activation, not required for tanh + const float outBeta = 0; // beta value for output activation, not required for tanh - NDArray x ('c', {bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b ('c', {4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray Wp('c', {3*nOut}, sd::DataType::FLOAT32); + NDArray x ('c', {bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); + NDArray b ('c', {4*nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray Wp('c', {3*nOut}, sd::DataType::FLOAT32); - NDArray h('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray c('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray h('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray c('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray expH('c', {bS, nOut}, {0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288}, sd::DataType::FLOAT32); - NDArray expC('c', {bS, nOut}, {3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778}, sd::DataType::FLOAT32); + NDArray expH('c', {bS, nOut}, {0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288}, sd::DataType::FLOAT32); + NDArray expC('c', {bS, nOut}, {3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778}, sd::DataType::FLOAT32); - std::vector params = {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; + std::vector params = {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; - x = 1.; - hI = 2.; - cI = 3.; - Wx = 0.5; - Wr = 0.4; - Wp = 0.3; - b = 0.7; + x = 1.; + hI = 2.; + cI = 3.; + Wx = 0.5; + Wr = 0.4; + Wp = 0.3; + b = 0.7; - sd::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c); + sd::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c); - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expC.isSameShape(c)); - ASSERT_TRUE(expC.equalsTo(c)); + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expC.isSameShape(c)); + ASSERT_TRUE(expC.equalsTo(c)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, lstmLayerCell_2) { - const int bS = 2; - const int nIn = 10; - const int nOut = 4; + const int bS = 2; + const int nIn = 10; + const int nOut = 4; - const float dataFormat = 0; // is ignored in cell op - const float cellClip = 3; // clipping value - const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid - const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid - const float cellAct = 0; // tanh activation for cell state - const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh - const float cellBeta = 0; // beta value for cell state activation, not required for tanh - const float outAct = 0; // tanh activation for output - const float outAlpha = 0; // alpha value for output activation, not required for tanh - const float outBeta = 0; // beta value for output activation, not required for tanh + const float dataFormat = 0; // is ignored in cell op + const float cellClip = 3; // clipping value + const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid + const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid + const float cellAct = 0; // tanh activation for cell state + const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh + const float cellBeta = 0; // beta value for cell state activation, not required for tanh + const float outAct = 0; // tanh activation for output + const float outAlpha = 0; // alpha value for output activation, not required for tanh + const float outBeta = 0; // beta value for output activation, not required for tanh - NDArray x ('c', {bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b ('c', {4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray Wp('c', {3*nOut}, sd::DataType::FLOAT32); + NDArray x ('c', {bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); + NDArray b ('c', {4*nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray Wp('c', {3*nOut}, sd::DataType::FLOAT32); - NDArray h('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray c('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray h('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray c('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray expH('c', {bS, nOut}, {0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995}, sd::DataType::FLOAT32); - NDArray expC('c', {bS, nOut}, {3., 3., 3., 3., 3., 3., 3., 3.}, sd::DataType::FLOAT32); + NDArray expH('c', {bS, nOut}, {0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995}, sd::DataType::FLOAT32); + NDArray expC('c', {bS, nOut}, {3., 3., 3., 3., 3., 3., 3., 3.}, sd::DataType::FLOAT32); - std::vector params = {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; + std::vector params = {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; - x = 1.; - hI = 2.; - cI = 3.; - Wx = 0.5; - Wr = 0.4; - Wp = 0.3; - b = 0.7; + x = 1.; + hI = 2.; + cI = 3.; + Wx = 0.5; + Wr = 0.4; + Wp = 0.3; + b = 0.7; - sd::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c); + sd::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c); - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expC.isSameShape(c)); - ASSERT_TRUE(expC.equalsTo(c)); + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expC.isSameShape(c)); + ASSERT_TRUE(expC.equalsTo(c)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, lstmLayerCell_3) { - const int nIn = 10; - const int nOut = 4; - - const float dataFormat = 0; // is ignored in cell op - const float cellClip = 5; // clipping value - const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid - const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid - const float cellAct = 0; // tanh activation for cell state - const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh - const float cellBeta = 0; // beta value for cell state activation, not required for tanh - const float outAct = 0; // tanh activation for output - const float outAlpha = 0; // alpha value for output activation, not required for tanh - const float outBeta = 0; // beta value for output activation, not required for tanh - - NDArray x ('c', {nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b ('c', {4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {nOut}, sd::DataType::FLOAT32); - NDArray Wp('c', {3*nOut}, sd::DataType::FLOAT32); - - NDArray h('c', {nOut}, sd::DataType::FLOAT32); - NDArray c('c', {nOut}, sd::DataType::FLOAT32); - - NDArray expH('c', {nOut}, {0.999288, 0.999288, 0.999288, 0.999288}, sd::DataType::FLOAT32); - NDArray expC('c', {nOut}, {3.999778, 3.999778, 3.999778, 3.999778}, sd::DataType::FLOAT32); - - std::vector params = {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; - - x = 1.; - hI = 2.; - cI = 3.; - Wx = 0.5; - Wr = 0.4; - Wp = 0.3; - b = 0.7; - - sd::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expC.isSameShape(c)); - ASSERT_TRUE(expC.equalsTo(c)); + const int nIn = 10; + const int nOut = 4; + + const float dataFormat = 0; // is ignored in cell op + const float cellClip = 5; // clipping value + const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid + const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid + const float cellAct = 0; // tanh activation for cell state + const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh + const float cellBeta = 0; // beta value for cell state activation, not required for tanh + const float outAct = 0; // tanh activation for output + const float outAlpha = 0; // alpha value for output activation, not required for tanh + const float outBeta = 0; // beta value for output activation, not required for tanh + + NDArray x('c', {nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {nOut}, sd::DataType::FLOAT32); + NDArray Wp('c', {3 * nOut}, sd::DataType::FLOAT32); + + NDArray h('c', {nOut}, sd::DataType::FLOAT32); + NDArray c('c', {nOut}, sd::DataType::FLOAT32); + + NDArray expH('c', {nOut}, {0.999288, 0.999288, 0.999288, 0.999288}, sd::DataType::FLOAT32); + NDArray expC('c', {nOut}, {3.999778, 3.999778, 3.999778, 3.999778}, sd::DataType::FLOAT32); + + std::vector params = + {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; + + x = 1.; + hI = 2.; + cI = 3.; + Wx = 0.5; + Wr = 0.4; + Wp = 0.3; + b = 0.7; + + sd::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expC.isSameShape(c)); + ASSERT_TRUE(expC.equalsTo(c)); } diff --git a/libnd4j/tests_cpu/layers_tests/IndexingTests.cpp b/libnd4j/tests_cpu/layers_tests/IndexingTests.cpp index dbe7ccd0a79d..04158a8f6bbf 100644 --- a/libnd4j/tests_cpu/layers_tests/IndexingTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/IndexingTests.cpp @@ -18,440 +18,422 @@ // Created by raver119 on 31.10.2017. // -#include "testlayers.h" -#include #include #include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class IndexingTests : public testing::Test { -public: - + public: }; TEST_F(IndexingTests, StridedSlice_1) { - auto x = NDArrayFactory::create('c', {3, 3, 3}); - auto exp = NDArrayFactory::create('c', {1, 1, 3}); - exp.p(0, 25.f); - exp.p(1, 26.f); - exp.p(2, 27.f); - - x.linspace(1); - auto begin = NDArrayFactory::create({2,2, 0}); - auto end = NDArrayFactory::create({3,3,3}); - auto strides = NDArrayFactory::create({1,1,1}); - - - sd::ops::strided_slice op; - - auto result = op.evaluate({&x, &begin, &end, &strides}, {}, {0,0,0,0,0}); //, 2,2,0, 3,3,3, 1,1,1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto x = NDArrayFactory::create('c', {3, 3, 3}); + auto exp = NDArrayFactory::create('c', {1, 1, 3}); + exp.p(0, 25.f); + exp.p(1, 26.f); + exp.p(2, 27.f); + + x.linspace(1); + auto begin = NDArrayFactory::create({2, 2, 0}); + auto end = NDArrayFactory::create({3, 3, 3}); + auto strides = NDArrayFactory::create({1, 1, 1}); + + sd::ops::strided_slice op; + + auto result = op.evaluate({&x, &begin, &end, &strides}, {}, + {0, 0, 0, 0, 0}); //, 2,2,0, 3,3,3, 1,1,1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(IndexingTests, StridedSlice_2) { - auto x = NDArrayFactory::create('c', {5, 5, 5}); - auto exp = NDArrayFactory::create('c', {2, 3, 3}, {86.f, 87.f, 88.f, 91.f, 92.f, 93.f, 96.f, 97.f, 98.f, 111.f, 112.f, 113.f, 116.f, 117.f, 118.f, 121.f, 122.f, 123.f}); + auto x = NDArrayFactory::create('c', {5, 5, 5}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 3}, + {86.f, 87.f, 88.f, 91.f, 92.f, 93.f, 96.f, 97.f, 98.f, 111.f, 112.f, + 113.f, 116.f, 117.f, 118.f, 121.f, 122.f, 123.f}); - x.linspace(1); + x.linspace(1); - sd::ops::strided_slice op; + sd::ops::strided_slice op; - auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto result = + op.evaluate({&x}, {}, {0, 0, 0, 0, 0, 3, 2, 0, 5, 5, 3, 1, 1, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(IndexingTests, StridedSlice_3) { - auto x = NDArrayFactory::create('c', {5, 5, 5}); - auto exp = NDArrayFactory::create('c', {2, 3, 2}, {86.f, 88.f, 91.f, 93.f, 96.f, 98.f, 111.f, 113.f, 116.f, 118.f, 121.f, 123.f}); + auto x = NDArrayFactory::create('c', {5, 5, 5}); + auto exp = + NDArrayFactory::create('c', {2, 3, 2}, + {86.f, 88.f, 91.f, 93.f, 96.f, 98.f, 111.f, + 113.f, 116.f, 118.f, 121.f, 123.f}); - x.linspace(1); + x.linspace(1); - sd::ops::strided_slice op; - - auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,2}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::strided_slice op; - auto z = result.at(0); + auto result = + op.evaluate({&x}, {}, {0, 0, 0, 0, 0, 3, 2, 0, 5, 5, 3, 1, 1, 2}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(IndexingTests, SimpleSlice_1) { + auto input = NDArrayFactory::create( + 'c', {3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); - auto input = NDArrayFactory::create('c', {3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); + auto exp = NDArrayFactory::create('c', {1, 1, 3}); + exp.p(0, 3.0f); + exp.p(1, 3.0f); + exp.p(2, 3.0f); - auto exp = NDArrayFactory::create('c', {1, 1, 3}); - exp.p(0, 3.0f); - exp.p(1, 3.0f); - exp.p(2, 3.0f); + sd::ops::slice op; - sd::ops::slice op; + auto result = op.evaluate({&input}, {}, {1, 0, 0, 1, 1, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto result = op.evaluate({&input}, {}, {1,0,0, 1,1,3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(IndexingTests, SimpleSlice_2) { - auto input = NDArrayFactory::create('c', {3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); - - auto exp = NDArrayFactory::create('c', {1, 2, 3}); - exp.p(0, 3.0f); - exp.p(1, 3.0f); - exp.p(2, 3.0f); - exp.p(3, 4.0f); - exp.p(4, 4.0f); - exp.p(5, 4.0f); + auto input = NDArrayFactory::create( + 'c', {3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); - sd::ops::slice op; + auto exp = NDArrayFactory::create('c', {1, 2, 3}); + exp.p(0, 3.0f); + exp.p(1, 3.0f); + exp.p(2, 3.0f); + exp.p(3, 4.0f); + exp.p(4, 4.0f); + exp.p(5, 4.0f); - auto result = op.evaluate({&input}, {}, {1,0,0, 1,2,3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::slice op; - auto z = result.at(0); + auto result = op.evaluate({&input}, {}, {1, 0, 0, 1, 2, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); + auto z = result.at(0); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(IndexingTests, SimpleSlice_3) { - auto input = NDArrayFactory::create('c', {3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); + auto input = NDArrayFactory::create( + 'c', {3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); - auto exp = NDArrayFactory::create('c', {2, 1, 3}); - exp.p(0, 3.0f); - exp.p(1, 3.0f); - exp.p(2, 3.0f); - exp.p(3, 5.0f); - exp.p(4, 5.0f); - exp.p(5, 5.0f); + auto exp = NDArrayFactory::create('c', {2, 1, 3}); + exp.p(0, 3.0f); + exp.p(1, 3.0f); + exp.p(2, 3.0f); + exp.p(3, 5.0f); + exp.p(4, 5.0f); + exp.p(5, 5.0f); - sd::ops::slice op; + sd::ops::slice op; - auto result = op.evaluate({&input}, {}, {1,0,0, 2,1,3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto result = op.evaluate({&input}, {}, {1, 0, 0, 2, 1, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); + auto z = result.at(0); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(IndexingTests, SimpleSlice_4) { - auto input = NDArrayFactory::create('c', {3, 2, 3}, {1.0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); - auto start = NDArrayFactory::create('c', {3}, {1.0, 0.0, 0.0}); - auto stop = NDArrayFactory::create('c', {3}, {2.0, 1.0, 3.0}); - auto exp = NDArrayFactory::create('c', {2, 1, 3}, {3.0, 3.0, 3.0, 5.0, 5.0, 5.0}); - - sd::ops::slice op; + auto input = NDArrayFactory::create( + 'c', {3, 2, 3}, {1.0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); + auto start = NDArrayFactory::create('c', {3}, {1.0, 0.0, 0.0}); + auto stop = NDArrayFactory::create('c', {3}, {2.0, 1.0, 3.0}); + auto exp = NDArrayFactory::create('c', {2, 1, 3}, + {3.0, 3.0, 3.0, 5.0, 5.0, 5.0}); - auto result = op.evaluate({&input, &start, &stop}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::slice op; - auto z = result.at(0); + auto result = op.evaluate({&input, &start, &stop}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(IndexingTests, MaskedSlice_0) { - auto matrix = NDArrayFactory::create('c', {3, 5}); - auto tads = matrix.allTensorsAlongDimension({1}); - for (int e = 0; e < tads.size(); e++) { - tads.at(e)->assign((float) (e+1)); - } - - auto exp = NDArrayFactory::create('c', {1, 5}); - exp.assign(2.0f); + auto matrix = NDArrayFactory::create('c', {3, 5}); + auto tads = matrix.allTensorsAlongDimension({1}); + for (int e = 0; e < tads.size(); e++) { + tads.at(e).assign((float)(e + 1)); + } - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 2, 1}); + auto exp = NDArrayFactory::create('c', {1, 5}); + exp.assign(2.0f); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix}, {}, {0, 0, 0, 0, 0, 1, 2, 1}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - // z->printShapeInfo("z"); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + // z->printShapeInfo("z"); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(IndexingTests, MaskedSlice_00) { - auto matrix = NDArrayFactory::create('c', {3, 5}); - auto tads = matrix.allTensorsAlongDimension({1}); - for (int e = 0; e < tads.size(); e++) { - tads.at(e)->assign((float) (e+1)); - } - - auto exp = NDArrayFactory::create('c', {1, 2}, {2, 2}); + auto matrix = NDArrayFactory::create('c', {3, 5}); + auto tads = matrix.allTensorsAlongDimension({1}); + for (int e = 0; e < tads.size(); e++) { + tads.at(e).assign((float)(e + 1)); + } + auto exp = NDArrayFactory::create('c', {1, 2}, {2, 2}); - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 1, 2, 3, 1, 1}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix}, {}, {0, 0, 0, 0, 0, 1, 1, 2, 3, 1, 1}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(IndexingTests, MaskedSlice_1) { - auto matrix = NDArrayFactory::create('c', {3, 5}); - auto tads = matrix.allTensorsAlongDimension({1}); - for (int e = 0; e < tads.size(); e++) { - tads.at(e)->assign((float) (e+1)); - } + auto matrix = NDArrayFactory::create('c', {3, 5}); + auto tads = matrix.allTensorsAlongDimension({1}); + for (int e = 0; e < tads.size(); e++) { + tads.at(e).assign((float)(e + 1)); + } - auto exp = NDArrayFactory::create('c', {5}); - exp.assign(2.0f); + auto exp = NDArrayFactory::create('c', {5}); + exp.assign(2.0f); - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 2, 1}); + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix}, {}, {0, 0, 0, 0, 1, 1, 2, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - // z->printShapeInfo("z"); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + // z->printShapeInfo("z"); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(IndexingTests, MaskedSlice_2) { - - auto matrix = NDArrayFactory::create('c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f}); - auto exp = NDArrayFactory::create('c', {3, 3}, {4.000000f, 4.200000f, 4.300000f, 5.000000f, 5.200000f, 5.300000f, 6.000000f, 6.200000f, 6.300000f}); - - // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 0, 0, 3, 3, 3, 1, 1, 1}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto matrix = NDArrayFactory::create( + 'c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, + 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, + 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {4.000000f, 4.200000f, 4.300000f, 5.000000f, 5.200000f, 5.300000f, + 6.000000f, 6.200000f, 6.300000f}); + + // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) + sd::ops::strided_slice op; + auto result = + op.evaluate({&matrix}, {}, {0, 0, 0, 0, 1, 1, 0, 0, 3, 3, 3, 1, 1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(IndexingTests, MaskedSlice_3) { + auto matrix = NDArrayFactory::create( + 'c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, + 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, + 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f}); + auto exp = NDArrayFactory::create('c', {2, 3}, + {4.f, 4.2f, 4.3f, 7.f, 7.2f, 7.3f}); - auto matrix = NDArrayFactory::create('c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f}); - auto exp = NDArrayFactory::create('c', {2, 3}, { 4.f, 4.2f, 4.3f, 7.f, 7.2f, 7.3f}); - - // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix}, {}, {0,0,0,0,2, 1, 0, 0, 3, 3, 3, 1, 1, 1}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) + sd::ops::strided_slice op; + auto result = + op.evaluate({&matrix}, {}, {0, 0, 0, 0, 2, 1, 0, 0, 3, 3, 3, 1, 1, 1}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(IndexingTests, MaskedSlice_4) { + auto matrix = NDArrayFactory::create( + 'c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, + 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, + 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f}); + auto exp = NDArrayFactory::create('c', {3}, {4.f, 4.2f, 4.3f}); - auto matrix = NDArrayFactory::create('c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f}); - auto exp = NDArrayFactory::create('c', {3}, { 4.f, 4.2f, 4.3f}); - - // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix}, {}, {0,0,0,0, 3, 1, 0, 0, 3, 3, 3, 1, 1, 1}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) + sd::ops::strided_slice op; + auto result = + op.evaluate({&matrix}, {}, {0, 0, 0, 0, 3, 1, 0, 0, 3, 3, 3, 1, 1, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(IndexingTests, Live_Slice_1) { - auto matrix = NDArrayFactory::create('c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f}); - auto exp = NDArrayFactory::create('c', {3}, { 4.f, 4.2f, 4.3f}); - - auto begin = NDArrayFactory::create('c', {3}, {1.0f, 0.0f, 0.0f}); - auto end = NDArrayFactory::create('c', {3}, {3.0f, 3.0f, 3.0f}); - auto stride = NDArrayFactory::create('c', {3}, {1.0f, 1.0f, 1.0f}); + auto matrix = NDArrayFactory::create( + 'c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, + 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, + 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f}); + auto exp = NDArrayFactory::create('c', {3}, {4.f, 4.2f, 4.3f}); - // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix, &begin, &end, &stride}, {}, {0,0,0,0,3}); + auto begin = NDArrayFactory::create('c', {3}, {1.0f, 0.0f, 0.0f}); + auto end = NDArrayFactory::create('c', {3}, {3.0f, 3.0f, 3.0f}); + auto stride = NDArrayFactory::create('c', {3}, {1.0f, 1.0f, 1.0f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) + sd::ops::strided_slice op; + auto result = + op.evaluate({&matrix, &begin, &end, &stride}, {}, {0, 0, 0, 0, 3}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - // z->printShapeInfo("z shape"); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + // z->printShapeInfo("z shape"); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(IndexingTests, Test_StridedSlice_1) { - auto x = NDArrayFactory::create('c', {1, 2}, {5.f, 2.f}); - auto a = NDArrayFactory::create('c', {1}, {0.f}); - auto b = NDArrayFactory::create('c', {1}, {1.f}); - auto c = NDArrayFactory::create('c', {1}, {1.f}); - auto exp = NDArrayFactory::create({5.0f, 2}); + auto x = NDArrayFactory::create('c', {1, 2}, {5.f, 2.f}); + auto a = NDArrayFactory::create('c', {1}, {0.f}); + auto b = NDArrayFactory::create('c', {1}, {1.f}); + auto c = NDArrayFactory::create('c', {1}, {1.f}); + auto exp = NDArrayFactory::create({5.0f, 2}); - sd::ops::strided_slice op; - auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); + sd::ops::strided_slice op; + auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(IndexingTests, Test_StridedSlice_2) { - auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); - auto a = NDArrayFactory::create('c', {2}, {1, 1}); - auto b = NDArrayFactory::create('c', {2}, {2, 2}); - auto c = NDArrayFactory::create('c', {2}, {1, 1}); - auto exp = NDArrayFactory::create('c', {1}, {5.0}); - - sd::ops::strided_slice op; - auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + auto a = NDArrayFactory::create('c', {2}, {1, 1}); + auto b = NDArrayFactory::create('c', {2}, {2, 2}); + auto c = NDArrayFactory::create('c', {2}, {1, 1}); + auto exp = NDArrayFactory::create('c', {1}, {5.0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::strided_slice op; + auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - // z->printIndexedBuffer("Z"); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + // z->printIndexedBuffer("Z"); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(IndexingTests, Test_StridedSlice_3) { - auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); - auto a = NDArrayFactory::create('c', {2}, {1, 2}); - auto b = NDArrayFactory::create('c', {2}, {2, 3}); - auto c = NDArrayFactory::create('c', {2}, {1, 1}); - auto exp = NDArrayFactory::create('c', {1}, {6.0}); - - sd::ops::strided_slice op; - auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + auto a = NDArrayFactory::create('c', {2}, {1, 2}); + auto b = NDArrayFactory::create('c', {2}, {2, 3}); + auto c = NDArrayFactory::create('c', {2}, {1, 1}); + auto exp = NDArrayFactory::create('c', {1}, {6.0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::strided_slice op; + auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(IndexingTests, Test_StridedSlice_4) { - auto x = NDArrayFactory::create('c', {1, 2}, {5, 2}); - auto a = NDArrayFactory::create('c', {1}, {0.}); - auto b = NDArrayFactory::create('c', {1}, {1}); - auto c = NDArrayFactory::create('c', {1}, {1}); - auto exp = NDArrayFactory::create({5.0f, 2}); - - sd::ops::strided_slice op; - auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); -// auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1, 0, 1, 1}); + auto x = NDArrayFactory::create('c', {1, 2}, {5, 2}); + auto a = NDArrayFactory::create('c', {1}, {0.}); + auto b = NDArrayFactory::create('c', {1}, {1}); + auto c = NDArrayFactory::create('c', {1}, {1}); + auto exp = NDArrayFactory::create({5.0f, 2}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::strided_slice op; + auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); + // auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1, 0, 1, + // 1}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - //z->printIndexedBuffer("Z"); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + // z->printIndexedBuffer("Z"); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(IndexingTests, Test_Subarray_Strided_1) { - auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto exp = NDArrayFactory::create('c', {3, 2}, {1, 3, 4, 6, 7, 9}); - auto sub = x({0,0,0, 0,3,2}, true, true); + auto x = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto exp = NDArrayFactory::create('c', {3, 2}, {1, 3, 4, 6, 7, 9}); + auto sub = x({0, 0, 0, 0, 3, 2}, true, true); - ASSERT_TRUE(exp.isSameShape(sub)); - ASSERT_TRUE(exp.equalsTo(sub)); + ASSERT_TRUE(exp.isSameShape(sub)); + ASSERT_TRUE(exp.equalsTo(sub)); } - /* TEST_F(IndexingTests, MaskedSlice_5) { - auto matrix('c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f}); + auto matrix('c', {3, 3, 3}, +{1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, +6.f, 6.2f, 6.3f, 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f}); auto exp('c', {2, 3}, { 4.f, 4.2f, 4.3f, 7.f, 7.2f, 7.3f}); // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) @@ -464,7 +446,5 @@ TEST_F(IndexingTests, MaskedSlice_5) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - - } */ \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/JavaInteropCudaTests.cu b/libnd4j/tests_cpu/layers_tests/JavaInteropCudaTests.cu index aa2c13eb5eed..160f315576c0 100644 --- a/libnd4j/tests_cpu/layers_tests/JavaInteropCudaTests.cu +++ b/libnd4j/tests_cpu/layers_tests/JavaInteropCudaTests.cu @@ -18,70 +18,84 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include #include -#include -#include #include #include +#include +#include + +#include + +#include "testlayers.h" using namespace sd; using namespace sd::ops; class JavaInteropCudaTests : public testing::Test { -public: - + public: }; TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_1) { - auto x = NDArrayFactory::create('c', {3, 5}); - auto y = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); - auto e = NDArrayFactory::create('c', {3, 5}); - x.assign(1.f); - e.assign(2.f); - - sd::ops::add op; - Context context(1); - - context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer()); - context.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); - context.setInputArray(1, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo()); - - context.setOutputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); - - PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_1"); - execCustomOp2(nullptr, op.getOpHash(), &context); - - pm.synchronize(); - - ASSERT_EQ(e, x); + auto x = NDArrayFactory::create('c', {3, 5}); + auto y = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto e = NDArrayFactory::create('c', {3, 5}); + x.assign(1.f); + e.assign(2.f); + + sd::ops::add op; + Context context(1); + + context.setCudaContext( + LaunchContext::defaultContext()->getCudaStream(), + LaunchContext::defaultContext()->getReductionPointer(), + LaunchContext::defaultContext()->getAllocationPointer()); + context.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo()); + context.setInputArray(1, y.buffer(), y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo()); + + context.setOutputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo()); + + PointersManager pm(LaunchContext::defaultContext(), + "test_DeclarableOp_execution_1"); + execCustomOp2(nullptr, op.getOpHash(), &context); + + pm.synchronize(); + + ASSERT_EQ(e, x); } TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_2) { - NDArray x('c', {3, 1, 2}, sd::DataType::FLOAT32); - NDArray y('c', {2, 2}, sd::DataType::FLOAT32); - NDArray z('c', {3, 2, 2}, sd::DataType::BOOL); - NDArray e('c', {3, 2, 2}, sd::DataType::BOOL); + NDArray x('c', {3, 1, 2}, sd::DataType::FLOAT32); + NDArray y('c', {2, 2}, sd::DataType::FLOAT32); + NDArray z('c', {3, 2, 2}, sd::DataType::BOOL); + NDArray e('c', {3, 2, 2}, sd::DataType::BOOL); - x.assign(1.f); - y.assign(2.f); - e.assign(false); + x.assign(1.f); + y.assign(2.f); + e.assign(false); - sd::ops::equals op; - Context context(1); + sd::ops::equals op; + Context context(1); - context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer()); - context.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); - context.setInputArray(1, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo()); + context.setCudaContext( + LaunchContext::defaultContext()->getCudaStream(), + LaunchContext::defaultContext()->getReductionPointer(), + LaunchContext::defaultContext()->getAllocationPointer()); + context.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo()); + context.setInputArray(1, y.buffer(), y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo()); - context.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + context.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); - PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_2"); - execCustomOp2(nullptr, op.getOpHash(), &context); + PointersManager pm(LaunchContext::defaultContext(), + "test_DeclarableOp_execution_2"); + execCustomOp2(nullptr, op.getOpHash(), &context); - pm.synchronize(); + pm.synchronize(); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } - diff --git a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp index 23080161af1a..c7ed836d4bc0 100644 --- a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp @@ -18,450 +18,524 @@ // @author raver119@gmail.com // -#include #include +#include +#include +#include #include #include -#include -#include -#include "testlayers.h" + #include +#include "testlayers.h" + using namespace sd; using namespace sd::ops; class JavaInteropTests : public testing::Test { -public: - + public: }; - TEST_F(JavaInteropTests, TestShapeExposure1) { - auto input = NDArrayFactory::create('c', {1, 2, 5, 4}); - auto weights = NDArrayFactory::create('c', {2, 2, 2, 3}); - auto exp = NDArrayFactory::create('c', {1, 3, 5, 4}); + auto input = NDArrayFactory::create('c', {1, 2, 5, 4}); + auto weights = NDArrayFactory::create('c', {2, 2, 2, 3}); + auto exp = NDArrayFactory::create('c', {1, 3, 5, 4}); - sd::ops::conv2d op; + sd::ops::conv2d op; - std::vector tArgs({}); - std::vector iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1}); + std::vector tArgs({}); + std::vector iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1}); - Nd4jPointer ptrs[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer) weights.shapeInfo()}; + Nd4jPointer ptrs[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)weights.shapeInfo()}; - auto shapeList = calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 2, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size()); + auto shapeList = + calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 2, tArgs.data(), + tArgs.size(), iArgs.data(), iArgs.size()); - ASSERT_EQ(1, shapeList->size()); + ASSERT_EQ(1, shapeList->size()); - ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0))); - ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]); - ASSERT_EQ(exp.sizeAt(1), shape::shapeOf((Nd4jLong *)shapeList->at(0))[1]); - ASSERT_EQ(exp.sizeAt(2), shape::shapeOf((Nd4jLong *)shapeList->at(0))[2]); - ASSERT_EQ(exp.sizeAt(3), shape::shapeOf((Nd4jLong *)shapeList->at(0))[3]); + ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0))); + ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]); + ASSERT_EQ(exp.sizeAt(1), shape::shapeOf((Nd4jLong *)shapeList->at(0))[1]); + ASSERT_EQ(exp.sizeAt(2), shape::shapeOf((Nd4jLong *)shapeList->at(0))[2]); + ASSERT_EQ(exp.sizeAt(3), shape::shapeOf((Nd4jLong *)shapeList->at(0))[3]); - //int *ptr = (int *) shapeList[0]; - //delete[] ptr; - //delete shapeList; + // int *ptr = (int *) shapeList[0]; + // delete[] ptr; + // delete shapeList; - deleteShapeList((Nd4jPointer) shapeList); + deleteShapeList((Nd4jPointer)shapeList); } - TEST_F(JavaInteropTests, TestShapeExposure2) { - auto input = NDArrayFactory::create('c', {1, 2, 5, 4}); - auto exp = NDArrayFactory::create('c', {4}, {1, 2, 5, 4}); + auto input = NDArrayFactory::create('c', {1, 2, 5, 4}); + auto exp = NDArrayFactory::create('c', {4}, {1, 2, 5, 4}); - sd::ops::shape_of op; + sd::ops::shape_of op; - std::vector tArgs({}); - std::vector iArgs({}); + std::vector tArgs({}); + std::vector iArgs({}); + Nd4jPointer ptrs[] = {(Nd4jPointer)input.shapeInfo()}; - Nd4jPointer ptrs[] = {(Nd4jPointer) input.shapeInfo()}; + auto shapeList = + calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 1, tArgs.data(), + tArgs.size(), iArgs.data(), iArgs.size()); - auto shapeList = calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 1, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size()); + ASSERT_EQ(1, shapeList->size()); - ASSERT_EQ(1, shapeList->size()); + ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0))); + ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]); - ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0))); - ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]); - - deleteShapeList((Nd4jPointer) shapeList); + deleteShapeList((Nd4jPointer)shapeList); } TEST_F(JavaInteropTests, TestShapeExposure3) { - auto x = NDArrayFactory::create('c', {5, 30}); - auto sizes = NDArrayFactory::create('c', {3}, {4, 15, 11}); + auto x = NDArrayFactory::create('c', {5, 30}); + auto sizes = NDArrayFactory::create('c', {3}, {4, 15, 11}); - std::vector list0 = {0,0, 0,4}; - std::vector list1 = {0,0, 4,19}; - std::vector list2 = {0,0, 19,30}; + std::vector list0 = {0, 0, 0, 4}; + std::vector list1 = {0, 0, 4, 19}; + std::vector list2 = {0, 0, 19, 30}; - auto sub0 = x(list0, true); - auto sub1 = x(list1, true); - auto sub2 = x(list2, true); + auto sub0 = x(list0, true); + auto sub1 = x(list1, true); + auto sub2 = x(list2, true); - sub0.assign(0.0f); - sub1.assign(1.0f); - sub2.assign(2.0f); + sub0.assign(0.0f); + sub1.assign(1.0f); + sub2.assign(2.0f); - Nd4jPointer inputBuffers[] = {x.buffer(), sizes.buffer(), x.specialBuffer(), sizes.specialBuffer()}; - Nd4jPointer inputShapes[] = {(Nd4jPointer)x.shapeInfo(), (Nd4jPointer)sizes.shapeInfo(), (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)sizes.specialShapeInfo()}; + Nd4jPointer inputBuffers[] = {x.buffer(), sizes.buffer(), x.specialBuffer(), + sizes.specialBuffer()}; + Nd4jPointer inputShapes[] = { + (Nd4jPointer)x.shapeInfo(), (Nd4jPointer)sizes.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)sizes.specialShapeInfo()}; - sd::ops::split_v op; + sd::ops::split_v op; - Nd4jLong iArgs[] = {1}; - auto hash = op.getOpHash(); + Nd4jLong iArgs[] = {1}; + auto hash = op.getOpHash(); - auto shapeList = calculateOutputShapes2(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0, nullptr, 0); + auto shapeList = + calculateOutputShapes2(nullptr, hash, inputBuffers, inputShapes, 2, + nullptr, 0, iArgs, 1, nullptr, 0, nullptr, 0); - ASSERT_EQ(3, shapeList->size()); + ASSERT_EQ(3, shapeList->size()); - ASSERT_TRUE(shape::equalsSoft(sub0.shapeInfo(), shapeList->at(0))); - ASSERT_TRUE(shape::equalsSoft(sub1.shapeInfo(), shapeList->at(1))); - ASSERT_TRUE(shape::equalsSoft(sub2.shapeInfo(), shapeList->at(2))); + ASSERT_TRUE(shape::equalsSoft(sub0.shapeInfo(), shapeList->at(0))); + ASSERT_TRUE(shape::equalsSoft(sub1.shapeInfo(), shapeList->at(1))); + ASSERT_TRUE(shape::equalsSoft(sub2.shapeInfo(), shapeList->at(2))); - deleteShapeList((Nd4jPointer) shapeList); + deleteShapeList((Nd4jPointer)shapeList); } TEST_F(JavaInteropTests, Test_Squeeze_1) { - auto x = NDArrayFactory::create('c', {1, 6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto z = NDArrayFactory::create('c', {6}); - auto e = NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - - sd::ops::squeeze op; - - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), x.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer)x.specialShapeInfo()}; - - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.buffer(), z.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; - auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); - ASSERT_EQ(Status::OK(), status); - - ASSERT_EQ(e, z); + auto x = NDArrayFactory::create('c', {1, 6}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto z = NDArrayFactory::create('c', {6}); + auto e = + NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + + sd::ops::squeeze op; + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)x.buffer(), x.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)x.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)z.buffer(), z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), + (Nd4jPointer)z.specialShapeInfo()}; + auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, + ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, + nullptr, 0, nullptr, 0, nullptr, 0, false); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); } TEST_F(JavaInteropTests, Test_RDiv_1) { - auto x = NDArrayFactory::create('c', {3}, {2, 2, 2}); - auto y = NDArrayFactory::create('c', {3}, {4, 6, 8}); - auto z = NDArrayFactory::create('c', {3}); - auto e = NDArrayFactory::create('c', {3}, {2, 3, 4}); - - NDArray::prepareSpecialUse({&z}, {&x, &y}); - - sd::ops::reversedivide op; - - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), (Nd4jPointer) y.buffer(), x.specialBuffer(), y.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer) y.shapeInfo(), (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; - - - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.buffer(), (Nd4jPointer)z.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; - auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); - - NDArray::registerSpecialUse({&z}, {&x, &y}); - ASSERT_EQ(Status::OK(), status); - - ASSERT_EQ(e, z); + auto x = NDArrayFactory::create('c', {3}, {2, 2, 2}); + auto y = NDArrayFactory::create('c', {3}, {4, 6, 8}); + auto z = NDArrayFactory::create('c', {3}); + auto e = NDArrayFactory::create('c', {3}, {2, 3, 4}); + + NDArray::prepareSpecialUse({&z}, {&x, &y}); + + sd::ops::reversedivide op; + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)x.buffer(), + (Nd4jPointer)y.buffer(), x.specialBuffer(), + y.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = { + (Nd4jPointer)x.shapeInfo(), (Nd4jPointer)y.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)z.buffer(), + (Nd4jPointer)z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), + (Nd4jPointer)z.specialShapeInfo()}; + auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, + ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, + nullptr, 0, nullptr, 0, nullptr, 0, false); + + NDArray::registerSpecialUse({&z}, {&x, &y}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); } TEST_F(JavaInteropTests, TestSconv2d_1) { - auto input = NDArrayFactory::create('c', {3, 3, 8, 8}); - auto weightsD = NDArrayFactory::create('c', {1, 3, 1, 1}); - auto weightsP = NDArrayFactory::create('c', {2, 3, 1, 1}); - auto bias = NDArrayFactory::create('c', {2}); - auto output = NDArrayFactory::create('c', {3, 2, 8, 8}); - output.assign(0.0); - - input.linspace(1); - weightsD.linspace(1); - weightsP.linspace(1); - bias.linspace(1); - weightsD.permutei({2,3,1,0}); - weightsP.permutei({2,3,1,0}); - - auto expOutput = NDArrayFactory::create('c', {3, 2, 8, 8}); - - sd::ops::sconv2d op; - - NDArray::prepareSpecialUse({&output}, {&input, &weightsD, &weightsP, &bias}); - - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.buffer(), (Nd4jPointer) weightsD.buffer(), (Nd4jPointer) weightsP.buffer(), (Nd4jPointer) bias.buffer(), (Nd4jPointer) input.specialBuffer(), (Nd4jPointer) weightsD.specialBuffer(), (Nd4jPointer) weightsP.specialBuffer(), (Nd4jPointer) bias.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer) weightsD.shapeInfo(), (Nd4jPointer) weightsP.shapeInfo(), (Nd4jPointer) bias.shapeInfo(), (Nd4jPointer) input.specialShapeInfo(), (Nd4jPointer) weightsD.specialShapeInfo(), (Nd4jPointer) weightsP.specialShapeInfo(), (Nd4jPointer) bias.specialShapeInfo()}; - - - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.buffer(), (Nd4jPointer) output.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.shapeInfo(), (Nd4jPointer) output.specialShapeInfo()}; - - Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}; - - execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 4, ptrsOutBuffers, ptrsOutShapes, 1, - nullptr, 0, exp, 9, nullptr, 0, false); - - //output.printBuffer("output"); - NDArray::registerSpecialUse({&output}, {&input, &weightsD, &weightsP, &bias}); - - ASSERT_NEAR(1423, output.e(0), 1e-5); - //nd4j_printf("Iter %i passed...\n", e); + auto input = NDArrayFactory::create('c', {3, 3, 8, 8}); + auto weightsD = NDArrayFactory::create('c', {1, 3, 1, 1}); + auto weightsP = NDArrayFactory::create('c', {2, 3, 1, 1}); + auto bias = NDArrayFactory::create('c', {2}); + auto output = NDArrayFactory::create('c', {3, 2, 8, 8}); + output.assign(0.0); + + input.linspace(1); + weightsD.linspace(1); + weightsP.linspace(1); + bias.linspace(1); + weightsD.permutei({2, 3, 1, 0}); + weightsP.permutei({2, 3, 1, 0}); + + auto expOutput = NDArrayFactory::create('c', {3, 2, 8, 8}); + + sd::ops::sconv2d op; + + NDArray::prepareSpecialUse({&output}, {&input, &weightsD, &weightsP, &bias}); + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)input.buffer(), + (Nd4jPointer)weightsD.buffer(), + (Nd4jPointer)weightsP.buffer(), + (Nd4jPointer)bias.buffer(), + (Nd4jPointer)input.specialBuffer(), + (Nd4jPointer)weightsD.specialBuffer(), + (Nd4jPointer)weightsP.specialBuffer(), + (Nd4jPointer)bias.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)weightsD.shapeInfo(), + (Nd4jPointer)weightsP.shapeInfo(), + (Nd4jPointer)bias.shapeInfo(), + (Nd4jPointer)input.specialShapeInfo(), + (Nd4jPointer)weightsD.specialShapeInfo(), + (Nd4jPointer)weightsP.specialShapeInfo(), + (Nd4jPointer)bias.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)output.buffer(), + (Nd4jPointer)output.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)output.shapeInfo(), + (Nd4jPointer)output.specialShapeInfo()}; + + Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}; + + execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 4, + ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, + false); + + // output.printBuffer("output"); + NDArray::registerSpecialUse({&output}, {&input, &weightsD, &weightsP, &bias}); + + ASSERT_NEAR(1423, output.e(0), 1e-5); + // nd4j_printf("Iter %i passed...\n", e); } TEST_F(JavaInteropTests, TestSconv2d_2) { - auto input = NDArrayFactory::create('c', {3, 3, 8, 8}); - auto weightsD = NDArrayFactory::create('c', {1, 3, 1, 1}); - auto output = NDArrayFactory::create('c', {3, 3, 8, 8}); - output.assign(0.0); + auto input = NDArrayFactory::create('c', {3, 3, 8, 8}); + auto weightsD = NDArrayFactory::create('c', {1, 3, 1, 1}); + auto output = NDArrayFactory::create('c', {3, 3, 8, 8}); + output.assign(0.0); - input.linspace(1); - weightsD.linspace(1); - weightsD.permutei({2,3,1,0}); + input.linspace(1); + weightsD.linspace(1); + weightsD.permutei({2, 3, 1, 0}); - auto expOutput = NDArrayFactory::create('c', {3, 3, 8, 8}); + auto expOutput = NDArrayFactory::create('c', {3, 3, 8, 8}); - sd::ops::sconv2d op; + sd::ops::sconv2d op; - NDArray::prepareSpecialUse({&output}, {&input, &weightsD}); + NDArray::prepareSpecialUse({&output}, {&input, &weightsD}); - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.buffer(), (Nd4jPointer) weightsD.buffer(), input.specialBuffer(), weightsD.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer) weightsD.shapeInfo(), (Nd4jPointer)input.specialShapeInfo(), (Nd4jPointer)weightsD.specialShapeInfo()}; + Nd4jPointer ptrsInBuffer[] = { + (Nd4jPointer)input.buffer(), (Nd4jPointer)weightsD.buffer(), + input.specialBuffer(), weightsD.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)weightsD.shapeInfo(), + (Nd4jPointer)input.specialShapeInfo(), + (Nd4jPointer)weightsD.specialShapeInfo()}; + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)output.buffer(), + output.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)output.shapeInfo(), + (Nd4jPointer)output.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.buffer(), output.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.shapeInfo(), (Nd4jPointer)output.specialShapeInfo()}; + Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0}; - Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0}; + execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, + ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, + false); - execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false); + NDArray::registerSpecialUse({&output}, {&input, &weightsD}); - NDArray::registerSpecialUse({&output}, {&input, &weightsD}); - - ASSERT_NEAR(1, output.e(0), 1e-5); + ASSERT_NEAR(1, output.e(0), 1e-5); } - TEST_F(JavaInteropTests, TestMaxPooling2d_1) { - auto input = NDArrayFactory::create('c', {1, 2, 4, 5}); - auto output = NDArrayFactory::create('c', {1, 2, 4, 5}); - input.linspace(1); + auto input = NDArrayFactory::create('c', {1, 2, 4, 5}); + auto output = NDArrayFactory::create('c', {1, 2, 4, 5}); + input.linspace(1); - NDArray::prepareSpecialUse({&output}, {&input}); + NDArray::prepareSpecialUse({&output}, {&input}); - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.buffer(), input.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer)input.specialShapeInfo()}; + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)input.buffer(), + input.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)input.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.buffer(), output.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.shapeInfo(), (Nd4jPointer)output.specialShapeInfo()}; + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)output.buffer(), + output.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)output.shapeInfo(), + (Nd4jPointer)output.specialShapeInfo()}; - std::vector iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1}); + std::vector iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1}); - sd::ops::maxpool2d op; + sd::ops::maxpool2d op; - Nd4jStatus status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs.data(), 9, nullptr, 0, false); - - NDArray::registerSpecialUse({&output}, {&input}); - ASSERT_EQ(ND4J_STATUS_OK, status); + Nd4jStatus status = execCustomOp( + nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, + ptrsOutShapes, 1, nullptr, 0, iArgs.data(), 9, nullptr, 0, false); + NDArray::registerSpecialUse({&output}, {&input}); + ASSERT_EQ(ND4J_STATUS_OK, status); } TEST_F(JavaInteropTests, TestCol2Im_1) { - /* - o.d.n.l.c.ConvolutionLayer - eps shape: [6, 1, 2, 2, 2, 4, 5, 160, 4, 2, 1, 40, 8, 0, -1, 99] - o.d.n.l.c.ConvolutionLayer - epsNext shape: [4, 1, 2, 4, 5, 20, 20, 5, 1, 0, 1, 99] - o.d.n.l.c.ConvolutionLayer - Strides: [1, 1] - o.d.n.l.c.ConvolutionLayer - Padding: [0, 0] - o.d.n.l.c.ConvolutionLayer - Input: [4,5] - o.d.n.l.c.ConvolutionLayer - Dilation: [1, 1] - */ - auto input = NDArrayFactory::create('c', {1, 2, 2, 2, 4, 5}); - auto output = NDArrayFactory::create('c', {1, 2, 4, 5}); - input.linspace(1); + /* + o.d.n.l.c.ConvolutionLayer - eps shape: [6, 1, 2, 2, 2, 4, 5, 160, 4, 2, + 1, 40, 8, 0, -1, 99] o.d.n.l.c.ConvolutionLayer - epsNext shape: [4, 1, 2, + 4, 5, 20, 20, 5, 1, 0, 1, 99] o.d.n.l.c.ConvolutionLayer - Strides: [1, 1] + o.d.n.l.c.ConvolutionLayer - Padding: [0, 0] + o.d.n.l.c.ConvolutionLayer - Input: [4,5] + o.d.n.l.c.ConvolutionLayer - Dilation: [1, 1] + */ + auto input = NDArrayFactory::create('c', {1, 2, 2, 2, 4, 5}); + auto output = NDArrayFactory::create('c', {1, 2, 4, 5}); + input.linspace(1); - NDArray::prepareSpecialUse({&output}, {&input}); + NDArray::prepareSpecialUse({&output}, {&input}); - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.buffer(), input.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer)input.specialShapeInfo()}; + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)input.buffer(), + input.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)input.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.buffer(), output.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.shapeInfo(), (Nd4jPointer)output.specialShapeInfo()}; + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)output.buffer(), + output.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)output.shapeInfo(), + (Nd4jPointer)output.specialShapeInfo()}; - sd::ops::col2im op; + sd::ops::col2im op; - Nd4jLong exp[] = {1, 1, 1, 1, 4, 5, 1, 1, 1}; + Nd4jLong exp[] = {1, 1, 1, 1, 4, 5, 1, 1, 1}; - auto hash = op.getOpHash(); + auto hash = op.getOpHash(); - execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false); + execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, + ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false); - NDArray::registerSpecialUse({&output}, {&input}); + NDArray::registerSpecialUse({&output}, {&input}); - ASSERT_TRUE(output.meanNumber().e(0) > 0.0f); + ASSERT_TRUE(output.meanNumber().e(0) > 0.0f); } TEST_F(JavaInteropTests, TestPNorm_1) { - /* - o.d.n.l.c.s.SubsamplingLayer - input: [4, 1, 3, 4, 4, 16, 16, 4, 1, 0, 1, 99] - o.d.n.l.c.s.SubsamplingLayer - output: [4, 1, 3, 3, 3, 27, 9, 3, 1, 0, 1, 99] - o.d.n.l.c.s.SubsamplingLayer - Kernel: [2, 2] - o.d.n.l.c.s.SubsamplingLayer - Strides: [1, 1] - o.d.n.l.c.s.SubsamplingLayer - Pad: [0, 0] - o.d.n.l.c.s.SubsamplingLayer - Dilation: [1, 1] - o.d.n.l.c.s.SubsamplingLayer - Same: false - o.d.n.l.c.s.SubsamplingLayer - pnorm: 2 - */ - auto input = NDArrayFactory::create('c', {1, 3, 4, 4}); - auto output = NDArrayFactory::create('c', {1, 3, 3, 3}); - input.linspace(1); - - NDArray::prepareSpecialUse({&output}, {&input}); - - sd::ops::pnormpool2d op; - - Nd4jLong exp[] = {2, 2, 1, 1, 0, 0, 1, 1, 0, 2, 0, 0}; - - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.buffer(), input.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer)input.specialShapeInfo()}; - - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.buffer(), output.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.shapeInfo(), (Nd4jPointer)output.specialShapeInfo()}; - - - execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false); - - NDArray::registerSpecialUse({&output}, {&input}); - - ASSERT_TRUE(output.meanNumber().e(0) > 0.0); + /* + o.d.n.l.c.s.SubsamplingLayer - input: [4, 1, 3, 4, 4, 16, 16, 4, 1, 0, 1, + 99] o.d.n.l.c.s.SubsamplingLayer - output: [4, 1, 3, 3, 3, 27, 9, 3, 1, 0, + 1, 99] o.d.n.l.c.s.SubsamplingLayer - Kernel: [2, 2] + o.d.n.l.c.s.SubsamplingLayer - Strides: [1, 1] + o.d.n.l.c.s.SubsamplingLayer - Pad: [0, 0] + o.d.n.l.c.s.SubsamplingLayer - Dilation: [1, 1] + o.d.n.l.c.s.SubsamplingLayer - Same: false + o.d.n.l.c.s.SubsamplingLayer - pnorm: 2 + */ + auto input = NDArrayFactory::create('c', {1, 3, 4, 4}); + auto output = NDArrayFactory::create('c', {1, 3, 3, 3}); + input.linspace(1); + + NDArray::prepareSpecialUse({&output}, {&input}); + + sd::ops::pnormpool2d op; + + Nd4jLong exp[] = {2, 2, 1, 1, 0, 0, 1, 1, 0, 2, 0, 0}; + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)input.buffer(), + input.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)input.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)output.buffer(), + output.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)output.shapeInfo(), + (Nd4jPointer)output.specialShapeInfo()}; + + execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, + ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, + 0, false); + + NDArray::registerSpecialUse({&output}, {&input}); + + ASSERT_TRUE(output.meanNumber().e(0) > 0.0); } - TEST_F(JavaInteropTests, TestInplace_1) { - auto input = NDArrayFactory::create('c', {10, 10}); - //auto exp('c', {10, 10}); - input.linspace(1); - - NDArray::prepareSpecialUse({}, {&input}); + auto input = NDArrayFactory::create('c', {10, 10}); + // auto exp('c', {10, 10}); + input.linspace(1); - sd::ops::clipbyvalue op; + NDArray::prepareSpecialUse({}, {&input}); - double extras[] = {-1.0f, 1.0f}; + sd::ops::clipbyvalue op; - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.buffer(), input.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer)input.specialShapeInfo()}; + double extras[] = {-1.0f, 1.0f}; + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)input.buffer(), + input.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)input.specialShapeInfo()}; - Nd4jStatus result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, nullptr, nullptr, 0, extras, 2, nullptr, 0, nullptr, 0, true); + Nd4jStatus result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, + ptrsInShapes, 1, nullptr, nullptr, 0, extras, + 2, nullptr, 0, nullptr, 0, true); - NDArray::registerSpecialUse({}, {&input}); + NDArray::registerSpecialUse({}, {&input}); - ASSERT_EQ(ND4J_STATUS_OK, result); + ASSERT_EQ(ND4J_STATUS_OK, result); - ASSERT_NEAR(1.0, input.meanNumber().e(0), 1e-5); + ASSERT_NEAR(1.0, input.meanNumber().e(0), 1e-5); } TEST_F(JavaInteropTests, Test_Synonyms_1) { - auto op = OpRegistrator::getInstance().getOperation("RDiv"); - auto opRef = OpRegistrator::getInstance().getOperation("reversedivide"); - std::string nameExp("reversedivide"); + auto op = OpRegistrator::getInstance().getOperation("RDiv"); + auto opRef = OpRegistrator::getInstance().getOperation("reversedivide"); + std::string nameExp("reversedivide"); - ASSERT_TRUE(op != nullptr); - ASSERT_TRUE(opRef != nullptr); + ASSERT_TRUE(op != nullptr); + ASSERT_TRUE(opRef != nullptr); - std::string name = *(op->getOpName()); - std::string nameRef = *(opRef->getOpName()); + std::string name = op->getOpName(); + std::string nameRef = opRef->getOpName(); - ASSERT_EQ(nameExp, nameRef); - ASSERT_EQ(nameRef, name); + ASSERT_EQ(nameExp, nameRef); + ASSERT_EQ(nameRef, name); } TEST_F(JavaInteropTests, Test_Synonyms_2) { - auto op = OpRegistrator::getInstance().getOperation("RDiv"); - auto opRef = OpRegistrator::getInstance().getOperation("reversedivide"); - std::string nameExp("reversedivide"); + auto op = OpRegistrator::getInstance().getOperation("RDiv"); + auto opRef = OpRegistrator::getInstance().getOperation("reversedivide"); + std::string nameExp("reversedivide"); - ASSERT_TRUE(op != nullptr); - ASSERT_TRUE(opRef != nullptr); + ASSERT_TRUE(op != nullptr); + ASSERT_TRUE(opRef != nullptr); - std::string name = *(op->getOpName()); - std::string nameRef = *(opRef->getOpName()); + std::string name = op->getOpName(); + std::string nameRef = opRef->getOpName(); - ASSERT_EQ(nameExp, nameRef); - ASSERT_EQ(nameRef, name); + ASSERT_EQ(nameExp, nameRef); + ASSERT_EQ(nameRef, name); } TEST_F(JavaInteropTests, Test_Synonyms_3) { - auto op = OpRegistrator::getInstance().getOperation("RDiv"); - auto opRef = OpRegistrator::getInstance().getOperation("reversedivide"); - std::string nameExp("reversedivide"); + auto op = OpRegistrator::getInstance().getOperation("RDiv"); + auto opRef = OpRegistrator::getInstance().getOperation("reversedivide"); + std::string nameExp("reversedivide"); - ASSERT_TRUE(op != nullptr); - ASSERT_TRUE(opRef != nullptr); + ASSERT_TRUE(op != nullptr); + ASSERT_TRUE(opRef != nullptr); - std::string name = *(op->getOpName()); - std::string nameRef = *(opRef->getOpName()); + std::string name = op->getOpName(); + std::string nameRef = opRef->getOpName(); - ASSERT_EQ(nameExp, nameRef); - ASSERT_EQ(nameRef, name); + ASSERT_EQ(nameExp, nameRef); + ASSERT_EQ(nameRef, name); } TEST_F(JavaInteropTests, Test_FastPath_Validation_1) { - auto x = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - auto z = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - - Context ctx(1); - ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - sd::ops::softmax op; - auto status = op.execute(&ctx); - ASSERT_NE(Status::OK(), status); + auto x = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + auto z = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + + sd::ops::softmax op; + auto status = op.execute(&ctx); + ASSERT_NE(Status::OK(), status); } TEST_F(JavaInteropTests, Test_FastPath_Validation_2) { - auto x = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - auto z = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - - Context ctx(1); - ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - sd::ops::softmax op; - auto status = op.execute(&ctx); - ASSERT_NE(Status::OK(), status); + auto x = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto z = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + + sd::ops::softmax op; + auto status = op.execute(&ctx); + ASSERT_NE(Status::OK(), status); } TEST_F(JavaInteropTests, Test_FastPath_Validation_3) { - auto x = NDArrayFactory::create('c', {3, 5}, { 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, - 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, - 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); - - auto min = NDArrayFactory::create({ -0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); - auto max = NDArrayFactory::create({ 0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); - - auto z = NDArrayFactory::create('c', {3, 5}); - - Context ctx(1); - ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); - ctx.setInputArray(1, min.buffer(), min.shapeInfo(), min.specialBuffer(), min.specialShapeInfo()); - ctx.setInputArray(2, max.buffer(), max.shapeInfo(), max.specialBuffer(), max.specialShapeInfo()); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - sd::ops::fake_quant_with_min_max_vars_per_channel op; - ASSERT_ANY_THROW(op.execute(&ctx)); + auto x = NDArrayFactory::create( + 'c', {3, 5}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + + auto min = NDArrayFactory::create( + {-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + auto max = NDArrayFactory::create( + {0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + auto z = NDArrayFactory::create('c', {3, 5}); + + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo()); + ctx.setInputArray(1, min.buffer(), min.shapeInfo(), min.specialBuffer(), + min.specialShapeInfo()); + ctx.setInputArray(2, max.buffer(), max.shapeInfo(), max.specialBuffer(), + max.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + + sd::ops::fake_quant_with_min_max_vars_per_channel op; + ASSERT_ANY_THROW(op.execute(&ctx)); } TEST_F(JavaInteropTests, Test_empty_cast_1) { - auto x = NDArrayFactory::create('c', {1, 0, 2}); - auto z = NDArrayFactory::create('c', {1, 0, 2}); - auto e = NDArrayFactory::create('c', {1, 0, 2}); - - Nd4jLong iArgs[] = {10}; - - Context ctx(1); - ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - ctx.setIArguments(iArgs, 1); - - sd::ops::cast op; - auto result = op.execute(&ctx); - ASSERT_EQ(Status::OK(), result); - ASSERT_EQ(e, z); + auto x = NDArrayFactory::create('c', {1, 0, 2}); + auto z = NDArrayFactory::create('c', {1, 0, 2}); + auto e = NDArrayFactory::create('c', {1, 0, 2}); + + Nd4jLong iArgs[] = {10}; + + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + ctx.setIArguments(iArgs, 1); + + sd::ops::cast op; + auto result = op.execute(&ctx); + ASSERT_EQ(Status::OK(), result); + ASSERT_EQ(e, z); } /* @@ -483,12 +557,16 @@ TEST_F(JavaInteropTests, test_avgpooling_edge_1) { Nd4jLong exp[] = {3,3, 1,1, 0,0, 1,1, 1, 0, 1}; Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), x.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), x.specialShapeInfo()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), +x.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.buffer(), z.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.shapeInfo(), z.special()}; + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.buffer(), +z.specialBuffer()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.shapeInfo(), +z.special()}; - auto result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false); + auto result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, +ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, +0, false); NDArray::registerSpecialUse({&z}, {&x}); @@ -549,8 +627,8 @@ TEST_F(JavaInteropTests, test_avgpooling_edge_1) { auto _z = z.e(e); auto eq = sd::math::nd4j_eq(_m, _z, 1e-5); if (!eq) { - nd4j_printf("Difference at element e [%i]: <%f> vs <%f>\n", e, _m, _z); - cnt++; + nd4j_printf("Difference at element e [%i]: <%f> vs <%f>\n", e, _m, +_z); cnt++; } } @@ -559,7 +637,8 @@ TEST_F(JavaInteropTests, test_avgpooling_edge_1) { TEST_F(JavaInteropTests, Test_GraphReuse_1) { - uint8_t* data = sd::graph::readFlatBuffers("./resources/reduce_dim_false.fb"); + uint8_t* data = +sd::graph::readFlatBuffers("./resources/reduce_dim_false.fb"); registerGraph(nullptr, 119, (Nd4jPointer) data); @@ -581,8 +660,9 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) { auto exp1 = NDArrayFactory::create('c', {3}, {6, 6, 6}); auto exp2 = NDArrayFactory::create('c', {3}, {9, 9, 9}); - // we load graph from file, because we're not in java here, and dont have buffer ready - uint8_t* data = sd::graph::readFlatBuffers("./resources/reduce_dim_false.fb"); + // we load graph from file, because we're not in java here, and dont have +buffer ready uint8_t* data = +sd::graph::readFlatBuffers("./resources/reduce_dim_false.fb"); // we ensure that there's no such a graph stored earlier ASSERT_FALSE(GraphHolder::getInstance().hasGraph(119)); @@ -605,10 +685,9 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) { Nd4jPointer inputs_0[] = {(Nd4jPointer) input_0.buffer()}; Nd4jPointer shapes_0[] = {(Nd4jPointer) input_0.shapeInfo()}; - // now we're executing stored graph and providing replacement for input variable - auto res_0 = executeStoredGraph(nullptr, 119, inputs_0, shapes_0, idx, 1); - ASSERT_EQ(ND4J_STATUS_OK, res_0->status()); - ASSERT_EQ(1, res_0->size()); + // now we're executing stored graph and providing replacement for input +variable auto res_0 = executeStoredGraph(nullptr, 119, inputs_0, shapes_0, idx, +1); ASSERT_EQ(ND4J_STATUS_OK, res_0->status()); ASSERT_EQ(1, res_0->size()); auto z0 = res_0->at(0)->getNDArray(); ASSERT_TRUE(exp0.isSameShape(z0)); @@ -658,197 +737,247 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) { */ TEST_F(JavaInteropTests, Test_Greater_1) { - auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 1, 2}); - auto y = NDArrayFactory::create('c', {2, 2}, {1, 2, 0, 0}); -// auto o = NDArrayFactory::create('c', {2, 2}, {3, 3, 3, 3}); - auto o = NDArrayFactory::create('c', {2, 2}, {true, true, true, true}); + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 1, 2}); + auto y = NDArrayFactory::create('c', {2, 2}, {1, 2, 0, 0}); + // auto o = NDArrayFactory::create('c', {2, 2}, {3, 3, 3, 3}); + auto o = NDArrayFactory::create('c', {2, 2}, {true, true, true, true}); - auto exp = NDArrayFactory::create('c', {2, 2}, {false, false, true, true}); + auto exp = + NDArrayFactory::create('c', {2, 2}, {false, false, true, true}); - NDArray::prepareSpecialUse({&o}, {&x, &y}); + NDArray::prepareSpecialUse({&o}, {&x, &y}); - sd::ops::greater op; + sd::ops::greater op; - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), (Nd4jPointer) y.buffer(), x.specialBuffer(), y.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer) y.shapeInfo(), (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)x.buffer(), + (Nd4jPointer)y.buffer(), x.specialBuffer(), + y.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = { + (Nd4jPointer)x.shapeInfo(), (Nd4jPointer)y.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.buffer(), o.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.shapeInfo(), (Nd4jPointer)o.specialShapeInfo()}; + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)o.buffer(), o.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)o.shapeInfo(), + (Nd4jPointer)o.specialShapeInfo()}; - execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, + ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, + nullptr, 0, false); - NDArray::registerSpecialUse({&o}, {&x, &y}); - ASSERT_TRUE(exp.equalsTo(&o)); + NDArray::registerSpecialUse({&o}, {&x, &y}); + ASSERT_TRUE(exp.equalsTo(&o)); } - TEST_F(JavaInteropTests, Test_Greater_2) { - auto x = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 1.f, 2.f}); - auto y = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 0.f, 0.f}); - auto o = NDArrayFactory::create('c', {2, 2}, {true, true, true, true}); + auto x = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 1.f, 2.f}); + auto y = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 0.f, 0.f}); + auto o = NDArrayFactory::create('c', {2, 2}, {true, true, true, true}); - auto exp = NDArrayFactory::create('c', {2, 2}, {false, false, true, true}); + auto exp = + NDArrayFactory::create('c', {2, 2}, {false, false, true, true}); - sd::ops::greater op; + sd::ops::greater op; - NDArray::prepareSpecialUse({&o}, {&x, &y}); + NDArray::prepareSpecialUse({&o}, {&x, &y}); - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), (Nd4jPointer) y.buffer(), x.specialBuffer(), y.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer) y.shapeInfo(), (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)x.buffer(), + (Nd4jPointer)y.buffer(), x.specialBuffer(), + y.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = { + (Nd4jPointer)x.shapeInfo(), (Nd4jPointer)y.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.buffer(), o.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.shapeInfo(), (Nd4jPointer)o.specialShapeInfo()}; + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)o.buffer(), o.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)o.shapeInfo(), + (Nd4jPointer)o.specialShapeInfo()}; - execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, + ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, + nullptr, 0, false); - NDArray::registerSpecialUse({&o}, {&x, &y}); + NDArray::registerSpecialUse({&o}, {&x, &y}); - ASSERT_TRUE(exp.equalsTo(&o)); + ASSERT_TRUE(exp.equalsTo(&o)); } TEST_F(JavaInteropTests, Test_Boolean_Op_1) { + sd::ops::is_non_decreasing op; - sd::ops::is_non_decreasing op; - - auto x = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); - auto o = NDArrayFactory::create(false); - auto exp = NDArrayFactory::create(1); + auto x = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + auto o = NDArrayFactory::create(false); + auto exp = NDArrayFactory::create(1); - NDArray::prepareSpecialUse({&o}, {&x}); + NDArray::prepareSpecialUse({&o}, {&x}); - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), x.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer)x.specialShapeInfo()}; + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)x.buffer(), x.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)x.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.buffer(), o.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.shapeInfo(), (Nd4jPointer)o.specialShapeInfo()}; + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)o.buffer(), o.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)o.shapeInfo(), + (Nd4jPointer)o.specialShapeInfo()}; - auto hash = op.getOpHash(); - auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + auto hash = op.getOpHash(); + auto status = + execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, + ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); - NDArray::registerSpecialUse({&o}, {&x}); - ASSERT_EQ(Status::OK(), status); + NDArray::registerSpecialUse({&o}, {&x}); + ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(exp.equalsTo(&o)); + ASSERT_TRUE(exp.equalsTo(&o)); } - TEST_F(JavaInteropTests, Test_Inplace_Outputs_1) { - auto x = NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto exp = NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto z = NDArrayFactory::create('c', {2, 3}); + auto x = NDArrayFactory::create('c', {2, 3}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto exp = NDArrayFactory::create('c', {2, 3}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto z = NDArrayFactory::create('c', {2, 3}); - sd::ops::test_output_reshape op; + sd::ops::test_output_reshape op; - NDArray::prepareSpecialUse({&z}, {&x}); + NDArray::prepareSpecialUse({&z}, {&x}); - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), x.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer)x.specialShapeInfo()}; + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)x.buffer(), x.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)x.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.buffer(), z.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)z.buffer(), z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), + (Nd4jPointer)z.specialShapeInfo()}; - auto hash = op.getOpHash(); - auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + auto hash = op.getOpHash(); + auto status = + execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, + ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); - NDArray::registerSpecialUse({&z}, {&x}); - ASSERT_EQ(Status::OK(), status); + NDArray::registerSpecialUse({&z}, {&x}); + ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(JavaInteropTests, Test_Inplace_Outputs_2) { - auto x = NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto y = NDArrayFactory::create(2.0f); - auto z = NDArrayFactory::create('f', {2, 3}); - auto e = NDArrayFactory::create('c', {2, 3}, {3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - - - sd::ops::add op; - - NDArray::prepareSpecialUse({&z}, {&x, &y}); - - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), (Nd4jPointer) y.buffer(), x.specialBuffer(), y.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer) y.shapeInfo(), (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; - - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.buffer(), z.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; - - auto hash = op.getOpHash(); - auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); - - NDArray::prepareSpecialUse({&z}, {&x, &y}); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); - ASSERT_FALSE(e.ordering() == z.ordering()); + auto x = NDArrayFactory::create('c', {2, 3}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto y = NDArrayFactory::create(2.0f); + auto z = NDArrayFactory::create('f', {2, 3}); + auto e = NDArrayFactory::create('c', {2, 3}, + {3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + + sd::ops::add op; + + NDArray::prepareSpecialUse({&z}, {&x, &y}); + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)x.buffer(), + (Nd4jPointer)y.buffer(), x.specialBuffer(), + y.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = { + (Nd4jPointer)x.shapeInfo(), (Nd4jPointer)y.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)z.buffer(), z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), + (Nd4jPointer)z.specialShapeInfo()}; + + auto hash = op.getOpHash(); + auto status = + execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, + ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + + NDArray::prepareSpecialUse({&z}, {&x, &y}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); + ASSERT_FALSE(e.ordering() == z.ordering()); } TEST_F(JavaInteropTests, Test_Inplace_Outputs_3) { - auto input = NDArrayFactory::create('c', {2, 3, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); - auto indices = NDArrayFactory::create('c', {1, 6}, {0,1, 2,2, 1,2}); - auto output = NDArrayFactory::create('f', {2, 1, 6, 4}); - auto e = NDArrayFactory::create('c', {2, 1, 6, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 9,10,11,12, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16, 17,18,19,20, 21,22,23,24, 21,22,23,24, 17,18,19,20, 21,22,23,24}); - - sd::ops::gather op; - - NDArray::prepareSpecialUse({&output}, {&input, &indices}); - - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.buffer(), (Nd4jPointer) indices.buffer(), input.specialBuffer(), indices.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer) indices.shapeInfo(), (Nd4jPointer)input.specialShapeInfo(), (Nd4jPointer)input.specialShapeInfo()}; - - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.buffer(), output.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.shapeInfo(), (Nd4jPointer)output.specialShapeInfo()}; - - Nd4jLong iArgs[] = {1}; - - auto hash = op.getOpHash(); - auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 1, nullptr, 0, false); - - NDArray::registerSpecialUse({&output}, {&input, &indices}); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(e.isSameShape(output)); - ASSERT_TRUE(e.equalsTo(output)); - ASSERT_FALSE(e.ordering() == output.ordering()); + auto input = NDArrayFactory::create( + 'c', {2, 3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + auto indices = + NDArrayFactory::create('c', {1, 6}, {0, 1, 2, 2, 1, 2}); + auto output = NDArrayFactory::create('f', {2, 1, 6, 4}); + auto e = NDArrayFactory::create( + 'c', {2, 1, 6, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 9, 10, 11, 12, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 21, 22, 23, 24, 17, 18, 19, 20, 21, 22, 23, 24}); + + sd::ops::gather op; + + NDArray::prepareSpecialUse({&output}, {&input, &indices}); + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)input.buffer(), + (Nd4jPointer)indices.buffer(), + input.specialBuffer(), indices.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)indices.shapeInfo(), + (Nd4jPointer)input.specialShapeInfo(), + (Nd4jPointer)input.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)output.buffer(), + output.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)output.shapeInfo(), + (Nd4jPointer)output.specialShapeInfo()}; + + Nd4jLong iArgs[] = {1}; + + auto hash = op.getOpHash(); + auto status = + execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, + ptrsOutShapes, 1, nullptr, 0, iArgs, 1, nullptr, 0, false); + + NDArray::registerSpecialUse({&output}, {&input, &indices}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(e.isSameShape(output)); + ASSERT_TRUE(e.equalsTo(output)); + ASSERT_FALSE(e.ordering() == output.ordering()); } TEST_F(JavaInteropTests, Test_Reduce3_EdgeCase) { - auto x = NDArrayFactory::create('c', {3, 4, 5}); - auto y = NDArrayFactory::create('c', {3, 4, 5}); - auto z = NDArrayFactory::create('c', {5}); - - auto dims = NDArrayFactory::create('c', {2}, {0, 1}); - dims.syncToHost(); - - sd::LaunchContext* context = sd::LaunchContext::defaultContext(); - - Nd4jPointer* extraPointers = nullptr; - #ifdef __CUDABLAS__ - extraPointers = new Nd4jPointer[6] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer()}; - #endif - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), {0,1}); - auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), {0,1}); - - NDArray::prepareSpecialUse({&z}, {&x, &y, &dims}); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); - OpaqueDataBuffer dimBuf(dims.dataBuffer()); - - execReduce3Tad(extraPointers, 2, &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &yBuf, y.shapeInfo(), y.specialShapeInfo(), - &zBuf, z.shapeInfo(), z.specialShapeInfo(), - &dimBuf, dims.shapeInfo(), dims.specialShapeInfo(), packX.platformShapeInfo(), - packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); - - NDArray::registerSpecialUse({&z}, {&x, &y, &dims}); - - delete []extraPointers; + auto x = NDArrayFactory::create('c', {3, 4, 5}); + auto y = NDArrayFactory::create('c', {3, 4, 5}); + auto z = NDArrayFactory::create('c', {5}); + + auto dims = NDArrayFactory::create('c', {2}, {0, 1}); + dims.syncToHost(); + + sd::LaunchContext *context = sd::LaunchContext::defaultContext(); + + Nd4jPointer *extraPointers = nullptr; +#ifdef __CUDABLAS__ + extraPointers = new Nd4jPointer[6]{ + nullptr, context->getCudaStream(), context->getScalarPointer(), + nullptr, context->getCudaSpecialStream(), context->getReductionPointer()}; +#endif + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + x.shapeInfo(), {0, 1}); + auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions( + y.shapeInfo(), {0, 1}); + + NDArray::prepareSpecialUse({&z}, {&x, &y, &dims}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dims.dataBuffer()); + + execReduce3Tad(extraPointers, 2, &xBuf, x.shapeInfo(), x.specialShapeInfo(), + nullptr, &yBuf, y.shapeInfo(), y.specialShapeInfo(), &zBuf, + z.shapeInfo(), z.specialShapeInfo(), &dimBuf, dims.shapeInfo(), + dims.specialShapeInfo(), packX.platformShapeInfo(), + packX.platformOffsets(), packY.platformShapeInfo(), + packY.platformOffsets()); + + NDArray::registerSpecialUse({&z}, {&x, &y, &dims}); + + delete[] extraPointers; } /* @@ -868,487 +997,1128 @@ TEST_F(JavaInteropTests, Test_SimpleIf_Output) { */ TEST_F(JavaInteropTests, Test_AveragePooling_FF_TF_double) { - - auto input = NDArrayFactory::create('c', {4, 10, 10, 3}, {9.37125111, 2.20166993, 2.91434479, 5.43639755, -2.10573769, 4.08528662, 5.86908436, -4.46203756, 2.21057916, 5.35849190, 0.01394637, 4.40566349, 7.07982206, -0.09633455, 2.42429352, 3.97301817, -1.89553940, 1.99690318, 6.33141708, 0.55401880, 1.70707977, 5.55204201, -0.03513752, 1.60011971, 2.62700319, -2.74582434, 3.06697464, 1.06277943, -1.16075921, -0.78095782, 9.72352791, -1.22686064, 1.99644792, 7.35571337, 1.40607321, 0.11390255, 9.53334427, 2.28303599, -1.66728830, 6.16678810, -0.04532295, -1.97708666, 9.74906158, 1.46223176, -1.46734393, 4.30761862, -1.23790228, 1.24823606, 6.13938427, -3.83689475, -1.19625473, 7.91535568, 6.05868721, -3.22946382, 8.81633949, -0.19967777, 0.66053957, 2.30919123, 0.74543846, -0.39347672, 11.11058044, 0.53720862, 1.52645731, 5.70012379, -1.15213466, 1.16451406, 7.00526333, 1.57362783, -2.44384766, 5.54213285, -1.98828590, -0.70483637, 7.88281822, -3.59875536, 0.80745387, 13.41578484, -1.55507684, -0.65855008, 9.32583523, -0.14544789, 0.73436141, 3.61176538, -1.71268058, -2.58490300, 9.09280205, -3.27405524, -2.04569697, 4.44761324, -0.62955856, -2.61917663, 8.04890442, 0.54579324, 0.85929775, 9.82259560, -1.93825579, 0.77703512, 4.67090321, -4.79267597, -2.38906908, 9.31265545, 0.96026313, -1.14109385, 11.54231834, -0.01417295, -0.39500344, 8.49191666, 0.55300158, 2.79490185, 6.92466164, 1.72254205, 2.82222271, 8.83112717, 2.95033407, 2.18054962, 6.73509789, -2.22272944, 0.51127720, -1.04563558, 2.15747333, -2.30959272, 9.55441570, 1.50396204, 1.77370787, 7.38146257, -1.79076433, 3.20961165, 7.18864202, 2.91217351, 0.43018937, 7.11078024, -1.17386127, -0.16817921, 6.12327290, -2.82205725, 3.30696845, 13.51291752, -1.30856836, -2.38332748, 11.09487438, -1.47190213, -0.53050828, 4.38285351, -5.07309771, 1.50714362, 5.72274446, -2.85825086, -0.89673209, 3.73791552, -0.67708802, -4.13149452, -0.00671843, -0.26566532, 0.32961160, 7.14501762, -1.41608179, -4.96590328, 12.26205540, -0.65158135, -0.88641000, 6.95777559, -0.79058206, -0.10260171, 7.87169170, 1.35921454, 1.11759663, 5.46187401, -2.57214499, 2.48484039, 4.04043484, -2.07137156, -1.42709637, 9.25487137, -0.12605135, -2.66949964, 2.89412403, 0.74451172, -2.96250391, 3.99258423, 0.27084303, 0.32213116, 5.42332172, -0.44414216, 1.70881832, 6.69346905, 0.53058422, -4.73146200, 4.22051668, 2.24834967, 0.66996074, 4.30173683, 0.11849818, -4.07520294, 8.27318478, -2.54398274, -2.86705542, 10.11775303, -0.99382895, 0.65881538, 7.93556786, -1.27934420, -1.69343162, 9.68042564, -1.02609646, -1.18189347, 5.75370646, -1.67888868, -4.48871994, 4.79537392, -0.79212248, -0.19855022, 6.15060997, -0.01081491, 3.64454579, 10.82562447, 1.58859253, -2.65847278, 8.60093212, -1.59196103, 0.07635692, 11.76175690, -1.17453325, 0.10122013, 6.86458445, -2.18891335, -2.74004745, 8.07066154, 0.71818852, -2.03035975, 6.31053686, 0.51509416, 1.39789927, 9.43515587, 2.04256630, 0.13985133, 4.65010691, 2.40911126, -0.36255789, -3.06867862, -0.45225358, -1.56778407, 6.05917358, -1.09891272, 1.77184200, 6.46248102, 0.96042323, -0.24346280, 4.63436460, -4.69907761, 1.25187206, 11.46173859, -2.21917558, 1.28007793, 6.92173195, 2.11268163, -3.47389889, 5.08722782, -3.03950930, -4.17154264, 11.30568314, 0.80361372, 2.53214502, 7.18707085, -4.49114513, 2.85449266, 10.14906883, -0.31974933, -0.84472644, -0.52459574, 0.12921631, -1.81390119, 2.76170087, 1.03982210, 2.91744232, -0.29048753, 5.87453508, -1.53684759, 1.85800636, -0.91404629, 1.28954852, 5.11354685, -2.47475505, -1.33179152, 2.58552408, 1.37316465, -3.32339454, 1.54122913, 3.24953628, -0.29758382, 2.82391763, -1.51142192, -1.22699404, 6.75745535, 0.65452754, -3.29385471, 2.06008053, 2.53172946, -4.23532820, -1.53909743, -0.07010663, -1.42173731, 7.29031610, -0.18448229, 4.59496164, 6.73027277, 0.73441899, 0.14426160, 4.14915276, -2.97010231, 6.05851364, 4.95218086, -2.39145470, 2.40494704, 2.10288811, 0.53503096, 1.44511235, 6.66344261, -3.05803776, 7.21418667, 3.30303526, -0.24163735, 3.47409391, 3.64520788, 2.15189481, -3.11243272, 3.62310791, 0.37379482, 0.40865007, -0.83132005, -4.78246069, 2.07030797, 6.51765442, 3.16178989, 5.06180477, 3.78434467, -0.96689719, 0.35965276, 5.89967585, 1.40294051, 1.11952639, 10.59778214, 0.26739889, -1.61297631, 6.24801159, -0.93914318, -0.57812452, 9.92604542, -0.73025000, -3.38530874, 2.45646000, -2.47949195, 0.51638460, 10.65636063, 1.97816694, -3.00407791, 2.66914415, -0.81951088, -0.23316640, 2.40737987, -2.70007610, 1.51531935, 4.08860207, -0.27552786, -1.31721711, 7.11568260, -3.33498216, -4.02545023, 7.22675610, -0.81690705, -2.52689576, 1.04016697, -0.79291463, -0.34875512, 10.00498390, -4.24167728, 1.46162593, 11.82569408, -1.70359993, -0.30161047, 16.44085884, -0.82253462, -0.09435523, 6.13080597, -0.20259480, 0.68308711, 6.15663004, -6.61776876, 0.33295766, 2.55449438, -0.17819691, -1.14892209, 5.56776142, 1.99279118, 1.33035934, 4.45823956, 3.34916544, -2.59905386, 6.16164446, -2.03881931, -2.45273542, 12.46793365, -2.22743297, 2.83738565, 8.48628139, -1.39347959, -1.30867767, 11.08041477, -4.00363779, 2.09183025, 11.30395889, -2.20504737, 1.37426853, 8.98735619, 1.04676604, -0.72757077, 8.28050232, -6.70741081, -0.65798020, 5.68592072, -0.60760021, 0.35854483, 6.26852131, 1.94100165, 1.32112014, 0.80987954, -1.74617672, -0.25434083, 7.16045523, 1.58884013, -2.64847064, 13.14820385, 1.21393633, -2.47258949, 9.41650105, -0.79384226, 2.48954105, 10.95629311, 0.47723705, 4.02126694, 8.02593136, -2.20726371, -1.18794477, 1.50836647, 0.93118095, -1.73513174, 8.85493565, -2.99670315, -0.79055870, 2.39473820, 2.05046916, -2.38055134, 11.82299423, 0.15609655, 0.68744308, 5.66401434, -0.69281673, 2.09855556, 7.74626589, -0.34283102, 1.00542057, 9.95838642, 0.80161905, 2.33455157, 9.80057335, -0.93561798, 2.56991577, 8.29711342, 0.94213426, 0.44209945, 11.70259857, 0.92710167, 2.60957146, 0.24971688, -0.86529571, 3.78628922, 6.80884457, -0.68178189, 2.21103406, 3.18895817, 0.60283208, -2.92716241, 6.72060776, -1.06625068, 2.56543374, 9.97404480, 3.58080721, -0.94936347, 10.16736984, -1.38464379, 1.18191063, 6.66179037, -3.56115270, 0.32329530, 10.90870762, 2.20638227, 0.19653285, 7.34650040, -3.63859272, -1.03027737, 5.98829985, -3.66606474, -3.89746714, 8.63469028, 1.22569811, 1.63240814, 3.74385309, 0.58243257, -0.56981975, 3.69260955, 1.00979900, -1.44030499, 8.57058144, -1.10648811, 1.20474911, 5.43133020, -2.14822555, -0.07928789, 11.25825310, 0.19645604, -5.49546146, 10.41917038, -0.68178523, -2.99639869, 6.50054455, 0.46488351, -5.42328453, 9.09500027, -2.82107449, 0.05601966, 15.34610748, -0.06820253, 3.86699796, 10.73316956, -3.04795432, -0.14702171, 5.64813185, 1.44028485, -2.47596145, 0.07280898, -3.03187990, -1.35183525, 9.35835648, 2.72966957, 1.88199532, 10.36187744, -0.22834805, -3.26738238, 6.92025137, -2.34061313, 4.77379704, 5.28559113, -2.96323752, -1.76186585, 5.94436455, 0.38647744, -5.73869514, 6.76849556, 1.40892124, -1.19068217, 5.37919092, -6.65328646, 3.62782669, 12.34744644, 2.44762444, -4.19242620, 6.14906216, 0.08121119, 0.61355996, 2.69666457, -1.88962626, -0.55314136, 1.84937525, 1.56048691, 1.17460012, 3.75674725, 1.06198275, -5.74625874, 5.41645575, -1.28946674, -1.51689398, 4.32400894, -0.05222082, -4.83948946, 1.80747867, 1.63144708, -2.73887825, 1.63975775, -2.02163982, -0.16210437, 2.93518686, 1.14427686, -2.83246303, 4.79283667, 2.69697428, -3.12678456, -1.19225168, -2.37022972, -3.09429741, 1.94225383, -1.13747168, -2.55048585, 5.40242243, 1.12777328, 3.43713188, 3.62658787, -2.16878843, 0.30164462, 2.97407579, -0.07275413, -1.31149673, 4.70066261, -2.01323795, 4.85255766, 4.59128904, 1.68084168, 1.60336494, 6.58138466, -1.04759812, 2.69906545, 3.55769277, -0.74327278, 2.65819693, 5.39528131, 2.11248922, -1.06446671, 5.24546766, -2.43146014, 4.58907509, 0.06521678, -2.24503994, 2.45722699, 6.94863081, 0.35258654, 2.83396196, 9.92525196, -1.12225175, -0.34365177, 7.19116688, -4.39813757, 0.46517885, 13.22028065, -2.57483673, -6.37226963, 7.58046293, -2.74600363, 0.42231262, 8.04881668, 0.17289802, -0.53447008, 16.55157471, -5.63614368, 0.39288223, 3.37079263, 1.26484549, -0.12820500, 8.46440125, -4.39304399, 2.97676420, 0.65650189, 0.83158541, -1.11556435, 6.32885838, -0.36087769, 2.80724382, 9.90292645, 1.15936041, 0.20947981, 6.91249275, -2.67404819, 2.93782163, 6.65656614, -2.30828357, 2.98214006, 6.80611229, -4.93821478, -7.66555262, 7.59763002, -0.54159302, 3.87403512, 12.42607784, 2.59284401, -0.23375344, 8.95293331, -0.71807784, 0.61873478, 8.66713524, 1.24289191, -2.37835455, 2.08071637, -0.88315344, -3.41891551, 6.85245323, 1.73007369, 1.02169311, 7.69170332, -2.85411978, 2.69790673, 8.12906551, -1.19351399, -2.26442742, 12.26104450, -0.75579089, -1.73274946, 10.68729019, 2.20655656, -0.90522075, 12.42165184, -1.67929137, 2.44851565, 9.31565762, -0.06645700, 1.52762020, 6.18427515, -1.68882596, 3.70261097, 3.02252960, -3.44125366, -1.31575799, 2.84617424, -0.96849400, -4.52356243, 9.95027161, 0.19966406, -0.78874779, 8.18595028, -4.08300209, 1.75126517, 0.96418417, -4.04913044, -0.95200396, 12.03637886, -0.03041124, 0.41642749, 8.88267422, -3.24985337, -2.24919462, 7.32566118, 0.16964148, -2.74123430, 7.05264473, -3.30191112, 0.17163286, 4.81851053, -1.64463484, -0.85933101, 7.29276276, 2.34066939, -2.14860010, 3.46148157, -0.01782012, 1.51504040, 4.79304934, 1.85281146, -1.70663762, 6.93470192, -4.15440845, -1.25983095, 10.52491760, 0.42930329, -1.85146868, 11.70042324, -0.41704914, 3.83796859, 9.21148491, -2.79719448, 0.79470479, 6.26926661, -5.85230207, 3.95105338, 7.84790897, -1.38680744, -1.78099084, 11.95235348, -2.99841452, -1.34507811, 6.15714645, -1.07552516, -2.81228638, 1.66234732, -4.55166149, -1.92601109, 8.64634514, -0.48158705, 3.31595659, 7.67371941, 2.56964207, 0.12107098, 4.56467867, -0.93541539, 1.39432955, 11.99714088, 1.05353570, -2.13099813, 3.67617917, 3.45895386, 1.37365830, 8.74344158, -4.17585802, 1.43908918, 6.28764772, 3.97346330, -0.69144285, 9.07983303, -0.41635889, -0.14965028, 8.85469818, 1.11306190, 2.59440994, 5.38982344, -1.07948279, 1.37252975, 10.26984596, -0.09318046, 2.73104119, 12.45902252, -1.55446684, -2.76124811, 12.19395065, -0.51846564, 1.02764034, 11.42673588, -0.95940983, -0.04781032, 8.78379822, -4.88957930, 0.32534006, 11.97696400, -3.35108662, 1.95104563, 4.46915388, -2.32061648, 3.45230985, 8.29983711, 2.81034684, -2.35529327, 6.07801294, -0.98105043, -0.05359888, 2.52291036, -0.01986909, -2.35321999, 10.51954269, 2.11145401, 3.53506470, 7.29093266, 0.03721160, -1.13496494, 7.43886709, -5.84201956, 2.50796294, 12.14647675, 2.77490377, -2.18896222, 6.05641937, 5.32617044, 1.04221284, 10.79106712, -2.95749092, -2.75414610, 11.30037117, -3.40654182, -2.24673963, 7.49126101, 0.70811015, -6.18003702, 13.83951187, -1.01204085, 1.36298490, -1.04451632, 2.42435336, -0.02346706, -0.85528886, 1.04731262, 0.22192979, 4.15708160, 0.34933877, 0.04814529, 2.24107265, 0.49676740, -1.47752666, 0.45040059, -0.70471478, -1.19759345, 0.21711677, 0.88461423, -2.76830935, 5.52066898, 1.97664857, -1.75381601, 3.45877838, 1.52617192, -1.61350942, 0.85337949, 1.97610760, -3.40310287, 3.40319014, -3.38691044, -0.71319139, 1.65463758, -0.60680127, -1.80700517, 8.02592373, 2.59627104, 2.65895891, 5.93043184, -4.48425817, 3.92670918, 4.19496679, -2.28286791, 6.41634607, 5.72330523, 1.16269672, -0.28753027, 2.46342492, 0.36693189, 0.26712441, 6.37652683, -2.50139046, 2.43923736, 5.56310415, 0.98065847, 1.04267502, 4.16403675, -0.04966142, 4.40897894, 3.72905660, -3.46129870, 3.59962773, 1.34830284, -1.76661730, 0.47943926, 5.29946661, -1.12711561, 1.26970029, 15.17655945, -1.50971997, 5.81345224, 8.48562050, -4.36049604, 2.48144460, 8.23780441, -3.46030426, -0.84656560, 5.94946814, 1.12747943, -2.65683913, 8.69085693, 1.31309867, -2.79958344, 8.76840591, -1.56444156, 1.62710834, 2.41177034, -0.72804940, 5.70619011, 4.67169666, -0.86167198, -1.83803177, 2.96346045, 2.82692933, -2.81557131, 7.11113358, -1.90071094, 2.54244423, 11.19284058, -0.06298946, -1.71517313, 12.98388577, 0.84510714, 3.00816894, 2.57200313, 0.03899818, -1.49330592, 9.60099125, -3.59513044, -1.30045319, 7.09241819, -0.65233821, -2.33627677, 8.81366920, 0.84154201, 1.03312039, 9.85289097, 0.19351870, 1.78496623, 7.34631205, -2.16530800, -0.65016162, 2.46842360, 0.24016285, -1.24308395, 4.78175163, -0.97682536, 2.20942235, 6.68382788, 3.76786447, -1.44454038, 6.26453733, -3.23575711, -2.30137897, 9.53092670, -5.55222607, 3.25999236, 9.37559509, 1.86339056, -0.23551451, 10.23400211, 3.93031883, -0.52629089, 7.85724449, -2.91549587, 4.46612740, 5.66530371, -2.70820427, 4.81359577, 10.31247330, 1.92230141, 2.53931546, 0.74986327, 1.70303428, 0.48063779, 5.31099129, -0.78976244, 3.75864220, 4.23051405, 2.34042454, -7.98193836, 9.83987141, -1.46722627, 3.54497814, 10.36455154, -4.51249075, 0.77715248, 7.78694630, -4.59989023, -2.49585629, 9.90296268, 1.38535416, 1.17441154, 10.10452843, -0.98628229, 0.60194463, 9.12639141, -3.90754628, 2.88526392, 7.24123430, -0.15283313, -0.75728363, -1.15116858, -2.53791571, 0.77229571, 6.44114161, 0.02646767, 4.95463037, 7.21066380, 1.79384065, 0.73250306, 8.04447937, 0.32576546, -0.79447043, 10.12717724, 2.33392906, 1.30716443, 12.36073112, -0.36694977, -1.20438910, 7.03105593, 0.59557682, 0.69267452, 10.18113136, 2.49944925, -0.42229167, 8.83143330, -1.18805945, -2.87509322, 4.53596449, 4.09732771, -3.39088297, -1.02536607, 0.82119560, -3.47302604, 9.29991817, 0.21001509, 4.97036457, 9.50018406, 1.04420102, 1.96560478, 10.74769592, -6.22709799, 3.11690164, 5.06759691, -1.23724771, -3.05831861, 8.12925529, -1.93435478, -1.10151744, 9.32263088, -0.04249470, -5.98547363, 10.49398136, 0.26400441, -0.78915191, 13.28219604, 2.99276900, 0.74853164, 2.49364305, -3.43529654, 4.05278301, 2.13498688, -2.35444307, -0.79900265, 4.66968822, -0.31095147, 3.60674143, 12.37222099, -0.07855003, -3.30292702, 12.15215874, 0.60886210, 2.87075138, 7.75271845, 0.38044083, 3.34402204, 6.40583277, -0.87888050, 0.67438459, 6.91080809, 1.98332930, -0.08303714, 8.08630371, -0.16772588, -2.74058914, 7.17253590, -2.69122696, 1.48173678, 8.99470139, -1.43302310, -0.88651133, 2.66944790, -0.29186964, 2.00838661, 5.09587479, -0.76676071, -2.88322186, 8.31110573, -0.14550979, -1.37726915, 10.28355122, -1.60575438, -0.04118848, 9.97510815, 0.14440438, -3.24632120, 9.00034523, 4.14319563, -1.31023729, 7.16950464, -0.70428526, 2.01559544, 7.26155043, 2.40816474, 2.09847403, 7.31264496, -0.75401551, 2.13392544, 7.03648758, 1.04036045, -1.15636516, 1.09634531, -0.06340861, -0.58107805, -0.65623116, 1.18972754, -0.80717683, 1.40118241, -0.61932516, -3.60596156, 1.59904599, -2.23774099, -1.13721037, 3.89620137, -0.09115922, -7.51356888, 2.36975193, -1.42520905, -2.34173775, 3.33830214, -2.74016523, -3.04115510, 6.00119495, -1.36084354, -2.45065260, 4.56992292, -3.02825928,-3.74182844,5.11069250,-0.91531068,-2.31385994,1.83399653,3.39370203,-3.60886002}); - auto z = NDArrayFactory::create('c', {4, 4, 4, 3}); - auto exp = NDArrayFactory::create('c', {4, 4, 4, 3}, {7.97172260, 0.06878620, 2.27749538, 7.29276514, -0.14074677, 0.65480286, 5.70313978, -0.06546132, 0.35443667, 3.70382833, -0.84020567, 0.63826996, 8.60301399, -0.38236514, 1.55177069, 7.37542057, -0.99374938, -0.29971302, 8.84352493, -0.67121059, 0.43132120, 4.78175592, -1.25070143, -1.91523600, 6.03855371, -0.00292124, -1.11214364, 7.90158176, -0.57949901, -0.96735370, 7.81192017, -0.53255427, -0.48009714, 3.16953635, 0.08353355, -1.54299748, 3.74821687, 1.69396687, 0.72724354, 5.42915201, -1.13686812, -0.71793109, 5.78376389, -0.72239977, -0.60055625, 2.53636408, 0.56777251, -2.07892323, 6.08064651, 0.68620735, 2.54017019, 5.65828180, -0.68255502, 1.47283304, 6.10842514, -0.39655915, 0.28380761, 1.96707797, -1.98206317, 0.94027776, 4.71811438, 0.32104525, -0.92409706, 8.34588146, -1.05581069, -0.55217457, 9.58440876, -0.96549922, 0.45820439, 5.65453672, -2.50953507, -0.71441835, 8.03059578, -0.21281289, 0.92125505, 9.26900673, -0.35963219, -0.70039093, 8.59924412, -1.22358346, 0.81318003, 3.85920119, -0.01305223, -1.09234154, 6.33158875, 1.28094780, -1.48926139, 4.94969177, -0.77126902, -1.97033751, 5.64381838, -0.16285487, -1.31277227, 2.39893222, -1.32902908, -1.39609122, 6.47572327, -0.45267010, 1.55727172, 6.70965624, -1.68735468, -0.05672536, 7.25092363, -0.64613032, 0.67050058, 3.60789680, -2.05948973, 2.22687531, 8.15202713, -0.70148355, 1.28314006, 8.14842319, -1.88807654, -1.04808438, 8.45500565, -0.76425624, 0.94542569, 4.56179953, -0.28786001, -2.04502511, 8.46278095, -0.31019822, 0.07339200, 9.34214592, -0.61948007, 0.52481830, 8.32515621, -1.52418160, 0.49678251, 5.11082315, -1.09908783, -0.52969611, 5.27806664, 0.88632923, 0.66754371, 4.75839233, 0.48928693, -0.68036932, 6.56925392, -0.02949905, -2.99189186, 4.46320581, -0.64534980, -0.29516968, 8.60809517, -1.13120568, 3.41720533, 5.84243155, -1.24109328, 0.89566326, 5.99578333, -0.42496428, 2.07076764, 3.17812920, -0.81566459, -0.14363396, 6.55184317, 0.39633346, -0.43852386, 8.70214558, -2.24613595, 0.30708700, 8.73882294, -0.53545928, 1.54409575, 4.49452257, -0.16509305, 0.19028664, 8.24897003, 0.44750381, 2.15448594, 8.97640514, -0.77728152, 0.57272542, 9.03467560, 0.47173575, -1.10807717, 3.30056310, -0.43268481, -0.41470885, 3.53798294, -0.08546703, -2.16840744, 6.18733406, -0.17871059, -2.59837723, 5.94218683, -1.02990067, -0.49760687, 3.76938033, 0.86383581, -1.91504073}); - - sd::ops::avgpool2d op; - - NDArray::prepareSpecialUse({&z}, {&input}); - - Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(input.buffer()), input.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), (Nd4jPointer)input.specialShapeInfo()}; - - Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast(z.buffer()), z.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; - - Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1}; - - auto hash = op.getOpHash(); - auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false); - - NDArray::registerSpecialUse({&z}, {&input}); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto input = NDArrayFactory::create( + 'c', {4, 10, 10, 3}, + {9.37125111, 2.20166993, 2.91434479, 5.43639755, -2.10573769, + 4.08528662, 5.86908436, -4.46203756, 2.21057916, 5.35849190, + 0.01394637, 4.40566349, 7.07982206, -0.09633455, 2.42429352, + 3.97301817, -1.89553940, 1.99690318, 6.33141708, 0.55401880, + 1.70707977, 5.55204201, -0.03513752, 1.60011971, 2.62700319, + -2.74582434, 3.06697464, 1.06277943, -1.16075921, -0.78095782, + 9.72352791, -1.22686064, 1.99644792, 7.35571337, 1.40607321, + 0.11390255, 9.53334427, 2.28303599, -1.66728830, 6.16678810, + -0.04532295, -1.97708666, 9.74906158, 1.46223176, -1.46734393, + 4.30761862, -1.23790228, 1.24823606, 6.13938427, -3.83689475, + -1.19625473, 7.91535568, 6.05868721, -3.22946382, 8.81633949, + -0.19967777, 0.66053957, 2.30919123, 0.74543846, -0.39347672, + 11.11058044, 0.53720862, 1.52645731, 5.70012379, -1.15213466, + 1.16451406, 7.00526333, 1.57362783, -2.44384766, 5.54213285, + -1.98828590, -0.70483637, 7.88281822, -3.59875536, 0.80745387, + 13.41578484, -1.55507684, -0.65855008, 9.32583523, -0.14544789, + 0.73436141, 3.61176538, -1.71268058, -2.58490300, 9.09280205, + -3.27405524, -2.04569697, 4.44761324, -0.62955856, -2.61917663, + 8.04890442, 0.54579324, 0.85929775, 9.82259560, -1.93825579, + 0.77703512, 4.67090321, -4.79267597, -2.38906908, 9.31265545, + 0.96026313, -1.14109385, 11.54231834, -0.01417295, -0.39500344, + 8.49191666, 0.55300158, 2.79490185, 6.92466164, 1.72254205, + 2.82222271, 8.83112717, 2.95033407, 2.18054962, 6.73509789, + -2.22272944, 0.51127720, -1.04563558, 2.15747333, -2.30959272, + 9.55441570, 1.50396204, 1.77370787, 7.38146257, -1.79076433, + 3.20961165, 7.18864202, 2.91217351, 0.43018937, 7.11078024, + -1.17386127, -0.16817921, 6.12327290, -2.82205725, 3.30696845, + 13.51291752, -1.30856836, -2.38332748, 11.09487438, -1.47190213, + -0.53050828, 4.38285351, -5.07309771, 1.50714362, 5.72274446, + -2.85825086, -0.89673209, 3.73791552, -0.67708802, -4.13149452, + -0.00671843, -0.26566532, 0.32961160, 7.14501762, -1.41608179, + -4.96590328, 12.26205540, -0.65158135, -0.88641000, 6.95777559, + -0.79058206, -0.10260171, 7.87169170, 1.35921454, 1.11759663, + 5.46187401, -2.57214499, 2.48484039, 4.04043484, -2.07137156, + -1.42709637, 9.25487137, -0.12605135, -2.66949964, 2.89412403, + 0.74451172, -2.96250391, 3.99258423, 0.27084303, 0.32213116, + 5.42332172, -0.44414216, 1.70881832, 6.69346905, 0.53058422, + -4.73146200, 4.22051668, 2.24834967, 0.66996074, 4.30173683, + 0.11849818, -4.07520294, 8.27318478, -2.54398274, -2.86705542, + 10.11775303, -0.99382895, 0.65881538, 7.93556786, -1.27934420, + -1.69343162, 9.68042564, -1.02609646, -1.18189347, 5.75370646, + -1.67888868, -4.48871994, 4.79537392, -0.79212248, -0.19855022, + 6.15060997, -0.01081491, 3.64454579, 10.82562447, 1.58859253, + -2.65847278, 8.60093212, -1.59196103, 0.07635692, 11.76175690, + -1.17453325, 0.10122013, 6.86458445, -2.18891335, -2.74004745, + 8.07066154, 0.71818852, -2.03035975, 6.31053686, 0.51509416, + 1.39789927, 9.43515587, 2.04256630, 0.13985133, 4.65010691, + 2.40911126, -0.36255789, -3.06867862, -0.45225358, -1.56778407, + 6.05917358, -1.09891272, 1.77184200, 6.46248102, 0.96042323, + -0.24346280, 4.63436460, -4.69907761, 1.25187206, 11.46173859, + -2.21917558, 1.28007793, 6.92173195, 2.11268163, -3.47389889, + 5.08722782, -3.03950930, -4.17154264, 11.30568314, 0.80361372, + 2.53214502, 7.18707085, -4.49114513, 2.85449266, 10.14906883, + -0.31974933, -0.84472644, -0.52459574, 0.12921631, -1.81390119, + 2.76170087, 1.03982210, 2.91744232, -0.29048753, 5.87453508, + -1.53684759, 1.85800636, -0.91404629, 1.28954852, 5.11354685, + -2.47475505, -1.33179152, 2.58552408, 1.37316465, -3.32339454, + 1.54122913, 3.24953628, -0.29758382, 2.82391763, -1.51142192, + -1.22699404, 6.75745535, 0.65452754, -3.29385471, 2.06008053, + 2.53172946, -4.23532820, -1.53909743, -0.07010663, -1.42173731, + 7.29031610, -0.18448229, 4.59496164, 6.73027277, 0.73441899, + 0.14426160, 4.14915276, -2.97010231, 6.05851364, 4.95218086, + -2.39145470, 2.40494704, 2.10288811, 0.53503096, 1.44511235, + 6.66344261, -3.05803776, 7.21418667, 3.30303526, -0.24163735, + 3.47409391, 3.64520788, 2.15189481, -3.11243272, 3.62310791, + 0.37379482, 0.40865007, -0.83132005, -4.78246069, 2.07030797, + 6.51765442, 3.16178989, 5.06180477, 3.78434467, -0.96689719, + 0.35965276, 5.89967585, 1.40294051, 1.11952639, 10.59778214, + 0.26739889, -1.61297631, 6.24801159, -0.93914318, -0.57812452, + 9.92604542, -0.73025000, -3.38530874, 2.45646000, -2.47949195, + 0.51638460, 10.65636063, 1.97816694, -3.00407791, 2.66914415, + -0.81951088, -0.23316640, 2.40737987, -2.70007610, 1.51531935, + 4.08860207, -0.27552786, -1.31721711, 7.11568260, -3.33498216, + -4.02545023, 7.22675610, -0.81690705, -2.52689576, 1.04016697, + -0.79291463, -0.34875512, 10.00498390, -4.24167728, 1.46162593, + 11.82569408, -1.70359993, -0.30161047, 16.44085884, -0.82253462, + -0.09435523, 6.13080597, -0.20259480, 0.68308711, 6.15663004, + -6.61776876, 0.33295766, 2.55449438, -0.17819691, -1.14892209, + 5.56776142, 1.99279118, 1.33035934, 4.45823956, 3.34916544, + -2.59905386, 6.16164446, -2.03881931, -2.45273542, 12.46793365, + -2.22743297, 2.83738565, 8.48628139, -1.39347959, -1.30867767, + 11.08041477, -4.00363779, 2.09183025, 11.30395889, -2.20504737, + 1.37426853, 8.98735619, 1.04676604, -0.72757077, 8.28050232, + -6.70741081, -0.65798020, 5.68592072, -0.60760021, 0.35854483, + 6.26852131, 1.94100165, 1.32112014, 0.80987954, -1.74617672, + -0.25434083, 7.16045523, 1.58884013, -2.64847064, 13.14820385, + 1.21393633, -2.47258949, 9.41650105, -0.79384226, 2.48954105, + 10.95629311, 0.47723705, 4.02126694, 8.02593136, -2.20726371, + -1.18794477, 1.50836647, 0.93118095, -1.73513174, 8.85493565, + -2.99670315, -0.79055870, 2.39473820, 2.05046916, -2.38055134, + 11.82299423, 0.15609655, 0.68744308, 5.66401434, -0.69281673, + 2.09855556, 7.74626589, -0.34283102, 1.00542057, 9.95838642, + 0.80161905, 2.33455157, 9.80057335, -0.93561798, 2.56991577, + 8.29711342, 0.94213426, 0.44209945, 11.70259857, 0.92710167, + 2.60957146, 0.24971688, -0.86529571, 3.78628922, 6.80884457, + -0.68178189, 2.21103406, 3.18895817, 0.60283208, -2.92716241, + 6.72060776, -1.06625068, 2.56543374, 9.97404480, 3.58080721, + -0.94936347, 10.16736984, -1.38464379, 1.18191063, 6.66179037, + -3.56115270, 0.32329530, 10.90870762, 2.20638227, 0.19653285, + 7.34650040, -3.63859272, -1.03027737, 5.98829985, -3.66606474, + -3.89746714, 8.63469028, 1.22569811, 1.63240814, 3.74385309, + 0.58243257, -0.56981975, 3.69260955, 1.00979900, -1.44030499, + 8.57058144, -1.10648811, 1.20474911, 5.43133020, -2.14822555, + -0.07928789, 11.25825310, 0.19645604, -5.49546146, 10.41917038, + -0.68178523, -2.99639869, 6.50054455, 0.46488351, -5.42328453, + 9.09500027, -2.82107449, 0.05601966, 15.34610748, -0.06820253, + 3.86699796, 10.73316956, -3.04795432, -0.14702171, 5.64813185, + 1.44028485, -2.47596145, 0.07280898, -3.03187990, -1.35183525, + 9.35835648, 2.72966957, 1.88199532, 10.36187744, -0.22834805, + -3.26738238, 6.92025137, -2.34061313, 4.77379704, 5.28559113, + -2.96323752, -1.76186585, 5.94436455, 0.38647744, -5.73869514, + 6.76849556, 1.40892124, -1.19068217, 5.37919092, -6.65328646, + 3.62782669, 12.34744644, 2.44762444, -4.19242620, 6.14906216, + 0.08121119, 0.61355996, 2.69666457, -1.88962626, -0.55314136, + 1.84937525, 1.56048691, 1.17460012, 3.75674725, 1.06198275, + -5.74625874, 5.41645575, -1.28946674, -1.51689398, 4.32400894, + -0.05222082, -4.83948946, 1.80747867, 1.63144708, -2.73887825, + 1.63975775, -2.02163982, -0.16210437, 2.93518686, 1.14427686, + -2.83246303, 4.79283667, 2.69697428, -3.12678456, -1.19225168, + -2.37022972, -3.09429741, 1.94225383, -1.13747168, -2.55048585, + 5.40242243, 1.12777328, 3.43713188, 3.62658787, -2.16878843, + 0.30164462, 2.97407579, -0.07275413, -1.31149673, 4.70066261, + -2.01323795, 4.85255766, 4.59128904, 1.68084168, 1.60336494, + 6.58138466, -1.04759812, 2.69906545, 3.55769277, -0.74327278, + 2.65819693, 5.39528131, 2.11248922, -1.06446671, 5.24546766, + -2.43146014, 4.58907509, 0.06521678, -2.24503994, 2.45722699, + 6.94863081, 0.35258654, 2.83396196, 9.92525196, -1.12225175, + -0.34365177, 7.19116688, -4.39813757, 0.46517885, 13.22028065, + -2.57483673, -6.37226963, 7.58046293, -2.74600363, 0.42231262, + 8.04881668, 0.17289802, -0.53447008, 16.55157471, -5.63614368, + 0.39288223, 3.37079263, 1.26484549, -0.12820500, 8.46440125, + -4.39304399, 2.97676420, 0.65650189, 0.83158541, -1.11556435, + 6.32885838, -0.36087769, 2.80724382, 9.90292645, 1.15936041, + 0.20947981, 6.91249275, -2.67404819, 2.93782163, 6.65656614, + -2.30828357, 2.98214006, 6.80611229, -4.93821478, -7.66555262, + 7.59763002, -0.54159302, 3.87403512, 12.42607784, 2.59284401, + -0.23375344, 8.95293331, -0.71807784, 0.61873478, 8.66713524, + 1.24289191, -2.37835455, 2.08071637, -0.88315344, -3.41891551, + 6.85245323, 1.73007369, 1.02169311, 7.69170332, -2.85411978, + 2.69790673, 8.12906551, -1.19351399, -2.26442742, 12.26104450, + -0.75579089, -1.73274946, 10.68729019, 2.20655656, -0.90522075, + 12.42165184, -1.67929137, 2.44851565, 9.31565762, -0.06645700, + 1.52762020, 6.18427515, -1.68882596, 3.70261097, 3.02252960, + -3.44125366, -1.31575799, 2.84617424, -0.96849400, -4.52356243, + 9.95027161, 0.19966406, -0.78874779, 8.18595028, -4.08300209, + 1.75126517, 0.96418417, -4.04913044, -0.95200396, 12.03637886, + -0.03041124, 0.41642749, 8.88267422, -3.24985337, -2.24919462, + 7.32566118, 0.16964148, -2.74123430, 7.05264473, -3.30191112, + 0.17163286, 4.81851053, -1.64463484, -0.85933101, 7.29276276, + 2.34066939, -2.14860010, 3.46148157, -0.01782012, 1.51504040, + 4.79304934, 1.85281146, -1.70663762, 6.93470192, -4.15440845, + -1.25983095, 10.52491760, 0.42930329, -1.85146868, 11.70042324, + -0.41704914, 3.83796859, 9.21148491, -2.79719448, 0.79470479, + 6.26926661, -5.85230207, 3.95105338, 7.84790897, -1.38680744, + -1.78099084, 11.95235348, -2.99841452, -1.34507811, 6.15714645, + -1.07552516, -2.81228638, 1.66234732, -4.55166149, -1.92601109, + 8.64634514, -0.48158705, 3.31595659, 7.67371941, 2.56964207, + 0.12107098, 4.56467867, -0.93541539, 1.39432955, 11.99714088, + 1.05353570, -2.13099813, 3.67617917, 3.45895386, 1.37365830, + 8.74344158, -4.17585802, 1.43908918, 6.28764772, 3.97346330, + -0.69144285, 9.07983303, -0.41635889, -0.14965028, 8.85469818, + 1.11306190, 2.59440994, 5.38982344, -1.07948279, 1.37252975, + 10.26984596, -0.09318046, 2.73104119, 12.45902252, -1.55446684, + -2.76124811, 12.19395065, -0.51846564, 1.02764034, 11.42673588, + -0.95940983, -0.04781032, 8.78379822, -4.88957930, 0.32534006, + 11.97696400, -3.35108662, 1.95104563, 4.46915388, -2.32061648, + 3.45230985, 8.29983711, 2.81034684, -2.35529327, 6.07801294, + -0.98105043, -0.05359888, 2.52291036, -0.01986909, -2.35321999, + 10.51954269, 2.11145401, 3.53506470, 7.29093266, 0.03721160, + -1.13496494, 7.43886709, -5.84201956, 2.50796294, 12.14647675, + 2.77490377, -2.18896222, 6.05641937, 5.32617044, 1.04221284, + 10.79106712, -2.95749092, -2.75414610, 11.30037117, -3.40654182, + -2.24673963, 7.49126101, 0.70811015, -6.18003702, 13.83951187, + -1.01204085, 1.36298490, -1.04451632, 2.42435336, -0.02346706, + -0.85528886, 1.04731262, 0.22192979, 4.15708160, 0.34933877, + 0.04814529, 2.24107265, 0.49676740, -1.47752666, 0.45040059, + -0.70471478, -1.19759345, 0.21711677, 0.88461423, -2.76830935, + 5.52066898, 1.97664857, -1.75381601, 3.45877838, 1.52617192, + -1.61350942, 0.85337949, 1.97610760, -3.40310287, 3.40319014, + -3.38691044, -0.71319139, 1.65463758, -0.60680127, -1.80700517, + 8.02592373, 2.59627104, 2.65895891, 5.93043184, -4.48425817, + 3.92670918, 4.19496679, -2.28286791, 6.41634607, 5.72330523, + 1.16269672, -0.28753027, 2.46342492, 0.36693189, 0.26712441, + 6.37652683, -2.50139046, 2.43923736, 5.56310415, 0.98065847, + 1.04267502, 4.16403675, -0.04966142, 4.40897894, 3.72905660, + -3.46129870, 3.59962773, 1.34830284, -1.76661730, 0.47943926, + 5.29946661, -1.12711561, 1.26970029, 15.17655945, -1.50971997, + 5.81345224, 8.48562050, -4.36049604, 2.48144460, 8.23780441, + -3.46030426, -0.84656560, 5.94946814, 1.12747943, -2.65683913, + 8.69085693, 1.31309867, -2.79958344, 8.76840591, -1.56444156, + 1.62710834, 2.41177034, -0.72804940, 5.70619011, 4.67169666, + -0.86167198, -1.83803177, 2.96346045, 2.82692933, -2.81557131, + 7.11113358, -1.90071094, 2.54244423, 11.19284058, -0.06298946, + -1.71517313, 12.98388577, 0.84510714, 3.00816894, 2.57200313, + 0.03899818, -1.49330592, 9.60099125, -3.59513044, -1.30045319, + 7.09241819, -0.65233821, -2.33627677, 8.81366920, 0.84154201, + 1.03312039, 9.85289097, 0.19351870, 1.78496623, 7.34631205, + -2.16530800, -0.65016162, 2.46842360, 0.24016285, -1.24308395, + 4.78175163, -0.97682536, 2.20942235, 6.68382788, 3.76786447, + -1.44454038, 6.26453733, -3.23575711, -2.30137897, 9.53092670, + -5.55222607, 3.25999236, 9.37559509, 1.86339056, -0.23551451, + 10.23400211, 3.93031883, -0.52629089, 7.85724449, -2.91549587, + 4.46612740, 5.66530371, -2.70820427, 4.81359577, 10.31247330, + 1.92230141, 2.53931546, 0.74986327, 1.70303428, 0.48063779, + 5.31099129, -0.78976244, 3.75864220, 4.23051405, 2.34042454, + -7.98193836, 9.83987141, -1.46722627, 3.54497814, 10.36455154, + -4.51249075, 0.77715248, 7.78694630, -4.59989023, -2.49585629, + 9.90296268, 1.38535416, 1.17441154, 10.10452843, -0.98628229, + 0.60194463, 9.12639141, -3.90754628, 2.88526392, 7.24123430, + -0.15283313, -0.75728363, -1.15116858, -2.53791571, 0.77229571, + 6.44114161, 0.02646767, 4.95463037, 7.21066380, 1.79384065, + 0.73250306, 8.04447937, 0.32576546, -0.79447043, 10.12717724, + 2.33392906, 1.30716443, 12.36073112, -0.36694977, -1.20438910, + 7.03105593, 0.59557682, 0.69267452, 10.18113136, 2.49944925, + -0.42229167, 8.83143330, -1.18805945, -2.87509322, 4.53596449, + 4.09732771, -3.39088297, -1.02536607, 0.82119560, -3.47302604, + 9.29991817, 0.21001509, 4.97036457, 9.50018406, 1.04420102, + 1.96560478, 10.74769592, -6.22709799, 3.11690164, 5.06759691, + -1.23724771, -3.05831861, 8.12925529, -1.93435478, -1.10151744, + 9.32263088, -0.04249470, -5.98547363, 10.49398136, 0.26400441, + -0.78915191, 13.28219604, 2.99276900, 0.74853164, 2.49364305, + -3.43529654, 4.05278301, 2.13498688, -2.35444307, -0.79900265, + 4.66968822, -0.31095147, 3.60674143, 12.37222099, -0.07855003, + -3.30292702, 12.15215874, 0.60886210, 2.87075138, 7.75271845, + 0.38044083, 3.34402204, 6.40583277, -0.87888050, 0.67438459, + 6.91080809, 1.98332930, -0.08303714, 8.08630371, -0.16772588, + -2.74058914, 7.17253590, -2.69122696, 1.48173678, 8.99470139, + -1.43302310, -0.88651133, 2.66944790, -0.29186964, 2.00838661, + 5.09587479, -0.76676071, -2.88322186, 8.31110573, -0.14550979, + -1.37726915, 10.28355122, -1.60575438, -0.04118848, 9.97510815, + 0.14440438, -3.24632120, 9.00034523, 4.14319563, -1.31023729, + 7.16950464, -0.70428526, 2.01559544, 7.26155043, 2.40816474, + 2.09847403, 7.31264496, -0.75401551, 2.13392544, 7.03648758, + 1.04036045, -1.15636516, 1.09634531, -0.06340861, -0.58107805, + -0.65623116, 1.18972754, -0.80717683, 1.40118241, -0.61932516, + -3.60596156, 1.59904599, -2.23774099, -1.13721037, 3.89620137, + -0.09115922, -7.51356888, 2.36975193, -1.42520905, -2.34173775, + 3.33830214, -2.74016523, -3.04115510, 6.00119495, -1.36084354, + -2.45065260, 4.56992292, -3.02825928, -3.74182844, 5.11069250, + -0.91531068, -2.31385994, 1.83399653, 3.39370203, -3.60886002}); + auto z = NDArrayFactory::create('c', {4, 4, 4, 3}); + auto exp = NDArrayFactory::create( + 'c', {4, 4, 4, 3}, + {7.97172260, 0.06878620, 2.27749538, 7.29276514, -0.14074677, + 0.65480286, 5.70313978, -0.06546132, 0.35443667, 3.70382833, + -0.84020567, 0.63826996, 8.60301399, -0.38236514, 1.55177069, + 7.37542057, -0.99374938, -0.29971302, 8.84352493, -0.67121059, + 0.43132120, 4.78175592, -1.25070143, -1.91523600, 6.03855371, + -0.00292124, -1.11214364, 7.90158176, -0.57949901, -0.96735370, + 7.81192017, -0.53255427, -0.48009714, 3.16953635, 0.08353355, + -1.54299748, 3.74821687, 1.69396687, 0.72724354, 5.42915201, + -1.13686812, -0.71793109, 5.78376389, -0.72239977, -0.60055625, + 2.53636408, 0.56777251, -2.07892323, 6.08064651, 0.68620735, + 2.54017019, 5.65828180, -0.68255502, 1.47283304, 6.10842514, + -0.39655915, 0.28380761, 1.96707797, -1.98206317, 0.94027776, + 4.71811438, 0.32104525, -0.92409706, 8.34588146, -1.05581069, + -0.55217457, 9.58440876, -0.96549922, 0.45820439, 5.65453672, + -2.50953507, -0.71441835, 8.03059578, -0.21281289, 0.92125505, + 9.26900673, -0.35963219, -0.70039093, 8.59924412, -1.22358346, + 0.81318003, 3.85920119, -0.01305223, -1.09234154, 6.33158875, + 1.28094780, -1.48926139, 4.94969177, -0.77126902, -1.97033751, + 5.64381838, -0.16285487, -1.31277227, 2.39893222, -1.32902908, + -1.39609122, 6.47572327, -0.45267010, 1.55727172, 6.70965624, + -1.68735468, -0.05672536, 7.25092363, -0.64613032, 0.67050058, + 3.60789680, -2.05948973, 2.22687531, 8.15202713, -0.70148355, + 1.28314006, 8.14842319, -1.88807654, -1.04808438, 8.45500565, + -0.76425624, 0.94542569, 4.56179953, -0.28786001, -2.04502511, + 8.46278095, -0.31019822, 0.07339200, 9.34214592, -0.61948007, + 0.52481830, 8.32515621, -1.52418160, 0.49678251, 5.11082315, + -1.09908783, -0.52969611, 5.27806664, 0.88632923, 0.66754371, + 4.75839233, 0.48928693, -0.68036932, 6.56925392, -0.02949905, + -2.99189186, 4.46320581, -0.64534980, -0.29516968, 8.60809517, + -1.13120568, 3.41720533, 5.84243155, -1.24109328, 0.89566326, + 5.99578333, -0.42496428, 2.07076764, 3.17812920, -0.81566459, + -0.14363396, 6.55184317, 0.39633346, -0.43852386, 8.70214558, + -2.24613595, 0.30708700, 8.73882294, -0.53545928, 1.54409575, + 4.49452257, -0.16509305, 0.19028664, 8.24897003, 0.44750381, + 2.15448594, 8.97640514, -0.77728152, 0.57272542, 9.03467560, + 0.47173575, -1.10807717, 3.30056310, -0.43268481, -0.41470885, + 3.53798294, -0.08546703, -2.16840744, 6.18733406, -0.17871059, + -2.59837723, 5.94218683, -1.02990067, -0.49760687, 3.76938033, + 0.86383581, -1.91504073}); + + sd::ops::avgpool2d op; + + NDArray::prepareSpecialUse({&z}, {&input}); + + Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(input.buffer()), + input.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)input.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast(z.buffer()), + z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), + (Nd4jPointer)z.specialShapeInfo()}; + + Nd4jLong iArgs[] = {3, 3, 3, 3, 0, 0, 1, 1, 1, 0, 1}; + + auto hash = op.getOpHash(); + auto status = + execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, + ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false); + + NDArray::registerSpecialUse({&z}, {&input}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(JavaInteropTests, Test_MaxPool2D_float_1) { - auto input = NDArrayFactory::create('c', {1, 1, 4, 5}); - auto z = NDArrayFactory::create('c', {1, 1, 4, 5}); + auto input = NDArrayFactory::create('c', {1, 1, 4, 5}); + auto z = NDArrayFactory::create('c', {1, 1, 4, 5}); - input.linspace(1.0); + input.linspace(1.0); - NDArray::prepareSpecialUse({&z}, {&input}); + NDArray::prepareSpecialUse({&z}, {&input}); - Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(input.buffer()), input.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), (Nd4jPointer)input.specialShapeInfo()}; + Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(input.buffer()), + input.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)input.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast(z.buffer()), z.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; + Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast(z.buffer()), + z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), + (Nd4jPointer)z.specialShapeInfo()}; - Nd4jLong iArgs[] = {2,2, 1,1, 1,1, 2,2,1, 0,0}; + Nd4jLong iArgs[] = {2, 2, 1, 1, 1, 1, 2, 2, 1, 0, 0}; - sd::ops::maxpool2d op; + sd::ops::maxpool2d op; - auto hash = op.getOpHash(); - auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false); + auto hash = op.getOpHash(); + auto status = + execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, + ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false); - NDArray::registerSpecialUse({&z}, {&input}); - ASSERT_EQ(Status::OK(), status); + NDArray::registerSpecialUse({&z}, {&input}); + ASSERT_EQ(Status::OK(), status); } TEST_F(JavaInteropTests, Test_Unstack_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.linspace(1.0); - auto z0 = NDArrayFactory::create('c',{5}); - auto z1 = NDArrayFactory::create('c',{5}); - auto z2 = NDArrayFactory::create('c',{5}); - auto z3 = NDArrayFactory::create('c',{5}); - auto z4 = NDArrayFactory::create('c',{5}); - - NDArray::prepareSpecialUse({&z0, &z1, &z2, &z3, &z4}, {&x}); - - Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(x.buffer()), x.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)x.shapeInfo(), (Nd4jPointer)x.specialShapeInfo()}; - - Nd4jPointer ptrsOutBuffers[] = {z0.buffer(), z1.buffer(), z2.buffer(), z3.buffer(), z4.buffer(), z0.specialBuffer(), z1.specialBuffer(), z2.specialBuffer(), z3.specialBuffer(), z4.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z0.shapeInfo(), (Nd4jPointer)z1.shapeInfo(), (Nd4jPointer)z2.shapeInfo(), - (Nd4jPointer)z3.shapeInfo(), (Nd4jPointer)z4.shapeInfo(), (Nd4jPointer)z0.specialShapeInfo(), - (Nd4jPointer)z1.specialShapeInfo(), (Nd4jPointer)z2.specialShapeInfo(), - (Nd4jPointer)z3.specialShapeInfo(), (Nd4jPointer)z4.specialShapeInfo()}; - - Nd4jLong iArgs[] = {0}; - - sd::ops::unstack op; - - auto hash = op.getOpHash(); - auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 5, nullptr, 0, iArgs, 1, nullptr, 0, false); - - NDArray::registerSpecialUse({&z0, &z1, &z2, &z3, &z4}, {&x}); - ASSERT_EQ(Status::OK(), status); + auto x = NDArrayFactory::create('c', {5, 5}); + x.linspace(1.0); + auto z0 = NDArrayFactory::create('c', {5}); + auto z1 = NDArrayFactory::create('c', {5}); + auto z2 = NDArrayFactory::create('c', {5}); + auto z3 = NDArrayFactory::create('c', {5}); + auto z4 = NDArrayFactory::create('c', {5}); + + NDArray::prepareSpecialUse({&z0, &z1, &z2, &z3, &z4}, {&x}); + + Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(x.buffer()), + x.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)x.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {z0.buffer(), z1.buffer(), + z2.buffer(), z3.buffer(), + z4.buffer(), z0.specialBuffer(), + z1.specialBuffer(), z2.specialBuffer(), + z3.specialBuffer(), z4.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = { + (Nd4jPointer)z0.shapeInfo(), (Nd4jPointer)z1.shapeInfo(), + (Nd4jPointer)z2.shapeInfo(), (Nd4jPointer)z3.shapeInfo(), + (Nd4jPointer)z4.shapeInfo(), (Nd4jPointer)z0.specialShapeInfo(), + (Nd4jPointer)z1.specialShapeInfo(), (Nd4jPointer)z2.specialShapeInfo(), + (Nd4jPointer)z3.specialShapeInfo(), (Nd4jPointer)z4.specialShapeInfo()}; + + Nd4jLong iArgs[] = {0}; + + sd::ops::unstack op; + + auto hash = op.getOpHash(); + auto status = + execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, + ptrsOutShapes, 5, nullptr, 0, iArgs, 1, nullptr, 0, false); + + NDArray::registerSpecialUse({&z0, &z1, &z2, &z3, &z4}, {&x}); + ASSERT_EQ(Status::OK(), status); } TEST_F(JavaInteropTests, Test_AveragePooling_FF_TF_float) { - - auto input = NDArrayFactory::create('c', {4, 10, 10, 3}, {9.37125111f, 2.20166993f,2.91434479f,5.43639755f,-2.10573769f, 4.08528662f,5.86908436f,-4.46203756f,2.21057916f,5.35849190f,0.01394637f, 4.40566349f, 7.07982206f, -0.09633455f, 2.42429352f, 3.97301817f, -1.89553940f, 1.99690318f, 6.33141708f, 0.55401880f, 1.70707977f, 5.55204201f, -0.03513752f, 1.60011971f, 2.62700319f, -2.74582434f, 3.06697464f, 1.06277943f, -1.16075921f, -0.78095782f, 9.72352791f, -1.22686064f, 1.99644792f, 7.35571337f, 1.40607321f, 0.11390255f, 9.53334427f, 2.28303599f, -1.66728830f, 6.16678810f, -0.04532295f, -1.97708666f, 9.74906158f, 1.46223176f, -1.46734393f, 4.30761862f, -1.23790228f, 1.24823606f, 6.13938427f, -3.83689475f, -1.19625473f, 7.91535568f, 6.05868721f, -3.22946382f, 8.81633949f, -0.19967777f, 0.66053957f, 2.30919123f, 0.74543846f, -0.39347672f, 11.11058044f, 0.53720862f, 1.52645731f, 5.70012379f, -1.15213466f, 1.16451406f, 7.00526333f, 1.57362783f, -2.44384766f, 5.54213285f, -1.98828590f, -0.70483637f, 7.88281822f, -3.59875536f, 0.80745387f, 13.41578484f, -1.55507684f, -0.65855008f, 9.32583523f, -0.14544789f, 0.73436141f, 3.61176538f, -1.71268058f, -2.58490300f, 9.09280205f, -3.27405524f, -2.04569697f, 4.44761324f, -0.62955856f, -2.61917663f, 8.04890442f, 0.54579324f, 0.85929775f, 9.82259560f, -1.93825579f, 0.77703512f, 4.67090321f, -4.79267597f, -2.38906908f, 9.31265545f, 0.96026313f, -1.14109385f, 11.54231834f, -0.01417295f, -0.39500344f, 8.49191666f, 0.55300158f, 2.79490185f, 6.92466164f, 1.72254205f, 2.82222271f, 8.83112717f, 2.95033407f, 2.18054962f, 6.73509789f, -2.22272944f, 0.51127720f, -1.04563558f, 2.15747333f, -2.30959272f, 9.55441570f, 1.50396204f, 1.77370787f, 7.38146257f, -1.79076433f, 3.20961165f, 7.18864202f, 2.91217351f, 0.43018937f, 7.11078024f, -1.17386127f, -0.16817921f, 6.12327290f, -2.82205725f, 3.30696845f, 13.51291752f, -1.30856836f, -2.38332748f, 11.09487438f, -1.47190213f, -0.53050828f, 4.38285351f, -5.07309771f, 1.50714362f, 5.72274446f, -2.85825086f, -0.89673209f, 3.73791552f, -0.67708802f, -4.13149452f, -0.00671843f, -0.26566532f, 0.32961160f, 7.14501762f, -1.41608179f, -4.96590328f, 12.26205540f, -0.65158135f, -0.88641000f, 6.95777559f, -0.79058206f, -0.10260171f, 7.87169170f, 1.35921454f, 1.11759663f, 5.46187401f, -2.57214499f, 2.48484039f, 4.04043484f, -2.07137156f, -1.42709637f, 9.25487137f, -0.12605135f, -2.66949964f, 2.89412403f, 0.74451172f, -2.96250391f, 3.99258423f, 0.27084303f, 0.32213116f, 5.42332172f, -0.44414216f, 1.70881832f, 6.69346905f, 0.53058422f, -4.73146200f, 4.22051668f, 2.24834967f, 0.66996074f, 4.30173683f, 0.11849818f, -4.07520294f, 8.27318478f, -2.54398274f, -2.86705542f, 10.11775303f, -0.99382895f, 0.65881538f, 7.93556786f, -1.27934420f, -1.69343162f, 9.68042564f, -1.02609646f, -1.18189347f, 5.75370646f, -1.67888868f, -4.48871994f, 4.79537392f, -0.79212248f, -0.19855022f, 6.15060997f, -0.01081491f, 3.64454579f, 10.82562447f, 1.58859253f, -2.65847278f, 8.60093212f, -1.59196103f, 0.07635692f, 11.76175690f, -1.17453325f, 0.10122013f, 6.86458445f, -2.18891335f, -2.74004745f, 8.07066154f, 0.71818852f, -2.03035975f, 6.31053686f, 0.51509416f, 1.39789927f, 9.43515587f, 2.04256630f, 0.13985133f, 4.65010691f, 2.40911126f, -0.36255789f, -3.06867862f, -0.45225358f, -1.56778407f, 6.05917358f, -1.09891272f, 1.77184200f, 6.46248102f, 0.96042323f, -0.24346280f, 4.63436460f, -4.69907761f, 1.25187206f, 11.46173859f, -2.21917558f, 1.28007793f, 6.92173195f, 2.11268163f, -3.47389889f, 5.08722782f, -3.03950930f, -4.17154264f, 11.30568314f, 0.80361372f, 2.53214502f, 7.18707085f, -4.49114513f, 2.85449266f, 10.14906883f, -0.31974933f, -0.84472644f, -0.52459574f, 0.12921631f, -1.81390119f, 2.76170087f, 1.03982210f, 2.91744232f, -0.29048753f, 5.87453508f, -1.53684759f, 1.85800636f, -0.91404629f, 1.28954852f, 5.11354685f, -2.47475505f, -1.33179152f, 2.58552408f, 1.37316465f, -3.32339454f, 1.54122913f, 3.24953628f, -0.29758382f, 2.82391763f, -1.51142192f, -1.22699404f, 6.75745535f, 0.65452754f, -3.29385471f, 2.06008053f, 2.53172946f, -4.23532820f, -1.53909743f, -0.07010663f, -1.42173731f, 7.29031610f, -0.18448229f, 4.59496164f, 6.73027277f, 0.73441899f, 0.14426160f, 4.14915276f, -2.97010231f, 6.05851364f, 4.95218086f, -2.39145470f, 2.40494704f, 2.10288811f, 0.53503096f, 1.44511235f, 6.66344261f, -3.05803776f, 7.21418667f, 3.30303526f, -0.24163735f, 3.47409391f, 3.64520788f, 2.15189481f, -3.11243272f, 3.62310791f, 0.37379482f, 0.40865007f, -0.83132005f, -4.78246069f, 2.07030797f, 6.51765442f, 3.16178989f, 5.06180477f, 3.78434467f, -0.96689719f, 0.35965276f, 5.89967585f, 1.40294051f, 1.11952639f, 10.59778214f, 0.26739889f, -1.61297631f, 6.24801159f, -0.93914318f, -0.57812452f, 9.92604542f, -0.73025000f, -3.38530874f, 2.45646000f, -2.47949195f, 0.51638460f, 10.65636063f, 1.97816694f, -3.00407791f, 2.66914415f, -0.81951088f, -0.23316640f, 2.40737987f, -2.70007610f, 1.51531935f, 4.08860207f, -0.27552786f, -1.31721711f, 7.11568260f, -3.33498216f, -4.02545023f, 7.22675610f, -0.81690705f, -2.52689576f, 1.04016697f, -0.79291463f, -0.34875512f, 10.00498390f, -4.24167728f, 1.46162593f, 11.82569408f, -1.70359993f, -0.30161047f, 16.44085884f, -0.82253462f, -0.09435523f, 6.13080597f, -0.20259480f, 0.68308711f, 6.15663004f, -6.61776876f, 0.33295766f, 2.55449438f, -0.17819691f, -1.14892209f, 5.56776142f, 1.99279118f, 1.33035934f, 4.45823956f, 3.34916544f, -2.59905386f, 6.16164446f, -2.03881931f, -2.45273542f, 12.46793365f, -2.22743297f, 2.83738565f, 8.48628139f, -1.39347959f, -1.30867767f, 11.08041477f, -4.00363779f, 2.09183025f, 11.30395889f, -2.20504737f, 1.37426853f, 8.98735619f, 1.04676604f, -0.72757077f, 8.28050232f, -6.70741081f, -0.65798020f, 5.68592072f, -0.60760021f, 0.35854483f, 6.26852131f, 1.94100165f, 1.32112014f, 0.80987954f, -1.74617672f, -0.25434083f, 7.16045523f, 1.58884013f, -2.64847064f, 13.14820385f, 1.21393633f, -2.47258949f, 9.41650105f, -0.79384226f, 2.48954105f, 10.95629311f, 0.47723705f, 4.02126694f, 8.02593136f, -2.20726371f, -1.18794477f, 1.50836647f, 0.93118095f, -1.73513174f, 8.85493565f, -2.99670315f, -0.79055870f, 2.39473820f, 2.05046916f, -2.38055134f, 11.82299423f, 0.15609655f, 0.68744308f, 5.66401434f, -0.69281673f, 2.09855556f, 7.74626589f, -0.34283102f, 1.00542057f, 9.95838642f, 0.80161905f, 2.33455157f, 9.80057335f, -0.93561798f, 2.56991577f, 8.29711342f, 0.94213426f, 0.44209945f, 11.70259857f, 0.92710167f, 2.60957146f, 0.24971688f, -0.86529571f, 3.78628922f, 6.80884457f, -0.68178189f, 2.21103406f, 3.18895817f, 0.60283208f, -2.92716241f, 6.72060776f, -1.06625068f, 2.56543374f, 9.97404480f, 3.58080721f, -0.94936347f, 10.16736984f, -1.38464379f, 1.18191063f, 6.66179037f, -3.56115270f, 0.32329530f, 10.90870762f, 2.20638227f, 0.19653285f, 7.34650040f, -3.63859272f, -1.03027737f, 5.98829985f, -3.66606474f, -3.89746714f, 8.63469028f, 1.22569811f, 1.63240814f, 3.74385309f, 0.58243257f, -0.56981975f, 3.69260955f, 1.00979900f, -1.44030499f, 8.57058144f, -1.10648811f, 1.20474911f, 5.43133020f, -2.14822555f, -0.07928789f, 11.25825310f, 0.19645604f, -5.49546146f, 10.41917038f, -0.68178523f, -2.99639869f, 6.50054455f, 0.46488351f, -5.42328453f, 9.09500027f, -2.82107449f, 0.05601966f, 15.34610748f, -0.06820253f, 3.86699796f, 10.73316956f, -3.04795432f, -0.14702171f, 5.64813185f, 1.44028485f, -2.47596145f, 0.07280898f, -3.03187990f, -1.35183525f, 9.35835648f, 2.72966957f, 1.88199532f, 10.36187744f, -0.22834805f, -3.26738238f, 6.92025137f, -2.34061313f, 4.77379704f, 5.28559113f, -2.96323752f, -1.76186585f, 5.94436455f, 0.38647744f, -5.73869514f, 6.76849556f, 1.40892124f, -1.19068217f, 5.37919092f, -6.65328646f, 3.62782669f, 12.34744644f, 2.44762444f, -4.19242620f, 6.14906216f, 0.08121119f, 0.61355996f, 2.69666457f, -1.88962626f, -0.55314136f, 1.84937525f, 1.56048691f, 1.17460012f, 3.75674725f, 1.06198275f, -5.74625874f, 5.41645575f, -1.28946674f, -1.51689398f, 4.32400894f, -0.05222082f, -4.83948946f, 1.80747867f, 1.63144708f, -2.73887825f, 1.63975775f, -2.02163982f, -0.16210437f, 2.93518686f, 1.14427686f, -2.83246303f, 4.79283667f, 2.69697428f, -3.12678456f, -1.19225168f, -2.37022972f, -3.09429741f, 1.94225383f, -1.13747168f, -2.55048585f, 5.40242243f, 1.12777328f, 3.43713188f, 3.62658787f, -2.16878843f, 0.30164462f, 2.97407579f, -0.07275413f, -1.31149673f, 4.70066261f, -2.01323795f, 4.85255766f, 4.59128904f, 1.68084168f, 1.60336494f, 6.58138466f, -1.04759812f, 2.69906545f, 3.55769277f, -0.74327278f, 2.65819693f, 5.39528131f, 2.11248922f, -1.06446671f, 5.24546766f, -2.43146014f, 4.58907509f, 0.06521678f, -2.24503994f, 2.45722699f, 6.94863081f, 0.35258654f, 2.83396196f, 9.92525196f, -1.12225175f, -0.34365177f, 7.19116688f, -4.39813757f, 0.46517885f, 13.22028065f, -2.57483673f, -6.37226963f, 7.58046293f, -2.74600363f, 0.42231262f, 8.04881668f, 0.17289802f, -0.53447008f, 16.55157471f, -5.63614368f, 0.39288223f, 3.37079263f, 1.26484549f, -0.12820500f, 8.46440125f, -4.39304399f, 2.97676420f, 0.65650189f, 0.83158541f, -1.11556435f, 6.32885838f, -0.36087769f, 2.80724382f, 9.90292645f, 1.15936041f, 0.20947981f, 6.91249275f, -2.67404819f, 2.93782163f, 6.65656614f, -2.30828357f, 2.98214006f, 6.80611229f, -4.93821478f, -7.66555262f, 7.59763002f, -0.54159302f, 3.87403512f, 12.42607784f, 2.59284401f, -0.23375344f, 8.95293331f, -0.71807784f, 0.61873478f, 8.66713524f, 1.24289191f, -2.37835455f, 2.08071637f, -0.88315344f, -3.41891551f, 6.85245323f, 1.73007369f, 1.02169311f, 7.69170332f, -2.85411978f, 2.69790673f, 8.12906551f, -1.19351399f, -2.26442742f, 12.26104450f, -0.75579089f, -1.73274946f, 10.68729019f, 2.20655656f, -0.90522075f, 12.42165184f, -1.67929137f, 2.44851565f, 9.31565762f, -0.06645700f, 1.52762020f, 6.18427515f, -1.68882596f, 3.70261097f, 3.02252960f, -3.44125366f, -1.31575799f, 2.84617424f, -0.96849400f, -4.52356243f, 9.95027161f, 0.19966406f, -0.78874779f, 8.18595028f, -4.08300209f, 1.75126517f, 0.96418417f, -4.04913044f, -0.95200396f, 12.03637886f, -0.03041124f, 0.41642749f, 8.88267422f, -3.24985337f, -2.24919462f, 7.32566118f, 0.16964148f, -2.74123430f, 7.05264473f, -3.30191112f, 0.17163286f, 4.81851053f, -1.64463484f, -0.85933101f, 7.29276276f, 2.34066939f, -2.14860010f, 3.46148157f, -0.01782012f, 1.51504040f, 4.79304934f, 1.85281146f, -1.70663762f, 6.93470192f, -4.15440845f, -1.25983095f, 10.52491760f, 0.42930329f, -1.85146868f, 11.70042324f, -0.41704914f, 3.83796859f, 9.21148491f, -2.79719448f, 0.79470479f, 6.26926661f, -5.85230207f, 3.95105338f, 7.84790897f, -1.38680744f, -1.78099084f, 11.95235348f, -2.99841452f, -1.34507811f, 6.15714645f, -1.07552516f, -2.81228638f, 1.66234732f, -4.55166149f, -1.92601109f, 8.64634514f, -0.48158705f, 3.31595659f, 7.67371941f, 2.56964207f, 0.12107098f, 4.56467867f, -0.93541539f, 1.39432955f, 11.99714088f, 1.05353570f, -2.13099813f, 3.67617917f, 3.45895386f, 1.37365830f, 8.74344158f, -4.17585802f, 1.43908918f, 6.28764772f, 3.97346330f, -0.69144285f, 9.07983303f, -0.41635889f, -0.14965028f, 8.85469818f, 1.11306190f, 2.59440994f, 5.38982344f, -1.07948279f, 1.37252975f, 10.26984596f, -0.09318046f, 2.73104119f, 12.45902252f, -1.55446684f, -2.76124811f, 12.19395065f, -0.51846564f, 1.02764034f, 11.42673588f, -0.95940983f, -0.04781032f, 8.78379822f, -4.88957930f, 0.32534006f, 11.97696400f, -3.35108662f, 1.95104563f, 4.46915388f, -2.32061648f, 3.45230985f, 8.29983711f, 2.81034684f, -2.35529327f, 6.07801294f, -0.98105043f, -0.05359888f, 2.52291036f, -0.01986909f, -2.35321999f, 10.51954269f, 2.11145401f, 3.53506470f, 7.29093266f, 0.03721160f, -1.13496494f, 7.43886709f, -5.84201956f, 2.50796294f, 12.14647675f, 2.77490377f, -2.18896222f, 6.05641937f, 5.32617044f, 1.04221284f, 10.79106712f, -2.95749092f, -2.75414610f, 11.30037117f, -3.40654182f, -2.24673963f, 7.49126101f, 0.70811015f, -6.18003702f, 13.83951187f, -1.01204085f, 1.36298490f, -1.04451632f, 2.42435336f, -0.02346706f, -0.85528886f, 1.04731262f, 0.22192979f, 4.15708160f, 0.34933877f, 0.04814529f, 2.24107265f, 0.49676740f, -1.47752666f, 0.45040059f, -0.70471478f, -1.19759345f, 0.21711677f, 0.88461423f, -2.76830935f, 5.52066898f, 1.97664857f, -1.75381601f, 3.45877838f, 1.52617192f, -1.61350942f, 0.85337949f, 1.97610760f, -3.40310287f, 3.40319014f, -3.38691044f, -0.71319139f, 1.65463758f, -0.60680127f, -1.80700517f, 8.02592373f, 2.59627104f, 2.65895891f, 5.93043184f, -4.48425817f, 3.92670918f, 4.19496679f, -2.28286791f, 6.41634607f, 5.72330523f, 1.16269672f, -0.28753027f, 2.46342492f, 0.36693189f, 0.26712441f, 6.37652683f, -2.50139046f, 2.43923736f, 5.56310415f, 0.98065847f, 1.04267502f, 4.16403675f, -0.04966142f, 4.40897894f, 3.72905660f, -3.46129870f, 3.59962773f, 1.34830284f, -1.76661730f, 0.47943926f, 5.29946661f, -1.12711561f, 1.26970029f, 15.17655945f, -1.50971997f, 5.81345224f, 8.48562050f, -4.36049604f, 2.48144460f, 8.23780441f, -3.46030426f, -0.84656560f, 5.94946814f, 1.12747943f, -2.65683913f, 8.69085693f, 1.31309867f, -2.79958344f, 8.76840591f, -1.56444156f, 1.62710834f, 2.41177034f, -0.72804940f, 5.70619011f, 4.67169666f, -0.86167198f, -1.83803177f, 2.96346045f, 2.82692933f, -2.81557131f, 7.11113358f, -1.90071094f, 2.54244423f, 11.19284058f, -0.06298946f, -1.71517313f, 12.98388577f, 0.84510714f, 3.00816894f, 2.57200313f, 0.03899818f, -1.49330592f, 9.60099125f, -3.59513044f, -1.30045319f, 7.09241819f, -0.65233821f, -2.33627677f, 8.81366920f, 0.84154201f, 1.03312039f, 9.85289097f, 0.19351870f, 1.78496623f, 7.34631205f, -2.16530800f, -0.65016162f, 2.46842360f, 0.24016285f, -1.24308395f, 4.78175163f, -0.97682536f, 2.20942235f, 6.68382788f, 3.76786447f, -1.44454038f, 6.26453733f, -3.23575711f, -2.30137897f, 9.53092670f, -5.55222607f, 3.25999236f, 9.37559509f, 1.86339056f, -0.23551451f, 10.23400211f, 3.93031883f, -0.52629089f, 7.85724449f, -2.91549587f, 4.46612740f, 5.66530371f, -2.70820427f, 4.81359577f, 10.31247330f, 1.92230141f, 2.53931546f, 0.74986327f, 1.70303428f, 0.48063779f, 5.31099129f, -0.78976244f, 3.75864220f, 4.23051405f, 2.34042454f, -7.98193836f, 9.83987141f, -1.46722627f, 3.54497814f, 10.36455154f, -4.51249075f, 0.77715248f, 7.78694630f, -4.59989023f, -2.49585629f, 9.90296268f, 1.38535416f, 1.17441154f, 10.10452843f, -0.98628229f, 0.60194463f, 9.12639141f, -3.90754628f, 2.88526392f, 7.24123430f, -0.15283313f, -0.75728363f, -1.15116858f, -2.53791571f, 0.77229571f, 6.44114161f, 0.02646767f, 4.95463037f, 7.21066380f, 1.79384065f, 0.73250306f, 8.04447937f, 0.32576546f, -0.79447043f, 10.12717724f, 2.33392906f, 1.30716443f, 12.36073112f, -0.36694977f, -1.20438910f, 7.03105593f, 0.59557682f, 0.69267452f, 10.18113136f, 2.49944925f, -0.42229167f, 8.83143330f, -1.18805945f, -2.87509322f, 4.53596449f, 4.09732771f, -3.39088297f, -1.02536607f, 0.82119560f, -3.47302604f, 9.29991817f, 0.21001509f, 4.97036457f, 9.50018406f, 1.04420102f, 1.96560478f, 10.74769592f, -6.22709799f, 3.11690164f, 5.06759691f, -1.23724771f, -3.05831861f, 8.12925529f, -1.93435478f, -1.10151744f, 9.32263088f, -0.04249470f, -5.98547363f, 10.49398136f, 0.26400441f, -0.78915191f, 13.28219604f, 2.99276900f, 0.74853164f, 2.49364305f, -3.43529654f, 4.05278301f, 2.13498688f, -2.35444307f, -0.79900265f, 4.66968822f, -0.31095147f, 3.60674143f, 12.37222099f, -0.07855003f, -3.30292702f, 12.15215874f, 0.60886210f, 2.87075138f, 7.75271845f, 0.38044083f, 3.34402204f, 6.40583277f, -0.87888050f, 0.67438459f, 6.91080809f, 1.98332930f, -0.08303714f, 8.08630371f, -0.16772588f, -2.74058914f, 7.17253590f, -2.69122696f, 1.48173678f, 8.99470139f, -1.43302310f, -0.88651133f, 2.66944790f, -0.29186964f, 2.00838661f, 5.09587479f, -0.76676071f, -2.88322186f, 8.31110573f, -0.14550979f, -1.37726915f, 10.28355122f, -1.60575438f, -0.04118848f, 9.97510815f, 0.14440438f, -3.24632120f, 9.00034523f, 4.14319563f, -1.31023729f, 7.16950464f, -0.70428526f, 2.01559544f, 7.26155043f, 2.40816474f, 2.09847403f, 7.31264496f, -0.75401551f, 2.13392544f, 7.03648758f, 1.04036045f, -1.15636516f, 1.09634531f, -0.06340861f, -0.58107805f, -0.65623116f, 1.18972754f, -0.80717683f, 1.40118241f, -0.61932516f, -3.60596156f, 1.59904599f, -2.23774099f, -1.13721037f, 3.89620137f, -0.09115922f, -7.51356888f, 2.36975193f, -1.42520905f, -2.34173775f, 3.33830214f, -2.74016523f, -3.04115510f, 6.00119495f, -1.36084354f, -2.45065260f, 4.56992292f, -3.02825928f, -3.74182844f, 5.11069250f, -0.91531068f, -2.31385994f, 1.83399653f, 3.39370203f, -3.60886002f}); - auto z = NDArrayFactory::create('c', {4, 4, 4, 3}); - auto exp = NDArrayFactory::create('c', {4, 4, 4, 3}, {7.97172260f, 0.06878620f, 2.27749538f, 7.29276514f, -0.14074677f, 0.65480286f, 5.70313978f, -0.06546132f, 0.35443667f, 3.70382833f, -0.84020567f, 0.63826996f, 8.60301399f, -0.38236514f, 1.55177069f, 7.37542057f, -0.99374938f, -0.29971302f, 8.84352493f, -0.67121059f, 0.43132120f, 4.78175592f, -1.25070143f, -1.91523600f, 6.03855371f, -0.00292124f, -1.11214364f, 7.90158176f, -0.57949901f, -0.96735370f, 7.81192017f, -0.53255427f, -0.48009714f, 3.16953635f, 0.08353355f, -1.54299748f, 3.74821687f, 1.69396687f, 0.72724354f, 5.42915201f, -1.13686812f, -0.71793109f, 5.78376389f, -0.72239977f, -0.60055625f, 2.53636408f, 0.56777251f, -2.07892323f, 6.08064651f, 0.68620735f, 2.54017019f, 5.65828180f, -0.68255502f, 1.47283304f, 6.10842514f, -0.39655915f, 0.28380761f, 1.96707797f, -1.98206317f, 0.94027776f, 4.71811438f, 0.32104525f, -0.92409706f, 8.34588146f, -1.05581069f, -0.55217457f, 9.58440876f, -0.96549922f, 0.45820439f, 5.65453672f, -2.50953507f, -0.71441835f, 8.03059578f, -0.21281289f, 0.92125505f, 9.26900673f, -0.35963219f, -0.70039093f, 8.59924412f, -1.22358346f, 0.81318003f, 3.85920119f, -0.01305223f, -1.09234154f, 6.33158875f, 1.28094780f, -1.48926139f, 4.94969177f, -0.77126902f, -1.97033751f, 5.64381838f, -0.16285487f, -1.31277227f, 2.39893222f, -1.32902908f, -1.39609122f, 6.47572327f, -0.45267010f, 1.55727172f, 6.70965624f, -1.68735468f, -0.05672536f, 7.25092363f, -0.64613032f, 0.67050058f, 3.60789680f, -2.05948973f, 2.22687531f, 8.15202713f, -0.70148355f, 1.28314006f, 8.14842319f, -1.88807654f, -1.04808438f, 8.45500565f, -0.76425624f, 0.94542569f, 4.56179953f, -0.28786001f, -2.04502511f, 8.46278095f, -0.31019822f, 0.07339200f, 9.34214592f, -0.61948007f, 0.52481830f, 8.32515621f, -1.52418160f, 0.49678251f, 5.11082315f, -1.09908783f, -0.52969611f, 5.27806664f, 0.88632923f, 0.66754371f, 4.75839233f, 0.48928693f, -0.68036932f, 6.56925392f, -0.02949905f, -2.99189186f, 4.46320581f, -0.64534980f, -0.29516968f, 8.60809517f, -1.13120568f, 3.41720533f, 5.84243155f, -1.24109328f, 0.89566326f, 5.99578333f, -0.42496428f, 2.07076764f, 3.17812920f, -0.81566459f, -0.14363396f, 6.55184317f, 0.39633346f, -0.43852386f, 8.70214558f, -2.24613595f, 0.30708700f, 8.73882294f, -0.53545928f, 1.54409575f, 4.49452257f, -0.16509305f, 0.19028664f, 8.24897003f, 0.44750381f, 2.15448594f, 8.97640514f, -0.77728152f, 0.57272542f, 9.03467560f, 0.47173575f, -1.10807717f, 3.30056310f, -0.43268481f, -0.41470885f, 3.53798294f, -0.08546703f, -2.16840744f, 6.18733406f, -0.17871059f, -2.59837723f, 5.94218683f, -1.02990067f, -0.49760687f, 3.76938033f, 0.86383581f, -1.91504073f}); - - sd::ops::avgpool2d op; - - NDArray::prepareSpecialUse({&z}, {&input}); - - Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(input.buffer()), input.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), (Nd4jPointer)input.specialShapeInfo()}; - - Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast(z.buffer()), z.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; - Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1}; - - auto hash = op.getOpHash(); - auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false); - - NDArray::registerSpecialUse({&z}, {&input}); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto input = NDArrayFactory::create( + 'c', {4, 10, 10, 3}, + {9.37125111f, 2.20166993f, 2.91434479f, 5.43639755f, -2.10573769f, + 4.08528662f, 5.86908436f, -4.46203756f, 2.21057916f, 5.35849190f, + 0.01394637f, 4.40566349f, 7.07982206f, -0.09633455f, 2.42429352f, + 3.97301817f, -1.89553940f, 1.99690318f, 6.33141708f, 0.55401880f, + 1.70707977f, 5.55204201f, -0.03513752f, 1.60011971f, 2.62700319f, + -2.74582434f, 3.06697464f, 1.06277943f, -1.16075921f, -0.78095782f, + 9.72352791f, -1.22686064f, 1.99644792f, 7.35571337f, 1.40607321f, + 0.11390255f, 9.53334427f, 2.28303599f, -1.66728830f, 6.16678810f, + -0.04532295f, -1.97708666f, 9.74906158f, 1.46223176f, -1.46734393f, + 4.30761862f, -1.23790228f, 1.24823606f, 6.13938427f, -3.83689475f, + -1.19625473f, 7.91535568f, 6.05868721f, -3.22946382f, 8.81633949f, + -0.19967777f, 0.66053957f, 2.30919123f, 0.74543846f, -0.39347672f, + 11.11058044f, 0.53720862f, 1.52645731f, 5.70012379f, -1.15213466f, + 1.16451406f, 7.00526333f, 1.57362783f, -2.44384766f, 5.54213285f, + -1.98828590f, -0.70483637f, 7.88281822f, -3.59875536f, 0.80745387f, + 13.41578484f, -1.55507684f, -0.65855008f, 9.32583523f, -0.14544789f, + 0.73436141f, 3.61176538f, -1.71268058f, -2.58490300f, 9.09280205f, + -3.27405524f, -2.04569697f, 4.44761324f, -0.62955856f, -2.61917663f, + 8.04890442f, 0.54579324f, 0.85929775f, 9.82259560f, -1.93825579f, + 0.77703512f, 4.67090321f, -4.79267597f, -2.38906908f, 9.31265545f, + 0.96026313f, -1.14109385f, 11.54231834f, -0.01417295f, -0.39500344f, + 8.49191666f, 0.55300158f, 2.79490185f, 6.92466164f, 1.72254205f, + 2.82222271f, 8.83112717f, 2.95033407f, 2.18054962f, 6.73509789f, + -2.22272944f, 0.51127720f, -1.04563558f, 2.15747333f, -2.30959272f, + 9.55441570f, 1.50396204f, 1.77370787f, 7.38146257f, -1.79076433f, + 3.20961165f, 7.18864202f, 2.91217351f, 0.43018937f, 7.11078024f, + -1.17386127f, -0.16817921f, 6.12327290f, -2.82205725f, 3.30696845f, + 13.51291752f, -1.30856836f, -2.38332748f, 11.09487438f, -1.47190213f, + -0.53050828f, 4.38285351f, -5.07309771f, 1.50714362f, 5.72274446f, + -2.85825086f, -0.89673209f, 3.73791552f, -0.67708802f, -4.13149452f, + -0.00671843f, -0.26566532f, 0.32961160f, 7.14501762f, -1.41608179f, + -4.96590328f, 12.26205540f, -0.65158135f, -0.88641000f, 6.95777559f, + -0.79058206f, -0.10260171f, 7.87169170f, 1.35921454f, 1.11759663f, + 5.46187401f, -2.57214499f, 2.48484039f, 4.04043484f, -2.07137156f, + -1.42709637f, 9.25487137f, -0.12605135f, -2.66949964f, 2.89412403f, + 0.74451172f, -2.96250391f, 3.99258423f, 0.27084303f, 0.32213116f, + 5.42332172f, -0.44414216f, 1.70881832f, 6.69346905f, 0.53058422f, + -4.73146200f, 4.22051668f, 2.24834967f, 0.66996074f, 4.30173683f, + 0.11849818f, -4.07520294f, 8.27318478f, -2.54398274f, -2.86705542f, + 10.11775303f, -0.99382895f, 0.65881538f, 7.93556786f, -1.27934420f, + -1.69343162f, 9.68042564f, -1.02609646f, -1.18189347f, 5.75370646f, + -1.67888868f, -4.48871994f, 4.79537392f, -0.79212248f, -0.19855022f, + 6.15060997f, -0.01081491f, 3.64454579f, 10.82562447f, 1.58859253f, + -2.65847278f, 8.60093212f, -1.59196103f, 0.07635692f, 11.76175690f, + -1.17453325f, 0.10122013f, 6.86458445f, -2.18891335f, -2.74004745f, + 8.07066154f, 0.71818852f, -2.03035975f, 6.31053686f, 0.51509416f, + 1.39789927f, 9.43515587f, 2.04256630f, 0.13985133f, 4.65010691f, + 2.40911126f, -0.36255789f, -3.06867862f, -0.45225358f, -1.56778407f, + 6.05917358f, -1.09891272f, 1.77184200f, 6.46248102f, 0.96042323f, + -0.24346280f, 4.63436460f, -4.69907761f, 1.25187206f, 11.46173859f, + -2.21917558f, 1.28007793f, 6.92173195f, 2.11268163f, -3.47389889f, + 5.08722782f, -3.03950930f, -4.17154264f, 11.30568314f, 0.80361372f, + 2.53214502f, 7.18707085f, -4.49114513f, 2.85449266f, 10.14906883f, + -0.31974933f, -0.84472644f, -0.52459574f, 0.12921631f, -1.81390119f, + 2.76170087f, 1.03982210f, 2.91744232f, -0.29048753f, 5.87453508f, + -1.53684759f, 1.85800636f, -0.91404629f, 1.28954852f, 5.11354685f, + -2.47475505f, -1.33179152f, 2.58552408f, 1.37316465f, -3.32339454f, + 1.54122913f, 3.24953628f, -0.29758382f, 2.82391763f, -1.51142192f, + -1.22699404f, 6.75745535f, 0.65452754f, -3.29385471f, 2.06008053f, + 2.53172946f, -4.23532820f, -1.53909743f, -0.07010663f, -1.42173731f, + 7.29031610f, -0.18448229f, 4.59496164f, 6.73027277f, 0.73441899f, + 0.14426160f, 4.14915276f, -2.97010231f, 6.05851364f, 4.95218086f, + -2.39145470f, 2.40494704f, 2.10288811f, 0.53503096f, 1.44511235f, + 6.66344261f, -3.05803776f, 7.21418667f, 3.30303526f, -0.24163735f, + 3.47409391f, 3.64520788f, 2.15189481f, -3.11243272f, 3.62310791f, + 0.37379482f, 0.40865007f, -0.83132005f, -4.78246069f, 2.07030797f, + 6.51765442f, 3.16178989f, 5.06180477f, 3.78434467f, -0.96689719f, + 0.35965276f, 5.89967585f, 1.40294051f, 1.11952639f, 10.59778214f, + 0.26739889f, -1.61297631f, 6.24801159f, -0.93914318f, -0.57812452f, + 9.92604542f, -0.73025000f, -3.38530874f, 2.45646000f, -2.47949195f, + 0.51638460f, 10.65636063f, 1.97816694f, -3.00407791f, 2.66914415f, + -0.81951088f, -0.23316640f, 2.40737987f, -2.70007610f, 1.51531935f, + 4.08860207f, -0.27552786f, -1.31721711f, 7.11568260f, -3.33498216f, + -4.02545023f, 7.22675610f, -0.81690705f, -2.52689576f, 1.04016697f, + -0.79291463f, -0.34875512f, 10.00498390f, -4.24167728f, 1.46162593f, + 11.82569408f, -1.70359993f, -0.30161047f, 16.44085884f, -0.82253462f, + -0.09435523f, 6.13080597f, -0.20259480f, 0.68308711f, 6.15663004f, + -6.61776876f, 0.33295766f, 2.55449438f, -0.17819691f, -1.14892209f, + 5.56776142f, 1.99279118f, 1.33035934f, 4.45823956f, 3.34916544f, + -2.59905386f, 6.16164446f, -2.03881931f, -2.45273542f, 12.46793365f, + -2.22743297f, 2.83738565f, 8.48628139f, -1.39347959f, -1.30867767f, + 11.08041477f, -4.00363779f, 2.09183025f, 11.30395889f, -2.20504737f, + 1.37426853f, 8.98735619f, 1.04676604f, -0.72757077f, 8.28050232f, + -6.70741081f, -0.65798020f, 5.68592072f, -0.60760021f, 0.35854483f, + 6.26852131f, 1.94100165f, 1.32112014f, 0.80987954f, -1.74617672f, + -0.25434083f, 7.16045523f, 1.58884013f, -2.64847064f, 13.14820385f, + 1.21393633f, -2.47258949f, 9.41650105f, -0.79384226f, 2.48954105f, + 10.95629311f, 0.47723705f, 4.02126694f, 8.02593136f, -2.20726371f, + -1.18794477f, 1.50836647f, 0.93118095f, -1.73513174f, 8.85493565f, + -2.99670315f, -0.79055870f, 2.39473820f, 2.05046916f, -2.38055134f, + 11.82299423f, 0.15609655f, 0.68744308f, 5.66401434f, -0.69281673f, + 2.09855556f, 7.74626589f, -0.34283102f, 1.00542057f, 9.95838642f, + 0.80161905f, 2.33455157f, 9.80057335f, -0.93561798f, 2.56991577f, + 8.29711342f, 0.94213426f, 0.44209945f, 11.70259857f, 0.92710167f, + 2.60957146f, 0.24971688f, -0.86529571f, 3.78628922f, 6.80884457f, + -0.68178189f, 2.21103406f, 3.18895817f, 0.60283208f, -2.92716241f, + 6.72060776f, -1.06625068f, 2.56543374f, 9.97404480f, 3.58080721f, + -0.94936347f, 10.16736984f, -1.38464379f, 1.18191063f, 6.66179037f, + -3.56115270f, 0.32329530f, 10.90870762f, 2.20638227f, 0.19653285f, + 7.34650040f, -3.63859272f, -1.03027737f, 5.98829985f, -3.66606474f, + -3.89746714f, 8.63469028f, 1.22569811f, 1.63240814f, 3.74385309f, + 0.58243257f, -0.56981975f, 3.69260955f, 1.00979900f, -1.44030499f, + 8.57058144f, -1.10648811f, 1.20474911f, 5.43133020f, -2.14822555f, + -0.07928789f, 11.25825310f, 0.19645604f, -5.49546146f, 10.41917038f, + -0.68178523f, -2.99639869f, 6.50054455f, 0.46488351f, -5.42328453f, + 9.09500027f, -2.82107449f, 0.05601966f, 15.34610748f, -0.06820253f, + 3.86699796f, 10.73316956f, -3.04795432f, -0.14702171f, 5.64813185f, + 1.44028485f, -2.47596145f, 0.07280898f, -3.03187990f, -1.35183525f, + 9.35835648f, 2.72966957f, 1.88199532f, 10.36187744f, -0.22834805f, + -3.26738238f, 6.92025137f, -2.34061313f, 4.77379704f, 5.28559113f, + -2.96323752f, -1.76186585f, 5.94436455f, 0.38647744f, -5.73869514f, + 6.76849556f, 1.40892124f, -1.19068217f, 5.37919092f, -6.65328646f, + 3.62782669f, 12.34744644f, 2.44762444f, -4.19242620f, 6.14906216f, + 0.08121119f, 0.61355996f, 2.69666457f, -1.88962626f, -0.55314136f, + 1.84937525f, 1.56048691f, 1.17460012f, 3.75674725f, 1.06198275f, + -5.74625874f, 5.41645575f, -1.28946674f, -1.51689398f, 4.32400894f, + -0.05222082f, -4.83948946f, 1.80747867f, 1.63144708f, -2.73887825f, + 1.63975775f, -2.02163982f, -0.16210437f, 2.93518686f, 1.14427686f, + -2.83246303f, 4.79283667f, 2.69697428f, -3.12678456f, -1.19225168f, + -2.37022972f, -3.09429741f, 1.94225383f, -1.13747168f, -2.55048585f, + 5.40242243f, 1.12777328f, 3.43713188f, 3.62658787f, -2.16878843f, + 0.30164462f, 2.97407579f, -0.07275413f, -1.31149673f, 4.70066261f, + -2.01323795f, 4.85255766f, 4.59128904f, 1.68084168f, 1.60336494f, + 6.58138466f, -1.04759812f, 2.69906545f, 3.55769277f, -0.74327278f, + 2.65819693f, 5.39528131f, 2.11248922f, -1.06446671f, 5.24546766f, + -2.43146014f, 4.58907509f, 0.06521678f, -2.24503994f, 2.45722699f, + 6.94863081f, 0.35258654f, 2.83396196f, 9.92525196f, -1.12225175f, + -0.34365177f, 7.19116688f, -4.39813757f, 0.46517885f, 13.22028065f, + -2.57483673f, -6.37226963f, 7.58046293f, -2.74600363f, 0.42231262f, + 8.04881668f, 0.17289802f, -0.53447008f, 16.55157471f, -5.63614368f, + 0.39288223f, 3.37079263f, 1.26484549f, -0.12820500f, 8.46440125f, + -4.39304399f, 2.97676420f, 0.65650189f, 0.83158541f, -1.11556435f, + 6.32885838f, -0.36087769f, 2.80724382f, 9.90292645f, 1.15936041f, + 0.20947981f, 6.91249275f, -2.67404819f, 2.93782163f, 6.65656614f, + -2.30828357f, 2.98214006f, 6.80611229f, -4.93821478f, -7.66555262f, + 7.59763002f, -0.54159302f, 3.87403512f, 12.42607784f, 2.59284401f, + -0.23375344f, 8.95293331f, -0.71807784f, 0.61873478f, 8.66713524f, + 1.24289191f, -2.37835455f, 2.08071637f, -0.88315344f, -3.41891551f, + 6.85245323f, 1.73007369f, 1.02169311f, 7.69170332f, -2.85411978f, + 2.69790673f, 8.12906551f, -1.19351399f, -2.26442742f, 12.26104450f, + -0.75579089f, -1.73274946f, 10.68729019f, 2.20655656f, -0.90522075f, + 12.42165184f, -1.67929137f, 2.44851565f, 9.31565762f, -0.06645700f, + 1.52762020f, 6.18427515f, -1.68882596f, 3.70261097f, 3.02252960f, + -3.44125366f, -1.31575799f, 2.84617424f, -0.96849400f, -4.52356243f, + 9.95027161f, 0.19966406f, -0.78874779f, 8.18595028f, -4.08300209f, + 1.75126517f, 0.96418417f, -4.04913044f, -0.95200396f, 12.03637886f, + -0.03041124f, 0.41642749f, 8.88267422f, -3.24985337f, -2.24919462f, + 7.32566118f, 0.16964148f, -2.74123430f, 7.05264473f, -3.30191112f, + 0.17163286f, 4.81851053f, -1.64463484f, -0.85933101f, 7.29276276f, + 2.34066939f, -2.14860010f, 3.46148157f, -0.01782012f, 1.51504040f, + 4.79304934f, 1.85281146f, -1.70663762f, 6.93470192f, -4.15440845f, + -1.25983095f, 10.52491760f, 0.42930329f, -1.85146868f, 11.70042324f, + -0.41704914f, 3.83796859f, 9.21148491f, -2.79719448f, 0.79470479f, + 6.26926661f, -5.85230207f, 3.95105338f, 7.84790897f, -1.38680744f, + -1.78099084f, 11.95235348f, -2.99841452f, -1.34507811f, 6.15714645f, + -1.07552516f, -2.81228638f, 1.66234732f, -4.55166149f, -1.92601109f, + 8.64634514f, -0.48158705f, 3.31595659f, 7.67371941f, 2.56964207f, + 0.12107098f, 4.56467867f, -0.93541539f, 1.39432955f, 11.99714088f, + 1.05353570f, -2.13099813f, 3.67617917f, 3.45895386f, 1.37365830f, + 8.74344158f, -4.17585802f, 1.43908918f, 6.28764772f, 3.97346330f, + -0.69144285f, 9.07983303f, -0.41635889f, -0.14965028f, 8.85469818f, + 1.11306190f, 2.59440994f, 5.38982344f, -1.07948279f, 1.37252975f, + 10.26984596f, -0.09318046f, 2.73104119f, 12.45902252f, -1.55446684f, + -2.76124811f, 12.19395065f, -0.51846564f, 1.02764034f, 11.42673588f, + -0.95940983f, -0.04781032f, 8.78379822f, -4.88957930f, 0.32534006f, + 11.97696400f, -3.35108662f, 1.95104563f, 4.46915388f, -2.32061648f, + 3.45230985f, 8.29983711f, 2.81034684f, -2.35529327f, 6.07801294f, + -0.98105043f, -0.05359888f, 2.52291036f, -0.01986909f, -2.35321999f, + 10.51954269f, 2.11145401f, 3.53506470f, 7.29093266f, 0.03721160f, + -1.13496494f, 7.43886709f, -5.84201956f, 2.50796294f, 12.14647675f, + 2.77490377f, -2.18896222f, 6.05641937f, 5.32617044f, 1.04221284f, + 10.79106712f, -2.95749092f, -2.75414610f, 11.30037117f, -3.40654182f, + -2.24673963f, 7.49126101f, 0.70811015f, -6.18003702f, 13.83951187f, + -1.01204085f, 1.36298490f, -1.04451632f, 2.42435336f, -0.02346706f, + -0.85528886f, 1.04731262f, 0.22192979f, 4.15708160f, 0.34933877f, + 0.04814529f, 2.24107265f, 0.49676740f, -1.47752666f, 0.45040059f, + -0.70471478f, -1.19759345f, 0.21711677f, 0.88461423f, -2.76830935f, + 5.52066898f, 1.97664857f, -1.75381601f, 3.45877838f, 1.52617192f, + -1.61350942f, 0.85337949f, 1.97610760f, -3.40310287f, 3.40319014f, + -3.38691044f, -0.71319139f, 1.65463758f, -0.60680127f, -1.80700517f, + 8.02592373f, 2.59627104f, 2.65895891f, 5.93043184f, -4.48425817f, + 3.92670918f, 4.19496679f, -2.28286791f, 6.41634607f, 5.72330523f, + 1.16269672f, -0.28753027f, 2.46342492f, 0.36693189f, 0.26712441f, + 6.37652683f, -2.50139046f, 2.43923736f, 5.56310415f, 0.98065847f, + 1.04267502f, 4.16403675f, -0.04966142f, 4.40897894f, 3.72905660f, + -3.46129870f, 3.59962773f, 1.34830284f, -1.76661730f, 0.47943926f, + 5.29946661f, -1.12711561f, 1.26970029f, 15.17655945f, -1.50971997f, + 5.81345224f, 8.48562050f, -4.36049604f, 2.48144460f, 8.23780441f, + -3.46030426f, -0.84656560f, 5.94946814f, 1.12747943f, -2.65683913f, + 8.69085693f, 1.31309867f, -2.79958344f, 8.76840591f, -1.56444156f, + 1.62710834f, 2.41177034f, -0.72804940f, 5.70619011f, 4.67169666f, + -0.86167198f, -1.83803177f, 2.96346045f, 2.82692933f, -2.81557131f, + 7.11113358f, -1.90071094f, 2.54244423f, 11.19284058f, -0.06298946f, + -1.71517313f, 12.98388577f, 0.84510714f, 3.00816894f, 2.57200313f, + 0.03899818f, -1.49330592f, 9.60099125f, -3.59513044f, -1.30045319f, + 7.09241819f, -0.65233821f, -2.33627677f, 8.81366920f, 0.84154201f, + 1.03312039f, 9.85289097f, 0.19351870f, 1.78496623f, 7.34631205f, + -2.16530800f, -0.65016162f, 2.46842360f, 0.24016285f, -1.24308395f, + 4.78175163f, -0.97682536f, 2.20942235f, 6.68382788f, 3.76786447f, + -1.44454038f, 6.26453733f, -3.23575711f, -2.30137897f, 9.53092670f, + -5.55222607f, 3.25999236f, 9.37559509f, 1.86339056f, -0.23551451f, + 10.23400211f, 3.93031883f, -0.52629089f, 7.85724449f, -2.91549587f, + 4.46612740f, 5.66530371f, -2.70820427f, 4.81359577f, 10.31247330f, + 1.92230141f, 2.53931546f, 0.74986327f, 1.70303428f, 0.48063779f, + 5.31099129f, -0.78976244f, 3.75864220f, 4.23051405f, 2.34042454f, + -7.98193836f, 9.83987141f, -1.46722627f, 3.54497814f, 10.36455154f, + -4.51249075f, 0.77715248f, 7.78694630f, -4.59989023f, -2.49585629f, + 9.90296268f, 1.38535416f, 1.17441154f, 10.10452843f, -0.98628229f, + 0.60194463f, 9.12639141f, -3.90754628f, 2.88526392f, 7.24123430f, + -0.15283313f, -0.75728363f, -1.15116858f, -2.53791571f, 0.77229571f, + 6.44114161f, 0.02646767f, 4.95463037f, 7.21066380f, 1.79384065f, + 0.73250306f, 8.04447937f, 0.32576546f, -0.79447043f, 10.12717724f, + 2.33392906f, 1.30716443f, 12.36073112f, -0.36694977f, -1.20438910f, + 7.03105593f, 0.59557682f, 0.69267452f, 10.18113136f, 2.49944925f, + -0.42229167f, 8.83143330f, -1.18805945f, -2.87509322f, 4.53596449f, + 4.09732771f, -3.39088297f, -1.02536607f, 0.82119560f, -3.47302604f, + 9.29991817f, 0.21001509f, 4.97036457f, 9.50018406f, 1.04420102f, + 1.96560478f, 10.74769592f, -6.22709799f, 3.11690164f, 5.06759691f, + -1.23724771f, -3.05831861f, 8.12925529f, -1.93435478f, -1.10151744f, + 9.32263088f, -0.04249470f, -5.98547363f, 10.49398136f, 0.26400441f, + -0.78915191f, 13.28219604f, 2.99276900f, 0.74853164f, 2.49364305f, + -3.43529654f, 4.05278301f, 2.13498688f, -2.35444307f, -0.79900265f, + 4.66968822f, -0.31095147f, 3.60674143f, 12.37222099f, -0.07855003f, + -3.30292702f, 12.15215874f, 0.60886210f, 2.87075138f, 7.75271845f, + 0.38044083f, 3.34402204f, 6.40583277f, -0.87888050f, 0.67438459f, + 6.91080809f, 1.98332930f, -0.08303714f, 8.08630371f, -0.16772588f, + -2.74058914f, 7.17253590f, -2.69122696f, 1.48173678f, 8.99470139f, + -1.43302310f, -0.88651133f, 2.66944790f, -0.29186964f, 2.00838661f, + 5.09587479f, -0.76676071f, -2.88322186f, 8.31110573f, -0.14550979f, + -1.37726915f, 10.28355122f, -1.60575438f, -0.04118848f, 9.97510815f, + 0.14440438f, -3.24632120f, 9.00034523f, 4.14319563f, -1.31023729f, + 7.16950464f, -0.70428526f, 2.01559544f, 7.26155043f, 2.40816474f, + 2.09847403f, 7.31264496f, -0.75401551f, 2.13392544f, 7.03648758f, + 1.04036045f, -1.15636516f, 1.09634531f, -0.06340861f, -0.58107805f, + -0.65623116f, 1.18972754f, -0.80717683f, 1.40118241f, -0.61932516f, + -3.60596156f, 1.59904599f, -2.23774099f, -1.13721037f, 3.89620137f, + -0.09115922f, -7.51356888f, 2.36975193f, -1.42520905f, -2.34173775f, + 3.33830214f, -2.74016523f, -3.04115510f, 6.00119495f, -1.36084354f, + -2.45065260f, 4.56992292f, -3.02825928f, -3.74182844f, 5.11069250f, + -0.91531068f, -2.31385994f, 1.83399653f, 3.39370203f, -3.60886002f}); + auto z = NDArrayFactory::create('c', {4, 4, 4, 3}); + auto exp = NDArrayFactory::create( + 'c', {4, 4, 4, 3}, + {7.97172260f, 0.06878620f, 2.27749538f, 7.29276514f, -0.14074677f, + 0.65480286f, 5.70313978f, -0.06546132f, 0.35443667f, 3.70382833f, + -0.84020567f, 0.63826996f, 8.60301399f, -0.38236514f, 1.55177069f, + 7.37542057f, -0.99374938f, -0.29971302f, 8.84352493f, -0.67121059f, + 0.43132120f, 4.78175592f, -1.25070143f, -1.91523600f, 6.03855371f, + -0.00292124f, -1.11214364f, 7.90158176f, -0.57949901f, -0.96735370f, + 7.81192017f, -0.53255427f, -0.48009714f, 3.16953635f, 0.08353355f, + -1.54299748f, 3.74821687f, 1.69396687f, 0.72724354f, 5.42915201f, + -1.13686812f, -0.71793109f, 5.78376389f, -0.72239977f, -0.60055625f, + 2.53636408f, 0.56777251f, -2.07892323f, 6.08064651f, 0.68620735f, + 2.54017019f, 5.65828180f, -0.68255502f, 1.47283304f, 6.10842514f, + -0.39655915f, 0.28380761f, 1.96707797f, -1.98206317f, 0.94027776f, + 4.71811438f, 0.32104525f, -0.92409706f, 8.34588146f, -1.05581069f, + -0.55217457f, 9.58440876f, -0.96549922f, 0.45820439f, 5.65453672f, + -2.50953507f, -0.71441835f, 8.03059578f, -0.21281289f, 0.92125505f, + 9.26900673f, -0.35963219f, -0.70039093f, 8.59924412f, -1.22358346f, + 0.81318003f, 3.85920119f, -0.01305223f, -1.09234154f, 6.33158875f, + 1.28094780f, -1.48926139f, 4.94969177f, -0.77126902f, -1.97033751f, + 5.64381838f, -0.16285487f, -1.31277227f, 2.39893222f, -1.32902908f, + -1.39609122f, 6.47572327f, -0.45267010f, 1.55727172f, 6.70965624f, + -1.68735468f, -0.05672536f, 7.25092363f, -0.64613032f, 0.67050058f, + 3.60789680f, -2.05948973f, 2.22687531f, 8.15202713f, -0.70148355f, + 1.28314006f, 8.14842319f, -1.88807654f, -1.04808438f, 8.45500565f, + -0.76425624f, 0.94542569f, 4.56179953f, -0.28786001f, -2.04502511f, + 8.46278095f, -0.31019822f, 0.07339200f, 9.34214592f, -0.61948007f, + 0.52481830f, 8.32515621f, -1.52418160f, 0.49678251f, 5.11082315f, + -1.09908783f, -0.52969611f, 5.27806664f, 0.88632923f, 0.66754371f, + 4.75839233f, 0.48928693f, -0.68036932f, 6.56925392f, -0.02949905f, + -2.99189186f, 4.46320581f, -0.64534980f, -0.29516968f, 8.60809517f, + -1.13120568f, 3.41720533f, 5.84243155f, -1.24109328f, 0.89566326f, + 5.99578333f, -0.42496428f, 2.07076764f, 3.17812920f, -0.81566459f, + -0.14363396f, 6.55184317f, 0.39633346f, -0.43852386f, 8.70214558f, + -2.24613595f, 0.30708700f, 8.73882294f, -0.53545928f, 1.54409575f, + 4.49452257f, -0.16509305f, 0.19028664f, 8.24897003f, 0.44750381f, + 2.15448594f, 8.97640514f, -0.77728152f, 0.57272542f, 9.03467560f, + 0.47173575f, -1.10807717f, 3.30056310f, -0.43268481f, -0.41470885f, + 3.53798294f, -0.08546703f, -2.16840744f, 6.18733406f, -0.17871059f, + -2.59837723f, 5.94218683f, -1.02990067f, -0.49760687f, 3.76938033f, + 0.86383581f, -1.91504073f}); + + sd::ops::avgpool2d op; + + NDArray::prepareSpecialUse({&z}, {&input}); + + Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(input.buffer()), + input.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)input.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast(z.buffer()), + z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), + (Nd4jPointer)z.specialShapeInfo()}; + Nd4jLong iArgs[] = {3, 3, 3, 3, 0, 0, 1, 1, 1, 0, 1}; + + auto hash = op.getOpHash(); + auto status = + execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, + ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false); + + NDArray::registerSpecialUse({&z}, {&input}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(JavaInteropTests, Test_Mixed_Add_1) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - auto arrayX = NDArrayFactory::create({1, 2, 3, 4}); - auto arrayY = NDArrayFactory::create({1, 2, 3, 4}); - auto arrayZ = NDArrayFactory::create({0, 0, 0, 0}); - auto arrayE = NDArrayFactory::create({2, 4, 6, 8}); + auto arrayX = NDArrayFactory::create({1, 2, 3, 4}); + auto arrayY = NDArrayFactory::create({1, 2, 3, 4}); + auto arrayZ = NDArrayFactory::create({0, 0, 0, 0}); + auto arrayE = NDArrayFactory::create({2, 4, 6, 8}); - NDArray::prepareSpecialUse({&arrayZ}, {&arrayX, &arrayY}); + NDArray::prepareSpecialUse({&arrayZ}, {&arrayX, &arrayY}); - OpaqueDataBuffer xBuf(arrayX.dataBuffer()); - OpaqueDataBuffer yBuf(arrayY.dataBuffer()); - OpaqueDataBuffer zBuf(arrayZ.dataBuffer()); + OpaqueDataBuffer xBuf(arrayX.dataBuffer()); + OpaqueDataBuffer yBuf(arrayY.dataBuffer()); + OpaqueDataBuffer zBuf(arrayZ.dataBuffer()); - execPairwiseTransform(nullptr, pairwise::Add, - &xBuf, arrayX.shapeInfo(), arrayX.specialShapeInfo(), - &yBuf, arrayY.shapeInfo(), arrayY.specialShapeInfo(), - &zBuf, arrayZ.shapeInfo(), arrayZ.specialShapeInfo(), - nullptr); + execPairwiseTransform(nullptr, pairwise::Add, &xBuf, arrayX.shapeInfo(), + arrayX.specialShapeInfo(), &yBuf, arrayY.shapeInfo(), + arrayY.specialShapeInfo(), &zBuf, arrayZ.shapeInfo(), + arrayZ.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&arrayZ}, {&arrayX, &arrayY}); + NDArray::registerSpecialUse({&arrayZ}, {&arrayX, &arrayY}); - ASSERT_EQ(arrayE, arrayZ); + ASSERT_EQ(arrayE, arrayZ); } TEST_F(JavaInteropTests, Test_Add_1) { - auto x = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); - auto y = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); - auto e = NDArrayFactory::create('c', {5}, {2, 2, 2, 2, 2}); + auto x = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); + auto y = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); + auto e = NDArrayFactory::create('c', {5}, {2, 2, 2, 2, 2}); - NDArray::prepareSpecialUse({&x}, {&x, &y}); + NDArray::prepareSpecialUse({&x}, {&x, &y}); - sd::ops::add op; + sd::ops::add op; - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), y.buffer(), x.specialBuffer(), y.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer)y.shapeInfo(), (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)x.buffer(), y.buffer(), + x.specialBuffer(), y.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = { + (Nd4jPointer)x.shapeInfo(), (Nd4jPointer)y.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) x.buffer(), x.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer)x.specialShapeInfo()}; + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)x.buffer(), x.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)x.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo()}; - execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, + ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, + nullptr, 0, false); - NDArray::registerSpecialUse({&x}, {&x, &y}); + NDArray::registerSpecialUse({&x}, {&x, &y}); - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(JavaInteropTests, zeta_test10) { + auto x = NDArrayFactory::create( + 'c', {3, 4}, + {1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 1.01, 1.11, 1.12}); + auto q = NDArrayFactory::create( + 'c', {3, 4}, + {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); + auto z = NDArrayFactory::create('c', {3, 4}); - auto x = NDArrayFactory::create('c', {3, 4}, {1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 1.01, 1.11, 1.12}); - auto q = NDArrayFactory::create('c', {3, 4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); - auto z = NDArrayFactory::create('c', {3, 4}); - - auto e = NDArrayFactory::create('c', {3, 4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); + auto e = NDArrayFactory::create( + 'c', {3, 4}, + {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, + 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); - sd::ops::zeta op; + sd::ops::zeta op; - NDArray::prepareSpecialUse({&z}, {&x, &q}); + NDArray::prepareSpecialUse({&z}, {&x, &q}); - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), q.buffer(), x.specialBuffer(), q.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer)q.shapeInfo(), (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)q.specialShapeInfo()}; + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)x.buffer(), q.buffer(), + x.specialBuffer(), q.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = { + (Nd4jPointer)x.shapeInfo(), (Nd4jPointer)q.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)q.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.buffer(), z.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)z.buffer(), z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), + (Nd4jPointer)z.specialShapeInfo()}; - execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, + ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, + nullptr, 0, false); - NDArray::registerSpecialUse({&z}, {&x, &q}); + NDArray::registerSpecialUse({&z}, {&x, &q}); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(JavaInteropTests, Test_IAMax_1) { - auto arrayX = NDArrayFactory::create({-0.24f, -0.26f, -0.07f, -0.01f}); - auto arrayZ = arrayX.indexReduceNumber(indexreduce::IndexAbsoluteMax, nullptr); - auto exp = NDArrayFactory::create(1); + auto arrayX = NDArrayFactory::create({-0.24f, -0.26f, -0.07f, -0.01f}); + auto arrayZ = + arrayX.indexReduceNumber(indexreduce::IndexAbsoluteMax, nullptr); + auto exp = NDArrayFactory::create(1); - ASSERT_EQ(exp, arrayZ); + ASSERT_EQ(exp, arrayZ); } TEST_F(JavaInteropTests, Test_Boolean_Broadcastables_1) { - auto arrayX = NDArrayFactory::create('c', {10, 10}); - auto arrayY = NDArrayFactory::create('c', {10, 10}); - - Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(arrayX.buffer()), reinterpret_cast(arrayY.buffer()), arrayX.specialBuffer(), arrayY.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)arrayX.shapeInfo(), (Nd4jPointer)arrayY.shapeInfo(), (Nd4jPointer)arrayX.specialShapeInfo(), (Nd4jPointer)arrayY.specialShapeInfo()}; - - NDArray::prepareSpecialUse({}, {&arrayX, &arrayY}); - sd::ops::greater_equal op; - auto shapeList = calculateOutputShapes2(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0); - NDArray::registerSpecialUse({}, {&arrayX, &arrayY}); - delete shapeList; + auto arrayX = NDArrayFactory::create('c', {10, 10}); + auto arrayY = NDArrayFactory::create('c', {10, 10}); + + Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(arrayX.buffer()), + reinterpret_cast(arrayY.buffer()), + arrayX.specialBuffer(), arrayY.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)arrayX.shapeInfo(), + (Nd4jPointer)arrayY.shapeInfo(), + (Nd4jPointer)arrayX.specialShapeInfo(), + (Nd4jPointer)arrayY.specialShapeInfo()}; + + NDArray::prepareSpecialUse({}, {&arrayX, &arrayY}); + sd::ops::greater_equal op; + auto shapeList = calculateOutputShapes2(nullptr, op.getOpHash(), ptrsInBuffer, + ptrsInShapes, 2, nullptr, 0, nullptr, + 0, nullptr, 0, nullptr, 0); + NDArray::registerSpecialUse({}, {&arrayX, &arrayY}); + delete shapeList; } TEST_F(JavaInteropTests, Test_L2_Loss_3) { - auto x = NDArrayFactory::create(0.7787855863571167); - auto e = NDArrayFactory::create(0.303254); - auto z = NDArrayFactory::create(0.0); + auto x = NDArrayFactory::create(0.7787855863571167); + auto e = NDArrayFactory::create(0.303254); + auto z = NDArrayFactory::create(0.0); - NDArray::prepareSpecialUse({&z}, {&x}); + NDArray::prepareSpecialUse({&z}, {&x}); - Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(x.buffer()), x.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)x.shapeInfo(), (Nd4jPointer)x.specialShapeInfo()}; + Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(x.buffer()), + x.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)x.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffer[] = {reinterpret_cast(z.buffer()), (Nd4jPointer)z.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; + Nd4jPointer ptrsOutBuffer[] = {reinterpret_cast(z.buffer()), + (Nd4jPointer)z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), + (Nd4jPointer)z.specialShapeInfo()}; - sd::ops::l2_loss op; - auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffer, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); - ASSERT_EQ(Status::OK(), status); + sd::ops::l2_loss op; + auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, + ptrsInShapes, 1, ptrsOutBuffer, ptrsOutShapes, 1, + nullptr, 0, nullptr, 0, nullptr, 0, false); + ASSERT_EQ(Status::OK(), status); - NDArray::registerSpecialUse({&z}, {&x}); + NDArray::registerSpecialUse({&z}, {&x}); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(JavaInteropTests, Test_Fastpath_3) { - auto array0 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto array1 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto z = NDArrayFactory::create('c', {3, 2}); + auto array0 = NDArrayFactory::create('c', {3, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto array1 = NDArrayFactory::create('c', {3, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto z = NDArrayFactory::create('c', {3, 2}); - auto exp = NDArrayFactory::create('c', {3, 2}, {2.f, 4.f, 6.f, 8.f, 10.f, 12.f}); - Context ctx(1); + auto exp = NDArrayFactory::create('c', {3, 2}, + {2.f, 4.f, 6.f, 8.f, 10.f, 12.f}); + Context ctx(1); - NDArray::prepareSpecialUse({&z}, {&array0, &array1}); + NDArray::prepareSpecialUse({&z}, {&array0, &array1}); - ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), array0.specialBuffer(), array0.specialShapeInfo()); - ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), array1.specialBuffer(), array1.specialShapeInfo()); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), + array0.specialBuffer(), array0.specialShapeInfo()); + ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), + array1.specialBuffer(), array1.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); - ASSERT_EQ(2, ctx.width()); + ASSERT_EQ(2, ctx.width()); - sd::ops::add op; - execCustomOp2(nullptr, op.getOpHash(), &ctx); + sd::ops::add op; + execCustomOp2(nullptr, op.getOpHash(), &ctx); - NDArray::registerSpecialUse({&z}, {&array0, &array1}); + NDArray::registerSpecialUse({&z}, {&array0, &array1}); - ASSERT_EQ(exp, z); + ASSERT_EQ(exp, z); } TEST_F(JavaInteropTests, Test_Fastpath_4) { + auto exp = NDArrayFactory::create( + 'c', {3, 5}, {1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1}); + auto z = NDArrayFactory::create('c', {3, 5}); + Nd4jLong iArgs[] = {3, 5, 2}; - auto exp = NDArrayFactory::create('c', {3, 5}, {1,1,1,0,0, 1,1,1,1,0, 1,1,1,1,1}); - auto z = NDArrayFactory::create('c', {3, 5}); - Nd4jLong iArgs[] = {3, 5, 2}; - - - NDArray::prepareSpecialUse({&z}, {}); + NDArray::prepareSpecialUse({&z}, {}); - Context ctx(1); + Context ctx(1); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - ctx.setIArguments(iArgs, 3); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + ctx.setIArguments(iArgs, 3); - sd::ops::tri op; - execCustomOp2(nullptr, op.getOpHash(), &ctx); + sd::ops::tri op; + execCustomOp2(nullptr, op.getOpHash(), &ctx); - NDArray::registerSpecialUse({&z}, {}); + NDArray::registerSpecialUse({&z}, {}); - ASSERT_EQ(exp, z); + ASSERT_EQ(exp, z); } TEST_F(JavaInteropTests, Test_Fastpath_5) { - auto a = NDArrayFactory::create('c', {3, 3}); - auto b = NDArrayFactory::create('c', {3, 3}); - auto c = NDArrayFactory::create('c', {3, 3}); - a.linspace(1.0); - b.linspace(1.0); + auto a = NDArrayFactory::create('c', {3, 3}); + auto b = NDArrayFactory::create('c', {3, 3}); + auto c = NDArrayFactory::create('c', {3, 3}); + a.linspace(1.0); + b.linspace(1.0); - NDArray::prepareSpecialUse({&c}, {&b, &c}); + NDArray::prepareSpecialUse({&c}, {&b, &c}); - Context ctx(1); + Context ctx(1); - ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo()); - ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo()); - ctx.setOutputArray(0, c.buffer(), c.shapeInfo(), c.specialBuffer(), c.specialShapeInfo()); + ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), + a.specialShapeInfo()); + ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), + b.specialShapeInfo()); + ctx.setOutputArray(0, c.buffer(), c.shapeInfo(), c.specialBuffer(), + c.specialShapeInfo()); - sd::ops::matmul op; - auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx); + sd::ops::matmul op; + auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx); - NDArray::registerSpecialUse({&c}, {&b, &c}); + NDArray::registerSpecialUse({&c}, {&b, &c}); - ASSERT_EQ(Status::OK(), status); + ASSERT_EQ(Status::OK(), status); } TEST_F(JavaInteropTests, Test_Fastpath_6) { - auto a = NDArrayFactory::create('c', {2, 3}); - auto b = NDArrayFactory::create('c', {3, 4}); - auto gI = NDArrayFactory::create('c', {2, 4}); + auto a = NDArrayFactory::create('c', {2, 3}); + auto b = NDArrayFactory::create('c', {3, 4}); + auto gI = NDArrayFactory::create('c', {2, 4}); - auto gA = NDArrayFactory::create('c', {2, 3}); - auto gB = NDArrayFactory::create('c', {3, 4}); - a.linspace(1.0); - b.linspace(1.0); - gI.linspace(1.0); + auto gA = NDArrayFactory::create('c', {2, 3}); + auto gB = NDArrayFactory::create('c', {3, 4}); + a.linspace(1.0); + b.linspace(1.0); + gI.linspace(1.0); - NDArray::prepareSpecialUse({&gA, &gB}, {&a, &b, &gI}); + NDArray::prepareSpecialUse({&gA, &gB}, {&a, &b, &gI}); - Context ctx(1); - Nd4jLong iArgs[] = {0L, 0L, 0L}; + Context ctx(1); + Nd4jLong iArgs[] = {0L, 0L, 0L}; - ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo()); - ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo()); - ctx.setInputArray(2, gI.buffer(), gI.shapeInfo(), gI.specialBuffer(), gI.specialShapeInfo()); + ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), + a.specialShapeInfo()); + ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), + b.specialShapeInfo()); + ctx.setInputArray(2, gI.buffer(), gI.shapeInfo(), gI.specialBuffer(), + gI.specialShapeInfo()); - ctx.setOutputArray(0, gA.buffer(), gA.shapeInfo(), gA.specialBuffer(), gA.specialShapeInfo()); - ctx.setOutputArray(1, gB.buffer(), gB.shapeInfo(), gB.specialBuffer(), gB.specialShapeInfo()); + ctx.setOutputArray(0, gA.buffer(), gA.shapeInfo(), gA.specialBuffer(), + gA.specialShapeInfo()); + ctx.setOutputArray(1, gB.buffer(), gB.shapeInfo(), gB.specialBuffer(), + gB.specialShapeInfo()); - ctx.setIArguments(iArgs, 3); + ctx.setIArguments(iArgs, 3); - sd::ops::matmul_bp op; - auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx); + sd::ops::matmul_bp op; + auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx); - NDArray::registerSpecialUse({&gA, &gB}, {&a, &b, &gI}); + NDArray::registerSpecialUse({&gA, &gB}, {&a, &b, &gI}); - ASSERT_EQ(Status::OK(), status); + ASSERT_EQ(Status::OK(), status); } TEST_F(JavaInteropTests, Test_Fastpath_7) { - auto a = NDArrayFactory::create('c', {2}, {1.f, 2.f}); - auto b = NDArrayFactory::create(3.f); - auto z = NDArrayFactory::create('c', {3}); - auto e = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto a = NDArrayFactory::create('c', {2}, {1.f, 2.f}); + auto b = NDArrayFactory::create(3.f); + auto z = NDArrayFactory::create('c', {3}); + auto e = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - NDArray::prepareSpecialUse({&z}, {&a, &b}); + NDArray::prepareSpecialUse({&z}, {&a, &b}); - Context ctx(1); - Nd4jLong iArgs[] = {0L, 0L, 0L}; + Context ctx(1); + Nd4jLong iArgs[] = {0L, 0L, 0L}; - ctx.setIArguments(iArgs, 1); + ctx.setIArguments(iArgs, 1); - sd::ops::concat op; + sd::ops::concat op; - ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo()); - ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo()); + ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), + a.specialShapeInfo()); + ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), + b.specialShapeInfo()); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); - auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx); + auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx); - NDArray::registerSpecialUse({&z}, {&a, &b}); - ASSERT_EQ(Status::OK(), status); + NDArray::registerSpecialUse({&z}, {&a, &b}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(JavaInteropTests, test_bfloat16_rng) { - if (!Environment::getInstance().isCPU()) - return; + if (!Environment::getInstance().isCPU()) return; - auto z = NDArrayFactory::create('c', {10}); - RandomGenerator rng(119, 323841120L); - bfloat16 args[2] = {(bfloat16) 0.0f, (bfloat16) 1.0f}; - OpaqueDataBuffer zBuf(z.dataBuffer()); - execRandom(nullptr, sd::random::Ops::UniformDistribution, &rng, &zBuf, z.shapeInfo(), z.specialShapeInfo(), args); + auto z = NDArrayFactory::create('c', {10}); + RandomGenerator rng(119, 323841120L); + bfloat16 args[2] = {(bfloat16)0.0f, (bfloat16)1.0f}; + OpaqueDataBuffer zBuf(z.dataBuffer()); + execRandom(nullptr, sd::random::Ops::UniformDistribution, &rng, &zBuf, + z.shapeInfo(), z.specialShapeInfo(), args); - //z.printIndexedBuffer("z"); - ASSERT_TRUE(z.sumNumber().e(0) > 0); + // z.printIndexedBuffer("z"); + ASSERT_TRUE(z.sumNumber().e(0) > 0); } TEST_F(JavaInteropTests, test_ismax_view) { - auto original = NDArrayFactory::create('c', {2, 3, 40}); - auto v = original.subarray({NDIndex::all(), NDIndex::all(), NDIndex::interval(0, 40, 2)}); - v.assign(1.0); + auto original = NDArrayFactory::create('c', {2, 3, 40}); + auto v = original.subarray( + {NDIndex::all(), NDIndex::all(), NDIndex::interval(0, 40, 2)}); + v.assign(1.0); - auto e = v.like(); - auto t = e(0, {2}); - t.assign(1.0); + auto e = v.like(); + auto t = e(0, {2}); + t.assign(1.0); - auto z = v.ulike(); + auto z = v.ulike(); + Nd4jLong iArgs[] = {2L, 0L}; + Context ctx(1); + ctx.setInputArray(0, v.buffer(), v.shapeInfo(), v.specialBuffer(), + v.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + ctx.setIArguments(iArgs, 1); - Nd4jLong iArgs[] = {2L, 0L}; - Context ctx(1); - ctx.setInputArray(0, v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo()); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - ctx.setIArguments(iArgs, 1); + sd::ops::ismax op; + op.execute(&ctx); - sd::ops::ismax op; - op.execute(&ctx); - - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(JavaInteropTests, test_size_dtype_1) { - auto x = NDArrayFactory::create('c', {3}, {1.f, 1.f, 1.f}); - auto z = NDArrayFactory::create(0.0f); - auto e = NDArrayFactory::create(3.0f); + auto x = NDArrayFactory::create('c', {3}, {1.f, 1.f, 1.f}); + auto z = NDArrayFactory::create(0.0f); + auto e = NDArrayFactory::create(3.0f); - Context ctx(1); - ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); - sd::ops::size op; - auto status = op.execute(&ctx); - ASSERT_EQ(Status::OK(), status); + sd::ops::size op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(JavaInteropTests, test_expandable_array_op_1) { - auto x = NDArrayFactory::string( {2}, {"first string", "second"}); - auto d = NDArrayFactory::string(" ", sd::DataType::UTF8); + auto x = NDArrayFactory::string({2}, {"first string", "second"}); + auto d = NDArrayFactory::string(" ", sd::DataType::UTF8); - auto z0 = NDArrayFactory::create('c', {6}); - auto z1 = NDArrayFactory::string( {3}, {"", "", ""}); + auto z0 = NDArrayFactory::create('c', {6}); + auto z1 = NDArrayFactory::string({3}, {"", "", ""}); - auto exp0 = NDArrayFactory::create({0,0, 0,1, 1,0}); - auto exp1 = NDArrayFactory::string( {3}, {"first", "string", "second"}); + auto exp0 = NDArrayFactory::create({0, 0, 0, 1, 1, 0}); + auto exp1 = NDArrayFactory::string({3}, {"first", "string", "second"}); - InteropDataBuffer iz0(z0.dataBuffer()); - InteropDataBuffer iz1(z1.dataBuffer()); + InteropDataBuffer iz0(z0.dataBuffer()); + InteropDataBuffer iz1(z1.dataBuffer()); - Context ctx(1); - ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); - ctx.setInputArray(1, d.buffer(), d.shapeInfo(), d.specialBuffer(), d.specialShapeInfo()); - ctx.setOutputArray(0, &iz0, z0.shapeInfo(), z0.specialShapeInfo()); - ctx.setOutputArray(1, &iz1, z1.shapeInfo(), z1.specialShapeInfo()); + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo()); + ctx.setInputArray(1, d.buffer(), d.shapeInfo(), d.specialBuffer(), + d.specialShapeInfo()); + ctx.setOutputArray(0, &iz0, z0.shapeInfo(), z0.specialShapeInfo()); + ctx.setOutputArray(1, &iz1, z1.shapeInfo(), z1.specialShapeInfo()); - sd::ops::compat_string_split op; - auto status = op.execute(&ctx); - ASSERT_EQ(Status::OK(), status); + sd::ops::compat_string_split op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(exp0, z0); - ASSERT_EQ(exp1, z1); + ASSERT_EQ(exp0, z0); + ASSERT_EQ(exp1, z1); } TEST_F(JavaInteropTests, test_workspace_backed_arrays_1) { - if (!Environment::getInstance().isCPU()) - return; + if (!Environment::getInstance().isCPU()) return; - auto x = NDArrayFactory::create('c', {4, 3, 4, 4}); - auto y = NDArrayFactory::create('c', {4, 3, 3, 3}); - auto z = NDArrayFactory::create('c', {4, 3, 4, 4}); + auto x = NDArrayFactory::create('c', {4, 3, 4, 4}); + auto y = NDArrayFactory::create('c', {4, 3, 3, 3}); + auto z = NDArrayFactory::create('c', {4, 3, 4, 4}); - double buffer[2048]; + double buffer[2048]; - InteropDataBuffer ix(0, DataType::DOUBLE, false); - InteropDataBuffer iy(0, DataType::DOUBLE, false); - InteropDataBuffer iz(0, DataType::DOUBLE, false); + InteropDataBuffer ix(0, DataType::DOUBLE, false); + InteropDataBuffer iy(0, DataType::DOUBLE, false); + InteropDataBuffer iz(0, DataType::DOUBLE, false); - // we're imitating workspace-managed array here - ix.setPrimary(buffer + 64, x.lengthOf()); - iy.setPrimary(buffer + 64 + x.lengthOf(), y.lengthOf()); - iz.setPrimary(buffer + 64 + x.lengthOf() + y.lengthOf(), z.lengthOf()); + // we're imitating workspace-managed array here + ix.setPrimary(buffer + 64, x.lengthOf()); + iy.setPrimary(buffer + 64 + x.lengthOf(), y.lengthOf()); + iz.setPrimary(buffer + 64 + x.lengthOf() + y.lengthOf(), z.lengthOf()); - Context ctx(1); - ctx.setInputArray(0, &ix, x.shapeInfo(), x.specialShapeInfo()); - ctx.setInputArray(1, &iy, y.shapeInfo(), y.specialShapeInfo()); - ctx.setOutputArray(0, &iz, z.shapeInfo(), z.specialShapeInfo()); + Context ctx(1); + ctx.setInputArray(0, &ix, x.shapeInfo(), x.specialShapeInfo()); + ctx.setInputArray(1, &iy, y.shapeInfo(), y.specialShapeInfo()); + ctx.setOutputArray(0, &iz, z.shapeInfo(), z.specialShapeInfo()); - ctx.setIArguments({2, 2, 1, 1, 0, 0, 1, 1, 0, 0, 0}); + ctx.setIArguments({2, 2, 1, 1, 0, 0, 1, 1, 0, 0, 0}); - sd::ops::maxpool2d_bp op; - auto status = op.execute(&ctx); - ASSERT_EQ(Status::OK(), status); + sd::ops::maxpool2d_bp op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); } TEST_F(JavaInteropTests, test_linspace_shape_1) { - if (!Environment::getInstance().isCPU()) - return; - - sd::ops::lin_space op; - double tArgs[2] = {1.0, 10.0}; - Nd4jLong iArgs = 10L; - int dArg = (int) sd::DataType::FLOAT32; - auto result = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 2, &iArgs, 1, nullptr, 0, &dArg, 1); - - ASSERT_EQ(1, result->size()); - delete result; + if (!Environment::getInstance().isCPU()) return; + + sd::ops::lin_space op; + double tArgs[2] = {1.0, 10.0}; + Nd4jLong iArgs = 10L; + int dArg = (int)sd::DataType::FLOAT32; + auto result = + ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, + tArgs, 2, &iArgs, 1, nullptr, 0, &dArg, 1); + + ASSERT_EQ(1, result->size()); + delete result; } /* @@ -1403,8 +2173,20 @@ TEST_F(JavaInteropTests, Test_Results_Conversion_1) { } */ // TEST_F(JavaInteropTests, Test_NLP_Aggregations_1) { -// std::array syn0 = {-0.022756476f, 0.0126427775f, 0.011029151f, -0.013542821f, -0.012327666f, -0.0032439455f, -0.008405109f, -0.016651405f, 0.0015980572f, -0.007442479f, 0.019937921f, -0.016222188f, -0.016541665f, 0.013372547f, 0.006625724f, 0.0058958204f, -0.01281835f, -6.2343775E-4f, 0.0019826533f, 0.010253737f, -0.010291531f, 0.0019767822f, 0.018071089f, -0.0117441565f, 0.023176769f, 0.0032820583f, 0.0061427564f, -0.01696018f, 0.0054971874f, 0.0043818625f, 0.019323621f, 0.0036080598f, 0.024376748f, -0.0024499625f, 0.019496754f, 0.010563821f, -2.0503551E-4f, -0.0146056535f, 0.009949291f, 0.017604528f, -0.0050302492f, -0.022060446f, 0.016468976f, -0.0034482107f, 0.010270384f, -0.0063356445f, -0.019934833f, -0.02325993f, 0.016109904f, -0.0031106502f, -0.0020592287f, 0.024031803f, 0.005184144f, -0.024887865f, 0.02100272f, 3.395051E-4f, 0.018432347f, 5.673498E-4f, -0.020073576f, 0.010949242f}; -// std::array syn1; +// std::array syn0 = {-0.022756476f, 0.0126427775f, 0.011029151f, +// -0.013542821f, -0.012327666f, -0.0032439455f, -0.008405109f, +// -0.016651405f, 0.0015980572f, -0.007442479f, 0.019937921f, -0.016222188f, +// -0.016541665f, 0.013372547f, 0.006625724f, 0.0058958204f, -0.01281835f, +// -6.2343775E-4f, 0.0019826533f, 0.010253737f, -0.010291531f, +// 0.0019767822f, 0.018071089f, -0.0117441565f, 0.023176769f, 0.0032820583f, +// 0.0061427564f, -0.01696018f, 0.0054971874f, 0.0043818625f, 0.019323621f, +// 0.0036080598f, 0.024376748f, -0.0024499625f, 0.019496754f, 0.010563821f, +// -2.0503551E-4f, -0.0146056535f, 0.009949291f, 0.017604528f, +// -0.0050302492f, -0.022060446f, 0.016468976f, -0.0034482107f, +// 0.010270384f, -0.0063356445f, -0.019934833f, -0.02325993f, 0.016109904f, +// -0.0031106502f, -0.0020592287f, 0.024031803f, 0.005184144f, +// -0.024887865f, 0.02100272f, 3.395051E-4f, 0.018432347f, 5.673498E-4f, +// -0.020073576f, 0.010949242f}; std::array syn1; // std::array exp; // for (int e = 0; e < syn1.size(); e++) @@ -1412,8 +2194,8 @@ TEST_F(JavaInteropTests, Test_Results_Conversion_1) { // for (int e = 0; e < exp.size(); e++) { // auto f = static_cast(e); -// auto tmp = sd::math::nd4j_exp((f / 100000.0 * 2.0 - 1.0) * 6.0); -// exp[e] = static_cast(tmp / (tmp + 1.0)); +// auto tmp = sd::math::nd4j_exp((f / 100000.0 * 2.0 +// - 1.0) * 6.0); exp[e] = static_cast(tmp / (tmp + 1.0)); // } // auto maxTypes = 5; @@ -1432,9 +2214,9 @@ TEST_F(JavaInteropTests, Test_Results_Conversion_1) { // int indexPos = maxTypes * batchLimit; // int intArraysPos = indexPos + (maxIndexArguments * batchLimit); -// int realPos = (intArraysPos + (maxIntArrays * maxIntArraySize * batchLimit)); -// int argsPos = (realPos + ((maxRealArguments * batchLimit))) / 2; -// int shapesPos = argsPos + (maxArgs * batchLimit); +// int realPos = (intArraysPos + (maxIntArrays * maxIntArraySize * +// batchLimit)); int argsPos = (realPos + ((maxRealArguments * batchLimit))) +// / 2; int shapesPos = argsPos + (maxArgs * batchLimit); // std::vector intArray0({0, 0, 0, 0, 0}); // std::vector intArray1({1, 0, 0, 0, 0}); @@ -1442,7 +2224,8 @@ TEST_F(JavaInteropTests, Test_Results_Conversion_1) { // std::vector indexingArgs0({1, 20, 5, 0, 100000, 3, 0, 0, 0}); // std::vector indexingArgs1({0, 20, 5, 0, 100000, 3, 1, 0, 0}); -// std::vector realArgs0({0.024964055335354007f, 3.0768702268737162E18f}); +// std::vector +// realArgs0({0.024964055335354007f, 3.0768702268737162E18f}); // int argSize = 6; // int shapesSize = 0; @@ -1493,6 +2276,7 @@ TEST_F(JavaInteropTests, Test_Results_Conversion_1) { // ptrptr[idx+1] = reinterpret_cast(syn1.data()); // ptrptr[idx+2] = reinterpret_cast(exp.data()); - -// execAggregateBatchFloat(nullptr, numAggregates, opNum, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIndexArguments, maxRealArguments, pointer.data()); +// execAggregateBatchFloat(nullptr, numAggregates, opNum, maxArgs, +// maxShapes, maxIntArrays, maxIntArraySize, maxIndexArguments, +// maxRealArguments, pointer.data()); // } diff --git a/libnd4j/tests_cpu/layers_tests/LambdaTests.cu b/libnd4j/tests_cpu/layers_tests/LambdaTests.cu index a114f71798ad..f210fa8cf198 100644 --- a/libnd4j/tests_cpu/layers_tests/LambdaTests.cu +++ b/libnd4j/tests_cpu/layers_tests/LambdaTests.cu @@ -18,202 +18,192 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include -#include #include #include +#include + +#include "testlayers.h" + using namespace sd; class LambdaTests : public testing::Test { -public: - - LambdaTests() { - printf("\n"); - fflush(stdout); - } + public: + LambdaTests() { + printf("\n"); + fflush(stdout); + } }; template -__global__ void runLambda(double *input, double *output, Nd4jLong length, Lambda lambda) { - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (Nd4jLong e = tid; e < length; e += gridDim.x * blockDim.x) { - output[e] = lambda(input[e]); - } +__global__ void runLambda(double *input, double *output, Nd4jLong length, + Lambda lambda) { + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + for (Nd4jLong e = tid; e < length; e += gridDim.x * blockDim.x) { + output[e] = lambda(input[e]); + } } -void launcher(cudaStream_t *stream, double *input, double *output, Nd4jLong length) { - //auto f = [] __host__ __device__ (double x) -> double { - // return x + 1.; - //}; - auto f = LAMBDA_D(x) { - return x+1.; - }; +void launcher(cudaStream_t *stream, double *input, double *output, + Nd4jLong length) { + // auto f = [] __host__ __device__ (double x) -> double { + // return x + 1.; + //}; + auto f = LAMBDA_D(x) { return x + 1.; }; - - runLambda<<<128, 128, 128, *stream>>>(input, output, length, f); + runLambda<<<128, 128, 128, *stream>>>(input, output, length, f); } - TEST_F(LambdaTests, test_basic_1) { - auto x = NDArrayFactory::create('c', {5}); - auto e = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); + auto x = NDArrayFactory::create('c', {5}); + auto e = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); + // x.applyLambda(f, nullptr); + launcher(LaunchContext::defaultContext()->getCudaStream(), + (double *)x.specialBuffer(), (double *)x.specialBuffer(), + x.lengthOf()); + auto res = + cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream()); + ASSERT_EQ(0, res); - - //x.applyLambda(f, nullptr); - launcher(LaunchContext::defaultContext()->getCudaStream(), (double *)x.specialBuffer(), (double *)x.specialBuffer(), x.lengthOf()); - auto res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream()); - ASSERT_EQ(0, res); - - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } void test(NDArray &x) { - auto f = LAMBDA_D(x) { - return x+1.; - }; + auto f = LAMBDA_D(x) { return x + 1.; }; - x.applyLambda(f, x); + x.applyLambda(f, x); } template void test2(NDArray &x) { - auto f = LAMBDA_T(x) { - return x+1.; - }; + auto f = LAMBDA_T(x) { return x + 1.; }; - x.applyLambda(f, x); + x.applyLambda(f, x); } void testPairwise(NDArray &x, NDArray &y) { - auto f = LAMBDA_DD(x, y) { - return x + y +1.; - }; + auto f = LAMBDA_DD(x, y) { return x + y + 1.; }; - x.applyPairwiseLambda(y, f, x); + x.applyPairwiseLambda(y, f, x); } void testTriplewise(NDArray &i, NDArray &j, NDArray &k) { - auto f = LAMBDA_DDD(i, j, k) { - return i + j + k + 2.; - }; + auto f = LAMBDA_DDD(i, j, k) { return i + j + k + 2.; }; - i.applyTriplewiseLambda(j, k, f, i); + i.applyTriplewiseLambda(j, k, f, i); } void testIndexed(NDArray &x) { - auto f = ILAMBDA_D(x) { - return _idx + 1.; - }; + auto f = ILAMBDA_D(x) { return _idx + 1.; }; - x.applyIndexedLambda(f, x); + x.applyIndexedLambda(f, x); } void testIndexedPairwise(NDArray &x, NDArray &y) { - auto f = ILAMBDA_DD(x, y) { - return _idx + x + y +1.; - }; + auto f = ILAMBDA_DD(x, y) { return _idx + x + y + 1.; }; - x.applyIndexedPairwiseLambda(y, f, x); + x.applyIndexedPairwiseLambda(y, f, x); } TEST_F(LambdaTests, test_basic_2) { - auto x = NDArrayFactory::create('c', {5}); - auto e = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); + auto x = NDArrayFactory::create('c', {5}); + auto e = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); - test(x); + test(x); - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(LambdaTests, test_basic_3) { - auto x = NDArrayFactory::create('c', {5}); - auto e = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto x = NDArrayFactory::create('c', {5}); + auto e = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); - test(x); + test(x); - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(LambdaTests, test_basic_4) { - auto x = NDArrayFactory::create('c', {5}); - auto e = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto x = NDArrayFactory::create('c', {5}); + auto e = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); - test2(x); + test2(x); - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(LambdaTests, test_basic_5) { - auto x = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); - auto y = NDArrayFactory::create('c', {5}, {2., 2., 2., 2., 2.}); - auto e = NDArrayFactory::create('c', {5}, {4., 4., 4., 4., 4.}); + auto x = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); + auto y = NDArrayFactory::create('c', {5}, {2., 2., 2., 2., 2.}); + auto e = NDArrayFactory::create('c', {5}, {4., 4., 4., 4., 4.}); - testPairwise(x, y); + testPairwise(x, y); - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(LambdaTests, test_basic_6) { - auto x = NDArrayFactory::create('c', {5}); - auto e = NDArrayFactory::create('c', {5}, {1., 2., 3., 4., 5.}); + auto x = NDArrayFactory::create('c', {5}); + auto e = NDArrayFactory::create('c', {5}, {1., 2., 3., 4., 5.}); - testIndexed(x); + testIndexed(x); - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(LambdaTests, test_basic_7) { - auto w = NDArrayFactory::create('c', {5}, {0., 0., 0., 0., 0.}); - auto x = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); - auto y = NDArrayFactory::create('c', {5}, {2., 2., 2., 2., 2.}); - auto e = NDArrayFactory::create('c', {5}, {5., 5., 5., 5., 5.}); + auto w = NDArrayFactory::create('c', {5}, {0., 0., 0., 0., 0.}); + auto x = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); + auto y = NDArrayFactory::create('c', {5}, {2., 2., 2., 2., 2.}); + auto e = NDArrayFactory::create('c', {5}, {5., 5., 5., 5., 5.}); - testTriplewise(w, x, y); + testTriplewise(w, x, y); - ASSERT_EQ(e, w); + ASSERT_EQ(e, w); } TEST_F(LambdaTests, test_basic_8) { - auto x = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); - auto y = NDArrayFactory::create('c', {5}, {2., 2., 2., 2., 2.}); - auto e = NDArrayFactory::create('c', {5}, {4., 5., 6., 7., 8.}); + auto x = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); + auto y = NDArrayFactory::create('c', {5}, {2., 2., 2., 2., 2.}); + auto e = NDArrayFactory::create('c', {5}, {4., 5., 6., 7., 8.}); - testIndexedPairwise(x, y); + testIndexedPairwise(x, y); - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } - template void testPairwiseMy(NDArray &x, NDArray &y, NDArray &z) { + auto f = LAMBDA_TT(x, y) { + return sd::math::nd4j_max(x, (T)0.f) - x * y + + sd::math::nd4j_log( + (T)1.f + sd::math::nd4j_exp(-sd::math::nd4j_abs(x))); + }; - auto f = LAMBDA_TT(x, y){ - return sd::math::nd4j_max(x, (T)0.f) - - x * y - + sd::math::nd4j_log((T)1.f - + sd::math::nd4j_exp(-sd::math::nd4j_abs(x))); - }; - - x.applyPairwiseLambda(y, f, z); + x.applyPairwiseLambda(y, f, z); } /////////////////////////////////////////////////////////////////// TEST_F(LambdaTests, test_basic_9) { - - NDArray labels('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray output('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray expected('c', {2,3,4}, {0.744397, 0.598139, 0.554355, 0.913015, 0.474077, 1.037488, 0.403186, 1.171101, 0.341154, 1.313262, 0.287335, 1.463282, 0.241008, 1.620417, 0.201413, 1.783901, 0.167786, 1.952978, 2.039387, 0.126928, 0.115520, 2.305083, 0.095545, 2.486836}); - - logits.linspace(0.1, 0.1); - - NDArray::prepareSpecialUse({&output}, {&logits, &labels}); - testPairwiseMy(logits, labels, output); - NDArray::registerSpecialUse({&output}, {&logits, &labels}); - - // output.printBuffer(nullptr, -1, true); - ASSERT_TRUE(expected.equalsTo(output)); + NDArray labels('c', {2, 3, 4}, {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, + 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray output('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray expected( + 'c', {2, 3, 4}, + {0.744397, 0.598139, 0.554355, 0.913015, 0.474077, 1.037488, + 0.403186, 1.171101, 0.341154, 1.313262, 0.287335, 1.463282, + 0.241008, 1.620417, 0.201413, 1.783901, 0.167786, 1.952978, + 2.039387, 0.126928, 0.115520, 2.305083, 0.095545, 2.486836}); + + logits.linspace(0.1, 0.1); + + NDArray::prepareSpecialUse({&output}, {&logits, &labels}); + testPairwiseMy(logits, labels, output); + NDArray::registerSpecialUse({&output}, {&logits, &labels}); + + // output.printBuffer(nullptr, -1, true); + ASSERT_TRUE(expected.equalsTo(output)); } diff --git a/libnd4j/tests_cpu/layers_tests/LaunchContextCudaTests.cu b/libnd4j/tests_cpu/layers_tests/LaunchContextCudaTests.cu index e16df80e680c..30003277de28 100644 --- a/libnd4j/tests_cpu/layers_tests/LaunchContextCudaTests.cu +++ b/libnd4j/tests_cpu/layers_tests/LaunchContextCudaTests.cu @@ -18,108 +18,111 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include +#include +#include #include +#include #include -#include +#include +#include #include -#include -#include #include -#include -#include -#include -#include +#include +#include +#include + #include -#include + +#include "testlayers.h" using namespace sd; using namespace sd::ops; class LaunchContextCudaTests : public testing::Test { - // + // }; - void acquireContext(int threadId, int &deviceId) { - deviceId = AffinityManager::currentDeviceId(); + deviceId = AffinityManager::currentDeviceId(); - nd4j_printf("Creating thread: [%i]; assigned deviceId: [%i];\n", threadId, deviceId); + nd4j_printf("Creating thread: [%i]; assigned deviceId: [%i];\n", threadId, + deviceId); - auto lc = LaunchContext::defaultContext(); - nd4j_printf("LC: [%p]\n", lc); + auto lc = LaunchContext::defaultContext(); + nd4j_printf("LC: [%p]\n", lc); - nd4j_printf("reductionPtr: [%p]; stream: [%p];\n", lc->getReductionPointer(), lc->getCudaStream()); + nd4j_printf("reductionPtr: [%p]; stream: [%p];\n", lc->getReductionPointer(), + lc->getCudaStream()); } TEST_F(LaunchContextCudaTests, basic_test_1) { - int deviceA, deviceB; - std::thread threadA(acquireContext, 0, std::ref(deviceA)); - std::thread threadB(acquireContext, 1, std::ref(deviceB)); + int deviceA, deviceB; + std::thread threadA(acquireContext, 0, std::ref(deviceA)); + std::thread threadB(acquireContext, 1, std::ref(deviceB)); - threadA.join(); - threadB.join(); - nd4j_printf("All threads joined\n",""); + threadA.join(); + threadB.join(); + nd4j_printf("All threads joined\n", ""); - if (AffinityManager::numberOfDevices() > 1) - ASSERT_NE(deviceA, deviceB); + if (AffinityManager::numberOfDevices() > 1) ASSERT_NE(deviceA, deviceB); } -void fillArray(int tid, std::vector &arrays) { - auto array = NDArrayFactory::create_('c', {3, 10}); - nd4j_printf("Array created on device [%i]\n", AffinityManager::currentDeviceId()); - array->assign(tid); - arrays[tid] = array; +void fillArray(int tid, std::vector &arrays) { + auto array = NDArrayFactory::create_('c', {3, 10}); + nd4j_printf("Array created on device [%i]\n", + AffinityManager::currentDeviceId()); + array->assign(tid); + arrays[tid] = array; } TEST_F(LaunchContextCudaTests, basic_test_2) { - std::vector arrays(2); + std::vector arrays(2); - std::thread threadA(fillArray, 0, std::ref(arrays)); - std::thread threadB(fillArray, 1, std::ref(arrays)); + std::thread threadA(fillArray, 0, std::ref(arrays)); + std::thread threadB(fillArray, 1, std::ref(arrays)); - threadA.join(); - threadB.join(); + threadA.join(); + threadB.join(); - for (int e = 0; e < 2; e++) { - auto array = arrays[e]; - ASSERT_EQ(e, array->e(0)); + for (int e = 0; e < 2; e++) { + auto array = arrays[e]; + ASSERT_EQ(e, array->e(0)); - delete array; - } + delete array; + } } void initAffinity(int tid, std::vector &aff) { - auto affinity = AffinityManager::currentDeviceId(); - aff[tid] = affinity; - nd4j_printf("Thread [%i] affined with device [%i]\n", tid, affinity); + auto affinity = AffinityManager::currentDeviceId(); + aff[tid] = affinity; + nd4j_printf("Thread [%i] affined with device [%i]\n", tid, affinity); } TEST_F(LaunchContextCudaTests, basic_test_3) { - auto totalThreads = AffinityManager::numberOfDevices() * 4; - nd4j_printf("Total threads: %i\n", totalThreads); - std::vector affinities(totalThreads); + auto totalThreads = AffinityManager::numberOfDevices() * 4; + nd4j_printf("Total threads: %i\n", totalThreads); + std::vector affinities(totalThreads); - for (int e = 0; e < totalThreads; e++) { - std::thread thread(initAffinity, e, std::ref(affinities)); + for (int e = 0; e < totalThreads; e++) { + std::thread thread(initAffinity, e, std::ref(affinities)); - thread.join(); - } + thread.join(); + } - std::vector hits(AffinityManager::numberOfDevices()); - std::fill(hits.begin(), hits.end(), 0); + std::vector hits(AffinityManager::numberOfDevices()); + std::fill(hits.begin(), hits.end(), 0); - // we need to make sure all threads were attached to "valid" devices - for (int e = 0; e < totalThreads; e++) { - auto aff = affinities[e]; - ASSERT_TRUE(aff >= 0 && aff < AffinityManager::numberOfDevices()); + // we need to make sure all threads were attached to "valid" devices + for (int e = 0; e < totalThreads; e++) { + auto aff = affinities[e]; + ASSERT_TRUE(aff >= 0 && aff < AffinityManager::numberOfDevices()); - hits[aff]++; - } + hits[aff]++; + } - // now we check if all devices got some threads - for (int e = 0; e < AffinityManager::numberOfDevices(); e++) { - ASSERT_GT(hits[e], 0); - } + // now we check if all devices got some threads + for (int e = 0; e < AffinityManager::numberOfDevices(); e++) { + ASSERT_GT(hits[e], 0); + } } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/LegacyOpsCudaTests.cu b/libnd4j/tests_cpu/layers_tests/LegacyOpsCudaTests.cu index 922d94afdbdc..6075ef1784d2 100644 --- a/libnd4j/tests_cpu/layers_tests/LegacyOpsCudaTests.cu +++ b/libnd4j/tests_cpu/layers_tests/LegacyOpsCudaTests.cu @@ -18,45 +18,51 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include +#include #include +#include #include -#include +#include +#include #include -#include -#include #include -#include -#include -#include -#include +#include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::ops; -class LegacyOpsCudaTests : public testing::Test { - -}; - +class LegacyOpsCudaTests : public testing::Test {}; TEST_F(LegacyOpsCudaTests, test_sortTad_1) { - auto x = NDArrayFactory::create('c', {3, 5}, {1.f, 3.f, 0.f, 2.f, 4.f, - 6.f, 5.f, 9.f, 7.f, 8.f, - 10.f, 11.f, 14.f, 12.f, 13.f}); + auto x = + NDArrayFactory::create('c', {3, 5}, + {1.f, 3.f, 0.f, 2.f, 4.f, 6.f, 5.f, 9.f, + 7.f, 8.f, 10.f, 11.f, 14.f, 12.f, 13.f}); - auto e = NDArrayFactory::create('c', {3, 5}, {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f}); + auto e = + NDArrayFactory::create('c', {3, 5}, + {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, + 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f}); - int axis = 1; - auto packX = ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), axis); + int axis = 1; + auto packX = + ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), axis); - Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; + Nd4jPointer extras[2] = {nullptr, + LaunchContext::defaultContext()->getCudaStream()}; - x.syncToDevice(); - sortTad(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), &axis, 1, packX.platformShapeInfo(), packX.platformOffsets(), false); - x.tickWriteDevice(); + x.syncToDevice(); + sortTad(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), &axis, 1, packX.platformShapeInfo(), + packX.platformOffsets(), false); + x.tickWriteDevice(); - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(LegacyOpsCudaTests, test_sort_1) { diff --git a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp index fe9c5a7a076a..899a5ec48142 100644 --- a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp @@ -18,751 +18,788 @@ // Created by raver119 on 16.10.2017. // -#include "testlayers.h" #include +#include #include +#include #include -#include +#include +#include #include -#include -#include #include -#include -#include -#include -#include +#include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::ops; -class LegacyOpsTests : public testing::Test { - -}; - +class LegacyOpsTests : public testing::Test {}; TEST_F(LegacyOpsTests, TransformTests_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(1.0); - auto z = NDArrayFactory::create('c', {5,5}); - auto exp = NDArrayFactory::create('c', {5, 5}); - exp.assign(-1.0); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(1.0); + auto z = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + exp.assign(-1.0); - sd::ops::LegacyTransformSameOp op(transform::Neg); // Neg - auto status = op.execute({&x}, {&z}, {}, {}, {}); - ASSERT_EQ(status, ND4J_STATUS_OK); - //z.printIndexedBuffer("Output NEG"); - ASSERT_TRUE(z.equalsTo(&exp)); + sd::ops::LegacyTransformSameOp op(transform::Neg); // Neg + auto status = op.execute({&x}, {&z}, {}, {}, {}); + ASSERT_EQ(status, ND4J_STATUS_OK); + // z.printIndexedBuffer("Output NEG"); + ASSERT_TRUE(z.equalsTo(&exp)); } TEST_F(LegacyOpsTests, TransformTests_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(1.0); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(1.0); - auto exp = NDArrayFactory::create('c', {5, 5}); - exp.assign(-1.0); + auto exp = NDArrayFactory::create('c', {5, 5}); + exp.assign(-1.0); - sd::ops::LegacyTransformSameOp op(transform::Neg); // Neg - auto result = op.evaluate({&x}, {}, {}); + sd::ops::LegacyTransformSameOp op(transform::Neg); // Neg + auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(1, result.size()); + ASSERT_EQ(1, result.size()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.equalsTo(z)); - - + ASSERT_TRUE(exp.equalsTo(z)); } -TEST_F(LegacyOpsTests, Reciprocal_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(2.0f); - - auto ethalon = NDArrayFactory::create('c', {5, 5}); - ethalon.assign(0.5f); +TEST_F(LegacyOpsTests, Reciprocal_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(2.0f); - sd::ops::LegacyTransformSameOp op(transform::Reciprocal); // Reciprocal - Nd4jStatus status = op.execute({&x}, {&x}, {}, {}, {}); + auto ethalon = NDArrayFactory::create('c', {5, 5}); + ethalon.assign(0.5f); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(ethalon.equalsTo(&x)); + sd::ops::LegacyTransformSameOp op(transform::Reciprocal); // Reciprocal + Nd4jStatus status = op.execute({&x}, {&x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(ethalon.equalsTo(&x)); } -TEST_F(LegacyOpsTests, PWT_Tests_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(2.0); +TEST_F(LegacyOpsTests, PWT_Tests_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(2.0); - auto y = NDArrayFactory::create('c', {5, 5}); - y.assign(3.0); + auto y = NDArrayFactory::create('c', {5, 5}); + y.assign(3.0); - auto exp = NDArrayFactory::create('c', {5, 5}); - exp.assign(6.0); + auto exp = NDArrayFactory::create('c', {5, 5}); + exp.assign(6.0); - sd::ops::LegacyPairwiseTransformOp op(pairwise::Multiply); // Multiply - Nd4jStatus status = op.execute({&x, &y}, {&x}, {}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); - - ASSERT_TRUE(exp.equalsTo(&x)); + sd::ops::LegacyPairwiseTransformOp op(pairwise::Multiply); // Multiply + Nd4jStatus status = op.execute({&x, &y}, {&x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(exp.equalsTo(&x)); } -TEST_F(LegacyOpsTests, PWT_Tests_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(2.0); - - auto y = NDArrayFactory::create('c', {5, 5}); - y.assign(3.0); +TEST_F(LegacyOpsTests, PWT_Tests_2) { + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(2.0); - auto exp = NDArrayFactory::create('c', {5, 5}); - exp.assign(6.0); + auto y = NDArrayFactory::create('c', {5, 5}); + y.assign(3.0); - sd::ops::LegacyPairwiseTransformOp op(pairwise::Multiply); // Multiply - auto result = op.evaluate({&x, &y}, {}, {}); + auto exp = NDArrayFactory::create('c', {5, 5}); + exp.assign(6.0); - auto z = result.at(0); + sd::ops::LegacyPairwiseTransformOp op(pairwise::Multiply); // Multiply + auto result = op.evaluate({&x, &y}, {}, {}); - //z->printBuffer("Z"); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + // z->printBuffer("Z"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(LegacyOpsTests, Scalar_Test_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(2.0); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(2.0); - auto exp = NDArrayFactory::create('c', {5, 5}); - exp.assign(7.0); + auto exp = NDArrayFactory::create('c', {5, 5}); + exp.assign(7.0); - sd::ops::LegacyScalarOp op(scalar::Add); - op.execute({&x}, {&x}, {5.0}, {}, {}); // + sd::ops::LegacyScalarOp op(scalar::Add); + op.execute({&x}, {&x}, {5.0}, {}, {}); // - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(LegacyOpsTests, Scalar_Test_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(2.0); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(2.0); - auto exp = NDArrayFactory::create('c', {5, 5}); - exp.assign(7.0); + auto exp = NDArrayFactory::create('c', {5, 5}); + exp.assign(7.0); - auto y = NDArrayFactory::create(5.0f); + auto y = NDArrayFactory::create(5.0f); - sd::ops::LegacyScalarOp op(scalar::Add, y); - auto result = op.evaluate({&x}, {}, {}); + sd::ops::LegacyScalarOp op(scalar::Add, y); + auto result = op.evaluate({&x}, {}, {}); - auto z = result.at(0); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(LegacyOpsTests, ReduceTests_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(1.0); - int opNum = reduce::Sum; - sd::ops::LegacyReduceSameOp op(opNum); - - auto result = op.evaluate({&x}, {}, {}); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(1.0); + int opNum = reduce::Sum; + sd::ops::LegacyReduceSameOp op(opNum); - ASSERT_EQ(1, result.size()); + auto result = op.evaluate({&x}, {}, {}); - auto z = result.at(0); - // z->printBuffer("ReduceTest1"); - ASSERT_TRUE(z->isScalar()); - ASSERT_NEAR(x.sumNumber().e(0), z->e(0), 1e-5f); + ASSERT_EQ(1, result.size()); - + auto z = result.at(0); + // z->printBuffer("ReduceTest1"); + ASSERT_TRUE(z.isScalar()); + ASSERT_NEAR(x.sumNumber().e(0), z.e(0), 1e-5f); } - TEST_F(LegacyOpsTests, ReduceTests_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(1.0); - - sd::ops::LegacyReduceSameOp op(reduce::Sum); - auto axis = NDArrayFactory::create('c', {1}, {1}); - auto result = op.evaluate({&x, &axis}, {}, {}); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(1.0); - ASSERT_EQ(1, result.size()); + sd::ops::LegacyReduceSameOp op(reduce::Sum); + auto axis = NDArrayFactory::create('c', {1}, {1}); + auto result = op.evaluate({&x, &axis}, {}, {}); - auto z = result.at(0); + ASSERT_EQ(1, result.size()); - auto exp = x.reduceAlongDimension(reduce::Sum, {1}); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto exp = x.reduceAlongDimension(reduce::Sum, {1}); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(LegacyOpsTests, ReduceTests_3) { - auto x = NDArrayFactory::create('c', {3, 5}); - x.linspace(1); - auto indices = NDArrayFactory::create('c', {1,1}, {1}); - - - sd::ops::LegacyReduceSameOp op(reduce::Sum); - auto result = op.evaluate({&x, &indices}, {}, {}); - auto z = result.at(0); - auto exp = x.reduceAlongDimension(reduce::Sum,{1}); + auto x = NDArrayFactory::create('c', {3, 5}); + x.linspace(1); + auto indices = NDArrayFactory::create('c', {1, 1}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::LegacyReduceSameOp op(reduce::Sum); + auto result = op.evaluate({&x, &indices}, {}, {}); + auto z = result.at(0); + auto exp = x.reduceAlongDimension(reduce::Sum, {1}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(LegacyOpsTests, ReduceTests_4) { - auto x = NDArrayFactory::create('c', {2, 3, 5}); - x.linspace(1); - auto indices = NDArrayFactory::create('c', {1, 1}, {1}); - - - sd::ops::LegacyReduceSameOp op(reduce::Sum); - auto result = op.evaluate({&x, &indices}, {}, {}, {true}); - auto z = result.at(0); - auto exp = x.reduceAlongDimension(reduce::Sum, {1}, true); - // indices.printShapeInfo("Indices shape"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - // z->printIndexedBuffer("Output reduce 4"); - // exp.printIndexedBuffer("Expected reduce 4"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto x = NDArrayFactory::create('c', {2, 3, 5}); + x.linspace(1); + auto indices = NDArrayFactory::create('c', {1, 1}, {1}); + + sd::ops::LegacyReduceSameOp op(reduce::Sum); + auto result = op.evaluate({&x, &indices}, {}, {}, {true}); + auto z = result.at(0); + auto exp = x.reduceAlongDimension(reduce::Sum, {1}, true); + // indices.printShapeInfo("Indices shape"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + // z->printIndexedBuffer("Output reduce 4"); + // exp.printIndexedBuffer("Expected reduce 4"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(LegacyOpsTests, ReduceTests_5) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(1.0); - int opNum = reduce::Mean; - sd::ops::LegacyReduceFloatOp op(opNum); - - auto result = op.evaluate({&x}); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(1.0); + int opNum = reduce::Mean; + sd::ops::LegacyReduceFloatOp op(opNum); - ASSERT_EQ(1, result.size()); + auto result = op.evaluate({&x}); - auto z = result.at(0); - // z->printBuffer("ReduceTest1"); - ASSERT_TRUE(z->isScalar()); - ASSERT_NEAR(x.meanNumber().e(0), z->e(0), 1e-5f); + ASSERT_EQ(1, result.size()); - + auto z = result.at(0); + // z->printBuffer("ReduceTest1"); + ASSERT_TRUE(z.isScalar()); + ASSERT_NEAR(x.meanNumber().e(0), z.e(0), 1e-5f); } - TEST_F(LegacyOpsTests, ReduceTests_6) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(1.0); - auto axis = NDArrayFactory::create('c', {1}, {1}); - sd::ops::LegacyReduceFloatOp op(reduce::Mean); - - auto result = op.evaluate({&x, &axis}, {}, {}); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(1.0); + auto axis = NDArrayFactory::create('c', {1}, {1}); + sd::ops::LegacyReduceFloatOp op(reduce::Mean); - ASSERT_EQ(1, result.size()); + auto result = op.evaluate({&x, &axis}, {}, {}); - auto z = result.at(0); + ASSERT_EQ(1, result.size()); - auto exp = x.reduceAlongDimension(reduce::Mean, {1}); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto exp = x.reduceAlongDimension(reduce::Mean, {1}); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(LegacyOpsTests, ReduceTests_7) { - auto x = NDArrayFactory::create('c', {3, 5}); - x.linspace(1); - auto indices = NDArrayFactory::create('c', {1,1}, {1}); - - - sd::ops::LegacyReduceFloatOp op(reduce::Mean); - auto result = op.evaluate({&x, &indices}, {}, {}); - auto z = result.at(0); - auto exp = x.reduceAlongDimension(reduce::Mean,{1}); + auto x = NDArrayFactory::create('c', {3, 5}); + x.linspace(1); + auto indices = NDArrayFactory::create('c', {1, 1}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::LegacyReduceFloatOp op(reduce::Mean); + auto result = op.evaluate({&x, &indices}, {}, {}); + auto z = result.at(0); + auto exp = x.reduceAlongDimension(reduce::Mean, {1}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(LegacyOpsTests, ReduceTests_8) { - auto x = NDArrayFactory::create('c', {2, 3, 5}); - x.linspace(1); - auto indices = NDArrayFactory::create('c', {1}, {1}); - - - sd::ops::LegacyReduceFloatOp op(reduce::Mean); - auto result = op.evaluate({&x, &indices}, {}, {}, {true}); - auto z = result.at(0); - auto exp = x.reduceAlongDimension(reduce::Mean, {1}, true); + auto x = NDArrayFactory::create('c', {2, 3, 5}); + x.linspace(1); + auto indices = NDArrayFactory::create('c', {1}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - // z->printIndexedBuffer("Reduce8 output"); - // z->printShapeInfo("Reduce8 shape"); - // exp.printShapeInfo("Reduce8 expected shape"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::LegacyReduceFloatOp op(reduce::Mean); + auto result = op.evaluate({&x, &indices}, {}, {}, {true}); + auto z = result.at(0); + auto exp = x.reduceAlongDimension(reduce::Mean, {1}, true); - + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + // z->printIndexedBuffer("Reduce8 output"); + // z->printShapeInfo("Reduce8 shape"); + // exp.printShapeInfo("Reduce8 expected shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(LegacyOpsTests, IndexReduceTests_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.linspace(1); - - sd::ops::LegacyIndexReduceOp op(indexreduce::IndexMax); + auto x = NDArrayFactory::create('c', {5, 5}); + x.linspace(1); - auto result = op.evaluate({&x}, {}, {}); + sd::ops::LegacyIndexReduceOp op(indexreduce::IndexMax); - ASSERT_EQ(1, result.size()); + auto result = op.evaluate({&x}, {}, {}); - auto z = result.at(0); + ASSERT_EQ(1, result.size()); - ASSERT_TRUE(z->isScalar()); - ASSERT_EQ(24, z->e(0)); + auto z = result.at(0); - + ASSERT_TRUE(z.isScalar()); + ASSERT_EQ(24, z.e(0)); } - TEST_F(LegacyOpsTests, IndexReduceTests_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto indices = NDArrayFactory::create('c', {1}, {1}); - x.linspace(1); - auto exp = NDArrayFactory::create({4,4,4,4,4}); - sd::ops::LegacyIndexReduceOp op(indexreduce::IndexMax); + auto x = NDArrayFactory::create('c', {5, 5}); + auto indices = NDArrayFactory::create('c', {1}, {1}); + x.linspace(1); + auto exp = NDArrayFactory::create({4, 4, 4, 4, 4}); + sd::ops::LegacyIndexReduceOp op(indexreduce::IndexMax); - auto result = op.evaluate({&x, &indices}, {}, {}); + auto result = op.evaluate({&x, &indices}, {}, {}); - ASSERT_EQ(1, result.size()); + ASSERT_EQ(1, result.size()); - auto z = result.at(0); - // z->printIndexedBuffer("Hello indexreduce2"); - ASSERT_TRUE(exp.equalsTo(z)); - //ASSERT_EQ(4, z->e(0)); - //ASSERT_EQ(4, z->e(1)); - //ASSERT_EQ(4, z->e(2)); - //ASSERT_EQ(4, z->e(3)); - //ASSERT_EQ(4, z->e(4)); - - + auto z = result.at(0); + // z->printIndexedBuffer("Hello indexreduce2"); + ASSERT_TRUE(exp.equalsTo(z)); + // ASSERT_EQ(4, z->e(0)); + // ASSERT_EQ(4, z->e(1)); + // ASSERT_EQ(4, z->e(2)); + // ASSERT_EQ(4, z->e(3)); + // ASSERT_EQ(4, z->e(4)); } TEST_F(LegacyOpsTests, BroadcastingTests_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(0.0f); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(0.0f); - auto row = NDArrayFactory::create('c', {1, 5}); - row.linspace(1); - auto axis = NDArrayFactory::create('c', {1}, {1}); - sd::ops::LegacyBroadcastOp op(broadcast::Add); - Nd4jStatus status = op.execute({&x, &row, &axis}, {&x}, {}, {}, {}); + auto row = NDArrayFactory::create('c', {1, 5}); + row.linspace(1); + auto axis = NDArrayFactory::create('c', {1}, {1}); + sd::ops::LegacyBroadcastOp op(broadcast::Add); + Nd4jStatus status = op.execute({&x, &row, &axis}, {&x}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_EQ(ND4J_STATUS_OK, status); - auto list = x.allTensorsAlongDimension({1}); - // x.printIndexedBuffer("Output broadcast"); - // list->at(0)->printIndexedBuffer("Column 0:"); - for (int e = 0; e < list.size(); e++) - ASSERT_TRUE(row.equalsTo(list.at(e))); + auto list = x.allTensorsAlongDimension({1}); + // x.printIndexedBuffer("Output broadcast"); + // list->at(0)->printIndexedBuffer("Column 0:"); + for (int e = 0; e < list.size(); e++) ASSERT_TRUE(row.equalsTo(list.at(e))); } TEST_F(LegacyOpsTests, BroadcastingTests_2) { - auto x = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); - auto y = NDArrayFactory::create('c', {10, 5}); - auto e = NDArrayFactory::create('c', {10, 5}); - y.assign(3.0); - e.assign(4.0); + auto x = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); + auto y = NDArrayFactory::create('c', {10, 5}); + auto e = NDArrayFactory::create('c', {10, 5}); + y.assign(3.0); + e.assign(4.0); - int axis = 1; + int axis = 1; - // shape::printShapeInfoLinear("tad shape", tad.tadOnlyShapeInfo); - auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), {axis}); + // shape::printShapeInfoLinear("tad shape", tad.tadOnlyShapeInfo); + auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions( + y.shapeInfo(), axis); - NDArray::prepareSpecialUse({&y}, {&x}); + NDArray::prepareSpecialUse({&y}, {&x}); - NativeOpExecutioner::execInverseBroadcast(LaunchContext::defaultContext(), broadcast::Add, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), &axis, 1, packY.platformShapeInfo(), packY.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); + NativeOpExecutioner::execInverseBroadcast( + LaunchContext::defaultContext(), broadcast::Add, x.buffer(), + x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), + y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), y.buffer(), + y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), &axis, 1, + packY.platformShapeInfo(), packY.platformOffsets(), + packY.platformShapeInfo(), packY.platformOffsets()); - NDArray::registerSpecialUse({&y}, {&x}); + NDArray::registerSpecialUse({&y}, {&x}); - ASSERT_EQ(e, y); + ASSERT_EQ(e, y); } TEST_F(LegacyOpsTests, PowDerivative_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create('c', {5, 5}); - x.assign(3.f); - exp.assign(6.f); + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.assign(3.f); + exp.assign(6.f); - float p = 2.0f; + float p = 2.0f; - x.applyScalar(scalar::PowDerivative, p, x); + x.applyScalar(scalar::PowDerivative, p, x); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } #ifndef __CUDABLAS__ TEST_F(LegacyOpsTests, reduce3_1) { - - Nd4jLong yShape[2] = {4,4}; - Nd4jLong xShape[1] = {4}; - float y[16] ={1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}; - float x[4] = {1,2,3,4}; - int dimension[1] = {1}; - int dimensionLength = 1; - int opNum = 1; - float extraVals[1] = {0}; - float result[4] = {0.0,0.0,0.0,0.0}; - - std::vector dim = {1}; - - auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, yShape); - auto xShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 1, xShape); - - //int *tadShapeBuffer = shape::computeResultShape(shapeBuffer,dimension,dimensionLength); - auto tadShapeBuffer = sd::ShapeUtils::evalReduceShapeInfo('c', dim, shapeBuffer, false, true, nullptr); - functions::reduce3::Reduce3::exec(opNum, x, xShapeBuffer, extraVals, y, shapeBuffer, result, tadShapeBuffer, dimension, dimensionLength, 0, 4); - - float distancesAssertion[4] = {0.0,8.0,16.0,24.0}; - for(int i = 0; i < 4; i++) - ASSERT_NEAR(distancesAssertion[i],result[i], 1e-5); - - delete[] shapeBuffer; - delete[] xShapeBuffer; + Nd4jLong yShape[2] = {4, 4}; + Nd4jLong xShape[1] = {4}; + float y[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + float x[4] = {1, 2, 3, 4}; + int dimension[1] = {1}; + int dimensionLength = 1; + int opNum = 1; + float extraVals[1] = {0}; + float result[4] = {0.0, 0.0, 0.0, 0.0}; + + std::vector dim = {1}; + + auto shapeBuffer = + sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, yShape); + auto xShapeBuffer = + sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 1, xShape); + + // int *tadShapeBuffer = + // shape::computeResultShape(shapeBuffer,dimension,dimensionLength); + auto tadShapeBuffer = sd::ShapeUtils::evalReduceShapeInfo( + 'c', dim, shapeBuffer, false, true, nullptr); + functions::reduce3::Reduce3::exec( + opNum, x, xShapeBuffer, extraVals, y, shapeBuffer, result, tadShapeBuffer, + dimension, dimensionLength, 0, 4); + + float distancesAssertion[4] = {0.0, 8.0, 16.0, 24.0}; + for (int i = 0; i < 4; i++) + ASSERT_NEAR(distancesAssertion[i], result[i], 1e-5); + + delete[] shapeBuffer; + delete[] xShapeBuffer; } #endif - TEST_F(LegacyOpsTests, Reduce3_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5}); - auto z = NDArrayFactory::create('c', {5}); - - auto dim = NDArrayFactory::create('c', {1}, {1}); - dim.syncToHost(); - - sd::LaunchContext* context = sd::LaunchContext::defaultContext(); - - Nd4jPointer* extraPointers = nullptr; - #ifdef __CUDABLAS__ - extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()}; - #endif + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5}); + auto z = NDArrayFactory::create('c', {5}); + + auto dim = NDArrayFactory::create('c', {1}, {1}); + dim.syncToHost(); + + sd::LaunchContext* context = sd::LaunchContext::defaultContext(); + + Nd4jPointer* extraPointers = nullptr; +#ifdef __CUDABLAS__ + extraPointers = new Nd4jPointer[7]{nullptr, + context->getCudaStream(), + context->getScalarPointer(), + nullptr, + context->getCudaSpecialStream(), + context->getReductionPointer(), + context->getAllocationPointer()}; +#endif - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), {1}); - auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), {1}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + x.shapeInfo(), 1); + auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions( + y.shapeInfo(), 1); - NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); - OpaqueDataBuffer dimBuf(dim.dataBuffer()); + NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dim.dataBuffer()); - execReduce3Tad(extraPointers, reduce3::CosineSimilarity, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, &yBuf, y.shapeInfo(), y.specialShapeInfo(), - &zBuf, z.shapeInfo(), z.specialShapeInfo(), - &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(), - packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); + execReduce3Tad(extraPointers, reduce3::CosineSimilarity, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), nullptr, &yBuf, y.shapeInfo(), + y.specialShapeInfo(), &zBuf, z.shapeInfo(), + z.specialShapeInfo(), &dimBuf, dim.shapeInfo(), + dim.specialShapeInfo(), packX.platformShapeInfo(), + packX.platformOffsets(), packY.platformShapeInfo(), + packY.platformOffsets()); - NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); + NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); - delete []extraPointers; + delete[] extraPointers; } TEST_F(LegacyOpsTests, Reduce3_3) { - auto x = NDArrayFactory::create('c', {3, 5}, {-0.84443557262, -0.06822254508, 0.74266910552, 0.61765557527, -0.77555125951, - -0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673, - 0.62955373525, -0.31357592344, 1.03362500667, -0.59279078245, 1.1914824247}); - - auto y = NDArrayFactory::create('c', {5}, {-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673}); - auto e = NDArrayFactory::create('c', {3}, {0.577452, 0.0, 1.80182}); - auto z = NDArrayFactory::create('c', {3}); - - auto dim = NDArrayFactory::create('c', {1}, {1}); - dim.syncToHost(); - - sd::LaunchContext* context = sd::LaunchContext::defaultContext(); - - Nd4jPointer* extraPointers = nullptr; - #ifdef __CUDABLAS__ - extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()}; - #endif - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), {1}); - auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), {1}); - - NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); - OpaqueDataBuffer dimBuf(dim.dataBuffer()); + auto x = NDArrayFactory::create( + 'c', {3, 5}, + {-0.84443557262, -0.06822254508, 0.74266910552, 0.61765557527, + -0.77555125951, -0.99536740779, -0.0257304441183, -0.6512106060, + -0.345789492130, -1.25485503673, 0.62955373525, -0.31357592344, + 1.03362500667, -0.59279078245, 1.1914824247}); + + auto y = NDArrayFactory::create( + 'c', {5}, + {-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, + -1.25485503673}); + auto e = NDArrayFactory::create('c', {3}, {0.577452, 0.0, 1.80182}); + auto z = NDArrayFactory::create('c', {3}); + + auto dim = NDArrayFactory::create('c', {1}, {1}); + dim.syncToHost(); + + sd::LaunchContext* context = sd::LaunchContext::defaultContext(); + + Nd4jPointer* extraPointers = nullptr; +#ifdef __CUDABLAS__ + extraPointers = new Nd4jPointer[7]{nullptr, + context->getCudaStream(), + context->getScalarPointer(), + nullptr, + context->getCudaSpecialStream(), + context->getReductionPointer(), + context->getAllocationPointer()}; +#endif - execReduce3Tad(extraPointers, reduce3::CosineDistance, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &yBuf, y.shapeInfo(), y.specialShapeInfo(), - &zBuf, z.shapeInfo(), z.specialShapeInfo(), - &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(), - packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); - ASSERT_EQ(e, z); - NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); - delete []extraPointers; + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + x.shapeInfo(), 1); + auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions( + y.shapeInfo(), 1); + + NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dim.dataBuffer()); + + execReduce3Tad(extraPointers, reduce3::CosineDistance, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), nullptr, &yBuf, y.shapeInfo(), + y.specialShapeInfo(), &zBuf, z.shapeInfo(), + z.specialShapeInfo(), &dimBuf, dim.shapeInfo(), + dim.specialShapeInfo(), packX.platformShapeInfo(), + packX.platformOffsets(), packY.platformShapeInfo(), + packY.platformOffsets()); + ASSERT_EQ(e, z); + NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); + delete[] extraPointers; } TEST_F(LegacyOpsTests, Reduce3_4) { - auto x = NDArrayFactory::create('c', {3, 5}, {-0.84443557262, -0.06822254508, 0.74266910552, 0.61765557527, -0.77555125951, - -0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673, - 0.62955373525, -0.31357592344, 1.03362500667, -0.59279078245, 1.1914824247}); - - auto y = NDArrayFactory::create('c', {1, 5}, {-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673}); - auto e = NDArrayFactory::create('c', {1, 3}, {0.577452, 0.0, 1.80182}); - auto z = NDArrayFactory::create('c', {1, 3}); - - auto dim = NDArrayFactory::create('c', {1}, {1}); - dim.syncToHost(); - - sd::LaunchContext* context = sd::LaunchContext::defaultContext(); - - Nd4jPointer* extraPointers = nullptr; - #ifdef __CUDABLAS__ - extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()}; - #endif + auto x = NDArrayFactory::create( + 'c', {3, 5}, + {-0.84443557262, -0.06822254508, 0.74266910552, 0.61765557527, + -0.77555125951, -0.99536740779, -0.0257304441183, -0.6512106060, + -0.345789492130, -1.25485503673, 0.62955373525, -0.31357592344, + 1.03362500667, -0.59279078245, 1.1914824247}); + + auto y = NDArrayFactory::create( + 'c', {1, 5}, + {-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, + -1.25485503673}); + auto e = + NDArrayFactory::create('c', {1, 3}, {0.577452, 0.0, 1.80182}); + auto z = NDArrayFactory::create('c', {1, 3}); + + auto dim = NDArrayFactory::create('c', {1}, {1}); + dim.syncToHost(); + + sd::LaunchContext* context = sd::LaunchContext::defaultContext(); + + Nd4jPointer* extraPointers = nullptr; +#ifdef __CUDABLAS__ + extraPointers = new Nd4jPointer[7]{nullptr, + context->getCudaStream(), + context->getScalarPointer(), + nullptr, + context->getCudaSpecialStream(), + context->getReductionPointer(), + context->getAllocationPointer()}; +#endif - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), {1}); - auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), {1}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + x.shapeInfo(), 1); + auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions( + y.shapeInfo(), 1); - NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); - OpaqueDataBuffer dimBuf(dim.dataBuffer()); + NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dim.dataBuffer()); - execReduce3Tad(extraPointers, reduce3::CosineDistance, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &yBuf, y.shapeInfo(), y.specialShapeInfo(), - &zBuf, z.shapeInfo(), z.specialShapeInfo(), - &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(), - packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); + execReduce3Tad(extraPointers, reduce3::CosineDistance, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), nullptr, &yBuf, y.shapeInfo(), + y.specialShapeInfo(), &zBuf, z.shapeInfo(), + z.specialShapeInfo(), &dimBuf, dim.shapeInfo(), + dim.specialShapeInfo(), packX.platformShapeInfo(), + packX.platformOffsets(), packY.platformShapeInfo(), + packY.platformOffsets()); - // z.printIndexedBuffer("z"); - NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); - ASSERT_EQ(e, z); - delete []extraPointers; + // z.printIndexedBuffer("z"); + NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); + ASSERT_EQ(e, z); + delete[] extraPointers; } TEST_F(LegacyOpsTests, Reduce3_5) { - auto x = NDArrayFactory::create('c', {3, 5}, {-0.84443557262, -0.06822254508, 0.74266910552, 0.61765557527, -0.77555125951, - -0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673, - 0.62955373525, -0.31357592344, 1.03362500667, -0.59279078245, 1.1914824247}); - - auto y = NDArrayFactory::create('c', {1, 5}, {-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673}); - auto e = NDArrayFactory::create('c', {1, 3}, {0.577452, 0.0, 1.80182}); - auto z = NDArrayFactory::create('c', {1, 3}); - - auto dim = NDArrayFactory::create('c', {1}, {1}); - dim.syncToHost(); - - sd::LaunchContext* context = sd::LaunchContext::defaultContext(); - - Nd4jPointer* extraPointers = nullptr; - #ifdef __CUDABLAS__ - extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()}; - #endif + auto x = NDArrayFactory::create( + 'c', {3, 5}, + {-0.84443557262, -0.06822254508, 0.74266910552, 0.61765557527, + -0.77555125951, -0.99536740779, -0.0257304441183, -0.6512106060, + -0.345789492130, -1.25485503673, 0.62955373525, -0.31357592344, + 1.03362500667, -0.59279078245, 1.1914824247}); + + auto y = NDArrayFactory::create( + 'c', {1, 5}, + {-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, + -1.25485503673}); + auto e = + NDArrayFactory::create('c', {1, 3}, {0.577452, 0.0, 1.80182}); + auto z = NDArrayFactory::create('c', {1, 3}); + + auto dim = NDArrayFactory::create('c', {1}, {1}); + dim.syncToHost(); + + sd::LaunchContext* context = sd::LaunchContext::defaultContext(); + + Nd4jPointer* extraPointers = nullptr; +#ifdef __CUDABLAS__ + extraPointers = new Nd4jPointer[7]{nullptr, + context->getCudaStream(), + context->getScalarPointer(), + nullptr, + context->getCudaSpecialStream(), + context->getReductionPointer(), + context->getAllocationPointer()}; +#endif - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), {1}); - auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), {1}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( + x.shapeInfo(), 1); + auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions( + y.shapeInfo(), 1); - NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); + NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); - OpaqueDataBuffer dimBuf(dim.dataBuffer()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dim.dataBuffer()); - execReduce3Tad(extraPointers, reduce3::CosineDistance, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &yBuf, y.shapeInfo(), y.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(), - &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(), - packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); + execReduce3Tad(extraPointers, reduce3::CosineDistance, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), nullptr, &yBuf, y.shapeInfo(), + y.specialShapeInfo(), &zBuf, z.shapeInfo(), + z.specialShapeInfo(), &dimBuf, dim.shapeInfo(), + dim.specialShapeInfo(), packX.platformShapeInfo(), + packX.platformOffsets(), packY.platformShapeInfo(), + packY.platformOffsets()); - NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); - ASSERT_EQ(e, z); - delete []extraPointers; + NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); + ASSERT_EQ(e, z); + delete[] extraPointers; } TEST_F(LegacyOpsTests, test_Reduce3_All_1) { - auto x = NDArrayFactory::create('c', {1000, 100}); - auto y = NDArrayFactory::create('c', {1, 100}); - auto z = NDArrayFactory::create('c', {1000, 1}); - auto dim = NDArrayFactory::create('c', {1}, {-1}); - - auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), -1); - auto tadPackY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), -1); - - sd::LaunchContext* context = sd::LaunchContext::defaultContext(); - - Nd4jPointer* extraPointers = nullptr; - #ifdef __CUDABLAS__ - extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()}; - #endif + auto x = NDArrayFactory::create('c', {1000, 100}); + auto y = NDArrayFactory::create('c', {1, 100}); + auto z = NDArrayFactory::create('c', {1000, 1}); + auto dim = NDArrayFactory::create('c', {1}, {-1}); + + auto tadPackX = + sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), -1); + auto tadPackY = + sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), -1); + + sd::LaunchContext* context = sd::LaunchContext::defaultContext(); + + Nd4jPointer* extraPointers = nullptr; +#ifdef __CUDABLAS__ + extraPointers = new Nd4jPointer[7]{nullptr, + context->getCudaStream(), + context->getScalarPointer(), + nullptr, + context->getCudaSpecialStream(), + context->getReductionPointer(), + context->getAllocationPointer()}; +#endif - NDArray::prepareSpecialUse({&z}, {&x, &y}); + NDArray::prepareSpecialUse({&z}, {&x, &y}); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); - OpaqueDataBuffer dimBuf(dim.dataBuffer()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dim.dataBuffer()); - execReduce3All(extraPointers, reduce3::EuclideanDistance, &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, &yBuf, y.shapeInfo(), y.specialShapeInfo(), - &zBuf, z.shapeInfo(), z.specialShapeInfo(), - &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(), - tadPackX.platformShapeInfo(), tadPackX.platformOffsets(), - tadPackY.platformShapeInfo(), tadPackY.platformOffsets()); + execReduce3All(extraPointers, reduce3::EuclideanDistance, &xBuf, + x.shapeInfo(), x.specialShapeInfo(), nullptr, &yBuf, + y.shapeInfo(), y.specialShapeInfo(), &zBuf, z.shapeInfo(), + z.specialShapeInfo(), &dimBuf, dim.shapeInfo(), + dim.specialShapeInfo(), tadPackX.platformShapeInfo(), + tadPackX.platformOffsets(), tadPackY.platformShapeInfo(), + tadPackY.platformOffsets()); - NDArray::registerSpecialUse({&z}, {&x, &y}); + NDArray::registerSpecialUse({&z}, {&x, &y}); - delete []extraPointers; + delete[] extraPointers; } - TEST_F(LegacyOpsTests, test_inverse_broadcast_1) { - auto x = NDArrayFactory::create('c', {4}, {2.0f, 2.0f, 2.0f, 2.0f}); - auto y = NDArrayFactory::create('c', {3, 4}); - auto e = NDArrayFactory::create('c', {3, 4}); - e.assign(2.0f); + auto x = NDArrayFactory::create('c', {4}, {2.0f, 2.0f, 2.0f, 2.0f}); + auto y = NDArrayFactory::create('c', {3, 4}); + auto e = NDArrayFactory::create('c', {3, 4}); + e.assign(2.0f); - auto tadPackY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), 1); + auto tadPackY = + sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), 1); - y.tickWriteDevice(); + y.tickWriteDevice(); - NativeOpExecutioner::execInverseBroadcast(LaunchContext::defaultContext(), broadcast::Add, - x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, 0, - tadPackY.platformShapeInfo(), tadPackY.platformOffsets(), - tadPackY.platformShapeInfo(), tadPackY.platformOffsets()); + NativeOpExecutioner::execInverseBroadcast( + LaunchContext::defaultContext(), broadcast::Add, x.buffer(), + x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), + y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), y.buffer(), + y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), nullptr, 0, + tadPackY.platformShapeInfo(), tadPackY.platformOffsets(), + tadPackY.platformShapeInfo(), tadPackY.platformOffsets()); - ASSERT_EQ(e, y); + ASSERT_EQ(e, y); } TEST_F(LegacyOpsTests, test_inverse_broadcast_2) { - auto x = NDArrayFactory::create('c', {4}, {2.0f, 2.0f, 2.0f, 2.0f}); - auto y = NDArrayFactory::create('c', {3, 4}); - auto z = NDArrayFactory::create('c', {3, 4}); - auto e = NDArrayFactory::create('c', {3, 4}); - e.assign(false); + auto x = NDArrayFactory::create('c', {4}, {2.0f, 2.0f, 2.0f, 2.0f}); + auto y = NDArrayFactory::create('c', {3, 4}); + auto z = NDArrayFactory::create('c', {3, 4}); + auto e = NDArrayFactory::create('c', {3, 4}); + e.assign(false); - auto row = y(1, {0}); - row.assign(2.0f); + auto row = y(1, {0}); + row.assign(2.0f); - auto erow = e(1, {0}); - erow.assign(true); + auto erow = e(1, {0}); + erow.assign(true); - auto tadPackY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), 1); + auto tadPackY = + sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), 1); - z.tickWriteDevice(); + z.tickWriteDevice(); - NativeOpExecutioner::execInverseBroadcastBool(LaunchContext::defaultContext(), broadcast::BoolOps::EqualTo, - x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, - nullptr, 0, - tadPackY.platformShapeInfo(), tadPackY.platformOffsets(), - tadPackY.platformShapeInfo(), tadPackY.platformOffsets()); + NativeOpExecutioner::execInverseBroadcastBool( + LaunchContext::defaultContext(), broadcast::BoolOps::EqualTo, x.buffer(), + x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), + y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), + z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr, + 0, tadPackY.platformShapeInfo(), tadPackY.platformOffsets(), + tadPackY.platformShapeInfo(), tadPackY.platformOffsets()); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(LegacyOpsTests, test_legacy_reduce_empty_1) { - auto x = NDArrayFactory::create('c', {2, 0, 3}); - auto z = NDArrayFactory::create('c', {2, 3}); - auto e = NDArrayFactory::create('c', {2, 3}); + auto x = NDArrayFactory::create('c', {2, 0, 3}); + auto z = NDArrayFactory::create('c', {2, 3}); + auto e = NDArrayFactory::create('c', {2, 3}); - int dim = 1; + int dim = 1; - NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Sum, - x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - &dim, 1, x.platformShapeInfo(), nullptr); + NativeOpExecutioner::execReduceSame( + LaunchContext::defaultContext(), reduce::SameOps::Sum, x.buffer(), + x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, + z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, + 1, x.platformShapeInfo(), nullptr); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(LegacyOpsTests, test_legacy_reduce_empty_2) { - auto x = NDArrayFactory::create('c', {2, 0, 3}); - auto z = NDArrayFactory::create('c', {2, 3}); - auto e = NDArrayFactory::create('c', {2, 3}); - e.assign(std::numeric_limits::infinity()); + auto x = NDArrayFactory::create('c', {2, 0, 3}); + auto z = NDArrayFactory::create('c', {2, 3}); + auto e = NDArrayFactory::create('c', {2, 3}); + e.assign(std::numeric_limits::infinity()); - int dim = 1; + int dim = 1; - NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Min, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, 1, x.platformShapeInfo(), nullptr); + NativeOpExecutioner::execReduceSame( + LaunchContext::defaultContext(), reduce::SameOps::Min, x.buffer(), + x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, + z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, + 1, x.platformShapeInfo(), nullptr); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(LegacyOpsTests, test_legacy_reduce_empty_3) { - auto x = NDArrayFactory::create('c', {2, 0, 3}); - auto z = NDArrayFactory::create('c', {2, 3}); - auto e = NDArrayFactory::create('c', {2, 3}); - e.assign(-std::numeric_limits::infinity()); + auto x = NDArrayFactory::create('c', {2, 0, 3}); + auto z = NDArrayFactory::create('c', {2, 3}); + auto e = NDArrayFactory::create('c', {2, 3}); + e.assign(-std::numeric_limits::infinity()); - int dim = 1; + int dim = 1; - NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Max, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, 1, x.platformShapeInfo(), nullptr); + NativeOpExecutioner::execReduceSame( + LaunchContext::defaultContext(), reduce::SameOps::Max, x.buffer(), + x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, + z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, + 1, x.platformShapeInfo(), nullptr); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(LegacyOpsTests, test_legacy_reduce_empty_4) { - if (!Environment::getInstance().isCPU()) - return; - int a = 0; - - auto x = NDArrayFactory::create('c', {1, 0, 2}); - auto d = NDArrayFactory::create('c', {1}, {a}); - auto z = NDArrayFactory::create('c', {0, 2}); - auto e = NDArrayFactory::create('c', {0, 2}); - - InteropDataBuffer xdb(x.dataBuffer()); - InteropDataBuffer ddb(d.dataBuffer()); - InteropDataBuffer zdb(z.dataBuffer()); + if (!Environment::getInstance().isCPU()) return; + int a = 0; + auto x = NDArrayFactory::create('c', {1, 0, 2}); + auto d = NDArrayFactory::create('c', {1}, {a}); + auto z = NDArrayFactory::create('c', {0, 2}); + auto e = NDArrayFactory::create('c', {0, 2}); - ::execReduceSame2(nullptr, reduce::SameOps::Sum, - &xdb, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &zdb, z.shapeInfo(), z.specialShapeInfo(), - &ddb, d.shapeInfo(), d.specialShapeInfo()); + InteropDataBuffer xdb(x.dataBuffer()); + InteropDataBuffer ddb(d.dataBuffer()); + InteropDataBuffer zdb(z.dataBuffer()); + ::execReduceSame2(nullptr, reduce::SameOps::Sum, &xdb, x.shapeInfo(), + x.specialShapeInfo(), nullptr, &zdb, z.shapeInfo(), + z.specialShapeInfo(), &ddb, d.shapeInfo(), + d.specialShapeInfo()); } TEST_F(LegacyOpsTests, test_legacy_transform_float_1) { - auto x = NDArrayFactory::create('c', {1, 0, 4}); + auto x = NDArrayFactory::create('c', {1, 0, 4}); - NativeOpExecutioner::execTransformFloat(LaunchContext::defaultContext(), transform::FloatOps::RSqrt, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, nullptr); + NativeOpExecutioner::execTransformFloat( + LaunchContext::defaultContext(), transform::FloatOps::RSqrt, x.buffer(), + x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), x.buffer(), + x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, + nullptr); } diff --git a/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp b/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp index 04e4a70e8cf3..541a5bcd2df0 100644 --- a/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp @@ -18,99 +18,88 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include -#include #include +#include "testlayers.h" + using namespace sd; using namespace sd::ops; -class ListOperationsTests : public testing::Test { - -}; +class ListOperationsTests : public testing::Test {}; TEST_F(ListOperationsTests, BasicTest_Write_1) { - NDArrayList list(5); - auto x = NDArrayFactory::create('c', {128}); - x.linspace(1); - - sd::ops::write_list op; + NDArrayList list(5); + auto x = NDArrayFactory::create('c', {128}); + x.linspace(1); - auto result = op.execute(&list, {&x}, {}, {1}); + sd::ops::write_list op; - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.execute(list, {&x}, {}, {1}); - ASSERT_EQ(1, list.elements()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto result2 = op.execute(&list, {&x}, {}, {2}); + ASSERT_EQ(1, list.elements()); - ASSERT_EQ(2, list.elements()); + auto result2 = op.execute(list, {&x}, {}, {2}); - - + ASSERT_EQ(2, list.elements()); } TEST_F(ListOperationsTests, BasicTest_Stack_1) { - NDArrayList list(10); - auto exp = NDArrayFactory::create('c', {10, 100}); - auto tads = exp.allTensorsAlongDimension({1}); - for (int e = 0; e < 10; e++) { - auto row = NDArrayFactory::create_('c', {100}); - row->assign((double) e); - list.write(e, row); - tads.at(e)->assign(row); - } - - sd::ops::stack_list op; + NDArrayList list(10); + auto exp = NDArrayFactory::create('c', {10, 100}); + auto tads = exp.allTensorsAlongDimension({1}); + for (int e = 0; e < 10; e++) { + auto row = NDArrayFactory::create('c', {100}); + row.assign((double)e); + list.write(e, row); + tads.at(e).assign(row); + } - auto result = op.execute(&list, {}, {}, {1}); + sd::ops::stack_list op; - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.execute(list, {}, {}, {1}); - auto z = result.at(0); - // z->printShapeInfo(); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printShapeInfo(); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ListOperationsTests, BasicTest_UnStackList_1) { - NDArrayList list(0, true); - auto x = NDArrayFactory::create('c', {10, 100}); - auto tads = x.allTensorsAlongDimension({1}); - for (int e = 0; e < 10; e++) { - auto row = NDArrayFactory::create_('c', {100}); - row->assign((double) e); - //list.write(e, row); - tads.at(e)->assign(row); - delete row; - } - - sd::ops::unstack_list op; - - auto result = op.execute(&list, {&x}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(list.elements(), 10); - -// auto z = result.at(0); -// z->printShapeInfo("The first of"); -// ASSERT_TRUE(exp.isSameShape(z)); -// ASSERT_TRUE(exp.equalsTo(z)); - for (int e = 0; e < 10; e++) { - auto row = list.read(e); - ASSERT_TRUE(row->equalsTo(tads.at(e))); - //list.write(e, row); - delete row; - } - - + NDArrayList list(0, true); + auto x = NDArrayFactory::create('c', {10, 100}); + auto tads = x.allTensorsAlongDimension({1}); + for (int e = 0; e < 10; e++) { + auto row = NDArrayFactory::create('c', {100}); + row.assign((double)e); + // list.write(e, row); + tads.at(e).assign(row); + } + + sd::ops::unstack_list op; + + auto result = op.execute(list, {&x}, {}, {0}); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(10, list.elements()); + + // auto z = result.at(0); + // z->printShapeInfo("The first of"); + // ASSERT_TRUE(exp.isSameShape(z)); + // ASSERT_TRUE(exp.equalsTo(z)); + for (int e = 0; e < 10; e++) { + auto row = list.read(e); + ASSERT_TRUE(row.equalsTo(tads.at(e))); + // list.write(e, row); + } } -//TEST_F(ListOperationsTests, BasicTest_UnStackList_2) { +// TEST_F(ListOperationsTests, BasicTest_UnStackList_2) { //// NDArrayList list(0, true); // auto x = NDArrayFactory::create('c', {10, 100}); // auto tads = x.allTensorsAlongDimension({1}); @@ -139,523 +128,241 @@ TEST_F(ListOperationsTests, BasicTest_UnStackList_1) { // //list.write(e, row); // } // -// +// // delete tads; //} TEST_F(ListOperationsTests, BasicTest_Read_1) { - NDArrayList list(10); - auto exp = NDArrayFactory::create('c', {1, 100}); - exp.assign(4.0f); - - for (int e = 0; e < 10; e++) { - auto row = NDArrayFactory::create_('c', {1, 100}); - row->assign((double) e); - list.write(e, new NDArray(row->dup())); + NDArrayList list(10); + auto exp = NDArrayFactory::create('c', {1, 100}); + exp.assign(4.0f); - delete row; - } + for (int e = 0; e < 10; e++) { + auto row = NDArrayFactory::create_('c', {1, 100}); + row->assign((double)e); + list.write(e, row->dup()); - sd::ops::read_list op; + delete row; + } - auto result = op.execute(&list, {}, {}, {4}); + sd::ops::read_list op; - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.execute(list, {}, {}, {4}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ListOperationsTests, BasicTest_Pick_1) { - NDArrayList list(10); - auto exp = NDArrayFactory::create('c', {4, 100}); - - for (int e = 0; e < 10; e++) { - auto row = NDArrayFactory::create_('c', {100}); - row->assign((double) e); - list.write(e, new NDArray(row->dup())); + NDArrayList list(10); + auto exp = NDArrayFactory::create('c', {4, 100}); - delete row; - } + for (int e = 0; e < 10; e++) { + auto row = NDArrayFactory::create_('c', {100}); + row->assign((double)e); + list.write(e, row->dup()); - auto tads = exp.allTensorsAlongDimension({1}); - tads.at(0)->assign(1.0f); - tads.at(1)->assign(1.0f); - tads.at(2)->assign(3.0f); - tads.at(3)->assign(3.0f); + delete row; + } + auto tads = exp.allTensorsAlongDimension({1}); + tads.at(0).assign(1.0f); + tads.at(1).assign(1.0f); + tads.at(2).assign(3.0f); + tads.at(3).assign(3.0f); - sd::ops::pick_list op; - auto result = op.execute(&list, {}, {}, {1, 1, 3, 3}); + sd::ops::pick_list op; + auto result = op.execute(list, {}, {}, {1, 1, 3, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ListOperationsTests, BasicTest_Size_1) { - NDArrayList list(10); - auto exp = NDArrayFactory::create(10); - for (int e = 0; e < 10; e++) { - auto row = NDArrayFactory::create_('c', {100}); - row->assign((double) e); - list.write(e, new NDArray(row->dup())); - - delete row; - } + NDArrayList list(10); + auto exp = NDArrayFactory::create(10); + for (int e = 0; e < 10; e++) { + auto row = NDArrayFactory::create_('c', {100}); + row->assign((double)e); + list.write(e, row->dup()); - sd::ops::size_list op; + delete row; + } - auto result = op.execute(&list, {}, {}, {1}); + sd::ops::size_list op; - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.execute(list, {}, {}, {1}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ListOperationsTests, BasicTest_Create_1) { - auto matrix = NDArrayFactory::create('c', {3, 2}); - matrix.linspace(1); - - sd::ops::create_list op; + auto matrix = NDArrayFactory::create('c', {3, 2}); + matrix.linspace(1); - auto result = op.execute(nullptr, {&matrix}, {}, {1, 1}); + sd::ops::create_list op; - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.execute(NDArrayList(), {&matrix}, {}, {1, 1}); - // we return flow as well - ASSERT_EQ(1, result.size()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + // we return flow as well + ASSERT_EQ(1, result.size()); } TEST_F(ListOperationsTests, BasicTest_Split_1) { - NDArrayList list(0, true); + NDArrayList list(0, true); - auto exp0 = NDArrayFactory::create('c', {2, 5}); - auto exp1 = NDArrayFactory::create('c', {3, 5}); - auto exp2 = NDArrayFactory::create('c', {5, 5}); + auto exp0 = NDArrayFactory::create('c', {2, 5}); + auto exp1 = NDArrayFactory::create('c', {3, 5}); + auto exp2 = NDArrayFactory::create('c', {5, 5}); - auto matrix = NDArrayFactory::create('c', {10, 5}); + auto matrix = NDArrayFactory::create('c', {10, 5}); - auto lengths = NDArrayFactory::create('c', {3}); - lengths.p(0, 2); - lengths.p(1, 3); - lengths.p(2, 5); + auto lengths = NDArrayFactory::create('c', {3}); + lengths.p(0, 2); + lengths.p(1, 3); + lengths.p(2, 5); - auto tads = matrix.allTensorsAlongDimension({1}); + auto tads = matrix.allTensorsAlongDimension({1}); - auto tads0 = exp0.allTensorsAlongDimension({1}); - auto tads1 = exp1.allTensorsAlongDimension({1}); - auto tads2 = exp2.allTensorsAlongDimension({1}); + auto tads0 = exp0.allTensorsAlongDimension({1}); + auto tads1 = exp1.allTensorsAlongDimension({1}); + auto tads2 = exp2.allTensorsAlongDimension({1}); - int cnt0 = 0; - int cnt1 = 0; - int cnt2 = 0; - for (int e = 0; e < 10; e++) { - auto row = NDArrayFactory::create_('c', {5}); - row->assign((double) e); - tads.at(e)->assign(row); + int cnt0 = 0; + int cnt1 = 0; + int cnt2 = 0; + for (int e = 0; e < 10; e++) { + auto row = NDArrayFactory::create('c', {5}); + row.assign((double)e); + tads.at(e).assign(row); - if (e < 2) - tads0.at(cnt0++)->assign(row); - else if (e < 5) - tads1.at(cnt1++)->assign(row); - else - tads2.at(cnt2++)->assign(row); + if (e < 2) + tads0.at(cnt0++).assign(row); + else if (e < 5) + tads1.at(cnt1++).assign(row); + else + tads2.at(cnt2++).assign(row); + } - delete row; - } + sd::ops::split_list op; + auto result = op.execute(list, {&matrix, &lengths}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::split_list op; - auto result = op.execute(&list, {&matrix, &lengths}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(3, list.height()); - ASSERT_EQ(3, list.height()); + ASSERT_TRUE(exp0.isSameShape(list.readRaw(0))); + ASSERT_TRUE(exp0.equalsTo(list.readRaw(0))); - ASSERT_TRUE(exp0.isSameShape(list.readRaw(0))); - ASSERT_TRUE(exp0.equalsTo(list.readRaw(0))); + ASSERT_TRUE(exp1.isSameShape(list.readRaw(1))); + ASSERT_TRUE(exp1.equalsTo(list.readRaw(1))); - ASSERT_TRUE(exp1.isSameShape(list.readRaw(1))); - ASSERT_TRUE(exp1.equalsTo(list.readRaw(1))); - - ASSERT_TRUE(exp2.isSameShape(list.readRaw(2))); - ASSERT_TRUE(exp2.equalsTo(list.readRaw(2))); - - + ASSERT_TRUE(exp2.isSameShape(list.readRaw(2))); + ASSERT_TRUE(exp2.equalsTo(list.readRaw(2))); } TEST_F(ListOperationsTests, BasicTest_Scatter_1) { - NDArrayList list(0, true); - auto s = NDArrayFactory::create(0.0); + NDArrayList list(0, true); + auto s = NDArrayFactory::create(0.0); - auto matrix = NDArrayFactory::create('c', {10, 5}); - auto tads = matrix.allTensorsAlongDimension({1}); - for (int e = 0; e < 10; e++) { - auto row = NDArrayFactory::create_('c', {1, 5}); - row->assign((double) e); - tads.at(e)->assign(row); + auto matrix = NDArrayFactory::create('c', {10, 5}); + auto tads = matrix.allTensorsAlongDimension({1}); + for (int e = 0; e < 10; e++) { + auto row = NDArrayFactory::create_('c', {1, 5}); + row->assign((double)e); + tads.at(e).assign(row); - delete row; - } - auto indices = NDArrayFactory::create('c', {1, 10}); - for (int e = 0; e < matrix.rows(); e++) - indices.p(e, 9 - e); + delete row; + } + auto indices = NDArrayFactory::create('c', {1, 10}); + for (int e = 0; e < matrix.rows(); e++) indices.p(e, 9 - e); - sd::ops::scatter_list op; - auto result = op.execute(&list, {&indices, &matrix, &s}, {}, {}); + sd::ops::scatter_list op; + auto result = op.execute(list, {&indices, &matrix, &s}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - for (int e = 0; e < 10; e++) { - auto row = tads.at(9 - e); - auto chunk = list.readRaw(e); + for (int e = 0; e < 10; e++) { + auto row = tads.at(9 - e); + auto chunk = list.readRaw(e); - ASSERT_TRUE(chunk->isSameShape(row)); + ASSERT_TRUE(chunk.isSameShape(row)); - ASSERT_TRUE(chunk->equalsTo(row)); - } - + ASSERT_TRUE(chunk.equalsTo(row)); + } } TEST_F(ListOperationsTests, BasicTest_Clone_1) { - auto list = new NDArrayList(0, true); - - VariableSpace variableSpace; - auto var = new Variable(nullptr, nullptr, -1, 0); - var->setNDArrayList(list); + NDArrayList list(0, true); - variableSpace.putVariable(-1, var); - variableSpace.trackList(list); + VariableSpace variableSpace; + auto var = std::make_shared(list, "", -1); + variableSpace.putVariable(-1, var); - Context block(1, &variableSpace); - block.pickInput(-1); + Context block(1, &variableSpace); + block.pickInput(-1); - sd::ops::clone_list op; + sd::ops::clone_list op; - ASSERT_TRUE(list == block.variable(0)->getNDArrayList()); + // ASSERT_TRUE(list == block.variable(0)->getNDArrayList().get()); - auto result = op.execute(&block); + auto result = op.execute(&block); - ASSERT_EQ(ND4J_STATUS_OK, result); + ASSERT_EQ(ND4J_STATUS_OK, result); - auto resVar = variableSpace.getVariable(1); + auto resVar = variableSpace.getVariable(1); - auto resList = resVar->getNDArrayList(); + auto resList = resVar->getNDArrayList().get(); - ASSERT_TRUE( resList != nullptr); + ASSERT_TRUE(resList != nullptr); - ASSERT_TRUE(list->equals(*resList)); + ASSERT_TRUE(list.equals(*resList)); } TEST_F(ListOperationsTests, BasicTest_Gather_1) { - NDArrayList list(0, true); - for (int e = 0; e < 10; e++) { - auto row = NDArrayFactory::create_('c', {3}); - row->assign((double) e); - list.write(e, new NDArray(row->dup())); - - delete row; - } - - auto exp = NDArrayFactory::create('c', {10, 3}); - auto tads = exp.allTensorsAlongDimension({1}); - for (int e = 0; e < 10; e++) { - auto tad = tads.at(9 - e); - tad->assign(e); - } - - auto indices = NDArrayFactory::create('c', {1, 10}); - indices.linspace(9, -1); - - sd::ops::gather_list op; - auto result = op.execute(&list, {&indices}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(1, result.size()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - - //exp.printIndexedBuffer("e"); - //z->printIndexedBuffer("z"); - - ASSERT_TRUE(exp.equalsTo(z)); - - -} - -TEST_F(ListOperationsTests, GraphTests_Sequential_1) { - Graph graph; - - auto matrix = NDArrayFactory::create_('c', {3, 3}); - auto tads = matrix->allTensorsAlongDimension({1}); - for (int e = 0; e < tads.size(); e++) { - tads.at(e)->assign((float) (e+1)); - } - - - auto exp = NDArrayFactory::create('c', {3, 3}); - auto tadsExp = exp.allTensorsAlongDimension({1}); - tadsExp.at(0)->assign(0.f); - tadsExp.at(1)->assign(-1.f); - tadsExp.at(2)->assign(-2.f); - - auto indices = NDArrayFactory::valueOf({3}, 1, 'c'); - //indices->linspace(0); - - - auto variableSpace = graph.getVariableSpace(); - variableSpace->putVariable(-1, matrix); - variableSpace->putVariable(-2, indices); - - - auto nodeA = new Node(OpType_TRANSFORM_SAME, 0, 1, {-1}); - - // creating list - sd::ops::create_list opB; - auto nodeB = new Node(&opB, 2, {1},{},{}, 0.0f, {}, {0, 1}); - //nodeB->setCustomOp(&opB); - - // filling list with matrix - sd::ops::split_list opC; - auto nodeC = new Node(&opC, 3, {2, 1, -2}); - //nodeC->setCustomOp(&opC); - - // reading chunks from List. We're adding op number 3 in inputs, to ensure graph will execute this node after split - sd::ops::read_list opD; - auto nodeD0 = new Node(&opD, 5, {2, 3}, {},{}, 0.0f, {}, {0}); - auto nodeD1 = new Node(&opD, 6, {2, 3}, {},{}, 0.0f, {}, {1}); - auto nodeD2 = new Node(&opD, 7, {2, 3}, {},{}, 0.0f, {}, {2}); - //nodeD0->setCustomOp(&opD); - //nodeD1->setCustomOp(&opD); - //nodeD2->setCustomOp(&opD); - - // using OneMinus on each chunk separately - auto nodeE0 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 10, {5}); - auto nodeE1 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 11, {6}); - auto nodeE2 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 12, {7}); - - // writing chunks back to the List - sd::ops::write_list opF; - auto nodeF0 = new Node(&opF, 15, {2, 10}, {},{}, 0.0f, {}, {0}); - auto nodeF1 = new Node(&opF, 16, {2, 11}, {},{}, 0.0f, {}, {1}); - auto nodeF2 = new Node(&opF, 17, {2, 12}, {},{}, 0.0f, {}, {2}); - -// nodeF0->setCustomOp(&opF); -// nodeF1->setCustomOp(&opF); -// nodeF2->setCustomOp(&opF); - - // now we're stacking chunks back to matrix state - sd::ops::stack_list opG; - auto nodeG = new Node(&opG, 20, {2, 15, 16, 17}); - //auto nodeG = new Node(OpType_CUSTOM, 0, 20, {2}); - -// nodeG->setCustomOp(&opG); - - - graph.addNode(nodeA); - graph.addNode(nodeB); - graph.addNode(nodeC); - graph.addNode(nodeD0); - graph.addNode(nodeD1); - graph.addNode(nodeD2); - graph.addNode(nodeE0); - graph.addNode(nodeE1); - graph.addNode(nodeE2); - - graph.addNode(nodeF0); - graph.addNode(nodeF1); - graph.addNode(nodeF2); - - graph.addNode(nodeG); - - // let's also validate structural integrity - graph.buildGraph(); - - ASSERT_EQ(0, nodeA->getLayer()); - ASSERT_EQ(1, nodeB->getLayer()); - ASSERT_EQ(2, nodeC->getLayer()); - - ASSERT_EQ(3, nodeD0->getLayer()); - ASSERT_EQ(3, nodeD1->getLayer()); - ASSERT_EQ(3, nodeD2->getLayer()); - - ASSERT_EQ(4, nodeE0->getLayer()); - ASSERT_EQ(4, nodeE1->getLayer()); - ASSERT_EQ(4, nodeE2->getLayer()); - - ASSERT_EQ(5, nodeF0->getLayer()); - ASSERT_EQ(5, nodeF1->getLayer()); - ASSERT_EQ(5, nodeF2->getLayer()); - - ASSERT_EQ(6, nodeG->getLayer()); - - auto result = GraphExecutioner::execute(&graph); - ASSERT_EQ(ND4J_STATUS_OK, result); - - ASSERT_TRUE(variableSpace->hasVariable(2)); - auto list = variableSpace->getVariable(2)->getNDArrayList(); - - ASSERT_TRUE(list != nullptr); - - ASSERT_EQ(3, list->height()); - ASSERT_EQ(3, list->elements()); - - - ASSERT_TRUE(variableSpace->hasVariable(20)); - - auto stack = variableSpace->getVariable(20)->getNDArray(); - - ASSERT_TRUE(stack != nullptr); - - ASSERT_TRUE(exp.isSameShape(stack)); - ASSERT_TRUE(exp.equalsTo(stack)); -} - - -TEST_F(ListOperationsTests, GraphTests_Sequential_2) { - Graph graph; - - auto scalar = NDArrayFactory::create_(0.0f); - auto matrix = NDArrayFactory::create_('c', {3, 3}); - auto tads = matrix->allTensorsAlongDimension({1}); - for (int e = 0; e < tads.size(); e++) { - tads.at(e)->assign((float) (e+1)); - } - - auto exp = NDArrayFactory::create('c', {3, 3}); - auto tadsExp = exp.allTensorsAlongDimension({1}); - tadsExp.at(0)->assign(0.f); - tadsExp.at(1)->assign(-1.f); - tadsExp.at(2)->assign(-2.f); - - //auto indices = NDArray::valueOf({1, 3}, 1.0f, 'c'); - auto indices = NDArrayFactory::create_('c', {1, 3}); - indices->linspace(0); - - - auto variableSpace = graph.getVariableSpace(); - variableSpace->putVariable(-1, matrix); - variableSpace->putVariable(-2, indices); - variableSpace->putVariable(-3, scalar); - - - auto nodeA = new Node(OpType_TRANSFORM_SAME, 0, 1, {-1}); - - // creating list - sd::ops::create_list opB; - auto nodeB = new Node(&opB, 2, {1},{},{}, 0.0f, {}, {0, 1}); -// nodeB->setCustomOp(&opB); - - // filling list with matrix - sd::ops::scatter_list opC; - auto nodeC = new Node(&opC, 3, {2, -2, 1, -3}); - - //nodeC->setCustomOp(&opC); - - sd::ops::read_list opD; - auto nodeD0 = new Node(&opD, 5, {2, 3}, {},{}, 0.0f, {}, {0}); - auto nodeD1 = new Node(&opD, 6, {2, 3, 15}, {},{}, 0.0f, {}, {1}); - auto nodeD2 = new Node(&opD, 7, {2, 3, 16}, {},{}, 0.0f, {}, {2}); - -// nodeD0->setCustomOp(&opD); -// nodeD1->setCustomOp(&opD); -// nodeD2->setCustomOp(&opD); - - - // using OneMinus on each chunk separately - auto nodeE0 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 10, {5}); - auto nodeE1 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 11, {6}); - auto nodeE2 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 12, {7}); - - // writing chunks back to the List - sd::ops::write_list opF; - auto nodeF0 = new Node(&opF, 15, {2, 10}, {},{}, 0.0f, {}, {0}); - auto nodeF1 = new Node(&opF, 16, {2, 11}, {},{}, 0.0f, {}, {1}); - auto nodeF2 = new Node(&opF, 17, {2, 12}, {},{}, 0.0f, {}, {2}); - -// nodeF0->setCustomOp(&opF); -// nodeF1->setCustomOp(&opF); -// nodeF2->setCustomOp(&opF); - - // now we're gathering chunks back to matrix state - sd::ops::pick_list opG; - auto nodeG = new Node(&opG, 20, {2, -2, 15, 16, 17}); - //auto nodeG = new Node(OpType_CUSTOM, 0, 20, {2}); - - //nodeG->setCustomOp(&opG); - - graph.addNode(nodeA); - graph.addNode(nodeB); - graph.addNode(nodeC); - graph.addNode(nodeD0); - graph.addNode(nodeD1); - graph.addNode(nodeD2); - graph.addNode(nodeE0); - graph.addNode(nodeE1); - graph.addNode(nodeE2); - - graph.addNode(nodeF0); - graph.addNode(nodeF1); - graph.addNode(nodeF2); - - graph.addNode(nodeG); - - // let's also validate structural integrity - graph.buildGraph(); - - ASSERT_EQ(0, nodeA->getLayer()); - ASSERT_EQ(1, nodeB->getLayer()); - ASSERT_EQ(2, nodeC->getLayer()); - - ASSERT_EQ(3, nodeD0->getLayer()); - ASSERT_EQ(4, nodeE0->getLayer()); - ASSERT_EQ(5, nodeF0->getLayer()); - - ASSERT_EQ(6, nodeD1->getLayer()); - ASSERT_EQ(7, nodeE1->getLayer()); - ASSERT_EQ(8, nodeF1->getLayer()); - - ASSERT_EQ(9, nodeD2->getLayer()); - ASSERT_EQ(10, nodeE2->getLayer()); - ASSERT_EQ(11, nodeF2->getLayer()); - - ASSERT_EQ(12, nodeG->getLayer()); - + NDArrayList list(0, true); + for (int e = 0; e < 10; e++) { + auto row = NDArrayFactory::create('c', {3}); + row.assign((double)e); + list.write(e, row.dup()); + } - auto result = GraphExecutioner::execute(&graph); - ASSERT_EQ(ND4J_STATUS_OK, result); + auto exp = NDArrayFactory::create('c', {10, 3}); + auto tads = exp.allTensorsAlongDimension({1}); + for (int e = 0; e < 10; e++) { + auto tad = tads.at(9 - e); + tad.assign(e); + } - ASSERT_TRUE(variableSpace->hasVariable(2)); - auto list = variableSpace->getVariable(2)->getNDArrayList(); + auto indices = NDArrayFactory::create('c', {1, 10}); + indices.linspace(9, -1); - ASSERT_TRUE(list != nullptr); + sd::ops::gather_list op; + auto result = op.execute(list, {&indices}, {}, {}); - ASSERT_EQ(3, list->height()); - ASSERT_EQ(3, list->elements()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(1, result.size()); - ASSERT_TRUE(variableSpace->hasVariable(20)); + auto z = result.at(0); - auto stack = variableSpace->getVariable(20)->getNDArray(); + ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(stack != nullptr); + // exp.printIndexedBuffer("e"); + // z->printIndexedBuffer("z"); - ASSERT_TRUE(exp.isSameShape(stack)); - ASSERT_TRUE(exp.equalsTo(stack)); + ASSERT_TRUE(exp.equalsTo(z)); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/LoopCoordsHelperTests.cpp b/libnd4j/tests_cpu/layers_tests/LoopCoordsHelperTests.cpp index 976e89550201..b3cf4833292c 100644 --- a/libnd4j/tests_cpu/layers_tests/LoopCoordsHelperTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/LoopCoordsHelperTests.cpp @@ -14,210 +14,202 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Abdelrauf - // +// +// @author Abdelrauf +// -#include "testlayers.h" #include + #include + +#include "testlayers.h" using namespace sd; class LoopCoordsHelper : public testing::Test { -public: - + public: }; - -template -FORCEINLINE -typename std::enable_if<(Rank - 1 == rankIndex), bool>::type +template +FORCEINLINE typename std::enable_if<(Rank - 1 == rankIndex), bool>::type eq_strides(CoordsState& cbs, const Nd4jLong* strides) { - return STRIDE(cbs, rankIndex) == strides[rankIndex]; + return STRIDE(cbs, rankIndex) == strides[rankIndex]; } -template -FORCEINLINE -typename std::enable_if<(Rank - 1 != rankIndex), bool>::type +template +FORCEINLINE typename std::enable_if<(Rank - 1 != rankIndex), bool>::type eq_strides(CoordsState& cbs, const Nd4jLong* strides) { - return STRIDE(cbs, rankIndex) == strides[rankIndex] && eq_strides(cbs, strides); + return STRIDE(cbs, rankIndex) == strides[rankIndex] && + eq_strides(cbs, strides); } -template -FORCEINLINE -typename std::enable_if<(Rank - 1 == rankIndex), bool>::type -eq_zip_strides(ZipCoordsState& cbs, const Nd4jLong* strides1, const Nd4jLong* strides2) { - return ZIP_STRIDE1(cbs, rankIndex) == strides1[rankIndex] && ZIP_STRIDE2(cbs, rankIndex) == strides2[rankIndex]; +template +FORCEINLINE typename std::enable_if<(Rank - 1 == rankIndex), bool>::type +eq_zip_strides(ZipCoordsState& cbs, const Nd4jLong* strides1, + const Nd4jLong* strides2) { + return ZIP_STRIDE1(cbs, rankIndex) == strides1[rankIndex] && + ZIP_STRIDE2(cbs, rankIndex) == strides2[rankIndex]; } -template -FORCEINLINE -typename std::enable_if<(Rank - 1 != rankIndex), bool>::type -eq_zip_strides(ZipCoordsState& cbs, const Nd4jLong* strides1, const Nd4jLong* strides2) { - return ZIP_STRIDE1(cbs, rankIndex) == strides1[rankIndex] && ZIP_STRIDE2(cbs, rankIndex) == strides2[rankIndex] - && eq_zip_strides(cbs, strides1, strides2); +template +FORCEINLINE typename std::enable_if<(Rank - 1 != rankIndex), bool>::type +eq_zip_strides(ZipCoordsState& cbs, const Nd4jLong* strides1, + const Nd4jLong* strides2) { + return ZIP_STRIDE1(cbs, rankIndex) == strides1[rankIndex] && + ZIP_STRIDE2(cbs, rankIndex) == strides2[rankIndex] && + eq_zip_strides(cbs, strides1, strides2); } - - - TEST_F(LoopCoordsHelper, Init_Tests) { + constexpr size_t test_Index = 131; + constexpr size_t Rank = 5; - constexpr size_t test_Index = 131; - constexpr size_t Rank = 5; + Nd4jLong shape[Rank] = {3, 5, 7, 8, 9}; + Nd4jLong multiply_st[] = {2, 3, 3, 5, 6, 7, 9, 3}; + Nd4jLong strides_c[Rank]; + Nd4jLong strides_f[Rank]; - Nd4jLong shape[Rank] = { 3, 5 ,7, 8, 9}; - Nd4jLong multiply_st[] = { 2,3,3,5,6,7,9,3 }; - Nd4jLong strides_c[Rank] ; - Nd4jLong strides_f[Rank]; + Nd4jLong coords[Rank]; + Nd4jLong coords_f[Rank]; - Nd4jLong coords[Rank]; - Nd4jLong coords_f[Rank]; + strides_f[0] = multiply_st[0] * shape[0]; + strides_c[Rank - 1] = multiply_st[Rank - 1] * shape[Rank - 1]; - strides_f[0] = multiply_st[0] * shape[0]; - strides_c[Rank-1] = multiply_st[Rank-1] * shape[Rank-1]; + for (int i = 1; i < Rank; i++) { + strides_f[i] = strides_f[i - 1] * multiply_st[i] * shape[i]; + } - for (int i = 1; i < Rank; i++) { - strides_f[i] = strides_f[i - 1] * multiply_st[i] * shape[i]; - } + for (int i = Rank - 2; i >= 0; i--) { + strides_c[i] = strides_c[i + 1] * multiply_st[i] * shape[i]; + } - for (int i = Rank-2; i >=0; i--) { - strides_c[i] = strides_c[i+1] * multiply_st[i] * shape[i]; - } + // init our base coords + index2coords_C(test_Index, Rank, shape, coords); + index2coords_F(test_Index, Rank, shape, coords_f); - //init our base coords - index2coords_C(test_Index, Rank, shape, coords); - index2coords_F(test_Index, Rank, shape, coords_f); + size_t offset_calc = offset_from_coords(strides_c, coords, Rank); + size_t offset_calc_f = offset_from_coords(strides_f, coords_f, Rank); + CoordsState cts; + CoordsState cts_f; - size_t offset_calc = offset_from_coords(strides_c, coords, Rank); - size_t offset_calc_f = offset_from_coords(strides_f, coords_f, Rank); - - CoordsState cts; - CoordsState cts_f; + ZipCoordsState zcts; + ZipCoordsState zcts_f; - ZipCoordsState zcts; - ZipCoordsState zcts_f; + size_t offset = init_coords(cts, test_Index, shape, strides_c); + size_t offset_f = + init_coords(cts_f, test_Index, shape, strides_f); - size_t offset = init_coords(cts, test_Index, shape, strides_c); - size_t offset_f = init_coords(cts_f, test_Index, shape, strides_f); - - zip_size_t zoffset = init_coords(zcts, test_Index, shape, strides_c, strides_c); - zip_size_t zoffset_f = init_coords(zcts_f, test_Index, shape, strides_f, strides_f); - - ASSERT_TRUE(eq_coords(cts, coords)); - ASSERT_TRUE(eq_coords(cts_f, coords_f)); + zip_size_t zoffset = + init_coords(zcts, test_Index, shape, strides_c, strides_c); + zip_size_t zoffset_f = init_coords(zcts_f, test_Index, shape, + strides_f, strides_f); - ASSERT_TRUE(eq_zip_coords(zcts, coords)); - ASSERT_TRUE(eq_zip_coords(zcts_f, coords_f)); + ASSERT_TRUE(eq_coords(cts, coords)); + ASSERT_TRUE(eq_coords(cts_f, coords_f)); - ASSERT_TRUE(eq_strides(cts,strides_c)); - ASSERT_TRUE(eq_strides(cts_f,strides_f)); + ASSERT_TRUE(eq_zip_coords(zcts, coords)); + ASSERT_TRUE(eq_zip_coords(zcts_f, coords_f)); - ASSERT_TRUE(eq_zip_strides(zcts, strides_c, strides_c)); - ASSERT_TRUE(eq_zip_strides(zcts_f, strides_f, strides_f)); + ASSERT_TRUE(eq_strides(cts, strides_c)); + ASSERT_TRUE(eq_strides(cts_f, strides_f)); + ASSERT_TRUE(eq_zip_strides(zcts, strides_c, strides_c)); + ASSERT_TRUE(eq_zip_strides(zcts_f, strides_f, strides_f)); - ASSERT_EQ(offset , offset_calc); - ASSERT_EQ(zoffset.first , offset_calc); - ASSERT_EQ(zoffset.second , offset_calc); - ASSERT_EQ(offset_f , offset_calc_f); - ASSERT_EQ(zoffset_f.first , offset_calc_f); - ASSERT_EQ(zoffset_f.second , offset_calc_f); + ASSERT_EQ(offset, offset_calc); + ASSERT_EQ(zoffset.first, offset_calc); + ASSERT_EQ(zoffset.second, offset_calc); + ASSERT_EQ(offset_f, offset_calc_f); + ASSERT_EQ(zoffset_f.first, offset_calc_f); + ASSERT_EQ(zoffset_f.second, offset_calc_f); } - -TEST_F(LoopCoordsHelper, Increment_Use_Tests) { +TEST_F(LoopCoordsHelper, Increment_Use_Tests) { + constexpr size_t Rank = 4; - constexpr size_t Rank = 4; - - Nd4jLong shape[Rank] = { 3, 5 ,7, 8 }; - Nd4jLong multiply_st[] = { 2,3,3,5,6,7,9,3 }; - Nd4jLong strides_c[Rank]; - Nd4jLong strides_f[Rank]; - - Nd4jLong coords[Rank] = {}; - Nd4jLong coords_f[Rank] = {}; - Nd4jLong coords2[Rank] = {}; - Nd4jLong coords2_f[Rank] = {}; - Nd4jLong zcoords2[Rank] = {}; - Nd4jLong zcoords2_f[Rank] = {}; - - strides_f[0] = multiply_st[0] * shape[0]; - strides_c[Rank - 1] = multiply_st[Rank - 1] * shape[Rank - 1]; - - for (int i = 1; i < Rank; i++) { - strides_f[i] = strides_f[i - 1] * multiply_st[i] * shape[i]; - } - - for (int i = Rank - 2; i >= 0; i--) { - strides_c[i] = strides_c[i + 1] * multiply_st[i] * shape[i]; - } - - int total = 1; - for (int i = 0; i < Rank; i++) { - total *= shape[i]; - } - - CoordsState cts; - CoordsState cts_f; - - ZipCoordsState zcts; - ZipCoordsState zcts_f; - - size_t offset = init_coords(cts, 0, shape, strides_c); - size_t offset_f = init_coords(cts_f, 0, shape, strides_f); + Nd4jLong shape[Rank] = {3, 5, 7, 8}; + Nd4jLong multiply_st[] = {2, 3, 3, 5, 6, 7, 9, 3}; + Nd4jLong strides_c[Rank]; + Nd4jLong strides_f[Rank]; - zip_size_t zoffset = init_coords(zcts, 0, shape, strides_c, strides_c); - zip_size_t zoffset_f = init_coords(zcts_f, 0, shape, strides_f, strides_f); + Nd4jLong coords[Rank] = {}; + Nd4jLong coords_f[Rank] = {}; + Nd4jLong coords2[Rank] = {}; + Nd4jLong coords2_f[Rank] = {}; + Nd4jLong zcoords2[Rank] = {}; + Nd4jLong zcoords2_f[Rank] = {}; - size_t offset2 = 0; - size_t offset2_f = 0; - zip_size_t zoffset2 = {}; - zip_size_t zoffset2_f = {}; + strides_f[0] = multiply_st[0] * shape[0]; + strides_c[Rank - 1] = multiply_st[Rank - 1] * shape[Rank - 1]; - for (int j = 0; j < total; j++) { + for (int i = 1; i < Rank; i++) { + strides_f[i] = strides_f[i - 1] * multiply_st[i] * shape[i]; + } + for (int i = Rank - 2; i >= 0; i--) { + strides_c[i] = strides_c[i + 1] * multiply_st[i] * shape[i]; + } - index2coords_C(j, Rank, shape, coords); - index2coords_F(j, Rank, shape, coords_f); + int total = 1; + for (int i = 0; i < Rank; i++) { + total *= shape[i]; + } - size_t offset_calc = offset_from_coords(strides_c, coords, Rank); - size_t offset_calc_f = offset_from_coords(strides_f, coords_f, Rank); + CoordsState cts; + CoordsState cts_f; + ZipCoordsState zcts; + ZipCoordsState zcts_f; - ASSERT_TRUE(eq_coords(cts, coords)); - ASSERT_TRUE(eq_coords(cts_f, coords_f)); + size_t offset = init_coords(cts, 0, shape, strides_c); + size_t offset_f = init_coords(cts_f, 0, shape, strides_f); - ASSERT_TRUE(eq_zip_coords(zcts, coords)); - ASSERT_TRUE(eq_zip_coords(zcts_f, coords_f)); + zip_size_t zoffset = init_coords(zcts, 0, shape, strides_c, strides_c); + zip_size_t zoffset_f = + init_coords(zcts_f, 0, shape, strides_f, strides_f); - ASSERT_EQ(offset, offset_calc); - ASSERT_EQ(zoffset.first, offset_calc); - ASSERT_EQ(zoffset.second, offset_calc); - ASSERT_EQ(offset_f, offset_calc_f); - ASSERT_EQ(zoffset_f.first, offset_calc_f); - ASSERT_EQ(zoffset_f.second, offset_calc_f); + size_t offset2 = 0; + size_t offset2_f = 0; + zip_size_t zoffset2 = {}; + zip_size_t zoffset2_f = {}; + for (int j = 0; j < total; j++) { + index2coords_C(j, Rank, shape, coords); + index2coords_F(j, Rank, shape, coords_f); - ASSERT_EQ(offset2, offset_calc); - ASSERT_EQ(zoffset2.first, offset_calc); - ASSERT_EQ(zoffset2.second, offset_calc); - ASSERT_EQ(offset2_f, offset_calc_f); - ASSERT_EQ(zoffset2_f.first, offset_calc_f); - ASSERT_EQ(zoffset2_f.second, offset_calc_f); + size_t offset_calc = offset_from_coords(strides_c, coords, Rank); + size_t offset_calc_f = offset_from_coords(strides_f, coords_f, Rank); - offset = inc_coords(cts, offset); - offset_f = inc_coords(cts_f, offset_f); - zoffset = inc_coords(zcts, zoffset); - zoffset_f = inc_coords(zcts_f, zoffset_f); + ASSERT_TRUE(eq_coords(cts, coords)); + ASSERT_TRUE(eq_coords(cts_f, coords_f)); - offset2 = inc_coords(shape,strides_c, coords2, offset2, Rank); - offset2_f = inc_coords(shape, strides_f, coords2_f, offset2_f, Rank); - zoffset2 = inc_coords(shape, strides_c, strides_c, zcoords2, zoffset2, Rank); - zoffset2_f = inc_coords(shape, strides_f, strides_f, zcoords2_f, zoffset2_f, Rank); + ASSERT_TRUE(eq_zip_coords(zcts, coords)); + ASSERT_TRUE(eq_zip_coords(zcts_f, coords_f)); - } - + ASSERT_EQ(offset, offset_calc); + ASSERT_EQ(zoffset.first, offset_calc); + ASSERT_EQ(zoffset.second, offset_calc); + ASSERT_EQ(offset_f, offset_calc_f); + ASSERT_EQ(zoffset_f.first, offset_calc_f); + ASSERT_EQ(zoffset_f.second, offset_calc_f); + + ASSERT_EQ(offset2, offset_calc); + ASSERT_EQ(zoffset2.first, offset_calc); + ASSERT_EQ(zoffset2.second, offset_calc); + ASSERT_EQ(offset2_f, offset_calc_f); + ASSERT_EQ(zoffset2_f.first, offset_calc_f); + ASSERT_EQ(zoffset2_f.second, offset_calc_f); + + offset = inc_coords(cts, offset); + offset_f = inc_coords(cts_f, offset_f); + zoffset = inc_coords(zcts, zoffset); + zoffset_f = inc_coords(zcts_f, zoffset_f); + + offset2 = inc_coords(shape, strides_c, coords2, offset2, Rank); + offset2_f = inc_coords(shape, strides_f, coords2_f, offset2_f, Rank); + zoffset2 = + inc_coords(shape, strides_c, strides_c, zcoords2, zoffset2, Rank); + zoffset2_f = inc_coords(shape, strides_f, strides_f, zcoords2_f, + zoffset2_f, Rank); + } } - diff --git a/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp b/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp new file mode 100644 index 000000000000..e381702db92b --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp @@ -0,0 +1,61 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include +#include + +#include + +#include "testlayers.h" + +using namespace sd; +using namespace sd::graph; + +class ManagedDataBufferTests : public testing::Test { + public: + ManagedDataBufferTests() { + /// + } +}; + +TEST_F(ManagedDataBufferTests, basic_constructor_test_1) { + GraphMemoryManager mgr; + auto mdb = std::make_shared(mgr, 0, DataType::FLOAT32, + memory::MemoryZone::HOT); + + NDArray array(mdb, 'c', {0}); +} + +TEST_F(ManagedDataBufferTests, basic_constructor_test_2) { + auto exp = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + + GraphMemoryManager mgr; + auto mdb = std::make_shared(mgr, 20, DataType::FLOAT32, + memory::MemoryZone::HOT); + + ASSERT_NE(nullptr, mdb->platform()); + + NDArray array(mdb, 'c', {5}); + array.assign(1.0f); + + ASSERT_EQ(exp, array); +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/MemoryUtilsTests.cpp b/libnd4j/tests_cpu/layers_tests/MemoryUtilsTests.cpp index 4bfe40405012..bee1e37355cf 100644 --- a/libnd4j/tests_cpu/layers_tests/MemoryUtilsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MemoryUtilsTests.cpp @@ -20,27 +20,24 @@ #include #include + #include "testlayers.h" using namespace sd::memory; class MemoryUtilsTests : public testing::Test { -public: - + public: }; TEST_F(MemoryUtilsTests, BasicRetrieve_1) { - MemoryReport reportA; - MemoryReport reportB; + MemoryReport reportA; + MemoryReport reportB; #ifdef _WIN32 - if (1 > 0) - return; + if (1 > 0) return; #endif + MemoryUtils::retrieveMemoryStatistics(reportA); - MemoryUtils::retrieveMemoryStatistics(reportA); - - - ASSERT_NE(reportA, reportB); + ASSERT_NE(reportA, reportB); } diff --git a/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp b/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp index c91c1c5c7811..8c95507df314 100644 --- a/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp @@ -14,96 +14,99 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // @author raver119@gmail.com // #ifdef HAVE_MKLDNN -#include "testlayers.h" -#include -#include -#include #include #include +#include +#include +#include + +#include "testlayers.h" using namespace sd; class MklDnnTests : public testing::Test { -public: - + public: }; -static void printer(std::initializer_list helpers) { - - for (auto v:helpers) { - nd4j_printf("Initialized [%s]\n", v->name().c_str()); - } +static void printer( + std::initializer_list helpers) { + for (auto v : helpers) { + nd4j_printf("Initialized [%s]\n", v->name().c_str()); + } } - TEST_F(MklDnnTests, helpers_includer) { - // we need this block, to make sure all helpers are still available within binary, and not optimized out by linker - sd::ops::platforms::PLATFORM_conv2d_ENGINE_CPU conv2d; - sd::ops::platforms::PLATFORM_conv2d_bp_ENGINE_CPU conv2d_bp; + // we need this block, to make sure all helpers are still available within + // binary, and not optimized out by linker + sd::ops::platforms::PLATFORM_conv2d_ENGINE_CPU conv2d; + sd::ops::platforms::PLATFORM_conv2d_bp_ENGINE_CPU conv2d_bp; - sd::ops::platforms::PLATFORM_conv2d_ENGINE_CPU conv3d; - sd::ops::platforms::PLATFORM_conv2d_bp_ENGINE_CPU conv3d_bp; + sd::ops::platforms::PLATFORM_conv2d_ENGINE_CPU conv3d; + sd::ops::platforms::PLATFORM_conv2d_bp_ENGINE_CPU conv3d_bp; - sd::ops::platforms::PLATFORM_avgpool2d_ENGINE_CPU avgpool2d; - sd::ops::platforms::PLATFORM_avgpool2d_bp_ENGINE_CPU avgpool2d_bp; + sd::ops::platforms::PLATFORM_avgpool2d_ENGINE_CPU avgpool2d; + sd::ops::platforms::PLATFORM_avgpool2d_bp_ENGINE_CPU avgpool2d_bp; - sd::ops::platforms::PLATFORM_maxpool2d_ENGINE_CPU maxpool2d; - sd::ops::platforms::PLATFORM_maxpool2d_bp_ENGINE_CPU maxpool2d_bp; + sd::ops::platforms::PLATFORM_maxpool2d_ENGINE_CPU maxpool2d; + sd::ops::platforms::PLATFORM_maxpool2d_bp_ENGINE_CPU maxpool2d_bp; - sd::ops::platforms::PLATFORM_avgpool3dnew_ENGINE_CPU avgpool3d; - sd::ops::platforms::PLATFORM_avgpool3dnew_bp_ENGINE_CPU avgpool3d_bp; + sd::ops::platforms::PLATFORM_avgpool3dnew_ENGINE_CPU avgpool3d; + sd::ops::platforms::PLATFORM_avgpool3dnew_bp_ENGINE_CPU avgpool3d_bp; - sd::ops::platforms::PLATFORM_maxpool3dnew_ENGINE_CPU maxpool3d; - sd::ops::platforms::PLATFORM_maxpool3dnew_bp_ENGINE_CPU maxpool3d_bp; + sd::ops::platforms::PLATFORM_maxpool3dnew_ENGINE_CPU maxpool3d; + sd::ops::platforms::PLATFORM_maxpool3dnew_bp_ENGINE_CPU maxpool3d_bp; - sd::ops::platforms::PLATFORM_lrn_ENGINE_CPU lrn; + sd::ops::platforms::PLATFORM_lrn_ENGINE_CPU lrn; - sd::ops::platforms::PLATFORM_batchnorm_ENGINE_CPU batchnorm; + sd::ops::platforms::PLATFORM_batchnorm_ENGINE_CPU batchnorm; - sd::ops::platforms::PLATFORM_matmul_ENGINE_CPU matmul; + sd::ops::platforms::PLATFORM_matmul_ENGINE_CPU matmul; - sd::ops::platforms::PLATFORM_softmax_ENGINE_CPU softmax; + sd::ops::platforms::PLATFORM_softmax_ENGINE_CPU softmax; - sd::ops::platforms::PLATFORM_softmax_bp_ENGINE_CPU softmax_bp; + sd::ops::platforms::PLATFORM_softmax_bp_ENGINE_CPU softmax_bp; - sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh; + sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh; - sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh_bp; - - sd::ops::platforms::PLATFORM_xw_plus_b_ENGINE_CPU xw_plus_b; + sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh_bp; - sd::ops::platforms::PLATFORM_xw_plus_b_bp_ENGINE_CPU xw_plus_b_bp; + sd::ops::platforms::PLATFORM_xw_plus_b_ENGINE_CPU xw_plus_b; + sd::ops::platforms::PLATFORM_xw_plus_b_bp_ENGINE_CPU xw_plus_b_bp; - printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax, &softmax_bp, &tanh, &tanh_bp, &xw_plus_b, &xw_plus_b_bp }); + printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, + &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, + &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, + &lrn, &batchnorm, &matmul, &softmax, + &softmax_bp, &tanh, &tanh_bp, &xw_plus_b, + &xw_plus_b_bp}); } TEST_F(MklDnnTests, test_tanh_1) { - auto x = NDArrayFactory::create(1.0f); - auto z = NDArrayFactory::create(0.0f); + auto x = NDArrayFactory::create(1.0f); + auto z = NDArrayFactory::create(0.0f); - sd::ops::tanh op; - auto status = op.execute({&x}, {&z}); + sd::ops::tanh op; + auto status = op.execute({&x}, {&z}); - ASSERT_EQ(Status::OK(), status); + ASSERT_EQ(Status::OK(), status); } TEST_F(MklDnnTests, test_tanh_2) { - auto x = NDArrayFactory::create('c', {1}, {1.0f}); - auto z = NDArrayFactory::create('c', {1}, {0.0f}); + auto x = NDArrayFactory::create('c', {1}, {1.0f}); + auto z = NDArrayFactory::create('c', {1}, {0.0f}); - sd::ops::tanh op; - auto status = op.execute({&x}, {&z}); + sd::ops::tanh op; + auto status = op.execute({&x}, {&z}); - ASSERT_EQ(Status::OK(), status); + ASSERT_EQ(Status::OK(), status); } #endif \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/MmapTests.cpp b/libnd4j/tests_cpu/layers_tests/MmapTests.cpp index 7200dc034cb6..e0419e5b7dc7 100644 --- a/libnd4j/tests_cpu/layers_tests/MmapTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MmapTests.cpp @@ -18,38 +18,38 @@ // Created by raver on 5/13/2018. // -#include "testlayers.h" -#include #include #include +#include + #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class MmapTests : public testing::Test { -public: - + public: }; TEST_F(MmapTests, Test_Basic_Mmap_1) { - // FIXME: we must adopt this for CUDA as well - if (!Environment::getInstance().isCPU()) - return; + // FIXME: we must adopt this for CUDA as well + if (!Environment::getInstance().isCPU()) return; - // just 10GB - Nd4jLong size = 100000L; + // just 10GB + Nd4jLong size = 100000L; - std::ofstream ofs("file", std::ios::binary | std::ios::out); - ofs.seekp(size + 1024L); - ofs.write("", 1); - ofs.close(); + std::ofstream ofs("file", std::ios::binary | std::ios::out); + ofs.seekp(size + 1024L); + ofs.write("", 1); + ofs.close(); - auto result = mmapFile(nullptr, "file", size); + auto result = mmapFile(nullptr, "file", size); - ASSERT_FALSE(result == nullptr); + ASSERT_FALSE(result == nullptr); - munmapFile(nullptr, result, size); + munmapFile(nullptr, result, size); - remove("file"); + remove("file"); } diff --git a/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp b/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp index 79f2ffa1e66d..05571283ff8b 100644 --- a/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp @@ -18,1752 +18,1760 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include #include #include -#include #include +#include +#include "testlayers.h" using namespace sd; class MultiDataTypeTests : public testing::Test { -public: - + public: }; //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, DataTypeUtils_Test_1) { - auto dtype = DataTypeUtils::pickPairwiseResultType(sd::INT32, sd::FLOAT32); + auto dtype = DataTypeUtils::pickPairwiseResultType(sd::INT32, sd::FLOAT32); - ASSERT_EQ(sd::FLOAT32, dtype); + ASSERT_EQ(sd::FLOAT32, dtype); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, DataTypeUtils_Test_2) { - auto dtype = DataTypeUtils::pickPairwiseResultType(sd::INT32, sd::DOUBLE); - ASSERT_EQ(sd::DOUBLE, dtype); + auto dtype = DataTypeUtils::pickPairwiseResultType(sd::INT32, sd::DOUBLE); + ASSERT_EQ(sd::DOUBLE, dtype); - ASSERT_EQ(sd::DOUBLE, DataTypeUtils::pickPairwiseResultType(sd::DOUBLE, sd::INT32)); + ASSERT_EQ(sd::DOUBLE, + DataTypeUtils::pickPairwiseResultType(sd::DOUBLE, sd::INT32)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, DataTypeUtils_Test_3) { - auto dtype = DataTypeUtils::pickPairwiseResultType(sd::FLOAT32, sd::DOUBLE); - ASSERT_EQ(sd::FLOAT32, dtype); + auto dtype = DataTypeUtils::pickPairwiseResultType(sd::FLOAT32, sd::DOUBLE); + ASSERT_EQ(sd::FLOAT32, dtype); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, Basic_Test_1) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - auto x = NDArrayFactory::create('c', {2, 3}, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); - auto y = NDArrayFactory::create('c', {2, 3}, {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}); - auto e = NDArrayFactory::create('c', {2, 3}, {0.0f, 2.0f, 4.0f, 6.0f, 8.0f, 10.0f}); + auto x = NDArrayFactory::create('c', {2, 3}, + {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); + auto y = NDArrayFactory::create('c', {2, 3}, + {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}); + auto e = NDArrayFactory::create('c', {2, 3}, + {0.0f, 2.0f, 4.0f, 6.0f, 8.0f, 10.0f}); - auto z = x + y; + auto z = x + y; - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, Basic_Test_2) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - auto x = NDArrayFactory::create('c', {2, 3}, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); - auto y = NDArrayFactory::create(2.0); - auto e = NDArrayFactory::create('c', {2, 3}, {0.0f, 2.0f, 4.0f, 6.0f, 8.0f, 10.0f}); + auto x = NDArrayFactory::create('c', {2, 3}, + {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); + auto y = NDArrayFactory::create(2.0); + auto e = NDArrayFactory::create('c', {2, 3}, + {0.0f, 2.0f, 4.0f, 6.0f, 8.0f, 10.0f}); - auto z = x * y; + auto z = x * y; - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, Basic_Test_3) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - auto x = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create(2.0); - auto e = NDArrayFactory::create('c', {2, 3}, {0.0f, 2.0f, 4.0f, 6.0f, 8.0f, 10.0f}); + auto x = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create(2.0); + auto e = NDArrayFactory::create( + 'c', {2, 3}, {0.0f, 2.0f, 4.0f, 6.0f, 8.0f, 10.0f}); - auto z = x * y; + auto z = x * y; - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, Basic_Test_4) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - auto x = NDArrayFactory::create('c', {2, 3}, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); - auto y = NDArrayFactory::create(2.0); - auto e = NDArrayFactory::create('c', {2, 3}, {0.0f, 2.0f, 4.0f, 6.0f, 8.0f, 10.0f}); + auto x = NDArrayFactory::create('c', {2, 3}, + {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); + auto y = NDArrayFactory::create(2.0); + auto e = NDArrayFactory::create( + 'c', {2, 3}, {0.0f, 2.0f, 4.0f, 6.0f, 8.0f, 10.0f}); - auto z = x * y; + auto z = x * y; - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, Basic_Test_5) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - auto x = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create(2); - auto e = NDArrayFactory::create('c', {2, 3}, {0, 2, 4, 6, 8, 10}); + auto x = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create(2); + auto e = NDArrayFactory::create('c', {2, 3}, {0, 2, 4, 6, 8, 10}); - auto z = x * y; + auto z = x * y; - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(MultiDataTypeTests, Basic_Test_7) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - auto x = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {2, 3}, {0.f, 1.f, 2.f, 3.f, 4.f, 5.f}); - auto e = NDArrayFactory::create('c', {2, 3}, {0.f, 2.f, 4.f, 6.f, 8.f, 10.f}); + auto x = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {2, 3}, + {0.f, 1.f, 2.f, 3.f, 4.f, 5.f}); + auto e = NDArrayFactory::create('c', {2, 3}, + {0.f, 2.f, 4.f, 6.f, 8.f, 10.f}); - sd::ops::add op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::add op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, Basic_Test_6) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - auto x = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create(2); - auto e = NDArrayFactory::create('c', {2, 3}, {0, 2, 4, 6, 8, 10}); + auto x = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create(2); + auto e = NDArrayFactory::create('c', {2, 3}, {0, 2, 4, 6, 8, 10}); - auto z = x * y; + auto z = x * y; - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_assign_number_test1) { - NDArray x('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::UINT8); - NDArray exp('c', {2, 3}, {10, 10, 10, 10, 10, 10}, sd::DataType::UINT8); + NDArray x('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::UINT8); + NDArray exp('c', {2, 3}, {10, 10, 10, 10, 10, 10}, sd::DataType::UINT8); - const double number = 10.8; - x = number; + const double number = 10.8; + x = number; - ASSERT_EQ(x,exp); + ASSERT_EQ(x, exp); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_assign_number_test2) { - NDArray x('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::INT64); - NDArray exp('c', {2, 3}, {1, 1, 1, 1, 1, 1}, sd::DataType::INT64); + NDArray x('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::INT64); + NDArray exp('c', {2, 3}, {1, 1, 1, 1, 1, 1}, sd::DataType::INT64); - const bool number = 1000; - x = number; + const bool number = 1000; + x = number; - ASSERT_EQ(x,exp); + ASSERT_EQ(x, exp); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_assign_number_test3) { - NDArray x('c', {2, 3}, {0, 1, 0, 1, 0, 1}, sd::DataType::BOOL); - NDArray exp('c', {2, 3}, {1, 1, 1, 1, 1, 1}, sd::DataType::BOOL); + NDArray x('c', {2, 3}, {0, 1, 0, 1, 0, 1}, sd::DataType::BOOL); + NDArray exp('c', {2, 3}, {1, 1, 1, 1, 1, 1}, sd::DataType::BOOL); - const int number = 1000; - x = number; + const int number = 1000; + x = number; - ASSERT_EQ(x,exp); + ASSERT_EQ(x, exp); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_repeat_test1) { - NDArray x('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::HALF); - NDArray y('c', {2, 4}, sd::DataType::HALF); - NDArray exp('c', {2, 4}, {0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5}, sd::DataType::HALF); + NDArray x('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray y('c', {2, 4}, sd::DataType::HALF); + NDArray exp('c', {2, 4}, {0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5}, + sd::DataType::HALF); - x.repeat(1, {2}, y); + x.repeat(1, {2}, y); - ASSERT_EQ(y, exp); + ASSERT_EQ(y, exp); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_bufferAsT_test1) { - NDArray x('f', {2}, {1.5, 3.5}, sd::DataType::FLOAT32); - NDArray y('c', {}, std::vector{1.5}, sd::DataType::FLOAT32); + NDArray x('f', {2}, {1.5, 3.5}, sd::DataType::FLOAT32); + NDArray y('c', {}, std::vector{1.5}, sd::DataType::FLOAT32); - const int* buffX = x.bufferAsT(); - const int* buffY = y.bufferAsT(); + const int* buffX = x.bufferAsT(); + const int* buffY = y.bufferAsT(); - ASSERT_EQ(*buffX, *buffY); + ASSERT_EQ(*buffX, *buffY); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_assign_test1) { - NDArray x('c', {2,2}, {0, 1, 2, 3}, sd::DataType::UINT8); - NDArray exp('c', {2,2}, {10, 10, 20, 20}, sd::DataType::UINT8); + NDArray x('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::UINT8); + NDArray exp('c', {2, 2}, {10, 10, 20, 20}, sd::DataType::UINT8); - NDArray scalar1('c', {}, std::vector{10.5}, sd::DataType::FLOAT32); - NDArray scalar2('c', {}, std::vector{20.8}, sd::DataType::DOUBLE); + NDArray scalar1('c', {}, std::vector{10.5}, sd::DataType::FLOAT32); + NDArray scalar2('c', {}, std::vector{20.8}, sd::DataType::DOUBLE); - x(0,{0}).assign(scalar1); - x(1,{0}).assign(scalar2); + x(0, {0}).assign(scalar1); + x(1, {0}).assign(scalar2); - ASSERT_EQ(x, exp); + ASSERT_EQ(x, exp); - x.assign(exp); + x.assign(exp); - ASSERT_EQ(x, exp); + ASSERT_EQ(x, exp); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test1) { - NDArray x('f', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::HALF); - NDArray exp1('c', {}, std::vector{3}, sd::DataType::INT64); - NDArray exp2('c', {1,1}, std::vector{1}, sd::DataType::INT64); - NDArray exp3('c', {2}, std::vector{1,2}, sd::DataType::INT64); + NDArray x('f', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray exp1('c', {}, std::vector{3}, sd::DataType::INT64); + NDArray exp2('c', {1, 1}, std::vector{1}, sd::DataType::INT64); + NDArray exp3('c', {2}, std::vector{1, 2}, sd::DataType::INT64); - auto scalar1 = x.reduceAlongDimension(sd::reduce::CountNonZero, {}/*whole range*/); - ASSERT_EQ(scalar1, exp1); + auto scalar1 = + x.reduceAlongDimension(sd::reduce::CountNonZero, {} /*whole range*/); + ASSERT_EQ(scalar1, exp1); - auto scalar2 = x.reduceAlongDimension(sd::reduce::CountZero, {}/*whole range*/, true); - ASSERT_EQ(scalar2, exp2); + auto scalar2 = + x.reduceAlongDimension(sd::reduce::CountZero, {} /*whole range*/, true); + ASSERT_EQ(scalar2, exp2); - auto scalar3 = x.reduceAlongDimension(sd::reduce::CountNonZero, {1}); - ASSERT_EQ(scalar3, exp3); + auto scalar3 = x.reduceAlongDimension(sd::reduce::CountNonZero, {1}); + ASSERT_EQ(scalar3, exp3); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test2) { - NDArray x('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT32); - NDArray exp1('c', {}, std::vector{1.5}, sd::DataType::FLOAT32); - NDArray exp2('c', {2}, {0.5,2.5}, sd::DataType::FLOAT32); + NDArray x('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT32); + NDArray exp1('c', {}, std::vector{1.5}, sd::DataType::FLOAT32); + NDArray exp2('c', {2}, {0.5, 2.5}, sd::DataType::FLOAT32); - auto scalar1 = x.reduceAlongDimension(sd::reduce::Mean, {}/*whole range*/); - // scalar1->printShapeInfo(); - // scalar1->printIndexedBuffer(); - ASSERT_EQ(scalar1, exp1); + auto scalar1 = x.reduceAlongDimension(sd::reduce::Mean, {} /*whole range*/); + // scalar1->printShapeInfo(); + // scalar1->printIndexedBuffer(); + ASSERT_EQ(scalar1, exp1); - auto scalar2 = x.reduceAlongDimension(sd::reduce::Mean, {1}); - ASSERT_EQ(scalar2, exp2); + auto scalar2 = x.reduceAlongDimension(sd::reduce::Mean, {1}); + ASSERT_EQ(scalar2, exp2); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test3) { - NDArray x('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::HALF); - NDArray exp1('c', {}, std::vector{8.}, sd::DataType::HALF); - NDArray exp2('c', {2}, {2.,6.}, sd::DataType::HALF); + NDArray x('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray exp1('c', {}, std::vector{8.}, sd::DataType::HALF); + NDArray exp2('c', {2}, {2., 6.}, sd::DataType::HALF); - auto scalar1 = x.reduceAlongDimension(sd::reduce::Sum, {}/*whole range*/); - ASSERT_EQ(scalar1, exp1); + auto scalar1 = x.reduceAlongDimension(sd::reduce::Sum, {} /*whole range*/); + ASSERT_EQ(scalar1, exp1); - auto scalar2 = x.reduceAlongDimension(sd::reduce::Sum, {1}); - ASSERT_EQ(scalar2, exp2); + auto scalar2 = x.reduceAlongDimension(sd::reduce::Sum, {1}); + ASSERT_EQ(scalar2, exp2); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test4) { - NDArray x('c', {2, 2}, {10.5, 1.5, -2.5, -3.5}, sd::DataType::HALF); - NDArray exp1('c', {}, std::vector{1}, sd::DataType::BOOL); - NDArray exp2('c', {2}, std::vector{1, 0}, sd::DataType::BOOL); + NDArray x('c', {2, 2}, {10.5, 1.5, -2.5, -3.5}, sd::DataType::HALF); + NDArray exp1('c', {}, std::vector{1}, sd::DataType::BOOL); + NDArray exp2('c', {2}, std::vector{1, 0}, sd::DataType::BOOL); - auto scalar1 = x.reduceAlongDimension(sd::reduce::IsPositive, {}/*whole range*/); - ASSERT_EQ(scalar1, exp1); + auto scalar1 = + x.reduceAlongDimension(sd::reduce::IsPositive, {} /*whole range*/); + ASSERT_EQ(scalar1, exp1); - auto scalar2 = x.reduceAlongDimension(sd::reduce::IsPositive, {1}); - ASSERT_EQ(scalar2, exp2); + auto scalar2 = x.reduceAlongDimension(sd::reduce::IsPositive, {1}); + ASSERT_EQ(scalar2, exp2); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_varianceNumber_test1) { - NDArray x('f', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray exp1('c', {}, std::vector{1.666666667}, sd::DataType::FLOAT32); - NDArray exp2('c', {}, std::vector{1.118033989}, sd::DataType::FLOAT32); + NDArray x('f', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray exp1('c', {}, std::vector{1.666666667}, + sd::DataType::FLOAT32); + NDArray exp2('c', {}, std::vector{1.118033989}, + sd::DataType::FLOAT32); - auto scalar1 = x.varianceNumber(variance::SummaryStatsVariance); - ASSERT_EQ(scalar1, exp1); + auto scalar1 = x.varianceNumber(variance::SummaryStatsVariance); + ASSERT_EQ(scalar1, exp1); - auto scalar2 = x.varianceNumber(variance::SummaryStatsStandardDeviation, false); - ASSERT_EQ(scalar2, exp2); + auto scalar2 = + x.varianceNumber(variance::SummaryStatsStandardDeviation, false); + ASSERT_EQ(scalar2, exp2); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorPlus_test1) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2, 2}, {-1, -2, -1, -2}, sd::DataType::FLOAT32); - NDArray x3('c', {2}, {-1, -2}, sd::DataType::FLOAT32); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {-1, -2, -1, -2}, sd::DataType::FLOAT32); + NDArray x3('c', {2}, {-1, -2}, sd::DataType::FLOAT32); - NDArray exp('c', {2, 2}, {-1, -1, 1, 1}, sd::DataType::FLOAT32); + NDArray exp('c', {2, 2}, {-1, -1, 1, 1}, sd::DataType::FLOAT32); - ASSERT_EQ(x1+x2, exp); - ASSERT_EQ(x1+x3, exp); + ASSERT_EQ(x1 + x2, exp); + ASSERT_EQ(x1 + x3, exp); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorPlus_test2) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); - NDArray x3('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::HALF); - const double val1 = -2; - const int val2 = -2; - NDArray exp1('c', {2,2}, {-2, -1, 0, 1}, sd::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {-2, -1, 0, 1}, sd::DataType::FLOAT32); - NDArray exp3('c', {2,2}, {-2, -1, 0, 1}, sd::DataType::HALF); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray x3('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::HALF); + const double val1 = -2; + const int val2 = -2; + NDArray exp1('c', {2, 2}, {-2, -1, 0, 1}, sd::DataType::DOUBLE); + NDArray exp2('c', {2, 2}, {-2, -1, 0, 1}, sd::DataType::FLOAT32); + NDArray exp3('c', {2, 2}, {-2, -1, 0, 1}, sd::DataType::HALF); - ASSERT_EQ(x1+val1, exp1); - ASSERT_EQ(val1+x1, exp1); + ASSERT_EQ(x1 + val1, exp1); + ASSERT_EQ(val1 + x1, exp1); - ASSERT_EQ(x2+val2, exp2); - ASSERT_EQ(val2+x2, exp2); + ASSERT_EQ(x2 + val2, exp2); + ASSERT_EQ(val2 + x2, exp2); - ASSERT_EQ(x3+val1, exp3); - ASSERT_EQ(val1+x3, exp3); + ASSERT_EQ(x3 + val1, exp3); + ASSERT_EQ(val1 + x3, exp3); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorMinus_test1) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2, 2}, {-1, -2, -1, -2}, sd::DataType::HALF); - NDArray x3('c', {2}, {-1, -2}, sd::DataType::HALF); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {-1, -2, -1, -2}, sd::DataType::HALF); + NDArray x3('c', {2}, {-1, -2}, sd::DataType::HALF); - NDArray exp('c', {2, 2}, {1, 3, 3, 5}, sd::DataType::HALF); + NDArray exp('c', {2, 2}, {1, 3, 3, 5}, sd::DataType::HALF); - ASSERT_EQ(x1-x2, exp); - ASSERT_EQ(x1-x3, exp); + ASSERT_EQ(x1 - x2, exp); + ASSERT_EQ(x1 - x3, exp); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorMinus_test2) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); - NDArray x3('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::HALF); - const double val1 = 2; - const int val2 = 2; - NDArray exp1('c', {2,2}, {-2, -1, 0, 1}, sd::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {2, 1, 0, -1}, sd::DataType::DOUBLE); - NDArray exp3('c', {2,2}, {-2, -1, 0, 1}, sd::DataType::FLOAT32); - NDArray exp4('c', {2,2}, {-2, -1, 0, 1}, sd::DataType::HALF); - NDArray exp5('c', {2,2}, {2, 1, 0, -1}, sd::DataType::FLOAT32); - NDArray exp6('c', {2,2}, {2, 1, 0, -1}, sd::DataType::HALF); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray x3('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::HALF); + const double val1 = 2; + const int val2 = 2; + NDArray exp1('c', {2, 2}, {-2, -1, 0, 1}, sd::DataType::DOUBLE); + NDArray exp2('c', {2, 2}, {2, 1, 0, -1}, sd::DataType::DOUBLE); + NDArray exp3('c', {2, 2}, {-2, -1, 0, 1}, sd::DataType::FLOAT32); + NDArray exp4('c', {2, 2}, {-2, -1, 0, 1}, sd::DataType::HALF); + NDArray exp5('c', {2, 2}, {2, 1, 0, -1}, sd::DataType::FLOAT32); + NDArray exp6('c', {2, 2}, {2, 1, 0, -1}, sd::DataType::HALF); - ASSERT_EQ(x1-val1, exp1); - ASSERT_EQ(val1-x1, exp2); + ASSERT_EQ(x1 - val1, exp1); + ASSERT_EQ(val1 - x1, exp2); - ASSERT_EQ(x2-val2, exp3); - ASSERT_EQ(val2-x2, exp5); + ASSERT_EQ(x2 - val2, exp3); + ASSERT_EQ(val2 - x2, exp5); - ASSERT_EQ(x3-val1, exp4); - ASSERT_EQ(val1-x3, exp6); + ASSERT_EQ(x3 - val1, exp4); + ASSERT_EQ(val1 - x3, exp6); } -//////////////////////////////////////////////////////////////////////////////// multiply +//////////////////////////////////////////////////////////////////////////////// +/// multiply TEST_F(MultiDataTypeTests, ndarray_operatorMultiply_test1) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2, 2}, {-1, -2, -1, -2}, sd::DataType::DOUBLE); - NDArray x3('c', {2}, {-1, -2}, sd::DataType::DOUBLE); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {-1, -2, -1, -2}, sd::DataType::DOUBLE); + NDArray x3('c', {2}, {-1, -2}, sd::DataType::DOUBLE); - NDArray exp('c', {2, 2}, {0, -2, -2, -6}, sd::DataType::DOUBLE); + NDArray exp('c', {2, 2}, {0, -2, -2, -6}, sd::DataType::DOUBLE); - ASSERT_EQ(x1*x2, exp); - ASSERT_EQ(x1*x3, exp); + ASSERT_EQ(x1 * x2, exp); + ASSERT_EQ(x1 * x3, exp); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorMultiply_test2) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); - NDArray x3('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::HALF); - const double val1 = -2; - const int val2 = -2; - NDArray exp1('c', {2,2}, {0, -2, -4, -6}, sd::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {0, -2, -4, -6}, sd::DataType::FLOAT32); - NDArray exp3('c', {2,2}, {0, -2, -4, -6}, sd::DataType::HALF); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray x3('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::HALF); + const double val1 = -2; + const int val2 = -2; + NDArray exp1('c', {2, 2}, {0, -2, -4, -6}, sd::DataType::DOUBLE); + NDArray exp2('c', {2, 2}, {0, -2, -4, -6}, sd::DataType::FLOAT32); + NDArray exp3('c', {2, 2}, {0, -2, -4, -6}, sd::DataType::HALF); - ASSERT_EQ(x1*val1, exp1); - ASSERT_EQ(val1*x1, exp1); + ASSERT_EQ(x1 * val1, exp1); + ASSERT_EQ(val1 * x1, exp1); - ASSERT_EQ(x2*val2, exp2); - ASSERT_EQ(val2*x2, exp2); + ASSERT_EQ(x2 * val2, exp2); + ASSERT_EQ(val2 * x2, exp2); - ASSERT_EQ(x3*val1, exp3); - ASSERT_EQ(val1*x3, exp3); + ASSERT_EQ(x3 * val1, exp3); + ASSERT_EQ(val1 * x3, exp3); } - -//////////////////////////////////////////////////////////////////////////////// multiply +//////////////////////////////////////////////////////////////////////////////// +/// multiply TEST_F(MultiDataTypeTests, ndarray_operatorDivide_test1) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2, 2}, {4, 1, 2, 3}, sd::DataType::HALF); - NDArray x2('c', {2, 2}, {-1, -2, -1, -9}, sd::DataType::DOUBLE); - NDArray x3('c', {2}, {-1, -2}, sd::DataType::FLOAT32); + NDArray x1('c', {2, 2}, {4, 1, 2, 3}, sd::DataType::HALF); + NDArray x2('c', {2, 2}, {-1, -2, -1, -9}, sd::DataType::DOUBLE); + NDArray x3('c', {2}, {-1, -2}, sd::DataType::FLOAT32); - NDArray exp1('c', {2, 2}, {-4, -0.5, -2, -0.3333333}, sd::DataType::HALF); - NDArray exp2('c', {2, 2}, {-0.25, -2, -0.5, -0.666667}, sd::DataType::HALF); + NDArray exp1('c', {2, 2}, {-4, -0.5, -2, -0.3333333}, sd::DataType::HALF); + NDArray exp2('c', {2, 2}, {-0.25, -2, -0.5, -0.666667}, sd::DataType::HALF); - ASSERT_EQ(x1/x2, exp1); - ASSERT_EQ(x3/x1, exp2); + ASSERT_EQ(x1 / x2, exp1); + ASSERT_EQ(x3 / x1, exp2); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorDivide_test2) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT64); - NDArray x2('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::FLOAT32); - NDArray x3('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::HALF); - const double val1 = 2; - const int val2 = -2; - NDArray exp1('c', {2,2}, {0.5, 1, 1.5, 2}, sd::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {2, 1, 0.666667, 0.5}, sd::DataType::DOUBLE); - NDArray exp3('c', {2,2}, {0, -1, -1, -2}, sd::DataType::INT64); - NDArray exp4('c', {2,2}, {-2, -1, 0., 0.}, sd::DataType::INT64); - NDArray exp5('c', {2,2}, {-0.5, -1, -1.5, -2}, sd::DataType::FLOAT32); - NDArray exp6('c', {2,2}, {-2, -1, -0.666667, -0.5}, sd::DataType::FLOAT32); - NDArray exp7('c', {2,2}, {0.5, 1, 1.5, 2}, sd::DataType::HALF); - NDArray exp8('c', {2,2}, {2, 1, 0.666667, 0.5}, sd::DataType::HALF); + NDArray x1('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::FLOAT32); + NDArray x3('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::HALF); + const double val1 = 2; + const int val2 = -2; + NDArray exp1('c', {2, 2}, {0.5, 1, 1.5, 2}, sd::DataType::DOUBLE); + NDArray exp2('c', {2, 2}, {2, 1, 0.666667, 0.5}, sd::DataType::DOUBLE); + NDArray exp3('c', {2, 2}, {0, -1, -1, -2}, sd::DataType::INT64); + NDArray exp4('c', {2, 2}, {-2, -1, 0., 0.}, sd::DataType::INT64); + NDArray exp5('c', {2, 2}, {-0.5, -1, -1.5, -2}, sd::DataType::FLOAT32); + NDArray exp6('c', {2, 2}, {-2, -1, -0.666667, -0.5}, sd::DataType::FLOAT32); + NDArray exp7('c', {2, 2}, {0.5, 1, 1.5, 2}, sd::DataType::HALF); + NDArray exp8('c', {2, 2}, {2, 1, 0.666667, 0.5}, sd::DataType::HALF); - ASSERT_EQ(x1/val1, exp1); - ASSERT_EQ(val1/x1, exp2); + ASSERT_EQ(x1 / val1, exp1); + ASSERT_EQ(val1 / x1, exp2); - ASSERT_EQ(x1/val2, exp3); - ASSERT_EQ(val2/x1, exp4); + ASSERT_EQ(x1 / val2, exp3); + ASSERT_EQ(val2 / x1, exp4); - ASSERT_EQ(x2/val2, exp5); - ASSERT_EQ(val2/x2, exp6); + ASSERT_EQ(x2 / val2, exp5); + ASSERT_EQ(val2 / x2, exp6); - ASSERT_EQ(x3/val1, exp7); - ASSERT_EQ(val1/x3, exp8); + ASSERT_EQ(x3 / val1, exp7); + ASSERT_EQ(val1 / x3, exp8); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorPlusEqual_test1) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray scalar1('c', {0}, std::vector{4}, sd::DataType::INT32); - NDArray scalar2('c', {0}, std::vector{1.5}, sd::DataType::HALF); + NDArray scalar1('c', {0}, std::vector{4}, sd::DataType::INT32); + NDArray scalar2('c', {0}, std::vector{1.5}, sd::DataType::HALF); - NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, sd::DataType::FLOAT32); - NDArray x2('c', {3,2}, {10, 20, 30, 40, 50, 60}, sd::DataType::INT64); - NDArray x3('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x4('c', {2}, {0.4, 0.5}, sd::DataType::HALF); - NDArray x5('c', {2,2}, {0, 1, 2, 3}, sd::DataType::HALF); - NDArray x6('c', {2}, {0.4, 0.5}, sd::DataType::FLOAT32); + NDArray x1('c', {2, 3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, + sd::DataType::FLOAT32); + NDArray x2('c', {3, 2}, {10, 20, 30, 40, 50, 60}, sd::DataType::INT64); + NDArray x3('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x4('c', {2}, {0.4, 0.5}, sd::DataType::HALF); + NDArray x5('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::HALF); + NDArray x6('c', {2}, {0.4, 0.5}, sd::DataType::FLOAT32); - NDArray exp1('c', {0}, std::vector{5}, sd::DataType::INT32); - NDArray exp2('c', {0}, std::vector{6.5}, sd::DataType::HALF); - NDArray exp3('c', {3,2}, {11, 22, 33, 44, 55, 66}, sd::DataType::INT64); - NDArray exp4('c', {2,3}, {12.5, 24.5, 36.5, 48.5, 60.5, 72.5}, sd::DataType::FLOAT32); - NDArray exp5('c', {2,2}, {0.4, 1.5, 2.4, 3.5}, sd::DataType::HALF); + NDArray exp1('c', {0}, std::vector{5}, sd::DataType::INT32); + NDArray exp2('c', {0}, std::vector{6.5}, sd::DataType::HALF); + NDArray exp3('c', {3, 2}, {11, 22, 33, 44, 55, 66}, sd::DataType::INT64); + NDArray exp4('c', {2, 3}, {12.5, 24.5, 36.5, 48.5, 60.5, 72.5}, + sd::DataType::FLOAT32); + NDArray exp5('c', {2, 2}, {0.4, 1.5, 2.4, 3.5}, sd::DataType::HALF); - scalar1 += scalar2; - ASSERT_EQ(scalar1, exp1); + scalar1 += scalar2; + ASSERT_EQ(scalar1, exp1); - scalar2 += scalar1; - ASSERT_EQ(scalar2, exp2); + scalar2 += scalar1; + ASSERT_EQ(scalar2, exp2); - x2 += x1; - ASSERT_EQ(x2, exp3); + x2 += x1; + ASSERT_EQ(x2, exp3); - x1 += x2; - ASSERT_EQ(x1, exp4); + x1 += x2; + ASSERT_EQ(x1, exp4); - x4 += x3; - ASSERT_EQ(x4, exp5); + x4 += x3; + ASSERT_EQ(x4, exp5); - x6 += x5; - ASSERT_EQ(x6, exp5); + x6 += x5; + ASSERT_EQ(x6, exp5); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorPlusEqual_test2) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); - NDArray x2('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT32); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT32); - const Nd4jLong val1 = 1; - const float16 val2 = 1.5; - const double val3 = 2.2; + const Nd4jLong val1 = 1; + const float16 val2 = 1.5; + const double val3 = 2.2; - NDArray exp1('c', {2,2}, {1, 2, 3, 4}, sd::DataType::FLOAT32); - NDArray exp2('c', {2,2}, {1, 2, 3, 4}, sd::DataType::INT32); - NDArray exp3('c', {2,2}, {2.5, 3.5, 4.5, 5.5}, sd::DataType::FLOAT32); - NDArray exp4('c', {2,2}, {2, 3, 4.5, 5}, sd::DataType::INT32); - NDArray exp5('c', {2,2}, {4.7, 5.7, 6.7, 7.7}, sd::DataType::FLOAT32); - NDArray exp6('c', {2,2}, {4, 5, 6, 7}, sd::DataType::INT32); + NDArray exp1('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::FLOAT32); + NDArray exp2('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray exp3('c', {2, 2}, {2.5, 3.5, 4.5, 5.5}, sd::DataType::FLOAT32); + NDArray exp4('c', {2, 2}, {2, 3, 4.5, 5}, sd::DataType::INT32); + NDArray exp5('c', {2, 2}, {4.7, 5.7, 6.7, 7.7}, sd::DataType::FLOAT32); + NDArray exp6('c', {2, 2}, {4, 5, 6, 7}, sd::DataType::INT32); - x1 += val1; - ASSERT_EQ(x1, exp1); + x1 += val1; + ASSERT_EQ(x1, exp1); - x2 += val1; - ASSERT_EQ(x2, exp2); + x2 += val1; + ASSERT_EQ(x2, exp2); - x1 += val2; - ASSERT_EQ(x1, exp3); + x1 += val2; + ASSERT_EQ(x1, exp3); - x2 += val2; - ASSERT_EQ(x2, exp4); + x2 += val2; + ASSERT_EQ(x2, exp4); - x1 += val3; - ASSERT_EQ(x1, exp5); + x1 += val3; + ASSERT_EQ(x1, exp5); - x2 += val3; - ASSERT_EQ(x2, exp6); + x2 += val3; + ASSERT_EQ(x2, exp6); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorMinusEqual_test1) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray scalar1('c', {0}, std::vector{4}, sd::DataType::INT32); - NDArray scalar2('c', {0}, std::vector{1.5}, sd::DataType::HALF); + NDArray scalar1('c', {0}, std::vector{4}, sd::DataType::INT32); + NDArray scalar2('c', {0}, std::vector{1.5}, sd::DataType::HALF); - NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, sd::DataType::FLOAT32); - NDArray x2('c', {3,2}, {10, 20, 30, 40, 50, 60}, sd::DataType::INT64); - NDArray x3('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x4('c', {2}, {0.4, 0.5}, sd::DataType::HALF); - NDArray x5('c', {2,2}, {0, 1, 2, 3}, sd::DataType::HALF); - NDArray x6('c', {2}, {0.4, 0.5}, sd::DataType::FLOAT32); + NDArray x1('c', {2, 3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, + sd::DataType::FLOAT32); + NDArray x2('c', {3, 2}, {10, 20, 30, 40, 50, 60}, sd::DataType::INT64); + NDArray x3('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x4('c', {2}, {0.4, 0.5}, sd::DataType::HALF); + NDArray x5('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::HALF); + NDArray x6('c', {2}, {0.4, 0.5}, sd::DataType::FLOAT32); - NDArray exp1('c', {0}, std::vector{2}, sd::DataType::INT32); - NDArray exp2('c', {0}, std::vector{-0.5}, sd::DataType::HALF); - NDArray exp3('c', {3,2}, {8, 17, 26, 35, 44, 53}, sd::DataType::INT64); - NDArray exp4('c', {2,3}, {-6.5, -14.5, -22.5, -30.5, -38.5, -46.5}, sd::DataType::FLOAT32); - NDArray exp5('c', {2,2}, {0.4, -0.5, -1.6, -2.5}, sd::DataType::HALF); + NDArray exp1('c', {0}, std::vector{2}, sd::DataType::INT32); + NDArray exp2('c', {0}, std::vector{-0.5}, sd::DataType::HALF); + NDArray exp3('c', {3, 2}, {8, 17, 26, 35, 44, 53}, sd::DataType::INT64); + NDArray exp4('c', {2, 3}, {-6.5, -14.5, -22.5, -30.5, -38.5, -46.5}, + sd::DataType::FLOAT32); + NDArray exp5('c', {2, 2}, {0.4, -0.5, -1.6, -2.5}, sd::DataType::HALF); - scalar1 -= scalar2; - ASSERT_EQ(scalar1, exp1); + scalar1 -= scalar2; + ASSERT_EQ(scalar1, exp1); - scalar2 -= scalar1; - ASSERT_EQ(scalar2, exp2); + scalar2 -= scalar1; + ASSERT_EQ(scalar2, exp2); - x2 -= x1; - ASSERT_EQ(x2, exp3); + x2 -= x1; + ASSERT_EQ(x2, exp3); - x1 -= x2; - ASSERT_EQ(x1, exp4); + x1 -= x2; + ASSERT_EQ(x1, exp4); - x4 -= x3; - ASSERT_EQ(x4, exp5); + x4 -= x3; + ASSERT_EQ(x4, exp5); - x6 -= x5; - ASSERT_EQ(x6, exp5); + x6 -= x5; + ASSERT_EQ(x6, exp5); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorMinusEqual_test2) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); - NDArray x2('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT32); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT32); - const Nd4jLong val1 = 1; - const float16 val2 = 1.5; - const double val3 = 2.2; + const Nd4jLong val1 = 1; + const float16 val2 = 1.5; + const double val3 = 2.2; - NDArray exp1('c', {2,2}, {-1, 0, 1, 2}, sd::DataType::FLOAT32); - NDArray exp2('c', {2,2}, {-1, 0, 1, 2}, sd::DataType::INT32); - NDArray exp3('c', {2,2}, {-2.5, -1.5, -0.5, 0.5}, sd::DataType::FLOAT32); - NDArray exp4('c', {2,2}, {-2., -1., 0., 0.}, sd::DataType::INT32); - NDArray exp5('c', {2,2}, {-4.7, -3.7, -2.7, -1.7}, sd::DataType::FLOAT32); - NDArray exp6('c', {2,2}, {-4, -3, -2, -2}, sd::DataType::INT32); + NDArray exp1('c', {2, 2}, {-1, 0, 1, 2}, sd::DataType::FLOAT32); + NDArray exp2('c', {2, 2}, {-1, 0, 1, 2}, sd::DataType::INT32); + NDArray exp3('c', {2, 2}, {-2.5, -1.5, -0.5, 0.5}, sd::DataType::FLOAT32); + NDArray exp4('c', {2, 2}, {-2., -1., 0., 0.}, sd::DataType::INT32); + NDArray exp5('c', {2, 2}, {-4.7, -3.7, -2.7, -1.7}, sd::DataType::FLOAT32); + NDArray exp6('c', {2, 2}, {-4, -3, -2, -2}, sd::DataType::INT32); - x1 -= val1; - ASSERT_EQ(x1, exp1); + x1 -= val1; + ASSERT_EQ(x1, exp1); - x2 -= val1; - ASSERT_EQ(x2, exp2); + x2 -= val1; + ASSERT_EQ(x2, exp2); - x1 -= val2; - ASSERT_EQ(x1, exp3); + x1 -= val2; + ASSERT_EQ(x1, exp3); - x2 -= val2; - ASSERT_EQ(x2, exp4); + x2 -= val2; + ASSERT_EQ(x2, exp4); - x1 -= val3; - ASSERT_EQ(x1, exp5); + x1 -= val3; + ASSERT_EQ(x1, exp5); - x2 -= val3; - ASSERT_EQ(x2, exp6); + x2 -= val3; + ASSERT_EQ(x2, exp6); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorMultiplyEqual_test1) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray scalar1('c', {0}, std::vector{3}, sd::DataType::INT32); - NDArray scalar2('c', {0}, std::vector{2.5}, sd::DataType::HALF); + NDArray scalar1('c', {0}, std::vector{3}, sd::DataType::INT32); + NDArray scalar2('c', {0}, std::vector{2.5}, sd::DataType::HALF); - NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, sd::DataType::FLOAT32); - NDArray x2('c', {3,2}, {1, 2, 3, 4, 5, 6}, sd::DataType::INT64); - NDArray x3('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x4('c', {2}, {0.4, 0.5}, sd::DataType::HALF); - NDArray x5('c', {2,2}, {0, 1, 2, 3}, sd::DataType::HALF); - NDArray x6('c', {2}, {0.4, 0.5}, sd::DataType::FLOAT32); + NDArray x1('c', {2, 3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, + sd::DataType::FLOAT32); + NDArray x2('c', {3, 2}, {1, 2, 3, 4, 5, 6}, sd::DataType::INT64); + NDArray x3('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x4('c', {2}, {0.4, 0.5}, sd::DataType::HALF); + NDArray x5('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::HALF); + NDArray x6('c', {2}, {0.4, 0.5}, sd::DataType::FLOAT32); - NDArray exp1('c', {0}, std::vector{7}, sd::DataType::INT32); - NDArray exp2('c', {0}, std::vector{17.5}, sd::DataType::HALF); - NDArray exp3('c', {3,2}, {1, 5, 10, 18, 27, 39}, sd::DataType::INT64); - NDArray exp4('c', {2,3}, {1.5, 12.5, 35, 81, 148.5, 253.5}, sd::DataType::FLOAT32); - NDArray exp5('c', {2,2}, {0., 0.5, 0.8, 1.5}, sd::DataType::HALF); + NDArray exp1('c', {0}, std::vector{7}, sd::DataType::INT32); + NDArray exp2('c', {0}, std::vector{17.5}, sd::DataType::HALF); + NDArray exp3('c', {3, 2}, {1, 5, 10, 18, 27, 39}, sd::DataType::INT64); + NDArray exp4('c', {2, 3}, {1.5, 12.5, 35, 81, 148.5, 253.5}, + sd::DataType::FLOAT32); + NDArray exp5('c', {2, 2}, {0., 0.5, 0.8, 1.5}, sd::DataType::HALF); - scalar1 *= scalar2; - ASSERT_EQ(scalar1, exp1); + scalar1 *= scalar2; + ASSERT_EQ(scalar1, exp1); - scalar2 *= scalar1; - ASSERT_EQ(scalar2, exp2); + scalar2 *= scalar1; + ASSERT_EQ(scalar2, exp2); - x2 *= x1; - ASSERT_EQ(x2, exp3); + x2 *= x1; + ASSERT_EQ(x2, exp3); - x1 *= x2; - ASSERT_EQ(x1, exp4); + x1 *= x2; + ASSERT_EQ(x1, exp4); - x4 *= x3; - ASSERT_EQ(x4, exp5); + x4 *= x3; + ASSERT_EQ(x4, exp5); - x6 *= x5; - ASSERT_EQ(x6, exp5); + x6 *= x5; + ASSERT_EQ(x6, exp5); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorMultiplyEqual_test2) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); - NDArray x2('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT32); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT32); - const Nd4jLong val1 = 1; - const float16 val2 = 1.5; - const double val3 = 2.2; + const Nd4jLong val1 = 1; + const float16 val2 = 1.5; + const double val3 = 2.2; - NDArray exp1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); - NDArray exp2('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT32); - NDArray exp3('c', {2,2}, {0, 1.5, 3, 4.5}, sd::DataType::FLOAT32); - NDArray exp4('c', {2,2}, {0, 1, 3, 4}, sd::DataType::INT32); - NDArray exp5('c', {2,2}, {0, 3.3, 6.6, 9.9}, sd::DataType::FLOAT32); - NDArray exp6('c', {2,2}, {0, 2, 6, 8}, sd::DataType::INT32); + NDArray exp1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray exp2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT32); + NDArray exp3('c', {2, 2}, {0, 1.5, 3, 4.5}, sd::DataType::FLOAT32); + NDArray exp4('c', {2, 2}, {0, 1, 3, 4}, sd::DataType::INT32); + NDArray exp5('c', {2, 2}, {0, 3.3, 6.6, 9.9}, sd::DataType::FLOAT32); + NDArray exp6('c', {2, 2}, {0, 2, 6, 8}, sd::DataType::INT32); - x1 *= val1; - ASSERT_EQ(x1, exp1); + x1 *= val1; + ASSERT_EQ(x1, exp1); - x2 *= val1; - ASSERT_EQ(x2, exp2); + x2 *= val1; + ASSERT_EQ(x2, exp2); - x1 *= val2; - ASSERT_EQ(x1, exp3); + x1 *= val2; + ASSERT_EQ(x1, exp3); - x2 *= val2; - ASSERT_EQ(x2, exp4); + x2 *= val2; + ASSERT_EQ(x2, exp4); - x1 *= val3; - ASSERT_EQ(x1, exp5); + x1 *= val3; + ASSERT_EQ(x1, exp5); - x2 *= val3; - ASSERT_EQ(x2, exp6); + x2 *= val3; + ASSERT_EQ(x2, exp6); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorDivideEqual_test1) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray scalar1('c', {0}, std::vector{3}, sd::DataType::INT32); - NDArray scalar2('c', {0}, std::vector{2.5}, sd::DataType::HALF); + NDArray scalar1('c', {0}, std::vector{3}, sd::DataType::INT32); + NDArray scalar2('c', {0}, std::vector{2.5}, sd::DataType::HALF); - NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, sd::DataType::FLOAT32); - NDArray x2('c', {3,2}, {10, 20, 30, 40, 50, 60}, sd::DataType::INT64); - NDArray x3('c', {2,2}, {1, 2, 3, 4}, sd::DataType::INT64); - NDArray x4('c', {2}, {0.4, 0.5}, sd::DataType::HALF); - NDArray x5('c', {2,2}, {1, 2, 3, 4}, sd::DataType::HALF); - NDArray x6('c', {2}, {0.4, 0.5}, sd::DataType::FLOAT32); + NDArray x1('c', {2, 3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, + sd::DataType::FLOAT32); + NDArray x2('c', {3, 2}, {10, 20, 30, 40, 50, 60}, sd::DataType::INT64); + NDArray x3('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT64); + NDArray x4('c', {2}, {0.4, 0.5}, sd::DataType::HALF); + NDArray x5('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::HALF); + NDArray x6('c', {2}, {0.4, 0.5}, sd::DataType::FLOAT32); - NDArray exp1('c', {0}, std::vector{1}, sd::DataType::INT32); - NDArray exp2('c', {0}, std::vector{2.5}, sd::DataType::HALF); - NDArray exp3('c', {3,2}, {6, 8, 8, 8, 9, 9}, sd::DataType::INT64); - NDArray exp4('c', {2,3}, {0.25, 0.3125, 0.4375, 0.5625, 0.611111111, 0.722222222}, sd::DataType::FLOAT32); - NDArray exp5('c', {2,2}, {0.4, 0.25, 0.1333333, 0.125}, sd::DataType::HALF); + NDArray exp1('c', {0}, std::vector{1}, sd::DataType::INT32); + NDArray exp2('c', {0}, std::vector{2.5}, sd::DataType::HALF); + NDArray exp3('c', {3, 2}, {6, 8, 8, 8, 9, 9}, sd::DataType::INT64); + NDArray exp4('c', {2, 3}, + {0.25, 0.3125, 0.4375, 0.5625, 0.611111111, 0.722222222}, + sd::DataType::FLOAT32); + NDArray exp5('c', {2, 2}, {0.4, 0.25, 0.1333333, 0.125}, sd::DataType::HALF); - scalar1 /= scalar2; - ASSERT_EQ(scalar1, exp1); + scalar1 /= scalar2; + ASSERT_EQ(scalar1, exp1); - scalar2 /= scalar1; - ASSERT_EQ(scalar2, exp2); + scalar2 /= scalar1; + ASSERT_EQ(scalar2, exp2); - x2 /= x1; - ASSERT_EQ(x2, exp3); + x2 /= x1; + ASSERT_EQ(x2, exp3); - x1 /= x2; - ASSERT_EQ(x1, exp4); + x1 /= x2; + ASSERT_EQ(x1, exp4); - x4 /= x3; - ASSERT_EQ(x4, exp5); + x4 /= x3; + ASSERT_EQ(x4, exp5); - x6 /= x5; - ASSERT_EQ(x6, exp5); + x6 /= x5; + ASSERT_EQ(x6, exp5); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorDivideEqual_test2) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, 2, 4, 6}, sd::DataType::FLOAT32); - NDArray x2('c', {2,2}, {0, 2, 4, 6}, sd::DataType::INT32); + NDArray x1('c', {2, 2}, {0, 2, 4, 6}, sd::DataType::FLOAT32); + NDArray x2('c', {2, 2}, {0, 2, 4, 6}, sd::DataType::INT32); - const Nd4jLong val1 = 1; - const float16 val2 = 2.; - const double val3 = 2.2; + const Nd4jLong val1 = 1; + const float16 val2 = 2.; + const double val3 = 2.2; - NDArray exp1('c', {2,2}, {0, 2, 4, 6}, sd::DataType::FLOAT32); - NDArray exp2('c', {2,2}, {0, 2, 4, 6}, sd::DataType::INT32); - NDArray exp3('c', {2,2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); - NDArray exp4('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT32); - NDArray exp5('c', {2,2}, {0, 0.45454545, 0.909090909, 1.363636364}, sd::DataType::FLOAT32); - NDArray exp6('c', {2,2}, {0, 0, 0, 1}, sd::DataType::INT32); + NDArray exp1('c', {2, 2}, {0, 2, 4, 6}, sd::DataType::FLOAT32); + NDArray exp2('c', {2, 2}, {0, 2, 4, 6}, sd::DataType::INT32); + NDArray exp3('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray exp4('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT32); + NDArray exp5('c', {2, 2}, {0, 0.45454545, 0.909090909, 1.363636364}, + sd::DataType::FLOAT32); + NDArray exp6('c', {2, 2}, {0, 0, 0, 1}, sd::DataType::INT32); - x1 /= val1; - ASSERT_EQ(x1, exp1); + x1 /= val1; + ASSERT_EQ(x1, exp1); - x2 /= val1; - ASSERT_EQ(x2, exp2); + x2 /= val1; + ASSERT_EQ(x2, exp2); - x1 /= val2; - ASSERT_EQ(x1, exp3); + x1 /= val2; + ASSERT_EQ(x1, exp3); - x2 /= val2; - ASSERT_EQ(x2, exp4); + x2 /= val2; + ASSERT_EQ(x2, exp4); - x1 /= val3; - ASSERT_EQ(x1, exp5); + x1 /= val3; + ASSERT_EQ(x1, exp5); - x2 /= val3; - ASSERT_EQ(x2, exp6); + x2 /= val3; + ASSERT_EQ(x2, exp6); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_reduceNumberFloat_test1) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::HALF); - NDArray x3('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray x3('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, {0, 1, 0, 1}, sd::DataType::BOOL); - NDArray exp1('c', {0}, std::vector{1.5}, sd::DataType::FLOAT32); - NDArray exp2('c', {0}, std::vector{2}, sd::DataType::HALF); - NDArray exp3('c', {0}, std::vector{2}, sd::DataType::DOUBLE); - NDArray exp4('c', {0}, std::vector{0.25},sd::DataType::FLOAT32); + NDArray exp1('c', {0}, std::vector{1.5}, sd::DataType::FLOAT32); + NDArray exp2('c', {0}, std::vector{2}, sd::DataType::HALF); + NDArray exp3('c', {0}, std::vector{2}, sd::DataType::DOUBLE); + NDArray exp4('c', {0}, std::vector{0.25}, sd::DataType::FLOAT32); + NDArray scalar = x1.reduceNumber(reduce::Mean); + ASSERT_EQ(scalar, exp1); + x1.reduceNumber(reduce::Mean, scalar); + ASSERT_EQ(scalar, exp1); - NDArray scalar = x1.reduceNumber(reduce::Mean); - ASSERT_EQ(scalar, exp1); - x1.reduceNumber(reduce::Mean, scalar); - ASSERT_EQ(scalar, exp1); + scalar = x2.reduceNumber(reduce::Mean); + ASSERT_EQ(scalar, exp2); + x2.reduceNumber(reduce::Mean, scalar); + ASSERT_EQ(scalar, exp2); - scalar = x2.reduceNumber(reduce::Mean); - ASSERT_EQ(scalar, exp2); - x2.reduceNumber(reduce::Mean, scalar); - ASSERT_EQ(scalar, exp2); + scalar = x3.reduceNumber(reduce::Mean); + ASSERT_EQ(scalar, exp3); + x3.reduceNumber(reduce::Mean, scalar); + ASSERT_EQ(scalar, exp3); - scalar = x3.reduceNumber(reduce::Mean); - ASSERT_EQ(scalar, exp3); - x3.reduceNumber(reduce::Mean,scalar); - ASSERT_EQ(scalar, exp3); - - scalar = x4.reduceNumber(reduce::Mean); - ASSERT_EQ(scalar, exp4); - x4.reduceNumber(reduce::Mean, scalar); - ASSERT_EQ(scalar, exp4); + scalar = x4.reduceNumber(reduce::Mean); + ASSERT_EQ(scalar, exp4); + x4.reduceNumber(reduce::Mean, scalar); + ASSERT_EQ(scalar, exp4); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_reduceNumberSame_test1) { - if (!Environment::getInstance().isExperimentalBuild()) - return; - - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::HALF); - NDArray x3('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray exp1('c', {0}, std::vector{6}, sd::DataType::INT64); - NDArray exp2('c', {0}, std::vector{8}, sd::DataType::HALF); - NDArray exp3('c', {0}, std::vector{8}, sd::DataType::DOUBLE); - NDArray exp4('c', {0}, std::vector{1}, sd::DataType::BOOL); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray x3('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, {0, 1, 0, 1}, sd::DataType::BOOL); + NDArray exp1('c', {0}, std::vector{6}, sd::DataType::INT64); + NDArray exp2('c', {0}, std::vector{8}, sd::DataType::HALF); + NDArray exp3('c', {0}, std::vector{8}, sd::DataType::DOUBLE); + NDArray exp4('c', {0}, std::vector{1}, sd::DataType::BOOL); - NDArray scalar = x1.reduceNumber(reduce::Sum); - ASSERT_EQ(scalar, exp1); - x1.reduceNumber(reduce::Sum, scalar); - ASSERT_EQ(scalar, exp1); + NDArray scalar = x1.reduceNumber(reduce::Sum); + ASSERT_EQ(scalar, exp1); + x1.reduceNumber(reduce::Sum, scalar); + ASSERT_EQ(scalar, exp1); - scalar = x2.reduceNumber(reduce::Sum); - ASSERT_EQ(scalar, exp2); - x2.reduceNumber(reduce::Sum, scalar); - ASSERT_EQ(scalar, exp2); + scalar = x2.reduceNumber(reduce::Sum); + ASSERT_EQ(scalar, exp2); + x2.reduceNumber(reduce::Sum, scalar); + ASSERT_EQ(scalar, exp2); - scalar = x3.reduceNumber(reduce::Sum); - ASSERT_EQ(scalar, exp3); - x3.reduceNumber(reduce::Sum, scalar); - ASSERT_EQ(scalar, exp3); + scalar = x3.reduceNumber(reduce::Sum); + ASSERT_EQ(scalar, exp3); + x3.reduceNumber(reduce::Sum, scalar); + ASSERT_EQ(scalar, exp3); - scalar = x4.reduceNumber(reduce::Sum); - ASSERT_EQ(scalar, exp4); - x4.reduceNumber(reduce::Sum, scalar); - ASSERT_EQ(scalar, exp4); + scalar = x4.reduceNumber(reduce::Sum); + ASSERT_EQ(scalar, exp4); + x4.reduceNumber(reduce::Sum, scalar); + ASSERT_EQ(scalar, exp4); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_reduceNumberBool_test1) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, -1, 2, -3}, sd::DataType::INT64); - NDArray x2('c', {2,2}, {0.5, -1.5, 2.5, -3.5}, sd::DataType::HALF); - NDArray x3('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, {-2, -1, 0, 1}, sd::DataType::BOOL); + NDArray x1('c', {2, 2}, {0, -1, 2, -3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0.5, -1.5, 2.5, -3.5}, sd::DataType::HALF); + NDArray x3('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, {-2, -1, 0, 1}, sd::DataType::BOOL); - NDArray exp1('c', {0}, std::vector{1}, sd::DataType::BOOL); + NDArray exp1('c', {0}, std::vector{1}, sd::DataType::BOOL); - NDArray scalar = x1.reduceNumber(reduce::IsFinite); - ASSERT_EQ(scalar, exp1); - x1.reduceNumber(reduce::IsFinite, scalar); - ASSERT_EQ(scalar, exp1); + NDArray scalar = x1.reduceNumber(reduce::IsFinite); + ASSERT_EQ(scalar, exp1); + x1.reduceNumber(reduce::IsFinite, scalar); + ASSERT_EQ(scalar, exp1); - scalar = x2.reduceNumber(reduce::IsFinite); - ASSERT_EQ(scalar, exp1); - x2.reduceNumber(reduce::IsFinite, scalar); - ASSERT_EQ(scalar, exp1); + scalar = x2.reduceNumber(reduce::IsFinite); + ASSERT_EQ(scalar, exp1); + x2.reduceNumber(reduce::IsFinite, scalar); + ASSERT_EQ(scalar, exp1); - scalar = x3.reduceNumber(reduce::IsFinite); - ASSERT_EQ(scalar, exp1); - x3.reduceNumber(reduce::IsFinite, scalar); - ASSERT_EQ(scalar, exp1); + scalar = x3.reduceNumber(reduce::IsFinite); + ASSERT_EQ(scalar, exp1); + x3.reduceNumber(reduce::IsFinite, scalar); + ASSERT_EQ(scalar, exp1); - scalar = x4.reduceNumber(reduce::IsFinite); - ASSERT_EQ(scalar, exp1); - x4.reduceNumber(reduce::IsFinite, scalar); - ASSERT_EQ(scalar, exp1); + scalar = x4.reduceNumber(reduce::IsFinite); + ASSERT_EQ(scalar, exp1); + x4.reduceNumber(reduce::IsFinite, scalar); + ASSERT_EQ(scalar, exp1); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_reduceNumberLong_test1) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::HALF); - NDArray x3('c', {2,2}, {0.5, -1.5, 0, 3.5}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray x3('c', {2, 2}, {0.5, -1.5, 0, 3.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, {0, 1, 0, 1}, sd::DataType::BOOL); - NDArray exp1('c', {0}, std::vector{3}, sd::DataType::INT64); - NDArray exp2('c', {0}, std::vector{4}, sd::DataType::INT64); - NDArray exp3('c', {0}, std::vector{3}, sd::DataType::INT64); - NDArray exp4('c', {0}, std::vector{2}, sd::DataType::INT64); + NDArray exp1('c', {0}, std::vector{3}, sd::DataType::INT64); + NDArray exp2('c', {0}, std::vector{4}, sd::DataType::INT64); + NDArray exp3('c', {0}, std::vector{3}, sd::DataType::INT64); + NDArray exp4('c', {0}, std::vector{2}, sd::DataType::INT64); - NDArray scalar = x1.reduceNumber(reduce::CountNonZero); - ASSERT_EQ(scalar, exp1); - x1.reduceNumber(reduce::CountNonZero, scalar); - ASSERT_EQ(scalar, exp1); + NDArray scalar = x1.reduceNumber(reduce::CountNonZero); + ASSERT_EQ(scalar, exp1); + x1.reduceNumber(reduce::CountNonZero, scalar); + ASSERT_EQ(scalar, exp1); - scalar = x2.reduceNumber(reduce::CountNonZero); - ASSERT_EQ(scalar, exp2); - x2.reduceNumber(reduce::CountNonZero, scalar); - ASSERT_EQ(scalar, exp2); + scalar = x2.reduceNumber(reduce::CountNonZero); + ASSERT_EQ(scalar, exp2); + x2.reduceNumber(reduce::CountNonZero, scalar); + ASSERT_EQ(scalar, exp2); - scalar = x3.reduceNumber(reduce::CountNonZero); - ASSERT_EQ(scalar, exp3); - x3.reduceNumber(reduce::CountNonZero, scalar); - ASSERT_EQ(scalar, exp3); + scalar = x3.reduceNumber(reduce::CountNonZero); + ASSERT_EQ(scalar, exp3); + x3.reduceNumber(reduce::CountNonZero, scalar); + ASSERT_EQ(scalar, exp3); - scalar = x4.reduceNumber(reduce::CountNonZero); - ASSERT_EQ(scalar, exp4); - x4.reduceNumber(reduce::CountNonZero, scalar); - ASSERT_EQ(scalar, exp4); + scalar = x4.reduceNumber(reduce::CountNonZero); + ASSERT_EQ(scalar, exp4); + x4.reduceNumber(reduce::CountNonZero, scalar); + ASSERT_EQ(scalar, exp4); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_indexReduceNumber_test1) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT32); - NDArray x2('c', {2,2}, {0.5, 1.5, -4.5, 3.5}, sd::DataType::HALF); - NDArray x3('c', {2,2}, {0, -1, 0, 1}, sd::DataType::BOOL); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT32); + NDArray x2('c', {2, 2}, {0.5, 1.5, -4.5, 3.5}, sd::DataType::HALF); + NDArray x3('c', {2, 2}, {0, -1, 0, 1}, sd::DataType::BOOL); - NDArray exp1('c', {0}, std::vector{3}, sd::DataType::INT64); - NDArray exp2('c', {0}, std::vector{2}, sd::DataType::INT64); - NDArray exp3('c', {0}, std::vector{1}, sd::DataType::INT64); + NDArray exp1('c', {0}, std::vector{3}, sd::DataType::INT64); + NDArray exp2('c', {0}, std::vector{2}, sd::DataType::INT64); + NDArray exp3('c', {0}, std::vector{1}, sd::DataType::INT64); - NDArray scalar = x1.indexReduceNumber(sd::indexreduce::IndexAbsoluteMax); - ASSERT_EQ(scalar, exp1); + NDArray scalar = x1.indexReduceNumber(sd::indexreduce::IndexAbsoluteMax); + ASSERT_EQ(scalar, exp1); - scalar = x2.indexReduceNumber(sd::indexreduce::IndexAbsoluteMax); - ASSERT_EQ(scalar, exp2); + scalar = x2.indexReduceNumber(sd::indexreduce::IndexAbsoluteMax); + ASSERT_EQ(scalar, exp2); - scalar = x3.indexReduceNumber(sd::indexreduce::IndexAbsoluteMax); - ASSERT_EQ(scalar, exp3); + scalar = x3.indexReduceNumber(sd::indexreduce::IndexAbsoluteMax); + ASSERT_EQ(scalar, exp3); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyTransformFloat_test1) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, 4, 9, 16}, sd::DataType::INT64); - NDArray x2('c', {2,2}, {0, 2.25, 6.25, 12.25}, sd::DataType::HALF); - NDArray x3('c', {2,2}, {0, 2.25, 6.25, 12.25}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); + NDArray x1('c', {2, 2}, {0, 4, 9, 16}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0, 2.25, 6.25, 12.25}, sd::DataType::HALF); + NDArray x3('c', {2, 2}, {0, 2.25, 6.25, 12.25}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, {0, 1, 0, 1}, sd::DataType::BOOL); - NDArray exp1('c', {2,2}, {0, 2, 3, 4}, sd::DataType::FLOAT32); - NDArray exp2('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); - NDArray exp3('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::HALF); - NDArray exp4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::HALF); + NDArray exp1('c', {2, 2}, {0, 2, 3, 4}, sd::DataType::FLOAT32); + NDArray exp2('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); + NDArray exp3('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray exp4('c', {2, 2}, {0, 1, 0, 1}, sd::DataType::HALF); - NDArray result1('c', {2,2}, sd::DataType::FLOAT32); - NDArray result2('c', {2,2}, sd::DataType::DOUBLE); - NDArray result3('c', {2,2}, sd::DataType::HALF); + NDArray result1('c', {2, 2}, sd::DataType::FLOAT32); + NDArray result2('c', {2, 2}, sd::DataType::DOUBLE); + NDArray result3('c', {2, 2}, sd::DataType::HALF); - x1.applyTransform(sd::transform::Sqrt, result1); - ASSERT_EQ(result1, exp1); + x1.applyTransform(sd::transform::Sqrt, result1); + ASSERT_EQ(result1, exp1); - x2.applyTransform(sd::transform::Sqrt, result2); - ASSERT_EQ(result2, exp2); + x2.applyTransform(sd::transform::Sqrt, result2); + ASSERT_EQ(result2, exp2); - x3.applyTransform(sd::transform::Sqrt, result3); - ASSERT_EQ(result3, exp3); + x3.applyTransform(sd::transform::Sqrt, result3); + ASSERT_EQ(result3, exp3); - x4.applyTransform(sd::transform::Sqrt, result3); - ASSERT_EQ(result3, exp4); + x4.applyTransform(sd::transform::Sqrt, result3); + ASSERT_EQ(result3, exp4); - x2.applyTransform(sd::transform::Sqrt, x2); - ASSERT_EQ(x2, exp3); + x2.applyTransform(sd::transform::Sqrt, x2); + ASSERT_EQ(x2, exp3); - x3.applyTransform(sd::transform::Sqrt, x3); - ASSERT_EQ(x3, exp2); + x3.applyTransform(sd::transform::Sqrt, x3); + ASSERT_EQ(x3, exp2); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyTransformSame_test1) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::HALF); - NDArray x3('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); - NDArray x5('c', {2,3}, {0, 1.5, 2.5, 3.5, 4.5, 5.5}, sd::DataType::DOUBLE); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray x3('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, {0, 1, 0, 1}, sd::DataType::BOOL); + NDArray x5('c', {2, 3}, {0, 1.5, 2.5, 3.5, 4.5, 5.5}, sd::DataType::DOUBLE); - NDArray exp1('c', {2,2}, {0, 1, 4, 9}, sd::DataType::INT64); - NDArray exp2('c', {2,2}, {0, 2.25, 6.25, 12.25}, sd::DataType::HALF); - NDArray exp3('c', {2,2}, {0, 2.25, 6.25, 12.25}, sd::DataType::DOUBLE); - NDArray exp4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); - NDArray exp5('c', {3,2}, {0, 2.25, 6.25, 12.25, 20.25, 30.25}, sd::DataType::DOUBLE); + NDArray exp1('c', {2, 2}, {0, 1, 4, 9}, sd::DataType::INT64); + NDArray exp2('c', {2, 2}, {0, 2.25, 6.25, 12.25}, sd::DataType::HALF); + NDArray exp3('c', {2, 2}, {0, 2.25, 6.25, 12.25}, sd::DataType::DOUBLE); + NDArray exp4('c', {2, 2}, {0, 1, 0, 1}, sd::DataType::BOOL); + NDArray exp5('c', {3, 2}, {0, 2.25, 6.25, 12.25, 20.25, 30.25}, + sd::DataType::DOUBLE); - NDArray result1('c', {2,2}, sd::DataType::INT64); - NDArray result2('c', {2,2}, sd::DataType::HALF); - NDArray result3('c', {2,2}, sd::DataType::DOUBLE); - NDArray result4('c', {2,2}, sd::DataType::BOOL); - NDArray result5('c', {3,2}, sd::DataType::DOUBLE); + NDArray result1('c', {2, 2}, sd::DataType::INT64); + NDArray result2('c', {2, 2}, sd::DataType::HALF); + NDArray result3('c', {2, 2}, sd::DataType::DOUBLE); + NDArray result4('c', {2, 2}, sd::DataType::BOOL); + NDArray result5('c', {3, 2}, sd::DataType::DOUBLE); - x1.applyTransform(sd::transform::Square, result1); - ASSERT_EQ(result1, exp1); + x1.applyTransform(sd::transform::Square, result1); + ASSERT_EQ(result1, exp1); - x2.applyTransform(sd::transform::Square, result2); - ASSERT_EQ(result2, exp2); + x2.applyTransform(sd::transform::Square, result2); + ASSERT_EQ(result2, exp2); - x3.applyTransform(sd::transform::Square, result3); - ASSERT_EQ(result3, exp3); + x3.applyTransform(sd::transform::Square, result3); + ASSERT_EQ(result3, exp3); - x4.applyTransform(sd::transform::Square, result4); - ASSERT_EQ(result4, exp4); + x4.applyTransform(sd::transform::Square, result4); + ASSERT_EQ(result4, exp4); - x2.applyTransform(sd::transform::Square, x2); - ASSERT_EQ(x2, exp2); + x2.applyTransform(sd::transform::Square, x2); + ASSERT_EQ(x2, exp2); - x3.applyTransform(sd::transform::Square, x3); - ASSERT_EQ(x3, exp3); + x3.applyTransform(sd::transform::Square, x3); + ASSERT_EQ(x3, exp3); - x5.applyTransform(sd::transform::Square, result5); - ASSERT_EQ(result5, exp5); + x5.applyTransform(sd::transform::Square, result5); + ASSERT_EQ(result5, exp5); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyTransformBool_test1) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::HALF); - NDArray x3('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); - NDArray x5('c', {2,3}, {0, 1.5, 2.5, 3.5, 4.5, 5.5}, sd::DataType::DOUBLE); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray x3('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, {0, 1, 0, 1}, sd::DataType::BOOL); + NDArray x5('c', {2, 3}, {0, 1.5, 2.5, 3.5, 4.5, 5.5}, sd::DataType::DOUBLE); - NDArray exp1('c', {2,2}, {0, 0, 0, 1}, sd::DataType::BOOL); - NDArray exp2('c', {2,2}, {0, 1, 0, 0}, sd::DataType::BOOL); - NDArray exp3('c', {3,2}, {0, 0, 0, 0, 0, 1}, sd::DataType::BOOL); + NDArray exp1('c', {2, 2}, {0, 0, 0, 1}, sd::DataType::BOOL); + NDArray exp2('c', {2, 2}, {0, 1, 0, 0}, sd::DataType::BOOL); + NDArray exp3('c', {3, 2}, {0, 0, 0, 0, 0, 1}, sd::DataType::BOOL); - NDArray result1('c', {2,2}, sd::DataType::BOOL); - NDArray result2('c', {3,2}, sd::DataType::BOOL); + NDArray result1('c', {2, 2}, sd::DataType::BOOL); + NDArray result2('c', {3, 2}, sd::DataType::BOOL); - /* - x1.applyTransform(sd::transform::IsMax, result1); - ASSERT_EQ(result1, exp1); + /* + x1.applyTransform(sd::transform::IsMax, result1); + ASSERT_EQ(result1, exp1); - x2.applyTransform(sd::transform::IsMax, result1); - ASSERT_EQ(result1, exp1); + x2.applyTransform(sd::transform::IsMax, result1); + ASSERT_EQ(result1, exp1); - x3.applyTransform(sd::transform::IsMax, result1); - ASSERT_EQ(result1, exp1); + x3.applyTransform(sd::transform::IsMax, result1); + ASSERT_EQ(result1, exp1); - x4.applyTransform(sd::transform::IsMax, result1); - ASSERT_EQ(result1, exp2); + x4.applyTransform(sd::transform::IsMax, result1); + ASSERT_EQ(result1, exp2); - x5.applyTransform(sd::transform::IsMax, result2); - ASSERT_EQ(result2, exp3); - */ + x5.applyTransform(sd::transform::IsMax, result2); + ASSERT_EQ(result2, exp3); + */ } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyTransformStrict_test1) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::HALF); - NDArray x2('c', {2,2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); - NDArray x3('c', {2,2}, {0, 1, 2, 3}, sd::DataType::DOUBLE); - NDArray x4('c', {2,3}, {0, 1, 2, 3, 4, 5}, sd::DataType::DOUBLE); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::HALF); + NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray x3('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::DOUBLE); - NDArray exp1('c', {2,2}, {0, 3, 12, 27}, sd::DataType::HALF); - NDArray exp2('c', {2,2}, {0, 3, 12, 27}, sd::DataType::FLOAT32); - NDArray exp3('c', {2,2}, {0, 3, 12, 27}, sd::DataType::DOUBLE); - NDArray exp4('c', {3,2}, {0, 3, 12, 27, 48, 75}, sd::DataType::DOUBLE); - NDArray exp5('c', {2,3}, {0, 3, 12, 27, 48, 75}, sd::DataType::DOUBLE); + NDArray exp1('c', {2, 2}, {0, 3, 12, 27}, sd::DataType::HALF); + NDArray exp2('c', {2, 2}, {0, 3, 12, 27}, sd::DataType::FLOAT32); + NDArray exp3('c', {2, 2}, {0, 3, 12, 27}, sd::DataType::DOUBLE); + NDArray exp4('c', {3, 2}, {0, 3, 12, 27, 48, 75}, sd::DataType::DOUBLE); + NDArray exp5('c', {2, 3}, {0, 3, 12, 27, 48, 75}, sd::DataType::DOUBLE); - NDArray result1('c', {2,2}, sd::DataType::HALF); - NDArray result2('c', {2,2}, sd::DataType::FLOAT32); - NDArray result3('c', {2,2}, sd::DataType::DOUBLE); - NDArray result4('c', {3,2}, sd::DataType::DOUBLE); + NDArray result1('c', {2, 2}, sd::DataType::HALF); + NDArray result2('c', {2, 2}, sd::DataType::FLOAT32); + NDArray result3('c', {2, 2}, sd::DataType::DOUBLE); + NDArray result4('c', {3, 2}, sd::DataType::DOUBLE); - x1.applyTransform(sd::transform::CubeDerivative, result1); - ASSERT_EQ(result1, exp1); + x1.applyTransform(sd::transform::CubeDerivative, result1); + ASSERT_EQ(result1, exp1); - x2.applyTransform(sd::transform::CubeDerivative, result2); - ASSERT_EQ(result2, exp2); + x2.applyTransform(sd::transform::CubeDerivative, result2); + ASSERT_EQ(result2, exp2); - x3.applyTransform(sd::transform::CubeDerivative, result3); - ASSERT_EQ(result3, exp3); + x3.applyTransform(sd::transform::CubeDerivative, result3); + ASSERT_EQ(result3, exp3); - x4.applyTransform(sd::transform::CubeDerivative, result4); - ASSERT_EQ(result4, exp4); + x4.applyTransform(sd::transform::CubeDerivative, result4); + ASSERT_EQ(result4, exp4); - x1.applyTransform(sd::transform::CubeDerivative, x1); - ASSERT_EQ(x1, exp1); + x1.applyTransform(sd::transform::CubeDerivative, x1); + ASSERT_EQ(x1, exp1); - x2.applyTransform(sd::transform::CubeDerivative, x2); - ASSERT_EQ(x2, exp2); + x2.applyTransform(sd::transform::CubeDerivative, x2); + ASSERT_EQ(x2, exp2); - x3.applyTransform(sd::transform::CubeDerivative, x3); - ASSERT_EQ(x3, exp3); + x3.applyTransform(sd::transform::CubeDerivative, x3); + ASSERT_EQ(x3, exp3); - x4.applyTransform(sd::transform::CubeDerivative, x4); - ASSERT_EQ(x4, exp5); + x4.applyTransform(sd::transform::CubeDerivative, x4); + ASSERT_EQ(x4, exp5); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyPairwiseTransform_test1) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2,3}, {0, 1, 2, 3, 4, 5}, sd::DataType::INT32); - NDArray x2('c', {2,3}, {0, 1, 2, 3, 4, 5}, sd::DataType::FLOAT32); - NDArray x3('c', {2,3}, {0, 1, 0, 1, 0, 0}, sd::DataType::BOOL); - NDArray x4('c', {3,2}, {0.5, 1.5, 2.5, 3.5, 4.5, 0}, sd::DataType::DOUBLE); - NDArray x5('c', {3,2}, sd::DataType::INT32); - NDArray x6('c', {2,3}, sd::DataType::DOUBLE); + NDArray x1('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::INT32); + NDArray x2('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray x3('c', {2, 3}, {0, 1, 0, 1, 0, 0}, sd::DataType::BOOL); + NDArray x4('c', {3, 2}, {0.5, 1.5, 2.5, 3.5, 4.5, 0}, sd::DataType::DOUBLE); + NDArray x5('c', {3, 2}, sd::DataType::INT32); + NDArray x6('c', {2, 3}, sd::DataType::DOUBLE); - NDArray exp1('c', {2,3}, {0, 2, 4, 6, 8, 5}, sd::DataType::INT32); - NDArray exp2('c', {2,3}, {0.5, 2.5, 4.5, 6.5, 8.5, 5.}, sd::DataType::FLOAT32); - NDArray exp3('c', {2,3}, {1, 1, 1, 1, 1, 0}, sd::DataType::BOOL); - NDArray exp4('c', {2,3}, {0.5, 2.5, 4.5, 6.5, 8.5, 5.}, sd::DataType::DOUBLE); - NDArray exp5('c', {3,2}, {0, 2, 4, 6, 8, 5}, sd::DataType::INT32); + NDArray exp1('c', {2, 3}, {0, 2, 4, 6, 8, 5}, sd::DataType::INT32); + NDArray exp2('c', {2, 3}, {0.5, 2.5, 4.5, 6.5, 8.5, 5.}, + sd::DataType::FLOAT32); + NDArray exp3('c', {2, 3}, {1, 1, 1, 1, 1, 0}, sd::DataType::BOOL); + NDArray exp4('c', {2, 3}, {0.5, 2.5, 4.5, 6.5, 8.5, 5.}, + sd::DataType::DOUBLE); + NDArray exp5('c', {3, 2}, {0, 2, 4, 6, 8, 5}, sd::DataType::INT32); - x1.applyPairwiseTransform(sd::pairwise::Add, x4, x5); - ASSERT_EQ(x5, exp5); + x1.applyPairwiseTransform(sd::pairwise::Add, x4, x5); + ASSERT_EQ(x5, exp5); - x1.applyPairwiseTransform(sd::pairwise::Add, x4, x6); - ASSERT_EQ(x6, exp4); + x1.applyPairwiseTransform(sd::pairwise::Add, x4, x6); + ASSERT_EQ(x6, exp4); - x1.applyPairwiseTransform(sd::pairwise::Add, x4); - ASSERT_EQ(x1, exp1); + x1.applyPairwiseTransform(sd::pairwise::Add, x4); + ASSERT_EQ(x1, exp1); - x2.applyPairwiseTransform(sd::pairwise::Add, x4); - ASSERT_EQ(x2, exp2); + x2.applyPairwiseTransform(sd::pairwise::Add, x4); + ASSERT_EQ(x2, exp2); - x3.applyPairwiseTransform(sd::pairwise::Add, x4); - ASSERT_EQ(x3, exp3); + x3.applyPairwiseTransform(sd::pairwise::Add, x4); + ASSERT_EQ(x3, exp3); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyPairwiseTransform_test2) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2,3}, {1, 1, 2, 3, 4, 5}, sd::DataType::INT32); - NDArray x2('c', {3,2}, {1, 0, 2, 0, 4, 0}, sd::DataType::INT32); - NDArray x3('c', {3,2}, {0.5, 1.5, 2.5, 3, 4.5, 0}, sd::DataType::DOUBLE); - NDArray x4('c', {2,3}, {0.5, 1., 2.5, 3, 4., 0}, sd::DataType::DOUBLE); - NDArray x5('c', {3,2}, {0, 1, 0, 1, 0, 1}, sd::DataType::BOOL); - NDArray x6('c', {2,3}, {1, 1, 1, 0, 1, 0}, sd::DataType::BOOL); + NDArray x1('c', {2, 3}, {1, 1, 2, 3, 4, 5}, sd::DataType::INT32); + NDArray x2('c', {3, 2}, {1, 0, 2, 0, 4, 0}, sd::DataType::INT32); + NDArray x3('c', {3, 2}, {0.5, 1.5, 2.5, 3, 4.5, 0}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 3}, {0.5, 1., 2.5, 3, 4., 0}, sd::DataType::DOUBLE); + NDArray x5('c', {3, 2}, {0, 1, 0, 1, 0, 1}, sd::DataType::BOOL); + NDArray x6('c', {2, 3}, {1, 1, 1, 0, 1, 0}, sd::DataType::BOOL); - NDArray x7('c', {3,2}, sd::DataType::BOOL); - NDArray x8('c', {2,3}, sd::DataType::BOOL); + NDArray x7('c', {3, 2}, sd::DataType::BOOL); + NDArray x8('c', {2, 3}, sd::DataType::BOOL); - NDArray exp1('c', {3,2}, {1, 0, 1, 0, 1, 0}, sd::DataType::BOOL); - NDArray exp2('c', {2,3}, {1, 0, 1, 1, 0, 1}, sd::DataType::BOOL); - NDArray exp3('c', {2,3}, {0, 1, 0, 0, 0, 0}, sd::DataType::BOOL); + NDArray exp1('c', {3, 2}, {1, 0, 1, 0, 1, 0}, sd::DataType::BOOL); + NDArray exp2('c', {2, 3}, {1, 0, 1, 1, 0, 1}, sd::DataType::BOOL); + NDArray exp3('c', {2, 3}, {0, 1, 0, 0, 0, 0}, sd::DataType::BOOL); - x1.applyPairwiseTransform(sd::pairwise::EqualTo, x2, x7); - ASSERT_EQ(x7, exp1); + x1.applyPairwiseTransform(sd::pairwise::EqualTo, x2, x7); + ASSERT_EQ(x7, exp1); - x3.applyPairwiseTransform(sd::pairwise::EqualTo, x4, x8); - ASSERT_EQ(x8, exp2); + x3.applyPairwiseTransform(sd::pairwise::EqualTo, x4, x8); + ASSERT_EQ(x8, exp2); - x5.applyPairwiseTransform(sd::pairwise::EqualTo, x6, x8); - ASSERT_EQ(x8, exp3); + x5.applyPairwiseTransform(sd::pairwise::EqualTo, x6, x8); + ASSERT_EQ(x8, exp3); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyBroadcast_test1) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2,3}, {10, 20, 30, 40, 50, 60}, sd::DataType::INT32); - NDArray x2('c', {2}, {1, 2}, sd::DataType::INT64); - NDArray x3('c', {2,3}, sd::DataType::INT32); - NDArray x4('c', {2}, {1, 2}, sd::DataType::FLOAT32); - NDArray x5('c', {2,3}, sd::DataType::FLOAT32); - NDArray x6('c', {2}, {1, 1}, sd::DataType::BOOL); + NDArray x1('c', {2, 3}, {10, 20, 30, 40, 50, 60}, sd::DataType::INT32); + NDArray x2('c', {2}, {1, 2}, sd::DataType::INT64); + NDArray x3('c', {2, 3}, sd::DataType::INT32); + NDArray x4('c', {2}, {1, 2}, sd::DataType::FLOAT32); + NDArray x5('c', {2, 3}, sd::DataType::FLOAT32); + NDArray x6('c', {2}, {1, 1}, sd::DataType::BOOL); - NDArray exp1('c', {2,3}, {11, 21, 31, 42, 52, 62}, sd::DataType::INT32); - NDArray exp2('c', {2,3}, {11, 21, 31, 42, 52, 62}, sd::DataType::FLOAT32); - NDArray exp3('c', {2,3}, {11, 21, 31, 41, 51, 61}, sd::DataType::INT32); + NDArray exp1('c', {2, 3}, {11, 21, 31, 42, 52, 62}, sd::DataType::INT32); + NDArray exp2('c', {2, 3}, {11, 21, 31, 42, 52, 62}, sd::DataType::FLOAT32); + NDArray exp3('c', {2, 3}, {11, 21, 31, 41, 51, 61}, sd::DataType::INT32); - x1.applyBroadcast(sd::broadcast::Add, {0}, x2, x3); - ASSERT_EQ(x3, exp1); + x1.applyBroadcast(sd::broadcast::Add, {0}, x2, x3); + ASSERT_EQ(x3, exp1); - x1.applyBroadcast(sd::broadcast::Add, {0}, x4, x5); - ASSERT_EQ(x5, exp2); + x1.applyBroadcast(sd::broadcast::Add, {0}, x4, x5); + ASSERT_EQ(x5, exp2); - x1.applyBroadcast(sd::broadcast::Add, {0}, x6, x3); - ASSERT_EQ(x3, exp3); + x1.applyBroadcast(sd::broadcast::Add, {0}, x6, x3); + ASSERT_EQ(x3, exp3); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyBroadcast_test2) { + NDArray x1('c', {2, 3}, {10, 20, 30, 40, 50, 60}, sd::DataType::INT32); + NDArray x2('c', {2}, {10, 60}, sd::DataType::INT32); + NDArray x3('c', {2, 3}, sd::DataType::BOOL); - NDArray x1('c', {2,3}, {10, 20, 30, 40, 50, 60}, sd::DataType::INT32); - NDArray x2('c', {2}, {10, 60}, sd::DataType::INT32); - NDArray x3('c', {2,3}, sd::DataType::BOOL); - - NDArray x4('c', {2,3}, {0, 0, 0, 0, 0, 1}, sd::DataType::BOOL); - NDArray x5('c', {2}, {0, 1}, sd::DataType::BOOL); + NDArray x4('c', {2, 3}, {0, 0, 0, 0, 0, 1}, sd::DataType::BOOL); + NDArray x5('c', {2}, {0, 1}, sd::DataType::BOOL); - NDArray exp1('c', {2,3}, {1, 0, 0, 0, 0, 1}, sd::DataType::BOOL); - NDArray exp2('c', {2,3}, {1, 1, 1, 0, 0, 1}, sd::DataType::BOOL); + NDArray exp1('c', {2, 3}, {1, 0, 0, 0, 0, 1}, sd::DataType::BOOL); + NDArray exp2('c', {2, 3}, {1, 1, 1, 0, 0, 1}, sd::DataType::BOOL); - x1.applyBroadcast(sd::broadcast::EqualTo, {0}, x2, x3); - ASSERT_EQ(x3, exp1); + x1.applyBroadcast(sd::broadcast::EqualTo, {0}, x2, x3); + ASSERT_EQ(x3, exp1); - x4.applyBroadcast(sd::broadcast::EqualTo, {0}, x5, x3); - ASSERT_EQ(x3, exp2); + x4.applyBroadcast(sd::broadcast::EqualTo, {0}, x5, x3); + ASSERT_EQ(x3, exp2); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyTrueBroadcast_test1) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {10, 20, 30, 40}, sd::DataType::INT32); - NDArray x2('c', {2}, {1, 2}, sd::DataType::HALF); - NDArray x3('c', {2,2}, sd::DataType::HALF); + NDArray x1('c', {2, 2}, {10, 20, 30, 40}, sd::DataType::INT32); + NDArray x2('c', {2}, {1, 2}, sd::DataType::HALF); + NDArray x3('c', {2, 2}, sd::DataType::HALF); - NDArray x4('c', {2}, {1, 2}, sd::DataType::INT64); - NDArray x5('c', {2,2}, sd::DataType::INT32); + NDArray x4('c', {2}, {1, 2}, sd::DataType::INT64); + NDArray x5('c', {2, 2}, sd::DataType::INT32); - NDArray x6('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); - NDArray x7('c', {2}, {1, 2}, sd::DataType::INT64); - NDArray x8('c', {2,2}, sd::DataType::BOOL); + NDArray x6('c', {2, 2}, {0, 1, 0, 1}, sd::DataType::BOOL); + NDArray x7('c', {2}, {1, 2}, sd::DataType::INT64); + NDArray x8('c', {2, 2}, sd::DataType::BOOL); - NDArray x13('c', {0}, std::vector{3}, sd::DataType::INT64); - NDArray x14('c', {0}, std::vector{1.5}, sd::DataType::DOUBLE); - NDArray x15(sd::DataType::DOUBLE); - NDArray x16('c', {2,2}, sd::DataType::DOUBLE); + NDArray x13('c', {0}, std::vector{3}, sd::DataType::INT64); + NDArray x14('c', {0}, std::vector{1.5}, sd::DataType::DOUBLE); + NDArray x15(sd::DataType::DOUBLE); + NDArray x16('c', {2, 2}, sd::DataType::DOUBLE); - NDArray exp1('c', {2,2}, {11, 22, 31, 42}, sd::DataType::HALF); - NDArray exp2('c', {2,2}, {11, 22, 31, 42}, sd::DataType::INT32); - NDArray exp3('c', {2,2}, {1, 1, 1, 1}, sd::DataType::BOOL); - NDArray exp4('c', {0}, std::vector{4.5}, sd::DataType::DOUBLE); - NDArray exp5('c', {2,2}, {11.5, 21.5, 31.5, 41.5}, sd::DataType::DOUBLE); + NDArray exp1('c', {2, 2}, {11, 22, 31, 42}, sd::DataType::HALF); + NDArray exp2('c', {2, 2}, {11, 22, 31, 42}, sd::DataType::INT32); + NDArray exp3('c', {2, 2}, {1, 1, 1, 1}, sd::DataType::BOOL); + NDArray exp4('c', {0}, std::vector{4.5}, sd::DataType::DOUBLE); + NDArray exp5('c', {2, 2}, {11.5, 21.5, 31.5, 41.5}, sd::DataType::DOUBLE); - x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x2, x3); - ASSERT_EQ(x3, exp1); + x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x2, x3); + ASSERT_EQ(x3, exp1); - x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x4, x5); - ASSERT_EQ(x5, exp2); + x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x4, x5); + ASSERT_EQ(x5, exp2); - x6.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x7, x8); - ASSERT_EQ(x8, exp3); + x6.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x7, x8); + ASSERT_EQ(x8, exp3); - auto x9 = x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x2); - ASSERT_EQ(x9, exp1); + auto x9 = x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x2); + ASSERT_EQ(x9, exp1); - auto x10 = x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x4); - ASSERT_EQ(x10, exp2); + auto x10 = x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x4); + ASSERT_EQ(x10, exp2); - auto x11 = x6.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x7); - ASSERT_EQ(x11, exp3); + auto x11 = x6.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x7); + ASSERT_EQ(x11, exp3); - auto x12 = x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x2); - ASSERT_EQ(x12, exp1); + auto x12 = x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x2); + ASSERT_EQ(x12, exp1); - x13.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x14, x15); - ASSERT_EQ(x15, exp4); + x13.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x14, x15); + ASSERT_EQ(x15, exp4); - x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x14, x16); - ASSERT_EQ(x16, exp5); - - x14.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x1, x16); - ASSERT_EQ(x16, exp5); + x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x14, x16); + ASSERT_EQ(x16, exp5); + x14.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x1, x16); + ASSERT_EQ(x16, exp5); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyTrueBroadcast_test2) { - if (!Environment::getInstance().isExperimentalBuild()) - return; - - NDArray x1('c', {2,2}, {10, 20, 30, 40}, sd::DataType::HALF); - NDArray x2('c', {2}, {10, 40}, sd::DataType::HALF); - NDArray x3('c', {2,2}, sd::DataType::BOOL); - NDArray x4('c', {0}, std::vector{10}, sd::DataType::HALF); - NDArray x5('c', {0}, std::vector{20}, sd::DataType::HALF); - NDArray x6(sd::DataType::BOOL); - - NDArray exp1('c', {2,2}, {1, 0, 0, 1}, sd::DataType::BOOL); - NDArray exp2('c', {2,2}, {1, 0, 0, 0}, sd::DataType::BOOL); - NDArray exp3('c', {0}, std::vector{0}, sd::DataType::BOOL); - - x1.applyTrueBroadcast(BroadcastBoolOpsTuple(sd::scalar::EqualTo, sd::pairwise::EqualTo, sd::broadcast::EqualTo), x2, x3); - ASSERT_EQ(x3, exp1); - - x1.applyTrueBroadcast(BroadcastBoolOpsTuple(sd::scalar::EqualTo, sd::pairwise::EqualTo, sd::broadcast::EqualTo), x4, x3); - ASSERT_EQ(x3, exp2); - - x4.applyTrueBroadcast(BroadcastBoolOpsTuple(sd::scalar::EqualTo, sd::pairwise::EqualTo, sd::broadcast::EqualTo), x1, x3); - ASSERT_EQ(x3, exp2); - - x5.applyTrueBroadcast(BroadcastBoolOpsTuple(sd::scalar::EqualTo, sd::pairwise::EqualTo, sd::broadcast::EqualTo), x4, x6); - ASSERT_EQ(x6, exp3); + if (!Environment::getInstance().isExperimentalBuild()) return; + + NDArray x1('c', {2, 2}, {10, 20, 30, 40}, sd::DataType::HALF); + NDArray x2('c', {2}, {10, 40}, sd::DataType::HALF); + NDArray x3('c', {2, 2}, sd::DataType::BOOL); + NDArray x4('c', {0}, std::vector{10}, sd::DataType::HALF); + NDArray x5('c', {0}, std::vector{20}, sd::DataType::HALF); + NDArray x6(sd::DataType::BOOL); + + NDArray exp1('c', {2, 2}, {1, 0, 0, 1}, sd::DataType::BOOL); + NDArray exp2('c', {2, 2}, {1, 0, 0, 0}, sd::DataType::BOOL); + NDArray exp3('c', {0}, std::vector{0}, sd::DataType::BOOL); + + x1.applyTrueBroadcast( + BroadcastBoolOpsTuple(sd::scalar::EqualTo, sd::pairwise::EqualTo, + sd::broadcast::EqualTo), + x2, x3); + ASSERT_EQ(x3, exp1); + + x1.applyTrueBroadcast( + BroadcastBoolOpsTuple(sd::scalar::EqualTo, sd::pairwise::EqualTo, + sd::broadcast::EqualTo), + x4, x3); + ASSERT_EQ(x3, exp2); + + x4.applyTrueBroadcast( + BroadcastBoolOpsTuple(sd::scalar::EqualTo, sd::pairwise::EqualTo, + sd::broadcast::EqualTo), + x1, x3); + ASSERT_EQ(x3, exp2); + + x5.applyTrueBroadcast( + BroadcastBoolOpsTuple(sd::scalar::EqualTo, sd::pairwise::EqualTo, + sd::broadcast::EqualTo), + x4, x6); + ASSERT_EQ(x6, exp3); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyScalar_test1) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); - NDArray x3('c', {2,2}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x3('c', {2, 2}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, {0, 1, 0, 1}, sd::DataType::BOOL); - NDArray exp1('c', {2,2}, {1, 2, 3, 4}, sd::DataType::INT64); - NDArray exp2('c', {2,2}, {1.5, 2.5, 3.5, 4.5}, sd::DataType::DOUBLE); - NDArray exp3('c', {2,2}, {0.1, 1.6, 2.6, 3.6}, sd::DataType::FLOAT32); - NDArray exp4('c', {2,2}, {1.1, 2.1, 1.1, 2.1}, sd::DataType::DOUBLE); - NDArray exp5('c', {2,2}, {1, 1, 1, 1}, sd::DataType::BOOL); + NDArray exp1('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT64); + NDArray exp2('c', {2, 2}, {1.5, 2.5, 3.5, 4.5}, sd::DataType::DOUBLE); + NDArray exp3('c', {2, 2}, {0.1, 1.6, 2.6, 3.6}, sd::DataType::FLOAT32); + NDArray exp4('c', {2, 2}, {1.1, 2.1, 1.1, 2.1}, sd::DataType::DOUBLE); + NDArray exp5('c', {2, 2}, {1, 1, 1, 1}, sd::DataType::BOOL); - x1.applyScalar(sd::scalar::Add, 1, x1); - ASSERT_EQ(x1, exp1); + x1.applyScalar(sd::scalar::Add, 1, x1); + ASSERT_EQ(x1, exp1); - x1.applyScalar(sd::scalar::Add, 0.5, x3); - ASSERT_EQ(x3, exp2); + x1.applyScalar(sd::scalar::Add, 0.5, x3); + ASSERT_EQ(x3, exp2); - x2.applyScalar(sd::scalar::Add, 0.1, x2); - ASSERT_EQ(x2, exp3); + x2.applyScalar(sd::scalar::Add, 0.1, x2); + ASSERT_EQ(x2, exp3); - x4.applyScalar(sd::scalar::Add, 1.1, x3); - ASSERT_EQ(x3, exp4); + x4.applyScalar(sd::scalar::Add, 1.1, x3); + ASSERT_EQ(x3, exp4); - x4.applyScalar(sd::scalar::Add, 1, x4); - ASSERT_EQ(x4, exp5); + x4.applyScalar(sd::scalar::Add, 1, x4); + ASSERT_EQ(x4, exp5); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyScalar_test2) { + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x3('c', {2, 2}, {0, 1, 1, 0}, sd::DataType::BOOL); + NDArray x4('c', {2, 2}, sd::DataType::BOOL); - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); - NDArray x3('c', {2,2}, {0, 1, 1, 0}, sd::DataType::BOOL); - NDArray x4('c', {2,2}, sd::DataType::BOOL); - - - NDArray exp1('c', {2,2}, {0, 1, 0, 0}, sd::DataType::BOOL); - NDArray exp2('c', {2,2}, {0, 1, 1, 0}, sd::DataType::BOOL); - - x1.applyScalar(sd::scalar::EqualTo, 1, x4); - ASSERT_EQ(x4, exp1); + NDArray exp1('c', {2, 2}, {0, 1, 0, 0}, sd::DataType::BOOL); + NDArray exp2('c', {2, 2}, {0, 1, 1, 0}, sd::DataType::BOOL); - x2.applyScalar(sd::scalar::EqualTo, 1.5, x4); - ASSERT_EQ(x4, exp1); + x1.applyScalar(sd::scalar::EqualTo, 1, x4); + ASSERT_EQ(x4, exp1); - x3.applyScalar(sd::scalar::EqualTo, true, x4); - ASSERT_EQ(x4, exp2); + x2.applyScalar(sd::scalar::EqualTo, 1.5, x4); + ASSERT_EQ(x4, exp1); + x3.applyScalar(sd::scalar::EqualTo, true, x4); + ASSERT_EQ(x4, exp2); } #ifndef __CUDABLAS__ ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyLambda_test1) { + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::DOUBLE); + NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x3('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x4('c', {2, 2}, sd::DataType::DOUBLE); + NDArray x5('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x6('c', {2, 2}, {0, -1, -1, 0.1}, sd::DataType::BOOL); + NDArray x7('c', {2, 2}, sd::DataType::BOOL); - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::DOUBLE); - NDArray x2('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x3('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); - NDArray x4('c', {2,2}, sd::DataType::DOUBLE); - NDArray x5('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); - NDArray x6('c', {2,2}, {0, -1, -1, 0.1}, sd::DataType::BOOL); - NDArray x7('c', {2,2}, sd::DataType::BOOL); + const float item1 = 0.1; + const double item2 = 0.1; + auto func1 = [=](float elem) { return elem + item1; }; + auto func2 = [=](int elem) { return elem + item1; }; + auto func3 = [=](int elem) { return elem + item2; }; + auto func4 = [=](double elem) { return elem + item1; }; + auto func5 = [=](float elem) { return elem - (int)1; }; - const float item1 = 0.1; - const double item2 = 0.1; - auto func1 = [=](float elem) { return elem + item1; }; - auto func2 = [=](int elem) { return elem + item1; }; - auto func3 = [=](int elem) { return elem + item2; }; - auto func4 = [=](double elem) { return elem + item1; }; - auto func5 = [=](float elem) { return elem - (int)1; }; - - NDArray exp1('c', {2,2}, {0.1, 1.1, 2.1, 3.1}, sd::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray exp3('c', {2,2}, {0.1, 1.1, 2.1, 3.1}, sd::DataType::FLOAT32); - NDArray exp4('c', {2,2}, {0.1, 1.6, 2.6, 3.6}, sd::DataType::FLOAT32); - NDArray exp5('c', {2,2}, {1, 0, 0, 0}, sd::DataType::BOOL); - - x1.applyLambda(func1, x4); - ASSERT_EQ(x4, exp1); + NDArray exp1('c', {2, 2}, {0.1, 1.1, 2.1, 3.1}, sd::DataType::DOUBLE); + NDArray exp2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray exp3('c', {2, 2}, {0.1, 1.1, 2.1, 3.1}, sd::DataType::FLOAT32); + NDArray exp4('c', {2, 2}, {0.1, 1.6, 2.6, 3.6}, sd::DataType::FLOAT32); + NDArray exp5('c', {2, 2}, {1, 0, 0, 0}, sd::DataType::BOOL); - x2.applyLambda(func1, x2); - ASSERT_EQ(x2, exp2); + x1.applyLambda(func1, x4); + ASSERT_EQ(x4, exp1); - x2.applyLambda(func2, x2); - ASSERT_EQ(x2, exp2); + x2.applyLambda(func1, x2); + ASSERT_EQ(x2, exp2); - x3.applyLambda(func3, x3); - ASSERT_EQ(x3, exp3); + x2.applyLambda(func2, x2); + ASSERT_EQ(x2, exp2); - x5.applyLambda(func4, x5); - // x5.printBuffer(); - ASSERT_EQ(x5, exp4); + x3.applyLambda(func3, x3); + ASSERT_EQ(x3, exp3); - x6.applyLambda(func5, x7); - ASSERT_EQ(x7, exp5); + x5.applyLambda(func4, x5); + // x5.printBuffer(); + ASSERT_EQ(x5, exp4); + + x6.applyLambda(func5, x7); + ASSERT_EQ(x7, exp5); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyIndexedLambda_test1) { + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::DOUBLE); + NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x3('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x4('c', {2, 2}, sd::DataType::DOUBLE); + NDArray x5('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x6('c', {2, 2}, {1, -1, -1, 0.1}, sd::DataType::BOOL); + NDArray x7('c', {2, 2}, sd::DataType::BOOL); - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::DOUBLE); - NDArray x2('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x3('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); - NDArray x4('c', {2,2}, sd::DataType::DOUBLE); - NDArray x5('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); - NDArray x6('c', {2,2}, {1, -1, -1, 0.1}, sd::DataType::BOOL); - NDArray x7('c', {2,2}, sd::DataType::BOOL); + const float item1 = 0.1; + const double item2 = 0.1; + auto func1 = [=](Nd4jLong idx, float elem) { return idx + elem + item1; }; + auto func2 = [=](Nd4jLong idx, int elem) { return idx + elem + item1; }; + auto func3 = [=](Nd4jLong idx, int elem) { return idx + elem + item2; }; + auto func4 = [=](Nd4jLong idx, double elem) { return idx + elem + item1; }; + auto func5 = [=](Nd4jLong idx, float elem) { return idx + elem - (int)1; }; - const float item1 = 0.1; - const double item2 = 0.1; - auto func1 = [=](Nd4jLong idx, float elem) { return idx + elem + item1; }; - auto func2 = [=](Nd4jLong idx, int elem) { return idx + elem + item1; }; - auto func3 = [=](Nd4jLong idx, int elem) { return idx + elem + item2; }; - auto func4 = [=](Nd4jLong idx, double elem) { return idx + elem + item1; }; - auto func5 = [=](Nd4jLong idx, float elem) { return idx + elem - (int)1; }; - - NDArray exp1('c', {2,2}, {0.1, 2.1, 4.1, 6.1}, sd::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {0, 2, 4, 6}, sd::DataType::INT64); - NDArray exp3('c', {2,2}, {0.1, 2.1, 4.1, 6.1}, sd::DataType::FLOAT32); - NDArray exp4('c', {2,2}, {0.1, 2.6, 4.6, 6.6}, sd::DataType::FLOAT32); - NDArray exp5('c', {2,2}, {0, 1, 1, 1}, sd::DataType::BOOL); - NDArray exp6('c', {2,2}, {0, 3, 6, 9}, sd::DataType::INT64); - - x1.applyIndexedLambda(func1, x4); - ASSERT_EQ(x4, exp1); + NDArray exp1('c', {2, 2}, {0.1, 2.1, 4.1, 6.1}, sd::DataType::DOUBLE); + NDArray exp2('c', {2, 2}, {0, 2, 4, 6}, sd::DataType::INT64); + NDArray exp3('c', {2, 2}, {0.1, 2.1, 4.1, 6.1}, sd::DataType::FLOAT32); + NDArray exp4('c', {2, 2}, {0.1, 2.6, 4.6, 6.6}, sd::DataType::FLOAT32); + NDArray exp5('c', {2, 2}, {0, 1, 1, 1}, sd::DataType::BOOL); + NDArray exp6('c', {2, 2}, {0, 3, 6, 9}, sd::DataType::INT64); - x2.applyIndexedLambda(func1, x2); - ASSERT_EQ(x2, exp2); + x1.applyIndexedLambda(func1, x4); + ASSERT_EQ(x4, exp1); - x2.applyIndexedLambda(func2, x2); - ASSERT_EQ(x2, exp6); + x2.applyIndexedLambda(func1, x2); + ASSERT_EQ(x2, exp2); - x3.applyIndexedLambda(func3, x3); - ASSERT_EQ(x3, exp3); + x2.applyIndexedLambda(func2, x2); + ASSERT_EQ(x2, exp6); + + x3.applyIndexedLambda(func3, x3); + ASSERT_EQ(x3, exp3); - x5.applyIndexedLambda(func4, x5); - ASSERT_EQ(x5, exp4); + x5.applyIndexedLambda(func4, x5); + ASSERT_EQ(x5, exp4); - x6.applyIndexedLambda(func5, x7); - ASSERT_EQ(x7, exp5); + x6.applyIndexedLambda(func5, x7); + ASSERT_EQ(x7, exp5); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyPairwiseLambda_test1) { + NDArray x1('c', {2, 2}, {0., 1, 2, 3}, sd::DataType::DOUBLE); + NDArray x2('c', {2, 2}, {0., 1, 2, 3}, sd::DataType::INT64); + NDArray x3('c', {2, 2}, {0., 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x4('c', {2, 2}, sd::DataType::DOUBLE); + NDArray x5('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x6('c', {2, 2}, {0.1, -1, -1, 0.1}, sd::DataType::BOOL); + NDArray x7('c', {2, 2}, sd::DataType::BOOL); + NDArray other1('c', {2, 2}, {0.1, 0.1, 0.1, 0.1}, sd::DataType::FLOAT32); + NDArray other2('c', {2, 2}, {0.1, 0.1, 0.1, 0.1}, sd::DataType::DOUBLE); + NDArray other3('c', {2, 2}, {0., -1, -2, -3}, sd::DataType::INT64); + NDArray other4('c', {2, 2}, {1, 0, 0.1, 0}, sd::DataType::BOOL); - NDArray x1('c', {2,2}, {0., 1, 2, 3}, sd::DataType::DOUBLE); - NDArray x2('c', {2,2}, {0., 1, 2, 3}, sd::DataType::INT64); - NDArray x3('c', {2,2}, {0., 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); - NDArray x4('c', {2,2}, sd::DataType::DOUBLE); - NDArray x5('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); - NDArray x6('c', {2,2}, {0.1, -1, -1, 0.1}, sd::DataType::BOOL); - NDArray x7('c', {2,2}, sd::DataType::BOOL); - NDArray other1('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, sd::DataType::FLOAT32); - NDArray other2('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, sd::DataType::DOUBLE); - NDArray other3('c', {2,2}, {0., -1, -2, -3}, sd::DataType::INT64); - NDArray other4('c', {2,2}, {1, 0, 0.1, 0}, sd::DataType::BOOL); - - auto func1 = [](float elem1, float elem2) { return elem1 + elem2; }; - auto func2 = [](int elem1, float elem2) { return elem1 + elem2; }; - auto func3 = [](int elem1, double elem2) { return elem1 + elem2; }; - auto func4 = [](double elem1, float elem2) { return elem1 + elem2; }; - auto func5 = [](float elem1, int elem2) { return elem1 - elem2; }; - - NDArray exp1('c', {2,2}, {0.1, 1.1, 2.1, 3.1}, sd::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {0., 0, 0, 0}, sd::DataType::INT64); - NDArray exp3('c', {2,2}, {0.1, 1.1, 2.1, 3.1}, sd::DataType::FLOAT32); - NDArray exp4('c', {2,2}, {0.1, 1.6, 2.6, 3.6}, sd::DataType::FLOAT32); - NDArray exp5('c', {2,2}, {0., 1, 0, 1}, sd::DataType::BOOL); - - x1.applyPairwiseLambda(other2, func1, x4); - ASSERT_EQ(x4, exp1); + auto func1 = [](float elem1, float elem2) { return elem1 + elem2; }; + auto func2 = [](int elem1, float elem2) { return elem1 + elem2; }; + auto func3 = [](int elem1, double elem2) { return elem1 + elem2; }; + auto func4 = [](double elem1, float elem2) { return elem1 + elem2; }; + auto func5 = [](float elem1, int elem2) { return elem1 - elem2; }; - x2.applyPairwiseLambda(other3, func1, x2); - ASSERT_EQ(x2, exp2); + NDArray exp1('c', {2, 2}, {0.1, 1.1, 2.1, 3.1}, sd::DataType::DOUBLE); + NDArray exp2('c', {2, 2}, {0., 0, 0, 0}, sd::DataType::INT64); + NDArray exp3('c', {2, 2}, {0.1, 1.1, 2.1, 3.1}, sd::DataType::FLOAT32); + NDArray exp4('c', {2, 2}, {0.1, 1.6, 2.6, 3.6}, sd::DataType::FLOAT32); + NDArray exp5('c', {2, 2}, {0., 1, 0, 1}, sd::DataType::BOOL); - x2.applyPairwiseLambda(other3, func2, x2); - ASSERT_EQ(x2, other3); + x1.applyPairwiseLambda(other2, func1, x4); + ASSERT_EQ(x4, exp1); - x3.applyPairwiseLambda(other1, func3, x3); - ASSERT_EQ(x3, exp3); + x2.applyPairwiseLambda(other3, func1, x2); + ASSERT_EQ(x2, exp2); + + x2.applyPairwiseLambda(other3, func2, x2); + ASSERT_EQ(x2, other3); - x5.applyPairwiseLambda(other1, func4, x5); - ASSERT_EQ(x5, exp4); + x3.applyPairwiseLambda(other1, func3, x3); + ASSERT_EQ(x3, exp3); - x6.applyPairwiseLambda(other4, func5, x7); - ASSERT_EQ(x7, exp5); + x5.applyPairwiseLambda(other1, func4, x5); + ASSERT_EQ(x5, exp4); + + x6.applyPairwiseLambda(other4, func5, x7); + ASSERT_EQ(x7, exp5); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyIndexedPairwiseLambda_test1) { - - NDArray x1('c', {2,2}, {0., 1, 2, 3}, sd::DataType::DOUBLE); - NDArray x2('c', {2,2}, {0., 1, 2, 3}, sd::DataType::INT64); - NDArray x3('c', {2,2}, {0., 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); - NDArray x4('c', {2,2}, sd::DataType::DOUBLE); - NDArray x5('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); - NDArray x6('c', {2,2}, {0.1, -1, -1, 0.1}, sd::DataType::BOOL); - NDArray x7('c', {2,2}, sd::DataType::BOOL); - NDArray other1('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, sd::DataType::FLOAT32); - NDArray other2('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, sd::DataType::DOUBLE); - NDArray other3('c', {2,2}, {0., -1, -2, -3}, sd::DataType::INT64); - NDArray other4('c', {2,2}, {1, 0, 0.1, 0}, sd::DataType::BOOL); - - auto func1 = [](Nd4jLong idx, float elem1, float elem2) { return elem1 + elem2 + idx; }; - auto func2 = [](Nd4jLong idx, int elem1, float elem2) { return elem1 + elem2 + idx; }; - auto func3 = [](Nd4jLong idx, int elem1, double elem2) { return elem1 + elem2 + idx; }; - auto func4 = [](Nd4jLong idx, double elem1, float elem2) { return elem1 + elem2 + idx; }; - auto func5 = [](Nd4jLong idx, float elem1, int elem2) { return elem1 - elem2 + idx; }; - - NDArray exp1('c', {2,2}, {0.1, 2.1, 4.1, 6.1}, sd::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {0., 1, 2, 3}, sd::DataType::INT64); - NDArray exp3('c', {2,2}, {0.1, 2.1, 4.1, 6.1}, sd::DataType::FLOAT32); - NDArray exp4('c', {2,2}, {0.1, 2.6, 4.6, 6.6}, sd::DataType::FLOAT32); - NDArray exp5('c', {2,2}, {0., 1, 1, 1}, sd::DataType::BOOL); - - x1.applyIndexedPairwiseLambda(other2, func1, x4); - ASSERT_EQ(x4, exp1); - - x2.applyIndexedPairwiseLambda(other3, func1, x2); - ASSERT_EQ(x2, exp2); - - x2.applyIndexedPairwiseLambda(other3, func2, x2); - ASSERT_EQ(x2, exp2); - - x3.applyIndexedPairwiseLambda(other1, func3, x3); - ASSERT_EQ(x3, exp3); - - x5.applyIndexedPairwiseLambda(other1, func4, x5); - ASSERT_EQ(x5, exp4); - - x6.applyIndexedPairwiseLambda(other4, func5, x7); - ASSERT_EQ(x7, exp5); + NDArray x1('c', {2, 2}, {0., 1, 2, 3}, sd::DataType::DOUBLE); + NDArray x2('c', {2, 2}, {0., 1, 2, 3}, sd::DataType::INT64); + NDArray x3('c', {2, 2}, {0., 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x4('c', {2, 2}, sd::DataType::DOUBLE); + NDArray x5('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x6('c', {2, 2}, {0.1, -1, -1, 0.1}, sd::DataType::BOOL); + NDArray x7('c', {2, 2}, sd::DataType::BOOL); + NDArray other1('c', {2, 2}, {0.1, 0.1, 0.1, 0.1}, sd::DataType::FLOAT32); + NDArray other2('c', {2, 2}, {0.1, 0.1, 0.1, 0.1}, sd::DataType::DOUBLE); + NDArray other3('c', {2, 2}, {0., -1, -2, -3}, sd::DataType::INT64); + NDArray other4('c', {2, 2}, {1, 0, 0.1, 0}, sd::DataType::BOOL); + + auto func1 = [](Nd4jLong idx, float elem1, float elem2) { + return elem1 + elem2 + idx; + }; + auto func2 = [](Nd4jLong idx, int elem1, float elem2) { + return elem1 + elem2 + idx; + }; + auto func3 = [](Nd4jLong idx, int elem1, double elem2) { + return elem1 + elem2 + idx; + }; + auto func4 = [](Nd4jLong idx, double elem1, float elem2) { + return elem1 + elem2 + idx; + }; + auto func5 = [](Nd4jLong idx, float elem1, int elem2) { + return elem1 - elem2 + idx; + }; + + NDArray exp1('c', {2, 2}, {0.1, 2.1, 4.1, 6.1}, sd::DataType::DOUBLE); + NDArray exp2('c', {2, 2}, {0., 1, 2, 3}, sd::DataType::INT64); + NDArray exp3('c', {2, 2}, {0.1, 2.1, 4.1, 6.1}, sd::DataType::FLOAT32); + NDArray exp4('c', {2, 2}, {0.1, 2.6, 4.6, 6.6}, sd::DataType::FLOAT32); + NDArray exp5('c', {2, 2}, {0., 1, 1, 1}, sd::DataType::BOOL); + + x1.applyIndexedPairwiseLambda(other2, func1, x4); + ASSERT_EQ(x4, exp1); + + x2.applyIndexedPairwiseLambda(other3, func1, x2); + ASSERT_EQ(x2, exp2); + + x2.applyIndexedPairwiseLambda(other3, func2, x2); + ASSERT_EQ(x2, exp2); + + x3.applyIndexedPairwiseLambda(other1, func3, x3); + ASSERT_EQ(x3, exp3); + + x5.applyIndexedPairwiseLambda(other1, func4, x5); + ASSERT_EQ(x5, exp4); + + x6.applyIndexedPairwiseLambda(other4, func5, x7); + ASSERT_EQ(x7, exp5); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyTriplewiseLambda_test1) { + NDArray x1('c', {2, 2}, {0., 1, 2, 3}, sd::DataType::DOUBLE); + NDArray x2('c', {2, 2}, {0., -1, -2, -3}, sd::DataType::DOUBLE); + NDArray x3('c', {2, 2}, {0, -1.5, -2.5, -3.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, sd::DataType::DOUBLE); - NDArray x1('c', {2,2}, {0., 1, 2, 3}, sd::DataType::DOUBLE); - NDArray x2('c', {2,2}, {0., -1, -2, -3}, sd::DataType::DOUBLE); - NDArray x3('c', {2,2}, {0, -1.5, -2.5, -3.5}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, sd::DataType::DOUBLE); - - NDArray x5('c', {2,2}, {0., 1, 2, 3}, sd::DataType::INT32); - NDArray x6('c', {2,2}, {0., -1, -2, -3}, sd::DataType::INT32); - NDArray x7('c', {2,2}, {0., 10, 20, 30}, sd::DataType::INT32); + NDArray x5('c', {2, 2}, {0., 1, 2, 3}, sd::DataType::INT32); + NDArray x6('c', {2, 2}, {0., -1, -2, -3}, sd::DataType::INT32); + NDArray x7('c', {2, 2}, {0., 10, 20, 30}, sd::DataType::INT32); - NDArray x8('c', {2,2}, {0., 1, 0, 1}, sd::DataType::BOOL); - NDArray x9('c', {2,2}, {1., 1, 0, 1}, sd::DataType::BOOL); - NDArray x10('c', {2,2}, {0., 0, 0, 0}, sd::DataType::BOOL); + NDArray x8('c', {2, 2}, {0., 1, 0, 1}, sd::DataType::BOOL); + NDArray x9('c', {2, 2}, {1., 1, 0, 1}, sd::DataType::BOOL); + NDArray x10('c', {2, 2}, {0., 0, 0, 0}, sd::DataType::BOOL); - auto func1 = [](double elem1, float elem2, int elem3) { return elem1 + elem2 + elem3; }; - auto func2 = [](float elem1, float elem2, float elem3) { return elem1 + elem2 + elem3; }; - auto func3 = [](int elem1, int elem2, int elem3) { return elem1 + elem2 + elem3; }; - auto func4 = [](bool elem1, bool elem2, bool elem3) { return elem1 + elem2 + elem3; }; + auto func1 = [](double elem1, float elem2, int elem3) { + return elem1 + elem2 + elem3; + }; + auto func2 = [](float elem1, float elem2, float elem3) { + return elem1 + elem2 + elem3; + }; + auto func3 = [](int elem1, int elem2, int elem3) { + return elem1 + elem2 + elem3; + }; + auto func4 = [](bool elem1, bool elem2, bool elem3) { + return elem1 + elem2 + elem3; + }; - NDArray exp('c', {2,2}, {1., 1, 0, 1}, sd::DataType::BOOL); + NDArray exp('c', {2, 2}, {1., 1, 0, 1}, sd::DataType::BOOL); - x1.applyTriplewiseLambda(x2, x3, func1, x4); - ASSERT_EQ(x4, x2); + x1.applyTriplewiseLambda(x2, x3, func1, x4); + ASSERT_EQ(x4, x2); - x1.applyTriplewiseLambda(x2, x3, func2, x1); - ASSERT_EQ(x1, x3); + x1.applyTriplewiseLambda(x2, x3, func2, x1); + ASSERT_EQ(x1, x3); - x5.applyTriplewiseLambda(x6, x7, func3, x5); - ASSERT_EQ(x5, x7); + x5.applyTriplewiseLambda(x6, x7, func3, x5); + ASSERT_EQ(x5, x7); - x8.applyTriplewiseLambda(x9, x10, func4, x8); - ASSERT_EQ(x8, exp); + x8.applyTriplewiseLambda(x9, x10, func4, x8); + ASSERT_EQ(x8, exp); } #endif ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyIndexReduce_test1) { + NDArray x1('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::DOUBLE); + NDArray exp1('c', {}, std::vector{5}, sd::DataType::INT64); + NDArray exp2('c', {2}, {2, 2}, sd::DataType::INT64); + NDArray exp3('c', {3}, {1, 1, 1}, sd::DataType::INT64); - NDArray x1('c', {2,3}, {0, 1, 2, 3, 4, 5}, sd::DataType::DOUBLE); - NDArray exp1('c', {}, std::vector{5}, sd::DataType::INT64); - NDArray exp2('c', {2}, {2,2}, sd::DataType::INT64); - NDArray exp3('c', {3}, {1,1,1}, sd::DataType::INT64); - - NDArray scalar = x1.applyIndexReduce(sd::indexreduce::IndexMax, {0,1}); - ASSERT_EQ(scalar, exp1); + NDArray scalar = x1.applyIndexReduce(sd::indexreduce::IndexMax, {0, 1}); + ASSERT_EQ(scalar, exp1); - NDArray vec1 = x1.applyIndexReduce(sd::indexreduce::IndexMax, {1}); - ASSERT_EQ(vec1, exp2); + NDArray vec1 = x1.applyIndexReduce(sd::indexreduce::IndexMax, {1}); + ASSERT_EQ(vec1, exp2); - NDArray vec2 = x1.applyIndexReduce(sd::indexreduce::IndexMax, {0}); - ASSERT_EQ(vec2, exp3); + NDArray vec2 = x1.applyIndexReduce(sd::indexreduce::IndexMax, {0}); + ASSERT_EQ(vec2, exp3); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyIndexReduce_test2) { + NDArray x1('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::DOUBLE); + NDArray scalar('c', {}, std::vector{5}, sd::DataType::INT64); + NDArray vec1('c', {2}, {2, 2}, sd::DataType::INT64); + NDArray vec2('c', {3}, {1, 1, 1}, sd::DataType::INT64); + NDArray exp1('c', {}, std::vector{5}, sd::DataType::INT64); + NDArray exp2('c', {2}, {2, 2}, sd::DataType::INT64); + NDArray exp3('c', {3}, {1, 1, 1}, sd::DataType::INT64); - NDArray x1('c', {2,3}, {0, 1, 2, 3, 4, 5}, sd::DataType::DOUBLE); - NDArray scalar('c', {}, std::vector{5}, sd::DataType::INT64); - NDArray vec1('c', {2}, {2,2}, sd::DataType::INT64); - NDArray vec2('c', {3}, {1,1,1}, sd::DataType::INT64); - NDArray exp1('c', {}, std::vector{5}, sd::DataType::INT64); - NDArray exp2('c', {2}, {2,2}, sd::DataType::INT64); - NDArray exp3('c', {3}, {1,1,1}, sd::DataType::INT64); + x1.applyIndexReduce(sd::indexreduce::IndexMax, scalar, {0, 1}); + ASSERT_EQ(scalar, exp1); - x1.applyIndexReduce(sd::indexreduce::IndexMax, scalar, {0,1}); - ASSERT_EQ(scalar, exp1); + x1.applyIndexReduce(sd::indexreduce::IndexMax, vec1, {1}); + ASSERT_EQ(vec1, exp2); - x1.applyIndexReduce(sd::indexreduce::IndexMax, vec1, {1}); - ASSERT_EQ(vec1, exp2); - - x1.applyIndexReduce(sd::indexreduce::IndexMax, vec2, {0}); - ASSERT_EQ(vec2, exp3); + x1.applyIndexReduce(sd::indexreduce::IndexMax, vec2, {0}); + ASSERT_EQ(vec2, exp3); } ////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, applyReduce3_test1) { + NDArray x1('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray x2('c', {2, 2}, {-1, -2, -3, -4}, sd::DataType::INT32); + NDArray x3('c', {2, 2}, {1.5, 1.5, 1.5, 1.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::DOUBLE); + NDArray exp1('c', {}, std::vector{-30}, sd::DataType::FLOAT32); + NDArray exp2('c', {}, std::vector{15}, sd::DataType::DOUBLE); - NDArray x1('c', {2,2}, {1,2,3,4}, sd::DataType::INT32); - NDArray x2('c', {2,2}, {-1,-2,-3,-4}, sd::DataType::INT32); - NDArray x3('c', {2,2}, {1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, {1,2,3,4}, sd::DataType::DOUBLE); - NDArray exp1('c', {}, std::vector{-30}, sd::DataType::FLOAT32); - NDArray exp2('c', {}, std::vector{15}, sd::DataType::DOUBLE); - - auto result = x1.applyReduce3(reduce3::Dot, x2); - ASSERT_EQ(result, exp1); + auto result = x1.applyReduce3(reduce3::Dot, x2); + ASSERT_EQ(result, exp1); - result = x3.applyReduce3(reduce3::Dot, x4); - ASSERT_EQ(result, exp2); + result = x3.applyReduce3(reduce3::Dot, x4); + ASSERT_EQ(result, exp2); } ////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, applyReduce3_test2) { + NDArray x1('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray x2('c', {2, 2}, {-1, -2, -3, -4}, sd::DataType::INT32); + NDArray x3('c', {2, 2}, {1.5, 1.5, 1.5, 1.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::DOUBLE); + NDArray x5('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::INT32); + NDArray x6('c', {2, 3}, {-6, -5, -4, -3, -2, -1}, sd::DataType::INT32); + NDArray x7('c', {2, 3}, {1.5, 1.5, 1.5, 1.5, 1.5, 1.5}, sd::DataType::DOUBLE); + NDArray x8('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); - NDArray x1('c', {2,2}, {1,2,3,4}, sd::DataType::INT32); - NDArray x2('c', {2,2}, {-1,-2,-3,-4}, sd::DataType::INT32); - NDArray x3('c', {2,2}, {1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, {1,2,3,4}, sd::DataType::DOUBLE); - NDArray x5('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::INT32); - NDArray x6('c', {2,3}, {-6,-5,-4,-3,-2,-1}, sd::DataType::INT32); - NDArray x7('c', {2,3}, {1.5,1.5,1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); - NDArray x8('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::DOUBLE); + NDArray exp1('c', {}, std::vector{-30}, sd::DataType::FLOAT32); + NDArray exp2('c', {}, std::vector{15}, sd::DataType::DOUBLE); + NDArray exp3('c', {3}, {-18, -20, -18}, sd::DataType::FLOAT32); + NDArray exp4('c', {2}, {-28, -28}, sd::DataType::FLOAT32); + NDArray exp5('c', {3}, {7.5, 10.5, 13.5}, sd::DataType::DOUBLE); + NDArray exp6('c', {2}, {9, 22.5}, sd::DataType::DOUBLE); - NDArray exp1('c', {}, std::vector{-30}, sd::DataType::FLOAT32); - NDArray exp2('c', {}, std::vector{15}, sd::DataType::DOUBLE); - NDArray exp3('c', {3}, {-18,-20,-18}, sd::DataType::FLOAT32); - NDArray exp4('c', {2}, {-28,-28}, sd::DataType::FLOAT32); - NDArray exp5('c', {3}, {7.5,10.5,13.5}, sd::DataType::DOUBLE); - NDArray exp6('c', {2}, {9,22.5}, sd::DataType::DOUBLE); + auto result = x1.applyReduce3(reduce3::Dot, x2, {0, 1}); + ASSERT_EQ(result, exp1); - auto result = x1.applyReduce3(reduce3::Dot, x2, {0,1}); - ASSERT_EQ(result, exp1); + result = x3.applyReduce3(reduce3::Dot, x4, {0, 1}); + ASSERT_EQ(result, exp2); - result = x3.applyReduce3(reduce3::Dot, x4, {0,1}); - ASSERT_EQ(result, exp2); + result = x5.applyReduce3(reduce3::Dot, x6, std::vector({0})); + ASSERT_EQ(result, exp3); - result = x5.applyReduce3(reduce3::Dot, x6, std::vector({0})); - ASSERT_EQ(result, exp3); + result = x5.applyReduce3(reduce3::Dot, x6, std::vector({1})); + ASSERT_EQ(result, exp4); - result = x5.applyReduce3(reduce3::Dot, x6, std::vector({1})); - ASSERT_EQ(result, exp4); + result = x8.applyReduce3(reduce3::Dot, x7, std::vector({0})); + ASSERT_EQ(result, exp5); - result = x8.applyReduce3(reduce3::Dot, x7, std::vector({0})); - ASSERT_EQ(result, exp5); - - result = x8.applyReduce3(reduce3::Dot, x7, std::vector({1})); - ASSERT_EQ(result, exp6); + result = x8.applyReduce3(reduce3::Dot, x7, std::vector({1})); + ASSERT_EQ(result, exp6); } ////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, applyAllReduce3_test1) { + NDArray x1('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray x2('c', {2, 3}, {-1, 1, -1, 1, -1, 1}, sd::DataType::INT32); + NDArray x3('c', {2, 3}, {1.5, 1.5, 1.5, 1.5, 1.5, 1.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::DOUBLE); + NDArray exp1('c', {2, 3}, {2, -2, 2, 2, -2, 2}, sd::DataType::FLOAT32); + NDArray exp2('c', {2, 3}, {6, 6, 6, 9, 9, 9}, sd::DataType::DOUBLE); - NDArray x1('c', {2,2}, {1,2,3,4}, sd::DataType::INT32); - NDArray x2('c', {2,3}, {-1,1,-1,1,-1,1}, sd::DataType::INT32); - NDArray x3('c', {2,3}, {1.5,1.5,1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, {1,2,3,4}, sd::DataType::DOUBLE); - NDArray exp1('c', {2,3}, {2,-2,2,2,-2,2}, sd::DataType::FLOAT32); - NDArray exp2('c', {2,3}, {6,6,6,9,9,9}, sd::DataType::DOUBLE); - - auto result = x1.applyAllReduce3(reduce3::Dot, x2, {0}); - ASSERT_EQ(result, exp1); + auto result = x1.applyAllReduce3(reduce3::Dot, x2, {0}); + ASSERT_EQ(result, exp1); - result = x4.applyAllReduce3(reduce3::Dot, x3, {0}); - ASSERT_EQ(result, exp2); + result = x4.applyAllReduce3(reduce3::Dot, x3, {0}); + ASSERT_EQ(result, exp2); } ////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, RowCol_test1) { - if (!Environment::getInstance().isExperimentalBuild()) - return; + if (!Environment::getInstance().isExperimentalBuild()) return; - NDArray x1('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::INT32); - NDArray x2('c', {2}, {0.5,0.6}, sd::DataType::FLOAT32); - NDArray x3('c', {3}, {1.5,1.6,1.7}, sd::DataType::FLOAT32); - NDArray x4('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::DOUBLE); - NDArray x5('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::INT32); + NDArray x1('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::INT32); + NDArray x2('c', {2}, {0.5, 0.6}, sd::DataType::FLOAT32); + NDArray x3('c', {3}, {1.5, 1.6, 1.7}, sd::DataType::FLOAT32); + NDArray x4('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); + NDArray x5('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::INT32); - NDArray exp1('c', {2,3}, {2,3,4,5,6,7}, sd::DataType::INT32); - NDArray exp2('c', {2,3}, {0,1,2,3,4,5}, sd::DataType::INT32); - NDArray exp3('c', {2,3}, {1.5,2.5,3.5,4.6,5.6,6.6}, sd::DataType::DOUBLE); - NDArray exp4('c', {2,3}, {0,1,1,2,3,3}, sd::DataType::INT32); + NDArray exp1('c', {2, 3}, {2, 3, 4, 5, 6, 7}, sd::DataType::INT32); + NDArray exp2('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::INT32); + NDArray exp3('c', {2, 3}, {1.5, 2.5, 3.5, 4.6, 5.6, 6.6}, + sd::DataType::DOUBLE); + NDArray exp4('c', {2, 3}, {0, 1, 1, 2, 3, 3}, sd::DataType::INT32); - x1.addiRowVector(x3); - ASSERT_EQ(x1, exp1); + x1.addiRowVector(x3); + ASSERT_EQ(x1, exp1); - x1.addiColumnVector(x2); - ASSERT_EQ(x1, exp1); + x1.addiColumnVector(x2); + ASSERT_EQ(x1, exp1); - x4.addiColumnVector(x2); - ASSERT_EQ(x4, exp3); + x4.addiColumnVector(x2); + ASSERT_EQ(x4, exp3); - x5.muliColumnVector(x2); - ASSERT_EQ(x5, exp4); + x5.muliColumnVector(x2); + ASSERT_EQ(x5, exp4); } ////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, RowCol_test2) { - if (!Environment::getInstance().isExperimentalBuild()) - return; - - NDArray x1('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::INT32); - NDArray x2('c', {2}, {0.5,0.6}, sd::DataType::FLOAT32); - NDArray x3('c', {3}, {1.5,1.6,1.7}, sd::DataType::FLOAT32); - NDArray x4('c', {2,3}, sd::DataType::FLOAT32); - NDArray x5('c', {3}, {1,2,3}, sd::DataType::INT64); - NDArray x6('c', {2,3}, sd::DataType::INT32); - NDArray x7('c', {3}, {1.5,1.6,1.7}, sd::DataType::DOUBLE); - NDArray x8('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::FLOAT32); - NDArray x9('c', {3}, {1,2,3}, sd::DataType::DOUBLE); - NDArray x10('c', {2,3}, sd::DataType::DOUBLE); - - NDArray exp1('c', {2,3}, {2.5,3.6,4.7,5.5,6.6,7.7}, sd::DataType::FLOAT32); - NDArray exp2('c', {2,3}, {2, 4, 6, 5, 7, 9}, sd::DataType::INT32); - NDArray exp3('c', {2,3}, {-0.5,0.4,1.3,2.5,3.4,4.3}, sd::DataType::FLOAT32); - NDArray exp4('c', {2,3}, {1,4,9,4,10,18}, sd::DataType::DOUBLE); - NDArray exp5('c', {2,3}, {1,1,1,4,2.5,2}, sd::DataType::DOUBLE); - NDArray exp6('c', {2,3}, {1.5,2.5,3.5,4.6,5.6,6.6}, sd::DataType::FLOAT32); - - x1.addRowVector(x3, x4); - ASSERT_EQ(x4, exp1); + if (!Environment::getInstance().isExperimentalBuild()) return; + + NDArray x1('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::INT32); + NDArray x2('c', {2}, {0.5, 0.6}, sd::DataType::FLOAT32); + NDArray x3('c', {3}, {1.5, 1.6, 1.7}, sd::DataType::FLOAT32); + NDArray x4('c', {2, 3}, sd::DataType::FLOAT32); + NDArray x5('c', {3}, {1, 2, 3}, sd::DataType::INT64); + NDArray x6('c', {2, 3}, sd::DataType::INT32); + NDArray x7('c', {3}, {1.5, 1.6, 1.7}, sd::DataType::DOUBLE); + NDArray x8('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::FLOAT32); + NDArray x9('c', {3}, {1, 2, 3}, sd::DataType::DOUBLE); + NDArray x10('c', {2, 3}, sd::DataType::DOUBLE); - x1.addRowVector(x5, x6); - ASSERT_EQ(x6, exp2); + NDArray exp1('c', {2, 3}, {2.5, 3.6, 4.7, 5.5, 6.6, 7.7}, + sd::DataType::FLOAT32); + NDArray exp2('c', {2, 3}, {2, 4, 6, 5, 7, 9}, sd::DataType::INT32); + NDArray exp3('c', {2, 3}, {-0.5, 0.4, 1.3, 2.5, 3.4, 4.3}, + sd::DataType::FLOAT32); + NDArray exp4('c', {2, 3}, {1, 4, 9, 4, 10, 18}, sd::DataType::DOUBLE); + NDArray exp5('c', {2, 3}, {1, 1, 1, 4, 2.5, 2}, sd::DataType::DOUBLE); + NDArray exp6('c', {2, 3}, {1.5, 2.5, 3.5, 4.6, 5.6, 6.6}, + sd::DataType::FLOAT32); - x8.subRowVector(x7, x4); - ASSERT_EQ(x4, exp3); + x1.addRowVector(x3, x4); + ASSERT_EQ(x4, exp1); - x1.mulRowVector(x9, x10); - ASSERT_EQ(x10, exp4); + x1.addRowVector(x5, x6); + ASSERT_EQ(x6, exp2); - x1.divRowVector(x9, x10); - ASSERT_EQ(x10, exp5); + x8.subRowVector(x7, x4); + ASSERT_EQ(x4, exp3); - x1.addColumnVector(x2, x4); - ASSERT_EQ(x4, exp6); + x1.mulRowVector(x9, x10); + ASSERT_EQ(x10, exp4); + + x1.divRowVector(x9, x10); + ASSERT_EQ(x10, exp5); + + x1.addColumnVector(x2, x4); + ASSERT_EQ(x4, exp6); } ////////////////////////////////////////////////////////////////////// @@ -1805,178 +1813,167 @@ TEST_F(MultiDataTypeTests, tile_test1) { ////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, asT_test1) { + NDArray x1('c', {2}, {1.5, 2.5}, sd::DataType::FLOAT32); - NDArray x1('c', {2}, {1.5, 2.5}, sd::DataType::FLOAT32); - - NDArray exp1('c', {2}, {1, 2}, sd::DataType::INT32); - NDArray exp2('c', {2}, {1.5, 2.5}, sd::DataType::DOUBLE); + NDArray exp1('c', {2}, {1, 2}, sd::DataType::INT32); + NDArray exp2('c', {2}, {1.5, 2.5}, sd::DataType::DOUBLE); - auto result = new NDArray(x1.asT()); - ASSERT_EQ(*result, exp1); - delete result; + auto result = new NDArray(x1.asT()); + ASSERT_EQ(*result, exp1); + delete result; - result = new NDArray(x1.asT()); - ASSERT_EQ(*result, exp2); - delete result; + result = new NDArray(x1.asT()); + ASSERT_EQ(*result, exp2); + delete result; - result = new NDArray(x1.asT(sd::DataType::INT32)); - ASSERT_EQ(*result, exp1); - delete result; + result = new NDArray(x1.asT(sd::DataType::INT32)); + ASSERT_EQ(*result, exp1); + delete result; - result = new NDArray(x1.asT(sd::DataType::DOUBLE)); - ASSERT_EQ(*result, exp2); - delete result; + result = new NDArray(x1.asT(sd::DataType::DOUBLE)); + ASSERT_EQ(*result, exp2); + delete result; } ////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, assign_test2) { + NDArray x1('c', {2, 3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, + sd::DataType::FLOAT32); + NDArray x2('c', {3, 2}, sd::DataType::INT32); + NDArray x3('c', {3, 2}, sd::DataType::DOUBLE); + NDArray x4('c', {3, 2}, sd::DataType::BOOL); + NDArray x5('c', {2, 3}, {1.5, 2.5, 0, 4.5, 5.5, 6.5}, sd::DataType::FLOAT32); - NDArray x1('c', {2,3}, {1.5,2.5,3.5,4.5,5.5,6.5}, sd::DataType::FLOAT32); - NDArray x2('c', {3,2}, sd::DataType::INT32); - NDArray x3('c', {3,2}, sd::DataType::DOUBLE); - NDArray x4('c', {3,2}, sd::DataType::BOOL); - NDArray x5('c', {2,3}, {1.5,2.5,0,4.5,5.5,6.5}, sd::DataType::FLOAT32); + NDArray exp1('c', {3, 2}, {1, 2, 3, 4, 5, 6}, sd::DataType::INT32); + NDArray exp2('c', {3, 2}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, + sd::DataType::DOUBLE); + NDArray exp3('c', {3, 2}, {1, 1, 0, 1, 1, 1}, sd::DataType::BOOL); - NDArray exp1('c', {3,2}, {1, 2,3,4,5,6}, sd::DataType::INT32); - NDArray exp2('c', {3,2}, {1.5,2.5,3.5,4.5,5.5,6.5}, sd::DataType::DOUBLE); - NDArray exp3('c', {3,2}, {1,1,0,1,1,1}, sd::DataType::BOOL); + x2.assign(x1); + ASSERT_EQ(x2, exp1); - x2.assign(x1); - ASSERT_EQ(x2, exp1); + x3.assign(x1); + ASSERT_EQ(x3, exp2); - x3.assign(x1); - ASSERT_EQ(x3, exp2); - - x4.assign(x5); - ASSERT_EQ(x4, exp3); + x4.assign(x5); + ASSERT_EQ(x4, exp3); } TEST_F(MultiDataTypeTests, Test_Cast_1) { - auto first = NDArrayFactory::create('c', {10}); - auto asBool = NDArrayFactory::create('c', {10}); - auto _not = NDArrayFactory::create('c', {10}); - auto asFloat = NDArrayFactory::create('c', {10}); - auto exp = NDArrayFactory::create('c', {10}); - exp.assign(0.0f); + auto first = NDArrayFactory::create('c', {10}); + auto asBool = NDArrayFactory::create('c', {10}); + auto _not = NDArrayFactory::create('c', {10}); + auto asFloat = NDArrayFactory::create('c', {10}); + auto exp = NDArrayFactory::create('c', {10}); + exp.assign(0.0f); - asBool.assign(first); + asBool.assign(first); - // asBool.printIndexedBuffer("asBool"); - asBool.applyScalar(scalar::Not, false, _not); + // asBool.printIndexedBuffer("asBool"); + asBool.applyScalar(scalar::Not, false, _not); - // _not.printIndexedBuffer("_not"); + // _not.printIndexedBuffer("_not"); - asFloat.assign(_not); + asFloat.assign(_not); - // asFloat.printIndexedBuffer("asFloat"); - ASSERT_EQ(exp, asFloat); + // asFloat.printIndexedBuffer("asFloat"); + ASSERT_EQ(exp, asFloat); } TEST_F(MultiDataTypeTests, Test_Cast_2) { - auto first = NDArrayFactory::create('c', {10}); - auto asBool = NDArrayFactory::create('c', {10}); - auto _not = NDArrayFactory::create('c', {10}); - auto asFloat = NDArrayFactory::create('c', {10}); - auto exp = NDArrayFactory::create('c', {10}); - exp.assign(1.0f); + auto first = NDArrayFactory::create('c', {10}); + auto asBool = NDArrayFactory::create('c', {10}); + auto _not = NDArrayFactory::create('c', {10}); + auto asFloat = NDArrayFactory::create('c', {10}); + auto exp = NDArrayFactory::create('c', {10}); + exp.assign(1.0f); - asBool.assign(first); + asBool.assign(first); - // asBool.printIndexedBuffer("asBool"); - asBool.applyTransform(transform::Not, _not); + // asBool.printIndexedBuffer("asBool"); + asBool.applyTransform(transform::Not, _not); - // _not.printIndexedBuffer("_not"); + // _not.printIndexedBuffer("_not"); - asFloat.assign(_not); + asFloat.assign(_not); - // asFloat.printIndexedBuffer("asFloat"); - ASSERT_EQ(exp, asFloat); + // asFloat.printIndexedBuffer("asFloat"); + ASSERT_EQ(exp, asFloat); } ////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, divide_bool_test1) { - - NDArray x1('c', {2,3}, {1.5,0,3.5,0,5.5,6.5}, sd::DataType::FLOAT32); - NDArray x2('c', {3,2}, {1,1,0,1,0,1}, sd::DataType::BOOL); - NDArray x3('c', {2,3}, sd::DataType::FLOAT32); - NDArray x4('c', {2}, sd::DataType::BOOL); - - try { - NDArray x3 = x1 / x2; - } - catch (std::exception& message) { - // printf("%s\n", message.what()); - ASSERT_TRUE(1); - } - - try { - x1 /= x2; - } - catch (std::exception& message) { - // printf("%s\n", message.what()); - ASSERT_TRUE(1); - } - - try { - NDArray x3 = 150. / x2; - } - catch (std::exception& message) { - // printf("%s\n", message.what()); - ASSERT_TRUE(1); - } - - try { - x1.divRowVector(x4, x3); - } - catch (std::exception& message) { - // printf("%s\n", message.what()); - ASSERT_TRUE(1); - } - - try { - x1.applyBroadcast(sd::broadcast::FloorDiv, {1}, x4, x3); - } - catch (std::exception& message) { - // printf("%s\n", message.what()); - ASSERT_TRUE(1); - } - - try { - x1.applyTrueBroadcast(BROADCAST(FloorMod), x2, x3); - } - catch (std::exception& message) { - // printf("%s\n", message.what()); - ASSERT_TRUE(1); - } + NDArray x1('c', {2, 3}, {1.5, 0, 3.5, 0, 5.5, 6.5}, sd::DataType::FLOAT32); + NDArray x2('c', {3, 2}, {1, 1, 0, 1, 0, 1}, sd::DataType::BOOL); + NDArray x3('c', {2, 3}, sd::DataType::FLOAT32); + NDArray x4('c', {2}, sd::DataType::BOOL); + + try { + NDArray x3 = x1 / x2; + } catch (std::exception& message) { + // printf("%s\n", message.what()); + ASSERT_TRUE(1); + } + + try { + x1 /= x2; + } catch (std::exception& message) { + // printf("%s\n", message.what()); + ASSERT_TRUE(1); + } + + try { + NDArray x3 = 150. / x2; + } catch (std::exception& message) { + // printf("%s\n", message.what()); + ASSERT_TRUE(1); + } + + try { + x1.divRowVector(x4, x3); + } catch (std::exception& message) { + // printf("%s\n", message.what()); + ASSERT_TRUE(1); + } + + try { + x1.applyBroadcast(sd::broadcast::FloorDiv, {1}, x4, x3); + } catch (std::exception& message) { + // printf("%s\n", message.what()); + ASSERT_TRUE(1); + } + + try { + x1.applyTrueBroadcast(BROADCAST(FloorMod), x2, x3); + } catch (std::exception& message) { + // printf("%s\n", message.what()); + ASSERT_TRUE(1); + } } - ////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, aaa) { + NDArray z('c', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sd::DataType::DOUBLE); + z.permutei({1, 0}); - NDArray z('c', {2,5}, {1,2,3,4,5,6,7,8,9,10}, sd::DataType::DOUBLE); - z.permutei({1,0}); - - sd::graph::RandomGenerator gen(119,5); - ExtraArguments extras({1.5, 2.5}); - - NativeOpExecutioner::execRandom(LaunchContext::defaultContext(), sd::random::UniformDistribution, - &gen, - z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - extras.argumentsAsT()); - // z.printIndexedBuffer(); + sd::graph::RandomGenerator gen(119, 5); + ExtraArguments extras({1.5, 2.5}); + NativeOpExecutioner::execRandom( + LaunchContext::defaultContext(), sd::random::UniformDistribution, &gen, + z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + extras.argumentsAsT()); + // z.printIndexedBuffer(); } ////////////////////////////////////////////////////////////////////// -TEST_F(MultiDataTypeTests, assign_2) -{ - NDArray x('c', {4}, {1.5,2.5,3.5,4.5}, sd::DataType::FLOAT32); - NDArray y('c', {4}, sd::DataType::INT32); - NDArray expected('c', {4}, {1,2,3,4}, sd::DataType::INT32); +TEST_F(MultiDataTypeTests, assign_2) { + NDArray x('c', {4}, {1.5, 2.5, 3.5, 4.5}, sd::DataType::FLOAT32); + NDArray y('c', {4}, sd::DataType::INT32); + NDArray expected('c', {4}, {1, 2, 3, 4}, sd::DataType::INT32); - y.assign(x); - // y.printBuffer(); + y.assign(x); + // y.printBuffer(); - ASSERT_TRUE(expected.equalsTo(&y)); + ASSERT_TRUE(expected.equalsTo(&y)); } diff --git a/libnd4j/tests_cpu/layers_tests/MultiDeviceTests.cpp b/libnd4j/tests_cpu/layers_tests/MultiDeviceTests.cpp index 1c12f2d7258f..306fc1525b56 100644 --- a/libnd4j/tests_cpu/layers_tests/MultiDeviceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MultiDeviceTests.cpp @@ -18,51 +18,52 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include -#include #include #include -#include +#include #include +#include + #include +#include "testlayers.h" using namespace sd; class MultiDeviceTests : public testing::Test { -public: - + public: }; -void createArrays(int limit, std::vector &arrays) { - auto deviceId = AffinityManager::currentDeviceId(); - auto numDevices = AffinityManager::numberOfDevices(); +void createArrays(int limit, std::vector &arrays) { + auto deviceId = AffinityManager::currentDeviceId(); + auto numDevices = AffinityManager::numberOfDevices(); - for (int e = 0; e < limit; e++) { - auto value = deviceId * limit + e; - arrays[value] = NDArrayFactory::create_('c', {10}); - arrays[value]->assign(value); - //nd4j_printf("device_%i; value: [%i]; mean: [%f]\n", deviceId, value, arrays[value]->meanNumber().e(0)); - } + for (int e = 0; e < limit; e++) { + auto value = deviceId * limit + e; + arrays[value] = NDArrayFactory::create_('c', {10}); + arrays[value]->assign(value); + // nd4j_printf("device_%i; value: [%i]; mean: [%f]\n", deviceId, value, + // arrays[value]->meanNumber().e(0)); + } } TEST_F(MultiDeviceTests, test_multi_device_migration_1) { - auto deviceId = AffinityManager::currentDeviceId(); - auto numDevices = AffinityManager::numberOfDevices(); - auto numArrays = 10; - std::vector arrays(numDevices * numArrays); + auto deviceId = AffinityManager::currentDeviceId(); + auto numDevices = AffinityManager::numberOfDevices(); + auto numArrays = 10; + std::vector arrays(numDevices * numArrays); - // filling list of arrays on multiple threads - for (int e = 0; e < numDevices; e++) { - std::thread t1(createArrays, numArrays, std::ref(arrays)); + // filling list of arrays on multiple threads + for (int e = 0; e < numDevices; e++) { + std::thread t1(createArrays, numArrays, std::ref(arrays)); - t1.join(); - } + t1.join(); + } - // at this moment all arrays are build, so we can test migration - for (int e = 0; e < arrays.size(); e++) { - ASSERT_NEAR((float) e, arrays[e]->meanNumber().e(0), 1e-5f); - delete arrays[e]; - } + // at this moment all arrays are build, so we can test migration + for (int e = 0; e < arrays.size(); e++) { + ASSERT_NEAR((float)e, arrays[e]->meanNumber().e(0), 1e-5f); + delete arrays[e]; + } } diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayConstructorsTests.cu b/libnd4j/tests_cpu/layers_tests/NDArrayConstructorsTests.cu index 24ac087d1673..6428169853be 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayConstructorsTests.cu +++ b/libnd4j/tests_cpu/layers_tests/NDArrayConstructorsTests.cu @@ -18,189 +18,186 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include #include +#include +#include #include #include #include #include -#include -#include #include +#include -#include +#include "testlayers.h" using namespace sd; using namespace sd::graph; class NDArrayConstructorsTests : public testing::Test { -public: - + public: }; TEST_F(NDArrayConstructorsTests, test_constructor_1) { - auto x = NDArrayFactory::empty_(); + auto x = NDArrayFactory::empty_(); - ASSERT_TRUE(x->buffer() == nullptr); - ASSERT_TRUE(x->specialBuffer() == nullptr); + ASSERT_TRUE(x->buffer() == nullptr); + ASSERT_TRUE(x->specialBuffer() == nullptr); - ASSERT_FALSE(x->shapeInfo() == nullptr); - ASSERT_FALSE(x->specialShapeInfo() == nullptr); + ASSERT_FALSE(x->shapeInfo() == nullptr); + ASSERT_FALSE(x->specialShapeInfo() == nullptr); - ASSERT_TRUE(x->isActualOnDeviceSide()); - ASSERT_TRUE(x->isActualOnHostSide()); + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_TRUE(x->isActualOnHostSide()); - delete x; + delete x; } TEST_F(NDArrayConstructorsTests, test_constructor_2) { - auto x = NDArrayFactory::vector(5, 1.0f); + auto x = NDArrayFactory::vector(5, 1.0f); + ASSERT_FALSE(x->buffer() == nullptr); + ASSERT_FALSE(x->specialBuffer() == nullptr); - ASSERT_FALSE(x->buffer() == nullptr); - ASSERT_FALSE(x->specialBuffer() == nullptr); + ASSERT_FALSE(x->shapeInfo() == nullptr); + ASSERT_FALSE(x->specialShapeInfo() == nullptr); - ASSERT_FALSE(x->shapeInfo() == nullptr); - ASSERT_FALSE(x->specialShapeInfo() == nullptr); + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_FALSE(x->isActualOnHostSide()); - ASSERT_TRUE(x->isActualOnDeviceSide()); - ASSERT_FALSE(x->isActualOnHostSide()); - - delete x; + delete x; } TEST_F(NDArrayConstructorsTests, test_constructor_3) { - auto x = NDArrayFactory::create('c',{5, 5}); + auto x = NDArrayFactory::create('c', {5, 5}); - ASSERT_TRUE(x.buffer() == nullptr); - ASSERT_FALSE(x.specialBuffer() == nullptr); + ASSERT_TRUE(x.buffer() == nullptr); + ASSERT_FALSE(x.specialBuffer() == nullptr); - ASSERT_FALSE(x.shapeInfo() == nullptr); - ASSERT_FALSE(x.specialShapeInfo() == nullptr); + ASSERT_FALSE(x.shapeInfo() == nullptr); + ASSERT_FALSE(x.specialShapeInfo() == nullptr); - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_FALSE(x.isActualOnHostSide()); + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); } TEST_F(NDArrayConstructorsTests, test_constructor_4) { - auto x = NDArrayFactory::create(sd::DataType::FLOAT32, 1.0f); + auto x = NDArrayFactory::create(sd::DataType::FLOAT32, 1.0f); - ASSERT_FALSE(x.buffer() == nullptr); - ASSERT_FALSE(x.specialBuffer() == nullptr); + ASSERT_FALSE(x.buffer() == nullptr); + ASSERT_FALSE(x.specialBuffer() == nullptr); - ASSERT_FALSE(x.shapeInfo() == nullptr); - ASSERT_FALSE(x.specialShapeInfo() == nullptr); + ASSERT_FALSE(x.shapeInfo() == nullptr); + ASSERT_FALSE(x.specialShapeInfo() == nullptr); - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_TRUE(x.isActualOnHostSide()); + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_TRUE(x.isActualOnHostSide()); } TEST_F(NDArrayConstructorsTests, test_constructor_5) { - auto x = NDArrayFactory::create('c',{2, 2}, {1, 2, 3, 4}); + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - ASSERT_TRUE(x.buffer() == nullptr); - ASSERT_FALSE(x.specialBuffer() == nullptr); + ASSERT_TRUE(x.buffer() == nullptr); + ASSERT_FALSE(x.specialBuffer() == nullptr); - ASSERT_FALSE(x.shapeInfo() == nullptr); - ASSERT_FALSE(x.specialShapeInfo() == nullptr); + ASSERT_FALSE(x.shapeInfo() == nullptr); + ASSERT_FALSE(x.specialShapeInfo() == nullptr); - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_FALSE(x.isActualOnHostSide()); + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); } TEST_F(NDArrayConstructorsTests, test_constructor_6) { - auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - NDArray y(x); + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + NDArray y(x); - ASSERT_TRUE(y.buffer() == nullptr); - ASSERT_FALSE(y.specialBuffer() == nullptr); + ASSERT_TRUE(y.buffer() == nullptr); + ASSERT_FALSE(y.specialBuffer() == nullptr); - ASSERT_FALSE(y.shapeInfo() == nullptr); - ASSERT_FALSE(y.specialShapeInfo() == nullptr); + ASSERT_FALSE(y.shapeInfo() == nullptr); + ASSERT_FALSE(y.specialShapeInfo() == nullptr); - ASSERT_TRUE(y.isActualOnDeviceSide()); - ASSERT_FALSE(y.isActualOnHostSide()); + ASSERT_TRUE(y.isActualOnDeviceSide()); + ASSERT_FALSE(y.isActualOnHostSide()); } TEST_F(NDArrayConstructorsTests, test_constructor_7) { - auto x = NDArrayFactory::create(1.0f); + auto x = NDArrayFactory::create(1.0f); - ASSERT_FALSE(x.buffer() == nullptr); - ASSERT_FALSE(x.specialBuffer() == nullptr); + ASSERT_FALSE(x.buffer() == nullptr); + ASSERT_FALSE(x.specialBuffer() == nullptr); - ASSERT_FALSE(x.shapeInfo() == nullptr); - ASSERT_FALSE(x.specialShapeInfo() == nullptr); + ASSERT_FALSE(x.shapeInfo() == nullptr); + ASSERT_FALSE(x.specialShapeInfo() == nullptr); - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_TRUE(x.isActualOnHostSide()); + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_TRUE(x.isActualOnHostSide()); } TEST_F(NDArrayConstructorsTests, test_constructor_8) { - auto x = NDArrayFactory::create_('c',{2, 2}, {1, 2, 3, 4}); + auto x = NDArrayFactory::create_('c', {2, 2}, {1, 2, 3, 4}); - ASSERT_TRUE(x->buffer() == nullptr); - ASSERT_FALSE(x->specialBuffer() == nullptr); + ASSERT_TRUE(x->buffer() == nullptr); + ASSERT_FALSE(x->specialBuffer() == nullptr); - ASSERT_FALSE(x->shapeInfo() == nullptr); - ASSERT_FALSE(x->specialShapeInfo() == nullptr); + ASSERT_FALSE(x->shapeInfo() == nullptr); + ASSERT_FALSE(x->specialShapeInfo() == nullptr); - ASSERT_TRUE(x->isActualOnDeviceSide()); - ASSERT_FALSE(x->isActualOnHostSide()); + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_FALSE(x->isActualOnHostSide()); - delete x; + delete x; } TEST_F(NDArrayConstructorsTests, test_constructor_9) { - auto x = NDArrayFactory::create_('c',{2, 2}); + auto x = NDArrayFactory::create_('c', {2, 2}); - ASSERT_TRUE(x->buffer() == nullptr); - ASSERT_FALSE(x->specialBuffer() == nullptr); + ASSERT_TRUE(x->buffer() == nullptr); + ASSERT_FALSE(x->specialBuffer() == nullptr); - ASSERT_FALSE(x->shapeInfo() == nullptr); - ASSERT_FALSE(x->specialShapeInfo() == nullptr); + ASSERT_FALSE(x->shapeInfo() == nullptr); + ASSERT_FALSE(x->specialShapeInfo() == nullptr); - ASSERT_TRUE(x->isActualOnDeviceSide()); - ASSERT_FALSE(x->isActualOnHostSide()); + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_FALSE(x->isActualOnHostSide()); - delete x; + delete x; } TEST_F(NDArrayConstructorsTests, test_linspace_1) { - auto x = NDArrayFactory::linspace(1.0f, 10.0f, 20); + auto x = NDArrayFactory::linspace(1.0f, 10.0f, 20); - ASSERT_FALSE(x->buffer() == nullptr); - ASSERT_FALSE(x->specialBuffer() == nullptr); + ASSERT_FALSE(x->buffer() == nullptr); + ASSERT_FALSE(x->specialBuffer() == nullptr); - ASSERT_FALSE(x->shapeInfo() == nullptr); - ASSERT_FALSE(x->specialShapeInfo() == nullptr); + ASSERT_FALSE(x->shapeInfo() == nullptr); + ASSERT_FALSE(x->specialShapeInfo() == nullptr); - ASSERT_TRUE(x->isActualOnDeviceSide()); - ASSERT_TRUE(x->isActualOnHostSide()); + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_TRUE(x->isActualOnHostSide()); - delete x; + delete x; } TEST_F(NDArrayConstructorsTests, test_constructor_10) { - - NDArray scalar1(sd::DataType::DOUBLE); // scalar1 = 0 - NDArray scalar2('c', {}, std::vector{0}); - - ASSERT_TRUE(scalar1.isActualOnDeviceSide()); - ASSERT_TRUE(!scalar1.isActualOnHostSide()); - ASSERT_TRUE(scalar2.isActualOnDeviceSide()); - ASSERT_TRUE(scalar2.isActualOnHostSide()); - - ASSERT_TRUE(scalar2.equalsTo(scalar1)); - - ASSERT_TRUE(scalar1.isActualOnDeviceSide()); - ASSERT_TRUE(!scalar1.isActualOnHostSide()); - ASSERT_TRUE(scalar2.isActualOnDeviceSide()); - ASSERT_TRUE(scalar2.isActualOnHostSide()); - - ASSERT_TRUE(scalar1.buffer() == nullptr); - ASSERT_TRUE(scalar1.specialBuffer() != nullptr); - ASSERT_TRUE(scalar1.shapeInfo() != nullptr); - ASSERT_TRUE(scalar1.specialShapeInfo() != nullptr); - ASSERT_TRUE(scalar1.lengthOf() == 1); + NDArray scalar1(sd::DataType::DOUBLE); // scalar1 = 0 + NDArray scalar2('c', {}, std::vector{0}); + + ASSERT_TRUE(scalar1.isActualOnDeviceSide()); + ASSERT_TRUE(!scalar1.isActualOnHostSide()); + ASSERT_TRUE(scalar2.isActualOnDeviceSide()); + ASSERT_TRUE(scalar2.isActualOnHostSide()); + + ASSERT_TRUE(scalar2.equalsTo(scalar1)); + + ASSERT_TRUE(scalar1.isActualOnDeviceSide()); + ASSERT_TRUE(!scalar1.isActualOnHostSide()); + ASSERT_TRUE(scalar2.isActualOnDeviceSide()); + ASSERT_TRUE(scalar2.isActualOnHostSide()); + + ASSERT_TRUE(scalar1.buffer() == nullptr); + ASSERT_TRUE(scalar1.specialBuffer() != nullptr); + ASSERT_TRUE(scalar1.shapeInfo() != nullptr); + ASSERT_TRUE(scalar1.specialShapeInfo() != nullptr); + ASSERT_TRUE(scalar1.lengthOf() == 1); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu b/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu index 01510dc916f9..663b7936622f 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu +++ b/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu @@ -14,2185 +14,2398 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author raver119@gmail.com - // +// +// @author raver119@gmail.com +// -#include "testlayers.h" #include #include +#include +#include #include #include #include #include -#include -#include #include #include +#include -#include +#include "testlayers.h" using namespace sd; using namespace sd::graph; class NDArrayCudaBasicsTests : public testing::Test { -public: - + public: }; ////////////////////////////////////////////////////////////////////////// -static cudaError_t allocateDeviceMem(LaunchContext& lc, std::vector& devicePtrs, const std::vector>& hostData) { - - if(devicePtrs.size() != hostData.size()) - throw std::invalid_argument("prepareDataForCuda: two input sts::vectors should same sizes !"); - - cudaError_t cudaResult; - - void* reductionPointer; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); if(cudaResult != 0) return cudaResult; - int* allocationPointer; - cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); if(cudaResult != 0) return cudaResult; - - lc.setReductionPointer(reductionPointer); - lc.setAllocationPointer(allocationPointer); - cudaStream_t stream = *lc.getCudaStream(); - - for(int i = 0; i < devicePtrs.size(); ++i) { - - cudaResult = cudaMalloc(reinterpret_cast(&devicePtrs[i]), hostData[i].second); if(cudaResult != 0) return cudaResult; - cudaMemcpyAsync(devicePtrs[i], hostData[i].first, hostData[i].second, cudaMemcpyHostToDevice, stream); - } - return cudaResult; +static cudaError_t allocateDeviceMem( + LaunchContext& lc, std::vector& devicePtrs, + const std::vector>& hostData) { + if (devicePtrs.size() != hostData.size()) + throw std::invalid_argument( + "prepareDataForCuda: two input sts::vectors should same sizes !"); + + cudaError_t cudaResult; + + void* reductionPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + if (cudaResult != 0) return cudaResult; + int* allocationPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&allocationPointer), 1024 * 1024); + if (cudaResult != 0) return cudaResult; + + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + cudaStream_t stream = *lc.getCudaStream(); + + for (int i = 0; i < devicePtrs.size(); ++i) { + cudaResult = cudaMalloc(reinterpret_cast(&devicePtrs[i]), + hostData[i].second); + if (cudaResult != 0) return cudaResult; + cudaMemcpyAsync(devicePtrs[i], hostData[i].first, hostData[i].second, + cudaMemcpyHostToDevice, stream); + } + return cudaResult; } TEST_F(NDArrayCudaBasicsTests, Test_Registration_1) { - auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {5}, {5, 4, 3, 2, 1}); + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}, {5, 4, 3, 2, 1}); - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_FALSE(x.isActualOnHostSide()); + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); } TEST_F(NDArrayCudaBasicsTests, Test_Registration_2) { - auto x = NDArrayFactory::create('c', {5}); - auto y = NDArrayFactory::create('c', {5}); + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create('c', {5}); - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_FALSE(x.isActualOnHostSide()); + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); } TEST_F(NDArrayCudaBasicsTests, Test_Registration_3) { - auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {5}, {5, 4, 3, 2, 1}); + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}, {5, 4, 3, 2, 1}); - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_FALSE(x.isActualOnHostSide()); + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); - NDArray::registerSpecialUse({&x}, {&y}); + NDArray::registerSpecialUse({&x}, {&y}); - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_FALSE(x.isActualOnHostSide()); + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); - ASSERT_TRUE(y.isActualOnDeviceSide()); - ASSERT_FALSE(y.isActualOnHostSide()); + ASSERT_TRUE(y.isActualOnDeviceSide()); + ASSERT_FALSE(y.isActualOnHostSide()); } TEST_F(NDArrayCudaBasicsTests, Test_Registration_01) { - auto x = NDArrayFactory::create_('c', {5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create_('c', {5}, {5, 4, 3, 2, 1}); + auto x = NDArrayFactory::create_('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create_('c', {5}, {5, 4, 3, 2, 1}); - ASSERT_TRUE(x->isActualOnDeviceSide()); - ASSERT_FALSE(x->isActualOnHostSide()); - delete x; - delete y; + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_FALSE(x->isActualOnHostSide()); + delete x; + delete y; } TEST_F(NDArrayCudaBasicsTests, Test_Registration_02) { - auto x = NDArrayFactory::create_('c', {5}); - auto y = NDArrayFactory::create_('c', {5}); + auto x = NDArrayFactory::create_('c', {5}); + auto y = NDArrayFactory::create_('c', {5}); - ASSERT_TRUE(x->isActualOnDeviceSide()); - ASSERT_FALSE(x->isActualOnHostSide()); - delete x; - delete y; + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_FALSE(x->isActualOnHostSide()); + delete x; + delete y; } TEST_F(NDArrayCudaBasicsTests, Test_Registration_03) { - auto x = NDArrayFactory::create_('c', {5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create_('c', {5}, {5, 4, 3, 2, 1}); + auto x = NDArrayFactory::create_('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create_('c', {5}, {5, 4, 3, 2, 1}); - ASSERT_TRUE(x->isActualOnDeviceSide()); - ASSERT_FALSE(x->isActualOnHostSide()); + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_FALSE(x->isActualOnHostSide()); - NDArray::registerSpecialUse({y}, {x}); - x->applyTransform(transform::Neg, *y); - //ASSERT_TRUE(x->isActualOnDeviceSide()); - //ASSERT_FALSE(x->isActualOnHostSide()); + NDArray::registerSpecialUse({y}, {x}); + x->applyTransform(transform::Neg, *y); + // ASSERT_TRUE(x->isActualOnDeviceSide()); + // ASSERT_FALSE(x->isActualOnHostSide()); - //ASSERT_TRUE(y->isActualOnDeviceSide()); - //ASSERT_TRUE(y->isActualOnHostSide()); - //y->syncToHost(); - // y->printBuffer("Negatives"); - delete x; - delete y; + // ASSERT_TRUE(y->isActualOnDeviceSide()); + // ASSERT_TRUE(y->isActualOnHostSide()); + // y->syncToHost(); + // y->printBuffer("Negatives"); + delete x; + delete y; } TEST_F(NDArrayCudaBasicsTests, Test_Cosine_1) { - auto x = NDArrayFactory::create_('c', {5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create_('c', {5}, {5, 4, 3, 2, 1}); + auto x = NDArrayFactory::create_('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create_('c', {5}, {5, 4, 3, 2, 1}); - ASSERT_TRUE(x->isActualOnDeviceSide()); - ASSERT_FALSE(x->isActualOnHostSide()); + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_FALSE(x->isActualOnHostSide()); - NDArray::registerSpecialUse({y}, {x}); - x->applyTransform(transform::Cosine, *y); - //ASSERT_TRUE(x->isActualOnDeviceSide()); - //ASSERT_FALSE(x->isActualOnHostSide()); + NDArray::registerSpecialUse({y}, {x}); + x->applyTransform(transform::Cosine, *y); + // ASSERT_TRUE(x->isActualOnDeviceSide()); + // ASSERT_FALSE(x->isActualOnHostSide()); - //ASSERT_TRUE(y->isActualOnDeviceSide()); - //ASSERT_TRUE(y->isActualOnHostSide()); - //y->syncToHost(); - delete x; - delete y; + // ASSERT_TRUE(y->isActualOnDeviceSide()); + // ASSERT_TRUE(y->isActualOnHostSide()); + // y->syncToHost(); + delete x; + delete y; } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestAdd_1) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto z = NDArrayFactory::create('c', { 5 }, {10, 10, 10, 10, 10}); - - auto exp = NDArrayFactory::create('c', { 5 }, { 2, 4, 6, 8, 10 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - - Nd4jPointer nativeStream = (Nd4jPointer)malloc(sizeof(cudaStream_t)); - CHECK_ALLOC(nativeStream, "Failed to allocate memory for new CUDA stream", sizeof(cudaStream_t)); - cudaError_t dZ = cudaStreamCreate(reinterpret_cast(&nativeStream)); - auto stream = reinterpret_cast(&nativeStream); - - //cudaMemcpyAsync(devBufferPtrX, x.buffer(), x.lengthOf() * x.sizeOfT(), cudaMemcpyHostToDevice, *stream); - //cudaMemcpyAsync(devShapePtrX, x.shapeInfo(), shape::shapeInfoByteLength(x.shapeInfo()), cudaMemcpyHostToDevice, *stream); - - LaunchContext lc(stream, nullptr, nullptr); - NativeOpExecutioner::execPairwiseTransform(&lc, pairwise::Add, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr); - z.tickWriteDevice(); - auto res = cudaStreamSynchronize(*stream); - ASSERT_EQ(0, res); - - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + // allocating host-side arrays + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto z = NDArrayFactory::create('c', {5}, {10, 10, 10, 10, 10}); + + auto exp = NDArrayFactory::create('c', {5}, {2, 4, 6, 8, 10}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); + + Nd4jPointer nativeStream = (Nd4jPointer)malloc(sizeof(cudaStream_t)); + CHECK_ALLOC(nativeStream, "Failed to allocate memory for new CUDA stream", + sizeof(cudaStream_t)); + cudaError_t dZ = + cudaStreamCreate(reinterpret_cast(&nativeStream)); + auto stream = reinterpret_cast(&nativeStream); + + // cudaMemcpyAsync(devBufferPtrX, x.buffer(), x.lengthOf() * x.sizeOfT(), + // cudaMemcpyHostToDevice, *stream); cudaMemcpyAsync(devShapePtrX, + // x.shapeInfo(), shape::shapeInfoByteLength(x.shapeInfo()), + // cudaMemcpyHostToDevice, *stream); + + LaunchContext lc(stream, nullptr, nullptr); + NativeOpExecutioner::execPairwiseTransform( + &lc, pairwise::Add, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr); + z.tickWriteDevice(); + auto res = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, res); + + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestAdd_2) { - // allocating host-side arrays - NDArray x('c', { 5 }, { 1, 2, 3, 4, 5}); - NDArray y('c', { 5 }, { 1, 2, 3, 4, 5}); - NDArray z('c', { 5 }, sd::DataType::DOUBLE); + // allocating host-side arrays + NDArray x('c', {5}, {1, 2, 3, 4, 5}); + NDArray y('c', {5}, {1, 2, 3, 4, 5}); + NDArray z('c', {5}, sd::DataType::DOUBLE); - NDArray exp('c', { 5 }, { 2, 4, 6, 8, 10 }); + NDArray exp('c', {5}, {2, 4, 6, 8, 10}); - Nd4jPointer nativeStream = (Nd4jPointer)malloc(sizeof(cudaStream_t)); - CHECK_ALLOC(nativeStream, "Failed to allocate memory for new CUDA stream", sizeof(cudaStream_t)); - cudaError_t dZ = cudaStreamCreate(reinterpret_cast(&nativeStream)); - auto stream = reinterpret_cast(&nativeStream); + Nd4jPointer nativeStream = (Nd4jPointer)malloc(sizeof(cudaStream_t)); + CHECK_ALLOC(nativeStream, "Failed to allocate memory for new CUDA stream", + sizeof(cudaStream_t)); + cudaError_t dZ = + cudaStreamCreate(reinterpret_cast(&nativeStream)); + auto stream = reinterpret_cast(&nativeStream); - LaunchContext lc(stream, *stream, nullptr, nullptr); - NativeOpExecutioner::execPairwiseTransform(&lc, pairwise::Add, nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr); - auto res = cudaStreamSynchronize(*stream); - ASSERT_EQ(0, res); + LaunchContext lc(stream, *stream, nullptr, nullptr); + NativeOpExecutioner::execPairwiseTransform( + &lc, pairwise::Add, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr); + auto res = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, res); - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestAdd_3) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto z = NDArrayFactory::create('c', { 5 }, {10, 10, 10, 10, 10}); - - auto exp = NDArrayFactory::create('c', { 5 }, { 2, 4, 6, 8, 10 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - - Nd4jPointer nativeStream = (Nd4jPointer)malloc(sizeof(cudaStream_t)); - CHECK_ALLOC(nativeStream, "Failed to allocate memory for new CUDA stream", sizeof(cudaStream_t)); - cudaError_t dZ = cudaStreamCreate(reinterpret_cast(&nativeStream)); - auto stream = reinterpret_cast(&nativeStream); - - //cudaMemcpyAsync(devBufferPtrX, x.buffer(), x.lengthOf() * x.sizeOfT(), cudaMemcpyHostToDevice, *stream); - //cudaMemcpyAsync(devShapePtrX, x.shapeInfo(), shape::shapeInfoByteLength(x.shapeInfo()), cudaMemcpyHostToDevice, *stream); - - LaunchContext lc(stream, *stream, nullptr, nullptr); - NativeOpExecutioner::execPairwiseTransform(&lc, pairwise::Add, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr); - z.tickWriteDevice(); - auto res = cudaStreamSynchronize(*stream); - ASSERT_EQ(0, res); - //double* localBuffer = ; - z.syncToHost(); - cudaMemcpy(z.buffer(), z.specialBuffer(), z.lengthOf() * z.sizeOfT(), cudaMemcpyDeviceToHost); - res = cudaStreamSynchronize(*stream); - z.tickWriteHost(); - ASSERT_EQ(0, res); - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - - for (int e = 0; e < z.lengthOf(); e++) { - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - } + // allocating host-side arrays + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto z = NDArrayFactory::create('c', {5}, {10, 10, 10, 10, 10}); + + auto exp = NDArrayFactory::create('c', {5}, {2, 4, 6, 8, 10}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); + + Nd4jPointer nativeStream = (Nd4jPointer)malloc(sizeof(cudaStream_t)); + CHECK_ALLOC(nativeStream, "Failed to allocate memory for new CUDA stream", + sizeof(cudaStream_t)); + cudaError_t dZ = + cudaStreamCreate(reinterpret_cast(&nativeStream)); + auto stream = reinterpret_cast(&nativeStream); + + // cudaMemcpyAsync(devBufferPtrX, x.buffer(), x.lengthOf() * x.sizeOfT(), + // cudaMemcpyHostToDevice, *stream); cudaMemcpyAsync(devShapePtrX, + // x.shapeInfo(), shape::shapeInfoByteLength(x.shapeInfo()), + // cudaMemcpyHostToDevice, *stream); + + LaunchContext lc(stream, *stream, nullptr, nullptr); + NativeOpExecutioner::execPairwiseTransform( + &lc, pairwise::Add, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr); + z.tickWriteDevice(); + auto res = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, res); + // double* localBuffer = ; + z.syncToHost(); + cudaMemcpy(z.buffer(), z.specialBuffer(), z.lengthOf() * z.sizeOfT(), + cudaMemcpyDeviceToHost); + res = cudaStreamSynchronize(*stream); + z.tickWriteHost(); + ASSERT_EQ(0, res); + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + + for (int e = 0; e < z.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + } } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestAdd_4) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto z = NDArrayFactory::create('c', { 5 }); - - auto exp = NDArrayFactory::create('c', { 5 }, { 2, 4, 6, 8, 10 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - x.applyPairwiseTransform(pairwise::Add, y, z); - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - - for (int e = 0; e < z.lengthOf(); e++) { - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - } + // allocating host-side arrays + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto z = NDArrayFactory::create('c', {5}); + + auto exp = NDArrayFactory::create('c', {5}, {2, 4, 6, 8, 10}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); + x.applyPairwiseTransform(pairwise::Add, y, z); + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + + for (int e = 0; e < z.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + } } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestAdd_5) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - //auto z = NDArrayFactory::create('c', { 5 }); - - auto exp = NDArrayFactory::create('c', { 5 }, { 2, 4, 6, 8, 10 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - x += y; - //x.applyPairwiseTransform(pairwise::Add, &y, &z, nullptr); - x.syncToHost(); - //y.printBuffer("3Y = "); - //z.printBuffer("3Result out"); - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - - for (int e = 0; e < x.lengthOf(); e++) { - ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); - } + // allocating host-side arrays + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + // auto z = NDArrayFactory::create('c', { 5 }); + + auto exp = NDArrayFactory::create('c', {5}, {2, 4, 6, 8, 10}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); + x += y; + // x.applyPairwiseTransform(pairwise::Add, &y, &z, nullptr); + x.syncToHost(); + // y.printBuffer("3Y = "); + // z.printBuffer("3Result out"); + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + + for (int e = 0; e < x.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); + } } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestAdd_6) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create(2); //.'c', { 5 }, { 1, 2, 3, 4, 5}); - //auto z = NDArrayFactory::create('c', { 5 }); - - auto exp = NDArrayFactory::create('c', { 5 }, { 3, 4, 5, 6, 7 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - x += y; - //x.applyPairwiseTransform(pairwise::Add, &y, &z, nullptr); - x.syncToHost(); - - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - - for (int e = 0; e < x.lengthOf(); e++) { - ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); - } + // allocating host-side arrays + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create(2); //.'c', { 5 }, { 1, 2, 3, 4, 5}); + // auto z = NDArrayFactory::create('c', { 5 }); + + auto exp = NDArrayFactory::create('c', {5}, {3, 4, 5, 6, 7}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); + x += y; + // x.applyPairwiseTransform(pairwise::Add, &y, &z, nullptr); + x.syncToHost(); + + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + + for (int e = 0; e < x.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); + } } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestAdd_7) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - //auto y = NDArrayFactory::create(2); //.'c', { 5 }, { 1, 2, 3, 4, 5}); - //auto z = NDArrayFactory::create('c', { 5 }); - - auto exp = NDArrayFactory::create('c', { 5 }, { 3, 4, 5, 6, 7 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - x += 2.; - //x.applyPairwiseTransform(pairwise::Add, &y, &z, nullptr); - x.syncToHost(); - - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - - for (int e = 0; e < x.lengthOf(); e++) { - ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); - } + // allocating host-side arrays + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + // auto y = NDArrayFactory::create(2); //.'c', { 5 }, { 1, 2, 3, 4, + // 5}); auto z = NDArrayFactory::create('c', { 5 }); + + auto exp = NDArrayFactory::create('c', {5}, {3, 4, 5, 6, 7}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); + x += 2.; + // x.applyPairwiseTransform(pairwise::Add, &y, &z, nullptr); + x.syncToHost(); + + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + + for (int e = 0; e < x.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); + } } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestMultiply_1) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto z = NDArrayFactory::create('c', { 5 }); - - auto exp = NDArrayFactory::create('c', { 5 }, { 1, 4, 9, 16, 25 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - x.applyPairwiseTransform(pairwise::Multiply, y, z); - // x.printBuffer("3X = "); - // y.printBuffer("3Y = "); - // z.printBuffer("3Result out"); - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - - for (int e = 0; e < z.lengthOf(); e++) { - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - } + // allocating host-side arrays + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto z = NDArrayFactory::create('c', {5}); + + auto exp = NDArrayFactory::create('c', {5}, {1, 4, 9, 16, 25}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); + x.applyPairwiseTransform(pairwise::Multiply, y, z); + // x.printBuffer("3X = "); + // y.printBuffer("3Y = "); + // z.printBuffer("3Result out"); + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + + for (int e = 0; e < z.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + } } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestMultiply_2) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - NDArray z('c', { 5 }, sd::DataType::DOUBLE); - - auto exp = NDArrayFactory::create('c', { 5 }, { 1, 4, 9, 16, 25 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - x.applyPairwiseTransform(pairwise::Multiply, y, z); - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - - for (int e = 0; e < z.lengthOf(); e++) { - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - } + // allocating host-side arrays + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + NDArray z('c', {5}, sd::DataType::DOUBLE); + + auto exp = NDArrayFactory::create('c', {5}, {1, 4, 9, 16, 25}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); + x.applyPairwiseTransform(pairwise::Multiply, y, z); + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + + for (int e = 0; e < z.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + } } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestMultiply_3) { - // allocating host-side arrays - NDArray x('c', { 5 }, { 1, 2, 3, 4, 5}, sd::DataType::DOUBLE); - NDArray y('c', { 5 }, { 1., 2., 3., 4., 5.}, sd::DataType::DOUBLE); - auto z = NDArrayFactory::create('c', { 5 }); - - auto exp = NDArrayFactory::create('c', { 5 }, { 1, 4, 9, 16, 25 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - x.applyPairwiseTransform(pairwise::Multiply, y, z); - //x.printBuffer("23X = "); - //y.printBuffer("23Y = "); - // z.printBuffer("23Result out"); - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - - for (int e = 0; e < z.lengthOf(); e++) { - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - } + // allocating host-side arrays + NDArray x('c', {5}, {1, 2, 3, 4, 5}, sd::DataType::DOUBLE); + NDArray y('c', {5}, {1., 2., 3., 4., 5.}, sd::DataType::DOUBLE); + auto z = NDArrayFactory::create('c', {5}); + + auto exp = NDArrayFactory::create('c', {5}, {1, 4, 9, 16, 25}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); + x.applyPairwiseTransform(pairwise::Multiply, y, z); + // x.printBuffer("23X = "); + // y.printBuffer("23Y = "); + // z.printBuffer("23Result out"); + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + + for (int e = 0; e < z.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + } } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestMultiply_4) { - // allocating host-side arrays - NDArray x('c', { 5 }, { 1, 2, 3, 4, 5}, sd::DataType::DOUBLE); - NDArray y('c', { 5 }, { 1., 2., 3., 4., 5.}, sd::DataType::DOUBLE); - //auto z = NDArrayFactory::create('c', { 5 }); - - auto exp = NDArrayFactory::create('c', { 5 }, { 1, 4, 9, 16, 25 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); - //x.printBuffer("23X = "); - //y.printBuffer("23Y = "); - x *= y; - //x.tickWriteDevice(); - // x.printBuffer("33Result out"); - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - - for (int e = 0; e < x.lengthOf(); e++) { - ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); - } + // allocating host-side arrays + NDArray x('c', {5}, {1, 2, 3, 4, 5}, sd::DataType::DOUBLE); + NDArray y('c', {5}, {1., 2., 3., 4., 5.}, sd::DataType::DOUBLE); + // auto z = NDArrayFactory::create('c', { 5 }); + + auto exp = NDArrayFactory::create('c', {5}, {1, 4, 9, 16, 25}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, + // nullptr); x.printBuffer("23X = "); y.printBuffer("23Y = "); + x *= y; + // x.tickWriteDevice(); + // x.printBuffer("33Result out"); + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + + for (int e = 0; e < x.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); + } } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestPrimitiveNeg_01) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto exp = NDArrayFactory::create('c', { 5 }, { -1, -2, -3, -4, -5 }); + // allocating host-side arrays + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto exp = NDArrayFactory::create('c', {5}, {-1, -2, -3, -4, -5}); - auto stream = x.getContext()->getCudaStream();//reinterpret_cast(&nativeStream); + auto stream = + x.getContext()->getCudaStream(); // reinterpret_cast(&nativeStream); - NativeOpExecutioner::execTransformSame(x.getContext(), transform::Neg, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), nullptr, nullptr, nullptr); - auto res = cudaStreamSynchronize(*stream); - ASSERT_EQ(0, res); - y.tickWriteDevice(); + NativeOpExecutioner::execTransformSame( + x.getContext(), transform::Neg, x.buffer(), x.shapeInfo(), + x.specialBuffer(), x.specialShapeInfo(), y.buffer(), y.shapeInfo(), + y.specialBuffer(), y.specialShapeInfo(), nullptr, nullptr, nullptr); + auto res = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, res); + y.tickWriteDevice(); - // x.printBuffer("X = "); - // y.printBuffer("Y = "); + // x.printBuffer("X = "); + // y.printBuffer("Y = "); - for (int e = 0; e < y.lengthOf(); e++) { - ASSERT_NEAR(exp.e(e), y.e(e), 1e-5); - } + for (int e = 0; e < y.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), y.e(e), 1e-5); + } } TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveNeg_2) { - auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {5}); - - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_FALSE(x.isActualOnHostSide()); - - x.applyTransform(transform::Neg, y); - //ASSERT_TRUE(x->isActualOnDeviceSide()); - //ASSERT_FALSE(x->isActualOnHostSide()); - - //ASSERT_TRUE(y->isActualOnDeviceSide()); - //ASSERT_TRUE(y->isActualOnHostSide()); - //auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); - //ASSERT_EQ(0, res); - // y.printBuffer("Negatives2"); - //delete x; - //delete y; -} - -TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveSqrt_1) { // strict - auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {5}); - auto exp = NDArrayFactory::create({1.000000, 1.414214, 1.732051, 2.000000, 2.236068}); - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_FALSE(x.isActualOnHostSide()); - - x.applyTransform(transform::Sqrt, y); - //ASSERT_TRUE(x->isActualOnDeviceSide()); - //ASSERT_FALSE(x->isActualOnHostSide()); - - //ASSERT_TRUE(y->isActualOnDeviceSide()); - //ASSERT_TRUE(y->isActualOnHostSide()); - //auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); - //ASSERT_EQ(0, res); - ASSERT_TRUE(y.equalsTo(exp)); - //y.printBuffer("SQRT output"); - //delete x; - //delete y; -} - -TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveAssign_1) { // strict - auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {5}); - //auto exp = NDArrayFactory::create({1.000000, 1.414214, 1.732051, 2.000000, 2.236068}); - //ASSERT_TRUE(x.isActualOnDeviceSide()); - //ASSERT_TRUE(x.isActualOnHostSide()); - - x.applyTransform(transform::Assign, y); - //ASSERT_TRUE(x->isActualOnDeviceSide()); - //ASSERT_FALSE(x->isActualOnHostSide()); - - //ASSERT_TRUE(y->isActualOnDeviceSide()); - //ASSERT_TRUE(y->isActualOnHostSide()); - //auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); - //ASSERT_EQ(0, res); - - // printf("Assigned to another array\n"); - // y.printBuffer("OUput"); - ASSERT_TRUE(y.equalsTo(x)); - //y.syncToHost(); - //y.printBuffer("IsMax output"); - //delete x; - //delete y; -} - -TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveCosine_1) { // strict - auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {5}); - auto exp = NDArrayFactory::create('c', {5}, {0.540302, -0.416147, -0.989992, -0.653644, 0.283662}); - - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_FALSE(x.isActualOnHostSide()); - - x.applyTransform(transform::Cosine, y); - //ASSERT_TRUE(x->isActualOnDeviceSide()); - //ASSERT_FALSE(x->isActualOnHostSide()); - - //ASSERT_TRUE(y->isActualOnDeviceSide()); - //ASSERT_TRUE(y->isActualOnHostSide()); - //auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); - //ASSERT_EQ(0, res); - ASSERT_TRUE(exp.isSameShape(y)); - ASSERT_TRUE(exp.dataType() == y.dataType()); - //y.printBuffer("Cosine2"); - //delete x; - //delete y; + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}); + + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); + + x.applyTransform(transform::Neg, y); + // ASSERT_TRUE(x->isActualOnDeviceSide()); + // ASSERT_FALSE(x->isActualOnHostSide()); + + // ASSERT_TRUE(y->isActualOnDeviceSide()); + // ASSERT_TRUE(y->isActualOnHostSide()); + // auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); + // ASSERT_EQ(0, res); + // y.printBuffer("Negatives2"); + // delete x; + // delete y; +} + +TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveSqrt_1) { // strict + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}); + auto exp = NDArrayFactory::create( + {1.000000, 1.414214, 1.732051, 2.000000, 2.236068}); + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); + + x.applyTransform(transform::Sqrt, y); + // ASSERT_TRUE(x->isActualOnDeviceSide()); + // ASSERT_FALSE(x->isActualOnHostSide()); + + // ASSERT_TRUE(y->isActualOnDeviceSide()); + // ASSERT_TRUE(y->isActualOnHostSide()); + // auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); + // ASSERT_EQ(0, res); + ASSERT_TRUE(y.equalsTo(exp)); + // y.printBuffer("SQRT output"); + // delete x; + // delete y; +} + +TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveAssign_1) { // strict + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}); + // auto exp = + // NDArrayFactory::create({1.000000, 1.414214, 1.732051, 2.000000, 2.236068}); + // ASSERT_TRUE(x.isActualOnDeviceSide()); + // ASSERT_TRUE(x.isActualOnHostSide()); + + x.applyTransform(transform::Assign, y); + // ASSERT_TRUE(x->isActualOnDeviceSide()); + // ASSERT_FALSE(x->isActualOnHostSide()); + + // ASSERT_TRUE(y->isActualOnDeviceSide()); + // ASSERT_TRUE(y->isActualOnHostSide()); + // auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); + // ASSERT_EQ(0, res); + + // printf("Assigned to another array\n"); + // y.printBuffer("OUput"); + ASSERT_TRUE(y.equalsTo(x)); + // y.syncToHost(); + // y.printBuffer("IsMax output"); + // delete x; + // delete y; +} + +TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveCosine_1) { // strict + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}); + auto exp = NDArrayFactory::create( + 'c', {5}, {0.540302, -0.416147, -0.989992, -0.653644, 0.283662}); + + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); + + x.applyTransform(transform::Cosine, y); + // ASSERT_TRUE(x->isActualOnDeviceSide()); + // ASSERT_FALSE(x->isActualOnHostSide()); + + // ASSERT_TRUE(y->isActualOnDeviceSide()); + // ASSERT_TRUE(y->isActualOnHostSide()); + // auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); + // ASSERT_EQ(0, res); + ASSERT_TRUE(exp.isSameShape(y)); + ASSERT_TRUE(exp.dataType() == y.dataType()); + // y.printBuffer("Cosine2"); + // delete x; + // delete y; } TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveCosine_2) { - auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {5}); - auto exp = NDArrayFactory::create('c', {5}, {0.540302, -0.416147, -0.989992, -0.653644, 0.283662}); - - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_FALSE(x.isActualOnHostSide()); - x.applyTransform(transform::Cosine, y); - //ASSERT_TRUE(x->isActualOnDeviceSide()); - //ASSERT_FALSE(x->isActualOnHostSide()); - - //ASSERT_TRUE(y->isActualOnDeviceSide()); - //ASSERT_TRUE(y->isActualOnHostSide()); - //auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); - //ASSERT_EQ(0, res); - //exp.syncToHost(); - //y.printBuffer("PrimitiveCosine2"); - //exp.printBuffer("Primitive Cosine exp"); - ASSERT_TRUE(exp.isSameShape(y)); - ASSERT_TRUE(exp.dataType() == y.dataType()); - //for (int e = 0; e < y.lengthOf(); e++) { - // ASSERT_NEAR(exp.e(e), y.e(e), 1e-5); - //} - - ASSERT_TRUE(exp.equalsTo(y)); - //delete x; - //delete y; + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}); + auto exp = NDArrayFactory::create( + 'c', {5}, {0.540302, -0.416147, -0.989992, -0.653644, 0.283662}); + + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); + x.applyTransform(transform::Cosine, y); + // ASSERT_TRUE(x->isActualOnDeviceSide()); + // ASSERT_FALSE(x->isActualOnHostSide()); + + // ASSERT_TRUE(y->isActualOnDeviceSide()); + // ASSERT_TRUE(y->isActualOnHostSide()); + // auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); + // ASSERT_EQ(0, res); + // exp.syncToHost(); + // y.printBuffer("PrimitiveCosine2"); + // exp.printBuffer("Primitive Cosine exp"); + ASSERT_TRUE(exp.isSameShape(y)); + ASSERT_TRUE(exp.dataType() == y.dataType()); + // for (int e = 0; e < y.lengthOf(); e++) { + // ASSERT_NEAR(exp.e(e), y.e(e), 1e-5); + //} + + ASSERT_TRUE(exp.equalsTo(y)); + // delete x; + // delete y; } TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveCosine_3) { - auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {5}); - auto exp = NDArrayFactory::create({0.540302, -0.416147, -0.989992, -0.653644, 0.283662}); - - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_FALSE(x.isActualOnHostSide()); - x.applyTransform(transform::Cosine, y); - //ASSERT_TRUE(x->isActualOnDeviceSide()); - //ASSERT_FALSE(x->isActualOnHostSide()); - - //ASSERT_TRUE(y->isActualOnDeviceSide()); - //ASSERT_TRUE(y->isActualOnHostSide()); - //auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); - //ASSERT_EQ(0, res); - //exp.syncToHost(); -// y.printBuffer("PrimitiveCosine3"); -// exp.printBuffer("Primitive Cosine3 exp"); -// y.printShapeInfo("Y shape"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.isSameShape(y)); -// -// for (int e = 0; e < y.lengthOf(); e++) { -// printf("%lf == %lf\n", exp.e(e), y.e(e)); -//// ASSERT_NEAR(exp.e(e), y.e(e), 1e-5); -// } - - ASSERT_TRUE(exp.equalsTo(y)); - //delete x; - //delete y; + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}); + auto exp = NDArrayFactory::create( + {0.540302, -0.416147, -0.989992, -0.653644, 0.283662}); + + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); + x.applyTransform(transform::Cosine, y); + // ASSERT_TRUE(x->isActualOnDeviceSide()); + // ASSERT_FALSE(x->isActualOnHostSide()); + + // ASSERT_TRUE(y->isActualOnDeviceSide()); + // ASSERT_TRUE(y->isActualOnHostSide()); + // auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); + // ASSERT_EQ(0, res); + // exp.syncToHost(); + // y.printBuffer("PrimitiveCosine3"); + // exp.printBuffer("Primitive Cosine3 exp"); + // y.printShapeInfo("Y shape"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(y)); + // + // for (int e = 0; e < y.lengthOf(); e++) { + // printf("%lf == %lf\n", exp.e(e), y.e(e)); + //// ASSERT_NEAR(exp.e(e), y.e(e), 1e-5); + // } + + ASSERT_TRUE(exp.equalsTo(y)); + // delete x; + // delete y; } TEST_F(NDArrayCudaBasicsTests, TestRawBroadcast_2) { - - //if (!Environment::getInstance().isExperimentalBuild()) - // return; - - NDArray x = NDArrayFactory::create('c', {2,3,4}); - NDArray y('c', {2,4}, {10,20,30,40,50,60,70,80}, sd::DataType::DOUBLE); - NDArray z('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::DOUBLE); -// NDArray exp('c', {2,3,4}, {10., 21., 32., 43., 14., 25., 36., 47., 18., 29., 40., 51., 62., 73., 84., 95., 66., 77., 88., 99., 70., 81., 92., 103}, sd::DataType::DOUBLE); - NDArray exp('c', {2,3,4}, {10., 40., 90., 160., 50., 120., 210., 320., 90., 200., 330., 480., 650., 840., 1050., 1280., 850., 1080., 1330., 1600., 1050., 1320., 1610., 1920.}, sd::DataType::DOUBLE); - x.linspace(1); x.syncToDevice(); - - std::vector dimensions = {0,2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execBroadcast(&lc, sd::broadcast::Multiply, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + // if (!Environment::getInstance().isExperimentalBuild()) + // return; + + NDArray x = NDArrayFactory::create('c', {2, 3, 4}); + NDArray y('c', {2, 4}, {10, 20, 30, 40, 50, 60, 70, 80}, + sd::DataType::DOUBLE); + NDArray z('c', {2, 3, 4}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::DOUBLE); + // NDArray exp('c', {2,3,4}, + // {10., 21., 32., 43., 14., 25., 36., 47., 18., 29., 40., 51., 62., 73., 84., + // 95., 66., 77., 88., 99., 70., 81., 92., 103}, sd::DataType::DOUBLE); + NDArray exp('c', {2, 3, 4}, + {10., 40., 90., 160., 50., 120., 210., 320., + 90., 200., 330., 480., 650., 840., 1050., 1280., + 850., 1080., 1330., 1600., 1050., 1320., 1610., 1920.}, + sd::DataType::DOUBLE); + x.linspace(1); + x.syncToDevice(); + + std::vector dimensions = {0, 2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execBroadcast( + &lc, sd::broadcast::Multiply, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } TEST_F(NDArrayCudaBasicsTests, TestRawBroadcast_3) { - - //if (!Environment::getInstance().isExperimentalBuild()) - // return; - - NDArray x('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray y('c', {2,4}, {10,20,30,40,50,60,70,80}, sd::DataType::DOUBLE); - NDArray z('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::DOUBLE); -// NDArray exp('c', {2,3,4}, {10., 21., 32., 43., 14., 25., 36., 47., 18., 29., 40., 51., 62., 73., 84., 95., 66., 77., 88., 99., 70., 81., 92., 103}, sd::DataType::DOUBLE); - NDArray exp('c', {2,3,4}, {10., 40., 90., 160., 50., 120., 210., 320., 90., 200., 330., 480., 650., 840., 1050., 1280., 850., 1080., 1330., 1600., 1050., 1320., 1610., 1920.}, sd::DataType::DOUBLE); - x.linspace(1); x.syncToDevice(); - - std::vector dimensions = {0,2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - //cudaStream_t stream; - //cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext* pLc = x.getContext();//(&stream); - cudaStream_t* stream = pLc->getCudaStream(); - // allocate required amount of global device memory and copy host data to it -// cudaResult = allocateDeviceMem(*pLc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - for(int i = 0; i < devicePtrs.size(); ++i) { - - cudaResult = cudaMalloc(reinterpret_cast(&devicePtrs[i]), hostData[i].second); ASSERT_EQ(0, cudaResult); - cudaMemcpyAsync(devicePtrs[i], hostData[i].first, hostData[i].second, cudaMemcpyHostToDevice, *stream); - } - - NDArray::registerSpecialUse({&z}, {&x, &y}); - // call cuda kernel which calculates result - NativeOpExecutioner::execBroadcast(pLc, sd::broadcast::Multiply, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - nullptr, nullptr); - - //cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - //z.syncToHost(); - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - ASSERT_TRUE(exp.equalsTo(z)); - // delete cuda stream - //cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + // if (!Environment::getInstance().isExperimentalBuild()) + // return; + + NDArray x('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray y('c', {2, 4}, {10, 20, 30, 40, 50, 60, 70, 80}, + sd::DataType::DOUBLE); + NDArray z('c', {2, 3, 4}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::DOUBLE); + // NDArray exp('c', {2,3,4}, + // {10., 21., 32., 43., 14., 25., 36., 47., 18., 29., 40., 51., 62., 73., 84., + // 95., 66., 77., 88., 99., 70., 81., 92., 103}, sd::DataType::DOUBLE); + NDArray exp('c', {2, 3, 4}, + {10., 40., 90., 160., 50., 120., 210., 320., + 90., 200., 330., 480., 650., 840., 1050., 1280., + 850., 1080., 1330., 1600., 1050., 1320., 1610., 1920.}, + sd::DataType::DOUBLE); + x.linspace(1); + x.syncToDevice(); + + std::vector dimensions = {0, 2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + // cudaStream_t stream; + // cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext* pLc = x.getContext(); //(&stream); + cudaStream_t* stream = pLc->getCudaStream(); + // allocate required amount of global device memory and copy host data to it + // cudaResult = allocateDeviceMem(*pLc, devicePtrs, hostData); + // ASSERT_EQ(0, cudaResult); + for (int i = 0; i < devicePtrs.size(); ++i) { + cudaResult = cudaMalloc(reinterpret_cast(&devicePtrs[i]), + hostData[i].second); + ASSERT_EQ(0, cudaResult); + cudaMemcpyAsync(devicePtrs[i], hostData[i].first, hostData[i].second, + cudaMemcpyHostToDevice, *stream); + } + + NDArray::registerSpecialUse({&z}, {&x, &y}); + // call cuda kernel which calculates result + NativeOpExecutioner::execBroadcast( + pLc, sd::broadcast::Multiply, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], nullptr, nullptr); + + // cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + // z.syncToHost(); + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + ASSERT_TRUE(exp.equalsTo(z)); + // delete cuda stream + // cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); } - TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_1) { - // allocating host-side arrays - NDArray x('c', { 2, 3 }, { 1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); - NDArray y = NDArrayFactory::create(3.); //'c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); - //auto z = NDArrayFactory::create('c', { 5 }); - - auto exp = NDArrayFactory::create('c', { 2, 3 }, { 3, 6, 9, 12, 15, 18 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); - x *= y; - //x.syncToHost(); - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - ASSERT_TRUE(exp.equalsTo(x)); -// for (int e = 0; e < x.lengthOf(); e++) { -// ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); -// } + // allocating host-side arrays + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); + NDArray y = NDArrayFactory::create( + 3.); //'c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); + // auto z = NDArrayFactory::create('c', { 5 }); + + auto exp = NDArrayFactory::create('c', {2, 3}, {3, 6, 9, 12, 15, 18}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, + // nullptr); + x *= y; + // x.syncToHost(); + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + ASSERT_TRUE(exp.equalsTo(x)); + // for (int e = 0; e < x.lengthOf(); e++) { + // ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); + // } } TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_01) { - // allocating host-side arrays - NDArray x('c', { 2, 3 }, { 1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); - NDArray y = NDArrayFactory::create(3.); //'c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); - auto z = NDArrayFactory::create('c', { 2, 3 }); - - auto exp = NDArrayFactory::create('c', { 2, 3 }, { 3, 6, 9, 12, 15, 18 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); - //x.printBuffer("23X = "); - //y.printBuffer("23Y = "); - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z);// *= y; - // z.printBuffer("53Result out"); - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - ASSERT_TRUE(exp.equalsTo(z)); - -// for (int e = 0; e < x.lengthOf(); e++) { -// ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); -// } + // allocating host-side arrays + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); + NDArray y = NDArrayFactory::create( + 3.); //'c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); + auto z = NDArrayFactory::create('c', {2, 3}); + + auto exp = NDArrayFactory::create('c', {2, 3}, {3, 6, 9, 12, 15, 18}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, + // nullptr); x.printBuffer("23X = "); y.printBuffer("23Y = "); + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); // *= y; + // z.printBuffer("53Result out"); + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + ASSERT_TRUE(exp.equalsTo(z)); + + // for (int e = 0; e < x.lengthOf(); e++) { + // ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + // } } TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_02) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 2, 3 }, { 1, 2, 3, 4, 5, 6}); //, sd::DataType::DOUBLE); - auto y = NDArrayFactory::create('c', {2,3}, {3, 3, 3, 3, 3, 3}); //'c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); - auto z = NDArrayFactory::create('c', { 2, 3 }); - - auto exp = NDArrayFactory::create('c', { 2, 3 }, { 3, 6, 9, 12, 15, 18 }); - //if (x.isActualOnHostSide() && !x.isActualOnDeviceSide()) - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); - //x.printBuffer("23X = "); - //y.printBuffer("23Y = "); - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z);// *= y; - - // z.printBuffer("52Result out"); - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - ASSERT_TRUE(exp.equalsTo(z)); - -// for (int e = 0; e < x.lengthOf(); e++) { -// ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); -// } + // allocating host-side arrays + auto x = NDArrayFactory::create( + 'c', {2, 3}, {1, 2, 3, 4, 5, 6}); //, sd::DataType::DOUBLE); + auto y = NDArrayFactory::create( + 'c', {2, 3}, + {3, 3, 3, 3, 3, 3}); //'c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); + auto z = NDArrayFactory::create('c', {2, 3}); + + auto exp = NDArrayFactory::create('c', {2, 3}, {3, 6, 9, 12, 15, 18}); + // if (x.isActualOnHostSide() && !x.isActualOnDeviceSide()) + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, + // nullptr); x.printBuffer("23X = "); y.printBuffer("23Y = "); + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); // *= y; + + // z.printBuffer("52Result out"); + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + ASSERT_TRUE(exp.equalsTo(z)); + + // for (int e = 0; e < x.lengthOf(); e++) { + // ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + // } } TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_002) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 2, 3 }, { 1, 2, 3, 4, 5, 6}); //, sd::DataType::DOUBLE); - auto y = NDArrayFactory::create('c', {2, 3}, {2., 3., 3., 3., 3., 3.}); //'c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); - auto z = NDArrayFactory::create('c', { 2, 3 }); - - auto exp = NDArrayFactory::create('c', { 2, 3 }, { 2, 6, 9, 12, 15, 18 }); - //if (x.isActualOnHostSide() && !x.isActualOnDeviceSide()) - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); - //x.printBuffer("23X = "); - //y.printBuffer("23Y = "); - x.applyPairwiseTransform(pairwise::Multiply, y, z);// *= y; - - // z.printBuffer("51Result out"); - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - ASSERT_TRUE(exp.equalsTo(z)); - -// for (int e = 0; e < x.lengthOf(); e++) { -// ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); -// } + // allocating host-side arrays + auto x = NDArrayFactory::create( + 'c', {2, 3}, {1, 2, 3, 4, 5, 6}); //, sd::DataType::DOUBLE); + auto y = NDArrayFactory::create( + 'c', {2, 3}, + {2., 3., 3., 3., 3., + 3.}); //'c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); + auto z = NDArrayFactory::create('c', {2, 3}); + + auto exp = NDArrayFactory::create('c', {2, 3}, {2, 6, 9, 12, 15, 18}); + // if (x.isActualOnHostSide() && !x.isActualOnDeviceSide()) + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, + // nullptr); x.printBuffer("23X = "); y.printBuffer("23Y = "); + x.applyPairwiseTransform(pairwise::Multiply, y, z); // *= y; + + // z.printBuffer("51Result out"); + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + ASSERT_TRUE(exp.equalsTo(z)); + + // for (int e = 0; e < x.lengthOf(); e++) { + // ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + // } } //////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestBroadcastRaw_1) { - - //if (!Environment::getInstance().isExperimentalBuild()) - // return; - - NDArray x('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::INT32); - NDArray y('c', {3}, {10, 20, 30}, sd::DataType::INT64); - NDArray z('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::INT32); - NDArray exp('c', {2,3,4}, {10, 11, 12, 13,24, 25, 26, 27,38, 39, 40, 41,22, 23, 24, 25,36, 37, 38, 39,50, 51, 52, 53}, sd::DataType::INT32); - //real output [10, 11, 12, 13, 4, 5, 6, 7, 28, 29, 30, 31, 22, 23, 24, 25, 16, 17, 18, 19, 40, 41, 42, 43] - x.linspace(0); x.syncToDevice(); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(Nd4jLong)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t* stream = x.getContext()->getCudaStream(); - LaunchContext* pLc = x.getContext(); - - // allocate required amount of global device memory and copy host data to it - //cudaResult = allocateDeviceMem(*pLc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - for(size_t i = 0; i < devicePtrs.size(); ++i) { - cudaResult = cudaMalloc(&devicePtrs[i], hostData[i].second); //if(cudaResult != 0) return cudaResult; - ASSERT_EQ(cudaResult, 0); - cudaMemcpy(devicePtrs[i], hostData[i].first, hostData[i].second, cudaMemcpyHostToDevice); - } - - // call cuda kernel which calculates result - NativeOpExecutioner::execBroadcast(pLc, sd::broadcast::Add, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(*stream); ASSERT_EQ(0, cudaResult); - - // x.printIndexedBuffer(" X"); - // y.printIndexedBuffer("+Y"); - // z.printBuffer("ADD broadcasted output"); - // verify results - // for (int e = 0; e < z.lengthOf(); e++) - // ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - //cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + // if (!Environment::getInstance().isExperimentalBuild()) + // return; + + NDArray x('c', {2, 3, 4}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::INT32); + NDArray y('c', {3}, {10, 20, 30}, sd::DataType::INT64); + NDArray z('c', {2, 3, 4}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::INT32); + NDArray exp('c', {2, 3, 4}, {10, 11, 12, 13, 24, 25, 26, 27, 38, 39, 40, 41, + 22, 23, 24, 25, 36, 37, 38, 39, 50, 51, 52, 53}, + sd::DataType::INT32); + // real output [10, 11, 12, 13, 4, 5, 6, 7, 28, 29, 30, 31, 22, 23, 24, 25, + // 16, 17, 18, 19, 40, 41, 42, 43] + x.linspace(0); + x.syncToDevice(); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back( + dimensions.data(), + dimensions.size() * sizeof(Nd4jLong)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t* stream = x.getContext()->getCudaStream(); + LaunchContext* pLc = x.getContext(); + + // allocate required amount of global device memory and copy host data to it + // cudaResult = allocateDeviceMem(*pLc, devicePtrs, hostData); + // ASSERT_EQ(0, cudaResult); + for (size_t i = 0; i < devicePtrs.size(); ++i) { + cudaResult = cudaMalloc( + &devicePtrs[i], + hostData[i].second); // if(cudaResult != 0) return cudaResult; + ASSERT_EQ(cudaResult, 0); + cudaMemcpy(devicePtrs[i], hostData[i].first, hostData[i].second, + cudaMemcpyHostToDevice); + } + + // call cuda kernel which calculates result + NativeOpExecutioner::execBroadcast( + pLc, sd::broadcast::Add, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, cudaResult); + + // x.printIndexedBuffer(" X"); + // y.printIndexedBuffer("+Y"); + // z.printBuffer("ADD broadcasted output"); + // verify results + // for (int e = 0; e < z.lengthOf(); e++) + // ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + // cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); } TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply) { - // allocating host-side arrays - NDArray x('c', { 2, 3 }, { 1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); - NDArray y('c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); - //auto z = NDArrayFactory::create('c', { 5 }); - - auto exp = NDArrayFactory::create('c', { 2, 3 }, { 2, 6, 12, 8, 15, 24 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); - //x.printBuffer("23X = "); - //y.printBuffer("23Y = "); - x *= y; - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - - //for (int e = 0; e < x.lengthOf(); e++) { - // ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); - //} + // allocating host-side arrays + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); + NDArray y('c', {3}, {2., 3., 4.}, sd::DataType::DOUBLE); + // auto z = NDArrayFactory::create('c', { 5 }); + + auto exp = NDArrayFactory::create('c', {2, 3}, {2, 6, 12, 8, 15, 24}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, + // nullptr); x.printBuffer("23X = "); y.printBuffer("23Y = "); + x *= y; + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + + // for (int e = 0; e < x.lengthOf(); e++) { + // ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); + //} } - TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_2) { - // allocating host-side arrays - NDArray x('c', { 2, 3 }, { 1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); - NDArray y('c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); - //auto z = NDArrayFactory::create('c', { 5 }); - - auto exp = NDArrayFactory::create('c', { 2, 3 }, { 11,12, 13,14, 15, 16 }); - auto expZ = NDArrayFactory::create('c', { 2, 3 }, { 2, 6, 12, 8, 15, 24 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); - //x.printBuffer("23X = "); - //y.printBuffer("23Y = "); - //void NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape, ExtraArguments *extraArgs) - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, exp); - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - - //for (int e = 0; e < x.lengthOf(); e++) { - // ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); - //} - ASSERT_TRUE(exp.equalsTo(expZ)); - + // allocating host-side arrays + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); + NDArray y('c', {3}, {2., 3., 4.}, sd::DataType::DOUBLE); + // auto z = NDArrayFactory::create('c', { 5 }); + + auto exp = + NDArrayFactory::create('c', {2, 3}, {11, 12, 13, 14, 15, 16}); + auto expZ = + NDArrayFactory::create('c', {2, 3}, {2, 6, 12, 8, 15, 24}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, + // nullptr); x.printBuffer("23X = "); y.printBuffer("23Y = "); void + // NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray* other, + // NDArray* target, const bool checkTargetShape, ExtraArguments *extraArgs) + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, exp); + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + + // for (int e = 0; e < x.lengthOf(); e++) { + // ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); + //} + ASSERT_TRUE(exp.equalsTo(expZ)); } - ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestReduceSum_1) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create(15); - auto exp = NDArrayFactory::create(15); + // allocating host-side arrays + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create(15); + auto exp = NDArrayFactory::create(15); - auto stream = x.getContext()->getCudaStream();//reinterpret_cast(&nativeStream); + auto stream = + x.getContext()->getCudaStream(); // reinterpret_cast(&nativeStream); - NativeOpExecutioner::execReduceSameScalar(x.getContext(), reduce::Sum, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo()); - auto res = cudaStreamSynchronize(*stream); - ASSERT_EQ(0, res); - y.syncToHost(); + NativeOpExecutioner::execReduceSameScalar( + x.getContext(), reduce::Sum, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), + y.specialBuffer(), y.specialShapeInfo()); + auto res = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, res); + y.syncToHost(); - ASSERT_NEAR(y.e(0), 15, 1e-5); + ASSERT_NEAR(y.e(0), 15, 1e-5); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestDup1) { + NDArray array('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + auto arrC = array.dup('c'); + auto arrF = array.dup('f'); + // arrC->printBuffer("arrC"); - NDArray array('c', {2,3}, {1,2,3,4,5,6}); - auto arrC = array.dup('c'); - auto arrF = array.dup('f'); - // arrC->printBuffer("arrC"); + // arrF->printBuffer("arrF"); + // arrC->printShapeInfo("C shape"); + // arrF->printShapeInfo("F shape"); - // arrF->printBuffer("arrF"); - //arrC->printShapeInfo("C shape"); - //arrF->printShapeInfo("F shape"); + ASSERT_TRUE(array.equalsTo(arrF)); + ASSERT_TRUE(array.equalsTo(arrC)); - ASSERT_TRUE(array.equalsTo(arrF)); - ASSERT_TRUE(array.equalsTo(arrC)); - - ASSERT_TRUE(arrF.equalsTo(arrC)); + ASSERT_TRUE(arrF.equalsTo(arrC)); } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, equalsTo_1) { + NDArray x('c', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sd::DataType::DOUBLE); + NDArray y('c', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sd::DataType::DOUBLE); - NDArray x('c', {2,5}, {1,2,3,4,5,6,7,8,9,10}, sd::DataType::DOUBLE); - NDArray y('c', {2,5}, {1,2,3,4,5,6,7,8,9,10}, sd::DataType::DOUBLE); - - ASSERT_TRUE(x.equalsTo(y)); + ASSERT_TRUE(x.equalsTo(y)); - x.permutei({1,0}); - y.permutei({1,0}); + x.permutei({1, 0}); + y.permutei({1, 0}); - ASSERT_TRUE(x.equalsTo(y)); + ASSERT_TRUE(x.equalsTo(y)); } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, equalsTo_2) { + NDArray x('c', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 10, 10}, + sd::DataType::DOUBLE); + NDArray y('c', {2, 5}, {1, 2, 5, 4, 5, 6, 7, 8, 9, 10}, sd::DataType::DOUBLE); - NDArray x('c', {2,5}, {1,2,3,4,5,6,7,8,10,10}, sd::DataType::DOUBLE); - NDArray y('c', {2,5}, {1,2,5,4,5,6,7,8,9,10}, sd::DataType::DOUBLE); - - ASSERT_FALSE(x.equalsTo(y)); + ASSERT_FALSE(x.equalsTo(y)); - x.permutei({1,0}); - y.permutei({1,0}); + x.permutei({1, 0}); + y.permutei({1, 0}); - ASSERT_FALSE(x.equalsTo(y)); + ASSERT_FALSE(x.equalsTo(y)); } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, equalsTo_3) { + NDArray x('c', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sd::DataType::DOUBLE); + NDArray y('c', {2, 5}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f}, + sd::DataType::FLOAT32); - NDArray x('c', {2,5}, {1,2,3,4,5,6,7,8,9,10}, sd::DataType::DOUBLE); - NDArray y('c', {2,5}, {1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f}, sd::DataType::FLOAT32); - - ASSERT_FALSE(x.equalsTo(y)); + ASSERT_FALSE(x.equalsTo(y)); - x.permutei({1,0}); - y.permutei({1,0}); + x.permutei({1, 0}); + y.permutei({1, 0}); - ASSERT_FALSE(x.equalsTo(y)); + ASSERT_FALSE(x.equalsTo(y)); } //////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, applyReduce3_1) { + NDArray x('c', {2, 3, 4}, {-10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13}, + sd::DataType::INT32); + NDArray x2('c', {2, 3, 4}, {-10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13}, + sd::DataType::INT32); + NDArray y('c', {2, 3, 4}, {-2, 3, -4, 5, -2, 3, -4, 5, -2, 3, -4, 5, + -2, 3, -4, 5, -2, 3, -4, 5, -2, 3, -4, 5}, + sd::DataType::INT32); + NDArray k('c', {2, 3}, {-2, 3, -4, 5, -2, 3}, sd::DataType::INT32); + NDArray k2('c', {3, 2}, {-2, 3, -4, 5, -2, 3}, sd::DataType::INT32); - NDArray x('c', {2,3,4}, {-10,-9,-8,-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13}, sd::DataType::INT32); - NDArray x2('c', {2,3,4}, {-10,-9,-8,-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13}, sd::DataType::INT32); - NDArray y('c', {2,3,4}, {-2,3,-4,5,-2,3,-4,5,-2,3,-4,5,-2,3,-4,5,-2,3,-4,5,-2,3,-4,5}, sd::DataType::INT32); - NDArray k('c', {2,3}, {-2,3,-4,5,-2,3}, sd::DataType::INT32); - NDArray k2('c', {3,2}, {-2,3,-4,5,-2,3}, sd::DataType::INT32); + NDArray exp1('c', {3}, {4.f, 20.f, 36.f}, sd::DataType::FLOAT32); + NDArray exp2('c', {2, 3}, {-10.f, -2.f, 6.f, 14.f, 22.f, 30.f}, + sd::DataType::FLOAT32); + NDArray exp3('c', {4}, {38.f, 41.f, 44.f, 47.f}, sd::DataType::FLOAT32); + NDArray exp4('c', {4}, {114.f, 117.f, 120.f, 123.f}, sd::DataType::FLOAT32); - NDArray exp1('c', {3}, {4.f, 20.f, 36.f}, sd::DataType::FLOAT32); - NDArray exp2('c', {2,3}, {-10.f, -2.f, 6.f,14.f, 22.f, 30.f}, sd::DataType::FLOAT32); - NDArray exp3('c', {4}, {38.f, 41.f, 44.f, 47.f}, sd::DataType::FLOAT32); - NDArray exp4('c', {4}, {114.f, 117.f, 120.f, 123.f}, sd::DataType::FLOAT32); + NDArray z = x.applyReduce3(sd::reduce3::Dot, y, {0, 2}); + ASSERT_TRUE(z.equalsTo(&exp1)); + z = x.applyReduce3(sd::reduce3::Dot, k, {0, 1}); + ASSERT_TRUE(z.equalsTo(&exp3)); - NDArray z = x.applyReduce3(sd::reduce3::Dot, y, {0,2}); - ASSERT_TRUE(z.equalsTo(&exp1)); + x.permutei({0, 2, 1}); + y.permutei({0, 2, 1}); - z = x.applyReduce3(sd::reduce3::Dot, k, {0,1}); - ASSERT_TRUE(z.equalsTo(&exp3)); + z = y.applyReduce3(sd::reduce3::Dot, x, {1}); + ASSERT_TRUE(z.equalsTo(&exp2)); - x.permutei({0,2,1}); - y.permutei({0,2,1}); + x2.permutei({1, 0, 2}); - z = y.applyReduce3(sd::reduce3::Dot, x, {1}); - ASSERT_TRUE(z.equalsTo(&exp2)); - - x2.permutei({1,0,2}); - - z = x2.applyReduce3(sd::reduce3::Dot, k2, {0,1}); - ASSERT_TRUE(z.equalsTo(&exp4)); + z = x2.applyReduce3(sd::reduce3::Dot, k2, {0, 1}); + ASSERT_TRUE(z.equalsTo(&exp4)); } //////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, applyReduce3_2) { + NDArray x('c', {2, 3, 4}, {-10, -9, -8.5, -7, -6, -5, -4, -3, -2, -1, 0, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13}, + sd::DataType::DOUBLE); + NDArray x2('c', {2, 3, 4}, {-10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0.5, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13}, + sd::DataType::DOUBLE); + NDArray y('c', {2, 3, 4}, {-2, 3, -4, 5, -2, 3, -4, 5, -2, 3, -4, 5, + -2.5, 3, -4, 5, -2, 3, -4, 5, -2, 3, -4, 5}, + sd::DataType::DOUBLE); + NDArray k('c', {2, 3}, {-2, 3, -4, 5.5, -2, 3}, sd::DataType::DOUBLE); + NDArray k2('c', {3, 2}, {-2, 3, -4, 5, -2, 3.5}, sd::DataType::DOUBLE); - NDArray x('c', {2,3,4}, {-10,-9,-8.5,-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13}, sd::DataType::DOUBLE); - NDArray x2('c', {2,3,4}, {-10,-9,-8,-7,-6,-5,-4,-3,-2,-1,0.5,1,2,3,4,5,6,7,8,9,10,11,12,13}, sd::DataType::DOUBLE); - NDArray y('c', {2,3,4}, {-2,3,-4,5,-2,3,-4,5,-2,3,-4,5,-2.5,3,-4,5,-2,3,-4,5,-2,3,-4,5}, sd::DataType::DOUBLE); - NDArray k('c', {2,3}, {-2,3,-4,5.5,-2,3}, sd::DataType::DOUBLE); - NDArray k2('c', {3,2}, {-2,3,-4,5,-2,3.5}, sd::DataType::DOUBLE); - - NDArray exp1('c', {3}, {5., 20., 36.}, sd::DataType::DOUBLE); - NDArray exp2('c', {2,3}, {-8., -2., 6., 13., 22., 30.}, sd::DataType::DOUBLE); - NDArray exp3('c', {4}, {39., 42.5, 47., 49.5}, sd::DataType::DOUBLE); - NDArray exp4('c', {4}, {119., 122.5, 125., 129.5}, sd::DataType::DOUBLE); + NDArray exp1('c', {3}, {5., 20., 36.}, sd::DataType::DOUBLE); + NDArray exp2('c', {2, 3}, {-8., -2., 6., 13., 22., 30.}, + sd::DataType::DOUBLE); + NDArray exp3('c', {4}, {39., 42.5, 47., 49.5}, sd::DataType::DOUBLE); + NDArray exp4('c', {4}, {119., 122.5, 125., 129.5}, sd::DataType::DOUBLE); - NDArray z = x.applyReduce3(sd::reduce3::Dot, y, {0,2}); - ASSERT_TRUE(z.equalsTo(&exp1)); + NDArray z = x.applyReduce3(sd::reduce3::Dot, y, {0, 2}); + ASSERT_TRUE(z.equalsTo(&exp1)); - z = x.applyReduce3(sd::reduce3::Dot, k, {0,1}); - ASSERT_TRUE(z.equalsTo(&exp3)); + z = x.applyReduce3(sd::reduce3::Dot, k, {0, 1}); + ASSERT_TRUE(z.equalsTo(&exp3)); - x.permutei({0,2,1}); - y.permutei({0,2,1}); + x.permutei({0, 2, 1}); + y.permutei({0, 2, 1}); - z = y.applyReduce3(sd::reduce3::Dot, x, {1}); - ASSERT_TRUE(z.equalsTo(&exp2)); + z = y.applyReduce3(sd::reduce3::Dot, x, {1}); + ASSERT_TRUE(z.equalsTo(&exp2)); - x2.permutei({1,0,2}); + x2.permutei({1, 0, 2}); - z = x2.applyReduce3(sd::reduce3::Dot, k2, {0,1}); - ASSERT_TRUE(z.equalsTo(&exp4)); + z = x2.applyReduce3(sd::reduce3::Dot, k2, {0, 1}); + ASSERT_TRUE(z.equalsTo(&exp4)); } //////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, applyReduce3_3) { + NDArray x1('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}, sd::DataType::INT32); + NDArray x2('c', {2, 2, 2}, {-1, -2, -3, -4, -5, -6, -7, -8}, + sd::DataType::INT32); + NDArray x3('c', {3, 2}, {1.5, 1.5, 1.5, 1.5, 1.5, 1.5}, sd::DataType::DOUBLE); + NDArray x4('c', {3, 2}, {1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); - NDArray x1('c', {2,2,2}, {1,2,3,4,5,6,7,8}, sd::DataType::INT32); - NDArray x2('c', {2,2,2}, {-1,-2,-3,-4,-5,-6,-7,-8}, sd::DataType::INT32); - NDArray x3('c', {3,2}, {1.5,1.5,1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); - NDArray x4('c', {3,2}, {1,2,3,4,5,6}, sd::DataType::DOUBLE); - - NDArray exp1('c', {}, std::vector{-204}, sd::DataType::FLOAT32); - NDArray exp2('c', {}, std::vector{31.5}, sd::DataType::DOUBLE); - + NDArray exp1('c', {}, std::vector{-204}, sd::DataType::FLOAT32); + NDArray exp2('c', {}, std::vector{31.5}, sd::DataType::DOUBLE); - auto z = x1.applyReduce3(reduce3::Dot, x2); - ASSERT_TRUE(z.equalsTo(&exp1)); + auto z = x1.applyReduce3(reduce3::Dot, x2); + ASSERT_TRUE(z.equalsTo(&exp1)); - z = x3.applyReduce3(reduce3::Dot, x4); - ASSERT_TRUE(z.equalsTo(&exp2)); + z = x3.applyReduce3(reduce3::Dot, x4); + ASSERT_TRUE(z.equalsTo(&exp2)); - x1.permutei({2,1,0}); - x2.permutei({2,1,0}); - x3.permutei({1,0}); - x4.permutei({1,0}); + x1.permutei({2, 1, 0}); + x2.permutei({2, 1, 0}); + x3.permutei({1, 0}); + x4.permutei({1, 0}); - z = x1.applyReduce3(reduce3::Dot, x2); - ASSERT_TRUE(z.equalsTo(&exp1)); + z = x1.applyReduce3(reduce3::Dot, x2); + ASSERT_TRUE(z.equalsTo(&exp1)); - z = x3.applyReduce3(reduce3::Dot, x4); - ASSERT_TRUE(z.equalsTo(&exp2)); + z = x3.applyReduce3(reduce3::Dot, x4); + ASSERT_TRUE(z.equalsTo(&exp2)); } //////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, applyAllReduce3_1) { - - NDArray x1('c', {2,3,2}, {1,2,3,4,5,6,7,8,-1,-2,-3,-4,}, sd::DataType::INT32); - NDArray x2('c', {2,2,2}, {-1,-2,-3,-4,-5,-6,-7,-8}, sd::DataType::INT32); - NDArray x3('c', {3,2}, {1.5,1.5,1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); - NDArray x4('c', {3,2}, {1,2,3,4,5,6}, sd::DataType::DOUBLE); - - NDArray exp1('c', {3,2}, {-88.f, -124.f, 6.f, -2.f, 22.f, 14.f}, sd::DataType::FLOAT32); - NDArray exp2('c', {6,4}, {-36.f, -44.f, -52.f, -60.f,-42.f, -52.f, -62.f, -72.f, 2.f, 0.f, -2.f, - -4.f, 6.f, 4.f, 2.f, 0.f, 10.f, 8.f, 6.f, 4.f, 14.f, 12.f, 10.f, 8.f}, - sd::DataType::FLOAT32); - NDArray exp3('c', {1,1}, std::vector{31.5}, sd::DataType::DOUBLE); - NDArray exp4('c', {3,3}, {4.5, 10.5, 16.5,4.5, 10.5, 16.5,4.5, 10.5, 16.5}, sd::DataType::DOUBLE); - - auto z = x1.applyAllReduce3(reduce3::Dot, x2, {0,2}); - ASSERT_TRUE(z.equalsTo(&exp1)); - - z = x1.applyAllReduce3(reduce3::Dot, x2, {0}); - ASSERT_TRUE(z.equalsTo(&exp2)); - - z = x3.applyAllReduce3(reduce3::Dot, x4, {0,1}); - ASSERT_TRUE(z.equalsTo(&exp3)); - - z = x3.applyAllReduce3(reduce3::Dot, x4, {1}); - ASSERT_TRUE(z.equalsTo(&exp4)); - - x1.permutei({2,1,0}); - x2.permutei({2,1,0}); - x3.permutei({1,0}); - x4.permutei({1,0}); - - z = x1.applyAllReduce3(reduce3::Dot, x2, {0,2}); - ASSERT_TRUE(z.equalsTo(&exp1)); - - z = x3.applyAllReduce3(reduce3::Dot, x4, {0}); - ASSERT_TRUE(z.equalsTo(&exp4)); + NDArray x1('c', {2, 3, 2}, + { + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + -1, + -2, + -3, + -4, + }, + sd::DataType::INT32); + NDArray x2('c', {2, 2, 2}, {-1, -2, -3, -4, -5, -6, -7, -8}, + sd::DataType::INT32); + NDArray x3('c', {3, 2}, {1.5, 1.5, 1.5, 1.5, 1.5, 1.5}, sd::DataType::DOUBLE); + NDArray x4('c', {3, 2}, {1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); + + NDArray exp1('c', {3, 2}, {-88.f, -124.f, 6.f, -2.f, 22.f, 14.f}, + sd::DataType::FLOAT32); + NDArray exp2('c', {6, 4}, + {-36.f, -44.f, -52.f, -60.f, -42.f, -52.f, -62.f, -72.f, + 2.f, 0.f, -2.f, -4.f, 6.f, 4.f, 2.f, 0.f, + 10.f, 8.f, 6.f, 4.f, 14.f, 12.f, 10.f, 8.f}, + sd::DataType::FLOAT32); + NDArray exp3('c', {1, 1}, std::vector{31.5}, sd::DataType::DOUBLE); + NDArray exp4('c', {3, 3}, {4.5, 10.5, 16.5, 4.5, 10.5, 16.5, 4.5, 10.5, 16.5}, + sd::DataType::DOUBLE); + + auto z = x1.applyAllReduce3(reduce3::Dot, x2, {0, 2}); + ASSERT_TRUE(z.equalsTo(&exp1)); + + z = x1.applyAllReduce3(reduce3::Dot, x2, {0}); + ASSERT_TRUE(z.equalsTo(&exp2)); + + z = x3.applyAllReduce3(reduce3::Dot, x4, {0, 1}); + ASSERT_TRUE(z.equalsTo(&exp3)); + + z = x3.applyAllReduce3(reduce3::Dot, x4, {1}); + ASSERT_TRUE(z.equalsTo(&exp4)); + + x1.permutei({2, 1, 0}); + x2.permutei({2, 1, 0}); + x3.permutei({1, 0}); + x4.permutei({1, 0}); + + z = x1.applyAllReduce3(reduce3::Dot, x2, {0, 2}); + ASSERT_TRUE(z.equalsTo(&exp1)); + + z = x3.applyAllReduce3(reduce3::Dot, x4, {0}); + ASSERT_TRUE(z.equalsTo(&exp4)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, applyIndexReduce_test1) { + NDArray x('c', {2, 3}, {0, 10, 1, 2, 2.5, -4}, sd::DataType::DOUBLE); - NDArray x('c', {2,3}, {0, 10, 1, 2, 2.5,-4}, sd::DataType::DOUBLE); + NDArray scalar('c', {}, std::vector{100}, sd::DataType::INT64); + NDArray vec1('c', {2}, {100, 100}, sd::DataType::INT64); + NDArray vec2('c', {3}, {100, 100, 100}, sd::DataType::INT64); - NDArray scalar('c', {}, std::vector{100}, sd::DataType::INT64); - NDArray vec1('c', {2}, {100,100}, sd::DataType::INT64); - NDArray vec2('c', {3}, {100,100,100}, sd::DataType::INT64); + NDArray exp1('c', {}, std::vector{1}, sd::DataType::INT64); + NDArray exp2('c', {2}, {1, 1}, sd::DataType::INT64); + NDArray exp3('c', {3}, {1, 0, 0}, sd::DataType::INT64); - NDArray exp1('c', {}, std::vector{1}, sd::DataType::INT64); - NDArray exp2('c', {2}, {1,1}, sd::DataType::INT64); - NDArray exp3('c', {3}, {1,0,0}, sd::DataType::INT64); + NDArray exp4('c', {}, std::vector{2}, sd::DataType::INT64); + NDArray exp5('c', {2}, {1, 1}, sd::DataType::INT64); + NDArray exp6('c', {3}, {1, 0, 0}, sd::DataType::INT64); - NDArray exp4('c', {}, std::vector{2}, sd::DataType::INT64); - NDArray exp5('c', {2}, {1,1}, sd::DataType::INT64); - NDArray exp6('c', {3}, {1,0,0}, sd::DataType::INT64); + x.applyIndexReduce(sd::indexreduce::IndexMax, scalar, {0, 1}); + ASSERT_TRUE(scalar.equalsTo(&exp1)); - x.applyIndexReduce(sd::indexreduce::IndexMax, scalar, {0,1}); - ASSERT_TRUE(scalar.equalsTo(&exp1)); + x.applyIndexReduce(sd::indexreduce::IndexMax, vec1, {1}); + ASSERT_TRUE(vec1.equalsTo(&exp2)); - x.applyIndexReduce(sd::indexreduce::IndexMax, vec1, {1}); - ASSERT_TRUE(vec1.equalsTo(&exp2)); + x.applyIndexReduce(sd::indexreduce::IndexMax, vec2, {0}); + ASSERT_TRUE(vec2.equalsTo(&exp3)); - x.applyIndexReduce(sd::indexreduce::IndexMax, vec2, {0}); - ASSERT_TRUE(vec2.equalsTo(&exp3)); + x.permutei({1, 0}); - x.permutei({1,0}); + x.applyIndexReduce(sd::indexreduce::IndexMax, scalar, {0, 1}); + ASSERT_TRUE(scalar.equalsTo(&exp4)); - x.applyIndexReduce(sd::indexreduce::IndexMax, scalar, {0,1}); - ASSERT_TRUE(scalar.equalsTo(&exp4)); + x.applyIndexReduce(sd::indexreduce::IndexMax, vec1, {0}); + ASSERT_TRUE(vec1.equalsTo(&exp5)); - x.applyIndexReduce(sd::indexreduce::IndexMax, vec1, {0}); - ASSERT_TRUE(vec1.equalsTo(&exp5)); - - x.applyIndexReduce(sd::indexreduce::IndexMax, vec2, {1}); - ASSERT_TRUE(vec2.equalsTo(&exp6)); + x.applyIndexReduce(sd::indexreduce::IndexMax, vec2, {1}); + ASSERT_TRUE(vec2.equalsTo(&exp6)); } - ////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, applyIndexReduce_test2) { + NDArray x('c', {2, 3}, {0, 10, 1, 2, 2.5, -4}, sd::DataType::DOUBLE); - NDArray x('c', {2,3}, {0, 10, 1, 2, 2.5,-4}, sd::DataType::DOUBLE); - - NDArray exp1('c', {}, std::vector{1}, sd::DataType::INT64); - NDArray exp2('c', {2}, {1,1}, sd::DataType::INT64); - NDArray exp3('c', {3}, {1,0,0}, sd::DataType::INT64); + NDArray exp1('c', {}, std::vector{1}, sd::DataType::INT64); + NDArray exp2('c', {2}, {1, 1}, sd::DataType::INT64); + NDArray exp3('c', {3}, {1, 0, 0}, sd::DataType::INT64); - NDArray exp4('c', {}, std::vector{2}, sd::DataType::INT64); - NDArray exp5('c', {2}, {1,1}, sd::DataType::INT64); - NDArray exp6('c', {3}, {1,0,0}, sd::DataType::INT64); + NDArray exp4('c', {}, std::vector{2}, sd::DataType::INT64); + NDArray exp5('c', {2}, {1, 1}, sd::DataType::INT64); + NDArray exp6('c', {3}, {1, 0, 0}, sd::DataType::INT64); - auto z = x.applyIndexReduce(sd::indexreduce::IndexMax, {0,1}); - ASSERT_TRUE(z.equalsTo(&exp1)); + auto z = x.applyIndexReduce(sd::indexreduce::IndexMax, {0, 1}); + ASSERT_TRUE(z.equalsTo(&exp1)); - z = x.applyIndexReduce(sd::indexreduce::IndexMax, {1}); - ASSERT_TRUE(z.equalsTo(&exp2)); + z = x.applyIndexReduce(sd::indexreduce::IndexMax, {1}); + ASSERT_TRUE(z.equalsTo(&exp2)); - z = x.applyIndexReduce(sd::indexreduce::IndexMax, {0}); - ASSERT_TRUE(z.equalsTo(&exp3)); + z = x.applyIndexReduce(sd::indexreduce::IndexMax, {0}); + ASSERT_TRUE(z.equalsTo(&exp3)); - x.permutei({1,0}); + x.permutei({1, 0}); - z = x.applyIndexReduce(sd::indexreduce::IndexMax, {0,1}); - ASSERT_TRUE(z.equalsTo(&exp4)); + z = x.applyIndexReduce(sd::indexreduce::IndexMax, {0, 1}); + ASSERT_TRUE(z.equalsTo(&exp4)); - z = x.applyIndexReduce(sd::indexreduce::IndexMax, {0}); - ASSERT_TRUE(z.equalsTo(&exp5)); + z = x.applyIndexReduce(sd::indexreduce::IndexMax, {0}); + ASSERT_TRUE(z.equalsTo(&exp5)); - z = x.applyIndexReduce(sd::indexreduce::IndexMax, {1}); - ASSERT_TRUE(z.equalsTo(&exp6)); + z = x.applyIndexReduce(sd::indexreduce::IndexMax, {1}); + ASSERT_TRUE(z.equalsTo(&exp6)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_float_test1) { - - NDArray x('c', {2,3,2}, {1,2,3,4,5,6,7,8,-1,-2,-3,-4,}, sd::DataType::INT32); - - NDArray z1('c', {}, std::vector{100}, sd::DataType::DOUBLE); - NDArray z2('c', {2,2}, {100,100,100,100}, sd::DataType::FLOAT32); - NDArray z3('c', {3}, {100,100,100}, sd::DataType::DOUBLE); - NDArray z4('c', {3,2}, {100,100,100,100,100,100}, sd::DataType::FLOAT32); - NDArray z5('c', {2}, {100,100}, sd::DataType::FLOAT32); - - NDArray exp1('c', {}, std::vector{2.166667}, sd::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {3.f,4.f,1.f,0.666667f}, sd::DataType::FLOAT32); - NDArray exp3('c', {3}, {4.5,1,1}, sd::DataType::DOUBLE); - NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, sd::DataType::FLOAT32); - NDArray exp5('c', {2}, {3.5f,0.833333f}, sd::DataType::FLOAT32); - - x.reduceAlongDimension(sd::reduce::Mean, z1, {0,1,2}); - ASSERT_TRUE(z1.equalsTo(&exp1)); - - x.reduceAlongDimension(sd::reduce::Mean, z2, {1}); - ASSERT_TRUE(z2.equalsTo(&exp2)); - - x.reduceAlongDimension(sd::reduce::Mean, z3, {0,2}); - ASSERT_TRUE(z3.equalsTo(&exp3)); - - x.permutei({1,0,2}); // 3x2x2 - - x.reduceAlongDimension(sd::reduce::Mean, z1, {0,1,2}); - ASSERT_TRUE(z1.equalsTo(&exp1)); - - x.reduceAlongDimension(sd::reduce::Mean, z4, {1}); - ASSERT_TRUE(z4.equalsTo(&exp4)); - - x.reduceAlongDimension(sd::reduce::Mean, z5, {0,2}); - ASSERT_TRUE(z5.equalsTo(&exp5)); + NDArray x('c', {2, 3, 2}, + { + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + -1, + -2, + -3, + -4, + }, + sd::DataType::INT32); + + NDArray z1('c', {}, std::vector{100}, sd::DataType::DOUBLE); + NDArray z2('c', {2, 2}, {100, 100, 100, 100}, sd::DataType::FLOAT32); + NDArray z3('c', {3}, {100, 100, 100}, sd::DataType::DOUBLE); + NDArray z4('c', {3, 2}, {100, 100, 100, 100, 100, 100}, + sd::DataType::FLOAT32); + NDArray z5('c', {2}, {100, 100}, sd::DataType::FLOAT32); + + NDArray exp1('c', {}, std::vector{2.166667}, sd::DataType::DOUBLE); + NDArray exp2('c', {2, 2}, {3.f, 4.f, 1.f, 0.666667f}, sd::DataType::FLOAT32); + NDArray exp3('c', {3}, {4.5, 1, 1}, sd::DataType::DOUBLE); + NDArray exp4('c', {3, 2}, {4, 5, 1, 1, 1, 1}, sd::DataType::FLOAT32); + NDArray exp5('c', {2}, {3.5f, 0.833333f}, sd::DataType::FLOAT32); + + x.reduceAlongDimension(sd::reduce::Mean, z1, {0, 1, 2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); + + x.reduceAlongDimension(sd::reduce::Mean, z2, {1}); + ASSERT_TRUE(z2.equalsTo(&exp2)); + + x.reduceAlongDimension(sd::reduce::Mean, z3, {0, 2}); + ASSERT_TRUE(z3.equalsTo(&exp3)); + + x.permutei({1, 0, 2}); // 3x2x2 + + x.reduceAlongDimension(sd::reduce::Mean, z1, {0, 1, 2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); + + x.reduceAlongDimension(sd::reduce::Mean, z4, {1}); + ASSERT_TRUE(z4.equalsTo(&exp4)); + + x.reduceAlongDimension(sd::reduce::Mean, z5, {0, 2}); + ASSERT_TRUE(z5.equalsTo(&exp5)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_float_test2) { - - NDArray x('c', {2,3,2}, {1,2,3,4,5,6,7,8,-1,-2,-3,-4,}, sd::DataType::DOUBLE); - - NDArray exp1('c', {}, std::vector{2.166667}, sd::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {3,4,1,0.666667}, sd::DataType::DOUBLE); - NDArray exp3('c', {3}, {4.5,1,1}, sd::DataType::DOUBLE); - NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, sd::DataType::DOUBLE); - NDArray exp5('c', {2}, {3.5,0.833333}, sd::DataType::DOUBLE); - - NDArray z1 = x.reduceAlongDimension(sd::reduce::Mean, {0,1,2}); - ASSERT_TRUE(z1.equalsTo(&exp1)); - - NDArray z2 = x.reduceAlongDimension(sd::reduce::Mean, {1}); - ASSERT_TRUE(z2.equalsTo(&exp2)); - - NDArray z3 = x.reduceAlongDimension(sd::reduce::Mean, {0,2}); - ASSERT_TRUE(z3.equalsTo(&exp3)); - - x.permutei({1,0,2}); // 3x2x2 - - NDArray z4 = x.reduceAlongDimension(sd::reduce::Mean, {0,1,2}); - ASSERT_TRUE(z4.equalsTo(&exp1)); - - NDArray z5 = x.reduceAlongDimension(sd::reduce::Mean, {1}); - ASSERT_TRUE(z5.equalsTo(&exp4)); - - NDArray z6 = x.reduceAlongDimension(sd::reduce::Mean, {0,2}); - ASSERT_TRUE(z6.equalsTo(&exp5)); + NDArray x('c', {2, 3, 2}, + { + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + -1, + -2, + -3, + -4, + }, + sd::DataType::DOUBLE); + + NDArray exp1('c', {}, std::vector{2.166667}, sd::DataType::DOUBLE); + NDArray exp2('c', {2, 2}, {3, 4, 1, 0.666667}, sd::DataType::DOUBLE); + NDArray exp3('c', {3}, {4.5, 1, 1}, sd::DataType::DOUBLE); + NDArray exp4('c', {3, 2}, {4, 5, 1, 1, 1, 1}, sd::DataType::DOUBLE); + NDArray exp5('c', {2}, {3.5, 0.833333}, sd::DataType::DOUBLE); + + NDArray z1 = x.reduceAlongDimension(sd::reduce::Mean, {0, 1, 2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); + + NDArray z2 = x.reduceAlongDimension(sd::reduce::Mean, {1}); + ASSERT_TRUE(z2.equalsTo(&exp2)); + + NDArray z3 = x.reduceAlongDimension(sd::reduce::Mean, {0, 2}); + ASSERT_TRUE(z3.equalsTo(&exp3)); + + x.permutei({1, 0, 2}); // 3x2x2 + + NDArray z4 = x.reduceAlongDimension(sd::reduce::Mean, {0, 1, 2}); + ASSERT_TRUE(z4.equalsTo(&exp1)); + + NDArray z5 = x.reduceAlongDimension(sd::reduce::Mean, {1}); + ASSERT_TRUE(z5.equalsTo(&exp4)); + + NDArray z6 = x.reduceAlongDimension(sd::reduce::Mean, {0, 2}); + ASSERT_TRUE(z6.equalsTo(&exp5)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, EqualityTest1) { - auto arrayA = NDArrayFactory::create_('f', {3, 5}); - auto arrayB = NDArrayFactory::create_('f', {3, 5}); - auto arrayC = NDArrayFactory::create_('f', {3, 5}); + auto arrayA = NDArrayFactory::create_('f', {3, 5}); + auto arrayB = NDArrayFactory::create_('f', {3, 5}); + auto arrayC = NDArrayFactory::create_('f', {3, 5}); - auto arrayD = NDArrayFactory::create_('f', {2, 4}); - auto arrayE = NDArrayFactory::create_('f', {1, 15}); + auto arrayD = NDArrayFactory::create_('f', {2, 4}); + auto arrayE = NDArrayFactory::create_('f', {1, 15}); - for (int i = 0; i < arrayA->rows(); i++) { - for (int k = 0; k < arrayA->columns(); k++) { - arrayA->p(i, k, (float) i); - } + for (int i = 0; i < arrayA->rows(); i++) { + for (int k = 0; k < arrayA->columns(); k++) { + arrayA->p(i, k, (float)i); } + } - for (int i = 0; i < arrayB->rows(); i++) { - for (int k = 0; k < arrayB->columns(); k++) { - arrayB->p(i, k, (float) i); - } + for (int i = 0; i < arrayB->rows(); i++) { + for (int k = 0; k < arrayB->columns(); k++) { + arrayB->p(i, k, (float)i); } + } - for (int i = 0; i < arrayC->rows(); i++) { - for (int k = 0; k < arrayC->columns(); k++) { - arrayC->p(i, k, (float) i+1); - } + for (int i = 0; i < arrayC->rows(); i++) { + for (int k = 0; k < arrayC->columns(); k++) { + arrayC->p(i, k, (float)i + 1); } + } - ASSERT_TRUE(arrayA->equalsTo(arrayB, 1e-5)); + ASSERT_TRUE(arrayA->equalsTo(arrayB, 1e-5)); - ASSERT_FALSE(arrayC->equalsTo(arrayB, 1e-5)); + ASSERT_FALSE(arrayC->equalsTo(arrayB, 1e-5)); - ASSERT_FALSE(arrayD->equalsTo(arrayB, 1e-5)); + ASSERT_FALSE(arrayD->equalsTo(arrayB, 1e-5)); - ASSERT_FALSE(arrayE->equalsTo(arrayB, 1e-5)); + ASSERT_FALSE(arrayE->equalsTo(arrayB, 1e-5)); - delete arrayA; - delete arrayB; - delete arrayC; - delete arrayD; - delete arrayE; + delete arrayA; + delete arrayB; + delete arrayC; + delete arrayD; + delete arrayE; } //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test1) { + NDArray x('c', {2, 3, 2}, + {1.5f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.5f, 8.f, -1.f, -2.f, -3.5f, -4.f}, + sd::DataType::FLOAT32); - NDArray x('c', {2,3,2}, {1.5f,2.f,3.f,4.f,5.f,6.f,7.5f,8.f,-1.f,-2.f,-3.5f,-4.f}, sd::DataType::FLOAT32); - - NDArray z1('c', {}, std::vector{100}, sd::DataType::FLOAT32); - NDArray z2('c', {2,2}, {100,100,100,100}, sd::DataType::FLOAT32); - NDArray z3('c', {3}, {100,100,100}, sd::DataType::FLOAT32); - NDArray z4('c', {3,2}, {100,100,100,100,100,100}, sd::DataType::FLOAT32); - NDArray z5('c', {2}, {100,100}, sd::DataType::FLOAT32); + NDArray z1('c', {}, std::vector{100}, sd::DataType::FLOAT32); + NDArray z2('c', {2, 2}, {100, 100, 100, 100}, sd::DataType::FLOAT32); + NDArray z3('c', {3}, {100, 100, 100}, sd::DataType::FLOAT32); + NDArray z4('c', {3, 2}, {100, 100, 100, 100, 100, 100}, + sd::DataType::FLOAT32); + NDArray z5('c', {2}, {100, 100}, sd::DataType::FLOAT32); - NDArray exp1('c', {}, std::vector{26.5f}, sd::DataType::FLOAT32); - NDArray exp2('c', {2,2}, {9.5f,12.f,3.f,2.f}, sd::DataType::FLOAT32); - NDArray exp3('c', {3}, {19.f,4.f,3.5f}, sd::DataType::FLOAT32); - NDArray exp4('c', {3,2}, {9.f,10.f,2.f,2.f,1.5f,2.f}, sd::DataType::FLOAT32); - NDArray exp5('c', {2}, {21.5f,5.f}, sd::DataType::FLOAT32); + NDArray exp1('c', {}, std::vector{26.5f}, sd::DataType::FLOAT32); + NDArray exp2('c', {2, 2}, {9.5f, 12.f, 3.f, 2.f}, sd::DataType::FLOAT32); + NDArray exp3('c', {3}, {19.f, 4.f, 3.5f}, sd::DataType::FLOAT32); + NDArray exp4('c', {3, 2}, {9.f, 10.f, 2.f, 2.f, 1.5f, 2.f}, + sd::DataType::FLOAT32); + NDArray exp5('c', {2}, {21.5f, 5.f}, sd::DataType::FLOAT32); - x.reduceAlongDimension(sd::reduce::Sum, z1, {0,1,2}); - ASSERT_TRUE(z1.equalsTo(&exp1)); + x.reduceAlongDimension(sd::reduce::Sum, z1, {0, 1, 2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(sd::reduce::Sum, z2, {1}); - ASSERT_TRUE(z2.equalsTo(&exp2)); + x.reduceAlongDimension(sd::reduce::Sum, z2, {1}); + ASSERT_TRUE(z2.equalsTo(&exp2)); - x.reduceAlongDimension(sd::reduce::Sum, z3, {0,2}); - ASSERT_TRUE(z3.equalsTo(&exp3)); + x.reduceAlongDimension(sd::reduce::Sum, z3, {0, 2}); + ASSERT_TRUE(z3.equalsTo(&exp3)); - x.permutei({1,0,2}); // 3x2x2 + x.permutei({1, 0, 2}); // 3x2x2 - x.reduceAlongDimension(sd::reduce::Sum, z1, {0,1,2}); - ASSERT_TRUE(z1.equalsTo(&exp1)); + x.reduceAlongDimension(sd::reduce::Sum, z1, {0, 1, 2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(sd::reduce::Sum, z4, {1}); - ASSERT_TRUE(z4.equalsTo(&exp4)); + x.reduceAlongDimension(sd::reduce::Sum, z4, {1}); + ASSERT_TRUE(z4.equalsTo(&exp4)); - x.reduceAlongDimension(sd::reduce::Sum, z5, {0,2}); - ASSERT_TRUE(z5.equalsTo(&exp5)); + x.reduceAlongDimension(sd::reduce::Sum, z5, {0, 2}); + ASSERT_TRUE(z5.equalsTo(&exp5)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test2) { - - NDArray x('c', {2,3,2}, {1.5,2,3,4,5,6,7.5,8,-1,-2,-3.5,-4,}, sd::DataType::INT64); - - NDArray exp1('c', {}, std::vector{26}, sd::DataType::INT64); - NDArray exp2('c', {2,2}, {9,12,3,2}, sd::DataType::INT64); - NDArray exp3('c', {3}, {18,4,4}, sd::DataType::INT64); - NDArray exp4('c', {3,2}, {8,10,2,2,2,2}, sd::DataType::INT64); - NDArray exp5('c', {2}, {21,5}, sd::DataType::INT64); - - NDArray z1 = x.reduceAlongDimension(sd::reduce::Sum, {0,1,2}); - ASSERT_TRUE(z1.equalsTo(&exp1)); - - NDArray z2 = x.reduceAlongDimension(sd::reduce::Sum, {1}); - ASSERT_TRUE(z2.equalsTo(&exp2)); - - NDArray z3 = x.reduceAlongDimension(sd::reduce::Sum, {0,2}); - ASSERT_TRUE(z3.equalsTo(&exp3)); - - x.permutei({1,0,2}); // 3x2x2 - - NDArray z4 = x.reduceAlongDimension(sd::reduce::Sum, {0,1,2}); - ASSERT_TRUE(z4.equalsTo(&exp1)); - - NDArray z5 = x.reduceAlongDimension(sd::reduce::Sum, {1}); - ASSERT_TRUE(z5.equalsTo(&exp4)); - - NDArray z6 = x.reduceAlongDimension(sd::reduce::Sum, {0,2}); - ASSERT_TRUE(z6.equalsTo(&exp5)); + NDArray x('c', {2, 3, 2}, + { + 1.5, + 2, + 3, + 4, + 5, + 6, + 7.5, + 8, + -1, + -2, + -3.5, + -4, + }, + sd::DataType::INT64); + + NDArray exp1('c', {}, std::vector{26}, sd::DataType::INT64); + NDArray exp2('c', {2, 2}, {9, 12, 3, 2}, sd::DataType::INT64); + NDArray exp3('c', {3}, {18, 4, 4}, sd::DataType::INT64); + NDArray exp4('c', {3, 2}, {8, 10, 2, 2, 2, 2}, sd::DataType::INT64); + NDArray exp5('c', {2}, {21, 5}, sd::DataType::INT64); + + NDArray z1 = x.reduceAlongDimension(sd::reduce::Sum, {0, 1, 2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); + + NDArray z2 = x.reduceAlongDimension(sd::reduce::Sum, {1}); + ASSERT_TRUE(z2.equalsTo(&exp2)); + + NDArray z3 = x.reduceAlongDimension(sd::reduce::Sum, {0, 2}); + ASSERT_TRUE(z3.equalsTo(&exp3)); + + x.permutei({1, 0, 2}); // 3x2x2 + + NDArray z4 = x.reduceAlongDimension(sd::reduce::Sum, {0, 1, 2}); + ASSERT_TRUE(z4.equalsTo(&exp1)); + + NDArray z5 = x.reduceAlongDimension(sd::reduce::Sum, {1}); + ASSERT_TRUE(z5.equalsTo(&exp4)); + + NDArray z6 = x.reduceAlongDimension(sd::reduce::Sum, {0, 2}); + ASSERT_TRUE(z6.equalsTo(&exp5)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test1) { + NDArray x('c', {2, 3, 2}, {0.5, 2, 3, -4, 5, 6, -7.5, 8, -1, -0.5, -3.5, 4}, + sd::DataType::DOUBLE); - NDArray x('c', {2,3,2}, {0.5,2,3,-4,5,6,-7.5,8,-1,-0.5,-3.5,4}, sd::DataType::DOUBLE); + NDArray z1('c', {}, std::vector{true}, sd::DataType::BOOL); + NDArray z2('c', {2, 2}, {true, true, true, true}, sd::DataType::BOOL); + NDArray z3('c', {3}, {true, true, true}, sd::DataType::BOOL); + NDArray z4('c', {3, 2}, {true, true, true, true, true, true}, + sd::DataType::BOOL); + NDArray z5('c', {2}, {true, true}, sd::DataType::BOOL); - NDArray z1('c', {}, std::vector{true}, sd::DataType::BOOL); - NDArray z2('c', {2,2}, {true,true,true,true}, sd::DataType::BOOL); - NDArray z3('c', {3}, {true,true,true}, sd::DataType::BOOL); - NDArray z4('c', {3,2}, {true,true,true,true,true,true}, sd::DataType::BOOL); - NDArray z5('c', {2}, {true,true}, sd::DataType::BOOL); + NDArray exp1('c', {}, std::vector{true}, sd::DataType::BOOL); + NDArray exp2('c', {2, 2}, {true, true, false, true}, sd::DataType::BOOL); + NDArray exp3('c', {3}, {true, true, true}, sd::DataType::BOOL); + NDArray exp4('c', {3, 2}, {true, true, true, false, true, true}, + sd::DataType::BOOL); + NDArray exp5('c', {2}, {true, true}, sd::DataType::BOOL); - NDArray exp1('c', {}, std::vector{true}, sd::DataType::BOOL); - NDArray exp2('c', {2,2}, {true,true,false,true}, sd::DataType::BOOL); - NDArray exp3('c', {3}, {true,true,true}, sd::DataType::BOOL); - NDArray exp4('c', {3,2}, {true,true,true,false,true,true}, sd::DataType::BOOL); - NDArray exp5('c', {2}, {true,true}, sd::DataType::BOOL); + x.reduceAlongDimension(sd::reduce::IsPositive, z1, {0, 1, 2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(sd::reduce::IsPositive, z1, {0,1,2}); - ASSERT_TRUE(z1.equalsTo(&exp1)); + x.reduceAlongDimension(sd::reduce::IsPositive, z2, {1}); + ASSERT_TRUE(z2.equalsTo(&exp2)); - x.reduceAlongDimension(sd::reduce::IsPositive, z2, {1}); - ASSERT_TRUE(z2.equalsTo(&exp2)); + x.reduceAlongDimension(sd::reduce::IsPositive, z3, {0, 2}); + ASSERT_TRUE(z3.equalsTo(&exp3)); - x.reduceAlongDimension(sd::reduce::IsPositive, z3, {0,2}); - ASSERT_TRUE(z3.equalsTo(&exp3)); + x.permutei({1, 0, 2}); // 3x2x2 - x.permutei({1,0,2}); // 3x2x2 + x.reduceAlongDimension(sd::reduce::IsPositive, z1, {0, 1, 2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(sd::reduce::IsPositive, z1, {0,1,2}); - ASSERT_TRUE(z1.equalsTo(&exp1)); + x.reduceAlongDimension(sd::reduce::IsPositive, z4, {1}); + ASSERT_TRUE(z4.equalsTo(&exp4)); - x.reduceAlongDimension(sd::reduce::IsPositive, z4, {1}); - ASSERT_TRUE(z4.equalsTo(&exp4)); - - x.reduceAlongDimension(sd::reduce::IsPositive, z5, {0,2}); - ASSERT_TRUE(z5.equalsTo(&exp5)); + x.reduceAlongDimension(sd::reduce::IsPositive, z5, {0, 2}); + ASSERT_TRUE(z5.equalsTo(&exp5)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test2) { + NDArray x('c', {2, 3, 2}, {0.5, 2, 3, -4, 5, 6, -7.5, 8, -1, -0.5, -3.5, 4}, + sd::DataType::INT32); - NDArray x('c', {2,3,2}, {0.5,2,3,-4,5,6,-7.5,8,-1,-0.5,-3.5,4}, sd::DataType::INT32); - - NDArray exp1('c', {}, std::vector{1}, sd::DataType::BOOL); - NDArray exp2('c', {2,2}, {1,1,0,1}, sd::DataType::BOOL); - NDArray exp3('c', {3}, {1,1,1}, sd::DataType::BOOL); - NDArray exp4('c', {3,2}, {0,1,1,0,1,1}, sd::DataType::BOOL); - NDArray exp5('c', {2}, {1,1}, sd::DataType::BOOL); + NDArray exp1('c', {}, std::vector{1}, sd::DataType::BOOL); + NDArray exp2('c', {2, 2}, {1, 1, 0, 1}, sd::DataType::BOOL); + NDArray exp3('c', {3}, {1, 1, 1}, sd::DataType::BOOL); + NDArray exp4('c', {3, 2}, {0, 1, 1, 0, 1, 1}, sd::DataType::BOOL); + NDArray exp5('c', {2}, {1, 1}, sd::DataType::BOOL); - NDArray z1 = x.reduceAlongDimension(sd::reduce::IsPositive, {0,1,2}); - ASSERT_TRUE(z1.equalsTo(&exp1)); + NDArray z1 = x.reduceAlongDimension(sd::reduce::IsPositive, {0, 1, 2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); - NDArray z2 = x.reduceAlongDimension(sd::reduce::IsPositive, {1}); - ASSERT_TRUE(z2.equalsTo(&exp2)); + NDArray z2 = x.reduceAlongDimension(sd::reduce::IsPositive, {1}); + ASSERT_TRUE(z2.equalsTo(&exp2)); - NDArray z3 = x.reduceAlongDimension(sd::reduce::IsPositive, {0,2}); - ASSERT_TRUE(z3.equalsTo(&exp3)); + NDArray z3 = x.reduceAlongDimension(sd::reduce::IsPositive, {0, 2}); + ASSERT_TRUE(z3.equalsTo(&exp3)); - x.permutei({1,0,2}); // 3x2x2 + x.permutei({1, 0, 2}); // 3x2x2 - NDArray z4 = x.reduceAlongDimension(sd::reduce::IsPositive, {0,1,2}); - ASSERT_TRUE(z4.equalsTo(&exp1)); + NDArray z4 = x.reduceAlongDimension(sd::reduce::IsPositive, {0, 1, 2}); + ASSERT_TRUE(z4.equalsTo(&exp1)); - NDArray z5 = x.reduceAlongDimension(sd::reduce::IsPositive, {1}); - ASSERT_TRUE(z5.equalsTo(&exp4)); + NDArray z5 = x.reduceAlongDimension(sd::reduce::IsPositive, {1}); + ASSERT_TRUE(z5.equalsTo(&exp4)); - NDArray z6 = x.reduceAlongDimension(sd::reduce::IsPositive, {0,2}); - ASSERT_TRUE(z6.equalsTo(&exp5)); + NDArray z6 = x.reduceAlongDimension(sd::reduce::IsPositive, {0, 2}); + ASSERT_TRUE(z6.equalsTo(&exp5)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_long_test1) { + NDArray x( + 'c', {2, 3, 2}, + {0.5f, 2.f, 3.f, -0.f, 5.f, 6.f, -7.5f, 0.f, -1.f, -0.5f, -3.5f, 4.f}, + sd::DataType::FLOAT32); - NDArray x('c', {2,3,2}, {0.5f,2.f,3.f,-0.f,5.f,6.f,-7.5f,0.f,-1.f,-0.5f,-3.5f,4.f}, sd::DataType::FLOAT32); + NDArray z1('c', {}, std::vector{100}, sd::DataType::INT64); + NDArray z2('c', {2, 2}, {100, 100, 100, 100}, sd::DataType::INT64); + NDArray z3('c', {3}, {100, 100, 100}, sd::DataType::INT64); + NDArray z4('c', {3, 2}, {100, 100, 100, 100, 100, 100}, sd::DataType::INT64); + NDArray z5('c', {2}, {100, 100}, sd::DataType::INT64); - NDArray z1('c', {}, std::vector{100}, sd::DataType::INT64); - NDArray z2('c', {2,2}, {100,100,100,100}, sd::DataType::INT64); - NDArray z3('c', {3}, {100,100,100}, sd::DataType::INT64); - NDArray z4('c', {3,2}, {100,100,100,100,100,100}, sd::DataType::INT64); - NDArray z5('c', {2}, {100,100}, sd::DataType::INT64); + NDArray exp1('c', {}, std::vector{2}, sd::DataType::INT64); + NDArray exp2('c', {2, 2}, {0, 1, 0, 1}, sd::DataType::INT64); + NDArray exp3('c', {3}, {1, 1, 0}, sd::DataType::INT64); + NDArray exp4('c', {3, 2}, {0, 1, 0, 1, 0, 0}, sd::DataType::INT64); + NDArray exp5('c', {2}, {1, 1}, sd::DataType::INT64); - NDArray exp1('c', {}, std::vector{2}, sd::DataType::INT64); - NDArray exp2('c', {2,2}, {0,1,0,1}, sd::DataType::INT64); - NDArray exp3('c', {3}, {1,1,0}, sd::DataType::INT64); - NDArray exp4('c', {3,2}, {0,1,0,1,0,0}, sd::DataType::INT64); - NDArray exp5('c', {2}, {1,1}, sd::DataType::INT64); + x.reduceAlongDimension(sd::reduce::CountZero, z1, {0, 1, 2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(sd::reduce::CountZero, z1, {0,1,2}); - ASSERT_TRUE(z1.equalsTo(&exp1)); + x.reduceAlongDimension(sd::reduce::CountZero, z2, {1}); + ASSERT_TRUE(z2.equalsTo(&exp2)); - x.reduceAlongDimension(sd::reduce::CountZero, z2, {1}); - ASSERT_TRUE(z2.equalsTo(&exp2)); + x.reduceAlongDimension(sd::reduce::CountZero, z3, {0, 2}); + ASSERT_TRUE(z3.equalsTo(&exp3)); - x.reduceAlongDimension(sd::reduce::CountZero, z3, {0,2}); - ASSERT_TRUE(z3.equalsTo(&exp3)); + x.permutei({1, 0, 2}); // 3x2x2 - x.permutei({1,0,2}); // 3x2x2 + x.reduceAlongDimension(sd::reduce::CountZero, z1, {0, 1, 2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(sd::reduce::CountZero, z1, {0,1,2}); - ASSERT_TRUE(z1.equalsTo(&exp1)); + x.reduceAlongDimension(sd::reduce::CountZero, z4, {1}); + ASSERT_TRUE(z4.equalsTo(&exp4)); - x.reduceAlongDimension(sd::reduce::CountZero, z4, {1}); - ASSERT_TRUE(z4.equalsTo(&exp4)); - - x.reduceAlongDimension(sd::reduce::CountZero, z5, {0,2}); - ASSERT_TRUE(z5.equalsTo(&exp5)); + x.reduceAlongDimension(sd::reduce::CountZero, z5, {0, 2}); + ASSERT_TRUE(z5.equalsTo(&exp5)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_long_test2) { + NDArray x('c', {2, 3, 2}, {0.5, 2, 3, -0, 5, 6, -7.5, 0, -1, -0.5, -3.5, 4}, + sd::DataType::INT32); - NDArray x('c', {2,3,2}, {0.5,2,3,-0,5,6,-7.5,0,-1,-0.5,-3.5,4}, sd::DataType::INT32); - - NDArray exp1('c', {}, std::vector{4}, sd::DataType::INT64); - NDArray exp2('c', {2,2}, {1,1,0,2}, sd::DataType::INT64); - NDArray exp3('c', {3}, {2,2,0}, sd::DataType::INT64); - NDArray exp4('c', {3,2}, {1,1,0,2,0,0}, sd::DataType::INT64); - NDArray exp5('c', {2}, {2,2}, sd::DataType::INT64); + NDArray exp1('c', {}, std::vector{4}, sd::DataType::INT64); + NDArray exp2('c', {2, 2}, {1, 1, 0, 2}, sd::DataType::INT64); + NDArray exp3('c', {3}, {2, 2, 0}, sd::DataType::INT64); + NDArray exp4('c', {3, 2}, {1, 1, 0, 2, 0, 0}, sd::DataType::INT64); + NDArray exp5('c', {2}, {2, 2}, sd::DataType::INT64); - NDArray z1 = x.reduceAlongDimension(sd::reduce::CountZero, {0,1,2}); - ASSERT_TRUE(z1.equalsTo(&exp1)); + NDArray z1 = x.reduceAlongDimension(sd::reduce::CountZero, {0, 1, 2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); - NDArray z2 = x.reduceAlongDimension(sd::reduce::CountZero, {1}); - ASSERT_TRUE(z2.equalsTo(&exp2)); + NDArray z2 = x.reduceAlongDimension(sd::reduce::CountZero, {1}); + ASSERT_TRUE(z2.equalsTo(&exp2)); - NDArray z3 = x.reduceAlongDimension(sd::reduce::CountZero, {0,2}); - ASSERT_TRUE(z3.equalsTo(&exp3)); + NDArray z3 = x.reduceAlongDimension(sd::reduce::CountZero, {0, 2}); + ASSERT_TRUE(z3.equalsTo(&exp3)); - x.permutei({1,0,2}); // 3x2x2 + x.permutei({1, 0, 2}); // 3x2x2 - NDArray z4 = x.reduceAlongDimension(sd::reduce::CountZero, {0,1,2}); - ASSERT_TRUE(z4.equalsTo(&exp1)); + NDArray z4 = x.reduceAlongDimension(sd::reduce::CountZero, {0, 1, 2}); + ASSERT_TRUE(z4.equalsTo(&exp1)); - NDArray z5 = x.reduceAlongDimension(sd::reduce::CountZero, {1}); - ASSERT_TRUE(z5.equalsTo(&exp4)); + NDArray z5 = x.reduceAlongDimension(sd::reduce::CountZero, {1}); + ASSERT_TRUE(z5.equalsTo(&exp4)); - NDArray z6 = x.reduceAlongDimension(sd::reduce::CountZero, {0,2}); - ASSERT_TRUE(z6.equalsTo(&exp5)); + NDArray z6 = x.reduceAlongDimension(sd::reduce::CountZero, {0, 2}); + ASSERT_TRUE(z6.equalsTo(&exp5)); } TEST_F(NDArrayCudaBasicsTests, BroadcastOpsTest1) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); + auto row = NDArrayFactory::linspace(1.0f, 5.0f, 5); + NDArray expRow('c', + { + 1, + 5, + }, + {1, 2, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray exp('c', {5, 5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, + 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, + sd::DataType::FLOAT32); - auto x = NDArrayFactory::create('c', {5, 5}); - auto z = NDArrayFactory::create('c', {5, 5}); - auto row = NDArrayFactory::linspace(1.0f, 5.0f, 5); - NDArray expRow('c', {1, 5,}, {1,2,3,4,5}, sd::DataType::FLOAT32); - NDArray exp('c', {5,5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, sd::DataType::FLOAT32); - - ASSERT_TRUE(row->equalsTo(&expRow)); + ASSERT_TRUE(row->equalsTo(&expRow)); - x.applyBroadcast(broadcast::Add, {1}, *row, z); - x += *row; + x.applyBroadcast(broadcast::Add, {1}, *row, z); + x += *row; - ASSERT_TRUE(x.equalsTo(z)); - //ASSERT_TRUE(z.equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(z)); + // ASSERT_TRUE(z.equalsTo(&exp)); - delete row; + delete row; } TEST_F(NDArrayCudaBasicsTests, BroadcastOpsTest2) { - - auto x = NDArrayFactory::create('c', {5, 5}); - //auto z = NDArrayFactory::create('c', {5, 5}); - auto row = NDArrayFactory::linspace(1.0f, 5.0f, 5); - NDArray expRow('c', {1, 5,}, {1,2,3,4,5}, sd::DataType::FLOAT32); - NDArray exp('c', {5,5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, sd::DataType::FLOAT32); - - ASSERT_TRUE(row->equalsTo(&expRow)); - x.applyBroadcast(broadcast::Add, {1}, *row, x); - ASSERT_TRUE(x.equalsTo(&exp)); + auto x = NDArrayFactory::create('c', {5, 5}); + // auto z = NDArrayFactory::create('c', {5, 5}); + auto row = NDArrayFactory::linspace(1.0f, 5.0f, 5); + NDArray expRow('c', + { + 1, + 5, + }, + {1, 2, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray exp('c', {5, 5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, + 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, + sd::DataType::FLOAT32); + + ASSERT_TRUE(row->equalsTo(&expRow)); + x.applyBroadcast(broadcast::Add, {1}, *row, x); + ASSERT_TRUE(x.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestBroadcast_1) { + NDArray exp('c', {2, 3, 2, 2}, + {1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3., + 1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.}, + sd::DataType::DOUBLE); - NDArray exp('c', {2, 3, 2, 2}, {1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3., 1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.}, sd::DataType::DOUBLE); + auto input = NDArrayFactory::create('c', {2, 3, 2, 2}); + auto bias = NDArrayFactory::create('c', {1, 3}); - auto input = NDArrayFactory::create('c',{ 2, 3, 2, 2}); - auto bias = NDArrayFactory::create('c', {1, 3}); - - bias.linspace(1); - input.applyBroadcast(broadcast::Add, {1}, bias, input); - ASSERT_TRUE(exp.equalsTo(&input)); + bias.linspace(1); + input.applyBroadcast(broadcast::Add, {1}, bias, input); + ASSERT_TRUE(exp.equalsTo(&input)); } TEST_F(NDArrayCudaBasicsTests, TestFloat16_1) { - auto x = NDArrayFactory::create({1,2,3,4,5,7,8,9}); - auto y = NDArrayFactory::create({1,2,3,4,5,7,8,9}); - ASSERT_TRUE(x.equalsTo(&y)); + auto x = NDArrayFactory::create({1, 2, 3, 4, 5, 7, 8, 9}); + auto y = NDArrayFactory::create({1, 2, 3, 4, 5, 7, 8, 9}); + ASSERT_TRUE(x.equalsTo(&y)); } TEST_F(NDArrayCudaBasicsTests, TestFloat16_2) { - auto x = NDArrayFactory::create('c', {9}, {1,2,3,4,5,6,7,8,9}); - auto y = NDArrayFactory::create('c', {9}, {1,2,3,4,5,6,7,8,9}); - ASSERT_TRUE(x.equalsTo(y)); - //for (int e = 0; e < x.lengthOf(); e++) - // ASSERT_NEAR(x.e(e), y.e(e), 1.e-5f); + auto x = + NDArrayFactory::create('c', {9}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = + NDArrayFactory::create('c', {9}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + ASSERT_TRUE(x.equalsTo(y)); + // for (int e = 0; e < x.lengthOf(); e++) + // ASSERT_NEAR(x.e(e), y.e(e), 1.e-5f); } TEST_F(NDArrayCudaBasicsTests, TestFloat16_3) { - auto x = NDArrayFactory::create({1,2,3,4,5,7,8,9}); - auto y = NDArrayFactory::create({1,2,3,4,5,7,8,9}); - ASSERT_TRUE(x.equalsTo(&y)); + auto x = NDArrayFactory::create({1, 2, 3, 4, 5, 7, 8, 9}); + auto y = NDArrayFactory::create({1, 2, 3, 4, 5, 7, 8, 9}); + ASSERT_TRUE(x.equalsTo(&y)); } TEST_F(NDArrayCudaBasicsTests, TestFloat_4) { - auto x = NDArrayFactory::create({1,2,3,4,5,7,8,9}); - auto y = NDArrayFactory::create({2,4,5,5,6,7,8,9}); - ASSERT_FALSE(x.equalsTo(&y)); + auto x = NDArrayFactory::create({1, 2, 3, 4, 5, 7, 8, 9}); + auto y = NDArrayFactory::create({2, 4, 5, 5, 6, 7, 8, 9}); + ASSERT_FALSE(x.equalsTo(&y)); } TEST_F(NDArrayCudaBasicsTests, TestFloat_5) { - auto x = NDArrayFactory::create('c', {3,3}, {1,2,3,4,5,6,7,8,9}); - auto y = NDArrayFactory::create('c', {3,3}, {2,4,5,5,6,7,8,9, 10}); - ASSERT_FALSE(x.equalsTo(&y)); + auto x = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = + NDArrayFactory::create('c', {3, 3}, {2, 4, 5, 5, 6, 7, 8, 9, 10}); + ASSERT_FALSE(x.equalsTo(&y)); } TEST_F(NDArrayCudaBasicsTests, TestFloat_6) { - auto x = NDArrayFactory::create('f', {3,3}, {1,2,3,4,5,6,7,8,9}); - auto y = NDArrayFactory::create('f', {3,3}, {2,4,5,5,6,7,8,9,10}); - ASSERT_FALSE(x.equalsTo(&y)); + auto x = + NDArrayFactory::create('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = + NDArrayFactory::create('f', {3, 3}, {2, 4, 5, 5, 6, 7, 8, 9, 10}); + ASSERT_FALSE(x.equalsTo(&y)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_05) -{ - auto x = NDArrayFactory::create('c', {8, 8, 8}); - auto y = NDArrayFactory::create('c', {1, 8, 8}); - auto expected = NDArrayFactory::create('c', {8, 8, 8}); - NDArray res2 = NDArrayFactory::create(expected.ordering(), expected.getShapeAsVector()); - x = 1.; - y = 2.; - expected = 3.; - res2 = 0.f; +TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_05) { + auto x = NDArrayFactory::create('c', {8, 8, 8}); + auto y = NDArrayFactory::create('c', {1, 8, 8}); + auto expected = NDArrayFactory::create('c', {8, 8, 8}); + NDArray res2 = NDArrayFactory::create(expected.ordering(), + expected.getShapeAsVector()); + x = 1.; + y = 2.; + expected = 3.; + res2 = 0.f; - x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, res2);// *= y; + x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, res2); // *= y; - ASSERT_TRUE(expected.isSameShape(&res2)); - ASSERT_TRUE(expected.equalsTo(&res2)); + ASSERT_TRUE(expected.isSameShape(&res2)); + ASSERT_TRUE(expected.equalsTo(&res2)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_5) -{ - auto x = NDArrayFactory::create('c', {8, 8, 8}); - auto y = NDArrayFactory::create('c', {8, 1, 8}); - auto expected = NDArrayFactory::create('c', {8, 8, 8}); - NDArray res2(expected); - x = 1.; - y = 2.; - expected = 3.; - //x.printBuffer("X="); - //y.printBuffer("Y="); - //expected.printBuffer("EXPECTED"); - auto result = x + y; - //result.printBuffer("1 + 2 ="); - //res2.assign(x + y); - - //x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &res2); - //res2.printBuffer("Z="); - //x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &res2);// *= y; -// x += y; - //x.printBuffer("OutputX"); - //res2.syncToHost(); - //res2.printBuffer("OUputZ"); - //x.printIndexedBuffer("OUtputX"); - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); +TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_5) { + auto x = NDArrayFactory::create('c', {8, 8, 8}); + auto y = NDArrayFactory::create('c', {8, 1, 8}); + auto expected = NDArrayFactory::create('c', {8, 8, 8}); + NDArray res2(expected); + x = 1.; + y = 2.; + expected = 3.; + // x.printBuffer("X="); + // y.printBuffer("Y="); + // expected.printBuffer("EXPECTED"); + auto result = x + y; + // result.printBuffer("1 + 2 ="); + // res2.assign(x + y); + + // x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &res2); + // res2.printBuffer("Z="); + // x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &res2);// *= y; + // x += y; + // x.printBuffer("OutputX"); + // res2.syncToHost(); + // res2.printBuffer("OUputZ"); + // x.printIndexedBuffer("OUtputX"); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_51) -{ - auto x = NDArrayFactory::create('c', {8, 8, 8}); - auto y = NDArrayFactory::create('c', {8, 8}); - auto expected = NDArrayFactory::create('c', {8, 8, 8}); - NDArray res2(expected); - x = 1.; - y = 2.; - expected = 3.; - //x.printBuffer("X="); - //y.printBuffer("Y="); - //expected.printBuffer("EXPECTED"); - auto result = x + y; - //result.printBuffer("1 + 2 ="); - //res2.assign(x + y); - - //x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &res2); - //res2.printBuffer("Z="); - //x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &res2);// *= y; -// x += y; - //x.printBuffer("OutputX"); - //res2.syncToHost(); - //res2.printBuffer("OUputZ"); - //x.printIndexedBuffer("OUtputX"); - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); -} - -TEST_F(NDArrayCudaBasicsTests, Tile_Test_2_1) -{ - auto x = NDArrayFactory::create('c', {2, 1, 2}); - x = 10.; - auto y = x.tile({1,2,1}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}); - exp = 10.; - - // y.printShapeInfo("Output SHAPE"); - // y.printBuffer("Output TILE"); - // exp.printBuffer("Expect TILE"); - ASSERT_TRUE(exp.equalsTo(y)); -} - -TEST_F(NDArrayCudaBasicsTests, Tile_Test_2_2) -{ - auto x = NDArrayFactory::create('f', {2, 1, 2}); - x = 10.; - auto y = x.tile({1,2,1}); - auto exp = NDArrayFactory::create('f', {2, 2, 2}); - exp = 10.; - ASSERT_TRUE(exp.equalsTo(y)); -} - -TEST_F(NDArrayCudaBasicsTests, Tile_Test_2_3) -{ - auto x = NDArrayFactory::create('f', {2, 1, 2}); - x = 10.; - x.p(1,0,1, 20); - x.syncToDevice(); - auto y = x.tile({1,2,1}); - auto exp = NDArrayFactory::create('f', {2, 2, 2}); - exp = 10.; - exp.p(1,0,1, 20.); - exp.p(1, 1, 1, 20.); - exp.syncToDevice(); - ASSERT_TRUE(exp.equalsTo(y)); +TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_51) { + auto x = NDArrayFactory::create('c', {8, 8, 8}); + auto y = NDArrayFactory::create('c', {8, 8}); + auto expected = NDArrayFactory::create('c', {8, 8, 8}); + NDArray res2(expected); + x = 1.; + y = 2.; + expected = 3.; + // x.printBuffer("X="); + // y.printBuffer("Y="); + // expected.printBuffer("EXPECTED"); + auto result = x + y; + // result.printBuffer("1 + 2 ="); + // res2.assign(x + y); + + // x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &res2); + // res2.printBuffer("Z="); + // x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &res2);// *= y; + // x += y; + // x.printBuffer("OutputX"); + // res2.syncToHost(); + // res2.printBuffer("OUputZ"); + // x.printIndexedBuffer("OUtputX"); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + +TEST_F(NDArrayCudaBasicsTests, Tile_Test_2_1) { + auto x = NDArrayFactory::create('c', {2, 1, 2}); + x = 10.; + auto y = x.tile({1, 2, 1}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}); + exp = 10.; + + // y.printShapeInfo("Output SHAPE"); + // y.printBuffer("Output TILE"); + // exp.printBuffer("Expect TILE"); + ASSERT_TRUE(exp.equalsTo(y)); +} + +TEST_F(NDArrayCudaBasicsTests, Tile_Test_2_2) { + auto x = NDArrayFactory::create('f', {2, 1, 2}); + x = 10.; + auto y = x.tile({1, 2, 1}); + auto exp = NDArrayFactory::create('f', {2, 2, 2}); + exp = 10.; + ASSERT_TRUE(exp.equalsTo(y)); +} + +TEST_F(NDArrayCudaBasicsTests, Tile_Test_2_3) { + auto x = NDArrayFactory::create('f', {2, 1, 2}); + x = 10.; + x.p(1, 0, 1, 20); + x.syncToDevice(); + auto y = x.tile({1, 2, 1}); + auto exp = NDArrayFactory::create('f', {2, 2, 2}); + exp = 10.; + exp.p(1, 0, 1, 20.); + exp.p(1, 1, 1, 20.); + exp.syncToDevice(); + ASSERT_TRUE(exp.equalsTo(y)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_2) -{ - double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.}; - NDArray a('c', {4,4}, {1,2,3,4,5,6,7,8,9,2,3,2,1,0,4,7}, sd::DataType::FLOAT32); - auto x = NDArrayFactory::create('c', {3, 2, 1}); - auto y = NDArrayFactory::create('c', {1, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); +TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_2) { + double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.}; + NDArray a('c', {4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 2, 1, 0, 4, 7}, + sd::DataType::FLOAT32); + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); - x.linspace(1); - y.linspace(1); - auto result = x + y; + x.linspace(1); + y.linspace(1); + auto result = x + y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayCudaBasicsTests, assign_2) -{ - NDArray x('c', {4}, {1.5f,2.5f,3.5f,4.5f}, sd::DataType::FLOAT32); - NDArray y('c', {4}, sd::DataType::INT32); - NDArray expected('c', {4}, {1,2,3,4}, sd::DataType::INT32); +TEST_F(NDArrayCudaBasicsTests, assign_2) { + NDArray x('c', {4}, {1.5f, 2.5f, 3.5f, 4.5f}, sd::DataType::FLOAT32); + NDArray y('c', {4}, sd::DataType::INT32); + NDArray expected('c', {4}, {1, 2, 3, 4}, sd::DataType::INT32); - y.assign(x); - // y.printBuffer("ASSIGN VECTOR"); + y.assign(x); + // y.printBuffer("ASSIGN VECTOR"); - ASSERT_TRUE(expected.equalsTo(&y)); + ASSERT_TRUE(expected.equalsTo(&y)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayCudaBasicsTests, subarray_1) -{ - NDArray x('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, sd::DataType::FLOAT32); - NDArray y('f', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, sd::DataType::FLOAT32); - - Nd4jLong shapeExpX0[] = {1, 2, 12, 8192, 1, 99}; - float buffExpX0[] = {1.f, 13.f}; - Nd4jLong shapeExpX1[] = {1, 2, 12, 8192, 1, 99}; - float buffExpX1[] = {2.f, 14.f}; - Nd4jLong shapeExpX2[] = {3, 2, 1, 1, 12, 4, 1, 8192, 1, 99}; - float buffExpX2[] = {1.f, 13.f}; - Nd4jLong shapeExpX3[] = {2, 2, 4, 12, 1, 8192, 1, 99}; - float buffExpX3[] = {9.f, 10.f, 11.f, 12.f, 21.f, 22.f, 23.f, 24.f}; - Nd4jLong shapeExpX4[] = {3, 2, 1, 4, 12, 4, 1, 8192, 1, 99}; - float buffExpX4[] = {9.f, 10.f, 11.f, 12.f, 21.f, 22.f, 23.f, 24.f}; - Nd4jLong shapeExpX5[] = {2, 2, 3, 12, 4, 8192, 1, 99}; - float buffExpX5[] = {4.f, 8.f, 12.f, 16.f, 20.f, 24.f}; - - Nd4jLong shapeExpY0[] = {1, 2, 1, 8192, 1, 99}; - float buffExpY0[] = {1.f, 2.f}; - Nd4jLong shapeExpY1[] = {1, 2, 1, 8192, 1, 99}; - float buffExpY1[] = {7.f, 8.f}; - Nd4jLong shapeExpY2[] = {3, 2, 1, 1, 1, 2, 6, 8192, 1, 102}; - float buffExpY2[] = {1.f, 2.f}; - Nd4jLong shapeExpY3[] = {2, 2, 4, 1, 6, 8192, 1, 99}; - float buffExpY3[] = {5.f, 11.f, 17.f, 23.f, 6.f, 12.f, 18.f, 24.f}; - Nd4jLong shapeExpY4[] = {3, 2, 1, 4, 1, 2, 6, 8192, 1, 102}; - float buffExpY4[] = {5.f, 11.f, 17.f, 23.f, 6.f, 12.f, 18.f, 24.f}; - Nd4jLong shapeExpY5[] = {2, 2, 3, 1, 2, 8192, 1, 99}; - float buffExpY5[] = {19.f, 21.f, 23.f, 20.f, 22.f, 24.f}; - - - NDArray x0 = x(0, {1,2}); - NDArray xExp(buffExpX0, shapeExpX0); - - ASSERT_TRUE(xExp.isSameShape(x0)); - ASSERT_TRUE(xExp.equalsTo(x0)); -// for(int i = 0; i < shape::shapeInfoLength(x0.rankOf()); ++i) -// ASSERT_TRUE(x0.shapeInfo()[i] == shapeExpX0[i]); -// for(int i = 0; i < x0.lengthOf(); ++i) -// ASSERT_TRUE(x0.e(i) == buffExpX0[i]); - - NDArray x1 = x(1, {1,2}); - NDArray x1Exp(buffExpX1, shapeExpX1); - ASSERT_TRUE(x1Exp.isSameShape(x1)); - ASSERT_TRUE(x1Exp.equalsTo(x1)); - -// for(int i = 0; i < shape::shapeInfoLength(x1.rankOf()); ++i) -// ASSERT_TRUE(x1.shapeInfo()[i] == shapeExpX1[i]); -// for(int i = 0; i < x1.lengthOf(); ++i) -// ASSERT_TRUE(x1.e(i) == buffExpX1[i]); - - NDArray x2 = x(0, {1,2}, true); - NDArray x2Exp(buffExpX2, shapeExpX2); - ASSERT_TRUE(x2Exp.isSameShape(x2)); -// x2.printBuffer("X2"); -// x2Exp.printBuffer("X2 EXPECT"); - ASSERT_TRUE(x2Exp.equalsTo(x2)); -// for(int i = 0; i < shape::shapeInfoLength(x2.rankOf()); ++i) -// ASSERT_TRUE(x2.shapeInfo()[i] == shapeExpX2[i]); -// for(int i = 0; i < x2.lengthOf(); ++i) -// ASSERT_TRUE(x2.e(i) == buffExpX2[i]); - - NDArray x3 = x(2, {1}); - NDArray x3Exp(buffExpX3, shapeExpX3); - ASSERT_TRUE(x3Exp.isSameShape(x3)); - ASSERT_TRUE(x3Exp.equalsTo(x3)); -// for(int i = 0; i < shape::shapeInfoLength(x3.rankOf()); ++i) -// ASSERT_TRUE(x3.shapeInfo()[i] == shapeExpX3[i]); -// for(int i = 0; i < x3.lengthOf(); ++i) -// ASSERT_TRUE(x3.e(i) == buffExpX3[i]); - - NDArray x4 = x(2, {1}, true); - NDArray x4Exp(buffExpX4, shapeExpX4); - ASSERT_TRUE(x4Exp.isSameShape(x4)); - ASSERT_TRUE(x4Exp.equalsTo(x4)); -// for(int i = 0; i < shape::shapeInfoLength(x4.rankOf()); ++i) -// ASSERT_TRUE(x4.shapeInfo()[i] == shapeExpX4[i]); -// for(int i = 0; i < x4.lengthOf(); ++i) -// ASSERT_TRUE(x4.e(i) == buffExpX4[i]); - - NDArray x5 = x(3, {2}); - NDArray x5Exp(buffExpX5, shapeExpX5); - ASSERT_TRUE(x5Exp.isSameShape(x5)); - ASSERT_TRUE(x5Exp.equalsTo(x5)); - -// for(int i = 0; i < shape::shapeInfoLength(x5.rankOf()); ++i) -// ASSERT_TRUE(x5.shapeInfo()[i] == shapeExpX5[i]); -// for(int i = 0; i < x5.lengthOf(); ++i) -// ASSERT_TRUE(x5.e(i) == buffExpX5[i]); - - // ******************* // - NDArray y0 = y(0, {1,2}); - NDArray y0Exp(buffExpY0, shapeExpY0); - ASSERT_TRUE(y0Exp.isSameShape(y0)); - ASSERT_TRUE(y0Exp.equalsTo(y0)); -// for(int i = 0; i < shape::shapeInfoLength(y0.rankOf()); ++i) -// ASSERT_TRUE(y0.shapeInfo()[i] == shapeExpY0[i]); -// for(int i = 0; i < y0.lengthOf(); ++i) -// ASSERT_TRUE(y0.e(i) == buffExpY0[i]); - - NDArray y1 = y(1, {1,2}); - NDArray y1Exp(buffExpY1, shapeExpY1); - ASSERT_TRUE(y1Exp.isSameShape(y1)); - ASSERT_TRUE(y1Exp.equalsTo(y1)); -// for(int i = 0; i < shape::shapeInfoLength(y1.rankOf()); ++i) -// ASSERT_TRUE(y1.shapeInfo()[i] == shapeExpY1[i]); -// for(int i = 0; i < y1.lengthOf(); ++i) -// ASSERT_TRUE(y1.e(i) == buffExpY1[i]); - - NDArray y2 = y(0, {1,2}, true); - NDArray y2Exp(buffExpY2, shapeExpY2); - ASSERT_TRUE(y2Exp.isSameShape(y2)); - ASSERT_TRUE(y2Exp.equalsTo(y2)); -// for(int i = 0; i < shape::shapeInfoLength(y2.rankOf()); ++i) -// ASSERT_TRUE(y2.shapeInfo()[i] == shapeExpY2[i]); -// for(int i = 0; i < y2.lengthOf(); ++i) -// ASSERT_TRUE(y2.e(i) == buffExpY2[i]); - - NDArray y3 = y(2, {1}); - NDArray y3Exp(buffExpY3, shapeExpY3); - ASSERT_TRUE(y3Exp.isSameShape(y3)); - ASSERT_TRUE(y3Exp.equalsTo(y3)); -// for(int i = 0; i < shape::shapeInfoLength(y3.rankOf()); ++i) -// ASSERT_TRUE(y3.shapeInfo()[i] == shapeExpY3[i]); -// for(int i = 0; i < y3.lengthOf(); ++i) -// ASSERT_TRUE(y3.e(i) == buffExpY3[i]); - - NDArray y4 = y(2, {1}, true); - NDArray y4Exp = NDArrayFactory::create('f', {2,1,4}, {5, 6, 11, 12, 17, 18, 23, 24}); - ASSERT_TRUE(y4Exp.isSameShape(y4)); - ASSERT_TRUE(y4Exp.equalsTo(y4)); -// for(int i = 0; i < shape::shapeInfoLength(y4.rankOf()); ++i) -// ASSERT_TRUE(y4.shapeInfo()[i] == shapeExpY4[i]); -// for(int i = 0; i < y4.lengthOf(); ++i) -// ASSERT_TRUE(y4.e(i) == buffExpY4[i]); - - NDArray y5 = y(3, {2}); - NDArray y5Exp(buffExpY5, shapeExpY5); - ASSERT_TRUE(y5Exp.isSameShape(y5)); - ASSERT_TRUE(y5Exp.equalsTo(y5)); -// for(int i = 0; i < shape::shapeInfoLength(y5.rankOf()); ++i) -// ASSERT_TRUE(y5.shapeInfo()[i] == shapeExpY5[i]); -// for(int i = 0; i < y5.lengthOf(); ++i) -// ASSERT_TRUE(y5.e(i) == buffExpY5[i]); +TEST_F(NDArrayCudaBasicsTests, subarray_1) { + NDArray x('c', {2, 3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}, + sd::DataType::FLOAT32); + NDArray y('f', {2, 3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}, + sd::DataType::FLOAT32); + Nd4jLong shapeExpX0[] = {1, 2, 12, 8192, 1, 99}; + float buffExpX0[] = {1.f, 13.f}; + Nd4jLong shapeExpX1[] = {1, 2, 12, 8192, 1, 99}; + float buffExpX1[] = {2.f, 14.f}; + Nd4jLong shapeExpX2[] = {3, 2, 1, 1, 12, 4, 1, 8192, 1, 99}; + float buffExpX2[] = {1.f, 13.f}; + Nd4jLong shapeExpX3[] = {2, 2, 4, 12, 1, 8192, 1, 99}; + float buffExpX3[] = {9.f, 10.f, 11.f, 12.f, 21.f, 22.f, 23.f, 24.f}; + Nd4jLong shapeExpX4[] = {3, 2, 1, 4, 12, 4, 1, 8192, 1, 99}; + float buffExpX4[] = {9.f, 10.f, 11.f, 12.f, 21.f, 22.f, 23.f, 24.f}; + Nd4jLong shapeExpX5[] = {2, 2, 3, 12, 4, 8192, 1, 99}; + float buffExpX5[] = {4.f, 8.f, 12.f, 16.f, 20.f, 24.f}; + + Nd4jLong shapeExpY0[] = {1, 2, 1, 8192, 1, 99}; + float buffExpY0[] = {1.f, 2.f}; + Nd4jLong shapeExpY1[] = {1, 2, 1, 8192, 1, 99}; + float buffExpY1[] = {7.f, 8.f}; + Nd4jLong shapeExpY2[] = {3, 2, 1, 1, 1, 2, 6, 8192, 1, 102}; + float buffExpY2[] = {1.f, 2.f}; + Nd4jLong shapeExpY3[] = {2, 2, 4, 1, 6, 8192, 1, 99}; + float buffExpY3[] = {5.f, 11.f, 17.f, 23.f, 6.f, 12.f, 18.f, 24.f}; + Nd4jLong shapeExpY4[] = {3, 2, 1, 4, 1, 2, 6, 8192, 1, 102}; + float buffExpY4[] = {5.f, 11.f, 17.f, 23.f, 6.f, 12.f, 18.f, 24.f}; + Nd4jLong shapeExpY5[] = {2, 2, 3, 1, 2, 8192, 1, 99}; + float buffExpY5[] = {19.f, 21.f, 23.f, 20.f, 22.f, 24.f}; + + NDArray x0 = x(0, {1, 2}); + NDArray xExp(buffExpX0, shapeExpX0); + + ASSERT_TRUE(xExp.isSameShape(x0)); + ASSERT_TRUE(xExp.equalsTo(x0)); + // for(int i = 0; i < shape::shapeInfoLength(x0.rankOf()); ++i) + // ASSERT_TRUE(x0.shapeInfo()[i] == shapeExpX0[i]); + // for(int i = 0; i < x0.lengthOf(); ++i) + // ASSERT_TRUE(x0.e(i) == buffExpX0[i]); + + NDArray x1 = x(1, {1, 2}); + NDArray x1Exp(buffExpX1, shapeExpX1); + ASSERT_TRUE(x1Exp.isSameShape(x1)); + ASSERT_TRUE(x1Exp.equalsTo(x1)); + + // for(int i = 0; i < shape::shapeInfoLength(x1.rankOf()); ++i) + // ASSERT_TRUE(x1.shapeInfo()[i] == shapeExpX1[i]); + // for(int i = 0; i < x1.lengthOf(); ++i) + // ASSERT_TRUE(x1.e(i) == buffExpX1[i]); + + NDArray x2 = x(0, {1, 2}, true); + NDArray x2Exp(buffExpX2, shapeExpX2); + ASSERT_TRUE(x2Exp.isSameShape(x2)); + // x2.printBuffer("X2"); + // x2Exp.printBuffer("X2 EXPECT"); + ASSERT_TRUE(x2Exp.equalsTo(x2)); + // for(int i = 0; i < shape::shapeInfoLength(x2.rankOf()); ++i) + // ASSERT_TRUE(x2.shapeInfo()[i] == shapeExpX2[i]); + // for(int i = 0; i < x2.lengthOf(); ++i) + // ASSERT_TRUE(x2.e(i) == buffExpX2[i]); + + NDArray x3 = x(2, {1}); + NDArray x3Exp(buffExpX3, shapeExpX3); + ASSERT_TRUE(x3Exp.isSameShape(x3)); + ASSERT_TRUE(x3Exp.equalsTo(x3)); + // for(int i = 0; i < shape::shapeInfoLength(x3.rankOf()); ++i) + // ASSERT_TRUE(x3.shapeInfo()[i] == shapeExpX3[i]); + // for(int i = 0; i < x3.lengthOf(); ++i) + // ASSERT_TRUE(x3.e(i) == buffExpX3[i]); + + NDArray x4 = x(2, {1}, true); + NDArray x4Exp(buffExpX4, shapeExpX4); + ASSERT_TRUE(x4Exp.isSameShape(x4)); + ASSERT_TRUE(x4Exp.equalsTo(x4)); + // for(int i = 0; i < shape::shapeInfoLength(x4.rankOf()); ++i) + // ASSERT_TRUE(x4.shapeInfo()[i] == shapeExpX4[i]); + // for(int i = 0; i < x4.lengthOf(); ++i) + // ASSERT_TRUE(x4.e(i) == buffExpX4[i]); + + NDArray x5 = x(3, {2}); + NDArray x5Exp(buffExpX5, shapeExpX5); + ASSERT_TRUE(x5Exp.isSameShape(x5)); + ASSERT_TRUE(x5Exp.equalsTo(x5)); + + // for(int i = 0; i < shape::shapeInfoLength(x5.rankOf()); ++i) + // ASSERT_TRUE(x5.shapeInfo()[i] == shapeExpX5[i]); + // for(int i = 0; i < x5.lengthOf(); ++i) + // ASSERT_TRUE(x5.e(i) == buffExpX5[i]); + + // ******************* // + NDArray y0 = y(0, {1, 2}); + NDArray y0Exp(buffExpY0, shapeExpY0); + ASSERT_TRUE(y0Exp.isSameShape(y0)); + ASSERT_TRUE(y0Exp.equalsTo(y0)); + // for(int i = 0; i < shape::shapeInfoLength(y0.rankOf()); ++i) + // ASSERT_TRUE(y0.shapeInfo()[i] == shapeExpY0[i]); + // for(int i = 0; i < y0.lengthOf(); ++i) + // ASSERT_TRUE(y0.e(i) == buffExpY0[i]); + + NDArray y1 = y(1, {1, 2}); + NDArray y1Exp(buffExpY1, shapeExpY1); + ASSERT_TRUE(y1Exp.isSameShape(y1)); + ASSERT_TRUE(y1Exp.equalsTo(y1)); + // for(int i = 0; i < shape::shapeInfoLength(y1.rankOf()); ++i) + // ASSERT_TRUE(y1.shapeInfo()[i] == shapeExpY1[i]); + // for(int i = 0; i < y1.lengthOf(); ++i) + // ASSERT_TRUE(y1.e(i) == buffExpY1[i]); + + NDArray y2 = y(0, {1, 2}, true); + NDArray y2Exp(buffExpY2, shapeExpY2); + ASSERT_TRUE(y2Exp.isSameShape(y2)); + ASSERT_TRUE(y2Exp.equalsTo(y2)); + // for(int i = 0; i < shape::shapeInfoLength(y2.rankOf()); ++i) + // ASSERT_TRUE(y2.shapeInfo()[i] == shapeExpY2[i]); + // for(int i = 0; i < y2.lengthOf(); ++i) + // ASSERT_TRUE(y2.e(i) == buffExpY2[i]); + + NDArray y3 = y(2, {1}); + NDArray y3Exp(buffExpY3, shapeExpY3); + ASSERT_TRUE(y3Exp.isSameShape(y3)); + ASSERT_TRUE(y3Exp.equalsTo(y3)); + // for(int i = 0; i < shape::shapeInfoLength(y3.rankOf()); ++i) + // ASSERT_TRUE(y3.shapeInfo()[i] == shapeExpY3[i]); + // for(int i = 0; i < y3.lengthOf(); ++i) + // ASSERT_TRUE(y3.e(i) == buffExpY3[i]); + + NDArray y4 = y(2, {1}, true); + NDArray y4Exp = NDArrayFactory::create('f', {2, 1, 4}, + {5, 6, 11, 12, 17, 18, 23, 24}); + ASSERT_TRUE(y4Exp.isSameShape(y4)); + ASSERT_TRUE(y4Exp.equalsTo(y4)); + // for(int i = 0; i < shape::shapeInfoLength(y4.rankOf()); ++i) + // ASSERT_TRUE(y4.shapeInfo()[i] == shapeExpY4[i]); + // for(int i = 0; i < y4.lengthOf(); ++i) + // ASSERT_TRUE(y4.e(i) == buffExpY4[i]); + + NDArray y5 = y(3, {2}); + NDArray y5Exp(buffExpY5, shapeExpY5); + ASSERT_TRUE(y5Exp.isSameShape(y5)); + ASSERT_TRUE(y5Exp.equalsTo(y5)); + // for(int i = 0; i < shape::shapeInfoLength(y5.rankOf()); ++i) + // ASSERT_TRUE(y5.shapeInfo()[i] == shapeExpY5[i]); + // for(int i = 0; i < y5.lengthOf(); ++i) + // ASSERT_TRUE(y5.e(i) == buffExpY5[i]); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, Test_diagonal_1) { - - auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); - auto exp = NDArrayFactory::create('c', {2, 1}, {1, 5}); - - auto diag = x.diagonal('c'); - //diag.syncToDevice(); - for (Nd4jLong e = 0; e < exp.lengthOf(); ++e) { - printf("VAL[%ld] = %f\n", e, diag.e(e)); //, exp.e(e), 1.e-5); - } - - for (Nd4jLong e = 0; e < exp.lengthOf(); ++e) { - ASSERT_NEAR(diag.e(e), exp.e(e), 1.e-5); - } - double eps(1.e-5); - NDArray tmp(sd::DataType::FLOAT32, x.getContext()); // scalar = 0 - - ExtraArguments extras({eps}); - NativeOpExecutioner::execReduce3Scalar(diag.getContext(), reduce3::EqualsWithEps, diag.buffer(), - diag.shapeInfo(), diag.specialBuffer(), diag.specialShapeInfo(), extras.argumentsAsT(sd::DataType::FLOAT32), - exp.buffer(), exp.shapeInfo(), exp.specialBuffer(), exp.specialShapeInfo(), - tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo()); - cudaStream_t* stream = x.getContext()->getCudaStream(); - auto res = cudaStreamSynchronize(*stream); - // tmp.printBuffer("Compare result is (expected 0)"); - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + auto exp = NDArrayFactory::create('c', {2, 1}, {1, 5}); + + auto diag = x.diagonal('c'); + // diag.syncToDevice(); + for (Nd4jLong e = 0; e < exp.lengthOf(); ++e) { + printf("VAL[%ld] = %f\n", e, + diag.e(e)); //, exp.e(e), 1.e-5); + } + + for (Nd4jLong e = 0; e < exp.lengthOf(); ++e) { + ASSERT_NEAR(diag.e(e), exp.e(e), 1.e-5); + } + double eps(1.e-5); + NDArray tmp(sd::DataType::FLOAT32, x.getContext()); // scalar = 0 + + ExtraArguments extras({eps}); + NativeOpExecutioner::execReduce3Scalar( + diag.getContext(), reduce3::EqualsWithEps, diag.buffer(), + diag.shapeInfo(), diag.specialBuffer(), diag.specialShapeInfo(), + extras.argumentsAsT(sd::DataType::FLOAT32), exp.buffer(), exp.shapeInfo(), + exp.specialBuffer(), exp.specialShapeInfo(), tmp.buffer(), + tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo()); + cudaStream_t* stream = x.getContext()->getCudaStream(); + auto res = cudaStreamSynchronize(*stream); + // tmp.printBuffer("Compare result is (expected 0)"); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_02) { - auto x = NDArrayFactory::linspace(1.f, 60.f, 60); //('c', {1, 60}); - //x.linspace(1); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); - x->reshapei('c', {3, 4, 5}); + auto x = NDArrayFactory::linspace(1.f, 60.f, 60); //('c', {1, 60}); + // x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 5}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, + 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, + 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, + 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, + 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, + 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); + x->reshapei('c', {3, 4, 5}); - x->permutei({0, 1, 2}); - x->streamline(); + x->permutei({0, 1, 2}); + x->streamline(); -// x.printShapeInfo("{0, 1, 2} shape"); -// x.printBuffer("{0, 1, 2} data"); + // x.printShapeInfo("{0, 1, 2} shape"); + // x.printBuffer("{0, 1, 2} data"); - ASSERT_TRUE(exp.isSameShape(x)); - ASSERT_TRUE(exp.equalsTo(x)); - delete x; + ASSERT_TRUE(exp.isSameShape(x)); + ASSERT_TRUE(exp.equalsTo(x)); + delete x; } TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_0) { - auto x = NDArrayFactory::create('c', {1, 60}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); - x.reshapei('c', {3, 4, 5}); - - x.permutei({0, 1, 2}); - x.streamline(); - -// x.printShapeInfo("{0, 1, 2} shape"); -// x.printBuffer("{0, 1, 2} data"); - - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 5}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, + 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, + 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, + 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, + 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, + 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); + x.reshapei('c', {3, 4, 5}); + + x.permutei({0, 1, 2}); + x.streamline(); + + // x.printShapeInfo("{0, 1, 2} shape"); + // x.printBuffer("{0, 1, 2} data"); + + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_1) { - auto x = NDArrayFactory::create('c', {1, 60}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); - x.reshapei('c', {3, 4, 5}); - - x.permutei({0, 1, 2}); - x.streamline(); - -// x.printShapeInfo("{0, 1, 2} shape"); -// x.printBuffer("{0, 1, 2} data"); - - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 5}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, + 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, + 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, + 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, + 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, + 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); + x.reshapei('c', {3, 4, 5}); + + x.permutei({0, 1, 2}); + x.streamline(); + + // x.printShapeInfo("{0, 1, 2} shape"); + // x.printBuffer("{0, 1, 2} data"); + + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_2) { - //auto x = NDArrayFactory::create('c', {1, 60}); - auto xx = NDArrayFactory::linspace(1.f, 60.f, 60); //('c', {1, 60}); -// auto x = *xx; - //x.linspace(1); -// auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); -// x.reshapei('c', {3, 4, 5}); - -// x.permutei({0, 1, 2}); -// x.streamline(); - -// x.printShapeInfo("{0, 1, 2} shape"); -// x.printBuffer("{0, 1, 2} data"); - -// ASSERT_TRUE(exp.isSameShape(&x)); -// ASSERT_TRUE(exp.equalsTo(&x)); - delete xx; + // auto x = NDArrayFactory::create('c', {1, 60}); + auto xx = NDArrayFactory::linspace(1.f, 60.f, 60); //('c', {1, 60}); + // auto x = *xx; + // x.linspace(1); + // auto exp = NDArrayFactory::create('c', {3, 4, 5}, + // {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, + // 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, + // 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, + // 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, + // 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, + // 57.0f, 58.0f, 59.0f, 60.0}); x.reshapei('c', {3, 4, 5}); + + // x.permutei({0, 1, 2}); + // x.streamline(); + + // x.printShapeInfo("{0, 1, 2} shape"); + // x.printBuffer("{0, 1, 2} data"); + + // ASSERT_TRUE(exp.isSameShape(&x)); + // ASSERT_TRUE(exp.equalsTo(&x)); + delete xx; } TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_3) { - auto x = NDArrayFactory::create('c', {1, 60}); - //x.linspace(1); - for (int l = 0; l < x.lengthOf(); l++) - x.p(l, float(l + 1.f)); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); - x.reshapei('c', {3, 4, 5}); + auto x = NDArrayFactory::create('c', {1, 60}); + // x.linspace(1); + for (int l = 0; l < x.lengthOf(); l++) x.p(l, float(l + 1.f)); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 5}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, + 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, + 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, + 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, + 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, + 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); + x.reshapei('c', {3, 4, 5}); - x.permutei({0, 1, 2}); - x.streamline(); + x.permutei({0, 1, 2}); + x.streamline(); -// x.printShapeInfo("{0, 1, 2} shape"); -// x.printBuffer("{0, 1, 2} data"); + // x.printShapeInfo("{0, 1, 2} shape"); + // x.printBuffer("{0, 1, 2} data"); - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(NDArrayCudaBasicsTests, Test_Empty_1) { - auto x = NDArrayFactory::empty(); - ASSERT_TRUE(x.isActualOnHostSide()); - ASSERT_TRUE(x.isEmpty()); + auto x = NDArrayFactory::empty(); + ASSERT_TRUE(x.isActualOnHostSide()); + ASSERT_TRUE(x.isEmpty()); } TEST_F(NDArrayCudaBasicsTests, Test_Empty_2) { - auto x = NDArrayFactory::empty_(); + auto x = NDArrayFactory::empty_(); - ASSERT_TRUE(x->isEmpty()); - delete x; + ASSERT_TRUE(x->isEmpty()); + delete x; } TEST_F(NDArrayCudaBasicsTests, Test_Empty_3) { - auto x = NDArrayFactory::empty(sd::DataType::FLOAT32); + auto x = NDArrayFactory::empty(sd::DataType::FLOAT32); - ASSERT_TRUE(x.isEmpty()); + ASSERT_TRUE(x.isEmpty()); } TEST_F(NDArrayCudaBasicsTests, Test_Empty_4) { - auto x = NDArrayFactory::empty_(sd::DataType::FLOAT32); + auto x = NDArrayFactory::empty_(sd::DataType::FLOAT32); - ASSERT_TRUE(x->isEmpty()); - delete x; + ASSERT_TRUE(x->isEmpty()); + delete x; } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayListTests.cpp b/libnd4j/tests_cpu/layers_tests/NDArrayListTests.cpp index 2de3e4651377..7110561e5ef8 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayListTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NDArrayListTests.cpp @@ -20,54 +20,50 @@ #include #include + #include "testlayers.h" using namespace sd; class NDArrayListTests : public testing::Test { -public: - + public: }; - TEST_F(NDArrayListTests, BasicTests_1) { - NDArrayList list(false); + NDArrayList list(false); - auto x = NDArrayFactory::create('c', {1, 10}); - auto y = NDArrayFactory::create('c', {1, 10}); + auto x = NDArrayFactory::create('c', {1, 10}); + auto y = NDArrayFactory::create('c', {1, 10}); - ASSERT_EQ(ND4J_STATUS_OK, list.write(1, new NDArray(x.dup()))); + ASSERT_EQ(ND4J_STATUS_OK, list.write(1, x.dup())); - //ASSERT_EQ(ND4J_STATUS_DOUBLE_WRITE, list.write(1, &y)); + // ASSERT_EQ(ND4J_STATUS_DOUBLE_WRITE, list.write(1, &y)); } TEST_F(NDArrayListTests, BasicTests_2) { - NDArrayList list(false); + NDArrayList list(false); - auto x = NDArrayFactory::create('c', {1, 10}); - auto y = NDArrayFactory::create('c', {1, 7}); + auto x = NDArrayFactory::create('c', {1, 10}); + auto y = NDArrayFactory::create('c', {1, 7}); - ASSERT_EQ(ND4J_STATUS_OK, list.write(1, new NDArray(x.dup()))); + ASSERT_EQ(ND4J_STATUS_OK, list.write(1, x.dup())); - ASSERT_EQ(ND4J_STATUS_BAD_INPUT, list.write(0, &y)); + ASSERT_EQ(ND4J_STATUS_BAD_INPUT, list.write(0, y)); } - TEST_F(NDArrayListTests, Test_Stack_UnStack_1) { - auto input = NDArrayFactory::create('c', {10, 10}); - input.linspace(1); - - NDArrayList list(false); + auto input = NDArrayFactory::create('c', {10, 10}); + input.linspace(1); - list.unstack(&input, 0); + NDArrayList list(false); - ASSERT_EQ(10, list.elements()); + list.unstack(input, 0); - auto array = list.stack(); + ASSERT_EQ(10, list.elements()); - ASSERT_TRUE(input.isSameShape(array)); + auto array = list.stack(); - ASSERT_TRUE(input.equalsTo(array)); + ASSERT_TRUE(input.isSameShape(array)); - delete array; + ASSERT_TRUE(input.equalsTo(array)); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp b/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp index 8150976e1880..7d7a74d227c8 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp @@ -18,1413 +18,1485 @@ // Created by raver119 on 04.08.17. // -#include "testlayers.h" -#include #include #include +#include + +#include "testlayers.h" + using namespace sd; ////////////////////////////////////////////////////////////////////// class NDArrayTest : public testing::Test { -public: - int alpha = 0; + public: + int alpha = 0; - Nd4jLong *cShape = new Nd4jLong[8]{2, 2, 2, 2, 1, 8192, 1, 99}; - Nd4jLong *fShape = new Nd4jLong[8]{2, 2, 2, 1, 2, 8192, 1, 102}; + Nd4jLong *cShape = new Nd4jLong[8]{2, 2, 2, 2, 1, 8192, 1, 99}; + Nd4jLong *fShape = new Nd4jLong[8]{2, 2, 2, 1, 2, 8192, 1, 102}; - float arr1[6] = {1,2,3,4,5,6}; - Nd4jLong shape1[8] = {2,2,3,3,1,8192,1,99}; - float arr2[48] = {1,2,3,1,2,3,4,5,6,4,5,6,1,2,3,1,2,3,4,5,6,4,5,6,1,2,3,1,2,3,4,5,6,4,5,6,1,2,3,1,2,3,4,5,6,4,5,6}; - Nd4jLong shape2[10] = {3,2,4,6,24,6,1,8192,1,99}; - const std::vector tileShape1 = {2,2,2}; + float arr1[6] = {1, 2, 3, 4, 5, 6}; + Nd4jLong shape1[8] = {2, 2, 3, 3, 1, 8192, 1, 99}; + float arr2[48] = {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 1, 2, 3, 1, + 2, 3, 4, 5, 6, 4, 5, 6, 1, 2, 3, 1, 2, 3, 4, 5, + 6, 4, 5, 6, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6}; + Nd4jLong shape2[10] = {3, 2, 4, 6, 24, 6, 1, 8192, 1, 99}; + const std::vector tileShape1 = {2, 2, 2}; - - ~NDArrayTest() { - delete[] cShape; - delete[] fShape; - } + ~NDArrayTest() { + delete[] cShape; + delete[] fShape; + } }; - ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestDup1) { + NDArray array(arr1, shape1); - NDArray array(arr1, shape1); + auto arrC = new NDArray(array.dup('c')); + auto arrF = new NDArray(array.dup('f')); - auto arrC = new NDArray(array.dup('c')); - auto arrF = new NDArray(array.dup('f')); + ASSERT_TRUE(array.equalsTo(arrF)); + ASSERT_TRUE(array.equalsTo(arrC)); - ASSERT_TRUE(array.equalsTo(arrF)); - ASSERT_TRUE(array.equalsTo(arrC)); + ASSERT_TRUE(arrF->equalsTo(arrC)); - ASSERT_TRUE(arrF->equalsTo(arrC)); - - delete arrC; - delete arrF; + delete arrC; + delete arrF; } - ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, AssignScalar1) { - auto array = NDArrayFactory::create_('c', {1, 10}); + auto array = NDArrayFactory::create_('c', {1, 10}); - array->assign(2.0f); + array->assign(2.0f); - for (int i = 0; i < array->lengthOf(); i++) { - ASSERT_EQ(2.0f, array->e(i)); - } + for (int i = 0; i < array->lengthOf(); i++) { + ASSERT_EQ(2.0f, array->e(i)); + } - delete array; + delete array; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, NDArrayOrder1) { - // original part - auto c = new float[4] {1, 2, 3, 4}; + // original part + auto c = new float[4]{1, 2, 3, 4}; - // expected part - auto f = new float[4] {1, 3, 2, 4}; + // expected part + auto f = new float[4]{1, 3, 2, 4}; - auto arrayC = new NDArray(c, cShape); - auto arrayF = new NDArray(arrayC->dup('f')); - auto arrayC2 = new NDArray(arrayF->dup('c')); + auto arrayC = new NDArray(c, cShape); + auto arrayF = new NDArray(arrayC->dup('f')); + auto arrayC2 = new NDArray(arrayF->dup('c')); - arrayF->syncToHost(); - arrayC2->syncToHost(); + arrayF->syncToHost(); + arrayC2->syncToHost();ASSERT_EQ('c', arrayC->ordering()); + ASSERT_EQ('f', arrayF->ordering()); + ASSERT_EQ('c', arrayC2->ordering()); - ASSERT_EQ('c', arrayC->ordering()); - ASSERT_EQ('f', arrayF->ordering()); - ASSERT_EQ('c', arrayC2->ordering()); + for (int i = 0; i < 4; i++) { + ASSERT_NEAR(f[i], arrayF->bufferAsT()[i], 1e-5f); + } - for (int i = 0; i < 4; i++) { - ASSERT_NEAR(f[i], arrayF->bufferAsT()[i], 1e-5f); - } + for (int i = 0; i < 8; i++) { + ASSERT_EQ(fShape[i], arrayF->shapeInfo()[i]); + } - for (int i = 0; i < 8; i++) { - ASSERT_EQ(fShape[i], arrayF->shapeInfo()[i]); - } + for (int i = 0; i < 4; i++) { + ASSERT_NEAR(c[i], arrayC2->bufferAsT()[i], 1e-5f); + } - for (int i = 0; i < 4; i++) { - ASSERT_NEAR(c[i], arrayC2->bufferAsT()[i], 1e-5f); - } + for (int i = 0; i < 8; i++) { + ASSERT_EQ(cShape[i], arrayC2->shapeInfo()[i]); + } - for (int i = 0; i < 8; i++) { - ASSERT_EQ(cShape[i], arrayC2->shapeInfo()[i]); - } - - - delete[] c; - delete[] f; - delete arrayC; - delete arrayF; - delete arrayC2; + delete[] c; + delete[] f; + delete arrayC; + delete arrayF; + delete arrayC2; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestGetScalar1) { - auto c = new float[4] {1, 2, 3, 4}; - auto cShape = new Nd4jLong[8]{2, 2, 2, 2, 1, 8192, 1, 99}; - - auto arrayC = new NDArray(c, cShape); + auto c = new float[4]{1, 2, 3, 4}; + auto cShape = new Nd4jLong[8]{2, 2, 2, 2, 1, 8192, 1, 99}; - ASSERT_NEAR(3.0f, arrayC->e(1, 0), 1e-5f); - ASSERT_NEAR(4.0f, arrayC->e(1, 1), 1e-5f); + auto arrayC = new NDArray(c, cShape); - auto arrayF = new NDArray(arrayC->dup('f')); + ASSERT_NEAR(3.0f, arrayC->e(1, 0), 1e-5f); + ASSERT_NEAR(4.0f, arrayC->e(1, 1), 1e-5f); - ASSERT_NEAR(3.0f, arrayF->e(1, 0), 1e-5f); - ASSERT_NEAR(4.0f, arrayF->e(1, 1), 1e-5f); + auto arrayF = new NDArray(arrayC->dup('f')); + ASSERT_NEAR(3.0f, arrayF->e(1, 0), 1e-5f); + ASSERT_NEAR(4.0f, arrayF->e(1, 1), 1e-5f); - arrayF->p(1, 0, 7.0f); - ASSERT_NEAR(7.0f, arrayF->e(1, 0), 1e-5f); + arrayF->p(1, 0, 7.0f); + ASSERT_NEAR(7.0f, arrayF->e(1, 0), 1e-5f); + arrayC->p(1, 1, 9.0f); + ASSERT_NEAR(9.0f, arrayC->e(1, 1), 1e-5f); - arrayC->p(1, 1, 9.0f); - ASSERT_NEAR(9.0f, arrayC->e(1, 1), 1e-5f); + delete[] c; + delete[] cShape; - delete[] c; - delete[] cShape; - - delete arrayC; - delete arrayF; + delete arrayC; + delete arrayF; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, EqualityTest1) { - auto arrayA = NDArrayFactory::create_('f', {3, 5}); - auto arrayB = NDArrayFactory::create_('f', {3, 5}); - auto arrayC = NDArrayFactory::create_('f', {3, 5}); + auto arrayA = NDArrayFactory::create_('f', {3, 5}); + auto arrayB = NDArrayFactory::create_('f', {3, 5}); + auto arrayC = NDArrayFactory::create_('f', {3, 5}); - auto arrayD = NDArrayFactory::create_('f', {2, 4}); - auto arrayE = NDArrayFactory::create_('f', {1, 15}); + auto arrayD = NDArrayFactory::create_('f', {2, 4}); + auto arrayE = NDArrayFactory::create_('f', {1, 15}); - for (int i = 0; i < arrayA->rows(); i++) { - for (int k = 0; k < arrayA->columns(); k++) { - arrayA->p(i, k, (float) i); - } + for (int i = 0; i < arrayA->rows(); i++) { + for (int k = 0; k < arrayA->columns(); k++) { + arrayA->p(i, k, (float)i); } + } - for (int i = 0; i < arrayB->rows(); i++) { - for (int k = 0; k < arrayB->columns(); k++) { - arrayB->p(i, k, (float) i); - } + for (int i = 0; i < arrayB->rows(); i++) { + for (int k = 0; k < arrayB->columns(); k++) { + arrayB->p(i, k, (float)i); } + } - for (int i = 0; i < arrayC->rows(); i++) { - for (int k = 0; k < arrayC->columns(); k++) { - arrayC->p(i, k, (float) i+1); - } + for (int i = 0; i < arrayC->rows(); i++) { + for (int k = 0; k < arrayC->columns(); k++) { + arrayC->p(i, k, (float)i + 1); } + } - //nd4j_printf("A B\n",""); - ASSERT_TRUE(arrayA->equalsTo(arrayB, 1e-5)); + // nd4j_printf("A B\n",""); + ASSERT_TRUE(arrayA->equalsTo(arrayB, 1e-5)); - //nd4j_printf("C B\n",""); - ASSERT_FALSE(arrayC->equalsTo(arrayB, 1e-5)); + // nd4j_printf("C B\n",""); + ASSERT_FALSE(arrayC->equalsTo(arrayB, 1e-5)); - //nd4j_printf("D B\n",""); - ASSERT_FALSE(arrayD->equalsTo(arrayB, 1e-5)); + // nd4j_printf("D B\n",""); + ASSERT_FALSE(arrayD->equalsTo(arrayB, 1e-5)); - //nd4j_printf("E B\n",""); - ASSERT_FALSE(arrayE->equalsTo(arrayB, 1e-5)); + // nd4j_printf("E B\n",""); + ASSERT_FALSE(arrayE->equalsTo(arrayB, 1e-5)); - delete arrayA; - delete arrayB; - delete arrayC; - delete arrayD; - delete arrayE; + delete arrayA; + delete arrayB; + delete arrayC; + delete arrayD; + delete arrayE; } TEST_F(NDArrayTest, TestTad1) { - auto array = NDArrayFactory::create_('c', {3, 3}); + auto array = NDArrayFactory::create_('c', {3, 3}); - auto row2 = (*array)(1, {0}); + auto row2 = (*array)(1, {0}); - ASSERT_TRUE(row2.isView()); - ASSERT_EQ(3, row2.lengthOf()); + ASSERT_TRUE(row2.isView()); + ASSERT_EQ(3, row2.lengthOf()); - row2.assign(1.0); + row2.assign(1.0); - ASSERT_NEAR(3.0f, array->sumNumber().e(0), 1e-5); - delete array; + ASSERT_NEAR(3.0f, array->sumNumber().e(0), 1e-5); + delete array; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestTad2) { - auto array = NDArrayFactory::create_('c', {3, 3}); + auto array = NDArrayFactory::create_('c', {3, 3}); - ASSERT_EQ(3, array->tensorsAlongDimension({1})); + ASSERT_EQ(3, array->tensorsAlongDimension({1})); - delete array; + delete array; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestTad3) { - auto array = NDArrayFactory::create_('c', {4, 3}); + auto array = NDArrayFactory::create_('c', {4, 3}); - auto row2 = (*array)(1, {0}); + auto row2 = (*array)(1, {0}); - ASSERT_TRUE(row2.isView()); - ASSERT_EQ(3, row2.lengthOf()); - delete array; + ASSERT_TRUE(row2.isView()); + ASSERT_EQ(3, row2.lengthOf()); + delete array; } - TEST_F(NDArrayTest, TestPermuteReshape1) { + NDArray array('c', {2, 2, 5, 5}, sd::DataType::FLOAT32); + int pShape[] = {4, 2, 5, 5, 2, 25, 5, 1, 50, 8192, 0, 99}; + int rShape[] = {3, 2, 25, 2, 25, 1, 50, 8192, 0, 99}; - NDArray array('c', {2, 2, 5, 5}, sd::DataType::FLOAT32); - int pShape[] = {4, 2, 5, 5, 2, 25, 5, 1, 50, 8192, 0, 99}; - int rShape[] = {3, 2, 25, 2, 25, 1, 50, 8192, 0, 99}; - - array.permutei({1, 2, 3, 0}); + array.permutei({1, 2, 3, 0}); - for (int e = 0; e < shape::shapeInfoLength(array.shapeInfo()); e++) - ASSERT_EQ(pShape[e], array.shapeInfo()[e]); + for (int e = 0; e < shape::shapeInfoLength(array.shapeInfo()); e++) + ASSERT_EQ(pShape[e], array.shapeInfo()[e]); - array.reshapei('c', {2, 25, 2}); + array.reshapei('c', {2, 25, 2}); - for (int e = 0; e < shape::shapeInfoLength(array.shapeInfo()); e++) - ASSERT_EQ(rShape[e], array.shapeInfo()[e]); + for (int e = 0; e < shape::shapeInfoLength(array.shapeInfo()); e++) + ASSERT_EQ(rShape[e], array.shapeInfo()[e]); } - TEST_F(NDArrayTest, TestPermuteReshape2) { - auto array = NDArrayFactory::create('c', {2, 2, 5, 5, 6, 6}); - int pShape[] = {6, 2, 2, 6, 6, 5, 5, 900, 1800, 6, 1, 180, 36, 8192, 0, 99}; - int rShape[] = {3, 2, 72, 25, 1800, 25, 1, 8192, 1, 99}; - + auto array = NDArrayFactory::create('c', {2, 2, 5, 5, 6, 6}); + int pShape[] = {6, 2, 2, 6, 6, 5, 5, 900, 1800, 6, 1, 180, 36, 8192, 0, 99}; + int rShape[] = {3, 2, 72, 25, 1800, 25, 1, 8192, 1, 99}; - // array.printShapeInfo("before"); + // array.printShapeInfo("before"); - array.permutei({1, 0, 4, 5, 2, 3}); + array.permutei({1, 0, 4, 5, 2, 3}); - // array.printShapeInfo("after "); + // array.printShapeInfo("after "); - auto aShape = array.shapeInfo(); + auto aShape = array.shapeInfo(); - for (int e = 0; e < shape::shapeInfoLength(array.shapeInfo()); e++) - ASSERT_EQ(pShape[e], aShape[e]); + for (int e = 0; e < shape::shapeInfoLength(array.shapeInfo()); e++) + ASSERT_EQ(pShape[e], aShape[e]); - array.reshapei('c', {2, 72, 25}); + array.reshapei('c', {2, 72, 25}); - for (int e = 0; e < shape::shapeInfoLength(array.shapeInfo()); e++) - ASSERT_EQ(rShape[e], array.shapeInfo()[e]); + for (int e = 0; e < shape::shapeInfoLength(array.shapeInfo()); e++) + ASSERT_EQ(rShape[e], array.shapeInfo()[e]); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestRepeat1) { + auto eBuffer = new float[8]{1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0}; + auto eShape = new Nd4jLong[8]{2, 4, 2, 2, 1, 8192, 1, 99}; + NDArray array('c', {2, 2}, sd::DataType::FLOAT32); + auto exp = new NDArray(eBuffer, eShape); + for (int e = 0; e < array.lengthOf(); e++) array.p(e, e + 1); - auto eBuffer = new float[8] {1.0,2.0,1.0,2.0,3.0,4.0,3.0,4.0}; - auto eShape = new Nd4jLong[8]{2, 4, 2, 2, 1, 8192, 1, 99}; - NDArray array('c', {2, 2}, sd::DataType::FLOAT32); - auto exp = new NDArray(eBuffer, eShape); - for (int e = 0; e < array.lengthOf(); e++) - array.p(e, e + 1); + // array.printBuffer(); - // array.printBuffer(); + auto rep = array.repeat(0, {2}); - auto rep = array.repeat(0, {2}); + ASSERT_EQ(4, rep.sizeAt(0)); + ASSERT_EQ(2, rep.sizeAt(1)); - ASSERT_EQ(4, rep.sizeAt(0)); - ASSERT_EQ(2, rep.sizeAt(1)); + ASSERT_TRUE(exp->equalsTo(rep)); - ASSERT_TRUE(exp->equalsTo(rep)); - - delete[] eBuffer; - delete[] eShape; - delete exp; + delete[] eBuffer; + delete[] eShape; + delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestRepeat2) { - auto eBuffer = new float[8] {1.0,2.0,1.0,2.0,3.0,4.0,3.0,4.0}; - auto eShape = new Nd4jLong[8]{2, 4, 2, 2, 1, 8192, 1, 99}; - auto array = NDArrayFactory::create_('c', {2, 2}); - auto exp = new NDArray(eBuffer, eShape); - for (int e = 0; e < array->lengthOf(); e++) - array->p(e, e + 1); + auto eBuffer = new float[8]{1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0}; + auto eShape = new Nd4jLong[8]{2, 4, 2, 2, 1, 8192, 1, 99}; + auto array = NDArrayFactory::create_('c', {2, 2}); + auto exp = new NDArray(eBuffer, eShape); + for (int e = 0; e < array->lengthOf(); e++) array->p(e, e + 1); - //array->printBuffer(); + // array->printBuffer(); - auto rep = new NDArray(exp->dup()); - rep->assign(0.); - array->repeat(0, {2}, *rep); - //rep->printIndexedBuffer("Repeated"); + auto rep = new NDArray(exp->dup()); + rep->assign(0.); + array->repeat(0, {2}, *rep); + // rep->printIndexedBuffer("Repeated"); - ASSERT_EQ(4, rep->sizeAt(0)); - ASSERT_EQ(2, rep->sizeAt(1)); + ASSERT_EQ(4, rep->sizeAt(0)); + ASSERT_EQ(2, rep->sizeAt(1)); - //rep->printBuffer(); + // rep->printBuffer(); - ASSERT_TRUE(exp->equalsTo(rep)); + ASSERT_TRUE(exp->equalsTo(rep)); - delete[] eBuffer; - delete[] eShape; - delete array; - delete exp; - delete rep; + delete[] eBuffer; + delete[] eShape; + delete array; + delete exp; + delete rep; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestIndexedPut1) { - auto array = NDArrayFactory::create_('f', {3, 3}); + auto array = NDArrayFactory::create_('f', {3, 3}); - array->p(4, 1.0f); - ASSERT_EQ(1.0f, array->e(4)); - //array->printBuffer(); + array->p(4, 1.0f); + ASSERT_EQ(1.0f, array->e(4)); + // array->printBuffer(); - delete array; + delete array; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestSum1) { - // Nd4jLong *cShape = new Nd4jLong[8]{2, 2, 2, 2, 1, 8192, 1, 99}; - float *c = new float[4] {1, 2, 3, 4}; + // Nd4jLong *cShape = new Nd4jLong[8]{2, 2, 2, 2, 1, 8192, 1, 99}; + float *c = new float[4]{1, 2, 3, 4}; - auto array = new NDArray(c, cShape); + auto array = new NDArray(c, cShape); - ASSERT_EQ(10.0f, array->sumNumber().e(0)); - ASSERT_EQ(2.5f, array->meanNumber().e(0)); + ASSERT_EQ(10.0f, array->sumNumber().e(0)); + ASSERT_EQ(2.5f, array->meanNumber().e(0)); - delete[] c; - delete array; + delete[] c; + delete array; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestAddiRowVector) { - float *c = new float[4] {1, 2, 3, 4}; - float *e = new float[4] {2, 3, 4, 5}; + float *c = new float[4]{1, 2, 3, 4}; + float *e = new float[4]{2, 3, 4, 5}; - auto array = new NDArray(c, cShape); - auto row = NDArrayFactory::create_('c', {1, 2}); - auto exp = new NDArray(e, cShape); - row->assign(1.0f); + auto array = new NDArray(c, cShape); + auto row = NDArrayFactory::create_('c', {1, 2}); + auto exp = new NDArray(e, cShape); + row->assign(1.0f); - array->addiRowVector(*row); + array->addiRowVector(*row); - ASSERT_TRUE(exp->equalsTo(array)); + ASSERT_TRUE(exp->equalsTo(array)); - delete[] c; - delete[] e; + delete[] c; + delete[] e; - delete array; - delete row; - delete exp; + delete array; + delete row; + delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestAddiColumnVector) { - float arr1[] = {1, 2, 3, 4}; - float arr2[] = {5, 6}; - float arr3[] = {6, 7, 9, 10}; - Nd4jLong shape1[] = {2,2,2,2,1,8192,1,99}; - Nd4jLong shape2[] = {2,2,1,1,1,8192,1,99}; - NDArray matrix(arr1, shape1); - NDArray column(arr2, shape2); - NDArray exp(arr3, shape1); + float arr1[] = {1, 2, 3, 4}; + float arr2[] = {5, 6}; + float arr3[] = {6, 7, 9, 10}; + Nd4jLong shape1[] = {2, 2, 2, 2, 1, 8192, 1, 99}; + Nd4jLong shape2[] = {2, 2, 1, 1, 1, 8192, 1, 99}; + NDArray matrix(arr1, shape1); + NDArray column(arr2, shape2); + NDArray exp(arr3, shape1); - matrix.addiColumnVector(column); - ASSERT_TRUE(exp.isSameShapeStrict(matrix)); - ASSERT_TRUE(exp.equalsTo(&matrix)); + matrix.addiColumnVector(column); + ASSERT_TRUE(exp.isSameShapeStrict(matrix)); + ASSERT_TRUE(exp.equalsTo(&matrix)); } - ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestMuliColumnVector) { - float arr1[] = {1, 2, 3, 4}; - float arr2[] = {5, 6}; - float arr3[] = {5, 10, 18, 24}; - Nd4jLong shape1[] = {2,2,2,2,1,8192,1,99}; - Nd4jLong shape2[] = {2,2,1,1,1,8192,1,99}; - NDArray matrix(arr1, shape1); - NDArray column(arr2, shape2); - NDArray exp(arr3, shape1); + float arr1[] = {1, 2, 3, 4}; + float arr2[] = {5, 6}; + float arr3[] = {5, 10, 18, 24}; + Nd4jLong shape1[] = {2, 2, 2, 2, 1, 8192, 1, 99}; + Nd4jLong shape2[] = {2, 2, 1, 1, 1, 8192, 1, 99}; + NDArray matrix(arr1, shape1); + NDArray column(arr2, shape2); + NDArray exp(arr3, shape1); - matrix.muliColumnVector(column); + matrix.muliColumnVector(column); - ASSERT_TRUE(exp.isSameShapeStrict(matrix)); - ASSERT_TRUE(exp.equalsTo(&matrix)); + ASSERT_TRUE(exp.isSameShapeStrict(matrix)); + ASSERT_TRUE(exp.equalsTo(&matrix)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test3D_1) { - auto arrayC = NDArrayFactory::create_('c', {2, 5, 10}); - auto arrayF = NDArrayFactory::create_('f', {2, 5, 10}); + auto arrayC = NDArrayFactory::create_('c', {2, 5, 10}); + auto arrayF = NDArrayFactory::create_('f', {2, 5, 10}); - ASSERT_EQ(100, arrayC->lengthOf()); - ASSERT_EQ(100, arrayF->lengthOf()); + ASSERT_EQ(100, arrayC->lengthOf()); + ASSERT_EQ(100, arrayF->lengthOf()); - ASSERT_EQ('c', arrayC->ordering()); - ASSERT_EQ('f', arrayF->ordering()); + ASSERT_EQ('c', arrayC->ordering()); + ASSERT_EQ('f', arrayF->ordering()); - delete arrayC; - delete arrayF; + delete arrayC; + delete arrayF; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestTranspose1) { - auto arrayC = NDArrayFactory::create_('c', {2, 5, 10}); + auto arrayC = NDArrayFactory::create_('c', {2, 5, 10}); - auto expC = new Nd4jLong[10] {3, 2, 5, 10, 50, 10, 1, 16384, 1, 99}; - auto expT = new Nd4jLong[10] {3, 10, 5, 2, 1, 10, 50, 16384, 1, 102}; + auto expC = new Nd4jLong[10]{3, 2, 5, 10, 50, 10, 1, 16384, 1, 99}; + auto expT = new Nd4jLong[10]{3, 10, 5, 2, 1, 10, 50, 16384, 1, 102}; - auto arrayT = arrayC->transpose(); + auto arrayT = arrayC->transpose(); - for (int e = 0; e < arrayC->rankOf(); e++) { - ASSERT_EQ(shape::shapeOf(expC)[e], arrayC->sizeAt(e)); - ASSERT_EQ(shape::shapeOf(expT)[e], arrayT.sizeAt(e)); - } + for (int e = 0; e < arrayC->rankOf(); e++) { + ASSERT_EQ(shape::shapeOf(expC)[e], arrayC->sizeAt(e)); + ASSERT_EQ(shape::shapeOf(expT)[e], arrayT.sizeAt(e)); + } - delete arrayC; - delete[] expC; - delete[] expT; + delete arrayC; + delete[] expC; + delete[] expT; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestTranspose2) { - auto arrayC = NDArrayFactory::create_('c', {2, 5, 10}); - - auto expC = new Nd4jLong[10] {3, 2, 5, 10, 50, 10, 1, 16384, 1, 99}; - auto expT = new Nd4jLong[10] {3, 10, 5, 2, 1, 10, 50, 16384, 1, 102}; + auto arrayC = NDArrayFactory::create_('c', {2, 5, 10}); - arrayC->transposei(); + auto expC = new Nd4jLong[10]{3, 2, 5, 10, 50, 10, 1, 16384, 1, 99}; + auto expT = new Nd4jLong[10]{3, 10, 5, 2, 1, 10, 50, 16384, 1, 102}; + arrayC->transposei(); - for (int e = 0; e < arrayC->rankOf(); e++) { - ASSERT_EQ(shape::shapeOf(expT)[e], arrayC->sizeAt(e)); - } + for (int e = 0; e < arrayC->rankOf(); e++) { + ASSERT_EQ(shape::shapeOf(expT)[e], arrayC->sizeAt(e)); + } - delete arrayC; - delete[] expC; - delete[] expT; + delete arrayC; + delete[] expC; + delete[] expT; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestSumAlongDimension1) { + NDArray array('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::FLOAT32); - NDArray array('c', {2,2}, {1,2,3,4}, sd::DataType::FLOAT32); + auto res = array.reduceAlongDimension(reduce::Sum, {0}); - auto res = array.reduceAlongDimension(reduce::Sum, {0}); + ASSERT_EQ(2, res.lengthOf()); - ASSERT_EQ(2, res.lengthOf()); - - ASSERT_EQ(4.0f, res.e(0)); - ASSERT_EQ(6.0f, res.e(1)); + ASSERT_EQ(4.0f, res.e(0)); + ASSERT_EQ(6.0f, res.e(1)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestSumAlongDimension2) { - float *c = new float[4] {1, 2, 3, 4}; - auto array = new NDArray(c, cShape); + float *c = new float[4]{1, 2, 3, 4}; + auto array = new NDArray(c, cShape); - auto res = array->reduceAlongDimension(reduce::Sum, {1}); + auto res = array->reduceAlongDimension(reduce::Sum, {1}); - ASSERT_EQ(2, res.lengthOf()); + ASSERT_EQ(2, res.lengthOf()); - ASSERT_EQ(3.0f, res.e(0)); - ASSERT_EQ(7.0f, res.e(1)); + ASSERT_EQ(3.0f, res.e(0)); + ASSERT_EQ(7.0f, res.e(1)); - delete[] c; - delete array; + delete[] c; + delete array; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestReduceAlongDimension1) { - float *c = new float[4] {1, 2, 3, 4}; - auto array = new NDArray(c, cShape); + float *c = new float[4]{1, 2, 3, 4}; + auto array = new NDArray(c, cShape); - auto res = array->reduceAlongDimension(reduce::Sum, {1}); + auto res = array->reduceAlongDimension(reduce::Sum, {1}); - ASSERT_EQ(2, res.lengthOf()); + ASSERT_EQ(2, res.lengthOf()); - ASSERT_EQ(3.0f, res.e(0)); - ASSERT_EQ(7.0f, res.e(1)); + ASSERT_EQ(3.0f, res.e(0)); + ASSERT_EQ(7.0f, res.e(1)); - delete[] c; - delete array; + delete[] c; + delete array; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestTransform1) { - float *c = new float[4] {-1, -2, -3, -4}; - auto array = new NDArray(c, cShape); + float *c = new float[4]{-1, -2, -3, -4}; + auto array = new NDArray(c, cShape); - float *e = new float[4] {1, 2, 3, 4}; - auto exp = new NDArray(e, cShape); + float *e = new float[4]{1, 2, 3, 4}; + auto exp = new NDArray(e, cShape); - array->applyTransform(transform::Abs, *array); + array->applyTransform(transform::Abs, *array); - ASSERT_TRUE(exp->equalsTo(array)); + ASSERT_TRUE(exp->equalsTo(array)); - delete[] c; - delete array; - delete[] e; - delete exp; + delete[] c; + delete array; + delete[] e; + delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestReduceScalar1) { - float *c = new float[4] {-1, -2, -3, -4}; - auto array = new NDArray(c, cShape); + float *c = new float[4]{-1, -2, -3, -4}; + auto array = new NDArray(c, cShape); - ASSERT_EQ(-4, array->reduceNumber(reduce::Min, nullptr).e(0)); + ASSERT_EQ(-4, array->reduceNumber(reduce::Min, nullptr).e(0)); - delete[] c; - delete array; + delete[] c; + delete array; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestReduceScalar2) { - float *c = new float[4] {-1, -2, -3, -4}; - auto array = new NDArray(c, cShape); + float *c = new float[4]{-1, -2, -3, -4}; + auto array = new NDArray(c, cShape); - ASSERT_EQ(-10, array->reduceNumber(reduce::Sum, nullptr).e(0)); + ASSERT_EQ(-10, array->reduceNumber(reduce::Sum, nullptr).e(0)); - delete[] c; - delete array; + delete[] c; + delete array; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestReduceScalar3) { - auto array = new NDArray(arr1, shape1); + auto array = new NDArray(arr1, shape1); - ASSERT_EQ(21, array->reduceNumber(reduce::Sum, nullptr).e(0)); + ASSERT_EQ(21, array->reduceNumber(reduce::Sum, nullptr).e(0)); - delete array; + delete array; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestApplyTransform1) { - float *c = new float[4] {-1, -2, -3, -4}; - auto array = new NDArray(c, cShape); - - float *e = new float[4] {1, 2, 3, 4}; - auto exp = new NDArray(e, cShape); + float *c = new float[4]{-1, -2, -3, -4}; + auto array = new NDArray(c, cShape); - array->applyTransform(transform::Abs, *array); + float *e = new float[4]{1, 2, 3, 4}; + auto exp = new NDArray(e, cShape); + array->applyTransform(transform::Abs, *array); - ASSERT_TRUE(exp->equalsTo(array)); + ASSERT_TRUE(exp->equalsTo(array)); - delete[] c; - delete array; + delete[] c; + delete array; - delete[] e; - delete exp; + delete[] e; + delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestVectors1) { - float *c = new float[4]{-1, -2, -3, -4}; - auto array = new NDArray(c, cShape); + float *c = new float[4]{-1, -2, -3, -4}; + auto array = new NDArray(c, cShape); + auto vecShape = array->getShapeInfoAsVector(); + auto vecBuffer = array->getBufferAsVector(); - auto vecShape = array->getShapeInfoAsVector(); - auto vecBuffer = array->getBufferAsVector(); + ASSERT_EQ(8, vecShape.size()); + ASSERT_EQ(4, vecBuffer.size()); - ASSERT_EQ(8, vecShape.size()); - ASSERT_EQ(4, vecBuffer.size()); - - for (int e = 0; e < vecBuffer.size(); e++) { - ASSERT_NEAR(c[e], vecBuffer[e], 1e-5); - } + for (int e = 0; e < vecBuffer.size(); e++) { + ASSERT_NEAR(c[e], vecBuffer[e], 1e-5); + } - for (int e = 0; e < vecShape.size(); e++) { - ASSERT_EQ(cShape[e], vecShape[e]); - } + for (int e = 0; e < vecShape.size(); e++) { + ASSERT_EQ(cShape[e], vecShape[e]); + } - delete[] c; - delete array; + delete[] c; + delete array; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestChecks1) { - auto array = NDArrayFactory::create('c', {1, 5}); + auto array = NDArrayFactory::create('c', {1, 5}); - ASSERT_FALSE(array.isMatrix()); - ASSERT_FALSE(array.isScalar()); - ASSERT_TRUE(array.isVector()); - ASSERT_FALSE(array.isColumnVector()); - ASSERT_TRUE(array.isRowVector()); + ASSERT_FALSE(array.isMatrix()); + ASSERT_FALSE(array.isScalar()); + ASSERT_TRUE(array.isVector()); + ASSERT_FALSE(array.isColumnVector()); + ASSERT_TRUE(array.isRowVector()); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestChecks2) { - auto array = NDArrayFactory::create('c', {5, 5}); + auto array = NDArrayFactory::create('c', {5, 5}); - ASSERT_TRUE(array.isMatrix()); - ASSERT_FALSE(array.isScalar()); - ASSERT_FALSE(array.isVector()); - ASSERT_FALSE(array.isColumnVector()); - ASSERT_FALSE(array.isRowVector()); + ASSERT_TRUE(array.isMatrix()); + ASSERT_FALSE(array.isScalar()); + ASSERT_FALSE(array.isVector()); + ASSERT_FALSE(array.isColumnVector()); + ASSERT_FALSE(array.isRowVector()); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestChecks3) { - auto array = NDArrayFactory::create('c', {5, 1}); + auto array = NDArrayFactory::create('c', {5, 1}); - ASSERT_FALSE(array.isMatrix()); - ASSERT_FALSE(array.isScalar()); - ASSERT_TRUE(array.isVector()); - ASSERT_TRUE(array.isColumnVector()); - ASSERT_FALSE(array.isRowVector()); + ASSERT_FALSE(array.isMatrix()); + ASSERT_FALSE(array.isScalar()); + ASSERT_TRUE(array.isVector()); + ASSERT_TRUE(array.isColumnVector()); + ASSERT_FALSE(array.isRowVector()); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestChecks4) { - auto array = NDArrayFactory::create('c', {1, 1}); + auto array = NDArrayFactory::create('c', {1, 1}); - ASSERT_FALSE(array.isMatrix()); - ASSERT_FALSE(array.isVector()); - ASSERT_FALSE(array.isColumnVector()); - ASSERT_FALSE(array.isRowVector()); - ASSERT_TRUE(array.isScalar()); + ASSERT_FALSE(array.isMatrix()); + ASSERT_FALSE(array.isVector()); + ASSERT_FALSE(array.isColumnVector()); + ASSERT_FALSE(array.isRowVector()); + ASSERT_TRUE(array.isScalar()); } TEST_F(NDArrayTest, TestReductionAny1) { - auto array = NDArrayFactory::create('c', {2, 2}); - array.p(0, 1.0f); - array.p(1, 1.0f); - array.p(2, 0.0f); - array.p(3, 0.0f); - array.syncToDevice(); - auto result0 = array.reduceAlongDimension(reduce::Any, {0}); + auto array = NDArrayFactory::create('c', {2, 2}); + array.p(0, 1.0f); + array.p(1, 1.0f); + array.p(2, 0.0f); + array.p(3, 0.0f); + array.syncToDevice(); + auto result0 = array.reduceAlongDimension(reduce::Any, {0}); - ASSERT_EQ(2, result0.lengthOf()); + ASSERT_EQ(2, result0.lengthOf()); - ASSERT_NEAR(1.0f, result0.e(0), 1e-5f); - ASSERT_NEAR(1.0f, result0.e(1), 1e-5f); + ASSERT_NEAR(1.0f, result0.e(0), 1e-5f); + ASSERT_NEAR(1.0f, result0.e(1), 1e-5f); - auto result1 = array.reduceAlongDimension(reduce::Any, {1}); + auto result1 = array.reduceAlongDimension(reduce::Any, {1}); - ASSERT_EQ(2, result1.lengthOf()); + ASSERT_EQ(2, result1.lengthOf()); - ASSERT_NEAR(1.0f, result1.e(0), 1e-5f); - ASSERT_NEAR(0.0f, result1.e(1), 1e-5f); + ASSERT_NEAR(1.0f, result1.e(0), 1e-5f); + ASSERT_NEAR(0.0f, result1.e(1), 1e-5f); } TEST_F(NDArrayTest, TestReductionAll1) { - auto array = NDArrayFactory::create('c', {2, 2}); - array.p(0, 1.0f); - array.p(1, 1.0f); - array.p(2, 0.0f); - array.p(3, 0.0f); + auto array = NDArrayFactory::create('c', {2, 2}); + array.p(0, 1.0f); + array.p(1, 1.0f); + array.p(2, 0.0f); + array.p(3, 0.0f); - auto result0 = array.reduceAlongDimension(reduce::All, {0}); - auto result1 = array.reduceAlongDimension(reduce::All, {1}); + auto result0 = array.reduceAlongDimension(reduce::All, {0}); + auto result1 = array.reduceAlongDimension(reduce::All, {1}); - ASSERT_EQ(2, result0.lengthOf()); - ASSERT_EQ(2, result1.lengthOf()); + ASSERT_EQ(2, result0.lengthOf()); + ASSERT_EQ(2, result1.lengthOf()); - ASSERT_FALSE(result0.e(0)); - ASSERT_FALSE(result0.e(1)); + ASSERT_FALSE(result0.e(0)); + ASSERT_FALSE(result0.e(1)); - ASSERT_TRUE(result1.e(0)); - ASSERT_FALSE(result1.e(1)); + ASSERT_TRUE(result1.e(0)); + ASSERT_FALSE(result1.e(1)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestChecks5) { - auto array = NDArrayFactory::create('c', {5, 5, 5}); + auto array = NDArrayFactory::create('c', {5, 5, 5}); - ASSERT_FALSE(array.isMatrix()); - ASSERT_FALSE(array.isVector()); - ASSERT_FALSE(array.isColumnVector()); - ASSERT_FALSE(array.isRowVector()); - ASSERT_FALSE(array.isScalar()); + ASSERT_FALSE(array.isMatrix()); + ASSERT_FALSE(array.isVector()); + ASSERT_FALSE(array.isColumnVector()); + ASSERT_FALSE(array.isRowVector()); + ASSERT_FALSE(array.isScalar()); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestTile1) { + // float arr1[6] = {1,2,3,4,5,6}; + // Nd4jLong shape1[8] = {2,2,3,3,1,8192,1,99}; + // float arr2[48] = + // {1,2,3,1,2,3,4,5,6,4,5,6,1,2,3,1,2,3,4,5,6,4,5,6,1,2,3,1,2,3,4,5,6,4,5,6,1,2,3,1,2,3,4,5,6,4,5,6}; + // Nd4jLong shape2[10] = {3,2,4,6,24,6,1,8192,1,99}; - // float arr1[6] = {1,2,3,4,5,6}; - // Nd4jLong shape1[8] = {2,2,3,3,1,8192,1,99}; - // float arr2[48] = {1,2,3,1,2,3,4,5,6,4,5,6,1,2,3,1,2,3,4,5,6,4,5,6,1,2,3,1,2,3,4,5,6,4,5,6,1,2,3,1,2,3,4,5,6,4,5,6}; - // Nd4jLong shape2[10] = {3,2,4,6,24,6,1,8192,1,99}; - - NDArray array1(arr1,shape1); // {2,3} - NDArray array2(arr2,shape2); // {2,4,6} - auto expA = new NDArray(array1.dup('c')); + NDArray array1(arr1, shape1); // {2,3} + NDArray array2(arr2, shape2); // {2,4,6} + auto expA = new NDArray(array1.dup('c')); - auto tiled = array1.tile(tileShape1); + auto tiled = array1.tile(tileShape1); - // array2.printShapeInfo("Expct shape"); - // tiled.printShapeInfo("Tiled shape"); - // tiled.printBuffer(); + // array2.printShapeInfo("Expct shape"); + // tiled.printShapeInfo("Tiled shape"); + // tiled.printBuffer(); - ASSERT_TRUE(tiled.isSameShape(&array2)); - ASSERT_TRUE(tiled.equalsTo(&array2)); + ASSERT_TRUE(tiled.isSameShape(&array2)); + ASSERT_TRUE(tiled.equalsTo(&array2)); - ASSERT_TRUE(expA->isSameShape(&array1)); - ASSERT_TRUE(expA->equalsTo(&array1)); + ASSERT_TRUE(expA->isSameShape(&array1)); + ASSERT_TRUE(expA->equalsTo(&array1)); - // delete tiled; - delete expA; + // delete tiled; + delete expA; } TEST_F(NDArrayTest, TestTile2) { + NDArray array1(arr1, shape1); + NDArray array2(arr2, shape2); - NDArray array1(arr1,shape1); - NDArray array2(arr2,shape2); + auto tiled = array1.tile(tileShape1); - auto tiled = array1.tile(tileShape1); - - ASSERT_TRUE(tiled.isSameShape(&array2)); - ASSERT_TRUE(tiled.equalsTo(&array2)); - // delete tiled; + ASSERT_TRUE(tiled.isSameShape(&array2)); + ASSERT_TRUE(tiled.equalsTo(&array2)); + // delete tiled; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestTile3) { + NDArray array1(arr1, shape1); + NDArray array2(arr2, shape2); - NDArray array1(arr1,shape1); - NDArray array2(arr2,shape2); - - array1.tilei(tileShape1); + array1.tilei(tileShape1); - ASSERT_TRUE(array1.isSameShapeStrict(array2)); - ASSERT_TRUE(array1.equalsTo(&array2)); + ASSERT_TRUE(array1.isSameShapeStrict(array2)); + ASSERT_TRUE(array1.equalsTo(&array2)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestTile4) { + float xBuff[] = {1, 2, 3, 4, 5, 6}; + float expBuff[] = {1.f, 2.f, 1.f, 2.f, 3.f, 4.f, + 3.f, 4.f, 5.f, 6.f, 5.f, 6.f}; - float xBuff[] = {1,2,3,4,5,6}; - float expBuff[] = {1.f,2.f, 1.f,2.f, 3.f,4.f, 3.f,4.f, 5.f,6.f, 5.f,6.f}; + auto x = NDArrayFactory::create(xBuff, 'c', {3, 1, 2}); + auto exp = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); - auto x = NDArrayFactory::create(xBuff, 'c', {3,1,2}); - auto exp = NDArrayFactory::create(expBuff, 'c', {3,2,2}); + auto result = x.tile({2, 1}); - auto result = x.tile({2,1}); - - ASSERT_TRUE(result.isSameShapeStrict(exp)); - ASSERT_TRUE(result.equalsTo(&exp)); + ASSERT_TRUE(result.isSameShapeStrict(exp)); + ASSERT_TRUE(result.equalsTo(&exp)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestTile5) { + float xBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + float expBuff[] = {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 9.f, 10.f, 11.f, 12.f}; - float xBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12}; - float expBuff[] = {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f, 9.f,10.f, 11.f,12.f, 9.f,10.f, 11.f,12.f}; - - auto x = NDArrayFactory::create(xBuff, 'c', {3,2,2}); - auto exp = NDArrayFactory::create(expBuff, 'c', {3,4,2}); + auto x = NDArrayFactory::create(xBuff, 'c', {3, 2, 2}); + auto exp = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); - auto result = x.tile({2,1}); + auto result = x.tile({2, 1}); - ASSERT_TRUE(result.isSameShapeStrict(exp)); - ASSERT_TRUE(result.equalsTo(&exp)); + ASSERT_TRUE(result.isSameShapeStrict(exp)); + ASSERT_TRUE(result.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, TestTile6) -{ - double expBuff[] = {10.,11., 10.,11., 10.,11., 10.,11., 12.,13., 12.,13., 12.,13., 12.,13., 14.,15., 14.,15., 14.,15., 14.,15.}; +TEST_F(NDArrayTest, TestTile6) { + double expBuff[] = {10., 11., 10., 11., 10., 11., 10., 11., + 12., 13., 12., 13., 12., 13., 12., 13., + 14., 15., 14., 15., 14., 15., 14., 15.}; - auto x = NDArrayFactory::create('c', {3, 1, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); + auto x = NDArrayFactory::create('c', {3, 1, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); - x.linspace(10); + x.linspace(10); - auto result = x.tile({1,4,1}); + auto result = x.tile({1, 4, 1}); - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestMmulHelper1) { - auto xBuffer = new float[3]{1.f, 2.f, 3.f}; - auto xShape = new Nd4jLong[8] {2, 1, 3, 1, 1, 8192, 1, 99}; - auto x = new NDArray(xBuffer, xShape); + auto xBuffer = new float[3]{1.f, 2.f, 3.f}; + auto xShape = new Nd4jLong[8]{2, 1, 3, 1, 1, 8192, 1, 99}; + auto x = new NDArray(xBuffer, xShape); - auto yBuffer = new float[3]{2.f, 4.f, 6.f}; - auto yShape = new Nd4jLong[8] {2, 1, 3, 1, 1, 8192, 1, 99}; - auto y = new NDArray(yBuffer, yShape); + auto yBuffer = new float[3]{2.f, 4.f, 6.f}; + auto yShape = new Nd4jLong[8]{2, 1, 3, 1, 1, 8192, 1, 99}; + auto y = new NDArray(yBuffer, yShape); - auto z = MmulHelper::mmul(x, y); + auto z = MmulHelper::mmul(x, y); - ASSERT_EQ(1, z->lengthOf()); - ASSERT_NEAR(28, z->e(0), 1e-5); + ASSERT_EQ(1, z->lengthOf()); + ASSERT_NEAR(28, z->e(0), 1e-5); - delete z; - delete[] xBuffer; - delete[] xShape; - delete[] yBuffer; - delete[] yShape; - delete y; - delete x; + delete z; + delete[] xBuffer; + delete[] xShape; + delete[] yBuffer; + delete[] yShape; + delete y; + delete x; } - TEST_F(NDArrayTest, TestPermuteReshapeMmul1) { - auto x = NDArrayFactory::create('c', {6, 3}); - auto y = NDArrayFactory::create('c', {3, 6}); + auto x = NDArrayFactory::create('c', {6, 3}); + auto y = NDArrayFactory::create('c', {3, 6}); - Nd4jLong _expS[] = {2, 3, 3, 1, 3, 8192, 1, 102}; - float _expB[] = {231.0f, 252.0f, 273.0f, 537.0f, 594.0f, 651.0f, 843.0f, 936.0f, 1029.0f}; - NDArray exp(_expB, _expS); + Nd4jLong _expS[] = {2, 3, 3, 1, 3, 8192, 1, 102}; + float _expB[] = {231.0f, 252.0f, 273.0f, 537.0f, 594.0f, + 651.0f, 843.0f, 936.0f, 1029.0f}; + NDArray exp(_expB, _expS); - for (int e = 0; e < x.lengthOf(); e++) - x.p(e, e+1); + for (int e = 0; e < x.lengthOf(); e++) x.p(e, e + 1); - for (int e = 0; e < y.lengthOf(); e++) - y.p(e, e+1); + for (int e = 0; e < y.lengthOf(); e++) y.p(e, e + 1); - x.permutei({1, 0}); - y.permutei({1, 0}); + x.permutei({1, 0}); + y.permutei({1, 0}); - auto z = MmulHelper::mmul(&x, &y); + auto z = MmulHelper::mmul(&x, &y); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); - delete z; + delete z; } TEST_F(NDArrayTest, TestPermuteReshapeMmul2) { - auto x = NDArrayFactory::create('c', {6, 3}); - auto y = NDArrayFactory::create('c', {3, 6}); + auto x = NDArrayFactory::create('c', {6, 3}); + auto y = NDArrayFactory::create('c', {3, 6}); - Nd4jLong _expS[] = {2, 3, 3, 1, 3, 8192, 1, 102}; - float _expB[] = {231.0f, 252.0f, 273.0f, 537.0f, 594.0f, 651.0f, 843.0f, 936.0f, 1029.0f}; - NDArray exp(_expB, _expS); + Nd4jLong _expS[] = {2, 3, 3, 1, 3, 8192, 1, 102}; + float _expB[] = {231.0f, 252.0f, 273.0f, 537.0f, 594.0f, + 651.0f, 843.0f, 936.0f, 1029.0f}; + NDArray exp(_expB, _expS); - for (int e = 0; e < x.lengthOf(); e++) - x.p(e, e+1); + for (int e = 0; e < x.lengthOf(); e++) x.p(e, e + 1); - for (int e = 0; e < y.lengthOf(); e++) - y.p(e, e+1); + for (int e = 0; e < y.lengthOf(); e++) y.p(e, e + 1); - auto x_ = new NDArray(x.dup('f')); - auto y_ = new NDArray(y.dup('f')); + auto x_ = new NDArray(x.dup('f')); + auto y_ = new NDArray(y.dup('f')); - x_->permutei({1, 0}); - y_->permutei({1, 0}); + x_->permutei({1, 0}); + y_->permutei({1, 0}); - auto z = MmulHelper::mmul(x_, y_); + auto z = MmulHelper::mmul(x_, y_); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); - delete z; - delete x_; - delete y_; + delete z; + delete x_; + delete y_; } - TEST_F(NDArrayTest, TestPermuteReshapeMmul3) { - auto x = NDArrayFactory::create('c', {2, 2, 2, 3, 2, 2}); - auto y = NDArrayFactory::create('c', {2, 3, 2 ,2}); + auto x = NDArrayFactory::create('c', {2, 2, 2, 3, 2, 2}); + auto y = NDArrayFactory::create('c', {2, 3, 2, 2}); - Nd4jLong _expS[] = {2, 8, 2, 1, 8, 8192, 1, 102}; - float _expB[] = {1624.0f, 1858.0f, 2092.0f, 2326.0f, 5368.0f, 5602.0f, 5836.0f, 6070.0f, 4504.0f, 5170.0f, 5836.0f, 6502.0f, 15160.0f, 15826.0f, 16492.0f, 17158.0f}; - NDArray exp(_expB, _expS); + Nd4jLong _expS[] = {2, 8, 2, 1, 8, 8192, 1, 102}; + float _expB[] = {1624.0f, 1858.0f, 2092.0f, 2326.0f, 5368.0f, 5602.0f, + 5836.0f, 6070.0f, 4504.0f, 5170.0f, 5836.0f, 6502.0f, + 15160.0f, 15826.0f, 16492.0f, 17158.0f}; + NDArray exp(_expB, _expS); - for (int e = 0; e < x.lengthOf(); e++) - x.p(e, e+1); + for (int e = 0; e < x.lengthOf(); e++) x.p(e, e + 1); - for (int e = 0; e < y.lengthOf(); e++) - y.p(e, e+1); + for (int e = 0; e < y.lengthOf(); e++) y.p(e, e + 1); - x.permutei({0, 3, 4, 5, 1, 2}); - y.permutei({3, 2, 1, 0}); + x.permutei({0, 3, 4, 5, 1, 2}); + y.permutei({3, 2, 1, 0}); - x.reshapei('c', {2 * 2 * 2, 3 * 2 * 2}); - y.reshapei('c', {2 * 2 * 3, 2}); + x.reshapei('c', {2 * 2 * 2, 3 * 2 * 2}); + y.reshapei('c', {2 * 2 * 3, 2}); - auto z = MmulHelper::mmul(&x, &y); + auto z = MmulHelper::mmul(&x, &y); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); - delete z; + delete z; } TEST_F(NDArrayTest, TestPermuteReshapeMmul4) { - auto x = NDArrayFactory::create('c', {2, 2, 2, 3, 2, 2}); - auto y = NDArrayFactory::create('c', {2, 3, 2 ,2}); + auto x = NDArrayFactory::create('c', {2, 2, 2, 3, 2, 2}); + auto y = NDArrayFactory::create('c', {2, 3, 2, 2}); - Nd4jLong _expS[] = {2, 8, 2, 1, 8, 8192, 1, 102}; - float _expB[] = {1624.0f, 1858.0f, 2092.0f, 2326.0f, 5368.0f, 5602.0f, 5836.0f, 6070.0f, 4504.0f, 5170.0f, 5836.0f, 6502.0f, 15160.0f, 15826.0f, 16492.0f, 17158.0f}; - NDArray exp(_expB, _expS); + Nd4jLong _expS[] = {2, 8, 2, 1, 8, 8192, 1, 102}; + float _expB[] = {1624.0f, 1858.0f, 2092.0f, 2326.0f, 5368.0f, 5602.0f, + 5836.0f, 6070.0f, 4504.0f, 5170.0f, 5836.0f, 6502.0f, + 15160.0f, 15826.0f, 16492.0f, 17158.0f}; + NDArray exp(_expB, _expS); - for (int e = 0; e < x.lengthOf(); e++) - x.p(e, e+1); + for (int e = 0; e < x.lengthOf(); e++) x.p(e, e + 1); - for (int e = 0; e < y.lengthOf(); e++) - y.p(e, e+1); + for (int e = 0; e < y.lengthOf(); e++) y.p(e, e + 1); - auto y_ = new NDArray(y.dup('f')); + auto y_ = new NDArray(y.dup('f')); - x.permutei({0, 3, 4, 5, 1, 2}); - y_->permutei({3, 2, 1, 0}); + x.permutei({0, 3, 4, 5, 1, 2}); + y_->permutei({3, 2, 1, 0}); - x.reshapei('c', {2 * 2 * 2, 3 * 2 * 2}); - y_->reshapei('c', {2 * 2 * 3, 2}); + x.reshapei('c', {2 * 2 * 2, 3 * 2 * 2}); + y_->reshapei('c', {2 * 2 * 3, 2}); - auto z = MmulHelper::mmul(&x, y_); + auto z = MmulHelper::mmul(&x, y_); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); - delete z; - delete y_; + delete z; + delete y_; } - ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestMmulHelper2) { - auto xBuffer = new float[15]{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}; - Nd4jLong xShape[8] = {2, 5, 3, 3, 1, 8192, 1, 99}; - auto x = new NDArray(xBuffer, xShape, sd::LaunchContext ::defaultContext(), true); - + auto xBuffer = new float[15]{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}; + Nd4jLong xShape[8] = {2, 5, 3, 3, 1, 8192, 1, 99}; + auto x = + new NDArray(xBuffer, xShape, sd::LaunchContext ::defaultContext(), true); - auto yBuffer = new float[3]{2.f, 4.f, 6.f}; - Nd4jLong yShape[8] = {2, 3, 1, 1, 1, 8192, 1, 99}; - auto y = new NDArray(yBuffer, yShape, sd::LaunchContext ::defaultContext(), true); + auto yBuffer = new float[3]{2.f, 4.f, 6.f}; + Nd4jLong yShape[8] = {2, 3, 1, 1, 1, 8192, 1, 99}; + auto y = + new NDArray(yBuffer, yShape, sd::LaunchContext ::defaultContext(), true); - auto z = NDArrayFactory::create_('f', {5, 1}); + auto z = NDArrayFactory::create_('f', {5, 1}); - auto expBuffer = new float[5]{28.00f, 64.00f, 100.00f, 136.00f, 172.00f}; - auto exp = new NDArray(expBuffer, z->shapeInfo(), sd::LaunchContext ::defaultContext(), true); + auto expBuffer = new float[5]{28.00f, 64.00f, 100.00f, 136.00f, 172.00f}; + auto exp = new NDArray(expBuffer, z->shapeInfo(), + sd::LaunchContext ::defaultContext(), true); - //sd::blas::GEMV::op('f', x->rows(), x->columns(), 1.0f, x->buffer(), y->rows(), y->buffer(), 1, 0.0, z->buffer(), 1); + // sd::blas::GEMV::op('f', x->rows(), x->columns(), 1.0f, x->buffer(), + // y->rows(), y->buffer(), 1, 0.0, z->buffer(), 1); - MmulHelper::mmul(x, y, z); + MmulHelper::mmul(x, y, z); - //z->printBuffer(); + // z->printBuffer(); - ASSERT_TRUE(z->equalsTo(exp)); + ASSERT_TRUE(z->equalsTo(exp)); - delete x; - delete y; - delete z; - delete exp; + delete x; + delete y; + delete z; + delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestMmulHelper3) { - auto xBuffer = new float[15]{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}; - auto xShape = new Nd4jLong[8] {2, 5, 3, 1, 5, 8192, 1, 102}; - auto x = new NDArray(xBuffer, xShape); + auto xBuffer = new float[15]{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}; + auto xShape = new Nd4jLong[8]{2, 5, 3, 1, 5, 8192, 1, 102}; + auto x = new NDArray(xBuffer, xShape); - auto yBuffer = new float[3]{2.f, 4.f, 6.f}; - auto yShape = new Nd4jLong[8] {2, 3, 1, 1, 1, 8192, 1, 99}; - auto y = new NDArray(yBuffer, yShape); + auto yBuffer = new float[3]{2.f, 4.f, 6.f}; + auto yShape = new Nd4jLong[8]{2, 3, 1, 1, 1, 8192, 1, 99}; + auto y = new NDArray(yBuffer, yShape); - auto z = NDArrayFactory::create_('f', {5, 1}); + auto z = NDArrayFactory::create_('f', {5, 1}); - auto expBuffer = new float[5]{92.00f, 104.00f, 116.00f, 128.00f, 140.00f}; - auto exp = new NDArray(expBuffer, z->shapeInfo()); + auto expBuffer = new float[5]{92.00f, 104.00f, 116.00f, 128.00f, 140.00f}; + auto exp = new NDArray(expBuffer, z->shapeInfo()); - //sd::blas::GEMV::op('f', x->rows(), x->columns(), 1.0f, x->buffer(), y->rows(), y->buffer(), 1, 0.0, z->buffer(), 1); + // sd::blas::GEMV::op('f', x->rows(), x->columns(), 1.0f, x->buffer(), + // y->rows(), y->buffer(), 1, 0.0, z->buffer(), 1); - MmulHelper::mmul(x, y, z); + MmulHelper::mmul(x, y, z); - //z->printBuffer(); + // z->printBuffer(); - ASSERT_TRUE(z->equalsTo(exp)); + ASSERT_TRUE(z->equalsTo(exp)); - delete[] expBuffer; - delete[] xBuffer; - delete[] yBuffer; - delete[] xShape; - delete[] yShape; + delete[] expBuffer; + delete[] xBuffer; + delete[] yBuffer; + delete[] xShape; + delete[] yShape; - delete x; - delete y; - delete z; - delete exp; + delete x; + delete y; + delete z; + delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestMmulHelper4) { - auto xBuffer = new float[6]{1, 2, 3, 4, 5, 6}; - auto xShape = new Nd4jLong[8] {2, 3, 2, 2, 1, 8192, 1, 99}; - auto x = new NDArray(xBuffer, xShape); + auto xBuffer = new float[6]{1, 2, 3, 4, 5, 6}; + auto xShape = new Nd4jLong[8]{2, 3, 2, 2, 1, 8192, 1, 99}; + auto x = new NDArray(xBuffer, xShape); - auto yBuffer = new float[6]{7, 8, 9, 0, 1, 2}; - auto yShape = new Nd4jLong[8] {2, 2, 3, 3, 1, 8192, 1, 99}; - auto y = new NDArray(yBuffer, yShape); + auto yBuffer = new float[6]{7, 8, 9, 0, 1, 2}; + auto yShape = new Nd4jLong[8]{2, 2, 3, 3, 1, 8192, 1, 99}; + auto y = new NDArray(yBuffer, yShape); - auto z = NDArrayFactory::create_('f', {3, 3}); + auto z = NDArrayFactory::create_('f', {3, 3}); - auto expBuffer = new float[9]{7.0f, 21.0f, 35.0f, 10.0f, 28.0f, 46.0f, 13.0f, 35.0f, 57.0f}; - auto exp = new NDArray(expBuffer, z->shapeInfo()); + auto expBuffer = new float[9]{7.0f, 21.0f, 35.0f, 10.0f, 28.0f, + 46.0f, 13.0f, 35.0f, 57.0f}; + auto exp = new NDArray(expBuffer, z->shapeInfo()); - MmulHelper::mmul(x, y, z); - ASSERT_TRUE(z->equalsTo(exp)); + MmulHelper::mmul(x, y, z); + ASSERT_TRUE(z->equalsTo(exp)); - delete[] expBuffer; - delete[] xBuffer; - delete[] yBuffer; - delete[] xShape; - delete[] yShape; + delete[] expBuffer; + delete[] xBuffer; + delete[] yBuffer; + delete[] xShape; + delete[] yShape; - delete x; - delete y; - delete z; - delete exp; + delete x; + delete y; + delete z; + delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestMmulHelper5) { - auto xBuffer = new float[6]{1, 2, 3, 4, 5, 6}; - auto xShape = new Nd4jLong[8] {2, 3, 2, 1, 3, 8192, 1, 102}; - auto x = new NDArray(xBuffer, xShape); + auto xBuffer = new float[6]{1, 2, 3, 4, 5, 6}; + auto xShape = new Nd4jLong[8]{2, 3, 2, 1, 3, 8192, 1, 102}; + auto x = new NDArray(xBuffer, xShape); - auto yBuffer = new float[6]{7, 8, 9, 0, 1, 2}; - auto yShape = new Nd4jLong[8] {2, 2, 3, 3, 1, 8192, 1, 99}; - auto y = new NDArray(yBuffer, yShape); + auto yBuffer = new float[6]{7, 8, 9, 0, 1, 2}; + auto yShape = new Nd4jLong[8]{2, 2, 3, 3, 1, 8192, 1, 99}; + auto y = new NDArray(yBuffer, yShape); - auto z = NDArrayFactory::create_('f', {3, 3}); + auto z = NDArrayFactory::create_('f', {3, 3}); - auto expBuffer = new float[9]{7.0f, 14.0f, 21.0f, 12.0f, 21.0f, 30.0f, 17.0f, 28.0f, 39.0f}; - auto exp = new NDArray(expBuffer, z->shapeInfo()); + auto expBuffer = new float[9]{7.0f, 14.0f, 21.0f, 12.0f, 21.0f, + 30.0f, 17.0f, 28.0f, 39.0f}; + auto exp = new NDArray(expBuffer, z->shapeInfo()); - MmulHelper::mmul(x, y, z); - ASSERT_TRUE(z->equalsTo(exp)); + MmulHelper::mmul(x, y, z); + ASSERT_TRUE(z->equalsTo(exp)); - delete[] expBuffer; - delete[] xBuffer; - delete[] yBuffer; - delete[] xShape; - delete[] yShape; + delete[] expBuffer; + delete[] xBuffer; + delete[] yBuffer; + delete[] xShape; + delete[] yShape; - delete x; - delete y; - delete z; - delete exp; + delete x; + delete y; + delete z; + delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestMmulHelper6) { - auto xBuffer = new float[6]{1, 2, 3, 4, 5, 6}; - auto xShape = new Nd4jLong[8] {2, 3, 2, 1, 3, 8192, 1, 102}; - auto x = new NDArray(xBuffer, xShape); + auto xBuffer = new float[6]{1, 2, 3, 4, 5, 6}; + auto xShape = new Nd4jLong[8]{2, 3, 2, 1, 3, 8192, 1, 102}; + auto x = new NDArray(xBuffer, xShape); - auto yBuffer = new float[6]{7, 8, 9, 0, 1, 2}; - auto yShape = new Nd4jLong[8] {2, 2, 3, 1, 2, 8192, 1, 102}; - auto y = new NDArray(yBuffer, yShape); + auto yBuffer = new float[6]{7, 8, 9, 0, 1, 2}; + auto yShape = new Nd4jLong[8]{2, 2, 3, 1, 2, 8192, 1, 102}; + auto y = new NDArray(yBuffer, yShape); - auto z = NDArrayFactory::create_('f', {3, 3}); + auto z = NDArrayFactory::create_('f', {3, 3}); - auto expBuffer = new float[9]{39.0f, 54.0f, 69.0f, 9.0f, 18.0f, 27.0f, 9.0f, 12.0f, 15.0f}; - auto exp = new NDArray(expBuffer, z->shapeInfo()); + auto expBuffer = + new float[9]{39.0f, 54.0f, 69.0f, 9.0f, 18.0f, 27.0f, 9.0f, 12.0f, 15.0f}; + auto exp = new NDArray(expBuffer, z->shapeInfo()); - MmulHelper::mmul(x, y, z); - ASSERT_TRUE(z->equalsTo(exp)); + MmulHelper::mmul(x, y, z); + ASSERT_TRUE(z->equalsTo(exp)); + delete[] expBuffer; + delete[] xBuffer; + delete[] yBuffer; + delete[] xShape; + delete[] yShape; - delete[] expBuffer; - delete[] xBuffer; - delete[] yBuffer; - delete[] xShape; - delete[] yShape; - - delete x; - delete y; - delete z; - delete exp; + delete x; + delete y; + delete z; + delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestMmulHelper7) { - auto xBuffer = new float[15]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - auto xShape = new Nd4jLong[8] {2, 5, 3, 1, 5, 8192, 1, 102}; - auto x = new NDArray(xBuffer, xShape); + auto xBuffer = + new float[15]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + auto xShape = new Nd4jLong[8]{2, 5, 3, 1, 5, 8192, 1, 102}; + auto x = new NDArray(xBuffer, xShape); - auto yBuffer = new float[5]{2, 4, 6, 8, 10}; - auto yShape = new Nd4jLong[8] {2, 1, 5, 1, 1, 8192, 1, 99}; - auto y = new NDArray(yBuffer, yShape); + auto yBuffer = new float[5]{2, 4, 6, 8, 10}; + auto yShape = new Nd4jLong[8]{2, 1, 5, 1, 1, 8192, 1, 99}; + auto y = new NDArray(yBuffer, yShape); - auto z = NDArrayFactory::create_('f', {1, 3}); + auto z = NDArrayFactory::create_('f', {1, 3}); - auto expBuffer = new float[9]{110.00f, 260.00f, 410.00f}; - auto exp = new NDArray(expBuffer, z->shapeInfo()); + auto expBuffer = new float[9]{110.00f, 260.00f, 410.00f}; + auto exp = new NDArray(expBuffer, z->shapeInfo()); - MmulHelper::mmul(y, x, z); + MmulHelper::mmul(y, x, z); - //z->printBuffer(); - ASSERT_TRUE(z->equalsTo(exp)); + // z->printBuffer(); + ASSERT_TRUE(z->equalsTo(exp)); - delete[] expBuffer; - delete[] xBuffer; - delete[] yBuffer; - delete[] xShape; - delete[] yShape; + delete[] expBuffer; + delete[] xBuffer; + delete[] yBuffer; + delete[] xShape; + delete[] yShape; - delete x; - delete y; - delete z; - delete exp; + delete x; + delete y; + delete z; + delete exp; } - TEST_F(NDArrayTest, TestMmulHelper_ND_1) { - Nd4jLong _expS[] = {3, 2, 3, 3, 9, 3, 1, 8192, 1, 99}; - float _expB[] = {70.f, 80.f, 90.f, 158.f, 184.f, 210.f, 246.f, 288.f, 330.f, 1030.f, 1088.f, 1146.f, 1310.f, 1384.f, 1458.f, 1590.f, 1680.f, 1770.f}; + Nd4jLong _expS[] = {3, 2, 3, 3, 9, 3, 1, 8192, 1, 99}; + float _expB[] = {70.f, 80.f, 90.f, 158.f, 184.f, 210.f, + 246.f, 288.f, 330.f, 1030.f, 1088.f, 1146.f, + 1310.f, 1384.f, 1458.f, 1590.f, 1680.f, 1770.f}; - auto a = NDArrayFactory::create('c', {2, 3, 4}); - for (int e = 0; e < a.lengthOf(); e++) - a.p(e, e+1); + auto a = NDArrayFactory::create('c', {2, 3, 4}); + for (int e = 0; e < a.lengthOf(); e++) a.p(e, e + 1); - auto b = NDArrayFactory::create('c', {2, 4, 3}); - for (int e = 0; e < b.lengthOf(); e++) - b.p(e, e+1); + auto b = NDArrayFactory::create('c', {2, 4, 3}); + for (int e = 0; e < b.lengthOf(); e++) b.p(e, e + 1); - NDArray exp(_expB, _expS); - auto c = MmulHelper::mmul(&a, &b); + NDArray exp(_expB, _expS); + auto c = MmulHelper::mmul(&a, &b); - ASSERT_TRUE(exp.isSameShape(c)); - ASSERT_TRUE(exp.equalsTo(c)); + ASSERT_TRUE(exp.isSameShape(c)); + ASSERT_TRUE(exp.equalsTo(c)); - delete c; + delete c; } - TEST_F(NDArrayTest, TestMmulHelper_ND_2) { - Nd4jLong _expS[] = {3, 2, 72, 2, 144, 2, 1, 8192, 1, 99}; - float _expB[] = { - 1.07250000e+04f, 1.10500000e+04f, 2.63500000e+04f, 2.73000000e+04f, 4.19750000e+04f, 4.35500000e+04f, - 5.76000000e+04f, 5.98000000e+04f, 7.32250000e+04f, 7.60500000e+04f, 8.88500000e+04f, 9.23000000e+04f, - 1.04475000e+05f, 1.08550000e+05f, 1.20100000e+05f, 1.24800000e+05f, 1.35725000e+05f, 1.41050000e+05f, - 1.51350000e+05f, 1.57300000e+05f, 1.66975000e+05f, 1.73550000e+05f, 1.82600000e+05f, 1.89800000e+05f, - 1.98225000e+05f, 2.06050000e+05f, 2.13850000e+05f, 2.22300000e+05f, 2.29475000e+05f, 2.38550000e+05f, - 2.45100000e+05f, 2.54800000e+05f, 2.60725000e+05f, 2.71050000e+05f, 2.76350000e+05f, 2.87300000e+05f, - 2.91975000e+05f, 3.03550000e+05f, 3.07600000e+05f, 3.19800000e+05f, 3.23225000e+05f, 3.36050000e+05f, - 3.38850000e+05f, 3.52300000e+05f, 3.54475000e+05f, 3.68550000e+05f, 3.70100000e+05f, 3.84800000e+05f, - 3.85725000e+05f, 4.01050000e+05f, 4.01350000e+05f, 4.17300000e+05f, 4.16975000e+05f, 4.33550000e+05f, - 4.32600000e+05f, 4.49800000e+05f, 4.48225000e+05f, 4.66050000e+05f, 4.63850000e+05f, 4.82300000e+05f, - 4.79475000e+05f, 4.98550000e+05f, 4.95100000e+05f, 5.14800000e+05f, 5.10725000e+05f, 5.31050000e+05f, - 5.26350000e+05f, 5.47300000e+05f, 5.41975000e+05f, 5.63550000e+05f, 5.57600000e+05f, 5.79800000e+05f, - 5.73225000e+05f, 5.96050000e+05f, 5.88850000e+05f, 6.12300000e+05f, 6.04475000e+05f, 6.28550000e+05f, - 6.20100000e+05f, 6.44800000e+05f, 6.35725000e+05f, 6.61050000e+05f, 6.51350000e+05f, 6.77300000e+05f, - 6.66975000e+05f, 6.93550000e+05f, 6.82600000e+05f, 7.09800000e+05f, 6.98225000e+05f, 7.26050000e+05f, - 7.13850000e+05f, 7.42300000e+05f, 7.29475000e+05f, 7.58550000e+05f, 7.45100000e+05f, 7.74800000e+05f, - 7.60725000e+05f, 7.91050000e+05f, 7.76350000e+05f, 8.07300000e+05f, 7.91975000e+05f, 8.23550000e+05f, - 8.07600000e+05f, 8.39800000e+05f, 8.23225000e+05f, 8.56050000e+05f, 8.38850000e+05f, 8.72300000e+05f, - 8.54475000e+05f, 8.88550000e+05f, 8.70100000e+05f, 9.04800000e+05f, 8.85725000e+05f, 9.21050000e+05f, - 9.01350000e+05f, 9.37300000e+05f, 9.16975000e+05f, 9.53550000e+05f, 9.32600000e+05f, 9.69800000e+05f, - 9.48225000e+05f, 9.86050000e+05f, 9.63850000e+05f, 1.00230000e+06f, 9.79475000e+05f, 1.01855000e+06f, - 9.95100000e+05f, 1.03480000e+06f, 1.01072500e+06f, 1.05105000e+06f, 1.02635000e+06f, 1.06730000e+06f, - 1.04197500e+06f, 1.08355000e+06f, 1.05760000e+06f, 1.09980000e+06f, 1.07322500e+06f, 1.11605000e+06f, - 1.08885000e+06f, 1.13230000e+06f, 1.10447500e+06f, 1.14855000e+06f, 1.12010000e+06f, 1.16480000e+06f, - 1.13572500e+06f, 1.18105000e+06f, 1.15135000e+06f, 1.19730000e+06f, 1.16697500e+06f, 1.21355000e+06f, - 3.54260000e+06f, 3.58980000e+06f, 3.58947500e+06f, 3.63730000e+06f, 3.63635000e+06f, 3.68480000e+06f, - 3.68322500e+06f, 3.73230000e+06f, 3.73010000e+06f, 3.77980000e+06f, 3.77697500e+06f, 3.82730000e+06f, - 3.82385000e+06f, 3.87480000e+06f, 3.87072500e+06f, 3.92230000e+06f, 3.91760000e+06f, 3.96980000e+06f, - 3.96447500e+06f, 4.01730000e+06f, 4.01135000e+06f, 4.06480000e+06f, 4.05822500e+06f, 4.11230000e+06f, - 4.10510000e+06f, 4.15980000e+06f, 4.15197500e+06f, 4.20730000e+06f, 4.19885000e+06f, 4.25480000e+06f, - 4.24572500e+06f, 4.30230000e+06f, 4.29260000e+06f, 4.34980000e+06f, 4.33947500e+06f, 4.39730000e+06f, - 4.38635000e+06f, 4.44480000e+06f, 4.43322500e+06f, 4.49230000e+06f, 4.48010000e+06f, 4.53980000e+06f, - 4.52697500e+06f, 4.58730000e+06f, 4.57385000e+06f, 4.63480000e+06f, 4.62072500e+06f, 4.68230000e+06f, - 4.66760000e+06f, 4.72980000e+06f, 4.71447500e+06f, 4.77730000e+06f, 4.76135000e+06f, 4.82480000e+06f, - 4.80822500e+06f, 4.87230000e+06f, 4.85510000e+06f, 4.91980000e+06f, 4.90197500e+06f, 4.96730000e+06f, - 4.94885000e+06f, 5.01480000e+06f, 4.99572500e+06f, 5.06230000e+06f, 5.04260000e+06f, 5.10980000e+06f, - 5.08947500e+06f, 5.15730000e+06f, 5.13635000e+06f, 5.20480000e+06f, 5.18322500e+06f, 5.25230000e+06f, - 5.23010000e+06f, 5.29980000e+06f, 5.27697500e+06f, 5.34730000e+06f, 5.32385000e+06f, 5.39480000e+06f, - 5.37072500e+06f, 5.44230000e+06f, 5.41760000e+06f, 5.48980000e+06f, 5.46447500e+06f, 5.53730000e+06f, - 5.51135000e+06f, 5.58480000e+06f, 5.55822500e+06f, 5.63230000e+06f, 5.60510000e+06f, 5.67980000e+06f, - 5.65197500e+06f, 5.72730000e+06f, 5.69885000e+06f, 5.77480000e+06f, 5.74572500e+06f, 5.82230000e+06f, - 5.79260000e+06f, 5.86980000e+06f, 5.83947500e+06f, 5.91730000e+06f, 5.88635000e+06f, 5.96480000e+06f, - 5.93322500e+06f, 6.01230000e+06f, 5.98010000e+06f, 6.05980000e+06f, 6.02697500e+06f, 6.10730000e+06f, - 6.07385000e+06f, 6.15480000e+06f, 6.12072500e+06f, 6.20230000e+06f, 6.16760000e+06f, 6.24980000e+06f, - 6.21447500e+06f, 6.29730000e+06f, 6.26135000e+06f, 6.34480000e+06f, 6.30822500e+06f, 6.39230000e+06f, - 6.35510000e+06f, 6.43980000e+06f, 6.40197500e+06f, 6.48730000e+06f, 6.44885000e+06f, 6.53480000e+06f, - 6.49572500e+06f, 6.58230000e+06f, 6.54260000e+06f, 6.62980000e+06f, 6.58947500e+06f, 6.67730000e+06f, - 6.63635000e+06f, 6.72480000e+06f, 6.68322500e+06f, 6.77230000e+06f, 6.73010000e+06f, 6.81980000e+06f, - 6.77697500e+06f, 6.86730000e+06f, 6.82385000e+06f, 6.91480000e+06f, 6.87072500e+06f, 6.96230000e+06f, - 6.91760000e+06f, 7.00980000e+06f, 6.96447500e+06f, 7.05730000e+06f, 7.01135000e+06f, 7.10480000e+06f, - 1.17619750e+07f, 1.18560500e+07f, 1.18401000e+07f, 1.19348000e+07f, 1.19182250e+07f, 1.20135500e+07f, - 1.19963500e+07f, 1.20923000e+07f, 1.20744750e+07f, 1.21710500e+07f, 1.21526000e+07f, 1.22498000e+07f, 1.22307250e+07f, 1.23285500e+07f, 1.23088500e+07f, 1.24073000e+07f, 1.23869750e+07f, 1.24860500e+07f, 1.24651000e+07f, 1.25648000e+07f, 1.25432250e+07f, 1.26435500e+07f, 1.26213500e+07f, 1.27223000e+07f, 1.26994750e+07f, 1.28010500e+07f, 1.27776000e+07f, 1.28798000e+07f, 1.28557250e+07f, 1.29585500e+07f, 1.29338500e+07f, 1.30373000e+07f, 1.30119750e+07f, 1.31160500e+07f, 1.30901000e+07f, 1.31948000e+07f, 1.31682250e+07f, 1.32735500e+07f, 1.32463500e+07f, 1.33523000e+07f, 1.33244750e+07f, 1.34310500e+07f, 1.34026000e+07f, 1.35098000e+07f, 1.34807250e+07f, 1.35885500e+07f, 1.35588500e+07f, 1.36673000e+07f, 1.36369750e+07f, 1.37460500e+07f, 1.37151000e+07f, 1.38248000e+07f, 1.37932250e+07f, 1.39035500e+07f, 1.38713500e+07f, 1.39823000e+07f, 1.39494750e+07f, 1.40610500e+07f, 1.40276000e+07f, 1.41398000e+07f, 1.41057250e+07f, 1.42185500e+07f, 1.41838500e+07f, 1.42973000e+07f, 1.42619750e+07f, 1.43760500e+07f, 1.43401000e+07f, 1.44548000e+07f, 1.44182250e+07f, 1.45335500e+07f, 1.44963500e+07f, 1.46123000e+07f, 1.45744750e+07f, 1.46910500e+07f, 1.46526000e+07f, 1.47698000e+07f, 1.47307250e+07f, 1.48485500e+07f, 1.48088500e+07f, 1.49273000e+07f, 1.48869750e+07f, 1.50060500e+07f, 1.49651000e+07f, 1.50848000e+07f, 1.50432250e+07f, 1.51635500e+07f, 1.51213500e+07f, 1.52423000e+07f, 1.51994750e+07f, 1.53210500e+07f, 1.52776000e+07f, 1.53998000e+07f, 1.53557250e+07f, 1.54785500e+07f, 1.54338500e+07f, 1.55573000e+07f, 1.55119750e+07f, 1.56360500e+07f, 1.55901000e+07f, 1.57148000e+07f, 1.56682250e+07f, 1.57935500e+07f, 1.57463500e+07f, 1.58723000e+07f, 1.58244750e+07f, 1.59510500e+07f, 1.59026000e+07f, 1.60298000e+07f, 1.59807250e+07f, 1.61085500e+07f, 1.60588500e+07f, 1.61873000e+07f, 1.61369750e+07f, 1.62660500e+07f, 1.62151000e+07f, 1.63448000e+07f, 1.62932250e+07f, 1.64235500e+07f, 1.63713500e+07f, 1.65023000e+07f, 1.64494750e+07f, 1.65810500e+07f, 1.65276000e+07f, 1.66598000e+07f, 1.66057250e+07f, 1.67385500e+07f, 1.66838500e+07f, 1.68173000e+07f, 1.67619750e+07f, 1.68960500e+07f, 1.68401000e+07f, 1.69748000e+07f, 1.69182250e+07f, 1.70535500e+07f, 1.69963500e+07f, 1.71323000e+07f, 1.70744750e+07f, 1.72110500e+07f, 1.71526000e+07f, 1.72898000e+07f, 1.72307250e+07f, 1.73685500e+07f, 1.73088500e+07f, 1.74473000e+07f, 1.73869750e+07f, 1.75260500e+07f, 1.74651000e+07f, 1.76048000e+07f, 1.75432250e+07f, 1.76835500e+07f, 2.46688500e+07f, 2.48098000e+07f, 2.47782250e+07f, 2.49198000e+07f, 2.48876000e+07f, 2.50298000e+07f, 2.49969750e+07f, 2.51398000e+07f, 2.51063500e+07f, 2.52498000e+07f, 2.52157250e+07f, 2.53598000e+07f, 2.53251000e+07f, 2.54698000e+07f, 2.54344750e+07f, 2.55798000e+07f, 2.55438500e+07f, 2.56898000e+07f, 2.56532250e+07f, 2.57998000e+07f, 2.57626000e+07f, 2.59098000e+07f, 2.58719750e+07f, 2.60198000e+07f, 2.59813500e+07f, 2.61298000e+07f, 2.60907250e+07f, 2.62398000e+07f, 2.62001000e+07f, 2.63498000e+07f, 2.63094750e+07f, 2.64598000e+07f, 2.64188500e+07f, 2.65698000e+07f, 2.65282250e+07f, 2.66798000e+07f, 2.66376000e+07f, 2.67898000e+07f, 2.67469750e+07f, 2.68998000e+07f, 2.68563500e+07f, 2.70098000e+07f, 2.69657250e+07f, 2.71198000e+07f, 2.70751000e+07f, 2.72298000e+07f, 2.71844750e+07f, 2.73398000e+07f, 2.72938500e+07f, 2.74498000e+07f, 2.74032250e+07f, 2.75598000e+07f, 2.75126000e+07f, 2.76698000e+07f, 2.76219750e+07f, 2.77798000e+07f, 2.77313500e+07f, 2.78898000e+07f, 2.78407250e+07f, 2.79998000e+07f, 2.79501000e+07f, 2.81098000e+07f, 2.80594750e+07f, 2.82198000e+07f, 2.81688500e+07f, 2.83298000e+07f, 2.82782250e+07f, 2.84398000e+07f, 2.83876000e+07f, 2.85498000e+07f, 2.84969750e+07f, 2.86598000e+07f, 2.86063500e+07f, 2.87698000e+07f, 2.87157250e+07f, 2.88798000e+07f, 2.88251000e+07f, 2.89898000e+07f, 2.89344750e+07f, 2.90998000e+07f, 2.90438500e+07f, 2.92098000e+07f, 2.91532250e+07f, 2.93198000e+07f, 2.92626000e+07f, 2.94298000e+07f, 2.93719750e+07f, 2.95398000e+07f, 2.94813500e+07f, 2.96498000e+07f, 2.95907250e+07f, 2.97598000e+07f, 2.97001000e+07f, 2.98698000e+07f, 2.98094750e+07f, 2.99798000e+07f, 2.99188500e+07f, 3.00898000e+07f, 3.00282250e+07f, 3.01998000e+07f, 3.01376000e+07f, 3.03098000e+07f, 3.02469750e+07f, 3.04198000e+07f, 3.03563500e+07f, 3.05298000e+07f, 3.04657250e+07f, 3.06398000e+07f, 3.05751000e+07f, 3.07498000e+07f, 3.06844750e+07f, 3.08598000e+07f, 3.07938500e+07f, 3.09698000e+07f, 3.09032250e+07f, 3.10798000e+07f, 3.10126000e+07f, 3.11898000e+07f, 3.11219750e+07f, 3.12998000e+07f, 3.12313500e+07f, 3.14098000e+07f, 3.13407250e+07f, 3.15198000e+07f, 3.14501000e+07f, 3.16298000e+07f, 3.15594750e+07f, 3.17398000e+07f, 3.16688500e+07f, 3.18498000e+07f, 3.17782250e+07f, 3.19598000e+07f, 3.18876000e+07f, 3.20698000e+07f, 3.19969750e+07f, 3.21798000e+07f, 3.21063500e+07f, 3.22898000e+07f, 3.22157250e+07f, 3.23998000e+07f, 3.23251000e+07f, 3.25098000e+07f, 3.24344750e+07f, 3.26198000e+07f, 3.25438500e+07f, 3.27298000e+07f, 3.26532250e+07f, 3.28398000e+07f, 3.27626000e+07f, 3.29498000e+07}; - - auto a = NDArrayFactory::create('c', {2, 72, 25}); - for (int e = 0; e < a.lengthOf(); e++) - a.p(e, e+1); - - auto b = NDArrayFactory::create('c', {2, 25, 2}); - for (int e = 0; e < b.lengthOf(); e++) - b.p(e, e+1); - - NDArray exp(_expB, _expS); - - auto c = MmulHelper::mmul(&a, &b); - - ASSERT_TRUE(exp.isSameShape(c)); - ASSERT_TRUE(exp.equalsTo(c, 1e1)); - - delete c; + Nd4jLong _expS[] = {3, 2, 72, 2, 144, 2, 1, 8192, 1, 99}; + float _expB[] = { + 1.07250000e+04f, 1.10500000e+04f, 2.63500000e+04f, 2.73000000e+04f, + 4.19750000e+04f, 4.35500000e+04f, 5.76000000e+04f, 5.98000000e+04f, + 7.32250000e+04f, 7.60500000e+04f, 8.88500000e+04f, 9.23000000e+04f, + 1.04475000e+05f, 1.08550000e+05f, 1.20100000e+05f, 1.24800000e+05f, + 1.35725000e+05f, 1.41050000e+05f, 1.51350000e+05f, 1.57300000e+05f, + 1.66975000e+05f, 1.73550000e+05f, 1.82600000e+05f, 1.89800000e+05f, + 1.98225000e+05f, 2.06050000e+05f, 2.13850000e+05f, 2.22300000e+05f, + 2.29475000e+05f, 2.38550000e+05f, 2.45100000e+05f, 2.54800000e+05f, + 2.60725000e+05f, 2.71050000e+05f, 2.76350000e+05f, 2.87300000e+05f, + 2.91975000e+05f, 3.03550000e+05f, 3.07600000e+05f, 3.19800000e+05f, + 3.23225000e+05f, 3.36050000e+05f, 3.38850000e+05f, 3.52300000e+05f, + 3.54475000e+05f, 3.68550000e+05f, 3.70100000e+05f, 3.84800000e+05f, + 3.85725000e+05f, 4.01050000e+05f, 4.01350000e+05f, 4.17300000e+05f, + 4.16975000e+05f, 4.33550000e+05f, 4.32600000e+05f, 4.49800000e+05f, + 4.48225000e+05f, 4.66050000e+05f, 4.63850000e+05f, 4.82300000e+05f, + 4.79475000e+05f, 4.98550000e+05f, 4.95100000e+05f, 5.14800000e+05f, + 5.10725000e+05f, 5.31050000e+05f, 5.26350000e+05f, 5.47300000e+05f, + 5.41975000e+05f, 5.63550000e+05f, 5.57600000e+05f, 5.79800000e+05f, + 5.73225000e+05f, 5.96050000e+05f, 5.88850000e+05f, 6.12300000e+05f, + 6.04475000e+05f, 6.28550000e+05f, 6.20100000e+05f, 6.44800000e+05f, + 6.35725000e+05f, 6.61050000e+05f, 6.51350000e+05f, 6.77300000e+05f, + 6.66975000e+05f, 6.93550000e+05f, 6.82600000e+05f, 7.09800000e+05f, + 6.98225000e+05f, 7.26050000e+05f, 7.13850000e+05f, 7.42300000e+05f, + 7.29475000e+05f, 7.58550000e+05f, 7.45100000e+05f, 7.74800000e+05f, + 7.60725000e+05f, 7.91050000e+05f, 7.76350000e+05f, 8.07300000e+05f, + 7.91975000e+05f, 8.23550000e+05f, 8.07600000e+05f, 8.39800000e+05f, + 8.23225000e+05f, 8.56050000e+05f, 8.38850000e+05f, 8.72300000e+05f, + 8.54475000e+05f, 8.88550000e+05f, 8.70100000e+05f, 9.04800000e+05f, + 8.85725000e+05f, 9.21050000e+05f, 9.01350000e+05f, 9.37300000e+05f, + 9.16975000e+05f, 9.53550000e+05f, 9.32600000e+05f, 9.69800000e+05f, + 9.48225000e+05f, 9.86050000e+05f, 9.63850000e+05f, 1.00230000e+06f, + 9.79475000e+05f, 1.01855000e+06f, 9.95100000e+05f, 1.03480000e+06f, + 1.01072500e+06f, 1.05105000e+06f, 1.02635000e+06f, 1.06730000e+06f, + 1.04197500e+06f, 1.08355000e+06f, 1.05760000e+06f, 1.09980000e+06f, + 1.07322500e+06f, 1.11605000e+06f, 1.08885000e+06f, 1.13230000e+06f, + 1.10447500e+06f, 1.14855000e+06f, 1.12010000e+06f, 1.16480000e+06f, + 1.13572500e+06f, 1.18105000e+06f, 1.15135000e+06f, 1.19730000e+06f, + 1.16697500e+06f, 1.21355000e+06f, 3.54260000e+06f, 3.58980000e+06f, + 3.58947500e+06f, 3.63730000e+06f, 3.63635000e+06f, 3.68480000e+06f, + 3.68322500e+06f, 3.73230000e+06f, 3.73010000e+06f, 3.77980000e+06f, + 3.77697500e+06f, 3.82730000e+06f, 3.82385000e+06f, 3.87480000e+06f, + 3.87072500e+06f, 3.92230000e+06f, 3.91760000e+06f, 3.96980000e+06f, + 3.96447500e+06f, 4.01730000e+06f, 4.01135000e+06f, 4.06480000e+06f, + 4.05822500e+06f, 4.11230000e+06f, 4.10510000e+06f, 4.15980000e+06f, + 4.15197500e+06f, 4.20730000e+06f, 4.19885000e+06f, 4.25480000e+06f, + 4.24572500e+06f, 4.30230000e+06f, 4.29260000e+06f, 4.34980000e+06f, + 4.33947500e+06f, 4.39730000e+06f, 4.38635000e+06f, 4.44480000e+06f, + 4.43322500e+06f, 4.49230000e+06f, 4.48010000e+06f, 4.53980000e+06f, + 4.52697500e+06f, 4.58730000e+06f, 4.57385000e+06f, 4.63480000e+06f, + 4.62072500e+06f, 4.68230000e+06f, 4.66760000e+06f, 4.72980000e+06f, + 4.71447500e+06f, 4.77730000e+06f, 4.76135000e+06f, 4.82480000e+06f, + 4.80822500e+06f, 4.87230000e+06f, 4.85510000e+06f, 4.91980000e+06f, + 4.90197500e+06f, 4.96730000e+06f, 4.94885000e+06f, 5.01480000e+06f, + 4.99572500e+06f, 5.06230000e+06f, 5.04260000e+06f, 5.10980000e+06f, + 5.08947500e+06f, 5.15730000e+06f, 5.13635000e+06f, 5.20480000e+06f, + 5.18322500e+06f, 5.25230000e+06f, 5.23010000e+06f, 5.29980000e+06f, + 5.27697500e+06f, 5.34730000e+06f, 5.32385000e+06f, 5.39480000e+06f, + 5.37072500e+06f, 5.44230000e+06f, 5.41760000e+06f, 5.48980000e+06f, + 5.46447500e+06f, 5.53730000e+06f, 5.51135000e+06f, 5.58480000e+06f, + 5.55822500e+06f, 5.63230000e+06f, 5.60510000e+06f, 5.67980000e+06f, + 5.65197500e+06f, 5.72730000e+06f, 5.69885000e+06f, 5.77480000e+06f, + 5.74572500e+06f, 5.82230000e+06f, 5.79260000e+06f, 5.86980000e+06f, + 5.83947500e+06f, 5.91730000e+06f, 5.88635000e+06f, 5.96480000e+06f, + 5.93322500e+06f, 6.01230000e+06f, 5.98010000e+06f, 6.05980000e+06f, + 6.02697500e+06f, 6.10730000e+06f, 6.07385000e+06f, 6.15480000e+06f, + 6.12072500e+06f, 6.20230000e+06f, 6.16760000e+06f, 6.24980000e+06f, + 6.21447500e+06f, 6.29730000e+06f, 6.26135000e+06f, 6.34480000e+06f, + 6.30822500e+06f, 6.39230000e+06f, 6.35510000e+06f, 6.43980000e+06f, + 6.40197500e+06f, 6.48730000e+06f, 6.44885000e+06f, 6.53480000e+06f, + 6.49572500e+06f, 6.58230000e+06f, 6.54260000e+06f, 6.62980000e+06f, + 6.58947500e+06f, 6.67730000e+06f, 6.63635000e+06f, 6.72480000e+06f, + 6.68322500e+06f, 6.77230000e+06f, 6.73010000e+06f, 6.81980000e+06f, + 6.77697500e+06f, 6.86730000e+06f, 6.82385000e+06f, 6.91480000e+06f, + 6.87072500e+06f, 6.96230000e+06f, 6.91760000e+06f, 7.00980000e+06f, + 6.96447500e+06f, 7.05730000e+06f, 7.01135000e+06f, 7.10480000e+06f, + 1.17619750e+07f, 1.18560500e+07f, 1.18401000e+07f, 1.19348000e+07f, + 1.19182250e+07f, 1.20135500e+07f, 1.19963500e+07f, 1.20923000e+07f, + 1.20744750e+07f, 1.21710500e+07f, 1.21526000e+07f, 1.22498000e+07f, + 1.22307250e+07f, 1.23285500e+07f, 1.23088500e+07f, 1.24073000e+07f, + 1.23869750e+07f, 1.24860500e+07f, 1.24651000e+07f, 1.25648000e+07f, + 1.25432250e+07f, 1.26435500e+07f, 1.26213500e+07f, 1.27223000e+07f, + 1.26994750e+07f, 1.28010500e+07f, 1.27776000e+07f, 1.28798000e+07f, + 1.28557250e+07f, 1.29585500e+07f, 1.29338500e+07f, 1.30373000e+07f, + 1.30119750e+07f, 1.31160500e+07f, 1.30901000e+07f, 1.31948000e+07f, + 1.31682250e+07f, 1.32735500e+07f, 1.32463500e+07f, 1.33523000e+07f, + 1.33244750e+07f, 1.34310500e+07f, 1.34026000e+07f, 1.35098000e+07f, + 1.34807250e+07f, 1.35885500e+07f, 1.35588500e+07f, 1.36673000e+07f, + 1.36369750e+07f, 1.37460500e+07f, 1.37151000e+07f, 1.38248000e+07f, + 1.37932250e+07f, 1.39035500e+07f, 1.38713500e+07f, 1.39823000e+07f, + 1.39494750e+07f, 1.40610500e+07f, 1.40276000e+07f, 1.41398000e+07f, + 1.41057250e+07f, 1.42185500e+07f, 1.41838500e+07f, 1.42973000e+07f, + 1.42619750e+07f, 1.43760500e+07f, 1.43401000e+07f, 1.44548000e+07f, + 1.44182250e+07f, 1.45335500e+07f, 1.44963500e+07f, 1.46123000e+07f, + 1.45744750e+07f, 1.46910500e+07f, 1.46526000e+07f, 1.47698000e+07f, + 1.47307250e+07f, 1.48485500e+07f, 1.48088500e+07f, 1.49273000e+07f, + 1.48869750e+07f, 1.50060500e+07f, 1.49651000e+07f, 1.50848000e+07f, + 1.50432250e+07f, 1.51635500e+07f, 1.51213500e+07f, 1.52423000e+07f, + 1.51994750e+07f, 1.53210500e+07f, 1.52776000e+07f, 1.53998000e+07f, + 1.53557250e+07f, 1.54785500e+07f, 1.54338500e+07f, 1.55573000e+07f, + 1.55119750e+07f, 1.56360500e+07f, 1.55901000e+07f, 1.57148000e+07f, + 1.56682250e+07f, 1.57935500e+07f, 1.57463500e+07f, 1.58723000e+07f, + 1.58244750e+07f, 1.59510500e+07f, 1.59026000e+07f, 1.60298000e+07f, + 1.59807250e+07f, 1.61085500e+07f, 1.60588500e+07f, 1.61873000e+07f, + 1.61369750e+07f, 1.62660500e+07f, 1.62151000e+07f, 1.63448000e+07f, + 1.62932250e+07f, 1.64235500e+07f, 1.63713500e+07f, 1.65023000e+07f, + 1.64494750e+07f, 1.65810500e+07f, 1.65276000e+07f, 1.66598000e+07f, + 1.66057250e+07f, 1.67385500e+07f, 1.66838500e+07f, 1.68173000e+07f, + 1.67619750e+07f, 1.68960500e+07f, 1.68401000e+07f, 1.69748000e+07f, + 1.69182250e+07f, 1.70535500e+07f, 1.69963500e+07f, 1.71323000e+07f, + 1.70744750e+07f, 1.72110500e+07f, 1.71526000e+07f, 1.72898000e+07f, + 1.72307250e+07f, 1.73685500e+07f, 1.73088500e+07f, 1.74473000e+07f, + 1.73869750e+07f, 1.75260500e+07f, 1.74651000e+07f, 1.76048000e+07f, + 1.75432250e+07f, 1.76835500e+07f, 2.46688500e+07f, 2.48098000e+07f, + 2.47782250e+07f, 2.49198000e+07f, 2.48876000e+07f, 2.50298000e+07f, + 2.49969750e+07f, 2.51398000e+07f, 2.51063500e+07f, 2.52498000e+07f, + 2.52157250e+07f, 2.53598000e+07f, 2.53251000e+07f, 2.54698000e+07f, + 2.54344750e+07f, 2.55798000e+07f, 2.55438500e+07f, 2.56898000e+07f, + 2.56532250e+07f, 2.57998000e+07f, 2.57626000e+07f, 2.59098000e+07f, + 2.58719750e+07f, 2.60198000e+07f, 2.59813500e+07f, 2.61298000e+07f, + 2.60907250e+07f, 2.62398000e+07f, 2.62001000e+07f, 2.63498000e+07f, + 2.63094750e+07f, 2.64598000e+07f, 2.64188500e+07f, 2.65698000e+07f, + 2.65282250e+07f, 2.66798000e+07f, 2.66376000e+07f, 2.67898000e+07f, + 2.67469750e+07f, 2.68998000e+07f, 2.68563500e+07f, 2.70098000e+07f, + 2.69657250e+07f, 2.71198000e+07f, 2.70751000e+07f, 2.72298000e+07f, + 2.71844750e+07f, 2.73398000e+07f, 2.72938500e+07f, 2.74498000e+07f, + 2.74032250e+07f, 2.75598000e+07f, 2.75126000e+07f, 2.76698000e+07f, + 2.76219750e+07f, 2.77798000e+07f, 2.77313500e+07f, 2.78898000e+07f, + 2.78407250e+07f, 2.79998000e+07f, 2.79501000e+07f, 2.81098000e+07f, + 2.80594750e+07f, 2.82198000e+07f, 2.81688500e+07f, 2.83298000e+07f, + 2.82782250e+07f, 2.84398000e+07f, 2.83876000e+07f, 2.85498000e+07f, + 2.84969750e+07f, 2.86598000e+07f, 2.86063500e+07f, 2.87698000e+07f, + 2.87157250e+07f, 2.88798000e+07f, 2.88251000e+07f, 2.89898000e+07f, + 2.89344750e+07f, 2.90998000e+07f, 2.90438500e+07f, 2.92098000e+07f, + 2.91532250e+07f, 2.93198000e+07f, 2.92626000e+07f, 2.94298000e+07f, + 2.93719750e+07f, 2.95398000e+07f, 2.94813500e+07f, 2.96498000e+07f, + 2.95907250e+07f, 2.97598000e+07f, 2.97001000e+07f, 2.98698000e+07f, + 2.98094750e+07f, 2.99798000e+07f, 2.99188500e+07f, 3.00898000e+07f, + 3.00282250e+07f, 3.01998000e+07f, 3.01376000e+07f, 3.03098000e+07f, + 3.02469750e+07f, 3.04198000e+07f, 3.03563500e+07f, 3.05298000e+07f, + 3.04657250e+07f, 3.06398000e+07f, 3.05751000e+07f, 3.07498000e+07f, + 3.06844750e+07f, 3.08598000e+07f, 3.07938500e+07f, 3.09698000e+07f, + 3.09032250e+07f, 3.10798000e+07f, 3.10126000e+07f, 3.11898000e+07f, + 3.11219750e+07f, 3.12998000e+07f, 3.12313500e+07f, 3.14098000e+07f, + 3.13407250e+07f, 3.15198000e+07f, 3.14501000e+07f, 3.16298000e+07f, + 3.15594750e+07f, 3.17398000e+07f, 3.16688500e+07f, 3.18498000e+07f, + 3.17782250e+07f, 3.19598000e+07f, 3.18876000e+07f, 3.20698000e+07f, + 3.19969750e+07f, 3.21798000e+07f, 3.21063500e+07f, 3.22898000e+07f, + 3.22157250e+07f, 3.23998000e+07f, 3.23251000e+07f, 3.25098000e+07f, + 3.24344750e+07f, 3.26198000e+07f, 3.25438500e+07f, 3.27298000e+07f, + 3.26532250e+07f, 3.28398000e+07f, 3.27626000e+07f, 3.29498000e+07}; + + auto a = NDArrayFactory::create('c', {2, 72, 25}); + for (int e = 0; e < a.lengthOf(); e++) a.p(e, e + 1); + + auto b = NDArrayFactory::create('c', {2, 25, 2}); + for (int e = 0; e < b.lengthOf(); e++) b.p(e, e + 1); + + NDArray exp(_expB, _expS); + + auto c = MmulHelper::mmul(&a, &b); + + ASSERT_TRUE(exp.isSameShape(c)); + ASSERT_TRUE(exp.equalsTo(c, 1e1)); + + delete c; } - TEST_F(NDArrayTest, TestNegSize1) { - auto array = NDArrayFactory::create('c', {2, 5, 7}); + auto array = NDArrayFactory::create('c', {2, 5, 7}); - ASSERT_EQ(7, array.sizeAt(-1)); - ASSERT_EQ(5, array.sizeAt(-2)); - ASSERT_EQ(2, array.sizeAt(-3)); + ASSERT_EQ(7, array.sizeAt(-1)); + ASSERT_EQ(5, array.sizeAt(-2)); + ASSERT_EQ(2, array.sizeAt(-3)); } ////////////////////////////////////////////////////////////////////// // not-in-place TEST_F(NDArrayTest, Permute1) { + Nd4jLong shape1[] = {3, 5, 10, 15, 150, 15, 1, 8192, 1, 99}; + Nd4jLong shape2[] = {3, 15, 5, 10, 1, 150, 15, 8192, 0, 99}; + const std::initializer_list perm = {2, 0, 1}; - Nd4jLong shape1[] = {3, 5, 10, 15, 150, 15, 1, 8192, 1, 99}; - Nd4jLong shape2[] = {3, 15, 5, 10, 1, 150, 15, 8192, 0, 99}; - const std::initializer_list perm = {2, 0, 1}; + NDArray arr1(shape1, true); + NDArray arr2(shape2, true); - NDArray arr1(shape1,true); - NDArray arr2(shape2,true); - - auto result = arr1.permute(perm); - ASSERT_TRUE(result.isSameShapeStrict(arr2)); + auto result = arr1.permute(perm); + ASSERT_TRUE(result.isSameShapeStrict(arr2)); } ////////////////////////////////////////////////////////////////////// // in-place TEST_F(NDArrayTest, Permute2) { + Nd4jLong shape1[] = {3, 5, 10, 15, 150, 15, 1, 8192, 1, 99}; + Nd4jLong shape2[] = {3, 15, 5, 10, 1, 150, 15, 8192, 0, 99}; + const std::initializer_list perm = {2, 0, 1}; - Nd4jLong shape1[] = {3, 5, 10, 15, 150, 15, 1, 8192, 1, 99}; - Nd4jLong shape2[] = {3, 15, 5, 10, 1, 150, 15, 8192, 0, 99}; - const std::initializer_list perm = {2, 0, 1}; - - NDArray arr1(shape1,true); - NDArray arr2(shape2,true); + NDArray arr1(shape1, true); + NDArray arr2(shape2, true); - ASSERT_TRUE(arr1.permutei(perm)); - ASSERT_TRUE(arr1.isSameShapeStrict(arr2)); + ASSERT_TRUE(arr1.permutei(perm)); + ASSERT_TRUE(arr1.isSameShapeStrict(arr2)); } TEST_F(NDArrayTest, RSubScalarTest1) { - auto array = NDArrayFactory::create('c', {1, 4}); - array.assign(2.0); + auto array = NDArrayFactory::create('c', {1, 4}); + array.assign(2.0); - auto result = NDArrayFactory::create('c', {1, 4}); + auto result = NDArrayFactory::create('c', {1, 4}); - array.applyScalar(scalar::ReverseSubtract, 1.0, result); + array.applyScalar(scalar::ReverseSubtract, 1.0, result); - ASSERT_NEAR(-1.0, result.meanNumber().e(0), 1e-5); + ASSERT_NEAR(-1.0, result.meanNumber().e(0), 1e-5); } TEST_F(NDArrayTest, BroadcastOpsTest1) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto row = NDArrayFactory::linspace(1.0f, 5.0f, 5); + float *brow = new float[5]{1, 2, 3, 4, 5}; + auto bshape = new Nd4jLong[8]{2, 1, 5, 1, 1, 8192, 1, 99}; + float *ebuf = new float[25]{1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, + 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5}; + auto eshape = new Nd4jLong[8]{2, 5, 5, 5, 1, 8192, 1, 99}; + NDArray expRow(brow, bshape); + NDArray exp(ebuf, eshape); - auto x = NDArrayFactory::create('c', {5, 5}); - auto row = NDArrayFactory::linspace(1.0f, 5.0f, 5); - float *brow = new float[5]{1,2,3,4,5}; - auto bshape = new Nd4jLong[8]{2, 1, 5, 1, 1, 8192, 1, 99}; - float *ebuf = new float[25] {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5}; - auto eshape = new Nd4jLong[8] {2, 5, 5, 5, 1, 8192, 1, 99}; - NDArray expRow(brow, bshape); - NDArray exp(ebuf, eshape); - - ASSERT_TRUE(row->equalsTo(&expRow)); - + ASSERT_TRUE(row->equalsTo(&expRow)); - x.applyBroadcast(broadcast::Add, {1}, *row, x); + x.applyBroadcast(broadcast::Add, {1}, *row, x); - //x.printBuffer("Result"); + // x.printBuffer("Result"); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); - delete[] brow; - delete[] bshape; - delete[] ebuf; - delete[] eshape; - delete row; + delete[] brow; + delete[] bshape; + delete[] ebuf; + delete[] eshape; + delete row; } TEST_F(NDArrayTest, TestIndexedPut2) { - auto x = NDArrayFactory::create('f', {2, 2}); - //x.printShapeInfo("x shape"); - x.p(1, 1.0f); + auto x = NDArrayFactory::create('f', {2, 2}); + // x.printShapeInfo("x shape"); + x.p(1, 1.0f); - //x.printBuffer("after"); - ASSERT_NEAR(reinterpret_cast(x.buffer())[2], 1.0, 1e-5); + // x.printBuffer("after"); + ASSERT_NEAR(reinterpret_cast(x.buffer())[2], 1.0, 1e-5); } TEST_F(NDArrayTest, TestIndexedPut3) { - auto x = NDArrayFactory::create('c', {2, 2}); - x.p(1, 1.0f); + auto x = NDArrayFactory::create('c', {2, 2}); + x.p(1, 1.0f); - //x.printBuffer("after"); - ASSERT_NEAR(reinterpret_cast(x.buffer())[1], 1.0, 1e-5); + // x.printBuffer("after"); + ASSERT_NEAR(reinterpret_cast(x.buffer())[1], 1.0, 1e-5); } TEST_F(NDArrayTest, TestIndexedPut4) { - auto x = NDArrayFactory::create('f', {2, 2}); - x.p(0, 1, 1.0f); + auto x = NDArrayFactory::create('f', {2, 2}); + x.p(0, 1, 1.0f); - //x.printBuffer("after"); - ASSERT_NEAR(reinterpret_cast(x.buffer())[2], 1.0, 1e-5); + // x.printBuffer("after"); + ASSERT_NEAR(reinterpret_cast(x.buffer())[2], 1.0, 1e-5); } - TEST_F(NDArrayTest, TestIndexedPut5) { - auto x = NDArrayFactory::create('c', {2, 2}); - x.p(0, 1, 1.0f); + auto x = NDArrayFactory::create('c', {2, 2}); + x.p(0, 1, 1.0f); - //x.printBuffer("after"); - ASSERT_NEAR(x.bufferAsT()[1], 1.0, 1e-5); + // x.printBuffer("after"); + ASSERT_NEAR(x.bufferAsT()[1], 1.0, 1e-5); } TEST_F(NDArrayTest, TestAllTensors1) { - auto matrix = NDArrayFactory::create('c', {3, 5}); + auto matrix = NDArrayFactory::create('c', {3, 5}); - ResultSet rows = matrix.allTensorsAlongDimension({1}); + ResultSet rows = matrix.allTensorsAlongDimension({1}); - ASSERT_EQ(3, rows.size()); + ASSERT_EQ(3, rows.size()); } - TEST_F(NDArrayTest, TestIndexing1) { - auto matrix = NDArrayFactory::create('c', {5, 5}); - for (int e = 0; e < matrix.lengthOf(); e++) - matrix.p(e, (float) e); + auto matrix = NDArrayFactory::create('c', {5, 5}); + for (int e = 0; e < matrix.lengthOf(); e++) matrix.p(e, (float)e); - auto sub = matrix({2,4, 0,0}, true); + auto sub = matrix({2, 4, 0, 0}, true); - ASSERT_EQ(2, sub.rows()); - ASSERT_EQ(5, sub.columns()); + ASSERT_EQ(2, sub.rows()); + ASSERT_EQ(5, sub.columns()); - ASSERT_NEAR(10, sub.e(0), 1e-5); + ASSERT_NEAR(10, sub.e(0), 1e-5); } - TEST_F(NDArrayTest, TestIndexing2) { - auto matrix = NDArrayFactory::create('c', {2, 5, 4, 4}); - matrix.linspace(0); + auto matrix = NDArrayFactory::create('c', {2, 5, 4, 4}); + matrix.linspace(0); - auto sub = matrix({0,0, 2,4, 0,0, 0,0}, true); + auto sub = matrix({0, 0, 2, 4, 0, 0, 0, 0}, true); - ASSERT_EQ(2, sub.sizeAt(0)); - ASSERT_EQ(2, sub.sizeAt(1)); - ASSERT_EQ(4, sub.sizeAt(2)); - ASSERT_EQ(4, sub.sizeAt(3)); + ASSERT_EQ(2, sub.sizeAt(0)); + ASSERT_EQ(2, sub.sizeAt(1)); + ASSERT_EQ(4, sub.sizeAt(2)); + ASSERT_EQ(4, sub.sizeAt(3)); - ASSERT_EQ(64, sub.lengthOf()); - ASSERT_NEAR(32, sub.e(0), 1e-5); - ASSERT_NEAR(112, sub.e(32), 1e-5); + ASSERT_EQ(64, sub.lengthOf()); + ASSERT_NEAR(32, sub.e(0), 1e-5); + ASSERT_NEAR(112, sub.e(32), 1e-5); } TEST_F(NDArrayTest, TestIndexing3) { - auto matrix = NDArrayFactory::create('c', {5, 5}); - matrix.linspace(0); + auto matrix = NDArrayFactory::create('c', {5, 5}); + matrix.linspace(0); - auto sub = matrix({2,4, 0,0}); + auto sub = matrix({2, 4, 0, 0}); - ASSERT_EQ(2, sub.rows()); - ASSERT_EQ(5, sub.columns()); + ASSERT_EQ(2, sub.rows()); + ASSERT_EQ(5, sub.columns()); - ASSERT_NEAR(10, sub.e(0), 1e-5); + ASSERT_NEAR(10, sub.e(0), 1e-5); } - TEST_F(NDArrayTest, TestIndexing4) { - auto matrix = NDArrayFactory::create('c', {2, 5, 4, 4}); - matrix.linspace(0); + auto matrix = NDArrayFactory::create('c', {2, 5, 4, 4}); + matrix.linspace(0); - auto sub = matrix({0,0, 2,4, 0,0, 0,0}); + auto sub = matrix({0, 0, 2, 4, 0, 0, 0, 0}); - ASSERT_EQ(2, sub.sizeAt(0)); - ASSERT_EQ(2, sub.sizeAt(1)); - ASSERT_EQ(4, sub.sizeAt(2)); - ASSERT_EQ(4, sub.sizeAt(3)); + ASSERT_EQ(2, sub.sizeAt(0)); + ASSERT_EQ(2, sub.sizeAt(1)); + ASSERT_EQ(4, sub.sizeAt(2)); + ASSERT_EQ(4, sub.sizeAt(3)); - - ASSERT_EQ(64, sub.lengthOf()); - ASSERT_NEAR(32, sub.e(0), 1e-5); - ASSERT_NEAR(112, sub.e(32), 1e-5); + ASSERT_EQ(64, sub.lengthOf()); + ASSERT_NEAR(32, sub.e(0), 1e-5); + ASSERT_NEAR(112, sub.e(32), 1e-5); } TEST_F(NDArrayTest, TestReshapeNegative1) { - std::unique_ptr array(NDArrayFactory::create_('c', {2, 3, 4, 64})); + std::unique_ptr array( + NDArrayFactory::create_('c', {2, 3, 4, 64})); - array->reshapei('c', {-1, 64}); + array->reshapei('c', {-1, 64}); - ASSERT_EQ(24, array->sizeAt(0)); - ASSERT_EQ(64, array->sizeAt(1)); + ASSERT_EQ(24, array->sizeAt(0)); + ASSERT_EQ(64, array->sizeAt(1)); } TEST_F(NDArrayTest, TestReshapeNegative2) { - std::unique_ptr array(NDArrayFactory::create_('c', {2, 3, 4, 64})); + std::unique_ptr array( + NDArrayFactory::create_('c', {2, 3, 4, 64})); - auto reshaped = array->reshape('c', {-1, 64}); + auto reshaped = array->reshape('c', {-1, 64}); - ASSERT_EQ(24, reshaped.sizeAt(0)); - ASSERT_EQ(64, reshaped.sizeAt(1)); + ASSERT_EQ(24, reshaped.sizeAt(0)); + ASSERT_EQ(64, reshaped.sizeAt(1)); } ////////////////////////////////////////////////////////////////////// // TEST_F(NDArrayTest, SVD1) { // double arrA[8] = {1, 2, 3, 4, 5, 6, 7, 8}; -// double arrU[8] = {-0.822647, 0.152483, -0.421375, 0.349918, -0.020103, 0.547354, 0.381169, 0.744789}; -// double arrS[2] = {0.626828, 14.269095}; +// double arrU[8] = {-0.822647, 0.152483, -0.421375, 0.349918, -0.020103, +// 0.547354, 0.381169, 0.744789}; double arrS[2] = {0.626828, 14.269095}; // double arrVt[4] = {0.767187,-0.641423, 0.641423, 0.767187}; // int shapeA[8] = {2, 4, 2, 2, 1, 0, 1, 99}; @@ -1448,9 +1520,11 @@ TEST_F(NDArrayTest, TestReshapeNegative2) { // TEST_F(NDArrayTest, SVD2) { // double arrA[6] = {1, 2, 3, 4, 5, 6}; -// double arrU[6] = {-0.386318, -0.922366, 0.000000, -0.922366, 0.386318, 0.000000}; -// double arrS[3] = {9.508032, 0.77287, 0.000}; -// double arrVt[9] = {-0.428667, -0.566307, -0.703947, 0.805964, 0.112382, -0.581199, 0.408248, -0.816497, 0.408248}; +// double arrU[6] = {-0.386318, -0.922366, 0.000000, -0.922366, 0.386318, +// 0.000000}; double arrS[3] = {9.508032, 0.77287, 0.000}; double arrVt[9] +// = +// {-0.428667, -0.566307, -0.703947, 0.805964, 0.112382, -0.581199, 0.408248, +// -0.816497, 0.408248}; // int shapeA[8] = {2, 2, 3, 3, 1, 0, 1, 99}; // int shapeS[8] = {2, 1, 3, 3, 1, 0, 1, 99}; @@ -1473,8 +1547,8 @@ TEST_F(NDArrayTest, TestReshapeNegative2) { // TEST_F(NDArrayTest, SVD3) { // double arrA[8] = {1, 2, 3, 4, 5, 6, 7, 8}; -// double arrU[8] = {-0.822647, 0.152483, -0.421375, 0.349918, -0.020103, 0.547354, 0.381169, 0.744789}; -// double arrS[2] = {0.626828, 14.269095}; +// double arrU[8] = {-0.822647, 0.152483, -0.421375, 0.349918, -0.020103, +// 0.547354, 0.381169, 0.744789}; double arrS[2] = {0.626828, 14.269095}; // double arrVt[4] = {0.767187,-0.641423, 0.641423, 0.767187}; // int shapeA[8] = {2, 4, 2, 2, 1, 0, 1, 99}; @@ -1498,9 +1572,11 @@ TEST_F(NDArrayTest, TestReshapeNegative2) { // TEST_F(NDArrayTest, SVD4) { // double arrA[6] = {1, 2, 3, 4, 5, 6}; -// double arrU[6] = {-0.386318, -0.922366, 0.000000, -0.922366, 0.386318, 0.000000}; -// double arrS[3] = {9.508032, 0.77287, 0.000}; -// double arrVt[9] = {-0.428667, -0.566307, -0.703947, 0.805964, 0.112382, -0.581199, 0.408248, -0.816497, 0.408248}; +// double arrU[6] = {-0.386318, -0.922366, 0.000000, -0.922366, 0.386318, +// 0.000000}; double arrS[3] = {9.508032, 0.77287, 0.000}; double arrVt[9] +// = +// {-0.428667, -0.566307, -0.703947, 0.805964, 0.112382, -0.581199, 0.408248, +// -0.816497, 0.408248}; // int shapeA[8] = {2, 2, 3, 3, 1, 0, 1, 99}; // int shapeS[8] = {2, 1, 3, 3, 1, 0, 1, 99}; @@ -1521,1176 +1597,1171 @@ TEST_F(NDArrayTest, TestReshapeNegative2) { //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestStdDev1) { - auto array = NDArrayFactory::create('c', {1, 5}); - for (int e = 0; e < array.lengthOf(); e++) - array.p(e, e+1); + auto array = NDArrayFactory::create('c', {1, 5}); + for (int e = 0; e < array.lengthOf(); e++) array.p(e, e + 1); - auto std = array.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); - ASSERT_NEAR(std, 1.58109, 1e-4); + auto std = array.varianceNumber(variance::SummaryStatsStandardDeviation, true) + .e(0); + ASSERT_NEAR(std, 1.58109, 1e-4); } //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestStdDev2) { - auto array = NDArrayFactory::create('c', {5, 6}); - auto tad = array(0, {1}); + auto array = NDArrayFactory::create('c', {5, 6}); + auto tad = array(0, {1}); - ASSERT_EQ(5, tad.lengthOf()); + ASSERT_EQ(5, tad.lengthOf()); - for (int e = 0; e < tad.lengthOf(); e++) - tad.p(e, e+1); + for (int e = 0; e < tad.lengthOf(); e++) tad.p(e, e + 1); - ASSERT_NEAR(15, tad.sumNumber().e(0), 1e-5); + ASSERT_NEAR(15, tad.sumNumber().e(0), 1e-5); - auto std = tad.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); - ASSERT_NEAR(std, 1.58109, 1e-4); + auto std = tad.varianceNumber(variance::SummaryStatsStandardDeviation, true) + .e(0); + ASSERT_NEAR(std, 1.58109, 1e-4); } TEST_F(NDArrayTest, TestStdDev3) { - auto array = NDArrayFactory::create('c', {1, 50000}); - for (int e = 0; e < array.lengthOf(); e++) - array.p(e, 1.f + (e%2?0.5f:-0.5f)); + auto array = NDArrayFactory::create('c', {1, 50000}); + for (int e = 0; e < array.lengthOf(); e++) + array.p(e, 1.f + (e % 2 ? 0.5f : -0.5f)); - auto std = array.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); - // nd4j_printf("Variance is %f\n", std); - ASSERT_NEAR(std, 0.5f, 1.0e-5f); + auto std = array.varianceNumber(variance::SummaryStatsStandardDeviation, true) + .e(0); + // nd4j_printf("Variance is %f\n", std); + ASSERT_NEAR(std, 0.5f, 1.0e-5f); } TEST_F(NDArrayTest, TestStdDev4) { - auto array = NDArrayFactory::create('c', {1, 20000}); - float const ethalon = 1 / 3.f; - float x = ethalon; - int total = array.lengthOf(); - for (int e = 0; e < total; e++) { - array.p(e, 1.0f + (e % 2?ethalon:-ethalon)); - x *= (e % 2? 2.f: 0.5f); - } - x = 0.f; - for (int e = 0; e < total; ++e) { - x += array.e(e); - } - x /= array.lengthOf(); - float y = 0; - double M2 = 0; - for (int e = 0; e < total; ++e) { + auto array = NDArrayFactory::create('c', {1, 20000}); + float const ethalon = 1 / 3.f; + float x = ethalon; + int total = array.lengthOf(); + for (int e = 0; e < total; e++) { + array.p(e, 1.0f + (e % 2 ? ethalon : -ethalon)); + x *= (e % 2 ? 2.f : 0.5f); + } + x = 0.f; + for (int e = 0; e < total; ++e) { + x += array.e(e); + } + x /= array.lengthOf(); + float y = 0; + double M2 = 0; + for (int e = 0; e < total; ++e) { // y += sd::math::nd4j_abs(array(e) - x); - M2 += (array.e(e) - x) * (array.e(e) - x); - } - //y /= total; - M2 /= total; + M2 += (array.e(e) - x) * (array.e(e) - x); + } + // y /= total; + M2 /= total; - y = M2; - auto a = array.varianceNumber(variance::SummaryStatsStandardDeviation, false); - auto std = a.e(0); -// float bY = array.varianceNumber(); - float bY = 0.3333333f; - // nd4j_printf("Variance is %f, res is %f, internal is %f\n, deviance is %f(%f)\n", std, x, bY, y, sd::math::nd4j_sqrt(M2)); - ASSERT_NEAR(std, 0.3333333f, 1.0e-5f); + y = M2; + auto a = array.varianceNumber(variance::SummaryStatsStandardDeviation, false); + auto std = a.e(0); + // float bY = array.varianceNumber(); + float bY = 0.3333333f; + // nd4j_printf("Variance is %f, res is %f, internal is %f\n, deviance is + // %f(%f)\n", std, x, bY, y, sd::math::nd4j_sqrt(M2)); + ASSERT_NEAR(std, 0.3333333f, 1.0e-5f); } TEST_F(NDArrayTest, TestStdDev5) { - auto array = NDArrayFactory::create('c', {1, 10000}); //00000}); - auto arrayD = NDArrayFactory::create('c', {1, 10000}); //00000}); - for (int e = 0; e < array.lengthOf(); e++) { - array.p(e, 1.f + (e%2?1/5.f:-1/5.f)); - arrayD.p(e, 1.0 + (e%2?1/5.:-1/5.)); - } - float stdF = array.varianceNumber(variance::SummaryStatsStandardDeviation, false).e(0); - double stdD = arrayD.varianceNumber(variance::SummaryStatsStandardDeviation, false).e(0); - // nd4j_printf("Variance is %f(%f)\n", stdF, stdD); - ASSERT_NEAR(stdD, 0.2, 1.0e-8); // 1/5 = 0.2 - ASSERT_NEAR(stdF, 0.2f, 1.0e-5f); // 1/5 = 0.2 + auto array = NDArrayFactory::create('c', {1, 10000}); // 00000}); + auto arrayD = NDArrayFactory::create('c', {1, 10000}); // 00000}); + for (int e = 0; e < array.lengthOf(); e++) { + array.p(e, 1.f + (e % 2 ? 1 / 5.f : -1 / 5.f)); + arrayD.p(e, 1.0 + (e % 2 ? 1 / 5. : -1 / 5.)); + } + float stdF = + array.varianceNumber(variance::SummaryStatsStandardDeviation, false) + .e(0); + double stdD = + arrayD.varianceNumber(variance::SummaryStatsStandardDeviation, false) + .e(0); + // nd4j_printf("Variance is %f(%f)\n", stdF, stdD); + ASSERT_NEAR(stdD, 0.2, 1.0e-8); // 1/5 = 0.2 + ASSERT_NEAR(stdF, 0.2f, 1.0e-5f); // 1/5 = 0.2 } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestApplyIndexReduce1) { - float xBuff[] = {1, 5, 2, 12, 9, 3, 10, 7, 4, 11, 6, 8}; - Nd4jLong xShapeInfo[] = {3, 2, 3, 2, 6, 2, 1, 8192, 1, 99}; - std::vector dim = {0,1}; + float xBuff[] = {1, 5, 2, 12, 9, 3, 10, 7, 4, 11, 6, 8}; + Nd4jLong xShapeInfo[] = {3, 2, 3, 2, 6, 2, 1, 8192, 1, 99}; + std::vector dim = {0, 1}; - NDArray x(xBuff, xShapeInfo); - auto exp = NDArrayFactory::create({3, 1}); + NDArray x(xBuff, xShapeInfo); + auto exp = NDArrayFactory::create({3, 1}); - auto result = x.applyIndexReduce(indexreduce::IndexMax, dim); - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + auto result = x.applyIndexReduce(indexreduce::IndexMax, dim); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, applyReduce3Dot) { - float xBuff[] = {1, 2, 3, 4, 5, 6}; - float yBuff[] = {2, 2, 2, 2, 2, 2}; - Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; + float xBuff[] = {1, 2, 3, 4, 5, 6}; + float yBuff[] = {2, 2, 2, 2, 2, 2}; + Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; - NDArray x(xBuff, xShapeInfo); - NDArray y(yBuff, xShapeInfo); + NDArray x(xBuff, xShapeInfo); + NDArray y(yBuff, xShapeInfo); - auto result = x.applyReduce3(reduce3::Dot, y); - ASSERT_TRUE(result.lengthOf() == 1); - ASSERT_NEAR(42, result.e(0), 1e-5); + auto result = x.applyReduce3(reduce3::Dot, y); + ASSERT_TRUE(result.lengthOf() == 1); + ASSERT_NEAR(42, result.e(0), 1e-5); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, applyAllReduce3EuclideanDistance) { - float xBuff[] = {1, 2, 3, 4, 5, 6}; - float yBuff[] = {2, 2, 2, 2, 2, 2}; - float expBuff[] = {1.414214f, 1.414214f, 5.385165f, 5.385165f}; - Nd4jLong expShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; - Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; + float xBuff[] = {1, 2, 3, 4, 5, 6}; + float yBuff[] = {2, 2, 2, 2, 2, 2}; + float expBuff[] = {1.414214f, 1.414214f, 5.385165f, 5.385165f}; + Nd4jLong expShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; + Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; - NDArray x(xBuff, xShapeInfo); - NDArray y(yBuff, xShapeInfo); - auto exp = NDArrayFactory::create('c', {2, 2}, {1.414214f, 1.414214f, 5.385165f, 5.385165f}); + NDArray x(xBuff, xShapeInfo); + NDArray y(yBuff, xShapeInfo); + auto exp = NDArrayFactory::create( + 'c', {2, 2}, {1.414214f, 1.414214f, 5.385165f, 5.385165f}); - auto result = x.applyAllReduce3(reduce3::EuclideanDistance, y, {1}); + auto result = x.applyAllReduce3(reduce3::EuclideanDistance, y, {1}); - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, applyReduce3EuclideanDistance) { - float xBuff[] = {1, 2, 3, 4, 5, 6}; - float yBuff[] = {2, 2, 2, 2, 2, 2}; - float expBuff[] = {1.414214f, 1.414214f, 5.385165f, 5.385165f}; - Nd4jLong expShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; - Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; + float xBuff[] = {1, 2, 3, 4, 5, 6}; + float yBuff[] = {2, 2, 2, 2, 2, 2}; + float expBuff[] = {1.414214f, 1.414214f, 5.385165f, 5.385165f}; + Nd4jLong expShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; + Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; - NDArray x(xBuff, xShapeInfo); - NDArray y(yBuff, xShapeInfo); - NDArray exp(expBuff, expShapeInfo); + NDArray x(xBuff, xShapeInfo); + NDArray y(yBuff, xShapeInfo); + NDArray exp(expBuff, expShapeInfo); - auto result = x.applyAllReduce3(reduce3::EuclideanDistance, y ,{1}); + auto result = x.applyAllReduce3(reduce3::EuclideanDistance, y, {1}); - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } - ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestVarianceAlongDimension1) { + float xBuff[] = {1, 2, 3, 4, 5, 6}; + float expBuff[] = {0.816497f, 0.816497f}; + Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; + Nd4jLong expShapeInfo[] = {1, 2, 1, 8192, 1, 99}; - float xBuff[] = {1, 2, 3, 4, 5, 6}; - float expBuff[] = {0.816497f, 0.816497f}; - Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; - Nd4jLong expShapeInfo[] = {1, 2, 1, 8192, 1, 99}; - - NDArray x(xBuff, xShapeInfo); - NDArray exp(expBuff, expShapeInfo); + NDArray x(xBuff, xShapeInfo); + NDArray exp(expBuff, expShapeInfo); - auto result = x.varianceAlongDimension(variance::SummaryStatsStandardDeviation, false, {1}); + auto result = x.varianceAlongDimension( + variance::SummaryStatsStandardDeviation, false, {1}); - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestVarianceAlongDimension2) { - float xBuff[] = {1, 2, 3, 4, 5, 6}; - float expBuff[] = {0.666667f, 0.666667f}; - Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; - Nd4jLong expShapeInfo[] = {1, 2, 1, 8192, 1, 99}; + float xBuff[] = {1, 2, 3, 4, 5, 6}; + float expBuff[] = {0.666667f, 0.666667f}; + Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; + Nd4jLong expShapeInfo[] = {1, 2, 1, 8192, 1, 99}; + NDArray x(xBuff, xShapeInfo); + NDArray exp(expBuff, expShapeInfo); - NDArray x(xBuff, xShapeInfo); - NDArray exp(expBuff, expShapeInfo); - - auto result = x.varianceAlongDimension(variance::SummaryStatsVariance, false, {1}); - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + auto result = + x.varianceAlongDimension(variance::SummaryStatsVariance, false, {1}); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestVarianceAlongDimension3) { - - - NDArray x = NDArrayFactory::create('c', {10, 10});//(xBuff, xShapeInfo); - NDArray exp = NDArrayFactory::create('c', {10});//(expBuff, expShapeInfo); - x.linspace(1); // 1, 2, 3, ..., 100 - exp.assign(825.f); - auto result = x.varianceAlongDimension(variance::SummaryStatsVariance, false, {0}); - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + NDArray x = + NDArrayFactory::create('c', {10, 10}); //(xBuff, xShapeInfo); + NDArray exp = + NDArrayFactory::create('c', {10}); //(expBuff, expShapeInfo); + x.linspace(1); // 1, 2, 3, ..., 100 + exp.assign(825.f); + auto result = + x.varianceAlongDimension(variance::SummaryStatsVariance, false, {0}); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestVarianceAlongDimension4) { - - - NDArray x = NDArrayFactory::create('c', {12, 1, 12});//(xBuff, xShapeInfo); - NDArray exp = NDArrayFactory::create('c', {1,12});//(expBuff, expShapeInfo); - x.linspace(1); // 1, 2, 3, ..., 100 - exp.assign(1716.); - auto result = x.varianceAlongDimension(variance::SummaryStatsVariance, false, {0}); - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + NDArray x = + NDArrayFactory::create('c', {12, 1, 12}); //(xBuff, xShapeInfo); + NDArray exp = + NDArrayFactory::create('c', {1, 12}); //(expBuff, expShapeInfo); + x.linspace(1); // 1, 2, 3, ..., 100 + exp.assign(1716.); + auto result = + x.varianceAlongDimension(variance::SummaryStatsVariance, false, {0}); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestSubRowVector1) { - float xBuff[] = {6, 7, 8, 9}; - float yBuff[] = {1, 2}; - float expBuff[] = {5, 5, 7, 7}; - Nd4jLong xShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; - Nd4jLong yShapeInfo[] = {2, 1, 2, 2, 1, 8192, 1, 99}; + float xBuff[] = {6, 7, 8, 9}; + float yBuff[] = {1, 2}; + float expBuff[] = {5, 5, 7, 7}; + Nd4jLong xShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; + Nd4jLong yShapeInfo[] = {2, 1, 2, 2, 1, 8192, 1, 99}; - NDArray x(xBuff, xShapeInfo); - NDArray y(yBuff, yShapeInfo); - NDArray target(x); - NDArray exp(expBuff, xShapeInfo); + NDArray x(xBuff, xShapeInfo); + NDArray y(yBuff, yShapeInfo); + NDArray target(x); + NDArray exp(expBuff, xShapeInfo); - x.subRowVector(y, target); + x.subRowVector(y, target); - ASSERT_TRUE(exp.isSameShapeStrict(target)); - ASSERT_TRUE(exp.equalsTo(&target)); + ASSERT_TRUE(exp.isSameShapeStrict(target)); + ASSERT_TRUE(exp.equalsTo(&target)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestDivRowVector1) { - float xBuff[] = {6, 8, 10, 12}; - float yBuff[] = {2, 4}; - float expBuff[] = {3, 2, 5, 3}; - Nd4jLong xShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; - Nd4jLong yShapeInfo[] = {2, 1, 2, 2, 1, 8192, 1, 99}; + float xBuff[] = {6, 8, 10, 12}; + float yBuff[] = {2, 4}; + float expBuff[] = {3, 2, 5, 3}; + Nd4jLong xShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; + Nd4jLong yShapeInfo[] = {2, 1, 2, 2, 1, 8192, 1, 99}; - NDArray x(xBuff, xShapeInfo); - NDArray y(yBuff, yShapeInfo); - NDArray target(x); - NDArray exp(expBuff, xShapeInfo); + NDArray x(xBuff, xShapeInfo); + NDArray y(yBuff, yShapeInfo); + NDArray target(x); + NDArray exp(expBuff, xShapeInfo); - x.divRowVector(y, target); + x.divRowVector(y, target); - ASSERT_TRUE(exp.isSameShapeStrict(target)); - ASSERT_TRUE(exp.equalsTo(&target)); + ASSERT_TRUE(exp.isSameShapeStrict(target)); + ASSERT_TRUE(exp.equalsTo(&target)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestMulRowVector1) { - float xBuff[] = {6, 8, 10, 12}; - float yBuff[] = {2, 4}; - float expBuff[] = {12, 32, 20, 48}; - Nd4jLong xShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; - Nd4jLong yShapeInfo[] = {2, 1, 2, 2, 1, 8192, 1, 99}; + float xBuff[] = {6, 8, 10, 12}; + float yBuff[] = {2, 4}; + float expBuff[] = {12, 32, 20, 48}; + Nd4jLong xShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; + Nd4jLong yShapeInfo[] = {2, 1, 2, 2, 1, 8192, 1, 99}; - NDArray x(xBuff, xShapeInfo); - NDArray y(yBuff, yShapeInfo); - NDArray target(x); - NDArray exp(expBuff, xShapeInfo); + NDArray x(xBuff, xShapeInfo); + NDArray y(yBuff, yShapeInfo); + NDArray target(x); + NDArray exp(expBuff, xShapeInfo); - x.mulRowVector(y, target); + x.mulRowVector(y, target); - ASSERT_TRUE(exp.isSameShapeStrict(target)); - ASSERT_TRUE(exp.equalsTo(&target)); + ASSERT_TRUE(exp.isSameShapeStrict(target)); + ASSERT_TRUE(exp.equalsTo(&target)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestTensorDotAgain_1) { - int sY = 1; - int sX = 1; - int pY = 0; - int pX = 0; - int iC = 2; - int oC = 2; - int kY = 3; - int kX = 3; - int iY = 2; - int iX = 2; - int oY = 6; - int oX = 6; - int eD = iC * oC; - int B = 2; - - /* - input = np.linspace(1, B * iC * iY * iX, B * iC * iY * iX).reshape(B, iC, iY, iX) - weights = np.linspace(1, iC * oC * kY * kX, iC * oC * kY * kX).reshape(iC, oC, kY, kX) - */ - double _expB[] = {96.0, 116.0, 136.0, 156.0, 256.0, 276.0, 296.0, 316.0, 102.0, 124.0, 146.0, 168.0, 278.0, 300.0, 322.0, 344.0, 108.0, 132.0, 156.0, 180.0, 300.0, 324.0, 348.0, 372.0, 114.0, 140.0, 166.0, 192.0, 322.0, 348.0, 374.0, 400.0, 120.0, 148.0, 176.0, 204.0, 344.0, 372.0, 400.0, 428.0, 126.0, 156.0, 186.0, 216.0, 366.0, 396.0, 426.0, 456.0, 132.0, 164.0, 196.0, 228.0, 388.0, 420.0, 452.0, 484.0, 138.0, 172.0, 206.0, 240.0, 410.0, 444.0, 478.0, 512.0, 144.0, 180.0, 216.0, 252.0, 432.0, 468.0, 504.0, 540.0, 150.0, 188.0, 226.0, 264.0, 454.0, 492.0, 530.0, 568.0, 156.0, 196.0, 236.0, 276.0, 476.0, 516.0, 556.0, 596.0, 162.0, 204.0, 246.0, 288.0, 498.0, 540.0, 582.0, 624.0, 168.0, 212.0, 256.0, 300.0, 520.0, 564.0, 608.0, 652.0, 174.0, 220.0, 266.0, 312.0, 542.0, 588.0, 634.0, 680.0, 180.0, 228.0, 276.0, 324.0, 564.0, 612.0, 660.0, 708.0, 186.0, 236.0, 286.0, 336.0, 586.0, 636.0, 686.0, 736.0, 192.0, 244.0, 296.0, 348.0, 608.0, 660.0, 712.0, 764.0, 198.0, 252.0, 306.0, 360.0, 630.0, 684.0, 738.0, 792.0}; - - Nd4jLong _expS[] = {6, 2, 3, 3, 2, 2, 2, 72, 24, 8, 4, 2, 1, 16384, 1, 99}; - NDArray exp(_expB, _expS, sd::LaunchContext ::defaultContext(), false); - - auto input = NDArrayFactory::create('c', {B, iC, iY, iX}); - auto weights = NDArrayFactory::create('c', {iC, oC, kY, kX}); - - input.linspace(1); - weights.linspace(1); - - auto result = MmulHelper::tensorDot(&weights, &input, {0}, {1}); - - //result->printShapeInfo("result shape"); - ASSERT_TRUE(exp.isSameShape(result)); - -// exp.printBuffer("Expctd buffer"); -// result->printBuffer("Result buffer"); - ASSERT_TRUE(exp.equalsTo(result)); - - delete result; + int sY = 1; + int sX = 1; + int pY = 0; + int pX = 0; + int iC = 2; + int oC = 2; + int kY = 3; + int kX = 3; + int iY = 2; + int iX = 2; + int oY = 6; + int oX = 6; + int eD = iC * oC; + int B = 2; + + /* + input = np.linspace(1, B * iC * iY * iX, B * iC * iY * iX).reshape(B, iC, iY, + iX) weights = np.linspace(1, iC * oC * kY * kX, iC * oC * kY * kX).reshape(iC, + oC, kY, kX) + */ + double _expB[] = { + 96.0, 116.0, 136.0, 156.0, 256.0, 276.0, 296.0, 316.0, 102.0, 124.0, + 146.0, 168.0, 278.0, 300.0, 322.0, 344.0, 108.0, 132.0, 156.0, 180.0, + 300.0, 324.0, 348.0, 372.0, 114.0, 140.0, 166.0, 192.0, 322.0, 348.0, + 374.0, 400.0, 120.0, 148.0, 176.0, 204.0, 344.0, 372.0, 400.0, 428.0, + 126.0, 156.0, 186.0, 216.0, 366.0, 396.0, 426.0, 456.0, 132.0, 164.0, + 196.0, 228.0, 388.0, 420.0, 452.0, 484.0, 138.0, 172.0, 206.0, 240.0, + 410.0, 444.0, 478.0, 512.0, 144.0, 180.0, 216.0, 252.0, 432.0, 468.0, + 504.0, 540.0, 150.0, 188.0, 226.0, 264.0, 454.0, 492.0, 530.0, 568.0, + 156.0, 196.0, 236.0, 276.0, 476.0, 516.0, 556.0, 596.0, 162.0, 204.0, + 246.0, 288.0, 498.0, 540.0, 582.0, 624.0, 168.0, 212.0, 256.0, 300.0, + 520.0, 564.0, 608.0, 652.0, 174.0, 220.0, 266.0, 312.0, 542.0, 588.0, + 634.0, 680.0, 180.0, 228.0, 276.0, 324.0, 564.0, 612.0, 660.0, 708.0, + 186.0, 236.0, 286.0, 336.0, 586.0, 636.0, 686.0, 736.0, 192.0, 244.0, + 296.0, 348.0, 608.0, 660.0, 712.0, 764.0, 198.0, 252.0, 306.0, 360.0, + 630.0, 684.0, 738.0, 792.0}; + + Nd4jLong _expS[] = {6, 2, 3, 3, 2, 2, 2, 72, 24, 8, 4, 2, 1, 16384, 1, 99}; + NDArray exp(_expB, _expS, sd::LaunchContext ::defaultContext(), false); + + auto input = NDArrayFactory::create('c', {B, iC, iY, iX}); + auto weights = NDArrayFactory::create('c', {iC, oC, kY, kX}); + + input.linspace(1); + weights.linspace(1); + + auto result = MmulHelper::tensorDot(&weights, &input, {0}, {1}); + + // result->printShapeInfo("result shape"); + ASSERT_TRUE(exp.isSameShape(result)); + + // exp.printBuffer("Expctd buffer"); + // result->printBuffer("Result buffer"); + ASSERT_TRUE(exp.equalsTo(result)); + + delete result; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestBroadcast_1) { - double _expB[] = {1.000000, 1.000000, 1.000000, 1.000000, 2.000000, 2.000000, 2.000000, 2.000000, 3.000000, 3.000000, 3.000000, 3.000000, 1.000000, 1.000000, 1.000000, 1.000000, 2.000000, 2.000000, 2.000000, 2.000000, 3.000000, 3.000000, 3.000000, 3.000000}; - Nd4jLong _expS[] = {4, 2, 3, 2, 2, 12, 4, 2, 1, 16384, 1, 99}; - NDArray exp(_expB, _expS, sd::LaunchContext ::defaultContext(), false); + double _expB[] = {1.000000, 1.000000, 1.000000, 1.000000, 2.000000, 2.000000, + 2.000000, 2.000000, 3.000000, 3.000000, 3.000000, 3.000000, + 1.000000, 1.000000, 1.000000, 1.000000, 2.000000, 2.000000, + 2.000000, 2.000000, 3.000000, 3.000000, 3.000000, 3.000000}; + Nd4jLong _expS[] = {4, 2, 3, 2, 2, 12, 4, 2, 1, 16384, 1, 99}; + NDArray exp(_expB, _expS, sd::LaunchContext ::defaultContext(), false); - auto input = NDArrayFactory::create('c',{ 2, 3, 2, 2}); - auto bias = NDArrayFactory::create('c', {1, 3}); + auto input = NDArrayFactory::create('c', {2, 3, 2, 2}); + auto bias = NDArrayFactory::create('c', {1, 3}); - bias.linspace(1); + bias.linspace(1); - input.applyBroadcast(broadcast::Add, {1}, bias, input); + input.applyBroadcast(broadcast::Add, {1}, bias, input); - //input.printBuffer("result"); - ASSERT_TRUE(exp.equalsTo(&input)); + // input.printBuffer("result"); + ASSERT_TRUE(exp.equalsTo(&input)); } TEST_F(NDArrayTest, TestTranspose_11) { - auto x = NDArrayFactory::create('c', {2, 3, 4}); - x.transposei(); + auto x = NDArrayFactory::create('c', {2, 3, 4}); + x.transposei(); - ASSERT_EQ(4, x.sizeAt(0)); - ASSERT_EQ(3, x.sizeAt(1)); - ASSERT_EQ(2, x.sizeAt(2)); + ASSERT_EQ(4, x.sizeAt(0)); + ASSERT_EQ(3, x.sizeAt(1)); + ASSERT_EQ(2, x.sizeAt(2)); } - TEST_F(NDArrayTest, TestTranspose_12) { - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto y = x.transpose(); + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto y = x.transpose(); - ASSERT_EQ(4, y.sizeAt(0)); - ASSERT_EQ(3, y.sizeAt(1)); - ASSERT_EQ(2, y.sizeAt(2)); + ASSERT_EQ(4, y.sizeAt(0)); + ASSERT_EQ(3, y.sizeAt(1)); + ASSERT_EQ(2, y.sizeAt(2)); - ASSERT_EQ(2, x.sizeAt(0)); - ASSERT_EQ(3, x.sizeAt(1)); - ASSERT_EQ(4, x.sizeAt(2)); + ASSERT_EQ(2, x.sizeAt(0)); + ASSERT_EQ(3, x.sizeAt(1)); + ASSERT_EQ(4, x.sizeAt(2)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestMMulMultiDim) { - const int bS=2; - const int K=3; - const int N=4; + const int bS = 2; + const int K = 3; + const int N = 4; - auto input = NDArrayFactory::create('c', {bS, K, N}); - auto weights = NDArrayFactory::create('c', {3*K, K}); - auto expected = NDArrayFactory::create('c', {bS, 3*K, N}, { 38, 44, 50, 56, 83, 98, 113, 128, 128, 152, 176, 200, 173, 206, 239, 272, 218, 260, 302, 344, 263, 314, 365, 416, 308, 368, 428, 488, 353, 422, 491, 560, 398, 476, 554, 632, 110, 116, 122, 128, 263, 278, 293, 308, 416, 440, 464, 488, 569, 602, 635, 668, 722, 764, 806, 848, 875, 926, 977, 1028, 1028, 1088, 1148, 1208, 1181, 1250, 1319, 1388, 1334, 1412, 1490, 1568}); + auto input = NDArrayFactory::create('c', {bS, K, N}); + auto weights = NDArrayFactory::create('c', {3 * K, K}); + auto expected = NDArrayFactory::create( + 'c', {bS, 3 * K, N}, + {38, 44, 50, 56, 83, 98, 113, 128, 128, 152, 176, 200, + 173, 206, 239, 272, 218, 260, 302, 344, 263, 314, 365, 416, + 308, 368, 428, 488, 353, 422, 491, 560, 398, 476, 554, 632, + 110, 116, 122, 128, 263, 278, 293, 308, 416, 440, 464, 488, + 569, 602, 635, 668, 722, 764, 806, 848, 875, 926, 977, 1028, + 1028, 1088, 1148, 1208, 1181, 1250, 1319, 1388, 1334, 1412, 1490, 1568}); - input.linspace(1); - weights.linspace(1); + input.linspace(1); + weights.linspace(1); - auto result = MmulHelper::mmul(&weights, &input, nullptr, 1., 0.); - // result must have such shape [bS x 3K x N] + auto result = MmulHelper::mmul(&weights, &input, nullptr, 1., 0.); + // result must have such shape [bS x 3K x N] - ASSERT_TRUE(result->isSameShape(&expected)); + ASSERT_TRUE(result->isSameShape(&expected)); - //result->printShapeInfo("result shape"); - // result->printBuffer("result buffer"); - ASSERT_TRUE(result->equalsTo(&expected)); - delete result; + // result->printShapeInfo("result shape"); + // result->printBuffer("result buffer"); + ASSERT_TRUE(result->equalsTo(&expected)); + delete result; } - TEST_F(NDArrayTest, AdditionOperator1) { + auto input1 = NDArrayFactory::create('c', {2, 2}); + auto input2 = NDArrayFactory::create('c', {2, 2}); + auto expected = NDArrayFactory::create('c', {2, 2}); - auto input1 = NDArrayFactory::create('c', {2,2}); - auto input2 = NDArrayFactory::create('c', {2,2}); - auto expected = NDArrayFactory::create('c', {2,2}); - - input1.assign(1.5); - input2.assign(2.); - expected.assign(3.5); + input1.assign(1.5); + input2.assign(2.); + expected.assign(3.5); - input2 = input1 + input2; - - ASSERT_TRUE(input2.equalsTo(&expected)); + input2 = input1 + input2; + ASSERT_TRUE(input2.equalsTo(&expected)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestMatmMul_Again_1) { - auto a = NDArrayFactory::create('c', {3, 4, 1}); - auto b = NDArrayFactory::create('c', {3, 1, 5}); + auto a = NDArrayFactory::create('c', {3, 4, 1}); + auto b = NDArrayFactory::create('c', {3, 1, 5}); - a.linspace(1); - b.linspace(1); + a.linspace(1); + b.linspace(1); - float _expB[] = {1.f, 2.f, 3.f, 4.f, 5.f, 2.f, 4.f, 6.f, 8.f, 10.f, 3.f, 6.f, 9.f, 12.f, 15.f, 4.f, 8.f, 12.f, 16.f, 20.f, 30.f, 35.f, 40.f, 45.f, 50.f, 36.f, 42.f, 48.f, 54.f, 60.f, 42.f, 49.f, 56.f, 63.f, 70.f, 48.f, 56.f, 64.f, 72.f, 80.f, 99.f, 108.f, 117.f, 126.f, 135.f, 110.f, 120.f, 130.f, 140.f, 150.f, 121.f, 132.f, 143.f, 154.f, 165.f, 132.f, 144.f, 156.f, 168.f, 180.f}; - Nd4jLong _expS[] = {3, 3, 4, 5, 20, 5, 1, 8192, 1, 99}; - NDArray c(_expB, _expS, sd::LaunchContext ::defaultContext(), false); + float _expB[] = { + 1.f, 2.f, 3.f, 4.f, 5.f, 2.f, 4.f, 6.f, 8.f, 10.f, + 3.f, 6.f, 9.f, 12.f, 15.f, 4.f, 8.f, 12.f, 16.f, 20.f, + 30.f, 35.f, 40.f, 45.f, 50.f, 36.f, 42.f, 48.f, 54.f, 60.f, + 42.f, 49.f, 56.f, 63.f, 70.f, 48.f, 56.f, 64.f, 72.f, 80.f, + 99.f, 108.f, 117.f, 126.f, 135.f, 110.f, 120.f, 130.f, 140.f, 150.f, + 121.f, 132.f, 143.f, 154.f, 165.f, 132.f, 144.f, 156.f, 168.f, 180.f}; + Nd4jLong _expS[] = {3, 3, 4, 5, 20, 5, 1, 8192, 1, 99}; + NDArray c(_expB, _expS, sd::LaunchContext ::defaultContext(), false); - auto c_ = MmulHelper::mmul(&a, &b); + auto c_ = MmulHelper::mmul(&a, &b); - ASSERT_TRUE(c.isSameShape(c_)); - ASSERT_TRUE(c.equalsTo(c_)); + ASSERT_TRUE(c.isSameShape(c_)); + ASSERT_TRUE(c.equalsTo(c_)); - delete c_; + delete c_; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestMatmMul_Again_2) { - auto a = NDArrayFactory::create('c', {2, 5, 4}); - auto b = NDArrayFactory::create('c', {2, 4, 1}); + auto a = NDArrayFactory::create('c', {2, 5, 4}); + auto b = NDArrayFactory::create('c', {2, 4, 1}); - a.linspace(1); - b.linspace(1); + a.linspace(1); + b.linspace(1); - double _expB[] = {30.f, 70.f, 110.f, 150.f, 190.f, 590.f, 694.f, 798.f, 902.f, 1006.f}; - Nd4jLong _expS[] = {3, 2, 5, 1, 5, 1, 1, 16384, 1, 99}; - NDArray c(_expB, _expS); + double _expB[] = {30.f, 70.f, 110.f, 150.f, 190.f, + 590.f, 694.f, 798.f, 902.f, 1006.f}; + Nd4jLong _expS[] = {3, 2, 5, 1, 5, 1, 1, 16384, 1, 99}; + NDArray c(_expB, _expS); - auto c_ = MmulHelper::mmul(&a, &b); + auto c_ = MmulHelper::mmul(&a, &b); - ASSERT_TRUE(c.isSameShape(c_)); + ASSERT_TRUE(c.isSameShape(c_)); - ASSERT_TRUE(c.equalsTo(c_)); + ASSERT_TRUE(c.equalsTo(c_)); - delete c_; + delete c_; } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Plus_Test_1) -{ - double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.}; +TEST_F(NDArrayTest, Operator_Plus_Test_1) { + double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.}; - auto x = NDArrayFactory::create('c', {3, 1, 2}); - auto y = NDArrayFactory::create('c', {2, 1}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + auto x = NDArrayFactory::create('c', {3, 1, 2}); + auto y = NDArrayFactory::create('c', {2, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); - //x.printShapeInfo("x shape"); - //y.printShapeInfo("y shape"); - //expected.printShapeInfo("e shape"); - //expected.printIndexedBuffer("e"); + // x.printShapeInfo("x shape"); + // y.printShapeInfo("y shape"); + // expected.printShapeInfo("e shape"); + // expected.printIndexedBuffer("e"); - x.linspace(1); - y.linspace(1); + x.linspace(1); + y.linspace(1); - auto result = x + y; + auto result = x + y; - //result.printIndexedBuffer("result"); + // result.printIndexedBuffer("result"); - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Plus_Test_2) -{ - double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.}; +TEST_F(NDArrayTest, Operator_Plus_Test_2) { + double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.}; - auto x = NDArrayFactory::create('c', {3, 2, 1}); - auto y = NDArrayFactory::create('c', {1, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); - x.linspace(1); - y.linspace(1); + x.linspace(1); + y.linspace(1); - auto result = x + y; - // result.printIndexedBuffer(); + auto result = x + y; + // result.printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Plus_Test_3) -{ - double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.}; +TEST_F(NDArrayTest, Operator_Plus_Test_3) { + double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.}; - auto x = NDArrayFactory::create('c', {3, 2, 1}); - auto y = NDArrayFactory::create('c', {1, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); - x.linspace(1); - y.linspace(1); + x.linspace(1); + y.linspace(1); - auto result = x + y; - // result.printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + auto result = x + y; + // result.printIndexedBuffer(); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Plus_Test_4) -{ - double expBuff[] = {11.,12., 12.,13., 13.,14., 14.,15., 13.,14., 14.,15., 15.,16., 16.,17., 15.,16., 16.,17., 17.,18., 18.,19.}; +TEST_F(NDArrayTest, Operator_Plus_Test_4) { + double expBuff[] = {11., 12., 12., 13., 13., 14., 14., 15., + 13., 14., 14., 15., 15., 16., 16., 17., + 15., 16., 16., 17., 17., 18., 18., 19.}; - auto x = NDArrayFactory::create('c', {3, 1, 2}); - auto y = NDArrayFactory::create('c', {4, 1}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); + auto x = NDArrayFactory::create('c', {3, 1, 2}); + auto y = NDArrayFactory::create('c', {4, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x + y; + auto result = x + y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Minus_Test_1) -{ - double expBuff[] = {9. ,10., 10.,11., 11.,12., 12.,13., 17.,18., 18.,19., 19.,20., 20.,21., 25.,26., 26.,27., 27.,28., 28.,29.}; +TEST_F(NDArrayTest, Operator_Minus_Test_1) { + double expBuff[] = {9., 10., 10., 11., 11., 12., 12., 13., + 17., 18., 18., 19., 19., 20., 20., 21., + 25., 26., 26., 27., 27., 28., 28., 29.}; - auto x = NDArrayFactory::create('c', {3, 4, 2}); - auto y = NDArrayFactory::create('c', {4, 1}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); + auto x = NDArrayFactory::create('c', {3, 4, 2}); + auto y = NDArrayFactory::create('c', {4, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x - y; + auto result = x - y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Minus_Test_2) -{ - double expBuff[] = {9., 8., 7., 6., 6., 5., 4., 3., 11.,10., 9., 8., 8., 7., 6., 5., 13.,12.,11.,10., 10., 9., 8., 7.}; +TEST_F(NDArrayTest, Operator_Minus_Test_2) { + double expBuff[] = {9., 8., 7., 6., 6., 5., 4., 3., 11., 10., 9., 8., + 8., 7., 6., 5., 13., 12., 11., 10., 10., 9., 8., 7.}; - auto x = NDArrayFactory::create('c', {3, 2, 1}); - auto y = NDArrayFactory::create('c', {1, 2, 4}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x - y; + auto result = x - y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Minus_Test_3) -{ - double expBuff[] = {9., 8., 7., 6., 6., 5., 4., 3., 11.,10., 9., 8., 8., 7., 6., 5., 13.,12.,11.,10., 10., 9., 8., 7.}; +TEST_F(NDArrayTest, Operator_Minus_Test_3) { + double expBuff[] = {9., 8., 7., 6., 6., 5., 4., 3., 11., 10., 9., 8., + 8., 7., 6., 5., 13., 12., 11., 10., 10., 9., 8., 7.}; - auto x = NDArrayFactory::create('c', {3, 2, 1}); - auto y = NDArrayFactory::create('c', {2, 4}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {2, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x - y; + auto result = x - y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Minus_Test_4) -{ - double expBuff[] = {9.,10., 8., 9., 11.,12.,10.,11., 13.,14.,12.,13.}; +TEST_F(NDArrayTest, Operator_Minus_Test_4) { + double expBuff[] = {9., 10., 8., 9., 11., 12., 10., 11., 13., 14., 12., 13.}; - auto x = NDArrayFactory::create('c', {3, 1, 2}); - auto y = NDArrayFactory::create('c', {2, 1}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + auto x = NDArrayFactory::create('c', {3, 1, 2}); + auto y = NDArrayFactory::create('c', {2, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x - y; + auto result = x - y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Minus_Test_5) -{ - double expBuff[] = {9. ,8 ,10., 9., 11.,10, 12.,11., 13.,12, 14.,13.}; +TEST_F(NDArrayTest, Operator_Minus_Test_5) { + double expBuff[] = {9., 8, 10., 9., 11., 10, 12., 11., 13., 12, 14., 13.}; - auto x = NDArrayFactory::create('c', {3, 2, 1}); - auto y = NDArrayFactory::create('c', {1, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x - y; + auto result = x - y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Minus_Test_6) -{ - double expBuff[] = {9., 8, 10., 9, 11.,10, 12.,11., 13.,12, 14.,13, 15.,14, 16.,15., 17.,16, 18.,17, 19.,18, 20.,19.}; +TEST_F(NDArrayTest, Operator_Minus_Test_6) { + double expBuff[] = {9., 8, 10., 9, 11., 10, 12., 11., 13., 12, 14., 13, + 15., 14, 16., 15., 17., 16, 18., 17, 19., 18, 20., 19.}; - auto x = NDArrayFactory::create('c', {3, 4, 1}); - auto y = NDArrayFactory::create('c', {1, 1, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); + auto x = NDArrayFactory::create('c', {3, 4, 1}); + auto y = NDArrayFactory::create('c', {1, 1, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x - y; + auto result = x - y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Multiply_Test_1) -{ - double expBuff[] = {10., 11., 24., 26., 42., 45., 64., 68., 18., 19., 40., 42., 66., 69., 96.,100., 26., 27., 56., 58., 90., 93., 128.,132.}; +TEST_F(NDArrayTest, Operator_Multiply_Test_1) { + double expBuff[] = {10., 11., 24., 26., 42., 45., 64., 68., + 18., 19., 40., 42., 66., 69., 96., 100., + 26., 27., 56., 58., 90., 93., 128., 132.}; - auto x = NDArrayFactory::create('c', {3, 4, 2}); - auto y = NDArrayFactory::create('c', {4, 1}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); + auto x = NDArrayFactory::create('c', {3, 4, 2}); + auto y = NDArrayFactory::create('c', {4, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x * y; + auto result = x * y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Multiply_Test_2) -{ - double expBuff[] = {10.,20., 30., 40., 55.,66., 77., 88., 12.,24., 36., 48., 65.,78., 91.,104., 14.,28., 42., 56., 75.,90.,105.,120.}; +TEST_F(NDArrayTest, Operator_Multiply_Test_2) { + double expBuff[] = {10., 20., 30., 40., 55., 66., 77., 88., + 12., 24., 36., 48., 65., 78., 91., 104., + 14., 28., 42., 56., 75., 90., 105., 120.}; - auto x = NDArrayFactory::create('c', {3, 2, 1}); - auto y = NDArrayFactory::create('c', {1, 2, 4}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x * y; + auto result = x * y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Multiply_Test_3) -{ - double expBuff[] = {10.,20., 30., 40.,55.,66., 77., 88., 12.,24., 36., 48.,65.,78., 91.,104., 14.,28., 42., 56.,75.,90.,105.,120.}; +TEST_F(NDArrayTest, Operator_Multiply_Test_3) { + double expBuff[] = {10., 20., 30., 40., 55., 66., 77., 88., + 12., 24., 36., 48., 65., 78., 91., 104., + 14., 28., 42., 56., 75., 90., 105., 120.}; - auto x = NDArrayFactory::create('c', {3, 2, 1}); - auto y = NDArrayFactory::create('c', {2, 4}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {2, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x * y; + auto result = x * y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Multiply_Test_4) -{ - double expBuff[] = {10.,11.,20.,22., 12.,13.,24.,26., 14.,15.,28.,30.}; +TEST_F(NDArrayTest, Operator_Multiply_Test_4) { + double expBuff[] = {10., 11., 20., 22., 12., 13., + 24., 26., 14., 15., 28., 30.}; - auto x = NDArrayFactory::create('c', {3, 1, 2}); - auto y = NDArrayFactory::create('c', {2, 1}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + auto x = NDArrayFactory::create('c', {3, 1, 2}); + auto y = NDArrayFactory::create('c', {2, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x * y; + auto result = x * y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Multiply_Test_5) -{ - double expBuff[] = {10.,20.,11.,22., 12.,24.,13.,26., 14.,28.,15.,30.}; +TEST_F(NDArrayTest, Operator_Multiply_Test_5) { + double expBuff[] = {10., 20., 11., 22., 12., 24., + 13., 26., 14., 28., 15., 30.}; - auto x = NDArrayFactory::create('c', {3, 2, 1}); - auto y = NDArrayFactory::create('c', {1, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x * y; + auto result = x * y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Multiply_Test_6) -{ - double expBuff[] = {10,11.,12.,13.,28.,30.,32.,34.,54.,57.,60.,63.}; +TEST_F(NDArrayTest, Operator_Multiply_Test_6) { + double expBuff[] = {10, 11., 12., 13., 28., 30., + 32., 34., 54., 57., 60., 63.}; - auto x = NDArrayFactory::create('c', {3, 4, 1}); - auto y = NDArrayFactory::create('c', {3, 1, 1}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 1}); + auto x = NDArrayFactory::create('c', {3, 4, 1}); + auto y = NDArrayFactory::create('c', {3, 1, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 1}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x * y; + auto result = x * y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Divide_Test_1) -{ - double expBuff[] = {10. ,11. , 6. , 6.5 , 4.6666, 5. , 4. , 4.25 , 18. ,19. , 10. ,10.5 , 7.3333, 7.6666, 6. , 6.25 , 26. ,27. , 14. ,14.5 , 10. ,10.3333, 8. , 8.25}; +TEST_F(NDArrayTest, Operator_Divide_Test_1) { + double expBuff[] = {10., 11., 6., 6.5, 4.6666, 5., 4., 4.25, + 18., 19., 10., 10.5, 7.3333, 7.6666, 6., 6.25, + 26., 27., 14., 14.5, 10., 10.3333, 8., 8.25}; - auto x = NDArrayFactory::create('c', {3, 4, 2}); - auto y = NDArrayFactory::create('c', {4, 1}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); + auto x = NDArrayFactory::create('c', {3, 4, 2}); + auto y = NDArrayFactory::create('c', {4, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x / y; + auto result = x / y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result,1e-4)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result, 1e-4)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Divide_Test_2) -{ - double expBuff[] = {10. ,5. ,3.333333,2.5 , 2.2,1.83333,1.571428,1.375, 12. ,6. ,4. ,3. , 2.6,2.16666,1.857142,1.625, 14. ,7. ,4.666666,3.5 , 3. ,2.5 ,2.142857,1.875}; +TEST_F(NDArrayTest, Operator_Divide_Test_2) { + double expBuff[] = {10., 5., 3.333333, 2.5, 2.2, 1.83333, 1.571428, 1.375, + 12., 6., 4., 3., 2.6, 2.16666, 1.857142, 1.625, + 14., 7., 4.666666, 3.5, 3., 2.5, 2.142857, 1.875}; - auto x = NDArrayFactory::create('c', {3, 2, 1}); - auto y = NDArrayFactory::create('c', {1, 2, 4}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x / y; + auto result = x / y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Divide_Test_3) -{ - double expBuff[] = {10. ,5. ,3.333333,2.5 , 2.2,1.833333,1.571428,1.375, 12. ,6. ,4. ,3. , 2.6,2.166666,1.857142,1.625, 14. ,7. ,4.666666,3.5 , 3. ,2.5 ,2.142857,1.875}; +TEST_F(NDArrayTest, Operator_Divide_Test_3) { + double expBuff[] = {10., 5., 3.333333, 2.5, 2.2, 1.833333, 1.571428, 1.375, + 12., 6., 4., 3., 2.6, 2.166666, 1.857142, 1.625, + 14., 7., 4.666666, 3.5, 3., 2.5, 2.142857, 1.875}; - auto x = NDArrayFactory::create('c', {3, 2, 1}); - auto y = NDArrayFactory::create('c', {2, 4}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {2, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x / y; + auto result = x / y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Divide_Test_4) -{ - double expBuff[] = {10.,11., 5., 5.5, 12.,13., 6., 6.5, 14.,15., 7., 7.5}; +TEST_F(NDArrayTest, Operator_Divide_Test_4) { + double expBuff[] = {10., 11., 5., 5.5, 12., 13., 6., 6.5, 14., 15., 7., 7.5}; - auto x = NDArrayFactory::create('c', {3, 1, 2}); - auto y = NDArrayFactory::create('c', {2, 1}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + auto x = NDArrayFactory::create('c', {3, 1, 2}); + auto y = NDArrayFactory::create('c', {2, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x / y; + auto result = x / y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Divide_Test_5) -{ - double expBuff[] = {10.,5., 11.,5.5, 12.,6., 13.,6.5, 14.,7., 15.,7.5}; +TEST_F(NDArrayTest, Operator_Divide_Test_5) { + double expBuff[] = {10., 5., 11., 5.5, 12., 6., 13., 6.5, 14., 7., 15., 7.5}; - auto x = NDArrayFactory::create('c', {3, 2, 1}); - auto y = NDArrayFactory::create('c', {1, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x / y; + auto result = x / y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Divide_Test_6) -{ - double expBuff[] = {10. , 5.5 , 4. , 3.25 ,14. , 7.5 , 5.333333, 4.25 ,18. , 9.5 , 6.666666, 5.25}; +TEST_F(NDArrayTest, Operator_Divide_Test_6) { + double expBuff[] = {10., 5.5, 4., 3.25, 14., 7.5, + 5.333333, 4.25, 18., 9.5, 6.666666, 5.25}; - auto x = NDArrayFactory::create('c', {3, 4, 1}); - auto y = NDArrayFactory::create('c', {1, 4, 1}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 1}); + auto x = NDArrayFactory::create('c', {3, 4, 1}); + auto y = NDArrayFactory::create('c', {1, 4, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 1}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x / y; + auto result = x / y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Divide_Test_7) -{ - double expBuff[] = {10., 5. ,3.333333,2.5 ,11., 5.5,3.666666,2.75,12., 6. ,4. ,3. ,13., 6.5,4.333333,3.25, 14., 7. ,4.666666,3.5 ,15., 7.5,5. ,3.75,16., 8. ,5.333333,4. ,17., 8.5,5.666666,4.25, 18., 9. ,6. ,4.5 ,19., 9.5,6.333333,4.75,20.,10. ,6.666666,5. ,21.,10.5,7. ,5.25}; +TEST_F(NDArrayTest, Operator_Divide_Test_7) { + double expBuff[] = { + 10., 5., 3.333333, 2.5, 11., 5.5, 3.666666, 2.75, 12., 6., 4., 3., + 13., 6.5, 4.333333, 3.25, 14., 7., 4.666666, 3.5, 15., 7.5, 5., 3.75, + 16., 8., 5.333333, 4., 17., 8.5, 5.666666, 4.25, 18., 9., 6., 4.5, + 19., 9.5, 6.333333, 4.75, 20., 10., 6.666666, 5., 21., 10.5, 7., 5.25}; - auto x = NDArrayFactory::create('c', {3, 4, 1}); - auto y = NDArrayFactory::create('c', {1, 1, 4}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 4}); + auto x = NDArrayFactory::create('c', {3, 4, 1}); + auto y = NDArrayFactory::create('c', {1, 1, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 4}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x / y; + auto result = x / y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } #ifndef __CUDABLAS__ ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_Lambda_1) { - auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); - auto exp = NDArrayFactory::create('c', {1, 5}, {4, 5, 6, 7, 8}); + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto exp = NDArrayFactory::create('c', {1, 5}, {4, 5, 6, 7, 8}); - auto lambda = LAMBDA_F(_val) { - return _val + 3.0f; - }; + auto lambda = LAMBDA_F(_val) { return _val + 3.0f; }; - x.applyLambda(lambda, x); + x.applyLambda(lambda, x); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_Lambda_2) { - auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 1, 2, 1}); - auto y = NDArrayFactory::create('c', {1, 5}, {1, 2, 1, 2, 1}); - auto exp = NDArrayFactory::create('c', {1, 5}, {3, 5, 3, 5, 3}); + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 1, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 5}, {1, 2, 1, 2, 1}); + auto exp = NDArrayFactory::create('c', {1, 5}, {3, 5, 3, 5, 3}); - auto lambda = LAMBDA_FF(_x, _y) { - return _x + _y + 1.0f; - }; + auto lambda = LAMBDA_FF(_x, _y) { return _x + _y + 1.0f; }; - x.applyPairwiseLambda(y, lambda, x); + x.applyPairwiseLambda(y, lambda, x); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_Lambda_3) { - auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 1, 2, 1}); - auto y = NDArrayFactory::create('c', {1, 5}, {1, 2, 1, 2, 1}); - auto exp = NDArrayFactory::create('c', {1, 5}, {4, 8, 4, 8, 4}); + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 1, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 5}, {1, 2, 1, 2, 1}); + auto exp = NDArrayFactory::create('c', {1, 5}, {4, 8, 4, 8, 4}); - auto lambda = LAMBDA_DD(_x, _y) { - return (_x + _y) * 2; - }; + auto lambda = LAMBDA_DD(_x, _y) { return (_x + _y) * 2; }; - x.applyPairwiseLambda(y, lambda, x); + x.applyPairwiseLambda(y, lambda, x); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } #endif ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_swapUnsafe_1) { + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {1, 4}, {5, 6, 7, 8}); + auto expX = NDArrayFactory::create('c', {2, 2}, {5, 6, 7, 8}); + auto expY = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto y = NDArrayFactory::create('c', {1, 4}, {5, 6, 7, 8}); - auto expX = NDArrayFactory::create('c', {2, 2}, {5, 6, 7, 8}); - auto expY = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - - x.swapUnsafe(y); + x.swapUnsafe(y); - ASSERT_TRUE(expX.equalsTo(&x)); - ASSERT_TRUE(expY.equalsTo(&y)); + ASSERT_TRUE(expX.equalsTo(&x)); + ASSERT_TRUE(expY.equalsTo(&y)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_1) { + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + auto exp = NDArrayFactory::create('c', {2, 1}, {1, 5}); - auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); - auto exp = NDArrayFactory::create('c', {2, 1}, {1, 5}); + auto diag = x.diagonal('c'); - auto diag = x.diagonal('c'); - - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_2) { + auto x = NDArrayFactory::create('f', {2, 3}); + auto exp = NDArrayFactory::create('f', {2, 1}, {1, 5}); + x.linspace(1); - auto x = NDArrayFactory::create('f', {2, 3}); - auto exp = NDArrayFactory::create('f', {2, 1}, {1, 5}); - x.linspace(1); - - auto diag = x.diagonal('c'); + auto diag = x.diagonal('c'); - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_3) { + auto x = NDArrayFactory::create('c', {2, 2}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {1, 2}, {1, 4}); - auto x = NDArrayFactory::create('c', {2, 2}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {1, 2}, {1, 4}); + auto diag = x.diagonal('r'); - auto diag = x.diagonal('r'); - - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_4) { + auto x = NDArrayFactory::create('f', {2, 2}); + x.linspace(1); + auto exp = NDArrayFactory::create('f', {1, 2}, {1, 4}); - auto x = NDArrayFactory::create('f', {2, 2}); - x.linspace(1); - auto exp = NDArrayFactory::create('f', {1, 2}, {1, 4}); - - auto diag = x.diagonal('r'); + auto diag = x.diagonal('r'); - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_5) { + auto x = NDArrayFactory::create('c', {2, 2, 2}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {1, 2}, {1, 8}); - auto x = NDArrayFactory::create('c', {2, 2, 2}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {1, 2}, {1, 8}); - - auto diag = x.diagonal('r'); + auto diag = x.diagonal('r'); - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_6) { + auto x = NDArrayFactory::create('f', {2, 2, 2}); + x.linspace(1); + auto exp = NDArrayFactory::create('f', {1, 2}, {1, 8}); - auto x = NDArrayFactory::create('f', {2, 2, 2}); - x.linspace(1); - auto exp = NDArrayFactory::create('f', {1, 2}, {1, 8}); + auto diag = x.diagonal('r'); - auto diag = x.diagonal('r'); - - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_7) { + auto x = NDArrayFactory::create('f', {2, 2, 2}); + x.linspace(1); + auto exp = NDArrayFactory::create('f', {2, 1}, {1, 8}); - auto x = NDArrayFactory::create('f', {2, 2, 2}); - x.linspace(1); - auto exp = NDArrayFactory::create('f', {2, 1}, {1, 8}); - - auto diag = x.diagonal('c'); + auto diag = x.diagonal('c'); - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_8) { + auto x = NDArrayFactory::create('c', {2, 3}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {1, 2}, {1, 5}); - auto x = NDArrayFactory::create('c', {2, 3}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {1, 2}, {1, 5}); + auto diag = x.diagonal('r'); - auto diag = x.diagonal('r'); - - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_9) { + auto x = NDArrayFactory::create('c', {2, 2}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {2, 1}, {1, 4}); - auto x = NDArrayFactory::create('c', {2, 2}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {2, 1}, {1, 4}); - - auto diag = x.diagonal('c'); + auto diag = x.diagonal('c'); - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } - ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_10) { + auto x = NDArrayFactory::create('f', {2, 2}); + x.linspace(1); + auto exp = NDArrayFactory::create('f', {2, 1}, {1, 4}); - auto x = NDArrayFactory::create('f', {2, 2}); - x.linspace(1); - auto exp = NDArrayFactory::create('f', {2, 1}, {1, 4}); - - auto diag = x.diagonal('c'); + auto diag = x.diagonal('c'); - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_11) { + auto x = NDArrayFactory::create('f', {3, 3}); + x.linspace(1); + auto exp = NDArrayFactory::create('f', {3, 1}, {1, 5, 9}); - auto x = NDArrayFactory::create('f', {3, 3}); - x.linspace(1); - auto exp = NDArrayFactory::create('f', {3, 1}, {1, 5, 9}); + auto diag = x.diagonal('c'); - auto diag = x.diagonal('c'); - - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_12) { + auto x = NDArrayFactory::create('c', {3, 3}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {1, 3}, {1, 5, 9}); - auto x = NDArrayFactory::create('c', {3, 3}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {1, 3}, {1, 5, 9}); - - auto diag = x.diagonal('r'); + auto diag = x.diagonal('r'); - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_13) { + auto x = NDArrayFactory::create('c', {3, 3, 4}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {3, 1}, {1, 18, 35}); - auto x = NDArrayFactory::create('c', {3, 3, 4}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {3, 1}, {1,18,35}); + auto diag = x.diagonal('c'); - auto diag = x.diagonal('c'); - - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_14) { + auto x = NDArrayFactory::create('c', {3, 3, 4}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {1, 3}, {1, 18, 35}); - auto x = NDArrayFactory::create('c', {3, 3, 4}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {1, 3}, {1,18,35}); - - auto diag = x.diagonal('r'); + auto diag = x.diagonal('r'); - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_15) { + auto x = NDArrayFactory::create('f', {3, 3, 4}); + x.linspace(1); + auto exp = NDArrayFactory::create('f', {1, 3}, {1, 18, 35}); - auto x = NDArrayFactory::create('f', {3, 3, 4}); - x.linspace(1); - auto exp = NDArrayFactory::create('f', {1, 3}, {1,18,35}); + auto diag = x.diagonal('r'); - auto diag = x.diagonal('r'); - - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_16) { + auto x = NDArrayFactory::create('f', {1, 5}); + x.linspace(1); + auto exp = NDArrayFactory::create('f', {1, 1}, {1}); - auto x = NDArrayFactory::create('f', {1, 5}); - x.linspace(1); - auto exp = NDArrayFactory::create('f', {1, 1}, {1}); - - auto diag = x.diagonal('c'); + auto diag = x.diagonal('c'); - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_17) { + auto x = NDArrayFactory::create('c', {5, 1}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {1, 1}, {1}); - auto x = NDArrayFactory::create('c', {5, 1}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {1, 1}, {1}); + auto diag = x.diagonal('r'); - auto diag = x.diagonal('r'); - - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_18) { + auto x = NDArrayFactory::create('f', {1, 1}); + x.linspace(1); + auto exp = NDArrayFactory::create('f', {1, 1}, {1}); - auto x = NDArrayFactory::create('f', {1, 1}); - x.linspace(1); - auto exp = NDArrayFactory::create('f', {1, 1}, {1}); - - auto diag = x.diagonal('r'); + auto diag = x.diagonal('r'); - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, assign_test1) { + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + NDArray y('c', {2, 3}, {10, 20, 30, 40, 50, 60}); + y.reshapei('c', {3, 2}); - NDArray x('c', {2, 3}, {1,2,3,4,5,6}); - NDArray y('c', {2, 3}, {10,20,30,40,50,60}); - y.reshapei('c',{3, 2}); - - x.assign(y); - x.reshapei('c',{3, 2}); - ASSERT_TRUE(x.equalsTo(y)); + x.assign(y); + x.reshapei('c', {3, 2}); + ASSERT_TRUE(x.equalsTo(y)); } diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayTests2.cpp b/libnd4j/tests_cpu/layers_tests/NDArrayTests2.cpp index 4dd4c3abee4d..a5cb2b4ba320 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/NDArrayTests2.cpp @@ -18,1290 +18,1343 @@ // Created by raver119 on 21.11.17. // -#include "testlayers.h" -#include #include #include #include +#include + +#include "testlayers.h" + using namespace sd; ////////////////////////////////////////////////////////////////////// class NDArrayTest2 : public testing::Test { -public: - + public: }; - TEST_F(NDArrayTest2, Test_ByteVector_1) { - auto x = NDArrayFactory::create('c', {10, 10}); - x.linspace(1); - - auto vec = x.asByteVector(); + auto x = NDArrayFactory::create('c', {10, 10}); + x.linspace(1); - auto restored = new NDArray((float *)vec.data(), x.shapeInfo(), x.getContext(), false); + auto vec = x.asByteVector(); + auto restored = + new NDArray((float *)vec.data(), x.shapeInfo(), x.getContext(), false); - ASSERT_TRUE(x.equalsTo(restored)); + ASSERT_TRUE(x.equalsTo(restored)); - delete restored; + delete restored; } TEST_F(NDArrayTest2, Test_ByteVector_2) { - auto x = NDArrayFactory::create('c', {10, 10}); - x.linspace(1); + auto x = NDArrayFactory::create('c', {10, 10}); + x.linspace(1); - auto vec = x.asByteVector(); + auto vec = x.asByteVector(); - auto restored = new NDArray((bfloat16 *)vec.data(), x.shapeInfo(), x.getContext(), false); + auto restored = + new NDArray((bfloat16 *)vec.data(), x.shapeInfo(), x.getContext(), false); - ASSERT_TRUE(x.equalsTo(restored)); + ASSERT_TRUE(x.equalsTo(restored)); - delete restored; + delete restored; } TEST_F(NDArrayTest2, Test_ByteVector_3) { - auto x = NDArrayFactory::create('c', {10, 10}); - x.linspace(1); + auto x = NDArrayFactory::create('c', {10, 10}); + x.linspace(1); - auto vec = x.asByteVector(); + auto vec = x.asByteVector(); - auto restored = new NDArray((double *)vec.data(), x.shapeInfo(), x.getContext(), false); + auto restored = + new NDArray((double *)vec.data(), x.shapeInfo(), x.getContext(), false); - ASSERT_TRUE(x.equalsTo(restored)); + ASSERT_TRUE(x.equalsTo(restored)); - delete restored; + delete restored; } TEST_F(NDArrayTest2, Test_Reshape_Scalar_1) { - auto x = NDArrayFactory::create('c', {1, 1}, {1.0}); - auto e = NDArrayFactory::create(1.0); + auto x = NDArrayFactory::create('c', {1, 1}, {1.0}); + auto e = NDArrayFactory::create(1.0); - x.reshapei({}); + x.reshapei({}); - ASSERT_EQ(e, x); - ASSERT_EQ(e.rankOf(), x.rankOf()); + ASSERT_EQ(e, x); + ASSERT_EQ(e.rankOf(), x.rankOf()); } TEST_F(NDArrayTest2, Test_Reshape_Scalar_2) { - auto x = NDArrayFactory::create('c', {1, 1}, {1.0}); - auto e = NDArrayFactory::create('c', {1}, {1.0}); + auto x = NDArrayFactory::create('c', {1, 1}, {1.0}); + auto e = NDArrayFactory::create('c', {1}, {1.0}); - x.reshapei({1}); + x.reshapei({1}); - ASSERT_EQ(e, x); - ASSERT_EQ(e.rankOf(), x.rankOf()); + ASSERT_EQ(e, x); + ASSERT_EQ(e.rankOf(), x.rankOf()); } TEST_F(NDArrayTest2, Test_IndexReduce_1) { - auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); - ExtraArguments extras({3.0, 0.0, 10.0}); - int idx = x.indexReduceNumber(indexreduce::FirstIndex, &extras).e(0); + ExtraArguments extras({3.0, 0.0, 10.0}); + int idx = x.indexReduceNumber(indexreduce::FirstIndex, &extras).e(0); - ASSERT_EQ(2, idx); + ASSERT_EQ(2, idx); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, SetIdentity_test_1) { + auto x = NDArrayFactory::create('c', {1, 5}); + auto xExp = NDArrayFactory::create('c', {1, 5}, {1, 0, 0, 0, 0}); - auto x = NDArrayFactory::create('c', {1, 5}); - auto xExp = NDArrayFactory::create('c', {1, 5}, {1, 0, 0, 0, 0}); - - x.setIdentity(); - ASSERT_TRUE(x.equalsTo(&xExp)); + x.setIdentity(); + ASSERT_TRUE(x.equalsTo(&xExp)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, SetIdentity_test_2) { + auto x = NDArrayFactory::create('f', {1, 5}); + auto xExp = NDArrayFactory::create('f', {1, 5}, {1, 0, 0, 0, 0}); - auto x = NDArrayFactory::create('f', {1, 5}); - auto xExp = NDArrayFactory::create('f', {1, 5}, {1, 0, 0, 0, 0}); - - x.setIdentity(); + x.setIdentity(); - ASSERT_TRUE(x.equalsTo(&xExp)); + ASSERT_TRUE(x.equalsTo(&xExp)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, SetIdentity_test_3) { + auto x = NDArrayFactory::create('f', {1, 1}); + auto xExp = NDArrayFactory::create('f', {1, 1}, {1}); - auto x = NDArrayFactory::create('f', {1, 1}); - auto xExp = NDArrayFactory::create('f', {1, 1}, {1}); + x.setIdentity(); - x.setIdentity(); - - ASSERT_TRUE(x.equalsTo(&xExp)); + ASSERT_TRUE(x.equalsTo(&xExp)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, SetIdentity_test_4) { + auto x = NDArrayFactory::create('f', {2, 1}); + auto xExp = NDArrayFactory::create('f', {2, 1}, {1, 0}); - auto x = NDArrayFactory::create('f', {2, 1}); - auto xExp = NDArrayFactory::create('f', {2, 1}, {1,0}); - - x.setIdentity(); + x.setIdentity(); - ASSERT_TRUE(x.equalsTo(&xExp)); + ASSERT_TRUE(x.equalsTo(&xExp)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, SetIdentity_test_5) { + auto x = NDArrayFactory::create('f', {2, 2}); + auto xExp = NDArrayFactory::create('f', {2, 2}, {1, 0, 0, 1}); - auto x = NDArrayFactory::create('f', {2, 2}); - auto xExp = NDArrayFactory::create('f', {2, 2}, {1,0,0,1}); + x.setIdentity(); - x.setIdentity(); - - ASSERT_TRUE(x.equalsTo(&xExp)); + ASSERT_TRUE(x.equalsTo(&xExp)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, SetIdentity_test_6) { + auto x = NDArrayFactory::create('c', {3, 2}); + auto xExp = NDArrayFactory::create('c', {3, 2}, + {1.f, 0.f, 0.f, 1.f, 0.f, 0.f}); - auto x = NDArrayFactory::create('c', {3, 2}); - auto xExp = NDArrayFactory::create('c', {3, 2}, {1.f, 0.f, 0.f, 1.f, 0.f, 0.f}); - - x.setIdentity(); + x.setIdentity(); - ASSERT_TRUE(x.equalsTo(&xExp)); + ASSERT_TRUE(x.equalsTo(&xExp)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, SetIdentity_test_7) { + auto x = NDArrayFactory::create('c', {3, 4}); + auto xExp = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f}); - auto x = NDArrayFactory::create('c', {3, 4}); - auto xExp = NDArrayFactory::create('c', {3, 4}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f}); + x.setIdentity(); - x.setIdentity(); - - ASSERT_TRUE(x.equalsTo(&xExp)); + ASSERT_TRUE(x.equalsTo(&xExp)); } #ifdef ALLOWED_3D_IDENTITY //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, SetIdentity_test_8) { + auto x = NDArrayFactory::create('c', {3, 3, 3}); + auto xExp = NDArrayFactory::create( + 'c', {3, 3, 3}, {1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.}); + x.setIdentity(); - auto x = NDArrayFactory::create('c', {3, 3, 3}); - auto xExp = NDArrayFactory::create('c', {3, 3, 3}, {1.,0.,0. ,0.,0.,0., 0.,0.,0., 0.,0.,0. ,0.,1.,0., 0.,0.,0., 0.,0.,0. ,0.,0.,0., 0.,0.,1.}); - x.setIdentity(); - - ASSERT_TRUE(x.equalsTo(&xExp)); + ASSERT_TRUE(x.equalsTo(&xExp)); } #endif //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_AllReduce3_1) { - auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); - auto y = NDArrayFactory::create('c', {2, 3}, {2, 3, 4, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {2, 2}, {1.73205, 1.73205, 1.73205, 1.73205}); + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); + auto y = NDArrayFactory::create('c', {2, 3}, {2, 3, 4, 2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {2, 2}, {1.73205, 1.73205, 1.73205, 1.73205}); - auto z = x.applyAllReduce3(reduce3::EuclideanDistance, y, {1}); + auto z = x.applyAllReduce3(reduce3::EuclideanDistance, y, {1}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_AllReduce3_2) { - auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 2, 3, 4 }); - auto y = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {2, 2}, {0., 1.73205, 1.73205, 0.}); + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {2, 2}, {0., 1.73205, 1.73205, 0.}); - auto z = x.applyAllReduce3(reduce3::EuclideanDistance, y, {1}); + auto z = x.applyAllReduce3(reduce3::EuclideanDistance, y, {1}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, mmul_test1) { + auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, {1, 2, 3, 4, 2, 4, 6, 8, 3, 6, 9, 12, 4, 8, 12, 16}); - auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); - auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16}); - - auto result = mmul(x, y); - ASSERT_TRUE(exp.isSameShape(&result)); - ASSERT_TRUE(exp.equalsTo(&result)); - + auto result = mmul(x, y); + ASSERT_TRUE(exp.isSameShape(&result)); + ASSERT_TRUE(exp.equalsTo(&result)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, mmul_test2) { + auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1}, {30}); - auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); - auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1}, {30}); - - auto result = mmul(y ,x); - - ASSERT_TRUE(exp.isSameShape(&result)); - ASSERT_TRUE(exp.equalsTo(&result)); + auto result = mmul(y, x); + ASSERT_TRUE(exp.isSameShape(&result)); + ASSERT_TRUE(exp.equalsTo(&result)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, mmul_test3) { + auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, + {1., 0.2, 0.3, 0.4, 0.2, 0.04, 0.06, 0.08, 0.3, 0.06, 0.09, 0.12, 0.4, + 0.08, 0.12, 0.16}); + auto w = NDArrayFactory::create(x.ordering(), {(int)x.lengthOf(), 1}, + x.getContext()); // column-vector + auto wT = NDArrayFactory::create( + x.ordering(), {1, (int)x.lengthOf()}, + x.getContext()); // row-vector (transposed w) - auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4, 4}, {1. ,0.2 ,0.3 ,0.4 ,0.2,0.04,0.06,0.08,0.3,0.06,0.09,0.12,0.4,0.08,0.12,0.16}); - auto w = NDArrayFactory::create( x.ordering(), {(int)x.lengthOf(), 1}, x.getContext()); // column-vector - auto wT = NDArrayFactory::create(x.ordering(), {1, (int)x.lengthOf()}, x.getContext()); // row-vector (transposed w) - - w = x / (float)10.; - w.p(0, 1.); - wT.assign(&w); + w = x / (float)10.; + w.p(0, 1.); + wT.assign(&w); - auto result = mmul(w ,wT); - - ASSERT_TRUE(exp.isSameShape(&result)); - ASSERT_TRUE(exp.equalsTo(&result)); + auto result = mmul(w, wT); + ASSERT_TRUE(exp.isSameShape(&result)); + ASSERT_TRUE(exp.equalsTo(&result)); } - TEST_F(NDArrayTest2, Test_Streamline_1) { - auto x = NDArrayFactory::create('c', {3, 4, 6}); - auto y = NDArrayFactory::create('c', {3, 4, 6}); - x.linspace(1); - y.linspace(1); + auto x = NDArrayFactory::create('c', {3, 4, 6}); + auto y = NDArrayFactory::create('c', {3, 4, 6}); + x.linspace(1); + y.linspace(1); - x.permutei({1, 0, 2}); - y.permutei({1, 0, 2}); + x.permutei({1, 0, 2}); + y.permutei({1, 0, 2}); - y.streamline(); + y.streamline(); - ASSERT_TRUE(x.isSameShape(&y)); - ASSERT_TRUE(x.equalsTo(&y)); - ASSERT_FALSE(x.isSameShapeStrict(y)); + ASSERT_TRUE(x.isSameShape(&y)); + ASSERT_TRUE(x.equalsTo(&y)); + ASSERT_FALSE(x.isSameShapeStrict(y)); } - TEST_F(NDArrayTest2, Test_Streamline_2) { - auto x = NDArrayFactory::create('c', {3, 4, 6}); - auto y = NDArrayFactory::create('f', {3, 4, 6}); - x.linspace(1); - y.linspace(1); + auto x = NDArrayFactory::create('c', {3, 4, 6}); + auto y = NDArrayFactory::create('f', {3, 4, 6}); + x.linspace(1); + y.linspace(1); - ASSERT_TRUE(x.isSameShape(&y)); - ASSERT_TRUE(x.equalsTo(&y)); + ASSERT_TRUE(x.isSameShape(&y)); + ASSERT_TRUE(x.equalsTo(&y)); - y.streamline('c'); + y.streamline('c'); - ASSERT_TRUE(x.isSameShape(&y)); - ASSERT_TRUE(x.equalsTo(&y)); + ASSERT_TRUE(x.isSameShape(&y)); + ASSERT_TRUE(x.equalsTo(&y)); } TEST_F(NDArrayTest2, Test_Enforce_1) { - auto x = NDArrayFactory::create('c', {4, 1, 1, 4}); - auto exp = NDArrayFactory::create('c', {4, 4}); + auto x = NDArrayFactory::create('c', {4, 1, 1, 4}); + auto exp = NDArrayFactory::create('c', {4, 4}); - x.linspace(1); - exp.linspace(1); + x.linspace(1); + exp.linspace(1); - x.enforce({4, 4}, 'c'); + x.enforce({4, 4}, 'c'); - ASSERT_TRUE(exp.isSameShapeStrict(x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.isSameShapeStrict(x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(NDArrayTest2, TestVector_1) { - auto x = NDArrayFactory::create('c', {2, 3}); - auto row = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto exp = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); + auto x = NDArrayFactory::create('c', {2, 3}); + auto row = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); - x.addiRowVector(row); + x.addiRowVector(row); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest2, Operator_Plus_Test_5) -{ - - auto x = NDArrayFactory::create('c', {8, 8, 8}); - auto y = NDArrayFactory::create('c', {8, 1, 8}); - auto expected = NDArrayFactory::create('c', {8, 8, 8}); +TEST_F(NDArrayTest2, Operator_Plus_Test_5) { + auto x = NDArrayFactory::create('c', {8, 8, 8}); + auto y = NDArrayFactory::create('c', {8, 1, 8}); + auto expected = NDArrayFactory::create('c', {8, 8, 8}); - x = 1.; - y = 2.; - expected = 3.; + x = 1.; + y = 2.; + expected = 3.; - auto result = x + y; + auto result = x + y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Operator_Plus_Test_6) { + auto x = NDArrayFactory::create('c', {3, 3, 3}); + auto y = NDArrayFactory::create('c', {3, 1, 3}); + auto expected = NDArrayFactory::create( + 'c', {3, 3, 3}, + {2., 4., 6., 5., 7., 9., 8., 10., 12., 14., 16., 18., 17., 19., + 21., 20., 22., 24., 26., 28., 30., 29., 31., 33., 32., 34., 36.}); + x.linspace(1); + y.linspace(1); - auto x = NDArrayFactory::create('c', {3, 3, 3}); - auto y = NDArrayFactory::create('c', {3, 1, 3}); - auto expected = NDArrayFactory::create('c', {3, 3, 3}, {2., 4., 6., 5., 7., 9., 8.,10.,12., 14.,16.,18.,17.,19.,21.,20.,22.,24., 26.,28.,30.,29.,31.,33.,32.,34.,36.}); - x.linspace(1); - y.linspace(1); + auto result = x + y; - auto result = x + y; - - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, tileToShape_test1) { + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 1, 2, 3, 4}); - auto x = NDArrayFactory::create('c', {2, 2}, {1,2,3,4}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1,2,3,4,1,2,3,4}); - - x.tileToShape({2,2,2}, x); + x.tileToShape({2, 2, 2}, x); - ASSERT_TRUE(x.isSameShape(&exp)); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(x.isSameShape(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, tileToShape_test2) { + auto x = NDArrayFactory::create('c', {2, 1, 2}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 2}, {1, 2, 1, 2, 1, 2, 3, 4, 3, 4, 3, 4}); - auto x = NDArrayFactory::create('c', {2, 1, 2}, {1,2,3,4}); - auto exp = NDArrayFactory::create('c', {2, 3, 2}, {1,2,1,2,1,2,3,4,3,4,3,4}); + x.tileToShape({2, 3, 2}, x); - x.tileToShape({2,3,2}, x); - - ASSERT_TRUE(x.isSameShape(&exp)); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(x.isSameShape(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, tileToShape_test3) { + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto result = NDArrayFactory::create('c', {2, 2, 2}); + auto exp = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 1, 2, 3, 4}); - auto x = NDArrayFactory::create('c', {2, 2}, {1,2,3,4}); - auto result = NDArrayFactory::create('c', {2, 2, 2}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1,2,3,4,1,2,3,4}); - - x.tileToShape({2,2,2}, result); - // result.printIndexedBuffer(); + x.tileToShape({2, 2, 2}, result); + // result.printIndexedBuffer(); - ASSERT_TRUE(result.isSameShape(&exp)); - ASSERT_TRUE(result.equalsTo(&exp)); + ASSERT_TRUE(result.isSameShape(&exp)); + ASSERT_TRUE(result.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, tileToShape_test4) { + auto x = NDArrayFactory::create('c', {2, 1, 2}, {1, 2, 3, 4}); + auto result = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 2}, {1, 2, 1, 2, 1, 2, 3, 4, 3, 4, 3, 4}); - auto x = NDArrayFactory::create('c', {2, 1, 2}, {1,2,3,4}); - auto result = NDArrayFactory::create('c', {2, 3, 2}); - auto exp = NDArrayFactory::create('c', {2, 3, 2}, {1,2,1,2,1,2,3,4,3,4,3,4}); + x.tileToShape({2, 3, 2}, result); - x.tileToShape({2,3,2}, result); - - ASSERT_TRUE(result.isSameShape(&exp)); - ASSERT_TRUE(result.equalsTo(&exp)); + ASSERT_TRUE(result.isSameShape(&exp)); + ASSERT_TRUE(result.equalsTo(&exp)); } #ifndef __CUDABLAS__ TEST_F(NDArrayTest2, Test_TriplewiseLambda_1) { - auto t = NDArrayFactory::create('c', {3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1}); - auto u = NDArrayFactory::create('c', {3, 3}, {2, 2, 2, 2, 2, 2, 2, 2, 2}); - auto v = NDArrayFactory::create('c', {3, 3}, {3, 3, 3, 3, 3, 3, 3, 3, 3}); - auto exp = NDArrayFactory::create('c', {3, 3}, {7, 7, 7, 7, 7, 7, 7, 7, 7}); + auto t = + NDArrayFactory::create('c', {3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1}); + auto u = + NDArrayFactory::create('c', {3, 3}, {2, 2, 2, 2, 2, 2, 2, 2, 2}); + auto v = + NDArrayFactory::create('c', {3, 3}, {3, 3, 3, 3, 3, 3, 3, 3, 3}); + auto exp = + NDArrayFactory::create('c', {3, 3}, {7, 7, 7, 7, 7, 7, 7, 7, 7}); - float extra = 1.0f; + float extra = 1.0f; - auto la = LAMBDA_DDD(_t, _u, _v, extra) { - return _t + _u + _v + extra; - }; + auto la = LAMBDA_DDD(_t, _u, _v, extra) { return _t + _u + _v + extra; }; - t.applyTriplewiseLambda(u, v, la, t); + t.applyTriplewiseLambda(u, v, la, t); - ASSERT_TRUE(t.equalsTo(&exp)); + ASSERT_TRUE(t.equalsTo(&exp)); } - TEST_F(NDArrayTest2, Test_TriplewiseLambda_2) { - auto t = NDArrayFactory::create('c', {3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1}); - auto u = NDArrayFactory::create('f', {3, 3}, {2, 2, 2, 2, 2, 2, 2, 2, 2}); - auto v = NDArrayFactory::create('c', {3, 3}, {3, 3, 3, 3, 3, 3, 3, 3, 3}); - auto exp = NDArrayFactory::create('c', {3, 3}, {7, 7, 7, 7, 7, 7, 7, 7, 7}); + auto t = + NDArrayFactory::create('c', {3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1}); + auto u = + NDArrayFactory::create('f', {3, 3}, {2, 2, 2, 2, 2, 2, 2, 2, 2}); + auto v = + NDArrayFactory::create('c', {3, 3}, {3, 3, 3, 3, 3, 3, 3, 3, 3}); + auto exp = + NDArrayFactory::create('c', {3, 3}, {7, 7, 7, 7, 7, 7, 7, 7, 7}); - float extra = 1.0f; + float extra = 1.0f; - auto la = LAMBDA_DDD(_t, _u, _v, extra) { - return _t + _u + _v + extra; - }; + auto la = LAMBDA_DDD(_t, _u, _v, extra) { return _t + _u + _v + extra; }; - t.applyTriplewiseLambda(u, v, la, t); + t.applyTriplewiseLambda(u, v, la, t); - ASSERT_TRUE(t.equalsTo(&exp)); + ASSERT_TRUE(t.equalsTo(&exp)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_Indexed_Lambda) { - auto x = NDArrayFactory::create('c', {2, 2}); - auto exp = NDArrayFactory::create('c', {2, 2}, {0, 1, 2, 3}); + auto x = NDArrayFactory::create('c', {2, 2}); + auto exp = NDArrayFactory::create('c', {2, 2}, {0, 1, 2, 3}); - auto lambda = ILAMBDA_D(_x) { - return (float) _idx; - }; + auto lambda = ILAMBDA_D(_x) { return (float)_idx; }; - x.applyIndexedLambda(lambda, x); + x.applyIndexedLambda(lambda, x); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } #endif TEST_F(NDArrayTest2, Test_PermuteEquality_1) { - auto x = NDArrayFactory::create('c', {1, 60}); - auto exp = NDArrayFactory::create('c', {3, 5, 4}, {1.0, 6.0, 11.0, 16.0, 2.0, 7.0, 12.0, 17.0, 3.0, 8.0, 13.0, 18.0, 4.0, 9.0, 14.0, 19.0, 5.0, 10.0, 15.0, 20.0, 21.0, 26.0, 31.0, 36.0, 22.0, 27.0, 32.0, 37.0, 23.0, 28.0, 33.0, 38.0, 24.0, 29.0, 34.0, 39.0, 25.0, 30.0, 35.0, 40.0, 41.0, 46.0, 51.0, 56.0, 42.0, 47.0, 52.0, 57.0, 43.0, 48.0, 53.0, 58.0, 44.0, 49.0, 54.0, 59.0, 45.0, 50.0, 55.0, 60.0}); - x.linspace(1); - x.reshapei('c', {3, 4, 5}); + auto x = NDArrayFactory::create('c', {1, 60}); + auto exp = NDArrayFactory::create( + 'c', {3, 5, 4}, + {1.0, 6.0, 11.0, 16.0, 2.0, 7.0, 12.0, 17.0, 3.0, 8.0, 13.0, 18.0, + 4.0, 9.0, 14.0, 19.0, 5.0, 10.0, 15.0, 20.0, 21.0, 26.0, 31.0, 36.0, + 22.0, 27.0, 32.0, 37.0, 23.0, 28.0, 33.0, 38.0, 24.0, 29.0, 34.0, 39.0, + 25.0, 30.0, 35.0, 40.0, 41.0, 46.0, 51.0, 56.0, 42.0, 47.0, 52.0, 57.0, + 43.0, 48.0, 53.0, 58.0, 44.0, 49.0, 54.0, 59.0, 45.0, 50.0, 55.0, 60.0}); + x.linspace(1); + x.reshapei('c', {3, 4, 5}); - x.permutei({0, 2, 1}); - x.streamline(); + x.permutei({0, 2, 1}); + x.streamline(); -// x.printShapeInfo("{0, 2, 1} shape"); -// x.printBuffer("{0, 2, 1} data"); + // x.printShapeInfo("{0, 2, 1} shape"); + // x.printBuffer("{0, 2, 1} data"); - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(NDArrayTest2, Test_PermuteEquality_0) { - auto x = NDArrayFactory::create('c', {1, 60}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0}); - x.reshapei('c', {3, 4, 5}); + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 5}, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, + 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, + 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, + 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0}); + x.reshapei('c', {3, 4, 5}); - x.permutei({0, 1, 2}); - x.streamline(); + x.permutei({0, 1, 2}); + x.streamline(); -// x.printShapeInfo("{0, 1, 2} shape"); -// x.printBuffer("{0, 1, 2} data"); + // x.printShapeInfo("{0, 1, 2} shape"); + // x.printBuffer("{0, 1, 2} data"); - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } - TEST_F(NDArrayTest2, Test_PermuteEquality_2) { - auto x = NDArrayFactory::create('c', {1, 60}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {4, 3, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 21.0, 22.0, 23.0, 24.0, 25.0, 41.0, 42.0, 43.0, 44.0, 45.0, 6.0, 7.0, 8.0, 9.0, 10.0, 26.0, 27.0, 28.0, 29.0, 30.0, 46.0, 47.0, 48.0, 49.0, 50.0, 11.0, 12.0, 13.0, 14.0, 15.0, 31.0, 32.0, 33.0, 34.0, 35.0, 51.0, 52.0, 53.0, 54.0, 55.0, 16.0, 17.0, 18.0, 19.0, 20.0, 36.0, 37.0, 38.0, 39.0, 40.0, 56.0, 57.0, 58.0, 59.0, 60.0}); - x.reshapei('c', {3, 4, 5}); + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {4, 3, 5}, + {1.0, 2.0, 3.0, 4.0, 5.0, 21.0, 22.0, 23.0, 24.0, 25.0, 41.0, 42.0, + 43.0, 44.0, 45.0, 6.0, 7.0, 8.0, 9.0, 10.0, 26.0, 27.0, 28.0, 29.0, + 30.0, 46.0, 47.0, 48.0, 49.0, 50.0, 11.0, 12.0, 13.0, 14.0, 15.0, 31.0, + 32.0, 33.0, 34.0, 35.0, 51.0, 52.0, 53.0, 54.0, 55.0, 16.0, 17.0, 18.0, + 19.0, 20.0, 36.0, 37.0, 38.0, 39.0, 40.0, 56.0, 57.0, 58.0, 59.0, 60.0}); + x.reshapei('c', {3, 4, 5}); - x.permutei({1, 0, 2}); - x.streamline(); + x.permutei({1, 0, 2}); + x.streamline(); -// x.printShapeInfo("{1, 0, 2} shape"); -// x.printBuffer("{1, 0, 2} data"); + // x.printShapeInfo("{1, 0, 2} shape"); + // x.printBuffer("{1, 0, 2} data"); - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(NDArrayTest2, Test_PermuteEquality_3) { - auto x = NDArrayFactory::create('c', {1, 60}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {4, 5, 3}, {1.0, 21.0, 41.0, 2.0, 22.0, 42.0, 3.0, 23.0, 43.0, 4.0, 24.0, 44.0, 5.0, 25.0, 45.0, 6.0, 26.0, 46.0, 7.0, 27.0, 47.0, 8.0, 28.0, 48.0, 9.0, 29.0, 49.0, 10.0, 30.0, 50.0, 11.0, 31.0, 51.0, 12.0, 32.0, 52.0, 13.0, 33.0, 53.0, 14.0, 34.0, 54.0, 15.0, 35.0, 55.0, 16.0, 36.0, 56.0, 17.0, 37.0, 57.0, 18.0, 38.0, 58.0, 19.0, 39.0, 59.0, 20.0, 40.0, 60.0}); - x.reshapei('c', {3, 4, 5}); + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {4, 5, 3}, + {1.0, 21.0, 41.0, 2.0, 22.0, 42.0, 3.0, 23.0, 43.0, 4.0, 24.0, 44.0, + 5.0, 25.0, 45.0, 6.0, 26.0, 46.0, 7.0, 27.0, 47.0, 8.0, 28.0, 48.0, + 9.0, 29.0, 49.0, 10.0, 30.0, 50.0, 11.0, 31.0, 51.0, 12.0, 32.0, 52.0, + 13.0, 33.0, 53.0, 14.0, 34.0, 54.0, 15.0, 35.0, 55.0, 16.0, 36.0, 56.0, + 17.0, 37.0, 57.0, 18.0, 38.0, 58.0, 19.0, 39.0, 59.0, 20.0, 40.0, 60.0}); + x.reshapei('c', {3, 4, 5}); - x.permutei({1, 2, 0}); - x.streamline(); + x.permutei({1, 2, 0}); + x.streamline(); -// x.printShapeInfo("{1, 2, 0} shape"); -// x.printBuffer("{1, 2, 0} data"); + // x.printShapeInfo("{1, 2, 0} shape"); + // x.printBuffer("{1, 2, 0} data"); - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(NDArrayTest2, Test_PermuteEquality_4) { - auto x = NDArrayFactory::create('c', {1, 60}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {5, 3, 4}, {1.0, 6.0, 11.0, 16.0, 21.0, 26.0, 31.0, 36.0, 41.0, 46.0, 51.0, 56.0, 2.0, 7.0, 12.0, 17.0, 22.0, 27.0, 32.0, 37.0, 42.0, 47.0, 52.0, 57.0, 3.0, 8.0, 13.0, 18.0, 23.0, 28.0, 33.0, 38.0, 43.0, 48.0, 53.0, 58.0, 4.0, 9.0, 14.0, 19.0, 24.0, 29.0, 34.0, 39.0, 44.0, 49.0, 54.0, 59.0, 5.0, 10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0}); - x.reshapei('c', {3, 4, 5}); + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {5, 3, 4}, + {1.0, 6.0, 11.0, 16.0, 21.0, 26.0, 31.0, 36.0, 41.0, 46.0, 51.0, 56.0, + 2.0, 7.0, 12.0, 17.0, 22.0, 27.0, 32.0, 37.0, 42.0, 47.0, 52.0, 57.0, + 3.0, 8.0, 13.0, 18.0, 23.0, 28.0, 33.0, 38.0, 43.0, 48.0, 53.0, 58.0, + 4.0, 9.0, 14.0, 19.0, 24.0, 29.0, 34.0, 39.0, 44.0, 49.0, 54.0, 59.0, + 5.0, 10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0}); + x.reshapei('c', {3, 4, 5}); - x.permutei({2, 0, 1}); - x.streamline(); + x.permutei({2, 0, 1}); + x.streamline(); -// x.printShapeInfo("{2, 0, 1} shape"); -// x.printBuffer("{2, 0, 1} data"); + // x.printShapeInfo("{2, 0, 1} shape"); + // x.printBuffer("{2, 0, 1} data"); - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(NDArrayTest2, Test_PermuteEquality_5) { - auto x = NDArrayFactory::create('c', {1, 60}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {5, 4, 3}, - {1.0, 21.0, 41.0, 6.0, 26.0, 46.0, 11.0, 31.0, 51.0, 16.0, 36.0, 56.0, 2.0, 22.0, 42.0, 7.0, - 27.0, 47.0, 12.0, 32.0, 52.0, 17.0, 37.0, 57.0, 3.0, 23.0, 43.0, 8.0, 28.0, 48.0, 13.0, 33.0, - 53.0, 18.0, 38.0, 58.0, 4.0, 24.0, 44.0, 9.0, 29.0, 49.0, 14.0, 34.0, 54.0, 19.0, 39.0, 59.0, - 5.0, 25.0, 45.0, 10.0, 30.0, 50.0, 15.0, 35.0, 55.0, 20.0, 40.0, 60.0}); - x.reshapei('c', {3, 4, 5}); - - x.permutei({2, 1, 0}); - x.streamline(); + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {5, 4, 3}, + {1.0, 21.0, 41.0, 6.0, 26.0, 46.0, 11.0, 31.0, 51.0, 16.0, 36.0, 56.0, + 2.0, 22.0, 42.0, 7.0, 27.0, 47.0, 12.0, 32.0, 52.0, 17.0, 37.0, 57.0, + 3.0, 23.0, 43.0, 8.0, 28.0, 48.0, 13.0, 33.0, 53.0, 18.0, 38.0, 58.0, + 4.0, 24.0, 44.0, 9.0, 29.0, 49.0, 14.0, 34.0, 54.0, 19.0, 39.0, 59.0, + 5.0, 25.0, 45.0, 10.0, 30.0, 50.0, 15.0, 35.0, 55.0, 20.0, 40.0, 60.0}); + x.reshapei('c', {3, 4, 5}); -// x.printShapeInfo("{2, 0, 1} shape"); -// x.printBuffer("{2, 0, 1} data"); + x.permutei({2, 1, 0}); + x.streamline(); - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + // x.printShapeInfo("{2, 0, 1} shape"); + // x.printBuffer("{2, 0, 1} data"); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, fillAsTriangular_test1) { + auto x = NDArrayFactory::create( + 'c', {4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, {1, 0, 0, 0, 5, 6, 0, 0, 9, 10, 11, 0, 13, 14, 15, 16}); - auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); - auto exp = NDArrayFactory::create('c', {4, 4}, {1,0,0,0,5,6,0,0,9,10,11,0 ,13,14,15,16}); - - x.fillAsTriangular(0., 0, 0, x, 'u'); - - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + x.fillAsTriangular(0., 0, 0, x, 'u'); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, fillAsTriangular_test2) { + auto x = NDArrayFactory::create( + 'c', {4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, {0, 0, 0, 0, 5, 0, 0, 0, 9, 10, 0, 0, 13, 14, 15, 0}); - auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); - auto exp = NDArrayFactory::create('c', {4, 4}, {0,0,0,0,5,0,0,0,9,10,0 ,0 ,13,14,15,0}); - - x.fillAsTriangular(0., 0, -1, x, 'u'); - - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + x.fillAsTriangular(0., 0, -1, x, 'u'); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, fillAsTriangular_test3) { + auto x = NDArrayFactory::create( + 'c', {4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, {1, 2, 3, 4, 0, 6, 7, 8, 0, 0, 11, 12, 0, 0, 0, 16}); - auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); - auto exp = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,0,6,7,8,0,0 ,11,12,0 ,0 , 0,16}); - - x.fillAsTriangular(0., 0, 0, x, 'l'); - - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + x.fillAsTriangular(0., 0, 0, x, 'l'); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, fillAsTriangular_test4) { + auto x = NDArrayFactory::create( + 'c', {4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, {0, 2, 3, 4, 0, 0, 7, 8, 0, 0, 0, 12, 0, 0, 0, 0}); - auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); - auto exp = NDArrayFactory::create('c', {4, 4}, {0,2,3,4,0,0,7,8,0,0 , 0,12, 0, 0, 0, 0}); - - x.fillAsTriangular(0., 1, 0, x, 'l'); + x.fillAsTriangular(0., 1, 0, x, 'l'); - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_DType_Conversion_1) { - auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); - auto xd = x.template asT(); + auto xd = x.template asT(); - auto xf = xd.template asT(); + auto xf = xd.template asT(); - ASSERT_TRUE(x.isSameShape(xf)); - ASSERT_TRUE(x.equalsTo(xf)); + ASSERT_TRUE(x.isSameShape(xf)); + ASSERT_TRUE(x.equalsTo(xf)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_ScalarArray_Assign_1) { - auto x = NDArrayFactory::create('c', {2, 2}); - auto y = NDArrayFactory::create(2.0f); - auto exp = NDArrayFactory::create('c', {2, 2}, {2.0f, 2.0f, 2.0f, 2.0f}); + auto x = NDArrayFactory::create('c', {2, 2}); + auto y = NDArrayFactory::create(2.0f); + auto exp = + NDArrayFactory::create('c', {2, 2}, {2.0f, 2.0f, 2.0f, 2.0f}); - x.assign(y); + x.assign(y); - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_Reshape_To_Vector_1) { - auto x = NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto exp = NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto x = NDArrayFactory::create('c', {2, 3}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto exp = + NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - x.reshapei({-1}); + x.reshapei({-1}); - ASSERT_TRUE(exp.isSameShape(x)); - ASSERT_TRUE(exp.equalsTo(x)); + ASSERT_TRUE(exp.isSameShape(x)); + ASSERT_TRUE(exp.equalsTo(x)); } - TEST_F(NDArrayTest2, Test_toIndexedString_1) { - auto x = NDArrayFactory::create('c', {2, 2}, {1.5f, 2.5f, 3.f, 4.5f}); + auto x = NDArrayFactory::create('c', {2, 2}, {1.5f, 2.5f, 3.f, 4.5f}); - auto str = x.asIndexedString(); - std::string exp = "[1.5, 2.5, 3, 4.5]"; + auto str = x.asIndexedString(); + std::string exp = "[1.5, 2.5, 3, 4.5]"; - ASSERT_EQ(exp, str); + ASSERT_EQ(exp, str); } - ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, permute_test4) { + Nd4jLong arr1ShapeInfo[] = {6, 1, 1, 4, 3, 2, 2, 48, + 48, 12, 4, 2, 1, 8192, 1, 99}; + Nd4jLong arr2ShapeInfo[] = {6, 1, 2, 2, 1, 4, 3, 48, + 2, 1, 48, 12, 4, 8192, 0, 99}; - Nd4jLong arr1ShapeInfo[] = {6, 1, 1, 4, 3, 2, 2, 48, 48, 12, 4, 2, 1, 8192, 1, 99}; - Nd4jLong arr2ShapeInfo[] = {6, 1, 2, 2, 1, 4, 3, 48, 2, 1, 48, 12, 4, 8192, 0, 99}; + auto arr1Buffer = new float[786432]; + auto arr2Buffer = new float[786432]; + NDArray arr1(arr1Buffer, arr1ShapeInfo, sd::LaunchContext ::defaultContext()); + NDArray arr2(arr2Buffer, arr2ShapeInfo, sd::LaunchContext ::defaultContext()); - auto arr1Buffer = new float[786432]; - auto arr2Buffer = new float[786432]; + const std::vector perm = {0, 4, 5, 1, 2, 3}; + auto arr1P = arr1.permute(perm); + // arr1P->printShapeInfo(); - NDArray arr1(arr1Buffer, arr1ShapeInfo, sd::LaunchContext ::defaultContext()); - NDArray arr2(arr2Buffer, arr2ShapeInfo, sd::LaunchContext ::defaultContext()); - - const std::vector perm = {0, 4, 5, 1, 2, 3}; - auto arr1P = arr1.permute(perm); - // arr1P->printShapeInfo(); - - // ASSERT_TRUE(arr1.isSameShapeStrict(&arr2)); - ASSERT_TRUE(arr1P.isSameShapeStrict(arr2)); - delete []arr1Buffer; - delete []arr2Buffer; + // ASSERT_TRUE(arr1.isSameShapeStrict(&arr2)); + ASSERT_TRUE(arr1P.isSameShapeStrict(arr2)); + delete[] arr1Buffer; + delete[] arr2Buffer; } //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, TestStdDev3) { + // autoarray('c', {10, 10}); + auto array = NDArrayFactory::create('c', {2, 2}, + {0.2946, 0.2084, 0.0345, 0.7368}); + const int len = array.lengthOf(); - // autoarray('c', {10, 10}); - auto array = NDArrayFactory::create('c', {2, 2}, {0.2946, 0.2084, 0.0345, 0.7368}); - const int len = array.lengthOf(); + double sum = 0.; + for (int i = 0; i < len; ++i) sum += array.e(i); - double sum = 0.; - for(int i=0; i < len; ++i) - sum += array.e(i); + const double mean = sum / len; - const double mean = sum / len; + double diffSquared = 0.; + for (int i = 0; i < len; ++i) + diffSquared += (array.e(i) - mean) * (array.e(i) - mean); - double diffSquared = 0.; - for(int i=0; i < len; ++i) - diffSquared += (array.e(i) - mean) * (array.e(i) - mean); + const double trueVariance = + math::nd4j_sqrt(diffSquared / len); + const double trueVarianceCorr = + math::nd4j_sqrt(diffSquared / (len - 1)); - const double trueVariance = math::nd4j_sqrt(diffSquared / len); - const double trueVarianceCorr = math::nd4j_sqrt(diffSquared / (len - 1)); + const double variance = + array.varianceNumber(variance::SummaryStatsStandardDeviation, false) + .e(0); + const double varianceCorr = + array.varianceNumber(variance::SummaryStatsStandardDeviation, true) + .e(0); - const double variance = array.varianceNumber(variance::SummaryStatsStandardDeviation, false).e(0); - const double varianceCorr = array.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); + // printf("%s expected %.10f calculated %.10f\n","variance :", + // trueVariance, variance ); printf("%s expected %.10f calculated + // %.10f\n","variance corrected:", trueVarianceCorr, varianceCorr); - // printf("%s expected %.10f calculated %.10f\n","variance :", trueVariance, variance ); - // printf("%s expected %.10f calculated %.10f\n","variance corrected:", trueVarianceCorr, varianceCorr); - - ASSERT_NEAR(trueVariance, variance, 1e-8); - ASSERT_NEAR(trueVarianceCorr, varianceCorr, 1e-8); + ASSERT_NEAR(trueVariance, variance, 1e-8); + ASSERT_NEAR(trueVarianceCorr, varianceCorr, 1e-8); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_Linspace_1) { - auto exp = NDArrayFactory::create('c',{1,5}, {1., 2., 3., 4., 5.}); - auto x = NDArrayFactory::create('c', {1, 5}); - x.linspace(1); + auto exp = NDArrayFactory::create('c', {1, 5}, {1., 2., 3., 4., 5.}); + auto x = NDArrayFactory::create('c', {1, 5}); + x.linspace(1); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_Linspace_2) { - auto exp = NDArrayFactory::create('c',{1,5}, {1., 3., 5., 7., 9.}); - auto x = NDArrayFactory::create('c', {1, 5}); + auto exp = NDArrayFactory::create('c', {1, 5}, {1., 3., 5., 7., 9.}); + auto x = NDArrayFactory::create('c', {1, 5}); - x.linspace(1, 2); + x.linspace(1, 2); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_Linspace_3) { + auto exp = + NDArrayFactory::create('c', {1, 5}, {1., 4., 7., 10., 13.}); - auto exp = NDArrayFactory::create('c',{1,5}, {1., 4., 7., 10., 13.}); - - auto x = NDArrayFactory::create('c', {1, 5}); - x.linspace(1,3); + auto x = NDArrayFactory::create('c', {1, 5}); + x.linspace(1, 3); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_Linspace_4) { - auto exp = NDArrayFactory::create('c',{1,5}, {-1., -2., -3., -4., -5.}); + auto exp = + NDArrayFactory::create('c', {1, 5}, {-1., -2., -3., -4., -5.}); - auto x = NDArrayFactory::create('c', {1, 5}); - x.linspace(-1, -1); + auto x = NDArrayFactory::create('c', {1, 5}); + x.linspace(-1, -1); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_Linspace_5) { - auto exp = NDArrayFactory::create('c',{1,5}, {9., 8., 7., 6., 5.}); + auto exp = NDArrayFactory::create('c', {1, 5}, {9., 8., 7., 6., 5.}); - auto x = NDArrayFactory::create('c', {1, 5}); - x.linspace(9, -1); + auto x = NDArrayFactory::create('c', {1, 5}); + x.linspace(9, -1); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); } - //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, allTensorsAlongDimension_test1) { + auto x = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - auto x = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - - auto set = x.allTensorsAlongDimension({0}); - // set->at(0)->printShapeInfo(); - // set->at(0)->printIndexedBuffer(); + auto set = x.allTensorsAlongDimension({0}); + // set->at(0)->printShapeInfo(); + // set->at(0)->printIndexedBuffer(); - ASSERT_TRUE(set.size() == 1); - ASSERT_TRUE(exp.isSameShape(set.at(0))); - ASSERT_TRUE(exp.equalsTo(set.at(0))); + ASSERT_TRUE(set.size() == 1); + ASSERT_TRUE(exp.isSameShape(set.at(0))); + ASSERT_TRUE(exp.equalsTo(set.at(0))); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, scalar_get_test1) { + auto scalar1 = NDArrayFactory::create(20.f); - auto scalar1 = NDArrayFactory::create(20.f); + NDArray arr('c', {2, 2}, {0., 10., 20., 30.}, sd::DataType::FLOAT32); - NDArray arr('c', {2,2}, {0., 10., 20., 30.}, sd::DataType::FLOAT32); + NDArray scalar2 = arr.e(2); - NDArray scalar2 = arr.e(2); - - ASSERT_TRUE(scalar1.isSameShape(scalar2)); - ASSERT_TRUE(scalar1.equalsTo(scalar2)); - ASSERT_TRUE(scalar1.dataType() == scalar2.dataType()); + ASSERT_TRUE(scalar1.isSameShape(scalar2)); + ASSERT_TRUE(scalar1.equalsTo(scalar2)); + ASSERT_TRUE(scalar1.dataType() == scalar2.dataType()); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, scalar_get_test2) { + auto scalar1 = NDArrayFactory::create(20.f); - auto scalar1 = NDArrayFactory::create(20.f); - - NDArray arr('f', {2,2}, {0., 10., 20., 30.}, sd::DataType::FLOAT32); + NDArray arr('f', {2, 2}, {0., 10., 20., 30.}, sd::DataType::FLOAT32); - NDArray scalar2 = arr.e(1); + NDArray scalar2 = arr.e(1); - ASSERT_TRUE(scalar1.isSameShape(scalar2)); - ASSERT_TRUE(scalar1.equalsTo(scalar2)); - ASSERT_TRUE(scalar1.dataType() == scalar2.dataType()); + ASSERT_TRUE(scalar1.isSameShape(scalar2)); + ASSERT_TRUE(scalar1.equalsTo(scalar2)); + ASSERT_TRUE(scalar1.dataType() == scalar2.dataType()); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, scalar_set_test1) { + NDArray scalar1 = NDArrayFactory::create(20.f); - NDArray scalar1 = NDArrayFactory::create(20.f); + NDArray arr('c', {2, 2}, {0., 10., -20., 30.}, sd::DataType::FLOAT32); + NDArray exp('c', {2, 2}, {0., 10., 20., 30.}, sd::DataType::FLOAT32); - NDArray arr('c', {2,2}, {0., 10., -20., 30.}, sd::DataType::FLOAT32); - NDArray exp('c', {2,2}, {0., 10., 20., 30.}, sd::DataType::FLOAT32); + arr.p(2, scalar1); - arr.p(2, scalar1); - - ASSERT_TRUE(exp.equalsTo(arr)); + ASSERT_TRUE(exp.equalsTo(arr)); } - //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, scalar_set_test2) { + NDArray scalar1 = NDArrayFactory::create(20.f); - NDArray scalar1 = NDArrayFactory::create(20.f); - - NDArray arr('f', {2,2}, {0., 10., -20., 30.}, sd::DataType::FLOAT32); - NDArray exp('f', {2,2}, {0., 10., 20., 30.}, sd::DataType::FLOAT32); + NDArray arr('f', {2, 2}, {0., 10., -20., 30.}, sd::DataType::FLOAT32); + NDArray exp('f', {2, 2}, {0., 10., 20., 30.}, sd::DataType::FLOAT32); - arr.p(1, scalar1); + arr.p(1, scalar1); - ASSERT_TRUE(exp.equalsTo(arr)); + ASSERT_TRUE(exp.equalsTo(arr)); } TEST_F(NDArrayTest2, big_dup_test) { - // auto arr = NDArrayFactory::linspace(1.0f, 10000000.0f, 100000000); - auto arr = NDArrayFactory::linspace(1.0f, 1000.0f, 10000); - auto dup = new NDArray(arr->dup('c')); + // auto arr = NDArrayFactory::linspace(1.0f, 10000000.0f, 100000000); + auto arr = NDArrayFactory::linspace(1.0f, 1000.0f, 10000); + auto dup = new NDArray(arr->dup('c')); - ASSERT_EQ(*arr, *dup); + ASSERT_EQ(*arr, *dup); - delete arr; - delete dup; + delete arr; + delete dup; } TEST_F(NDArrayTest2, debugInfoTest_1) { - NDArray testArray('c', {2, 4, 4, 4}, { - 91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., - 51., 42., 67., 24., 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., - 31., 22., 87., 44., 55., 46., 73., 28., -119., 12., 112., 13., 14., 114., 16., 117., - 91., -82., 37., 64., -55.1, 0, 73., 28., -119., 12., 112., 13., 14., 114., 16.2, 117., - 91., -82., 37., 64., 55., 46., 73., 28., -119., 12., 112., 13., 14., 114., 16., 117., - 51., 42., 67., 24., 15., 0., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., - 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., - 91., 82., 37., 64., -3, 0, 73., 28., 119., 12., 112., 13., 140., 110., 160., 107.}, sd::DataType::DOUBLE); - NDArray res(sd::DataType::DOUBLE); - DebugInfo info = DebugHelper::debugStatistics(&testArray); - DebugInfo exp; // = {} - sd::ops::reduce_min minOp; - sd::ops::reduce_mean meanOp; - sd::ops::reduce_max maxOp; - sd::ops::reduce_stdev stdevOp; - - minOp.execute({&testArray}, {&res}, {}, {}, {}); - exp._minValue = res.e(0); - meanOp.execute({&testArray}, {&res}, {}, {}, {}); - exp._meanValue = res.e(0); - maxOp.execute({&testArray}, {&res}, {}, {}, {}); - exp._maxValue = res.e(0); - stdevOp.execute({&testArray}, {&res}, {}, {}, {}); - exp._stdDevValue = res.e(0); - exp._zeroCount = 3; - exp._negativeCount = 7; - exp._positiveCount = 118; - exp._infCount = 0; - exp._nanCount = 0; - printf("Output statistics %lf %lf %lf %lf\n", info._minValue, info._maxValue, info._meanValue, info._stdDevValue); - printf("Expect statistics %lf %lf %lf %lf\n", exp._minValue, exp._maxValue, exp._meanValue, exp._stdDevValue); - printf("%lld %lld %lld %lld %lld\n", info._zeroCount, info._negativeCount, info._positiveCount, info._infCount, info._nanCount); - ASSERT_EQ(exp, info); + NDArray testArray( + 'c', {2, 4, 4, 4}, + {91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., + 13., 14., 114., 16., 117., 51., 42., 67., 24., 15., 56., + 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., 31., + 22., 87., 44., 55., 46., 73., 28., -119., 12., 112., 13., + 14., 114., 16., 117., 91., -82., 37., 64., -55.1, 0, 73., + 28., -119., 12., 112., 13., 14., 114., 16.2, 117., 91., -82., + 37., 64., 55., 46., 73., 28., -119., 12., 112., 13., 14., + 114., 16., 117., 51., 42., 67., 24., 15., 0., 93., 28., + 109., 82., 12., 113., 114., 14., 116., 11., 31., 22., 87., + 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., + 16., 117., 91., 82., 37., 64., -3, 0, 73., 28., 119., + 12., 112., 13., 140., 110., 160., 107.}, + sd::DataType::DOUBLE); + NDArray res(sd::DataType::DOUBLE); + DebugInfo info = DebugHelper::debugStatistics(&testArray); + DebugInfo exp; // = {} + sd::ops::reduce_min minOp; + sd::ops::reduce_mean meanOp; + sd::ops::reduce_max maxOp; + sd::ops::reduce_stdev stdevOp; + + minOp.execute({&testArray}, {&res}, {}, {}, {}); + exp._minValue = res.e(0); + meanOp.execute({&testArray}, {&res}, {}, {}, {}); + exp._meanValue = res.e(0); + maxOp.execute({&testArray}, {&res}, {}, {}, {}); + exp._maxValue = res.e(0); + stdevOp.execute({&testArray}, {&res}, {}, {}, {}); + exp._stdDevValue = res.e(0); + exp._zeroCount = 3; + exp._negativeCount = 7; + exp._positiveCount = 118; + exp._infCount = 0; + exp._nanCount = 0; + printf("Output statistics %lf %lf %lf %lf\n", info._minValue, info._maxValue, + info._meanValue, info._stdDevValue); + printf("Expect statistics %lf %lf %lf %lf\n", exp._minValue, exp._maxValue, + exp._meanValue, exp._stdDevValue); + printf("%lld %lld %lld %lld %lld\n", info._zeroCount, info._negativeCount, + info._positiveCount, info._infCount, info._nanCount); + ASSERT_EQ(exp, info); } TEST_F(NDArrayTest2, debugInfoTest_2) { - NDArray testArray('c', {2, 4, 4, 4}, { - 91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., - 51., 42., 67., 24., 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., - 31., 22., 87., 44., 55., 46., 73., 28., -119., 12., 112., 13., 14., 114., 16., 117., - 91., -82., 37., 64., -55.1, 0, 73., 28., -119., 12., 112., 13., 14., 114., 16.2, 117., - 91., -82., 37., 64., 55., 46., 73., 28., -119., 12., 112., 13., 14., 114., 16., 117., - 51., 42., 67., 24., 15., 0., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., - 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., - 91., 82., 37., 64., -3, 0, 73., 28., 119., 12., 112., 13., 140., 110., 160., 107.}, sd::DataType::DOUBLE); - - DebugInfo info; - DebugInfo exp; // = {} - exp._minValue = -119; - exp._maxValue = 160.; - exp._meanValue = 51.328906; - exp._stdDevValue = 52.385694; - exp._zeroCount = 3; - exp._negativeCount = 7; - exp._positiveCount = 118; - exp._infCount = 0; - exp._nanCount = 0; - DebugHelper::retrieveDebugStatistics(&info, &testArray); - printf("Output statistics %lf %lf %lf %lf\n", info._minValue, info._maxValue, info._meanValue, info._stdDevValue); - printf("Expect statistics %lf %lf %lf %lf\n", exp._minValue, exp._maxValue, exp._meanValue, exp._stdDevValue); - printf("%lld %lld %lld %lld %lld\n", info._zeroCount, info._negativeCount, info._positiveCount, info._infCount, info._nanCount); - //printf("%lf %lf %lf %lf\n", info._minValue, info._maxValue, info._meanValue, info._stdDevValue); - //printf("%lld %lld %lld %lld %lld\n", info._zeroCount, info._negativeCount, info._positiveCount, info._infCount, info._nanCount); - ASSERT_EQ(exp, info); + NDArray testArray( + 'c', {2, 4, 4, 4}, + {91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., + 13., 14., 114., 16., 117., 51., 42., 67., 24., 15., 56., + 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., 31., + 22., 87., 44., 55., 46., 73., 28., -119., 12., 112., 13., + 14., 114., 16., 117., 91., -82., 37., 64., -55.1, 0, 73., + 28., -119., 12., 112., 13., 14., 114., 16.2, 117., 91., -82., + 37., 64., 55., 46., 73., 28., -119., 12., 112., 13., 14., + 114., 16., 117., 51., 42., 67., 24., 15., 0., 93., 28., + 109., 82., 12., 113., 114., 14., 116., 11., 31., 22., 87., + 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., + 16., 117., 91., 82., 37., 64., -3, 0, 73., 28., 119., + 12., 112., 13., 140., 110., 160., 107.}, + sd::DataType::DOUBLE); + + DebugInfo info; + DebugInfo exp; // = {} + exp._minValue = -119; + exp._maxValue = 160.; + exp._meanValue = 51.328906; + exp._stdDevValue = 52.385694; + exp._zeroCount = 3; + exp._negativeCount = 7; + exp._positiveCount = 118; + exp._infCount = 0; + exp._nanCount = 0; + DebugHelper::retrieveDebugStatistics(&info, &testArray); + printf("Output statistics %lf %lf %lf %lf\n", info._minValue, info._maxValue, + info._meanValue, info._stdDevValue); + printf("Expect statistics %lf %lf %lf %lf\n", exp._minValue, exp._maxValue, + exp._meanValue, exp._stdDevValue); + printf("%lld %lld %lld %lld %lld\n", info._zeroCount, info._negativeCount, + info._positiveCount, info._infCount, info._nanCount); + // printf("%lf %lf %lf %lf\n", info._minValue, info._maxValue, + // info._meanValue, info._stdDevValue); printf("%lld %lld %lld %lld %lld\n", + // info._zeroCount, info._negativeCount, info._positiveCount, info._infCount, + // info._nanCount); + ASSERT_EQ(exp, info); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, test_subarray_ews_1) { + NDArray x('c', {10, 5}, sd::DataType::FLOAT32); + auto subArr1 = x.subarray({NDIndex::all(), NDIndex::point(2)}); - NDArray x('c', {10, 5}, sd::DataType::FLOAT32); - auto subArr1 = x.subarray({NDIndex::all(), NDIndex::point(2)}); - - ASSERT_EQ(5, subArr1.ews()); + ASSERT_EQ(5, subArr1.ews()); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, test_subarray_ews_2) { + NDArray x('f', {10, 5}, sd::DataType::FLOAT32); + auto subArr1 = x.subarray({NDIndex::all(), NDIndex::point(2)}); - NDArray x('f', {10, 5}, sd::DataType::FLOAT32); - auto subArr1 = x.subarray({NDIndex::all(), NDIndex::point(2)}); - - ASSERT_EQ(1, subArr1.ews()); + ASSERT_EQ(1, subArr1.ews()); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, test_subarray_ews_3) { + NDArray x('c', {10, 5}, sd::DataType::FLOAT32); + auto subArr1 = x.subarray({NDIndex::point(2), NDIndex::all()}); - NDArray x('c', {10, 5}, sd::DataType::FLOAT32); - auto subArr1 = x.subarray({NDIndex::point(2), NDIndex::all()}); - - ASSERT_EQ(1, subArr1.ews()); + ASSERT_EQ(1, subArr1.ews()); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, test_subarray_ews_4) { + NDArray x('f', {10, 5}, sd::DataType::FLOAT32); + auto subArr1 = x.subarray({NDIndex::point(2), NDIndex::all()}); - NDArray x('f', {10, 5}, sd::DataType::FLOAT32); - auto subArr1 = x.subarray({NDIndex::point(2), NDIndex::all()}); - - ASSERT_EQ(10, subArr1.ews()); + ASSERT_EQ(10, subArr1.ews()); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, subarray_1) { - - NDArray x('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, sd::DataType::FLOAT32); - NDArray y('f', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, sd::DataType::FLOAT32); - - Nd4jLong shapeExpX0[] = {1, 2, 12, 8192, 12, 99}; - float buffExpX0[] = {1.000000, 13.000000}; - float buffExpX1[] = {2.000000, 14.000000}; - Nd4jLong shapeExpX2[] = {3, 2, 1, 1, 12, 4, 1, 8192, 12, 99}; - float buffExpX2[] = {1.000000, 13.000000}; - Nd4jLong shapeExpX3[] = {2, 2, 4, 12, 1, 8192, 0, 99}; - float buffExpX3[] = {9.000000, 10.000000, 11.000000, 12.000000, 21.000000, 22.000000, 23.000000, 24.000000}; - Nd4jLong shapeExpX4[] = {3, 2, 1, 4, 12, 4, 1, 8192, 0, 99}; - float buffExpX4[] = {9.000000, 10.000000, 11.000000, 12.000000, 21.000000, 22.000000, 23.000000, 24.000000}; - Nd4jLong shapeExpX5[] = {2, 2, 3, 12, 4, 8192, 4, 99}; - float buffExpX5[] = {4.000000, 8.000000, 12.000000, 16.000000, 20.000000, 24.000000}; - - Nd4jLong shapeExpY0[] = {1, 2, 1, 8192, 1, 102}; - float buffExpY0[] = {1.000000, 2.000000}; - float buffExpY1[] = {7.000000, 8.000000}; - Nd4jLong shapeExpY2[] = {3, 2, 1, 1, 1, 2, 6, 8192, 1, 102}; - float buffExpY2[] = {1.000000, 2.000000}; - Nd4jLong shapeExpY3[] = {2, 2, 4, 1, 6, 8192, 0, 102}; - float buffExpY3[] = {5.000000, 11.000000, 17.000000, 23.000000, 6.000000, 12.000000, 18.000000, 24.000000}; - Nd4jLong shapeExpY4[] = {3, 2, 1, 4, 1, 2, 6, 8192, 0, 102}; - float buffExpY4[] = {5.000000, 11.000000, 17.000000, 23.000000, 6.000000, 12.000000, 18.000000, 24.000000}; - Nd4jLong shapeExpY5[] = {2, 2, 3, 1, 2, 8192, 1, 102}; - float buffExpY5[] = {19.000000, 21.000000, 23.000000, 20.000000, 22.000000, 24.000000}; - - - NDArray x0 = x(0, {1,2}); - for(int i = 0; i < shape::shapeInfoLength(x0.rankOf()); ++i) - ASSERT_TRUE(x0.shapeInfo()[i] == shapeExpX0[i]); - for(int i = 0; i < x0.lengthOf(); ++i) - ASSERT_TRUE(x0.e(i) == buffExpX0[i]); - - NDArray x1 = x(1, {1,2}); - for(int i = 0; i < shape::shapeInfoLength(x1.rankOf()); ++i) - ASSERT_TRUE(x1.shapeInfo()[i] == shapeExpX0[i]); - for(int i = 0; i < x1.lengthOf(); ++i) - ASSERT_TRUE(x1.e(i) == buffExpX1[i]); - - NDArray x2 = x(0, {1,2}, true); - for(int i = 0; i < shape::shapeInfoLength(x2.rankOf()); ++i) - ASSERT_TRUE(x2.shapeInfo()[i] == shapeExpX2[i]); - for(int i = 0; i < x2.lengthOf(); ++i) - ASSERT_TRUE(x2.e(i) == buffExpX2[i]); - - NDArray x3 = x(2, {1}); - for(int i = 0; i < shape::shapeInfoLength(x3.rankOf()); ++i) - ASSERT_TRUE(x3.shapeInfo()[i] == shapeExpX3[i]); - for(int i = 0; i < x3.lengthOf(); ++i) - ASSERT_TRUE(x3.e(i) == buffExpX3[i]); - - NDArray x4 = x(2, {1}, true); - for(int i = 0; i < shape::shapeInfoLength(x4.rankOf()); ++i) - ASSERT_TRUE(x4.shapeInfo()[i] == shapeExpX4[i]); - for(int i = 0; i < x4.lengthOf(); ++i) - ASSERT_TRUE(x4.e(i) == buffExpX4[i]); - - NDArray x5 = x(3, {2}); - for(int i = 0; i < shape::shapeInfoLength(x5.rankOf()); ++i) - ASSERT_TRUE(x5.shapeInfo()[i] == shapeExpX5[i]); - for(int i = 0; i < x5.lengthOf(); ++i) - ASSERT_TRUE(x5.e(i) == buffExpX5[i]); - - // ******************* // - NDArray y0 = y(0, {1,2}); - for(int i = 0; i < shape::shapeInfoLength(y0.rankOf()); ++i) - ASSERT_TRUE(y0.shapeInfo()[i] == shapeExpY0[i]); - for(int i = 0; i < y0.lengthOf(); ++i) - ASSERT_TRUE(y0.e(i) == buffExpY0[i]); - - NDArray y1 = y(1, {1,2}); - for(int i = 0; i < shape::shapeInfoLength(y1.rankOf()); ++i) - ASSERT_TRUE(y1.shapeInfo()[i] == shapeExpY0[i]); - for(int i = 0; i < y1.lengthOf(); ++i) - ASSERT_TRUE(y1.e(i) == buffExpY1[i]); - - NDArray y2 = y(0, {1,2}, true); - for(int i = 0; i < shape::shapeInfoLength(y2.rankOf()); ++i) - ASSERT_TRUE(y2.shapeInfo()[i] == shapeExpY2[i]); - for(int i = 0; i < y2.lengthOf(); ++i) - ASSERT_TRUE(y2.e(i) == buffExpY2[i]); - - NDArray y3 = y(2, {1}); - for(int i = 0; i < shape::shapeInfoLength(y3.rankOf()); ++i) - ASSERT_TRUE(y3.shapeInfo()[i] == shapeExpY3[i]); - for(int i = 0; i < y3.lengthOf(); ++i) - ASSERT_TRUE(y3.e(i) == buffExpY3[i]); - - NDArray y4 = y(2, {1}, true); - for(int i = 0; i < shape::shapeInfoLength(y4.rankOf()); ++i) - ASSERT_TRUE(y4.shapeInfo()[i] == shapeExpY4[i]); - for(int i = 0; i < y4.lengthOf(); ++i) - ASSERT_TRUE(y4.e(i) == buffExpY4[i]); - - NDArray y5 = y(3, {2}); - for(int i = 0; i < shape::shapeInfoLength(y5.rankOf()); ++i) - ASSERT_TRUE(y5.shapeInfo()[i] == shapeExpY5[i]); - for(int i = 0; i < y5.lengthOf(); ++i) - ASSERT_TRUE(y5.e(i) == buffExpY5[i]); - + NDArray x('c', {2, 3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}, + sd::DataType::FLOAT32); + NDArray y('f', {2, 3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}, + sd::DataType::FLOAT32); + + Nd4jLong shapeExpX0[] = {1, 2, 12, 8192, 12, 99}; + float buffExpX0[] = {1.000000, 13.000000}; + float buffExpX1[] = {2.000000, 14.000000}; + Nd4jLong shapeExpX2[] = {3, 2, 1, 1, 12, 4, 1, 8192, 12, 99}; + float buffExpX2[] = {1.000000, 13.000000}; + Nd4jLong shapeExpX3[] = {2, 2, 4, 12, 1, 8192, 0, 99}; + float buffExpX3[] = {9.000000, 10.000000, 11.000000, 12.000000, + 21.000000, 22.000000, 23.000000, 24.000000}; + Nd4jLong shapeExpX4[] = {3, 2, 1, 4, 12, 4, 1, 8192, 0, 99}; + float buffExpX4[] = {9.000000, 10.000000, 11.000000, 12.000000, + 21.000000, 22.000000, 23.000000, 24.000000}; + Nd4jLong shapeExpX5[] = {2, 2, 3, 12, 4, 8192, 4, 99}; + float buffExpX5[] = {4.000000, 8.000000, 12.000000, + 16.000000, 20.000000, 24.000000}; + + Nd4jLong shapeExpY0[] = {1, 2, 1, 8192, 1, 102}; + float buffExpY0[] = {1.000000, 2.000000}; + float buffExpY1[] = {7.000000, 8.000000}; + Nd4jLong shapeExpY2[] = {3, 2, 1, 1, 1, 2, 6, 8192, 1, 102}; + float buffExpY2[] = {1.000000, 2.000000}; + Nd4jLong shapeExpY3[] = {2, 2, 4, 1, 6, 8192, 0, 102}; + float buffExpY3[] = {5.000000, 11.000000, 17.000000, 23.000000, + 6.000000, 12.000000, 18.000000, 24.000000}; + Nd4jLong shapeExpY4[] = {3, 2, 1, 4, 1, 2, 6, 8192, 0, 102}; + float buffExpY4[] = {5.000000, 11.000000, 17.000000, 23.000000, + 6.000000, 12.000000, 18.000000, 24.000000}; + Nd4jLong shapeExpY5[] = {2, 2, 3, 1, 2, 8192, 1, 102}; + float buffExpY5[] = {19.000000, 21.000000, 23.000000, + 20.000000, 22.000000, 24.000000}; + + NDArray x0 = x(0, {1, 2}); + for (int i = 0; i < shape::shapeInfoLength(x0.rankOf()); ++i) + ASSERT_TRUE(x0.shapeInfo()[i] == shapeExpX0[i]); + for (int i = 0; i < x0.lengthOf(); ++i) + ASSERT_TRUE(x0.e(i) == buffExpX0[i]); + + NDArray x1 = x(1, {1, 2}); + for (int i = 0; i < shape::shapeInfoLength(x1.rankOf()); ++i) + ASSERT_TRUE(x1.shapeInfo()[i] == shapeExpX0[i]); + for (int i = 0; i < x1.lengthOf(); ++i) + ASSERT_TRUE(x1.e(i) == buffExpX1[i]); + + NDArray x2 = x(0, {1, 2}, true); + for (int i = 0; i < shape::shapeInfoLength(x2.rankOf()); ++i) + ASSERT_TRUE(x2.shapeInfo()[i] == shapeExpX2[i]); + for (int i = 0; i < x2.lengthOf(); ++i) + ASSERT_TRUE(x2.e(i) == buffExpX2[i]); + + NDArray x3 = x(2, {1}); + for (int i = 0; i < shape::shapeInfoLength(x3.rankOf()); ++i) + ASSERT_TRUE(x3.shapeInfo()[i] == shapeExpX3[i]); + for (int i = 0; i < x3.lengthOf(); ++i) + ASSERT_TRUE(x3.e(i) == buffExpX3[i]); + + NDArray x4 = x(2, {1}, true); + for (int i = 0; i < shape::shapeInfoLength(x4.rankOf()); ++i) + ASSERT_TRUE(x4.shapeInfo()[i] == shapeExpX4[i]); + for (int i = 0; i < x4.lengthOf(); ++i) + ASSERT_TRUE(x4.e(i) == buffExpX4[i]); + + NDArray x5 = x(3, {2}); + for (int i = 0; i < shape::shapeInfoLength(x5.rankOf()); ++i) + ASSERT_TRUE(x5.shapeInfo()[i] == shapeExpX5[i]); + for (int i = 0; i < x5.lengthOf(); ++i) + ASSERT_TRUE(x5.e(i) == buffExpX5[i]); + + // ******************* // + NDArray y0 = y(0, {1, 2}); + for (int i = 0; i < shape::shapeInfoLength(y0.rankOf()); ++i) + ASSERT_TRUE(y0.shapeInfo()[i] == shapeExpY0[i]); + for (int i = 0; i < y0.lengthOf(); ++i) + ASSERT_TRUE(y0.e(i) == buffExpY0[i]); + + NDArray y1 = y(1, {1, 2}); + for (int i = 0; i < shape::shapeInfoLength(y1.rankOf()); ++i) + ASSERT_TRUE(y1.shapeInfo()[i] == shapeExpY0[i]); + for (int i = 0; i < y1.lengthOf(); ++i) + ASSERT_TRUE(y1.e(i) == buffExpY1[i]); + + NDArray y2 = y(0, {1, 2}, true); + for (int i = 0; i < shape::shapeInfoLength(y2.rankOf()); ++i) + ASSERT_TRUE(y2.shapeInfo()[i] == shapeExpY2[i]); + for (int i = 0; i < y2.lengthOf(); ++i) + ASSERT_TRUE(y2.e(i) == buffExpY2[i]); + + NDArray y3 = y(2, {1}); + for (int i = 0; i < shape::shapeInfoLength(y3.rankOf()); ++i) + ASSERT_TRUE(y3.shapeInfo()[i] == shapeExpY3[i]); + for (int i = 0; i < y3.lengthOf(); ++i) + ASSERT_TRUE(y3.e(i) == buffExpY3[i]); + + NDArray y4 = y(2, {1}, true); + for (int i = 0; i < shape::shapeInfoLength(y4.rankOf()); ++i) + ASSERT_TRUE(y4.shapeInfo()[i] == shapeExpY4[i]); + for (int i = 0; i < y4.lengthOf(); ++i) + ASSERT_TRUE(y4.e(i) == buffExpY4[i]); + + NDArray y5 = y(3, {2}); + for (int i = 0; i < shape::shapeInfoLength(y5.rankOf()); ++i) + ASSERT_TRUE(y5.shapeInfo()[i] == shapeExpY5[i]); + for (int i = 0; i < y5.lengthOf(); ++i) + ASSERT_TRUE(y5.e(i) == buffExpY5[i]); } TEST_F(NDArrayTest2, test_subarray_interval_1) { + NDArray x('f', {10, 10}, sd::DataType::FLOAT32); + auto subArr1 = x.subarray({NDIndex::all(), NDIndex::interval(0, 9)}); - NDArray x('f', {10, 10}, sd::DataType::FLOAT32); - auto subArr1 = x.subarray({NDIndex::all(), NDIndex::interval(0,9)}); - - ASSERT_EQ(10, subArr1.sizeAt(0)); - ASSERT_EQ(9, subArr1.sizeAt(1)); + ASSERT_EQ(10, subArr1.sizeAt(0)); + ASSERT_EQ(9, subArr1.sizeAt(1)); } TEST_F(NDArrayTest2, test_subarray_interval_2) { + NDArray x('c', {10, 10}, sd::DataType::FLOAT32); + auto subArr1 = x.subarray({NDIndex::all(), NDIndex::interval(0, 9)}); - NDArray x('c', {10, 10}, sd::DataType::FLOAT32); - auto subArr1 = x.subarray({NDIndex::all(), NDIndex::interval(0,9)}); - - ASSERT_EQ(10, subArr1.sizeAt(0)); - ASSERT_EQ(9, subArr1.sizeAt(1)); + ASSERT_EQ(10, subArr1.sizeAt(0)); + ASSERT_EQ(9, subArr1.sizeAt(1)); } TEST_F(NDArrayTest2, test_subarray_3d_cf) { - NDArray f('f', {10, 20, 30}, sd::DataType::FLOAT32); - NDArray c('c', {10, 20, 30}, sd::DataType::FLOAT32); + NDArray f('f', {10, 20, 30}, sd::DataType::FLOAT32); + NDArray c('c', {10, 20, 30}, sd::DataType::FLOAT32); - auto subarrayF = f({0,0, 0,0, 2,3}, true); + auto subarrayF = f({0, 0, 0, 0, 2, 3}, true); - auto subarrayC = c({2,3, 0,0, 0,0}, true); + auto subarrayC = c({2, 3, 0, 0, 0, 0}, true); } TEST_F(NDArrayTest2, test_broadcast_row_1) { - auto x = NDArrayFactory::create('c', {10, 5}); - auto y = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); - auto e = NDArrayFactory::create('c', {10, 5}); - e.assign(1.0f); + auto x = NDArrayFactory::create('c', {10, 5}); + auto y = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto e = NDArrayFactory::create('c', {10, 5}); + e.assign(1.0f); - x += y; + x += y; - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(NDArrayTest2, test_broadcast_column_1) { - auto x = NDArrayFactory::create('c', {5, 10}); - auto y = NDArrayFactory::create('c', {5, 1}, {1.f, 1.f, 1.f, 1.f, 1.f}); - auto e = NDArrayFactory::create('c', {5, 10}); - e.assign(1.0f); + auto x = NDArrayFactory::create('c', {5, 10}); + auto y = + NDArrayFactory::create('c', {5, 1}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto e = NDArrayFactory::create('c', {5, 10}); + e.assign(1.0f); - x += y; + x += y; - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(NDArrayTest2, test_broadcast_column_2) { - auto x = NDArrayFactory::create('c', {5, 10}); - auto y = NDArrayFactory::create('c', {5, 1}, {1.f, 1.f, 1.f, 1.f, 1.f}); - auto e = NDArrayFactory::create('c', {5, 10}); - e.assign(1.0f); + auto x = NDArrayFactory::create('c', {5, 10}); + auto y = + NDArrayFactory::create('c', {5, 1}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto e = NDArrayFactory::create('c', {5, 10}); + e.assign(1.0f); - x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, x, false); + x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, x, false); - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(NDArrayTest2, test_broadcast_column_3) { - auto x = NDArrayFactory::create('c', {5, 10}); - auto y = NDArrayFactory::create('c', {5, 1}, {1.f, 1.f, 1.f, 1.f, 1.f}); - auto e = NDArrayFactory::create('c', {5, 10}); - e.assign(1.0f); + auto x = NDArrayFactory::create('c', {5, 10}); + auto y = + NDArrayFactory::create('c', {5, 1}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto e = NDArrayFactory::create('c', {5, 10}); + e.assign(1.0f); - x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, x); + x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, x); - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(NDArrayTest2, test_broadcast_column_4) { - auto x = NDArrayFactory::create('f', {10, 5}); - auto y = NDArrayFactory::create('f', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); - auto e = NDArrayFactory::create('f', {10, 5}); - e.assign(1.0f); + auto x = NDArrayFactory::create('f', {10, 5}); + auto y = NDArrayFactory::create('f', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto e = NDArrayFactory::create('f', {10, 5}); + e.assign(1.0f); - x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, x); + x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, x); - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(NDArrayTest2, test_not_tiled_1) { - auto x = NDArrayFactory::create('c', {4, 12, 128, 128}); - auto y = NDArrayFactory::create('c', {4, 1, 128, 128}); - auto e = NDArrayFactory::create('c', {4, 12, 128, 128}); - y.assign(1.0f); - e.assign(1.0f); + auto x = NDArrayFactory::create('c', {4, 12, 128, 128}); + auto y = NDArrayFactory::create('c', {4, 1, 128, 128}); + auto e = NDArrayFactory::create('c', {4, 12, 128, 128}); + y.assign(1.0f); + e.assign(1.0f); - x += y; + x += y; - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(NDArrayTest2, test_not_tiled_2) { - auto x = NDArrayFactory::create('c', {4, 128, 768}); - auto y = NDArrayFactory::create('c', {4, 128, 1}); - auto e = NDArrayFactory::create('c', {4, 128, 768}); - y.assign(1.0f); - e.assign(1.0f); + auto x = NDArrayFactory::create('c', {4, 128, 768}); + auto y = NDArrayFactory::create('c', {4, 128, 1}); + auto e = NDArrayFactory::create('c', {4, 128, 768}); + y.assign(1.0f); + e.assign(1.0f); - x += y; + x += y; - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(NDArrayTest2, test_long_sum_1) { - auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto z = x.reduceAlongDimension(reduce::Sum, {0}); + auto z = x.reduceAlongDimension(reduce::Sum, {0}); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, reshapei_1) { + Nd4jLong shapeInfo1[] = {6, 2, 1, 2, 1, 7, 1, 7, + 7, 14, 28, 1, 1, 8192, 0, 99}; + Nd4jLong shapeInfo2[] = {2, 4, 7, 7, 1, 8192, 1, 99}; - Nd4jLong shapeInfo1[] = {6, 2,1,2,1,7,1, 7,7,14,28,1,1, 8192, 0, 99}; - Nd4jLong shapeInfo2[] = {2, 4, 7, 7, 1, 8192, 1, 99}; + auto buffer = new float[shape::length(shapeInfo1)]; + NDArray x(buffer, shapeInfo1); - auto buffer = new float[shape::length(shapeInfo1)]; - NDArray x(buffer, shapeInfo1); + const bool canReshape = x.reshapei({4, 7}); - const bool canReshape = x.reshapei({4,7}); + ASSERT_FALSE(canReshape); + ASSERT_TRUE(shape::equalsStrict(x.shapeInfo(), shapeInfo2)); - ASSERT_FALSE(canReshape); - ASSERT_TRUE(shape::equalsStrict(x.shapeInfo(), shapeInfo2)); - - delete[] buffer; + delete[] buffer; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, reshapei_2) { + Nd4jLong shapeInfo1[] = {6, 1, 2, 1, 2, 7, 1, 28, + 7, 7, 14, 1, 1, 8192, 0, 99}; + Nd4jLong shapeInfo2[] = {2, 4, 7, 7, 1, 8192, 1, 99}; - Nd4jLong shapeInfo1[] = {6, 1,2,1,2,7,1, 28,7,7,14,1,1, 8192, 0, 99}; - Nd4jLong shapeInfo2[] = {2, 4, 7, 7, 1, 8192, 1, 99}; - - auto buffer = new float[shape::length(shapeInfo1)]; - NDArray x(buffer, shapeInfo1); + auto buffer = new float[shape::length(shapeInfo1)]; + NDArray x(buffer, shapeInfo1); - const bool canReshape = x.reshapei({4,7}); + const bool canReshape = x.reshapei({4, 7}); - ASSERT_FALSE(canReshape); - ASSERT_TRUE(shape::equalsStrict(x.shapeInfo(), shapeInfo2)); + ASSERT_FALSE(canReshape); + ASSERT_TRUE(shape::equalsStrict(x.shapeInfo(), shapeInfo2)); - delete[] buffer; + delete[] buffer; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, trueBroadcast_1) { + NDArray x('f', {2, 3}, {1., 2., 3., 4., 5., 6.}); + NDArray y('f', {1, 3}, {5., 4., 3.}); + NDArray z('c', {2, 3}, sd::DataType::DOUBLE); - NDArray x('f', {2, 3}, {1., 2., 3., 4., 5., 6.}); - NDArray y('f', {1, 3}, {5., 4., 3.}); - NDArray z('c', {2, 3}, sd::DataType::DOUBLE); + auto exp = x - y; + x.applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), y, z); - auto exp = x - y; - x.applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), y, z); + // exp.printIndexedBuffer(); + // z.printIndexedBuffer(); - // exp.printIndexedBuffer(); - // z.printIndexedBuffer(); - - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, reduce_1) { + NDArray arr6('f', {1, 1, 4, 4, 4, 4}, sd::DataType::DOUBLE); + NDArray exp('f', {1, 1, 4, 4}, sd::DataType::DOUBLE); - NDArray arr6('f', {1, 1, 4, 4, 4, 4}, sd::DataType::DOUBLE); - NDArray exp('f', {1, 1, 4, 4}, sd::DataType::DOUBLE); - - arr6.linspace(1); + arr6.linspace(1); - NDArray arr6s = arr6.reduceAlongDimension(sd::reduce::Sum, {2,3}); + NDArray arr6s = arr6.reduceAlongDimension(sd::reduce::Sum, {2, 3}); - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 4; j++) { - double sum = 0; - for (int x = 0; x < 4; x++) { - for (int y = 0; y < 4; y++) { - Nd4jLong indices[] = {0, 0, x, y, i, j}; - Nd4jLong offset = shape::getOffset(arr6.shapeInfo(), indices); - sum += ((double*)arr6.buffer())[offset]; - } - } - exp.p(0, 0, i, j, sum); + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + double sum = 0; + for (int x = 0; x < 4; x++) { + for (int y = 0; y < 4; y++) { + Nd4jLong indices[] = {0, 0, x, y, i, j}; + Nd4jLong offset = shape::getOffset(arr6.shapeInfo(), indices); + sum += ((double *)arr6.buffer())[offset]; } + } + exp.p(0, 0, i, j, sum); } + } - // arr6s->printShapeInfo(); - // exp.printShapeInfo(); - // exp.printIndexedBuffer(); - // arr6s->printIndexedBuffer(); + // arr6s->printShapeInfo(); + // exp.printShapeInfo(); + // exp.printIndexedBuffer(); + // arr6s->printIndexedBuffer(); - ASSERT_TRUE(exp.equalsTo(arr6s)); + ASSERT_TRUE(exp.equalsTo(arr6s)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, reduce3_1) { + NDArray x('c', {1, 4}, {1, 2, 3, 4}); + NDArray y('c', {1, 4}, {2, 3, 4, 5}); + NDArray exp('c', {4}, {1, 1, 1, 1}); - NDArray x('c', {1,4}, {1,2,3,4}); - NDArray y('c', {1,4}, {2,3,4,5}); - NDArray exp('c', {4}, {1,1,1,1}); + NDArray z = x.applyReduce3(sd::reduce3::EuclideanDistance, y, {0}, nullptr); - NDArray z = x.applyReduce3(sd::reduce3::EuclideanDistance, y, {0}, nullptr); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NDArrayTest2, all_tads_1) { - auto x = NDArrayFactory::create('c', {3, 5}); + auto x = NDArrayFactory::create('c', {3, 5}); - auto arrays = x.allTensorsAlongDimension({1}); - ASSERT_EQ(3, arrays.size()); + auto arrays = x.allTensorsAlongDimension({1}); + ASSERT_EQ(3, arrays.size()); } TEST_F(NDArrayTest2, test_trueBroadcast_empty_1) { - auto x = NDArrayFactory::create('c', {0, 2}); - auto y = NDArrayFactory::create('c', {1, 2}); + auto x = NDArrayFactory::create('c', {0, 2}); + auto y = NDArrayFactory::create('c', {1, 2}); - auto z = x + y; + auto z = x + y; - ASSERT_EQ(x, z); + ASSERT_EQ(x, z); } TEST_F(NDArrayTest2, test_trueBroadcast_empty_2) { - auto x = NDArrayFactory::create('c', {0, 2}); - auto y = NDArrayFactory::create('c', {1, 2}); + auto x = NDArrayFactory::create('c', {0, 2}); + auto y = NDArrayFactory::create('c', {1, 2}); - auto z = y + x; + auto z = y + x; - ASSERT_EQ(x, z); + ASSERT_EQ(x, z); } TEST_F(NDArrayTest2, test_subarray_followed_by_reshape_1) { + NDArray x('c', {5, 1, 3}, sd::DataType::FLOAT32); + NDArray e('c', {1, 3}, {7.f, 8.f, 9.f}, sd::DataType::FLOAT32); - NDArray x('c', {5, 1, 3}, sd::DataType::FLOAT32); - NDArray e('c', {1, 3}, {7.f, 8.f, 9.f}, sd::DataType::FLOAT32); - - x.linspace(1.); + x.linspace(1.); - auto s = x({2,3, 0,0, 0,0}); + auto s = x({2, 3, 0, 0, 0, 0}); - // s.printIndexedBuffer("s"); + // s.printIndexedBuffer("s"); - auto r = s.reshape(x.ordering(), {1, 3}); - // r.printIndexedBuffer("r"); + auto r = s.reshape(x.ordering(), {1, 3}); + // r.printIndexedBuffer("r"); - ASSERT_EQ(e, r); + ASSERT_EQ(e, r); } TEST_F(NDArrayTest2, test_numpy_import_1) { - std::string fname("./resources/arr_3,4_float32.npy"); - auto exp = NDArrayFactory::create('c', {3, 4}); - exp.linspace(0); + std::string fname("./resources/arr_3,4_float32.npy"); + auto exp = NDArrayFactory::create('c', {3, 4}); + exp.linspace(0); - auto array = NDArrayFactory::fromNpyFile(fname.c_str()); + auto array = NDArrayFactory::fromNpyFile(fname.c_str()); - ASSERT_EQ(exp, array); + ASSERT_EQ(exp, array); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp index 2f87b509952c..9295db2af829 100644 --- a/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp @@ -18,1455 +18,1431 @@ // Created by GS on 22.07.2019. // -#include "testlayers.h" #include +#include #include -#include -#include -#include -#include -#include -#include -#include -#include #include -#include +#include #include #include +#include +#include +#include +#include +#include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::ops; class NativeOpsTests : public testing::Test { -public: - + public: }; - TEST_F(NativeOpsTests, CreateContextTests_1) { -// auto x = NDArrayFactory::create('c', {5, 5}); -// x.assign(1.0); -// auto z = NDArrayFactory::create('c', {5,5}); -// auto exp = NDArrayFactory::create('c', {5, 5}); - auto context = ::createContext(); - ASSERT_TRUE(context == nullptr); - //delete context; + // auto x = NDArrayFactory::create('c', {5, 5}); + // x.assign(1.0); + // auto z = NDArrayFactory::create('c', {5,5}); + // auto exp = NDArrayFactory::create('c', {5, 5}); + auto context = ::createContext(); + ASSERT_TRUE(context == nullptr); + // delete context; } TEST_F(NativeOpsTests, CreateContextTests_2) { -// auto x = NDArrayFactory::create('c', {5, 5}); -// x.assign(1.0); -// auto z = NDArrayFactory::create('c', {5,5}); -// auto exp = NDArrayFactory::create('c', {5, 5}); - auto context1 = ::createContext(); - auto context2 = ::createContext(); - ASSERT_TRUE(context1 == context2); - //delete context1; - //delete context2; + // auto x = NDArrayFactory::create('c', {5, 5}); + // x.assign(1.0); + // auto z = NDArrayFactory::create('c', {5,5}); + // auto exp = NDArrayFactory::create('c', {5, 5}); + auto context1 = ::createContext(); + auto context2 = ::createContext(); + ASSERT_TRUE(context1 == context2); + // delete context1; + // delete context2; } TEST_F(NativeOpsTests, PointerTests_1) { - auto x = NDArrayFactory::create('c', {5}, {1,2,3,4,5}); + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); // x.linspace(1.0); #ifdef __CUDABLAS__ -printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - ::tryPointer(nullptr, x.buffer(), 4); + ::tryPointer(nullptr, x.buffer(), 4); #endif -// auto exp = NDArrayFactory::create('c', {5, 5}); -// exp.assign(-1.0); -// -// sd::ops::LegacyTransformSameOp op(transform::Neg); // Neg -// auto result = op.execute({&x}, {}, {}); -// -// ASSERT_EQ(1, result->size()); -// -// auto z = result->at(0); -// -// ASSERT_TRUE(exp.equalsTo(z)); -// -// delete result; + // auto exp = NDArrayFactory::create('c', {5, 5}); + // exp.assign(-1.0); + // + // sd::ops::LegacyTransformSameOp op(transform::Neg); // Neg + // auto result = op.execute({&x}, {}, {}); + // + // ASSERT_EQ(1, result->size()); + // + // auto z = result->at(0); + // + // ASSERT_TRUE(exp.equalsTo(z)); + // + // delete result; } TEST_F(NativeOpsTests, ThresholdTests_1) { // auto x = NDArrayFactory::create('c', {5}, {1,2,3,4,5}); // x.linspace(1.0); #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - ::setElementThreshold(4); - ASSERT_TRUE(4 == sd::Environment::getInstance().elementwiseThreshold()); + ::setElementThreshold(4); + ASSERT_TRUE(4 == sd::Environment::getInstance().elementwiseThreshold()); #endif - } TEST_F(NativeOpsTests, ThresholdTests_2) { // auto x = NDArrayFactory::create('c', {5}, {1,2,3,4,5}); // x.linspace(1.0); #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - ::setTADThreshold(4); - ASSERT_TRUE(4 == sd::Environment::getInstance().tadThreshold()); + ::setTADThreshold(4); + ASSERT_TRUE(4 == sd::Environment::getInstance().tadThreshold()); #endif - } TEST_F(NativeOpsTests, ExecIndexReduce_1) { - auto x = NDArrayFactory::create('c', {5}, {1,2,3,4,5}); - auto exp = NDArrayFactory::create(120); - x.linspace(1.0); + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto exp = NDArrayFactory::create(120); + x.linspace(1.0); #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - ::execIndexReduceScalar(nullptr, - indexreduce::IndexMax, - &xBuf, x.shapeInfo(), - nullptr, - nullptr, - &expBuf, exp.shapeInfo(), - nullptr); - - ASSERT_TRUE(exp.e(0) == 4LL); -#endif + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execIndexReduceScalar(nullptr, indexreduce::IndexMax, &xBuf, x.shapeInfo(), + nullptr, nullptr, &expBuf, exp.shapeInfo(), nullptr); + ASSERT_TRUE(exp.e(0) == 4LL); +#endif } TEST_F(NativeOpsTests, ExecIndexReduce_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create(120); - x.linspace(1.0); + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120); + x.linspace(1.0); #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - NDArray dimension = NDArrayFactory::create({}); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - OpaqueDataBuffer dimensionBuf(dimension.dataBuffer()); - - ::execIndexReduce(nullptr, - indexreduce::IndexMax, - &xBuf, x.shapeInfo(), nullptr, - nullptr, - &expBuf, exp.shapeInfo(), - nullptr, - &dimensionBuf, dimension.shapeInfo(), - nullptr); - - ASSERT_TRUE(exp.e(0) == 24LL); -#endif + NDArray dimension = NDArrayFactory::create({}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimensionBuf(dimension.dataBuffer()); + ::execIndexReduce(nullptr, indexreduce::IndexMax, &xBuf, x.shapeInfo(), + nullptr, nullptr, &expBuf, exp.shapeInfo(), nullptr, + &dimensionBuf, dimension.shapeInfo(), nullptr); + + ASSERT_TRUE(exp.e(0) == 24LL); +#endif } TEST_F(NativeOpsTests, ExecBroadcast_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 1}); - auto exp = NDArrayFactory::create('c', {5, 5}); - x.linspace(1.0); - y.linspace(2,2); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 1}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.linspace(1.0); + y.linspace(2, 2); #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - auto dimension = NDArrayFactory::create('c', {1}, {1}); - - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - OpaqueDataBuffer dimBuf(dimension.dataBuffer()); - - ::execBroadcast(nullptr, - broadcast::Add, - &xBuf, x.shapeInfo(), - nullptr, - &yBuf, y.shapeInfo(), - nullptr, - &expBuf, exp.shapeInfo(), - nullptr, - &dimBuf, dimension.shapeInfo(), - nullptr); - - ASSERT_TRUE(exp.e(0) == 3.); -#endif + auto dimension = NDArrayFactory::create('c', {1}, {1}); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + ::execBroadcast(nullptr, broadcast::Add, &xBuf, x.shapeInfo(), nullptr, &yBuf, + y.shapeInfo(), nullptr, &expBuf, exp.shapeInfo(), nullptr, + &dimBuf, dimension.shapeInfo(), nullptr); + + ASSERT_TRUE(exp.e(0) == 3.); +#endif } TEST_F(NativeOpsTests, ExecBroadcast_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 1}); - auto exp = NDArrayFactory::create('c', {5, 5}); - x.linspace(1.0); - y.linspace(2,2); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 1}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.linspace(1.0); + y.linspace(2, 2); #ifdef __CUDABLAS__ -printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - int dimd = 0; - auto dimension = NDArrayFactory::create('c', {1}, {dimd}); - - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - OpaqueDataBuffer dimBuf(dimension.dataBuffer()); - - ::execBroadcastBool(nullptr, - broadcast::EqualTo, - &xBuf, x.shapeInfo(), nullptr, - &yBuf, y.shapeInfo(), nullptr, - &expBuf, exp.shapeInfo(), nullptr, nullptr, - &dimBuf, dimension.shapeInfo(), - nullptr); - ASSERT_TRUE(exp.e(1) && !exp.e(0)); + int dimd = 0; + auto dimension = NDArrayFactory::create('c', {1}, {dimd}); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + + ::execBroadcastBool(nullptr, broadcast::EqualTo, &xBuf, x.shapeInfo(), + nullptr, &yBuf, y.shapeInfo(), nullptr, &expBuf, + exp.shapeInfo(), nullptr, nullptr, &dimBuf, + dimension.shapeInfo(), nullptr); + ASSERT_TRUE(exp.e(1) && !exp.e(0)); #endif - } TEST_F(NativeOpsTests, ExecPairwise_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create('c', {5, 5}); - x.linspace(1.0); - y.assign(2.); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.linspace(1.0); + y.assign(2.); #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - - ::execPairwiseTransform(nullptr, - pairwise::Add, - &xBuf, x.shapeInfo(), nullptr, - &yBuf, y.shapeInfo(), nullptr, - &expBuf, exp.shapeInfo(), nullptr, - nullptr); - ASSERT_TRUE(exp.e(5) == 8.); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execPairwiseTransform(nullptr, pairwise::Add, &xBuf, x.shapeInfo(), nullptr, + &yBuf, y.shapeInfo(), nullptr, &expBuf, + exp.shapeInfo(), nullptr, nullptr); + ASSERT_TRUE(exp.e(5) == 8.); #endif - } TEST_F(NativeOpsTests, ExecPairwise_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create('c', {5, 5}); - x.assign(true); - y.assign(false); - y.r(5) = true; + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.assign(true); + y.assign(false); + y.r(5) = true; #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - - ::execPairwiseTransformBool(nullptr, - pairwise::And, - &xBuf, x.shapeInfo(), nullptr, - &yBuf, y.shapeInfo(), nullptr, - &expBuf, exp.shapeInfo(), nullptr, - nullptr); - ASSERT_TRUE(exp.e(5) && !exp.e(4)); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execPairwiseTransformBool(nullptr, pairwise::And, &xBuf, x.shapeInfo(), + nullptr, &yBuf, y.shapeInfo(), nullptr, &expBuf, + exp.shapeInfo(), nullptr, nullptr); + ASSERT_TRUE(exp.e(5) && !exp.e(4)); #endif - } TEST_F(NativeOpsTests, ReduceTest_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create(120.); - x.linspace(1.0); + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120.); + x.linspace(1.0); #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - auto dimension = NDArrayFactory::create('c', {1}, {1}); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - - ::execReduceFloat(nullptr, - reduce::Mean, - &xBuf, x.shapeInfo(), nullptr, - nullptr, - &expBuf, exp.shapeInfo(), nullptr); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce Mean"); - ASSERT_TRUE(exp.e(0) == 13.); + auto dimension = NDArrayFactory::create('c', {1}, {1}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execReduceFloat(nullptr, reduce::Mean, &xBuf, x.shapeInfo(), nullptr, + nullptr, &expBuf, exp.shapeInfo(), nullptr); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce Mean"); + ASSERT_TRUE(exp.e(0) == 13.); #endif - } TEST_F(NativeOpsTests, ReduceTest_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create(120.); - x.linspace(1.0); + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120.); + x.linspace(1.0); #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - - ::execReduceSame(nullptr, - reduce::Sum, - &xBuf, x.shapeInfo(), nullptr, - nullptr, - &expBuf, exp.shapeInfo(), nullptr); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce Sum"); - ASSERT_TRUE(exp.e(0) == 325.); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execReduceSame(nullptr, reduce::Sum, &xBuf, x.shapeInfo(), nullptr, nullptr, + &expBuf, exp.shapeInfo(), nullptr); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce Sum"); + ASSERT_TRUE(exp.e(0) == 325.); #endif - } TEST_F(NativeOpsTests, ReduceTest_3) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create(false); - x.linspace(1.0); + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(false); + x.linspace(1.0); #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - - ::execReduceBool(nullptr, - reduce::All, - &xBuf, x.shapeInfo(), nullptr, - nullptr, - &expBuf, exp.shapeInfo(), nullptr); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce All"); - ASSERT_TRUE(exp.e(0) == true); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execReduceBool(nullptr, reduce::All, &xBuf, x.shapeInfo(), nullptr, nullptr, + &expBuf, exp.shapeInfo(), nullptr); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce All"); + ASSERT_TRUE(exp.e(0) == true); #endif - } TEST_F(NativeOpsTests, ReduceTest_4) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create(120LL); - x.linspace(1.0); + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120LL); + x.linspace(1.0); #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - - ::execReduceLong(nullptr, - reduce::CountNonZero, - &xBuf, x.shapeInfo(), nullptr, - nullptr, - &expBuf, exp.shapeInfo(), nullptr); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce CountNonZero"); - ASSERT_TRUE(exp.e(0) == 25LL); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execReduceLong(nullptr, reduce::CountNonZero, &xBuf, x.shapeInfo(), nullptr, + nullptr, &expBuf, exp.shapeInfo(), nullptr); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce CountNonZero"); + ASSERT_TRUE(exp.e(0) == 25LL); #endif - } TEST_F(NativeOpsTests, ReduceTest_5) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create(120LL); - x.linspace(1.0); + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120LL); + x.linspace(1.0); #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - auto dimension = NDArrayFactory::create({0, 1}); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - OpaqueDataBuffer dimBuf(dimension.dataBuffer()); - - ::execReduceLong2(nullptr, - reduce::CountNonZero, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo()); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce CountNonZero"); - ASSERT_TRUE(exp.e(0) == 25LL); + auto dimension = NDArrayFactory::create({0, 1}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + + ::execReduceLong2(nullptr, reduce::CountNonZero, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), nullptr, &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), &dimBuf, dimension.shapeInfo(), + dimension.specialShapeInfo()); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce CountNonZero"); + ASSERT_TRUE(exp.e(0) == 25LL); #endif - } TEST_F(NativeOpsTests, ReduceTest_6) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto z = NDArrayFactory::create({5, 4, 3, 2, 1}); - auto exp = NDArrayFactory::create({1,2,3,4,6}); - x.linspace(1.0); + auto x = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create({5, 4, 3, 2, 1}); + auto exp = NDArrayFactory::create({1, 2, 3, 4, 6}); + x.linspace(1.0); #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - auto dimension = NDArrayFactory::create('c', {1}, {1}); - x.p(5, 0); - x.p(10, 0); x.p(11, 0); - x.p(15, 0); x.p(16, 0); x.p(17, 0); - x.p(20, 0); x.p(21, 0); x.p(22, 0); x.p(23, 0); - - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer dimBuf(dimension.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - - ::execReduceLong2(nullptr, - reduce::CountNonZero, - &xBuf, x.shapeInfo(), nullptr, - nullptr, - &expBuf, exp.shapeInfo(), nullptr, - &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo()); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce CountNonZero"); - ASSERT_TRUE(exp.equalsTo(z)); + auto dimension = NDArrayFactory::create('c', {1}, {1}); + x.p(5, 0); + x.p(10, 0); + x.p(11, 0); + x.p(15, 0); + x.p(16, 0); + x.p(17, 0); + x.p(20, 0); + x.p(21, 0); + x.p(22, 0); + x.p(23, 0); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execReduceLong2(nullptr, reduce::CountNonZero, &xBuf, x.shapeInfo(), + nullptr, nullptr, &expBuf, exp.shapeInfo(), nullptr, + &dimBuf, dimension.shapeInfo(), + dimension.specialShapeInfo()); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce CountNonZero"); + ASSERT_TRUE(exp.equalsTo(z)); #endif - } TEST_F(NativeOpsTests, ReduceTest_7) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create(120.); - auto z = NDArrayFactory::create(13.); - + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120.); + auto z = NDArrayFactory::create(13.); - auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); - Nd4jPointer extra[6]; + auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - x.syncToHost(); - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; #endif - x.linspace(1.0); - x.syncToDevice(); - dimension.syncToHost(); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer dimBuf(dimension.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - - ::execReduceFloat2(extra, - reduce::Mean, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo()); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce Mean"); - ASSERT_TRUE(exp.equalsTo(z)); - + x.linspace(1.0); + x.syncToDevice(); + dimension.syncToHost(); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execReduceFloat2(extra, reduce::Mean, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), nullptr, &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), &dimBuf, dimension.shapeInfo(), + dimension.specialShapeInfo()); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce Mean"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, ReduceTest_8) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto z = NDArrayFactory::create(120.); - auto exp = NDArrayFactory::create(325.); - + auto x = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create(120.); + auto exp = NDArrayFactory::create(325.); - auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); - Nd4jPointer extra[6]; + auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); #endif - x.linspace(1.0); - x.syncToDevice(); + x.linspace(1.0); + x.syncToDevice(); - dimension.syncToHost(); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer dimBuf(dimension.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); - - ::execReduceSame2(extra, - reduce::Sum, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &zBuf, z.shapeInfo(), z.specialShapeInfo(), - &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo()); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce Sum"); - ASSERT_TRUE(exp.equalsTo(z)); + dimension.syncToHost(); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + ::execReduceSame2(extra, reduce::Sum, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), nullptr, &zBuf, z.shapeInfo(), + z.specialShapeInfo(), &dimBuf, dimension.shapeInfo(), + dimension.specialShapeInfo()); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce Sum"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, ReduceTest_9) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create(false); - auto z = NDArrayFactory::create(true); + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(false); + auto z = NDArrayFactory::create(true); - auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); - Nd4jPointer extra[6]; + auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); #endif - x.linspace(1.0); - x.syncToDevice(); + x.linspace(1.0); + x.syncToDevice(); - dimension.syncToHost(); + dimension.syncToHost(); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer dimBuf(dimension.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); - ::execReduceBool2(extra, - reduce::All, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo()); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce All"); - ASSERT_TRUE(exp.equalsTo(z)); + ::execReduceBool2(extra, reduce::All, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), nullptr, &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), &dimBuf, dimension.shapeInfo(), + dimension.specialShapeInfo()); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce All"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, Reduce3Test_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create(120.); - auto z = NDArrayFactory::create(650.); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120.); + auto z = NDArrayFactory::create(650.); - auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); - Nd4jPointer extra[6]; + auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - x.linspace(1.0); - y.assign(2.); - x.syncToDevice(); + x.linspace(1.0); + y.assign(2.); + x.syncToDevice(); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); - ::execReduce3(extra, - reduce3::Dot, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &yBuf, y.shapeInfo(), y.specialShapeInfo(), - &expBuf, exp.shapeInfo(), exp.specialShapeInfo()); - //z.printIndexedBuffer("Z"); - //exp.printIndexedBuffer("Reduce3 Dot"); - ASSERT_TRUE(exp.equalsTo(z)); + ::execReduce3(extra, reduce3::Dot, &xBuf, x.shapeInfo(), x.specialShapeInfo(), + nullptr, &yBuf, y.shapeInfo(), y.specialShapeInfo(), &expBuf, + exp.shapeInfo(), exp.specialShapeInfo()); + // z.printIndexedBuffer("Z"); + // exp.printIndexedBuffer("Reduce3 Dot"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, Reduce3Test_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create(120.); - auto z = NDArrayFactory::create(650.); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120.); + auto z = NDArrayFactory::create(650.); - auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); - Nd4jPointer extra[6]; + auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - x.linspace(1.0); - y.assign(2.); - x.syncToDevice(); + x.linspace(1.0); + y.assign(2.); + x.syncToDevice(); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); - ::execReduce3Scalar(extra, - reduce3::Dot, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &yBuf, y.shapeInfo(), y.specialShapeInfo(), - &expBuf, exp.shapeInfo(), exp.specialShapeInfo()); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce3 Dot"); - ASSERT_TRUE(exp.equalsTo(z)); + ::execReduce3Scalar(extra, reduce3::Dot, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), nullptr, &yBuf, y.shapeInfo(), + y.specialShapeInfo(), &expBuf, exp.shapeInfo(), + exp.specialShapeInfo()); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce3 Dot"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, Reduce3Test_3) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create(120.); - auto z = NDArrayFactory::create(650.); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120.); + auto z = NDArrayFactory::create(650.); - auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); - Nd4jPointer extra[6]; + auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - x.linspace(1.0); - y.assign(2.); - x.syncToDevice(); - dimension.syncToHost(); - - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - OpaqueDataBuffer dimBuf(dimension.dataBuffer()); - - ::execReduce3Tad(extra, - reduce3::Dot, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &yBuf, y.shapeInfo(), y.specialShapeInfo(), - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo(), - nullptr, nullptr, nullptr, nullptr); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce All"); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.0); + y.assign(2.); + x.syncToDevice(); + dimension.syncToHost(); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + + ::execReduce3Tad( + extra, reduce3::Dot, &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, + &yBuf, y.shapeInfo(), y.specialShapeInfo(), &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), &dimBuf, dimension.shapeInfo(), + dimension.specialShapeInfo(), nullptr, nullptr, nullptr, nullptr); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce All"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, Reduce3Test_4) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create(120.); - auto z = NDArrayFactory::create(650.); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120.); + auto z = NDArrayFactory::create(650.); - auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); - Nd4jPointer extra[6]; + auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - x.linspace(1.0); - y.assign(2.); - x.syncToDevice(); - dimension.syncToHost(); - int* dimensions = reinterpret_cast(dimension.buffer()); - auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf()); - auto tadPackY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), dimensions, dimension.lengthOf()); - - auto hTADShapeInfoX = tadPackX.primaryShapeInfo(); - auto hTADOffsetsX = tadPackX.primaryOffsets(); - auto hTADShapeInfoY = tadPackY.primaryShapeInfo(); - auto hTADOffsetsY = tadPackY.primaryOffsets(); - - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - OpaqueDataBuffer dimBuf(dimension.dataBuffer()); - - ::execReduce3All(extra, - reduce3::Dot, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &yBuf, y.shapeInfo(), y.specialShapeInfo(), - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo(), - hTADShapeInfoX, hTADOffsetsX, hTADShapeInfoY, hTADOffsetsY); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce All"); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.0); + y.assign(2.); + x.syncToDevice(); + dimension.syncToHost(); + int *dimensions = reinterpret_cast(dimension.buffer()); + auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions( + x.shapeInfo(), dimensions, dimension.lengthOf()); + auto tadPackY = sd::ConstantTadHelper::getInstance().tadForDimensions( + y.shapeInfo(), dimensions, dimension.lengthOf()); + + auto hTADShapeInfoX = tadPackX.primaryShapeInfo(); + auto hTADOffsetsX = tadPackX.primaryOffsets(); + auto hTADShapeInfoY = tadPackY.primaryShapeInfo(); + auto hTADOffsetsY = tadPackY.primaryOffsets(); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + + ::execReduce3All(extra, reduce3::Dot, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), nullptr, &yBuf, y.shapeInfo(), + y.specialShapeInfo(), &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), &dimBuf, dimension.shapeInfo(), + dimension.specialShapeInfo(), hTADShapeInfoX, hTADOffsetsX, + hTADShapeInfoY, hTADOffsetsY); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce All"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, ScalarTest_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create(10.); - auto exp = NDArrayFactory::create('c', {5,5}); - auto z = NDArrayFactory::create('c', {5,5}); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create(10.); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); - Nd4jPointer extra[6]; + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - y.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + y.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - x.linspace(1.0); - z.linspace(10., 10.); - //y.assign(2.); - x.syncToDevice(); - z.syncToDevice(); - - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - - ::execScalar(extra, - scalar::Multiply, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - &yBuf, y.shapeInfo(), y.specialShapeInfo(), nullptr); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce All"); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.0); + z.linspace(10., 10.); + // y.assign(2.); + x.syncToDevice(); + z.syncToDevice(); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execScalar(extra, scalar::Multiply, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), &yBuf, y.shapeInfo(), + y.specialShapeInfo(), nullptr); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce All"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, ScalarTest_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create(10.f); - auto exp = NDArrayFactory::create('c', {5,5}); - auto z = NDArrayFactory::create('c', {5,5}); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create(10.f); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); - Nd4jPointer extra[6]; + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - y.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + y.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - x.linspace(1.0); - z.assign(false); - //y.assign(2.); - x.syncToDevice(); - z.syncToDevice(); - - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - - ::execScalarBool(extra, - scalar::GreaterThan, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - &yBuf, y.shapeInfo(), y.specialShapeInfo(), nullptr); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce All"); - ASSERT_TRUE(exp.e(5) == z.e(5) && exp.e(15) != z.e(15)); + x.linspace(1.0); + z.assign(false); + // y.assign(2.); + x.syncToDevice(); + z.syncToDevice(); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execScalarBool(extra, scalar::GreaterThan, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), &yBuf, y.shapeInfo(), + y.specialShapeInfo(), nullptr); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce All"); + ASSERT_TRUE(exp.e(5) == z.e(5) && + exp.e(15) != z.e(15)); } TEST_F(NativeOpsTests, SummaryStatsScalarTest_1) { - auto x = NDArrayFactory::create('c', {5, 5}, {0.1f, 0.2f, 0.3f, -0.3f, -0.5f, 0.5f, 0.7f, 0.9f, 0.8f, 0.1f, 0.11f, 0.12f, 0.5f, -0.8f, -0.9f, 0.4f, 0.1f, 0.2f, 0.3f, -0.3f, -0.5f, 0.2f, 0.3f, -0.3f, -0.5f}); - auto exp = NDArrayFactory::create(0.9f); - auto z = NDArrayFactory::create(0.21587136f); - - Nd4jPointer extra[6]; + auto x = NDArrayFactory::create( + 'c', {5, 5}, {0.1f, 0.2f, 0.3f, -0.3f, -0.5f, 0.5f, 0.7f, 0.9f, 0.8f, + 0.1f, 0.11f, 0.12f, 0.5f, -0.8f, -0.9f, 0.4f, 0.1f, 0.2f, + 0.3f, -0.3f, -0.5f, 0.2f, 0.3f, -0.3f, -0.5f}); + auto exp = NDArrayFactory::create(0.9f); + auto z = NDArrayFactory::create(0.21587136f); + + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); - ::execSummaryStatsScalar(extra, - variance::SummaryStatsVariance, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), false); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Standard Variance"); - ASSERT_TRUE(exp.equalsTo(z)); + ::execSummaryStatsScalar(extra, variance::SummaryStatsVariance, &xBuf, + x.shapeInfo(), x.specialShapeInfo(), nullptr, + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + false); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Standard Variance"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, SummaryStatsScalarTest_2) { - auto x = NDArrayFactory::create('c', {5, 5}, {0.1, 0.2, 0.3, -0.3, -0.5, 0.5, 0.7, 0.9, 0.8, 0.1, 0.11, 0.12, 0.5, -0.8, -0.9, 0.4, 0.1, 0.2, 0.3, -0.3, -0.5, 0.2, 0.3, -0.3, -0.5}); - auto exp = NDArrayFactory::create(0.9); - auto z = NDArrayFactory::create(0.21587136); - - Nd4jPointer extra[6]; + auto x = NDArrayFactory::create( + 'c', {5, 5}, + {0.1, 0.2, 0.3, -0.3, -0.5, 0.5, 0.7, 0.9, 0.8, 0.1, 0.11, 0.12, 0.5, + -0.8, -0.9, 0.4, 0.1, 0.2, 0.3, -0.3, -0.5, 0.2, 0.3, -0.3, -0.5}); + auto exp = NDArrayFactory::create(0.9); + auto z = NDArrayFactory::create(0.21587136); + + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - ::execSummaryStats(extra, - variance::SummaryStatsVariance, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), false); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Standard Variance"); - ASSERT_TRUE(exp.equalsTo(z)); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execSummaryStats(extra, variance::SummaryStatsVariance, &xBuf, + x.shapeInfo(), x.specialShapeInfo(), nullptr, &expBuf, + exp.shapeInfo(), exp.specialShapeInfo(), false); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Standard Variance"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, SummaryStatsScalarTest_3) { - auto x = NDArrayFactory::create('c', {5, 5}, {0.1, 0.2, 0.3, -0.3, -0.5, 0.5, 0.7, 0.9, 0.8, 0.1, 0.11, 0.12, 0.5, -0.8, -0.9, 0.4, 0.1, 0.2, 0.3, -0.3, -0.5, 0.2, 0.3, -0.3, -0.5}); - auto exp = NDArrayFactory::create(0.9); - auto z = NDArrayFactory::create(0.21587136); - - Nd4jPointer extra[6]; + auto x = NDArrayFactory::create( + 'c', {5, 5}, + {0.1, 0.2, 0.3, -0.3, -0.5, 0.5, 0.7, 0.9, 0.8, 0.1, 0.11, 0.12, 0.5, + -0.8, -0.9, 0.4, 0.1, 0.2, 0.3, -0.3, -0.5, 0.2, 0.3, -0.3, -0.5}); + auto exp = NDArrayFactory::create(0.9); + auto z = NDArrayFactory::create(0.21587136); + + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - auto dimensions = NDArrayFactory::create({0, 1}); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - OpaqueDataBuffer dimBuf(dimensions.dataBuffer()); - - ::execSummaryStatsTad(extra, - variance::SummaryStatsVariance, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - &dimBuf, dimensions.shapeInfo(), dimensions.specialShapeInfo(), - false, - nullptr, nullptr); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Standard Variance"); - ASSERT_TRUE(exp.equalsTo(z)); + auto dimensions = NDArrayFactory::create({0, 1}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimensions.dataBuffer()); + + ::execSummaryStatsTad(extra, variance::SummaryStatsVariance, &xBuf, + x.shapeInfo(), x.specialShapeInfo(), nullptr, &expBuf, + exp.shapeInfo(), exp.specialShapeInfo(), &dimBuf, + dimensions.shapeInfo(), dimensions.specialShapeInfo(), + false, nullptr, nullptr); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Standard Variance"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, TransformTest_1) { - auto x = NDArrayFactory::create('c', {5, 5}, {1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225, 256, 289, 324, 361, 400, 441, 484, 529, 576, 625}); - auto exp = NDArrayFactory::create('c', {5, 5}); - auto z = NDArrayFactory::create('c', {5,5}); - - Nd4jPointer extra[6]; + auto x = NDArrayFactory::create( + 'c', {5, 5}, + {1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, + 196, 225, 256, 289, 324, 361, 400, 441, 484, 529, 576, 625}); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); + + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - z.linspace(1.); + z.linspace(1.); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); - ::execTransformFloat(extra, - transform::Sqrt, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - nullptr); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Sqrt is"); - ASSERT_TRUE(exp.equalsTo(z)); + ::execTransformFloat(extra, transform::Sqrt, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), nullptr); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Sqrt is"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, TransformTest_2) { - auto x = NDArrayFactory::create('c', {5, 5}, {1.f, 4.f, 9.f, 16.f, 25.f, 36.f, 49.f, 64.f, 81.f, 100.f, 121.f, 144.f, 169.f, 196.f, 225.f, 256.f, 289.f, 324.f, 361.f, 400.f, 441.f, 484.f, 529.f, 576.f, 625.f}); - auto exp = NDArrayFactory::create('c', {5, 5}); - auto z = NDArrayFactory::create('c', {5,5}); - - Nd4jPointer extra[6]; + auto x = NDArrayFactory::create( + 'c', {5, 5}, + {1.f, 4.f, 9.f, 16.f, 25.f, 36.f, 49.f, 64.f, 81.f, + 100.f, 121.f, 144.f, 169.f, 196.f, 225.f, 256.f, 289.f, 324.f, + 361.f, 400.f, 441.f, 484.f, 529.f, 576.f, 625.f}); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); + + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - z.linspace(1.); + z.linspace(1.); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); - ::execTransformSame(extra, - transform::Square, - &zBuf, z.shapeInfo(), z.specialShapeInfo(), - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - nullptr); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Square is"); - ASSERT_TRUE(exp.equalsTo(x)); + ::execTransformSame(extra, transform::Square, &zBuf, z.shapeInfo(), + z.specialShapeInfo(), &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), nullptr); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Square is"); + ASSERT_TRUE(exp.equalsTo(x)); } TEST_F(NativeOpsTests, TransformTest_3) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create('c', {5, 5}); - auto z = NDArrayFactory::create('c', {5,5}); + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); - Nd4jPointer extra[6]; + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - x.linspace(1.); - z.assign(true); - x.p(24, -25); - z.p(24, false); + x.linspace(1.); + z.assign(true); + x.p(24, -25); + z.p(24, false); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); - ::execTransformBool(extra, - transform::IsPositive, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - nullptr); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("IsPositive"); - ASSERT_TRUE(exp.equalsTo(z)); + ::execTransformBool(extra, transform::IsPositive, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), nullptr); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("IsPositive"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, TransformTest_4) { - auto x = NDArrayFactory::create('c', {5, 5}, {0, 1, 2, 3, 2, 1, 0, 1.57, 1.57, 1.57, 3.141592, 3.141592, - 3.141592, 0, 0, 0, 0, 1, 1, 2, 2, 2, 1, 0, 0}); - auto exp = NDArrayFactory::create('c', {5, 5}); - auto z = NDArrayFactory::create('c', {5,5}, {1., 0.540302, -0.416147, -0.989992, -0.416147, 0.540302, 1.0, - 0.000796, 0.000796, 0.000796, -1, -1, -1, 1., 1., 1.0, 1.0, - 0.540302, 0.540302, -0.416147, -0.416147, -0.416147, 0.540302, 1., 1.}); - - Nd4jPointer extra[6]; + auto x = NDArrayFactory::create( + 'c', {5, 5}, + {0, 1, 2, 3, 2, 1, 0, 1.57, 1.57, 1.57, 3.141592, 3.141592, 3.141592, + 0, 0, 0, 0, 1, 1, 2, 2, 2, 1, 0, 0}); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create( + 'c', {5, 5}, {1., 0.540302, -0.416147, -0.989992, -0.416147, + 0.540302, 1.0, 0.000796, 0.000796, 0.000796, + -1, -1, -1, 1., 1., + 1.0, 1.0, 0.540302, 0.540302, -0.416147, + -0.416147, -0.416147, 0.540302, 1., 1.}); + + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - //z.linspace(1.); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); + // z.linspace(1.); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); - ::execTransformStrict(extra, - transform::Cosine, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - nullptr); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Cosine"); - ASSERT_TRUE(exp.equalsTo(z)); + ::execTransformStrict(extra, transform::Cosine, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), nullptr); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Cosine"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, ScalarTadTest_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create(10.f); - auto exp = NDArrayFactory::create('c', {5,5}); - auto z = NDArrayFactory::create('c', {5,5}); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create(10.f); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); - Nd4jPointer extra[6]; + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - y.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + y.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - x.linspace(1.0); - z.linspace(10., 10.); - //y.assign(2.); - x.syncToDevice(); - z.syncToDevice(); - auto dimension = NDArrayFactory::create({0, 1}); - auto dimensions = reinterpret_cast(dimension.buffer()); - auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf()); - auto tadPackZ = sd::ConstantTadHelper::getInstance().tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf()); - - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - OpaqueDataBuffer dimBuf(dimension.dataBuffer()); - - ::execScalarTad(extra, - scalar::Multiply, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - &yBuf, y.shapeInfo(), y.specialShapeInfo(), - nullptr, - &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo(), - tadPackX.primaryShapeInfo(), tadPackX.primaryOffsets(), tadPackZ.primaryShapeInfo(), tadPackZ.primaryOffsets()); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce All"); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.0); + z.linspace(10., 10.); + // y.assign(2.); + x.syncToDevice(); + z.syncToDevice(); + auto dimension = NDArrayFactory::create({0, 1}); + auto dimensions = reinterpret_cast(dimension.buffer()); + auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions( + x.shapeInfo(), dimensions, dimension.lengthOf()); + auto tadPackZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + z.shapeInfo(), dimensions, dimension.lengthOf()); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + + ::execScalarTad(extra, scalar::Multiply, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), &yBuf, y.shapeInfo(), + y.specialShapeInfo(), nullptr, &dimBuf, dimension.shapeInfo(), + dimension.specialShapeInfo(), tadPackX.primaryShapeInfo(), + tadPackX.primaryOffsets(), tadPackZ.primaryShapeInfo(), + tadPackZ.primaryOffsets()); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce All"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, ScalarTadTest_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create(true); - auto exp = NDArrayFactory::create('c', {5,5}); - auto z = NDArrayFactory::create('c', {5, 5}); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create(true); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); - Nd4jPointer extra[6]; + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - y.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + y.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - x.assign(false); - x.p(5, true); - x.p(15, true); - //z.linspace(10., 10.); - //y.assign(2.); - x.syncToDevice(); - z.syncToDevice(); - auto dimension = NDArrayFactory::create({0, 1}); - auto dimensions = reinterpret_cast(dimension.buffer()); - auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf()); - auto tadPackZ = sd::ConstantTadHelper::getInstance().tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf()); - z.assign(true); - - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - OpaqueDataBuffer dimBuf(dimension.dataBuffer()); - - ::execScalarBoolTad(extra, - scalar::And, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - &expBuf, exp.shapeInfo(), - exp.specialShapeInfo(), - &yBuf, y.shapeInfo(), - y.specialShapeInfo(), - nullptr, - &dimBuf, dimension.shapeInfo(), - dimension.specialShapeInfo(), - tadPackX.primaryShapeInfo(), tadPackX.primaryOffsets(), tadPackZ.primaryShapeInfo(), tadPackZ.primaryOffsets()); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("And"); - ASSERT_TRUE(exp.e(5) == z.e(5) && exp.e(15)); + x.assign(false); + x.p(5, true); + x.p(15, true); + // z.linspace(10., 10.); + // y.assign(2.); + x.syncToDevice(); + z.syncToDevice(); + auto dimension = NDArrayFactory::create({0, 1}); + auto dimensions = reinterpret_cast(dimension.buffer()); + auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions( + x.shapeInfo(), dimensions, dimension.lengthOf()); + auto tadPackZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + z.shapeInfo(), dimensions, dimension.lengthOf()); + z.assign(true); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + + ::execScalarBoolTad(extra, scalar::And, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), &yBuf, y.shapeInfo(), + y.specialShapeInfo(), nullptr, &dimBuf, + dimension.shapeInfo(), dimension.specialShapeInfo(), + tadPackX.primaryShapeInfo(), tadPackX.primaryOffsets(), + tadPackZ.primaryShapeInfo(), tadPackZ.primaryOffsets()); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("And"); + ASSERT_TRUE(exp.e(5) == z.e(5) && exp.e(15)); } TEST_F(NativeOpsTests, ConcatTest_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create('c', {10,5}); - auto z = NDArrayFactory::create('c', {10,5}); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {10, 5}); + auto z = NDArrayFactory::create('c', {10, 5}); - Nd4jPointer extra[6]; + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - y.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + y.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - x.linspace(1.0); - y.linspace(26); - - //y.assign(2.); - x.syncToDevice(); - z.syncToDevice(); - int d = 0; - auto dimension = NDArrayFactory::create('c', {1}, {d}); - auto dimensions = reinterpret_cast(dimension.buffer()); - //auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf()); - auto tadPackZ = sd::ConstantTadHelper::getInstance().tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf()); - exp.linspace(1); - Nd4jPointer datas[] = {x.buffer(), y.buffer()}; - Nd4jPointer shapes[] = {(Nd4jPointer)x.shapeInfo(), (Nd4jPointer)y.shapeInfo()}; - - ::specialConcat(extra, 0, 2, datas, shapes, z.buffer(), z.shapeInfo(), nullptr, nullptr); - -// exp.printIndexedBuffer("Exp"); -// z.printIndexedBuffer("Concat"); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.0); + y.linspace(26); + + // y.assign(2.); + x.syncToDevice(); + z.syncToDevice(); + int d = 0; + auto dimension = NDArrayFactory::create('c', {1}, {d}); + auto dimensions = reinterpret_cast(dimension.buffer()); + // auto tadPackX = + // sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), + // dimensions, dimension.lengthOf()); + auto tadPackZ = sd::ConstantTadHelper::getInstance().tadForDimensions( + z.shapeInfo(), dimensions, dimension.lengthOf()); + exp.linspace(1); + Nd4jPointer datas[] = {x.buffer(), y.buffer()}; + Nd4jPointer shapes[] = {(Nd4jPointer)x.shapeInfo(), + (Nd4jPointer)y.shapeInfo()}; + + ::specialConcat(extra, 0, 2, datas, shapes, z.buffer(), z.shapeInfo(), + nullptr, nullptr); + + // exp.printIndexedBuffer("Exp"); + // z.printIndexedBuffer("Concat"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, InitializeTest_1) { -// ::initializeDevicesAndFunctions(); + // ::initializeDevicesAndFunctions(); } TEST_F(NativeOpsTests, MallocTest_1) { - auto a = ::mallocHost(16, 0); - ::freeHost(a); - auto dA = ::mallocDevice(16, 0, 0); - ::freeDevice(dA, 0); + auto a = ::mallocHost(16, 0); + ::freeHost(a); + auto dA = ::mallocDevice(16, 0, 0); + ::freeDevice(dA, 0); } TEST_F(NativeOpsTests, OMPTest_1) { - auto maxThreads = ::ompGetMaxThreads(); - auto numThreads = ::ompGetNumThreads(); - //::setOmpMinThreads(maxThreads); - //::setOmpNumThreads(numThreads); + auto maxThreads = ::ompGetMaxThreads(); + auto numThreads = ::ompGetNumThreads(); + //::setOmpMinThreads(maxThreads); + //::setOmpNumThreads(numThreads); } TEST_F(NativeOpsTests, CreateTest_1) { - auto xx = ::createContext(); - auto yy = ::createStream(); - auto zz = ::createEvent(); - ::destroyEvent(zz); - if (xx) - delete (LaunchContext*)xx; - if (yy) - printf("Stream should be destoyed before."); - + auto xx = ::createContext(); + auto yy = ::createStream(); + auto zz = ::createEvent(); + ::destroyEvent(zz); + if (xx) delete (LaunchContext *)xx; + if (yy) printf("Stream should be destoyed before."); } TEST_F(NativeOpsTests, MemTest_1) { - auto x = NDArrayFactory::create({10, 20, 30, 40, 50}); - auto y = NDArrayFactory::create({20, 20, 20, 20, 20}); + auto x = NDArrayFactory::create({10, 20, 30, 40, 50}); + auto y = NDArrayFactory::create({20, 20, 20, 20, 20}); #ifdef __CUDABLAS__ - return ; + return; #endif - //ASSERT_TRUE(0 == ::memcpy(x.buffer(), y.buffer(), x.lengthOf() * sizeof(double), 0, nullptr)); - ASSERT_TRUE(0 == ::memcpyAsync(x.buffer(), y.buffer(), x.lengthOf() * sizeof(double), 0, nullptr)); - //ASSERT_TRUE(0 == ::memset(x.buffer(), 119, x.lengthOf() * sizeof(double), 0, nullptr)); - ASSERT_TRUE(0 == ::memsetAsync(x.buffer(), 119, x.lengthOf() * sizeof(double), 0, nullptr)); - + // ASSERT_TRUE(0 == ::memcpy(x.buffer(), y.buffer(), x.lengthOf() * + // sizeof(double), 0, nullptr)); + ASSERT_TRUE(0 == ::memcpyAsync(x.buffer(), y.buffer(), + x.lengthOf() * sizeof(double), 0, nullptr)); + // ASSERT_TRUE(0 == ::memset(x.buffer(), 119, x.lengthOf() * sizeof(double), + // 0, nullptr)); + ASSERT_TRUE(0 == ::memsetAsync(x.buffer(), 119, x.lengthOf() * sizeof(double), + 0, nullptr)); } TEST_F(NativeOpsTests, PullRowsTest_1) { - NDArray x('c', {5, 1}, {0,1,2,3,4}); - NDArray z('c', {4, 1}, sd::DataType::DOUBLE); - NDArray exp('c', {4, 1}, {0,2,3,4}); + NDArray x('c', {5, 1}, {0, 1, 2, 3, 4}); + NDArray z('c', {4, 1}, sd::DataType::DOUBLE); + NDArray exp('c', {4, 1}, {0, 2, 3, 4}); - Nd4jLong indexes[] = {0,2,3,4}; - PointersManager pm(LaunchContext::defaultContext(), "NativeOpsTests::pullRows"); - auto pidx = reinterpret_cast(pm.replicatePointer(indexes, 4 * sizeof(Nd4jLong))); + Nd4jLong indexes[] = {0, 2, 3, 4}; + PointersManager pm(LaunchContext::defaultContext(), + "NativeOpsTests::pullRows"); + auto pidx = reinterpret_cast( + pm.replicatePointer(indexes, 4 * sizeof(Nd4jLong))); - std::vector dims = {1}; + std::vector dims = {1}; - auto xTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), dims); - auto zTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(z.shapeInfo(), dims); + auto xTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + x.shapeInfo(), dims); + auto zTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + z.shapeInfo(), dims); - Nd4jPointer nativeStart[2]; + Nd4jPointer nativeStart[2]; #ifdef __CUDABLAS__ - nativeStart[1] = (x.getContext()->getCudaStream()); + nativeStart[1] = (x.getContext()->getCudaStream()); #endif - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); - pullRows(nativeStart, &xBuf, x.shapeInfo(), x.specialShapeInfo(), - &zBuf, z.shapeInfo(), z.specialShapeInfo(), - 4, pidx, - xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), - zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); + pullRows(nativeStart, &xBuf, x.shapeInfo(), x.specialShapeInfo(), &zBuf, + z.shapeInfo(), z.specialShapeInfo(), 4, pidx, + xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), + zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); - ASSERT_TRUE(z.equalsTo(exp)); - pm.synchronize(); + ASSERT_TRUE(z.equalsTo(exp)); + pm.synchronize(); } TEST_F(NativeOpsTests, TadPackTest_1) { - int dimension[] = {1}; - int const dimensionLength = 1; - auto x = NDArrayFactory::create('c', {2,3,4}); - sd::TadPack* pack = ::tadOnlyShapeInfo(x.shapeInfo(), - dimension, - dimensionLength); - ASSERT_TRUE(pack != nullptr); - delete pack; + int dimension[] = {1}; + int const dimensionLength = 1; + auto x = NDArrayFactory::create('c', {2, 3, 4}); + sd::TadPack *pack = + ::tadOnlyShapeInfo(x.shapeInfo(), dimension, dimensionLength); + ASSERT_TRUE(pack != nullptr); + delete pack; } TEST_F(NativeOpsTests, AverageTest_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create('c', {5,5}); - auto z = NDArrayFactory::create('c', {5,5}); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); #ifdef __CUDABLAS__ - return; + return; #endif - x.linspace(1); - exp.linspace(1); - Nd4jPointer xList[] = {x.buffer(), x.buffer()}; - Nd4jPointer dxList[] = {x.specialBuffer(), x.specialBuffer()}; - ::average(nullptr, - xList, x.shapeInfo(), - dxList, x.specialShapeInfo(), - z.buffer(), z.shapeInfo(), - z.specialBuffer(), z.specialShapeInfo(), - 2, - x.lengthOf(), - true); -// z.printIndexedBuffer("RES"); - ASSERT_TRUE(z.equalsTo(exp)); + x.linspace(1); + exp.linspace(1); + Nd4jPointer xList[] = {x.buffer(), x.buffer()}; + Nd4jPointer dxList[] = {x.specialBuffer(), x.specialBuffer()}; + ::average(nullptr, xList, x.shapeInfo(), dxList, x.specialShapeInfo(), + z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + 2, x.lengthOf(), true); + // z.printIndexedBuffer("RES"); + ASSERT_TRUE(z.equalsTo(exp)); } TEST_F(NativeOpsTests, AccumulateTest_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create('c', {5,5}); - auto z = NDArrayFactory::create('c', {5,5}); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); #ifdef __CUDABLAS__ - return; + return; #endif - x.linspace(1); - exp.linspace(2,2); - Nd4jPointer xList[] = {x.buffer(), x.buffer()}; - Nd4jPointer dxList[] = {x.specialBuffer(), x.specialBuffer()}; - ::accumulate(nullptr, - xList, x.shapeInfo(), - dxList, x.specialShapeInfo(), - z.buffer(), z.shapeInfo(), - z.specialBuffer(), z.specialShapeInfo(), - 2, - x.lengthOf()); -// z.printIndexedBuffer("RES"); - ASSERT_TRUE(z.equalsTo(exp)); + x.linspace(1); + exp.linspace(2, 2); + Nd4jPointer xList[] = {x.buffer(), x.buffer()}; + Nd4jPointer dxList[] = {x.specialBuffer(), x.specialBuffer()}; + ::accumulate(nullptr, xList, x.shapeInfo(), dxList, x.specialShapeInfo(), + z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), 2, x.lengthOf()); + // z.printIndexedBuffer("RES"); + ASSERT_TRUE(z.equalsTo(exp)); } TEST_F(NativeOpsTests, P2PTest_1) { - ::enableP2P(true); - ::checkP2P(); - ::isP2PAvailable(); + ::enableP2P(true); + ::checkP2P(); + ::isP2PAvailable(); } TEST_F(NativeOpsTests, ShuffleTest_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create('c', {5,5}); - auto z = NDArrayFactory::create('c', {5,5}); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); #ifdef __CUDABLAS__ - return; + return; #endif - x.linspace(1); - y.linspace(34); - exp.linspace(2,2); - Nd4jPointer xList[] = {x.buffer(), x.buffer()}; - Nd4jPointer dxList[] = {x.specialBuffer(), y.specialBuffer()}; - Nd4jPointer xShapeList[] = {(Nd4jPointer)x.shapeInfo(), (Nd4jPointer)y.shapeInfo()}; - Nd4jPointer dxShapeList[] = {(Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; - Nd4jPointer zList[] = {z.buffer(), z.buffer()}; - Nd4jPointer dzList[] = {z.specialBuffer(), z.specialBuffer()}; - Nd4jPointer zShapeList[] = {(Nd4jPointer)z.shapeInfo(), (Nd4jPointer)z.shapeInfo()}; - Nd4jPointer dzShapeList[] = {(Nd4jPointer)z.specialShapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; - int shuffleMap[] = {1, 0, 4, 3, 2}; - auto zTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), {1}); - Nd4jPointer zListOffset[] = {(Nd4jPointer)zTadPack.platformOffsets(), (Nd4jPointer)zTadPack.platformOffsets()}; - Nd4jPointer zListTADs[] = {(Nd4jPointer)zTadPack.platformShapeInfo(), (Nd4jPointer)zTadPack.platformShapeInfo()}; - ::shuffle(nullptr, - xList, xShapeList, - dxList, dxShapeList, - zList, zShapeList, - dzList, dzShapeList, - 2, - shuffleMap, zListTADs, zListOffset); -// z.printIndexedBuffer("RES"); -// x.printIndexedBuffer("INPUT shuffled"); -// y.printIndexedBuffer("INPUT 2 shuffled"); -// ASSERT_TRUE(z.equalsTo(exp)); + x.linspace(1); + y.linspace(34); + exp.linspace(2, 2); + Nd4jPointer xList[] = {x.buffer(), x.buffer()}; + Nd4jPointer dxList[] = {x.specialBuffer(), y.specialBuffer()}; + Nd4jPointer xShapeList[] = {(Nd4jPointer)x.shapeInfo(), + (Nd4jPointer)y.shapeInfo()}; + Nd4jPointer dxShapeList[] = {(Nd4jPointer)x.specialShapeInfo(), + (Nd4jPointer)y.specialShapeInfo()}; + Nd4jPointer zList[] = {z.buffer(), z.buffer()}; + Nd4jPointer dzList[] = {z.specialBuffer(), z.specialBuffer()}; + Nd4jPointer zShapeList[] = {(Nd4jPointer)z.shapeInfo(), + (Nd4jPointer)z.shapeInfo()}; + Nd4jPointer dzShapeList[] = {(Nd4jPointer)z.specialShapeInfo(), + (Nd4jPointer)z.specialShapeInfo()}; + int shuffleMap[] = {1, 0, 4, 3, 2}; + auto zTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + x.shapeInfo(), 1); + Nd4jPointer zListOffset[] = {(Nd4jPointer)zTadPack.platformOffsets(), + (Nd4jPointer)zTadPack.platformOffsets()}; + Nd4jPointer zListTADs[] = {(Nd4jPointer)zTadPack.platformShapeInfo(), + (Nd4jPointer)zTadPack.platformShapeInfo()}; + ::shuffle(nullptr, xList, xShapeList, dxList, dxShapeList, zList, zShapeList, + dzList, dzShapeList, 2, shuffleMap, zListTADs, zListOffset); + // z.printIndexedBuffer("RES"); + // x.printIndexedBuffer("INPUT shuffled"); + // y.printIndexedBuffer("INPUT 2 shuffled"); + // ASSERT_TRUE(z.equalsTo(exp)); } TEST_F(NativeOpsTests, ConvertTypesTest_1) { - auto x = NDArrayFactory::create('c', {5, 5}); + auto x = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create('c', {5, 5}); - auto z = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); #ifdef __CUDABLAS__ - return; + return; #endif - x.linspace(2, 2); - exp.linspace(2, 2); - ::convertTypes(nullptr, ND4J_FLOAT32, x.buffer(), x.lengthOf(), ND4J_DOUBLE, z.buffer()); - ASSERT_TRUE(z.equalsTo(exp)); + x.linspace(2, 2); + exp.linspace(2, 2); + ::convertTypes(nullptr, ND4J_FLOAT32, x.buffer(), x.lengthOf(), ND4J_DOUBLE, + z.buffer()); + ASSERT_TRUE(z.equalsTo(exp)); } -//TEST_F(NativeOpsTests, Test_Aggregations_1) { +// TEST_F(NativeOpsTests, Test_Aggregations_1) { // NativeOps ops; // auto x = NDArrayFactory::create('c', {5,5}); // auto y = NDArrayFactory::create('c', {5,5}); // // -// ops.execAggregate(nullptr, 0, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIndexArguments, maxRealArguments, pointer.data(), sd::DataType::FLOAT32); -// void **arguments, -// int numArguments, -// Nd4jLong **shapeArguments, -// int numShapeArguments, -// int *indexArguments, -// int numIndexArguments, -// int **intArrays, -// int numIntArrays, -// void *realArguments, +// ops.execAggregate(nullptr, 0, maxArgs, maxShapes, maxIntArrays, +// maxIntArraySize, maxIndexArguments, maxRealArguments, pointer.data(), +// sd::DataType::FLOAT32); void **arguments, int numArguments, Nd4jLong +// **shapeArguments, int numShapeArguments, int *indexArguments, int +// numIndexArguments, int **intArrays, int numIntArrays, void *realArguments, // int numRealArguments, // sd::DataType dtype //} TEST_F(NativeOpsTests, RandomTest_1) { - auto z = NDArrayFactory::create('c', {100}); - Nd4jPointer extra[] = {nullptr, nullptr}; + auto z = NDArrayFactory::create('c', {100}); + Nd4jPointer extra[] = {nullptr, nullptr}; #ifdef __CUDABLAS__ - return; - extra[1] = z.getContext()->getCudaStream(); + return; + extra[1] = z.getContext()->getCudaStream(); #endif - graph::RandomGenerator rng(1023, 119); - double p = 0.5; - OpaqueDataBuffer zBuf(z.dataBuffer()); + graph::RandomGenerator rng(1023, 119); + double p = 0.5; + OpaqueDataBuffer zBuf(z.dataBuffer()); - ::execRandom(extra, random::BernoulliDistribution, &rng, &zBuf, z.shapeInfo(), z.specialShapeInfo(), &p); + ::execRandom(extra, random::BernoulliDistribution, &rng, &zBuf, z.shapeInfo(), + z.specialShapeInfo(), &p); } TEST_F(NativeOpsTests, RandomTest_2) { - auto x = NDArrayFactory::create('c', {100}); - auto z = NDArrayFactory::create('c', {100}); - Nd4jPointer extra[] = {nullptr, nullptr}; + auto x = NDArrayFactory::create('c', {100}); + auto z = NDArrayFactory::create('c', {100}); + Nd4jPointer extra[] = {nullptr, nullptr}; #ifdef __CUDABLAS__ - return; - extra[1] = z.getContext()->getCudaStream(); + return; + extra[1] = z.getContext()->getCudaStream(); #endif - x.linspace(0, 0.01); - graph::RandomGenerator rng(1023, 119); - double p = 0.5; - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); + x.linspace(0, 0.01); + graph::RandomGenerator rng(1023, 119); + double p = 0.5; + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); - ::execRandom2(extra, random::DropOut, &rng, &xBuf, x.shapeInfo(), x.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(), &p); + ::execRandom2(extra, random::DropOut, &rng, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), &zBuf, z.shapeInfo(), + z.specialShapeInfo(), &p); } TEST_F(NativeOpsTests, RandomTest_3) { - auto x = NDArrayFactory::create('c', {100}); - auto y = NDArrayFactory::create('c', {100}); - auto z = NDArrayFactory::create('c', {100}); - Nd4jPointer extra[] = {nullptr, nullptr}; + auto x = NDArrayFactory::create('c', {100}); + auto y = NDArrayFactory::create('c', {100}); + auto z = NDArrayFactory::create('c', {100}); + Nd4jPointer extra[] = {nullptr, nullptr}; #ifdef __CUDABLAS__ - return; - extra[1] = z.getContext()->getCudaStream(); + return; + extra[1] = z.getContext()->getCudaStream(); #endif - x.linspace(0, 0.01); - x.linspace(1, -0.01); - graph::RandomGenerator rng(1023, 119); - double p = 0.5; - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); + x.linspace(0, 0.01); + x.linspace(1, -0.01); + graph::RandomGenerator rng(1023, 119); + double p = 0.5; + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); - ::execRandom3(extra, random::ProbablisticMerge, &rng, &xBuf, x.shapeInfo(), x.specialShapeInfo(), &yBuf, - y.shapeInfo(), y.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(), &p); + ::execRandom3(extra, random::ProbablisticMerge, &rng, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), &yBuf, y.shapeInfo(), + y.specialShapeInfo(), &zBuf, z.shapeInfo(), + z.specialShapeInfo(), &p); } TEST_F(NativeOpsTests, RandomTest_4) { #ifdef __CUDABLAS__ - return ; + return; #endif - graph::RandomGenerator* rng = (graph::RandomGenerator*)::initRandom(nullptr, 1023, 0, nullptr); - ::refreshBuffer(nullptr, 1203L, rng); - ::reSeedBuffer(nullptr, 3113L, rng); - ::destroyRandom(rng); + graph::RandomGenerator *rng = + (graph::RandomGenerator *)::initRandom(nullptr, 1023, 0, nullptr); + ::refreshBuffer(nullptr, 1203L, rng); + ::reSeedBuffer(nullptr, 3113L, rng); + ::destroyRandom(rng); } TEST_F(NativeOpsTests, SortTest_1) { #ifdef __CUDABLAS__ - return ; + return; #endif - auto sortedVals = NDArrayFactory::create( - {10, 1, 5, 120, 34, 5, 78, 138, 3, 111, 331, 29, 91, 71, 73, 50, 56, 4}); - auto exp = NDArrayFactory::create({1, 3, 4, 5, 5, 10, 29, 34, 50, 56, 71, 73, 78, 91, 111, 120, 138, 331}); + auto sortedVals = NDArrayFactory::create( + {10, 1, 5, 120, 34, 5, 78, 138, 3, 111, 331, 29, 91, 71, 73, 50, 56, 4}); + auto exp = NDArrayFactory::create( + {1, 3, 4, 5, 5, 10, 29, 34, 50, 56, 71, 73, 78, 91, 111, 120, 138, 331}); - ::sort(nullptr, sortedVals.buffer(), sortedVals.shapeInfo(), sortedVals.specialBuffer(), - sortedVals.specialShapeInfo(), false); - ASSERT_TRUE(sortedVals.equalsTo(exp)); + ::sort(nullptr, sortedVals.buffer(), sortedVals.shapeInfo(), + sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), false); + ASSERT_TRUE(sortedVals.equalsTo(exp)); } TEST_F(NativeOpsTests, SortTests_2) { - auto k = NDArrayFactory::create('c', {10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); - auto v = NDArrayFactory::create('c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); - - auto ek = NDArrayFactory::create('c', {10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto ev = NDArrayFactory::create('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); - Nd4jPointer extras[2]; + auto k = NDArrayFactory::create('c', {10}, + {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create( + 'c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create('c', {10}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create( + 'c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + Nd4jPointer extras[2]; #ifdef __CUDABLAS__ - extras[1] = LaunchContext::defaultContext()->getCudaStream(); + extras[1] = LaunchContext::defaultContext()->getCudaStream(); #endif -// OpaqueDataBuffer xBuf(x.dataBuffer()); -// OpaqueDataBuffer yBuf(y.dataBuffer()); -// OpaqueDataBuffer expBuf(exp.dataBuffer()); -// OpaqueDataBuffer dimBuf(exp.dataBuffer()); + // OpaqueDataBuffer xBuf(x.dataBuffer()); + // OpaqueDataBuffer yBuf(y.dataBuffer()); + // OpaqueDataBuffer expBuf(exp.dataBuffer()); + // OpaqueDataBuffer dimBuf(exp.dataBuffer()); - ::sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false); - k.tickWriteDevice(); - v.tickWriteDevice(); + ::sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), + v.specialBuffer(), v.specialShapeInfo(), false); + k.tickWriteDevice(); + v.tickWriteDevice(); - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } TEST_F(NativeOpsTests, SortTest_3) { - auto k = NDArrayFactory::create('c', {10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); - auto v = NDArrayFactory::create('c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + auto k = NDArrayFactory::create('c', {10}, + {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create( + 'c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); - auto ek = NDArrayFactory::create('c', {10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto ev = NDArrayFactory::create('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + auto ek = NDArrayFactory::create('c', {10}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create( + 'c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); #ifdef __CUDABLAS__ - Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; + Nd4jPointer extras[2] = {nullptr, + LaunchContext::defaultContext()->getCudaStream()}; #else - Nd4jPointer extras[2]; + Nd4jPointer extras[2]; #endif - ::sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false); - k.tickWriteDevice(); - v.tickWriteDevice(); + ::sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), + v.specialBuffer(), v.specialShapeInfo(), false); + k.tickWriteDevice(); + v.tickWriteDevice(); - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } TEST_F(NativeOpsTests, SortTest_4) { #ifdef __CUDABLAS__ - return ; + return; #endif - auto sortedVals = NDArrayFactory::create('c', {3, 6}, - { 10, 1, 5, 120, 34, 5, - 78, 138, 3, 111, 331, 29, - 91, 71, 73, 50, 56, 4}); - auto exp = NDArrayFactory::create('c', {3, 6}, {1, 5, 5, 10, 34, 120, 3, 29, 78, 111, 138, 331, 4, 50, 56, 71, 73, 91}); - - std::vector dims({1}); - auto packX = ConstantTadHelper::getInstance().tadForDimensions(sortedVals.shapeInfo(), {1}); - ::sortTad(nullptr, sortedVals.buffer(), sortedVals.shapeInfo(), sortedVals.specialBuffer(), - sortedVals.specialShapeInfo(), dims.data(), dims.size(), packX.platformShapeInfo(), packX.platformOffsets(), false); -// sortedVals.printBuffer("OUT"); -// exp.printIndexedBuffer("EXP"); - ASSERT_TRUE(sortedVals.equalsTo(exp)); + auto sortedVals = NDArrayFactory::create( + 'c', {3, 6}, + {10, 1, 5, 120, 34, 5, 78, 138, 3, 111, 331, 29, 91, 71, 73, 50, 56, 4}); + auto exp = NDArrayFactory::create( + 'c', {3, 6}, + {1, 5, 5, 10, 34, 120, 3, 29, 78, 111, 138, 331, 4, 50, 56, 71, 73, 91}); + + std::vector dims({1}); + auto packX = ConstantTadHelper::getInstance().tadForDimensions( + sortedVals.shapeInfo(), 1); + ::sortTad(nullptr, sortedVals.buffer(), sortedVals.shapeInfo(), + sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), + dims.data(), dims.size(), packX.platformShapeInfo(), + packX.platformOffsets(), false); + // sortedVals.printBuffer("OUT"); + // exp.printIndexedBuffer("EXP"); + ASSERT_TRUE(sortedVals.equalsTo(exp)); } TEST_F(NativeOpsTests, SortTests_5) { - auto k = NDArrayFactory::create('c', {2, 10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); - auto v = NDArrayFactory::create('c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); - - auto ek = NDArrayFactory::create('c', {2, 10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto ev = NDArrayFactory::create('c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); - - Nd4jPointer extras[2]; + auto k = NDArrayFactory::create( + 'c', {2, 10}, + {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create( + 'c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, + 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create( + 'c', {2, 10}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create( + 'c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, + 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + Nd4jPointer extras[2]; #ifdef __CUDABLAS__ - extras[1] = LaunchContext::defaultContext()->getCudaStream(); + extras[1] = LaunchContext::defaultContext()->getCudaStream(); #endif - int axis = 1; + int axis = 1; - ::sortTadByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); - k.tickWriteDevice(); - v.tickWriteDevice(); + ::sortTadByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), + v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); + k.tickWriteDevice(); + v.tickWriteDevice(); -// k.printIndexedBuffer("k"); -// v.printIndexedBuffer("v"); + // k.printIndexedBuffer("k"); + // v.printIndexedBuffer("v"); - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } TEST_F(NativeOpsTests, SortTests_6) { - auto k = NDArrayFactory::create('c', {2, 10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); - auto v = NDArrayFactory::create('c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); - - auto ek = NDArrayFactory::create('c', {2, 10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto ev = NDArrayFactory::create('c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); - - Nd4jPointer extras[2]; + auto k = NDArrayFactory::create( + 'c', {2, 10}, + {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create( + 'c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, + 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create( + 'c', {2, 10}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create( + 'c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, + 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + Nd4jPointer extras[2]; #ifdef __CUDABLAS__ - extras[1] = LaunchContext::defaultContext()->getCudaStream(); + extras[1] = LaunchContext::defaultContext()->getCudaStream(); #endif - int axis = 1; + int axis = 1; - ::sortTadByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); - k.tickWriteDevice(); - v.tickWriteDevice(); + ::sortTadByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), + v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); + k.tickWriteDevice(); + v.tickWriteDevice(); - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } -//TEST_F(NativeOpsTests, MapTests_1) { +// TEST_F(NativeOpsTests, MapTests_1) { //#ifdef __CUDABLAS__ // return ; //#endif @@ -1479,131 +1455,153 @@ TEST_F(NativeOpsTests, SortTests_6) { //} TEST_F(NativeOpsTests, MapTests_1) { - //printf("Custom ops: %s\n", ::getAllCustomOps()); - //printf("All ops: %s\n", ::getAllOperations()); + // printf("Custom ops: %s\n", ::getAllCustomOps()); + // printf("All ops: %s\n", ::getAllOperations()); - ::getAllCustomOps(); - ::getAllOperations(); + ::getAllCustomOps(); + ::getAllOperations(); } TEST_F(NativeOpsTests, CustomOpTest_1) { - auto x = NDArrayFactory::create('c', {1, 6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto z = NDArrayFactory::create('c', {6}); - auto e = NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto x = NDArrayFactory::create('c', {1, 6}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto z = NDArrayFactory::create('c', {6}); + auto e = + NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - sd::ops::squeeze op; + sd::ops::squeeze op; - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), x.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer)x.specialShapeInfo()}; + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)x.buffer(), x.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)x.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.buffer(), z.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)z.buffer(), z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), + (Nd4jPointer)z.specialShapeInfo()}; + auto status = ::execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, + ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, + 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + ASSERT_EQ(Status::OK(), status); - auto status = ::execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); - ASSERT_EQ(Status::OK(), status); - - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(NativeOpsTests, CustomOpTests_2) { - auto array0 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto array1 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto z = NDArrayFactory::create('c', {3, 2}); + auto array0 = NDArrayFactory::create('c', {3, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto array1 = NDArrayFactory::create('c', {3, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto z = NDArrayFactory::create('c', {3, 2}); - auto exp = NDArrayFactory::create('c', {3, 2}, {2.f, 4.f, 6.f, 8.f, 10.f, 12.f}); - Context ctx(1); + auto exp = NDArrayFactory::create('c', {3, 2}, + {2.f, 4.f, 6.f, 8.f, 10.f, 12.f}); + Context ctx(1); - NDArray::prepareSpecialUse({&z}, {&array0, &array1}); + NDArray::prepareSpecialUse({&z}, {&array0, &array1}); - ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), array0.specialBuffer(), array0.specialShapeInfo()); - ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), array1.specialBuffer(), array1.specialShapeInfo()); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), + array0.specialBuffer(), array0.specialShapeInfo()); + ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), + array1.specialBuffer(), array1.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); - ASSERT_EQ(2, ctx.width()); + ASSERT_EQ(2, ctx.width()); - sd::ops::add op; - ::execCustomOp2(nullptr, op.getOpHash(), &ctx); + sd::ops::add op; + ::execCustomOp2(nullptr, op.getOpHash(), &ctx); - NDArray::registerSpecialUse({&z}, {&array0, &array1}); + NDArray::registerSpecialUse({&z}, {&array0, &array1}); - ASSERT_EQ(exp, z); + ASSERT_EQ(exp, z); } TEST_F(NativeOpsTests, CalculateOutputShapeTests_1) { - auto input = NDArrayFactory::create('c', {1, 2, 5, 4}); - auto weights = NDArrayFactory::create('c', {2, 2, 2, 3}); - auto exp = NDArrayFactory::create('c', {1, 3, 5, 4}); + auto input = NDArrayFactory::create('c', {1, 2, 5, 4}); + auto weights = NDArrayFactory::create('c', {2, 2, 2, 3}); + auto exp = NDArrayFactory::create('c', {1, 3, 5, 4}); - sd::ops::conv2d op; + sd::ops::conv2d op; - std::vector tArgs({}); - std::vector iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1}); + std::vector tArgs({}); + std::vector iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1}); - Nd4jPointer ptrs[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer) weights.shapeInfo()}; + Nd4jPointer ptrs[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)weights.shapeInfo()}; #ifdef __CUDABLAS__ - return; + return; #endif - auto shapeList = ::calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 2, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size()); + auto shapeList = + ::calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 2, tArgs.data(), + tArgs.size(), iArgs.data(), iArgs.size()); - ASSERT_EQ(1, shapeList->size()); + ASSERT_EQ(1, shapeList->size()); - ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0))); - ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]); - ASSERT_EQ(exp.sizeAt(1), shape::shapeOf((Nd4jLong *)shapeList->at(0))[1]); - ASSERT_EQ(exp.sizeAt(2), shape::shapeOf((Nd4jLong *)shapeList->at(0))[2]); - ASSERT_EQ(exp.sizeAt(3), shape::shapeOf((Nd4jLong *)shapeList->at(0))[3]); + ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0))); + ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]); + ASSERT_EQ(exp.sizeAt(1), shape::shapeOf((Nd4jLong *)shapeList->at(0))[1]); + ASSERT_EQ(exp.sizeAt(2), shape::shapeOf((Nd4jLong *)shapeList->at(0))[2]); + ASSERT_EQ(exp.sizeAt(3), shape::shapeOf((Nd4jLong *)shapeList->at(0))[3]); - //int *ptr = (int *) shapeList[0]; - //delete[] ptr; - //delete shapeList; + // int *ptr = (int *) shapeList[0]; + // delete[] ptr; + // delete shapeList; - ::deleteShapeList((Nd4jPointer) shapeList); + ::deleteShapeList((Nd4jPointer)shapeList); } TEST_F(NativeOpsTests, CalculateOutputShapeTests_2) { - auto input = NDArrayFactory::create('c', {1, 2, 5, 4}); - auto weights = NDArrayFactory::create('c', {2, 2, 2, 3}); - auto exp = NDArrayFactory::create('c', {1, 3, 5, 4}); + auto input = NDArrayFactory::create('c', {1, 2, 5, 4}); + auto weights = NDArrayFactory::create('c', {2, 2, 2, 3}); + auto exp = NDArrayFactory::create('c', {1, 3, 5, 4}); - sd::ops::conv2d op; + sd::ops::conv2d op; - std::vector tArgs({}); - std::vector bArgsF({}); - std::vector iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1}); + std::vector tArgs({}); + std::vector bArgsF({}); + std::vector iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1}); - Nd4jPointer shapePtrs[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer) weights.shapeInfo()}; - Nd4jPointer dataPtrs[] = {(Nd4jPointer)input.buffer(), (Nd4jPointer)weights.buffer()}; + Nd4jPointer shapePtrs[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)weights.shapeInfo()}; + Nd4jPointer dataPtrs[] = {(Nd4jPointer)input.buffer(), + (Nd4jPointer)weights.buffer()}; #ifdef __CUDABLAS__ - return; + return; #endif - auto shapeList = ::calculateOutputShapes2(nullptr, op.getOpHash(), dataPtrs, shapePtrs, 2, const_cast(tArgs.data()), tArgs.size(), - const_cast(iArgs.data()), iArgs.size(), nullptr, bArgsF.size(), nullptr, 0); -// Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs - ASSERT_EQ(1, shapeList->size()); + auto shapeList = ::calculateOutputShapes2( + nullptr, op.getOpHash(), dataPtrs, shapePtrs, 2, + const_cast(tArgs.data()), tArgs.size(), + const_cast(iArgs.data()), iArgs.size(), nullptr, + bArgsF.size(), nullptr, 0); + // Nd4jPointer* extraPointers, Nd4jLong hash, + // Nd4jPointer* inputBuffers, Nd4jPointer* + // inputShapes, int numInputShapes, double* + // tArgs, int numTArgs, Nd4jLong *iArgs, int + // numIArgs, bool *bArgs, int numBArgs + ASSERT_EQ(1, shapeList->size()); - ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0))); - ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]); - ASSERT_EQ(exp.sizeAt(1), shape::shapeOf((Nd4jLong *)shapeList->at(0))[1]); - ASSERT_EQ(exp.sizeAt(2), shape::shapeOf((Nd4jLong *)shapeList->at(0))[2]); - ASSERT_EQ(exp.sizeAt(3), shape::shapeOf((Nd4jLong *)shapeList->at(0))[3]); + ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0))); + ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]); + ASSERT_EQ(exp.sizeAt(1), shape::shapeOf((Nd4jLong *)shapeList->at(0))[1]); + ASSERT_EQ(exp.sizeAt(2), shape::shapeOf((Nd4jLong *)shapeList->at(0))[2]); + ASSERT_EQ(exp.sizeAt(3), shape::shapeOf((Nd4jLong *)shapeList->at(0))[3]); - //int *ptr = (int *) shapeList[0]; - //delete[] ptr; - //delete shapeList; + // int *ptr = (int *) shapeList[0]; + // delete[] ptr; + // delete shapeList; - ::deleteShapeList((Nd4jPointer) shapeList); + ::deleteShapeList((Nd4jPointer)shapeList); } - TEST_F(NativeOpsTests, interop_databuffer_tests_1) { - auto idb = ::allocateDataBuffer(100, 10, false); - auto ptr = ::dbPrimaryBuffer(idb); - ::deleteDataBuffer(idb); + auto idb = ::allocateDataBuffer(100, 10, false); + auto ptr = ::dbPrimaryBuffer(idb); + ::deleteDataBuffer(idb); } -//Uncomment when needed only - massive calculations -//TEST_F(NativeOpsTests, BenchmarkTests_1) { +// Uncomment when needed only - massive calculations +// TEST_F(NativeOpsTests, BenchmarkTests_1) { // // printf("%s\n", ::runLightBenchmarkSuit(true)); // printf("%s\n", ::runFullBenchmarkSuit(true)); diff --git a/libnd4j/tests_cpu/layers_tests/NlpTests.cpp b/libnd4j/tests_cpu/layers_tests/NlpTests.cpp index 2325e24455ed..b8afdfcf01eb 100644 --- a/libnd4j/tests_cpu/layers_tests/NlpTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NlpTests.cpp @@ -18,457 +18,476 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include #include -#include #include #include +#include +#include +#include "testlayers.h" using namespace sd; - class NlpTests : public testing::Test { -public: - - NlpTests() { - printf("\n"); - fflush(stdout); - } + public: + NlpTests() { + printf("\n"); + fflush(stdout); + } }; TEST_F(NlpTests, basic_sg_hs_test_1) { - auto exp0 = NDArrayFactory::create('c', {1, 10}); - auto exp1 = NDArrayFactory::create('c', {1, 10}); - - exp0.assign(0.01001f); - exp1.assign(0.020005f); - - auto target = NDArrayFactory::create(0); - auto ngStarter = NDArrayFactory::empty(); - auto indices = NDArrayFactory::create('c', {1}, {1}); - auto codes = NDArrayFactory::create('c', {1}); - auto syn0 = NDArrayFactory::create('c', {100, 10}); - auto syn1 = NDArrayFactory::create('c', {100, 10}); - auto syn1Neg = NDArrayFactory::empty(); - auto expTable = NDArrayFactory::create('c', {10000}); - auto negTable = NDArrayFactory::empty(); - auto neu1e = NDArrayFactory::create('c', {10}); - - syn0.assign(0.01); - syn1.assign(0.02); - expTable.assign(0.5); - - auto alpha = NDArrayFactory::create(0.001); - auto randomValue = NDArrayFactory::create(1L); - auto inferenceVector = NDArrayFactory::empty(); - - sd::ops::skipgram op; - auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true); - ASSERT_EQ(Status::OK(), result.status()); - - auto row0 = syn0({0,1, 0,0}, true); - auto row1 = syn1({1,2, 0,0}, true); - - ASSERT_EQ(exp0, row0); - ASSERT_EQ(exp1, row1); - - + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.01001f); + exp1.assign(0.020005f); + + auto target = NDArrayFactory::create(0); + auto ngStarter = NDArrayFactory::empty(); + auto indices = NDArrayFactory::create('c', {1}, {1}); + auto codes = NDArrayFactory::create('c', {1}); + auto syn0 = NDArrayFactory::create('c', {100, 10}); + auto syn1 = NDArrayFactory::create('c', {100, 10}); + auto syn1Neg = NDArrayFactory::empty(); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::empty(); + auto neu1e = NDArrayFactory::create('c', {10}); + + syn0.assign(0.01); + syn1.assign(0.02); + expTable.assign(0.5); + + auto alpha = NDArrayFactory::create(0.001); + auto randomValue = NDArrayFactory::create(1L); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::skipgram op; + auto result = op.evaluate( + {&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, + &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, + {}, {}, {false}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + auto row0 = syn0({0, 1, 0, 0}, true); + auto row1 = syn1({1, 2, 0, 0}, true); + + ASSERT_EQ(exp0, row0); + ASSERT_EQ(exp1, row1); } TEST_F(NlpTests, basic_sg_hs_test_2) { - auto exp0 = NDArrayFactory::create('c', {1, 10}); - auto exp1 = NDArrayFactory::create('c', {1, 10}); - auto exp2 = NDArrayFactory::create('c', {1, 10}); - - exp0.assign(0.01f); - exp1.assign(0.020005f); - exp2.assign(0.019995f); - - auto target = NDArrayFactory::create(0); - auto ngStarter = NDArrayFactory::empty(); - auto indices = NDArrayFactory::create('c', {2}, {1, 2}); - auto codes = NDArrayFactory::create('c', {2}, {0, 1}); - auto syn0 = NDArrayFactory::create('c', {100, 10}); - auto syn1 = NDArrayFactory::create('c', {100, 10}); - auto syn1Neg = NDArrayFactory::empty(); - auto expTable = NDArrayFactory::create('c', {10000}); - auto negTable = NDArrayFactory::empty(); - auto neu1e = NDArrayFactory::create('c', {10}); - - syn0.assign(0.01); - syn1.assign(0.02); - expTable.assign(0.5); - - auto alpha = NDArrayFactory::create(0.001); - auto randomValue = NDArrayFactory::create(1L); - auto inferenceVector = NDArrayFactory::empty(); - - sd::ops::skipgram op; - auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true); - ASSERT_EQ(Status::OK(), result.status()); - - auto row0 = syn0({0,1, 0,0}, true); - auto row1 = syn1({1,2, 0,0}, true); - auto row2 = syn1({2,3, 0,0}, true); - - ASSERT_EQ(exp0, row0); - ASSERT_EQ(exp1, row1); - ASSERT_EQ(exp2, row2); - - + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + auto exp2 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.01f); + exp1.assign(0.020005f); + exp2.assign(0.019995f); + + auto target = NDArrayFactory::create(0); + auto ngStarter = NDArrayFactory::empty(); + auto indices = NDArrayFactory::create('c', {2}, {1, 2}); + auto codes = NDArrayFactory::create('c', {2}, {0, 1}); + auto syn0 = NDArrayFactory::create('c', {100, 10}); + auto syn1 = NDArrayFactory::create('c', {100, 10}); + auto syn1Neg = NDArrayFactory::empty(); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::empty(); + auto neu1e = NDArrayFactory::create('c', {10}); + + syn0.assign(0.01); + syn1.assign(0.02); + expTable.assign(0.5); + + auto alpha = NDArrayFactory::create(0.001); + auto randomValue = NDArrayFactory::create(1L); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::skipgram op; + auto result = op.evaluate( + {&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, + &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, + {}, {}, {false}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + auto row0 = syn0({0, 1, 0, 0}, true); + auto row1 = syn1({1, 2, 0, 0}, true); + auto row2 = syn1({2, 3, 0, 0}, true); + + ASSERT_EQ(exp0, row0); + ASSERT_EQ(exp1, row1); + ASSERT_EQ(exp2, row2); } TEST_F(NlpTests, basic_sg_hs_test_3) { - auto exp0 = NDArrayFactory::create('c', {1, 10}); - auto exp1 = NDArrayFactory::create('c', {1, 10}); - auto exp2 = NDArrayFactory::create('c', {1, 10}); - - exp0.assign(0.01f); - exp1.assign(0.020005f); - exp2.assign(0.019995f); - - auto target = NDArrayFactory::create(0); - auto ngStarter = NDArrayFactory::empty(); - auto indices0 = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto indices1 = NDArrayFactory::create('c', {3}, {3, 1, 2}); - auto codes00 = NDArrayFactory::create('c', {3}, {0, 1, 1}); - auto codes01 = NDArrayFactory::create('c', {3}, {1, 0, 1}); - auto syn00 = NDArrayFactory::create('c', {100, 10}); - auto syn01 = NDArrayFactory::create('c', {100, 10}); - auto syn10 = NDArrayFactory::create('c', {100, 10}); - auto syn11 = NDArrayFactory::create('c', {100, 10}); - auto syn1Neg = NDArrayFactory::empty(); - auto expTable = NDArrayFactory::create('c', {10000}); - auto negTable = NDArrayFactory::empty(); - auto neu1e = NDArrayFactory::create('c', {10}); - - RandomGenerator rng(119L, 198L); - RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &syn00, 0.0, 1.0); - RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &syn10, 0.0, 1.0); - - syn01.assign(syn00); - syn11.assign(syn10); - expTable.assign(0.5); - - auto alpha = NDArrayFactory::create(0.001); - auto randomValue = NDArrayFactory::create(1L); - auto inferenceVector = NDArrayFactory::empty(); - - sd::ops::skipgram op; - auto result0 = op.evaluate({&target, &ngStarter, &indices0, &codes00, &syn00, &syn10, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true); - auto result1 = op.evaluate({&target, &ngStarter, &indices1, &codes01, &syn01, &syn11, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true); - ASSERT_EQ(Status::OK(), result0.status()); - - auto row00 = syn00({0,1, 0,0}, true); - auto row01 = syn01({0,1, 0,0}, true); - auto row1 = syn10({1,2, 0,0}, true); - auto row2 = syn11({1,2, 0,0}, true); - - ASSERT_EQ(row2, row1); - ASSERT_EQ(row00, row01); + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + auto exp2 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.01f); + exp1.assign(0.020005f); + exp2.assign(0.019995f); + + auto target = NDArrayFactory::create(0); + auto ngStarter = NDArrayFactory::empty(); + auto indices0 = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto indices1 = NDArrayFactory::create('c', {3}, {3, 1, 2}); + auto codes00 = NDArrayFactory::create('c', {3}, {0, 1, 1}); + auto codes01 = NDArrayFactory::create('c', {3}, {1, 0, 1}); + auto syn00 = NDArrayFactory::create('c', {100, 10}); + auto syn01 = NDArrayFactory::create('c', {100, 10}); + auto syn10 = NDArrayFactory::create('c', {100, 10}); + auto syn11 = NDArrayFactory::create('c', {100, 10}); + auto syn1Neg = NDArrayFactory::empty(); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::empty(); + auto neu1e = NDArrayFactory::create('c', {10}); + + RandomGenerator rng(119L, 198L); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &syn00, 0.0, + 1.0); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &syn10, 0.0, + 1.0); + + syn01.assign(syn00); + syn11.assign(syn10); + expTable.assign(0.5); + + auto alpha = NDArrayFactory::create(0.001); + auto randomValue = NDArrayFactory::create(1L); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::skipgram op; + auto result0 = op.evaluate( + {&target, &ngStarter, &indices0, &codes00, &syn00, &syn10, &syn1Neg, + &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, + {}, {}, {false}, {}, true); + auto result1 = op.evaluate( + {&target, &ngStarter, &indices1, &codes01, &syn01, &syn11, &syn1Neg, + &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, + {}, {}, {false}, {}, true); + ASSERT_EQ(Status::OK(), result0.status()); + + auto row00 = syn00({0, 1, 0, 0}, true); + auto row01 = syn01({0, 1, 0, 0}, true); + auto row1 = syn10({1, 2, 0, 0}, true); + auto row2 = syn11({1, 2, 0, 0}, true); + + ASSERT_EQ(row2, row1); + ASSERT_EQ(row00, row01); } TEST_F(NlpTests, basic_sg_hs_ns_test_1) { - auto target = NDArrayFactory::create(0); - auto ngStarter = NDArrayFactory::create(1); - auto indices = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - auto codes = NDArrayFactory::create('c', {5}, {1, 1, 0, 1, 1}); - auto syn0 = NDArrayFactory::create('c', {100, 150}); - auto syn1 = NDArrayFactory::create('c', {100, 150}); - auto syn1Neg = NDArrayFactory::create('c', {100, 150}); - auto expTable = NDArrayFactory::create('c', {1000}); - auto negTable = NDArrayFactory::create('c', {1000}); - auto neu1e = NDArrayFactory::create('c', {10}); - negTable.linspace(1.0); - - auto alpha = NDArrayFactory::create(1.25); - auto randomValue = NDArrayFactory::create(119L); - auto inferenceVector = NDArrayFactory::empty(); - - sd::ops::skipgram op; - auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {3}, {false}, {}, true); - ASSERT_EQ(Status::OK(), result.status()); - - + auto target = NDArrayFactory::create(0); + auto ngStarter = NDArrayFactory::create(1); + auto indices = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto codes = NDArrayFactory::create('c', {5}, {1, 1, 0, 1, 1}); + auto syn0 = NDArrayFactory::create('c', {100, 150}); + auto syn1 = NDArrayFactory::create('c', {100, 150}); + auto syn1Neg = NDArrayFactory::create('c', {100, 150}); + auto expTable = NDArrayFactory::create('c', {1000}); + auto negTable = NDArrayFactory::create('c', {1000}); + auto neu1e = NDArrayFactory::create('c', {10}); + negTable.linspace(1.0); + + auto alpha = NDArrayFactory::create(1.25); + auto randomValue = NDArrayFactory::create(119L); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::skipgram op; + auto result = op.evaluate( + {&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, + &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, + {}, {3}, {false}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); } TEST_F(NlpTests, basic_sg_ns_test_1) { - auto exp0 = NDArrayFactory::create('c', {1, 10}); - - exp0.assign(0.01); - - auto target = NDArrayFactory::create(1); - auto ngStarter = NDArrayFactory::create(3); - auto indices = NDArrayFactory::empty(); - auto codes = NDArrayFactory::empty(); - auto syn0 = NDArrayFactory::create('c', {10, 10}); - auto syn1 = NDArrayFactory::empty(); - auto syn1Neg = NDArrayFactory::create('c', {10, 10}); - auto expTable = NDArrayFactory::create('c', {1000}); - auto negTable = NDArrayFactory::create('c', {1000}); - auto neu1e = NDArrayFactory::create('c', {10}); - - auto syn1Neg2 = NDArrayFactory::create('c', {10, 10}); - - syn0.assign(0.01); - syn1.assign(0.02); - syn1Neg.assign(0.03); - syn1Neg2.assign(0.03); - expTable.assign(0.5); - - auto alpha = NDArrayFactory::create(0.001); - auto randomValue = NDArrayFactory::create(2L); - auto inferenceVector = NDArrayFactory::empty(); - - sd::ops::skipgram op; - auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {1, 1}, {false}, {}, true); - ASSERT_EQ(Status::OK(), result.status()); - - auto row0 = syn0({1,2, 0,0}, true); - - ASSERT_EQ(exp0, row0); - ASSERT_FALSE(syn1Neg2.equalsTo(syn1Neg, 1e-6)); - - + auto exp0 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.01); + + auto target = NDArrayFactory::create(1); + auto ngStarter = NDArrayFactory::create(3); + auto indices = NDArrayFactory::empty(); + auto codes = NDArrayFactory::empty(); + auto syn0 = NDArrayFactory::create('c', {10, 10}); + auto syn1 = NDArrayFactory::empty(); + auto syn1Neg = NDArrayFactory::create('c', {10, 10}); + auto expTable = NDArrayFactory::create('c', {1000}); + auto negTable = NDArrayFactory::create('c', {1000}); + auto neu1e = NDArrayFactory::create('c', {10}); + + auto syn1Neg2 = NDArrayFactory::create('c', {10, 10}); + + syn0.assign(0.01); + syn1.assign(0.02); + syn1Neg.assign(0.03); + syn1Neg2.assign(0.03); + expTable.assign(0.5); + + auto alpha = NDArrayFactory::create(0.001); + auto randomValue = NDArrayFactory::create(2L); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::skipgram op; + auto result = op.evaluate( + {&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, + &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, + {}, {1, 1}, {false}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + auto row0 = syn0({1, 2, 0, 0}, true); + + ASSERT_EQ(exp0, row0); + ASSERT_FALSE(syn1Neg2.equalsTo(syn1Neg, 1e-6)); } TEST_F(NlpTests, basic_cb_hs_test_1) { - auto exp0 = NDArrayFactory::create('c', {1, 10}); - auto exp1 = NDArrayFactory::create('c', {1, 10}); - auto exp2 = NDArrayFactory::create('c', {1, 10}); - - exp0.assign(0.0095f); - exp1.assign(0.019875f); - exp2.assign(0.02f); - - auto target = NDArrayFactory::create(0); - auto ngStarter = NDArrayFactory::empty(); - auto context = NDArrayFactory::create('c', {3}, {0, 1, 2}); - auto locked = NDArrayFactory::create('c', {3}); - auto indices = NDArrayFactory::create('c', {2}, {4, 5}); - auto codes = NDArrayFactory::create('c', {2}, {1, 1}); - auto syn0 = NDArrayFactory::create('c', {100, 10}); - auto syn1 = NDArrayFactory::create('c', {100, 10}); - auto syn1Neg = NDArrayFactory::empty(); - auto expTable = NDArrayFactory::create('c', {10000}); - auto negTable = NDArrayFactory::empty(); - auto numWords = NDArrayFactory::create('c', {1}, {1}); - - syn0.assign(0.01); - syn1.assign(0.02); - expTable.assign(0.5); - - auto alpha = NDArrayFactory::create(0.025); - auto randomValue = NDArrayFactory::create(2L); - auto inferenceVector = NDArrayFactory::empty(); - - sd::ops::cbow op; - auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true); - ASSERT_EQ(Status::OK(), result.status()); - - auto row_s0_0 = syn0({0,1, 0,0}, true); - auto row_s0_1 = syn0({1,2, 0,0}, true); - auto row_s0_2 = syn0({2,3, 0,0}, true); - - auto row_s1_4 = syn1({4,5, 0,0}, true); - auto row_s1_5 = syn1({5,6, 0,0}, true); - auto row_s1_6 = syn1({6,7, 0,0}, true); - - ASSERT_EQ(exp0, row_s0_0); - ASSERT_EQ(exp0, row_s0_1); - ASSERT_EQ(exp0, row_s0_2); - - ASSERT_EQ(exp1, row_s1_4); - ASSERT_EQ(exp1, row_s1_5); - ASSERT_EQ(exp2, row_s1_6); - - + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + auto exp2 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.0095f); + exp1.assign(0.019875f); + exp2.assign(0.02f); + + auto target = NDArrayFactory::create(0); + auto ngStarter = NDArrayFactory::empty(); + auto context = NDArrayFactory::create('c', {3}, {0, 1, 2}); + auto locked = NDArrayFactory::create('c', {3}); + auto indices = NDArrayFactory::create('c', {2}, {4, 5}); + auto codes = NDArrayFactory::create('c', {2}, {1, 1}); + auto syn0 = NDArrayFactory::create('c', {100, 10}); + auto syn1 = NDArrayFactory::create('c', {100, 10}); + auto syn1Neg = NDArrayFactory::empty(); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::empty(); + auto numWords = NDArrayFactory::create('c', {1}, {1}); + + syn0.assign(0.01); + syn1.assign(0.02); + expTable.assign(0.5); + + auto alpha = NDArrayFactory::create(0.025); + auto randomValue = NDArrayFactory::create(2L); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::cbow op; + auto result = + op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, + &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, + &numWords, &locked, &inferenceVector}, + {}, {}, {true}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + auto row_s0_0 = syn0({0, 1, 0, 0}, true); + auto row_s0_1 = syn0({1, 2, 0, 0}, true); + auto row_s0_2 = syn0({2, 3, 0, 0}, true); + + auto row_s1_4 = syn1({4, 5, 0, 0}, true); + auto row_s1_5 = syn1({5, 6, 0, 0}, true); + auto row_s1_6 = syn1({6, 7, 0, 0}, true); + + ASSERT_EQ(exp0, row_s0_0); + ASSERT_EQ(exp0, row_s0_1); + ASSERT_EQ(exp0, row_s0_2); + + ASSERT_EQ(exp1, row_s1_4); + ASSERT_EQ(exp1, row_s1_5); + ASSERT_EQ(exp2, row_s1_6); } TEST_F(NlpTests, basic_cb_ns_test_1) { - auto exp0 = NDArrayFactory::create('c', {1, 10}); - auto exp1 = NDArrayFactory::create('c', {1, 10}); - auto exp2 = NDArrayFactory::create('c', {1, 10}); - - exp0.assign(0.0096265625); - exp1.assign(0.01); - exp2.assign(0.030125f); - - auto target = NDArrayFactory::create(0); - auto ngStarter = NDArrayFactory::create(6); - auto context = NDArrayFactory::create('c', {3}, {0, 1, 2}); - auto locked = NDArrayFactory::create('c', {3}); - auto indices = NDArrayFactory::empty(); - auto codes = NDArrayFactory::empty(); - auto syn0 = NDArrayFactory::create('c', {100, 10}); - auto syn1 = NDArrayFactory::create('c', {100, 10}); - auto syn1Neg = NDArrayFactory::create('c', {100, 10}); - auto expTable = NDArrayFactory::create('c', {10000}); - auto negTable = NDArrayFactory::create('c', {100000}); - auto numWords = NDArrayFactory::create('c', {2}, {1, 2}); - - syn0.assign(0.01); - syn1.assign(0.02); - syn1Neg.assign(0.03); - expTable.assign(0.5); - - auto alpha = NDArrayFactory::create(0.025); - auto randomValue = NDArrayFactory::create(2L); - auto inferenceVector = NDArrayFactory::empty(); - - sd::ops::cbow op; - auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {1, 2, 0}, {true}, {}, true); - ASSERT_EQ(Status::OK(), result.status()); - - auto row_s0_0 = syn0({0,1, 0,0}, true); - auto row_s0_1 = syn0({1,2, 0,0}, true); - auto row_s0_2 = syn0({2,3, 0,0}, true); - - auto row_s1_4 = syn1({4,5, 0,0}, true); - auto row_s1_5 = syn1({5,6, 0,0}, true); - auto row_s1_6 = syn1Neg({6,7, 0,0}, true); - - - ASSERT_EQ(exp0, row_s0_0); - ASSERT_EQ(exp0, row_s0_1); - ASSERT_EQ(exp0, row_s0_2); - ASSERT_EQ(exp2, row_s1_6); - - + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + auto exp2 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.0096265625); + exp1.assign(0.01); + exp2.assign(0.030125f); + + auto target = NDArrayFactory::create(0); + auto ngStarter = NDArrayFactory::create(6); + auto context = NDArrayFactory::create('c', {3}, {0, 1, 2}); + auto locked = NDArrayFactory::create('c', {3}); + auto indices = NDArrayFactory::empty(); + auto codes = NDArrayFactory::empty(); + auto syn0 = NDArrayFactory::create('c', {100, 10}); + auto syn1 = NDArrayFactory::create('c', {100, 10}); + auto syn1Neg = NDArrayFactory::create('c', {100, 10}); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::create('c', {100000}); + auto numWords = NDArrayFactory::create('c', {2}, {1, 2}); + + syn0.assign(0.01); + syn1.assign(0.02); + syn1Neg.assign(0.03); + expTable.assign(0.5); + + auto alpha = NDArrayFactory::create(0.025); + auto randomValue = NDArrayFactory::create(2L); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::cbow op; + auto result = + op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, + &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, + &numWords, &locked, &inferenceVector}, + {}, {1, 2, 0}, {true}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + auto row_s0_0 = syn0({0, 1, 0, 0}, true); + auto row_s0_1 = syn0({1, 2, 0, 0}, true); + auto row_s0_2 = syn0({2, 3, 0, 0}, true); + + auto row_s1_4 = syn1({4, 5, 0, 0}, true); + auto row_s1_5 = syn1({5, 6, 0, 0}, true); + auto row_s1_6 = syn1Neg({6, 7, 0, 0}, true); + + ASSERT_EQ(exp0, row_s0_0); + ASSERT_EQ(exp0, row_s0_1); + ASSERT_EQ(exp0, row_s0_2); + ASSERT_EQ(exp2, row_s1_6); } TEST_F(NlpTests, test_sg_hs_batch_1) { - auto exp0 = NDArrayFactory::create('c', {1, 10}); - auto exp1 = NDArrayFactory::create('c', {1, 10}); - auto exp2 = NDArrayFactory::create('c', {1, 10}); - - exp0.assign(0.01f); - exp1.assign(0.020005f); - exp2.assign(0.019995f); - - auto target = NDArrayFactory::create('c', {2}, {0, 5}); - auto ngStarter = NDArrayFactory::empty(); - auto indices = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto codes = NDArrayFactory::create('c', {2, 2}, {0, 1, 1, 1}); - auto syn0 = NDArrayFactory::create('c', {100, 10}); - auto syn1 = NDArrayFactory::create('c', {100, 10}); - auto syn1Neg = NDArrayFactory::empty(); - auto expTable = NDArrayFactory::create('c', {10000}); - auto negTable = NDArrayFactory::empty(); - - auto alpha = NDArrayFactory::create('c', {2}, {0.001, 0.024}); - auto randomValue = NDArrayFactory::create('c', {2}, {1L, 3L}); - auto inferenceVector = NDArrayFactory::empty(); - auto neu1e = NDArrayFactory::create('c', {2, 10}); - - syn0.assign(0.01); - syn1.assign(0.02); - expTable.assign(0.5); - - sd::ops::skipgram op; - auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false, true}, {}, true); - ASSERT_EQ(Status::OK(), result.status()); - - auto row0 = syn0({0,1, 0,0}, true); - auto row1 = syn1({1,2, 0,0}, true); - auto row2 = syn1({2,3, 0,0}, true); - - ASSERT_TRUE(exp0.equalsTo(row0, 1e-6)); - ASSERT_TRUE(exp1.equalsTo(row1, 1e-6)); - ASSERT_TRUE(exp2.equalsTo(row2, 1e-6)); - - + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + auto exp2 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.01f); + exp1.assign(0.020005f); + exp2.assign(0.019995f); + + auto target = NDArrayFactory::create('c', {2}, {0, 5}); + auto ngStarter = NDArrayFactory::empty(); + auto indices = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto codes = NDArrayFactory::create('c', {2, 2}, {0, 1, 1, 1}); + auto syn0 = NDArrayFactory::create('c', {100, 10}); + auto syn1 = NDArrayFactory::create('c', {100, 10}); + auto syn1Neg = NDArrayFactory::empty(); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::empty(); + + auto alpha = NDArrayFactory::create('c', {2}, {0.001, 0.024}); + auto randomValue = NDArrayFactory::create('c', {2}, {1L, 3L}); + auto inferenceVector = NDArrayFactory::empty(); + auto neu1e = NDArrayFactory::create('c', {2, 10}); + + syn0.assign(0.01); + syn1.assign(0.02); + expTable.assign(0.5); + + sd::ops::skipgram op; + auto result = op.evaluate( + {&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, + &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, + {}, {}, {false, true}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + auto row0 = syn0({0, 1, 0, 0}, true); + auto row1 = syn1({1, 2, 0, 0}, true); + auto row2 = syn1({2, 3, 0, 0}, true); + + ASSERT_TRUE(exp0.equalsTo(row0, 1e-6)); + ASSERT_TRUE(exp1.equalsTo(row1, 1e-6)); + ASSERT_TRUE(exp2.equalsTo(row2, 1e-6)); } TEST_F(NlpTests, test_sg_ns_batch_1) { - auto exp0 = NDArrayFactory::create('c', {1, 10}); - auto exp1 = NDArrayFactory::create('c', {1, 10}); - auto exp2 = NDArrayFactory::create('c', {1, 10}); - - exp0.assign(0.01f); - exp1.assign(0.020005f); - exp2.assign(0.019995f); - - auto target = NDArrayFactory::create('c', {2}, {0, 5}); - auto ngStarter = NDArrayFactory::create('c', {2}, {3, 8}); - auto indices = NDArrayFactory::empty(); - auto codes = NDArrayFactory::empty(); - auto syn0 = NDArrayFactory::create('c', {100, 10}); - auto syn1Neg = NDArrayFactory::create('c', {100, 10}); - auto syn1 = NDArrayFactory::empty(); - auto expTable = NDArrayFactory::create('c', {10000}); - auto negTable = NDArrayFactory::create('c', {100000}); - - auto alpha = NDArrayFactory::create('c', {2}, {0.001, 0.024}); - auto randomValue = NDArrayFactory::create('c', {2}, {1L, 3L}); - auto inferenceVector = NDArrayFactory::empty(); - auto neu1e = NDArrayFactory::create('c', {2, 10}); - - syn0.assign(0.01); - syn1.assign(0.02); - expTable.assign(0.5); - negTable.linspace(0.0); - - sd::ops::skipgram op; - auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {4, 5}, {false, true}, {}, true); - ASSERT_EQ(Status::OK(), result.status()); - - + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + auto exp2 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.01f); + exp1.assign(0.020005f); + exp2.assign(0.019995f); + + auto target = NDArrayFactory::create('c', {2}, {0, 5}); + auto ngStarter = NDArrayFactory::create('c', {2}, {3, 8}); + auto indices = NDArrayFactory::empty(); + auto codes = NDArrayFactory::empty(); + auto syn0 = NDArrayFactory::create('c', {100, 10}); + auto syn1Neg = NDArrayFactory::create('c', {100, 10}); + auto syn1 = NDArrayFactory::empty(); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::create('c', {100000}); + + auto alpha = NDArrayFactory::create('c', {2}, {0.001, 0.024}); + auto randomValue = NDArrayFactory::create('c', {2}, {1L, 3L}); + auto inferenceVector = NDArrayFactory::empty(); + auto neu1e = NDArrayFactory::create('c', {2, 10}); + + syn0.assign(0.01); + syn1.assign(0.02); + expTable.assign(0.5); + negTable.linspace(0.0); + + sd::ops::skipgram op; + auto result = op.evaluate( + {&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, + &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, + {}, {4, 5}, {false, true}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); } TEST_F(NlpTests, test_cbow_hs_batch_1) { #ifdef __CUDABLAS__ - return ; + return; #endif - auto target = NDArrayFactory::create(0); - auto ngStarter = NDArrayFactory::empty(); - auto context = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 100, 101, 102}); - auto locked = NDArrayFactory::create('c', {2, 3}); - auto indices = NDArrayFactory::create('c', {2, 2}, {4, 5, 40, 50}); - auto codes = NDArrayFactory::create('c', {2, 2}, {1, 1, 1, 1}); - auto syn0 = NDArrayFactory::create('c', {244, 10}); - auto syn1 = NDArrayFactory::create('c', {244, 10}); - auto syn1Neg = NDArrayFactory::empty(); - auto expTable = NDArrayFactory::create('c', {10000}); - auto negTable = NDArrayFactory::empty(); - auto numWords = NDArrayFactory::create('c', {2}, {1, 2}); - - syn0.assign(0.01); - syn1.assign(0.02); - expTable.assign(0.5); - - auto alpha = NDArrayFactory::create('c', {2}, {0.025, 0.025}); - auto randomValue = NDArrayFactory::create('c', {2}, {2L, 2L}); - auto inferenceVector = NDArrayFactory::empty(); - - sd::ops::cbow op; - auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true); - ASSERT_EQ(Status::OK(), result.status()); - - auto exp0 = NDArrayFactory::create('c', {1, 10}); - auto exp1 = NDArrayFactory::create('c', {1, 10}); - auto exp2 = NDArrayFactory::create('c', {1, 10}); - - exp0.assign(0.0095f); - exp1.assign(0.019875f); - exp2.assign(0.02f); - - auto row_s0_0 = syn0({0,1, 0,0}, true); - auto row_s0_1 = syn0({1,2, 0,0}, true); - auto row_s0_2 = syn0({2,3, 0,0}, true); - - auto row_s1_4 = syn1({4,5, 0,0}, true); - auto row_s1_5 = syn1({5,6, 0,0}, true); - auto row_s1_6 = syn1({6,7, 0,0}, true); - - ASSERT_EQ(exp0, row_s0_0); - ASSERT_EQ(exp0, row_s0_1); - ASSERT_EQ(exp0, row_s0_2); - ASSERT_EQ(exp1, row_s1_4); - ASSERT_EQ(exp1, row_s1_5); - ASSERT_EQ(exp2, row_s1_6); - + auto target = NDArrayFactory::create(0); + auto ngStarter = NDArrayFactory::empty(); + auto context = + NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 100, 101, 102}); + auto locked = NDArrayFactory::create('c', {2, 3}); + auto indices = NDArrayFactory::create('c', {2, 2}, {4, 5, 40, 50}); + auto codes = NDArrayFactory::create('c', {2, 2}, {1, 1, 1, 1}); + auto syn0 = NDArrayFactory::create('c', {244, 10}); + auto syn1 = NDArrayFactory::create('c', {244, 10}); + auto syn1Neg = NDArrayFactory::empty(); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::empty(); + auto numWords = NDArrayFactory::create('c', {2}, {1, 2}); + + syn0.assign(0.01); + syn1.assign(0.02); + expTable.assign(0.5); + + auto alpha = NDArrayFactory::create('c', {2}, {0.025, 0.025}); + auto randomValue = NDArrayFactory::create('c', {2}, {2L, 2L}); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::cbow op; + auto result = + op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, + &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, + &numWords, &locked, &inferenceVector}, + {}, {}, {true}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + auto exp2 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.0095f); + exp1.assign(0.019875f); + exp2.assign(0.02f); + + auto row_s0_0 = syn0({0, 1, 0, 0}, true); + auto row_s0_1 = syn0({1, 2, 0, 0}, true); + auto row_s0_2 = syn0({2, 3, 0, 0}, true); + + auto row_s1_4 = syn1({4, 5, 0, 0}, true); + auto row_s1_5 = syn1({5, 6, 0, 0}, true); + auto row_s1_6 = syn1({6, 7, 0, 0}, true); + + ASSERT_EQ(exp0, row_s0_0); + ASSERT_EQ(exp0, row_s0_1); + ASSERT_EQ(exp0, row_s0_2); + ASSERT_EQ(exp1, row_s1_4); + ASSERT_EQ(exp1, row_s1_5); + ASSERT_EQ(exp2, row_s1_6); } diff --git a/libnd4j/tests_cpu/layers_tests/NodeTests.cpp b/libnd4j/tests_cpu/layers_tests/NodeTests.cpp index 8f4a8ae70bde..59d7db737528 100644 --- a/libnd4j/tests_cpu/layers_tests/NodeTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NodeTests.cpp @@ -18,56 +18,93 @@ // Created by raver119 on 21.02.18. // -#include "testlayers.h" #include -#include #include +#include #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class NodeTests : public testing::Test { -public: - + public: }; -TEST_F(NodeTests, Test_Dtype_Conversion_1) { - auto nodeA = new Node(OpType_TRANSFORM_SAME, 0, 1, {-1}, {2}); +TEST_F(NodeTests, test_copy_1) { + Node a(sd::ops::add(), "add"); - auto nd = nodeA->asT(); - auto nf = nd->asT(); + Node b(sd::ops::divide(), "div"); - ASSERT_EQ(nodeA->id(), nf->id()); - ASSERT_EQ(*nodeA->name(), *nf->name()); - ASSERT_EQ(nodeA->getOpClass(), nf->getOpClass()); - ASSERT_EQ(nodeA->opType(), nf->opType()); - ASSERT_EQ(nodeA->opNum(), nf->opNum()); + ASSERT_NE(a.name(), b.name()); + ASSERT_NE(a.customOp()->getOpName(), b.customOp()->getOpName()); + ASSERT_NE(a.contextPrototype().name(), b.contextPrototype().name()); - delete nodeA; - delete nd; - delete nf; + b = a; + + ASSERT_EQ(a.name(), b.name()); + ASSERT_EQ(a.customOp()->getOpName(), b.customOp()->getOpName()); + ASSERT_EQ(a.contextPrototype().name(), b.contextPrototype().name()); + + ASSERT_NE(&a.contextPrototype(), &b.contextPrototype()); +} + +static FORCEINLINE Node copy(const Node& node) { + return Node(node); } +TEST_F(NodeTests, test_copy_2) { + Node a(sd::ops::add(), "add"); + + Node b(sd::ops::divide(), "div"); + + ASSERT_NE(a.name(), b.name()); + ASSERT_NE(a.customOp()->getOpName(), b.customOp()->getOpName()); + ASSERT_NE(a.contextPrototype().name(), b.contextPrototype().name()); + + b = copy(a); + + ASSERT_EQ(a.name(), b.name()); + ASSERT_EQ(a.customOp()->getOpName(), b.customOp()->getOpName()); + ASSERT_EQ(a.contextPrototype().name(), b.contextPrototype().name()); + + ASSERT_NE(&a.contextPrototype(), &b.contextPrototype()); +} + +TEST_F(NodeTests, test_copy_3) { + Node a(sd::ops::add(), "add"); + + Node b = copy(a); + + ASSERT_EQ(a.name(), b.name()); + ASSERT_EQ(a.customOp()->getOpName(), b.customOp()->getOpName()); + ASSERT_EQ(a.contextPrototype().name(), b.contextPrototype().name()); + + ASSERT_NE(&a.contextPrototype(), &b.contextPrototype()); +} + +static MAP_IMPL modifier(MAP_IMPL map) { + if (map.size() < 5) { + map[3] = Node(sd::ops::multiply(), "mul"); + return map; + } else { + map.erase(1); + return map; + } +} -TEST_F(NodeTests, Test_Dtype_Conversion_2) { - sd::ops::add opA; +TEST_F(NodeTests, test_copy_4) { + MAP_IMPL map; + map[1] = Node(sd::ops::add(), "add"); + map[2] = Node(sd::ops::divide(), "div"); - //auto nodeA = new Node(OpType_CUSTOM, 0, 1, {-1}, {2}); - auto nodeA = new Node(&opA, 1, {-1}, {2}); - //nodeA->setCustomOp(&op); + auto other = modifier(map); - auto nd = nodeA->asT(); - auto nf = nd->asT(); + ASSERT_EQ(3, other.size()); - ASSERT_EQ(nodeA->id(), nf->id()); - ASSERT_EQ(*nodeA->name(), *nf->name()); -// ASSERT_EQ(nodeA->getOpClass(), nf->getOpClass()); - ASSERT_EQ(nodeA->opType(), nf->opType()); - ASSERT_EQ(nodeA->opNum(), nf->opNum()); - ASSERT_EQ(nodeA->getCustomOp()->getOpHash(), nf->getCustomOp()->getOpHash()); + ASSERT_EQ(map[1].name(), other[1].name()); + ASSERT_EQ(map[1].contextPrototype().name(), other[1].contextPrototype().name()); - delete nodeA; - delete nd; - delete nf; + ASSERT_NE(&map[1].contextPrototype(), &other[1].contextPrototype()); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/OmpLaunchHelperTests.cpp b/libnd4j/tests_cpu/layers_tests/OmpLaunchHelperTests.cpp index af327d653c8b..0b10246fbb17 100644 --- a/libnd4j/tests_cpu/layers_tests/OmpLaunchHelperTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OmpLaunchHelperTests.cpp @@ -18,106 +18,107 @@ // Created by raver119 on 30.06.18. // -#include "testlayers.h" #include #include +#include "testlayers.h" using namespace sd; using namespace sd::graph; class OmpLaunchHelperTests : public testing::Test { -private: - int ewt = 0; -public: - OmpLaunchHelperTests() { - this->ewt = Environment::getInstance().elementwiseThreshold(); - Environment::getInstance().setElementwiseThreshold(1000); - }; - - ~OmpLaunchHelperTests() { - Environment::getInstance().setElementwiseThreshold(this->ewt); - } + private: + int ewt = 0; + + public: + OmpLaunchHelperTests() { + this->ewt = Environment::getInstance().elementwiseThreshold(); + Environment::getInstance().setElementwiseThreshold(1000); + }; + + ~OmpLaunchHelperTests() { + Environment::getInstance().setElementwiseThreshold(this->ewt); + } }; TEST_F(OmpLaunchHelperTests, Test_BetterSpan_1) { - auto span = OmpLaunchHelper::betterSpan(1000, 4); - ASSERT_EQ(250, span); + auto span = OmpLaunchHelper::betterSpan(1000, 4); + ASSERT_EQ(250, span); } TEST_F(OmpLaunchHelperTests, Test_BetterSpan_2) { - auto span = OmpLaunchHelper::betterSpan(1001, 4); - ASSERT_EQ(251, span); + auto span = OmpLaunchHelper::betterSpan(1001, 4); + ASSERT_EQ(251, span); } TEST_F(OmpLaunchHelperTests, Test_BetterSpan_3) { - auto span = OmpLaunchHelper::betterSpan(1002, 4); - ASSERT_EQ(251, span); + auto span = OmpLaunchHelper::betterSpan(1002, 4); + ASSERT_EQ(251, span); } TEST_F(OmpLaunchHelperTests, Test_BetterSpan_5) { - auto span = OmpLaunchHelper::betterSpan(1003, 4); - ASSERT_EQ(251, span); + auto span = OmpLaunchHelper::betterSpan(1003, 4); + ASSERT_EQ(251, span); } TEST_F(OmpLaunchHelperTests, Test_BetterSpan_6) { - auto span = OmpLaunchHelper::betterSpan(1004, 4); - ASSERT_EQ(251, span); + auto span = OmpLaunchHelper::betterSpan(1004, 4); + ASSERT_EQ(251, span); } - TEST_F(OmpLaunchHelperTests, Test_BetterThreads_1) { - auto n = OmpLaunchHelper::betterThreads(4000, 6); - ASSERT_EQ(4, n); + auto n = OmpLaunchHelper::betterThreads(4000, 6); + ASSERT_EQ(4, n); } TEST_F(OmpLaunchHelperTests, Test_BetterThreads_2) { - auto n = OmpLaunchHelper::betterThreads(12000, 6); - ASSERT_EQ(6, n); + auto n = OmpLaunchHelper::betterThreads(12000, 6); + ASSERT_EQ(6, n); } TEST_F(OmpLaunchHelperTests, Test_BetterThreads_3) { - auto n = OmpLaunchHelper::betterThreads(899, 6); - ASSERT_EQ(1, n); + auto n = OmpLaunchHelper::betterThreads(899, 6); + ASSERT_EQ(1, n); } TEST_F(OmpLaunchHelperTests, test_tad_threads_1) { - Nd4jLong numTads = 16; - Nd4jLong tadLength = 16; + Nd4jLong numTads = 16; + Nd4jLong tadLength = 16; -// nd4j_printf("TT: [%i]; ET: [%i];\n", Environment::getInstance().tadThreshold(), Environment::getInstance().elementwiseThreshold()); - ASSERT_EQ(1, OmpLaunchHelper::tadThreads(tadLength, numTads)); + // nd4j_printf("TT: [%i]; ET: [%i];\n", + // Environment::getInstance().tadThreshold(), + // Environment::getInstance().elementwiseThreshold()); + ASSERT_EQ(1, OmpLaunchHelper::tadThreads(tadLength, numTads)); } TEST_F(OmpLaunchHelperTests, test_tad_threads_2) { - if (omp_get_max_threads() <= 1) - return; + if (omp_get_max_threads() <= 1) return; - Nd4jLong numTads = 2; - Nd4jLong tadLength = Environment::getInstance().elementwiseThreshold(); + Nd4jLong numTads = 2; + Nd4jLong tadLength = Environment::getInstance().elementwiseThreshold(); - ASSERT_EQ(2, OmpLaunchHelper::tadThreads(tadLength, numTads)); + ASSERT_EQ(2, OmpLaunchHelper::tadThreads(tadLength, numTads)); } TEST_F(OmpLaunchHelperTests, test_tad_threads_3) { - Nd4jLong numTads = 2; - Nd4jLong tadLength = 128; + Nd4jLong numTads = 2; + Nd4jLong tadLength = 128; - ASSERT_EQ(1, OmpLaunchHelper::tadThreads(tadLength, numTads)); + ASSERT_EQ(1, OmpLaunchHelper::tadThreads(tadLength, numTads)); } TEST_F(OmpLaunchHelperTests, test_tad_threads_4) { - Nd4jLong numTads = 4; - Nd4jLong tadLength = 64; + Nd4jLong numTads = 4; + Nd4jLong tadLength = 64; - ASSERT_EQ(1, OmpLaunchHelper::tadThreads(tadLength, numTads)); + ASSERT_EQ(1, OmpLaunchHelper::tadThreads(tadLength, numTads)); } TEST_F(OmpLaunchHelperTests, test_tad_threads_5) { - auto exp = omp_get_max_threads(); + auto exp = omp_get_max_threads(); - Nd4jLong numTads = exp; - Nd4jLong tadLength = Environment::getInstance().elementwiseThreshold(); + Nd4jLong numTads = exp; + Nd4jLong tadLength = Environment::getInstance().elementwiseThreshold(); - ASSERT_EQ(exp, OmpLaunchHelper::tadThreads(tadLength, numTads)); + ASSERT_EQ(exp, OmpLaunchHelper::tadThreads(tadLength, numTads)); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp index e1cf4ec52663..4f630ffcc826 100644 --- a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp @@ -18,371 +18,207 @@ // Created by raver119 on 11.10.2017. // -#include "testlayers.h" -#include -#include -#include -#include -#include +#include #include #include -#include +#include +#include +#include + +#include + +#include "testlayers.h" using namespace sd; using namespace sd::ops; class OneOffTests : public testing::Test { -public: - + public: }; TEST_F(OneOffTests, test_avg_pool_3d_1) { - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/avg_pooling3d.fb"); + auto graph = Graph::fromFlatBuffers("./resources/avg_pooling3d.fb"); - ASSERT_TRUE(graph != nullptr); + graph.printOut(); + graph.execute(); +} - // graph->printOut(); +TEST_F(OneOffTests, test_avg_pool_3d_2) { + auto graph = Graph::fromFlatBuffers("./resources/avg_pooling3d.fb"); - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - delete graph; + graph.execute(); } TEST_F(OneOffTests, test_non2d_0A_1) { - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/non2d_0A.fb"); + auto graph = Graph::fromFlatBuffers("./resources/non2d_0A.fb"); - ASSERT_TRUE(graph != nullptr); - - // graph->printOut(); - - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - delete graph; + graph.execute(); } -/* -TEST_F(OneOffTests, test_assert_scalar_float32_1) { - sd::ops::Assert op; - sd::ops::identity op1; - sd::ops::noop op2; - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/scalar_float32.fb"); - - ASSERT_TRUE(graph != nullptr); - - graph->printOut(); - - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - delete graph; -}*/ - TEST_F(OneOffTests, test_assert_scalar_float32_2) { - sd::ops::Assert op; - sd::ops::identity op1; - sd::ops::noop op2; - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/assertsomething.fb"); - - ASSERT_TRUE(graph != nullptr); + sd::ops::Assert op; + sd::ops::identity op1; + sd::ops::noop op2; + auto graph = Graph::fromFlatBuffers("./resources/assertsomething.fb"); - // graph->printOut(); + graph.printOut(); - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - delete graph; + // graph.execute(); } - TEST_F(OneOffTests, test_pad_1D_1) { - auto e = NDArrayFactory::create('c', {7}, {10.f,0.778786f, 0.801198f, 0.724375f, 0.230894f, 0.727141f,10.f}); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/pad_1D.fb"); - - ASSERT_TRUE(graph != nullptr); - - // graph->printOut(); - - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(4)); - - auto z = graph->getVariableSpace()->getVariable(4)->getNDArray(); - ASSERT_TRUE(z != nullptr); - - // z->printIndexedBuffer("z"); - - ASSERT_EQ(e, *z); - delete graph; -} -/* -TEST_F(OneOffTests, test_scatter_nd_update_1) { - - auto e = NDArrayFactory::create('c', {10, 7}, {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.20446908f, 0.37918627f, 0.99792874f, 0.71881700f, 0.18677747f, - 0.78299069f, 0.55216062f, 0.40746713f, 0.92128086f, 0.57195139f, 0.44686234f, 0.30861020f, 0.31026053f, 0.09293187f, - 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.95073712f, 0.45613325f, 0.95149803f, 0.88341522f, 0.54366302f, 0.50060666f, 0.39031255f, - 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, - 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}); - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/scatter_nd_update.fb"); - ASSERT_TRUE(graph != nullptr); - - graph->printOut(); - - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(6)); - - auto z = graph->getVariableSpace()->getVariable(6)->getNDArray(); - ASSERT_TRUE(z != nullptr); - - z->printIndexedBuffer("z"); - - ASSERT_EQ(e, *z); - - delete graph; -} - */ - -TEST_F(OneOffTests, test_conv2d_nhwc_failed_1) { - auto e = NDArrayFactory::create('c', {1, 5, 5, 6}, {0.55744928f, 0.76827729f, 1.09401524f, 0.00000000f, 0.00000000f, 0.00000000f, 0.56373537f, 0.90029907f, 0.78997850f, 0.00000000f, 0.00000000f, 0.00000000f, 0.14252824f, 0.95961076f, 0.87750554f, 0.00000000f, 0.00000000f, 0.00000000f, 0.44874173f, 0.99537718f, 1.17154264f, 0.00000000f, 0.00000000f, 0.00000000f, 0.60377145f, 0.79939061f, 0.56031001f, 0.00000000f, 0.00000000f, 0.00000000f, 0.52975273f, 0.90678585f, 0.73763013f, 0.00000000f, 0.00000000f, 0.00000000f, 0.22146404f, 0.82499605f, 0.47222072f, 0.00000000f, 0.00000000f, 0.00000000f, 0.42772964f, 0.39793295f, 0.71436501f, 0.00000000f, 0.00000000f, 0.00000000f, 0.48836520f, 1.01658893f, 0.74419701f, 0.00000000f, 0.00000000f, 0.00000000f, 0.78984612f, 0.94083673f, 0.83841157f, 0.00000000f, 0.00000000f, 0.00000000f, 0.40448499f, 0.67732805f, 0.75499672f, 0.00000000f, 0.00000000f, 0.00000000f, 0.43675962f, 0.79476535f, 0.72976631f, 0.00000000f, 0.00000000f, 0.00000000f, 0.58808053f, 0.65222591f, 0.72552216f, 0.00000000f, 0.00000000f, 0.00000000f, 0.37445742f, 1.22581339f, 1.05341125f, 0.00000000f, 0.00000000f, 0.00000000f, 0.30095795f, 0.59941679f, 0.63323414f, 0.00000000f, 0.00000000f, 0.00000000f, 0.24199286f, 1.02546394f, 0.69537812f, 0.00000000f, 0.00000000f, 0.00000000f, 0.23628944f, 0.90791851f, 1.01209974f, 0.00000000f, 0.00000000f, 0.00000000f, 0.62740159f, 0.56518674f, 0.76692569f, 0.00000000f, 0.00000000f, 0.00000000f, 0.13327584f, 0.32628393f, 0.10280430f, 0.00000000f, 0.00000000f, 0.00000000f, 0.42691272f, 0.25625113f, 0.30524066f, 0.00000000f, 0.00000000f, 0.00000000f, 0.17797673f, 0.84179950f, 0.80061519f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00199084f, 0.51838887f, 0.43932241f, 0.00000000f, 0.00000000f, 0.00000000f, 0.16684581f, 0.50822425f, 0.48668745f, 0.00000000f, 0.00000000f, 0.00000000f, 0.16749343f, 0.93093169f, 0.86871749f, 0.00000000f, 0.00000000f, 0.00000000f, 0.17486368f, 0.44460732f, 0.44499981f, 0.00000000f, 0.00000000f, 0.00000000f}); - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/channels_last_b1_k2_s1_d1_SAME_crelu.fb"); - ASSERT_TRUE(graph != nullptr); + auto e = NDArrayFactory::create( + 'c', {7}, + {10.f, 0.778786f, 0.801198f, 0.724375f, 0.230894f, 0.727141f, 10.f}); + auto graph = Graph::fromFlatBuffers("./resources/pad_1D.fb"); - // graph->printOut(); + graph.execute(); - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(graph.variableSpace().hasVariable(4)); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(9)); + auto z = graph.variableSpace().getVariable(4)->getNDArray(); + ASSERT_TRUE(z != nullptr); - auto z = graph->getVariableSpace()->getVariable(9)->getNDArray(); - ASSERT_TRUE(z != nullptr); - - ASSERT_EQ(e, *z); - - delete graph; + ASSERT_EQ(e, *z); } TEST_F(OneOffTests, test_tensor_array_1) { - auto e = NDArrayFactory::create('c', {2, 3}, {0.77878559f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f}); - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_array_close_sz1_float32_nodynamic_noname_noshape.fb"); - ASSERT_TRUE(graph != nullptr); + auto e = + NDArrayFactory::create('c', {2, 3}, + {0.77878559f, 0.80119777f, 0.72437465f, + 0.23089433f, 0.72714126f, 0.18039072f}); - // graph->printOut(); + auto graph = Graph::fromFlatBuffers( + "./resources/tensor_array_close_sz1_float32_nodynamic_noname_noshape.fb"); - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(5)); + graph.execute(); - auto z = graph->getVariableSpace()->getVariable(5)->getNDArray(); - ASSERT_TRUE(z != nullptr); + ASSERT_TRUE(graph.variableSpace().hasVariable(5)); - ASSERT_EQ(e, *z); + auto z = graph.variableSpace().getVariable(5)->getNDArray(); + ASSERT_TRUE(z != nullptr); - delete graph; + ASSERT_EQ(e, *z); } TEST_F(OneOffTests, test_tensor_array_2) { - auto e = NDArrayFactory::create('c', {2, 3}, {0.77878559f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f}); + auto e = + NDArrayFactory::create('c', {2, 3}, + {0.77878559f, 0.80119777f, 0.72437465f, + 0.23089433f, 0.72714126f, 0.18039072f}); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_array_split_sz1_float32_nodynamic_noname_noshape.fb"); - ASSERT_TRUE(graph != nullptr); + auto graph = Graph::fromFlatBuffers( + "./resources/tensor_array_split_sz1_float32_nodynamic_noname_noshape.fb"); - // graph->printOut(); + graph.execute(); - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(6)); + ASSERT_TRUE(graph.variableSpace().hasVariable(6)); - auto z = graph->getVariableSpace()->getVariable(6)->getNDArray(); - ASSERT_TRUE(z != nullptr); + auto z = graph.variableSpace().getVariable(6)->getNDArray(); + ASSERT_TRUE(z != nullptr); - ASSERT_EQ(e, *z); - - delete graph; + ASSERT_EQ(e, *z); } TEST_F(OneOffTests, test_tensor_array_3) { - auto e = NDArrayFactory::create('c', {3, 2, 3}, {7, 2, 9, 4, 3, 3, 8, 7, 0, 0, 6, 8, 7, 9, 0, 1, 1, 4}); - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_array_stack_sz3-1_int32_dynamic_name_shape.fb"); - ASSERT_TRUE(graph != nullptr); - - // graph->printOut(); + auto e = NDArrayFactory::create( + 'c', {3, 2, 3}, {7, 2, 9, 4, 3, 3, 8, 7, 0, 0, 6, 8, 7, 9, 0, 1, 1, 4}); + auto graph = Graph::fromFlatBuffers( + "./resources/tensor_array_stack_sz3-1_int32_dynamic_name_shape.fb"); - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(15)); + graph.execute(); - auto z = graph->getVariableSpace()->getVariable(15)->getNDArray(); - ASSERT_TRUE(z != nullptr); + ASSERT_TRUE(graph.variableSpace().hasVariable(15)); - ASSERT_EQ(e, *z); + auto z = graph.variableSpace().getVariable(15)->getNDArray(); + ASSERT_TRUE(z != nullptr); - delete graph; + ASSERT_EQ(e, *z); } TEST_F(OneOffTests, test_tensor_array_4) { - auto e = NDArrayFactory::create('c', {2, 3}, {4, 3, 1, 1, 1, 0}); - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_array_unstack_sz1_int64_nodynamic_noname_shape2-3.fb"); - ASSERT_TRUE(graph != nullptr); - - // graph->printOut(); + auto e = NDArrayFactory::create('c', {2, 3}, {4, 3, 1, 1, 1, 0}); + auto graph = Graph::fromFlatBuffers( + "./resources/" + "tensor_array_unstack_sz1_int64_nodynamic_noname_shape2-3.fb"); - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(11)); + graph.execute(); - auto z = graph->getVariableSpace()->getVariable(11)->getNDArray(); - ASSERT_TRUE(z != nullptr); + ASSERT_TRUE(graph.variableSpace().hasVariable(11)); - ASSERT_EQ(e, *z); + auto z = graph.variableSpace().getVariable(11)->getNDArray(); + ASSERT_TRUE(z != nullptr); - delete graph; + ASSERT_EQ(e, *z); } TEST_F(OneOffTests, test_assert_4) { - auto e = NDArrayFactory::create('c', {2, 2}, {1, 1, 1, 1}); + auto e = NDArrayFactory::create('c', {2, 2}, {1, 1, 1, 1}); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/assert_type_rank2_int64.fb"); - ASSERT_TRUE(graph != nullptr); + auto graph = Graph::fromFlatBuffers("./resources/assert_type_rank2_int64.fb"); - // graph->printOut(); + graph.execute(); + ASSERT_TRUE(graph.variableSpace().hasVariable(1)); - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1)); + auto z = graph.variableSpace().getVariable(1)->getNDArray(); + ASSERT_TRUE(z != nullptr); - auto z = graph->getVariableSpace()->getVariable(1)->getNDArray(); - ASSERT_TRUE(z != nullptr); - - ASSERT_EQ(e, *z); - - delete graph; + ASSERT_EQ(e, *z); } -// TEST_F(OneOffTests, test_cond_true_1) { -// auto e = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); - -// auto graph = GraphExecutioner::importFromFlatBuffers("./resources/cond_true.fb"); -// ASSERT_TRUE(graph != nullptr); - -// graph->printOut(); - - -// Nd4jStatus status = GraphExecutioner::execute(graph); -// ASSERT_EQ(Status::OK(), status); -// ASSERT_TRUE(graph->getVariableSpace()->hasVariable(6)); - -// auto z = graph->getVariableSpace()->getVariable(6)->getNDArray(); -// ASSERT_TRUE(z != nullptr); - -// z->printIndexedBuffer("z buffer"); - -// ASSERT_EQ(e, *z); - -// delete graph; -// } - -/* -TEST_F(OneOffTests, test_cond_false_1) { - auto e = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/cond_false.fb"); - ASSERT_TRUE(graph != nullptr); - - graph->printOut(); - - - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(6)); - - auto z = graph->getVariableSpace()->getVariable(6)->getNDArray(); - ASSERT_TRUE(z != nullptr); - - z->printIndexedBuffer("z buffer"); - - ASSERT_EQ(e, *z); - - delete graph; -} -*/ - TEST_F(OneOffTests, test_identity_n_2) { - auto e = NDArrayFactory::create('c', {2, 3}, {0.77878559f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f}); - - sd::ops::identity_n op; - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/identity_n_2.fb"); - ASSERT_TRUE(graph != nullptr); + auto e = + NDArrayFactory::create('c', {2, 3}, + {0.77878559f, 0.80119777f, 0.72437465f, + 0.23089433f, 0.72714126f, 0.18039072f}); - // graph->printOut(); + sd::ops::identity_n op; + auto graph = Graph::fromFlatBuffers("./resources/identity_n_2.fb"); - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1)); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1, 1)); + graph.execute(); - auto z = graph->getVariableSpace()->getVariable(1)->getNDArray(); - ASSERT_TRUE(z != nullptr); + ASSERT_TRUE(graph.variableSpace().hasVariable(1)); + ASSERT_TRUE(graph.variableSpace().hasVariable(1, 1)); - ASSERT_EQ(e, *z); + auto z = graph.variableSpace().getVariable(1)->getNDArray(); + ASSERT_TRUE(z != nullptr); - delete graph; + ASSERT_EQ(e, *z); } TEST_F(OneOffTests, test_non2d_1) { - auto e = NDArrayFactory::create('c', {1, 1}, {5.42746449f}); - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/non2d_1.fb"); - ASSERT_TRUE(graph != nullptr); - - // graph->printOut(); + auto e = NDArrayFactory::create('c', {1, 2}, {2.07706356f, 2.66380072f}); - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); + auto graph = Graph::fromFlatBuffers("./resources/non2d_1.fb"); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(3)); + graph.execute(); - auto z = graph->getVariableSpace()->getVariable(3)->getNDArray(); - ASSERT_TRUE(z != nullptr); + ASSERT_TRUE(graph.variableSpace().hasVariable(6)); - ASSERT_EQ(e, *z); + auto z = graph.variableSpace().getVariable(6)->getNDArray(); + ASSERT_TRUE(z != nullptr); - - delete graph; + ASSERT_EQ(e, *z); } TEST_F(OneOffTests, test_reduce_all_1) { - auto e = NDArrayFactory::create('c', {1, 4}, {true, false, false, false}); - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_all_rank2_d0_keep.fb"); - ASSERT_TRUE(graph != nullptr); - - // graph->printOut(); - - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1)); + auto e = NDArrayFactory::create('c', {1, 4}, {true, false, false, false}); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(2)); - auto in = graph->getVariableSpace()->getVariable(2)->getNDArray(); + auto graph = Graph::fromFlatBuffers("./resources/reduce_all_rank2_d0_keep.fb"); + graph.execute(); - auto z = graph->getVariableSpace()->getVariable(1)->getNDArray(); - ASSERT_TRUE(z != nullptr); + ASSERT_TRUE(graph.variableSpace().hasVariable(1)); - ASSERT_EQ(e, *z); + ASSERT_TRUE(graph.variableSpace().hasVariable(2)); + auto in = graph.variableSpace().getVariable(2)->getNDArray(); + auto z = graph.variableSpace().getVariable(1)->getNDArray(); + ASSERT_TRUE(z != nullptr); - delete graph; + ASSERT_EQ(e, *z); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp new file mode 100644 index 000000000000..910e69c89b28 --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp @@ -0,0 +1,95 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include +#include +#include +#include + +#include + +#include "testlayers.h" + +using namespace sd; +using namespace sd::ops; +using namespace sd::graph; + +class OpSequenceTests : public testing::Test { + public: + OpSequenceTests() {} +}; + +TEST_F(OpSequenceTests, test_append_1) { + OpSequence sequenceA; + OpSequence sequenceB; + + ASSERT_EQ(0, sequenceA.length()); + + Context ctx1(1); + Context ctx2(2); + + sequenceA.append(Node(sd::ops::add(), "add"), ctx1); + sequenceB.append(Node(sd::ops::multiply(), "mul"), ctx2); + + ASSERT_EQ(1, sequenceA.length()); + + sequenceA.append(sequenceB); + + ASSERT_EQ(2, sequenceA.length()); +} + +TEST_F(OpSequenceTests, test_iterator_1) { + Graph graph; + OpSequence sequence; + + ASSERT_EQ(0, sequence.length()); + + Context ctx1(1); + Context ctx2(2); + + sequence.append(Node(ops::add(), "add"), ctx1); + sequence.append(Node(ops::divide(), "div"), ctx2); + + ASSERT_EQ(2, sequence.length()); + + int cnt = 1; + for (const auto &v : sequence) { + ASSERT_EQ(cnt++, v.protoContext().nodeId()); + } + + ASSERT_EQ(3, cnt); + + OptimizedGraph optimizedGraph; + ASSERT_EQ(0, optimizedGraph.layers()); + + optimizedGraph.append(sequence); + ASSERT_EQ(1, optimizedGraph.layers()); + + auto layer = optimizedGraph.layer(0); + + // we expect exactly 1 sequence in this layer + ASSERT_EQ(1, layer.width()); + + auto seq = layer[0]; + + ASSERT_EQ(2, seq.length()); +} diff --git a/libnd4j/tests_cpu/layers_tests/OpTrackerTests.cpp b/libnd4j/tests_cpu/layers_tests/OpTrackerTests.cpp index a14971ad5112..488e7d30381e 100644 --- a/libnd4j/tests_cpu/layers_tests/OpTrackerTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OpTrackerTests.cpp @@ -17,53 +17,51 @@ // // Created by raver119 on 15.12.17. // -#include "testlayers.h" #include -#include #include #include #include +#include + +#include "testlayers.h" + using namespace sd; using namespace sd::ops; using namespace sd::graph; class OpTrackerTests : public testing::Test { -public: - int numIterations = 10; - int poolSize = 10; + public: + int numIterations = 10; + int poolSize = 10; - OpTrackerTests() { - printf("\n"); - fflush(stdout); - } + OpTrackerTests() {} }; TEST_F(OpTrackerTests, Test_Existence_1) { - sd::_loader loader; + sd::_loader loader; - // nd4j_printf("Groups: %i; Operations: %i\n", OpTracker::getInstance().totalGroups(), OpTracker::getInstance().totalOperations()); + // nd4j_printf("Groups: %i; Operations: %i\n", + // OpTracker::getInstance().totalGroups(), + // OpTracker::getInstance().totalOperations()); - ASSERT_TRUE(OpTracker::getInstance().totalGroups() > 0); - ASSERT_TRUE(OpTracker::getInstance().totalOperations() > 0); + ASSERT_TRUE(OpTracker::getInstance().totalGroups() > 0); + ASSERT_TRUE(OpTracker::getInstance().totalOperations() > 0); - OpTracker::getInstance().exportOperations(); + OpTracker::getInstance().exportOperations(); } TEST_F(OpTrackerTests, Test_Ops_List_1) { - sd::ops::less op; - auto vec = OpRegistrator::getInstance().getAllHashes(); + sd::ops::less op; + auto vec = OpRegistrator::getInstance().getAllHashes(); - // nd4j_printf("Total ops: %lld\n", vec.size()); - // nd4j_printf("Less hash: %lld\n", op.getOpHash()); + // nd4j_printf("Total ops: %lld\n", vec.size()); + // nd4j_printf("Less hash: %lld\n", op.getOpHash()); - for (const auto &v: vec) { - if (v == 5484196977525668316L) { - auto op = OpRegistrator::getInstance().getOperation(v); - // nd4j_printf("OpName: %s\n", op->getOpName()->c_str()); - } + for (const auto &v : vec) { + if (v == 5484196977525668316L) { + auto op = OpRegistrator::getInstance().getOperation(v); + // nd4j_printf("OpName: %s\n", op->getOpName()->c_str()); } + } } - - - diff --git a/libnd4j/tests_cpu/layers_tests/OpTupleTests.cpp b/libnd4j/tests_cpu/layers_tests/OpTupleTests.cpp index bec75f056dfc..6c947c001444 100644 --- a/libnd4j/tests_cpu/layers_tests/OpTupleTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OpTupleTests.cpp @@ -18,42 +18,39 @@ // Created by raver119 on 11.10.2017. // -#include "testlayers.h" #include #include +#include "testlayers.h" + using namespace sd; using namespace sd::ops; class OpTupleTests : public testing::Test { - public: + public: }; TEST_F(OpTupleTests, DirectConstructorTest1) { - auto alpha = NDArrayFactory::create_('c', {1, 2}); - auto beta = NDArrayFactory::create_('c', {1, 2}); - OpTuple tuple("dummy", {alpha, beta}, {12.0f}, {1,2, 3}); - - ASSERT_EQ("dummy", tuple._opName); - ASSERT_EQ(2, tuple._inputs.size()); - ASSERT_EQ(0, tuple._outputs.size()); - ASSERT_EQ(1, tuple._tArgs.size()); - ASSERT_EQ(3, tuple._iArgs.size()); + auto alpha = NDArrayFactory::create_('c', {1, 2}); + auto beta = NDArrayFactory::create_('c', {1, 2}); + OpTuple tuple("dummy", {alpha, beta}, {12.0f}, {1, 2, 3}); + + ASSERT_EQ("dummy", tuple._opName); + ASSERT_EQ(2, tuple._inputs.size()); + ASSERT_EQ(0, tuple._outputs.size()); + ASSERT_EQ(1, tuple._tArgs.size()); + ASSERT_EQ(3, tuple._iArgs.size()); } TEST_F(OpTupleTests, BuilderTest1) { - auto alpha = NDArrayFactory::create_('c', {1, 2}); - auto beta = NDArrayFactory::create_('c', {1, 2}); - OpTuple tuple("dummy"); - tuple.addInput(alpha) - ->addInput(beta) - ->setTArgs({12.0f}) - ->setIArgs({1, 2, 3}); - - - ASSERT_EQ("dummy", tuple._opName); - ASSERT_EQ(2, tuple._inputs.size()); - ASSERT_EQ(0, tuple._outputs.size()); - ASSERT_EQ(1, tuple._tArgs.size()); - ASSERT_EQ(3, tuple._iArgs.size()); + auto alpha = NDArrayFactory::create_('c', {1, 2}); + auto beta = NDArrayFactory::create_('c', {1, 2}); + OpTuple tuple("dummy"); + tuple.addInput(alpha)->addInput(beta)->setTArgs({12.0f})->setIArgs({1, 2, 3}); + + ASSERT_EQ("dummy", tuple._opName); + ASSERT_EQ(2, tuple._inputs.size()); + ASSERT_EQ(0, tuple._outputs.size()); + ASSERT_EQ(1, tuple._tArgs.size()); + ASSERT_EQ(3, tuple._iArgs.size()); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/PairwiseTests.cpp b/libnd4j/tests_cpu/layers_tests/PairwiseTests.cpp index 0d73b369bae3..9a973bedfcb1 100644 --- a/libnd4j/tests_cpu/layers_tests/PairwiseTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PairwiseTests.cpp @@ -17,34 +17,31 @@ // // Created by agibsonccc on 1/17/17. // -#include "testinclude.h" #include +#include "testinclude.h" + class EqualsTest : public testing::Test { -public: - const Nd4jLong firstShapeBuffer[8] = {2,1,2,1,1,0,1,102}; - float data[2] = {1.0f, 7.0f}; - const Nd4jLong secondShapeBuffer[8] = {2,2,1,6,1,0,6,99}; - float dataSecond[12] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; - int opNum = 4; - float extraArgs[1] = {1e-6f}; - int dimension[1] = {2147483647}; - int dimensionLength = 1; + public: + const Nd4jLong firstShapeBuffer[8] = {2, 1, 2, 1, 1, 0, 1, 102}; + float data[2] = {1.0f, 7.0f}; + const Nd4jLong secondShapeBuffer[8] = {2, 2, 1, 6, 1, 0, 6, 99}; + float dataSecond[12] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; + int opNum = 4; + float extraArgs[1] = {1e-6f}; + int dimension[1] = {2147483647}; + int dimensionLength = 1; }; #ifndef __CUDABLAS__ -TEST_F(EqualsTest,Eps) { - auto val = sd::NDArrayFactory::create(0.0f); - functions::reduce3::Reduce3::execScalar(opNum, - data, - firstShapeBuffer, - extraArgs, - dataSecond, - secondShapeBuffer, - val.buffer(), - val.shapeInfo()); - ASSERT_TRUE(val.e(0) < 0.5); +TEST_F(EqualsTest, Eps) { + auto val = sd::NDArrayFactory::create(0.0f); + functions::reduce3::Reduce3::execScalar( + opNum, data, firstShapeBuffer, extraArgs, dataSecond, secondShapeBuffer, + val.buffer(), val.shapeInfo()); + ASSERT_TRUE(val.e(0) < 0.5); } #endif diff --git a/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp index 089b4a92f5db..f8dc57f087c2 100644 --- a/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp @@ -18,1681 +18,1825 @@ // Created by raver119 on 12.10.2017. // -#include "testlayers.h" #include #include +#include "testlayers.h" using namespace sd; using namespace sd::ops; class ParityOpsTests : public testing::Test { -public: - + public: }; - TEST_F(ParityOpsTests, TestZeroAs1) { - auto x = NDArrayFactory::create('c', {10, 10}); - x.assign(1.0); + auto x = NDArrayFactory::create('c', {10, 10}); + x.assign(1.0); - auto exp = NDArrayFactory::create('c', {10, 10}); - exp.assign(0.0f); + auto exp = NDArrayFactory::create('c', {10, 10}); + exp.assign(0.0f); - sd::ops::zeros_as op; + sd::ops::zeros_as op; - auto result = op.evaluate({&x}, {}, {}); - - auto z = result.at(0); - - ASSERT_TRUE(z->isSameShape(&x)); - ASSERT_TRUE(z->equalsTo(&exp)); + auto result = op.evaluate({&x}, {}, {}); + auto z = result.at(0); + ASSERT_TRUE(z.isSameShape(&x)); + ASSERT_TRUE(z.equalsTo(&exp)); } TEST_F(ParityOpsTests, TestMaximum1) { - auto x = NDArrayFactory::create('c', {10, 10}); - x.assign(1.0); - - auto y = NDArrayFactory::create('c', {10, 10}); - y.assign(2.0); + auto x = NDArrayFactory::create('c', {10, 10}); + x.assign(1.0); - sd::ops::maximum op; + auto y = NDArrayFactory::create('c', {10, 10}); + y.assign(2.0); - auto result = op.evaluate({&x, &y}, {}, {}); + sd::ops::maximum op; - auto z = result.at(0); - - ASSERT_TRUE(y.equalsTo(z)); + auto result = op.evaluate({&x, &y}, {}, {}); + auto z = result.at(0); + ASSERT_TRUE(y.equalsTo(z)); } - TEST_F(ParityOpsTests, TestMinimum1) { - auto x = NDArrayFactory::create('c', {10, 10}); - x.assign(1.0f); - - auto y = NDArrayFactory::create('c', {10, 10}); - y.assign(-2.0f); - + auto x = NDArrayFactory::create('c', {10, 10}); + x.assign(1.0f); - sd::ops::minimum op; + auto y = NDArrayFactory::create('c', {10, 10}); + y.assign(-2.0f); - auto result = op.evaluate({&x, &y}, {}, {}); + sd::ops::minimum op; - auto z = result.at(0); - - ASSERT_TRUE(y.equalsTo(z)); + auto result = op.evaluate({&x, &y}, {}, {}); + auto z = result.at(0); + ASSERT_TRUE(y.equalsTo(z)); } TEST_F(ParityOpsTests, TestTear1) { - auto input = NDArrayFactory::create('c', {10, 5}); - auto tads = input.allTensorsAlongDimension({1}); - for (int e = 0; e < tads.size(); e++) { - ASSERT_EQ(5, tads.at(e)->lengthOf()); - tads.at(e)->assign((float) e + 1); - } - - sd::ops::tear op; - - auto result = op.evaluate({&input}, {}, {1}); + auto input = NDArrayFactory::create('c', {10, 5}); + auto tads = input.allTensorsAlongDimension({1}); + for (int e = 0; e < tads.size(); e++) { + ASSERT_EQ(5, tads.at(e).lengthOf()); + tads.at(e).assign((float)e + 1); + } - ASSERT_EQ(10, result.size()); + sd::ops::tear op; - for (int e = 0; e < result.size(); e++) - ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e))); + auto result = op.evaluate({&input}, {}, {1}); + ASSERT_EQ(10, result.size()); + for (int e = 0; e < result.size(); e++) + ASSERT_TRUE(tads.at(e).equalsTo(result.at(e))); } TEST_F(ParityOpsTests, TestUnstack1) { - auto input = NDArrayFactory::create('c', {10, 5}); - auto tads = input.allTensorsAlongDimension({1}); - for (int e = 0; e < tads.size(); e++) { - ASSERT_EQ(5, tads.at(e)->lengthOf()); - tads.at(e)->assign((float) e + 1); - } + auto input = NDArrayFactory::create('c', {10, 5}); + auto tads = input.allTensorsAlongDimension({1}); + for (int e = 0; e < tads.size(); e++) { + ASSERT_EQ(5, tads.at(e).lengthOf()); + tads.at(e).assign((float)e + 1); + } - sd::ops::unstack op; + sd::ops::unstack op; - auto result = op.evaluate({&input}, {}, {0}); - - ASSERT_EQ(10, result.size()); - - for (int e = 0; e < result.size(); e++) - ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e))); + auto result = op.evaluate({&input}, {}, {0}); + ASSERT_EQ(10, result.size()); + for (int e = 0; e < result.size(); e++) + ASSERT_TRUE(tads.at(e).equalsTo(result.at(e))); } - - TEST_F(ParityOpsTests, TestUnstack2) { - auto input = NDArrayFactory::create('c', {5,2,6}); - auto tads = input.allTensorsAlongDimension({0,1}); - for (int e = 0; e < tads.size(); e++) { - ASSERT_EQ(10, tads.at(e)->lengthOf()); - tads.at(e)->assign((float) e + 1); - } + auto input = NDArrayFactory::create('c', {5, 2, 6}); + auto tads = input.allTensorsAlongDimension({0, 1}); + for (int e = 0; e < tads.size(); e++) { + ASSERT_EQ(10, tads.at(e).lengthOf()); + tads.at(e).assign((float)e + 1); + } - sd::ops::unstack op; + sd::ops::unstack op; - auto result = op.evaluate({&input}, {}, {2}); - - ASSERT_EQ(6, result.size()); - - for (int e = 0; e < result.size(); e++) - ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e))); + auto result = op.evaluate({&input}, {}, {2}); + ASSERT_EQ(6, result.size()); + for (int e = 0; e < result.size(); e++) + ASSERT_TRUE(tads.at(e).equalsTo(result.at(e))); } TEST_F(ParityOpsTests, TestUnstack3) { - auto input = NDArrayFactory::create('c', {3,2,3}); - auto exp = NDArrayFactory::create('c', {3, 2}, {1.f, 4., 7., 10.f, 13.f, 16.f}); - input.linspace(1); - - sd::ops::unstack op; + auto input = NDArrayFactory::create('c', {3, 2, 3}); + auto exp = NDArrayFactory::create('c', {3, 2}, + {1.f, 4., 7., 10.f, 13.f, 16.f}); + input.linspace(1); - auto result = op.evaluate({&input}, {}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::unstack op; - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto result = op.evaluate({&input}, {}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(ParityOpsTests, TestUnstack4) { - auto input = NDArrayFactory::create('c', {3,2,3}); - auto exp = NDArrayFactory::create('c', {3, 3}, { 1, 2, 3, 7, 8, 9, 13, 14, 15.}); - input.linspace(1); - - sd::ops::unstack op; - - auto result = op.evaluate({&input}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto input = NDArrayFactory::create('c', {3, 2, 3}); + auto exp = NDArrayFactory::create('c', {3, 3}, + {1, 2, 3, 7, 8, 9, 13, 14, 15.}); + input.linspace(1); - auto z = result.at(0); + sd::ops::unstack op; - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto result = op.evaluate({&input}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, TestUnstack5) { - auto input = NDArrayFactory::create('c', {3,2,3}); - auto exp = NDArrayFactory::create('c', {2, 3}, { 1, 2, 3, 4, 5, 6}); - input.linspace(1); + auto input = NDArrayFactory::create('c', {3, 2, 3}); + auto exp = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + input.linspace(1); - sd::ops::unstack op; + sd::ops::unstack op; - auto result = op.evaluate({&input}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto result = op.evaluate({&input}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, TestUnstack6) { - auto input = NDArrayFactory::create('c', {1, 1, 1}); - auto exp = NDArrayFactory::create('c', {1, 1}, {1}); - input.linspace(1); - - sd::ops::unstack op; + auto input = NDArrayFactory::create('c', {1, 1, 1}); + auto exp = NDArrayFactory::create('c', {1, 1}, {1}); + input.linspace(1); - auto result = op.evaluate({&input}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::unstack op; - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto result = op.evaluate({&input}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, TestUnstack7) { - auto input = NDArrayFactory::create('c', {1, 1, 1}); - auto exp = NDArrayFactory::create('c', {1, 1}, {1}); - input.linspace(1); - - sd::ops::unstack op; + auto input = NDArrayFactory::create('c', {1, 1, 1}); + auto exp = NDArrayFactory::create('c', {1, 1}, {1}); + input.linspace(1); - auto result = op.evaluate({&input}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::unstack op; - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto result = op.evaluate({&input}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, TestUnstack8) { - auto input = NDArrayFactory::create('c', {1, 1}); - auto exp = NDArrayFactory::create('c', {1}, {1}); - input.linspace(1); - - sd::ops::unstack op; + auto input = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {1}, {1}); + input.linspace(1); - auto result = op.evaluate({&input}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::unstack op; - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto result = op.evaluate({&input}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, TestUnstack9) { - auto input = NDArrayFactory::create('c', {1, 1}); - auto exp = NDArrayFactory::create('c', {1}, {1}); - input.linspace(1); - - sd::ops::unstack op; - - auto result = op.evaluate({&input}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto input = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {1}, {1}); + input.linspace(1); - auto z = result.at(0); + sd::ops::unstack op; - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto result = op.evaluate({&input}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, TestUnstack10) { + auto input = NDArrayFactory::create('c', {3, 0, 2}); + auto exp = NDArrayFactory::create('c', {0, 2}); - auto input = NDArrayFactory::create('c', {3, 0, 2}); - auto exp = NDArrayFactory::create('c', {0,2}); - - sd::ops::unstack op; - - auto result = op.evaluate({&input}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.isSameShape(result.at(1))); - ASSERT_TRUE(exp.isSameShape(result.at(2))); + sd::ops::unstack op; + auto result = op.evaluate({&input}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.isSameShape(result.at(1))); + ASSERT_TRUE(exp.isSameShape(result.at(2))); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, TestUnstack11) { + auto input = NDArrayFactory::create('c', {3, 0, 2}); + auto exp = NDArrayFactory::create('c', {3, 0}); - auto input = NDArrayFactory::create('c', {3, 0, 2}); - auto exp = NDArrayFactory::create('c', {3,0}); - - sd::ops::unstack op; - - auto result = op.evaluate({&input}, {}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.isSameShape(result.at(1))); + sd::ops::unstack op; + auto result = op.evaluate({&input}, {}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.isSameShape(result.at(1))); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, TestUnstack12) { + auto input = NDArrayFactory::create('c', {3, 0, 2}); - auto input = NDArrayFactory::create('c', {3, 0, 2}); - - sd::ops::unstack op; - - auto result = op.evaluate({&input}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(result.size() == 0); + sd::ops::unstack op; + auto result = op.evaluate({&input}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.size() == 0); } TEST_F(ParityOpsTests, TestUnstack13) { + auto x = NDArrayFactory::create('c', {2, 3}); - auto x = NDArrayFactory::create('c', {2, 3}); - - sd::ops::unstack op; - auto result = op.evaluate({&x}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::unstack op; + auto result = op.evaluate({&x}, {}, {1}); - ASSERT_EQ(3, result.size()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - for (int e = 0; e < 3; e++) - ASSERT_EQ(1, result.at(e)->rankOf()); + ASSERT_EQ(3, result.size()); + for (int e = 0; e < 3; e++) ASSERT_EQ(1, result.at(e).rankOf()); } - - TEST_F(ParityOpsTests, ExpandDimsTest1) { - auto input = NDArrayFactory::create('c', {5, 5}); - input.linspace(1); - auto reshaped = input.reshape('c', {5, 1, 5}); - - sd::ops::expand_dims op; - auto result = op.evaluate({&input}, {}, {1}); + auto input = NDArrayFactory::create('c', {5, 5}); + input.linspace(1); + auto reshaped = input.reshape('c', {5, 1, 5}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::expand_dims op; + auto result = op.evaluate({&input}, {}, {1}); - auto z = result.at(0); - - ASSERT_TRUE(reshaped.isSameShape(z)); - ASSERT_TRUE(reshaped.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(reshaped.isSameShape(z)); + ASSERT_TRUE(reshaped.equalsTo(z)); } - TEST_F(ParityOpsTests, ExpandDimsTest2) { - auto input = NDArrayFactory::create('c', {3, 4}); - input.linspace(1); - auto reshaped = input.reshape('c', {1, 3, 4}); - - sd::ops::expand_dims op; - auto result = op.evaluate({&input}, {}, {0}); + auto input = NDArrayFactory::create('c', {3, 4}); + input.linspace(1); + auto reshaped = input.reshape('c', {1, 3, 4}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::expand_dims op; + auto result = op.evaluate({&input}, {}, {0}); - auto z = result.at(0); - - ASSERT_TRUE(reshaped.isSameShape(z)); - ASSERT_TRUE(reshaped.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(reshaped.isSameShape(z)); + ASSERT_TRUE(reshaped.equalsTo(z)); } - TEST_F(ParityOpsTests, ExpandDimsTest3) { - auto input = NDArrayFactory::create('c', {3, 4}); - input.linspace(1); - auto reshaped = input.reshape('c', {3, 1, 4}); - - sd::ops::expand_dims op; - auto result = op.evaluate({&input}, {}, {-2}); + auto input = NDArrayFactory::create('c', {3, 4}); + input.linspace(1); + auto reshaped = input.reshape('c', {3, 1, 4}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::expand_dims op; + auto result = op.evaluate({&input}, {}, {-2}); - auto z = result.at(0); - - ASSERT_TRUE(reshaped.isSameShape(z)); - ASSERT_TRUE(reshaped.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(reshaped.isSameShape(z)); + ASSERT_TRUE(reshaped.equalsTo(z)); } TEST_F(ParityOpsTests, ExpandDimsTest4) { - auto input = NDArrayFactory::create('c', {3, 4}); - input.linspace(1); - auto reshaped = input.reshape('c', {1, 3, 4}); - - sd::ops::expand_dims op; - auto result = op.evaluate({&input}, {}, {-3}); + auto input = NDArrayFactory::create('c', {3, 4}); + input.linspace(1); + auto reshaped = input.reshape('c', {1, 3, 4}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::expand_dims op; + auto result = op.evaluate({&input}, {}, {-3}); - auto z = result.at(0); - - ASSERT_TRUE(reshaped.isSameShape(z)); - ASSERT_TRUE(reshaped.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(reshaped.isSameShape(z)); + ASSERT_TRUE(reshaped.equalsTo(z)); } - TEST_F(ParityOpsTests, Test_Shape_1) { - auto x = NDArrayFactory::create('c', {3, 4, 5, 6}); - auto exp = NDArrayFactory::create('c', {4}, {3, 4, 5, 6}); + auto x = NDArrayFactory::create('c', {3, 4, 5, 6}); + auto exp = NDArrayFactory::create('c', {4}, {3, 4, 5, 6}); - sd::ops::shape_of op; - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::shape_of op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(ParityOpsTests, Test_Equals_1) { - auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {1, 5}, {1, 0, 3, 0, 5}); - auto exp = NDArrayFactory::create('c', {1, 5}, {1, 0, 1, 0, 1}); + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {1, 5}, {1, 0, 3, 0, 5}); + auto exp = NDArrayFactory::create('c', {1, 5}, {1, 0, 1, 0, 1}); - sd::ops::equals op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::equals op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(ParityOpsTests, Test_NotEquals_1) { - auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {1, 5}, {1, 0, 3, 0, 5}); - auto exp = NDArrayFactory::create('c', {1, 5}, {0, 1, 0, 1, 0}); - - sd::ops::not_equals op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {1, 5}, {1, 0, 3, 0, 5}); + auto exp = NDArrayFactory::create('c', {1, 5}, {0, 1, 0, 1, 0}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::not_equals op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Less_1) { - auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {1, 5}, {5, 4, 3, 2, 1}); - auto exp = NDArrayFactory::create('c', {1, 5}, {1, 1, 0, 0, 0}); - - sd::ops::less op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {1, 5}, {5, 4, 3, 2, 1}); + auto exp = NDArrayFactory::create('c', {1, 5}, {1, 1, 0, 0, 0}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::less op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_LessEquals_1) { - auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {1, 5}, {5, 4, 3, 2, 1}); - auto exp = NDArrayFactory::create('c', {1, 5}, {1, 1, 1, 0, 0}); + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {1, 5}, {5, 4, 3, 2, 1}); + auto exp = NDArrayFactory::create('c', {1, 5}, {1, 1, 1, 0, 0}); - sd::ops::less_equal op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::less_equal op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_GreaterEquals_1) { - auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {1, 5}, {5, 4, 3, 2, 1}); - auto exp = NDArrayFactory::create('c', {1, 5}, {0, 0, 1, 1, 1}); - - sd::ops::greater_equal op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {1, 5}, {5, 4, 3, 2, 1}); + auto exp = NDArrayFactory::create('c', {1, 5}, {0, 0, 1, 1, 1}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::greater_equal op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_GreaterEquals_2) { - auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {1, 5}, {5, 4, 3, 2, 1}); - auto exp = NDArrayFactory::create('c', {1, 5}, {0, 0, 1, 1, 1}); + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {1, 5}, {5, 4, 3, 2, 1}); + auto exp = NDArrayFactory::create('c', {1, 5}, {0, 0, 1, 1, 1}); - sd::ops::greater_equal op; - auto result = op.evaluate({&x, &y}, {}, {}, {}, {}, false); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::greater_equal op; + auto result = op.evaluate({&x, &y}, {}, {}, {}, {}, false); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Greater_1) { - auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {1, 5}, {5, 4, 3, 2, 1}); - auto exp = NDArrayFactory::create('c', {1, 5}, {0, 0, 0, 1, 1}); - - sd::ops::greater op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {1, 5}, {5, 4, 3, 2, 1}); + auto exp = NDArrayFactory::create('c', {1, 5}, {0, 0, 0, 1, 1}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::greater op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Where_1) { - auto mask = NDArrayFactory::create('c', {3, 3}, {1, 1, 1, 0, 0, 0, 1, 1, 1}); - auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto y = NDArrayFactory::create('c', {3, 3}, {9, 8, 7, 6, 5, 4, 3, 2, 1}); - auto exp = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 6, 5, 4, 7, 8, 9}); - - sd::ops::Where op; - auto result = op.evaluate({&mask, &x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto mask = + NDArrayFactory::create('c', {3, 3}, {1, 1, 1, 0, 0, 0, 1, 1, 1}); + auto x = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = + NDArrayFactory::create('c', {3, 3}, {9, 8, 7, 6, 5, 4, 3, 2, 1}); + auto exp = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 6, 5, 4, 7, 8, 9}); - // z->printIndexedBuffer("result"); + sd::ops::Where op; + auto result = op.evaluate({&mask, &x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printIndexedBuffer("result"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Where_2) { - auto mask = NDArrayFactory::create('c', {1, 3}, {1, 0, 0}); - auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto y = NDArrayFactory::create('c', {3, 3}, {9, 8, 7, 6, 5, 4, 3, 2, 1}); - auto exp = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 6, 5, 4, 3, 2, 1}); + auto mask = NDArrayFactory::create('c', {1, 3}, {1, 0, 0}); + auto x = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = + NDArrayFactory::create('c', {3, 3}, {9, 8, 7, 6, 5, 4, 3, 2, 1}); + auto exp = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 6, 5, 4, 3, 2, 1}); - sd::ops::Where op; - auto result = op.evaluate({&mask, &x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::Where op; + auto result = op.evaluate({&mask, &x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(ParityOpsTests, Test_Where_3) { - auto mask = NDArrayFactory::create('c', {2, 2, 3}, {0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1}); - auto exp = NDArrayFactory::create('c', {5, 3}, {0, 0, 1, 0, 0, 2, 0, 1, 1, 1, 0, 0, 1, 1, 2}); + auto mask = NDArrayFactory::create( + 'c', {2, 2, 3}, {0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1}); + auto exp = NDArrayFactory::create( + 'c', {5, 3}, {0, 0, 1, 0, 0, 2, 0, 1, 1, 1, 0, 0, 1, 1, 2}); - sd::ops::Where op; - auto result = op.evaluate({&mask}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::Where op; + auto result = op.evaluate({&mask}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - // z->printShapeInfo("z"); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printShapeInfo("z"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Select_1) { - auto mask = NDArrayFactory::create('c', {1, 3}, {1, 0, 0}); - auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto y = NDArrayFactory::create('c', {3, 3}, {9, 8, 7, 6, 5, 4, 3, 2, 1}); - auto exp = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 6, 5, 4, 3, 2, 1}); - - sd::ops::select op; - auto result = op.evaluate({&mask, &x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto mask = NDArrayFactory::create('c', {1, 3}, {1, 0, 0}); + auto x = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = + NDArrayFactory::create('c', {3, 3}, {9, 8, 7, 6, 5, 4, 3, 2, 1}); + auto exp = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 6, 5, 4, 3, 2, 1}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::select op; + auto result = op.evaluate({&mask, &x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Select_2) { - auto mask = NDArrayFactory::create('c', {2, 2}, {1, 0, 1, 0}); - auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4 }); - auto y = NDArrayFactory::create('c', {2, 2}, {9, 8, 7, 6}); - auto exp = NDArrayFactory::create('c', {2, 2}, {1, 8, 3, 6}); + auto mask = NDArrayFactory::create('c', {2, 2}, {1, 0, 1, 0}); + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {2, 2}, {9, 8, 7, 6}); + auto exp = NDArrayFactory::create('c', {2, 2}, {1, 8, 3, 6}); - sd::ops::select op; - auto result = op.evaluate({&mask, &x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::select op; + auto result = op.evaluate({&mask, &x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Select_3) { - bool value = false; - auto mask = NDArrayFactory::create('c', {1, 1}, {value}); - auto x = NDArrayFactory::create('c', {1, 1}, {1}); - auto y = NDArrayFactory::create('c', {1, 1}, {2}); - auto exp = NDArrayFactory::create('c', {1, 1}, {2}); - - sd::ops::select op; - auto result = op.evaluate({&mask, &x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + bool value = false; + auto mask = NDArrayFactory::create('c', {1, 1}, {value}); + auto x = NDArrayFactory::create('c', {1, 1}, {1}); + auto y = NDArrayFactory::create('c', {1, 1}, {2}); + auto exp = NDArrayFactory::create('c', {1, 1}, {2}); - auto z = result.at(0); + sd::ops::select op; + auto result = op.evaluate({&mask, &x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Bias_Add_1) { - auto x = NDArrayFactory::create('c', {10, 5}); - x.assign(0.0); - auto bias = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - sd::ops::biasadd op; + auto x = NDArrayFactory::create('c', {10, 5}); + x.assign(0.0); + auto bias = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + sd::ops::biasadd op; - auto result = op.evaluate({&x, &bias}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({&x, &bias}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - auto tads = z->allTensorsAlongDimension({1}); - for (int e = 0; e < tads.size(); e++) { - ASSERT_TRUE(bias.equalsTo(tads.at(e))); - } + auto z = result.at(0); + auto tads = z.allTensorsAlongDimension({1}); + for (int e = 0; e < tads.size(); e++) { + ASSERT_TRUE(bias.equalsTo(tads.at(e))); + } } TEST_F(ParityOpsTests, Test_Scatter_Add_1) { - auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - NDArray idc('c', {1}, std::vector({0}), sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {1, 2}, {1, 1}); - auto exp = NDArrayFactory::create('c', {2, 2}, {2, 3, 3, 4}); - - sd::ops::scatter_add op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + NDArray idc('c', {1}, std::vector({0}), sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2}, {1, 1}); + auto exp = NDArrayFactory::create('c', {2, 2}, {2, 3, 3, 4}); - auto z = result.at(0); + sd::ops::scatter_add op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Scatter_Add_2) { + auto vec = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + NDArray idc('c', {1, 4}, {0., 1, 2, 3}, sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 4}, {1, 1, 1, 1}); + auto exp = NDArrayFactory::create('c', {1, 4}, {2, 3, 4, 5}); - auto vec = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - NDArray idc('c', {1, 4}, {0., 1, 2, 3}, sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {1, 4}, {1, 1, 1, 1}); - auto exp = NDArrayFactory::create('c', {1, 4}, {2, 3, 4, 5}); - - sd::ops::scatter_add op; - auto result = op.evaluate({&vec, &idc, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + sd::ops::scatter_add op; + auto result = op.evaluate({&vec, &idc, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Scatter_Add_3) { - auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - NDArray idc('c', {1}, std::vector({0}), sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {1, 2, 2}, {1, 1, 1, 1}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {2, 3, 4, 5, 5, 6, 7, 8}); + auto matrix = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray idc('c', {1}, std::vector({0}), sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2, 2}, {1, 1, 1, 1}); + auto exp = + NDArrayFactory::create('c', {2, 2, 2}, {2, 3, 4, 5, 5, 6, 7, 8}); - sd::ops::scatter_add op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::scatter_add op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Scatter_Add_4) { - auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - NDArray idc('c', {1, 2}, std::vector{0, 0}, sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {1, 2, 2, 2}, {1, 1, 1, 1, 1, 1, 1, 1}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {3, 4, 5, 6, 5, 6, 7, 8}); - - sd::ops::scatter_add op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true, true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto matrix = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray idc('c', {1, 2}, std::vector{0, 0}, sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2, 2, 2}, + {1, 1, 1, 1, 1, 1, 1, 1}); + auto exp = + NDArrayFactory::create('c', {2, 2, 2}, {3, 4, 5, 6, 5, 6, 7, 8}); - auto z = result.at(0); + sd::ops::scatter_add op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true, true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Scatter_Add_5) { - auto matrix = NDArrayFactory::create('c', {2, 2, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - NDArray idc('c', {2, 2}, {1., 1, 0, 0}, sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {2, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto exp = NDArrayFactory::create('c', {2, 2, 3}, {9., 11., 13.,15., 17., 19., 9., 11., 13.,15., 17., 19.}); - - sd::ops::scatter_add op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto matrix = NDArrayFactory::create( + 'c', {2, 2, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + NDArray idc('c', {2, 2}, {1., 1, 0, 0}, sd::DataType::INT64); + auto updates = NDArrayFactory::create( + 'c', {2, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 3}, + {9., 11., 13., 15., 17., 19., 9., 11., 13., 15., 17., 19.}); - auto z = result.at(0); - // z->printBuffer(); + sd::ops::scatter_add op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printBuffer(); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Scatter_Add_6) { - auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 1, 1, 1, 1, 1, 1, 1}); - NDArray idc('c', {2, 2}, {1, 1, 0, 0}, sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {2, 2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {7, 9, 11, 13, 7, 9, 11, 13}); + auto matrix = + NDArrayFactory::create('c', {2, 2, 2}, {1, 1, 1, 1, 1, 1, 1, 1}); + NDArray idc('c', {2, 2}, {1, 1, 0, 0}, sd::DataType::INT64); + auto updates = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, + {7, 9, 11, 13, 7, 9, 11, 13}); - sd::ops::scatter_add op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true, true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::scatter_add op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true, true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Scatter_Add_7) { - auto matrix = NDArrayFactory::create('c', {10, 3}, {1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f,19.f,20.f,21.f,22.f,23.f,24.f,25.f,26.f,27.f,28.f,29.f,30.f}); - NDArray idc('c', {}, std::vector{5}, sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {3}, {10.f, 20.f, 30.f}); - auto exp = NDArrayFactory::create('c', {10, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f,11.f,12.f, 13.f,14.f,15.f, 26.f,37.f,48.f, 19.f,20.f,21.f, 22.f,23.f,24.f, 25.f,26.f,27.f, 28.f,29.f,30.f}); - - sd::ops::scatter_add op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto matrix = NDArrayFactory::create( + 'c', {10, 3}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f}); + NDArray idc('c', {}, std::vector{5}, sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {3}, {10.f, 20.f, 30.f}); + auto exp = NDArrayFactory::create( + 'c', {10, 3}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 26.f, 37.f, 48.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f}); - auto z = result.at(0); + sd::ops::scatter_add op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, Test_Scatter_Add_8) { + NDArray input('c', {8}, {1, 1, 1, 1, 1, 1, 1, 1}, sd::DataType::FLOAT32); + NDArray indices('c', {4}, {1, 1, 1, 1}, sd::DataType::INT32); + NDArray updates('c', {4}, {1, 2, 3, 4}, sd::DataType::FLOAT32); + NDArray expected('c', {8}, {1.f, 11.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}, + sd::DataType::FLOAT32); - NDArray input('c', {8}, {1,1,1,1,1,1,1,1}, sd::DataType::FLOAT32); - NDArray indices('c', {4}, {1, 1, 1, 1}, sd::DataType::INT32); - NDArray updates('c', {4}, {1,2,3,4}, sd::DataType::FLOAT32); - NDArray expected('c', {8}, {1.f, 11.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}, sd::DataType::FLOAT32); + NDArray z('c', {8}, sd::DataType::FLOAT32); - NDArray z('c', {8}, sd::DataType::FLOAT32); + sd::ops::scatter_add op; + Nd4jStatus status = + op.execute({&input, &indices, &updates}, {&z}, {}, {}, {true}); + // z.printBuffer(); - sd::ops::scatter_add op; - Nd4jStatus status = op.execute({&input, &indices, &updates}, {&z}, {}, {}, {true}); - // z.printBuffer(); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.isSameShapeStrict(z)); - ASSERT_TRUE(expected.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.isSameShapeStrict(z)); + ASSERT_TRUE(expected.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, Test_Scatter_Add_9) { - auto matrix = NDArrayFactory::create('c', {2, 2, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - NDArray idc('c', {2, 2}, {1, 10, 0, 0}, sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {2, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto output = NDArrayFactory::create('c', {2, 2, 3}); + auto matrix = NDArrayFactory::create( + 'c', {2, 2, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + NDArray idc('c', {2, 2}, {1, 10, 0, 0}, sd::DataType::INT64); + auto updates = NDArrayFactory::create( + 'c', {2, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto output = NDArrayFactory::create('c', {2, 2, 3}); - sd::ops::scatter_add op; + sd::ops::scatter_add op; - ASSERT_ANY_THROW(op.execute({&matrix, &idc, &updates}, {&output}, {}, {}, {true, true})); + ASSERT_ANY_THROW( + op.execute({&matrix, &idc, &updates}, {&output}, {}, {}, {true, true})); } //////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterMax_test1) { - auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - NDArray idc('c', {1}, std::vector{0.}, sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {1, 2}, {10, 1}); - auto exp = NDArrayFactory::create('c', {2, 2}, {10, 2, 3, 4}); - - sd::ops::scatter_max op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + NDArray idc('c', {1}, std::vector{0.}, sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2}, {10, 1}); + auto exp = NDArrayFactory::create('c', {2, 2}, {10, 2, 3, 4}); - auto z = result.at(0); + sd::ops::scatter_max op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, scatterMax_test2) { - auto vec = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - NDArray idc('c', {1, 4}, {0, 1, 2, 3}, sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {1, 4}, {10, 1, 30, 1}); - auto exp = NDArrayFactory::create('c', {1, 4}, {10, 2, 30, 4}); + auto vec = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + NDArray idc('c', {1, 4}, {0, 1, 2, 3}, sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 4}, {10, 1, 30, 1}); + auto exp = NDArrayFactory::create('c', {1, 4}, {10, 2, 30, 4}); - sd::ops::scatter_max op; - auto result = op.evaluate({&vec, &idc, &updates}, {}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::scatter_max op; + auto result = op.evaluate({&vec, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, scatterMax_test3) { - auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - NDArray idc('c', {1}, std::vector({0}), sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {1, 2, 2}, {10, 1, 30, 1}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {10, 2, 30, 4, 5, 6, 7, 8}); - - sd::ops::scatter_max op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto matrix = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray idc('c', {1}, std::vector({0}), sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2, 2}, {10, 1, 30, 1}); + auto exp = + NDArrayFactory::create('c', {2, 2, 2}, {10, 2, 30, 4, 5, 6, 7, 8}); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_max op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, scatterMax_test4) { - auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - NDArray idc('c', {1,2}, std::vector{0.,0}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {1, 2, 2, 2}, {1,10,1,10, 1,1,10,1.}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 10, 10, 10, 5, 6, 7, 8}); + auto matrix = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray idc('c', {1, 2}, std::vector{0., 0}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {1, 2, 2, 2}, + {1, 10, 1, 10, 1, 1, 10, 1.}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, + {1, 10, 10, 10, 5, 6, 7, 8}); - sd::ops::scatter_max op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {true}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::scatter_max op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {true}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, scatterMax_test5) { - auto matrix = NDArrayFactory::create('c', {2, 2, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - NDArray idc('c', {2, 2}, {1, 1, 0, 0}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {2, 2, 2, 3}, {2,10,1,10, 2,10,1,10, 2,10,1,10, 10,2,10,1, 10,2,10,1, 10,2,10,1.}); - auto exp = NDArrayFactory::create('c', {2, 2, 3}, {10, 2, 10, 2, 10, 2, 2, 10, 2, 10, 2, 10}); - - sd::ops::scatter_max op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto matrix = NDArrayFactory::create( + 'c', {2, 2, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + NDArray idc('c', {2, 2}, {1, 1, 0, 0}, sd::DataType::INT32); + auto updates = NDArrayFactory::create( + 'c', {2, 2, 2, 3}, {2, 10, 1, 10, 2, 10, 1, 10, 2, 10, 1, 10, + 10, 2, 10, 1, 10, 2, 10, 1, 10, 2, 10, 1.}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 3}, {10, 2, 10, 2, 10, 2, 2, 10, 2, 10, 2, 10}); - auto z = result.at(0); + sd::ops::scatter_max op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, scatterMax_test6) { - auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 1, 1, 1, 1, 1, 1, 1}); - NDArray idc('c', {2, 2}, {1, 1, 0, 0}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {2, 2, 2, 2}, {0,2,0,2, 0,2,0,2, 2,0,2,0., 2,0,2,0}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {2, 1, 2, 1, 1, 2, 1, 2}); - - sd::ops::scatter_max op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto matrix = + NDArrayFactory::create('c', {2, 2, 2}, {1, 1, 1, 1, 1, 1, 1, 1}); + NDArray idc('c', {2, 2}, {1, 1, 0, 0}, sd::DataType::INT32); + auto updates = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, {0, 2, 0, 2, 0, 2, 0, 2, 2, 0, 2, 0., 2, 0, 2, 0}); + auto exp = + NDArrayFactory::create('c', {2, 2, 2}, {2, 1, 2, 1, 1, 2, 1, 2}); - auto z = result.at(0); + sd::ops::scatter_max op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(ParityOpsTests, scatterMin_test1) { - auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - NDArray idc('c', {1}, std::vector({0}), sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {1, 2}, {-1, 1}); - auto exp = NDArrayFactory::create('c', {2, 2}, {-1, 1, 3, 4}); - - sd::ops::scatter_min op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + NDArray idc('c', {1}, std::vector({0}), sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {1, 2}, {-1, 1}); + auto exp = NDArrayFactory::create('c', {2, 2}, {-1, 1, 3, 4}); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_min op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, scatterMin_test2) { - auto vec = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - NDArray idc('c', {1, 4}, {0, 1, 2, 3}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {1, 4}, {10, 1, 30, 1}); - auto exp = NDArrayFactory::create('c', {1, 4}, {1, 1, 3, 1}); + auto vec = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + NDArray idc('c', {1, 4}, {0, 1, 2, 3}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {1, 4}, {10, 1, 30, 1}); + auto exp = NDArrayFactory::create('c', {1, 4}, {1, 1, 3, 1}); - sd::ops::scatter_min op; - auto result = op.evaluate({&vec, &idc, &updates}, {}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::scatter_min op; + auto result = op.evaluate({&vec, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, scatterMin_test3) { - auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - NDArray idc('c', {1}, std::vector({0}), sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {1, 2, 2}, {10, 1, 30, 2}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 1, 3, 2, 5, 6, 7, 8}); - - sd::ops::scatter_min op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto matrix = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray idc('c', {1}, std::vector({0}), sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {1, 2, 2}, {10, 1, 30, 2}); + auto exp = + NDArrayFactory::create('c', {2, 2, 2}, {1, 1, 3, 2, 5, 6, 7, 8}); - auto z = result.at(0); + sd::ops::scatter_min op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, scatterMin_test4) { - auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - NDArray idc('c', {1,2}, std::vector{0.,0}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {1, 2, 2, 2}, {1,10,1,10, 1,1,10,1.}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 1, 1, 1, 5, 6, 7, 8}); + auto matrix = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray idc('c', {1, 2}, std::vector{0., 0}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {1, 2, 2, 2}, + {1, 10, 1, 10, 1, 1, 10, 1.}); + auto exp = + NDArrayFactory::create('c', {2, 2, 2}, {1, 1, 1, 1, 5, 6, 7, 8}); - sd::ops::scatter_min op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::scatter_min op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - // z->printBuffer(); - - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printBuffer(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterMin_test5) { - auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - NDArray idc('c', {1,2}, {10,10}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {1, 2, 2, 2}, {1,10,1,10, 1,1,10,1.}); - auto output = NDArrayFactory::create('c', {2, 2, 2}); + auto matrix = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray idc('c', {1, 2}, {10, 10}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {1, 2, 2, 2}, + {1, 10, 1, 10, 1, 1, 10, 1.}); + auto output = NDArrayFactory::create('c', {2, 2, 2}); - sd::ops::scatter_min op; + sd::ops::scatter_min op; - ASSERT_ANY_THROW(op.execute({&matrix, &idc, &updates}, {&output}, {}, {}, {true, true})); + ASSERT_ANY_THROW( + op.execute({&matrix, &idc, &updates}, {&output}, {}, {}, {true, true})); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_test1) { + NDArray indices('c', {2, 1}, {1., 0.}, sd::DataType::INT32); + auto updates = NDArrayFactory::create( + 'c', {2, 4}, {10.f, 20.f, 30.f, 40.f, 50.f, 60.f, 70.f, 80.f}); + auto shape = NDArrayFactory::create('c', {2}, {3, 4}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {50.f, 60.f, 70.f, 80.f, 10.f, 20.f, 30.f, 40.f, 0.f, 0.f, 0.f, 0.f}); - NDArray indices('c', {2, 1}, {1., 0.}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {2, 4}, {10.f, 20.f, 30.f, 40.f, 50.f, 60.f, 70.f, 80.f}); - auto shape = NDArrayFactory::create('c', {2}, {3, 4}); - auto exp = NDArrayFactory::create('c', {3, 4}, {50.f, 60.f, 70.f, 80.f, 10.f, 20.f, 30.f, 40.f, 0.f, 0.f, 0.f, 0.f}); - - sd::ops::scatter_nd op; - auto result = op.evaluate({&indices, &updates, &shape}, {}, {false, true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::scatter_nd op; + auto result = op.evaluate({&indices, &updates, &shape}, {}, {false, true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - // z->printBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printBuffer(); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_test2) { + NDArray indices('c', {3, 1}, {4., 2., 0.}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3, 4}); + auto shape = NDArrayFactory::create('c', {2}, {5, 4}); + auto exp = NDArrayFactory::create( + 'c', {5, 4}, {9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 5.f, 6.f, + 7.f, 8.f, 0.f, 0.f, 0.f, 0.f, 1.f, 2.f, 3.f, 4.f}); + updates.linspace(1.f); - NDArray indices('c', {3, 1}, {4., 2., 0.}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {3, 4}); - auto shape = NDArrayFactory::create('c', {2}, {5, 4}); - auto exp = NDArrayFactory::create('c', {5, 4}, {9.f,10.f,11.f,12.f, 0.f, 0.f, 0.f, 0.f, 5.f, 6.f, 7.f, 8.f, 0.f, 0.f, 0.f, 0.f, 1.f, 2.f, 3.f, 4.f}); - updates.linspace(1.f); - - sd::ops::scatter_nd op; - auto result = op.evaluate({&indices, &updates, &shape}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + sd::ops::scatter_nd op; + auto result = op.evaluate({&indices, &updates, &shape}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_test3) { - - NDArray indices('c', {2, 3, 1}, {0., 2., 7., 3., 6., 9.}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {2,3, 3,4}); - auto shape = NDArrayFactory::create('c', {3}, {10, 3, 4}); - auto exp = NDArrayFactory::create('c', {10, 3, 4}, {1.f, 2.f, 3.f, 4., 5.f, 6.f, 7.f, 8., 9.f, 10.f, 11.f, 12., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., - 13.f, 14.f, 15.f, 16.,17.f, 18.f, 19.f, 20.,21.f, 22.f, 23.f, 24.,37.f, 38.f, 39.f, 40.,41.f, 42.f, 43.f, 44.,45.f, 46.f, 47.f, 48., - 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., - 49.f, 50.f, 51.f, 52.,53.f, 54.f, 55.f, 56.,57.f, 58.f, 59.f, 60.,25.f, 26.f, 27.f, 28.,29.f, 30.f, 31.f, 32.,33.f, 34.f, 35.f, 36., - 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0.,61.f, 62.f, 63.f, 64.,65.f, 66.f, 67.f, 68.,69.f, 70.f, 71.f, 72.,}); - updates.linspace(1.f); - - sd::ops::scatter_nd op; - auto result = op.evaluate({&indices, &updates, &shape}, {}, {false, true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + NDArray indices('c', {2, 3, 1}, {0., 2., 7., 3., 6., 9.}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2, 3, 3, 4}); + auto shape = NDArrayFactory::create('c', {3}, {10, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {10, 3, 4}, + { + 1.f, 2.f, 3.f, 4., 5.f, 6.f, 7.f, 8., 9.f, 10.f, 11.f, 12., + 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., + 13.f, 14.f, 15.f, 16., 17.f, 18.f, 19.f, 20., 21.f, 22.f, 23.f, 24., + 37.f, 38.f, 39.f, 40., 41.f, 42.f, 43.f, 44., 45.f, 46.f, 47.f, 48., + 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., + 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., + 49.f, 50.f, 51.f, 52., 53.f, 54.f, 55.f, 56., 57.f, 58.f, 59.f, 60., + 25.f, 26.f, 27.f, 28., 29.f, 30.f, 31.f, 32., 33.f, 34.f, 35.f, 36., + 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., + 61.f, 62.f, 63.f, 64., 65.f, 66.f, 67.f, 68., 69.f, 70.f, 71.f, 72., + }); + updates.linspace(1.f); + + sd::ops::scatter_nd op; + auto result = op.evaluate({&indices, &updates, &shape}, {}, {false, true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_test4) { + NDArray indices('c', {4, 1}, {4., 3., 1., 7.}, sd::DataType::INT32); + auto updates = + NDArrayFactory::create('c', {4}, {9.f, 10.f, 11.f, 12.f}); + auto shape = NDArrayFactory::create('c', {1}, {8}); + auto exp = NDArrayFactory::create( + 'c', {8}, {0.f, 11.f, 0.f, 10.f, 9.f, 0.f, 0.f, 12.f}); - NDArray indices('c', {4, 1}, {4., 3., 1., 7.}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {4}, {9.f, 10.f, 11.f, 12.f}); - auto shape = NDArrayFactory::create('c', {1}, {8}); - auto exp = NDArrayFactory::create('c', {8}, {0.f, 11.f, 0.f, 10.f, 9.f, 0.f, 0.f, 12.f}); - - sd::ops::scatter_nd op; - auto result = op.evaluate({&indices, &updates, &shape}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + sd::ops::scatter_nd op; + auto result = op.evaluate({&indices, &updates, &shape}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_test5) { + NDArray indices('c', {4, 1}, {1, 1, 1, 1}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto shape = NDArrayFactory::create('c', {1}, {8}); + auto exp = NDArrayFactory::create( + 'c', {8}, {0.f, 10.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); - NDArray indices('c', {4, 1}, {1, 1, 1, 1}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - auto shape = NDArrayFactory::create('c', {1}, {8}); - auto exp = NDArrayFactory::create('c', {8}, {0.f, 10.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + sd::ops::scatter_nd op; + auto result = op.evaluate({&indices, &updates, &shape}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - sd::ops::scatter_nd op; - auto result = op.evaluate({&indices, &updates, &shape}, {}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - // z->printBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printBuffer(); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_test6) { + NDArray indices('c', {3, 2}, {0, 1, 1, 0, 3, 2}, sd::DataType::INT32); + NDArray updates('c', {3, 2, 3}, sd::DataType::FLOAT32); + NDArray shape('c', {4}, {5, 4, 2, 3}, sd::DataType::INT32); - NDArray indices('c', {3, 2}, {0,1,1,0,3,2}, sd::DataType::INT32); - NDArray updates('c', {3, 2, 3}, sd::DataType::FLOAT32); - NDArray shape('c', {4}, {5,4,2,3}, sd::DataType::INT32); - - NDArray exp('c', {5,4,2,3}, {0., 0., 0.,0., 0., 0.,1., 2., 3.,4., 5., 6.,0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., - 7., 8., 9., 10., 11., 12., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 13., 14., 15., 16., 17., 18., 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, sd::DataType::FLOAT32); - updates.linspace(1); - - sd::ops::scatter_nd op; - auto result = op.evaluate({&indices, &updates, &shape}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + NDArray exp('c', {5, 4, 2, 3}, + {0., 0., 0., 0., 0., 0., 1., 2., 3., 4., 5., 6., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 7., 8., 9., 10., 11., 12., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 13., 14., 15., 16., 17., 18., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, + sd::DataType::FLOAT32); + updates.linspace(1); - auto z = result.at(0); - // z->printBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_nd op; + auto result = op.evaluate({&indices, &updates, &shape}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printBuffer(); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_test7) { - - NDArray indices('c', {4,3,2}, {0,1,1,0,3,2,1,0,0,1,1,0,3,2,1,0,0,1,1,0,3,2,1,0}, sd::DataType::INT32); - NDArray updates('c', {4,3,2,3}, sd::DataType::FLOAT32); - NDArray shape('c', {4}, {5,4,2,3}, sd::DataType::INT32); - - NDArray exp('c', {5,4,2,3}, {0., 0., 0., 0., 0., 0., 75., 78., 81., 84., 87., 90., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - 222., 228., 234., 240., 246., 252., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 111., 114., 117., 120., 123., 126., 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, sd::DataType::FLOAT32); - updates.linspace(1); - - sd::ops::scatter_nd op; - auto result = op.evaluate({&indices, &updates, &shape}, {}, {}, {true, true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - // z->printBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + NDArray indices('c', {4, 3, 2}, {0, 1, 1, 0, 3, 2, 1, 0, 0, 1, 1, 0, + 3, 2, 1, 0, 0, 1, 1, 0, 3, 2, 1, 0}, + sd::DataType::INT32); + NDArray updates('c', {4, 3, 2, 3}, sd::DataType::FLOAT32); + NDArray shape('c', {4}, {5, 4, 2, 3}, sd::DataType::INT32); + + NDArray exp('c', {5, 4, 2, 3}, + {0., 0., 0., 0., 0., 0., 75., 78., 81., 84., 87., 90., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 222., 228., 234., 240., 246., 252., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 111., 114., 117., 120., 123., 126., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, + sd::DataType::FLOAT32); + updates.linspace(1); + + sd::ops::scatter_nd op; + auto result = op.evaluate({&indices, &updates, &shape}, {}, {}, {true, true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + // z->printBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_test8) { + NDArray indices('c', {3, 2}, {0, 0, 1, 1, 2, 2}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto shape = NDArrayFactory::create('c', {2}, {6, 4}); + auto exp = NDArrayFactory::create( + 'c', {6, 4}, + {1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); - NDArray indices('c', {3, 2}, {0,0, 1,1, 2,2}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - auto shape = NDArrayFactory::create('c', {2}, {6,4}); - auto exp = NDArrayFactory::create('c', {6,4}, {1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); - - sd::ops::scatter_nd op; - auto result = op.evaluate({&indices, &updates, &shape}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - // z->printBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_nd op; + auto result = op.evaluate({&indices, &updates, &shape}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printBuffer(); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_test9) { + NDArray indices('c', {2, 3, 1}, {0., 20., 7., 30., 6., 90.}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2, 3, 3, 4}); + auto shape = NDArrayFactory::create('c', {3}, {10, 3, 4}); + auto output = NDArrayFactory::create('c', {10, 3, 4}); - NDArray indices('c', {2, 3, 1}, {0., 20., 7., 30., 6., 90.}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {2,3, 3,4}); - auto shape = NDArrayFactory::create('c', {3}, {10, 3, 4}); - auto output = NDArrayFactory::create('c', {10, 3, 4}); - - sd::ops::scatter_nd op; + sd::ops::scatter_nd op; - ASSERT_ANY_THROW(auto result = op.execute({&indices, &updates, &shape}, {&output}, {}, {}, {false, true})); + ASSERT_ANY_THROW(auto result = op.execute({&indices, &updates, &shape}, + {&output}, {}, {}, {false, true})); } - //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_add_test1) { + auto input = NDArrayFactory::create( + 'c', {8}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + NDArray indices('c', {4, 1}, {4., 3., 1., 7.}, sd::DataType::INT32); + auto updates = + NDArrayFactory::create('c', {4}, {9.f, 10.f, 11.f, 12.f}); + auto exp = NDArrayFactory::create( + 'c', {8}, {1.f, 13.f, 3.f, 14.f, 14.f, 6.f, 7.f, 20.f}); - auto input = NDArrayFactory::create('c', {8}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - NDArray indices('c', {4, 1}, {4., 3., 1., 7.}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {4}, {9.f, 10.f, 11.f, 12.f}); - auto exp = NDArrayFactory::create('c', {8}, {1.f, 13.f, 3.f, 14.f, 14.f, 6.f, 7.f, 20.f}); - - sd::ops::scatter_nd_add op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_nd_add op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_add_test2) { + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {3, 3, 2}, + {0.f, 0.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 0.f, 5.f, 1.f, + 0.f, 2.f, 1.f, 3.f, 2.f, 0.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3, 3}); + auto exp = NDArrayFactory::create( + 'c', {6, 4}, + {1.f, 0.f, 7.f, 0.f, 0.f, 2.f, 0.f, 8.f, 9.f, 0.f, 3.f, 0.f, + 0.f, 0.f, 0.f, 4.f, 5.f, 0.f, 0.f, 0.f, 0.f, 6.f, 0.f, 0.f}); - auto input = NDArrayFactory::create('c', {6, 4}); - NDArray indices('c', {3, 3, 2}, {0.f,0.f, 1.f,1.f, 2.f,2.f, 3.f,3.f, 4.f,0.f, 5.f,1.f, 0.f,2.f, 1.f,3.f, 2.f,0.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {3,3}); - auto exp = NDArrayFactory::create('c', {6,4}, {1.f,0.f,7.f,0.f, 0.f,2.f,0.f,8.f, 9.f,0.f,3.f,0.f, 0.f,0.f,0.f,4.f, 5.f,0.f,0.f,0.f, 0.f,6.f,0.f,0.f}); - - input = 0.f; - updates.linspace(1.f); + input = 0.f; + updates.linspace(1.f); - sd::ops::scatter_nd_add op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - // z->printIndexedBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_nd_add op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer(); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_add_test3) { + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {2, 3, 1}, {5.f, 1.f, 2.f, 3.f, 4.f, 0.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {6, 4}, + {21.f, 22.f, 23.f, 24.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 1.f, 2.f, 3.f, 4.f}); - auto input = NDArrayFactory::create('c', {6, 4}); - NDArray indices('c', {2, 3, 1}, {5.f, 1.f, 2.f, 3.f, 4.f, 0.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {6,4}, {21.f, 22.f, 23.f, 24.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f, 1.f, 2.f, 3.f, 4.f}); - - input = 0.f; - updates.linspace(1.f); - - sd::ops::scatter_nd_add op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + input = 0.f; + updates.linspace(1.f); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_nd_add op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_add_test4) { - - auto input = NDArrayFactory::create('c', {6, 4, 5}); - NDArray indices('c', {3, 3, 2}, {0.f,0.f, 1.f,1.f, 2.f,2.f, 3.f,3.f, 4.f,0.f, 5.f,1.f, 0.f,2.f, 1.f,3.f, 2.f,0.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {3,3,5}); - auto exp = NDArrayFactory::create('c', {6,4,5}, {1.f, 2.f, 3.f, 4.f, 5.f, 0.f, 0.f, 0.f, 0.f, 0.f,31.f, 32.f, 33.f, 34.f, 35.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 6.f, 7.f, 8.f, 9.f, 10.f, 0.f, 0.f, 0.f, 0.f, 0.f,36.f, 37.f, 38.f, 39.f, 40.f, - 41.f, 42.f, 43.f, 44.f, 45.f, 0.f, 0.f, 0.f, 0.f, 0.f,11.f, 12.f, 13.f, 14.f, 15.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,16.f, 17.f, 18.f, 19.f, 20.f, - 21.f, 22.f, 23.f, 24.f, 25.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f,26.f, 27.f, 28.f, 29.f, 30.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); - input = 0.f; - updates.linspace(1.f); - - sd::ops::scatter_nd_add op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto input = NDArrayFactory::create('c', {6, 4, 5}); + NDArray indices('c', {3, 3, 2}, + {0.f, 0.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 0.f, 5.f, 1.f, + 0.f, 2.f, 1.f, 3.f, 2.f, 0.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3, 3, 5}); + auto exp = NDArrayFactory::create( + 'c', {6, 4, 5}, + {1.f, 2.f, 3.f, 4.f, 5.f, 0.f, 0.f, 0.f, 0.f, 0.f, 31.f, 32.f, + 33.f, 34.f, 35.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 6.f, 7.f, 8.f, 9.f, 10.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, + 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 11.f, 12.f, 13.f, 14.f, 15.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 25.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 26.f, 27.f, 28.f, + 29.f, 30.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + input = 0.f; + updates.linspace(1.f); + + sd::ops::scatter_nd_add op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_add_test5) { - - auto input = NDArrayFactory::create('c', {6,5,4,3,2}); - NDArray indices('c', {2,2,3}, {0.f,0.f,0.f, 1.f,1.f,1.f, 2.f,2.f,2.f, 3.f,3.f,3.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {2,2,3,2}); - auto exp = NDArrayFactory::create('c', {6,5,4,3,2}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 7.f, 8.f, 9.f, 10.f,11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,13.f, 14.f,15.f, 16.f,17.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,19.f, 20.f,21.f, 22.f,23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); - input = 0.f; - updates.linspace(1.f); - - sd::ops::scatter_nd_add op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto input = NDArrayFactory::create('c', {6, 5, 4, 3, 2}); + NDArray indices('c', {2, 2, 3}, + {0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2, 2, 3, 2}); + auto exp = NDArrayFactory::create( + 'c', {6, 5, 4, 3, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + input = 0.f; + updates.linspace(1.f); + + sd::ops::scatter_nd_add op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_add_test6) { + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {2, 3, 1}, {50.f, 1.f, 2.f, 3.f, 40.f, 0.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2, 3, 4}); + auto output = NDArrayFactory::create('c', {6, 4}); - auto input = NDArrayFactory::create('c', {6, 4}); - NDArray indices('c', {2, 3, 1}, {50.f, 1.f, 2.f, 3.f, 40.f, 0.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {2,3,4}); - auto output = NDArrayFactory::create('c', {6,4}); + sd::ops::scatter_nd_add op; - sd::ops::scatter_nd_add op; - - ASSERT_ANY_THROW(op.execute({&input, &indices, &updates}, {&output}, {}, {}, {false, true})); + ASSERT_ANY_THROW(op.execute({&input, &indices, &updates}, {&output}, {}, {}, + {false, true})); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_sub_test1) { + auto input = NDArrayFactory::create( + 'c', {8}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + NDArray indices('c', {4, 1}, {4.f, 3.f, 1.f, 7.f}, sd::DataType::INT32); + auto updates = + NDArrayFactory::create('c', {4}, {9.f, 10.f, 11.f, 12.f}); + auto exp = NDArrayFactory::create( + 'c', {8}, {1.f, -9.f, 3.f, -6.f, -4.f, 6.f, 7.f, -4.f}); - auto input = NDArrayFactory::create('c', {8}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - NDArray indices('c', {4, 1}, {4.f, 3.f, 1.f, 7.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {4}, {9.f, 10.f, 11.f, 12.f}); - auto exp = NDArrayFactory::create('c', {8}, {1.f, -9.f, 3.f, -6.f, -4.f, 6.f, 7.f, -4.f}); - - sd::ops::scatter_nd_sub op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_nd_sub op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_sub_test2) { + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {3, 3, 2}, + {0.f, 0.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 0.f, 5.f, 1.f, + 0.f, 2.f, 1.f, 3.f, 2.f, 0.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3, 3}); + auto exp = NDArrayFactory::create( + 'c', {6, 4}, + {-1.f, 0.f, -7.f, 0.f, 0.f, -2.f, 0.f, -8.f, -9.f, 0.f, -3.f, 0.f, + 0.f, 0.f, 0.f, -4.f, -5.f, 0.f, 0.f, 0.f, 0.f, -6.f, 0.f, 0.f}); - auto input = NDArrayFactory::create('c', {6, 4}); - NDArray indices('c', {3, 3, 2}, {0.f,0.f, 1.f,1.f, 2.f,2.f, 3.f,3.f, 4.f,0.f, 5.f,1.f, 0.f,2.f, 1.f,3.f, 2.f,0.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {3,3}); - auto exp = NDArrayFactory::create('c', {6,4}, {-1.f,0.f,-7.f,0.f, 0.f,-2.f,0.f,-8.f, -9.f,0.f,-3.f,0.f, 0.f,0.f,0.f,-4.f, -5.f,0.f,0.f,0.f, 0.f,-6.f,0.f,0.f}); - - input = 0.f; - updates.linspace(1.f); - - sd::ops::scatter_nd_sub op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - //exp.printIndexedBuffer("e"); - //z->printIndexedBuffer("z"); + input = 0.f; + updates.linspace(1.f); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_nd_sub op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // exp.printIndexedBuffer("e"); + // z->printIndexedBuffer("z"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_sub_test3) { + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {2, 3, 1}, {5.f, 1.f, 2.f, 3.f, 4.f, 0.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {6, 4}, {-21.f, -22.f, -23.f, -24., -5.f, -6.f, -7.f, -8., + -9.f, -10.f, -11.f, -12., -13.f, -14.f, -15.f, -16., + -17.f, -18.f, -19.f, -20., -1.f, -2.f, -3.f, -4.f}); - auto input = NDArrayFactory::create('c', {6, 4}); - NDArray indices('c', {2, 3, 1}, {5.f, 1.f, 2.f, 3.f,4.f, 0.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {6,4}, {-21.f,-22.f,-23.f,-24., -5.f, -6.f, -7.f, -8., -9.f,-10.f,-11.f,-12., -13.f,-14.f,-15.f,-16., -17.f,-18.f,-19.f,-20., -1.f, -2.f, -3.f, -4.f}); - - input = 0.f; - updates.linspace(1.f); - - sd::ops::scatter_nd_sub op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + input = 0.f; + updates.linspace(1.f); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_nd_sub op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_sub_test4) { - - auto input = NDArrayFactory::create('c', {6, 4, 5}); - NDArray indices('c', {3, 3, 2}, {0.f,0.f, 1.f,1.f, 2.f,2.f, 3.f,3.f, 4.f,0.f, 5.f,1.f, 0.f,2.f, 1.f,3.f, 2.f,0.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {3,3,5}); - auto exp = NDArrayFactory::create('c', {6,4,5}, {-1.f, -2.f, -3.f, -4.f, -5.f, 0.f, 0.f, 0.f, 0.f, 0.f,-31.f, -32.f, -33.f, -34.f, -35.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, -6.f, -7.f, -8.f, -9.f, -10.f, 0.f, 0.f, 0.f, 0.f, 0.f,-36.f, -37.f, -38.f, -39.f, -40.f, - -41.f, -42.f, -43.f, -44.f, -45.f, 0.f, 0.f, 0.f, 0.f, 0.f,-11.f, -12.f, -13.f, -14.f, -15.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,-16.f, -17.f, -18.f, -19.f, -20.f, - -21.f, -22.f, -23.f, -24.f, -25.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f,-26.f, -27.f, -28.f, -29.f, -30.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); - input = 0.f; - updates.linspace(1.f); - - sd::ops::scatter_nd_sub op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto input = NDArrayFactory::create('c', {6, 4, 5}); + NDArray indices('c', {3, 3, 2}, + {0.f, 0.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 0.f, 5.f, 1.f, + 0.f, 2.f, 1.f, 3.f, 2.f, 0.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3, 3, 5}); + auto exp = NDArrayFactory::create( + 'c', {6, 4, 5}, + {-1.f, -2.f, -3.f, -4.f, -5.f, 0.f, 0.f, 0.f, 0.f, 0.f, + -31.f, -32.f, -33.f, -34.f, -35.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, -6.f, -7.f, -8.f, -9.f, -10.f, + 0.f, 0.f, 0.f, 0.f, 0.f, -36.f, -37.f, -38.f, -39.f, -40.f, + -41.f, -42.f, -43.f, -44.f, -45.f, 0.f, 0.f, 0.f, 0.f, 0.f, + -11.f, -12.f, -13.f, -14.f, -15.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, -16.f, -17.f, -18.f, -19.f, -20.f, + -21.f, -22.f, -23.f, -24.f, -25.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, -26.f, -27.f, -28.f, -29.f, -30.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + input = 0.f; + updates.linspace(1.f); + + sd::ops::scatter_nd_sub op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_sub_test5) { - - auto input = NDArrayFactory::create('c', {6,5,4,3,2}); - NDArray indices('c', {2,2,3}, {0.f,0.f,0.f, 1.f,1.f,1.f, 2.f,2.f,2.f, 3.f,3.f,3.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {2,2,3,2}); - auto exp = NDArrayFactory::create('c', {6,5,4,3,2}, { -1.f, -2.f, -3.f, -4.f, -5.f, -6.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -7.f, -8.f, -9.f, -10.f,-11.f, -12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,-13.f, -14.f,-15.f, -16.f,-17.f, -18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,-19.f, -20.f,-21.f, -22.f,-23.f,-24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); - input = 0.f; - updates.linspace(1.f); - - sd::ops::scatter_nd_sub op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto input = NDArrayFactory::create('c', {6, 5, 4, 3, 2}); + NDArray indices('c', {2, 2, 3}, + {0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2, 2, 3, 2}); + auto exp = NDArrayFactory::create( + 'c', {6, 5, 4, 3, 2}, + {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + -7.f, -8.f, -9.f, -10.f, -11.f, -12.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + -13.f, -14.f, -15.f, -16.f, -17.f, -18.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + -19.f, -20.f, -21.f, -22.f, -23.f, -24.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + input = 0.f; + updates.linspace(1.f); + + sd::ops::scatter_nd_sub op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_update_test1) { + auto input = NDArrayFactory::create( + 'c', {8}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + NDArray indices('c', {4, 1}, {4.f, 3.f, 1.f, 7.f}, sd::DataType::INT32); + auto updates = + NDArrayFactory::create('c', {4}, {9.f, 10.f, 11.f, 12.f}); + auto exp = NDArrayFactory::create( + 'c', {8}, {1.f, 11.f, 3.f, 10.f, 9.f, 6.f, 7.f, 12.f}); - auto input = NDArrayFactory::create('c', {8}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - NDArray indices('c', {4, 1}, {4.f, 3.f, 1.f, 7.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {4}, {9.f, 10.f, 11.f, 12.f}); - auto exp = NDArrayFactory::create('c', {8}, {1.f, 11.f, 3.f, 10.f, 9.f, 6.f, 7.f, 12.f}); - - sd::ops::scatter_nd_update op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_nd_update op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_update_test2) { + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {3, 3, 2}, + {0.f, 0.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 0.f, 5.f, 1.f, + 0.f, 2.f, 1.f, 3.f, 2.f, 0.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3, 3}); + auto exp = NDArrayFactory::create( + 'c', {6, 4}, + {1.f, -1.f, 7.f, -1.f, -1.f, 2.f, -1.f, 8.f, 9.f, -1.f, 3.f, -1.f, + -1.f, -1.f, -1.f, 4.f, 5.f, -1.f, -1.f, -1.f, -1.f, 6.f, -1.f, -1.f}); - auto input = NDArrayFactory::create('c', {6, 4}); - NDArray indices('c', {3, 3, 2}, {0.f,0.f, 1.f,1.f, 2.f,2.f, 3.f,3.f, 4.f,0.f, 5.f,1.f, 0.f,2.f, 1.f,3.f, 2.f,0.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {3,3}); - auto exp = NDArrayFactory::create('c', {6,4}, {1.f,-1.f,7.f,-1.f, -1.f,2.f,-1.f,8.f, 9.f,-1.f,3.f,-1.f, -1.f,-1.f,-1.f,4.f, 5.f,-1.f,-1.f,-1.f, -1.f,6.f,-1.f,-1.f}); - - input = -1.f; - updates.linspace(1.f); + input = -1.f; + updates.linspace(1.f); - sd::ops::scatter_nd_update op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - // z->printIndexedBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_nd_update op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer(); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_update_test3) { + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {2, 3, 1}, {5.f, 1.f, 2.f, 3.f, 4.f, 0.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {6, 4}, + { + 21.f, 22.f, 23.f, 24.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 1.f, 2.f, 3.f, 4.f, + }); - auto input = NDArrayFactory::create('c', {6, 4}); - NDArray indices('c', {2, 3, 1}, {5.f, 1.f, 2.f, 3.f, 4.f, 0.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {6,4}, {21.f, 22.f, 23.f, 24.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f, 1.f, 2.f, 3.f, 4.f,}); - - input = -1.f; - updates.linspace(1.f); + input = -1.f; + updates.linspace(1.f); - sd::ops::scatter_nd_update op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - // z->printBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_nd_update op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printBuffer(); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_update_test4) { - - auto input = NDArrayFactory::create('c', {6, 4, 5}); - NDArray indices('c', {3, 3, 2}, {0.f,0.f, 1.f,1.f, 2.f,2.f, 3.f,3.f, 4.f,0.f, 5.f,1.f, 0.f,2.f, 1.f,3.f, 2.f,0.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {3,3,5}); - auto exp = NDArrayFactory::create('c', {6,4,5}, {1.f, 2.f, 3.f, 4.f, 5.f, -1.f, -1.f, -1.f, -1.f, -1.f,31.f, 32.f, 33.f, 34.f, 35.f, -1.f, -1.f, -1.f, -1.f, -1.f, - -1.f, -1.f, -1.f, -1.f, -1.f, 6.f, 7.f, 8.f, 9.f, 10.f, -1.f, -1.f, -1.f, -1.f, -1.f,36.f, 37.f, 38.f, 39.f, 40.f, - 41.f, 42.f, 43.f, 44.f, 45.f, -1.f, -1.f, -1.f, -1.f, -1.f,11.f, 12.f, 13.f, 14.f, 15.f, -1.f, -1.f, -1.f, -1.f, -1.f, - -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f,16.f, 17.f, 18.f, 19.f, 20.f, - 21.f, 22.f, 23.f, 24.f, 25.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, - -1.f, -1.f, -1.f, -1.f, -1.f,26.f, 27.f, 28.f, 29.f, 30.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f}); - input = -1.f; - updates.linspace(1.f); - - sd::ops::scatter_nd_update op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto input = NDArrayFactory::create('c', {6, 4, 5}); + NDArray indices('c', {3, 3, 2}, + {0.f, 0.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 0.f, 5.f, 1.f, + 0.f, 2.f, 1.f, 3.f, 2.f, 0.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3, 3, 5}); + auto exp = NDArrayFactory::create( + 'c', {6, 4, 5}, + {1.f, 2.f, 3.f, 4.f, 5.f, -1.f, -1.f, -1.f, -1.f, -1.f, 31.f, 32.f, + 33.f, 34.f, 35.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, 6.f, 7.f, 8.f, 9.f, 10.f, -1.f, -1.f, -1.f, -1.f, -1.f, 36.f, + 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, -1.f, -1.f, -1.f, + -1.f, -1.f, 11.f, 12.f, 13.f, 14.f, 15.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 25.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, 26.f, 27.f, 28.f, + 29.f, 30.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f}); + input = -1.f; + updates.linspace(1.f); + + sd::ops::scatter_nd_update op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_update_test5) { - - auto input = NDArrayFactory::create('c', {6,5,4,3,2}); - NDArray indices('c', {2,2,3}, {0.f,0.f,0.f, 1.f,1.f,1.f, 2.f,2.f,2.f, 3.f,3.f,3.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {2,2,3,2}); - auto exp = NDArrayFactory::create('c', {6,5,4,3,2}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, 7.f, 8.f, 9.f, 10.f,11.f, 12.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f,13.f, 14.f,15.f, 16.f,17.f, 18.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f,19.f, 20.f,21.f, 22.f,23.f, 24.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f}); - input = -1.f; - updates.linspace(1.f); - - sd::ops::scatter_nd_update op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto input = NDArrayFactory::create('c', {6, 5, 4, 3, 2}); + NDArray indices('c', {2, 2, 3}, + {0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2, 2, 3, 2}); + auto exp = NDArrayFactory::create( + 'c', {6, 5, 4, 3, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f}); + input = -1.f; + updates.linspace(1.f); + + sd::ops::scatter_nd_update op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_update_test6) { + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {3, 3, 2}, + {0.f, 0.f, 10.f, 1.f, 20.f, 2.f, 30.f, 3.f, 40.f, 0.f, 50.f, + 1.f, 0.f, 2.f, 1.f, 3.f, 2.f, 0.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3, 3}); + auto output = NDArrayFactory::create('c', {6, 4}); - auto input = NDArrayFactory::create('c', {6, 4}); - NDArray indices('c', {3, 3, 2}, {0.f,0.f, 10.f,1.f, 20.f,2.f, 30.f,3.f, 40.f,0.f, 50.f,1.f, 0.f,2.f, 1.f,3.f, 2.f,0.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {3,3}); - auto output = NDArrayFactory::create('c', {6,4}); + sd::ops::scatter_nd_update op; - sd::ops::scatter_nd_update op; - - ASSERT_ANY_THROW(op.execute({&input, &indices, &updates}, {&output}, {}, {}, {true, true})); + ASSERT_ANY_THROW(op.execute({&input, &indices, &updates}, {&output}, {}, {}, + {true, true})); } ////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatter_update_1) { + NDArray x('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray updates('c', {2, 2}, {10, 20, 30, 40}, sd::DataType::INT32); - NDArray x('c', {2,2}, {1,2,3,4}, sd::DataType::INT32); - NDArray updates('c', {2,2}, {10,20,30,40}, sd::DataType::INT32); - - NDArray exp('c', {2,2}, {30,40,10,20}, sd::DataType::INT32); - - sd::ops::scatter_update op; - auto results = op.evaluate({&x, &updates}, {}, {6, 1,1, 2,1,0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - // x.printBuffer(); + NDArray exp('c', {2, 2}, {30, 40, 10, 20}, sd::DataType::INT32); - ASSERT_TRUE(exp.isSameShape(x)); - ASSERT_TRUE(exp.equalsTo(x)); + sd::ops::scatter_update op; + auto results = op.evaluate({&x, &updates}, {}, {6, 1, 1, 2, 1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + // x.printBuffer(); + ASSERT_TRUE(exp.isSameShape(x)); + ASSERT_TRUE(exp.equalsTo(x)); } ////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatter_update_2) { + NDArray x('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray updates('c', {2, 2}, {10, 20, 30, 40}, sd::DataType::INT32); - NDArray x('c', {2,2}, {1,2,3,4}, sd::DataType::INT32); - NDArray updates('c', {2,2}, {10,20,30,40}, sd::DataType::INT32); - - NDArray exp('c', {2,2}, {20,10,40,30}, sd::DataType::INT32); - - sd::ops::scatter_update op; - auto results = op.evaluate({&x, &updates}, {}, {6, 1,0, 2,1,0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + NDArray exp('c', {2, 2}, {20, 10, 40, 30}, sd::DataType::INT32); - ASSERT_TRUE(exp.isSameShape(x)); - ASSERT_TRUE(exp.equalsTo(x)); + sd::ops::scatter_update op; + auto results = op.evaluate({&x, &updates}, {}, {6, 1, 0, 2, 1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(exp.isSameShape(x)); + ASSERT_TRUE(exp.equalsTo(x)); } ////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatter_update_3) { + NDArray x('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}, sd::DataType::INT32); + NDArray updates('c', {2, 2, 2}, {10, 20, 30, 40, 50, 60, 70, 80}, + sd::DataType::INT32); - NDArray x('c', {2,2,2}, {1,2,3,4,5,6,7,8}, sd::DataType::INT32); - NDArray updates('c', {2,2,2}, {10,20,30,40,50,60,70,80}, sd::DataType::INT32); + NDArray exp('c', {2, 2, 2}, {50, 60, 70, 80, 10, 20, 30, 40}, + sd::DataType::INT32); - NDArray exp('c', {2,2,2}, {50,60,70,80,10,20,30,40}, sd::DataType::INT32); - - sd::ops::scatter_update op; - auto results = op.evaluate({&x, &updates}, {}, {6, 2,1,2, 2,1,0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(exp.isSameShape(x)); - ASSERT_TRUE(exp.equalsTo(x)); + sd::ops::scatter_update op; + auto results = op.evaluate({&x, &updates}, {}, {6, 2, 1, 2, 2, 1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(exp.isSameShape(x)); + ASSERT_TRUE(exp.equalsTo(x)); } ////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatter_update_4) { + NDArray x('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}, sd::DataType::INT32); + NDArray updates('c', {2, 2, 2}, {10, 20, 30, 40, 50, 60, 70, 80}, + sd::DataType::INT32); - NDArray x('c', {2,2,2}, {1,2,3,4,5,6,7,8}, sd::DataType::INT32); - NDArray updates('c', {2,2,2}, {10,20,30,40,50,60,70,80}, sd::DataType::INT32); - - NDArray exp('c', {2,2,2}, {20,2,3,10,60,6,7,50}, sd::DataType::INT32); - - sd::ops::scatter_update op; - auto results = op.evaluate({&x, &updates}, {}, {6, 1,0, 2,3,0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + NDArray exp('c', {2, 2, 2}, {20, 2, 3, 10, 60, 6, 7, 50}, + sd::DataType::INT32); - ASSERT_TRUE(exp.isSameShape(x)); - ASSERT_TRUE(exp.equalsTo(x)); + sd::ops::scatter_update op; + auto results = op.evaluate({&x, &updates}, {}, {6, 1, 0, 2, 3, 0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(exp.isSameShape(x)); + ASSERT_TRUE(exp.equalsTo(x)); } diff --git a/libnd4j/tests_cpu/layers_tests/PerformanceTests.cpp b/libnd4j/tests_cpu/layers_tests/PerformanceTests.cpp index c6155eb0c235..32d87ba4f705 100644 --- a/libnd4j/tests_cpu/layers_tests/PerformanceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PerformanceTests.cpp @@ -18,127 +18,129 @@ // @author raver119@gmail.com // -#include "testlayers.h" +#include #include -#include #include -#include #include -#include -#include -#include -#include -#include +#include +#include +#include #include -#include #include +#include +#include #include - -#include +#include +#include +#include +#include +#include #include -#include -#include -#include +#include #include #include -#include -#include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class PerformanceTests : public testing::Test { -public: - int numIterations = 100; + public: + int numIterations = 100; - PerformanceTests() { - samediff::ThreadPool::getInstance(); - } + PerformanceTests() { samediff::ThreadPool::getInstance(); } }; - #ifdef RELEASE_BUILD TEST_F(PerformanceTests, test_matmul_c_f_1) { - int iterations = 500; - std::vector valuesC, valuesF; - for (int e = 0; e < iterations; e++) { - auto xc = NDArrayFactory::create('c', {512, 2048}); - auto yc = NDArrayFactory::create('c', {2048, 512}); - auto zc = NDArrayFactory::create('c', {512, 512}); + int iterations = 500; + std::vector valuesC, valuesF; + for (int e = 0; e < iterations; e++) { + auto xc = NDArrayFactory::create('c', {512, 2048}); + auto yc = NDArrayFactory::create('c', {2048, 512}); + auto zc = NDArrayFactory::create('c', {512, 512}); - auto xf = NDArrayFactory::create('f', {512, 2048}); - auto yf = NDArrayFactory::create('f', {2048, 512}); - auto zf = NDArrayFactory::create('f', {512, 512}); + auto xf = NDArrayFactory::create('f', {512, 2048}); + auto yf = NDArrayFactory::create('f', {2048, 512}); + auto zf = NDArrayFactory::create('f', {512, 512}); - auto warm = xc.like(); - warm.linspace(1.0); + auto warm = xc.like(); + warm.linspace(1.0); - //zc.linspace(1.0); - //zf.linspace(1.0); + // zc.linspace(1.0); + // zf.linspace(1.0); - sd::ops::matmul op; + sd::ops::matmul op; - auto timeStartF = std::chrono::system_clock::now(); + auto timeStartF = std::chrono::system_clock::now(); - op.execute({&xf, &yf}, {&zf}); + op.execute({&xf, &yf}, {&zf}); - auto timeEndF = std::chrono::system_clock::now(); - auto outerTimeF = std::chrono::duration_cast(timeEndF - timeStartF).count(); + auto timeEndF = std::chrono::system_clock::now(); + auto outerTimeF = std::chrono::duration_cast( + timeEndF - timeStartF) + .count(); + auto timeStartC = std::chrono::system_clock::now(); - auto timeStartC = std::chrono::system_clock::now(); + op.execute({&xc, &yc}, {&zc}); - op.execute({&xc, &yc}, {&zc}); + auto timeEndC = std::chrono::system_clock::now(); + auto outerTimeC = std::chrono::duration_cast( + timeEndC - timeStartC) + .count(); - auto timeEndC = std::chrono::system_clock::now(); - auto outerTimeC = std::chrono::duration_cast(timeEndC - timeStartC).count(); + valuesF.emplace_back(outerTimeF); + valuesC.emplace_back(outerTimeC); + } - valuesF.emplace_back(outerTimeF); - valuesC.emplace_back(outerTimeC); - } + std::sort(valuesC.begin(), valuesC.end()); + std::sort(valuesF.begin(), valuesF.end()); - std::sort(valuesC.begin(), valuesC.end()); - std::sort(valuesF.begin(), valuesF.end()); - - - nd4j_printf("Median time C: [%lld]; Median time F: [%lld];", valuesC[valuesC.size() / 2], valuesF[valuesF.size() / 2]); + nd4j_printf("Median time C: [%lld]; Median time F: [%lld];", + valuesC[valuesC.size() / 2], valuesF[valuesF.size() / 2]); } TEST_F(PerformanceTests, test_maxpooling2d_1) { - std::vector valuesX; - // auto x = NDArrayFactory::create('c', {32, 3, 224, 224}); - // auto z = NDArrayFactory::create('c', {32, 3, 224, 224}); - auto x = NDArrayFactory::create('c', {8, 3, 64, 64}); - auto z = NDArrayFactory::create('c', {8, 3, 64, 64}); - x.linspace(1.0f); - Nd4jLong k = 5; - - - Nd4jLong iArgs[] {k,k, 1,1, 0,0, 1,1, 1}; - Context ctx(1); - ctx.setInputArray(0, &x); - ctx.setOutputArray(0, &z); - ctx.setIArguments(iArgs, 9); - - sd::ops::maxpool2d op; - - for (int i = 0; i < numIterations; i++) { - auto timeStart = std::chrono::system_clock::now(); - - op.execute(&ctx); - - auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - valuesX.emplace_back(outerTime); - - if ((i + 1) % 1000 == 0) - nd4j_printf("Iteration %i finished...\n", i + 1); - } - - std::sort(valuesX.begin(), valuesX.end()); - nd4j_printf("Execution time: %lld; Min: %lld; Max: %lld;\n", valuesX[valuesX.size() / 2], valuesX[0], valuesX[valuesX.size() - 1]); + std::vector valuesX; + // auto x = NDArrayFactory::create('c', {32, 3, 224, 224}); + // auto z = NDArrayFactory::create('c', {32, 3, 224, 224}); + auto x = NDArrayFactory::create('c', {8, 3, 64, 64}); + auto z = NDArrayFactory::create('c', {8, 3, 64, 64}); + x.linspace(1.0f); + Nd4jLong k = 5; + + Nd4jLong iArgs[]{k, k, 1, 1, 0, 0, 1, 1, 1}; + Context ctx(1); + ctx.setInputArray(0, &x); + ctx.setOutputArray(0, &z); + ctx.setIArguments(iArgs, 9); + + sd::ops::maxpool2d op; + + for (int i = 0; i < numIterations; i++) { + auto timeStart = std::chrono::system_clock::now(); + + op.execute(&ctx); + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast( + timeEnd - timeStart) + .count(); + valuesX.emplace_back(outerTime); + + if ((i + 1) % 1000 == 0) nd4j_printf("Iteration %i finished...\n", i + 1); + } + + std::sort(valuesX.begin(), valuesX.end()); + nd4j_printf("Execution time: %lld; Min: %lld; Max: %lld;\n", + valuesX[valuesX.size() / 2], valuesX[0], + valuesX[valuesX.size() - 1]); } #endif \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index e07a0496d57d..9deb691507a2 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -19,28 +19,27 @@ // Created by raver119 on 20.11.17. // -#include "testlayers.h" #include -#include #include -#include #include -#include -#include -#include -#include -#include +#include +#include +#include +#include #include -#include #include +#include +#include #include +#include +#include +#include +#include #include - -#include +#include +#include #include -#include -#include -#include +#include #include #include #include @@ -50,230 +49,222 @@ #include #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class PlaygroundTests : public testing::Test { -public: - int numIterations = 3; - int poolSize = 10; + public: + int numIterations = 3; + int poolSize = 10; - PlaygroundTests() { - } + PlaygroundTests() { + + } }; TEST_F(PlaygroundTests, test_avx) { - nd4j_printf("Optimal level: %i; Binary level: %i;\n", ::optimalLevel(), ::binaryLevel()); + nd4j_printf("Optimal level: %i; Binary level: %i;\n", ::optimalLevel(), + ::binaryLevel()); } - TEST_F(PlaygroundTests, test_biasAdd_1) { - auto x = NDArrayFactory::create('c', {512, 3072}); - auto y = NDArrayFactory::create('c', {3072}); + auto x = NDArrayFactory::create('c', {512, 3072}); + auto y = NDArrayFactory::create('c', {3072}); - std::vector values; + std::vector values; - sd::ops::biasadd op; + sd::ops::biasadd op; - for (int e = 0; e < 100; e++) { - auto timeStart = std::chrono::system_clock::now(); + for (int e = 0; e < 100; e++) { + auto timeStart = std::chrono::system_clock::now(); - op.execute({&x, &y}, {&x}); + op.execute({&x, &y}, {&x}); - auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - values.emplace_back(outerTime); - } + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast( + timeEnd - timeStart) + .count(); + values.emplace_back(outerTime); + } - std::sort(values.begin(), values.end()); + std::sort(values.begin(), values.end()); - nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); + nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); } - TEST_F(PlaygroundTests, test_bert_full_1) { -#ifdef _RELEASE - - // this test will run ONLY if this model exists - if (sd::graph::getFileSize("/home/raver119/Downloads/BertFull/model.fb") < 0) - return; +//#ifdef _RELEASE - auto graph = GraphExecutioner::importFromFlatBuffers("/home/raver119/Downloads/BertFull/model.fb"); + // this test will run ONLY if this model exists + if (!FileUtils::fileExists("/home/raver119/Downloads/BertFull/model.fb")) + return; - auto t = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/BertFull/in0_IteratorGetNext.npy"); - auto u = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/BertFull/in1_IteratorGetNext_1.npy"); - auto v = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/BertFull/in2_IteratorGetNext_4.npy"); - auto z = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/BertFull/out_loss-Softmax.npy"); + auto graph = + Graph::fromFlatBuffers("/home/raver119/Downloads/BertFull/model.fb"); - //graph->printOut(); + auto t = NDArrayFactory::fromNpyFile( + "/home/raver119/Downloads/BertFull/in0_IteratorGetNext.npy"); + auto u = NDArrayFactory::fromNpyFile( + "/home/raver119/Downloads/BertFull/in1_IteratorGetNext_1.npy"); + auto v = NDArrayFactory::fromNpyFile( + "/home/raver119/Downloads/BertFull/in2_IteratorGetNext_4.npy"); + auto z = NDArrayFactory::fromNpyFile( + "/home/raver119/Downloads/BertFull/out_loss-Softmax.npy"); - graph->tagInplaceNodes(); - - graph->getVariableSpace()->putVariable(658,0, t); - graph->getVariableSpace()->putVariable(659,0, u); - graph->getVariableSpace()->putVariable(660,0, v); + graph.printOut(); /* - // validating graph now - auto status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1620)); + // validating graph now + auto results = graph.execute({{"IteratorGetNext", t}, {"IteratorGetNext:1", u}, {"IteratorGetNext:4", v}}, {"loss/Softmax"}); + ASSERT_EQ(z, results["loss/Softmax"]); +*/ - auto array = graph->getVariableSpace()->getVariable(1620)->getNDArray(); - ASSERT_EQ(z, *array); -*/ +// sd::Environment::getInstance().setProfiling(true); + //auto profile = GraphProfilingHelper::profile(graph, 1); - sd::Environment::getInstance().setProfiling(true); - auto profile = GraphProfilingHelper::profile(graph, 1); + //profile->printOut(); - profile->printOut(); + //sd::Environment::getInstance().setProfiling(false); + //delete profile; - sd::Environment::getInstance().setProfiling(false); - delete profile; + /* + std::vector values; -/* - std::vector values; + for (int e = 0; e < 1; e++) { + auto timeStart = std::chrono::system_clock::now(); - for (int e = 0; e < 1; e++) { - auto timeStart = std::chrono::system_clock::now(); + GraphExecutioner::execute(graph); - GraphExecutioner::execute(graph); + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = + std::chrono::duration_cast(timeEnd - + timeStart).count(); values.emplace_back(outerTime); + } - auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - values.emplace_back(outerTime); - } + std::sort(values.begin(), values.end()); - std::sort(values.begin(), values.end()); + nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); + */ - nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); -*/ - delete graph; -#endif +//#endif } - TEST_F(PlaygroundTests, test_bert_1) { #ifdef _RELEASE - // this test will run ONLY if this model exists - if (sd::graph::getFileSize("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model.fb") < 0) - return; + // this test will run ONLY if this model exists + if (!FileUtils::fileExists( + "/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model.fb")) + return; - auto graph = GraphExecutioner::importFromFlatBuffers("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model.fb"); + auto graph = Graph::fromFlatBuffers( + "/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model.fb"); - auto t = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_input_IteratorGetNext.numpy"); - auto u = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_input_IteratorGetNext_1.numpy"); - auto v = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_input_IteratorGetNext_4.numpy"); - auto z = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model_output.numpy"); + auto t = NDArrayFactory::fromNpyFile( + "/home/raver119/Downloads/Bert_minimal_model/" + "bert_minimal_input_IteratorGetNext.numpy"); + auto u = NDArrayFactory::fromNpyFile( + "/home/raver119/Downloads/Bert_minimal_model/" + "bert_minimal_input_IteratorGetNext_1.numpy"); + auto v = NDArrayFactory::fromNpyFile( + "/home/raver119/Downloads/Bert_minimal_model/" + "bert_minimal_input_IteratorGetNext_4.numpy"); + auto z = NDArrayFactory::fromNpyFile( + "/home/raver119/Downloads/Bert_minimal_model/" + "bert_minimal_model_output.numpy"); - //graph->printOut(); + // graph->printOut(); - graph->tagInplaceNodes(); - graph->getVariableSpace()->putVariable(85,0, t); - graph->getVariableSpace()->putVariable(86,0, u); - graph->getVariableSpace()->putVariable(87,0, v); + graph.variableSpace().putVariable(85, 0, t); + graph.variableSpace().putVariable(86, 0, u); + graph.variableSpace().putVariable(87, 0, v); -/* - // validating graph now - auto status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(198)); + /* + // validating graph now + auto status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(graph->variableSpace()->hasVariable(198)); - auto array = graph->getVariableSpace()->getVariable(198)->getNDArray(); - ASSERT_EQ(z, *array); + auto array = graph->variableSpace()->getVariable(198)->getNDArray(); + ASSERT_EQ(z, *array); -*/ - sd::Environment::getInstance().setProfiling(true); - auto profile = GraphProfilingHelper::profile(graph, 1); - - profile->printOut(); + */ - sd::Environment::getInstance().setProfiling(false); - delete profile; + /* + std::vector values; -/* - std::vector values; + for (int e = 0; e < 1; e++) { + auto timeStart = std::chrono::system_clock::now(); - for (int e = 0; e < 1; e++) { - auto timeStart = std::chrono::system_clock::now(); + GraphExecutioner::execute(graph); - GraphExecutioner::execute(graph); + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = + std::chrono::duration_cast(timeEnd - + timeStart).count(); values.emplace_back(outerTime); + } - auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - values.emplace_back(outerTime); - } + std::sort(values.begin(), values.end()); - std::sort(values.begin(), values.end()); + nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); + */ - nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); -*/ - delete graph; #endif } TEST_F(PlaygroundTests, test_bert_2) { #ifdef _RELEASE - // this test will run ONLY if this model exists - if (sd::graph::getFileSize("/home/raver119/Downloads/Bert_minimal_model/bert_like_ops.fb") < 0) - return; - auto graph = GraphExecutioner::importFromFlatBuffers("/home/raver119/Downloads/Bert_minimal_model/bert_like_ops.fb"); + // this test will run ONLY if this model exists + if (!FileUtils::fileExists( + "/home/raver119/Downloads/Bert_minimal_model/bert_like_ops.fb")) + return; - //graph->printOut(); + auto graph = Graph::fromFlatBuffers( + "/home/raver119/Downloads/Bert_minimal_model/bert_like_ops.fb"); - graph->tagInplaceNodes(); + /* + // validating graph now + auto status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(graph->variableSpace()->hasVariable(198)); + auto array = graph->variableSpace()->getVariable(198)->getNDArray(); + ASSERT_EQ(z, *array); + */ -/* - // validating graph now - auto status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(198)); + /* + std::vector values; - auto array = graph->getVariableSpace()->getVariable(198)->getNDArray(); - ASSERT_EQ(z, *array); -*/ - - sd::Environment::getInstance().setProfiling(true); - auto profile = GraphProfilingHelper::profile(graph, 1); - - profile->printOut(); - - sd::Environment::getInstance().setProfiling(false); - delete profile; - -/* - std::vector values; + for (int e = 0; e < 1; e++) { + auto timeStart = std::chrono::system_clock::now(); - for (int e = 0; e < 1; e++) { - auto timeStart = std::chrono::system_clock::now(); + GraphExecutioner::execute(graph); - GraphExecutioner::execute(graph); + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = + std::chrono::duration_cast(timeEnd - + timeStart).count(); values.emplace_back(outerTime); + } - auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - values.emplace_back(outerTime); - } + std::sort(values.begin(), values.end()); - std::sort(values.begin(), values.end()); + nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); + */ - nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); -*/ - delete graph; #endif } - TEST_F(PlaygroundTests, test_one_off_ops_1) { - auto x = NDArrayFactory::create('c', {4, 128, 768}); - auto y = NDArrayFactory::create('c', {4, 128, 1}); - auto z = x.ulike(); + auto x = NDArrayFactory::create('c', {4, 128, 768}); + auto y = NDArrayFactory::create('c', {4, 128, 1}); + auto z = x.ulike(); - sd::ops::squaredsubtract op; - op.execute({&x, &y}, {&z}); + sd::ops::squaredsubtract op; + op.execute({&x, &y}, {&z}); } #if defined(INDEX_REDUCTIONS_BENCH_TESTS) @@ -560,8 +551,9 @@ TEST_F(PlaygroundTests, test_broadcast_1) { sd::ops::helpers::addBias(ctx, *x, *y, *z, false); auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - values.emplace_back(outerTime); + auto outerTime = +std::chrono::duration_cast(timeEnd - +timeStart).count(); values.emplace_back(outerTime); } std::sort(values.begin(), values.end()); @@ -575,8 +567,6 @@ TEST_F(PlaygroundTests, test_broadcast_1) { } } - -/* TEST_F(PlaygroundTests, test_broadcast_1) { int pool = 500; std::vector aX(pool); @@ -607,8 +597,9 @@ TEST_F(PlaygroundTests, test_broadcast_1) { x->applyTransform(transform::Tanh, *z, nullptr); auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - values.emplace_back(outerTime); + auto outerTime = +std::chrono::duration_cast(timeEnd - +timeStart).count(); values.emplace_back(outerTime); } std::sort(values.begin(), values.end()); @@ -626,8 +617,8 @@ TEST_F(PlaygroundTests, test_broadcast_1) { /* TEST_F(PlaygroundTests, test_s_0) { - std::vector> shapes = {{32, 224, 224, 3}, {32, 56, 56, 64}, {32, 7, 7, 512}}; - std::vector threads = {1, 2, 4, 8, 16}; + std::vector> shapes = {{32, 224, 224, 3}, {32, 56, 56, +64}, {32, 7, 7, 512}}; std::vector threads = {1, 2, 4, 8, 16}; for (auto shape: shapes) { for (auto t: threads) { @@ -653,20 +644,23 @@ TEST_F(PlaygroundTests, test_s_0) { sd::ops::helpers::addBias(ctx, x, y, z, false); auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - values.emplace_back(outerTime); + auto outerTime = +std::chrono::duration_cast(timeEnd - +timeStart).count(); values.emplace_back(outerTime); } std::sort(values.begin(), values.end()); - nd4j_printf("Shape: [%lld, %lld, %lld, %lld]; Threads: [%i]; Time: %lld us;\n", shape[0], shape[1], shape[2], shape[3], t, values[values.size() / 2]); + nd4j_printf("Shape: [%lld, %lld, %lld, %lld]; Threads: [%i]; Time: +%lld us;\n", shape[0], shape[1], shape[2], shape[3], t, values[values.size() / +2]); } } } TEST_F(PlaygroundTests, test_s_1) { - std::vector> shapes = {{32, 3, 224, 224}, {32, 64, 56, 56}, {32, 512, 7, 7}}; - std::vector threads = {1, 2, 4, 8, 16}; + std::vector> shapes = {{32, 3, 224, 224}, {32, 64, 56, +56}, {32, 512, 7, 7}}; std::vector threads = {1, 2, 4, 8, 16}; for (auto shape: shapes) { for (auto t: threads) { @@ -692,13 +686,16 @@ TEST_F(PlaygroundTests, test_s_1) { sd::ops::helpers::addBias(ctx, x, y, z, true); auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - values.emplace_back(outerTime); + auto outerTime = +std::chrono::duration_cast(timeEnd - +timeStart).count(); values.emplace_back(outerTime); } std::sort(values.begin(), values.end()); - nd4j_printf("Shape: [%lld, %lld, %lld, %lld]; Threads: [%i]; Time: %lld us;\n", shape[0], shape[1], shape[2], shape[3], t, values[values.size() / 2]); + nd4j_printf("Shape: [%lld, %lld, %lld, %lld]; Threads: [%i]; Time: +%lld us;\n", shape[0], shape[1], shape[2], shape[3], t, values[values.size() / +2]); } } } @@ -725,8 +722,8 @@ TEST_F(PlaygroundTests, test_s_0) { op.execute(&ctx); auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); - values.emplace_back(outerTime); + auto outerTime = std::chrono::duration_cast +(timeEnd - timeStart).count(); values.emplace_back(outerTime); } std::sort(values.begin(), values.end()); @@ -769,8 +766,8 @@ TEST_F(PlaygroundTests, test_s_1) { op.execute(&ctx); auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); - values.emplace_back(outerTime); + auto outerTime = std::chrono::duration_cast +(timeEnd - timeStart).count(); values.emplace_back(outerTime); } @@ -806,8 +803,8 @@ TEST_F(PlaygroundTests, test_s_2) { } auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); - values.emplace_back(outerTime); + auto outerTime = std::chrono::duration_cast +(timeEnd - timeStart).count(); values.emplace_back(outerTime); }; std::sort(values.begin(), values.end()); @@ -855,8 +852,9 @@ TEST_F(PlaygroundTests, test_s_4) { } } auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - valuesX.emplace_back(outerTime); + auto outerTime = +std::chrono::duration_cast(timeEnd - +timeStart).count(); valuesX.emplace_back(outerTime); } @@ -868,7 +866,8 @@ TEST_F(PlaygroundTests, test_s_4) { for (auto k = 0; k < xs2; k++) { for (auto l = 0; l < xs3; l++) { - zbuffer[thread_id] += buffer[i * j + (k * l)] * 2.5f; + zbuffer[thread_id] += buffer[i * j + (k * l)] +* 2.5f; } } } @@ -877,18 +876,21 @@ TEST_F(PlaygroundTests, test_s_4) { samediff::Threads::parallel_for(f2d, 0, xs0, 1, 0, xs1, 1); auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - valuesY.emplace_back(outerTime); + auto outerTime = +std::chrono::duration_cast(timeEnd - +timeStart).count(); valuesY.emplace_back(outerTime); } if (valuesX.size() > 0) { std::sort(valuesX.begin(), valuesX.end()); - nd4j_printf("OpenMP time: %lld; Min: %lld; Max: %lld;\n", valuesX[valuesX.size() / 2], valuesX[0], valuesX[valuesX.size() - 1]); + nd4j_printf("OpenMP time: %lld; Min: %lld; Max: %lld;\n", +valuesX[valuesX.size() / 2], valuesX[0], valuesX[valuesX.size() - 1]); } if (valuesY.size() > 0) { std::sort(valuesY.begin(), valuesY.end()); - nd4j_printf("Threads time: %lld; Min: %lld; Max: %lld;\n", valuesY[valuesY.size() / 2], valuesY[0], valuesY[valuesY.size() - 1]); + nd4j_printf("Threads time: %lld; Min: %lld; Max: %lld;\n", +valuesY[valuesY.size() / 2], valuesY[0], valuesY[valuesY.size() - 1]); } nd4j_printf("Sum: %f\n", z.sumNumber().e(0)); @@ -921,17 +923,20 @@ TEST_F(PlaygroundTests, test_s_5) { auto timeStart = std::chrono::system_clock::now(); // picking best fit here - auto splitLoop = samediff::ThreadsHelper::pickLoop2d(numThreads, itersX, itersY); - auto span = samediff::Span2::build(splitLoop, 0, numThreads, startX, stopX, incX, startY, stopY, incY); + auto splitLoop = samediff::ThreadsHelper::pickLoop2d(numThreads, itersX, +itersY); auto span = samediff::Span2::build(splitLoop, 0, numThreads, startX, +stopX, incX, startY, stopY, incY); auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - values.emplace_back(outerTime); + auto outerTime = +std::chrono::duration_cast(timeEnd - +timeStart).count(); values.emplace_back(outerTime); } std::sort(values.begin(), values.end()); - nd4j_printf("Calculations time: [Median: %lld; Min: %lld; Max: %lld;]\n", values[values.size() / 2], values[0], values[values.size()-1]); + nd4j_printf("Calculations time: [Median: %lld; Min: %lld; Max: %lld;]\n", +values[values.size() / 2], values[0], values[values.size()-1]); } @@ -951,13 +956,15 @@ TEST_F(PlaygroundTests, test_s_6) { } auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - values.emplace_back(outerTime); + auto outerTime = +std::chrono::duration_cast(timeEnd - +timeStart).count(); values.emplace_back(outerTime); } std::sort(values.begin(), values.end()); - nd4j_printf("Calculations time: [Median: %lld; Min: %lld; Max: %lld;]\n", values[values.size() / 2], values[0], values[values.size()-1]); + nd4j_printf("Calculations time: [Median: %lld; Min: %lld; Max: %lld;]\n", +values[values.size() / 2], values[0], values[values.size()-1]); } @@ -981,19 +988,21 @@ TEST_F(PlaygroundTests, test_relubp_1) { auto y = x.ulike(); auto z = x.ulike(); RandomGenerator rng(119, 120); - RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &x, -1.0, 1.0); - RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &y, -1.0, 1.0); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &x, +-1.0, 1.0); RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, +&y, -1.0, 1.0); int iterations = 10; auto timeStart = std::chrono::system_clock::now(); for (int e = 0; e < iterations; e++) - ops::helpers::reluDerivative(LaunchContext::defaultContext(), &x, &y, &z); - auto timeEnd = std::chrono::system_clock::now(); + ops::helpers::reluDerivative(LaunchContext::defaultContext(), &x, &y, +&z); auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); - auto time = (Nd4jLong) outerTime / iterations; - auto bw = (1000000L * (float) (x.lengthOf() * x.sizeOfT()) / time) / 1024 / 1024 / 1024; + auto outerTime = std::chrono::duration_cast +(timeEnd - timeStart).count(); auto time = (Nd4jLong) outerTime / iterations; + auto bw = (1000000L * (float) (x.lengthOf() * x.sizeOfT()) / time) / 1024 / +1024 / 1024; nd4j_printf("Time: %lld; BW: %f GB/s\n", time, bw); } @@ -1001,10 +1010,11 @@ TEST_F(PlaygroundTests, test_relubp_1) { ////////////////////////////////////////////////////////////////////// TEST_F(PlaygroundTests, my) { - int bS=8, iD=32,iH=32,iW=32, iC=128, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=2,dH=2,dW=2; - int oD,oH,oW; + int bS=8, iD=32,iH=32,iW=32, iC=128, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, +pD=0,pH=0,pW=0, dD=2,dH=2,dW=2; int oD,oH,oW; - sd::ops::ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, 0); + sd::ops::ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, +sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, 0); printf("!!%i, %i, %i\n", oD,oH,oW); @@ -1018,9 +1028,10 @@ TEST_F(PlaygroundTests, my) { auto block = new Context(1, variableSpace, false); // not-in-place auto timeStart = std::chrono::system_clock::now(); - sd::ops::ConvolutionUtils::col2vol(*block, col, vol, sD, sH, sW, pD, pH, pW, dD, dH, dW); - auto timeEnd = std::chrono::system_clock::now(); - auto time = std::chrono::duration_cast (timeEnd - timeStart).count(); + sd::ops::ConvolutionUtils::col2vol(*block, col, vol, sD, sH, sW, pD, pH, pW, +dD, dH, dW); auto timeEnd = std::chrono::system_clock::now(); auto time = +std::chrono::duration_cast (timeEnd - +timeStart).count(); printf("time: %i \n", time); @@ -1030,11 +1041,13 @@ TEST_F(PlaygroundTests, my) { TEST_F(PlaygroundTests, my) { - int bS=32, iD=32,iH=64,iW=64, iC=128, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=2,dH=2,dW=2; - int oD,oH,oW; + int bS=32, iD=32,iH=64,iW=64, iC=128, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, +pD=0,pH=0,pW=0, dD=2,dH=2,dW=2; int oD,oH,oW; - // sd::ops::ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, 0); - sd::ops::ConvolutionUtils::calcOutSizeDeconv2D(oH, oW, kH, kW, sH, sW, pH, pW,dH, dW, iH, iW, 0); + // sd::ops::ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, +sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, 0); + sd::ops::ConvolutionUtils::calcOutSizeDeconv2D(oH, oW, kH, kW, sH, sW, pH, +pW,dH, dW, iH, iW, 0); printf("!!%i, %i, %i\n", oD,oH,oW); @@ -1051,10 +1064,11 @@ TEST_F(PlaygroundTests, my) { auto block = new Context(1, variableSpace, false); // not-in-place auto timeStart = std::chrono::system_clock::now(); - // sd::ops::ConvolutionUtils::col2vol(*block, col, vol, sD, sH, sW, pD, pH, pW, dD, dH, dW); - sd::ops::helpers::col2im(*col.getContext(), col, im, sH, sW, pH, pW, iH, iW, dH, dW); - auto timeEnd = std::chrono::system_clock::now(); - auto time = std::chrono::duration_cast (timeEnd - timeStart).count(); + // sd::ops::ConvolutionUtils::col2vol(*block, col, vol, sD, sH, sW, pD, pH, +pW, dD, dH, dW); sd::ops::helpers::col2im(*col.getContext(), col, im, sH, sW, +pH, pW, iH, iW, dH, dW); auto timeEnd = std::chrono::system_clock::now(); auto +time = std::chrono::duration_cast (timeEnd - +timeStart).count(); printf("time: %i \n", time); @@ -1072,15 +1086,17 @@ TEST_F(PlaygroundTests, lstmLayerCellBp_1) { // const int nOut = 6; const float cellClip = 1.1; // clipping value - const Nd4jLong gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid - const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid - const Nd4jLong cellAct = 0; // tanh activation for cell state - const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh - const float cellBeta = 0; // beta value for cell state activation, not required for tanh - const Nd4jLong outAct = 0; // tanh activation for output - const float outAlpha = 0; // alpha value for output activation, not required for tanh - const float outBeta = 0; // beta value for output activation, not required for tanh + const Nd4jLong gateAct = 2; // sigmoid activation for input (i), +forget (f) and output (o) gates const float gateAlpha = 0; // alpha value +for activation for gates, not required for sigmoid const float gateBeta = 0; // +beta value for activation for gates, not required for sigmoid const Nd4jLong +cellAct = 0; // tanh activation for cell state const float cellAlpha = 0; +// alpha value for cell state activation, not required for tanh const float +cellBeta = 0; // beta value for cell state activation, not required for +tanh const Nd4jLong outAct = 0; // tanh activation for output const +float outAlpha = 0; // alpha value for output activation, not required for +tanh const float outBeta = 0; // beta value for output activation, not +required for tanh NDArray x ('c', {bS, nIn}, sd::DataType::DOUBLE); NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); @@ -1119,17 +1135,21 @@ TEST_F(PlaygroundTests, lstmLayerCellBp_1) { std::vector iArgs = {gateAct, cellAct, outAct}; // std::vector bArgs = {false, false}; - // const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &hI, &cI}, tArgs, iArgs, bArgs); - // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &hI, &cI, &dLdh}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &hI, &cI}, tArgs, iArgs, +bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &hI, &cI, &dLdh}, tArgs, +iArgs, bArgs); std::vector bArgs = {true, true}; - const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, +iArgs, bArgs); const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, +&dLdh}, tArgs, iArgs, bArgs); sd::ops::lstmLayerCell opFF; sd::ops::lstmLayerCellBp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, true, true, true}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, +argsHolderBP, {true, true, true, true, true, true, true}); } diff --git a/libnd4j/tests_cpu/layers_tests/ProtoBufTests.cpp b/libnd4j/tests_cpu/layers_tests/ProtoBufTests.cpp index fe2f97bb6c0e..6ded3e6ffeca 100644 --- a/libnd4j/tests_cpu/layers_tests/ProtoBufTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ProtoBufTests.cpp @@ -18,9 +18,7 @@ // @author raver119@gmail.com // - #include "testlayers.h" -#include /* @@ -33,7 +31,8 @@ class ProtoBufTests : public testing::Test { TEST_F(ProtoBufTests, TestBinaryLoad1) { GOOGLE_PROTOBUF_VERIFY_VERSION; - auto graph = GraphExecutioner::importFromTensorFlow("../../../tests/resources/tensorflow_inception_graph.pb"); + auto graph = +GraphExecutioner::importFromTensorFlow("../../../tests/resources/tensorflow_inception_graph.pb"); ASSERT_FALSE(graph == nullptr); } @@ -41,7 +40,8 @@ TEST_F(ProtoBufTests, TestBinaryLoad1) { TEST_F(ProtoBufTests, TestTextLoad1) { GOOGLE_PROTOBUF_VERIFY_VERSION; - auto graph = GraphExecutioner::importFromTensorFlow("../../../tests/resources/max_graph.pb.txt"); + auto graph = +GraphExecutioner::importFromTensorFlow("../../../tests/resources/max_graph.pb.txt"); ASSERT_FALSE(graph == nullptr); } @@ -50,14 +50,15 @@ TEST_F(ProtoBufTests, TestTextLoad1) { TEST_F(ProtoBufTests, TestTextLoad2) { GOOGLE_PROTOBUF_VERIFY_VERSION; - auto graph = GraphExecutioner::importFromTensorFlow("../../../tests/resources/max_add_2.pb.txt"); + auto graph = +GraphExecutioner::importFromTensorFlow("../../../tests/resources/max_add_2.pb.txt"); ASSERT_FALSE(graph == nullptr); - ASSERT_EQ(2, graph->getVariableSpace()->externalEntries()); + ASSERT_EQ(2, graph->variableSpace()->externalEntries()); - auto var0 = graph->getVariableSpace()->getVariable(new std::string("zeros")); - auto var1 = graph->getVariableSpace()->getVariable(new std::string("ones")); + auto var0 = graph->variableSpace()->getVariable(new std::string("zeros")); + auto var1 = graph->variableSpace()->getVariable(new std::string("ones")); // first we're veryfying variable states @@ -70,9 +71,11 @@ TEST_F(ProtoBufTests, TestTextLoad2) { ASSERT_EQ(12, var0->getNDArray()->lengthOf()); ASSERT_EQ(12, var1->getNDArray()->lengthOf()); - ASSERT_NEAR(0.0f, var0->getNDArray()->reduceNumber>(), 1e-5); - ASSERT_NEAR(12.0f, var1->getNDArray()->reduceNumber>(), 1e-5); - ASSERT_NEAR(1.0f, var1->getNDArray()->reduceNumber>(), 1e-5); + ASSERT_NEAR(0.0f, var0->getNDArray()->reduceNumber>(), +1e-5); ASSERT_NEAR(12.0f, +var1->getNDArray()->reduceNumber>(), 1e-5); + ASSERT_NEAR(1.0f, var1->getNDArray()->reduceNumber>(), +1e-5); // now we're veryfying op graph @@ -80,22 +83,25 @@ TEST_F(ProtoBufTests, TestTextLoad2) { GraphExecutioner::execute(graph); - ASSERT_NEAR(12.0f, var0->getNDArray()->reduceNumber>(), 1e-5); - ASSERT_NEAR(1.0f, var0->getNDArray()->reduceNumber>(), 1e-5); + ASSERT_NEAR(12.0f, var0->getNDArray()->reduceNumber>(), +1e-5); ASSERT_NEAR(1.0f, +var0->getNDArray()->reduceNumber>(), 1e-5); } TEST_F(ProtoBufTests, TestTextLoad3) { GOOGLE_PROTOBUF_VERIFY_VERSION; - auto graph = GraphExecutioner::importFromTensorFlow("../../../tests/resources/max_multiply.pb.txt"); + auto graph = +GraphExecutioner::importFromTensorFlow("../../../tests/resources/max_multiply.pb.txt"); ASSERT_FALSE(graph == nullptr); - ASSERT_EQ(2, graph->getVariableSpace()->externalEntries()); + ASSERT_EQ(2, graph->variableSpace()->externalEntries()); - auto var0 = graph->getVariableSpace()->getVariable(new std::string("Placeholder")); - auto var1 = graph->getVariableSpace()->getVariable(new std::string("Placeholder_1")); + auto var0 = graph->variableSpace()->getVariable(new +std::string("Placeholder")); auto var1 = graph->variableSpace()->getVariable(new +std::string("Placeholder_1")); ASSERT_TRUE(var0 != nullptr); ASSERT_TRUE(var1 != nullptr); diff --git a/libnd4j/tests_cpu/layers_tests/QuantizationTests.cpp b/libnd4j/tests_cpu/layers_tests/QuantizationTests.cpp index 97f6cd8cd3b6..5e1ef8de14b0 100644 --- a/libnd4j/tests_cpu/layers_tests/QuantizationTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/QuantizationTests.cpp @@ -18,53 +18,49 @@ // @author raver119@protonmail.com // - -#include "testlayers.h" #include #include +#include "testlayers.h" using namespace sd; -class QuantizationTests : public testing::Test { - -}; +class QuantizationTests : public testing::Test {}; TEST_F(QuantizationTests, Basic_Test_1) { #ifndef __CUDABLAS__ - auto s = TypeCast::estimateQuantizedSize(10); - ASSERT_EQ(18, s); + auto s = TypeCast::estimateQuantizedSize(10); + ASSERT_EQ(18, s); #endif } TEST_F(QuantizationTests, Basic_Test_2) { #ifndef __CUDABLAS__ - auto s = TypeCast::estimateQuantizedSize(1); - ASSERT_EQ(9, s); + auto s = TypeCast::estimateQuantizedSize(1); + ASSERT_EQ(9, s); #endif } TEST_F(QuantizationTests, Compression_Test_1) { +#ifndef __CUDABLAS__ - #ifndef __CUDABLAS__ - - auto x = NDArrayFactory::create('c', {10}); - auto z = NDArrayFactory::create('c', {10}); - x.linspace(1.0f); + auto x = NDArrayFactory::create('c', {10}); + auto z = NDArrayFactory::create('c', {10}); + x.linspace(1.0f); - auto q = new char[TypeCast::estimateQuantizedSize(x.lengthOf())]; + auto q = new char[TypeCast::estimateQuantizedSize(x.lengthOf())]; - TypeCast::convertToQuantized(nullptr, x.buffer(), x.lengthOf(), q); - TypeCast::convertFromQuantized(nullptr, q, x.lengthOf(), z.buffer()); + TypeCast::convertToQuantized(nullptr, x.buffer(), x.lengthOf(), q); + TypeCast::convertFromQuantized(nullptr, q, x.lengthOf(), z.buffer()); - ASSERT_TRUE(x.equalsTo(z, 0.1)); + ASSERT_TRUE(x.equalsTo(z, 0.1)); - auto fq = reinterpret_cast(q); + auto fq = reinterpret_cast(q); - ASSERT_NEAR(1.0f, fq[0], 1e-5); - ASSERT_NEAR(10.0f, fq[1], 1e-5); + ASSERT_NEAR(1.0f, fq[0], 1e-5); + ASSERT_NEAR(10.0f, fq[1], 1e-5); - delete[] q; + delete[] q; - #endif +#endif } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp index a2c33374a83d..224414af35a9 100644 --- a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp @@ -19,205 +19,215 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include #include #include -#include #include +#include + +#include + +#include "testlayers.h" using namespace sd; class RNGTests : public testing::Test { -private: - //Nd4jLong *_bufferA; - //Nd4jLong *_bufferB; - -public: - long _seed = 119L; - //sd::random::RandomBuffer *_rngA; - //sd::random::RandomBuffer *_rngB; - sd::graph::RandomGenerator _rngA; - sd::graph::RandomGenerator _rngB; - - NDArray* nexp0 = NDArrayFactory::create_('c', {10, 10}); - NDArray* nexp1 = NDArrayFactory::create_('c', {10, 10}); - NDArray* nexp2 = NDArrayFactory::create_('c', {10, 10}); - - RNGTests() { - //_bufferA = new Nd4jLong[100000]; - //_bufferB = new Nd4jLong[100000]; - //_rngA = (sd::random::RandomBuffer *) initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferA); - //_rngB = (sd::random::RandomBuffer *) initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferB); - _rngA.setStates(_seed * 0xDEADBEEF * 13, _seed * 0xDEADBEEF * 7); - _rngB.setStates(_seed * 0xDEADBEEF * 13, _seed * 0xDEADBEEF * 7); - nexp0->assign(-1.0f); - nexp1->assign(-2.0f); - nexp2->assign(-3.0f); - } + private: + // Nd4jLong *_bufferA; + // Nd4jLong *_bufferB; + + public: + long _seed = 119L; + sd::graph::RandomGenerator _rngA; + sd::graph::RandomGenerator _rngB; + + NDArray* nexp0 = NDArrayFactory::create_('c', {10, 10}); + NDArray* nexp1 = NDArrayFactory::create_('c', {10, 10}); + NDArray* nexp2 = NDArrayFactory::create_('c', {10, 10}); + + RNGTests() { + //_bufferA = new Nd4jLong[100000]; + //_bufferB = new Nd4jLong[100000]; + //_rngA = (sd::random::RandomBuffer *) initRandom(nullptr, _seed, 100000, + //(Nd4jPointer) _bufferA); _rngB = (sd::random::RandomBuffer *) + // initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferB); + _rngA.setStates(_seed * 0xDEADBEEF * 13, _seed * 0xDEADBEEF * 7); + _rngB.setStates(_seed * 0xDEADBEEF * 13, _seed * 0xDEADBEEF * 7); + nexp0->assign(-1.0f); + nexp1->assign(-2.0f); + nexp2->assign(-3.0f); + } - ~RNGTests() { - //destroyRandom(_rngA); - //destroyRandom(_rngB); - //delete[] _bufferA; - //delete[] _bufferB; + ~RNGTests() { + // destroyRandom(_rngA); + // destroyRandom(_rngB); + // delete[] _bufferA; + // delete[] _bufferB; - delete nexp0; - delete nexp1; - delete nexp2; - } + delete nexp0; + delete nexp1; + delete nexp2; + } }; TEST_F(RNGTests, TestSeeds_1) { - RandomGenerator generator(123L, 456L); + RandomGenerator generator(123L, 456L); - ASSERT_EQ(123, generator.rootState()); - ASSERT_EQ(456, generator.nodeState()); + ASSERT_EQ(123, generator.rootState()); + ASSERT_EQ(456, generator.nodeState()); - Nd4jPointer ptr = malloc(sizeof(RandomGenerator)); - memcpy(ptr, &generator, sizeof(RandomGenerator)); + Nd4jPointer ptr = malloc(sizeof(RandomGenerator)); + memcpy(ptr, &generator, sizeof(RandomGenerator)); - auto cast = reinterpret_cast(ptr); - ASSERT_EQ(123, cast->rootState()); - ASSERT_EQ(456, cast->nodeState()); + auto cast = reinterpret_cast(ptr); + ASSERT_EQ(123, cast->rootState()); + ASSERT_EQ(456, cast->nodeState()); - free(ptr); + free(ptr); } TEST_F(RNGTests, TestSeeds_2) { - RandomGenerator generator(12, 13); + RandomGenerator generator(12, 13); - generator.setStates(123L, 456L); + generator.setStates(123L, 456L); - ASSERT_EQ(123, generator.rootState()); - ASSERT_EQ(456, generator.nodeState()); + ASSERT_EQ(123, generator.rootState()); + ASSERT_EQ(456, generator.nodeState()); } TEST_F(RNGTests, TestGenerator_SGA_1) { - RandomGenerator generator(12, 13); - auto array= NDArrayFactory::create('c',{10000000}); - generator.setStates(123L, 456L); - for (auto idx = 0; idx < array.lengthOf(); idx++) { - float x = generator.relativeT(idx, -sd::DataTypeUtils::template max() / 10, - sd::DataTypeUtils::template max() / 10); - array.r(idx) = x; - } - auto minimum = array.reduceNumber(reduce::AMin); - minimum.printBuffer("Randomly float min on 1M array"); - ASSERT_EQ(123, generator.rootState()); - ASSERT_EQ(456, generator.nodeState()); + RandomGenerator generator(12, 13); + auto array = NDArrayFactory::create('c', {10000000}); + generator.setStates(123L, 456L); + for (auto idx = 0; idx < array.lengthOf(); idx++) { + float x = + generator.relativeT(idx, -sd::DataTypeUtils::template max() / 10, + sd::DataTypeUtils::template max() / 10); + array.r(idx) = x; + } + auto minimum = array.reduceNumber(reduce::AMin); + minimum.printBuffer("Randomly float min on 1M array"); + ASSERT_EQ(123, generator.rootState()); + ASSERT_EQ(456, generator.nodeState()); } - TEST_F(RNGTests, Test_Dropout_1) { - auto x0 = NDArrayFactory::create('c', {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); - - x0.linspace(1); - x1.linspace(1); - - float prob[] = {0.5f}; - - //x0.applyRandom(random::DropOut, _rngA, nullptr, &x0, prob); - //x1.applyRandom(random::DropOut, _rngB, nullptr, &x1, prob); - RandomLauncher::applyDropOut(LaunchContext::defaultContext(), _rngA, &x0, 0.5); - RandomLauncher::applyDropOut(LaunchContext::defaultContext(), _rngB, &x1, 0.5); - ASSERT_TRUE(x0.equalsTo(&x1)); - //x0.printIndexedBuffer("Dropout"); - // this check is required to ensure we're calling wrong signature - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + x0.linspace(1); + x1.linspace(1); + + float prob[] = {0.5f}; + + // x0.applyRandom(random::DropOut, _rngA, nullptr, &x0, prob); + // x1.applyRandom(random::DropOut, _rngB, nullptr, &x1, prob); + RandomLauncher::applyDropOut(LaunchContext::defaultContext(), _rngA, &x0, + 0.5); + RandomLauncher::applyDropOut(LaunchContext::defaultContext(), _rngB, &x1, + 0.5); + ASSERT_TRUE(x0.equalsTo(&x1)); + // x0.printIndexedBuffer("Dropout"); + // this check is required to ensure we're calling wrong signature + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); } TEST_F(RNGTests, Test_DropoutInverted_1) { - auto x0 = NDArrayFactory::create('c', {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); - - x0.linspace(1); - x1.linspace(1); - - float prob[] = {0.5f}; - - //x0.template applyRandom>(_rngA, nullptr, &x0, prob); - //x1.template applyRandom>(_rngB, nullptr, &x1, prob); - RandomLauncher::applyInvertedDropOut(LaunchContext::defaultContext(), _rngA, &x0, 0.5); - RandomLauncher::applyInvertedDropOut(LaunchContext::defaultContext(), _rngB, &x1, 0.5); - ASSERT_TRUE(x0.equalsTo(&x1)); - //x0.printIndexedBuffer("DropoutInverted"); - // this check is required to ensure we're calling wrong signature - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + x0.linspace(1); + x1.linspace(1); + + float prob[] = {0.5f}; + + // x0.template applyRandom>(_rngA, nullptr, + // &x0, prob); x1.template + // applyRandom>(_rngB, nullptr, &x1, prob); + RandomLauncher::applyInvertedDropOut(LaunchContext::defaultContext(), _rngA, + &x0, 0.5); + RandomLauncher::applyInvertedDropOut(LaunchContext::defaultContext(), _rngB, + &x1, 0.5); + ASSERT_TRUE(x0.equalsTo(&x1)); + // x0.printIndexedBuffer("DropoutInverted"); + // this check is required to ensure we're calling wrong signature + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); } - TEST_F(RNGTests, Test_Launcher_1) { - auto x0 = NDArrayFactory::create('c', {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - RandomLauncher::applyDropOut(LaunchContext::defaultContext(), _rngA, &x0, 0.5f); - RandomLauncher::applyDropOut(LaunchContext::defaultContext(), _rngB, &x1, 0.5f); + RandomLauncher::applyDropOut(LaunchContext::defaultContext(), _rngA, &x0, + 0.5f); + RandomLauncher::applyDropOut(LaunchContext::defaultContext(), _rngB, &x1, + 0.5f); - ASSERT_TRUE(x0.equalsTo(&x1)); + ASSERT_TRUE(x0.equalsTo(&x1)); - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); } - TEST_F(RNGTests, Test_Launcher_2) { - auto x0 = NDArrayFactory::create('c', {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - RandomLauncher::applyInvertedDropOut(LaunchContext::defaultContext(), _rngA, &x0, 0.5f); - RandomLauncher::applyInvertedDropOut(LaunchContext::defaultContext(), _rngB, &x1, 0.5f); + RandomLauncher::applyInvertedDropOut(LaunchContext::defaultContext(), _rngA, + &x0, 0.5f); + RandomLauncher::applyInvertedDropOut(LaunchContext::defaultContext(), _rngB, + &x1, 0.5f); - ASSERT_TRUE(x0.equalsTo(&x1)); + ASSERT_TRUE(x0.equalsTo(&x1)); - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); } - TEST_F(RNGTests, Test_Launcher_3) { - auto x0 = NDArrayFactory::create('c', {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - RandomLauncher::applyAlphaDropOut(LaunchContext::defaultContext(), _rngA, &x0, 0.5f, 0.2f, 0.1f, 0.3f); - RandomLauncher::applyAlphaDropOut(LaunchContext::defaultContext(), _rngB, &x1, 0.5f, 0.2f, 0.1f, 0.3f); + RandomLauncher::applyAlphaDropOut(LaunchContext::defaultContext(), _rngA, &x0, + 0.5f, 0.2f, 0.1f, 0.3f); + RandomLauncher::applyAlphaDropOut(LaunchContext::defaultContext(), _rngB, &x1, + 0.5f, 0.2f, 0.1f, 0.3f); - //x1.printIndexedBuffer("x1"); - ASSERT_TRUE(x0.equalsTo(&x1)); + // x1.printIndexedBuffer("x1"); + ASSERT_TRUE(x0.equalsTo(&x1)); - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); } TEST_F(RNGTests, Test_Uniform_1) { - auto x0 = NDArrayFactory::create('c', {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); - RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, + 2.0f); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, + 2.0f); - x0.printLinearBuffer(); + x0.printLinearBuffer(); x1.printLinearBuffer(); ASSERT_TRUE(x0.equalsTo(&x1)); - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); - for (int e = 0; e < x0.lengthOf(); e++) { - float v = x0.e(e); - nd4j_printf("%f\n", v); - ASSERT_TRUE(v >= 1.0f && v <= 2.0f); - } + for (int e = 0; e < x0.lengthOf(); e++) { + float v = x0.e(e); + nd4j_printf("%f\n", v);ASSERT_TRUE(v >= 1.0f && v <= 2.0f); + } } TEST_F(RNGTests, Test_Uniform_10) { @@ -292,750 +302,745 @@ TEST_F(RNGTests, Test_Uniform_13) { } TEST_F(RNGTests, Test_Uniform_3) { - auto x0 = NDArrayFactory::create('c', {1000000}); + auto x0 = NDArrayFactory::create('c', {1000000}); - RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, + 2.0f); - for (int e = 0; e < x0.lengthOf(); e++) { - auto v = x0.t(e); - ASSERT_TRUE(v >= 1.0 && v <= 2.0); - } + for (int e = 0; e < x0.lengthOf(); e++) { + auto v = x0.t(e); + ASSERT_TRUE(v >= 1.0 && v <= 2.0); + } } TEST_F(RNGTests, Test_Bernoulli_1) { - auto x0 = NDArrayFactory::create('c', {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - RandomLauncher::fillBernoulli(LaunchContext::defaultContext(), _rngA, &x0, 1.0f); - RandomLauncher::fillBernoulli(LaunchContext::defaultContext(), _rngB, &x1, 1.0f); + RandomLauncher::fillBernoulli(LaunchContext::defaultContext(), _rngA, &x0, + 1.0f); + RandomLauncher::fillBernoulli(LaunchContext::defaultContext(), _rngB, &x1, + 1.0f); - ASSERT_TRUE(x0.equalsTo(&x1)); + ASSERT_TRUE(x0.equalsTo(&x1)); - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); } TEST_F(RNGTests, Test_Gaussian_1) { - auto x0 = NDArrayFactory::create('c', {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); - RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngA, &x0, + 1.0f, 2.0f); + RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, + 1.0f, 2.0f); - //x0.printIndexedBuffer("x0"); - //x1.printIndexedBuffer("x1"); - ASSERT_TRUE(x0.equalsTo(&x1)); + // x0.printIndexedBuffer("x0"); + // x1.printIndexedBuffer("x1"); + ASSERT_TRUE(x0.equalsTo(&x1)); - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); } TEST_F(RNGTests, Test_Gaussian_21) { - auto x0 = NDArrayFactory::create('c', {1000, 1000}); - auto x1 = NDArrayFactory::create('c', {1000, 1000}); - - RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngA, &x0, 0.0f, 1.0f); - RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, 0.0f, 1.0f); - -// x0.printIndexedBuffer("x0"); -// x1.printIndexedBuffer("x1"); - ASSERT_TRUE(x0.equalsTo(&x1)); - - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); - sd::ops::moments op; - auto result = op.evaluate({&x0}, {}, {}); - //x0.printIndexedBuffer("X0 Normal"); - //x1.printIndexedBuffer("X1 Normal"); - ASSERT_TRUE(result.status() == Status::OK()); - auto mean = result.at(0); - auto variance = result.at(1); - - // mean->printIndexedBuffer("Mean"); - // variance->printIndexedBuffer("Variance"); - - ASSERT_NEAR(sd::math::nd4j_abs(mean->e(0)), 0.f, 0.2f); - ASSERT_NEAR(variance->e(0), 1.0f, 0.2f); - - + auto x0 = NDArrayFactory::create('c', {1000, 1000}); + auto x1 = NDArrayFactory::create('c', {1000, 1000}); + + RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngA, &x0, + 0.0f, 1.0f); + RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, + 0.0f, 1.0f); + + // x0.printIndexedBuffer("x0"); + // x1.printIndexedBuffer("x1"); + ASSERT_TRUE(x0.equalsTo(&x1)); + + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); + sd::ops::moments op; + auto result = op.evaluate({&x0}, {}, {}); + // x0.printIndexedBuffer("X0 Normal"); + // x1.printIndexedBuffer("X1 Normal"); + ASSERT_TRUE(result.status() == Status::OK()); + auto mean = result.at(0); + auto variance = result.at(1); + + // mean->printIndexedBuffer("Mean"); + // variance->printIndexedBuffer("Variance"); + + ASSERT_NEAR(sd::math::nd4j_abs(mean.e(0)), 0.f, 0.2f); + ASSERT_NEAR(variance.e(0), 1.0f, 0.2f); } #ifdef DEBUG_BUILD TEST_F(RNGTests, Test_Gaussian_22) { - auto x0 = NDArrayFactory::create('c', {1000, 800}); - auto x1 = NDArrayFactory::create('c', {1000, 800}); - - RandomLauncher::fillGaussian(sd::LaunchContext::defaultContext(), _rngA, &x0, 0.0f, 1.0f); - RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, 0.0f, 1.0f); - - //x0.printIndexedBuffer("x0"); - //x1.printIndexedBuffer("x1"); - ASSERT_TRUE(x0.equalsTo(&x1)); - - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); - sd::ops::moments op; - auto result = op.evaluate({&x0}, {}, {}); - //x0.printIndexedBuffer("X0 Normal"); - //x1.printIndexedBuffer("X1 Normal"); - ASSERT_TRUE(result.status() == Status::OK()); - auto mean0 = result.at(0); - auto variance0 = result.at(1); - - //mean0->printIndexedBuffer("Mean"); - //variance0->printIndexedBuffer("Variance"); - ASSERT_NEAR(sd::math::nd4j_abs(mean0->e(0)), 0.f, 1.0e-3f); - ASSERT_NEAR(variance0->e(0), 1.0f, 1.e-3f); - + auto x0 = NDArrayFactory::create('c', {1000, 800}); + auto x1 = NDArrayFactory::create('c', {1000, 800}); + + RandomLauncher::fillGaussian(sd::LaunchContext::defaultContext(), _rngA, &x0, + 0.0f, 1.0f); + RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, + 0.0f, 1.0f); + + // x0.printIndexedBuffer("x0"); + // x1.printIndexedBuffer("x1"); + ASSERT_TRUE(x0.equalsTo(&x1)); + + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); + sd::ops::moments op; + auto result = op.evaluate({&x0}, {}, {}); + // x0.printIndexedBuffer("X0 Normal"); + // x1.printIndexedBuffer("X1 Normal"); + ASSERT_TRUE(result.status() == Status::OK()); + auto mean0 = result.at(0); + auto variance0 = result.at(1); + + // mean0->printIndexedBuffer("Mean"); + // variance0->printIndexedBuffer("Variance"); + ASSERT_NEAR(sd::math::nd4j_abs(mean0.e(0)), 0.f, 1.0e-3f); + ASSERT_NEAR(variance0.e(0), 1.0f, 1.e-3f); } TEST_F(RNGTests, Test_Gaussian_3) { - auto x0 = NDArrayFactory::create('c', {800000}); - - RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngA, &x0, 0.0, 1.0); - - auto mean = x0.meanNumber(); //.e(0); - auto stdev = x0.varianceNumber(sd::variance::SummaryStatsStandardDeviation, false);//.e(0); - auto meanExp = NDArrayFactory::create(0.); - auto devExp = NDArrayFactory::create(1.); - ASSERT_TRUE(meanExp.equalsTo(mean, 1.e-3)); - ASSERT_TRUE(devExp.equalsTo(stdev, 1.e-3)); + auto x0 = NDArrayFactory::create('c', {800000}); + + RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngA, &x0, 0.0, + 1.0); + + auto mean = x0.meanNumber(); //.e(0); + auto stdev = x0.varianceNumber(sd::variance::SummaryStatsStandardDeviation, + false); //.e(0); + auto meanExp = NDArrayFactory::create(0.); + auto devExp = NDArrayFactory::create(1.); + ASSERT_TRUE(meanExp.equalsTo(mean, 1.e-3)); + ASSERT_TRUE(devExp.equalsTo(stdev, 1.e-3)); } TEST_F(RNGTests, Test_LogNormal_1) { - auto x0 = NDArrayFactory::create('c', {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - RandomLauncher::fillLogNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); - RandomLauncher::fillLogNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + RandomLauncher::fillLogNormal(LaunchContext::defaultContext(), _rngA, &x0, + 1.0f, 2.0f); + RandomLauncher::fillLogNormal(LaunchContext::defaultContext(), _rngB, &x1, + 1.0f, 2.0f); - ASSERT_TRUE(x0.equalsTo(&x1)); + ASSERT_TRUE(x0.equalsTo(&x1)); - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); } TEST_F(RNGTests, Test_Truncated_1) { - auto x0 = NDArrayFactory::create('c', {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); - - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); - - ASSERT_TRUE(x0.equalsTo(&x1)); - - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); - - /* Check up distribution */ - auto mean = x1.reduceNumber(reduce::Mean); - // mean.printIndexedBuffer("Mean 1.0"); - auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); - - auto deviation = x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); - //deviation /= (double)x1.lengthOf(); - // deviation.printIndexedBuffer("Deviation should be 2.0"); - // x1.printIndexedBuffer("Distribution TN"); - + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, + &x0, 1.0f, 2.0f); + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, + &x1, 1.0f, 2.0f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); + + /* Check up distribution */ + auto mean = x1.reduceNumber(reduce::Mean); + // mean.printIndexedBuffer("Mean 1.0"); + auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); + + auto deviation = + x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); + // deviation /= (double)x1.lengthOf(); + // deviation.printIndexedBuffer("Deviation should be 2.0"); + // x1.printIndexedBuffer("Distribution TN"); } TEST_F(RNGTests, Test_Truncated_2) { - auto x0 = NDArrayFactory::create('c', {1000, 1000}); - auto x1 = NDArrayFactory::create('c', {1000, 1000}); - - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); - - ASSERT_TRUE(x0.equalsTo(&x1)); - - //ASSERT_FALSE(x0.equalsTo(nexp0)); - //ASSERT_FALSE(x0.equalsTo(nexp1)); - //ASSERT_FALSE(x0.equalsTo(nexp2)); - - /* Check up distribution */ - auto mean = x1.reduceNumber(reduce::Mean); - // mean.printIndexedBuffer("Mean 1.0"); - //auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); - - auto deviation = x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); - //deviation /= (double)x1.lengthOf(); - // deviation.printIndexedBuffer("Deviation should be 2.0"); - //x1.printIndexedBuffer("Distribution TN"); - ASSERT_NEAR(mean.e(0), 1.f, 0.5); - ASSERT_NEAR(deviation.e(0), 2.f, 0.5); + auto x0 = NDArrayFactory::create('c', {1000, 1000}); + auto x1 = NDArrayFactory::create('c', {1000, 1000}); + + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, + &x0, 1.0f, 2.0f); + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, + &x1, 1.0f, 2.0f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + + // ASSERT_FALSE(x0.equalsTo(nexp0)); + // ASSERT_FALSE(x0.equalsTo(nexp1)); + // ASSERT_FALSE(x0.equalsTo(nexp2)); + + /* Check up distribution */ + auto mean = x1.reduceNumber(reduce::Mean); + // mean.printIndexedBuffer("Mean 1.0"); + // auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); + + auto deviation = + x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); + // deviation /= (double)x1.lengthOf(); + // deviation.printIndexedBuffer("Deviation should be 2.0"); + // x1.printIndexedBuffer("Distribution TN"); + ASSERT_NEAR(mean.e(0), 1.f, 0.5); + ASSERT_NEAR(deviation.e(0), 2.f, 0.5); } TEST_F(RNGTests, Test_Truncated_21) { - auto x0 = NDArrayFactory::create('c', {100, 100}); - auto x1 = NDArrayFactory::create('c', {100, 100}); - - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); - - ASSERT_TRUE(x0.equalsTo(&x1)); - - auto mean0 = x0.reduceNumber(reduce::Mean); - // mean0.printIndexedBuffer("0Mean 1.0"); - //auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); - - auto deviation0 = x0.varianceNumber(variance::SummaryStatsStandardDeviation, false); - // deviation0.printIndexedBuffer("0Deviation should be 2.0"); - - //ASSERT_FALSE(x0.equalsTo(nexp0)); - //ASSERT_FALSE(x0.equalsTo(nexp1)); - //ASSERT_FALSE(x0.equalsTo(nexp2)); - - /* Check up distribution */ - auto mean = x1.reduceNumber(reduce::Mean); - // mean.printIndexedBuffer("Mean 1.0"); - //auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); - - auto deviation = x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); - //deviation /= (double)x1.lengthOf(); - // deviation.printIndexedBuffer("Deviation should be 2.0"); - //x1.printIndexedBuffer("Distribution TN"); - ASSERT_NEAR(mean.e(0), 1.f, 0.002); - ASSERT_NEAR(deviation.e(0), 2.f, 0.5); - sd::ops::moments op; - auto result = op.evaluate({&x0}, {}, {}, {}, {}, false); - - // result.at(0)->printBuffer("MEAN"); - // result.at(1)->printBuffer("VARIANCE"); - - sd::ops::reduce_min minOp; - sd::ops::reduce_max maxOp; - - auto minRes = minOp.evaluate({&x1}, {}, {}, {}); - auto maxRes = maxOp.evaluate({&x0}, {}, {}, {}); - // minRes->at(0)->printBuffer("MIN for Truncated"); - // maxRes->at(0)->printBuffer("MAX for Truncated"); + auto x0 = NDArrayFactory::create('c', {100, 100}); + auto x1 = NDArrayFactory::create('c', {100, 100}); + + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, + &x0, 1.0f, 2.0f); + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, + &x1, 1.0f, 2.0f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + + auto mean0 = x0.reduceNumber(reduce::Mean); + // mean0.printIndexedBuffer("0Mean 1.0"); + // auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); + + auto deviation0 = + x0.varianceNumber(variance::SummaryStatsStandardDeviation, false); + // deviation0.printIndexedBuffer("0Deviation should be 2.0"); + + // ASSERT_FALSE(x0.equalsTo(nexp0)); + // ASSERT_FALSE(x0.equalsTo(nexp1)); + // ASSERT_FALSE(x0.equalsTo(nexp2)); + + /* Check up distribution */ + auto mean = x1.reduceNumber(reduce::Mean); + // mean.printIndexedBuffer("Mean 1.0"); + // auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); + + auto deviation = + x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); + // deviation /= (double)x1.lengthOf(); + // deviation.printIndexedBuffer("Deviation should be 2.0"); + // x1.printIndexedBuffer("Distribution TN"); + ASSERT_NEAR(mean.e(0), 1.f, 0.002); + ASSERT_NEAR(deviation.e(0), 2.f, 0.5); + sd::ops::moments op; + auto result = op.evaluate({&x0}, {}, {}, {}, {}, false); + + // result.at(0)->printBuffer("MEAN"); + // result.at(1)->printBuffer("VARIANCE"); + + sd::ops::reduce_min minOp; + sd::ops::reduce_max maxOp; + + auto minRes = minOp.evaluate({&x1}, {}, {}, {}); + auto maxRes = maxOp.evaluate({&x0}, {}, {}, {}); + // minRes->at(0)->printBuffer("MIN for Truncated"); + // maxRes->at(0)->printBuffer("MAX for Truncated"); } TEST_F(RNGTests, Test_Truncated_22) { - auto x0 = NDArrayFactory::create('c', {100, 100}); - auto x1 = NDArrayFactory::create('c', {100, 100}); - - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 2.0f, 4.0f); - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 2.0f, 4.0f); - - ASSERT_TRUE(x0.equalsTo(&x1)); - - auto mean0 = x0.reduceNumber(reduce::Mean); - // mean0.printIndexedBuffer("0Mean 2.0"); - //auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); - - auto deviation0 = x0.varianceNumber(variance::SummaryStatsStandardDeviation, false); - // deviation0.printIndexedBuffer("0Deviation should be 4.0"); - - //ASSERT_FALSE(x0.equalsTo(nexp0)); - //ASSERT_FALSE(x0.equalsTo(nexp1)); - //ASSERT_FALSE(x0.equalsTo(nexp2)); - - /* Check up distribution */ - auto mean = x1.reduceNumber(reduce::Mean); - // mean.printIndexedBuffer("Mean 2.0"); - //auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); - - auto deviation = x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); - //deviation /= (double)x1.lengthOf(); - // deviation.printIndexedBuffer("Deviation should be 4.0"); - //x1.printIndexedBuffer("Distribution TN"); - ASSERT_NEAR(mean.e(0), 2.f, 0.01); - ASSERT_NEAR(deviation.e(0), 4.f, 0.52); - sd::ops::moments op; - auto result = op.evaluate({&x0}, {}, {}, {}, {}, false); - - sd::ops::reduce_min minOp; - sd::ops::reduce_max maxOp; - - auto minRes = minOp.evaluate({&x1}, {}, {}, {}); - auto maxRes = maxOp.evaluate({&x0}, {}, {}, {}); - // minRes->at(0)->printBuffer("MIN for Truncated2"); - // maxRes->at(0)->printBuffer("MAX for Truncated2"); - + auto x0 = NDArrayFactory::create('c', {100, 100}); + auto x1 = NDArrayFactory::create('c', {100, 100}); + + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, + &x0, 2.0f, 4.0f); + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, + &x1, 2.0f, 4.0f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + + auto mean0 = x0.reduceNumber(reduce::Mean); + // mean0.printIndexedBuffer("0Mean 2.0"); + // auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); + + auto deviation0 = + x0.varianceNumber(variance::SummaryStatsStandardDeviation, false); + // deviation0.printIndexedBuffer("0Deviation should be 4.0"); + + // ASSERT_FALSE(x0.equalsTo(nexp0)); + // ASSERT_FALSE(x0.equalsTo(nexp1)); + // ASSERT_FALSE(x0.equalsTo(nexp2)); + + /* Check up distribution */ + auto mean = x1.reduceNumber(reduce::Mean); + // mean.printIndexedBuffer("Mean 2.0"); + // auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); + + auto deviation = + x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); + // deviation /= (double)x1.lengthOf(); + // deviation.printIndexedBuffer("Deviation should be 4.0"); + // x1.printIndexedBuffer("Distribution TN"); + ASSERT_NEAR(mean.e(0), 2.f, 0.01); + ASSERT_NEAR(deviation.e(0), 4.f, 0.52); + sd::ops::moments op; + auto result = op.evaluate({&x0}, {}, {}, {}, {}, false); + + sd::ops::reduce_min minOp; + sd::ops::reduce_max maxOp; + + auto minRes = minOp.evaluate({&x1}, {}, {}, {}); + auto maxRes = maxOp.evaluate({&x0}, {}, {}, {}); + // minRes->at(0)->printBuffer("MIN for Truncated2"); + // maxRes->at(0)->printBuffer("MAX for Truncated2"); } TEST_F(RNGTests, Test_Truncated_23) { - auto x0 = NDArrayFactory::create('c', {1000, 1000}); - auto x1 = NDArrayFactory::create('c', {1000, 1000}); - - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 0.0f, 1.0f); - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 0.0f, 1.0f); - - ASSERT_TRUE(x0.equalsTo(&x1)); - - auto mean0 = x0.reduceNumber(reduce::Mean); - // mean0.printIndexedBuffer("0Mean 2.0"); - //auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); - - auto deviation0 = x0.varianceNumber(variance::SummaryStatsStandardDeviation, false); - // deviation0.printIndexedBuffer("0Deviation should be 4.0"); - - //ASSERT_FALSE(x0.equalsTo(nexp0)); - //ASSERT_FALSE(x0.equalsTo(nexp1)); - //ASSERT_FALSE(x0.equalsTo(nexp2)); - - /* Check up distribution */ - auto mean = x1.reduceNumber(reduce::Mean); - // mean.printIndexedBuffer("Mean 2.0"); - //auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); - - auto deviation = x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); - //deviation /= (double)x1.lengthOf(); - // deviation.printIndexedBuffer("Deviation should be 4.0"); - //x1.printIndexedBuffer("Distribution TN"); - ASSERT_NEAR(mean.e(0), 0.f, 0.01); - ASSERT_NEAR(deviation.e(0), 1.f, 0.5); - sd::ops::moments op; - auto result = op.evaluate({&x0}); - // result->at(0)->printBuffer("MEAN"); - // result->at(1)->printBuffer("VARIANCE"); - sd::ops::reduce_min minOp; - sd::ops::reduce_max maxOp; - - auto minRes = minOp.evaluate({&x1}, {}, {}, {}); - auto maxRes = maxOp.evaluate({&x0}, {}, {}, {}); - // minRes->at(0)->printBuffer("MIN for Truncated3"); - // maxRes->at(0)->printBuffer("MAX for Truncated3"); - + auto x0 = NDArrayFactory::create('c', {1000, 1000}); + auto x1 = NDArrayFactory::create('c', {1000, 1000}); + + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, + &x0, 0.0f, 1.0f); + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, + &x1, 0.0f, 1.0f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + + auto mean0 = x0.reduceNumber(reduce::Mean); + // mean0.printIndexedBuffer("0Mean 2.0"); + // auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); + + auto deviation0 = + x0.varianceNumber(variance::SummaryStatsStandardDeviation, false); + // deviation0.printIndexedBuffer("0Deviation should be 4.0"); + + // ASSERT_FALSE(x0.equalsTo(nexp0)); + // ASSERT_FALSE(x0.equalsTo(nexp1)); + // ASSERT_FALSE(x0.equalsTo(nexp2)); + + /* Check up distribution */ + auto mean = x1.reduceNumber(reduce::Mean); + // mean.printIndexedBuffer("Mean 2.0"); + // auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); + + auto deviation = + x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); + // deviation /= (double)x1.lengthOf(); + // deviation.printIndexedBuffer("Deviation should be 4.0"); + // x1.printIndexedBuffer("Distribution TN"); + ASSERT_NEAR(mean.e(0), 0.f, 0.01); + ASSERT_NEAR(deviation.e(0), 1.f, 0.5); + sd::ops::moments op; + auto result = op.evaluate({&x0}); + // result->at(0)->printBuffer("MEAN"); + // result->at(1)->printBuffer("VARIANCE"); + sd::ops::reduce_min minOp; + sd::ops::reduce_max maxOp; + + auto minRes = minOp.evaluate({&x1}, {}, {}, {}); + auto maxRes = maxOp.evaluate({&x0}, {}, {}, {}); + // minRes->at(0)->printBuffer("MIN for Truncated3"); + // maxRes->at(0)->printBuffer("MAX for Truncated3"); } TEST_F(RNGTests, Test_Truncated_3) { - auto x0 = NDArrayFactory::create('c', {2000, 2000}); - auto x1 = NDArrayFactory::create('c', {2000, 2000}); + auto x0 = NDArrayFactory::create('c', {2000, 2000}); + auto x1 = NDArrayFactory::create('c', {2000, 2000}); - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, + &x0, 1.0f, 2.0f); + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, + &x1, 1.0f, 2.0f); - ASSERT_TRUE(x0.equalsTo(&x1)); + ASSERT_TRUE(x0.equalsTo(&x1)); - // Check up distribution - auto mean = x1.reduceNumber(reduce::Mean); - // mean.printIndexedBuffer("Mean 1.0"); - //auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); + // Check up distribution + auto mean = x1.reduceNumber(reduce::Mean); + // mean.printIndexedBuffer("Mean 1.0"); + // auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); - auto deviation = x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); - ASSERT_NEAR(mean.e(0), 1.f, 0.001); - ASSERT_NEAR(deviation.e(0), 2.f, 0.3); + auto deviation = + x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); + ASSERT_NEAR(mean.e(0), 1.f, 0.001); + ASSERT_NEAR(deviation.e(0), 2.f, 0.3); } #endif TEST_F(RNGTests, Test_Binomial_1) { - auto x0 = NDArrayFactory::create('c', {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - RandomLauncher::fillBinomial(LaunchContext::defaultContext(), _rngA, &x0, 3, 2.0f); - RandomLauncher::fillBinomial(LaunchContext::defaultContext(), _rngB, &x1, 3, 2.0f); + RandomLauncher::fillBinomial(LaunchContext::defaultContext(), _rngA, &x0, 3, + 2.0f); + RandomLauncher::fillBinomial(LaunchContext::defaultContext(), _rngB, &x1, 3, + 2.0f); - ASSERT_TRUE(x0.equalsTo(&x1)); + ASSERT_TRUE(x0.equalsTo(&x1)); - //nexp2->printIndexedBuffer("nexp2"); - //x0.printIndexedBuffer("x0"); + // nexp2->printIndexedBuffer("nexp2"); + // x0.printIndexedBuffer("x0"); - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); } - TEST_F(RNGTests, Test_Uniform_2) { - auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); - - RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - auto op = new sd::ops::LegacyRandomOp(0); - auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {}); - - ASSERT_EQ(Status::OK(), result.status()); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, + 2.0f); - auto z = result.at(0); + sd::ops::LegacyRandomOp op(0); + auto result = + op.execute(_rngA, {&input}, {1.0f, 2.0f}, {}, {sd::DataType::FLOAT32}); - ASSERT_TRUE(x1.isSameShape(z)); - ASSERT_TRUE(x1.equalsTo(z)); + ASSERT_EQ(Status::OK(), result.status()); - delete op; + auto z = result.at(0); + ASSERT_TRUE(x1.isSameShape(z)); + ASSERT_TRUE(x1.equalsTo(z)); } TEST_F(RNGTests, Test_Uniform_SGA_3) { - //auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); - - RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, -sd::DataTypeUtils::template max(), sd::DataTypeUtils::template max()); - auto minimumU = x1.reduceNumber(reduce::AMin); - minimumU.printBuffer("\nMinimum"); + // auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, + -sd::DataTypeUtils::template max(), + sd::DataTypeUtils::template max()); + auto minimumU = x1.reduceNumber(reduce::AMin); + minimumU.printBuffer("\nMinimum"); } TEST_F(RNGTests, Test_Gaussian_2) { - auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); - - RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); - - auto op = new sd::ops::LegacyRandomOp(random::GaussianDistribution); - auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {}); + auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - ASSERT_EQ(Status::OK(), result.status()); + RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, + 1.0f, 2.0f); - auto z = result.at(0); + sd::ops::LegacyRandomOp op(random::GaussianDistribution); + auto result = op.execute(_rngA, {&input}, {1.0f, 2.0f}, {}); - ASSERT_TRUE(x1.isSameShape(z)); - ASSERT_TRUE(x1.equalsTo(z)); + ASSERT_EQ(Status::OK(), result.status()); - delete op; + auto z = result.at(0); + ASSERT_TRUE(x1.isSameShape(z)); + ASSERT_TRUE(x1.equalsTo(z)); } TEST_F(RNGTests, Test_LogNorm_2) { - auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); + auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - RandomLauncher::fillLogNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + RandomLauncher::fillLogNormal(LaunchContext::defaultContext(), _rngB, &x1, + 1.0f, 2.0f); - auto op = new sd::ops::LegacyRandomOp(random::LogNormalDistribution); - auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {}); + sd::ops::LegacyRandomOp op(random::LogNormalDistribution); + auto result = op.execute(_rngA, {&input}, {1.0f, 2.0f}, {}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_TRUE(x1.isSameShape(z)); - ASSERT_TRUE(x1.equalsTo(z)); - - delete op; + auto z = result.at(0); + ASSERT_TRUE(x1.isSameShape(z)); + ASSERT_TRUE(x1.equalsTo(z)); } TEST_F(RNGTests, Test_TruncatedNorm_2) { - auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); + auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, + &x1, 1.0f, 2.0f); - auto op = new sd::ops::LegacyRandomOp(random::TruncatedNormalDistribution); - auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {}); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::LegacyRandomOp op(random::TruncatedNormalDistribution); + auto result = op.execute(_rngA, {&input}, {1.0f, 2.0f}, {}); - auto z = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(x1.isSameShape(z)); - ASSERT_TRUE(x1.equalsTo(z)); - delete op; + auto z = result.at(0); + ASSERT_TRUE(x1.isSameShape(z)); + ASSERT_TRUE(x1.equalsTo(z)); } - TEST_F(RNGTests, Test_Binomial_2) { - auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); - - RandomLauncher::fillBinomial(LaunchContext::defaultContext(), _rngB, &x1, 3, 0.5f); - - auto op = new sd::ops::LegacyRandomOp(random::BinomialDistributionEx); - auto result = op->execute(_rngA, {&input}, {0.5f}, {3}); + auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - ASSERT_EQ(Status::OK(), result.status()); + RandomLauncher::fillBinomial(LaunchContext::defaultContext(), _rngB, &x1, 3, + 0.5f); - auto z = result.at(0); + sd::ops::LegacyRandomOp op(random::BinomialDistributionEx); + auto result = op.execute(_rngA, {&input}, {0.5f}, {3}); - ASSERT_TRUE(x1.isSameShape(z)); - ASSERT_TRUE(x1.equalsTo(z)); + ASSERT_EQ(Status::OK(), result.status()); - delete op; + auto z = result.at(0); + ASSERT_TRUE(x1.isSameShape(z)); + ASSERT_TRUE(x1.equalsTo(z)); } - TEST_F(RNGTests, Test_Bernoulli_2) { - auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); - - RandomLauncher::fillBernoulli(LaunchContext::defaultContext(), _rngB, &x1, 0.5f); + auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - auto op = new sd::ops::LegacyRandomOp(random::BernoulliDistribution); - auto result = op->execute(_rngA, {&input}, {0.5f}, {}); + RandomLauncher::fillBernoulli(LaunchContext::defaultContext(), _rngB, &x1, + 0.5f); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); + sd::ops::LegacyRandomOp op(random::BernoulliDistribution); + auto result = op.execute(_rngA, {&input}, {0.5f}, {}); - ASSERT_TRUE(x1.isSameShape(z)); - ASSERT_TRUE(x1.equalsTo(z)); + ASSERT_EQ(Status::OK(), result.status()); - delete op; + auto z = result.at(0); + ASSERT_TRUE(x1.isSameShape(z)); + ASSERT_TRUE(x1.equalsTo(z)); } TEST_F(RNGTests, Test_GaussianDistribution_1) { - auto x = NDArrayFactory::create('c', {2}, {10, 10}); - auto exp0 = NDArrayFactory::create('c', {10, 10}); + auto x = NDArrayFactory::create('c', {2}, {10, 10}); + auto exp0 = NDArrayFactory::create('c', {10, 10}); + sd::ops::random_normal op; + auto result = op.evaluate({&x}, {0.0, 1.0f}, {}); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::random_normal op; - auto result = op.evaluate({&x}, {0.0, 1.0f}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_TRUE(exp0.isSameShape(z)); - ASSERT_FALSE(exp0.equalsTo(z)); - - - ASSERT_FALSE(nexp0->equalsTo(z)); - ASSERT_FALSE(nexp1->equalsTo(z)); - ASSERT_FALSE(nexp2->equalsTo(z)); - + auto z = result.at(0); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + ASSERT_FALSE(nexp0->equalsTo(z)); + ASSERT_FALSE(nexp1->equalsTo(z)); + ASSERT_FALSE(nexp2->equalsTo(z)); } TEST_F(RNGTests, Test_BernoulliDistribution_1) { - auto x = NDArrayFactory::create('c', {2}, {10, 10}); - auto exp0 = NDArrayFactory::create('c', {10, 10}); + auto x = NDArrayFactory::create('c', {2}, {10, 10}); + auto exp0 = NDArrayFactory::create('c', {10, 10}); + sd::ops::random_bernoulli op; + auto result = op.evaluate({&x}, {0.5f}, {}); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::random_bernoulli op; - auto result = op.evaluate({&x}, {0.5f}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_FALSE(exp0.equalsTo(z)); - - ASSERT_FALSE(nexp0->equalsTo(z)); - ASSERT_FALSE(nexp1->equalsTo(z)); - ASSERT_FALSE(nexp2->equalsTo(z)); + auto z = result.at(0); + ASSERT_FALSE(exp0.equalsTo(z)); + ASSERT_FALSE(nexp0->equalsTo(z)); + ASSERT_FALSE(nexp1->equalsTo(z)); + ASSERT_FALSE(nexp2->equalsTo(z)); } - TEST_F(RNGTests, Test_ExponentialDistribution_1) { - auto x = NDArrayFactory::create('c', {2}, {10, 10}); - auto exp0 = NDArrayFactory::create('c', {10, 10}); - - - sd::ops::random_exponential op; - auto result = op.evaluate({&x}, {0.25f}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_TRUE(exp0.isSameShape(z)); - ASSERT_FALSE(exp0.equalsTo(z)); - // - z->printBuffer("\nExponential1"); - auto mean = z->reduceNumber(reduce::Mean); - auto variance = z->varianceNumber(variance::SummaryStatsVariance, false); - mean.printBuffer("Mean for exponential with param 0.25 (4 exp) is"); - variance.printBuffer("Variance for exponential with param 0.25 (16 exp) is"); - ASSERT_FALSE(nexp0->equalsTo(z)); - ASSERT_FALSE(nexp1->equalsTo(z)); - ASSERT_FALSE(nexp2->equalsTo(z)); - -// delete result; + auto x = NDArrayFactory::create('c', {2}, {10, 10}); + auto exp0 = NDArrayFactory::create('c', {10, 10}); + + sd::ops::random_exponential op; + auto result = op.evaluate({&x}, {0.25f}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + // + // z->printBuffer("\nExponential1"); + auto mean = z.reduceNumber(reduce::Mean); + auto variance = z.varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean for exponential with param 0.25 (4 exp) is"); + variance.printBuffer("Variance for exponential with param 0.25 (16 exp) is"); + ASSERT_FALSE(nexp0->equalsTo(z)); + ASSERT_FALSE(nexp1->equalsTo(z)); + ASSERT_FALSE(nexp2->equalsTo(z)); + + // delete result; } TEST_F(RNGTests, Test_ExponentialDistribution_1_SGA) { - auto x = NDArrayFactory::create('c', {2}, {10, 10}); - auto exp0 = NDArrayFactory::create('c', {10, 10}); - - - sd::ops::random_exponential op; - auto result = op.evaluate({&x}, {1.f}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_TRUE(exp0.isSameShape(z)); - ASSERT_FALSE(exp0.equalsTo(z)); - // - z->printBuffer("\nExponential2"); - auto mean = z->reduceNumber(reduce::Mean); - auto variance = z->varianceNumber(variance::SummaryStatsVariance, false); - mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); - variance.printBuffer("Variance for exponential with param 1. (1 exp) is"); - ASSERT_FALSE(nexp0->equalsTo(z)); - ASSERT_FALSE(nexp1->equalsTo(z)); - ASSERT_FALSE(nexp2->equalsTo(z)); - - + auto x = NDArrayFactory::create('c', {2}, {10, 10}); + auto exp0 = NDArrayFactory::create('c', {10, 10}); + + sd::ops::random_exponential op; + auto result = op.evaluate({&x}, {1.f}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + // + // z->printBuffer("\nExponential2"); + auto mean = z.reduceNumber(reduce::Mean); + auto variance = z.varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); + variance.printBuffer("Variance for exponential with param 1. (1 exp) is"); + ASSERT_FALSE(nexp0->equalsTo(z)); + ASSERT_FALSE(nexp1->equalsTo(z)); + ASSERT_FALSE(nexp2->equalsTo(z)); } TEST_F(RNGTests, Test_ExponentialDistribution_2_SGA) { - auto x = NDArrayFactory::create('c', {2}, {10, 10}); - auto exp0 = NDArrayFactory::create('c', {10, 10}); - RandomGenerator oc(2716049175077475646L, -6182841917129177862L); - - sd::ops::random_exponential op; - RandomLauncher::fillExponential(x.getContext(), oc, &exp0, 2.f); - auto result = op.evaluate({&x}, {1.f}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_TRUE(exp0.isSameShape(z)); - ASSERT_FALSE(exp0.equalsTo(z)); - // -// z->printBuffer("\nExponential2+"); - auto mean = z->reduceNumber(reduce::Mean); - auto variance = z->varianceNumber(variance::SummaryStatsVariance, false); - mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); - variance.printBuffer("Variance for exponential with param 1. (1 exp) is"); - ASSERT_FALSE(nexp0->equalsTo(z)); - ASSERT_FALSE(nexp1->equalsTo(z)); - ASSERT_FALSE(nexp2->equalsTo(z)); - mean = exp0.reduceNumber(reduce::Mean); - variance = exp0.varianceNumber(variance::SummaryStatsVariance, false); - mean.printBuffer("Mean for exponential with param 2.0 (1/2 exp) is"); - variance.printBuffer("Variance for exponential with param 2. (1/2 exp) is"); + auto x = NDArrayFactory::create('c', {2}, {10, 10}); + auto exp0 = NDArrayFactory::create('c', {10, 10}); + RandomGenerator oc(2716049175077475646L, -6182841917129177862L); + + sd::ops::random_exponential op; + RandomLauncher::fillExponential(x.getContext(), oc, &exp0, 2.f); + auto result = op.evaluate({&x}, {1.f}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + // + // z->printBuffer("\nExponential2+"); + auto mean = z.reduceNumber(reduce::Mean); + auto variance = z.varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); + variance.printBuffer("Variance for exponential with param 1. (1 exp) is"); + ASSERT_FALSE(nexp0->equalsTo(z)); + ASSERT_FALSE(nexp1->equalsTo(z)); + ASSERT_FALSE(nexp2->equalsTo(z)); + mean = exp0.reduceNumber(reduce::Mean); + variance = exp0.varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean for exponential with param 2.0 (1/2 exp) is"); + variance.printBuffer("Variance for exponential with param 2. (1/2 exp) is"); } TEST_F(RNGTests, Test_ExponentialDistribution_3_SGA) { - auto x = NDArrayFactory::create('c', {2}, {1000, 1000}); - auto exp0 = NDArrayFactory::create('c', {1000, 1000}); - RandomGenerator oc(2716049175077475646L, -6182841917129177862L); - auto expMean = NDArrayFactory::create(0.5f); - auto expVar = NDArrayFactory::create(0.25f); - sd::ops::random_exponential op; - RandomLauncher::fillExponential(exp0.getContext(), oc, &exp0, 2.f); - - auto result = op.evaluate({&x}, {1.}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - //ASSERT_TRUE(exp0.isSameShape(z)); - //ASSERT_FALSE(exp0.equalsTo(z)); - // -// z->printBuffer("\nExponential2+"); - auto mean = z->reduceNumber(reduce::Mean); - auto variance = z->varianceNumber(variance::SummaryStatsVariance, false); - mean.printBuffer("Mean"); - variance.printBuffer("Variance"); - ASSERT_NEAR(mean.e(0), 1.f, 1.e-2f); - ASSERT_NEAR(variance.e(0), 1.f, 1.e-2f); -// mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); -// variance.printBuffer("Variance for exponential with param 1. (1 exp) is"); -// ASSERT_FALSE(nexp0->equalsTo(z)); -// ASSERT_FALSE(nexp1->equalsTo(z)); -// ASSERT_FALSE(nexp2->equalsTo(z)); - mean = exp0.reduceNumber(reduce::Mean); - variance = exp0.varianceNumber(variance::SummaryStatsVariance, false); - mean.printBuffer("Mean for exponential with param 2.0 (1/2 exp) is"); - variance.printBuffer("Variance for exponential with param 2. (1/4 exp) is"); - ASSERT_TRUE(mean.equalsTo(expMean, 1.e-3)); - ASSERT_TRUE(variance.equalsTo(expVar, 1.e-3)); - RandomLauncher::fillExponential(exp0.getContext(), oc, &exp0, 1.f); - mean = exp0.reduceNumber(reduce::Mean); - variance = exp0.varianceNumber(variance::SummaryStatsVariance, false); - mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); - variance.printBuffer("Variance for exponential with param 1.0 (1 exp) is"); + auto x = NDArrayFactory::create('c', {2}, {1000, 1000}); + auto exp0 = NDArrayFactory::create('c', {1000, 1000}); + RandomGenerator oc(2716049175077475646L, -6182841917129177862L); + auto expMean = NDArrayFactory::create(0.5f); + auto expVar = NDArrayFactory::create(0.25f); + sd::ops::random_exponential op; + RandomLauncher::fillExponential(exp0.getContext(), oc, &exp0, 2.f); + + auto result = op.evaluate({&x}, {1.}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // ASSERT_TRUE(exp0.isSameShape(z)); + // ASSERT_FALSE(exp0.equalsTo(z)); + // + // z->printBuffer("\nExponential2+"); + auto mean = z.reduceNumber(reduce::Mean); + auto variance = z.varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean"); + variance.printBuffer("Variance"); + ASSERT_NEAR(mean.e(0), 1.f, 1.e-2f); + ASSERT_NEAR(variance.e(0), 1.f, 1.e-2f); + // mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); + // variance.printBuffer("Variance for exponential with param 1. (1 exp) + // is"); ASSERT_FALSE(nexp0->equalsTo(z)); + // ASSERT_FALSE(nexp1->equalsTo(z)); + // ASSERT_FALSE(nexp2->equalsTo(z)); + mean = exp0.reduceNumber(reduce::Mean); + variance = exp0.varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean for exponential with param 2.0 (1/2 exp) is"); + variance.printBuffer("Variance for exponential with param 2. (1/4 exp) is"); + ASSERT_TRUE(mean.equalsTo(expMean, 1.e-3)); + ASSERT_TRUE(variance.equalsTo(expVar, 1.e-3)); + RandomLauncher::fillExponential(exp0.getContext(), oc, &exp0, 1.f); + mean = exp0.reduceNumber(reduce::Mean); + variance = exp0.varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); + variance.printBuffer("Variance for exponential with param 1.0 (1 exp) is"); } TEST_F(RNGTests, Test_ExponentialDistribution_2) { - auto x = NDArrayFactory::create('c', {2}, {10, 10}); - auto y = NDArrayFactory::create('c', {10, 10}); - auto exp0 = NDArrayFactory::create('c', {10, 10}); - - y.assign(1.0); - - - sd::ops::random_exponential op; - auto result = op.evaluate({&x, &y}, {0.25f}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_TRUE(exp0.isSameShape(z)); - ASSERT_FALSE(exp0.equalsTo(z)); + auto x = NDArrayFactory::create('c', {2}, {10, 10}); + auto y = NDArrayFactory::create('c', {10, 10}); + auto exp0 = NDArrayFactory::create('c', {10, 10}); + y.assign(1.0); - ASSERT_FALSE(nexp0->equalsTo(z)); - ASSERT_FALSE(nexp1->equalsTo(z)); - ASSERT_FALSE(nexp2->equalsTo(z)); + sd::ops::random_exponential op; + auto result = op.evaluate({&x, &y}, {0.25f}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + ASSERT_FALSE(nexp0->equalsTo(z)); + ASSERT_FALSE(nexp1->equalsTo(z)); + ASSERT_FALSE(nexp2->equalsTo(z)); } TEST_F(RNGTests, Test_PoissonDistribution_1) { - auto x = NDArrayFactory::create('c', {1}, {10}); - auto la = NDArrayFactory::create('c', {2, 3}); - auto exp0 = NDArrayFactory::create('c', {10, 2, 3}); - - la.linspace(1.0); - - - sd::ops::random_poisson op; - auto result = op.evaluate({&x, &la}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {1}, {10}); + auto la = NDArrayFactory::create('c', {2, 3}); + auto exp0 = NDArrayFactory::create('c', {10, 2, 3}); - auto z = result.at(0); -// z->printIndexedBuffer("Poisson distribution"); - ASSERT_TRUE(exp0.isSameShape(z)); - ASSERT_FALSE(exp0.equalsTo(z)); + la.linspace(1.0); + sd::ops::random_poisson op; + auto result = op.evaluate({&x, &la}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Poisson distribution"); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); } TEST_F(RNGTests, Test_GammaDistribution_1) { - auto x = NDArrayFactory::create('c', {1}, {10}); - auto al = NDArrayFactory::create('c', {2, 3}); - auto exp0 = NDArrayFactory::create('c', {10, 2, 3}); + auto x = NDArrayFactory::create('c', {1}, {10}); + auto al = NDArrayFactory::create('c', {2, 3}); + auto exp0 = NDArrayFactory::create('c', {10, 2, 3}); - al.linspace(1.0); - - - sd::ops::random_gamma op; - auto result = op.evaluate({&x, &al}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); -// z->printIndexedBuffer("Gamma distribution"); - ASSERT_TRUE(exp0.isSameShape(z)); - ASSERT_FALSE(exp0.equalsTo(z)); + al.linspace(1.0); + sd::ops::random_gamma op; + auto result = op.evaluate({&x, &al}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Gamma distribution"); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); } TEST_F(RNGTests, Test_GammaDistribution_2) { - auto x = NDArrayFactory::create('c', {1}, {10}); - auto al = NDArrayFactory::create('c', {2, 3}); - auto be = NDArrayFactory::create('c', {2, 3}); - auto exp0 = NDArrayFactory::create('c', {10, 2, 3}); - - al.linspace(1.0); - be.assign(1.0); - - sd::ops::random_gamma op; - auto result = op.evaluate({&x, &al, &be}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); -// z->printIndexedBuffer("Gamma distribution"); - ASSERT_TRUE(exp0.isSameShape(z)); - ASSERT_FALSE(exp0.equalsTo(z)); + auto x = NDArrayFactory::create('c', {1}, {10}); + auto al = NDArrayFactory::create('c', {2, 3}); + auto be = NDArrayFactory::create('c', {2, 3}); + auto exp0 = NDArrayFactory::create('c', {10, 2, 3}); + + al.linspace(1.0); + be.assign(1.0); + + sd::ops::random_gamma op; + auto result = op.evaluate({&x, &al, &be}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printIndexedBuffer("Gamma distribution"); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); } TEST_F(RNGTests, Test_GammaDistribution_3) { - auto x = NDArrayFactory::create('c', {1}, {10}); - auto al = NDArrayFactory::create('c', {3, 1}); - auto be = NDArrayFactory::create('c', {1, 2}); - auto exp0 = NDArrayFactory::create('c', {10, 3, 2}); - - al.linspace(1.0); - be.assign(2.0); - - sd::ops::random_gamma op; - auto result = op.evaluate({&x, &al, &be}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); -// z->printIndexedBuffer("Gamma distribution"); - ASSERT_TRUE(exp0.isSameShape(z)); - ASSERT_FALSE(exp0.equalsTo(z)); - - + auto x = NDArrayFactory::create('c', {1}, {10}); + auto al = NDArrayFactory::create('c', {3, 1}); + auto be = NDArrayFactory::create('c', {1, 2}); + auto exp0 = NDArrayFactory::create('c', {10, 3, 2}); + + al.linspace(1.0); + be.assign(2.0); + + sd::ops::random_gamma op; + auto result = op.evaluate({&x, &al, &be}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printIndexedBuffer("Gamma distribution"); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); } TEST_F(RNGTests, Test_GammaDistribution_4) { @@ -1057,12 +1062,12 @@ TEST_F(RNGTests, Test_GammaDistribution_4) { ASSERT_FALSE(exp0.equalsTo(z)); sd::ops::reduce_mean testOps1; sd::ops::reduce_variance testOps2; - auto testRes1 = testOps1.evaluate({z}); - auto testRes2 = testOps2.evaluate({z}); + auto testRes1 = testOps1.evaluate({&z}); + auto testRes2 = testOps2.evaluate({&z}); // testRes1[0]->printBuffer("Mean (expected 1.0)"); // testRes2[0]->printBuffer("Variance (expected 0.5)"); - ASSERT_NEAR(testRes1[0]->t(0), 1.0f, 0.01); - ASSERT_NEAR(testRes2[0]->t(0), 0.5f, 0.02); + ASSERT_NEAR(testRes1[0].t(0), 1.0f, 0.01); + ASSERT_NEAR(testRes2[0].t(0), 0.5f, 0.02); } TEST_F(RNGTests, Test_GammaDistribution_5) { @@ -1085,29 +1090,27 @@ TEST_F(RNGTests, Test_GammaDistribution_5) { // z->printIndexedBuffer("Gamma distributed"); sd::ops::reduce_mean testOps1; sd::ops::reduce_variance testOps2; - auto testRes1 = testOps1.evaluate({z}); - auto testRes2 = testOps2.evaluate({z}); + auto testRes1 = testOps1.evaluate({&z}); + auto testRes2 = testOps2.evaluate({&z}); // testRes1[0]->printBuffer("Mean (expected 0.1)"); // testRes2[0]->printBuffer("Variance (expected 0.05)"); - ASSERT_NEAR(testRes1[0]->t(0), 0.1f, 0.02); - ASSERT_NEAR(testRes2[0]->t(0), 0.05f, 0.02); + ASSERT_NEAR(testRes1[0].t(0), 0.1f, 0.02); + ASSERT_NEAR(testRes2[0].t(0), 0.05f, 0.02); } TEST_F(RNGTests, Test_UniformDistribution_04) { - auto x = NDArrayFactory::create('c', {1}, {10}); - auto al = NDArrayFactory::create(1); - auto be = NDArrayFactory::create(20); - auto exp0 = NDArrayFactory::create('c', {10}); - - - sd::ops::randomuniform op; - auto result = op.evaluate({&x, &al, &be}, {}, {DataType::INT32}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_TRUE(exp0.isSameShape(z)); - ASSERT_FALSE(exp0.equalsTo(z)); - + auto x = NDArrayFactory::create('c', {1}, {10}); + auto al = NDArrayFactory::create(1); + auto be = NDArrayFactory::create(20); + auto exp0 = NDArrayFactory::create('c', {10}); + + sd::ops::randomuniform op; + auto result = op.evaluate({&x, &al, &be}, {}, {DataType::INT32}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); } TEST_F(RNGTests, Test_UniformDistribution_05) { @@ -1126,341 +1129,383 @@ TEST_F(RNGTests, Test_UniformDistribution_05) { ASSERT_FALSE(exp0.equalsTo(z)); sd::ops::reduce_max checkOp; - auto checkResult = checkOp.evaluate({z}); - checkResult[0]->printIndexedBuffer("Max on uniform with 0 to 1 on 100M cases is"); + auto checkResult = checkOp.evaluate({&z}); + checkResult[0].printIndexedBuffer("Max on uniform with 0 to 1 on 100M cases is"); } namespace sd { - namespace tests { - static void fillList(Nd4jLong seed, int numberOfArrays, std::vector &shape, std::vector &list, sd::graph::RandomGenerator *rng) { - rng->setSeed((int) seed); - - for (int i = 0; i < numberOfArrays; i++) { - auto arrayI = NDArrayFactory::create(shape); - auto arrayR = NDArrayFactory::create_('c', shape); - auto min = NDArrayFactory::create(0.0); - auto max = NDArrayFactory::create(1.0); - sd::ops::randomuniform op; - op.execute(*rng, {&arrayI, &min, &max}, {arrayR}, {}, {DataType::DOUBLE}, {}, {}, false); - - list.emplace_back(arrayR); - } - }; - } -} +namespace tests { +static void fillList(Nd4jLong seed, int numberOfArrays, + std::vector& shape, std::vector& list, + sd::graph::RandomGenerator* rng) { + rng->setSeed((int)seed); + + for (int i = 0; i < numberOfArrays; i++) { + auto arrayI = NDArrayFactory::create(shape); + auto arrayR = NDArrayFactory::create_('c', shape); + auto min = NDArrayFactory::create(0.0); + auto max = NDArrayFactory::create(1.0); + sd::ops::randomuniform op; + op.execute(*rng, {&arrayI, &min, &max}, {arrayR}, {}, {DataType::DOUBLE}, + {}, {}, false); + + list.emplace_back(arrayR); + } +}; +} // namespace tests +} // namespace sd TEST_F(RNGTests, Test_Reproducibility_1) { - Nd4jLong seed = 123; + Nd4jLong seed = 123; - std::vector shape = {32, 3, 28, 28}; - sd::graph::RandomGenerator rng; + std::vector shape = {32, 3, 28, 28}; + sd::graph::RandomGenerator rng; - std::vector expList; - sd::tests::fillList(seed, 10, shape, expList, &rng); + std::vector expList; + sd::tests::fillList(seed, 10, shape, expList, &rng); - for (int e = 0; e < 2; e++) { - std::vector trialList; - sd::tests::fillList(seed, 10, shape, trialList, &rng); + for (int e = 0; e < 2; e++) { + std::vector trialList; + sd::tests::fillList(seed, 10, shape, trialList, &rng); - for (int a = 0; a < expList.size(); a++) { - auto arrayE = expList[a]; - auto arrayT = trialList[a]; + for (int a = 0; a < expList.size(); a++) { + auto arrayE = expList[a]; + auto arrayT = trialList[a]; - bool t = arrayE->equalsTo(arrayT); - if (!t) { - // nd4j_printf("Failed at iteration [%i] for array [%i]\n", e, a); - ASSERT_TRUE(false); - } + bool t = arrayE->equalsTo(arrayT); + if (!t) { + // nd4j_printf("Failed at iteration [%i] for array [%i]\n", e, a); + ASSERT_TRUE(false); + } - delete arrayT; - } + delete arrayT; } + } - for (auto v: expList) - delete v; + for (auto v : expList) delete v; } #ifndef DEBUG_BUILD TEST_F(RNGTests, Test_Reproducibility_2) { - Nd4jLong seed = 123; + Nd4jLong seed = 123; - std::vector shape = {32, 3, 64, 64}; - sd::graph::RandomGenerator rng; + std::vector shape = {32, 3, 64, 64}; + sd::graph::RandomGenerator rng; - std::vector expList; - sd::tests::fillList(seed, 10, shape, expList, &rng); + std::vector expList; + sd::tests::fillList(seed, 10, shape, expList, &rng); - for (int e = 0; e < 2; e++) { - std::vector trialList; - sd::tests::fillList(seed, 10, shape, trialList, &rng); + for (int e = 0; e < 2; e++) { + std::vector trialList; + sd::tests::fillList(seed, 10, shape, trialList, &rng); - for (int a = 0; a < expList.size(); a++) { - auto arrayE = expList[a]; - auto arrayT = trialList[a]; + for (int a = 0; a < expList.size(); a++) { + auto arrayE = expList[a]; + auto arrayT = trialList[a]; - bool t = arrayE->equalsTo(arrayT); - if (!t) { - // nd4j_printf("Failed at iteration [%i] for array [%i]\n", e, a); + bool t = arrayE->equalsTo(arrayT); + if (!t) { + // nd4j_printf("Failed at iteration [%i] for array [%i]\n", e, a); - for (Nd4jLong f = 0; f < arrayE->lengthOf(); f++) { - double x = arrayE->e(f); - double y = arrayT->e(f); + for (Nd4jLong f = 0; f < arrayE->lengthOf(); f++) { + double x = arrayE->e(f); + double y = arrayT->e(f); - if (sd::math::nd4j_re(x, y) > 0.1) { - // nd4j_printf("E[%lld] %f != T[%lld] %f\n", (long long) f, (float) x, (long long) f, (float) y); - throw std::runtime_error("boom"); - } - } + if (sd::math::nd4j_re(x, y) > 0.1) { + // nd4j_printf("E[%lld] %f != T[%lld] %f\n", (long long) f, (float) + // x, (long long) f, (float) y); + throw std::runtime_error("boom"); + } + } - // just breaker, since test failed - ASSERT_TRUE(false); - } + // just breaker, since test failed + ASSERT_TRUE(false); + } - delete arrayT; - } + delete arrayT; } + } - for (auto v: expList) - delete v; + for (auto v : expList) delete v; } TEST_F(RNGTests, Test_Uniform_4) { - auto x1 = NDArrayFactory::create('c', {1000000}); + auto x1 = NDArrayFactory::create('c', {1000000}); - RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, 1.0, 2.0); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, 1.0, + 2.0); - /* Check up distribution */ - auto mean = x1.reduceNumber(reduce::Mean); - // mean.printIndexedBuffer("Mean should be 1.5"); - auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); + /* Check up distribution */ + auto mean = x1.reduceNumber(reduce::Mean); + // mean.printIndexedBuffer("Mean should be 1.5"); + auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); - auto deviation = x1.varianceNumber(variance::SummaryStatsVariance, false); - //deviation /= (double)x1.lengthOf(); - // deviation.printIndexedBuffer("Deviation should be 1/12 (0.083333)"); + auto deviation = x1.varianceNumber(variance::SummaryStatsVariance, false); + // deviation /= (double)x1.lengthOf(); + // deviation.printIndexedBuffer("Deviation should be 1/12 (0.083333)"); - ASSERT_NEAR(mean.e(0), 1.5, 1e-3); - ASSERT_NEAR(1/12., deviation.e(0), 1e-3); + ASSERT_NEAR(mean.e(0), 1.5, 1e-3); + ASSERT_NEAR(1 / 12., deviation.e(0), 1e-3); } #endif TEST_F(RNGTests, test_choice_1) { - const auto x = NDArrayFactory::linspace(0, 10, 11); - const auto prob = NDArrayFactory::valueOf({11}, 1.0/11, 'c'); - auto z = NDArrayFactory::create('c', {1000}); - - RandomGenerator rng(119, 256); - NativeOpExecutioner::execRandom(sd::LaunchContext ::defaultContext(), random::Choice, &rng, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), prob->buffer(), prob->shapeInfo(), prob->specialBuffer(), prob->specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr); - - // z.printIndexedBuffer("z"); - - delete x; - delete prob; + const auto x = NDArrayFactory::linspace(0, 10, 11); + const auto prob = NDArrayFactory::valueOf({11}, 1.0 / 11, 'c'); + auto z = NDArrayFactory::create('c', {1000}); + + RandomGenerator rng(119, 256); + NativeOpExecutioner::execRandom( + sd::LaunchContext ::defaultContext(), random::Choice, &rng, x->buffer(), + x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), prob->buffer(), + prob->shapeInfo(), prob->specialBuffer(), prob->specialShapeInfo(), + z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr); + + // z.printIndexedBuffer("z"); + + delete x; + delete prob; } TEST_F(RNGTests, test_uniform_119) { - auto x = NDArrayFactory::create('c', {2}, {1, 5}); - auto z = NDArrayFactory::create('c', {1, 5}); - + auto x = NDArrayFactory::create('c', {2}, {1, 5}); + auto z = NDArrayFactory::create('c', {1, 5}); - sd::ops::randomuniform op; - auto status = op.execute({&x}, {&z}, {1.0, 2.0}, {}, {}); - ASSERT_EQ(Status::OK(), status); + sd::ops::randomuniform op; + auto status = op.execute({&x}, {&z}, {1.0, 2.0}, {}, {}); + ASSERT_EQ(Status::OK(), status); } TEST_F(RNGTests, test_multinomial_1) { - - NDArray probs('f', { 3, 3 }, { 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3 }, sd::DataType::FLOAT32); - NDArray expected('f', { 3, 3 }, { 0., 1, 2, 2, 0, 0, 1, 2, 1 }, sd::DataType::INT64); - NDArray output('f', { 3, 3 }, sd::DataType::INT64); - NDArray samples('f', { 1 }, std::vector({3}), sd::DataType::INT32); - - sd::ops::random_multinomial op; - RandomGenerator rng(1234, 1234); - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64}, {}, {}, false) ); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - NDArray probsZ('c', { 1, 3 }, { 0.3, 0.3, 0.3 }, sd::DataType::FLOAT32); - NDArray expectedZ('c', { 3, 3 }, { 0., 0, 0, 0, 0, 0, 0, 0, 0 }, sd::DataType::INT64); - - auto result = op.evaluate({ &probsZ, &samples }, { }, { 1, INT64 }); - auto outputZ = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expectedZ.isSameShape(outputZ)); - ASSERT_TRUE(expectedZ.equalsTo(outputZ)); + NDArray probs('f', {3, 3}, {0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3}, + sd::DataType::FLOAT32); + NDArray expected('f', {3, 3}, {0., 1, 2, 2, 0, 0, 1, 2, 1}, + sd::DataType::INT64); + NDArray output('f', {3, 3}, sd::DataType::INT64); + NDArray samples('f', {1}, std::vector({3}), sd::DataType::INT32); + + sd::ops::random_multinomial op; + RandomGenerator rng(1234, 1234); + ASSERT_EQ(Status::OK(), op.execute(rng, {&probs, &samples}, {&output}, {}, + {0}, {}, {INT64}, false)); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + NDArray probsZ('c', {1, 3}, {0.3, 0.3, 0.3}, sd::DataType::FLOAT32); + NDArray expectedZ('c', {3, 3}, {0., 0, 0, 0, 0, 0, 0, 0, 0}, + sd::DataType::INT64); + + auto result = op.evaluate({&probsZ, &samples}, {}, {1}, {}, {INT64}); + auto outputZ = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expectedZ.isSameShape(outputZ)); + ASSERT_TRUE(expectedZ.equalsTo(outputZ)); } TEST_F(RNGTests, test_multinomial_2) { - - NDArray samples('c', { 1 }, std::vector{ 20 }, sd::DataType::INT32); - NDArray probs('c', { 3, 5 }, { 0.2, 0.3, 0.5, 0.3, 0.5, 0.2, 0.5, 0.2, 0.3, 0.35, 0.25, 0.3, 0.25, 0.25, 0.5 }, sd::DataType::FLOAT32); - NDArray expected('c', { 3, 20 }, { 0, 2, 0, 2, 0, 4, 2, 0, 1, 2, 0, 2, 3, 0, 0, 2, 4, 4, 1, 0, 2, 3, 2, 3, 0, 1, 3, 1, 1, 1, 2, 4, 3, 3, 1, 4, 4, 2, 0, 0, 3, 3, 3, 0, 0, 2, 2, 3, 3, 0, 0, 2, 3, 4, 2, 2, 3, 2, 1, 2 }, sd::DataType::INT64); - NDArray output('c', { 3, 20 }, sd::DataType::INT64); - - sd::ops::random_multinomial op; - RandomGenerator rng(1234, 1234); - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false)); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - NDArray probs2('c', { 5, 3 }, { 0.2, 0.3, 0.5, 0.3, 0.5, 0.2, 0.5, 0.2, 0.3, 0.35, 0.25, 0.3, 0.25, 0.25, 0.5 }, sd::DataType::FLOAT32); - NDArray expected2('c', { 20, 3 }, { 0, 2, 3, 2, 3, 3, 0, 2, 3, 2, 3, 0, 0, 0, 0, 4, 1, 2, 2, 3, 2, 3, 1, 3, 1, 1, 3, 2, 1, 0, 0, 2, 0, 2, 4, 2, 3, 3, 3, 0, 3, 4, 0, 1, 2, 2, 0, 2, 4, 4, 0, 4, 2, 2, 1, 0, 1, 0, 0, 2 }, sd::DataType::INT64); - NDArray output2('c', { 20, 3 }, sd::DataType::INT64); - - rng.setStates(1234, 1234); - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs2, &samples }, { &output2 }, {}, { 1, INT64 }, {}, {}, false)); - ASSERT_TRUE(expected2.isSameShape(output2)); - ASSERT_TRUE(expected2.equalsTo(output2)); + NDArray samples('c', {1}, std::vector{20}, sd::DataType::INT32); + NDArray probs('c', {3, 5}, + {0.2, 0.3, 0.5, 0.3, 0.5, 0.2, 0.5, 0.2, 0.3, 0.35, 0.25, 0.3, + 0.25, 0.25, 0.5}, + sd::DataType::FLOAT32); + NDArray expected('c', {3, 20}, + {0, 2, 0, 2, 0, 4, 2, 0, 1, 2, 0, 2, 3, 0, 0, 2, 4, 4, 1, 0, + 2, 3, 2, 3, 0, 1, 3, 1, 1, 1, 2, 4, 3, 3, 1, 4, 4, 2, 0, 0, + 3, 3, 3, 0, 0, 2, 2, 3, 3, 0, 0, 2, 3, 4, 2, 2, 3, 2, 1, 2}, + sd::DataType::INT64); + NDArray output('c', {3, 20}, sd::DataType::INT64); + + sd::ops::random_multinomial op; + RandomGenerator rng(1234, 1234); + ASSERT_EQ(Status::OK(), op.execute(rng, {&probs, &samples}, {&output}, {}, + {0}, {}, {INT64}, false)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + NDArray probs2('c', {5, 3}, + {0.2, 0.3, 0.5, 0.3, 0.5, 0.2, 0.5, 0.2, 0.3, 0.35, 0.25, 0.3, + 0.25, 0.25, 0.5}, + sd::DataType::FLOAT32); + NDArray expected2('c', {20, 3}, {0, 2, 3, 2, 3, 3, 0, 2, 3, 2, 3, 0, 0, 0, 0, + 4, 1, 2, 2, 3, 2, 3, 1, 3, 1, 1, 3, 2, 1, 0, + 0, 2, 0, 2, 4, 2, 3, 3, 3, 0, 3, 4, 0, 1, 2, + 2, 0, 2, 4, 4, 0, 4, 2, 2, 1, 0, 1, 0, 0, 2}, + sd::DataType::INT64); + NDArray output2('c', {20, 3}, sd::DataType::INT64); + + rng.setStates(1234, 1234); + ASSERT_EQ(Status::OK(), op.execute(rng, {&probs2, &samples}, {&output2}, {}, + {1}, {}, {INT64}, false)); + ASSERT_TRUE(expected2.isSameShape(output2)); + ASSERT_TRUE(expected2.equalsTo(output2)); } TEST_F(RNGTests, test_multinomial_3) { - - NDArray probs('c', { 4, 3 }, { 0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3 }, sd::DataType::FLOAT32); - NDArray expected('c', { 4, 5 }, sd::DataType::INT64); - NDArray output('c', { 4, 5 }, sd::DataType::INT64); - NDArray samples('c', { 1 }, std::vector{ 5 }, sd::DataType::INT32); - RandomGenerator rng(1234, 1234); - - sd::ops::random_multinomial op; - - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 0, INT64 }, {}, {}, false)); - - rng.setStates(1234, 1234); - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false)); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + NDArray probs('c', {4, 3}, + {0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3}, + sd::DataType::FLOAT32); + NDArray expected('c', {4, 5}, sd::DataType::INT64); + NDArray output('c', {4, 5}, sd::DataType::INT64); + NDArray samples('c', {1}, std::vector{5}, sd::DataType::INT32); + RandomGenerator rng(1234, 1234); + + sd::ops::random_multinomial op; + + ASSERT_EQ(Status::OK(), op.execute(rng, {&probs, &samples}, {&expected}, {}, + {0}, {}, {INT64}, false)); + + rng.setStates(1234, 1234); + ASSERT_EQ(Status::OK(), op.execute(rng, {&probs, &samples}, {&output}, {}, + {0}, {}, {INT64}, false)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } TEST_F(RNGTests, test_multinomial_4) { - - NDArray probs('c', { 3, 4 }, { 0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3 }, sd::DataType::FLOAT32); - NDArray expected('c', { 5, 4 }, sd::DataType::INT64); - NDArray output('c', { 5, 4 }, sd::DataType::INT64); - NDArray samples('c', { 1 }, std::vector{ 5 }, sd::DataType::INT32); - - RandomGenerator rng(1234, 1234); - sd::ops::random_multinomial op; - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 1, INT64 }, {}, {}, false)); - - rng.setStates(1234, 1234); - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1, INT64 }, {}, {}, false)); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + NDArray probs('c', {3, 4}, + {0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3}, + sd::DataType::FLOAT32); + NDArray expected('c', {5, 4}, sd::DataType::INT64); + NDArray output('c', {5, 4}, sd::DataType::INT64); + NDArray samples('c', {1}, std::vector{5}, sd::DataType::INT32); + + RandomGenerator rng(1234, 1234); + sd::ops::random_multinomial op; + ASSERT_EQ(Status::OK(), op.execute(rng, {&probs, &samples}, {&expected}, {}, + {1}, {}, {INT64}, false)); + + rng.setStates(1234, 1234); + ASSERT_EQ(Status::OK(), op.execute(rng, {&probs, &samples}, {&output}, {}, + {1}, {}, {INT64}, false)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } TEST_F(RNGTests, test_multinomial_5) { - // multinomial as binomial if 2 classes used - int batchValue = 1; - int ClassValue = 2; - int Samples = 100000; + // multinomial as binomial if 2 classes used + int batchValue = 1; + int ClassValue = 2; + int Samples = 100000; - NDArray samples('c', { 1 }, std::vector{ 1.*Samples }, sd::DataType::INT32); + NDArray samples('c', {1}, std::vector{1. * Samples}, + sd::DataType::INT32); - NDArray probs('c', { ClassValue, batchValue }, { 1.0, 1.0 }, sd::DataType::FLOAT32); + NDArray probs('c', {ClassValue, batchValue}, {1.0, 1.0}, + sd::DataType::FLOAT32); - sd::ops::random_multinomial op; + sd::ops::random_multinomial op; - NDArray output('c', { Samples, batchValue }, sd::DataType::INT64); - RandomGenerator rng(1234, 1234); + NDArray output('c', {Samples, batchValue}, sd::DataType::INT64); + RandomGenerator rng(1234, 1234); - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1 }, {}, {}, false)); - auto deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false); - auto mean = output.meanNumber(); - // printf("Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); - // theoretical values for binomial - ASSERT_NEAR(0.5, deviation.e(0), 4e-3); // 1000000 3e-3); - ASSERT_NEAR(0.5, mean.e(0), 4e-3); // 1000000 3e-3); - - for (int i = 0; i < output.lengthOf(); i++) { - auto value = output.e(i); - ASSERT_TRUE(value >= 0 && value < ClassValue); - } + ASSERT_EQ(Status::OK(), op.execute(rng, {&probs, &samples}, {&output}, {}, + {1}, {}, {}, false)); - auto resultR = op.evaluate({ &probs, &samples }, { }, { 1 }); - auto outputR = resultR.at(0); - ASSERT_EQ(Status::OK(), resultR.status()); + auto deviation = + output.varianceNumber(variance::SummaryStatsStandardDeviation, false); + auto mean = output.meanNumber(); + // printf("Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); + // theoretical values for binomial + ASSERT_NEAR(0.5, deviation.e(0), 4e-3); // 1000000 3e-3); + ASSERT_NEAR(0.5, mean.e(0), 4e-3); // 1000000 3e-3); - deviation = outputR->varianceNumber(variance::SummaryStatsStandardDeviation, false); - mean = outputR->meanNumber(); - // printf("Random seed - Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); - ASSERT_NEAR(0.5, deviation.e(0), 45e-3); // 1000000 35e-3); - ASSERT_NEAR(0.5, mean.e(0), 45e-3); // 1000000 35e-3); - - for (int i = 0; i < outputR->lengthOf(); i++) { - auto value = outputR->e(i); - ASSERT_TRUE(value >= 0 && value < ClassValue); - } + for (int i = 0; i < output.lengthOf(); i++) { + auto value = output.e(i); + ASSERT_TRUE(value >= 0 && value < ClassValue); + } + auto resultR = op.evaluate({&probs, &samples}, {}, {1}); + auto outputR = resultR.at(0); + ASSERT_EQ(Status::OK(), resultR.status()); + + deviation = + outputR.varianceNumber(variance::SummaryStatsStandardDeviation, false); + mean = outputR.meanNumber(); + // printf("Random seed - Var: %f Mean: %f \n", deviation.e(0), + // mean.e(0)); + ASSERT_NEAR(0.5, deviation.e(0), 45e-3); // 1000000 35e-3); + ASSERT_NEAR(0.5, mean.e(0), 45e-3); // 1000000 35e-3); + + for (int i = 0; i < outputR.lengthOf(); i++) { + auto value = outputR.e(i); + ASSERT_TRUE(value >= 0 && value < ClassValue); + } } - TEST_F(RNGTests, test_multinomial_6) { + int batchValue = 1; + int ClassValue = 5; + int Samples = 100000; - int batchValue = 1; - int ClassValue = 5; - int Samples = 100000; - - NDArray samples('c', { 1 }, std::vector{ 1. * Samples }, sd::DataType::INT32); + NDArray samples('c', {1}, std::vector{1. * Samples}, + sd::DataType::INT32); - sd::ops::random_multinomial op; - NDArray probExpect('c', { ClassValue }, { 0.058, 0.096, 0.1576, 0.2598, 0.4287 }, sd::DataType::DOUBLE); + sd::ops::random_multinomial op; + NDArray probExpect('c', {ClassValue}, {0.058, 0.096, 0.1576, 0.2598, 0.4287}, + sd::DataType::DOUBLE); - // without seed - NDArray probsR('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, sd::DataType::FLOAT32); + // without seed + NDArray probsR('c', {batchValue, ClassValue}, {1., 1.5, 2., 2.5, 3.}, + sd::DataType::FLOAT32); - auto resultR = op.evaluate({ &probsR, &samples }, { }, { 0 }); - auto outputR = resultR.at(0); - ASSERT_EQ(Status::OK(), resultR.status()); + auto resultR = op.evaluate({&probsR, &samples}, {}, {0}); + auto outputR = resultR.at(0); + ASSERT_EQ(Status::OK(), resultR.status()); - NDArray countsR('c', { ClassValue }, { 0., 0, 0, 0, 0 }, sd::DataType::DOUBLE); - - for (int i = 0; i < outputR->lengthOf(); i++) { - auto value = outputR->e(i); - ASSERT_TRUE(value >= 0 && value < ClassValue); - double* z = countsR.bufferAsT(); - z[value] += 1; - } + NDArray countsR('c', {ClassValue}, {0., 0, 0, 0, 0}, sd::DataType::DOUBLE); - for (int i = 0; i < countsR.lengthOf(); i++) { - auto c = countsR.e(i); - auto p = probExpect.e(i); - // printf("Get freq : %f Expect freq: %f \n", c / Samples, p); - ASSERT_NEAR((c / Samples), p, 45e-3); // 1000000 35e-3); - } + for (int i = 0; i < outputR.lengthOf(); i++) { + auto value = outputR.e(i); + ASSERT_TRUE(value >= 0 && value < ClassValue); + double* z = countsR.bufferAsT(); + z[value] += 1; + } - auto deviation = outputR->varianceNumber(variance::SummaryStatsStandardDeviation, false); - auto mean = outputR->meanNumber(); - // printf("Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); - ASSERT_NEAR(1.2175, deviation.e(0), 45e-3); // 1000000 35e-3); - ASSERT_NEAR(2.906, mean.e(0), 45e-3); // 1000000 35e-3); + for (int i = 0; i < countsR.lengthOf(); i++) { + auto c = countsR.e(i); + auto p = probExpect.e(i); + // printf("Get freq : %f Expect freq: %f \n", c / Samples, p); + ASSERT_NEAR((c / Samples), p, 45e-3); // 1000000 35e-3); + } + auto deviation = + outputR.varianceNumber(variance::SummaryStatsStandardDeviation, false); + auto mean = outputR.meanNumber(); + // printf("Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); + ASSERT_NEAR(1.2175, deviation.e(0), 45e-3); // 1000000 35e-3); + ASSERT_NEAR(2.906, mean.e(0), 45e-3); // 1000000 35e-3); - RandomGenerator rng(1234, 1234); - NDArray probs('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, sd::DataType::FLOAT32); - NDArray output('c', { batchValue, Samples }, sd::DataType::INT64); + RandomGenerator rng(1234, 1234); + NDArray probs('c', {batchValue, ClassValue}, {1., 1.5, 2., 2.5, 3.}, + sd::DataType::FLOAT32); + NDArray output('c', {batchValue, Samples}, sd::DataType::INT64); - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false)); + ASSERT_EQ(Status::OK(), op.execute(rng, {&probs, &samples}, {&output}, {}, + {0}, {}, {INT64}, false)); - NDArray counts('c', { ClassValue }, { 0., 0, 0, 0, 0 }, sd::DataType::DOUBLE); + NDArray counts('c', {ClassValue}, {0., 0, 0, 0, 0}, sd::DataType::DOUBLE); - for (int i = 0; i < output.lengthOf(); i++) { - auto value = output.e(i); - ASSERT_TRUE(value >= 0 && value < ClassValue); - double* z = counts.bufferAsT(); - z[value] += 1; - } + for (int i = 0; i < output.lengthOf(); i++) { + auto value = output.e(i); + ASSERT_TRUE(value >= 0 && value < ClassValue); + double* z = counts.bufferAsT(); + z[value] += 1; + } - for (int i = 0; i < counts.lengthOf(); i++) { - auto c = counts.e(i); - auto p = probExpect.e(i); - // printf("Get freq : %f Expect freq: %f \n", c / Samples, p); - ASSERT_NEAR((c / Samples), p, 4e-3); // 1000000 3e-3); - } + for (int i = 0; i < counts.lengthOf(); i++) { + auto c = counts.e(i); + auto p = probExpect.e(i); + // printf("Get freq : %f Expect freq: %f \n", c / Samples, p); + ASSERT_NEAR((c / Samples), p, 4e-3); // 1000000 3e-3); + } - deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false); - mean = output.meanNumber(); - // printf("Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); - ASSERT_NEAR(1.2175, deviation.e(0), 5e-3); // 1000000 3e-3); - ASSERT_NEAR(2.906, mean.e(0), 5e-3); // 1000000 3e-3); + deviation = + output.varianceNumber(variance::SummaryStatsStandardDeviation, false); + mean = output.meanNumber(); + // printf("Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); + ASSERT_NEAR(1.2175, deviation.e(0), 5e-3); // 1000000 3e-3); + ASSERT_NEAR(2.906, mean.e(0), 5e-3); // 1000000 3e-3); } diff --git a/libnd4j/tests_cpu/layers_tests/ResultSetTests.cpp b/libnd4j/tests_cpu/layers_tests/ResultSetTests.cpp index 4ca8a3806572..ef4706aaaed5 100644 --- a/libnd4j/tests_cpu/layers_tests/ResultSetTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ResultSetTests.cpp @@ -18,32 +18,31 @@ // Created by raver on 4/18/2019. // -#include "testlayers.h" #include #include #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class ResultSetTests : public testing::Test { -public: - + public: }; TEST_F(ResultSetTests, basic_test_1) { - auto x = NDArrayFactory::create('c', {3, 5}); + auto x = NDArrayFactory::create('c', {3, 5}); - auto tensors = x.allTensorsAlongDimension({1}); - ASSERT_EQ(3, tensors.size()); + auto tensors = x.allTensorsAlongDimension({1}); + ASSERT_EQ(3, tensors.size()); - ResultSet set = tensors; - ASSERT_EQ(3, tensors.size()); - ASSERT_EQ(3, set.size()); + ResultSet set = tensors; + ASSERT_EQ(3, tensors.size()); + ASSERT_EQ(3, set.size()); - for (int e = 0; e < set.size(); e++) - ASSERT_EQ(5, set.at(e)->lengthOf()); + for (int e = 0; e < set.size(); e++) ASSERT_EQ(5, set.at(e).lengthOf()); - for (int e = 0; e < tensors.size(); e++) - ASSERT_EQ(5, tensors.at(e)->lengthOf()); + for (int e = 0; e < tensors.size(); e++) + ASSERT_EQ(5, tensors.at(e).lengthOf()); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/SanityTests.cpp b/libnd4j/tests_cpu/layers_tests/SanityTests.cpp index 7ca6732fe62b..2c076bcd73e5 100644 --- a/libnd4j/tests_cpu/layers_tests/SanityTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/SanityTests.cpp @@ -18,45 +18,37 @@ // Created by raver119 on 13/11/17. // -#include "testlayers.h" #include #include #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class SanityTests : public testing::Test { -public: - + public: }; - -TEST_F(SanityTests, VariableSpace_1) { - VariableSpace variableSpace; - variableSpace.putVariable(1, new Variable()); - variableSpace.putVariable(1, 1, new Variable()); - - std::pair pair(1, 2); - variableSpace.putVariable(pair, new Variable()); -} - TEST_F(SanityTests, VariableSpace_2) { - VariableSpace variableSpace; - variableSpace.putVariable(1, new Variable(NDArrayFactory::create_('c', {3, 3}))); - variableSpace.putVariable(1, 1, new Variable(NDArrayFactory::create_('c', {3, 3}))); + VariableSpace variableSpace; + variableSpace.putVariable(1, NDArrayFactory::create('c', {3, 3})); + variableSpace.putVariable({1, 1}, NDArrayFactory::create('c', {3, 3})); - std::pair pair(1, 2); - variableSpace.putVariable(pair, new Variable(NDArrayFactory::create_('c', {3, 3}))); + std::pair pair(1, 2); + variableSpace.putVariable(pair, NDArrayFactory::create('c', {3, 3})); } - TEST_F(SanityTests, Graph_1) { - Graph graph; + Graph graph; - graph.getVariableSpace()->putVariable(1, new Variable(NDArrayFactory::create_('c', {3, 3}))); - graph.getVariableSpace()->putVariable(1, 1, new Variable(NDArrayFactory::create_('c', {3, 3}))); + graph.variableSpace().putVariable(1, + NDArrayFactory::create('c', {3, 3})); + graph.variableSpace().putVariable({1, 1}, + NDArrayFactory::create('c', {3, 3})); - std::pair pair(1, 2); - graph.getVariableSpace()->putVariable(pair, new Variable(NDArrayFactory::create_('c', {3, 3}))); + std::pair pair(1, 2); + graph.variableSpace().putVariable(pair, + NDArrayFactory::create('c', {3, 3})); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/ScalarTests.cpp b/libnd4j/tests_cpu/layers_tests/ScalarTests.cpp index 937ca4675608..fde5a69bbd2b 100644 --- a/libnd4j/tests_cpu/layers_tests/ScalarTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ScalarTests.cpp @@ -18,219 +18,202 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include #include -#include #include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class ScalarTests : public testing::Test { -public: - + public: }; TEST_F(ScalarTests, Test_Create_1) { - auto x = NDArrayFactory::create(2.0f); - - ASSERT_EQ(0, x.rankOf()); - ASSERT_EQ(1, x.lengthOf()); - ASSERT_TRUE(x.isScalar()); - ASSERT_FALSE(x.isVector()); - ASSERT_FALSE(x.isRowVector()); - ASSERT_FALSE(x.isColumnVector()); - ASSERT_FALSE(x.isMatrix()); + auto x = NDArrayFactory::create(2.0f); + + ASSERT_EQ(0, x.rankOf()); + ASSERT_EQ(1, x.lengthOf()); + ASSERT_TRUE(x.isScalar()); + ASSERT_FALSE(x.isVector()); + ASSERT_FALSE(x.isRowVector()); + ASSERT_FALSE(x.isColumnVector()); + ASSERT_FALSE(x.isMatrix()); } TEST_F(ScalarTests, Test_Add_1) { - auto x = NDArrayFactory::create(2.0f); - auto exp = NDArrayFactory::create(5.0f); + auto x = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create(5.0f); - x += 3.0f; + x += 3.0f; - ASSERT_NEAR(5.0f, x.e(0), 1e-5f); - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_NEAR(5.0f, x.e(0), 1e-5f); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(ScalarTests, Test_Add_2) { - auto x = NDArrayFactory::create(2.0f); - auto y = NDArrayFactory::create(3.0f); - auto exp = NDArrayFactory::create(5.0f); + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create(3.0f); + auto exp = NDArrayFactory::create(5.0f); - x += y; + x += y; - ASSERT_NEAR(5.0f, x.e(0), 1e-5f); - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_NEAR(5.0f, x.e(0), 1e-5f); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(ScalarTests, Test_Add_3) { - auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto y = NDArrayFactory::create(3.0f); - auto exp = NDArrayFactory::create('c', {3}, {4, 5, 6}); + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto y = NDArrayFactory::create(3.0f); + auto exp = NDArrayFactory::create('c', {3}, {4, 5, 6}); - x += y; + x += y; - ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(ScalarTests, Test_EQ_1) { - auto x = NDArrayFactory::create(2.0f); - auto y = NDArrayFactory::create(3.0f); + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create(3.0f); - ASSERT_TRUE(y.isSameShape(&x)); - ASSERT_FALSE(y.equalsTo(&x)); + ASSERT_TRUE(y.isSameShape(&x)); + ASSERT_FALSE(y.equalsTo(&x)); } TEST_F(ScalarTests, Test_Concat_1) { - auto t = NDArrayFactory::create(1.0f); - auto u = NDArrayFactory::create(2.0f); - auto v = NDArrayFactory::create(3.0f); - auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); - - sd::ops::concat op; - auto result = op.evaluate({&t, &u, &v}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto t = NDArrayFactory::create(1.0f); + auto u = NDArrayFactory::create(2.0f); + auto v = NDArrayFactory::create(3.0f); + auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto z = result.at(0); + sd::ops::concat op; + auto result = op.evaluate({&t, &u, &v}, {}, {0}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(ScalarTests, Test_Concat_2) { - auto t = NDArrayFactory::create(1.0f); - auto u = NDArrayFactory::create('c', {3}, {2, 3, 4}); - auto v = NDArrayFactory::create(5.0f); - auto exp = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - - sd::ops::concat op; - auto result = op.evaluate({&t, &u, &v}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto t = NDArrayFactory::create(1.0f); + auto u = NDArrayFactory::create('c', {3}, {2, 3, 4}); + auto v = NDArrayFactory::create(5.0f); + auto exp = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - auto z = result.at(0); - // z->printIndexedBuffer(); + sd::ops::concat op; + auto result = op.evaluate({&t, &u, &v}, {}, {0}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer(); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(ScalarTests, Test_Concat_3) { - auto t = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto u = NDArrayFactory::create(4.0f); - auto v = NDArrayFactory::create(5.0f); - auto exp = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - - sd::ops::concat op; - auto result = op.evaluate({&t, &u, &v}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto t = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto u = NDArrayFactory::create(4.0f); + auto v = NDArrayFactory::create(5.0f); + auto exp = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - auto z = result.at(0); + sd::ops::concat op; + auto result = op.evaluate({&t, &u, &v}, {}, {0}); - //z->printShapeInfo("z"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printShapeInfo("z"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ScalarTests, Test_ExpandDims_1) { - auto x = NDArrayFactory::create(2.0f); - auto exp = NDArrayFactory::create('c', {1}, {2.0f}); + auto x = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create('c', {1}, {2.0f}); - sd::ops::expand_dims op; - auto result = op.evaluate({&x}, {}, {0}); + sd::ops::expand_dims op; + auto result = op.evaluate({&x}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ScalarTests, Test_Squeeze_1) { - auto x = NDArrayFactory::create(2.0f); - auto exp = NDArrayFactory::create(2.0f); - - sd::ops::squeeze op; - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create(2.0f); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::squeeze op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ScalarTests, Test_Permute_1) { - auto x = NDArrayFactory::create(3.0f); - auto exp = NDArrayFactory::create(3.0f); - - sd::ops::permute op; - auto result = op.evaluate({&x}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto x = NDArrayFactory::create(3.0f); + auto exp = NDArrayFactory::create(3.0f); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::permute op; + auto result = op.evaluate({&x}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ScalarTests, Test_Concat_Scalar_1) { - auto t = NDArrayFactory::create('c', {1, 1}, {1.0f}); - auto u = NDArrayFactory::create('c', {1, 1}, {2.0f}); - auto v = NDArrayFactory::create('c', {1, 1}, {3.0f}); - auto w = NDArrayFactory::create('c', {1, 1}, {4.0f}); - auto exp = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto t = NDArrayFactory::create('c', {1, 1}, {1.0f}); + auto u = NDArrayFactory::create('c', {1, 1}, {2.0f}); + auto v = NDArrayFactory::create('c', {1, 1}, {3.0f}); + auto w = NDArrayFactory::create('c', {1, 1}, {4.0f}); + auto exp = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); - sd::ops::concat op; - auto result = op.evaluate({&t, &u, &v, &w}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::concat op; + auto result = op.evaluate({&t, &u, &v, &w}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(ScalarTests, Test_Concat_Scalar_2) { - auto t = NDArrayFactory::create('c', {1, 1}, {1.0f}); - auto u = NDArrayFactory::create('c', {1, 1}, {2.0f}); - auto v = NDArrayFactory::create('c', {1, 1}, {3.0f}); - auto w = NDArrayFactory::create('c', {1, 1}, {4.0f}); - auto exp = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - - sd::ops::concat op; - auto result = op.evaluate({&t, &u, &v, &w}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto t = NDArrayFactory::create('c', {1, 1}, {1.0f}); + auto u = NDArrayFactory::create('c', {1, 1}, {2.0f}); + auto v = NDArrayFactory::create('c', {1, 1}, {3.0f}); + auto w = NDArrayFactory::create('c', {1, 1}, {4.0f}); + auto exp = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto z = result.at(0); + sd::ops::concat op; + auto result = op.evaluate({&t, &u, &v, &w}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/ScopeTests.cpp b/libnd4j/tests_cpu/layers_tests/ScopeTests.cpp deleted file mode 100644 index 6c83e869e91d..000000000000 --- a/libnd4j/tests_cpu/layers_tests/ScopeTests.cpp +++ /dev/null @@ -1,165 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 15.10.2017. -// - -#include "testlayers.h" -#include -#include -#include - -using namespace sd; -using namespace sd::graph; - -class ScopeTests : public testing::Test { -public: - -}; - -TEST_F(ScopeTests, BasicTests_1) { - Graph graph; - - auto x = NDArrayFactory::create_('c', {2, 2}); - x->assign(0.0f); - - auto variableSpace = graph.getVariableSpace(); - variableSpace->putVariable(-1, x); - - sd::ops::Scope opScope; - - auto scopeBody = new Node(OpType_LOGIC, 10, 1); - scopeBody->setName("scopeBody"); - scopeBody->setCustomOp(&opScope); - - - graph.addNode(scopeBody); - - ASSERT_EQ(1, graph.totalNodes()); - - auto scopedB0 = new Node(OpType_SCALAR, 0, 6, {-1}, {}, {}, 1.0f); - scopedB0->markInplace(true); - scopedB0->setScopeInfo(1, "scopeBody"); - - graph.addNode(scopedB0); - - ASSERT_EQ(1, graph.totalNodes()); - -} -/* -TEST_F(ScopeTests, RealTests_1) { - Graph graph; - - auto x = NDArrayFactory::create_('c', {2, 2}); - x->assign(0.0f); - - auto y = NDArrayFactory::create_('c', {2, 2}); - y->assign(0.0); - -// auto scalar = NDArrayFactory::create_('c', {1, 1}); - auto scalar = NDArrayFactory::create_(10.f); - //scalar->p(0, 10); - - auto variableSpace = graph.getVariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - variableSpace->putVariable(-3, scalar); - - // just few ops coming before while - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 1, {-1}); - auto nodeB = new Node(OpType_SCALAR, scalar::Add, 2, {1}, {}, {}, 1.0); - - // - auto scopeCondition = new Node(OpType_LOGIC, logic::Scope, 3); - scopeCondition->setName("scopeCondition"); - sd::ops::Scope opScope; - scopeCondition->setCustomOp(&opScope); - - // this is scope of the body, it'll be executed multiple times - auto scopeBody = new Node(OpType_LOGIC, logic::Scope, 10); - scopeBody->setName("scopeBody"); - scopeBody->setCustomOp(&opScope); - -//////////////////////////////////////////////////////////////////////////////////////////////////// -//// filling out condition scope -//////////////////////////////////////////////////////////////////////////////////////////////////// - // this is Sum accumulation, which feed - auto scopedA0 = new Node(OpType_REDUCE_SAME, reduce::Sum, 4, {12}); - scopedA0->setScopeInfo(3, "scopeCondition"); - - // this op compares LT A0 result with variable `scalar` which is 10; - sd::ops::lt_scalar op; - auto scopedA1 = new Node(&op, 5, {4, -3}); - scopedA1->setScopeInfo(3, "scopeCondition"); - - -//////////////////////////////////////////////////////////////////////////////////////////////////// -//// filling out body scope -//////////////////////////////////////////////////////////////////////////////////////////////////// - auto scopedB0 = new Node(OpType_SCALAR, scalar::Add, 6, {12}, {}, {}, 1.0f); - scopedB0->markInplace(false); - scopedB0->setScopeInfo(10, "scopeBody"); - - auto nodeReturn = new Node(OpType_LOGIC, logic::Return, 7, {6}, {12}); - sd::ops::Return opReturn; - nodeReturn->setCustomOp(&opReturn); - nodeReturn->setScopeInfo(10, "scopeBody"); - - // WHILE operations takes 2 scopes - :0 is condition scope, and :1 is loop body scope - auto nodeWhile = new Node(OpType_LOGIC, logic::While, 12, {-2, 3, 10}); - sd::ops::While opWhile; - nodeWhile->setCustomOp(&opWhile); - - // adding root nodes first, nothing unusual expected here - graph.addNode(nodeA); - graph.addNode(nodeB); - - // now we're registering our scopes - graph.addNode(scopeCondition); - graph.addNode(scopeBody); - - // at this moment graph should have 4 (four) nodes registered - ASSERT_EQ(4, graph.totalNodes()); - - // adding node that's attached to some scope. so it should be pushed to specific scope - graph.addNode(scopedA0); - - // we should still have 4 ops in graph, because node added above - goes directly into the scope - // thus falls out of the graph direct execution - it can be executed only via Scope - ASSERT_EQ(4, graph.totalNodes()); - - graph.addNode(scopedA1); - graph.addNode(scopedB0); - graph.addNode(nodeReturn); - - // should be still 4. no options here. - ASSERT_EQ(4, graph.totalNodes()); - - // WHILE is valid node, so we expect nodes counter to go up - graph.addNode(nodeWhile); - ASSERT_EQ(5, graph.totalNodes()); - - // now, let's try to execute graph - Nd4jStatus status = GraphExecutioner::execute(&graph); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto w = variableSpace->getVariable(12, 0)->getNDArray(); - - w->printShapeInfo("w shape"); - ASSERT_NEAR(12.f, w->sumNumber().e(0), 1e-5f); -} -*/ \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/ServerRelatedTests.cpp b/libnd4j/tests_cpu/layers_tests/ServerRelatedTests.cpp index 50c1f4b19adb..667279397a8a 100644 --- a/libnd4j/tests_cpu/layers_tests/ServerRelatedTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ServerRelatedTests.cpp @@ -18,25 +18,25 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include #include #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class ServerRelatedTests : public testing::Test { -public: - ServerRelatedTests() { - Environment::getInstance().setDebug(true); - Environment::getInstance().setVerbose(true); - } - - ~ServerRelatedTests() { - Environment::getInstance().setDebug(false); - Environment::getInstance().setVerbose(false); - } + public: + ServerRelatedTests() { + Environment::getInstance().setDebug(true); + Environment::getInstance().setVerbose(true); + } + + ~ServerRelatedTests() { + Environment::getInstance().setDebug(false); + Environment::getInstance().setVerbose(false); + } }; /* TEST_F(ServerRelatedTests, Basic_Output_Test_1) { @@ -83,106 +83,111 @@ TEST_F(ServerRelatedTests, Basic_Output_Test_1) { */ #if GRAPH_FILES_OK TEST_F(ServerRelatedTests, Basic_Execution_Test_1) { - flatbuffers::FlatBufferBuilder builder(4096); - auto oGraph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); - oGraph->printOut(); + flatbuffers::FlatBufferBuilder builder(4096); + auto oGraph = GraphExecutioner::importFromFlatBuffers( + "./resources/reduce_dim_false.fb"); + oGraph->printOut(); - auto exp = NDArrayFactory::create('c', {3}, {3.f, 3.f, 3.f}); + auto exp = NDArrayFactory::create('c', {3}, {3.f, 3.f, 3.f}); - GraphHolder::getInstance().registerGraph(11901L, oGraph); + GraphHolder::getInstance().registerGraph(11901L, oGraph); - auto cGraph = GraphHolder::getInstance().cloneGraph(11901L); + auto cGraph = GraphHolder::getInstance().cloneGraph(11901L); - ASSERT_TRUE(oGraph != cGraph); + ASSERT_TRUE(oGraph != cGraph); - auto flatResult = GraphExecutioner::execute(cGraph, builder, nullptr); + auto flatResult = GraphExecutioner::execute(cGraph, builder, nullptr); - builder.Finish(flatResult); - auto ptr = builder.GetBufferPointer(); - auto received = GetFlatResult(ptr); + builder.Finish(flatResult); + auto ptr = builder.GetBufferPointer(); + auto received = GetFlatResult(ptr); - ExecutionResult restored(received); - ASSERT_EQ(1, restored.size()); + ExecutionResult restored(received); + ASSERT_EQ(1, restored.size()); - ASSERT_EQ(exp, *restored.at(0)->getNDArray()); + ASSERT_EQ(exp, *restored.at(0)->getNDArray()); - delete cGraph; + delete cGraph; - GraphHolder::getInstance().dropGraphAny(11901L); + GraphHolder::getInstance().dropGraphAny(11901L); } TEST_F(ServerRelatedTests, Basic_Execution_Test_2) { - flatbuffers::FlatBufferBuilder builder(4096); - flatbuffers::FlatBufferBuilder otherBuilder(4096); - auto oGraph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); - oGraph->printOut(); + flatbuffers::FlatBufferBuilder builder(4096); + flatbuffers::FlatBufferBuilder otherBuilder(4096); + auto oGraph = GraphExecutioner::importFromFlatBuffers( + "./resources/reduce_dim_false.fb"); + oGraph->printOut(); - auto input0 = NDArrayFactory::create('c', {3, 3}, {2.f,2.f,2.f, 2.f,2.f,2.f, 2.f,2.f,2.f}); - auto exp = NDArrayFactory::create('c', {3}, {6.f, 6.f, 6.f}); + auto input0 = NDArrayFactory::create( + 'c', {3, 3}, {2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}); + auto exp = NDArrayFactory::create('c', {3}, {6.f, 6.f, 6.f}); - GraphHolder::getInstance().registerGraph(11902L, oGraph); + GraphHolder::getInstance().registerGraph(11902L, oGraph); - auto cGraph = GraphHolder::getInstance().cloneGraph(11902L); + auto cGraph = GraphHolder::getInstance().cloneGraph(11902L); - ASSERT_TRUE(oGraph != cGraph); + ASSERT_TRUE(oGraph != cGraph); - // mastering InferenceRequest - InferenceRequest ir(11902L); - ir.appendVariable(1, 0, &input0); + // mastering InferenceRequest + InferenceRequest ir(11902L); + ir.appendVariable(1, 0, &input0); - auto af = ir.asFlatInferenceRequest(otherBuilder); - otherBuilder.Finish(af); - auto fptr = otherBuilder.GetBufferPointer(); - auto fir = GetFlatInferenceRequest(fptr); + auto af = ir.asFlatInferenceRequest(otherBuilder); + otherBuilder.Finish(af); + auto fptr = otherBuilder.GetBufferPointer(); + auto fir = GetFlatInferenceRequest(fptr); - auto flatResult = GraphExecutioner::execute(cGraph, builder, fir); + auto flatResult = GraphExecutioner::execute(cGraph, builder, fir); - builder.Finish(flatResult); - auto ptr = builder.GetBufferPointer(); - auto received = GetFlatResult(ptr); + builder.Finish(flatResult); + auto ptr = builder.GetBufferPointer(); + auto received = GetFlatResult(ptr); - ExecutionResult restored(received); - ASSERT_EQ(1, restored.size()); + ExecutionResult restored(received); + ASSERT_EQ(1, restored.size()); - ASSERT_EQ(exp, *restored.at(0)->getNDArray()); + ASSERT_EQ(exp, *restored.at(0)->getNDArray()); - delete cGraph; + delete cGraph; - GraphHolder::getInstance().dropGraphAny(11902L); + GraphHolder::getInstance().dropGraphAny(11902L); } TEST_F(ServerRelatedTests, BasicExecutionTests_3) { - flatbuffers::FlatBufferBuilder builder(4096); - flatbuffers::FlatBufferBuilder otherBuilder(4096); - auto oGraph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); - oGraph->printOut(); + flatbuffers::FlatBufferBuilder builder(4096); + flatbuffers::FlatBufferBuilder otherBuilder(4096); + auto oGraph = GraphExecutioner::importFromFlatBuffers( + "./resources/reduce_dim_false.fb"); + oGraph->printOut(); - auto input0 = NDArrayFactory::create('c', {3, 3}, {2.f,2.f,2.f, 2.f,2.f,2.f, 2.f,2.f,2.f}); - auto exp = NDArrayFactory::create('c', {3}, {6.f, 6.f, 6.f}); + auto input0 = NDArrayFactory::create( + 'c', {3, 3}, {2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}); + auto exp = NDArrayFactory::create('c', {3}, {6.f, 6.f, 6.f}); - GraphHolder::getInstance().registerGraph(11903L, oGraph); + GraphHolder::getInstance().registerGraph(11903L, oGraph); - // mastering InferenceRequest - InferenceRequest ir(11903L); - ir.appendVariable(1, 0, &input0); + // mastering InferenceRequest + InferenceRequest ir(11903L); + ir.appendVariable(1, 0, &input0); - auto af = ir.asFlatInferenceRequest(otherBuilder); - otherBuilder.Finish(af); - auto fptr = otherBuilder.GetBufferPointer(); - auto fir = GetFlatInferenceRequest(fptr); + auto af = ir.asFlatInferenceRequest(otherBuilder); + otherBuilder.Finish(af); + auto fptr = otherBuilder.GetBufferPointer(); + auto fir = GetFlatInferenceRequest(fptr); + auto flatResult = + GraphHolder::getInstance().execute(fir->id(), builder, fir); - auto flatResult = GraphHolder::getInstance().execute(fir->id(), builder, fir); + builder.Finish(flatResult); + auto ptr = builder.GetBufferPointer(); + auto received = GetFlatResult(ptr); - builder.Finish(flatResult); - auto ptr = builder.GetBufferPointer(); - auto received = GetFlatResult(ptr); - - ExecutionResult restored(received); - ASSERT_EQ(1, restored.size()); + ExecutionResult restored(received); + ASSERT_EQ(1, restored.size()); - ASSERT_EQ(exp, *restored.at(0)->getNDArray()); + ASSERT_EQ(exp, *restored.at(0)->getNDArray()); - GraphHolder::getInstance().dropGraphAny(11903L); + GraphHolder::getInstance().dropGraphAny(11903L); } #endif diff --git a/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp b/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp deleted file mode 100644 index 8481dfde5cef..000000000000 --- a/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp +++ /dev/null @@ -1,93 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#ifndef LIBND4J_SESSIONLOCALTESTS_H -#define LIBND4J_SESSIONLOCALTESTS_H - -#include "testlayers.h" -#include -#include - -using namespace sd::graph; - -class SessionLocalTests : public testing::Test { -public: - -}; - -TEST_F(SessionLocalTests, BasicTests_1) { - VariableSpace variableSpace; - SessionLocalStorage storage(&variableSpace, nullptr); - - if (omp_get_max_threads() <= 1) - return; - - PRAGMA_OMP_PARALLEL_FOR_THREADS(4) - for (int e = 0; e < 4; e++) { - storage.startSession(); - } - - ASSERT_EQ(4, storage.numberOfSessions()); - - PRAGMA_OMP_PARALLEL_FOR_THREADS(4) - for (int e = 0; e < 4; e++) { - storage.endSession(); - } - - ASSERT_EQ(0, storage.numberOfSessions()); -} - - -TEST_F(SessionLocalTests, BasicTests_2) { - VariableSpace variableSpace; - SessionLocalStorage storage(&variableSpace, nullptr); - - if (omp_get_max_threads() <= 1) - return; - - auto alpha = sd::NDArrayFactory::create_('c',{5,5}); - alpha->assign(0.0); - - variableSpace.putVariable(-1, alpha); - - PRAGMA_OMP_PARALLEL_FOR_THREADS(4) - for (int e = 0; e < 4; e++) { - storage.startSession(); - - auto varSpace = storage.localVariableSpace(); - - auto arr = varSpace->getVariable(-1)->getNDArray(); - arr->applyScalar(sd::scalar::Add, (float) e+1, *arr); - } - - float lastValue = 0.0f; - for (int e = 1; e <= 4; e++) { - auto varSpace = storage.localVariableSpace((Nd4jLong) e); - - auto arr = varSpace->getVariable(-1)->getNDArray(); - - //nd4j_printf("Last value: %f; Current value: %f\n", lastValue, arr->e(0)); - - ASSERT_NE(lastValue, arr->e(0)); - lastValue = arr->e(0); - } -} - -#endif //LIBND4J_SESSIONLOCALTESTS_H diff --git a/libnd4j/tests_cpu/layers_tests/ShapeTests.cpp b/libnd4j/tests_cpu/layers_tests/ShapeTests.cpp index 46edc50d8bd1..b4e314767daf 100644 --- a/libnd4j/tests_cpu/layers_tests/ShapeTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ShapeTests.cpp @@ -19,316 +19,305 @@ // #include -#include "testlayers.h" #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class ShapeTests : public testing::Test { -public: - + public: }; - TEST_F(ShapeTests, Test_Basics_1) { - Nd4jLong shape[] = {2, 5, 3, 3, 1, 0, 1, 99}; + Nd4jLong shape[] = {2, 5, 3, 3, 1, 0, 1, 99}; - ASSERT_EQ(2, shape::rank(shape)); - ASSERT_EQ(1, shape::elementWiseStride(shape)); - ASSERT_EQ(5, shape::sizeAt(shape, 0)); - ASSERT_EQ(3, shape::sizeAt(shape, 1)); - ASSERT_EQ('c', shape::order(shape)); + ASSERT_EQ(2, shape::rank(shape)); + ASSERT_EQ(1, shape::elementWiseStride(shape)); + ASSERT_EQ(5, shape::sizeAt(shape, 0)); + ASSERT_EQ(3, shape::sizeAt(shape, 1)); + ASSERT_EQ('c', shape::order(shape)); } - TEST_F(ShapeTests, Test_Basics_2) { - Nd4jLong shape[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; - - ASSERT_EQ(4, shape::rank(shape)); - ASSERT_EQ(-1, shape::elementWiseStride(shape)); - ASSERT_EQ(2, shape::sizeAt(shape, 0)); - ASSERT_EQ(3, shape::sizeAt(shape, 1)); - ASSERT_EQ(4, shape::sizeAt(shape, 2)); - ASSERT_EQ(5, shape::sizeAt(shape, 3)); - ASSERT_EQ('f', shape::order(shape)); + Nd4jLong shape[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; + + ASSERT_EQ(4, shape::rank(shape)); + ASSERT_EQ(-1, shape::elementWiseStride(shape)); + ASSERT_EQ(2, shape::sizeAt(shape, 0)); + ASSERT_EQ(3, shape::sizeAt(shape, 1)); + ASSERT_EQ(4, shape::sizeAt(shape, 2)); + ASSERT_EQ(5, shape::sizeAt(shape, 3)); + ASSERT_EQ('f', shape::order(shape)); } - TEST_F(ShapeTests, Test_tadLength_1) { - Nd4jLong shape[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; - int axis[] = {2, 3}; + Nd4jLong shape[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; + int axis[] = {2, 3}; - ASSERT_EQ(20, shape::tadLength(shape, axis, 2)); + ASSERT_EQ(20, shape::tadLength(shape, axis, 2)); } - TEST_F(ShapeTests, Test_ShapeEquality_1) { - Nd4jLong shape[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; - Nd4jLong shape_GOOD[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, 1, 99}; - Nd4jLong shape_BAD[] = {4, 3, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; + Nd4jLong shape[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; + Nd4jLong shape_GOOD[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, 1, 99}; + Nd4jLong shape_BAD[] = {4, 3, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; - - ASSERT_TRUE(shape::equalsSoft(shape, shape_GOOD)); - ASSERT_FALSE(shape::equalsSoft(shape, shape_BAD)); + ASSERT_TRUE(shape::equalsSoft(shape, shape_GOOD)); + ASSERT_FALSE(shape::equalsSoft(shape, shape_BAD)); } TEST_F(ShapeTests, Test_ShapeEquality_2) { - Nd4jLong shape[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; - Nd4jLong shape_GOOD[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; - Nd4jLong shape_BAD[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 99}; - + Nd4jLong shape[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; + Nd4jLong shape_GOOD[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; + Nd4jLong shape_BAD[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 99}; - ASSERT_TRUE(shape::equalsStrict(shape, shape_GOOD)); - ASSERT_FALSE(shape::equalsStrict(shape, shape_BAD)); + ASSERT_TRUE(shape::equalsStrict(shape, shape_GOOD)); + ASSERT_FALSE(shape::equalsStrict(shape, shape_BAD)); } TEST_F(ShapeTests, Test_Ind2SubC_1) { - Nd4jLong shape[] = {3, 5}; - Nd4jLong c0[2]; - shape::index2coords(0, 2, shape, c0); + Nd4jLong shape[] = {3, 5}; + Nd4jLong c0[2]; + shape::index2coords(0, 2, shape, c0); - ASSERT_EQ(0, c0[0]); - ASSERT_EQ(0, c0[1]); + ASSERT_EQ(0, c0[0]); + ASSERT_EQ(0, c0[1]); - Nd4jLong c1[2]; - shape::index2coords(1, 2, shape, c1); + Nd4jLong c1[2]; + shape::index2coords(1, 2, shape, c1); - ASSERT_EQ(0, c1[0]); - ASSERT_EQ(1, c1[1]); + ASSERT_EQ(0, c1[0]); + ASSERT_EQ(1, c1[1]); - Nd4jLong c6[2]; - shape::index2coords(5, 2, shape, c6); + Nd4jLong c6[2]; + shape::index2coords(5, 2, shape, c6); - ASSERT_EQ(1, c6[0]); - ASSERT_EQ(0, c6[1]); + ASSERT_EQ(1, c6[0]); + ASSERT_EQ(0, c6[1]); } - TEST_F(ShapeTests, Test_ShapeDetector_1) { - Nd4jLong shape[] = {2, 5, 3, 3, 1, 0, 1, 99}; + Nd4jLong shape[] = {2, 5, 3, 3, 1, 0, 1, 99}; - ASSERT_TRUE(shape::isMatrix(shape)); + ASSERT_TRUE(shape::isMatrix(shape)); } TEST_F(ShapeTests, Test_ShapeDetector_2) { - Nd4jLong shape[] = {3, 2, 5, 3, 15, 3, 1, 0, 1, 99}; + Nd4jLong shape[] = {3, 2, 5, 3, 15, 3, 1, 0, 1, 99}; - ASSERT_FALSE(shape::isMatrix(shape)); + ASSERT_FALSE(shape::isMatrix(shape)); } TEST_F(ShapeTests, Test_ShapeDetector_3) { - Nd4jLong shape[] = {2, 1, 3, 3, 1, 0, 1, 99}; + Nd4jLong shape[] = {2, 1, 3, 3, 1, 0, 1, 99}; - ASSERT_FALSE(shape::isColumnVector(shape)); - ASSERT_TRUE(shape::isVector(shape)); - ASSERT_TRUE(shape::isRowVector(shape)); - ASSERT_FALSE(shape::isMatrix(shape)); + ASSERT_FALSE(shape::isColumnVector(shape)); + ASSERT_TRUE(shape::isVector(shape)); + ASSERT_TRUE(shape::isRowVector(shape)); + ASSERT_FALSE(shape::isMatrix(shape)); } - TEST_F(ShapeTests, Test_ShapeDetector_4) { - Nd4jLong shape[] = {2, 3, 1, 1, 1, 0, 1, 99}; + Nd4jLong shape[] = {2, 3, 1, 1, 1, 0, 1, 99}; - ASSERT_TRUE(shape::isColumnVector(shape)); - ASSERT_TRUE(shape::isVector(shape)); - ASSERT_FALSE(shape::isRowVector(shape)); - ASSERT_FALSE(shape::isMatrix(shape)); + ASSERT_TRUE(shape::isColumnVector(shape)); + ASSERT_TRUE(shape::isVector(shape)); + ASSERT_FALSE(shape::isRowVector(shape)); + ASSERT_FALSE(shape::isMatrix(shape)); } TEST_F(ShapeTests, Test_ShapeDetector_5) { - Nd4jLong shape[] = {2, 1, 1, 1, 1, 0, 1, 99}; + Nd4jLong shape[] = {2, 1, 1, 1, 1, 0, 1, 99}; - ASSERT_TRUE(shape::isScalar(shape)); - ASSERT_FALSE(shape::isMatrix(shape)); + ASSERT_TRUE(shape::isScalar(shape)); + ASSERT_FALSE(shape::isMatrix(shape)); - // edge case here. Technicaly it's still a vector with length of 1 - ASSERT_TRUE(shape::isVector(shape)); + // edge case here. Technicaly it's still a vector with length of 1 + ASSERT_TRUE(shape::isVector(shape)); } TEST_F(ShapeTests, Test_ShapeDetector_6) { - Nd4jLong shape[] = {2, 1, 1, 1, 1, 0, 1, 99}; + Nd4jLong shape[] = {2, 1, 1, 1, 1, 0, 1, 99}; - ASSERT_EQ(8, shape::shapeInfoLength(shape)); - ASSERT_EQ(64, shape::shapeInfoByteLength(shape)); + ASSERT_EQ(8, shape::shapeInfoLength(shape)); + ASSERT_EQ(64, shape::shapeInfoByteLength(shape)); } TEST_F(ShapeTests, Test_ShapeDetector_7) { - Nd4jLong shape[] = {3, 1, 1, 1, 1, 1, 1, 0, 1, 99}; + Nd4jLong shape[] = {3, 1, 1, 1, 1, 1, 1, 0, 1, 99}; - ASSERT_EQ(10, shape::shapeInfoLength(shape)); - ASSERT_EQ(80, shape::shapeInfoByteLength(shape)); + ASSERT_EQ(10, shape::shapeInfoLength(shape)); + ASSERT_EQ(80, shape::shapeInfoByteLength(shape)); } TEST_F(ShapeTests, Test_Transpose_1) { - Nd4jLong shape[] = {3, 2, 5, 3, 15, 3, 1, 0, 1, 99}; - Nd4jLong exp[] = {3, 3, 5, 2, 1, 3, 15, 0, 1, 102}; + Nd4jLong shape[] = {3, 2, 5, 3, 15, 3, 1, 0, 1, 99}; + Nd4jLong exp[] = {3, 3, 5, 2, 1, 3, 15, 0, 1, 102}; - shape::transposeInplace(shape); + shape::transposeInplace(shape); - ASSERT_TRUE(shape::equalsStrict(exp, shape)); + ASSERT_TRUE(shape::equalsStrict(exp, shape)); } TEST_F(ShapeTests, Test_Transpose_2) { - Nd4jLong shape[] = {2, 5, 3, 3, 1, 0, 1, 99}; - Nd4jLong exp[] = {2, 3, 5, 1, 3, 0, 1, 102}; + Nd4jLong shape[] = {2, 5, 3, 3, 1, 0, 1, 99}; + Nd4jLong exp[] = {2, 3, 5, 1, 3, 0, 1, 102}; - shape::transposeInplace(shape); + shape::transposeInplace(shape); - ASSERT_TRUE(shape::equalsStrict(exp, shape)); + ASSERT_TRUE(shape::equalsStrict(exp, shape)); } TEST_F(ShapeTests, Test_Transpose_3) { - Nd4jLong shape[] = {2, 1, 3, 3, 1, 0, 1, 99}; - Nd4jLong exp[] = {2, 3, 1, 1, 3, 0, 1, 102}; + Nd4jLong shape[] = {2, 1, 3, 3, 1, 0, 1, 99}; + Nd4jLong exp[] = {2, 3, 1, 1, 3, 0, 1, 102}; - shape::transposeInplace(shape); + shape::transposeInplace(shape); - ASSERT_TRUE(shape::equalsStrict(exp, shape)); + ASSERT_TRUE(shape::equalsStrict(exp, shape)); } - TEST_F(ShapeTests, Test_Transpose_4) { - Nd4jLong shape[] = {4, 2, 3, 4, 5, 5, 4, 3, 2, 0, 1, 99}; - Nd4jLong exp[] = {4, 5, 4, 3, 2, 2, 3, 4, 5, 0, 1, 102}; + Nd4jLong shape[] = {4, 2, 3, 4, 5, 5, 4, 3, 2, 0, 1, 99}; + Nd4jLong exp[] = {4, 5, 4, 3, 2, 2, 3, 4, 5, 0, 1, 102}; - shape::transposeInplace(shape); + shape::transposeInplace(shape); - ASSERT_TRUE(shape::equalsStrict(exp, shape)); + ASSERT_TRUE(shape::equalsStrict(exp, shape)); } TEST_F(ShapeTests, Test_Edge_1) { - auto x = NDArrayFactory::create('f', {1, 4, 1, 4}); - x.linspace(1); + auto x = NDArrayFactory::create('f', {1, 4, 1, 4}); + x.linspace(1); - x.reshapei('c', {4, 4}); + x.reshapei('c', {4, 4}); - //x.printShapeInfo("reshape0"); - //x.printIndexedBuffer("x i"); - //x.printBuffer("x r"); + // x.printShapeInfo("reshape0"); + // x.printIndexedBuffer("x i"); + // x.printBuffer("x r"); - x.reshapei({4, 1, 1, 4}); + x.reshapei({4, 1, 1, 4}); - //x.printShapeInfo("reshape1"); + // x.printShapeInfo("reshape1"); } TEST_F(ShapeTests, Test_Edge_2) { - auto x = NDArrayFactory::create('c', {1, 4, 1, 3}); + auto x = NDArrayFactory::create('c', {1, 4, 1, 3}); - x.reshapei('c', {3, 4}); + x.reshapei('c', {3, 4}); - //x.printShapeInfo("reshape0"); + // x.printShapeInfo("reshape0"); - x.reshapei({3, 1, 1, 4}); + x.reshapei({3, 1, 1, 4}); - //x.printShapeInfo("reshape1"); + // x.printShapeInfo("reshape1"); } - TEST_F(ShapeTests, Test_Remove_Index_1) { - int array[] = {1, 2, 3}; - int idx[] = {0}; - int result[2]; - shape::removeIndex(array, idx, 3, 1, result); + int array[] = {1, 2, 3}; + int idx[] = {0}; + int result[2]; + shape::removeIndex(array, idx, 3, 1, result); - ASSERT_EQ(2, result[0]); - ASSERT_EQ(3, result[1]); + ASSERT_EQ(2, result[0]); + ASSERT_EQ(3, result[1]); } TEST_F(ShapeTests, Test_Remove_Index_2) { - int array[] = {1, 2, 3}; - int idx[] = {1}; - int result[2]; - shape::removeIndex(array, idx, 3, 1, result); + int array[] = {1, 2, 3}; + int idx[] = {1}; + int result[2]; + shape::removeIndex(array, idx, 3, 1, result); - ASSERT_EQ(1, result[0]); - ASSERT_EQ(3, result[1]); + ASSERT_EQ(1, result[0]); + ASSERT_EQ(3, result[1]); } TEST_F(ShapeTests, Test_Remove_Index_3) { - int array[] = {1, 2, 3}; - int idx[] = {2}; - int result[2]; - shape::removeIndex(array, idx, 3, 1, result); + int array[] = {1, 2, 3}; + int idx[] = {2}; + int result[2]; + shape::removeIndex(array, idx, 3, 1, result); - ASSERT_EQ(1, result[0]); - ASSERT_EQ(2, result[1]); + ASSERT_EQ(1, result[0]); + ASSERT_EQ(2, result[1]); } TEST_F(ShapeTests, Test_Remove_Index_4) { - int array[] = {1, 2, 3}; - int idx[] = {0, 2}; - int result[1]; - shape::removeIndex(array, idx, 3, 2, result); + int array[] = {1, 2, 3}; + int idx[] = {0, 2}; + int result[1]; + shape::removeIndex(array, idx, 3, 2, result); - ASSERT_EQ(2, result[0]); + ASSERT_EQ(2, result[0]); } TEST_F(ShapeTests, Test_Remove_Index_5) { - int array[] = {1, 2, 3}; - int idx[] = {1, 0}; - int result[1]; - shape::removeIndex(array, idx, 3, 2, result); + int array[] = {1, 2, 3}; + int idx[] = {1, 0}; + int result[1]; + shape::removeIndex(array, idx, 3, 2, result); - ASSERT_EQ(3, result[0]); + ASSERT_EQ(3, result[0]); } TEST_F(ShapeTests, Test_Remove_Index_6) { - int array[] = {1, 2, 3}; - int idx[] = {1, 2}; - int result[1]; - shape::removeIndex(array, idx, 3, 2, result); + int array[] = {1, 2, 3}; + int idx[] = {1, 2}; + int result[1]; + shape::removeIndex(array, idx, 3, 2, result); - ASSERT_EQ(1, result[0]); + ASSERT_EQ(1, result[0]); } TEST_F(ShapeTests, Tests_Transpose_119_1) { - auto x = NDArrayFactory::create('c', {3, 2}); - auto y = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); - auto z = NDArrayFactory::create('c', {2, 3}); + auto x = NDArrayFactory::create('c', {3, 2}); + auto y = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); + auto z = NDArrayFactory::create('c', {2, 3}); - x.linspace(1.f); + x.linspace(1.f); - auto e = x.permute({1, 0}); - e.streamline('c'); + auto e = x.permute({1, 0}); + e.streamline('c'); - sd::ops::transpose op; - auto result = op.execute({&x, &y}, {&z}, {}, {}, {}); + sd::ops::transpose op; + auto result = op.execute({&x, &y}, {&z}, {}, {}, {}); - ASSERT_EQ(Status::OK(), result); - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_EQ(Status::OK(), result); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(ShapeTests, Tests_Transpose_119_2) { - auto x = NDArrayFactory::create('c', {3, 5}); - x.linspace(1.f); - - auto exp = x.transpose(); + auto x = NDArrayFactory::create('c', {3, 5}); + x.linspace(1.f); - sd::ops::transpose op; - auto result = op.evaluate({&x}); - ASSERT_EQ(Status::OK(), result.status()); + auto exp = x.transpose(); - auto z = result.at(0); + sd::ops::transpose op; + auto result = op.evaluate({&x}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ShapeTests, Tests_Transpose_119_3) { - auto x = NDArrayFactory::create('c', {3, 5}); - x.linspace(1.f); + auto x = NDArrayFactory::create('c', {3, 5}); + x.linspace(1.f); - auto z = NDArrayFactory::create('c', {5, 3}); + auto z = NDArrayFactory::create('c', {5, 3}); - auto exp = x.transpose(); + auto exp = x.transpose(); - sd::ops::transpose op; - auto result = op.execute({&x}, {&z}, {}, {}, {}); - ASSERT_EQ(Status::OK(), result); + sd::ops::transpose op; + auto result = op.execute({&x}, {&z}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/ShapeTests2.cpp b/libnd4j/tests_cpu/layers_tests/ShapeTests2.cpp index b84213342126..79940359f808 100644 --- a/libnd4j/tests_cpu/layers_tests/ShapeTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/ShapeTests2.cpp @@ -17,255 +17,233 @@ // // Created by agibsonccc on 1/6/17. // +#include +#include #include + #include "testinclude.h" -#include -#include class OnesTest : public testing::Test { -public: - Nd4jLong shapeBuffer[12] = {4,4,3,1,1,3,1,1,1,0,1,99}; - int dimension[3] = {0,2,3}; - Nd4jLong tadAssertionShape[10] = {3,1,1,4,1,1,3,0,3,99}; - int dimensionLength = 3; + public: + Nd4jLong shapeBuffer[12] = {4, 4, 3, 1, 1, 3, 1, 1, 1, 0, 1, 99}; + int dimension[3] = {0, 2, 3}; + Nd4jLong tadAssertionShape[10] = {3, 1, 1, 4, 1, 1, 3, 0, 3, 99}; + int dimensionLength = 3; }; class LabelTest : public testing::Test { -public: - float labels[450] = {1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0}; - Nd4jLong shapeInfo[8] = {2,150,3,1,150,16384,1,102}; - int dimension[1] = {1}; - int dimensionLength = 1; - Nd4jLong tadShapeInfoAssert[8] = {2,1,3,1,150,16384,150,102}; + public: + float labels[450] = { + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0}; + Nd4jLong shapeInfo[8] = {2, 150, 3, 1, 150, 16384, 1, 102}; + int dimension[1] = {1}; + int dimensionLength = 1; + Nd4jLong tadShapeInfoAssert[8] = {2, 1, 3, 1, 150, 16384, 150, 102}; }; class ThreeDTest : public testing::Test { -public: - Nd4jLong shape[3] = {3,4,5}; - Nd4jLong *shapeBuffer; - ThreeDTest() { - shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 3, shape); - } - ~ThreeDTest() { - delete[] shapeBuffer; - } + public: + Nd4jLong shape[3] = {3, 4, 5}; + Nd4jLong *shapeBuffer; + ThreeDTest() { + shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', + 3, shape); + } + ~ThreeDTest() { delete[] shapeBuffer; } }; -class VectorTest : public testing::Test { - -}; +class VectorTest : public testing::Test {}; class NumTadTests : public testing::Test { -public: - Nd4jLong shape[3] = {3,4,5}; - int dimension = 0; + public: + Nd4jLong shape[3] = {3, 4, 5}; + int dimension = 0; }; -class ShapeTest : public testing::Test { -public: - Nd4jLong vectorShape[2] = {1,2}; +class ShapeTest : public testing::Test { + public: + Nd4jLong vectorShape[2] = {1, 2}; }; class MatrixTest : public testing::Test { -public: - int rows = 3; - int cols = 4; - int rank = 2; - int dims[2] = {0,1}; - Nd4jLong expectedShapes[2][2] = { - {1,3}, - {1,4} - }; - Nd4jLong expectedStrides[2][2] = { - {1,4}, - {1,1} - }; + public: + int rows = 3; + int cols = 4; + int rank = 2; + int dims[2] = {0, 1}; + Nd4jLong expectedShapes[2][2] = {{1, 3}, {1, 4}}; + Nd4jLong expectedStrides[2][2] = {{1, 4}, {1, 1}}; }; class TADStall : public testing::Test { -public: - Nd4jLong shape[4] = {3,3,4,5}; - int dimensions[3] = {1,2,3}; + public: + Nd4jLong shape[4] = {3, 3, 4, 5}; + int dimensions[3] = {1, 2, 3}; }; class TensorOneDimTest : public testing::Test { -public: - int rows = 3; - int cols = 4; - int dim2 = 5; - int rank = 3; - int dims[3] = {0,1,2}; - Nd4jLong expectedShapes[3][2] = { - {1,3}, - {1,4}, - {1,5} - }; - Nd4jLong expectedStrides[3][2] = { - {1,20}, - {1,5}, - {1,1} - }; + public: + int rows = 3; + int cols = 4; + int dim2 = 5; + int rank = 3; + int dims[3] = {0, 1, 2}; + Nd4jLong expectedShapes[3][2] = {{1, 3}, {1, 4}, {1, 5}}; + Nd4jLong expectedStrides[3][2] = {{1, 20}, {1, 5}, {1, 1}}; }; class TensorTwoDimTest : public testing::Test { -public: - //From a 3d array: - int rows = 3; - int cols = 4; - int dim2 = 5; - int dimensionLength = 2; - int dims[3][2] = { - {0,1},{0,2},{1,2} - }; - - Nd4jLong shape[3] {rows,cols,dim2}; - - //Along dimension 0,1: expect matrix with shape [rows,cols] - //Along dimension 0,2: expect matrix with shape [rows,dim2] - //Along dimension 1,2: expect matrix with shape [cols,dim2] - Nd4jLong expectedShapes[3][2] = { - {rows,cols}, - {rows,dim2}, - {cols,dim2} - }; - - Nd4jLong expectedStrides[3][2] = { - {20,5}, - {20,1}, - {5,1} - }; - + public: + // From a 3d array: + int rows = 3; + int cols = 4; + int dim2 = 5; + int dimensionLength = 2; + int dims[3][2] = {{0, 1}, {0, 2}, {1, 2}}; + + Nd4jLong shape[3]{rows, cols, dim2}; + + // Along dimension 0,1: expect matrix with shape [rows,cols] + // Along dimension 0,2: expect matrix with shape [rows,dim2] + // Along dimension 1,2: expect matrix with shape [cols,dim2] + Nd4jLong expectedShapes[3][2] = {{rows, cols}, {rows, dim2}, {cols, dim2}}; + + Nd4jLong expectedStrides[3][2] = {{20, 5}, {20, 1}, {5, 1}}; }; class TensorTwoFromFourDDimTest : public testing::Test { -public: - //From a 3d array: - int rows = 3; - int cols = 4; - int dim2 = 5; - int dim3 = 6; - Nd4jLong shape[4] = {rows,cols,dim2,dim3}; - int dimensionLength = 2; - //Along dimension 0,1: expect matrix with shape [rows,cols] - //Along dimension 0,2: expect matrix with shape [rows,dim2] - //Along dimension 0,3: expect matrix with shape [rows,dim3] - //Along dimension 1,2: expect matrix with shape [cols,dim2] - //Along dimension 1,3: expect matrix with shape [cols,dim3] - //Along dimension 2,3: expect matrix with shape [dim2,dim3] - - int dims[6][2] = { - {0,1}, - {0,2}, - {0,3}, - {1,2}, - {1,3}, - {2,3} - }; - - Nd4jLong expectedShapes[6][2] = { - {rows,cols}, - {rows,dim2}, - {rows,dim3}, - {cols,dim2}, - {cols,dim3} - ,{dim2,dim3} - }; - - Nd4jLong expectedStrides[6][2] = { - {120,30}, - {120,6}, - {120,1}, - {30,6}, - {30,1}, - {6,1} - }; + public: + // From a 3d array: + int rows = 3; + int cols = 4; + int dim2 = 5; + int dim3 = 6; + Nd4jLong shape[4] = {rows, cols, dim2, dim3}; + int dimensionLength = 2; + // Along dimension 0,1: expect matrix with shape [rows,cols] + // Along dimension 0,2: expect matrix with shape [rows,dim2] + // Along dimension 0,3: expect matrix with shape [rows,dim3] + // Along dimension 1,2: expect matrix with shape [cols,dim2] + // Along dimension 1,3: expect matrix with shape [cols,dim3] + // Along dimension 2,3: expect matrix with shape [dim2,dim3] + + int dims[6][2] = {{0, 1}, {0, 2}, {0, 3}, {1, 2}, {1, 3}, {2, 3}}; + + Nd4jLong expectedShapes[6][2] = {{rows, cols}, {rows, dim2}, {rows, dim3}, + {cols, dim2}, {cols, dim3}, {dim2, dim3}}; + + Nd4jLong expectedStrides[6][2] = {{120, 30}, {120, 6}, {120, 1}, + {30, 6}, {30, 1}, {6, 1}}; }; - class OrderTest : public testing::Test { -public: - Nd4jLong expected[8] = {2,3,4,1,3,0,0,102}; - Nd4jLong test[8] = {2,3,4,1,3,0,0,102}; - + public: + Nd4jLong expected[8] = {2, 3, 4, 1, 3, 0, 0, 102}; + Nd4jLong test[8] = {2, 3, 4, 1, 3, 0, 0, 102}; }; - class LeadingOnes : public testing::Test { -public: - Nd4jLong shapeBufferF[16] = {4,1,1,4,4,1,1,1,4,16384,1,102}; // shapes with data type DOUBLE - Nd4jLong shapeBufferC[16] = {4,1,1,4,4,16,16,4,1,16384,1,99}; - int dimensionLength = 2; - int dimension[2] = {2,3}; - Nd4jLong tadAssertionC[10] = {3,4,4,1,4,1,16,16384,1,99}; - Nd4jLong tadCAssertionF[10] = {3,4,4,1,1,4,1,16384,1,102}; + public: + Nd4jLong shapeBufferF[16] = { + 4, 1, 1, 4, 4, 1, + 1, 1, 4, 16384, 1, 102}; // shapes with data type DOUBLE + Nd4jLong shapeBufferC[16] = {4, 1, 1, 4, 4, 16, 16, 4, 1, 16384, 1, 99}; + int dimensionLength = 2; + int dimension[2] = {2, 3}; + Nd4jLong tadAssertionC[10] = {3, 4, 4, 1, 4, 1, 16, 16384, 1, 99}; + Nd4jLong tadCAssertionF[10] = {3, 4, 4, 1, 1, 4, 1, 16384, 1, 102}; }; - -TEST_F(LeadingOnes,OnesTest) { - - shape::TAD *cTad = new shape::TAD; - cTad->init(shapeBufferC,dimension,dimensionLength); - cTad->createTadOnlyShapeInfo(); - cTad->createOffsets(); - shape::TAD *fTad = new shape::TAD; - fTad->init(shapeBufferF,dimension,dimensionLength); - fTad->createTadOnlyShapeInfo(); - fTad->createOffsets(); - // shape::printShapeInfoLinear(cTad->tadOnlyShapeInfo); - // shape::printShapeInfoLinear(fTad->tadOnlyShapeInfo); - ASSERT_TRUE(arrsEquals(10, tadCAssertionF, fTad->tadOnlyShapeInfo)); - ASSERT_TRUE(arrsEquals(10, tadAssertionC, cTad->tadOnlyShapeInfo)); - - delete cTad; - delete fTad; +TEST_F(LeadingOnes, OnesTest) { + shape::TAD *cTad = new shape::TAD; + cTad->init(shapeBufferC, dimension, dimensionLength); + cTad->createTadOnlyShapeInfo(); + cTad->createOffsets(); + shape::TAD *fTad = new shape::TAD; + fTad->init(shapeBufferF, dimension, dimensionLength); + fTad->createTadOnlyShapeInfo(); + fTad->createOffsets(); + // shape::printShapeInfoLinear(cTad->tadOnlyShapeInfo); + // shape::printShapeInfoLinear(fTad->tadOnlyShapeInfo); + ASSERT_TRUE(arrsEquals(10, tadCAssertionF, fTad->tadOnlyShapeInfo)); + ASSERT_TRUE(arrsEquals(10, tadAssertionC, cTad->tadOnlyShapeInfo)); + + delete cTad; + delete fTad; } - class NormalThreeFourFive : public testing::Test { -public: - Nd4jLong assertionBuffer[8] = {2, 3, 4, 20, 5, 16384, 5, 99}; - Nd4jLong inputShapeBuffer[10] = {3,3,4,5,20,5,1,16384,1,99}; - int dimensionLength = 2; - int dimension[2] = {0,1}; + public: + Nd4jLong assertionBuffer[8] = {2, 3, 4, 20, 5, 16384, 5, 99}; + Nd4jLong inputShapeBuffer[10] = {3, 3, 4, 5, 20, 5, 1, 16384, 1, 99}; + int dimensionLength = 2; + int dimension[2] = {0, 1}; }; +TEST_F(NormalThreeFourFive, DimensionTest) { + shape::TAD *tad = new shape::TAD; + tad->init(inputShapeBuffer, dimension, dimensionLength); + tad->createTadOnlyShapeInfo(); + tad->createOffsets(); + ASSERT_TRUE(arrsEquals(8, assertionBuffer, tad->tadOnlyShapeInfo)); -TEST_F(NormalThreeFourFive,DimensionTest) { - shape::TAD *tad = new shape::TAD; - tad->init(inputShapeBuffer,dimension,dimensionLength); - tad->createTadOnlyShapeInfo(); - tad->createOffsets(); - ASSERT_TRUE(arrsEquals(8,assertionBuffer,tad->tadOnlyShapeInfo)); - - delete tad; + delete tad; } class DimensionWarning : public testing::Test { -public: - int dimensionLength = 2; - int dimensions[2] = {0,1}; - Nd4jLong shape[3] = {1,5,1}; - Nd4jLong *shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 3, shape); - - ~DimensionWarning() { - delete[] shapeBuffer; - } + public: + int dimensionLength = 2; + int dimensions[2] = {0, 1}; + Nd4jLong shape[3] = {1, 5, 1}; + Nd4jLong *shapeBuffer = + sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 3, shape); + + ~DimensionWarning() { delete[] shapeBuffer; } }; - -TEST_F(DimensionWarning,ShapeWarning) { - shape::TAD *tad = new shape::TAD; - tad->init(shapeBuffer,dimensions,dimensionLength); - tad->createTadOnlyShapeInfo(); - tad->createOffsets(); - delete tad; +TEST_F(DimensionWarning, ShapeWarning) { + shape::TAD *tad = new shape::TAD; + tad->init(shapeBuffer, dimensions, dimensionLength); + tad->createTadOnlyShapeInfo(); + tad->createOffsets(); + delete tad; } - class TadRank : public testing::Test { - Nd4jLong shapeBuffer[12] = {4,2,1,3,3,9,9,3,1,0,1,99}; - int dimensionLength = 2; - int dimension[2] = {2,3}; - + Nd4jLong shapeBuffer[12] = {4, 2, 1, 3, 3, 9, 9, 3, 1, 0, 1, 99}; + int dimensionLength = 2; + int dimension[2] = {2, 3}; }; class TestRemoveIndex : public testing::Test {}; @@ -281,95 +259,88 @@ class SliceMatrixTest : public testing::Test {}; class SliceTensorTest : public testing::Test {}; class ElementWiseStrideTest : public testing::Test { -public: - Nd4jLong shape[3] = {3,4,5}; - Nd4jLong stride[2] = {20,5}; - int elementWiseStrideAssertion = -1; + public: + Nd4jLong shape[3] = {3, 4, 5}; + Nd4jLong stride[2] = {20, 5}; + int elementWiseStrideAssertion = -1; }; -class PermuteTest : public testing::Test{}; +class PermuteTest : public testing::Test {}; -class LengthPerSliceTest : public testing::Test{}; +class LengthPerSliceTest : public testing::Test {}; class ExpectedValuesTest : public testing::Test { -public: - Nd4jLong mainShape[4] = {9,7,5,3}; - int testDimensions[3] = {0,2,3}; - + public: + Nd4jLong mainShape[4] = {9, 7, 5, 3}; + int testDimensions[3] = {0, 2, 3}; }; class BeginOneTadTest : public testing::Test { -public: - Nd4jLong assertionShapeBuffer[8] = {2,3,5,1,3,16384,1,102}; - Nd4jLong inputShapeBuffer[10] = {3,1,3,5,1,1,3,16384,0,102}; - int dimensionLength = 2; - int dimension[2] = {1,2}; - //error: [2,1,1,1,1,0,1,97] + public: + Nd4jLong assertionShapeBuffer[8] = {2, 3, 5, 1, 3, 16384, 1, 102}; + Nd4jLong inputShapeBuffer[10] = {3, 1, 3, 5, 1, 1, 3, 16384, 0, 102}; + int dimensionLength = 2; + int dimension[2] = {1, 2}; + // error: [2,1,1,1,1,0,1,97] }; class FourDTest : public testing::Test { - /** - * INDArray array3d = Nd4j.ones(1, 10, 10); + /** + * INDArray array3d = Nd4j.ones(1, 10, 10); array3d.sum(1); INDArray array4d = Nd4j.ones(1, 10, 10, 10); INDArray sum40 = array4d.sum(0); - */ -public: - Nd4jLong threeDShape[3] = {1,10,10}; - Nd4jLong fourDShape[4] = {1,10,10,10}; - Nd4jLong *threeDShapeBuffer = nullptr,*fourDShapeBuffer = nullptr; - int dimensionThree = 1; - int dimensionThreeTwo = 0; - int dimensionFour = 0; - int dimensionLength = 1; - FourDTest() { - threeDShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'f', 3, threeDShape); - fourDShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'f', 4, fourDShape); - } - ~FourDTest() { - if(threeDShapeBuffer != nullptr) - delete[] threeDShapeBuffer; - if(fourDShapeBuffer != nullptr) - delete[] fourDShapeBuffer; - } - - - + */ + public: + Nd4jLong threeDShape[3] = {1, 10, 10}; + Nd4jLong fourDShape[4] = {1, 10, 10, 10}; + Nd4jLong *threeDShapeBuffer = nullptr, *fourDShapeBuffer = nullptr; + int dimensionThree = 1; + int dimensionThreeTwo = 0; + int dimensionFour = 0; + int dimensionLength = 1; + FourDTest() { + threeDShapeBuffer = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'f', 3, threeDShape); + fourDShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, + 'f', 4, fourDShape); + } + ~FourDTest() { + if (threeDShapeBuffer != nullptr) delete[] threeDShapeBuffer; + if (fourDShapeBuffer != nullptr) delete[] fourDShapeBuffer; + } }; - - -TEST_F(FourDTest,ThreeDFourDTest) { - shape::TAD *threeTadTwo = new shape::TAD; - threeTadTwo->init(threeDShapeBuffer,&dimensionThreeTwo,dimensionLength); - threeTadTwo->createTadOnlyShapeInfo(); - threeTadTwo->createOffsets(); - - shape::TAD *threeTad = new shape::TAD; - threeTad->init(threeDShapeBuffer,&dimensionThree,dimensionLength); - threeTad->createTadOnlyShapeInfo(); - threeTad->createOffsets(); - - shape::TAD *fourTad = new shape::TAD; - fourTad->init(fourDShapeBuffer,&dimensionFour,dimensionLength); - fourTad->createTadOnlyShapeInfo(); - fourTad->createOffsets(); - - delete threeTadTwo; - delete threeTad; - delete fourTad; +TEST_F(FourDTest, ThreeDFourDTest) { + shape::TAD *threeTadTwo = new shape::TAD; + threeTadTwo->init(threeDShapeBuffer, &dimensionThreeTwo, dimensionLength); + threeTadTwo->createTadOnlyShapeInfo(); + threeTadTwo->createOffsets(); + + shape::TAD *threeTad = new shape::TAD; + threeTad->init(threeDShapeBuffer, &dimensionThree, dimensionLength); + threeTad->createTadOnlyShapeInfo(); + threeTad->createOffsets(); + + shape::TAD *fourTad = new shape::TAD; + fourTad->init(fourDShapeBuffer, &dimensionFour, dimensionLength); + fourTad->createTadOnlyShapeInfo(); + fourTad->createOffsets(); + + delete threeTadTwo; + delete threeTad; + delete fourTad; } - - class RowVectorOnesTest : public testing::Test { -public: - Nd4jLong shapeBuffer[12] = {4,4,3,1,1,3,1,1,1,8192,1,99}; // float32 type of shape - float data[12] = {1,2,3,4,5,6,7,8,9,10,11,12}; - Nd4jLong assertionBuffer[10] = {3,4,1,1,3,1,1,8192,0,99}; - int dimensionLength = 3; - int dimension[3] = {0,2,3}; + public: + Nd4jLong shapeBuffer[12] = {4, 4, 3, 1, 1, 3, + 1, 1, 1, 8192, 1, 99}; // float32 type of shape + float data[12] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + Nd4jLong assertionBuffer[10] = {3, 4, 1, 1, 3, 1, 1, 8192, 0, 99}; + int dimensionLength = 3; + int dimension[3] = {0, 2, 3}; }; // TEST_F(RowVectorOnesTest,TadShape) { @@ -380,57 +351,58 @@ class RowVectorOnesTest : public testing::Test { // delete tad; // } - - class SixDTest : public testing::Test { -public: - Nd4jLong inputShapeBuffer[16] = {6,1,1,4,4,4,4,1,1,1,4,16,64,16384,1,102}; // shape with double data type - int dimensionLength = 2; - int dimension[2] = {2,3}; - Nd4jLong assertionShapeBuffer[8] = {2,4,4,1,4,16384,1,102}; // also double typed shape + public: + Nd4jLong inputShapeBuffer[16] = { + 6, 1, 1, 4, 4, 4, 4, 1, + 1, 1, 4, 16, 64, 16384, 1, 102}; // shape with double data type + int dimensionLength = 2; + int dimension[2] = {2, 3}; + Nd4jLong assertionShapeBuffer[8] = { + 2, 4, 4, 1, 4, 16384, 1, 102}; // also double typed shape }; TEST_F(SixDTest, SixDWithOnes) { - shape::TAD *tad = new shape::TAD; - tad->init(inputShapeBuffer,dimension,dimensionLength); - tad->createTadOnlyShapeInfo(); - tad->createOffsets(); - // shape::printShapeInfoLinear(inputShapeBuffer); - // shape::printShapeInfoLinear(tad->tadOnlyShapeInfo); - //[2,1,1,1,1,0,1,97] - ASSERT_TRUE(arrsEquals(8,assertionShapeBuffer,tad->tadOnlyShapeInfo)); - delete tad; + shape::TAD *tad = new shape::TAD; + tad->init(inputShapeBuffer, dimension, dimensionLength); + tad->createTadOnlyShapeInfo(); + tad->createOffsets(); + // shape::printShapeInfoLinear(inputShapeBuffer); + // shape::printShapeInfoLinear(tad->tadOnlyShapeInfo); + //[2,1,1,1,1,0,1,97] + ASSERT_TRUE(arrsEquals(8, assertionShapeBuffer, tad->tadOnlyShapeInfo)); + delete tad; } class TrailingTest : public testing::Test { -public: - Nd4jLong inputShapeBuffer[12] = {4,5,5,5,1,1,5,25,125,16384,1,102}; - int dimensionLength = 1; - int dimension[1] = {0}; - Nd4jLong assertionShapeBuffer[8] = {2,1,5,125,1,16384,1,102}; + public: + Nd4jLong inputShapeBuffer[12] = {4, 5, 5, 5, 1, 1, 5, 25, 125, 16384, 1, 102}; + int dimensionLength = 1; + int dimension[1] = {0}; + Nd4jLong assertionShapeBuffer[8] = {2, 1, 5, 125, 1, 16384, 1, 102}; }; -TEST_F(TrailingTest,TrailingTest2) { - shape::TAD *tad = new shape::TAD; - tad->init(inputShapeBuffer,dimension,dimensionLength); - tad->createTadOnlyShapeInfo(); - tad->createOffsets(); - //[2,1,1,1,1,0,1,97] - ASSERT_TRUE(arrsEquals(8,assertionShapeBuffer,tad->tadOnlyShapeInfo)); - delete tad; +TEST_F(TrailingTest, TrailingTest2) { + shape::TAD *tad = new shape::TAD; + tad->init(inputShapeBuffer, dimension, dimensionLength); + tad->createTadOnlyShapeInfo(); + tad->createOffsets(); + //[2,1,1,1,1,0,1,97] + ASSERT_TRUE(arrsEquals(8, assertionShapeBuffer, tad->tadOnlyShapeInfo)); + delete tad; } - class ScalarTest : public testing::Test { -public: - Nd4jLong inputShapeBuffer[12] = {3,2,3,4,12,4,1,16384,1,99}; - int dimensionLength = 1; - int dimension[1] = {1}; - Nd4jLong assertionShapeBuffer[8] = {2,1,1,1,1,16384,1,99}; + public: + Nd4jLong inputShapeBuffer[12] = {3, 2, 3, 4, 12, 4, 1, 16384, 1, 99}; + int dimensionLength = 1; + int dimension[1] = {1}; + Nd4jLong assertionShapeBuffer[8] = {2, 1, 1, 1, 1, 16384, 1, 99}; }; /* TEST_F(ScalarTest,ScalarTest2) { - shape::TAD *tad = new shape::TAD(inputShapeBuffer,dimension,dimensionLength); + shape::TAD *tad = new +shape::TAD(inputShapeBuffer,dimension,dimensionLength); tad->createTadOnlyShapeInfo(); tad ->createOffsets(); //[2,1,1,1,1,0,1,97] @@ -439,37 +411,34 @@ TEST_F(ScalarTest,ScalarTest2) { } */ - - class ThreeTest : public testing::Test { -public: - Nd4jLong inputShapeBuffer[10] = {3,4,3,2,6,2,1,16384,1,99}; - int dimensionLength = 1; - int dimension[1] = {0}; - Nd4jLong assertionShapeBuffer[8] = {2,1,4,1,6,16384,6,99}; + public: + Nd4jLong inputShapeBuffer[10] = {3, 4, 3, 2, 6, 2, 1, 16384, 1, 99}; + int dimensionLength = 1; + int dimension[1] = {0}; + Nd4jLong assertionShapeBuffer[8] = {2, 1, 4, 1, 6, 16384, 6, 99}; }; -TEST_F(ThreeTest,ThreeTest ) { - shape::TAD *tad = new shape::TAD; - tad->init(inputShapeBuffer,dimension,dimensionLength); - tad->createTadOnlyShapeInfo(); - tad->createOffsets(); - //[2,1,1,1,1,0,1,97] - ASSERT_TRUE(arrsEquals(8,assertionShapeBuffer,tad->tadOnlyShapeInfo)); - delete tad; +TEST_F(ThreeTest, ThreeTest) { + shape::TAD *tad = new shape::TAD; + tad->init(inputShapeBuffer, dimension, dimensionLength); + tad->createTadOnlyShapeInfo(); + tad->createOffsets(); + //[2,1,1,1,1,0,1,97] + ASSERT_TRUE(arrsEquals(8, assertionShapeBuffer, tad->tadOnlyShapeInfo)); + delete tad; } - TEST_F(BeginOneTadTest, TadTest) { - shape::TAD *tad = new shape::TAD; - tad->init(inputShapeBuffer,dimension,dimensionLength); - tad->createTadOnlyShapeInfo(); - auto tadShapeBuffer = tad->tadOnlyShapeInfo; - // shape::printShapeInfoLinear(tadShapeBuffer); - //[2,1,1,1,1,0,1,97] - ASSERT_TRUE(arrsEquals(8,assertionShapeBuffer,tadShapeBuffer)); - - delete tad; + shape::TAD *tad = new shape::TAD; + tad->init(inputShapeBuffer, dimension, dimensionLength); + tad->createTadOnlyShapeInfo(); + auto tadShapeBuffer = tad->tadOnlyShapeInfo; + // shape::printShapeInfoLinear(tadShapeBuffer); + //[2,1,1,1,1,0,1,97] + ASSERT_TRUE(arrsEquals(8, assertionShapeBuffer, tadShapeBuffer)); + + delete tad; } /* @@ -481,338 +450,352 @@ TEST_F(OnesTest,OnesTadTest) { } */ -TEST_F(LabelTest,LabelTad) { - shape::TAD *tad = new shape::TAD; - tad->init(shapeInfo,dimension,dimensionLength); - tad->createTadOnlyShapeInfo(); - auto tadShapeInfo = tad->tadOnlyShapeInfo; - ASSERT_TRUE(arrsEquals(8,tadShapeInfoAssert,tadShapeInfo)); +TEST_F(LabelTest, LabelTad) { + shape::TAD *tad = new shape::TAD; + tad->init(shapeInfo, dimension, dimensionLength); + tad->createTadOnlyShapeInfo(); + auto tadShapeInfo = tad->tadOnlyShapeInfo; + ASSERT_TRUE(arrsEquals(8, tadShapeInfoAssert, tadShapeInfo)); - delete tad; + delete tad; } -TEST_F(ExpectedValuesTest,TadTest) { - auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 4, mainShape); - shape::TAD *tad = new shape::TAD; - tad->init(shapeBuffer,testDimensions,3); - tad->createTadOnlyShapeInfo(); - auto shapeInfo = tad->tadOnlyShapeInfo; - - delete tad; - delete[] shapeBuffer; -} +TEST_F(ExpectedValuesTest, TadTest) { + auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, + 'c', 4, mainShape); + shape::TAD *tad = new shape::TAD; + tad->init(shapeBuffer, testDimensions, 3); + tad->createTadOnlyShapeInfo(); + auto shapeInfo = tad->tadOnlyShapeInfo; -TEST_F(OrderTest,testOrder) { - int rank = shape::rank(expected); - auto expectedShape = shape::shapeOf(expected); - auto expectedStride = shape::stride(expected); - int realOrder = shape::getOrder(rank,expectedShape,expectedStride,1); - int expectedOrder = 102; - ASSERT_EQ(expectedOrder,realOrder); + delete tad; + delete[] shapeBuffer; } - -TEST_F(ThreeDTest,TensorAlongDimensionTest) { - int dimension[2] = {0,2}; - Nd4jLong tadShapeAssertion[2] = {3,5}; - Nd4jLong strideAssertion[2] = {20,1}; - shape::TAD *tad = new shape::TAD; - tad->init(0,this->shapeBuffer,dimension,2); - tad->createTadOnlyShapeInfo(); - auto shapeBufferTest = tad->tadOnlyShapeInfo; - auto shapeTest = shape::shapeOf(shapeBufferTest); - auto strideTest = shape::stride(shapeBufferTest); - ASSERT_TRUE(arrsEquals(2,tadShapeAssertion,shapeTest)); - ASSERT_TRUE(arrsEquals(2,strideAssertion,strideTest)); - delete tad; +TEST_F(OrderTest, testOrder) { + int rank = shape::rank(expected); + auto expectedShape = shape::shapeOf(expected); + auto expectedStride = shape::stride(expected); + int realOrder = shape::getOrder(rank, expectedShape, expectedStride, 1); + int expectedOrder = 102; + ASSERT_EQ(expectedOrder, realOrder); } - -TEST_F(NumTadTests,TadTest) { - auto shape = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 3, this->shape); - shape::TAD *tad = new shape::TAD; - tad->init(shape,&dimension,1); - int numTads = shape::tensorsAlongDimension(shape,&dimension,1); - ASSERT_EQ(20,numTads); - delete[] shape; - delete tad; +TEST_F(ThreeDTest, TensorAlongDimensionTest) { + int dimension[2] = {0, 2}; + Nd4jLong tadShapeAssertion[2] = {3, 5}; + Nd4jLong strideAssertion[2] = {20, 1}; + shape::TAD *tad = new shape::TAD; + tad->init(0, this->shapeBuffer, dimension, 2); + tad->createTadOnlyShapeInfo(); + auto shapeBufferTest = tad->tadOnlyShapeInfo; + auto shapeTest = shape::shapeOf(shapeBufferTest); + auto strideTest = shape::stride(shapeBufferTest); + ASSERT_TRUE(arrsEquals(2, tadShapeAssertion, shapeTest)); + ASSERT_TRUE(arrsEquals(2, strideAssertion, strideTest)); + delete tad; } -TEST_F(TADStall,TestStall) { - auto shapeInfo = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 4, shape); - shape::TAD *tad = new shape::TAD; - tad->init(0,shapeInfo,this->dimensions,3); - tad->createTadOnlyShapeInfo(); - Nd4jLong *test = tad->tadOnlyShapeInfo; - - delete[] shapeInfo; - delete tad; +TEST_F(NumTadTests, TadTest) { + auto shape = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 3, + this->shape); + shape::TAD *tad = new shape::TAD; + tad->init(shape, &dimension, 1); + int numTads = shape::tensorsAlongDimension(shape, &dimension, 1); + ASSERT_EQ(20, numTads); + delete[] shape; + delete tad; } +TEST_F(TADStall, TestStall) { + auto shapeInfo = + sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 4, shape); + shape::TAD *tad = new shape::TAD; + tad->init(0, shapeInfo, this->dimensions, 3); + tad->createTadOnlyShapeInfo(); + Nd4jLong *test = tad->tadOnlyShapeInfo; -TEST_F(LengthPerSliceTest,TestLengthPerSlice) { - Nd4jLong firstShape[2] = {5,3}; - int lengthPerSliceAssertionFirst = 3; - int firstDimension = 0; - int lengthPerSliceTest = shape::lengthPerSlice(2,firstShape,&firstDimension,1); - ASSERT_EQ(lengthPerSliceAssertionFirst,lengthPerSliceTest); + delete[] shapeInfo; + delete tad; } -TEST_F(PermuteTest,PermuteShapeBufferTest) { - int permuteOrder[4] = {3,2,1,0}; - int normalOrder[4] = {0,1,2,3}; - Nd4jLong shapeToPermute[4] = {5,3,2,6}; - Nd4jLong permutedOrder[4] = {6,2,3,5}; - auto shapeBufferOriginal = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 4, shapeToPermute); - auto assertionShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 4, shapeToPermute); - shape::permuteShapeBufferInPlace(shapeBufferOriginal,normalOrder,shapeBufferOriginal); - EXPECT_TRUE(arrsEquals(4,assertionShapeBuffer,shapeBufferOriginal)); - - auto backwardsAssertion = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 4, permutedOrder); - auto permuted = shape::permuteShapeBuffer(assertionShapeBuffer, permuteOrder); - EXPECT_TRUE(arrsEquals(4, backwardsAssertion, permuted)); - - - delete[] permuted; - delete[] backwardsAssertion; - delete[] shapeBufferOriginal; - delete[] assertionShapeBuffer; +TEST_F(LengthPerSliceTest, TestLengthPerSlice) { + Nd4jLong firstShape[2] = {5, 3}; + int lengthPerSliceAssertionFirst = 3; + int firstDimension = 0; + int lengthPerSliceTest = + shape::lengthPerSlice(2, firstShape, &firstDimension, 1); + ASSERT_EQ(lengthPerSliceAssertionFirst, lengthPerSliceTest); } -TEST_F(ElementWiseStrideTest,ElementWiseStrideTest) { - +TEST_F(PermuteTest, PermuteShapeBufferTest) { + int permuteOrder[4] = {3, 2, 1, 0}; + int normalOrder[4] = {0, 1, 2, 3}; + Nd4jLong shapeToPermute[4] = {5, 3, 2, 6}; + Nd4jLong permutedOrder[4] = {6, 2, 3, 5}; + auto shapeBufferOriginal = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'c', 4, shapeToPermute); + auto assertionShapeBuffer = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'c', 4, shapeToPermute); + shape::permuteShapeBufferInPlace(shapeBufferOriginal, normalOrder, + shapeBufferOriginal); + EXPECT_TRUE(arrsEquals(4, assertionShapeBuffer, shapeBufferOriginal)); + + auto backwardsAssertion = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'c', 4, permutedOrder); + auto permuted = shape::permuteShapeBuffer(assertionShapeBuffer, permuteOrder); + EXPECT_TRUE(arrsEquals(4, backwardsAssertion, permuted)); + + delete[] permuted; + delete[] backwardsAssertion; + delete[] shapeBufferOriginal; + delete[] assertionShapeBuffer; } -TEST_F(SliceVectorTest,RowColumnVectorTest) { - Nd4jLong rowVectorShape[2] = {1,5}; - auto rowVectorShapeInfo = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, rowVectorShape); - Nd4jLong colVectorShape[2] = {5,1}; - auto colVectorShapeInfo = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, colVectorShape); - Nd4jLong *sliceRow = shape::sliceOfShapeBuffer(0,rowVectorShapeInfo); - EXPECT_TRUE(arrsEquals(2,rowVectorShapeInfo,sliceRow)); - Nd4jLong *scalarSliceInfo = shape::createScalarShapeInfo(); - Nd4jLong *scalarColumnAssertion = shape::createScalarShapeInfo(); - scalarColumnAssertion[shape::shapeInfoLength(2) - 3] = 1; - Nd4jLong *scalarColumnTest = shape::sliceOfShapeBuffer(1L,colVectorShapeInfo); - EXPECT_TRUE(arrsEquals(2,scalarColumnAssertion,scalarColumnTest)); - - delete[] scalarColumnTest; - delete[] scalarColumnAssertion; - delete[] scalarSliceInfo; - delete[] sliceRow; - delete[] rowVectorShapeInfo; - delete[] colVectorShapeInfo; +TEST_F(ElementWiseStrideTest, ElementWiseStrideTest) {} + +TEST_F(SliceVectorTest, RowColumnVectorTest) { + Nd4jLong rowVectorShape[2] = {1, 5}; + auto rowVectorShapeInfo = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'c', 2, rowVectorShape); + Nd4jLong colVectorShape[2] = {5, 1}; + auto colVectorShapeInfo = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'c', 2, colVectorShape); + Nd4jLong *sliceRow = shape::sliceOfShapeBuffer(0, rowVectorShapeInfo); + EXPECT_TRUE(arrsEquals(2, rowVectorShapeInfo, sliceRow)); + Nd4jLong *scalarSliceInfo = shape::createScalarShapeInfo(); + Nd4jLong *scalarColumnAssertion = shape::createScalarShapeInfo(); + scalarColumnAssertion[shape::shapeInfoLength(2) - 3] = 1; + Nd4jLong *scalarColumnTest = + shape::sliceOfShapeBuffer(1L, colVectorShapeInfo); + EXPECT_TRUE(arrsEquals(2, scalarColumnAssertion, scalarColumnTest)); + + delete[] scalarColumnTest; + delete[] scalarColumnAssertion; + delete[] scalarSliceInfo; + delete[] sliceRow; + delete[] rowVectorShapeInfo; + delete[] colVectorShapeInfo; } -TEST_F(SliceTensorTest,TestSlice) { - Nd4jLong shape[3] = {3,3,2}; - auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 3, shape); - Nd4jLong sliceShape[2] = {3,2}; - auto sliceShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, sliceShape); - Nd4jLong *testSlice = shape::sliceOfShapeBuffer(0,shapeBuffer); - EXPECT_TRUE(arrsEquals(2,sliceShapeBuffer,testSlice)); - delete[] testSlice; - delete[] shapeBuffer; - delete[] sliceShapeBuffer; - -} - -TEST_F(SliceMatrixTest,TestSlice) { - Nd4jLong shape[2] = {3,2}; - auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, shape); - Nd4jLong sliceShape[2] = {1,2}; - auto sliceShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, sliceShape); - Nd4jLong *testSlice = shape::sliceOfShapeBuffer(0,shapeBuffer); - EXPECT_TRUE(arrsEquals(2,sliceShapeBuffer,testSlice)); - delete[] testSlice; - delete[] shapeBuffer; - delete[] sliceShapeBuffer; - +TEST_F(SliceTensorTest, TestSlice) { + Nd4jLong shape[3] = {3, 3, 2}; + auto shapeBuffer = + sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 3, shape); + Nd4jLong sliceShape[2] = {3, 2}; + auto sliceShapeBuffer = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'c', 2, sliceShape); + Nd4jLong *testSlice = shape::sliceOfShapeBuffer(0, shapeBuffer); + EXPECT_TRUE(arrsEquals(2, sliceShapeBuffer, testSlice)); + delete[] testSlice; + delete[] shapeBuffer; + delete[] sliceShapeBuffer; } - -TEST_F(TestConcat,ConcatTest) { - Nd4jLong firstArr[2] = {1,2}; - Nd4jLong secondConcat[2] = {3,4}; - Nd4jLong concatAssertion[4] = {1,2,3,4}; - Nd4jLong *concatTest = shape::concat(firstArr,2,secondConcat,2); - EXPECT_TRUE(arrsEquals(4,concatAssertion,concatTest)); - delete[] concatTest; +TEST_F(SliceMatrixTest, TestSlice) { + Nd4jLong shape[2] = {3, 2}; + auto shapeBuffer = + sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, shape); + Nd4jLong sliceShape[2] = {1, 2}; + auto sliceShapeBuffer = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'c', 2, sliceShape); + Nd4jLong *testSlice = shape::sliceOfShapeBuffer(0, shapeBuffer); + EXPECT_TRUE(arrsEquals(2, sliceShapeBuffer, testSlice)); + delete[] testSlice; + delete[] shapeBuffer; + delete[] sliceShapeBuffer; } -TEST_F(TestReverseCopy,ReverseCopyTest) { - Nd4jLong toCopy[5] = {0,1,2,3,4}; - Nd4jLong reverseAssertion[5] = {4,3,2,1,0}; - Nd4jLong *reverseCopyTest = shape::reverseCopy(toCopy,5); - EXPECT_TRUE(arrsEquals(5,reverseAssertion,reverseCopyTest)); - delete[] reverseCopyTest; +TEST_F(TestConcat, ConcatTest) { + Nd4jLong firstArr[2] = {1, 2}; + Nd4jLong secondConcat[2] = {3, 4}; + Nd4jLong concatAssertion[4] = {1, 2, 3, 4}; + Nd4jLong *concatTest = shape::concat(firstArr, 2, secondConcat, 2); + EXPECT_TRUE(arrsEquals(4, concatAssertion, concatTest)); + delete[] concatTest; } -TEST_F(TestRemoveIndex,Remove) { - Nd4jLong input[5] = {0,1,2,3,4}; - Nd4jLong indexesToRemove[3] = {0,1,2}; - Nd4jLong indexesToRemoveAssertion[2] = {3,4}; - Nd4jLong *indexesToRemoveTest = shape::removeIndex(input,indexesToRemove, (Nd4jLong) 5, (Nd4jLong) 3); - EXPECT_TRUE(arrsEquals(2,indexesToRemoveAssertion,indexesToRemoveTest)); - delete[] indexesToRemoveTest; +TEST_F(TestReverseCopy, ReverseCopyTest) { + Nd4jLong toCopy[5] = {0, 1, 2, 3, 4}; + Nd4jLong reverseAssertion[5] = {4, 3, 2, 1, 0}; + Nd4jLong *reverseCopyTest = shape::reverseCopy(toCopy, 5); + EXPECT_TRUE(arrsEquals(5, reverseAssertion, reverseCopyTest)); + delete[] reverseCopyTest; } -TEST_F(TensorTwoFromFourDDimTest,TadTwoFromFourDimTest) { - //Along dimension 0,1: expect matrix with shape [rows,cols] - //Along dimension 0,2: expect matrix with shape [rows,dim2] - //Along dimension 0,3: expect matrix with shape [rows,dim3] - //Along dimension 1,2: expect matrix with shape [cols,dim2] - //Along dimension 1,3: expect matrix with shape [cols,dim3] - //Along dimension 2,3: expect matrix with shape [dim2,dim3] - auto baseShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 4, shape); - for(int i = 0; i < 3; i++) { - int *dimArr = dims[i]; - Nd4jLong *expectedShape = expectedShapes[i]; - shape::TAD *tad = new shape::TAD; - tad->init(baseShapeBuffer,dimArr,dimensionLength); - auto expectedShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', dimensionLength, expectedShape); - tad->createTadOnlyShapeInfo(); - Nd4jLong *testShapeBuffer = tad->tadOnlyShapeInfo; - EXPECT_TRUE(arrsEquals(shape::rank(expectedShapeBuffer),expectedShape,shape::shapeOf(testShapeBuffer))); - EXPECT_TRUE(arrsEquals(shape::rank(expectedShapeBuffer),expectedStrides[i],shape::stride(testShapeBuffer))); - - delete[] expectedShapeBuffer; - delete tad; - } - - delete[] baseShapeBuffer; +TEST_F(TestRemoveIndex, Remove) { + Nd4jLong input[5] = {0, 1, 2, 3, 4}; + Nd4jLong indexesToRemove[3] = {0, 1, 2}; + Nd4jLong indexesToRemoveAssertion[2] = {3, 4}; + Nd4jLong *indexesToRemoveTest = shape::removeIndex( + input, indexesToRemove, (Nd4jLong)5, (Nd4jLong)3); + EXPECT_TRUE(arrsEquals(2, indexesToRemoveAssertion, indexesToRemoveTest)); + delete[] indexesToRemoveTest; } -TEST_F(TensorTwoDimTest,TadTwoDimTest) { - //Along dimension 0,1: expect matrix with shape [rows,cols] - //Along dimension 0,2: expect matrix with shape [rows,dim2] - //Along dimension 1,2: expect matrix with shape [cols,dim2] - auto baseShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 3, shape); - - for(int i = 0; i < 3; i++) { - int *dimArr = dims[i]; - Nd4jLong *expectedShape = expectedShapes[i]; - shape::TAD *tad = new shape::TAD; - tad->init(baseShapeBuffer,dimArr,dimensionLength); - auto expectedShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', dimensionLength, expectedShape); - tad->createTadOnlyShapeInfo(); - Nd4jLong *testShapeBuffer = tad->tadOnlyShapeInfo; - Nd4jLong *expectedStride = expectedStrides[i]; - Nd4jLong *testShape = shape::shapeOf(testShapeBuffer); - Nd4jLong *testStride = shape::stride(testShapeBuffer); - EXPECT_TRUE(arrsEquals(shape::rank(expectedShapeBuffer),expectedShape,testShape)); - EXPECT_TRUE(arrsEquals(shape::rank(testShapeBuffer),expectedStride,testStride)); +TEST_F(TensorTwoFromFourDDimTest, TadTwoFromFourDimTest) { + // Along dimension 0,1: expect matrix with shape [rows,cols] + // Along dimension 0,2: expect matrix with shape [rows,dim2] + // Along dimension 0,3: expect matrix with shape [rows,dim3] + // Along dimension 1,2: expect matrix with shape [cols,dim2] + // Along dimension 1,3: expect matrix with shape [cols,dim3] + // Along dimension 2,3: expect matrix with shape [dim2,dim3] + auto baseShapeBuffer = + sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 4, shape); + for (int i = 0; i < 3; i++) { + int *dimArr = dims[i]; + Nd4jLong *expectedShape = expectedShapes[i]; + shape::TAD *tad = new shape::TAD; + tad->init(baseShapeBuffer, dimArr, dimensionLength); + auto expectedShapeBuffer = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'c', dimensionLength, expectedShape); + tad->createTadOnlyShapeInfo(); + Nd4jLong *testShapeBuffer = tad->tadOnlyShapeInfo; + EXPECT_TRUE(arrsEquals(shape::rank(expectedShapeBuffer), expectedShape, + shape::shapeOf(testShapeBuffer))); + EXPECT_TRUE(arrsEquals(shape::rank(expectedShapeBuffer), expectedStrides[i], + shape::stride(testShapeBuffer))); - delete[] expectedShapeBuffer; - delete tad; + delete[] expectedShapeBuffer; + delete tad; + } - } + delete[] baseShapeBuffer; +} - delete[] baseShapeBuffer; +TEST_F(TensorTwoDimTest, TadTwoDimTest) { + // Along dimension 0,1: expect matrix with shape [rows,cols] + // Along dimension 0,2: expect matrix with shape [rows,dim2] + // Along dimension 1,2: expect matrix with shape [cols,dim2] + auto baseShapeBuffer = + sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 3, shape); + for (int i = 0; i < 3; i++) { + int *dimArr = dims[i]; + Nd4jLong *expectedShape = expectedShapes[i]; + shape::TAD *tad = new shape::TAD; + tad->init(baseShapeBuffer, dimArr, dimensionLength); + auto expectedShapeBuffer = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'c', dimensionLength, expectedShape); + tad->createTadOnlyShapeInfo(); + Nd4jLong *testShapeBuffer = tad->tadOnlyShapeInfo; + Nd4jLong *expectedStride = expectedStrides[i]; + Nd4jLong *testShape = shape::shapeOf(testShapeBuffer); + Nd4jLong *testStride = shape::stride(testShapeBuffer); + EXPECT_TRUE( + arrsEquals(shape::rank(expectedShapeBuffer), expectedShape, testShape)); + EXPECT_TRUE( + arrsEquals(shape::rank(testShapeBuffer), expectedStride, testStride)); + + delete[] expectedShapeBuffer; + delete tad; + } + delete[] baseShapeBuffer; } -TEST_F(TensorOneDimTest,TadDimensionsForTensor) { - Nd4jLong shape[3] = {rows,cols,dim2}; - auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', rank, shape); - - for(int i = 0; i < rank; i++) { - //Along dimension 0: expect row vector with length 'dims[i]' - shape::TAD *zero = new shape::TAD; - zero->init(shapeBuffer,&dims[i],1); - zero->createTadOnlyShapeInfo(); - Nd4jLong *testDimZeroShapeBuffer = zero->tadOnlyShapeInfo; - Nd4jLong *testShape = shape::shapeOf(testDimZeroShapeBuffer); - Nd4jLong *testStride = shape::stride(testDimZeroShapeBuffer); - EXPECT_TRUE(arrsEquals(2,expectedShapes[i],testShape)); - EXPECT_TRUE(arrsEquals(2,expectedStrides[i],testStride)); - - delete zero; - } - - delete[] shapeBuffer; +TEST_F(TensorOneDimTest, TadDimensionsForTensor) { + Nd4jLong shape[3] = {rows, cols, dim2}; + auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, + 'c', rank, shape); + + for (int i = 0; i < rank; i++) { + // Along dimension 0: expect row vector with length 'dims[i]' + shape::TAD *zero = new shape::TAD; + zero->init(shapeBuffer, &dims[i], 1); + zero->createTadOnlyShapeInfo(); + Nd4jLong *testDimZeroShapeBuffer = zero->tadOnlyShapeInfo; + Nd4jLong *testShape = shape::shapeOf(testDimZeroShapeBuffer); + Nd4jLong *testStride = shape::stride(testDimZeroShapeBuffer); + EXPECT_TRUE(arrsEquals(2, expectedShapes[i], testShape)); + EXPECT_TRUE(arrsEquals(2, expectedStrides[i], testStride)); + + delete zero; + } + + delete[] shapeBuffer; } - -TEST_F(MatrixTest,TadDimensionsForMatrix) { - Nd4jLong shape[2] = {rows,cols}; - auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', rank, shape); - - shape::TAD *dimZero = new shape::TAD; - dimZero->init(shapeBuffer,&dims[0],1); - shape::TAD *dimOne = new shape::TAD; - dimOne->init(shapeBuffer,&dims[1],1); - //Along dimension 0: expect row vector with length 'rows' - Nd4jLong rowVectorShape[2] = {1,rows}; - auto expectedDimZeroShape = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, rowVectorShape); - dimZero->createTadOnlyShapeInfo(); - Nd4jLong *testDimZero = dimZero->tadOnlyShapeInfo; - EXPECT_TRUE(arrsEquals(2,expectedShapes[0],shape::shapeOf(testDimZero))); - EXPECT_TRUE(arrsEquals(2,expectedStrides[0],shape::stride(testDimZero))); - - delete[] expectedDimZeroShape; - //Along dimension 1: expect row vector with length 'cols' - Nd4jLong rowVectorColShape[2] {1,cols}; - auto expectedDimOneShape = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, rowVectorColShape); - dimOne->createTadOnlyShapeInfo(); - Nd4jLong *testDimOneShape = dimOne->tadOnlyShapeInfo; - EXPECT_TRUE(arrsEquals(2,expectedShapes[1],shape::shapeOf(testDimOneShape))); - EXPECT_TRUE(arrsEquals(2,expectedStrides[1],shape::stride(testDimOneShape))); - - delete[] expectedDimOneShape; - delete dimOne; - delete dimZero; - delete[] shapeBuffer; +TEST_F(MatrixTest, TadDimensionsForMatrix) { + Nd4jLong shape[2] = {rows, cols}; + auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, + 'c', rank, shape); + + shape::TAD *dimZero = new shape::TAD; + dimZero->init(shapeBuffer, &dims[0], 1); + shape::TAD *dimOne = new shape::TAD; + dimOne->init(shapeBuffer, &dims[1], 1); + // Along dimension 0: expect row vector with length 'rows' + Nd4jLong rowVectorShape[2] = {1, rows}; + auto expectedDimZeroShape = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'c', 2, rowVectorShape); + dimZero->createTadOnlyShapeInfo(); + Nd4jLong *testDimZero = dimZero->tadOnlyShapeInfo; + EXPECT_TRUE(arrsEquals(2, expectedShapes[0], shape::shapeOf(testDimZero))); + EXPECT_TRUE(arrsEquals(2, expectedStrides[0], shape::stride(testDimZero))); + + delete[] expectedDimZeroShape; + // Along dimension 1: expect row vector with length 'cols' + Nd4jLong rowVectorColShape[2]{1, cols}; + auto expectedDimOneShape = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'c', 2, rowVectorColShape); + dimOne->createTadOnlyShapeInfo(); + Nd4jLong *testDimOneShape = dimOne->tadOnlyShapeInfo; + EXPECT_TRUE( + arrsEquals(2, expectedShapes[1], shape::shapeOf(testDimOneShape))); + EXPECT_TRUE( + arrsEquals(2, expectedStrides[1], shape::stride(testDimOneShape))); + + delete[] expectedDimOneShape; + delete dimOne; + delete dimZero; + delete[] shapeBuffer; } -TEST_F(VectorTest,VectorTadShape) { - Nd4jLong rowVector[2] = {2,2}; - auto rowBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, rowVector); - int rowDimension = 1; - - Nd4jLong columnVector[2] = {2,2}; - auto colShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, columnVector); - int colDimension = 0; - - - shape::TAD *rowTad = new shape::TAD; - rowTad->init(rowBuffer,&rowDimension,1); - rowTad->createTadOnlyShapeInfo(); - Nd4jLong *rowTadShapeBuffer = rowTad->tadOnlyShapeInfo; - Nd4jLong *rowTadShape = shape::shapeOf(rowTadShapeBuffer); - shape::TAD *colTad = new shape::TAD; - colTad->init(colShapeBuffer,&colDimension,1); - colTad->createTadOnlyShapeInfo(); - Nd4jLong *colTadShapeBuffer = colTad->tadOnlyShapeInfo; - Nd4jLong *colTadShape = shape::shapeOf(colTadShapeBuffer); - Nd4jLong assertionShape[2] = {1,2}; - Nd4jLong assertionStride[2] = {1,1}; - EXPECT_TRUE(arrsEquals(2,assertionShape,rowTadShape)); - EXPECT_TRUE(arrsEquals(2,assertionStride,shape::stride(rowTadShapeBuffer))); - EXPECT_TRUE(arrsEquals(2,assertionShape,colTadShape)); - - delete[] rowBuffer; - delete[] colShapeBuffer; - delete rowTad; - delete colTad; +TEST_F(VectorTest, VectorTadShape) { + Nd4jLong rowVector[2] = {2, 2}; + auto rowBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, + 'c', 2, rowVector); + int rowDimension = 1; + + Nd4jLong columnVector[2] = {2, 2}; + auto colShapeBuffer = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'c', 2, columnVector); + int colDimension = 0; + + shape::TAD *rowTad = new shape::TAD; + rowTad->init(rowBuffer, &rowDimension, 1); + rowTad->createTadOnlyShapeInfo(); + Nd4jLong *rowTadShapeBuffer = rowTad->tadOnlyShapeInfo; + Nd4jLong *rowTadShape = shape::shapeOf(rowTadShapeBuffer); + shape::TAD *colTad = new shape::TAD; + colTad->init(colShapeBuffer, &colDimension, 1); + colTad->createTadOnlyShapeInfo(); + Nd4jLong *colTadShapeBuffer = colTad->tadOnlyShapeInfo; + Nd4jLong *colTadShape = shape::shapeOf(colTadShapeBuffer); + Nd4jLong assertionShape[2] = {1, 2}; + Nd4jLong assertionStride[2] = {1, 1}; + EXPECT_TRUE(arrsEquals(2, assertionShape, rowTadShape)); + EXPECT_TRUE(arrsEquals(2, assertionStride, shape::stride(rowTadShapeBuffer))); + EXPECT_TRUE(arrsEquals(2, assertionShape, colTadShape)); + + delete[] rowBuffer; + delete[] colShapeBuffer; + delete rowTad; + delete colTad; } +TEST_F(ShapeTest, IsVector) { ASSERT_TRUE(shape::isVector(vectorShape, 2)); } +TEST_F(VectorTest, LinspaceCombinationTest) { + int rows = 3; + int cols = 4; + int len = rows * cols; + double *linspaced = linspace(1, rows * cols, len); + Nd4jLong shape[2] = {rows, cols}; + auto shapeBuffer = + sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, shape); - -TEST_F(ShapeTest,IsVector) { - ASSERT_TRUE(shape::isVector(vectorShape,2)); -} - -TEST_F(VectorTest,LinspaceCombinationTest) { - int rows = 3; - int cols = 4; - int len = rows * cols; - double *linspaced = linspace(1,rows * cols,len); - Nd4jLong shape[2] = {rows,cols}; - auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, shape); - - delete[] shapeBuffer; - delete[] linspaced; + delete[] shapeBuffer; + delete[] linspaced; } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/ShapeUtilsTests.cpp b/libnd4j/tests_cpu/layers_tests/ShapeUtilsTests.cpp index 25f4f2c18001..8644960932a7 100644 --- a/libnd4j/tests_cpu/layers_tests/ShapeUtilsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ShapeUtilsTests.cpp @@ -18,276 +18,253 @@ // Created by raver119 on 01.11.2017. // -#include "testlayers.h" -#include #include +#include +#include "testlayers.h" using namespace sd; using namespace sd::graph; class ShapeUtilsTests : public testing::Test { -public: - + public: }; ////////////////////////////////////////////////////////////////// TEST_F(ShapeUtilsTests, evalDimsToExclude_1) { - std::vector res = ShapeUtils::evalDimsToExclude(3, {0}); + std::vector res = ShapeUtils::evalDimsToExclude(3, {0}); - ASSERT_EQ(2, res.size()); - ASSERT_EQ(1, res.at(0)); - ASSERT_EQ(2, res.at(1)); + ASSERT_EQ(2, res.size()); + ASSERT_EQ(1, res.at(0)); + ASSERT_EQ(2, res.at(1)); } ////////////////////////////////////////////////////////////////// TEST_F(ShapeUtilsTests, evalDimsToExclude_2) { - std::vector res = ShapeUtils::evalDimsToExclude(4, {2, 3}); + std::vector res = ShapeUtils::evalDimsToExclude(4, {2, 3}); - ASSERT_EQ(2, res.size()); - ASSERT_EQ(0, res.at(0)); - ASSERT_EQ(1, res.at(1)); + ASSERT_EQ(2, res.size()); + ASSERT_EQ(0, res.at(0)); + ASSERT_EQ(1, res.at(1)); } ////////////////////////////////////////////////////////////////// -TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_1) -{ - - Nd4jLong xShapeInfo[] = {3, 3, 2, 2, 4, 2, 1, 8192, 1, 99}; - Nd4jLong yShapeInfo[] = {2, 1, 2, 2, 1, 8192, 1, 99}; - Nd4jLong expShapeInfo[] = {3, 3, 2, 2, 4, 2, 1, 8192, 1, 99}; +TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_1) { + Nd4jLong xShapeInfo[] = {3, 3, 2, 2, 4, 2, 1, 8192, 1, 99}; + Nd4jLong yShapeInfo[] = {2, 1, 2, 2, 1, 8192, 1, 99}; + Nd4jLong expShapeInfo[] = {3, 3, 2, 2, 4, 2, 1, 8192, 1, 99}; - NDArray x(xShapeInfo); - NDArray y(yShapeInfo); + NDArray x(xShapeInfo); + NDArray y(yShapeInfo); - const Nd4jLong *newShapeInfo = nullptr; - ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr); + const Nd4jLong *newShapeInfo = nullptr; + ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr); - ASSERT_TRUE(shape::equalsStrict(expShapeInfo, newShapeInfo)); + ASSERT_TRUE(shape::equalsStrict(expShapeInfo, newShapeInfo)); } ////////////////////////////////////////////////////////////////// -TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_2) -{ +TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_2) { + Nd4jLong xShapeInfo[] = {4, 8, 1, 6, 1, 6, 6, 1, 1, 8192, 1, 99}; + Nd4jLong yShapeInfo[] = {3, 7, 1, 5, 5, 5, 1, 8192, 1, 99}; + Nd4jLong expShapeInfo[] = {4, 8, 7, 6, 5, 210, 30, 5, 1, 8192, 1, 99}; - Nd4jLong xShapeInfo[] = {4, 8, 1, 6, 1, 6, 6, 1, 1, 8192, 1, 99}; - Nd4jLong yShapeInfo[] = {3, 7, 1, 5, 5, 5, 1, 8192, 1, 99}; - Nd4jLong expShapeInfo[] = {4, 8, 7, 6, 5, 210, 30, 5, 1, 8192, 1, 99}; + NDArray x(xShapeInfo); + NDArray y(yShapeInfo); - NDArray x(xShapeInfo); - NDArray y(yShapeInfo); + const Nd4jLong *newShapeInfo = nullptr; + ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr); - const Nd4jLong *newShapeInfo = nullptr; - ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr); - - ASSERT_TRUE(shape::equalsStrict(expShapeInfo, newShapeInfo)); + ASSERT_TRUE(shape::equalsStrict(expShapeInfo, newShapeInfo)); } ////////////////////////////////////////////////////////////////// -TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_3) -{ - - Nd4jLong xShapeInfo[] = {3, 15, 3, 5, 15, 5, 1, 8192, 1, 99}; - Nd4jLong yShapeInfo[] = {3, 15, 1, 5, 5, 5, 1, 8192, 1, 99}; - Nd4jLong expShapeInfo[] = {3, 15, 3, 5, 15, 5, 1, 8192, 1, 99}; +TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_3) { + Nd4jLong xShapeInfo[] = {3, 15, 3, 5, 15, 5, 1, 8192, 1, 99}; + Nd4jLong yShapeInfo[] = {3, 15, 1, 5, 5, 5, 1, 8192, 1, 99}; + Nd4jLong expShapeInfo[] = {3, 15, 3, 5, 15, 5, 1, 8192, 1, 99}; - NDArray x(xShapeInfo); - NDArray y(yShapeInfo); + NDArray x(xShapeInfo); + NDArray y(yShapeInfo); - const Nd4jLong *newShapeInfo = nullptr; - ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr); + const Nd4jLong *newShapeInfo = nullptr; + ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr); - ASSERT_TRUE(shape::equalsStrict(expShapeInfo, newShapeInfo)); + ASSERT_TRUE(shape::equalsStrict(expShapeInfo, newShapeInfo)); } ////////////////////////////////////////////////////////////////// -TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_4) -{ +TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_4) { + Nd4jLong xShapeInfo[] = {3, 8, 1, 3, 3, 3, 1, 8192, 1, 99}; + Nd4jLong yShapeInfo[] = {2, 4, 3, 3, 1, 8192, 1, 99}; + Nd4jLong expShapeInfo[] = {3, 8, 4, 3, 12, 3, 1, 8192, 1, 99}; - Nd4jLong xShapeInfo[] = {3, 8, 1, 3, 3, 3, 1, 8192, 1, 99}; - Nd4jLong yShapeInfo[] = {2, 4, 3, 3, 1, 8192, 1, 99}; - Nd4jLong expShapeInfo[] = {3, 8, 4, 3, 12, 3, 1, 8192, 1, 99}; + NDArray x(xShapeInfo); + NDArray y(yShapeInfo); - NDArray x(xShapeInfo); - NDArray y(yShapeInfo); + const Nd4jLong *newShapeInfo = nullptr; + ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr); + // for(int i=0; i<2*newShapeInfo[0]+4; ++i) + // std::cout<('c',{2,3,4,5}); - auto expected = NDArrayFactory::create('c', {2,4,5}); - std::vector dimensions = {1}; +TEST_F(ShapeUtilsTests, evalReduceShapeInfo_test1) { + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto expected = NDArrayFactory::create('c', {2, 4, 5}); + std::vector dimensions = {1}; - auto newShapeInfo = ShapeUtils::evalReduceShapeInfo('c', dimensions, x.shapeInfo()); + auto newShapeInfo = + ShapeUtils::evalReduceShapeInfo('c', dimensions, x.shapeInfo()); - ASSERT_TRUE(shape::shapeEquals(expected.shapeInfo(), newShapeInfo)); + ASSERT_TRUE(shape::shapeEquals(expected.shapeInfo(), newShapeInfo)); } ////////////////////////////////////////////////////////////////// -TEST_F(ShapeUtilsTests, evalReduceShapeInfo_test2) -{ +TEST_F(ShapeUtilsTests, evalReduceShapeInfo_test2) { + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto expected = NDArrayFactory::create('c', {2, 1, 4, 5}); + std::vector dimensions = {1}; - auto x = NDArrayFactory::create('c',{2,3,4,5}); - auto expected = NDArrayFactory::create('c', {2,1,4,5}); - std::vector dimensions = {1}; + auto newShapeInfo = + ShapeUtils::evalReduceShapeInfo('c', dimensions, x.shapeInfo(), true); - auto newShapeInfo = ShapeUtils::evalReduceShapeInfo('c', dimensions, x.shapeInfo(), true); - - ASSERT_TRUE(shape::shapeEquals(expected.shapeInfo(), newShapeInfo)); + ASSERT_TRUE(shape::shapeEquals(expected.shapeInfo(), newShapeInfo)); } ////////////////////////////////////////////////////////////////// -TEST_F(ShapeUtilsTests, evalReduceShapeInfo_test3) -{ - - auto x = NDArrayFactory::create('c',{2,3,4,5}); - auto expected = NDArrayFactory::create('c', {1,1,1,5}); - std::vector dimensions = {0,1,2}; +TEST_F(ShapeUtilsTests, evalReduceShapeInfo_test3) { + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto expected = NDArrayFactory::create('c', {1, 1, 1, 5}); + std::vector dimensions = {0, 1, 2}; - auto newShapeInfo = ShapeUtils::evalReduceShapeInfo('c', dimensions, x.shapeInfo(), true); - - ASSERT_TRUE(shape::shapeEquals(expected.shapeInfo(), newShapeInfo)); + auto newShapeInfo = + ShapeUtils::evalReduceShapeInfo('c', dimensions, x.shapeInfo(), true); + ASSERT_TRUE(shape::shapeEquals(expected.shapeInfo(), newShapeInfo)); } ////////////////////////////////////////////////////////////////// -TEST_F(ShapeUtilsTests, evalReduceShapeInfo_test4) -{ - - auto x = NDArrayFactory::create('c',{2,3,4,5}); - auto expected = NDArrayFactory::create('c', {1,1,1,1}); - std::vector dimensions = {0,1,2,3}; +TEST_F(ShapeUtilsTests, evalReduceShapeInfo_test4) { + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto expected = NDArrayFactory::create('c', {1, 1, 1, 1}); + std::vector dimensions = {0, 1, 2, 3}; - auto newShapeInfo = ShapeUtils::evalReduceShapeInfo('c', dimensions, x.shapeInfo(), true); + auto newShapeInfo = + ShapeUtils::evalReduceShapeInfo('c', dimensions, x.shapeInfo(), true); - ASSERT_TRUE(shape::shapeEquals(expected.shapeInfo(), newShapeInfo)); + ASSERT_TRUE(shape::shapeEquals(expected.shapeInfo(), newShapeInfo)); } TEST_F(ShapeUtilsTests, Test_Strings_1) { - auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); - std::string exp("[2, 3, 4, 5]"); + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + std::string exp("[2, 3, 4, 5]"); - auto s = ShapeUtils::shapeAsString(&x); + auto s = ShapeUtils::shapeAsString(&x); - ASSERT_EQ(exp, s); + ASSERT_EQ(exp, s); } TEST_F(ShapeUtilsTests, Test_Backward_Axis_1) { - auto x = NDArrayFactory::create('c', {2, 4, 3}); - auto y = NDArrayFactory::create('c', {4, 3}); - std::vector exp({0}); + auto x = NDArrayFactory::create('c', {2, 4, 3}); + auto y = NDArrayFactory::create('c', {4, 3}); + std::vector exp({0}); - auto z = ShapeUtils::evalBroadcastBackwardAxis(y.shapeInfo(), x.shapeInfo()); + auto z = ShapeUtils::evalBroadcastBackwardAxis(y.shapeInfo(), x.shapeInfo()); - ASSERT_EQ(exp, z); + ASSERT_EQ(exp, z); } TEST_F(ShapeUtilsTests, Test_Backward_Axis_2) { - auto x = NDArrayFactory::create('c', {2, 4, 4, 3}); - auto y = NDArrayFactory::create('c', {4, 1, 3}); - std::vector exp({0, 2}); + auto x = NDArrayFactory::create('c', {2, 4, 4, 3}); + auto y = NDArrayFactory::create('c', {4, 1, 3}); + std::vector exp({0, 2}); - auto z = ShapeUtils::evalBroadcastBackwardAxis(y.shapeInfo(), x.shapeInfo()); + auto z = ShapeUtils::evalBroadcastBackwardAxis(y.shapeInfo(), x.shapeInfo()); - ASSERT_EQ(exp, z); + ASSERT_EQ(exp, z); } - TEST_F(ShapeUtilsTests, Test_Backward_Axis_3) { - auto x = NDArrayFactory::create('c', {2, 4, 4, 3}); - auto y = NDArrayFactory::create('c', {2, 1, 1, 3}); - std::vector exp({1, 2}); + auto x = NDArrayFactory::create('c', {2, 4, 4, 3}); + auto y = NDArrayFactory::create('c', {2, 1, 1, 3}); + std::vector exp({1, 2}); - auto z = ShapeUtils::evalBroadcastBackwardAxis(y.shapeInfo(), x.shapeInfo()); + auto z = ShapeUtils::evalBroadcastBackwardAxis(y.shapeInfo(), x.shapeInfo()); - ASSERT_EQ(exp, z); + ASSERT_EQ(exp, z); } ////////////////////////////////////////////////////////////////// TEST_F(ShapeUtilsTests, evalPermutFromTo_test1) { + int a = 1, b = 2, c = 3, d = 4; + std::vector expected = {2, 3, 0, 1}; - int a=1, b=2, c=3, d=4; - std::vector expected = {2, 3, 0, 1}; - - std::vector result = ShapeUtils::evalPermutFromTo({a,b,c,d}, {c,d,a,b}); - - ASSERT_TRUE(std::equal(begin(expected), end(expected), begin(result))); + std::vector result = + ShapeUtils::evalPermutFromTo({a, b, c, d}, {c, d, a, b}); + ASSERT_TRUE(std::equal(begin(expected), end(expected), begin(result))); } ////////////////////////////////////////////////////////////////// TEST_F(ShapeUtilsTests, evalPermutFromTo_test2) { + int a = 1, b = 2, c = 3, d = 4; + std::vector expected = {0, 1, 3, 2}; - int a=1, b=2, c=3, d=4; - std::vector expected = {0, 1, 3, 2}; - - std::vector result = ShapeUtils::evalPermutFromTo({a,b,c,d}, {a,b,d,c}); - - ASSERT_TRUE(std::equal(begin(expected), end(expected), begin(result))); + std::vector result = + ShapeUtils::evalPermutFromTo({a, b, c, d}, {a, b, d, c}); + ASSERT_TRUE(std::equal(begin(expected), end(expected), begin(result))); } ////////////////////////////////////////////////////////////////// TEST_F(ShapeUtilsTests, evalPermutFromTo_test3) { + int a = 2, b = 2, c = 3, d = 2; + std::vector expected = {0, 1, 3, 2}; - int a=2, b=2, c=3, d=2; - std::vector expected = {0, 1, 3, 2}; - - std::vector result = ShapeUtils::evalPermutFromTo({a,b,c,d}, {a,b,d,c}); - - ASSERT_TRUE(std::equal(begin(expected), end(expected), begin(result))); + std::vector result = + ShapeUtils::evalPermutFromTo({a, b, c, d}, {a, b, d, c}); + ASSERT_TRUE(std::equal(begin(expected), end(expected), begin(result))); } ////////////////////////////////////////////////////////////////// TEST_F(ShapeUtilsTests, evalPermutFromTo_test4) { + int a = 2, b = 3, c = 4, d = 5; - int a=2, b=3, c=4, d=5; - - std::vector result = ShapeUtils::evalPermutFromTo({a,b,c,d}, {a,b,c,d}); - - ASSERT_TRUE(result.empty()); + std::vector result = + ShapeUtils::evalPermutFromTo({a, b, c, d}, {a, b, c, d}); + ASSERT_TRUE(result.empty()); } ////////////////////////////////////////////////////////////////// TEST_F(ShapeUtilsTests, evalPermutFromTo_test5) { + int a = 1, b = 2, c = 3, d = 4; - int a=1, b=2, c=3, d=4; - - // EXPECT_THROW(ShapeUtils::evalPermutFromTo({a,b,c,d}, {c,d,a,8}), const char*); - ASSERT_TRUE(1); + // EXPECT_THROW(ShapeUtils::evalPermutFromTo({a,b,c,d}, {c,d,a,8}), const + // char*); + ASSERT_TRUE(1); } ////////////////////////////////////////////////////////////////// TEST_F(ShapeUtilsTests, evalPermutFromTo_test6) { + int a = 1, b = 2, c = 3, d = 4; - int a=1, b=2, c=3, d=4; - - // EXPECT_THROW(ShapeUtils::evalPermutFromTo({a,b,c,d}, {a,b,c,d,d}), const char*); - ASSERT_TRUE(1); + // EXPECT_THROW(ShapeUtils::evalPermutFromTo({a,b,c,d}, {a,b,c,d,d}), const + // char*); + ASSERT_TRUE(1); } ////////////////////////////////////////////////////////////////// TEST_F(ShapeUtilsTests, isPermutNecessary_test1) { - - ASSERT_TRUE(ShapeUtils::isPermutNecessary({1,0,2,3})); + ASSERT_TRUE(ShapeUtils::isPermutNecessary({1, 0, 2, 3})); } ////////////////////////////////////////////////////////////////// TEST_F(ShapeUtilsTests, isPermutNecessary_test2) { - - ASSERT_TRUE(!ShapeUtils::isPermutNecessary({0,1,2,3})); + ASSERT_TRUE(!ShapeUtils::isPermutNecessary({0, 1, 2, 3})); } - - diff --git a/libnd4j/tests_cpu/layers_tests/SingleDimTests.cpp b/libnd4j/tests_cpu/layers_tests/SingleDimTests.cpp index cc13f3529df5..ff1034a25f58 100644 --- a/libnd4j/tests_cpu/layers_tests/SingleDimTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/SingleDimTests.cpp @@ -18,168 +18,152 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include #include -#include #include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class SingleDimTests : public testing::Test { -public: - + public: }; TEST_F(SingleDimTests, Test_Create_1) { - auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - ASSERT_EQ(5, x.lengthOf()); - ASSERT_EQ(1, x.rankOf()); - ASSERT_TRUE(x.isVector()); - ASSERT_TRUE(x.isRowVector()); - ASSERT_FALSE(x.isMatrix()); + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + ASSERT_EQ(5, x.lengthOf()); + ASSERT_EQ(1, x.rankOf()); + ASSERT_TRUE(x.isVector()); + ASSERT_TRUE(x.isRowVector()); + ASSERT_FALSE(x.isMatrix()); } TEST_F(SingleDimTests, Test_Add_1) { - auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto exp = NDArrayFactory::create('c', {3}, {2, 3, 4}); + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {3}, {2, 3, 4}); - x += 1.0f; + x += 1.0f; - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } - TEST_F(SingleDimTests, Test_Pairwise_1) { - auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto exp = NDArrayFactory::create('c', {3}, {2, 4, 6}); + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {3}, {2, 4, 6}); - x += x; + x += x; - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(SingleDimTests, Test_Concat_1) { - auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto y = NDArrayFactory::create('c', {3}, {4, 5, 6}); - auto exp = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto y = NDArrayFactory::create('c', {3}, {4, 5, 6}); + auto exp = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); - sd::ops::concat op; - auto result = op.evaluate({&x, &y}, {}, {0}); + sd::ops::concat op; + auto result = op.evaluate({&x, &y}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(SingleDimTests, Test_Reduce_1) { - auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); - float r = x.reduceNumber(reduce::Sum).e(0); + float r = x.reduceNumber(reduce::Sum).e(0); - ASSERT_NEAR(6.0f, r, 1e-5f); + ASSERT_NEAR(6.0f, r, 1e-5f); } TEST_F(SingleDimTests, Test_IndexReduce_1) { - auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto r = x.indexReduceNumber(indexreduce::IndexMax).e(0); + auto r = x.indexReduceNumber(indexreduce::IndexMax).e(0); - ASSERT_NEAR(2, r, 1e-5f); + ASSERT_NEAR(2, r, 1e-5f); } - TEST_F(SingleDimTests, Test_ExpandDims_1) { - auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto exp = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); - sd::ops::expand_dims op; - auto result = op.evaluate({&x}, {}, {0}); + sd::ops::expand_dims op; + auto result = op.evaluate({&x}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(SingleDimTests, Test_ExpandDims_2) { - auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto exp = NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); - sd::ops::expand_dims op; - auto result = op.evaluate({&x}, {}, {1}); + sd::ops::expand_dims op; + auto result = op.evaluate({&x}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(SingleDimTests, Test_Squeeze_1) { - std::vector vecS({1}); - std::vector vecB({3.0f}); - auto x = NDArrayFactory::create('c', vecS, vecB); - auto exp = NDArrayFactory::create(3.0f); + std::vector vecS({1}); + std::vector vecB({3.0f}); + auto x = NDArrayFactory::create('c', vecS, vecB); + auto exp = NDArrayFactory::create(3.0f); - sd::ops::squeeze op; - auto result = op.evaluate({&x}, {}, {}); + sd::ops::squeeze op; + auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_EQ(exp.rankOf(), z->rankOf()); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_EQ(exp.rankOf(), z.rankOf()); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(SingleDimTests, Test_Squeeze_2) { - auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); - - sd::ops::squeeze op; - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::squeeze op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(SingleDimTests, Test_Permute_1) { - auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); - - sd::ops::permute op; - auto result = op.evaluate({&x}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::permute op; + auto result = op.evaluate({&x}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/SortCpuTests.cpp b/libnd4j/tests_cpu/layers_tests/SortCpuTests.cpp index a31547561a3e..9f43fdf7c22a 100644 --- a/libnd4j/tests_cpu/layers_tests/SortCpuTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/SortCpuTests.cpp @@ -18,87 +18,110 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include #include -#include #include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class SortCpuTests : public testing::Test { -public: - + public: }; - TEST_F(SortCpuTests, test_linear_sort_by_key_1) { - if (!Environment::getInstance().isCPU()) - return; + if (!Environment::getInstance().isCPU()) return; - auto k = NDArrayFactory::create('c', {10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); - auto v = NDArrayFactory::create('c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + auto k = NDArrayFactory::create('c', {10}, + {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create( + 'c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); - auto ek = NDArrayFactory::create('c', {10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto ev = NDArrayFactory::create('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + auto ek = NDArrayFactory::create('c', {10}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create( + 'c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + sortByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), + v.specialShapeInfo(), false); - sortByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false); - - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } TEST_F(SortCpuTests, test_linear_sort_by_val_1) { - if (!Environment::getInstance().isCPU()) - return; - - auto k = NDArrayFactory::create('c', {10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); - auto v = NDArrayFactory::create('c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + if (!Environment::getInstance().isCPU()) return; - auto ek = NDArrayFactory::create('c', {10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto ev = NDArrayFactory::create('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + auto k = NDArrayFactory::create('c', {10}, + {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create( + 'c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + auto ek = NDArrayFactory::create('c', {10}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create( + 'c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); - sortByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false); + sortByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), + v.specialBuffer(), v.specialShapeInfo(), false); - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } TEST_F(SortCpuTests, test_tad_sort_by_key_1) { - if (!Environment::getInstance().isCPU()) - return; - - auto k = NDArrayFactory::create('c', {2, 10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); - auto v = NDArrayFactory::create('c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); - - auto ek = NDArrayFactory::create('c', {2, 10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto ev = NDArrayFactory::create('c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); - - - int axis = 1; - sortTadByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); - - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + if (!Environment::getInstance().isCPU()) return; + + auto k = NDArrayFactory::create( + 'c', {2, 10}, + {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create( + 'c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, + 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create( + 'c', {2, 10}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create( + 'c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, + 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + int axis = 1; + sortTadByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), + v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); + + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } TEST_F(SortCpuTests, test_tad_sort_by_val_1) { - if (!Environment::getInstance().isCPU()) - return; - - auto k = NDArrayFactory::create('c', {2, 10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); - auto v = NDArrayFactory::create('c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); - - auto ek = NDArrayFactory::create('c', {2, 10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto ev = NDArrayFactory::create('c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); - - - int axis = 1; - sortTadByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); - - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + if (!Environment::getInstance().isCPU()) return; + + auto k = NDArrayFactory::create( + 'c', {2, 10}, + {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create( + 'c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, + 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create( + 'c', {2, 10}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create( + 'c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, + 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + int axis = 1; + sortTadByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), + v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); + + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } diff --git a/libnd4j/tests_cpu/layers_tests/SortCudaTests.cu b/libnd4j/tests_cpu/layers_tests/SortCudaTests.cu index 5a5e75b1b31e..129f5acede02 100644 --- a/libnd4j/tests_cpu/layers_tests/SortCudaTests.cu +++ b/libnd4j/tests_cpu/layers_tests/SortCudaTests.cu @@ -18,107 +18,148 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include #include -#include #include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class SortCudaTests : public testing::Test { -public: - + public: }; - TEST_F(SortCudaTests, test_linear_sort_by_key_1) { - auto k = NDArrayFactory::create('c', {10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); - auto v = NDArrayFactory::create('c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); - - auto ek = NDArrayFactory::create('c', {10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto ev = NDArrayFactory::create('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); - - Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; - - sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false); - k.tickWriteDevice(); - v.tickWriteDevice(); - - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + auto k = NDArrayFactory::create('c', {10}, + {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create( + 'c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create('c', {10}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create( + 'c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + Nd4jPointer extras[2] = {nullptr, + LaunchContext::defaultContext()->getCudaStream()}; + + sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), + v.specialShapeInfo(), false); + k.tickWriteDevice(); + v.tickWriteDevice(); + + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } TEST_F(SortCudaTests, test_linear_sort_by_val_1) { - auto k = NDArrayFactory::create('c', {10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); - auto v = NDArrayFactory::create('c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); - - auto ek = NDArrayFactory::create('c', {10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto ev = NDArrayFactory::create('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); - - Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; - - sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false); - k.tickWriteDevice(); - v.tickWriteDevice(); - - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + auto k = NDArrayFactory::create('c', {10}, + {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create( + 'c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create('c', {10}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create( + 'c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + Nd4jPointer extras[2] = {nullptr, + LaunchContext::defaultContext()->getCudaStream()}; + + sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), + v.specialBuffer(), v.specialShapeInfo(), false); + k.tickWriteDevice(); + v.tickWriteDevice(); + + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } TEST_F(SortCudaTests, test_linear_sort_by_val_2) { - auto k = NDArrayFactory::create('c', {6}, {0, 1, 2, 3, 4, 5}); -// auto v = NDArrayFactory::create('c', {6}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); - NDArray v = NDArrayFactory::create('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); - auto ek = NDArrayFactory::create('c', {6}, {3, 0, 1, 2, 4, 5}); - auto ev = NDArrayFactory::create('c', {6}, {0.95, 0.9, 0.75, 0.6, 0.5, 0.3}); - - Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; - - sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), true); - k.tickWriteDevice(); - v.tickWriteDevice(); - // k.printIndexedBuffer("KEYS"); - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + auto k = NDArrayFactory::create('c', {6}, {0, 1, 2, 3, 4, 5}); + // auto v = NDArrayFactory::create('c', {6}, {1.5, 3.5, 5.5, 9.5, + // 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + NDArray v = NDArrayFactory::create('c', {6}, + {0.9f, .75f, .6f, .95f, .5f, .3f}); + auto ek = NDArrayFactory::create('c', {6}, {3, 0, 1, 2, 4, 5}); + auto ev = NDArrayFactory::create('c', {6}, + {0.95, 0.9, 0.75, 0.6, 0.5, 0.3}); + + Nd4jPointer extras[2] = {nullptr, + LaunchContext::defaultContext()->getCudaStream()}; + + sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), + v.specialBuffer(), v.specialShapeInfo(), true); + k.tickWriteDevice(); + v.tickWriteDevice(); + // k.printIndexedBuffer("KEYS"); + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } TEST_F(SortCudaTests, test_tad_sort_by_key_1) { - auto k = NDArrayFactory::create('c', {2, 10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); - auto v = NDArrayFactory::create('c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); - - auto ek = NDArrayFactory::create('c', {2, 10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto ev = NDArrayFactory::create('c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); - - Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; - - int axis = 1; - sortTadByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); - k.tickWriteDevice(); - v.tickWriteDevice(); - - // k.printIndexedBuffer("k"); - // v.printIndexedBuffer("v"); - - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + auto k = NDArrayFactory::create( + 'c', {2, 10}, + {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create( + 'c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, + 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create( + 'c', {2, 10}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create( + 'c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, + 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + Nd4jPointer extras[2] = {nullptr, + LaunchContext::defaultContext()->getCudaStream()}; + + int axis = 1; + sortTadByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), + v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); + k.tickWriteDevice(); + v.tickWriteDevice(); + + // k.printIndexedBuffer("k"); + // v.printIndexedBuffer("v"); + + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } TEST_F(SortCudaTests, test_tad_sort_by_val_1) { - auto k = NDArrayFactory::create('c', {2, 10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); - auto v = NDArrayFactory::create('c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); - - auto ek = NDArrayFactory::create('c', {2, 10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto ev = NDArrayFactory::create('c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); - - Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; - - int axis = 1; - sortTadByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); - k.tickWriteDevice(); - v.tickWriteDevice(); - - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + auto k = NDArrayFactory::create( + 'c', {2, 10}, + {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create( + 'c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, + 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create( + 'c', {2, 10}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create( + 'c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, + 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + Nd4jPointer extras[2] = {nullptr, + LaunchContext::defaultContext()->getCudaStream()}; + + int axis = 1; + sortTadByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), + v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); + k.tickWriteDevice(); + v.tickWriteDevice(); + + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } diff --git a/libnd4j/tests_cpu/layers_tests/SparseUtilsTest.cpp b/libnd4j/tests_cpu/layers_tests/SparseUtilsTest.cpp index 37f52568f42f..505b5557bffd 100644 --- a/libnd4j/tests_cpu/layers_tests/SparseUtilsTest.cpp +++ b/libnd4j/tests_cpu/layers_tests/SparseUtilsTest.cpp @@ -18,130 +18,66 @@ // Created by raver119 on 04.08.17. // -#include "testlayers.h" -#include #include + +#include + #include "ops/specials_sparse.h" +#include "testlayers.h" using namespace sd; ////////////////////////////////////////////////////////////////////// class SparseUtilsTest : public testing::Test { -public: - static const Nd4jLong nnz = 40; - static const int rank = 3; + public: + static const Nd4jLong nnz = 40; + static const int rank = 3; }; - ////////////////////////////////////////////////////////////////////// TEST_F(SparseUtilsTest, SortCOOindices_Test) { - - #ifndef __CUDABLAS__ - - Nd4jLong * indicesArr = new Nd4jLong[nnz * rank]{ - 0,2,7, - 2,36,35, - 3,30,17, - 5,12,22, - 5,43,45, - 6,32,11, - 8,8,32, - 9,29,11, - 5,11,22, - 15,26,16, - 17,48,49, - 24,28,31, - 26,6,23, - 31,21,31, - 35,46,45, - 37,13,14, - 6,38,18, - 7,28,20, - 8,29,39, - 8,32,30, - 9,42,43, - 11,15,18, - 13,18,45, - 29,26,39, - 30,8,25, - 42,31,24, - 28,33,5, - 31,27,1, - 35,43,26, - 36,8,37, - 39,22,14, - 39,24,42, - 42,48,2, - 43,26,48, - 44,23,49, - 45,18,34, - 46,28,5, - 46,32,17, - 48,34,44, - 49,38,39, - }; - - Nd4jLong * expIndicesArr = new Nd4jLong[nnz * rank]{ - 0, 2, 7, - 2, 36, 35, - 3, 30, 17, - 5, 11, 22, - 5, 12, 22, - 5, 43, 45, - 6, 32, 11, - 6, 38, 18, - 7, 28, 20, - 8, 8, 32, - 8, 29, 39, - 8, 32, 30, - 9, 29, 11, - 9, 42, 43, - 11, 15, 18, - 13, 18, 45, - 15, 26, 16, - 17, 48, 49, - 24, 28, 31, - 26, 6, 23, - 28, 33, 5, - 29, 26, 39, - 30, 8, 25, - 31, 21, 31, - 31, 27, 1, - 35, 43, 26, - 35, 46, 45, - 36, 8, 37, - 37, 13, 14, - 39, 22, 14, - 39, 24, 42, - 42, 31, 24, - 42, 48, 2, - 43, 26, 48, - 44, 23, 49, - 45, 18, 34, - 46, 28, 5, - 46, 32, 17, - 48, 34, 44, - 49, 38, 39, - }; - - auto values = NDArrayFactory::create('c', {40}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, - 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39}); - - auto expValues = NDArrayFactory::create('c', {40}, {0, 1, 2, 8, 3, 4, 5, 16, 17, 6, 18, 19, 7, 20, 21, 22, 9, - 10, 11, 12, 26, 23, 24, 13, 27, 28, 14, 29, 15, 30, 31, 25, 32, 33, - 34, 35, 36, 37, 38, 39 - }); - - sd::sparse::SparseUtils::sortCooIndicesGeneric(indicesArr, reinterpret_cast(values.buffer()), nnz, rank); - - for ( int i = 0; i < rank * nnz; ++i){ - ASSERT_EQ(expIndicesArr[i], indicesArr[i]); - } - - ASSERT_TRUE(expValues.equalsTo(values)); - - - delete[] indicesArr; - delete[] expIndicesArr; - - #endif +#ifndef __CUDABLAS__ + + Nd4jLong* indicesArr = new Nd4jLong[nnz * rank]{ + 0, 2, 7, 2, 36, 35, 3, 30, 17, 5, 12, 22, 5, 43, 45, 6, 32, 11, + 8, 8, 32, 9, 29, 11, 5, 11, 22, 15, 26, 16, 17, 48, 49, 24, 28, 31, + 26, 6, 23, 31, 21, 31, 35, 46, 45, 37, 13, 14, 6, 38, 18, 7, 28, 20, + 8, 29, 39, 8, 32, 30, 9, 42, 43, 11, 15, 18, 13, 18, 45, 29, 26, 39, + 30, 8, 25, 42, 31, 24, 28, 33, 5, 31, 27, 1, 35, 43, 26, 36, 8, 37, + 39, 22, 14, 39, 24, 42, 42, 48, 2, 43, 26, 48, 44, 23, 49, 45, 18, 34, + 46, 28, 5, 46, 32, 17, 48, 34, 44, 49, 38, 39, + }; + + Nd4jLong* expIndicesArr = new Nd4jLong[nnz * rank]{ + 0, 2, 7, 2, 36, 35, 3, 30, 17, 5, 11, 22, 5, 12, 22, 5, 43, 45, + 6, 32, 11, 6, 38, 18, 7, 28, 20, 8, 8, 32, 8, 29, 39, 8, 32, 30, + 9, 29, 11, 9, 42, 43, 11, 15, 18, 13, 18, 45, 15, 26, 16, 17, 48, 49, + 24, 28, 31, 26, 6, 23, 28, 33, 5, 29, 26, 39, 30, 8, 25, 31, 21, 31, + 31, 27, 1, 35, 43, 26, 35, 46, 45, 36, 8, 37, 37, 13, 14, 39, 22, 14, + 39, 24, 42, 42, 31, 24, 42, 48, 2, 43, 26, 48, 44, 23, 49, 45, 18, 34, + 46, 28, 5, 46, 32, 17, 48, 34, 44, 49, 38, 39, + }; + + auto values = NDArrayFactory::create( + 'c', {40}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, + 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39}); + + auto expValues = NDArrayFactory::create( + 'c', {40}, {0, 1, 2, 8, 3, 4, 5, 16, 17, 6, 18, 19, 7, 20, + 21, 22, 9, 10, 11, 12, 26, 23, 24, 13, 27, 28, 14, 29, + 15, 30, 31, 25, 32, 33, 34, 35, 36, 37, 38, 39}); + + sd::sparse::SparseUtils::sortCooIndicesGeneric( + indicesArr, reinterpret_cast(values.buffer()), nnz, rank); + + for (int i = 0; i < rank * nnz; ++i) { + ASSERT_EQ(expIndicesArr[i], indicesArr[i]); + } + + ASSERT_TRUE(expValues.equalsTo(values)); + + delete[] indicesArr; + delete[] expIndicesArr; + +#endif } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/StashTests.cpp b/libnd4j/tests_cpu/layers_tests/StashTests.cpp index 2cba6682dd11..5c2b8651016c 100644 --- a/libnd4j/tests_cpu/layers_tests/StashTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/StashTests.cpp @@ -22,67 +22,65 @@ #define LIBND4J_STASHTESTS_H #include -#include "testlayers.h" #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class StashTests : public testing::Test { -public: - + public: }; TEST_F(StashTests, BasicTests_1) { - Stash stash; + Stash stash; - auto alpha = NDArrayFactory::create_('c',{5, 5}); - alpha->assign(1.0); + auto alpha = NDArrayFactory::create_('c', {5, 5}); + alpha->assign(1.0); - auto beta = NDArrayFactory::create_('c',{5, 5}); - beta->assign(2.0); + auto beta = NDArrayFactory::create_('c', {5, 5}); + beta->assign(2.0); - auto cappa = NDArrayFactory::create_('c',{5, 5}); - cappa->assign(3.0); + auto cappa = NDArrayFactory::create_('c', {5, 5}); + cappa->assign(3.0); - stash.storeArray(1, "alpha", alpha); - stash.storeArray(2, "alpha", beta); - stash.storeArray(3, "cappa", cappa); + stash.storeArray(1, "alpha", alpha); + stash.storeArray(2, "alpha", beta); + stash.storeArray(3, "cappa", cappa); - ASSERT_TRUE(stash.checkStash(1, "alpha")); - ASSERT_TRUE(stash.checkStash(2, "alpha")); - ASSERT_TRUE(stash.checkStash(3, "cappa")); + ASSERT_TRUE(stash.checkStash(1, "alpha")); + ASSERT_TRUE(stash.checkStash(2, "alpha")); + ASSERT_TRUE(stash.checkStash(3, "cappa")); - ASSERT_FALSE(stash.checkStash(3, "alpha")); - ASSERT_FALSE(stash.checkStash(2, "beta")); - ASSERT_FALSE(stash.checkStash(1, "cappa")); + ASSERT_FALSE(stash.checkStash(3, "alpha")); + ASSERT_FALSE(stash.checkStash(2, "beta")); + ASSERT_FALSE(stash.checkStash(1, "cappa")); } - TEST_F(StashTests, BasicTests_2) { - Stash stash; - - auto alpha = NDArrayFactory::create_('c',{5, 5}); - alpha->assign(1.0); + Stash stash; - auto beta = NDArrayFactory::create_('c',{5, 5}); - beta->assign(2.0); + auto alpha = NDArrayFactory::create_('c', {5, 5}); + alpha->assign(1.0); - auto cappa = NDArrayFactory::create_('c',{5, 5}); - cappa->assign(3.0); + auto beta = NDArrayFactory::create_('c', {5, 5}); + beta->assign(2.0); - stash.storeArray(1, "alpha", alpha); - stash.storeArray(1, "beta", beta); - stash.storeArray(1, "cappa", cappa); + auto cappa = NDArrayFactory::create_('c', {5, 5}); + cappa->assign(3.0); - ASSERT_FALSE(stash.checkStash(2, "alpha")); - ASSERT_FALSE(stash.checkStash(2, "beta")); - ASSERT_FALSE(stash.checkStash(2, "cappa")); + stash.storeArray(1, "alpha", alpha); + stash.storeArray(1, "beta", beta); + stash.storeArray(1, "cappa", cappa); - ASSERT_TRUE(alpha == stash.extractArray(1, "alpha")); - ASSERT_TRUE(beta == stash.extractArray(1, "beta")); - ASSERT_TRUE(cappa == stash.extractArray(1, "cappa")); + ASSERT_FALSE(stash.checkStash(2, "alpha")); + ASSERT_FALSE(stash.checkStash(2, "beta")); + ASSERT_FALSE(stash.checkStash(2, "cappa")); + ASSERT_TRUE(alpha == stash.extractArray(1, "alpha")); + ASSERT_TRUE(beta == stash.extractArray(1, "beta")); + ASSERT_TRUE(cappa == stash.extractArray(1, "cappa")); } -#endif //LIBND4J_STASHTESTS_H +#endif // LIBND4J_STASHTESTS_H diff --git a/libnd4j/tests_cpu/layers_tests/StringTests.cpp b/libnd4j/tests_cpu/layers_tests/StringTests.cpp index 41352246ebf1..f2d2de852e19 100644 --- a/libnd4j/tests_cpu/layers_tests/StringTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/StringTests.cpp @@ -20,850 +20,861 @@ // @author Oleg Semeniv // - #include #include -#include "testlayers.h" #include #include #include +#include "testlayers.h" + using namespace sd; class StringTests : public testing::Test { -public: - + public: }; ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_1) { - std::string f("alpha"); - auto array = NDArrayFactory::string(f); - ASSERT_EQ(sd::DataType::UTF8, array.dataType()); + std::string f("alpha"); + auto array = NDArrayFactory::string(f); + ASSERT_EQ(sd::DataType::UTF8, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto z = array.e(0); + auto z = array.e(0); - ASSERT_EQ(f, z); + ASSERT_EQ(f, z); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_2) { - std::string f("alpha"); - auto array = NDArrayFactory::string(f.c_str()); - ASSERT_EQ(sd::DataType::UTF8, array.dataType()); + std::string f("alpha"); + auto array = NDArrayFactory::string(f.c_str()); + ASSERT_EQ(sd::DataType::UTF8, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto z = array.e(0); + auto z = array.e(0); - ASSERT_EQ(f, z); + ASSERT_EQ(f, z); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_3) { + auto array = NDArrayFactory::string( + {3, 2}, {"alpha", "beta", "gamma", "phi", "theta", "omega"}); - auto array = NDArrayFactory::string({3, 2}, {"alpha", "beta", "gamma", "phi", "theta", "omega"}); - - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_4) { + NDArray array({3, 2}, + std::vector{U"alpha", U"beta", U"gamma€한", + U"pÿqwe", U"ß水𝄋", U"omega"}); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - NDArray array( { 3, 2 }, std::vector{ U"alpha", U"beta", U"gamma€한", U"pÿqwe", U"ß水𝄋", U"omega" }); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); - - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_5) { + NDArray array({3, 2}, + std::vector{u"alpha", u"beta", u"gamma€한", + u"pÿqwe", u"ß水𝄋", u"omega"}); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - NDArray array( { 3, 2 }, std::vector{ u"alpha", u"beta", u"gamma€한", u"pÿqwe", u"ß水𝄋", u"omega" }); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); - - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_6) { + NDArray array({3, 2}, std::vector{"alpha", "beta", "gamma€한", + "pÿqwe", "ß水𝄋", "omega"}); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - NDArray array( { 3, 2 }, std::vector{ "alpha", "beta", "gamma€한", "pÿqwe", "ß水𝄋", "omega" }); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); - - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_7) { + NDArray array({3, 2}, + std::vector{U"alpha", U"beta", U"gamma€한", + U"pÿqwe", U"ß水𝄋", U"omega"}); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - NDArray array( { 3, 2 }, std::vector{ U"alpha", U"beta", U"gamma€한", U"pÿqwe", U"ß水𝄋", U"omega" }); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); - - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_8) { + NDArray array({3, 2}, + std::vector{u"alpha", u"beta", u"gamma€한", + u"pÿqwe", u"ß水𝄋", u"omega"}); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - NDArray array( { 3, 2 }, std::vector{ u"alpha", u"beta", u"gamma€한", u"pÿqwe", u"ß水𝄋", u"omega" }); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); - - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_9) { + NDArray array({3, 2}, std::vector{"alpha", "beta", "gamma€한", + "pÿqwe", "ß水𝄋", "omega"}); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - NDArray array( { 3, 2 }, std::vector{ "alpha", "beta", "gamma€한", "pÿqwe", "ß水𝄋", "omega" }); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); - - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_10) { - - NDArray array(std::u32string(U"gamma€한")); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); - array.printIndexedBuffer("String array"); + NDArray array(std::u32string(U"gamma€한")); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_11) { + NDArray array(U"gamma€한"); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - NDArray array(U"gamma€한"); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); - - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_12) { - - NDArray array(std::u16string(u"gamma€한")); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); - array.printIndexedBuffer("String array"); + NDArray array(std::u16string(u"gamma€한")); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_13) { + NDArray array(u"gamma€한"); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - NDArray array(u"gamma€한"); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); - - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_14) { - - NDArray array(std::string("gamma€한")); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); - array.printIndexedBuffer("String array"); + NDArray array(std::string("gamma€한")); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_15) { + NDArray array("gamma€한"); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - NDArray array("gamma€한"); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); - - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_16) { + auto array = NDArrayFactory::string( + {3, 2}, std::vector{"alpha", "beta", "gamma", "phi", "theta", + "omega"}); - auto array = NDArrayFactory::string( { 3, 2 }, std::vector{ "alpha", "beta", "gamma", "phi", "theta", "omega" }); - - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_17) { + auto array = NDArrayFactory::string( + {3, 2}, std::vector{"alpha", "beta", "gamma", "phi", "theta", + "omega"}); - auto array = NDArrayFactory::string({ 3, 2 }, std::vector{ "alpha", "beta", "gamma", "phi", "theta", "omega" }); - - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_18) { + auto array = NDArrayFactory::string( + {3, 2}, std::vector{u"alpha", u"beta", u"gamma", u"phi", + u"theta", u"omega"}); - auto array = NDArrayFactory::string({ 3, 2 }, std::vector{ u"alpha", u"beta", u"gamma", u"phi", u"theta", u"omega" }); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); - - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_19) { + auto array = NDArrayFactory::string( + {3, 2}, std::vector{u"alpha", u"beta", u"gamma", u"phi", + u"theta", u"omega"}); - auto array = NDArrayFactory::string( { 3, 2 }, std::vector{ u"alpha", u"beta", u"gamma", u"phi", u"theta", u"omega" }); - - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_20) { + auto array = NDArrayFactory::string( + {3, 2}, std::vector{U"alpha", U"beta", U"gamma", U"phi", + U"theta", U"omega"}); - auto array = NDArrayFactory::string( { 3, 2 }, std::vector{ U"alpha", U"beta", U"gamma", U"phi", U"theta", U"omega" }); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); - - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_21) { + auto array = NDArrayFactory::string( + {3, 2}, std::vector{U"alpha", U"òèçùà12345¤z", + U"ß水𝄋ÿ€한𐍈®кею90ощъ]ї", U"phi", + U"theta", U"omega"}); - auto array = NDArrayFactory::string( { 3, 2 }, std::vector{ U"alpha", U"òèçùà12345¤z", U"ß水𝄋ÿ€한𐍈®кею90ощъ]ї", U"phi", U"theta", U"omega" }); - - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_22) { - std::u16string f(u"ß水𝄋ÿ€한𐍈®кею90ощъ]ї"); - auto array = NDArrayFactory::string(f.c_str()); - ASSERT_EQ(sd::DataType::UTF16, array.dataType()); + std::u16string f(u"ß水𝄋ÿ€한𐍈®кею90ощъ]ї"); + auto array = NDArrayFactory::string(f.c_str()); + ASSERT_EQ(sd::DataType::UTF16, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto z = array.e(0); + auto z = array.e(0); - ASSERT_EQ(f, z); + ASSERT_EQ(f, z); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_23) { - std::u32string f(U"ß水𝄋ÿ€한𐍈®кею90ощъ]ї"); - auto array = NDArrayFactory::string(f.c_str()); - ASSERT_EQ(sd::DataType::UTF32, array.dataType()); + std::u32string f(U"ß水𝄋ÿ€한𐍈®кею90ощъ]ї"); + auto array = NDArrayFactory::string(f.c_str()); + ASSERT_EQ(sd::DataType::UTF32, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto z = array.e(0); + auto z = array.e(0); - ASSERT_EQ(f, z); + ASSERT_EQ(f, z); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Export_Test_1) { - auto array = NDArrayFactory::string( {3}, {"alpha", "beta", "gamma"}); - auto vector = array.asByteVector(); + auto array = NDArrayFactory::string({3}, {"alpha", "beta", "gamma"}); + auto vector = array.asByteVector(); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_dup_1) { - std::string f("alpha"); - auto array = NDArrayFactory::string(f); - ASSERT_EQ(sd::DataType::UTF8, array.dataType()); + std::string f("alpha"); + auto array = NDArrayFactory::string(f); + ASSERT_EQ(sd::DataType::UTF8, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto dup = new NDArray(array.dup()); + auto dup = new NDArray(array.dup()); - auto z0 = array.e(0); - auto z1 = dup->e(0); + auto z0 = array.e(0); + auto z1 = dup->e(0); - ASSERT_EQ(f, z0); - ASSERT_EQ(f, z1); + ASSERT_EQ(f, z0); + ASSERT_EQ(f, z1); - delete dup; + delete dup; } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, byte_length_test_1) { - std::string f("alpha"); - auto array = NDArrayFactory::string(f); + std::string f("alpha"); + auto array = NDArrayFactory::string(f); - ASSERT_EQ(f.length(), StringUtils::byteLength(array)); + ASSERT_EQ(f.length(), StringUtils::byteLength(array)); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, byte_length_test_2) { - auto array = NDArrayFactory::string( {2}, {"alpha", "beta"}); + auto array = NDArrayFactory::string({2}, {"alpha", "beta"}); - ASSERT_EQ(9, StringUtils::byteLength(array)); + ASSERT_EQ(9, StringUtils::byteLength(array)); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, test_split_1) { - auto split = StringUtils::split("alpha beta gamma", " "); + auto split = StringUtils::split("alpha beta gamma", " "); - ASSERT_EQ(3, split.size()); - ASSERT_EQ(std::string("alpha"), split[0]); - ASSERT_EQ(std::string("beta"), split[1]); - ASSERT_EQ(std::string("gamma"), split[2]); + ASSERT_EQ(3, split.size()); + ASSERT_EQ(std::string("alpha"), split[0]); + ASSERT_EQ(std::string("beta"), split[1]); + ASSERT_EQ(std::string("gamma"), split[2]); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, test_unicode_utf8_utf16) { + std::string utf8 = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + std::u16string utf16Exp = + u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - std::string utf8 = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - std::u16string utf16Exp = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - - std::u16string utf16Res; - ASSERT_TRUE(StringUtils::u8StringToU16String(utf8, utf16Res)); + std::u16string utf16Res; + ASSERT_TRUE(StringUtils::u8StringToU16String(utf8, utf16Res)); - ASSERT_EQ(utf16Res.size(), utf16Exp.size()); - for (auto i = 0; i < utf16Exp.size(); i++) { - ASSERT_EQ(utf16Exp[i], utf16Res[i]); - } + ASSERT_EQ(utf16Res.size(), utf16Exp.size()); + for (auto i = 0; i < utf16Exp.size(); i++) { + ASSERT_EQ(utf16Exp[i], utf16Res[i]); + } } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, test_unicode_utf8_utf32) { + std::string utf8 = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + std::u32string utf32Exp = + U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - std::string utf8 = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - std::u32string utf32Exp = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + std::u32string utf32Res; + ASSERT_TRUE(StringUtils::u8StringToU32String(utf8, utf32Res)); - std::u32string utf32Res; - ASSERT_TRUE(StringUtils::u8StringToU32String(utf8, utf32Res)); - - ASSERT_EQ(utf32Res.size(), utf32Exp.size()); - for (auto i = 0; i < utf32Exp.size(); i++) { - ASSERT_EQ(utf32Exp[i], utf32Res[i]); - } + ASSERT_EQ(utf32Res.size(), utf32Exp.size()); + for (auto i = 0; i < utf32Exp.size(); i++) { + ASSERT_EQ(utf32Exp[i], utf32Res[i]); + } } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, test_unicode_utf16_utf8) { + std::string utf8Exp = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + std::u16string utf16 = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - std::string utf8Exp = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - std::u16string utf16 = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - - std::string utf8Res; - ASSERT_TRUE(StringUtils::u16StringToU8String(utf16, utf8Res)); + std::string utf8Res; + ASSERT_TRUE(StringUtils::u16StringToU8String(utf16, utf8Res)); - ASSERT_EQ(utf8Res.size(), utf8Exp.size()); - for (auto i = 0; i < utf8Exp.size(); i++) { - ASSERT_EQ(utf8Exp[i], utf8Res[i]); - } + ASSERT_EQ(utf8Res.size(), utf8Exp.size()); + for (auto i = 0; i < utf8Exp.size(); i++) { + ASSERT_EQ(utf8Exp[i], utf8Res[i]); + } } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, test_unicode_utf32_utf8) { + std::string utf8Exp = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею 90ощъ]їїщkk1q\n\t\rop~"; + std::u32string utf32 = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею 90ощъ]їїщkk1q\n\t\rop~"; - std::string utf8Exp = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею 90ощъ]їїщkk1q\n\t\rop~"; - std::u32string utf32 = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею 90ощъ]їїщkk1q\n\t\rop~"; - - std::string utf8Res; - ASSERT_TRUE(StringUtils::u32StringToU8String(utf32, utf8Res)); + std::string utf8Res; + ASSERT_TRUE(StringUtils::u32StringToU8String(utf32, utf8Res)); - ASSERT_EQ(utf8Res.size(), utf8Exp.size()); - for (auto i = 0; i < utf8Exp.size(); i++) { - ASSERT_EQ(utf8Exp[i], utf8Res[i]); - } + ASSERT_EQ(utf8Res.size(), utf8Exp.size()); + for (auto i = 0; i < utf8Exp.size(); i++) { + ASSERT_EQ(utf8Exp[i], utf8Res[i]); + } } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, test_unicode_utf16_utf32) { + std::u32string utf32Exp = + U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + std::u16string utf16 = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - std::u32string utf32Exp = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - std::u16string utf16 = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + std::u32string utf32Res; + ASSERT_TRUE(StringUtils::u16StringToU32String(utf16, utf32Res)); - std::u32string utf32Res; - ASSERT_TRUE(StringUtils::u16StringToU32String(utf16, utf32Res)); - - ASSERT_EQ(utf32Res.size(), utf32Exp.size()); - for (auto i = 0; i < utf32Exp.size(); i++) { - ASSERT_EQ(utf32Exp[i], utf32Res[i]); - } + ASSERT_EQ(utf32Res.size(), utf32Exp.size()); + for (auto i = 0; i < utf32Exp.size(); i++) { + ASSERT_EQ(utf32Exp[i], utf32Res[i]); + } } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, test_unicode_utf32_utf16) { + std::u16string utf16Exp = + u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + std::u32string utf32 = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - std::u16string utf16Exp = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - std::u32string utf32 = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - - std::u16string utf16Res; - ASSERT_TRUE(StringUtils::u32StringToU16String(utf32, utf16Res)); + std::u16string utf16Res; + ASSERT_TRUE(StringUtils::u32StringToU16String(utf32, utf16Res)); - ASSERT_EQ(utf16Res.size(), utf16Exp.size()); - for (auto i = 0; i < utf16Exp.size(); i++) { - ASSERT_EQ(utf16Exp[i], utf16Res[i]); - } + ASSERT_EQ(utf16Res.size(), utf16Exp.size()); + for (auto i = 0; i < utf16Exp.size(); i++) { + ASSERT_EQ(utf16Exp[i], utf16Res[i]); + } } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, byte_length_test_Default) { - - std::string f("alpha"); - auto array = NDArrayFactory::string(f); + std::string f("alpha"); + auto array = NDArrayFactory::string(f); + + ASSERT_EQ(f.length(), StringUtils::byteLength(array)); - ASSERT_EQ(f.length(), StringUtils::byteLength(array)); + std::u16string f16(u"alpha"); + auto array16 = NDArrayFactory::string(f16); - std::u16string f16(u"alpha"); - auto array16 = NDArrayFactory::string(f16); - - ASSERT_EQ(sizeof(char16_t)*f16.length(), StringUtils::byteLength(array16)); + ASSERT_EQ(sizeof(char16_t) * f16.length(), StringUtils::byteLength(array16)); - std::u32string f32(U"alpha"); - auto array32 = NDArrayFactory::string(f32); + std::u32string f32(U"alpha"); + auto array32 = NDArrayFactory::string(f32); - ASSERT_EQ(sizeof(char32_t) * f32.length(), StringUtils::byteLength(array32)); + ASSERT_EQ(sizeof(char32_t) * f32.length(), StringUtils::byteLength(array32)); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, byte_length_test_UTF16) { - std::string f(u8"alpha"); - auto array = NDArrayFactory::string(f, sd::DataType::UTF16); + std::string f(u8"alpha"); + auto array = NDArrayFactory::string(f, sd::DataType::UTF16); - ASSERT_EQ(sizeof(char16_t) * f.length(), StringUtils::byteLength(array)); + ASSERT_EQ(sizeof(char16_t) * f.length(), StringUtils::byteLength(array)); - std::u16string f16(u"alpha"); - auto array16 = NDArrayFactory::string(f16, sd::DataType::UTF16); + std::u16string f16(u"alpha"); + auto array16 = NDArrayFactory::string(f16, sd::DataType::UTF16); - ASSERT_EQ(sizeof(char16_t) * f16.length(), StringUtils::byteLength(array16)); + ASSERT_EQ(sizeof(char16_t) * f16.length(), StringUtils::byteLength(array16)); - std::u32string f32(U"alpha"); - auto array32 = NDArrayFactory::string(f32, sd::DataType::UTF16); + std::u32string f32(U"alpha"); + auto array32 = NDArrayFactory::string(f32, sd::DataType::UTF16); - ASSERT_EQ(sizeof(char16_t) * f32.length(), StringUtils::byteLength(array32)); + ASSERT_EQ(sizeof(char16_t) * f32.length(), StringUtils::byteLength(array32)); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_UTF16toU8) { + std::u16string f16(u"alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f16, sd::DataType::UTF8); + ASSERT_EQ(sd::DataType::UTF8, array.dataType()); - std::u16string f16(u"alpha水𝄋ÿ€한𐍈®кею"); - auto array = NDArrayFactory::string(f16, sd::DataType::UTF8); - ASSERT_EQ(sd::DataType::UTF8, array.dataType()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + auto z = array.e(0); - auto z = array.e(0); - - std::string f(u8"alpha水𝄋ÿ€한𐍈®кею"); - ASSERT_EQ(f, z); + std::string f(u8"alpha水𝄋ÿ€한𐍈®кею"); + ASSERT_EQ(f, z); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_UTF32toU8) { - std::u32string f32(U"alpha水𝄋ÿ€한𐍈®кею"); - auto array = NDArrayFactory::string(f32.c_str(), sd::DataType::UTF8); - ASSERT_EQ(sd::DataType::UTF8, array.dataType()); + std::u32string f32(U"alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f32.c_str(), sd::DataType::UTF8); + ASSERT_EQ(sd::DataType::UTF8, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto z = array.e(0); - std::string f(u8"alpha水𝄋ÿ€한𐍈®кею"); - ASSERT_EQ(f, z); + auto z = array.e(0); + std::string f(u8"alpha水𝄋ÿ€한𐍈®кею"); + ASSERT_EQ(f, z); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_UTF16toU16) { + std::u16string f16(u"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f16, sd::DataType::UTF16); + ASSERT_EQ(sd::DataType::UTF16, array.dataType()); - std::u16string f16(u"€alpha水𝄋ÿ€한𐍈®кею"); - auto array = NDArrayFactory::string(f16, sd::DataType::UTF16); - ASSERT_EQ(sd::DataType::UTF16, array.dataType()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + auto z = array.e(0); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); - auto z = array.e(0); - - ASSERT_EQ(z, f16); + ASSERT_EQ(z, f16); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_UTF32toU16) { + std::u32string f32(U"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f32, sd::DataType::UTF16); + ASSERT_EQ(sd::DataType::UTF16, array.dataType()); - std::u32string f32(U"€alpha水𝄋ÿ€한𐍈®кею"); - auto array = NDArrayFactory::string(f32, sd::DataType::UTF16); - ASSERT_EQ(sd::DataType::UTF16, array.dataType()); - - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); - auto z = array.e(0); - std::u16string f16(u"€alpha水𝄋ÿ€한𐍈®кею"); - ASSERT_EQ(z, f16); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + auto z = array.e(0); + std::u16string f16(u"€alpha水𝄋ÿ€한𐍈®кею"); + ASSERT_EQ(z, f16); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_UTF16toU32) { + std::u16string f16(u"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f16, sd::DataType::UTF32); + ASSERT_EQ(sd::DataType::UTF32, array.dataType()); - std::u16string f16(u"€alpha水𝄋ÿ€한𐍈®кею"); - auto array = NDArrayFactory::string(f16, sd::DataType::UTF32); - ASSERT_EQ(sd::DataType::UTF32, array.dataType()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); - - auto z = array.e(0); - std::u32string fres(U"€alpha水𝄋ÿ€한𐍈®кею"); - ASSERT_EQ(z, fres); + auto z = array.e(0); + std::u32string fres(U"€alpha水𝄋ÿ€한𐍈®кею"); + ASSERT_EQ(z, fres); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_UTF32toU32) { + std::u32string f32(U"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f32); + ASSERT_EQ(sd::DataType::UTF32, array.dataType()); - std::u32string f32(U"€alpha水𝄋ÿ€한𐍈®кею"); - auto array = NDArrayFactory::string(f32); - ASSERT_EQ(sd::DataType::UTF32, array.dataType()); - - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); - auto z = array.e(0); - ASSERT_EQ(f32, z); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + auto z = array.e(0); + ASSERT_EQ(f32, z); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_UTF8toU32) { + std::string f(u8"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f, sd::DataType::UTF32); + ASSERT_EQ(sd::DataType::UTF32, array.dataType()); - std::string f(u8"€alpha水𝄋ÿ€한𐍈®кею"); - auto array = NDArrayFactory::string(f, sd::DataType::UTF32); - ASSERT_EQ(sd::DataType::UTF32, array.dataType()); - - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); - std::u32string f32(U"€alpha水𝄋ÿ€한𐍈®кею"); - auto z = array.e(0); - ASSERT_EQ(f32, z); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + std::u32string f32(U"€alpha水𝄋ÿ€한𐍈®кею"); + auto z = array.e(0); + ASSERT_EQ(f32, z); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_StringVecU8toUTF16) { - auto array = NDArrayFactory::string({ 3, 2 }, { "alpha€", "beta", "gamma水", "phi", "theta", "omega水" }, sd::DataType::UTF16); + auto array = NDArrayFactory::string( + {3, 2}, {"alpha€", "beta", "gamma水", "phi", "theta", "omega水"}, + sd::DataType::UTF16); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_StringVecU8toUTF32) { - auto array = NDArrayFactory::string( { 3, 2 }, { "alpha€", "beta水", "gamma", "phi", "theta", "omega" }, sd::DataType::UTF32); + auto array = NDArrayFactory::string( + {3, 2}, {"alpha€", "beta水", "gamma", "phi", "theta", "omega"}, + sd::DataType::UTF32); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Export_Test_U8toUTF16) { - auto array = NDArrayFactory::string({ 3 }, { "alpha", "beta", "gamma" }, sd::DataType::UTF16); + auto array = NDArrayFactory::string({3}, {"alpha", "beta", "gamma"}, + sd::DataType::UTF16); - auto vector = array.asByteVector(); + auto vector = array.asByteVector(); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Export_Test_U8toUTF32) { - auto array = NDArrayFactory::string({ 3 }, { "alpha", "beta", "gamma" }, sd::DataType::UTF32); + auto array = NDArrayFactory::string({3}, {"alpha", "beta", "gamma"}, + sd::DataType::UTF32); - auto vector = array.asByteVector(); + auto vector = array.asByteVector(); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_StringVecU16toUTF16) { - auto array = NDArrayFactory::string({ 3, 2 }, { u"alpha水", u"beta", u"gamma", u"phi", u"theta水", u"omega" }, sd::DataType::UTF16); + auto array = NDArrayFactory::string( + {3, 2}, {u"alpha水", u"beta", u"gamma", u"phi", u"theta水", u"omega"}, + sd::DataType::UTF16); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_StringVecU16toUTF32) { - auto array = NDArrayFactory::string( { 3, 2 }, { u"alpha水", u"beta", u"gamma水", u"phi", u"theta", u"omega" }, sd::DataType::UTF32); + auto array = NDArrayFactory::string( + {3, 2}, {u"alpha水", u"beta", u"gamma水", u"phi", u"theta", u"omega"}, + sd::DataType::UTF32); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_StringVecU16toUTF8) { - auto array = NDArrayFactory::string( { 3, 2 }, { u"alpha€", u"beta水", u"gamma", u"phi水", u"theta", u"omega" }, sd::DataType::UTF8); + auto array = NDArrayFactory::string( + {3, 2}, {u"alpha€", u"beta水", u"gamma", u"phi水", u"theta", u"omega"}, + sd::DataType::UTF8); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Export_Test_U16toUTF8) { - auto array = NDArrayFactory::string( { 3 }, { u"alpha", u"beta", u"gamma" }, sd::DataType::UTF8); + auto array = NDArrayFactory::string({3}, {u"alpha", u"beta", u"gamma"}, + sd::DataType::UTF8); - auto vector = array.asByteVector(); + auto vector = array.asByteVector(); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Export_Test_U16toUTF16) { - auto array = NDArrayFactory::string( { 3 }, { u"alpha", u"beta", u"gamma" }, sd::DataType::UTF16); + auto array = NDArrayFactory::string({3}, {u"alpha", u"beta", u"gamma"}, + sd::DataType::UTF16); - auto vector = array.asByteVector(); + auto vector = array.asByteVector(); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Export_Test_U16toUTF32) { - auto array = NDArrayFactory::string( { 3 }, { u"alpha水", u"beta", u"gamma水" }, sd::DataType::UTF32); + auto array = NDArrayFactory::string({3}, {u"alpha水", u"beta", u"gamma水"}, + sd::DataType::UTF32); - auto vector = array.asByteVector(); + auto vector = array.asByteVector(); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_StringVecU32toUTF32) { - auto array = NDArrayFactory::string( { 3, 2 }, { U"alpha€", U"beta水", U"gamma", U"phi", U"theta", U"omega水" }, sd::DataType::UTF32); + auto array = NDArrayFactory::string( + {3, 2}, {U"alpha€", U"beta水", U"gamma", U"phi", U"theta", U"omega水"}, + sd::DataType::UTF32); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_StringVecU32toUTF16) { - auto array = NDArrayFactory::string({ 3, 2 }, { U"alpha水", U"水beta", U"gamma", U"phi水", U"theta", U"omega" }, sd::DataType::UTF16); + auto array = NDArrayFactory::string( + {3, 2}, {U"alpha水", U"水beta", U"gamma", U"phi水", U"theta", U"omega"}, + sd::DataType::UTF16); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); - printf("Array elements size: \n"); - for (int e = 0; e < array.lengthOf(); e++) { - printf("Element %d size: %d\n", e, static_cast(array.e(e).size())); - } + printf("Array elements size: \n"); + for (int e = 0; e < array.lengthOf(); e++) { + printf("Element %d size: %d\n", e, + static_cast(array.e(e).size())); + } } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_StringVecU32toUTF8) { - auto array = NDArrayFactory::string( { 3, 2 }, { U"alpha水", U"beta", U"gamma水", U"phi", U"theta", U"omega" }, sd::DataType::UTF8); + auto array = NDArrayFactory::string( + {3, 2}, {U"alpha水", U"beta", U"gamma水", U"phi", U"theta", U"omega"}, + sd::DataType::UTF8); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Export_Test_U32toUTF32) { - auto array = NDArrayFactory::string( { 3 }, { U"alpha", U"beta", U"gamma" }, sd::DataType::UTF32); + auto array = NDArrayFactory::string({3}, {U"alpha", U"beta", U"gamma"}, + sd::DataType::UTF32); - auto vector = array.asByteVector(); + auto vector = array.asByteVector(); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Export_Test_U32toUTF16) { - auto array = NDArrayFactory::string( { 3 }, { U"alpha", U"beta水", U"gamma水" }, sd::DataType::UTF16); + auto array = NDArrayFactory::string({3}, {U"alpha", U"beta水", U"gamma水"}, + sd::DataType::UTF16); - auto vector = array.asByteVector(); + auto vector = array.asByteVector(); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Export_Test_U32toUTF8) { - auto array = NDArrayFactory::string( { 3 }, { U"alpha", U"beta", U"gamma水" }, sd::DataType::UTF8); + auto array = NDArrayFactory::string({3}, {U"alpha", U"beta", U"gamma水"}, + sd::DataType::UTF8); - auto vector = array.asByteVector(); + auto vector = array.asByteVector(); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_dup_UTF16) { - std::u16string f(u"€alpha水𝄋ÿ€한𐍈®кею"); - auto array = NDArrayFactory::string(f); - ASSERT_EQ(sd::DataType::UTF16, array.dataType()); + std::u16string f(u"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f); + ASSERT_EQ(sd::DataType::UTF16, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto dup = new NDArray(array.dup()); + auto dup = new NDArray(array.dup()); - auto z0 = array.e(0); - auto z1 = dup->e(0); + auto z0 = array.e(0); + auto z1 = dup->e(0); - ASSERT_EQ(f, z0); - ASSERT_EQ(f, z1); + ASSERT_EQ(f, z0); + ASSERT_EQ(f, z1); - delete dup; + delete dup; } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_dup_UTF32) { - std::u32string f(U"€alpha水𝄋ÿ€한𐍈®кею"); - auto array = NDArrayFactory::string(f); - ASSERT_EQ(sd::DataType::UTF32, array.dataType()); + std::u32string f(U"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f); + ASSERT_EQ(sd::DataType::UTF32, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto dup = new NDArray(array.dup()); + auto dup = new NDArray(array.dup()); - auto z0 = array.e(0); - auto z1 = dup->e(0); + auto z0 = array.e(0); + auto z1 = dup->e(0); - ASSERT_EQ(f, z0); - ASSERT_EQ(f, z1); + ASSERT_EQ(f, z0); + ASSERT_EQ(f, z1); - delete dup; + delete dup; } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_cast_UTF32toUTF8) { - - std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); - - std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею"); - - auto array = NDArrayFactory::string(u32); - ASSERT_EQ(sd::DataType::UTF32, array.dataType()); + std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); + + std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею"); + + auto array = NDArrayFactory::string(u32); + ASSERT_EQ(sd::DataType::UTF32, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto aCast = array.cast(sd::DataType::UTF8); + auto aCast = array.cast(sd::DataType::UTF8); - auto z0 = array.e(0); - auto z1 = aCast.e(0); + auto z0 = array.e(0); + auto z1 = aCast.e(0); - ASSERT_EQ(u32, z0); - ASSERT_EQ(u8, z1); + ASSERT_EQ(u32, z0); + ASSERT_EQ(u8, z1); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_cast_UTF32toUTF16) { + std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); - std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); + std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); - std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(u32); + ASSERT_EQ(sd::DataType::UTF32, array.dataType()); - auto array = NDArrayFactory::string(u32); - ASSERT_EQ(sd::DataType::UTF32, array.dataType()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); - - auto aCast = array.cast(sd::DataType::UTF16); + auto aCast = array.cast(sd::DataType::UTF16); - auto z0 = array.e(0); - auto z1 = aCast.e(0); + auto z0 = array.e(0); + auto z1 = aCast.e(0); - ASSERT_EQ(u32, z0); - ASSERT_EQ(u16, z1); + ASSERT_EQ(u32, z0); + ASSERT_EQ(u16, z1); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_cast_UTF32toUTF32) { + std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); - std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(u32); + ASSERT_EQ(sd::DataType::UTF32, array.dataType()); - auto array = NDArrayFactory::string(u32); - ASSERT_EQ(sd::DataType::UTF32, array.dataType()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + auto aCast = array.cast(sd::DataType::UTF32); - auto aCast = array.cast(sd::DataType::UTF32); + auto z0 = array.e(0); + auto z1 = aCast.e(0); - auto z0 = array.e(0); - auto z1 = aCast.e(0); - - ASSERT_EQ(u32, z0); - ASSERT_EQ(u32, z1); + ASSERT_EQ(u32, z0); + ASSERT_EQ(u32, z1); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_cast_UTF16toUTF16) { + std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); - std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); - - auto array = NDArrayFactory::string(u16); - ASSERT_EQ(sd::DataType::UTF16, array.dataType()); + auto array = NDArrayFactory::string(u16); + ASSERT_EQ(sd::DataType::UTF16, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto aCast = array.cast(sd::DataType::UTF16); + auto aCast = array.cast(sd::DataType::UTF16); - auto z0 = array.e(0); - auto z1 = aCast.e(0); + auto z0 = array.e(0); + auto z1 = aCast.e(0); - ASSERT_EQ(u16, z0); - ASSERT_EQ(u16, z1); + ASSERT_EQ(u16, z0); + ASSERT_EQ(u16, z1); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_cast_UTF16toUTF32) { + std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); - std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); + std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); - std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(u16); + ASSERT_EQ(sd::DataType::UTF16, array.dataType()); - auto array = NDArrayFactory::string(u16); - ASSERT_EQ(sd::DataType::UTF16, array.dataType()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + auto aCast = array.cast(sd::DataType::UTF32); - auto aCast = array.cast(sd::DataType::UTF32); + auto z0 = array.e(0); + auto z1 = aCast.e(0); - auto z0 = array.e(0); - auto z1 = aCast.e(0); - - ASSERT_EQ(u32, z1); - ASSERT_EQ(u16, z0); + ASSERT_EQ(u32, z1); + ASSERT_EQ(u16, z0); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_cast_UTF16toUTF8) { + std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею"); - std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею"); - - std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); + std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); - auto array = NDArrayFactory::string(u16); - ASSERT_EQ(sd::DataType::UTF16, array.dataType()); + auto array = NDArrayFactory::string(u16); + ASSERT_EQ(sd::DataType::UTF16, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto aCast = array.cast(sd::DataType::UTF8); + auto aCast = array.cast(sd::DataType::UTF8); - auto z0 = array.e(0); - auto z1 = aCast.e(0); + auto z0 = array.e(0); + auto z1 = aCast.e(0); - ASSERT_EQ(u8, z1); - ASSERT_EQ(u16, z0); + ASSERT_EQ(u8, z1); + ASSERT_EQ(u16, z0); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_cast_UTF8toUTF8) { + std::string u8("€alpha水𝄋ÿ€한𐍈®кею"); - std::string u8("€alpha水𝄋ÿ€한𐍈®кею"); - - auto array = NDArrayFactory::string(u8); - ASSERT_EQ(sd::DataType::UTF8, array.dataType()); + auto array = NDArrayFactory::string(u8); + ASSERT_EQ(sd::DataType::UTF8, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto aCast = array.cast(sd::DataType::UTF8); + auto aCast = array.cast(sd::DataType::UTF8); - auto z0 = array.e(0); - auto z1 = aCast.e(0); + auto z0 = array.e(0); + auto z1 = aCast.e(0); - ASSERT_EQ(u8, z1); - ASSERT_EQ(u8, z0); + ASSERT_EQ(u8, z1); + ASSERT_EQ(u8, z0); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_cast_UTF8toUTF16) { + std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею"); - std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею"); + std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); - std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(u8); + ASSERT_EQ(sd::DataType::UTF8, array.dataType()); - auto array = NDArrayFactory::string(u8); - ASSERT_EQ(sd::DataType::UTF8, array.dataType()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + auto aCast = array.cast(sd::DataType::UTF16); - auto aCast = array.cast(sd::DataType::UTF16); + auto z0 = array.e(0); + auto z1 = aCast.e(0); - auto z0 = array.e(0); - auto z1 = aCast.e(0); - - ASSERT_EQ(u8, z0); - ASSERT_EQ(u16, z1); + ASSERT_EQ(u8, z0); + ASSERT_EQ(u16, z1); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_cast_UTF8toUTF32) { + std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею"); - std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею"); - - std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); + std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); - auto array = NDArrayFactory::string(u8); - ASSERT_EQ(sd::DataType::UTF8, array.dataType()); + auto array = NDArrayFactory::string(u8); + ASSERT_EQ(sd::DataType::UTF8, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto aCast = array.cast(sd::DataType::UTF32); + auto aCast = array.cast(sd::DataType::UTF32); - auto z0 = array.e(0); - auto z1 = aCast.e(0); + auto z0 = array.e(0); + auto z1 = aCast.e(0); - ASSERT_EQ(u8, z0); - ASSERT_EQ(u32, z1); + ASSERT_EQ(u8, z0); + ASSERT_EQ(u32, z1); } TEST_F(StringTests, test_bit_string_1) { diff --git a/libnd4j/tests_cpu/layers_tests/SwitchTests.cpp b/libnd4j/tests_cpu/layers_tests/SwitchTests.cpp deleted file mode 100644 index 8d6a8d180e63..000000000000 --- a/libnd4j/tests_cpu/layers_tests/SwitchTests.cpp +++ /dev/null @@ -1,251 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 13.10.2017. -// - -#include "testlayers.h" -#include - -using namespace sd; -using namespace sd::ops; -using namespace sd::graph; - -class SwitchTests : public testing::Test { -public: - -}; - -TEST_F(SwitchTests, SwitchTest1) { - Graph graph; - - FlowPath flowPath; - - auto variableSpace = graph.getVariableSpace(); - variableSpace->setFlowPath(&flowPath); - - auto input = NDArrayFactory::create_('c',{32, 100}); - input->assign(-119.0f); - - auto condtionX = NDArrayFactory::create_('c', {1, 1}); - condtionX->p(0, 0.0f); - auto condtionY = NDArrayFactory::create_('c', {1, 1}); - condtionY->p(0, 0.0f); - - variableSpace->putVariable(-1, input); - variableSpace->putVariable(-2, condtionX); - variableSpace->putVariable(-3, condtionY); - - // this is just 2 ops, that are executed sequentially. We don't really care bout them - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {3}); - - // this is our condition op, we'll be using Equals condition, on variables conditionX and conditionY (ids -2 and -3 respectively) - // we're creating this op manually in tests, as always. - sd::ops::eq_scalar eqOp; - auto nodeCondition = new Node(&eqOp, 119, {-2, -3}); - //nodeCondition->setOpType(OpType_BOOLEAN); - - // now, this is Switch operation. It takes BooleanOperation operation in, - // and based on evaluation result (true/false) - it'll pass data via :0 or :1 output - // other idx will be considered disabled, and that graph branch won't be executed - sd::ops::Switch switchOp; - auto nodeSwitch = new Node(&switchOp, 3, {2, 119}, {4, 5}); - - // these 2 ops are connected to FALSE and TRUE outputs. output :0 considered FALSE, and output :1 considered TRUE - auto nodeZ0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 4, {}, {}); - nodeZ0->pickInput(3, 0); - auto nodeZ1 = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 5, {}, {}); - nodeZ1->pickInput(3, 1); - - - graph.addNode(nodeA); - graph.addNode(nodeB); - graph.addNode(nodeCondition); - graph.addNode(nodeSwitch); - graph.addNode(nodeZ0); - graph.addNode(nodeZ1); - - graph.buildGraph(); - - // we're making sure nodes connected to the Switch have no other inputs in this Graph - ASSERT_EQ(1, nodeZ0->input()->size()); - ASSERT_EQ(1, nodeZ1->input()->size()); - - // just validating topo sort - ASSERT_EQ(0, nodeA->getLayer()); - ASSERT_EQ(0, nodeCondition->getLayer()); - ASSERT_EQ(1, nodeB->getLayer()); - ASSERT_EQ(2, nodeSwitch->getLayer()); - ASSERT_EQ(3, nodeZ0->getLayer()); - ASSERT_EQ(3, nodeZ1->getLayer()); - - // executing graph - Nd4jStatus status = GraphExecutioner::execute(&graph); - - ASSERT_EQ(ND4J_STATUS_OK, status); - - // nd4j_printf("Z0: [%i]; Z1: [%i]\n", flowPath.isNodeActive(nodeZ0->id()), flowPath.isNodeActive(nodeZ1->id())); - - // we know that Switch got TRUE evaluation, so :0 should be inactive - ASSERT_FALSE(flowPath.isNodeActive(nodeZ0->id())); - - // and :1 should be active - ASSERT_TRUE(flowPath.isNodeActive(nodeZ1->id())); - - std::pair unexpected(4,0); - std::pair expectedResultIndex(5,0); - ASSERT_TRUE(variableSpace->hasVariable(expectedResultIndex)); - - // getting output of nodeZ1 - auto output = variableSpace->getVariable(expectedResultIndex)->getNDArray(); - - // and veryfing it against known expected value - ASSERT_NEAR(-118.0f, output->e(0), 1e-5f); -} - -TEST_F(SwitchTests, SwitchTest2) { - Graph graph; - - FlowPath flowPath; - auto variableSpace = graph.getVariableSpace(); - variableSpace->setFlowPath(&flowPath); - - auto input = NDArrayFactory::create_('c',{32, 100}); - input->assign(-119.0f); - - auto condtionX = NDArrayFactory::create_('c', {1, 1}); - condtionX->p(0, 1.0f); - auto condtionY = NDArrayFactory::create_('c', {1, 1}); - condtionY->p(0, 1.0f); - - - variableSpace->putVariable(-1, input); - variableSpace->putVariable(-2, condtionX); - variableSpace->putVariable(-3, condtionY); - - - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {3}); - - auto scopeCondition = new Node(OpType_LOGIC, logic::Scope, 3); - scopeCondition->setName("scopeCondition"); - - auto nodeCondition = new Node(OpType_LOGIC, logic::Scope, 119, {-2, -3}); - nodeCondition->setScopeInfo(3, "scopeCondition"); - - sd::ops::eq_scalar eqOp; - nodeCondition->setCustomOp(&eqOp); - - auto nodeSwitch = new Node(OpType_LOGIC, logic::Switch, 5, {3, 2}); - - sd::ops::Switch switchOp; - nodeSwitch->setCustomOp(&switchOp); - - - // these 2 ops are connected to FALSE and TRUE outputs. output :0 considered FALSE, and output :1 considered TRUE - auto nodeZ0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 6, {}, {}); - nodeZ0->pickInput(5, 0); - auto nodeZ1 = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 7, {}, {}); - nodeZ1->pickInput(5, 1); - - graph.addNode(nodeA); - graph.addNode(nodeB); - graph.addNode(scopeCondition); - graph.addNode(nodeCondition); - graph.addNode(nodeSwitch); - graph.addNode(nodeZ0); - graph.addNode(nodeZ1); - - Nd4jStatus status = GraphExecutioner::execute(&graph); - - ASSERT_EQ(ND4J_STATUS_OK, status); - - ASSERT_TRUE(!flowPath.isNodeActive(nodeZ0->id())); - ASSERT_TRUE(flowPath.isNodeActive(nodeZ1->id())); - - auto z = graph.getVariableSpace()->getVariable(7)->getNDArray(); - - // abs(-119) = 119; 1 - 119 = -118 - ASSERT_NEAR(-118.f, z->e(0), 1e-5); -} - -TEST_F(SwitchTests, SwitchTest3) { - Graph graph; - - FlowPath flowPath; - auto variableSpace = graph.getVariableSpace(); - variableSpace->setFlowPath(&flowPath); - - auto input = NDArrayFactory::create_('c',{32, 100}); - input->assign(-119.0f); - - auto condtionX = NDArrayFactory::create_('c', {1, 1}); - condtionX->p(0, 2.0f); - auto condtionY = NDArrayFactory::create_('c', {1, 1}); - condtionY->p(0, 1.0f); - - - variableSpace->putVariable(-1, input); - variableSpace->putVariable(-2, condtionX); - variableSpace->putVariable(-3, condtionY); - - - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {3}); - - auto scopeCondition = new Node(OpType_LOGIC, logic::Scope, 3); - scopeCondition->setName("scopeCondition"); - - auto nodeCondition = new Node(OpType_LOGIC, logic::Scope, 119, {-2, -3}); - nodeCondition->setScopeInfo(3, "scopeCondition"); - - sd::ops::eq_scalar eqOp; - nodeCondition->setCustomOp(&eqOp); - - auto nodeSwitch = new Node(OpType_LOGIC, logic::Switch, 5, {3, 2}); - - sd::ops::Switch switchOp; - nodeSwitch->setCustomOp(&switchOp); - - - // these 2 ops are connected to FALSE and TRUE outputs. output :0 considered FALSE, and output :1 considered TRUE - auto nodeZ0 = new Node(OpType_TRANSFORM_SAME, transform::Neg, 6, {}, {}); - nodeZ0->pickInput(5, 0); - auto nodeZ1 = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 7, {}, {}); - nodeZ1->pickInput(5, 1); - - graph.addNode(nodeA); - graph.addNode(nodeB); - graph.addNode(scopeCondition); - graph.addNode(nodeCondition); - graph.addNode(nodeSwitch); - graph.addNode(nodeZ0); - graph.addNode(nodeZ1); - - Nd4jStatus status = GraphExecutioner::execute(&graph); - - ASSERT_EQ(ND4J_STATUS_OK, status); - - ASSERT_TRUE(flowPath.isNodeActive(nodeZ0->id())); - ASSERT_TRUE(!flowPath.isNodeActive(nodeZ1->id())); - - auto z = graph.getVariableSpace()->getVariable(6)->getNDArray(); - - // abs(-119) = 119; Neg(119) = 119 - ASSERT_NEAR(-119.f, z->e(0), 1e-5); -} diff --git a/libnd4j/tests_cpu/layers_tests/TadTests.cpp b/libnd4j/tests_cpu/layers_tests/TadTests.cpp index 947927bfbfbb..3bf0b6e9fcf5 100644 --- a/libnd4j/tests_cpu/layers_tests/TadTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/TadTests.cpp @@ -21,423 +21,433 @@ #ifndef LIBND4J_TADTESTS_H #define LIBND4J_TADTESTS_H -#include "testlayers.h" #include +#include #include + #include -#include + +#include "testlayers.h" using namespace sd; class TadTests : public testing::Test { -public: - int numLoops = 100000000; + public: + int numLoops = 100000000; - int extLoops = 1000; - int intLoops = 1000; + int extLoops = 1000; + int intLoops = 1000; }; TEST_F(TadTests, Test4DTad1) { + NDArray* arraySource = sd::NDArrayFactory::linspace(1.0f, 10000.0f, 10000); - NDArray* arraySource = sd::NDArrayFactory::linspace(1.0f, 10000.0f, 10000); - - Nd4jLong badShape[] = {4, 2, 1, 4, 4, 80, 16, 4, 1, 8192, -1, 99}; - Nd4jLong goodShape[] = {4, 2, 1, 4, 4, 16, 16, 4, 1, 8192, 1, 99}; + Nd4jLong badShape[] = {4, 2, 1, 4, 4, 80, 16, 4, 1, 8192, -1, 99}; + Nd4jLong goodShape[] = {4, 2, 1, 4, 4, 16, 16, 4, 1, 8192, 1, 99}; - std::vector buff = arraySource->getBufferAsVector(); + std::vector buff = arraySource->getBufferAsVector(); - NDArray* arrayExp = new NDArray(buff.data(), goodShape); - NDArray* arrayBad = new NDArray(buff.data(), badShape); + NDArray* arrayExp = new NDArray(buff.data(), goodShape); + NDArray* arrayBad = new NDArray(buff.data(), badShape); - int dim = 1; - shape::TAD tad; - tad.init(arrayBad->shapeInfo(), &dim, 1); - tad.createTadOnlyShapeInfo(); - tad.createOffsets(); + int dim = 1; + shape::TAD tad; + tad.init(arrayBad->shapeInfo(), &dim, 1); + tad.createTadOnlyShapeInfo(); + tad.createOffsets(); - int exp[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95 }; - for (int e = 0; e < 32; e++) - ASSERT_EQ((int) tad.tadOffsets[e], exp[e]); + int exp[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}; + for (int e = 0; e < 32; e++) ASSERT_EQ((int)tad.tadOffsets[e], exp[e]); - delete arrayExp; - delete arrayBad; - delete arraySource; + delete arrayExp; + delete arrayBad; + delete arraySource; } TEST_F(TadTests, TestNumTads1) { - auto x = NDArrayFactory::create('c', {2, 3}); - auto y = NDArrayFactory::create('c', {2, 2}); + auto x = NDArrayFactory::create('c', {2, 3}); + auto y = NDArrayFactory::create('c', {2, 2}); - std::vector dim({0}); + std::vector dim({0}); - Nd4jLong tadLengthX = shape::tadLength(x.shapeInfo(), dim.data(), dim.size()); - Nd4jLong numTadsX = x.lengthOf() / tadLengthX; + Nd4jLong tadLengthX = shape::tadLength(x.shapeInfo(), dim.data(), dim.size()); + Nd4jLong numTadsX = x.lengthOf() / tadLengthX; - Nd4jLong tadLengthY = shape::tadLength(y.shapeInfo(), dim.data(), dim.size()); - Nd4jLong numTadsY = y.lengthOf() / tadLengthY; + Nd4jLong tadLengthY = shape::tadLength(y.shapeInfo(), dim.data(), dim.size()); + Nd4jLong numTadsY = y.lengthOf() / tadLengthY; - ASSERT_EQ(2, tadLengthX); - ASSERT_EQ(3, numTadsX); + ASSERT_EQ(2, tadLengthX); + ASSERT_EQ(3, numTadsX); - ASSERT_EQ(2, tadLengthY); - ASSERT_EQ(2, numTadsY); + ASSERT_EQ(2, tadLengthY); + ASSERT_EQ(2, numTadsY); } TEST_F(TadTests, TestShapeTad_1) { + float buff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 16, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + Nd4jLong shapeInfo[] = {3, 2, 3, 4, 12, 4, 1, 8192, 1, 99}; - float buff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,16,16,17,18,19,20,21,22,23,24}; - Nd4jLong shapeInfo[] = {3, 2, 3, 4, 12, 4, 1, 8192, 1, 99}; + NDArray input(buff, shapeInfo); - NDArray input(buff, shapeInfo); + std::vector dimensions = {0, 1, 2}; + Nd4jLong tadLength = + shape::tadLength(input.shapeInfo(), dimensions.data(), dimensions.size()); + Nd4jLong numTads = input.lengthOf() / tadLength; - std::vector dimensions = {0,1,2}; - Nd4jLong tadLength = shape::tadLength(input.shapeInfo(), dimensions.data(), dimensions.size()); - Nd4jLong numTads = input.lengthOf() / tadLength; + shape::TAD tad; + tad.init(input.shapeInfo(), dimensions.data(), dimensions.size()); + tad.createTadOnlyShapeInfo(); + tad.createOffsets(); - shape::TAD tad; - tad.init(input.shapeInfo(), dimensions.data(), dimensions.size()); - tad.createTadOnlyShapeInfo(); - tad.createOffsets(); + auto tadShapeInfo = + new Nd4jLong[shape::shapeInfoLength(tad.tadOnlyShapeInfo[0])]; + std::memcpy(tadShapeInfo, tad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); - auto tadShapeInfo = new Nd4jLong[shape::shapeInfoLength(tad.tadOnlyShapeInfo[0])]; - std::memcpy(tadShapeInfo, tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + float* tadBuff = reinterpret_cast(input.buffer()) + tad.tadOffsets[0]; + NDArray tadArr(tadBuff, tadShapeInfo); - float* tadBuff = reinterpret_cast(input.buffer()) + tad.tadOffsets[0]; - NDArray tadArr(tadBuff, tadShapeInfo); + ASSERT_TRUE(numTads == 1); + ASSERT_TRUE(input.isSameShapeStrict(tadArr)); + ASSERT_TRUE(input.equalsTo(&tadArr)); - ASSERT_TRUE(numTads==1); - ASSERT_TRUE(input.isSameShapeStrict(tadArr)); - ASSERT_TRUE(input.equalsTo(&tadArr)); - - delete[] tadShapeInfo; + delete[] tadShapeInfo; } TEST_F(TadTests, TadNoAxis_1) { - auto array = NDArrayFactory::create('c', {2, 3}); + auto array = NDArrayFactory::create('c', {2, 3}); - shape::TAD tad; - tad.init(array.shapeInfo(), nullptr, 0); - tad.createTadOnlyShapeInfo(); - tad.createOffsets(); + shape::TAD tad; + tad.init(array.shapeInfo(), nullptr, 0); + tad.createTadOnlyShapeInfo(); + tad.createOffsets(); - ASSERT_TRUE(tad.wholeThing); + ASSERT_TRUE(tad.wholeThing); - ASSERT_TRUE(shape::equalsStrict(tad.tadOnlyShapeInfo, array.shapeInfo())); + ASSERT_TRUE(shape::equalsStrict(tad.tadOnlyShapeInfo, array.shapeInfo())); } TEST_F(TadTests, TadEdgeCase_1) { - auto array = NDArrayFactory::create('c', {5, 4, 1}); - auto exp = NDArrayFactory::create('c', {5, 4}); - array.linspace(1); + auto array = NDArrayFactory::create('c', {5, 4, 1}); + auto exp = NDArrayFactory::create('c', {5, 4}); + array.linspace(1); - auto tad = array(0, {2}); + auto tad = array(0, {2}); - ASSERT_TRUE(exp.isSameShape(tad)); + ASSERT_TRUE(exp.isSameShape(tad)); } TEST_F(TadTests, TestEdgeCase_2) { + auto array = + NDArrayFactory::create('f', {2, 3, 1}, {1, 4, 2, 5, 3, 6}); - auto array = NDArrayFactory::create('f', {2, 3, 1}, {1, 4, 2, 5, 3, 6}); - - for (int e = 0 ; e < array.lengthOf(); e++) { - auto tad = array(e, {0,1}); - ASSERT_NEAR(tad.e(0), array.e(e), 1e-5); - } + for (int e = 0; e < array.lengthOf(); e++) { + auto tad = array(e, {0, 1}); + ASSERT_NEAR(tad.e(0), array.e(e), 1e-5); + } } TEST_F(TadTests, TadEdgeCase_2) { - auto array = NDArrayFactory::create('c', {2, 3, 4}); + auto array = NDArrayFactory::create('c', {2, 3, 4}); - auto tad = array(0, {0,2}); + auto tad = array(0, {0, 2}); - ASSERT_EQ(3, tad.lengthOf()); + ASSERT_EQ(3, tad.lengthOf()); } - TEST_F(TadTests, test_Tad_Ews_optimization_1) { - shape::TAD xTad; + shape::TAD xTad; - std::array array = {1,2}; - ASSERT_TRUE(xTad.dimensionsDescending(3, array.data(), array.size())); + std::array array = {1, 2}; + ASSERT_TRUE(xTad.dimensionsDescending(3, array.data(), array.size())); } TEST_F(TadTests, test_Tad_Ews_optimization_2) { - shape::TAD xTad; + shape::TAD xTad; - std::array array = {0,2}; - ASSERT_FALSE(xTad.dimensionsDescending(3, array.data(), array.size())); + std::array array = {0, 2}; + ASSERT_FALSE(xTad.dimensionsDescending(3, array.data(), array.size())); } TEST_F(TadTests, test_Tad_Ews_optimization_3) { - shape::TAD xTad; + shape::TAD xTad; - std::array array = {1}; - ASSERT_TRUE(xTad.dimensionsDescending(2, array.data(), array.size())); + std::array array = {1}; + ASSERT_TRUE(xTad.dimensionsDescending(2, array.data(), array.size())); } TEST_F(TadTests, test_Tad_Ews_optimization_4) { - shape::TAD xTad; + shape::TAD xTad; - std::array array = {0}; - ASSERT_TRUE(xTad.dimensionsDescending(1, array.data(), array.size())); + std::array array = {0}; + ASSERT_TRUE(xTad.dimensionsDescending(1, array.data(), array.size())); } TEST_F(TadTests, test_Tad_Ews_optimization_5) { - shape::TAD xTad; + shape::TAD xTad; - std::array array = {2,3}; - ASSERT_TRUE(xTad.dimensionsDescending(4, array.data(), array.size())); + std::array array = {2, 3}; + ASSERT_TRUE(xTad.dimensionsDescending(4, array.data(), array.size())); } TEST_F(TadTests, test_TAD_empty_dims_1) { - Nd4jLong xShape[8] = {2, 150, 1, 3, 1, 16384, 3, 99}; - shape::TAD xTad; - xTad.init(xShape, reinterpret_cast(112L), 0); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); + Nd4jLong xShape[8] = {2, 150, 1, 3, 1, 16384, 3, 99}; + shape::TAD xTad; + xTad.init(xShape, reinterpret_cast(112L), 0); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); } TEST_F(TadTests, test_tad_order_1) { - Nd4jLong xShape[8] = {2, 150, 10, 10, 1, 8192, 1, 99}; - Nd4jLong tShape[8] = {2, 1, 10, 1, 1, 8192, 1, 99}; - shape::TAD xTad; - int dim = 1; - xTad.init(xShape, &dim, 1); - xTad.createTadOnlyShapeInfo(); - - ASSERT_TRUE(shape::equalsStrict(tShape, xTad.tadOnlyShapeInfo)); + Nd4jLong xShape[8] = {2, 150, 10, 10, 1, 8192, 1, 99}; + Nd4jLong tShape[8] = {2, 1, 10, 1, 1, 8192, 1, 99}; + shape::TAD xTad; + int dim = 1; + xTad.init(xShape, &dim, 1); + xTad.createTadOnlyShapeInfo(); + + ASSERT_TRUE(shape::equalsStrict(tShape, xTad.tadOnlyShapeInfo)); } TEST_F(TadTests, test_tad_order_2) { - Nd4jLong xShape[8] = {2, 150, 10, 10, 1, 8192, 1, 99}; - Nd4jLong tShape[8] = {2, 1, 150, 1, 10, 8192, 10, 99}; - shape::TAD xTad; - int dim = 0; - xTad.init(xShape, &dim, 1); - xTad.createTadOnlyShapeInfo(); - - ASSERT_TRUE(shape::equalsStrict(tShape, xTad.tadOnlyShapeInfo)); + Nd4jLong xShape[8] = {2, 150, 10, 10, 1, 8192, 1, 99}; + Nd4jLong tShape[8] = {2, 1, 150, 1, 10, 8192, 10, 99}; + shape::TAD xTad; + int dim = 0; + xTad.init(xShape, &dim, 1); + xTad.createTadOnlyShapeInfo(); + + ASSERT_TRUE(shape::equalsStrict(tShape, xTad.tadOnlyShapeInfo)); } - TEST_F(TadTests, test_tad_order_3) { - Nd4jLong xShape[10] = {3, 10, 20, 30, 600 ,30, 1, 8192, 1, 99}; - Nd4jLong tShape[8] = {2, 1, 30, 1, 1, 8192, 1, 99}; - shape::TAD xTad; - int dim = 2; - xTad.init(xShape, &dim, 1); - xTad.createTadOnlyShapeInfo(); - - ASSERT_TRUE(shape::equalsStrict(tShape, xTad.tadOnlyShapeInfo)); + Nd4jLong xShape[10] = {3, 10, 20, 30, 600, 30, 1, 8192, 1, 99}; + Nd4jLong tShape[8] = {2, 1, 30, 1, 1, 8192, 1, 99}; + shape::TAD xTad; + int dim = 2; + xTad.init(xShape, &dim, 1); + xTad.createTadOnlyShapeInfo(); + + ASSERT_TRUE(shape::equalsStrict(tShape, xTad.tadOnlyShapeInfo)); } - TEST_F(TadTests, test_tad_order_4) { - Nd4jLong xShape[10] = {3, 10, 20, 30, 600 ,30, 1, 8192, 1, 99}; - Nd4jLong tShape[8] = {2, 20, 30, 30, 1, 8192, 1, 99}; - shape::TAD xTad; - int dim[2] = {1, 2}; - xTad.init(xShape, dim, 2); - xTad.createTadOnlyShapeInfo(); - - ASSERT_TRUE(shape::equalsStrict(tShape, xTad.tadOnlyShapeInfo)); + Nd4jLong xShape[10] = {3, 10, 20, 30, 600, 30, 1, 8192, 1, 99}; + Nd4jLong tShape[8] = {2, 20, 30, 30, 1, 8192, 1, 99}; + shape::TAD xTad; + int dim[2] = {1, 2}; + xTad.init(xShape, dim, 2); + xTad.createTadOnlyShapeInfo(); + + ASSERT_TRUE(shape::equalsStrict(tShape, xTad.tadOnlyShapeInfo)); } TEST_F(TadTests, test_column_1) { - auto x = NDArrayFactory::create('c', {5, 2}); - auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), 0); + auto x = NDArrayFactory::create('c', {5, 2}); + auto tadPack = + sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), 0); - ASSERT_EQ(1, shape::rank(tadPack.primaryShapeInfo())); - ASSERT_EQ(5, shape::length(tadPack.primaryShapeInfo())); - ASSERT_TRUE(shape::isVector(tadPack.primaryShapeInfo())); + ASSERT_EQ(1, shape::rank(tadPack.primaryShapeInfo())); + ASSERT_EQ(5, shape::length(tadPack.primaryShapeInfo())); + ASSERT_TRUE(shape::isVector(tadPack.primaryShapeInfo())); - auto scalarViewPack = sd::ConstantTadHelper::getInstance().tadForDimensions(tadPack.primaryShapeInfo(), 0); + auto scalarViewPack = sd::ConstantTadHelper::getInstance().tadForDimensions( + tadPack.primaryShapeInfo(), 0); - ASSERT_TRUE(shape::equalsStrict(tadPack.primaryShapeInfo(), scalarViewPack.primaryShapeInfo())); + ASSERT_TRUE(shape::equalsStrict(tadPack.primaryShapeInfo(), + scalarViewPack.primaryShapeInfo())); } /////////////////////////////////////////////////////////////////// TEST_F(TadTests, calcOffsets_1) { + Nd4jLong shapeInfoF[10] = {3, 2, 3, 4, 1, 2, 6, 8192, 1, 102}; + Nd4jLong shapeInfoC[10] = {3, 2, 3, 4, 12, 4, 1, 8192, 1, 99}; + Nd4jLong shapeInfoFC[10] = {3, 2, 3, 4, 1, 2, 6, 8192, 1, 99}; + ; - Nd4jLong shapeInfoF[10] = {3, 2,3,4, 1,2,6, 8192, 1, 102}; - Nd4jLong shapeInfoC[10] = {3, 2,3,4, 12,4,1, 8192, 1, 99}; - Nd4jLong shapeInfoFC[10] = {3, 2,3,4, 1,2,6, 8192, 1, 99};; - - Nd4jLong expOffsetsF[24] = {0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23}; - Nd4jLong expOffsetsC[24] = {0,12,4,16,8,20,1,13,5,17,9,21,2,14,6,18,10,22,3,15,7,19,11,23}; + Nd4jLong expOffsetsF[24] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}; + Nd4jLong expOffsetsC[24] = {0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21, + 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23}; - Nd4jLong offsets[24]; + Nd4jLong offsets[24]; - shape::calcOffsets(shapeInfoF, offsets, 'f'); + shape::calcOffsets(shapeInfoF, offsets, 'f'); - for (int e = 0; e < 24; e++) - ASSERT_TRUE(offsets[e] == expOffsetsF[e]); + for (int e = 0; e < 24; e++) ASSERT_TRUE(offsets[e] == expOffsetsF[e]); - shape::calcOffsets(shapeInfoC, offsets, 'f'); + shape::calcOffsets(shapeInfoC, offsets, 'f'); - for (int e = 0; e < 24; e++) - ASSERT_TRUE(offsets[e] == expOffsetsC[e]); + for (int e = 0; e < 24; e++) ASSERT_TRUE(offsets[e] == expOffsetsC[e]); - shape::calcOffsets(shapeInfoFC, offsets, 'f'); + shape::calcOffsets(shapeInfoFC, offsets, 'f'); - for (int e = 0; e < 24; e++) - ASSERT_TRUE(offsets[e] == expOffsetsF[e]); + for (int e = 0; e < 24; e++) ASSERT_TRUE(offsets[e] == expOffsetsF[e]); } - ///////////////////////////////////////////////////////////////// TEST_F(TadTests, outerArrayIndexes_1) { - - NDArray x('c', {2,3,4,5}, sd::DataType::FLOAT32); - int maxIdxs[120]; - - NDArray y1('c', {3,5}, sd::DataType::FLOAT32); - const std::vector dimsToExclude1 = {0,2}; - const int n1[] = {20,25,30,35, 80,85,90,95}; - int minIdx = 5; - - int N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y1.shapeInfo(), dimsToExclude1.data()); - ASSERT_TRUE(N == x.lengthOf()/y1.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n1[i] == maxIdxs[i]); - - NDArray y2('c', {4,5}, sd::DataType::FLOAT32); - const std::vector dimsToExclude2 = {0,1}; - const int n2[] = {12,32,52, 72,92,112}; - minIdx = 12; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y2.shapeInfo(), dimsToExclude2.data()); - ASSERT_TRUE(N == x.lengthOf()/y2.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n2[i] == maxIdxs[i]); - - NDArray y3('c', {2,5}, sd::DataType::FLOAT32); - const std::vector dimsToExclude3 = {1,2}; - const int n3[] = {64,69,74,79,84,89,94,99,104,109,114,119}; - minIdx = 9; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y3.shapeInfo(), dimsToExclude3.data()); - ASSERT_TRUE(N == x.lengthOf()/y3.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n3[i] == maxIdxs[i]); - - NDArray y4('c', {2,3}, sd::DataType::FLOAT32); - const std::vector dimsToExclude4 = {2,3}; - const int n4[] = {20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39}; - minIdx = 1; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y4.shapeInfo(), dimsToExclude4.data()); - ASSERT_TRUE(N == x.lengthOf()/y4.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n4[i] == maxIdxs[i]); - - NDArray y5('c', {2,4}, sd::DataType::FLOAT32); - const std::vector dimsToExclude5 = {1,3}; - const int n5[] = {65,66,67,68,69, 85,86,87,88,89, 105,106,107,108,109}; - minIdx = 5; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y5.shapeInfo(), dimsToExclude5.data()); - ASSERT_TRUE(N == x.lengthOf()/y5.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n5[i] == maxIdxs[i]); - - NDArray y6('c', {2,3,4}, sd::DataType::FLOAT32); - const std::vector dimsToExclude6 = {3}; - const int n6[] = {65,66,67,68,69}; - minIdx = 13; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y6.shapeInfo(), dimsToExclude6.data()); - ASSERT_TRUE(N == x.lengthOf()/y6.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n6[i] == maxIdxs[i]); - - NDArray y7('c', {4}, sd::DataType::FLOAT32); - const std::vector dimsToExclude7 = {0,1,3}; - const int n7[] = {15,16,17,18,19, 35,36,37,38,39, 55,56,57,58,59, 75,76,77,78,79, 95,96,97,98,99, 115,116,117,118,119}; - minIdx = 3; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y7.shapeInfo(), dimsToExclude7.data()); - ASSERT_TRUE(N == x.lengthOf()/y7.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n7[i] == maxIdxs[i]); - - NDArray y8('c', {5}, sd::DataType::FLOAT32); - const std::vector dimsToExclude8 = {0,1,2}; - const int n8[] = {0,5,10,15, 20,25,30,35, 40,45,50,55, 60,65,70,75, 80,85,90,95, 100,105,110,115}; - minIdx = 0; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y8.shapeInfo(), dimsToExclude8.data()); - ASSERT_TRUE(N == x.lengthOf()/y8.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n8[i] == maxIdxs[i]); - - NDArray y9('c', {2}, sd::DataType::FLOAT32); - const std::vector dimsToExclude9 = {1,2,3}; - const int n9[] = {60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119}; - minIdx = 1; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y9.shapeInfo(), dimsToExclude9.data()); - ASSERT_TRUE(N == x.lengthOf()/y9.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n9[i] == maxIdxs[i]); - - NDArray y10('c', {3,4,5}, sd::DataType::FLOAT32); - const std::vector dimsToExclude10 = {0}; - const int n10[] = {11, 71}; - minIdx = 11; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y10.shapeInfo(), dimsToExclude10.data()); - ASSERT_TRUE(N == x.lengthOf()/y10.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n10[i] == maxIdxs[i]); - - NDArray y11('c', {2,4,5}, sd::DataType::FLOAT32); - const std::vector dimsToExclude11 = {1}; - const int n11[] = {66, 86, 106}; - minIdx = 26; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y11.shapeInfo(), dimsToExclude11.data()); - ASSERT_TRUE(N == x.lengthOf()/y11.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n11[i] == maxIdxs[i]); - - NDArray y12('c', {3,2}, sd::DataType::FLOAT32); - const std::vector dimsToExclude12 = {0,2}; - const int n12[] = {0,2,4,5,7,9,10,12,14,15,17,19,60,62,64,65,67,69,70,72,74,75,77,79}; - minIdx = 0; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y12.shapeInfo(), dimsToExclude12.data()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n12[i] == maxIdxs[i]); - - NDArray y13('c', {3,2}, sd::DataType::FLOAT32); - const std::vector dimsToExclude13 = {0,2}; - const int n13[] = {1,3,6,8,11,13,16,18,61,63,66,68,71,73,76,78}; - minIdx = 1; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y13.shapeInfo(), dimsToExclude13.data()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n13[i] == maxIdxs[i]); - - NDArray y14('c', {4,5}, sd::DataType::FLOAT32); - const int n14[] = {12,32,52, 72,92,112}; - minIdx = 12; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y14.shapeInfo(), nullptr); - ASSERT_TRUE(N == x.lengthOf()/y14.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n14[i] == maxIdxs[i]); - - NDArray y15('c', {3,4,5}, sd::DataType::FLOAT32); - const int n15[] = {11, 71}; - minIdx = 11; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y15.shapeInfo(), nullptr); - ASSERT_TRUE(N == x.lengthOf()/y15.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n15[i] == maxIdxs[i]); + NDArray x('c', {2, 3, 4, 5}, sd::DataType::FLOAT32); + int maxIdxs[120]; + + NDArray y1('c', {3, 5}, sd::DataType::FLOAT32); + const std::vector dimsToExclude1 = {0, 2}; + const int n1[] = {20, 25, 30, 35, 80, 85, 90, 95}; + int minIdx = 5; + + int N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), + y1.shapeInfo(), dimsToExclude1.data()); + ASSERT_TRUE(N == x.lengthOf() / y1.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n1[i] == maxIdxs[i]); + + NDArray y2('c', {4, 5}, sd::DataType::FLOAT32); + const std::vector dimsToExclude2 = {0, 1}; + const int n2[] = {12, 32, 52, 72, 92, 112}; + minIdx = 12; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y2.shapeInfo(), + dimsToExclude2.data()); + ASSERT_TRUE(N == x.lengthOf() / y2.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n2[i] == maxIdxs[i]); + + NDArray y3('c', {2, 5}, sd::DataType::FLOAT32); + const std::vector dimsToExclude3 = {1, 2}; + const int n3[] = {64, 69, 74, 79, 84, 89, 94, 99, 104, 109, 114, 119}; + minIdx = 9; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y3.shapeInfo(), + dimsToExclude3.data()); + ASSERT_TRUE(N == x.lengthOf() / y3.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n3[i] == maxIdxs[i]); + + NDArray y4('c', {2, 3}, sd::DataType::FLOAT32); + const std::vector dimsToExclude4 = {2, 3}; + const int n4[] = {20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39}; + minIdx = 1; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y4.shapeInfo(), + dimsToExclude4.data()); + ASSERT_TRUE(N == x.lengthOf() / y4.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n4[i] == maxIdxs[i]); + + NDArray y5('c', {2, 4}, sd::DataType::FLOAT32); + const std::vector dimsToExclude5 = {1, 3}; + const int n5[] = {65, 66, 67, 68, 69, 85, 86, 87, + 88, 89, 105, 106, 107, 108, 109}; + minIdx = 5; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y5.shapeInfo(), + dimsToExclude5.data()); + ASSERT_TRUE(N == x.lengthOf() / y5.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n5[i] == maxIdxs[i]); + + NDArray y6('c', {2, 3, 4}, sd::DataType::FLOAT32); + const std::vector dimsToExclude6 = {3}; + const int n6[] = {65, 66, 67, 68, 69}; + minIdx = 13; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y6.shapeInfo(), + dimsToExclude6.data()); + ASSERT_TRUE(N == x.lengthOf() / y6.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n6[i] == maxIdxs[i]); + + NDArray y7('c', {4}, sd::DataType::FLOAT32); + const std::vector dimsToExclude7 = {0, 1, 3}; + const int n7[] = {15, 16, 17, 18, 19, 35, 36, 37, 38, 39, + 55, 56, 57, 58, 59, 75, 76, 77, 78, 79, + 95, 96, 97, 98, 99, 115, 116, 117, 118, 119}; + minIdx = 3; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y7.shapeInfo(), + dimsToExclude7.data()); + ASSERT_TRUE(N == x.lengthOf() / y7.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n7[i] == maxIdxs[i]); + + NDArray y8('c', {5}, sd::DataType::FLOAT32); + const std::vector dimsToExclude8 = {0, 1, 2}; + const int n8[] = {0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, + 60, 65, 70, 75, 80, 85, 90, 95, 100, 105, 110, 115}; + minIdx = 0; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y8.shapeInfo(), + dimsToExclude8.data()); + ASSERT_TRUE(N == x.lengthOf() / y8.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n8[i] == maxIdxs[i]); + + NDArray y9('c', {2}, sd::DataType::FLOAT32); + const std::vector dimsToExclude9 = {1, 2, 3}; + const int n9[] = {60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, + 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, + 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, + 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, + 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119}; + minIdx = 1; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y9.shapeInfo(), + dimsToExclude9.data()); + ASSERT_TRUE(N == x.lengthOf() / y9.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n9[i] == maxIdxs[i]); + + NDArray y10('c', {3, 4, 5}, sd::DataType::FLOAT32); + const std::vector dimsToExclude10 = {0}; + const int n10[] = {11, 71}; + minIdx = 11; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y10.shapeInfo(), + dimsToExclude10.data()); + ASSERT_TRUE(N == x.lengthOf() / y10.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n10[i] == maxIdxs[i]); + + NDArray y11('c', {2, 4, 5}, sd::DataType::FLOAT32); + const std::vector dimsToExclude11 = {1}; + const int n11[] = {66, 86, 106}; + minIdx = 26; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y11.shapeInfo(), + dimsToExclude11.data()); + ASSERT_TRUE(N == x.lengthOf() / y11.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n11[i] == maxIdxs[i]); + + NDArray y12('c', {3, 2}, sd::DataType::FLOAT32); + const std::vector dimsToExclude12 = {0, 2}; + const int n12[] = {0, 2, 4, 5, 7, 9, 10, 12, 14, 15, 17, 19, + 60, 62, 64, 65, 67, 69, 70, 72, 74, 75, 77, 79}; + minIdx = 0; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y12.shapeInfo(), + dimsToExclude12.data()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n12[i] == maxIdxs[i]); + + NDArray y13('c', {3, 2}, sd::DataType::FLOAT32); + const std::vector dimsToExclude13 = {0, 2}; + const int n13[] = {1, 3, 6, 8, 11, 13, 16, 18, + 61, 63, 66, 68, 71, 73, 76, 78}; + minIdx = 1; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y13.shapeInfo(), + dimsToExclude13.data()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n13[i] == maxIdxs[i]); + + NDArray y14('c', {4, 5}, sd::DataType::FLOAT32); + const int n14[] = {12, 32, 52, 72, 92, 112}; + minIdx = 12; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y14.shapeInfo(), + nullptr); + ASSERT_TRUE(N == x.lengthOf() / y14.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n14[i] == maxIdxs[i]); + + NDArray y15('c', {3, 4, 5}, sd::DataType::FLOAT32); + const int n15[] = {11, 71}; + minIdx = 11; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y15.shapeInfo(), + nullptr); + ASSERT_TRUE(N == x.lengthOf() / y15.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n15[i] == maxIdxs[i]); } - - -#endif //LIBND4J_TADTESTS_H +#endif // LIBND4J_TADTESTS_H diff --git a/libnd4j/tests_cpu/layers_tests/ThreadsTests.cpp b/libnd4j/tests_cpu/layers_tests/ThreadsTests.cpp index 71957bc59763..8e2a1aebe756 100644 --- a/libnd4j/tests_cpu/layers_tests/ThreadsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ThreadsTests.cpp @@ -18,12 +18,14 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include -#include +#include #include +#include +#include + #include -#include + +#include "testlayers.h" using namespace samediff; using namespace sd; @@ -31,181 +33,179 @@ using namespace sd::ops; using namespace sd::graph; class ThreadsTests : public testing::Test { -public: - ThreadsTests() { - nd4j_printf("\n",""); - } + public: + ThreadsTests() { nd4j_printf("\n", ""); } }; TEST_F(ThreadsTests, th_test_1) { - ASSERT_EQ(1, ThreadsHelper::numberOfThreads(6, 1023)); - ASSERT_EQ(1, ThreadsHelper::numberOfThreads(6, 1024)); - ASSERT_EQ(1, ThreadsHelper::numberOfThreads(6, 1026)); + ASSERT_EQ(1, ThreadsHelper::numberOfThreads(6, 1023)); + ASSERT_EQ(1, ThreadsHelper::numberOfThreads(6, 1024)); + ASSERT_EQ(1, ThreadsHelper::numberOfThreads(6, 1026)); - ASSERT_EQ(1, ThreadsHelper::numberOfThreads(6, 2043)); - ASSERT_EQ(2, ThreadsHelper::numberOfThreads(6, 2048)); + ASSERT_EQ(1, ThreadsHelper::numberOfThreads(6, 2043)); + ASSERT_EQ(2, ThreadsHelper::numberOfThreads(6, 2048)); } - TEST_F(ThreadsTests, th_test_2) { - // in this case we'll get better split over second loop - exactly 32 elements per thread - ASSERT_EQ(2, ThreadsHelper::pickLoop2d(32, 48, 1024)); - ASSERT_EQ(2, ThreadsHelper::pickLoop2d(6, 4, 16384)); + // in this case we'll get better split over second loop - exactly 32 elements + // per thread + ASSERT_EQ(2, ThreadsHelper::pickLoop2d(32, 48, 1024)); + ASSERT_EQ(2, ThreadsHelper::pickLoop2d(6, 4, 16384)); - // in this case we'll get better split over first loop - 2 loops/2048 elements per thread - ASSERT_EQ(1, ThreadsHelper::pickLoop2d(32, 64, 1024)); - ASSERT_EQ(1, ThreadsHelper::pickLoop2d(6, 6, 16384)); + // in this case we'll get better split over first loop - 2 loops/2048 elements + // per thread + ASSERT_EQ(1, ThreadsHelper::pickLoop2d(32, 64, 1024)); + ASSERT_EQ(1, ThreadsHelper::pickLoop2d(6, 6, 16384)); - // in this case none of loops are good enough, but second loop is too small for split - ASSERT_EQ(1, ThreadsHelper::pickLoop2d(6, 64, 32)); + // in this case none of loops are good enough, but second loop is too small + // for split + ASSERT_EQ(1, ThreadsHelper::pickLoop2d(6, 64, 32)); - // all loops are good enough, but we go with bigger one, since small - ASSERT_EQ(1, ThreadsHelper::pickLoop2d(2, 64, 32)); + // all loops are good enough, but we go with bigger one, since small + ASSERT_EQ(1, ThreadsHelper::pickLoop2d(2, 64, 32)); - // obviously split goes into second loop, to give 1024 elements per thread - ASSERT_EQ(2, ThreadsHelper::pickLoop2d(2, 1, 2048)); + // obviously split goes into second loop, to give 1024 elements per thread + ASSERT_EQ(2, ThreadsHelper::pickLoop2d(2, 1, 2048)); } TEST_F(ThreadsTests, th_test_3) { - // typical conv cases - ASSERT_EQ(1, ThreadsHelper::pickLoop3d(4, 32, 3, 128)); - ASSERT_EQ(2, ThreadsHelper::pickLoop3d(4, 1, 128, 64)); - ASSERT_EQ(3, ThreadsHelper::pickLoop3d(4, 1, 3, 128)); - - // checking for optimal threads for conv inference - ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 1, 3, 128)); - ASSERT_EQ(4, ThreadsHelper::numberOfThreads3d(4, 1, 3, 128)); - ASSERT_EQ(8, ThreadsHelper::numberOfThreads3d(8, 1, 3, 128)); - - // checking for optimal threads for conv training - ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 16, 3, 128)); - ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 8, 3, 128)); - - - ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 8, 3, 64)); - ASSERT_EQ(1, ThreadsHelper::pickLoop3d(6, 8, 3, 64)); + // typical conv cases + ASSERT_EQ(1, ThreadsHelper::pickLoop3d(4, 32, 3, 128)); + ASSERT_EQ(2, ThreadsHelper::pickLoop3d(4, 1, 128, 64)); + ASSERT_EQ(3, ThreadsHelper::pickLoop3d(4, 1, 3, 128)); + + // checking for optimal threads for conv inference + ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 1, 3, 128)); + ASSERT_EQ(4, ThreadsHelper::numberOfThreads3d(4, 1, 3, 128)); + ASSERT_EQ(8, ThreadsHelper::numberOfThreads3d(8, 1, 3, 128)); + + // checking for optimal threads for conv training + ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 16, 3, 128)); + ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 8, 3, 128)); + + ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 8, 3, 64)); + ASSERT_EQ(1, ThreadsHelper::pickLoop3d(6, 8, 3, 64)); } TEST_F(ThreadsTests, th_test_5) { - ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 32, 112, 112)); + ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 32, 112, 112)); - ASSERT_EQ(1, ThreadsHelper::pickLoop3d(6, 32, 112, 112)); + ASSERT_EQ(1, ThreadsHelper::pickLoop3d(6, 32, 112, 112)); - for (auto e = 0; e < 6; e++) { - auto span = Span3::build(1, e, 6, 0, 32, 1, 0, 112, 1, 0, 112, 1); + for (auto e = 0; e < 6; e++) { + auto span = Span3::build(1, e, 6, 0, 32, 1, 0, 112, 1, 0, 112, 1); - nd4j_printf("Span start: %lld; stop: %lld\n", span.startX(), span.stopX()); - } + nd4j_printf("Span start: %lld; stop: %lld\n", span.startX(), span.stopX()); + } } TEST_F(ThreadsTests, th_test_4) { - // typical conv cases - ASSERT_EQ(2, ThreadsHelper::numberOfThreads2d(2, 32, 3)); - ASSERT_EQ(4, ThreadsHelper::numberOfThreads2d(4, 32, 3)); - ASSERT_EQ(6, ThreadsHelper::numberOfThreads2d(6, 32, 1)); - ASSERT_EQ(8, ThreadsHelper::numberOfThreads2d(8, 16, 64)); + // typical conv cases + ASSERT_EQ(2, ThreadsHelper::numberOfThreads2d(2, 32, 3)); + ASSERT_EQ(4, ThreadsHelper::numberOfThreads2d(4, 32, 3)); + ASSERT_EQ(6, ThreadsHelper::numberOfThreads2d(6, 32, 1)); + ASSERT_EQ(8, ThreadsHelper::numberOfThreads2d(8, 16, 64)); - ASSERT_EQ(1, ThreadsHelper::pickLoop2d(4, 32, 1)); - ASSERT_EQ(1, ThreadsHelper::pickLoop2d(8, 19, 17)); + ASSERT_EQ(1, ThreadsHelper::pickLoop2d(4, 32, 1)); + ASSERT_EQ(1, ThreadsHelper::pickLoop2d(8, 19, 17)); - // primes edge cases - ASSERT_EQ(6, ThreadsHelper::numberOfThreads2d(6, 19, 17)); - ASSERT_EQ(8, ThreadsHelper::numberOfThreads2d(8, 19, 17)); + // primes edge cases + ASSERT_EQ(6, ThreadsHelper::numberOfThreads2d(6, 19, 17)); + ASSERT_EQ(8, ThreadsHelper::numberOfThreads2d(8, 19, 17)); - ASSERT_EQ(1, ThreadsHelper::pickLoop2d(8, 19, 17)); + ASSERT_EQ(1, ThreadsHelper::pickLoop2d(8, 19, 17)); - for (auto e = 0; e < 6; e++) { - auto span = Span2::build(1, e, 6, 0, 19, 1, 0, 17, 1); + for (auto e = 0; e < 6; e++) { + auto span = Span2::build(1, e, 6, 0, 19, 1, 0, 17, 1); - nd4j_printf("Span start: %lld; stop: %lld\n", span.startX(), span.stopX()); - } + nd4j_printf("Span start: %lld; stop: %lld\n", span.startX(), span.stopX()); + } - nd4j_printf("-----------------------\n",""); - for (auto e = 0; e < 6; e++) { - auto span = Span2::build(1, e, 6, 0, 32, 1, 0, 3, 1); + nd4j_printf("-----------------------\n", ""); + for (auto e = 0; e < 6; e++) { + auto span = Span2::build(1, e, 6, 0, 32, 1, 0, 3, 1); - nd4j_printf("Span start: %lld; stop: %lld\n", span.startX(), span.stopX()); - } + nd4j_printf("Span start: %lld; stop: %lld\n", span.startX(), span.stopX()); + } } - TEST_F(ThreadsTests, test_span_converage_1) { - for (int b = 1; b <= 128; b++) { - for (int c = 1; c <= 64; c++) { - for (int t = 1; t <= 64; t++) { - - auto threads = ThreadsHelper::numberOfThreads2d(t, b, c); - auto loop = ThreadsHelper::pickLoop2d(threads, b, c); - - if (t > 1 && threads == 1 && (b > 1 && c > 1)) { - nd4j_printf("Got 1 thread for [%i, %i] loop; initial max threads: %i\n", b, c, t) - } - - auto sum = 0; - for (auto a = 0; a < threads; a++) { - auto span = Span2::build(loop, a,threads, 0, b, 1, 0, c, 1); - - if (loop == 1) - sum += span.stopX() - span.startX(); - else if (loop == 2) - sum += span.stopY() - span.startY(); - else - throw std::runtime_error("Bad loop!"); - } - - if (loop == 1) - ASSERT_EQ(b, sum); - else - ASSERT_EQ(c, sum); - } + for (int b = 1; b <= 128; b++) { + for (int c = 1; c <= 64; c++) { + for (int t = 1; t <= 64; t++) { + auto threads = ThreadsHelper::numberOfThreads2d(t, b, c); + auto loop = ThreadsHelper::pickLoop2d(threads, b, c); + + if (t > 1 && threads == 1 && (b > 1 && c > 1)) { + nd4j_printf( + "Got 1 thread for [%i, %i] loop; initial max threads: %i\n", b, c, + t) + } + + auto sum = 0; + for (auto a = 0; a < threads; a++) { + auto span = Span2::build(loop, a, threads, 0, b, 1, 0, c, 1); + + if (loop == 1) + sum += span.stopX() - span.startX(); + else if (loop == 2) + sum += span.stopY() - span.startY(); + else + throw std::runtime_error("Bad loop!"); } + + if (loop == 1) + ASSERT_EQ(b, sum); + else + ASSERT_EQ(c, sum); + } } + } } TEST_F(ThreadsTests, validation_test_2d_1) { - if (1 > 0) - return; - - std::vector threads({1, 2, 4, 6, 8, 12, 16, 20, 32, 48, 64}); + if (1 > 0) return; - for (int e = 1; e < 1024; e++) { - for (int i = 1; i <= 1024; i++ ) { - for (auto t:threads) { - std::atomic sum; - sum.store(0); + std::vector threads({1, 2, 4, 6, 8, 12, 16, 20, 32, 48, 64}); - auto func = PRAGMA_THREADS_FOR_2D { - for (auto x = start_x; x < stop_x; x += inc_x) { - for (auto y = start_y; y < stop_y; y += inc_y) { - sum++; - } - } - }; + for (int e = 1; e < 1024; e++) { + for (int i = 1; i <= 1024; i++) { + for (auto t : threads) { + std::atomic sum; + sum.store(0); - samediff::Threads::parallel_for(func, 0, e, 1, 0, i, 1, t, true); - - ASSERT_EQ(e * i, sum.load()); + auto func = PRAGMA_THREADS_FOR_2D { + for (auto x = start_x; x < stop_x; x += inc_x) { + for (auto y = start_y; y < stop_y; y += inc_y) { + sum++; } - } + } + }; - nd4j_printf("Finished iteration %i\n", e); + samediff::Threads::parallel_for(func, 0, e, 1, 0, i, 1, t, true); + + ASSERT_EQ(e * i, sum.load()); + } } + + nd4j_printf("Finished iteration %i\n", e); + } } TEST_F(ThreadsTests, reduction_test_1) { + auto func = PRAGMA_REDUCE_LONG { + int64_t sum = 0; - auto func = PRAGMA_REDUCE_LONG { - int64_t sum = 0; - - for (auto e = start; e < stop; e++) { - sum++; - }; - - return sum; + for (auto e = start; e < stop; e++) { + sum++; }; - auto sum = samediff::Threads::parallel_long(func, LAMBDA_AL {return _old + _new;}, 0, 8192, 1, 4); - ASSERT_EQ(8192, sum); + return sum; + }; + + auto sum = samediff::Threads::parallel_long( + func, LAMBDA_AL { return _old + _new; }, 0, 8192, 1, 4); + ASSERT_EQ(8192, sum); } static void _code(int thread_id) { @@ -252,7 +252,9 @@ TEST_F(ThreadsTests, basic_test_1) { auto timeStartThreads = std::chrono::system_clock::now(); samediff::Threads::parallel_for(func, 0, array.lengthOf()); auto timeEndThreads = std::chrono::system_clock::now(); - auto outerTimeThreads = std::chrono::duration_cast (timeEndThreads - timeStartThreads).count(); + auto outerTimeThreads = +std::chrono::duration_cast (timeEndThreads - +timeStartThreads).count(); auto timeStartOmp = std::chrono::system_clock::now(); PRAGMA_OMP_PARALLEL_FOR_SIMD @@ -260,10 +262,12 @@ TEST_F(ThreadsTests, basic_test_1) { lbuffer[e] += 1.0f; } auto timeEndOmp = std::chrono::system_clock::now(); - auto outerTimeOmp = std::chrono::duration_cast (timeEndOmp - timeStartOmp).count(); + auto outerTimeOmp = std::chrono::duration_cast +(timeEndOmp - timeStartOmp).count(); ASSERT_NEAR((float) array.lengthOf(), array.sumNumber().e(0), 1e-5f); - nd4j_printf("Threads time: %lld us; OMP time: %lld us; %p\n", outerTimeThreads, outerTimeOmp, instance) + nd4j_printf("Threads time: %lld us; OMP time: %lld us; %p\n", +outerTimeThreads, outerTimeOmp, instance) } */ \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/TypeCastTests.cpp b/libnd4j/tests_cpu/layers_tests/TypeCastTests.cpp index 2c27f95f9066..9b7dd6cfde10 100644 --- a/libnd4j/tests_cpu/layers_tests/TypeCastTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/TypeCastTests.cpp @@ -18,55 +18,56 @@ // Created by raver119 on 02/07/18. // -#include "testlayers.h" -#include #include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::ops; using namespace sd::graph; class TypeCastTests : public testing::Test { -public: - + public: }; TEST_F(TypeCastTests, Test_Cast_1) { #ifndef __CUDABLAS__ - const int limit = 100; - auto src = new double[limit]; - auto z = new float[limit]; - auto exp = new float[limit]; - - for (int e = 0; e < limit; e++) { - src[e] = static_cast(e); - exp[e] = static_cast(e); - } - - TypeCast::convertGeneric(nullptr, reinterpret_cast(src), limit, reinterpret_cast(z)); - - for (int e = 0; e < limit; e++) { - ASSERT_NEAR(exp[e], z[e], 1e-5f); - } - - delete[] src; - delete[] z; - delete[] exp; + const int limit = 100; + auto src = new double[limit]; + auto z = new float[limit]; + auto exp = new float[limit]; + + for (int e = 0; e < limit; e++) { + src[e] = static_cast(e); + exp[e] = static_cast(e); + } + + TypeCast::convertGeneric(nullptr, + reinterpret_cast(src), limit, + reinterpret_cast(z)); + + for (int e = 0; e < limit; e++) { + ASSERT_NEAR(exp[e], z[e], 1e-5f); + } + + delete[] src; + delete[] z; + delete[] exp; #endif } TEST_F(TypeCastTests, Test_ConvertDtype_1) { +#ifndef __CUDABLAS__ - #ifndef __CUDABLAS__ - - float src[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; - float16 dst[5]; - float16 exp[] = {(float16) 1.0f, (float16) 2.0f, (float16) 3.0f, (float16) 4.0f, (float16) 5.0f}; + float src[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + float16 dst[5]; + float16 exp[] = {(float16)1.0f, (float16)2.0f, (float16)3.0f, (float16)4.0f, + (float16)5.0f}; - convertTypes(nullptr, ND4J_FLOAT32, src, 5, ND4J_FLOAT16, dst); + convertTypes(nullptr, ND4J_FLOAT32, src, 5, ND4J_FLOAT16, dst); - for (int e = 0; e < 5; e++) - ASSERT_NEAR(exp[e], dst[e], (float16) 0.01f); + for (int e = 0; e < 5; e++) ASSERT_NEAR(exp[e], dst[e], (float16)0.01f); - #endif +#endif } diff --git a/libnd4j/tests_cpu/layers_tests/VariableProxyTests.cpp b/libnd4j/tests_cpu/layers_tests/VariableProxyTests.cpp index 16e7cf7ac310..fb4c0711bbeb 100644 --- a/libnd4j/tests_cpu/layers_tests/VariableProxyTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/VariableProxyTests.cpp @@ -18,156 +18,185 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class VariableProxyTests : public testing::Test { -public: - + public: }; - TEST_F(VariableProxyTests, Test_Simple_1) { - auto x = NDArrayFactory::create_('c', {2, 2}, {1, 2, 3, 4}); - VariableSpace ref; + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + VariableSpace ref; - ref.putVariable(119, x); + ref.putVariable(119, x); - ASSERT_TRUE(ref.hasVariable(119)); + ASSERT_TRUE(ref.hasVariable(119)); - VariableProxy proxy(&ref); + VariableProxy proxy(&ref); - ASSERT_TRUE(proxy.hasVariable(119)); + ASSERT_TRUE(proxy.hasVariable(119)); } - TEST_F(VariableProxyTests, Test_Simple_2) { - auto x = NDArrayFactory::create_('c', {2, 2}, {1, 2, 3, 4}); - VariableSpace ref; + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + VariableSpace ref; - ASSERT_FALSE(ref.hasVariable(119)); + ASSERT_FALSE(ref.hasVariable(119)); - VariableProxy proxy(&ref); + VariableProxy proxy(&ref); - ASSERT_FALSE(proxy.hasVariable(119)); + ASSERT_FALSE(proxy.hasVariable(119)); - proxy.putVariable(119, x); + proxy.putVariable(119, x); - ASSERT_FALSE(ref.hasVariable(119)); + ASSERT_FALSE(ref.hasVariable(119)); - ASSERT_TRUE(proxy.hasVariable(119)); + ASSERT_TRUE(proxy.hasVariable(119)); } - TEST_F(VariableProxyTests, Test_Simple_3) { - auto x = NDArrayFactory::create_('c', {2, 2}, {1, 2, 3, 4}); - auto y = NDArrayFactory::create_('c', {2, 2}, {4, 2, 3, 1}); - VariableSpace ref; + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {2, 2}, {4, 2, 3, 1}); + VariableSpace ref; - ref.putVariable(119, x); + ref.putVariable(119, x); - ASSERT_TRUE(ref.hasVariable(119)); + ASSERT_TRUE(ref.hasVariable(119)); - VariableProxy proxy(&ref); + VariableProxy proxy(&ref); - ASSERT_TRUE(proxy.hasVariable(119)); + ASSERT_TRUE(proxy.hasVariable(119)); - proxy.putVariable(119, y); + proxy.putVariable(119, y); - ASSERT_TRUE(ref.hasVariable(119)); + ASSERT_TRUE(ref.hasVariable(119)); - ASSERT_TRUE(proxy.hasVariable(119)); + ASSERT_TRUE(proxy.hasVariable(119)); - auto z0 = ref.getVariable(119)->getNDArray(); - auto z1 = proxy.getVariable(119)->getNDArray(); + auto z0 = ref.getVariable(119)->getNDArray(); + auto z1 = proxy.getVariable(119)->getNDArray(); - ASSERT_FALSE(z0 == z1); - ASSERT_TRUE(y == z1); - ASSERT_TRUE(x == z0); + ASSERT_FALSE(z0 == z1); + ASSERT_TRUE(y == *z1); + ASSERT_TRUE(x == *z0); } TEST_F(VariableProxyTests, Test_Simple_4) { - auto x = NDArrayFactory::create_('c', {2, 2}, {1, 2, 3, 4}); - auto y = NDArrayFactory::create_('c', {2, 2}, {4, 2, 3, 1}); - auto z = NDArrayFactory::create_('c', {2, 2}, {4, 1, 3, 2}); - VariableSpace ref; + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {2, 2}, {4, 2, 3, 1}); + auto z = NDArrayFactory::create('c', {2, 2}, {4, 1, 3, 2}); + VariableSpace ref; - ref.putVariable(119, x); - ref.putVariable(118, z); + ref.putVariable(119, x); + ref.putVariable(118, z); - ASSERT_TRUE(ref.hasVariable(119)); + ASSERT_TRUE(ref.hasVariable(119)); - VariableProxy proxy(&ref); + VariableProxy proxy(&ref); - ASSERT_TRUE(proxy.hasVariable(119)); + ASSERT_TRUE(proxy.hasVariable(119)); - proxy.putVariable(119, y); + proxy.putVariable(119, y); - ASSERT_TRUE(ref.hasVariable(119)); - ASSERT_TRUE(ref.hasVariable(118)); + ASSERT_TRUE(ref.hasVariable(119)); + ASSERT_TRUE(ref.hasVariable(118)); - ASSERT_TRUE(proxy.hasVariable(119)); - ASSERT_TRUE(proxy.hasVariable(118)); + ASSERT_TRUE(proxy.hasVariable(119)); + ASSERT_TRUE(proxy.hasVariable(118)); - auto z0 = ref.getVariable(119)->getNDArray(); - auto z1 = proxy.getVariable(119)->getNDArray(); + auto z0 = ref.getVariable(119)->getNDArray(); + auto z1 = proxy.getVariable(119)->getNDArray(); - ASSERT_FALSE(z0 == z1); - ASSERT_TRUE(y == z1); - ASSERT_TRUE(x == z0); + ASSERT_FALSE(z0 == z1); + ASSERT_TRUE(y == *z1); + ASSERT_TRUE(x == *z0); } - TEST_F(VariableProxyTests, Test_Cast_1) { - auto x = NDArrayFactory::create_('c', {2, 2}, {1, 2, 3, 4}); - auto y = NDArrayFactory::create_('c', {2, 2}, {4, 2, 3, 1}); - VariableSpace ref; + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {2, 2}, {4, 2, 3, 1}); + VariableSpace ref; - ref.putVariable(-119, x); + ref.putVariable(-119, x); - ASSERT_TRUE(ref.hasVariable(-119)); + ASSERT_TRUE(ref.hasVariable(-119)); - VariableProxy proxy(&ref); - auto cast = (VariableSpace *) &proxy; + VariableProxy proxy(&ref); + auto cast = (VariableSpace *)&proxy; - ASSERT_TRUE(cast->hasVariable(-119)); + ASSERT_TRUE(cast->hasVariable(-119)); - cast->putVariable(-119, y); + cast->putVariable(-119, y); - ASSERT_TRUE(ref.hasVariable(-119)); + ASSERT_TRUE(ref.hasVariable(-119)); - ASSERT_TRUE(cast->hasVariable(-119)); + ASSERT_TRUE(cast->hasVariable(-119)); - auto z0 = ref.getVariable(-119)->getNDArray(); - auto z1 = cast->getVariable(-119)->getNDArray(); + auto z0 = ref.getVariable(-119)->getNDArray(); + auto z1 = cast->getVariable(-119)->getNDArray(); - ASSERT_FALSE(z0 == z1); - ASSERT_TRUE(y == z1); - ASSERT_TRUE(x == z0); + ASSERT_FALSE(z0 == z1); + ASSERT_TRUE(y == *z1); + ASSERT_TRUE(x == *z0); } +TEST_F(VariableProxyTests, test_update_1) { + VariableSpace ref; + + auto x = NDArrayFactory::create(1); + auto A = NDArrayFactory::create(2); + auto B = NDArrayFactory::create(3); + + // set initial states for all 3 VariableSpaces/Proxies + ref.putVariable(2, x); + + VariableProxy proxyA(&ref); + proxyA.putVariable(2, A); -TEST_F(VariableProxyTests, Test_Clone_1) { - auto x = NDArrayFactory::create_('c', {2, 2}, {1, 2, 3, 4}); - auto y = NDArrayFactory::create_('c', {2, 2}, {4, 2, 3, 1}); - VariableSpace ref; + VariableProxy proxyB(&proxyA); + proxyB.putVariable(2, B); + + // check out initial state + ASSERT_EQ(x, *ref.getVariable(2)->getNDArray().get()); + ASSERT_EQ(A, *proxyA.getVariable(2)->getNDArray().get()); + ASSERT_EQ(B, *proxyB.getVariable(2)->getNDArray().get()); + + // update state now and check result + proxyB.pushTo(proxyA); + ASSERT_EQ(x, *ref.getVariable(2)->getNDArray().get()); + ASSERT_EQ(B, *proxyA.getVariable(2)->getNDArray().get()); + ASSERT_EQ(B, *proxyB.getVariable(2)->getNDArray().get()); +} - ref.putVariable(118, x); +TEST_F(VariableProxyTests, test_update_2) { + VariableSpace ref; - VariableProxy proxy(&ref); + auto x = NDArrayFactory::create(1); + auto A = NDArrayFactory::create(2); + auto B = NDArrayFactory::create(3); - proxy.putVariable(119, y); + // set initial states for all 3 VariableSpaces/Proxies + ref.putVariable(2, x); - ASSERT_TRUE(proxy.hasVariable(118)); - ASSERT_TRUE(proxy.hasVariable(119)); + VariableProxy proxyA(&ref); + proxyA.putVariable(2, A); - auto clone = proxy.clone(); + VariableProxy proxyB(&proxyA); + proxyB.putVariable(2, B); - ASSERT_TRUE(clone->hasVariable(118)); - ASSERT_TRUE(clone->hasVariable(119)); + // check out initial state + ASSERT_EQ(x, *ref.getVariable(2)->getNDArray().get()); + ASSERT_EQ(A, *proxyA.getVariable(2)->getNDArray().get()); + ASSERT_EQ(B, *proxyB.getVariable(2)->getNDArray().get()); - delete clone; + // update state now and check result + proxyB.pullFrom(proxyA); + ASSERT_EQ(x, *ref.getVariable(2)->getNDArray().get()); + ASSERT_EQ(A, *proxyA.getVariable(2)->getNDArray().get()); + ASSERT_EQ(A, *proxyB.getVariable(2)->getNDArray().get()); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/VariableSpaceTests.cpp b/libnd4j/tests_cpu/layers_tests/VariableSpaceTests.cpp index ec10f3db097b..bd9df549124a 100644 --- a/libnd4j/tests_cpu/layers_tests/VariableSpaceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/VariableSpaceTests.cpp @@ -18,203 +18,102 @@ // @author raver119@gmail.com // -#include "testlayers.h" +#include #include -#include -#include -#include #include -#include +#include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class VariableSpaceTest : public testing::Test { -public: - int *cShape = new int[8]{2, 2, 2, 2, 1, 0, 1, 99}; - int *fShape = new int[8]{2, 2, 2, 1, 2, 0, 1, 102}; - - - ~VariableSpaceTest() { - delete[] cShape; - delete[] fShape; - } + public: + int *cShape = new int[8]{2, 2, 2, 2, 1, 0, 1, 99}; + int *fShape = new int[8]{2, 2, 2, 1, 2, 0, 1, 102}; + + ~VariableSpaceTest() { + delete[] cShape; + delete[] fShape; + } }; - TEST_F(VariableSpaceTest, SettersGettersTest1) { - auto space1 = new VariableSpace(); - auto arrayA = NDArrayFactory::create_('c', {5, 5}); - auto arrayB = NDArrayFactory::create_('c', {3, 3}); + auto space1 = new VariableSpace(); + auto arrayA = NDArrayFactory::create('c', {5, 5}); + auto arrayB = NDArrayFactory::create('c', {3, 3}); - space1->putVariable(1, arrayA); - space1->putVariable(2, arrayB); + space1->putVariable(1, arrayA); + space1->putVariable(2, arrayB); - auto arrayRA = space1->getVariable(1); - auto arrayRB = space1->getVariable(2); + auto arrayRA = space1->getVariable(1); + auto arrayRB = space1->getVariable(2); - ASSERT_TRUE(arrayA == arrayRA->getNDArray()); - ASSERT_TRUE(arrayB == arrayRB->getNDArray()); + ASSERT_TRUE(arrayA == *arrayRA->getNDArray()); + ASSERT_TRUE(arrayB == *arrayRB->getNDArray()); - // we should survive this call - delete space1; + // we should survive this call + delete space1; } - TEST_F(VariableSpaceTest, SettersGettersTest2) { - auto space1 = new VariableSpace(); - auto arrayA = NDArrayFactory::create_('c', {5, 5}); - auto arrayB = NDArrayFactory::create_('c', {3, 3}); + auto space1 = new VariableSpace(); + auto arrayA = NDArrayFactory::create('c', {5, 5}); + auto arrayB = NDArrayFactory::create('c', {3, 3}); - auto varA = new Variable(arrayA); - auto varB = new Variable(arrayB); + space1->putVariable(-1, arrayA); + space1->putVariable(2, arrayB); - varA->markExternal(true); + Nd4jLong expExternal = (25 * 4) + (8 * 8); + Nd4jLong expInternal = (9 * 4) + (8 * 8); - space1->putVariable(-1, varA); - space1->putVariable(2, varB); + ASSERT_EQ(expExternal, space1->externalMemory()); + ASSERT_EQ(expInternal, space1->internalMemory()); - Nd4jLong expExternal = (25 * 4) + (8 * 8); - Nd4jLong expInternal = (9 * 4) + (8 * 8); - - ASSERT_EQ(expExternal, space1->externalMemory()); - ASSERT_EQ(expInternal, space1->internalMemory()); - - delete space1; + delete space1; } TEST_F(VariableSpaceTest, EqualityTest1) { - VariableSpace space; + VariableSpace space; - std::string name("myvar"); + std::string name("myvar"); - auto arrayA = NDArrayFactory::create_('c', {3, 3}); - auto variableA = new Variable(arrayA, name.c_str()); + auto arrayA = NDArrayFactory::create('c', {3, 3}); + auto variableA = std::make_shared(arrayA, name, 1); - space.putVariable(1, variableA); + space.putVariable(1, variableA); - std::pair pair(1,0); + std::pair pair(1, 0); - ASSERT_TRUE(space.hasVariable(1)); - ASSERT_TRUE(space.hasVariable(pair)); - ASSERT_TRUE(space.hasVariable(&name)); + ASSERT_TRUE(space.hasVariable(1)); + ASSERT_TRUE(space.hasVariable(pair)); + ASSERT_TRUE(space.hasVariable(name)); - auto rV1 = space.getVariable(1); - auto rV2 = space.getVariable(pair); - auto rV3 = space.getVariable(&name); + auto rV1 = space.getVariable(1); + auto rV2 = space.getVariable(pair); + auto rV3 = space.getVariable(name); - ASSERT_TRUE(rV1 == rV2); - ASSERT_TRUE(rV2 == rV3); + ASSERT_TRUE(rV1 == rV2); + ASSERT_TRUE(rV2 == rV3); } TEST_F(VariableSpaceTest, EqualityTest2) { - VariableSpace space; - - auto arrayA = NDArrayFactory::create_('c', {3, 3}); - - space.putVariable(1, arrayA); - - std::pair pair(1,0); - - ASSERT_TRUE(space.hasVariable(1)); - ASSERT_TRUE(space.hasVariable(pair)); - - auto rV1 = space.getVariable(1); - auto rV2 = space.getVariable(pair); - - ASSERT_TRUE(rV1 == rV2); -} - -TEST_F(VariableSpaceTest, CloneTests_1) { - VariableSpace spaceA; - - auto arrayA = NDArrayFactory::create_('c', {3, 3}); - arrayA->assign(1.0); - - spaceA.putVariable(1, arrayA); - - auto spaceB = spaceA.clone(); - - std::pair pair(1,0); - - ASSERT_TRUE(spaceB->hasVariable(1)); - ASSERT_TRUE(spaceB->hasVariable(pair)); - - auto arrayB = spaceB->getVariable(1)->getNDArray(); - - ASSERT_TRUE(arrayA->equalsTo(arrayB)); - - arrayB->assign(2.0); - - ASSERT_FALSE(arrayA->equalsTo(arrayB)); - - delete spaceB; -} - -TEST_F(VariableSpaceTest, CloneTests_2) { - VariableSpace spaceA; - - auto arrayA = NDArrayFactory::create_('c', {3, 3}); - arrayA->assign(1.0); - - auto variableA = new Variable(arrayA, "alpha"); - - std::string str("alpha"); - std::pair pair(2, 3); - - spaceA.putVariable(pair, variableA); - - ASSERT_TRUE(spaceA.hasVariable(&str)); - ASSERT_TRUE(spaceA.hasVariable(pair)); - - auto spaceB = spaceA.clone(); - - ASSERT_FALSE(spaceB->hasVariable(1)); - ASSERT_FALSE(spaceB->hasVariable(2)); - ASSERT_TRUE(spaceB->hasVariable(pair)); - ASSERT_TRUE(spaceB->hasVariable(&str)); - - auto arrayB = spaceB->getVariable(pair)->getNDArray(); - - ASSERT_TRUE(arrayA->equalsTo(arrayB)); - - arrayB->assign(2.0); - - ASSERT_FALSE(arrayA->equalsTo(arrayB)); - - delete spaceB; - - ASSERT_TRUE(spaceA.hasVariable(&str)); - ASSERT_TRUE(spaceA.hasVariable(pair)); -} - - -TEST_F(VariableSpaceTest, Test_DType_Conversion_1) { - /* - VariableSpace spaceA; - - auto arrayA = NDArrayFactory::create_('c', {3, 3}); - arrayA->assign(1.0); - - auto variableA = new Variable(arrayA, "alpha"); - - std::string str("alpha"); - std::pair pair(2, 3); - - spaceA.putVariable(pair, variableA); + VariableSpace space; + auto arrayA = NDArrayFactory::create('c', {3, 3}); - auto sd = spaceA.template asT(); - auto sf = sd->template asT(); + space.putVariable(1, arrayA); - ASSERT_TRUE(sf->hasVariable(pair)); + std::pair pair(1, 0); - auto xf = sf->getVariable(pair)->getNDArray(); + ASSERT_TRUE(space.hasVariable(1)); + ASSERT_TRUE(space.hasVariable(pair)); - ASSERT_TRUE(arrayA->isSameShape(xf)); - ASSERT_TRUE(arrayA->equalsTo(xf)); + auto rV1 = space.getVariable(1); + auto rV2 = space.getVariable(pair); - delete sd; - delete sf; - */ + ASSERT_TRUE(rV1 == rV2); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/VariableTests.cpp b/libnd4j/tests_cpu/layers_tests/VariableTests.cpp index 49b9b02d6dc3..99342f5c9e6b 100644 --- a/libnd4j/tests_cpu/layers_tests/VariableTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/VariableTests.cpp @@ -21,151 +21,128 @@ #ifndef LIBND4J_VARIABLETESTS_H #define LIBND4J_VARIABLETESTS_H -#include "testlayers.h" #include -#include #include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class VariableTests : public testing::Test { -public: - + public: }; -TEST_F(VariableTests, TestClone_1) { - auto array1 = NDArrayFactory::create_('c', {5, 5}); - array1->assign(1.0); - - auto var1 = new Variable(array1, "alpha"); - var1->setId(119); - - - auto var2 = var1->clone(); - - ASSERT_FALSE(var1->getNDArray() == var2->getNDArray()); - auto array2 = var2->getNDArray(); - - ASSERT_TRUE(array1->equalsTo(array2)); - ASSERT_EQ(var1->id(), var2->id()); - ASSERT_EQ(*var1->getName(), *var2->getName()); - - delete var1; - - std::string str("alpha"); - ASSERT_EQ(*var2->getName(), str); - array2->assign(2.0); - - ASSERT_NEAR(2.0, array2->meanNumber().e(0), 1e-5); - - delete var2; -} - TEST_F(VariableTests, Test_FlatVariableDataType_1) { - flatbuffers::FlatBufferBuilder builder(1024); - auto original = NDArrayFactory::create('c', {5, 10}); - original.linspace(1); + flatbuffers::FlatBufferBuilder builder(1024); + auto original = NDArrayFactory::create('c', {5, 10}); + original.linspace(1); - auto vec = original.asByteVector(); + auto vec = original.asByteVector(); - auto fShape = builder.CreateVector(original.getShapeInfoAsFlatVector()); - auto fBuffer = builder.CreateVector(vec); - auto fVid = CreateIntPair(builder, 1, 12); + auto fShape = builder.CreateVector(original.getShapeInfoAsFlatVector()); + auto fBuffer = builder.CreateVector(vec); + auto fVid = CreateIntPair(builder, 1, 12); - auto fArray = CreateFlatArray(builder, fShape, fBuffer, sd::graph::DType::DType_FLOAT); + auto fArray = + CreateFlatArray(builder, fShape, fBuffer, sd::graph::DType::DType_FLOAT); - auto flatVar = CreateFlatVariable(builder, fVid, 0, sd::graph::DType::DType_FLOAT, 0, fArray); + auto flatVar = CreateFlatVariable(builder, fVid, 0, + sd::graph::DType::DType_FLOAT, 0, fArray); - builder.Finish(flatVar); + builder.Finish(flatVar); - auto ptr = builder.GetBufferPointer(); + auto ptr = builder.GetBufferPointer(); - auto restoredVar = GetFlatVariable(ptr); + auto restoredVar = GetFlatVariable(ptr); - auto rv = new Variable(restoredVar); + auto rv = new Variable(restoredVar); - ASSERT_EQ(1, rv->id()); - ASSERT_EQ(12, rv->index()); + ASSERT_EQ(1, rv->id()); + ASSERT_EQ(12, rv->index()); - auto restoredArray = rv->getNDArray(); + auto restoredArray = rv->getNDArray(); - ASSERT_TRUE(original.isSameShape(restoredArray)); - ASSERT_TRUE(original.equalsTo(restoredArray)); + ASSERT_TRUE(original.isSameShape(*restoredArray)); + ASSERT_TRUE(original.equalsTo(*restoredArray)); - delete rv; + delete rv; } TEST_F(VariableTests, Test_FlatVariableDataType_2) { - flatbuffers::FlatBufferBuilder builder(1024); - auto original = NDArrayFactory::create('c', {5, 10}); - original.linspace(1); + flatbuffers::FlatBufferBuilder builder(1024); + auto original = NDArrayFactory::create('c', {5, 10}); + original.linspace(1); - auto vec = original.asByteVector(); + auto vec = original.asByteVector(); - auto fShape = builder.CreateVector(original.getShapeInfoAsFlatVector()); - auto fBuffer = builder.CreateVector(vec); - auto fVid = CreateIntPair(builder, 1, 12); + auto fShape = builder.CreateVector(original.getShapeInfoAsFlatVector()); + auto fBuffer = builder.CreateVector(vec); + auto fVid = CreateIntPair(builder, 1, 12); - auto fArray = CreateFlatArray(builder, fShape, fBuffer, sd::graph::DType::DType_DOUBLE); + auto fArray = + CreateFlatArray(builder, fShape, fBuffer, sd::graph::DType::DType_DOUBLE); - auto flatVar = CreateFlatVariable(builder, fVid, 0, sd::graph::DType::DType_DOUBLE, 0, fArray); + auto flatVar = CreateFlatVariable(builder, fVid, 0, + sd::graph::DType::DType_DOUBLE, 0, fArray); - builder.Finish(flatVar); + builder.Finish(flatVar); - auto ptr = builder.GetBufferPointer(); + auto ptr = builder.GetBufferPointer(); - auto restoredVar = GetFlatVariable(ptr); + auto restoredVar = GetFlatVariable(ptr); - auto rv = new Variable(restoredVar); + auto rv = new Variable(restoredVar); - ASSERT_EQ(1, rv->id()); - ASSERT_EQ(12, rv->index()); + ASSERT_EQ(1, rv->id()); + ASSERT_EQ(12, rv->index()); - auto restoredArray = rv->getNDArray(); + auto restoredArray = rv->getNDArray(); - ASSERT_TRUE(original.isSameShape(restoredArray)); - ASSERT_TRUE(original.equalsTo(restoredArray)); + ASSERT_TRUE(original.isSameShape(*restoredArray)); + ASSERT_TRUE(original.equalsTo(*restoredArray)); - delete rv; + delete rv; } - TEST_F(VariableTests, Test_FlatVariableDataType_3) { - flatbuffers::FlatBufferBuilder builder(1024); - auto original = NDArrayFactory::create('c', {5, 10}); - auto floating = NDArrayFactory::create('c', {5, 10}); - original.linspace(1); - floating.linspace(1); + flatbuffers::FlatBufferBuilder builder(1024); + auto original = NDArrayFactory::create('c', {5, 10}); + auto floating = NDArrayFactory::create('c', {5, 10}); + original.linspace(1); + floating.linspace(1); - auto vec = original.asByteVector(); + auto vec = original.asByteVector(); - auto fShape = builder.CreateVector(original.getShapeInfoAsFlatVector()); - auto fBuffer = builder.CreateVector(vec); - auto fVid = CreateIntPair(builder, 1, 12); + auto fShape = builder.CreateVector(original.getShapeInfoAsFlatVector()); + auto fBuffer = builder.CreateVector(vec); + auto fVid = CreateIntPair(builder, 1, 12); - auto fArray = CreateFlatArray(builder, fShape, fBuffer, sd::graph::DType::DType_DOUBLE); + auto fArray = + CreateFlatArray(builder, fShape, fBuffer, sd::graph::DType::DType_DOUBLE); - auto flatVar = CreateFlatVariable(builder, fVid, 0, sd::graph::DType::DType_DOUBLE, 0, fArray); + auto flatVar = CreateFlatVariable(builder, fVid, 0, + sd::graph::DType::DType_DOUBLE, 0, fArray); - builder.Finish(flatVar); + builder.Finish(flatVar); - auto ptr = builder.GetBufferPointer(); + auto ptr = builder.GetBufferPointer(); - auto restoredVar = GetFlatVariable(ptr); + auto restoredVar = GetFlatVariable(ptr); - auto rv = new Variable(restoredVar); + auto rv = new Variable(restoredVar); - ASSERT_EQ(1, rv->id()); - ASSERT_EQ(12, rv->index()); + ASSERT_EQ(1, rv->id()); + ASSERT_EQ(12, rv->index()); - auto restoredArray = rv->getNDArray(); - auto conv = restoredArray->asT(); + auto restoredArray = rv->getNDArray(); + auto conv = restoredArray->asT(); - ASSERT_TRUE(floating.isSameShape(restoredArray)); - ASSERT_TRUE(floating.equalsTo(conv)); + ASSERT_TRUE(floating.isSameShape(*restoredArray)); + ASSERT_TRUE(floating.equalsTo(conv)); - delete rv; + delete rv; } /* @@ -179,7 +156,8 @@ TEST_F(VariableTests, Test_FlatVariableDataType_4) { auto fShape = builder.CreateVector(original.getShapeAsFlatVector()); auto fVid = CreateIntPair(builder, 37, 12); - auto flatVar = CreateFlatVariable(builder, fVid, 0, sd::graph::DType::DType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER); + auto flatVar = CreateFlatVariable(builder, fVid, 0, +sd::graph::DType::DType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER); builder.Finish(flatVar); @@ -202,24 +180,4 @@ TEST_F(VariableTests, Test_FlatVariableDataType_4) { delete rv; } */ -TEST_F(VariableTests, Test_Dtype_Conversion_1) { - auto x = NDArrayFactory::create_('c', {2, 3}, {1, 2, 3, 4, 5, 6}); - Variable v(x, "alpha", 12, 3); - - auto vd = v.template asT(); - auto vf = vd->template asT(); - - ASSERT_EQ(*v.getName(), *vf->getName()); - ASSERT_EQ(v.id(), vf->id()); - ASSERT_EQ(v.index(), vf->index()); - - auto xf = vf->getNDArray(); - - ASSERT_TRUE(x->isSameShape(xf)); - ASSERT_TRUE(x->equalsTo(xf)); - - delete vd; - delete vf; -} - -#endif //LIBND4J_VARIABLETESTS_H +#endif // LIBND4J_VARIABLETESTS_H diff --git a/libnd4j/tests_cpu/layers_tests/WorkspaceTests.cpp b/libnd4j/tests_cpu/layers_tests/WorkspaceTests.cpp index b291e5fbb20e..5b0e13a7c356 100644 --- a/libnd4j/tests_cpu/layers_tests/WorkspaceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/WorkspaceTests.cpp @@ -21,260 +21,248 @@ #ifndef LIBND4J_WORKSPACETESTS_H #define LIBND4J_WORKSPACETESTS_H -#include "testlayers.h" #include -#include -#include #include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::memory; -class WorkspaceTests : public testing::Test { - -}; - +class WorkspaceTests : public testing::Test {}; TEST_F(WorkspaceTests, BasicInitialization1) { - Workspace workspace(1024); + Workspace workspace(1024); - ASSERT_EQ(1024, workspace.getCurrentSize()); - ASSERT_EQ(0, workspace.getCurrentOffset()); + ASSERT_EQ(1024, workspace.getCurrentSize()); + ASSERT_EQ(0, workspace.getCurrentOffset()); } TEST_F(WorkspaceTests, BasicInitialization2) { - Workspace workspace(65536); + Workspace workspace(65536); - ASSERT_EQ(0, workspace.getCurrentOffset()); - LaunchContext ctx; - ctx.setWorkspace(&workspace); - auto array = NDArrayFactory::create('c', {5, 5}, &ctx); + ASSERT_EQ(0, workspace.getCurrentOffset()); + LaunchContext ctx; + ctx.setWorkspace(&workspace); + auto array = NDArrayFactory::create('c', {5, 5}, &ctx); - array.p(0, 1.0f); - array.p(5, 1.0f); + array.p(0, 1.0f); + array.p(5, 1.0f); - auto v = array.reduceNumber(reduce::Sum); - auto f = v.e(0); + auto v = array.reduceNumber(reduce::Sum); + auto f = v.e(0); - ASSERT_NEAR(2.0f, f, 1e-5); + ASSERT_NEAR(2.0f, f, 1e-5); - ASSERT_TRUE(workspace.getCurrentOffset() > 0); + ASSERT_TRUE(workspace.getCurrentOffset() > 0); } - TEST_F(WorkspaceTests, BasicInitialization3) { - Workspace workspace; + Workspace workspace; - ASSERT_EQ(0, workspace.getCurrentOffset()); - LaunchContext ctx; - ctx.setWorkspace(&workspace); + ASSERT_EQ(0, workspace.getCurrentOffset()); + LaunchContext ctx; + ctx.setWorkspace(&workspace); - auto array = NDArrayFactory::create('c', {5, 5}, &ctx); + auto array = NDArrayFactory::create('c', {5, 5}, &ctx); - array.p(0, 1.0f); - array.p(5, 1.0f); + array.p(0, 1.0f); + array.p(5, 1.0f); - auto v = array.reduceNumber(reduce::Sum); - auto f = v.e(0); + auto v = array.reduceNumber(reduce::Sum); + auto f = v.e(0); - ASSERT_NEAR(2.0f, array.reduceNumber(reduce::Sum).e(0), 1e-5); + ASSERT_NEAR(2.0f, array.reduceNumber(reduce::Sum).e(0), 1e-5); - ASSERT_TRUE(workspace.getCurrentOffset() == 0); + ASSERT_TRUE(workspace.getCurrentOffset() == 0); } - TEST_F(WorkspaceTests, ResetTest1) { - Workspace workspace(65536); - LaunchContext ctx; - ctx.setWorkspace(&workspace); + Workspace workspace(65536); + LaunchContext ctx; + ctx.setWorkspace(&workspace); - auto array = NDArrayFactory::create('c', {5, 5}, &ctx); - array.p(0, 1.0f); - array.p(5, 1.0f); + auto array = NDArrayFactory::create('c', {5, 5}, &ctx); + array.p(0, 1.0f); + array.p(5, 1.0f); - workspace.scopeOut(); - for (int e = 0; e < 5; e++) { - workspace.scopeIn(); + workspace.scopeOut(); + for (int e = 0; e < 5; e++) { + workspace.scopeIn(); - auto array2 = NDArrayFactory::create('c', {5, 5}, &ctx); - array2.p(0, 1.0f); - array2.p(5, 1.0f); + auto array2 = NDArrayFactory::create('c', {5, 5}, &ctx); + array2.p(0, 1.0f); + array2.p(5, 1.0f); - ASSERT_NEAR(2.0f, array2.reduceNumber(reduce::Sum).e(0), 1e-5); + ASSERT_NEAR(2.0f, array2.reduceNumber(reduce::Sum).e(0), 1e-5); - workspace.scopeOut(); - } + workspace.scopeOut(); + } - ASSERT_EQ(65536, workspace.getCurrentSize()); - ASSERT_EQ(0, workspace.getCurrentOffset()); - ASSERT_EQ(0, workspace.getSpilledSize()); + ASSERT_EQ(65536, workspace.getCurrentSize()); + ASSERT_EQ(0, workspace.getCurrentOffset()); + ASSERT_EQ(0, workspace.getSpilledSize()); } - TEST_F(WorkspaceTests, StretchTest1) { - if (!Environment::getInstance().isCPU()) - return; + if (!Environment::getInstance().isCPU()) return; - Workspace workspace(128); - void* ptr = workspace.allocateBytes(8); - workspace.scopeOut(); - ASSERT_EQ(0, workspace.getSpilledSize()); - ASSERT_EQ(0, workspace.getSpilledSecondarySize()); - ASSERT_EQ(0, workspace.getCurrentOffset()); - ASSERT_EQ(0, workspace.getCurrentSecondaryOffset()); - - - workspace.scopeIn(); - for (int e = 0; e < 10; e++) { - - workspace.allocateBytes(128); - - } - ASSERT_EQ(128 * 9, workspace.getSpilledSize()); - workspace.scopeOut(); - workspace.scopeIn(); + Workspace workspace(128); + void* ptr = workspace.allocateBytes(8); + workspace.scopeOut(); + ASSERT_EQ(0, workspace.getSpilledSize()); + ASSERT_EQ(0, workspace.getSpilledSecondarySize()); + ASSERT_EQ(0, workspace.getCurrentOffset()); + ASSERT_EQ(0, workspace.getCurrentSecondaryOffset()); - ASSERT_EQ(0, workspace.getCurrentOffset()); + workspace.scopeIn(); + for (int e = 0; e < 10; e++) { + workspace.allocateBytes(128); + } + ASSERT_EQ(128 * 9, workspace.getSpilledSize()); + workspace.scopeOut(); + workspace.scopeIn(); - // we should have absolutely different pointer here, due to reallocation - void* ptr2 = workspace.allocateBytes(8); + ASSERT_EQ(0, workspace.getCurrentOffset()); - //ASSERT_FALSE(ptr == ptr2); + // we should have absolutely different pointer here, due to reallocation + void* ptr2 = workspace.allocateBytes(8); + // ASSERT_FALSE(ptr == ptr2); - ASSERT_EQ(1280, workspace.getCurrentSize()); - ASSERT_EQ(0, workspace.getSpilledSize()); + ASSERT_EQ(1280, workspace.getCurrentSize()); + ASSERT_EQ(0, workspace.getSpilledSize()); } TEST_F(WorkspaceTests, NewInWorkspaceTest1) { - if (!Environment::getInstance().isCPU()) - return; + if (!Environment::getInstance().isCPU()) return; - Workspace ws(65536); + Workspace ws(65536); - ASSERT_EQ(65536, ws.getCurrentSize()); - ASSERT_EQ(0, ws.getCurrentOffset()); + ASSERT_EQ(65536, ws.getCurrentSize()); + ASSERT_EQ(0, ws.getCurrentOffset()); - ASSERT_FALSE(MemoryRegistrator::getInstance().hasWorkspaceAttached()); + ASSERT_FALSE(MemoryRegistrator::getInstance().hasWorkspaceAttached()); - MemoryRegistrator::getInstance().attachWorkspace(&ws); + MemoryRegistrator::getInstance().attachWorkspace(&ws); - ASSERT_TRUE(MemoryRegistrator::getInstance().hasWorkspaceAttached()); + ASSERT_TRUE(MemoryRegistrator::getInstance().hasWorkspaceAttached()); - auto ast = NDArrayFactory::create_('c', {5, 5}); + auto ast = NDArrayFactory::create_('c', {5, 5}); - ASSERT_TRUE(ws.getCurrentOffset() > 0); + ASSERT_TRUE(ws.getCurrentOffset() > 0); - delete ast; + delete ast; - MemoryRegistrator::getInstance().forgetWorkspace(); + MemoryRegistrator::getInstance().forgetWorkspace(); - ASSERT_FALSE(MemoryRegistrator::getInstance().hasWorkspaceAttached()); - ASSERT_TRUE(MemoryRegistrator::getInstance().getWorkspace() == nullptr); + ASSERT_FALSE(MemoryRegistrator::getInstance().hasWorkspaceAttached()); + ASSERT_TRUE(MemoryRegistrator::getInstance().getWorkspace() == nullptr); } - TEST_F(WorkspaceTests, NewInWorkspaceTest2) { - Workspace ws(65536); - LaunchContext ctx; - ctx.setWorkspace(&ws); + Workspace ws(65536); + LaunchContext ctx; + ctx.setWorkspace(&ws); - ASSERT_EQ(65536, ws.getCurrentSize()); - ASSERT_EQ(0, ws.getCurrentOffset()); + ASSERT_EQ(65536, ws.getCurrentSize()); + ASSERT_EQ(0, ws.getCurrentOffset()); - MemoryRegistrator::getInstance().attachWorkspace(&ws); + MemoryRegistrator::getInstance().attachWorkspace(&ws); - auto ast = NDArrayFactory::create_('c', {5, 5}, &ctx); + auto ast = NDArrayFactory::create_('c', {5, 5}, &ctx); - ASSERT_TRUE(ws.getCurrentOffset() > 0); + ASSERT_TRUE(ws.getCurrentOffset() > 0); - delete ast; + delete ast; - MemoryRegistrator::getInstance().forgetWorkspace(); + MemoryRegistrator::getInstance().forgetWorkspace(); } TEST_F(WorkspaceTests, CloneTest1) { - if (!Environment::getInstance().isCPU()) - return; + if (!Environment::getInstance().isCPU()) return; - Workspace ws(65536); + Workspace ws(65536); - ws.allocateBytes(65536 * 2); + ws.allocateBytes(65536 * 2); - ASSERT_EQ(65536 * 2, ws.getSpilledSize()); + ASSERT_EQ(65536 * 2, ws.getSpilledSize()); - auto clone = ws.clone(); + auto clone = ws.clone(); - ASSERT_EQ(65536 * 2, clone->getCurrentSize()); - ASSERT_EQ(0, clone->getCurrentOffset()); - ASSERT_EQ(0, clone->getSpilledSize()); + ASSERT_EQ(65536 * 2, clone->getCurrentSize()); + ASSERT_EQ(0, clone->getCurrentOffset()); + ASSERT_EQ(0, clone->getSpilledSize()); - delete clone; + delete clone; } TEST_F(WorkspaceTests, Test_Arrays_1) { - Workspace ws(65536); - LaunchContext ctx; - ctx.setWorkspace(&ws); - - auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}, &ctx); + Workspace ws(65536); + LaunchContext ctx; + ctx.setWorkspace(&ws); - // x.printIndexedBuffer("x0"); + auto x = NDArrayFactory::create('c', {3, 3}, + {1, 2, 3, 4, 5, 6, 7, 8, 9}, &ctx); - auto y = NDArrayFactory::create('c', {3, 3}, {-1, -2, -3, -4, -5, -6, -7, -8, -9}, &ctx); + // x.printIndexedBuffer("x0"); - // x.printIndexedBuffer("x2"); + auto y = NDArrayFactory::create( + 'c', {3, 3}, {-1, -2, -3, -4, -5, -6, -7, -8, -9}, &ctx); - auto z = NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 0, 0, 0, 0, 0, 0}, &ctx); + // x.printIndexedBuffer("x2"); - MmulHelper::mmul(&x, &y, &z); + auto z = NDArrayFactory::create('c', {3, 3}, + {0, 0, 0, 0, 0, 0, 0, 0, 0}, &ctx); - y.assign(&x); + MmulHelper::mmul(&x, &y, &z); + y.assign(&x); - // x.printIndexedBuffer("x3"); - // y.printIndexedBuffer("y"); - // z.printIndexedBuffer("z"); + // x.printIndexedBuffer("x3"); + // y.printIndexedBuffer("y"); + // z.printIndexedBuffer("z"); } #ifdef GRAPH_FILES_OK TEST_F(WorkspaceTests, Test_Graph_1) { - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); - auto workspace = graph->getVariableSpace()->workspace(); + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); + auto workspace = graph->variableSpace()->workspace(); - auto status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); + auto status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); - delete graph; + delete graph; } #endif TEST_F(WorkspaceTests, Test_Externalized_1) { - if (!Environment::getInstance().isCPU()) - return; + if (!Environment::getInstance().isCPU()) return; - char buffer[10000]; - ExternalWorkspace pojo((Nd4jPointer) buffer, 10000, nullptr, 0); + char buffer[10000]; + ExternalWorkspace pojo((Nd4jPointer)buffer, 10000, nullptr, 0); - ASSERT_EQ(10000, pojo.sizeHost()); - ASSERT_EQ(0, pojo.sizeDevice()); + ASSERT_EQ(10000, pojo.sizeHost()); + ASSERT_EQ(0, pojo.sizeDevice()); - Workspace ws(&pojo); - ASSERT_EQ(10000, ws.getCurrentSize()); - ASSERT_EQ(10000, ws.getAllocatedSize()); - LaunchContext ctx; - ctx.setWorkspace(&ws); + Workspace ws(&pojo); + ASSERT_EQ(10000, ws.getCurrentSize()); + ASSERT_EQ(10000, ws.getAllocatedSize()); + LaunchContext ctx; + ctx.setWorkspace(&ws); - auto x = NDArrayFactory::create('c', {10, 10}, &ctx); + auto x = NDArrayFactory::create('c', {10, 10}, &ctx); - // only buffer size goes into account - ASSERT_EQ(400, ws.getUsedSize()); - ASSERT_EQ(400, ws.getCurrentOffset()); + // only buffer size goes into account + ASSERT_EQ(400, ws.getUsedSize()); + ASSERT_EQ(400, ws.getCurrentOffset()); - x.assign(2.0); + x.assign(2.0); - float m = x.meanNumber().e(0); - ASSERT_NEAR(2.0f, m, 1e-5); + float m = x.meanNumber().e(0); + ASSERT_NEAR(2.0f, m, 1e-5); } // TODO: uncomment this test once long shapes are introduced @@ -285,5 +273,4 @@ TEST_F(WorkspaceTests, Test_Big_Allocation_1) { } */ - -#endif //LIBND4J_WORKSPACETESTS_H \ No newline at end of file +#endif // LIBND4J_WORKSPACETESTS_H \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/WorkspaceTests.cu b/libnd4j/tests_cpu/layers_tests/WorkspaceTests.cu index 6fe157ac87ae..2f04fbfe2b6c 100644 --- a/libnd4j/tests_cpu/layers_tests/WorkspaceTests.cu +++ b/libnd4j/tests_cpu/layers_tests/WorkspaceTests.cu @@ -18,43 +18,46 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include -#include -#include #include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::memory; -class CudaWorkspaceTests : public testing::Test { - -}; +class CudaWorkspaceTests : public testing::Test {}; TEST_F(CudaWorkspaceTests, Basic_Tests_1) { - Workspace workspace(65536, 65536); + Workspace workspace(65536, 65536); - ASSERT_EQ(0, workspace.getCurrentOffset()); - LaunchContext ctx; - ctx.setWorkspace(&workspace); - auto array = NDArrayFactory::create('c', {5, 5}, &ctx); + ASSERT_EQ(0, workspace.getCurrentOffset()); + LaunchContext ctx; + ctx.setWorkspace(&workspace); + auto array = NDArrayFactory::create('c', {5, 5}, &ctx); - ASSERT_EQ(108, workspace.getCurrentOffset()); - ASSERT_EQ(0, workspace.getCurrentSecondaryOffset()); + ASSERT_EQ(108, workspace.getCurrentOffset()); + ASSERT_EQ(0, workspace.getCurrentSecondaryOffset()); - array.e(0); + array.e(0); - ASSERT_EQ(100, workspace.getCurrentSecondaryOffset()); + ASSERT_EQ(100, workspace.getCurrentSecondaryOffset()); } TEST_F(CudaWorkspaceTests, Basic_Tests_2) { - Workspace workspace(65536, 65536); - - ASSERT_EQ(0, workspace.getCurrentOffset()); - LaunchContext ctx; - ctx.setWorkspace(&workspace); - auto array = NDArrayFactory::create('c', {5, 5}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, &ctx); - - ASSERT_EQ(108, workspace.getCurrentOffset()); - ASSERT_EQ(0, workspace.getCurrentSecondaryOffset()); + Workspace workspace(65536, 65536); + + ASSERT_EQ(0, workspace.getCurrentOffset()); + LaunchContext ctx; + ctx.setWorkspace(&workspace); + auto array = NDArrayFactory::create( + 'c', {5, 5}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, + &ctx); + + ASSERT_EQ(108, workspace.getCurrentOffset()); + ASSERT_EQ(0, workspace.getCurrentSecondaryOffset()); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/testinclude.h b/libnd4j/tests_cpu/layers_tests/testinclude.h index b266019a9e31..48b23f3548e3 100644 --- a/libnd4j/tests_cpu/layers_tests/testinclude.h +++ b/libnd4j/tests_cpu/layers_tests/testinclude.h @@ -20,33 +20,38 @@ #ifndef LIBND4J_TESTINCLUDE_H #define LIBND4J_TESTINCLUDE_H -#include "testlayers.h" -#include #include -//https://stackoverflow.com/questions/228005/alternative-to-itoa-for-converting-integer-to-string-c -FORCEINLINE std::string int_array_to_string(Nd4jLong int_array[], Nd4jLong size_of_array) { - std::string returnstring = "["; - for (int temp = 0; temp < size_of_array; temp++) { - returnstring += std::to_string(int_array[temp]); - if(temp < size_of_array - 1) - returnstring += ","; - } - returnstring += "]"; - return returnstring; -} - -FORCEINLINE ::testing::AssertionResult arrsEquals(Nd4jLong n, Nd4jLong *assertion,Nd4jLong *other) { - for(int i = 0; i < n; i++) { - if(assertion[i] != other[i]) { - std::string message = std::string("Failure at index ") + std::to_string(i) + std::string(" assertion: ") + int_array_to_string(assertion,n) + std::string(" and test array ") + int_array_to_string(other,n) + std::string(" is not equal"); - return ::testing::AssertionFailure() << message; - } +#include - } - return ::testing::AssertionSuccess(); +#include "testlayers.h" +// https://stackoverflow.com/questions/228005/alternative-to-itoa-for-converting-integer-to-string-c +FORCEINLINE std::string int_array_to_string(Nd4jLong int_array[], + Nd4jLong size_of_array) { + std::string returnstring = "["; + for (int temp = 0; temp < size_of_array; temp++) { + returnstring += std::to_string(int_array[temp]); + if (temp < size_of_array - 1) returnstring += ","; + } + returnstring += "]"; + return returnstring; } +FORCEINLINE ::testing::AssertionResult arrsEquals(Nd4jLong n, + Nd4jLong *assertion, + Nd4jLong *other) { + for (int i = 0; i < n; i++) { + if (assertion[i] != other[i]) { + std::string message = + std::string("Failure at index ") + std::to_string(i) + + std::string(" assertion: ") + int_array_to_string(assertion, n) + + std::string(" and test array ") + int_array_to_string(other, n) + + std::string(" is not equal"); + return ::testing::AssertionFailure() << message; + } + } + return ::testing::AssertionSuccess(); +} -#endif //LIBND4J_TESTINCLUDE_H +#endif // LIBND4J_TESTINCLUDE_H diff --git a/libnd4j/tests_cpu/layers_tests/testlayers.h b/libnd4j/tests_cpu/layers_tests/testlayers.h index 9106223d861b..6e4f03fdfcc1 100644 --- a/libnd4j/tests_cpu/layers_tests/testlayers.h +++ b/libnd4j/tests_cpu/layers_tests/testlayers.h @@ -21,21 +21,21 @@ #ifndef LIBND4J_TESTLAYERS_H #define LIBND4J_TESTLAYERS_H -#include -#include -#include -#include +#include +#include +#include #include #include #include -#include -#include -#include +#include #include +#include +#include #include -#include -#include -#include +#include +#include +#include + #include -#endif //LIBND4J_TESTLAYERS_H +#endif // LIBND4J_TESTLAYERS_H diff --git a/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt b/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt index 7e01e284732c..9fad1242c33b 100644 --- a/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt +++ b/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt @@ -225,12 +225,12 @@ if (CMAKE_BUILD_TYPE STREQUAL "Debug" AND NOT(MINGW) AND NOT(APPLE)) SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -export-dynamic") endif() - file(GLOB_RECURSE COMPILATION_UNITS false ../include/ops/declarable/helpers/cpu/compilation_units/*.cpp.in + file(GLOB_RECURSE COMPILATION_UNITS false ../include/ops/declarable/helpers/cpu/compilation_units/*.cpp.in ../include/loops/cpu/compilation_units/*.cpp.in ../include/helpers/cpu/loops/*.cpp.in) - foreach(FL_ITEM ${COMPILATION_UNITS}) + foreach(FL_ITEM ${COMPILATION_UNITS}) genCompilation(FL_ITEM) - endforeach() + endforeach() # this function strips path from file name, basically making up short file name, i.e. file.cpp function(SHORTNAME LONG_NAME OUTPUT) diff --git a/libnd4j/tests_cpu/resources/cond_false.fb b/libnd4j/tests_cpu/resources/cond_false.fb new file mode 100644 index 000000000000..65629204d743 Binary files /dev/null and b/libnd4j/tests_cpu/resources/cond_false.fb differ diff --git a/libnd4j/tests_cpu/resources/cond_true.fb b/libnd4j/tests_cpu/resources/cond_true.fb index 003f7868a2d6..75a354c14192 100644 Binary files a/libnd4j/tests_cpu/resources/cond_true.fb and b/libnd4j/tests_cpu/resources/cond_true.fb differ diff --git a/libnd4j/tests_cpu/resources/while_iter1.fb b/libnd4j/tests_cpu/resources/while_iter1.fb new file mode 100644 index 000000000000..d81f0b4b3770 Binary files /dev/null and b/libnd4j/tests_cpu/resources/while_iter1.fb differ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java index b4bd62096c25..2619f9a933aa 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java @@ -1078,11 +1078,6 @@ void sortTad(PointerPointer extraPointers, void deleteVariablesSet(OpaqueVariablesSet pointer); - // GraphState creation - Pointer getGraphState(long id); - - void deleteGraphState(Pointer state); - int estimateThreshold(PointerPointer extraPointers, Pointer x, LongPointer xShapeInfo, int N, float threshold); // this method executes op that requires scope to be present: if/while/cond/whatever diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 2926c06b97a3..2799508c9607 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -273,16 +273,179 @@ public IntIntPair put(int firstValue, int secondValue) { // Created by raver119 on 07.05.19. // -// #ifndef DEV_TESTS_MEMORYTYPE_H -// #define DEV_TESTS_MEMORYTYPE_H - /** enum sd::memory::MemoryType */ - public static final int - HOST = 0, - DEVICE = 10; - +// #ifndef SD_MEMORYTYPE_H +// #define SD_MEMORYTYPE_H +/** enum sd::memory::MemoryType */ +public static final int + HOST = 0, + DEVICE = 10; + + // namespace sd + +// #endif // SD_MEMORYTYPE_H + + +// Parsed from memory/MemoryZone.h + +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +// #ifndef SD_MEMORYZONE_H +// #define SD_MEMORYZONE_H +/** enum sd::memory::MemoryZone */ +public static final int + COLD = 0, + WARM = 10, + HOT = 20; + + // namespace sd + +// #endif // SD_MEMORYZONE_H + + +// Parsed from memory/MemoryDescriptor.h + +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +// #ifndef SD_MEMORYDESCRIPTOR_H +// #define SD_MEMORYDESCRIPTOR_H + +// #include +// #include + +// #include +@Namespace("sd::memory") @NoOffset public static class MemoryDescriptor extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public MemoryDescriptor(Pointer p) { super(p); } + + public MemoryDescriptor(Pointer ptr, @Cast("sd::memory::MemoryZone") int zone, @Cast("uint64_t") long bytes) { super((Pointer)null); allocate(ptr, zone, bytes); } + private native void allocate(Pointer ptr, @Cast("sd::memory::MemoryZone") int zone, @Cast("uint64_t") long bytes); + + public MemoryDescriptor(@Const @ByRef MemoryDescriptor other) { super((Pointer)null); allocate(other); } + @NoException private native void allocate(@Const @ByRef MemoryDescriptor other); + + public native @ByRef @Name("operator =") @NoException MemoryDescriptor put(@Const @ByRef MemoryDescriptor other); + + // move constructor + + // move assignment operator + + public native @Name("address") Pointer _address(); + public native @Cast("sd::memory::MemoryZone") int zone(); + public native @Cast("uint64_t") long bytes(); +} + // namespace memory + // namespace sd + +// #endif // SD_MEMORYDESCRIPTOR_H + + +// Parsed from memory/GraphMemoryManager.h +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +// #ifndef SD_GRAPHMEMORYMANAGER_H +// #define SD_GRAPHMEMORYMANAGER_H + +// #include +// #include +// #include +// #include +// #include +// #include +// #include +@Namespace("sd::graph") @NoOffset public static class GraphMemoryManager extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public GraphMemoryManager(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public GraphMemoryManager(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public GraphMemoryManager position(long position) { + return (GraphMemoryManager)super.position(position); + } + + public GraphMemoryManager() { super((Pointer)null); allocate(); } + private native void allocate(); + + /** + * This method does allocation (probably) and returns structure that describes + * it + * @param numBytes - number of bytes to be allocated + * @param zone - memory zone for allocation + * @return + */ + public native @ByVal @Name("allocate") MemoryDescriptor _allocate(@Cast("size_t") long numBytes, @Cast("sd::memory::MemoryZone") int zone); + + /** + * This method releases (probably) memory chunk described by given descriptor + * @param descriptor + */ + public native void release(@ByRef MemoryDescriptor descriptor); + + /** + * This method allows to store reference to certain memory regions and keep them alive as long as Graph is alive + * @param ptr + */ +} + // namespace graph + // namespace sd -// #endif //DEV_TESTS_MEMORYTYPE_H +// #endif // SD_GRAPHMEMORYMANAGER_H // Parsed from array/DataType.h @@ -309,31 +472,31 @@ public IntIntPair put(int firstValue, int secondValue) { // #ifndef ND4J_DATATYPE_H // #define ND4J_DATATYPE_H - /** enum sd::DataType */ - public static final int - INHERIT = 0, - BOOL = 1, - FLOAT8 = 2, - HALF = 3, - HALF2 = 4, - FLOAT32 = 5, - DOUBLE = 6, - INT8 = 7, - INT16 = 8, - INT32 = 9, - INT64 = 10, - UINT8 = 11, - UINT16 = 12, - UINT32 = 13, - UINT64 = 14, - QINT8 = 15, - QINT16 = 16, - BFLOAT16 = 17, - UTF8 = 50, - UTF16 = 51, - UTF32 = 52, - ANY = 100, - AUTO = 200; +/** enum sd::DataType */ +public static final int + INHERIT = 0, + BOOL = 1, + FLOAT8 = 2, + HALF = 3, + HALF2 = 4, + FLOAT32 = 5, + DOUBLE = 6, + INT8 = 7, + INT16 = 8, + INT32 = 9, + INT64 = 10, + UINT8 = 11, + UINT16 = 12, + UINT32 = 13, + UINT64 = 14, + QINT8 = 15, + QINT16 = 16, + BFLOAT16 = 17, + UTF8 = 50, + UTF16 = 51, + UTF32 = 52, + ANY = 100, + AUTO = 200; // #endif @@ -361,16 +524,17 @@ public IntIntPair put(int firstValue, int secondValue) { // @author Yurii Shyrma (iuriish@yahoo.com) // -// #ifndef DEV_TESTS_DATABUFFER_H -// #define DEV_TESTS_DATABUFFER_H +// #ifndef SD_DATABUFFER_H +// #define SD_DATABUFFER_H -// #include -// #include -// #include -// #include // #include -// #include // #include +// #include +// #include +// #include +// #include + +// #include @Namespace("sd") @NoOffset public static class DataBuffer extends Pointer { static { Loader.load(); } @@ -383,112 +547,117 @@ public IntIntPair put(int firstValue, int secondValue) { return (DataBuffer)super.position(position); } + public DataBuffer(Pointer primary, Pointer special, @Cast("const size_t") long lenInBytes, + @Cast("const sd::DataType") int dataType, @Cast("const bool") boolean isOwnerPrimary/*=false*/, + @Cast("const bool") boolean isOwnerSpecial/*=false*/, + Workspace workspace/*=nullptr*/) { super((Pointer)null); allocate(primary, special, lenInBytes, dataType, isOwnerPrimary, isOwnerSpecial, workspace); } + private native void allocate(Pointer primary, Pointer special, @Cast("const size_t") long lenInBytes, + @Cast("const sd::DataType") int dataType, @Cast("const bool") boolean isOwnerPrimary/*=false*/, + @Cast("const bool") boolean isOwnerSpecial/*=false*/, + Workspace workspace/*=nullptr*/); + public DataBuffer(Pointer primary, Pointer special, @Cast("const size_t") long lenInBytes, + @Cast("const sd::DataType") int dataType) { super((Pointer)null); allocate(primary, special, lenInBytes, dataType); } + private native void allocate(Pointer primary, Pointer special, @Cast("const size_t") long lenInBytes, + @Cast("const sd::DataType") int dataType); + + public DataBuffer(Pointer primary, @Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType, + @Cast("const bool") boolean isOwnerPrimary/*=false*/, + Workspace workspace/*=nullptr*/) { super((Pointer)null); allocate(primary, lenInBytes, dataType, isOwnerPrimary, workspace); } + private native void allocate(Pointer primary, @Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType, + @Cast("const bool") boolean isOwnerPrimary/*=false*/, + Workspace workspace/*=nullptr*/); + public DataBuffer(Pointer primary, @Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType) { super((Pointer)null); allocate(primary, lenInBytes, dataType); } + private native void allocate(Pointer primary, @Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType); + + public DataBuffer(@Const Pointer hostBuffer, + @Cast("const sd::DataType") int dataType, @Cast("const size_t") long lenInBytes, + Workspace workspace/*=nullptr*/) { super((Pointer)null); allocate(hostBuffer, dataType, lenInBytes, workspace); } + private native void allocate(@Const Pointer hostBuffer, + @Cast("const sd::DataType") int dataType, @Cast("const size_t") long lenInBytes, + Workspace workspace/*=nullptr*/); + public DataBuffer(@Const Pointer hostBuffer, + @Cast("const sd::DataType") int dataType, @Cast("const size_t") long lenInBytes) { super((Pointer)null); allocate(hostBuffer, dataType, lenInBytes); } + private native void allocate(@Const Pointer hostBuffer, + @Cast("const sd::DataType") int dataType, @Cast("const size_t") long lenInBytes); + + public DataBuffer(@Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType, + Workspace workspace/*=nullptr*/, + @Cast("const bool") boolean allocBoth/*=false*/) { super((Pointer)null); allocate(lenInBytes, dataType, workspace, allocBoth); } + private native void allocate(@Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType, + Workspace workspace/*=nullptr*/, + @Cast("const bool") boolean allocBoth/*=false*/); + public DataBuffer(@Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType) { super((Pointer)null); allocate(lenInBytes, dataType); } + private native void allocate(@Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType); + + public DataBuffer(@Const @ByRef DataBuffer other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef DataBuffer other); + public DataBuffer() { super((Pointer)null); allocate(); } + private native void allocate(); - public DataBuffer(Pointer primary, Pointer special, - @Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType, - @Cast("const bool") boolean isOwnerPrimary/*=false*/, @Cast("const bool") boolean isOwnerSpecial/*=false*/, - Workspace workspace/*=nullptr*/) { super((Pointer)null); allocate(primary, special, lenInBytes, dataType, isOwnerPrimary, isOwnerSpecial, workspace); } - private native void allocate(Pointer primary, Pointer special, - @Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType, - @Cast("const bool") boolean isOwnerPrimary/*=false*/, @Cast("const bool") boolean isOwnerSpecial/*=false*/, - Workspace workspace/*=nullptr*/); - public DataBuffer(Pointer primary, Pointer special, - @Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType) { super((Pointer)null); allocate(primary, special, lenInBytes, dataType); } - private native void allocate(Pointer primary, Pointer special, - @Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType); - - public DataBuffer(Pointer primary, - @Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType, - @Cast("const bool") boolean isOwnerPrimary/*=false*/, - Workspace workspace/*=nullptr*/) { super((Pointer)null); allocate(primary, lenInBytes, dataType, isOwnerPrimary, workspace); } - private native void allocate(Pointer primary, - @Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType, - @Cast("const bool") boolean isOwnerPrimary/*=false*/, - Workspace workspace/*=nullptr*/); - public DataBuffer(Pointer primary, - @Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType) { super((Pointer)null); allocate(primary, lenInBytes, dataType); } - private native void allocate(Pointer primary, - @Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType); - - public DataBuffer(@Const Pointer hostBuffer, - @Cast("const sd::DataType") int dataType, @Cast("const size_t") long lenInBytes, - Workspace workspace/*=nullptr*/) { super((Pointer)null); allocate(hostBuffer, dataType, lenInBytes, workspace); } - private native void allocate(@Const Pointer hostBuffer, - @Cast("const sd::DataType") int dataType, @Cast("const size_t") long lenInBytes, - Workspace workspace/*=nullptr*/); - public DataBuffer(@Const Pointer hostBuffer, - @Cast("const sd::DataType") int dataType, @Cast("const size_t") long lenInBytes) { super((Pointer)null); allocate(hostBuffer, dataType, lenInBytes); } - private native void allocate(@Const Pointer hostBuffer, - @Cast("const sd::DataType") int dataType, @Cast("const size_t") long lenInBytes); - - public DataBuffer(@Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType, Workspace workspace/*=nullptr*/, @Cast("const bool") boolean allocBoth/*=false*/) { super((Pointer)null); allocate(lenInBytes, dataType, workspace, allocBoth); } - private native void allocate(@Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType, Workspace workspace/*=nullptr*/, @Cast("const bool") boolean allocBoth/*=false*/); - public DataBuffer(@Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType) { super((Pointer)null); allocate(lenInBytes, dataType); } - private native void allocate(@Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType); - - public DataBuffer(@Const @ByRef DataBuffer other) { super((Pointer)null); allocate(other); } - private native void allocate(@Const @ByRef DataBuffer other); - public DataBuffer() { super((Pointer)null); allocate(); } - private native void allocate(); - - public native @ByRef @Name("operator =") DataBuffer put(@Const @ByRef DataBuffer other); - - public native @Cast("sd::DataType") int getDataType(); - public native void setDataType(@Cast("sd::DataType") int dataType); - public native @Cast("size_t") long getLenInBytes(); - - public native Pointer primary(); - public native Pointer special(); - - public native void allocatePrimary(); - public native void allocateSpecial(); - - public native void writePrimary(); - public native void writeSpecial(); - public native void readPrimary(); - public native void readSpecial(); - public native @Cast("bool") boolean isPrimaryActual(); - public native @Cast("bool") boolean isSpecialActual(); - - public native void expand(@Cast("const uint64_t") long size); - - public native int deviceId(); - public native void setDeviceId(int deviceId); - public native void migrate(); - - public native void syncToPrimary(@Const LaunchContext context, @Cast("const bool") boolean forceSync/*=false*/); - public native void syncToPrimary(@Const LaunchContext context); - public native void syncToSpecial(@Cast("const bool") boolean forceSync/*=false*/); - public native void syncToSpecial(); - - public native void setToZeroBuffers(@Cast("const bool") boolean both/*=false*/); - public native void setToZeroBuffers(); - - public native void copyBufferFrom(@Const @ByRef DataBuffer other, @Cast("size_t") long sizeToCopyinBytes/*=0*/, @Cast("const Nd4jLong") long offsetThis/*=0*/, @Cast("const Nd4jLong") long offsetOther/*=0*/); - public native void copyBufferFrom(@Const @ByRef DataBuffer other); - - public static native void memcpy(@Const @ByRef DataBuffer dst, @Const @ByRef DataBuffer src); - - public native void setPrimaryBuffer(Pointer buffer, @Cast("size_t") long length); - public native void setSpecialBuffer(Pointer buffer, @Cast("size_t") long length); + public native @ByRef @Name("operator =") DataBuffer put(@Const @ByRef DataBuffer other); - /** - * This method deletes buffers, if we're owners - */ - public native @Name("close") void _close(); + public native @Cast("sd::DataType") int getDataType(); + public native void setDataType(@Cast("sd::DataType") int dataType); + public native @Cast("size_t") long getLenInBytes(); + + public native Pointer primary(); + public native Pointer special(); + public native Pointer platform(); + + public native void allocatePrimary(); + public native void allocateSpecial(); + + public native void writePrimary(); + public native void writeSpecial(); + public native void readPrimary(); + public native void readSpecial(); + public native @Cast("bool") boolean isPrimaryActual(); + public native @Cast("bool") boolean isSpecialActual(); + + public native void expand(@Cast("const uint64_t") long size); + + public native int deviceId(); + public native void setDeviceId(int deviceId); + + public native void migrate(); + + public native void syncToPrimary(@Const LaunchContext context, + @Cast("const bool") boolean forceSync/*=false*/); + public native void syncToPrimary(@Const LaunchContext context); + public native void syncToSpecial(@Cast("const bool") boolean forceSync/*=false*/); + public native void syncToSpecial(); + + public native void setToZeroBuffers(@Cast("const bool") boolean both/*=false*/); + public native void setToZeroBuffers(); + + public native void copyBufferFrom(@Const @ByRef DataBuffer other, @Cast("size_t") long sizeToCopyinBytes/*=0*/, + @Cast("const Nd4jLong") long offsetThis/*=0*/, + @Cast("const Nd4jLong") long offsetOther/*=0*/); + public native void copyBufferFrom(@Const @ByRef DataBuffer other); + + public static native void memcpy(@Const @ByRef DataBuffer dst, @Const @ByRef DataBuffer src); + + public native void setPrimaryBuffer(Pointer buffer, @Cast("size_t") long length); + public native void setSpecialBuffer(Pointer buffer, @Cast("size_t") long length); + + /** + * This method deletes buffers, if we're owners + */ + public native @Name("close") void _close(); } ///// IMLEMENTATION OF INLINE METHODS ///// //////////////////////////////////////////////////////////////////////// - + //////////////////////////////////////////////////////////////////////// - +//////////////////////////////////////////////////////////////////////// + // namespace sd -// #endif //DEV_TESTS_DATABUFFER_H +// #endif // SD_DATABUFFER_H // Parsed from array/PointerDeallocator.h @@ -587,33 +756,33 @@ private native void allocate(@Const Pointer hostBuffer, // #include // #include // #include - @Namespace("sd") @NoOffset public static class ConstantDataBuffer extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ConstantDataBuffer(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public ConstantDataBuffer(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public ConstantDataBuffer position(long position) { - return (ConstantDataBuffer)super.position(position); - } - - public ConstantDataBuffer(@Const @ByRef ConstantDataBuffer other) { super((Pointer)null); allocate(other); } - private native void allocate(@Const @ByRef ConstantDataBuffer other); - public ConstantDataBuffer() { super((Pointer)null); allocate(); } - private native void allocate(); +@Namespace("sd") @NoOffset public static class ConstantDataBuffer extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ConstantDataBuffer(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public ConstantDataBuffer(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public ConstantDataBuffer position(long position) { + return (ConstantDataBuffer)super.position(position); + } - public native @Cast("uint8_t") byte sizeOf(); - public native @Cast("uint64_t") long length(); + public ConstantDataBuffer(@Const @ByRef ConstantDataBuffer other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef ConstantDataBuffer other); + public ConstantDataBuffer() { super((Pointer)null); allocate(); } + private native void allocate(); - public native Pointer primary(); - public native Pointer special(); + public native @Cast("uint8_t") byte sizeOf(); + public native @Cast("uint64_t") long length(); - public native @ByRef @Name("operator =") ConstantDataBuffer put(@Const @ByRef ConstantDataBuffer other); - } + public native Pointer primary(); + public native Pointer special(); + public native @ByRef @Name("operator =") ConstantDataBuffer put(@Const @ByRef ConstantDataBuffer other); +} + // namespace sd -// #endif //DEV_TESTS_CONSTANTDATABUFFER_H +// #endif // SD_CONSTANTDATABUFFER_H // Parsed from array/ConstantShapeBuffer.h @@ -746,68 +915,68 @@ private native void allocate(@Const Pointer hostBuffer, // @author raver119@gmail.com // -// #ifndef DEV_TESTS_CONSTANTDESCRIPTOR_H -// #define DEV_TESTS_CONSTANTDESCRIPTOR_H +// #ifndef SD_CONSTANTDESCRIPTOR_H +// #define SD_CONSTANTDESCRIPTOR_H +// #include // #include -// #include -// #include -// #include // #include -// #include - @Namespace("sd") @NoOffset public static class ConstantDescriptor extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ConstantDescriptor(Pointer p) { super(p); } - - public ConstantDescriptor(DoublePointer values, int length) { super((Pointer)null); allocate(values, length); } - private native void allocate(DoublePointer values, int length); - public ConstantDescriptor(DoubleBuffer values, int length) { super((Pointer)null); allocate(values, length); } - private native void allocate(DoubleBuffer values, int length); - public ConstantDescriptor(double[] values, int length) { super((Pointer)null); allocate(values, length); } - private native void allocate(double[] values, int length); - public ConstantDescriptor(@Cast("const Nd4jLong*") LongPointer values, int length) { super((Pointer)null); allocate(values, length); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer values, int length); - public ConstantDescriptor(@Cast("const Nd4jLong*") LongBuffer values, int length) { super((Pointer)null); allocate(values, length); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer values, int length); - public ConstantDescriptor(@Cast("const Nd4jLong*") long[] values, int length) { super((Pointer)null); allocate(values, length); } - private native void allocate(@Cast("const Nd4jLong*") long[] values, int length); - - public ConstantDescriptor(@Cast("Nd4jLong*") @StdVector LongPointer values) { super((Pointer)null); allocate(values); } - private native void allocate(@Cast("Nd4jLong*") @StdVector LongPointer values); - public ConstantDescriptor(@Cast("Nd4jLong*") @StdVector LongBuffer values) { super((Pointer)null); allocate(values); } - private native void allocate(@Cast("Nd4jLong*") @StdVector LongBuffer values); - public ConstantDescriptor(@Cast("Nd4jLong*") @StdVector long[] values) { super((Pointer)null); allocate(values); } - private native void allocate(@Cast("Nd4jLong*") @StdVector long[] values); - public ConstantDescriptor(@StdVector DoublePointer values) { super((Pointer)null); allocate(values); } - private native void allocate(@StdVector DoublePointer values); - public ConstantDescriptor(@StdVector DoubleBuffer values) { super((Pointer)null); allocate(values); } - private native void allocate(@StdVector DoubleBuffer values); - public ConstantDescriptor(@StdVector double[] values) { super((Pointer)null); allocate(values); } - private native void allocate(@StdVector double[] values); - - // equal to operator - public native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef ConstantDescriptor other); - - // less than operator - public native @Cast("bool") @Name("operator <") boolean lessThan(@Const @ByRef ConstantDescriptor other); - - public native @Cast("bool") boolean isInteger(); - public native @Cast("bool") boolean isFloat(); - - public native @Cast("Nd4jLong") long length(); - - public native @Cast("Nd4jLong*") @StdVector LongPointer integerValues(); - public native @StdVector DoublePointer floatValues(); - } +// #include +// #include +// #include +@Namespace("sd") @NoOffset public static class ConstantDescriptor extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ConstantDescriptor(Pointer p) { super(p); } + + public ConstantDescriptor(DoublePointer values, int length) { super((Pointer)null); allocate(values, length); } + private native void allocate(DoublePointer values, int length); + public ConstantDescriptor(DoubleBuffer values, int length) { super((Pointer)null); allocate(values, length); } + private native void allocate(DoubleBuffer values, int length); + public ConstantDescriptor(double[] values, int length) { super((Pointer)null); allocate(values, length); } + private native void allocate(double[] values, int length); + public ConstantDescriptor(@Cast("const Nd4jLong*") LongPointer values, int length) { super((Pointer)null); allocate(values, length); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer values, int length); + public ConstantDescriptor(@Cast("const Nd4jLong*") LongBuffer values, int length) { super((Pointer)null); allocate(values, length); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer values, int length); + public ConstantDescriptor(@Cast("const Nd4jLong*") long[] values, int length) { super((Pointer)null); allocate(values, length); } + private native void allocate(@Cast("const Nd4jLong*") long[] values, int length); + + public ConstantDescriptor(@Cast("Nd4jLong*") @StdVector LongPointer values) { super((Pointer)null); allocate(values); } + private native void allocate(@Cast("Nd4jLong*") @StdVector LongPointer values); + public ConstantDescriptor(@Cast("Nd4jLong*") @StdVector LongBuffer values) { super((Pointer)null); allocate(values); } + private native void allocate(@Cast("Nd4jLong*") @StdVector LongBuffer values); + public ConstantDescriptor(@Cast("Nd4jLong*") @StdVector long[] values) { super((Pointer)null); allocate(values); } + private native void allocate(@Cast("Nd4jLong*") @StdVector long[] values); + public ConstantDescriptor(@StdVector DoublePointer values) { super((Pointer)null); allocate(values); } + private native void allocate(@StdVector DoublePointer values); + public ConstantDescriptor(@StdVector DoubleBuffer values) { super((Pointer)null); allocate(values); } + private native void allocate(@StdVector DoubleBuffer values); + public ConstantDescriptor(@StdVector double[] values) { super((Pointer)null); allocate(values); } + private native void allocate(@StdVector double[] values); + + // equal to operator + public native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef ConstantDescriptor other); + + // less than operator + public native @Cast("bool") @Name("operator <") boolean lessThan(@Const @ByRef ConstantDescriptor other); + + public native @Cast("bool") boolean isInteger(); + public native @Cast("bool") boolean isFloat(); + + public native @Cast("Nd4jLong") long length(); + + public native @Cast("Nd4jLong*") @StdVector LongPointer integerValues(); + public native @StdVector DoublePointer floatValues(); +} + // namespace sd // #ifndef __JAVACPP_HACK__ // #endif - -// #endif //DEV_TESTS_CONSTANTDESCRIPTOR_H +// #endif // SD_CONSTANTDESCRIPTOR_H // Parsed from array/TadPack.h @@ -832,47 +1001,49 @@ private native void allocate(@Const Pointer hostBuffer, // @author raver119@gmail.com // -// #ifndef DEV_TESTS_TADPACK_H -// #define DEV_TESTS_TADPACK_H +// #ifndef SD_TADPACK_H +// #define SD_TADPACK_H // #include // #include - @Namespace("sd") @NoOffset public static class TadPack extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public TadPack(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public TadPack(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public TadPack position(long position) { - return (TadPack)super.position(position); - } - - public TadPack(@Const @ByRef ConstantShapeBuffer shapes, @Const @ByRef ConstantOffsetsBuffer offets, @Cast("Nd4jLong") long numTads) { super((Pointer)null); allocate(shapes, offets, numTads); } - private native void allocate(@Const @ByRef ConstantShapeBuffer shapes, @Const @ByRef ConstantOffsetsBuffer offets, @Cast("Nd4jLong") long numTads); - public TadPack() { super((Pointer)null); allocate(); } - private native void allocate(); - - public native @Cast("const Nd4jLong*") LongPointer primaryShapeInfo(); - public native @Cast("const Nd4jLong*") LongPointer primaryOffsets(); +@Namespace("sd") @NoOffset public static class TadPack extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public TadPack(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public TadPack(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public TadPack position(long position) { + return (TadPack)super.position(position); + } - public native @Cast("const Nd4jLong*") LongPointer specialShapeInfo(); - public native @Cast("const Nd4jLong*") LongPointer specialOffsets(); + public TadPack(@Const @ByRef ConstantShapeBuffer shapes, @Const @ByRef ConstantOffsetsBuffer offets, + @Cast("Nd4jLong") long numTads) { super((Pointer)null); allocate(shapes, offets, numTads); } + private native void allocate(@Const @ByRef ConstantShapeBuffer shapes, @Const @ByRef ConstantOffsetsBuffer offets, + @Cast("Nd4jLong") long numTads); + public TadPack() { super((Pointer)null); allocate(); } + private native void allocate(); - public native @Cast("Nd4jLong") long numberOfTads(); - public native int shapeInfoLength(); + public native @Cast("const Nd4jLong*") LongPointer primaryShapeInfo(); + public native @Cast("const Nd4jLong*") LongPointer primaryOffsets(); - /** - * These methods return either primary or special pointers depending on platform binaries were compiled for - * @return - */ - public native @Cast("const Nd4jLong*") LongPointer platformShapeInfo(); - public native @Cast("const Nd4jLong*") LongPointer platformOffsets(); - } + public native @Cast("const Nd4jLong*") LongPointer specialShapeInfo(); + public native @Cast("const Nd4jLong*") LongPointer specialOffsets(); + public native @Cast("Nd4jLong") long numberOfTads(); + public native int shapeInfoLength(); + /** + * These methods return either primary or special pointers depending on + * platform binaries were compiled for + * @return + */ + public native @Cast("const Nd4jLong*") LongPointer platformShapeInfo(); + public native @Cast("const Nd4jLong*") LongPointer platformOffsets(); +} + // namespace sd -// #endif //DEV_TESTS_TADPACK_H +// #endif // SD_TADPACK_H // Parsed from execution/ErrorReference.h @@ -897,36 +1068,36 @@ private native void allocate(@Const Pointer hostBuffer, // @author raver119@gmail.com // -// #ifndef DEV_TESTS_ERRORREFERENCE_H -// #define DEV_TESTS_ERRORREFERENCE_H +// #ifndef SD_ERRORREFERENCE_H +// #define SD_ERRORREFERENCE_H -// #include // #include - @Namespace("sd") @NoOffset public static class ErrorReference extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ErrorReference(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public ErrorReference(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public ErrorReference position(long position) { - return (ErrorReference)super.position(position); - } - - public ErrorReference() { super((Pointer)null); allocate(); } - private native void allocate(); - - public native int errorCode(); - public native @Cast("char*") String errorMessage(); - public native void setErrorCode(int errorCode); - public native void setErrorMessage(@StdString BytePointer message); - public native void setErrorMessage(@StdString String message); +// #include +@Namespace("sd") @NoOffset public static class ErrorReference extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ErrorReference(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public ErrorReference(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public ErrorReference position(long position) { + return (ErrorReference)super.position(position); } + public ErrorReference() { super((Pointer)null); allocate(); } + private native void allocate(); + + public native int errorCode(); + public native @Cast("char*") String errorMessage(); + public native void setErrorCode(int errorCode); + public native void setErrorMessage(@StdString BytePointer message); + public native void setErrorMessage(@StdString String message); +} + // namespace sd -// #endif //DEV_TESTS_ERRORREFERENCE_H +// #endif // SD_ERRORREFERENCE_H // Parsed from execution/Engine.h @@ -953,13 +1124,13 @@ private native void allocate(@Const Pointer hostBuffer, // #ifndef SD_ENGINE_H // #define SD_ENGINE_H - /** enum samediff::Engine */ - public static final int - ENGINE_CPU = 0, - ENGINE_CUDA = 1; +/** enum samediff::Engine */ +public static final int + ENGINE_CPU = 0, + ENGINE_CUDA = 1; -// #endif //SD_ENGINE_H +// #endif // SD_ENGINE_H // Parsed from execution/ExecutionMode.h @@ -986,14 +1157,14 @@ private native void allocate(@Const Pointer hostBuffer, // #ifndef SD_EXECUTIONMODE_H // #define SD_EXECUTIONMODE_H - /** enum samediff::ExecutionMode */ - public static final int - MODE_UNDEFINED = 0, - MODE_TRAINING = 1, - MODE_INFERENCE = 2; +/** enum samediff::ExecutionMode */ +public static final int + MODE_UNDEFINED = 0, + MODE_TRAINING = 1, + MODE_INFERENCE = 2; -// #endif //SD_EXECUTIONMODE_H +// #endif // SD_EXECUTIONMODE_H // Parsed from system/Environment.h @@ -1021,100 +1192,101 @@ private native void allocate(@Const Pointer hostBuffer, // #ifndef LIBND4J_ENVIRONMENT_H // #define LIBND4J_ENVIRONMENT_H -// #include -// #include -// #include -// #include // #include -// #include +// #include // #include - @Namespace("sd") @NoOffset public static class Environment extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Environment(Pointer p) { super(p); } - - /** - * These 3 fields are mostly for CUDA/cuBLAS version tracking - */ - public native int _blasMajorVersion(); public native Environment _blasMajorVersion(int setter); - public native int _blasMinorVersion(); public native Environment _blasMinorVersion(int setter); - public native int _blasPatchVersion(); public native Environment _blasPatchVersion(int setter); - - public static native @ByRef Environment getInstance(); - - public native @Cast("bool") boolean isVerbose(); - public native void setVerbose(@Cast("bool") boolean reallyVerbose); - public native @Cast("bool") boolean isDebug(); - public native @Cast("bool") boolean isProfiling(); - public native @Cast("bool") boolean isDetectingLeaks(); - public native @Cast("bool") boolean isDebugAndVerbose(); - public native void setDebug(@Cast("bool") boolean reallyDebug); - public native void setProfiling(@Cast("bool") boolean reallyProfile); - public native void setLeaksDetector(@Cast("bool") boolean reallyDetect); - public native @Cast("bool") boolean helpersAllowed(); - public native void allowHelpers(@Cast("bool") boolean reallyAllow); - - public native @Cast("bool") boolean blasFallback(); - - public native int tadThreshold(); - public native void setTadThreshold(int threshold); +// #include - public native int elementwiseThreshold(); - public native void setElementwiseThreshold(int threshold); +// #include +// #include +// #include +@Namespace("sd") @NoOffset public static class Environment extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Environment(Pointer p) { super(p); } - public native int maxThreads(); - public native void setMaxThreads(int max); - public native int maxMasterThreads(); - public native void setMaxMasterThreads(int max); + /** + * These 3 fields are mostly for CUDA/cuBLAS version tracking + */ + public native int _blasMajorVersion(); public native Environment _blasMajorVersion(int setter); + public native int _blasMinorVersion(); public native Environment _blasMinorVersion(int setter); + public native int _blasPatchVersion(); public native Environment _blasPatchVersion(int setter); - /* - * Legacy memory limits API, still used in new API as simplified version - */ - public native void setMaxPrimaryMemory(@Cast("uint64_t") long maxBytes); - public native void setMaxSpecialyMemory(@Cast("uint64_t") long maxBytes); - public native void setMaxDeviceMemory(@Cast("uint64_t") long maxBytes); + public static native @ByRef Environment getInstance(); - public native @Cast("uint64_t") long maxPrimaryMemory(); - public native @Cast("uint64_t") long maxSpecialMemory(); - //////////////////////// + public native @Cast("bool") boolean isVerbose(); + public native void setVerbose(@Cast("bool") boolean reallyVerbose); + public native @Cast("bool") boolean isDebug(); + public native @Cast("bool") boolean isProfiling(); + public native @Cast("bool") boolean isDetectingLeaks(); + public native @Cast("bool") boolean isDebugAndVerbose(); + public native void setDebug(@Cast("bool") boolean reallyDebug); + public native void setProfiling(@Cast("bool") boolean reallyProfile); + public native void setLeaksDetector(@Cast("bool") boolean reallyDetect); + public native @Cast("bool") boolean helpersAllowed(); + public native void allowHelpers(@Cast("bool") boolean reallyAllow); - /* - * Methods for memory limits/counters - */ - public native void setGroupLimit(int group, @Cast("Nd4jLong") long numBytes); - public native void setDeviceLimit(int deviceId, @Cast("Nd4jLong") long numBytes); + public native @Cast("bool") boolean blasFallback(); - public native @Cast("Nd4jLong") long getGroupLimit(int group); - public native @Cast("Nd4jLong") long getDeviceLimit(int deviceId); + public native int tadThreshold(); + public native void setTadThreshold(int threshold); - public native @Cast("Nd4jLong") long getGroupCounter(int group); - public native @Cast("Nd4jLong") long getDeviceCounter(int deviceId); - //////////////////////// + public native int elementwiseThreshold(); + public native void setElementwiseThreshold(int threshold); - public native @Cast("bool") boolean isUseMKLDNN(); - public native void setUseMKLDNN(@Cast("bool") boolean useMKLDNN); + public native int maxThreads(); + public native void setMaxThreads(int max); - public native @Cast("sd::DataType") int defaultFloatDataType(); - public native void setDefaultFloatDataType(@Cast("sd::DataType") int dtype); + public native int maxMasterThreads(); + public native void setMaxMasterThreads(int max); - public native @Cast("bool") boolean precisionBoostAllowed(); - public native void allowPrecisionBoost(@Cast("bool") boolean reallyAllow); + /* + * Legacy memory limits API, still used in new API as simplified version + */ + public native void setMaxPrimaryMemory(@Cast("uint64_t") long maxBytes); + public native void setMaxSpecialyMemory(@Cast("uint64_t") long maxBytes); + public native void setMaxDeviceMemory(@Cast("uint64_t") long maxBytes); + + public native @Cast("uint64_t") long maxPrimaryMemory(); + public native @Cast("uint64_t") long maxSpecialMemory(); + //////////////////////// - public native @Cast("bool") boolean isExperimentalBuild(); + /* + * Methods for memory limits/counters + */ + public native void setGroupLimit(int group, @Cast("Nd4jLong") long numBytes); + public native void setDeviceLimit(int deviceId, @Cast("Nd4jLong") long numBytes); - public native @Cast("bool") boolean isCPU(); + public native @Cast("Nd4jLong") long getGroupLimit(int group); + public native @Cast("Nd4jLong") long getDeviceLimit(int deviceId); - public native int blasMajorVersion(); - public native int blasMinorVersion(); - public native int blasPatchVersion(); + public native @Cast("Nd4jLong") long getGroupCounter(int group); + public native @Cast("Nd4jLong") long getDeviceCounter(int deviceId); + //////////////////////// - public native @StdVector Pair capabilities(); - } + public native @Cast("bool") boolean isUseMKLDNN(); + public native void setUseMKLDNN(@Cast("bool") boolean useMKLDNN); + + public native @Cast("sd::DataType") int defaultFloatDataType(); + public native void setDefaultFloatDataType(@Cast("sd::DataType") int dtype); + + public native @Cast("bool") boolean precisionBoostAllowed(); + public native void allowPrecisionBoost(@Cast("bool") boolean reallyAllow); + + public native @Cast("bool") boolean isExperimentalBuild(); + + public native @Cast("bool") boolean isCPU(); + public native int blasMajorVersion(); + public native int blasMinorVersion(); + public native int blasPatchVersion(); + public native @StdVector Pair capabilities(); +} + // namespace sd -// #endif //LIBND4J_ENVIRONMENT_H +// #endif // LIBND4J_ENVIRONMENT_H // Parsed from types/utf8string.h @@ -1139,44 +1311,44 @@ private native void allocate(@Const Pointer hostBuffer, // @author raver119@gmail.com // -// #ifndef DEV_TESTS_UTF8STRING_H -// #define DEV_TESTS_UTF8STRING_H +// #ifndef SD_UTF8STRING_H +// #define SD_UTF8STRING_H -// #include // #include - @Namespace("sd") @NoOffset public static class utf8string extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public utf8string(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public utf8string(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public utf8string position(long position) { - return (utf8string)super.position(position); - } - - public native @Cast("char*") BytePointer _buffer(); public native utf8string _buffer(BytePointer setter); - public native @Cast("unsigned int") int _length(); public native utf8string _length(int setter); - public utf8string() { super((Pointer)null); allocate(); } - private native void allocate(); - - public utf8string(@Cast("char*") String string, int length) { super((Pointer)null); allocate(string, length); } - private native void allocate(@Cast("char*") String string, int length); - public utf8string(@Cast("char*") BytePointer string, int length) { super((Pointer)null); allocate(string, length); } - private native void allocate(@Cast("char*") BytePointer string, int length); - public utf8string(@StdString BytePointer string) { super((Pointer)null); allocate(string); } - private native void allocate(@StdString BytePointer string); - public utf8string(@StdString String string) { super((Pointer)null); allocate(string); } - private native void allocate(@StdString String string); - public utf8string(@Const @ByRef utf8string other) { super((Pointer)null); allocate(other); } - private native void allocate(@Const @ByRef utf8string other); - public native @ByRef @Name("operator =") utf8string put(@Const @ByRef utf8string other); +// #include +@Namespace("sd") @NoOffset public static class utf8string extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public utf8string(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public utf8string(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public utf8string position(long position) { + return (utf8string)super.position(position); } + public native @Cast("char*") BytePointer _buffer(); public native utf8string _buffer(BytePointer setter); + public native @Cast("unsigned int") int _length(); public native utf8string _length(int setter); + + public utf8string() { super((Pointer)null); allocate(); } + private native void allocate(); + public utf8string(@Cast("char*") String string, int length) { super((Pointer)null); allocate(string, length); } + private native void allocate(@Cast("char*") String string, int length); + public utf8string(@Cast("char*") BytePointer string, int length) { super((Pointer)null); allocate(string, length); } + private native void allocate(@Cast("char*") BytePointer string, int length); + public utf8string(@StdString BytePointer string) { super((Pointer)null); allocate(string); } + private native void allocate(@StdString BytePointer string); + public utf8string(@StdString String string) { super((Pointer)null); allocate(string); } + private native void allocate(@StdString String string); + public utf8string(@Const @ByRef utf8string other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef utf8string other); + public native @ByRef @Name("operator =") utf8string put(@Const @ByRef utf8string other); +} + // namespace sd -// #endif //DEV_TESTS_UTF8STRING_H +// #endif // SD_UTF8STRING_H // Parsed from legacy/NativeOps.h @@ -1214,7 +1386,7 @@ private native void allocate(@Const Pointer hostBuffer, defined __DMC__ || \ defined __BORLANDC__ ) # define thread_local __declspec(thread) -// note that ICC (linux) and Clang are covered by __GNUC__ +// note that ICC (linux) and Clang are covered by __GNUC__ # elif defined __GNUC__ || \ defined __SUNPRO_C || \ defined __xlC__ @@ -1225,17 +1397,17 @@ private native void allocate(@Const Pointer hostBuffer, #endif */ +// #include // #include // #include -// #include -//DO NOT REMOVE: THIS IS AN EDITOR SEMANTICS THING FOR CLION -//IT DEFINES THE EXPORT MACRO FOR THE EDITOR AND THEN -//RE ADDS THE DEFINITION VIA dll.h -// #ifdef _WIN32 -// #define ND4J_EXPORT __declspec(dllexport) +// DO NOT REMOVE: THIS IS AN EDITOR SEMANTICS THING FOR CLION +// IT DEFINES THE EXPORT MACRO FOR THE EDITOR AND THEN +// RE ADDS THE DEFINITION VIA dll.h +// #ifdef _WIN32 +// #define SD_EXPORT __declspec(dllexport) // #else -// #define ND4J_EXPORT +// #define SD_EXPORT // #endif // #include @@ -1247,17 +1419,16 @@ private native void allocate(@Const Pointer hostBuffer, bool verbose = false; */ -// #include -// #include -// #include // #include +// #include // #include -// #include +// #include // #include -// #include -// #include -// #include // #include +// #include +// #include +// #include +// #include // #include // #include @@ -1293,27 +1464,33 @@ private native void allocate(@Const Pointer hostBuffer, public native void setTADThreshold(int num); /** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param extraParams - */ -public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo); -public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); -public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo); + * + * @param opNum + * @param x + * @param xShapeInfo + * @param extraParams + */ +public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, + @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + Pointer extraParams, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo); +public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + Pointer extraParams, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); +public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, + @Cast("const Nd4jLong*") long[] dXShapeInfo, + Pointer extraParams, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo); /** * @@ -1326,24 +1503,24 @@ public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer ex * @param dimension * @param dimensionLength */ -public native void execIndexReduce(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape); -public native void execIndexReduce(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape); -public native void execIndexReduce(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape); +public native void execIndexReduce( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape); +public native void execIndexReduce( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape); +public native void execIndexReduce( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape); /** * @@ -1357,53 +1534,61 @@ public native void execIndexReduce(@Cast("Nd4jPointer*") PointerPointer extraPoi * @param dimension * @param dimensionLength */ -public native void execBroadcast( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, @Cast("const Nd4jLong*") LongPointer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape); -public native void execBroadcast( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape); -public native void execBroadcast( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, @Cast("const Nd4jLong*") long[] dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape); - +public native void execBroadcast(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, + @Cast("const Nd4jLong*") LongPointer dXShapeInfo, OpaqueDataBuffer dbY, + @Cast("const Nd4jLong*") LongPointer hYShapeInfo, + @Cast("const Nd4jLong*") LongPointer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongPointer hDimensionShape, + @Cast("const Nd4jLong*") LongPointer dDimensionShape); +public native void execBroadcast(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, OpaqueDataBuffer dbY, + @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongBuffer hDimensionShape, + @Cast("const Nd4jLong*") LongBuffer dDimensionShape); +public native void execBroadcast(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, + @Cast("const Nd4jLong*") long[] dXShapeInfo, OpaqueDataBuffer dbY, + @Cast("const Nd4jLong*") long[] hYShapeInfo, + @Cast("const Nd4jLong*") long[] dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") long[] hDimensionShape, + @Cast("const Nd4jLong*") long[] dDimensionShape); public native void execBroadcastBool( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, @Cast("const Nd4jLong*") LongPointer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape); + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, + @Cast("const Nd4jLong*") LongPointer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, + @Cast("const Nd4jLong*") LongPointer dDimensionShape); public native void execBroadcastBool( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape); + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, + @Cast("const Nd4jLong*") LongBuffer dDimensionShape); public native void execBroadcastBool( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, @Cast("const Nd4jLong*") long[] dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape); + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, + @Cast("const Nd4jLong*") long[] dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, + @Cast("const Nd4jLong*") long[] dDimensionShape); /** * @@ -1418,48 +1603,48 @@ public native void execBroadcastBool( * @param n */ public native void execPairwiseTransform( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, @Cast("const Nd4jLong*") LongPointer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - Pointer extraParams); + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, + @Cast("const Nd4jLong*") LongPointer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, + Pointer extraParams); public native void execPairwiseTransform( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - Pointer extraParams); + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, + Pointer extraParams); public native void execPairwiseTransform( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, @Cast("const Nd4jLong*") long[] dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - Pointer extraParams); + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, + @Cast("const Nd4jLong*") long[] dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, + Pointer extraParams); public native void execPairwiseTransformBool( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, @Cast("const Nd4jLong*") LongPointer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - Pointer extraParams); + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, + @Cast("const Nd4jLong*") LongPointer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, + Pointer extraParams); public native void execPairwiseTransformBool( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - Pointer extraParams); + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, + Pointer extraParams); public native void execPairwiseTransformBool( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, @Cast("const Nd4jLong*") long[] dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - Pointer extraParams); + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, + @Cast("const Nd4jLong*") long[] dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, + Pointer extraParams); /** * @@ -1470,70 +1655,93 @@ public native void execPairwiseTransformBool( * @param result * @param resultShapeInfo */ -public native void execReduceFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo); -public native void execReduceFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); -public native void execReduceFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo); - -public native void execReduceSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo); -public native void execReduceSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); -public native void execReduceSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo); - -public native void execReduceBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo); -public native void execReduceBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); -public native void execReduceBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo); - - -public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo); -public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); -public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo); +public native void execReduceFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, + @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo); +public native void execReduceFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); +public native void execReduceFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, + @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo); + +public native void execReduceSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, + @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo); +public native void execReduceSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); +public native void execReduceSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, + @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo); + +public native void execReduceBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, + @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo); +public native void execReduceBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); +public native void execReduceBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, + @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo); + +public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, + @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo); +public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); +public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, + @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo); /** * @@ -1544,84 +1752,81 @@ public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPoin * @param result * @param resultShapeInfo */ -public native void execReduceFloat2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape); -public native void execReduceFloat2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape); -public native void execReduceFloat2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape); - - -public native void execReduceSame2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape); -public native void execReduceSame2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape); -public native void execReduceSame2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape); - - -public native void execReduceBool2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape); -public native void execReduceBool2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape); -public native void execReduceBool2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape); - - -public native void execReduceLong2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape); -public native void execReduceLong2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape); -public native void execReduceLong2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape); +public native void execReduceFloat2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape); +public native void execReduceFloat2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape); +public native void execReduceFloat2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape); + +public native void execReduceSame2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape); +public native void execReduceSame2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape); +public native void execReduceSame2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape); + +public native void execReduceBool2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape); +public native void execReduceBool2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape); +public native void execReduceBool2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape); + +public native void execReduceLong2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape); +public native void execReduceLong2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape); +public native void execReduceLong2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape); /** * @@ -1634,24 +1839,27 @@ public native void execReduceLong2(@Cast("Nd4jPointer*") PointerPointer extraPoi * @param result * @param resultShapeInfo */ -public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParamsVals, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, @Cast("const Nd4jLong*") LongPointer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo); -public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParamsVals, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); -public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParamsVals, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, @Cast("const Nd4jLong*") long[] dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo); +public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, + @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParamsVals, + OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, + @Cast("const Nd4jLong*") LongPointer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo); +public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParamsVals, + OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); +public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, + @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParamsVals, + OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, + @Cast("const Nd4jLong*") long[] dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo); /** * @@ -1662,24 +1870,24 @@ public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointer * @param y * @param yShapeInfo */ -public native void execReduce3Scalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParamsVals, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, @Cast("const Nd4jLong*") LongPointer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo); -public native void execReduce3Scalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParamsVals, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); -public native void execReduce3Scalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParamsVals, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, @Cast("const Nd4jLong*") long[] dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo); +public native void execReduce3Scalar( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + Pointer extraParamsVals, OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, + @Cast("const Nd4jLong*") LongPointer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo); +public native void execReduce3Scalar( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + Pointer extraParamsVals, OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); +public native void execReduce3Scalar( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + Pointer extraParamsVals, OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, + @Cast("const Nd4jLong*") long[] dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo); /** * * @param opNum @@ -1693,62 +1901,67 @@ public native void execReduce3Scalar(@Cast("Nd4jPointer*") PointerPointer extraP * @param dimension * @param dimensionLength */ -public native void execReduce3Tad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParamsVals, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, @Cast("const Nd4jLong*") LongPointer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape, - @Cast("const Nd4jLong*") LongPointer tadOnlyShapeInfo, @Cast("const Nd4jLong*") LongPointer tadOffsets, - @Cast("const Nd4jLong*") LongPointer yTadOnlyShapeInfo, @Cast("const Nd4jLong*") LongPointer yTadOffsets); -public native void execReduce3Tad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParamsVals, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape, - @Cast("const Nd4jLong*") LongBuffer tadOnlyShapeInfo, @Cast("const Nd4jLong*") LongBuffer tadOffsets, - @Cast("const Nd4jLong*") LongBuffer yTadOnlyShapeInfo, @Cast("const Nd4jLong*") LongBuffer yTadOffsets); -public native void execReduce3Tad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParamsVals, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, @Cast("const Nd4jLong*") long[] dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape, - @Cast("const Nd4jLong*") long[] tadOnlyShapeInfo, @Cast("const Nd4jLong*") long[] tadOffsets, - @Cast("const Nd4jLong*") long[] yTadOnlyShapeInfo, @Cast("const Nd4jLong*") long[] yTadOffsets); - - -public native void execReduce3All(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParamsVals, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, @Cast("const Nd4jLong*") LongPointer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape, - @Cast("const Nd4jLong*") LongPointer xTadShapeInfo, @Cast("const Nd4jLong*") LongPointer xOffsets, - @Cast("const Nd4jLong*") LongPointer yTadShapeInfo, @Cast("const Nd4jLong*") LongPointer yOffsets); -public native void execReduce3All(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParamsVals, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape, - @Cast("const Nd4jLong*") LongBuffer xTadShapeInfo, @Cast("const Nd4jLong*") LongBuffer xOffsets, - @Cast("const Nd4jLong*") LongBuffer yTadShapeInfo, @Cast("const Nd4jLong*") LongBuffer yOffsets); -public native void execReduce3All(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParamsVals, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, @Cast("const Nd4jLong*") long[] dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape, - @Cast("const Nd4jLong*") long[] xTadShapeInfo, @Cast("const Nd4jLong*") long[] xOffsets, - @Cast("const Nd4jLong*") long[] yTadShapeInfo, @Cast("const Nd4jLong*") long[] yOffsets); +public native void execReduce3Tad( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + Pointer extraParamsVals, OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, + @Cast("const Nd4jLong*") LongPointer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, + @Cast("const Nd4jLong*") LongPointer dDimensionShape, @Cast("const Nd4jLong*") LongPointer tadOnlyShapeInfo, + @Cast("const Nd4jLong*") LongPointer tadOffsets, @Cast("const Nd4jLong*") LongPointer yTadOnlyShapeInfo, + @Cast("const Nd4jLong*") LongPointer yTadOffsets); +public native void execReduce3Tad( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + Pointer extraParamsVals, OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, + @Cast("const Nd4jLong*") LongBuffer dDimensionShape, @Cast("const Nd4jLong*") LongBuffer tadOnlyShapeInfo, + @Cast("const Nd4jLong*") LongBuffer tadOffsets, @Cast("const Nd4jLong*") LongBuffer yTadOnlyShapeInfo, + @Cast("const Nd4jLong*") LongBuffer yTadOffsets); +public native void execReduce3Tad( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + Pointer extraParamsVals, OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, + @Cast("const Nd4jLong*") long[] dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, + @Cast("const Nd4jLong*") long[] dDimensionShape, @Cast("const Nd4jLong*") long[] tadOnlyShapeInfo, + @Cast("const Nd4jLong*") long[] tadOffsets, @Cast("const Nd4jLong*") long[] yTadOnlyShapeInfo, + @Cast("const Nd4jLong*") long[] yTadOffsets); + +public native void execReduce3All( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + Pointer extraParamsVals, OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, + @Cast("const Nd4jLong*") LongPointer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, + @Cast("const Nd4jLong*") LongPointer dDimensionShape, @Cast("const Nd4jLong*") LongPointer xTadShapeInfo, + @Cast("const Nd4jLong*") LongPointer xOffsets, @Cast("const Nd4jLong*") LongPointer yTadShapeInfo, + @Cast("const Nd4jLong*") LongPointer yOffsets); +public native void execReduce3All( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + Pointer extraParamsVals, OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, + @Cast("const Nd4jLong*") LongBuffer dDimensionShape, @Cast("const Nd4jLong*") LongBuffer xTadShapeInfo, + @Cast("const Nd4jLong*") LongBuffer xOffsets, @Cast("const Nd4jLong*") LongBuffer yTadShapeInfo, + @Cast("const Nd4jLong*") LongBuffer yOffsets); +public native void execReduce3All( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + Pointer extraParamsVals, OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, + @Cast("const Nd4jLong*") long[] dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, + @Cast("const Nd4jLong*") long[] dDimensionShape, @Cast("const Nd4jLong*") long[] xTadShapeInfo, + @Cast("const Nd4jLong*") long[] xOffsets, @Cast("const Nd4jLong*") long[] yTadShapeInfo, + @Cast("const Nd4jLong*") long[] yOffsets); /** * @@ -1761,43 +1974,52 @@ public native void execReduce3All(@Cast("Nd4jPointer*") PointerPointer extraPoin * @param extraParams * @param n */ -public native void execScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbScalar, @Cast("const Nd4jLong*") LongPointer hSscalarShapeInfo, @Cast("const Nd4jLong*") LongPointer dSscalarShapeInfo, - Pointer extraParams); -public native void execScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbScalar, @Cast("const Nd4jLong*") LongBuffer hSscalarShapeInfo, @Cast("const Nd4jLong*") LongBuffer dSscalarShapeInfo, - Pointer extraParams); -public native void execScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbScalar, @Cast("const Nd4jLong*") long[] hSscalarShapeInfo, @Cast("const Nd4jLong*") long[] dSscalarShapeInfo, - Pointer extraParams); - -public native void execScalarBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbScalar, @Cast("const Nd4jLong*") LongPointer hSscalarShapeInfo, @Cast("const Nd4jLong*") LongPointer dSscalarShapeInfo, - Pointer extraParams); -public native void execScalarBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbScalar, @Cast("const Nd4jLong*") LongBuffer hSscalarShapeInfo, @Cast("const Nd4jLong*") LongBuffer dSscalarShapeInfo, - Pointer extraParams); -public native void execScalarBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbScalar, @Cast("const Nd4jLong*") long[] hSscalarShapeInfo, @Cast("const Nd4jLong*") long[] dSscalarShapeInfo, - Pointer extraParams); +public native void execScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, + @Cast("const Nd4jLong*") LongPointer dXShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbScalar, + @Cast("const Nd4jLong*") LongPointer hSscalarShapeInfo, + @Cast("const Nd4jLong*") LongPointer dSscalarShapeInfo, Pointer extraParams); +public native void execScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbScalar, + @Cast("const Nd4jLong*") LongBuffer hSscalarShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dSscalarShapeInfo, Pointer extraParams); +public native void execScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, + @Cast("const Nd4jLong*") long[] dXShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbScalar, + @Cast("const Nd4jLong*") long[] hSscalarShapeInfo, + @Cast("const Nd4jLong*") long[] dSscalarShapeInfo, Pointer extraParams); + +public native void execScalarBool( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, OpaqueDataBuffer dbScalar, + @Cast("const Nd4jLong*") LongPointer hSscalarShapeInfo, @Cast("const Nd4jLong*") LongPointer dSscalarShapeInfo, + Pointer extraParams); +public native void execScalarBool( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, OpaqueDataBuffer dbScalar, + @Cast("const Nd4jLong*") LongBuffer hSscalarShapeInfo, @Cast("const Nd4jLong*") LongBuffer dSscalarShapeInfo, + Pointer extraParams); +public native void execScalarBool( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, OpaqueDataBuffer dbScalar, + @Cast("const Nd4jLong*") long[] hSscalarShapeInfo, @Cast("const Nd4jLong*") long[] dSscalarShapeInfo, + Pointer extraParams); /** * @@ -1806,24 +2028,21 @@ public native void execScalarBool(@Cast("Nd4jPointer*") PointerPointer extraPoin * @param xShapeInfo * @param extraParams */ -public native void execSummaryStatsScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - @Cast("bool") boolean biasCorrected); -public native void execSummaryStatsScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - @Cast("bool") boolean biasCorrected); -public native void execSummaryStatsScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - @Cast("bool") boolean biasCorrected); +public native void execSummaryStatsScalar( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, @Cast("bool") boolean biasCorrected); +public native void execSummaryStatsScalar( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, @Cast("bool") boolean biasCorrected); +public native void execSummaryStatsScalar( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, @Cast("bool") boolean biasCorrected); /** * * @param opNum @@ -1833,24 +2052,21 @@ public native void execSummaryStatsScalar(@Cast("Nd4jPointer*") PointerPointer e * @param result * @param resultShapeInfo */ -public native void execSummaryStats(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - @Cast("bool") boolean biasCorrected); -public native void execSummaryStats(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - @Cast("bool") boolean biasCorrected); -public native void execSummaryStats(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - @Cast("bool") boolean biasCorrected); +public native void execSummaryStats( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, @Cast("bool") boolean biasCorrected); +public native void execSummaryStats( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, @Cast("bool") boolean biasCorrected); +public native void execSummaryStats( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, @Cast("bool") boolean biasCorrected); /** * * @param opNum @@ -1862,30 +2078,30 @@ public native void execSummaryStats(@Cast("Nd4jPointer*") PointerPointer extraPo * @param dimension * @param dimensionLength */ -public native void execSummaryStatsTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape, - @Cast("bool") boolean biasCorrected, - @Cast("const Nd4jLong*") LongPointer tadShapeInfo, @Cast("const Nd4jLong*") LongPointer tadOffsets); -public native void execSummaryStatsTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape, - @Cast("bool") boolean biasCorrected, - @Cast("const Nd4jLong*") LongBuffer tadShapeInfo, @Cast("const Nd4jLong*") LongBuffer tadOffsets); -public native void execSummaryStatsTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape, - @Cast("bool") boolean biasCorrected, - @Cast("const Nd4jLong*") long[] tadShapeInfo, @Cast("const Nd4jLong*") long[] tadOffsets); +public native void execSummaryStatsTad( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape, + @Cast("bool") boolean biasCorrected, @Cast("const Nd4jLong*") LongPointer tadShapeInfo, + @Cast("const Nd4jLong*") LongPointer tadOffsets); +public native void execSummaryStatsTad( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape, + @Cast("bool") boolean biasCorrected, @Cast("const Nd4jLong*") LongBuffer tadShapeInfo, + @Cast("const Nd4jLong*") LongBuffer tadOffsets); +public native void execSummaryStatsTad( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape, + @Cast("bool") boolean biasCorrected, @Cast("const Nd4jLong*") long[] tadShapeInfo, + @Cast("const Nd4jLong*") long[] tadOffsets); /** * @@ -1897,85 +2113,91 @@ public native void execSummaryStatsTad(@Cast("Nd4jPointer*") PointerPointer extr * @param extraParams * @param n */ -public native void execTransformFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - Pointer extraParams); -public native void execTransformFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - Pointer extraParams); -public native void execTransformFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - Pointer extraParams); - -public native void execTransformSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - Pointer extraParams); -public native void execTransformSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - Pointer extraParams); -public native void execTransformSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - Pointer extraParams); - -public native void execTransformBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - Pointer extraParams); -public native void execTransformBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - Pointer extraParams); -public native void execTransformBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - Pointer extraParams); - -public native void execTransformAny(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - Pointer extraParams); -public native void execTransformAny(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - Pointer extraParams); -public native void execTransformAny(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - Pointer extraParams); - -public native void execTransformStrict(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - Pointer extraParams); -public native void execTransformStrict(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - Pointer extraParams); -public native void execTransformStrict(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - Pointer extraParams); +public native void execTransformFloat( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); +public native void execTransformFloat( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); +public native void execTransformFloat( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); + +public native void execTransformSame( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); +public native void execTransformSame( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); +public native void execTransformSame( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); + +public native void execTransformBool( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); +public native void execTransformBool( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); +public native void execTransformBool( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); + +public native void execTransformAny(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, + @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); +public native void execTransformAny(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); +public native void execTransformAny(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, + @Cast("const Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); + +public native void execTransformStrict( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); +public native void execTransformStrict( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); +public native void execTransformStrict( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); /** * @@ -1990,92 +2212,86 @@ public native void execTransformStrict(@Cast("Nd4jPointer*") PointerPointer extr * @param dimension * @param dimensionLength */ -public native void execScalarTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbScalars, @Cast("const Nd4jLong*") LongPointer hScalarShapeInfo, @Cast("const Nd4jLong*") LongPointer dScalarShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape, - @Cast("const Nd4jLong*") LongPointer tadShapeInfo, @Cast("const Nd4jLong*") LongPointer tadOffsets, - @Cast("const Nd4jLong*") LongPointer tadShapeInfoZ, @Cast("const Nd4jLong*") LongPointer tadOffsetsZ); -public native void execScalarTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbScalars, @Cast("const Nd4jLong*") LongBuffer hScalarShapeInfo, @Cast("const Nd4jLong*") LongBuffer dScalarShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape, - @Cast("const Nd4jLong*") LongBuffer tadShapeInfo, @Cast("const Nd4jLong*") LongBuffer tadOffsets, - @Cast("const Nd4jLong*") LongBuffer tadShapeInfoZ, @Cast("const Nd4jLong*") LongBuffer tadOffsetsZ); -public native void execScalarTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbScalars, @Cast("const Nd4jLong*") long[] hScalarShapeInfo, @Cast("const Nd4jLong*") long[] dScalarShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape, - @Cast("const Nd4jLong*") long[] tadShapeInfo, @Cast("const Nd4jLong*") long[] tadOffsets, - @Cast("const Nd4jLong*") long[] tadShapeInfoZ, @Cast("const Nd4jLong*") long[] tadOffsetsZ); - -public native void execScalarBoolTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbScalars, @Cast("const Nd4jLong*") LongPointer hScalarShapeInfo, @Cast("const Nd4jLong*") LongPointer dScalarShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape, - @Cast("const Nd4jLong*") LongPointer tadShapeInfo, @Cast("const Nd4jLong*") LongPointer tadOffsets, - @Cast("const Nd4jLong*") LongPointer tadShapeInfoZ, @Cast("const Nd4jLong*") LongPointer tadOffsetsZ); -public native void execScalarBoolTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbScalars, @Cast("const Nd4jLong*") LongBuffer hScalarShapeInfo, @Cast("const Nd4jLong*") LongBuffer dScalarShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape, - @Cast("const Nd4jLong*") LongBuffer tadShapeInfo, @Cast("const Nd4jLong*") LongBuffer tadOffsets, - @Cast("const Nd4jLong*") LongBuffer tadShapeInfoZ, @Cast("const Nd4jLong*") LongBuffer tadOffsetsZ); -public native void execScalarBoolTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbScalars, @Cast("const Nd4jLong*") long[] hScalarShapeInfo, @Cast("const Nd4jLong*") long[] dScalarShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape, - @Cast("const Nd4jLong*") long[] tadShapeInfo, @Cast("const Nd4jLong*") long[] tadOffsets, - @Cast("const Nd4jLong*") long[] tadShapeInfoZ, @Cast("const Nd4jLong*") long[] tadOffsetsZ); - -public native void specialConcat( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int dimension, - int numArrays, - @Cast("Nd4jPointer*") PointerPointer data, - @Cast("Nd4jPointer*") PointerPointer inputShapeInfo, - Pointer result, - @Cast("const Nd4jLong*") LongPointer resultShapeInfo, - @Cast("Nd4jPointer*") PointerPointer tadPointers, - @Cast("Nd4jPointer*") PointerPointer offsetPointers); -public native void specialConcat( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int dimension, - int numArrays, - @Cast("Nd4jPointer*") PointerPointer data, - @Cast("Nd4jPointer*") PointerPointer inputShapeInfo, - Pointer result, - @Cast("const Nd4jLong*") LongBuffer resultShapeInfo, - @Cast("Nd4jPointer*") PointerPointer tadPointers, - @Cast("Nd4jPointer*") PointerPointer offsetPointers); -public native void specialConcat( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int dimension, - int numArrays, - @Cast("Nd4jPointer*") PointerPointer data, - @Cast("Nd4jPointer*") PointerPointer inputShapeInfo, - Pointer result, - @Cast("const Nd4jLong*") long[] resultShapeInfo, - @Cast("Nd4jPointer*") PointerPointer tadPointers, - @Cast("Nd4jPointer*") PointerPointer offsetPointers); +public native void execScalarTad( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, OpaqueDataBuffer dbScalars, + @Cast("const Nd4jLong*") LongPointer hScalarShapeInfo, @Cast("const Nd4jLong*") LongPointer dScalarShapeInfo, + Pointer extraParams, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape, + @Cast("const Nd4jLong*") LongPointer tadShapeInfo, @Cast("const Nd4jLong*") LongPointer tadOffsets, + @Cast("const Nd4jLong*") LongPointer tadShapeInfoZ, @Cast("const Nd4jLong*") LongPointer tadOffsetsZ); +public native void execScalarTad( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, OpaqueDataBuffer dbScalars, + @Cast("const Nd4jLong*") LongBuffer hScalarShapeInfo, @Cast("const Nd4jLong*") LongBuffer dScalarShapeInfo, + Pointer extraParams, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape, + @Cast("const Nd4jLong*") LongBuffer tadShapeInfo, @Cast("const Nd4jLong*") LongBuffer tadOffsets, + @Cast("const Nd4jLong*") LongBuffer tadShapeInfoZ, @Cast("const Nd4jLong*") LongBuffer tadOffsetsZ); +public native void execScalarTad( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, OpaqueDataBuffer dbScalars, + @Cast("const Nd4jLong*") long[] hScalarShapeInfo, @Cast("const Nd4jLong*") long[] dScalarShapeInfo, + Pointer extraParams, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape, + @Cast("const Nd4jLong*") long[] tadShapeInfo, @Cast("const Nd4jLong*") long[] tadOffsets, + @Cast("const Nd4jLong*") long[] tadShapeInfoZ, @Cast("const Nd4jLong*") long[] tadOffsetsZ); + +public native void execScalarBoolTad( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, OpaqueDataBuffer dbScalars, + @Cast("const Nd4jLong*") LongPointer hScalarShapeInfo, @Cast("const Nd4jLong*") LongPointer dScalarShapeInfo, + Pointer extraParams, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape, + @Cast("const Nd4jLong*") LongPointer tadShapeInfo, @Cast("const Nd4jLong*") LongPointer tadOffsets, + @Cast("const Nd4jLong*") LongPointer tadShapeInfoZ, @Cast("const Nd4jLong*") LongPointer tadOffsetsZ); +public native void execScalarBoolTad( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, OpaqueDataBuffer dbScalars, + @Cast("const Nd4jLong*") LongBuffer hScalarShapeInfo, @Cast("const Nd4jLong*") LongBuffer dScalarShapeInfo, + Pointer extraParams, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape, + @Cast("const Nd4jLong*") LongBuffer tadShapeInfo, @Cast("const Nd4jLong*") LongBuffer tadOffsets, + @Cast("const Nd4jLong*") LongBuffer tadShapeInfoZ, @Cast("const Nd4jLong*") LongBuffer tadOffsetsZ); +public native void execScalarBoolTad( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, OpaqueDataBuffer dbScalars, + @Cast("const Nd4jLong*") long[] hScalarShapeInfo, @Cast("const Nd4jLong*") long[] dScalarShapeInfo, + Pointer extraParams, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape, + @Cast("const Nd4jLong*") long[] tadShapeInfo, @Cast("const Nd4jLong*") long[] tadOffsets, + @Cast("const Nd4jLong*") long[] tadShapeInfoZ, @Cast("const Nd4jLong*") long[] tadOffsetsZ); + +public native void specialConcat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int dimension, + int numArrays, @Cast("Nd4jPointer*") PointerPointer data, + @Cast("Nd4jPointer*") PointerPointer inputShapeInfo, Pointer result, + @Cast("const Nd4jLong*") LongPointer resultShapeInfo, + @Cast("Nd4jPointer*") PointerPointer tadPointers, + @Cast("Nd4jPointer*") PointerPointer offsetPointers); +public native void specialConcat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int dimension, + int numArrays, @Cast("Nd4jPointer*") PointerPointer data, + @Cast("Nd4jPointer*") PointerPointer inputShapeInfo, Pointer result, + @Cast("const Nd4jLong*") LongBuffer resultShapeInfo, + @Cast("Nd4jPointer*") PointerPointer tadPointers, + @Cast("Nd4jPointer*") PointerPointer offsetPointers); +public native void specialConcat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int dimension, + int numArrays, @Cast("Nd4jPointer*") PointerPointer data, + @Cast("Nd4jPointer*") PointerPointer inputShapeInfo, Pointer result, + @Cast("const Nd4jLong*") long[] resultShapeInfo, + @Cast("Nd4jPointer*") PointerPointer tadPointers, + @Cast("Nd4jPointer*") PointerPointer offsetPointers); /** * This method implementation exists only for cuda. @@ -2099,10 +2315,12 @@ public native void specialConcat( * * @param pointer pointer that'll be used for allocation * @param memorySize memory size, in bytes - * @param ptrToDeviceId pointer to deviceId. For cuda that's just and int, for OpenCL that's pointer to device_id, etc + * @param ptrToDeviceId pointer to deviceId. For cuda that's just and int, for + * OpenCL that's pointer to device_id, etc * @param flags optional parameter */ -public native @Cast("Nd4jPointer") Pointer mallocDevice(@Cast("Nd4jLong") long memorySize, int deviceId, int flags); +public native @Cast("Nd4jPointer") Pointer mallocDevice(@Cast("Nd4jLong") long memorySize, int deviceId, + int flags); /** * This method releases previously allocated host memory space @@ -2143,7 +2361,6 @@ public native void specialConcat( */ public native void setOmpMinThreads(int threads); - public native @Cast("bool") boolean isBlasVersionMatches(int major, int minor, int build); /** @@ -2263,11 +2480,8 @@ public native void specialConcat( * @param reserved * @return */ -public native int memcpySync(@Cast("Nd4jPointer") Pointer dst, - @Cast("Nd4jPointer") Pointer src, - @Cast("Nd4jLong") long size, - int flags, - @Cast("Nd4jPointer") Pointer reserved); +public native int memcpySync(@Cast("Nd4jPointer") Pointer dst, @Cast("Nd4jPointer") Pointer src, @Cast("Nd4jLong") long size, + int flags, @Cast("Nd4jPointer") Pointer reserved); /** * @@ -2278,11 +2492,8 @@ public native int memcpySync(@Cast("Nd4jPointer") Pointer dst, * @param reserved * @return */ -public native int memcpyAsync(@Cast("Nd4jPointer") Pointer dst, - @Cast("Nd4jPointer") Pointer src, - @Cast("Nd4jLong") long size, - int flags, - @Cast("Nd4jPointer") Pointer reserved); +public native int memcpyAsync(@Cast("Nd4jPointer") Pointer dst, @Cast("Nd4jPointer") Pointer src, @Cast("Nd4jLong") long size, + int flags, @Cast("Nd4jPointer") Pointer reserved); /** * @@ -2293,11 +2504,8 @@ public native int memcpyAsync(@Cast("Nd4jPointer") Pointer dst, * @param reserved * @return */ -public native int memsetSync(@Cast("Nd4jPointer") Pointer dst, - int value, - @Cast("Nd4jLong") long size, - int flags, - @Cast("Nd4jPointer") Pointer reserved); +public native int memsetSync(@Cast("Nd4jPointer") Pointer dst, int value, @Cast("Nd4jLong") long size, int flags, + @Cast("Nd4jPointer") Pointer reserved); /** * @@ -2308,11 +2516,8 @@ public native int memsetSync(@Cast("Nd4jPointer") Pointer dst, * @param reserved * @return */ -public native int memsetAsync(@Cast("Nd4jPointer") Pointer dst, - int value, - @Cast("Nd4jLong") long size, - int flags, - @Cast("Nd4jPointer") Pointer reserved); +public native int memsetAsync(@Cast("Nd4jPointer") Pointer dst, int value, @Cast("Nd4jLong") long size, int flags, + @Cast("Nd4jPointer") Pointer reserved); /** * @@ -2323,11 +2528,8 @@ public native int memsetAsync(@Cast("Nd4jPointer") Pointer dst, * @param reserved * @return */ -public native int memcpyConstantAsync(@Cast("Nd4jLong") long dst, - @Cast("Nd4jPointer") Pointer src, - @Cast("Nd4jLong") long size, - int flags, - @Cast("Nd4jPointer") Pointer reserved); +public native int memcpyConstantAsync(@Cast("Nd4jLong") long dst, @Cast("Nd4jPointer") Pointer src, @Cast("Nd4jLong") long size, + int flags, @Cast("Nd4jPointer") Pointer reserved); /** * @@ -2368,14 +2570,11 @@ public native int memcpyConstantAsync(@Cast("Nd4jLong") long dst, * @param offsetsBuffer */ public native OpaqueTadPack tadOnlyShapeInfo(@Cast("const Nd4jLong*") LongPointer xShapeInfo, - IntPointer dimension, - int dimensionLength); + IntPointer dimension, int dimensionLength); public native OpaqueTadPack tadOnlyShapeInfo(@Cast("const Nd4jLong*") LongBuffer xShapeInfo, - IntBuffer dimension, - int dimensionLength); + IntBuffer dimension, int dimensionLength); public native OpaqueTadPack tadOnlyShapeInfo(@Cast("const Nd4jLong*") long[] xShapeInfo, - int[] dimension, - int dimensionLength); + int[] dimension, int dimensionLength); public native @Cast("const Nd4jLong*") LongPointer getPrimaryShapeInfo(OpaqueTadPack pack); public native @Cast("const Nd4jLong*") LongPointer getPrimaryOffsets(OpaqueTadPack pack); @@ -2404,33 +2603,30 @@ public native OpaqueTadPack tadOnlyShapeInfo(@Cast("const Nd4jLong*") long[] xSh * @param zTadShapeInfo * @param zTadOffsets */ -public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer xShapeInfo, @Cast("const Nd4jLong*") LongPointer dxShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer zShapeInfo, @Cast("const Nd4jLong*") LongPointer dzShapeInfo, - @Cast("Nd4jLong") long n, - @Cast("Nd4jLong*") LongPointer indexes, - @Cast("const Nd4jLong*") LongPointer tadShapeInfo, - @Cast("const Nd4jLong*") LongPointer tadOffsets, - @Cast("const Nd4jLong*") LongPointer zTadShapeInfo, - @Cast("const Nd4jLong*") LongPointer zTadOffsets); -public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer xShapeInfo, @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer zShapeInfo, @Cast("const Nd4jLong*") LongBuffer dzShapeInfo, - @Cast("Nd4jLong") long n, - @Cast("Nd4jLong*") LongBuffer indexes, - @Cast("const Nd4jLong*") LongBuffer tadShapeInfo, - @Cast("const Nd4jLong*") LongBuffer tadOffsets, - @Cast("const Nd4jLong*") LongBuffer zTadShapeInfo, - @Cast("const Nd4jLong*") LongBuffer zTadOffsets); -public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] xShapeInfo, @Cast("const Nd4jLong*") long[] dxShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] zShapeInfo, @Cast("const Nd4jLong*") long[] dzShapeInfo, - @Cast("Nd4jLong") long n, - @Cast("Nd4jLong*") long[] indexes, - @Cast("const Nd4jLong*") long[] tadShapeInfo, - @Cast("const Nd4jLong*") long[] tadOffsets, - @Cast("const Nd4jLong*") long[] zTadShapeInfo, - @Cast("const Nd4jLong*") long[] zTadOffsets); +public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer xShapeInfo, @Cast("const Nd4jLong*") LongPointer dxShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer zShapeInfo, + @Cast("const Nd4jLong*") LongPointer dzShapeInfo, @Cast("Nd4jLong") long n, + @Cast("Nd4jLong*") LongPointer indexes, @Cast("const Nd4jLong*") LongPointer tadShapeInfo, + @Cast("const Nd4jLong*") LongPointer tadOffsets, + @Cast("const Nd4jLong*") LongPointer zTadShapeInfo, + @Cast("const Nd4jLong*") LongPointer zTadOffsets); +public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer xShapeInfo, @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer zShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dzShapeInfo, @Cast("Nd4jLong") long n, + @Cast("Nd4jLong*") LongBuffer indexes, @Cast("const Nd4jLong*") LongBuffer tadShapeInfo, + @Cast("const Nd4jLong*") LongBuffer tadOffsets, + @Cast("const Nd4jLong*") LongBuffer zTadShapeInfo, + @Cast("const Nd4jLong*") LongBuffer zTadOffsets); +public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] xShapeInfo, @Cast("const Nd4jLong*") long[] dxShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] zShapeInfo, + @Cast("const Nd4jLong*") long[] dzShapeInfo, @Cast("Nd4jLong") long n, + @Cast("Nd4jLong*") long[] indexes, @Cast("const Nd4jLong*") long[] tadShapeInfo, + @Cast("const Nd4jLong*") long[] tadOffsets, + @Cast("const Nd4jLong*") long[] zTadShapeInfo, + @Cast("const Nd4jLong*") long[] zTadOffsets); /** * @@ -2441,54 +2637,40 @@ public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, * @param length * @param propagate */ -public native void average(@Cast("Nd4jPointer*") PointerPointer extras, - @Cast("Nd4jPointer*") PointerPointer x, @Cast("const Nd4jLong*") LongPointer xShapeInfo, - @Cast("Nd4jPointer*") PointerPointer dx, @Cast("const Nd4jLong*") LongPointer dxShapeInfo, - Pointer z, @Cast("const Nd4jLong*") LongPointer zShapeInfo, - Pointer dz, @Cast("const Nd4jLong*") LongPointer dzShapeInfo, - int n, - @Cast("Nd4jLong") long length, - @Cast("bool") boolean propagate); -public native void average(@Cast("Nd4jPointer*") PointerPointer extras, - @Cast("Nd4jPointer*") PointerPointer x, @Cast("const Nd4jLong*") LongBuffer xShapeInfo, - @Cast("Nd4jPointer*") PointerPointer dx, @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, - Pointer z, @Cast("const Nd4jLong*") LongBuffer zShapeInfo, - Pointer dz, @Cast("const Nd4jLong*") LongBuffer dzShapeInfo, - int n, - @Cast("Nd4jLong") long length, - @Cast("bool") boolean propagate); -public native void average(@Cast("Nd4jPointer*") PointerPointer extras, - @Cast("Nd4jPointer*") PointerPointer x, @Cast("const Nd4jLong*") long[] xShapeInfo, - @Cast("Nd4jPointer*") PointerPointer dx, @Cast("const Nd4jLong*") long[] dxShapeInfo, - Pointer z, @Cast("const Nd4jLong*") long[] zShapeInfo, - Pointer dz, @Cast("const Nd4jLong*") long[] dzShapeInfo, - int n, - @Cast("Nd4jLong") long length, - @Cast("bool") boolean propagate); - - -public native void accumulate(@Cast("Nd4jPointer*") PointerPointer extras, - @Cast("Nd4jPointer*") PointerPointer x, @Cast("const Nd4jLong*") LongPointer xShapeInfo, - @Cast("Nd4jPointer*") PointerPointer dx, @Cast("const Nd4jLong*") LongPointer dxShapeInfo, - Pointer z, @Cast("const Nd4jLong*") LongPointer zShapeInfo, - Pointer dz, @Cast("const Nd4jLong*") LongPointer dzShapeInfo, - int n, - @Cast("Nd4jLong") long length); -public native void accumulate(@Cast("Nd4jPointer*") PointerPointer extras, - @Cast("Nd4jPointer*") PointerPointer x, @Cast("const Nd4jLong*") LongBuffer xShapeInfo, - @Cast("Nd4jPointer*") PointerPointer dx, @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, - Pointer z, @Cast("const Nd4jLong*") LongBuffer zShapeInfo, - Pointer dz, @Cast("const Nd4jLong*") LongBuffer dzShapeInfo, - int n, - @Cast("Nd4jLong") long length); -public native void accumulate(@Cast("Nd4jPointer*") PointerPointer extras, - @Cast("Nd4jPointer*") PointerPointer x, @Cast("const Nd4jLong*") long[] xShapeInfo, - @Cast("Nd4jPointer*") PointerPointer dx, @Cast("const Nd4jLong*") long[] dxShapeInfo, - Pointer z, @Cast("const Nd4jLong*") long[] zShapeInfo, - Pointer dz, @Cast("const Nd4jLong*") long[] dzShapeInfo, - int n, - @Cast("Nd4jLong") long length); - +public native void average(@Cast("Nd4jPointer*") PointerPointer extras, @Cast("Nd4jPointer*") PointerPointer x, + @Cast("const Nd4jLong*") LongPointer xShapeInfo, @Cast("Nd4jPointer*") PointerPointer dx, + @Cast("const Nd4jLong*") LongPointer dxShapeInfo, Pointer z, + @Cast("const Nd4jLong*") LongPointer zShapeInfo, Pointer dz, + @Cast("const Nd4jLong*") LongPointer dzShapeInfo, int n, @Cast("Nd4jLong") long length, + @Cast("bool") boolean propagate); +public native void average(@Cast("Nd4jPointer*") PointerPointer extras, @Cast("Nd4jPointer*") PointerPointer x, + @Cast("const Nd4jLong*") LongBuffer xShapeInfo, @Cast("Nd4jPointer*") PointerPointer dx, + @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, Pointer z, + @Cast("const Nd4jLong*") LongBuffer zShapeInfo, Pointer dz, + @Cast("const Nd4jLong*") LongBuffer dzShapeInfo, int n, @Cast("Nd4jLong") long length, + @Cast("bool") boolean propagate); +public native void average(@Cast("Nd4jPointer*") PointerPointer extras, @Cast("Nd4jPointer*") PointerPointer x, + @Cast("const Nd4jLong*") long[] xShapeInfo, @Cast("Nd4jPointer*") PointerPointer dx, + @Cast("const Nd4jLong*") long[] dxShapeInfo, Pointer z, + @Cast("const Nd4jLong*") long[] zShapeInfo, Pointer dz, + @Cast("const Nd4jLong*") long[] dzShapeInfo, int n, @Cast("Nd4jLong") long length, + @Cast("bool") boolean propagate); + +public native void accumulate(@Cast("Nd4jPointer*") PointerPointer extras, @Cast("Nd4jPointer*") PointerPointer x, + @Cast("const Nd4jLong*") LongPointer xShapeInfo, @Cast("Nd4jPointer*") PointerPointer dx, + @Cast("const Nd4jLong*") LongPointer dxShapeInfo, Pointer z, + @Cast("const Nd4jLong*") LongPointer zShapeInfo, Pointer dz, + @Cast("const Nd4jLong*") LongPointer dzShapeInfo, int n, @Cast("Nd4jLong") long length); +public native void accumulate(@Cast("Nd4jPointer*") PointerPointer extras, @Cast("Nd4jPointer*") PointerPointer x, + @Cast("const Nd4jLong*") LongBuffer xShapeInfo, @Cast("Nd4jPointer*") PointerPointer dx, + @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, Pointer z, + @Cast("const Nd4jLong*") LongBuffer zShapeInfo, Pointer dz, + @Cast("const Nd4jLong*") LongBuffer dzShapeInfo, int n, @Cast("Nd4jLong") long length); +public native void accumulate(@Cast("Nd4jPointer*") PointerPointer extras, @Cast("Nd4jPointer*") PointerPointer x, + @Cast("const Nd4jLong*") long[] xShapeInfo, @Cast("Nd4jPointer*") PointerPointer dx, + @Cast("const Nd4jLong*") long[] dxShapeInfo, Pointer z, + @Cast("const Nd4jLong*") long[] zShapeInfo, Pointer dz, + @Cast("const Nd4jLong*") long[] dzShapeInfo, int n, @Cast("Nd4jLong") long length); /** * P2P enabler @@ -2526,34 +2708,24 @@ public native void accumulate(@Cast("Nd4jPointer*") PointerPointer extras, * @param tadShapeInfo * @param tadOffsets */ -public native void shuffle(@Cast("Nd4jPointer*") PointerPointer extras, - @Cast("Nd4jPointer*") PointerPointer x, @Cast("Nd4jPointer*") PointerPointer xShapeInfo, - @Cast("Nd4jPointer*") PointerPointer dx, @Cast("Nd4jPointer*") PointerPointer dxShapeInfo, - @Cast("Nd4jPointer*") PointerPointer z, @Cast("Nd4jPointer*") PointerPointer zShapeInfo, - @Cast("Nd4jPointer*") PointerPointer dz, @Cast("Nd4jPointer*") PointerPointer dzShapeInfo, - int N, - IntPointer shuffleMap, - @Cast("Nd4jPointer*") PointerPointer tadShapeInfo, - @Cast("Nd4jPointer*") PointerPointer tadOffsets); -public native void shuffle(@Cast("Nd4jPointer*") PointerPointer extras, - @Cast("Nd4jPointer*") PointerPointer x, @Cast("Nd4jPointer*") PointerPointer xShapeInfo, - @Cast("Nd4jPointer*") PointerPointer dx, @Cast("Nd4jPointer*") PointerPointer dxShapeInfo, - @Cast("Nd4jPointer*") PointerPointer z, @Cast("Nd4jPointer*") PointerPointer zShapeInfo, - @Cast("Nd4jPointer*") PointerPointer dz, @Cast("Nd4jPointer*") PointerPointer dzShapeInfo, - int N, - IntBuffer shuffleMap, - @Cast("Nd4jPointer*") PointerPointer tadShapeInfo, - @Cast("Nd4jPointer*") PointerPointer tadOffsets); -public native void shuffle(@Cast("Nd4jPointer*") PointerPointer extras, - @Cast("Nd4jPointer*") PointerPointer x, @Cast("Nd4jPointer*") PointerPointer xShapeInfo, - @Cast("Nd4jPointer*") PointerPointer dx, @Cast("Nd4jPointer*") PointerPointer dxShapeInfo, - @Cast("Nd4jPointer*") PointerPointer z, @Cast("Nd4jPointer*") PointerPointer zShapeInfo, - @Cast("Nd4jPointer*") PointerPointer dz, @Cast("Nd4jPointer*") PointerPointer dzShapeInfo, - int N, - int[] shuffleMap, - @Cast("Nd4jPointer*") PointerPointer tadShapeInfo, - @Cast("Nd4jPointer*") PointerPointer tadOffsets); - +public native void shuffle(@Cast("Nd4jPointer*") PointerPointer extras, @Cast("Nd4jPointer*") PointerPointer x, + @Cast("Nd4jPointer*") PointerPointer xShapeInfo, @Cast("Nd4jPointer*") PointerPointer dx, + @Cast("Nd4jPointer*") PointerPointer dxShapeInfo, @Cast("Nd4jPointer*") PointerPointer z, + @Cast("Nd4jPointer*") PointerPointer zShapeInfo, @Cast("Nd4jPointer*") PointerPointer dz, + @Cast("Nd4jPointer*") PointerPointer dzShapeInfo, int N, IntPointer shuffleMap, + @Cast("Nd4jPointer*") PointerPointer tadShapeInfo, @Cast("Nd4jPointer*") PointerPointer tadOffsets); +public native void shuffle(@Cast("Nd4jPointer*") PointerPointer extras, @Cast("Nd4jPointer*") PointerPointer x, + @Cast("Nd4jPointer*") PointerPointer xShapeInfo, @Cast("Nd4jPointer*") PointerPointer dx, + @Cast("Nd4jPointer*") PointerPointer dxShapeInfo, @Cast("Nd4jPointer*") PointerPointer z, + @Cast("Nd4jPointer*") PointerPointer zShapeInfo, @Cast("Nd4jPointer*") PointerPointer dz, + @Cast("Nd4jPointer*") PointerPointer dzShapeInfo, int N, IntBuffer shuffleMap, + @Cast("Nd4jPointer*") PointerPointer tadShapeInfo, @Cast("Nd4jPointer*") PointerPointer tadOffsets); +public native void shuffle(@Cast("Nd4jPointer*") PointerPointer extras, @Cast("Nd4jPointer*") PointerPointer x, + @Cast("Nd4jPointer*") PointerPointer xShapeInfo, @Cast("Nd4jPointer*") PointerPointer dx, + @Cast("Nd4jPointer*") PointerPointer dxShapeInfo, @Cast("Nd4jPointer*") PointerPointer z, + @Cast("Nd4jPointer*") PointerPointer zShapeInfo, @Cast("Nd4jPointer*") PointerPointer dz, + @Cast("Nd4jPointer*") PointerPointer dzShapeInfo, int N, int[] shuffleMap, + @Cast("Nd4jPointer*") PointerPointer tadShapeInfo, @Cast("Nd4jPointer*") PointerPointer tadOffsets); /** * Type Conversions @@ -2568,8 +2740,8 @@ public native void shuffle(@Cast("Nd4jPointer*") PointerPointer extras, * @param dstType * @param z */ -public native void convertTypes(@Cast("Nd4jPointer*") PointerPointer extras, int srcType, @Cast("Nd4jPointer") Pointer x, @Cast("Nd4jLong") long N, int dstType, @Cast("Nd4jPointer") Pointer z); - +public native void convertTypes(@Cast("Nd4jPointer*") PointerPointer extras, int srcType, @Cast("Nd4jPointer") Pointer x, + @Cast("Nd4jLong") long N, int dstType, @Cast("Nd4jPointer") Pointer z); /** * @@ -2596,83 +2768,46 @@ public native void shuffle(@Cast("Nd4jPointer*") PointerPointer extras, * @param realArguments * @param numRealArguments */ -public native void execAggregate(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("void**") PointerPointer arguments, - int numArguments, - @Cast("Nd4jLong**") PointerPointer shapeArguments, - int numShapeArguments, - IntPointer indexArguments, - int numIndexArguments, - @Cast("int**") PointerPointer intArrays, - int numIntArrays, - Pointer realArguments, - int numRealArguments, - @Cast("sd::DataType") int dtype); -public native void execAggregate(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("void**") @ByPtrPtr Pointer arguments, - int numArguments, - @Cast("Nd4jLong**") @ByPtrPtr LongPointer shapeArguments, - int numShapeArguments, - IntPointer indexArguments, - int numIndexArguments, - @ByPtrPtr IntPointer intArrays, - int numIntArrays, - Pointer realArguments, - int numRealArguments, - @Cast("sd::DataType") int dtype); -public native void execAggregate(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("void**") @ByPtrPtr Pointer arguments, - int numArguments, - @Cast("Nd4jLong**") @ByPtrPtr LongBuffer shapeArguments, - int numShapeArguments, - IntBuffer indexArguments, - int numIndexArguments, - @ByPtrPtr IntBuffer intArrays, - int numIntArrays, - Pointer realArguments, - int numRealArguments, - @Cast("sd::DataType") int dtype); -public native void execAggregate(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("void**") @ByPtrPtr Pointer arguments, - int numArguments, - @Cast("Nd4jLong**") @ByPtrPtr long[] shapeArguments, - int numShapeArguments, - int[] indexArguments, - int numIndexArguments, - @ByPtrPtr int[] intArrays, - int numIntArrays, - Pointer realArguments, - int numRealArguments, - @Cast("sd::DataType") int dtype); - - -public native void batchExecutor(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int numAggregates, - int opNum, - int maxArgs, - int maxShapes, - int maxIntArrays, - int maxIntArraySize, - int maxIdx, - int maxReals, - Pointer ptrToArguments, - @Cast("sd::DataType") int dtype); - -public native void execAggregateBatch(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int numAggregates, - int opNum, - int maxArgs, - int maxShapes, - int maxIntArrays, - int maxIntArraySize, - int maxIdx, - int maxReals, - Pointer ptrToArguments, - @Cast("sd::DataType") int dtype); +public native void execAggregate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("void**") PointerPointer arguments, int numArguments, + @Cast("Nd4jLong**") PointerPointer shapeArguments, int numShapeArguments, + IntPointer indexArguments, int numIndexArguments, + @Cast("int**") PointerPointer intArrays, int numIntArrays, + Pointer realArguments, int numRealArguments, + @Cast("sd::DataType") int dtype); +public native void execAggregate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("void**") @ByPtrPtr Pointer arguments, int numArguments, + @Cast("Nd4jLong**") @ByPtrPtr LongPointer shapeArguments, int numShapeArguments, + IntPointer indexArguments, int numIndexArguments, + @ByPtrPtr IntPointer intArrays, int numIntArrays, + Pointer realArguments, int numRealArguments, + @Cast("sd::DataType") int dtype); +public native void execAggregate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("void**") @ByPtrPtr Pointer arguments, int numArguments, + @Cast("Nd4jLong**") @ByPtrPtr LongBuffer shapeArguments, int numShapeArguments, + IntBuffer indexArguments, int numIndexArguments, + @ByPtrPtr IntBuffer intArrays, int numIntArrays, + Pointer realArguments, int numRealArguments, + @Cast("sd::DataType") int dtype); +public native void execAggregate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("void**") @ByPtrPtr Pointer arguments, int numArguments, + @Cast("Nd4jLong**") @ByPtrPtr long[] shapeArguments, int numShapeArguments, + int[] indexArguments, int numIndexArguments, + @ByPtrPtr int[] intArrays, int numIntArrays, + Pointer realArguments, int numRealArguments, + @Cast("sd::DataType") int dtype); + +public native void batchExecutor(@Cast("Nd4jPointer*") PointerPointer extraPointers, int numAggregates, + int opNum, int maxArgs, int maxShapes, + int maxIntArrays, int maxIntArraySize, int maxIdx, + int maxReals, Pointer ptrToArguments, + @Cast("sd::DataType") int dtype); + +public native void execAggregateBatch(@Cast("Nd4jPointer*") PointerPointer extraPointers, int numAggregates, + int opNum, int maxArgs, int maxShapes, + int maxIntArrays, int maxIntArraySize, + int maxIdx, int maxReals, + Pointer ptrToArguments, @Cast("sd::DataType") int dtype); /** * Random operations @@ -2687,21 +2822,18 @@ public native void execAggregateBatch(@Cast("Nd4jPointer*") PointerPointer extra * @param zShapeBuffer * @param extraArguments */ -public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("Nd4jPointer") Pointer state, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeBuffer, @Cast("const Nd4jLong*") LongPointer dZShapeBuffer, - Pointer extraArguments); -public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("Nd4jPointer") Pointer state, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeBuffer, @Cast("const Nd4jLong*") LongBuffer dZShapeBuffer, - Pointer extraArguments); -public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("Nd4jPointer") Pointer state, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeBuffer, @Cast("const Nd4jLong*") long[] dZShapeBuffer, - Pointer extraArguments); +public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("Nd4jPointer") Pointer state, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeBuffer, + @Cast("const Nd4jLong*") LongPointer dZShapeBuffer, Pointer extraArguments); +public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("Nd4jPointer") Pointer state, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeBuffer, + @Cast("const Nd4jLong*") LongBuffer dZShapeBuffer, Pointer extraArguments); +public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("Nd4jPointer") Pointer state, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeBuffer, + @Cast("const Nd4jLong*") long[] dZShapeBuffer, Pointer extraArguments); /** * @@ -2716,27 +2848,30 @@ public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers * @param zShapeBuffer * @param extraArguments */ -public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("Nd4jPointer") Pointer state, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeBuffer, @Cast("const Nd4jLong*") LongPointer dXShapeBuffer, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeBuffer, @Cast("const Nd4jLong*") LongPointer dYShapeBuffer, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeBuffer, @Cast("const Nd4jLong*") LongPointer dZShapeBuffer, - Pointer extraArguments); -public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("Nd4jPointer") Pointer state, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeBuffer, @Cast("const Nd4jLong*") LongBuffer dXShapeBuffer, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeBuffer, @Cast("const Nd4jLong*") LongBuffer dYShapeBuffer, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeBuffer, @Cast("const Nd4jLong*") LongBuffer dZShapeBuffer, - Pointer extraArguments); -public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("Nd4jPointer") Pointer state, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeBuffer, @Cast("const Nd4jLong*") long[] dXShapeBuffer, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeBuffer, @Cast("const Nd4jLong*") long[] dYShapeBuffer, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeBuffer, @Cast("const Nd4jLong*") long[] dZShapeBuffer, - Pointer extraArguments); +public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("Nd4jPointer") Pointer state, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeBuffer, + @Cast("const Nd4jLong*") LongPointer dXShapeBuffer, OpaqueDataBuffer dbY, + @Cast("const Nd4jLong*") LongPointer hYShapeBuffer, + @Cast("const Nd4jLong*") LongPointer dYShapeBuffer, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeBuffer, + @Cast("const Nd4jLong*") LongPointer dZShapeBuffer, Pointer extraArguments); +public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("Nd4jPointer") Pointer state, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeBuffer, + @Cast("const Nd4jLong*") LongBuffer dXShapeBuffer, OpaqueDataBuffer dbY, + @Cast("const Nd4jLong*") LongBuffer hYShapeBuffer, + @Cast("const Nd4jLong*") LongBuffer dYShapeBuffer, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeBuffer, + @Cast("const Nd4jLong*") LongBuffer dZShapeBuffer, Pointer extraArguments); +public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("Nd4jPointer") Pointer state, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeBuffer, + @Cast("const Nd4jLong*") long[] dXShapeBuffer, OpaqueDataBuffer dbY, + @Cast("const Nd4jLong*") long[] hYShapeBuffer, + @Cast("const Nd4jLong*") long[] dYShapeBuffer, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeBuffer, + @Cast("const Nd4jLong*") long[] dZShapeBuffer, Pointer extraArguments); /** * @@ -2749,25 +2884,24 @@ public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointer * @param zShapeBuffer * @param extraArguments */ -public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("Nd4jPointer") Pointer state, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeBuffer, @Cast("const Nd4jLong*") LongPointer dXShapeBuffer, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeBuffer, @Cast("const Nd4jLong*") LongPointer dZShapeBuffer, - Pointer extraArguments); -public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("Nd4jPointer") Pointer state, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeBuffer, @Cast("const Nd4jLong*") LongBuffer dXShapeBuffer, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeBuffer, @Cast("const Nd4jLong*") LongBuffer dZShapeBuffer, - Pointer extraArguments); -public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("Nd4jPointer") Pointer state, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeBuffer, @Cast("const Nd4jLong*") long[] dXShapeBuffer, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeBuffer, @Cast("const Nd4jLong*") long[] dZShapeBuffer, - Pointer extraArguments); - +public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("Nd4jPointer") Pointer state, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeBuffer, + @Cast("const Nd4jLong*") LongPointer dXShapeBuffer, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeBuffer, + @Cast("const Nd4jLong*") LongPointer dZShapeBuffer, Pointer extraArguments); +public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("Nd4jPointer") Pointer state, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeBuffer, + @Cast("const Nd4jLong*") LongBuffer dXShapeBuffer, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeBuffer, + @Cast("const Nd4jLong*") LongBuffer dZShapeBuffer, Pointer extraArguments); +public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("Nd4jPointer") Pointer state, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeBuffer, + @Cast("const Nd4jLong*") long[] dXShapeBuffer, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeBuffer, + @Cast("const Nd4jLong*") long[] dZShapeBuffer, Pointer extraArguments); /** * @@ -2777,10 +2911,8 @@ public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointer * @param ptrToBuffer * @return */ -public native @Cast("Nd4jPointer") Pointer initRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, - long seed, - long bufferSize, - @Cast("Nd4jPointer") Pointer ptrToBuffer); +public native @Cast("Nd4jPointer") Pointer initRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, long seed, + long bufferSize, @Cast("Nd4jPointer") Pointer ptrToBuffer); /** * @@ -2788,9 +2920,8 @@ public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointer * @param seed * @param ptrRandom */ -public native void refreshBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointers, - long seed, - @Cast("Nd4jPointer") Pointer ptrRandom); +public native void refreshBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointers, long seed, + @Cast("Nd4jPointer") Pointer ptrRandom); /** * @@ -2798,9 +2929,8 @@ public native void refreshBuffer(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param seed * @param ptrRandom */ -public native void reSeedBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointers, - long seed, - @Cast("Nd4jPointer") Pointer ptrRandom); +public native void reSeedBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointers, long seed, + @Cast("Nd4jPointer") Pointer ptrRandom); /** * @@ -2809,95 +2939,93 @@ public native void reSeedBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointe public native void destroyRandom(@Cast("Nd4jPointer") Pointer ptrRandom); /** -* -* @param data -* @param shapeBuffer -* @param wordSize -* @param headerSize -* @return -*/ + * + * @param data + * @param shapeBuffer + * @param wordSize + * @param headerSize + * @return + */ -public native @Cast("Nd4jPointer") Pointer numpyHeaderForNd4j(@Cast("Nd4jPointer") Pointer data,@Cast("Nd4jPointer") Pointer shapeBuffer,@Cast("Nd4jLong") long wordSize,@Cast("Nd4jLong*") LongPointer headerSize); -public native @Cast("Nd4jPointer") Pointer numpyHeaderForNd4j(@Cast("Nd4jPointer") Pointer data,@Cast("Nd4jPointer") Pointer shapeBuffer,@Cast("Nd4jLong") long wordSize,@Cast("Nd4jLong*") LongBuffer headerSize); -public native @Cast("Nd4jPointer") Pointer numpyHeaderForNd4j(@Cast("Nd4jPointer") Pointer data,@Cast("Nd4jPointer") Pointer shapeBuffer,@Cast("Nd4jLong") long wordSize,@Cast("Nd4jLong*") long[] headerSize); +public native @Cast("Nd4jPointer") Pointer numpyHeaderForNd4j(@Cast("Nd4jPointer") Pointer data, @Cast("Nd4jPointer") Pointer shapeBuffer, + @Cast("Nd4jLong") long wordSize, @Cast("Nd4jLong*") LongPointer headerSize); +public native @Cast("Nd4jPointer") Pointer numpyHeaderForNd4j(@Cast("Nd4jPointer") Pointer data, @Cast("Nd4jPointer") Pointer shapeBuffer, + @Cast("Nd4jLong") long wordSize, @Cast("Nd4jLong*") LongBuffer headerSize); +public native @Cast("Nd4jPointer") Pointer numpyHeaderForNd4j(@Cast("Nd4jPointer") Pointer data, @Cast("Nd4jPointer") Pointer shapeBuffer, + @Cast("Nd4jLong") long wordSize, @Cast("Nd4jLong*") long[] headerSize); /** -* Load numpy from a header -* based on the cnpy parse from header method. -* @param data the header data to parse -* @return a pointer to a numpy cnpy:NpyArray struct -*/ + * Load numpy from a header + * based on the cnpy parse from header method. + * @param data the header data to parse + * @return a pointer to a numpy cnpy:NpyArray struct + */ public native @Cast("Nd4jPointer") Pointer loadNpyFromHeader(@Cast("Nd4jPointer") Pointer data); /** -* Create a numpy array from an nd4j -* array -* @param data a pointer to the data -* @param shapeBuffer the shapebuffer for the nd4j array -* @param wordSize the word size (4 for float, 8 for doubles) -* @return a pointer to a numpy array -*/ - -public native @Cast("Nd4jPointer") Pointer numpyFromNd4j(@Cast("Nd4jPointer") Pointer data,@Cast("Nd4jPointer") Pointer shapeBuffer,@Cast("Nd4jLong") long wordSize); + * Create a numpy array from an nd4j + * array + * @param data a pointer to the data + * @param shapeBuffer the shapebuffer for the nd4j array + * @param wordSize the word size (4 for float, 8 for doubles) + * @return a pointer to a numpy array + */ +public native @Cast("Nd4jPointer") Pointer numpyFromNd4j(@Cast("Nd4jPointer") Pointer data, @Cast("Nd4jPointer") Pointer shapeBuffer, + @Cast("Nd4jLong") long wordSize); /** -* -* @param npyArray -* @return -*/ + * + * @param npyArray + * @return + */ public native @Cast("Nd4jPointer") Pointer shapeBufferForNumpy(@Cast("Nd4jPointer") Pointer npyArray); - /** -* Get the shape buffer from a -* numpy array. -* **Warning** this allocates memory -* @param npyArray -* @return -*/ + * Get the shape buffer from a + * numpy array. + * **Warning** this allocates memory + * @param npyArray + * @return + */ public native @Cast("Nd4jPointer") Pointer shapeBufferForNumpyHeader(@Cast("Nd4jPointer") Pointer npyArray); - - /** -* -* @param npyArray -* @return -*/ + * + * @param npyArray + * @return + */ public native @Cast("Nd4jPointer") Pointer dataPointForNumpyHeader(@Cast("Nd4jPointer") Pointer npyArray); /** -* -* @param npyArray -* @return -*/ + * + * @param npyArray + * @return + */ public native @Cast("Nd4jPointer") Pointer dataPointForNumpyStruct(@Cast("Nd4jPointer") Pointer npyArrayStruct); /** -* -* @param npyArray -* @param fromFile -* @return -*/ + * + * @param npyArray + * @param fromFile + * @return + */ public native @Cast("Nd4jPointer") Pointer dataPointForNumpy(@Cast("Nd4jPointer") Pointer npyArray); /** -* Load a numpy array from a file -* and return it as an Nd4jPointer -* @param path -* @return -*/ + * Load a numpy array from a file + * and return it as an Nd4jPointer + * @param path + * @return + */ public native @Cast("Nd4jPointer") Pointer numpyFromFile(@StdString BytePointer path); public native @Cast("Nd4jPointer") Pointer numpyFromFile(@StdString String path); - ////// NPZ ////// public native Pointer mapFromNpzFile(@StdString BytePointer path); public native Pointer mapFromNpzFile(@StdString String path); - public native int getNumNpyArraysInMap(Pointer map); public native @Cast("char*") String getNpyArrayNameFromMap(Pointer map, int index); @@ -2922,26 +3050,23 @@ public native void reSeedBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointe ////// /** -* Get the element size for a numpy array -* @param npyArray the numpy array's address -* to get the length for -* @return -*/ + * Get the element size for a numpy array + * @param npyArray the numpy array's address + * to get the length for + * @return + */ public native int elementSizeForNpyArray(@Cast("Nd4jPointer") Pointer npyArray); - /** -* Get the element size for a numpy array -* @param npyArray the numpy array's address -* to get the length for -* @return -*/ + * Get the element size for a numpy array + * @param npyArray the numpy array's address + * to get the length for + * @return + */ public native int elementSizeForNpyArrayHeader(@Cast("Nd4jPointer") Pointer npyArray); - public native void releaseNumpy(@Cast("Nd4jPointer") Pointer npyArray); - /** * Return the length of a shape buffer * based on the pointer @@ -2950,18 +3075,18 @@ public native void reSeedBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointe */ public native int lengthForShapeBufferPointer(@Cast("Nd4jPointer") Pointer buffer); - - /** -* The pointer to get the address for -* -* @param address the address to get the pointer -* @return the pointer for the given address -*/ +/** + * The pointer to get the address for + * + * @param address the address to get the pointer + * @return the pointer for the given address + */ public native @Cast("Nd4jPointer") Pointer pointerForAddress(@Cast("Nd4jLong") long _address); /** - * This method takes single N-dimensional tensor, and copies its TADs to target arrays + * This method takes single N-dimensional tensor, and copies its TADs to target + * arrays * * @param x * @param xShapeInfo @@ -2969,164 +3094,138 @@ public native void reSeedBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointe * @param zShapeInfo * @return */ -public native void tear(@Cast("Nd4jPointer*") PointerPointer extraPointers, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer xShapeInfo, @Cast("const Nd4jLong*") LongPointer dxShapeInfo, - @Cast("Nd4jPointer*") PointerPointer targets, @Cast("const Nd4jLong*") LongPointer zShapeInfo, - @Cast("const Nd4jLong*") LongPointer tadShapeInfo, - @Cast("const Nd4jLong*") LongPointer tadOffsets); -public native void tear(@Cast("Nd4jPointer*") PointerPointer extraPointers, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer xShapeInfo, @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, - @Cast("Nd4jPointer*") PointerPointer targets, @Cast("const Nd4jLong*") LongBuffer zShapeInfo, - @Cast("const Nd4jLong*") LongBuffer tadShapeInfo, - @Cast("const Nd4jLong*") LongBuffer tadOffsets); -public native void tear(@Cast("Nd4jPointer*") PointerPointer extraPointers, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] xShapeInfo, @Cast("const Nd4jLong*") long[] dxShapeInfo, - @Cast("Nd4jPointer*") PointerPointer targets, @Cast("const Nd4jLong*") long[] zShapeInfo, - @Cast("const Nd4jLong*") long[] tadShapeInfo, - @Cast("const Nd4jLong*") long[] tadOffsets); - -public native void sort(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") LongPointer xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") LongPointer dxShapeInfo, - @Cast("bool") boolean descending); -public native void sort(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") LongBuffer xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, - @Cast("bool") boolean descending); -public native void sort(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") long[] xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") long[] dxShapeInfo, - @Cast("bool") boolean descending); - -public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") LongPointer xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") LongPointer dxShapeInfo, - Pointer y, @Cast("const Nd4jLong*") LongPointer yShapeInfo, - Pointer dy, @Cast("const Nd4jLong*") LongPointer dyShapeInfo, - @Cast("bool") boolean descending); -public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") LongBuffer xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, - Pointer y, @Cast("const Nd4jLong*") LongBuffer yShapeInfo, - Pointer dy, @Cast("const Nd4jLong*") LongBuffer dyShapeInfo, - @Cast("bool") boolean descending); -public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") long[] xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") long[] dxShapeInfo, - Pointer y, @Cast("const Nd4jLong*") long[] yShapeInfo, - Pointer dy, @Cast("const Nd4jLong*") long[] dyShapeInfo, - @Cast("bool") boolean descending); - -public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") LongPointer xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") LongPointer dxShapeInfo, - Pointer y, @Cast("const Nd4jLong*") LongPointer yShapeInfo, - Pointer dy, @Cast("const Nd4jLong*") LongPointer dyShapeInfo, - @Cast("bool") boolean descending); -public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") LongBuffer xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, - Pointer y, @Cast("const Nd4jLong*") LongBuffer yShapeInfo, - Pointer dy, @Cast("const Nd4jLong*") LongBuffer dyShapeInfo, - @Cast("bool") boolean descending); -public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") long[] xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") long[] dxShapeInfo, - Pointer y, @Cast("const Nd4jLong*") long[] yShapeInfo, - Pointer dy, @Cast("const Nd4jLong*") long[] dyShapeInfo, - @Cast("bool") boolean descending); - -public native void sortTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") LongPointer xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") LongPointer dxShapeInfo, - IntPointer dimension, - int dimensionLength, - @Cast("const Nd4jLong*") LongPointer tadShapeInfo, - @Cast("const Nd4jLong*") LongPointer tadOffsets, - @Cast("bool") boolean descending); -public native void sortTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") LongBuffer xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, - IntBuffer dimension, - int dimensionLength, - @Cast("const Nd4jLong*") LongBuffer tadShapeInfo, - @Cast("const Nd4jLong*") LongBuffer tadOffsets, - @Cast("bool") boolean descending); -public native void sortTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") long[] xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") long[] dxShapeInfo, - int[] dimension, - int dimensionLength, - @Cast("const Nd4jLong*") long[] tadShapeInfo, - @Cast("const Nd4jLong*") long[] tadOffsets, - @Cast("bool") boolean descending); - -public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") LongPointer xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") LongPointer dxShapeInfo, - Pointer y, @Cast("const Nd4jLong*") LongPointer yShapeInfo, - Pointer dy, @Cast("const Nd4jLong*") LongPointer dyShapeInfo, - IntPointer dimension, - int dimensionLength, - @Cast("bool") boolean descending); -public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") LongBuffer xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, - Pointer y, @Cast("const Nd4jLong*") LongBuffer yShapeInfo, - Pointer dy, @Cast("const Nd4jLong*") LongBuffer dyShapeInfo, - IntBuffer dimension, - int dimensionLength, - @Cast("bool") boolean descending); -public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") long[] xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") long[] dxShapeInfo, - Pointer y, @Cast("const Nd4jLong*") long[] yShapeInfo, - Pointer dy, @Cast("const Nd4jLong*") long[] dyShapeInfo, - int[] dimension, - int dimensionLength, - @Cast("bool") boolean descending); - -public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") LongPointer xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") LongPointer dxShapeInfo, - Pointer y, @Cast("const Nd4jLong*") LongPointer yShapeInfo, - Pointer dy, @Cast("const Nd4jLong*") LongPointer dyShapeInfo, - IntPointer dimension, - int dimensionLength, - @Cast("bool") boolean descending); -public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") LongBuffer xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, - Pointer y, @Cast("const Nd4jLong*") LongBuffer yShapeInfo, - Pointer dy, @Cast("const Nd4jLong*") LongBuffer dyShapeInfo, - IntBuffer dimension, - int dimensionLength, - @Cast("bool") boolean descending); -public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") long[] xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") long[] dxShapeInfo, - Pointer y, @Cast("const Nd4jLong*") long[] yShapeInfo, - Pointer dy, @Cast("const Nd4jLong*") long[] dyShapeInfo, - int[] dimension, - int dimensionLength, - @Cast("bool") boolean descending); - +public native void tear(@Cast("Nd4jPointer*") PointerPointer extraPointers, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer xShapeInfo, @Cast("const Nd4jLong*") LongPointer dxShapeInfo, + @Cast("Nd4jPointer*") PointerPointer targets, @Cast("const Nd4jLong*") LongPointer zShapeInfo, + @Cast("const Nd4jLong*") LongPointer tadShapeInfo, @Cast("const Nd4jLong*") LongPointer tadOffsets); +public native void tear(@Cast("Nd4jPointer*") PointerPointer extraPointers, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer xShapeInfo, @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, + @Cast("Nd4jPointer*") PointerPointer targets, @Cast("const Nd4jLong*") LongBuffer zShapeInfo, + @Cast("const Nd4jLong*") LongBuffer tadShapeInfo, @Cast("const Nd4jLong*") LongBuffer tadOffsets); +public native void tear(@Cast("Nd4jPointer*") PointerPointer extraPointers, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] xShapeInfo, @Cast("const Nd4jLong*") long[] dxShapeInfo, + @Cast("Nd4jPointer*") PointerPointer targets, @Cast("const Nd4jLong*") long[] zShapeInfo, + @Cast("const Nd4jLong*") long[] tadShapeInfo, @Cast("const Nd4jLong*") long[] tadOffsets); + +public native void sort(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") LongPointer xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") LongPointer dxShapeInfo, @Cast("bool") boolean descending); +public native void sort(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") LongBuffer xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, @Cast("bool") boolean descending); +public native void sort(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") long[] xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") long[] dxShapeInfo, @Cast("bool") boolean descending); + +public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") LongPointer xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") LongPointer dxShapeInfo, Pointer y, + @Cast("const Nd4jLong*") LongPointer yShapeInfo, Pointer dy, + @Cast("const Nd4jLong*") LongPointer dyShapeInfo, @Cast("bool") boolean descending); +public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") LongBuffer xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, Pointer y, + @Cast("const Nd4jLong*") LongBuffer yShapeInfo, Pointer dy, + @Cast("const Nd4jLong*") LongBuffer dyShapeInfo, @Cast("bool") boolean descending); +public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") long[] xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") long[] dxShapeInfo, Pointer y, + @Cast("const Nd4jLong*") long[] yShapeInfo, Pointer dy, + @Cast("const Nd4jLong*") long[] dyShapeInfo, @Cast("bool") boolean descending); + +public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") LongPointer xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") LongPointer dxShapeInfo, Pointer y, + @Cast("const Nd4jLong*") LongPointer yShapeInfo, Pointer dy, + @Cast("const Nd4jLong*") LongPointer dyShapeInfo, @Cast("bool") boolean descending); +public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") LongBuffer xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, Pointer y, + @Cast("const Nd4jLong*") LongBuffer yShapeInfo, Pointer dy, + @Cast("const Nd4jLong*") LongBuffer dyShapeInfo, @Cast("bool") boolean descending); +public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") long[] xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") long[] dxShapeInfo, Pointer y, + @Cast("const Nd4jLong*") long[] yShapeInfo, Pointer dy, + @Cast("const Nd4jLong*") long[] dyShapeInfo, @Cast("bool") boolean descending); + +public native void sortTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") LongPointer xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") LongPointer dxShapeInfo, IntPointer dimension, + int dimensionLength, @Cast("const Nd4jLong*") LongPointer tadShapeInfo, + @Cast("const Nd4jLong*") LongPointer tadOffsets, @Cast("bool") boolean descending); +public native void sortTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") LongBuffer xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, IntBuffer dimension, + int dimensionLength, @Cast("const Nd4jLong*") LongBuffer tadShapeInfo, + @Cast("const Nd4jLong*") LongBuffer tadOffsets, @Cast("bool") boolean descending); +public native void sortTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") long[] xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") long[] dxShapeInfo, int[] dimension, + int dimensionLength, @Cast("const Nd4jLong*") long[] tadShapeInfo, + @Cast("const Nd4jLong*") long[] tadOffsets, @Cast("bool") boolean descending); + +public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") LongPointer xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") LongPointer dxShapeInfo, Pointer y, + @Cast("const Nd4jLong*") LongPointer yShapeInfo, Pointer dy, + @Cast("const Nd4jLong*") LongPointer dyShapeInfo, IntPointer dimension, + int dimensionLength, @Cast("bool") boolean descending); +public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") LongBuffer xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, Pointer y, + @Cast("const Nd4jLong*") LongBuffer yShapeInfo, Pointer dy, + @Cast("const Nd4jLong*") LongBuffer dyShapeInfo, IntBuffer dimension, + int dimensionLength, @Cast("bool") boolean descending); +public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") long[] xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") long[] dxShapeInfo, Pointer y, + @Cast("const Nd4jLong*") long[] yShapeInfo, Pointer dy, + @Cast("const Nd4jLong*") long[] dyShapeInfo, int[] dimension, + int dimensionLength, @Cast("bool") boolean descending); + +public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") LongPointer xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") LongPointer dxShapeInfo, Pointer y, + @Cast("const Nd4jLong*") LongPointer yShapeInfo, Pointer dy, + @Cast("const Nd4jLong*") LongPointer dyShapeInfo, IntPointer dimension, + int dimensionLength, @Cast("bool") boolean descending); +public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") LongBuffer xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, Pointer y, + @Cast("const Nd4jLong*") LongBuffer yShapeInfo, Pointer dy, + @Cast("const Nd4jLong*") LongBuffer dyShapeInfo, IntBuffer dimension, + int dimensionLength, @Cast("bool") boolean descending); +public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") long[] xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") long[] dxShapeInfo, Pointer y, + @Cast("const Nd4jLong*") long[] yShapeInfo, Pointer dy, + @Cast("const Nd4jLong*") long[] dyShapeInfo, int[] dimension, + int dimensionLength, @Cast("bool") boolean descending); // special sort impl for sorting out COO indices and values -public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, Pointer values, @Cast("Nd4jLong") long length, int rank); -public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer indices, Pointer values, @Cast("Nd4jLong") long length, int rank); -public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] indices, Pointer values, @Cast("Nd4jLong") long length, int rank); - - -public native @Cast("Nd4jLong*") LongPointer mmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String fileName, @Cast("Nd4jLong") long length); -public native @Cast("Nd4jLong*") LongBuffer mmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") BytePointer fileName, @Cast("Nd4jLong") long length); - -public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer ptrMap, @Cast("Nd4jLong") long length); -public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer ptrMap, @Cast("Nd4jLong") long length); -public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] ptrMap, @Cast("Nd4jLong") long length); +public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, + Pointer values, @Cast("Nd4jLong") long length, int rank); +public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer indices, + Pointer values, @Cast("Nd4jLong") long length, int rank); +public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] indices, + Pointer values, @Cast("Nd4jLong") long length, int rank); + +public native @Cast("Nd4jLong*") LongPointer mmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String fileName, + @Cast("Nd4jLong") long length); +public native @Cast("Nd4jLong*") LongBuffer mmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") BytePointer fileName, + @Cast("Nd4jLong") long length); + +public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer ptrMap, + @Cast("Nd4jLong") long length); +public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer ptrMap, + @Cast("Nd4jLong") long length); +public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] ptrMap, + @Cast("Nd4jLong") long length); // flatbuffers execution -public native OpaqueResultWrapper executeFlatGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer flatBufferPointer); +public native OpaqueResultWrapper executeFlatGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, + @Cast("Nd4jPointer") Pointer flatBufferPointer); public native @Cast("Nd4jLong") long getResultWrapperSize(OpaqueResultWrapper ptr); public native @Cast("Nd4jPointer") Pointer getResultWrapperPointer(OpaqueResultWrapper ptr); @@ -3136,34 +3235,117 @@ public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPoin public native @Cast("char*") String getAllOperations(); // customOp executioner -public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, @Cast("bool") boolean isInplace); -public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, @Cast("bool") boolean isInplace); -public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, @Cast("bool") boolean isInplace); -public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, @Cast("bool") boolean isInplace); -public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, @Cast("bool") boolean isInplace); -public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, @Cast("bool") boolean isInplace); -public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer") Pointer opContext); - -public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs); -public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs); -public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntPointer dArgs, int numDArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntBuffer dArgs, int numDArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, int[] dArgs, int numDArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntPointer dArgs, int numDArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntBuffer dArgs, int numDArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, int[] dArgs, int numDArgs); +public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, + @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, + int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, + @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, + DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, + int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, + @Cast("bool") boolean isInplace); +public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, + @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, + int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, + @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, + DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, + int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, + @Cast("bool") boolean isInplace); +public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, + @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, + int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, + @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, + double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, + int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, + @Cast("bool") boolean isInplace); +public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, + @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, + int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, + @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, + DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, + int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, + @Cast("bool") boolean isInplace); +public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, + @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, + int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, + @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, + DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, + int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, + @Cast("bool") boolean isInplace); +public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, + @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, + int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, + @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, + double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, + int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, + @Cast("bool") boolean isInplace); +public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, + @Cast("Nd4jPointer") Pointer opContext); + +public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, + @Cast("Nd4jLong") long hash, + @Cast("Nd4jPointer*") PointerPointer inputShapes, + int numInputShapes, + DoublePointer tArgs, int numTArgs, + @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs); +public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, + @Cast("Nd4jLong") long hash, + @Cast("Nd4jPointer*") PointerPointer inputShapes, + int numInputShapes, + DoubleBuffer tArgs, int numTArgs, + @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs); +public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, + @Cast("Nd4jLong") long hash, + @Cast("Nd4jPointer*") PointerPointer inputShapes, + int numInputShapes, + double[] tArgs, int numTArgs, + @Cast("Nd4jLong*") long[] iArgs, int numIArgs); +public native OpaqueShapeList calculateOutputShapes2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, + @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, + @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntPointer dArgs, + int numDArgs); +public native OpaqueShapeList calculateOutputShapes2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, + @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, + @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntBuffer dArgs, + int numDArgs); +public native OpaqueShapeList calculateOutputShapes2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, + @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, + @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, int[] dArgs, + int numDArgs); +public native OpaqueShapeList calculateOutputShapes2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, + @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, + @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntPointer dArgs, + int numDArgs); +public native OpaqueShapeList calculateOutputShapes2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, + @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, + @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntBuffer dArgs, + int numDArgs); +public native OpaqueShapeList calculateOutputShapes2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, + @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, + @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, int[] dArgs, + int numDArgs); public native @Cast("Nd4jLong") long getShapeListSize(OpaqueShapeList list); public native @Cast("const Nd4jLong*") LongPointer getShape(OpaqueShapeList list, @Cast("Nd4jLong") long i); public native void deleteShapeList(@Cast("Nd4jPointer") Pointer shapeList); -public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer") Pointer flatBufferPointer); +public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, + @Cast("Nd4jPointer") Pointer flatBufferPointer); -public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs); -public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs); -public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs); +public native OpaqueVariablesSet executeStoredGraph( + @Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, + @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs); +public native OpaqueVariablesSet executeStoredGraph( + @Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, + @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs); +public native OpaqueVariablesSet executeStoredGraph( + @Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, + @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs); public native @Cast("Nd4jLong") long getVariablesSetSize(OpaqueVariablesSet set); public native @Cast("Nd4jStatus") int getVariablesSetStatus(OpaqueVariablesSet set); @@ -3183,68 +3365,123 @@ public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPoin public native void deleteVariablesSet(OpaqueVariablesSet pointer); -// GraphState creation -public native @Cast("Nd4jPointer") Pointer getGraphState(@Cast("Nd4jLong") long id); - -public native void deleteGraphState(@Cast("Nd4jPointer") Pointer state); - public native void deleteResultWrapper(@Cast("Nd4jPointer") Pointer ptr); -public native int estimateThreshold(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer x, @Cast("const Nd4jLong*") LongPointer xShapeInfo, int N, float threshold); -public native int estimateThreshold(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer x, @Cast("const Nd4jLong*") LongBuffer xShapeInfo, int N, float threshold); -public native int estimateThreshold(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer x, @Cast("const Nd4jLong*") long[] xShapeInfo, int N, float threshold); - -// this method executes op that requires scope to be present: if/while/cond/whatever -public native @Cast("Nd4jStatus") int execCustomOpWithScope(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer state, @Cast("Nd4jLong") long opHash, @Cast("Nd4jLong*") LongPointer scopes, int numScopes, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs); -public native @Cast("Nd4jStatus") int execCustomOpWithScope(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer state, @Cast("Nd4jLong") long opHash, @Cast("Nd4jLong*") LongBuffer scopes, int numScopes, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs); -public native @Cast("Nd4jStatus") int execCustomOpWithScope(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer state, @Cast("Nd4jLong") long opHash, @Cast("Nd4jLong*") long[] scopes, int numScopes, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs); - -//void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int numStrings, Nd4jPointer buffer); -public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String string, int length); -public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") BytePointer string, int length); -public native @Cast("Nd4jLong") long getUtf8StringLength(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr); -public native @Cast("char*") BytePointer getUtf8StringBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr); +public native int estimateThreshold(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer x, + @Cast("const Nd4jLong*") LongPointer xShapeInfo, int N, + float threshold); +public native int estimateThreshold(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer x, + @Cast("const Nd4jLong*") LongBuffer xShapeInfo, int N, + float threshold); +public native int estimateThreshold(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer x, + @Cast("const Nd4jLong*") long[] xShapeInfo, int N, + float threshold); + +// this method executes op that requires scope to be present: +// if/while/cond/whatever +public native @Cast("Nd4jStatus") int execCustomOpWithScope( + @Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer state, @Cast("Nd4jLong") long opHash, + @Cast("Nd4jLong*") LongPointer scopes, int numScopes, @Cast("Nd4jPointer*") PointerPointer inputBuffers, + @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, + @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs); +public native @Cast("Nd4jStatus") int execCustomOpWithScope( + @Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer state, @Cast("Nd4jLong") long opHash, + @Cast("Nd4jLong*") LongBuffer scopes, int numScopes, @Cast("Nd4jPointer*") PointerPointer inputBuffers, + @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, + @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs); +public native @Cast("Nd4jStatus") int execCustomOpWithScope( + @Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer state, @Cast("Nd4jLong") long opHash, + @Cast("Nd4jLong*") long[] scopes, int numScopes, @Cast("Nd4jPointer*") PointerPointer inputBuffers, + @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, + @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs); + +// void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int +// numStrings, Nd4jPointer buffer); +public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, + @Cast("char*") String string, int length); +public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, + @Cast("char*") BytePointer string, int length); +public native @Cast("Nd4jLong") long getUtf8StringLength(@Cast("Nd4jPointer*") PointerPointer extraPointers, + @Cast("Nd4jPointer") Pointer ptr); +public native @Cast("char*") BytePointer getUtf8StringBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointers, + @Cast("Nd4jPointer") Pointer ptr); public native void deleteUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr); -public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs, - Pointer hX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer hXOffsets, - Pointer dX, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXOffsets, - Pointer hY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, @Cast("const Nd4jLong*") LongPointer hYOffsets, - Pointer dY, @Cast("const Nd4jLong*") LongPointer dYShapeInfo, @Cast("const Nd4jLong*") LongPointer dYOffsets, - Pointer hIindexes, @Cast("const Nd4jLong*") LongPointer hIndicesShapeInfo, Pointer dIindexes, @Cast("const Nd4jLong*") LongPointer dIndicesShapeInfo); -public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs, - Pointer hX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer hXOffsets, - Pointer dX, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXOffsets, - Pointer hY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, @Cast("const Nd4jLong*") LongBuffer hYOffsets, - Pointer dY, @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, @Cast("const Nd4jLong*") LongBuffer dYOffsets, - Pointer hIindexes, @Cast("const Nd4jLong*") LongBuffer hIndicesShapeInfo, Pointer dIindexes, @Cast("const Nd4jLong*") LongBuffer dIndicesShapeInfo); -public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs, - Pointer hX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] hXOffsets, - Pointer dX, @Cast("const Nd4jLong*") long[] dXShapeInfo, @Cast("const Nd4jLong*") long[] dXOffsets, - Pointer hY, @Cast("const Nd4jLong*") long[] hYShapeInfo, @Cast("const Nd4jLong*") long[] hYOffsets, - Pointer dY, @Cast("const Nd4jLong*") long[] dYShapeInfo, @Cast("const Nd4jLong*") long[] dYOffsets, - Pointer hIindexes, @Cast("const Nd4jLong*") long[] hIndicesShapeInfo, Pointer dIindexes, @Cast("const Nd4jLong*") long[] dIndicesShapeInfo); - -public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") LongPointer specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo); -public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") LongBuffer specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo); -public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") long[] specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo); - -public native OpaqueConstantShapeBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer strides, @Cast("sd::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); -public native OpaqueConstantShapeBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer strides, @Cast("sd::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); -public native OpaqueConstantShapeBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] strides, @Cast("sd::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); - -public native OpaqueConstantDataBuffer constantBufferLong(@Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") LongPointer data, int length); -public native OpaqueConstantDataBuffer constantBufferLong(@Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") LongBuffer data, int length); -public native OpaqueConstantDataBuffer constantBufferLong(@Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") long[] data, int length); -public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("sd::DataType") int dtype, DoublePointer data, int length); -public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("sd::DataType") int dtype, DoubleBuffer data, int length); -public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("sd::DataType") int dtype, double[] data, int length); -public native OpaqueConstantDataBuffer constantBuffer(@Cast("sd::DataType") int dtype, ConstantDescriptor descriptor); +public native void scatterUpdate( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs, Pointer hX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer hXOffsets, Pointer dX, + @Cast("const Nd4jLong*") LongPointer dXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXOffsets, Pointer hY, + @Cast("const Nd4jLong*") LongPointer hYShapeInfo, @Cast("const Nd4jLong*") LongPointer hYOffsets, Pointer dY, + @Cast("const Nd4jLong*") LongPointer dYShapeInfo, @Cast("const Nd4jLong*") LongPointer dYOffsets, Pointer hIindexes, + @Cast("const Nd4jLong*") LongPointer hIndicesShapeInfo, Pointer dIindexes, + @Cast("const Nd4jLong*") LongPointer dIndicesShapeInfo); +public native void scatterUpdate( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs, Pointer hX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer hXOffsets, Pointer dX, + @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXOffsets, Pointer hY, + @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, @Cast("const Nd4jLong*") LongBuffer hYOffsets, Pointer dY, + @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, @Cast("const Nd4jLong*") LongBuffer dYOffsets, Pointer hIindexes, + @Cast("const Nd4jLong*") LongBuffer hIndicesShapeInfo, Pointer dIindexes, + @Cast("const Nd4jLong*") LongBuffer dIndicesShapeInfo); +public native void scatterUpdate( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs, Pointer hX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] hXOffsets, Pointer dX, + @Cast("const Nd4jLong*") long[] dXShapeInfo, @Cast("const Nd4jLong*") long[] dXOffsets, Pointer hY, + @Cast("const Nd4jLong*") long[] hYShapeInfo, @Cast("const Nd4jLong*") long[] hYOffsets, Pointer dY, + @Cast("const Nd4jLong*") long[] dYShapeInfo, @Cast("const Nd4jLong*") long[] dYOffsets, Pointer hIindexes, + @Cast("const Nd4jLong*") long[] hIndicesShapeInfo, Pointer dIindexes, + @Cast("const Nd4jLong*") long[] dIndicesShapeInfo); + +public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, + @Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, + @Cast("Nd4jLong*") LongPointer specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo); +public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, + @Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, + @Cast("Nd4jLong*") LongBuffer specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo); +public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, + @Cast("Nd4jLong*") long[] shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, + @Cast("Nd4jLong*") long[] specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo); + +public native OpaqueConstantShapeBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongPointer shape, + @Cast("Nd4jLong*") LongPointer strides, + @Cast("sd::DataType") int dtype, char order, + @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); +public native OpaqueConstantShapeBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongBuffer shape, + @Cast("Nd4jLong*") LongBuffer strides, + @Cast("sd::DataType") int dtype, char order, + @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); +public native OpaqueConstantShapeBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") long[] shape, + @Cast("Nd4jLong*") long[] strides, + @Cast("sd::DataType") int dtype, char order, + @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); + +public native OpaqueConstantDataBuffer constantBufferLong(@Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") LongPointer data, + int length); +public native OpaqueConstantDataBuffer constantBufferLong(@Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") LongBuffer data, + int length); +public native OpaqueConstantDataBuffer constantBufferLong(@Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") long[] data, + int length); +public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("sd::DataType") int dtype, + DoublePointer data, + int length); +public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("sd::DataType") int dtype, + DoubleBuffer data, + int length); +public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("sd::DataType") int dtype, + double[] data, + int length); +public native OpaqueConstantDataBuffer constantBuffer( + @Cast("sd::DataType") int dtype, ConstantDescriptor descriptor); public native @Cast("Nd4jPointer") Pointer getConstantDataBufferPrimary(OpaqueConstantDataBuffer dbf); public native @Cast("Nd4jPointer") Pointer getConstantDataBufferSpecial(OpaqueConstantDataBuffer dbf); public native @Cast("Nd4jLong") long getConstantDataBufferLength(OpaqueConstantDataBuffer dbf); +public native @Cast("Nd4jLong") long getConstantDataBufferSizeOf(OpaqueConstantDataBuffer dbf); +public native void deleteShapeBuffer(OpaqueConstantDataBuffer ptr); public native @Cast("Nd4jPointer") Pointer getConstantShapeBufferPrimary(OpaqueConstantShapeBuffer dbf); public native @Cast("Nd4jPointer") Pointer getConstantShapeBufferSpecial(OpaqueConstantShapeBuffer dbf); @@ -3252,28 +3489,58 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint public native void deleteConstantDataBuffer(OpaqueConstantDataBuffer ptr); public native OpaqueContext createGraphContext(int nodeId); -public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr); +public native OpaqueRandomGenerator getGraphContextRandomGenerator( + OpaqueContext ptr); public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow); -public native void ctxShapeFunctionOverride(OpaqueContext ptr, @Cast("bool") boolean reallyOverride); +public native void ctxShapeFunctionOverride(OpaqueContext ptr, + @Cast("bool") boolean reallyOverride); public native void ctxSetExecutionMode(OpaqueContext ptr, int execMode); public native void ctxPurge(OpaqueContext ptr); public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace); -public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer); -public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); -public native void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); -public native void setGraphContextInputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo); -public native void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo); -public native void setGraphContextDArguments(OpaqueContext ptr, IntPointer arguments, int numberOfArguments); -public native void setGraphContextDArguments(OpaqueContext ptr, IntBuffer arguments, int numberOfArguments); -public native void setGraphContextDArguments(OpaqueContext ptr, int[] arguments, int numberOfArguments); -public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments); -public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, int numberOfArguments); -public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, int numberOfArguments); -public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") LongPointer arguments, int numberOfArguments); -public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") LongBuffer arguments, int numberOfArguments); -public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") long[] arguments, int numberOfArguments); -public native void setGraphContextBArguments(OpaqueContext ptr, @Cast("bool*") BooleanPointer arguments, int numberOfArguments); -public native void setGraphContextBArguments(OpaqueContext ptr, @Cast("bool*") boolean[] arguments, int numberOfArguments); +public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, + Pointer reductionPointer, + Pointer allocationPointer); +public native void setGraphContextInputArray(OpaqueContext ptr, int index, + Pointer buffer, Pointer shapeInfo, + Pointer specialBuffer, + Pointer specialShapeInfo); +public native void setGraphContextOutputArray(OpaqueContext ptr, int index, + Pointer buffer, Pointer shapeInfo, + Pointer specialBuffer, + Pointer specialShapeInfo); +public native void setGraphContextInputBuffer(OpaqueContext ptr, int index, + OpaqueDataBuffer buffer, + Pointer shapeInfo, + Pointer specialShapeInfo); +public native void setGraphContextOutputBuffer(OpaqueContext ptr, int index, + OpaqueDataBuffer buffer, + Pointer shapeInfo, + Pointer specialShapeInfo); +public native void setGraphContextDArguments(OpaqueContext ptr, IntPointer arguments, + int numberOfArguments); +public native void setGraphContextDArguments(OpaqueContext ptr, IntBuffer arguments, + int numberOfArguments); +public native void setGraphContextDArguments(OpaqueContext ptr, int[] arguments, + int numberOfArguments); +public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, + int numberOfArguments); +public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, + int numberOfArguments); +public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, + int numberOfArguments); +public native void setGraphContextIArguments(OpaqueContext ptr, + @Cast("Nd4jLong*") LongPointer arguments, + int numberOfArguments); +public native void setGraphContextIArguments(OpaqueContext ptr, + @Cast("Nd4jLong*") LongBuffer arguments, + int numberOfArguments); +public native void setGraphContextIArguments(OpaqueContext ptr, + @Cast("Nd4jLong*") long[] arguments, + int numberOfArguments); +public native void setGraphContextBArguments(OpaqueContext ptr, @Cast("bool*") BooleanPointer arguments, + int numberOfArguments); +public native void setGraphContextBArguments(OpaqueContext ptr, @Cast("bool*") boolean[] arguments, + int numberOfArguments); public native void deleteGraphContext(OpaqueContext ptr); public native OpaqueRandomGenerator createRandomGenerator(@Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/); @@ -3300,17 +3567,26 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint public native @Cast("Nd4jPointer") Pointer lcBlasHandle(OpaqueLaunchContext lc); public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc); -public native OpaqueDataBuffer allocateDataBuffer(@Cast("Nd4jLong") long elements, int dataType, @Cast("bool") boolean allocateBoth); -public native OpaqueDataBuffer dbAllocateDataBuffer(@Cast("Nd4jLong") long elements, int dataType, @Cast("bool") boolean allocateBoth); -public native OpaqueDataBuffer dbCreateExternalDataBuffer(@Cast("Nd4jLong") long elements, int dataType, @Cast("Nd4jPointer") Pointer primary, @Cast("Nd4jPointer") Pointer special); -public native OpaqueDataBuffer dbCreateView(OpaqueDataBuffer dataBuffer, @Cast("Nd4jLong") long length, @Cast("Nd4jLong") long offset); +public native OpaqueDataBuffer allocateDataBuffer(@Cast("Nd4jLong") long elements, int dataType, + @Cast("bool") boolean allocateBoth); +public native OpaqueDataBuffer dbAllocateDataBuffer(@Cast("Nd4jLong") long elements, + int dataType, + @Cast("bool") boolean allocateBoth); +public native OpaqueDataBuffer dbCreateExternalDataBuffer(@Cast("Nd4jLong") long elements, + int dataType, + @Cast("Nd4jPointer") Pointer primary, + @Cast("Nd4jPointer") Pointer special); +public native OpaqueDataBuffer dbCreateView(OpaqueDataBuffer dataBuffer, + @Cast("Nd4jLong") long length, @Cast("Nd4jLong") long offset); public native @Cast("Nd4jPointer") Pointer dbPrimaryBuffer(OpaqueDataBuffer dataBuffer); public native @Cast("Nd4jPointer") Pointer dbSpecialBuffer(OpaqueDataBuffer dataBuffer); public native void dbExpandBuffer(OpaqueDataBuffer dataBuffer, @Cast("Nd4jLong") long elements); public native void dbAllocatePrimaryBuffer(OpaqueDataBuffer dataBuffer); public native void dbAllocateSpecialBuffer(OpaqueDataBuffer dataBuffer); -public native void dbSetPrimaryBuffer(OpaqueDataBuffer dataBuffer, @Cast("Nd4jPointer") Pointer primaryBuffer, @Cast("Nd4jLong") long numBytes); -public native void dbSetSpecialBuffer(OpaqueDataBuffer dataBuffer, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong") long numBytes); +public native void dbSetPrimaryBuffer(OpaqueDataBuffer dataBuffer, + @Cast("Nd4jPointer") Pointer primaryBuffer, @Cast("Nd4jLong") long numBytes); +public native void dbSetSpecialBuffer(OpaqueDataBuffer dataBuffer, + @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong") long numBytes); public native void dbSyncToSpecial(OpaqueDataBuffer dataBuffer); public native void dbSyncToPrimary(OpaqueDataBuffer dataBuffer); public native int dbLocality(OpaqueDataBuffer dataBuffer); @@ -3324,14 +3600,13 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint public native void deleteDataBuffer(OpaqueDataBuffer dataBuffer); public native void dbExpand(OpaqueDataBuffer dataBuffer, @Cast("Nd4jLong") long elements); - public native int binaryLevel(); public native int optimalLevel(); public native @Cast("bool") boolean isMinimalRequirementsMet(); public native @Cast("bool") boolean isOptimalRequirementsMet(); -// #endif //NATIVEOPERATIONS_NATIVEOPS_H +// #endif // NATIVEOPERATIONS_NATIVEOPS_H // Parsed from memory/ExternalWorkspace.h @@ -3359,33 +3634,35 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_EXTERNALWORKSPACE_H // #define LIBND4J_EXTERNALWORKSPACE_H -// #include // #include - @Namespace("sd::memory") @NoOffset public static class ExternalWorkspace extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ExternalWorkspace(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public ExternalWorkspace(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public ExternalWorkspace position(long position) { - return (ExternalWorkspace)super.position(position); - } - - public ExternalWorkspace() { super((Pointer)null); allocate(); } - private native void allocate(); +// #include +@Namespace("sd::memory") @NoOffset public static class ExternalWorkspace extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ExternalWorkspace(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public ExternalWorkspace(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public ExternalWorkspace position(long position) { + return (ExternalWorkspace)super.position(position); + } - public ExternalWorkspace(@Cast("Nd4jPointer") Pointer ptrH, @Cast("Nd4jLong") long sizeH, @Cast("Nd4jPointer") Pointer ptrD, @Cast("Nd4jLong") long sizeD) { super((Pointer)null); allocate(ptrH, sizeH, ptrD, sizeD); } - private native void allocate(@Cast("Nd4jPointer") Pointer ptrH, @Cast("Nd4jLong") long sizeH, @Cast("Nd4jPointer") Pointer ptrD, @Cast("Nd4jLong") long sizeD); - - public native Pointer pointerHost(); - public native Pointer pointerDevice(); + public ExternalWorkspace() { super((Pointer)null); allocate(); } + private native void allocate(); - public native @Cast("Nd4jLong") long sizeHost(); - public native @Cast("Nd4jLong") long sizeDevice(); - } - + public ExternalWorkspace(@Cast("Nd4jPointer") Pointer ptrH, @Cast("Nd4jLong") long sizeH, @Cast("Nd4jPointer") Pointer ptrD, + @Cast("Nd4jLong") long sizeD) { super((Pointer)null); allocate(ptrH, sizeH, ptrD, sizeD); } + private native void allocate(@Cast("Nd4jPointer") Pointer ptrH, @Cast("Nd4jLong") long sizeH, @Cast("Nd4jPointer") Pointer ptrD, + @Cast("Nd4jLong") long sizeD); + public native Pointer pointerHost(); + public native Pointer pointerDevice(); + + public native @Cast("Nd4jLong") long sizeHost(); + public native @Cast("Nd4jLong") long sizeDevice(); +} + // namespace memory + // namespace sd // #endif @@ -3417,67 +3694,69 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_WORKSPACE_H // #define LIBND4J_WORKSPACE_H -// #include -// #include -// #include +// #include +// #include // #include // #include // #include -// #include -// #include - @Namespace("sd::memory") @NoOffset public static class Workspace extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Workspace(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Workspace(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Workspace position(long position) { - return (Workspace)super.position(position); - } - - public Workspace(ExternalWorkspace external) { super((Pointer)null); allocate(external); } - private native void allocate(ExternalWorkspace external); - public Workspace(@Cast("Nd4jLong") long initialSize/*=0L*/, @Cast("Nd4jLong") long secondaryBytes/*=0L*/) { super((Pointer)null); allocate(initialSize, secondaryBytes); } - private native void allocate(@Cast("Nd4jLong") long initialSize/*=0L*/, @Cast("Nd4jLong") long secondaryBytes/*=0L*/); - public Workspace() { super((Pointer)null); allocate(); } - private native void allocate(); - - public native @Cast("Nd4jLong") long getAllocatedSize(); - public native @Cast("Nd4jLong") long getCurrentSize(); - public native @Cast("Nd4jLong") long getCurrentOffset(); - public native @Cast("Nd4jLong") long getSpilledSize(); - public native @Cast("Nd4jLong") long getUsedSize(); - - public native @Cast("Nd4jLong") long getAllocatedSecondarySize(); - public native @Cast("Nd4jLong") long getCurrentSecondarySize(); - public native @Cast("Nd4jLong") long getCurrentSecondaryOffset(); - public native @Cast("Nd4jLong") long getSpilledSecondarySize(); - public native @Cast("Nd4jLong") long getUsedSecondarySize(); - - public native void expandBy(@Cast("Nd4jLong") long primaryBytes, @Cast("Nd4jLong") long secondaryBytes/*=0L*/); - public native void expandBy(@Cast("Nd4jLong") long primaryBytes); - public native void expandTo(@Cast("Nd4jLong") long primaryBytes, @Cast("Nd4jLong") long secondaryBytes/*=0L*/); - public native void expandTo(@Cast("Nd4jLong") long primaryBytes); - -// bool resizeSupported(); - - public native Pointer allocateBytes(@Cast("Nd4jLong") long numBytes); - public native Pointer allocateBytes(@Cast("sd::memory::MemoryType") int type, @Cast("Nd4jLong") long numBytes); - - public native void scopeIn(); - public native void scopeOut(); - - /* - * This method creates NEW workspace of the same memory size and returns pointer to it - */ - public native Workspace clone(); - } - +// #include +// #include +// #include +@Namespace("sd::memory") @NoOffset public static class Workspace extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Workspace(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Workspace(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Workspace position(long position) { + return (Workspace)super.position(position); + } + + public Workspace(ExternalWorkspace external) { super((Pointer)null); allocate(external); } + private native void allocate(ExternalWorkspace external); + public Workspace(@Cast("Nd4jLong") long initialSize/*=0L*/, @Cast("Nd4jLong") long secondaryBytes/*=0L*/) { super((Pointer)null); allocate(initialSize, secondaryBytes); } + private native void allocate(@Cast("Nd4jLong") long initialSize/*=0L*/, @Cast("Nd4jLong") long secondaryBytes/*=0L*/); + public Workspace() { super((Pointer)null); allocate(); } + private native void allocate(); + + public native @Cast("Nd4jLong") long getAllocatedSize(); + public native @Cast("Nd4jLong") long getCurrentSize(); + public native @Cast("Nd4jLong") long getCurrentOffset(); + public native @Cast("Nd4jLong") long getSpilledSize(); + public native @Cast("Nd4jLong") long getUsedSize(); + + public native @Cast("Nd4jLong") long getAllocatedSecondarySize(); + public native @Cast("Nd4jLong") long getCurrentSecondarySize(); + public native @Cast("Nd4jLong") long getCurrentSecondaryOffset(); + public native @Cast("Nd4jLong") long getSpilledSecondarySize(); + public native @Cast("Nd4jLong") long getUsedSecondarySize(); + + public native void expandBy(@Cast("Nd4jLong") long primaryBytes, @Cast("Nd4jLong") long secondaryBytes/*=0L*/); + public native void expandBy(@Cast("Nd4jLong") long primaryBytes); + public native void expandTo(@Cast("Nd4jLong") long primaryBytes, @Cast("Nd4jLong") long secondaryBytes/*=0L*/); + public native void expandTo(@Cast("Nd4jLong") long primaryBytes); + + // bool resizeSupported(); + + public native Pointer allocateBytes(@Cast("Nd4jLong") long numBytes); + public native Pointer allocateBytes(@Cast("sd::memory::MemoryType") int type, @Cast("Nd4jLong") long numBytes); + + public native void scopeIn(); + public native void scopeOut(); + + /* + * This method creates NEW workspace of the same memory size and returns + * pointer to it + */ + public native Workspace clone(); +} + // namespace memory + // namespace sd -// #endif //LIBND4J_WORKSPACE_H +// #endif // LIBND4J_WORKSPACE_H // Parsed from indexing/NDIndex.h @@ -3505,79 +3784,77 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_NDINDEX_H // #define LIBND4J_NDINDEX_H +// #include // #include + // #include -// #include - @Namespace("sd") @NoOffset public static class NDIndex extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public NDIndex(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public NDIndex(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public NDIndex position(long position) { - return (NDIndex)super.position(position); - } - - public NDIndex() { super((Pointer)null); allocate(); } - private native void allocate(); - - public native @Cast("bool") boolean isAll(); - public native @Cast("bool") boolean isPoint(); - public native @Cast("bool") boolean isInterval(); - - public native @Cast("Nd4jLong*") @StdVector LongPointer getIndices(); - public native @Cast("Nd4jLong") long stride(); - - public static native NDIndex all(); - public static native NDIndex point(@Cast("Nd4jLong") long pt); - public static native NDIndex interval(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end, @Cast("Nd4jLong") long stride/*=1*/); - public static native NDIndex interval(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end); - } - - @Namespace("sd") public static class NDIndexAll extends NDIndex { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public NDIndexAll(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public NDIndexAll(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public NDIndexAll position(long position) { - return (NDIndexAll)super.position(position); - } - - public NDIndexAll() { super((Pointer)null); allocate(); } - private native void allocate(); - public native @Cast("bool") boolean isInterval(); +@Namespace("sd") @NoOffset public static class NDIndex extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public NDIndex(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public NDIndex(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public NDIndex position(long position) { + return (NDIndex)super.position(position); } + public NDIndex() { super((Pointer)null); allocate(); } + private native void allocate(); + + public native @Cast("bool") boolean isAll(); + public native @Cast("bool") boolean isPoint(); + public native @Cast("bool") boolean isInterval(); - @Namespace("sd") public static class NDIndexPoint extends NDIndex { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public NDIndexPoint(Pointer p) { super(p); } - - public NDIndexPoint(@Cast("Nd4jLong") long point) { super((Pointer)null); allocate(point); } - private native void allocate(@Cast("Nd4jLong") long point); - public native @Cast("bool") boolean isInterval(); - } + public native @Cast("Nd4jLong*") @StdVector LongPointer getIndices(); + public native @Cast("Nd4jLong") long stride(); - @Namespace("sd") public static class NDIndexInterval extends NDIndex { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public NDIndexInterval(Pointer p) { super(p); } - - public NDIndexInterval(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end, @Cast("Nd4jLong") long stride/*=1*/) { super((Pointer)null); allocate(start, end, stride); } - private native void allocate(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end, @Cast("Nd4jLong") long stride/*=1*/); - public NDIndexInterval(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end) { super((Pointer)null); allocate(start, end); } - private native void allocate(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end); - public native @Cast("bool") boolean isInterval(); + public static native NDIndex all(); + public static native NDIndex point(@Cast("Nd4jLong") long pt); + public static native NDIndex interval(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end, @Cast("Nd4jLong") long stride/*=1*/); + public static native NDIndex interval(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end); +} + +@Namespace("sd") public static class NDIndexAll extends NDIndex { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public NDIndexAll(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public NDIndexAll(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public NDIndexAll position(long position) { + return (NDIndexAll)super.position(position); } + public NDIndexAll() { super((Pointer)null); allocate(); } + private native void allocate(); + public native @Cast("bool") boolean isInterval(); +} + +@Namespace("sd") public static class NDIndexPoint extends NDIndex { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public NDIndexPoint(Pointer p) { super(p); } + + public NDIndexPoint(@Cast("Nd4jLong") long point) { super((Pointer)null); allocate(point); } + private native void allocate(@Cast("Nd4jLong") long point); + public native @Cast("bool") boolean isInterval(); +} +@Namespace("sd") public static class NDIndexInterval extends NDIndex { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public NDIndexInterval(Pointer p) { super(p); } + public NDIndexInterval(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end, @Cast("Nd4jLong") long stride/*=1*/) { super((Pointer)null); allocate(start, end, stride); } + private native void allocate(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end, @Cast("Nd4jLong") long stride/*=1*/); + public NDIndexInterval(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end) { super((Pointer)null); allocate(start, end); } + private native void allocate(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end); + public native @Cast("bool") boolean isInterval(); +} + // namespace sd -// #endif //LIBND4J_NDINDEX_H +// #endif // LIBND4J_NDINDEX_H // Parsed from indexing/IndicesList.h @@ -3606,20 +3883,21 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_INDICESLIST_H // #include + // #include "NDIndex.h" - @Namespace("sd") @NoOffset public static class IndicesList extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public IndicesList(Pointer p) { super(p); } - +@Namespace("sd") @NoOffset public static class IndicesList extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public IndicesList(Pointer p) { super(p); } - public native int size(); - public native NDIndex at(int idx); - public native void push_back(NDIndex idx); - public native @Cast("bool") boolean isScalar(); - } -// #endif //LIBND4J_INDICESLIST_H + public native int size(); + public native NDIndex at(int idx); + public native void push_back(NDIndex idx); + public native @Cast("bool") boolean isScalar(); +} + // namespace sd +// #endif // LIBND4J_INDICESLIST_H // Parsed from graph/VariableType.h @@ -3646,15 +3924,15 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef ND4J_VARIABLE_TYPE_H // #define ND4J_VARIABLE_TYPE_H - /** enum sd::graph::VariableType */ - public static final int - NDARRAY = 0, - ARRAY_LIST = 1, - FLOW = 2, - CONSTANT = 3, - PLACEHOLDER = 4; - +/** enum sd::graph::VariableType */ +public static final int + NDARRAY = 0, + ARRAY_LIST = 1, + FLOW = 2, + CONSTANT = 3, + PLACEHOLDER = 4; + // namespace sd // #endif @@ -3683,44 +3961,45 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_INPUTLIST_H // #define LIBND4J_INPUTLIST_H +// #include // #include // #include -// #include -// #include // #include - @Namespace("sd::graph") @NoOffset public static class ArgumentsList extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ArgumentsList(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public ArgumentsList(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public ArgumentsList position(long position) { - return (ArgumentsList)super.position(position); - } - - public ArgumentsList() { super((Pointer)null); allocate(); } - private native void allocate(); - - /** - * This method returns number of argument pairs available - * - * @return - */ - public native int size(); - /** - * This method returns Pair at specified index - * - * @param index - * @return - */ - public native @ByRef Pair at(int index); +// #include +@Namespace("sd::graph") @NoOffset public static class ArgumentsList extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ArgumentsList(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public ArgumentsList(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public ArgumentsList position(long position) { + return (ArgumentsList)super.position(position); } + public ArgumentsList() { super((Pointer)null); allocate(); } + private native void allocate(); + + /** + * This method returns number of argument pairs available + * + * @return + */ + public native int size(); + /** + * This method returns Pair at specified index + * + * @param index + * @return + */ + public native @ByRef Pair at(int index); +} + // namespace graph + // namespace sd -// #endif //LIBND4J_INPUTLIST_H +// #endif // LIBND4J_INPUTLIST_H // Parsed from types/pair.h @@ -3749,29 +4028,28 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_PAIR_H // #include - @Namespace("sd") @NoOffset public static class Pair extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Pair(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Pair(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Pair position(long position) { - return (Pair)super.position(position); - } - - public Pair(int first/*=0*/, int second/*=0*/) { super((Pointer)null); allocate(first, second); } - private native void allocate(int first/*=0*/, int second/*=0*/); - public Pair() { super((Pointer)null); allocate(); } - private native void allocate(); - - public native int first(); - public native int second(); +@Namespace("sd") @NoOffset public static class Pair extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Pair(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Pair(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Pair position(long position) { + return (Pair)super.position(position); } + public Pair(int first/*=0*/, int second/*=0*/) { super((Pointer)null); allocate(first, second); } + private native void allocate(int first/*=0*/, int second/*=0*/); + public Pair() { super((Pointer)null); allocate(); } + private native void allocate(); + public native int first(); + public native int second(); +} + // namespace sd -// #endif //LIBND4J_PAIR_H +// #endif // LIBND4J_PAIR_H // Parsed from array/NDArray.h @@ -3795,1110 +4073,1338 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef NDARRAY_H // #define NDARRAY_H -// #include -// #include -// #include -// #include -// #include "legacy/NativeOpExecutioner.h" -// #include -// #include -// #include -// #include -// #include -// #include // #include // #include +// #include +// #include +// #include +// #include +// #include // #include +// #include +// #include +// #include +// #include +// #include // #include -// #include -// #include +// #include +// #include +// #include +// #include // #include // #include -// #include -// #include -// #include -// #include -// #include -// #include +// #include +// #include +// #include +// #include + +// #include +// #include // #include // #include // #include // #include +// #include "legacy/NativeOpExecutioner.h" +@Namespace("sd") public static native @ByVal NDArray mmul(@Const @ByRef NDArray arg0, @Const @ByRef NDArray arg1); +@Namespace("sd") @NoOffset public static class NDArray extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public NDArray(Pointer p) { super(p); } - @Namespace("sd") public static native @ByVal NDArray mmul(@Const @ByRef NDArray arg0, @Const @ByRef NDArray arg1); - - @Namespace("sd") @NoOffset public static class NDArray extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public NDArray(Pointer p) { super(p); } - - public NDArray() { super((Pointer)null); allocate(); } - private native void allocate(); + public NDArray() { super((Pointer)null); allocate(); } + private native void allocate(); - /** - * do not allocate memory, memory for array is passed from outside - */ + /** + * do not allocate memory, memory for array is passed from outside + */ // #ifndef __JAVACPP_HACK__ // #endif - /** - * do not allocate memory, memory for array is passed from outside - */ - public NDArray(Pointer buffer, @Cast("Nd4jLong*") LongPointer shapeInfo, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, shapeInfo, context, isBuffAlloc); } - private native void allocate(Pointer buffer, @Cast("Nd4jLong*") LongPointer shapeInfo, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isBuffAlloc/*=false*/); - public NDArray(Pointer buffer, @Cast("Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(buffer, shapeInfo); } - private native void allocate(Pointer buffer, @Cast("Nd4jLong*") LongPointer shapeInfo); - public NDArray(Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, shapeInfo, context, isBuffAlloc); } - private native void allocate(Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isBuffAlloc/*=false*/); - public NDArray(Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(buffer, shapeInfo); } - private native void allocate(Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo); - public NDArray(Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, shapeInfo, context, isBuffAlloc); } - private native void allocate(Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isBuffAlloc/*=false*/); - public NDArray(Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(buffer, shapeInfo); } - private native void allocate(Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo); + /** + * do not allocate memory, memory for array is passed from outside + */ + public NDArray(Pointer buffer, @Cast("Nd4jLong*") LongPointer shapeInfo, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, shapeInfo, context, isBuffAlloc); } + private native void allocate(Pointer buffer, @Cast("Nd4jLong*") LongPointer shapeInfo, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isBuffAlloc/*=false*/); + public NDArray(Pointer buffer, @Cast("Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(buffer, shapeInfo); } + private native void allocate(Pointer buffer, @Cast("Nd4jLong*") LongPointer shapeInfo); + public NDArray(Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, shapeInfo, context, isBuffAlloc); } + private native void allocate(Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isBuffAlloc/*=false*/); + public NDArray(Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(buffer, shapeInfo); } + private native void allocate(Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo); + public NDArray(Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, shapeInfo, context, isBuffAlloc); } + private native void allocate(Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isBuffAlloc/*=false*/); + public NDArray(Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(buffer, shapeInfo); } + private native void allocate(Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo); - /** - * do not allocate memory, memory for array is passed from outside - * we suppose the content of both (device and host) buffers is identical - */ - public NDArray(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongPointer shapeInfo, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isBuffAlloc/*=false*/, @Cast("bool") boolean isBuffDAlloc/*=false*/) { super((Pointer)null); allocate(buffer, bufferD, shapeInfo, context, isBuffAlloc, isBuffDAlloc); } - private native void allocate(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongPointer shapeInfo, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isBuffAlloc/*=false*/, @Cast("bool") boolean isBuffDAlloc/*=false*/); - public NDArray(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(buffer, bufferD, shapeInfo); } - private native void allocate(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongPointer shapeInfo); - public NDArray(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongBuffer shapeInfo, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isBuffAlloc/*=false*/, @Cast("bool") boolean isBuffDAlloc/*=false*/) { super((Pointer)null); allocate(buffer, bufferD, shapeInfo, context, isBuffAlloc, isBuffDAlloc); } - private native void allocate(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongBuffer shapeInfo, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isBuffAlloc/*=false*/, @Cast("bool") boolean isBuffDAlloc/*=false*/); - public NDArray(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(buffer, bufferD, shapeInfo); } - private native void allocate(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongBuffer shapeInfo); - public NDArray(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") long[] shapeInfo, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isBuffAlloc/*=false*/, @Cast("bool") boolean isBuffDAlloc/*=false*/) { super((Pointer)null); allocate(buffer, bufferD, shapeInfo, context, isBuffAlloc, isBuffDAlloc); } - private native void allocate(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") long[] shapeInfo, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isBuffAlloc/*=false*/, @Cast("bool") boolean isBuffDAlloc/*=false*/); - public NDArray(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(buffer, bufferD, shapeInfo); } - private native void allocate(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") long[] shapeInfo); + /** + * do not allocate memory, memory for array is passed from outside + * we suppose the content of both (device and host) buffers is identical + */ + public NDArray(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongPointer shapeInfo, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isBuffAlloc/*=false*/, @Cast("bool") boolean isBuffDAlloc/*=false*/) { super((Pointer)null); allocate(buffer, bufferD, shapeInfo, context, isBuffAlloc, isBuffDAlloc); } + private native void allocate(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongPointer shapeInfo, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isBuffAlloc/*=false*/, @Cast("bool") boolean isBuffDAlloc/*=false*/); + public NDArray(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(buffer, bufferD, shapeInfo); } + private native void allocate(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongPointer shapeInfo); + public NDArray(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongBuffer shapeInfo, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isBuffAlloc/*=false*/, @Cast("bool") boolean isBuffDAlloc/*=false*/) { super((Pointer)null); allocate(buffer, bufferD, shapeInfo, context, isBuffAlloc, isBuffDAlloc); } + private native void allocate(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongBuffer shapeInfo, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isBuffAlloc/*=false*/, @Cast("bool") boolean isBuffDAlloc/*=false*/); + public NDArray(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(buffer, bufferD, shapeInfo); } + private native void allocate(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongBuffer shapeInfo); + public NDArray(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") long[] shapeInfo, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isBuffAlloc/*=false*/, @Cast("bool") boolean isBuffDAlloc/*=false*/) { super((Pointer)null); allocate(buffer, bufferD, shapeInfo, context, isBuffAlloc, isBuffDAlloc); } + private native void allocate(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") long[] shapeInfo, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isBuffAlloc/*=false*/, @Cast("bool") boolean isBuffDAlloc/*=false*/); + public NDArray(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(buffer, bufferD, shapeInfo); } + private native void allocate(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") long[] shapeInfo); - /** - * copy constructor - */ - public NDArray(@Const @ByRef NDArray other) { super((Pointer)null); allocate(other); } - private native void allocate(@Const @ByRef NDArray other); + /** + * copy constructor + */ + public NDArray(@Const @ByRef NDArray other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef NDArray other); - /** - * move constructor - */ + /** + * move constructor + */ - /** - * constructor, create array stored at given workspace - */ - public NDArray(LaunchContext context) { super((Pointer)null); allocate(context); } - private native void allocate(LaunchContext context); + /** + * constructor, create array stored at given workspace + */ + public NDArray(LaunchContext context) { super((Pointer)null); allocate(context); } + private native void allocate(LaunchContext context); + /** + * constructor creates new NDArray using shape information from "shapeInfo", + * set all elements in new array to zeros, if copyStrides is true then use + * stride values from "shapeInfo", else calculate strides independently + */ + public NDArray(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean nullify/*=true*/); + public NDArray(@Cast("const Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(shapeInfo); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo); + public NDArray(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean nullify/*=true*/); + public NDArray(@Cast("const Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(shapeInfo); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo); + public NDArray(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); } + private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean nullify/*=true*/); + public NDArray(@Cast("const Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(shapeInfo); } + private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo); - /** - * constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently - */ - public NDArray(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean nullify/*=true*/); - public NDArray(@Cast("const Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(shapeInfo); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo); - public NDArray(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean nullify/*=true*/); - public NDArray(@Cast("const Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(shapeInfo); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - public NDArray(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); } - private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean nullify/*=true*/); - public NDArray(@Cast("const Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(shapeInfo); } - private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo); - - /** - * constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently - * set dtype as array type - */ - public NDArray(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("sd::DataType") int dtype, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("sd::DataType") int dtype, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean nullify/*=true*/); - public NDArray(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("sd::DataType") int dtype); - public NDArray(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("sd::DataType") int dtype, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("sd::DataType") int dtype, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean nullify/*=true*/); - public NDArray(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("sd::DataType") int dtype); - public NDArray(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("sd::DataType") int dtype, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); } - private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("sd::DataType") int dtype, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean nullify/*=true*/); - public NDArray(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); } - private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("sd::DataType") int dtype); - - /** - * this constructor creates new array using shape information contained in vector argument - */ - public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(order, shape, dtype, context); } - private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); - public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape) { super((Pointer)null); allocate(order, shape); } - private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape); - public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(order, shape, dtype, context); } - private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); - public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape) { super((Pointer)null); allocate(order, shape); } - private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape); - public NDArray(char order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(order, shape, dtype, context); } - private native void allocate(char order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); - public NDArray(char order, @Cast("Nd4jLong*") @StdVector long[] shape) { super((Pointer)null); allocate(order, shape); } - private native void allocate(char order, @Cast("Nd4jLong*") @StdVector long[] shape); - - /** - * This constructor creates new array with elements copied from data and using shape information stored in shape, elements from data will be casted to dtype - */ - public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @StdVector DoublePointer data, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(order, shape, data, dtype, context); } - private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @StdVector DoublePointer data, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); - public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @StdVector DoublePointer data) { super((Pointer)null); allocate(order, shape, data); } - private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @StdVector DoublePointer data); - public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @StdVector DoubleBuffer data, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(order, shape, data, dtype, context); } - private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @StdVector DoubleBuffer data, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); - public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @StdVector DoubleBuffer data) { super((Pointer)null); allocate(order, shape, data); } - private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @StdVector DoubleBuffer data); - public NDArray(char order, @Cast("Nd4jLong*") @StdVector long[] shape, @StdVector double[] data, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(order, shape, data, dtype, context); } - private native void allocate(char order, @Cast("Nd4jLong*") @StdVector long[] shape, @StdVector double[] data, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); - public NDArray(char order, @Cast("Nd4jLong*") @StdVector long[] shape, @StdVector double[] data) { super((Pointer)null); allocate(order, shape, data); } - private native void allocate(char order, @Cast("Nd4jLong*") @StdVector long[] shape, @StdVector double[] data); - - /** - * this constructor creates new array using given buffer (without memory allocation) and shape information stored in shape - */ - public NDArray(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("sd::DataType") int dtype, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, order, shape, dtype, context, isBuffAlloc); } - private native void allocate(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("sd::DataType") int dtype, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean isBuffAlloc/*=false*/); - public NDArray(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); } - private native void allocate(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("sd::DataType") int dtype); - public NDArray(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("sd::DataType") int dtype, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, order, shape, dtype, context, isBuffAlloc); } - private native void allocate(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("sd::DataType") int dtype, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean isBuffAlloc/*=false*/); - public NDArray(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); } - private native void allocate(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("sd::DataType") int dtype); - public NDArray(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("sd::DataType") int dtype, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, order, shape, dtype, context, isBuffAlloc); } - private native void allocate(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("sd::DataType") int dtype, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean isBuffAlloc/*=false*/); - public NDArray(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); } - private native void allocate(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("sd::DataType") int dtype); - - /** - * This method returns new array with the same shape & data type - * @return - */ - public native @ByVal NDArray like(); + /** + * constructor creates new NDArray using shape information from "shapeInfo", + * set all elements in new array to be zeros, if copyStrides is true then use + * stride values from "shapeInfo", else calculate strides independently set + * dtype as array type + */ + public NDArray(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("sd::DataType") int dtype, + @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("sd::DataType") int dtype, + @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean nullify/*=true*/); + public NDArray(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("sd::DataType") int dtype); + public NDArray(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("sd::DataType") int dtype, + @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("sd::DataType") int dtype, + @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean nullify/*=true*/); + public NDArray(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("sd::DataType") int dtype); + public NDArray(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("sd::DataType") int dtype, + @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); } + private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("sd::DataType") int dtype, + @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean nullify/*=true*/); + public NDArray(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); } + private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("sd::DataType") int dtype); - /** - * This method returns new uninitialized array with the same shape & data type - * @return - */ - public native @ByVal NDArray ulike(); + /** + * this constructor creates new array using shape information contained in + * vector argument + */ + public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(order, shape, dtype, context); } + private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape) { super((Pointer)null); allocate(order, shape); } + private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape); + public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(order, shape, dtype, context); } + private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape) { super((Pointer)null); allocate(order, shape); } + private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape); + public NDArray(char order, @Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(order, shape, dtype, context); } + private native void allocate(char order, @Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + public NDArray(char order, @Cast("Nd4jLong*") @StdVector long[] shape) { super((Pointer)null); allocate(order, shape); } + private native void allocate(char order, @Cast("Nd4jLong*") @StdVector long[] shape); + /** + * This constructor creates new array with elements copied from data and using + * shape information stored in shape, elements from data will be casted to + * dtype + */ + public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, + @StdVector DoublePointer data, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(order, shape, data, dtype, context); } + private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, + @StdVector DoublePointer data, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, + @StdVector DoublePointer data) { super((Pointer)null); allocate(order, shape, data); } + private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, + @StdVector DoublePointer data); + public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @StdVector DoubleBuffer data, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(order, shape, data, dtype, context); } + private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @StdVector DoubleBuffer data, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @StdVector DoubleBuffer data) { super((Pointer)null); allocate(order, shape, data); } + private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @StdVector DoubleBuffer data); + public NDArray(char order, @Cast("Nd4jLong*") @StdVector long[] shape, + @StdVector double[] data, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(order, shape, data, dtype, context); } + private native void allocate(char order, @Cast("Nd4jLong*") @StdVector long[] shape, + @StdVector double[] data, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + public NDArray(char order, @Cast("Nd4jLong*") @StdVector long[] shape, + @StdVector double[] data) { super((Pointer)null); allocate(order, shape, data); } + private native void allocate(char order, @Cast("Nd4jLong*") @StdVector long[] shape, + @StdVector double[] data); - /** - * this constructor creates new NDArray with shape matching "other" array, - * doesn't copy "other" elements into new array !!! - */ - public NDArray(@Const NDArray other, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(other, copyStrides, context); } - private native void allocate(@Const NDArray other, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + /** + * this constructor creates new array using given buffer (without memory + * allocation) and shape information stored in shape + */ + public NDArray(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("sd::DataType") int dtype, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("const bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, order, shape, dtype, context, isBuffAlloc); } + private native void allocate(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("sd::DataType") int dtype, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("const bool") boolean isBuffAlloc/*=false*/); + public NDArray(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); } + private native void allocate(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("sd::DataType") int dtype); + public NDArray(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("sd::DataType") int dtype, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("const bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, order, shape, dtype, context, isBuffAlloc); } + private native void allocate(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("sd::DataType") int dtype, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("const bool") boolean isBuffAlloc/*=false*/); + public NDArray(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); } + private native void allocate(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("sd::DataType") int dtype); + public NDArray(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("sd::DataType") int dtype, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("const bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, order, shape, dtype, context, isBuffAlloc); } + private native void allocate(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("sd::DataType") int dtype, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("const bool") boolean isBuffAlloc/*=false*/); + public NDArray(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); } + private native void allocate(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("sd::DataType") int dtype); - /** - * this constructor creates scalar(and set its value = 0) or empty array depending on bool argument isScalar - */ - public NDArray(@Cast("sd::DataType") int dtype, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isScalar/*=true*/) { super((Pointer)null); allocate(dtype, context, isScalar); } - private native void allocate(@Cast("sd::DataType") int dtype, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isScalar/*=true*/); - public NDArray(@Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(dtype); } - private native void allocate(@Cast("sd::DataType") int dtype); + /** + * This method returns new array with the same shape & data type + * @return + */ + public native @ByVal NDArray like(); - /** - * This method blocks until asynchronous operation finishes - */ - public native void synchronize(@Cast("char*") String msg); - public native void synchronize(@Cast("char*") BytePointer msg); + /** + * This method returns new uninitialized array with the same shape & data type + * @return + */ + public native @ByVal NDArray ulike(); - /** - * This method allows to set _isAttached flag - * @param reallyAttached - */ - public native void setAttached(@Cast("bool") boolean reallyAttached); + /** + * this constructor creates new NDArray with shape matching "other" array, + * doesn't copy "other" elements into new array !!! + */ + public NDArray( + @Const NDArray other, @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(other, copyStrides, context); } + private native void allocate( + @Const NDArray other, @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/); - public native void tickWriteHost(); - public native void tickWriteDevice(); - public native void tickReadHost(); - public native void tickReadDevice(); - public native void tickBothActual(); - public native @Cast("bool") boolean isActualOnHostSide(); - public native @Cast("bool") boolean isActualOnDeviceSide(); - public native void makeBothBuffersActual(); + /** + * this constructor creates scalar(and set its value = 0) or empty array + * depending on bool argument isScalar + */ + public NDArray(@Cast("sd::DataType") int dtype, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isScalar/*=true*/) { super((Pointer)null); allocate(dtype, context, isScalar); } + private native void allocate(@Cast("sd::DataType") int dtype, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isScalar/*=true*/); + public NDArray(@Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(dtype); } + private native void allocate(@Cast("sd::DataType") int dtype); - public native void syncToHost(); - public native void syncToDevice(); - public native void syncShape(); + /** + * This method blocks until asynchronous operation finishes + */ + public native void synchronize(@Cast("char*") String msg); + public native void synchronize(@Cast("char*") BytePointer msg); - /** - * This method can be used on architectures that use special buffers - * @param writeList - * @param readList - */ - public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector({})") ConstNDArrayVector readList); - public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList); - public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector({})") ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/); - public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList); + /** + * This method allows to set _isAttached flag + * @param reallyAttached + */ + public native void setAttached(@Cast("bool") boolean reallyAttached); - public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector({})") ConstNDArrayVector readList); - public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList); - public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector({})") ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/); - public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList); + public native void tickWriteHost(); + public native void tickWriteDevice(); + public native void tickReadHost(); + public native void tickReadDevice(); + public native void tickBothActual(); + public native @Cast("bool") boolean isActualOnHostSide(); + public native @Cast("bool") boolean isActualOnDeviceSide(); + public native void makeBothBuffersActual(); - /** - * This method returns buffer pointer offset by given number of elements, wrt own data type - * @param offset - * @return - */ - public native Pointer bufferWithOffset(@Cast("Nd4jLong") long offset); - public native Pointer specialBufferWithOffset(@Cast("Nd4jLong") long offset); - /** - * copy assignment operator - * in particular, when _dataType != other._dataType and both shapes are the same, there will be allocation of new _buffer and _dataType acquires other._dataType - */ - public native @ByRef @Name("operator =") NDArray put(@Const @ByRef NDArray other); + public native void syncToHost(); + public native void syncToDevice(); + public native void syncShape(); - /** - * move assignment operator - */ + /** + * This method can be used on architectures that use special buffers + * @param writeList + * @param readList + */ + public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList, + @Const @ByRef(nullValue = "std::vector({})") ConstNDArrayVector readList); + public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList); + public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, + @Const @ByRef(nullValue = "std::vector({})") ConstNDArrayVector readList, + @Cast("bool") boolean synchronizeWritables/*=false*/); + public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList); + + public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList, + @Const @ByRef(nullValue = "std::vector({})") ConstNDArrayVector readList); + public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList); + public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, + @Const @ByRef(nullValue = "std::vector({})") ConstNDArrayVector readList, + @Cast("bool") boolean synchronizeWritables/*=false*/); + public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList); - /** - * assignment operator, assigns the same scalar to all array elements - */ + /** + * This method returns buffer pointer offset by given number of elements, wrt + * own data type + * @param offset + * @return + */ + public native Pointer bufferWithOffset(@Cast("Nd4jLong") long offset); + public native Pointer specialBufferWithOffset(@Cast("Nd4jLong") long offset); + /** + * copy assignment operator + * in particular, when _dataType != other._dataType and both shapes are the + * same, there will be allocation of new _buffer and _dataType acquires + * other._dataType + */ + public native @ByRef @Name("operator =") NDArray put(@Const @ByRef NDArray other); + /** + * move assignment operator + */ - /** - * operators for memory allocation and deletion - */ - public native @Name("operator new") Pointer _new(@Cast("size_t") long i); - public native @Name("operator delete") void _delete(Pointer p); + /** + * assignment operator, assigns the same scalar to all array elements + */ + /** + * operators for memory allocation and deletion + */ + public native @Name("operator new") Pointer _new(@Cast("size_t") long i); + public native @Name("operator delete") void _delete(Pointer p); - public native void setContext(LaunchContext context); + public native void setContext(LaunchContext context); - /** - * create a new array by replicating current array by repeats times along given dimension - * axis - axis along which to repeat elements - * repeats - number of repetitions - */ - public native @ByVal NDArray repeat(int axis, @StdVector IntPointer repeats); - public native @ByVal NDArray repeat(int axis, @StdVector IntBuffer repeats); - public native @ByVal NDArray repeat(int axis, @StdVector int[] repeats); + /** + * create a new array by replicating current array by repeats times along + * given dimension axis - axis along which to repeat elements repeats - number + * of repetitions + */ + public native @ByVal NDArray repeat(int axis, @StdVector IntPointer repeats); + public native @ByVal NDArray repeat(int axis, @StdVector IntBuffer repeats); + public native @ByVal NDArray repeat(int axis, @StdVector int[] repeats); - /** - * This method fills this array with zeros - */ - public native void nullify(); + /** + * This method fills this array with zeros + */ + public native void nullify(); - /** - * This method returns quantized copy of given array - * - * @param array - * @return - */ - public static native @ByVal NDArray quantize(@Const @ByRef NDArray array); + /** + * This method returns quantized copy of given array + * + * @param array + * @return + */ + public static native @ByVal NDArray quantize(@Const @ByRef NDArray array); - /** - * fill target array by repeating current array - * axis - axis along which to repeat elements - * repeats - vector containing numbers of repetition for elements at given axis - */ - public native void repeat(int axis, @StdVector IntPointer repeats, @ByRef NDArray target); - public native void repeat(int axis, @StdVector IntBuffer repeats, @ByRef NDArray target); - public native void repeat(int axis, @StdVector int[] repeats, @ByRef NDArray target); + /** + * fill target array by repeating current array + * axis - axis along which to repeat elements + * repeats - vector containing numbers of repetition for elements at given + * axis + */ + public native void repeat(int axis, @StdVector IntPointer repeats, + @ByRef NDArray target); + public native void repeat(int axis, @StdVector IntBuffer repeats, + @ByRef NDArray target); + public native void repeat(int axis, @StdVector int[] repeats, + @ByRef NDArray target); - /** - * creates array which points on certain sub-range of this array, sub-range is defined by given indices - */ - - - + /** + * creates array which points on certain sub-range of this array, sub-range + * is defined by given indices + */ + + + - /** - * cast array elements to given dtype - */ - public native @ByVal NDArray cast(@Cast("sd::DataType") int dtype); + /** + * cast array elements to given dtype + */ + public native @ByVal NDArray cast(@Cast("sd::DataType") int dtype); - public native void cast(@ByRef NDArray target, @Cast("sd::DataType") int dtype); + public native void cast(@ByRef NDArray target, @Cast("sd::DataType") int dtype); - /** - * returns _context - */ - public native LaunchContext getContext(); + /** + * returns _context + */ + public native LaunchContext getContext(); // #ifndef __JAVACPP_HACK__ // #endif - /** - * returns host buffer - */ - public native Pointer buffer(); - - - /** - * returns buffer offset (offset is the same for host and device buffers) - */ - public native @Cast("Nd4jLong") long bufferOffset(); - - /** - * if _bufferD==nullptr return _buffer, else return _bufferD - */ - public native Pointer specialBuffer(); + /** + * returns host buffer + */ + public native Pointer buffer(); - /** - * returns device buffer if compilation is for cuda case, otherwise returns host buffer - */ - public native Pointer platformBuffer(); + /** + * returns buffer offset (offset is the same for host and device buffers) + */ + public native @Cast("Nd4jLong") long bufferOffset(); - /** - * returns _shapeInfo - */ - public native @Cast("const Nd4jLong*") LongPointer shapeInfo(); + /** + * if _bufferD==nullptr return _buffer, else return _bufferD + */ + public native Pointer specialBuffer(); + /** + * returns device buffer if compilation is for cuda case, otherwise returns + * host buffer + */ + public native Pointer platformBuffer(); - /** - * Returns True if it's legally empty NDArray, or false otherwise - * @return - */ - public native @Cast("bool") boolean isEmpty(); + /** + * returns _shapeInfo + */ + public native @Cast("const Nd4jLong*") LongPointer shapeInfo(); - /** - * if _shapeInfoD==nullptr return _shapeInfo, else return _shapeInfoD - */ - public native @Cast("const Nd4jLong*") LongPointer specialShapeInfo(); + /** + * Returns True if it's legally empty NDArray, or false otherwise + * @return + */ + public native @Cast("bool") boolean isEmpty(); - public native @Cast("const Nd4jLong*") LongPointer platformShapeInfo(); + /** + * if _shapeInfoD==nullptr return _shapeInfo, else return _shapeInfoD + */ + public native @Cast("const Nd4jLong*") LongPointer specialShapeInfo(); - /** - * permutes (in-place) the dimensions in array according to "dimensions" array - */ - public native @Cast("bool") boolean permutei(@StdVector IntPointer dimensions); - public native @Cast("bool") boolean permutei(@StdVector IntBuffer dimensions); - public native @Cast("bool") boolean permutei(@StdVector int[] dimensions); - public native @Cast("bool") boolean permutei(@Const IntPointer dimensions, int rank); - public native @Cast("bool") boolean permutei(@Const IntBuffer dimensions, int rank); - public native @Cast("bool") boolean permutei(@Const int[] dimensions, int rank); - public native @Cast("bool") boolean permutei(@Cast("Nd4jLong*") @StdVector LongPointer dimensions); - public native @Cast("bool") boolean permutei(@Cast("Nd4jLong*") @StdVector LongBuffer dimensions); - public native @Cast("bool") boolean permutei(@Cast("Nd4jLong*") @StdVector long[] dimensions); - public native @Cast("bool") boolean permutei(@Cast("const Nd4jLong*") LongPointer dimensions, int rank); - public native @Cast("bool") boolean permutei(@Cast("const Nd4jLong*") LongBuffer dimensions, int rank); - public native @Cast("bool") boolean permutei(@Cast("const Nd4jLong*") long[] dimensions, int rank); - - public native @Cast("bool") boolean isFinite(); - public native @Cast("bool") boolean hasNaNs(); - public native @Cast("bool") boolean hasInfs(); - - public native void copyBuffersContinuouslyFrom(@Const @ByRef NDArray other, @Cast("size_t") long sizeToCopyInBytes/*=0*/, @Cast("Nd4jLong") long offsetThis/*=0*/, @Cast("Nd4jLong") long offsetOther/*=0*/); - public native void copyBuffersContinuouslyFrom(@Const @ByRef NDArray other); + public native @Cast("const Nd4jLong*") LongPointer platformShapeInfo(); - /** - * permutes the dimensions in array according to "dimensions" array, new array points on _buffer of this array - */ - public native @ByVal NDArray permute(@StdVector IntPointer dimensions); - public native @ByVal NDArray permute(@StdVector IntBuffer dimensions); - public native @ByVal NDArray permute(@StdVector int[] dimensions); - public native @ByVal NDArray permute(@Const IntPointer dimensions, int rank); - public native @ByVal NDArray permute(@Const IntBuffer dimensions, int rank); - public native @ByVal NDArray permute(@Const int[] dimensions, int rank); - - - + /** + * permutes (in-place) the dimensions in array according to "dimensions" + * array + */ + public native @Cast("bool") boolean permutei(@StdVector IntPointer dimensions); + public native @Cast("bool") boolean permutei(@StdVector IntBuffer dimensions); + public native @Cast("bool") boolean permutei(@StdVector int[] dimensions); + public native @Cast("bool") boolean permutei(@Const IntPointer dimensions, int rank); + public native @Cast("bool") boolean permutei(@Const IntBuffer dimensions, int rank); + public native @Cast("bool") boolean permutei(@Const int[] dimensions, int rank); + public native @Cast("bool") boolean permutei(@Cast("Nd4jLong*") @StdVector LongPointer dimensions); + public native @Cast("bool") boolean permutei(@Cast("Nd4jLong*") @StdVector LongBuffer dimensions); + public native @Cast("bool") boolean permutei(@Cast("Nd4jLong*") @StdVector long[] dimensions); + public native @Cast("bool") boolean permutei(@Cast("const Nd4jLong*") LongPointer dimensions, int rank); + public native @Cast("bool") boolean permutei(@Cast("const Nd4jLong*") LongBuffer dimensions, int rank); + public native @Cast("bool") boolean permutei(@Cast("const Nd4jLong*") long[] dimensions, int rank); + + public native @Cast("bool") boolean isFinite(); + public native @Cast("bool") boolean hasNaNs(); + public native @Cast("bool") boolean hasInfs(); + + public native void copyBuffersContinuouslyFrom(@Const @ByRef NDArray other, + @Cast("size_t") long sizeToCopyInBytes/*=0*/, + @Cast("Nd4jLong") long offsetThis/*=0*/, + @Cast("Nd4jLong") long offsetOther/*=0*/); + public native void copyBuffersContinuouslyFrom(@Const @ByRef NDArray other); - public native void permute(@Const IntPointer dimensions, int rank, @ByRef NDArray target); - public native void permute(@Const IntBuffer dimensions, int rank, @ByRef NDArray target); - public native void permute(@Const int[] dimensions, int rank, @ByRef NDArray target); - public native void permute(@StdVector IntPointer dimensions, @ByRef NDArray target); - public native void permute(@StdVector IntBuffer dimensions, @ByRef NDArray target); - public native void permute(@StdVector int[] dimensions, @ByRef NDArray target); - public native @ByVal NDArray permute(@Cast("Nd4jLong*") @StdVector LongPointer dimensions); - public native @ByVal NDArray permute(@Cast("Nd4jLong*") @StdVector LongBuffer dimensions); - public native @ByVal NDArray permute(@Cast("Nd4jLong*") @StdVector long[] dimensions); - public native @ByVal NDArray permute(@Cast("const Nd4jLong*") LongPointer dimensions, int rank); - public native @ByVal NDArray permute(@Cast("const Nd4jLong*") LongBuffer dimensions, int rank); - public native @ByVal NDArray permute(@Cast("const Nd4jLong*") long[] dimensions, int rank); - - - + /** + * permutes the dimensions in array according to "dimensions" array, new + * array points on _buffer of this array + */ + public native @ByVal NDArray permute(@StdVector IntPointer dimensions); + public native @ByVal NDArray permute(@StdVector IntBuffer dimensions); + public native @ByVal NDArray permute(@StdVector int[] dimensions); + public native @ByVal NDArray permute(@Const IntPointer dimensions, int rank); + public native @ByVal NDArray permute(@Const IntBuffer dimensions, int rank); + public native @ByVal NDArray permute(@Const int[] dimensions, int rank); + + + + + public native void permute(@Const IntPointer dimensions, int rank, @ByRef NDArray target); + public native void permute(@Const IntBuffer dimensions, int rank, @ByRef NDArray target); + public native void permute(@Const int[] dimensions, int rank, @ByRef NDArray target); + public native void permute(@StdVector IntPointer dimensions, @ByRef NDArray target); + public native void permute(@StdVector IntBuffer dimensions, @ByRef NDArray target); + public native void permute(@StdVector int[] dimensions, @ByRef NDArray target); + public native @ByVal NDArray permute(@Cast("Nd4jLong*") @StdVector LongPointer dimensions); + public native @ByVal NDArray permute(@Cast("Nd4jLong*") @StdVector LongBuffer dimensions); + public native @ByVal NDArray permute(@Cast("Nd4jLong*") @StdVector long[] dimensions); + public native @ByVal NDArray permute(@Cast("const Nd4jLong*") LongPointer dimensions, int rank); + public native @ByVal NDArray permute(@Cast("const Nd4jLong*") LongBuffer dimensions, int rank); + public native @ByVal NDArray permute(@Cast("const Nd4jLong*") long[] dimensions, int rank); + + + + + public native void permute(@Cast("const Nd4jLong*") LongPointer dimensions, int rank, + @ByRef NDArray target); + public native void permute(@Cast("const Nd4jLong*") LongBuffer dimensions, int rank, + @ByRef NDArray target); + public native void permute(@Cast("const Nd4jLong*") long[] dimensions, int rank, + @ByRef NDArray target); + public native void permute(@Cast("Nd4jLong*") @StdVector LongPointer dimensions, @ByRef NDArray target); + public native void permute(@Cast("Nd4jLong*") @StdVector LongBuffer dimensions, @ByRef NDArray target); + public native void permute(@Cast("Nd4jLong*") @StdVector long[] dimensions, @ByRef NDArray target); - public native void permute(@Cast("const Nd4jLong*") LongPointer dimensions, int rank, @ByRef NDArray target); - public native void permute(@Cast("const Nd4jLong*") LongBuffer dimensions, int rank, @ByRef NDArray target); - public native void permute(@Cast("const Nd4jLong*") long[] dimensions, int rank, @ByRef NDArray target); - public native void permute(@Cast("Nd4jLong*") @StdVector LongPointer dimensions, @ByRef NDArray target); - public native void permute(@Cast("Nd4jLong*") @StdVector LongBuffer dimensions, @ByRef NDArray target); - public native void permute(@Cast("Nd4jLong*") @StdVector long[] dimensions, @ByRef NDArray target); + /** + * This method streamlines given view or permuted array, and reallocates + * buffer + */ + public native void streamline(char order/*='a'*/); + public native void streamline(); - /** - * This method streamlines given view or permuted array, and reallocates buffer - */ - public native void streamline(char order/*='a'*/); - public native void streamline(); + /** + * prints information about array shape + * msg - message to print out + */ + public native void printShapeInfo(@Cast("char*") String msg/*=nullptr*/); + public native void printShapeInfo(); + public native void printShapeInfo(@Cast("char*") BytePointer msg/*=nullptr*/); - /** - * prints information about array shape - * msg - message to print out - */ - public native void printShapeInfo(@Cast("char*") String msg/*=nullptr*/); - public native void printShapeInfo(); - public native void printShapeInfo(@Cast("char*") BytePointer msg/*=nullptr*/); + /** + * prints buffer elements + * msg - message to print out + * limit - number of array elements to print out + * sync - if true check whether host buffer is actual, if it is not then make + * it so + */ + public native void printBuffer(@Cast("char*") String msg/*=nullptr*/, @Cast("Nd4jLong") long _limit/*=-1*/, + @Cast("const bool") boolean sync/*=true*/); + public native void printBuffer(); + public native void printBuffer(@Cast("char*") BytePointer msg/*=nullptr*/, @Cast("Nd4jLong") long _limit/*=-1*/, + @Cast("const bool") boolean sync/*=true*/); - /** - * prints buffer elements - * msg - message to print out - * limit - number of array elements to print out - * sync - if true check whether host buffer is actual, if it is not then make it so - */ - public native void printBuffer(@Cast("char*") String msg/*=nullptr*/, @Cast("Nd4jLong") long _limit/*=-1*/, @Cast("const bool") boolean sync/*=true*/); - public native void printBuffer(); - public native void printBuffer(@Cast("char*") BytePointer msg/*=nullptr*/, @Cast("Nd4jLong") long _limit/*=-1*/, @Cast("const bool") boolean sync/*=true*/); + /** + * make strings for current ndarray - linear or structured by dimensions + * */ + public native @StdString BytePointer linearString(@Cast("Nd4jLong") long _limit/*=-1*/); + public native @StdString BytePointer linearString(); + public native @StdString BytePointer indexedBufferString(@Cast("Nd4jLong") long _limit/*=-1*/); + public native @StdString BytePointer indexedBufferString(); - /** - * print element by element consequently in a way they (elements) are stored in physical memory - */ - public native void printLinearBuffer(); + /** + * print element by element consequently in a way they (elements) are stored + * in physical memory + */ + public native void printLinearBuffer(); - /** - * prints _buffer (if host = true) or _bufferD (if host = false) as it is, that is in current state without checking buffer status - */ + /** + * prints _buffer (if host = true) or _bufferD (if host = false) as it is, + * that is in current state without checking buffer status + */ - /** - * prints buffer elements, takes into account offset between elements (element-wise-stride) - * msg - message to print out - * limit - number of array elements to print out - */ - public native void printIndexedBuffer(@Cast("char*") String msg/*=nullptr*/, @Cast("Nd4jLong") long _limit/*=-1*/); - public native void printIndexedBuffer(); - public native void printIndexedBuffer(@Cast("char*") BytePointer msg/*=nullptr*/, @Cast("Nd4jLong") long _limit/*=-1*/); + /** + * prints buffer elements, takes into account offset between elements + * (element-wise-stride) msg - message to print out limit - number of array + * elements to print out + */ + public native void printIndexedBuffer(@Cast("char*") String msg/*=nullptr*/, @Cast("Nd4jLong") long _limit/*=-1*/); + public native void printIndexedBuffer(); + public native void printIndexedBuffer(@Cast("char*") BytePointer msg/*=nullptr*/, @Cast("Nd4jLong") long _limit/*=-1*/); - public native @StdString BytePointer asIndexedString(@Cast("Nd4jLong") long _limit/*=-1*/); - public native @StdString BytePointer asIndexedString(); - public native @StdString BytePointer asString(@Cast("Nd4jLong") long _limit/*=-1*/); - public native @StdString BytePointer asString(); + public native @StdString BytePointer asIndexedString(@Cast("Nd4jLong") long _limit/*=-1*/); + public native @StdString BytePointer asIndexedString(); + public native @StdString BytePointer asString(@Cast("Nd4jLong") long _limit/*=-1*/); + public native @StdString BytePointer asString(); - /** - * this method assigns values of given array to this one - */ - public native void assign(@Const NDArray other, @Cast("bool") boolean allowParallelism/*=true*/); - public native void assign(@Const NDArray other); + /** + * this method assigns values of given array to this one + */ + public native void assign(@Const NDArray other, @Cast("bool") boolean allowParallelism/*=true*/); + public native void assign(@Const NDArray other); - /** - * this method assigns values of given array to this one - */ + /** + * this method assigns values of given array to this one + */ - /** - * this method assigns given value to all elements in array - */ + /** + * this method assigns given value to all elements in array + */ - /** - * returns new copy of this array, optionally in different order - */ - public native @ByVal NDArray dup(byte newOrder/*='a'*/); - public native @ByVal NDArray dup(); + /** + * returns new copy of this array, optionally in different order + */ + public native @ByVal NDArray dup(byte newOrder/*='a'*/); + public native @ByVal NDArray dup(); - /** - * returns sum of all elements of array - */ - public native @ByVal NDArray sumNumber(); + /** + * returns sum of all elements of array + */ + public native @ByVal NDArray sumNumber(); - /** - * returns mean number of array - */ - public native @ByVal NDArray meanNumber(); + /** + * returns mean number of array + */ + public native @ByVal NDArray meanNumber(); // #ifndef __JAVACPP_HACK__ // #endif - /** - * apply transpose operation to the copy of this array, that is this array remains unaffected - */ - public native @ByVal NDArray transpose(); - - - /** - * perform transpose operation and store result in target, this array remains unaffected - * target - where to store result - */ - public native void transpose(@ByRef NDArray target); + /** + * apply transpose operation to the copy of this array, that is this array + * remains unaffected + */ + public native @ByVal NDArray transpose(); + - /** - * apply in-place transpose operation to this array, so this array becomes transposed - */ - public native void transposei(); + /** + * perform transpose operation and store result in target, this array remains + * unaffected target - where to store result + */ + public native void transpose(@ByRef NDArray target); - /** - * returns the number of arrays pointing on specified dimension(s) - * dimensions - array of dimensions to point on - */ - public native @Cast("Nd4jLong") long tensorsAlongDimension(@StdVector IntPointer dimensions); - public native @Cast("Nd4jLong") long tensorsAlongDimension(@StdVector IntBuffer dimensions); - public native @Cast("Nd4jLong") long tensorsAlongDimension(@StdVector int[] dimensions); + /** + * apply in-place transpose operation to this array, so this array becomes + * transposed + */ + public native void transposei(); - /** - * returns true if elements of two arrays are equal to within given epsilon value - * other - input array to compare - * eps - epsilon, this value defines the precision of elements comparison - */ - public native @Cast("bool") boolean equalsTo(@Const NDArray other, double eps/*=1e-5*/); - public native @Cast("bool") boolean equalsTo(@Const NDArray other); + /** + * returns the number of arrays pointing on specified dimension(s) + * dimensions - array of dimensions to point on + */ + public native @Cast("Nd4jLong") long tensorsAlongDimension(@StdVector IntPointer dimensions); + public native @Cast("Nd4jLong") long tensorsAlongDimension(@StdVector IntBuffer dimensions); + public native @Cast("Nd4jLong") long tensorsAlongDimension(@StdVector int[] dimensions); - /** - * add given row vector to all rows of this array - * row - row vector to add - */ - public native void addiRowVector(@Const @ByRef NDArray row); + /** + * returns true if elements of two arrays are equal to within given epsilon + * value other - input array to compare eps - epsilon, this value defines the + * precision of elements comparison + */ + public native @Cast("bool") boolean equalsTo(@Const NDArray other, double eps/*=1e-5*/); + public native @Cast("bool") boolean equalsTo(@Const NDArray other); - /** - * add given row vector to all rows of this array, store result in target - * row - row vector to add - * target - where to store result - */ - public native void addRowVector(@Const @ByRef NDArray row, @ByRef NDArray target); + /** + * add given row vector to all rows of this array + * row - row vector to add + */ + public native void addiRowVector(@Const @ByRef NDArray row); - /** - * subtract given row vector from all rows of this array, store result in target - * row - row vector to subtract - * target - where to store result - */ - public native void subRowVector(@Const @ByRef NDArray row, @ByRef NDArray target); + /** + * add given row vector to all rows of this array, store result in target + * row - row vector to add + * target - where to store result + */ + public native void addRowVector(@Const @ByRef NDArray row, @ByRef NDArray target); - /** - * multiply all rows of this array on given row vector, store result in target - * row - row vector to multiply on - * target - where to store result - */ - public native void mulRowVector(@Const @ByRef NDArray row, @ByRef NDArray target); + /** + * subtract given row vector from all rows of this array, store result in + * target row - row vector to subtract target - where to store result + */ + public native void subRowVector(@Const @ByRef NDArray row, @ByRef NDArray target); - /** - * divide all rows of this array on given row vector, store result in target - * row - row vector to divide on - * target - where to store result - */ - public native void divRowVector(@Const @ByRef NDArray row, @ByRef NDArray target); + /** + * multiply all rows of this array on given row vector, store result in + * target row - row vector to multiply on target - where to store result + */ + public native void mulRowVector(@Const @ByRef NDArray row, @ByRef NDArray target); - /** - * add given column vector to all columns of this array, store result in target - * column - column vector to add - * target - where to store result - */ - public native void addColumnVector(@Const @ByRef NDArray column, @ByRef NDArray target); + /** + * divide all rows of this array on given row vector, store result in target + * row - row vector to divide on + * target - where to store result + */ + public native void divRowVector(@Const @ByRef NDArray row, @ByRef NDArray target); - /** - * add given column vector to all columns of this array, this array becomes affected (in-place operation) - * column - column vector to add - */ - public native void addiColumnVector(@Const @ByRef NDArray column); + /** + * add given column vector to all columns of this array, store result in + * target column - column vector to add target - where to store result + */ + public native void addColumnVector(@Const @ByRef NDArray column, @ByRef NDArray target); - /** - * multiply all columns of this array on given column vector, this array becomes affected (in-place operation) - * column - column vector to multiply on - */ - public native void muliColumnVector(@Const @ByRef NDArray column); + /** + * add given column vector to all columns of this array, this array becomes + * affected (in-place operation) column - column vector to add + */ + public native void addiColumnVector(@Const @ByRef NDArray column); - /** - * returns number of bytes used by _buffer & _shapeInfo - */ - public native @Cast("Nd4jLong") long memoryFootprint(); + /** + * multiply all columns of this array on given column vector, this array + * becomes affected (in-place operation) column - column vector to multiply on + */ + public native void muliColumnVector(@Const @ByRef NDArray column); - /** - * these methods suited for FlatBuffers use - */ - public native @Cast("Nd4jLong*") @StdVector LongPointer getShapeAsVector(); - public native @StdVector IntPointer getShapeAsVectorInt(); - public native @Cast("Nd4jLong*") @StdVector LongPointer getShapeInfoAsVector(); - public native @Cast("int64_t*") @StdVector LongPointer getShapeInfoAsFlatVector(); - public native @Cast("int64_t*") @StdVector LongPointer getShapeAsFlatVector(); + /** + * returns number of bytes used by _buffer & _shapeInfo + */ + public native @Cast("Nd4jLong") long memoryFootprint(); - /** - * set new order and shape in case of suitable array length (in-place operation) - * order - order to set - * shape - shape to set - * copyToNewBuff - if true then old buffer will be copied to new buffer if last one will be allocated after reshaping - * if there was permute applied before or there are weird strides, then new buffer is allocated for array - */ - public native @Cast("bool") boolean reshapei(byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("const bool") boolean copyToNewBuff/*=true*/); - public native @Cast("bool") boolean reshapei(byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape); - public native @Cast("bool") boolean reshapei(byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("const bool") boolean copyToNewBuff/*=true*/); - public native @Cast("bool") boolean reshapei(byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape); - public native @Cast("bool") boolean reshapei(byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("const bool") boolean copyToNewBuff/*=true*/); - public native @Cast("bool") boolean reshapei(byte order, @Cast("Nd4jLong*") @StdVector long[] shape); - public native @Cast("bool") boolean reshapei(@Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("const bool") boolean copyToNewBuff/*=true*/); - public native @Cast("bool") boolean reshapei(@Cast("Nd4jLong*") @StdVector LongPointer shape); - public native @Cast("bool") boolean reshapei(@Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("const bool") boolean copyToNewBuff/*=true*/); - public native @Cast("bool") boolean reshapei(@Cast("Nd4jLong*") @StdVector LongBuffer shape); - public native @Cast("bool") boolean reshapei(@Cast("Nd4jLong*") @StdVector long[] shape, @Cast("const bool") boolean copyToNewBuff/*=true*/); - public native @Cast("bool") boolean reshapei(@Cast("Nd4jLong*") @StdVector long[] shape); + /** + * these methods suited for FlatBuffers use + */ + public native @Cast("Nd4jLong*") @StdVector LongPointer getShapeAsVector(); + public native @StdVector IntPointer getShapeAsVectorInt(); + public native @Cast("Nd4jLong*") @StdVector LongPointer getShapeInfoAsVector(); + public native @Cast("int64_t*") @StdVector LongPointer getShapeInfoAsFlatVector(); + public native @Cast("int64_t*") @StdVector LongPointer getShapeAsFlatVector(); - /** - * creates new array with corresponding order and shape, new array will point on _buffer of this array - * order - order to set - * shape - shape to set - * - * if permute have been applied before or there are weird strides, then new buffer is allocated for new array - */ - public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("const bool") boolean copyToNewBuff/*=true*/); - public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape); - public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("const bool") boolean copyToNewBuff/*=true*/); - public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape); - public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("const bool") boolean copyToNewBuff/*=true*/); - public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector long[] shape); - + /** + * set new order and shape in case of suitable array length (in-place + * operation) order - order to set shape - shape to set copyToNewBuff - if + * true then old buffer will be copied to new buffer if last one will be + * allocated after reshaping if there was permute applied before or there are + * weird strides, then new buffer is allocated for array + */ + public native @Cast("bool") boolean reshapei(byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("const bool") boolean copyToNewBuff/*=true*/); + public native @Cast("bool") boolean reshapei(byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape); + public native @Cast("bool") boolean reshapei(byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("const bool") boolean copyToNewBuff/*=true*/); + public native @Cast("bool") boolean reshapei(byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape); + public native @Cast("bool") boolean reshapei(byte order, @Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("const bool") boolean copyToNewBuff/*=true*/); + public native @Cast("bool") boolean reshapei(byte order, @Cast("Nd4jLong*") @StdVector long[] shape); + public native @Cast("bool") boolean reshapei(@Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("const bool") boolean copyToNewBuff/*=true*/); + public native @Cast("bool") boolean reshapei(@Cast("Nd4jLong*") @StdVector LongPointer shape); + public native @Cast("bool") boolean reshapei(@Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("const bool") boolean copyToNewBuff/*=true*/); + public native @Cast("bool") boolean reshapei(@Cast("Nd4jLong*") @StdVector LongBuffer shape); + public native @Cast("bool") boolean reshapei(@Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("const bool") boolean copyToNewBuff/*=true*/); + public native @Cast("bool") boolean reshapei(@Cast("Nd4jLong*") @StdVector long[] shape); - /** - * calculate strides and set given order - * order - order to set - */ - public native void updateStrides(byte order); + /** + * creates new array with corresponding order and shape, new array will point + * on _buffer of this array order - order to set shape - shape to set + * + * if permute have been applied before or there are weird strides, then new + * buffer is allocated for new array + */ + public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("const bool") boolean copyToNewBuff/*=true*/); + public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape); + public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("const bool") boolean copyToNewBuff/*=true*/); + public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape); + public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("const bool") boolean copyToNewBuff/*=true*/); + public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector long[] shape); + - /** - * change an array by repeating it the number of times given by reps (in-place operation) - * repeats - contains numbers of repetitions - */ - public native void tilei(@Cast("Nd4jLong*") @StdVector LongPointer repeats); - public native void tilei(@Cast("Nd4jLong*") @StdVector LongBuffer repeats); - public native void tilei(@Cast("Nd4jLong*") @StdVector long[] repeats); + /** + * calculate strides and set given order + * order - order to set + */ + public native void updateStrides(byte order); - /** - * returns new array which is created by repeating of this array the number of times given by reps - * repeats - contains numbers of repetitions - */ - public native @ByVal NDArray tile(@Cast("Nd4jLong*") @StdVector LongPointer repeats); - public native @ByVal NDArray tile(@Cast("Nd4jLong*") @StdVector LongBuffer repeats); - public native @ByVal NDArray tile(@Cast("Nd4jLong*") @StdVector long[] repeats); + /** + * change an array by repeating it the number of times given by reps + * (in-place operation) repeats - contains numbers of repetitions + */ + public native void tilei(@Cast("Nd4jLong*") @StdVector LongPointer repeats); + public native void tilei(@Cast("Nd4jLong*") @StdVector LongBuffer repeats); + public native void tilei(@Cast("Nd4jLong*") @StdVector long[] repeats); - /** - * change an array by repeating it the number of times given by reps (in-place operation) - * repeats - contains numbers of repetitions - * target - where to store result - */ - public native void tile(@Cast("Nd4jLong*") @StdVector LongPointer repeats, @ByRef NDArray target); - public native void tile(@Cast("Nd4jLong*") @StdVector LongBuffer repeats, @ByRef NDArray target); - public native void tile(@Cast("Nd4jLong*") @StdVector long[] repeats, @ByRef NDArray target); + /** + * returns new array which is created by repeating of this array the number + * of times given by reps repeats - contains numbers of repetitions + */ + public native @ByVal NDArray tile(@Cast("Nd4jLong*") @StdVector LongPointer repeats); + public native @ByVal NDArray tile(@Cast("Nd4jLong*") @StdVector LongBuffer repeats); + public native @ByVal NDArray tile(@Cast("Nd4jLong*") @StdVector long[] repeats); - /** - * change an array by repeating it the number of times to acquire the new shape which is the same as target shape - * target - where to store result - */ - public native void tile(@ByRef NDArray target); + /** + * change an array by repeating it the number of times given by reps + * (in-place operation) repeats - contains numbers of repetitions target - + * where to store result + */ + public native void tile(@Cast("Nd4jLong*") @StdVector LongPointer repeats, @ByRef NDArray target); + public native void tile(@Cast("Nd4jLong*") @StdVector LongBuffer repeats, @ByRef NDArray target); + public native void tile(@Cast("Nd4jLong*") @StdVector long[] repeats, @ByRef NDArray target); - /** - * check whether array is identity matrix - */ - public native @Cast("bool") boolean isIdentityMatrix(); + /** + * change an array by repeating it the number of times to acquire the new + * shape which is the same as target shape target - where to store result + */ + public native void tile(@ByRef NDArray target); - /** - * check whether array is unitary matrix - */ - public native @Cast("bool") boolean isUnitary(); + /** + * check whether array is identity matrix + */ + public native @Cast("bool") boolean isIdentityMatrix(); - /** - * operator returns subarray with buffer pointing at this->_buffer with offset defined by given intervals - * idx - intervals of indexes which define the subarrays to point on, idx has form {dim0Start,dim0End, dim1Start,dim1End, ....} and length (2 * this->rankOf()) - * when (dimStart == dimEnd) then whole range will be used for current dimension - * keepUnitiesInShape - if false then eliminate unities from resulting array shape, for example {1,a,1,b} -> {a,b} - * isStrided - if true then idx has length (3 * this->rankOf()) and contains additional stride numbers which correspond to stride between dimStart and dimEnd, - * so structure of idx is like {dim0Start,dim0End,dim0Stride, dim1Start,dim1End,dim1Stride, ....} - */ - public native @ByVal @Name("operator ()") NDArray apply(@Cast("Nd4jLong*") @StdVector LongPointer idx, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/); - public native @ByVal @Name("operator ()") NDArray apply(@Cast("Nd4jLong*") @StdVector LongPointer idx); - public native @ByVal @Name("operator ()") NDArray apply(@Cast("Nd4jLong*") @StdVector LongBuffer idx, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/); - public native @ByVal @Name("operator ()") NDArray apply(@Cast("Nd4jLong*") @StdVector LongBuffer idx); - public native @ByVal @Name("operator ()") NDArray apply(@Cast("Nd4jLong*") @StdVector long[] idx, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/); - public native @ByVal @Name("operator ()") NDArray apply(@Cast("Nd4jLong*") @StdVector long[] idx); + /** + * check whether array is unitary matrix + */ + public native @Cast("bool") boolean isUnitary(); - /** - * evaluates subarray with buffer pointing at this->_buffer and offset defined by given sequential index subArrIdx and dimensions in dimsToExclude - * subArrIdx - index of current sub-array - * dimsToExclude - MUST BE SORTED, dimensions to evaluate sub-array along, i.e. when shape is [2,3,4,5] and dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5], and subArrIdx must be in range [0,7] - * if dimsToExclude is empty then idxRanges containing all zeros (means whole array) will be returned. - * keepUnitiesInShape - if false then eliminate unities from resulting array shape, for example {1,a,1,b} -> {a,b} - */ - public native @ByVal @Name("operator ()") NDArray apply(@Cast("const Nd4jLong") long subArrIdx, @StdVector IntPointer dimsToExclude, @Cast("bool") boolean keepUnitiesInShape/*=false*/); - public native @ByVal @Name("operator ()") NDArray apply(@Cast("const Nd4jLong") long subArrIdx, @StdVector IntPointer dimsToExclude); - public native @ByVal @Name("operator ()") NDArray apply(@Cast("const Nd4jLong") long subArrIdx, @StdVector IntBuffer dimsToExclude, @Cast("bool") boolean keepUnitiesInShape/*=false*/); - public native @ByVal @Name("operator ()") NDArray apply(@Cast("const Nd4jLong") long subArrIdx, @StdVector IntBuffer dimsToExclude); - public native @ByVal @Name("operator ()") NDArray apply(@Cast("const Nd4jLong") long subArrIdx, @StdVector int[] dimsToExclude, @Cast("bool") boolean keepUnitiesInShape/*=false*/); - public native @ByVal @Name("operator ()") NDArray apply(@Cast("const Nd4jLong") long subArrIdx, @StdVector int[] dimsToExclude); + /** + * operator returns subarray with buffer pointing at this->_buffer with + * offset defined by given intervals idx - intervals of indexes which define + * the subarrays to point on, idx has form {dim0Start,dim0End, + * dim1Start,dim1End, ....} and length (2 * this->rankOf()) when (dimStart == + * dimEnd) then whole range will be used for current dimension + * keepUnitiesInShape - if false then eliminate unities from resulting array + * shape, for example {1,a,1,b} -> {a,b} isStrided - if true then idx has + * length (3 * this->rankOf()) and contains additional stride numbers which + * correspond to stride between dimStart and dimEnd, so structure of idx is + * like {dim0Start,dim0End,dim0Stride, dim1Start,dim1End,dim1Stride, ....} + */ + public native @ByVal @Name("operator ()") NDArray apply(@Cast("Nd4jLong*") @StdVector LongPointer idx, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/, + @Cast("const bool") boolean isStrided/*=false*/); + public native @ByVal @Name("operator ()") NDArray apply(@Cast("Nd4jLong*") @StdVector LongPointer idx); + public native @ByVal @Name("operator ()") NDArray apply(@Cast("Nd4jLong*") @StdVector LongBuffer idx, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/, + @Cast("const bool") boolean isStrided/*=false*/); + public native @ByVal @Name("operator ()") NDArray apply(@Cast("Nd4jLong*") @StdVector LongBuffer idx); + public native @ByVal @Name("operator ()") NDArray apply(@Cast("Nd4jLong*") @StdVector long[] idx, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/, + @Cast("const bool") boolean isStrided/*=false*/); + public native @ByVal @Name("operator ()") NDArray apply(@Cast("Nd4jLong*") @StdVector long[] idx); - /** - * processes whole set of sub-arrays - * evaluates shapeInfo of sub-arrays (all sub-arrays have the same shapeInfo) and their buffer offsets (each sub-array has its own unique offset from original this-buffer) - * dimsToExclude - MUST BE SORTED, dimensions to evaluate sub-array along, i.e. when shape is [2,3,4,5] and dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5] - * if dimsToExclude.size() = array rank it means sub-array is whole array and copy of original_shapeInfo will be returned and one zero offset - * subArrShapeInfo - output argument, contains shapeInfo common for all sub-arrays - * subArrOffsets - output argument, contains successive sub-arrays offsets from original this-buffer - * keepUnitiesInShape - if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b} - */ - public native void getSubArrShapeAndOffsets(@StdVector IntPointer dimsToExclude, @Cast("Nd4jLong*&") @ByPtrRef LongPointer subArrShapeInfo, @Cast("Nd4jLong*&") @ByPtrRef LongPointer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); - public native void getSubArrShapeAndOffsets(@StdVector IntPointer dimsToExclude, @Cast("Nd4jLong*&") @ByPtrRef LongPointer subArrShapeInfo, @Cast("Nd4jLong*&") @ByPtrRef LongPointer subArrOffsets); - public native void getSubArrShapeAndOffsets(@StdVector IntBuffer dimsToExclude, @Cast("Nd4jLong*&") @ByPtrRef LongBuffer subArrShapeInfo, @Cast("Nd4jLong*&") @ByPtrRef LongBuffer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); - public native void getSubArrShapeAndOffsets(@StdVector IntBuffer dimsToExclude, @Cast("Nd4jLong*&") @ByPtrRef LongBuffer subArrShapeInfo, @Cast("Nd4jLong*&") @ByPtrRef LongBuffer subArrOffsets); - public native void getSubArrShapeAndOffsets(@StdVector int[] dimsToExclude, @Cast("Nd4jLong*&") @ByPtrRef long[] subArrShapeInfo, @Cast("Nd4jLong*&") @ByPtrRef long[] subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); - public native void getSubArrShapeAndOffsets(@StdVector int[] dimsToExclude, @Cast("Nd4jLong*&") @ByPtrRef long[] subArrShapeInfo, @Cast("Nd4jLong*&") @ByPtrRef long[] subArrOffsets); + /** + * evaluates subarray with buffer pointing at this->_buffer and offset + * defined by given sequential index subArrIdx and dimensions in dimsToExclude + * subArrIdx - index of current sub-array + * dimsToExclude - MUST BE SORTED, dimensions to evaluate sub-array along, + * i.e. when shape is [2,3,4,5] and dimsToExclude={0,2}, then there will be 8 + * sub-arrays with shape [3,5], and subArrIdx must be in range [0,7] if + * dimsToExclude is empty then idxRanges containing all zeros (means whole + * array) will be returned. keepUnitiesInShape - if false then eliminate + * unities from resulting array shape, for example {1,a,1,b} -> {a,b} + */ + public native @ByVal @Name("operator ()") NDArray apply(@Cast("const Nd4jLong") long subArrIdx, + @StdVector IntPointer dimsToExclude, + @Cast("bool") boolean keepUnitiesInShape/*=false*/); + public native @ByVal @Name("operator ()") NDArray apply(@Cast("const Nd4jLong") long subArrIdx, + @StdVector IntPointer dimsToExclude); + public native @ByVal @Name("operator ()") NDArray apply(@Cast("const Nd4jLong") long subArrIdx, + @StdVector IntBuffer dimsToExclude, + @Cast("bool") boolean keepUnitiesInShape/*=false*/); + public native @ByVal @Name("operator ()") NDArray apply(@Cast("const Nd4jLong") long subArrIdx, + @StdVector IntBuffer dimsToExclude); + public native @ByVal @Name("operator ()") NDArray apply(@Cast("const Nd4jLong") long subArrIdx, + @StdVector int[] dimsToExclude, + @Cast("bool") boolean keepUnitiesInShape/*=false*/); + public native @ByVal @Name("operator ()") NDArray apply(@Cast("const Nd4jLong") long subArrIdx, + @StdVector int[] dimsToExclude); - /** - * addition unary operator array += other - * other - input array to add - */ - public native @Name("operator +=") void addPut(@Const @ByRef NDArray other); + /** + * processes whole set of sub-arrays + * evaluates shapeInfo of sub-arrays (all sub-arrays have the same shapeInfo) + * and their buffer offsets (each sub-array has its own unique offset from + * original this-buffer) dimsToExclude - MUST BE SORTED, dimensions to + * evaluate sub-array along, i.e. when shape is [2,3,4,5] and + * dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5] if + * dimsToExclude.size() = array rank it means sub-array is whole array and + * copy of original_shapeInfo will be returned and one zero offset + * subArrShapeInfo - output argument, contains shapeInfo common for all + * sub-arrays subArrOffsets - output argument, contains successive + * sub-arrays offsets from original this-buffer keepUnitiesInShape - if false + * then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> + * {a,b} + */ + public native void getSubArrShapeAndOffsets(@StdVector IntPointer dimsToExclude, + @Cast("Nd4jLong*&") @ByPtrRef LongPointer subArrShapeInfo, + @Cast("Nd4jLong*&") @ByPtrRef LongPointer subArrOffsets, + @Cast("bool") boolean keepUnitiesInShape/*=false*/); + public native void getSubArrShapeAndOffsets(@StdVector IntPointer dimsToExclude, + @Cast("Nd4jLong*&") @ByPtrRef LongPointer subArrShapeInfo, + @Cast("Nd4jLong*&") @ByPtrRef LongPointer subArrOffsets); + public native void getSubArrShapeAndOffsets(@StdVector IntBuffer dimsToExclude, + @Cast("Nd4jLong*&") @ByPtrRef LongBuffer subArrShapeInfo, + @Cast("Nd4jLong*&") @ByPtrRef LongBuffer subArrOffsets, + @Cast("bool") boolean keepUnitiesInShape/*=false*/); + public native void getSubArrShapeAndOffsets(@StdVector IntBuffer dimsToExclude, + @Cast("Nd4jLong*&") @ByPtrRef LongBuffer subArrShapeInfo, + @Cast("Nd4jLong*&") @ByPtrRef LongBuffer subArrOffsets); + public native void getSubArrShapeAndOffsets(@StdVector int[] dimsToExclude, + @Cast("Nd4jLong*&") @ByPtrRef long[] subArrShapeInfo, + @Cast("Nd4jLong*&") @ByPtrRef long[] subArrOffsets, + @Cast("bool") boolean keepUnitiesInShape/*=false*/); + public native void getSubArrShapeAndOffsets(@StdVector int[] dimsToExclude, + @Cast("Nd4jLong*&") @ByPtrRef long[] subArrShapeInfo, + @Cast("Nd4jLong*&") @ByPtrRef long[] subArrOffsets); - /** - * subtraction unary operator array -= other - * other - input array to add - */ - public native @Name("operator -=") void subtractPut(@Const @ByRef NDArray other); + /** + * addition unary operator array += other + * other - input array to add + */ + public native @Name("operator +=") void addPut(@Const @ByRef NDArray other); - /** - * negative operator, it changes sign of all array elements on opposite - */ - public native @ByVal @Name("operator -") NDArray subtract(); - + /** + * subtraction unary operator array -= other + * other - input array to add + */ + public native @Name("operator -=") void subtractPut(@Const @ByRef NDArray other); - /** - * pairwise multiplication unary operator array *= other - * other - input array to multiply on - */ - public native @Name("operator *=") void multiplyPut(@Const @ByRef NDArray other); + /** + * negative operator, it changes sign of all array elements on opposite + */ + public native @ByVal @Name("operator -") NDArray subtract(); + - /** - * multiplication unary operator array *= scalar - * scalar - input scalar to multiply on - */ + /** + * pairwise multiplication unary operator array *= other + * other - input array to multiply on + */ + public native @Name("operator *=") void multiplyPut(@Const @ByRef NDArray other); - /** - * pairwise division unary operator: array /= other - * other - input array to divide on - */ - public native @Name("operator /=") void dividePut(@Const @ByRef NDArray other); + /** + * multiplication unary operator array *= scalar + * scalar - input scalar to multiply on + */ - /** - * division unary operator: array /= scalar - * scalar - input scalar to divide on - */ + /** + * pairwise division unary operator: array /= other + * other - input array to divide on + */ + public native @Name("operator /=") void dividePut(@Const @ByRef NDArray other); - /** - * friend function which implements mathematical multiplication of two arrays - * left - input array - * right - input array - */ - + /** + * division unary operator: array /= scalar + * scalar - input scalar to divide on + */ - /** - * return vector containing _buffer as flat binary array - */ - public native @StdVector BytePointer asByteVector(); + /** + * friend function which implements mathematical multiplication of two arrays + * left - input array + * right - input array + */ + - /** - * makes array to be identity matrix (not necessarily square), that is set all diagonal elements = 1, rest = 0 - */ - public native void setIdentity(); + /** + * return vector containing _buffer as flat binary array + */ + public native @StdVector BytePointer asByteVector(); - /** - * swaps the contents of tow arrays, - * PLEASE NOTE: method doesn't take into account the shapes of arrays, shapes may be different except one condition: arrays lengths must be the same - */ - public native void swapUnsafe(@ByRef NDArray other); + /** + * makes array to be identity matrix (not necessarily square), that is set + * all diagonal elements = 1, rest = 0 + */ + public native void setIdentity(); - /** - * return vector with buffer which points on corresponding diagonal elements of array - * type - means of vector to be returned: column ('c') or row ('r') - */ - public native @ByVal NDArray diagonal(byte type ); + /** + * swaps the contents of tow arrays, + * PLEASE NOTE: method doesn't take into account the shapes of arrays, shapes + * may be different except one condition: arrays lengths must be the same + */ + public native void swapUnsafe(@ByRef NDArray other); - /** - * fill target matrix with given value in one or two directions from main diagonal: - * - down from main diagonal starting at subdiagonal number "lower" if direction = 'l' (down) or 'b' (both) - * - up from main diagonal starting at superdiagonal number "upper"if direction = 'u' (up) or 'b' (both) - * direction - in what direction to fill matrix. There are 3 possible directions: - * 'u' - fill up, mathematically this corresponds to lower triangular matrix, subdiagonal "lower" unaffected - * 'l' - fill down, mathematically this corresponds to upper triangular matrix, superdiagonal "upper" remains unaffected - * 'b' - fill in both directions, both "lower" and "upper" are taken into account - * rest of target elements are equal to this array elements - * target and this array should have same shapes, except when this_rank = 1 (in that case should be target_rank = 2) - */ + /** + * return vector with buffer which points on corresponding diagonal elements + * of array type - means of vector to be returned: column ('c') or row ('r') + */ + public native @ByVal NDArray diagonal(byte type); - /** - * change an array by repeating it the number of times in order to acquire new shape equal to the input shape - * - * shape - contains new shape to broadcast array to - * target - optional argument, if target != nullptr the resulting array will be placed in target, in opposite case tile operation is done in place - */ - public native @ByVal NDArray tileToShape(@Cast("const Nd4jLong*") LongPointer shapeInfo); - public native @ByVal NDArray tileToShape(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - public native @ByVal NDArray tileToShape(@Cast("const Nd4jLong*") long[] shapeInfo); - public native void tileToShape(@Cast("Nd4jLong*") @StdVector LongPointer shape, @ByRef NDArray target); - public native void tileToShape(@Cast("Nd4jLong*") @StdVector LongBuffer shape, @ByRef NDArray target); - public native void tileToShape(@Cast("Nd4jLong*") @StdVector long[] shape, @ByRef NDArray target); + /** + * fill target matrix with given value in one or two directions from main + * diagonal: + * - down from main diagonal starting at subdiagonal number "lower" if + * direction = 'l' (down) or 'b' (both) + * - up from main diagonal starting at superdiagonal number "upper"if + * direction = 'u' (up) or 'b' (both) direction - in what direction to fill + * matrix. There are 3 possible directions: 'u' - fill up, mathematically this + * corresponds to lower triangular matrix, subdiagonal "lower" unaffected 'l' + * - fill down, mathematically this corresponds to upper triangular matrix, + * superdiagonal "upper" remains unaffected 'b' - fill in both directions, + * both "lower" and "upper" are taken into account rest of target elements are + * equal to this array elements target and this array should have same shapes, + * except when this_rank = 1 (in that case should be target_rank = 2) + */ + + /** + * change an array by repeating it the number of times in order to acquire + * new shape equal to the input shape + * + * shape - contains new shape to broadcast array to + * target - optional argument, if target != nullptr the resulting array will + * be placed in target, in opposite case tile operation is done in place + */ + public native @ByVal NDArray tileToShape(@Cast("const Nd4jLong*") LongPointer shapeInfo); + public native @ByVal NDArray tileToShape(@Cast("const Nd4jLong*") LongBuffer shapeInfo); + public native @ByVal NDArray tileToShape(@Cast("const Nd4jLong*") long[] shapeInfo); + public native void tileToShape(@Cast("Nd4jLong*") @StdVector LongPointer shape, @ByRef NDArray target); + public native void tileToShape(@Cast("Nd4jLong*") @StdVector LongBuffer shape, @ByRef NDArray target); + public native void tileToShape(@Cast("Nd4jLong*") @StdVector long[] shape, @ByRef NDArray target); // #ifndef __JAVACPP_HACK__ // #endif - public native @ByVal NDArray asT(@Cast("sd::DataType") int dtype); + public native @ByVal NDArray asT(@Cast("sd::DataType") int dtype); + public native void linspace(double start); - public native void linspace(double start); + public native void linspace(double start, double step); - public native void linspace(double start, double step); + /** + * calculates the trace of an array, that is sum of elements on main diagonal + * = sum array[i, i, i, ...] + */ + public native double getTrace(); - /** - * calculates the trace of an array, that is sum of elements on main diagonal = sum array[i, i, i, ...] - */ - public native double getTrace(); + public native @ByVal ResultSet multipleTensorsAlongDimension( + @StdVector IntPointer indices, + @StdVector IntPointer dimensions); + public native @ByVal ResultSet multipleTensorsAlongDimension( + @StdVector IntBuffer indices, + @StdVector IntBuffer dimensions); + public native @ByVal ResultSet multipleTensorsAlongDimension( + @StdVector int[] indices, + @StdVector int[] dimensions); - public native @ByVal ResultSet multipleTensorsAlongDimension(@StdVector IntPointer indices, @StdVector IntPointer dimensions); - public native @ByVal ResultSet multipleTensorsAlongDimension(@StdVector IntBuffer indices, @StdVector IntBuffer dimensions); - public native @ByVal ResultSet multipleTensorsAlongDimension(@StdVector int[] indices, @StdVector int[] dimensions); + public native @ByVal ResultSet allTensorsAlongDimension(@StdVector IntPointer dimensions); + public native @ByVal ResultSet allTensorsAlongDimension(@StdVector IntBuffer dimensions); + public native @ByVal ResultSet allTensorsAlongDimension(@StdVector int[] dimensions); - public native @ByVal ResultSet allTensorsAlongDimension(@StdVector IntPointer dimensions); - public native @ByVal ResultSet allTensorsAlongDimension(@StdVector IntBuffer dimensions); - public native @ByVal ResultSet allTensorsAlongDimension(@StdVector int[] dimensions); + public native @ByVal ResultSet allExamples(); - public native @ByVal ResultSet allExamples(); + /** + * set _shapeInfo + */ + public native void setShapeInfo(@Cast("const Nd4jLong*") LongPointer shapeInfo); + public native void setShapeInfo(@Cast("const Nd4jLong*") LongBuffer shapeInfo); + public native void setShapeInfo(@Cast("const Nd4jLong*") long[] shapeInfo); + public native void setShapeInfo(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype); + public native void setShapeInfo(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype); + public native void setShapeInfo(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype); + public native void setShapeInfo(@Const @ByRef ShapeDescriptor descriptor); + public native void setShapeInfo(@Const @ByRef ConstantShapeBuffer shapeBuffer); - /** - * set _shapeInfo - */ - public native void setShapeInfo(@Cast("const Nd4jLong*") LongPointer shapeInfo); - public native void setShapeInfo(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - public native void setShapeInfo(@Cast("const Nd4jLong*") long[] shapeInfo); - public native void setShapeInfo(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype); - public native void setShapeInfo(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype); - public native void setShapeInfo(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype); - public native void setShapeInfo(@Const @ByRef ShapeDescriptor descriptor); - public native void setShapeInfo(@Const @ByRef ConstantShapeBuffer shapeBuffer); + /** + * returns absolute offset which corresponds to given sequential index + */ + public native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong") long i); - /** - * returns absolute offset which corresponds to given sequential index - */ - public native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong") long i); + /** + * returns reference on array element with given index + */ - /** - * returns reference on array element with given index - */ + /** + * returns array element with given index + * i - element index in array + */ + /** + * default destructor + */ - /** - * returns array element with given index - * i - element index in array - */ + /** + * set _shapeInfo + */ + /** + * returns the value of "dim" dimension + */ + public native @Cast("Nd4jLong") long sizeAt(int dim); - /** - * default destructor - */ + /** + * returns stride of "dim" dimension + */ + public native @Cast("Nd4jLong") long strideAt(int dim); - /** - * set _shapeInfo - */ + /** + * returns order of array + */ + public native char ordering(); - /** - * returns the value of "dim" dimension - */ - public native @Cast("Nd4jLong") long sizeAt(int dim); + /** + * return _isView + */ + public native @Cast("bool") boolean isView(); - /** - * returns stride of "dim" dimension - */ - public native @Cast("Nd4jLong") long strideAt(int dim); + /** + * returns shape portion of shapeInfo + */ + public native @Cast("Nd4jLong*") LongPointer shapeOf(); - /** - * returns order of array - */ - public native char ordering(); + /** + * returns strides portion of shapeInfo + */ + public native @Cast("Nd4jLong*") LongPointer stridesOf(); - /** - * return _isView - */ - public native @Cast("bool") boolean isView(); + /** + * returns rank of array + */ + public native int rankOf(); - /** - * returns shape portion of shapeInfo - */ - public native @Cast("Nd4jLong*") LongPointer shapeOf(); + /** + * returns length of array + */ + public native @Cast("Nd4jLong") long lengthOf(); - /** - * returns strides portion of shapeInfo - */ - public native @Cast("Nd4jLong*") LongPointer stridesOf(); + /** + * returns number of rows in array + */ + public native @Cast("Nd4jLong") long rows(); - /** - * returns rank of array - */ - public native int rankOf(); + /** + * returns number of columns in array + */ + public native @Cast("Nd4jLong") long columns(); - /** - * returns length of array - */ - public native @Cast("Nd4jLong") long lengthOf(); + /** + * returns size of array elements type + */ + public native @Cast("size_t") long sizeOfT(); - /** - * returns number of rows in array - */ - public native @Cast("Nd4jLong") long rows(); + /** + * returns element-wise-stride + */ + public native @Cast("Nd4jLong") long ews(); - /** - * returns number of columns in array - */ - public native @Cast("Nd4jLong") long columns(); + // returns true if arrays have same shape + public native @Cast("bool") boolean isSameShape(@Const NDArray other); + public native @Cast("bool") boolean isSameShape(@Cast("Nd4jLong*") @StdVector LongPointer shape); + public native @Cast("bool") boolean isSameShape(@Cast("Nd4jLong*") @StdVector LongBuffer shape); + public native @Cast("bool") boolean isSameShape(@Cast("Nd4jLong*") @StdVector long[] shape); + public native @Cast("bool") boolean areSameShapeAndType(@Const @ByRef NDArray other); - /** - * returns size of array elements type - */ - public native @Cast("size_t") long sizeOfT(); + /** + * returns true if these two NDArrays have same rank, dimensions, strides, + * ews and order + */ + public native @Cast("bool") boolean isSameShapeStrict(@Const @ByRef NDArray other); - /** - * returns element-wise-stride - */ - public native @Cast("Nd4jLong") long ews(); + /** + * returns true if buffer && shapeInfo were defined (non nullptr) + */ + public native @Cast("bool") boolean nonNull(); - // returns true if arrays have same shape - public native @Cast("bool") boolean isSameShape(@Const NDArray other); - public native @Cast("bool") boolean isSameShape(@Cast("Nd4jLong*") @StdVector LongPointer shape); - public native @Cast("bool") boolean isSameShape(@Cast("Nd4jLong*") @StdVector LongBuffer shape); - public native @Cast("bool") boolean isSameShape(@Cast("Nd4jLong*") @StdVector long[] shape); - public native @Cast("bool") boolean areSameShapeAndType(@Const @ByRef NDArray other); + /** + * returns array element with given index from linear buffer + * i - element index in array + */ - /** - * returns true if these two NDArrays have same rank, dimensions, strides, ews and order - */ - public native @Cast("bool") boolean isSameShapeStrict(@Const @ByRef NDArray other); + /** + * returns element with given indexes from 2D array + * i - number of row + * j - number of column + */ - /** - * returns true if buffer && shapeInfo were defined (non nullptr) - */ - public native @Cast("bool") boolean nonNull(); + /** + * returns element with given indexes from 3D array + * i - height + * j - width + * k - depth + */ - /** - * returns array element with given index from linear buffer - * i - element index in array - */ + /** + * returns element with given indexes from DD array + */ - /** - * returns element with given indexes from 2D array - * i - number of row - * j - number of column - */ + /** + * returns array-scalar containing element of this array with given index + * i - element index in array + */ + public native @ByVal NDArray e(@Cast("const Nd4jLong") long i); - /** - * returns element with given indexes from 3D array - * i - height - * j - width - * k - depth - */ + /** + * assigns given scalar to array element by given index, regards array buffer + * as linear i - element index in array value - scalar value to assign + */ - /** - * returns element with given indexes from DD array - */ + public native void p(@Cast("const Nd4jLong") long i, @Const @ByRef NDArray value); - /** - * returns array-scalar containing element of this array with given index - * i - element index in array - */ - public native @ByVal NDArray e(@Cast("const Nd4jLong") long i); + /** + * assigns given scalar to 2D array element by given indexes + * i - number of row + * j - number of row + * value - scalar value to assign + */ - /** - * assigns given scalar to array element by given index, regards array buffer as linear - * i - element index in array - * value - scalar value to assign - */ + /** + * assigns given scalar to 3D array element by given indexes + * i - height + * j - width + * k - depth + * value - scalar value to assign + */ + public native void p(@Cast("const Nd4jLong") long i, @Cast("const Nd4jLong") long j, @Cast("const Nd4jLong") long k, @Cast("const Nd4jLong") long l, + @Const @ByRef NDArray value); - public native void p(@Cast("const Nd4jLong") long i, @Const @ByRef NDArray value); + /** + * returns true if array is 2D + */ + public native @Cast("bool") boolean isMatrix(); - /** - * assigns given scalar to 2D array element by given indexes - * i - number of row - * j - number of row - * value - scalar value to assign - */ + /** + * returns true if array is vector + */ + public native @Cast("bool") boolean isVector(); - /** - * assigns given scalar to 3D array element by given indexes - * i - height - * j - width - * k - depth - * value - scalar value to assign - */ - public native void p(@Cast("const Nd4jLong") long i, @Cast("const Nd4jLong") long j, @Cast("const Nd4jLong") long k, @Cast("const Nd4jLong") long l, @Const @ByRef NDArray value); + /** + * returns true if array is column vector + */ + public native @Cast("bool") boolean isColumnVector(); - /** - * returns true if array is 2D - */ - public native @Cast("bool") boolean isMatrix(); + /** + * returns true if array is row vector + */ + public native @Cast("bool") boolean isRowVector(); - /** - * returns true if array is vector - */ - public native @Cast("bool") boolean isVector(); + /** + * returns true if all dimensions of array except one are unities, for + * example: [1,1,n,1], [n,1,1], [n], ... posOfNonUnityDim - one dimension with + * value > 1 + */ + public native @Cast("bool") boolean isCommonVector(@ByRef IntPointer posOfNonUnityDim); + public native @Cast("bool") boolean isCommonVector(@ByRef IntBuffer posOfNonUnityDim); + public native @Cast("bool") boolean isCommonVector(@ByRef int[] posOfNonUnityDim); - /** - * returns true if array is column vector - */ - public native @Cast("bool") boolean isColumnVector(); + /** + * returns true if array is scalar + */ + public native @Cast("bool") boolean isScalar(); - /** - * returns true if array is row vector - */ - public native @Cast("bool") boolean isRowVector(); + /** + * Returns data type of this array + * @return + */ + public native @Cast("sd::DataType") int dataType(); - /** - * returns true if all dimensions of array except one are unities, for example: [1,1,n,1], [n,1,1], [n], ... - * posOfNonUnityDim - one dimension with value > 1 - */ - public native @Cast("bool") boolean isCommonVector(@ByRef IntPointer posOfNonUnityDim); - public native @Cast("bool") boolean isCommonVector(@ByRef IntBuffer posOfNonUnityDim); - public native @Cast("bool") boolean isCommonVector(@ByRef int[] posOfNonUnityDim); - - - /** - * returns true if array is scalar - */ - public native @Cast("bool") boolean isScalar(); - - /** - * Returns data type of this array - * @return - */ - public native @Cast("sd::DataType") int dataType(); - - /** - * This method returns true if value is from Integer space - * @return - */ - public native @Cast("bool") boolean isZ(); - - /** - * This method returns true if array is from Real space - * @return - */ - public native @Cast("bool") boolean isR(); + /** + * This method returns true if value is from Integer space + * @return + */ + public native @Cast("bool") boolean isZ(); - /** - * This method returns true if array is from Boolean space - * @return - */ - public native @Cast("bool") boolean isB(); + /** + * This method returns true if array is from Real space + * @return + */ + public native @Cast("bool") boolean isR(); - /** - * This method returns true if array contains Complex numbers - * @return - */ - public native @Cast("bool") boolean isC(); + /** + * This method returns true if array is from Boolean space + * @return + */ + public native @Cast("bool") boolean isB(); - /** - * This method returns true if array contains String - * @return - */ - public native @Cast("bool") boolean isS(); + /** + * This method returns true if array contains Complex numbers + * @return + */ + public native @Cast("bool") boolean isC(); - public native @Cast("bool") boolean isAttached(); + /** + * This method returns true if array contains String + * @return + */ + public native @Cast("bool") boolean isS(); - public native NDArray detach(); + public native @Cast("bool") boolean isAttached(); - public native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef NDArray other); + public native @ByVal NDArray detach(); - public native @Cast("bool") @Name("operator !=") boolean notEquals(@Const @ByRef NDArray other); - } + /** + * This method returns true if array is valid array with some shape etc + * @return + */ + public native @Cast("bool") boolean defined(); + public native @Cast("bool") boolean undefined(); + public native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef NDArray other); + public native @Cast("bool") @Name("operator !=") boolean notEquals(@Const @ByRef NDArray other); +} // class NDArray +@Namespace("sd") public static native @Cast("std::ostream*") @ByRef @Name("operator <<") Pointer shiftLeft(@Cast("std::ostream*") @ByRef Pointer os, @Const @ByRef NDArray m); ////////////////////////////////////////////////////////////////////////// ///// IMLEMENTATION OF INLINE METHODS ///// @@ -4965,7 +5471,6 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint ////////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////////// @@ -5043,13 +5548,12 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint //////////////////////////////////////////////////////////////////////// - -// #if defined(__CUDACC__) //&& defined(BUILD_TESTS) +// #if defined(__CUDACC__) //&& defined(BUILD_TESTS) // for CUDA we need stil stuff inline // #include // #endif - + // namespace sd // #endif @@ -5081,51 +5585,68 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef NDARRAY_LIST_H // #define NDARRAY_LIST_H -// #include -// #include -// #include // #include // #include // #include - @Namespace("sd") @NoOffset public static class NDArrayList extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public NDArrayList(Pointer p) { super(p); } - - public NDArrayList(int height, @Cast("bool") boolean expandable/*=false*/) { super((Pointer)null); allocate(height, expandable); } - private native void allocate(int height, @Cast("bool") boolean expandable/*=false*/); - public NDArrayList(int height) { super((Pointer)null); allocate(height); } - private native void allocate(int height); - public native @Cast("sd::DataType") int dataType(); +// #include +// #include +// #include +@Namespace("sd") @NoOffset public static class NDArrayList extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public NDArrayList(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public NDArrayList(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public NDArrayList position(long position) { + return (NDArrayList)super.position(position); + } + + public NDArrayList(int height/*=0*/, @Cast("bool") boolean expandable/*=false*/) { super((Pointer)null); allocate(height, expandable); } + private native void allocate(int height/*=0*/, @Cast("bool") boolean expandable/*=false*/); + public NDArrayList() { super((Pointer)null); allocate(); } + private native void allocate(); + + public NDArrayList(@Const @ByRef NDArrayList other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef NDArrayList other); - public native NDArray read(int idx); - public native NDArray readRaw(int idx); - public native @Cast("Nd4jStatus") int write(int idx, NDArray array); - public native NDArray pick(@StdVector IntPointer indices); - public native NDArray pick(@StdVector IntBuffer indices); - public native NDArray pick(@StdVector int[] indices); - public native @Cast("bool") boolean isWritten(int index); + public native @ByRef @Name("operator =") @NoException NDArrayList put(@Const @ByRef NDArrayList other); - public native @Cast("Nd4jLong*") @StdVector LongPointer shape(); + // move assignment operator - public native NDArray stack(); - public native void unstack(NDArray array, int axis); + public native @Cast("sd::DataType") int dataType(); - public native @ByRef IntIntPair id(); - public native @StdString @ByRef @Cast({"char*", "std::string*"}) BytePointer name(); - //sd::memory::Workspace* workspace(); - public native LaunchContext context(); - public native NDArrayList clone(); + public native @ByVal NDArray read(int idx); + public native @ByVal NDArray readRaw(int idx); + public native @Cast("Nd4jStatus") int write(int idx, @Const @ByRef NDArray array); - public native @Cast("bool") boolean equals(@ByRef NDArrayList other); + public native @ByVal NDArray pick(@StdVector IntPointer indices); + public native @ByVal NDArray pick(@StdVector IntBuffer indices); + public native @ByVal NDArray pick(@StdVector int[] indices); + public native @Cast("bool") boolean isWritten(int index); - public native int elements(); - public native int height(); + public native @Cast("Nd4jLong*") @StdVector LongPointer shape(); + public native void setShape(@Cast("Nd4jLong*") @StdVector LongPointer shape); + public native void setShape(@Cast("Nd4jLong*") @StdVector LongBuffer shape); + public native void setShape(@Cast("Nd4jLong*") @StdVector long[] shape); - public native int counter(); - } + public native @ByVal NDArray stack(); + public native void unstack(@Const @ByRef NDArray array, int axis); + + public native @Const @ByRef IntIntPair id(); + public native @StdString BytePointer name(); + + public native @ByVal NDArrayList clone(); + + public native @Cast("bool") boolean equals(@ByRef NDArrayList other); + + public native int elements(); + public native int height(); + public native int counter(); +} + // namespace sd // #endif @@ -5158,56 +5679,58 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_RESULTSET_H // #define LIBND4J_RESULTSET_H -// #include // #include +// #include // #include -// #include // forward declaration of template class NDArray - - @Namespace("sd") @NoOffset public static class ResultSet extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ResultSet(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public ResultSet(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public ResultSet position(long position) { - return (ResultSet)super.position(position); - } - - public ResultSet() { super((Pointer)null); allocate(); } - private native void allocate(); + +// #include // forward declaration of template class NDArray + +@Namespace("sd") @NoOffset public static class ResultSet extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ResultSet(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public ResultSet(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public ResultSet position(long position) { + return (ResultSet)super.position(position); + } + + public ResultSet() { super((Pointer)null); allocate(); } + private native void allocate(); // #ifndef __JAVACPP_HACK__ // #endif - public ResultSet(@Const @ByRef ResultSet other) { super((Pointer)null); allocate(other); } - @NoException private native void allocate(@Const @ByRef ResultSet other); + public ResultSet(@Const @ByRef ResultSet other) { super((Pointer)null); allocate(other); } + @NoException private native void allocate(@Const @ByRef ResultSet other); - public native @ByRef @Name("operator =") @NoException ResultSet put(@Const @ByRef ResultSet other); + public native @ByRef @Name("operator =") @NoException ResultSet put(@Const @ByRef ResultSet other); - // move constructor + // move constructor - // move assignment operator + // move assignment operator - public native int size(); - public native NDArray at(@Cast("const unsigned long") long idx); - public native @Name("operator []") NDArray get(@Cast("const unsigned long") long idx); - public native void push_back(NDArray array); - - public native @Cast("Nd4jStatus") int status(); - public native void setStatus(@Cast("Nd4jStatus") int status); - public native void purge(); - public native void setNonRemovable(); - } + public native int size(); + public native @ByRef NDArray at(@Cast("const unsigned long") long idx); + public native @ByRef @Name("operator []") NDArray get(@Cast("const unsigned long") long idx); + public native void push_back(@Const @ByRef NDArray array); + public native @Cast("Nd4jStatus") int status(); + public native void setStatus(@Cast("Nd4jStatus") int status); + public native void purge(); + public native void setNonRemovable(); +} + // namespace sd -// #endif //LIBND4J_RESULTSET_H +// #endif // LIBND4J_RESULTSET_H // Parsed from graph/RandomGenerator.h /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -5237,6 +5760,12 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #include // #include // #include +// #include +// #include +// #include +// #include + +// #include // #include // #include @@ -5244,110 +5773,113 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #endif // #ifdef __CUDACC__ // #else - @Namespace("sd::graph") @NoOffset public static class RandomGenerator extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public RandomGenerator(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public RandomGenerator(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public RandomGenerator position(long position) { - return (RandomGenerator)super.position(position); - } - +@Namespace("sd::graph") @NoOffset public static class RandomGenerator extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public RandomGenerator(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public RandomGenerator(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public RandomGenerator position(long position) { + return (RandomGenerator)super.position(position); + } + public native @Cast("uint32_t") int xoroshiro32(@Cast("uint64_t") long index); - public native @Cast("uint64_t") long xoroshiro64(@Cast("uint64_t") long index); - public RandomGenerator(@Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/) { super((Pointer)null); allocate(rootSeed, nodeSeed); } - private native void allocate(@Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/); - public RandomGenerator() { super((Pointer)null); allocate(); } - private native void allocate(); + public native @Cast("uint64_t") long xoroshiro64(@Cast("uint64_t") long index); + public RandomGenerator(@Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/) { super((Pointer)null); allocate(rootSeed, nodeSeed); } + private native void allocate(@Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/); + public RandomGenerator() { super((Pointer)null); allocate(); } + private native void allocate(); + + public RandomGenerator(@Const @ByRef RandomGenerator other) { super((Pointer)null); allocate(other); } + @NoException private native void allocate(@Const @ByRef RandomGenerator other); + + public native @ByRef @Name("operator =") @NoException RandomGenerator put(@Const @ByRef RandomGenerator other); + + // move constructor + + // move assignment operator - /** - * This method allows to change graph-level state in runtime. - * PLEASE NOTE: this method will change state of node as well. - */ - public native void setStates(@Cast("Nd4jLong") long rootSeed, @Cast("Nd4jLong") long nodeState/*=0*/); - public native void setStates(@Cast("Nd4jLong") long rootSeed); + /** + * This method allows to change graph-level state in runtime. + * PLEASE NOTE: this method will change state of node as well. + */ + public native void setStates(@Cast("Nd4jLong") long rootSeed, @Cast("Nd4jLong") long nodeState/*=0*/); + public native void setStates(@Cast("Nd4jLong") long rootSeed); - + /** + * This method returns T value between from and to + */ - /** - * This method returns T value between from and to - */ + /** + * This method returns T value between 0 and MAX_T + */ - /** - * This method returns T value between 0 and MAX_T - */ + /** + * These two methods are made for JVM + * @param index + * @return + */ + public native int relativeInt(@Cast("Nd4jLong") long index); + public native @Cast("Nd4jLong") long relativeLong(@Cast("Nd4jLong") long index); - /** - * These two methods are made for JVM - * @param index - * @return - */ - public native int relativeInt(@Cast("Nd4jLong") long index); - public native @Cast("Nd4jLong") long relativeLong(@Cast("Nd4jLong") long index); + public native void rewindH(@Cast("uint64_t") long steps); - public native void rewindH(@Cast("uint64_t") long steps); + /** + * These methods set up only node states, with non-changed root ones + */ + public native void setSeed(int seed); - /** - * These methods set up only node states, with non-changed root ones - */ - public native void setSeed(int seed); + public native void setSeed(@Cast("uint64_t") long seed); - public native void setSeed(@Cast("uint64_t") long seed); + public native @Cast("Nd4jLong") long rootState(); - public native @Cast("Nd4jLong") long rootState(); + public native @Cast("Nd4jLong") long nodeState(); +} - public native @Cast("Nd4jLong") long nodeState(); - } - - - - - - - - - - - - - - - ////// - @Namespace("sd::graph") public static native @Cast("uint32_t") int rotl(@Cast("const uint32_t") int x, int k); - @Namespace("sd::graph") public static native @Cast("uint64_t") long rotl(@Cast("const uint64_t") long x, int k); - @Namespace("sd::graph") public static native @Cast("uint32_t") int next(@Cast("uint32_t") int s0, @Cast("uint32_t") int s1, @Cast("uint32_t") int s2, @Cast("uint32_t") int s3); - - + + + + + +////// +@Namespace("sd::graph") public static native @Cast("uint32_t") int rotl(@Cast("const uint32_t") int x, int k); + +@Namespace("sd::graph") public static native @Cast("uint64_t") long rotl(@Cast("const uint64_t") long x, int k); + +@Namespace("sd::graph") public static native @Cast("uint32_t") int next(@Cast("uint32_t") int s0, @Cast("uint32_t") int s1, @Cast("uint32_t") int s2, @Cast("uint32_t") int s3); - + + + // namespace graph + // namespace sd + // #endif @@ -5376,98 +5908,113 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_VARIABLE_H // #define LIBND4J_VARIABLE_H -// #include // #include // #include // #include // #include -// #include // #include +// #include + +// #include // #ifndef __JAVACPP_HACK__ // #endif - @Namespace("sd::graph") @NoOffset public static class Variable extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Variable(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Variable(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Variable position(long position) { - return (Variable)super.position(position); - } - - public Variable(@Cast("bool") boolean placeHolder) { super((Pointer)null); allocate(placeHolder); } - private native void allocate(@Cast("bool") boolean placeHolder); - public Variable(NDArray arrayw, @Cast("char*") String name, int id, int idx/*=0*/) { super((Pointer)null); allocate(arrayw, name, id, idx); } - private native void allocate(NDArray arrayw, @Cast("char*") String name, int id, int idx/*=0*/); - public Variable(NDArray arrayw, @Cast("char*") String name, int id) { super((Pointer)null); allocate(arrayw, name, id); } - private native void allocate(NDArray arrayw, @Cast("char*") String name, int id); - public Variable(NDArray arrayw, @Cast("char*") BytePointer name, int id, int idx/*=0*/) { super((Pointer)null); allocate(arrayw, name, id, idx); } - private native void allocate(NDArray arrayw, @Cast("char*") BytePointer name, int id, int idx/*=0*/); - public Variable(NDArray arrayw, @Cast("char*") BytePointer name, int id) { super((Pointer)null); allocate(arrayw, name, id); } - private native void allocate(NDArray arrayw, @Cast("char*") BytePointer name, int id); - public Variable(NDArray array/*=nullptr*/, @Cast("char*") String name/*=nullptr*/) { super((Pointer)null); allocate(array, name); } - private native void allocate(NDArray array/*=nullptr*/, @Cast("char*") String name/*=nullptr*/); - public Variable() { super((Pointer)null); allocate(); } - private native void allocate(); - public Variable(NDArray array/*=nullptr*/, @Cast("char*") BytePointer name/*=nullptr*/) { super((Pointer)null); allocate(array, name); } - private native void allocate(NDArray array/*=nullptr*/, @Cast("char*") BytePointer name/*=nullptr*/); +@Namespace("sd::graph") @NoOffset public static class Variable extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Variable(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Variable(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Variable position(long position) { + return (Variable)super.position(position); + } + + public Variable(@Cast("bool") boolean placeHolder, @Cast("sd::DataType") int dataType/*=sd::DataType::ANY*/, @Cast("Nd4jLong*") @StdVector LongPointer shape/*={}*/) { super((Pointer)null); allocate(placeHolder, dataType, shape); } + private native void allocate(@Cast("bool") boolean placeHolder, @Cast("sd::DataType") int dataType/*=sd::DataType::ANY*/, @Cast("Nd4jLong*") @StdVector LongPointer shape/*={}*/); + public Variable(@Cast("bool") boolean placeHolder) { super((Pointer)null); allocate(placeHolder); } + private native void allocate(@Cast("bool") boolean placeHolder); + public Variable(@Cast("bool") boolean placeHolder, @Cast("sd::DataType") int dataType/*=sd::DataType::ANY*/, @Cast("Nd4jLong*") @StdVector LongBuffer shape/*={}*/) { super((Pointer)null); allocate(placeHolder, dataType, shape); } + private native void allocate(@Cast("bool") boolean placeHolder, @Cast("sd::DataType") int dataType/*=sd::DataType::ANY*/, @Cast("Nd4jLong*") @StdVector LongBuffer shape/*={}*/); + public Variable(@Cast("bool") boolean placeHolder, @Cast("sd::DataType") int dataType/*=sd::DataType::ANY*/, @Cast("Nd4jLong*") @StdVector long[] shape/*={}*/) { super((Pointer)null); allocate(placeHolder, dataType, shape); } + private native void allocate(@Cast("bool") boolean placeHolder, @Cast("sd::DataType") int dataType/*=sd::DataType::ANY*/, @Cast("Nd4jLong*") @StdVector long[] shape/*={}*/); + public Variable(@Const @ByRef NDArray array, @StdString BytePointer name, int id, int idx/*=0*/) { super((Pointer)null); allocate(array, name, id, idx); } + private native void allocate(@Const @ByRef NDArray array, @StdString BytePointer name, int id, int idx/*=0*/); + public Variable(@Const @ByRef NDArray array, @StdString BytePointer name, int id) { super((Pointer)null); allocate(array, name, id); } + private native void allocate(@Const @ByRef NDArray array, @StdString BytePointer name, int id); + public Variable(@Const @ByRef NDArray array, @StdString String name, int id, int idx/*=0*/) { super((Pointer)null); allocate(array, name, id, idx); } + private native void allocate(@Const @ByRef NDArray array, @StdString String name, int id, int idx/*=0*/); + public Variable(@Const @ByRef NDArray array, @StdString String name, int id) { super((Pointer)null); allocate(array, name, id); } + private native void allocate(@Const @ByRef NDArray array, @StdString String name, int id); + public Variable(@SharedPtr NDArrayList array, @StdString BytePointer name, int id, int idx/*=0*/) { super((Pointer)null); allocate(array, name, id, idx); } + private native void allocate(@SharedPtr NDArrayList array, @StdString BytePointer name, int id, int idx/*=0*/); + public Variable(@SharedPtr NDArrayList array, @StdString BytePointer name, int id) { super((Pointer)null); allocate(array, name, id); } + private native void allocate(@SharedPtr NDArrayList array, @StdString BytePointer name, int id); + public Variable(@SharedPtr NDArrayList array, @StdString String name, int id, int idx/*=0*/) { super((Pointer)null); allocate(array, name, id, idx); } + private native void allocate(@SharedPtr NDArrayList array, @StdString String name, int id, int idx/*=0*/); + public Variable(@SharedPtr NDArrayList array, @StdString String name, int id) { super((Pointer)null); allocate(array, name, id); } + private native void allocate(@SharedPtr NDArrayList array, @StdString String name, int id); + public Variable(@SharedPtr NDArray array, @Cast("char*") String name/*=nullptr*/) { super((Pointer)null); allocate(array, name); } + private native void allocate(@SharedPtr NDArray array, @Cast("char*") String name/*=nullptr*/); + public Variable(@SharedPtr NDArray array) { super((Pointer)null); allocate(array); } + private native void allocate(@SharedPtr NDArray array); + public Variable(@SharedPtr NDArray array, @Cast("char*") BytePointer name/*=nullptr*/) { super((Pointer)null); allocate(array, name); } + private native void allocate(@SharedPtr NDArray array, @Cast("char*") BytePointer name/*=nullptr*/); + public Variable() { super((Pointer)null); allocate(); } + private native void allocate(); // #ifndef __JAVACPP_HACK__ // #endif - public native Variable clone(); + public native @Cast("bool") boolean hasNDArray(); + public native @SharedPtr NDArray getNDArray(); + public native void setNDArray(@SharedPtr NDArray array); - public native @Cast("bool") boolean hasNDArray(); - public native NDArray getNDArray(); - public native void setNDArray(NDArray array); + public native @Cast("bool") boolean hasNDArrayList(); + public native @SharedPtr NDArrayList getNDArrayList(); + public native void setNDArrayList(@SharedPtr NDArrayList list); - public native @Cast("bool") boolean hasNDArrayList(); - public native NDArrayList getNDArrayList(); - public native void setNDArrayList(NDArrayList list); + public native @Cast("bool") boolean isExternal(); + public native @Cast("bool") boolean isReadOnly(); + public native @Cast("bool") boolean isEmpty(); + public native @Cast("bool") boolean isRemovable(); - public native @Cast("bool") boolean isExternal(); - public native @Cast("bool") boolean isReadOnly(); - public native @Cast("bool") boolean isEmpty(); - public native @Cast("bool") boolean isRemovable(); + public native @Cast("bool") boolean isPlaceholder(); - public native @Cast("bool") boolean isPlaceholder(); + public native @Cast("sd::graph::VariableType") int variableType(); + public native void setVariableType(@Cast("sd::graph::VariableType") int variableType); - public native @Cast("sd::graph::VariableType") int variableType(); - public native void setVariableType(@Cast("sd::graph::VariableType") int variableType); + public native void markExternal(@Cast("bool") boolean reallyExternal); + public native void markReadOnly(@Cast("bool") boolean reallyReadOnly); + public native void markRemovable(@Cast("bool") boolean reallyRemovable); - /** - * This method returns InputType of this variable - */ - //InputType variableType() { - // return _variableType; - //} + public native int id(); + public native int index(); + public native void setIndex(int index); + public native void setId(int id); + public native void setId(int id, int idx); - public native void markExternal(@Cast("bool") boolean reallyExternal); - public native void markReadOnly(@Cast("bool") boolean reallyReadOnly); - public native void markRemovable(@Cast("bool") boolean reallyRemovable); + public native @StdString BytePointer name(); + public native @StdString BytePointer getName(); + public native void setName(@StdString BytePointer name); + public native void setName(@StdString String name); - public native int id(); - public native int index(); - public native void setIndex(int index); - public native void setId(int id); - public native void setId(int id, int idx); + public native @Cast("Nd4jLong*") @StdVector LongPointer shape(); + public native @Cast("sd::DataType") int dataType(); - public native @StdString @Cast({"char*", "std::string*"}) BytePointer getName(); - public native void setName(@StdString @Cast({"char*", "std::string*"}) BytePointer name); - - public native @Cast("Nd4jLong*") @StdVector LongPointer shape(); + public native @StdVector IntIntPair dependencies(); // #ifndef __JAVACPP_HACK__ // #endif - } - - +// #ifndef __JAVACPP_HACK__ +// #endif +} + // namespace graph + // namespace sd -// #endif //LIBND4J_VARIABLE_H +// #endif // LIBND4J_VARIABLE_H // Parsed from graph/VariablesSet.h @@ -5495,36 +6042,34 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_VARIABLESSET_H // #define LIBND4J_VARIABLESSET_H -// #include -// #include -// #include -// #include // #include - @Namespace("sd::graph") @NoOffset public static class VariablesSet extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public VariablesSet(Pointer p) { super(p); } - - public VariablesSet(@Cast("Nd4jStatus") int status/*=ND4J_STATUS_OK*/) { super((Pointer)null); allocate(status); } - private native void allocate(@Cast("Nd4jStatus") int status/*=ND4J_STATUS_OK*/); - public VariablesSet() { super((Pointer)null); allocate(); } - private native void allocate(); - - public native @Cast("Nd4jStatus") int status(); - - public native int size(); +// #include +// #include - public native void push_back(Variable variable); +// #include +// #include +@Namespace("sd::graph") @NoOffset public static class VariablesSet extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public VariablesSet(Pointer p) { super(p); } - public native Variable at(int index); + public VariablesSet(@Cast("Nd4jStatus") int status/*=ND4J_STATUS_OK*/) { super((Pointer)null); allocate(status); } + private native void allocate(@Cast("Nd4jStatus") int status/*=ND4J_STATUS_OK*/); + public VariablesSet() { super((Pointer)null); allocate(); } + private native void allocate(); - } - + public native @Cast("Nd4jStatus") int status(); + public native int size(); + public native void push_back(Variable variable); + public native Variable at(int index); +} + // namespace graph + // namespace sd -// #endif //LIBND4J_VARIABLESSET_H +// #endif // LIBND4J_VARIABLESSET_H // Parsed from graph/FlowPath.h @@ -5552,68 +6097,68 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_FLOWPATH_H // #define LIBND4J_FLOWPATH_H -// #include -// #include -// #include -// #include -// #include // #include +// #include // #include // #include - @Namespace("sd::graph") @NoOffset public static class FlowPath extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public FlowPath(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public FlowPath(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public FlowPath position(long position) { - return (FlowPath)super.position(position); - } - - public FlowPath() { super((Pointer)null); allocate(); } - private native void allocate(); +// #include +// #include - public native void setInnerTime(int nodeId, @Cast("Nd4jLong") long time); - public native void setOuterTime(int nodeId, @Cast("Nd4jLong") long time); +// #include +// #include +@Namespace("sd::graph") @NoOffset public static class FlowPath extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public FlowPath(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public FlowPath(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public FlowPath position(long position) { + return (FlowPath)super.position(position); + } - public native @Cast("Nd4jLong") long innerTime(int nodeId); - public native @Cast("Nd4jLong") long outerTime(int nodeId); + public FlowPath() { super((Pointer)null); allocate(); } + private native void allocate(); - public native @Cast("bool") boolean isNodeActive(int nodeId); - public native void markNodeActive(int nodeId, @Cast("bool") boolean isActive); + public native void setInnerTime(int nodeId, @Cast("Nd4jLong") long time); + public native void setOuterTime(int nodeId, @Cast("Nd4jLong") long time); - public native @Cast("bool") boolean wasExecuted(int nodeId); - public native void markExecuted(int nodeId, @Cast("bool") boolean wasExecuted); + public native @Cast("Nd4jLong") long innerTime(int nodeId); + public native @Cast("Nd4jLong") long outerTime(int nodeId); - public native int branch(int nodeId); - public native void markBranch(int nodeId, int index); + public native @Cast("bool") boolean isNodeActive(int nodeId); + public native void markNodeActive(int nodeId, @Cast("bool") boolean isActive); - // Frame-related methods + public native @Cast("bool") boolean wasExecuted(int nodeId); + public native void markExecuted(int nodeId, @Cast("bool") boolean wasExecuted); - public native void registerFrame(@Cast("Nd4jLong") long frameId); - public native void forgetFrame(@Cast("Nd4jLong") long frameId); + public native int branch(int nodeId); + public native void markBranch(int nodeId, int index); - public native @Cast("bool") boolean isFrameActive(@Cast("Nd4jLong") long frameId); - public native void markFrameActive(@Cast("Nd4jLong") long frameId, @Cast("bool") boolean isActive); + // Frame-related methods - public native @Cast("bool") boolean isRewindPlanned(@Cast("Nd4jLong") long frameId); - public native void planRewind(@Cast("Nd4jLong") long frameId, @Cast("bool") boolean reallyRewind); + public native void registerFrame(@Cast("Nd4jLong") long frameId); + public native void forgetFrame(@Cast("Nd4jLong") long frameId); - public native int getRewindPosition(@Cast("Nd4jLong") long frameId); - public native void setRewindPosition(@Cast("Nd4jLong") long frameId, int _position); - public native void setRewindPositionOnce(@Cast("Nd4jLong") long frameId, int _position); + public native @Cast("bool") boolean isFrameActive(@Cast("Nd4jLong") long frameId); + public native void markFrameActive(@Cast("Nd4jLong") long frameId, @Cast("bool") boolean isActive); - public native void incrementNumberOfCycles(@Cast("Nd4jLong") long frameId); - public native @Cast("Nd4jLong") long getNumberOfCycles(@Cast("Nd4jLong") long frameId); + public native @Cast("bool") boolean isRewindPlanned(@Cast("Nd4jLong") long frameId); + public native void planRewind(@Cast("Nd4jLong") long frameId, @Cast("bool") boolean reallyRewind); - public native GraphProfile profile(); - } - + public native int getRewindPosition(@Cast("Nd4jLong") long frameId); + public native void setRewindPosition(@Cast("Nd4jLong") long frameId, int _position); + public native void setRewindPositionOnce(@Cast("Nd4jLong") long frameId, int _position); + public native void incrementNumberOfCycles(@Cast("Nd4jLong") long frameId); + public native @Cast("Nd4jLong") long getNumberOfCycles(@Cast("Nd4jLong") long frameId); + public native GraphProfile profile(); +} + // namespace graph + // namespace sd -// #endif //LIBND4J_FLOWPATH_H +// #endif // LIBND4J_FLOWPATH_H // Parsed from graph/Intervals.h @@ -5641,43 +6186,41 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_INTERVALS_H // #define LIBND4J_INTERVALS_H -// #include -// #include -// #include // #include +// #include - @Namespace("sd") @NoOffset public static class Intervals extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Intervals(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Intervals(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Intervals position(long position) { - return (Intervals)super.position(position); - } - +// #include +// #include - // default constructor - public Intervals() { super((Pointer)null); allocate(); } - private native void allocate(); - - // constructor - public Intervals(@Const @ByRef LongVectorVector content ) { super((Pointer)null); allocate(content); } - private native void allocate(@Const @ByRef LongVectorVector content ); - - // accessing operator - public native @Cast("Nd4jLong*") @StdVector @Name("operator []") LongPointer get(@Cast("const Nd4jLong") long i); +@Namespace("sd") @NoOffset public static class Intervals extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Intervals(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Intervals(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Intervals position(long position) { + return (Intervals)super.position(position); + } - // returns size of _content - public native int size(); + // default constructor + public Intervals() { super((Pointer)null); allocate(); } + private native void allocate(); - } + // constructor + public Intervals(@Const @ByRef LongVectorVector content) { super((Pointer)null); allocate(content); } + private native void allocate(@Const @ByRef LongVectorVector content); + // accessing operator + public native @Cast("Nd4jLong*") @StdVector @Name("operator []") LongPointer get(@Cast("const Nd4jLong") long i); + // returns size of _content + public native int size(); +} + // namespace sd -// #endif //LIBND4J_INTERVALS_H +// #endif // LIBND4J_INTERVALS_H // Parsed from graph/Stash.h @@ -5707,195 +6250,78 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint //#include // #include -// #include -// #include -// #include +// #include + // #include // #include -// #include - @Namespace("sd::graph") @NoOffset public static class KeyPair extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public KeyPair(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public KeyPair(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public KeyPair position(long position) { - return (KeyPair)super.position(position); - } - - public KeyPair(int node/*=0*/, @Cast("char*") String name/*=nullptr*/) { super((Pointer)null); allocate(node, name); } - private native void allocate(int node/*=0*/, @Cast("char*") String name/*=nullptr*/); - public KeyPair() { super((Pointer)null); allocate(); } - private native void allocate(); - public KeyPair(int node/*=0*/, @Cast("char*") BytePointer name/*=nullptr*/) { super((Pointer)null); allocate(node, name); } - private native void allocate(int node/*=0*/, @Cast("char*") BytePointer name/*=nullptr*/); +// #include +// #include +// #include +@Namespace("sd::graph") @NoOffset public static class KeyPair extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public KeyPair(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public KeyPair(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public KeyPair position(long position) { + return (KeyPair)super.position(position); + } - public native @Cast("bool") @Name("operator <") boolean lessThan(@Const @ByRef KeyPair other); + public KeyPair(int node/*=0*/, @Cast("char*") String name/*=nullptr*/) { super((Pointer)null); allocate(node, name); } + private native void allocate(int node/*=0*/, @Cast("char*") String name/*=nullptr*/); + public KeyPair() { super((Pointer)null); allocate(); } + private native void allocate(); + public KeyPair(int node/*=0*/, @Cast("char*") BytePointer name/*=nullptr*/) { super((Pointer)null); allocate(node, name); } + private native void allocate(int node/*=0*/, @Cast("char*") BytePointer name/*=nullptr*/); - public native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef KeyPair other); + public native @Cast("bool") @Name("operator <") boolean lessThan(@Const @ByRef KeyPair other); - public native int key(); - public native @StdString BytePointer name(); - } - + public native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef KeyPair other); + public native int key(); + public native @StdString BytePointer name(); +} + // namespace graph + // namespace sd // #ifndef __JAVACPP_HACK__ // #endif - @Namespace("sd::graph") @NoOffset public static class Stash extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Stash(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Stash(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Stash position(long position) { - return (Stash)super.position(position); - } - - public Stash() { super((Pointer)null); allocate(); } - private native void allocate(); - - //void storeArray(sd::graph::Block& block, const char *name, sd::NDArray *array); - public native void storeArray(int nodeId, @Cast("char*") String name, NDArray array); - public native void storeArray(int nodeId, @Cast("char*") BytePointer name, NDArray array); - - //bool checkStash(sd::graph::Block& block, const char *name); - public native @Cast("bool") boolean checkStash(int nodeId, @Cast("char*") String name); - public native @Cast("bool") boolean checkStash(int nodeId, @Cast("char*") BytePointer name); - - //sd::NDArray* extractArray(sd::graph::Block& block, const char *name); - public native NDArray extractArray(int nodeId, @Cast("char*") String name); - public native NDArray extractArray(int nodeId, @Cast("char*") BytePointer name); - - public native void clear(); - } - - - - - - - -// #endif //LIBND4J_STASH_H - - -// Parsed from graph/GraphState.h - -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 23.01.18. -// - -// #ifndef LIBND4J_GRAPHSTATE_H -// #define LIBND4J_GRAPHSTATE_H - -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include - - @Namespace("sd::graph") @NoOffset public static class GraphState extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public GraphState(Pointer p) { super(p); } - - public GraphState(@Cast("Nd4jLong") long id) { super((Pointer)null); allocate(id); } - private native void allocate(@Cast("Nd4jLong") long id); - - /** - * - * @return - */ - public native @Cast("Nd4jLong") long id(); - - /** - * This method adds scope to this state tracker - * - * @param scopeId - * @return - */ - public native @Cast("Nd4jStatus") int registerScope(int scopeId); - - /** - * This method cheks if scope with given ID exists - * - * @param scopeId - ID of the scope - * @return - TRUE if scope exists, FALSE otherwise - */ - public native @Cast("bool") boolean hasScope(int scopeId); - - /** - * This method removes specified scope from this state tracker - * - * @param scopeId - * @return - */ - public native @Cast("Nd4jStatus") int forgetScope(int scopeId); - -// #ifndef __JAVACPP_HACK__ -// #endif - /** - * This method adds given op to the end of specified scope - * - * @param scopeId - * @param opNum - * @param type - * @return - */ - public native @Cast("Nd4jStatus") int attachOpToScope(int scopeId, @Cast("Nd4jLong") long opNum, int type, @ByVal ArgumentsList inputs); +@Namespace("sd::graph") @NoOffset public static class Stash extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Stash(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Stash(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Stash position(long position) { + return (Stash)super.position(position); + } - /** - * This method adds return statement to specified scope - * - * PLEASE NOTE: should be used only in body scopes - * - * @param scopeId - * @param nodeId - * @param args - * @return - */ - public native @Cast("Nd4jStatus") int defineReturn(int scopeId, int nodeId, @ByVal ArgumentsList args); + public Stash() { super((Pointer)null); allocate(); } + private native void allocate(); - /** - * This method returns current variable space of this state holder - * - * @return - */ - public native VariableSpace variableSpace(); - } + // void storeArray(sd::graph::Block& block, const char *name, + // sd::NDArray *array); + public native void storeArray(int nodeId, @Cast("char*") String name, NDArray array); + public native void storeArray(int nodeId, @Cast("char*") BytePointer name, NDArray array); + // bool checkStash(sd::graph::Block& block, const char *name); + public native @Cast("bool") boolean checkStash(int nodeId, @Cast("char*") String name); + public native @Cast("bool") boolean checkStash(int nodeId, @Cast("char*") BytePointer name); + // sd::NDArray* extractArray(sd::graph::Block& block, const char *name); + public native NDArray extractArray(int nodeId, @Cast("char*") String name); + public native NDArray extractArray(int nodeId, @Cast("char*") BytePointer name); + public native void clear(); +} + // namespace graph + // namespace sd -// #endif //LIBND4J_GRAPHSTATE_H +// #endif // LIBND4J_STASH_H // Parsed from graph/VariableSpace.h @@ -5923,102 +6349,102 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_VARIABLESPACE_H // #define LIBND4J_VARIABLESPACE_H -// #include -// #include -// #include -// #include -// #include -// #include -// #include // #include // #include +// #include +// #include // #include +// #include +// #include // #include -// #include -// #include - @Namespace("sd::graph") @NoOffset public static class VariableSpace extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public VariableSpace(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public VariableSpace(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public VariableSpace position(long position) { - return (VariableSpace)super.position(position); - } - - public VariableSpace() { super((Pointer)null); allocate(); } - private native void allocate(); - - public native @ByRef @Name("operator =") VariableSpace put(@Const @ByRef VariableSpace other); - - public native int numberOfPlaceholders(); - public native @Cast("sd::graph::Variable**") @StdVector PointerPointer getPlaceholders(); - public native void setWorkspace(Workspace workspace); - - public native LaunchContext launchContext(); - - public native @Cast("bool") boolean hasExternalVariable(int it); - public native @Cast("bool") boolean hasExternalVariable(@ByRef IntIntPair pair); - public native @Cast("bool") boolean hasExternalVariable(@StdString @Cast({"char*", "std::string*"}) BytePointer symbol); - - public native @Cast("bool") boolean hasVariable(int id); - public native @Cast("bool") boolean hasVariable(int id, int idx); - public native @Cast("bool") boolean hasVariable(@ByRef IntIntPair pair); - public native @Cast("bool") boolean hasVariable(@StdString @Cast({"char*", "std::string*"}) BytePointer symbol); - - public native Variable getVariable(int id); - public native Variable getVariable(int id, int idx); - public native Variable getVariable(@ByRef IntIntPair pair); - public native Variable getVariable(@StdString @Cast({"char*", "std::string*"}) BytePointer symbol); - - public native @Cast("sd::graph::Variable**") @StdVector PointerPointer getVariables(); - - public native Variable putVariable(@ByRef IntIntPair pair, NDArray array); - public native void putVariable(@ByRef IntIntPair pair, Variable variable); - public native void putVariable(int id, Variable variable); - public native void putVariable(int id, NDArray array); - public native Variable putVariable(int id, int idx, NDArray array); - public native void putVariable(int id, int idx, Variable array); - - public native void dropVariable(@ByRef IntIntPair pair); - public native void dropVariable(int id, int idx); - - public native void trackList(NDArrayList list); - - public native void putOutputVariable(Variable variable); - - public native void replaceVariable(Variable variable); - - // memory-related statistics - public native @Cast("Nd4jLong") long externalMemory(); - public native @Cast("Nd4jLong") long internalMemory(); - public native @Cast("Nd4jLong") long totalMemory(); - - public native int externalEntries(); - public native int internalEntries(); - public native int totalEntries(); - - public native VariableSpace clone(); - - public native @Cast("sd::graph::Variable**") @StdVector PointerPointer handles(); +// #include +// #include +// #include +// #include +// #include +@Namespace("sd::graph") @NoOffset public static class VariableSpace extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public VariableSpace(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public VariableSpace(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public VariableSpace position(long position) { + return (VariableSpace)super.position(position); + } - public native VariableSpace asT(); - public native void injectVariable(@ByRef IntIntPair pair, Variable variable); + public VariableSpace() { super((Pointer)null); allocate(); } + private native void allocate(); - public native Stash getStash(); + public VariableSpace(@Const @ByRef VariableSpace variableSpace) { super((Pointer)null); allocate(variableSpace); } + private native void allocate(@Const @ByRef VariableSpace variableSpace); - public native @Cast("sd::graph::Variable**") @StdVector PointerPointer getExternalVariables(); + public native @ByRef @Name("operator =") VariableSpace put(@Const @ByRef VariableSpace other); - public native void setFlowPath(FlowPath timers); - public native FlowPath flowPath(); - } - + public native int numberOfPlaceholders(); +// #ifndef __JAVACPP_HACK__ +// #endif + public native @Cast("bool") boolean hasExternalVariable(int it); + public native @Cast("bool") boolean hasExternalVariable(@Const @ByRef IntIntPair pair); + public native @Cast("bool") boolean hasExternalVariable(@StdString BytePointer symbol); + public native @Cast("bool") boolean hasExternalVariable(@StdString String symbol); + + public native @Cast("bool") boolean hasVariable(int id); + public native @Cast("bool") boolean hasVariable(int id, int idx); + public native @Cast("bool") boolean hasVariable(@Const @ByRef IntIntPair pair); + public native @Cast("bool") boolean hasVariable(@StdString BytePointer symbol); + public native @Cast("bool") boolean hasVariable(@StdString String symbol); + + public native @SharedPtr Variable getVariable(int id); + public native @SharedPtr Variable getVariable(int id, int idx); + public native @SharedPtr Variable getVariable( + @Const @ByRef IntIntPair pair); + public native @SharedPtr Variable getVariable( + @StdString BytePointer symbol); + public native @SharedPtr Variable getVariable( + @StdString String symbol); + + public native @SharedPtr Variable putVariable(@Const @ByRef IntIntPair pair, @Const @ByRef NDArray array); + public native @SharedPtr Variable putVariable(int id, @Const @ByRef NDArray array); + public native @SharedPtr Variable putVariable(int id, int idx, @SharedPtr NDArray array); + public native @SharedPtr Variable putVariable(@StdString BytePointer name, int id, int idx, @Const @ByRef NDArray array); + public native @SharedPtr Variable putVariable(@StdString String name, int id, int idx, @Const @ByRef NDArray array); + + public native void putVariable(@StdString BytePointer name, int id, int idx, @SharedPtr Variable variable); + public native void putVariable(@StdString String name, int id, int idx, @SharedPtr Variable variable); + public native void putVariable(@Const @ByRef IntIntPair pair, @SharedPtr Variable variable); + public native void putVariable(int id, @SharedPtr Variable variable); + + public native void dropVariable(@StdString BytePointer pair); + public native void dropVariable(@StdString String pair); + public native void dropVariable(@Const @ByRef IntIntPair pair); + public native void dropVariable(int id, int idx); + + public native void putOutputVariable(@SharedPtr Variable variable); + + public native void replaceVariable(@SharedPtr Variable variable); + + // memory-related statistics + public native @Cast("Nd4jLong") long externalMemory(); + public native @Cast("Nd4jLong") long internalMemory(); + public native @Cast("Nd4jLong") long totalMemory(); + + public native int externalEntries(); + public native int internalEntries(); + public native int totalEntries(); + + public native void injectVariable(@Const @ByRef IntIntPair pair, + @SharedPtr Variable variable); + + public native Stash stash(); +} + // namespace graph + // namespace sd -// #endif //LIBND4J_VARIABLESPACE_H +// #endif // LIBND4J_VARIABLESPACE_H // Parsed from helpers/helper_generator.h @@ -6046,10 +6472,10 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_HELPER_GENERATOR_H // #define LIBND4J_HELPER_GENERATOR_H -// #include -// #include // #include // #include +// #include +// #include // #ifdef _MSC_VER // include for uint64_t on MSVC @@ -6059,14 +6485,13 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef UINT64_C // #if defined(__LP64__) -// #define UINT64_C(c) c ## UL +// #define UINT64_C(c) c##UL // #else -// #define UINT64_C(c) c ## ULL -// #endif //LP64 -// #endif // UINT64 - -// #endif // MSVC/ANDROID +// #define UINT64_C(c) c##ULL +// #endif // LP64 +// #endif // UINT64 +// #endif // MSVC/ANDROID // #ifdef __GNUC__ // #include @@ -6074,200 +6499,194 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifdef __CUDACC__ // #else - @Namespace("sd::random") @NoOffset public static class RandomBuffer extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public RandomBuffer(Pointer p) { super(p); } - - /** - * This method allocates buffer of size * sizeof(Nd4jLong) - * - * @param size - * @return - */ +@Namespace("sd::random") @NoOffset public static class RandomBuffer extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public RandomBuffer(Pointer p) { super(p); } + + /** + * This method allocates buffer of size * sizeof(Nd4jLong) + * + * @param size + * @return + */ // #ifdef __CUDACC__ // #endif - public RandomBuffer(@Cast("Nd4jLong") long seed, @Cast("Nd4jLong") long size, @Cast("uint64_t*") LongPointer buffer) { super((Pointer)null); allocate(seed, size, buffer); } - private native void allocate(@Cast("Nd4jLong") long seed, @Cast("Nd4jLong") long size, @Cast("uint64_t*") LongPointer buffer); - public RandomBuffer(@Cast("Nd4jLong") long seed, @Cast("Nd4jLong") long size, @Cast("uint64_t*") LongBuffer buffer) { super((Pointer)null); allocate(seed, size, buffer); } - private native void allocate(@Cast("Nd4jLong") long seed, @Cast("Nd4jLong") long size, @Cast("uint64_t*") LongBuffer buffer); - public RandomBuffer(@Cast("Nd4jLong") long seed, @Cast("Nd4jLong") long size, @Cast("uint64_t*") long[] buffer) { super((Pointer)null); allocate(seed, size, buffer); } - private native void allocate(@Cast("Nd4jLong") long seed, @Cast("Nd4jLong") long size, @Cast("uint64_t*") long[] buffer); + public RandomBuffer(@Cast("Nd4jLong") long seed, @Cast("Nd4jLong") long size, @Cast("uint64_t*") LongPointer buffer) { super((Pointer)null); allocate(seed, size, buffer); } + private native void allocate(@Cast("Nd4jLong") long seed, @Cast("Nd4jLong") long size, @Cast("uint64_t*") LongPointer buffer); + public RandomBuffer(@Cast("Nd4jLong") long seed, @Cast("Nd4jLong") long size, @Cast("uint64_t*") LongBuffer buffer) { super((Pointer)null); allocate(seed, size, buffer); } + private native void allocate(@Cast("Nd4jLong") long seed, @Cast("Nd4jLong") long size, @Cast("uint64_t*") LongBuffer buffer); + public RandomBuffer(@Cast("Nd4jLong") long seed, @Cast("Nd4jLong") long size, @Cast("uint64_t*") long[] buffer) { super((Pointer)null); allocate(seed, size, buffer); } + private native void allocate(@Cast("Nd4jLong") long seed, @Cast("Nd4jLong") long size, @Cast("uint64_t*") long[] buffer); - public native @Cast("uint64_t*") LongPointer getBuffer(); + public native @Cast("uint64_t*") LongPointer getBuffer(); - public native @Cast("uint64_t*") LongPointer getDeviceBuffer(); + public native @Cast("uint64_t*") LongPointer getDeviceBuffer(); // #ifdef __CUDACC__ // #endif - public native @Cast("Nd4jLong") long getSize(); - - public native @Cast("Nd4jLong") long getSeed(); + public native @Cast("Nd4jLong") long getSize(); - public native void setSeed(@Cast("Nd4jLong") long seed); + public native @Cast("Nd4jLong") long getSeed(); - public native @Cast("Nd4jLong") long getAllocatedSize(); + public native void setSeed(@Cast("Nd4jLong") long seed); - public native @Cast("Nd4jLong") long getOffset(); + public native @Cast("Nd4jLong") long getAllocatedSize(); - public native void setOffset(@Cast("Nd4jLong") long offset); + public native @Cast("Nd4jLong") long getOffset(); - public native void reSeed(@Cast("Nd4jLong") long amplifier); + public native void setOffset(@Cast("Nd4jLong") long offset); - public native @Cast("uint64_t") long getElement(@Cast("Nd4jLong") long _position); + public native void reSeed(@Cast("Nd4jLong") long amplifier); - public native @Cast("uint64_t") long next64(@Cast("uint64_t") long shiftedSeed); + public native @Cast("uint64_t") long getElement(@Cast("Nd4jLong") long _position); - public static native @Cast("uint64_t") long rotl(@Cast("const uint64_t") long x, @Cast("uint64_t") long k); + public native @Cast("uint64_t") long next64(@Cast("uint64_t") long shiftedSeed); - public static native @Cast("uint64_t") long safeShift(@Cast("uint64_t") long x, @Cast("uint64_t") long y); + public static native @Cast("uint64_t") long rotl(@Cast("const uint64_t") long x, @Cast("uint64_t") long k); - public native @Cast("uint64_t") long seedConv(@Cast("Nd4jLong") long seed); + public static native @Cast("uint64_t") long safeShift(@Cast("uint64_t") long x, @Cast("uint64_t") long y); - public native void incrementGeneration(); + public native @Cast("uint64_t") long seedConv(@Cast("Nd4jLong") long seed); - public native @Cast("Nd4jLong") long getNextIndex(); + public native void incrementGeneration(); - public native @Cast("uint64_t") long getNextElement(); + public native @Cast("Nd4jLong") long getNextIndex(); + public native @Cast("uint64_t") long getNextElement(); - /** - * This method skips X elements from buffer - * - * @param numberOfElements number of elements to skip - */ + /** + * This method skips X elements from buffer + * + * @param numberOfElements number of elements to skip + */ // #ifdef __CUDACC__ // #endif - public native void rewindH(@Cast("Nd4jLong") long numberOfElements); - - /** - * This method returns random int in range [0..MAX_INT] - * @return - */ - public native int nextInt(); - - public native @Cast("uint64_t") long nextUInt64(); - - /** - * This method returns random int in range [0..to] - * @param to - * @return - */ - public native int nextInt(int to); - - /** - * This method returns random int in range [from..to] - * @param from - * @param to - * @return - */ - public native int nextInt(int from, int to); - - - /** - * This method returns random T in range of [0..1] - * @return - */ - - /** - * This method returns random T in range of [0..to] - * @param to - * @return - */ - - /** - * This method returns random T in range [from..to] - * @param from - * @param to - * @return - */ - - public native @Cast("uint64_t") long relativeUInt64(@Cast("Nd4jLong") long index); - - /** - * relative methods are made as workaround for lock-free concurrent execution - */ - public native int relativeInt(@Cast("Nd4jLong") long index); - - /** - * This method returns random int within [0..to] - * - * @param index - * @param to - * @return - */ - public native int relativeInt(@Cast("Nd4jLong") long index, int to); - - /** - * This method returns random int within [from..to] - * - * @param index - * @param to - * @param from - * @return - */ - public native int relativeInt(@Cast("Nd4jLong") long index, int from, int to); - - /** - * This method returns random T within [0..1] - * - * @param index - * @return - */ - -/** - * This method returns random T within [0..to] - * - * @param index - * @param to - * @return - */ + public native void rewindH(@Cast("Nd4jLong") long numberOfElements); -/** - * This method returns random T within [from..to] - * - * @param index - * @param from - * @param to - * @return - */ + /** + * This method returns random int in range [0..MAX_INT] + * @return + */ + public native int nextInt(); - } + public native @Cast("uint64_t") long nextUInt64(); - @Namespace("sd::random") @NoOffset public static class IGenerator extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public IGenerator(Pointer p) { super(p); } - + /** + * This method returns random int in range [0..to] + * @param to + * @return + */ + public native int nextInt(int to); + /** + * This method returns random int in range [from..to] + * @param from + * @param to + * @return + */ + public native int nextInt(int from, int to); - public native RandomBuffer getBuffer(); + /** + * This method returns random T in range of [0..1] + * @return + */ - public native void setOffset(@Cast("Nd4jLong") long offset); + /** + * This method returns random T in range of [0..to] + * @param to + * @return + */ - public native @Cast("Nd4jLong") long getElementAbsolute(@Cast("Nd4jLong") long _position); + /** + * This method returns random T in range [from..to] + * @param from + * @param to + * @return + */ - public native @Cast("Nd4jLong") long getElementRelative(@Cast("Nd4jLong") long _position); + public native @Cast("uint64_t") long relativeUInt64(@Cast("Nd4jLong") long index); - public native void refreshBuffer(); - } + /** + * relative methods are made as workaround for lock-free concurrent execution + */ + public native int relativeInt(@Cast("Nd4jLong") long index); + /** + * This method returns random int within [0..to] + * + * @param index + * @param to + * @return + */ + public native int relativeInt(@Cast("Nd4jLong") long index, int to); + /** + * This method returns random int within [from..to] + * + * @param index + * @param to + * @param from + * @return + */ + public native int relativeInt(@Cast("Nd4jLong") long index, int from, int to); - @Namespace("sd::random") @NoOffset public static class Xoroshiro128 extends IGenerator { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Xoroshiro128(Pointer p) { super(p); } - - public Xoroshiro128(RandomBuffer buffer) { super((Pointer)null); allocate(buffer); } - private native void allocate(RandomBuffer buffer); + /** + * This method returns random T within [0..1] + * + * @param index + * @return + */ - public native void refreshBuffer(); - } - + /** + * This method returns random T within [0..to] + * + * @param index + * @param to + * @return + */ + + /** + * This method returns random T within [from..to] + * + * @param index + * @param from + * @param to + * @return + */ +} + +@Namespace("sd::random") @NoOffset public static class IGenerator extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public IGenerator(Pointer p) { super(p); } + + + public native RandomBuffer getBuffer(); + + public native void setOffset(@Cast("Nd4jLong") long offset); + + public native @Cast("Nd4jLong") long getElementAbsolute(@Cast("Nd4jLong") long _position); + + public native @Cast("Nd4jLong") long getElementRelative(@Cast("Nd4jLong") long _position); + + public native void refreshBuffer(); +} + +@Namespace("sd::random") @NoOffset public static class Xoroshiro128 extends IGenerator { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Xoroshiro128(Pointer p) { super(p); } -// #endif //LIBND4J_HELPER_GENERATOR_H + public Xoroshiro128(RandomBuffer buffer) { super((Pointer)null); allocate(buffer); } + private native void allocate(RandomBuffer buffer); + + public native void refreshBuffer(); +} + // namespace random + // namespace sd +// #endif // LIBND4J_HELPER_GENERATOR_H // Parsed from graph/profiling/GraphProfile.h @@ -6295,102 +6714,105 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef ND4J_GRAPH_PROFILE_H // #define ND4J_GRAPH_PROFILE_H -// #include "NodeProfile.h" -// #include // #include -// #include -// #include -// #include +// #include + // #include - @Namespace("sd::graph") @NoOffset public static class GraphProfile extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public GraphProfile(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public GraphProfile(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public GraphProfile position(long position) { - return (GraphProfile)super.position(position); - } - - public GraphProfile() { super((Pointer)null); allocate(); } - private native void allocate(); - - /** - * These methods just adding amount of bytes to various counters - */ - public native void addToTotal(@Cast("Nd4jLong") long bytes); - public native void addToActivations(@Cast("Nd4jLong") long bytes); - public native void addToTemporary(@Cast("Nd4jLong") long bytes); - public native void addToObjects(@Cast("Nd4jLong") long bytes); - - /** - * This method allows to set graph construction (i.e. deserialization) time in nanoseconds - */ - public native void setBuildTime(@Cast("Nd4jLong") long nanos); - - /** - * This method sets graph execution time in nanoseconds. - */ - public native void setExecutionTime(@Cast("Nd4jLong") long nanos); - - public native void startEvent(@Cast("char*") String name); - public native void startEvent(@Cast("char*") BytePointer name); - public native void recordEvent(@Cast("char*") String name); - public native void recordEvent(@Cast("char*") BytePointer name); - public native void deleteEvent(@Cast("char*") String name); - public native void deleteEvent(@Cast("char*") BytePointer name); - - /** - * This method saves time as delta from last saved time - */ - public native void spotEvent(@Cast("char*") String name); - public native void spotEvent(@Cast("char*") BytePointer name); - - /** - * This method returns pointer to NodeProfile by ID - * PLEASE NOTE: this method will create new NodeProfile if there's none - */ - public native NodeProfile nodeById(int id, @Cast("char*") String name/*=nullptr*/); - public native NodeProfile nodeById(int id); - public native NodeProfile nodeById(int id, @Cast("char*") BytePointer name/*=nullptr*/); - public native @Cast("bool") boolean nodeExists(int id); - - /** - * This method merges values from other profile report - * @param other - */ - public native void merge(GraphProfile other); - public native void assign(GraphProfile other); - - /** - * These methods are just utility methods for time - */ - public static native @Cast("Nd4jLong") long currentTime(); - public static native @Cast("Nd4jLong") long relativeTime(@Cast("Nd4jLong") long time); - - public native void printOut(); - } - +// #include +// #include +// #include +// #include "NodeProfile.h" +@Namespace("sd::graph") @NoOffset public static class GraphProfile extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public GraphProfile(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public GraphProfile(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public GraphProfile position(long position) { + return (GraphProfile)super.position(position); + } -// #endif + public GraphProfile() { super((Pointer)null); allocate(); } + private native void allocate(); -// Parsed from graph/profiling/NodeProfile.h + /** + * These methods just adding amount of bytes to various counters + */ + public native void addToTotal(@Cast("Nd4jLong") long bytes); + public native void addToActivations(@Cast("Nd4jLong") long bytes); + public native void addToTemporary(@Cast("Nd4jLong") long bytes); + public native void addToObjects(@Cast("Nd4jLong") long bytes); -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * + /** + * This method allows to set graph construction (i.e. deserialization) time in + * nanoseconds + */ + public native void setBuildTime(@Cast("Nd4jLong") long nanos); + + /** + * This method sets graph execution time in nanoseconds. + */ + public native void setExecutionTime(@Cast("Nd4jLong") long nanos); + + public native void startEvent(@Cast("char*") String name); + public native void startEvent(@Cast("char*") BytePointer name); + public native void recordEvent(@Cast("char*") String name); + public native void recordEvent(@Cast("char*") BytePointer name); + public native void deleteEvent(@Cast("char*") String name); + public native void deleteEvent(@Cast("char*") BytePointer name); + + /** + * This method saves time as delta from last saved time + */ + public native void spotEvent(@Cast("char*") String name); + public native void spotEvent(@Cast("char*") BytePointer name); + + /** + * This method returns pointer to NodeProfile by ID + * PLEASE NOTE: this method will create new NodeProfile if there's none + */ + public native NodeProfile nodeById(int id, @Cast("char*") String name/*=nullptr*/); + public native NodeProfile nodeById(int id); + public native NodeProfile nodeById(int id, @Cast("char*") BytePointer name/*=nullptr*/); + public native @Cast("bool") boolean nodeExists(int id); + + /** + * This method merges values from other profile report + * @param other + */ + public native void merge(GraphProfile other); + public native void assign(GraphProfile other); + + /** + * These methods are just utility methods for time + */ + public static native @Cast("Nd4jLong") long currentTime(); + public static native @Cast("Nd4jLong") long relativeTime(@Cast("Nd4jLong") long time); + + public native void printOut(); +} + // namespace graph + // namespace sd + +// #endif + +// Parsed from graph/profiling/NodeProfile.h + +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ @@ -6401,65 +6823,66 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_NODE_PROFILE_H // #define LIBND4J_NODE_PROFILE_H -// #include // #include +// #include + // #include // #include - @Namespace("sd::graph") @NoOffset public static class NodeProfile extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public NodeProfile(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public NodeProfile(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public NodeProfile position(long position) { - return (NodeProfile)super.position(position); - } - - public NodeProfile() { super((Pointer)null); allocate(); } - private native void allocate(); - - public NodeProfile(int id, @Cast("char*") String name) { super((Pointer)null); allocate(id, name); } - private native void allocate(int id, @Cast("char*") String name); - public NodeProfile(int id, @Cast("char*") BytePointer name) { super((Pointer)null); allocate(id, name); } - private native void allocate(int id, @Cast("char*") BytePointer name); - - public native void setBuildTime(@Cast("Nd4jLong") long time); - public native void setPreparationTime(@Cast("Nd4jLong") long time); - public native void setExecutionTime(@Cast("Nd4jLong") long time); - public native void setTotalTime(@Cast("Nd4jLong") long time); - public native void setShapeFunctionTime(@Cast("Nd4jLong") long time); - public native void setArrayTime(@Cast("Nd4jLong") long time); - public native void setInputTime(@Cast("Nd4jLong") long time); - - public native void setActivationsSize(@Cast("Nd4jLong") long bytes); - public native void setTemporarySize(@Cast("Nd4jLong") long bytes); - public native void setObjectsSize(@Cast("Nd4jLong") long bytes); - public native void setTotalSize(@Cast("Nd4jLong") long bytes); - - public native void addInputShape(@Cast("const Nd4jLong*") LongPointer shapeInfo); - public native void addInputShape(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - public native void addInputShape(@Cast("const Nd4jLong*") long[] shapeInfo); - public native void addOutputShape(@Cast("const Nd4jLong*") LongPointer shapeInfo); - public native void addOutputShape(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - public native void addOutputShape(@Cast("const Nd4jLong*") long[] shapeInfo); - - public native @Cast("Nd4jLong") long getActivationsSize(); - public native @Cast("Nd4jLong") long getTemporarySize(); - public native @Cast("Nd4jLong") long getObjectsSize(); - public native @Cast("Nd4jLong") long getTotalSize(); - - public native @Cast("Nd4jLong") long getExecutionTime(); - - public native @StdString @ByRef @Cast({"char*", "std::string*"}) BytePointer name(); - - public native void merge(NodeProfile other); - public native void assign(NodeProfile other); - - public native void printOut(); - } - +@Namespace("sd::graph") @NoOffset public static class NodeProfile extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public NodeProfile(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public NodeProfile(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public NodeProfile position(long position) { + return (NodeProfile)super.position(position); + } + + public NodeProfile() { super((Pointer)null); allocate(); } + private native void allocate(); + + public NodeProfile(int id, @Cast("char*") String name) { super((Pointer)null); allocate(id, name); } + private native void allocate(int id, @Cast("char*") String name); + public NodeProfile(int id, @Cast("char*") BytePointer name) { super((Pointer)null); allocate(id, name); } + private native void allocate(int id, @Cast("char*") BytePointer name); + + public native void setBuildTime(@Cast("Nd4jLong") long time); + public native void setPreparationTime(@Cast("Nd4jLong") long time); + public native void setExecutionTime(@Cast("Nd4jLong") long time); + public native void setTotalTime(@Cast("Nd4jLong") long time); + public native void setShapeFunctionTime(@Cast("Nd4jLong") long time); + public native void setArrayTime(@Cast("Nd4jLong") long time); + public native void setInputTime(@Cast("Nd4jLong") long time); + + public native void setActivationsSize(@Cast("Nd4jLong") long bytes); + public native void setTemporarySize(@Cast("Nd4jLong") long bytes); + public native void setObjectsSize(@Cast("Nd4jLong") long bytes); + public native void setTotalSize(@Cast("Nd4jLong") long bytes); + + public native void addInputShape(@Cast("const Nd4jLong*") LongPointer shapeInfo); + public native void addInputShape(@Cast("const Nd4jLong*") LongBuffer shapeInfo); + public native void addInputShape(@Cast("const Nd4jLong*") long[] shapeInfo); + public native void addOutputShape(@Cast("const Nd4jLong*") LongPointer shapeInfo); + public native void addOutputShape(@Cast("const Nd4jLong*") LongBuffer shapeInfo); + public native void addOutputShape(@Cast("const Nd4jLong*") long[] shapeInfo); + + public native @Cast("Nd4jLong") long getActivationsSize(); + public native @Cast("Nd4jLong") long getTemporarySize(); + public native @Cast("Nd4jLong") long getObjectsSize(); + public native @Cast("Nd4jLong") long getTotalSize(); + + public native @Cast("Nd4jLong") long getExecutionTime(); + + public native @StdString @ByRef @Cast({"char*", "std::string*"}) BytePointer name(); + + public native void merge(NodeProfile other); + public native void assign(NodeProfile other); + public native void printOut(); +} + // namespace graph + // namespace sd // #endif @@ -6489,214 +6912,209 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_CONTEXT_H // #define LIBND4J_CONTEXT_H -// #include // #include +// #include +// #include // #include // #include -// #include +// #include // #include -// #include + +// #include // CUDA-specific includes // #ifdef __CUDACC__ // #endif - /** - * This class defines input desired for any given node/operation within graph - */ - @Namespace("sd::graph") @NoOffset public static class Context extends ContextPrototype { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Context(Pointer p) { super(p); } - - public Context(ContextPrototype prototype, VariableSpace variableSpace) { super((Pointer)null); allocate(prototype, variableSpace); } - private native void allocate(ContextPrototype prototype, VariableSpace variableSpace); - - public Context(int nodeId, VariableSpace variableSpace/*=nullptr*/) { super((Pointer)null); allocate(nodeId, variableSpace); } - private native void allocate(int nodeId, VariableSpace variableSpace/*=nullptr*/); - public Context(int nodeId) { super((Pointer)null); allocate(nodeId); } - private native void allocate(int nodeId); - public Context(int nodeId, VariableSpace variableSpace, @Cast("bool") boolean isInplace) { super((Pointer)null); allocate(nodeId, variableSpace, isInplace); } - private native void allocate(int nodeId, VariableSpace variableSpace, @Cast("bool") boolean isInplace); - - // default destructor - - // these methods are for execution timing - public native void setOuterTime(@Cast("Nd4jLong") long time); - public native void setInnerTime(@Cast("Nd4jLong") long time); - public native @Cast("Nd4jLong") long getOuterTime(); - public native @Cast("Nd4jLong") long getInnerTime(); - - public native @Cast("sd::DataType") int dataType(); - - public native @Cast("sd::DataType") int dataType(int index); - public native void setDataType(int index, @Cast("sd::DataType") int type); - // these methods are related to Workspace abstraction - public native @Cast("bool") boolean hasWorkspaceProvided(); - public native void attachWorkspace(Workspace workspace); - public native void forgetWorkspace(); - - // these methods return full-time workspace - public native Workspace getWorkspace(); - public native Workspace workspace(); - public native Workspace fWorkspace(); - - // this method returns workspace for temporary allocations - public native Workspace tWorkspace(); - - // this method returns workspace for object allocations - public native Workspace oWorkspace(); - - public native void setVariableSpace(VariableSpace variableSpace); - - public native RandomBuffer getRNG(); - public native void setRNG(RandomBuffer rng); +/** + * This class defines input desired for any given node/operation within graph + */ +@Namespace("sd::graph") @NoOffset public static class Context extends ContextPrototype { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Context(Pointer p) { super(p); } - public native void setTargetEngine(@Cast("samediff::Engine") int engine); + public Context(@Const @ByRef ContextPrototype prototype, VariableSpace variableSpace, + GraphMemoryManager memoryManager/*=nullptr*/) { super((Pointer)null); allocate(prototype, variableSpace, memoryManager); } + private native void allocate(@Const @ByRef ContextPrototype prototype, VariableSpace variableSpace, + GraphMemoryManager memoryManager/*=nullptr*/); + public Context(@Const @ByRef ContextPrototype prototype, VariableSpace variableSpace) { super((Pointer)null); allocate(prototype, variableSpace); } + private native void allocate(@Const @ByRef ContextPrototype prototype, VariableSpace variableSpace); - public native VariableSpace getVariableSpace(); + public Context(int nodeId, VariableSpace variableSpace/*=nullptr*/) { super((Pointer)null); allocate(nodeId, variableSpace); } + private native void allocate(int nodeId, VariableSpace variableSpace/*=nullptr*/); + public Context(int nodeId) { super((Pointer)null); allocate(nodeId); } + private native void allocate(int nodeId); + public Context(int nodeId, VariableSpace variableSpace, @Cast("bool") boolean isInplace) { super((Pointer)null); allocate(nodeId, variableSpace, isInplace); } + private native void allocate(int nodeId, VariableSpace variableSpace, @Cast("bool") boolean isInplace); - public native LaunchContext launchContext(); + // default destructor - // these fields define, if we can execute specific node in-place, without generating new array + // these methods are for execution timing + public native void setOuterTime(@Cast("Nd4jLong") long time); + public native void setInnerTime(@Cast("Nd4jLong") long time); + public native @Cast("Nd4jLong") long outerTime(); + public native @Cast("Nd4jLong") long innerTime(); + // these methods are related to Workspace abstraction + public native @Cast("bool") boolean hasWorkspaceProvided(); + public native void attachWorkspace(Workspace workspace); - // these variables are only for Divergent Nodes - public native int getBranch(); - public native void setBranch(int branch); + public native Workspace workspace(); - /** - * - * @return - */ - public native Stash getStash(); + public native void setVariableSpace(VariableSpace variableSpace); - /** - * - */ - public native void trackList(NDArrayList list); + public native void setTargetEngine(@Cast("samediff::Engine") int engine); + public native VariableSpace getVariableSpace(); - /** - * This method returns variable for a given input index for this block - * @param idx - * @return - */ - public native Variable getVariable(int idx); - public native Variable variable(int idx); + public native LaunchContext launchContext(); - /** - * This method is shortcut to getVariable(int idx); - * - * + it check fastpath for array availability (preferred) - * @return - */ - public native NDArray getNDArray(int idx); - public native NDArray array(int idx); + /** + * + * @return + */ + public native Stash stash(); + /** + * This method returns variable for a given input index for this block + * @param idx + * @return + */ + public native @SharedPtr Variable getVariable(int idx); + public native @SharedPtr Variable variable(int idx); - /** - * This method fetches variable from VariableSpace DIRECTLY - * @param p - * @return - */ - public native Variable variable(int node, int index); - public native Variable variable(@ByRef IntIntPair p); + /** + * This method is shortcut to getVariable(int idx); + * + * + it check fastpath for array availability (preferred) + * @return + */ + public native @SharedPtr NDArray getNDArray(int idx); + public native @SharedPtr NDArray array(int idx); + /** + * This is special method, used only within Graph + * @param idx + * @return + */ + public native NDArray arrayForOp(int idx); - public native void pushNDArrayToVariableSpace(int nodeId, int index, NDArray array, @Cast("bool") boolean removable/*=true*/); - public native void pushNDArrayToVariableSpace(int nodeId, int index, NDArray array); - public native void pushNDArrayToVariableSpace(@ByRef IntIntPair pair, NDArray array, @Cast("bool") boolean removable/*=true*/); - public native void pushNDArrayToVariableSpace(@ByRef IntIntPair pair, NDArray array); + /** + * This method fetches variable from VariableSpace DIRECTLY + * @param p + * @return + */ + public native @SharedPtr Variable variable(int node, int index); + public native @SharedPtr Variable variable(@Const @ByRef IntIntPair p); + + public native void pushNDArrayToVariableSpace(int nodeId, int index, @Const @ByRef NDArray array); + public native void pushNDArrayToVariableSpace(@Const @ByRef IntIntPair pair, + @Const @ByRef NDArray array); + + public native void pushNDArrayListToVariableSpace(int nodeId, int index, @SharedPtr NDArrayList list); + public native void pushNDArrayListToVariableSpace(int nodeId, int index, + @Const @ByRef NDArrayList list, + @Cast("bool") boolean track/*=true*/); + public native void pushNDArrayListToVariableSpace(@Const @ByRef IntIntPair pair, + @Const @ByRef NDArrayList list, + @Cast("bool") boolean track/*=true*/); + public native void pushNDArrayListToVariableSpace(@Const @ByRef IntIntPair pair, + @Const @ByRef NDArrayList list); + + public native @Cast("bool") boolean isValueAvailable(@StdString BytePointer name, int id, int idx/*=0*/); + public native @Cast("bool") boolean isValueAvailable(@StdString BytePointer name, int id); + public native @Cast("bool") boolean isValueAvailable(@StdString String name, int id, int idx/*=0*/); + public native @Cast("bool") boolean isValueAvailable(@StdString String name, int id); + + public native @SharedPtr Variable ensureVariable(@StdString BytePointer name, int id, + int idx/*=0*/); + public native @SharedPtr Variable ensureVariable(@StdString BytePointer name, int id); + public native @SharedPtr Variable ensureVariable(@StdString String name, int id, + int idx/*=0*/); + public native @SharedPtr Variable ensureVariable(@StdString String name, int id); + + public native @Cast("unsigned long") long width(); + + // methods used in java interop + /** + * This method checks if Context uses fastpath variable access + * @return + */ + public native @Cast("bool") boolean isFastPath(); - public native void pushNDArrayListToVariableSpace(int nodeId, int index, NDArrayList list, @Cast("bool") boolean track/*=true*/); - public native void pushNDArrayListToVariableSpace(int nodeId, int index, NDArrayList list); - public native void pushNDArrayListToVariableSpace(@ByRef IntIntPair pair, NDArrayList list, @Cast("bool") boolean track/*=true*/); - public native void pushNDArrayListToVariableSpace(@ByRef IntIntPair pair, NDArrayList list); + /** + * Method allows to forbid FastPath execution + * @param reallyForbid + */ + public native void forbidFastPath(@Cast("bool") boolean reallyForbid); - public native @Cast("bool") boolean isValueAvailable(int idx/*=0*/); - public native @Cast("bool") boolean isValueAvailable(); +// #ifndef __JAVACPP_HACK__ +// #endif - public native Variable ensureVariable(int idx/*=0*/); - public native Variable ensureVariable(); + public native void setInputArray(int index, @SharedPtr NDArray array); + public native void setInputArray(int index, Pointer buffer, @Const Pointer shapeInfo, + Pointer specialBuffer, @Const Pointer specialShapeInfo); + public native void setInputArray(int index, Pointer databuffer, @Const Pointer shapeInfo, + @Const Pointer specialShapeInfo); + + public native void setOutputArray(int index, @SharedPtr NDArray array); + public native void setOutputArray(int index, Pointer buffer, @Const Pointer shapeInfo, + Pointer specialBuffer, @Const Pointer specialShapeInfo); + public native void setOutputArray(int index, Pointer databuffer, @Const Pointer shapeInfo, + @Const Pointer specialShapeInfo); + + public native void setTArguments(DoublePointer arguments, int numberOfArguments); + public native void setTArguments(DoubleBuffer arguments, int numberOfArguments); + public native void setTArguments(double[] arguments, int numberOfArguments); + public native void setIArguments(@Cast("Nd4jLong*") LongPointer arguments, int numberOfArguments); + public native void setIArguments(@Cast("Nd4jLong*") LongBuffer arguments, int numberOfArguments); + public native void setIArguments(@Cast("Nd4jLong*") long[] arguments, int numberOfArguments); + public native void setBArguments(@Cast("bool*") BooleanPointer arguments, int numberOfArguments); + public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments); + public native void setDArguments(@Cast("sd::DataType*") IntPointer arguments, int numberOfArguments); + public native void setDArguments(@Cast("sd::DataType*") IntBuffer arguments, int numberOfArguments); + public native void setDArguments(@Cast("sd::DataType*") int[] arguments, int numberOfArguments); + + public native void setTArguments(@StdVector DoublePointer tArgs); + public native void setTArguments(@StdVector DoubleBuffer tArgs); + public native void setTArguments(@StdVector double[] tArgs); + public native void setIArguments(@Cast("Nd4jLong*") @StdVector LongPointer tArgs); + public native void setIArguments(@Cast("Nd4jLong*") @StdVector LongBuffer tArgs); + public native void setIArguments(@Cast("Nd4jLong*") @StdVector long[] tArgs); + public native void setBArguments(@Cast("bool*") @StdVector BooleanPointer tArgs); + public native void setBArguments(@Cast("bool*") @StdVector boolean[] tArgs); + public native void setDArguments(@Cast("sd::DataType*") @StdVector IntPointer dArgs); + public native void setDArguments(@Cast("sd::DataType*") @StdVector IntBuffer dArgs); + public native void setDArguments(@Cast("sd::DataType*") @StdVector int[] dArgs); - public native @Cast("unsigned long") long width(); + /** + * This method purges fastpath in/out contents and releases all the handles. + * + * PLEASE NOTE: I/T/B/D args will stay intact + */ + public native void clearFastPath(); - // methods used in java interop - /** - * This method checks if Context uses fastpath variable access - * @return - */ - public native @Cast("bool") boolean isFastPath(); + public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, + @Cast("Nd4jPointer") Pointer allocationPointer); - /** - * Method allows to forbid FastPath execution - * @param reallyForbid - */ - public native void forbidFastPath(@Cast("bool") boolean reallyForbid); + public native void allowHelpers(@Cast("bool") boolean reallyAllow); + public native @Cast("bool") boolean helpersAllowed(); -// #ifndef __JAVACPP_HACK__ -// #endif + public native void setShapeFunctionOverride(@Cast("bool") boolean reallyOverride); + public native @Cast("bool") boolean shapeFunctionOverride(); - public native void setInputArray(int index, NDArray array, @Cast("bool") boolean removable/*=false*/); - public native void setInputArray(int index, NDArray array); - public native void setInputArray(int index, Pointer buffer, @Const Pointer shapeInfo, Pointer specialBuffer, @Const Pointer specialShapeInfo); - public native void setInputArray(int index, Pointer databuffer, @Const Pointer shapeInfo, @Const Pointer specialShapeInfo); - - public native void setOutputArray(int index, NDArray array, @Cast("bool") boolean removable/*=false*/); - public native void setOutputArray(int index, NDArray array); - public native void setOutputArray(int index, Pointer buffer, @Const Pointer shapeInfo, Pointer specialBuffer, @Const Pointer specialShapeInfo); - public native void setOutputArray(int index, Pointer databuffer, @Const Pointer shapeInfo, @Const Pointer specialShapeInfo); - - public native void setTArguments(DoublePointer arguments, int numberOfArguments); - public native void setTArguments(DoubleBuffer arguments, int numberOfArguments); - public native void setTArguments(double[] arguments, int numberOfArguments); - public native void setIArguments(@Cast("Nd4jLong*") LongPointer arguments, int numberOfArguments); - public native void setIArguments(@Cast("Nd4jLong*") LongBuffer arguments, int numberOfArguments); - public native void setIArguments(@Cast("Nd4jLong*") long[] arguments, int numberOfArguments); - public native void setBArguments(@Cast("bool*") BooleanPointer arguments, int numberOfArguments); - public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments); - public native void setDArguments(@Cast("sd::DataType*") IntPointer arguments, int numberOfArguments); - public native void setDArguments(@Cast("sd::DataType*") IntBuffer arguments, int numberOfArguments); - public native void setDArguments(@Cast("sd::DataType*") int[] arguments, int numberOfArguments); - - public native void setTArguments(@StdVector DoublePointer tArgs); - public native void setTArguments(@StdVector DoubleBuffer tArgs); - public native void setTArguments(@StdVector double[] tArgs); - public native void setIArguments(@Cast("Nd4jLong*") @StdVector LongPointer tArgs); - public native void setIArguments(@Cast("Nd4jLong*") @StdVector LongBuffer tArgs); - public native void setIArguments(@Cast("Nd4jLong*") @StdVector long[] tArgs); - public native void setBArguments(@Cast("bool*") @StdVector BooleanPointer tArgs); - public native void setBArguments(@Cast("bool*") @StdVector boolean[] tArgs); - public native void setDArguments(@Cast("sd::DataType*") @StdVector IntPointer dArgs); - public native void setDArguments(@Cast("sd::DataType*") @StdVector IntBuffer dArgs); - public native void setDArguments(@Cast("sd::DataType*") @StdVector int[] dArgs); - - /** - * This method purges fastpath in/out contents and releases all the handles. - * - * PLEASE NOTE: I/T/B/D args will stay intact - */ - public native void clearFastPath(); - - public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer); - - public native void allowHelpers(@Cast("bool") boolean reallyAllow); - public native @Cast("bool") boolean helpersAllowed(); - - public native void setShapeFunctionOverride(@Cast("bool") boolean reallyOverride); - public native @Cast("bool") boolean shapeFunctionOverride(); - - public native @Cast("samediff::ExecutionMode") int executionMode(); - public native void setExecutionMode(@Cast("samediff::ExecutionMode") int executionMode); - - public native @Cast("bool") boolean isTraining(); - public native @Cast("bool") boolean isInference(); - } - + public native @Cast("samediff::ExecutionMode") int executionMode(); + public native void setExecutionMode(@Cast("samediff::ExecutionMode") int executionMode); + public native @Cast("bool") boolean isTraining(); + public native @Cast("bool") boolean isInference(); + public native @Const @ByRef GraphMemoryManager memoryManager(); +} + // namespace graph + // namespace sd -// #endif //LIBND4J_BLOCK_H +// #endif // LIBND4J_BLOCK_H // Parsed from graph/ContextPrototype.h @@ -6725,99 +7143,130 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef ND4J_CONTEXT_PROTOTYPE_H // #define ND4J_CONTEXT_PROTOTYPE_H -// #include -// #include // #include -// #include -// #include -// #include // #include // #include +// #include +// #include +// #include +// #include + +// #include // #ifndef __STANDALONE_BUILD__ // #include // #endif - @Namespace("sd::graph") @NoOffset public static class ContextPrototype extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ContextPrototype(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public ContextPrototype(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public ContextPrototype position(long position) { - return (ContextPrototype)super.position(position); - } - - public ContextPrototype(OpDescriptor opDescriptor/*=nullptr*/, int nodeId/*=1*/, @Cast("bool") boolean inPlace/*=false*/) { super((Pointer)null); allocate(opDescriptor, nodeId, inPlace); } - private native void allocate(OpDescriptor opDescriptor/*=nullptr*/, int nodeId/*=1*/, @Cast("bool") boolean inPlace/*=false*/); - public ContextPrototype() { super((Pointer)null); allocate(); } - private native void allocate(); - - public native int getNodeId(); - public native int nodeId(); - - // this method returns true, if inputs are defined - public native @Cast("bool") boolean hasVariablesFilled(); - - public native void setOpDescriptor(OpDescriptor opDescriptor); - - public native @Cast("sd::DataType") int dataType(); - public native @Cast("sd::DataType") int dataType(int index); - public native void setDataType(int index, @Cast("sd::DataType") int type); - - public native @Cast("bool") boolean isInplace(); - public native void markInplace(@Cast("bool") boolean reallyInplace); - - public native void pickInput(int input); - public native void pickInput(int input, int index); - public native void pickInput(@ByRef IntIntPair p); - public native void fillInputs(@StdVector IntPointer inputs); - public native void fillInputs(@StdVector IntBuffer inputs); - public native void fillInputs(@StdVector int[] inputs); - public native @StdVector IntIntPair inputs(); - - public native @StdVector DoublePointer getTArguments(); - public native @StdVector IntPointer getIArguments(); - public native @Cast("bool*") @StdVector BooleanPointer getBArguments(); - public native @Cast("sd::DataType*") @StdVector IntPointer getDArguments(); - public native @StdVector IntPointer getAxis(); - - public native @Cast("samediff::Engine") int engine(); - - public native @Cast("size_t") long numT(); - public native @Cast("size_t") long numI(); - public native @Cast("size_t") long numB(); - public native @Cast("size_t") long numD(); - - public native IntIntPair input(int idx); - - public native int opNum(); - public native void setOpNum(int opNum); - - public native @Cast("bool") boolean isUseMKLDNN(); - public native void setUseMKLDNN(@Cast("bool") boolean useMKLDNN); - - /** - * This method returns number of inputs available in this block - * @return - */ - public native @Cast("unsigned long") long width(); - - // just a clone - public native ContextPrototype clone(); - - public native @ByRef RandomGenerator randomGenerator(); - public native @Const @ByRef RandomGenerator getRng(); - public native void setRng(@Const @ByRef RandomGenerator anotherRng); - public native void setRandomGenerator(@Const @ByRef RandomGenerator anotherRng); - public native @Cast("uint64_t") long randomSeed(); - public native void setRandomSeed(@Cast("uint64_t") long seed); - } - +@Namespace("sd::graph") @NoOffset public static class ContextPrototype extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ContextPrototype(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public ContextPrototype(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public ContextPrototype position(long position) { + return (ContextPrototype)super.position(position); + } + + public ContextPrototype(OpDescriptor opDescriptor/*=nullptr*/, + int nodeId/*=1*/, @Cast("bool") boolean inPlace/*=false*/) { super((Pointer)null); allocate(opDescriptor, nodeId, inPlace); } + private native void allocate(OpDescriptor opDescriptor/*=nullptr*/, + int nodeId/*=1*/, @Cast("bool") boolean inPlace/*=false*/); + public ContextPrototype() { super((Pointer)null); allocate(); } + private native void allocate(); + + public ContextPrototype(@Const @ByRef ContextPrototype other) { super((Pointer)null); allocate(other); } + @NoException private native void allocate(@Const @ByRef ContextPrototype other); + + public native @ByRef @Name("operator =") @NoException ContextPrototype put(@Const @ByRef ContextPrototype other); + + // move constructor + + // move assignment operator + + public native int getNodeId(); + public native int nodeId(); + public native void setNodeId(int id); + + // this method returns true, if inputs are defined + public native @Cast("bool") boolean hasVariablesFilled(); + + public native void setOpDescriptor(OpDescriptor opDescriptor); + + public native @Cast("bool") boolean isInplace(); + public native void markInplace(@Cast("bool") boolean reallyInplace); + + public native void pickInput(int input); + public native void pickInput(int input, int index); + public native void pickInput(@Const @ByRef IntIntPair p); + public native void fillInputs(@StdVector IntPointer inputs); + public native void fillInputs(@StdVector IntBuffer inputs); + public native void fillInputs(@StdVector int[] inputs); + public native @StdVector IntIntPair inputs(); + + public native @StdVector DoublePointer getTArguments(); + public native @StdVector IntPointer getIArguments(); + public native @Cast("bool*") @StdVector BooleanPointer getBArguments(); + public native @Cast("sd::DataType*") @StdVector IntPointer getDArguments(); + public native @StdVector IntPointer getAxis(); + + public native void appendI(@Cast("Nd4jLong*") @StdVector LongPointer value); + public native void appendI(@Cast("Nd4jLong*") @StdVector LongBuffer value); + public native void appendI(@Cast("Nd4jLong*") @StdVector long[] value); + public native void appendT(@StdVector DoublePointer value); + public native void appendT(@StdVector DoubleBuffer value); + public native void appendT(@StdVector double[] value); + public native void appendB(@Cast("bool*") @StdVector BooleanPointer value); + public native void appendB(@Cast("bool*") @StdVector boolean[] value); + public native void appendD(@Cast("sd::DataType*") @StdVector IntPointer value); + public native void appendD(@Cast("sd::DataType*") @StdVector IntBuffer value); + public native void appendD(@Cast("sd::DataType*") @StdVector int[] value); + + public native void appendA(@Cast("Nd4jLong") long value); + public native void appendI(@Cast("Nd4jLong") long value); + public native void appendT(double value); + public native void appendB(@Cast("bool") boolean value); + public native void appendD(@Cast("sd::DataType") int value); + + public native @Cast("samediff::Engine") int engine(); + + public native @Cast("size_t") long numT(); + public native @Cast("size_t") long numI(); + public native @Cast("size_t") long numB(); + public native @Cast("size_t") long numD(); + + public native @Const @ByRef IntIntPair input(int idx); + + public native int opNum(); + public native void setOpNum(int opNum); + + public native @Cast("bool") boolean isUseMKLDNN(); + public native void setUseMKLDNN(@Cast("bool") boolean useMKLDNN); + public native @StdString BytePointer name(); + public native void setName(@StdString BytePointer name); + public native void setName(@StdString String name); -// #endif //ND4J_CONTEXT_PROTOTYPE_H + /** + * This method returns number of inputs available in this block + * @return + */ + public native @Cast("unsigned long") long width(); + + // just a clone + public native ContextPrototype clone(); + + public native @ByRef RandomGenerator randomGenerator(); + public native @Const @ByRef RandomGenerator getRng(); + public native void setRng(@Const @ByRef RandomGenerator anotherRng); + public native void setRandomGenerator(@Const @ByRef RandomGenerator anotherRng); + public native @Cast("uint64_t") long randomSeed(); + public native void setRandomSeed(@Cast("uint64_t") long seed); +} + // namespace graph + // namespace sd + +// #endif // ND4J_CONTEXT_PROTOTYPE_H // Parsed from graph/ResultWrapper.h @@ -6845,26 +7294,25 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_RESULTWRAPPER_H // #define LIBND4J_RESULTWRAPPER_H +// #include // #include // #include -// #include - @Namespace("sd::graph") @NoOffset public static class ResultWrapper extends org.nd4j.nativeblas.ResultWrapperAbstraction { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ResultWrapper(Pointer p) { super(p); } - - public ResultWrapper(@Cast("Nd4jLong") long size, @Cast("Nd4jPointer") Pointer ptr) { super((Pointer)null); allocate(size, ptr); } - private native void allocate(@Cast("Nd4jLong") long size, @Cast("Nd4jPointer") Pointer ptr); - - public native @Cast("Nd4jLong") long size(); +@Namespace("sd::graph") @NoOffset public static class ResultWrapper extends org.nd4j.nativeblas.ResultWrapperAbstraction { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ResultWrapper(Pointer p) { super(p); } - public native @Cast("Nd4jPointer") Pointer pointer(); - } - + public ResultWrapper(@Cast("Nd4jLong") long size, @Cast("Nd4jPointer") Pointer ptr) { super((Pointer)null); allocate(size, ptr); } + private native void allocate(@Cast("Nd4jLong") long size, @Cast("Nd4jPointer") Pointer ptr); + public native @Cast("Nd4jLong") long size(); + public native @Cast("Nd4jPointer") Pointer pointer(); +} + // namespace graph + // namespace sd -// #endif //LIBND4J_RESULTWRAPPER_H +// #endif // LIBND4J_RESULTWRAPPER_H // Parsed from helpers/shape.h @@ -6895,205 +7343,309 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef SHAPE_H_ // #define SHAPE_H_ -// #include +// #include + // #include +// #include + +// #include "../cnpy/cnpy.h" +// #include "../helpers/logger.h" +// #include "math/templatemath.h" // #include "system/dll.h" // #include "system/nd4jmalloc.h" -// #include "math/templatemath.h" -// #include "../helpers/logger.h" // #include "system/pointercast.h" -// #include "../cnpy/cnpy.h" -// #include public static final int MAX_DIMENSION = 0x7fffffff; -public static final int MAX_NUM_THREADS = 1024; +public static final int MAX_NUM_THREADS = 1024; public static final int MAX_RANK = 32; -public static final int MAX_SHAPEINFOLENGTH = 2*MAX_RANK+4; +public static final int MAX_SHAPEINFOLENGTH = 2 * MAX_RANK + 4; public static final int MAX_COORD = 3; public static final int PREALLOC_SIZE = 33554432; // #ifdef __CUDACC__ // #endif - // #ifdef __CUDACC__ // #else // #define INLINEDEF inline // #endif -// #include "system/pairwise_util.h" -// #include // #include +// #include + +// #include "system/pairwise_util.h" /** * Shape information approximating * the information on an ndarray */ - @Namespace("shape") @NoOffset public static class ShapeInformation extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ShapeInformation(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public ShapeInformation(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public ShapeInformation position(long position) { - return (ShapeInformation)super.position(position); - } - - public ShapeInformation(@Cast("Nd4jLong*") LongPointer shape_/*=nullptr*/, @Cast("Nd4jLong*") LongPointer stride_/*=nullptr*/, char order_/*=0*/, int rank_/*=0*/, int offset_/*=0*/, int elementWiseStride_/*=0*/) { super((Pointer)null); allocate(shape_, stride_, order_, rank_, offset_, elementWiseStride_); } - private native void allocate(@Cast("Nd4jLong*") LongPointer shape_/*=nullptr*/, @Cast("Nd4jLong*") LongPointer stride_/*=nullptr*/, char order_/*=0*/, int rank_/*=0*/, int offset_/*=0*/, int elementWiseStride_/*=0*/); - public ShapeInformation() { super((Pointer)null); allocate(); } - private native void allocate(); - public ShapeInformation(@Cast("Nd4jLong*") LongBuffer shape_/*=nullptr*/, @Cast("Nd4jLong*") LongBuffer stride_/*=nullptr*/, char order_/*=0*/, int rank_/*=0*/, int offset_/*=0*/, int elementWiseStride_/*=0*/) { super((Pointer)null); allocate(shape_, stride_, order_, rank_, offset_, elementWiseStride_); } - private native void allocate(@Cast("Nd4jLong*") LongBuffer shape_/*=nullptr*/, @Cast("Nd4jLong*") LongBuffer stride_/*=nullptr*/, char order_/*=0*/, int rank_/*=0*/, int offset_/*=0*/, int elementWiseStride_/*=0*/); - public ShapeInformation(@Cast("Nd4jLong*") long[] shape_/*=nullptr*/, @Cast("Nd4jLong*") long[] stride_/*=nullptr*/, char order_/*=0*/, int rank_/*=0*/, int offset_/*=0*/, int elementWiseStride_/*=0*/) { super((Pointer)null); allocate(shape_, stride_, order_, rank_, offset_, elementWiseStride_); } - private native void allocate(@Cast("Nd4jLong*") long[] shape_/*=nullptr*/, @Cast("Nd4jLong*") long[] stride_/*=nullptr*/, char order_/*=0*/, int rank_/*=0*/, int offset_/*=0*/, int elementWiseStride_/*=0*/); - - public native @Cast("Nd4jLong*") LongPointer shape(); public native ShapeInformation shape(LongPointer setter); - public native @Cast("Nd4jLong*") LongPointer stride(); public native ShapeInformation stride(LongPointer setter); - public native char order(); public native ShapeInformation order(char setter); - public native int rank(); public native ShapeInformation rank(int setter); - public native int offset(); public native ShapeInformation offset(int setter); - public native int elementWiseStride(); public native ShapeInformation elementWiseStride(int setter); +@Namespace("shape") @NoOffset public static class ShapeInformation extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ShapeInformation(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public ShapeInformation(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public ShapeInformation position(long position) { + return (ShapeInformation)super.position(position); } + public ShapeInformation(@Cast("Nd4jLong*") LongPointer shape_/*=nullptr*/, + @Cast("Nd4jLong*") LongPointer stride_/*=nullptr*/, char order_/*=0*/, + int rank_/*=0*/, int offset_/*=0*/, + int elementWiseStride_/*=0*/) { super((Pointer)null); allocate(shape_, stride_, order_, rank_, offset_, elementWiseStride_); } + private native void allocate(@Cast("Nd4jLong*") LongPointer shape_/*=nullptr*/, + @Cast("Nd4jLong*") LongPointer stride_/*=nullptr*/, char order_/*=0*/, + int rank_/*=0*/, int offset_/*=0*/, + int elementWiseStride_/*=0*/); + public ShapeInformation() { super((Pointer)null); allocate(); } + private native void allocate(); + public ShapeInformation(@Cast("Nd4jLong*") LongBuffer shape_/*=nullptr*/, + @Cast("Nd4jLong*") LongBuffer stride_/*=nullptr*/, char order_/*=0*/, + int rank_/*=0*/, int offset_/*=0*/, + int elementWiseStride_/*=0*/) { super((Pointer)null); allocate(shape_, stride_, order_, rank_, offset_, elementWiseStride_); } + private native void allocate(@Cast("Nd4jLong*") LongBuffer shape_/*=nullptr*/, + @Cast("Nd4jLong*") LongBuffer stride_/*=nullptr*/, char order_/*=0*/, + int rank_/*=0*/, int offset_/*=0*/, + int elementWiseStride_/*=0*/); + public ShapeInformation(@Cast("Nd4jLong*") long[] shape_/*=nullptr*/, + @Cast("Nd4jLong*") long[] stride_/*=nullptr*/, char order_/*=0*/, + int rank_/*=0*/, int offset_/*=0*/, + int elementWiseStride_/*=0*/) { super((Pointer)null); allocate(shape_, stride_, order_, rank_, offset_, elementWiseStride_); } + private native void allocate(@Cast("Nd4jLong*") long[] shape_/*=nullptr*/, + @Cast("Nd4jLong*") long[] stride_/*=nullptr*/, char order_/*=0*/, + int rank_/*=0*/, int offset_/*=0*/, + int elementWiseStride_/*=0*/); + + public native @Cast("Nd4jLong*") LongPointer shape(); public native ShapeInformation shape(LongPointer setter); + public native @Cast("Nd4jLong*") LongPointer stride(); public native ShapeInformation stride(LongPointer setter); + public native char order(); public native ShapeInformation order(char setter); + public native int rank(); public native ShapeInformation rank(int setter); + public native int offset(); public native ShapeInformation offset(int setter); + public native int elementWiseStride(); public native ShapeInformation elementWiseStride(int setter); +} + /** * Indexing information * for bounds checking */ - @Namespace("shape") public static class CurrentIndexing extends Pointer { - static { Loader.load(); } - /** Default native constructor. */ - public CurrentIndexing() { super((Pointer)null); allocate(); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public CurrentIndexing(long size) { super((Pointer)null); allocateArray(size); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public CurrentIndexing(Pointer p) { super(p); } - private native void allocate(); - private native void allocateArray(long size); - @Override public CurrentIndexing position(long position) { - return (CurrentIndexing)super.position(position); - } - - public native int numElementsPerThread(); public native CurrentIndexing numElementsPerThread(int setter); - public native int blockStartingIndex(); public native CurrentIndexing blockStartingIndex(int setter); - public native int startingThreadIndex(); public native CurrentIndexing startingThreadIndex(int setter); - public native int endingThreadIndex(); public native CurrentIndexing endingThreadIndex(int setter); - +@Namespace("shape") public static class CurrentIndexing extends Pointer { + static { Loader.load(); } + /** Default native constructor. */ + public CurrentIndexing() { super((Pointer)null); allocate(); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public CurrentIndexing(long size) { super((Pointer)null); allocateArray(size); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public CurrentIndexing(Pointer p) { super(p); } + private native void allocate(); + private native void allocateArray(long size); + @Override public CurrentIndexing position(long position) { + return (CurrentIndexing)super.position(position); } + public native int numElementsPerThread(); public native CurrentIndexing numElementsPerThread(int setter); + public native int blockStartingIndex(); public native CurrentIndexing blockStartingIndex(int setter); + public native int startingThreadIndex(); public native CurrentIndexing startingThreadIndex(int setter); + public native int endingThreadIndex(); public native CurrentIndexing endingThreadIndex(int setter); +} +@Namespace("shape") public static native @Cast("bool") boolean shapeEquals(int shape1Rank, + @Cast("const Nd4jLong*") LongPointer shape1, + int shape2Rank, + @Cast("const Nd4jLong*") LongPointer shape2); +@Namespace("shape") public static native @Cast("bool") boolean shapeEquals(int shape1Rank, + @Cast("const Nd4jLong*") LongBuffer shape1, + int shape2Rank, + @Cast("const Nd4jLong*") LongBuffer shape2); +@Namespace("shape") public static native @Cast("bool") boolean shapeEquals(int shape1Rank, + @Cast("const Nd4jLong*") long[] shape1, + int shape2Rank, + @Cast("const Nd4jLong*") long[] shape2); + +@Namespace("shape") public static native @Cast("const Nd4jLong*") LongPointer detachShape(@Cast("const Nd4jLong*") LongPointer originalShape); +@Namespace("shape") public static native @Cast("const Nd4jLong*") LongBuffer detachShape(@Cast("const Nd4jLong*") LongBuffer originalShape); +@Namespace("shape") public static native @Cast("const Nd4jLong*") long[] detachShape(@Cast("const Nd4jLong*") long[] originalShape); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer copyShape(@Cast("const Nd4jLong*") LongPointer originalShape); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer copyShape(@Cast("const Nd4jLong*") LongBuffer originalShape); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] copyShape(@Cast("const Nd4jLong*") long[] originalShape); + +@Namespace("shape") public static native @Cast("bool") boolean shapeEquals(@Cast("const Nd4jLong*") LongPointer shapeInfo1, + @Cast("const Nd4jLong*") LongPointer shapeInfo2); +@Namespace("shape") public static native @Cast("bool") boolean shapeEquals(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, + @Cast("const Nd4jLong*") LongBuffer shapeInfo2); +@Namespace("shape") public static native @Cast("bool") boolean shapeEquals(@Cast("const Nd4jLong*") long[] shapeInfo1, + @Cast("const Nd4jLong*") long[] shapeInfo2); + +@Namespace("shape") public static native @Cast("bool") boolean shapeEquals(@Cast("const Nd4jLong*") LongPointer shapeInfo1, + @Cast("const Nd4jLong*") LongPointer shapeInfo2, + @Cast("const Nd4jLong*") LongPointer shapeInfo3); +@Namespace("shape") public static native @Cast("bool") boolean shapeEquals(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, + @Cast("const Nd4jLong*") LongBuffer shapeInfo2, + @Cast("const Nd4jLong*") LongBuffer shapeInfo3); +@Namespace("shape") public static native @Cast("bool") boolean shapeEquals(@Cast("const Nd4jLong*") long[] shapeInfo1, + @Cast("const Nd4jLong*") long[] shapeInfo2, + @Cast("const Nd4jLong*") long[] shapeInfo3); + +@Namespace("shape") public static native @Cast("bool") boolean strideEquals(int shape1Rank, + @Cast("const Nd4jLong*") LongPointer shape1, + int shape2Rank, + @Cast("const Nd4jLong*") LongPointer shape2); +@Namespace("shape") public static native @Cast("bool") boolean strideEquals(int shape1Rank, + @Cast("const Nd4jLong*") LongBuffer shape1, + int shape2Rank, + @Cast("const Nd4jLong*") LongBuffer shape2); +@Namespace("shape") public static native @Cast("bool") boolean strideEquals(int shape1Rank, + @Cast("const Nd4jLong*") long[] shape1, + int shape2Rank, + @Cast("const Nd4jLong*") long[] shape2); + +@Namespace("shape") public static native @Cast("bool") boolean strideEquals(@Cast("const Nd4jLong*") LongPointer shapeInfo1, + @Cast("const Nd4jLong*") LongPointer shapeInfo2); +@Namespace("shape") public static native @Cast("bool") boolean strideEquals(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, + @Cast("const Nd4jLong*") LongBuffer shapeInfo2); +@Namespace("shape") public static native @Cast("bool") boolean strideEquals(@Cast("const Nd4jLong*") long[] shapeInfo1, + @Cast("const Nd4jLong*") long[] shapeInfo2); + +@Namespace("shape") public static native @Cast("bool") boolean strideEquals(@Cast("const Nd4jLong*") LongPointer stride1, int rank1, + @Cast("const Nd4jLong*") LongPointer stride2, int rank2); +@Namespace("shape") public static native @Cast("bool") boolean strideEquals(@Cast("const Nd4jLong*") LongBuffer stride1, int rank1, + @Cast("const Nd4jLong*") LongBuffer stride2, int rank2); +@Namespace("shape") public static native @Cast("bool") boolean strideEquals(@Cast("const Nd4jLong*") long[] stride1, int rank1, + @Cast("const Nd4jLong*") long[] stride2, int rank2); + +@Namespace("shape") public static native @Cast("bool") boolean equalsSoft(@Cast("const Nd4jLong*") LongPointer shapeA, + @Cast("const Nd4jLong*") LongPointer shapeB); +@Namespace("shape") public static native @Cast("bool") boolean equalsSoft(@Cast("const Nd4jLong*") LongBuffer shapeA, + @Cast("const Nd4jLong*") LongBuffer shapeB); +@Namespace("shape") public static native @Cast("bool") boolean equalsSoft(@Cast("const Nd4jLong*") long[] shapeA, + @Cast("const Nd4jLong*") long[] shapeB); + +@Namespace("shape") public static native @Cast("bool") boolean equalsTypesAndShapesSoft(@Cast("const Nd4jLong*") LongPointer shapeA, + @Cast("const Nd4jLong*") LongPointer shapeB); +@Namespace("shape") public static native @Cast("bool") boolean equalsTypesAndShapesSoft(@Cast("const Nd4jLong*") LongBuffer shapeA, + @Cast("const Nd4jLong*") LongBuffer shapeB); +@Namespace("shape") public static native @Cast("bool") boolean equalsTypesAndShapesSoft(@Cast("const Nd4jLong*") long[] shapeA, + @Cast("const Nd4jLong*") long[] shapeB); + +@Namespace("shape") public static native @Cast("bool") boolean equalsStrict(@Cast("const Nd4jLong*") LongPointer shapeA, + @Cast("const Nd4jLong*") LongPointer shapeB); +@Namespace("shape") public static native @Cast("bool") boolean equalsStrict(@Cast("const Nd4jLong*") LongBuffer shapeA, + @Cast("const Nd4jLong*") LongBuffer shapeB); +@Namespace("shape") public static native @Cast("bool") boolean equalsStrict(@Cast("const Nd4jLong*") long[] shapeA, + @Cast("const Nd4jLong*") long[] shapeB); + +// returns true if ranks, shapes and strides are the same +@Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") LongPointer shapeInfo1, + @Cast("const Nd4jLong*") LongPointer shapeInfo2); +@Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, + @Cast("const Nd4jLong*") LongBuffer shapeInfo2); +@Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") long[] shapeInfo1, + @Cast("const Nd4jLong*") long[] shapeInfo2); +@Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") LongPointer shapeInfo1, + @Cast("const Nd4jLong*") LongPointer shapeInfo2, + @Cast("const Nd4jLong*") LongPointer shapeInfo3); +@Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, + @Cast("const Nd4jLong*") LongBuffer shapeInfo2, + @Cast("const Nd4jLong*") LongBuffer shapeInfo3); +@Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") long[] shapeInfo1, + @Cast("const Nd4jLong*") long[] shapeInfo2, + @Cast("const Nd4jLong*") long[] shapeInfo3); + +@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim); +@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim); +@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim); +@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim); +@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim); +@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim); + +@Namespace("shape") public static native void traceNew(int id); + +@Namespace("shape") public static native int tadIndexForLinear(int linearIndex, int tadLength); + +@Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer dimension, + int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension, + int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("const Nd4jLong*") long[] shapeInfo, int[] dimension, + int dimensionLength); + +@Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongPointer oldShape, + int newRank, @Cast("Nd4jLong*") LongPointer newShape, + @Cast("bool") boolean isFOrder); +@Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongBuffer oldShape, + int newRank, @Cast("Nd4jLong*") LongBuffer newShape, + @Cast("bool") boolean isFOrder); +@Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") long[] oldShape, + int newRank, @Cast("Nd4jLong*") long[] newShape, + @Cast("bool") boolean isFOrder); + +@Namespace("shape") public static native @Cast("bool") boolean reshapeC(@Cast("const Nd4jLong*") LongPointer oldShapeInfo, + byte newOrder, int newRank, + @Cast("const Nd4jLong*") LongPointer newShape, + @Cast("Nd4jLong*") LongPointer newShapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean reshapeC(@Cast("const Nd4jLong*") LongBuffer oldShapeInfo, + byte newOrder, int newRank, + @Cast("const Nd4jLong*") LongBuffer newShape, + @Cast("Nd4jLong*") LongBuffer newShapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean reshapeC(@Cast("const Nd4jLong*") long[] oldShapeInfo, + byte newOrder, int newRank, + @Cast("const Nd4jLong*") long[] newShape, + @Cast("Nd4jLong*") long[] newShapeInfo); +/** + * newShapeInfo contains rank, shape and order only, no strides/ews/type + */ +@Namespace("shape") public static native @Cast("bool") boolean reshapeC(@Cast("const Nd4jLong*") LongPointer oldShapeInfo, + @Cast("Nd4jLong*") LongPointer newShapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean reshapeC(@Cast("const Nd4jLong*") LongBuffer oldShapeInfo, + @Cast("Nd4jLong*") LongBuffer newShapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean reshapeC(@Cast("const Nd4jLong*") long[] oldShapeInfo, + @Cast("Nd4jLong*") long[] newShapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean shapeEquals(int shape1Rank, @Cast("const Nd4jLong*") LongPointer shape1, int shape2Rank, @Cast("const Nd4jLong*") LongPointer shape2); - @Namespace("shape") public static native @Cast("bool") boolean shapeEquals(int shape1Rank, @Cast("const Nd4jLong*") LongBuffer shape1, int shape2Rank, @Cast("const Nd4jLong*") LongBuffer shape2); - @Namespace("shape") public static native @Cast("bool") boolean shapeEquals(int shape1Rank, @Cast("const Nd4jLong*") long[] shape1, int shape2Rank, @Cast("const Nd4jLong*") long[] shape2); - - @Namespace("shape") public static native @Cast("const Nd4jLong*") LongPointer detachShape(@Cast("const Nd4jLong*") LongPointer originalShape); - @Namespace("shape") public static native @Cast("const Nd4jLong*") LongBuffer detachShape(@Cast("const Nd4jLong*") LongBuffer originalShape); - @Namespace("shape") public static native @Cast("const Nd4jLong*") long[] detachShape(@Cast("const Nd4jLong*") long[] originalShape); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer copyShape(@Cast("const Nd4jLong*") LongPointer originalShape); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer copyShape(@Cast("const Nd4jLong*") LongBuffer originalShape); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] copyShape(@Cast("const Nd4jLong*") long[] originalShape); - - @Namespace("shape") public static native @Cast("bool") boolean shapeEquals(@Cast("const Nd4jLong*") LongPointer shapeInfo1, @Cast("const Nd4jLong*") LongPointer shapeInfo2); - @Namespace("shape") public static native @Cast("bool") boolean shapeEquals(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, @Cast("const Nd4jLong*") LongBuffer shapeInfo2); - @Namespace("shape") public static native @Cast("bool") boolean shapeEquals(@Cast("const Nd4jLong*") long[] shapeInfo1, @Cast("const Nd4jLong*") long[] shapeInfo2); - - @Namespace("shape") public static native @Cast("bool") boolean shapeEquals(@Cast("const Nd4jLong*") LongPointer shapeInfo1, @Cast("const Nd4jLong*") LongPointer shapeInfo2, @Cast("const Nd4jLong*") LongPointer shapeInfo3); - @Namespace("shape") public static native @Cast("bool") boolean shapeEquals(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, @Cast("const Nd4jLong*") LongBuffer shapeInfo2, @Cast("const Nd4jLong*") LongBuffer shapeInfo3); - @Namespace("shape") public static native @Cast("bool") boolean shapeEquals(@Cast("const Nd4jLong*") long[] shapeInfo1, @Cast("const Nd4jLong*") long[] shapeInfo2, @Cast("const Nd4jLong*") long[] shapeInfo3); - - @Namespace("shape") public static native @Cast("bool") boolean strideEquals(int shape1Rank,@Cast("const Nd4jLong*") LongPointer shape1,int shape2Rank, @Cast("const Nd4jLong*") LongPointer shape2); - @Namespace("shape") public static native @Cast("bool") boolean strideEquals(int shape1Rank,@Cast("const Nd4jLong*") LongBuffer shape1,int shape2Rank, @Cast("const Nd4jLong*") LongBuffer shape2); - @Namespace("shape") public static native @Cast("bool") boolean strideEquals(int shape1Rank,@Cast("const Nd4jLong*") long[] shape1,int shape2Rank, @Cast("const Nd4jLong*") long[] shape2); - - @Namespace("shape") public static native @Cast("bool") boolean strideEquals(@Cast("const Nd4jLong*") LongPointer shapeInfo1, @Cast("const Nd4jLong*") LongPointer shapeInfo2); - @Namespace("shape") public static native @Cast("bool") boolean strideEquals(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, @Cast("const Nd4jLong*") LongBuffer shapeInfo2); - @Namespace("shape") public static native @Cast("bool") boolean strideEquals(@Cast("const Nd4jLong*") long[] shapeInfo1, @Cast("const Nd4jLong*") long[] shapeInfo2); - - @Namespace("shape") public static native @Cast("bool") boolean strideEquals(@Cast("const Nd4jLong*") LongPointer stride1,int rank1, @Cast("const Nd4jLong*") LongPointer stride2, int rank2); - @Namespace("shape") public static native @Cast("bool") boolean strideEquals(@Cast("const Nd4jLong*") LongBuffer stride1,int rank1, @Cast("const Nd4jLong*") LongBuffer stride2, int rank2); - @Namespace("shape") public static native @Cast("bool") boolean strideEquals(@Cast("const Nd4jLong*") long[] stride1,int rank1, @Cast("const Nd4jLong*") long[] stride2, int rank2); - - @Namespace("shape") public static native @Cast("bool") boolean equalsSoft(@Cast("const Nd4jLong*") LongPointer shapeA, @Cast("const Nd4jLong*") LongPointer shapeB); - @Namespace("shape") public static native @Cast("bool") boolean equalsSoft(@Cast("const Nd4jLong*") LongBuffer shapeA, @Cast("const Nd4jLong*") LongBuffer shapeB); - @Namespace("shape") public static native @Cast("bool") boolean equalsSoft(@Cast("const Nd4jLong*") long[] shapeA, @Cast("const Nd4jLong*") long[] shapeB); - - @Namespace("shape") public static native @Cast("bool") boolean equalsTypesAndShapesSoft(@Cast("const Nd4jLong*") LongPointer shapeA, @Cast("const Nd4jLong*") LongPointer shapeB); - @Namespace("shape") public static native @Cast("bool") boolean equalsTypesAndShapesSoft(@Cast("const Nd4jLong*") LongBuffer shapeA, @Cast("const Nd4jLong*") LongBuffer shapeB); - @Namespace("shape") public static native @Cast("bool") boolean equalsTypesAndShapesSoft(@Cast("const Nd4jLong*") long[] shapeA, @Cast("const Nd4jLong*") long[] shapeB); - - @Namespace("shape") public static native @Cast("bool") boolean equalsStrict(@Cast("const Nd4jLong*") LongPointer shapeA, @Cast("const Nd4jLong*") LongPointer shapeB); - @Namespace("shape") public static native @Cast("bool") boolean equalsStrict(@Cast("const Nd4jLong*") LongBuffer shapeA, @Cast("const Nd4jLong*") LongBuffer shapeB); - @Namespace("shape") public static native @Cast("bool") boolean equalsStrict(@Cast("const Nd4jLong*") long[] shapeA, @Cast("const Nd4jLong*") long[] shapeB); - - // returns true if ranks, shapes and strides are the same - @Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") LongPointer shapeInfo1, @Cast("const Nd4jLong*") LongPointer shapeInfo2); - @Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, @Cast("const Nd4jLong*") LongBuffer shapeInfo2); - @Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") long[] shapeInfo1, @Cast("const Nd4jLong*") long[] shapeInfo2); - @Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") LongPointer shapeInfo1, @Cast("const Nd4jLong*") LongPointer shapeInfo2, @Cast("const Nd4jLong*") LongPointer shapeInfo3); - @Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, @Cast("const Nd4jLong*") LongBuffer shapeInfo2, @Cast("const Nd4jLong*") LongBuffer shapeInfo3); - @Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") long[] shapeInfo1, @Cast("const Nd4jLong*") long[] shapeInfo2, @Cast("const Nd4jLong*") long[] shapeInfo3); - - @Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim); - @Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim); - @Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim); - @Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim); - @Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim); - @Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim); - - @Namespace("shape") public static native void traceNew(int id); - - - @Namespace("shape") public static native int tadIndexForLinear(int linearIndex, int tadLength); - - @Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer dimension, int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension, int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("const Nd4jLong*") long[] shapeInfo, int[] dimension, int dimensionLength); - - @Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongPointer oldShape, int newRank, @Cast("Nd4jLong*") LongPointer newShape, @Cast("bool") boolean isFOrder); - @Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongBuffer oldShape, int newRank, @Cast("Nd4jLong*") LongBuffer newShape, @Cast("bool") boolean isFOrder); - @Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") long[] oldShape, int newRank, @Cast("Nd4jLong*") long[] newShape, @Cast("bool") boolean isFOrder); - - @Namespace("shape") public static native @Cast("bool") boolean reshapeC(@Cast("const Nd4jLong*") LongPointer oldShapeInfo, byte newOrder, int newRank, @Cast("const Nd4jLong*") LongPointer newShape, @Cast("Nd4jLong*") LongPointer newShapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean reshapeC(@Cast("const Nd4jLong*") LongBuffer oldShapeInfo, byte newOrder, int newRank, @Cast("const Nd4jLong*") LongBuffer newShape, @Cast("Nd4jLong*") LongBuffer newShapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean reshapeC(@Cast("const Nd4jLong*") long[] oldShapeInfo, byte newOrder, int newRank, @Cast("const Nd4jLong*") long[] newShape, @Cast("Nd4jLong*") long[] newShapeInfo); - /** - * newShapeInfo contains rank, shape and order only, no strides/ews/type - */ - @Namespace("shape") public static native @Cast("bool") boolean reshapeC(@Cast("const Nd4jLong*") LongPointer oldShapeInfo, @Cast("Nd4jLong*") LongPointer newShapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean reshapeC(@Cast("const Nd4jLong*") LongBuffer oldShapeInfo, @Cast("Nd4jLong*") LongBuffer newShapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean reshapeC(@Cast("const Nd4jLong*") long[] oldShapeInfo, @Cast("Nd4jLong*") long[] newShapeInfo); - - /** - * Get the shape info buffer - * for the given rank and shape. - */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeBuffer(int rank, @Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") LongPointer shape); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeBuffer(int rank, @Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") LongBuffer shape); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeBuffer(int rank, @Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") long[] shape); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeBuffer(int rank, @Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeBuffer(int rank, @Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeBuffer(int rank, @Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] buffer); +/** + * Get the shape info buffer + * for the given rank and shape. + */ +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeBuffer(int rank, @Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") LongPointer shape); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeBuffer(int rank, @Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") LongBuffer shape); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeBuffer(int rank, @Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") long[] shape); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeBuffer(int rank, @Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") LongPointer shape, + @Cast("Nd4jLong*") LongPointer buffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeBuffer(int rank, @Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") LongBuffer shape, + @Cast("Nd4jLong*") LongBuffer buffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeBuffer(int rank, @Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") long[] shape, + @Cast("Nd4jLong*") long[] buffer); - /** - * Get the shape info buffer - * for the given rank and shape. - */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeBufferFortran(int rank, @Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") LongPointer shape); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeBufferFortran(int rank, @Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") LongBuffer shape); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeBufferFortran(int rank, @Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") long[] shape); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeBufferFortran(int rank, @Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer output); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeBufferFortran(int rank, @Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer output); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeBufferFortran(int rank, @Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] output); +/** + * Get the shape info buffer + * for the given rank and shape. + */ +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeBufferFortran(int rank, @Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") LongPointer shape); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeBufferFortran(int rank, @Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") LongBuffer shape); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeBufferFortran(int rank, @Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") long[] shape); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeBufferFortran(int rank, @Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") LongPointer shape, + @Cast("Nd4jLong*") LongPointer output); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeBufferFortran(int rank, @Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") LongBuffer shape, + @Cast("Nd4jLong*") LongBuffer output); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeBufferFortran(int rank, @Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") long[] shape, + @Cast("Nd4jLong*") long[] output); // #ifdef __CUDACC__ // #endif - - /** * Computes the standard packed array strides for a given shape. * @@ -7101,13 +7653,19 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param startNum the start number for the strides * @return the strides for a matrix of n dimensions */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStridesFortran(@Cast("const Nd4jLong*") LongPointer shape, int rank); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStridesFortran(@Cast("const Nd4jLong*") LongBuffer shape, int rank); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStridesFortran(@Cast("const Nd4jLong*") long[] shape, int rank); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStridesFortran(@Cast("const Nd4jLong*") LongPointer shape, int rank, @Cast("Nd4jLong*") LongPointer ret); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStridesFortran(@Cast("const Nd4jLong*") LongBuffer shape, int rank, @Cast("Nd4jLong*") LongBuffer ret); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStridesFortran(@Cast("const Nd4jLong*") long[] shape, int rank, @Cast("Nd4jLong*") long[] ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStridesFortran(@Cast("const Nd4jLong*") LongPointer shape, + int rank); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStridesFortran(@Cast("const Nd4jLong*") LongBuffer shape, + int rank); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStridesFortran(@Cast("const Nd4jLong*") long[] shape, + int rank); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStridesFortran(@Cast("const Nd4jLong*") LongPointer shape, int rank, + @Cast("Nd4jLong*") LongPointer ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStridesFortran(@Cast("const Nd4jLong*") LongBuffer shape, int rank, + @Cast("Nd4jLong*") LongBuffer ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStridesFortran(@Cast("const Nd4jLong*") long[] shape, int rank, + @Cast("Nd4jLong*") long[] ret); /** * Computes the standard packed array strides for a given shape. @@ -7117,23 +7675,29 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @return the strides for a matrix of n dimensions */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStrides(@Cast("const Nd4jLong*") LongPointer shape, int rank); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStrides(@Cast("const Nd4jLong*") LongBuffer shape, int rank); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStrides(@Cast("const Nd4jLong*") long[] shape, int rank); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStrides(@Cast("const Nd4jLong*") LongPointer shape, int rank, @Cast("Nd4jLong*") LongPointer ret); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStrides(@Cast("const Nd4jLong*") LongBuffer shape, int rank, @Cast("Nd4jLong*") LongBuffer ret); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStrides(@Cast("const Nd4jLong*") long[] shape, int rank, @Cast("Nd4jLong*") long[] ret); - - @Namespace("shape") public static native void updateStrides(@Cast("Nd4jLong*") LongPointer shape, byte order); - @Namespace("shape") public static native void updateStrides(@Cast("Nd4jLong*") LongBuffer shape, byte order); - @Namespace("shape") public static native void updateStrides(@Cast("Nd4jLong*") long[] shape, byte order); - @Namespace("shape") public static native void updateStrides(int rank, @Cast("const Nd4jLong*") LongPointer shapeOnly, @Cast("Nd4jLong*") LongPointer stridesOnly, byte order); - @Namespace("shape") public static native void updateStrides(int rank, @Cast("const Nd4jLong*") LongBuffer shapeOnly, @Cast("Nd4jLong*") LongBuffer stridesOnly, byte order); - @Namespace("shape") public static native void updateStrides(int rank, @Cast("const Nd4jLong*") long[] shapeOnly, @Cast("Nd4jLong*") long[] stridesOnly, byte order); - - -// check whether input dimensions are permuted, not permuted dimensions order have to be 0,....,rank-1 +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStrides(@Cast("const Nd4jLong*") LongPointer shape, int rank); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStrides(@Cast("const Nd4jLong*") LongBuffer shape, int rank); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStrides(@Cast("const Nd4jLong*") long[] shape, int rank); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStrides(@Cast("const Nd4jLong*") LongPointer shape, int rank, + @Cast("Nd4jLong*") LongPointer ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStrides(@Cast("const Nd4jLong*") LongBuffer shape, int rank, + @Cast("Nd4jLong*") LongBuffer ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStrides(@Cast("const Nd4jLong*") long[] shape, int rank, + @Cast("Nd4jLong*") long[] ret); + +@Namespace("shape") public static native void updateStrides(@Cast("Nd4jLong*") LongPointer shape, byte order); +@Namespace("shape") public static native void updateStrides(@Cast("Nd4jLong*") LongBuffer shape, byte order); +@Namespace("shape") public static native void updateStrides(@Cast("Nd4jLong*") long[] shape, byte order); +@Namespace("shape") public static native void updateStrides(int rank, @Cast("const Nd4jLong*") LongPointer shapeOnly, + @Cast("Nd4jLong*") LongPointer stridesOnly, byte order); +@Namespace("shape") public static native void updateStrides(int rank, @Cast("const Nd4jLong*") LongBuffer shapeOnly, + @Cast("Nd4jLong*") LongBuffer stridesOnly, byte order); +@Namespace("shape") public static native void updateStrides(int rank, @Cast("const Nd4jLong*") long[] shapeOnly, + @Cast("Nd4jLong*") long[] stridesOnly, byte order); + +// check whether input dimensions are permuted, not permuted dimensions order +// have to be 0,....,rank-1 /** * Computes the standard packed array strides for a given shape. @@ -7142,13 +7706,19 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param startNum the start number for the strides * @return the strides for a matrix of n dimensions */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStridesFortran(@Cast("const Nd4jLong*") LongPointer shape, int rank, int startNum); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStridesFortran(@Cast("const Nd4jLong*") LongBuffer shape, int rank, int startNum); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStridesFortran(@Cast("const Nd4jLong*") long[] shape, int rank, int startNum); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStridesFortran(@Cast("const Nd4jLong*") LongPointer shape, int rank, int startNum, @Cast("Nd4jLong*") LongPointer ret); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStridesFortran(@Cast("const Nd4jLong*") LongBuffer shape, int rank, int startNum, @Cast("Nd4jLong*") LongBuffer ret); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStridesFortran(@Cast("const Nd4jLong*") long[] shape, int rank, int startNum, @Cast("Nd4jLong*") long[] ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStridesFortran(@Cast("const Nd4jLong*") LongPointer shape, int rank, + int startNum); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStridesFortran(@Cast("const Nd4jLong*") LongBuffer shape, int rank, + int startNum); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStridesFortran(@Cast("const Nd4jLong*") long[] shape, int rank, + int startNum); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStridesFortran(@Cast("const Nd4jLong*") LongPointer shape, int rank, + int startNum, @Cast("Nd4jLong*") LongPointer ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStridesFortran(@Cast("const Nd4jLong*") LongBuffer shape, int rank, + int startNum, @Cast("Nd4jLong*") LongBuffer ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStridesFortran(@Cast("const Nd4jLong*") long[] shape, int rank, + int startNum, @Cast("Nd4jLong*") long[] ret); /** * Computes the standard packed array strides for a given shape. @@ -7157,38 +7727,44 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param startNum the start number for the strides * @return the strides for a matrix of n dimensions */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStrides(@Cast("const Nd4jLong*") LongPointer shape, int rank, int startNum); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStrides(@Cast("const Nd4jLong*") LongBuffer shape, int rank, int startNum); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStrides(@Cast("const Nd4jLong*") long[] shape, int rank, int startNum); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStrides(@Cast("const Nd4jLong*") LongPointer shape, int rank, int startNum, @Cast("Nd4jLong*") LongPointer ret); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStrides(@Cast("const Nd4jLong*") LongBuffer shape, int rank, int startNum, @Cast("Nd4jLong*") LongBuffer ret); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStrides(@Cast("const Nd4jLong*") long[] shape, int rank, int startNum, @Cast("Nd4jLong*") long[] ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStrides(@Cast("const Nd4jLong*") LongPointer shape, int rank, + int startNum); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStrides(@Cast("const Nd4jLong*") LongBuffer shape, int rank, + int startNum); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStrides(@Cast("const Nd4jLong*") long[] shape, int rank, + int startNum); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStrides(@Cast("const Nd4jLong*") LongPointer shape, int rank, + int startNum, @Cast("Nd4jLong*") LongPointer ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStrides(@Cast("const Nd4jLong*") LongBuffer shape, int rank, + int startNum, @Cast("Nd4jLong*") LongBuffer ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStrides(@Cast("const Nd4jLong*") long[] shape, int rank, + int startNum, @Cast("Nd4jLong*") long[] ret); /** * @param toCopy the shape to copy * @return a copy of the original struct */ - @Namespace("shape") public static native ShapeInformation shapeCopy( ShapeInformation toCopy); - - - @Namespace("shape") public static native @Cast("bool") boolean strideDescendingCAscendingF(@Cast("const Nd4jLong*") LongPointer shapeBuffer); - @Namespace("shape") public static native @Cast("bool") boolean strideDescendingCAscendingF(@Cast("const Nd4jLong*") LongBuffer shapeBuffer); - @Namespace("shape") public static native @Cast("bool") boolean strideDescendingCAscendingF(@Cast("const Nd4jLong*") long[] shapeBuffer); +@Namespace("shape") public static native ShapeInformation shapeCopy(ShapeInformation toCopy); - @Namespace("shape") public static native @Cast("bool") boolean isContiguous(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean isContiguous(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean isContiguous(@Cast("const Nd4jLong*") long[] shapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean strideDescendingCAscendingF( + @Cast("const Nd4jLong*") LongPointer shapeBuffer); +@Namespace("shape") public static native @Cast("bool") boolean strideDescendingCAscendingF( + @Cast("const Nd4jLong*") LongBuffer shapeBuffer); +@Namespace("shape") public static native @Cast("bool") boolean strideDescendingCAscendingF( + @Cast("const Nd4jLong*") long[] shapeBuffer); +@Namespace("shape") public static native @Cast("bool") boolean isContiguous(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean isContiguous(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean isContiguous(@Cast("const Nd4jLong*") long[] shapeInfo); /** * copy-past from java hasDefaultStridesForShape function * check whether array is not permuted and has contiguous elements in memory */ - @Namespace("shape") public static native @Cast("bool") boolean areStridesDefault(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean areStridesDefault(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean areStridesDefault(@Cast("const Nd4jLong*") long[] shapeInfo); - +@Namespace("shape") public static native @Cast("bool") boolean areStridesDefault(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean areStridesDefault(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean areStridesDefault(@Cast("const Nd4jLong*") long[] shapeInfo); /** * Compute the element wise stride @@ -7201,9 +7777,15 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @return 0 if there is no element wise stride the * element wise stride of reshape(1,length) otherwise */ - @Namespace("shape") public static native int computeElementWiseStride(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer stride, int isFOrder); - @Namespace("shape") public static native int computeElementWiseStride(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer stride, int isFOrder); - @Namespace("shape") public static native int computeElementWiseStride(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] stride, int isFOrder); +@Namespace("shape") public static native int computeElementWiseStride(int rank, @Cast("const Nd4jLong*") LongPointer shape, + @Cast("const Nd4jLong*") LongPointer stride, + int isFOrder); +@Namespace("shape") public static native int computeElementWiseStride(int rank, @Cast("const Nd4jLong*") LongBuffer shape, + @Cast("const Nd4jLong*") LongBuffer stride, + int isFOrder); +@Namespace("shape") public static native int computeElementWiseStride(int rank, @Cast("const Nd4jLong*") long[] shape, + @Cast("const Nd4jLong*") long[] stride, + int isFOrder); /** * Compute the element wise stride @@ -7216,17 +7798,41 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @return 0 if there is no element wise stride the * element wise stride of reshape(1,length) otherwise */ - @Namespace("shape") public static native int computeElementWiseStride(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer stride, int isFOrder, @Cast("const Nd4jLong*") LongPointer dimension, int dimensionLength); - @Namespace("shape") public static native int computeElementWiseStride(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer stride, int isFOrder, @Cast("const Nd4jLong*") LongBuffer dimension, int dimensionLength); - @Namespace("shape") public static native int computeElementWiseStride(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] stride, int isFOrder, @Cast("const Nd4jLong*") long[] dimension, int dimensionLength); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeInfoOnlyShapeAndStride(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer dimension, int dimensionLength,@Cast("bool") boolean reverseCopyStride); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeInfoOnlyShapeAndStride(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer dimension, int dimensionLength,@Cast("bool") boolean reverseCopyStride); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeInfoOnlyShapeAndStride(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] dimension, int dimensionLength,@Cast("bool") boolean reverseCopyStride); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeInfoOnlyShapeAndStride(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer dimension, int dimensionLength,@Cast("bool") boolean reverseCopyStride, @Cast("Nd4jLong*") LongPointer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeInfoOnlyShapeAndStride(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer dimension, int dimensionLength,@Cast("bool") boolean reverseCopyStride, @Cast("Nd4jLong*") LongBuffer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeInfoOnlyShapeAndStride(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] dimension, int dimensionLength,@Cast("bool") boolean reverseCopyStride, @Cast("Nd4jLong*") long[] buffer); +@Namespace("shape") public static native int computeElementWiseStride(int rank, @Cast("const Nd4jLong*") LongPointer shape, + @Cast("const Nd4jLong*") LongPointer stride, + int isFOrder, + @Cast("const Nd4jLong*") LongPointer dimension, + int dimensionLength); +@Namespace("shape") public static native int computeElementWiseStride(int rank, @Cast("const Nd4jLong*") LongBuffer shape, + @Cast("const Nd4jLong*") LongBuffer stride, + int isFOrder, + @Cast("const Nd4jLong*") LongBuffer dimension, + int dimensionLength); +@Namespace("shape") public static native int computeElementWiseStride(int rank, @Cast("const Nd4jLong*") long[] shape, + @Cast("const Nd4jLong*") long[] stride, + int isFOrder, + @Cast("const Nd4jLong*") long[] dimension, + int dimensionLength); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeInfoOnlyShapeAndStride( + @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer dimension, int dimensionLength, + @Cast("bool") boolean reverseCopyStride); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeInfoOnlyShapeAndStride( + @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer dimension, int dimensionLength, + @Cast("bool") boolean reverseCopyStride); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeInfoOnlyShapeAndStride( + @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] dimension, int dimensionLength, + @Cast("bool") boolean reverseCopyStride); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeInfoOnlyShapeAndStride( + @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer dimension, int dimensionLength, + @Cast("bool") boolean reverseCopyStride, @Cast("Nd4jLong*") LongPointer buffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeInfoOnlyShapeAndStride( + @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer dimension, int dimensionLength, + @Cast("bool") boolean reverseCopyStride, @Cast("Nd4jLong*") LongBuffer buffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeInfoOnlyShapeAndStride( + @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] dimension, int dimensionLength, + @Cast("bool") boolean reverseCopyStride, @Cast("Nd4jLong*") long[] buffer); /** * * @param length @@ -7234,11 +7840,12 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param rearrange * @return */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer doPermuteSwap(int length, @Cast("Nd4jLong*") LongPointer shape, IntPointer rearrange); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer doPermuteSwap(int length, @Cast("Nd4jLong*") LongBuffer shape, IntBuffer rearrange); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] doPermuteSwap(int length, @Cast("Nd4jLong*") long[] shape, int[] rearrange); - - +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer doPermuteSwap(int length, @Cast("Nd4jLong*") LongPointer shape, + IntPointer rearrange); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer doPermuteSwap(int length, @Cast("Nd4jLong*") LongBuffer shape, + IntBuffer rearrange); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] doPermuteSwap(int length, @Cast("Nd4jLong*") long[] shape, + int[] rearrange); /** * In place permute swap @@ -7246,55 +7853,82 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param shape * @param rearrange */ - @Namespace("shape") public static native void doPermuteSwap(int length, @Cast("Nd4jLong**") PointerPointer shape, IntPointer rearrange); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer permuteShapeBuffer(@Cast("const Nd4jLong*") LongPointer shapeBuffer, IntPointer rearrange); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer permuteShapeBuffer(@Cast("const Nd4jLong*") LongBuffer shapeBuffer, IntBuffer rearrange); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] permuteShapeBuffer(@Cast("const Nd4jLong*") long[] shapeBuffer, int[] rearrange); - - @Namespace("shape") public static native void permuteShapeBufferInPlace(@Cast("Nd4jLong*") LongPointer shapeBuffer, IntPointer rearrange, @Cast("Nd4jLong*") LongPointer out); - @Namespace("shape") public static native void permuteShapeBufferInPlace(@Cast("Nd4jLong*") LongBuffer shapeBuffer, IntBuffer rearrange, @Cast("Nd4jLong*") LongBuffer out); - @Namespace("shape") public static native void permuteShapeBufferInPlace(@Cast("Nd4jLong*") long[] shapeBuffer, int[] rearrange, @Cast("Nd4jLong*") long[] out); - - @Namespace("shape") public static native void doPermuteShapeInfo(@Cast("Nd4jLong*") LongPointer shapeBuffer, @Const IntPointer rearrange, @Cast("Nd4jLong") long len/*=-1*/); - @Namespace("shape") public static native void doPermuteShapeInfo(@Cast("Nd4jLong*") LongPointer shapeBuffer, @Const IntPointer rearrange); - @Namespace("shape") public static native void doPermuteShapeInfo(@Cast("Nd4jLong*") LongBuffer shapeBuffer, @Const IntBuffer rearrange, @Cast("Nd4jLong") long len/*=-1*/); - @Namespace("shape") public static native void doPermuteShapeInfo(@Cast("Nd4jLong*") LongBuffer shapeBuffer, @Const IntBuffer rearrange); - @Namespace("shape") public static native void doPermuteShapeInfo(@Cast("Nd4jLong*") long[] shapeBuffer, @Const int[] rearrange, @Cast("Nd4jLong") long len/*=-1*/); - @Namespace("shape") public static native void doPermuteShapeInfo(@Cast("Nd4jLong*") long[] shapeBuffer, @Const int[] rearrange); - - /** - * Rearrange the permute indexes - * according to which dimensions are specified. - * - * For example, dimension is implicitly: - * 0,1,2 - * - * If you want to do a reduce along dimensions 0 and 1, - * you need to permute the indexes to be: - * 2,0,1 - * - * which will give us the ability to ierate along an element - * wise stride. - */ - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createPermuteIndexes(int originalRank, IntPointer dimension,int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createPermuteIndexes(int originalRank, IntBuffer dimension,int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] createPermuteIndexes(int originalRank, int[] dimension,int dimensionLength); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer computeResultShape(@Cast("const Nd4jLong*") LongPointer originalShapeBuffer, IntPointer dimension,int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer computeResultShape(@Cast("const Nd4jLong*") LongBuffer originalShapeBuffer, IntBuffer dimension,int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] computeResultShape(@Cast("const Nd4jLong*") long[] originalShapeBuffer, int[] dimension,int dimensionLength); - - /** - * This method does inplace transpose of given shapeBuffer - * - * @param shapeBuffer - */ - @Namespace("shape") public static native void transposeInplace(@Cast("Nd4jLong*") LongPointer shapeBuffer); - @Namespace("shape") public static native void transposeInplace(@Cast("Nd4jLong*") LongBuffer shapeBuffer); - @Namespace("shape") public static native void transposeInplace(@Cast("Nd4jLong*") long[] shapeBuffer); +@Namespace("shape") public static native void doPermuteSwap(int length, @Cast("Nd4jLong**") PointerPointer shape, + IntPointer rearrange); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer permuteShapeBuffer(@Cast("const Nd4jLong*") LongPointer shapeBuffer, + IntPointer rearrange); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer permuteShapeBuffer(@Cast("const Nd4jLong*") LongBuffer shapeBuffer, + IntBuffer rearrange); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] permuteShapeBuffer(@Cast("const Nd4jLong*") long[] shapeBuffer, + int[] rearrange); + +@Namespace("shape") public static native void permuteShapeBufferInPlace(@Cast("Nd4jLong*") LongPointer shapeBuffer, + IntPointer rearrange, + @Cast("Nd4jLong*") LongPointer out); +@Namespace("shape") public static native void permuteShapeBufferInPlace(@Cast("Nd4jLong*") LongBuffer shapeBuffer, + IntBuffer rearrange, + @Cast("Nd4jLong*") LongBuffer out); +@Namespace("shape") public static native void permuteShapeBufferInPlace(@Cast("Nd4jLong*") long[] shapeBuffer, + int[] rearrange, + @Cast("Nd4jLong*") long[] out); + +@Namespace("shape") public static native void doPermuteShapeInfo(@Cast("Nd4jLong*") LongPointer shapeBuffer, + @Const IntPointer rearrange, + @Cast("Nd4jLong") long len/*=-1*/); +@Namespace("shape") public static native void doPermuteShapeInfo(@Cast("Nd4jLong*") LongPointer shapeBuffer, + @Const IntPointer rearrange); +@Namespace("shape") public static native void doPermuteShapeInfo(@Cast("Nd4jLong*") LongBuffer shapeBuffer, + @Const IntBuffer rearrange, + @Cast("Nd4jLong") long len/*=-1*/); +@Namespace("shape") public static native void doPermuteShapeInfo(@Cast("Nd4jLong*") LongBuffer shapeBuffer, + @Const IntBuffer rearrange); +@Namespace("shape") public static native void doPermuteShapeInfo(@Cast("Nd4jLong*") long[] shapeBuffer, + @Const int[] rearrange, + @Cast("Nd4jLong") long len/*=-1*/); +@Namespace("shape") public static native void doPermuteShapeInfo(@Cast("Nd4jLong*") long[] shapeBuffer, + @Const int[] rearrange); + +/** + * Rearrange the permute indexes + * according to which dimensions are specified. + * + * For example, dimension is implicitly: + * 0,1,2 + * + * If you want to do a reduce along dimensions 0 and 1, + * you need to permute the indexes to be: + * 2,0,1 + * + * which will give us the ability to ierate along an element + * wise stride. + */ + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createPermuteIndexes(int originalRank, + IntPointer dimension, + int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createPermuteIndexes(int originalRank, + IntBuffer dimension, + int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] createPermuteIndexes(int originalRank, + int[] dimension, + int dimensionLength); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer computeResultShape( + @Cast("const Nd4jLong*") LongPointer originalShapeBuffer, IntPointer dimension, int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer computeResultShape( + @Cast("const Nd4jLong*") LongBuffer originalShapeBuffer, IntBuffer dimension, int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] computeResultShape( + @Cast("const Nd4jLong*") long[] originalShapeBuffer, int[] dimension, int dimensionLength); +/** + * This method does inplace transpose of given shapeBuffer + * + * @param shapeBuffer + */ +@Namespace("shape") public static native void transposeInplace(@Cast("Nd4jLong*") LongPointer shapeBuffer); +@Namespace("shape") public static native void transposeInplace(@Cast("Nd4jLong*") LongBuffer shapeBuffer); +@Namespace("shape") public static native void transposeInplace(@Cast("Nd4jLong*") long[] shapeBuffer); /** * Get the ordering for the device @@ -7304,9 +7938,12 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param elementStride * @return */ - @Namespace("shape") public static native char getOrder(int length, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, int elementStride); - @Namespace("shape") public static native char getOrder(int length, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, int elementStride); - @Namespace("shape") public static native char getOrder(int length, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] stride, int elementStride); +@Namespace("shape") public static native char getOrder(int length, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, + int elementStride); +@Namespace("shape") public static native char getOrder(int length, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, + int elementStride); +@Namespace("shape") public static native char getOrder(int length, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] stride, + int elementStride); /** * Ensure that every value in the re arrange @@ -7324,10 +7961,14 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param rearrange the order to re arrange * @param rank the rank of the rearrange array */ - @Namespace("shape") public static native void permute(@Cast("shape::ShapeInformation**") PointerPointer info, IntPointer rearrange, int rank); - @Namespace("shape") public static native void permute(@ByPtrPtr ShapeInformation info, IntPointer rearrange, int rank); - @Namespace("shape") public static native void permute(@ByPtrPtr ShapeInformation info, IntBuffer rearrange, int rank); - @Namespace("shape") public static native void permute(@ByPtrPtr ShapeInformation info, int[] rearrange, int rank); +@Namespace("shape") public static native void permute(@Cast("shape::ShapeInformation**") PointerPointer info, IntPointer rearrange, + int rank); +@Namespace("shape") public static native void permute(@ByPtrPtr ShapeInformation info, IntPointer rearrange, + int rank); +@Namespace("shape") public static native void permute(@ByPtrPtr ShapeInformation info, IntBuffer rearrange, + int rank); +@Namespace("shape") public static native void permute(@ByPtrPtr ShapeInformation info, int[] rearrange, + int rank); /** * Returns whether the @@ -7335,72 +7976,80 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param shape the shape of the array * @param rank the rank of cthe shape */ - @Namespace("shape") public static native int isVector(@Cast("const Nd4jLong*") LongPointer shape, int rank); - @Namespace("shape") public static native int isVector(@Cast("const Nd4jLong*") LongBuffer shape, int rank); - @Namespace("shape") public static native int isVector(@Cast("const Nd4jLong*") long[] shape, int rank); - - - /** - * When 1 dimension is the whole length of the - * array - */ - @Namespace("shape") public static native int oneDimEqualToLength(@Cast("Nd4jLong*") LongPointer shape, int rank); - @Namespace("shape") public static native int oneDimEqualToLength(@Cast("Nd4jLong*") LongBuffer shape, int rank); - @Namespace("shape") public static native int oneDimEqualToLength(@Cast("Nd4jLong*") long[] shape, int rank); - - @Namespace("shape") public static native int oneDimEqualToLength(@Cast("Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native int oneDimEqualToLength(@Cast("Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native int oneDimEqualToLength(@Cast("Nd4jLong*") long[] shapeInfo); - - @Namespace("shape") public static native int isVector(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native int isVector(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native int isVector(@Cast("const Nd4jLong*") long[] shapeInfo); - - @Namespace("shape") public static native @Cast("bool") boolean isLikeVector(@Cast("const Nd4jLong*") LongPointer shapeInfo, @ByRef IntPointer posOfNonUnityDim); - @Namespace("shape") public static native @Cast("bool") boolean isLikeVector(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @ByRef IntBuffer posOfNonUnityDim); - @Namespace("shape") public static native @Cast("bool") boolean isLikeVector(@Cast("const Nd4jLong*") long[] shapeInfo, @ByRef int[] posOfNonUnityDim); +@Namespace("shape") public static native int isVector(@Cast("const Nd4jLong*") LongPointer shape, int rank); +@Namespace("shape") public static native int isVector(@Cast("const Nd4jLong*") LongBuffer shape, int rank); +@Namespace("shape") public static native int isVector(@Cast("const Nd4jLong*") long[] shape, int rank); - @Namespace("shape") public static native @Cast("bool") boolean isCommonVector(@Cast("const Nd4jLong*") LongPointer shapeInfo, @ByRef IntPointer posOfNonUnityDim); - @Namespace("shape") public static native @Cast("bool") boolean isCommonVector(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @ByRef IntBuffer posOfNonUnityDim); - @Namespace("shape") public static native @Cast("bool") boolean isCommonVector(@Cast("const Nd4jLong*") long[] shapeInfo, @ByRef int[] posOfNonUnityDim); - - @Namespace("shape") public static native @Cast("bool") boolean isRowVector(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean isRowVector(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean isRowVector(@Cast("const Nd4jLong*") long[] shapeInfo); - - @Namespace("shape") public static native @Cast("bool") boolean isColumnVector(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean isColumnVector(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean isColumnVector(@Cast("const Nd4jLong*") long[] shapeInfo); +/** + * When 1 dimension is the whole length of the + * array + */ +@Namespace("shape") public static native int oneDimEqualToLength(@Cast("Nd4jLong*") LongPointer shape, int rank); +@Namespace("shape") public static native int oneDimEqualToLength(@Cast("Nd4jLong*") LongBuffer shape, int rank); +@Namespace("shape") public static native int oneDimEqualToLength(@Cast("Nd4jLong*") long[] shape, int rank); + +@Namespace("shape") public static native int oneDimEqualToLength(@Cast("Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native int oneDimEqualToLength(@Cast("Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native int oneDimEqualToLength(@Cast("Nd4jLong*") long[] shapeInfo); + +@Namespace("shape") public static native int isVector(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native int isVector(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native int isVector(@Cast("const Nd4jLong*") long[] shapeInfo); + +@Namespace("shape") public static native @Cast("bool") boolean isLikeVector(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @ByRef IntPointer posOfNonUnityDim); +@Namespace("shape") public static native @Cast("bool") boolean isLikeVector(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @ByRef IntBuffer posOfNonUnityDim); +@Namespace("shape") public static native @Cast("bool") boolean isLikeVector(@Cast("const Nd4jLong*") long[] shapeInfo, + @ByRef int[] posOfNonUnityDim); + +@Namespace("shape") public static native @Cast("bool") boolean isCommonVector(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @ByRef IntPointer posOfNonUnityDim); +@Namespace("shape") public static native @Cast("bool") boolean isCommonVector(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @ByRef IntBuffer posOfNonUnityDim); +@Namespace("shape") public static native @Cast("bool") boolean isCommonVector(@Cast("const Nd4jLong*") long[] shapeInfo, + @ByRef int[] posOfNonUnityDim); + +@Namespace("shape") public static native @Cast("bool") boolean isRowVector(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean isRowVector(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean isRowVector(@Cast("const Nd4jLong*") long[] shapeInfo); + +@Namespace("shape") public static native @Cast("bool") boolean isColumnVector(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean isColumnVector(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean isColumnVector(@Cast("const Nd4jLong*") long[] shapeInfo); - /** - * shape - input inShape is shape only, not shapeInfo - * returns number of non-unity dimensions in inShape - */ - @Namespace("shape") public static native int numOfNonUnitDims(int rank, @Cast("const Nd4jLong*") LongPointer inShape); - @Namespace("shape") public static native int numOfNonUnitDims(int rank, @Cast("const Nd4jLong*") LongBuffer inShape); - @Namespace("shape") public static native int numOfNonUnitDims(int rank, @Cast("const Nd4jLong*") long[] inShape); +/** + * shape - input inShape is shape only, not shapeInfo + * returns number of non-unity dimensions in inShape + */ +@Namespace("shape") public static native int numOfNonUnitDims(int rank, + @Cast("const Nd4jLong*") LongPointer inShape); +@Namespace("shape") public static native int numOfNonUnitDims(int rank, + @Cast("const Nd4jLong*") LongBuffer inShape); +@Namespace("shape") public static native int numOfNonUnitDims(int rank, + @Cast("const Nd4jLong*") long[] inShape); - /** +/** * Returns whether the * given shape is a vector or not * @param shape the shape of the array * @param rank the rank of the shape */ - @Namespace("shape") public static native int isMatrix(@Cast("const Nd4jLong*") LongPointer shape, int rank); - @Namespace("shape") public static native int isMatrix(@Cast("const Nd4jLong*") LongBuffer shape, int rank); - @Namespace("shape") public static native int isMatrix(@Cast("const Nd4jLong*") long[] shape, int rank); +@Namespace("shape") public static native int isMatrix(@Cast("const Nd4jLong*") LongPointer shape, int rank); +@Namespace("shape") public static native int isMatrix(@Cast("const Nd4jLong*") LongBuffer shape, int rank); +@Namespace("shape") public static native int isMatrix(@Cast("const Nd4jLong*") long[] shape, int rank); - @Namespace("shape") public static native int isMatrix(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native int isMatrix(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native int isMatrix(@Cast("const Nd4jLong*") long[] shapeInfo); +@Namespace("shape") public static native int isMatrix(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native int isMatrix(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native int isMatrix(@Cast("const Nd4jLong*") long[] shapeInfo); /** * Returns the shape portion of an information * buffer */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeOf(@Cast("Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeOf(@Cast("Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeOf(@Cast("Nd4jLong*") long[] shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeOf(@Cast("Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeOf(@Cast("Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeOf(@Cast("Nd4jLong*") long[] shapeInfo); /** * Return a copy of a buffer. @@ -7408,19 +8057,22 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * that must be freed elsewhere. */ - /** +/** * Return a copy of a buffer. * This buffer allocates memory * that must be freed elsewhere. */ - /** -* Return a copy of a buffer. -* This buffer allocates memory -* that must be freed elsewhere. -*/ - @Namespace("shape") public static native void copyTo(int length, @Cast("const Nd4jLong*") LongPointer from, @Cast("Nd4jLong*") LongPointer to, @Cast("Nd4jLong*") LongPointer indexes); - @Namespace("shape") public static native void copyTo(int length, @Cast("const Nd4jLong*") LongBuffer from, @Cast("Nd4jLong*") LongBuffer to, @Cast("Nd4jLong*") LongBuffer indexes); - @Namespace("shape") public static native void copyTo(int length, @Cast("const Nd4jLong*") long[] from, @Cast("Nd4jLong*") long[] to, @Cast("Nd4jLong*") long[] indexes); +/** + * Return a copy of a buffer. + * This buffer allocates memory + * that must be freed elsewhere. + */ +@Namespace("shape") public static native void copyTo(int length, @Cast("const Nd4jLong*") LongPointer from, @Cast("Nd4jLong*") LongPointer to, + @Cast("Nd4jLong*") LongPointer indexes); +@Namespace("shape") public static native void copyTo(int length, @Cast("const Nd4jLong*") LongBuffer from, @Cast("Nd4jLong*") LongBuffer to, + @Cast("Nd4jLong*") LongBuffer indexes); +@Namespace("shape") public static native void copyTo(int length, @Cast("const Nd4jLong*") long[] from, @Cast("Nd4jLong*") long[] to, + @Cast("Nd4jLong*") long[] indexes); /** * Permute the given strides @@ -7431,24 +8083,28 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * and all must be filled in) * @return the rearranged array */ - //ND4J_EXPORT _CUDA_HD Nd4jLong *permutedStrides(Nd4jLong *toPermute, int shapeRank, Nd4jLong *rearrange); +// SD_EXPORT _CUDA_HD Nd4jLong *permutedStrides(Nd4jLong *toPermute, int +// shapeRank, Nd4jLong *rearrange); /** * Return the slice (shape + 1 in pointer arithmetic) * @param shape the shape to take the slice of * @return the shape array - the first entry */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer slice(@Cast("Nd4jLong*") LongPointer shape); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer slice(@Cast("Nd4jLong*") LongBuffer shape); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] slice(@Cast("Nd4jLong*") long[] shape); - - @Namespace("shape") public static native int slices(@Cast("Nd4jLong*") LongPointer shapeBuffer); - @Namespace("shape") public static native int slices(@Cast("Nd4jLong*") LongBuffer shapeBuffer); - @Namespace("shape") public static native int slices(@Cast("Nd4jLong*") long[] shapeBuffer); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer sliceOfShapeBuffer(@Cast("Nd4jLong") long sliceIdx, @Cast("Nd4jLong*") LongPointer shapeBuffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer sliceOfShapeBuffer(@Cast("Nd4jLong") long sliceIdx, @Cast("Nd4jLong*") LongBuffer shapeBuffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] sliceOfShapeBuffer(@Cast("Nd4jLong") long sliceIdx, @Cast("Nd4jLong*") long[] shapeBuffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer slice(@Cast("Nd4jLong*") LongPointer shape); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer slice(@Cast("Nd4jLong*") LongBuffer shape); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] slice(@Cast("Nd4jLong*") long[] shape); + +@Namespace("shape") public static native int slices(@Cast("Nd4jLong*") LongPointer shapeBuffer); +@Namespace("shape") public static native int slices(@Cast("Nd4jLong*") LongBuffer shapeBuffer); +@Namespace("shape") public static native int slices(@Cast("Nd4jLong*") long[] shapeBuffer); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer sliceOfShapeBuffer(@Cast("Nd4jLong") long sliceIdx, + @Cast("Nd4jLong*") LongPointer shapeBuffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer sliceOfShapeBuffer(@Cast("Nd4jLong") long sliceIdx, + @Cast("Nd4jLong*") LongBuffer shapeBuffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] sliceOfShapeBuffer(@Cast("Nd4jLong") long sliceIdx, + @Cast("Nd4jLong*") long[] shapeBuffer); /** * Returns the length of the * shape information buffer: @@ -7457,35 +8113,35 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * info length for * @return rank * 2 + 4 */ - @Namespace("shape") public static native int shapeInfoLength(int rank); +@Namespace("shape") public static native int shapeInfoLength(int rank); - @Namespace("shape") public static native int shapeInfoLength(@Cast("Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native int shapeInfoLength(@Cast("Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native int shapeInfoLength(@Cast("Nd4jLong*") long[] shapeInfo); +@Namespace("shape") public static native int shapeInfoLength(@Cast("Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native int shapeInfoLength(@Cast("Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native int shapeInfoLength(@Cast("Nd4jLong*") long[] shapeInfo); - @Namespace("shape") public static native @Cast("size_t") long shapeInfoByteLength(int rank); +@Namespace("shape") public static native @Cast("size_t") long shapeInfoByteLength(int rank); - @Namespace("shape") public static native @Cast("size_t") long shapeInfoByteLength(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native @Cast("size_t") long shapeInfoByteLength(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native @Cast("size_t") long shapeInfoByteLength(@Cast("const Nd4jLong*") long[] shapeInfo); +@Namespace("shape") public static native @Cast("size_t") long shapeInfoByteLength(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native @Cast("size_t") long shapeInfoByteLength(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native @Cast("size_t") long shapeInfoByteLength(@Cast("const Nd4jLong*") long[] shapeInfo); /** * Returns the rank portion of * an information buffer */ - @Namespace("shape") public static native int rank(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native int rank(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native int rank(@Cast("const Nd4jLong*") long[] shapeInfo); - @Namespace("shape") public static native int rank(@Const IntPointer shapeInfo); - @Namespace("shape") public static native int rank(@Const IntBuffer shapeInfo); - @Namespace("shape") public static native int rank(@Const int[] shapeInfo); +@Namespace("shape") public static native int rank(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native int rank(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native int rank(@Cast("const Nd4jLong*") long[] shapeInfo); +@Namespace("shape") public static native int rank(@Const IntPointer shapeInfo); +@Namespace("shape") public static native int rank(@Const IntBuffer shapeInfo); +@Namespace("shape") public static native int rank(@Const int[] shapeInfo); - /** - * returns pointer on elementWiseStride - */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer ews(@Cast("Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer ews(@Cast("Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] ews(@Cast("Nd4jLong*") long[] shapeInfo); +/** + * returns pointer on elementWiseStride + */ +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer ews(@Cast("Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer ews(@Cast("Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] ews(@Cast("Nd4jLong*") long[] shapeInfo); /** * Converts a raw int buffer of the layout: @@ -7497,81 +8153,83 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * * where shape and stride are both straight int pointers */ - @Namespace("shape") public static native ShapeInformation infoFromBuffer(@Cast("Nd4jLong*") LongPointer buffer); - @Namespace("shape") public static native ShapeInformation infoFromBuffer(@Cast("Nd4jLong*") LongBuffer buffer); - @Namespace("shape") public static native ShapeInformation infoFromBuffer(@Cast("Nd4jLong*") long[] buffer); +@Namespace("shape") public static native ShapeInformation infoFromBuffer(@Cast("Nd4jLong*") LongPointer buffer); +@Namespace("shape") public static native ShapeInformation infoFromBuffer(@Cast("Nd4jLong*") LongBuffer buffer); +@Namespace("shape") public static native ShapeInformation infoFromBuffer(@Cast("Nd4jLong*") long[] buffer); /** * Returns the stride portion of an information * buffer */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer stride(@Cast("Nd4jLong*") LongPointer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer stride(@Cast("Nd4jLong*") LongBuffer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] stride(@Cast("Nd4jLong*") long[] buffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer stride(@Cast("Nd4jLong*") LongPointer buffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer stride(@Cast("Nd4jLong*") LongBuffer buffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] stride(@Cast("Nd4jLong*") long[] buffer); /** * Compute the length of the given shape */ - @Namespace("shape") public static native @Cast("bool") boolean isEmpty(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean isEmpty(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean isEmpty(@Cast("const Nd4jLong*") long[] shapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean isEmpty(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean isEmpty(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean isEmpty(@Cast("const Nd4jLong*") long[] shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long length(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long length(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long length(@Cast("const Nd4jLong*") long[] shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long length(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long length(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long length(@Cast("const Nd4jLong*") long[] shapeInfo); /*** * Returns the offset portion of an information buffer */ - @Namespace("shape") public static native @Cast("Nd4jLong") long offset(@Cast("Nd4jLong*") LongPointer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong") long offset(@Cast("Nd4jLong*") LongBuffer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong") long offset(@Cast("Nd4jLong*") long[] buffer); +@Namespace("shape") public static native @Cast("Nd4jLong") long offset(@Cast("Nd4jLong*") LongPointer buffer); +@Namespace("shape") public static native @Cast("Nd4jLong") long offset(@Cast("Nd4jLong*") LongBuffer buffer); +@Namespace("shape") public static native @Cast("Nd4jLong") long offset(@Cast("Nd4jLong*") long[] buffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") @ByRef LongPointer extra(@Cast("Nd4jLong*") LongPointer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") @ByRef LongBuffer extra(@Cast("Nd4jLong*") LongBuffer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") @ByRef long[] extra(@Cast("Nd4jLong*") long[] buffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") @ByRef LongPointer extra(@Cast("Nd4jLong*") LongPointer buffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") @ByRef LongBuffer extra(@Cast("Nd4jLong*") LongBuffer buffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") @ByRef long[] extra(@Cast("Nd4jLong*") long[] buffer); /** * Returns the ordering * for this shape information buffer */ - @Namespace("shape") public static native char order(@Cast("const Nd4jLong*") LongPointer buffer); - @Namespace("shape") public static native char order(@Cast("const Nd4jLong*") LongBuffer buffer); - @Namespace("shape") public static native char order(@Cast("const Nd4jLong*") long[] buffer); +@Namespace("shape") public static native char order(@Cast("const Nd4jLong*") LongPointer buffer); +@Namespace("shape") public static native char order(@Cast("const Nd4jLong*") LongBuffer buffer); +@Namespace("shape") public static native char order(@Cast("const Nd4jLong*") long[] buffer); /** * Returns the type */ - @Namespace("shape") public static native @Cast("Nd4jLong") long type(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long type(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long type(@Cast("const Nd4jLong*") long[] shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long type(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long type(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long type(@Cast("const Nd4jLong*") long[] shapeInfo); /** * Returns the element wise stride for this information * buffer */ - @Namespace("shape") public static native @Cast("Nd4jLong") long elementWiseStride(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long elementWiseStride(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long elementWiseStride(@Cast("const Nd4jLong*") long[] shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long elementWiseStride(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long elementWiseStride(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long elementWiseStride(@Cast("const Nd4jLong*") long[] shapeInfo); - - /** +/** * Returns the element wise stride for this information * buffer - * relative to a dimension and ordering for a reduction index + * relative to a dimension and ordering for a reduction index */ - @Namespace("shape") public static native @Cast("Nd4jLong") long reductionIndexElementWiseStride(@Cast("Nd4jLong*") LongPointer buffer, IntPointer dimension, int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong") long reductionIndexElementWiseStride(@Cast("Nd4jLong*") LongBuffer buffer, IntBuffer dimension, int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong") long reductionIndexElementWiseStride(@Cast("Nd4jLong*") long[] buffer, int[] dimension, int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long reductionIndexElementWiseStride( + @Cast("Nd4jLong*") LongPointer buffer, IntPointer dimension, int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long reductionIndexElementWiseStride( + @Cast("Nd4jLong*") LongBuffer buffer, IntBuffer dimension, int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long reductionIndexElementWiseStride( + @Cast("Nd4jLong*") long[] buffer, int[] dimension, int dimensionLength); /** * Returns whether * the given shape info buffer * represents a scalar shape */ - @Namespace("shape") public static native int isScalar(@Cast("const Nd4jLong*") LongPointer info); - @Namespace("shape") public static native int isScalar(@Cast("const Nd4jLong*") LongBuffer info); - @Namespace("shape") public static native int isScalar(@Cast("const Nd4jLong*") long[] info); +@Namespace("shape") public static native int isScalar(@Cast("const Nd4jLong*") LongPointer info); +@Namespace("shape") public static native int isScalar(@Cast("const Nd4jLong*") LongBuffer info); +@Namespace("shape") public static native int isScalar(@Cast("const Nd4jLong*") long[] info); /** * Returns whether @@ -7579,7 +8237,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * represents a scalar * shape or not */ - @Namespace("shape") public static native int isScalar(ShapeInformation info); +@Namespace("shape") public static native int isScalar(ShapeInformation info); /** * Return a copy of this array with the @@ -7594,7 +8252,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * item */ - /** +/** * Return a copy of this array with the * given index omitted * @@ -7607,20 +8265,26 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * item */ - /** - * Iterate over a given set of indexes - * the begin and end indexes are 0 based. - * 1 padding is automatically assumed for the ending. - * - * For example if you want to iterate over 0 to 4 - * it will go to 4 rather than 3. - * - * indexes should be the indexes to exclude - * indexes length should be the length of indexes - */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer everyIndexBut(@Cast("const Nd4jLong*") LongPointer indexes,int indexesLength,int begin,int end); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer everyIndexBut(@Cast("const Nd4jLong*") LongBuffer indexes,int indexesLength,int begin,int end); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] everyIndexBut(@Cast("const Nd4jLong*") long[] indexes,int indexesLength,int begin,int end); +/** + * Iterate over a given set of indexes + * the begin and end indexes are 0 based. + * 1 padding is automatically assumed for the ending. + * + * For example if you want to iterate over 0 to 4 + * it will go to 4 rather than 3. + * + * indexes should be the indexes to exclude + * indexes length should be the length of indexes + */ +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer everyIndexBut(@Cast("const Nd4jLong*") LongPointer indexes, + int indexesLength, int begin, + int end); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer everyIndexBut(@Cast("const Nd4jLong*") LongBuffer indexes, + int indexesLength, int begin, + int end); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] everyIndexBut(@Cast("const Nd4jLong*") long[] indexes, + int indexesLength, int begin, + int end); /** * Computes the offset for accessing @@ -7630,7 +8294,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint //#ifdef __CUDACC__ // __device__ //#endif -// ND4J_EXPORT int tadOffset(shape::ShapeInformation *xInfo, int offset); +// SD_EXPORT int tadOffset(shape::ShapeInformation *xInfo, int offset); /** * Returns a shape @@ -7640,15 +8304,15 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * for the shape to be returned as * @return the new shape */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer ensureVectorShape(@Cast("Nd4jLong*") LongPointer shape); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer ensureVectorShape(@Cast("Nd4jLong*") LongBuffer shape); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] ensureVectorShape(@Cast("Nd4jLong*") long[] shape); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer ensureVectorShape(@Cast("Nd4jLong*") LongPointer shape); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer ensureVectorShape(@Cast("Nd4jLong*") LongBuffer shape); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] ensureVectorShape(@Cast("Nd4jLong*") long[] shape); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createScalarShapeInfo(); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createScalarShapeInfo(); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createScalarShapeInfo(@Cast("Nd4jLong*") LongPointer ret); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createScalarShapeInfo(@Cast("Nd4jLong*") LongBuffer ret); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] createScalarShapeInfo(@Cast("Nd4jLong*") long[] ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createScalarShapeInfo(@Cast("Nd4jLong*") LongPointer ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createScalarShapeInfo(@Cast("Nd4jLong*") LongBuffer ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] createScalarShapeInfo(@Cast("Nd4jLong*") long[] ret); /** * Generate an int buffer @@ -7666,9 +8330,12 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * Keep the given indexes * in the data */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer keep(@Cast("Nd4jLong*") LongPointer data, @Const IntPointer index, int indexLength, int dataLength); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer keep(@Cast("Nd4jLong*") LongBuffer data, @Const IntBuffer index, int indexLength, int dataLength); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] keep(@Cast("Nd4jLong*") long[] data, @Const int[] index, int indexLength, int dataLength); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer keep(@Cast("Nd4jLong*") LongPointer data, @Const IntPointer index, + int indexLength, int dataLength); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer keep(@Cast("Nd4jLong*") LongBuffer data, @Const IntBuffer index, + int indexLength, int dataLength); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] keep(@Cast("Nd4jLong*") long[] data, @Const int[] index, + int indexLength, int dataLength); /** * Generate reverse copy of the data @@ -7706,9 +8373,15 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @return the length per slice of the given shape * along the given dimension */ - @Namespace("shape") public static native @Cast("Nd4jLong") long lengthPerSlice(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Const IntPointer dimension, int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong") long lengthPerSlice(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Const IntBuffer dimension, int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong") long lengthPerSlice(int rank, @Cast("const Nd4jLong*") long[] shape, @Const int[] dimension, int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long lengthPerSlice(int rank, @Cast("const Nd4jLong*") LongPointer shape, + @Const IntPointer dimension, + int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long lengthPerSlice(int rank, @Cast("const Nd4jLong*") LongBuffer shape, + @Const IntBuffer dimension, + int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long lengthPerSlice(int rank, @Cast("const Nd4jLong*") long[] shape, + @Const int[] dimension, + int dimensionLength); /** * calculates the offset for a tensor @@ -7717,27 +8390,15 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param tensorShape * @return */ - @Namespace("shape") public static native @Cast("Nd4jLong") long sliceOffsetForTensor(int rank, - int index, - @Cast("const Nd4jLong*") LongPointer shape, - @Cast("const Nd4jLong*") LongPointer tensorShape, - int tensorShapeLength, - @Const IntPointer dimension, - int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong") long sliceOffsetForTensor(int rank, - int index, - @Cast("const Nd4jLong*") LongBuffer shape, - @Cast("const Nd4jLong*") LongBuffer tensorShape, - int tensorShapeLength, - @Const IntBuffer dimension, - int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong") long sliceOffsetForTensor(int rank, - int index, - @Cast("const Nd4jLong*") long[] shape, - @Cast("const Nd4jLong*") long[] tensorShape, - int tensorShapeLength, - @Const int[] dimension, - int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long sliceOffsetForTensor( + int rank, int index, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer tensorShape, + int tensorShapeLength, @Const IntPointer dimension, int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long sliceOffsetForTensor( + int rank, int index, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer tensorShape, + int tensorShapeLength, @Const IntBuffer dimension, int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long sliceOffsetForTensor( + int rank, int index, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] tensorShape, + int tensorShapeLength, @Const int[] dimension, int dimensionLength); /** * calculates the offset for a tensor @@ -7746,53 +8407,55 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param tensorShape * @return */ - @Namespace("shape") public static native @Cast("Nd4jLong") long sliceOffsetForTensor(int index,int tensorLength,int lengthPerSlice2); +@Namespace("shape") public static native @Cast("Nd4jLong") long sliceOffsetForTensor(int index, int tensorLength, + int lengthPerSlice2); /** * Computes the tensor along dimension * offset * @param index the index to get the offset for the tad for * @param rank the rank of the shapes and strides * @param info the shape information to use for tad - * @param dimension the dimensions to use for computing the tensor along dimensions + * @param dimension the dimensions to use for computing the tensor along + * dimensions */ -// ND4J_EXPORT _CUDA_HD int offset(int index, +// SD_EXPORT _CUDA_HD int offset(int index, // int rank, // shape::ShapeInformation *info, // Nd4jLong *dimension, // int dimensionLength); - /** * Computes the number * of tensors along * a given dimension */ - @Namespace("shape") public static native @Cast("Nd4jLong") long tensorsAlongDimension(int rank, - int length, - @Cast("Nd4jLong*") LongPointer shape, - IntPointer dimension, - int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong") long tensorsAlongDimension(int rank, - int length, - @Cast("Nd4jLong*") LongBuffer shape, - IntBuffer dimension, - int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong") long tensorsAlongDimension(int rank, - int length, - @Cast("Nd4jLong*") long[] shape, - int[] dimension, - int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long tensorsAlongDimension(int rank, int length, + @Cast("Nd4jLong*") LongPointer shape, + IntPointer dimension, + int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long tensorsAlongDimension(int rank, int length, + @Cast("Nd4jLong*") LongBuffer shape, + IntBuffer dimension, + int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long tensorsAlongDimension(int rank, int length, + @Cast("Nd4jLong*") long[] shape, + int[] dimension, + int dimensionLength); /** * Computes the number * of tensors along * a given dimension */ - @Namespace("shape") public static native @Cast("Nd4jLong") long tensorsAlongDimension(@Cast("Nd4jLong*") LongPointer shapeInfo, IntPointer dimension, int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong") long tensorsAlongDimension(@Cast("Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension, int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong") long tensorsAlongDimension(@Cast("Nd4jLong*") long[] shapeInfo, int[] dimension, int dimensionLength); - - +@Namespace("shape") public static native @Cast("Nd4jLong") long tensorsAlongDimension(@Cast("Nd4jLong*") LongPointer shapeInfo, + IntPointer dimension, + int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long tensorsAlongDimension(@Cast("Nd4jLong*") LongBuffer shapeInfo, + IntBuffer dimension, + int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long tensorsAlongDimension(@Cast("Nd4jLong*") long[] shapeInfo, + int[] dimension, + int dimensionLength); /** * Returns the tensor along dimension @@ -7802,26 +8465,30 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param i * @return */ - @Namespace("shape") public static native int tadForBlockIndex(int blockSize, int blockIdx, int i); +@Namespace("shape") public static native int tadForBlockIndex(int blockSize, int blockIdx, int i); /** * Computes the number of tads per block * */ - @Namespace("shape") public static native int tadsPerBlock(int blockSize, int tads); +@Namespace("shape") public static native int tadsPerBlock(int blockSize, int tads); -// ND4J_EXPORT _CUDA_HD Nd4jLong *tadShapeInfo(int index, Nd4jLong *xShapeInfo, Nd4jLong *dimension, +// SD_EXPORT _CUDA_HD Nd4jLong *tadShapeInfo(int index, Nd4jLong *xShapeInfo, +// Nd4jLong *dimension, // int dimensionLength); /** * Returns a shape buffer * for the shape information metadata. */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer toShapeBuffer( ShapeInformation info); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer toShapeBuffer(ShapeInformation info); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer toShapeBuffer( ShapeInformation info, @Cast("Nd4jLong*") LongPointer ret); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer toShapeBuffer( ShapeInformation info, @Cast("Nd4jLong*") LongBuffer ret); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] toShapeBuffer( ShapeInformation info, @Cast("Nd4jLong*") long[] ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer toShapeBuffer(ShapeInformation info, + @Cast("Nd4jLong*") LongPointer ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer toShapeBuffer(ShapeInformation info, + @Cast("Nd4jLong*") LongBuffer ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] toShapeBuffer(ShapeInformation info, + @Cast("Nd4jLong*") long[] ret); /** * Returns the number of elements per thread @@ -7872,25 +8539,29 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param numElementsPerTad the number of elements * per tad */ - @Namespace("shape") public static native int tadIndex(int i, int elementWiseStride, int numElementsPerTad); +@Namespace("shape") public static native int tadIndex(int i, int elementWiseStride, + int numElementsPerTad); /** * Map a tad to a * reduction index. * @param tadIndexForOriginal the original tad index for the * split up problem (eg: split is dimension 3 mapping to a 2,3 problem) - * @param tadsForReduced the number of tads for the shrunk down problem (eg: 2,3) + * @param tadsForReduced the number of tads for the shrunk down problem (eg: + * 2,3) * @param tadsForOriginal the number of tads for the smaller problem (eg: 3) */ - @Namespace("shape") public static native int reductionIndexForTad(int tadIndexForOriginal, int tadsForReduced, - int tadsForOriginal); +@Namespace("shape") public static native int reductionIndexForTad(int tadIndexForOriginal, + int tadsForReduced, + int tadsForOriginal); /** * Computes the number of tads * per reduce index for the * reduction tad. */ - @Namespace("shape") public static native int tadsPerReduceIndex(int tadsForReduce, int tadsForOriginal); +@Namespace("shape") public static native int tadsPerReduceIndex(int tadsForReduce, + int tadsForOriginal); /** * Maps a linear index to a reduction index @@ -7900,351 +8571,653 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param tadNum the number of tads for the shrunken problem * @param originalTadNum the tad number for the reduced version of the problem */ - @Namespace("shape") public static native int reductionIndexForLinear(int i, int elementWiseStride, int numElementsPerTad, - int tadNum, int originalTadNum); +@Namespace("shape") public static native int reductionIndexForLinear(int i, int elementWiseStride, + int numElementsPerTad, + int tadNum, int originalTadNum); /** * Returns the prod of the data * up to the given length */ - @Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongPointer data, int length); - @Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongBuffer data, int length); - @Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") long[] data, int length); - - /** - * Returns the rear most left over item not present in - * the dimension array. This assumes that the dimension array is sorted. - * - * For example, given a dimension array of: - * 0,2 - * - * and - * - * 12,4,2,1 in data - * - * You end up with 1 (data[3]) - * since the first item won't match - * the last item of the dimension array - */ - -// ND4J_EXPORT _CUDA_HD int rearMostLeftOverItem(Nd4jLong *data,int length,Nd4jLong *dimension,int dimensionLength); - - /** -* Get an offset for retrieval -* from a data buffer -* based on the given -* shape stride and given indices -* @param baseOffset the offset to start from -* @param shape the shape of the array -* @param stride the stride of the array -* @param indices the indices to iterate over -* @return the double at the specified index -*/ +@Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongPointer data, int length); +@Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongBuffer data, int length); +@Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") long[] data, int length); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords, @Cast("Nd4jLong") long baseOffset/*=0*/); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords, @Cast("Nd4jLong") long baseOffset/*=0*/); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords, @Cast("Nd4jLong") long baseOffset/*=0*/); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords, @Cast("Nd4jLong") long baseOffset/*=0*/); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords, @Cast("Nd4jLong") long baseOffset/*=0*/); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords, @Cast("Nd4jLong") long baseOffset/*=0*/); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, int rank); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, int rank); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] createShapeInfo(@Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] stride, int rank); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, int rank, @Cast("Nd4jLong*") LongPointer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, int rank, @Cast("Nd4jLong*") LongBuffer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] createShapeInfo(@Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] stride, int rank, @Cast("Nd4jLong*") long[] buffer); - - /** - * Convert a linear index to the corresponding coordinates - * for example if shape is {2, 4}, then index 5 corresponds to coordinates [1, 1] - */ - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer coords); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer coords); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] coords); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, IntPointer coords); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, IntBuffer coords); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, int[] coords); - - @Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords); - @Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords); - @Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords); - @Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords); - @Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords); - @Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords); - - /** - * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! - */ - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords, int dimsSize, @Const IntPointer tadDims); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords, int dimsSize, @Const IntBuffer tadDims); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords, int dimsSize, @Const int[] tadDims); - - /** - * Convert coordinates to the corresponding linear index (sequence number in other words) - * for example if shape is {2, 4} and coordinates [1, 1] then index 5 is returned - */ - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Const IntPointer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Const IntBuffer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Const int[] coords); - /** - * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! - */ - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords, int dimsSize, @Const IntPointer tadDims); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords, int dimsSize, @Const IntBuffer tadDims); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords, int dimsSize, @Const int[] tadDims); - - /** - * increment n-dimensional array by one iteration by changing coord appropriately - * for example we have array with shape {2, 3}: - * - if input coord = {0,1}, then output coord = {0,2} - * - if input coord = {0,2}, then output coord = {1,0} - * so the aim is to produce following subsequence of coord: {0,0}, {0,1}, {0,2}, {1,0}, {1,1}, {1,2} - */ +/** + * Returns the rear most left over item not present in + * the dimension array. This assumes that the dimension array is sorted. + * + * For example, given a dimension array of: + * 0,2 + * + * and + * + * 12,4,2,1 in data + * + * You end up with 1 (data[3]) + * since the first item won't match + * the last item of the dimension array + */ - /* calculates an array buffer offset for given "index" using following formula: offset = coord_0*stride_0 + coord_1*stride_1 + ... + coord_{rank-1}*stride_{rank-1} - */ - @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntPointer shapeInfo); - @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntBuffer shapeInfo); - @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") int[] shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer lShapeInfo, @Cast("const uint*") IntPointer uShapeInfo, @Cast("const bool") boolean useUnsigned); - @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer lShapeInfo, @Cast("const uint*") IntBuffer uShapeInfo, @Cast("const bool") boolean useUnsigned); - @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] lShapeInfo, @Cast("const uint*") int[] uShapeInfo, @Cast("const bool") boolean useUnsigned); - - @Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") long[] shapeInfo); - - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("const Nd4jLong*") long[] shapeInfo); - - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, @Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, @Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, @Cast("const Nd4jLong*") long[] shapeInfo); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, @Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, @Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, @Cast("const Nd4jLong*") long[] shapeInfo); - - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer strides); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer strides); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] strides); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer strides); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer strides); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] strides); - - @Namespace("shape") public static native void printIntArray(@Cast("const Nd4jLong*") LongPointer arr, int length); - @Namespace("shape") public static native void printIntArray(@Cast("const Nd4jLong*") LongBuffer arr, int length); - @Namespace("shape") public static native void printIntArray(@Cast("const Nd4jLong*") long[] arr, int length); - @Namespace("shape") public static native void printIntArray(@Const IntPointer arr, int length); - @Namespace("shape") public static native void printIntArray(@Const IntBuffer arr, int length); - @Namespace("shape") public static native void printIntArray(@Const int[] arr, int length); - - @Namespace("shape") public static native void printArray(FloatPointer arr,int length); - @Namespace("shape") public static native void printArray(FloatBuffer arr,int length); - @Namespace("shape") public static native void printArray(float[] arr,int length); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeBufferOfNpy(int rank, @Cast("unsigned int*") IntPointer shape,@Cast("bool") boolean fortranOrder); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeBufferOfNpy(int rank, @Cast("unsigned int*") IntBuffer shape,@Cast("bool") boolean fortranOrder); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeBufferOfNpy(int rank, @Cast("unsigned int*") int[] shape,@Cast("bool") boolean fortranOrder); - -// ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBufferOfNpyBuffer(char *buffer); - - - // this function checks the consistence of dimensions with array rank (negative dimensions, too large dimensions, too big number of dimensions) - // also sort input array of dimensions, this operation is also necessary for creating TAD object - @Namespace("shape") public static native void checkDimensions(int rank, @StdVector IntPointer dimensions); - @Namespace("shape") public static native void checkDimensions(int rank, @StdVector IntBuffer dimensions); - @Namespace("shape") public static native void checkDimensions(int rank, @StdVector int[] dimensions); - - // function calculates linear index of array min, min is sub-array of max, index to be returned is min-array's index and corresponds to maxIdx of max array - // dimsToExclude - should be sorted in increasing order - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo); - - // function calculates absolute offset of min array, min is sub-array of max, offset to be returned corresponds to maxIdx of max array - // dimsToExclude - should be sorted in increasing order - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo); - - // max array is outer for min array, min array is sub-array of max array - // function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs) - // dimsToExclude - should be sorted in increasing order - // dimsLen - length of dimsToExclude, if not set (= -1), then it is calculated as maxRank - minRank - @Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); - @Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo); - @Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); - @Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo); - @Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); - @Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo); - - // calculate indexes of max-array, these output indexes correspond to one minIdx index of min-array which is sub-array of max-array - // dimsToExclude - should be sorted in increasing order - @Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/); - @Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo); - @Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/); - @Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo); - @Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/); - @Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo); - - // calculate offsets of max-array, these offsets correspond to one minIdx index of min-array which is sub-array of max-array - // maxOffsets - will contain calculated offsets of max-array, buffer for maxOffsets should be allocated beforehand - // dimsToExclude - should be sorted in increasing order - // memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be allocated beforehand - @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, IntPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/); - @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, IntPointer memBuff); - @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, IntBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/); - @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, IntBuffer memBuff); - @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, int[] memBuff, @Const int[] dimsToExclude/*=nullptr*/); - @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, int[] memBuff); - - // calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array - // rank is equal to size of shape - @Namespace("shape") public static native void calcOffsets(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer strides, @Cast("Nd4jLong*") LongPointer offsets, byte order/*='c'*/); - @Namespace("shape") public static native void calcOffsets(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer strides, @Cast("Nd4jLong*") LongPointer offsets); - @Namespace("shape") public static native void calcOffsets(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer strides, @Cast("Nd4jLong*") LongBuffer offsets, byte order/*='c'*/); - @Namespace("shape") public static native void calcOffsets(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer strides, @Cast("Nd4jLong*") LongBuffer offsets); - @Namespace("shape") public static native void calcOffsets(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] strides, @Cast("Nd4jLong*") long[] offsets, byte order/*='c'*/); - @Namespace("shape") public static native void calcOffsets(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] strides, @Cast("Nd4jLong*") long[] offsets); - @Namespace("shape") public static native void calcOffsets(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer offsets, byte order/*='c'*/); - @Namespace("shape") public static native void calcOffsets(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer offsets); - @Namespace("shape") public static native void calcOffsets(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer offsets, byte order/*='c'*/); - @Namespace("shape") public static native void calcOffsets(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer offsets); - @Namespace("shape") public static native void calcOffsets(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] offsets, byte order/*='c'*/); - @Namespace("shape") public static native void calcOffsets(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] offsets); - // ND4J_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order = 'c'); - // ND4J_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order = 'c'); - @Namespace("shape") public static native void shapeOldScalar(@Cast("sd::DataType") int dtype, @Cast("Nd4jLong*const") LongPointer buffer, byte order); - @Namespace("shape") public static native void shapeOldScalar(@Cast("sd::DataType") int dtype, @Cast("Nd4jLong*const") LongBuffer buffer, byte order); - @Namespace("shape") public static native void shapeOldScalar(@Cast("sd::DataType") int dtype, @Cast("Nd4jLong*const") long[] buffer, byte order); - - // deduce order and element-wise stride - // if array is scalar or unit length vector then ews = 1 and order is preserved - // if array is common vector then ews = stride of non-unity dimension and order is preserved - // if strides are normal/contiguous then ews = 1 and corresponding order is set, otherwise ews = 0 and order is preserved - @Namespace("shape") public static native void checkStridesEwsAndOrder(@Cast("Nd4jLong*") LongPointer shapeInfo, byte proposedOrder, int numOfNonUnitDims, @Cast("const Nd4jLong*") LongPointer shapeNoUnities, @Cast("const Nd4jLong*") LongPointer stridesNoUnities); - @Namespace("shape") public static native void checkStridesEwsAndOrder(@Cast("Nd4jLong*") LongBuffer shapeInfo, byte proposedOrder, int numOfNonUnitDims, @Cast("const Nd4jLong*") LongBuffer shapeNoUnities, @Cast("const Nd4jLong*") LongBuffer stridesNoUnities); - @Namespace("shape") public static native void checkStridesEwsAndOrder(@Cast("Nd4jLong*") long[] shapeInfo, byte proposedOrder, int numOfNonUnitDims, @Cast("const Nd4jLong*") long[] shapeNoUnities, @Cast("const Nd4jLong*") long[] stridesNoUnities); - @Namespace("shape") public static native void checkStridesEwsAndOrder(@Cast("Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native void checkStridesEwsAndOrder(@Cast("Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native void checkStridesEwsAndOrder(@Cast("Nd4jLong*") long[] shapeInfo); - - /** - * processes whole set of sub-arrays - * evaluates shapeInfo of sub-arrays (all sub-arrays have the same shapeInfo) and their buffer offsets (each sub-array has its own unique offset from original this-buffer) - * arguments: - * wholeShapeInfo - original shapeInfo of whole array - * numOfSubArrs - number of sub-arrays, size of subArrOffsets is equal to numOfSubArrs - * dimsSize - size of dimsToExclude, if dimsSize = array rank or dimsSize = 0 it means sub-array is whole array, copy of wholeShapeInfo and one zero offset will be returned - * dimsToExclude - MUST BE SORTED, dimensions to evaluate sub-array along, i.e. when shape is [2,3,4,5] and dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5] - * subArrShapeInfo - output argument, contains shapeInfo (same for all sub-arrays) - * subArrOffsets - output argument, contains successive sub-arrays offsets from original this-buffer - * keepUnitiesInShape - if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b} - */ - @Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); - @Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets); - @Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); - @Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets); - @Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); - @Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets); - - /** - * processes only one sub-array, evaluates shapeInfo of sub-array and its buffer offset from original array - * arguments: - * idx - input argument, intervals of indexes which define the sub-array to point on, - * when isStrided = false then idx has form {dim0Start,dim0End, dim1Start,dim1End, ....} and length (2 * maxRank) - * when isStrided = true then idx has form {dim0Start,dim0End,dim0Stride, dim1Start,dim1End,dim1Stride, ....} and length (3 * maxRank) - * when (dimStart == dimEnd) then whole range will be used for current dimension - * maxShapeInfo - input argument, shapeInfo of original array - * minShapeInfo - output argument, shapeInfo of sub-array to be deduced - * minOffset - output argument, offset of sub-array buffer offsets from original buffer - * keepUnitiesInShape - input argument, if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b} - * isStrided - input argument, if true then idx has length (3 * this->rankOf()) and contains additional stride numbers which correspond to stride between dimStart and dimEnd, - * numOfUntiesInMinShape - input argument, number of occurrences in idx when (dimEnd - dimStart) = 1 - */ - @Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongPointer idx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongPointer minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/); - @Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongPointer idx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongPointer minOffset); - @Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongBuffer idx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongBuffer minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/); - @Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongBuffer idx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongBuffer minOffset); - @Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") long[] idx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") @ByRef long[] minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/); - @Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") long[] idx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") @ByRef long[] minOffset); - - /** - * for example inShapeInfo is {3, 2,1,4, 4,4,1, 16384,1,99} - * then output shapeNoUnities will contain {2,4, 4,1} - that is only shape and strides, no rank/type/ews/order - * stridesNoUnities will point on strides in shapeNoUnities that is on {4,1} - * returns number of non-unity dimensions in inShapeInfo - * if there is no unities in inShapeInfo, then no copy procedure will be performed and shapeNoUnities/stridesNoUnities will point on corresponding places in inShapeInfo - */ - @Namespace("shape") public static native int excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") LongPointer inShapeInfo, @Cast("Nd4jLong*&") @ByPtrRef LongPointer shapeNoUnities, @Cast("Nd4jLong*&") @ByPtrRef LongPointer stridesNoUnities); - @Namespace("shape") public static native int excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") LongBuffer inShapeInfo, @Cast("Nd4jLong*&") @ByPtrRef LongBuffer shapeNoUnities, @Cast("Nd4jLong*&") @ByPtrRef LongBuffer stridesNoUnities); - @Namespace("shape") public static native int excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") long[] inShapeInfo, @Cast("Nd4jLong*&") @ByPtrRef long[] shapeNoUnities, @Cast("Nd4jLong*&") @ByPtrRef long[] stridesNoUnities); - - /** - * for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, dimsToExclude = {1,3}, dimsSize = 2 - * then outShapeInfo will contain {3, 2,3,4, 12,4,1, 16384,1,99} - */ - @Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") LongPointer inShapeInfo, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer outShapeInfo); - @Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") LongBuffer inShapeInfo, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer outShapeInfo); - @Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") long[] inShapeInfo, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] outShapeInfo); - - /** - * get stride over contiguous axis (contiguous axis must have stride = 1) - * for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1) - */ - // INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo); - - - - - - -//END HEADERS - - - //BEGIN IMPLEMENTATIONS +// SD_EXPORT _CUDA_HD int rearMostLeftOverItem(Nd4jLong *data,int +// length,Nd4jLong *dimension,int dimensionLength); +/** + * Get an offset for retrieval + * from a data buffer + * based on the given + * shape stride and given indices + * @param baseOffset the offset to start from + * @param shape the shape of the array + * @param stride the stride of the array + * @param indices the indices to iterate over + * @return the double at the specified index + */ +@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Cast("const Nd4jLong*") LongPointer coords, + @Cast("Nd4jLong") long baseOffset/*=0*/); +@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Cast("const Nd4jLong*") LongPointer coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Cast("const Nd4jLong*") LongBuffer coords, + @Cast("Nd4jLong") long baseOffset/*=0*/); +@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Cast("const Nd4jLong*") LongBuffer coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, + @Cast("const Nd4jLong*") long[] coords, + @Cast("Nd4jLong") long baseOffset/*=0*/); +@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, + @Cast("const Nd4jLong*") long[] coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Const IntPointer coords, + @Cast("Nd4jLong") long baseOffset/*=0*/); +@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Const IntPointer coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Const IntBuffer coords, + @Cast("Nd4jLong") long baseOffset/*=0*/); +@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Const IntBuffer coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, + @Const int[] coords, + @Cast("Nd4jLong") long baseOffset/*=0*/); +@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, + @Const int[] coords); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, + int rank); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, + int rank); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] createShapeInfo(@Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] stride, + int rank); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, + int rank, @Cast("Nd4jLong*") LongPointer buffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, + int rank, @Cast("Nd4jLong*") LongBuffer buffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] createShapeInfo(@Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] stride, + int rank, @Cast("Nd4jLong*") long[] buffer); -// #ifdef __CUDACC__ -// #endif +/** + * Convert a linear index to the corresponding coordinates + * for example if shape is {2, 4}, then index 5 corresponds to coordinates [1, + * 1] + */ +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, + @Cast("Nd4jLong*") LongPointer coords); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Cast("Nd4jLong*") LongBuffer coords); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, + @Cast("Nd4jLong*") long[] coords); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, + IntPointer coords); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, + IntBuffer coords); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, + int[] coords); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, + @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer coords); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, + @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer coords); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, + @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] coords); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, + @Cast("const Nd4jLong*") LongPointer shape, IntPointer coords); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, + @Cast("const Nd4jLong*") LongBuffer shape, IntBuffer coords); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, + @Cast("const Nd4jLong*") long[] shape, int[] coords); + +@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, + @Cast("const Nd4jLong") long index, + @Cast("const Nd4jLong*") LongPointer shapeInfo, + @Cast("Nd4jLong*") LongPointer coords); +@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, + @Cast("const Nd4jLong") long index, + @Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Cast("Nd4jLong*") LongBuffer coords); +@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, + @Cast("const Nd4jLong") long index, + @Cast("const Nd4jLong*") long[] shapeInfo, + @Cast("Nd4jLong*") long[] coords); +@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, + @Cast("const Nd4jLong") long index, + @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords); +@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, + @Cast("const Nd4jLong") long index, + @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords); +@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, + @Cast("const Nd4jLong") long index, + @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords); /** -* Length of a tad given -* the shape information -*/ + * take into account only dimensions stored in tadDims, tadDims must be sorted + * in increasing order! + */ +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, + IntPointer coords, int dimsSize, + @Const IntPointer tadDims); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, + IntBuffer coords, int dimsSize, + @Const IntBuffer tadDims); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, + int[] coords, int dimsSize, + @Const int[] tadDims); + +/** + * Convert coordinates to the corresponding linear index (sequence number in + * other words) for example if shape is {2, 4} and coordinates [1, 1] then index + * 5 is returned + */ +@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Cast("const Nd4jLong*") LongPointer coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Cast("const Nd4jLong*") LongBuffer coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, + @Cast("const Nd4jLong*") long[] coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Const IntPointer coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Const IntBuffer coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, + @Const int[] coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, + @Const IntPointer coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, + @Const IntBuffer coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, + @Const int[] coords); +/** + * take into account only dimensions stored in tadDims, tadDims must be sorted + * in increasing order! + */ +@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Const IntPointer coords, int dimsSize, + @Const IntPointer tadDims); +@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Const IntBuffer coords, int dimsSize, + @Const IntBuffer tadDims); +@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, + @Const int[] coords, int dimsSize, + @Const int[] tadDims); + +/** + * increment n-dimensional array by one iteration by changing coord + * appropriately for example we have array with shape {2, 3}: + * - if input coord = {0,1}, then output coord = {0,2} + * - if input coord = {0,2}, then output coord = {1,0} + * so the aim is to produce following subsequence of coord: {0,0}, {0,1}, {0,2}, + * {1,0}, {1,1}, {1,2} + */ + +/* calculates an array buffer offset for given "index" using following formula: + * offset = coord_0*stride_0 + coord_1*stride_1 + ... + + * coord_{rank-1}*stride_{rank-1} + */ +@Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntPointer shapeInfo); +@Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntBuffer shapeInfo); +@Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") int[] shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, + @Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, + @Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, + @Cast("const Nd4jLong*") long[] shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, + @Cast("const Nd4jLong*") LongPointer lShapeInfo, + @Cast("const uint*") IntPointer uShapeInfo, + @Cast("const bool") boolean useUnsigned); +@Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, + @Cast("const Nd4jLong*") LongBuffer lShapeInfo, + @Cast("const uint*") IntBuffer uShapeInfo, + @Cast("const bool") boolean useUnsigned); +@Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, + @Cast("const Nd4jLong*") long[] lShapeInfo, + @Cast("const uint*") int[] uShapeInfo, + @Cast("const bool") boolean useUnsigned); + +@Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") long[] shapeInfo); + +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("const Nd4jLong*") long[] shapeInfo); + +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, + @Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, + @Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, + @Cast("const Nd4jLong*") long[] shapeInfo); +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, + @Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, + @Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, + @Cast("const Nd4jLong*") long[] shapeInfo); + +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, int rank, + @Cast("const Nd4jLong*") LongPointer shape, + @Cast("const Nd4jLong*") LongPointer strides); +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, int rank, + @Cast("const Nd4jLong*") LongBuffer shape, + @Cast("const Nd4jLong*") LongBuffer strides); +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, int rank, + @Cast("const Nd4jLong*") long[] shape, + @Cast("const Nd4jLong*") long[] strides); +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, int rank, + @Cast("const Nd4jLong*") LongPointer shape, + @Cast("const Nd4jLong*") LongPointer strides); +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, int rank, + @Cast("const Nd4jLong*") LongBuffer shape, + @Cast("const Nd4jLong*") LongBuffer strides); +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, int rank, + @Cast("const Nd4jLong*") long[] shape, + @Cast("const Nd4jLong*") long[] strides); + +@Namespace("shape") public static native void printIntArray(@Cast("const Nd4jLong*") LongPointer arr, int length); +@Namespace("shape") public static native void printIntArray(@Cast("const Nd4jLong*") LongBuffer arr, int length); +@Namespace("shape") public static native void printIntArray(@Cast("const Nd4jLong*") long[] arr, int length); +@Namespace("shape") public static native void printIntArray(@Const IntPointer arr, int length); +@Namespace("shape") public static native void printIntArray(@Const IntBuffer arr, int length); +@Namespace("shape") public static native void printIntArray(@Const int[] arr, int length); + +@Namespace("shape") public static native void printArray(FloatPointer arr, int length); +@Namespace("shape") public static native void printArray(FloatBuffer arr, int length); +@Namespace("shape") public static native void printArray(float[] arr, int length); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeBufferOfNpy(int rank, @Cast("unsigned int*") IntPointer shape, + @Cast("bool") boolean fortranOrder); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeBufferOfNpy(int rank, @Cast("unsigned int*") IntBuffer shape, + @Cast("bool") boolean fortranOrder); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeBufferOfNpy(int rank, @Cast("unsigned int*") int[] shape, + @Cast("bool") boolean fortranOrder); + +// SD_EXPORT _CUDA_HD Nd4jLong *shapeBufferOfNpyBuffer(char *buffer); + +// this function checks the consistence of dimensions with array rank (negative +// dimensions, too large dimensions, too big number of dimensions) also sort +// input array of dimensions, this operation is also necessary for creating TAD +// object +@Namespace("shape") public static native void checkDimensions(int rank, + @StdVector IntPointer dimensions); +@Namespace("shape") public static native void checkDimensions(int rank, + @StdVector IntBuffer dimensions); +@Namespace("shape") public static native void checkDimensions(int rank, + @StdVector int[] dimensions); + +// function calculates linear index of array min, min is sub-array of max, index +// to be returned is min-array's index and corresponds to maxIdx of max array +// dimsToExclude - should be sorted in increasing order +@Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, + @Cast("const Nd4jLong*") LongPointer maxShapeInfo, + @Cast("const Nd4jLong*") LongPointer minShapeInfo, + @Const IntPointer dimsToExclude/*=nullptr*/, + int dimsLen/*=-1*/); +@Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, + @Cast("const Nd4jLong*") LongPointer maxShapeInfo, + @Cast("const Nd4jLong*") LongPointer minShapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, + @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, + @Cast("const Nd4jLong*") LongBuffer minShapeInfo, + @Const IntBuffer dimsToExclude/*=nullptr*/, + int dimsLen/*=-1*/); +@Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, + @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, + @Cast("const Nd4jLong*") LongBuffer minShapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, + @Cast("const Nd4jLong*") long[] maxShapeInfo, + @Cast("const Nd4jLong*") long[] minShapeInfo, + @Const int[] dimsToExclude/*=nullptr*/, + int dimsLen/*=-1*/); +@Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, + @Cast("const Nd4jLong*") long[] maxShapeInfo, + @Cast("const Nd4jLong*") long[] minShapeInfo); + +// function calculates absolute offset of min array, min is sub-array of max, +// offset to be returned corresponds to maxIdx of max array dimsToExclude - +// should be sorted in increasing order +@Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, + @Cast("const Nd4jLong*") LongPointer maxShapeInfo, + @Cast("const Nd4jLong*") LongPointer minShapeInfo, + @Const IntPointer dimsToExclude/*=nullptr*/, + int dimsLen/*=-1*/); +@Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, + @Cast("const Nd4jLong*") LongPointer maxShapeInfo, + @Cast("const Nd4jLong*") LongPointer minShapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, + @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, + @Cast("const Nd4jLong*") LongBuffer minShapeInfo, + @Const IntBuffer dimsToExclude/*=nullptr*/, + int dimsLen/*=-1*/); +@Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, + @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, + @Cast("const Nd4jLong*") LongBuffer minShapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, + @Cast("const Nd4jLong*") long[] maxShapeInfo, + @Cast("const Nd4jLong*") long[] minShapeInfo, + @Const int[] dimsToExclude/*=nullptr*/, + int dimsLen/*=-1*/); +@Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, + @Cast("const Nd4jLong*") long[] maxShapeInfo, + @Cast("const Nd4jLong*") long[] minShapeInfo); + +// max array is outer for min array, min array is sub-array of max array +// function calculates the coordinates of min array (and saves them into +// minIdxs) given coordinates of max array (already stored in maxIdxs) +// dimsToExclude - should be sorted in increasing order +// dimsLen - length of dimsToExclude, if not set (= -1), then it is calculated +// as maxRank - minRank +@Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, + @Cast("const Nd4jLong*") LongPointer maxShapeInfo, + @Cast("const Nd4jLong*") LongPointer minShapeInfo, + @Const IntPointer dimsToExclude/*=nullptr*/, + int dimsLen/*=-1*/); +@Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, + @Cast("const Nd4jLong*") LongPointer maxShapeInfo, + @Cast("const Nd4jLong*") LongPointer minShapeInfo); +@Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, + @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, + @Cast("const Nd4jLong*") LongBuffer minShapeInfo, + @Const IntBuffer dimsToExclude/*=nullptr*/, + int dimsLen/*=-1*/); +@Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, + @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, + @Cast("const Nd4jLong*") LongBuffer minShapeInfo); +@Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, + @Cast("const Nd4jLong*") long[] maxShapeInfo, + @Cast("const Nd4jLong*") long[] minShapeInfo, + @Const int[] dimsToExclude/*=nullptr*/, + int dimsLen/*=-1*/); +@Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, + @Cast("const Nd4jLong*") long[] maxShapeInfo, + @Cast("const Nd4jLong*") long[] minShapeInfo); + +// calculate indexes of max-array, these output indexes correspond to one minIdx +// index of min-array which is sub-array of max-array dimsToExclude - should be +// sorted in increasing order +@Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, + @Cast("const Nd4jLong*") LongPointer maxShapeInfo, + @Cast("const Nd4jLong*") LongPointer minShapeInfo, + @Const IntPointer dimsToExclude/*=nullptr*/); +@Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, + @Cast("const Nd4jLong*") LongPointer maxShapeInfo, + @Cast("const Nd4jLong*") LongPointer minShapeInfo); +@Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, + @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, + @Cast("const Nd4jLong*") LongBuffer minShapeInfo, + @Const IntBuffer dimsToExclude/*=nullptr*/); +@Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, + @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, + @Cast("const Nd4jLong*") LongBuffer minShapeInfo); +@Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, + @Cast("const Nd4jLong*") long[] maxShapeInfo, + @Cast("const Nd4jLong*") long[] minShapeInfo, + @Const int[] dimsToExclude/*=nullptr*/); +@Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, + @Cast("const Nd4jLong*") long[] maxShapeInfo, + @Cast("const Nd4jLong*") long[] minShapeInfo); + +// calculate offsets of max-array, these offsets correspond to one minIdx index +// of min-array which is sub-array of max-array maxOffsets - will contain +// calculated offsets of max-array, buffer for maxOffsets should be allocated +// beforehand dimsToExclude - should be sorted in increasing order memBuff - +// auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments +// storing, should be allocated beforehand +@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, + @Cast("const Nd4jLong") long minIdx, + @Cast("const Nd4jLong*") LongPointer maxShapeInfo, + @Cast("const Nd4jLong*") LongPointer minShapeInfo, + IntPointer memBuff, + @Const IntPointer dimsToExclude/*=nullptr*/); +@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, + @Cast("const Nd4jLong") long minIdx, + @Cast("const Nd4jLong*") LongPointer maxShapeInfo, + @Cast("const Nd4jLong*") LongPointer minShapeInfo, + IntPointer memBuff); +@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, + @Cast("const Nd4jLong") long minIdx, + @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, + @Cast("const Nd4jLong*") LongBuffer minShapeInfo, + IntBuffer memBuff, + @Const IntBuffer dimsToExclude/*=nullptr*/); +@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, + @Cast("const Nd4jLong") long minIdx, + @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, + @Cast("const Nd4jLong*") LongBuffer minShapeInfo, + IntBuffer memBuff); +@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, + @Cast("const Nd4jLong") long minIdx, + @Cast("const Nd4jLong*") long[] maxShapeInfo, + @Cast("const Nd4jLong*") long[] minShapeInfo, + int[] memBuff, + @Const int[] dimsToExclude/*=nullptr*/); +@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, + @Cast("const Nd4jLong") long minIdx, + @Cast("const Nd4jLong*") long[] maxShapeInfo, + @Cast("const Nd4jLong*") long[] minShapeInfo, + int[] memBuff); + +// calculates offsets for entities (elements or sub-arrays), shape in context of +// sub-array means dimensions excluded from outer array rank is equal to size of +// shape +@Namespace("shape") public static native void calcOffsets(int rank, @Cast("const Nd4jLong*") LongPointer shape, + @Cast("const Nd4jLong*") LongPointer strides, @Cast("Nd4jLong*") LongPointer offsets, + byte order/*='c'*/); +@Namespace("shape") public static native void calcOffsets(int rank, @Cast("const Nd4jLong*") LongPointer shape, + @Cast("const Nd4jLong*") LongPointer strides, @Cast("Nd4jLong*") LongPointer offsets); +@Namespace("shape") public static native void calcOffsets(int rank, @Cast("const Nd4jLong*") LongBuffer shape, + @Cast("const Nd4jLong*") LongBuffer strides, @Cast("Nd4jLong*") LongBuffer offsets, + byte order/*='c'*/); +@Namespace("shape") public static native void calcOffsets(int rank, @Cast("const Nd4jLong*") LongBuffer shape, + @Cast("const Nd4jLong*") LongBuffer strides, @Cast("Nd4jLong*") LongBuffer offsets); +@Namespace("shape") public static native void calcOffsets(int rank, @Cast("const Nd4jLong*") long[] shape, + @Cast("const Nd4jLong*") long[] strides, @Cast("Nd4jLong*") long[] offsets, + byte order/*='c'*/); +@Namespace("shape") public static native void calcOffsets(int rank, @Cast("const Nd4jLong*") long[] shape, + @Cast("const Nd4jLong*") long[] strides, @Cast("Nd4jLong*") long[] offsets); +@Namespace("shape") public static native void calcOffsets(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer offsets, + byte order/*='c'*/); +@Namespace("shape") public static native void calcOffsets(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer offsets); +@Namespace("shape") public static native void calcOffsets(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer offsets, + byte order/*='c'*/); +@Namespace("shape") public static native void calcOffsets(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer offsets); +@Namespace("shape") public static native void calcOffsets(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] offsets, + byte order/*='c'*/); +@Namespace("shape") public static native void calcOffsets(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] offsets); +// SD_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, +// const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order = 'c'); +// SD_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, +// const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, +// Nd4jLong*& zOffsets, const char order = 'c'); +@Namespace("shape") public static native void shapeOldScalar(@Cast("sd::DataType") int dtype, + @Cast("Nd4jLong*const") LongPointer buffer, + byte order); +@Namespace("shape") public static native void shapeOldScalar(@Cast("sd::DataType") int dtype, + @Cast("Nd4jLong*const") LongBuffer buffer, + byte order); +@Namespace("shape") public static native void shapeOldScalar(@Cast("sd::DataType") int dtype, + @Cast("Nd4jLong*const") long[] buffer, + byte order); + +// deduce order and element-wise stride +// if array is scalar or unit length vector then ews = 1 and order is preserved +// if array is common vector then ews = stride of non-unity dimension and order +// is preserved if strides are normal/contiguous then ews = 1 and corresponding +// order is set, otherwise ews = 0 and order is preserved +@Namespace("shape") public static native void checkStridesEwsAndOrder( + @Cast("Nd4jLong*") LongPointer shapeInfo, byte proposedOrder, int numOfNonUnitDims, + @Cast("const Nd4jLong*") LongPointer shapeNoUnities, @Cast("const Nd4jLong*") LongPointer stridesNoUnities); +@Namespace("shape") public static native void checkStridesEwsAndOrder( + @Cast("Nd4jLong*") LongBuffer shapeInfo, byte proposedOrder, int numOfNonUnitDims, + @Cast("const Nd4jLong*") LongBuffer shapeNoUnities, @Cast("const Nd4jLong*") LongBuffer stridesNoUnities); +@Namespace("shape") public static native void checkStridesEwsAndOrder( + @Cast("Nd4jLong*") long[] shapeInfo, byte proposedOrder, int numOfNonUnitDims, + @Cast("const Nd4jLong*") long[] shapeNoUnities, @Cast("const Nd4jLong*") long[] stridesNoUnities); +@Namespace("shape") public static native void checkStridesEwsAndOrder(@Cast("Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native void checkStridesEwsAndOrder(@Cast("Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native void checkStridesEwsAndOrder(@Cast("Nd4jLong*") long[] shapeInfo); + +/** + * processes whole set of sub-arrays + * evaluates shapeInfo of sub-arrays (all sub-arrays have the same shapeInfo) + * and their buffer offsets (each sub-array has its own unique offset from + * original this-buffer) arguments: wholeShapeInfo - original shapeInfo of whole + * array numOfSubArrs - number of sub-arrays, size of subArrOffsets is equal to + * numOfSubArrs dimsSize - size of dimsToExclude, if dimsSize = array rank or + * dimsSize = 0 it means sub-array is whole array, copy of wholeShapeInfo and + * one zero offset will be returned dimsToExclude - MUST BE SORTED, dimensions + * to evaluate sub-array along, i.e. when shape is [2,3,4,5] and + * dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5] + * subArrShapeInfo - output argument, contains shapeInfo (same for all + * sub-arrays) subArrOffsets - output argument, contains successive + * sub-arrays offsets from original this-buffer keepUnitiesInShape - if false + * then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> + * {a,b} + */ +@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets( + @Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, + int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, + @Cast("Nd4jLong*") LongPointer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); +@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets( + @Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, + int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, + @Cast("Nd4jLong*") LongPointer subArrOffsets); +@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets( + @Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, + int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, + @Cast("Nd4jLong*") LongBuffer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); +@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets( + @Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, + int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, + @Cast("Nd4jLong*") LongBuffer subArrOffsets); +@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets( + @Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, + int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, + @Cast("Nd4jLong*") long[] subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); +@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets( + @Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, + int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, + @Cast("Nd4jLong*") long[] subArrOffsets); + +/** + * processes only one sub-array, evaluates shapeInfo of sub-array and its buffer + * offset from original array arguments: idx - input argument, intervals of + * indexes which define the sub-array to point on, when isStrided = false then + * idx has form {dim0Start,dim0End, dim1Start,dim1End, ....} and length (2 * + * maxRank) when isStrided = true then idx has form + * {dim0Start,dim0End,dim0Stride, dim1Start,dim1End,dim1Stride, ....} and + * length (3 * maxRank) when (dimStart == dimEnd) then whole range will be used + * for current dimension maxShapeInfo - input argument, shapeInfo of original + * array minShapeInfo - output argument, shapeInfo of sub-array to be deduced + * minOffset - output argument, offset of sub-array buffer offsets from original + * buffer keepUnitiesInShape - input argument, if false then eliminate unities + * from sub-array shapeInfo, for example {1,a,1,b} -> {a,b} isStrided - input + * argument, if true then idx has length (3 * this->rankOf()) and contains + * additional stride numbers which correspond to stride between dimStart and + * dimEnd, numOfUntiesInMinShape - input argument, number of occurrences in idx + * when (dimEnd - dimStart) = 1 + */ +@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset( + @Cast("const Nd4jLong*") LongPointer idx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("Nd4jLong*") LongPointer minShapeInfo, + @Cast("Nd4jLong*") @ByRef LongPointer minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, + @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/); +@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset( + @Cast("const Nd4jLong*") LongPointer idx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("Nd4jLong*") LongPointer minShapeInfo, + @Cast("Nd4jLong*") @ByRef LongPointer minOffset); +@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset( + @Cast("const Nd4jLong*") LongBuffer idx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("Nd4jLong*") LongBuffer minShapeInfo, + @Cast("Nd4jLong*") @ByRef LongBuffer minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, + @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/); +@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset( + @Cast("const Nd4jLong*") LongBuffer idx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("Nd4jLong*") LongBuffer minShapeInfo, + @Cast("Nd4jLong*") @ByRef LongBuffer minOffset); +@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset( + @Cast("const Nd4jLong*") long[] idx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("Nd4jLong*") long[] minShapeInfo, + @Cast("Nd4jLong*") @ByRef long[] minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, + @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/); +@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset( + @Cast("const Nd4jLong*") long[] idx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("Nd4jLong*") long[] minShapeInfo, + @Cast("Nd4jLong*") @ByRef long[] minOffset); + +/** + * for example inShapeInfo is {3, 2,1,4, 4,4,1, 16384,1,99} + * then output shapeNoUnities will contain {2,4, 4,1} - that is only shape and + * strides, no rank/type/ews/order stridesNoUnities will point on strides in + * shapeNoUnities that is on {4,1} returns number of non-unity dimensions in + * inShapeInfo if there is no unities in inShapeInfo, then no copy procedure + * will be performed and shapeNoUnities/stridesNoUnities will point on + * corresponding places in inShapeInfo + */ +@Namespace("shape") public static native int excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") LongPointer inShapeInfo, + @Cast("Nd4jLong*&") @ByPtrRef LongPointer shapeNoUnities, + @Cast("Nd4jLong*&") @ByPtrRef LongPointer stridesNoUnities); +@Namespace("shape") public static native int excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") LongBuffer inShapeInfo, + @Cast("Nd4jLong*&") @ByPtrRef LongBuffer shapeNoUnities, + @Cast("Nd4jLong*&") @ByPtrRef LongBuffer stridesNoUnities); +@Namespace("shape") public static native int excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") long[] inShapeInfo, + @Cast("Nd4jLong*&") @ByPtrRef long[] shapeNoUnities, + @Cast("Nd4jLong*&") @ByPtrRef long[] stridesNoUnities); + +/** + * for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, + * dimsToExclude = {1,3}, dimsSize = 2 then outShapeInfo will contain {3, 2,3,4, + * 12,4,1, 16384,1,99} + */ +@Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") LongPointer inShapeInfo, + int dimsSize, + @Const IntPointer dimsToExclude, + @Cast("Nd4jLong*") LongPointer outShapeInfo); +@Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") LongBuffer inShapeInfo, + int dimsSize, + @Const IntBuffer dimsToExclude, + @Cast("Nd4jLong*") LongBuffer outShapeInfo); +@Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") long[] inShapeInfo, + int dimsSize, + @Const int[] dimsToExclude, + @Cast("Nd4jLong*") long[] outShapeInfo); + +/** + * get stride over contiguous axis (contiguous axis must have stride = 1) + * for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then + * output is 5 (that is smallest stride in inShapeInfo except those equal to 1) + */ +// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const +// Nd4jLong* inShapeInfo); +// END HEADERS +// BEGIN IMPLEMENTATIONS + +// #ifdef __CUDACC__ +// #endif + +/** + * Length of a tad given + * the shape information + */ /** * Tad element wise stride: @@ -8272,9 +9245,12 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * Again: this may not preserve ordering of the tad * but maybe used for reductions. */ - @Namespace("shape") public static native int tadElementWiseStride(@Cast("Nd4jLong*") LongPointer shapeInfo, IntPointer dimension,int dimensionLength); - @Namespace("shape") public static native int tadElementWiseStride(@Cast("Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension,int dimensionLength); - @Namespace("shape") public static native int tadElementWiseStride(@Cast("Nd4jLong*") long[] shapeInfo, int[] dimension,int dimensionLength); +@Namespace("shape") public static native int tadElementWiseStride(@Cast("Nd4jLong*") LongPointer shapeInfo, IntPointer dimension, + int dimensionLength); +@Namespace("shape") public static native int tadElementWiseStride(@Cast("Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension, + int dimensionLength); +@Namespace("shape") public static native int tadElementWiseStride(@Cast("Nd4jLong*") long[] shapeInfo, int[] dimension, + int dimensionLength); /** * Computes the standard packed array strides for a given shape. @@ -8312,9 +9288,8 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint ////////////////////////////////////////////////////////////////////// - -// check whether input dimensions are permuted, not permuted dimensions order have to be 0,....,rank-1 - +// check whether input dimensions are permuted, not permuted dimensions order +// have to be 0,....,rank-1 /** * @param toCopy the shape to copy @@ -8326,16 +9301,16 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * for the given rank and shape. */ - /** - * This is special method, it returns ONLY 2D shapebuffer. - * - * This method is used only for SoftMax - */ +/** + * This is special method, it returns ONLY 2D shapebuffer. + * + * This method is used only for SoftMax + */ /** -* Get the shape info buffer -* for the given rank and shape. -*/ + * Get the shape info buffer + * for the given rank and shape. + */ ////////////////////////////////////////////////////////////////////// @@ -8345,9 +9320,9 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint ////////////////////////////////////////////////////////////////////// - // ////////////////////////////////////////////////////////////////////// -// INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong arrLen) { +// INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong +// *shapeInfo, Nd4jLong arrLen) { // const Nd4jLong ews = shapeInfo[shapeInfo[0] + shapeInfo[0] + 2]; @@ -8369,7 +9344,8 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // return offset; // } -// INLINEDEF _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo, uint arrLen) { +// INLINEDEF _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo, +// uint arrLen) { // const uint rank = shapeInfo[0]; // const uint ews = shapeInfo[rank + rank + 2]; @@ -8396,7 +9372,6 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint ////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////// /** @@ -8424,10 +9399,6 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @return */ - - - - /** * Ensure that every value in the re arrange * array is unique @@ -8455,11 +9426,11 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint ////////////////////////////////////////////////////////////////////// /** -* Returns whether the -* given shape is a vector or not -* @param shape the shape of the array -* @param rank the rank of the shape -*/ + * Returns whether the + * given shape is a vector or not + * @param shape the shape of the array + * @param rank the rank of the shape + */ /** * Returns the shape portion of an information @@ -8473,16 +9444,16 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint */ /** -* Return a copy of a buffer. -* This buffer allocates memory -* that must be freed elsewhere. -*/ + * Return a copy of a buffer. + * This buffer allocates memory + * that must be freed elsewhere. + */ /** -* Return a copy of a buffer. -* This buffer allocates memory -* that must be freed elsewhere. -*/ + * Return a copy of a buffer. + * This buffer allocates memory + * that must be freed elsewhere. + */ /** * Permute the given strides @@ -8493,15 +9464,14 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * and all must be filled in) * @return the rearranged array */ - /* - INLINEDEF _CUDA_HD Nd4jLong *permutedStrides(Nd4jLong *toPermute, int shapeRank, int *rearrange) { - Nd4jLong *strideCopy = copyOf(shapeRank, toPermute); - checkArrangeArray(rearrange, shapeRank, shapeRank); - Nd4jLong *newStride = doPermuteSwap(shapeRank, strideCopy, rearrange); - delete[] strideCopy; - return newStride; - } - */ +/* + INLINEDEF _CUDA_HD Nd4jLong *permutedStrides(Nd4jLong *toPermute, int + shapeRank, int *rearrange) { Nd4jLong *strideCopy = copyOf(shapeRank, + toPermute); checkArrangeArray(rearrange, shapeRank, shapeRank); Nd4jLong + *newStride = doPermuteSwap(shapeRank, strideCopy, rearrange); delete[] + strideCopy; return newStride; + } + */ /** * Return the slice (shape + 1 in pointer arithmetic) @@ -8539,7 +9509,6 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * buffer */ - /** * Compute the length of the given shape */ @@ -8549,7 +9518,6 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * portion of an information buffer */ - /** * Returns the ordering * for this shape information buffer @@ -8565,9 +9533,9 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint */ /** -* Returns the element wise stride for this information -* buffer relative to a dimension and reduction index -*/ + * Returns the element wise stride for this information + * buffer relative to a dimension and reduction index + */ /** * Returns whether @@ -8595,7 +9563,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * item */ - /** +/** * Return a copy of this array with the * given index omitted * @@ -8624,9 +9592,9 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * for the shape to be returned as * @return the new shape */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer ensureVectorShape(@Cast("Nd4jLong*") LongPointer shape, int dimension); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer ensureVectorShape(@Cast("Nd4jLong*") LongBuffer shape, int dimension); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] ensureVectorShape(@Cast("Nd4jLong*") long[] shape, int dimension); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer ensureVectorShape(@Cast("Nd4jLong*") LongPointer shape, int dimension); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer ensureVectorShape(@Cast("Nd4jLong*") LongBuffer shape, int dimension); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] ensureVectorShape(@Cast("Nd4jLong*") long[] shape, int dimension); /** * Returns a shape @@ -8637,23 +9605,24 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @return the new shape */ - /** - * This method does STRICT comparison for two shape buffers - * - * @param shape - * @return - */ +/** + * This method does STRICT comparison for two shape buffers + * + * @param shape + * @return + */ ////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////// - /** - * This method does SOFT comparison for two shape buffers, we compare only rank & shapes - * - * @param shape - * @return - */ +/** + * This method does SOFT comparison for two shape buffers, we compare only rank + * & shapes + * + * @param shape + * @return + */ /** * Generate an int buffer @@ -8724,7 +9693,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @return */ - /** +/** * calculates the offset for a tensor * @param index * @param arr @@ -8732,14 +9701,9 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @return */ - // #ifdef __CUDACC__ // #endif - - - - /** * Computes the number * of tensors along @@ -8752,20 +9716,17 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * a given dimension */ - - - /** -* Get an offset for retrieval -* from a data buffer -* based on the given -* shape stride and given indices -* @param baseOffset the offset to start from -* @param shape the shape of the array -* @param stride the stride of the array -* @param indices the indices to iterate over -* @return the double at the specified index -*/ + * Get an offset for retrieval + * from a data buffer + * based on the given + * shape stride and given indices + * @param baseOffset the offset to start from + * @param shape the shape of the array + * @param stride the stride of the array + * @param indices the indices to iterate over + * @return the double at the specified index + */ ////////////////////////////////////////////////////////////////////////// @@ -8773,7 +9734,6 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint ////////////////////////////////////////////////////////////////////////// - /** * Returns the tensor along dimension * for the given block index @@ -8807,7 +9767,8 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * reduction index. * @param tadIndexForOriginal the original tad index for the * split up problem (eg: split is dimension 3 mapping to a 2,3 problem) - * @param tadsForReduced the number of tads for the shrunk down problem (eg: 2,3) + * @param tadsForReduced the number of tads for the shrunk down problem (eg: + * 2,3) * @param tadsForOriginal the number of tads for the smaller problem (eg: 3) */ @@ -8833,24 +9794,21 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param originalTadNum the tad number for the reduced version of the problem */ - /** * Returns the prod of the data * up to the given length */ - @Namespace("shape") public static native int rearMostLeftOverItem(@Cast("Nd4jLong*") LongPointer data, @Cast("Nd4jLong*") LongPointer dimension,int dimensionLength); - @Namespace("shape") public static native int rearMostLeftOverItem(@Cast("Nd4jLong*") LongBuffer data, @Cast("Nd4jLong*") LongBuffer dimension,int dimensionLength); - @Namespace("shape") public static native int rearMostLeftOverItem(@Cast("Nd4jLong*") long[] data, @Cast("Nd4jLong*") long[] dimension,int dimensionLength); +@Namespace("shape") public static native int rearMostLeftOverItem(@Cast("Nd4jLong*") LongPointer data, @Cast("Nd4jLong*") LongPointer dimension, + int dimensionLength); +@Namespace("shape") public static native int rearMostLeftOverItem(@Cast("Nd4jLong*") LongBuffer data, @Cast("Nd4jLong*") LongBuffer dimension, + int dimensionLength); +@Namespace("shape") public static native int rearMostLeftOverItem(@Cast("Nd4jLong*") long[] data, @Cast("Nd4jLong*") long[] dimension, + int dimensionLength); // #ifdef __CUDACC__ // #endif - - - - - // INLINEDEF _CUDA_HD Nd4jLong *shapeBufferOfNpyBuffer(char *buffer) { // unsigned Nd4jLong *shape; // unsigned int ndims, wordSize; @@ -8864,17 +9822,17 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint ////////////////////////////////////////////////////////////////////////// // copy-past from java hasDefaultStridesForShape function -// INLINEDEF _CUDA_H bool reshapeC(const int oldRank, Nd4jLong* oldShape, const int newRank, Nd4jLong* newShapeOf, bool isFOrder, Nd4jLong* target) { +// INLINEDEF _CUDA_H bool reshapeC(const int oldRank, Nd4jLong* oldShape, const +// int newRank, Nd4jLong* newShapeOf, bool isFOrder, Nd4jLong* target) { // int oldnd; // Nd4jLong* olddims = shape::copyOf(oldRank, shape::shapeOf(oldShape)); -// Nd4jLong* oldstrides = shape::copyOf(oldRank, shape::stride(oldShape)); -// int np, op, last_stride; -// int oi, oj, ok, ni, nj, nk; -// Nd4jLong* newStrides = new Nd4jLong[newRank]; -// oldnd = 0; +// Nd4jLong* oldstrides = shape::copyOf(oldRank, +// shape::stride(oldShape)); int np, op, last_stride; int oi, oj, ok, +// ni, nj, nk; Nd4jLong* newStrides = new Nd4jLong[newRank]; oldnd = 0; // /* -// * Remove axes with dimension 1 from the old array. They have no effect +// * Remove axes with dimension 1 from the old array. They have no +// effect // * but would need special cases since their strides do not matter. // */ // for (oi = 0; oi < oldRank; oi++) { @@ -8911,11 +9869,8 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // return false; // } -// /* oi to oj and ni to nj give the axis ranges currently worked with */ -// oi = 0; -// oj = 1; -// ni = 0; -// nj = 1; +// /* oi to oj and ni to nj give the axis ranges currently worked with +// */ oi = 0; oj = 1; ni = 0; nj = 1; // while (ni < newRank && oi < oldnd) { // np = newShapeOf[ni]; @@ -8943,7 +9898,8 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // } // } else { // /* C order */ -// if (oldstrides[ok] != olddims[ok + 1] * oldstrides[ok + 1]) { +// if (oldstrides[ok] != olddims[ok + 1] * oldstrides[ok + +// 1]) { // /* not contiguous enough */ // delete[] olddims; // delete[] oldstrides; @@ -8994,7 +9950,8 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // target[shape::shapeInfoLength(newRank) - 3] = 0; // target[shape::shapeInfoLength(newRank) - 2] = 0; // target[shape::shapeInfoLength(newRank) - 1] = isFOrder ? 102 : 99; -// sd::ArrayOptions::setDataType(target, sd::ArrayOptions::dataType(oldShape)); +// sd::ArrayOptions::setDataType(target, +// sd::ArrayOptions::dataType(oldShape)); // delete[] olddims; // delete[] oldstrides; @@ -9004,18 +9961,26 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // } ////////////////////////////////////////////////////////////////////// -// INLINEDEF _CUDA_H bool reshapeC(const int oldRank, const Nd4jLong* oldShapeInfo, const int newRank, const Nd4jLong* newShape, Nd4jLong* newShapeInfo) { +// INLINEDEF _CUDA_H bool reshapeC(const int oldRank, const Nd4jLong* +// oldShapeInfo, const int newRank, const Nd4jLong* newShape, Nd4jLong* +// newShapeInfo) { -// // PLEASE NOTE !: reshaping not-permuted (ews=1) array in f order (except insertion/elimination of unities) will definitely cause allocation of new buffer for array elements -// // also this function takes into account identical shapes automatically, namely in that case oldShapeInfo is completely copied to newShapeInfo +// // PLEASE NOTE !: reshaping not-permuted (ews=1) array in f order +// (except insertion/elimination of unities) will definitely cause +// allocation of new buffer for array elements +// // also this function takes into account identical shapes +// automatically, namely in that case oldShapeInfo is completely copied +// to newShapeInfo // newShapeInfo[0] = newRank; // memcpy(newShapeInfo + 1, newShape, newRank * sizeof(Nd4jLong)); // Nd4jLong* newStrides = shape::stride(newShapeInfo); -// const Nd4jLong* oldShape = shape::shapeOf(const_cast(oldShapeInfo)); -// const Nd4jLong* oldStrides = shape::stride(const_cast(oldShapeInfo)); -// Nd4jLong oldStart(0), oldStop(1), newStart(0), newStop(1), newDim, oldDim; +// const Nd4jLong* oldShape = +// shape::shapeOf(const_cast(oldShapeInfo)); const Nd4jLong* +// oldStrides = shape::stride(const_cast(oldShapeInfo)); +// Nd4jLong oldStart(0), oldStop(1), newStart(0), newStop(1), newDim, +// oldDim; // while (newStart < newRank && oldStart < oldRank) { @@ -9026,13 +9991,16 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // if (newDim < oldDim) newDim *= newShape[newStop++]; // else oldDim *= oldShape[oldStop++]; -// // ------ Check whether the original axes can be combined ------ // -// for (int step = 1, i = oldStart; i < oldStop - 1; ++i) { -// if(oldShape[i] == 1) // skip unity-dimension and its stride +// // ------ Check whether the original axes can be combined ------ +// // for (int step = 1, i = oldStart; i < oldStop - 1; ++i) { +// if(oldShape[i] == 1) // skip unity-dimension +// and its stride // continue; // while((i + step) < oldRank && oldShape[i + step] == 1) -// ++step; // skip following unity-dimensions and its strides if such are present -// if((i + step) < oldRank && oldStrides[i] != oldShape[i + step] * oldStrides[i + step]) +// ++step; // skip following +// unity-dimensions and its strides if such are present +// if((i + step) < oldRank && oldStrides[i] != oldShape[i + +// step] * oldStrides[i + step]) // return false; // not contiguous enough // } @@ -9044,13 +10012,15 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // oldStart = oldStop++; // } -// // rest of strides should be unities (if there is remainder in strides space, that is newStart < newRank) -// for (int i = newStart; i < newRank; ++i) +// // rest of strides should be unities (if there is remainder in +// strides space, that is newStart < newRank) for (int i = newStart; i < +// newRank; ++i) // newStrides[i] = 1; -// newShapeInfo[2 * newRank + 3] = shape::order(oldShapeInfo); // order -// newShapeInfo[2 * newRank + 2] = shape::elementWiseStride(oldShapeInfo); // ews -// newShapeInfo[2 * newRank + 1] = shape::type(oldShapeInfo); // type +// newShapeInfo[2 * newRank + 3] = shape::order(oldShapeInfo); // order +// newShapeInfo[2 * newRank + 2] = +// shape::elementWiseStride(oldShapeInfo); // ews newShapeInfo[2 * +// newRank + 1] = shape::type(oldShapeInfo); // type // return true; // } @@ -9059,20 +10029,22 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint ////////////////////////////////////////////////////////////////////// - // this function checks the consistence of dimensions with array rank (negative dimensions, too large dimensions, too big number of dimensions) - // also it sorts input array of dimensions, this operation is also necessary for creating TAD object - +// this function checks the consistence of dimensions with array rank (negative +// dimensions, too large dimensions, too big number of dimensions) also it sorts +// input array of dimensions, this operation is also necessary for creating TAD +// object // max array is outer for min array, min array is sub-array of max array -// function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs) +// function calculates the coordinates of min array (and saves them into +// minIdxs) given coordinates of max array (already stored in maxIdxs) - ////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////// @@ -9103,7 +10075,9 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint ////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////// -// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) { +// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& +// xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* +// zShapeInfo, Nd4jLong*& zOffsets, const char order) { // // we assume all array have same length // const Nd4jLong len = shape::length(xShapeInfo); @@ -9116,22 +10090,27 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // const char yOrder = shape::order(yShapeInfo); // const char zOrder = shape::order(zShapeInfo); -// const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo, zShapeInfo); +// const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo, +// zShapeInfo); -// if (xEws == 1 && yEws == 1 && zEws == 1 && xOrder == yOrder && xOrder == zOrder && (xOrder == 'c' || shapesSame)) { +// if (xEws == 1 && yEws == 1 && zEws == 1 && xOrder == yOrder && xOrder == +// zOrder && (xOrder == 'c' || shapesSame)) { // xOffsets = yOffsets = zOffsets = nullptr; // } -// else if(xEws == 1 && yEws == 1 && xOrder == yOrder && (xOrder == 'c' || shape::shapeEquals(xShapeInfo, yShapeInfo))) { +// else if(xEws == 1 && yEws == 1 && xOrder == yOrder && (xOrder == 'c' || +// shape::shapeEquals(xShapeInfo, yShapeInfo))) { // xOffsets = yOffsets = nullptr; // zOffsets = new Nd4jLong[len]; // shape::calcOffsets(zShapeInfo, zOffsets, xOrder); // } -// else if(xEws == 1 && zEws == 1 && xOrder == zOrder && (xOrder == 'c' || shape::shapeEquals(xShapeInfo, zShapeInfo))) { +// else if(xEws == 1 && zEws == 1 && xOrder == zOrder && (xOrder == 'c' || +// shape::shapeEquals(xShapeInfo, zShapeInfo))) { // xOffsets = zOffsets = nullptr; // yOffsets = new Nd4jLong[len]; // shape::calcOffsets(yShapeInfo, yOffsets, xOrder); // } -// else if(yEws == 1 && zEws == 1 && yOrder == zOrder && (yOrder == 'c' || shape::shapeEquals(yShapeInfo, zShapeInfo))) { +// else if(yEws == 1 && zEws == 1 && yOrder == zOrder && (yOrder == 'c' || +// shape::shapeEquals(yShapeInfo, zShapeInfo))) { // yOffsets = zOffsets = nullptr; // xOffsets = new Nd4jLong[len]; // shape::calcOffsets(xShapeInfo, xOffsets, yOrder); @@ -9184,7 +10163,8 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // } // } // } -// else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo, zShapeInfo)) { +// else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo, +// zShapeInfo)) { // xOffsets = new Nd4jLong[len]; // shape::calcOffsets(xShapeInfo, xOffsets); // yOffsets = zOffsets = xOffsets; @@ -9244,7 +10224,9 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // } ////////////////////////////////////////////////////////////////////// -// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order) { +// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& +// xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order) +// { // // we assume all array have same length // const Nd4jLong len = shape::length(xShapeInfo); @@ -9257,7 +10239,8 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo); -// if (xEws == 1 && yEws == 1 && xOrder == yOrder && (xOrder == 'c' || shapesSame)) { +// if (xEws == 1 && yEws == 1 && xOrder == yOrder && (xOrder == 'c' || +// shapesSame)) { // xOffsets = yOffsets = nullptr; // } // else if(xEws == 1) { @@ -9296,9 +10279,9 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint ////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////// -// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo) { +// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const +// Nd4jLong* inShapeInfo) { // Nd4jLong result = 9223372036854775807LL; @@ -9316,9 +10299,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // return result == 9223372036854775807LL ? 1 : result; // } - - - + // namespace shape // #endif /* SHAPE_H_ */ @@ -9348,7 +10329,6 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_OPARGSHOLDER_H // #define LIBND4J_OPARGSHOLDER_H - // #include // #include @@ -9363,68 +10343,99 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint return (OpArgsHolder)super.position(position); } + // default constructor + public OpArgsHolder() { super((Pointer)null); allocate(); } + private native void allocate(); - // default constructor - public OpArgsHolder() { super((Pointer)null); allocate(); } - private native void allocate(); - - // copy constructor - public OpArgsHolder(@Const @ByRef OpArgsHolder other) { super((Pointer)null); allocate(other); } - private native void allocate(@Const @ByRef OpArgsHolder other); - - // constructor - public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } - private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/); - public OpArgsHolder(@Const @ByRef NDArrayVector inArrs) { super((Pointer)null); allocate(inArrs); } - private native void allocate(@Const @ByRef NDArrayVector inArrs); - public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector DoubleBuffer tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs/*=std::vector()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } - private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector DoubleBuffer tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs/*=std::vector()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/); - public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } - private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/); - public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } - private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/); - public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector DoubleBuffer tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs/*=std::vector()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } - private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector DoubleBuffer tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs/*=std::vector()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/); - public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } - private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/); - - // move constructor - - // assignment operator - public native @ByRef @Name("operator =") OpArgsHolder put(@Const @ByRef OpArgsHolder other); - - // move assignment operator - - public native @Const @ByRef NDArrayVector getInArrs(); - - public native @StdVector DoublePointer getTArgs(); - - public native @Cast("Nd4jLong*") @StdVector LongPointer getIArgs(); - - public native @Cast("bool*") @StdVector BooleanPointer getBArgs(); - - public native @Cast("bool*") @StdVector BooleanPointer getAllocInfo(); - - public native int getNumInArrs(); - - public native int getNumTArgs(); - - public native int getNumIArgs(); - - public native int getNumBArgs(); - - public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs, @Cast("const bool") boolean isInPlace/*=false*/); - public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs); - + // copy constructor + public OpArgsHolder(@Const @ByRef OpArgsHolder other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef OpArgsHolder other); + + // constructor + public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, + @StdVector DoublePointer tArgs/*=std::vector()*/, + @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector()*/, + @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } + private native void allocate(@Const @ByRef NDArrayVector inArrs, + @StdVector DoublePointer tArgs/*=std::vector()*/, + @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector()*/, + @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/); + public OpArgsHolder(@Const @ByRef NDArrayVector inArrs) { super((Pointer)null); allocate(inArrs); } + private native void allocate(@Const @ByRef NDArrayVector inArrs); + public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, + @StdVector DoubleBuffer tArgs/*=std::vector()*/, + @Cast("Nd4jLong*") @StdVector LongBuffer iArgs/*=std::vector()*/, + @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } + private native void allocate(@Const @ByRef NDArrayVector inArrs, + @StdVector DoubleBuffer tArgs/*=std::vector()*/, + @Cast("Nd4jLong*") @StdVector LongBuffer iArgs/*=std::vector()*/, + @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/); + public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, + @StdVector double[] tArgs/*=std::vector()*/, + @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector()*/, + @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } + private native void allocate(@Const @ByRef NDArrayVector inArrs, + @StdVector double[] tArgs/*=std::vector()*/, + @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector()*/, + @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/); + public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, + @StdVector DoublePointer tArgs/*=std::vector()*/, + @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector()*/, + @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } + private native void allocate(@Const @ByRef NDArrayVector inArrs, + @StdVector DoublePointer tArgs/*=std::vector()*/, + @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector()*/, + @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/); + public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, + @StdVector DoubleBuffer tArgs/*=std::vector()*/, + @Cast("Nd4jLong*") @StdVector LongBuffer iArgs/*=std::vector()*/, + @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } + private native void allocate(@Const @ByRef NDArrayVector inArrs, + @StdVector DoubleBuffer tArgs/*=std::vector()*/, + @Cast("Nd4jLong*") @StdVector LongBuffer iArgs/*=std::vector()*/, + @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/); + public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, + @StdVector double[] tArgs/*=std::vector()*/, + @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector()*/, + @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } + private native void allocate(@Const @ByRef NDArrayVector inArrs, + @StdVector double[] tArgs/*=std::vector()*/, + @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector()*/, + @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/); + + // move constructor + + // assignment operator + public native @ByRef @Name("operator =") OpArgsHolder put(@Const @ByRef OpArgsHolder other); + + // move assignment operator + + public native @Const @ByRef NDArrayVector getInArrs(); + + public native @StdVector DoublePointer getTArgs(); + + public native @Cast("Nd4jLong*") @StdVector LongPointer getIArgs(); + + public native @Cast("bool*") @StdVector BooleanPointer getBArgs(); + + public native @Cast("bool*") @StdVector BooleanPointer getAllocInfo(); + + public native int getNumInArrs(); + + public native int getNumTArgs(); + + public native int getNumIArgs(); + + public native int getNumBArgs(); + + public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs, + @Cast("const bool") boolean isInPlace/*=false*/); + public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs); } + // namespace sd - - - - - -// #endif //LIBND4J_OPARGSHOLDER_H +// #endif // LIBND4J_OPARGSHOLDER_H // Parsed from array/ShapeList.h @@ -9452,57 +10463,58 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_SHAPELIST_H // #define LIBND4J_SHAPELIST_H -// #include // #include // #include - @Namespace("sd") @NoOffset public static class ShapeList extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ShapeList(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public ShapeList(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public ShapeList position(long position) { - return (ShapeList)super.position(position); - } - - public ShapeList(@Cast("const Nd4jLong*") LongPointer shape/*=nullptr*/) { super((Pointer)null); allocate(shape); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer shape/*=nullptr*/); - public ShapeList() { super((Pointer)null); allocate(); } - private native void allocate(); - public ShapeList(@Cast("const Nd4jLong*") LongBuffer shape/*=nullptr*/) { super((Pointer)null); allocate(shape); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer shape/*=nullptr*/); - public ShapeList(@Cast("const Nd4jLong*") long[] shape/*=nullptr*/) { super((Pointer)null); allocate(shape); } - private native void allocate(@Cast("const Nd4jLong*") long[] shape/*=nullptr*/); - public ShapeList(@Cast("const Nd4jLong**") @StdVector PointerPointer shapes, @Cast("bool") boolean isWorkspace) { super((Pointer)null); allocate(shapes, isWorkspace); } - private native void allocate(@Cast("const Nd4jLong**") @StdVector PointerPointer shapes, @Cast("bool") boolean isWorkspace); - public ShapeList(@Cast("const Nd4jLong**") @StdVector @ByPtrPtr LongPointer shapes, @Cast("bool") boolean isWorkspace) { super((Pointer)null); allocate(shapes, isWorkspace); } - private native void allocate(@Cast("const Nd4jLong**") @StdVector @ByPtrPtr LongPointer shapes, @Cast("bool") boolean isWorkspace); - public ShapeList(@Cast("const Nd4jLong**") @StdVector @ByPtrPtr LongBuffer shapes, @Cast("bool") boolean isWorkspace) { super((Pointer)null); allocate(shapes, isWorkspace); } - private native void allocate(@Cast("const Nd4jLong**") @StdVector @ByPtrPtr LongBuffer shapes, @Cast("bool") boolean isWorkspace); - public ShapeList(@Cast("const Nd4jLong**") @StdVector @ByPtrPtr long[] shapes, @Cast("bool") boolean isWorkspace) { super((Pointer)null); allocate(shapes, isWorkspace); } - private native void allocate(@Cast("const Nd4jLong**") @StdVector @ByPtrPtr long[] shapes, @Cast("bool") boolean isWorkspace); - public ShapeList(@Cast("const Nd4jLong**") @StdVector PointerPointer shapes) { super((Pointer)null); allocate(shapes); } - private native void allocate(@Cast("const Nd4jLong**") @StdVector PointerPointer shapes); - //ShapeList(bool autoRemovable); - - public native @Cast("const Nd4jLong**") @StdVector PointerPointer asVector(); - public native void destroy(); - public native int size(); - public native @Cast("const Nd4jLong*") LongPointer at(int idx); - public native void push_back(@Cast("const Nd4jLong*") LongPointer shape); - public native void push_back(@Cast("const Nd4jLong*") LongBuffer shape); - public native void push_back(@Cast("const Nd4jLong*") long[] shape); - /** - * PLEASE NOTE: This method should be called ONLY if shapes were generated at workspaces. Otherwise you'll get memory leak - */ - public native void detach(); +// #include +@Namespace("sd") @NoOffset public static class ShapeList extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ShapeList(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public ShapeList(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public ShapeList position(long position) { + return (ShapeList)super.position(position); } + public ShapeList(@Cast("const Nd4jLong*") LongPointer shape/*=nullptr*/) { super((Pointer)null); allocate(shape); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer shape/*=nullptr*/); + public ShapeList() { super((Pointer)null); allocate(); } + private native void allocate(); + public ShapeList(@Cast("const Nd4jLong*") LongBuffer shape/*=nullptr*/) { super((Pointer)null); allocate(shape); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer shape/*=nullptr*/); + public ShapeList(@Cast("const Nd4jLong*") long[] shape/*=nullptr*/) { super((Pointer)null); allocate(shape); } + private native void allocate(@Cast("const Nd4jLong*") long[] shape/*=nullptr*/); + public ShapeList(@Cast("const Nd4jLong**") @StdVector PointerPointer shapes, @Cast("bool") boolean isWorkspace) { super((Pointer)null); allocate(shapes, isWorkspace); } + private native void allocate(@Cast("const Nd4jLong**") @StdVector PointerPointer shapes, @Cast("bool") boolean isWorkspace); + public ShapeList(@Cast("const Nd4jLong**") @StdVector @ByPtrPtr LongPointer shapes, @Cast("bool") boolean isWorkspace) { super((Pointer)null); allocate(shapes, isWorkspace); } + private native void allocate(@Cast("const Nd4jLong**") @StdVector @ByPtrPtr LongPointer shapes, @Cast("bool") boolean isWorkspace); + public ShapeList(@Cast("const Nd4jLong**") @StdVector @ByPtrPtr LongBuffer shapes, @Cast("bool") boolean isWorkspace) { super((Pointer)null); allocate(shapes, isWorkspace); } + private native void allocate(@Cast("const Nd4jLong**") @StdVector @ByPtrPtr LongBuffer shapes, @Cast("bool") boolean isWorkspace); + public ShapeList(@Cast("const Nd4jLong**") @StdVector @ByPtrPtr long[] shapes, @Cast("bool") boolean isWorkspace) { super((Pointer)null); allocate(shapes, isWorkspace); } + private native void allocate(@Cast("const Nd4jLong**") @StdVector @ByPtrPtr long[] shapes, @Cast("bool") boolean isWorkspace); + public ShapeList(@Cast("const Nd4jLong**") @StdVector PointerPointer shapes) { super((Pointer)null); allocate(shapes); } + private native void allocate(@Cast("const Nd4jLong**") @StdVector PointerPointer shapes); + // ShapeList(bool autoRemovable); + + public native @Cast("const Nd4jLong**") @StdVector PointerPointer asVector(); + public native void destroy(); + public native int size(); + public native @Cast("const Nd4jLong*") LongPointer at(int idx); + public native void push_back(@Cast("const Nd4jLong*") LongPointer shape); + public native void push_back(@Cast("const Nd4jLong*") LongBuffer shape); + public native void push_back(@Cast("const Nd4jLong*") long[] shape); + /** + * PLEASE NOTE: This method should be called ONLY if shapes were generated at + * workspaces. Otherwise you'll get memory leak + */ + public native void detach(); +} + // namespace sd -// #endif //LIBND4J_SHAPELIST_H +// #endif // LIBND4J_SHAPELIST_H // Parsed from system/type_boilerplate.h @@ -9535,539 +10547,1513 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define EXPAND3(...) __VA_ARGS__ // #define EXTRACT(...) EXTRACT __VA_ARGS__ // #define NOTHING_EXTRACT -// #define PASTE(x, ...) x ## __VA_ARGS__ -// #define PASTE2(x, ...) x ## __VA_ARGS__ -// #define PASTE3(x, ...) x ## __VA_ARGS__ +// #define PASTE(x, ...) x##__VA_ARGS__ +// #define PASTE2(x, ...) x##__VA_ARGS__ +// #define PASTE3(x, ...) x##__VA_ARGS__ // #define EVALUATING_PASTE(x, ...) PASTE(x, __VA_ARGS__) // #define EVALUATING_PASTE2(x, ...) PASTE2(x, __VA_ARGS__) // #define EVALUATING_PASTE3(x, ...) PASTE3(x, __VA_ARGS__) // #define UNPAREN(x) EVALUATING_PASTE(NOTHING_, EXTRACT x) // #define UNPAREN2(x) EVALUATING_PASTE2(NOTHING_, EXTRACT x) // #define UNPAREN3(x) EVALUATING_PASTE3(NOTHING_, EXTRACT x) -// #define EVAL( x ) x -// #define EVALX( x ) x -// #define EVAL0(...) EVAL1(EVAL1(EVAL1(__VA_ARGS__))) +// #define EVAL(x) x +// #define EVALX(x) x +// #define EVAL0(...) EVAL1(EVAL1(EVAL1(__VA_ARGS__))) // #define EVAL1(...) EVAL2(EVAL2(EVAL2(__VA_ARGS__))) // #define EVAL2(...) EVAL3(EVAL3(EVAL3(__VA_ARGS__))) // #define EVAL3(...) EVAL4(EVAL4(EVAL4(__VA_ARGS__))) // #define EVAL4(...) EVAL5(EVAL5(EVAL5(__VA_ARGS__))) // #define EVAL5(...) __VA_ARGS__ - // #define SEL_T_1(WHAT, NAME, SIGNATURE, TYPE_A) WHAT(NAME, SIGNATURE, TYPE_A) -// #define SEL_T_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_3(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_2(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_4(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_3(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_5(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_4(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_6(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_5(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_7(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_6(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_8(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_7(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_9(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_8(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_10(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_9(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_11(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_10(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_12(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_11(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_13(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_12(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_14(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_13(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_15(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_14(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_16(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_15(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_17(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_16(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_18(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_17(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_19(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_18(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_20(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_19(WHAT, NAME, SIGNATURE, __VA_ARGS__)) - - -// #define SEL_TT1_1(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) -// #define SEL_TT1_2(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_1(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_3(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_2(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_4(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_3(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_5(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_4(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_6(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_5(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_7(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_6(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_8(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_7(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_9(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_8(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_10(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_9(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_11(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_10(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_12(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_11(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_13(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_12(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_14(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_13(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_15(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_14(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_16(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_15(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_17(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_16(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_18(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_17(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_19(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_18(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_20(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_19(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) - - -// #define SEL_P1_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) -// #define SEL_P1_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_20(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) - -// #define SEL_P2_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) -// #define SEL_P2_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_20(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) - - - - -// #define SEL_TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) -// #define SEL_TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) - +// #define SEL_T_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_3(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_2(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_4(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_3(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_5(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_4(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_6(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_5(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_7(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_6(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_8(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_7(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_9(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_8(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_10(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_9(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_11(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_10(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_12(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_11(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_13(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_12(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_14(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_13(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_15(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_14(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_16(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_15(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_17(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_16(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_18(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_17(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_19(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_18(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_20(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_19(WHAT, NAME, SIGNATURE, __VA_ARGS__)) + +// #define SEL_TT1_1(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// #define SEL_TT1_2(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_1(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_3(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_2(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_4(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_3(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_5(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_4(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_6(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_5(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_7(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_6(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_8(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_7(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_9(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_8(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_10(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_9(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_11(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_10(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_12(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_11(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_13(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_12(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_14(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_13(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_15(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_14(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_16(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_15(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_17(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_16(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_18(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_17(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_19(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_18(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_20(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_19(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) + +// #define SEL_P1_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// #define SEL_P1_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_20(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) + +// #define SEL_P2_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// #define SEL_P2_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_20(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) + +// #define SEL_TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// #define SEL_TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) // #define DS_1(WHAT, NAME, SIGNATURE, TYPE_A) WHAT(NAME, SIGNATURE, TYPE_A) -// #define DS_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_3(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_2(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_4(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_3(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_5(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_4(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_6(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_5(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_7(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_6(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_8(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_7(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_9(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_8(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_10(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_9(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_11(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_10(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_12(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_11(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_13(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_12(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_14(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_13(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_15(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_14(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_16(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_15(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_17(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_16(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_18(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_17(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_19(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_18(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_20(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_19(WHAT, NAME, SIGNATURE, __VA_ARGS__)) - +// #define DS_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_3(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_2(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_4(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_3(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_5(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_4(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_6(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_5(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_7(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_6(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_8(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_7(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_9(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_8(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_10(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_9(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_11(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_10(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_12(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_11(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_13(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_12(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_14(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_13(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_15(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_14(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_16(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_15(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_17(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_16(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_18(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_17(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_19(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_18(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_20(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_19(WHAT, NAME, SIGNATURE, __VA_ARGS__)) // #define DP_1(WHAT, NAME, SIGNATURE, TYPE_A) WHAT(NAME, SIGNATURE, TYPE_A) -// #define DP_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_3(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_2(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_4(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_3(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_5(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_4(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_6(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_5(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_7(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_6(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_8(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_7(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_9(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_8(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_10(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_9(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_11(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_10(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_12(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_11(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_13(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_12(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_14(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_13(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_15(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_14(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_16(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_15(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_17(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_16(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_18(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_17(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_19(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_18(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_20(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_19(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_21(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_20(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_22(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_21(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_23(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_22(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_24(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_23(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_25(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_24(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_26(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_25(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_27(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_26(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_28(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_27(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_29(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_28(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_30(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_29(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_31(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_30(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_32(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_31(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_33(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_32(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_34(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_33(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_35(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_34(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_36(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_35(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_37(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_36(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_38(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_37(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_39(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_38(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_40(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_39(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_41(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_40(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_42(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_41(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_43(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_42(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_44(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_43(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_45(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_44(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_46(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_45(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_47(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_46(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_48(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_47(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_49(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_48(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_50(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_49(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_51(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_50(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_52(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_51(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_53(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_52(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_54(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_53(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_55(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_54(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_56(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_55(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_57(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_56(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_58(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_57(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_59(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_58(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_60(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_59(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_61(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_60(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_62(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_61(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_63(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_62(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_64(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_63(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_65(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_64(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_66(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_65(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_67(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_66(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_68(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_67(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_69(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_68(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_70(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_69(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_71(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_70(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_72(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_71(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_73(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_72(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_74(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_73(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_75(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_74(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_76(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_75(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_77(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_76(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_78(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_77(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_79(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_78(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_80(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_79(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_81(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_80(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_82(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_81(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_83(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_82(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_84(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_83(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_85(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_84(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_86(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_85(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_87(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_86(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_88(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_87(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_89(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_88(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_90(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_89(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_91(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_90(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_92(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_91(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_93(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_92(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_94(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_93(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_95(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_94(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_96(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_95(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_97(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_96(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_98(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_97(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_99(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_98(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_100(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_99(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_101(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_100(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_102(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_101(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_103(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_102(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_104(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_103(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_105(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_104(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_106(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_105(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_107(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_106(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_108(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_107(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_109(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_108(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_110(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_109(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_111(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_110(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_112(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_111(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_113(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_112(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_114(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_113(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_115(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_114(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_116(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_115(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_117(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_116(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_118(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_117(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_119(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_118(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_120(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_119(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_121(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_120(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_122(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_121(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_123(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_122(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_124(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_123(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_125(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_124(WHAT, NAME, SIGNATURE, __VA_ARGS__)) - - -// #define DT_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) -// #define DT_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_1(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_4(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_5(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_6(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_7(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_8(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_9(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_10(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_11(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_12(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_13(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_14(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_15(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_16(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_17(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_18(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_19(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) - -// #define DT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) -// #define DT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_1(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_4(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_5(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_6(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_7(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_8(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_9(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_10(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_11(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_12(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_13(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_14(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_15(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_16(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_17(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_18(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_19(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) - -// #define TTT1_1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) -// #define TTT1_2(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_3(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_2(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_4(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_3(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_5(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_4(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_6(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_5(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_7(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_6(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_8(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_7(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_9(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_8(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_10(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_9(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_11(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_10(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_12(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_11(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_13(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_12(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_14(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_13(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_15(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_14(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_16(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_15(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_17(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_16(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_18(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_17(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_19(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_18(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_20(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_19(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) - -// #define TTT2_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) -// #define TTT2_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_20(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) - - -// #define TTT3_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) -// #define TTT3_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_20(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) - - - -// #define TT1_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) -// #define TT1_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) - -// #define TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) -// #define TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) - -// #define TT3_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) -// #define TT3_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) - - -// #define GET_MACRO_SEL_T(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -// #define GET_MACRO_SEL_P1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -// #define GET_MACRO_SEL_P2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -// #define GET_MACRO_SEL_TT1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -// #define GET_MACRO_SEL_TT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -// #define GET_MACRO_DS(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -// #define GET_MACRO_DT(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -// #define GET_MACRO_DP(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, _61, _62, _63, _64, _65, _66, _67, _68, _69, _70, _71, _72, _73, _74, _75, _76, _77, _78, _79, _80, _81, _82, _83, _84, _85, _86, _87, _88, _89, _90, _91, _92, _93, _94, _95, _96, _97, _98, _99, _100, _101, _102, _103, _104, _105, _106, _107, _108, _109, _110, _111, _112, _113, _114, _115, _116, _117, _118, _119, _120, _121, _122, _123, _124, _125, NAME,...) NAME -// #define GET_MACRO_DT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME - - -// #define GET_MACRO_TT1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -// #define GET_MACRO_TT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -// #define GET_MACRO_TT3(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME - -// #define GET_MACRO_TTT1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -// #define GET_MACRO_TTT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -// #define GET_MACRO_TTT3(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME - -// #define FOR_EACH_S1(WHAT, NAME, SIGNATURE, ...) EXPAND(GET_MACRO_SEL_T(__VA_ARGS__, SEL_T_20, SEL_T_19, SEL_T_18, SEL_T_17, SEL_T_16, SEL_T_15, SEL_T_14, SEL_T_13, SEL_T_12, SEL_T_11, SEL_T_10, SEL_T_9, SEL_T_8, SEL_T_7, SEL_T_6, SEL_T_5, SEL_T_4, SEL_T_3, SEL_T_2, SEL_T_1)(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define FOR_EACH_S2(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, ...) EXPAND(GET_MACRO_SEL_TT1(__VA_ARGS__, SEL_TT1_20, SEL_TT1_19, SEL_TT1_18, SEL_TT1_17, SEL_TT1_16, SEL_TT1_15, SEL_TT1_14, SEL_TT1_13, SEL_TT1_12, SEL_TT1_11, SEL_TT1_10, SEL_TT1_9, SEL_TT1_8, SEL_TT1_7, SEL_TT1_6, SEL_TT1_5, SEL_TT1_4, SEL_TT1_3, SEL_TT1_2, SEL_TT1_1)(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) -// #define FOR_EACH_P1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, ...) EXPAND(GET_MACRO_SEL_P1(__VA_ARGS__, SEL_P1_20, SEL_P1_19, SEL_P1_18, SEL_P1_17, SEL_P1_16, SEL_P1_15, SEL_P1_14, SEL_P1_13, SEL_P1_12, SEL_P1_11, SEL_P1_10, SEL_P1_9, SEL_P1_8, SEL_P1_7, SEL_P1_6, SEL_P1_5, SEL_P1_4, SEL_P1_3, SEL_P1_2, SEL_P1_1)(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) -// #define FOR_EACH_P2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, ...) EXPAND2(GET_MACRO_SEL_P2(__VA_ARGS__, SEL_P2_20, SEL_P2_19, SEL_P2_18, SEL_P2_17, SEL_P2_16, SEL_P2_15, SEL_P2_14, SEL_P2_13, SEL_P2_12, SEL_P2_11, SEL_P2_10, SEL_P2_9, SEL_P2_8, SEL_P2_7, SEL_P2_6, SEL_P2_5, SEL_P2_4, SEL_P2_3, SEL_P2_2, SEL_P2_1)(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) -// #define FOR_EACH_S3(WHAT, NAME, SIGNATURE, TYPE_A, ...) EXPAND(GET_MACRO_SEL_TT2(__VA_ARGS__, SEL_TT2_20, SEL_TT2_19, SEL_TT2_18, SEL_TT2_17, SEL_TT2_16, SEL_TT2_15, SEL_TT2_14, SEL_TT2_13, SEL_TT2_12, SEL_TT2_11, SEL_TT2_10, SEL_TT2_9, SEL_TT2_8, SEL_TT2_7, SEL_TT2_6, SEL_TT2_5, SEL_TT2_4, SEL_TT2_3, SEL_TT2_2, SEL_TT2_1)(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define FOR_EACH_DS(WHAT, NAME, SIGNATURE, ...) EXPAND(GET_MACRO_DS(__VA_ARGS__, DS_20, DS_19, DS_18, DS_17, DS_16, DS_15, DS_14, DS_13, DS_12, DS_11, DS_10, DS_9, DS_8, DS_7, DS_6, DS_5, DS_4, DS_3, DS_2, DS_1)(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define FOR_EACH_DT(WHAT, NAME, SIGNATURE, TYPES_A, ...) EXPAND(GET_MACRO_DT(__VA_ARGS__, DT_20, DT_19, DT_18, DT_17, DT_16, DT_15, DT_14, DT_13, DT_12, DT_11, DT_10, DT_9, DT_8, DT_7, DT_6, DT_5, DT_4, DT_3, DT_2, DT_1)(WHAT, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) -// #define FOR_EACH_DT2(WHAT, NAME, SIGNATURE, TYPE_A, ...) EXPAND(GET_MACRO_DT2(__VA_ARGS__, DT2_20, DT2_19, DT2_18, DT2_17, DT2_16, DT2_15, DT2_14, DT2_13, DT2_12, DT2_11, DT2_10, DT2_9, DT2_8, DT2_7, DT2_6, DT2_5, DT2_4, DT2_3, DT2_2, DT2_1)(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define FOR_EACH_DP(WHAT, NAME, SIGNATURE, ...) EXPAND(GET_MACRO_DP(__VA_ARGS__, DP_125, DP_124, DP_123, DP_122, DP_121, DP_120, DP_119, DP_118, DP_117, DP_116, DP_115, DP_114, DP_113, DP_112, DP_111, DP_110, DP_109, DP_108, DP_107, DP_106, DP_105, DP_104, DP_103, DP_102, DP_101, DP_100, DP_99, DP_98, DP_97, DP_96, DP_95, DP_94, DP_93, DP_92, DP_91, DP_90, DP_89, DP_88, DP_87, DP_86, DP_85, DP_84, DP_83, DP_82, DP_81, DP_80, DP_79, DP_78, DP_77, DP_76, DP_75, DP_74, DP_73, DP_72, DP_71, DP_70, DP_69, DP_68, DP_67, DP_66, DP_65, DP_64, DP_63, DP_62, DP_61, DP_60, DP_59, DP_58, DP_57, DP_56, DP_55, DP_54, DP_53, DP_52, DP_51, DP_50, DP_49, DP_48, DP_47, DP_46, DP_45, DP_44, DP_43, DP_42, DP_41, DP_40, DP_39, DP_38, DP_37, DP_36, DP_35, DP_34, DP_33, DP_32, DP_31, DP_30, DP_29, DP_28, DP_27, DP_26, DP_25, DP_24, DP_23, DP_22, DP_21, DP_20, DP_19, DP_18, DP_17, DP_16, DP_15, DP_14, DP_13, DP_12, DP_11, DP_10, DP_9, DP_8, DP_7, DP_6, DP_5, DP_4, DP_3, DP_2, DP_1)(WHAT, NAME, SIGNATURE, __VA_ARGS__)) - - -// #define FOR_EACH_TT1(WHAT, NAME, SIGNATURE, TYPES_X, TYPES_Y, ...) EXPAND(GET_MACRO_TT1(__VA_ARGS__, TT1_20, TT1_19, TT1_18, TT1_17, TT1_16, TT1_15, TT1_14, TT1_13, TT1_12, TT1_11, TT1_10, TT1_9, TT1_8, TT1_7, TT1_6, TT1_5, TT1_4, TT1_3, TT1_2, TT1_1)(WHAT, NAME, SIGNATURE, TYPES_X, TYPES_Y, __VA_ARGS__)) -// #define FOR_EACH_TT2(WHAT, NAME, SIGNATURE, TYPE_Z, TYPES_X, ...) EXPAND(GET_MACRO_TT2(__VA_ARGS__, TT2_20, TT2_19, TT2_18, TT2_17, TT2_16, TT2_15, TT2_14, TT2_13, TT2_12, TT2_11, TT2_10, TT2_9, TT2_8, TT2_7, TT2_6, TT2_5, TT2_4, TT2_3, TT2_2, TT2_1)(WHAT, NAME, SIGNATURE, TYPE_Z, TYPES_X, __VA_ARGS__)) -// #define FOR_EACH_TT3(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, ...) EXPAND(GET_MACRO_TT3(__VA_ARGS__, TT3_20, TT3_19, TT3_18, TT3_17, TT3_16, TT3_15, TT3_14, TT3_13, TT3_12, TT3_11, TT3_10, TT3_9, TT3_8, TT3_7, TT3_6, TT3_5, TT3_4, TT3_3, TT3_2, TT3_1)(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, __VA_ARGS__)) - -// #define FOR_EACH_TTT1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, ...) EXPAND(GET_MACRO_TTT1(__VA_ARGS__, TTT1_20, TTT1_19, TTT1_18, TTT1_17, TTT1_16, TTT1_15, TTT1_14, TTT1_13, TTT1_12, TTT1_11, TTT1_10, TTT1_9, TTT1_8, TTT1_7, TTT1_6, TTT1_5, TTT1_4, TTT1_3, TTT1_2, TTT1_1)(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, __VA_ARGS__)) -// #define FOR_EACH_TTT2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, ...) EXPAND(GET_MACRO_TTT2(__VA_ARGS__, TTT2_20, TTT2_19, TTT2_18, TTT2_17, TTT2_16, TTT2_15, TTT2_14, TTT2_13, TTT2_12, TTT2_11, TTT2_10, TTT2_9, TTT2_8, TTT2_7, TTT2_6, TTT2_5, TTT2_4, TTT2_3, TTT2_2, TTT2_1)(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, __VA_ARGS__)) -// #define FOR_EACH_TTT3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, ...) EXPAND(GET_MACRO_TTT3(__VA_ARGS__, TTT3_20, TTT3_19, TTT3_18, TTT3_17, TTT3_16, TTT3_15, TTT3_14, TTT3_13, TTT3_12, TTT3_11, TTT3_10, TTT3_9, TTT3_8, TTT3_7, TTT3_6, TTT3_5, TTT3_4, TTT3_3, TTT3_2, TTT3_1)(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, __VA_ARGS__)) - -// #define _EXEC_SELECTOR_T(WHAT, NAME, SIGNATURE, ...) EVAL(FOR_EACH_S1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define _EXEC_SELECTOR_P_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, ...) EVAL(FOR_EACH_P1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) -// #define _EXEC_SELECTOR_P_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, ...) EVAL(FOR_EACH_P2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define _EXEC_SELECTOR_TT_1(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, ...) EVAL(FOR_EACH_S2(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) -// #define _EXEC_SELECTOR_TT_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) EVAL(FOR_EACH_S3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define _EXEC_SINGLE_T(WHAT, NAME, SIGNATURE, ...) EVAL(FOR_EACH_DS(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define _EXEC_DOUBLE_T(WHAT, NAME, SIGNATURE, TYPES_A, ...) EVAL(FOR_EACH_DT(WHAT, NAME, SIGNATURE, LIST(TYPES_A), __VA_ARGS__)) -// #define _EXEC_DOUBLE_T2(WHAT, NAME, SIGNATURE, TYPE_A, ...) EVAL(FOR_EACH_DT2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define _EXEC_DOUBLE_P(WHAT, NAME, SIGNATURE, ...) EVAL(FOR_EACH_DP(WHAT, NAME, SIGNATURE, __VA_ARGS__)) - -// #define _EXEC_SELECTOR_TTT_1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, ...) EVAL(FOR_EACH_TTT1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, __VA_ARGS__)) -// #define _EXEC_SELECTOR_TTT_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, ...) EVAL(FOR_EACH_TTT2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, __VA_ARGS__)) -// #define _EXEC_SELECTOR_TTT_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, ...) EVAL(FOR_EACH_TTT3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, __VA_ARGS__)) - -// #define _EXEC_TRIPLE_T3(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, ...) EVAL(FOR_EACH_TT3(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, __VA_ARGS__)) -// #define _EXEC_TRIPLE_T2(WHAT, NAME, SIGNATURE, TYPE_Z, TYPES_X, ...) EVAL(FOR_EACH_TT2(WHAT, NAME, SIGNATURE, TYPE_Z, LIST(TYPES_X), __VA_ARGS__)) -// #define _EXEC_TRIPLE_T1(WHAT, NAME, SIGNATURE, TYPES_X, TYPES_Y, ...) EVAL(FOR_EACH_TT1(WHAT, NAME, SIGNATURE, LIST(TYPES_X), LIST(TYPES_Y), __VA_ARGS__)) - -// #define DISPATCH_PAIRWISE(NAME, SIGNATURE, TYPE, TYPES_B) EVAL(_EXEC_DOUBLE_T2(RANDOMPAIRWISE2, NAME, SIGNATURE, TYPE, TYPES_B)) -// #define DISPATCH_PAIRWISE2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE, ...) EVAL(_EXEC_SELECTOR_P_2(SELECTOR_PAIRWISE_2, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE, __VA_ARGS__)) - -// #define DISPATCH_DTYPES(NAME, SIGNATURE, TYPE, TYPES_B) EVAL(_EXEC_DOUBLE_T2(RANDOMDOUBLE2, NAME, SIGNATURE, TYPE, TYPES_B)) -// #define DISPATCH_DTYPES2(NAME, SIGNATURE, TYPE, ...) EVAL(_EXEC_SELECTOR_TT_2(SELECTOR_DOUBLE_2, NAME, SIGNATURE, TYPE, __VA_ARGS__)) - -// #define DISPATCH_TTYPES2(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, ...) EVAL(_EXEC_SELECTOR_TTT_2(SELECTOR_TRIPLE_2, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, __VA_ARGS__)) -// #define DISPATCH_TTYPES3(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, ...) EVAL(_EXEC_SELECTOR_TTT_3(SELECTOR_TRIPLE_3, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, __VA_ARGS__)) - +// #define DP_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_3(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_2(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_4(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_3(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_5(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_4(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_6(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_5(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_7(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_6(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_8(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_7(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_9(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_8(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_10(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_9(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_11(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_10(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_12(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_11(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_13(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_12(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_14(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_13(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_15(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_14(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_16(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_15(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_17(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_16(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_18(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_17(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_19(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_18(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_20(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_19(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_21(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_20(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_22(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_21(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_23(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_22(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_24(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_23(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_25(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_24(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_26(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_25(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_27(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_26(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_28(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_27(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_29(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_28(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_30(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_29(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_31(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_30(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_32(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_31(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_33(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_32(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_34(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_33(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_35(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_34(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_36(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_35(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_37(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_36(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_38(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_37(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_39(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_38(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_40(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_39(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_41(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_40(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_42(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_41(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_43(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_42(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_44(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_43(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_45(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_44(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_46(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_45(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_47(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_46(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_48(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_47(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_49(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_48(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_50(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_49(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_51(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_50(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_52(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_51(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_53(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_52(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_54(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_53(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_55(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_54(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_56(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_55(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_57(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_56(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_58(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_57(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_59(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_58(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_60(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_59(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_61(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_60(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_62(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_61(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_63(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_62(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_64(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_63(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_65(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_64(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_66(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_65(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_67(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_66(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_68(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_67(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_69(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_68(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_70(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_69(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_71(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_70(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_72(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_71(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_73(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_72(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_74(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_73(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_75(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_74(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_76(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_75(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_77(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_76(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_78(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_77(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_79(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_78(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_80(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_79(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_81(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_80(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_82(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_81(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_83(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_82(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_84(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_83(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_85(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_84(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_86(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_85(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_87(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_86(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_88(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_87(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_89(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_88(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_90(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_89(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_91(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_90(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_92(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_91(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_93(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_92(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_94(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_93(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_95(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_94(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_96(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_95(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_97(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_96(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_98(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_97(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_99(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_98(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_100(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_99(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_101(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_100(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_102(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_101(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_103(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_102(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_104(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_103(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_105(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_104(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_106(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_105(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_107(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_106(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_108(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_107(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_109(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_108(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_110(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_109(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_111(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_110(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_112(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_111(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_113(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_112(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_114(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_113(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_115(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_114(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_116(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_115(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_117(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_116(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_118(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_117(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_119(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_118(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_120(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_119(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_121(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_120(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_122(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_121(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_123(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_122(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_124(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_123(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_125(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_124(WHAT, NAME, SIGNATURE, __VA_ARGS__)) + +// #define DT_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// #define DT_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_1(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_4(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_5(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_6(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_7(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_8(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_9(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_10(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_11(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_12(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_13(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_14(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_15(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_16(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_17(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_18(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_19(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) + +// #define DT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// #define DT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_1(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_4(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_5(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_6(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_7(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_8(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_9(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_10(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_11(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_12(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_13(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_14(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_15(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_16(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_17(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_18(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_19(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) + +// #define TTT1_1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// #define TTT1_2(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_3(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_2(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_4(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_3(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_5(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_4(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_6(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_5(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_7(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_6(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_8(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_7(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_9(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_8(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_10(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_9(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_11(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_10(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_12(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_11(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_13(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_12(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_14(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_13(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_15(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_14(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_16(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_15(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_17(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_16(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_18(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_17(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_19(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_18(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_20(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_19(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) + +// #define TTT2_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// #define TTT2_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_20(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) + +// #define TTT3_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// #define TTT3_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_20(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) + +// #define TT1_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// #define TT1_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) + +// #define TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// #define TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) + +// #define TT3_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// #define TT3_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) + +// #define GET_MACRO_SEL_T(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, +// _13, _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME +// #define GET_MACRO_SEL_P1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, +// _13, _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME +// #define GET_MACRO_SEL_P2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, +// _13, _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME +// #define GET_MACRO_SEL_TT1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, +// _13, _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME +// #define GET_MACRO_SEL_TT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, +// _13, _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME +// #define GET_MACRO_DS(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, +// _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME +// #define GET_MACRO_DT(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, +// _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME +// #define GET_MACRO_DP( +// _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, +// _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, +// _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, +// _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, _61, +// _62, _63, _64, _65, _66, _67, _68, _69, _70, _71, _72, _73, _74, _75, _76, +// _77, _78, _79, _80, _81, _82, _83, _84, _85, _86, _87, _88, _89, _90, _91, +// _92, _93, _94, _95, _96, _97, _98, _99, _100, _101, _102, _103, _104, +// _105, _106, _107, _108, _109, _110, _111, _112, _113, _114, _115, _116, +// _117, _118, _119, _120, _121, _122, _123, _124, _125, NAME, ...) +// NAME +// #define GET_MACRO_DT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, +// _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME + +// #define GET_MACRO_TT1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, +// _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME +// #define GET_MACRO_TT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, +// _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME +// #define GET_MACRO_TT3(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, +// _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME + +// #define GET_MACRO_TTT1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, +// _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME +// #define GET_MACRO_TTT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, +// _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME +// #define GET_MACRO_TTT3(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, +// _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME + +// #define FOR_EACH_S1(WHAT, NAME, SIGNATURE, ...) +// EXPAND(GET_MACRO_SEL_T(__VA_ARGS__, SEL_T_20, SEL_T_19, SEL_T_18, SEL_T_17, +// SEL_T_16, SEL_T_15, SEL_T_14, SEL_T_13, SEL_T_12, +// SEL_T_11, SEL_T_10, SEL_T_9, SEL_T_8, SEL_T_7, +// SEL_T_6, SEL_T_5, SEL_T_4, SEL_T_3, SEL_T_2, +// SEL_T_1)(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define FOR_EACH_S2(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, ...) +// EXPAND(GET_MACRO_SEL_TT1( +// __VA_ARGS__, SEL_TT1_20, SEL_TT1_19, SEL_TT1_18, SEL_TT1_17, SEL_TT1_16, +// SEL_TT1_15, SEL_TT1_14, SEL_TT1_13, SEL_TT1_12, SEL_TT1_11, SEL_TT1_10, +// SEL_TT1_9, SEL_TT1_8, SEL_TT1_7, SEL_TT1_6, SEL_TT1_5, SEL_TT1_4, +// SEL_TT1_3, SEL_TT1_2, +// SEL_TT1_1)(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) +// #define FOR_EACH_P1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, ...) +// EXPAND(GET_MACRO_SEL_P1(__VA_ARGS__, SEL_P1_20, SEL_P1_19, SEL_P1_18, +// SEL_P1_17, SEL_P1_16, SEL_P1_15, SEL_P1_14, +// SEL_P1_13, SEL_P1_12, SEL_P1_11, SEL_P1_10, +// SEL_P1_9, SEL_P1_8, SEL_P1_7, SEL_P1_6, SEL_P1_5, +// SEL_P1_4, SEL_P1_3, SEL_P1_2, SEL_P1_1)( +// WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) +// #define FOR_EACH_P2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, ...) +// EXPAND2(GET_MACRO_SEL_P2(__VA_ARGS__, SEL_P2_20, SEL_P2_19, SEL_P2_18, +// SEL_P2_17, SEL_P2_16, SEL_P2_15, SEL_P2_14, +// SEL_P2_13, SEL_P2_12, SEL_P2_11, SEL_P2_10, +// SEL_P2_9, SEL_P2_8, SEL_P2_7, SEL_P2_6, SEL_P2_5, +// SEL_P2_4, SEL_P2_3, SEL_P2_2, SEL_P2_1)( +// WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) +// #define FOR_EACH_S3(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// EXPAND(GET_MACRO_SEL_TT2( +// __VA_ARGS__, SEL_TT2_20, SEL_TT2_19, SEL_TT2_18, SEL_TT2_17, SEL_TT2_16, +// SEL_TT2_15, SEL_TT2_14, SEL_TT2_13, SEL_TT2_12, SEL_TT2_11, SEL_TT2_10, +// SEL_TT2_9, SEL_TT2_8, SEL_TT2_7, SEL_TT2_6, SEL_TT2_5, SEL_TT2_4, +// SEL_TT2_3, SEL_TT2_2, +// SEL_TT2_1)(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define FOR_EACH_DS(WHAT, NAME, SIGNATURE, ...) +// EXPAND(GET_MACRO_DS(__VA_ARGS__, DS_20, DS_19, DS_18, DS_17, DS_16, DS_15, +// DS_14, DS_13, DS_12, DS_11, DS_10, DS_9, DS_8, DS_7, +// DS_6, DS_5, DS_4, DS_3, DS_2, +// DS_1)(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define FOR_EACH_DT(WHAT, NAME, SIGNATURE, TYPES_A, ...) +// EXPAND(GET_MACRO_DT(__VA_ARGS__, DT_20, DT_19, DT_18, DT_17, DT_16, DT_15, +// DT_14, DT_13, DT_12, DT_11, DT_10, DT_9, DT_8, DT_7, +// DT_6, DT_5, DT_4, DT_3, DT_2, +// DT_1)(WHAT, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) +// #define FOR_EACH_DT2(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// EXPAND(GET_MACRO_DT2(__VA_ARGS__, DT2_20, DT2_19, DT2_18, DT2_17, DT2_16, +// DT2_15, DT2_14, DT2_13, DT2_12, DT2_11, DT2_10, DT2_9, +// DT2_8, DT2_7, DT2_6, DT2_5, DT2_4, DT2_3, DT2_2, +// DT2_1)(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define FOR_EACH_DP(WHAT, NAME, SIGNATURE, ...) +// EXPAND(GET_MACRO_DP( +// __VA_ARGS__, DP_125, DP_124, DP_123, DP_122, DP_121, DP_120, DP_119, +// DP_118, DP_117, DP_116, DP_115, DP_114, DP_113, DP_112, DP_111, DP_110, +// DP_109, DP_108, DP_107, DP_106, DP_105, DP_104, DP_103, DP_102, DP_101, +// DP_100, DP_99, DP_98, DP_97, DP_96, DP_95, DP_94, DP_93, DP_92, DP_91, +// DP_90, DP_89, DP_88, DP_87, DP_86, DP_85, DP_84, DP_83, DP_82, DP_81, +// DP_80, DP_79, DP_78, DP_77, DP_76, DP_75, DP_74, DP_73, DP_72, DP_71, +// DP_70, DP_69, DP_68, DP_67, DP_66, DP_65, DP_64, DP_63, DP_62, DP_61, +// DP_60, DP_59, DP_58, DP_57, DP_56, DP_55, DP_54, DP_53, DP_52, DP_51, +// DP_50, DP_49, DP_48, DP_47, DP_46, DP_45, DP_44, DP_43, DP_42, DP_41, +// DP_40, DP_39, DP_38, DP_37, DP_36, DP_35, DP_34, DP_33, DP_32, DP_31, +// DP_30, DP_29, DP_28, DP_27, DP_26, DP_25, DP_24, DP_23, DP_22, DP_21, +// DP_20, DP_19, DP_18, DP_17, DP_16, DP_15, DP_14, DP_13, DP_12, DP_11, +// DP_10, DP_9, DP_8, DP_7, DP_6, DP_5, DP_4, DP_3, DP_2, +// DP_1)(WHAT, NAME, SIGNATURE, __VA_ARGS__)) + +// #define FOR_EACH_TT1(WHAT, NAME, SIGNATURE, TYPES_X, TYPES_Y, ...) +// EXPAND(GET_MACRO_TT1(__VA_ARGS__, TT1_20, TT1_19, TT1_18, TT1_17, TT1_16, +// TT1_15, TT1_14, TT1_13, TT1_12, TT1_11, TT1_10, TT1_9, +// TT1_8, TT1_7, TT1_6, TT1_5, TT1_4, TT1_3, TT1_2, +// TT1_1)(WHAT, NAME, SIGNATURE, TYPES_X, TYPES_Y, +// __VA_ARGS__)) +// #define FOR_EACH_TT2(WHAT, NAME, SIGNATURE, TYPE_Z, TYPES_X, ...) +// EXPAND(GET_MACRO_TT2(__VA_ARGS__, TT2_20, TT2_19, TT2_18, TT2_17, TT2_16, +// TT2_15, TT2_14, TT2_13, TT2_12, TT2_11, TT2_10, TT2_9, +// TT2_8, TT2_7, TT2_6, TT2_5, TT2_4, TT2_3, TT2_2, +// TT2_1)(WHAT, NAME, SIGNATURE, TYPE_Z, TYPES_X, +// __VA_ARGS__)) +// #define FOR_EACH_TT3(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, ...) +// EXPAND(GET_MACRO_TT3(__VA_ARGS__, TT3_20, TT3_19, TT3_18, TT3_17, TT3_16, +// TT3_15, TT3_14, TT3_13, TT3_12, TT3_11, TT3_10, TT3_9, +// TT3_8, TT3_7, TT3_6, TT3_5, TT3_4, TT3_3, TT3_2, +// TT3_1)(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, +// __VA_ARGS__)) + +// #define FOR_EACH_TTT1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, +// ...) +// EXPAND(GET_MACRO_TTT1(__VA_ARGS__, TTT1_20, TTT1_19, TTT1_18, TTT1_17, +// TTT1_16, TTT1_15, TTT1_14, TTT1_13, TTT1_12, TTT1_11, +// TTT1_10, TTT1_9, TTT1_8, TTT1_7, TTT1_6, TTT1_5, +// TTT1_4, TTT1_3, TTT1_2, TTT1_1)( +// WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, __VA_ARGS__)) +// #define FOR_EACH_TTT2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, ...) +// EXPAND(GET_MACRO_TTT2(__VA_ARGS__, TTT2_20, TTT2_19, TTT2_18, TTT2_17, +// TTT2_16, TTT2_15, TTT2_14, TTT2_13, TTT2_12, TTT2_11, +// TTT2_10, TTT2_9, TTT2_8, TTT2_7, TTT2_6, TTT2_5, +// TTT2_4, TTT2_3, TTT2_2, TTT2_1)( +// WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, __VA_ARGS__)) +// #define FOR_EACH_TTT3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, ...) +// EXPAND(GET_MACRO_TTT3(__VA_ARGS__, TTT3_20, TTT3_19, TTT3_18, TTT3_17, +// TTT3_16, TTT3_15, TTT3_14, TTT3_13, TTT3_12, TTT3_11, +// TTT3_10, TTT3_9, TTT3_8, TTT3_7, TTT3_6, TTT3_5, +// TTT3_4, TTT3_3, TTT3_2, TTT3_1)( +// WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, __VA_ARGS__)) + +// #define _EXEC_SELECTOR_T(WHAT, NAME, SIGNATURE, ...) +// EVAL(FOR_EACH_S1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define _EXEC_SELECTOR_P_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, +// TYPES_A, ...) +// EVAL(FOR_EACH_P1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, +// __VA_ARGS__)) +// #define _EXEC_SELECTOR_P_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// ...) +// EVAL(FOR_EACH_P2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define _EXEC_SELECTOR_TT_1(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, ...) +// EVAL(FOR_EACH_S2(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) +// #define _EXEC_SELECTOR_TT_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// EVAL(FOR_EACH_S3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define _EXEC_SINGLE_T(WHAT, NAME, SIGNATURE, ...) +// EVAL(FOR_EACH_DS(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define _EXEC_DOUBLE_T(WHAT, NAME, SIGNATURE, TYPES_A, ...) +// EVAL(FOR_EACH_DT(WHAT, NAME, SIGNATURE, LIST(TYPES_A), __VA_ARGS__)) +// #define _EXEC_DOUBLE_T2(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// EVAL(FOR_EACH_DT2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define _EXEC_DOUBLE_P(WHAT, NAME, SIGNATURE, ...) +// EVAL(FOR_EACH_DP(WHAT, NAME, SIGNATURE, __VA_ARGS__)) + +// #define _EXEC_SELECTOR_TTT_1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, +// TYPES_Y, ...) +// EVAL(FOR_EACH_TTT1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, +// __VA_ARGS__)) +// #define _EXEC_SELECTOR_TTT_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, +// ...) +// EVAL(FOR_EACH_TTT2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, +// __VA_ARGS__)) +// #define _EXEC_SELECTOR_TTT_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, +// ...) +// EVAL(FOR_EACH_TTT3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, __VA_ARGS__)) + +// #define _EXEC_TRIPLE_T3(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, ...) +// EVAL(FOR_EACH_TT3(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, __VA_ARGS__)) +// #define _EXEC_TRIPLE_T2(WHAT, NAME, SIGNATURE, TYPE_Z, TYPES_X, ...) +// EVAL(FOR_EACH_TT2(WHAT, NAME, SIGNATURE, TYPE_Z, LIST(TYPES_X), __VA_ARGS__)) +// #define _EXEC_TRIPLE_T1(WHAT, NAME, SIGNATURE, TYPES_X, TYPES_Y, ...) +// EVAL(FOR_EACH_TT1(WHAT, NAME, SIGNATURE, LIST(TYPES_X), LIST(TYPES_Y), +// __VA_ARGS__)) + +// #define DISPATCH_PAIRWISE(NAME, SIGNATURE, TYPE, TYPES_B) +// EVAL(_EXEC_DOUBLE_T2(RANDOMPAIRWISE2, NAME, SIGNATURE, TYPE, TYPES_B)) +// #define DISPATCH_PAIRWISE2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE, ...) +// EVAL(_EXEC_SELECTOR_P_2(SELECTOR_PAIRWISE_2, XTYPE, YTYPE, ZTYPE, NAME, +// SIGNATURE, TYPE, __VA_ARGS__)) + +// #define DISPATCH_DTYPES(NAME, SIGNATURE, TYPE, TYPES_B) +// EVAL(_EXEC_DOUBLE_T2(RANDOMDOUBLE2, NAME, SIGNATURE, TYPE, TYPES_B)) +// #define DISPATCH_DTYPES2(NAME, SIGNATURE, TYPE, ...) +// EVAL(_EXEC_SELECTOR_TT_2(SELECTOR_DOUBLE_2, NAME, SIGNATURE, TYPE, +// __VA_ARGS__)) + +// #define DISPATCH_TTYPES2(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, ...) +// EVAL(_EXEC_SELECTOR_TTT_2(SELECTOR_TRIPLE_2, ZTYPE, NAME, SIGNATURE, TYPE_X, +// TYPES_Z, __VA_ARGS__)) +// #define DISPATCH_TTYPES3(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, ...) +// EVAL(_EXEC_SELECTOR_TTT_3(SELECTOR_TRIPLE_3, ZTYPE, NAME, SIGNATURE, TYPE_X, +// TYPE_Y, __VA_ARGS__)) // #ifndef __CLION_IDE__ -// #define BUILD_SINGLE_UNCHAINED_TEMPLATE(NAME, SIGNATURE, TYPES) EVAL(_EXEC_SINGLE_T(RANDOMSINGLEU, NAME, (SIGNATURE), TYPES)) -// #define BUILD_SINGLE_TEMPLATE(NAME, SIGNATURE, TYPES) EVAL(_EXEC_SINGLE_T(RANDOMSINGLE, NAME, (SIGNATURE), TYPES)) -// #define BUILD_SINGLE_TEMPLATE_TWICE(NAME, SIGNATURE, TYPES) EVAL(_EXEC_SELECTOR_T(TEMPLATE_SINGLE_TWICE, NAME, SIGNATURE, TYPES)) -// #define BUILD_DOUBLE_TEMPLATE(NAME, SIGNATURE, TYPES_A, TYPES_B) EVAL(_EXEC_DOUBLE_T(RANDOMDOUBLE, NAME, (SIGNATURE), (TYPES_A), TYPES_B)) -// #define BUILD_SINGLE_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} -// #define BUILD_SINGLE_SELECTOR_TWICE(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE_TWICE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} -// #define BUILD_SINGLE_SELECTOR_THRICE(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE_THRICE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} - - -// #define BUILD_SINGLE_PARTIAL_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_PARTIAL_SINGLE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type"); }} -// #define BUILD_DOUBLE_SELECTOR(XTYPE, YTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) switch(XTYPE) { EVAL(_EXEC_SELECTOR_TT_1(SELECTOR_DOUBLE, YTYPE, NAME, (SIGNATURE), (TYPES_B), TYPES_A)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} -// #define BUILD_TRIPLE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPES_Z) switch(XTYPE) { EVAL(_EXEC_SELECTOR_TTT_1(SELECTOR_TRIPLE, YTYPE, ZTYPE, NAME, SIGNATURE, (TYPES_Z), (TYPES_Y), TYPES_X)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type"); } } -// #define BUILD_TRIPLE_TEMPLATE(NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPES_Z) EVAL(_EXEC_TRIPLE_T1(RANDOMTRIPLE, NAME, (SIGNATURE), (TYPES_X), (TYPES_Y), TYPES_Z)) -// #define BUILD_PAIRWISE_TEMPLATE(NAME, SIGNATURE, TYPES_A) EVAL(_EXEC_DOUBLE_P(RANDOMPAIRWISE, NAME, (SIGNATURE), TYPES_A)) -// #define BUILD_PAIRWISE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) switch(XTYPE) { EVAL(_EXEC_SELECTOR_P_1(SELECTOR_PAIRWISE, XTYPE, YTYPE, ZTYPE, NAME, (SIGNATURE), (TYPES_B), TYPES_A)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type"); }} +// #define BUILD_SINGLE_UNCHAINED_TEMPLATE(NAME, SIGNATURE, TYPES) +// EVAL(_EXEC_SINGLE_T(RANDOMSINGLEU, NAME, (SIGNATURE), TYPES)) +// #define BUILD_SINGLE_TEMPLATE(NAME, SIGNATURE, TYPES) +// EVAL(_EXEC_SINGLE_T(RANDOMSINGLE, NAME, (SIGNATURE), TYPES)) +// #define BUILD_SINGLE_TEMPLATE_TWICE(NAME, SIGNATURE, TYPES) +// EVAL(_EXEC_SELECTOR_T(TEMPLATE_SINGLE_TWICE, NAME, SIGNATURE, TYPES)) +// #define BUILD_DOUBLE_TEMPLATE(NAME, SIGNATURE, TYPES_A, TYPES_B) +// EVAL(_EXEC_DOUBLE_T(RANDOMDOUBLE, NAME, (SIGNATURE), (TYPES_A), TYPES_B)) +// #define BUILD_SINGLE_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) +// switch (XTYPE) { +// EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE, NAME, SIGNATURE, TYPES)); +// default: { +// printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); +// fflush(stdout); +// throw std::runtime_error("bad data type"); +// } +// } +// #define BUILD_SINGLE_SELECTOR_TWICE(XTYPE, NAME, SIGNATURE, TYPES) +// switch (XTYPE) { +// EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE_TWICE, NAME, SIGNATURE, TYPES)); +// default: { +// printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); +// fflush(stdout); +// throw std::runtime_error("bad data type"); +// } +// } +// #define BUILD_SINGLE_SELECTOR_THRICE(XTYPE, NAME, SIGNATURE, TYPES) +// switch (XTYPE) { +// EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE_THRICE, NAME, SIGNATURE, TYPES)); +// default: { +// printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); +// fflush(stdout); +// throw std::runtime_error("bad data type"); +// } +// } + +// #define BUILD_SINGLE_PARTIAL_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) +// switch (XTYPE) { +// EVAL(_EXEC_SELECTOR_T(SELECTOR_PARTIAL_SINGLE, NAME, SIGNATURE, TYPES)); +// default: { +// printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); +// fflush(stdout); +// throw std::runtime_error("bad data type"); +// } +// } +// #define BUILD_DOUBLE_SELECTOR(XTYPE, YTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) +// switch (XTYPE) { +// EVAL(_EXEC_SELECTOR_TT_1(SELECTOR_DOUBLE, YTYPE, NAME, (SIGNATURE), +// (TYPES_B), TYPES_A)); +// default: { +// printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); +// fflush(stdout); +// throw std::runtime_error("bad data type"); +// } +// } +// #define BUILD_TRIPLE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_X, +// TYPES_Y, TYPES_Z) +// switch (XTYPE) { +// EVAL(_EXEC_SELECTOR_TTT_1(SELECTOR_TRIPLE, YTYPE, ZTYPE, NAME, SIGNATURE, +// (TYPES_Z), (TYPES_Y), TYPES_X)); +// default: { +// printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); +// fflush(stdout); +// throw std::runtime_error("bad data type"); +// } +// } +// #define BUILD_TRIPLE_TEMPLATE(NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPES_Z) +// EVAL(_EXEC_TRIPLE_T1(RANDOMTRIPLE, NAME, (SIGNATURE), (TYPES_X), (TYPES_Y), +// TYPES_Z)) +// #define BUILD_PAIRWISE_TEMPLATE(NAME, SIGNATURE, TYPES_A) +// EVAL(_EXEC_DOUBLE_P(RANDOMPAIRWISE, NAME, (SIGNATURE), TYPES_A)) +// #define BUILD_PAIRWISE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, +// TYPES_B) +// switch (XTYPE) { +// EVAL(_EXEC_SELECTOR_P_1(SELECTOR_PAIRWISE, XTYPE, YTYPE, ZTYPE, NAME, +// (SIGNATURE), (TYPES_B), TYPES_A)); +// default: { +// printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); +// fflush(stdout); +// throw std::runtime_error("bad data type"); +// } +// } // #else // #define BUILD_SINGLE_UNCHAINED_TEMPLATE(NAME, SIGNATURE, TYPES) // #define BUILD_SINGLE_TEMPLATE(NAME, SIGNATURE, TYPES) @@ -10078,77 +12064,208 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define BUILD_SINGLE_SELECTOR_THRICE(XTYPE, NAME, SIGNATURE, TYPES) // #define BUILD_SINGLE_PARTIAL_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) // #define BUILD_DOUBLE_SELECTOR(XTYPE, YTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) -// #define BUILD_TRIPLE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPES_Z) +// #define BUILD_TRIPLE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_X, +// TYPES_Y, TYPES_Z) // #define BUILD_TRIPLE_TEMPLATE(NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPES_Z) // #define BUILD_PAIRWISE_TEMPLATE(NAME, SIGNATURE, TYPES_A) -// #define BUILD_PAIRWISE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) +// #define BUILD_PAIRWISE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, +// TYPES_B) // #endif // #define LIST(...) __VA_ARGS__ -// #define _SELECTOR_DOUBLE_2(NAME, SIGNATURE, TYPE_A, ENUM, TYPE_B) case ENUM: { NAME SIGNATURE; break; }; -// #define SELECTOR_DOUBLE_2(NAME, SIGNATURE, TYPE_A, TYPE_B) EVALUATING_PASTE2(_SELECT, OR_DOUBLE_2(NAME, UNPAREN3(SIGNATURE), TYPE_A, UNPAREN3(TYPE_B))) - -// #define _SELECTOR_DOUBLE(YTYPE, NAME, SIGNATURE, ENUM, TYPE_A, ...) case ENUM: { switch(YTYPE) { EXPAND(DISPATCH_DTYPES2(NAME, SIGNATURE, TYPE_A, __VA_ARGS__)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d\n", YTYPE, __FILE__, __LINE__); fflush(stdout);}}; break; }; -// #define SELECTOR_DOUBLE(YTYPE, NAME, SIGNATURE, TYPES_B, TYPE_A) EVALUATING_PASTE(_SELECTOR, _DOUBLE(YTYPE, NAME, SIGNATURE, UNPAREN(TYPE_A), UNPAREN(TYPES_B))) - -// #define _SELECTOR_PAIRWISE_2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, ENUM, TYPE_B) case ENUM: { if (ZTYPE == YTYPE) {NAME SIGNATURE;} else if (XTYPE == ZTYPE ){NAME SIGNATURE;} else {printf("[ERROR] Unknown dtypeX=%d on %s:%d\n", YTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("Unknown Z operand");}; break; }; -// #define SELECTOR_PAIRWISE_2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) EVALUATING_PASTE2(_SELECT, OR_PAIRWISE_2(XTYPE, YTYPE, ZTYPE, NAME, UNPAREN3(SIGNATURE), TYPE_A, UNPAREN3(TYPE_B))) -// #define _SELECTOR_PAIRWISE(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, ENUM, TYPE_A, ...) case ENUM: { switch(YTYPE) { EXPAND(DISPATCH_PAIRWISE2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d\n", YTYPE, __FILE__, __LINE__); fflush(stdout);}}; break; }; -// #define SELECTOR_PAIRWISE(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_B, TYPE_A) EVALUATING_PASTE(_SELECTOR, _PAIRWISE(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, UNPAREN(TYPE_A), UNPAREN(TYPES_B))) - -// #define _SELECTOR_TRIPLE_3(NAME, SIGNATURE, TYPE_X, TYPE_Y, ENUM_Z, TYPE_Z) case ENUM_Z: { NAMESIGNATURE;}; break; -// #define SELECTOR_TRIPLE_3(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, TYPE_Z) EVALUATING_PASTE3(_SELECTOR, _TRIPLE_3(NAME, SIGNATURE, TYPE_X, TYPE_Y, UNPAREN3(TYPE_Z))) -// #define _SELECTOR_TRIPLE_2(ZTYPE, NAME, SIGNATURE, TYPE_X, ENUM_Y, TYPE_Y, TYPES_Z) case ENUM_Y: { switch (ZTYPE) { EXPAND2(DISPATCH_TTYPES3(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, UNPAREN3(TYPES_Z))); default: {printf("[ERROR] Unknown dtypeZ=%d on %s:%d\n", ZTYPE, __FILE__, __LINE__); ; fflush(stdout);} } break; }; -// #define SELECTOR_TRIPLE_2(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, TYPE_Y) EVALUATING_PASTE2(_SELECTOR, _TRIPLE_2(ZTYPE, NAME, SIGNATURE, TYPE_X, UNPAREN2(TYPE_Y), TYPES_Z)) -// #define _SELECTOR_TRIPLE(YTYPE, ZTYPE, NAME, SIGNATURE, ENUM_X, TYPE_X, TYPES_Z, ...) case ENUM_X: { switch (YTYPE) { EXPAND(DISPATCH_TTYPES2(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, __VA_ARGS__ )); default: {printf("[ERROR] Unknown dtypeY=%d on %s:%d\n", YTYPE, __FILE__, __LINE__); ; fflush(stdout);} } break; }; -// #define SELECTOR_TRIPLE(YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, TYPE_X) EVALUATING_PASTE(_SELECTOR, _TRIPLE(YTYPE, ZTYPE, NAME, SIGNATURE, UNPAREN(TYPE_X), TYPES_Z, UNPAREN(TYPES_Y))) - -// #define _SELECTOR_SINGLE(A, B, C, D) case C: {AB; break;}; -// #define SELECTOR_SINGLE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_SINGLE(A, B, UNPAREN(C))) - -// #define _SELECTOR_SINGLE_THRICE(A, B, C, D) case C: {AB; break;}; -// #define SELECTOR_SINGLE_THRICE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_SINGLE_THRICE(A, B, UNPAREN(C))) - -// #define _SELECTOR_SINGLE_TWICE(A, B, C, D) case C: {AB; break;}; -// #define SELECTOR_SINGLE_TWICE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_SINGLE_TWICE(A, B, UNPAREN(C))) - -// #define _TEMPLATE_SINGLE_TWICE(A, B, C, D) AB; -// #define TEMPLATE_SINGLE_TWICE(A, B, C) EVALUATING_PASTE(_TEM, PLATE_SINGLE_TWICE(A, B, UNPAREN(C))) - -// #define _SELECTOR_PARTIAL_SINGLE(A, B, C, D) case C: {A D, UNPAREN2(B); break;}; -// #define SELECTOR_PARTIAL_SINGLE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_PARTIAL_SINGLE(A, B, UNPAREN(C))) - -// #define _RANDOMSINGLE(A, B, C, D) AB; +// #define _SELECTOR_DOUBLE_2(NAME, SIGNATURE, TYPE_A, ENUM, TYPE_B) +// case ENUM: { +// NAME SIGNATURE; +// break; +// }; +// #define SELECTOR_DOUBLE_2(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVALUATING_PASTE2(_SELECT, OR_DOUBLE_2(NAME, UNPAREN3(SIGNATURE), TYPE_A, +// UNPAREN3(TYPE_B))) + +// #define _SELECTOR_DOUBLE(YTYPE, NAME, SIGNATURE, ENUM, TYPE_A, ...) +// case ENUM: { +// switch (YTYPE) { +// EXPAND(DISPATCH_DTYPES2(NAME, SIGNATURE, TYPE_A, __VA_ARGS__)); +// default: { +// printf("[ERROR] Unknown dtypeX=%d on %s:%d\n", YTYPE, __FILE__, +// __LINE__); +// fflush(stdout); +// } +// }; +// break; +// }; +// #define SELECTOR_DOUBLE(YTYPE, NAME, SIGNATURE, TYPES_B, TYPE_A) +// EVALUATING_PASTE(_SELECTOR, _DOUBLE(YTYPE, NAME, SIGNATURE, UNPAREN(TYPE_A), +// UNPAREN(TYPES_B))) + +// #define _SELECTOR_PAIRWISE_2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// ENUM, TYPE_B) +// case ENUM: { +// if (ZTYPE == YTYPE) { +// NAME SIGNATURE; +// } else if (XTYPE == ZTYPE) { +// NAME SIGNATURE; +// } else { +// printf("[ERROR] Unknown dtypeX=%d on %s:%d\n", YTYPE, __FILE__, +// __LINE__); +// fflush(stdout); +// throw std::runtime_error("Unknown Z operand"); +// }; +// break; +// }; +// #define SELECTOR_PAIRWISE_2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// TYPE_B) +// EVALUATING_PASTE2( +// _SELECT, OR_PAIRWISE_2(XTYPE, YTYPE, ZTYPE, NAME, UNPAREN3(SIGNATURE), +// TYPE_A, UNPAREN3(TYPE_B))) +// #define _SELECTOR_PAIRWISE(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, ENUM, TYPE_A, +// ...) +// case ENUM: { +// switch (YTYPE) { +// EXPAND(DISPATCH_PAIRWISE2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)); +// default: { +// printf("[ERROR] Unknown dtypeX=%d on %s:%d\n", YTYPE, __FILE__, +// __LINE__); +// fflush(stdout); +// } +// }; +// break; +// }; +// #define SELECTOR_PAIRWISE(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_B, +// TYPE_A) +// EVALUATING_PASTE(_SELECTOR, _PAIRWISE(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, +// UNPAREN(TYPE_A), UNPAREN(TYPES_B))) + +// #define _SELECTOR_TRIPLE_3(NAME, SIGNATURE, TYPE_X, TYPE_Y, ENUM_Z, TYPE_Z) +// case ENUM_Z: { +// NAME SIGNATURE; +// }; break; +// #define SELECTOR_TRIPLE_3(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, TYPE_Z) +// EVALUATING_PASTE3( +// _SELECTOR, _TRIPLE_3(NAME, SIGNATURE, TYPE_X, TYPE_Y, UNPAREN3(TYPE_Z))) +// #define _SELECTOR_TRIPLE_2(ZTYPE, NAME, SIGNATURE, TYPE_X, ENUM_Y, TYPE_Y, +// TYPES_Z) +// case ENUM_Y: { +// switch (ZTYPE) { +// EXPAND2(DISPATCH_TTYPES3(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, +// UNPAREN3(TYPES_Z))); +// default: { +// printf("[ERROR] Unknown dtypeZ=%d on %s:%d\n", ZTYPE, __FILE__, +// __LINE__); +// ; +// fflush(stdout); +// } +// } +// break; +// }; +// #define SELECTOR_TRIPLE_2(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, TYPE_Y) +// EVALUATING_PASTE2(_SELECTOR, _TRIPLE_2(ZTYPE, NAME, SIGNATURE, TYPE_X, +// UNPAREN2(TYPE_Y), TYPES_Z)) +// #define _SELECTOR_TRIPLE(YTYPE, ZTYPE, NAME, SIGNATURE, ENUM_X, TYPE_X, +// TYPES_Z, ...) +// case ENUM_X: { +// switch (YTYPE) { +// EXPAND(DISPATCH_TTYPES2(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, +// __VA_ARGS__)); +// default: { +// printf("[ERROR] Unknown dtypeY=%d on %s:%d\n", YTYPE, __FILE__, +// __LINE__); +// ; +// fflush(stdout); +// } +// } +// break; +// }; +// #define SELECTOR_TRIPLE(YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, +// TYPE_X) +// EVALUATING_PASTE(_SELECTOR, +// _TRIPLE(YTYPE, ZTYPE, NAME, SIGNATURE, UNPAREN(TYPE_X), +// TYPES_Z, UNPAREN(TYPES_Y))) + +// #define _SELECTOR_SINGLE(A, B, C, D) +// case C: { +// A B; +// break; +// }; +// #define SELECTOR_SINGLE(A, B, C) +// EVALUATING_PASTE(_SEL, ECTOR_SINGLE(A, B, UNPAREN(C))) + +// #define _SELECTOR_SINGLE_THRICE(A, B, C, D) +// case C: { +// A B; +// break; +// }; +// #define SELECTOR_SINGLE_THRICE(A, B, C) +// EVALUATING_PASTE(_SEL, ECTOR_SINGLE_THRICE(A, B, UNPAREN(C))) + +// #define _SELECTOR_SINGLE_TWICE(A, B, C, D) +// case C: { +// A B; +// break; +// }; +// #define SELECTOR_SINGLE_TWICE(A, B, C) +// EVALUATING_PASTE(_SEL, ECTOR_SINGLE_TWICE(A, B, UNPAREN(C))) + +// #define _TEMPLATE_SINGLE_TWICE(A, B, C, D) A B; +// #define TEMPLATE_SINGLE_TWICE(A, B, C) +// EVALUATING_PASTE(_TEM, PLATE_SINGLE_TWICE(A, B, UNPAREN(C))) + +// #define _SELECTOR_PARTIAL_SINGLE(A, B, C, D) +// case C: { +// A D, UNPAREN2(B); +// break; +// }; +// #define SELECTOR_PARTIAL_SINGLE(A, B, C) +// EVALUATING_PASTE(_SEL, ECTOR_PARTIAL_SINGLE(A, B, UNPAREN(C))) + +// #define _RANDOMSINGLE(A, B, C, D) A B; // #define _RANDOMSINGLEU(A, B, C, D) A D B; -// #define RANDOMSINGLE(A, B, C) EVALUATING_PASTE(_RAND, OMSINGLE(A, UNPAREN(B), UNPAREN(C))) -// #define RANDOMSINGLEU(A, B, C) EVALUATING_PASTE(_RAND, OMSINGLEU(A, UNPAREN(B), UNPAREN(C))) -// #define RANDOMDOUBLE(A, B, C, D) EXPAND(DISPATCH_DTYPES(A, UNPAREN(B), D, UNPAREN(C))) - -// #define _RANDOMDOUBLE2(A, B, C, D, E, F) AB; -// #define RANDOMDOUBLE2(A, B, C, D) EVALUATING_PASTE(_RAND, OMDOUBLE2(A, B, UNPAREN(C), UNPAREN(D))) - -// #define _RANDOMPAIRWISE2(A, B, C, D, E) AE; -// #define RANDOMPAIRWISE(A, B, C) EVALUATING_PASTE(_RANDOM, PAIRWISE2(A, UNPAREN(C), UNPAREN(B))) - -// #define _RANDOMTRIPLE3(A, B, ZN, ZT, YN, YT, XN, XT) AB; -// #define RANDOMTRIPLE3(A, B, Z, Y, X) EVALUATING_PASTE(_RANDOM, TRIPLE3(A, UNPAREN(B), UNPAREN(Z), UNPAREN(Y), UNPAREN(X))) - -// #define _RANDOMTRIPLE2(NAME, SIGNATURE, TYPE_Z, TYPE_Y, TYPES_X) EVALX(_EXEC_TRIPLE_T3(RANDOMTRIPLE3, NAME, SIGNATURE, TYPE_Z, TYPE_Y, UNPAREN(TYPES_X))) -// #define RANDOMTRIPLE2(NAME, SIGNATURE, TYPE_Z, TYPES_X, TYPE_Y) _RANDOMTRIPLE2(NAME, SIGNATURE, TYPE_Z, TYPE_Y, TYPES_X) -// #define _RANDOMTRIPLE(NAME, SIGNATURE, TYPE_Z, TYPES_X, TYPES_Y) EVAL(_EXEC_TRIPLE_T2(RANDOMTRIPLE2, NAME, SIGNATURE, TYPE_Z, TYPES_X, UNPAREN(TYPES_Y))) -// #define RANDOMTRIPLE(NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPE_Z) _RANDOMTRIPLE(NAME, SIGNATURE, TYPE_Z, TYPES_X, TYPES_Y) - - -// #define BROADCAST(NAME) sd::BroadcastOpsTuple::custom(sd::scalar::NAME, sd::pairwise::NAME, sd::broadcast::NAME) -// #define BROADCAST_BOOL(NAME) sd::BroadcastBoolOpsTuple::custom(sd::scalar::NAME, sd::pairwise::NAME, sd::broadcast::NAME) +// #define RANDOMSINGLE(A, B, C) +// EVALUATING_PASTE(_RAND, OMSINGLE(A, UNPAREN(B), UNPAREN(C))) +// #define RANDOMSINGLEU(A, B, C) +// EVALUATING_PASTE(_RAND, OMSINGLEU(A, UNPAREN(B), UNPAREN(C))) +// #define RANDOMDOUBLE(A, B, C, D) +// EXPAND(DISPATCH_DTYPES(A, UNPAREN(B), D, UNPAREN(C))) + +// #define _RANDOMDOUBLE2(A, B, C, D, E, F) A B; +// #define RANDOMDOUBLE2(A, B, C, D) +// EVALUATING_PASTE(_RAND, OMDOUBLE2(A, B, UNPAREN(C), UNPAREN(D))) + +// #define _RANDOMPAIRWISE2(A, B, C, D, E) A E; +// #define RANDOMPAIRWISE(A, B, C) +// EVALUATING_PASTE(_RANDOM, PAIRWISE2(A, UNPAREN(C), UNPAREN(B))) + +// #define _RANDOMTRIPLE3(A, B, ZN, ZT, YN, YT, XN, XT) A B; +// #define RANDOMTRIPLE3(A, B, Z, Y, X) +// EVALUATING_PASTE(_RANDOM, +// TRIPLE3(A, UNPAREN(B), UNPAREN(Z), UNPAREN(Y), UNPAREN(X))) + +// #define _RANDOMTRIPLE2(NAME, SIGNATURE, TYPE_Z, TYPE_Y, TYPES_X) +// EVALX(_EXEC_TRIPLE_T3(RANDOMTRIPLE3, NAME, SIGNATURE, TYPE_Z, TYPE_Y, +// UNPAREN(TYPES_X))) +// #define RANDOMTRIPLE2(NAME, SIGNATURE, TYPE_Z, TYPES_X, TYPE_Y) +// _RANDOMTRIPLE2(NAME, SIGNATURE, TYPE_Z, TYPE_Y, TYPES_X) +// #define _RANDOMTRIPLE(NAME, SIGNATURE, TYPE_Z, TYPES_X, TYPES_Y) +// EVAL(_EXEC_TRIPLE_T2(RANDOMTRIPLE2, NAME, SIGNATURE, TYPE_Z, TYPES_X, +// UNPAREN(TYPES_Y))) +// #define RANDOMTRIPLE(NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPE_Z) +// _RANDOMTRIPLE(NAME, SIGNATURE, TYPE_Z, TYPES_X, TYPES_Y) + +// #define BROADCAST(NAME) +// sd::BroadcastOpsTuple::custom(sd::scalar::NAME, sd::pairwise::NAME, +// sd::broadcast::NAME) +// #define BROADCAST_BOOL(NAME) +// sd::BroadcastBoolOpsTuple::custom(sd::scalar::NAME, sd::pairwise::NAME, +// sd::broadcast::NAME) public static final int ALL_STRINGS =UTF32; public static final int ALL_INDICES =INT64; public static final int ALL_INTS =UINT64; public static final int ALL_FLOATS =BFLOAT16; -// #endif //TESTS_CPU_TYPE_BOILERPLATE_H +// #endif // TESTS_CPU_TYPE_BOILERPLATE_H // Parsed from system/op_boilerplate.h @@ -11366,12 +13483,12 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define REQUIRE_OK(A) if (sd::ops::resultHelper( (A), #A, __FILE__, __LINE__ ) != 0) return ND4J_STATUS_VALIDATION; // #define REQUIRE_TRUE(COND, ...) if (!(COND)) { if (sd::ops::conditionHelper(__FILE__, __LINE__, COND, __VA_ARGS__) != 0) throw std::invalid_argument("Op validation failed");}; -// #define DECLARE_ENTRY(NAME, ...) template struct ND4J_EXPORT __registratorFloat>; -// template struct ND4J_EXPORT __registratorHalf>; -// template struct ND4J_EXPORT __registratorDouble>; -// template struct ND4J_EXPORT __registratorSynonymHalf>; -// template struct ND4J_EXPORT __registratorSynonymDouble>; -// template struct ND4J_EXPORT __registratorSynonymFloat>; +// #define DECLARE_ENTRY(NAME, ...) template struct SD_EXPORT __registratorFloat>; +// template struct SD_EXPORT __registratorHalf>; +// template struct SD_EXPORT __registratorDouble>; +// template struct SD_EXPORT __registratorSynonymHalf>; +// template struct SD_EXPORT __registratorSynonymDouble>; +// template struct SD_EXPORT __registratorSynonymFloat>; // #if defined(_MSC_VER) || defined(_WIN64) || defined(_WIN32) || defined(__CLION_IDE__) || defined(__VSCODE__) @@ -11389,7 +13506,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define REGISTER_H(NAME) template // struct __registrator_##NAME { // __registrator_##NAME() { -// OpName *ptr = new OpName(); +// auto ptr = std::make_shared(); // OpRegistrator::getInstance().registerOperation(ptr); // } // }; @@ -11403,7 +13520,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define REGISTER_C(NAME) // #endif -// #define DECLARE_OP(NAME, NIN, NOUT, INPLACEABLE) class ND4J_EXPORT NAME: public sd::ops::DeclarableOp { +// #define DECLARE_OP(NAME, NIN, NOUT, INPLACEABLE) class SD_EXPORT NAME: public sd::ops::DeclarableOp { // public: // NAME(); // sd::ShapeList* calculateOutputShape(sd::ShapeList* inputShape, sd::graph::Context& block); @@ -11413,7 +13530,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // }; // REGISTER_H(NAME) -// #define DECLARE_BOOLEAN_OP(NAME, NIN, SCALAR) class ND4J_EXPORT NAME: public sd::ops::BooleanOp { +// #define DECLARE_BOOLEAN_OP(NAME, NIN, SCALAR) class SD_EXPORT NAME: public sd::ops::BooleanOp { // public: // NAME(); // protected: @@ -11426,7 +13543,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // REGISTER_C(NAME) // Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block) -// #define DECLARE_LIST_OP(NAME, NIN, NOUT, TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::DeclarableListOp { +// #define DECLARE_LIST_OP(NAME, NIN, NOUT, TARGS, IARGS) class SD_EXPORT NAME: public sd::ops::DeclarableListOp { // public: // NAME(); // protected: @@ -11438,7 +13555,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // REGISTER_C(NAME) // Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block) -// #define DECLARE_LOGIC_OP(NAME) class ND4J_EXPORT NAME: public sd::ops::LogicOp { +// #define DECLARE_LOGIC_OP(NAME) class SD_EXPORT NAME: public sd::ops::LogicOp { // public: // NAME(); // protected: @@ -11469,7 +13586,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define DECLARE_SYN(NAME, ORIGINAL) template // struct __registratorSynonym_##NAME { // __registratorSynonym_##NAME(const char *name, const char *oname) { -// auto ptr = reinterpret_cast(OpRegistrator::getInstance().getOperation(oname)); +// auto ptr = OpRegistrator::getInstance().getOperation(oname); // if (ptr == nullptr) { // std::string newName(name); // std::string oldName(oname); @@ -11481,7 +13598,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // }; // static sd::ops::__registratorSynonym_##NAME zzz_register_opd_##NAME(#NAME, #ORIGINAL) -// #define DECLARE_DIVERGENT_OP(NAME, NIN, NOUT, INPLACEABLE) class ND4J_EXPORT NAME: public sd::ops::DeclarableOp { +// #define DECLARE_DIVERGENT_OP(NAME, NIN, NOUT, INPLACEABLE) class SD_EXPORT NAME: public sd::ops::DeclarableOp { // public: // NAME(); // sd::ShapeList* calculateOutputShape(sd::ShapeList* inputShape, sd::graph::Context& block); @@ -11504,7 +13621,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // } // Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block) -// #define DECLARE_CONFIGURABLE_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::DeclarableOp { +// #define DECLARE_CONFIGURABLE_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) class SD_EXPORT NAME: public sd::ops::DeclarableOp { // public: // NAME(); // sd::ShapeList* calculateOutputShape(sd::ShapeList* inputShape, sd::graph::Context& block); @@ -11527,7 +13644,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // } // Nd4jStatus sd::ops::NAME::validateAndExecute(Context& block) -// #define DECLARE_REDUCTION_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::DeclarableReductionOp { +// #define DECLARE_REDUCTION_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) class SD_EXPORT NAME: public sd::ops::DeclarableReductionOp { // public: // NAME(); // protected: @@ -11541,7 +13658,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block) -// #define DECLARE_CUSTOM_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::DeclarableCustomOp { +// #define DECLARE_CUSTOM_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) class SD_EXPORT NAME: public sd::ops::DeclarableCustomOp { // protected: // void registerTypes(); // Nd4jStatus validateAndExecute(Context& block); @@ -11563,7 +13680,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define DECLARE_TYPES(NAME) void sd::ops::NAME::registerTypes() -// #define DECLARE_BROADCASTABLE_OP(NAME,TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::BroadcastableOp { +// #define DECLARE_BROADCASTABLE_OP(NAME,TARGS, IARGS) class SD_EXPORT NAME: public sd::ops::BroadcastableOp { // protected: // void registerTypes(); // Nd4jStatus validateAndExecute(Context& block); @@ -11572,7 +13689,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // }; // REGISTER_H(NAME) -// #define DECLARE_BROADCASTABLE_BOOL_OP(NAME,TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::BroadcastableBoolOp { +// #define DECLARE_BROADCASTABLE_BOOL_OP(NAME,TARGS, IARGS) class SD_EXPORT NAME: public sd::ops::BroadcastableBoolOp { // protected: // void registerTypes(); // Nd4jStatus validateAndExecute(Context& block); @@ -11637,20 +13754,20 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define CHECK_STASH(NAME) block.getStash()->checkStash(block.getNodeId(), NAME); // #define UNSTASH(NAME) block.getStash()->extractArray(block.getNodeId(), NAME); -// #define INPUT_VARIABLE(INDEX) block.array(INDEX) +// #define INPUT_VARIABLE(INDEX) block.arrayForOp(INDEX) // #define OUTPUT_VARIABLE(INDEX) reinterpret_cast(this->getZ(block, INDEX)) // #define OUTPUT_NULLIFIED(INDEX) reinterpret_cast(this->getNullifiedZ(block, INDEX)) -// #define INPUT_LIST(INDEX) reinterpret_cast(block.getVariable(INDEX)->getNDArrayList()) +// #define INPUT_LIST(INDEX) reinterpret_cast(block.getVariable(INDEX)->getNDArrayList().get()) -// #define D_ARG(INDEX) block.getDArguments()->at(INDEX) -// #define INT_ARG(INDEX) block.getIArguments()->at(INDEX) +// #define D_ARG(INDEX) block.getDArguments().at(INDEX) +// #define INT_ARG(INDEX) block.getIArguments().at(INDEX) // #define I_ARG(INDEX) INT_ARG(INDEX) -// #define T_ARG(INDEX) block.getTArguments()->at(INDEX) -// #define B_ARG(INDEX) block.getBArguments()->at(INDEX) +// #define T_ARG(INDEX) block.getTArguments().at(INDEX) +// #define B_ARG(INDEX) block.getBArguments().at(INDEX) -// #define COPY_SHAPE(SRC, TGT) TGT = ShapeBuilders::copyShapeInfo(SRC, true, block.getWorkspace()) +// #define COPY_SHAPE(SRC, TGT) TGT = ShapeBuilders::copyShapeInfo(SRC, true, block.workspace()) // #define COPY_SHAPE_EX(SRC, TGT, WORKSPACE) TGT = ShapeBuilders::copyShapeInfo(SRC, true, WORKSPACE) @@ -11754,15 +13871,15 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef ND4J_INPUTTYPE_H // #define ND4J_INPUTTYPE_H - /** enum sd::ops::InputType */ - public static final int - InputType_BOOLEAN = 0, - InputType_NUMERIC = 1, - InputType_STRINGULAR = 2, - InputType_NUMERIC_SET = 3, - InputType_STRINGULAR_SET = 4; - +/** enum sd::ops::InputType */ +public static final int + InputType_BOOLEAN = 0, + InputType_NUMERIC = 1, + InputType_STRINGULAR = 2, + InputType_NUMERIC_SET = 3, + InputType_STRINGULAR_SET = 4; + // namespace sd // #endif @@ -11791,134 +13908,155 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_OPDESCRIPTOR_H // #define LIBND4J_OPDESCRIPTOR_H -// #include -// #include -// #include +// #include +// #include // #include // #include -// #include -// #include - - /** - * This class is very basic info holder for ops. bean/pojo pretty much. - * - */ - @Namespace("sd::ops") @NoOffset public static class OpDescriptor extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public OpDescriptor(Pointer p) { super(p); } - - // default constructor - public OpDescriptor(int numInputs, int numOutputs, @StdString BytePointer opName, @Cast("bool") boolean allowsInplace) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace); } - private native void allocate(int numInputs, int numOutputs, @StdString BytePointer opName, @Cast("bool") boolean allowsInplace); - public OpDescriptor(int numInputs, int numOutputs, @StdString String opName, @Cast("bool") boolean allowsInplace) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace); } - private native void allocate(int numInputs, int numOutputs, @StdString String opName, @Cast("bool") boolean allowsInplace); - - // constructor for boolean ops - public OpDescriptor(int numInputs, @StdString BytePointer opName, @Cast("bool") boolean isScalar) { super((Pointer)null); allocate(numInputs, opName, isScalar); } - private native void allocate(int numInputs, @StdString BytePointer opName, @Cast("bool") boolean isScalar); - public OpDescriptor(int numInputs, @StdString String opName, @Cast("bool") boolean isScalar) { super((Pointer)null); allocate(numInputs, opName, isScalar); } - private native void allocate(int numInputs, @StdString String opName, @Cast("bool") boolean isScalar); - - // default constructor - // constructor for configurable op - public OpDescriptor(int numInputs, int numOutputs, @Cast("char*") String opName, @Cast("bool") boolean allowsInplace, int tArgs, int iArgs) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs); } - private native void allocate(int numInputs, int numOutputs, @Cast("char*") String opName, @Cast("bool") boolean allowsInplace, int tArgs, int iArgs); - public OpDescriptor(int numInputs, int numOutputs, @Cast("char*") BytePointer opName, @Cast("bool") boolean allowsInplace, int tArgs, int iArgs) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs); } - private native void allocate(int numInputs, int numOutputs, @Cast("char*") BytePointer opName, @Cast("bool") boolean allowsInplace, int tArgs, int iArgs); - - // constructor for non-configurable divergent op - public OpDescriptor(int numInputs, int numOutputs, @StdString BytePointer opName, @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace, divergent); } - private native void allocate(int numInputs, int numOutputs, @StdString BytePointer opName, @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent); - public OpDescriptor(int numInputs, int numOutputs, @StdString String opName, @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace, divergent); } - private native void allocate(int numInputs, int numOutputs, @StdString String opName, @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent); - - // constructor for non-configurable divergent op - - // constructor for configurable divergent op - public OpDescriptor(int numInputs, int numOutputs, @Cast("char*") String opName, @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent, int tArgs, int iArgs) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace, divergent, tArgs, iArgs); } - private native void allocate(int numInputs, int numOutputs, @Cast("char*") String opName, @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent, int tArgs, int iArgs); - public OpDescriptor(int numInputs, int numOutputs, @Cast("char*") BytePointer opName, @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent, int tArgs, int iArgs) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace, divergent, tArgs, iArgs); } - private native void allocate(int numInputs, int numOutputs, @Cast("char*") BytePointer opName, @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent, int tArgs, int iArgs); - - // constructor for logical ops (while, scope, etc) - public OpDescriptor(@Cast("char*") String opName, @Cast("bool") boolean isLogic) { super((Pointer)null); allocate(opName, isLogic); } - private native void allocate(@Cast("char*") String opName, @Cast("bool") boolean isLogic); - public OpDescriptor(@Cast("char*") BytePointer opName, @Cast("bool") boolean isLogic) { super((Pointer)null); allocate(opName, isLogic); } - private native void allocate(@Cast("char*") BytePointer opName, @Cast("bool") boolean isLogic); - - public native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef OpDescriptor other); - - // default destructor - - // this method returns minimal expected number of T arguments - public native int getNumberOfTArgs(); - - // this method returns minimal expected number of Integer arguments - public native int getNumberOfIArgs(); - - // this method returns minimal expected number of inputs - public native int getNumberOfInputs(); - - // this method returns hash code for this operation - public native @Cast("Nd4jLong") long getHash(); - - // this method returns minimal expected number of outputs - public native int getNumberOfOutputs(); - - // this method returns opName (can be empty) - public native @StdString @Cast({"char*", "std::string*"}) BytePointer getOpName(); - - // returns TRUE if this op is divergent. FALSE otherwise - public native @Cast("bool") boolean isDivergent(); - - // returns TRUE if this op allows in-place execution - public native @Cast("bool") boolean allowsInplace(); - - // this method allows you to enable/disable inplace call for a given op - public native void allowInplace(@Cast("bool") boolean reallyAllow); - - // this method returns opNum (applicable for legacy XYZ ops only) - public native int getOpNum(); - - // this method allows to set specifc opNum - public native void setOpNum(int opNum); - - public native void setHash(@Cast("Nd4jLong") long hash); - - public native @Cast("sd::ops::InputType") int inputType(); - - - - public native OpDescriptor setInputType(@Cast("sd::ops::InputType") int type); - public native OpDescriptor setAllowedInputTypes(int index, @Cast("sd::DataType*") @StdVector IntPointer dtype); - public native OpDescriptor setAllowedInputTypes(int index, @Cast("sd::DataType*") @StdVector IntBuffer dtype); - public native OpDescriptor setAllowedInputTypes(int index, @Cast("sd::DataType*") @StdVector int[] dtype); - public native OpDescriptor setAllowedOutputTypes(int index, @Cast("sd::DataType*") @StdVector IntPointer dtype); - public native OpDescriptor setAllowedOutputTypes(int index, @Cast("sd::DataType*") @StdVector IntBuffer dtype); - public native OpDescriptor setAllowedOutputTypes(int index, @Cast("sd::DataType*") @StdVector int[] dtype); - public native OpDescriptor setAllowedInputTypes(int index, @Cast("sd::DataType") int dtype); - public native OpDescriptor setAllowedOutputTypes(int index, @Cast("sd::DataType") int dtype); - public native OpDescriptor setAllowedInputTypes(@Cast("sd::DataType") int dtype); - public native OpDescriptor setAllowedOutputTypes(@Cast("sd::DataType") int dtype); - public native OpDescriptor allowOverride(@Cast("bool") boolean reallyAllow); - public native OpDescriptor setSameMode(@Cast("bool") boolean reallySame); - public native OpDescriptor setInputType(int idx, @Cast("sd::DataType") int dtype); - public native OpDescriptor setOutputType(int idx, @Cast("sd::DataType") int dtype); - - public native @Cast("sd::DataType*") @StdVector IntPointer getOutputTypesForOutput(int index); - - public native @Cast("bool") boolean checkInputMatch(int index, @Cast("sd::DataType") int dataType); - public native @Cast("bool") boolean checkOutputMatch(int index, @Cast("sd::DataType") int dataType); - public native @Cast("bool") boolean isSameMode(); - - public native @Cast("bool") boolean isInherit(int index); - } - +// #include +// #include +// #include +/** + * This class is very basic info holder for ops. bean/pojo pretty much. + * + */ +@Namespace("sd::ops") @NoOffset public static class OpDescriptor extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public OpDescriptor(Pointer p) { super(p); } + + // default constructor + public OpDescriptor(int numInputs, int numOutputs, @StdString BytePointer opName, + @Cast("bool") boolean allowsInplace) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace); } + private native void allocate(int numInputs, int numOutputs, @StdString BytePointer opName, + @Cast("bool") boolean allowsInplace); + public OpDescriptor(int numInputs, int numOutputs, @StdString String opName, + @Cast("bool") boolean allowsInplace) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace); } + private native void allocate(int numInputs, int numOutputs, @StdString String opName, + @Cast("bool") boolean allowsInplace); + + // constructor for boolean ops + public OpDescriptor(int numInputs, @StdString BytePointer opName, @Cast("bool") boolean isScalar) { super((Pointer)null); allocate(numInputs, opName, isScalar); } + private native void allocate(int numInputs, @StdString BytePointer opName, @Cast("bool") boolean isScalar); + public OpDescriptor(int numInputs, @StdString String opName, @Cast("bool") boolean isScalar) { super((Pointer)null); allocate(numInputs, opName, isScalar); } + private native void allocate(int numInputs, @StdString String opName, @Cast("bool") boolean isScalar); + + // default constructor + + // constructor for configurable op + public OpDescriptor(int numInputs, int numOutputs, @Cast("char*") String opName, + @Cast("bool") boolean allowsInplace, int tArgs, int iArgs) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs); } + private native void allocate(int numInputs, int numOutputs, @Cast("char*") String opName, + @Cast("bool") boolean allowsInplace, int tArgs, int iArgs); + public OpDescriptor(int numInputs, int numOutputs, @Cast("char*") BytePointer opName, + @Cast("bool") boolean allowsInplace, int tArgs, int iArgs) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs); } + private native void allocate(int numInputs, int numOutputs, @Cast("char*") BytePointer opName, + @Cast("bool") boolean allowsInplace, int tArgs, int iArgs); + + // constructor for non-configurable divergent op + public OpDescriptor(int numInputs, int numOutputs, @StdString BytePointer opName, + @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace, divergent); } + private native void allocate(int numInputs, int numOutputs, @StdString BytePointer opName, + @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent); + public OpDescriptor(int numInputs, int numOutputs, @StdString String opName, + @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace, divergent); } + private native void allocate(int numInputs, int numOutputs, @StdString String opName, + @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent); + + // constructor for non-configurable divergent op + + // constructor for configurable divergent op + public OpDescriptor(int numInputs, int numOutputs, @Cast("char*") String opName, + @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent, int tArgs, int iArgs) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace, divergent, tArgs, iArgs); } + private native void allocate(int numInputs, int numOutputs, @Cast("char*") String opName, + @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent, int tArgs, int iArgs); + public OpDescriptor(int numInputs, int numOutputs, @Cast("char*") BytePointer opName, + @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent, int tArgs, int iArgs) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace, divergent, tArgs, iArgs); } + private native void allocate(int numInputs, int numOutputs, @Cast("char*") BytePointer opName, + @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent, int tArgs, int iArgs); + + // constructor for logical ops (while, scope, etc) + public OpDescriptor(@Cast("char*") String opName, @Cast("bool") boolean isLogic) { super((Pointer)null); allocate(opName, isLogic); } + private native void allocate(@Cast("char*") String opName, @Cast("bool") boolean isLogic); + public OpDescriptor(@Cast("char*") BytePointer opName, @Cast("bool") boolean isLogic) { super((Pointer)null); allocate(opName, isLogic); } + private native void allocate(@Cast("char*") BytePointer opName, @Cast("bool") boolean isLogic); + + public native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef OpDescriptor other); + + // default destructor + + // this method returns minimal expected number of T arguments + public native int getNumberOfTArgs(); + + // this method returns minimal expected number of Integer arguments + public native int getNumberOfIArgs(); + + // this method returns minimal expected number of inputs + public native int getNumberOfInputs(); + + // this method returns hash code for this operation + public native @Cast("Nd4jLong") long getHash(); + + // this method returns minimal expected number of outputs + public native int getNumberOfOutputs(); + + // this method returns opName (can be empty) + public native @StdString @Cast({"char*", "std::string*"}) BytePointer getOpName(); + + // returns TRUE if this op is divergent. FALSE otherwise + public native @Cast("bool") boolean isDivergent(); + + // returns TRUE if this op allows in-place execution + public native @Cast("bool") boolean allowsInplace(); + + // this method allows you to enable/disable inplace call for a given op + public native void allowInplace(@Cast("bool") boolean reallyAllow); + + // this method returns opNum (applicable for legacy XYZ ops only) + public native int getOpNum(); + + // this method allows to set specifc opNum + public native void setOpNum(int opNum); + + public native void setHash(@Cast("Nd4jLong") long hash); + + public native @Cast("sd::ops::InputType") int inputType(); + + public native OpDescriptor setInputType(@Cast("sd::ops::InputType") int type); + public native OpDescriptor setAllowedInputTypes(int index, + @Cast("sd::DataType*") @StdVector IntPointer dtype); + public native OpDescriptor setAllowedInputTypes(int index, + @Cast("sd::DataType*") @StdVector IntBuffer dtype); + public native OpDescriptor setAllowedInputTypes(int index, + @Cast("sd::DataType*") @StdVector int[] dtype); + public native OpDescriptor setAllowedOutputTypes(int index, + @Cast("sd::DataType*") @StdVector IntPointer dtype); + public native OpDescriptor setAllowedOutputTypes(int index, + @Cast("sd::DataType*") @StdVector IntBuffer dtype); + public native OpDescriptor setAllowedOutputTypes(int index, + @Cast("sd::DataType*") @StdVector int[] dtype); + public native OpDescriptor setAllowedInputTypes(int index, @Cast("sd::DataType") int dtype); + public native OpDescriptor setAllowedOutputTypes(int index, @Cast("sd::DataType") int dtype); + public native OpDescriptor setAllowedInputTypes(@Cast("sd::DataType") int dtype); + public native OpDescriptor setAllowedOutputTypes(@Cast("sd::DataType") int dtype); + public native OpDescriptor allowOverride(@Cast("bool") boolean reallyAllow); + public native OpDescriptor setSameMode(@Cast("bool") boolean reallySame); + public native OpDescriptor setInputType(int idx, @Cast("sd::DataType") int dtype); + public native OpDescriptor setOutputType(int idx, @Cast("sd::DataType") int dtype); + + public native @Cast("sd::DataType*") @StdVector IntPointer getOutputTypesForOutput(int index); + + public native @Cast("bool") boolean checkInputMatch(int index, @Cast("sd::DataType") int dataType); + public native @Cast("bool") boolean checkOutputMatch(int index, @Cast("sd::DataType") int dataType); + public native @Cast("bool") boolean isSameMode(); + + public native @Cast("bool") boolean isInherit(int index); +} + // namespace ops + // namespace sd -// #endif //LIBND4J_OPDESCRIPTOR_H +// #endif // LIBND4J_OPDESCRIPTOR_H // Parsed from ops/declarable/PlatformHelper.h @@ -11946,65 +14084,68 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef SD_PLATFORMHELPER_H // #define SD_PLATFORMHELPER_H -// #include // #include // #include -// #include -// #include +// #include // #include - /** - * This abstract class defines methods used by platform-specific helpers implementations - */ - @Namespace("sd::ops::platforms") @NoOffset public static class PlatformHelper extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public PlatformHelper(Pointer p) { super(p); } - - - public native @StdString BytePointer name(); - - public native @Cast("samediff::Engine") int engine(); - - public native @Cast("Nd4jLong") long hash(); - - /** - * This method checks, if given helper can be used with given input/output/configuration options - * - * @param context - * @return - */ - public native @Cast("bool") boolean isUsable(@ByRef Context context); - - /** - * This method invokes helper. Typically this method replaces actual op execution - * - * @param context - * @return - */ - public native @Cast("Nd4jStatus") int invokeHelper(@ByRef Context context); - - /** - * Helper method, needed for compatibility with DeclarableOp macros - * @param ctx - * @param inputId - * @return - */ - public native NDArray getZ(@ByRef Context ctx, int inputId); - - /** - * Helper method, needed for compatibility with DeclarableOp macros - * @param ctx - * @param inputId - * @return - */ - public native NDArray getNullifiedZ(@ByRef Context ctx, int inputId); - } - - +// #include + +// #include +/** + * This abstract class defines methods used by platform-specific helpers + * implementations + */ +@Namespace("sd::ops::platforms") @NoOffset public static class PlatformHelper extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public PlatformHelper(Pointer p) { super(p); } + + + public native @StdString BytePointer name(); + + public native @Cast("samediff::Engine") int engine(); + + public native @Cast("Nd4jLong") long hash(); + + /** + * This method checks, if given helper can be used with given + * input/output/configuration options + * + * @param context + * @return + */ + public native @Cast("bool") boolean isUsable(@ByRef Context context); + + /** + * This method invokes helper. Typically this method replaces actual op + * execution + * + * @param context + * @return + */ + public native @Cast("Nd4jStatus") int invokeHelper(@ByRef Context context); + /** + * Helper method, needed for compatibility with DeclarableOp macros + * @param ctx + * @param inputId + * @return + */ + public native NDArray getZ(@ByRef Context ctx, int inputId); + /** + * Helper method, needed for compatibility with DeclarableOp macros + * @param ctx + * @param inputId + * @return + */ + public native NDArray getNullifiedZ(@ByRef Context ctx, int inputId); +} + // namespace platforms + // namespace ops + // namespace sd -// #endif //SD_PLATFORMHELPER_H +// #endif // SD_PLATFORMHELPER_H // Parsed from ops/declarable/BroadcastableOp.h @@ -12033,22 +14174,23 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_BROADCASTABLEOP_H // #include -// #include "OpDescriptor.h" -// #include "DeclarableOp.h" -// #include "DeclarableCustomOp.h" - @Namespace("sd::ops") public static class BroadcastableOp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public BroadcastableOp(Pointer p) { super(p); } - - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } - +// #include "DeclarableCustomOp.h" +// #include "DeclarableOp.h" +// #include "OpDescriptor.h" +@Namespace("sd::ops") public static class BroadcastableOp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public BroadcastableOp(Pointer p) { super(p); } + public native ShapeList calculateOutputShape(ShapeList inputShape, + @ByRef Context block); +} + // namespace ops + // namespace sd -// #endif //LIBND4J_BROADCASTABLEOP_H +// #endif // LIBND4J_BROADCASTABLEOP_H // Parsed from ops/declarable/BroadcastableBoolOp.h @@ -12077,22 +14219,23 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define SD_BROADCASTABLEBOOLOP_H // #include -// #include "OpDescriptor.h" -// #include "DeclarableOp.h" -// #include "DeclarableCustomOp.h" - @Namespace("sd::ops") public static class BroadcastableBoolOp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public BroadcastableBoolOp(Pointer p) { super(p); } - - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } - +// #include "DeclarableCustomOp.h" +// #include "DeclarableOp.h" +// #include "OpDescriptor.h" +@Namespace("sd::ops") public static class BroadcastableBoolOp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public BroadcastableBoolOp(Pointer p) { super(p); } + public native ShapeList calculateOutputShape(ShapeList inputShape, + @ByRef Context block); +} + // namespace ops + // namespace sd -// #endif //SD_BROADCASTABLEBOOLOP_H +// #endif // SD_BROADCASTABLEBOOLOP_H // Parsed from ops/declarable/DeclarableOp.h @@ -12120,156 +14263,290 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_DECLARABLE_OPS_H // #define LIBND4J_DECLARABLE_OPS_H -// #include -// #include -// #include // #include -// #include -// #include "OpDescriptor.h" -// #include -// #include // #include +// #include +// #include // #include -// #include +// #include // #include +// #include +// #include +// #include + +// #include + +// #include "OpDescriptor.h" //#include // #include // #include // #include - @Namespace("sd::ops") public static native @Cast("Nd4jStatus") int conditionHelper(@Cast("char*") String file, int line, int condition, int argNumber, @Cast("char*") String format); - @Namespace("sd::ops") public static native @Cast("Nd4jStatus") int conditionHelper(@Cast("char*") BytePointer file, int line, int condition, int argNumber, @Cast("char*") BytePointer format); +@Namespace("sd::ops") public static native @Cast("Nd4jStatus") int conditionHelper(@Cast("char*") String file, int line, int condition, + int argNumber, @Cast("char*") String format); +@Namespace("sd::ops") public static native @Cast("Nd4jStatus") int conditionHelper(@Cast("char*") BytePointer file, int line, int condition, + int argNumber, @Cast("char*") BytePointer format); - /** - * This class is the basic building block of Graph Operations. Any CustomOp out there is built on top of this "abstract" class. - * - */ - @Namespace("sd::ops") @NoOffset public static class DeclarableOp extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public DeclarableOp(Pointer p) { super(p); } - - // for special cases, like BooleanOps - - // regular constructors - - // for LogicalOps - - // default testructor - - // this method returns OpDescriptor, describing this Op instance - public native OpDescriptor getOpDescriptor(); - - public native @Cast("Nd4jStatus") int validateDataTypes(@ByRef Context block); - - /** - * This method should be available in each implemented Op, and should return Op output shape(s), for a given input shape(s) - */ - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - - /** - * Returns opName - * - * @return - */ - public native @StdString @Cast({"char*", "std::string*"}) BytePointer getOpName(); - - /** - * Returns opHash - */ - public native @Cast("Nd4jLong") long getOpHash(); - - /** - * This method sets arguments for op - */ -// void setArguments(); - - /** - * This method returns pointer to results - */ -// void getResults(); - - /** - * This method executes given Op - * - * @param block - * @return 0 if OK, error code otherwise - */ - public native @Cast("Nd4jStatus") int execute(Context block); - - public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs); - - public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs); - public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs); - public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs); - public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - - public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs); - - public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs); - public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs); - public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs); - public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs); - - public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/); - public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder); - - - // There methods provide various validation options - public native @Cast("Nd4jStatus") int validateNonEmptyInput(@ByRef Context block); - - // this method checks if all input arrays have equal lengths - public native @Cast("Nd4jStatus") int validateInputLengthMatch(@ByRef Context block); - - // this method checks if all input arrays have the same shapes (orders/strides are NOT checked) - public native @Cast("Nd4jStatus") int validateInputDimensionsMatch(@ByRef Context block); - - // this method check if all input arrays have the same orders - public native @Cast("Nd4jStatus") int validateOrdersMatch(@ByRef Context block); - - // this method checks if all input arrays are 2D - public native @Cast("Nd4jStatus") int validateInput2D(@ByRef Context block); - - // this method checks if all input arrays are 3D - public native @Cast("Nd4jStatus") int validateInput3D(@ByRef Context block); - - // this method checks if all input arrays are 4D - public native @Cast("Nd4jStatus") int validateInput4D(@ByRef Context block); - - // this method checks if all input arrays are ND - public native @Cast("Nd4jStatus") int validateInputDimensions(@ByRef Context block, int rank); - - // this method checks if number of available arguments matches op expectations - public native @Cast("Nd4jStatus") int validateArguments(@ByRef Context block); - } - +/** + * This class is the basic building block of Graph Operations. Any CustomOp out + * there is built on top of this "abstract" class. + * + */ +@Namespace("sd::ops") @NoOffset public static class DeclarableOp extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public DeclarableOp(Pointer p) { super(p); } + + // for special cases, like BooleanOps + + // regular constructors + // for LogicalOps -// #endif //LIBND4J_DECLARABLE_OPS_H + // default testructor + + // this method returns OpDescriptor, describing this Op instance + public native OpDescriptor getOpDescriptor(); + + public native @Cast("Nd4jStatus") int validateDataTypes(@ByRef Context block); + + /** + * This method should be available in each implemented Op, and should return + * Op output shape(s), for a given input shape(s) + */ + public native ShapeList calculateOutputShape(ShapeList inputShape, + @ByRef Context block); + + /** + * Returns opName + * + * @return + */ + public native @StdString BytePointer getOpName(); + + /** + * Returns opHash + */ + public native @Cast("Nd4jLong") long getOpHash(); + + /** + * This method sets arguments for op + */ + // void setArguments(); + + /** + * This method returns pointer to results + */ + // void getResults(); + + /** + * This method executes given Op + * + * @param block + * @return 0 if OK, error code otherwise + */ + public native @Cast("Nd4jStatus") int execute(Context block); + + public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs); + + public native @Cast("Nd4jStatus") int execute( + @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, + @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, + @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, + @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/); + public native @Cast("Nd4jStatus") int execute( + @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, + @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs); + public native @Cast("Nd4jStatus") int execute( + @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, + @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, + @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, + @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/); + public native @Cast("Nd4jStatus") int execute( + @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, + @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs); + public native @Cast("Nd4jStatus") int execute( + @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, + @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, + @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, + @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/); + public native @Cast("Nd4jStatus") int execute( + @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, + @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs); + public native @Cast("Nd4jStatus") int execute( + @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, + @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, + @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, + @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/); + public native @Cast("Nd4jStatus") int execute( + @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, + @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, + @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, + @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/); + public native @Cast("Nd4jStatus") int execute( + @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, + @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, + @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, + @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/); + + public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs); + + public native @ByVal ResultSet evaluate( + @Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, + @Cast("Nd4jLong*") @StdVector LongPointer iArgs, + @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, + @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet evaluate( + @Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, + @Cast("Nd4jLong*") @StdVector LongPointer iArgs); + public native @ByVal ResultSet evaluate( + @Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, + @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, + @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, + @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet evaluate( + @Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, + @Cast("Nd4jLong*") @StdVector LongBuffer iArgs); + public native @ByVal ResultSet evaluate( + @Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, + @Cast("Nd4jLong*") @StdVector long[] iArgs, + @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, + @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet evaluate( + @Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, + @Cast("Nd4jLong*") @StdVector long[] iArgs); + public native @ByVal ResultSet evaluate( + @Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, + @Cast("Nd4jLong*") @StdVector LongPointer iArgs, + @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, + @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet evaluate( + @Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, + @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, + @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, + @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet evaluate( + @Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, + @Cast("Nd4jLong*") @StdVector long[] iArgs, + @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, + @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/); + + public native @Cast("Nd4jStatus") int execute( + @ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, + @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, + @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); + public native @Cast("Nd4jStatus") int execute( + @ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, + @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); + public native @Cast("Nd4jStatus") int execute( + @ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, + @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, + @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); + public native @Cast("Nd4jStatus") int execute( + @ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, + @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs); + public native @Cast("Nd4jStatus") int execute( + @ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, + @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, + @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); + public native @Cast("Nd4jStatus") int execute( + @ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, + @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); + public native @Cast("Nd4jStatus") int execute( + @ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, + @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, + @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); + public native @Cast("Nd4jStatus") int execute( + @ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, + @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs); + public native @Cast("Nd4jStatus") int execute( + @ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, + @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, + @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); + public native @Cast("Nd4jStatus") int execute( + @ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, + @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); + public native @Cast("Nd4jStatus") int execute( + @ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, + @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, + @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); + public native @Cast("Nd4jStatus") int execute( + @ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, + @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs); + + public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder); + + // There methods provide various validation options + public native @Cast("Nd4jStatus") int validateNonEmptyInput(@ByRef Context block); + + // this method checks if all input arrays have equal lengths + public native @Cast("Nd4jStatus") int validateInputLengthMatch(@ByRef Context block); + + // this method checks if all input arrays have the same shapes (orders/strides + // are NOT checked) + public native @Cast("Nd4jStatus") int validateInputDimensionsMatch(@ByRef Context block); + + // this method check if all input arrays have the same orders + public native @Cast("Nd4jStatus") int validateOrdersMatch(@ByRef Context block); + + // this method checks if all input arrays are 2D + public native @Cast("Nd4jStatus") int validateInput2D(@ByRef Context block); + + // this method checks if all input arrays are 3D + public native @Cast("Nd4jStatus") int validateInput3D(@ByRef Context block); + + // this method checks if all input arrays are 4D + public native @Cast("Nd4jStatus") int validateInput4D(@ByRef Context block); + + // this method checks if all input arrays are ND + public native @Cast("Nd4jStatus") int validateInputDimensions(@ByRef Context block, int rank); + + // this method checks if number of available arguments matches op expectations + public native @Cast("Nd4jStatus") int validateArguments(@ByRef Context block); + + /** + * This method pre-allocates NDArrays for Op output, in case they are not + * available at op execution time + */ + public native int prepareOutputs(@ByRef Context block); +} + // namespace ops + // namespace sd + +// #endif // LIBND4J_DECLARABLE_OPS_H // Parsed from ops/declarable/DeclarableListOp.h @@ -12299,24 +14576,36 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #include // #include -// #include // #include - @Namespace("sd::ops") public static class DeclarableListOp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public DeclarableListOp(Pointer p) { super(p); } - - - - public native @Cast("Nd4jStatus") int execute(Context block); - public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @StdVector IntPointer iArgs); - public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @StdVector IntBuffer iArgs); - public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @StdVector int[] iArgs); - - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } - - +// #include +@Namespace("sd::ops") public static class DeclarableListOp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public DeclarableListOp(Pointer p) { super(p); } + + + public native @Cast("Nd4jStatus") int execute(Context block); + + public native @ByVal ResultSet execute(@Const @ByRef NDArrayList list, + @Const @ByRef NDArrayVector inputs, + @StdVector DoublePointer tArgs/*={}*/, + @StdVector IntPointer iArgs/*={}*/); + public native @ByVal ResultSet execute(@Const @ByRef NDArrayList list, + @Const @ByRef NDArrayVector inputs); + public native @ByVal ResultSet execute(@Const @ByRef NDArrayList list, + @Const @ByRef NDArrayVector inputs, + @StdVector DoubleBuffer tArgs/*={}*/, + @StdVector IntBuffer iArgs/*={}*/); + public native @ByVal ResultSet execute(@Const @ByRef NDArrayList list, + @Const @ByRef NDArrayVector inputs, + @StdVector double[] tArgs/*={}*/, + @StdVector int[] iArgs/*={}*/); + + public native ShapeList calculateOutputShape(ShapeList inputShape, + @ByRef Context block); +} + // namespace ops + // namespace sd // #endif @@ -12346,18 +14635,19 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_DECLARABLE_REDUCTION_OP_H // #include - @Namespace("sd::ops") public static class DeclarableReductionOp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public DeclarableReductionOp(Pointer p) { super(p); } - +@Namespace("sd::ops") public static class DeclarableReductionOp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public DeclarableReductionOp(Pointer p) { super(p); } - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } - + public native ShapeList calculateOutputShape(ShapeList inputShape, + @ByRef Context block); +} + // namespace ops + // namespace sd -// #endif //LIBND4J_DECLARABLE_REDUCTION_OP_H +// #endif // LIBND4J_DECLARABLE_REDUCTION_OP_H // Parsed from ops/declarable/DeclarableCustomOp.h @@ -12386,18 +14676,19 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_DECLARABLECUSTOMOP_H // #include - @Namespace("sd::ops") public static class DeclarableCustomOp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public DeclarableCustomOp(Pointer p) { super(p); } - +@Namespace("sd::ops") public static class DeclarableCustomOp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public DeclarableCustomOp(Pointer p) { super(p); } - public native ShapeList calculateOutputShape(ShapeList inputShapes, @ByRef Context block); - } - + public native ShapeList calculateOutputShape(ShapeList inputShapes, + @ByRef Context block); +} + // namespace ops + // namespace sd -// #endif //LIBND4J_DECLARABLECUSTOMOP_H +// #endif // LIBND4J_DECLARABLECUSTOMOP_H // Parsed from ops/declarable/BooleanOp.h @@ -12426,27 +14717,27 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_BOOLEANOP_H // #include -// #include "OpDescriptor.h" -// #include "DeclarableOp.h" - @Namespace("sd::ops") @NoOffset public static class BooleanOp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public BooleanOp(Pointer p) { super(p); } - - - public native @Cast("bool") boolean verify(@Const @ByRef NDArrayVector args); - public native @Cast("bool") boolean verify(@ByRef Context block); - public native @Cast("Nd4jStatus") int execute(Context block); +// #include "DeclarableOp.h" +// #include "OpDescriptor.h" +@Namespace("sd::ops") @NoOffset public static class BooleanOp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public BooleanOp(Pointer p) { super(p); } - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } - + public native @Cast("bool") boolean verify(@Const @ByRef NDArrayVector args); + public native @Cast("bool") boolean verify(@ByRef Context block); + public native @Cast("Nd4jStatus") int execute(Context block); + public native ShapeList calculateOutputShape(ShapeList inputShape, + @ByRef Context block); +} + // namespace ops + // namespace sd -// #endif //LIBND4J_BOOLEANOP_H +// #endif // LIBND4J_BOOLEANOP_H // Parsed from ops/declarable/LogicOp.h @@ -12475,29 +14766,31 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #include "DeclarableOp.h" - /** - * Logic ops are unique snowflakes in any Graph. They dramatically change Graph Execution process, by introducing loops, conditions, etc. - * - * Their code is the part of GraphExecutioner logic. But we still want them to be expressed via Graph - * \tparam T - */ - @Namespace("sd::ops") public static class LogicOp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public LogicOp(Pointer p) { super(p); } - - public LogicOp(@Cast("char*") String name) { super((Pointer)null); allocate(name); } - private native void allocate(@Cast("char*") String name); - public LogicOp(@Cast("char*") BytePointer name) { super((Pointer)null); allocate(name); } - private native void allocate(@Cast("char*") BytePointer name); - - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } - +/** + * Logic ops are unique snowflakes in any Graph. They dramatically change Graph + * Execution process, by introducing loops, conditions, etc. + * + * Their code is the part of GraphExecutioner logic. But we still want them to + * be expressed via Graph + * \tparam T + */ +@Namespace("sd::ops") public static class LogicOp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public LogicOp(Pointer p) { super(p); } + public LogicOp(@Cast("char*") String name) { super((Pointer)null); allocate(name); } + private native void allocate(@Cast("char*") String name); + public LogicOp(@Cast("char*") BytePointer name) { super((Pointer)null); allocate(name); } + private native void allocate(@Cast("char*") BytePointer name); + public native ShapeList calculateOutputShape(ShapeList inputShape, + @ByRef Context block); +} + // namespace ops + // namespace sd -// #endif //LIBND4J_LOGICOP_H +// #endif // LIBND4J_LOGICOP_H // Parsed from ops/declarable/OpRegistrator.h @@ -12525,76 +14818,84 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_OPREGISTRATOR_H // #define LIBND4J_OPREGISTRATOR_H -// #include -// #include -// #include -// #include +// #include // #include // #include -// #include +// #include + +// #include +// #include +// #include // handlers part -// #include // #include +// #include // #ifndef __JAVACPP_HACK__ // #endif - /** - * This class provides runtime ops lookup, based on opName or opHash. - * To build lookup directory we use *_OP_IMPL macro, which puts static structs at compile time in .cpp files, - * so once binary is executed, static objects are initialized automatically, and we get list of all ops - * available at runtime via this singleton. - * - */ - @Namespace("sd::ops") @NoOffset public static class OpRegistrator extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public OpRegistrator(Pointer p) { super(p); } - +/** + * This class provides runtime ops lookup, based on opName or opHash. + * To build lookup directory we use *_OP_IMPL macro, which puts static structs + * at compile time in .cpp files, so once binary is executed, static objects are + * initialized automatically, and we get list of all ops available at runtime + * via this singleton. + * + */ +@Namespace("sd::ops") @NoOffset public static class OpRegistrator extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public OpRegistrator(Pointer p) { super(p); } - public static native @ByRef OpRegistrator getInstance(); - public static native void exitHandler(); - public static native void sigIntHandler(int sig); - public static native void sigSegVHandler(int sig); + public static native @ByRef OpRegistrator getInstance(); - - public native @Cast("char*") String getAllCustomOperations(); + public static native void exitHandler(); + public static native void sigIntHandler(int sig); + public static native void sigSegVHandler(int sig); - /** - * This method registers operation in our registry, so we can use them later - * - * @param op - */ - public native @Cast("bool") boolean registerOperation(@Cast("char*") String name, DeclarableOp op); - public native @Cast("bool") boolean registerOperation(@Cast("char*") BytePointer name, DeclarableOp op); - public native @Cast("bool") boolean registerOperation(DeclarableOp op); + + public native @Cast("char*") String getAllCustomOperations(); - public native void registerHelper(PlatformHelper op); + /** + * This method registers operation in our registry, so we can use them later + * + * @param op + */ + public native @Cast("bool") boolean registerOperation(@StdString BytePointer opName, + @SharedPtr DeclarableOp op); + public native @Cast("bool") boolean registerOperation(@StdString String opName, + @SharedPtr DeclarableOp op); + public native @Cast("bool") boolean registerOperation(@SharedPtr DeclarableOp op); - public native @Cast("bool") boolean hasHelper(@Cast("Nd4jLong") long hash, @Cast("samediff::Engine") int engine); + public native void registerHelper(PlatformHelper op); - public native DeclarableOp getOperation(@Cast("char*") String name); - public native DeclarableOp getOperation(@Cast("char*") BytePointer name); - public native DeclarableOp getOperation(@Cast("Nd4jLong") long hash); + public native @Cast("bool") boolean hasHelper(@Cast("Nd4jLong") long hash, @Cast("samediff::Engine") int engine); - public native PlatformHelper getPlatformHelper(@Cast("Nd4jLong") long hash, @Cast("samediff::Engine") int engine); + public native @SharedPtr DeclarableOp getOperation(@Cast("Nd4jLong") long hash); + public native @SharedPtr DeclarableOp getOperation(@StdString BytePointer name); + public native @SharedPtr DeclarableOp getOperation(@StdString String name); - public native @Cast("Nd4jLong*") @StdVector LongPointer getAllHashes(); + public native @Cast("bool") boolean hasOperation(@StdString BytePointer opName); + public native @Cast("bool") boolean hasOperation(@StdString String opName); + public native @Cast("bool") boolean hasOperation(@Cast("const Nd4jLong") long opName); - public native int numberOfOperations(); - } + public native PlatformHelper getPlatformHelper( + @Cast("Nd4jLong") long hash, @Cast("samediff::Engine") int engine); + public native @Cast("Nd4jLong*") @StdVector LongPointer getAllHashes(); - /* - * These structs are used to "register" our ops in OpRegistrator. - */ + public native int numberOfOperations(); +} - +/* + * These structs are used to "register" our ops in OpRegistrator. + */ + // namespace ops + // namespace sd -// #endif //LIBND4J_OPREGISTRATOR_H +// #endif // LIBND4J_OPREGISTRATOR_H // Parsed from ops/declarable/CustomOperations.h @@ -12622,156 +14923,156 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_CUSTOMOPERATIONS_H // #define LIBND4J_CUSTOMOPERATIONS_H +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include // #include +// #include +// #include // #include // #include -// #include +// #include // #include +// #include +// #include +// #include +// #include // #include -// #include -// #include -// #include -// #include -// #include +// #include // #include // #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include +// #include +// #include +// #include +// #include // #include -// #include -// #include -// #include -// #include +// #include +// #include +// #include // #include +// #include // #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include - @Namespace("sd") public static class _loader extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public _loader(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public _loader(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public _loader position(long position) { - return (_loader)super.position(position); - } - - public _loader() { super((Pointer)null); allocate(); } - private native void allocate(); +@Namespace("sd") public static class _loader extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public _loader(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public _loader(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public _loader position(long position) { + return (_loader)super.position(position); + } + + public _loader() { super((Pointer)null); allocate(); } + private native void allocate(); +} + +// logic ops +@Namespace("sd::ops") public static class Switch extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Switch(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Switch(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Switch position(long position) { + return (Switch)super.position(position); } - // logic ops - @Namespace("sd::ops") public static class Switch extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Switch(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Switch(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Switch position(long position) { - return (Switch)super.position(position); - } - public Switch() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class While extends LogicOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public While(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public While(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public While position(long position) { - return (While)super.position(position); - } - +@Namespace("sd::ops") public static class While extends LogicOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public While(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public While(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public While position(long position) { + return (While)super.position(position); + } + public While() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class Scope extends LogicOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Scope(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Scope(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Scope position(long position) { - return (Scope)super.position(position); - } - +@Namespace("sd::ops") public static class Scope extends LogicOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Scope(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Scope(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Scope position(long position) { + return (Scope)super.position(position); + } + public Scope() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class Conditional extends LogicOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Conditional(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Conditional(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Conditional position(long position) { - return (Conditional)super.position(position); - } - +@Namespace("sd::ops") public static class Conditional extends LogicOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Conditional(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Conditional(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Conditional position(long position) { + return (Conditional)super.position(position); + } + public Conditional() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class Return extends LogicOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Return(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Return(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Return position(long position) { - return (Return)super.position(position); - } - +@Namespace("sd::ops") public static class Return extends LogicOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Return(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Return(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Return position(long position) { + return (Return)super.position(position); + } + public Return() { super((Pointer)null); allocate(); } private native void allocate(); } +/** + * This operations exposes given arguments as it's own outputs, but does it only + * once. Subsequent calls will be served directly by this op. + * + * PLEASE NOTE: This operation is internal graph operation, and shouldn't be + * used directly usually. + */ +@Namespace("sd::ops") public static class expose extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public expose(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public expose(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public expose position(long position) { + return (expose)super.position(position); + } - /** - * This operations exposes given arguments as it's own outputs, but does it only once. - * Subsequent calls will be served directly by this op. - * - * PLEASE NOTE: This operation is internal graph operation, and shouldn't be used directly usually. - */ - @Namespace("sd::ops") public static class expose extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public expose(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public expose(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public expose position(long position) { - return (expose)super.position(position); - } - public expose() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - + // namespace ops + // namespace sd - -// #endif //LIBND4J_CUSTOMOPERATIONS_H +// #endif // LIBND4J_CUSTOMOPERATIONS_H // Parsed from ops/declarable/headers/activations.h @@ -12799,696 +15100,693 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_HEADERS_ACTIVATIONS_H // #define LIBND4J_HEADERS_ACTIVATIONS_H - // #include - /** - * This is Sigmoid activation function implementation - * Math is: 1 / 1 + exp(-x) - */ -// #if NOT_EXCLUDED(OP_sigmoid) - @Namespace("sd::ops") public static class sigmoid extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sigmoid(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sigmoid(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sigmoid position(long position) { - return (sigmoid)super.position(position); - } - +/** + * This is Sigmoid activation function implementation + * Math is: 1 / 1 + exp(-x) + */ +// #if NOT_EXCLUDED(OP_sigmoid) +@Namespace("sd::ops") public static class sigmoid extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sigmoid(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sigmoid(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sigmoid position(long position) { + return (sigmoid)super.position(position); + } + public sigmoid() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class sigmoid_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sigmoid_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sigmoid_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sigmoid_bp position(long position) { - return (sigmoid_bp)super.position(position); - } - +@Namespace("sd::ops") public static class sigmoid_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sigmoid_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sigmoid_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sigmoid_bp position(long position) { + return (sigmoid_bp)super.position(position); + } + public sigmoid_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is Softsign activation function implementation + * Math is: x / 1 + abs(x) + */ +// #if NOT_EXCLUDED(OP_softsign) +@Namespace("sd::ops") public static class softsign extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public softsign(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public softsign(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public softsign position(long position) { + return (softsign)super.position(position); + } - /** - * This is Softsign activation function implementation - * Math is: x / 1 + abs(x) - */ -// #if NOT_EXCLUDED(OP_softsign) - @Namespace("sd::ops") public static class softsign extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public softsign(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public softsign(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public softsign position(long position) { - return (softsign)super.position(position); - } - public softsign() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class softsign_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public softsign_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public softsign_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public softsign_bp position(long position) { - return (softsign_bp)super.position(position); - } - +@Namespace("sd::ops") public static class softsign_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public softsign_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public softsign_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public softsign_bp position(long position) { + return (softsign_bp)super.position(position); + } + public softsign_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is Tanh activation function implementation + */ +// #if NOT_EXCLUDED(OP_tanh) +@Namespace("sd::ops") public static class tanh extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public tanh(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public tanh(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public tanh position(long position) { + return (tanh)super.position(position); + } - /** - * This is Tanh activation function implementation - */ -// #if NOT_EXCLUDED(OP_tanh) - @Namespace("sd::ops") public static class tanh extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public tanh(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public tanh(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public tanh position(long position) { - return (tanh)super.position(position); - } - public tanh() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class tanh_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public tanh_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public tanh_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public tanh_bp position(long position) { - return (tanh_bp)super.position(position); - } - +@Namespace("sd::ops") public static class tanh_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public tanh_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public tanh_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public tanh_bp position(long position) { + return (tanh_bp)super.position(position); + } + public tanh_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is Softplus activation function implementation + * Math is: log(1 + exp(x)) + */ +// #if NOT_EXCLUDED(OP_softplus) +@Namespace("sd::ops") public static class softplus extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public softplus(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public softplus(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public softplus position(long position) { + return (softplus)super.position(position); + } - /** - * This is Softplus activation function implementation - * Math is: log(1 + exp(x)) - */ -// #if NOT_EXCLUDED(OP_softplus) - @Namespace("sd::ops") public static class softplus extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public softplus(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public softplus(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public softplus position(long position) { - return (softplus)super.position(position); - } - public softplus() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class softplus_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public softplus_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public softplus_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public softplus_bp position(long position) { - return (softplus_bp)super.position(position); - } - +@Namespace("sd::ops") public static class softplus_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public softplus_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public softplus_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public softplus_bp position(long position) { + return (softplus_bp)super.position(position); + } + public softplus_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is RELU activation function implementation + */ +// #if NOT_EXCLUDED(OP_relu) +@Namespace("sd::ops") public static class relu extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public relu(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public relu(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public relu position(long position) { + return (relu)super.position(position); + } - /** - * This is RELU activation function implementation - */ -// #if NOT_EXCLUDED(OP_relu) - @Namespace("sd::ops") public static class relu extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public relu(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public relu(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public relu position(long position) { - return (relu)super.position(position); - } - public relu() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class relu_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public relu_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public relu_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public relu_bp position(long position) { - return (relu_bp)super.position(position); - } - +@Namespace("sd::ops") public static class relu_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public relu_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public relu_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public relu_bp position(long position) { + return (relu_bp)super.position(position); + } + public relu_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is SELU activation function implementation + */ +// #if NOT_EXCLUDED(OP_selu) +@Namespace("sd::ops") public static class selu extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public selu(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public selu(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public selu position(long position) { + return (selu)super.position(position); + } - /** - * This is SELU activation function implementation - */ -// #if NOT_EXCLUDED(OP_selu) - @Namespace("sd::ops") public static class selu extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public selu(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public selu(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public selu position(long position) { - return (selu)super.position(position); - } - public selu() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class selu_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public selu_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public selu_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public selu_bp position(long position) { - return (selu_bp)super.position(position); - } - +@Namespace("sd::ops") public static class selu_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public selu_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public selu_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public selu_bp position(long position) { + return (selu_bp)super.position(position); + } + public selu_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is Leaky RELU activation function. + * Math is: x < 0 ? alpha * x : x; + */ +// #if NOT_EXCLUDED(OP_lrelu) +@Namespace("sd::ops") public static class lrelu extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lrelu(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lrelu(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lrelu position(long position) { + return (lrelu)super.position(position); + } - /** - * This is Leaky RELU activation function. - * Math is: x < 0 ? alpha * x : x; - */ -// #if NOT_EXCLUDED(OP_lrelu) - @Namespace("sd::ops") public static class lrelu extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lrelu(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lrelu(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lrelu position(long position) { - return (lrelu)super.position(position); - } - public lrelu() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class lrelu_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lrelu_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lrelu_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lrelu_bp position(long position) { - return (lrelu_bp)super.position(position); - } - +@Namespace("sd::ops") public static class lrelu_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lrelu_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lrelu_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lrelu_bp position(long position) { + return (lrelu_bp)super.position(position); + } + public lrelu_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op is ELU activation function. + * Math is: x >= 0 ? x : exp(x) - 1; + */ +// #if NOT_EXCLUDED(OP_elu) +@Namespace("sd::ops") public static class elu extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public elu(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public elu(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public elu position(long position) { + return (elu)super.position(position); + } - /** - * This op is ELU activation function. - * Math is: x >= 0 ? x : exp(x) - 1; - */ -// #if NOT_EXCLUDED(OP_elu) - @Namespace("sd::ops") public static class elu extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public elu(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public elu(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public elu position(long position) { - return (elu)super.position(position); - } - public elu() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class elu_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public elu_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public elu_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public elu_bp position(long position) { - return (elu_bp)super.position(position); - } - +@Namespace("sd::ops") public static class elu_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public elu_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public elu_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public elu_bp position(long position) { + return (elu_bp)super.position(position); + } + public elu_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is Cube activation function. + * Math is: x^3 + */ +// #if NOT_EXCLUDED(OP_cube) +@Namespace("sd::ops") public static class cube extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cube(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cube(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cube position(long position) { + return (cube)super.position(position); + } - /** - * This is Cube activation function. - * Math is: x^3 - */ -// #if NOT_EXCLUDED(OP_cube) - @Namespace("sd::ops") public static class cube extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cube(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cube(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cube position(long position) { - return (cube)super.position(position); - } - public cube() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class cube_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cube_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cube_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cube_bp position(long position) { - return (cube_bp)super.position(position); - } - +@Namespace("sd::ops") public static class cube_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cube_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cube_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cube_bp position(long position) { + return (cube_bp)super.position(position); + } + public cube_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is RectifiedTanh activation function. + * Math is: max(0, tanh(x)) + */ +// #if NOT_EXCLUDED(OP_rectifiedtanh) +@Namespace("sd::ops") public static class rectifiedtanh extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public rectifiedtanh(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public rectifiedtanh(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public rectifiedtanh position(long position) { + return (rectifiedtanh)super.position(position); + } - /** - * This is RectifiedTanh activation function. - * Math is: max(0, tanh(x)) - */ -// #if NOT_EXCLUDED(OP_rectifiedtanh) - @Namespace("sd::ops") public static class rectifiedtanh extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public rectifiedtanh(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public rectifiedtanh(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public rectifiedtanh position(long position) { - return (rectifiedtanh)super.position(position); - } - public rectifiedtanh() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class rectifiedtanh_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public rectifiedtanh_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public rectifiedtanh_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public rectifiedtanh_bp position(long position) { - return (rectifiedtanh_bp)super.position(position); - } - +@Namespace("sd::ops") public static class rectifiedtanh_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public rectifiedtanh_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public rectifiedtanh_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public rectifiedtanh_bp position(long position) { + return (rectifiedtanh_bp)super.position(position); + } + public rectifiedtanh_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is RationalTanh activation function. + */ +// #if NOT_EXCLUDED(OP_rationaltanh) +@Namespace("sd::ops") public static class rationaltanh extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public rationaltanh(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public rationaltanh(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public rationaltanh position(long position) { + return (rationaltanh)super.position(position); + } - /** - * This is RationalTanh activation function. - */ -// #if NOT_EXCLUDED(OP_rationaltanh) - @Namespace("sd::ops") public static class rationaltanh extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public rationaltanh(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public rationaltanh(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public rationaltanh position(long position) { - return (rationaltanh)super.position(position); - } - public rationaltanh() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class rationaltanh_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public rationaltanh_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public rationaltanh_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public rationaltanh_bp position(long position) { - return (rationaltanh_bp)super.position(position); - } - +@Namespace("sd::ops") public static class rationaltanh_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public rationaltanh_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public rationaltanh_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public rationaltanh_bp position(long position) { + return (rationaltanh_bp)super.position(position); + } + public rationaltanh_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is HardTanh activation function. + * Math is: x < -1.0 ? -1.0 : x > 1.0 ? 1.0 : x; + */ +// #if NOT_EXCLUDED(OP_hardtanh) +@Namespace("sd::ops") public static class hardtanh extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public hardtanh(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public hardtanh(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public hardtanh position(long position) { + return (hardtanh)super.position(position); + } - /** - * This is HardTanh activation function. - * Math is: x < -1.0 ? -1.0 : x > 1.0 ? 1.0 : x; - */ -// #if NOT_EXCLUDED(OP_hardtanh) - @Namespace("sd::ops") public static class hardtanh extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public hardtanh(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public hardtanh(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public hardtanh position(long position) { - return (hardtanh)super.position(position); - } - public hardtanh() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class hardtanh_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public hardtanh_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public hardtanh_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public hardtanh_bp position(long position) { - return (hardtanh_bp)super.position(position); - } - +@Namespace("sd::ops") public static class hardtanh_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public hardtanh_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public hardtanh_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public hardtanh_bp position(long position) { + return (hardtanh_bp)super.position(position); + } + public hardtanh_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is HardSigmoid activation function. + * Math is: min(1, max(0, 0.2 * x + 0.5)) + */ +// #if NOT_EXCLUDED(OP_hardsigmoid) +@Namespace("sd::ops") public static class hardsigmoid extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public hardsigmoid(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public hardsigmoid(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public hardsigmoid position(long position) { + return (hardsigmoid)super.position(position); + } - /** - * This is HardSigmoid activation function. - * Math is: min(1, max(0, 0.2 * x + 0.5)) - */ -// #if NOT_EXCLUDED(OP_hardsigmoid) - @Namespace("sd::ops") public static class hardsigmoid extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public hardsigmoid(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public hardsigmoid(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public hardsigmoid position(long position) { - return (hardsigmoid)super.position(position); - } - public hardsigmoid() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class hardsigmoid_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public hardsigmoid_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public hardsigmoid_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public hardsigmoid_bp position(long position) { - return (hardsigmoid_bp)super.position(position); - } - +@Namespace("sd::ops") public static class hardsigmoid_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public hardsigmoid_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public hardsigmoid_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public hardsigmoid_bp position(long position) { + return (hardsigmoid_bp)super.position(position); + } + public hardsigmoid_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is Indentity operation. It passes signal umodified in both directions. + */ +// #if NOT_EXCLUDED(OP_identity) +@Namespace("sd::ops") public static class identity extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public identity(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public identity(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public identity position(long position) { + return (identity)super.position(position); + } - /** - * This is Indentity operation. It passes signal umodified in both directions. - */ -// #if NOT_EXCLUDED(OP_identity) - @Namespace("sd::ops") public static class identity extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public identity(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public identity(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public identity position(long position) { - return (identity)super.position(position); - } - public identity() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class identity_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public identity_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public identity_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public identity_bp position(long position) { - return (identity_bp)super.position(position); - } - +@Namespace("sd::ops") public static class identity_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public identity_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public identity_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public identity_bp position(long position) { + return (identity_bp)super.position(position); + } + public identity_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is Indentity operation. It passes signal umodified in both directions. + */ +// #if NOT_EXCLUDED(OP_identity_n) +@Namespace("sd::ops") public static class identity_n extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public identity_n(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public identity_n(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public identity_n position(long position) { + return (identity_n)super.position(position); + } - /** - * This is Indentity operation. It passes signal umodified in both directions. - */ -// #if NOT_EXCLUDED(OP_identity_n) - @Namespace("sd::ops") public static class identity_n extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public identity_n(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public identity_n(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public identity_n position(long position) { - return (identity_n)super.position(position); - } - public identity_n() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is Concatenated RELU implementation. + * What happens inside: RELU(Concat((x, -x, {-1}))) + * + * PLEASE NOTE: Concatenation will double amount of features available in input + */ +// #if NOT_EXCLUDED(OP_crelu) +@Namespace("sd::ops") public static class crelu extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public crelu(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public crelu(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public crelu position(long position) { + return (crelu)super.position(position); + } - /** - * This is Concatenated RELU implementation. - * What happens inside: RELU(Concat((x, -x, {-1}))) - * - * PLEASE NOTE: Concatenation will double amount of features available in input - */ -// #if NOT_EXCLUDED(OP_crelu) - @Namespace("sd::ops") public static class crelu extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public crelu(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public crelu(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public crelu position(long position) { - return (crelu)super.position(position); - } - public crelu() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class crelu_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public crelu_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public crelu_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public crelu_bp position(long position) { - return (crelu_bp)super.position(position); - } - +@Namespace("sd::ops") public static class crelu_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public crelu_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public crelu_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public crelu_bp position(long position) { + return (crelu_bp)super.position(position); + } + public crelu_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is RELU6 activation function implementation + */ +// #if NOT_EXCLUDED(OP_relu6) +@Namespace("sd::ops") public static class relu6 extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public relu6(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public relu6(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public relu6 position(long position) { + return (relu6)super.position(position); + } - /** - * This is RELU6 activation function implementation - */ -// #if NOT_EXCLUDED(OP_relu6) - @Namespace("sd::ops") public static class relu6 extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public relu6(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public relu6(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public relu6 position(long position) { - return (relu6)super.position(position); - } - public relu6() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class relu6_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public relu6_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public relu6_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public relu6_bp position(long position) { - return (relu6_bp)super.position(position); - } - +@Namespace("sd::ops") public static class relu6_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public relu6_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public relu6_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public relu6_bp position(long position) { + return (relu6_bp)super.position(position); + } + public relu6_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * Parametric Rectified Linear Unit + * f(x) = alpha * x for x < 0, f(x) = x for x >= 0 + */ +// #if NOT_EXCLUDED(OP_prelu) +@Namespace("sd::ops") public static class prelu extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public prelu(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public prelu(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public prelu position(long position) { + return (prelu)super.position(position); + } - /** - * Parametric Rectified Linear Unit - * f(x) = alpha * x for x < 0, f(x) = x for x >= 0 - */ -// #if NOT_EXCLUDED(OP_prelu) - @Namespace("sd::ops") public static class prelu extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public prelu(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public prelu(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public prelu position(long position) { - return (prelu)super.position(position); - } - public prelu() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class prelu_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public prelu_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public prelu_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public prelu_bp position(long position) { - return (prelu_bp)super.position(position); - } - +@Namespace("sd::ops") public static class prelu_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public prelu_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public prelu_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public prelu_bp position(long position) { + return (prelu_bp)super.position(position); + } + public prelu_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Thresholded Rectified Linear Unit + * f(x) = x for x > theta, f(x) = 0 otherwise + * theta must be >= 0 + */ +// #if NOT_EXCLUDED(OP_thresholdedrelu) +@Namespace("sd::ops") public static class thresholdedrelu extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public thresholdedrelu(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public thresholdedrelu(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public thresholdedrelu position(long position) { + return (thresholdedrelu)super.position(position); + } - /** - * Thresholded Rectified Linear Unit - * f(x) = x for x > theta, f(x) = 0 otherwise - * theta must be >= 0 - */ -// #if NOT_EXCLUDED(OP_thresholdedrelu) - @Namespace("sd::ops") public static class thresholdedrelu extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public thresholdedrelu(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public thresholdedrelu(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public thresholdedrelu position(long position) { - return (thresholdedrelu)super.position(position); - } - public thresholdedrelu() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class thresholdedrelu_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public thresholdedrelu_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public thresholdedrelu_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public thresholdedrelu_bp position(long position) { - return (thresholdedrelu_bp)super.position(position); - } - +@Namespace("sd::ops") public static class thresholdedrelu_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public thresholdedrelu_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public thresholdedrelu_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public thresholdedrelu_bp position(long position) { + return (thresholdedrelu_bp)super.position(position); + } + public thresholdedrelu_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - - +// #endif + // namespace ops + // namespace sd // #endif @@ -13519,319 +15817,323 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #include - /** - * This is scalar boolean op. - * Both operands should be scalars. - * - * Returns true if x < y - */ -// #if NOT_EXCLUDED(OP_lt_scalar) - @Namespace("sd::ops") public static class lt_scalar extends BooleanOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lt_scalar(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lt_scalar(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lt_scalar position(long position) { - return (lt_scalar)super.position(position); - } - +/** + * This is scalar boolean op. + * Both operands should be scalars. + * + * Returns true if x < y + */ +// #if NOT_EXCLUDED(OP_lt_scalar) +@Namespace("sd::ops") public static class lt_scalar extends BooleanOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lt_scalar(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lt_scalar(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lt_scalar position(long position) { + return (lt_scalar)super.position(position); + } + public lt_scalar() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This is scalar boolean op. + * Both operands should be scalars. + * + * Returns true if x > y + */ +// #if NOT_EXCLUDED(OP_gt_scalar) +@Namespace("sd::ops") public static class gt_scalar extends BooleanOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public gt_scalar(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public gt_scalar(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public gt_scalar position(long position) { + return (gt_scalar)super.position(position); + } - /** - * This is scalar boolean op. - * Both operands should be scalars. - * - * Returns true if x > y - */ -// #if NOT_EXCLUDED(OP_gt_scalar) - @Namespace("sd::ops") public static class gt_scalar extends BooleanOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public gt_scalar(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public gt_scalar(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public gt_scalar position(long position) { - return (gt_scalar)super.position(position); - } - public gt_scalar() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This is scalar boolean op. + * Both operands should be scalars. + * + * Returns true if x <= y + */ +// #if NOT_EXCLUDED(OP_lte_scalar) +@Namespace("sd::ops") public static class lte_scalar extends BooleanOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lte_scalar(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lte_scalar(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lte_scalar position(long position) { + return (lte_scalar)super.position(position); + } - /** - * This is scalar boolean op. - * Both operands should be scalars. - * - * Returns true if x <= y - */ -// #if NOT_EXCLUDED(OP_lte_scalar) - @Namespace("sd::ops") public static class lte_scalar extends BooleanOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lte_scalar(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lte_scalar(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lte_scalar position(long position) { - return (lte_scalar)super.position(position); - } - public lte_scalar() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This is scalar boolean op. + * Both operands should be scalars. + * + * Returns true if x >= y + */ +// #if NOT_EXCLUDED(OP_gte_scalar) +@Namespace("sd::ops") public static class gte_scalar extends BooleanOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public gte_scalar(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public gte_scalar(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public gte_scalar position(long position) { + return (gte_scalar)super.position(position); + } - /** - * This is scalar boolean op. - * Both operands should be scalars. - * - * Returns true if x >= y - */ -// #if NOT_EXCLUDED(OP_gte_scalar) - @Namespace("sd::ops") public static class gte_scalar extends BooleanOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public gte_scalar(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public gte_scalar(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public gte_scalar position(long position) { - return (gte_scalar)super.position(position); - } - public gte_scalar() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This is scalar boolean op. + * Both operands should be scalars. + * + * Returns true if both operands are equal. + */ +// #if NOT_EXCLUDED(OP_eq_scalar) +@Namespace("sd::ops") public static class eq_scalar extends BooleanOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public eq_scalar(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public eq_scalar(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public eq_scalar position(long position) { + return (eq_scalar)super.position(position); + } - /** - * This is scalar boolean op. - * Both operands should be scalars. - * - * Returns true if both operands are equal. - */ -// #if NOT_EXCLUDED(OP_eq_scalar) - @Namespace("sd::ops") public static class eq_scalar extends BooleanOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public eq_scalar(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public eq_scalar(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public eq_scalar position(long position) { - return (eq_scalar)super.position(position); - } - public eq_scalar() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This is scalar boolean op. + * Both operands should be scalars. + * + * Returns true if x != y + */ +// #if NOT_EXCLUDED(OP_neq_scalar) +@Namespace("sd::ops") public static class neq_scalar extends BooleanOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public neq_scalar(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public neq_scalar(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public neq_scalar position(long position) { + return (neq_scalar)super.position(position); + } - /** - * This is scalar boolean op. - * Both operands should be scalars. - * - * Returns true if x != y - */ -// #if NOT_EXCLUDED(OP_neq_scalar) - @Namespace("sd::ops") public static class neq_scalar extends BooleanOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public neq_scalar(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public neq_scalar(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public neq_scalar position(long position) { - return (neq_scalar)super.position(position); - } - public neq_scalar() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This op takes 2 n-dimensional arrays as input, and return + * array of the same shape, with elements, either from x or y, depending on the + * condition. + */ +// #if NOT_EXCLUDED(OP_where) +@Namespace("sd::ops") public static class Where extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Where(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Where(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Where position(long position) { + return (Where)super.position(position); + } - /** - * This op takes 2 n-dimensional arrays as input, and return - * array of the same shape, with elements, either from x or y, depending on the condition. - */ -// #if NOT_EXCLUDED(OP_where) - @Namespace("sd::ops") public static class Where extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Where(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Where(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Where position(long position) { - return (Where)super.position(position); - } - public Where() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_where_np) +@Namespace("sd::ops") public static class where_np extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public where_np(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public where_np(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public where_np position(long position) { + return (where_np)super.position(position); + } -// #if NOT_EXCLUDED(OP_where_np) - @Namespace("sd::ops") public static class where_np extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public where_np(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public where_np(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public where_np position(long position) { - return (where_np)super.position(position); - } - public where_np() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op takes 2 n-dimensional arrays as input, and return + * array of the same shape, with elements, either from x or y, depending on the + * condition. + */ +// #if NOT_EXCLUDED(OP_select) +@Namespace("sd::ops") public static class select extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public select(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public select(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public select position(long position) { + return (select)super.position(position); + } - /** - * This op takes 2 n-dimensional arrays as input, and return - * array of the same shape, with elements, either from x or y, depending on the condition. - */ -// #if NOT_EXCLUDED(OP_select) - @Namespace("sd::ops") public static class select extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public select(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public select(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public select position(long position) { - return (select)super.position(position); - } - public select() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op takes either 1 argument and 1 scalar + * or 1 argument and another comparison array + * and runs a pre defined conditional op. + * + * The output of the op is dynamic in size and returns a flat vector of + * elements that return true on the given condition. In numpy parlance, most + * people might understand: a[a > 2] where a is a numpy array and the condition + * is true when an element is > 2. Libnd4j already implements a number of pre + * defined conditions. + * \tparam T + */ +// #if NOT_EXCLUDED(OP_choose) +@Namespace("sd::ops") public static class choose extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public choose(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public choose(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public choose position(long position) { + return (choose)super.position(position); + } - /** - * This op takes either 1 argument and 1 scalar - * or 1 argument and another comparison array - * and runs a pre defined conditional op. - * - * The output of the op is dynamic in size and returns a flat vector of elements - * that return true on the given condition. - * In numpy parlance, most people might understand: - * a[a > 2] - * where a is a numpy array and the condition is true when an element is - * > 2. Libnd4j already implements a number of pre defined conditions. - * \tparam T - */ -// #if NOT_EXCLUDED(OP_choose) - @Namespace("sd::ops") public static class choose extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public choose(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public choose(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public choose position(long position) { - return (choose)super.position(position); - } - public choose() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op takes 1 n-dimensional array as input, and returns true if for every + * adjacent pair we have x[i] <= x[i+1]. + */ +// #if NOT_EXCLUDED(OP_is_non_decreasing) +@Namespace("sd::ops") public static class is_non_decreasing extends BooleanOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public is_non_decreasing(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public is_non_decreasing(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public is_non_decreasing position(long position) { + return (is_non_decreasing)super.position(position); + } - /** - * This op takes 1 n-dimensional array as input, and returns true if for every adjacent pair we have x[i] <= x[i+1]. - */ -// #if NOT_EXCLUDED(OP_is_non_decreasing) - @Namespace("sd::ops") public static class is_non_decreasing extends BooleanOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public is_non_decreasing(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public is_non_decreasing(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public is_non_decreasing position(long position) { - return (is_non_decreasing)super.position(position); - } - public is_non_decreasing() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This op takes 1 n-dimensional array as input, and returns true if for every + * adjacent pair we have x[i] < x[i+1]. + */ +// #if NOT_EXCLUDED(OP_is_strictly_increasing) +@Namespace("sd::ops") public static class is_strictly_increasing extends BooleanOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public is_strictly_increasing(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public is_strictly_increasing(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public is_strictly_increasing position(long position) { + return (is_strictly_increasing)super.position(position); + } - /** - * This op takes 1 n-dimensional array as input, and returns true if for every adjacent pair we have x[i] < x[i+1]. - */ -// #if NOT_EXCLUDED(OP_is_strictly_increasing) - @Namespace("sd::ops") public static class is_strictly_increasing extends BooleanOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public is_strictly_increasing(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public is_strictly_increasing(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public is_strictly_increasing position(long position) { - return (is_strictly_increasing)super.position(position); - } - public is_strictly_increasing() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This op takes 1 n-dimensional array as input, and returns true if input is a + * numeric array. + */ +// #if NOT_EXCLUDED(OP_is_numeric_tensor) +@Namespace("sd::ops") public static class is_numeric_tensor extends BooleanOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public is_numeric_tensor(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public is_numeric_tensor(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public is_numeric_tensor position(long position) { + return (is_numeric_tensor)super.position(position); + } - /** - * This op takes 1 n-dimensional array as input, and returns true if input is a numeric array. - */ -// #if NOT_EXCLUDED(OP_is_numeric_tensor) - @Namespace("sd::ops") public static class is_numeric_tensor extends BooleanOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public is_numeric_tensor(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public is_numeric_tensor(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public is_numeric_tensor position(long position) { - return (is_numeric_tensor)super.position(position); - } - public is_numeric_tensor() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * + */ +// #if NOT_EXCLUDED(OP_boolean_not) +@Namespace("sd::ops") public static class boolean_not extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public boolean_not(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public boolean_not(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public boolean_not position(long position) { + return (boolean_not)super.position(position); + } - /** - * - */ -// #if NOT_EXCLUDED(OP_boolean_not) - @Namespace("sd::ops") public static class boolean_not extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public boolean_not(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public boolean_not(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public boolean_not position(long position) { - return (boolean_not)super.position(position); - } - public boolean_not() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - +// #endif + // namespace ops + // namespace sd // #endif @@ -13860,280 +16162,293 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_HEADERS_BROADCASTABLE_H // #define LIBND4J_HEADERS_BROADCASTABLE_H -// #include // #include -// #include +// #include // #include - // TODO: make broadcastables separate class +// #include +// TODO: make broadcastables separate class + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Max(X, Y) + */ +// #if NOT_EXCLUDED(OP_maximum) +@Namespace("sd::ops") public static class maximum extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public maximum(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public maximum(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public maximum position(long position) { + return (maximum)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Max(X, Y) - */ -// #if NOT_EXCLUDED(OP_maximum) - @Namespace("sd::ops") public static class maximum extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public maximum(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public maximum(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public maximum position(long position) { - return (maximum)super.position(position); - } - public maximum() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class maximum_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public maximum_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public maximum_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public maximum_bp position(long position) { - return (maximum_bp)super.position(position); - } - +@Namespace("sd::ops") public static class maximum_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public maximum_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public maximum_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public maximum_bp position(long position) { + return (maximum_bp)super.position(position); + } + public maximum_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Min(X, Y) + */ +// #if NOT_EXCLUDED(OP_minimum) +@Namespace("sd::ops") public static class minimum extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public minimum(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public minimum(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public minimum position(long position) { + return (minimum)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Min(X, Y) - */ -// #if NOT_EXCLUDED(OP_minimum) - @Namespace("sd::ops") public static class minimum extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public minimum(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public minimum(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public minimum position(long position) { - return (minimum)super.position(position); - } - public minimum() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class minimum_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public minimum_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public minimum_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public minimum_bp position(long position) { - return (minimum_bp)super.position(position); - } - +@Namespace("sd::ops") public static class minimum_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public minimum_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public minimum_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public minimum_bp position(long position) { + return (minimum_bp)super.position(position); + } + public minimum_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Add(X, Y) + */ +// #if NOT_EXCLUDED(OP_add) +@Namespace("sd::ops") public static class add extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public add(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public add(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public add position(long position) { + return (add)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Add(X, Y) - */ -// #if NOT_EXCLUDED(OP_add) - @Namespace("sd::ops") public static class add extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public add(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public add(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public add position(long position) { - return (add)super.position(position); - } - public add() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class add_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public add_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public add_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public add_bp position(long position) { - return (add_bp)super.position(position); - } - +@Namespace("sd::ops") public static class add_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public add_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public add_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public add_bp position(long position) { + return (add_bp)super.position(position); + } + public add_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Subtract(X, Y) + */ +// #if NOT_EXCLUDED(OP_subtract) +@Namespace("sd::ops") public static class subtract extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public subtract(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public subtract(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public subtract position(long position) { + return (subtract)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Subtract(X, Y) - */ -// #if NOT_EXCLUDED(OP_subtract) - @Namespace("sd::ops") public static class subtract extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public subtract(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public subtract(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public subtract position(long position) { - return (subtract)super.position(position); - } - public subtract() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class subtract_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public subtract_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public subtract_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public subtract_bp position(long position) { - return (subtract_bp)super.position(position); - } - +@Namespace("sd::ops") public static class subtract_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public subtract_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public subtract_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public subtract_bp position(long position) { + return (subtract_bp)super.position(position); + } + public subtract_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Subtract(Y, X) + */ +// #if NOT_EXCLUDED(OP_reversesubtract) +@Namespace("sd::ops") public static class reversesubtract extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reversesubtract(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reversesubtract(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reversesubtract position(long position) { + return (reversesubtract)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Subtract(Y, X) - */ -// #if NOT_EXCLUDED(OP_reversesubtract) - @Namespace("sd::ops") public static class reversesubtract extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reversesubtract(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reversesubtract(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reversesubtract position(long position) { - return (reversesubtract)super.position(position); - } - public reversesubtract() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class reversesubtract_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reversesubtract_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reversesubtract_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reversesubtract_bp position(long position) { - return (reversesubtract_bp)super.position(position); - } - +@Namespace("sd::ops") public static class reversesubtract_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reversesubtract_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reversesubtract_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reversesubtract_bp position(long position) { + return (reversesubtract_bp)super.position(position); + } + public reversesubtract_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = ReverseMod(X, Y) == Mod(Y, X) + */ +// #if NOT_EXCLUDED(OP_reversemod) +@Namespace("sd::ops") public static class reversemod extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reversemod(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reversemod(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reversemod position(long position) { + return (reversemod)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = ReverseMod(X, Y) == Mod(Y, X) - */ -// #if NOT_EXCLUDED(OP_reversemod) - @Namespace("sd::ops") public static class reversemod extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reversemod(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reversemod(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reversemod position(long position) { - return (reversemod)super.position(position); - } - public reversemod() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class reversemod_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reversemod_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reversemod_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reversemod_bp position(long position) { - return (reversemod_bp)super.position(position); - } - +@Namespace("sd::ops") public static class reversemod_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reversemod_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reversemod_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reversemod_bp position(long position) { + return (reversemod_bp)super.position(position); + } + public reversemod_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Subtract(X, Y) * Subtract(X, Y) + */ +// #if NOT_EXCLUDED(OP_squaredsubtract) +@Namespace("sd::ops") public static class squaredsubtract extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public squaredsubtract(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public squaredsubtract(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public squaredsubtract position(long position) { + return (squaredsubtract)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Subtract(X, Y) * Subtract(X, Y) - */ -// #if NOT_EXCLUDED(OP_squaredsubtract) - @Namespace("sd::ops") public static class squaredsubtract extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public squaredsubtract(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public squaredsubtract(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public squaredsubtract position(long position) { - return (squaredsubtract)super.position(position); - } - public squaredsubtract() { super((Pointer)null); allocate(); } private native void allocate(); } @@ -14152,250 +16467,262 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Multiply(X, Y) + */ +// #if NOT_EXCLUDED(OP_multiply) +@Namespace("sd::ops") public static class multiply extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public multiply(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public multiply(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public multiply position(long position) { + return (multiply)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Multiply(X, Y) - */ -// #if NOT_EXCLUDED(OP_multiply) - @Namespace("sd::ops") public static class multiply extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public multiply(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public multiply(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public multiply position(long position) { - return (multiply)super.position(position); - } - public multiply() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class multiply_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public multiply_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public multiply_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public multiply_bp position(long position) { - return (multiply_bp)super.position(position); - } - +@Namespace("sd::ops") public static class multiply_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public multiply_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public multiply_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public multiply_bp position(long position) { + return (multiply_bp)super.position(position); + } + public multiply_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Divide(X, Y) + */ +// #if NOT_EXCLUDED(OP_divide) +@Namespace("sd::ops") public static class divide extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public divide(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public divide(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public divide position(long position) { + return (divide)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Divide(X, Y) - */ -// #if NOT_EXCLUDED(OP_divide) - @Namespace("sd::ops") public static class divide extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public divide(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public divide(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public divide position(long position) { - return (divide)super.position(position); - } - public divide() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class divide_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public divide_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public divide_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public divide_bp position(long position) { - return (divide_bp)super.position(position); - } - +@Namespace("sd::ops") public static class divide_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public divide_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public divide_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public divide_bp position(long position) { + return (divide_bp)super.position(position); + } + public divide_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Divide(X, Y) with exception, 0 if Y = 0 + */ +// #if NOT_EXCLUDED(OP_divide_no_nan) +@Namespace("sd::ops") public static class divide_no_nan extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public divide_no_nan(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public divide_no_nan(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public divide_no_nan position(long position) { + return (divide_no_nan)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Divide(X, Y) with exception, 0 if Y = 0 - */ -// #if NOT_EXCLUDED(OP_divide_no_nan) - @Namespace("sd::ops") public static class divide_no_nan extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public divide_no_nan(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public divide_no_nan(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public divide_no_nan position(long position) { - return (divide_no_nan)super.position(position); - } - public divide_no_nan() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Divide(Y, x) - */ -// #if NOT_EXCLUDED(OP_reversedivide) - @Namespace("sd::ops") public static class reversedivide extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reversedivide(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reversedivide(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reversedivide position(long position) { - return (reversedivide)super.position(position); - } - +// #endif +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Divide(Y, x) + */ +// #if NOT_EXCLUDED(OP_reversedivide) +@Namespace("sd::ops") public static class reversedivide extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reversedivide(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reversedivide(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reversedivide position(long position) { + return (reversedivide)super.position(position); + } + public reversedivide() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class reversedivide_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reversedivide_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reversedivide_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reversedivide_bp position(long position) { - return (reversedivide_bp)super.position(position); - } - +@Namespace("sd::ops") public static class reversedivide_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reversedivide_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reversedivide_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reversedivide_bp position(long position) { + return (reversedivide_bp)super.position(position); + } + public reversedivide_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = FloorMod(X, Y) + */ +// #if NOT_EXCLUDED(OP_floormod) +@Namespace("sd::ops") public static class floormod extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public floormod(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public floormod(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public floormod position(long position) { + return (floormod)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = FloorMod(X, Y) - */ -// #if NOT_EXCLUDED(OP_floormod) - @Namespace("sd::ops") public static class floormod extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public floormod(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public floormod(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public floormod position(long position) { - return (floormod)super.position(position); - } - public floormod() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class floormod_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public floormod_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public floormod_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public floormod_bp position(long position) { - return (floormod_bp)super.position(position); - } - +@Namespace("sd::ops") public static class floormod_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public floormod_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public floormod_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public floormod_bp position(long position) { + return (floormod_bp)super.position(position); + } + public floormod_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_mod) +@Namespace("sd::ops") public static class mod extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mod(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mod(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mod position(long position) { + return (mod)super.position(position); + } -// #if NOT_EXCLUDED(OP_mod) - @Namespace("sd::ops") public static class mod extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mod(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mod(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mod position(long position) { - return (mod)super.position(position); - } - public mod() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class mod_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mod_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mod_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mod_bp position(long position) { - return (mod_bp)super.position(position); - } - +@Namespace("sd::ops") public static class mod_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mod_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mod_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mod_bp position(long position) { + return (mod_bp)super.position(position); + } + public mod_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = FloorDiv(X, Y) + */ +// #if NOT_EXCLUDED(OP_floordiv) +@Namespace("sd::ops") public static class floordiv extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public floordiv(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public floordiv(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public floordiv position(long position) { + return (floordiv)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = FloorDiv(X, Y) - */ -// #if NOT_EXCLUDED(OP_floordiv) - @Namespace("sd::ops") public static class floordiv extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public floordiv(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public floordiv(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public floordiv position(long position) { - return (floordiv)super.position(position); - } - public floordiv() { super((Pointer)null); allocate(); } private native void allocate(); } @@ -14414,453 +16741,458 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - // #endif + // #endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Divide(X, Y) + */ +// #if NOT_EXCLUDED(OP_realdiv) +@Namespace("sd::ops") public static class realdiv extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public realdiv(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public realdiv(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public realdiv position(long position) { + return (realdiv)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Divide(X, Y) - */ -// #if NOT_EXCLUDED(OP_realdiv) - @Namespace("sd::ops") public static class realdiv extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public realdiv(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public realdiv(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public realdiv position(long position) { - return (realdiv)super.position(position); - } - public realdiv() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class realdiv_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public realdiv_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public realdiv_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public realdiv_bp position(long position) { - return (realdiv_bp)super.position(position); - } - +@Namespace("sd::ops") public static class realdiv_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public realdiv_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public realdiv_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public realdiv_bp position(long position) { + return (realdiv_bp)super.position(position); + } + public realdiv_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * + * + * \tparam T + */ +@Namespace("sd::ops") public static class truncatediv extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public truncatediv(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public truncatediv(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public truncatediv position(long position) { + return (truncatediv)super.position(position); + } - /** - * - * - * \tparam T - */ - @Namespace("sd::ops") public static class truncatediv extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public truncatediv(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public truncatediv(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public truncatediv position(long position) { - return (truncatediv)super.position(position); - } - public truncatediv() { super((Pointer)null); allocate(); } private native void allocate(); } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Assign(X, Y) - */ -// #if NOT_EXCLUDED(OP_assign) - @Namespace("sd::ops") public static class assign extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public assign(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public assign(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public assign position(long position) { - return (assign)super.position(position); - } - +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Assign(X, Y) + */ +// #if NOT_EXCLUDED(OP_assign) +@Namespace("sd::ops") public static class assign extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public assign(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public assign(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public assign position(long position) { + return (assign)super.position(position); + } + public assign() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class assign_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public assign_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public assign_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public assign_bp position(long position) { - return (assign_bp)super.position(position); - } - +@Namespace("sd::ops") public static class assign_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public assign_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public assign_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public assign_bp position(long position) { + return (assign_bp)super.position(position); + } + public assign_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_meshgrid) +@Namespace("sd::ops") public static class meshgrid extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public meshgrid(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public meshgrid(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public meshgrid position(long position) { + return (meshgrid)super.position(position); + } -// #if NOT_EXCLUDED(OP_meshgrid) - @Namespace("sd::ops") public static class meshgrid extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public meshgrid(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public meshgrid(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public meshgrid position(long position) { - return (meshgrid)super.position(position); - } - public meshgrid() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op takes 2 equally shaped arrays as input, and provides binary matrix as + * output. Math is: _x == _y ? (T) 1.0f : (T) 0.0f; + * + */ +// #if NOT_EXCLUDED(OP_equals) +@Namespace("sd::ops") public static class equals extends BroadcastableBoolOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public equals(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public equals(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public equals position(long position) { + return (equals)super.position(position); + } - /** - * This op takes 2 equally shaped arrays as input, and provides binary matrix as output. - * Math is: _x == _y ? (T) 1.0f : (T) 0.0f; - * - */ -// #if NOT_EXCLUDED(OP_equals) - @Namespace("sd::ops") public static class equals extends BroadcastableBoolOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public equals(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public equals(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public equals position(long position) { - return (equals)super.position(position); - } - public equals() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This op takes 2 equally shaped arrays as input, and provides binary matrix as + * output. Math is: _x != _y ? (T) 1.0f : (T) 0.0f; + */ +// #if NOT_EXCLUDED(OP_not_equals) +@Namespace("sd::ops") public static class not_equals extends BroadcastableBoolOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public not_equals(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public not_equals(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public not_equals position(long position) { + return (not_equals)super.position(position); + } - /** - * This op takes 2 equally shaped arrays as input, and provides binary matrix as output. - * Math is: _x != _y ? (T) 1.0f : (T) 0.0f; - */ -// #if NOT_EXCLUDED(OP_not_equals) - @Namespace("sd::ops") public static class not_equals extends BroadcastableBoolOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public not_equals(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public not_equals(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public not_equals position(long position) { - return (not_equals)super.position(position); - } - public not_equals() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This op takes 2 equally shaped arrays as input, and provides binary matrix as + * output. Math is: _x <= _y ? (T) 1.0f : (T) 0.0f; + */ +// #if NOT_EXCLUDED(OP_less_equal) +@Namespace("sd::ops") public static class less_equal extends BroadcastableBoolOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public less_equal(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public less_equal(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public less_equal position(long position) { + return (less_equal)super.position(position); + } - /** - * This op takes 2 equally shaped arrays as input, and provides binary matrix as output. - * Math is: _x <= _y ? (T) 1.0f : (T) 0.0f; - */ -// #if NOT_EXCLUDED(OP_less_equal) - @Namespace("sd::ops") public static class less_equal extends BroadcastableBoolOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public less_equal(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public less_equal(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public less_equal position(long position) { - return (less_equal)super.position(position); - } - public less_equal() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This op takes 2 equally shaped arrays as input, and provides binary matrix as + * output. Math is: _x >= _y ? (T) 1.0f : (T) 0.0f; + */ +// #if NOT_EXCLUDED(OP_greater_equal) +@Namespace("sd::ops") public static class greater_equal extends BroadcastableBoolOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public greater_equal(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public greater_equal(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public greater_equal position(long position) { + return (greater_equal)super.position(position); + } - /** - * This op takes 2 equally shaped arrays as input, and provides binary matrix as output. - * Math is: _x >= _y ? (T) 1.0f : (T) 0.0f; - */ -// #if NOT_EXCLUDED(OP_greater_equal) - @Namespace("sd::ops") public static class greater_equal extends BroadcastableBoolOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public greater_equal(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public greater_equal(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public greater_equal position(long position) { - return (greater_equal)super.position(position); - } - public greater_equal() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This op takes 2 equally shaped arrays as input, and provides binary matrix as + * output. Math is: _x < _y ? (T) 1.0f : (T) 0.0f; + */ +// #if NOT_EXCLUDED(OP_less) +@Namespace("sd::ops") public static class less extends BroadcastableBoolOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public less(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public less(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public less position(long position) { + return (less)super.position(position); + } - /** - * This op takes 2 equally shaped arrays as input, and provides binary matrix as output. - * Math is: _x < _y ? (T) 1.0f : (T) 0.0f; - */ -// #if NOT_EXCLUDED(OP_less) - @Namespace("sd::ops") public static class less extends BroadcastableBoolOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public less(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public less(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public less position(long position) { - return (less)super.position(position); - } - public less() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This op takes 2 equally shaped arrays as input, and provides binary matrix as + * output. Math is: _x > _y ? (T) 1.0f : (T) 0.0f; + */ +// #if NOT_EXCLUDED(OP_greater) +@Namespace("sd::ops") public static class greater extends BroadcastableBoolOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public greater(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public greater(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public greater position(long position) { + return (greater)super.position(position); + } - /** - * This op takes 2 equally shaped arrays as input, and provides binary matrix as output. - * Math is: _x > _y ? (T) 1.0f : (T) 0.0f; - */ -// #if NOT_EXCLUDED(OP_greater) - @Namespace("sd::ops") public static class greater extends BroadcastableBoolOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public greater(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public greater(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public greater position(long position) { - return (greater)super.position(position); - } - public greater() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * + */ +// #if NOT_EXCLUDED(OP_boolean_and) +@Namespace("sd::ops") public static class boolean_and extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public boolean_and(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public boolean_and(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public boolean_and position(long position) { + return (boolean_and)super.position(position); + } - /** - * - */ -// #if NOT_EXCLUDED(OP_boolean_and) - @Namespace("sd::ops") public static class boolean_and extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public boolean_and(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public boolean_and(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public boolean_and position(long position) { - return (boolean_and)super.position(position); - } - public boolean_and() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * + */ +// #if NOT_EXCLUDED(OP_boolean_or) +@Namespace("sd::ops") public static class boolean_or extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public boolean_or(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public boolean_or(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public boolean_or position(long position) { + return (boolean_or)super.position(position); + } - /** - * - */ -// #if NOT_EXCLUDED(OP_boolean_or) - @Namespace("sd::ops") public static class boolean_or extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public boolean_or(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public boolean_or(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public boolean_or position(long position) { - return (boolean_or)super.position(position); - } - public boolean_or() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * + */ +// #if NOT_EXCLUDED(OP_boolean_xor) +@Namespace("sd::ops") public static class boolean_xor extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public boolean_xor(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public boolean_xor(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public boolean_xor position(long position) { + return (boolean_xor)super.position(position); + } - /** - * - */ -// #if NOT_EXCLUDED(OP_boolean_xor) - @Namespace("sd::ops") public static class boolean_xor extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public boolean_xor(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public boolean_xor(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public boolean_xor position(long position) { - return (boolean_xor)super.position(position); - } - public boolean_xor() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation performs calculation of percentile of input array along given + * axises + * + * Input - tensor with rank N > 0 + * Output - tensor with rank (N - length(axis)) or scalar if number of Integer + * arguments is zero Float arguments: 0: percentile (scalar) in range [0,100] + * (inclusively) 1: interpolation (optional), possible values are 0-"lower", + * 1-"higher", 2-"nearest"(default) 2: keepDims (optional), if it is non zero, + * then unities are kept in reduced resulting shape of output array, default is + * 0 Integer arguments - axis - the sequence of axises to calculate percentile + * along, if sequence is empty then calculate percentile for whole input tensor + * and return result as scalar + * + */ +// #if NOT_EXCLUDED(OP_percentile) +@Namespace("sd::ops") public static class percentile extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public percentile(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public percentile(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public percentile position(long position) { + return (percentile)super.position(position); + } - /** - * This operation performs calculation of percentile of input array along given axises - * - * Input - tensor with rank N > 0 - * Output - tensor with rank (N - length(axis)) or scalar if number of Integer arguments is zero - * Float arguments: - * 0: percentile (scalar) in range [0,100] (inclusively) - * 1: interpolation (optional), possible values are 0-"lower", 1-"higher", 2-"nearest"(default) - * 2: keepDims (optional), if it is non zero, then unities are kept in reduced resulting shape of output array, default is 0 - * Integer arguments - axis - the sequence of axises to calculate percentile along, if sequence is empty then calculate percentile for whole input tensor and return result as scalar - * - */ -// #if NOT_EXCLUDED(OP_percentile) - @Namespace("sd::ops") public static class percentile extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public percentile(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public percentile(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public percentile position(long position) { - return (percentile)super.position(position); - } - public percentile() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * Special atan2 op impl for TF's args order + * \tparam T + */ +// #if NOT_EXCLUDED(OP_tf_atan2) +@Namespace("sd::ops") public static class tf_atan2 extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public tf_atan2(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public tf_atan2(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public tf_atan2 position(long position) { + return (tf_atan2)super.position(position); + } - /** - * Special atan2 op impl for TF's args order - * \tparam T - */ -// #if NOT_EXCLUDED(OP_tf_atan2) - @Namespace("sd::ops") public static class tf_atan2 extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public tf_atan2(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public tf_atan2(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public tf_atan2 position(long position) { - return (tf_atan2)super.position(position); - } - public tf_atan2() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * Broadcastable pow implementation + * \tparam T + */ +// #if NOT_EXCLUDED(OP_Pow) +@Namespace("sd::ops") public static class Pow extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Pow(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Pow(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Pow position(long position) { + return (Pow)super.position(position); + } - /** - * Broadcastable pow implementation - * \tparam T - */ -// #if NOT_EXCLUDED(OP_Pow) - @Namespace("sd::ops") public static class Pow extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Pow(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Pow(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Pow position(long position) { - return (Pow)super.position(position); - } - public Pow() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class Pow_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Pow_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Pow_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Pow_bp position(long position) { - return (Pow_bp)super.position(position); - } - +@Namespace("sd::ops") public static class Pow_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Pow_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Pow_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Pow_bp position(long position) { + return (Pow_bp)super.position(position); + } + public Pow_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Broadcastable igamma implementation + * + * igamma(a, x) = gamma(а, x) / Gamma(a) - Gamma distribution function P(a,x) + * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } + * gamma(a, x) = int from 0 to x { t ^ {a - 1} e^{-t}dt } + * \tparam T + */ +// #if NOT_EXCLUDED(OP_igamma) +@Namespace("sd::ops") public static class igamma extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public igamma(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public igamma(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public igamma position(long position) { + return (igamma)super.position(position); + } - /** - * Broadcastable igamma implementation - * - * igamma(a, x) = gamma(а, x) / Gamma(a) - Gamma distribution function P(a,x) - * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } - * gamma(a, x) = int from 0 to x { t ^ {a - 1} e^{-t}dt } - * \tparam T - */ -// #if NOT_EXCLUDED(OP_igamma) - @Namespace("sd::ops") public static class igamma extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public igamma(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public igamma(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public igamma position(long position) { - return (igamma)super.position(position); - } - public igamma() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif - /** - * Broadcastable igammac implementation - * igammac(a, x) = Gamma(a,x)/Gamma(а) - Gamma distribution function Q(a,x) - * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } - * Gamma(a, x) = int from x to infinity { t ^ {a - 1} e^{-t}dt } - * \tparam T - */ -// #if NOT_EXCLUDED(OP_igammac) - @Namespace("sd::ops") public static class igammac extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public igammac(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public igammac(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public igammac position(long position) { - return (igammac)super.position(position); - } - +// #endif +/** + * Broadcastable igammac implementation + * igammac(a, x) = Gamma(a,x)/Gamma(а) - Gamma distribution function Q(a,x) + * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } + * Gamma(a, x) = int from x to infinity { t ^ {a - 1} e^{-t}dt } + * \tparam T + */ +// #if NOT_EXCLUDED(OP_igammac) +@Namespace("sd::ops") public static class igammac extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public igammac(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public igammac(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public igammac position(long position) { + return (igammac)super.position(position); + } + public igammac() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif - - +// #endif + // namespace ops + // namespace sd // #endif @@ -14891,824 +17223,823 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #include - /** - * 1D temporal convolution implementation - * Expected input: - * x: 3D array - * weight: 3D Array - * bias: optional vector - * - * Int args: - * 0: kernel - * 1: stride - * 2: padding - */ -// #if NOT_EXCLUDED(OP_conv1d) - @Namespace("sd::ops") public static class conv1d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public conv1d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public conv1d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public conv1d position(long position) { - return (conv1d)super.position(position); - } - +/** + * 1D temporal convolution implementation + * Expected input: + * x: 3D array + * weight: 3D Array + * bias: optional vector + * + * Int args: + * 0: kernel + * 1: stride + * 2: padding + */ +// #if NOT_EXCLUDED(OP_conv1d) +@Namespace("sd::ops") public static class conv1d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public conv1d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public conv1d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public conv1d position(long position) { + return (conv1d)super.position(position); + } + public conv1d() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class conv1d_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public conv1d_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public conv1d_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public conv1d_bp position(long position) { - return (conv1d_bp)super.position(position); - } - +@Namespace("sd::ops") public static class conv1d_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public conv1d_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public conv1d_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public conv1d_bp position(long position) { + return (conv1d_bp)super.position(position); + } + public conv1d_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * 2D convolution implementation + * Expected input: + * x: 4D array + * weight: 4D Array + * bias: optional vector, length of outputChannels + * + * IntArgs: + * 0: kernel height + * 1: kernel width + * 2: stride height + * 3: stride width + * 4: padding height + * 5: padding width + * 6: dilation height + * 7: dilation width + * 8: same mode: 1 true, 0 false + * 9: data format: 1 NHWC, 0 NCHW + */ +// #if NOT_EXCLUDED(OP_conv2d) +@Namespace("sd::ops") public static class conv2d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public conv2d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public conv2d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public conv2d position(long position) { + return (conv2d)super.position(position); + } - /** - * 2D convolution implementation - * Expected input: - * x: 4D array - * weight: 4D Array - * bias: optional vector, length of outputChannels - * - * IntArgs: - * 0: kernel height - * 1: kernel width - * 2: stride height - * 3: stride width - * 4: padding height - * 5: padding width - * 6: dilation height - * 7: dilation width - * 8: same mode: 1 true, 0 false - * 9: data format: 1 NHWC, 0 NCHW - */ -// #if NOT_EXCLUDED(OP_conv2d) - @Namespace("sd::ops") public static class conv2d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public conv2d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public conv2d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public conv2d position(long position) { - return (conv2d)super.position(position); - } - public conv2d() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class conv2d_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public conv2d_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public conv2d_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public conv2d_bp position(long position) { - return (conv2d_bp)super.position(position); - } - +@Namespace("sd::ops") public static class conv2d_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public conv2d_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public conv2d_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public conv2d_bp position(long position) { + return (conv2d_bp)super.position(position); + } + public conv2d_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class conv2d_input_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public conv2d_input_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public conv2d_input_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public conv2d_input_bp position(long position) { - return (conv2d_input_bp)super.position(position); - } - - public conv2d_input_bp() { super((Pointer)null); allocate(); } - private native void allocate(); +@Namespace("sd::ops") public static class conv2d_input_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public conv2d_input_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public conv2d_input_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public conv2d_input_bp position(long position) { + return (conv2d_input_bp)super.position(position); + } + + public conv2d_input_bp() { super((Pointer)null); allocate(); } + private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Depthwise convolution2d op: + * Expected inputs: + * x: 4D array, NCHW format + * weightsDepth: 4D array, + * weightsPointwise: optional, 4D array + * bias: optional, vector + */ +// #if NOT_EXCLUDED(OP_sconv2d) +@Namespace("sd::ops") public static class sconv2d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sconv2d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sconv2d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sconv2d position(long position) { + return (sconv2d)super.position(position); + } - /** - * Depthwise convolution2d op: - * Expected inputs: - * x: 4D array, NCHW format - * weightsDepth: 4D array, - * weightsPointwise: optional, 4D array - * bias: optional, vector - */ -// #if NOT_EXCLUDED(OP_sconv2d) - @Namespace("sd::ops") public static class sconv2d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sconv2d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sconv2d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sconv2d position(long position) { - return (sconv2d)super.position(position); - } - public sconv2d() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class sconv2d_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sconv2d_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sconv2d_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sconv2d_bp position(long position) { - return (sconv2d_bp)super.position(position); - } - +@Namespace("sd::ops") public static class sconv2d_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sconv2d_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sconv2d_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sconv2d_bp position(long position) { + return (sconv2d_bp)super.position(position); + } + public sconv2d_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * 2D deconvolution implementation + * + * IntArgs: + * 0: kernel height + * 1: kernel width + * 2: stride height + * 3: stride width + * 4: padding height + * 5: padding width + * 6: dilation height + * 7: dilation width + * 8: same mode: 0 false, 1 true + */ +// #if NOT_EXCLUDED(OP_deconv2d) +@Namespace("sd::ops") public static class deconv2d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public deconv2d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public deconv2d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public deconv2d position(long position) { + return (deconv2d)super.position(position); + } - /** - * 2D deconvolution implementation - * - * IntArgs: - * 0: kernel height - * 1: kernel width - * 2: stride height - * 3: stride width - * 4: padding height - * 5: padding width - * 6: dilation height - * 7: dilation width - * 8: same mode: 0 false, 1 true - */ -// #if NOT_EXCLUDED(OP_deconv2d) - @Namespace("sd::ops") public static class deconv2d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public deconv2d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public deconv2d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public deconv2d position(long position) { - return (deconv2d)super.position(position); - } - public deconv2d() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class deconv2d_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public deconv2d_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public deconv2d_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public deconv2d_bp position(long position) { - return (deconv2d_bp)super.position(position); - } - +@Namespace("sd::ops") public static class deconv2d_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public deconv2d_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public deconv2d_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public deconv2d_bp position(long position) { + return (deconv2d_bp)super.position(position); + } + public deconv2d_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif - /** - * 3D deconvolution implementation - * - * IntArgs: - * 0: filter(kernel) depth - * 1: filter(kernel) height - * 2: filter(kernel) width - * 3: strides depth - * 4: strides height - * 5: strides width - * 6: paddings depth - * 7: paddings height - * 8: paddings width - * 9: dilations depth - * 10: dilations height - * 11: dilations width - * 12: same mode: 0 false, 1 true - * 13: data format (optional): 0-NDHWC, 1-NCDHW, default is 1 - */ +/** + * 3D deconvolution implementation + * + * IntArgs: + * 0: filter(kernel) depth + * 1: filter(kernel) height + * 2: filter(kernel) width + * 3: strides depth + * 4: strides height + * 5: strides width + * 6: paddings depth + * 7: paddings height + * 8: paddings width + * 9: dilations depth + * 10: dilations height + * 11: dilations width + * 12: same mode: 0 false, 1 true + * 13: data format (optional): 0-NDHWC, 1-NCDHW, default is 1 + */ + +// #if NOT_EXCLUDED(OP_deconv3d) +@Namespace("sd::ops") public static class deconv3d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public deconv3d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public deconv3d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public deconv3d position(long position) { + return (deconv3d)super.position(position); + } -// #if NOT_EXCLUDED(OP_deconv3d) - @Namespace("sd::ops") public static class deconv3d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public deconv3d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public deconv3d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public deconv3d position(long position) { - return (deconv3d)super.position(position); - } - public deconv3d() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class deconv3d_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public deconv3d_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public deconv3d_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public deconv3d_bp position(long position) { - return (deconv3d_bp)super.position(position); - } - +@Namespace("sd::ops") public static class deconv3d_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public deconv3d_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public deconv3d_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public deconv3d_bp position(long position) { + return (deconv3d_bp)super.position(position); + } + public deconv3d_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * This op implements max pooling for convolution networks. + * Expected Input: 4D array, NCHW format. + * + * IntArgs: + * 0: kernel height + * 1: kernel width + * 2: stride height + * 3: stride width + * 4: padding height + * 5: padding width + * 6: dilation height + * 7: dilation width + * 8: same mode: 0 false, 1 true + */ +// #if NOT_EXCLUDED(OP_maxpool2d) +@Namespace("sd::ops") public static class maxpool2d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public maxpool2d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public maxpool2d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public maxpool2d position(long position) { + return (maxpool2d)super.position(position); + } - /** - * This op implements max pooling for convolution networks. - * Expected Input: 4D array, NCHW format. - * - * IntArgs: - * 0: kernel height - * 1: kernel width - * 2: stride height - * 3: stride width - * 4: padding height - * 5: padding width - * 6: dilation height - * 7: dilation width - * 8: same mode: 0 false, 1 true - */ -// #if NOT_EXCLUDED(OP_maxpool2d) - @Namespace("sd::ops") public static class maxpool2d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public maxpool2d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public maxpool2d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public maxpool2d position(long position) { - return (maxpool2d)super.position(position); - } - public maxpool2d() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class maxpool2d_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public maxpool2d_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public maxpool2d_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public maxpool2d_bp position(long position) { - return (maxpool2d_bp)super.position(position); - } - +@Namespace("sd::ops") public static class maxpool2d_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public maxpool2d_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public maxpool2d_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public maxpool2d_bp position(long position) { + return (maxpool2d_bp)super.position(position); + } + public maxpool2d_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op implements average pooling for convolution networks. + * Expected Input: 4D array, NCHW format. + * + * IntArgs: + * 0: kernel height + * 1: kernel width + * 2: stride height + * 3: stride width + * 4: padding height + * 5: padding width + * 6: dilation height + * 7: dilation width + * 8: same mode: 0 false, 1 true + */ +// #if NOT_EXCLUDED(OP_avgpool2d) +@Namespace("sd::ops") public static class avgpool2d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public avgpool2d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public avgpool2d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public avgpool2d position(long position) { + return (avgpool2d)super.position(position); + } - /** - * This op implements average pooling for convolution networks. - * Expected Input: 4D array, NCHW format. - * - * IntArgs: - * 0: kernel height - * 1: kernel width - * 2: stride height - * 3: stride width - * 4: padding height - * 5: padding width - * 6: dilation height - * 7: dilation width - * 8: same mode: 0 false, 1 true - */ -// #if NOT_EXCLUDED(OP_avgpool2d) - @Namespace("sd::ops") public static class avgpool2d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public avgpool2d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public avgpool2d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public avgpool2d position(long position) { - return (avgpool2d)super.position(position); - } - public avgpool2d() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class avgpool2d_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public avgpool2d_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public avgpool2d_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public avgpool2d_bp position(long position) { - return (avgpool2d_bp)super.position(position); - } - +@Namespace("sd::ops") public static class avgpool2d_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public avgpool2d_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public avgpool2d_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public avgpool2d_bp position(long position) { + return (avgpool2d_bp)super.position(position); + } + public avgpool2d_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op implements pnorm pooling for convolution networks. + * Expected Input: 4D array, NCHW format. + * + * IntArgs: + * 0: kernel height + * 1: kernel width + * 2: stride height + * 3: stride width + * 4: padding height + * 5: padding width + * 6: dilation height + * 7: dilation width + * 8: same mode: 0 false, 1 true + * 9: p for p-norm + */ +// #if NOT_EXCLUDED(OP_pnormpool2d) +@Namespace("sd::ops") public static class pnormpool2d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public pnormpool2d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public pnormpool2d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public pnormpool2d position(long position) { + return (pnormpool2d)super.position(position); + } - /** - * This op implements pnorm pooling for convolution networks. - * Expected Input: 4D array, NCHW format. - * - * IntArgs: - * 0: kernel height - * 1: kernel width - * 2: stride height - * 3: stride width - * 4: padding height - * 5: padding width - * 6: dilation height - * 7: dilation width - * 8: same mode: 0 false, 1 true - * 9: p for p-norm - */ -// #if NOT_EXCLUDED(OP_pnormpool2d) - @Namespace("sd::ops") public static class pnormpool2d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public pnormpool2d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public pnormpool2d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public pnormpool2d position(long position) { - return (pnormpool2d)super.position(position); - } - public pnormpool2d() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class pnormpool2d_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public pnormpool2d_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public pnormpool2d_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public pnormpool2d_bp position(long position) { - return (pnormpool2d_bp)super.position(position); - } - +@Namespace("sd::ops") public static class pnormpool2d_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public pnormpool2d_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public pnormpool2d_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public pnormpool2d_bp position(long position) { + return (pnormpool2d_bp)super.position(position); + } + public pnormpool2d_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op implements im2col algorithm, widely used in convolution neural + * networks Input: 4D input expected + * + * Int args: + * 0: kernel height + * 1: kernel width + * 2: stride height + * 3: stride width + * 4: padding height + * 5: padding width + * 6: dilation height + * 7: dilation width + * 8: isSameMode + */ +// #if NOT_EXCLUDED(OP_im2col) +@Namespace("sd::ops") public static class im2col extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public im2col(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public im2col(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public im2col position(long position) { + return (im2col)super.position(position); + } - /** - * This op implements im2col algorithm, widely used in convolution neural networks - * Input: 4D input expected - * - * Int args: - * 0: kernel height - * 1: kernel width - * 2: stride height - * 3: stride width - * 4: padding height - * 5: padding width - * 6: dilation height - * 7: dilation width - * 8: isSameMode - */ -// #if NOT_EXCLUDED(OP_im2col) - @Namespace("sd::ops") public static class im2col extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public im2col(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public im2col(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public im2col position(long position) { - return (im2col)super.position(position); - } - public im2col() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class im2col_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public im2col_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public im2col_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public im2col_bp position(long position) { - return (im2col_bp)super.position(position); - } - +@Namespace("sd::ops") public static class im2col_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public im2col_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public im2col_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public im2col_bp position(long position) { + return (im2col_bp)super.position(position); + } + public im2col_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op implements col2im algorithm, widely used in convolution neural + * networks Input: 6D input expected (like output of im2col op) + * + * Int args: + * 0: stride height + * 1: stride width + * 2: padding height + * 3: padding width + * 4: image height + * 5: image width + * 6: dilation height + * 7: dilation width + */ +// #if NOT_EXCLUDED(OP_col2im) +@Namespace("sd::ops") public static class col2im extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public col2im(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public col2im(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public col2im position(long position) { + return (col2im)super.position(position); + } - /** - * This op implements col2im algorithm, widely used in convolution neural networks - * Input: 6D input expected (like output of im2col op) - * - * Int args: - * 0: stride height - * 1: stride width - * 2: padding height - * 3: padding width - * 4: image height - * 5: image width - * 6: dilation height - * 7: dilation width - */ -// #if NOT_EXCLUDED(OP_col2im) - @Namespace("sd::ops") public static class col2im extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public col2im(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public col2im(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public col2im position(long position) { - return (col2im)super.position(position); - } - public col2im() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Expected input: 4D array + * + * IntArgs: + * 0: scale factor for rows (height) + * 1: scale factor for columns (width) + * 2: data format: 0 NHWC (default), 1 NCHW + */ +// #if NOT_EXCLUDED(OP_upsampling2d) +@Namespace("sd::ops") public static class upsampling2d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public upsampling2d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public upsampling2d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public upsampling2d position(long position) { + return (upsampling2d)super.position(position); + } + + public upsampling2d() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +@Namespace("sd::ops") public static class upsampling2d_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public upsampling2d_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public upsampling2d_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public upsampling2d_bp position(long position) { + return (upsampling2d_bp)super.position(position); + } - /** - * Expected input: 4D array - * - * IntArgs: - * 0: scale factor for rows (height) - * 1: scale factor for columns (width) - * 2: data format: 0 NHWC (default), 1 NCHW - */ -// #if NOT_EXCLUDED(OP_upsampling2d) - @Namespace("sd::ops") public static class upsampling2d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public upsampling2d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public upsampling2d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public upsampling2d position(long position) { - return (upsampling2d)super.position(position); - } - - public upsampling2d() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } - @Namespace("sd::ops") public static class upsampling2d_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public upsampling2d_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public upsampling2d_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public upsampling2d_bp position(long position) { - return (upsampling2d_bp)super.position(position); - } - public upsampling2d_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Expected input: 4D array + * + * IntArgs: + * 0: scale factor for depth + * 1: scale factor for rows (height) + * 2: scale factor for columns (width) + * 3: data format: 0 NDHWC (default), 1 NCDHW + */ +// #if NOT_EXCLUDED(OP_upsampling3d) +@Namespace("sd::ops") public static class upsampling3d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public upsampling3d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public upsampling3d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public upsampling3d position(long position) { + return (upsampling3d)super.position(position); + } - /** - * Expected input: 4D array - * - * IntArgs: - * 0: scale factor for depth - * 1: scale factor for rows (height) - * 2: scale factor for columns (width) - * 3: data format: 0 NDHWC (default), 1 NCDHW - */ -// #if NOT_EXCLUDED(OP_upsampling3d) - @Namespace("sd::ops") public static class upsampling3d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public upsampling3d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public upsampling3d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public upsampling3d position(long position) { - return (upsampling3d)super.position(position); - } - public upsampling3d() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class upsampling3d_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public upsampling3d_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public upsampling3d_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public upsampling3d_bp position(long position) { - return (upsampling3d_bp)super.position(position); - } - +@Namespace("sd::ops") public static class upsampling3d_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public upsampling3d_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public upsampling3d_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public upsampling3d_bp position(long position) { + return (upsampling3d_bp)super.position(position); + } + public upsampling3d_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op produces binary matrix wrt to target dimension. + * Maximum value within each TAD is replaced with 1, other values are set to + * true. + * + * Int args: + * 0: axis + */ +// #if NOT_EXCLUDED(OP_ismax) +@Namespace("sd::ops") public static class ismax extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ismax(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public ismax(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public ismax position(long position) { + return (ismax)super.position(position); + } - /** - * This op produces binary matrix wrt to target dimension. - * Maximum value within each TAD is replaced with 1, other values are set to true. - * - * Int args: - * 0: axis - */ -// #if NOT_EXCLUDED(OP_ismax) - @Namespace("sd::ops") public static class ismax extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ismax(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public ismax(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public ismax position(long position) { - return (ismax)super.position(position); - } - public ismax() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Dilation2D op + * + * Int args: + * 0: isSameMode + */ +// #if NOT_EXCLUDED(OP_dilation2d) +@Namespace("sd::ops") public static class dilation2d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public dilation2d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public dilation2d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public dilation2d position(long position) { + return (dilation2d)super.position(position); + } - /** - * Dilation2D op - * - * Int args: - * 0: isSameMode - */ -// #if NOT_EXCLUDED(OP_dilation2d) - @Namespace("sd::ops") public static class dilation2d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public dilation2d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public dilation2d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public dilation2d position(long position) { - return (dilation2d)super.position(position); - } - public dilation2d() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_conv3dnew) +@Namespace("sd::ops") public static class conv3dnew extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public conv3dnew(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public conv3dnew(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public conv3dnew position(long position) { + return (conv3dnew)super.position(position); + } -// #if NOT_EXCLUDED(OP_conv3dnew) - @Namespace("sd::ops") public static class conv3dnew extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public conv3dnew(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public conv3dnew(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public conv3dnew position(long position) { - return (conv3dnew)super.position(position); - } - public conv3dnew() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class conv3dnew_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public conv3dnew_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public conv3dnew_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public conv3dnew_bp position(long position) { - return (conv3dnew_bp)super.position(position); - } - +@Namespace("sd::ops") public static class conv3dnew_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public conv3dnew_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public conv3dnew_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public conv3dnew_bp position(long position) { + return (conv3dnew_bp)super.position(position); + } + public conv3dnew_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_avgpool3dnew) +@Namespace("sd::ops") public static class avgpool3dnew extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public avgpool3dnew(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public avgpool3dnew(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public avgpool3dnew position(long position) { + return (avgpool3dnew)super.position(position); + } -// #if NOT_EXCLUDED(OP_avgpool3dnew) - @Namespace("sd::ops") public static class avgpool3dnew extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public avgpool3dnew(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public avgpool3dnew(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public avgpool3dnew position(long position) { - return (avgpool3dnew)super.position(position); - } - public avgpool3dnew() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class avgpool3dnew_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public avgpool3dnew_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public avgpool3dnew_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public avgpool3dnew_bp position(long position) { - return (avgpool3dnew_bp)super.position(position); - } - +@Namespace("sd::ops") public static class avgpool3dnew_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public avgpool3dnew_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public avgpool3dnew_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public avgpool3dnew_bp position(long position) { + return (avgpool3dnew_bp)super.position(position); + } + public avgpool3dnew_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_maxpool3dnew) +@Namespace("sd::ops") public static class maxpool3dnew extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public maxpool3dnew(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public maxpool3dnew(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public maxpool3dnew position(long position) { + return (maxpool3dnew)super.position(position); + } -// #if NOT_EXCLUDED(OP_maxpool3dnew) - @Namespace("sd::ops") public static class maxpool3dnew extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public maxpool3dnew(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public maxpool3dnew(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public maxpool3dnew position(long position) { - return (maxpool3dnew)super.position(position); - } - public maxpool3dnew() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class maxpool3dnew_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public maxpool3dnew_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public maxpool3dnew_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public maxpool3dnew_bp position(long position) { - return (maxpool3dnew_bp)super.position(position); - } - +@Namespace("sd::ops") public static class maxpool3dnew_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public maxpool3dnew_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public maxpool3dnew_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public maxpool3dnew_bp position(long position) { + return (maxpool3dnew_bp)super.position(position); + } + public maxpool3dnew_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op same as maxpool2d with a variant to return a matrix of indexes for + * max values + * + * Input - 4D tensor + * Output: + * 0 - 4D tensor as input + * 1 - 4D tensor with max value indexes + * + * Int params: + * 9 int with 2x4 vectors and 1 bool value + */ +// #if NOT_EXCLUDED(OP_max_pool_woth_argmax) +@Namespace("sd::ops") public static class max_pool_with_argmax extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public max_pool_with_argmax(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public max_pool_with_argmax(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public max_pool_with_argmax position(long position) { + return (max_pool_with_argmax)super.position(position); + } - /** - * This op same as maxpool2d with a variant to return a matrix of indexes for max values - * - * Input - 4D tensor - * Output: - * 0 - 4D tensor as input - * 1 - 4D tensor with max value indexes - * - * Int params: - * 9 int with 2x4 vectors and 1 bool value - */ -// #if NOT_EXCLUDED(OP_max_pool_woth_argmax) - @Namespace("sd::ops") public static class max_pool_with_argmax extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public max_pool_with_argmax(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public max_pool_with_argmax(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public max_pool_with_argmax position(long position) { - return (max_pool_with_argmax)super.position(position); - } - public max_pool_with_argmax() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +// #if NOT_EXCLUDED(OP_depthwise_conv2d) +@Namespace("sd::ops") public static class depthwise_conv2d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public depthwise_conv2d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public depthwise_conv2d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public depthwise_conv2d position(long position) { + return (depthwise_conv2d)super.position(position); + } -// #if NOT_EXCLUDED(OP_depthwise_conv2d) - @Namespace("sd::ops") public static class depthwise_conv2d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public depthwise_conv2d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public depthwise_conv2d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public depthwise_conv2d position(long position) { - return (depthwise_conv2d)super.position(position); - } - public depthwise_conv2d() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class depthwise_conv2d_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public depthwise_conv2d_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public depthwise_conv2d_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public depthwise_conv2d_bp position(long position) { - return (depthwise_conv2d_bp)super.position(position); - } - +@Namespace("sd::ops") public static class depthwise_conv2d_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public depthwise_conv2d_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public depthwise_conv2d_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public depthwise_conv2d_bp position(long position) { + return (depthwise_conv2d_bp)super.position(position); + } + public depthwise_conv2d_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * point-wise 2D convolution + * Expected input: + * x: 4D array + * weight: 4D Array [1, 1, iC, oC] (NHWC) or [oC, iC, 1, 1] (NCHW) + * bias: optional vector, length of oC + * + * IntArgs: + * 0: data format: 1 NHWC, 0 NCHW (optional, by default = NHWC) + */ +@Namespace("sd::ops") public static class pointwise_conv2d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public pointwise_conv2d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public pointwise_conv2d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public pointwise_conv2d position(long position) { + return (pointwise_conv2d)super.position(position); + } - /** - * point-wise 2D convolution - * Expected input: - * x: 4D array - * weight: 4D Array [1, 1, iC, oC] (NHWC) or [oC, iC, 1, 1] (NCHW) - * bias: optional vector, length of oC - * - * IntArgs: - * 0: data format: 1 NHWC, 0 NCHW (optional, by default = NHWC) - */ - @Namespace("sd::ops") public static class pointwise_conv2d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public pointwise_conv2d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public pointwise_conv2d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public pointwise_conv2d position(long position) { - return (pointwise_conv2d)super.position(position); - } - public pointwise_conv2d() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class deconv2d_tf extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public deconv2d_tf(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public deconv2d_tf(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public deconv2d_tf position(long position) { - return (deconv2d_tf)super.position(position); - } - +@Namespace("sd::ops") public static class deconv2d_tf extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public deconv2d_tf(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public deconv2d_tf(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public deconv2d_tf position(long position) { + return (deconv2d_tf)super.position(position); + } + public deconv2d_tf() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - - - + // namespace ops + // namespace sd // #endif @@ -15738,251 +18069,248 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_HEADERS_LIST_H // #include - // list operations, basically all around NDArrayList +// list operations, basically all around NDArrayList + +/** + * This operations puts given NDArray into (optionally) given NDArrayList. + * If no NDArrayList was provided - new one will be created + */ +// #if NOT_EXCLUDED(OP_write_list) +@Namespace("sd::ops") public static class write_list extends DeclarableListOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public write_list(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public write_list(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public write_list position(long position) { + return (write_list)super.position(position); + } - /** - * This operations puts given NDArray into (optionally) given NDArrayList. - * If no NDArrayList was provided - new one will be created - */ -// #if NOT_EXCLUDED(OP_write_list) - @Namespace("sd::ops") public static class write_list extends DeclarableListOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public write_list(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public write_list(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public write_list position(long position) { - return (write_list)super.position(position); - } - public write_list() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation concatenates given NDArrayList, and returns NDArray as result + */ +// #if NOT_EXCLUDED(OP_stack_list) +@Namespace("sd::ops") public static class stack_list extends DeclarableListOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public stack_list(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public stack_list(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public stack_list position(long position) { + return (stack_list)super.position(position); + } - /** - * This operation concatenates given NDArrayList, and returns NDArray as result - */ -// #if NOT_EXCLUDED(OP_stack_list) - @Namespace("sd::ops") public static class stack_list extends DeclarableListOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public stack_list(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public stack_list(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public stack_list position(long position) { - return (stack_list)super.position(position); - } - public stack_list() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operations selects specified index fron NDArrayList and returns it as + * NDArray Expected arguments: x: non-empty list indices: optional, scalar with + * index + * + * Int args: + * optional, index + */ +// #if NOT_EXCLUDED(OP_read_list) +@Namespace("sd::ops") public static class read_list extends DeclarableListOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public read_list(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public read_list(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public read_list position(long position) { + return (read_list)super.position(position); + } - /** - * This operations selects specified index fron NDArrayList and returns it as NDArray - * Expected arguments: - * x: non-empty list - * indices: optional, scalar with index - * - * Int args: - * optional, index - */ -// #if NOT_EXCLUDED(OP_read_list) - @Namespace("sd::ops") public static class read_list extends DeclarableListOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public read_list(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public read_list(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public read_list position(long position) { - return (read_list)super.position(position); - } - public read_list() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operations selects specified indices fron NDArrayList and returns them + * as NDArray Expected arguments: x: non-empty list indices: optional, vector + * with indices + * + * Int args: + * optional, indices + */ +// #if NOT_EXCLUDED(OP_pick_list) +@Namespace("sd::ops") public static class pick_list extends DeclarableListOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public pick_list(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public pick_list(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public pick_list position(long position) { + return (pick_list)super.position(position); + } - /** - * This operations selects specified indices fron NDArrayList and returns them as NDArray - * Expected arguments: - * x: non-empty list - * indices: optional, vector with indices - * - * Int args: - * optional, indices - */ -// #if NOT_EXCLUDED(OP_pick_list) - @Namespace("sd::ops") public static class pick_list extends DeclarableListOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public pick_list(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public pick_list(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public pick_list position(long position) { - return (pick_list)super.position(position); - } - public pick_list() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operations returns scalar, with number of existing arrays within given + * NDArrayList Expected arguments: x: list + */ +// #if NOT_EXCLUDED(OP_size_list) +@Namespace("sd::ops") public static class size_list extends DeclarableListOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public size_list(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public size_list(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public size_list position(long position) { + return (size_list)super.position(position); + } - /** - * This operations returns scalar, with number of existing arrays within given NDArrayList - * Expected arguments: - * x: list - */ -// #if NOT_EXCLUDED(OP_size_list) - @Namespace("sd::ops") public static class size_list extends DeclarableListOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public size_list(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public size_list(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public size_list position(long position) { - return (size_list)super.position(position); - } - public size_list() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation creates new empty NDArrayList + */ +// #if NOT_EXCLUDED(OP_create_list) +@Namespace("sd::ops") public static class create_list extends DeclarableListOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public create_list(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public create_list(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public create_list position(long position) { + return (create_list)super.position(position); + } - /** - * This operation creates new empty NDArrayList - */ -// #if NOT_EXCLUDED(OP_create_list) - @Namespace("sd::ops") public static class create_list extends DeclarableListOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public create_list(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public create_list(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public create_list position(long position) { - return (create_list)super.position(position); - } - public create_list() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation unpacks given NDArray into specified NDArrayList wrt specified + * indices + */ +// #if NOT_EXCLUDED(OP_scatter_list) +@Namespace("sd::ops") public static class scatter_list extends DeclarableListOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_list(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_list(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_list position(long position) { + return (scatter_list)super.position(position); + } - /** - * This operation unpacks given NDArray into specified NDArrayList wrt specified indices - */ -// #if NOT_EXCLUDED(OP_scatter_list) - @Namespace("sd::ops") public static class scatter_list extends DeclarableListOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_list(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_list(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_list position(long position) { - return (scatter_list)super.position(position); - } - public scatter_list() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation splits given NDArray into chunks, and stores them into given + * NDArrayList wert sizes Expected arguments: list: optional, NDArrayList. if + * not available - new NDArrayList will be created array: array to be split + * sizes: vector with sizes for each chunk + */ +// #if NOT_EXCLUDED(OP_split_list) +@Namespace("sd::ops") public static class split_list extends DeclarableListOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public split_list(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public split_list(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public split_list position(long position) { + return (split_list)super.position(position); + } - /** - * This operation splits given NDArray into chunks, and stores them into given NDArrayList wert sizes - * Expected arguments: - * list: optional, NDArrayList. if not available - new NDArrayList will be created - * array: array to be split - * sizes: vector with sizes for each chunk - */ -// #if NOT_EXCLUDED(OP_split_list) - @Namespace("sd::ops") public static class split_list extends DeclarableListOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public split_list(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public split_list(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public split_list position(long position) { - return (split_list)super.position(position); - } - public split_list() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation builds NDArray from NDArrayList using indices + * Expected arguments: + * x: non-empty list + * indices: vector with indices for gather operation + */ +// #if NOT_EXCLUDED(OP_gather_list) +@Namespace("sd::ops") public static class gather_list extends DeclarableListOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public gather_list(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public gather_list(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public gather_list position(long position) { + return (gather_list)super.position(position); + } - /** - * This operation builds NDArray from NDArrayList using indices - * Expected arguments: - * x: non-empty list - * indices: vector with indices for gather operation - */ -// #if NOT_EXCLUDED(OP_gather_list) - @Namespace("sd::ops") public static class gather_list extends DeclarableListOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public gather_list(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public gather_list(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public gather_list position(long position) { - return (gather_list)super.position(position); - } - public gather_list() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation clones given NDArrayList + */ +// #if NOT_EXCLUDED(OP_clone_list) +@Namespace("sd::ops") public static class clone_list extends DeclarableListOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public clone_list(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public clone_list(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public clone_list position(long position) { + return (clone_list)super.position(position); + } - /** - * This operation clones given NDArrayList - */ -// #if NOT_EXCLUDED(OP_clone_list) - @Namespace("sd::ops") public static class clone_list extends DeclarableListOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public clone_list(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public clone_list(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public clone_list position(long position) { - return (clone_list)super.position(position); - } - public clone_list() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation unstacks given NDArray into NDArrayList by the first dimension + */ +// #if NOT_EXCLUDED(OP_unstack_list) +@Namespace("sd::ops") public static class unstack_list extends DeclarableListOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unstack_list(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unstack_list(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unstack_list position(long position) { + return (unstack_list)super.position(position); + } - /** - * This operation unstacks given NDArray into NDArrayList by the first dimension - */ -// #if NOT_EXCLUDED(OP_unstack_list) - @Namespace("sd::ops") public static class unstack_list extends DeclarableListOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unstack_list(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unstack_list(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unstack_list position(long position) { - return (unstack_list)super.position(position); - } - public unstack_list() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif - - +// #endif + // namespace ops + // namespace sd // #endif @@ -16013,714 +18341,742 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #include - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features - * 1: 2d tensor of weights [3K x K] - * 2: row of biases with twice length [1 x 2K] - * 3: 2d tensor of previous cell state [bS x K] - * 4: optional, 2d tensor of dropout mask [bS x K] - * - * Output arrays: - * 0: 3d tensor of cell output [bS x K x N] - * 1: 3d tensor of cell state [bS x K x N] - */ -// #if NOT_EXCLUDED(OP_sru) - @Namespace("sd::ops") public static class sru extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sru(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sru(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sru position(long position) { - return (sru)super.position(position); - } - +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for Simple Recurrent Unit: "Training RNNs as Fast + * as CNNs" Tao Lei, Yu Zhang, Yoav Artzi + * + * Input arrays: + * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - + * batch size, K - number of features 1: 2d tensor of weights [3K x K] 2: row of + * biases with twice length [1 x 2K] 3: 2d tensor of previous cell state [bS x + * K] 4: optional, 2d tensor of dropout mask [bS x K] + * + * Output arrays: + * 0: 3d tensor of cell output [bS x K x N] + * 1: 3d tensor of cell state [bS x K x N] + */ +// #if NOT_EXCLUDED(OP_sru) +@Namespace("sd::ops") public static class sru extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sru(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sru(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sru position(long position) { + return (sru)super.position(position); + } + public sru() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for Simple Recurrent Unit (bidirectional case): + * "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi + * + * Input arrays: + * 0: input 3d tensor with shape [N x bS x 2K], N - number of time steps, bS + * - batch size, K - number of features 1: 2d tensor of weights [2K x 6K] 2: row + * of biases with twice length [1 x 4K] 3: 2d tensor of previous cell state [bS + * x 2K] 4: optional, 2d tensor of dropout mask [bS x 2K] + * + * Output arrays: + * 0: 3d tensor of cell output [N x bS x 2K] + * 1: 3d tensor of cell state [N x bS x 2K] + */ +// #if NOT_EXCLUDED(OP_sru_bi) +@Namespace("sd::ops") public static class sru_bi extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sru_bi(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sru_bi(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sru_bi position(long position) { + return (sru_bi)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for Simple Recurrent Unit (bidirectional case): "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input 3d tensor with shape [N x bS x 2K], N - number of time steps, bS - batch size, K - number of features - * 1: 2d tensor of weights [2K x 6K] - * 2: row of biases with twice length [1 x 4K] - * 3: 2d tensor of previous cell state [bS x 2K] - * 4: optional, 2d tensor of dropout mask [bS x 2K] - * - * Output arrays: - * 0: 3d tensor of cell output [N x bS x 2K] - * 1: 3d tensor of cell state [N x bS x 2K] - */ -// #if NOT_EXCLUDED(OP_sru_bi) - @Namespace("sd::ops") public static class sru_bi extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sru_bi(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sru_bi(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sru_bi position(long position) { - return (sru_bi)super.position(position); - } - public sru_bi() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for back propagation in Simple Recurrent Unit: + * "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi + * + * Input arrays: + * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - + * batch size, K - number of features 1: 2d tensor of weights [3K x K] 2: row of + * biases with twice length [1 x 2K] 3: 2d tensor of previous cell state [bS x + * K] 4: 3d tensor of cell state [bS x K x N] 5: 2d tensor of cell state + * gradients [bS x K] 6: 3d tensor of state output gradients [bS x K x N] 7: + * optional, 2d tensor of dropout mask [bS x K] + * + * Output arrays: + * 0: 3d tensor of input gradients [bS x K x N] + * 1: 3d tensor of weights gradients [bS x 3K x K] + * 2: 2d, row of biases gradients [1 x 2K] + * 3: 2d, tensor of state gradients [bS x K] + */ +// #if NOT_EXCLUDED(OP_sru) +@Namespace("sd::ops") public static class sru_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sru_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sru_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sru_bp position(long position) { + return (sru_bp)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for back propagation in Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features - * 1: 2d tensor of weights [3K x K] - * 2: row of biases with twice length [1 x 2K] - * 3: 2d tensor of previous cell state [bS x K] - * 4: 3d tensor of cell state [bS x K x N] - * 5: 2d tensor of cell state gradients [bS x K] - * 6: 3d tensor of state output gradients [bS x K x N] - * 7: optional, 2d tensor of dropout mask [bS x K] - * - * Output arrays: - * 0: 3d tensor of input gradients [bS x K x N] - * 1: 3d tensor of weights gradients [bS x 3K x K] - * 2: 2d, row of biases gradients [1 x 2K] - * 3: 2d, tensor of state gradients [bS x K] - */ -// #if NOT_EXCLUDED(OP_sru) - @Namespace("sd::ops") public static class sru_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sru_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sru_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sru_bp position(long position) { - return (sru_bp)super.position(position); - } - public sru_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for back propagation in Simple Recurrent Unit + * (bidirectional case): "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav + * Artzi + * + * Input arrays: + * 0: input 3d tensor with shape [N x bS x 2K], N - number of time steps, bS + * - batch size, K - number of features 1: 2d tensor of weights [2K x 6K] 2: row + * of biases with twice length [1 x 4K] 3: 2d tensor of previous cell state [bS + * x 2K] 4: 3d tensor of cell state [N x bS x 2K] 5: 2d tensor of cell state + * gradients [bS x 2K] 6: 3d tensor of state output gradients [N x bS x 2K] 7: + * optional, 2d tensor of dropout mask [bS x 2K] + * + * Output arrays: + * 0: 3d tensor of input gradients [N x bS x 2K] + * 1: 3d tensor of weights gradients [N x 2K x 6K] + * 2: 2d, row of biases gradients [1 x 4K] + * 3: 2d, tensor of state gradients [bS x 2K] + */ +// #if NOT_EXCLUDED(OP_sru_bi) +@Namespace("sd::ops") public static class sru_bi_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sru_bi_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sru_bi_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sru_bi_bp position(long position) { + return (sru_bi_bp)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for back propagation in Simple Recurrent Unit (bidirectional case): "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input 3d tensor with shape [N x bS x 2K], N - number of time steps, bS - batch size, K - number of features - * 1: 2d tensor of weights [2K x 6K] - * 2: row of biases with twice length [1 x 4K] - * 3: 2d tensor of previous cell state [bS x 2K] - * 4: 3d tensor of cell state [N x bS x 2K] - * 5: 2d tensor of cell state gradients [bS x 2K] - * 6: 3d tensor of state output gradients [N x bS x 2K] - * 7: optional, 2d tensor of dropout mask [bS x 2K] - * - * Output arrays: - * 0: 3d tensor of input gradients [N x bS x 2K] - * 1: 3d tensor of weights gradients [N x 2K x 6K] - * 2: 2d, row of biases gradients [1 x 4K] - * 3: 2d, tensor of state gradients [bS x 2K] - */ -// #if NOT_EXCLUDED(OP_sru_bi) - @Namespace("sd::ops") public static class sru_bi_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sru_bi_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sru_bi_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sru_bi_bp position(long position) { - return (sru_bi_bp)super.position(position); - } - public sru_bi_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for LSTM cell with peep hole connections: + * S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural + * Computation and https://research.google.com/pubs/archive/43905.pdf Hasim Sak, + * Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent + * neural network architectures for large scale acoustic modeling." INTERSPEECH, + * 2014. + * + * Input arrays: + * 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - + * number of features 1: previous cell output [batchSize x numProj], that is at + * previous time step t-1, in case of projection=false -> numProj=numUnits!!! 2: + * previous cell state [batchSize x numUnits], that is at previous time step + * t-1 3: input-to-hidden weights, [inSize x 4*numUnits] 4: hidden-to-hidden + * weights, [numProj x 4*numUnits] 5: diagonal weights for peephole connections + * [3*numUnits] 6: projection weights [numUnits x numProj] 7: biases, + * [4*numUnits] + * + * Input integer arguments: + * 0: if not zero, provide peephole connections + * 1: if not zero, then projection is performed, if zero then + * numProj==numUnits is mandatory! + * + * Input float arguments: + * 0: clipping value for cell state, if it is not equal to zero, then cell + * state is clipped 1: clipping value for projected cell output, if it is not + * equal to zero, then projected cell output is clipped 2: the bias added to + * forget gates in order to reduce the scale of forgetting in the beginning of + * the training + * + * Output arrays: + * 0: current cell output [batchSize x numProj], that is at current time step + * t 1: current cell state [batchSize x numUnits], that is at current time step + * t + */ +// #if NOT_EXCLUDED(OP_lstmCell) +@Namespace("sd::ops") public static class lstmCell extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstmCell(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstmCell(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstmCell position(long position) { + return (lstmCell)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for LSTM cell with peep hole connections: - * S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation - * and - * https://research.google.com/pubs/archive/43905.pdf - * Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014. - * - * Input arrays: - * 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - number of features - * 1: previous cell output [batchSize x numProj], that is at previous time step t-1, in case of projection=false -> numProj=numUnits!!! - * 2: previous cell state [batchSize x numUnits], that is at previous time step t-1 - * 3: input-to-hidden weights, [inSize x 4*numUnits] - * 4: hidden-to-hidden weights, [numProj x 4*numUnits] - * 5: diagonal weights for peephole connections [3*numUnits] - * 6: projection weights [numUnits x numProj] - * 7: biases, [4*numUnits] - * - * Input integer arguments: - * 0: if not zero, provide peephole connections - * 1: if not zero, then projection is performed, if zero then numProj==numUnits is mandatory! - * - * Input float arguments: - * 0: clipping value for cell state, if it is not equal to zero, then cell state is clipped - * 1: clipping value for projected cell output, if it is not equal to zero, then projected cell output is clipped - * 2: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training - * - * Output arrays: - * 0: current cell output [batchSize x numProj], that is at current time step t - * 1: current cell state [batchSize x numUnits], that is at current time step t - */ -// #if NOT_EXCLUDED(OP_lstmCell) - @Namespace("sd::ops") public static class lstmCell extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lstmCell(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lstmCell(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lstmCell position(long position) { - return (lstmCell)super.position(position); - } - public lstmCell() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_lstmLayerCell) +@Namespace("sd::ops") public static class lstmLayerCell extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstmLayerCell(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstmLayerCell(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstmLayerCell position(long position) { + return (lstmLayerCell)super.position(position); + } -// #if NOT_EXCLUDED(OP_lstmLayerCell) - @Namespace("sd::ops") public static class lstmLayerCell extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lstmLayerCell(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lstmLayerCell(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lstmLayerCell position(long position) { - return (lstmLayerCell)super.position(position); - } - public lstmLayerCell() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_lstmLayerCell) - @Namespace("sd::ops") public static class lstmLayerCellBp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lstmLayerCellBp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lstmLayerCellBp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lstmLayerCellBp position(long position) { - return (lstmLayerCellBp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_lstmLayerCell) +@Namespace("sd::ops") public static class lstmLayerCellBp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstmLayerCellBp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstmLayerCellBp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstmLayerCellBp position(long position) { + return (lstmLayerCellBp)super.position(position); + } + public lstmLayerCellBp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for LSTM cell with optional peep hole + * connections: S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". + * Neural Computation and https://research.google.com/pubs/archive/43905.pdf + * Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory + * recurrent neural network architectures for large scale acoustic modeling." + * INTERSPEECH, 2014. See also: https://arxiv.org/pdf/1503.04069.pdf + * + * Input arrays: + * 0: input [bS, inSize] at time t + * 1: previous cell state [bS, numUnits], time t-1 + * 2: previous output [bS, numUnits], time t-1 + * 3: Weights - concatenated (input-to-hidden, hidden-to-hidden weights) + * weights, [(inSize+numUnits), 4*numUnits] 4: weights - cell peephole (t-1) + * connections to input modulation gate, [numUnits] 5: weights - cell peephole + * (t-1) connections to forget gate, [numUnits] 6: weights - cell peephole (t) + * connections to output gate, [numUnits] 7: biases, shape [4*numUnits] + * + * Input integer arguments: + * 0: if not zero, provide peephole connections + * + * Input float arguments: + * 0: the bias added to forget gates in order to reduce the scale of + * forgetting in the beginning of the training 1: clipping value for cell state, + * if it is not equal to zero, then cell state is clipped + * + * Output arrays: + * 0: i - Input modulation gate activations [bS, numUnits] + * 1: c (cs) - Cell state (pre tanh) [bs, numUnits] (cs) + * 2: f - Output - forget gate activations [bs, numUnits] + * 3: o - Output - output gate activations [bs, numUnits] + * 4: z (ci) - Output - block input [bs, numUnits] + * 5: h (co) - Cell state, post tanh [bs, numUnits] + * 6: y (h) - Current cell output [bS, numUnits], time t + */ +// #if NOT_EXCLUDED(OP_lstmBlockCell) +@Namespace("sd::ops") public static class lstmBlockCell extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstmBlockCell(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstmBlockCell(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstmBlockCell position(long position) { + return (lstmBlockCell)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for LSTM cell with optional peep hole connections: - * S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation - * and - * https://research.google.com/pubs/archive/43905.pdf - * Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014. - * See also: https://arxiv.org/pdf/1503.04069.pdf - * - * Input arrays: - * 0: input [bS, inSize] at time t - * 1: previous cell state [bS, numUnits], time t-1 - * 2: previous output [bS, numUnits], time t-1 - * 3: Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(inSize+numUnits), 4*numUnits] - * 4: weights - cell peephole (t-1) connections to input modulation gate, [numUnits] - * 5: weights - cell peephole (t-1) connections to forget gate, [numUnits] - * 6: weights - cell peephole (t) connections to output gate, [numUnits] - * 7: biases, shape [4*numUnits] - * - * Input integer arguments: - * 0: if not zero, provide peephole connections - * - * Input float arguments: - * 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training - * 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped - * - * Output arrays: - * 0: i - Input modulation gate activations [bS, numUnits] - * 1: c (cs) - Cell state (pre tanh) [bs, numUnits] (cs) - * 2: f - Output - forget gate activations [bs, numUnits] - * 3: o - Output - output gate activations [bs, numUnits] - * 4: z (ci) - Output - block input [bs, numUnits] - * 5: h (co) - Cell state, post tanh [bs, numUnits] - * 6: y (h) - Current cell output [bS, numUnits], time t - */ -// #if NOT_EXCLUDED(OP_lstmBlockCell) - @Namespace("sd::ops") public static class lstmBlockCell extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lstmBlockCell(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lstmBlockCell(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lstmBlockCell position(long position) { - return (lstmBlockCell)super.position(position); - } - public lstmBlockCell() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for LSTM layer with optional peep hole + * connections. See lstmBlockCell for details. lstmBlockCell is used internally + * for computation. This method expects as input (and returns as output) + * sequences in one of 3 formats, depending on the data format arg: dataFormat = + * 0 -> TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to + * as "time major" dataFormat = 1 -> NST: shape [numExamples, inOutSize, + * timeLength] dataFormat = 2 -> NTS: shape [numExamples, timeLength, inOutSize] + * - TF "time_major=false" layout + * + * + * Input arrays: + * 0: max sequence length; long/int64 scalar + * 1: input [seqLength, bS, inSize] at time t + * 2: previous/initial cell state [bS, numUnits] + * 3: previous/initial output [bS, numUnits] + * 4: Weights - concatenated (input-to-hidden, hidden-to-hidden weights) + * weights, [(inSize+numUnits), 4*numUnits] 5: weights - cell peephole (t-1) + * connections to input modulation gate, [numUnits] 6: weights - cell peephole + * (t-1) connections to forget gate, [numUnits] 7: weights - cell peephole (t) + * connections to output gate, [numUnits] 8: biases, Shape [4*numUnits] + * + * Input integer arguments: + * 0: if not zero, provide peephole connections + * 1: Data format - 0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen]; + * 2=NTS=[mb,seqLen,size] + * + * Input float arguments: + * 0: the bias added to forget gates in order to reduce the scale of + * forgetting in the beginning of the training 1: clipping value for cell state, + * if it is not equal to zero, then cell state is clipped + * + * Output arrays: + * 0: i - Input modulation gate activations, rank 3, shape as per + * dataFormat 1: c (cs) - Cell state (pre tanh), rank 3, shape as per dataFormat + * 2: f - Output - forget gate activations, rank 3, shape as per + * dataFormat 3: o - Output - output gate activations, rank 3, shape as per + * dataFormat 4: z (ci) - Output - block input, rank 3, shape as per dataFormat + * 5: h (co) - Cell state, post tanh, rank 3, shape as per dataFormat + * 6: y (h) - Current cell output, rank 3, shape as per dataFormat + */ +// #if NOT_EXCLUDED(OP_lstmBlock) +@Namespace("sd::ops") public static class lstmBlock extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstmBlock(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstmBlock(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstmBlock position(long position) { + return (lstmBlock)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for LSTM layer with optional peep hole connections. - * See lstmBlockCell for details. lstmBlockCell is used internally for computation. - * This method expects as input (and returns as output) sequences in one of 3 formats, depending on the data format arg: - * dataFormat = 0 -> TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major" - * dataFormat = 1 -> NST: shape [numExamples, inOutSize, timeLength] - * dataFormat = 2 -> NTS: shape [numExamples, timeLength, inOutSize] - TF "time_major=false" layout - * - * - * Input arrays: - * 0: max sequence length; long/int64 scalar - * 1: input [seqLength, bS, inSize] at time t - * 2: previous/initial cell state [bS, numUnits] - * 3: previous/initial output [bS, numUnits] - * 4: Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(inSize+numUnits), 4*numUnits] - * 5: weights - cell peephole (t-1) connections to input modulation gate, [numUnits] - * 6: weights - cell peephole (t-1) connections to forget gate, [numUnits] - * 7: weights - cell peephole (t) connections to output gate, [numUnits] - * 8: biases, Shape [4*numUnits] - * - * Input integer arguments: - * 0: if not zero, provide peephole connections - * 1: Data format - 0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen]; 2=NTS=[mb,seqLen,size] - * - * Input float arguments: - * 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training - * 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped - * - * Output arrays: - * 0: i - Input modulation gate activations, rank 3, shape as per dataFormat - * 1: c (cs) - Cell state (pre tanh), rank 3, shape as per dataFormat - * 2: f - Output - forget gate activations, rank 3, shape as per dataFormat - * 3: o - Output - output gate activations, rank 3, shape as per dataFormat - * 4: z (ci) - Output - block input, rank 3, shape as per dataFormat - * 5: h (co) - Cell state, post tanh, rank 3, shape as per dataFormat - * 6: y (h) - Current cell output, rank 3, shape as per dataFormat - */ -// #if NOT_EXCLUDED(OP_lstmBlock) - @Namespace("sd::ops") public static class lstmBlock extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lstmBlock(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lstmBlock(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lstmBlock position(long position) { - return (lstmBlock)super.position(position); - } - public lstmBlock() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +// #if NOT_EXCLUDED(OP_lstmLayer) +@Namespace("sd::ops") public static class lstmLayer extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstmLayer(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstmLayer(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstmLayer position(long position) { + return (lstmLayer)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// -// #if NOT_EXCLUDED(OP_lstmLayer) - @Namespace("sd::ops") public static class lstmLayer extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lstmLayer(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lstmLayer(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lstmLayer position(long position) { - return (lstmLayer)super.position(position); - } - public lstmLayer() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +// #if NOT_EXCLUDED(OP_lstmLayer) +@Namespace("sd::ops") public static class lstmLayer_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstmLayer_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstmLayer_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstmLayer_bp position(long position) { + return (lstmLayer_bp)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// -// #if NOT_EXCLUDED(OP_lstmLayer) - @Namespace("sd::ops") public static class lstmLayer_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lstmLayer_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lstmLayer_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lstmLayer_bp position(long position) { - return (lstmLayer_bp)super.position(position); - } - public lstmLayer_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operations for Simple Recurrent Unit cell: "Training RNNs + * as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi + * + * Input arrays: + * 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - + * number of features 1: previous cell state [batchSize x inSize], that is at + * previous time step t-1 2: weights [inSize x 3*inSize] 3: biases [1 x + * 2*inSize] + * + * Output arrays: + * 0: current cell output [batchSize x inSize], that is at current time step + * t 1: current cell state [batchSize x inSize], that is at current time step t + */ +// #if NOT_EXCLUDED(OP_sruCell) +@Namespace("sd::ops") public static class sruCell extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sruCell(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sruCell(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sruCell position(long position) { + return (sruCell)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operations for Simple Recurrent Unit cell: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - number of features - * 1: previous cell state [batchSize x inSize], that is at previous time step t-1 - * 2: weights [inSize x 3*inSize] - * 3: biases [1 x 2*inSize] - * - * Output arrays: - * 0: current cell output [batchSize x inSize], that is at current time step t - * 1: current cell state [batchSize x inSize], that is at current time step t - */ -// #if NOT_EXCLUDED(OP_sruCell) - @Namespace("sd::ops") public static class sruCell extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sruCell(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sruCell(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sruCell position(long position) { - return (sruCell)super.position(position); - } - public sruCell() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of gated Recurrent Unit cell: + * Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, + * Fethi Bougares, Holger Schwenk, Yoshua Bengio "Learning Phrase + * Representations using RNN Encoder-Decoder for Statistical Machine + * Translation" + * + * Input arrays: + * 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - + * number of features 1: previous cell output [batchSize x numUnits], that is + * at previous time step t-1 2: RU weights - [(inSize+numUnits), 2*numUnits] - + * reset and update gates (input/recurrent weights) 3: C weights - + * [(inSize+numUnits), numUnits] - cell gate (input/recurrent weights) 4: reset + * and update biases, [2*numUnits] - reset and update gates 5: cell biases, + * [numUnits] + * + * Output arrays: + * 0: Reset gate output [bS, numUnits] + * 1: Update gate output [bS, numUnits] + * 2: Cell gate output [bS, numUnits] + * 3: Current cell output [bS, numUnits] + */ +// #if NOT_EXCLUDED(OP_gruCell) +@Namespace("sd::ops") public static class gruCell extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public gruCell(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public gruCell(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public gruCell position(long position) { + return (gruCell)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of gated Recurrent Unit cell: - * Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, Yoshua Bengio - * "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation" - * - * Input arrays: - * 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - number of features - * 1: previous cell output [batchSize x numUnits], that is at previous time step t-1 - * 2: RU weights - [(inSize+numUnits), 2*numUnits] - reset and update gates (input/recurrent weights) - * 3: C weights - [(inSize+numUnits), numUnits] - cell gate (input/recurrent weights) - * 4: reset and update biases, [2*numUnits] - reset and update gates - * 5: cell biases, [numUnits] - * - * Output arrays: - * 0: Reset gate output [bS, numUnits] - * 1: Update gate output [bS, numUnits] - * 2: Cell gate output [bS, numUnits] - * 3: Current cell output [bS, numUnits] - */ -// #if NOT_EXCLUDED(OP_gruCell) - @Namespace("sd::ops") public static class gruCell extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public gruCell(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public gruCell(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public gruCell position(long position) { - return (gruCell)super.position(position); - } - public gruCell() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_gruCell) +@Namespace("sd::ops") public static class gruCell_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public gruCell_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public gruCell_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public gruCell_bp position(long position) { + return (gruCell_bp)super.position(position); + } -// #if NOT_EXCLUDED(OP_gruCell) - @Namespace("sd::ops") public static class gruCell_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public gruCell_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public gruCell_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public gruCell_bp position(long position) { - return (gruCell_bp)super.position(position); - } - public gruCell_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation "LSTM time sequences" with peep hole connections: + * + * Input arrays: + * 0: input with shape [time x batchSize x inSize], time - number of time + * steps, batchSize - batch size, inSize - number of features 1: initial cell + * output [batchSize x numProj], that is at time step = 0, in case of + * projection=false -> numProj=numUnits!!! 2: initial cell state [batchSize x + * numUnits], that is at time step = 0 3: input-to-hidden weights, [inSize x + * 4*numUnits] 4: hidden-to-hidden weights, [numProj x 4*numUnits] 5: diagonal + * weights for peephole connections [3*numUnits] 6: projection weights [numUnits + * x numProj] 7: biases, [4*numUnits] + * + * Input integer arguments: + * 0: if not zero, provide peephole connections + * 1: if not zero, then projection is performed, if zero then + * numProj==numUnits is mandatory! + * + * Input float arguments: + * 0: clipping value for cell state, if it is not equal to zero, then cell + * state is clipped 1: clipping value for projected cell output, if it is not + * equal to zero, then projected cell output is clipped 2: the bias added to + * forget gates in order to reduce the scale of forgetting in the beginning of + * the training + * + * Output arrays: + * 0: cell outputs [time x batchSize x numProj], that is per each time step + * 1: cell states [time x batchSize x numUnits], that is per each time step + */ +// #if NOT_EXCLUDED(OP_lstm) +@Namespace("sd::ops") public static class lstm extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstm(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstm(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstm position(long position) { + return (lstm)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation "LSTM time sequences" with peep hole connections: - * - * Input arrays: - * 0: input with shape [time x batchSize x inSize], time - number of time steps, batchSize - batch size, inSize - number of features - * 1: initial cell output [batchSize x numProj], that is at time step = 0, in case of projection=false -> numProj=numUnits!!! - * 2: initial cell state [batchSize x numUnits], that is at time step = 0 - * 3: input-to-hidden weights, [inSize x 4*numUnits] - * 4: hidden-to-hidden weights, [numProj x 4*numUnits] - * 5: diagonal weights for peephole connections [3*numUnits] - * 6: projection weights [numUnits x numProj] - * 7: biases, [4*numUnits] - * - * Input integer arguments: - * 0: if not zero, provide peephole connections - * 1: if not zero, then projection is performed, if zero then numProj==numUnits is mandatory! - * - * Input float arguments: - * 0: clipping value for cell state, if it is not equal to zero, then cell state is clipped - * 1: clipping value for projected cell output, if it is not equal to zero, then projected cell output is clipped - * 2: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training - * - * Output arrays: - * 0: cell outputs [time x batchSize x numProj], that is per each time step - * 1: cell states [time x batchSize x numUnits], that is per each time step - */ -// #if NOT_EXCLUDED(OP_lstm) - @Namespace("sd::ops") public static class lstm extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lstm(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lstm(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lstm position(long position) { - return (lstm)super.position(position); - } - public lstm() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of gated Recurrent Unit: + * + * Input arrays: + * 0: input with shape [time x batchSize x inSize], time - number of time + * steps, batchSize - batch size, inSize - number of features 1: initial cell + * output [batchSize x numUnits], that is at time step = 0 2: input-to-hidden + * weights, [inSize x 3*numUnits] 3: hidden-to-hidden weights, [numUnits x + * 3*numUnits] 4: biases, [3*numUnits] + * + * Output arrays: + * 0: cell outputs [time x batchSize x numUnits], that is per each time step + */ +// #if NOT_EXCLUDED(OP_gru) +@Namespace("sd::ops") public static class gru extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public gru(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public gru(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public gru position(long position) { + return (gru)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of gated Recurrent Unit: - * - * Input arrays: - * 0: input with shape [time x batchSize x inSize], time - number of time steps, batchSize - batch size, inSize - number of features - * 1: initial cell output [batchSize x numUnits], that is at time step = 0 - * 2: input-to-hidden weights, [inSize x 3*numUnits] - * 3: hidden-to-hidden weights, [numUnits x 3*numUnits] - * 4: biases, [3*numUnits] - * - * Output arrays: - * 0: cell outputs [time x batchSize x numUnits], that is per each time step - */ -// #if NOT_EXCLUDED(OP_gru) - @Namespace("sd::ops") public static class gru extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public gru(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public gru(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public gru position(long position) { - return (gru)super.position(position); - } - public gru() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_gru) +@Namespace("sd::ops") public static class gru_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public gru_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public gru_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public gru_bp position(long position) { + return (gru_bp)super.position(position); + } -// #if NOT_EXCLUDED(OP_gru) - @Namespace("sd::ops") public static class gru_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public gru_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public gru_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public gru_bp position(long position) { - return (gru_bp)super.position(position); - } - public gru_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation "static RNN time sequences" with peep hole + * connections: + * + * Input arrays: + * 0: input with shape [time x batchSize x inSize], time - number of time + * steps, batchSize - batch size, inSize - number of features 1: input-to-hidden + * weights, [inSize x numUnits] 2: hidden-to-hidden weights, [numUnits x + * numUnits] 3: biases, [2*numUnits] 4: (optional) initial cell output + * [batchSize x numUnits], that is at time step = 0 5: (optional) vector with + * shape [batchSize] containing integer values within [0,time), each element of + * this vector set max time step per each input in batch, this provides no + * calculations for time >= maxTimeStep + * + * Output arrays: + * 0: cell outputs [time x batchSize x numUnits] + * 1: cell final non-zero output [batchSize x numUnits] + */ +@Namespace("sd::ops") public static class static_rnn extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public static_rnn(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public static_rnn(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public static_rnn position(long position) { + return (static_rnn)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation "static RNN time sequences" with peep hole connections: - * - * Input arrays: - * 0: input with shape [time x batchSize x inSize], time - number of time steps, batchSize - batch size, inSize - number of features - * 1: input-to-hidden weights, [inSize x numUnits] - * 2: hidden-to-hidden weights, [numUnits x numUnits] - * 3: biases, [2*numUnits] - * 4: (optional) initial cell output [batchSize x numUnits], that is at time step = 0 - * 5: (optional) vector with shape [batchSize] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this provides no calculations for time >= maxTimeStep - * - * Output arrays: - * 0: cell outputs [time x batchSize x numUnits] - * 1: cell final non-zero output [batchSize x numUnits] - */ - @Namespace("sd::ops") public static class static_rnn extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public static_rnn(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public static_rnn(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public static_rnn position(long position) { - return (static_rnn)super.position(position); - } - public static_rnn() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation "static RNN time sequences" with peep hole connections: - * - * Input arrays: - * 0: input with shape [time x batchSize x inSize] or [batchSize x time x numUnits], time - number of time steps, batchSize - batch size, inSize - number of features - * 1: input-to-hidden weights, [inSize x numUnits] - * 2: hidden-to-hidden weights, [numUnits x numUnits] - * 3: biases, [2*numUnits] - * 4: (optional) initial cell output [batchSize x numUnits], that is at time step = 0 - * 5: (optional) vector with shape [batchSize] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this provides no calculations for time >= maxTimeStep - * - * Input integer arguments: - * 0: (optional) timeMajor - if non zero then input shape is [time, batchSize, ...], else [batchSize, time, ...] - * - * Output arrays: - * 0: cell outputs [time x batchSize x numUnits] or [batchSize x time x numUnits] - * 1: cell final non-zero output [batchSize x numUnits] - */ - @Namespace("sd::ops") public static class dynamic_rnn extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public dynamic_rnn(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public dynamic_rnn(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public dynamic_rnn position(long position) { - return (dynamic_rnn)super.position(position); - } - +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation "static RNN time sequences" with peep hole + * connections: + * + * Input arrays: + * 0: input with shape [time x batchSize x inSize] or [batchSize x time x + * numUnits], time - number of time steps, batchSize - batch size, inSize - + * number of features 1: input-to-hidden weights, [inSize x numUnits] 2: + * hidden-to-hidden weights, [numUnits x numUnits] 3: biases, [2*numUnits] 4: + * (optional) initial cell output [batchSize x numUnits], that is at time step = + * 0 5: (optional) vector with shape [batchSize] containing integer values + * within [0,time), each element of this vector set max time step per each input + * in batch, this provides no calculations for time >= maxTimeStep + * + * Input integer arguments: + * 0: (optional) timeMajor - if non zero then input shape is [time, + * batchSize, ...], else [batchSize, time, ...] + * + * Output arrays: + * 0: cell outputs [time x batchSize x numUnits] or [batchSize x time x + * numUnits] 1: cell final non-zero output [batchSize x numUnits] + */ +@Namespace("sd::ops") public static class dynamic_rnn extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public dynamic_rnn(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public dynamic_rnn(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public dynamic_rnn position(long position) { + return (dynamic_rnn)super.position(position); + } + public dynamic_rnn() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation "static RNN time sequences" with peep hole connections: - * - * Input arrays: - * 0: input with shape [time x batchSize x inSize], time - number of time steps, batchSize - batch size, inSize - number of features - * 1: input-to-hidden weights for forward RNN, [inSize x numUnitsFW] - * 2: hidden-to-hidden weights for forward RNN, [numUnitsFW x numUnitsFW] - * 3: biases for forward RNN, [2*numUnitsFW] - * 4: input-to-hidden weights for backward RNN, [inSize x numUnitsBW] - * 5: hidden-to-hidden weights for backward RNN, [numUnitsBW x numUnitsBW] - * 6: biases for backward RNN, [2*numUnitsBW] - * 7: (optional) initial cell output for forward RNN [batchSize x numUnitsFW], that is at time step = 0 - * 8: (optional) initial cell output for backward RNN [batchSize x numUnitsBW], that is at time step = 0 - * 9: (optional) vector with shape [batchSize] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this provides no calculations for time >= maxTimeStep - * - * Output arrays: - * 0: cell outputs [time x batchSize x (numUnitsFW + numUnitsBW)] - * 1: cell final non-zero output for forward RNN [batchSize x numUnitsFW] - * 2: cell final non-zero output for backward RNN [batchSize x numUnitsBW] - */ - @Namespace("sd::ops") public static class static_bidirectional_rnn extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public static_bidirectional_rnn(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public static_bidirectional_rnn(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public static_bidirectional_rnn position(long position) { - return (static_bidirectional_rnn)super.position(position); - } - - public static_bidirectional_rnn() { super((Pointer)null); allocate(); } - private native void allocate(); +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation "static RNN time sequences" with peep hole + * connections: + * + * Input arrays: + * 0: input with shape [time x batchSize x inSize], time - number of time + * steps, batchSize - batch size, inSize - number of features 1: input-to-hidden + * weights for forward RNN, [inSize x numUnitsFW] 2: hidden-to-hidden weights + * for forward RNN, [numUnitsFW x numUnitsFW] 3: biases for forward RNN, + * [2*numUnitsFW] 4: input-to-hidden weights for backward RNN, [inSize x + * numUnitsBW] 5: hidden-to-hidden weights for backward RNN, [numUnitsBW x + * numUnitsBW] 6: biases for backward RNN, [2*numUnitsBW] 7: (optional) initial + * cell output for forward RNN [batchSize x numUnitsFW], that is at time step = + * 0 8: (optional) initial cell output for backward RNN [batchSize x + * numUnitsBW], that is at time step = 0 9: (optional) vector with shape + * [batchSize] containing integer values within [0,time), each element of this + * vector set max time step per each input in batch, this provides no + * calculations for time >= maxTimeStep + * + * Output arrays: + * 0: cell outputs [time x batchSize x (numUnitsFW + numUnitsBW)] + * 1: cell final non-zero output for forward RNN [batchSize x numUnitsFW] + * 2: cell final non-zero output for backward RNN [batchSize x numUnitsBW] + */ +@Namespace("sd::ops") public static class static_bidirectional_rnn extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public static_bidirectional_rnn(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public static_bidirectional_rnn(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public static_bidirectional_rnn position(long position) { + return (static_bidirectional_rnn)super.position(position); + } + + public static_bidirectional_rnn() { super((Pointer)null); allocate(); } + private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation "static RNN time sequences" with peep hole connections: - * - * Input arrays: - * 0: input with shape [time x batchSize x inSize] or [batchSize x time x inSize], time - number of time steps, batchSize - batch size, inSize - number of features - * 1: input-to-hidden weights for forward RNN, [inSize x numUnitsFW] - * 2: hidden-to-hidden weights for forward RNN, [numUnitsFW x numUnitsFW] - * 3: biases for forward RNN, [2*numUnitsFW] - * 4: input-to-hidden weights for backward RNN, [inSize x numUnitsBW] - * 5: hidden-to-hidden weights for backward RNN, [numUnitsBW x numUnitsBW] - * 6: biases for backward RNN, [2*numUnitsBW] - * 7: (optional) initial cell output for forward RNN [batchSize x numUnitsFW], that is at time step = 0 - * 8: (optional) initial cell output for backward RNN [batchSize x numUnitsBW], that is at time step = 0 - * 9: (optional) vector with shape [batchSize] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this provides no calculations for time >= maxTimeStep - * - * Input integer arguments: - * 0: (optional) timeMajor - if non zero then input shape is [time, batchSize, ...], else [batchSize, time, ...] - * - * Output arrays: - * 0: cell outputs for forward RNN [time x batchSize x numUnitsFW] or [batchSize x time x numUnitsFW] - * 1: cell outputs for backward RNN [time x batchSize x numUnitsBW] or [batchSize x time x numUnitsBW] - * 2: cell final non-zero output for forward RNN [batchSize x numUnitsFW] - * 3: cell final non-zero output for backward RNN [batchSize x numUnitsBW] - */ - @Namespace("sd::ops") public static class dynamic_bidirectional_rnn extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public dynamic_bidirectional_rnn(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public dynamic_bidirectional_rnn(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public dynamic_bidirectional_rnn position(long position) { - return (dynamic_bidirectional_rnn)super.position(position); - } - +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation "static RNN time sequences" with peep hole + * connections: + * + * Input arrays: + * 0: input with shape [time x batchSize x inSize] or [batchSize x time x + * inSize], time - number of time steps, batchSize - batch size, inSize - number + * of features 1: input-to-hidden weights for forward RNN, [inSize x + * numUnitsFW] 2: hidden-to-hidden weights for forward RNN, [numUnitsFW x + * numUnitsFW] 3: biases for forward RNN, [2*numUnitsFW] 4: input-to-hidden + * weights for backward RNN, [inSize x numUnitsBW] 5: hidden-to-hidden weights + * for backward RNN, [numUnitsBW x numUnitsBW] 6: biases for backward RNN, + * [2*numUnitsBW] 7: (optional) initial cell output for forward RNN [batchSize x + * numUnitsFW], that is at time step = 0 8: (optional) initial cell output for + * backward RNN [batchSize x numUnitsBW], that is at time step = 0 9: (optional) + * vector with shape [batchSize] containing integer values within [0,time), each + * element of this vector set max time step per each input in batch, this + * provides no calculations for time >= maxTimeStep + * + * Input integer arguments: + * 0: (optional) timeMajor - if non zero then input shape is [time, + * batchSize, ...], else [batchSize, time, ...] + * + * Output arrays: + * 0: cell outputs for forward RNN [time x batchSize x numUnitsFW] or + * [batchSize x time x numUnitsFW] 1: cell outputs for backward RNN [time x + * batchSize x numUnitsBW] or [batchSize x time x numUnitsBW] 2: cell final + * non-zero output for forward RNN [batchSize x numUnitsFW] 3: cell final + * non-zero output for backward RNN [batchSize x numUnitsBW] + */ +@Namespace("sd::ops") public static class dynamic_bidirectional_rnn extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public dynamic_bidirectional_rnn(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public dynamic_bidirectional_rnn(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public dynamic_bidirectional_rnn position(long position) { + return (dynamic_bidirectional_rnn)super.position(position); + } + public dynamic_bidirectional_rnn() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - - + // namespace ops + // namespace sd // #endif // Parsed from ops/declarable/headers/transforms.h @@ -16749,839 +19105,840 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_HEADERS_TRANSFORMS_H // #include -// #if NOT_EXCLUDED(OP_clipbyvalue) - @Namespace("sd::ops") public static class clipbyvalue extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public clipbyvalue(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public clipbyvalue(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public clipbyvalue position(long position) { - return (clipbyvalue)super.position(position); - } - +// #if NOT_EXCLUDED(OP_clipbyvalue) +@Namespace("sd::ops") public static class clipbyvalue extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public clipbyvalue(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public clipbyvalue(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public clipbyvalue position(long position) { + return (clipbyvalue)super.position(position); + } + public clipbyvalue() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_clipbynorm) +@Namespace("sd::ops") public static class clipbynorm extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public clipbynorm(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public clipbynorm(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public clipbynorm position(long position) { + return (clipbynorm)super.position(position); + } -// #if NOT_EXCLUDED(OP_clipbynorm) - @Namespace("sd::ops") public static class clipbynorm extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public clipbynorm(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public clipbynorm(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public clipbynorm position(long position) { - return (clipbynorm)super.position(position); - } - public clipbynorm() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class clipbynorm_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public clipbynorm_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public clipbynorm_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public clipbynorm_bp position(long position) { - return (clipbynorm_bp)super.position(position); - } - +@Namespace("sd::ops") public static class clipbynorm_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public clipbynorm_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public clipbynorm_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public clipbynorm_bp position(long position) { + return (clipbynorm_bp)super.position(position); + } + public clipbynorm_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_clipbyavgnorm) +@Namespace("sd::ops") public static class clipbyavgnorm extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public clipbyavgnorm(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public clipbyavgnorm(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public clipbyavgnorm position(long position) { + return (clipbyavgnorm)super.position(position); + } -// #if NOT_EXCLUDED(OP_clipbyavgnorm) - @Namespace("sd::ops") public static class clipbyavgnorm extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public clipbyavgnorm(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public clipbyavgnorm(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public clipbyavgnorm position(long position) { - return (clipbyavgnorm)super.position(position); - } - public clipbyavgnorm() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class clipbyavgnorm_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public clipbyavgnorm_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public clipbyavgnorm_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public clipbyavgnorm_bp position(long position) { - return (clipbyavgnorm_bp)super.position(position); - } - +@Namespace("sd::ops") public static class clipbyavgnorm_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public clipbyavgnorm_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public clipbyavgnorm_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public clipbyavgnorm_bp position(long position) { + return (clipbyavgnorm_bp)super.position(position); + } + public clipbyavgnorm_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_cumsum) +@Namespace("sd::ops") public static class cumsum extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cumsum(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cumsum(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cumsum position(long position) { + return (cumsum)super.position(position); + } -// #if NOT_EXCLUDED(OP_cumsum) - @Namespace("sd::ops") public static class cumsum extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cumsum(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cumsum(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cumsum position(long position) { - return (cumsum)super.position(position); - } - public cumsum() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_cumprod) +@Namespace("sd::ops") public static class cumprod extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cumprod(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cumprod(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cumprod position(long position) { + return (cumprod)super.position(position); + } -// #if NOT_EXCLUDED(OP_cumprod) - @Namespace("sd::ops") public static class cumprod extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cumprod(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cumprod(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cumprod position(long position) { - return (cumprod)super.position(position); - } - public cumprod() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_tile) +@Namespace("sd::ops") public static class tile extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public tile(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public tile(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public tile position(long position) { + return (tile)super.position(position); + } -// #if NOT_EXCLUDED(OP_tile) - @Namespace("sd::ops") public static class tile extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public tile(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public tile(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public tile position(long position) { - return (tile)super.position(position); - } - public tile() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class tile_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public tile_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public tile_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public tile_bp position(long position) { - return (tile_bp)super.position(position); - } - +@Namespace("sd::ops") public static class tile_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public tile_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public tile_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public tile_bp position(long position) { + return (tile_bp)super.position(position); + } + public tile_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_repeat) +@Namespace("sd::ops") public static class repeat extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public repeat(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public repeat(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public repeat position(long position) { + return (repeat)super.position(position); + } -// #if NOT_EXCLUDED(OP_repeat) - @Namespace("sd::ops") public static class repeat extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public repeat(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public repeat(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public repeat position(long position) { - return (repeat)super.position(position); - } - public repeat() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_invert_permutation) +@Namespace("sd::ops") public static class invert_permutation extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public invert_permutation(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public invert_permutation(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public invert_permutation position(long position) { + return (invert_permutation)super.position(position); + } -// #if NOT_EXCLUDED(OP_invert_permutation) - @Namespace("sd::ops") public static class invert_permutation extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public invert_permutation(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public invert_permutation(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public invert_permutation position(long position) { - return (invert_permutation)super.position(position); - } - public invert_permutation() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +@Namespace("sd::ops") public static class concat extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public concat(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public concat(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public concat position(long position) { + return (concat)super.position(position); + } - @Namespace("sd::ops") public static class concat extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public concat(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public concat(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public concat position(long position) { - return (concat)super.position(position); - } - public concat() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class concat_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public concat_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public concat_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public concat_bp position(long position) { - return (concat_bp)super.position(position); - } - +@Namespace("sd::ops") public static class concat_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public concat_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public concat_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public concat_bp position(long position) { + return (concat_bp)super.position(position); + } + public concat_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #if NOT_EXCLUDED(OP_mergemax) - @Namespace("sd::ops") public static class mergemax extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mergemax(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mergemax(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mergemax position(long position) { - return (mergemax)super.position(position); - } - +// #if NOT_EXCLUDED(OP_mergemax) +@Namespace("sd::ops") public static class mergemax extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mergemax(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mergemax(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mergemax position(long position) { + return (mergemax)super.position(position); + } + public mergemax() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class mergemax_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mergemax_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mergemax_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mergemax_bp position(long position) { - return (mergemax_bp)super.position(position); - } - +@Namespace("sd::ops") public static class mergemax_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mergemax_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mergemax_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mergemax_bp position(long position) { + return (mergemax_bp)super.position(position); + } + public mergemax_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - /* - * Complete tensor with max indices merged from all input tensors list - * - * INPUT: tensors with the same shape - * OUTPUT: integer tensor with the same shape - * INT_ARG: result type (one of int), INT32 by default - */ -// #if NOT_EXCLUDED(OP_mergemaxindex) - @Namespace("sd::ops") public static class mergemaxindex extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mergemaxindex(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mergemaxindex(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mergemaxindex position(long position) { - return (mergemaxindex)super.position(position); - } - +// #endif +/* + * Complete tensor with max indices merged from all input tensors list + * + * INPUT: tensors with the same shape + * OUTPUT: integer tensor with the same shape + * INT_ARG: result type (one of int), INT32 by default + */ +// #if NOT_EXCLUDED(OP_mergemaxindex) +@Namespace("sd::ops") public static class mergemaxindex extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mergemaxindex(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mergemaxindex(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mergemaxindex position(long position) { + return (mergemaxindex)super.position(position); + } + public mergemaxindex() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_mergeadd) +@Namespace("sd::ops") public static class mergeadd extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mergeadd(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mergeadd(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mergeadd position(long position) { + return (mergeadd)super.position(position); + } + + public mergeadd() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +@Namespace("sd::ops") public static class mergeadd_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mergeadd_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mergeadd_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mergeadd_bp position(long position) { + return (mergeadd_bp)super.position(position); + } -// #if NOT_EXCLUDED(OP_mergeadd) - @Namespace("sd::ops") public static class mergeadd extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mergeadd(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mergeadd(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mergeadd position(long position) { - return (mergeadd)super.position(position); - } - - public mergeadd() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } - @Namespace("sd::ops") public static class mergeadd_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mergeadd_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mergeadd_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mergeadd_bp position(long position) { - return (mergeadd_bp)super.position(position); - } - public mergeadd_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_mergeavg) +@Namespace("sd::ops") public static class mergeavg extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mergeavg(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mergeavg(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mergeavg position(long position) { + return (mergeavg)super.position(position); + } -// #if NOT_EXCLUDED(OP_mergeavg) - @Namespace("sd::ops") public static class mergeavg extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mergeavg(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mergeavg(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mergeavg position(long position) { - return (mergeavg)super.position(position); - } - public mergeavg() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class mergeavg_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mergeavg_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mergeavg_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mergeavg_bp position(long position) { - return (mergeavg_bp)super.position(position); - } - +@Namespace("sd::ops") public static class mergeavg_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mergeavg_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mergeavg_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mergeavg_bp position(long position) { + return (mergeavg_bp)super.position(position); + } + public mergeavg_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_scatter_update) +@Namespace("sd::ops") public static class scatter_update extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_update(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_update(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_update position(long position) { + return (scatter_update)super.position(position); + } -// #if NOT_EXCLUDED(OP_scatter_update) - @Namespace("sd::ops") public static class scatter_update extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_update(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_update(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_update position(long position) { - return (scatter_update)super.position(position); - } - public scatter_update() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_Floor) +@Namespace("sd::ops") public static class Floor extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Floor(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Floor(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Floor position(long position) { + return (Floor)super.position(position); + } -// #if NOT_EXCLUDED(OP_Floor) - @Namespace("sd::ops") public static class Floor extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Floor(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Floor(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Floor position(long position) { - return (Floor)super.position(position); - } - public Floor() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_Log1p) +@Namespace("sd::ops") public static class Log1p extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Log1p(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Log1p(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Log1p position(long position) { + return (Log1p)super.position(position); + } -// #if NOT_EXCLUDED(OP_Log1p) - @Namespace("sd::ops") public static class Log1p extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Log1p(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Log1p(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Log1p position(long position) { - return (Log1p)super.position(position); - } - public Log1p() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_reverse) +@Namespace("sd::ops") public static class reverse extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reverse(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reverse(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reverse position(long position) { + return (reverse)super.position(position); + } -// #if NOT_EXCLUDED(OP_reverse) - @Namespace("sd::ops") public static class reverse extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reverse(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reverse(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reverse position(long position) { - return (reverse)super.position(position); - } - public reverse() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class reverse_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reverse_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reverse_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reverse_bp position(long position) { - return (reverse_bp)super.position(position); - } - +@Namespace("sd::ops") public static class reverse_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reverse_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reverse_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reverse_bp position(long position) { + return (reverse_bp)super.position(position); + } + public reverse_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_gather) +@Namespace("sd::ops") public static class gather extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public gather(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public gather(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public gather position(long position) { + return (gather)super.position(position); + } -// #if NOT_EXCLUDED(OP_gather) - @Namespace("sd::ops") public static class gather extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public gather(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public gather(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public gather position(long position) { - return (gather)super.position(position); - } - public gather() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_pad) +@Namespace("sd::ops") public static class pad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public pad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public pad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public pad position(long position) { + return (pad)super.position(position); + } -// #if NOT_EXCLUDED(OP_pad) - @Namespace("sd::ops") public static class pad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public pad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public pad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public pad position(long position) { - return (pad)super.position(position); - } - public pad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * creates identity 2D matrix or batch of identical 2D identity matrices + * + * Input array: + * provide some array - in any case operation simply neglects it + * + * Input float argument (if passed): + * TArgs[0] - type of elements of output array, default value is 5 (float) + * + * Input integer arguments: + * IArgs[0] - order of output identity matrix, 99 -> 'c'-order, 102 -> + * 'f'-order IArgs[1] - the number of rows in output inner-most 2D + * identity matrix IArgs[2] - optional, the number of columns in output + * inner-most 2D identity matrix, if this argument is not provided then it is + * taken to be equal to number of rows IArgs[3,4,...] - optional, shape of + * batch, output matrix will have leading batch dimensions of this shape + */ +// #if NOT_EXCLUDED(OP_eye) +@Namespace("sd::ops") public static class eye extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public eye(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public eye(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public eye position(long position) { + return (eye)super.position(position); + } - /** - * creates identity 2D matrix or batch of identical 2D identity matrices - * - * Input array: - * provide some array - in any case operation simply neglects it - * - * Input float argument (if passed): - * TArgs[0] - type of elements of output array, default value is 5 (float) - * - * Input integer arguments: - * IArgs[0] - order of output identity matrix, 99 -> 'c'-order, 102 -> 'f'-order - * IArgs[1] - the number of rows in output inner-most 2D identity matrix - * IArgs[2] - optional, the number of columns in output inner-most 2D identity matrix, if this argument is not provided then it is taken to be equal to number of rows - * IArgs[3,4,...] - optional, shape of batch, output matrix will have leading batch dimensions of this shape - */ -// #if NOT_EXCLUDED(OP_eye) - @Namespace("sd::ops") public static class eye extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public eye(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public eye(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public eye position(long position) { - return (eye)super.position(position); - } - public eye() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_gather_nd) +@Namespace("sd::ops") public static class gather_nd extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public gather_nd(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public gather_nd(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public gather_nd position(long position) { + return (gather_nd)super.position(position); + } -// #if NOT_EXCLUDED(OP_gather_nd) - @Namespace("sd::ops") public static class gather_nd extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public gather_nd(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public gather_nd(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public gather_nd position(long position) { - return (gather_nd)super.position(position); - } - public gather_nd() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_reverse_sequence) +@Namespace("sd::ops") public static class reverse_sequence extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reverse_sequence(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reverse_sequence(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reverse_sequence position(long position) { + return (reverse_sequence)super.position(position); + } -// #if NOT_EXCLUDED(OP_reverse_sequence) - @Namespace("sd::ops") public static class reverse_sequence extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reverse_sequence(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reverse_sequence(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reverse_sequence position(long position) { - return (reverse_sequence)super.position(position); - } - public reverse_sequence() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_trace) +@Namespace("sd::ops") public static class trace extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public trace(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public trace(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public trace position(long position) { + return (trace)super.position(position); + } -// #if NOT_EXCLUDED(OP_trace) - @Namespace("sd::ops") public static class trace extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public trace(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public trace(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public trace position(long position) { - return (trace)super.position(position); - } - public trace() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_random_shuffle) +@Namespace("sd::ops") public static class random_shuffle extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public random_shuffle(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public random_shuffle(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public random_shuffle position(long position) { + return (random_shuffle)super.position(position); + } -// #if NOT_EXCLUDED(OP_random_shuffle) - @Namespace("sd::ops") public static class random_shuffle extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public random_shuffle(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public random_shuffle(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public random_shuffle position(long position) { - return (random_shuffle)super.position(position); - } - public random_shuffle() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * clip a list of given tensors with given average norm when needed + * + * Input: + * a list of tensors (at least one) + * + * Input floating point argument: + * clip_norm - a value that used as threshold value and norm to be used + * + * return a list of clipped tensors + * and global_norm as scalar tensor at the end + */ +// #if NOT_EXCLUDED(OP_clip_by_global_norm) +@Namespace("sd::ops") public static class clip_by_global_norm extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public clip_by_global_norm(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public clip_by_global_norm(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public clip_by_global_norm position(long position) { + return (clip_by_global_norm)super.position(position); + } - /** - * clip a list of given tensors with given average norm when needed - * - * Input: - * a list of tensors (at least one) - * - * Input floating point argument: - * clip_norm - a value that used as threshold value and norm to be used - * - * return a list of clipped tensors - * and global_norm as scalar tensor at the end - */ -// #if NOT_EXCLUDED(OP_clip_by_global_norm) - @Namespace("sd::ops") public static class clip_by_global_norm extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public clip_by_global_norm(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public clip_by_global_norm(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public clip_by_global_norm position(long position) { - return (clip_by_global_norm)super.position(position); - } - public clip_by_global_norm() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +@Namespace("sd::ops") public static class tri extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public tri(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public tri(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public tri position(long position) { + return (tri)super.position(position); + } - @Namespace("sd::ops") public static class tri extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public tri(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public tri(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public tri position(long position) { - return (tri)super.position(position); - } - public tri() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class triu extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public triu(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public triu(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public triu position(long position) { - return (triu)super.position(position); - } - +@Namespace("sd::ops") public static class triu extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public triu(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public triu(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public triu position(long position) { + return (triu)super.position(position); + } + public triu() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class triu_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public triu_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public triu_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public triu_bp position(long position) { - return (triu_bp)super.position(position); - } - +@Namespace("sd::ops") public static class triu_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public triu_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public triu_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public triu_bp position(long position) { + return (triu_bp)super.position(position); + } + public triu_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #if NOT_EXCLUDED(OP_mirror_pad) - @Namespace("sd::ops") public static class mirror_pad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mirror_pad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mirror_pad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mirror_pad position(long position) { - return (mirror_pad)super.position(position); - } - +// #if NOT_EXCLUDED(OP_mirror_pad) +@Namespace("sd::ops") public static class mirror_pad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mirror_pad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mirror_pad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mirror_pad position(long position) { + return (mirror_pad)super.position(position); + } + public mirror_pad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_cumsum) +@Namespace("sd::ops") public static class cumsum_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cumsum_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cumsum_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cumsum_bp position(long position) { + return (cumsum_bp)super.position(position); + } -// #if NOT_EXCLUDED(OP_cumsum) - @Namespace("sd::ops") public static class cumsum_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cumsum_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cumsum_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cumsum_bp position(long position) { - return (cumsum_bp)super.position(position); - } - public cumsum_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_cumprod) +@Namespace("sd::ops") public static class cumprod_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cumprod_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cumprod_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cumprod_bp position(long position) { + return (cumprod_bp)super.position(position); + } -// #if NOT_EXCLUDED(OP_cumprod) - @Namespace("sd::ops") public static class cumprod_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cumprod_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cumprod_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cumprod_bp position(long position) { - return (cumprod_bp)super.position(position); - } - public cumprod_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +// #if NOT_EXCLUDED(OP_flatten) +@Namespace("sd::ops") public static class flatten extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public flatten(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public flatten(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public flatten position(long position) { + return (flatten)super.position(position); + } -// #if NOT_EXCLUDED(OP_flatten) - @Namespace("sd::ops") public static class flatten extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public flatten(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public flatten(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public flatten position(long position) { - return (flatten)super.position(position); - } - public flatten() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * returns histogram (as 1D array) with fixed bins width + * + * Input arrays: + * - input array with elements to be binned into output histogram + * - range array with first element being bottom limit and second element being + top limit of histogram, please note that input_value <= range[0] will be mapped + to histogram[0], input_value >= range[1] will be mapped to histogram[-1] + * + * Input integer arguments: + * nbins (optional) - number of histogram bins, default value is 100 + */ +// #if NOT_EXCLUDED(OP_histogram_fixed_width) +@Namespace("sd::ops") public static class histogram_fixed_width extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public histogram_fixed_width(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public histogram_fixed_width(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public histogram_fixed_width position(long position) { + return (histogram_fixed_width)super.position(position); + } - /** - * returns histogram (as 1D array) with fixed bins width - * - * Input arrays: - * - input array with elements to be binned into output histogram - * - range array with first element being bottom limit and second element being top limit of histogram, - please note that input_value <= range[0] will be mapped to histogram[0], input_value >= range[1] will be mapped to histogram[-1] - * - * Input integer arguments: - * nbins (optional) - number of histogram bins, default value is 100 - */ -// #if NOT_EXCLUDED(OP_histogram_fixed_width) - @Namespace("sd::ops") public static class histogram_fixed_width extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public histogram_fixed_width(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public histogram_fixed_width(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public histogram_fixed_width position(long position) { - return (histogram_fixed_width)super.position(position); - } - public histogram_fixed_width() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * standardizes input array to be zero mean unit variance along the given axis + * + * + */ +// #if NOT_EXCLUDED(OP_standardize) +@Namespace("sd::ops") public static class standardize extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public standardize(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public standardize(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public standardize position(long position) { + return (standardize)super.position(position); + } - /** - * standardizes input array to be zero mean unit variance along the given axis - * - * - */ -// #if NOT_EXCLUDED(OP_standardize) - @Namespace("sd::ops") public static class standardize extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public standardize(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public standardize(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public standardize position(long position) { - return (standardize)super.position(position); - } - public standardize() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class standardize_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public standardize_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public standardize_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public standardize_bp position(long position) { - return (standardize_bp)super.position(position); - } - +@Namespace("sd::ops") public static class standardize_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public standardize_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public standardize_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public standardize_bp position(long position) { + return (standardize_bp)super.position(position); + } + public standardize_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation calculates hash code, optionally along dimension + */ +// #if NOT_EXCLUDED(OP_hashcode) +@Namespace("sd::ops") public static class hashcode extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public hashcode(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public hashcode(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public hashcode position(long position) { + return (hashcode)super.position(position); + } - /** - * This operation calculates hash code, optionally along dimension - */ -// #if NOT_EXCLUDED(OP_hashcode) - @Namespace("sd::ops") public static class hashcode extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public hashcode(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public hashcode(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public hashcode position(long position) { - return (hashcode)super.position(position); - } - public hashcode() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation calculates number of entries per bin + */ +// #if NOT_EXCLUDED(OP_histogram) +@Namespace("sd::ops") public static class histogram extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public histogram(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public histogram(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public histogram position(long position) { + return (histogram)super.position(position); + } - /** - * This operation calculates number of entries per bin - */ -// #if NOT_EXCLUDED(OP_histogram) - @Namespace("sd::ops") public static class histogram extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public histogram(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public histogram(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public histogram position(long position) { - return (histogram)super.position(position); - } - public histogram() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - +// #endif + // namespace ops + // namespace sd // #endif @@ -17612,61 +19969,59 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_HEADERS_PARITY_H // #include - /** - * This operation returns index of max element in a given NDArray (optionally: along given dimension(s)) - * Expected input: - * 0: N-dimensional array - * 1: optional axis vector - * - * Int args: - * 0: optional axis - */ -// #if NOT_EXCLUDED(OP_argmax) - @Namespace("sd::ops") public static class argmax extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public argmax(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public argmax(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public argmax position(long position) { - return (argmax)super.position(position); - } - +/** + * This operation returns index of max element in a given NDArray (optionally: + * along given dimension(s)) Expected input: 0: N-dimensional array 1: optional + * axis vector + * + * Int args: + * 0: optional axis + */ +// #if NOT_EXCLUDED(OP_argmax) +@Namespace("sd::ops") public static class argmax extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public argmax(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public argmax(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public argmax position(long position) { + return (argmax)super.position(position); + } + public argmax() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation returns index of min element in a given NDArray (optionally: + * along given dimension(s)) Expected input: 0: N-dimensional array 1: optional + * axis vector + * + * Int args: + * 0: optional axis + */ +// #if NOT_EXCLUDED(OP_argmin) +@Namespace("sd::ops") public static class argmin extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public argmin(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public argmin(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public argmin position(long position) { + return (argmin)super.position(position); + } - /** - * This operation returns index of min element in a given NDArray (optionally: along given dimension(s)) - * Expected input: - * 0: N-dimensional array - * 1: optional axis vector - * - * Int args: - * 0: optional axis - */ -// #if NOT_EXCLUDED(OP_argmin) - @Namespace("sd::ops") public static class argmin extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public argmin(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public argmin(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public argmin position(long position) { - return (argmin)super.position(position); - } - public argmin() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif - /** +/** * This operation returns index of absolute max element in a given NDArray (optionally: along given dimension(s)) * Expected input: * 0: N-dimensional array @@ -17721,3654 +20076,3720 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #endif /** - * This operation provides various normalization modes: - * 0: frobenius - * 1: euclidean (norm2) - * 2: norm1 - * 3: norm2 - * 4: inf-norm - * 5: p-norm - * - * Expected arguments: - * input: N-dimensional array - * - * - * Int args: - * 0...: axis - * - * T args: - * 0: norm mode - * 1: p for p-norm - */ -// #if NOT_EXCLUDED(OP_norm) - @Namespace("sd::ops") public static class norm extends DeclarableReductionOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public norm(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public norm(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public norm position(long position) { - return (norm)super.position(position); - } - + * This operation provides various normalization modes: + * 0: frobenius + * 1: euclidean (norm2) + * 2: norm1 + * 3: norm2 + * 4: inf-norm + * 5: p-norm + * + * Expected arguments: + * input: N-dimensional array + * + * + * Int args: + * 0...: axis + * + * T args: + * 0: norm mode + * 1: p for p-norm + */ +// #if NOT_EXCLUDED(OP_norm) +@Namespace("sd::ops") public static class norm extends DeclarableReductionOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public norm(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public norm(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public norm position(long position) { + return (norm)super.position(position); + } + public norm() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * Inserts elements provided by diagonal array into the main diagonal of + * innermost matrices of input array + * + * Input arrays: + * 0: input array, considered as batch of matrices + * 1: diagonal array containing elements to be inserted into input array, + * following rank condition should be satisfied: diagonal_rank = input_rank + * - 1, the shapes of diagonal and input arrays must be equal except last + * dimension of input array, for example if input_shape = [A,B,C,D] then + * diagonal_shape = [A,B,C], also last dimension of diagonal array should be + * equal to smaller of last and last but one input dimensions that is: + * diagonal_shape[-1] = min(input_shape[-1], input_shape[-2]) + * + * Output array: + * 0: has the same shape as input, corresponding diagonal elements are + * substituted + */ +// #if NOT_EXCLUDED(OP_matrix_set_diag) +@Namespace("sd::ops") public static class matrix_set_diag extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public matrix_set_diag(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public matrix_set_diag(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public matrix_set_diag position(long position) { + return (matrix_set_diag)super.position(position); + } - /** - * Inserts elements provided by diagonal array into the main diagonal of innermost matrices of input array - * - * Input arrays: - * 0: input array, considered as batch of matrices - * 1: diagonal array containing elements to be inserted into input array, - * following rank condition should be satisfied: diagonal_rank = input_rank - 1, - * the shapes of diagonal and input arrays must be equal except last dimension of input array, - * for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C], - * also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions - * that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2]) - * - * Output array: - * 0: has the same shape as input, corresponding diagonal elements are substituted - */ -// #if NOT_EXCLUDED(OP_matrix_set_diag) - @Namespace("sd::ops") public static class matrix_set_diag extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public matrix_set_diag(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public matrix_set_diag(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public matrix_set_diag position(long position) { - return (matrix_set_diag)super.position(position); - } - public matrix_set_diag() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Inserts elements provided by diagonal array into the main diagonal of + * innermost matrices of output array, rest output elements are set to zeros + * + * Input array: + * diagonal: array containing elements to be inserted into output array, + * following rank condition is present: diagonal_rank = ouput_rank + * - 1 + * + * Output array: + * 0: is considered as batch of matrices, if for example diagonal array has + * shape [A,B,C] then output array has shape [A,B,C,C] + */ +@Namespace("sd::ops") public static class matrix_diag extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public matrix_diag(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public matrix_diag(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public matrix_diag position(long position) { + return (matrix_diag)super.position(position); + } - /** - * Inserts elements provided by diagonal array into the main diagonal of innermost matrices of output array, - * rest output elements are set to zeros - * - * Input array: - * diagonal: array containing elements to be inserted into output array, - * following rank condition is present: diagonal_rank = ouput_rank - 1 - * - * Output array: - * 0: is considered as batch of matrices, if for example diagonal array has shape [A,B,C] then output array has shape [A,B,C,C] - */ - @Namespace("sd::ops") public static class matrix_diag extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public matrix_diag(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public matrix_diag(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public matrix_diag position(long position) { - return (matrix_diag)super.position(position); - } - public matrix_diag() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - /** - * This op calculates regularized incomplete beta integral Ix(a, b). - * Implementation is based on two algorithms depending on input values of a and b: - * - when a and b are both > maxValue (3000.), then Gauss-Legendre quadrature method is applied - * - when a and b are both <= maxValue (3000.), then modified Lentz’s algorithm for continued fractions is applied - * - * Input arrays: - * a: defines power t^{a-1}, must be > 0, type float. - * b: defines power (1-t)^{b-1}, must be > 0, type float. - * x: defines upper limit of integration, must be within (0 <= x <= 1) range, type float. - * - * Output array: - * 0: values of regularized incomplete beta integral that corresponds to variable upper limit x, type float - * - * Three input and one output arrays must have the same shape - */ -// #if NOT_EXCLUDED(OP_betainc) - @Namespace("sd::ops") public static class betainc extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public betainc(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public betainc(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public betainc position(long position) { - return (betainc)super.position(position); - } - +/** + * This op calculates regularized incomplete beta integral Ix(a, b). + * Implementation is based on two algorithms depending on input values of a and + * b: + * - when a and b are both > maxValue (3000.), then Gauss-Legendre quadrature + * method is applied + * - when a and b are both <= maxValue (3000.), then modified Lentz’s algorithm + * for continued fractions is applied + * + * Input arrays: + * a: defines power t^{a-1}, must be > 0, type float. + * b: defines power (1-t)^{b-1}, must be > 0, type float. + * x: defines upper limit of integration, must be within (0 <= x <= 1) range, + * type float. + * + * Output array: + * 0: values of regularized incomplete beta integral that corresponds to + * variable upper limit x, type float + * + * Three input and one output arrays must have the same shape + */ +// #if NOT_EXCLUDED(OP_betainc) +@Namespace("sd::ops") public static class betainc extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public betainc(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public betainc(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public betainc position(long position) { + return (betainc)super.position(position); + } + public betainc() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation is added for compatibility purposes mostly. + * PLEASE NOTE: Please consider using Add instead + * Expected arguments: + * 0: N-dimensional input + * 1: bias vector + */ +// #if NOT_EXCLUDED(OP_biasadd) +@Namespace("sd::ops") public static class biasadd extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public biasadd(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public biasadd(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public biasadd position(long position) { + return (biasadd)super.position(position); + } - /** - * This operation is added for compatibility purposes mostly. - * PLEASE NOTE: Please consider using Add instead - * Expected arguments: - * 0: N-dimensional input - * 1: bias vector - */ -// #if NOT_EXCLUDED(OP_biasadd) - @Namespace("sd::ops") public static class biasadd extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public biasadd(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public biasadd(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public biasadd position(long position) { - return (biasadd)super.position(position); - } - public biasadd() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class biasadd_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public biasadd_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public biasadd_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public biasadd_bp position(long position) { - return (biasadd_bp)super.position(position); - } - +@Namespace("sd::ops") public static class biasadd_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public biasadd_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public biasadd_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public biasadd_bp position(long position) { + return (biasadd_bp)super.position(position); + } + public biasadd_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Returns a diagonal tensor with a given diagonal values. Given a diagonal, + * this operation returns a tensor with the diagonal and everything else padded + * with zeros. + */ +// #if NOT_EXCLUDED(OP_diag) +@Namespace("sd::ops") public static class diag extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public diag(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public diag(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public diag position(long position) { + return (diag)super.position(position); + } - /** - * Returns a diagonal tensor with a given diagonal values. Given a diagonal, this operation returns a tensor with the diagonal and everything else padded with zeros. - */ -// #if NOT_EXCLUDED(OP_diag) - @Namespace("sd::ops") public static class diag extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public diag(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public diag(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public diag position(long position) { - return (diag)super.position(position); - } - public diag() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Returns a diagonal tensor with a given diagonal values. Given a diagonal, + * this operation returns a tensor with the diagonal and everything else padded + * with zeros. + */ +// #if NOT_EXCLUDED(OP_diag_part) +@Namespace("sd::ops") public static class diag_part extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public diag_part(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public diag_part(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public diag_part position(long position) { + return (diag_part)super.position(position); + } - /** - * Returns a diagonal tensor with a given diagonal values. Given a diagonal, this operation returns a tensor with the diagonal and everything else padded with zeros. - */ -// #if NOT_EXCLUDED(OP_diag_part) - @Namespace("sd::ops") public static class diag_part extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public diag_part(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public diag_part(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public diag_part position(long position) { - return (diag_part)super.position(position); - } - public diag_part() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Returns a diagonal vector for any submatricies with in a given tensor. + * It is an op inverse to matrix_set_giag. + * Using input tensor as batched 2D diagonals flat them to vector (1D) with + * diagonal values. + * + * Input : batched tensor with rank >=2 + * Output: tensor with rank lesser by 1 from input + */ +// #if NOT_EXCLUDED(OP_matrix_diag_part) +@Namespace("sd::ops") public static class matrix_diag_part extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public matrix_diag_part(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public matrix_diag_part(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public matrix_diag_part position(long position) { + return (matrix_diag_part)super.position(position); + } - /** - * Returns a diagonal vector for any submatricies with in a given tensor. - * It is an op inverse to matrix_set_giag. - * Using input tensor as batched 2D diagonals flat them to vector (1D) with diagonal values. - * - * Input : batched tensor with rank >=2 - * Output: tensor with rank lesser by 1 from input - */ -// #if NOT_EXCLUDED(OP_matrix_diag_part) - @Namespace("sd::ops") public static class matrix_diag_part extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public matrix_diag_part(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public matrix_diag_part(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public matrix_diag_part position(long position) { - return (matrix_diag_part)super.position(position); - } - public matrix_diag_part() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * QR decomposition: A = QR, where Q is ortogonal (Q * QT = I) and R is upper + * triangular. For A (MxN) Q is M x M and R is (NxN). + * + * Input : + * 0 - float (or complex float) tensor with shape {.,..,...,M,N} - batch of + * float matricies + * + * Output: + * 0 - float tensor with shape {.,..,...,MxN} - batch of ortogonal matricies + * {Qs} 1 - float tensor with shape {.,..,...,NxN} - batch of upper triangular + * matricies {Rs} + */ +// #if NOT_EXCLUDED(OP_qr) +@Namespace("sd::ops") public static class qr extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public qr(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public qr(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public qr position(long position) { + return (qr)super.position(position); + } - /** - * QR decomposition: A = QR, where Q is ortogonal (Q * QT = I) and R is upper triangular. - * For A (MxN) Q is M x M and R is (NxN). - * - * Input : - * 0 - float (or complex float) tensor with shape {.,..,...,M,N} - batch of float matricies - * - * Output: - * 0 - float tensor with shape {.,..,...,MxN} - batch of ortogonal matricies {Qs} - * 1 - float tensor with shape {.,..,...,NxN} - batch of upper triangular matricies {Rs} - */ -// #if NOT_EXCLUDED(OP_qr) - @Namespace("sd::ops") public static class qr extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public qr(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public qr(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public qr position(long position) { - return (qr)super.position(position); - } - public qr() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation takes 2 arrays: original values, and values to be excluded. + * And returns 2 arrays: values left after exclusion, and indices in original + * array for surivals. Expected arguments: 0: vector with original values 1: + * vector with values to exclude + */ +// #if NOT_EXCLUDED(OP_listdiff) +@Namespace("sd::ops") public static class listdiff extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public listdiff(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public listdiff(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public listdiff position(long position) { + return (listdiff)super.position(position); + } - /** - * This operation takes 2 arrays: original values, and values to be excluded. And returns 2 arrays: values left after exclusion, and indices in original array for surivals. - * Expected arguments: - * 0: vector with original values - * 1: vector with values to exclude - */ -// #if NOT_EXCLUDED(OP_listdiff) - @Namespace("sd::ops") public static class listdiff extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public listdiff(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public listdiff(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public listdiff position(long position) { - return (listdiff)super.position(position); - } - public listdiff() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation applies Add operation to specific inputs wrt indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +// #if NOT_EXCLUDED(OP_scatter_add) +@Namespace("sd::ops") public static class scatter_add extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_add(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_add(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_add position(long position) { + return (scatter_add)super.position(position); + } - /** - * This operation applies Add operation to specific inputs wrt indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ -// #if NOT_EXCLUDED(OP_scatter_add) - @Namespace("sd::ops") public static class scatter_add extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_add(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_add(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_add position(long position) { - return (scatter_add)super.position(position); - } - public scatter_add() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation applies Subtract operation to specific inputs wrt indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +// #if NOT_EXCLUDED(OP_scatter_sub) +@Namespace("sd::ops") public static class scatter_sub extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_sub(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_sub(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_sub position(long position) { + return (scatter_sub)super.position(position); + } - /** - * This operation applies Subtract operation to specific inputs wrt indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ -// #if NOT_EXCLUDED(OP_scatter_sub) - @Namespace("sd::ops") public static class scatter_sub extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_sub(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_sub(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_sub position(long position) { - return (scatter_sub)super.position(position); - } - public scatter_sub() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation applies Multiply operation to specific inputs wrt indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +// #if NOT_EXCLUDED(OP_scatter_mul) +@Namespace("sd::ops") public static class scatter_mul extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_mul(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_mul(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_mul position(long position) { + return (scatter_mul)super.position(position); + } - /** - * This operation applies Multiply operation to specific inputs wrt indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ -// #if NOT_EXCLUDED(OP_scatter_mul) - @Namespace("sd::ops") public static class scatter_mul extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_mul(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_mul(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_mul position(long position) { - return (scatter_mul)super.position(position); - } - public scatter_mul() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation applies Divide operation to specific inputs wrt indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +// #if NOT_EXCLUDED(OP_scatter_div) +@Namespace("sd::ops") public static class scatter_div extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_div(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_div(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_div position(long position) { + return (scatter_div)super.position(position); + } - /** - * This operation applies Divide operation to specific inputs wrt indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ -// #if NOT_EXCLUDED(OP_scatter_div) - @Namespace("sd::ops") public static class scatter_div extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_div(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_div(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_div position(long position) { - return (scatter_div)super.position(position); - } - public scatter_div() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation applies Assign operation to specific inputs wrt indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +// #if NOT_EXCLUDED(OP_scatter_upd) +@Namespace("sd::ops") public static class scatter_upd extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_upd(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_upd(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_upd position(long position) { + return (scatter_upd)super.position(position); + } - /** - * This operation applies Assign operation to specific inputs wrt indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ -// #if NOT_EXCLUDED(OP_scatter_upd) - @Namespace("sd::ops") public static class scatter_upd extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_upd(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_upd(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_upd position(long position) { - return (scatter_upd)super.position(position); - } - public scatter_upd() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation applies Max operation to specific inputs through given indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +// #if NOT_EXCLUDED(OP_scatter_max) +@Namespace("sd::ops") public static class scatter_max extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_max(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_max(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_max position(long position) { + return (scatter_max)super.position(position); + } - /** - * This operation applies Max operation to specific inputs through given indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ -// #if NOT_EXCLUDED(OP_scatter_max) - @Namespace("sd::ops") public static class scatter_max extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_max(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_max(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_max position(long position) { - return (scatter_max)super.position(position); - } - public scatter_max() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation applies Min operation to specific inputs through given indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +// #if NOT_EXCLUDED(OP_scatter_min) +@Namespace("sd::ops") public static class scatter_min extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_min(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_min(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_min position(long position) { + return (scatter_min)super.position(position); + } - /** - * This operation applies Min operation to specific inputs through given indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ -// #if NOT_EXCLUDED(OP_scatter_min) - @Namespace("sd::ops") public static class scatter_min extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_min(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_min(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_min position(long position) { - return (scatter_min)super.position(position); - } - public scatter_min() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation scatter "updates" elements into new output array according to + * given "indices" Expected arguments: indices: array containing elements/slices + * indexes of output array to put "updates" elements into, the rest output + * elements will be zeros updates: array containing elements to be inserted into + * output array shape: contains shape of output array + */ +// #if NOT_EXCLUDED(OP_scatter_nd) +@Namespace("sd::ops") public static class scatter_nd extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_nd(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_nd(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_nd position(long position) { + return (scatter_nd)super.position(position); + } - /** - * This operation scatter "updates" elements into new output array according to given "indices" - * Expected arguments: - * indices: array containing elements/slices indexes of output array to put "updates" elements into, the rest output elements will be zeros - * updates: array containing elements to be inserted into output array - * shape: contains shape of output array - */ -// #if NOT_EXCLUDED(OP_scatter_nd) - @Namespace("sd::ops") public static class scatter_nd extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_nd(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_nd(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_nd position(long position) { - return (scatter_nd)super.position(position); - } - public scatter_nd() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation scatter "updates" elements into input array along given + * "indices" Expected arguments: input: array to be updated indices: array + * containing elements/slices indexes of input array to put "updates" elements + * into updates: array containing elements to be inserted into input array + */ +// #if NOT_EXCLUDED(OP_scatter_nd_update) +@Namespace("sd::ops") public static class scatter_nd_update extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_nd_update(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_nd_update(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_nd_update position(long position) { + return (scatter_nd_update)super.position(position); + } - /** - * This operation scatter "updates" elements into input array along given "indices" - * Expected arguments: - * input: array to be updated - * indices: array containing elements/slices indexes of input array to put "updates" elements into - * updates: array containing elements to be inserted into input array - */ -// #if NOT_EXCLUDED(OP_scatter_nd_update) - @Namespace("sd::ops") public static class scatter_nd_update extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_nd_update(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_nd_update(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_nd_update position(long position) { - return (scatter_nd_update)super.position(position); - } - public scatter_nd_update() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation adds "updates" elements to input array along given "indices" + * Expected arguments: + * input: array to be updated + * indices: array containing elements/slices indexes of input array to add + * "updates" elements to updates: array containing elements to be interfered + * with input + */ +// #if NOT_EXCLUDED(OP_scatter_add) +@Namespace("sd::ops") public static class scatter_nd_add extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_nd_add(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_nd_add(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_nd_add position(long position) { + return (scatter_nd_add)super.position(position); + } - /** - * This operation adds "updates" elements to input array along given "indices" - * Expected arguments: - * input: array to be updated - * indices: array containing elements/slices indexes of input array to add "updates" elements to - * updates: array containing elements to be interfered with input - */ -// #if NOT_EXCLUDED(OP_scatter_add) - @Namespace("sd::ops") public static class scatter_nd_add extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_nd_add(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_nd_add(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_nd_add position(long position) { - return (scatter_nd_add)super.position(position); - } - public scatter_nd_add() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation subtract "updates" elements from input array along given + * "indices" Expected arguments: input: array to be updated indices: array + * containing elements/slices indexes of input array to subtract "updates" + * elements from updates: array containing elements to be interfered with input + */ +// #if NOT_EXCLUDED(OP_scatter_sub) +@Namespace("sd::ops") public static class scatter_nd_sub extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_nd_sub(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_nd_sub(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_nd_sub position(long position) { + return (scatter_nd_sub)super.position(position); + } - /** - * This operation subtract "updates" elements from input array along given "indices" - * Expected arguments: - * input: array to be updated - * indices: array containing elements/slices indexes of input array to subtract "updates" elements from - * updates: array containing elements to be interfered with input - */ -// #if NOT_EXCLUDED(OP_scatter_sub) - @Namespace("sd::ops") public static class scatter_nd_sub extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_nd_sub(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_nd_sub(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_nd_sub position(long position) { - return (scatter_nd_sub)super.position(position); - } - public scatter_nd_sub() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation takes input's shape, and returns new NDArray filled with + * specified value Expected arguments: input: N-dimensional array + * + * T args: + * 0: scalar value, used to fill NDArray + */ +// #if NOT_EXCLUDED(OP_fill_as) +@Namespace("sd::ops") public static class fill_as extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public fill_as(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public fill_as(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public fill_as position(long position) { + return (fill_as)super.position(position); + } - /** - * This operation takes input's shape, and returns new NDArray filled with specified value - * Expected arguments: - * input: N-dimensional array - * - * T args: - * 0: scalar value, used to fill NDArray - */ -// #if NOT_EXCLUDED(OP_fill_as) - @Namespace("sd::ops") public static class fill_as extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public fill_as(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public fill_as(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public fill_as position(long position) { - return (fill_as)super.position(position); - } - public fill_as() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation applies element-wise rint (round to integral value) operation + */ +// #if NOT_EXCLUDED(OP_rint) +@Namespace("sd::ops") public static class rint extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public rint(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public rint(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public rint position(long position) { + return (rint)super.position(position); + } - /** - * This operation applies element-wise rint (round to integral value) operation - */ -// #if NOT_EXCLUDED(OP_rint) - @Namespace("sd::ops") public static class rint extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public rint(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public rint(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public rint position(long position) { - return (rint)super.position(position); - } - public rint() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation returns unique elements from input array as vector, and their + * original indices in input array Expected input: input: N-dimensional array + */ +// #if NOT_EXCLUDED(OP_unique) +@Namespace("sd::ops") public static class unique extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unique(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unique(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unique position(long position) { + return (unique)super.position(position); + } - /** - * This operation returns unique elements from input array as vector, and their original indices in input array - * Expected input: - * input: N-dimensional array - */ -// #if NOT_EXCLUDED(OP_unique) - @Namespace("sd::ops") public static class unique extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unique(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unique(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unique position(long position) { - return (unique)super.position(position); - } - public unique() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation returns 3 1D arrays for given 1D array with unique element + * count and indexes input: 0 - 1D array + * + * output: + * 0 - 1D array with unique values + * 1 - 1D array with ids for values in array above + * 2 - 1D array with counts for values in array above + */ +// #if NOT_EXCLUDED(OP_unique_with_counts) +@Namespace("sd::ops") public static class unique_with_counts extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unique_with_counts(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unique_with_counts(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unique_with_counts position(long position) { + return (unique_with_counts)super.position(position); + } - /** - * This operation returns 3 1D arrays for given 1D array with unique element count and indexes - * input: - * 0 - 1D array - * - * output: - * 0 - 1D array with unique values - * 1 - 1D array with ids for values in array above - * 2 - 1D array with counts for values in array above - */ -// #if NOT_EXCLUDED(OP_unique_with_counts) - @Namespace("sd::ops") public static class unique_with_counts extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unique_with_counts(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unique_with_counts(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unique_with_counts position(long position) { - return (unique_with_counts)super.position(position); - } - public unique_with_counts() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation splits input NDArray into multiple TADs along given dimensions + * Expected arguments: + * input: N-dimensional array + * + * Int args: + * 0..: TAD axis + */ +// #if NOT_EXCLUDED(OP_tear) +@Namespace("sd::ops") public static class tear extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public tear(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public tear(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public tear position(long position) { + return (tear)super.position(position); + } - /** - * This operation splits input NDArray into multiple TADs along given dimensions - * Expected arguments: - * input: N-dimensional array - * - * Int args: - * 0..: TAD axis - */ -// #if NOT_EXCLUDED(OP_tear) - @Namespace("sd::ops") public static class tear extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public tear(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public tear(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public tear position(long position) { - return (tear)super.position(position); - } - public tear() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op does the same as tear, just uses different input format: + * \tparam T + */ +// #if NOT_EXCLUDED(OP_unstack) +@Namespace("sd::ops") public static class unstack extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unstack(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unstack(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unstack position(long position) { + return (unstack)super.position(position); + } - /** - * This op does the same as tear, just uses different input format: - * \tparam T - */ -// #if NOT_EXCLUDED(OP_unstack) - @Namespace("sd::ops") public static class unstack extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unstack(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unstack(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unstack position(long position) { - return (unstack)super.position(position); - } - public unstack() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation extracts a strided (optionally) slice from a tensor, + */ +// #if NOT_EXCLUDED(OP_strided_slice) +@Namespace("sd::ops") public static class strided_slice extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public strided_slice(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public strided_slice(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public strided_slice position(long position) { + return (strided_slice)super.position(position); + } - /** - * This operation extracts a strided (optionally) slice from a tensor, - */ -// #if NOT_EXCLUDED(OP_strided_slice) - @Namespace("sd::ops") public static class strided_slice extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public strided_slice(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public strided_slice(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public strided_slice position(long position) { - return (strided_slice)super.position(position); - } - public strided_slice() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } // TODO: new op type needed. that returns VIEW - @Namespace("sd::ops") public static class strided_slice_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public strided_slice_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public strided_slice_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public strided_slice_bp position(long position) { - return (strided_slice_bp)super.position(position); - } - + } // TODO: new op type needed. that returns VIEW +@Namespace("sd::ops") public static class strided_slice_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public strided_slice_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public strided_slice_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public strided_slice_bp position(long position) { + return (strided_slice_bp)super.position(position); + } + public strided_slice_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation extracts a slice from a tensor. + * + */ +// #if NOT_EXCLUDED(OP_slice) +@Namespace("sd::ops") public static class slice extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public slice(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public slice(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public slice position(long position) { + return (slice)super.position(position); + } - /** - * This operation extracts a slice from a tensor. - * - */ -// #if NOT_EXCLUDED(OP_slice) - @Namespace("sd::ops") public static class slice extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public slice(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public slice(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public slice position(long position) { - return (slice)super.position(position); - } - public slice() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class slice_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public slice_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public slice_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public slice_bp position(long position) { - return (slice_bp)super.position(position); - } - +@Namespace("sd::ops") public static class slice_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public slice_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public slice_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public slice_bp position(long position) { + return (slice_bp)super.position(position); + } + public slice_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation generate sequences. Basically from......to, with step used as + * increment. Expected arguments: start: optional scalar with starting value + * stop: optional scalar with end value + * step: optional scalar witn step value + * + * Int args: (optional) + * 0: optional scalar with starting value + * 1: optional scalar with end value + * 1: optional scalar witn step value + * + * T args: (optional) + * 0: optional scalar with starting value + * 1: optional scalar with end value + * 1: optional scalar witn step value + */ +// #if NOT_EXCLUDED(OP_range) +@Namespace("sd::ops") public static class range extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public range(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public range(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public range position(long position) { + return (range)super.position(position); + } - /** - * This operation generate sequences. Basically from......to, with step used as increment. - * Expected arguments: - * start: optional scalar with starting value - * stop: optional scalar with end value - * step: optional scalar witn step value - * - * Int args: (optional) - * 0: optional scalar with starting value - * 1: optional scalar with end value - * 1: optional scalar witn step value - * - * T args: (optional) - * 0: optional scalar with starting value - * 1: optional scalar with end value - * 1: optional scalar witn step value - */ -// #if NOT_EXCLUDED(OP_range) - @Namespace("sd::ops") public static class range extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public range(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public range(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public range position(long position) { - return (range)super.position(position); - } - public range() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation return one-hot encoded n-dimensional array + * Expected arguments: + * input: N-dimensional array + * + * T args: + * 0: 'on' value + * 1: 'off' value + * + * Int args: + * 0: depth + * 1: axis + */ +// #if NOT_EXCLUDED(OP_onehot) +@Namespace("sd::ops") public static class onehot extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public onehot(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public onehot(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public onehot position(long position) { + return (onehot)super.position(position); + } - /** - * This operation return one-hot encoded n-dimensional array - * Expected arguments: - * input: N-dimensional array - * - * T args: - * 0: 'on' value - * 1: 'off' value - * - * Int args: - * 0: depth - * 1: axis - */ -// #if NOT_EXCLUDED(OP_onehot) - @Namespace("sd::ops") public static class onehot extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public onehot(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public onehot(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public onehot position(long position) { - return (onehot)super.position(position); - } - public onehot() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * This operation calculate the confusion matrix for a + * pair of prediction and label 1-D arrays. + * Expected arguments: + * Input arrays: + * 0 - predictions: 1-D array + * 1 - labels: 1-D array + * 2 - weights : optional + * Int args: + * 0 - num_classes: optional + * + */ +// #if NOT_EXCLUDED(OP_confusion_matrix) +@Namespace("sd::ops") public static class confusion_matrix extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public confusion_matrix(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public confusion_matrix(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public confusion_matrix position(long position) { + return (confusion_matrix)super.position(position); + } - /** - * This operation calculate the confusion matrix for a - * pair of prediction and label 1-D arrays. - * Expected arguments: - * Input arrays: - * 0 - predictions: 1-D array - * 1 - labels: 1-D array - * 2 - weights : optional - * Int args: - * 0 - num_classes: optional - * - */ -// #if NOT_EXCLUDED(OP_confusion_matrix) - @Namespace("sd::ops") public static class confusion_matrix extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public confusion_matrix(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public confusion_matrix(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public confusion_matrix position(long position) { - return (confusion_matrix)super.position(position); - } - public confusion_matrix() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation stacks a list of rank tensors into one rank-(R+1) tensor. + * Expected arguments: + * 0...: N-Dimensional arrays to stack + * + */ +// #if NOT_EXCLUDED(OP_stack) +@Namespace("sd::ops") public static class stack extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public stack(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public stack(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public stack position(long position) { + return (stack)super.position(position); + } - /** - * This operation stacks a list of rank tensors into one rank-(R+1) tensor. - * Expected arguments: - * 0...: N-Dimensional arrays to stack - * - */ -// #if NOT_EXCLUDED(OP_stack) - @Namespace("sd::ops") public static class stack extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public stack(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public stack(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public stack position(long position) { - return (stack)super.position(position); - } - public stack() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation returns length of input array + * Expected arguments: + * input: N-dimensional array + * + * TODO: make this operation reduction, to allow TAD -> size + */ +// #if NOT_EXCLUDED(OP_size) +@Namespace("sd::ops") public static class size extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public size(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public size(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public size position(long position) { + return (size)super.position(position); + } - /** - * This operation returns length of input array - * Expected arguments: - * input: N-dimensional array - * - * TODO: make this operation reduction, to allow TAD -> size - */ -// #if NOT_EXCLUDED(OP_size) - @Namespace("sd::ops") public static class size extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public size(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public size(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public size position(long position) { - return (size)super.position(position); - } - public size() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } // add DeclarableScalarOp? -// #endif + } // add DeclarableScalarOp? +// #endif +/** + * This operation returns rank of input array as scalar value. + */ +// #if NOT_EXCLUDED(OP_rank) +@Namespace("sd::ops") public static class rank extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public rank(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public rank(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public rank position(long position) { + return (rank)super.position(position); + } - /** - * This operation returns rank of input array as scalar value. - */ -// #if NOT_EXCLUDED(OP_rank) - @Namespace("sd::ops") public static class rank extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public rank(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public rank(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public rank position(long position) { - return (rank)super.position(position); - } - public rank() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } // ^ -// #endif + } // ^ +// #endif +// #if NOT_EXCLUDED(OP_broadcastgradientargs) +@Namespace("sd::ops") public static class broadcastgradientargs extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public broadcastgradientargs(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public broadcastgradientargs(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public broadcastgradientargs position(long position) { + return (broadcastgradientargs)super.position(position); + } -// #if NOT_EXCLUDED(OP_broadcastgradientargs) - @Namespace("sd::ops") public static class broadcastgradientargs extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public broadcastgradientargs(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public broadcastgradientargs(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public broadcastgradientargs position(long position) { - return (broadcastgradientargs)super.position(position); - } - public broadcastgradientargs() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation takes input's shape, and returns new NDArray filled with zeros + * Expected arguments: + * input: N-dimensional array + * + */ +// #if NOT_EXCLUDED(OP_zeros_as) +@Namespace("sd::ops") public static class zeros_as extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public zeros_as(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public zeros_as(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public zeros_as position(long position) { + return (zeros_as)super.position(position); + } - /** - * This operation takes input's shape, and returns new NDArray filled with zeros - * Expected arguments: - * input: N-dimensional array - * - */ -// #if NOT_EXCLUDED(OP_zeros_as) - @Namespace("sd::ops") public static class zeros_as extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public zeros_as(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public zeros_as(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public zeros_as position(long position) { - return (zeros_as)super.position(position); - } - public zeros_as() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation takes input's shape, and returns new NDArray filled with ones + * Expected arguments: + * input: N-dimensional array + * + */ +// #if NOT_EXCLUDED(OP_ones_as) +@Namespace("sd::ops") public static class ones_as extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ones_as(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public ones_as(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public ones_as position(long position) { + return (ones_as)super.position(position); + } - /** - * This operation takes input's shape, and returns new NDArray filled with ones - * Expected arguments: - * input: N-dimensional array - * - */ -// #if NOT_EXCLUDED(OP_ones_as) - @Namespace("sd::ops") public static class ones_as extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ones_as(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public ones_as(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public ones_as position(long position) { - return (ones_as)super.position(position); - } - public ones_as() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation applies element-wise pow(x, 2) to the given input + * Expected arguments: + * input: N-Dimensional array + */ +// #if NOT_EXCLUDED(OP_square) +@Namespace("sd::ops") public static class square extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public square(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public square(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public square position(long position) { + return (square)super.position(position); + } - /** - * This operation applies element-wise pow(x, 2) to the given input - * Expected arguments: - * input: N-Dimensional array - */ -// #if NOT_EXCLUDED(OP_square) - @Namespace("sd::ops") public static class square extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public square(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public square(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public square position(long position) { - return (square)super.position(position); - } - public square() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op calculates Hurwitz zeta function zeta(x, q) = sum_{n=0}^{inf} (q + + * n)^{-x} Implementation is based on Euler-Maclaurin summation formula + * + * Input arrays: + * x: define power {-x}, must be > 1, type float. + * q: define summand in denominator, must be > 0, type float. + * + * Output array: + * 0: corresponding values of Hurwitz zeta function + * + * Two input and one output arrays must have the same shape + */ +// #if NOT_EXCLUDED(OP_zeta) +@Namespace("sd::ops") public static class zeta extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public zeta(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public zeta(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public zeta position(long position) { + return (zeta)super.position(position); + } - /** - * This op calculates Hurwitz zeta function zeta(x, q) = sum_{n=0}^{inf} (q + n)^{-x} - * Implementation is based on Euler-Maclaurin summation formula - * - * Input arrays: - * x: define power {-x}, must be > 1, type float. - * q: define summand in denominator, must be > 0, type float. - * - * Output array: - * 0: corresponding values of Hurwitz zeta function - * - * Two input and one output arrays must have the same shape - */ -// #if NOT_EXCLUDED(OP_zeta) - @Namespace("sd::ops") public static class zeta extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public zeta(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public zeta(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public zeta position(long position) { - return (zeta)super.position(position); - } - public zeta() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op calculates polygamma function psi^(n)(x). Implementation is based on + * serial representation written in terms of the Hurwitz zeta function: + * polygamma = (-1)^{n+1} * n! * zeta(n+1, x). + * + * Input arrays: + * 0: n - define derivative order (n+1), type integer (however currently is + * implemented as float casted to integer) 1: x - abscissa points where to + * evaluate the polygamma function, type float + * + * Output array: + * 0: values of polygamma function at corresponding x, type float + * + * Two input and one output arrays have the same shape + */ +// #if NOT_EXCLUDED(OP_polygamma) +@Namespace("sd::ops") public static class polygamma extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public polygamma(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public polygamma(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public polygamma position(long position) { + return (polygamma)super.position(position); + } - /** - * This op calculates polygamma function psi^(n)(x). Implementation is based on serial representation written in - * terms of the Hurwitz zeta function: polygamma = (-1)^{n+1} * n! * zeta(n+1, x). - * - * Input arrays: - * 0: n - define derivative order (n+1), type integer (however currently is implemented as float casted to integer) - * 1: x - abscissa points where to evaluate the polygamma function, type float - * - * Output array: - * 0: values of polygamma function at corresponding x, type float - * - * Two input and one output arrays have the same shape - */ -// #if NOT_EXCLUDED(OP_polygamma) - @Namespace("sd::ops") public static class polygamma extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public polygamma(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public polygamma(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public polygamma position(long position) { - return (polygamma)super.position(position); - } - public polygamma() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op calculates lgamma function lgamma(x) = log(Gamma(x)) + * + * Input arrays: + * 0: x - input matrix + * + * Output array: + * 0: log of Gamma(x) + * + */ +// #if NOT_EXCLUDED(OP_lgamma) +@Namespace("sd::ops") public static class lgamma extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lgamma(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lgamma(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lgamma position(long position) { + return (lgamma)super.position(position); + } - /** - * This op calculates lgamma function lgamma(x) = log(Gamma(x)) - * - * Input arrays: - * 0: x - input matrix - * - * Output array: - * 0: log of Gamma(x) - * - */ -// #if NOT_EXCLUDED(OP_lgamma) - @Namespace("sd::ops") public static class lgamma extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lgamma(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lgamma(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lgamma position(long position) { - return (lgamma)super.position(position); - } - public lgamma() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op calculates digamma function psi(x) = derivative of log(Gamma(x)) + * + * Input arrays: + * 0: x - abscissa points where to evaluate the digamma function, type float + * + * Output array: + * 0: values of digamma function at corresponding x, type float + * + */ +// #if NOT_EXCLUDED(OP_digamma) +@Namespace("sd::ops") public static class digamma extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public digamma(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public digamma(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public digamma position(long position) { + return (digamma)super.position(position); + } - /** - * This op calculates digamma function psi(x) = derivative of log(Gamma(x)) - * - * Input arrays: - * 0: x - abscissa points where to evaluate the digamma function, type float - * - * Output array: - * 0: values of digamma function at corresponding x, type float - * - */ -// #if NOT_EXCLUDED(OP_digamma) - @Namespace("sd::ops") public static class digamma extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public digamma(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public digamma(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public digamma position(long position) { - return (digamma)super.position(position); - } - public digamma() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation takes shape as first argument, and returns new NDArray filled + * with specific scalar value. Input arrays: 0 - shape vector 1 - optional + * scalar NDArray + * + * T arguments: + * 0 - optional scalar value + * + */ +// #if NOT_EXCLUDED(OP_fill) +@Namespace("sd::ops") public static class fill extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public fill(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public fill(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public fill position(long position) { + return (fill)super.position(position); + } - /** - * This operation takes shape as first argument, and returns new NDArray filled with specific scalar value. - * Input arrays: - * 0 - shape vector - * 1 - optional scalar NDArray - * - * T arguments: - * 0 - optional scalar value - * - */ -// #if NOT_EXCLUDED(OP_fill) - @Namespace("sd::ops") public static class fill extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public fill(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public fill(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public fill position(long position) { - return (fill)super.position(position); - } - public fill() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation splits given NDArray into chunks of specific size, along given + * dimension Input arrays: 0 - input array 1 - array of sizes 2 - optional axis + * + * Integer arguments: + * 0 - optional axis + * + */ +// #if NOT_EXCLUDED(OP_split_v) +@Namespace("sd::ops") public static class split_v extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public split_v(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public split_v(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public split_v position(long position) { + return (split_v)super.position(position); + } - /** - * This operation splits given NDArray into chunks of specific size, along given dimension - * Input arrays: - * 0 - input array - * 1 - array of sizes - * 2 - optional axis - * - * Integer arguments: - * 0 - optional axis - * - */ -// #if NOT_EXCLUDED(OP_split_v) - @Namespace("sd::ops") public static class split_v extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public split_v(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public split_v(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public split_v position(long position) { - return (split_v)super.position(position); - } - public split_v() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation splits given NDArray into chunks of specific size, along given + * dimension 0 - input array 1 - optional axis + * + * Integer arguments: + * 0 - number of splits + * 1 - optional axis + */ +// #if NOT_EXCLUDED(OP_split) +@Namespace("sd::ops") public static class split extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public split(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public split(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public split position(long position) { + return (split)super.position(position); + } - /** - * This operation splits given NDArray into chunks of specific size, along given dimension - * 0 - input array - * 1 - optional axis - * - * Integer arguments: - * 0 - number of splits - * 1 - optional axis - */ -// #if NOT_EXCLUDED(OP_split) - @Namespace("sd::ops") public static class split extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public split(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public split(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public split position(long position) { - return (split)super.position(position); - } - public split() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * This operation adjusts image hue by delta + * Input arrays: + * 0 - input array with rank >= 3, must have at least one dimension equal 3, + * that is dimension containing channels. 1 - optional argument, input + * scalar-array containing delta + * + * T arguments: + * 0 - optional argument, delta value + * + * Int arguments: + * 0 - optional argument, corresponds to dimension with 3 channels + */ +// #if NOT_EXCLUDED(OP_adjust_hue) +@Namespace("sd::ops") public static class adjust_hue extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public adjust_hue(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public adjust_hue(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public adjust_hue position(long position) { + return (adjust_hue)super.position(position); + } - /** - * This operation adjusts image hue by delta - * Input arrays: - * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels. - * 1 - optional argument, input scalar-array containing delta - * - * T arguments: - * 0 - optional argument, delta value - * - * Int arguments: - * 0 - optional argument, corresponds to dimension with 3 channels - */ -// #if NOT_EXCLUDED(OP_adjust_hue) - @Namespace("sd::ops") public static class adjust_hue extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public adjust_hue(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public adjust_hue(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public adjust_hue position(long position) { - return (adjust_hue)super.position(position); - } - public adjust_hue() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation adjusts image saturation by delta + * Input arrays: + * 0 - input array with rank >= 3, must have at least one dimension equal 3, + * that is dimension containing channels. 1 - optional argument, input + * scalar-array containing saturation factor + * + * T arguments: + * 0 - optional argument, saturation factor + * + * Int arguments: + * 0 - optional argument, corresponds to dimension with 3 channels + */ +// #if NOT_EXCLUDED(OP_adjust_saturation) +@Namespace("sd::ops") public static class adjust_saturation extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public adjust_saturation(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public adjust_saturation(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public adjust_saturation position(long position) { + return (adjust_saturation)super.position(position); + } - /** - * This operation adjusts image saturation by delta - * Input arrays: - * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels. - * 1 - optional argument, input scalar-array containing saturation factor - * - * T arguments: - * 0 - optional argument, saturation factor - * - * Int arguments: - * 0 - optional argument, corresponds to dimension with 3 channels - */ -// #if NOT_EXCLUDED(OP_adjust_saturation) - @Namespace("sd::ops") public static class adjust_saturation extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public adjust_saturation(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public adjust_saturation(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public adjust_saturation position(long position) { - return (adjust_saturation)super.position(position); - } - public adjust_saturation() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation adjusts image contrast by given factor ( z = (x - mean) * + * factor + mean ) Input arrays: 0 - input array with rank >= 3, must have last + * one dimension equal 3, that is dimension containing channels. 1 - optional + * argument, input scalar-array containing saturation contrast factor + * + * T arguments: + * 0 - optional argument, contrast factor + * + */ +// #if NOT_EXCLUDED(OP_adjust_contrast) +@Namespace("sd::ops") public static class adjust_contrast extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public adjust_contrast(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public adjust_contrast(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public adjust_contrast position(long position) { + return (adjust_contrast)super.position(position); + } - /** - * This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean ) - * Input arrays: - * 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels. - * 1 - optional argument, input scalar-array containing saturation contrast factor - * - * T arguments: - * 0 - optional argument, contrast factor - * - */ -// #if NOT_EXCLUDED(OP_adjust_contrast) - @Namespace("sd::ops") public static class adjust_contrast extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public adjust_contrast(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public adjust_contrast(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public adjust_contrast position(long position) { - return (adjust_contrast)super.position(position); - } - public adjust_contrast() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class adjust_contrast_v2 extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public adjust_contrast_v2(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public adjust_contrast_v2(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public adjust_contrast_v2 position(long position) { - return (adjust_contrast_v2)super.position(position); - } - +@Namespace("sd::ops") public static class adjust_contrast_v2 extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public adjust_contrast_v2(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public adjust_contrast_v2(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public adjust_contrast_v2 position(long position) { + return (adjust_contrast_v2)super.position(position); + } + public adjust_contrast_v2() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - +// #endif +/** + * This operation rearranges data from depth into blocks of spatial data. This + * is the reverse transformation of space_to_depth op. This op output is a copy + * of the input tensor where values from the depth dimension are moved in + * spatial blocks to the height and width dimensions. Int attr 0 indicates the + * input block size and how the data is moved. Input: 0 - 4D tensor on given + * type Output: 0 - 4D tensor of given type and proper shape + * + * Int arguments: + * 0 - block size + * 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels + * } 1 ("NCHW"): shape{ batch, channels, height, width } 2 ("NCHW_VECT_C"): int8 + * shape{ batch, channels / 4, height, width, 4 } optional (default 0) + */ +// #if NOT_EXCLUDED(OP_depth_to_space) +@Namespace("sd::ops") public static class depth_to_space extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public depth_to_space(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public depth_to_space(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public depth_to_space position(long position) { + return (depth_to_space)super.position(position); + } - /** - * This operation rearranges data from depth into blocks of spatial data. This is the reverse transformation - * of space_to_depth op. This op output is a copy of the input tensor where values from the depth dimension - * are moved in spatial blocks to the height and width dimensions. Int attr 0 indicates the input - * block size and how the data is moved. - * Input: - * 0 - 4D tensor on given type - * Output: - * 0 - 4D tensor of given type and proper shape - * - * Int arguments: - * 0 - block size - * 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels } - * 1 ("NCHW"): shape{ batch, channels, height, width } - * 2 ("NCHW_VECT_C"): int8 shape{ batch, channels / 4, height, width, 4 } - * optional (default 0) - */ -// #if NOT_EXCLUDED(OP_depth_to_space) - @Namespace("sd::ops") public static class depth_to_space extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public depth_to_space(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public depth_to_space(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public depth_to_space position(long position) { - return (depth_to_space)super.position(position); - } - public depth_to_space() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation rearranges blocks of spatial data, into depth.This op output + * is a copy of the input tensor where values from the height and width + * dimensions are moved to the depth dimension. Int attr 0 indicates the input + * block size. + * + * Input: + * - 4D tensor of given type + * Output: + * - 4D tensor + * + * Int arguments: + * 0 - block size + * 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels + * } 1 ("NCHW"): shape{ batch, channels, height, width } 2 ("NCHW_VECT_C"): int8 + * shape{ batch, channels / 4, height, width, 4 } optional (default 0) + * + */ +// #if NOT_EXCLUDED(OP_space_to_depth) +@Namespace("sd::ops") public static class space_to_depth extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public space_to_depth(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public space_to_depth(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public space_to_depth position(long position) { + return (space_to_depth)super.position(position); + } - /** - * This operation rearranges blocks of spatial data, into depth.This op output is a copy of the input tensor - * where values from the height and width dimensions are moved to the depth dimension. Int attr 0 indicates - * the input block size. - * - * Input: - * - 4D tensor of given type - * Output: - * - 4D tensor - * - * Int arguments: - * 0 - block size - * 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels } - * 1 ("NCHW"): shape{ batch, channels, height, width } - * 2 ("NCHW_VECT_C"): int8 shape{ batch, channels / 4, height, width, 4 } - * optional (default 0) - * - */ -// #if NOT_EXCLUDED(OP_space_to_depth) - @Namespace("sd::ops") public static class space_to_depth extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public space_to_depth(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public space_to_depth(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public space_to_depth position(long position) { - return (space_to_depth)super.position(position); - } - public space_to_depth() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif - /** - * This op calculates cross-product between input arguments - * Input arguments - * 0 - vector or tensor A - * 1 - vector or tensor B - */ -// #if NOT_EXCLUDED(OP_cross) - @Namespace("sd::ops") public static class cross extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cross(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cross(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cross position(long position) { - return (cross)super.position(position); - } - - public cross() { super((Pointer)null); allocate(); } - private native void allocate(); +/** + * This op calculates cross-product between input arguments + * Input arguments + * 0 - vector or tensor A + * 1 - vector or tensor B + */ +// #if NOT_EXCLUDED(OP_cross) +@Namespace("sd::ops") public static class cross extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cross(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cross(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cross position(long position) { + return (cross)super.position(position); + } + + public cross() { super((Pointer)null); allocate(); } + private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Zero-pads and then rearranges (permutes) blocks of spatial data into batch. + * More specifically, this op outputs a copy of the input tensor where values + * from the height and width dimensions are moved to the batch dimension. After + * the zero-padding, both height and width of the input must be divisible by the + * block size. + * + * Inputs: + * 0 - input tensor + * 1 - 2D paddings tensor (shape {M, 2}) + * + * Output: + * - result tensor + * + * Int args: + * 0 - block size (M) + * + */ +// #if NOT_EXCLUDED(OP_space_to_batch) +@Namespace("sd::ops") public static class space_to_batch extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public space_to_batch(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public space_to_batch(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public space_to_batch position(long position) { + return (space_to_batch)super.position(position); + } - /** - * Zero-pads and then rearranges (permutes) blocks of spatial data into batch. More specifically, this op - * outputs a copy of the input tensor where values from the height and width dimensions are moved to the - * batch dimension. After the zero-padding, both height and width of the input must be divisible by the block - * size. - * - * Inputs: - * 0 - input tensor - * 1 - 2D paddings tensor (shape {M, 2}) - * - * Output: - * - result tensor - * - * Int args: - * 0 - block size (M) - * - */ -// #if NOT_EXCLUDED(OP_space_to_batch) - @Namespace("sd::ops") public static class space_to_batch extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public space_to_batch(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public space_to_batch(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public space_to_batch position(long position) { - return (space_to_batch)super.position(position); - } - public space_to_batch() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/* + * This operation divides "spatial" dimensions [1, ..., M] of the input into a + * grid of blocks of shape block_shape, and interleaves these blocks with the + * "batch" dimension (0) such that in the output, the spatial dimensions [1, + * ..., M] correspond to the position within the grid, and the batch dimension + * combines both the position within a spatial block and the original batch + * position. Prior to division into blocks, the spatial dimensions of the input + * are optionally zero padded according to paddings. + * + * Inputs: + * 0 - input (N-D tensor) + * 1 - block_shape - int 1D tensor with M length + * 2 - paddings - int 2D tensor with shape {M, 2} + * + * Output: + * - N-D tensor with the same type as input 0. + * + * */ +// #if NOT_EXCLUDED(OP_space_to_batch_nd) +@Namespace("sd::ops") public static class space_to_batch_nd extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public space_to_batch_nd(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public space_to_batch_nd(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public space_to_batch_nd position(long position) { + return (space_to_batch_nd)super.position(position); + } - /* - * This operation divides "spatial" dimensions [1, ..., M] of the input into a grid of blocks of shape - * block_shape, and interleaves these blocks with the "batch" dimension (0) such that in the output, - * the spatial dimensions [1, ..., M] correspond to the position within the grid, and the batch dimension - * combines both the position within a spatial block and the original batch position. Prior to division into - * blocks, the spatial dimensions of the input are optionally zero padded according to paddings. - * - * Inputs: - * 0 - input (N-D tensor) - * 1 - block_shape - int 1D tensor with M length - * 2 - paddings - int 2D tensor with shape {M, 2} - * - * Output: - * - N-D tensor with the same type as input 0. - * - * */ -// #if NOT_EXCLUDED(OP_space_to_batch_nd) - @Namespace("sd::ops") public static class space_to_batch_nd extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public space_to_batch_nd(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public space_to_batch_nd(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public space_to_batch_nd position(long position) { - return (space_to_batch_nd)super.position(position); - } - public space_to_batch_nd() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * + * + */ +// #if NOT_EXCLUDED(OP_batch_to_space) +@Namespace("sd::ops") public static class batch_to_space extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public batch_to_space(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public batch_to_space(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public batch_to_space position(long position) { + return (batch_to_space)super.position(position); + } - /** - * - * - */ -// #if NOT_EXCLUDED(OP_batch_to_space) - @Namespace("sd::ops") public static class batch_to_space extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public batch_to_space(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public batch_to_space(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public batch_to_space position(long position) { - return (batch_to_space)super.position(position); - } - public batch_to_space() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_batch_to_space_nd) - @Namespace("sd::ops") public static class batch_to_space_nd extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public batch_to_space_nd(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public batch_to_space_nd(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public batch_to_space_nd position(long position) { - return (batch_to_space_nd)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_batch_to_space_nd) +@Namespace("sd::ops") public static class batch_to_space_nd extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public batch_to_space_nd(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public batch_to_space_nd(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public batch_to_space_nd position(long position) { + return (batch_to_space_nd)super.position(position); + } + public batch_to_space_nd() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * top_k operation returns a vector of k top values for + * given NDArray as tensor with default boolean (true) + * as sort for result index array + * will be sorted by the values in descending order. + * The first parameter is a NDArray for working. + * The second is k (default 1) - optional + * The third is boolean value(default is true) (0 - as is, 1 - sorted by value) + * optional + */ +// #if NOT_EXCLUDED(OP_top_k) +@Namespace("sd::ops") public static class top_k extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public top_k(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public top_k(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public top_k position(long position) { + return (top_k)super.position(position); + } - /** - * top_k operation returns a vector of k top values for - * given NDArray as tensor with default boolean (true) - * as sort for result index array - * will be sorted by the values in descending order. - * The first parameter is a NDArray for working. - * The second is k (default 1) - optional - * The third is boolean value(default is true) (0 - as is, 1 - sorted by value) optional - */ -// #if NOT_EXCLUDED(OP_top_k) - @Namespace("sd::ops") public static class top_k extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public top_k(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public top_k(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public top_k position(long position) { - return (top_k)super.position(position); - } - public top_k() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * in_top_k operation returns a vector of k boolean values for + * given NDArray as 2D matrix of predicted in the NDArray k top values + * The first parameter is a NDArray of predicted values (2d array). + * The second is NDArray as vector of indeces k top values will be search. + * The third is k + */ +// #if NOT_EXCLUDED(OP_in_top_k) +@Namespace("sd::ops") public static class in_top_k extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public in_top_k(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public in_top_k(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public in_top_k position(long position) { + return (in_top_k)super.position(position); + } - /** - * in_top_k operation returns a vector of k boolean values for - * given NDArray as 2D matrix of predicted in the NDArray k top values - * The first parameter is a NDArray of predicted values (2d array). - * The second is NDArray as vector of indeces k top values will be search. - * The third is k - */ -// #if NOT_EXCLUDED(OP_in_top_k) - @Namespace("sd::ops") public static class in_top_k extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public in_top_k(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public in_top_k(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public in_top_k position(long position) { - return (in_top_k)super.position(position); - } - public in_top_k() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * moments operation calculate a mean and variation for given NDArray + * with reduce a result according to axis array given. + * For full axis the result is both mean and variance of all members in array. + * Otherwise there are two NDArrays with means and variances for + * Axes can be put as the second NDArray or as int vector. + * + * the optional flag "keep_dims" can be set as T param + */ +// #if NOT_EXCLUDED(OP_moments) +@Namespace("sd::ops") public static class moments extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public moments(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public moments(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public moments position(long position) { + return (moments)super.position(position); + } - /** - * moments operation calculate a mean and variation for given NDArray - * with reduce a result according to axis array given. - * For full axis the result is both mean and variance of all members in array. - * Otherwise there are two NDArrays with means and variances for - * Axes can be put as the second NDArray or as int vector. - * - * the optional flag "keep_dims" can be set as T param - */ -// #if NOT_EXCLUDED(OP_moments) - @Namespace("sd::ops") public static class moments extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public moments(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public moments(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public moments position(long position) { - return (moments)super.position(position); - } - public moments() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * embedding_lookup - search for submatrices in given matrix and retunts them + * accordingly to index array given. + */ +// #if NOT_EXCLUDED(OP_embedding_lookup) +@Namespace("sd::ops") public static class embedding_lookup extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public embedding_lookup(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public embedding_lookup(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public embedding_lookup position(long position) { + return (embedding_lookup)super.position(position); + } - /** - * embedding_lookup - search for submatrices in given matrix and retunts them - * accordingly to index array given. - */ -// #if NOT_EXCLUDED(OP_embedding_lookup) - @Namespace("sd::ops") public static class embedding_lookup extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public embedding_lookup(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public embedding_lookup(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public embedding_lookup position(long position) { - return (embedding_lookup)super.position(position); - } - public embedding_lookup() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * dynamic_partition - partition a input tensor onto num_partitions + * accordingly to index array given. + * + * the first param - NDArray to be partitioned. + * the second param - index array + * the third param (integer param) - num or partitions. + * + * returns a num of NDArrays as output + */ +// #if NOT_EXCLUDED(OP_dynamic_partition) +@Namespace("sd::ops") public static class dynamic_partition extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public dynamic_partition(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public dynamic_partition(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public dynamic_partition position(long position) { + return (dynamic_partition)super.position(position); + } - /** - * dynamic_partition - partition a input tensor onto num_partitions - * accordingly to index array given. - * - * the first param - NDArray to be partitioned. - * the second param - index array - * the third param (integer param) - num or partitions. - * - * returns a num of NDArrays as output - */ -// #if NOT_EXCLUDED(OP_dynamic_partition) - @Namespace("sd::ops") public static class dynamic_partition extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public dynamic_partition(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public dynamic_partition(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public dynamic_partition position(long position) { - return (dynamic_partition)super.position(position); - } - public dynamic_partition() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_dynamic_partition_bp) +@Namespace("sd::ops") public static class dynamic_partition_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public dynamic_partition_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public dynamic_partition_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public dynamic_partition_bp position(long position) { + return (dynamic_partition_bp)super.position(position); + } -// #if NOT_EXCLUDED(OP_dynamic_partition_bp) - @Namespace("sd::ops") public static class dynamic_partition_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public dynamic_partition_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public dynamic_partition_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public dynamic_partition_bp position(long position) { - return (dynamic_partition_bp)super.position(position); - } - public dynamic_partition_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * dynamic_stitch - merge partitions from the second param a input tensor + * into a single tensor accordingly to index array given. + * + * the first param - index array + * the second params - tensors to be merged + * + * returns a num of NDArrays as output + * + * the operation is inversion od dynamic_partition + */ +// #if NOT_EXCLUDED(OP_dynamic_stitch) +@Namespace("sd::ops") public static class dynamic_stitch extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public dynamic_stitch(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public dynamic_stitch(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public dynamic_stitch position(long position) { + return (dynamic_stitch)super.position(position); + } - /** - * dynamic_stitch - merge partitions from the second param a input tensor - * into a single tensor accordingly to index array given. - * - * the first param - index array - * the second params - tensors to be merged - * - * returns a num of NDArrays as output - * - * the operation is inversion od dynamic_partition - */ -// #if NOT_EXCLUDED(OP_dynamic_stitch) - @Namespace("sd::ops") public static class dynamic_stitch extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public dynamic_stitch(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public dynamic_stitch(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public dynamic_stitch position(long position) { - return (dynamic_stitch)super.position(position); - } - public dynamic_stitch() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * zero_fraction op. + * compute a fraction of zeros in given array + * + * input param - an array (tensor) + * output value - a real number with given type (e.g. float or double) + */ +// #if NOT_EXCLUDED(OP_zero_fraction) +@Namespace("sd::ops") public static class zero_fraction extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public zero_fraction(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public zero_fraction(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public zero_fraction position(long position) { + return (zero_fraction)super.position(position); + } - /** - * zero_fraction op. - * compute a fraction of zeros in given array - * - * input param - an array (tensor) - * output value - a real number with given type (e.g. float or double) - */ -// #if NOT_EXCLUDED(OP_zero_fraction) - @Namespace("sd::ops") public static class zero_fraction extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public zero_fraction(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public zero_fraction(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public zero_fraction position(long position) { - return (zero_fraction)super.position(position); - } - public zero_fraction() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif - /** - * xw_plus_b op. - * multiply two first matrices and add third vector to each row of result - * - * input params: - * - 2D matrix NxM - * - 2D matrix MxN - * - 1D vector with N elements - * output value - 2D matrix NxN as multiply of matrixes and add vector - * Int args: - * 0 - optional switcher of weights format, if int arg == 1 - mkldnn, else mmul - */ -// #if NOT_EXCLUDED(OP_xw_plus_b) - @Namespace("sd::ops") public static class xw_plus_b extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public xw_plus_b(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public xw_plus_b(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public xw_plus_b position(long position) { - return (xw_plus_b)super.position(position); - } - - public xw_plus_b() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } - @Namespace("sd::ops") public static class xw_plus_b_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public xw_plus_b_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public xw_plus_b_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public xw_plus_b_bp position(long position) { - return (xw_plus_b_bp)super.position(position); - } - - public xw_plus_b_bp() { super((Pointer)null); allocate(); } - private native void allocate(); +/** + * xw_plus_b op. + * multiply two first matrices and add third vector to each row of result + * + * input params: + * - 2D matrix NxM + * - 2D matrix MxN + * - 1D vector with N elements + * output value - 2D matrix NxN as multiply of matrixes and add vector + * Int args: + * 0 - optional switcher of weights format, if int arg == 1 - mkldnn, else + * mmul + */ +// #if NOT_EXCLUDED(OP_xw_plus_b) +@Namespace("sd::ops") public static class xw_plus_b extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public xw_plus_b(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public xw_plus_b(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public xw_plus_b position(long position) { + return (xw_plus_b)super.position(position); + } + + public xw_plus_b() { super((Pointer)null); allocate(); } + private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +@Namespace("sd::ops") public static class xw_plus_b_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public xw_plus_b_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public xw_plus_b_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public xw_plus_b_bp position(long position) { + return (xw_plus_b_bp)super.position(position); + } + + public xw_plus_b_bp() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + +/** + * This operation is missed due it simplicy. + * Input and output params are the same after operation. + * Input - NDArray, output - NDArray with the same shape. + */ +// #if NOT_EXCLUDED(OP_stop_gradient) +@Namespace("sd::ops") public static class stop_gradient extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public stop_gradient(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public stop_gradient(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public stop_gradient position(long position) { + return (stop_gradient)super.position(position); + } - /** - * This operation is missed due it simplicy. - * Input and output params are the same after operation. - * Input - NDArray, output - NDArray with the same shape. - */ -// #if NOT_EXCLUDED(OP_stop_gradient) - @Namespace("sd::ops") public static class stop_gradient extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public stop_gradient(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public stop_gradient(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public stop_gradient position(long position) { - return (stop_gradient)super.position(position); - } - public stop_gradient() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_parallel_stack) +@Namespace("sd::ops") public static class parallel_stack extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public parallel_stack(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public parallel_stack(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public parallel_stack position(long position) { + return (parallel_stack)super.position(position); + } -// #if NOT_EXCLUDED(OP_parallel_stack) - @Namespace("sd::ops") public static class parallel_stack extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public parallel_stack(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public parallel_stack(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public parallel_stack position(long position) { - return (parallel_stack)super.position(position); - } - public parallel_stack() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * normalize_moments operation normalize already calculated mean and variation + * accordingly to shift and count. + * input params: + * - count of data + * - tensor with mean + * - tensor with variance (the same shape as before) + * + * - optional floating point param shift. + * + * returns a normalized pair mean and variance with the same shapes as input + */ +// #if NOT_EXCLUDED(OP_normalize_moments) +@Namespace("sd::ops") public static class normalize_moments extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public normalize_moments(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public normalize_moments(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public normalize_moments position(long position) { + return (normalize_moments)super.position(position); + } - /** - * normalize_moments operation normalize already calculated mean and variation - * accordingly to shift and count. - * input params: - * - count of data - * - tensor with mean - * - tensor with variance (the same shape as before) - * - * - optional floating point param shift. - * - * returns a normalized pair mean and variance with the same shapes as input - */ -// #if NOT_EXCLUDED(OP_normalize_moments) - @Namespace("sd::ops") public static class normalize_moments extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public normalize_moments(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public normalize_moments(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public normalize_moments position(long position) { - return (normalize_moments)super.position(position); - } - public normalize_moments() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * sufficient_statistics operation return calculated mean and variation with + * data count. this operation is invert for moments accordingly to shift and + * count. input params: + * - input tensor + * - axes vector + * + * + * - optional floating point param shift. + * - optional int (as bool) keep_dimension + * + * returns four tensors: + * - scalar tensor (data count) + * - sum elements of input (accross axises) + * - sum of squares of input (accross axises) + * - shift (if was given by input floating param) + */ +// #if NOT_EXCLUDED(OP_sufficient_statistics) +@Namespace("sd::ops") public static class sufficient_statistics extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sufficient_statistics(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sufficient_statistics(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sufficient_statistics position(long position) { + return (sufficient_statistics)super.position(position); + } - /** - * sufficient_statistics operation return calculated mean and variation with data count. - * this operation is invert for moments - * accordingly to shift and count. - * input params: - * - input tensor - * - axes vector - * - * - * - optional floating point param shift. - * - optional int (as bool) keep_dimension - * - * returns four tensors: - * - scalar tensor (data count) - * - sum elements of input (accross axises) - * - sum of squares of input (accross axises) - * - shift (if was given by input floating param) - */ -// #if NOT_EXCLUDED(OP_sufficient_statistics) - @Namespace("sd::ops") public static class sufficient_statistics extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sufficient_statistics(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sufficient_statistics(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sufficient_statistics position(long position) { - return (sufficient_statistics)super.position(position); - } - public sufficient_statistics() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op calculates weighted logarithmic loss of input + * Input arguments + * 0 - target + * 1 - input + * 2 - weights (scalar or vector with same as last dimension) + * + * return value - a tensor with the same shape as target or input + */ +// #if NOT_EXCLUDED(OP_weighted_cross_entropy_with_logits) +@Namespace("sd::ops") public static class weighted_cross_entropy_with_logits extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public weighted_cross_entropy_with_logits(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public weighted_cross_entropy_with_logits(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public weighted_cross_entropy_with_logits position(long position) { + return (weighted_cross_entropy_with_logits)super.position(position); + } - /** - * This op calculates weighted logarithmic loss of input - * Input arguments - * 0 - target - * 1 - input - * 2 - weights (scalar or vector with same as last dimension) - * - * return value - a tensor with the same shape as target or input - */ -// #if NOT_EXCLUDED(OP_weighted_cross_entropy_with_logits) - @Namespace("sd::ops") public static class weighted_cross_entropy_with_logits extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public weighted_cross_entropy_with_logits(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public weighted_cross_entropy_with_logits(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public weighted_cross_entropy_with_logits position(long position) { - return (weighted_cross_entropy_with_logits)super.position(position); - } - public weighted_cross_entropy_with_logits() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op calculates dropout of input + * Input arguments + * 0 - input tensor + * 1 - noise_shape - (vector with shape to reduce) - optional + * + * int parameter - seed for random numbers + * T parameter - probability (should be between 0 and 1) + * return value - a tensor with the same shape as target or input + */ +// #if NOT_EXCLUDED(OP_dropout) +@Namespace("sd::ops") public static class dropout extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public dropout(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public dropout(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public dropout position(long position) { + return (dropout)super.position(position); + } - /** - * This op calculates dropout of input - * Input arguments - * 0 - input tensor - * 1 - noise_shape - (vector with shape to reduce) - optional - * - * int parameter - seed for random numbers - * T parameter - probability (should be between 0 and 1) - * return value - a tensor with the same shape as target or input - */ -// #if NOT_EXCLUDED(OP_dropout) - @Namespace("sd::ops") public static class dropout extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public dropout(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public dropout(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public dropout position(long position) { - return (dropout)super.position(position); - } - public dropout() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_dropout_bp) - @Namespace("sd::ops") public static class dropout_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public dropout_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public dropout_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public dropout_bp position(long position) { - return (dropout_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_dropout_bp) +@Namespace("sd::ops") public static class dropout_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public dropout_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public dropout_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public dropout_bp position(long position) { + return (dropout_bp)super.position(position); + } + public dropout_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/* Calculates alpha weighted dropout + T params: + 0 - drop probability + 1 - alpha value + 2 - alpha' value + 3 - beta value + */ +// #if NOT_EXCLUDED(OP_alpha_dropout_bp) +@Namespace("sd::ops") public static class alpha_dropout_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public alpha_dropout_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public alpha_dropout_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public alpha_dropout_bp position(long position) { + return (alpha_dropout_bp)super.position(position); + } - /* Calculates alpha weighted dropout - T params: - 0 - drop probability - 1 - alpha value - 2 - alpha' value - 3 - beta value - */ -// #if NOT_EXCLUDED(OP_alpha_dropout_bp) - @Namespace("sd::ops") public static class alpha_dropout_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public alpha_dropout_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public alpha_dropout_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public alpha_dropout_bp position(long position) { - return (alpha_dropout_bp)super.position(position); - } - public alpha_dropout_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * bincount operation return a vector with element counted. + * + * input params: + * - input tensor - only int part are accepted + * - weights - the same shape tensor with integer weights for element + * (optional) default weight - 1,1,1..,1 for all values in the tensor + * + * optional ints: + * - min_length - zero or greater + * - max_length - between min_length and max(input) + 1 + * + * returns four tensors: + * - vector tensor with length to min(max_len, max(input) + 1) with count + * of values in indexed place + * + */ +// #if NOT_EXCLUDED(OP_bincount) +@Namespace("sd::ops") public static class bincount extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public bincount(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public bincount(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public bincount position(long position) { + return (bincount)super.position(position); + } - /** - * bincount operation return a vector with element counted. - * - * input params: - * - input tensor - only int part are accepted - * - weights - the same shape tensor with integer weights for element (optional) - * default weight - 1,1,1..,1 for all values in the tensor - * - * optional ints: - * - min_length - zero or greater - * - max_length - between min_length and max(input) + 1 - * - * returns four tensors: - * - vector tensor with length to min(max_len, max(input) + 1) with count - * of values in indexed place - * - */ -// #if NOT_EXCLUDED(OP_bincount) - @Namespace("sd::ops") public static class bincount extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public bincount(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public bincount(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public bincount position(long position) { - return (bincount)super.position(position); - } - public bincount() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * broadcast_dynamic_shape op. + * + * input params: + * 0 - the first shape (vector with shape) + * 1 - the second shape (vector with shape) + * + * return value: + * vector with broadcasted shape + */ +// #if NOT_EXCLUDED(OP_broadcast_dynamic_shape) +@Namespace("sd::ops") public static class broadcast_dynamic_shape extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public broadcast_dynamic_shape(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public broadcast_dynamic_shape(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public broadcast_dynamic_shape position(long position) { + return (broadcast_dynamic_shape)super.position(position); + } - /** - * broadcast_dynamic_shape op. - * - * input params: - * 0 - the first shape (vector with shape) - * 1 - the second shape (vector with shape) - * - * return value: - * vector with broadcasted shape - */ -// #if NOT_EXCLUDED(OP_broadcast_dynamic_shape) - @Namespace("sd::ops") public static class broadcast_dynamic_shape extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public broadcast_dynamic_shape(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public broadcast_dynamic_shape(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public broadcast_dynamic_shape position(long position) { - return (broadcast_dynamic_shape)super.position(position); - } - public broadcast_dynamic_shape() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * matrix_determinant op. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) + * + * return value: + * tensor with dimension (x * y * z * ::: *) with determinant for all + * M x M matricies + */ +// #if NOT_EXCLUDED(OP_matrix_determinant) +@Namespace("sd::ops") public static class matrix_determinant extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public matrix_determinant(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public matrix_determinant(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public matrix_determinant position(long position) { + return (matrix_determinant)super.position(position); + } - /** - * matrix_determinant op. - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * M) - * - * return value: - * tensor with dimension (x * y * z * ::: *) with determinant for all - * M x M matricies - */ -// #if NOT_EXCLUDED(OP_matrix_determinant) - @Namespace("sd::ops") public static class matrix_determinant extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public matrix_determinant(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public matrix_determinant(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public matrix_determinant position(long position) { - return (matrix_determinant)super.position(position); - } - public matrix_determinant() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif - /** - * log_matrix_determinant op. - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * M) - * - * return value: - * tensor with dimension (x * y * z * ::: *) with log determinant for all - * M x M matricies - */ +/** + * log_matrix_determinant op. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) + * + * return value: + * tensor with dimension (x * y * z * ::: *) with log determinant for all + * M x M matricies + */ -// #if NOT_EXCLUDED(OP_log_matrix_determinant) - @Namespace("sd::ops") public static class log_matrix_determinant extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public log_matrix_determinant(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public log_matrix_determinant(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public log_matrix_determinant position(long position) { - return (log_matrix_determinant)super.position(position); - } - - public log_matrix_determinant() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } -// #endif +// #if NOT_EXCLUDED(OP_log_matrix_determinant) +@Namespace("sd::ops") public static class log_matrix_determinant extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public log_matrix_determinant(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public log_matrix_determinant(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public log_matrix_determinant position(long position) { + return (log_matrix_determinant)super.position(position); + } - /** - * logdet op. Logarithm of the determinant of hermitian positive matricies. - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * M) - * - * return value: - * tensor with dimension (x * y * z * ::: *) with log determinant for all - * M x M matricies - */ + public log_matrix_determinant() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + +/** + * logdet op. Logarithm of the determinant of hermitian positive matricies. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) + * + * return value: + * tensor with dimension (x * y * z * ::: *) with log determinant for all + * M x M matricies + */ + +// #if NOT_EXCLUDED(OP_logdet) +@Namespace("sd::ops") public static class logdet extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public logdet(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public logdet(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public logdet position(long position) { + return (logdet)super.position(position); + } -// #if NOT_EXCLUDED(OP_logdet) - @Namespace("sd::ops") public static class logdet extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public logdet(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public logdet(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public logdet position(long position) { - return (logdet)super.position(position); - } - public logdet() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * matrix_solve_ls op (lstsq) - solves one or more linear least-squares + * problems. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * N) - left parts of + * equations 1 - the tensor with dimension (x * y * z * ::: * M * K) - right + * parts of equations + * + * float args: + * 0 - l2_regularizer (default 0. and only for 0 implemented) + * + * boolean args: + * 0 - fast - default is true (optional) - use Cholesky decomposition instead + * QR decomposition of matricies. + * + * return value: + * tensor with dimension (x * y * z * ::: * N * K) with solutions + * + */ +// #if NOT_EXCLUDED(OP_lstsq) +@Namespace("sd::ops") public static class lstsq extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstsq(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstsq(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstsq position(long position) { + return (lstsq)super.position(position); + } - /** - * matrix_solve_ls op (lstsq) - solves one or more linear least-squares problems. - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * N) - left parts of equations - * 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations - * - * float args: - * 0 - l2_regularizer (default 0. and only for 0 implemented) - * - * boolean args: - * 0 - fast - default is true (optional) - use Cholesky decomposition instead QR decomposition of matricies. - * - * return value: - * tensor with dimension (x * y * z * ::: * N * K) with solutions - * - */ -// #if NOT_EXCLUDED(OP_lstsq) - @Namespace("sd::ops") public static class lstsq extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lstsq(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lstsq(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lstsq position(long position) { - return (lstsq)super.position(position); - } - public lstsq() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/* solve_ls - analog of lstsq op with another solution approach + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * N) - left parts of + * equations 1 - the tensor with dimension (x * y * z * ::: * M * K) - right + * parts of equations + * + * float args: + * 0 - l2_regularizer (default 0. and only for 0 implemented) + * + * boolean args: + * 0 - fast - default is true (optional) - use Cholesky decomposition instead + * QR decomposition of matricies. + * + * return value: + * tensor with dimension (x * y * z * ::: * N * K) with solutions + * + * Note: if fast is false - then l2_regularizer arg is ignored and used lstsq + * method due QR decomposition + * */ +// #if NOT_EXCLUDED(OP_solve_ls) +@Namespace("sd::ops") public static class solve_ls extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public solve_ls(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public solve_ls(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public solve_ls position(long position) { + return (solve_ls)super.position(position); + } - /* solve_ls - analog of lstsq op with another solution approach - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * N) - left parts of equations - * 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations - * - * float args: - * 0 - l2_regularizer (default 0. and only for 0 implemented) - * - * boolean args: - * 0 - fast - default is true (optional) - use Cholesky decomposition instead QR decomposition of matricies. - * - * return value: - * tensor with dimension (x * y * z * ::: * N * K) with solutions - * - * Note: if fast is false - then l2_regularizer arg is ignored and used lstsq method due QR decomposition - * */ -// #if NOT_EXCLUDED(OP_solve_ls) - @Namespace("sd::ops") public static class solve_ls extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public solve_ls(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public solve_ls(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public solve_ls position(long position) { - return (solve_ls)super.position(position); - } - public solve_ls() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * matrix_inverse op. - make inverse for all 2D square matricies found in the + * input tensor + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) + * + * return value: + * tensor with dimension (x * y * z * ::: * M * M) with inverse M x M + * matricies in it + */ +// #if NOT_EXCLUDED(OP_matrix_inverse) +@Namespace("sd::ops") public static class matrix_inverse extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public matrix_inverse(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public matrix_inverse(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public matrix_inverse position(long position) { + return (matrix_inverse)super.position(position); + } - /** - * matrix_inverse op. - make inverse for all 2D square matricies found in the input tensor - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * M) - * - * return value: - * tensor with dimension (x * y * z * ::: * M * M) with inverse M x M matricies in it - */ -// #if NOT_EXCLUDED(OP_matrix_inverse) - @Namespace("sd::ops") public static class matrix_inverse extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public matrix_inverse(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public matrix_inverse(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public matrix_inverse position(long position) { - return (matrix_inverse)super.position(position); - } - public matrix_inverse() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * triangular_solve op. - reverse Gaussian method for solve systems of linear + * equations. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of + * equations 1 - the tensor with dimension (x * y * z * ::: * M * K) - right + * parts of equations + * + * boolean args: + * 0 - lower - default is true (optional) - left part is lower triangular + * matrix 1 - adjoint - default is false (optional) - indicate input matrix or + * its adjoint (hermitian addition) should be used + * + * return value: + * tensor with dimension (x * y * z * ::: * M * K) with solutions + * + */ +// #if NOT_EXCLUDED(OP_triangular_solve) +@Namespace("sd::ops") public static class triangular_solve extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public triangular_solve(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public triangular_solve(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public triangular_solve position(long position) { + return (triangular_solve)super.position(position); + } - /** - * triangular_solve op. - reverse Gaussian method for solve systems of linear equations. - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of equations - * 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations - * - * boolean args: - * 0 - lower - default is true (optional) - left part is lower triangular matrix - * 1 - adjoint - default is false (optional) - indicate input matrix or its adjoint (hermitian addition) should be used - * - * return value: - * tensor with dimension (x * y * z * ::: * M * K) with solutions - * - */ -// #if NOT_EXCLUDED(OP_triangular_solve) - @Namespace("sd::ops") public static class triangular_solve extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public triangular_solve(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public triangular_solve(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public triangular_solve position(long position) { - return (triangular_solve)super.position(position); - } - public triangular_solve() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * solve op. - solve systems of linear equations - general method. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of + * equations 1 - the tensor with dimension (x * y * z * ::: * M * K) - right + * parts of equations + * + * boolean args: + * 0 - adjoint - default is false (optional) - indicate input matrix or its + * adjoint (hermitian addition) should be used + * + * return value: + * tensor with dimension (x * y * z * ::: * M * K) with solutions + * + */ +// #if NOT_EXCLUDED(OP_solve) +@Namespace("sd::ops") public static class solve extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public solve(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public solve(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public solve position(long position) { + return (solve)super.position(position); + } - /** - * solve op. - solve systems of linear equations - general method. - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of equations - * 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations - * - * boolean args: - * 0 - adjoint - default is false (optional) - indicate input matrix or its adjoint (hermitian addition) should be used - * - * return value: - * tensor with dimension (x * y * z * ::: * M * K) with solutions - * - */ -// #if NOT_EXCLUDED(OP_solve) - @Namespace("sd::ops") public static class solve extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public solve(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public solve(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public solve position(long position) { - return (solve)super.position(position); - } - public solve() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif - /** - * lu op. - make LUP decomposition of given batch of 2D square matricies - * - * input params: - * 0 - float tensor with dimension (x * y * z * ::: * M * M) - * - * return value: - * 0 - float tensor with dimension (x * y * z * ::: * M * M) with LU M x M matricies in it - * 1 - int (32 or 64) batched vector of permutations with length M - shape (x * y * z * ::: * M) - * - * int argument: - * 0 - data type of output permutaion vector (int32 or int64), optional, default INT32 - */ +/** + * lu op. - make LUP decomposition of given batch of 2D square matricies + * + * input params: + * 0 - float tensor with dimension (x * y * z * ::: * M * M) + * + * return value: + * 0 - float tensor with dimension (x * y * z * ::: * M * M) with LU M x M + * matricies in it 1 - int (32 or 64) batched vector of permutations with length + * M - shape (x * y * z * ::: * M) + * + * int argument: + * 0 - data type of output permutaion vector (int32 or int64), optional, + * default INT32 + */ + +// #if NOT_EXCLUDED(OP_matrix_inverse) +@Namespace("sd::ops") public static class lu extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lu(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lu(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lu position(long position) { + return (lu)super.position(position); + } -// #if NOT_EXCLUDED(OP_matrix_inverse) - @Namespace("sd::ops") public static class lu extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lu(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lu(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lu position(long position) { - return (lu)super.position(position); - } - public lu() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * sequence_mask op. - make mask for given tensor filled by (j > x[i_1, + * i_2,...,i_n]) -> z[i_1, i_2,...,i_n,j] + * + * input params: + * 0 - the ND-tensor filled by integer-like values + * + * optional int param - maxlength (maxlength >= max(x)). By default maxlength = + * max(x). return value: (N+1)D tensor filled by 0 and 1 accordingly the mask + */ +// #if NOT_EXCLUDED(OP_sequence_mask) +@Namespace("sd::ops") public static class sequence_mask extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sequence_mask(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sequence_mask(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sequence_mask position(long position) { + return (sequence_mask)super.position(position); + } - /** - * sequence_mask op. - make mask for given tensor filled by (j > x[i_1, i_2,...,i_n]) -> z[i_1, i_2,...,i_n,j] - * - * input params: - * 0 - the ND-tensor filled by integer-like values - * - * optional int param - maxlength (maxlength >= max(x)). By default maxlength = max(x). - * return value: - * (N+1)D tensor filled by 0 and 1 accordingly the mask - */ -// #if NOT_EXCLUDED(OP_sequence_mask) - @Namespace("sd::ops") public static class sequence_mask extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sequence_mask(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sequence_mask(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sequence_mask position(long position) { - return (sequence_mask)super.position(position); - } - public sequence_mask() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - /** - * segment_max op. - make a tensor filled by max values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * return value: - * tensor with max values according to indices sets. - */ +// #endif +/** + * segment_max op. - make a tensor filled by max values according to index + * tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * return value: + * tensor with max values according to indices sets. + */ + +// #if NOT_EXCLUDED(OP_segment_max) +@Namespace("sd::ops") public static class segment_max extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public segment_max(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public segment_max(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public segment_max position(long position) { + return (segment_max)super.position(position); + } -// #if NOT_EXCLUDED(OP_segment_max) - @Namespace("sd::ops") public static class segment_max extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public segment_max(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public segment_max(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public segment_max position(long position) { - return (segment_max)super.position(position); - } - public segment_max() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_segment_max_bp) - @Namespace("sd::ops") public static class segment_max_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public segment_max_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public segment_max_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public segment_max_bp position(long position) { - return (segment_max_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_segment_max_bp) +@Namespace("sd::ops") public static class segment_max_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public segment_max_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public segment_max_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public segment_max_bp position(long position) { + return (segment_max_bp)super.position(position); + } + public segment_max_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * segment_min op. - make a tensor filled by min values according to index + * tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * return value: + * tensor with min values according to indices sets. + */ +// #if NOT_EXCLUDED(OP_segment_min) +@Namespace("sd::ops") public static class segment_min extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public segment_min(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public segment_min(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public segment_min position(long position) { + return (segment_min)super.position(position); + } - /** - * segment_min op. - make a tensor filled by min values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * return value: - * tensor with min values according to indices sets. - */ -// #if NOT_EXCLUDED(OP_segment_min) - @Namespace("sd::ops") public static class segment_min extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public segment_min(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public segment_min(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public segment_min position(long position) { - return (segment_min)super.position(position); - } - public segment_min() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_segment_min_bp) - @Namespace("sd::ops") public static class segment_min_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public segment_min_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public segment_min_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public segment_min_bp position(long position) { - return (segment_min_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_segment_min_bp) +@Namespace("sd::ops") public static class segment_min_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public segment_min_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public segment_min_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public segment_min_bp position(long position) { + return (segment_min_bp)super.position(position); + } + public segment_min_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * segment_sum op. - make a tensor filled by sum of values according to index + * tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * return value: + * tensor with sum of values according to indices sets. + */ +// #if NOT_EXCLUDED(OP_segment_sum) +@Namespace("sd::ops") public static class segment_sum extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public segment_sum(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public segment_sum(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public segment_sum position(long position) { + return (segment_sum)super.position(position); + } - /** - * segment_sum op. - make a tensor filled by sum of values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * return value: - * tensor with sum of values according to indices sets. - */ -// #if NOT_EXCLUDED(OP_segment_sum) - @Namespace("sd::ops") public static class segment_sum extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public segment_sum(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public segment_sum(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public segment_sum position(long position) { - return (segment_sum)super.position(position); - } - public segment_sum() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_segment_sum_bp) - @Namespace("sd::ops") public static class segment_sum_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public segment_sum_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public segment_sum_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public segment_sum_bp position(long position) { - return (segment_sum_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_segment_sum_bp) +@Namespace("sd::ops") public static class segment_sum_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public segment_sum_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public segment_sum_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public segment_sum_bp position(long position) { + return (segment_sum_bp)super.position(position); + } + public segment_sum_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * segment_prod op. - make a tensor filled by product of values according to + * index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * return value: + * tensor with product of values according to indices sets. + */ +// #if NOT_EXCLUDED(OP_segment_prod) +@Namespace("sd::ops") public static class segment_prod extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public segment_prod(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public segment_prod(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public segment_prod position(long position) { + return (segment_prod)super.position(position); + } - /** - * segment_prod op. - make a tensor filled by product of values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * return value: - * tensor with product of values according to indices sets. - */ -// #if NOT_EXCLUDED(OP_segment_prod) - @Namespace("sd::ops") public static class segment_prod extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public segment_prod(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public segment_prod(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public segment_prod position(long position) { - return (segment_prod)super.position(position); - } - public segment_prod() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_segment_prod_bp) - @Namespace("sd::ops") public static class segment_prod_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public segment_prod_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public segment_prod_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public segment_prod_bp position(long position) { - return (segment_prod_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_segment_prod_bp) +@Namespace("sd::ops") public static class segment_prod_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public segment_prod_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public segment_prod_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public segment_prod_bp position(long position) { + return (segment_prod_bp)super.position(position); + } + public segment_prod_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - /** - * segment_mean op. - make a tensor filled by average of values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * return value: - * tensor with average of values according to indices sets. - */ -// #if NOT_EXCLUDED(OP_segment_mean) - @Namespace("sd::ops") public static class segment_mean extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public segment_mean(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public segment_mean(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public segment_mean position(long position) { - return (segment_mean)super.position(position); - } - +// #endif +/** + * segment_mean op. - make a tensor filled by average of values according to + * index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * return value: + * tensor with average of values according to indices sets. + */ +// #if NOT_EXCLUDED(OP_segment_mean) +@Namespace("sd::ops") public static class segment_mean extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public segment_mean(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public segment_mean(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public segment_mean position(long position) { + return (segment_mean)super.position(position); + } + public segment_mean() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_segment_mean_bp) - @Namespace("sd::ops") public static class segment_mean_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public segment_mean_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public segment_mean_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public segment_mean_bp position(long position) { - return (segment_mean_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_segment_mean_bp) +@Namespace("sd::ops") public static class segment_mean_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public segment_mean_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public segment_mean_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public segment_mean_bp position(long position) { + return (segment_mean_bp)super.position(position); + } + public segment_mean_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * unsorted_segment_max op. - make a tensor filled by max values according to + * index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * return value: + * tensor with max values according to indices sets. + */ +// #if NOT_EXCLUDED(OP_unsorted_segment_max) +@Namespace("sd::ops") public static class unsorted_segment_max extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unsorted_segment_max(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unsorted_segment_max(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unsorted_segment_max position(long position) { + return (unsorted_segment_max)super.position(position); + } - /** - * unsorted_segment_max op. - make a tensor filled by max values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * return value: - * tensor with max values according to indices sets. - */ -// #if NOT_EXCLUDED(OP_unsorted_segment_max) - @Namespace("sd::ops") public static class unsorted_segment_max extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unsorted_segment_max(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unsorted_segment_max(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unsorted_segment_max position(long position) { - return (unsorted_segment_max)super.position(position); - } - public unsorted_segment_max() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_unsorted_segment_max_bp) - @Namespace("sd::ops") public static class unsorted_segment_max_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unsorted_segment_max_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unsorted_segment_max_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unsorted_segment_max_bp position(long position) { - return (unsorted_segment_max_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_unsorted_segment_max_bp) +@Namespace("sd::ops") public static class unsorted_segment_max_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unsorted_segment_max_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unsorted_segment_max_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unsorted_segment_max_bp position(long position) { + return (unsorted_segment_max_bp)super.position(position); + } + public unsorted_segment_max_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * unsorted_segment_min op. - make a tensor filled by min values according to + * index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * integer param: + * 0 - num of segments + * + * return value: + * tensor with min values according to indices sets. + */ +// #if NOT_EXCLUDED(OP_unsorted_segment_min_bp) +@Namespace("sd::ops") public static class unsorted_segment_min extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unsorted_segment_min(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unsorted_segment_min(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unsorted_segment_min position(long position) { + return (unsorted_segment_min)super.position(position); + } - /** - * unsorted_segment_min op. - make a tensor filled by min values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * integer param: - * 0 - num of segments - * - * return value: - * tensor with min values according to indices sets. - */ -// #if NOT_EXCLUDED(OP_unsorted_segment_min_bp) - @Namespace("sd::ops") public static class unsorted_segment_min extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unsorted_segment_min(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unsorted_segment_min(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unsorted_segment_min position(long position) { - return (unsorted_segment_min)super.position(position); - } - public unsorted_segment_min() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_unsorted_segment_min_bp) - @Namespace("sd::ops") public static class unsorted_segment_min_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unsorted_segment_min_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unsorted_segment_min_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unsorted_segment_min_bp position(long position) { - return (unsorted_segment_min_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_unsorted_segment_min_bp) +@Namespace("sd::ops") public static class unsorted_segment_min_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unsorted_segment_min_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unsorted_segment_min_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unsorted_segment_min_bp position(long position) { + return (unsorted_segment_min_bp)super.position(position); + } + public unsorted_segment_min_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * unsorted_segment_sum op. - make a tensor filled by sum of values according to + * index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * integer param: + * 0 - num of segments + * + * return value: + * tensor with sum of values according to indices sets. + */ +// #if NOT_EXCLUDED(OP_unsorted_segment_sum) +@Namespace("sd::ops") public static class unsorted_segment_sum extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unsorted_segment_sum(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unsorted_segment_sum(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unsorted_segment_sum position(long position) { + return (unsorted_segment_sum)super.position(position); + } - /** - * unsorted_segment_sum op. - make a tensor filled by sum of values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * integer param: - * 0 - num of segments - * - * return value: - * tensor with sum of values according to indices sets. - */ -// #if NOT_EXCLUDED(OP_unsorted_segment_sum) - @Namespace("sd::ops") public static class unsorted_segment_sum extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unsorted_segment_sum(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unsorted_segment_sum(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unsorted_segment_sum position(long position) { - return (unsorted_segment_sum)super.position(position); - } - public unsorted_segment_sum() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_unsorted_segment_sum_bp) - @Namespace("sd::ops") public static class unsorted_segment_sum_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unsorted_segment_sum_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unsorted_segment_sum_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unsorted_segment_sum_bp position(long position) { - return (unsorted_segment_sum_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_unsorted_segment_sum_bp) +@Namespace("sd::ops") public static class unsorted_segment_sum_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unsorted_segment_sum_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unsorted_segment_sum_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unsorted_segment_sum_bp position(long position) { + return (unsorted_segment_sum_bp)super.position(position); + } + public unsorted_segment_sum_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * unsorted_segment_prod op. - make a tensor filled by product of values + * according to index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * integer param: + * 0 - num of segments + * + * return value: + * tensor with product of values according to indices sets. + */ +// #if NOT_EXCLUDED(OP_unsorted_segment_prod) +@Namespace("sd::ops") public static class unsorted_segment_prod extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unsorted_segment_prod(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unsorted_segment_prod(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unsorted_segment_prod position(long position) { + return (unsorted_segment_prod)super.position(position); + } + + public unsorted_segment_prod() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif +// #if NOT_EXCLUDED(OP_unsorted_segment_prod_bp) +@Namespace("sd::ops") public static class unsorted_segment_prod_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unsorted_segment_prod_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unsorted_segment_prod_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unsorted_segment_prod_bp position(long position) { + return (unsorted_segment_prod_bp)super.position(position); + } - /** - * unsorted_segment_prod op. - make a tensor filled by product of values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * integer param: - * 0 - num of segments - * - * return value: - * tensor with product of values according to indices sets. - */ -// #if NOT_EXCLUDED(OP_unsorted_segment_prod) - @Namespace("sd::ops") public static class unsorted_segment_prod extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unsorted_segment_prod(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unsorted_segment_prod(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unsorted_segment_prod position(long position) { - return (unsorted_segment_prod)super.position(position); - } - - public unsorted_segment_prod() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } -// #endif -// #if NOT_EXCLUDED(OP_unsorted_segment_prod_bp) - @Namespace("sd::ops") public static class unsorted_segment_prod_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unsorted_segment_prod_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unsorted_segment_prod_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unsorted_segment_prod_bp position(long position) { - return (unsorted_segment_prod_bp)super.position(position); - } - public unsorted_segment_prod_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * unsorted_segment_mean op. - make a tensor filled by average of values + * according to index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * integer param: + * 0 - num of segments + * + * return value: + * tensor with average of values according to indices sets. + */ +// #if NOT_EXCLUDED(OP_unsorted_segment_mean) +@Namespace("sd::ops") public static class unsorted_segment_mean extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unsorted_segment_mean(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unsorted_segment_mean(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unsorted_segment_mean position(long position) { + return (unsorted_segment_mean)super.position(position); + } - /** - * unsorted_segment_mean op. - make a tensor filled by average of values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * integer param: - * 0 - num of segments - * - * return value: - * tensor with average of values according to indices sets. - */ -// #if NOT_EXCLUDED(OP_unsorted_segment_mean) - @Namespace("sd::ops") public static class unsorted_segment_mean extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unsorted_segment_mean(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unsorted_segment_mean(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unsorted_segment_mean position(long position) { - return (unsorted_segment_mean)super.position(position); - } - public unsorted_segment_mean() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_unsorted_segment_mean_bp) - @Namespace("sd::ops") public static class unsorted_segment_mean_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unsorted_segment_mean_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unsorted_segment_mean_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unsorted_segment_mean_bp position(long position) { - return (unsorted_segment_mean_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_unsorted_segment_mean_bp) +@Namespace("sd::ops") public static class unsorted_segment_mean_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unsorted_segment_mean_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unsorted_segment_mean_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unsorted_segment_mean_bp position(long position) { + return (unsorted_segment_mean_bp)super.position(position); + } + public unsorted_segment_mean_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * unsorted_segment_sqrt_n op. - computes the sum along segments of a tensor + * divided by the sqrt(N). + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * integer param: + * 0 - num of segments + * + * return value: + * tensor with average of values according to indices sets. + */ +// #if NOT_EXCLUDED(OP_unsorted_segment_sqrt) +@Namespace("sd::ops") public static class unsorted_segment_sqrt_n extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unsorted_segment_sqrt_n(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unsorted_segment_sqrt_n(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unsorted_segment_sqrt_n position(long position) { + return (unsorted_segment_sqrt_n)super.position(position); + } - /** - * unsorted_segment_sqrt_n op. - computes the sum along segments of a tensor divided by the sqrt(N). - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * integer param: - * 0 - num of segments - * - * return value: - * tensor with average of values according to indices sets. - */ -// #if NOT_EXCLUDED(OP_unsorted_segment_sqrt) - @Namespace("sd::ops") public static class unsorted_segment_sqrt_n extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unsorted_segment_sqrt_n(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unsorted_segment_sqrt_n(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unsorted_segment_sqrt_n position(long position) { - return (unsorted_segment_sqrt_n)super.position(position); - } - public unsorted_segment_sqrt_n() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_unsorted_segment_sqrt_n_bp) - @Namespace("sd::ops") public static class unsorted_segment_sqrt_n_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unsorted_segment_sqrt_n_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unsorted_segment_sqrt_n_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unsorted_segment_sqrt_n_bp position(long position) { - return (unsorted_segment_sqrt_n_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_unsorted_segment_sqrt_n_bp) +@Namespace("sd::ops") public static class unsorted_segment_sqrt_n_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unsorted_segment_sqrt_n_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unsorted_segment_sqrt_n_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unsorted_segment_sqrt_n_bp position(long position) { + return (unsorted_segment_sqrt_n_bp)super.position(position); + } + public unsorted_segment_sqrt_n_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * extract_image_patches op - Extract patches from images and put them in the + * "depth" output dimension. + * + * input params: + * 0 - images tensor (4D) + * + * int params: + * 0 - ksize_rows + * 1 - ksize_cols + * 2 - strides_rows + * 3 - strides_cols + * 4 - rates_rows + * 5 - rates_cols + * 6 - padding_type - 0 - equiv 'VALID', 1 - 'SAME' + */ +// #if NOT_EXCLUDED(OP_extract_image_patches) +@Namespace("sd::ops") public static class extract_image_patches extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public extract_image_patches(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public extract_image_patches(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public extract_image_patches position(long position) { + return (extract_image_patches)super.position(position); + } - /** - * extract_image_patches op - Extract patches from images and put them in the "depth" output dimension. - * - * input params: - * 0 - images tensor (4D) - * - * int params: - * 0 - ksize_rows - * 1 - ksize_cols - * 2 - strides_rows - * 3 - strides_cols - * 4 - rates_rows - * 5 - rates_cols - * 6 - padding_type - 0 - equiv 'VALID', 1 - 'SAME' - */ -// #if NOT_EXCLUDED(OP_extract_image_patches) - @Namespace("sd::ops") public static class extract_image_patches extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public extract_image_patches(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public extract_image_patches(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public extract_image_patches position(long position) { - return (extract_image_patches)super.position(position); - } - public extract_image_patches() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * draw_bounding_boxes op - modified input image with given colors exept given + * boxes. + * + * input params: + * 0 - images tensor (4D) with shape {batch, width, height, channels}, where + * channes is 1 (BW image), 3 (RGB) or 4 (RGBA) 1 - boxes tensor (3D) with shape + * {batch, number_of_boxes, 4} where last dimension encoded as (y_min, x_min, + * y_max, x_max), all values in between 0. and 1. 2 - colours tensor (2D) with + * shape {number_of_boxes, channels} -- bordering color set (palette) + * + * output: + * 0 - 4D tensor with same shape as images (input 0) + */ +// #if NOT_EXCLUDED(OP_draw_bounding_boxes) +@Namespace("sd::ops") public static class draw_bounding_boxes extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public draw_bounding_boxes(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public draw_bounding_boxes(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public draw_bounding_boxes position(long position) { + return (draw_bounding_boxes)super.position(position); + } - /** - * draw_bounding_boxes op - modified input image with given colors exept given boxes. - * - * input params: - * 0 - images tensor (4D) with shape {batch, width, height, channels}, where channes is 1 (BW image), - * 3 (RGB) or 4 (RGBA) - * 1 - boxes tensor (3D) with shape {batch, number_of_boxes, 4} where last dimension encoded as - * (y_min, x_min, y_max, x_max), all values in between 0. and 1. - * 2 - colours tensor (2D) with shape {number_of_boxes, channels} -- bordering color set (palette) - * - * output: - * 0 - 4D tensor with same shape as images (input 0) - */ -// #if NOT_EXCLUDED(OP_draw_bounding_boxes) - @Namespace("sd::ops") public static class draw_bounding_boxes extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public draw_bounding_boxes(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public draw_bounding_boxes(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public draw_bounding_boxes position(long position) { - return (draw_bounding_boxes)super.position(position); - } - public draw_bounding_boxes() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * roll - op porting from numpy + * (https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.roll.html) + * + * input params: + * 0 - NDArray + * + * int params: + * 0 - shift + * 1 - axe 1 + * 2 - axe 2 + * ... + * N - axe N + * + * All axes are optional and should be between 0 and input->rankOf(). Of + * course, all axes can be repeated. + * + * output: + * 0 - NDArray with the same shape as input. + */ +// #if NOT_EXCLUDED(OP_roll) +@Namespace("sd::ops") public static class roll extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public roll(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public roll(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public roll position(long position) { + return (roll)super.position(position); + } - /** - * roll - op porting from numpy (https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.roll.html) - * - * input params: - * 0 - NDArray - * - * int params: - * 0 - shift - * 1 - axe 1 - * 2 - axe 2 - * ... - * N - axe N - * - * All axes are optional and should be between 0 and input->rankOf(). Of course, all axes can be repeated. - * - * output: - * 0 - NDArray with the same shape as input. - */ -// #if NOT_EXCLUDED(OP_roll) - @Namespace("sd::ops") public static class roll extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public roll(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public roll(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public roll position(long position) { - return (roll)super.position(position); - } - public roll() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * lin_space - op porting from TF + * (https://www.tensorflow.org/api_docs/python/tf/lin_space) + * + * optional input params: + * 0 - startVal - NDArray scalar (float point) + * 1 - finishVal - NDArray scalar (float point) + * 2 - numOfElements - NDArray scalar (integer) + * Optional: + * T args + * 0 - startVal + * 1 - finishVal] + * 2 - numOfElements + * output: + * 0 - 1D NDArray with the same type as input and length as given with + * numOfElements param. + */ +// #if NOT_EXCLUDED(OP_lin_space) +@Namespace("sd::ops") public static class lin_space extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lin_space(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lin_space(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lin_space position(long position) { + return (lin_space)super.position(position); + } - /** - * lin_space - op porting from TF (https://www.tensorflow.org/api_docs/python/tf/lin_space) - * - * optional input params: - * 0 - startVal - NDArray scalar (float point) - * 1 - finishVal - NDArray scalar (float point) - * 2 - numOfElements - NDArray scalar (integer) - * Optional: - * T args - * 0 - startVal - * 1 - finishVal] - * 2 - numOfElements - * output: - * 0 - 1D NDArray with the same type as input and length as given with numOfElements param. - */ -// #if NOT_EXCLUDED(OP_lin_space) - @Namespace("sd::ops") public static class lin_space extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lin_space(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lin_space(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lin_space position(long position) { - return (lin_space)super.position(position); - } - public lin_space() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * reduction_sum - tf.reduction_sum operation + * + * input params: + * 0 - NDArray + * + * T_ARG param (optional): + * 0 - keep_dims != 0. + * + * int params (optional): + * 0 - axe 1 + * 1 - axe 2 + * ... + * N-1 axe N + * + * All axes are optional and should be between 0 and input->rankOf() - 1 + * + * output: + * 0 - NDArray with reduces shape accordingly to axes (the scalar in default + * case). + */ +// #if NOT_EXCLUDED(OP_reduce_sum) +@Namespace("sd::ops") public static class reduce_sum extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_sum(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_sum(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_sum position(long position) { + return (reduce_sum)super.position(position); + } - /** - * reduction_sum - tf.reduction_sum operation - * - * input params: - * 0 - NDArray - * - * T_ARG param (optional): - * 0 - keep_dims != 0. - * - * int params (optional): - * 0 - axe 1 - * 1 - axe 2 - * ... - * N-1 axe N - * - * All axes are optional and should be between 0 and input->rankOf() - 1 - * - * output: - * 0 - NDArray with reduces shape accordingly to axes (the scalar in default case). - */ -// #if NOT_EXCLUDED(OP_reduce_sum) - @Namespace("sd::ops") public static class reduce_sum extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_sum(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_sum(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_sum position(long position) { - return (reduce_sum)super.position(position); - } - public reduce_sum() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_reduce_sum_bp) +@Namespace("sd::ops") public static class reduce_sum_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_sum_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_sum_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_sum_bp position(long position) { + return (reduce_sum_bp)super.position(position); + } -// #if NOT_EXCLUDED(OP_reduce_sum_bp) - @Namespace("sd::ops") public static class reduce_sum_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_sum_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_sum_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_sum_bp position(long position) { - return (reduce_sum_bp)super.position(position); - } - public reduce_sum_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * reduction_prod - tf.reduction_prod operation + * + * input params: + * 0 - NDArray + * + * T_ARG param (optional): + * 0 - keep_dims != 0. + * + * int params (optional): + * 0 - axe 1 + * 1 - axe 2 + * ... + * N-1 axe N + * + * All axes are optional and should be between 0 and input->rankOf() - 1 + * + * output: + * 0 - NDArray with reduces shape accordingly to axes (the scalar in default + * case). + */ +// #if NOT_EXCLUDED(OP_reduce_prod) +@Namespace("sd::ops") public static class reduce_prod extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_prod(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_prod(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_prod position(long position) { + return (reduce_prod)super.position(position); + } - /** - * reduction_prod - tf.reduction_prod operation - * - * input params: - * 0 - NDArray - * - * T_ARG param (optional): - * 0 - keep_dims != 0. - * - * int params (optional): - * 0 - axe 1 - * 1 - axe 2 - * ... - * N-1 axe N - * - * All axes are optional and should be between 0 and input->rankOf() - 1 - * - * output: - * 0 - NDArray with reduces shape accordingly to axes (the scalar in default case). - */ -// #if NOT_EXCLUDED(OP_reduce_prod) - @Namespace("sd::ops") public static class reduce_prod extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_prod(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_prod(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_prod position(long position) { - return (reduce_prod)super.position(position); - } - public reduce_prod() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_reduce_prod_bp) +@Namespace("sd::ops") public static class reduce_prod_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_prod_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_prod_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_prod_bp position(long position) { + return (reduce_prod_bp)super.position(position); + } -// #if NOT_EXCLUDED(OP_reduce_prod_bp) - @Namespace("sd::ops") public static class reduce_prod_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_prod_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_prod_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_prod_bp position(long position) { - return (reduce_prod_bp)super.position(position); - } - public reduce_prod_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op calculates min of elements along given dimensions + * + * input array: + * x: tensor to calculate mins for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate min along, default corresponds + * to empty list in which case calculation is performed for all dimensions and + * scalar is returned + * + * output array: + * reduced tensor with calculated mins + */ +// #if NOT_EXCLUDED(OP_reduce_min) +@Namespace("sd::ops") public static class reduce_min extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_min(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_min(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_min position(long position) { + return (reduce_min)super.position(position); + } - /** - * This op calculates min of elements along given dimensions - * - * input array: - * x: tensor to calculate mins for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate min along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated mins - */ -// #if NOT_EXCLUDED(OP_reduce_min) - @Namespace("sd::ops") public static class reduce_min extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_min(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_min(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_min position(long position) { - return (reduce_min)super.position(position); - } - public reduce_min() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_reduce_min_bp) - @Namespace("sd::ops") public static class reduce_min_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_min_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_min_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_min_bp position(long position) { - return (reduce_min_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_reduce_min_bp) +@Namespace("sd::ops") public static class reduce_min_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_min_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_min_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_min_bp position(long position) { + return (reduce_min_bp)super.position(position); + } + public reduce_min_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op calculates max of elements along given dimensions + * + * input array: + * x: tensor to calculate maxes for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate max along, default corresponds + * to empty list in which case calculation is performed for all dimensions and + * scalar is returned + * + * output array: + * reduced tensor with calculated maxes + */ +// #if NOT_EXCLUDED(OP_reduce_max) +@Namespace("sd::ops") public static class reduce_max extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_max(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_max(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_max position(long position) { + return (reduce_max)super.position(position); + } - /** - * This op calculates max of elements along given dimensions - * - * input array: - * x: tensor to calculate maxes for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate max along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated maxes - */ -// #if NOT_EXCLUDED(OP_reduce_max) - @Namespace("sd::ops") public static class reduce_max extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_max(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_max(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_max position(long position) { - return (reduce_max)super.position(position); - } - public reduce_max() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_reduce_max_bp) - @Namespace("sd::ops") public static class reduce_max_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_max_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_max_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_max_bp position(long position) { - return (reduce_max_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_reduce_max_bp) +@Namespace("sd::ops") public static class reduce_max_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_max_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_max_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_max_bp position(long position) { + return (reduce_max_bp)super.position(position); + } + public reduce_max_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op calculates norm1 of elements along given dimensions + * + * input array: + * x: tensor to calculate norm1 for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate norm1 along, default + * corresponds to empty list in which case calculation is performed for all + * dimensions and scalar is returned + * + * output array: + * reduced tensor with calculated norm1 + */ +// #if NOT_EXCLUDED(OP_reduce_norm1) +@Namespace("sd::ops") public static class reduce_norm1 extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_norm1(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_norm1(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_norm1 position(long position) { + return (reduce_norm1)super.position(position); + } - /** - * This op calculates norm1 of elements along given dimensions - * - * input array: - * x: tensor to calculate norm1 for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate norm1 along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated norm1 - */ -// #if NOT_EXCLUDED(OP_reduce_norm1) - @Namespace("sd::ops") public static class reduce_norm1 extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_norm1(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_norm1(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_norm1 position(long position) { - return (reduce_norm1)super.position(position); - } - public reduce_norm1() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_reduce_norm1_bp) - @Namespace("sd::ops") public static class reduce_norm1_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_norm1_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_norm1_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_norm1_bp position(long position) { - return (reduce_norm1_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_reduce_norm1_bp) +@Namespace("sd::ops") public static class reduce_norm1_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_norm1_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_norm1_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_norm1_bp position(long position) { + return (reduce_norm1_bp)super.position(position); + } + public reduce_norm1_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op calculates norm2 of elements along given dimensions + * + * input array: + * x: tensor to calculate norm2 for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate norm2 along, default + * corresponds to empty list in which case calculation is performed for all + * dimensions and scalar is returned + * + * output array: + * reduced tensor with calculated norm2 + */ +// #if NOT_EXCLUDED(OP_reduce_norm2) +@Namespace("sd::ops") public static class reduce_norm2 extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_norm2(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_norm2(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_norm2 position(long position) { + return (reduce_norm2)super.position(position); + } - /** - * This op calculates norm2 of elements along given dimensions - * - * input array: - * x: tensor to calculate norm2 for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate norm2 along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated norm2 - */ -// #if NOT_EXCLUDED(OP_reduce_norm2) - @Namespace("sd::ops") public static class reduce_norm2 extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_norm2(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_norm2(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_norm2 position(long position) { - return (reduce_norm2)super.position(position); - } - public reduce_norm2() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_reduce_norm2_bp) - @Namespace("sd::ops") public static class reduce_norm2_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_norm2_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_norm2_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_norm2_bp position(long position) { - return (reduce_norm2_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_reduce_norm2_bp) +@Namespace("sd::ops") public static class reduce_norm2_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_norm2_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_norm2_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_norm2_bp position(long position) { + return (reduce_norm2_bp)super.position(position); + } + public reduce_norm2_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * This op calculates squared norm of elements along given dimensions + * + * input array: + * x: tensor to calculate squared norm for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate squared norm along, default + * corresponds to empty list in which case calculation is performed for all + * dimensions and scalar is returned + * + * output array: + * reduced tensor with calculated norm + */ +// #if NOT_EXCLUDED(OP_reduce_sqnorm) +@Namespace("sd::ops") public static class reduce_sqnorm extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_sqnorm(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_sqnorm(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_sqnorm position(long position) { + return (reduce_sqnorm)super.position(position); + } - /** - * This op calculates squared norm of elements along given dimensions - * - * input array: - * x: tensor to calculate squared norm for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate squared norm along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated norm - */ -// #if NOT_EXCLUDED(OP_reduce_sqnorm) - @Namespace("sd::ops") public static class reduce_sqnorm extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_sqnorm(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_sqnorm(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_sqnorm position(long position) { - return (reduce_sqnorm)super.position(position); - } - public reduce_sqnorm() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_reduce_sqnorm_bp) - @Namespace("sd::ops") public static class reduce_sqnorm_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_sqnorm_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_sqnorm_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_sqnorm_bp position(long position) { - return (reduce_sqnorm_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_reduce_sqnorm_bp) +@Namespace("sd::ops") public static class reduce_sqnorm_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_sqnorm_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_sqnorm_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_sqnorm_bp position(long position) { + return (reduce_sqnorm_bp)super.position(position); + } + public reduce_sqnorm_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op calculates norm max of elements along given dimensions + * + * input array: + * x: tensor to calculate norm max for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate norm max along, default + * corresponds to empty list in which case calculation is performed for all + * dimensions and scalar is returned + * + * output array: + * reduced tensor with calculated norm + */ +// #if NOT_EXCLUDED(OP_reduce_norm_max) +@Namespace("sd::ops") public static class reduce_norm_max extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_norm_max(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_norm_max(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_norm_max position(long position) { + return (reduce_norm_max)super.position(position); + } - /** - * This op calculates norm max of elements along given dimensions - * - * input array: - * x: tensor to calculate norm max for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate norm max along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated norm - */ -// #if NOT_EXCLUDED(OP_reduce_norm_max) - @Namespace("sd::ops") public static class reduce_norm_max extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_norm_max(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_norm_max(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_norm_max position(long position) { - return (reduce_norm_max)super.position(position); - } - public reduce_norm_max() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_reduce_norm_max_bp) - @Namespace("sd::ops") public static class reduce_norm_max_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_norm_max_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_norm_max_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_norm_max_bp position(long position) { - return (reduce_norm_max_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_reduce_norm_max_bp) +@Namespace("sd::ops") public static class reduce_norm_max_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_norm_max_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_norm_max_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_norm_max_bp position(long position) { + return (reduce_norm_max_bp)super.position(position); + } + public reduce_norm_max_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op calculates mean of elements along given dimensions + * + * input array: + * x: tensor to calculate mean for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate mean along, default corresponds + * to empty list in which case calculation is performed for all dimensions and + * scalar is returned + * + * output array: + * reduced tensor with calculated means + */ +// #if NOT_EXCLUDED(OP_reduce_mean) +@Namespace("sd::ops") public static class reduce_mean extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_mean(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_mean(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_mean position(long position) { + return (reduce_mean)super.position(position); + } - /** - * This op calculates mean of elements along given dimensions - * - * input array: - * x: tensor to calculate mean for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate mean along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated means - */ -// #if NOT_EXCLUDED(OP_reduce_mean) - @Namespace("sd::ops") public static class reduce_mean extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_mean(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_mean(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_mean position(long position) { - return (reduce_mean)super.position(position); - } - public reduce_mean() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_reduce_mean_bp) +@Namespace("sd::ops") public static class reduce_mean_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_mean_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_mean_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_mean_bp position(long position) { + return (reduce_mean_bp)super.position(position); + } -// #if NOT_EXCLUDED(OP_reduce_mean_bp) - @Namespace("sd::ops") public static class reduce_mean_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_mean_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_mean_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_mean_bp position(long position) { - return (reduce_mean_bp)super.position(position); - } - public reduce_mean_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - // #endif - /** - * This op calculates sample variance of elements along given dimensions - * - * input array: - * x: tensor to calculate mean for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * biasCorrected - if non zero, then bias correction will be applied, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate mean along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated means - */ - @Namespace("sd::ops") public static class reduce_variance extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_variance(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_variance(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_variance position(long position) { - return (reduce_variance)super.position(position); - } - + // #endif +/** + * This op calculates sample variance of elements along given dimensions + * + * input array: + * x: tensor to calculate mean for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero biasCorrected - if non zero, then bias correction will + * be applied, default value is zero + * + * int arguments: + * list of integers - dimensions to calculate mean along, default corresponds + * to empty list in which case calculation is performed for all dimensions and + * scalar is returned + * + * output array: + * reduced tensor with calculated means + */ +@Namespace("sd::ops") public static class reduce_variance extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_variance(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_variance(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_variance position(long position) { + return (reduce_variance)super.position(position); + } + public reduce_variance() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class reduce_variance_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_variance_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_variance_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_variance_bp position(long position) { - return (reduce_variance_bp)super.position(position); - } - +@Namespace("sd::ops") public static class reduce_variance_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_variance_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_variance_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_variance_bp position(long position) { + return (reduce_variance_bp)super.position(position); + } + public reduce_variance_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } /** - * This op calculates sample standard deviation of elements along given dimensions - * - * input array: - * x: tensor to calculate mean for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * biasCorrected - if non zero, then bias correction will be applied, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate mean along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated means - */ - @Namespace("sd::ops") public static class reduce_stdev extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_stdev(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_stdev(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_stdev position(long position) { - return (reduce_stdev)super.position(position); - } - + * This op calculates sample standard deviation of elements along given + * dimensions + * + * input array: + * x: tensor to calculate mean for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero biasCorrected - if non zero, then bias correction will + * be applied, default value is zero + * + * int arguments: + * list of integers - dimensions to calculate mean along, default corresponds + * to empty list in which case calculation is performed for all dimensions and + * scalar is returned + * + * output array: + * reduced tensor with calculated means + */ +@Namespace("sd::ops") public static class reduce_stdev extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_stdev(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_stdev(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_stdev position(long position) { + return (reduce_stdev)super.position(position); + } + public reduce_stdev() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class reduce_stdev_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_stdev_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_stdev_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_stdev_bp position(long position) { - return (reduce_stdev_bp)super.position(position); - } - +@Namespace("sd::ops") public static class reduce_stdev_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_stdev_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_stdev_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_stdev_bp position(long position) { + return (reduce_stdev_bp)super.position(position); + } + public reduce_stdev_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } /** - * This op calculates backprop dot for two tensors along given dimensions - * - * input array: - * x: tensor to calculate dot for - * y: tensor to calculate dot for - * z: tensor with gradient output of the FF dot for x and y - * - * int arguments: - * list of integers - dimensions to calculate dot along, - * default corresponds to empty list in which case calculation - * is performed for all dimensions and scalar is returned. - * - * output array: - * the tensor with calculated backproped dots - * - */ + * This op calculates backprop dot for two tensors along given dimensions + * + * input array: + * x: tensor to calculate dot for + * y: tensor to calculate dot for + * z: tensor with gradient output of the FF dot for x and y + * + * int arguments: + * list of integers - dimensions to calculate dot along, + * default corresponds to empty list in which case calculation + * is performed for all dimensions and scalar is returned. + * + * output array: + * the tensor with calculated backproped dots + * + */ + +// #if NOT_EXCLUDED(OP_reduce_dot_bp) +@Namespace("sd::ops") public static class reduce_dot_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_dot_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_dot_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_dot_bp position(long position) { + return (reduce_dot_bp)super.position(position); + } -// #if NOT_EXCLUDED(OP_reduce_dot_bp) - @Namespace("sd::ops") public static class reduce_dot_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_dot_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_dot_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_dot_bp position(long position) { - return (reduce_dot_bp)super.position(position); - } - public reduce_dot_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - /** - * reduce_logsumexp - tf.reduce_logsumexe operation - * - * input params: - * 0 - NDArray (input) - * 1 - 1D NDArray (axis) (optional) - integer array - * - * T_ARG param (optional): - * 0 - keep_dims != 0. - * - * int params (optional): - * 0 - axe 1 - * 1 - axe 2 - * ... - * N-1 axe N - * - * CAUTION: All axes are optional and should be between 0 and input->rankOf() - 1 - * and put either with second param or as integers but not both - * - * output: - * 0 - NDArray with reduces shape accordingly to axes (the scalar in default case). - */ -// #if NOT_EXCLUDED(OP_reduce_logsumexp) - @Namespace("sd::ops") public static class reduce_logsumexp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_logsumexp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_logsumexp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_logsumexp position(long position) { - return (reduce_logsumexp)super.position(position); - } - +// #endif +/** + * reduce_logsumexp - tf.reduce_logsumexe operation + * + * input params: + * 0 - NDArray (input) + * 1 - 1D NDArray (axis) (optional) - integer array + * + * T_ARG param (optional): + * 0 - keep_dims != 0. + * + * int params (optional): + * 0 - axe 1 + * 1 - axe 2 + * ... + * N-1 axe N + * + * CAUTION: All axes are optional and should be between 0 and input->rankOf() - + * 1 and put either with second param or as integers but not both + * + * output: + * 0 - NDArray with reduces shape accordingly to axes (the scalar in default + * case). + */ +// #if NOT_EXCLUDED(OP_reduce_logsumexp) +@Namespace("sd::ops") public static class reduce_logsumexp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_logsumexp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_logsumexp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_logsumexp position(long position) { + return (reduce_logsumexp)super.position(position); + } + public reduce_logsumexp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif /** * Copy a tensor setting everything outside a central band in each innermost matrix @@ -21385,294 +23806,297 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * */ -// #if NOT_EXCLUDED(OP_matrix_band_part) - @Namespace("sd::ops") public static class matrix_band_part extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public matrix_band_part(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public matrix_band_part(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public matrix_band_part position(long position) { - return (matrix_band_part)super.position(position); - } - +// #if NOT_EXCLUDED(OP_matrix_band_part) +@Namespace("sd::ops") public static class matrix_band_part extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public matrix_band_part(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public matrix_band_part(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public matrix_band_part position(long position) { + return (matrix_band_part)super.position(position); + } + public matrix_band_part() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +// #if NOT_EXCLUDED(OP_Assert) +@Namespace("sd::ops") public static class Assert extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Assert(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Assert(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Assert position(long position) { + return (Assert)super.position(position); + } -// #if NOT_EXCLUDED(OP_Assert) - @Namespace("sd::ops") public static class Assert extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Assert(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Assert(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Assert position(long position) { - return (Assert)super.position(position); - } - public Assert() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * image.non_max_suppression ops. + * input: + * 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type + * 1 - scales - 1D-tensor with shape (num_boxes) by float type + * 2 - output_size - 0D-tensor by int type (optional) + * float args: + * 0 - overlap_threshold - threshold value for overlap checks (optional, by + * default 0.5) 1 - score_threshold - the threshold for deciding when to remove + * boxes based on score (optional, by default -inf) int args: 0 - output_size - + * as arg 2 used for same target. Eigher this or arg 2 should be provided. + * + * output: + * - vector with size M, where M <= output_size by int type + * + * */ +// #if NOT_EXCLUDED(OP_image_non_max_suppression) +@Namespace("sd::ops") public static class non_max_suppression extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public non_max_suppression(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public non_max_suppression(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public non_max_suppression position(long position) { + return (non_max_suppression)super.position(position); + } - /** - * image.non_max_suppression ops. - * input: - * 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type - * 1 - scales - 1D-tensor with shape (num_boxes) by float type - * 2 - output_size - 0D-tensor by int type (optional) - * float args: - * 0 - overlap_threshold - threshold value for overlap checks (optional, by default 0.5) - * 1 - score_threshold - the threshold for deciding when to remove boxes based on score (optional, by default -inf) - * int args: - * 0 - output_size - as arg 2 used for same target. Eigher this or arg 2 should be provided. - * - * output: - * - vector with size M, where M <= output_size by int type - * - * */ -// #if NOT_EXCLUDED(OP_image_non_max_suppression) - @Namespace("sd::ops") public static class non_max_suppression extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public non_max_suppression(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public non_max_suppression(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public non_max_suppression position(long position) { - return (non_max_suppression)super.position(position); - } - public non_max_suppression() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_image_non_max_suppression_v3) - @Namespace("sd::ops") public static class non_max_suppression_v3 extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public non_max_suppression_v3(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public non_max_suppression_v3(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public non_max_suppression_v3 position(long position) { - return (non_max_suppression_v3)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_image_non_max_suppression_v3) +@Namespace("sd::ops") public static class non_max_suppression_v3 extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public non_max_suppression_v3(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public non_max_suppression_v3(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public non_max_suppression_v3 position(long position) { + return (non_max_suppression_v3)super.position(position); + } + public non_max_suppression_v3() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/* + * image.non_max_suppression_overlaps op. + * input: + * 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type + * 1 - scales - 1D-tensor with shape (num_boxes) by float type + * 2 - output_size - 0D-tensor by int type (optional) + * float args: + * 0 - overlap_threshold - threshold value for overlap checks (optional, by + * default 0.5) 1 - score_threshold - the threshold for deciding when to remove + * boxes based on score (optional, by default -inf) int args: 0 - output_size - + * as arg 2 used for same target. Eigher this or arg 2 should be provided. + * + * output: + * 0 - 1D integer tensor with shape [M], epresenting the selected indices + * from the overlaps tensor, where M <= max_output_size + * */ +// #if NOT_EXCLUDED(OP_image_non_max_suppression_overlaps) +@Namespace("sd::ops") public static class non_max_suppression_overlaps extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public non_max_suppression_overlaps(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public non_max_suppression_overlaps(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public non_max_suppression_overlaps position(long position) { + return (non_max_suppression_overlaps)super.position(position); + } - /* - * image.non_max_suppression_overlaps op. - * input: - * 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type - * 1 - scales - 1D-tensor with shape (num_boxes) by float type - * 2 - output_size - 0D-tensor by int type (optional) - * float args: - * 0 - overlap_threshold - threshold value for overlap checks (optional, by default 0.5) - * 1 - score_threshold - the threshold for deciding when to remove boxes based on score (optional, by default -inf) - * int args: - * 0 - output_size - as arg 2 used for same target. Eigher this or arg 2 should be provided. - * - * output: - * 0 - 1D integer tensor with shape [M], epresenting the selected indices from the overlaps tensor, where M <= max_output_size - * */ -// #if NOT_EXCLUDED(OP_image_non_max_suppression_overlaps) - @Namespace("sd::ops") public static class non_max_suppression_overlaps extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public non_max_suppression_overlaps(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public non_max_suppression_overlaps(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public non_max_suppression_overlaps position(long position) { - return (non_max_suppression_overlaps)super.position(position); - } - public non_max_suppression_overlaps() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/* + * cholesky op - decomposite positive square symetric matrix (or matricies when + * rank > 2). input: 0 - matricies - tensor with shape (..., N, N) by float type + * + * output - lower triangular matrix (matricies when rank > 2) with the same + * shape as input. + * */ +// #if NOT_EXCLUDED(OP_cholesky) +@Namespace("sd::ops") public static class cholesky extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cholesky(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cholesky(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cholesky position(long position) { + return (cholesky)super.position(position); + } - /* - * cholesky op - decomposite positive square symetric matrix (or matricies when rank > 2). - * input: - * 0 - matricies - tensor with shape (..., N, N) by float type - * - * output - lower triangular matrix (matricies when rank > 2) with the same shape as input. - * */ -// #if NOT_EXCLUDED(OP_cholesky) - @Namespace("sd::ops") public static class cholesky extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cholesky(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cholesky(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cholesky position(long position) { - return (cholesky)super.position(position); - } - public cholesky() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - /* - * nth_element - apply nth_element for last dimension of input tensor - * input array: - * 0 - input array - * 1 - scalar tensor with n for operation. n should be less than last dimension - * - * output: - * 0 - NDArray with the same shape as input - */ -// #if NOT_EXCLUDED(OP_nth_element) - @Namespace("sd::ops") public static class nth_element extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public nth_element(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public nth_element(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public nth_element position(long position) { - return (nth_element)super.position(position); - } - +// #endif +/* + * nth_element - apply nth_element for last dimension of input tensor + * input array: + * 0 - input array + * 1 - scalar tensor with n for operation. n should be less than last + * dimension + * + * output: + * 0 - NDArray with the same shape as input + */ +// #if NOT_EXCLUDED(OP_nth_element) +@Namespace("sd::ops") public static class nth_element extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public nth_element(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public nth_element(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public nth_element position(long position) { + return (nth_element)super.position(position); + } + public nth_element() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op checks for Inf/NaN values within input array, and throws exception if + * there's at least one + */ +// #if NOT_EXCLUDED(OP_check_numerics) +@Namespace("sd::ops") public static class check_numerics extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public check_numerics(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public check_numerics(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public check_numerics position(long position) { + return (check_numerics)super.position(position); + } - /** - * This op checks for Inf/NaN values within input array, and throws exception if there's at least one - */ -// #if NOT_EXCLUDED(OP_check_numerics) - @Namespace("sd::ops") public static class check_numerics extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public check_numerics(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public check_numerics(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public check_numerics position(long position) { - return (check_numerics)super.position(position); - } - public check_numerics() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif /** - * fake_quant_with_min_max_vals - tf.quantization.fake_quant_with_min_max_vars - * - * input params: - * 0 - NDArray (input) - * 1 - 0D Tensor - min value - * 2 - 0D Tensor - max value - * - * int params (optional): - * 0 - num_bits (allowed interval [2, 16], default 8) - * 1 - narrow_range (default False) - * - * output: - * 0 - NDArray with the same shape as input - */ -// #if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars) - @Namespace("sd::ops") public static class fake_quant_with_min_max_vars extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public fake_quant_with_min_max_vars(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public fake_quant_with_min_max_vars(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public fake_quant_with_min_max_vars position(long position) { - return (fake_quant_with_min_max_vars)super.position(position); - } - + * fake_quant_with_min_max_vals - tf.quantization.fake_quant_with_min_max_vars + * + * input params: + * 0 - NDArray (input) + * 1 - 0D Tensor - min value + * 2 - 0D Tensor - max value + * + * int params (optional): + * 0 - num_bits (allowed interval [2, 16], default 8) + * 1 - narrow_range (default False) + * + * output: + * 0 - NDArray with the same shape as input + */ +// #if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars) +@Namespace("sd::ops") public static class fake_quant_with_min_max_vars extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public fake_quant_with_min_max_vars(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public fake_quant_with_min_max_vars(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public fake_quant_with_min_max_vars position(long position) { + return (fake_quant_with_min_max_vars)super.position(position); + } + public fake_quant_with_min_max_vars() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif /** - * fake_quant_with_min_max_vals_per_channel - tf.quantization.fake_quant_with_min_max_vars_per_channel - * - * input params: - * 0 - NDArray (input) - at least 2D. - * 1 - 1D Tensor - min values (min length equals to last dim of input) - * 2 - 1D Tensor - max value (length equals to min) - * - * int params (optional): - * 0 - num_bits (allowed interval [2, 16], default 8) - * 1 - narrow_range (default False) - * - * output: - * 0 - NDArray with the same shape as input - */ -// #if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars_per_channel) - @Namespace("sd::ops") public static class fake_quant_with_min_max_vars_per_channel extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public fake_quant_with_min_max_vars_per_channel(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public fake_quant_with_min_max_vars_per_channel(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public fake_quant_with_min_max_vars_per_channel position(long position) { - return (fake_quant_with_min_max_vars_per_channel)super.position(position); - } - + * fake_quant_with_min_max_vals_per_channel - + * tf.quantization.fake_quant_with_min_max_vars_per_channel + * + * input params: + * 0 - NDArray (input) - at least 2D. + * 1 - 1D Tensor - min values (min length equals to last dim of input) + * 2 - 1D Tensor - max value (length equals to min) + * + * int params (optional): + * 0 - num_bits (allowed interval [2, 16], default 8) + * 1 - narrow_range (default False) + * + * output: + * 0 - NDArray with the same shape as input + */ +// #if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars_per_channel) +@Namespace("sd::ops") public static class fake_quant_with_min_max_vars_per_channel extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public fake_quant_with_min_max_vars_per_channel(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public fake_quant_with_min_max_vars_per_channel(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public fake_quant_with_min_max_vars_per_channel position(long position) { + return (fake_quant_with_min_max_vars_per_channel)super.position(position); + } + public fake_quant_with_min_max_vars_per_channel() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * compare_and_bitpack - compare with greater and pack result with uint8 + * + * input params: + * 0 - NDArray (input) + * 1 - 0D Tensor - threshold + * + * + * output: + * 0 - NDArray with the same shape as input and type uint8 + */ +// #if NOT_EXCLUDED(OP_compare_and_bitpack) +@Namespace("sd::ops") public static class compare_and_bitpack extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public compare_and_bitpack(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public compare_and_bitpack(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public compare_and_bitpack position(long position) { + return (compare_and_bitpack)super.position(position); + } - /** - * compare_and_bitpack - compare with greater and pack result with uint8 - * - * input params: - * 0 - NDArray (input) - * 1 - 0D Tensor - threshold - * - * - * output: - * 0 - NDArray with the same shape as input and type uint8 - */ -// #if NOT_EXCLUDED(OP_compare_and_bitpack) - @Namespace("sd::ops") public static class compare_and_bitpack extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public compare_and_bitpack(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public compare_and_bitpack(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public compare_and_bitpack position(long position) { - return (compare_and_bitpack)super.position(position); - } - public compare_and_bitpack() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - +// #endif + // namespace ops + // namespace sd // #endif @@ -21703,308 +24127,307 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_HEADERS_SHAPE_H // #include -// #if NOT_EXCLUDED(OP_permute) - @Namespace("sd::ops") public static class permute extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public permute(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public permute(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public permute position(long position) { - return (permute)super.position(position); - } - +// #if NOT_EXCLUDED(OP_permute) +@Namespace("sd::ops") public static class permute extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public permute(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public permute(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public permute position(long position) { + return (permute)super.position(position); + } + public permute() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_reshapeas) +@Namespace("sd::ops") public static class reshapeas extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reshapeas(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reshapeas(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reshapeas position(long position) { + return (reshapeas)super.position(position); + } -// #if NOT_EXCLUDED(OP_reshapeas) - @Namespace("sd::ops") public static class reshapeas extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reshapeas(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reshapeas(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reshapeas position(long position) { - return (reshapeas)super.position(position); - } - public reshapeas() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_transpose) +@Namespace("sd::ops") public static class transpose extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public transpose(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public transpose(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public transpose position(long position) { + return (transpose)super.position(position); + } -// #if NOT_EXCLUDED(OP_transpose) - @Namespace("sd::ops") public static class transpose extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public transpose(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public transpose(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public transpose position(long position) { - return (transpose)super.position(position); - } - public transpose() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_shape_of) +@Namespace("sd::ops") public static class shape_of extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public shape_of(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public shape_of(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public shape_of position(long position) { + return (shape_of)super.position(position); + } -// #if NOT_EXCLUDED(OP_shape_of) - @Namespace("sd::ops") public static class shape_of extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public shape_of(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public shape_of(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public shape_of position(long position) { - return (shape_of)super.position(position); - } - public shape_of() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_shapes_of) +@Namespace("sd::ops") public static class shapes_of extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public shapes_of(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public shapes_of(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public shapes_of position(long position) { + return (shapes_of)super.position(position); + } -// #if NOT_EXCLUDED(OP_shapes_of) - @Namespace("sd::ops") public static class shapes_of extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public shapes_of(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public shapes_of(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public shapes_of position(long position) { - return (shapes_of)super.position(position); - } - public shapes_of() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_squeeze) +@Namespace("sd::ops") public static class squeeze extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public squeeze(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public squeeze(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public squeeze position(long position) { + return (squeeze)super.position(position); + } -// #if NOT_EXCLUDED(OP_squeeze) - @Namespace("sd::ops") public static class squeeze extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public squeeze(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public squeeze(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public squeeze position(long position) { - return (squeeze)super.position(position); - } - public squeeze() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_expand_dims) +@Namespace("sd::ops") public static class expand_dims extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public expand_dims(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public expand_dims(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public expand_dims position(long position) { + return (expand_dims)super.position(position); + } -// #if NOT_EXCLUDED(OP_expand_dims) - @Namespace("sd::ops") public static class expand_dims extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public expand_dims(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public expand_dims(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public expand_dims position(long position) { - return (expand_dims)super.position(position); - } - public expand_dims() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_reshape) +@Namespace("sd::ops") public static class reshape extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reshape(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reshape(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reshape position(long position) { + return (reshape)super.position(position); + } -// #if NOT_EXCLUDED(OP_reshape) - @Namespace("sd::ops") public static class reshape extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reshape(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reshape(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reshape position(long position) { - return (reshape)super.position(position); - } - public reshape() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_size_at) +@Namespace("sd::ops") public static class size_at extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public size_at(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public size_at(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public size_at position(long position) { + return (size_at)super.position(position); + } -// #if NOT_EXCLUDED(OP_size_at) - @Namespace("sd::ops") public static class size_at extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public size_at(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public size_at(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public size_at position(long position) { - return (size_at)super.position(position); - } - public size_at() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op changes order of given array to specified order. + * In other words: C/F order switch + * + * Int args: + * 0 - isForder. set to 1 for F order output, or 0 for C order output + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_order) +@Namespace("sd::ops") public static class order extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public order(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public order(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public order position(long position) { + return (order)super.position(position); + } - /** - * This op changes order of given array to specified order. - * In other words: C/F order switch - * - * Int args: - * 0 - isForder. set to 1 for F order output, or 0 for C order output - * - * \tparam T - */ -// #if NOT_EXCLUDED(OP_order) - @Namespace("sd::ops") public static class order extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public order(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public order(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public order position(long position) { - return (order)super.position(position); - } - public order() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op boosts specified input up to specified shape + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_tile_to_shape) +@Namespace("sd::ops") public static class tile_to_shape extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public tile_to_shape(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public tile_to_shape(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public tile_to_shape position(long position) { + return (tile_to_shape)super.position(position); + } - /** - * This op boosts specified input up to specified shape - * - * \tparam T - */ -// #if NOT_EXCLUDED(OP_tile_to_shape) - @Namespace("sd::ops") public static class tile_to_shape extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public tile_to_shape(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public tile_to_shape(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public tile_to_shape position(long position) { - return (tile_to_shape)super.position(position); - } - public tile_to_shape() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class tile_to_shape_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public tile_to_shape_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public tile_to_shape_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public tile_to_shape_bp position(long position) { - return (tile_to_shape_bp)super.position(position); - } - +@Namespace("sd::ops") public static class tile_to_shape_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public tile_to_shape_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public tile_to_shape_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public tile_to_shape_bp position(long position) { + return (tile_to_shape_bp)super.position(position); + } + public tile_to_shape_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op broadcast given input up to given shape + * + * inputs: + * input array - array to be broadcasted to given shape + * shape array - array containing shape be broadcasted to + */ +// #if NOT_EXCLUDED(OP_broadcast_to) +@Namespace("sd::ops") public static class broadcast_to extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public broadcast_to(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public broadcast_to(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public broadcast_to position(long position) { + return (broadcast_to)super.position(position); + } - /** - * This op broadcast given input up to given shape - * - * inputs: - * input array - array to be broadcasted to given shape - * shape array - array containing shape be broadcasted to - */ -// #if NOT_EXCLUDED(OP_broadcast_to) - @Namespace("sd::ops") public static class broadcast_to extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public broadcast_to(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public broadcast_to(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public broadcast_to position(long position) { - return (broadcast_to)super.position(position); - } - public broadcast_to() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } -// #endif + } +// #endif +// #if NOT_EXCLUDED(OP_evaluate_reduction_shape) +@Namespace("sd::ops") public static class evaluate_reduction_shape extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public evaluate_reduction_shape(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public evaluate_reduction_shape(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public evaluate_reduction_shape position(long position) { + return (evaluate_reduction_shape)super.position(position); + } -// #if NOT_EXCLUDED(OP_evaluate_reduction_shape) - @Namespace("sd::ops") public static class evaluate_reduction_shape extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public evaluate_reduction_shape(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public evaluate_reduction_shape(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public evaluate_reduction_shape position(long position) { - return (evaluate_reduction_shape)super.position(position); - } - public evaluate_reduction_shape() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation creates new array + * Input: + * array with shape values + * + * IArgs: + * order value + * data type value + * + * BArgs: + * initialization option + */ +// #if NOT_EXCLUDED(OP_create) +@Namespace("sd::ops") public static class create extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public create(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public create(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public create position(long position) { + return (create)super.position(position); + } - /** - * This operation creates new array - * Input: - * array with shape values - * - * IArgs: - * order value - * data type value - * - * BArgs: - * initialization option - */ -// #if NOT_EXCLUDED(OP_create) - @Namespace("sd::ops") public static class create extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public create(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public create(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public create position(long position) { - return (create)super.position(position); - } - public create() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - +// #endif + // namespace ops + // namespace sd // #endif @@ -22035,218 +24458,219 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_HEADERS_RANDOM_H // #include -// #if NOT_EXCLUDED(OP_set_seed) - @Namespace("sd::ops") public static class set_seed extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public set_seed(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public set_seed(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public set_seed position(long position) { - return (set_seed)super.position(position); - } - +// #if NOT_EXCLUDED(OP_set_seed) +@Namespace("sd::ops") public static class set_seed extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public set_seed(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public set_seed(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public set_seed position(long position) { + return (set_seed)super.position(position); + } + public set_seed() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_get_seed) +@Namespace("sd::ops") public static class get_seed extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public get_seed(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public get_seed(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public get_seed position(long position) { + return (get_seed)super.position(position); + } -// #if NOT_EXCLUDED(OP_get_seed) - @Namespace("sd::ops") public static class get_seed extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public get_seed(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public get_seed(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public get_seed position(long position) { - return (get_seed)super.position(position); - } - public get_seed() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/* + * random_uniform distribution for types int32,int64, float16, float and double + * by default dtype is float32 + * + * input: + * 0 - shape of output (1D int tensor) + * 1 - min val (0D of output type) - optional (0 as default) + * 2 - max val (0D of output type) - optional (inf as default) + * + * output: + * 0 - uniformly distributed values of given type (between min and max) + */ +// #if NOT_EXCLUDED(OP_randomuniform) +@Namespace("sd::ops") public static class randomuniform extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public randomuniform(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public randomuniform(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public randomuniform position(long position) { + return (randomuniform)super.position(position); + } - /* - * random_uniform distribution for types int32,int64, float16, float and double - * by default dtype is float32 - * - * input: - * 0 - shape of output (1D int tensor) - * 1 - min val (0D of output type) - optional (0 as default) - * 2 - max val (0D of output type) - optional (inf as default) - * - * output: - * 0 - uniformly distributed values of given type (between min and max) - */ -// #if NOT_EXCLUDED(OP_randomuniform) - @Namespace("sd::ops") public static class randomuniform extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public randomuniform(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public randomuniform(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public randomuniform position(long position) { - return (randomuniform)super.position(position); - } - public randomuniform() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - /* - * multinomial (categorical) random generator draws samples from a multinomial distribution - * - * Input array: - * 0 - 2D ndarray with unnormalized log-probabilities with shape [batch_size (N), num_classes (K)] - * 1 - array with one int value of samples number, number of independent samples to draw for each experiment 1,N. - * Int arguments: - * 0 - optional argument, corresponds to dimension with batch_size - * 1 - optional argument, integer type to use for the output. Default int64. - * - * Output array: - * 0 - 2D ndarray with the drawn samples of shape [batch_size, num_samples] - */ -// #if NOT_EXCLUDED(OP_random_multinomial) - @Namespace("sd::ops") public static class random_multinomial extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public random_multinomial(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public random_multinomial(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public random_multinomial position(long position) { - return (random_multinomial)super.position(position); - } - +// #endif +/* + * multinomial (categorical) random generator draws samples from a multinomial + * distribution + * + * Input array: + * 0 - 2D ndarray with unnormalized log-probabilities with shape [batch_size + * (N), num_classes (K)] 1 - array with one int value of samples number, number + * of independent samples to draw for each experiment 1,N. Int arguments: 0 - + * optional argument, corresponds to dimension with batch_size 1 - optional + * argument, integer type to use for the output. Default int64. + * + * Output array: + * 0 - 2D ndarray with the drawn samples of shape [batch_size, num_samples] + */ +// #if NOT_EXCLUDED(OP_random_multinomial) +@Namespace("sd::ops") public static class random_multinomial extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public random_multinomial(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public random_multinomial(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public random_multinomial position(long position) { + return (random_multinomial)super.position(position); + } + public random_multinomial() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_random_normal) +@Namespace("sd::ops") public static class random_normal extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public random_normal(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public random_normal(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public random_normal position(long position) { + return (random_normal)super.position(position); + } -// #if NOT_EXCLUDED(OP_random_normal) - @Namespace("sd::ops") public static class random_normal extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public random_normal(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public random_normal(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public random_normal position(long position) { - return (random_normal)super.position(position); - } - public random_normal() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_random_bernoulli) +@Namespace("sd::ops") public static class random_bernoulli extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public random_bernoulli(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public random_bernoulli(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public random_bernoulli position(long position) { + return (random_bernoulli)super.position(position); + } -// #if NOT_EXCLUDED(OP_random_bernoulli) - @Namespace("sd::ops") public static class random_bernoulli extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public random_bernoulli(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public random_bernoulli(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public random_bernoulli position(long position) { - return (random_bernoulli)super.position(position); - } - public random_bernoulli() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_random_exponential) +@Namespace("sd::ops") public static class random_exponential extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public random_exponential(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public random_exponential(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public random_exponential position(long position) { + return (random_exponential)super.position(position); + } -// #if NOT_EXCLUDED(OP_random_exponential) - @Namespace("sd::ops") public static class random_exponential extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public random_exponential(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public random_exponential(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public random_exponential position(long position) { - return (random_exponential)super.position(position); - } - public random_exponential() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_random_crop) +@Namespace("sd::ops") public static class random_crop extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public random_crop(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public random_crop(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public random_crop position(long position) { + return (random_crop)super.position(position); + } -// #if NOT_EXCLUDED(OP_random_crop) - @Namespace("sd::ops") public static class random_crop extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public random_crop(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public random_crop(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public random_crop position(long position) { - return (random_crop)super.position(position); - } - public random_crop() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * random_gamma op. + */ +// #if NOT_EXCLUDED(OP_random_gamma) +@Namespace("sd::ops") public static class random_gamma extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public random_gamma(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public random_gamma(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public random_gamma position(long position) { + return (random_gamma)super.position(position); + } - /** - * random_gamma op. - */ -// #if NOT_EXCLUDED(OP_random_gamma) - @Namespace("sd::ops") public static class random_gamma extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public random_gamma(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public random_gamma(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public random_gamma position(long position) { - return (random_gamma)super.position(position); - } - public random_gamma() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * random_poisson op. + */ +// #if NOT_EXCLUDED(OP_random_poisson) +@Namespace("sd::ops") public static class random_poisson extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public random_poisson(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public random_poisson(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public random_poisson position(long position) { + return (random_poisson)super.position(position); + } - /** - * random_poisson op. - */ -// #if NOT_EXCLUDED(OP_random_poisson) - @Namespace("sd::ops") public static class random_poisson extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public random_poisson(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public random_poisson(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public random_poisson position(long position) { - return (random_poisson)super.position(position); - } - public random_poisson() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - +// #endif + // namespace ops + // namespace sd // #endif @@ -22277,473 +24701,481 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #include -// #if NOT_EXCLUDED(OP_softmax) - @Namespace("sd::ops") public static class softmax extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public softmax(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public softmax(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public softmax position(long position) { - return (softmax)super.position(position); - } - +// #if NOT_EXCLUDED(OP_softmax) +@Namespace("sd::ops") public static class softmax extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public softmax(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public softmax(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public softmax position(long position) { + return (softmax)super.position(position); + } + public softmax() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class softmax_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public softmax_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public softmax_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public softmax_bp position(long position) { - return (softmax_bp)super.position(position); - } - +@Namespace("sd::ops") public static class softmax_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public softmax_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public softmax_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public softmax_bp position(long position) { + return (softmax_bp)super.position(position); + } + public softmax_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Local response normalization implementation as TF. + * input: 4D array + * + * T args: + * + * 0: bias + * 1: alpha + * 2: beta + * + * Int arg: depth - optional local radius + * + * output - 4D array + */ +// #if NOT_EXCLUDED(OP_lrn) +@Namespace("sd::ops") public static class lrn extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lrn(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lrn(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lrn position(long position) { + return (lrn)super.position(position); + } - /** - * Local response normalization implementation as TF. - * input: 4D array - * - * T args: - * - * 0: bias - * 1: alpha - * 2: beta - * - * Int arg: depth - optional local radius - * - * output - 4D array - */ -// #if NOT_EXCLUDED(OP_lrn) - @Namespace("sd::ops") public static class lrn extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lrn(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lrn(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lrn position(long position) { - return (lrn)super.position(position); - } - public lrn() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Local response normalization - backprop variant. + * input: + * 0 - 4D array of data + * 1 - epsilon - 4D array of approximation + * + * T args: + * + * 0: bias + * 1: alpha + * 2: beta + * + * Int arg: depth - optional local radius + * + * output - next approximation as 4D array + */ +// #if NOT_EXCLUDED(OP_lrn) +@Namespace("sd::ops") public static class lrn_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lrn_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lrn_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lrn_bp position(long position) { + return (lrn_bp)super.position(position); + } - /** - * Local response normalization - backprop variant. - * input: - * 0 - 4D array of data - * 1 - epsilon - 4D array of approximation - * - * T args: - * - * 0: bias - * 1: alpha - * 2: beta - * - * Int arg: depth - optional local radius - * - * output - next approximation as 4D array - */ -// #if NOT_EXCLUDED(OP_lrn) - @Namespace("sd::ops") public static class lrn_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lrn_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lrn_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lrn_bp position(long position) { - return (lrn_bp)super.position(position); - } - public lrn_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Batch normalization implementation. + * Reference: https://arxiv.org/abs/1502.03167v3 + * + * Expected arguments: + * input: input array (any number of dimensions) + * mean: + * variance: + * gamma: + * beta: + * + * Int args: + * 0: apply scale + * 1: apply offset + * + * + * T args: + * 0: epsilon + */ +// #if NOT_EXCLUDED(OP_batchnorm) +@Namespace("sd::ops") public static class batchnorm extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public batchnorm(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public batchnorm(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public batchnorm position(long position) { + return (batchnorm)super.position(position); + } - /** - * Batch normalization implementation. - * Reference: https://arxiv.org/abs/1502.03167v3 - * - * Expected arguments: - * input: input array (any number of dimensions) - * mean: - * variance: - * gamma: - * beta: - * - * Int args: - * 0: apply scale - * 1: apply offset - * - * - * T args: - * 0: epsilon - */ -// #if NOT_EXCLUDED(OP_batchnorm) - @Namespace("sd::ops") public static class batchnorm extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public batchnorm(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public batchnorm(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public batchnorm position(long position) { - return (batchnorm)super.position(position); - } - public batchnorm() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * back prop in batch normalization + * + * Expected arguments: + * input: input array (any number of dimensions) + * mean: + * variance: + * gamma: optional + * beta: optional + * dLdOut: next epsilon + * + * Int args: + * 0: apply scale + * 1: apply offset + * + * T args: + * 0: epsilon + * + * output arrays: + * dL/dInput + * dL/dMean + * dL/dVariance + * dL/dGamma, optional + * dL/dBeta, optional + */ +// #if NOT_EXCLUDED(OP_batchnorm) +@Namespace("sd::ops") public static class batchnorm_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public batchnorm_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public batchnorm_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public batchnorm_bp position(long position) { + return (batchnorm_bp)super.position(position); + } - /** - * back prop in batch normalization - * - * Expected arguments: - * input: input array (any number of dimensions) - * mean: - * variance: - * gamma: optional - * beta: optional - * dLdOut: next epsilon - * - * Int args: - * 0: apply scale - * 1: apply offset - * - * T args: - * 0: epsilon - * - * output arrays: - * dL/dInput - * dL/dMean - * dL/dVariance - * dL/dGamma, optional - * dL/dBeta, optional - */ -// #if NOT_EXCLUDED(OP_batchnorm) - @Namespace("sd::ops") public static class batchnorm_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public batchnorm_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public batchnorm_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public batchnorm_bp position(long position) { - return (batchnorm_bp)super.position(position); - } - public batchnorm_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * This operation updates parameters with provided gradients, wrt learning rate + * Expected arguments: + * x: parameters, any shape + * y: gradients. same shape as x + * lr: optional, learning rate + * + * T args: + * 0: optional, learning rate + */ +// #if NOT_EXCLUDED(OP_apply_sgd) +@Namespace("sd::ops") public static class apply_sgd extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public apply_sgd(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public apply_sgd(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public apply_sgd position(long position) { + return (apply_sgd)super.position(position); + } - /** - * This operation updates parameters with provided gradients, wrt learning rate - * Expected arguments: - * x: parameters, any shape - * y: gradients. same shape as x - * lr: optional, learning rate - * - * T args: - * 0: optional, learning rate - */ -// #if NOT_EXCLUDED(OP_apply_sgd) - @Namespace("sd::ops") public static class apply_sgd extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public apply_sgd(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public apply_sgd(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public apply_sgd position(long position) { - return (apply_sgd)super.position(position); - } - public apply_sgd() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation performs batch normalization of layer, it is based on + * following article https://arxiv.org/abs/1502.03167. Expected arguments: x: + * input 4D array of shape [bS,iH,iW,iD] (data format = NHWC) or [bS,iD,iH,iW] + * (data format = NCHW), where bS - batch size iH - input height iW - input + * width iD - input depth (or number of channels) scale: 1D input array of + * scale factors, shape [iD] offset: 1D input array of offsets (shifts), shape + * [iD] mean: 1D input array of population mean used for inference, shape [iD], + * this array is required only if isTraining = false variance: 1D input array of + * population mean used for inference, shape [iD], this array is required only + * if isTraining = false + * + * T input arguments: + * 0: epsilon, it is optional argument, default value is 0.001, this is small + * number to be added to the variance of x + * + * integer input arguments: + * 0: dataFormat, may have two values: zero -> NHWC, unity -> NCHW + * 1: isTraining, may have two values: zero -> inference, unity -> training + */ +// #if NOT_EXCLUDED(OP_fused_batch_norm) +@Namespace("sd::ops") public static class fused_batch_norm extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public fused_batch_norm(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public fused_batch_norm(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public fused_batch_norm position(long position) { + return (fused_batch_norm)super.position(position); + } - /** - * This operation performs batch normalization of layer, it is based on following article https://arxiv.org/abs/1502.03167. - * Expected arguments: - * x: input 4D array of shape [bS,iH,iW,iD] (data format = NHWC) or [bS,iD,iH,iW] (data format = NCHW), where - * bS - batch size - * iH - input height - * iW - input width - * iD - input depth (or number of channels) - * scale: 1D input array of scale factors, shape [iD] - * offset: 1D input array of offsets (shifts), shape [iD] - * mean: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false - * variance: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false - * - * T input arguments: - * 0: epsilon, it is optional argument, default value is 0.001, this is small number to be added to the variance of x - * - * integer input arguments: - * 0: dataFormat, may have two values: zero -> NHWC, unity -> NCHW - * 1: isTraining, may have two values: zero -> inference, unity -> training - */ -// #if NOT_EXCLUDED(OP_fused_batch_norm) - @Namespace("sd::ops") public static class fused_batch_norm extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public fused_batch_norm(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public fused_batch_norm(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public fused_batch_norm position(long position) { - return (fused_batch_norm)super.position(position); - } - public fused_batch_norm() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_log_softmax) +@Namespace("sd::ops") public static class log_softmax extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public log_softmax(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public log_softmax(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public log_softmax position(long position) { + return (log_softmax)super.position(position); + } -// #if NOT_EXCLUDED(OP_log_softmax) - @Namespace("sd::ops") public static class log_softmax extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public log_softmax(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public log_softmax(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public log_softmax position(long position) { - return (log_softmax)super.position(position); - } - public log_softmax() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class log_softmax_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public log_softmax_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public log_softmax_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public log_softmax_bp position(long position) { - return (log_softmax_bp)super.position(position); - } - +@Namespace("sd::ops") public static class log_softmax_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public log_softmax_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public log_softmax_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public log_softmax_bp position(long position) { + return (log_softmax_bp)super.position(position); + } + public log_softmax_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * relu_layer = relu(x*w + b) + */ +@Namespace("sd::ops") public static class relu_layer extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public relu_layer(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public relu_layer(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public relu_layer position(long position) { + return (relu_layer)super.position(position); + } - /** - * relu_layer = relu(x*w + b) - */ - @Namespace("sd::ops") public static class relu_layer extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public relu_layer(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public relu_layer(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public relu_layer position(long position) { - return (relu_layer)super.position(position); - } - public relu_layer() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - /** - * applies layer normalization to input - * y = g * standardize(x) + b - * - * see sd::ops::standardize - * - */ -// #if NOT_EXCLUDED(OP_layer_norm) - @Namespace("sd::ops") public static class layer_norm extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public layer_norm(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public layer_norm(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public layer_norm position(long position) { - return (layer_norm)super.position(position); - } - +/** + * applies layer normalization to input + * y = g * standardize(x) + b + * + * see sd::ops::standardize + * + */ +// #if NOT_EXCLUDED(OP_layer_norm) +@Namespace("sd::ops") public static class layer_norm extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public layer_norm(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public layer_norm(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public layer_norm position(long position) { + return (layer_norm)super.position(position); + } + public layer_norm() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class layer_norm_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public layer_norm_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public layer_norm_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public layer_norm_bp position(long position) { - return (layer_norm_bp)super.position(position); - } - +@Namespace("sd::ops") public static class layer_norm_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public layer_norm_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public layer_norm_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public layer_norm_bp position(long position) { + return (layer_norm_bp)super.position(position); + } + public layer_norm_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation performs dot product attention on the given timeseries input + * with the given queries out = sum(similarity(k_i, q) * v_i) + * + * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q + * + * Optionally with normalization step: + * similarity(k, q) = softmax(k * q / sqrt(size(q)) + * + * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, + * eq. 1) + * + * Note: This supports multiple queries at once, if only one query is available + * the queries vector still has to be 3D but can have queryCount = 1 + * + * Note: keys and values usually is the same array. If you want to use it as the + * same array, simply pass it for both. + * + * Expected arguments: + * q: input 3D array "queries" of shape [batchSize, featureKeys, queryCount] or + * 4D array of shape [batchSize, numHeads, featureKeys, queryCount] k: input 3D + * array "keys" of shape [batchSize, featureKeys, timesteps] or 4D array of + * shape [batchSize, numHeads, featureKeys, timesteps] v: input 3D array + * "values" of shape [batchSize, featureValues, timesteps] or 4D array of shape + * [batchSize, numHeads, featureValues, timesteps] mask: OPTIONAL; array that + * defines which values should be skipped of shape [batchSize, timesteps] + * + * integer input arguments: + * 0: normalization, may have two values: zero -> do not apply normalization, + * one -> apply normalization 1: withWeights, may have two values: zero -> do + * not return weights, one -> return weights + * + * Output Arrays: + * 0: Attention result arrays of shape [batchSize, featureValues, queryCount] or + * [batchSize, numHeads, featureValues, queryCount] 1: OPTIONAL; Attention + * weights of shape [batchSize, timesteps, queryCount] or [batchSize, numHeads, + * timesteps, queryCount] + */ +// #if NOT_EXCLUDED(OP_dot_product_attention) +@Namespace("sd::ops") public static class dot_product_attention extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public dot_product_attention(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public dot_product_attention(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public dot_product_attention position(long position) { + return (dot_product_attention)super.position(position); + } - /** - * This operation performs dot product attention on the given timeseries input with the given queries - * out = sum(similarity(k_i, q) * v_i) - * - * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q - * - * Optionally with normalization step: - * similarity(k, q) = softmax(k * q / sqrt(size(q)) - * - * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1) - * - * Note: This supports multiple queries at once, if only one query is available the queries vector still has to - * be 3D but can have queryCount = 1 - * - * Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for - * both. - * - * Expected arguments: - * q: input 3D array "queries" of shape [batchSize, featureKeys, queryCount] or 4D array of shape [batchSize, numHeads, featureKeys, queryCount] - * k: input 3D array "keys" of shape [batchSize, featureKeys, timesteps] or 4D array of shape [batchSize, numHeads, featureKeys, timesteps] - * v: input 3D array "values" of shape [batchSize, featureValues, timesteps] or 4D array of shape [batchSize, numHeads, featureValues, timesteps] - * mask: OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] - * - * integer input arguments: - * 0: normalization, may have two values: zero -> do not apply normalization, one -> apply normalization - * 1: withWeights, may have two values: zero -> do not return weights, one -> return weights - * - * Output Arrays: - * 0: Attention result arrays of shape [batchSize, featureValues, queryCount] or [batchSize, numHeads, featureValues, queryCount] - * 1: OPTIONAL; Attention weights of shape [batchSize, timesteps, queryCount] or [batchSize, numHeads, timesteps, queryCount] - */ -// #if NOT_EXCLUDED(OP_dot_product_attention) - @Namespace("sd::ops") public static class dot_product_attention extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public dot_product_attention(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public dot_product_attention(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public dot_product_attention position(long position) { - return (dot_product_attention)super.position(position); - } - public dot_product_attention() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class dot_product_attention_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public dot_product_attention_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public dot_product_attention_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public dot_product_attention_bp position(long position) { - return (dot_product_attention_bp)super.position(position); - } - +@Namespace("sd::ops") public static class dot_product_attention_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public dot_product_attention_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public dot_product_attention_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public dot_product_attention_bp position(long position) { + return (dot_product_attention_bp)super.position(position); + } + public dot_product_attention_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * This performs multi-headed dot product attention on the given timeseries + * input out = concat(head_1, head_2, ..., head_n) * Wo head_i = + * dot_product_attention(Wq_i*q, Wk_i*k, Wv_i*v) + * + * Optionally with normalization when calculating the attention for each head. + * + * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. + * 4,5, "3.2.2 Multi-Head Attention") + * + * This makes use of dot_product_attention OP support for rank 4 inputs. + * + * Expected arguments: + * q: input 3D array "queries" of shape [batchSize, featureKeys, queryCount] + * k: input 3D array "keys" of shape [batchSize, featureKeys, timesteps] + * v: input 3D array "values" of shape [batchSize, featureValues, timesteps] + * Wq: input query projection weights of shape [numHeads, projectedKeys, + * featureKeys] Wk: input key projection weights of shape [numHeads, + * projectedKeys, featureKeys] Wv: input value projection weights of shape + * [numHeads, projectedValues, featureValues] Wo: output projection weights of + * shape [numHeads * projectedValues, outSize] mask: OPTIONAL; array that + * defines which values should be skipped of shape [batchSize, timesteps] + * + * integer input arguments: + * 0: normalization, may have two values: zero -> do not apply normalization, + * one -> apply normalization 1: withWeights, may have two values: zero -> do + * not return weights, one -> return weights + * + * Output Arrays: + * 0: Attention result arrays of shape [batchSize, outSize, queryCount] + * 1: OPTIONAL; Attention weights of shape [batchSize, numHeads, timesteps, + * queryCount] + */ +// #if NOT_EXCLUDED(OP_multi_head_dot_product_attention) +@Namespace("sd::ops") public static class multi_head_dot_product_attention extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public multi_head_dot_product_attention(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public multi_head_dot_product_attention(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public multi_head_dot_product_attention position(long position) { + return (multi_head_dot_product_attention)super.position(position); + } - /** - * This performs multi-headed dot product attention on the given timeseries input - * out = concat(head_1, head_2, ..., head_n) * Wo - * head_i = dot_product_attention(Wq_i*q, Wk_i*k, Wv_i*v) - * - * Optionally with normalization when calculating the attention for each head. - * - * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. 4,5, "3.2.2 Multi-Head Attention") - * - * This makes use of dot_product_attention OP support for rank 4 inputs. - * - * Expected arguments: - * q: input 3D array "queries" of shape [batchSize, featureKeys, queryCount] - * k: input 3D array "keys" of shape [batchSize, featureKeys, timesteps] - * v: input 3D array "values" of shape [batchSize, featureValues, timesteps] - * Wq: input query projection weights of shape [numHeads, projectedKeys, featureKeys] - * Wk: input key projection weights of shape [numHeads, projectedKeys, featureKeys] - * Wv: input value projection weights of shape [numHeads, projectedValues, featureValues] - * Wo: output projection weights of shape [numHeads * projectedValues, outSize] - * mask: OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] - * - * integer input arguments: - * 0: normalization, may have two values: zero -> do not apply normalization, one -> apply normalization - * 1: withWeights, may have two values: zero -> do not return weights, one -> return weights - * - * Output Arrays: - * 0: Attention result arrays of shape [batchSize, outSize, queryCount] - * 1: OPTIONAL; Attention weights of shape [batchSize, numHeads, timesteps, queryCount] - */ -// #if NOT_EXCLUDED(OP_multi_head_dot_product_attention) - @Namespace("sd::ops") public static class multi_head_dot_product_attention extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public multi_head_dot_product_attention(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public multi_head_dot_product_attention(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public multi_head_dot_product_attention position(long position) { - return (multi_head_dot_product_attention)super.position(position); - } - public multi_head_dot_product_attention() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class multi_head_dot_product_attention_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public multi_head_dot_product_attention_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public multi_head_dot_product_attention_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public multi_head_dot_product_attention_bp position(long position) { - return (multi_head_dot_product_attention_bp)super.position(position); - } - +@Namespace("sd::ops") public static class multi_head_dot_product_attention_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public multi_head_dot_product_attention_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public multi_head_dot_product_attention_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public multi_head_dot_product_attention_bp position(long position) { + return (multi_head_dot_product_attention_bp)super.position(position); + } + public multi_head_dot_product_attention_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - +// #endif + // namespace ops + // namespace sd // #endif @@ -22773,150 +25205,149 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #include - /** - * This op is general matmum implementation. Depending on inputs dimensionality output result might be different. - * matrix x matrix = BLAS gemm - * vector x matrix = BLAS gemm - * vector x vector = BLAS dot - * vector x scalar = element-wise mul - * scalar x vector = element-wise mul - * - * Optional T arguments: - * 0: alpha (where applicable) - * 1: beta (where applicable) - * - * Optional Integer arguments: - * 0: transA (where applicable) - * 1: transB (where applicable) - */ -// #if NOT_EXCLUDED(OP_matmul) - @Namespace("sd::ops") public static class matmul extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public matmul(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public matmul(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public matmul position(long position) { - return (matmul)super.position(position); - } - +/** + * This op is general matmum implementation. Depending on inputs dimensionality + * output result might be different. matrix x matrix = BLAS gemm vector x matrix + * = BLAS gemm vector x vector = BLAS dot vector x scalar = element-wise mul + * scalar x vector = element-wise mul + * + * Optional T arguments: + * 0: alpha (where applicable) + * 1: beta (where applicable) + * + * Optional Integer arguments: + * 0: transA (where applicable) + * 1: transB (where applicable) + */ +// #if NOT_EXCLUDED(OP_matmul) +@Namespace("sd::ops") public static class matmul extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public matmul(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public matmul(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public matmul position(long position) { + return (matmul)super.position(position); + } + public matmul() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class matmul_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public matmul_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public matmul_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public matmul_bp position(long position) { - return (matmul_bp)super.position(position); - } - +@Namespace("sd::ops") public static class matmul_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public matmul_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public matmul_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public matmul_bp position(long position) { + return (matmul_bp)super.position(position); + } + public matmul_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * tensorMmul/tensorDot operation + * takes 2 ndarrays, and 2 sets of axes + * + * Integer argumens map: + * IArgs[0] - number of axes along for first array + * IArgs[1]... axes values for first array + * IArgs[] - number of axes along for second array + * IArgs[1]... axes values for second array + */ +// #if NOT_EXCLUDED(OP_tensormmul) +@Namespace("sd::ops") public static class tensormmul extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public tensormmul(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public tensormmul(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public tensormmul position(long position) { + return (tensormmul)super.position(position); + } - /** - * tensorMmul/tensorDot operation - * takes 2 ndarrays, and 2 sets of axes - * - * Integer argumens map: - * IArgs[0] - number of axes along for first array - * IArgs[1]... axes values for first array - * IArgs[] - number of axes along for second array - * IArgs[1]... axes values for second array - */ -// #if NOT_EXCLUDED(OP_tensormmul) - @Namespace("sd::ops") public static class tensormmul extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public tensormmul(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public tensormmul(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public tensormmul position(long position) { - return (tensormmul)super.position(position); - } - public tensormmul() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class tensormmul_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public tensormmul_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public tensormmul_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public tensormmul_bp position(long position) { - return (tensormmul_bp)super.position(position); - } - +@Namespace("sd::ops") public static class tensormmul_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public tensormmul_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public tensormmul_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public tensormmul_bp position(long position) { + return (tensormmul_bp)super.position(position); + } + public tensormmul_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op is simple implementation of BLAS AXPY method. + * Math is: y += a * x; + */ +// #if NOT_EXCLUDED(OP_axpy) +@Namespace("sd::ops") public static class axpy extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public axpy(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public axpy(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public axpy position(long position) { + return (axpy)super.position(position); + } - /** - * This op is simple implementation of BLAS AXPY method. - * Math is: y += a * x; - */ -// #if NOT_EXCLUDED(OP_axpy) - @Namespace("sd::ops") public static class axpy extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public axpy(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public axpy(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public axpy position(long position) { - return (axpy)super.position(position); - } - public axpy() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation implements batched matrix multiplication + * Expected arguments: + * alpha: vector of T + * beta: vector of T + * ...: A, B matrices sequentially. i.e: AAAAABBBBB + * + * Integer arguments: + * transA, transB, M, N, K, ldA, ldB, ldC - usual BLAS gemm arguments + * batchCount - number of operations in this batch + * + * PLEASE NOTE: M, N, K, ldA, ldB, ldC should be equal for all matrices within + * batch. + */ +// #if NOT_EXCLUDED(OP_batched_gemm) +@Namespace("sd::ops") public static class batched_gemm extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public batched_gemm(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public batched_gemm(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public batched_gemm position(long position) { + return (batched_gemm)super.position(position); + } - /** - * This operation implements batched matrix multiplication - * Expected arguments: - * alpha: vector of T - * beta: vector of T - * ...: A, B matrices sequentially. i.e: AAAAABBBBB - * - * Integer arguments: - * transA, transB, M, N, K, ldA, ldB, ldC - usual BLAS gemm arguments - * batchCount - number of operations in this batch - * - * PLEASE NOTE: M, N, K, ldA, ldB, ldC should be equal for all matrices within batch. - */ -// #if NOT_EXCLUDED(OP_batched_gemm) - @Namespace("sd::ops") public static class batched_gemm extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public batched_gemm(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public batched_gemm(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public batched_gemm position(long position) { - return (batched_gemm)super.position(position); - } - public batched_gemm() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif /** * performs singular value decomposition (SVD) of one or more matrices, evaluates the SVD of each inner-most 2D matrix in input array: @@ -23009,114 +25440,114 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // @author raver119@gmail.com // // #include -// #if NOT_EXCLUDED(OP_test_output_reshape) - @Namespace("sd::ops") public static class test_output_reshape extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public test_output_reshape(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public test_output_reshape(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public test_output_reshape position(long position) { - return (test_output_reshape)super.position(position); - } - +// #if NOT_EXCLUDED(OP_test_output_reshape) +@Namespace("sd::ops") public static class test_output_reshape extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public test_output_reshape(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public test_output_reshape(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public test_output_reshape position(long position) { + return (test_output_reshape)super.position(position); + } + public test_output_reshape() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_test_scalar) +@Namespace("sd::ops") public static class test_scalar extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public test_scalar(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public test_scalar(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public test_scalar position(long position) { + return (test_scalar)super.position(position); + } -// #if NOT_EXCLUDED(OP_test_scalar) - @Namespace("sd::ops") public static class test_scalar extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public test_scalar(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public test_scalar(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public test_scalar position(long position) { - return (test_scalar)super.position(position); - } - public test_scalar() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_testreduction) +@Namespace("sd::ops") public static class testreduction extends DeclarableReductionOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public testreduction(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public testreduction(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public testreduction position(long position) { + return (testreduction)super.position(position); + } -// #if NOT_EXCLUDED(OP_testreduction) - @Namespace("sd::ops") public static class testreduction extends DeclarableReductionOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public testreduction(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public testreduction(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public testreduction position(long position) { - return (testreduction)super.position(position); - } - public testreduction() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_noop) +@Namespace("sd::ops") public static class noop extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public noop(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public noop(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public noop position(long position) { + return (noop)super.position(position); + } -// #if NOT_EXCLUDED(OP_noop) - @Namespace("sd::ops") public static class noop extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public noop(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public noop(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public noop position(long position) { - return (noop)super.position(position); - } - public noop() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_testop2i2o) +@Namespace("sd::ops") public static class testop2i2o extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public testop2i2o(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public testop2i2o(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public testop2i2o position(long position) { + return (testop2i2o)super.position(position); + } -// #if NOT_EXCLUDED(OP_testop2i2o) - @Namespace("sd::ops") public static class testop2i2o extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public testop2i2o(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public testop2i2o(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public testop2i2o position(long position) { - return (testop2i2o)super.position(position); - } - public testop2i2o() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_testcustom) +@Namespace("sd::ops") public static class testcustom extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public testcustom(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public testcustom(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public testcustom position(long position) { + return (testcustom)super.position(position); + } -// #if NOT_EXCLUDED(OP_testcustom) - @Namespace("sd::ops") public static class testcustom extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public testcustom(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public testcustom(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public testcustom position(long position) { - return (testcustom)super.position(position); - } - public testcustom() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - +// #endif + // namespace ops + // namespace sd // Parsed from ops/declarable/headers/bitwise.h @@ -23144,226 +25575,228 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_HEADERS_BITWISE_H // #include - /** - * This operation toggles individual bits of each element in array - * - * PLEASE NOTE: This operation is possible only on integer data types - * - * \tparam T - */ -// #if NOT_EXCLUDED(OP_toggle_bits) - @Namespace("sd::ops") public static class toggle_bits extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public toggle_bits(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public toggle_bits(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public toggle_bits position(long position) { - return (toggle_bits)super.position(position); - } - +/** + * This operation toggles individual bits of each element in array + * + * PLEASE NOTE: This operation is possible only on integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_toggle_bits) +@Namespace("sd::ops") public static class toggle_bits extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public toggle_bits(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public toggle_bits(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public toggle_bits position(long position) { + return (toggle_bits)super.position(position); + } + public toggle_bits() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * This operation shift individual bits of each element in array to the left: << + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_shift_bits) +@Namespace("sd::ops") public static class shift_bits extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public shift_bits(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public shift_bits(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public shift_bits position(long position) { + return (shift_bits)super.position(position); + } - /** - * This operation shift individual bits of each element in array to the left: << - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * \tparam T - */ -// #if NOT_EXCLUDED(OP_shift_bits) - @Namespace("sd::ops") public static class shift_bits extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public shift_bits(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public shift_bits(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public shift_bits position(long position) { - return (shift_bits)super.position(position); - } - public shift_bits() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation shift individual bits of each element in array to the right: + * >> + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_rshift_bits) +@Namespace("sd::ops") public static class rshift_bits extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public rshift_bits(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public rshift_bits(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public rshift_bits position(long position) { + return (rshift_bits)super.position(position); + } - /** - * This operation shift individual bits of each element in array to the right: >> - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * \tparam T - */ -// #if NOT_EXCLUDED(OP_rshift_bits) - @Namespace("sd::ops") public static class rshift_bits extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public rshift_bits(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public rshift_bits(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public rshift_bits position(long position) { - return (rshift_bits)super.position(position); - } - public rshift_bits() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation shift individual bits of each element in array, shifting to + * the left + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_cyclic_shift_bits) +@Namespace("sd::ops") public static class cyclic_shift_bits extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cyclic_shift_bits(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cyclic_shift_bits(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cyclic_shift_bits position(long position) { + return (cyclic_shift_bits)super.position(position); + } - /** - * This operation shift individual bits of each element in array, shifting to the left - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * \tparam T - */ -// #if NOT_EXCLUDED(OP_cyclic_shift_bits) - @Namespace("sd::ops") public static class cyclic_shift_bits extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cyclic_shift_bits(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cyclic_shift_bits(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cyclic_shift_bits position(long position) { - return (cyclic_shift_bits)super.position(position); - } - public cyclic_shift_bits() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation shift individual bits of each element in array, shifting to + * the right + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_cyclic_rshift_bits) +@Namespace("sd::ops") public static class cyclic_rshift_bits extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cyclic_rshift_bits(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cyclic_rshift_bits(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cyclic_rshift_bits position(long position) { + return (cyclic_rshift_bits)super.position(position); + } - /** - * This operation shift individual bits of each element in array, shifting to the right - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * \tparam T - */ -// #if NOT_EXCLUDED(OP_cyclic_rshift_bits) - @Namespace("sd::ops") public static class cyclic_rshift_bits extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cyclic_rshift_bits(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cyclic_rshift_bits(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cyclic_rshift_bits position(long position) { - return (cyclic_rshift_bits)super.position(position); - } - public cyclic_rshift_bits() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation applies bitwise AND + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_bitwise_and) +@Namespace("sd::ops") public static class bitwise_and extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public bitwise_and(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public bitwise_and(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public bitwise_and position(long position) { + return (bitwise_and)super.position(position); + } - /** - * This operation applies bitwise AND - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * \tparam T - */ -// #if NOT_EXCLUDED(OP_bitwise_and) - @Namespace("sd::ops") public static class bitwise_and extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public bitwise_and(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public bitwise_and(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public bitwise_and position(long position) { - return (bitwise_and)super.position(position); - } - public bitwise_and() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation applies bitwise OR + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_bitwise_or) +@Namespace("sd::ops") public static class bitwise_or extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public bitwise_or(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public bitwise_or(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public bitwise_or position(long position) { + return (bitwise_or)super.position(position); + } - /** - * This operation applies bitwise OR - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * \tparam T - */ -// #if NOT_EXCLUDED(OP_bitwise_or) - @Namespace("sd::ops") public static class bitwise_or extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public bitwise_or(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public bitwise_or(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public bitwise_or position(long position) { - return (bitwise_or)super.position(position); - } - public bitwise_or() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation applies bitwise XOR + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_bitwise_xor) +@Namespace("sd::ops") public static class bitwise_xor extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public bitwise_xor(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public bitwise_xor(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public bitwise_xor position(long position) { + return (bitwise_xor)super.position(position); + } - /** - * This operation applies bitwise XOR - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * \tparam T - */ -// #if NOT_EXCLUDED(OP_bitwise_xor) - @Namespace("sd::ops") public static class bitwise_xor extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public bitwise_xor(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public bitwise_xor(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public bitwise_xor position(long position) { - return (bitwise_xor)super.position(position); - } - public bitwise_xor() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation returns hamming distance based on bits + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_bits_hamming_distance) +@Namespace("sd::ops") public static class bits_hamming_distance extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public bits_hamming_distance(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public bits_hamming_distance(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public bits_hamming_distance position(long position) { + return (bits_hamming_distance)super.position(position); + } - /** - * This operation returns hamming distance based on bits - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * \tparam T - */ -// #if NOT_EXCLUDED(OP_bits_hamming_distance) - @Namespace("sd::ops") public static class bits_hamming_distance extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public bits_hamming_distance(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public bits_hamming_distance(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public bits_hamming_distance position(long position) { - return (bits_hamming_distance)super.position(position); - } - public bits_hamming_distance() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - +// #endif + // namespace ops + // namespace sd // #endif @@ -23393,692 +25826,731 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_HEADERS_LOSS_H // #include - - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of hinge loss function max(0, 1 - labels*logits) - * - * Input arrays: - * 0: logits - logits, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels and must be broadcastable to labels. - * 2: labels - ground truth vales, expected to be 0. or 1., type float. - * Must have the same shape as logits. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as logits. - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as logits or just single scalar, depending on reduction mode (see input integer argument) - */ -// #if NOT_EXCLUDED(OP_hinge_loss) - @Namespace("sd::ops") public static class hinge_loss extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public hinge_loss(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public hinge_loss(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public hinge_loss position(long position) { - return (hinge_loss)super.position(position); - } - + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of hinge loss function max(0, 1 - labels*logits) + * + * Input arrays: + * 0: logits - logits, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels and must be + * broadcastable to labels. 2: labels - ground truth vales, expected to be 0. + * or 1., type float. Must have the same shape as logits. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as logits. + * 1 - "weighted_sum", output is scalar and equal to sum of all elements + * of weightedLosses array 2 - "weighted_mean", output is scalar and equal to + * sum of all elements of weightedLosses array divided by sum of all elements of + * weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output is scalar + * and equal to scalar sum of all elements of weightedLosses array divided by + * number of non-zero weights + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as logits or just single scalar, + * depending on reduction mode (see input integer argument) + */ +// #if NOT_EXCLUDED(OP_hinge_loss) +@Namespace("sd::ops") public static class hinge_loss extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public hinge_loss(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public hinge_loss(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public hinge_loss position(long position) { + return (hinge_loss)super.position(position); + } + public hinge_loss() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class hinge_loss_grad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public hinge_loss_grad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public hinge_loss_grad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public hinge_loss_grad position(long position) { - return (hinge_loss_grad)super.position(position); - } - +@Namespace("sd::ops") public static class hinge_loss_grad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public hinge_loss_grad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public hinge_loss_grad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public hinge_loss_grad position(long position) { + return (hinge_loss_grad)super.position(position); + } + public hinge_loss_grad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of Huber loss function: + * 0.5 * (labels-predictions)^2 if + * |labels-predictions| <= delta 0.5 * delta^2 + delta * (|labels-predictions| - + * delta) if |labels-predictions| > delta + * + * Input arrays: + * 0: predictions - the predicted values, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels, and must be + * broadcastable to labels. 2: labels - ground truth vales, type float. Must + * have the same shape as predictions. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as + * predictions 1 - "weighted_sum", output is scalar and equal to sum of all + * elements of weightedLosses array 2 - "weighted_mean", output is scalar and + * equal to sum of all elements of weightedLosses array divided by sum of all + * elements of weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output + * is scalar and equal to scalar sum of all elements of weightedLosses array + * divided by number of non-zero weights + * + * Input float arguments: + * 0: point where the huber loss function changes from a quadratic to linear. + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as predictions or just single + * scalar, depending on reduction mode (see input integer argument) + */ +// #if NOT_EXCLUDED(OP_huber_loss) +@Namespace("sd::ops") public static class huber_loss extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public huber_loss(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public huber_loss(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public huber_loss position(long position) { + return (huber_loss)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of Huber loss function: - * 0.5 * (labels-predictions)^2 if |labels-predictions| <= delta - * 0.5 * delta^2 + delta * (|labels-predictions| - delta) if |labels-predictions| > delta - * - * Input arrays: - * 0: predictions - the predicted values, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels, and must be broadcastable to labels. - * 2: labels - ground truth vales, type float. - * Must have the same shape as predictions. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as predictions - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Input float arguments: - * 0: point where the huber loss function changes from a quadratic to linear. - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as predictions or just single scalar, depending on reduction mode (see input integer argument) - */ -// #if NOT_EXCLUDED(OP_huber_loss) - @Namespace("sd::ops") public static class huber_loss extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public huber_loss(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public huber_loss(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public huber_loss position(long position) { - return (huber_loss)super.position(position); - } - public huber_loss() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class huber_loss_grad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public huber_loss_grad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public huber_loss_grad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public huber_loss_grad position(long position) { - return (huber_loss_grad)super.position(position); - } - +@Namespace("sd::ops") public static class huber_loss_grad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public huber_loss_grad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public huber_loss_grad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public huber_loss_grad position(long position) { + return (huber_loss_grad)super.position(position); + } + public huber_loss_grad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of logarithmic loss function ( y_i * log(p_i) + (1 - y_i) * + * log(1 - p_i) ) + * + * Input arrays: + * 0: predictions - the predicted values, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels, and must be + * broadcastable to labels. 2: labels - ground truth vales, type float. Must + * have the same shape as predictions. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as + * predictions 1 - "weighted_sum", output is scalar and equal to sum of all + * elements of weightedLosses array 2 - "weighted_mean", output is scalar and + * equal to sum of all elements of weightedLosses array divided by sum of all + * elements of weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output + * is scalar and equal to scalar sum of all elements of weightedLosses array + * divided by number of non-zero weights + * + * Input float arguments: + * 0: a small increment to add to avoid taking a log of zero. + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as predictions or just single + * scalar, depending on reduction mode (see input integer argument) + */ +// #if NOT_EXCLUDED(OP_log_loss) +@Namespace("sd::ops") public static class log_loss extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public log_loss(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public log_loss(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public log_loss position(long position) { + return (log_loss)super.position(position); + } - - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of logarithmic loss function ( y_i * log(p_i) + (1 - y_i) * log(1 - p_i) ) - * - * Input arrays: - * 0: predictions - the predicted values, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels, and must be broadcastable to labels. - * 2: labels - ground truth vales, type float. - * Must have the same shape as predictions. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as predictions - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Input float arguments: - * 0: a small increment to add to avoid taking a log of zero. - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as predictions or just single scalar, depending on reduction mode (see input integer argument) - */ -// #if NOT_EXCLUDED(OP_log_loss) - @Namespace("sd::ops") public static class log_loss extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public log_loss(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public log_loss(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public log_loss position(long position) { - return (log_loss)super.position(position); - } - public log_loss() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class log_loss_grad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public log_loss_grad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public log_loss_grad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public log_loss_grad position(long position) { - return (log_loss_grad)super.position(position); - } - +@Namespace("sd::ops") public static class log_loss_grad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public log_loss_grad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public log_loss_grad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public log_loss_grad position(long position) { + return (log_loss_grad)super.position(position); + } + public log_loss_grad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * l2_loss op. + * compute a l2 norm for given array. + * + * input param - an array (tensor) + * output value - a real number with given type (e.g. float or double) + */ +// #if NOT_EXCLUDED(OP_l2_loss) +@Namespace("sd::ops") public static class l2_loss extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public l2_loss(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public l2_loss(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public l2_loss position(long position) { + return (l2_loss)super.position(position); + } - /** - * l2_loss op. - * compute a l2 norm for given array. - * - * input param - an array (tensor) - * output value - a real number with given type (e.g. float or double) - */ -// #if NOT_EXCLUDED(OP_l2_loss) - @Namespace("sd::ops") public static class l2_loss extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public l2_loss(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public l2_loss(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public l2_loss position(long position) { - return (l2_loss)super.position(position); - } - public l2_loss() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * This op calculates logarithmic loss of poisson distributed input. + * Input arrays: + * 0: log_predictions - must be already pre-transformed to log(x) + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels and must be + * broadcastable to labels. 2: labels - ground truth vales, expected to be 0. + * or 1., type float. Must have the same shape as logits. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as logits. + * 1 - "weighted_sum", output is scalar and equal to sum of all elements + * of weightedLosses array 2 - "weighted_mean", output is scalar and equal to + * sum of all elements of weightedLosses array divided by sum of all elements of + * weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output is scalar + * and equal to scalar sum of all elements of weightedLosses array divided by + * number of non-zero weights 1: optional - boolean value compute_full_loss: 0 + * (default) or 1 (compute) + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as log_predictions or just single + * scalar, depending on reduction mode (see input integer argument) + */ +// #if NOT_EXCLUDED(OP_log_poisson_loss) +@Namespace("sd::ops") public static class log_poisson_loss extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public log_poisson_loss(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public log_poisson_loss(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public log_poisson_loss position(long position) { + return (log_poisson_loss)super.position(position); + } - /** - * This op calculates logarithmic loss of poisson distributed input. - * Input arrays: - * 0: log_predictions - must be already pre-transformed to log(x) - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels and must be broadcastable to labels. - * 2: labels - ground truth vales, expected to be 0. or 1., type float. - * Must have the same shape as logits. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as logits. - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * 1: optional - boolean value compute_full_loss: 0 (default) or 1 (compute) - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as log_predictions or just single scalar, depending on reduction mode (see input integer argument) - */ -// #if NOT_EXCLUDED(OP_log_poisson_loss) - @Namespace("sd::ops") public static class log_poisson_loss extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public log_poisson_loss(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public log_poisson_loss(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public log_poisson_loss position(long position) { - return (log_poisson_loss)super.position(position); - } - public log_poisson_loss() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class log_poisson_loss_grad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public log_poisson_loss_grad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public log_poisson_loss_grad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public log_poisson_loss_grad position(long position) { - return (log_poisson_loss_grad)super.position(position); - } - +@Namespace("sd::ops") public static class log_poisson_loss_grad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public log_poisson_loss_grad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public log_poisson_loss_grad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public log_poisson_loss_grad position(long position) { + return (log_poisson_loss_grad)super.position(position); + } + public log_poisson_loss_grad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of pairwise-errors-squared loss function + * + * Input arrays: + * 0: predictions - the predicted values, type float. + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels and must be + * broadcastable to labels. 2: labels - ground truth vales, type float. Must + * have the same shape as predictions. + * + * Output array: + * 0: loss value, it is just single scalar, type float. + */ +// #if NOT_EXCLUDED(OP_mean_pairwssqerr_loss) +@Namespace("sd::ops") public static class mean_pairwssqerr_loss extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mean_pairwssqerr_loss(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mean_pairwssqerr_loss(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mean_pairwssqerr_loss position(long position) { + return (mean_pairwssqerr_loss)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of pairwise-errors-squared loss function - * - * Input arrays: - * 0: predictions - the predicted values, type float. - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels and must be broadcastable to labels. - * 2: labels - ground truth vales, type float. - * Must have the same shape as predictions. - * - * Output array: - * 0: loss value, it is just single scalar, type float. - */ -// #if NOT_EXCLUDED(OP_mean_pairwssqerr_loss) - @Namespace("sd::ops") public static class mean_pairwssqerr_loss extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mean_pairwssqerr_loss(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mean_pairwssqerr_loss(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mean_pairwssqerr_loss position(long position) { - return (mean_pairwssqerr_loss)super.position(position); - } - public mean_pairwssqerr_loss() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class mean_pairwssqerr_loss_grad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mean_pairwssqerr_loss_grad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mean_pairwssqerr_loss_grad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mean_pairwssqerr_loss_grad position(long position) { - return (mean_pairwssqerr_loss_grad)super.position(position); - } - +@Namespace("sd::ops") public static class mean_pairwssqerr_loss_grad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mean_pairwssqerr_loss_grad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mean_pairwssqerr_loss_grad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mean_pairwssqerr_loss_grad position(long position) { + return (mean_pairwssqerr_loss_grad)super.position(position); + } + public mean_pairwssqerr_loss_grad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of Sum-of-Squares loss function 1/N * + * sum_{i}^{N}(predictions_i - labels_i)^2 + * + * Input arrays: + * 0: predictions - the predicted values, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels and must be + * broadcastable to labels. 2: labels - ground truth vales, type float. Must + * have the same shape as predictions. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as + * predictions 1 - "weighted_sum", output is scalar and equal to sum of all + * elements of weightedLosses array 2 - "weighted_mean", output is scalar and + * equal to sum of all elements of weightedLosses array divided by sum of all + * elements of weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output + * is scalar and equal to scalar sum of all elements of weightedLosses array + * divided by number of non-zero weights + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as predictions or just single + * scalar, depending on reduction mode (see input integer argument) + */ +// #if NOT_EXCLUDED(OP_mean_sqerr_loss) +@Namespace("sd::ops") public static class mean_sqerr_loss extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mean_sqerr_loss(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mean_sqerr_loss(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mean_sqerr_loss position(long position) { + return (mean_sqerr_loss)super.position(position); + } - - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of Sum-of-Squares loss function 1/N * sum_{i}^{N}(predictions_i - labels_i)^2 - * - * Input arrays: - * 0: predictions - the predicted values, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels and must be broadcastable to labels. - * 2: labels - ground truth vales, type float. - * Must have the same shape as predictions. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as predictions - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as predictions or just single scalar, depending on reduction mode (see input integer argument) - */ -// #if NOT_EXCLUDED(OP_mean_sqerr_loss) - @Namespace("sd::ops") public static class mean_sqerr_loss extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mean_sqerr_loss(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mean_sqerr_loss(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mean_sqerr_loss position(long position) { - return (mean_sqerr_loss)super.position(position); - } - public mean_sqerr_loss() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class mean_sqerr_loss_grad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mean_sqerr_loss_grad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mean_sqerr_loss_grad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mean_sqerr_loss_grad position(long position) { - return (mean_sqerr_loss_grad)super.position(position); - } - +@Namespace("sd::ops") public static class mean_sqerr_loss_grad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mean_sqerr_loss_grad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mean_sqerr_loss_grad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mean_sqerr_loss_grad position(long position) { + return (mean_sqerr_loss_grad)super.position(position); + } + public mean_sqerr_loss_grad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of sigmoid cross-entropy loss function max(logits, 0.) - + * logits * labels + log(1. + exp(-abs(logits))); + * + * Input arrays: + * 0: logits - logits, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels, and must be + * broadcastable to labels. 2: labels - ground truth vales, expected to be 0. + * or 1., type float. Must have the same shape as logits. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as logits. + * 1 - "weighted_sum", output is scalar and equal to sum of all elements + * of weightedLosses array 2 - "weighted_mean", output is scalar and equal to + * sum of all elements of weightedLosses array divided by sum of all elements of + * weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output is scalar + * and equal to scalar sum of all elements of weightedLosses array divided by + * number of non-zero weights + * + * Input float arguments: + * 0: smoothing value, if it is greater than 0 then apply smoothing to the + * labels (smooth the labels towards 1/2): new_labels = labels * (1 - + * labelsSmoothing)+ 0.5 * labelsSmoothing + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as logits or just single scalar, + * depending on reduction mode (see input integer argument) + */ +// #if NOT_EXCLUDED(OP_sigm_cross_entropy_loss) +@Namespace("sd::ops") public static class sigm_cross_entropy_loss extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sigm_cross_entropy_loss(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sigm_cross_entropy_loss(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sigm_cross_entropy_loss position(long position) { + return (sigm_cross_entropy_loss)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of sigmoid cross-entropy loss function max(logits, 0.) - logits * labels + log(1. + exp(-abs(logits))); - * - * Input arrays: - * 0: logits - logits, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels, and must be broadcastable to labels. - * 2: labels - ground truth vales, expected to be 0. or 1., type float. - * Must have the same shape as logits. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as logits. - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Input float arguments: - * 0: smoothing value, if it is greater than 0 then apply smoothing to the labels (smooth the labels towards 1/2): new_labels = labels * (1 - labelsSmoothing)+ 0.5 * labelsSmoothing - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as logits or just single scalar, depending on reduction mode (see input integer argument) - */ -// #if NOT_EXCLUDED(OP_sigm_cross_entropy_loss) - @Namespace("sd::ops") public static class sigm_cross_entropy_loss extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sigm_cross_entropy_loss(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sigm_cross_entropy_loss(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sigm_cross_entropy_loss position(long position) { - return (sigm_cross_entropy_loss)super.position(position); - } - public sigm_cross_entropy_loss() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class sigm_cross_entropy_loss_grad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sigm_cross_entropy_loss_grad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sigm_cross_entropy_loss_grad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sigm_cross_entropy_loss_grad position(long position) { - return (sigm_cross_entropy_loss_grad)super.position(position); - } - +@Namespace("sd::ops") public static class sigm_cross_entropy_loss_grad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sigm_cross_entropy_loss_grad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sigm_cross_entropy_loss_grad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sigm_cross_entropy_loss_grad position(long position) { + return (sigm_cross_entropy_loss_grad)super.position(position); + } + public sigm_cross_entropy_loss_grad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - +// #endif + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of softmax cross-entropy loss function max(logits, 0.) - + * logits * labels + log(1. + exp(-abs(logits))); + * + * Input arrays: + * 0: logits - logits, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels, and must be + * broadcastable to labels. 2: labels - ground truth vales, expected to be 0. + * or 1., type float. Must have the same shape as logits. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as logits. + * 1 - "weighted_sum", output is scalar and equal to sum of all elements + * of weightedLosses array 2 - "weighted_mean", output is scalar and equal to + * sum of all elements of weightedLosses array divided by sum of all elements of + * weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output is scalar + * and equal to scalar sum of all elements of weightedLosses array divided by + * number of non-zero weights + * + * Input float arguments: + * 0: smoothing value, if it is greater than 0 then apply smoothing to the + * labels (smooth the labels towards 1/numClasses): new_labels = labels * (1 - + * labelsSmoothing) + labelsSmoothing / numClasses + * + * Output array: + * 0: loss values, type float. + * Can be an array with shape as in logits except last dimension is equal + * to unity or just single scalar, depending on reduction mode (see input + * integer argument) + */ +// #if NOT_EXCLUDED(OP_softmax_cross_entropy_loss) +@Namespace("sd::ops") public static class softmax_cross_entropy_loss extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public softmax_cross_entropy_loss(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public softmax_cross_entropy_loss(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public softmax_cross_entropy_loss position(long position) { + return (softmax_cross_entropy_loss)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of softmax cross-entropy loss function max(logits, 0.) - logits * labels + log(1. + exp(-abs(logits))); - * - * Input arrays: - * 0: logits - logits, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels, and must be broadcastable to labels. - * 2: labels - ground truth vales, expected to be 0. or 1., type float. - * Must have the same shape as logits. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as logits. - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Input float arguments: - * 0: smoothing value, if it is greater than 0 then apply smoothing to the labels (smooth the labels towards 1/numClasses): new_labels = labels * (1 - labelsSmoothing) + labelsSmoothing / numClasses - * - * Output array: - * 0: loss values, type float. - * Can be an array with shape as in logits except last dimension is equal to unity or just single scalar, depending on reduction mode (see input integer argument) - */ -// #if NOT_EXCLUDED(OP_softmax_cross_entropy_loss) - @Namespace("sd::ops") public static class softmax_cross_entropy_loss extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public softmax_cross_entropy_loss(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public softmax_cross_entropy_loss(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public softmax_cross_entropy_loss position(long position) { - return (softmax_cross_entropy_loss)super.position(position); - } - public softmax_cross_entropy_loss() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } - @Namespace("sd::ops") public static class softmax_cross_entropy_loss_grad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public softmax_cross_entropy_loss_grad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public softmax_cross_entropy_loss_grad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public softmax_cross_entropy_loss_grad position(long position) { - return (softmax_cross_entropy_loss_grad)super.position(position); - } - + } +@Namespace("sd::ops") public static class softmax_cross_entropy_loss_grad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public softmax_cross_entropy_loss_grad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public softmax_cross_entropy_loss_grad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public softmax_cross_entropy_loss_grad position(long position) { + return (softmax_cross_entropy_loss_grad)super.position(position); + } + public softmax_cross_entropy_loss_grad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } -// #endif + } +// #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of Absolute Difference loss function |predictions - labels| + * + * Input arrays: + * 0: predictions - the predicted values, type float. + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels and must be + * broadcastable to labels. 2: labels - ground truth vales, type float. Must + * have the same shape as predictions. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as + * predictions 1 - "weighted_sum", output is scalar and equal to sum of all + * elements of weightedLosses array 2 - "weighted_mean", output is scalar and + * equal to sum of all elements of weightedLosses array divided by sum of all + * elements of weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output + * is scalar and equal to scalar sum of all elements of weightedLosses array + * divided by number of non-zero weights + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as predictions or just single + * scalar, depending on reduction mode (see input integer argument) + */ +// #if NOT_EXCLUDED(OP_absolute_difference_loss) +@Namespace("sd::ops") public static class absolute_difference_loss extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public absolute_difference_loss(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public absolute_difference_loss(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public absolute_difference_loss position(long position) { + return (absolute_difference_loss)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of Absolute Difference loss function |predictions - labels| - * - * Input arrays: - * 0: predictions - the predicted values, type float. - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels and must be broadcastable to labels. - * 2: labels - ground truth vales, type float. - * Must have the same shape as predictions. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as predictions - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as predictions or just single scalar, depending on reduction mode (see input integer argument) - */ -// #if NOT_EXCLUDED(OP_absolute_difference_loss) - @Namespace("sd::ops") public static class absolute_difference_loss extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public absolute_difference_loss(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public absolute_difference_loss(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public absolute_difference_loss position(long position) { - return (absolute_difference_loss)super.position(position); - } - public absolute_difference_loss() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class absolute_difference_loss_grad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public absolute_difference_loss_grad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public absolute_difference_loss_grad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public absolute_difference_loss_grad position(long position) { - return (absolute_difference_loss_grad)super.position(position); - } - +@Namespace("sd::ops") public static class absolute_difference_loss_grad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public absolute_difference_loss_grad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public absolute_difference_loss_grad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public absolute_difference_loss_grad position(long position) { + return (absolute_difference_loss_grad)super.position(position); + } + public absolute_difference_loss_grad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of cosine-distance loss function 1. - (predictions * + * labels).reduce_sum_along(dimension) + * + * Input arrays: + * 0: predictions - the predicted values, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels and must be + * broadcastable to labels. 2: labels - ground truth vales, type float. Must + * have the same shape as predictions. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as + * predictions 1 - "weighted_sum", output is scalar and equal to sum of all + * elements of weightedLosses array 2 - "weighted_mean", output is scalar and + * equal to sum of all elements of weightedLosses array divided by sum of all + * elements of weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output + * is scalar and equal to scalar sum of all elements of weightedLosses array + * divided by number of non-zero weights 1: dimension along which the cosine + * distance is computed. + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as predictions or just single + * scalar, depending on reduction mode (see input integer argument) + */ +// #if NOT_EXCLUDED(OP_cosine_distance_loss) +@Namespace("sd::ops") public static class cosine_distance_loss extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cosine_distance_loss(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cosine_distance_loss(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cosine_distance_loss position(long position) { + return (cosine_distance_loss)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of cosine-distance loss function 1. - (predictions * labels).reduce_sum_along(dimension) - * - * Input arrays: - * 0: predictions - the predicted values, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels and must be broadcastable to labels. - * 2: labels - ground truth vales, type float. - * Must have the same shape as predictions. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as predictions - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * 1: dimension along which the cosine distance is computed. - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as predictions or just single scalar, depending on reduction mode (see input integer argument) - */ -// #if NOT_EXCLUDED(OP_cosine_distance_loss) - @Namespace("sd::ops") public static class cosine_distance_loss extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cosine_distance_loss(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cosine_distance_loss(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cosine_distance_loss position(long position) { - return (cosine_distance_loss)super.position(position); - } - public cosine_distance_loss() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class cosine_distance_loss_grad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cosine_distance_loss_grad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cosine_distance_loss_grad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cosine_distance_loss_grad position(long position) { - return (cosine_distance_loss_grad)super.position(position); - } - +@Namespace("sd::ops") public static class cosine_distance_loss_grad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cosine_distance_loss_grad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cosine_distance_loss_grad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cosine_distance_loss_grad position(long position) { + return (cosine_distance_loss_grad)super.position(position); + } + public cosine_distance_loss_grad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of softmax cross-entropy loss function + * + * Input arrays: + * 0: logits - logits, type float + * 1: labels - ground truth vales, expected to be 0. or 1., type float. + * Must have the same shape as logits. + * + * Input integer arguments: + * 0: optional (default is last dimension) dimension with classes + * + * Output array: + * 0: loss values, type float. An array with shape resulting from reducing of + * logits shape along dimension with classes + */ +// #if NOT_EXCLUDED(OP_softmax_cross_entropy_loss_with_logits) +@Namespace("sd::ops") public static class softmax_cross_entropy_loss_with_logits extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public softmax_cross_entropy_loss_with_logits(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public softmax_cross_entropy_loss_with_logits(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public softmax_cross_entropy_loss_with_logits position(long position) { + return (softmax_cross_entropy_loss_with_logits)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of softmax cross-entropy loss function - * - * Input arrays: - * 0: logits - logits, type float - * 1: labels - ground truth vales, expected to be 0. or 1., type float. - * Must have the same shape as logits. - * - * Input integer arguments: - * 0: optional (default is last dimension) dimension with classes - * - * Output array: - * 0: loss values, type float. An array with shape resulting from reducing of logits shape along dimension with classes - */ -// #if NOT_EXCLUDED(OP_softmax_cross_entropy_loss_with_logits) - @Namespace("sd::ops") public static class softmax_cross_entropy_loss_with_logits extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public softmax_cross_entropy_loss_with_logits(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public softmax_cross_entropy_loss_with_logits(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public softmax_cross_entropy_loss_with_logits position(long position) { - return (softmax_cross_entropy_loss_with_logits)super.position(position); - } - public softmax_cross_entropy_loss_with_logits() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class softmax_cross_entropy_loss_with_logits_grad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public softmax_cross_entropy_loss_with_logits_grad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public softmax_cross_entropy_loss_with_logits_grad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public softmax_cross_entropy_loss_with_logits_grad position(long position) { - return (softmax_cross_entropy_loss_with_logits_grad)super.position(position); - } - +@Namespace("sd::ops") public static class softmax_cross_entropy_loss_with_logits_grad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public softmax_cross_entropy_loss_with_logits_grad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public softmax_cross_entropy_loss_with_logits_grad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public softmax_cross_entropy_loss_with_logits_grad position(long position) { + return (softmax_cross_entropy_loss_with_logits_grad)super.position(position); + } + public softmax_cross_entropy_loss_with_logits_grad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of sparse softmax cross-entropy loss function + * + * Input arrays: + * 0: labels - ground truth vales, expected to be within range [0, + * num_classes), type float. Must have rank equal logits rank minus 1. 1: logits + * - logits, type float + * + * Output array: + * 0: loss values, type float. Has the same shape as labels + */ +// #if NOT_EXCLUDED(OP_sparse_softmax_cross_entropy_loss_with_logits) +@Namespace("sd::ops") public static class sparse_softmax_cross_entropy_loss_with_logits extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sparse_softmax_cross_entropy_loss_with_logits(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sparse_softmax_cross_entropy_loss_with_logits(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sparse_softmax_cross_entropy_loss_with_logits position(long position) { + return (sparse_softmax_cross_entropy_loss_with_logits)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of sparse softmax cross-entropy loss function - * - * Input arrays: - * 0: labels - ground truth vales, expected to be within range [0, num_classes), type float. - * Must have rank equal logits rank minus 1. - * 1: logits - logits, type float - * - * Output array: - * 0: loss values, type float. Has the same shape as labels - */ -// #if NOT_EXCLUDED(OP_sparse_softmax_cross_entropy_loss_with_logits) - @Namespace("sd::ops") public static class sparse_softmax_cross_entropy_loss_with_logits extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sparse_softmax_cross_entropy_loss_with_logits(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sparse_softmax_cross_entropy_loss_with_logits(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sparse_softmax_cross_entropy_loss_with_logits position(long position) { - return (sparse_softmax_cross_entropy_loss_with_logits)super.position(position); - } - public sparse_softmax_cross_entropy_loss_with_logits() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class sparse_softmax_cross_entropy_loss_with_logits_grad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sparse_softmax_cross_entropy_loss_with_logits_grad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sparse_softmax_cross_entropy_loss_with_logits_grad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sparse_softmax_cross_entropy_loss_with_logits_grad position(long position) { - return (sparse_softmax_cross_entropy_loss_with_logits_grad)super.position(position); - } - +@Namespace("sd::ops") public static class sparse_softmax_cross_entropy_loss_with_logits_grad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sparse_softmax_cross_entropy_loss_with_logits_grad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sparse_softmax_cross_entropy_loss_with_logits_grad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sparse_softmax_cross_entropy_loss_with_logits_grad position(long position) { + return (sparse_softmax_cross_entropy_loss_with_logits_grad)super.position(position); + } + public sparse_softmax_cross_entropy_loss_with_logits_grad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - - +// #endif + // namespace ops + // namespace sd // #endif @@ -24106,218 +26578,221 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_HEADERS_DTYPE_H // #define LIBND4J_HEADERS_DTYPE_H -// #include - /** - * This operation casts elements of input array to double data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ -// #if NOT_EXCLUDED(OP_to_double) - @Namespace("sd::ops") public static class to_double extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public to_double(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public to_double(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public to_double position(long position) { - return (to_double)super.position(position); - } - +// #include +/** + * This operation casts elements of input array to double data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +// #if NOT_EXCLUDED(OP_to_double) +@Namespace("sd::ops") public static class to_double extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public to_double(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public to_double(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public to_double position(long position) { + return (to_double)super.position(position); + } + public to_double() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation casts elements of input array to float16 data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +// #if NOT_EXCLUDED(OP_to_float16) +@Namespace("sd::ops") public static class to_float16 extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public to_float16(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public to_float16(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public to_float16 position(long position) { + return (to_float16)super.position(position); + } - /** - * This operation casts elements of input array to float16 data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ -// #if NOT_EXCLUDED(OP_to_float16) - @Namespace("sd::ops") public static class to_float16 extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public to_float16(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public to_float16(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public to_float16 position(long position) { - return (to_float16)super.position(position); - } - public to_float16() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation casts elements of input array to float data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +// #if NOT_EXCLUDED(OP_to_float32) +@Namespace("sd::ops") public static class to_float32 extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public to_float32(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public to_float32(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public to_float32 position(long position) { + return (to_float32)super.position(position); + } - /** - * This operation casts elements of input array to float data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ -// #if NOT_EXCLUDED(OP_to_float32) - @Namespace("sd::ops") public static class to_float32 extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public to_float32(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public to_float32(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public to_float32 position(long position) { - return (to_float32)super.position(position); - } - public to_float32() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation casts elements of input array to int32 data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +// #if NOT_EXCLUDED(OP_to_int32) +@Namespace("sd::ops") public static class to_int32 extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public to_int32(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public to_int32(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public to_int32 position(long position) { + return (to_int32)super.position(position); + } - /** - * This operation casts elements of input array to int32 data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ -// #if NOT_EXCLUDED(OP_to_int32) - @Namespace("sd::ops") public static class to_int32 extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public to_int32(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public to_int32(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public to_int32 position(long position) { - return (to_int32)super.position(position); - } - public to_int32() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation casts elements of input array to int64 (aka long long) data + * type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +// #if NOT_EXCLUDED(OP_to_int64) +@Namespace("sd::ops") public static class to_int64 extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public to_int64(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public to_int64(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public to_int64 position(long position) { + return (to_int64)super.position(position); + } - /** - * This operation casts elements of input array to int64 (aka long long) data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ -// #if NOT_EXCLUDED(OP_to_int64) - @Namespace("sd::ops") public static class to_int64 extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public to_int64(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public to_int64(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public to_int64 position(long position) { - return (to_int64)super.position(position); - } - public to_int64() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation casts elements of input array to unsinged int32 data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +// #if NOT_EXCLUDED(OP_to_uint32) +@Namespace("sd::ops") public static class to_uint32 extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public to_uint32(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public to_uint32(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public to_uint32 position(long position) { + return (to_uint32)super.position(position); + } - /** - * This operation casts elements of input array to unsinged int32 data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ -// #if NOT_EXCLUDED(OP_to_uint32) - @Namespace("sd::ops") public static class to_uint32 extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public to_uint32(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public to_uint32(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public to_uint32 position(long position) { - return (to_uint32)super.position(position); - } - public to_uint32() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation casts elements of input array to unsigned int64 (aka unsigned + * long long) data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +// #if NOT_EXCLUDED(OP_to_uint64) +@Namespace("sd::ops") public static class to_uint64 extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public to_uint64(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public to_uint64(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public to_uint64 position(long position) { + return (to_uint64)super.position(position); + } - /** - * This operation casts elements of input array to unsigned int64 (aka unsigned long long) data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ -// #if NOT_EXCLUDED(OP_to_uint64) - @Namespace("sd::ops") public static class to_uint64 extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public to_uint64(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public to_uint64(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public to_uint64 position(long position) { - return (to_uint64)super.position(position); - } - public to_uint64() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation casts elements of input array to specified data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + * + * + * Int args: + * 0: target DataType + */ +// #if NOT_EXCLUDED(OP_cast) +@Namespace("sd::ops") public static class cast extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cast(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cast(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cast position(long position) { + return (cast)super.position(position); + } - /** - * This operation casts elements of input array to specified data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - * - * - * Int args: - * 0: target DataType - */ -// #if NOT_EXCLUDED(OP_cast) - @Namespace("sd::ops") public static class cast extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cast(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cast(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cast position(long position) { - return (cast)super.position(position); - } - public cast() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - /** - * This operation change type of input and modified shape of output to conform with given data type - * - * all as above op - * */ -// #if NOT_EXCLUDED(OP_bitcast) - @Namespace("sd::ops") public static class bitcast extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public bitcast(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public bitcast(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public bitcast position(long position) { - return (bitcast)super.position(position); - } - +// #endif +/** + * This operation change type of input and modified shape of output to conform + * with given data type + * + * all as above op + * */ +// #if NOT_EXCLUDED(OP_bitcast) +@Namespace("sd::ops") public static class bitcast extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public bitcast(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public bitcast(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public bitcast position(long position) { + return (bitcast)super.position(position); + } + public bitcast() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - +// #endif + // namespace ops + // namespace sd // #endif @@ -24346,56 +26821,57 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_CONTEXTBUFFERS_H // #define LIBND4J_CONTEXTBUFFERS_H +// #include // #include // #include -// #include - @Namespace("sd") @NoOffset public static class ContextBuffers extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ContextBuffers(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public ContextBuffers(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public ContextBuffers position(long position) { - return (ContextBuffers)super.position(position); - } - - public ContextBuffers() { super((Pointer)null); allocate(); } - private native void allocate(); - public ContextBuffers(@Const @ByRef ContextBuffers other) { super((Pointer)null); allocate(other); } - private native void allocate(@Const @ByRef ContextBuffers other); - public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/) { super((Pointer)null); allocate(rPointer, sPointer, aPointer, isOwner); } - private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/); - public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer) { super((Pointer)null); allocate(rPointer, sPointer, aPointer); } - private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer); - - public native @ByRef @Name("operator =") ContextBuffers put(@Const @ByRef ContextBuffers other); +@Namespace("sd") @NoOffset public static class ContextBuffers extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ContextBuffers(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public ContextBuffers(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public ContextBuffers position(long position) { + return (ContextBuffers)super.position(position); + } - public native void release(); + public ContextBuffers() { super((Pointer)null); allocate(); } + private native void allocate(); + public ContextBuffers(@Const @ByRef ContextBuffers other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef ContextBuffers other); + public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer, + @Cast("bool") boolean isOwner/*=false*/) { super((Pointer)null); allocate(rPointer, sPointer, aPointer, isOwner); } + private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer, + @Cast("bool") boolean isOwner/*=false*/); + public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer) { super((Pointer)null); allocate(rPointer, sPointer, aPointer); } + private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer); - public native Pointer reductionBuffer(); - public native Pointer scalarBuffer(); - public native Pointer allocationBuffer(); + public native @ByRef @Name("operator =") ContextBuffers put(@Const @ByRef ContextBuffers other); - public native Pointer execStream(); - public native Pointer specialStream(); + public native void release(); - public native void setReductionBuffer(Pointer pointer); - public native void setScalarBuffer(Pointer pointer); - public native void setAllocationBuffer(Pointer pointer); + public native Pointer reductionBuffer(); + public native Pointer scalarBuffer(); + public native Pointer allocationBuffer(); - public native ErrorReference errorReference(); + public native Pointer execStream(); + public native Pointer specialStream(); - public native void triggerOwnership(@Cast("bool") boolean isOwner); + public native void setReductionBuffer(Pointer pointer); + public native void setScalarBuffer(Pointer pointer); + public native void setAllocationBuffer(Pointer pointer); - public native int deviceId(); + public native ErrorReference errorReference(); - public native @Cast("bool") boolean isInitialized(); - } + public native void triggerOwnership(@Cast("bool") boolean isOwner); + public native int deviceId(); + public native @Cast("bool") boolean isInitialized(); +} + // namespace sd -// #endif //DEV_TESTS_CONTEXTBUFFERS_H +// #endif // SD_CONTEXTBUFFERS_H // Parsed from execution/LaunchContext.h @@ -24423,7 +26899,6 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_CUDACONTEXT_H // #define LIBND4J_CUDACONTEXT_H - // #ifdef __CUDABLAS__ // #endif @@ -24432,14 +26907,15 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #include "config.h" // #endif +// #include +// #include +// #include // #include -// #include // #include -// #include -// #include + +// #include // #include -// #include -// #include +// #include @Namespace("sd") @NoOffset public static class LaunchContext extends Pointer { static { Loader.load(); } @@ -24452,41 +26928,41 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifdef __CUDABLAS__ -// #endif // CUDA - public LaunchContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer/*=nullptr*/, @Cast("Nd4jPointer") Pointer scalarPointer/*=nullptr*/, @Cast("Nd4jPointer") Pointer allocationPointer/*=nullptr*/) { super((Pointer)null); allocate(cudaStream, reductionPointer, scalarPointer, allocationPointer); } - private native void allocate(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer/*=nullptr*/, @Cast("Nd4jPointer") Pointer scalarPointer/*=nullptr*/, @Cast("Nd4jPointer") Pointer allocationPointer/*=nullptr*/); - public LaunchContext(@Cast("Nd4jPointer") Pointer cudaStream) { super((Pointer)null); allocate(cudaStream); } - private native void allocate(@Cast("Nd4jPointer") Pointer cudaStream); - public LaunchContext() { super((Pointer)null); allocate(); } - private native void allocate(); - public native Workspace getWorkspace(); - public native void setWorkspace(Workspace theWorkspace); +// #endif // CUDA + public LaunchContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer/*=nullptr*/, + @Cast("Nd4jPointer") Pointer scalarPointer/*=nullptr*/, + @Cast("Nd4jPointer") Pointer allocationPointer/*=nullptr*/) { super((Pointer)null); allocate(cudaStream, reductionPointer, scalarPointer, allocationPointer); } + private native void allocate(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer/*=nullptr*/, + @Cast("Nd4jPointer") Pointer scalarPointer/*=nullptr*/, + @Cast("Nd4jPointer") Pointer allocationPointer/*=nullptr*/); + public LaunchContext(@Cast("Nd4jPointer") Pointer cudaStream) { super((Pointer)null); allocate(cudaStream); } + private native void allocate(@Cast("Nd4jPointer") Pointer cudaStream); + public LaunchContext() { super((Pointer)null); allocate(); } + private native void allocate(); + public native Workspace getWorkspace(); + public native void setWorkspace(Workspace theWorkspace); - public native Pointer engine(); + public native Pointer engine(); - public native int getDeviceID(); - public native void setDeviceID(int deviceID); - public native ErrorReference errorReference(); + public native int getDeviceID(); + public native void setDeviceID(int deviceID); + public native ErrorReference errorReference(); // #ifndef __JAVACPP_HACK__ // #endif - public static native @Cast("bool") boolean isInitialized(); - public static native void releaseBuffers(); - - - public static native LaunchContext defaultContext(); + public static native @Cast("bool") boolean isInitialized(); + public static native void releaseBuffers(); + public static native LaunchContext defaultContext(); - public static native void swapContextBuffers(@ByRef ContextBuffers buffers); - + public static native void swapContextBuffers(@ByRef ContextBuffers buffers); } + // namespace sd - - -// #endif //LIBND4J_CUDACONTEXT_H +// #endif // LIBND4J_CUDACONTEXT_H // Parsed from array/ShapeDescriptor.h @@ -24511,15 +26987,16 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // @author raver119@gmail.com // -// #ifndef DEV_TESTS_SHAPEDESCRIPTOR_H -// #define DEV_TESTS_SHAPEDESCRIPTOR_H +// #ifndef SD_SHAPEDESCRIPTOR_H +// #define SD_SHAPEDESCRIPTOR_H -// #include -// #include +// #include // #include // #include -// #include + // #include +// #include +// #include @Namespace("sd") @NoOffset public static class ShapeDescriptor extends Pointer { static { Loader.load(); } @@ -24532,108 +27009,185 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint return (ShapeDescriptor)super.position(position); } - public ShapeDescriptor(@Const @ByRef ShapeDescriptor other) { super((Pointer)null); allocate(other); } - private native void allocate(@Const @ByRef ShapeDescriptor other); - public ShapeDescriptor(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("bool") boolean inheritDtype/*=true*/) { super((Pointer)null); allocate(shapeInfo, inheritDtype); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("bool") boolean inheritDtype/*=true*/); - public ShapeDescriptor(@Cast("const Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(shapeInfo); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo); - public ShapeDescriptor(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("bool") boolean inheritDtype/*=true*/) { super((Pointer)null); allocate(shapeInfo, inheritDtype); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("bool") boolean inheritDtype/*=true*/); - public ShapeDescriptor(@Cast("const Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(shapeInfo); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - public ShapeDescriptor(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("bool") boolean inheritDtype/*=true*/) { super((Pointer)null); allocate(shapeInfo, inheritDtype); } - private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("bool") boolean inheritDtype/*=true*/); - public ShapeDescriptor(@Cast("const Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(shapeInfo); } - private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo); - public ShapeDescriptor(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtypeOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtypeOverride); - public ShapeDescriptor(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtypeOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtypeOverride); - public ShapeDescriptor(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtypeOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride); } - private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtypeOverride); - public ShapeDescriptor(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer dtypeOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer dtypeOverride); - public ShapeDescriptor(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer dtypeOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer dtypeOverride); - public ShapeDescriptor(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] dtypeOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride); } - private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] dtypeOverride); - public ShapeDescriptor(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer dtypeOverride, @Cast("const Nd4jLong*") LongPointer orderOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride, orderOverride); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer dtypeOverride, @Cast("const Nd4jLong*") LongPointer orderOverride); - public ShapeDescriptor(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer dtypeOverride, @Cast("const Nd4jLong*") LongBuffer orderOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride, orderOverride); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer dtypeOverride, @Cast("const Nd4jLong*") LongBuffer orderOverride); - public ShapeDescriptor(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] dtypeOverride, @Cast("const Nd4jLong*") long[] orderOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride, orderOverride); } - private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] dtypeOverride, @Cast("const Nd4jLong*") long[] orderOverride); - public ShapeDescriptor(@Cast("const sd::DataType") int type, @Cast("const Nd4jLong") long length) { super((Pointer)null); allocate(type, length); } - private native void allocate(@Cast("const sd::DataType") int type, @Cast("const Nd4jLong") long length); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("const Nd4jLong*") LongPointer shape, int rank) { super((Pointer)null); allocate(type, order, shape, rank); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("const Nd4jLong*") LongPointer shape, int rank); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("const Nd4jLong*") LongBuffer shape, int rank) { super((Pointer)null); allocate(type, order, shape, rank); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("const Nd4jLong*") LongBuffer shape, int rank); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("const Nd4jLong*") long[] shape, int rank) { super((Pointer)null); allocate(type, order, shape, rank); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("const Nd4jLong*") long[] shape, int rank); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer strides, int rank, @Cast("Nd4jLong") long ews, @Cast("const bool") boolean empty) { super((Pointer)null); allocate(type, order, shape, strides, rank, ews, empty); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer strides, int rank, @Cast("Nd4jLong") long ews, @Cast("const bool") boolean empty); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer strides, int rank, @Cast("Nd4jLong") long ews, @Cast("const bool") boolean empty) { super((Pointer)null); allocate(type, order, shape, strides, rank, ews, empty); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer strides, int rank, @Cast("Nd4jLong") long ews, @Cast("const bool") boolean empty); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] strides, int rank, @Cast("Nd4jLong") long ews, @Cast("const bool") boolean empty) { super((Pointer)null); allocate(type, order, shape, strides, rank, ews, empty); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] strides, int rank, @Cast("Nd4jLong") long ews, @Cast("const bool") boolean empty); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape) { super((Pointer)null); allocate(type, order, shape); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape) { super((Pointer)null); allocate(type, order, shape); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector long[] shape) { super((Pointer)null); allocate(type, order, shape); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector long[] shape); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("Nd4jLong*") @StdVector LongPointer strides) { super((Pointer)null); allocate(type, order, shape, strides); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("Nd4jLong*") @StdVector LongPointer strides); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("Nd4jLong*") @StdVector LongBuffer strides) { super((Pointer)null); allocate(type, order, shape, strides); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("Nd4jLong*") @StdVector LongBuffer strides); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("Nd4jLong*") @StdVector long[] strides) { super((Pointer)null); allocate(type, order, shape, strides); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("Nd4jLong*") @StdVector long[] strides); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("Nd4jLong*") @StdVector LongPointer strides, @Cast("const Nd4jLong") long ews) { super((Pointer)null); allocate(type, order, shape, strides, ews); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("Nd4jLong*") @StdVector LongPointer strides, @Cast("const Nd4jLong") long ews); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("Nd4jLong*") @StdVector LongBuffer strides, @Cast("const Nd4jLong") long ews) { super((Pointer)null); allocate(type, order, shape, strides, ews); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("Nd4jLong*") @StdVector LongBuffer strides, @Cast("const Nd4jLong") long ews); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("Nd4jLong*") @StdVector long[] strides, @Cast("const Nd4jLong") long ews) { super((Pointer)null); allocate(type, order, shape, strides, ews); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("Nd4jLong*") @StdVector long[] strides, @Cast("const Nd4jLong") long ews); - public ShapeDescriptor() { super((Pointer)null); allocate(); } - private native void allocate(); - - public native int rank(); - public native @Cast("Nd4jLong") long ews(); - public native @Cast("Nd4jLong") long arrLength(); - public native char order(); - public native @Cast("sd::DataType") int dataType(); - public native @Cast("bool") boolean isEmpty(); - public native @Cast("Nd4jLong*") @StdVector LongPointer shape(); - public native @Cast("Nd4jLong*") @StdVector LongPointer strides(); - - // we use default copy assignment operator - public native @ByRef @Name("operator =") ShapeDescriptor put(@Const @ByRef ShapeDescriptor other); - - // we use default move assignment operator - - // equal to operator - public native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef ShapeDescriptor other); - - // less than operator - public native @Cast("bool") @Name("operator <") boolean lessThan(@Const @ByRef ShapeDescriptor other); - - public native @Cast("Nd4jLong*") LongPointer toShapeInfo(); - - - public static native @ByVal ShapeDescriptor emptyDescriptor(@Cast("const sd::DataType") int type); - public static native @ByVal ShapeDescriptor scalarDescriptor(@Cast("const sd::DataType") int type); - public static native @ByVal ShapeDescriptor vectorDescriptor(@Cast("const Nd4jLong") long length, @Cast("const sd::DataType") int type); - } + public ShapeDescriptor(@Const @ByRef ShapeDescriptor other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef ShapeDescriptor other); + public ShapeDescriptor(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("bool") boolean inheritDtype/*=true*/) { super((Pointer)null); allocate(shapeInfo, inheritDtype); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("bool") boolean inheritDtype/*=true*/); + public ShapeDescriptor(@Cast("const Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(shapeInfo); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo); + public ShapeDescriptor(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("bool") boolean inheritDtype/*=true*/) { super((Pointer)null); allocate(shapeInfo, inheritDtype); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("bool") boolean inheritDtype/*=true*/); + public ShapeDescriptor(@Cast("const Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(shapeInfo); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo); + public ShapeDescriptor(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("bool") boolean inheritDtype/*=true*/) { super((Pointer)null); allocate(shapeInfo, inheritDtype); } + private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("bool") boolean inheritDtype/*=true*/); + public ShapeDescriptor(@Cast("const Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(shapeInfo); } + private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo); + public ShapeDescriptor(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Cast("const sd::DataType") int dtypeOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Cast("const sd::DataType") int dtypeOverride); + public ShapeDescriptor(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Cast("const sd::DataType") int dtypeOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Cast("const sd::DataType") int dtypeOverride); + public ShapeDescriptor(@Cast("const Nd4jLong*") long[] shapeInfo, + @Cast("const sd::DataType") int dtypeOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride); } + private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, + @Cast("const sd::DataType") int dtypeOverride); + public ShapeDescriptor(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Cast("const Nd4jLong*") LongPointer dtypeOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Cast("const Nd4jLong*") LongPointer dtypeOverride); + public ShapeDescriptor(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Cast("const Nd4jLong*") LongBuffer dtypeOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Cast("const Nd4jLong*") LongBuffer dtypeOverride); + public ShapeDescriptor(@Cast("const Nd4jLong*") long[] shapeInfo, + @Cast("const Nd4jLong*") long[] dtypeOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride); } + private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, + @Cast("const Nd4jLong*") long[] dtypeOverride); + public ShapeDescriptor(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Cast("const Nd4jLong*") LongPointer dtypeOverride, + @Cast("const Nd4jLong*") LongPointer orderOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride, orderOverride); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Cast("const Nd4jLong*") LongPointer dtypeOverride, + @Cast("const Nd4jLong*") LongPointer orderOverride); + public ShapeDescriptor(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Cast("const Nd4jLong*") LongBuffer dtypeOverride, + @Cast("const Nd4jLong*") LongBuffer orderOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride, orderOverride); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Cast("const Nd4jLong*") LongBuffer dtypeOverride, + @Cast("const Nd4jLong*") LongBuffer orderOverride); + public ShapeDescriptor(@Cast("const Nd4jLong*") long[] shapeInfo, + @Cast("const Nd4jLong*") long[] dtypeOverride, + @Cast("const Nd4jLong*") long[] orderOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride, orderOverride); } + private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, + @Cast("const Nd4jLong*") long[] dtypeOverride, + @Cast("const Nd4jLong*") long[] orderOverride); + public ShapeDescriptor(@Cast("const sd::DataType") int type, @Cast("const Nd4jLong") long length) { super((Pointer)null); allocate(type, length); } + private native void allocate(@Cast("const sd::DataType") int type, @Cast("const Nd4jLong") long length); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("const Nd4jLong*") LongPointer shape, int rank) { super((Pointer)null); allocate(type, order, shape, rank); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("const Nd4jLong*") LongPointer shape, int rank); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("const Nd4jLong*") LongBuffer shape, int rank) { super((Pointer)null); allocate(type, order, shape, rank); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("const Nd4jLong*") LongBuffer shape, int rank); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("const Nd4jLong*") long[] shape, int rank) { super((Pointer)null); allocate(type, order, shape, rank); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("const Nd4jLong*") long[] shape, int rank); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer strides, + int rank, @Cast("Nd4jLong") long ews, @Cast("const bool") boolean empty) { super((Pointer)null); allocate(type, order, shape, strides, rank, ews, empty); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer strides, + int rank, @Cast("Nd4jLong") long ews, @Cast("const bool") boolean empty); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer strides, + int rank, @Cast("Nd4jLong") long ews, @Cast("const bool") boolean empty) { super((Pointer)null); allocate(type, order, shape, strides, rank, ews, empty); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer strides, + int rank, @Cast("Nd4jLong") long ews, @Cast("const bool") boolean empty); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] strides, + int rank, @Cast("Nd4jLong") long ews, @Cast("const bool") boolean empty) { super((Pointer)null); allocate(type, order, shape, strides, rank, ews, empty); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] strides, + int rank, @Cast("Nd4jLong") long ews, @Cast("const bool") boolean empty); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector LongPointer shape) { super((Pointer)null); allocate(type, order, shape); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector LongPointer shape); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector LongBuffer shape) { super((Pointer)null); allocate(type, order, shape); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector LongBuffer shape); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector long[] shape) { super((Pointer)null); allocate(type, order, shape); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector long[] shape); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("Nd4jLong*") @StdVector LongPointer strides) { super((Pointer)null); allocate(type, order, shape, strides); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("Nd4jLong*") @StdVector LongPointer strides); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("Nd4jLong*") @StdVector LongBuffer strides) { super((Pointer)null); allocate(type, order, shape, strides); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("Nd4jLong*") @StdVector LongBuffer strides); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("Nd4jLong*") @StdVector long[] strides) { super((Pointer)null); allocate(type, order, shape, strides); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("Nd4jLong*") @StdVector long[] strides); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("Nd4jLong*") @StdVector LongPointer strides, + @Cast("const Nd4jLong") long ews) { super((Pointer)null); allocate(type, order, shape, strides, ews); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("Nd4jLong*") @StdVector LongPointer strides, + @Cast("const Nd4jLong") long ews); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("Nd4jLong*") @StdVector LongBuffer strides, + @Cast("const Nd4jLong") long ews) { super((Pointer)null); allocate(type, order, shape, strides, ews); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("Nd4jLong*") @StdVector LongBuffer strides, + @Cast("const Nd4jLong") long ews); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("Nd4jLong*") @StdVector long[] strides, + @Cast("const Nd4jLong") long ews) { super((Pointer)null); allocate(type, order, shape, strides, ews); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("Nd4jLong*") @StdVector long[] strides, + @Cast("const Nd4jLong") long ews); + public ShapeDescriptor() { super((Pointer)null); allocate(); } + private native void allocate(); + + public native int rank(); + public native @Cast("Nd4jLong") long ews(); + public native @Cast("Nd4jLong") long arrLength(); + public native char order(); + public native @Cast("sd::DataType") int dataType(); + public native @Cast("bool") boolean isEmpty(); + public native @Cast("Nd4jLong*") @StdVector LongPointer shape(); + public native @Cast("Nd4jLong*") @StdVector LongPointer strides(); + + // we use default copy assignment operator + public native @ByRef @Name("operator =") ShapeDescriptor put(@Const @ByRef ShapeDescriptor other); + + // we use default move assignment operator + + // equal to operator + public native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef ShapeDescriptor other); + // less than operator + public native @Cast("bool") @Name("operator <") boolean lessThan(@Const @ByRef ShapeDescriptor other); + + public native @Cast("Nd4jLong*") LongPointer toShapeInfo(); + + public static native @ByVal ShapeDescriptor emptyDescriptor(@Cast("const sd::DataType") int type); + public static native @ByVal ShapeDescriptor scalarDescriptor(@Cast("const sd::DataType") int type); + public static native @ByVal ShapeDescriptor vectorDescriptor(@Cast("const Nd4jLong") long length, + @Cast("const sd::DataType") int type); +} + // namespace sd // #ifndef __JAVACPP_HACK__ // #endif - -// #endif //DEV_TESTS_SHAPEDESCRIPTOR_H +// #endif // SD_SHAPEDESCRIPTOR_H // Parsed from array/TadDescriptor.h @@ -24658,67 +27212,103 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // @author raver119@gmail.com // -// #ifndef DEV_TESTS_TADDESCRIPTOR_H -// #define DEV_TESTS_TADDESCRIPTOR_H +// #ifndef SD_TADDESCRIPTOR_H +// #define SD_TADDESCRIPTOR_H -// #include "ShapeDescriptor.h" // #include - @Namespace("sd") @NoOffset public static class TadDescriptor extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public TadDescriptor(Pointer p) { super(p); } - - public TadDescriptor(@Cast("const Nd4jLong*") LongPointer originalShape, @Const IntPointer dimensions, int length, @Cast("const bool") boolean keepUnitiesInShape/*=false*/) { super((Pointer)null); allocate(originalShape, dimensions, length, keepUnitiesInShape); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer originalShape, @Const IntPointer dimensions, int length, @Cast("const bool") boolean keepUnitiesInShape/*=false*/); - public TadDescriptor(@Cast("const Nd4jLong*") LongPointer originalShape, @Const IntPointer dimensions, int length) { super((Pointer)null); allocate(originalShape, dimensions, length); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer originalShape, @Const IntPointer dimensions, int length); - public TadDescriptor(@Cast("const Nd4jLong*") LongBuffer originalShape, @Const IntBuffer dimensions, int length, @Cast("const bool") boolean keepUnitiesInShape/*=false*/) { super((Pointer)null); allocate(originalShape, dimensions, length, keepUnitiesInShape); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer originalShape, @Const IntBuffer dimensions, int length, @Cast("const bool") boolean keepUnitiesInShape/*=false*/); - public TadDescriptor(@Cast("const Nd4jLong*") LongBuffer originalShape, @Const IntBuffer dimensions, int length) { super((Pointer)null); allocate(originalShape, dimensions, length); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer originalShape, @Const IntBuffer dimensions, int length); - public TadDescriptor(@Cast("const Nd4jLong*") long[] originalShape, @Const int[] dimensions, int length, @Cast("const bool") boolean keepUnitiesInShape/*=false*/) { super((Pointer)null); allocate(originalShape, dimensions, length, keepUnitiesInShape); } - private native void allocate(@Cast("const Nd4jLong*") long[] originalShape, @Const int[] dimensions, int length, @Cast("const bool") boolean keepUnitiesInShape/*=false*/); - public TadDescriptor(@Cast("const Nd4jLong*") long[] originalShape, @Const int[] dimensions, int length) { super((Pointer)null); allocate(originalShape, dimensions, length); } - private native void allocate(@Cast("const Nd4jLong*") long[] originalShape, @Const int[] dimensions, int length); - public TadDescriptor(@Const @ByRef ShapeDescriptor descriptor, @StdVector IntPointer dimensions, @Cast("const bool") boolean keepUnitiesInShape/*=false*/) { super((Pointer)null); allocate(descriptor, dimensions, keepUnitiesInShape); } - private native void allocate(@Const @ByRef ShapeDescriptor descriptor, @StdVector IntPointer dimensions, @Cast("const bool") boolean keepUnitiesInShape/*=false*/); - public TadDescriptor(@Const @ByRef ShapeDescriptor descriptor, @StdVector IntPointer dimensions) { super((Pointer)null); allocate(descriptor, dimensions); } - private native void allocate(@Const @ByRef ShapeDescriptor descriptor, @StdVector IntPointer dimensions); - public TadDescriptor(@Const @ByRef ShapeDescriptor descriptor, @StdVector IntBuffer dimensions, @Cast("const bool") boolean keepUnitiesInShape/*=false*/) { super((Pointer)null); allocate(descriptor, dimensions, keepUnitiesInShape); } - private native void allocate(@Const @ByRef ShapeDescriptor descriptor, @StdVector IntBuffer dimensions, @Cast("const bool") boolean keepUnitiesInShape/*=false*/); - public TadDescriptor(@Const @ByRef ShapeDescriptor descriptor, @StdVector IntBuffer dimensions) { super((Pointer)null); allocate(descriptor, dimensions); } - private native void allocate(@Const @ByRef ShapeDescriptor descriptor, @StdVector IntBuffer dimensions); - public TadDescriptor(@Const @ByRef ShapeDescriptor descriptor, @StdVector int[] dimensions, @Cast("const bool") boolean keepUnitiesInShape/*=false*/) { super((Pointer)null); allocate(descriptor, dimensions, keepUnitiesInShape); } - private native void allocate(@Const @ByRef ShapeDescriptor descriptor, @StdVector int[] dimensions, @Cast("const bool") boolean keepUnitiesInShape/*=false*/); - public TadDescriptor(@Const @ByRef ShapeDescriptor descriptor, @StdVector int[] dimensions) { super((Pointer)null); allocate(descriptor, dimensions); } - private native void allocate(@Const @ByRef ShapeDescriptor descriptor, @StdVector int[] dimensions); - public TadDescriptor(@Const @ByRef TadDescriptor other) { super((Pointer)null); allocate(other); } - private native void allocate(@Const @ByRef TadDescriptor other); - - // we use default copy assignment operator - public native @ByRef @Name("operator =") TadDescriptor put(@Const @ByRef TadDescriptor other); - - // we use default move assignment operator - - // equal to operator - public native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef TadDescriptor other); - - // less than operator - public native @Cast("bool") @Name("operator <") boolean lessThan(@Const @ByRef TadDescriptor other); - - public native @StdVector IntPointer axis(); - public native @ByRef ShapeDescriptor originalShape(); - public native @Const @ByRef ShapeDescriptor originalShapeConst(); - public native @Cast("bool") boolean areUnitiesinShape(); - } +// #include "ShapeDescriptor.h" +@Namespace("sd") @NoOffset public static class TadDescriptor extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public TadDescriptor(Pointer p) { super(p); } + + public TadDescriptor(@Cast("const Nd4jLong*") LongPointer originalShape, @Const IntPointer dimensions, + int length, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/) { super((Pointer)null); allocate(originalShape, dimensions, length, keepUnitiesInShape); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer originalShape, @Const IntPointer dimensions, + int length, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/); + public TadDescriptor(@Cast("const Nd4jLong*") LongPointer originalShape, @Const IntPointer dimensions, + int length) { super((Pointer)null); allocate(originalShape, dimensions, length); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer originalShape, @Const IntPointer dimensions, + int length); + public TadDescriptor(@Cast("const Nd4jLong*") LongBuffer originalShape, @Const IntBuffer dimensions, + int length, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/) { super((Pointer)null); allocate(originalShape, dimensions, length, keepUnitiesInShape); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer originalShape, @Const IntBuffer dimensions, + int length, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/); + public TadDescriptor(@Cast("const Nd4jLong*") LongBuffer originalShape, @Const IntBuffer dimensions, + int length) { super((Pointer)null); allocate(originalShape, dimensions, length); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer originalShape, @Const IntBuffer dimensions, + int length); + public TadDescriptor(@Cast("const Nd4jLong*") long[] originalShape, @Const int[] dimensions, + int length, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/) { super((Pointer)null); allocate(originalShape, dimensions, length, keepUnitiesInShape); } + private native void allocate(@Cast("const Nd4jLong*") long[] originalShape, @Const int[] dimensions, + int length, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/); + public TadDescriptor(@Cast("const Nd4jLong*") long[] originalShape, @Const int[] dimensions, + int length) { super((Pointer)null); allocate(originalShape, dimensions, length); } + private native void allocate(@Cast("const Nd4jLong*") long[] originalShape, @Const int[] dimensions, + int length); + public TadDescriptor(@Const @ByRef ShapeDescriptor descriptor, + @StdVector IntPointer dimensions, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/) { super((Pointer)null); allocate(descriptor, dimensions, keepUnitiesInShape); } + private native void allocate(@Const @ByRef ShapeDescriptor descriptor, + @StdVector IntPointer dimensions, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/); + public TadDescriptor(@Const @ByRef ShapeDescriptor descriptor, + @StdVector IntPointer dimensions) { super((Pointer)null); allocate(descriptor, dimensions); } + private native void allocate(@Const @ByRef ShapeDescriptor descriptor, + @StdVector IntPointer dimensions); + public TadDescriptor(@Const @ByRef ShapeDescriptor descriptor, + @StdVector IntBuffer dimensions, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/) { super((Pointer)null); allocate(descriptor, dimensions, keepUnitiesInShape); } + private native void allocate(@Const @ByRef ShapeDescriptor descriptor, + @StdVector IntBuffer dimensions, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/); + public TadDescriptor(@Const @ByRef ShapeDescriptor descriptor, + @StdVector IntBuffer dimensions) { super((Pointer)null); allocate(descriptor, dimensions); } + private native void allocate(@Const @ByRef ShapeDescriptor descriptor, + @StdVector IntBuffer dimensions); + public TadDescriptor(@Const @ByRef ShapeDescriptor descriptor, + @StdVector int[] dimensions, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/) { super((Pointer)null); allocate(descriptor, dimensions, keepUnitiesInShape); } + private native void allocate(@Const @ByRef ShapeDescriptor descriptor, + @StdVector int[] dimensions, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/); + public TadDescriptor(@Const @ByRef ShapeDescriptor descriptor, + @StdVector int[] dimensions) { super((Pointer)null); allocate(descriptor, dimensions); } + private native void allocate(@Const @ByRef ShapeDescriptor descriptor, + @StdVector int[] dimensions); + public TadDescriptor(@Const @ByRef TadDescriptor other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef TadDescriptor other); + + // we use default copy assignment operator + public native @ByRef @Name("operator =") TadDescriptor put(@Const @ByRef TadDescriptor other); + + // we use default move assignment operator + + // equal to operator + public native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef TadDescriptor other); + + // less than operator + public native @Cast("bool") @Name("operator <") boolean lessThan(@Const @ByRef TadDescriptor other); + + public native @StdVector IntPointer axis(); + public native @ByRef ShapeDescriptor originalShape(); + public native @Const @ByRef ShapeDescriptor originalShapeConst(); + public native @Cast("bool") boolean areUnitiesinShape(); +} + // namespace sd // #ifndef __JAVACPP_HACK__ // #endif - -// #endif //DEV_TESTS_TADDESCRIPTOR_H +// #endif // SD_TADDESCRIPTOR_H // Parsed from helpers/DebugInfo.h @@ -24746,48 +27336,48 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J__DEBUG_INFO_HELPER__H // #define LIBND4J__DEBUG_INFO_HELPER__H -// #include -// #include -// #include // #include -// #include -// #include // #include +// #include +// #include +// #include +// #include + +// #include // #ifdef __CUDACC__ // #endif - @Namespace("sd") public static class DebugInfo extends Pointer { - static { Loader.load(); } - /** Default native constructor. */ - public DebugInfo() { super((Pointer)null); allocate(); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public DebugInfo(long size) { super((Pointer)null); allocateArray(size); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public DebugInfo(Pointer p) { super(p); } - private native void allocate(); - private native void allocateArray(long size); - @Override public DebugInfo position(long position) { - return (DebugInfo)super.position(position); - } - - public native double _minValue(); public native DebugInfo _minValue(double setter); - public native double _maxValue(); public native DebugInfo _maxValue(double setter); - public native double _meanValue(); public native DebugInfo _meanValue(double setter); - public native double _stdDevValue(); public native DebugInfo _stdDevValue(double setter); - public native @Cast("Nd4jLong") long _zeroCount(); public native DebugInfo _zeroCount(long setter); - public native @Cast("Nd4jLong") long _positiveCount(); public native DebugInfo _positiveCount(long setter); - public native @Cast("Nd4jLong") long _negativeCount(); public native DebugInfo _negativeCount(long setter); - public native @Cast("Nd4jLong") long _infCount(); public native DebugInfo _infCount(long setter); - public native @Cast("Nd4jLong") long _nanCount(); public native DebugInfo _nanCount(long setter); +@Namespace("sd") public static class DebugInfo extends Pointer { + static { Loader.load(); } + /** Default native constructor. */ + public DebugInfo() { super((Pointer)null); allocate(); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public DebugInfo(long size) { super((Pointer)null); allocateArray(size); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public DebugInfo(Pointer p) { super(p); } + private native void allocate(); + private native void allocateArray(long size); + @Override public DebugInfo position(long position) { + return (DebugInfo)super.position(position); } - @Namespace("sd") public static native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef DebugInfo first, @Const @ByRef DebugInfo second); - + public native double _minValue(); public native DebugInfo _minValue(double setter); + public native double _maxValue(); public native DebugInfo _maxValue(double setter); + public native double _meanValue(); public native DebugInfo _meanValue(double setter); + public native double _stdDevValue(); public native DebugInfo _stdDevValue(double setter); + public native @Cast("Nd4jLong") long _zeroCount(); public native DebugInfo _zeroCount(long setter); + public native @Cast("Nd4jLong") long _positiveCount(); public native DebugInfo _positiveCount(long setter); + public native @Cast("Nd4jLong") long _negativeCount(); public native DebugInfo _negativeCount(long setter); + public native @Cast("Nd4jLong") long _infCount(); public native DebugInfo _infCount(long setter); + public native @Cast("Nd4jLong") long _nanCount(); public native DebugInfo _nanCount(long setter); +} +@Namespace("sd") public static native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef DebugInfo first, @Const @ByRef DebugInfo second); + // namespace sd -// #endif //LIBND4J_DEBUGHELPER_H +// #endif // LIBND4J_DEBUGHELPER_H // Parsed from ops/declarable/headers/third_party.h @@ -24816,25 +27406,25 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_HEADERS_TPARTY_H // #include -// #if NOT_EXCLUDED(OP_firas_sparse) - @Namespace("sd::ops") public static class firas_sparse extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public firas_sparse(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public firas_sparse(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public firas_sparse position(long position) { - return (firas_sparse)super.position(position); - } - +// #if NOT_EXCLUDED(OP_firas_sparse) +@Namespace("sd::ops") public static class firas_sparse extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public firas_sparse(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public firas_sparse(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public firas_sparse position(long position) { + return (firas_sparse)super.position(position); + } + public firas_sparse() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - +// #endif + // namespace ops + // namespace sd // #endif diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java index f10410314368..a2514d522ab9 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java @@ -35,6 +35,9 @@ @Properties(inherit = openblas.class, target = "org.nd4j.nativeblas.Nd4jCpu", helper = "org.nd4j.nativeblas.Nd4jCpuHelper", value = {@Platform(define = "LIBND4J_ALL_OPS", include = { "memory/MemoryType.h", + "memory/MemoryZone.h", + "memory/MemoryDescriptor.h", + "memory/GraphMemoryManager.h", "array/DataType.h", "array/DataBuffer.h", "array/PointerDeallocator.h", @@ -67,7 +70,6 @@ "graph/FlowPath.h", "graph/Intervals.h", "graph/Stash.h", - "graph/GraphState.h", "graph/VariableSpace.h", "helpers/helper_generator.h", "graph/profiling/GraphProfile.h", @@ -80,8 +82,6 @@ "array/ShapeList.h", "system/type_boilerplate.h", "system/op_boilerplate.h", - //"enum_boilerplate.h", - //"op_enums.h", "ops/InputType.h", "ops/declarable/OpDescriptor.h", "ops/declarable/PlatformHelper.h", @@ -159,7 +159,7 @@ public void init(Logger logger, java.util.Properties properties, String encoding @Override public void map(InfoMap infoMap) { - infoMap.put(new Info("thread_local", "ND4J_EXPORT", "INLINEDEF", "CUBLASWINAPI", "FORCEINLINE", + infoMap.put(new Info("thread_local", "ND4J_EXPORT", "SD_EXPORT", "INLINEDEF", "CUBLASWINAPI", "FORCEINLINE", "_CUDA_H", "_CUDA_D", "_CUDA_G", "_CUDA_HD", "LIBND4J_ALL_OPS", "NOT_EXCLUDED").cppTypes().annotations()) .put(new Info("NativeOps.h").objectify()) .put(new Info("OpaqueTadPack").pointerTypes("OpaqueTadPack")) @@ -174,19 +174,13 @@ public void map(InfoMap infoMap) { .put(new Info("OpaqueContext").pointerTypes("OpaqueContext")) .put(new Info("OpaqueRandomGenerator").pointerTypes("OpaqueRandomGenerator")) .put(new Info("OpaqueLaunchContext").pointerTypes("OpaqueLaunchContext")) - .put(new Info("const char").valueTypes("byte").pointerTypes("@Cast(\"char*\") String", - "@Cast(\"char*\") BytePointer")) - .put(new Info("char").valueTypes("char").pointerTypes("@Cast(\"char*\") BytePointer", - "@Cast(\"char*\") String")) + .put(new Info("const char").valueTypes("byte").pointerTypes("@Cast(\"char*\") String", "@Cast(\"char*\") BytePointer")) + .put(new Info("char").valueTypes("char").pointerTypes("@Cast(\"char*\") BytePointer", "@Cast(\"char*\") String")) .put(new Info("Nd4jPointer").cast().valueTypes("Pointer").pointerTypes("PointerPointer")) - .put(new Info("Nd4jLong").cast().valueTypes("long").pointerTypes("LongPointer", "LongBuffer", - "long[]")) - .put(new Info("Nd4jStatus").cast().valueTypes("int").pointerTypes("IntPointer", "IntBuffer", - "int[]")) - .put(new Info("float16").cast().valueTypes("short").pointerTypes("ShortPointer", "ShortBuffer", - "short[]")) - .put(new Info("bfloat16").cast().valueTypes("short").pointerTypes("ShortPointer", "ShortBuffer", - "short[]")); + .put(new Info("Nd4jLong").cast().valueTypes("long").pointerTypes("LongPointer", "LongBuffer", "long[]")) + .put(new Info("Nd4jStatus").cast().valueTypes("int").pointerTypes("IntPointer", "IntBuffer", "int[]")) + .put(new Info("float16").cast().valueTypes("short").pointerTypes("ShortPointer", "ShortBuffer", "short[]")) + .put(new Info("bfloat16").cast().valueTypes("short").pointerTypes("ShortPointer", "ShortBuffer", "short[]")); infoMap.put(new Info("__CUDACC__", "MAX_UINT", "HAVE_MKLDNN", "__CUDABLAS__").define(false)) .put(new Info("__JAVACPP_HACK__", "LIBND4J_ALL_OPS").define(true)) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java index 45ecc80fc325..9c002ed41b7c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java @@ -1119,4 +1119,10 @@ public void testControlDependencies1() throws Exception { assertEquals(variables.get("cond/LinSpace/num"), Collections.singletonList("cond/switch_t")); assertEquals(variables.get("cond/ones"), Collections.singletonList("cond/switch_f")); } + + @Test + public void testWhile() throws Exception { + SameDiff sd = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/while1/iter_1/frozen_model.pb").getInputStream()); + sd.save(new File("../../../libnd4j/tests_cpu/resources/while_iter1.fb"), true); + } } \ No newline at end of file diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/Gradients.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/Gradients.java index 4307efe1ed8b..cae7ab60f875 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/Gradients.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/Gradients.java @@ -1,26 +1,26 @@ -/******************************************************************************* - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ -package org.deeplearning4j.rl4j.agent.update; - -import lombok.Value; -import org.deeplearning4j.nn.gradient.Gradient; - -// Work in progress -@Value -public class Gradients { - private Gradient[] gradients; // Temporary: we'll need something better than a Gradient[] - private int batchSize; -} +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.agent.update; + +import lombok.Value; +import org.deeplearning4j.nn.gradient.Gradient; + +// Work in progress +@Value +public class Gradients { + private Gradient[] gradients; // Temporary: we'll need something better than a Gradient[] + private int batchSize; +}